diff --git a/.bazelrc b/.bazelrc index e21a1a32917..f11c376df65 100644 --- a/.bazelrc +++ b/.bazelrc @@ -39,6 +39,7 @@ # # Feature and Third party library support options: # xla: Build TF with XLA +# tpu: Build TF with TPU support # using_cuda: CUDA is available to build system. # cuda: Build with full cuda support. # rocm: Build with AMD GPU support (rocm). @@ -57,13 +58,12 @@ # # # Remote build execution options (only configured to work with TF team projects for now.) -# rbe: General RBE options shared by all flavors. -# rbe_linux: General RBE options used on all linux builds. -# rbe_win: General RBE options used on all windows builds. +# rbe: General RBE options shared by all flavors. +# rbe_linux: General RBE options used on all linux builds. +# rbe_win: General RBE options used on all windows builds. # -# rbe_cpu_linux: RBE options to build with only CPU support. -# rbe_linux_cuda_nvcc: RBE options to build with GPU support using nvcc. -# rbe_gpu_linux: An alias for rbe_linux_cuda_nvcc +# rbe_cpu_linux: RBE options to build with only CPU support. +# rbe_linux_cuda_nvcc_py*: RBE options to build with GPU support using nvcc. # # rbe_linux_py2: Linux Python 2 RBE config. # rbe_linux_py3: Linux Python 3 RBE config @@ -180,6 +180,9 @@ build:dbg --cxxopt -DTF_LITE_DISABLE_X86_NEON # AWS SDK must be compiled in release mode. see: https://github.com/tensorflow/tensorflow/issues/37498 build:dbg --copt -DDEBUG_BUILD +# Config to build TPU backend +build:tpu --define=with_tpu_support=true + build:tensorrt --action_env TF_NEED_TENSORRT=1 build:rocm --crosstool_top=@local_config_rocm//crosstool:toolchain @@ -396,33 +399,48 @@ build:rbe_linux_cuda_base --repo_env=REMOTE_GPU_TESTING=1 build:rbe_linux_cuda_base --repo_env=TF_NEED_CUDA=1 test:rbe_linux_cuda_base --test_env=LD_LIBRARY_PATH="/usr/local/cuda/lib64:/usr/local/cuda/extras/CUPTI/lib64" -build:rbe_linux_cuda_nvcc --config=rbe_linux_cuda_base -build:rbe_linux_cuda_nvcc --crosstool_top="@ubuntu18.04-py3-gcc7_manylinux2010-cuda10.1-cudnn7-tensorrt6.0_config_cuda//crosstool:toolchain" -build:rbe_linux_cuda_nvcc --extra_toolchains="@ubuntu18.04-py3-gcc7_manylinux2010-cuda10.1-cudnn7-tensorrt6.0_config_cuda//crosstool:toolchain-linux-x86_64" -build:rbe_linux_cuda_nvcc --extra_execution_platforms="@ubuntu18.04-py3-gcc7_manylinux2010-cuda10.1-cudnn7-tensorrt6.0_config_platform//:platform" -build:rbe_linux_cuda_nvcc --host_platform="@ubuntu18.04-py3-gcc7_manylinux2010-cuda10.1-cudnn7-tensorrt6.0_config_platform//:platform" -build:rbe_linux_cuda_nvcc --platforms="@ubuntu18.04-py3-gcc7_manylinux2010-cuda10.1-cudnn7-tensorrt6.0_config_platform//:platform" -build:rbe_linux_cuda_nvcc --repo_env=TF_CUDA_CONFIG_REPO="@ubuntu18.04-py3-gcc7_manylinux2010-cuda10.1-cudnn7-tensorrt6.0_config_cuda" -build:rbe_linux_cuda_nvcc --repo_env=TF_TENSORRT_CONFIG_REPO="@ubuntu18.04-py3-gcc7_manylinux2010-cuda10.1-cudnn7-tensorrt6.0_config_tensorrt" -build:rbe_linux_cuda_nvcc --repo_env=TF_NCCL_CONFIG_REPO="@ubuntu18.04-py3-gcc7_manylinux2010-cuda10.1-cudnn7-tensorrt6.0_config_nccl" -build:rbe_linux_cuda_nvcc --define=using_cuda_nvcc=true -test:rbe_linux_cuda_nvcc --config=rbe_linux_cuda_base +build:rbe_linux_cuda10.1_nvcc_base --config=rbe_linux_cuda_base +build:rbe_linux_cuda10.1_nvcc_base --define=using_cuda_nvcc=true +build:rbe_linux_cuda10.1_nvcc_base --crosstool_top="@ubuntu18.04-gcc7_manylinux2010-cuda10.1-cudnn7-tensorrt6.0_config_cuda//crosstool:toolchain" +build:rbe_linux_cuda10.1_nvcc_base --extra_toolchains="@ubuntu18.04-gcc7_manylinux2010-cuda10.1-cudnn7-tensorrt6.0_config_cuda//crosstool:toolchain-linux-x86_64" +build:rbe_linux_cuda10.1_nvcc_base --extra_execution_platforms="@ubuntu18.04-gcc7_manylinux2010-cuda10.1-cudnn7-tensorrt6.0_config_platform//:platform" +build:rbe_linux_cuda10.1_nvcc_base --host_platform="@ubuntu18.04-gcc7_manylinux2010-cuda10.1-cudnn7-tensorrt6.0_config_platform//:platform" +build:rbe_linux_cuda10.1_nvcc_base --platforms="@ubuntu18.04-gcc7_manylinux2010-cuda10.1-cudnn7-tensorrt6.0_config_platform//:platform" +build:rbe_linux_cuda10.1_nvcc_base --repo_env=TF_CUDA_CONFIG_REPO="@ubuntu18.04-gcc7_manylinux2010-cuda10.1-cudnn7-tensorrt6.0_config_cuda" +build:rbe_linux_cuda10.1_nvcc_base --repo_env=TF_TENSORRT_CONFIG_REPO="@ubuntu18.04-gcc7_manylinux2010-cuda10.1-cudnn7-tensorrt6.0_config_tensorrt" +build:rbe_linux_cuda10.1_nvcc_base --repo_env=TF_NCCL_CONFIG_REPO="@ubuntu18.04-gcc7_manylinux2010-cuda10.1-cudnn7-tensorrt6.0_config_nccl" +build:rbe_linux_cuda10.1_nvcc_py2.7 --config=rbe_linux_cuda10.1_nvcc_base --repo_env=TF_PYTHON_CONFIG_REPO="@ubuntu18.04-gcc7_manylinux2010-cuda10.1-cudnn7-tensorrt6.0_config_python2.7" +build:rbe_linux_cuda10.1_nvcc_py3.5 --config=rbe_linux_cuda10.1_nvcc_base --repo_env=TF_PYTHON_CONFIG_REPO="@ubuntu18.04-gcc7_manylinux2010-cuda10.1-cudnn7-tensorrt6.0_config_python3.5" +build:rbe_linux_cuda10.1_nvcc_py3.6 --config=rbe_linux_cuda10.1_nvcc_base --repo_env=TF_PYTHON_CONFIG_REPO="@ubuntu18.04-gcc7_manylinux2010-cuda10.1-cudnn7-tensorrt6.0_config_python3.6" +build:rbe_linux_cuda10.1_nvcc_py3.7 --config=rbe_linux_cuda10.1_nvcc_base --repo_env=TF_PYTHON_CONFIG_REPO="@ubuntu18.04-gcc7_manylinux2010-cuda10.1-cudnn7-tensorrt6.0_config_python3.7" +build:rbe_linux_cuda10.1_nvcc_py3.8 --config=rbe_linux_cuda10.1_nvcc_base --repo_env=TF_PYTHON_CONFIG_REPO="@ubuntu18.04-gcc7_manylinux2010-cuda10.1-cudnn7-tensorrt6.0_config_python3.8" -build:rbe_linux_cuda_nvcc_base --config=rbe_linux_cuda_base -build:rbe_linux_cuda_nvcc_base --crosstool_top="@ubuntu18.04-gcc7_manylinux2010-cuda10.1-cudnn7-tensorrt6.0_config_cuda//crosstool:toolchain" -build:rbe_linux_cuda_nvcc_base --extra_toolchains="@ubuntu18.04-gcc7_manylinux2010-cuda10.1-cudnn7-tensorrt6.0_config_cuda//crosstool:toolchain-linux-x86_64" -build:rbe_linux_cuda_nvcc_base --extra_execution_platforms="@ubuntu18.04-gcc7_manylinux2010-cuda10.1-cudnn7-tensorrt6.0_config_platform//:platform" -build:rbe_linux_cuda_nvcc_base --host_platform="@ubuntu18.04-gcc7_manylinux2010-cuda10.1-cudnn7-tensorrt6.0_config_platform//:platform" -build:rbe_linux_cuda_nvcc_base --platforms="@ubuntu18.04-gcc7_manylinux2010-cuda10.1-cudnn7-tensorrt6.0_config_platform//:platform" -build:rbe_linux_cuda_nvcc_base --repo_env=TF_CUDA_CONFIG_REPO="@ubuntu18.04-gcc7_manylinux2010-cuda10.1-cudnn7-tensorrt6.0_config_cuda" -build:rbe_linux_cuda_nvcc_base --repo_env=TF_TENSORRT_CONFIG_REPO="@ubuntu18.04-gcc7_manylinux2010-cuda10.1-cudnn7-tensorrt6.0_config_tensorrt" -build:rbe_linux_cuda_nvcc_base --repo_env=TF_NCCL_CONFIG_REPO="@ubuntu18.04-gcc7_manylinux2010-cuda10.1-cudnn7-tensorrt6.0_config_nccl" -build:rbe_linux_cuda_nvcc_base --define=using_cuda_nvcc=true -build:rbe_linux_cuda_nvcc_py27 --config=rbe_linux_cuda_nvcc_base --repo_env=TF_PYTHON_CONFIG_REPO="@ubuntu18.04-gcc7_manylinux2010-cuda10.1-cudnn7-tensorrt6.0_config_python2.7" -build:rbe_linux_cuda_nvcc_py35 --config=rbe_linux_cuda_nvcc_base --repo_env=TF_PYTHON_CONFIG_REPO="@ubuntu18.04-gcc7_manylinux2010-cuda10.1-cudnn7-tensorrt6.0_config_python3.5" -build:rbe_linux_cuda_nvcc_py36 --config=rbe_linux_cuda_nvcc_base --repo_env=TF_PYTHON_CONFIG_REPO="@ubuntu18.04-gcc7_manylinux2010-cuda10.1-cudnn7-tensorrt6.0_config_python3.6" -build:rbe_linux_cuda_nvcc_py37 --config=rbe_linux_cuda_nvcc_base --repo_env=TF_PYTHON_CONFIG_REPO="@ubuntu18.04-gcc7_manylinux2010-cuda10.1-cudnn7-tensorrt6.0_config_python3.7" -build:rbe_linux_cuda_nvcc_py38 --config=rbe_linux_cuda_nvcc_base --repo_env=TF_PYTHON_CONFIG_REPO="@ubuntu18.04-gcc7_manylinux2010-cuda10.1-cudnn7-tensorrt6.0_config_python3.8" +build:rbe_linux_cuda11.0_nvcc_base --config=rbe_linux_cuda_base +build:rbe_linux_cuda11.0_nvcc_base --define=using_cuda_nvcc=true +build:rbe_linux_cuda11.0_nvcc_base --crosstool_top="@ubuntu18.04-gcc7_manylinux2010-cuda11.0-cudnn8-tensorrt7.1_config_cuda//crosstool:toolchain" +build:rbe_linux_cuda11.0_nvcc_base --extra_toolchains="@ubuntu18.04-gcc7_manylinux2010-cuda11.0-cudnn8-tensorrt7.1_config_cuda//crosstool:toolchain-linux-x86_64" +build:rbe_linux_cuda11.0_nvcc_base --extra_execution_platforms="@ubuntu18.04-gcc7_manylinux2010-cuda11.0-cudnn8-tensorrt7.1_config_platform//:platform" +build:rbe_linux_cuda11.0_nvcc_base --host_platform="@ubuntu18.04-gcc7_manylinux2010-cuda11.0-cudnn8-tensorrt7.1_config_platform//:platform" +build:rbe_linux_cuda11.0_nvcc_base --platforms="@ubuntu18.04-gcc7_manylinux2010-cuda11.0-cudnn8-tensorrt7.1_config_platform//:platform" +build:rbe_linux_cuda11.0_nvcc_base --repo_env=TF_CUDA_CONFIG_REPO="@ubuntu18.04-gcc7_manylinux2010-cuda11.0-cudnn8-tensorrt7.1_config_cuda" +build:rbe_linux_cuda11.0_nvcc_base --repo_env=TF_TENSORRT_CONFIG_REPO="@ubuntu18.04-gcc7_manylinux2010-cuda11.0-cudnn8-tensorrt7.1_config_tensorrt" +build:rbe_linux_cuda11.0_nvcc_base --repo_env=TF_NCCL_CONFIG_REPO="@ubuntu18.04-gcc7_manylinux2010-cuda11.0-cudnn8-tensorrt7.1_config_nccl" +build:rbe_linux_cuda11.0_nvcc_py2.7 --config=rbe_linux_cuda11.0_nvcc_base --repo_env=TF_PYTHON_CONFIG_REPO="@ubuntu18.04-gcc7_manylinux2010-cuda11.0-cudnn8-tensorrt7.1_config_python2.7" +build:rbe_linux_cuda11.0_nvcc_py3.5 --config=rbe_linux_cuda11.0_nvcc_base --repo_env=TF_PYTHON_CONFIG_REPO="@ubuntu18.04-gcc7_manylinux2010-cuda11.0-cudnn8-tensorrt7.1_config_python3.5" +build:rbe_linux_cuda11.0_nvcc_py3.6 --config=rbe_linux_cuda11.0_nvcc_base --repo_env=TF_PYTHON_CONFIG_REPO="@ubuntu18.04-gcc7_manylinux2010-cuda11.0-cudnn8-tensorrt7.1_config_python3.6" +build:rbe_linux_cuda11.0_nvcc_py3.7 --config=rbe_linux_cuda11.0_nvcc_base --repo_env=TF_PYTHON_CONFIG_REPO="@ubuntu18.04-gcc7_manylinux2010-cuda11.0-cudnn8-tensorrt7.1_config_python3.7" +build:rbe_linux_cuda11.0_nvcc_py3.8 --config=rbe_linux_cuda11.0_nvcc_base --repo_env=TF_PYTHON_CONFIG_REPO="@ubuntu18.04-gcc7_manylinux2010-cuda11.0-cudnn8-tensorrt7.1_config_python3.8" + +# Map default to CUDA 10.1. +build:rbe_linux_cuda_nvcc_py27 --config=rbe_linux_cuda10.1_nvcc_py2.7 +build:rbe_linux_cuda_nvcc_py35 --config=rbe_linux_cuda10.1_nvcc_py3.5 +build:rbe_linux_cuda_nvcc_py36 --config=rbe_linux_cuda10.1_nvcc_py3.6 +build:rbe_linux_cuda_nvcc_py37 --config=rbe_linux_cuda10.1_nvcc_py3.7 +build:rbe_linux_cuda_nvcc_py38 --config=rbe_linux_cuda10.1_nvcc_py3.8 + +# Deprecated configs that people might still use. +build:rbe_linux_cuda_nvcc --config=rbe_linux_cuda_nvcc_py36 +build:rbe_gpu_linux --config=rbe_linux_cuda_nvcc build:rbe_linux_cuda_clang_base --config=rbe_linux_cuda_base build:rbe_linux_cuda_clang_base --crosstool_top="@ubuntu16.04-clang_manylinux2010-cuda10.1-cudnn7-tensorrt6.0_config_cuda//crosstool:toolchain" @@ -440,8 +458,6 @@ build:rbe_linux_cuda_clang_py36 --config=rbe_linux_cuda_clang_base --repo_env=TF build:rbe_linux_cuda_clang_py37 --config=rbe_linux_cuda_clang_base --repo_env=TF_PYTHON_CONFIG_REPO="@ubuntu16.04-clang_manylinux2010-cuda10.1-cudnn7-tensorrt6.0_config_python3.7" build:rbe_linux_cuda_clang_py38 --config=rbe_linux_cuda_clang_base --repo_env=TF_PYTHON_CONFIG_REPO="@ubuntu16.04-clang_manylinux2010-cuda10.1-cudnn7-tensorrt6.0_config_python3.8" -common:rbe_gpu_linux --config=rbe_linux_cuda_nvcc - build:rbe_linux_py2 --config=rbe_linux build:rbe_linux_py2 --repo_env=PYTHON_BIN_PATH="/usr/bin/python2" build:rbe_linux_py2 --python_path="/usr/bin/python2" diff --git a/README.md b/README.md index 54c9470b04b..73a345706a4 100644 --- a/README.md +++ b/README.md @@ -95,6 +95,7 @@ for general questions and discussion, and please direct specific questions to The TensorFlow project strives to abide by generally accepted best practices in open-source software development: +[![Fuzzing Status](https://oss-fuzz-build-logs.storage.googleapis.com/badges/tensorflow.svg)](https://bugs.chromium.org/p/oss-fuzz/issues/list?sort=-opened&can=1&q=proj:tensorflow) [![CII Best Practices](https://bestpractices.coreinfrastructure.org/projects/1486/badge)](https://bestpractices.coreinfrastructure.org/projects/1486) [![Contributor Covenant](https://img.shields.io/badge/Contributor%20Covenant-v1.4%20adopted-ff69b4.svg)](CODE_OF_CONDUCT.md) diff --git a/RELEASE.md b/RELEASE.md index f93626cc876..68d9399676a 100644 --- a/RELEASE.md +++ b/RELEASE.md @@ -1,3 +1,57 @@ +# Release 2.4.0 + + + +## Breaking Changes + +* +* + +## Known Caveats + +* + +## Major Features and Improvements + +* +* + +## Bug Fixes and Other Changes + +* +* +* +* TF Core: + * +* `tf.data`: + * +* `tf.distribute`: + * +* `tf.keras`: + * +* `tf.function`/AutoGraph: + * +* `tf.lite`: + * +* `tf.random`: + * +* Math and Linear Algebra: + * +* TPU Enhancements: + * +* XLA Support: + * +* Tracing and Debugging: + * +* Other: + * + +## Thanks to our Contributors + +This release contains contributions from many people at Google, as well as: + +, , , , , + # Release 2.3.0 ## Breaking Changes diff --git a/tensorflow/BUILD b/tensorflow/BUILD index bd0619b0c05..d00608ccc98 100644 --- a/tensorflow/BUILD +++ b/tensorflow/BUILD @@ -467,6 +467,13 @@ config_setting( visibility = ["//visibility:public"], ) +# This flag enables experimental TPU support +config_setting( + name = "with_tpu_support", + values = {"define": "with_tpu_support=true"}, + visibility = ["//visibility:public"], +) + # Specifies via a config setting if this is a mobile build or not, makes # it easier to combine settings later. selects.config_setting_group( diff --git a/tensorflow/c/c_api_experimental.cc b/tensorflow/c/c_api_experimental.cc index e9e6d470c68..831c6a0ad40 100644 --- a/tensorflow/c/c_api_experimental.cc +++ b/tensorflow/c/c_api_experimental.cc @@ -624,7 +624,7 @@ void TFE_InferShapes(TFE_Op* tfe_op, TF_ShapeAndTypeList* input_shapes, const int num_inputs = input_shapes->num_items; NodeDef node_def; - tensorflow::AbstractOperationInterface* op = tensorflow::unwrap(tfe_op); + tensorflow::ImmediateExecutionOperation* op = tensorflow::unwrap(tfe_op); node_def.set_name(op->Name()); node_def.set_op(op->Name()); for (int i = 0; i < num_inputs; ++i) { diff --git a/tensorflow/c/eager/BUILD b/tensorflow/c/eager/BUILD index 9d3c79e0ae7..5f7ab4a1f59 100644 --- a/tensorflow/c/eager/BUILD +++ b/tensorflow/c/eager/BUILD @@ -38,9 +38,10 @@ tf_cuda_library( "//tensorflow/core:portable_tensorflow_lib_lite", ], "//conditions:default": [ - ":context_interface", - ":operation_interface", - ":tensor_handle_interface", + ":immediate_execution_context", + ":immediate_execution_operation", + ":immediate_execution_tensor_handle", + ":abstract_tensor_handle", ":tfe_context_internal", ":tfe_cancellation_manager_internal", ":tfe_executor_internal", @@ -101,13 +102,17 @@ tf_cuda_library( filegroup( name = "pywrap_required_hdrs", srcs = [ + "abstract_context.h", + "abstract_function.h", + "abstract_operation.h", + "abstract_tensor_handle.h", "c_api_experimental.h", "c_api_internal.h", "c_api_unified_experimental.h", - "context_interface.h", "dlpack.h", - "operation_interface.h", - "tensor_handle_interface.h", + "immediate_execution_context.h", + "immediate_execution_operation.h", + "immediate_execution_tensor_handle.h", "tfe_cancellation_manager_internal.h", "tfe_executor_internal.h", "tfe_monitoring_internal.h", @@ -163,12 +168,22 @@ cc_library( ) cc_library( - name = "tensor_handle_interface", - hdrs = ["tensor_handle_interface.h"], + name = "abstract_tensor_handle", + hdrs = ["abstract_tensor_handle.h"], + visibility = [ + "//tensorflow:internal", + ], + deps = [], +) + +cc_library( + name = "immediate_execution_tensor_handle", + hdrs = ["immediate_execution_tensor_handle.h"], visibility = [ "//tensorflow:internal", ], deps = [ + ":abstract_tensor_handle", "//tensorflow/c:tensor_interface", "//tensorflow/core:framework", "//tensorflow/core:lib", @@ -177,13 +192,13 @@ cc_library( ) cc_library( - name = "operation_interface", - hdrs = ["operation_interface.h"], + name = "abstract_operation", + hdrs = ["abstract_operation.h"], visibility = [ "//tensorflow:internal", ], deps = [ - ":tensor_handle_interface", + ":abstract_tensor_handle", "//tensorflow/c:tensor_interface", "//tensorflow/core:framework", "//tensorflow/core:lib", @@ -193,14 +208,58 @@ cc_library( ) cc_library( - name = "context_interface", - hdrs = ["context_interface.h"], + name = "immediate_execution_operation", + hdrs = ["immediate_execution_operation.h"], visibility = [ "//tensorflow:internal", ], deps = [ - ":operation_interface", - ":tensor_handle_interface", + ":abstract_operation", + ":abstract_tensor_handle", + ":immediate_execution_tensor_handle", + "//tensorflow/c:tensor_interface", + "//tensorflow/core:framework", + "//tensorflow/core:lib", + "//tensorflow/core:protos_all_cc", + "@com_google_absl//absl/types:span", + ], +) + +cc_library( + name = "abstract_context", + hdrs = ["abstract_context.h"], + visibility = [ + "//tensorflow:internal", + ], + deps = [ + ":abstract_function", + ":abstract_operation", + ], +) + +cc_library( + name = "abstract_function", + hdrs = ["abstract_function.h"], + visibility = [ + "//tensorflow:internal", + ], + deps = [ + "//tensorflow/c:c_api", + "//tensorflow/core:protos_all_cc", + "//tensorflow/core/platform:status", + ], +) + +cc_library( + name = "immediate_execution_context", + hdrs = ["immediate_execution_context.h"], + visibility = [ + "//tensorflow:internal", + ], + deps = [ + ":abstract_context", + ":immediate_execution_operation", + ":immediate_execution_tensor_handle", "//tensorflow/c:tensor_interface", "//tensorflow/core:framework", "//tensorflow/core:lib", @@ -217,7 +276,7 @@ cc_library( "//tensorflow:internal", ], deps = [ - ":context_interface", + ":immediate_execution_context", "//tensorflow/c:conversion_macros", ], ) @@ -277,7 +336,7 @@ cc_library( "//tensorflow:internal", ], deps = [ - ":operation_interface", + ":immediate_execution_operation", "//tensorflow/c:conversion_macros", ], ) @@ -300,7 +359,7 @@ cc_library( "//tensorflow:internal", ], deps = [ - ":tensor_handle_interface", + ":immediate_execution_tensor_handle", "//tensorflow/c:conversion_macros", ], ) @@ -480,6 +539,9 @@ tf_cuda_library( ":tfe_context_internal", ":tfe_op_internal", ":tfe_tensorhandle_internal", + ":abstract_operation", + ":abstract_context", + ":abstract_tensor_handle", "//tensorflow/c:c_api", "//tensorflow/c:c_api_internal", "//tensorflow/core:core_cpu", diff --git a/tensorflow/c/eager/abstract_context.h b/tensorflow/c/eager/abstract_context.h new file mode 100644 index 00000000000..36d983e1408 --- /dev/null +++ b/tensorflow/c/eager/abstract_context.h @@ -0,0 +1,83 @@ +/* Copyright 2020 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ +#ifndef TENSORFLOW_C_EAGER_ABSTRACT_CONTEXT_H_ +#define TENSORFLOW_C_EAGER_ABSTRACT_CONTEXT_H_ + +#include +#include + +#include "tensorflow/c/eager/abstract_function.h" +#include "tensorflow/c/eager/abstract_operation.h" + +namespace tensorflow { + +// Abstract interface to a context. +// +// This serves as a factory for creating `AbstractOperation`s and for +// registering traced functions. +// Operations creation within a context can only be executed in that context +// (for now at least). +// Implementations of the context may contain some state e.g. an execution +// environment, a traced representation etc. +class AbstractContext { + protected: + enum AbstractContextKind { kTracing, kImmediateExecution }; + explicit AbstractContext(AbstractContextKind kind) : kind_(kind) {} + virtual ~AbstractContext() {} + + public: + AbstractContextKind getKind() const { return kind_; } + + // Release any underlying resources, including the interface object. + // + // WARNING: The destructor of this class is marked as protected to disallow + // clients from directly destroying this object since it may manage it's own + // lifetime through ref counting. Thus clients MUST call Release() in order to + // destroy an instance of this class. + virtual void Release() = 0; + + // Creates an operation builder and ties it to this context. + // The returned object can be used for setting operation's attributes, + // adding inputs and finally executing (immediately or lazily as in tracing) + // it in this context. + virtual AbstractOperation* CreateOperation() = 0; + + // Registers a function with this context, after this the function is + // available to be called/referenced by its name in this context. + virtual Status RegisterFunction(AbstractFunction*) = 0; + // Remove a function. 'func' argument is the name of a previously added + // FunctionDef. The name is in fdef.signature.name. + virtual Status RemoveFunction(const string& func) = 0; + + private: + const AbstractContextKind kind_; +}; + +namespace internal { +struct AbstractContextDeleter { + void operator()(AbstractContext* p) const { + if (p != nullptr) { + p->Release(); + } + } +}; +} // namespace internal + +using AbstractContextPtr = + std::unique_ptr; + +} // namespace tensorflow + +#endif // TENSORFLOW_C_EAGER_ABSTRACT_CONTEXT_H_ diff --git a/tensorflow/c/eager/abstract_function.h b/tensorflow/c/eager/abstract_function.h new file mode 100644 index 00000000000..e322b31f2b4 --- /dev/null +++ b/tensorflow/c/eager/abstract_function.h @@ -0,0 +1,46 @@ +/* Copyright 2020 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ +#ifndef TENSORFLOW_C_EAGER_ABSTRACT_FUNCTION_H_ +#define TENSORFLOW_C_EAGER_ABSTRACT_FUNCTION_H_ + +#include "tensorflow/c/c_api.h" +#include "tensorflow/core/framework/function.pb.h" +#include "tensorflow/core/platform/status.h" + +namespace tensorflow { + +// A traced function: this hides the complexity of converting the serialized +// representation between various supported formats e.g. FunctionDef and Mlir +// function. +class AbstractFunction { + protected: + enum AbstractFunctionKind { kGraphFunc, kMlirFunc }; + explicit AbstractFunction(AbstractFunctionKind kind) : kind_(kind) {} + + public: + // Returns which subclass is this instance of. + AbstractFunctionKind getKind() const { return kind_; } + virtual ~AbstractFunction() = default; + + // Returns the AbstractFunction as a FunctionDef. + virtual Status GetFunctionDef(FunctionDef**) = 0; + + private: + const AbstractFunctionKind kind_; +}; + +} // namespace tensorflow + +#endif // TENSORFLOW_C_EAGER_ABSTRACT_FUNCTION_H_ diff --git a/tensorflow/c/eager/operation_interface.h b/tensorflow/c/eager/abstract_operation.h similarity index 77% rename from tensorflow/c/eager/operation_interface.h rename to tensorflow/c/eager/abstract_operation.h index 844ba6c14bd..817d7656ec8 100644 --- a/tensorflow/c/eager/operation_interface.h +++ b/tensorflow/c/eager/abstract_operation.h @@ -12,24 +12,31 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#ifndef TENSORFLOW_C_EAGER_OPERATION_INTERFACE_H_ -#define TENSORFLOW_C_EAGER_OPERATION_INTERFACE_H_ +#ifndef TENSORFLOW_C_EAGER_ABSTRACT_OPERATION_H_ +#define TENSORFLOW_C_EAGER_ABSTRACT_OPERATION_H_ + +#include #include "absl/types/span.h" -#include "tensorflow/c/eager/tensor_handle_interface.h" +#include "tensorflow/c/eager/abstract_tensor_handle.h" #include "tensorflow/c/tensor_interface.h" -#include "tensorflow/core/framework/device_attributes.pb.h" -#include "tensorflow/core/framework/op_def.pb.h" #include "tensorflow/core/framework/types.pb.h" #include "tensorflow/core/platform/status.h" -struct TFE_Op; - namespace tensorflow { // Abstract interface to an operation. -class AbstractOperationInterface { +// This interface allows building and executing an operation in either +// tracing or immediate execution mode. +class AbstractOperation { + protected: + enum AbstractOperationKind { kTracing, kImmediateExecution }; + explicit AbstractOperation(AbstractOperationKind kind) : kind_(kind) {} + virtual ~AbstractOperation() {} + public: + AbstractOperationKind getKind() const { return kind_; } + // Release any underlying resources, including the interface object. // // WARNING: The destructor of this class is marked as protected to disallow @@ -38,7 +45,6 @@ class AbstractOperationInterface { // clients MUST call Release() in order to destroy an instance of this class. virtual void Release() = 0; - virtual void Clear() = 0; virtual Status Reset(const char* op, const char* raw_device_name) = 0; virtual const string& Name() const = 0; @@ -66,12 +72,10 @@ class AbstractOperationInterface { // existing and given constraints will be performed. virtual Status SetDeviceName(const char* name) = 0; - virtual Status AddInput(AbstractTensorHandleInterface* input) = 0; - virtual Status AddInputList( - absl::Span inputs) = 0; - virtual Status Execute(absl::Span retvals, + virtual Status AddInput(AbstractTensorHandle* input) = 0; + virtual Status AddInputList(absl::Span inputs) = 0; + virtual Status Execute(absl::Span retvals, int* num_retvals) = 0; - virtual const tensorflow::OpDef* OpDef() const = 0; virtual Status SetAttrString(const char* attr_name, const char* data, size_t length) = 0; @@ -82,7 +86,7 @@ class AbstractOperationInterface { virtual Status SetAttrShape(const char* attr_name, const int64_t* dims, const int num_dims) = 0; virtual Status SetAttrFunction(const char* attr_name, - const AbstractOperationInterface* value) = 0; + const AbstractOperation* value) = 0; virtual Status SetAttrFunctionName(const char* attr_name, const char* value, size_t length) = 0; virtual Status SetAttrTensor(const char* attr_name, @@ -102,19 +106,25 @@ class AbstractOperationInterface { virtual Status SetAttrShapeList(const char* attr_name, const int64_t** dims, const int* num_dims, int num_values) = 0; virtual Status SetAttrFunctionList( - const char* attr_name, - absl::Span values) = 0; + const char* attr_name, absl::Span values) = 0; - virtual Status InputLength(const char* input_name, int* length) = 0; - virtual Status OutputLength(const char* output_name, int* length) = 0; - - // Experimental - virtual Status SetUseXla(bool enable) = 0; - - protected: - virtual ~AbstractOperationInterface() {} + private: + const AbstractOperationKind kind_; }; +namespace internal { +struct AbstractOperationDeleter { + void operator()(AbstractOperation* p) const { + if (p != nullptr) { + p->Release(); + } + } +}; +} // namespace internal + +using AbstractOpPtr = + std::unique_ptr; + } // namespace tensorflow -#endif // TENSORFLOW_C_EAGER_OPERATION_INTERFACE_H_ +#endif // TENSORFLOW_C_EAGER_ABSTRACT_OPERATION_H_ diff --git a/tensorflow/c/eager/abstract_tensor_handle.h b/tensorflow/c/eager/abstract_tensor_handle.h new file mode 100644 index 00000000000..64b941d0729 --- /dev/null +++ b/tensorflow/c/eager/abstract_tensor_handle.h @@ -0,0 +1,61 @@ +/* Copyright 2020 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ +#ifndef TENSORFLOW_C_EAGER_ABSTRACT_TENSOR_HANDLE_H_ +#define TENSORFLOW_C_EAGER_ABSTRACT_TENSOR_HANDLE_H_ + +#include + +namespace tensorflow { + +// Abstract interface to a Tensor handle in either tracing or immediate +// execution mode. +class AbstractTensorHandle { + protected: + enum AbstractTensorHandleKind { kTracing, kImmediateExecution }; + explicit AbstractTensorHandle(AbstractTensorHandleKind kind) : kind_(kind) {} + virtual ~AbstractTensorHandle() {} + + public: + AbstractTensorHandleKind getKind() const { return kind_; } + + // Release any underlying resources, including the interface object. + // + // WARNING: The destructor of this class is marked as protected to disallow + // clients from directly destroying this object since it may manage it's own + // lifetime through ref counting. Thus this must be allocated on the heap and + // clients MUST call Release() in order to destroy an instance of this class. + virtual void Release() = 0; + + private: + const AbstractTensorHandleKind kind_; +}; + +namespace internal { +struct AbstractTensorHandleDeleter { + void operator()(AbstractTensorHandle* p) const { + if (p != nullptr) { + p->Release(); + } + } +}; +} // namespace internal + +using AbstractTensorHandlePtr = + std::unique_ptr; + +} // namespace tensorflow + +#endif // TENSORFLOW_C_EAGER_ABSTRACT_TENSOR_HANDLE_H_ diff --git a/tensorflow/c/eager/c_api.cc b/tensorflow/c/eager/c_api.cc index fdc91675f8b..4be3cdd7c2d 100644 --- a/tensorflow/c/eager/c_api.cc +++ b/tensorflow/c/eager/c_api.cc @@ -21,6 +21,8 @@ limitations under the License. #include #include +#include "tensorflow/c/eager/abstract_tensor_handle.h" + // clang-format off #include "tensorflow/core/platform/platform.h" // clang-format on @@ -31,8 +33,8 @@ limitations under the License. #include "tensorflow/c/c_api_internal.h" #include "tensorflow/c/eager/c_api_experimental.h" #include "tensorflow/c/eager/c_api_internal.h" -#include "tensorflow/c/eager/operation_interface.h" -#include "tensorflow/c/eager/tensor_handle_interface.h" +#include "tensorflow/c/eager/immediate_execution_operation.h" +#include "tensorflow/c/eager/immediate_execution_tensor_handle.h" #include "tensorflow/c/eager/tfe_context_internal.h" #include "tensorflow/c/eager/tfe_op_internal.h" #include "tensorflow/c/eager/tfe_tensorhandle_internal.h" @@ -1119,7 +1121,7 @@ size_t TFE_TensorHandleDeviceMemorySize(TFE_TensorHandle* h, TFE_Op* TFE_NewOp(TFE_Context* ctx, const char* op_or_function_name, TF_Status* status) { - tensorflow::AbstractOperationInterface* new_op = + tensorflow::ImmediateExecutionOperation* new_op = tensorflow::unwrap(ctx)->CreateOperation(); status->status = new_op->Reset(op_or_function_name, nullptr); if (!status->status.ok()) { @@ -1164,7 +1166,9 @@ void TFE_OpAddInput(TFE_Op* op, TFE_TensorHandle* input, TF_Status* status) { void TFE_OpAddInputList(TFE_Op* op, TFE_TensorHandle** inputs, int num_inputs, TF_Status* status) { status->status = tensorflow::unwrap(op)->AddInputList( - {tensorflow::unwrap(inputs), static_cast(num_inputs)}); + {reinterpret_cast( + tensorflow::unwrap(inputs)), + static_cast(num_inputs)}); } TF_AttrType TFE_OpGetAttrType(TFE_Op* op, const char* attr_name, @@ -1324,7 +1328,9 @@ void TFE_OpSetAttrShapeList(TFE_Op* op, const char* attr_name, void TFE_OpSetAttrFunctionList(TFE_Op* op, const char* attr_name, const TFE_Op** value, int num_values) { auto s = tensorflow::unwrap(op)->SetAttrFunctionList( - attr_name, {tensorflow::unwrap(value), static_cast(num_values)}); + attr_name, {reinterpret_cast( + tensorflow::unwrap(value)), + static_cast(num_values)}); if (!s.ok()) { LOG(WARNING) << "Unable to set attribute: " << attr_name; } @@ -1368,7 +1374,10 @@ TF_CAPI_EXPORT extern int TFE_OpGetOutputLength(TFE_Op* op, void TFE_Execute(TFE_Op* op, TFE_TensorHandle** retvals, int* num_retvals, TF_Status* status) { status->status = tensorflow::unwrap(op)->Execute( - absl::MakeSpan(tensorflow::unwrap(retvals), *num_retvals), num_retvals); + absl::MakeSpan(reinterpret_cast( + tensorflow::unwrap(retvals)), + *num_retvals), + num_retvals); } TFE_TensorHandle* TFE_TensorHandleCopyToDevice(TFE_TensorHandle* h, diff --git a/tensorflow/c/eager/c_api_experimental.cc b/tensorflow/c/eager/c_api_experimental.cc index 0d71b11531b..7390cf243be 100644 --- a/tensorflow/c/eager/c_api_experimental.cc +++ b/tensorflow/c/eager/c_api_experimental.cc @@ -38,7 +38,7 @@ using tensorflow::string; void TFE_OpReset(TFE_Op* op_to_reset, const char* op_or_function_name, const char* raw_device_name, TF_Status* status) { if (op_to_reset) { - tensorflow::AbstractOperationInterface* op = + tensorflow::ImmediateExecutionOperation* op = tensorflow::unwrap(op_to_reset); op->Clear(); status->status = op->Reset(op_or_function_name, raw_device_name); @@ -60,6 +60,12 @@ void TFE_ContextDisableGraphCollection(TFE_Context* ctx) { context->SetShouldStoreGraphs(false); } +uint64_t TFE_GetContextId(TFE_Context* ctx) { + tensorflow::EagerContext* context = + tensorflow::ContextFromInterface(tensorflow::unwrap(ctx)); + return context->GetContextId(); +} + void TFE_MonitoringCounterCellIncrementBy(TFE_MonitoringCounterCell* cell, int64_t value) { cell->cell.IncrementBy(value); diff --git a/tensorflow/c/eager/c_api_experimental.h b/tensorflow/c/eager/c_api_experimental.h index 1b8efe61ee0..1af76c01154 100644 --- a/tensorflow/c/eager/c_api_experimental.h +++ b/tensorflow/c/eager/c_api_experimental.h @@ -300,6 +300,14 @@ TF_CAPI_EXPORT extern void TFE_ContextOptionsSetLazyRemoteInputsCopy( TF_CAPI_EXPORT extern void TFE_ContextOptionsSetTfrt(TFE_ContextOptions*, bool use_tfrt); +// Returns the context_id from the EagerContext which is used by the +// EagerService to maintain consistency between client and worker. The +// context_id is initialized with a dummy value and is later set when the worker +// is initialized (either locally or remotely). The context_id can change during +// the process lifetime although this should cause the worker to be +// reinitialized (e.g. cleared caches) as well. +TF_CAPI_EXPORT extern uint64_t TFE_GetContextId(TFE_Context* ctx); + // ----------------------------------------------------------------------------- // Cancellation APIs. diff --git a/tensorflow/c/eager/context_interface.h b/tensorflow/c/eager/immediate_execution_context.h similarity index 77% rename from tensorflow/c/eager/context_interface.h rename to tensorflow/c/eager/immediate_execution_context.h index e5a770a6826..77d59dd23e2 100644 --- a/tensorflow/c/eager/context_interface.h +++ b/tensorflow/c/eager/immediate_execution_context.h @@ -12,15 +12,17 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#ifndef TENSORFLOW_C_EAGER_CONTEXT_INTERFACE_H_ -#define TENSORFLOW_C_EAGER_CONTEXT_INTERFACE_H_ +#ifndef TENSORFLOW_C_EAGER_IMMEDIATE_EXECUTION_CONTEXT_H_ +#define TENSORFLOW_C_EAGER_IMMEDIATE_EXECUTION_CONTEXT_H_ +#include #include #include "absl/types/optional.h" #include "absl/types/span.h" -#include "tensorflow/c/eager/operation_interface.h" -#include "tensorflow/c/eager/tensor_handle_interface.h" +#include "tensorflow/c/eager/abstract_context.h" +#include "tensorflow/c/eager/immediate_execution_operation.h" +#include "tensorflow/c/eager/immediate_execution_tensor_handle.h" #include "tensorflow/c/tensor_interface.h" #include "tensorflow/core/framework/function.pb.h" #include "tensorflow/core/framework/numeric_types.h" @@ -34,16 +36,9 @@ namespace tensorflow { // // A context is responsible for creating key objects such as Tensors, // TensorHandles & Operations. -class AbstractContextInterface { +class ImmediateExecutionContext : public AbstractContext { public: - // Release any underlying resources, including the interface object. - // - // WARNING: The destructor of this class is marked as protected to disallow - // clients from directly destroying this object since it may manage it's own - // lifetime through ref counting. Thus clients MUST call Release() in order to - // destroy an instance of this class. - virtual void Release() = 0; - + static constexpr AbstractContextKind kKind = kImmediateExecution; // Optimized scalar creation functions virtual AbstractTensorInterface* CreateInt64Scalar(int64 value) = 0; virtual AbstractTensorInterface* CreateUint64Scalar(uint64 value) = 0; @@ -74,15 +69,15 @@ class AbstractContextInterface { void* memory_releaser_arg) = 0; // Create a handle to wrap and manage a Tensor - virtual AbstractTensorHandleInterface* CreateLocalHandle( + virtual ImmediateExecutionTensorHandle* CreateLocalHandle( AbstractTensorInterface* t) = 0; // Copy the handle to another device. - virtual AbstractTensorHandleInterface* CopyTensorHandleToDevice( - AbstractTensorHandleInterface* handle, const char* device_name, + virtual ImmediateExecutionTensorHandle* CopyTensorHandleToDevice( + ImmediateExecutionTensorHandle* handle, const char* device_name, Status* status) = 0; // Create an operation to perform op execution - virtual AbstractOperationInterface* CreateOperation() = 0; + ImmediateExecutionOperation* CreateOperation() override = 0; // Returns whether the runtime is backed by TFRT or the legacy TF Eager // Runtime. This is necessary to decouple runtime-dependent @@ -107,14 +102,26 @@ class AbstractContextInterface { // be executed as an op. Return error if the function with the same name // already exists. virtual Status AddFunctionDef(const FunctionDef& fdef) = 0; - // Remove a function. 'func' argument is the name of a previously added - // FunctionDef. The name is in fdef.signature.name. - virtual Status RemoveFunction(const string& func) = 0; protected: - virtual ~AbstractContextInterface() {} + ImmediateExecutionContext() : AbstractContext(kKind) {} + ~ImmediateExecutionContext() override {} }; +namespace internal { +struct ImmediateExecutionContextDeleter { + void operator()(ImmediateExecutionContext* p) const { + if (p != nullptr) { + p->Release(); + } + } +}; +} // namespace internal + +using ImmediateContextPtr = + std::unique_ptr; + } // namespace tensorflow -#endif // TENSORFLOW_C_EAGER_CONTEXT_INTERFACE_H_ +#endif // TENSORFLOW_C_EAGER_IMMEDIATE_EXECUTION_CONTEXT_H_ diff --git a/tensorflow/c/eager/immediate_execution_operation.h b/tensorflow/c/eager/immediate_execution_operation.h new file mode 100644 index 00000000000..4e2959ba7af --- /dev/null +++ b/tensorflow/c/eager/immediate_execution_operation.h @@ -0,0 +1,69 @@ +/* Copyright 2020 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ +#ifndef TENSORFLOW_C_EAGER_IMMEDIATE_EXECUTION_OPERATION_H_ +#define TENSORFLOW_C_EAGER_IMMEDIATE_EXECUTION_OPERATION_H_ + +#include + +#include "absl/types/span.h" +#include "tensorflow/c/eager/abstract_operation.h" +#include "tensorflow/c/eager/immediate_execution_tensor_handle.h" +#include "tensorflow/c/tensor_interface.h" +#include "tensorflow/core/framework/device_attributes.pb.h" +#include "tensorflow/core/framework/op_def.pb.h" +#include "tensorflow/core/framework/types.pb.h" +#include "tensorflow/core/platform/casts.h" +#include "tensorflow/core/platform/status.h" + +struct TFE_Op; + +namespace tensorflow { + +// Abstract interface to an operation. +class ImmediateExecutionOperation : public AbstractOperation { + public: + static constexpr AbstractOperationKind kKind = kImmediateExecution; + virtual void Clear() = 0; + + virtual const tensorflow::OpDef* OpDef() const = 0; + + virtual Status InputLength(const char* input_name, int* length) = 0; + virtual Status OutputLength(const char* output_name, int* length) = 0; + + // Experimental + virtual Status SetUseXla(bool enable) = 0; + + protected: + ImmediateExecutionOperation() : AbstractOperation(kKind) {} + ~ImmediateExecutionOperation() override {} +}; + +namespace internal { +struct ImmediateExecutionOperationDeleter { + void operator()(ImmediateExecutionOperation* p) const { + if (p != nullptr) { + p->Release(); + } + } +}; +} // namespace internal + +using ImmediateOpPtr = + std::unique_ptr; + +} // namespace tensorflow + +#endif // TENSORFLOW_C_EAGER_IMMEDIATE_EXECUTION_OPERATION_H_ diff --git a/tensorflow/c/eager/tensor_handle_interface.h b/tensorflow/c/eager/immediate_execution_tensor_handle.h similarity index 69% rename from tensorflow/c/eager/tensor_handle_interface.h rename to tensorflow/c/eager/immediate_execution_tensor_handle.h index 1ca40daec41..31aa3aa0f75 100644 --- a/tensorflow/c/eager/tensor_handle_interface.h +++ b/tensorflow/c/eager/immediate_execution_tensor_handle.h @@ -12,9 +12,10 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#ifndef TENSORFLOW_C_EAGER_TENSOR_HANDLE_INTERFACE_H_ -#define TENSORFLOW_C_EAGER_TENSOR_HANDLE_INTERFACE_H_ +#ifndef TENSORFLOW_C_EAGER_IMMEDIATE_EXECUTION_TENSOR_HANDLE_H_ +#define TENSORFLOW_C_EAGER_IMMEDIATE_EXECUTION_TENSOR_HANDLE_H_ +#include "tensorflow/c/eager/abstract_tensor_handle.h" #include "tensorflow/c/tensor_interface.h" #include "tensorflow/core/framework/types.pb.h" #include "tensorflow/core/platform/status.h" @@ -30,15 +31,9 @@ namespace tensorflow { // files. The interface lists the common functionality that must be provided by // any concrete implementation. However, in cases where the true concrete class // is needed a static_cast can be applied. -class AbstractTensorHandleInterface { +class ImmediateExecutionTensorHandle : public AbstractTensorHandle { public: - // Release any underlying resources, including the interface object. - // - // WARNING: The destructor of this class is marked as protected to disallow - // clients from directly destroying this object since it may manage it's own - // lifetime through ref counting. Thus this must be allocated on the heap and - // clients MUST call Release() in order to destroy an instance of this class. - virtual void Release() = 0; + static constexpr AbstractTensorHandleKind kKind = kImmediateExecution; // Returns tensor dtype. virtual tensorflow::DataType DataType() const = 0; @@ -57,12 +52,27 @@ class AbstractTensorHandleInterface { virtual AbstractTensorInterface* Resolve(Status* status) = 0; // Return a copy of the handle. - virtual AbstractTensorHandleInterface* Copy() = 0; + virtual ImmediateExecutionTensorHandle* Copy() = 0; protected: - virtual ~AbstractTensorHandleInterface() {} + ImmediateExecutionTensorHandle() : AbstractTensorHandle(kKind) {} + ~ImmediateExecutionTensorHandle() override {} }; +namespace internal { +struct ImmediateExecutionTensorHandleDeleter { + void operator()(ImmediateExecutionTensorHandle* p) const { + if (p != nullptr) { + p->Release(); + } + } +}; +} // namespace internal + +using ImmediateTensorHandlePtr = + std::unique_ptr; + } // namespace tensorflow -#endif // TENSORFLOW_C_EAGER_TENSOR_HANDLE_INTERFACE_H_ +#endif // TENSORFLOW_C_EAGER_IMMEDIATE_EXECUTION_TENSOR_HANDLE_H_ diff --git a/tensorflow/c/eager/parallel_device/parallel_device_lib.cc b/tensorflow/c/eager/parallel_device/parallel_device_lib.cc index d0149b29c08..768f686bd88 100644 --- a/tensorflow/c/eager/parallel_device/parallel_device_lib.cc +++ b/tensorflow/c/eager/parallel_device/parallel_device_lib.cc @@ -262,14 +262,14 @@ std::unique_ptr ParallelDevice::DeviceIDs( components.reserve(underlying_devices_.size()); for (int device_index = 0; device_index < underlying_devices_.size(); ++device_index) { - int64_t* device_id = new int64_t; + int32_t* device_id = new int32_t; *device_id = device_index; std::unique_ptr tensor( TF_NewTensor( - TF_INT64, /*dims=*/nullptr, /*num_dims=*/0, device_id, - sizeof(int64_t), + TF_INT32, /*dims=*/nullptr, /*num_dims=*/0, device_id, + sizeof(int32_t), [](void* data, size_t, void* arg) { - delete reinterpret_cast(data); + delete reinterpret_cast(data); }, nullptr), TF_DeleteTensor); @@ -283,7 +283,7 @@ std::unique_ptr ParallelDevice::DeviceIDs( if (TF_GetCode(status) != TF_OK) return nullptr; TFE_OpSetAttrTensor(const_op.get(), "value", tensor.get(), status); if (TF_GetCode(status) != TF_OK) return nullptr; - TFE_OpSetAttrType(const_op.get(), "dtype", TF_INT64); + TFE_OpSetAttrType(const_op.get(), "dtype", TF_INT32); TFE_TensorHandle* device_handle; int num_outputs = 1; TFE_Execute(const_op.get(), &device_handle, &num_outputs, status); diff --git a/tensorflow/c/eager/parallel_device/parallel_device_testlib.cc b/tensorflow/c/eager/parallel_device/parallel_device_testlib.cc index fba47865c36..828dcbae093 100644 --- a/tensorflow/c/eager/parallel_device/parallel_device_testlib.cc +++ b/tensorflow/c/eager/parallel_device/parallel_device_testlib.cc @@ -296,8 +296,8 @@ void BasicTestsForTwoDevices(TFE_Context* context, const char* first_device, TFE_DeleteTensorHandle(result_handle); ASSERT_TRUE(TF_GetCode(status.get()) == TF_OK) << TF_Message(status.get()); - ExpectScalarEq(components[0].get(), 0); - ExpectScalarEq(components[1].get(), 1); + ExpectScalarEq(components[0].get(), 0); + ExpectScalarEq(components[1].get(), 1); std::string first_device = TFE_TensorHandleBackingDeviceName(components[0].get(), status.get()); ASSERT_EQ(underlying_devices[0], first_device); diff --git a/tensorflow/c/eager/tfe_context_internal.h b/tensorflow/c/eager/tfe_context_internal.h index 1d29bee9ee3..1f2035317fa 100644 --- a/tensorflow/c/eager/tfe_context_internal.h +++ b/tensorflow/c/eager/tfe_context_internal.h @@ -16,7 +16,7 @@ limitations under the License. #define TENSORFLOW_C_EAGER_TFE_CONTEXT_INTERNAL_H_ #include "tensorflow/c/conversion_macros.h" -#include "tensorflow/c/eager/context_interface.h" +#include "tensorflow/c/eager/immediate_execution_context.h" // Wraps a pointer to a context implementation. // @@ -28,7 +28,7 @@ typedef struct TFE_Context TFE_Context; namespace tensorflow { -DEFINE_CONVERSION_FUNCTIONS(tensorflow::AbstractContextInterface, TFE_Context); +DEFINE_CONVERSION_FUNCTIONS(tensorflow::ImmediateExecutionContext, TFE_Context); } // namespace tensorflow diff --git a/tensorflow/c/eager/tfe_op_internal.h b/tensorflow/c/eager/tfe_op_internal.h index 6ca7f741d16..3fe94d358b6 100644 --- a/tensorflow/c/eager/tfe_op_internal.h +++ b/tensorflow/c/eager/tfe_op_internal.h @@ -16,7 +16,7 @@ limitations under the License. #define TENSORFLOW_C_EAGER_TFE_OP_INTERNAL_H_ #include "tensorflow/c/conversion_macros.h" -#include "tensorflow/c/eager/operation_interface.h" +#include "tensorflow/c/eager/immediate_execution_operation.h" // Wraps a pointer to an operation implementation. // @@ -28,8 +28,8 @@ typedef struct TFE_Op TFE_Op; namespace tensorflow { -DEFINE_CONVERSION_FUNCTIONS(tensorflow::AbstractOperationInterface, TFE_Op); -DEFINE_CONVERSION_FUNCTIONS(tensorflow::AbstractOperationInterface*, TFE_Op*); +DEFINE_CONVERSION_FUNCTIONS(tensorflow::ImmediateExecutionOperation, TFE_Op); +DEFINE_CONVERSION_FUNCTIONS(tensorflow::ImmediateExecutionOperation*, TFE_Op*); } // namespace tensorflow diff --git a/tensorflow/c/eager/tfe_tensorhandle_internal.h b/tensorflow/c/eager/tfe_tensorhandle_internal.h index 543e5f1d932..308e8c24e2c 100644 --- a/tensorflow/c/eager/tfe_tensorhandle_internal.h +++ b/tensorflow/c/eager/tfe_tensorhandle_internal.h @@ -16,7 +16,7 @@ limitations under the License. #define TENSORFLOW_C_EAGER_TFE_TENSORHANDLE_INTERNAL_H_ #include "tensorflow/c/conversion_macros.h" -#include "tensorflow/c/eager/tensor_handle_interface.h" +#include "tensorflow/c/eager/immediate_execution_tensor_handle.h" // Wraps a pointer to a tensor handle implementation. // @@ -28,9 +28,9 @@ typedef struct TFE_TensorHandle TFE_TensorHandle; namespace tensorflow { -DEFINE_CONVERSION_FUNCTIONS(tensorflow::AbstractTensorHandleInterface, +DEFINE_CONVERSION_FUNCTIONS(tensorflow::ImmediateExecutionTensorHandle, TFE_TensorHandle); -DEFINE_CONVERSION_FUNCTIONS(tensorflow::AbstractTensorHandleInterface*, +DEFINE_CONVERSION_FUNCTIONS(tensorflow::ImmediateExecutionTensorHandle*, TFE_TensorHandle*); } // namespace tensorflow diff --git a/tensorflow/c/env.cc b/tensorflow/c/env.cc index 1c35ff9001d..ce715c43acb 100644 --- a/tensorflow/c/env.cc +++ b/tensorflow/c/env.cc @@ -18,6 +18,7 @@ limitations under the License. #include "tensorflow/c/c_api_internal.h" #include "tensorflow/c/tf_status_helper.h" #include "tensorflow/core/platform/env.h" +#include "tensorflow/core/platform/path.h" #include "tensorflow/core/platform/types.h" struct TF_StringStream { @@ -146,6 +147,10 @@ TF_StringStream* TF_GetLocalTempDirectories() { return list; } +char* TF_GetTempFileName(const char* extension) { + return strdup(::tensorflow::io::GetTempFilename(extension).c_str()); +} + TF_CAPI_EXPORT extern uint64_t TF_NowNanos(void) { return ::tensorflow::Env::Default()->NowNanos(); } diff --git a/tensorflow/c/env.h b/tensorflow/c/env.h index 2a763730bc3..7dc7ac32f08 100644 --- a/tensorflow/c/env.h +++ b/tensorflow/c/env.h @@ -152,6 +152,10 @@ TF_CAPI_EXPORT extern TF_StringStream* TF_GetChildren(const char* filename, // The caller is responsible for freeing the list (see TF_StringStreamDone). TF_CAPI_EXPORT extern TF_StringStream* TF_GetLocalTempDirectories(void); +// Creates a temporary file name with an extension. +// The caller is responsible for freeing the returned pointer. +TF_CAPI_EXPORT extern char* TF_GetTempFileName(const char* extension); + // Returns the number of nanoseconds since the Unix epoch. TF_CAPI_EXPORT extern uint64_t TF_NowNanos(void); diff --git a/tensorflow/c/experimental/filesystem/plugins/gcs/BUILD b/tensorflow/c/experimental/filesystem/plugins/gcs/BUILD index c9fee433589..f61aa8347d4 100644 --- a/tensorflow/c/experimental/filesystem/plugins/gcs/BUILD +++ b/tensorflow/c/experimental/filesystem/plugins/gcs/BUILD @@ -1,5 +1,5 @@ # Experimental gcs filesystem plugin. -load("//tensorflow:tensorflow.bzl", "get_win_copts", "tf_cc_shared_object") +load("//tensorflow:tensorflow.bzl", "get_win_copts", "tf_cc_shared_object", "tf_cc_test") package( licenses = ["notice"], # Apache 2.0 @@ -24,9 +24,45 @@ cc_library( "//tensorflow:windows": get_win_copts(), }), deps = [ + ":gcs_helper", + "//tensorflow/c:env", "//tensorflow/c:tf_status", "//tensorflow/c/experimental/filesystem:filesystem_interface", "@com_github_googlecloudplatform_google_cloud_cpp//:storage_client", "@com_google_absl//absl/strings", ], ) + +cc_library( + name = "gcs_helper", + srcs = ["gcs_helper.cc"], + hdrs = ["gcs_helper.h"], + linkstatic = 1, + deps = [ + "//tensorflow/c:env", + ], +) + +tf_cc_test( + name = "gcs_filesystem_test", + srcs = [ + "gcs_filesystem.cc", + "gcs_filesystem_test.cc", + ], + local_defines = ["TF_GCS_FILESYSTEM_TEST"], + tags = [ + "manual", + "notap", + ], + deps = [ + ":gcs_helper", + "//tensorflow/c:env", + "//tensorflow/c:tf_status", + "//tensorflow/c:tf_status_helper", + "//tensorflow/c/experimental/filesystem:filesystem_interface", + "//tensorflow/core/platform:stacktrace_handler", + "//tensorflow/core/platform:test", + "@com_github_googlecloudplatform_google_cloud_cpp//:storage_client", + "@com_google_absl//absl/strings", + ], +) diff --git a/tensorflow/c/experimental/filesystem/plugins/gcs/gcs_filesystem.cc b/tensorflow/c/experimental/filesystem/plugins/gcs/gcs_filesystem.cc index 8c54bc85439..8c5c035f939 100644 --- a/tensorflow/c/experimental/filesystem/plugins/gcs/gcs_filesystem.cc +++ b/tensorflow/c/experimental/filesystem/plugins/gcs/gcs_filesystem.cc @@ -15,11 +15,23 @@ limitations under the License. #include #include +#include + #include "absl/strings/string_view.h" #include "google/cloud/storage/client.h" +#include "tensorflow/c/env.h" #include "tensorflow/c/experimental/filesystem/filesystem_interface.h" +#include "tensorflow/c/experimental/filesystem/plugins/gcs/gcs_helper.h" #include "tensorflow/c/tf_status.h" +#ifdef TF_GCS_FILESYSTEM_TEST +// For testing purpose, we expose some functions. +#define TF_STATIC +#else +// Otherwise, we don't expose any symbol. +#define TF_STATIC static +#endif + // Implementation of a filesystem for GCS environments. // This filesystem will support `gs://` URI schemes. namespace gcs = google::cloud::storage; @@ -86,6 +98,20 @@ namespace tf_random_access_file { // SECTION 2. Implementation for `TF_WritableFile` // ---------------------------------------------------------------------------- namespace tf_writable_file { +typedef struct GCSFile { + const char* bucket; + const char* object; + gcs::Client* gcs_client; // not owned + TempFile outfile; + bool sync_need; +} GCSFile; + +static void Cleanup(TF_WritableFile* file) { + auto gcs_file = static_cast(file->plugin_file); + plugin_memory_free(const_cast(gcs_file->bucket)); + plugin_memory_free(const_cast(gcs_file->object)); + delete gcs_file; +} // TODO(vnvo2409): Implement later @@ -104,7 +130,7 @@ namespace tf_read_only_memory_region { namespace tf_gcs_filesystem { // TODO(vnvo2409): Add lazy-loading and customizing parameters. -static void Init(TF_Filesystem* filesystem, TF_Status* status) { +TF_STATIC void Init(TF_Filesystem* filesystem, TF_Status* status) { google::cloud::StatusOr client = gcs::Client::CreateDefaultClient(); if (!client) { @@ -117,8 +143,54 @@ static void Init(TF_Filesystem* filesystem, TF_Status* status) { TF_SetStatus(status, TF_OK, ""); } +static void Cleanup(TF_Filesystem* filesystem) { + plugin_memory_free(filesystem->plugin_filesystem); +} + // TODO(vnvo2409): Implement later +static void NewWritableFile(const TF_Filesystem* filesystem, const char* path, + TF_WritableFile* file, TF_Status* status) { + char* bucket; + char* object; + ParseGCSPath(path, false, &bucket, &object, status); + if (TF_GetCode(status) != TF_OK) return; + + auto gcs_client = static_cast(filesystem->plugin_filesystem); + char* temp_file_name = TF_GetTempFileName(""); + file->plugin_file = new tf_writable_file::GCSFile( + {bucket, object, gcs_client, + TempFile(temp_file_name, std::ios::binary | std::ios::out), true}); + // We are responsible for freeing the pointer returned by TF_GetTempFileName + free(temp_file_name); + TF_SetStatus(status, TF_OK, ""); +} + +static void NewAppendableFile(const TF_Filesystem* filesystem, const char* path, + TF_WritableFile* file, TF_Status* status) { + char* bucket; + char* object; + ParseGCSPath(path, false, &bucket, &object, status); + if (TF_GetCode(status) != TF_OK) return; + + auto gcs_client = static_cast(filesystem->plugin_filesystem); + char* temp_file_name = TF_GetTempFileName(""); + + auto gcs_status = gcs_client->DownloadToFile(bucket, object, temp_file_name); + TF_SetStatusFromGCSStatus(gcs_status, status); + auto status_code = TF_GetCode(status); + if (status_code != TF_OK && status_code != TF_NOT_FOUND) { + return; + } + // If this file does not exist on server, we will need to sync it. + bool sync_need = (status_code == TF_NOT_FOUND); + file->plugin_file = new tf_writable_file::GCSFile( + {bucket, object, gcs_client, + TempFile(temp_file_name, std::ios::binary | std::ios::app), sync_need}); + free(temp_file_name); + TF_SetStatus(status, TF_OK, ""); +} + } // namespace tf_gcs_filesystem static void ProvideFilesystemSupportFor(TF_FilesystemPluginOps* ops, @@ -126,9 +198,17 @@ static void ProvideFilesystemSupportFor(TF_FilesystemPluginOps* ops, TF_SetFilesystemVersionMetadata(ops); ops->scheme = strdup(uri); + ops->writable_file_ops = static_cast( + plugin_memory_allocate(TF_WRITABLE_FILE_OPS_SIZE)); + ops->writable_file_ops->cleanup = tf_writable_file::Cleanup; + ops->filesystem_ops = static_cast( plugin_memory_allocate(TF_FILESYSTEM_OPS_SIZE)); ops->filesystem_ops->init = tf_gcs_filesystem::Init; + ops->filesystem_ops->cleanup = tf_gcs_filesystem::Cleanup; + ops->filesystem_ops->new_writable_file = tf_gcs_filesystem::NewWritableFile; + ops->filesystem_ops->new_appendable_file = + tf_gcs_filesystem::NewAppendableFile; } void TF_InitPlugin(TF_FilesystemPluginInfo* info) { diff --git a/tensorflow/c/experimental/filesystem/plugins/gcs/gcs_filesystem_test.cc b/tensorflow/c/experimental/filesystem/plugins/gcs/gcs_filesystem_test.cc new file mode 100644 index 00000000000..43221763791 --- /dev/null +++ b/tensorflow/c/experimental/filesystem/plugins/gcs/gcs_filesystem_test.cc @@ -0,0 +1,61 @@ +/* Copyright 2020 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ +#include "tensorflow/c/experimental/filesystem/filesystem_interface.h" +#include "tensorflow/c/tf_status_helper.h" +#include "tensorflow/core/platform/stacktrace_handler.h" +#include "tensorflow/core/platform/test.h" + +#define ASSERT_TF_OK(x) ASSERT_EQ(TF_OK, TF_GetCode(x)) + +// Forward declaration +namespace tf_gcs_filesystem { +void Init(TF_Filesystem* filesystem, TF_Status* status); +} + +namespace tensorflow { +namespace { + +class GCSFilesystemTest : public ::testing::Test { + public: + void SetUp() override { + status_ = TF_NewStatus(); + filesystem_ = new TF_Filesystem; + tf_gcs_filesystem::Init(filesystem_, status_); + ASSERT_TF_OK(status_) << "Can not initialize filesystem. " + << TF_Message(status_); + } + void TearDown() override { + TF_DeleteStatus(status_); + // TODO(vnvo2409): Add filesystem cleanup + delete filesystem_; + } + + protected: + TF_Filesystem* filesystem_; + TF_Status* status_; +}; + +// We have to add this test here because there must be at least one test. +// This test will be removed in the future. +TEST_F(GCSFilesystemTest, TestInit) { ASSERT_TF_OK(status_); } + +} // namespace +} // namespace tensorflow + +GTEST_API_ int main(int argc, char** argv) { + tensorflow::testing::InstallStacktraceHandler(); + ::testing::InitGoogleTest(&argc, argv); + return RUN_ALL_TESTS(); +} diff --git a/tensorflow/c/experimental/filesystem/plugins/gcs/gcs_helper.cc b/tensorflow/c/experimental/filesystem/plugins/gcs/gcs_helper.cc new file mode 100644 index 00000000000..4504a9f3b35 --- /dev/null +++ b/tensorflow/c/experimental/filesystem/plugins/gcs/gcs_helper.cc @@ -0,0 +1,34 @@ +/* Copyright 2020 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ +#include "tensorflow/c/experimental/filesystem/plugins/gcs/gcs_helper.h" + +#include + +#include +#include +#include + +TempFile::TempFile(const char* temp_file_name, std::ios::openmode mode) + : std::fstream(temp_file_name, mode), name_(temp_file_name) {} + +TempFile::TempFile(TempFile&& rhs) + : std::fstream(std::move(rhs)), name_(std::move(rhs.name_)) {} + +TempFile::~TempFile() { + std::fstream::close(); + std::remove(name_.c_str()); +} + +const std::string TempFile::getName() const { return name_; } diff --git a/tensorflow/c/experimental/filesystem/plugins/gcs/gcs_helper.h b/tensorflow/c/experimental/filesystem/plugins/gcs/gcs_helper.h new file mode 100644 index 00000000000..1a521ca4f1e --- /dev/null +++ b/tensorflow/c/experimental/filesystem/plugins/gcs/gcs_helper.h @@ -0,0 +1,33 @@ +/* Copyright 2020 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ +#ifndef TENSORFLOW_C_EXPERIMENTAL_FILESYSTEM_PLUGINS_GCS_GCS_HELPER_H_ +#define TENSORFLOW_C_EXPERIMENTAL_FILESYSTEM_PLUGINS_GCS_GCS_HELPER_H_ + +#include +#include + +class TempFile : public std::fstream { + public: + // We should specify openmode each time we call TempFile. + TempFile(const char* temp_file_name, std::ios::openmode mode); + TempFile(TempFile&& rhs); + ~TempFile() override; + const std::string getName() const; + + private: + const std::string name_; +}; + +#endif // TENSORFLOW_C_EXPERIMENTAL_FILESYSTEM_PLUGINS_GCS_GCS_HELPER_H_ diff --git a/tensorflow/c/experimental/saved_model/core/BUILD b/tensorflow/c/experimental/saved_model/core/BUILD index 2e817ed02e0..dbe1b6d656c 100644 --- a/tensorflow/c/experimental/saved_model/core/BUILD +++ b/tensorflow/c/experimental/saved_model/core/BUILD @@ -23,8 +23,8 @@ cc_library( ], deps = [ ":function_metadata", - "//tensorflow/c/eager:operation_interface", - "//tensorflow/c/eager:tensor_handle_interface", + "//tensorflow/c/eager:immediate_execution_operation", + "//tensorflow/c/eager:immediate_execution_tensor_handle", "//tensorflow/core:protos_all_cc", ], ) diff --git a/tensorflow/c/experimental/saved_model/core/concrete_function.cc b/tensorflow/c/experimental/saved_model/core/concrete_function.cc index d5da2ca9bf4..41bae4352fc 100644 --- a/tensorflow/c/experimental/saved_model/core/concrete_function.cc +++ b/tensorflow/c/experimental/saved_model/core/concrete_function.cc @@ -15,12 +15,12 @@ limitations under the License. #include "tensorflow/c/experimental/saved_model/core/concrete_function.h" -#include "tensorflow/c/eager/tensor_handle_interface.h" +#include "tensorflow/c/eager/immediate_execution_tensor_handle.h" #include "tensorflow/c/experimental/saved_model/core/function_metadata.h" namespace tensorflow { -const std::vector& +const std::vector& ConcreteFunction::GetCaptures() const { return captures_; } diff --git a/tensorflow/c/experimental/saved_model/core/concrete_function.h b/tensorflow/c/experimental/saved_model/core/concrete_function.h index 6f8a5375277..22535641ef5 100644 --- a/tensorflow/c/experimental/saved_model/core/concrete_function.h +++ b/tensorflow/c/experimental/saved_model/core/concrete_function.h @@ -18,8 +18,8 @@ limitations under the License. #include -#include "tensorflow/c/eager/operation_interface.h" -#include "tensorflow/c/eager/tensor_handle_interface.h" +#include "tensorflow/c/eager/immediate_execution_operation.h" +#include "tensorflow/c/eager/immediate_execution_tensor_handle.h" #include "tensorflow/c/experimental/saved_model/core/function_metadata.h" #include "tensorflow/core/framework/function.pb.h" @@ -38,15 +38,15 @@ class ConcreteFunction { virtual ~ConcreteFunction() = 0; // This method returns the "Call" Op used to execute the function. - virtual AbstractOperationInterface* GetCallOp() = 0; + virtual ImmediateExecutionOperation* GetCallOp() = 0; - const std::vector& GetCaptures() + const std::vector& GetCaptures() const; const FunctionMetadata& GetFunctionMetadata() const; private: FunctionMetadata metadata_; - std::vector captures_; + std::vector captures_; FunctionDef* function_; }; diff --git a/tensorflow/c/experimental/saved_model/core/ops/BUILD b/tensorflow/c/experimental/saved_model/core/ops/BUILD index b42e93c3716..1e2496487f9 100644 --- a/tensorflow/c/experimental/saved_model/core/ops/BUILD +++ b/tensorflow/c/experimental/saved_model/core/ops/BUILD @@ -14,44 +14,6 @@ package( licenses = ["notice"], # Apache 2.0 ) -cc_library( - name = "owned_eager_op", - hdrs = [ - "owned_eager_op.h", - ], - deps = [ - "//tensorflow/c/eager:operation_interface", - ], -) - -cc_library( - name = "owned_tensor_handle", - hdrs = [ - "owned_tensor_handle.h", - ], - deps = [ - "//tensorflow/c/eager:tensor_handle_interface", - "//tensorflow/core/common_runtime/eager:tensor_handle", - ], -) - -cc_library( - name = "owned_eager_context", - hdrs = ["owned_eager_context.h"], - deps = [ - "//tensorflow/c/eager:context_interface", - "//tensorflow/core/common_runtime/eager:context", - ], -) - -cc_library( - name = "owned_tensor", - hdrs = ["owned_tensor.h"], - deps = [ - "//tensorflow/c:tensor_interface", - ], -) - cc_library( name = "variable_ops", srcs = [ @@ -61,10 +23,11 @@ cc_library( "variable_ops.h", ], deps = [ - ":owned_eager_op", - ":owned_tensor_handle", - "//tensorflow/c/eager:context_interface", - "//tensorflow/c/eager:tensor_handle_interface", + "//tensorflow/c/eager:abstract_operation", + "//tensorflow/c/eager:abstract_tensor_handle", + "//tensorflow/c/eager:immediate_execution_context", + "//tensorflow/c/eager:immediate_execution_operation", + "//tensorflow/c/eager:immediate_execution_tensor_handle", "//tensorflow/core:framework", "//tensorflow/core:lib", "//tensorflow/core:protos_all_cc", @@ -78,10 +41,11 @@ tf_cc_test( "variable_ops_test.cc", ], deps = [ - ":owned_eager_context", - ":owned_tensor", - ":owned_tensor_handle", ":variable_ops", + "//tensorflow/c:tensor_interface", + "//tensorflow/c/eager:abstract_tensor_handle", + "//tensorflow/c/eager:immediate_execution_context", + "//tensorflow/c/eager:immediate_execution_tensor_handle", "//tensorflow/core:framework", "//tensorflow/core:lib", "//tensorflow/core:protos_all_cc", diff --git a/tensorflow/c/experimental/saved_model/core/ops/owned_eager_context.h b/tensorflow/c/experimental/saved_model/core/ops/owned_eager_context.h deleted file mode 100644 index 300059cd069..00000000000 --- a/tensorflow/c/experimental/saved_model/core/ops/owned_eager_context.h +++ /dev/null @@ -1,54 +0,0 @@ -/* Copyright 2020 The TensorFlow Authors. All Rights Reserved. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. -==============================================================================*/ - -#ifndef TENSORFLOW_C_EXPERIMENTAL_SAVED_MODEL_CORE_OPS_OWNED_EAGER_CONTEXT_H_ -#define TENSORFLOW_C_EXPERIMENTAL_SAVED_MODEL_CORE_OPS_OWNED_EAGER_CONTEXT_H_ - -#include - -#include "tensorflow/c/eager/context_interface.h" -#include "tensorflow/core/common_runtime/eager/context.h" - -namespace tensorflow { -namespace internal { - -struct AbstractContextInterfaceDeleter { - void operator()(AbstractContextInterface* p) const { - if (p != nullptr) { - p->Release(); - } - } -}; - -struct EagerContextDeleter { - void operator()(EagerContext* p) const { - if (p != nullptr) { - p->Release(); - } - } -}; - -} // namespace internal - -using AbstractContextPtr = - std::unique_ptr; - -using EagerContextPtr = - std::unique_ptr; - -} // namespace tensorflow - -#endif // THIRD_PARTY_TENSORFLOW_C_EXPERIMENTAL_SAVED_MODEL_CORE_OPS_OWNED_EAGER_CONTEXT_H_ diff --git a/tensorflow/c/experimental/saved_model/core/ops/owned_tensor_handle.h b/tensorflow/c/experimental/saved_model/core/ops/owned_tensor_handle.h deleted file mode 100644 index e98d6554afb..00000000000 --- a/tensorflow/c/experimental/saved_model/core/ops/owned_tensor_handle.h +++ /dev/null @@ -1,54 +0,0 @@ -/* Copyright 2020 The TensorFlow Authors. All Rights Reserved. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. -==============================================================================*/ - -#ifndef TENSORFLOW_C_EXPERIMENTAL_SAVED_MODEL_CORE_OWNED_TENSOR_HANDLE_H_ -#define TENSORFLOW_C_EXPERIMENTAL_SAVED_MODEL_CORE_OWNED_TENSOR_HANDLE_H_ - -#include - -#include "tensorflow/c/eager/tensor_handle_interface.h" -#include "tensorflow/core/common_runtime/eager/tensor_handle.h" - -namespace tensorflow { -namespace internal { - -struct TensorHandleDeleter { - void operator()(TensorHandle* p) const { - if (p != nullptr) { - p->Release(); - } - } -}; - -struct AbstractTensorHandleDeleter { - void operator()(AbstractTensorHandleInterface* p) const { - if (p != nullptr) { - p->Release(); - } - } -}; - -} // namespace internal - -using TensorHandlePtr = - std::unique_ptr; - -using AbstractTensorHandlePtr = - std::unique_ptr; - -} // namespace tensorflow - -#endif // TENSORFLOW_C_EXPERIMENTAL_SAVED_MODEL_CORE_OWNED_TENSOR_HANDLE_H_ diff --git a/tensorflow/c/experimental/saved_model/core/ops/variable_ops.cc b/tensorflow/c/experimental/saved_model/core/ops/variable_ops.cc index a3b3ace7be9..67c592fc16b 100644 --- a/tensorflow/c/experimental/saved_model/core/ops/variable_ops.cc +++ b/tensorflow/c/experimental/saved_model/core/ops/variable_ops.cc @@ -16,9 +16,11 @@ limitations under the License. #include "tensorflow/c/experimental/saved_model/core/ops/variable_ops.h" #include "absl/types/span.h" -#include "tensorflow/c/eager/context_interface.h" -#include "tensorflow/c/experimental/saved_model/core/ops/owned_eager_op.h" -#include "tensorflow/c/experimental/saved_model/core/ops/owned_tensor_handle.h" +#include "tensorflow/c/eager/abstract_operation.h" +#include "tensorflow/c/eager/abstract_tensor_handle.h" +#include "tensorflow/c/eager/immediate_execution_context.h" +#include "tensorflow/c/eager/immediate_execution_operation.h" +#include "tensorflow/c/eager/immediate_execution_tensor_handle.h" #include "tensorflow/core/framework/tensor_shape.h" #include "tensorflow/core/framework/types.pb.h" #include "tensorflow/core/lib/gtl/inlined_vector.h" @@ -32,10 +34,10 @@ namespace internal { static const char kNoSharingResourceID[] = "cd2c89b7-88b7-44c8-ad83-06c2a9158347"; -Status CreateUninitializedResourceVariable(AbstractContextInterface* ctx, +Status CreateUninitializedResourceVariable(ImmediateExecutionContext* ctx, DataType dtype, TensorShape shape, - AbstractTensorHandlePtr* handle) { - AbstractOpPtr varhandle_op = AbstractOpPtr(ctx->CreateOperation()); + ImmediateTensorHandlePtr* handle) { + ImmediateOpPtr varhandle_op(ctx->CreateOperation()); TF_RETURN_IF_ERROR(varhandle_op->Reset("VarHandleOp", nullptr)); TF_RETURN_IF_ERROR(varhandle_op->SetAttrType("dtype", dtype)); @@ -50,18 +52,23 @@ Status CreateUninitializedResourceVariable(AbstractContextInterface* ctx, TF_RETURN_IF_ERROR(varhandle_op->SetAttrString( "shared_name", kNoSharingResourceID, strlen(kNoSharingResourceID))); - AbstractTensorHandleInterface* var_handle = nullptr; + AbstractTensorHandle* var_handle = nullptr; int num_retvals = 1; TF_RETURN_IF_ERROR(varhandle_op->Execute( absl::MakeSpan(&var_handle, num_retvals), &num_retvals)); - handle->reset(var_handle); + AbstractTensorHandlePtr owned_var_handle(var_handle); + if (owned_var_handle->getKind() != ImmediateExecutionTensorHandle::kKind) { + return errors::Internal("Unexpected tensor handle kind."); + } + handle->reset(reinterpret_cast( + owned_var_handle.release())); return Status(); } -Status AssignVariable(AbstractContextInterface* ctx, - AbstractTensorHandleInterface* variable_handle, - DataType dtype, AbstractTensorHandleInterface* value) { - AbstractOpPtr assign_op(ctx->CreateOperation()); +Status AssignVariable(ImmediateExecutionContext* ctx, + ImmediateExecutionTensorHandle* variable_handle, + DataType dtype, ImmediateExecutionTensorHandle* value) { + ImmediateOpPtr assign_op(ctx->CreateOperation()); TF_RETURN_IF_ERROR(assign_op->Reset("AssignVariableOp", nullptr)); TF_RETURN_IF_ERROR(assign_op->SetAttrType("dtype", dtype)); TF_RETURN_IF_ERROR(assign_op->AddInput(variable_handle)); @@ -72,25 +79,30 @@ Status AssignVariable(AbstractContextInterface* ctx, return Status(); } -Status ReadVariable(AbstractContextInterface* ctx, - AbstractTensorHandleInterface* variable_handle, - DataType dtype, AbstractTensorHandlePtr* output) { - AbstractOpPtr read_op = AbstractOpPtr(ctx->CreateOperation()); +Status ReadVariable(ImmediateExecutionContext* ctx, + ImmediateExecutionTensorHandle* variable_handle, + DataType dtype, ImmediateTensorHandlePtr* output) { + ImmediateOpPtr read_op(ctx->CreateOperation()); TF_RETURN_IF_ERROR(read_op->Reset("ReadVariableOp", nullptr)); TF_RETURN_IF_ERROR(read_op->SetAttrType("dtype", dtype)); TF_RETURN_IF_ERROR(read_op->AddInput(variable_handle)); - AbstractTensorHandleInterface* value = nullptr; + AbstractTensorHandle* value = nullptr; int num_retvals = 1; TF_RETURN_IF_ERROR( read_op->Execute(absl::MakeSpan(&value, num_retvals), &num_retvals)); - output->reset(value); + AbstractTensorHandlePtr owned_value(value); + if (owned_value->getKind() != ImmediateExecutionTensorHandle::kKind) { + return errors::Internal("Unexpected tensor handle kind."); + } + output->reset( + reinterpret_cast(owned_value.release())); return Status(); } -Status DestroyResource(AbstractContextInterface* ctx, - AbstractTensorHandleInterface* handle) { - AbstractOpPtr destroy_op = AbstractOpPtr(ctx->CreateOperation()); +Status DestroyResource(ImmediateExecutionContext* ctx, + ImmediateExecutionTensorHandle* handle) { + ImmediateOpPtr destroy_op(ctx->CreateOperation()); TF_RETURN_IF_ERROR(destroy_op->Reset("DestroyResourceOp", nullptr)); TF_RETURN_IF_ERROR(destroy_op->SetAttrBool("ignore_lookup_error", true)); TF_RETURN_IF_ERROR(destroy_op->AddInput(handle)); diff --git a/tensorflow/c/experimental/saved_model/core/ops/variable_ops.h b/tensorflow/c/experimental/saved_model/core/ops/variable_ops.h index 8a410328b9e..13c941a77fe 100644 --- a/tensorflow/c/experimental/saved_model/core/ops/variable_ops.h +++ b/tensorflow/c/experimental/saved_model/core/ops/variable_ops.h @@ -16,9 +16,8 @@ limitations under the License. #ifndef TENSORFLOW_C_EXPERIMENTAL_SAVED_MODEL_CORE_OPS_VARIABLE_OPS_H #define TENSORFLOW_C_EXPERIMENTAL_SAVED_MODEL_CORE_OPS_VARIABLE_OPS_H -#include "tensorflow/c/eager/context_interface.h" -#include "tensorflow/c/eager/tensor_handle_interface.h" -#include "tensorflow/c/experimental/saved_model/core/ops/owned_tensor_handle.h" +#include "tensorflow/c/eager/immediate_execution_context.h" +#include "tensorflow/c/eager/immediate_execution_tensor_handle.h" #include "tensorflow/core/framework/tensor_shape.h" #include "tensorflow/core/framework/types.pb.h" #include "tensorflow/core/platform/status.h" @@ -30,31 +29,31 @@ namespace internal { // TensorHandle associated with the variable. This is equivalent to creating an // unitialized TF2 tf.Variable. // https://github.com/tensorflow/tensorflow/blob/516608035f85cec8b126712b0ff8407220206b22/tensorflow/python/ops/resource_variable_ops.py#L1867-L1872 -Status CreateUninitializedResourceVariable(AbstractContextInterface* ctx, +Status CreateUninitializedResourceVariable(ImmediateExecutionContext* ctx, DataType dtype, TensorShape shape, - AbstractTensorHandlePtr* handle); + ImmediateTensorHandlePtr* handle); // Executes an AssignVariableOp using `ctx`, assigning the variable associated // with `variable_handle` with `value`. `dtype` must be the datatype of the // underlying variable for `variable_handle`. Note that it is illegal to assign // a variable to a Tensor with a different dtype than what the variable was // created with. -Status AssignVariable(AbstractContextInterface* ctx, - AbstractTensorHandleInterface* variable_handle, - DataType dtype, AbstractTensorHandleInterface* value); +Status AssignVariable(ImmediateExecutionContext* ctx, + ImmediateExecutionTensorHandle* variable_handle, + DataType dtype, ImmediateExecutionTensorHandle* value); // Executes a ReadVariableOp using `ctx`. This reads the underlying variable // value of `variable_handle` and copies the value to `output`. `dtype` must be // the dtype of the variable associated with `variable_handle`. -Status ReadVariable(AbstractContextInterface* ctx, - AbstractTensorHandleInterface* variable_handle, - DataType dtype, AbstractTensorHandlePtr* output); +Status ReadVariable(ImmediateExecutionContext* ctx, + ImmediateExecutionTensorHandle* variable_handle, + DataType dtype, ImmediateTensorHandlePtr* output); // Executes DestroyResourceOp on `handle`, using `ctx`. This is equivalent to // the cleanup that occurs in a tf.Variable's EagerResourceDeleter: // https://github.com/tensorflow/tensorflow/blob/516608035f85cec8b126712b0ff8407220206b22/tensorflow/python/ops/resource_variable_ops.py#L289-L290 -Status DestroyResource(AbstractContextInterface* ctx, - AbstractTensorHandleInterface* handle); +Status DestroyResource(ImmediateExecutionContext* ctx, + ImmediateExecutionTensorHandle* handle); } // namespace internal } // namespace tensorflow diff --git a/tensorflow/c/experimental/saved_model/core/ops/variable_ops_test.cc b/tensorflow/c/experimental/saved_model/core/ops/variable_ops_test.cc index 3c57ed4d38a..09c45332efc 100644 --- a/tensorflow/c/experimental/saved_model/core/ops/variable_ops_test.cc +++ b/tensorflow/c/experimental/saved_model/core/ops/variable_ops_test.cc @@ -17,9 +17,8 @@ limitations under the License. #include -#include "tensorflow/c/experimental/saved_model/core/ops/owned_eager_context.h" -#include "tensorflow/c/experimental/saved_model/core/ops/owned_tensor.h" -#include "tensorflow/c/experimental/saved_model/core/ops/owned_tensor_handle.h" +#include "tensorflow/c/eager/immediate_execution_tensor_handle.h" +#include "tensorflow/c/tensor_interface.h" #include "tensorflow/core/common_runtime/device_mgr.h" #include "tensorflow/core/common_runtime/eager/context.h" #include "tensorflow/core/framework/tensor.h" @@ -30,10 +29,10 @@ limitations under the License. namespace tensorflow { namespace { -AbstractTensorHandlePtr CreateScalarTensorHandle(EagerContext* context, - float value) { +ImmediateTensorHandlePtr CreateScalarTensorHandle(EagerContext* context, + float value) { AbstractTensorPtr tensor(context->CreateFloatScalar(value)); - AbstractTensorHandlePtr handle(context->CreateLocalHandle(tensor.get())); + ImmediateTensorHandlePtr handle(context->CreateLocalHandle(tensor.get())); return handle; } @@ -62,7 +61,7 @@ class VariableOpsTest : public ::testing::Test { // Sanity check for variable creation TEST_F(VariableOpsTest, CreateVariableSuccessful) { // Create a DT_Resource TensorHandle that points to a scalar DT_FLOAT tensor - AbstractTensorHandlePtr handle; + ImmediateTensorHandlePtr handle; TF_EXPECT_OK(internal::CreateUninitializedResourceVariable( context(), DT_FLOAT, {}, &handle)); // The created TensorHandle should be a DT_Resource @@ -72,7 +71,7 @@ TEST_F(VariableOpsTest, CreateVariableSuccessful) { // Sanity check for variable destruction TEST_F(VariableOpsTest, DestroyVariableSuccessful) { // Create a DT_Resource TensorHandle that points to a scalar DT_FLOAT tensor - AbstractTensorHandlePtr handle; + ImmediateTensorHandlePtr handle; TF_EXPECT_OK(internal::CreateUninitializedResourceVariable( context(), DT_FLOAT, {}, &handle)); @@ -83,18 +82,18 @@ TEST_F(VariableOpsTest, DestroyVariableSuccessful) { // Sanity check for handle assignment and reading TEST_F(VariableOpsTest, AssignVariableAndReadSuccessful) { // Create a DT_Resource TensorHandle that points to a scalar DT_FLOAT tensor - AbstractTensorHandlePtr variable; + ImmediateTensorHandlePtr variable; TF_EXPECT_OK(internal::CreateUninitializedResourceVariable( context(), DT_FLOAT, {}, &variable)); // Create a Scalar float TensorHandle with value 42, and assign it to // the variable. - AbstractTensorHandlePtr my_value = CreateScalarTensorHandle(context(), 42.0); + ImmediateTensorHandlePtr my_value = CreateScalarTensorHandle(context(), 42.0); TF_EXPECT_OK(internal::AssignVariable(context(), variable.get(), DT_FLOAT, my_value.get())); // Read back the value from the variable, and check that it is 42. - AbstractTensorHandlePtr read_value_handle; + ImmediateTensorHandlePtr read_value_handle; TF_EXPECT_OK(internal::ReadVariable(context(), variable.get(), DT_FLOAT, &read_value_handle)); Status status; diff --git a/tensorflow/c/experimental/saved_model/internal/BUILD b/tensorflow/c/experimental/saved_model/internal/BUILD index 72474940c16..888c284bb12 100644 --- a/tensorflow/c/experimental/saved_model/internal/BUILD +++ b/tensorflow/c/experimental/saved_model/internal/BUILD @@ -178,7 +178,7 @@ cc_library( ":tensorhandle_list_type", "//tensorflow/c:c_api_macros", "//tensorflow/c/eager:c_api", - "//tensorflow/c/eager:tensor_handle_interface", + "//tensorflow/c/eager:immediate_execution_tensor_handle", "//tensorflow/c/eager:tfe_tensorhandle_internal", ], ) @@ -190,7 +190,7 @@ cc_library( ], deps = [ "//tensorflow/c:conversion_macros", - "//tensorflow/c/eager:tensor_handle_interface", + "//tensorflow/c/eager:immediate_execution_tensor_handle", ], ) diff --git a/tensorflow/c/experimental/saved_model/internal/tensorhandle_list.cc b/tensorflow/c/experimental/saved_model/internal/tensorhandle_list.cc index 7d018658101..c8f00c1f7c0 100644 --- a/tensorflow/c/experimental/saved_model/internal/tensorhandle_list.cc +++ b/tensorflow/c/experimental/saved_model/internal/tensorhandle_list.cc @@ -17,7 +17,7 @@ limitations under the License. #include -#include "tensorflow/c/eager/tensor_handle_interface.h" +#include "tensorflow/c/eager/immediate_execution_tensor_handle.h" #include "tensorflow/c/eager/tfe_tensorhandle_internal.h" #include "tensorflow/c/experimental/saved_model/internal/tensorhandle_list_type.h" diff --git a/tensorflow/c/experimental/saved_model/internal/tensorhandle_list_type.h b/tensorflow/c/experimental/saved_model/internal/tensorhandle_list_type.h index 8cbec2806a8..566417df025 100644 --- a/tensorflow/c/experimental/saved_model/internal/tensorhandle_list_type.h +++ b/tensorflow/c/experimental/saved_model/internal/tensorhandle_list_type.h @@ -19,7 +19,7 @@ limitations under the License. #include #include "tensorflow/c/conversion_macros.h" -#include "tensorflow/c/eager/tensor_handle_interface.h" +#include "tensorflow/c/eager/immediate_execution_tensor_handle.h" // Internal structures used by the SavedModel C API. These are likely to // change and should not be depended on. @@ -29,7 +29,7 @@ typedef struct TF_TensorHandleList TF_TensorHandleList; namespace tensorflow { DEFINE_CONVERSION_FUNCTIONS( - std::vector, + std::vector, TF_TensorHandleList) } // namespace tensorflow diff --git a/tensorflow/c/tensor_interface.h b/tensorflow/c/tensor_interface.h index eb0d28b0bf9..d165c84980c 100644 --- a/tensorflow/c/tensor_interface.h +++ b/tensorflow/c/tensor_interface.h @@ -54,6 +54,20 @@ class AbstractTensorInterface { virtual ~AbstractTensorInterface() {} }; +namespace internal { +struct AbstractTensorInterfaceDeleter { + void operator()(AbstractTensorInterface* p) const { + if (p != nullptr) { + p->Release(); + } + } +}; +} // namespace internal + +using AbstractTensorPtr = + std::unique_ptr; + } // namespace tensorflow #endif // TENSORFLOW_C_TENSOR_INTERFACE_H_ diff --git a/tensorflow/cc/gradients/nn_grad_test.cc b/tensorflow/cc/gradients/nn_grad_test.cc index 942ec08f451..f5a09e09dcd 100644 --- a/tensorflow/cc/gradients/nn_grad_test.cc +++ b/tensorflow/cc/gradients/nn_grad_test.cc @@ -259,9 +259,6 @@ TEST_F(NNGradTest, MaxPoolGradV2Helper) { RunTest(x, x_init_value, y, y_shape); } -// TODO(rocm): -// Re-enable this test once 3D pooling is supported on ROCm platform -#ifndef TENSORFLOW_USE_ROCM TEST_F(NNGradTest, MaxPool3DGradHelper) { TensorShape x_shape({1, 3, 3, 3, 1}); TensorShape y_shape({1, 1, 1, 1, 1}); @@ -274,7 +271,6 @@ TEST_F(NNGradTest, MaxPool3DGradHelper) { SetRandomValuesForMaxPooling(&x_init_value); RunTest(x, x_init_value, y, y_shape); } -#endif TEST_F(NNGradTest, AvgPoolGradHelper) { TensorShape x_shape({1, 2, 2, 1}); @@ -287,9 +283,6 @@ TEST_F(NNGradTest, AvgPoolGradHelper) { RunTest(x, x_shape, y, y_shape); } -// TODO(rocm): -// Re-enable this test once 3D pooling is supported on ROCm platform -#ifndef TENSORFLOW_USE_ROCM TEST_F(NNGradTest, AvgPool3DGradHelper) { TensorShape x_shape({1, 3, 3, 3, 1}); TensorShape y_shape({1, 1, 1, 1, 1}); @@ -300,7 +293,6 @@ TEST_F(NNGradTest, AvgPool3DGradHelper) { auto y = AvgPool3D(scope_, x, ksize, strides, "SAME"); RunTest(x, x_shape, y, y_shape); } -#endif TEST_F(NNGradTest, LRN) { TensorShape x_shape({1, 1, 2, 1}); diff --git a/tensorflow/compiler/jit/kernels/xla_ops.cc b/tensorflow/compiler/jit/kernels/xla_ops.cc index 0fc1a349adc..e3542586c89 100644 --- a/tensorflow/compiler/jit/kernels/xla_ops.cc +++ b/tensorflow/compiler/jit/kernels/xla_ops.cc @@ -108,8 +108,7 @@ class XlaExecutableClosure { explicit XlaExecutableClosure( xla::LocalClient* client, xla::LocalExecutable* executable, const XlaCompiler::CompilationResult* compilation_result, - std::map resource_var_snapshots, - int num_constant_args) + ResourceVarsSnapshot resource_var_snapshots, int num_constant_args) : client_(client), executable_(executable), compilation_result_(compilation_result), @@ -124,7 +123,7 @@ class XlaExecutableClosure { const XlaCompiler::CompilationResult* compilation_result() const { return compilation_result_; } - const std::map& resource_var_snapshots() const { + const ResourceVarsSnapshot& resource_var_snapshots() const { return resource_var_snapshots_; } int num_constant_args() const { return num_constant_args_; } @@ -133,7 +132,7 @@ class XlaExecutableClosure { xla::LocalClient* client_; xla::LocalExecutable* executable_; const XlaCompiler::CompilationResult* compilation_result_; - std::map resource_var_snapshots_; + ResourceVarsSnapshot resource_var_snapshots_; int num_constant_args_; TF_DISALLOW_COPY_AND_ASSIGN(XlaExecutableClosure); @@ -276,10 +275,10 @@ static Status BuildCompilationCache(OpKernelContext* ctx, static Status CompileToLocalExecutable( OpKernelContext* ctx, const NameAttrList& function, bool has_ref_vars, - const XlaPlatformInfo& platform_info, absl::Span resources, + const XlaPlatformInfo& platform_info, + absl::Span variable_infos, absl::Span constants, bool lazy, xla::LocalClient** client, - std::map* variables, - const XlaCompiler::CompilationResult** kernel, + const XlaCompiler::CompilationResult** compilation_result, xla::LocalExecutable** executable) { // We store information about the JIT-compiled XLA computation // in the ResourceMgr. @@ -299,7 +298,6 @@ static Status CompileToLocalExecutable( // this is more obviously correct.) core::ScopedUnref cache_ref(cache); - TF_RETURN_IF_ERROR(SnapshotResourceVariables(ctx, resources, variables)); *client = static_cast(cache->client()); absl::optional tf_allocator_adapter; @@ -337,11 +335,11 @@ static Status CompileToLocalExecutable( std::vector args; TF_RETURN_IF_ERROR(XlaComputationLaunchContext::BuildXlaCompilerArguments( - constant_args, *variables, ctx, &args)); + constant_args, variable_infos, ctx, &args)); return cache->Compile(options, function, args, compile_options, lazy ? XlaCompilationCache::CompileMode::kLazy : XlaCompilationCache::CompileMode::kStrict, - kernel, executable); + compilation_result, executable); } void XlaLocalLaunchBase::Compute(OpKernelContext* ctx) { @@ -349,16 +347,22 @@ void XlaLocalLaunchBase::Compute(OpKernelContext* ctx) { << Canonicalize(function_.name(), AttrSlice(&function_.attr())); xla::LocalClient* client; - const XlaCompiler::CompilationResult* kernel; + const XlaCompiler::CompilationResult* compilation_result; xla::LocalExecutable* executable; - std::map variables; + ResourceVarsSnapshot variables; { + std::vector variable_infos; + OP_REQUIRES_OK( + ctx, GetVariableInfosFromCtxInputs(ctx, resources_, &variable_infos)); + OP_REQUIRES_OK(ctx, LockVariables(absl::MakeSpan(variable_infos))); Status s = CompileToLocalExecutable( ctx, function_, /*has_ref_vars=*/has_ref_vars_, platform_info_, - resources_, constants_, /*lazy=*/false, &client, &variables, &kernel, - &executable); + variable_infos, constants_, /*lazy=*/false, &client, + &compilation_result, &executable); OP_REQUIRES_OK(ctx, s); + OP_REQUIRES_OK(ctx, SnapshotResourceVariables(ctx, resources_, + variable_infos, &variables)); } se::Stream* stream = @@ -373,7 +377,7 @@ void XlaLocalLaunchBase::Compute(OpKernelContext* ctx) { client, allocator, /*allocate_xla_tensors=*/platform_info_.is_on_xla_device(), platform_info_.UseMultipleStreams()); - launch_context.PopulateInputs(ctx, kernel, variables, + launch_context.PopulateInputs(ctx, compilation_result, variables, /*missing_ctx_input_prefix=*/0); // Execute the computation. @@ -413,7 +417,7 @@ void XlaLocalLaunchBase::Compute(OpKernelContext* ctx) { executable->executable()->module().input_output_alias_config(); OP_REQUIRES_OK( ctx, launch_context.PopulateOutputs( - ctx, kernel, run_result.ConsumeValueOrDie(), + ctx, compilation_result, run_result.ConsumeValueOrDie(), /*missing_ctx_input_prefix=*/0, input_output_alias, variables)); VLOG(1) << "Done"; } @@ -494,7 +498,7 @@ void XlaCompileOp::Compute(OpKernelContext* ctx) { xla::LocalClient* client; const XlaCompiler::CompilationResult* kernel; xla::LocalExecutable* executable; - std::map variables; + ResourceVarsSnapshot variables; bool cannot_compile_cluster; { @@ -506,9 +510,16 @@ void XlaCompileOp::Compute(OpKernelContext* ctx) { cannot_compile_cluster) { executable = nullptr; } else { + std::vector variable_infos; + OP_REQUIRES_OK( + ctx, GetVariableInfosFromCtxInputs(ctx, resources_, &variable_infos)); + OP_REQUIRES_OK(ctx, LockVariables(absl::MakeSpan(variable_infos))); Status status = CompileToLocalExecutable( - ctx, function_, has_ref_vars_, platform_info_, resources_, constants_, - /*lazy=*/!must_compile_, &client, &variables, &kernel, &executable); + ctx, function_, has_ref_vars_, platform_info_, variable_infos, + constants_, + /*lazy=*/!must_compile_, &client, &kernel, &executable); + OP_REQUIRES_OK(ctx, SnapshotResourceVariables(ctx, resources_, + variable_infos, &variables)); if (must_compile_ || status.code() != error::UNIMPLEMENTED) { OP_REQUIRES_OK(ctx, status); } diff --git a/tensorflow/compiler/jit/mark_for_compilation_pass.cc b/tensorflow/compiler/jit/mark_for_compilation_pass.cc index 9f5723f4fa4..dc5df94e963 100644 --- a/tensorflow/compiler/jit/mark_for_compilation_pass.cc +++ b/tensorflow/compiler/jit/mark_for_compilation_pass.cc @@ -1837,7 +1837,7 @@ absl::flat_hash_map>* GetWhitelistTable() { "ConcatOffset", "Const", "MirrorPad", "Pack", "Pad", "PadV2", "Reverse", "ReverseV2", "ReverseSequence", "Slice", "Split", "SplitV", "StridedSlice", "StridedSliceGrad", "ResourceStridedSliceAssign", - "Tile", "Transpose", "InvertPermutation", "Unpack"}}}; + "Tile", "Transpose", "InvertPermutation", "Unpack", "DeviceIndex"}}}; // clang-format on return result; } diff --git a/tensorflow/compiler/jit/xla_compile_on_demand_op.cc b/tensorflow/compiler/jit/xla_compile_on_demand_op.cc index e1ad0e8c5af..afaee614f02 100644 --- a/tensorflow/compiler/jit/xla_compile_on_demand_op.cc +++ b/tensorflow/compiler/jit/xla_compile_on_demand_op.cc @@ -28,32 +28,23 @@ limitations under the License. namespace tensorflow { -namespace { -std::map GetVariables(OpKernelContext* ctx) { - std::map variables; - for (int64 i = 0; i < ctx->num_inputs(); ++i) { +// Returns argument indices corresponding to the resource variable inputs of +// kernel context `ctx`. +static std::vector GetResourceVariableIndices(OpKernelContext* ctx) { + std::vector out; + for (int64 i = 0; i < ctx->num_inputs(); i++) { if (ctx->input(i).dtype() == DT_RESOURCE) { - core::RefCountPtr variable; - ResourceHandle handle = HandleFromInput(ctx, i); - OptionalTensor& optional = variables[i]; - optional.name = handle.name(); - if (LookupResource(ctx, handle, &variable).ok()) { - tf_shared_lock lock(*variable->mu()); - optional.present = true; - optional.value = *variable->tensor(); - } + out.push_back(i); } } - return variables; + return out; } -} // namespace Status XlaCompileOnDemandOp::Run(OpKernelContext* ctx, const XlaDevice::Metadata& metadata, const XlaCompiler::CompilationResult* result, - xla::LocalExecutable* executable) { - std::map variables = GetVariables(ctx); - + xla::LocalExecutable* executable, + const ResourceVarsSnapshot& variable_args) { xla::LocalClient* client = metadata.client(); // Builds an XLA allocator for the device. @@ -62,7 +53,7 @@ Status XlaCompileOnDemandOp::Run(OpKernelContext* ctx, /*allocate_xla_tensors=*/true, /*use_multiple_streams=*/metadata.UseMultipleStreams()); - launch_context.PopulateInputs(ctx, result, variables, + launch_context.PopulateInputs(ctx, result, variable_args, /*missing_ctx_input_prefix=*/0); se::Stream* stream = @@ -87,7 +78,7 @@ Status XlaCompileOnDemandOp::Run(OpKernelContext* ctx, executable->executable()->module().input_output_alias_config(); TF_RETURN_IF_ERROR(launch_context.PopulateOutputs( ctx, result, run_result.ConsumeValueOrDie(), - /*missing_ctx_input_prefix=*/0, input_output_alias, variables)); + /*missing_ctx_input_prefix=*/0, input_output_alias, variable_args)); return Status::OK(); } @@ -115,7 +106,7 @@ Status XlaCompileOnDemandOp::ShouldArgumentBeConstant( Status XlaCompileOnDemandOp::Compile( OpKernelContext* ctx, const XlaDevice::Metadata& metadata, const XlaCompiler::CompilationResult** result, - xla::LocalExecutable** executable) { + ResourceVarsSnapshot* variable_args, xla::LocalExecutable** executable) { std::map constant_arguments; for (int64 i = 0; i < ctx->num_inputs(); ++i) { const Tensor& device_tensor = ctx->input(i); @@ -190,12 +181,18 @@ Status XlaCompileOnDemandOp::Compile( // rather than a one-element tuple. compile_options.always_return_tuple = false; - std::map variable_args = GetVariables(ctx); - + std::vector variables_indices = GetResourceVariableIndices(ctx); std::vector args; - - TF_RETURN_IF_ERROR(XlaComputationLaunchContext::BuildXlaCompilerArguments( - constant_arguments, variable_args, ctx, &args)); + { + std::vector variable_infos; + TF_RETURN_IF_ERROR( + GetVariableInfosFromCtxInputs(ctx, variables_indices, &variable_infos)); + TF_RETURN_IF_ERROR(LockVariables(absl::MakeSpan(variable_infos))); + TF_RETURN_IF_ERROR(SnapshotResourceVariables( + ctx, variables_indices, variable_infos, variable_args)); + TF_RETURN_IF_ERROR(XlaComputationLaunchContext::BuildXlaCompilerArguments( + constant_arguments, variable_infos, ctx, &args)); + } return cache->CompileSingleOp(options, args, ctx, compile_options, result, executable); @@ -206,8 +203,10 @@ void XlaCompileOnDemandOp::Compute(OpKernelContext* ctx) { xla::LocalExecutable* executable; const XlaDevice::Metadata* metadata; OP_REQUIRES_OK(ctx, XlaDevice::GetMetadata(ctx, &metadata)); - OP_REQUIRES_OK(ctx, Compile(ctx, *metadata, &result, &executable)); - OP_REQUIRES_OK(ctx, Run(ctx, *metadata, result, executable)); + ResourceVarsSnapshot variable_args; + OP_REQUIRES_OK(ctx, + Compile(ctx, *metadata, &result, &variable_args, &executable)); + OP_REQUIRES_OK(ctx, Run(ctx, *metadata, result, executable, variable_args)); } } // namespace tensorflow diff --git a/tensorflow/compiler/jit/xla_compile_on_demand_op.h b/tensorflow/compiler/jit/xla_compile_on_demand_op.h index 98f634db98f..cc5f2f1e42f 100644 --- a/tensorflow/compiler/jit/xla_compile_on_demand_op.h +++ b/tensorflow/compiler/jit/xla_compile_on_demand_op.h @@ -20,6 +20,7 @@ limitations under the License. #define TENSORFLOW_COMPILER_JIT_XLA_COMPILE_ON_DEMAND_OP_H_ #include "tensorflow/compiler/jit/xla_device.h" +#include "tensorflow/compiler/jit/xla_launch_util.h" #include "tensorflow/compiler/tf2xla/xla_compiler.h" #include "tensorflow/compiler/xla/client/local_client.h" #include "tensorflow/core/framework/function.h" @@ -47,10 +48,12 @@ class XlaCompileOnDemandOp : public OpKernel { bool* result); Status Compile(OpKernelContext* ctx, const XlaDevice::Metadata& metadata, const XlaCompiler::CompilationResult** result, + ResourceVarsSnapshot* variable_args, xla::LocalExecutable** executable); Status Run(OpKernelContext* ctx, const XlaDevice::Metadata& metadata, const XlaCompiler::CompilationResult* result, - xla::LocalExecutable* executable); + xla::LocalExecutable* executable, + const ResourceVarsSnapshot& variable_args); }; } // namespace tensorflow diff --git a/tensorflow/compiler/jit/xla_launch_util.cc b/tensorflow/compiler/jit/xla_launch_util.cc index 25eed134e35..eb31b23c991 100644 --- a/tensorflow/compiler/jit/xla_launch_util.cc +++ b/tensorflow/compiler/jit/xla_launch_util.cc @@ -52,7 +52,8 @@ const char kPossibleNonVariableResourceHintMessage[] = "resource inputs to XLA."; } // anonymous namespace -VariableInfo::VariableInfo(int index, Var* var) : index_(index), var_(var) {} +VariableInfo::VariableInfo(int index, absl::string_view name, Var* var) + : index_(index), name_(name), var_(var) {} VariableInfo::VariableInfo(VariableInfo&& other) : index_(other.index_), var_(other.var_), lock_held_(other.lock_held_) { other.index_ = -1; @@ -87,16 +88,15 @@ VariableInfo::~VariableInfo() { // Returns a vector of VariableInfo instances for the resource variable inputs // to the kernel with context `ctx`. The input indices for the resource // variable inputs are in `variable_indices`. -static Status GetVariableInfosFromCtxInputs( - OpKernelContext* ctx, absl::Span variable_indices, - std::vector* result) { +Status GetVariableInfosFromCtxInputs(OpKernelContext* ctx, + absl::Span variable_indices, + std::vector* result) { std::vector resource_handles; absl::c_transform( variable_indices, std::back_inserter(resource_handles), [&](int variable_idx) { return &HandleFromInput(ctx, variable_idx); }); std::vector> variables; - Status s = LookupResources(ctx, resource_handles, &variables); if (!s.ok()) { errors::AppendToMessage(&s, kPossibleNonVariableResourceHintMessage); @@ -109,7 +109,9 @@ static Status GetVariableInfosFromCtxInputs( // *Release* the variable because we're going to unref it later in // ~VariableInfo. Var* variable = variables[i].release(); - result->emplace_back(variable_indices[i], variable); + int input_idx = variable_indices[i]; + std::string var_name = HandleFromInput(ctx, input_idx).name(); + result->emplace_back(input_idx, var_name, variable); } return Status::OK(); @@ -162,21 +164,12 @@ Status LockVariables(absl::Span variables) { Status SnapshotResourceVariables(OpKernelContext* ctx, absl::Span variable_indices, - std::map* result) { - std::vector variable_infos; - TF_RETURN_IF_ERROR( - GetVariableInfosFromCtxInputs(ctx, variable_indices, &variable_infos)); - TF_RETURN_IF_ERROR(LockVariables(absl::MakeSpan(variable_infos))); - + absl::Span variable_infos, + ResourceVarsSnapshot* result) { for (int i = 0; i < variable_indices.size(); i++) { - if (variable_infos[i].var()) { - OptionalTensor& tensor = (*result)[variable_indices[i]]; - tensor.name = HandleFromInput(ctx, variable_indices[i]).name(); - tensor.present = true; - tensor.value = *variable_infos[i].var()->tensor(); - } else { - (*result)[variable_indices[i]] = OptionalTensor(); - } + Var* var = variable_infos[i].var(); + (*result)[variable_indices[i]] = + var ? absl::make_optional(*var->tensor()) : absl::nullopt; } return Status::OK(); } @@ -197,8 +190,7 @@ XlaComputationLaunchContext::XlaComputationLaunchContext( void XlaComputationLaunchContext::PopulateInputs( OpKernelContext* ctx, const XlaCompiler::CompilationResult* compilation_result, - const std::map& variables, - int missing_ctx_input_prefix) { + const ResourceVarsSnapshot& variables, int missing_ctx_input_prefix) { // Build ShapedBuffers that point directly to the Tensor buffers. arg_ptrs_ = std::vector(compilation_result->xla_input_shapes.size()); @@ -210,7 +202,7 @@ void XlaComputationLaunchContext::PopulateInputs( CHECK_GE(arg_num, missing_ctx_input_prefix); const xla::Shape& shape = compilation_result->xla_input_shapes[i]; const Tensor* t = variables.count(arg_num) - ? &(variables.at(arg_num).value) + ? &(variables.at(arg_num).value()) : &(ctx->input(arg_num - missing_ctx_input_prefix)); CHECK(t); @@ -262,7 +254,7 @@ static const Tensor* FindAliasedTensorForOutput( int output_num, OpKernelContext* ctx, int missing_ctx_input_prefix, const xla::HloInputOutputAliasConfig& input_output_alias, absl::Span input_mapping, - const std::map& resource_var_snapshots) { + const ResourceVarsSnapshot& resource_var_snapshots) { if (MustAliasOutput(input_output_alias, output_num)) { int xla_param = input_output_alias.GetAliasedParameter({output_num}) .value() @@ -274,8 +266,8 @@ static const Tensor* FindAliasedTensorForOutput( // entry time. if (input_tensor->dtype() == DT_RESOURCE) { auto& v = resource_var_snapshots.at(missing_ctx_input_prefix + tf_param); - CHECK(v.present); - return &v.value; + CHECK(v.has_value()); + return &v.value(); } return input_tensor; } @@ -298,9 +290,9 @@ static Tensor GetOrCreateTensorForOutput( int output_num, OpKernelContext* ctx, int missing_ctx_input_prefix, const xla::HloInputOutputAliasConfig& input_output_alias, absl::Span input_mapping, - const std::map& resource_var_snapshots, - DataType output_dtype, const TensorShape& output_shape, - se::DeviceMemoryBase output_buffer, Allocator* output_allocator) { + const ResourceVarsSnapshot& resource_var_snapshots, DataType output_dtype, + const TensorShape& output_shape, se::DeviceMemoryBase output_buffer, + Allocator* output_allocator) { if (const Tensor* aliased_tensor = FindAliasedTensorForOutput( output_num, ctx, missing_ctx_input_prefix, input_output_alias, input_mapping, resource_var_snapshots)) { @@ -431,13 +423,13 @@ static xla::StatusOr> GatherVariableInfo( // TODO(b/35625933): tensorflow::Var should contain a PersistentTensor, // not a Tensor. Var* variable = nullptr; - TF_RETURN_IF_ERROR(LookupOrCreateResource( - ctx, HandleFromInput(ctx, actual_input_index), &variable, - [&write](Var** ptr) { - *ptr = new Var(write.type); - return Status::OK(); - })); - variable_infos.emplace_back(actual_input_index, variable); + const ResourceHandle handle = HandleFromInput(ctx, actual_input_index); + TF_RETURN_IF_ERROR(LookupOrCreateResource(ctx, handle, &variable, + [&write](Var** ptr) { + *ptr = new Var(write.type); + return Status::OK(); + })); + variable_infos.emplace_back(actual_input_index, handle.name(), variable); } return variable_infos; } @@ -447,7 +439,7 @@ Status XlaComputationLaunchContext::PopulateOutputs( const XlaCompiler::CompilationResult* compilation_result, ScopedShapedBuffer output, int missing_ctx_input_prefix, const xla::HloInputOutputAliasConfig& input_output_alias, - const std::map& resource_var_snapshots) { + const ResourceVarsSnapshot& resource_var_snapshots) { se::Stream* stream = ctx->op_device_context() ? ctx->op_device_context()->stream() : nullptr; Allocator* allocator = ctx->device()->GetAllocator({}); @@ -484,10 +476,36 @@ Status XlaComputationLaunchContext::PopulateOutputs( stream->ThenRecordEvent(definition_event.get()); } + std::vector output_tensor_shapes; + output_tensor_shapes.reserve(ctx->num_outputs()); + if (output.on_host_shape().is_dynamic()) { + TF_ASSIGN_OR_RETURN( + auto transfer_manager, + xla::TransferManager::GetForPlatform(stream->parent()->platform())); + + xla::Shape output_host_shape = output.on_host_shape(); + xla::Shape output_device_shape = output.on_device_shape(); + TF_RETURN_IF_ERROR(transfer_manager->ReadDynamicShapes( + stream, &output, &output_host_shape, &output_device_shape)); + + output.set_shapes(output_host_shape, output_device_shape); + for (int i = 0; i < ctx->num_outputs(); ++i) { + const xla::Shape& subshape = + xla::ShapeUtil::GetSubshape(output_host_shape, {i}); + TensorShape shape; + TF_RETURN_IF_ERROR(XLAShapeToTensorShape(subshape, &shape)); + output_tensor_shapes.push_back(shape); + } + } else { + for (int i = 0; i < ctx->num_outputs(); ++i) { + output_tensor_shapes.push_back(compilation_result->outputs[i].shape); + } + } + // Copy XLA results to the OpOutputList. int output_num = 0; for (int i = 0; i < ctx->num_outputs(); ++i) { - const TensorShape& shape = compilation_result->outputs[i].shape; + const TensorShape& shape = output_tensor_shapes[i]; const DataType& type = compilation_result->outputs[i].type; VLOG(2) << "Retval " << i << " shape " << shape.DebugString() << " type " << DataTypeString(type); @@ -564,12 +582,21 @@ Status XlaComputationLaunchContext::PopulateOutputs( Status XlaComputationLaunchContext::BuildXlaCompilerArguments( const std::map& constant_args, - const std::map& variable_args, OpKernelContext* ctx, + absl::Span variable_args, OpKernelContext* ctx, std::vector* args) { args->resize(ctx->num_inputs()); + absl::flat_hash_map variable_info_lookup; + for (const VariableInfo& info : variable_args) { + CHECK(!info.var() || info.lock_held()) + << "Need to hold the lock on resource variables " + "before calling BuildXlaCompilerArguments"; + variable_info_lookup.emplace(info.index(), &info); + } + for (int64 input_num = 0; input_num < ctx->num_inputs(); ++input_num) { XlaCompiler::Argument& arg = (*args)[input_num]; + if (constant_args.count(input_num) > 0) { // Handles compile-time constants. const Tensor& input = constant_args.at(input_num); @@ -578,7 +605,7 @@ Status XlaComputationLaunchContext::BuildXlaCompilerArguments( arg.type = input.dtype(); arg.shape = input.shape(); arg.constant_value = input; - } else if (variable_args.count(input_num) == 0) { + } else if (variable_info_lookup.count(input_num) == 0) { // Handles the non-constant arguments. const Tensor& input = ctx->input(input_num); TF_RET_CHECK(input.dtype() != DT_RESOURCE); @@ -594,14 +621,14 @@ Status XlaComputationLaunchContext::BuildXlaCompilerArguments( // Handles resource variables. const Tensor& input = ctx->input(input_num); TF_RET_CHECK(input.dtype() == DT_RESOURCE); - const OptionalTensor& variable = variable_args.at(input_num); - arg.name = variable.name; + const VariableInfo& variable = *variable_info_lookup[input_num]; + arg.name = std::string(variable.name()); arg.kind = XlaCompiler::Argument::kResource; arg.resource_kind = XlaResource::kVariable; - if (variable.present) { - const Tensor& value = variable.value; - arg.type = value.dtype(); - arg.shape = value.shape(); + if (variable.var()) { + const Tensor* value = variable.var()->tensor(); + arg.type = value->dtype(); + arg.shape = value->shape(); arg.initialized = true; } else { // The values of uninitialized variables are not passed as inputs, since diff --git a/tensorflow/compiler/jit/xla_launch_util.h b/tensorflow/compiler/jit/xla_launch_util.h index 9a7f20cb310..92b6c4c8a08 100644 --- a/tensorflow/compiler/jit/xla_launch_util.h +++ b/tensorflow/compiler/jit/xla_launch_util.h @@ -34,36 +34,17 @@ limitations under the License. namespace tensorflow { -// Struct that represents a possibly-absent Tensor. -struct OptionalTensor { - string name; // A descriptive name - bool present = false; // Is the tensor present? - Tensor value; // If present, what is the Tensor's value? -}; - -// Takes a snapshot of the values of resource variable arguments, whose indices -// are specified in `variable_indices` argument. We snapshot tensors that back -// resource variables since concurrent updates may modify the shape, and it is -// important that the shapes used for compilation match the true shapes of the -// buffers. -// -// We snapshot the entire set of resource variables as one atomic operation. -// This models Read->* dependencies between resource variable operations. See -// jit/resource_operation_safety_analysis for details. -// -// Returns a map of TensorFlow argument index to resource variable. If a -// resource variable is not initialized, the corresponding OptionalTensor -// will have its `present` field set to false. -Status SnapshotResourceVariables(OpKernelContext* ctx, - absl::Span variable_indices, - std::map* result); +// Snapshot of resource variables for a TF kernel invocation, mapping from +// parameter number to values at execution time. If the resource variable is not +// initialized, the value will not be present. +using ResourceVarsSnapshot = absl::flat_hash_map>; // Information about the state of a variable passed as input to the _XlaCompile // and _XlaRun operators. Unlocks the resource variable and decrements its // refcount on destruction. class VariableInfo { public: - explicit VariableInfo(int index, Var* var); + explicit VariableInfo(int index, absl::string_view name, Var* var); VariableInfo(VariableInfo&& other); VariableInfo& operator=(VariableInfo&& other); @@ -79,6 +60,9 @@ class VariableInfo { // "empty", i.e. it does not track a resource variable. Var* var() const { return var_; } + // Returns the variable name. + absl::string_view name() const { return name_; } + // Returns true if the resource variable lock was successfully acquired by // this thread. bool lock_held() const { return lock_held_; } @@ -88,6 +72,7 @@ class VariableInfo { private: int index_; + std::string name_; Var* var_; // We can't use a optional here because it confuses the compiler's @@ -96,6 +81,20 @@ class VariableInfo { bool lock_held_ = false; }; +// Takes a snapshot of the values of resource variable arguments, whose indices +// are specified in `variable_indices` argument. We snapshot tensors that back +// resource variables since concurrent updates may modify the shape, and it is +// important that the shapes used for compilation match the true shapes of the +// buffers. +// +// We snapshot the entire set of resource variables as one atomic operation. +// This models Read->* dependencies between resource variable operations. See +// jit/resource_operation_safety_analysis for details. +Status SnapshotResourceVariables(OpKernelContext* ctx, + absl::Span variable_indices, + absl::Span variable_infos, + ResourceVarsSnapshot* result); + // Acquires the mutexes for all the variables in `variables` using a // deadlock-safe protocol (acquire the mutexes in increasing-address order). // @@ -104,6 +103,13 @@ class VariableInfo { Status LockVariables(absl::Span variables) TF_EXCLUSIVE_LOCK_FUNCTION(); +// Returns a vector of VariableInfo instances for the resource variable inputs +// to the kernel with context `ctx`. The input indices for the resource +// variable inputs are in `variable_indices`. +Status GetVariableInfosFromCtxInputs(OpKernelContext* ctx, + absl::Span variable_indices, + std::vector* result); + // Helper class to perform the marshalling of TensorFlow inputs and outputs to // ShapedBuffers suitable for passing to an XLA computation. class XlaComputationLaunchContext { @@ -123,9 +129,10 @@ class XlaComputationLaunchContext { // Builds a XlaCompiler::Argument vector from the arguments to an XlaLaunch // op. + // Precondition: variables in `variable_args` are locked. static Status BuildXlaCompilerArguments( const std::map& constant_args, - const std::map& variable_args, OpKernelContext* ctx, + absl::Span variable_args, OpKernelContext* ctx, std::vector* args); // Add all inputs within `ctx` as XLA arguments (returned by arguments()). @@ -137,7 +144,7 @@ class XlaComputationLaunchContext { // (in other words, no inputs actually required by the kernel can be missing). void PopulateInputs(OpKernelContext* ctx, const XlaCompiler::CompilationResult* compilation_result, - const std::map& variables, + const ResourceVarsSnapshot& variables, int missing_ctx_input_prefix); // Given the XLA output in `output`, populate all outputs of `ctx`. Also @@ -155,7 +162,7 @@ class XlaComputationLaunchContext { const XlaCompiler::CompilationResult* compilation_result, xla::ScopedShapedBuffer output, int missing_ctx_input_prefix, const xla::HloInputOutputAliasConfig& input_output_alias, - const std::map& resource_var_snapshots); + const ResourceVarsSnapshot& resource_var_snapshots); // Return the argument list. Only valid after PopulateInputs() has been // called. diff --git a/tensorflow/compiler/mlir/g3doc/xla_gpu_codegen.md b/tensorflow/compiler/mlir/g3doc/xla_gpu_codegen.md new file mode 100644 index 00000000000..06c55abf1fa --- /dev/null +++ b/tensorflow/compiler/mlir/g3doc/xla_gpu_codegen.md @@ -0,0 +1,265 @@ +# MLIR CodeGen for XLA + + + +XLA operates on `HloInstruction` and performs many optimizations on this +representation, sharing a lot of these between targeted devices. As some point a +linear schedule is computed and the memory buffer is assigned to each value +statically. The device specific codegen operates by traversing this sequence and +calling "emitters" to generate a representation suitable for the device (for +example a single LLVM function per XLA computation on CPU, or a sequence of +"thunks" encapsulating GPU operations and possibly generated PTX when targeting +GPU). + +As a staging step, we're currently in the process of intercepting the process +right after XLA completes the buffer-assignment phase and emit instead an MLIR +module in the `lhlo` dialect. From there we perform the codegen using MLIR +components (Linalg, affine, and GPU dialect mainly) depending on the device. + +Below is the plan of record to incrementally migrate XLA/GPU by using `lhlo` as +the codegen input. + +## Tasks + + | Host | Device +------------- | ------------------------ | ------------------------ +Input format | HloInstruction* (Task 1) | HloInstruction* (Task 1) +Output format | xla::Thunk (Task 2) | LLVM IR (Task 3) + +* **Task 1** changes both host and device input format from HloInstruction* to + LHLO. +* **Task 2** changes output format of host from thunks to "some landing pad + for host" (see below). +* **Task 3** migrates device output from LLVM IR to some form of MLIR. It's + optional to this project, and see the section "Migrating Device LLVM IR" for + details. + +This project prioritizes having end-to-end runnable models with LHLO-emitters +enabled as much as possible. This implies that the following order list of +objectives by priority: + +* Make XLA/GPU runnable with LHLO emitters, with existing Thunks and emitters + unmodified. +* Eliminate the references to HloInstruction\* in LHLO, case by case: + * Switch a legacy emitter to an MLIR-based emitter (e.g. Linalg), or + * Mechanically translate the existing emitter to take MLIR representation + (migrate to Standard with GPU Dialect). + +## Migrating Thunks (Task 2) + +xla::gpu::Thunk is a data structure that: + +* Can be called into from the host (xla::gpu::Thunk::ExecuteOnStream()). +* Carries various data in its subclasses. +* Interacts with BufferAllocation::Slice and StreamExecutor. +* Launches kernels +* Calls into all runtime libraries. + +The cost of that includes: + +* Representing op-specific configuration data (e.g. convolution configs). +* Migrating op shape and operand shapes. +* Representing a tree of thunks (while, condition, etc). + +The migration work is independent from LHLO / emitter migration. Under limited +resources, it's prioritized behind LHLO / emitter migration. + +We have several choices on how to lower the host-side part from LHLO: + +* TFRT + * (Pro) great CUDA and HIP wrappers for use. + * (Pro) easy to implement library calls (cuDNN, cuBLAS, cuFFT, etc), as + TFRT ops are interpreted by C++ code. + * (Con) host side is under development and not tested. + * (Con) the JAX integration isn’t clear from a runtime point of view +* Jitted CPU code + * (Pro) great lower-ability. Create a few loops and conditions and it's + done. + * (Con) GPUDialect doesn't yet model chains/streams/asynchronicity/device + allocation. + * (Con) CUDA / HIP runtime support is minimal (toolkit path, version, + dynamic loading, etc). +* Existing (interpreting) XLA runtime + +Tentative conclusion: Use jitted CPU code during the transition, and optionally +adopt TFRT in the end. + +## Migrating Device LLVM IR (Task 3) + +An elemental emitter generates target op by filling it element by element. Each +output element depends on a set of elements from the operands. All elements are +described by combining the buffer with dynamic indices. It's sufficient to +describe almost all "math" ops, but for performance reasons only a large subset +of "math" ops are implemented directly in (Cpu|Gpu)ElementalIrEmitter. + +ElementalIrEmitter is unique in that: + +* A large portion of the code is shared between XLA/GPU and CPU. +* It represents a large portion of ops seen in models, including all + element-wise ops. +* Most fusions solely depend on ElementalIrEmitter. +* It's structurally simple, as it describes a data dependency DAG between op + elements and operand elements. +* It's mostly portable and high-level (e.g. unlike GPU kReduce and GPU kCopy). +* Dynamic shape support is easy for at least element-wise ops. + +Now, for all ops, elementally-emitted or not, there are several flavors of the +end state of each XLA op: + +1. Device code stays as LLVM IR. +1. Refactor the old emitter to be like LHLO -> MLIR LLVM Dialect: + * (Cost) Will be throw-away work if we want to ultimately migrate to + Standard. + * (Benefit) It is easy and mechanical. Can be done in a short period. + * (Benefit) It doesn't benefit more compared to a). +1. Refactor old emitters to be like LHLO -> MLIR GPU + Standard + Loops: + * (Cost) Lifting existing emitters to Standard introduces some challenges. + Pointers and GEPs need to be converted to MemRefs and SubViews. Ensuring + amdgpu completeness is another one. + * (Cost) XLA/GPU heavily relies on LLVM metadata: + * `range` for block/thread indices. + * `align`, `dereferenceable`, `invariant.load`, `alias.scope`, + `noalias` for load/stores. + * `llvm.loop.unroll.disable`, `llvm.loop.unroll.full`, + `llvm.loop.vectorize.enable` for sequential loops. + * (Benefit) Can be long-term. More portable. +1. Refactor old emitters to be LHLO -> Linalg, and write new Linalg emitters + * (Cost) This is case by case. Compared to previous options, a new + implementation that matches XLA's performance needs to go through the + benchmark <-> optimize workflow, which can be a significant cost for + some ops. + * (Benefit) unified stack; community support; portability; more + optimization potentials. + +## Prioritization + +While all three tasks mentioned above are parallelizable, under limited +resources they have to be serialized. The prioritization focuses on visible +results for completion of each task. + +The prioritization is: Task1 (LHLO for legacy emitters) > Task 2 (Thunks) > Task +3 (MLIR emitters). + +By the end of Task 1, users of XLA can generate an LHLO (e.g. kernel generator) +and execute them. The compilation format will not be serializable MLIR. + +By the end of Task 2, LHLO lowers to proper, serializable MLIR. This enables +offline compilation. + +By the end of Task 3, all XLA emitters are MLIR-based in its implementation. + +## Detailed Design + +### Step 1: (Task 1) Complete LHLO and Make Legacy Emitters Take LHLO + +This step makes all existing XLA/GPU emitters interact with MLIR ops. This step +is pure refactoring and NFC. + +This step is mostly mechanical, but it's worth noticing the following +discrepancies between an unnested HloComputation and LHLO: + +* Each HloInstruction has direct access to its operands (a data-flow DAG). On + contrary, each LHLO op only has access to its operand buffers (a bipartite + between ops and buffers). LHLO ops have to go through use-def chains to + access their operand ops. +* Unnested legacy emitters empirically almost never access their operands. The + only exception is kReduce. +* Unnested legacy emitters access BufferAssignment only for getting slices, + not for accessing aux data structures like dataflow\_analysis() or + alias\_analysis(). llvm\_ir builds its own alias\_analysis() based on slice + information. + +The conclusion is that LHLO should fit right-in without major hassle. + +### Step 2: (Optional) Profiling Support + +**This step is only needed if we start to discard some of the XLA Thunk logic +(see the next step).** + +Before actually turning on any MLIR-based emitters, we need profiling for +MLIR-based emitters. + +Currently XLA performs its own profiling by calling into StreamExecutor's timer. +The timer under the hood inserts two events before and after a kernel launch, +and measures the sync time between these two events. + +There are roughly three approaches to support profiling in MLIR: + +* Run a profiler end-to-end +* Add a profile op for each op in LHLO, using an injected profiler. + +The "end-to-end" approach is transparent to MLIR, but suffers the same problem +that makes XLA not use it in the first place: library calls collected by a +profiler (nvprof/...) can't easily relate to HLO ops. For example, cuDNN +launches multiple kernels for each HLO, and it's hard to tell which kernels +correspond to which HLO. + +The "injected profiler" approach requires: + +* LHLO to take a profiler as a parameter. +* inserting profile.start / profile.end before and after each op. +* a pass from that lowers profile.{start,end} to a C++ implementation. + +The exact profiling can't be easily done for MLIR-generated ops, since: + +* MLIR doesn't have a timer, nor it depends on TFRT / StreamExecutor. +* MLIR doesn't easily call into C functions with complicated parameters. + +### Step 3: (Task 2) Migrating Thunks + +This step migrates all host ops and library calls. This step will eliminate most +of the thunks and produce serializable MLIR instead. + +There are roughly three kinds of thunks: + +* KernelThunk, which launches a kernel. +* Control flow thunks, which has host control flow logic (conditional, while, + for, sequence) and launch body kernels. +* Library thunks: cuDNN, cuBLAS, cuFFT, NCCL, etc. + +The **bottom line** is to: + +* Create a Thunk dialect that provides (de)serialize logic for all existing + C++-based Thunks. +* Change emitters to emit a graph of Thunk dialect. + +**Optionally**, we can relieve some thunks from C++ implementation. KernelThunk +can lower to the GPU LaunchKernelOp. Control flow thunks can leverage the CFG +Dialect for loops and conditions, combined with LaunchKernelOp. This optional +step requires profiling and stream support. + +### Step 4: (Task 3) Migrated ElementalIrEmitter + +Once profiling is ready, we can complete and tune all ElementalIrEmitter-based +emitters in MLIR. Then we turn them on by default, assuming that all of these +MLIR-based emitters use a single stream. + +Notice that it's beneficial to migrate XLA/CPU's ElementalIrEmitter as well, +since they share a large portion of the code. + +With all benchmarking and performance hunting done (TODO: define performance +parity), we turn on the new MLIR-based elemental emitter, and delete the legacy +ElementalIrEmitter. + +This step also provides easy fusion transitions (nested ops) for the later +migration. + +### Step 5: Multi-Stream Support or Drop + +We can't delete +[some of the emitters](https://github.com/tensorflow/tensorflow/blob/master/tensorflow/compiler/xla/service/gpu/stream_assignment.cc#L140) +until we support it in MLIR, or we drop the feature. It's a relatively large +amount of work in MLIR and a small amount of gain for XLA. We should investigate +current users of multi-stream XLA/GPU users, and try to delete this feature if +reasonable. + +### Step 6: (Task 3) Migrated Device Ops + +This step migrates all unnested ops, then we can delete all unnested emitters. + +This calls on a rewrite/refactor for kCopy and kReduce. kReduce is already +worked on for plenty, so the actual amount of work that needs to be done remains +to be seen. diff --git a/tensorflow/compiler/mlir/lite/BUILD b/tensorflow/compiler/mlir/lite/BUILD index 8e9d615053c..8d4efeb3d60 100644 --- a/tensorflow/compiler/mlir/lite/BUILD +++ b/tensorflow/compiler/mlir/lite/BUILD @@ -314,7 +314,6 @@ tf_cc_test( cc_library( name = "tensorflow_lite_legalize_tf", srcs = [ - "transforms/device_index_selector.cc", "transforms/dilated_conv.cc", "transforms/generated_legalize_tf.inc", "transforms/generated_lower_static_tensor_list.inc", diff --git a/tensorflow/compiler/mlir/lite/converter_gen.cc b/tensorflow/compiler/mlir/lite/converter_gen.cc index 6df569a8031..edead2037a3 100644 --- a/tensorflow/compiler/mlir/lite/converter_gen.cc +++ b/tensorflow/compiler/mlir/lite/converter_gen.cc @@ -446,7 +446,7 @@ static void GenOperandResultVerifier(raw_ostream &os, auto desc = definit->getDef()->getValueAsString("tflRuntimeTypeDescription"); - // Emit a loop to check all the dynamic values in the pack. + // Emit a loop to check all operands. os << formatv(" for (Value v : top.getODS{0}{1}s({2})) {{\n", // Capitalize the first letter to match the function name valueKind.substr(0, 1).upper(), valueKind.substr(1), @@ -455,14 +455,10 @@ static void GenOperandResultVerifier(raw_ostream &os, os << " (void)v;\n" << " if (!(" << tgfmt(pred.getCondition(), &fctx.withSelf("v.getType()")) << ")) {\n" - << " if (failure_on_operand_type_mismatch) {\n" << formatv( - " return op->emitOpError(\"{0} #\") << index " + " return op->emitOpError(\"{0} #\") << index " "<< \" must be {1}, but got \" << v.getType();\n", valueKind, desc) - << " } else {\n" - << " return ::mlir::LogicalResult::Failure;\n" - << " }\n" << " }\n" // if << " ++index;\n" << " }\n"; // for @@ -487,8 +483,7 @@ static bool RuntimeVerifierWriterMain(raw_ostream &os, RecordKeeper &records) { mlir::tblgen::FmtContext verify_ctx; os << "::mlir::LogicalResult " << op.getCppClassName() - << "::VerifyTflRuntimeConstraints(::mlir::Operation *op, bool " - "failure_on_operand_type_mismatch) {\n"; + << "::VerifyTflRuntimeConstraints(::mlir::Operation *op) {\n"; os << " auto top = cast<" << op.getCppClassName() << ">(op); (void)top;\n"; verify_ctx.withOp("top"); @@ -529,11 +524,8 @@ static bool RuntimeVerifierWriterMain(raw_ostream &os, RecordKeeper &records) { mlir::tblgen::Pred pred(dyn_cast(val->getValue())); os << tgfmt( - " if (!($0)) {\n " - " if (failure_on_operand_type_mismatch) {\n" - " return top.emitOpError(\"failed to verify that $1\");\n" - " } else {\n" - " return ::mlir::LogicalResult::Failure;\n }\n }\n", + " if (!($0))\n" + " return top.emitOpError(\"failed to verify that $1\");\n", &verify_ctx, tgfmt(pred.getCondition(), &verify_ctx), desc); } os << " return top.verify();\n}\n"; diff --git a/tensorflow/compiler/mlir/lite/flatbuffer_export.cc b/tensorflow/compiler/mlir/lite/flatbuffer_export.cc index a260670015a..e34e7ae7ca6 100644 --- a/tensorflow/compiler/mlir/lite/flatbuffer_export.cc +++ b/tensorflow/compiler/mlir/lite/flatbuffer_export.cc @@ -240,10 +240,10 @@ static bool IsValidTFLiteMlirModule(ModuleOp module) { } for (auto fn : module.getOps()) { - if (fn.getBlocks().size() != 1) { + if (!llvm::hasSingleElement(fn)) { return fn.emitError("should have exactly one basic block"), false; } - auto& bb = fn.getBlocks().front(); + auto& bb = fn.front(); for (auto arg : bb.getArguments()) { if (!HasValidTFLiteType(arg, fn)) @@ -1089,7 +1089,7 @@ void Translator::InitializeNamesFromAttribute(FuncOp fn, bool* has_input_attr) { dict_attr.get("outputs").dyn_cast_or_null()) { str.getValue().split(output_names, ',', /*MaxSplit=*/-1, /*KeepEmpty=*/false); - auto term = fn.getBlocks().back().getTerminator(); + auto term = fn.back().getTerminator(); if (output_names.size() != term->getNumOperands()) { fn.emitWarning() << "output names (" << output_names.size() << ") != terminator operands (" << term->getNumOperands() diff --git a/tensorflow/compiler/mlir/lite/ir/tfl_op_interfaces.td b/tensorflow/compiler/mlir/lite/ir/tfl_op_interfaces.td index 23101113a6f..a79d79b5970 100644 --- a/tensorflow/compiler/mlir/lite/ir/tfl_op_interfaces.td +++ b/tensorflow/compiler/mlir/lite/ir/tfl_op_interfaces.td @@ -94,8 +94,7 @@ def TFL_RuntimeVerification : OpInterface<"TflRuntimeVerifyOpInterface"> { let methods = [ StaticInterfaceMethod< [{Returns whether the op's operands/results are supported by runtime.}], - "LogicalResult", "VerifyTflRuntimeConstraints", - (ins "Operation*":$op, "bool":$failure_on_operand_type_mismatch) + "LogicalResult", "VerifyTflRuntimeConstraints", (ins "Operation*":$op) >, ]; } diff --git a/tensorflow/compiler/mlir/lite/ir/tfl_ops.cc b/tensorflow/compiler/mlir/lite/ir/tfl_ops.cc index 6e9930271c8..16d256c7571 100644 --- a/tensorflow/compiler/mlir/lite/ir/tfl_ops.cc +++ b/tensorflow/compiler/mlir/lite/ir/tfl_ops.cc @@ -758,6 +758,22 @@ OpFoldResult ConcatenationOp::fold(ArrayRef operands) { return new_concat.getResult(); } +//===----------------------------------------------------------------------===// +// CustomOp +//===----------------------------------------------------------------------===// + +static LogicalResult Verify(CustomOp op) { + OpaqueElementsAttr opaque_attr = + op.custom_option().cast(); + if (!opaque_attr.getType().hasStaticShape()) + return op.emitOpError("custom_option should have a static shape."); + if (opaque_attr.getValue().size() != + opaque_attr.getType().cast().getDimSize(0)) + return op.emitOpError( + "custom_option should have the same length of content with shape."); + return success(); +} + //===----------------------------------------------------------------------===// // FullyConnectedOp //===----------------------------------------------------------------------===// @@ -2169,6 +2185,10 @@ static LogicalResult Verify(TransposeOp op) { return success(); } +//===----------------------------------------------------------------------===// +// WhileOp +//===----------------------------------------------------------------------===// + LogicalResult Verify(WhileOp op) { if (op.getNumOperands() != op.getNumResults()) return op.emitOpError(llvm::formatv( @@ -2178,18 +2198,6 @@ LogicalResult Verify(WhileOp op) { return success(); } -static LogicalResult Verify(CustomOp op) { - OpaqueElementsAttr opaque_attr = - op.custom_option().cast(); - if (!opaque_attr.getType().hasStaticShape()) - return op.emitOpError("custom_option should have a static shape."); - if (opaque_attr.getValue().size() != - opaque_attr.getType().cast().getDimSize(0)) - return op.emitOpError( - "custom_option should have the same length of content with shape."); - return success(); -} - namespace { // Canonicalize While op so that results and operands match and external values // are via implicit capture rather than via block args. diff --git a/tensorflow/compiler/mlir/lite/ir/tfl_ops.td b/tensorflow/compiler/mlir/lite/ir/tfl_ops.td index 509c13ae161..f379b241f9d 100644 --- a/tensorflow/compiler/mlir/lite/ir/tfl_ops.td +++ b/tensorflow/compiler/mlir/lite/ir/tfl_ops.td @@ -571,6 +571,8 @@ def TFL_TransposeConvOp: TFL_Op<"transpose_conv", [ TFL_OperandHasRank<2, 4>, PredOpTrait<"input and output must have same element type", TFL_TCresVTEtIsSameAsOp<0, 2>>, + AccumulatorUniformScale<3, 1, 2>, + TFL_ChannelDimIndexInterface, AffineOpCoefficient<0, 2>, TFL_GpuTargetOp, TFL_SparseOp]> { let summary = "Transpose convolution operator"; @@ -596,6 +598,8 @@ def TFL_TransposeConvOp: TFL_Op<"transpose_conv", [ let verifier = [{ return Verify(*this); }]; let extraClassDeclaration = [{ + // ChannelDimIndexInterface: + int GetChannelDimIndex() { return 0; } // SparseOpInterface: std::vector GetSparseOperands() { return {1}; } std::vector> GetFloatBlockSize() { return {}; } @@ -953,14 +957,14 @@ in the batch dimensions and broadcasting. }]; let arguments = (ins - TFL_TensorOf<[F32]>:$x, - TFL_TensorOf<[F32]>:$y, + TFL_TensorOf<[F32, QI8]>:$x, + TFL_TensorOf<[F32, QI8]>:$y, DefaultValuedAttr:$adj_x, DefaultValuedAttr:$adj_y ); let results = (outs - TFL_TensorOf<[F32]>:$output + TFL_TensorOf<[F32, QI8]>:$output ); let hasOptions = 1; diff --git a/tensorflow/compiler/mlir/lite/quantization/import_quant_stats_pass.cc b/tensorflow/compiler/mlir/lite/quantization/import_quant_stats_pass.cc index d924a3e82ac..6299a70b1df 100644 --- a/tensorflow/compiler/mlir/lite/quantization/import_quant_stats_pass.cc +++ b/tensorflow/compiler/mlir/lite/quantization/import_quant_stats_pass.cc @@ -76,7 +76,8 @@ class ImportQuantStatsPass // If the index is out of range, this method returns false. Otherwise it // returns true if the value is a float tensor. bool IsQuantizableResult(Operation *op, int index) { - if (index < 0 || index >= op->getNumResults()) return false; + if (index < 0 || index >= static_cast(op->getNumResults())) + return false; Value res = op->getResult(index); return res.getType().isa() && res.getType().cast().getElementType().isa(); @@ -158,7 +159,7 @@ void ImportQuantStatsPass::ImportAsStatsOps(OpBuilder b, Operation *op, InsertStatsOpAtResult(b, op->getResult(index), layer_stats, axis_stats, axis); } else { - for (int i = 0; i < op->getNumResults(); ++i) { + for (int i = 0, e = op->getNumResults(); i < e; ++i) { if (IsQuantizableResult(op, i)) { InsertStatsOpAtResult(b, op->getResult(i), layer_stats, axis_stats, axis); diff --git a/tensorflow/compiler/mlir/lite/quantization/quantization_config.cc b/tensorflow/compiler/mlir/lite/quantization/quantization_config.cc index 3edd9c36760..9adabde4f25 100644 --- a/tensorflow/compiler/mlir/lite/quantization/quantization_config.cc +++ b/tensorflow/compiler/mlir/lite/quantization/quantization_config.cc @@ -48,7 +48,7 @@ bool ParseInputNodeQuantSpecs(absl::string_view node_names, std::vector> node_mins; if (!min_values.empty()) { std::vector node_mins_str = absl::StrSplit(min_values, ','); - for (int i = 0; i < node_mins_str.size(); i++) { + for (int i = 0, e = node_mins_str.size(); i < e; i++) { double value; if (!absl::SimpleAtod(node_mins_str[i], &value)) { return true; @@ -60,7 +60,7 @@ bool ParseInputNodeQuantSpecs(absl::string_view node_names, std::vector> node_maxs; if (!max_values.empty()) { std::vector node_maxs_str = absl::StrSplit(max_values, ','); - for (int i = 0; i < node_maxs_str.size(); i++) { + for (int i = 0, e = node_maxs_str.size(); i < e; i++) { double value; if (!absl::SimpleAtod(node_maxs_str[i], &value)) { llvm::errs() << "Unexpected mins: " << node_maxs_str[i] << "\n"; diff --git a/tensorflow/compiler/mlir/lite/quantization/quantization_driver.cc b/tensorflow/compiler/mlir/lite/quantization/quantization_driver.cc index 89443b1ec65..f3e746c7a43 100644 --- a/tensorflow/compiler/mlir/lite/quantization/quantization_driver.cc +++ b/tensorflow/compiler/mlir/lite/quantization/quantization_driver.cc @@ -294,7 +294,7 @@ class QuantizationDriver { return; if (current_op == op) llvm::errs() << "===>>>"; llvm::errs() << op->getName() << " : ("; - for (auto i = 0; i < op->getNumOperands(); ++i) { + for (int i = 0, e = op->getNumOperands(); i < e; ++i) { if (auto params = GetOperandQuantState(op, i).params) params.print(llvm::errs()); else @@ -303,7 +303,7 @@ class QuantizationDriver { llvm::errs() << ","; } llvm::errs() << ") -> ("; - for (auto i = 0; i < op->getNumResults(); ++i) { + for (int i = 0, e = op->getNumResults(); i < e; ++i) { if (auto params = GetResultQuantState(op, i).params) params.print(llvm::errs()); else diff --git a/tensorflow/compiler/mlir/lite/quantization/quantization_utils.cc b/tensorflow/compiler/mlir/lite/quantization/quantization_utils.cc index 32f68aaae5f..b98739eac6e 100644 --- a/tensorflow/compiler/mlir/lite/quantization/quantization_utils.cc +++ b/tensorflow/compiler/mlir/lite/quantization/quantization_utils.cc @@ -55,7 +55,7 @@ static Type GetQuantizedType(Builder builder, Type input_type, } else if (min.size() == max.size()) { auto shape = input_type.dyn_cast(); if (!shape || shape.getRank() <= quant_dim || - min.size() != shape.getDimSize(quant_dim)) { + static_cast(min.size()) != shape.getDimSize(quant_dim)) { return {}; } // TODO(b/141508873): the quantization dim is set to the last dimension. @@ -76,7 +76,8 @@ TypeAttr RescaleQuantizedType(Type input, Attribute factor) { if (auto qtype = ele_type.dyn_cast()) { ArrayRef scales = qtype.getScales(); // Broadcasting hasn't been implemented yet. - if (scales.size() != factor_values.getNumElements()) return {}; + if (static_cast(scales.size()) != factor_values.getNumElements()) + return {}; SmallVector new_scales; new_scales.reserve(scales.size()); auto scales_iter = scales.begin(); @@ -270,7 +271,7 @@ Type GetUniformQuantizedPerAxisTypeForWeight(ElementsAttr attr, int quant_dim, bool narrow_range) { Builder builder(attr.getContext()); auto shape = attr.getType().cast().getShape(); - if (shape.size() <= quant_dim) return {}; + if (static_cast(shape.size()) <= quant_dim) return {}; // `symmetric` can only be used when it is `signed` and `narrow_range`. if (symmetric && (!is_signed || !narrow_range)) return {}; @@ -335,7 +336,7 @@ quant::QuantizedType GetUniformQuantizedTypeForBias( const std::vector& op_types) { if (op_types.empty()) return {}; - int axis_size = 1; + size_t axis_size = 1; int32_t quant_dim = -1; Type expressed_type; // Requires all the op types are valid UniformQuantizedTypes or @@ -369,7 +370,7 @@ quant::QuantizedType GetUniformQuantizedTypeForBias( scales[index_scale.index()] *= index_scale.value(); } } else if (auto type = op_type.dyn_cast()) { - for (int index = 0; index != axis_size; ++index) { + for (int index = 0, e = axis_size; index != e; ++index) { scales[index] *= type.getScale(); } } diff --git a/tensorflow/compiler/mlir/lite/tests/legalize-tf.mlir b/tensorflow/compiler/mlir/lite/tests/legalize-tf.mlir index 1ae789f5468..5756fa6dec2 100644 --- a/tensorflow/compiler/mlir/lite/tests/legalize-tf.mlir +++ b/tensorflow/compiler/mlir/lite/tests/legalize-tf.mlir @@ -990,6 +990,13 @@ func @batch_to_space_nd(%arg0: tensor<4x2x2x3xf32>, %arg1: tensor<2xi32>, %arg2: // CHECK: "tfl.batch_to_space_nd"(%arg0, %arg1, %arg2) : (tensor<4x2x2x3xf32>, tensor<2xi32>, tensor<2x2xi32>) -> tensor } +func @batch_to_space_nd_unsupported(%arg0: tensor, %arg1: tensor<3xi32>, %arg2: tensor<3x2xi32>) -> tensor { + %0 = "tf.BatchToSpaceND"(%arg0, %arg1, %arg2) : (tensor, tensor<3xi32>, tensor<3x2xi32>) -> tensor + return %0 : tensor + // CHECK-LABEL: batch_to_space_nd_unsupported + // CHECK: "tf.BatchToSpaceND" +} + func @space_to_batch_nd(%arg0: tensor<1x4x4x3xf32>, %arg1: tensor<2xi32>, %arg2: tensor<2x2xi32>) -> tensor { %0 = "tf.SpaceToBatchND"(%arg0, %arg1, %arg2) : (tensor<1x4x4x3xf32>, tensor<2xi32>, tensor<2x2xi32>) -> tensor return %0 : tensor diff --git a/tensorflow/compiler/mlir/lite/tests/prepare-quantize-signed.mlir b/tensorflow/compiler/mlir/lite/tests/prepare-quantize-signed.mlir index e1f496b91f4..4a83616408e 100644 --- a/tensorflow/compiler/mlir/lite/tests/prepare-quantize-signed.mlir +++ b/tensorflow/compiler/mlir/lite/tests/prepare-quantize-signed.mlir @@ -70,6 +70,7 @@ func @prepareAdd(%arg0: tensor<2x2xf32>) -> tensor<2x2xf32> { } // CHECK-LABEL: prepareConv2DSplat +// PerTensor-LABEL: prepareConv2DSplat func @prepareConv2DSplat(%arg0: tensor<1x5x5x3xf32>) -> tensor<1x5x5x3xf32> { %w = constant dense<127.0> : tensor<3x3x3x3xf32> %b = constant dense<0.0> : tensor<3xf32> @@ -89,6 +90,7 @@ func @prepareConv2DSplat(%arg0: tensor<1x5x5x3xf32>) -> tensor<1x5x5x3xf32> { } // CHECK-LABEL: prepareConv2D +// PerTensor-LABEL: prepareConv2D func @prepareConv2D(%arg0: tensor<1x5x5x1xf32>) -> tensor<1x5x5x3xf32> { %w = constant dense<[[[[0.0]]], [[[127.0]]], [[[-127.0]]]]> : tensor<3x1x1x1xf32> %b = constant dense<0.0> : tensor<3xf32> @@ -108,6 +110,7 @@ func @prepareConv2D(%arg0: tensor<1x5x5x1xf32>) -> tensor<1x5x5x3xf32> { } // CHECK-LABEL: prepareDepthwiseConv2D +// PerTensor-LABEL: prepareDepthwiseConv2D func @prepareDepthwiseConv2D(%arg0: tensor<1x224x224x3xf32>) -> tensor<1x112x112x32xf32> { %w = constant dense<127.0> : tensor<32x3x3x3xf32> %b = constant dense<0.0> : tensor<32xf32> @@ -127,6 +130,7 @@ func @prepareDepthwiseConv2D(%arg0: tensor<1x224x224x3xf32>) -> tensor<1x112x112 } // CHECK-LABEL: QuantizeFullyConnected +// PerTensor-LABEL: QuantizeFullyConnected func @QuantizeFullyConnected(%arg0: tensor<1x224x224x3xf32>) -> tensor<1x112x112x32xf32> { %w = constant dense<127.0> : tensor<32x12xf32> %b = constant dense<0.0> : tensor<32xf32> @@ -143,3 +147,22 @@ func @QuantizeFullyConnected(%arg0: tensor<1x224x224x3xf32>) -> tensor<1x112x112 // PerTensor: %[[dq:.*]] = "tfl.dequantize"(%0) : (tensor<32x12x!quant.uniform:f32, 1.000000e+00>>) -> tensor<32x12xf32> // PerTensor: "tfl.fully_connected"(%arg0, %[[dq]] } + +// CHECK-LABEL: QuantizeTransposeConv +// PerTensor-LABEL: QuantizeTransposeConv +func @QuantizeTransposeConv(%arg0: tensor<32x4x4x128xf32>, %arg1: tensor<4xi32>) -> tensor<1x32x42x128xf32> { + %w = constant dense<127.0> : tensor<1x32x42x128xf32> + %b = constant dense<0.0> : tensor<1x32x42x128xf32> + %tc = "tfl.transpose_conv"(%arg1, %arg0, %w, %b) {padding = "SAME", stride_h = 2 : i32, stride_w = 2 : i32} : (tensor<4xi32>, tensor<32x4x4x128xf32>, tensor<1x32x42x128xf32>, tensor<1x32x42x128xf32>) -> tensor<1x32x42x128xf32> + return %tc : tensor<1x32x42x128xf32> + +// CHECK: %[[CST:.*]] = constant dense<1.270000e+02> : tensor<1x32x42x128xf32> +// CHECK: %[[QUANTIZE:.*]] = "tfl.quantize"(%[[CST]]) {qtype = tensor<1x32x42x128x!quant.uniform:f32:0, {1.000000e+00}>>, volatile} +// CHECK: %[[DEQUANTIZE:.*]] = "tfl.dequantize"(%[[QUANTIZE]]) : (tensor<1x32x42x128x!quant.uniform:f32:0, {1.000000e+00}>>) -> tensor<1x32x42x128xf32> +// CHECK: "tfl.transpose_conv"(%arg1, %arg0, %[[DEQUANTIZE]] + +// PerTensor: %[[CST:.*]] = constant dense<1.270000e+02> : tensor<1x32x42x128xf32> +// PerTensor: %[[QUANTIZE:.*]] = "tfl.quantize"(%[[CST]]) {qtype = tensor<1x32x42x128x!quant.uniform:f32, 1.000000e+00>>, volatile} +// PerTensor: %[[DEQUANTIZE:.*]] = "tfl.dequantize"(%[[QUANTIZE]]) : (tensor<1x32x42x128x!quant.uniform:f32, 1.000000e+00>>) -> tensor<1x32x42x128xf32> +// PerTensor: "tfl.transpose_conv"(%arg1, %arg0, %[[DEQUANTIZE]] +} diff --git a/tensorflow/compiler/mlir/lite/tests/prepare-tf.mlir b/tensorflow/compiler/mlir/lite/tests/prepare-tf.mlir index e95f3d011e2..719430959d0 100644 --- a/tensorflow/compiler/mlir/lite/tests/prepare-tf.mlir +++ b/tensorflow/compiler/mlir/lite/tests/prepare-tf.mlir @@ -528,6 +528,26 @@ func @PadStridedSliceNewAxisMask2(%arg0: tensor<4x64x64x1xf32>) -> tensor<1x4x64 return %1 : tensor<1x4x64x64xf32> } +// CHECK-LABEL: @StridedSliceRewriteMasks +func @StridedSliceRewriteMasks(%arg0: tensor<8x4x16x2xf32>) -> tensor<8x4x16x1xf32> { + %cst = "tf.Const"() {device = "", value = dense<[1, 0, 1]> : tensor<3xi32>} : () -> tensor<3xi32> + %cst_0 = "tf.Const"() {device = "", value = dense<[1, 0, 0]> : tensor<3xi32>} : () -> tensor<3xi32> + %cst_1 = "tf.Const"() {device = "", value = dense<1> : tensor<3xi32>} : () -> tensor<3xi32> + + // CHECK: %[[CST:.*]] = constant dense<[1, 0, 0, 1]> : tensor<4xi32> + // CHECK: %[[CST0:.*]] = constant dense<[1, 0, 0, 0]> : tensor<4xi32> + // CHECK: %[[CST1:.*]] = constant dense<1> : tensor<4xi32> + // CHECK: %[[RESULT:.*]] = "tf.StridedSlice"(%arg0, %[[CST]], %[[CST0]], %[[CST1]]) + // CHECK-SAME: begin_mask = 7 : i64 + // CHECK-SAME: ellipsis_mask = 0 : i64 + // CHECK-SAME: end_mask = 14 : i64 + // CHECK-SAME: new_axis_mask = 0 : i64 + // CHECK-SAME: shrink_axis_mask = 0 : i64 + + %0 = "tf.StridedSlice"(%arg0, %cst, %cst_0, %cst_1) {begin_mask = 1 : i64, device = "", ellipsis_mask = 2 : i64, end_mask = 4 : i64, new_axis_mask = 0 : i64, shrink_axis_mask = 0 : i64} : (tensor<8x4x16x2xf32>, tensor<3xi32>, tensor<3xi32>, tensor<3xi32>) -> tensor<8x4x16x1xf32> + return %0 : tensor<8x4x16x1xf32> +} + // CHECK-LABEL: @MatrixSetDiagV2Conversion func @MatrixSetDiagV2Conversion(%arg0: tensor<3x3xi32>, %arg1: tensor<3xi32>) -> tensor<3x3xi32> { %cst = constant dense<0> : tensor diff --git a/tensorflow/compiler/mlir/lite/tf_tfl_passes.cc b/tensorflow/compiler/mlir/lite/tf_tfl_passes.cc index 589515d6246..3fa2eae42f2 100644 --- a/tensorflow/compiler/mlir/lite/tf_tfl_passes.cc +++ b/tensorflow/compiler/mlir/lite/tf_tfl_passes.cc @@ -39,22 +39,18 @@ namespace tensorflow { void AddQuantizationPasses(const mlir::TFL::QuantizationSpecs& quant_specs, mlir::OpPassManager* pass_manager) { pass_manager->addPass(mlir::TFL::CreatePrepareQuantizePass(quant_specs)); - pass_manager->addPass(mlir::TFL::CreateQuantizePass()); - bool emit_quant_adaptor_ops = - quant_specs.inference_type != quant_specs.inference_input_type; - pass_manager->addPass( - mlir::TFL::CreatePostQuantizePass(emit_quant_adaptor_ops)); - if (quant_specs.default_ranges.first.hasValue() || quant_specs.default_ranges.second.hasValue()) { pass_manager->addPass(mlir::TFL::CreateDefaultQuantParamsPass( quant_specs.default_ranges.first.getValueOr(0.0), quant_specs.default_ranges.second.getValueOr(0.0), quant_specs.IsSignedInferenceType())); - pass_manager->addPass(mlir::TFL::CreateQuantizePass()); - pass_manager->addPass( - mlir::TFL::CreatePostQuantizePass(emit_quant_adaptor_ops)); } + pass_manager->addPass(mlir::TFL::CreateQuantizePass()); + bool emit_quant_adaptor_ops = + quant_specs.inference_type != quant_specs.inference_input_type; + pass_manager->addPass( + mlir::TFL::CreatePostQuantizePass(emit_quant_adaptor_ops)); } void AddTFToTFLConversionPasses(const mlir::TFL::PassConfig& pass_config, @@ -63,7 +59,7 @@ void AddTFToTFLConversionPasses(const mlir::TFL::PassConfig& pass_config, standard_pipeline_options.enable_inliner = false; standard_pipeline_options.form_clusters = pass_config.form_clusters; mlir::TF::CreateTFStandardPipeline(*pass_manager, standard_pipeline_options); - pass_manager->addPass(mlir::TFL::CreateDeviceIndexSelectorPass()); + pass_manager->addPass(mlir::TF::CreateDeviceIndexSelectorPass()); if (pass_config.shape_inference) { pass_manager->addPass(mlir::TF::CreateTFShapeInferencePass()); @@ -212,9 +208,6 @@ void CreateTFLStandardPipeline(OpPassManager& pm, // Saved model pass to mark global tensors immutable. pm.addPass(mlir::tf_saved_model::CreateOptimizeGlobalTensorsPass()); - // Used to mark non-exported functions in saved model private. - pm.addPass(mlir::tf_saved_model:: - CreateMarkFunctionVisibilityUsingSavedModelLinkagePass()); // Op fusion pass. pm.addPass(mlir::TFL::CreatePrepareCompositeFunctionsPass()); diff --git a/tensorflow/compiler/mlir/lite/transforms/legalize_tf.cc b/tensorflow/compiler/mlir/lite/transforms/legalize_tf.cc index 46ed134d7ee..1328a2baf5d 100644 --- a/tensorflow/compiler/mlir/lite/transforms/legalize_tf.cc +++ b/tensorflow/compiler/mlir/lite/transforms/legalize_tf.cc @@ -28,9 +28,11 @@ limitations under the License. #include "llvm/ADT/APInt.h" #include "llvm/ADT/ArrayRef.h" #include "llvm/ADT/StringSwitch.h" +#include "llvm/Support/Threading.h" #include "mlir/Dialect/Quant/FakeQuantSupport.h" // from @llvm-project #include "mlir/Dialect/Quant/UniformSupport.h" // from @llvm-project #include "mlir/IR/Attributes.h" // from @llvm-project +#include "mlir/IR/Diagnostics.h" // from @llvm-project #include "mlir/IR/MLIRContext.h" // from @llvm-project #include "mlir/IR/Operation.h" // from @llvm-project #include "mlir/IR/PatternMatch.h" // from @llvm-project @@ -767,13 +769,26 @@ void LegalizeTF::runOnFunction() { [](Operation* op) { auto tfl_op = dyn_cast_or_null(op); if (!tfl_op) return false; - return succeeded(tfl_op.VerifyTflRuntimeConstraints( - tfl_op.getOperation(), - /*failure_on_operand_type_mismatch=*/false)); + return succeeded(tfl_op.VerifyTflRuntimeConstraints(op)); })); } else { target.addLegalDialect(); } + + // Ignore transient errors by registering an no-op handler. + // Applying legalization patterns will emit unwanted, transient errors when + // the replaced TFLite ops do not meet the sanity checks. In order to ignore + // the transient errors, the following lines override a diagnostic handler + // with an no-op handler only while this pass runs. + uint64_t current_thread_id = llvm::get_threadid(); + ScopedDiagnosticHandler scoped_diag_handler( + context, [¤t_thread_id](Diagnostic&) -> LogicalResult { + // Consume only errors that are coming from the same thread in order not + // to ignore errors from other passes that are running. Things running + // in the pass manager can be multi-threaded. + return success(current_thread_id == llvm::get_threadid()); + }); + // Keep trying to convert. // TODO(karimnosseir): This is similar to what apply greedy patterns does. // Look if there is a function that tries until it converge. diff --git a/tensorflow/compiler/mlir/lite/transforms/passes.h b/tensorflow/compiler/mlir/lite/transforms/passes.h index 01e5eb1cb68..105c9394fb4 100644 --- a/tensorflow/compiler/mlir/lite/transforms/passes.h +++ b/tensorflow/compiler/mlir/lite/transforms/passes.h @@ -91,9 +91,6 @@ std::unique_ptr> CreateWhileOutlinePass(); // Verifies runtime constraints. std::unique_ptr> CreateRuntimeVerifyPass(); -// Creates function pass to select device index/fold tf.DeviceIndex. -std::unique_ptr> CreateDeviceIndexSelectorPass(); - } // namespace TFL } // namespace mlir diff --git a/tensorflow/compiler/mlir/lite/transforms/post_quantize.cc b/tensorflow/compiler/mlir/lite/transforms/post_quantize.cc index 9a1da0ad03d..33380e00543 100644 --- a/tensorflow/compiler/mlir/lite/transforms/post_quantize.cc +++ b/tensorflow/compiler/mlir/lite/transforms/post_quantize.cc @@ -52,7 +52,7 @@ class PostQuantizePass : public PassWrapper { void RemoveQuantizationAdaptorOps(FuncOp func) { mlir::OpBuilder builder(func.getBody()); - auto& bb = func.getBlocks().front(); + auto& bb = func.front(); auto* terminator = bb.getTerminator(); int num_args = bb.getNumArguments(); diff --git a/tensorflow/compiler/mlir/lite/transforms/prepare_tf.cc b/tensorflow/compiler/mlir/lite/transforms/prepare_tf.cc index 3310c521a5a..6ee988496fa 100644 --- a/tensorflow/compiler/mlir/lite/transforms/prepare_tf.cc +++ b/tensorflow/compiler/mlir/lite/transforms/prepare_tf.cc @@ -584,46 +584,50 @@ struct ConvertTFStridedSlice : public RewritePattern { const int ellipsis_filled_dim_size = input_size - begin_shape[0] + 1; - llvm::APInt new_begin_mask = strided_slice_op.begin_mask(); - llvm::APInt new_end_mask = strided_slice_op.end_mask(); + int64_t begin_mask = strided_slice_op.begin_mask().getSExtValue(); + int64_t end_mask = strided_slice_op.end_mask().getSExtValue(); + int64_t new_begin_mask = 0; + int64_t new_end_mask = 0; SmallVector padded_begin; SmallVector padded_end; SmallVector padded_stride; // Before the ellipsis. - uint64_t index = 1; - int count = 0; - - while (index < ellipsis_mask) { - padded_begin.push_back(begin_dense_elem_attr.getValue(count)); - padded_end.push_back(end_dense_elem_attr.getValue(count)); - padded_stride.push_back(stride_dense_elem_attr.getValue(count)); - index <<= 1; - count++; + int index = 0; + int new_index = 0; + while (((ellipsis_mask >> index) & 1) == 0) { + padded_begin.push_back(begin_dense_elem_attr.getValue(index)); + padded_end.push_back(end_dense_elem_attr.getValue(index)); + padded_stride.push_back(stride_dense_elem_attr.getValue(index)); + if ((begin_mask >> index) & 1) new_begin_mask |= (1 << new_index); + if ((end_mask >> index) & 1) new_end_mask |= (1 << new_index); + ++index; + ++new_index; } // Ellipsis. - for (int i = 0; i < ellipsis_filled_dim_size; ++i) { - new_begin_mask |= ellipsis_mask; - new_end_mask |= ellipsis_mask; + for (; new_index < index + ellipsis_filled_dim_size; ++new_index) { + new_begin_mask |= (1 << new_index); + new_end_mask |= (1 << new_index); // Mimic the begin/end/strides mask behavior. padded_begin.push_back(0); padded_end.push_back(0); padded_stride.push_back(1); - - ellipsis_mask <<= 1; } // Account for ellipsis mask. - count++; + ++index; // After the ellipsis. - for (; count < begin_shape[0]; ++count) { - padded_begin.push_back(begin_dense_elem_attr.getValue(count)); - padded_end.push_back(end_dense_elem_attr.getValue(count)); - padded_stride.push_back(stride_dense_elem_attr.getValue(count)); + for (; index < begin_shape[0]; ++index) { + padded_begin.push_back(begin_dense_elem_attr.getValue(index)); + padded_end.push_back(end_dense_elem_attr.getValue(index)); + padded_stride.push_back(stride_dense_elem_attr.getValue(index)); + + if ((begin_mask >> index) & 1) new_begin_mask |= (1 << new_index); + if ((end_mask >> index) & 1) new_end_mask |= (1 << new_index); } auto attribute_type = rewriter.getIntegerType(64); @@ -645,7 +649,7 @@ struct ConvertTFStridedSlice : public RewritePattern { end_op.getResult(), stride_op.getResult(), rewriter.getIntegerAttr(attribute_type, new_begin_mask), rewriter.getIntegerAttr(attribute_type, new_end_mask), - rewriter.getI64IntegerAttr(0), + /*ellipsis_maks=*/rewriter.getI64IntegerAttr(0), rewriter.getIntegerAttr(attribute_type, strided_slice_op.new_axis_mask()), rewriter.getIntegerAttr(attribute_type, @@ -655,10 +659,12 @@ struct ConvertTFStridedSlice : public RewritePattern { LogicalResult matchAndRewrite(Operation *op, PatternRewriter &rewriter) const override { - // TODO(renjieliu): Consider expand the transformation for shrink - // mask as well. TF::StridedSliceOp strided_slice_op = llvm::cast(op); + // TODO(renjieliu): Consider expand the transformation for shrink mask as + // well. + if (strided_slice_op.shrink_axis_mask().getZExtValue()) return failure(); + // Handle new axis mask. uint64_t new_axis_mask = strided_slice_op.new_axis_mask().getZExtValue(); if (new_axis_mask != 0) { diff --git a/tensorflow/compiler/mlir/lite/transforms/runtime_verify.cc b/tensorflow/compiler/mlir/lite/transforms/runtime_verify.cc index 3268329b1c1..cc2e691180e 100644 --- a/tensorflow/compiler/mlir/lite/transforms/runtime_verify.cc +++ b/tensorflow/compiler/mlir/lite/transforms/runtime_verify.cc @@ -34,8 +34,7 @@ class RuntimeVerifyPass void RuntimeVerifyPass::runOnFunction() { getFunction().walk([&](TflRuntimeVerifyOpInterface op) { - if (failed(op.VerifyTflRuntimeConstraints( - op.getOperation(), /*failure_on_operand_type_mismatch=*/true))) + if (failed(op.VerifyTflRuntimeConstraints(op.getOperation()))) signalPassFailure(); }); } diff --git a/tensorflow/compiler/mlir/tensorflow/BUILD b/tensorflow/compiler/mlir/tensorflow/BUILD index 904ccb7e820..b159815d5eb 100644 --- a/tensorflow/compiler/mlir/tensorflow/BUILD +++ b/tensorflow/compiler/mlir/tensorflow/BUILD @@ -57,6 +57,7 @@ gentbl( td_srcs = [ ":tensorflow_ops_td_files", ], + test = True, ) gentbl( @@ -88,6 +89,7 @@ gentbl( td_srcs = [ ":tensorflow_ops_td_files", ], + test = True, ) gentbl( @@ -112,6 +114,7 @@ gentbl( "@llvm-project//mlir:include/mlir/IR/OpBase.td", "@llvm-project//mlir:include/mlir/Dialect/StandardOps/IR/Ops.td", ], + test = True, ) gentbl( @@ -137,6 +140,7 @@ gentbl( "@llvm-project//mlir:include/mlir/Dialect/StandardOps/IR/Ops.td", "@llvm-project//mlir:include/mlir/IR/OpBase.td", ], + test = True, ) gentbl( @@ -161,6 +165,7 @@ gentbl( "@llvm-project//mlir:include/mlir/IR/OpBase.td", "@llvm-project//mlir:include/mlir/Dialect/StandardOps/IR/Ops.td", ], + test = True, ) gentbl( @@ -475,6 +480,7 @@ cc_library( "transforms/cluster_outlining.cc", "transforms/collection_ops_util.cc", "transforms/decompose_resource_ops_pass.cc", + "transforms/device_index_selector.cc", "transforms/einsum.cc", "transforms/executor_island_coarsening.cc", "transforms/executor_tpuv1_inline_tpu_island.cc", @@ -491,7 +497,6 @@ cc_library( "transforms/graph_pruning.cc", "transforms/launch_to_device_attribute.cc", "transforms/layout_optimization.cc", - "transforms/mark_function_visibility.cc", "transforms/materialize_mlir_passthrough_op.cc", "transforms/optimize.cc", "transforms/optimize_global_tensors.cc", @@ -661,7 +666,9 @@ cc_library( ":tensorflow_types", ":translate_utils", "//tensorflow/cc/saved_model:bundle_v2", + "//tensorflow/cc/saved_model:constants", "//tensorflow/cc/saved_model:loader_lite", + "//tensorflow/cc/saved_model:loader_util", "//tensorflow/compiler/jit:shape_inference_helpers", "//tensorflow/compiler/mlir:op_or_arg_name_mapper", "//tensorflow/compiler/tf2xla:functionalize_control_flow", @@ -673,6 +680,7 @@ cc_library( "//tensorflow/core:lib", "//tensorflow/core:protos_all_cc", "//tensorflow/core/grappler/utils:transitive_fanin", + "//tensorflow/core/platform:protobuf_internal", "//tensorflow/core/platform:types", "//tensorflow/stream_executor/lib", "@com_google_absl//absl/algorithm:container", @@ -682,7 +690,6 @@ cc_library( "@com_google_absl//absl/strings", "@com_google_absl//absl/types:optional", "@llvm-project//llvm:Support", - "@llvm-project//mlir:Analysis", "@llvm-project//mlir:IR", "@llvm-project//mlir:Pass", "@llvm-project//mlir:StandardOps", diff --git a/tensorflow/compiler/mlir/tensorflow/ir/tf_device.cc b/tensorflow/compiler/mlir/tensorflow/ir/tf_device.cc index b8f0585040c..7dd74282487 100644 --- a/tensorflow/compiler/mlir/tensorflow/ir/tf_device.cc +++ b/tensorflow/compiler/mlir/tensorflow/ir/tf_device.cc @@ -299,13 +299,13 @@ ParseResult ParseReplicateOp(OpAsmParser* parser, OperationState* state) { parser->parseRegion(body, region_args, region_arg_types)) return failure(); - if (body.getBlocks().size() > 1) - return parser->emitError(loc) << "expects a single block region"; - // Ensure that the region is well formed: it contains at least a block with // a ReturnOp terminator. ReplicateOp::ensureTerminator(body, parser->getBuilder(), state->location); + if (!llvm::hasSingleElement(body)) + return parser->emitError(loc) << "expects a single block region"; + Operation& terminator = body.front().back(); if (!isa(terminator)) return parser->emitError(loc) << "expects a tf_device.return terminator"; diff --git a/tensorflow/compiler/mlir/tensorflow/ir/tf_executor.cc b/tensorflow/compiler/mlir/tensorflow/ir/tf_executor.cc index 3403651eef8..1e66eee06bb 100644 --- a/tensorflow/compiler/mlir/tensorflow/ir/tf_executor.cc +++ b/tensorflow/compiler/mlir/tensorflow/ir/tf_executor.cc @@ -220,13 +220,13 @@ ParseResult ParseGraphOp(OpAsmParser &parser, OperationState &result) { Region &body = *result.addRegion(); if (parser.parseRegion(body, llvm::None, llvm::None)) return failure(); - if (body.getBlocks().size() > 1) - return parser.emitError(loc) << "expects a single block region"; - // Ensure that the region is well formed: it contains at least a block with // a FetchOp terminator. GraphOp::ensureTerminator(body, parser.getBuilder(), result.location); + if (!llvm::hasSingleElement(body)) + return parser.emitError(loc) << "expects a single block region"; + // Get the results type from the terminator type inside the graph. Operation &fetch = body.back().back(); if (!isa(fetch)) diff --git a/tensorflow/compiler/mlir/tensorflow/ir/tf_generated_ops.td b/tensorflow/compiler/mlir/tensorflow/ir/tf_generated_ops.td index d403462e6a6..65ca3ea4dbd 100644 --- a/tensorflow/compiler/mlir/tensorflow/ir/tf_generated_ops.td +++ b/tensorflow/compiler/mlir/tensorflow/ir/tf_generated_ops.td @@ -52,15 +52,12 @@ an output element, this operation computes \\(y = |x|\\). def TF_AcosOp : TF_Op<"Acos", [NoSideEffect, SameOperandsAndResultType]> { let summary = "Computes acos of x element-wise."; - let description = [{ - }]; - let arguments = (ins - TensorOf<[BF16, F16, F32, F64, I32, I64, TF_Complex128, TF_Complex64]>:$x + TensorOf<[BF16, F16, F32, F64, I16, I32, I64, I8, TF_Complex128, TF_Complex64]>:$x ); let results = (outs - TensorOf<[BF16, F16, F32, F64, I32, I64, TF_Complex128, TF_Complex64]>:$y + TensorOf<[BF16, F16, F32, F64, I16, I32, I64, I8, TF_Complex128, TF_Complex64]>:$y ); TF_DerivedOperandTypeAttr T = TF_DerivedOperandTypeAttr<0>; @@ -164,6 +161,81 @@ def TF_AddV2Op : TF_Op<"AddV2", [Commutative, NoSideEffect, ResultsBroadcastable let hasFolder = 1; } +def TF_AdjustContrastv2Op : TF_Op<"AdjustContrastv2", [NoSideEffect]> { + let summary = "Adjust the contrast of one or more images."; + + let description = [{ +`images` is a tensor of at least 3 dimensions. The last 3 dimensions are +interpreted as `[height, width, channels]`. The other dimensions only +represent a collection of images, such as `[batch, height, width, channels].` + +Contrast is adjusted independently for each channel of each image. + +For each channel, the Op first computes the mean of the image pixels in the +channel and then adjusts each component of each pixel to +`(x - mean) * contrast_factor + mean`. + }]; + + let arguments = (ins + TensorOf<[F16, F32]>:$images, + F32Tensor:$contrast_factor + ); + + let results = (outs + TensorOf<[F16, F32]>:$output + ); + + TF_DerivedOperandTypeAttr T = TF_DerivedOperandTypeAttr<0>; +} + +def TF_AdjustHueOp : TF_Op<"AdjustHue", [NoSideEffect]> { + let summary = "Adjust the hue of one or more images."; + + let description = [{ +`images` is a tensor of at least 3 dimensions. The last dimension is +interpreted as channels, and must be three. + +The input image is considered in the RGB colorspace. Conceptually, the RGB +colors are first mapped into HSV. A delta is then applied all the hue values, +and then remapped back to RGB colorspace. + }]; + + let arguments = (ins + TensorOf<[F16, F32]>:$images, + F32Tensor:$delta + ); + + let results = (outs + TensorOf<[F16, F32]>:$output + ); + + TF_DerivedOperandTypeAttr T = TF_DerivedOperandTypeAttr<0>; +} + +def TF_AdjustSaturationOp : TF_Op<"AdjustSaturation", [NoSideEffect]> { + let summary = "Adjust the saturation of one or more images."; + + let description = [{ +`images` is a tensor of at least 3 dimensions. The last dimension is +interpreted as channels, and must be three. + +The input image is considered in the RGB colorspace. Conceptually, the RGB +colors are first mapped into HSV. A scale is then applied all the saturation +values, and then remapped back to RGB colorspace. + }]; + + let arguments = (ins + TensorOf<[F16, F32]>:$images, + F32Tensor:$scale + ); + + let results = (outs + TensorOf<[F16, F32]>:$output + ); + + TF_DerivedOperandTypeAttr T = TF_DerivedOperandTypeAttr<0>; +} + def TF_AllOp : TF_Op<"All", [NoSideEffect]> { let summary = [{ Computes the "logical and" of elements across dimensions of a tensor. @@ -296,9 +368,6 @@ retained with length 1. def TF_ApproximateEqualOp : TF_Op<"ApproximateEqual", [Commutative, NoSideEffect]> { let summary = "Returns the truth value of abs(x-y) < tolerance element-wise."; - let description = [{ - }]; - let arguments = (ins TensorOf<[BF16, F16, F32, F64, I16, I32, I64, I8, TF_Complex128, TF_Complex64, TF_Qint32, TF_Qint8, TF_Quint8, TF_Uint16, TF_Uint32, TF_Uint64, TF_Uint8]>:$x, TensorOf<[BF16, F16, F32, F64, I16, I32, I64, I8, TF_Complex128, TF_Complex64, TF_Qint32, TF_Qint8, TF_Quint8, TF_Uint16, TF_Uint32, TF_Uint64, TF_Uint8]>:$y, @@ -436,11 +505,11 @@ tf.math.asin(y) # [1.047, 0.785] = x }]; let arguments = (ins - TensorOf<[BF16, F16, F32, F64, I32, I64, TF_Complex128, TF_Complex64]>:$x + TensorOf<[BF16, F16, F32, F64, I16, I32, I64, I8, TF_Complex128, TF_Complex64]>:$x ); let results = (outs - TensorOf<[BF16, F16, F32, F64, I32, I64, TF_Complex128, TF_Complex64]>:$y + TensorOf<[BF16, F16, F32, F64, I16, I32, I64, I8, TF_Complex128, TF_Complex64]>:$y ); TF_DerivedOperandTypeAttr T = TF_DerivedOperandTypeAttr<0>; @@ -569,11 +638,11 @@ tf.math.atan(y) # [1.047, 0.785] = x }]; let arguments = (ins - TensorOf<[BF16, F16, F32, F64, I32, I64, TF_Complex128, TF_Complex64]>:$x + TensorOf<[BF16, F16, F32, F64, I16, I32, I64, I8, TF_Complex128, TF_Complex64]>:$x ); let results = (outs - TensorOf<[BF16, F16, F32, F64, I32, I64, TF_Complex128, TF_Complex64]>:$y + TensorOf<[BF16, F16, F32, F64, I16, I32, I64, I8, TF_Complex128, TF_Complex64]>:$y ); TF_DerivedOperandTypeAttr T = TF_DerivedOperandTypeAttr<0>; @@ -659,9 +728,6 @@ window in `value`. def TF_AvgPoolGradOp : TF_Op<"AvgPoolGrad", [NoSideEffect]> { let summary = "Computes gradients of the average pooling function."; - let description = [{ - }]; - let arguments = (ins I32Tensor:$orig_input_shape, TF_FpTensor:$grad, @@ -855,48 +921,6 @@ reverse of SpaceToBatch. See below for a precise description. TF_DerivedOperandTypeAttr Tblock_shape = TF_DerivedOperandTypeAttr<1>; } -def TF_BesselI0eOp : TF_Op<"BesselI0e", [NoSideEffect, SameOperandsAndResultType]> { - let summary = "Computes the Bessel i0e function of `x` element-wise."; - - let description = [{ -Exponentially scaled modified Bessel function of order 0 defined as -`bessel_i0e(x) = exp(-abs(x)) bessel_i0(x)`. - -This function is faster and numerically stabler than `bessel_i0(x)`. - }]; - - let arguments = (ins - TF_FpTensor:$x - ); - - let results = (outs - TF_FpTensor:$y - ); - - TF_DerivedOperandTypeAttr T = TF_DerivedOperandTypeAttr<0>; -} - -def TF_BesselI1eOp : TF_Op<"BesselI1e", [NoSideEffect, SameOperandsAndResultType]> { - let summary = "Computes the Bessel i1e function of `x` element-wise."; - - let description = [{ -Exponentially scaled modified Bessel function of order 0 defined as -`bessel_i1e(x) = exp(-abs(x)) bessel_i1(x)`. - -This function is faster and numerically stabler than `bessel_i1(x)`. - }]; - - let arguments = (ins - TF_FpTensor:$x - ); - - let results = (outs - TF_FpTensor:$y - ); - - TF_DerivedOperandTypeAttr T = TF_DerivedOperandTypeAttr<0>; -} - def TF_BiasAddOp : TF_Op<"BiasAdd", [NoSideEffect]> { let summary = "Adds `bias` to `value`."; @@ -1327,9 +1351,6 @@ An n-way switch statement, implementing the following: def TF_CastOp : TF_Op<"Cast", [NoSideEffect, SameOperandsAndResultShape]> { let summary = "Cast x of type SrcT to y of DstT."; - let description = [{ - }]; - let arguments = (ins TF_Tensor:$x, @@ -1349,9 +1370,6 @@ def TF_CastOp : TF_Op<"Cast", [NoSideEffect, SameOperandsAndResultShape]> { def TF_CeilOp : TF_Op<"Ceil", [NoSideEffect, SameOperandsAndResultType]> { let summary = "Returns element-wise smallest integer not less than x."; - let description = [{ - }]; - let arguments = (ins TF_FpTensor:$x ); @@ -1410,9 +1428,6 @@ greater than `clip_value_max` are set to `clip_value_max`. def TF_CollectiveBcastRecvOp : TF_Op<"CollectiveBcastRecv", []> { let summary = "Receives a tensor value broadcast from another device."; - let description = [{ - }]; - let arguments = (ins I64Attr:$group_size, I64Attr:$group_key, @@ -1432,9 +1447,6 @@ def TF_CollectiveBcastRecvOp : TF_Op<"CollectiveBcastRecv", []> { def TF_CollectiveBcastSendOp : TF_Op<"CollectiveBcastSend", []> { let summary = "Broadcasts a tensor value to one or more other devices."; - let description = [{ - }]; - let arguments = (ins TensorOf<[F16, F32, F64, I1, I32, I64]>:$input, @@ -1458,9 +1470,6 @@ def TF_CollectiveGatherOp : TF_Op<"CollectiveGather", []> { Mutually accumulates multiple tensors of identical type and shape. }]; - let description = [{ - }]; - let arguments = (ins TensorOf<[F16, F32, F64, I32, I64]>:$input, @@ -1484,9 +1493,6 @@ def TF_CollectiveReduceOp : TF_Op<"CollectiveReduce", [SameOperandsAndResultType Mutually reduces multiple tensors of identical type and shape. }]; - let description = [{ - }]; - let arguments = (ins TensorOf<[F16, F32, F64, I32, I64]>:$input, @@ -1566,9 +1572,6 @@ value is computed as \\( \sqrt{a^2 + b^2}\\). def TF_ConcatOp : TF_Op<"Concat", [NoSideEffect]> { let summary = "Concatenates tensors along one dimension."; - let description = [{ - }]; - let arguments = (ins I32Tensor:$concat_dim, Variadic:$values @@ -1625,9 +1628,6 @@ This is typically used by gradient computations for a concat operation. def TF_ConcatV2Op : TF_Op<"ConcatV2", [NoSideEffect]> { let summary = "Concatenates tensors along one dimension."; - let description = [{ - }]; - let arguments = (ins Variadic:$values, TF_I32OrI64Tensor:$axis @@ -1767,9 +1767,6 @@ def TF_Conv2DBackpropFilterOp : TF_Op<"Conv2DBackpropFilter", [NoSideEffect, TF_ Computes the gradients of convolution with respect to the filter. }]; - let description = [{ - }]; - let arguments = (ins TF_FpTensor:$input, I32Tensor:$filter_sizes, @@ -1803,9 +1800,6 @@ def TF_Conv2DBackpropInputOp : TF_Op<"Conv2DBackpropInput", [NoSideEffect, TF_La Computes the gradients of convolution with respect to the input. }]; - let description = [{ - }]; - let arguments = (ins I32Tensor:$input_sizes, TensorOf<[BF16, F16, F32, F64, I32]>:$filter, @@ -1877,9 +1871,6 @@ def TF_Conv3DBackpropFilterV2Op : TF_Op<"Conv3DBackpropFilterV2", [NoSideEffect] Computes the gradients of 3-D convolution with respect to the filter. }]; - let description = [{ - }]; - let arguments = (ins TF_FpTensor:$input, I32Tensor:$filter_sizes, @@ -1903,9 +1894,6 @@ def TF_Conv3DBackpropInputV2Op : TF_Op<"Conv3DBackpropInputV2", [NoSideEffect]> Computes the gradients of 3-D convolution with respect to the input. }]; - let description = [{ - }]; - let arguments = (ins TF_I32OrI64Tensor:$input_sizes, TF_FpTensor:$filter, @@ -2391,6 +2379,10 @@ def TF_DeviceIndexOp : TF_Op<"DeviceIndex", [NoSideEffect]> { let summary = "Return the index of device the op runs."; let description = [{ +Given a list of device names, this operation returns the index of the device +this op runs. The length of the list is returned in two cases: +(1) Device does not exist in the given device list. +(2) It is in XLA compilation. }]; let arguments = (ins @@ -2717,9 +2709,6 @@ def TF_EluGradOp : TF_Op<"EluGrad", [NoSideEffect, SameOperandsAndResultType]> { Computes gradients for the exponential linear (Elu) operation. }]; - let description = [{ - }]; - let arguments = (ins TF_FpTensor:$gradients, TF_FpTensor:$outputs @@ -2739,9 +2728,6 @@ Creates a tensor with the given shape. This operation creates a tensor of `shape` and `dtype`. }]; - let description = [{ - }]; - let arguments = (ins I32Tensor:$shape, @@ -2827,6 +2813,27 @@ the corresponding feature. TF_DerivedOperandSizeAttr N = TF_DerivedOperandSizeAttr<0>; } +def TF_EnsureShapeOp : TF_Op<"EnsureShape", [NoSideEffect]> { + let summary = "Ensures that the tensor's shape matches the expected shape."; + + let description = [{ +Raises an error if the input tensor's shape does not match the specified shape. +Returns the input tensor otherwise. + }]; + + let arguments = (ins + TF_Tensor:$input, + + TF_ShapeAttr:$shape + ); + + let results = (outs + TF_Tensor:$output + ); + + TF_DerivedOperandTypeAttr T = TF_DerivedOperandTypeAttr<0>; +} + def TF_EqualOp : TF_Op<"Equal", [Commutative, NoSideEffect]> { let summary = "Returns the truth value of (x == y) element-wise."; @@ -2871,9 +2878,6 @@ tf.math.equal(x, y) ==> array([True, True]) def TF_ErfOp : TF_Op<"Erf", [NoSideEffect, SameOperandsAndResultType]> { let summary = "Computes the Gauss error function of `x` element-wise."; - let description = [{ - }]; - let arguments = (ins TF_FpTensor:$x ); @@ -2890,9 +2894,6 @@ def TF_ErfcOp : TF_Op<"Erfc", [NoSideEffect, SameOperandsAndResultType]> { Computes the complementary error function of `x` element-wise. }]; - let description = [{ - }]; - let arguments = (ins TF_FpTensor:$x ); @@ -2907,9 +2908,6 @@ Computes the complementary error function of `x` element-wise. def TF_ErfinvOp : TF_Op<"Erfinv", [NoSideEffect]> { let summary = ""; - let description = [{ - }]; - let arguments = (ins TF_FpTensor:$x ); @@ -3107,6 +3105,25 @@ dimensions of `input`. TF_DerivedOperandTypeAttr Tcomplex = TF_DerivedOperandTypeAttr<0>; } +def TF_FakeParamOp : TF_Op<"FakeParam", [NoSideEffect]> { + let summary = [{ + This op is used as a placeholder in If branch functions. It doesn't provide a + valid output when run, so must either be removed (e.g. replaced with a + function input) or guaranteed not to be used (e.g. if mirroring an + intermediate output needed for the gradient computation of the other branch). + }]; + + let arguments = (ins + TF_ShapeAttr:$shape + ); + + let results = (outs + TF_Tensor:$output + ); + + TF_DerivedResultTypeAttr dtype = TF_DerivedResultTypeAttr<0>; +} + def TF_FakeQuantWithMinMaxArgsOp : TF_Op<"FakeQuantWithMinMaxArgs", [NoSideEffect, SameOperandsAndResultType]> { let summary = [{ Fake-quantize the 'inputs' tensor, type float to 'outputs' tensor of same type. @@ -3305,9 +3322,6 @@ fill([2, 3], 9) ==> [[9, 9, 9] def TF_FloorOp : TF_Op<"Floor", [NoSideEffect, SameOperandsAndResultType]> { let summary = "Returns element-wise largest integer not greater than x."; - let description = [{ - }]; - let arguments = (ins TF_FpTensor:$x ); @@ -3844,6 +3858,28 @@ tf.math.greater_equal(x, y) ==> [True, False, True, True] TF_DerivedOperandTypeAttr T = TF_DerivedOperandTypeAttr<0>; } +def TF_HSVToRGBOp : TF_Op<"HSVToRGB", [NoSideEffect]> { + let summary = "Convert one or more images from HSV to RGB."; + + let description = [{ +Outputs a tensor of the same shape as the `images` tensor, containing the RGB +value of the pixels. The output is only well defined if the value in `images` +are in `[0,1]`. + +See `rgb_to_hsv` for a description of the HSV encoding. + }]; + + let arguments = (ins + TF_FpTensor:$images + ); + + let results = (outs + TF_FpTensor:$output + ); + + TF_DerivedOperandTypeAttr T = TF_DerivedOperandTypeAttr<0>; +} + def TF_HashTableV2Op : TF_Op<"HashTableV2", []> { let summary = "Creates a non-initialized hash table."; @@ -4093,9 +4129,6 @@ def TF_IgammaGradAOp : TF_Op<"IgammaGradA", [NoSideEffect, ResultsBroadcastableS WithBroadcastableBinOpBuilder { let summary = "Computes the gradient of `igamma(a, x)` wrt `a`."; - let description = [{ - }]; - let arguments = (ins TF_F32OrF64Tensor:$a, TF_F32OrF64Tensor:$x @@ -4178,11 +4211,11 @@ I.e., \\(y = 1 / x\\). }]; let arguments = (ins - TensorOf<[BF16, F16, F32, F64, I32, I64, TF_Complex128, TF_Complex64]>:$x + TensorOf<[BF16, F16, F32, F64, I16, I32, I64, I8, TF_Complex128, TF_Complex64]>:$x ); let results = (outs - TensorOf<[BF16, F16, F32, F64, I32, I64, TF_Complex128, TF_Complex64]>:$y + TensorOf<[BF16, F16, F32, F64, I16, I32, I64, I8, TF_Complex128, TF_Complex64]>:$y ); TF_DerivedOperandTypeAttr T = TF_DerivedOperandTypeAttr<0>; @@ -4368,9 +4401,6 @@ tf.math.is_nan(x) ==> [False, True, False, True, False] def TF_IteratorGetNextOp : TF_Op<"IteratorGetNext", []> { let summary = "Gets the next output from the given iterator ."; - let description = [{ - }]; - let arguments = (ins TF_ResourceTensor:$iterator ); @@ -4439,9 +4469,6 @@ convolutional neural networks (NIPS 2012)](http://papers.nips.cc/paper/4824-imag def TF_LRNGradOp : TF_Op<"LRNGrad", [NoSideEffect]> { let summary = "Gradients for Local Response Normalization."; - let description = [{ - }]; - let arguments = (ins TensorOf<[BF16, F16, F32]>:$input_grads, TensorOf<[BF16, F16, F32]>:$input_image, @@ -4463,9 +4490,6 @@ def TF_LRNGradOp : TF_Op<"LRNGrad", [NoSideEffect]> { def TF_LeakyReluOp : TF_Op<"LeakyRelu", [NoSideEffect, SameOperandsAndResultType]> { let summary = "Computes rectified linear: `max(features, features * alpha)`."; - let description = [{ - }]; - let arguments = (ins TF_FpTensor:$features, @@ -4486,9 +4510,6 @@ def TF_LeakyReluGradOp : TF_Op<"LeakyReluGrad", [NoSideEffect, SameOperandsAndRe Computes rectified linear gradients for a LeakyRelu operation. }]; - let description = [{ - }]; - let arguments = (ins TF_FpTensor:$gradients, TF_FpTensor:$features, @@ -4769,9 +4790,6 @@ def TF_LogicalAndOp : TF_Op<"LogicalAnd", [Commutative, NoSideEffect, ResultsBro def TF_LogicalNotOp : TF_Op<"LogicalNot", [NoSideEffect, SameOperandsAndResultType]> { let summary = "Returns the truth value of `NOT x` element-wise."; - let description = [{ - }]; - let arguments = (ins I1Tensor:$x ); @@ -4852,9 +4870,6 @@ The tensor `values` must be of the type of the table values. def TF_LookupTableSizeV2Op : TF_Op<"LookupTableSizeV2", []> { let summary = "Computes the number of elements in the given table."; - let description = [{ - }]; - let arguments = (ins TF_ResourceTensor:$table_handle ); @@ -5539,9 +5554,6 @@ retained with length 1. def TF_MaxPoolOp : TF_Op<"MaxPool", [NoSideEffect, TF_FoldOperandsTransposeInterface]> { let summary = "Performs max pooling on the input."; - let description = [{ - }]; - let arguments = (ins TensorOf<[BF16, F16, F32, F64, I16, I32, I64, I8, TF_Qint8, TF_Uint16, TF_Uint8]>:$input, @@ -5568,9 +5580,6 @@ def TF_MaxPoolOp : TF_Op<"MaxPool", [NoSideEffect, TF_FoldOperandsTransposeInter def TF_MaxPool3DOp : TF_Op<"MaxPool3D", [NoSideEffect]> { let summary = "Performs 3D max pooling on the input."; - let description = [{ - }]; - let arguments = (ins TensorOf<[BF16, F16, F32]>:$input, @@ -5590,9 +5599,6 @@ def TF_MaxPool3DOp : TF_Op<"MaxPool3D", [NoSideEffect]> { def TF_MaxPool3DGradOp : TF_Op<"MaxPool3DGrad", [NoSideEffect]> { let summary = "Computes gradients of 3D max pooling function."; - let description = [{ - }]; - let arguments = (ins TensorOf<[BF16, F16, F32]>:$orig_input, TensorOf<[BF16, F16, F32]>:$orig_output, @@ -5615,9 +5621,6 @@ def TF_MaxPool3DGradOp : TF_Op<"MaxPool3DGrad", [NoSideEffect]> { def TF_MaxPoolGradOp : TF_Op<"MaxPoolGrad", [NoSideEffect]> { let summary = "Computes gradients of the maxpooling function."; - let description = [{ - }]; - let arguments = (ins TF_IntOrFpTensor:$orig_input, TF_IntOrFpTensor:$orig_output, @@ -5896,9 +5899,6 @@ Returns x * y element-wise. Returns zero if y is zero, even if x if infinite or def TF_MultinomialOp : TF_Op<"Multinomial", []> { let summary = "Draws samples from a multinomial distribution."; - let description = [{ - }]; - let arguments = (ins TF_IntOrFpTensor:$logits, I32Tensor:$num_samples, @@ -5918,9 +5918,6 @@ def TF_MultinomialOp : TF_Op<"Multinomial", []> { def TF_NdtriOp : TF_Op<"Ndtri", [NoSideEffect]> { let summary = ""; - let description = [{ - }]; - let arguments = (ins TF_FpTensor:$x ); @@ -5940,11 +5937,11 @@ I.e., \\(y = -x\\). }]; let arguments = (ins - TensorOf<[BF16, F16, F32, F64, I32, I64, TF_Complex128, TF_Complex64]>:$x + TensorOf<[BF16, F16, F32, F64, I16, I32, I64, I8, TF_Complex128, TF_Complex64]>:$x ); let results = (outs - TensorOf<[BF16, F16, F32, F64, I32, I64, TF_Complex128, TF_Complex64]>:$y + TensorOf<[BF16, F16, F32, F64, I16, I32, I64, I8, TF_Complex128, TF_Complex64]>:$y ); TF_DerivedOperandTypeAttr T = TF_DerivedOperandTypeAttr<0>; @@ -5955,9 +5952,6 @@ I.e., \\(y = -x\\). def TF_NoOp : TF_Op<"NoOp", [NoSideEffect]> { let summary = "Does nothing. Only useful as a placeholder for control edges."; - let description = [{ - }]; - let arguments = (ins); let results = (outs); @@ -6211,9 +6205,6 @@ output = def TF_OutfeedEnqueueTupleOp : TF_Op<"OutfeedEnqueueTuple", []> { let summary = "Enqueue multiple Tensor values on the computation outfeed."; - let description = [{ - }]; - let arguments = (ins Variadic:$inputs ); @@ -6498,9 +6489,6 @@ q_full, r_full = qr(a, full_matrices=True) def TF_QuantizeAndDequantizeOp : TF_Op<"QuantizeAndDequantize", [NoSideEffect, SameOperandsAndResultType]> { let summary = "Use QuantizeAndDequantizeV2 instead."; - let description = [{ - }]; - let arguments = (ins TF_FpTensor:$input, @@ -6711,15 +6699,47 @@ the dimension is padded with zeros. TF_DerivedResultTypeAttr Tcomplex = TF_DerivedResultTypeAttr<0>; } +def TF_RGBToHSVOp : TF_Op<"RGBToHSV", [NoSideEffect]> { + let summary = "Converts one or more images from RGB to HSV."; + + let description = [{ +Outputs a tensor of the same shape as the `images` tensor, containing the HSV +value of the pixels. The output is only well defined if the value in `images` +are in `[0,1]`. + +`output[..., 0]` contains hue, `output[..., 1]` contains saturation, and +`output[..., 2]` contains value. All HSV values are in `[0,1]`. A hue of 0 +corresponds to pure red, hue 1/3 is pure green, and 2/3 is pure blue. + +Usage Example: + +>>> blue_image = tf.stack([ +... tf.zeros([5,5]), +... tf.zeros([5,5]), +... tf.ones([5,5])], +... axis=-1) +>>> blue_hsv_image = tf.image.rgb_to_hsv(blue_image) +>>> blue_hsv_image[0,0].numpy() +array([0.6666667, 1. , 1. ], dtype=float32) + }]; + + let arguments = (ins + TF_FpTensor:$images + ); + + let results = (outs + TF_FpTensor:$output + ); + + TF_DerivedOperandTypeAttr T = TF_DerivedOperandTypeAttr<0>; +} + def TF_RandomGammaGradOp : TF_Op<"RandomGammaGrad", [NoSideEffect, ResultsBroadcastableShape]>, WithBroadcastableBinOpBuilder { let summary = [{ Computes the derivative of a Gamma random sample w.r.t. `alpha`. }]; - let description = [{ - }]; - let arguments = (ins TF_F32OrF64Tensor:$alpha, TF_F32OrF64Tensor:$sample @@ -6970,11 +6990,11 @@ I.e., \\(y = 1 / x\\). }]; let arguments = (ins - TensorOf<[BF16, F16, F32, F64, I32, I64, TF_Complex128, TF_Complex64]>:$x + TensorOf<[BF16, F16, F32, F64, I16, I32, I64, I8, TF_Complex128, TF_Complex64]>:$x ); let results = (outs - TensorOf<[BF16, F16, F32, F64, I32, I64, TF_Complex128, TF_Complex64]>:$y + TensorOf<[BF16, F16, F32, F64, I16, I32, I64, I8, TF_Complex128, TF_Complex64]>:$y ); TF_DerivedOperandTypeAttr T = TF_DerivedOperandTypeAttr<0>; @@ -7049,9 +7069,6 @@ array([ 0., 0., -0., 3.], dtype=float32) def TF_Relu6Op : TF_Op<"Relu6", [NoSideEffect, SameOperandsAndResultType]> { let summary = "Computes rectified linear 6: `min(max(features, 0), 6)`."; - let description = [{ - }]; - let arguments = (ins TF_IntOrFpTensor:$features ); @@ -7066,9 +7083,6 @@ def TF_Relu6Op : TF_Op<"Relu6", [NoSideEffect, SameOperandsAndResultType]> { def TF_Relu6GradOp : TF_Op<"Relu6Grad", [NoSideEffect, SameOperandsAndResultType]> { let summary = "Computes rectified linear 6 gradients for a Relu6 operation."; - let description = [{ - }]; - let arguments = (ins TF_IntOrFpTensor:$gradients, TF_IntOrFpTensor:$features @@ -7084,9 +7098,6 @@ def TF_Relu6GradOp : TF_Op<"Relu6Grad", [NoSideEffect, SameOperandsAndResultType def TF_ReluGradOp : TF_Op<"ReluGrad", [NoSideEffect, SameOperandsAndResultType]> { let summary = "Computes rectified linear gradients for a Relu operation."; - let description = [{ - }]; - let arguments = (ins TF_IntOrFpTensor:$gradients, TF_IntOrFpTensor:$features @@ -7208,14 +7219,29 @@ Input images can be of different types but output images are always float. TF_DerivedOperandTypeAttr T = TF_DerivedOperandTypeAttr<0>; } +def TF_ResizeBilinearGradOp : TF_Op<"ResizeBilinearGrad", [NoSideEffect]> { + let summary = "Computes the gradient of bilinear interpolation."; + + let arguments = (ins + F32Tensor:$grads, + TF_FpTensor:$original_image, + + DefaultValuedAttr:$align_corners, + DefaultValuedAttr:$half_pixel_centers + ); + + let results = (outs + TF_FpTensor:$output + ); + + TF_DerivedOperandTypeAttr T = TF_DerivedOperandTypeAttr<1>; +} + def TF_ResizeNearestNeighborOp : TF_Op<"ResizeNearestNeighbor", [NoSideEffect]> { let summary = [{ Resize `images` to `size` using nearest neighbor interpolation. }]; - let description = [{ - }]; - let arguments = (ins TensorOf<[BF16, F16, F32, F64, I16, I32, I64, I8, TF_Uint16, TF_Uint8]>:$images, I32Tensor:$size, @@ -7332,9 +7358,6 @@ var <- var - mom def TF_ResourceApplyGradientDescentOp : TF_Op<"ResourceApplyGradientDescent", []> { let summary = "Update '*var' by subtracting 'alpha' * 'delta' from it."; - let description = [{ - }]; - let arguments = (ins TF_ResourceTensor:$var, TensorOf<[BF16, F16, F32, F64, I16, I32, I64, I8, TF_Complex128, TF_Complex64, TF_Qint32, TF_Qint8, TF_Quint8, TF_Uint16, TF_Uint32, TF_Uint64, TF_Uint8]>:$alpha, @@ -7697,11 +7720,11 @@ according to the current system rounding mode use std::cint. }]; let arguments = (ins - TensorOf<[BF16, F16, F32, F64, I32, I64, TF_Complex128, TF_Complex64]>:$x + TensorOf<[BF16, F16, F32, F64, I16, I32, I64, I8, TF_Complex128, TF_Complex64]>:$x ); let results = (outs - TensorOf<[BF16, F16, F32, F64, I32, I64, TF_Complex128, TF_Complex64]>:$y + TensorOf<[BF16, F16, F32, F64, I16, I32, I64, I8, TF_Complex128, TF_Complex64]>:$y ); TF_DerivedOperandTypeAttr T = TF_DerivedOperandTypeAttr<0>; @@ -8117,9 +8140,6 @@ select(condition, t, e) ==> [[1, 2], def TF_SelectV2Op : TF_Op<"SelectV2", [NoSideEffect, ResultsBroadcastableShape]> { let summary = ""; - let description = [{ - }]; - let arguments = (ins I1Tensor:$condition, TF_Tensor:$t, @@ -8168,9 +8188,6 @@ def TF_SeluGradOp : TF_Op<"SeluGrad", [NoSideEffect, SameOperandsAndResultType]> Computes gradients for the scaled exponential linear (Selu) operation. }]; - let description = [{ - }]; - let arguments = (ins TF_FpTensor:$gradients, TF_FpTensor:$outputs @@ -8421,9 +8438,6 @@ whose values are extracted from 'input' starting at the offsets in def TF_SnapshotOp : TF_Op<"Snapshot", [NoSideEffect, SameOperandsAndResultType]> { let summary = "Returns a copy of the input tensor."; - let description = [{ - }]; - let arguments = (ins TF_Tensor:$input ); @@ -8488,9 +8502,6 @@ Inputs are the logits, not probabilities. def TF_SoftplusOp : TF_Op<"Softplus", [NoSideEffect, SameOperandsAndResultType]> { let summary = "Computes softplus: `log(exp(features) + 1)`."; - let description = [{ - }]; - let arguments = (ins TF_FpTensor:$features ); @@ -8505,9 +8516,6 @@ def TF_SoftplusOp : TF_Op<"Softplus", [NoSideEffect, SameOperandsAndResultType]> def TF_SoftplusGradOp : TF_Op<"SoftplusGrad", [NoSideEffect, SameOperandsAndResultType]> { let summary = "Computes softplus gradients for a softplus operation."; - let description = [{ - }]; - let arguments = (ins TF_FpTensor:$gradients, TF_FpTensor:$features @@ -8523,9 +8531,6 @@ def TF_SoftplusGradOp : TF_Op<"SoftplusGrad", [NoSideEffect, SameOperandsAndResu def TF_SoftsignOp : TF_Op<"Softsign", [NoSideEffect, SameOperandsAndResultType]> { let summary = "Computes softsign: `features / (abs(features) + 1)`."; - let description = [{ - }]; - let arguments = (ins TF_FpTensor:$features ); @@ -8540,9 +8545,6 @@ def TF_SoftsignOp : TF_Op<"Softsign", [NoSideEffect, SameOperandsAndResultType]> def TF_SoftsignGradOp : TF_Op<"SoftsignGrad", [NoSideEffect, SameOperandsAndResultType]> { let summary = "Computes softsign gradients for a softsign operation."; - let description = [{ - }]; - let arguments = (ins TF_FpTensor:$gradients, TF_FpTensor:$features @@ -8790,9 +8792,6 @@ are checked during execution. def TF_SplitOp : TF_Op<"Split", [NoSideEffect]> { let summary = "Splits a tensor into `num_split` tensors along one dimension."; - let description = [{ - }]; - let arguments = (ins I32Tensor:$split_dim, TF_Tensor:$value @@ -8811,9 +8810,6 @@ def TF_SplitOp : TF_Op<"Split", [NoSideEffect]> { def TF_SplitVOp : TF_Op<"SplitV", [NoSideEffect]> { let summary = "Splits a tensor into `num_split` tensors along one dimension."; - let description = [{ - }]; - let arguments = (ins TF_Tensor:$value, TF_I32OrI64Tensor:$size_splits, @@ -8877,11 +8873,11 @@ I.e., \\(y = x * x = x^2\\). }]; let arguments = (ins - TensorOf<[BF16, F16, F32, F64, I32, I64, TF_Complex128, TF_Complex64]>:$x + TensorOf<[BF16, F16, F32, F64, I16, I32, I64, I8, TF_Complex128, TF_Complex64]>:$x ); let results = (outs - TensorOf<[BF16, F16, F32, F64, I32, I64, TF_Complex128, TF_Complex64]>:$y + TensorOf<[BF16, F16, F32, F64, I16, I32, I64, I8, TF_Complex128, TF_Complex64]>:$y ); TF_DerivedOperandTypeAttr T = TF_DerivedOperandTypeAttr<0>; @@ -8950,9 +8946,6 @@ shape(squeeze(t, [2, 4])) ==> [1, 2, 3, 1] def TF_StackCloseV2Op : TF_Op<"StackCloseV2", []> { let summary = "Delete the stack from its resource container."; - let description = [{ - }]; - let arguments = (ins TF_ResourceTensor:$handle ); @@ -8963,9 +8956,6 @@ def TF_StackCloseV2Op : TF_Op<"StackCloseV2", []> { def TF_StackPopV2Op : TF_Op<"StackPopV2", []> { let summary = "Pop the element at the top of the stack."; - let description = [{ - }]; - let arguments = (ins TF_ResourceTensor:$handle ); @@ -8980,9 +8970,6 @@ def TF_StackPopV2Op : TF_Op<"StackPopV2", []> { def TF_StackPushV2Op : TF_Op<"StackPushV2", []> { let summary = "Push an element onto the stack."; - let description = [{ - }]; - let arguments = (ins TF_ResourceTensor:$handle, TF_Tensor:$elem, @@ -9000,9 +8987,6 @@ def TF_StackPushV2Op : TF_Op<"StackPushV2", []> { def TF_StackV2Op : TF_Op<"StackV2", []> { let summary = "A stack that produces elements in first-in last-out order."; - let description = [{ - }]; - let arguments = (ins I32Tensor:$max_size, @@ -9015,6 +8999,32 @@ def TF_StackV2Op : TF_Op<"StackV2", []> { ); } +def TF_StatelessRandomUniformOp : TF_Op<"StatelessRandomUniform", [NoSideEffect]> { + let summary = [{ +Outputs deterministic pseudorandom random values from a uniform distribution. + }]; + + let description = [{ +The generated values follow a uniform distribution in the range `[0, 1)`. The +lower bound 0 is included in the range, while the upper bound 1 is excluded. + +The outputs are a deterministic function of `shape` and `seed`. + }]; + + let arguments = (ins + TF_I32OrI64Tensor:$shape, + TF_I32OrI64Tensor:$seed + ); + + let results = (outs + TF_FpTensor:$output + ); + + TF_DerivedOperandTypeAttr T = TF_DerivedOperandTypeAttr<0>; + TF_DerivedOperandTypeAttr Tseed = TF_DerivedOperandTypeAttr<1>; + TF_DerivedResultTypeAttr dtype = TF_DerivedResultTypeAttr<0>; +} + def TF_StopGradientOp : TF_Op<"StopGradient", [NoSideEffect, TF_AllTypesMatch<["input", "output"]>]> { let summary = "Stops gradient computation."; @@ -9508,11 +9518,11 @@ Given an input tensor, this function computes tangent of every }]; let arguments = (ins - TensorOf<[BF16, F16, F32, F64, I32, I64, TF_Complex128, TF_Complex64]>:$x + TensorOf<[BF16, F16, F32, F64, I16, I32, I64, I8, TF_Complex128, TF_Complex64]>:$x ); let results = (outs - TensorOf<[BF16, F16, F32, F64, I32, I64, TF_Complex128, TF_Complex64]>:$y + TensorOf<[BF16, F16, F32, F64, I16, I32, I64, I8, TF_Complex128, TF_Complex64]>:$y ); TF_DerivedOperandTypeAttr T = TF_DerivedOperandTypeAttr<0>; @@ -9694,9 +9704,6 @@ calculation gets its own TensorArray accumulator. def TF_TensorArrayReadV3Op : TF_Op<"TensorArrayReadV3", []> { let summary = "Read an element from the TensorArray into output `value`."; - let description = [{ - }]; - let arguments = (ins TF_ResourceTensor:$handle, I32Tensor:$index, @@ -9736,9 +9743,6 @@ Scatter the data from the input value into specific TensorArray elements. def TF_TensorArraySizeV3Op : TF_Op<"TensorArraySizeV3", []> { let summary = "Get the current size of the TensorArray."; - let description = [{ - }]; - let arguments = (ins TF_ResourceTensor:$handle, F32Tensor:$flow_in @@ -9815,9 +9819,6 @@ Write data via Write and read via Read or Pack. def TF_TensorArrayWriteV3Op : TF_Op<"TensorArrayWriteV3", []> { let summary = "Push an element onto the tensor_array."; - let description = [{ - }]; - let arguments = (ins TF_ResourceTensor:$handle, I32Tensor:$index, @@ -9938,9 +9939,6 @@ values: The tensor. def TF_TensorListGetItemOp : TF_Op<"TensorListGetItem", [NoSideEffect]> { let summary = ""; - let description = [{ - }]; - let arguments = (ins TF_VariantTensor:$input_handle, I32Tensor:$index, @@ -10070,9 +10068,6 @@ output_handle: The TensorList. def TF_TensorListSetItemOp : TF_Op<"TensorListSetItem", [NoSideEffect]> { let summary = ""; - let description = [{ - }]; - let arguments = (ins TF_VariantTensor:$input_handle, I32Tensor:$index, @@ -10862,9 +10857,6 @@ def TF_XdivyOp : TF_Op<"Xdivy", [NoSideEffect, ResultsBroadcastableShape]>, WithBroadcastableBinOpBuilder { let summary = "Returns 0 if x == 0, and x / y otherwise, elementwise."; - let description = [{ - }]; - let arguments = (ins TensorOf<[F16, F32, F64, TF_Complex128, TF_Complex64]>:$x, TensorOf<[F16, F32, F64, TF_Complex128, TF_Complex64]>:$y @@ -11041,9 +11033,6 @@ def TF_XlaHostComputeOp : TF_Op<"XlaHostCompute", []> { A pseudo-op to represent host-side computation in an XLA program. }]; - let description = [{ - }]; - let arguments = (ins Variadic:$inputs, @@ -11114,9 +11103,6 @@ https://www.tensorflow.org/performance/xla/operation_semantics#pad def TF_XlaRecvFromHostOp : TF_Op<"XlaRecvFromHost", []> { let summary = "An op to receive a tensor from the host."; - let description = [{ - }]; - let arguments = (ins TF_ShapeAttr:$shape, StrAttr:$key @@ -11154,9 +11140,6 @@ https://www.tensorflow.org/performance/xla/operation_semantics#reduce . def TF_XlaReplicaIdOp : TF_Op<"XlaReplicaId", [NoSideEffect]> { let summary = "Replica ID."; - let description = [{ - }]; - let arguments = (ins); let results = (outs @@ -11196,9 +11179,6 @@ i=0...N-1. def TF_XlaSendToHostOp : TF_Op<"XlaSendToHost", []> { let summary = "An op to send a tensor to the host."; - let description = [{ - }]; - let arguments = (ins TF_Tensor:$input, @@ -11242,9 +11222,6 @@ tensor such that tensor[...,:,:] = u[..., :, :] * Diag(s[..., :]) * Transpose(v[ def TF_Xlog1pyOp : TF_Op<"Xlog1py", [NoSideEffect]> { let summary = "Returns 0 if x == 0, and x * log1p(y) otherwise, elementwise."; - let description = [{ - }]; - let arguments = (ins TensorOf<[F16, F32, F64, TF_Complex128, TF_Complex64]>:$x, TensorOf<[F16, F32, F64, TF_Complex128, TF_Complex64]>:$y @@ -11261,9 +11238,6 @@ def TF_XlogyOp : TF_Op<"Xlogy", [NoSideEffect, ResultsBroadcastableShape]>, WithBroadcastableBinOpBuilder { let summary = "Returns 0 if x == 0, and x * log(y) otherwise, elementwise."; - let description = [{ - }]; - let arguments = (ins TensorOf<[F16, F32, F64, TF_Complex128, TF_Complex64]>:$x, TensorOf<[F16, F32, F64, TF_Complex128, TF_Complex64]>:$y @@ -11279,9 +11253,6 @@ def TF_XlogyOp : TF_Op<"Xlogy", [NoSideEffect, ResultsBroadcastableShape]>, def TF_ZerosLikeOp : TF_Op<"ZerosLike", [NoSideEffect, SameOperandsAndResultType]> { let summary = "Returns a tensor of zeros with the same shape and type as x."; - let description = [{ - }]; - let arguments = (ins TF_Tensor:$x ); @@ -11388,9 +11359,6 @@ expected to create these operators. def TF__HostComputeMlirOp : TF_Op<"_HostComputeMlir", []> { let summary = "A host-side computation called from a TPU device."; - let description = [{ - }]; - let arguments = (ins Variadic:$inputs, @@ -11470,9 +11438,6 @@ def TF__XlaRecvAtHostOp : TF_Op<"_XlaRecvAtHost", []> { A placeholder op to receive values from a running XLA computation. }]; - let description = [{ - }]; - let arguments = (ins TF_StrTensor:$dynamic_key, @@ -11490,9 +11455,6 @@ A placeholder op to receive values from a running XLA computation. def TF__XlaSendFromHostOp : TF_Op<"_XlaSendFromHost", []> { let summary = "A placeholder op to send values to a running XLA computation."; - let description = [{ - }]; - let arguments = (ins Variadic:$inputs, TF_StrTensor:$dynamic_key, diff --git a/tensorflow/compiler/mlir/tensorflow/ir/tf_ops.td b/tensorflow/compiler/mlir/tensorflow/ir/tf_ops.td index d8675bb786f..f5d8fbae46a 100644 --- a/tensorflow/compiler/mlir/tensorflow/ir/tf_ops.td +++ b/tensorflow/compiler/mlir/tensorflow/ir/tf_ops.td @@ -232,6 +232,7 @@ else_branch: A function that takes 'inputs' and returns a list of def TF_YieldOp : TF_Op<"Yield", [Terminator]> { let summary = "Yield operation"; + let description = [{ The "yield" operation represents a return operation within the conditional and body of structured control flow (e.g., if and while). The operation @@ -497,6 +498,7 @@ Inserts a placeholder for a tensor that will be always fed. def TF_PlaceholderWithDefaultOp : TF_Op<"PlaceholderWithDefault", [NoSideEffect]> { let summary = "Placeholder op"; + let description = [{ A placeholder op that passes through input when its output is not fed. }]; @@ -839,9 +841,6 @@ def TF_XlaShardingOp : TF_Op<"XlaSharding", [NoSideEffect]> { An op which shards the input based on the given sharding attribute. }]; - let description = [{ - }]; - let arguments = (ins TF_Tensor:$input, @@ -858,9 +857,6 @@ An op which shards the input based on the given sharding attribute. def TF_InfeedDequeueTupleOp : TF_Op<"InfeedDequeueTuple", []> { let summary = "Fetches multiple values from infeed as an XLA tuple."; - let description = [{ - }]; - let arguments = (ins OptionalAttr:$_XlaSharding ); @@ -904,9 +900,6 @@ def TF_BatchDatasetV2Op : TF_Op<"BatchDatasetV2", [NoSideEffect]> { Creates a dataset that batches `batch_size` elements from `input_dataset`. }]; - let description = [{ - }]; - let arguments = (ins TF_VariantTensor:$input_dataset, I64Tensor:$batch_size, @@ -1048,4 +1041,46 @@ operation create / operate on a copy of `x`. TF_DerivedOperandTypeAttr T = TF_DerivedOperandTypeAttr<0>; } +def TF_BesselI0eOp : TF_Op<"BesselI0e", [NoSideEffect, SameOperandsAndResultType]> { + let summary = "Computes the Bessel i0e function of `x` element-wise."; + + let description = [{ +Exponentially scaled modified Bessel function of order 0 defined as +`bessel_i0e(x) = exp(-abs(x)) bessel_i0(x)`. + +This function is faster and numerically stabler than `bessel_i0(x)`. + }]; + + let arguments = (ins + TF_FpTensor:$x + ); + + let results = (outs + TF_FpTensor:$y + ); + + TF_DerivedOperandTypeAttr T = TF_DerivedOperandTypeAttr<0>; +} + +def TF_BesselI1eOp : TF_Op<"BesselI1e", [NoSideEffect, SameOperandsAndResultType]> { + let summary = "Computes the Bessel i1e function of `x` element-wise."; + + let description = [{ +Exponentially scaled modified Bessel function of order 0 defined as +`bessel_i1e(x) = exp(-abs(x)) bessel_i1(x)`. + +This function is faster and numerically stabler than `bessel_i1(x)`. + }]; + + let arguments = (ins + TF_FpTensor:$x + ); + + let results = (outs + TF_FpTensor:$y + ); + + TF_DerivedOperandTypeAttr T = TF_DerivedOperandTypeAttr<0>; +} + #endif // TF_OPS diff --git a/tensorflow/compiler/mlir/tensorflow/ir/tf_saved_model.cc b/tensorflow/compiler/mlir/tensorflow/ir/tf_saved_model.cc index 140a778770c..5a7d81d4c0c 100644 --- a/tensorflow/compiler/mlir/tensorflow/ir/tf_saved_model.cc +++ b/tensorflow/compiler/mlir/tensorflow/ir/tf_saved_model.cc @@ -28,6 +28,7 @@ limitations under the License. #include "mlir/IR/Identifier.h" // from @llvm-project #include "mlir/IR/Module.h" // from @llvm-project #include "mlir/IR/OpImplementation.h" // from @llvm-project +#include "mlir/IR/PatternMatch.h" // from @llvm-project #include "mlir/IR/StandardTypes.h" // from @llvm-project #include "mlir/IR/SymbolTable.h" // from @llvm-project #include "mlir/IR/TypeUtilities.h" // from @llvm-project @@ -76,6 +77,23 @@ static LogicalResult Verify(GlobalTensorOp global_tensor) { return success(); } +static LogicalResult Verify(SessionInitializerOp session_initializer) { + mlir::SymbolTable symbol_table( + session_initializer.getParentOfType()); + + auto init_func_op = + symbol_table.lookup(session_initializer.initializer()); + if (!init_func_op) + return session_initializer.emitOpError() + << "the initializer function does not exist"; + + if (!init_func_op.getType().getResults().empty()) + return session_initializer.emitOpError() + << "the initializer function should have no output"; + + return success(); +} + #define GET_OP_CLASSES #include "tensorflow/compiler/mlir/tensorflow/ir/tf_saved_model.cc.inc" @@ -212,14 +230,36 @@ static LogicalResult VerifySavedModelModule( } } for (auto func : module.getOps()) { - if (HasAnyTfSavedModelArgAttr(func)) { - if (!IsExported(func)) { - return func.emitError() - << "can only apply 'tf_saved_model' argument attributes " - "to exported functions"; - } + const bool is_exported = IsExported(func); + + if (is_exported && !func.isPublic()) { + return func.emitError() + << "exported function @" << func.getName() << " should be public"; + } + + if (!is_exported && func.isPublic()) { + return func.emitError() << "non-exported function @" << func.getName() + << " should be private"; + } + + if (!is_exported && HasAnyTfSavedModelArgAttr(func)) { + return func.emitError() << "can only apply 'tf_saved_model' argument " + "attributes to exported functions"; } } + + auto session_initializers = module.getOps(); + if (!session_initializers.empty() && + !llvm::hasSingleElement(session_initializers)) { + return (*++session_initializers.begin()).emitError() + << "there must be no more than one session_initializer op"; + } + + auto is_init = [&session_initializers](mlir::FuncOp func) { + if (session_initializers.empty()) return false; + return (*session_initializers.begin()).initializer() == func.getName(); + }; + SymbolTable symbol_table(module); auto symbol_uses = SymbolTable::getSymbolUses(&module.getBodyRegion()); if (!symbol_uses.hasValue()) { @@ -230,6 +270,12 @@ static LogicalResult VerifySavedModelModule( auto func = symbol_table.lookup( symbol_use.getSymbolRef().cast().getValue()); if (func && IsExported(func)) { + // If it is an init function, then it can be used by the unique + // session_initializer op. + if (is_init(func) && + llvm::isa(symbol_use.getUser())) + continue; + return symbol_use.getUser() ->emitError("exported function cannot be internally referenced") .attachNote(func.getLoc()) @@ -349,5 +395,39 @@ GlobalTensorOp LookupBoundInput(FuncOp func, int arg_index, return symbol_table.lookup(attr.getValue()); } +SessionInitializerOp GetSessionInitializerOp(mlir::ModuleOp op) { + auto initializers = op.getOps(); + if (initializers.empty()) return {}; + return *initializers.begin(); +} + +class OptimizeSessionInitializerPattern + : public OpRewritePattern { + public: + using OpRewritePattern::OpRewritePattern; + + LogicalResult matchAndRewrite(SessionInitializerOp op, + PatternRewriter &rewriter) const override { + SymbolTable symbol_table(op.getParentOfType()); + auto init_func_op = symbol_table.lookup(op.initializer()); + + // The init function can only be referenced from the SessionInitializerOp. + // And there is at most one SessionInitializerOp in the module. So both ops + // have no other uses and can be simply erased. + if (init_func_op.front().begin()->isKnownTerminator()) { + rewriter.eraseOp(init_func_op); + rewriter.eraseOp(op); + return success(); + } + + return failure(); + } +}; + +void SessionInitializerOp::getCanonicalizationPatterns( + OwningRewritePatternList &results, MLIRContext *context) { + results.insert(context); +} + } // namespace tf_saved_model } // namespace mlir diff --git a/tensorflow/compiler/mlir/tensorflow/ir/tf_saved_model.h b/tensorflow/compiler/mlir/tensorflow/ir/tf_saved_model.h index 47ebb1a1be5..b6f8753cc51 100644 --- a/tensorflow/compiler/mlir/tensorflow/ir/tf_saved_model.h +++ b/tensorflow/compiler/mlir/tensorflow/ir/tf_saved_model.h @@ -61,6 +61,10 @@ GlobalTensorOp LookupBoundInput(FuncOp func, int arg_index, // should have. Type GetBoundInputArgTypeFor(GlobalTensorOp global_tensor); +// Returns the session initializer of this module if it exists. Returns null +// otherwise. +SessionInitializerOp GetSessionInitializerOp(mlir::ModuleOp op); + } // namespace tf_saved_model } // namespace mlir diff --git a/tensorflow/compiler/mlir/tensorflow/ir/tf_saved_model_ops.td b/tensorflow/compiler/mlir/tensorflow/ir/tf_saved_model_ops.td index 4431a160edf..dc1210a4d2a 100644 --- a/tensorflow/compiler/mlir/tensorflow/ir/tf_saved_model_ops.td +++ b/tensorflow/compiler/mlir/tensorflow/ir/tf_saved_model_ops.td @@ -128,4 +128,30 @@ def TfSavedModel_GlobalTensorOp : TfSavedModel_Op<"global_tensor"> { let verifier = [{ return Verify(*this); }]; } +def TfSavedModel_SessionInitializerOp: TfSavedModel_Op<"session_initializer"> { + let summary = "Initializes TensorFlow session state."; + let description = [{ + The session initializer op marks a function that must be called by an + external agent exactly once to initialize TensorFlow session state, and this + must happen before any other exported functions are called. There must be no + more than one session initializer in a saved model. + + The `initializer` represents the initialization function. The function have + no output and this function should be only called once. + + This is used, for example, to initialize hash tables stored in resources and + accessed by resource name (rather than as resource handles or bound inputs + which is how `global_tensor`s are referenced) + }]; + + let arguments = (ins + FlatSymbolRefAttr:$initializer + ); + + + let verifier = [{ return Verify(*this); }]; + + let hasCanonicalizer = 1; +} + #endif // SAVED_MODEL_DIALECT diff --git a/tensorflow/compiler/mlir/tensorflow/tests/function_visibility.mlir b/tensorflow/compiler/mlir/tensorflow/tests/function_visibility.mlir deleted file mode 100644 index 55af3cffde3..00000000000 --- a/tensorflow/compiler/mlir/tensorflow/tests/function_visibility.mlir +++ /dev/null @@ -1,47 +0,0 @@ -// RUN: tf-opt -tf-saved-model-mark-func-visibility -split-input-file %s | FileCheck --check-prefix=SAVEDMODEL %s -// RUN: tf-opt -tf-mark-func-visibility -split-input-file -verify-diagnostics %s | FileCheck %s - - -module attributes {tf_saved_model.semantics} { - // SAVEDMODEL: func @func_exported_1() attributes {tf_saved_model.exported_names = ["func_exported_1"]} - func @func_exported_1() attributes {tf_saved_model.exported_names = ["func_exported_1"]} { - "tf.some_call"() {callee = {callee = {callee = @child}}} : () -> () - return - } - - // SAVEDMODEL: func @func_exported_2() attributes {tf_saved_model.exported_names = ["func_exported_2"]} - func @func_exported_2() attributes {tf_saved_model.exported_names = ["func_exported_2"]} { - "tf.some_call"() {callee = {callee = {callee = @child}}} : () -> () - return - } - - // SAVEDMODEL: func @func_not_exported() attributes {sym_visibility = "private"} - func @func_not_exported() { - return - } - -} - -// ----- - -module { - // CHECK: func @func_with_entry_spec(%arg0: tensor<1xi32>) -> tensor<1xi32> attributes {tf.entry_function = {inputs = "x", outputs = "y"}} - func @func_with_entry_spec(%arg0: tensor<1xi32>) -> tensor<1xi32> attributes {tf.entry_function = {inputs = "x", outputs = "y"}} { - return %arg0 : tensor<1xi32> - } - - // CHECK: func @func_without_entry_spec(%arg0: tensor<*xi32>, %arg1: tensor<*xi32>) -> tensor<*xi32> attributes {sym_visibility = "private"} - func @func_without_entry_spec(%arg0: tensor<*xi32>, %arg1: tensor<*xi32>) -> tensor<*xi32> { - %0 = "tf.AddV2"(%arg0, %arg1) {T = i32, device = ""} : (tensor<*xi32>, tensor<*xi32>) -> tensor<*xi32> - return %0 : tensor<*xi32> - } -} - -// ----- - -module { - // expected-error @+1 {{can't overwrite the visibility of function private_func_with_entry_spec with private visibility}} - func @private_func_with_entry_spec(%arg0: tensor<1xi32>) -> tensor<1xi32> attributes {tf.entry_function = {inputs = "x", outputs = "y"}, sym_visibility = "private"} { - return %arg0 : tensor<1xi32> - } -} diff --git a/tensorflow/compiler/mlir/tensorflow/tests/shape_inference.mlir b/tensorflow/compiler/mlir/tensorflow/tests/shape_inference.mlir index 1599d53ed15..1af4ba6b3dc 100644 --- a/tensorflow/compiler/mlir/tensorflow/tests/shape_inference.mlir +++ b/tensorflow/compiler/mlir/tensorflow/tests/shape_inference.mlir @@ -433,4 +433,17 @@ func @multiple_blocks_one_return(%arg0: tensor) -> tensor<*xf32> { // CHECK: return %[[CAST_RESULT_0]], %[[CAST_RESULT_1]], %[[ADDI]] return %27, %28, %2 : tensor<*xui8>, tensor<*xi8>, tensor<*xi8> } + + // CHECK-LABEL: infer_device_launch + func @infer_device_launch(%arg0: tensor<1x8x2xi32>) -> (tensor<*xf32>, tensor<*xf32>) { + %0 = "tf.Const"() {value = dense<-1> : tensor} : () -> tensor + %1 = "tf_device.launch"() ({ + %2 = "tf.Cast"(%arg0) {Truncate = false} : (tensor<1x8x2xi32>) -> tensor<1x8x2xf32> + tf_device.return %2 : tensor<1x8x2xf32> + // CHECK: () -> tensor<1x8x2xf32> + }) {device = "/device:CPU:0"} : () -> tensor<*xf32> + // CHECK: (tensor, tensor<1x8x2xf32>) -> (tensor<1x8x1xf32>, tensor<1x8x1xf32>) + %3:2 = "tf.Split"(%0, %1) {device = ""} : (tensor, tensor<*xf32>) -> (tensor<*xf32>, tensor<*xf32>) + return %3#0, %3#1 : tensor<*xf32>, tensor<*xf32> + } } diff --git a/tensorflow/compiler/mlir/tensorflow/tests/tf_saved_model/build_defs.bzl b/tensorflow/compiler/mlir/tensorflow/tests/tf_saved_model/build_defs.bzl index 594afa10453..95ad05aa1e6 100644 --- a/tensorflow/compiler/mlir/tensorflow/tests/tf_saved_model/build_defs.bzl +++ b/tensorflow/compiler/mlir/tensorflow/tests/tf_saved_model/build_defs.bzl @@ -4,8 +4,6 @@ load("//tensorflow/compiler/mlir:glob_lit_test.bzl", "lit_test") def tf_saved_model_test(name, data, tags = None): """Create a SavedModel test.""" - if tags == None: - tags = ["no_rocm"] native.py_binary( name = name, testonly = 1, @@ -26,5 +24,5 @@ def tf_saved_model_test(name, data, tags = None): name = name + ".py", data = [name] + data, driver = "@llvm-project//mlir:run_lit.sh", - tags = tags, + tags = tags + ["no_rocm"], ) diff --git a/tensorflow/compiler/mlir/tensorflow/tests/tf_saved_model/common_v1.py b/tensorflow/compiler/mlir/tensorflow/tests/tf_saved_model/common_v1.py index 7171f63bb05..5bfcfa5378a 100644 --- a/tensorflow/compiler/mlir/tensorflow/tests/tf_saved_model/common_v1.py +++ b/tensorflow/compiler/mlir/tensorflow/tests/tf_saved_model/common_v1.py @@ -46,7 +46,10 @@ def set_tf_options(): # This function needs to take a "create_module_fn", as opposed to just the # module itself, because the creation of the module has to be delayed until # after absl and tensorflow have run various initialization steps. -def do_test(signature_def_map, show_debug_info=False): +def do_test(signature_def_map, + init_op=None, + canonicalize=False, + show_debug_info=False): """Runs test. 1. Performs absl and tf "main"-like initialization that must run before almost @@ -61,6 +64,9 @@ def do_test(signature_def_map, show_debug_info=False): Args: signature_def_map: A map from string key to signature_def. The key will be used as function name in the resulting MLIR. + init_op: The initializer op for the saved model. If set, it will generate a + initializer graph in the resulting MLIR. + canonicalize: If true, canonicalizer will be run on the resulting MLIR. show_debug_info: If true, shows debug locations in the resulting MLIR. """ @@ -84,6 +90,7 @@ def do_test(signature_def_map, show_debug_info=False): builder.add_meta_graph_and_variables( sess, [tf.saved_model.tag_constants.SERVING], signature_def_map, + main_op=init_op, strip_default_attrs=True) builder.save() @@ -97,6 +104,9 @@ def do_test(signature_def_map, show_debug_info=False): mlir = pywrap_mlir.experimental_run_pass_pipeline(mlir, 'tf-standard-pipeline', show_debug_info) + if canonicalize: + mlir = pywrap_mlir.experimental_run_pass_pipeline(mlir, 'canonicalize', + show_debug_info) print(mlir) app.run(app_main) diff --git a/tensorflow/compiler/mlir/tensorflow/tests/tf_saved_model/hash_table_v1.py b/tensorflow/compiler/mlir/tensorflow/tests/tf_saved_model/hash_table_v1.py new file mode 100644 index 00000000000..16290455608 --- /dev/null +++ b/tensorflow/compiler/mlir/tensorflow/tests/tf_saved_model/hash_table_v1.py @@ -0,0 +1,92 @@ +# Copyright 2019 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== + +# RUN: %p/hash_table_v1 | FileCheck %s + +# pylint: disable=missing-docstring,line-too-long +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +import tensorflow.compat.v1 as tf +from tensorflow.compiler.mlir.tensorflow.tests.tf_saved_model import common_v1 + +# Verify that the tf.versions attribute exists. It is difficult to enforce +# contents, since the version numbers change over time. The conversion logic +# itself is verified in the common graphdef converter, so here just assert +# it is being invoked. +# CHECK: module +# CHECK-SAME: tf.versions +# CHECK-SAME: bad_consumers +# CHECK-SAME: min_consumer +# CHECK-SAME: producer + +# CHECK: "tf_saved_model.session_initializer"() {initializer = [[init:@.*]]} : () -> () +# CHECK: "tf_saved_model.global_tensor"() + +# CHECK: func [[init]] +# CHECK-NEXT: [[R5:%.*]] = "tf.Const"() +# CHECK-NEXT: [[R6:%.*]] = "tf.Const"() +# CHECK-NEXT: [[R7:%.*]] = "tf.HashTableV2"() +# CHECK-SAME: shared_name = "[[hash_table:.*]]" +# CHECK-NEXT: "tf.LookupTableImportV2"([[R7]], [[R5]], [[R6]]) + +# CHECK: func {{@[a-zA-Z_0-9]+}}( +# CHECK-SAME: [[ARG0:%.*]]: tensor +# CHECK-SAME: [[ARG1:%.*]]: tensor, value = {{.*}} : tensor<1x3xf32>} : () -> () +# CHECK-NOT: session_initializer + +# CHECK: func {{@[a-zA-Z_0-9]+}}( +# CHECK-SAME: [[ARG0:%.*]]: tensor<3x1xf32> {tf_saved_model.index_path = ["x"]}, +# CHECK-SAME: [[ARG1:%.*]]: tensor>> {tf_saved_model.bound_input = @[[VAR]]}) +# CHECK-SAME: -> (tensor<3x3xf32> {tf_saved_model.index_path = ["r"]}) +# CHECK-SAME: attributes {{.*}} tf_saved_model.exported_names = ["key"] + +# CHECK-NEXT: [[R0:%.*]] = "tf.ReadVariableOp"([[ARG1]]) {{{.*}}} : (tensor>>) -> tensor<1x3xf32> +# CHECK-NEXT: [[R1:%.*]] = "tf.MatMul"([[ARG0]], [[R0]]) {{{.*}}} : (tensor<3x1xf32>, tensor<1x3xf32>) -> tensor<3x3xf32> +# CHECK-NEXT: return [[R1]] : tensor<3x3xf32> + + +def Test(): + + x = tf.constant([[1.0], [1.0], [1.0]]) + y = tf.compat.v1.get_variable( + name='y', + shape=(1, 3), + initializer=tf.random_normal_initializer(), + trainable=True) + r = tf.matmul(x, y) + + tensor_info_x = tf.compat.v1.saved_model.utils.build_tensor_info(x) + tensor_info_r = tf.compat.v1.saved_model.utils.build_tensor_info(r) + + return { + 'key': (tf.compat.v1.saved_model.signature_def_utils.build_signature_def( + inputs={'x': tensor_info_x}, + outputs={'r': tensor_info_r}, + method_name='some_function')) + } + + +if __name__ == '__main__': + common_v1.set_tf_options() + common_v1.do_test( + Test(), tf.initializers.global_variables(), canonicalize=True) diff --git a/tensorflow/compiler/mlir/tensorflow/tests/tf_saved_model_delete_unused_funcs.mlir b/tensorflow/compiler/mlir/tensorflow/tests/tf_saved_model_delete_unused_funcs.mlir deleted file mode 100644 index 6f2c47a935f..00000000000 --- a/tensorflow/compiler/mlir/tensorflow/tests/tf_saved_model_delete_unused_funcs.mlir +++ /dev/null @@ -1,96 +0,0 @@ -// RUN: tf-opt -tf-saved-model-mark-func-visibility -symbol-dce -split-input-file %s | FileCheck %s - -module attributes {tf_saved_model.semantics} { - - // Test case: Unused function should be deleted. - - // CHECK-NOT: func @unused - func @unused() { - return - } - -} - -// ----- - -module attributes {tf_saved_model.semantics} { - - // Test case: Root calls child. Child should not be deleted. - - // CHECK: func @root - func @root() attributes {tf_saved_model.exported_names = ["root"]} { - "tf.some_call"() { callee = @child } : () -> () - return - } - - // CHECK: func @child - func @child() { - return - } - -} - -// ----- - -module attributes {tf_saved_model.semantics} { - - // Test case: Don't crash if attribute that doesn't reference a func. - - "tf.some_opaque_global_variable"() { sym_name = "some_global" } : () -> () - - func @root2() attributes {tf_saved_model.exported_names = ["root2"]} { - "tf.do_something_with_a_global"() { global = @some_global } : () -> () - return - } - -} - -// ----- - -module attributes {tf_saved_model.semantics} { - - // Test case: Delete recursively dead cycle. - - // CHECK-NOT: func @recursively_dead0 - func @recursively_dead0() { - "tf.some_call"() { callee = @recursively_dead1 } : () -> () - return - } - // CHECK-NOT: func @recursively_dead1 - func @recursively_dead1() { - "tf.some_call"() { callee = @recursively_dead0 } : () -> () - return - } - -} - -// ----- - -module attributes {tf_saved_model.semantics} { - - // Test case: Root calls child with a deeply nested symbol reference. - // Child should not be deleted. - - // CHECK: func @root - func @root() attributes {tf_saved_model.exported_names = ["root"]} { - "tf.some_call"() {callee = {callee = {callee = @child}}} : () -> () - return - } - - // CHECK: func @child - func @child() { - return - } - -} - -// ----- - -// Test case: If the module doesn't have tf_saved_model semantics, then this -// pass shouldn't do anything. -module { - // CHECK: func @not_dead() - func @not_dead() { - return - } -} diff --git a/tensorflow/compiler/mlir/tensorflow/tests/tf_saved_model_freeze_global_tensors.mlir b/tensorflow/compiler/mlir/tensorflow/tests/tf_saved_model_freeze_global_tensors.mlir index 38627b41b68..6c32a3bc4d6 100644 --- a/tensorflow/compiler/mlir/tensorflow/tests/tf_saved_model_freeze_global_tensors.mlir +++ b/tensorflow/compiler/mlir/tensorflow/tests/tf_saved_model_freeze_global_tensors.mlir @@ -64,7 +64,7 @@ module attributes {tf_saved_model.semantics} { return } - func @f_callee(%arg0: tensor>>) { + func @f_callee(%arg0: tensor>>) attributes {sym_visibility = "private"} { return } } diff --git a/tensorflow/compiler/mlir/tensorflow/tests/tf_saved_model_ops.mlir b/tensorflow/compiler/mlir/tensorflow/tests/tf_saved_model_ops.mlir index 21e3bef8fd8..26cdf025a10 100644 --- a/tensorflow/compiler/mlir/tensorflow/tests/tf_saved_model_ops.mlir +++ b/tensorflow/compiler/mlir/tensorflow/tests/tf_saved_model_ops.mlir @@ -2,6 +2,11 @@ module attributes {tf_saved_model.semantics} { + // CHECK: tf_saved_model.session_initializer + "tf_saved_model.session_initializer"() { + initializer = @init + } : () -> () + // Representation for constants: (immutable) global tensor. // CHECK: tf_saved_model.global_tensor "tf_saved_model.global_tensor"() { @@ -35,7 +40,18 @@ module attributes {tf_saved_model.semantics} { return %arg0 : tensor } - func @f() { + func @f() attributes {sym_visibility = "private"} { + return + } + + // Representation for init functions + // CHECK: func @init + // CHECK-SAME: exported_names = ["__tf_saved_model_session_initializer"] + func @init( + %arg1: tensor>> {tf_saved_model.bound_input = @some_constant} + ) attributes {tf_saved_model.exported_names = ["__tf_saved_model_session_initializer"]} + { + "tf.some_call"(%arg1) : (tensor>>) -> () return } diff --git a/tensorflow/compiler/mlir/tensorflow/tests/tf_saved_model_ops_invalid.mlir b/tensorflow/compiler/mlir/tensorflow/tests/tf_saved_model_ops_invalid.mlir index c055c6c9f56..260174b184f 100644 --- a/tensorflow/compiler/mlir/tensorflow/tests/tf_saved_model_ops_invalid.mlir +++ b/tensorflow/compiler/mlir/tensorflow/tests/tf_saved_model_ops_invalid.mlir @@ -3,7 +3,7 @@ module attributes {tf_saved_model.semantics} { // expected-error@+1 {{unknown tf_saved_model dialect arg attribute 'tf_saved_model.not_a_real_arg_attr'}} - func @f(%arg0: tensor {tf_saved_model.not_a_real_arg_attr = 1 : i32}) { + func @f(%arg0: tensor {tf_saved_model.not_a_real_arg_attr = 1 : i32}) attributes {sym_visibility = "private"} { return } @@ -233,7 +233,7 @@ module attributes {tf_saved_model.semantics} { "tf_saved_model.global_tensor"() { is_mutable, sym_name = "v", type = tensor, value = dense<1.> : tensor<1xf32> } : () -> () // expected-error@+1 {{can only apply 'tf_saved_model' argument attributes to exported functions}} func @f(%arg0: tensor>> {tf_saved_model.bound_input = @v}) - -> (tensor {tf_saved_model.index_path = []}) { + -> (tensor {tf_saved_model.index_path = []}) attributes {sym_visibility = "private"} { %0 = "tf.ReadVariableOp"(%arg0) : (tensor>>) -> tensor return %0 : tensor } @@ -258,3 +258,97 @@ module attributes {tf_saved_model.semantics} { // expected-error@+1 {{'type' attribute for immutable 'tf_saved_model.global_tensor' should have a static shape}} "tf_saved_model.global_tensor"() { sym_name = "v", type = tensor, value = dense<1.> : tensor<1xf32> } : () -> () } + +// ----- + +module attributes {tf_saved_model.semantics} { + + // expected-error@+1 {{the initializer function does not exist}} + "tf_saved_model.session_initializer"() { initializer = @init } : () -> () +} + +// ----- + +module attributes {tf_saved_model.semantics} { + + // expected-error@+1 {{the initializer function should have no output}} + "tf_saved_model.session_initializer"() { initializer = @init } : () -> () + func @init() -> tensor<1xf32> attributes {sym_visibility = "private"} { + %0 = "tf.Const"() {value = dense<[1.0]> : tensor<1xf32> } : () -> tensor<1xf32> + return %0 : tensor<1xf32> + } +} + +// ----- + +module attributes {tf_saved_model.semantics} { + + "tf_saved_model.session_initializer"() { initializer = @init } : () -> () + // expected-error@+1 {{there must be no more than one session_initializer op}} + "tf_saved_model.session_initializer"() { initializer = @init } : () -> () + func @init() -> tensor<1xf32> attributes {sym_visibility = "private"} { + %0 = "tf.Const"() {value = dense<[1.0]> : tensor<1xf32> } : () -> tensor<1xf32> + return %0 : tensor<1xf32> + } +} + +// ----- + +module attributes {tf_saved_model.semantics} { + + // expected-error@+1 {{exported function @f should be public}} + func @f( + %arg0: tensor {tf.resource_name = "resource"} + ) attributes { sym_visibility = "private", tf_saved_model.exported_names = ["foo.some_func"] } { + return + } + +} + +// ----- + +module attributes {tf_saved_model.semantics} { + + // expected-error@+1 {{non-exported function @f should be private}} + func @f( + %arg0: tensor {tf.resource_name = "resource"} + ) { + return + } + +} + +// ----- + +module attributes {tf_saved_model.semantics} { + + // expected-error@+1 {{the initializer function does not exist}} + "tf_saved_model.session_initializer"() { initializer = @init } : () -> () +} + +// ----- + +module attributes {tf_saved_model.semantics} { + + // expected-error@+1 {{the initializer function should have no output}} + "tf_saved_model.session_initializer"() { initializer = @init } : () -> () + func @init() -> (tensor<1xf32> {tf_saved_model.index_path = ["output"]}) + attributes { tf_saved_model.exported_names = ["__tf_saved_model_session_initializer"] } { + %0 = "tf.Const"() {value = dense<[1.0]> : tensor<1xf32> } : () -> tensor<1xf32> + return %0 : tensor<1xf32> + } +} + +// ----- + +module attributes {tf_saved_model.semantics} { + + "tf_saved_model.session_initializer"() { initializer = @init } : () -> () + // expected-error@+1 {{there must be no more than one session_initializer op}} + "tf_saved_model.session_initializer"() { initializer = @init } : () -> () + func @init() -> (tensor<1xf32> {tf_saved_model.index_path = ["output"]}) + attributes { tf_saved_model.exported_names = ["__tf_saved_model_session_initializer"] } { + %0 = "tf.Const"() {value = dense<[1.0]> : tensor<1xf32> } : () -> tensor<1xf32> + return %0 : tensor<1xf32> + } +} diff --git a/tensorflow/compiler/mlir/tensorflow/tests/tf_saved_model_optimize_global_tensors.mlir b/tensorflow/compiler/mlir/tensorflow/tests/tf_saved_model_optimize_global_tensors.mlir index 9d8911d306d..0c68cf0cf64 100644 --- a/tensorflow/compiler/mlir/tensorflow/tests/tf_saved_model_optimize_global_tensors.mlir +++ b/tensorflow/compiler/mlir/tensorflow/tests/tf_saved_model_optimize_global_tensors.mlir @@ -1,7 +1,7 @@ // RUN: tf-opt -tf-saved-model-optimize-global-tensors -split-input-file %s | FileCheck %s //===----------------------------------------------------------------------===// -// Freezing. +// Immutability. //===----------------------------------------------------------------------===// module attributes {tf_saved_model.semantics} { @@ -142,3 +142,89 @@ module attributes {tf_saved_model.semantics} { // Test running the pass on a module that does not have // tf_saved_model.semantics. module {} + +// ----- + +// Test use as an input in unhandled op +module attributes {tf_saved_model.semantics} { + + // CHECK: "tf_saved_model.global_tensor"() { + // CHECK-SAME: is_mutable + // CHECK-SAME: } : () -> () + "tf_saved_model.global_tensor"() { is_mutable, sym_name = "v", type = tensor, value = dense<42.> : tensor } : () -> () + + func @f(%arg0: tensor>> {tf_saved_model.bound_input = @v}) + attributes {tf_saved_model.exported_names = ["f"]} { + "tf.unhandled_op"(%arg0) : (tensor>>) -> () + return + } +} + + +// ----- + +// Test use as a region capture in an unhandled op +module attributes {tf_saved_model.semantics} { + + // CHECK: "tf_saved_model.global_tensor"() { + // CHECK-SAME: is_mutable + // CHECK-SAME: } : () -> () + "tf_saved_model.global_tensor"() { is_mutable, sym_name = "v", type = tensor, value = dense<42.> : tensor } : () -> () + + func @f(%arg0: tensor>> {tf_saved_model.bound_input = @v}) + attributes {tf_saved_model.exported_names = ["f"]} { + "tf.unhandled"() ({ + %val = "tf.ReadVariableOp"(%arg0) : (tensor>>) -> tensor + "tf.unhandled_terminator"() : () -> () + }) : () -> () + return + } +} + +// ----- + +// Test use as region capture as well as input in an unhandled op +// to the unhandled op. +module attributes {tf_saved_model.semantics} { + + // CHECK: "tf_saved_model.global_tensor"() { + // CHECK-SAME: is_mutable + // CHECK-SAME: } : () -> () + "tf_saved_model.global_tensor"() { is_mutable, sym_name = "v", type = tensor, value = dense<42.> : tensor } : () -> () + + // CHECK: "tf_saved_model.global_tensor"() { + // CHECK-SAME: is_mutable + // CHECK-SAME: } : () -> () + "tf_saved_model.global_tensor"() { is_mutable, sym_name = "u", type = tensor, value = dense<22.> : tensor } : () -> () + + func @f(%arg0: tensor>> {tf_saved_model.bound_input = @v}, %arg1: tensor>> {tf_saved_model.bound_input = @u}) + attributes {tf_saved_model.exported_names = ["f"]} { + %0 = "tf.unhandled"(%arg0) ({ + %val = "tf.ReadVariableOp"(%arg1) : (tensor>>) -> tensor + "tf.unhandled_terminator"() : () -> () + }) : (tensor>>) -> (tensor>>) + return + } +} + +// ----- + +// Test multiple global tensors uses as operands for an unhandled op. +module attributes {tf_saved_model.semantics} { + + // CHECK: "tf_saved_model.global_tensor"() { + // CHECK-SAME: is_mutable + // CHECK-SAME: } : () -> () + "tf_saved_model.global_tensor"() { is_mutable, sym_name = "v", type = tensor, value = dense<42.> : tensor } : () -> () + + // CHECK: "tf_saved_model.global_tensor"() { + // CHECK-SAME: is_mutable + // CHECK-SAME: } : () -> () + "tf_saved_model.global_tensor"() { is_mutable, sym_name = "u", type = tensor, value = dense<22.> : tensor } : () -> () + + func @f(%arg0: tensor>> {tf_saved_model.bound_input = @v}, %arg1: tensor>> {tf_saved_model.bound_input = @u}) + attributes {tf_saved_model.exported_names = ["f"]} { + "tf.unhandled"(%arg0, %arg1) : (tensor>>, tensor>>) -> () + return + } +} diff --git a/tensorflow/compiler/mlir/tensorflow/tests/tf_saved_model_optimize_global_tensors_interprocedural.mlir b/tensorflow/compiler/mlir/tensorflow/tests/tf_saved_model_optimize_global_tensors_interprocedural.mlir index 91e8c9c4b66..14a0006cd3b 100644 --- a/tensorflow/compiler/mlir/tensorflow/tests/tf_saved_model_optimize_global_tensors_interprocedural.mlir +++ b/tensorflow/compiler/mlir/tensorflow/tests/tf_saved_model_optimize_global_tensors_interprocedural.mlir @@ -20,12 +20,12 @@ module attributes {tf_saved_model.semantics} { return %val : tensor } - func @f_callee(%arg0: tensor<*x!tf.resource>) -> tensor { + func @f_callee(%arg0: tensor<*x!tf.resource>) -> tensor attributes {sym_visibility = "private"} { %val = "tf.PartitionedCall"(%arg0) {config = "", config_proto = "", executor_type = "", f = @f_callee_callee} : (tensor<*x!tf.resource>) -> (tensor) return %val : tensor } - func @f_callee_callee(%arg0: tensor<*x!tf.resource>) -> tensor { + func @f_callee_callee(%arg0: tensor<*x!tf.resource>) -> tensor attributes {sym_visibility = "private"} { %val = "tf.ReadVariableOp"(%arg0) : (tensor<*x!tf.resource>) -> tensor return %val : tensor } @@ -59,7 +59,7 @@ module attributes {tf_saved_model.semantics} { return %val : tensor } - func @f_common(%arg0: tensor<*x!tf.resource>) -> tensor { + func @f_common(%arg0: tensor<*x!tf.resource>) -> tensor attributes {sym_visibility = "private"} { %val = "tf.ReadVariableOp"(%arg0) : (tensor<*x!tf.resource>) -> tensor return %val : tensor } @@ -85,7 +85,7 @@ module attributes {tf_saved_model.semantics} { return %val_2 : tensor } - func @f_callee(%arg0: tensor<*x!tf.resource>) -> tensor { + func @f_callee(%arg0: tensor<*x!tf.resource>) -> tensor attributes {sym_visibility = "private"} { %cst_1 = constant dense<2.0> : tensor return %cst_1 : tensor } @@ -112,13 +112,13 @@ module attributes {tf_saved_model.semantics} { } // CHECK: func @f_callee(%arg0: tensor<*x!tf.resource>) -> tensor - func @f_callee(%arg0: tensor<*x!tf.resource>) -> tensor { + func @f_callee(%arg0: tensor<*x!tf.resource>) -> tensor attributes {sym_visibility = "private"} { %val = "tf.PartitionedCall"(%arg0) {config = "", config_proto = "", executor_type = "", f = @f_callee_callee} : (tensor<*x!tf.resource>) -> (tensor) return %val : tensor } // CHECK: func @f_callee_callee(%arg0: tensor<*x!tf.resource>) -> tensor - func @f_callee_callee(%arg0: tensor<*x!tf.resource>) -> tensor { + func @f_callee_callee(%arg0: tensor<*x!tf.resource>) -> tensor attributes {sym_visibility = "private"} { %c0 = "tf.Const"() { value = dense<1.0> : tensor } : () -> tensor "tf.AssignVariableOp"(%arg0, %c0) : (tensor<*x!tf.resource>, tensor) -> () return %c0 : tensor @@ -146,13 +146,13 @@ module attributes {tf_saved_model.semantics} { } // CHECK: func @f_callee(%arg0: tensor<*x!tf.resource>) -> tensor - func @f_callee(%arg0: tensor<*x!tf.resource>) -> tensor { + func @f_callee(%arg0: tensor<*x!tf.resource>) -> tensor attributes {sym_visibility = "private"} { %val = "tf.PartitionedCall"(%arg0) {config = "", config_proto = "", executor_type = "", f = @f_callee_callee} : (tensor<*x!tf.resource>) -> (tensor) return %val : tensor } // CHECK: func @f_callee_callee(%arg0: tensor<*x!tf.resource>) -> tensor - func @f_callee_callee(%arg0: tensor<*x!tf.resource>) -> tensor { + func @f_callee_callee(%arg0: tensor<*x!tf.resource>) -> tensor attributes {sym_visibility = "private"} { %c0 = "tf.Const"() { value = dense<1.0> : tensor } : () -> tensor "tf.AssignVariableOp"(%arg0, %c0) : (tensor<*x!tf.resource>, tensor) -> () return %c0 : tensor @@ -179,13 +179,13 @@ module attributes {tf_saved_model.semantics} { // CHECK: func @f(%arg0: tensor<*x!tf.resource>) -> tensor - func @f(%arg0: tensor<*x!tf.resource>) -> tensor { + func @f(%arg0: tensor<*x!tf.resource>) -> tensor attributes {sym_visibility = "private"} { %val = "tf.PartitionedCall"(%arg0) {config = "", config_proto = "", executor_type = "", f = @g} : (tensor<*x!tf.resource>) -> (tensor) return %val : tensor } // CHECK: func @g(%arg0: tensor<*x!tf.resource>) -> tensor - func @g(%arg0: tensor<*x!tf.resource>) -> tensor { + func @g(%arg0: tensor<*x!tf.resource>) -> tensor attributes {sym_visibility = "private"} { %val = "tf.PartitionedCall"(%arg0) {config = "", config_proto = "", executor_type = "", f = @f} : (tensor<*x!tf.resource>) -> (tensor) return %val : tensor } @@ -212,7 +212,7 @@ module attributes {tf_saved_model.semantics} { // CHECK: func @f(%arg0: tensor<*x!tf.resource>) -> tensor - func @f(%arg0: tensor<*x!tf.resource>) -> tensor { + func @f(%arg0: tensor<*x!tf.resource>) -> tensor attributes {sym_visibility = "private"} { %c0 = "tf.Const"() { value = dense<1.0> : tensor } : () -> tensor "tf.AssignAddVariableOp"(%arg0, %c0) : (tensor<*x!tf.resource>, tensor) -> () return %c0 : tensor diff --git a/tensorflow/compiler/mlir/tensorflow/tests/tpu_extract_outside_compilation.mlir b/tensorflow/compiler/mlir/tensorflow/tests/tpu_extract_outside_compilation.mlir index 6bb8e99d796..d88489f5da0 100644 --- a/tensorflow/compiler/mlir/tensorflow/tests/tpu_extract_outside_compilation.mlir +++ b/tensorflow/compiler/mlir/tensorflow/tests/tpu_extract_outside_compilation.mlir @@ -262,7 +262,6 @@ func @single_outside_compiled_input_output_single_outside_compilation(%arg0: ten return %1 : tensor } - // Tests extraction of a single outside compiled cluster with multiple input/output. // CHECK-LABEL: func @multiple_outside_compiled_input_output_single_outside_compilation @@ -439,3 +438,24 @@ func @multiple_outside_compiled_inputs_single_outside_compilation(%arg0: tensor< return %1 : tensor } + +// Tests only directly used results of tpu cluster are remapped with +// parallel_execute. + +// CHECK-LABEL: func @remapped_results +func @remapped_results(%arg0: tensor) -> tensor { + %0 = "tf.A"(%arg0) : (tensor) -> tensor + // CHECK: %[[REPLICATE:[0-9]*]]:2 = tf_device.replicate + // CHECK: %[[PARALLEL_EXECUTE_OUTPUT:[0-9]*]]:2 = "tf_device.parallel_execute" + // CHECK: tf_device.return %[[PARALLEL_EXECUTE_OUTPUT]]#1 : tensor + %1:2 = tf_device.replicate([%0, %arg0] as %ri_0: tensor) {n = 2 : i32} { + %2:2 = "tf_device.cluster"() ( { + %3 = "tf.A"() : () -> (tensor) + %4 = "tf.B"(%3) {_xla_outside_compilation = "cluster1"} : (tensor) -> (tensor) + %5:2 = "tf.C"(%4) : (tensor) -> (tensor, tensor) + tf_device.return %5#0, %5#1 : tensor, tensor + }) {cluster_attr = "cluster_attr"} : () -> (tensor, tensor) + tf_device.return %2#1 : tensor + } + return %1 : tensor +} diff --git a/tensorflow/compiler/mlir/lite/transforms/device_index_selector.cc b/tensorflow/compiler/mlir/tensorflow/transforms/device_index_selector.cc similarity index 92% rename from tensorflow/compiler/mlir/lite/transforms/device_index_selector.cc rename to tensorflow/compiler/mlir/tensorflow/transforms/device_index_selector.cc index d4aed750dc8..550647a915a 100644 --- a/tensorflow/compiler/mlir/lite/transforms/device_index_selector.cc +++ b/tensorflow/compiler/mlir/tensorflow/transforms/device_index_selector.cc @@ -21,11 +21,11 @@ limitations under the License. #include "mlir/IR/Operation.h" // from @llvm-project #include "mlir/IR/PatternMatch.h" // from @llvm-project #include "mlir/Pass/Pass.h" // from @llvm-project -#include "tensorflow/compiler/mlir/lite/transforms/passes.h" #include "tensorflow/compiler/mlir/tensorflow/ir/tf_ops.h" +#include "tensorflow/compiler/mlir/tensorflow/transforms/passes.h" namespace mlir { -namespace TFL { +namespace TF { namespace { // Folds the DeviceIndex op to a constant value. The DeviceIndex return the @@ -55,8 +55,8 @@ void DeviceIndexSelector::runOnOperation() { // Convert all the DeviceIndex ops to constant values. func.getBody().walk([](TF::DeviceIndexOp op) { // This just selects the default in all cases where DeviceIndex feeds into - // tf.Case. This could be enhanced based on explicit TFLite specification or - // TAC in future. + // tf.Case. This could be enhanced to have some sort of policy in the + // future. OpBuilder b(op); RankedTensorType type = RankedTensorType::get({}, b.getIntegerType(32)); int index = op.device_names().size(); @@ -79,7 +79,7 @@ std::unique_ptr> CreateDeviceIndexSelectorPass() { } static PassRegistration pass( - "tfl-device-index-selector", "Fold tf.DeviceIndex to constant"); + "tf-device-index-selector", "Fold tf.DeviceIndex to constant"); -} // namespace TFL +} // namespace TF } // namespace mlir diff --git a/tensorflow/compiler/mlir/tensorflow/transforms/fold_switch.cc b/tensorflow/compiler/mlir/tensorflow/transforms/fold_switch.cc index 4d26747ebdc..b47378762a9 100644 --- a/tensorflow/compiler/mlir/tensorflow/transforms/fold_switch.cc +++ b/tensorflow/compiler/mlir/tensorflow/transforms/fold_switch.cc @@ -199,7 +199,7 @@ static void MatchSwitchFoldOps(tf_executor::SwitchOp switch_op, // Folds merge nodes with only a single non-dead input. static LogicalResult FoldMergeNodes(FuncOp function, const DeadQueue& queue) { // Create builder for val_index of MergeOp. - auto* block = &function.getBlocks().front(); + auto* block = &function.front(); OpBuilder builder = OpBuilder::atBlockEnd(block); auto type = builder.getIntegerType(32); auto build_index = [&](Location loc, int value) { diff --git a/tensorflow/compiler/mlir/tensorflow/transforms/fused_kernel_matcher.cc b/tensorflow/compiler/mlir/tensorflow/transforms/fused_kernel_matcher.cc index 4b10550df7b..d10f5e26e8f 100644 --- a/tensorflow/compiler/mlir/tensorflow/transforms/fused_kernel_matcher.cc +++ b/tensorflow/compiler/mlir/tensorflow/transforms/fused_kernel_matcher.cc @@ -52,11 +52,6 @@ struct FusedKernelMatcherPass void runOnFunction() override; }; -// Returns an op's name with the dialect prefix stripped off. -StringRef GetOpNameWithoutDialect(Operation *op) { - return op->getName().getStringRef().split(".").second; -} - bool IsActivationFunction(Operation *op) { return isa(op) || isa(op) || isa(op); } @@ -128,8 +123,8 @@ class FuseContractionWithBiasAdd : public OpRewritePattern { } SmallVector locations{contraction.getLoc(), bias_add.getLoc()}; - SmallVector fused_ops{ - StringAttr::get(GetOpNameWithoutDialect(bias_add), context)}; + SmallVector fused_ops{StringAttr::get( + bias_add.getOperation()->getName().stripDialect(), context)}; // BiasAdd may or may not feed into an activation function. auto activation = GetActivation(bias_add); @@ -143,7 +138,7 @@ class FuseContractionWithBiasAdd : public OpRewritePattern { if (fuse_activation) { locations.push_back(activation->getLoc()); fused_ops.push_back( - StringAttr::get(GetOpNameWithoutDialect(activation), context)); + StringAttr::get(activation->getName().stripDialect(), context)); result_type = activation->getResultTypes().front(); } else { result_type = bias_add.getResult().getType(); diff --git a/tensorflow/compiler/mlir/tensorflow/transforms/lift_variables_test_pass.h b/tensorflow/compiler/mlir/tensorflow/transforms/lift_variables_test_pass.h index faecdf04368..0e6d844bed3 100644 --- a/tensorflow/compiler/mlir/tensorflow/transforms/lift_variables_test_pass.h +++ b/tensorflow/compiler/mlir/tensorflow/transforms/lift_variables_test_pass.h @@ -96,15 +96,19 @@ class FakeSession : public tensorflow::Session { for (const std::string& output_name : output_names) { Tensor output; if (output_name == "dense/bias") { - outputs->push_back( - Tensor(tensorflow::DT_FLOAT, tensorflow::TensorShape({50}))); + Tensor t = Tensor(tensorflow::DT_FLOAT, tensorflow::TensorShape({50})); + t.flat().setZero(); + outputs->push_back(t); } else if (output_name == "dense/kernel") { - outputs->push_back( - Tensor(tensorflow::DT_FLOAT, tensorflow::TensorShape({100, 50}))); + Tensor t = + Tensor(tensorflow::DT_FLOAT, tensorflow::TensorShape({100, 50})); + t.flat().setZero(); + outputs->push_back(t); } else { // Create a scalar float tensor. - outputs->push_back( - Tensor(tensorflow::DT_FLOAT, tensorflow::TensorShape({}))); + Tensor t = Tensor(tensorflow::DT_FLOAT, tensorflow::TensorShape({})); + t.flat()(0) = 1.0f; + outputs->push_back(t); } } return Status::OK(); diff --git a/tensorflow/compiler/mlir/tensorflow/transforms/mark_function_visibility.cc b/tensorflow/compiler/mlir/tensorflow/transforms/mark_function_visibility.cc deleted file mode 100644 index 31a80a4ecdb..00000000000 --- a/tensorflow/compiler/mlir/tensorflow/transforms/mark_function_visibility.cc +++ /dev/null @@ -1,165 +0,0 @@ -/* Copyright 2020 The TensorFlow Authors. All Rights Reserved. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. -==============================================================================*/ - -#include "llvm/ADT/STLExtras.h" -#include "mlir/IR/Module.h" // from @llvm-project -#include "mlir/Pass/Pass.h" // from @llvm-project -#include "tensorflow/compiler/mlir/tensorflow/ir/tf_saved_model.h" - -#define DEBUG_TYPE "tf-shape-inference" - -namespace mlir { - -namespace { - -LogicalResult MarkFunctionVisibility( - ModuleOp module, llvm::function_ref IsExternalVisible) { - LogicalResult result = success(); - - for (auto func : module.getOps()) { - FuncOp::Visibility old_visibility = func.getVisibility(); - - FuncOp::Visibility visibility = IsExternalVisible(func) - ? FuncOp::Visibility::Public - : FuncOp::Visibility::Private; - - auto get_visibility_name = [](FuncOp::Visibility v) { - return v == FuncOp::Visibility::Public - ? "public" - : v == FuncOp::Visibility::Private ? "private" : "nested"; - }; - - if (old_visibility != SymbolTable::Visibility::Public && - old_visibility != visibility) { - result = func.emitError() - << "can't overwrite the visibility of function " - << func.getName() << " with " - << get_visibility_name(old_visibility) << " visibility"; - } - - LLVM_DEBUG(llvm::dbgs() - << "function " << func.getName() << " has " - << get_visibility_name(visibility) << " visibility \n"); - - func.setVisibility(visibility); - } - - return result; -} - -} // anonymous namespace - -namespace TF { - -LogicalResult MarkFunctionVisibilityUsingEntryFunctionSpecification( - ModuleOp module) { - auto HasEntryFunctionSpecification = [](FuncOp func) -> bool { - auto attrs = func.getAttrOfType("tf.entry_function"); - return attrs && !attrs.empty(); - }; - return MarkFunctionVisibility(module, HasEntryFunctionSpecification); -} - -namespace { -struct MarkFunctionVisibilityUsingEntryFunctionSpecificationPass - : public PassWrapper< - MarkFunctionVisibilityUsingEntryFunctionSpecificationPass, - OperationPass> { - void runOnOperation() override { - if (failed(MarkFunctionVisibilityUsingEntryFunctionSpecification( - getOperation()))) { - signalPassFailure(); - } - } -}; -} // namespace - -static PassRegistration< - MarkFunctionVisibilityUsingEntryFunctionSpecificationPass> - pass("tf-mark-func-visibility", - "Use tf.entry_function to mark function visibility."); - -std::unique_ptr> -CreateMarkFunctionVisibilityUsingEntryFunctionSpecificationPass() { - return std::make_unique< - MarkFunctionVisibilityUsingEntryFunctionSpecificationPass>(); -} - -// Marks the main function with public visibility, while other functions are -// marked with private visibility. -LogicalResult MarkOnlyMainFunctionWithPublicVisibility(ModuleOp module) { - for (auto func : module.getOps()) { - if (func.getName() == "main") { - func.setVisibility(FuncOp::Visibility::Public); - } else { - func.setVisibility(FuncOp::Visibility::Private); - } - } - return success(); -} - -namespace { -struct MarkOnlyMainFunctionWithPublicVisibilityPass - : public PassWrapper> { - void runOnOperation() override { - if (failed(MarkOnlyMainFunctionWithPublicVisibility(getOperation()))) { - signalPassFailure(); - } - } -}; -} // namespace - -std::unique_ptr> -CreateMarkOnlyMainFunctionWithPublicVisibilityPass() { - return std::make_unique(); -} - -} // namespace TF - -namespace tf_saved_model { - -static LogicalResult MarkFunctionVisibilityUsingSavedModelLinkage( - ModuleOp module) { - if (!tf_saved_model::HasTfSavedModelSemantics(module)) { - return success(); - } - return MarkFunctionVisibility(module, tf_saved_model::IsExported); -} - -namespace { -struct MarkFunctionVisibilityUsingSavedModelLinkagePass - : public PassWrapper> { - void runOnOperation() override { - if (failed(MarkFunctionVisibilityUsingSavedModelLinkage(getOperation()))) { - signalPassFailure(); - } - } -}; -} // namespace - -static PassRegistration pass( - "tf-saved-model-mark-func-visibility", - "Use tf_saved_model linkage information to mark function visibility."); - -std::unique_ptr> -CreateMarkFunctionVisibilityUsingSavedModelLinkagePass() { - return std::make_unique(); -} - -} // namespace tf_saved_model - -} // namespace mlir diff --git a/tensorflow/compiler/mlir/tensorflow/transforms/materialize_mlir_passthrough_op.cc b/tensorflow/compiler/mlir/tensorflow/transforms/materialize_mlir_passthrough_op.cc index 94fdfb310ac..3ed27d7ce30 100644 --- a/tensorflow/compiler/mlir/tensorflow/transforms/materialize_mlir_passthrough_op.cc +++ b/tensorflow/compiler/mlir/tensorflow/transforms/materialize_mlir_passthrough_op.cc @@ -71,7 +71,7 @@ void MaterializePassthroughOpPass::runOnFunction() { return; } Region &body = main.getBody(); - if (body.getBlocks().size() != 1) { + if (!llvm::hasSingleElement(body)) { op->emitError() << "MLIR Opaque Op expects a main() entry point with a " "single block\n"; return; diff --git a/tensorflow/compiler/mlir/tensorflow/transforms/optimize_global_tensors.cc b/tensorflow/compiler/mlir/tensorflow/transforms/optimize_global_tensors.cc index cd8f988fd5f..07cc6203cbd 100644 --- a/tensorflow/compiler/mlir/tensorflow/transforms/optimize_global_tensors.cc +++ b/tensorflow/compiler/mlir/tensorflow/transforms/optimize_global_tensors.cc @@ -56,14 +56,14 @@ struct GlobalTensorUse { using GlobalTensorUsesMap = std::map>; -static bool IsResourceType(Type type) { +bool IsResourceType(Type type) { if (auto tensor_type = type.dyn_cast()) { return tensor_type.getElementType().isa(); } return false; } -static bool IsResource(Value value) { return IsResourceType(value.getType()); } +bool IsResource(Value value) { return IsResourceType(value.getType()); } class ResourceAnalyzer { public: @@ -129,30 +129,24 @@ class ResourceAnalyzer { // this errs on the side of being conservative. We should improve // this by using either a property or a trait that clearly // identifies ops with resource mutating behavior. - if (PropagatePotentiallyWrittenWithinUnhandledOp(op)) { - return; - } + PropagatePotentiallyWrittenWithinUnhandledOp(op); }); return success(); } // If an op is not one of the handled ones, we assume all resource usages // within its purview are mutating in nature. - bool PropagatePotentiallyWrittenWithinUnhandledOp(Operation* op) { + void PropagatePotentiallyWrittenWithinUnhandledOp(Operation* op) { for (auto operand : op->getOperands()) { if (IsResource(operand)) { SetPotentiallyWritten(operand); - return true; } } - bool uses_resources = false; visitUsedValuesDefinedAbove(op->getRegions(), [&](OpOperand* operand) { if (IsResource(operand->get())) { SetPotentiallyWritten(operand->get()); - uses_resources = true; } }); - return uses_resources; } // Given a funcOp associated with the callee and operands from the @@ -212,7 +206,7 @@ bool IsImmutable(GlobalTensorOp global_tensor, return true; } -static GlobalTensorUsesMap CreateGlobalTensorUsesMap(ModuleOp module) { +GlobalTensorUsesMap CreateGlobalTensorUsesMap(ModuleOp module) { GlobalTensorUsesMap global_tensor_uses; SymbolTable symbol_table(module); @@ -293,13 +287,13 @@ void OptimizeGlobalTensorsPass::runOnOperation() { EraseUnusedGlobalTensors(module, global_tensor_uses); } -} // namespace - // For "opt" to pick up this pass. -static PassRegistration pass( +PassRegistration pass( "tf-saved-model-optimize-global-tensors", "Optimize tf_saved_model.global_tensor's."); +} // namespace + std::unique_ptr> CreateOptimizeGlobalTensorsPass() { return std::make_unique(); } diff --git a/tensorflow/compiler/mlir/tensorflow/transforms/passes.h b/tensorflow/compiler/mlir/tensorflow/transforms/passes.h index 7158d0f6be0..168b317641d 100644 --- a/tensorflow/compiler/mlir/tensorflow/transforms/passes.h +++ b/tensorflow/compiler/mlir/tensorflow/transforms/passes.h @@ -117,21 +117,6 @@ std::unique_ptr> CreatePromoteVarHandlesToArgsPass(); std::unique_ptr> CreateConvertReadonlyReferenceVariablesToResourceVariablesPass(); -// Marks function visibility using tf.entry_function specification. That is, -// functions with tf.entry_function attributes are marked with public -// visibility while the other functions are marked with private visibility. -LogicalResult MarkFunctionVisibilityUsingEntryFunctionSpecification( - ModuleOp module); -// Creates a pass that uses tf.entry_function specification to mark function -// visibility. -std::unique_ptr> -CreateMarkFunctionVisibilityUsingEntryFunctionSpecificationPass(); - -// Creates a pass that marks the main function with public visibility, while -// other functions are marked with private visibility. -std::unique_ptr> -CreateMarkOnlyMainFunctionWithPublicVisibilityPass(); - // Creates a simple device assignment pass on TF dialect for CoreRT use case. std::unique_ptr> CreateSimpleTFDeviceAssignmentPass( llvm::StringRef default_device); @@ -162,6 +147,9 @@ std::unique_ptr> CreateLegalizeHloToTfPass(); // generally used beyond exporting to runtimes that supports these ops. In the // future these fusions may be codegen'd automatically. std::unique_ptr> CreateFusedKernelMatcherPass(); + +// Creates function pass to select device index/fold tf.DeviceIndex. +std::unique_ptr> CreateDeviceIndexSelectorPass(); } // namespace TF namespace tf_executor { @@ -296,7 +284,8 @@ std::unique_ptr> CreateTPUHostComputationExpansionPass(); // Creates a pass that extract outside compilation (CPU ops inside TPU cluster) // ops to a separate parallel_execute region to run on CPU. -std::unique_ptr> CreateTPUExtractOutsideCompilationPass(); +std::unique_ptr> +CreateTPUExtractOutsideCompilationPass(); // Populates the supplied passmanager with the passes required to run the void CreateTPUBridgePipeline(OpPassManager& pm); @@ -315,13 +304,6 @@ std::unique_ptr> CreateOptimizeGlobalTensorsPass(); // Creates a pass that freezes tf_saved_model.global_tensor ops. std::unique_ptr> CreateFreezeGlobalTensorsPass(); -// Creates a pass that uses tf_saved_model dialect linkage information -// to mark function visibility. That is, exported functions are marked with -// public visibility while the other functions are marked with private -// visibility. -std::unique_ptr> -CreateMarkFunctionVisibilityUsingSavedModelLinkagePass(); - } // namespace tf_saved_model } // namespace mlir diff --git a/tensorflow/compiler/mlir/tensorflow/transforms/promote_resources_to_args.cc b/tensorflow/compiler/mlir/tensorflow/transforms/promote_resources_to_args.cc index cece23b4750..af36770f496 100644 --- a/tensorflow/compiler/mlir/tensorflow/transforms/promote_resources_to_args.cc +++ b/tensorflow/compiler/mlir/tensorflow/transforms/promote_resources_to_args.cc @@ -80,11 +80,11 @@ constexpr char kResourceNameArgAttr[] = "tf.resource_name"; // Checks if a function has only one block. mlir::LogicalResult CheckSingleBlockFunction(FuncOp function) { - if (!hasSingleElement(function.getBlocks())) + if (!llvm::hasSingleElement(function)) { return function.emitError() << "expects function '" << function.getName() << "' to have 1 block, got " << function.getBlocks().size(); - + } return success(); } diff --git a/tensorflow/compiler/mlir/tensorflow/transforms/resource_op_lifting.cc b/tensorflow/compiler/mlir/tensorflow/transforms/resource_op_lifting.cc index ed7ebc25c9f..799ab3a0f0d 100644 --- a/tensorflow/compiler/mlir/tensorflow/transforms/resource_op_lifting.cc +++ b/tensorflow/compiler/mlir/tensorflow/transforms/resource_op_lifting.cc @@ -1113,7 +1113,7 @@ LogicalResult ResourceLiftingForFunctionalControlFlow(FuncOp function) { // This routine should only be called when control flow operations are still // represented with TF IfOp and WhileOp operations. In this case, there should // be only one basic blocks in the MLIR representation. - if (!hasSingleElement(function.getBlocks())) { + if (!llvm::hasSingleElement(function)) { return function.emitError() << "expect the function to have 1 block while it has " << function.getBlocks().size(); diff --git a/tensorflow/compiler/mlir/tensorflow/transforms/shape_inference.cc b/tensorflow/compiler/mlir/tensorflow/transforms/shape_inference.cc index 7e4baadc397..33ccf5caff2 100644 --- a/tensorflow/compiler/mlir/tensorflow/transforms/shape_inference.cc +++ b/tensorflow/compiler/mlir/tensorflow/transforms/shape_inference.cc @@ -215,6 +215,10 @@ bool InferShapeForNonTFDialectOperation(Operation* op, Dialect* tf_dialect) { return InferShapeForPassThroughOps( tensor_cast.getOperation()->getOperands(), op, tf_dialect); } + if (auto launch_op = dyn_cast(op)) { + return InferShapeForPassThroughOps( + launch_op.GetBody().getTerminator()->getOperands(), op, tf_dialect); + } return false; } diff --git a/tensorflow/compiler/mlir/tensorflow/transforms/stack_ops_decomposition.cc b/tensorflow/compiler/mlir/tensorflow/transforms/stack_ops_decomposition.cc index c349c2b4c3e..734a7d04a86 100644 --- a/tensorflow/compiler/mlir/tensorflow/transforms/stack_ops_decomposition.cc +++ b/tensorflow/compiler/mlir/tensorflow/transforms/stack_ops_decomposition.cc @@ -343,7 +343,7 @@ LogicalResult HandlePartitionedCallOp( } llvm::SmallDenseMap callee_map; FuncOp lowered_callee = callee; - if (callee.getVisibility() != SymbolTable::Visibility::Private) { + if (!callee.isPrivate()) { // Clone non-private callee in case of signature change. lowered_callee = callee.clone(); lowered_callee.setVisibility(SymbolTable::Visibility::Private); diff --git a/tensorflow/compiler/mlir/tensorflow/transforms/tensor_array_ops_decomposition.cc b/tensorflow/compiler/mlir/tensorflow/transforms/tensor_array_ops_decomposition.cc index cfeb2b1f031..a9e1243714e 100644 --- a/tensorflow/compiler/mlir/tensorflow/transforms/tensor_array_ops_decomposition.cc +++ b/tensorflow/compiler/mlir/tensorflow/transforms/tensor_array_ops_decomposition.cc @@ -759,7 +759,7 @@ LogicalResult HandlePartitionedCallOp( return it->getSecond().accumulate_on_write; }; FuncOp lowered_callee = callee; - if (callee.getVisibility() != SymbolTable::Visibility::Private) { + if (!callee.isPrivate()) { // Clone non-private callee in case of signature change. lowered_callee = callee.clone(); lowered_callee.setVisibility(SymbolTable::Visibility::Private); diff --git a/tensorflow/compiler/mlir/tensorflow/transforms/tensor_list_ops_decomposition.cc b/tensorflow/compiler/mlir/tensorflow/transforms/tensor_list_ops_decomposition.cc index 9733bfe2290..b118ab6c6c9 100644 --- a/tensorflow/compiler/mlir/tensorflow/transforms/tensor_list_ops_decomposition.cc +++ b/tensorflow/compiler/mlir/tensorflow/transforms/tensor_list_ops_decomposition.cc @@ -322,7 +322,7 @@ LogicalResult HandlePartitionedCallOp( // Rewrite the callee. llvm::SmallDenseMap callee_map; FuncOp lowered_callee = callee; - if (callee.getVisibility() != SymbolTable::Visibility::Private) { + if (!callee.isPrivate()) { // Clone non-private callee in case of signature change. lowered_callee = callee.clone(); lowered_callee.setVisibility(SymbolTable::Visibility::Private); diff --git a/tensorflow/compiler/mlir/tensorflow/transforms/tpu_extract_outside_compilation.cc b/tensorflow/compiler/mlir/tensorflow/transforms/tpu_extract_outside_compilation.cc index 54600faca4b..503c9869557 100644 --- a/tensorflow/compiler/mlir/tensorflow/transforms/tpu_extract_outside_compilation.cc +++ b/tensorflow/compiler/mlir/tensorflow/transforms/tpu_extract_outside_compilation.cc @@ -49,8 +49,9 @@ using OutsideClusterMap = // TODO(b/154363171): Add example tranformations. struct TPUExtractOutsideCompilation - : public PassWrapper { - void runOnFunction() override; + : public PassWrapper> { + void runOnOperation() override; }; // Collects and clusters ops in `block` with the same `_xla_outside_compilation` @@ -108,18 +109,6 @@ tf_device::LaunchOp CreateLaunchOpForOutsideCluster( return launch_op; } -// Propagates the return from `parallel_execute_op` to parent replicate -// op if it exists. -void PropagateParallelExecuteReturnToReplicate( - tf_device::ParallelExecuteOp parallel_execute_op) { - // Update the return for the parallel_execute op parent. - auto replicate = llvm::dyn_cast_or_null( - parallel_execute_op.getParentOp()); - if (replicate) - replicate.GetBody().getTerminator()->setOperands( - parallel_execute_op.execute_outputs()); -} - // Extracts all externally provided operands of `cluster_ops`. llvm::SmallSetVector GetExternalOperands( llvm::ArrayRef cluster_ops) { @@ -305,12 +294,21 @@ void CreateParallelExecuteFromOutsideClusters( tpu_cluster.getOperation()->moveBefore( parallel_execute_tpu_block.getTerminator()); - PropagateParallelExecuteReturnToReplicate(parallel_execute_op); + // Remap cluster results with parallel_execute results if user is outside of + // parallel_execute. + for (auto result : + llvm::zip(tpu_cluster.getResults(), parallel_execute_op.getResults())) { + Value tpu_cluster_result = std::get<0>(result); + Value parallel_execute_result = std::get<1>(result); + for (auto& use : llvm::make_early_inc_range(tpu_cluster_result.getUses())) + if (!parallel_execute_op.getOperation()->isProperAncestor(use.getOwner())) + use.set(parallel_execute_result); + } } -void TPUExtractOutsideCompilation::runOnFunction() { +void TPUExtractOutsideCompilation::runOnOperation() { auto extract_result = - getFunction().walk([&](tf_device::ClusterOp tpu_cluster) { + getOperation().walk([&](tf_device::ClusterOp tpu_cluster) { OutsideClusterMap clusters; if (failed(CollectAndGroupOutsideClusterOps(&tpu_cluster.GetBody(), &clusters))) @@ -328,7 +326,7 @@ void TPUExtractOutsideCompilation::runOnFunction() { } // namespace -std::unique_ptr> +std::unique_ptr> CreateTPUExtractOutsideCompilationPass() { return std::make_unique(); } diff --git a/tensorflow/compiler/mlir/tensorflow/transforms/tpu_rewrite_pass.cc b/tensorflow/compiler/mlir/tensorflow/transforms/tpu_rewrite_pass.cc index 696882cd105..ec9b3df525f 100644 --- a/tensorflow/compiler/mlir/tensorflow/transforms/tpu_rewrite_pass.cc +++ b/tensorflow/compiler/mlir/tensorflow/transforms/tpu_rewrite_pass.cc @@ -146,6 +146,9 @@ LogicalResult EncapsulateFuncAndSerialize(FuncOp entry_func, // We can simply change name of TPU program's main function because there // should be no other reference to it. clone.setName("main"); + clone.setVisibility(FuncOp::Visibility::Public); + } else { + clone.setVisibility(FuncOp::Visibility::Private); } symbol_table.insert(clone); } diff --git a/tensorflow/compiler/mlir/tensorflow/transforms/tpu_sharding_identification_pass.cc b/tensorflow/compiler/mlir/tensorflow/transforms/tpu_sharding_identification_pass.cc index f8b6e364f55..b05e87c6485 100644 --- a/tensorflow/compiler/mlir/tensorflow/transforms/tpu_sharding_identification_pass.cc +++ b/tensorflow/compiler/mlir/tensorflow/transforms/tpu_sharding_identification_pass.cc @@ -159,8 +159,7 @@ llvm::SmallVector ExtractFunctionsConnectedToArg( while (!functions_to_parse.empty()) { llvm::SmallVector newly_discovered_functions; for (auto function_info : functions_to_parse) { - Block& func_entry_block = - function_info.func.getBody().getBlocks().front(); + Block& func_entry_block = function_info.func.front(); auto argument = func_entry_block.getArgument(function_info.argument_index); @@ -186,8 +185,7 @@ void IdentifyXlaShardingForComputationInputs( StringRef logical_core_0_sharding, tf_device::ClusterFuncOp cluster_func_op, FuncOp cluster_function, Builder* builder) { // Look up function definition from module. - Block& cluster_function_block = - cluster_function.getBody().getBlocks().front(); + Block& cluster_function_block = cluster_function.front(); ModuleOp module = cluster_func_op.getParentOfType(); llvm::SmallVector sharding_for_args( @@ -215,8 +213,7 @@ void IdentifyXlaShardingForComputationInputs( const int function_argument_index = function_arg_info.argument_index; auto& parsed_function = function_arg_info.func; - Block& parsed_function_block = - parsed_function.getBody().getBlocks().front(); + Block& parsed_function_block = parsed_function.front(); arg_sharding = ParseInputSharding( parsed_function_block.getArgument(function_argument_index)); } @@ -245,7 +242,7 @@ void IdentifyXlaShardingForComputationOutputs( tf_device::ClusterFuncOp cluster_func, Builder* builder) { // By default return values from logical core 0 is used if no sharding // configuration is defined. - Block& function_block = func.getBody().getBlocks().front(); + Block& function_block = func.front(); Operation* terminator = function_block.getTerminator(); llvm::SmallVector sharding_for_rets( terminator->getNumOperands(), logical_core_0_sharding); diff --git a/tensorflow/compiler/mlir/tensorflow/transforms/tpu_variable_runtime_reformatting.cc b/tensorflow/compiler/mlir/tensorflow/transforms/tpu_variable_runtime_reformatting.cc index ec4a25c6fdd..d88982d9ee7 100644 --- a/tensorflow/compiler/mlir/tensorflow/transforms/tpu_variable_runtime_reformatting.cc +++ b/tensorflow/compiler/mlir/tensorflow/transforms/tpu_variable_runtime_reformatting.cc @@ -261,7 +261,6 @@ tf_device::ReplicateOp AddInputsToReplicateOp( // placed in logical core 0. // TODO(b/148913020): Remove this constraint once model parallelism is // supported. - assert(devices.size() == 1); assert(devices.find(tensorflow::GetDeviceAliasForLogicalCore(0)) ->getSecond() .size() == num_replicas); @@ -369,9 +368,6 @@ llvm::SmallVector CreateStateVars( // TODO(b/148913020): Remove this constraint once model parallelism is // supported. - assert(devices.size() == 1 && - "As model parallelism is not supported yet, tf_device.replicate " - "`devices` attribute should have one dictionary element."); const auto& device_list = devices.find(tensorflow::GetDeviceAliasForLogicalCore(0))->getSecond(); diff --git a/tensorflow/compiler/mlir/tensorflow/translate/export_graphdef.cc b/tensorflow/compiler/mlir/tensorflow/translate/export_graphdef.cc index 262f6f4e50c..8cd14894f8f 100644 --- a/tensorflow/compiler/mlir/tensorflow/translate/export_graphdef.cc +++ b/tensorflow/compiler/mlir/tensorflow/translate/export_graphdef.cc @@ -128,7 +128,7 @@ class LegalizedOpOrValLocNameMapper : public OpOrArgLocNameMapper { Status HasSingleGraphSingleOpIslandsFunctions(mlir::ModuleOp module) { Status status = Status::OK(); module.walk([&](mlir::FuncOp function) { - if (function.getBlocks().size() != 1) { + if (!llvm::hasSingleElement(function)) { status = errors::FailedPrecondition( kInvalidExecutorGraphMsg, "only single block functions are supported."); diff --git a/tensorflow/compiler/mlir/tensorflow/translate/import_model.cc b/tensorflow/compiler/mlir/tensorflow/translate/import_model.cc index 820d0ce31fb..fea809c0798 100644 --- a/tensorflow/compiler/mlir/tensorflow/translate/import_model.cc +++ b/tensorflow/compiler/mlir/tensorflow/translate/import_model.cc @@ -60,6 +60,8 @@ limitations under the License. #include "mlir/IR/Types.h" // from @llvm-project #include "mlir/IR/Verifier.h" // from @llvm-project #include "mlir/Pass/PassManager.h" // from @llvm-project +#include "tensorflow/cc/saved_model/constants.h" +#include "tensorflow/cc/saved_model/loader_util.h" #include "tensorflow/compiler/jit/shape_inference_helpers.h" #include "tensorflow/compiler/mlir/op_or_arg_name_mapper.h" #include "tensorflow/compiler/mlir/tensorflow/ir/tf_attributes.h" @@ -116,6 +118,7 @@ using mlir::NamedAttrList; using mlir::TensorType; using mlir::TF::VarHandleOp; using mlir::tf_saved_model::GlobalTensorOp; +using mlir::tf_saved_model::SessionInitializerOp; using stream_executor::port::StatusOr; namespace { @@ -2955,6 +2958,13 @@ void SortSavedModelModule(mlir::ModuleOp module) { named_global_tensor.global_tensor.getOperation()->moveBefore( &module.getBody()->front()); } + + auto initializers = module.getOps(); + if (!initializers.empty()) { + (*initializers.begin()) + .getOperation() + ->moveBefore(&module.getBody()->front()); + } } Status CreateSavedModelIR( @@ -3241,17 +3251,32 @@ class SavedModelSignatureDefImporter { absl::Span exported_names, mlir::MLIRContext* context) : bundle_(bundle), + flib_def_(OpRegistry::Global(), graph_def().library()), + debug_info_(), exported_names_(exported_names), - module_(mlir::ModuleOp::create(mlir::UnknownLoc::get(context))) {} + module_(mlir::ModuleOp::create(mlir::UnknownLoc::get(context))) { + // debug_info might not be loaded with loader_lite. + if (bundle_.debug_info != nullptr) debug_info_ = *bundle_.debug_info; + } // Converts the SavedModel to the SavedModel dialect. Creates an MLIR function // for each signature. StatusOr ConvertSignatures(); - Status ConvertSignature(const GraphDef& graphdef, - const std::string& sig_def_key, - const SignatureDef& signature_def, - const GraphDebugInfo& debug_info, - const FunctionLibraryDefinition& flib_def); + Status ConvertSignature(const std::string& sig_def_key, + const SignatureDef& signature_def); + + // Converts the initialization graph in the SavedModel to an MLIR function. + Status ConvertInitializer(); + + // Converts a graph with feeds and fetches to an MLIR function. + StatusOr ConvertGraph( + const std::string& name, + const std::vector>& inputs, + const std::vector>& outputs, + const std::vector control_outputs); + + // Coarsens the islands in `module_`. + Status CoarsenIslands(); // Creates GlobalTensorOp for each variable and moves each VarHandle op to // the enclosing function's arguments. @@ -3262,6 +3287,10 @@ class SavedModelSignatureDefImporter { // tensor's shape is used to provide the most accurate nested shape. void LiftVariable(VarHandleOp op, GlobalTensorOp global_tensor); + // Removes the variable and related ops in the init function if it is already + // imported as a global tensor. + void RemoveVariable(VarHandleOp op); + using VarGlobalMap = llvm::MapVector< llvm::StringRef, std::pair>>; @@ -3273,18 +3302,68 @@ class SavedModelSignatureDefImporter { GraphImportConfig::InputArrays ParseInputArrays( const std::vector>& inputs); + const GraphDef& graph_def() const { + return bundle_.meta_graph_def.graph_def(); + } + const FunctionLibraryDefinition& flib_def() const { return flib_def_; } + const GraphDebugInfo& debug_info() const { return debug_info_; } + const SavedModelBundle& bundle_; + FunctionLibraryDefinition flib_def_; + GraphDebugInfo debug_info_; absl::Span exported_names_; mlir::OwningModuleRef module_; }; +Status SavedModelSignatureDefImporter::ConvertInitializer() { + std::vector asset_file_defs; + TF_RETURN_IF_ERROR( + internal::GetAssetFileDefs(bundle_.meta_graph_def, &asset_file_defs)); + + if (!asset_file_defs.empty()) + return errors::Unimplemented( + absl::StrCat("Assets are not supported in signaturedef importer")); + + std::string init_node_name; + TF_RETURN_IF_ERROR( + internal::GetInitOp("", bundle_.meta_graph_def, &init_node_name)); + + if (init_node_name.empty()) return Status::OK(); + + TF_ASSIGN_OR_RETURN(auto sub_module, + ConvertGraph(init_node_name, {}, {}, {init_node_name})); + + mlir::SymbolTable symbol_table(*sub_module); + + auto init_func_op = symbol_table.lookup(init_node_name); + + init_func_op.removeAttr("tf.entry_function"); + + mlir::OpBuilder builder(module_->getBodyRegion()); + + // Set the exported name of init function to an reserved name for + // tf_saved_model. + init_func_op.setAttr( + "tf_saved_model.exported_names", + builder.getStrArrayAttr({"__tf_saved_model_session_initializer"})); + + builder.create( + module_->getLoc(), builder.getSymbolRefAttr(init_func_op.getName())); + + // Move the converted functions to top level MLIR module. + auto* block = module_->getBody(); + auto* sub_block = sub_module->getBody(); + block->getOperations().splice( + mlir::Block::iterator(block->getTerminator()), sub_block->getOperations(), + sub_block->begin(), mlir::Block::iterator(sub_block->getTerminator())); + + return Status::OK(); +} + StatusOr SavedModelSignatureDefImporter::ConvertSignatures() { const auto& signatures = bundle_.GetSignatures(); - const auto& graphdef = bundle_.meta_graph_def.graph_def(); - PopulateTfVersions(module_.get(), graphdef.versions()); - - FunctionLibraryDefinition flib_def(OpRegistry::Global(), graphdef.library()); + PopulateTfVersions(module_.get(), graph_def().versions()); // debug_info might not be loaded with loader_lite. GraphDebugInfo debug_info; @@ -3307,23 +3386,49 @@ SavedModelSignatureDefImporter::ConvertSignatures() { continue; } - TF_RETURN_IF_ERROR(ConvertSignature(graphdef, sig_def_key, signature_def, - debug_info, flib_def)); + TF_RETURN_IF_ERROR(ConvertSignature(sig_def_key, signature_def)); } - TF_RETURN_IF_ERROR(LiftVariables()); + + TF_RETURN_IF_ERROR(ConvertInitializer()); mlir::OpBuilder builder(module_->getBodyRegion()); module_->setAttr("tf_saved_model.semantics", builder.getUnitAttr()); + + TF_RETURN_IF_ERROR(CoarsenIslands()); + TF_RETURN_IF_ERROR(LiftVariables()); + SortSavedModelModule(*module_); MarkSavedModelFunctionVisibility(*module_); return std::move(module_); } +StatusOr SavedModelSignatureDefImporter::ConvertGraph( + const std::string& name, + const std::vector>& inputs, + const std::vector>& outputs, + const std::vector control_outputs) { + GraphImportConfig specs; + specs.prune_unused_nodes = true; + specs.inputs = ParseInputArrays(inputs); + for (auto& output : outputs) specs.outputs.push_back(output.second.name()); + specs.control_outputs = control_outputs; + + // Convert sub-graphdef to sub-graph. + GraphConstructorOptions options; + options.allow_internal_ops = true; + options.add_default_attributes = true; + Graph graph(OpRegistry::Global()); + + TF_RETURN_IF_ERROR(ConvertGraphDefToGraph(options, graph_def(), &graph)); + + // Convert sub-graph to MLIR module.true + return GraphDefImporter::Convert(module_->getContext(), graph, debug_info(), + flib_def(), specs, name); +} + Status SavedModelSignatureDefImporter::ConvertSignature( - const GraphDef& graphdef, const std::string& sig_def_key, - const SignatureDef& signature_def, const GraphDebugInfo& debug_info, - const FunctionLibraryDefinition& flib_def) { + const std::string& sig_def_key, const SignatureDef& signature_def) { // Create local vectors for the input and output and sort them to be // deterministic. We don't want anyone to really depend on the order, client // should lookup argument/result mapping by attribute name. @@ -3339,34 +3444,9 @@ Status SavedModelSignatureDefImporter::ConvertSignature( return lhs.first.size() < rhs.first.size() || lhs.first > rhs.first; }); - GraphImportConfig specs; - specs.prune_unused_nodes = true; - specs.inputs = ParseInputArrays(inputs); - for (auto& output : outputs) specs.outputs.push_back(output.second.name()); - - // Remove unused nodes and create sub-graphdef. - GraphDef sub_graph_def; - TF_RETURN_IF_ERROR(tensorflow::grappler::SetTransitiveFaninGraph( - graphdef, &sub_graph_def, - /*terminal_nodes=*/{specs.outputs.begin(), specs.outputs.end()})); - - // Set the function library definitions in the pruned graphdef. - *sub_graph_def.mutable_library() = flib_def.ToProto(); - - // Convert sub-graphdef to sub-graph. - GraphConstructorOptions options; - options.allow_internal_ops = true; - options.add_default_attributes = true; - Graph sub_graph(OpRegistry::Global()); - - TF_RETURN_IF_ERROR( - ConvertGraphDefToGraph(options, sub_graph_def, &sub_graph)); - // Convert sub-graph to MLIR module. - TF_ASSIGN_OR_RETURN( - auto sub_module, - GraphDefImporter::Convert(module_->getContext(), sub_graph, debug_info, - flib_def, specs, sig_def_key)); + TF_ASSIGN_OR_RETURN(auto sub_module, + ConvertGraph(sig_def_key, inputs, outputs, {})); mlir::OpBuilder builder(sub_module->getBodyRegion()); // Find the FuncOp which corresponds to current SignatureDef. @@ -3399,16 +3479,28 @@ Status SavedModelSignatureDefImporter::ConvertSignature( sub_block->begin(), mlir::Block::iterator(sub_block->getTerminator())); return Status::OK(); -} +} // namespace Status SavedModelSignatureDefImporter::LiftVariables() { VarGlobalMap var_globals; + llvm::SmallVector init_vars; - auto walker = [&var_globals](mlir::Operation* op) { - if (auto var_handle_op = llvm::dyn_cast(op)) - var_globals[var_handle_op.shared_name()].second.push_back(var_handle_op); - else if (op->getName().getStringRef() == "tf.VariableV2") + auto session_initializer = + mlir::tf_saved_model::GetSessionInitializerOp(*module_); + + auto walker = [&var_globals, &init_vars, + &session_initializer](mlir::Operation* op) { + if (auto var_handle_op = llvm::dyn_cast(op)) { + if (session_initializer && + session_initializer.initializer() == + var_handle_op.getParentOfType().getName()) + init_vars.push_back(var_handle_op); + else + var_globals[var_handle_op.shared_name()].second.push_back( + var_handle_op); + } else if (op->getName().getStringRef() == "tf.VariableV2") { return mlir::WalkResult::interrupt(); + } return mlir::WalkResult::advance(); }; bool contains_ref_variable = module_->walk(walker).wasInterrupted(); @@ -3425,9 +3517,51 @@ Status SavedModelSignatureDefImporter::LiftVariables() { for (VarHandleOp var_handle : it.second.second) LiftVariable(var_handle, it.second.first); + for (auto op : init_vars) RemoveVariable(op); + return Status::OK(); } +Status SavedModelSignatureDefImporter::CoarsenIslands() { + mlir::StatusScopedDiagnosticHandler diag_handler(module_->getContext()); + + mlir::PassManager pm(module_->getContext()); + pm.addNestedPass( + mlir::tf_executor::CreateTFExecutorIslandCoarseningPass()); + if (mlir::failed(pm.run(*module_))) + return diag_handler.Combine( + errors::Internal("failed to coarsening islands.")); + + return Status::OK(); +} + +void SavedModelSignatureDefImporter::RemoveVariable(VarHandleOp op) { + llvm::SmallVector work_list; + work_list.push_back(op); + while (!work_list.empty()) { + auto* op = work_list.back(); + work_list.pop_back(); + + for (mlir::Value res : op->getResults()) { + for (mlir::Operation* user : res.getUsers()) { + work_list.push_back(user); + } + } + + for (auto& use : op->getOpOperands()) { + if (mlir::Value value = use.get()) { + mlir::Operation* def = value.getDefiningOp(); + work_list.push_back(def); + } + } + + op->dropAllReferences(); + op->dropAllDefinedValueUses(); + + op->erase(); + } +} + void SavedModelSignatureDefImporter::LiftVariable( VarHandleOp op, GlobalTensorOp global_tensor) { mlir::OpBuilder builder(&module_->getBodyRegion()); @@ -3460,12 +3594,7 @@ void SavedModelSignatureDefImporter::LiftVariable( // Add the newly added function param to entry block's arguments. auto new_value = func_op.front().addArgument(resource_type); - // Remove the VarHandleOp also updating the containing island's return type. - DCHECK(llvm::isa(op.getParentOp())); - DCHECK(llvm::cast(op.getParentOp()) - .WrapsSingleOp()); op.getOperation()->replaceAllUsesWith(llvm::ArrayRef(new_value)); - op.getParentOp()->getResult(0).setType(resource_type); op.getOperation()->erase(); } diff --git a/tensorflow/compiler/mlir/tensorflow/translate/tf_functional_to_executor.cc b/tensorflow/compiler/mlir/tensorflow/translate/tf_functional_to_executor.cc index 29f98de6448..78019119d9d 100644 --- a/tensorflow/compiler/mlir/tensorflow/translate/tf_functional_to_executor.cc +++ b/tensorflow/compiler/mlir/tensorflow/translate/tf_functional_to_executor.cc @@ -46,13 +46,13 @@ struct FunctionalToExecutorDialectConversion } // end anonymous namespace void FunctionalToExecutorDialectConversion::runOnFunction() { - if (getFunction().getBlocks().size() != 1) { + if (!llvm::hasSingleElement(getFunction())) { LLVM_DEBUG(llvm::dbgs() << "Expect single block function, skip conversion " "to tf_executor dialect\n"); return; } auto loc = getFunction().getLoc(); - mlir::Block& body = getFunction().getBody().front(); + mlir::Block& body = getFunction().front(); // Find region of interest and ReturnOp. auto copy_range = body.without_terminator(); if (copy_range.begin() != copy_range.end() && diff --git a/tensorflow/compiler/mlir/tensorflow/translate/translate_tf_dialect_op.cc b/tensorflow/compiler/mlir/tensorflow/translate/translate_tf_dialect_op.cc index bd3fe9876ff..5236bdeffbf 100644 --- a/tensorflow/compiler/mlir/tensorflow/translate/translate_tf_dialect_op.cc +++ b/tensorflow/compiler/mlir/tensorflow/translate/translate_tf_dialect_op.cc @@ -13,6 +13,7 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ +#include "llvm/ADT/STLExtras.h" #include "llvm/Support/ToolOutputFile.h" #include "mlir/IR/Function.h" // from @llvm-project #include "mlir/IR/Location.h" // from @llvm-project @@ -26,12 +27,12 @@ static mlir::Operation* ExtractOnlyOp(mlir::ModuleOp module) { mlir::FuncOp fn = module.lookupSymbol("main"); if (!fn) return nullptr; - if (fn.getBlocks().size() != 1) return nullptr; + if (!llvm::hasSingleElement(fn)) return nullptr; // Here, modules with exactly two operations in the only basic block are // supported. The last operation should be a terminator operation and the // other operation is the operation of interest. - auto& block = fn.getBlocks().front(); + auto& block = fn.front(); if (block.getOperations().size() != 2) return nullptr; if (!block.back().isKnownTerminator()) return nullptr; diff --git a/tensorflow/compiler/mlir/tensorflow/utils/compile_mlir_util.cc b/tensorflow/compiler/mlir/tensorflow/utils/compile_mlir_util.cc index fd1ba3b1901..dac2fea87e2 100644 --- a/tensorflow/compiler/mlir/tensorflow/utils/compile_mlir_util.cc +++ b/tensorflow/compiler/mlir/tensorflow/utils/compile_mlir_util.cc @@ -267,9 +267,6 @@ Status ConvertMLIRToXlaComputation( const XlaCompiler::ShapeRepresentationFn shape_representation_fn, std::vector> custom_legalization_passes) { mlir::PassManager tf2xla(module_op.getContext()); - // Mark main function as public, and other functions as private. - tf2xla.addPass( - mlir::TF::CreateMarkOnlyMainFunctionWithPublicVisibilityPass()); tf2xla.addNestedPass(mlir::createCanonicalizerPass()); tf2xla.addPass(mlir::TF::CreateTensorListOpsDecompositionPass()); tf2xla.addPass(mlir::TF::CreateStackOpsDecompositionPass()); diff --git a/tensorflow/compiler/mlir/tensorflow/utils/dump_mlir_util.cc b/tensorflow/compiler/mlir/tensorflow/utils/dump_mlir_util.cc index 797687ea658..febf2bc096d 100644 --- a/tensorflow/compiler/mlir/tensorflow/utils/dump_mlir_util.cc +++ b/tensorflow/compiler/mlir/tensorflow/utils/dump_mlir_util.cc @@ -41,7 +41,7 @@ std::string MakeUniqueFilename(string name) { static NameCounts& instance = *new NameCounts; // Remove illegal characters from `name`. - for (int i = 0; i < name.size(); ++i) { + for (int i = 0, e = name.size(); i < e; ++i) { char ch = name[i]; if (ch == '/' || ch == '[' || ch == ']' || ch == '*' || ch == '?' || ch == '\\') { diff --git a/tensorflow/compiler/mlir/tools/kernel_gen/BUILD b/tensorflow/compiler/mlir/tools/kernel_gen/BUILD index cebfa7cd9d4..80b597d962d 100644 --- a/tensorflow/compiler/mlir/tools/kernel_gen/BUILD +++ b/tensorflow/compiler/mlir/tools/kernel_gen/BUILD @@ -44,8 +44,9 @@ tf_cc_binary( visibility = ["//tensorflow/core/kernels/cubin_headers:__pkg__"], deps = [ ":cubin_creator", - "//tensorflow/core:framework_internal", + "//tensorflow/compiler/mlir:init_mlir", "//tensorflow/core:lib", "@com_google_absl//absl/strings", + "@llvm-project//llvm:Support", ], ) diff --git a/tensorflow/compiler/mlir/tools/kernel_gen/cubin_creator.cc b/tensorflow/compiler/mlir/tools/kernel_gen/cubin_creator.cc index b534b5a5604..85a53e042e1 100644 --- a/tensorflow/compiler/mlir/tools/kernel_gen/cubin_creator.cc +++ b/tensorflow/compiler/mlir/tools/kernel_gen/cubin_creator.cc @@ -135,11 +135,11 @@ Status LowerTfOpToLhloWithDynamicShapes(mlir::ModuleOp module) { return Status::OK(); } -struct PropagateStaticKnowledge - : public mlir::PassWrapper> { - explicit PropagateStaticKnowledge(mlir::FunctionType type, - llvm::ArrayRef same_shape_) + explicit PropagateTensorFlowABIKnowledge(mlir::FunctionType type, + llvm::ArrayRef same_shape_) : func_type(type), same_shape(same_shape_) {} void runOnOperation() override { @@ -148,6 +148,11 @@ struct PropagateStaticKnowledge // we insert constants into the code and replace usages accordingly. // We do not change the signature so that we keep a somewhat stable ABI // that is easy to undertand by tools. + // We also know that tensorflow aligns all allocated pointers by 16, so + // we pass this on. Furthermore, we know that arguments never alias. More + // precicely, they may only alias (due to reuse) if the kernel does not + // read from a position it previously has written to. We express this with + // the noalias attribute. mlir::LLVM::LLVMFuncOp func = getOperation(); // This only works if the function is local and we can rewrite it. @@ -172,6 +177,9 @@ struct PropagateStaticKnowledge return; } positions.push_back(arg_pos); + // Set alignment and aliasing on the pointers. + func.setArgAttr(arg_pos + 1, "llvm.noalias", b.getBoolAttr(true)); + func.setArgAttr(arg_pos + 1, "llvm.align", b.getIndexAttr(16)); // Replace the offset with zero. Offset is argument number 3. func.getArgument(arg_pos + 2).replaceAllUsesWith(zero); // Forward over base_ptr, aligned_ptr, offset, size and stride arguments. @@ -213,7 +221,7 @@ struct PropagateStaticKnowledge llvm::ArrayRef same_shape; }; -Status PropagateStaticShapeKnowledgeToKernel( +Status PropagateTensorFlowABIKnowledgeToKernel( mlir::ModuleOp module, llvm::ArrayRef same_shape) { // Grab the original signature from the single function. auto func = *module.getBody()->op_begin(); @@ -228,7 +236,8 @@ Status PropagateStaticShapeKnowledgeToKernel( /*printAfterOnlyOnChange=*/false, llvm::dbgs()); auto& kernel_pm = pm.nest<::mlir::gpu::GPUModuleOp>(); kernel_pm.addNestedPass( - absl::make_unique(func.getType(), same_shape)); + absl::make_unique(func.getType(), + same_shape)); if (failed(pm.run(module))) { return InternalError("Static knowledge propagation failed."); @@ -259,11 +268,12 @@ StatusOr> tensorflow::kernel_gen::GenerateCubinForTfCode( options.tile_sizes = tile_sizes; options.unroll_factors = unroll_factors; options.collapse_parallel_loops = false; + options.use_approximations = true; TF_RETURN_IF_ERROR(xla::mlir_gpu::LowerLHLOToGPU(module.get(), options)); } TF_RETURN_IF_ERROR(xla::mlir_gpu::LowerKernelBodiesToNVVM(module.get())); TF_RETURN_IF_ERROR( - PropagateStaticShapeKnowledgeToKernel(module.get(), same_shape)); + PropagateTensorFlowABIKnowledgeToKernel(module.get(), same_shape)); mlir::OwningModuleRef kernel_module = xla::mlir_gpu::ExtractKernelModule(*module).ValueOrDie(); @@ -278,10 +288,15 @@ StatusOr> tensorflow::kernel_gen::GenerateCubinForTfCode( xla::HloModuleConfig config; config.set_debug_options(xla::GetDebugOptionsFromFlags()); + auto enable_fusion = [](llvm::TargetMachine* target) { + target->Options.AllowFPOpFusion = llvm::FPOpFusion::FPOpFusionMode::Fast; + }; + TF_ASSIGN_OR_RETURN(std::string libdevice_dir, GetLibdeviceDir(config)); - TF_ASSIGN_OR_RETURN(std::string ptx, xla::gpu::nvptx::CompileToPtx( - llvmModule.get(), compute_capability, - config, libdevice_dir)); + TF_ASSIGN_OR_RETURN( + std::string ptx, + xla::gpu::nvptx::CompileToPtx(llvmModule.get(), compute_capability, + config, libdevice_dir, enable_fusion)); VLOG(1) << ptx; #if GOOGLE_CUDA diff --git a/tensorflow/compiler/mlir/tools/kernel_gen/tf_to_cubin.cc b/tensorflow/compiler/mlir/tools/kernel_gen/tf_to_cubin.cc index 66fcabde0ac..96831689600 100644 --- a/tensorflow/compiler/mlir/tools/kernel_gen/tf_to_cubin.cc +++ b/tensorflow/compiler/mlir/tools/kernel_gen/tf_to_cubin.cc @@ -21,77 +21,37 @@ #include #include -#include "absl/strings/numbers.h" -#include "absl/strings/str_split.h" #include "absl/strings/string_view.h" +#include "llvm/Support/CommandLine.h" +#include "tensorflow/compiler/mlir/init_mlir.h" #include "tensorflow/compiler/mlir/tools/kernel_gen/cubin_creator.h" #include "tensorflow/core/platform/env.h" -#include "tensorflow/core/platform/init_main.h" #include "tensorflow/core/platform/logging.h" -#include "tensorflow/core/util/command_line_flags.h" - -namespace { -bool ParseStringList(std::string string_list, std::vector* result) { - result->clear(); - uint32_t item; - auto items = absl::StrSplit(string_list, ','); - for (const auto& item_str : items) { - if (!absl::SimpleAtoi(item_str, &item)) { - LOG(ERROR) << "Expected token " << item_str << " to be an integer"; - return false; - } - result->push_back(item); - } - return true; -} -} // namespace int main(int argc, char** argv) { - std::string input_file = "foo.mlir"; - std::string output_file = "foo.bin"; - int32_t architecture = 50; - std::vector tile_sizes; - std::vector unroll_factors; - std::vector same_shape; + llvm::cl::opt input_file("input", llvm::cl::desc("input file"), + llvm::cl::value_desc("filename"), + llvm::cl::init("foo.mlir")); + llvm::cl::opt output_file( + "output", llvm::cl::desc("output file"), llvm::cl::value_desc("filename"), + llvm::cl::init("foo.bin")); + llvm::cl::opt architecture( + "arch", llvm::cl::desc("target architecture (e.g. 50 for sm_50)"), + llvm::cl::init(50)); + llvm::cl::list tile_sizes( + "tile_sizes", llvm::cl::desc("tile sizes to use"), llvm::cl::ZeroOrMore, + llvm::cl::CommaSeparated); + llvm::cl::list unroll_factors( + "unroll_factors", + llvm::cl::desc("factors to unroll by, separated by commas"), + llvm::cl::ZeroOrMore, llvm::cl::CommaSeparated); + llvm::cl::list same_shape( + "same_shape", + llvm::cl::desc("arguments with same shape, separated by commas"), + llvm::cl::ZeroOrMore, llvm::cl::CommaSeparated); - auto parse_tile_sizes = [&tile_sizes](std::string tile_sizes_str) { - if (!ParseStringList(tile_sizes_str, &tile_sizes)) { - return false; - } - // Initialize with the default. - if (tile_sizes.empty()) { - tile_sizes.push_back(16); - tile_sizes.push_back(64); - } - return true; - }; - - auto parse_unroll_factors = - [&unroll_factors](std::string unroll_factors_str) { - return ParseStringList(unroll_factors_str, &unroll_factors); - }; - - auto parse_same_shape = [&same_shape](std::string same_shape_str) { - return ParseStringList(same_shape_str, &same_shape); - }; - - std::vector flag_list = { - tensorflow::Flag("input", &input_file, "input file"), - tensorflow::Flag("output", &output_file, "output file"), - tensorflow::Flag("arch", &architecture, - "target architecture (e.g. 50 for sm_50)"), - tensorflow::Flag("tile_sizes", parse_tile_sizes, "16,64", - "tile sizes to use"), - tensorflow::Flag("unroll_factors", parse_unroll_factors, "", - "factors to unroll by, separated by commas"), - tensorflow::Flag("same_shape", parse_same_shape, "", - "arguments with same shape, separated by commas"), - }; - bool parse_ok = tensorflow::Flags::Parse(&argc, argv, flag_list); - tensorflow::port::InitMain("usage", &argc, &argv); - if (!parse_ok) { - return 1; - } + tensorflow::InitMlir y(&argc, &argv); + llvm::cl::ParseCommandLineOptions(argc, argv, "TF op GPU kernel generator\n"); std::pair compute_capability(architecture / 10, architecture % 10); diff --git a/tensorflow/compiler/mlir/xla/BUILD b/tensorflow/compiler/mlir/xla/BUILD index 43458aab2d3..d089f80d571 100644 --- a/tensorflow/compiler/mlir/xla/BUILD +++ b/tensorflow/compiler/mlir/xla/BUILD @@ -515,6 +515,24 @@ cc_library( alwayslink = 1, ) +cc_library( + name = "xla_legalize_tanh_to_approximation", + srcs = ["transforms/legalize_tanh_to_approximation.cc"], + hdrs = [ + "transforms/passes.h", + "transforms/rewriters.h", + ], + deps = [ + "@llvm-project//llvm:Support", + "@llvm-project//mlir:IR", + "@llvm-project//mlir:Pass", + "@llvm-project//mlir:StandardOps", + "@llvm-project//mlir:Support", + "@llvm-project//mlir:Transforms", + ], + alwayslink = 1, +) + gentbl( name = "xla_lower_complex_inc_gen", tbl_outs = [ @@ -946,6 +964,7 @@ cc_library( ":xla_hlo_fusion", ":xla_hlo_to_lhlo_with_xla", ":xla_legalize_control_flow", + ":xla_legalize_tanh_to_approximation", ":xla_legalize_tf", ":xla_legalize_tf_with_tf2xla", ":xla_legalize_to_linalg", diff --git a/tensorflow/compiler/mlir/xla/ir/chlo_ops.cc b/tensorflow/compiler/mlir/xla/ir/chlo_ops.cc index 26db4549a2a..3408f3ed0cc 100644 --- a/tensorflow/compiler/mlir/xla/ir/chlo_ops.cc +++ b/tensorflow/compiler/mlir/xla/ir/chlo_ops.cc @@ -49,7 +49,7 @@ static Type GetBroadcastType(Type x, Type y, Type element_type, if (shape_x.size() == shape_y.size()) { llvm::SmallVector out_shape(shape_x.size()); - for (int i = 0; i < shape_x.size(); i++) { + for (int i = 0, e = shape_x.size(); i < e; i++) { auto x_val = shape_x[i]; auto y_val = shape_y[i]; if (x_val == -1 || y_val == -1) { diff --git a/tensorflow/compiler/mlir/xla/ir/hlo_ops.cc b/tensorflow/compiler/mlir/xla/ir/hlo_ops.cc index 38bff6c2ca7..e0fa1da93b8 100644 --- a/tensorflow/compiler/mlir/xla/ir/hlo_ops.cc +++ b/tensorflow/compiler/mlir/xla/ir/hlo_ops.cc @@ -106,53 +106,6 @@ DenseIntElementsAttr BuildSliceLimits(DenseIntElementsAttr start_indices, return GetI64ElementsAttr(slice_limits, builder); } -// Returns the padding value of the given position. If padding_attr is a -// nullptr, returns 0. -static int64_t GetPaddingValue(DenseIntElementsAttr padding_attr, - ArrayRef index) { - if (!padding_attr) return 0; - return padding_attr.getValue(index); -} - -static bool IsOnlyPaddingSpatialDims(Value lhs, - ConvDimensionNumbers dimension_numbers, - DenseIntElementsAttr edge_padding_low, - DenseIntElementsAttr edge_padding_high) { - const int64_t batch_dim = dimension_numbers.input_batch_dimension().getInt(); - const int64_t feature_dim = - dimension_numbers.input_feature_dimension().getInt(); - if (edge_padding_low.getValue(batch_dim) || - edge_padding_high.getValue(batch_dim)) - return false; - if (edge_padding_low.getValue(feature_dim) || - edge_padding_high.getValue(feature_dim)) - return false; - return true; -} - -DenseIntElementsAttr BuildConvPaddingAttrs( - DenseIntElementsAttr edge_padding_low, - DenseIntElementsAttr edge_padding_high, DenseIntElementsAttr padding_attr, - ConvDimensionNumbers dimension_numbers, Builder* builder) { - SmallVector padding_low, padding_high; - for (const auto& dim : dimension_numbers.input_spatial_dimensions()) { - unsigned i = dim.getZExtValue(); - padding_low.push_back(edge_padding_low.getValue(i)); - padding_high.push_back(edge_padding_high.getValue(i)); - } - - int rank = padding_low.size(); - SmallVector padding; - for (unsigned i = 0; i < rank; ++i) { - padding.push_back(GetPaddingValue(padding_attr, {i, 0}) + padding_low[i]); - padding.push_back(GetPaddingValue(padding_attr, {i, 1}) + padding_high[i]); - } - // padding_attr.getType() doesn't work because it is an optional attribute, - // which can be a nullptr. - auto type = RankedTensorType::get({rank, 2}, builder->getIntegerType(64)); - return DenseIntElementsAttr::get(type, padding); -} - #include "tensorflow/compiler/mlir/xla/transforms/generated_canonicalize.inc" } // namespace @@ -891,7 +844,7 @@ static Attribute foldConcatenateHelper(ConcatenateOp* op, auto shape = type.getShape(); size_t top_size = 1; - for (int i = 0; i < axis; i++) { + for (int i = 0, e = axis; i < e; i++) { top_size = top_size * shape[i]; } @@ -1169,7 +1122,7 @@ static LogicalResult Verify(MapOp op) { // increasing. auto values = op.dimensions().getValues(); auto dimensions = std::vector{values.begin(), values.end()}; - for (int i = 0; i < dimensions.size(); ++i) { + for (int i = 0, e = dimensions.size(); i < e; ++i) { if (dimensions[i] != i) return op.emitOpError() << "requires monotonically increasing dimension " "numbers, but got: " @@ -2153,14 +2106,5 @@ LogicalResult deriveShapeFromFirstOperand( return success(); } -//===----------------------------------------------------------------------===// -// ConvOp -//===----------------------------------------------------------------------===// - -void ConvOp::getCanonicalizationPatterns(OwningRewritePatternList& results, - MLIRContext* context) { - results.insert(context); -} - } // namespace xla_hlo } // namespace mlir diff --git a/tensorflow/compiler/mlir/xla/ir/hlo_ops.td b/tensorflow/compiler/mlir/xla/ir/hlo_ops.td index b1745c73fbf..f92d1c5b85c 100644 --- a/tensorflow/compiler/mlir/xla/ir/hlo_ops.td +++ b/tensorflow/compiler/mlir/xla/ir/hlo_ops.td @@ -929,8 +929,6 @@ def HLO_ConvOp : HLO_Op<"convolution", [NoSideEffect]>, BASE_HLO_ConvOp { ); let results = (outs HLO_Tensor); - - let hasCanonicalizer = 1; } def HLO_CopyOp: HLO_Op<"copy", [NoSideEffect, SameOperandsAndResultType]>, BASE_HLO_CopyOp { diff --git a/tensorflow/compiler/mlir/xla/ir/lhlo_ops.td b/tensorflow/compiler/mlir/xla/ir/lhlo_ops.td index aed7c83570e..95ad97118ef 100644 --- a/tensorflow/compiler/mlir/xla/ir/lhlo_ops.td +++ b/tensorflow/compiler/mlir/xla/ir/lhlo_ops.td @@ -760,15 +760,6 @@ def LHLO_SortOp: LHLO_Op<"sort", []>, BASE_HLO_SortOp { let regions = (region SizedRegion<1>:$comparator); } -def LHLO_TupleSelectOp: LHLO_Op<"tuple_select", [SameOperandsShape]> { - let arguments = (ins - Arg:$pred, - Arg:$on_true, - Arg:$on_false, - Arg:$output - ); -} - //===----------------------------------------------------------------------===// // Late operations //===----------------------------------------------------------------------===// diff --git a/tensorflow/compiler/mlir/xla/ir/mlir_hlo_builder.cc b/tensorflow/compiler/mlir/xla/ir/mlir_hlo_builder.cc index 21b1ac5f0ea..3c11d8e590d 100644 --- a/tensorflow/compiler/mlir/xla/ir/mlir_hlo_builder.cc +++ b/tensorflow/compiler/mlir/xla/ir/mlir_hlo_builder.cc @@ -132,6 +132,22 @@ StatusOr MlirHloBuilder::FftInternal( return MakeXlaOp(op); } +StatusOr MlirHloBuilder::CustomCallInternal( + const string& call_target_name, absl::Span operands, + const Shape& shape, const string& opaque, + absl::optional> operand_shapes_with_layout) { + if (operand_shapes_with_layout.has_value()) + return Unimplemented( + "CustomCall doesn't support operands shapes with layout"); + TF_ASSIGN_OR_RETURN(mlir::Type ty, ConvertShapeToType( + shape, builder_)); + auto op = builder_.create( + loc_, ty, GetValues(operands), builder_.getStringAttr(call_target_name), + /*has_side_effect=*/builder_.getBoolAttr(false), + builder_.getStringAttr(opaque)); + return MakeXlaOp(op); +} + StatusOr MlirHloBuilder::ReduceInternal( const Shape& shape, absl::Span all_operands, const XlaComputation& computation, diff --git a/tensorflow/compiler/mlir/xla/ir/mlir_hlo_builder.h b/tensorflow/compiler/mlir/xla/ir/mlir_hlo_builder.h index 4b28c32db99..4d7d93af7a7 100644 --- a/tensorflow/compiler/mlir/xla/ir/mlir_hlo_builder.h +++ b/tensorflow/compiler/mlir/xla/ir/mlir_hlo_builder.h @@ -124,6 +124,12 @@ class MlirHloBuilder : public XlaBuilder { FftType fft_type, absl::Span fft_length) override; + StatusOr CustomCallInternal(const string& call_target_name, + absl::Span operands, + const Shape& shape, const string& opaque, + absl::optional> + operand_shapes_with_layout) override; + StatusOr ReduceInternal( const Shape& shape, absl::Span all_operands, const XlaComputation& computation, diff --git a/tensorflow/compiler/mlir/xla/mlir_hlo_to_hlo.cc b/tensorflow/compiler/mlir/xla/mlir_hlo_to_hlo.cc index 60d9a698731..7a576780c61 100644 --- a/tensorflow/compiler/mlir/xla/mlir_hlo_to_hlo.cc +++ b/tensorflow/compiler/mlir/xla/mlir_hlo_to_hlo.cc @@ -1148,13 +1148,13 @@ LogicalResult ConvertToHloModule::LowerFunctionCall( LogicalResult ConvertToHloModule::RunOnFunction(mlir::FuncOp f) { if (lowered_computation_.count(f)) return success(); - if (f.getBlocks().size() != 1) { + if (!llvm::hasSingleElement(f)) { return f.emitError("only single block Function supported"); } // Create a sub-builder if this is not the main function. std::unique_ptr builder_up; - bool entry_function = f.getName().str() == "main"; + bool entry_function = f.getName() == "main"; if (!entry_function) builder_up = module_builder_.CreateSubBuilder(f.getName().str()); auto& builder = entry_function ? module_builder_ : *builder_up; diff --git a/tensorflow/compiler/mlir/xla/tests/canonicalize.mlir b/tensorflow/compiler/mlir/xla/tests/canonicalize.mlir index ef0f8c4d200..1954c3344df 100644 --- a/tensorflow/compiler/mlir/xla/tests/canonicalize.mlir +++ b/tensorflow/compiler/mlir/xla/tests/canonicalize.mlir @@ -415,71 +415,6 @@ func @fold_copy(%arg : tensor<1x4xf32>) -> tensor<1x4xf32> { return %0 : tensor<1x4xf32> } -// CHECK-LABEL: func @fold_pad_into_conv_f32 -func @fold_pad_into_conv_f32(%arg0 : tensor<1x32x32x3xf32>, - %arg1 : tensor<7x7x3x64xf32>) - -> tensor<1x16x16x64xf32> { - // CHECK-NOT: xla_hlo.pad - // CHECK: xla_hlo.convolution - // CHECK-SAME: padding = dense<3> : tensor<2x2xi64> - %0 = xla_hlo.constant dense<0.000000e+00> : tensor - %1 = "xla_hlo.pad"(%arg0, %0) { - edge_padding_high = dense<[0, 3, 3, 0]> : tensor<4xi64>, - edge_padding_low = dense<[0, 3, 3, 0]> : tensor<4xi64>, - interior_padding = dense<0> : tensor<4xi64> - } : (tensor<1x32x32x3xf32>, tensor) -> tensor<1x38x38x3xf32> - %2 = "xla_hlo.convolution"(%1, %arg1) { - batch_group_count = 1 : i64, - dimension_numbers = { - input_batch_dimension = 0 : i64, - input_feature_dimension = 3 : i64, - input_spatial_dimensions = dense<[1, 2]> : tensor<2xi64>, - kernel_input_feature_dimension = 2 : i64, - kernel_output_feature_dimension = 3 : i64, - kernel_spatial_dimensions = dense<[0, 1]> : tensor<2xi64>, - output_batch_dimension = 0 : i64, - output_feature_dimension = 3 : i64, - output_spatial_dimensions = dense<[1, 2]> : tensor<2xi64> - }, - feature_group_count = 1 : i64, - padding = dense<0> : tensor<2x2xi64>, - window_strides = dense<2> : tensor<2xi64> - } : (tensor<1x38x38x3xf32>, tensor<7x7x3x64xf32>) -> tensor<1x16x16x64xf32> - return %2 : tensor<1x16x16x64xf32> -} - -// CHECK-LABEL: func @fold_pad_into_conv_i32 -func @fold_pad_into_conv_i32(%arg0 : tensor<1x32x32x3xi32>, - %arg1 : tensor<7x7x3x64xi32>) - -> tensor<1x16x16x64xi32> { - // CHECK-NOT: xla_hlo.pad - // CHECK: xla_hlo.convolution - // CHECK-SAME: padding = dense<3> : tensor<2x2xi64> - %0 = xla_hlo.constant dense<0> : tensor - %1 = "xla_hlo.pad"(%arg0, %0) { - edge_padding_high = dense<[0, 3, 3, 0]> : tensor<4xi64>, - edge_padding_low = dense<[0, 3, 3, 0]> : tensor<4xi64>, - interior_padding = dense<0> : tensor<4xi64> - } : (tensor<1x32x32x3xi32>, tensor) -> tensor<1x38x38x3xi32> - %2 = "xla_hlo.convolution"(%1, %arg1) { - batch_group_count = 1 : i64, - dimension_numbers = { - input_batch_dimension = 0 : i64, - input_feature_dimension = 3 : i64, - input_spatial_dimensions = dense<[1, 2]> : tensor<2xi64>, - kernel_input_feature_dimension = 2 : i64, - kernel_output_feature_dimension = 3 : i64, - kernel_spatial_dimensions = dense<[0, 1]> : tensor<2xi64>, - output_batch_dimension = 0 : i64, - output_feature_dimension = 3 : i64, - output_spatial_dimensions = dense<[1, 2]> : tensor<2xi64> - }, - feature_group_count = 1 : i64, - window_strides = dense<2> : tensor<2xi64> - } : (tensor<1x38x38x3xi32>, tensor<7x7x3x64xi32>) -> tensor<1x16x16x64xi32> - return %2 : tensor<1x16x16x64xi32> -} - // CHECK-LABEL: func @dynamic_reshape_not_actually_dynamic func @dynamic_reshape_not_actually_dynamic(%arg0: tensor<4xf32>, %shape: tensor<2xindex>) -> tensor<4x1xf32> { // CHECK: xla_hlo.reshape diff --git a/tensorflow/compiler/mlir/xla/tests/legalize-tf-with-tf2xla.mlir b/tensorflow/compiler/mlir/xla/tests/legalize-tf-with-tf2xla.mlir index b8a6df54519..86a7f2b9e09 100644 --- a/tensorflow/compiler/mlir/xla/tests/legalize-tf-with-tf2xla.mlir +++ b/tensorflow/compiler/mlir/xla/tests/legalize-tf-with-tf2xla.mlir @@ -35,7 +35,7 @@ func @not_whitelisted_op(%arg0: tensor<3xi32>, %arg1: tensor, %arg2: tensor // CHECK-LABEL: unranked_operand func @unranked_operand(%arg0: tensor<*xf32>) -> tensor<*xf32> { // CHECK: tf.Abs - // expected-remark@+1 {{lowering requires static shaped operands}} + // expected-remark@+1 {{lowering requires static shaped tensor operands}} %0 = "tf.Abs"(%arg0) : (tensor<*xf32>) -> tensor<*xf32> return %0 : tensor<*xf32> @@ -44,12 +44,20 @@ func @unranked_operand(%arg0: tensor<*xf32>) -> tensor<*xf32> { // CHECK-LABEL: dynamic_operand func @dynamic_operand(%arg0: tensor) -> tensor { // CHECK: tf.Abs - // expected-remark@+1 {{lowering requires static shaped operands}} + // expected-remark@+1 {{lowering requires static shaped tensor operands}} %0 = "tf.Abs"(%arg0) : (tensor) -> tensor return %0 : tensor } +// CHECK-LABEL: tuple_type +func @tuple_type(%arg0: tuple, tensor>) -> tensor { + // Verifies that the pass can handle operands of non-tensor type like tuple + // from non TensorFlow ops. + %0 = "xla_hlo.get_tuple_element"(%arg0) {index = 0 : i32} : (tuple, tensor>) -> tensor + return %0 : tensor +} + // CHECK-LABEL: unsupported_dtype func @unsupported_dtype(%arg0: tensor<2x!tf.variant>) -> tensor<2x!tf.variant> { // CHECK: tf.AddN diff --git a/tensorflow/compiler/mlir/xla/tests/legalize-tf.mlir b/tensorflow/compiler/mlir/xla/tests/legalize-tf.mlir index 8d6969dd669..2cd98ea3f6b 100644 --- a/tensorflow/compiler/mlir/xla/tests/legalize-tf.mlir +++ b/tensorflow/compiler/mlir/xla/tests/legalize-tf.mlir @@ -814,6 +814,13 @@ func @preventgradient(%arg0: tensor<1xi32>) -> tensor<1xi32> { return %0: tensor<1xi32> } +// CHECK-LABEL: func @checkNumerics +func @checkNumerics(%arg0: tensor<1xf32>) -> tensor<1xf32> { + // CHECK-NEXT: return %arg0 : tensor<1xf32> + %0 = "tf.CheckNumerics"(%arg0) {message = "check numerics"} : (tensor<1xf32>) -> tensor<1xf32> + return %0: tensor<1xf32> +} + //===----------------------------------------------------------------------===// // InfeedDequeueTuple legalization //===----------------------------------------------------------------------===// diff --git a/tensorflow/compiler/mlir/xla/tests/legalize_tanh_to_approximation.mlir b/tensorflow/compiler/mlir/xla/tests/legalize_tanh_to_approximation.mlir new file mode 100644 index 00000000000..a8286c9b5a9 --- /dev/null +++ b/tensorflow/compiler/mlir/xla/tests/legalize_tanh_to_approximation.mlir @@ -0,0 +1,134 @@ +// RUN: xla-opt -xla-legalize-tanh-to-approximation -split-input-file %s | FileCheck %s + +func @tanh_f64(%arg0 : f64) -> f64 { + %res = tanh %arg0 : f64 + return %res : f64 +} + +// CHECK-LABEL: @tanh_f64 +// CHECK: tanh + +// ----- + +func @tanh_f32(%arg0 : f32) -> f32 { + %res = tanh %arg0 : f32 + return %res : f32 +} + +// NOTE: Assertions have been autogenerated by utils/generate-test-checks.py +// CHECK: module { + +// CHECK-LABEL: func @tanh_f32( +// CHECK-SAME: %[[VAL_0:.*]]: f32) -> f32 { +// CHECK: %[[VAL_1:.*]] = constant 2.000000e+01 : f32 +// CHECK: %[[VAL_2:.*]] = constant 1.000000e+00 : f32 +// CHECK: %[[VAL_3:.*]] = constant 4.000000e-04 : f32 +// CHECK: %[[VAL_4:.*]] = constant 9.000000e+00 : f32 +// CHECK: %[[VAL_5:.*]] = constant -2.76076837E-16 : f32 +// CHECK: %[[VAL_6:.*]] = constant 2.00018794E-13 : f32 +// CHECK: %[[VAL_7:.*]] = constant -8.60467184E-11 : f32 +// CHECK: %[[VAL_8:.*]] = constant 5.12229725E-8 : f32 +// CHECK: %[[VAL_9:.*]] = constant 1.48572235E-5 : f32 +// CHECK: %[[VAL_10:.*]] = constant 6.37261954E-4 : f32 +// CHECK: %[[VAL_11:.*]] = constant 0.00489352457 : f32 +// CHECK: %[[VAL_12:.*]] = constant 1.19825836E-6 : f32 +// CHECK: %[[VAL_13:.*]] = constant 1.18534706E-4 : f32 +// CHECK: %[[VAL_14:.*]] = constant 0.00226843474 : f32 +// CHECK: %[[VAL_15:.*]] = constant 0.00489352504 : f32 +// CHECK: %[[VAL_16:.*]] = absf %[[VAL_0]] : f32 +// CHECK: %[[VAL_17:.*]] = copysign %[[VAL_2]], %[[VAL_0]] : f32 +// CHECK: %[[VAL_18:.*]] = cmpf "ult", %[[VAL_16]], %[[VAL_1]] : f32 +// CHECK: %[[VAL_19:.*]] = cmpf "olt", %[[VAL_16]], %[[VAL_3]] : f32 +// CHECK: %[[VAL_20:.*]] = cmpf "ule", %[[VAL_16]], %[[VAL_4]] : f32 +// CHECK: %[[VAL_21:.*]] = copysign %[[VAL_4]], %[[VAL_0]] : f32 +// CHECK: %[[VAL_22:.*]] = select %[[VAL_20]], %[[VAL_0]], %[[VAL_21]] : f32 +// CHECK: %[[VAL_23:.*]] = mulf %[[VAL_22]], %[[VAL_22]] : f32 +// CHECK: %[[VAL_24:.*]] = mulf %[[VAL_23]], %[[VAL_5]] : f32 +// CHECK: %[[VAL_25:.*]] = addf %[[VAL_24]], %[[VAL_6]] : f32 +// CHECK: %[[VAL_26:.*]] = mulf %[[VAL_23]], %[[VAL_25]] : f32 +// CHECK: %[[VAL_27:.*]] = addf %[[VAL_26]], %[[VAL_7]] : f32 +// CHECK: %[[VAL_28:.*]] = mulf %[[VAL_23]], %[[VAL_27]] : f32 +// CHECK: %[[VAL_29:.*]] = addf %[[VAL_28]], %[[VAL_8]] : f32 +// CHECK: %[[VAL_30:.*]] = mulf %[[VAL_23]], %[[VAL_29]] : f32 +// CHECK: %[[VAL_31:.*]] = addf %[[VAL_30]], %[[VAL_9]] : f32 +// CHECK: %[[VAL_32:.*]] = mulf %[[VAL_23]], %[[VAL_31]] : f32 +// CHECK: %[[VAL_33:.*]] = addf %[[VAL_32]], %[[VAL_10]] : f32 +// CHECK: %[[VAL_34:.*]] = mulf %[[VAL_23]], %[[VAL_33]] : f32 +// CHECK: %[[VAL_35:.*]] = addf %[[VAL_34]], %[[VAL_11]] : f32 +// CHECK: %[[VAL_36:.*]] = mulf %[[VAL_22]], %[[VAL_35]] : f32 +// CHECK: %[[VAL_37:.*]] = mulf %[[VAL_23]], %[[VAL_12]] : f32 +// CHECK: %[[VAL_38:.*]] = addf %[[VAL_37]], %[[VAL_13]] : f32 +// CHECK: %[[VAL_39:.*]] = mulf %[[VAL_23]], %[[VAL_38]] : f32 +// CHECK: %[[VAL_40:.*]] = addf %[[VAL_39]], %[[VAL_14]] : f32 +// CHECK: %[[VAL_41:.*]] = mulf %[[VAL_23]], %[[VAL_40]] : f32 +// CHECK: %[[VAL_42:.*]] = addf %[[VAL_41]], %[[VAL_15]] : f32 +// CHECK: %[[VAL_43:.*]] = divf %[[VAL_36]], %[[VAL_42]] : f32 +// CHECK: %[[VAL_44:.*]] = select %[[VAL_19]], %[[VAL_0]], %[[VAL_43]] : f32 +// CHECK: %[[VAL_45:.*]] = select %[[VAL_18]], %[[VAL_44]], %[[VAL_17]] : f32 +// CHECK: return %[[VAL_45]] : f32 +// CHECK: } +// CHECK: } + +// ----- + +func @tanh_f16(%arg0 : f16) -> f16 { + %res = tanh %arg0 : f16 + return %res : f16 +} + +// NOTE: Assertions have been autogenerated by utils/generate-test-checks.py +// CHECK: module { + +// CHECK-LABEL: func @tanh_f16( +// CHECK-SAME: %[[VAL_0:.*]]: f16) -> f16 { +// CHECK: %[[VAL_1:.*]] = constant 2.000000e+01 : f32 +// CHECK: %[[VAL_2:.*]] = constant 1.000000e+00 : f32 +// CHECK: %[[VAL_3:.*]] = constant 4.000000e-04 : f32 +// CHECK: %[[VAL_4:.*]] = constant 9.000000e+00 : f32 +// CHECK: %[[VAL_5:.*]] = constant -2.76076837E-16 : f32 +// CHECK: %[[VAL_6:.*]] = constant 2.00018794E-13 : f32 +// CHECK: %[[VAL_7:.*]] = constant -8.60467184E-11 : f32 +// CHECK: %[[VAL_8:.*]] = constant 5.12229725E-8 : f32 +// CHECK: %[[VAL_9:.*]] = constant 1.48572235E-5 : f32 +// CHECK: %[[VAL_10:.*]] = constant 6.37261954E-4 : f32 +// CHECK: %[[VAL_11:.*]] = constant 0.00489352457 : f32 +// CHECK: %[[VAL_12:.*]] = constant 1.19825836E-6 : f32 +// CHECK: %[[VAL_13:.*]] = constant 1.18534706E-4 : f32 +// CHECK: %[[VAL_14:.*]] = constant 0.00226843474 : f32 +// CHECK: %[[VAL_15:.*]] = constant 0.00489352504 : f32 +// CHECK: %[[VAL_16:.*]] = fpext %[[VAL_0]] : f16 to f32 +// CHECK: %[[VAL_17:.*]] = absf %[[VAL_16]] : f32 +// CHECK: %[[VAL_18:.*]] = copysign %[[VAL_2]], %[[VAL_16]] : f32 +// CHECK: %[[VAL_19:.*]] = cmpf "ult", %[[VAL_17]], %[[VAL_1]] : f32 +// CHECK: %[[VAL_20:.*]] = cmpf "olt", %[[VAL_17]], %[[VAL_3]] : f32 +// CHECK: %[[VAL_21:.*]] = cmpf "ule", %[[VAL_17]], %[[VAL_4]] : f32 +// CHECK: %[[VAL_22:.*]] = copysign %[[VAL_4]], %[[VAL_16]] : f32 +// CHECK: %[[VAL_23:.*]] = select %[[VAL_21]], %[[VAL_16]], %[[VAL_22]] : f32 +// CHECK: %[[VAL_24:.*]] = mulf %[[VAL_23]], %[[VAL_23]] : f32 +// CHECK: %[[VAL_25:.*]] = mulf %[[VAL_24]], %[[VAL_5]] : f32 +// CHECK: %[[VAL_26:.*]] = addf %[[VAL_25]], %[[VAL_6]] : f32 +// CHECK: %[[VAL_27:.*]] = mulf %[[VAL_24]], %[[VAL_26]] : f32 +// CHECK: %[[VAL_28:.*]] = addf %[[VAL_27]], %[[VAL_7]] : f32 +// CHECK: %[[VAL_29:.*]] = mulf %[[VAL_24]], %[[VAL_28]] : f32 +// CHECK: %[[VAL_30:.*]] = addf %[[VAL_29]], %[[VAL_8]] : f32 +// CHECK: %[[VAL_31:.*]] = mulf %[[VAL_24]], %[[VAL_30]] : f32 +// CHECK: %[[VAL_32:.*]] = addf %[[VAL_31]], %[[VAL_9]] : f32 +// CHECK: %[[VAL_33:.*]] = mulf %[[VAL_24]], %[[VAL_32]] : f32 +// CHECK: %[[VAL_34:.*]] = addf %[[VAL_33]], %[[VAL_10]] : f32 +// CHECK: %[[VAL_35:.*]] = mulf %[[VAL_24]], %[[VAL_34]] : f32 +// CHECK: %[[VAL_36:.*]] = addf %[[VAL_35]], %[[VAL_11]] : f32 +// CHECK: %[[VAL_37:.*]] = mulf %[[VAL_23]], %[[VAL_36]] : f32 +// CHECK: %[[VAL_38:.*]] = mulf %[[VAL_24]], %[[VAL_12]] : f32 +// CHECK: %[[VAL_39:.*]] = addf %[[VAL_38]], %[[VAL_13]] : f32 +// CHECK: %[[VAL_40:.*]] = mulf %[[VAL_24]], %[[VAL_39]] : f32 +// CHECK: %[[VAL_41:.*]] = addf %[[VAL_40]], %[[VAL_14]] : f32 +// CHECK: %[[VAL_42:.*]] = mulf %[[VAL_24]], %[[VAL_41]] : f32 +// CHECK: %[[VAL_43:.*]] = addf %[[VAL_42]], %[[VAL_15]] : f32 +// CHECK: %[[VAL_44:.*]] = divf %[[VAL_37]], %[[VAL_43]] : f32 +// CHECK: %[[VAL_45:.*]] = select %[[VAL_20]], %[[VAL_16]], %[[VAL_44]] : f32 +// CHECK: %[[VAL_46:.*]] = select %[[VAL_19]], %[[VAL_45]], %[[VAL_18]] : f32 +// CHECK: %[[VAL_47:.*]] = fptrunc %[[VAL_46]] : f32 to f16 +// CHECK: return %[[VAL_47]] : f16 +// CHECK: } +// CHECK: } + diff --git a/tensorflow/compiler/mlir/xla/tests/lhlo_ops.mlir b/tensorflow/compiler/mlir/xla/tests/lhlo_ops.mlir index 0ed8b36466e..1e803da4ac6 100644 --- a/tensorflow/compiler/mlir/xla/tests/lhlo_ops.mlir +++ b/tensorflow/compiler/mlir/xla/tests/lhlo_ops.mlir @@ -964,23 +964,3 @@ func @sort_memrefs(%arg0: memref<16x16xf32>, %arg1: memref<16x16xf16>, }) : (memref<16x16xf32>, memref<16x16xf16>, tuple, memref<16x16xf16>>) -> () return } - -// ----- - -// CHECK-LABEL: func @tuple_select_memrefs -func @tuple_select_memrefs(%pred: memref<20xi1>, %true_values: memref<20xf32>, - %false_values: memref<20xf32>, %arg_out: memref<20xf32>) -> () { - "xla_lhlo.tuple_select"(%pred, %true_values, %false_values, %arg_out) - : (memref<20xi1>, memref<20xf32>, memref<20xf32>, memref<20xf32>) -> () - return -} - -// ----- - -func @tuple_select_memrefs(%pred: memref<10xi1>, %true_values: memref<20xf32>, - %false_values: memref<20xf32>, %arg_out: memref<20xf32>) -> () { - // expected-error@+1{{requires the same shape for all operands}} - "xla_lhlo.tuple_select"(%pred, %true_values, %false_values, %arg_out) - : (memref<10xi1>, memref<20xf32>, memref<20xf32>, memref<20xf32>) -> () - return -} diff --git a/tensorflow/compiler/mlir/xla/transforms/canonicalize.td b/tensorflow/compiler/mlir/xla/transforms/canonicalize.td index b788cb80380..c319551d92a 100644 --- a/tensorflow/compiler/mlir/xla/transforms/canonicalize.td +++ b/tensorflow/compiler/mlir/xla/transforms/canonicalize.td @@ -28,54 +28,3 @@ def UnaryEinsumToEinsum : Pat< (HLO_UnaryEinsumOp $operand, $equation), (HLO_EinsumOp (HLO_ConstOp (GetScalarOfType<1> $operand)), $operand, (UnaryToBinaryEinsumEq $equation))>; - -//===----------------------------------------------------------------------===// -// Conv op patterns. -//===----------------------------------------------------------------------===// - -def IsZero : Attr() &&" - "$_self.cast().isSplat() &&" - "$_self.cast().getSplatValue()" - ".getValue().isZero()) ||" - "($_self.isa() &&" - "$_self.cast().isSplat() &&" - "$_self.cast().getSplatValue()" - ".getInt() == 0)">>; - -def IsOnlyPaddingSpatialDims - : Constraint>; - -def BuildConvPaddingAttrs : NativeCodeCall< - "BuildConvPaddingAttrs($0, $1, $2, $3, &$_builder)">; - -def FoldPadIntoConv : Pat< - (HLO_ConvOp - (HLO_PadOp $lhs, - (HLO_ConstOp IsZero:$padding_value), - $edge_padding_low, - $edge_padding_high, - IsZero:$interior_padding), - $rhs, - $window_strides, - $padding, - $lhs_dilation, - $rhs_dilation, - $dimension_numbers, - $feature_group_count, - $batch_group_count, - $precision_config), - (HLO_ConvOp - $lhs, - $rhs, - $window_strides, - (BuildConvPaddingAttrs $edge_padding_low, $edge_padding_high, $padding, - $dimension_numbers), - $lhs_dilation, - $rhs_dilation, - $dimension_numbers, - $feature_group_count, - $batch_group_count, - $precision_config), - [(IsOnlyPaddingSpatialDims $lhs, $dimension_numbers, $edge_padding_low, - $edge_padding_high)]>; diff --git a/tensorflow/compiler/mlir/xla/transforms/hlo_legalize_to_lhlo.cc b/tensorflow/compiler/mlir/xla/transforms/hlo_legalize_to_lhlo.cc index 1cfe0c12e20..7cdc0d92207 100644 --- a/tensorflow/compiler/mlir/xla/transforms/hlo_legalize_to_lhlo.cc +++ b/tensorflow/compiler/mlir/xla/transforms/hlo_legalize_to_lhlo.cc @@ -230,10 +230,10 @@ struct HloToLhloReduceOpConverter : public BaseOpConversion { auto loc = op.getLoc(); // TODO(b/137624192) Implement variadic reduce. if (op.getNumResults() != 1) return failure(); - if (op.getParentRegion()->getBlocks().size() != 1) { - op.emitOpError() << "tensor to buffer conversion expects a single block " - "in the region containing the operation"; - return failure(); + if (!llvm::hasSingleElement(op.body())) { + return op.emitOpError() + << "tensor to buffer conversion expects a single block " + "in the region containing the operation"; } const auto& original_results = op.getResults(); SmallVector buffer_args(operands.begin(), operands.end()); @@ -389,10 +389,13 @@ struct HloLegalizeToLhlo target.addLegalOp(); target.addLegalOp(); target.addIllegalDialect(); + + BufferAssignmentTypeConverter converter; target.addDynamicallyLegalOp([&](FuncOp op) { auto inputs = op.getType().getInputs(); - return std::all_of(inputs.begin(), inputs.end(), - [](Type input) { return input.isa(); }); + return llvm::all_of(inputs, + [](Type input) { return input.isa(); }) && + converter.isLegal(&op.getBody()); }); target.addDynamicallyLegalOp([&](mlir::ReturnOp returnOp) { return std::all_of(returnOp.operand_type_begin(), @@ -401,8 +404,7 @@ struct HloLegalizeToLhlo }); auto module = getOperation(); - BufferAssignmentTypeConverter converter; - module.walk([&](FuncOp func) { + module.walk([&](FuncOp func) -> WalkResult { BufferAssignmentPlacer bufferAssignment(func); OwningRewritePatternList patterns; populateHLOToLHLOConversionPattern(func.getContext(), &bufferAssignment, @@ -418,8 +420,7 @@ struct HloLegalizeToLhlo /*allowMemrefFunctionResults=*/false>(&context, &bufferAssignment, &converter, &patterns); } - return WalkResult( - applyPartialConversion(func, target, patterns, &converter)); + return applyPartialConversion(func, target, patterns); }); } @@ -463,6 +464,7 @@ void populateHLOToLHLOConversionPattern( HloToLhloOpConverter, HloToLhloOpConverter, HloToLhloOpConverter, + HloToLhloOpConverter, HloToLhloOpConverter, HloToLhloOpConverter, HloToLhloOpConverter, diff --git a/tensorflow/compiler/mlir/xla/transforms/legalize_tanh_to_approximation.cc b/tensorflow/compiler/mlir/xla/transforms/legalize_tanh_to_approximation.cc new file mode 100644 index 00000000000..9696db377da --- /dev/null +++ b/tensorflow/compiler/mlir/xla/transforms/legalize_tanh_to_approximation.cc @@ -0,0 +1,167 @@ +/* Copyright 2020 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +// This file implements logic for lowering the tanh standard ops to an +// approximation. + +#include "mlir/Dialect/StandardOps/IR/Ops.h" // from @llvm-project +#include "mlir/IR/Function.h" // from @llvm-project +#include "mlir/IR/PatternMatch.h" // from @llvm-project +#include "mlir/Pass/Pass.h" // from @llvm-project +#include "tensorflow/compiler/mlir/xla/transforms/passes.h" +#include "tensorflow/compiler/mlir/xla/transforms/rewriters.h" + +namespace mlir { +namespace xla { +namespace { + +/// Emits the fast tanh approximation that is also used by XLA. +static Value EmitTanhApproximation(Value input, Value abs_value, Location loc, + PatternRewriter &rewriter) { + // For small values of x, we can approximate tanh(x)=x. For extremely small + // values of x (|x| < 1e-37), the other approximation would evaluate + // tanh(x) = 0. + constexpr float kCanUseApprox = 0.0004; + Value can_use_approx = + rewriter.create(loc, rewriter.getF32FloatAttr(kCanUseApprox)); + Value return_input = rewriter.create(loc, CmpFPredicate::OLT, + abs_value, can_use_approx); + + // Clamp the input to [-9, 9]. + Value plus_nine = + rewriter.create(loc, rewriter.getF32FloatAttr(9.0)); + Value smaller_than_nine = + rewriter.create(loc, CmpFPredicate::ULE, abs_value, plus_nine); + Value input_clamped = rewriter.create( + loc, smaller_than_nine, input, + rewriter.create(loc, plus_nine, input)); + + static constexpr std::array numerator_coeffs{ + -2.76076847742355e-16f, 2.00018790482477e-13f, -8.60467152213735e-11f, + 5.12229709037114e-08f, 1.48572235717979e-05f, 6.37261928875436e-04f, + 4.89352455891786e-03f}; + + static constexpr std::array denominator_coeffs{ + 1.19825839466702e-06f, 1.18534705686654e-04f, 2.26843463243900e-03f, + 4.89352518554385e-03f}; + + Value input_squared = + rewriter.create(loc, input_clamped, input_clamped); + Value numerator = rewriter.create( + loc, rewriter.getF32FloatAttr(numerator_coeffs[0])); + for (int i = 1; i < numerator_coeffs.size(); i++) { + numerator = rewriter.create( + loc, rewriter.create(loc, input_squared, numerator), + rewriter.create( + loc, rewriter.getF32FloatAttr(numerator_coeffs[i]))); + } + + numerator = rewriter.create(loc, input_clamped, numerator); + + Value denominator = rewriter.create( + loc, rewriter.getF32FloatAttr(denominator_coeffs[0])); + for (int i = 1; i < denominator_coeffs.size(); i++) { + denominator = rewriter.create( + loc, rewriter.create(loc, input_squared, denominator), + rewriter.create( + loc, rewriter.getF32FloatAttr(denominator_coeffs[i]))); + } + + Value approx = rewriter.create(loc, numerator, denominator); + + return rewriter.create(loc, return_input, input, approx); +} + +class ApproximateTanhLowering : public OpRewritePattern { + public: + explicit ApproximateTanhLowering(MLIRContext *ctx) + : OpRewritePattern(ctx, 100) {} + + LogicalResult matchAndRewrite(TanhOp tanhOp, + PatternRewriter &rewriter) const override { + Type operand_type = tanhOp.getType(); + + if (operand_type.isF64()) { + // Similar to XLA, do not rewrite f64 as precision might matter. + return failure(); + } + + Location loc = tanhOp.getLoc(); + Value input = tanhOp.operand(); + if (operand_type.isF16()) { + input = rewriter.create(loc, input, rewriter.getF32Type()); + } + + // If we still do not have f32, fail. + if (!input.getType().isF32()) { + return failure(); + } + + // For |operand| > 20.0, we just return -1/1. + constexpr double kMaxValue = 20.0; + Value max_value = + rewriter.create(loc, rewriter.getF32FloatAttr(kMaxValue)); + Value abs_value = rewriter.create(loc, input); + + Value one = rewriter.create(loc, rewriter.getF32FloatAttr(1.0)); + Value one_with_sign = rewriter.create(loc, one, input); + + Value smaller_than_twenty = + rewriter.create(loc, CmpFPredicate::ULT, abs_value, max_value); + + // Otherwise, we use the approximation. + Value approx = EmitTanhApproximation(input, abs_value, loc, rewriter); + + Value result = rewriter.create(loc, smaller_than_twenty, approx, + one_with_sign); + + // Truncate back if needed. + if (operand_type.isF16()) { + result = rewriter.create(loc, result, rewriter.getF16Type()); + } + + rewriter.replaceOp(tanhOp, {result}); + return success(); + } +}; + +struct LegalizeTanhToApproximation + : public PassWrapper { + /// Perform the lowering of standard dialect operations to approximations. + void runOnFunction() override { + OwningRewritePatternList patterns; + PopulateTanhToApproximationPatterns(&getContext(), &patterns); + applyPatternsAndFoldGreedily(getFunction(), patterns); + } +}; + +} // anonymous namespace + +std::unique_ptr> +createLegalizeTanhToApproximationPass() { + return std::make_unique(); +} + +void PopulateTanhToApproximationPatterns(mlir::MLIRContext *context, + OwningRewritePatternList *patterns) { + patterns->insert(context); +} + +static PassRegistration legalize_pass( + "xla-legalize-tanh-to-approximation", + "Legalize tanh from standard dialect to an approximation"); + +} // namespace xla +} // namespace mlir diff --git a/tensorflow/compiler/mlir/xla/transforms/legalize_tf.cc b/tensorflow/compiler/mlir/xla/transforms/legalize_tf.cc index b7cad554043..1788cd1b270 100644 --- a/tensorflow/compiler/mlir/xla/transforms/legalize_tf.cc +++ b/tensorflow/compiler/mlir/xla/transforms/legalize_tf.cc @@ -5238,8 +5238,8 @@ LogicalResult legalizeTF(Operation *op, bool allow_partial_conversion, // Fully qualify ReturnOp here as xla_hlo dialect also defines a ReturnOp. target.addLegalOp(); DenseSet nonlegalized_ops; - LogicalResult result = applyPartialConversion( - op, target, patterns, /*converter=*/nullptr, &nonlegalized_ops); + LogicalResult result = + applyPartialConversion(op, target, patterns, &nonlegalized_ops); // In order to enforce that the conversion result is fully converted, // fail if there are any nonlegalized ops in the set. if (failed(result) || !nonlegalized_ops.empty()) { diff --git a/tensorflow/compiler/mlir/xla/transforms/legalize_tf_patterns.td b/tensorflow/compiler/mlir/xla/transforms/legalize_tf_patterns.td index f3c432f38bd..df7b887fcad 100644 --- a/tensorflow/compiler/mlir/xla/transforms/legalize_tf_patterns.td +++ b/tensorflow/compiler/mlir/xla/transforms/legalize_tf_patterns.td @@ -321,7 +321,10 @@ def : Pat<(TF_PadV2Op $input, (TF_ConstOp $padding), $c), foreach src = [TF_IdentityOp, TF_StopGradientOp] in def : Pat<(src $op), (replaceWithValue $op)>; -def : Pat<(TF_PreventGradientOp $op, $msg), (replaceWithValue $op)>; + +// TODO(b/32223192): Support CheckNumerics in HLO. +foreach src = [TF_PreventGradientOp, TF_CheckNumericsOp] in + def : Pat<(src $op, $msg), (replaceWithValue $op)>; //===----------------------------------------------------------------------===// // MatMul op patterns. diff --git a/tensorflow/compiler/mlir/xla/transforms/legalize_tf_with_tf2xla.cc b/tensorflow/compiler/mlir/xla/transforms/legalize_tf_with_tf2xla.cc index e57d6938efb..54453406ef7 100644 --- a/tensorflow/compiler/mlir/xla/transforms/legalize_tf_with_tf2xla.cc +++ b/tensorflow/compiler/mlir/xla/transforms/legalize_tf_with_tf2xla.cc @@ -23,6 +23,7 @@ limitations under the License. #include "absl/strings/string_view.h" #include "llvm/ADT/DenseSet.h" #include "llvm/ADT/Optional.h" +#include "llvm/ADT/STLExtras.h" #include "mlir/Dialect/StandardOps/IR/Ops.h" // from @llvm-project #include "mlir/IR/Diagnostics.h" // from @llvm-project #include "mlir/IR/Function.h" // from @llvm-project @@ -88,6 +89,9 @@ static bool IsOpWhitelisted(Operation* op) { TypeID::get(), TypeID::get(), TypeID::get(), + TypeID::get(), + TypeID::get(), + TypeID::get(), TypeID::get(), TypeID::get(), TypeID::get(), @@ -127,6 +131,7 @@ static bool IsOpWhitelisted(Operation* op) { TypeID::get(), TypeID::get(), TypeID::get(), + TypeID::get(), TypeID::get(), TypeID::get(), TypeID::get(), @@ -157,10 +162,14 @@ static bool IsOpWhitelisted(Operation* op) { TypeID::get(), TypeID::get(), TypeID::get(), + TypeID::get(), TypeID::get(), TypeID::get(), TypeID::get(), TypeID::get(), + TypeID::get(), + TypeID::get(), + TypeID::get(), TypeID::get(), TypeID::get(), TypeID::get(), @@ -312,13 +321,14 @@ LogicalResult FuncLegalizer::PrepareParams() { } LogicalResult FuncLegalizer::Legalize() { + if (func_.empty()) return success(); + // TensorFlow functions don't use CFGs. - if (func_.getBlocks().size() > 1) { + if (!llvm::hasSingleElement(func_)) { emitError(func_.getLoc()) << "requires at most one block in a TF function"; return failure(); } - if (func_.getBlocks().empty()) return success(); - Block& block = func_.getBlocks().front(); + Block& block = func_.front(); std::vector ops; ops.reserve(block.getOperations().size()); @@ -337,9 +347,9 @@ LogicalResult FuncLegalizer::LegalizeOp(Operation* op) { // Only static shaped operands are supported in XLA builders for now. for (Type ty : op->getOperandTypes()) { - auto ranked_ty = ty.cast(); - if (!ranked_ty.hasStaticShape()) { - op->emitRemark() << "lowering requires static shaped operands"; + auto ranked_ty = ty.dyn_cast(); + if (!ranked_ty || !ranked_ty.hasStaticShape()) { + op->emitRemark() << "lowering requires static shaped tensor operands"; return success(); } } diff --git a/tensorflow/compiler/mlir/xla/transforms/lhlo_fuse_linalg.cc b/tensorflow/compiler/mlir/xla/transforms/lhlo_fuse_linalg.cc index e16ab571b4d..f0971fdf76e 100644 --- a/tensorflow/compiler/mlir/xla/transforms/lhlo_fuse_linalg.cc +++ b/tensorflow/compiler/mlir/xla/transforms/lhlo_fuse_linalg.cc @@ -19,6 +19,7 @@ limitations under the License. #include "mlir/Dialect/Linalg/Analysis/DependenceAnalysis.h" #include "absl/memory/memory.h" #include "llvm/ADT/ArrayRef.h" +#include "llvm/ADT/STLExtras.h" #include "mlir/Dialect/Linalg/Transforms/Transforms.h" // from @llvm-project #include "mlir/Dialect/StandardOps/IR/Ops.h" // from @llvm-project #include "mlir/Pass/Pass.h" // from @llvm-project @@ -44,7 +45,7 @@ class LhloFuseLinalg : public PassWrapper { auto func = getFunction(); // TODO(pifon): Remove assumption that the function has a single block. - if (func.getBlocks().size() != 1) { + if (!llvm::hasSingleElement(func)) { emitError(func.getLoc(), "The function needs to have a single block."); signalPassFailure(); return; @@ -58,7 +59,7 @@ class LhloFuseLinalg : public PassWrapper { for (auto func_arg : func.getArguments()) { result_buffers.insert(func_arg); } - for (auto& block : func.getBlocks()) { + for (auto& block : func) { auto returnOp = mlir::dyn_cast(block.getTerminator()); if (!returnOp) continue; for (auto operand : returnOp.getOperands()) { diff --git a/tensorflow/compiler/mlir/xla/transforms/lhlo_legalize_to_affine.cc b/tensorflow/compiler/mlir/xla/transforms/lhlo_legalize_to_affine.cc index 56b9f5879f6..904a30e847a 100644 --- a/tensorflow/compiler/mlir/xla/transforms/lhlo_legalize_to_affine.cc +++ b/tensorflow/compiler/mlir/xla/transforms/lhlo_legalize_to_affine.cc @@ -31,6 +31,17 @@ namespace mlir { namespace xla_lhlo { namespace { +// Builds an affine loop nest iterating from zeros to "upper_bounds" with unit +// steps, and populates the body of the innermost loop using "body_builder". +static void BuildBoundedAffineLoopNest( + OpBuilder& builder, Location location, ArrayRef upper_bounds, + function_ref body_builder) { + SmallVector lower_bounds(upper_bounds.size(), /*Value=*/0); + SmallVector steps(upper_bounds.size(), /*Value=*/1); + buildAffineLoopNest(builder, location, lower_bounds, upper_bounds, steps, + body_builder); +} + struct DotOpConverter : public OpRewritePattern { using OpRewritePattern::OpRewritePattern; @@ -48,37 +59,29 @@ struct DotOpConverter : public OpRewritePattern { if ((lhs_type.getRank() != 2) || (rhs_type.getRank() != 2)) { return failure(); } - SmallVector lhs_indices, rhs_indices, result_indices; - const auto& loc = op.getLoc(); - // Create the canonical ijk form of matmul. - auto forOp = rewriter.create(loc, 0, shape_lhs[0]); - lhs_indices.push_back(forOp.getInductionVar()); - result_indices.push_back(forOp.getInductionVar()); + LogicalResult map_status = success(); + auto body_builder = [&](OpBuilder& builder, Location loc, ValueRange ivs) { + SmallVector lhs_indices{ivs[0], ivs[2]}, + rhs_indices{ivs[2], ivs[1]}, result_indices{ivs[0], ivs[1]}; - rewriter.setInsertionPointToStart(forOp.getBody()); - forOp = rewriter.create(loc, 0, shape_rhs.back()); - result_indices.push_back(forOp.getInductionVar()); - rhs_indices.resize(2); - rhs_indices[1] = forOp.getInductionVar(); + auto l = builder.create(loc, lhs, lhs_indices); + auto r = builder.create(loc, rhs, rhs_indices); + auto result = + rewriter.create(loc, op.output(), result_indices); + Value op_result = xla_lhlo::XlaOpToStdScalarOp::map( + op, element_type, {l, r, result}, &builder); + map_status = success(op_result != nullptr); + if (failed(map_status)) return; + builder.create(loc, op_result, op.output(), + result_indices); + }; - rewriter.setInsertionPointToStart(forOp.getBody()); - forOp = rewriter.create(loc, 0, shape_rhs.front()); - lhs_indices.push_back(forOp.getInductionVar()); - rhs_indices[0] = forOp.getInductionVar(); + BuildBoundedAffineLoopNest(rewriter, op.getLoc(), + {shape_lhs[0], shape_rhs[1], shape_rhs[0]}, + body_builder); + if (failed(map_status)) return failure(); - // Construct the innermost loop body. - rewriter.setInsertionPointToStart(forOp.getBody()); - auto l = rewriter.create(loc, lhs, lhs_indices); - auto r = rewriter.create(loc, rhs, rhs_indices); - auto result = - rewriter.create(loc, op.output(), result_indices); - Value op_result = xla_lhlo::XlaOpToStdScalarOp::map( - op, element_type, {l, r, result}, &rewriter); - if (op_result == nullptr) { - return failure(); - } - rewriter.create(loc, op_result, op.output(), result_indices); rewriter.eraseOp(op); return success(); } @@ -99,22 +102,22 @@ struct BinaryOpConverter : public OpRewritePattern { if (lhs_type.getShape() != rhs_type.getShape()) { return failure(); } - const auto& shape = lhs_type.getShape(); - SmallVector induction_vars; - const auto loc = op.getLoc(); - for (int i = 0; i < shape.size(); ++i) { - auto forOp = rewriter.create(loc, 0, shape[i]); - induction_vars.push_back(forOp.getInductionVar()); - rewriter.setInsertionPointToStart(forOp.getBody()); - } - auto l = rewriter.create(loc, lhs, induction_vars); - auto r = rewriter.create(loc, rhs, induction_vars); - Value opResult = xla_lhlo::XlaOpToStdScalarOp::map( - op, element_type, {l, r}, &rewriter); - if (opResult == nullptr) { - return failure(); - } - rewriter.create(loc, opResult, op.out(), induction_vars); + + LogicalResult map_status = success(); + auto body_builder = [&](OpBuilder& builder, Location loc, + ValueRange induction_vars) { + auto l = builder.create(loc, lhs, induction_vars); + auto r = builder.create(loc, rhs, induction_vars); + Value op_result = xla_lhlo::XlaOpToStdScalarOp::map( + op, element_type, {l, r}, &builder); + map_status = success(op_result != nullptr); + if (failed(map_status)) return; + rewriter.create(loc, op_result, op.out(), induction_vars); + }; + + BuildBoundedAffineLoopNest(rewriter, op.getLoc(), lhs_type.getShape(), + body_builder); + if (failed(map_status)) return failure(); rewriter.eraseOp(op); return success(); } diff --git a/tensorflow/compiler/mlir/xla/transforms/lhlo_legalize_to_gpu.cc b/tensorflow/compiler/mlir/xla/transforms/lhlo_legalize_to_gpu.cc index f0eb3cc1a0f..c23b8b49268 100644 --- a/tensorflow/compiler/mlir/xla/transforms/lhlo_legalize_to_gpu.cc +++ b/tensorflow/compiler/mlir/xla/transforms/lhlo_legalize_to_gpu.cc @@ -177,7 +177,7 @@ struct LhloLegalizeToGpu : public PassWrapper { target.addIllegalOp(); auto func = getFunction(); patterns.insert(func.getContext()); - if (failed(applyPartialConversion(func, target, patterns, nullptr))) { + if (failed(applyPartialConversion(func, target, patterns))) { signalPassFailure(); } } diff --git a/tensorflow/compiler/mlir/xla/transforms/lhlo_legalize_to_llvm.cc b/tensorflow/compiler/mlir/xla/transforms/lhlo_legalize_to_llvm.cc index 99d2c08aa98..78a77dc3b4d 100644 --- a/tensorflow/compiler/mlir/xla/transforms/lhlo_legalize_to_llvm.cc +++ b/tensorflow/compiler/mlir/xla/transforms/lhlo_legalize_to_llvm.cc @@ -129,7 +129,7 @@ struct DynamicMemRefCastOpConverter void PopulateLhloToLLVMConversionPatterns(LLVMTypeConverter *converter, OwningRewritePatternList *patterns) { patterns->insert( - *converter); + *converter, LowerToLLVMOptions()); } } // namespace xla_lhlo diff --git a/tensorflow/compiler/mlir/xla/transforms/lhlo_legalize_to_llvm_pass.cc b/tensorflow/compiler/mlir/xla/transforms/lhlo_legalize_to_llvm_pass.cc index 9b809049290..63265c4a7e7 100644 --- a/tensorflow/compiler/mlir/xla/transforms/lhlo_legalize_to_llvm_pass.cc +++ b/tensorflow/compiler/mlir/xla/transforms/lhlo_legalize_to_llvm_pass.cc @@ -43,7 +43,7 @@ class TestLhloToLLVMPass target.addLegalOp(); target.addIllegalDialect(); - if (failed(applyFullConversion(m, target, patterns, &converter))) { + if (failed(applyFullConversion(m, target, patterns))) { signalPassFailure(); } } diff --git a/tensorflow/compiler/mlir/xla/transforms/lhlo_legalize_to_parallel_loops.cc b/tensorflow/compiler/mlir/xla/transforms/lhlo_legalize_to_parallel_loops.cc index b3112d49103..65962c5b7a5 100644 --- a/tensorflow/compiler/mlir/xla/transforms/lhlo_legalize_to_parallel_loops.cc +++ b/tensorflow/compiler/mlir/xla/transforms/lhlo_legalize_to_parallel_loops.cc @@ -711,7 +711,7 @@ struct LhloLegalizeToParallelLoops target.addIllegalOp(); - if (failed(applyPartialConversion(func, target, patterns, nullptr))) { + if (failed(applyPartialConversion(func, target, patterns))) { signalPassFailure(); } } diff --git a/tensorflow/compiler/mlir/xla/transforms/map_hlo_to_lhlo_op.h b/tensorflow/compiler/mlir/xla/transforms/map_hlo_to_lhlo_op.h index 4b9397795a1..8d5f27474a5 100644 --- a/tensorflow/compiler/mlir/xla/transforms/map_hlo_to_lhlo_op.h +++ b/tensorflow/compiler/mlir/xla/transforms/map_hlo_to_lhlo_op.h @@ -61,6 +61,7 @@ MAP_HLO_TO_LHLO(MulOp); MAP_HLO_TO_LHLO(NegOp); MAP_HLO_TO_LHLO(RealOp); MAP_HLO_TO_LHLO(ReduceOp); +MAP_HLO_TO_LHLO(ReshapeOp); MAP_HLO_TO_LHLO(RemOp); MAP_HLO_TO_LHLO(RsqrtOp); MAP_HLO_TO_LHLO(SelectOp); diff --git a/tensorflow/compiler/mlir/xla/transforms/passes.h b/tensorflow/compiler/mlir/xla/transforms/passes.h index a2af8124786..3db0bc3b474 100644 --- a/tensorflow/compiler/mlir/xla/transforms/passes.h +++ b/tensorflow/compiler/mlir/xla/transforms/passes.h @@ -115,6 +115,13 @@ std::unique_ptr createLhloCopyRemovalPass(); std::unique_ptr> createLegalizeLhloToParallelLoopsPass(); } // namespace xla_lhlo + +namespace xla { + +/// Lowers the standard TanhOp to an approximation that does not use intrinsics. +std::unique_ptr> createLegalizeTanhToApproximationPass(); + +} // namespace xla } // namespace mlir #endif // TENSORFLOW_COMPILER_MLIR_XLA_TRANSFORMS_PASSES_H_ diff --git a/tensorflow/compiler/mlir/xla/transforms/rewriters.h b/tensorflow/compiler/mlir/xla/transforms/rewriters.h index 59347198fe4..7303b87be75 100644 --- a/tensorflow/compiler/mlir/xla/transforms/rewriters.h +++ b/tensorflow/compiler/mlir/xla/transforms/rewriters.h @@ -91,6 +91,14 @@ void PopulateLegalizeChloToHloPatterns(MLIRContext *context, } // namespace xla_chlo +namespace xla { + +// Populates a pattern that translates the standard TanhOp to an approximation +// that does not use intrinsics. +void PopulateTanhToApproximationPatterns(MLIRContext *context, + OwningRewritePatternList *patterns); + +} // namespace xla } // namespace mlir #endif // TENSORFLOW_COMPILER_MLIR_XLA_TRANSFORMS_REWRITERS_H_ diff --git a/tensorflow/compiler/mlir/xla/transforms/xla_hlo_fusion.cc b/tensorflow/compiler/mlir/xla/transforms/xla_hlo_fusion.cc index c4eb0e143d2..5d3eda0bea5 100644 --- a/tensorflow/compiler/mlir/xla/transforms/xla_hlo_fusion.cc +++ b/tensorflow/compiler/mlir/xla/transforms/xla_hlo_fusion.cc @@ -487,7 +487,7 @@ struct XlaHloFusion : public mlir::PassWrapper { } // process each block and do fusion within a block. - for (Block& block : func.getBlocks()) { + for (Block& block : func) { SmallVector op_list; for (Operation& op : block) { op_list.push_back(&op); diff --git a/tensorflow/compiler/mlir/xla/transforms/xla_legalize_to_linalg.cc b/tensorflow/compiler/mlir/xla/transforms/xla_legalize_to_linalg.cc index ad78a01100b..e7bb5df8233 100644 --- a/tensorflow/compiler/mlir/xla/transforms/xla_legalize_to_linalg.cc +++ b/tensorflow/compiler/mlir/xla/transforms/xla_legalize_to_linalg.cc @@ -125,32 +125,19 @@ class PointwiseToLinalgConverter : public OpConversionPattern { opResultTypes.push_back(shapedType); } + int64_t args_count = bodyArgTypes.size(); + int64_t results_count = bodyResultTypes.size(); auto linalgOp = rewriter.create( - loc, opResultTypes, args, - /*inputCount=*/bodyArgTypes.size(), - /*outputCount=*/bodyResultTypes.size(), indexing_maps, - GetNParallelLoopsAttrs(nloops)); - - // Add a block to the region. - auto* region = &linalgOp.region(); - auto* block = rewriter.createBlock(region, region->end()); - block->addArguments(bodyArgTypes); - if (isLHLO) block->addArguments(bodyResultTypes); - - SmallVector bodyArgs; - for (int i = 0, e = bodyArgTypes.size(); i < e; ++i) { - bodyArgs.push_back(block->getArgument(i)); - } - - rewriter.setInsertionPointToEnd(block); - // TODO(ravishankarm) : For now use the method in xla_lhlo namespace. That - // method needs to be moved out of there. - Value opResult = xla_lhlo::XlaOpToStdScalarOp::map( - op, bodyResultTypes, bodyArgs, &rewriter); - if (!opResult) { - return failure(); - } - rewriter.create(loc, opResult); + loc, opResultTypes, args, args_count, results_count, indexing_maps, + GetNParallelLoopsAttrs(nloops), + [&](OpBuilder& nestedBuilder, Location nestedLoc, ValueRange args) { + // TODO(ravishankarm) : For now use the method in xla_lhlo namespace. + // That method needs to be moved out of there. + Value opResult = xla_lhlo::XlaOpToStdScalarOp::map( + op, bodyResultTypes, + llvm::to_vector<2>(args.take_front(args_count)), &rewriter); + nestedBuilder.create(loc, opResult); + }); rewriter.replaceOp(op, linalgOp.getOperation()->getResults()); return success(); } @@ -301,27 +288,20 @@ class DataMovementOpConverter : public OpConversionPattern { OpTy op, ArrayRef args, ConversionPatternRewriter& rewriter) const final { if (!verifyXLAOpBufferOrTensorSemantics(op)) return failure(); - auto operandType = op.operand().getType().template cast(); auto resultType = getXLAOpResultType(op); SmallVector indexing_maps = Derived::getIndexingMaps(op, &rewriter); if (indexing_maps.empty()) return failure(); - OpBuilder::InsertionGuard linalgOpGuard(rewriter); auto nloops = resultType.getRank(); auto loc = op.getLoc(); auto linalgOp = rewriter.create( loc, isLHLO ? ArrayRef{} : resultType, args, /*inputCount=*/1, - /*outputCount=*/1, indexing_maps, GetNParallelLoopsAttrs(nloops)); - - auto* region = &linalgOp.region(); - auto* block = rewriter.createBlock(region, region->end()); - block->addArguments(operandType.getElementType()); - if (isLHLO) block->addArgument(resultType.getElementType()); - - rewriter.setInsertionPointToEnd(block); - rewriter.create(loc, block->getArgument(0)); + /*outputCount=*/1, indexing_maps, GetNParallelLoopsAttrs(nloops), + [&](OpBuilder& nestedBuilder, Location nestedLoc, ValueRange args) { + nestedBuilder.create(loc, *args.begin()); + }); rewriter.replaceOp(op, linalgOp.getOperation()->getResults()); return success(); @@ -437,36 +417,26 @@ class LhloBroadcastInDimConverter Value zero = rewriter.create(loc, 0); Value val = rewriter.create(loc, operand, llvm::makeArrayRef({zero})); - auto linalgOp = rewriter.create( + rewriter.create( loc, llvm::None, llvm::makeArrayRef(operand_adaptor.output()), /*inputCount=*/0, /*outputCount=*/1, llvm::makeArrayRef(rewriter.getMultiDimIdentityMap(nloops)), - GetNParallelLoopsAttrs(nloops)); + GetNParallelLoopsAttrs(nloops), + [&](OpBuilder& nestedBuilder, Location nestedLoc, ValueRange args) { + nestedBuilder.create(loc, val); + }); - auto* region = &linalgOp.region(); - auto* block = rewriter.createBlock(region, region->end()); - block->addArgument(result_type.getElementType()); - - rewriter.setInsertionPointToEnd(block); - rewriter.create(loc, val); } else { auto indexing_maps = getIndexingMaps(op, broadcast_dims, result_shape, operand_type, &rewriter); - - OpBuilder::InsertionGuard linalgOpGuard(rewriter); - auto linalgOp = rewriter.create( + rewriter.create( loc, llvm::None, llvm::makeArrayRef({operand, operand_adaptor.output()}), /*inputCount=*/1, /*outputCount=*/1, indexing_maps, - GetNParallelLoopsAttrs(nloops)); - - auto* region = &linalgOp.region(); - auto* block = rewriter.createBlock(region, region->end()); - block->addArguments(operand_type.getElementType()); - block->addArgument(result_type.getElementType()); - - rewriter.setInsertionPointToEnd(block); - rewriter.create(loc, block->getArgument(0)); + GetNParallelLoopsAttrs(nloops), + [&](OpBuilder& nestedBuilder, Location nestedLoc, ValueRange args) { + nestedBuilder.create(loc, *args.begin()); + }); } rewriter.replaceOp(op, llvm::None); return success(); @@ -686,32 +656,26 @@ class IotaConverter : public OpConversionPattern { // Construct the indexing maps needed for linalg.generic ops. unsigned nloops = resultMemrefType.getRank(); - auto loc = iotaOp.getLoc(); - auto linalgOp = rewriter.create( - loc, ArrayRef{}, args, + rewriter.create( + iotaOp.getLoc(), ArrayRef{}, args, 0, // args_in 1, // args_out llvm::makeArrayRef(rewriter.getMultiDimIdentityMap(nloops)), - GetNParallelLoopsAttrs(nloops)); + GetNParallelLoopsAttrs(nloops), + [&](OpBuilder& nestedBuilder, Location nestedLoc, ValueRange ivs, + ValueRange args) { + Value castOp = nestedBuilder.create( + nestedLoc, ivs[iotaOp.iota_dimension().getZExtValue()], + nestedBuilder.getIntegerType( + resultElementType.getIntOrFloatBitWidth())); + if (resultElementType.isa()) { + castOp = nestedBuilder.create(nestedLoc, castOp, + resultElementType); + } + nestedBuilder.create(nestedLoc, castOp); + }); - // Add a block to the region. - auto* region = &linalgOp.region(); - auto* block = rewriter.createBlock(region, region->end()); - for (unsigned i = 0; i < nloops; ++i) { - block->addArgument(rewriter.getIndexType()); - } - block->addArguments(llvm::makeArrayRef(resultElementType)); - - rewriter.setInsertionPointToEnd(block); - Operation* castOp = rewriter.create( - loc, block->getArgument(iotaOp.iota_dimension().getZExtValue()), - rewriter.getIntegerType(resultElementType.getIntOrFloatBitWidth())); - if (resultElementType.isa()) { - castOp = rewriter.create(loc, castOp->getResult(0), - resultElementType); - } - rewriter.create(loc, castOp->getResult(0)); - rewriter.eraseOp(iotaOp); + rewriter.replaceOp(iotaOp, llvm::None); return success(); } }; diff --git a/tensorflow/compiler/mlir/lite/tests/tf_device_index_selector.mlir b/tensorflow/compiler/tensorflow/tests/tf_device_index_selector.mlir similarity index 94% rename from tensorflow/compiler/mlir/lite/tests/tf_device_index_selector.mlir rename to tensorflow/compiler/tensorflow/tests/tf_device_index_selector.mlir index 1ac7f30d644..7fc2b210f91 100644 --- a/tensorflow/compiler/mlir/lite/tests/tf_device_index_selector.mlir +++ b/tensorflow/compiler/tensorflow/tests/tf_device_index_selector.mlir @@ -1,6 +1,6 @@ // Test DeviceIndex selector. -// RUN: tf-opt --tfl-device-index-selector %s | FileCheck %s +// RUN: tf-opt --tf-device-index-selector %s | FileCheck %s // CHECK-LABEL: func @select func @select(%arg0: tensor, %arg1: tensor) -> (tensor, tensor) { diff --git a/tensorflow/compiler/tests/BUILD b/tensorflow/compiler/tests/BUILD index b574622efce..42353451408 100644 --- a/tensorflow/compiler/tests/BUILD +++ b/tensorflow/compiler/tests/BUILD @@ -770,6 +770,7 @@ tf_xla_py_test( size = "small", timeout = "long", srcs = ["image_ops_test.py"], + enable_mlir_bridge = True, python_version = "PY3", shard_count = 10, tags = [ @@ -1452,6 +1453,26 @@ tf_xla_py_test( ], ) +tf_xla_py_test( + name = "case_test", + size = "small", + srcs = ["case_test.py"], + disabled_backends = ["cpu_ondemand"], + python_version = "PY3", + tags = [ + "no_pip", # TODO(b/149738646): fix pip install so these tests run on kokoro pip + ], + use_xla_device = False, # Uses tf.function(experimental_compile=True) + deps = [ + ":xla_test", + "//tensorflow/compiler/tf2xla/python:xla", + "//tensorflow/python:array_ops", + "//tensorflow/python:framework", + "//tensorflow/python:platform_test", + "//tensorflow/python:training", + ], +) + tf_xla_py_test( name = "gather_test", size = "medium", diff --git a/tensorflow/compiler/tests/case_test.py b/tensorflow/compiler/tests/case_test.py new file mode 100644 index 00000000000..3b2dff537da --- /dev/null +++ b/tensorflow/compiler/tests/case_test.py @@ -0,0 +1,87 @@ +# Copyright 2020 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""Tests for while loops in XLA.""" + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +from tensorflow.compiler.tests import xla_test +from tensorflow.python.eager import def_function +from tensorflow.python.framework import ops +from tensorflow.python.ops import array_ops +from tensorflow.python.ops import control_flow_ops +from tensorflow.python.ops import image_ops +from tensorflow.python.ops import io_ops +from tensorflow.python.platform import test + + +class CaseTest(xla_test.XLATestCase): + + def testCaseBasic(self): + + @def_function.function(experimental_compile=True) + def switch_case_test(branch_index): + + def f1(): + return array_ops.constant(17) + + def f2(): + return array_ops.constant(31) + + def f3(): + return array_ops.constant(-1) + + return control_flow_ops.switch_case( + branch_index, branch_fns={ + 0: f1, + 1: f2 + }, default=f3) + + with ops.device(self.device): + self.assertEqual(switch_case_test(array_ops.constant(0)).numpy(), 17) + self.assertEqual(switch_case_test(array_ops.constant(1)).numpy(), 31) + self.assertEqual(switch_case_test(array_ops.constant(2)).numpy(), -1) + self.assertEqual(switch_case_test(array_ops.constant(3)).numpy(), -1) + + def testBranchIsPruned(self): + + @def_function.function(experimental_compile=True) + def switch_case_test(): + branch_index = array_ops.constant(0) + + def f1(): + return array_ops.constant(17) + + def f2(): + # Some operations that XLA cannot compile. + image_ops.decode_image(io_ops.read_file('/tmp/bmp')) + return array_ops.constant(31) + + # This tests that we do not try to compile all branches if the branch + # index in trivially constant. + return control_flow_ops.switch_case( + branch_index, branch_fns={ + 0: f1, + 1: f2 + }, default=f2) + + with ops.device(self.device): + self.assertEqual(switch_case_test().numpy(), 17) + + +if __name__ == '__main__': + ops.enable_eager_execution() + test.main() diff --git a/tensorflow/compiler/tf2tensorrt/BUILD b/tensorflow/compiler/tf2tensorrt/BUILD index 4a8599e29f6..368cb5af2ed 100644 --- a/tensorflow/compiler/tf2tensorrt/BUILD +++ b/tensorflow/compiler/tf2tensorrt/BUILD @@ -79,6 +79,15 @@ tf_cuda_cc_test( ]), ) +cc_library( + name = "common_utils", + hdrs = ["common/utils.h"], + copts = tf_copts(), + deps = [ + "//tensorflow/core/platform:logging", + ] + if_tensorrt([":tensorrt_lib"]), +) + cc_library( name = "trt_op_kernels", srcs = [ @@ -95,6 +104,7 @@ cc_library( ":trt_plugins", ":trt_resources", ":utils", + ":common_utils", "@com_google_absl//absl/memory", "@com_google_absl//absl/strings", "@local_config_cuda//cuda:cuda_headers", @@ -240,6 +250,7 @@ tf_cuda_library( hdrs = ["utils/trt_logger.h"], visibility = ["//visibility:public"], deps = [ + ":common_utils", ":logger_registry", "//tensorflow/core:lib_proto_parsing", ] + if_tensorrt([":tensorrt_lib"]), @@ -375,6 +386,7 @@ tf_cuda_library( "convert/trt_optimization_pass.h", ], deps = [ + ":common_utils", ":logger_registry", ":segment", ":trt_allocator", @@ -488,6 +500,7 @@ cc_library( ], copts = tf_copts(), deps = [ + ":common_utils", "//tensorflow/core:graph", "//tensorflow/core:lib", "//tensorflow/core:lib_internal", @@ -575,6 +588,7 @@ cc_library( hdrs = ["utils/py_utils.h"], copts = tf_copts(), deps = if_tensorrt([ + ":common_utils", ":tensorrt_lib", "//tensorflow/stream_executor/platform:dso_loader", ]), diff --git a/tensorflow/compiler/tf2tensorrt/common/utils.h b/tensorflow/compiler/tf2tensorrt/common/utils.h new file mode 100644 index 00000000000..b428733ecd4 --- /dev/null +++ b/tensorflow/compiler/tf2tensorrt/common/utils.h @@ -0,0 +1,33 @@ +/* Copyright 2020 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_COMPILER_TF2TENSORRT_COMMON_UTILS_H_ +#define TENSORFLOW_COMPILER_TF2TENSORRT_COMMON_UTILS_H_ + +#if GOOGLE_CUDA && GOOGLE_TENSORRT + +#include "tensorflow/core/platform/logging.h" + +namespace tensorflow { +namespace tensorrt { + +#define LOG_WARNING_WITH_PREFIX LOG(WARNING) << "TF-TRT Warning: " + +} // namespace tensorrt +} // namespace tensorflow + +#endif // GOOGLE_CUDA && GOOGLE_TENSORRT + +#endif // TENSORFLOW_COMPILER_TF2TENSORRT_COMMON_UTILS_H_ diff --git a/tensorflow/compiler/tf2tensorrt/convert/convert_graph.cc b/tensorflow/compiler/tf2tensorrt/convert/convert_graph.cc index 414d27477bc..5429aaf3362 100644 --- a/tensorflow/compiler/tf2tensorrt/convert/convert_graph.cc +++ b/tensorflow/compiler/tf2tensorrt/convert/convert_graph.cc @@ -25,6 +25,7 @@ limitations under the License. #include #include "absl/strings/str_cat.h" +#include "tensorflow/compiler/tf2tensorrt/common/utils.h" #include "tensorflow/compiler/tf2tensorrt/convert/convert_nodes.h" #include "tensorflow/compiler/tf2tensorrt/convert/logger_registry.h" #include "tensorflow/compiler/tf2tensorrt/convert/utils.h" @@ -52,8 +53,7 @@ limitations under the License. #include "tensorflow/core/util/device_name_utils.h" #include "tensorflow/tools/graph_transforms/transform_utils.h" -#if GOOGLE_CUDA -#if GOOGLE_TENSORRT +#if GOOGLE_CUDA && GOOGLE_TENSORRT #include "third_party/gpus/cuda/include/cuda_runtime_api.h" #include "third_party/tensorrt/NvInfer.h" namespace tensorflow { @@ -276,8 +276,9 @@ Status GetEngineInfo(const Graph* g, if (segment_devices.size() == 1) { info->device = *segment_devices.begin(); } else if (segment_devices.size() > 1) { - LOG(WARNING) << "Detected multiple (" << segment_devices.size() - << ") devices for the segment. Picking first one to continue."; + LOG_WARNING_WITH_PREFIX + << "Detected multiple (" << segment_devices.size() + << ") devices for the segment. Picking first one to continue."; info->device = *segment_devices.begin(); } else { TfGpuId tf_gpu_id; @@ -663,7 +664,7 @@ std::pair GetDeviceAndAllocator(const ConversionParams& params, StrAppend(&msg, engine.device, "': "); for (auto d : devices) StrAppend(&msg, d->name(), ", "); StrAppend(&msg, ". Will get the allocator from first one."); - LOG(WARNING) << msg; + LOG_WARNING_WITH_PREFIX << msg; } AllocatorAttributes alloc_attr; cuda_device_id = devices[0]->tensorflow_gpu_device_info()->gpu_id; @@ -671,8 +672,8 @@ std::pair GetDeviceAndAllocator(const ConversionParams& params, VLOG(1) << "Using allocator " << dev_allocator->Name() << " and cuda_device_id " << cuda_device_id; } else { - LOG(WARNING) << "Cluster is set but device '" << engine.device - << "' is not found in the cluster"; + LOG_WARNING_WITH_PREFIX << "Cluster is set but device '" << engine.device + << "' is not found in the cluster"; } return std::make_pair(cuda_device_id, dev_allocator); } @@ -770,8 +771,8 @@ Status ConvertAfterShapes(const ConversionParams& params) { Status status = GetEngineInfo(&graph, static_graph_properties, curr_segment, node_map, reverse_topo_order, &curr_engine); if (!status.ok()) { - LOG(WARNING) << "Failed to get engine info for segment " << t << ": " - << status; + LOG_WARNING_WITH_PREFIX << "Failed to get engine info for segment " << t + << ": " << status; continue; } curr_engine.precision_mode = params.precision_mode; @@ -784,8 +785,9 @@ Status ConvertAfterShapes(const ConversionParams& params) { &graph, curr_engine.engine_name); if (!status.ok()) { - LOG(WARNING) << "Failed to register segment graphdef to the library " << t - << ": " << status; + LOG_WARNING_WITH_PREFIX + << "Failed to register segment graphdef to the library " << t << ": " + << status; continue; } @@ -836,7 +838,8 @@ Status ConvertAfterShapes(const ConversionParams& params) { alloc.reset(new TRTDeviceAllocator(device_alloc.second)); } else { // Setting allocator as nullptr should get revert to the cudamalloc - LOG(WARNING) << "Can't identify the cuda device. Running on device 0 "; + LOG_WARNING_WITH_PREFIX + << "Can't identify the cuda device. Running on device 0 "; } cudaSetDevice(cuda_device_id); auto status = @@ -850,9 +853,9 @@ Status ConvertAfterShapes(const ConversionParams& params) { LOG(INFO) << "Replaced " << msg << "."; } else { // Graph is not modified. - LOG(WARNING) << "Cannot replace " << msg - << " reason: " << status.error_message() - << " (keeping original segment)."; + LOG_WARNING_WITH_PREFIX << "Cannot replace " << msg + << " reason: " << status.error_message() + << " (keeping original segment)."; } if (VLOG_IS_ON(1)) { msg = "Segment consists of nodes: "; @@ -880,5 +883,4 @@ Status ConvertAfterShapes(const ConversionParams& params) { } // namespace tensorrt } // namespace tensorflow -#endif // GOOGLE_TENSORRT -#endif // GOOGLE_CUDA +#endif // GOOGLE_CUDA && GOOGLE_TENSORRT diff --git a/tensorflow/compiler/tf2tensorrt/convert/convert_graph.h b/tensorflow/compiler/tf2tensorrt/convert/convert_graph.h index 53ab84a6fa9..d3897e864fa 100644 --- a/tensorflow/compiler/tf2tensorrt/convert/convert_graph.h +++ b/tensorflow/compiler/tf2tensorrt/convert/convert_graph.h @@ -24,8 +24,7 @@ limitations under the License. #include "tensorflow/core/lib/core/status.h" #include "tensorflow/core/platform/types.h" -#if GOOGLE_CUDA -#if GOOGLE_TENSORRT +#if GOOGLE_CUDA && GOOGLE_TENSORRT namespace tensorflow { namespace tensorrt { @@ -66,7 +65,6 @@ Status RegisterGraphToFunctionLibrary(const GraphDef& segment_graph_def, } // namespace tensorrt } // namespace tensorflow -#endif // GOOGLE_TENSORRT -#endif // GOOGLE_CUDA +#endif // GOOGLE_CUDA && GOOGLE_TENSORRT #endif // TENSORFLOW_COMPILER_TF2TENSORRT_CONVERT_CONVERT_GRAPH_H_ diff --git a/tensorflow/compiler/tf2tensorrt/convert/convert_graph_test.cc b/tensorflow/compiler/tf2tensorrt/convert/convert_graph_test.cc index a1f523d6bfa..54fb1d56441 100644 --- a/tensorflow/compiler/tf2tensorrt/convert/convert_graph_test.cc +++ b/tensorflow/compiler/tf2tensorrt/convert/convert_graph_test.cc @@ -34,8 +34,7 @@ limitations under the License. #include "tensorflow/core/protobuf/config.pb.h" // NOLINT #include "tensorflow/core/public/session.h" -#if GOOGLE_CUDA -#if GOOGLE_TENSORRT +#if GOOGLE_CUDA && GOOGLE_TENSORRT namespace tensorflow { namespace tensorrt { @@ -231,5 +230,4 @@ TEST_F(ConvertAfterShapesTest, DirectlyConnectedEngines) { } // namespace tensorrt } // namespace tensorflow -#endif // GOOGLE_TENSORRT -#endif // GOOGLE_CUDA +#endif // GOOGLE_CUDA && GOOGLE_TENSORRT diff --git a/tensorflow/compiler/tf2tensorrt/convert/convert_nodes.cc b/tensorflow/compiler/tf2tensorrt/convert/convert_nodes.cc index 20ee5ffd8f8..2ec616ba621 100644 --- a/tensorflow/compiler/tf2tensorrt/convert/convert_nodes.cc +++ b/tensorflow/compiler/tf2tensorrt/convert/convert_nodes.cc @@ -31,6 +31,7 @@ limitations under the License. #include "absl/strings/str_cat.h" #include "absl/strings/str_format.h" #include "absl/strings/string_view.h" +#include "tensorflow/compiler/tf2tensorrt/common/utils.h" #include "tensorflow/compiler/tf2tensorrt/convert/utils.h" #include "tensorflow/compiler/tf2tensorrt/utils/trt_logger.h" #include "tensorflow/compiler/tf2tensorrt/utils/trt_shape_optimization_profiles.h" @@ -58,8 +59,7 @@ limitations under the License. #include "tensorflow/core/util/env_var.h" #include "tensorflow/core/util/strided_slice_op.h" -#if GOOGLE_CUDA -#if GOOGLE_TENSORRT +#if GOOGLE_CUDA && GOOGLE_TENSORRT #include "third_party/tensorrt/NvInfer.h" #include "third_party/tensorrt/NvInferPlugin.h" @@ -1214,15 +1214,16 @@ static void InitializeTrtPlugins(nvinfer1::ILogger* trt_logger) { nvinfer1::IPluginCreator* const* trt_plugin_creator_list = getPluginRegistry()->getPluginCreatorList(&num_trt_plugins); if (!trt_plugin_creator_list) { - LOG(WARNING) << "Can not find any TensorRT plugins in registry."; + LOG_WARNING_WITH_PREFIX << "Can not find any TensorRT plugins in registry."; } else { VLOG(1) << "Found the following " << num_trt_plugins << " TensorRT plugins in registry:"; for (int i = 0; i < num_trt_plugins; ++i) { if (!trt_plugin_creator_list[i]) { - LOG(WARNING) << "TensorRT plugin at index " << i - << " is not accessible (null pointer returned by " - "getPluginCreatorList for this plugin)"; + LOG_WARNING_WITH_PREFIX + << "TensorRT plugin at index " << i + << " is not accessible (null pointer returned by " + "getPluginCreatorList for this plugin)"; } else { VLOG(1) << " " << trt_plugin_creator_list[i]->getPluginName(); } @@ -1827,9 +1828,9 @@ void Converter::MaybeApplyQuantizationRanges() { // are tensors which are created internally by TF-TRT. The ranges for // these unnamed ITensors are always inferred from user provided ranges, // thus there will also be a warning for the range(s) the user missed. - LOG(WARNING) << "Quantization range was not found for " - << tensor->getName() << ". " - << "Setting invalid quantization range."; + LOG_WARNING_WITH_PREFIX << "Quantization range was not found for " + << tensor->getName() << ". " + << "Setting invalid quantization range."; // Set the range to something unusable so the engine will fail if it // tries to actually use the tensor's range. tensor->setDynamicRange(0, 0); @@ -4424,8 +4425,13 @@ Status ConvertSquare(OpConverterParams* params) { const auto& inputs = params->inputs; const auto& node_def = params->node_def; TF_RETURN_IF_ERROR(CheckInputsWeights(*params, {{"x", false}})); +#if IS_TRT_VERSION_GE(6, 0, 1, 0) + TF_RETURN_IF_ERROR(AllowDataTypes( + *params, {DataType::DT_FLOAT, DataType::DT_HALF, DataType::DT_INT32})); +#else TF_RETURN_IF_ERROR( AllowDataTypes(*params, {DataType::DT_FLOAT, DataType::DT_HALF})); +#endif if (params->validation_only) return Status::OK(); // Constant 2 with same rank as input @@ -4893,10 +4899,11 @@ Status ConvertFusedBatchNorm(OpConverterParams* params) { // Trying to use batchnorm in training mode is a very common problem. // Because the error message will only be printed in VLOG(1) by the // segmenter, we issue a special warning so that users will actually see it. - LOG(WARNING) << node_def.op() << " only supports is_training=false. If you " - << "are using Keras, please call " - << "keras.backend.set_learning_phase(0) before constructing " - << "your model. At " << node_def.name(); + LOG_WARNING_WITH_PREFIX + << node_def.op() << " only supports is_training=false. If you " + << "are using Keras, please call " + << "keras.backend.set_learning_phase(0) before constructing " + << "your model. At " << node_def.name(); return errors::Unimplemented(node_def.op(), " only supports is_training=false, at ", node_def.name()); @@ -6034,7 +6041,7 @@ Status ConvertGraphDefToEngine( const string error_message = StrCat("Validation failed for ", node_name, " and input slot ", slot_number, ": ", status.error_message()); - LOG(WARNING) << error_message; + LOG_WARNING_WITH_PREFIX << error_message; return Status(status.code(), error_message); } VLOG(2) << "Adding engine input tensor " << node_name << " with shape " @@ -6250,5 +6257,4 @@ bool OutputEdgeValidator::operator()(const Edge* out_edge) const { } // namespace tensorrt } // namespace tensorflow -#endif // GOOGLE_TENSORRT -#endif // GOOGLE_CUDA +#endif // GOOGLE_CUDA && GOOGLE_TENSORRT diff --git a/tensorflow/compiler/tf2tensorrt/convert/convert_nodes.h b/tensorflow/compiler/tf2tensorrt/convert/convert_nodes.h index 7a1276c645c..a621735fad1 100644 --- a/tensorflow/compiler/tf2tensorrt/convert/convert_nodes.h +++ b/tensorflow/compiler/tf2tensorrt/convert/convert_nodes.h @@ -33,8 +33,7 @@ limitations under the License. #include "tensorflow/core/lib/core/status.h" #include "tensorflow/stream_executor/lib/statusor.h" -#if GOOGLE_CUDA -#if GOOGLE_TENSORRT +#if GOOGLE_CUDA && GOOGLE_TENSORRT #include "third_party/tensorrt/NvInfer.h" namespace tensorflow { @@ -694,7 +693,6 @@ BinaryOperationMap(); } // namespace tensorrt } // namespace tensorflow -#endif // GOOGLE_TENSORRT -#endif // GOOGLE_CUDA +#endif // GOOGLE_CUDA && GOOGLE_TENSORRT #endif // TENSORFLOW_COMPILER_TF2TENSORRT_CONVERT_CONVERT_NODES_H_ diff --git a/tensorflow/compiler/tf2tensorrt/convert/convert_nodes_test.cc b/tensorflow/compiler/tf2tensorrt/convert/convert_nodes_test.cc index 450831910f6..53ec9ee7ada 100644 --- a/tensorflow/compiler/tf2tensorrt/convert/convert_nodes_test.cc +++ b/tensorflow/compiler/tf2tensorrt/convert/convert_nodes_test.cc @@ -21,8 +21,7 @@ limitations under the License. #include #include -#if GOOGLE_CUDA -#if GOOGLE_TENSORRT +#if GOOGLE_CUDA && GOOGLE_TENSORRT #include #include @@ -1811,7 +1810,9 @@ class ParameterizedOpConverterTestBase const int batch_size = input_data_[0].tensor.shape().dim_size(0); Status stat = OpConverterTest::BuildAndRun(input_data_, &output_data, batch_size); - ASSERT_EQ(expected_runtime_status, stat); + ASSERT_EQ(expected_runtime_status.ok(), stat.ok()) + << "expected status: " << expected_runtime_status + << ", actual status: " << stat; if (expected_runtime_status.ok() && stat.ok()) { for (int i = 0; i < n_output; i++) { // Check the shape of the actual output tensors @@ -2754,58 +2755,40 @@ TEST_F(OpConverterTest, ConvertQuantize) { } } -template -void TestConvertSquare(OpConverterTest* test) { - test->Reset(); - typedef typename EnumToDataType::Type CType; - - Scope s = Scope::NewRootScope(); - auto input = ops::Placeholder(s.WithOpName("input"), dtype); - auto square = ops::Square(s.WithOpName("my_square"), input); - NodeDef node_def = square.operation.node()->def(); - - test->AddTestTensor("input", {1, 20}, /*batch_size=*/1, - TfDataTypeToTrt(dtype)); - test->RunValidationAndConversion(node_def); - TRT_TensorOrWeights output; - TF_EXPECT_OK(test->GetTensorOrWeights("my_square", &output)); - ASSERT_TRUE(output.is_tensor()); - ExpectTrtDimsEqualsArray({1, 20}, output.tensor()->getDimensions()); - - const int num_inputs = 20; - std::vector inputs(num_inputs); - std::vector expected_outputs(num_inputs); - for (int i = 0; i < num_inputs; ++i) { - const CType value = CType(i - 9); - inputs[i] = value; - expected_outputs[i] = value * value; - } - const DataVec input_data{{"input", test->AsTensor(inputs)}}; - // Engine outputs are converted to FP16 automatically if we set FP16 mode in - // the builder. - DataVec output_data{{"my_square", test->ConstructTensor(num_inputs)}}; - TF_EXPECT_OK(test->BuildAndRun(input_data, &output_data)); - ExpectArrayNear(expected_outputs, GetSpanForData(output_data[0])); -} - -TEST_F(OpConverterTest, ConvertSquare) { +TEST_P(OpConverterTest2, ConvertSquare) { { // Input is weights, should fail. Reset(); Scope s = Scope::NewRootScope(); - auto input = ops::Placeholder(s.WithOpName("input"), DT_FLOAT); + auto input = ops::Placeholder(s.WithOpName("input"), tf_type); auto square = ops::Square(s.WithOpName("my_square"), input); NodeDef node_def = square.operation.node()->def(); - AddTestWeights("input", {1, 2, 3}, {1, 2, 3, 4, -5, 6}); + AddTestWeights("input", {1, 2, 3}, {1, 2, 3, 4, -5, 6}, tf_type); RunValidationAndConversion( node_def, error::UNIMPLEMENTED, "The input \"x\" for Square must be a tensor, at my_square"); } - // OK. Note that kINT32 is not supported by IElementWiseLayer, so we don't - // test DT_INT32 type here. - TestConvertSquare(this); - TestConvertSquare(this); + Reset(); + + Scope s = Scope::NewRootScope(); + auto input = ops::Placeholder(s.WithOpName("input"), tf_type); + auto square = ops::Square(s.WithOpName("my_square"), input); + NodeDef node_def = square.operation.node()->def(); + + const int num_inputs = 20; + std::vector inputs(num_inputs); + std::vector expected_outputs(num_inputs); + + for (int i = 0; i < num_inputs; ++i) { + const float value = (i - 9); + inputs[i] = value; + expected_outputs[i] = value * value; + } + AddTestTensor("input", {1, 1, 20}, tf_type, inputs); + + TestOpConverter("my_square", node_def, {1, 1, 20}, Status::OK(), Status::OK(), + ArrayFloatNear(expected_outputs, 0)); } #if IS_TRT_VERSION_GE(5, 1, 0, 0) @@ -6359,87 +6342,70 @@ NodeDef GetSquaredDifferenceNodeDef(DataType dtype) { return squared_diff.operation.node()->def(); } -template -void TestConvertSquaredDifference(OpConverterTest* test) { - typedef typename EnumToDataType::Type CType; - - struct TestParams { - std::vector dims_x; - std::vector dims_y; - std::vector value_x; - std::vector value_y; - std::vector expected_output_dims; - std::vector expected_output; - }; - - const std::vector common_input = InitTestVector(6); - std::vector params = { - { - /*dims_x=*/{1, 2, 3}, - /*dims_y=*/{1, 2, 3}, - /*value_x=*/common_input, - /*value_y=*/CastTestVector({0, -1, 3, 0, 10, -7}), - /*expected_output_dims=*/{1, 2, 3}, - /*expected_output=*/CastTestVector({0, 4, 1, 9, 36, 144}), - }, - { - /*dims_x=*/{1, 2, 3}, - /*dims_y=*/{1, 1, 3}, - /*value_x=*/common_input, - /*value_y=*/CastTestVector({0, 1, 2}), - /*expected_output_dims=*/{1, 2, 3}, - /*expected_output=*/CastTestVector({0, 0, 0, 9, 9, 9}), - }, - }; - - for (int i = 0; i < params.size(); ++i) { - test->Reset(); - - NodeDef node_def = GetSquaredDifferenceNodeDef(dtype); - test->AddTestTensor("x", params[i].dims_x, 1, TfDataTypeToTrt(dtype)); - test->AddTestTensor("y", params[i].dims_y, 1, TfDataTypeToTrt(dtype)); - test->RunValidationAndConversion(node_def); - - TRT_TensorOrWeights output; - TF_EXPECT_OK(test->GetTensorOrWeights("my_squared_diff", &output)); - EXPECT_TRUE(output.is_tensor()); - ExpectTrtDimsEqualsArray(params[i].expected_output_dims, - output.tensor()->getDimensions()); - - DataVec input_data{{"x", test->AsTensor(params[i].value_x)}, - {"y", test->AsTensor(params[i].value_y)}}; - DataVec output_data{ - {"my_squared_diff", - test->ConstructTensor(params[i].expected_output.size())}}; - TF_EXPECT_OK(test->BuildAndRun(input_data, &output_data)); - EXPECT_THAT(GetSpanForData(output_data[0]), - ElementsAreArray(params[i].expected_output)); - } -} - -TEST_F(OpConverterTest, ConvertSquaredDifference) { +TEST_P(OpConverterTest2, ConvertSquaredDifference) { { // Input is a weight, should fail. Reset(); - NodeDef node_def = GetSquaredDifferenceNodeDef(DT_FLOAT); + NodeDef node_def = GetSquaredDifferenceNodeDef(tf_type); AddTestWeights("x", {1, 2, 3}, {1, 2, 3, 4, 5, 6}); - AddTestTensor("y", {1, 2, 3}); + AddTestTensor("y", {1, 1, 2, 3}); RunValidationAndConversion(node_def, error::UNIMPLEMENTED, "The input \"x\" for SquaredDifference must be " "a tensor, at my_squared_diff"); } - { - // Shapes are not broadcastable, should fail. - Reset(); - NodeDef node_def = GetSquaredDifferenceNodeDef(DT_FLOAT); - AddTestTensor("x", {2, 3}); - AddTestTensor("y", {7, 5}); - RunValidationAndConversion(node_def, error::INVALID_ARGUMENT, - "Infeasible broadcast scheme"); - } - TestConvertSquaredDifference(this); - TestConvertSquaredDifference(this); + struct TestParams { + std::vector dims_x; + std::vector dims_y; + std::vector value_x; + std::vector value_y; + std::vector expected_output_dims; + std::vector expected_output; + Status status; + Status runtime_status; + }; + + const std::vector common_input = InitTestVector(6); + std::vector params = { + {/*dims_x=*/{1, 2, 3}, + /*dims_y=*/{1, 7, 5}, + /*value_x=*/common_input, + /*value_y=*/std::vector(7 * 5, 0), + /*expected_output_dims=*/{1, 1, 2, 3}, + /*expected_output=*/common_input, + trt_mode == TrtTestMode::kDynamicShape + ? Status::OK() + : errors::InvalidArgument("Infeasible broadcast scheme"), + errors::Internal( + "Binding index out of range. This can happen if profile is not set, " + "or the network is invalid for the current profile.")}, + { + /*dims_x=*/{1, 1, 2, 3}, + /*dims_y=*/{1, 1, 2, 3}, + /*value_x=*/common_input, + /*value_y=*/{0, -1, 3, 0, 10, -7}, + /*expected_output_dims=*/{1, 1, 2, 3}, + /*expected_output=*/{0, 4, 1, 9, 36, 144}, + }, + { + /*dims_x=*/{1, 1, 2, 3}, + /*dims_y=*/{1, 1, 1, 3}, + /*value_x=*/common_input, + /*value_y=*/{0, 1, 2}, + /*expected_output_dims=*/{1, 1, 2, 3}, + /*expected_output=*/{0, 0, 0, 9, 9, 9}, + }, + }; + + for (auto p : params) { + Reset(); + NodeDef node_def = GetSquaredDifferenceNodeDef(tf_type); + AddTestTensor("x", p.dims_x, p.value_x); + AddTestTensor("y", p.dims_y, p.value_y); + TestOpConverter("my_squared_diff", node_def, p.expected_output_dims, + p.status, p.runtime_status, + ElementsAreArray(p.expected_output)); + } } #if IS_TRT_VERSION_GE(6, 0, 0, 0) @@ -6669,5 +6635,4 @@ TEST_F(OpConverterTest, ConvertPad) { } // namespace tensorrt } // namespace tensorflow -#endif // GOOGLE_TENSORRT -#endif // GOOGLE_CUDA +#endif // GOOGLE_CUDA && GOOGLE_TENSORRT diff --git a/tensorflow/compiler/tf2tensorrt/convert/logger_registry.cc b/tensorflow/compiler/tf2tensorrt/convert/logger_registry.cc index 82e68cbb28d..07c9c2f1ea0 100644 --- a/tensorflow/compiler/tf2tensorrt/convert/logger_registry.cc +++ b/tensorflow/compiler/tf2tensorrt/convert/logger_registry.cc @@ -12,8 +12,7 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#if GOOGLE_CUDA -#if GOOGLE_TENSORRT +#if GOOGLE_CUDA && GOOGLE_TENSORRT #include "tensorflow/compiler/tf2tensorrt/convert/logger_registry.h" @@ -58,5 +57,4 @@ LoggerRegistry* GetLoggerRegistry() { } // namespace tensorrt } // namespace tensorflow -#endif // GOOGLE_TENSORRT -#endif // GOOGLE_CUDA +#endif // GOOGLE_CUDA && GOOGLE_TENSORRT diff --git a/tensorflow/compiler/tf2tensorrt/convert/logger_registry.h b/tensorflow/compiler/tf2tensorrt/convert/logger_registry.h index 45b302742d0..2a265cf7caa 100644 --- a/tensorflow/compiler/tf2tensorrt/convert/logger_registry.h +++ b/tensorflow/compiler/tf2tensorrt/convert/logger_registry.h @@ -19,7 +19,8 @@ limitations under the License. #include "tensorflow/core/platform/macros.h" #include "tensorflow/core/platform/types.h" -#if GOOGLE_CUDA +#if GOOGLE_CUDA && GOOGLE_TENSORRT + #include "third_party/tensorrt/NvInfer.h" namespace tensorflow { @@ -53,5 +54,5 @@ class RegisterLogger { } // namespace tensorrt } // namespace tensorflow -#endif // GOOGLE_CUDA +#endif // GOOGLE_CUDA && GOOGLE_TENSORRT #endif // TENSORFLOW_COMPILER_TF2TENSORRT_CONVERT_LOGGER_REGISTRY_H_ diff --git a/tensorflow/compiler/tf2tensorrt/convert/trt_optimization_pass.cc b/tensorflow/compiler/tf2tensorrt/convert/trt_optimization_pass.cc index 72f4fe5ef9b..1cf98d135cb 100644 --- a/tensorflow/compiler/tf2tensorrt/convert/trt_optimization_pass.cc +++ b/tensorflow/compiler/tf2tensorrt/convert/trt_optimization_pass.cc @@ -28,8 +28,7 @@ limitations under the License. #include "tensorflow/core/platform/logging.h" #include "tensorflow/core/platform/stacktrace.h" -#if GOOGLE_CUDA -#if GOOGLE_TENSORRT +#if GOOGLE_CUDA && GOOGLE_TENSORRT namespace tensorflow { namespace tensorrt { namespace convert { @@ -302,5 +301,4 @@ static VerboseCustomGraphOptimizerRegistrar TRTOptimizationPass_Registrar( } // namespace tensorrt } // namespace tensorflow -#endif -#endif +#endif // GOOGLE_CUDA && GOOGLE_TENSORRT diff --git a/tensorflow/compiler/tf2tensorrt/convert/trt_optimization_pass.h b/tensorflow/compiler/tf2tensorrt/convert/trt_optimization_pass.h index f79048bb5f6..e0aaa5500ab 100644 --- a/tensorflow/compiler/tf2tensorrt/convert/trt_optimization_pass.h +++ b/tensorflow/compiler/tf2tensorrt/convert/trt_optimization_pass.h @@ -23,8 +23,7 @@ limitations under the License. #include "tensorflow/core/grappler/optimizers/custom_graph_optimizer.h" #include "tensorflow/core/platform/logging.h" -#if GOOGLE_CUDA -#if GOOGLE_TENSORRT +#if GOOGLE_CUDA && GOOGLE_TENSORRT namespace tensorflow { namespace tensorrt { @@ -83,6 +82,5 @@ class TRTOptimizationPass : public grappler::CustomGraphOptimizer { } // namespace tensorrt } // namespace tensorflow -#endif // GOOGLE_CUDA -#endif // GOOGLE_TENSORRT +#endif // GOOGLE_CUDA && GOOGLE_TENSORRT #endif // TENSORFLOW_COMPILER_TF2TENSORRT_CONVERT_TRT_OPTIMIZATION_PASS_H_ diff --git a/tensorflow/compiler/tf2tensorrt/kernels/get_calibration_data_op.cc b/tensorflow/compiler/tf2tensorrt/kernels/get_calibration_data_op.cc index 3143b06817e..76fb40b9520 100644 --- a/tensorflow/compiler/tf2tensorrt/kernels/get_calibration_data_op.cc +++ b/tensorflow/compiler/tf2tensorrt/kernels/get_calibration_data_op.cc @@ -22,8 +22,7 @@ limitations under the License. #include "tensorflow/core/framework/resource_mgr.h" #include "tensorflow/core/lib/core/refcount.h" -#if GOOGLE_CUDA -#if GOOGLE_TENSORRT +#if GOOGLE_CUDA && GOOGLE_TENSORRT namespace tensorflow { namespace tensorrt { @@ -67,5 +66,4 @@ REGISTER_KERNEL_BUILDER(Name("GetCalibrationDataOp").Device(DEVICE_GPU), } // namespace tensorrt } // namespace tensorflow -#endif // GOOGLE_TENSORRT -#endif // GOOGLE_CUDA +#endif // GOOGLE_CUDA && GOOGLE_TENSORRT diff --git a/tensorflow/compiler/tf2tensorrt/kernels/trt_engine_op.cc b/tensorflow/compiler/tf2tensorrt/kernels/trt_engine_op.cc index d9b8e198f4f..1094555a622 100644 --- a/tensorflow/compiler/tf2tensorrt/kernels/trt_engine_op.cc +++ b/tensorflow/compiler/tf2tensorrt/kernels/trt_engine_op.cc @@ -20,6 +20,7 @@ limitations under the License. #include "absl/strings/ascii.h" #include "absl/strings/str_cat.h" #include "absl/strings/string_view.h" +#include "tensorflow/compiler/tf2tensorrt/common/utils.h" #include "tensorflow/compiler/tf2tensorrt/convert/convert_nodes.h" #include "tensorflow/compiler/tf2tensorrt/convert/utils.h" #include "tensorflow/compiler/tf2tensorrt/utils/trt_allocator.h" @@ -44,10 +45,10 @@ limitations under the License. #include "tensorflow/core/platform/stream_executor.h" #include "tensorflow/core/platform/thread_annotations.h" #include "tensorflow/core/platform/types.h" +#include "tensorflow/core/util/env_var.h" #include "tensorflow/stream_executor/lib/statusor.h" -#if GOOGLE_CUDA -#if GOOGLE_TENSORRT +#if GOOGLE_CUDA && GOOGLE_TENSORRT #include "third_party/gpus/cuda/include/cuda_runtime_api.h" #include "third_party/tensorrt/NvInfer.h" @@ -520,6 +521,17 @@ Status TRTEngineOp::VerifyInputShapes( return Status::OK(); } +static bool AllowEngineNativeSegmentExecution() { + bool value; + Status status = + ReadBoolFromEnvVar("TF_TRT_ALLOW_ENGINE_NATIVE_SEGMENT_EXECUTION", + /*default_value=*/true, &value); + if (!status.ok()) { + LOG(ERROR) << status; + } + return value; +} + void TRTEngineOp::ComputeAsync(OpKernelContext* ctx, AsyncOpKernel::DoneCallback done) { auto helper = new AsyncHelper(done); @@ -604,17 +616,31 @@ void TRTEngineOp::ComputeAsync(OpKernelContext* ctx, EngineContext* engine_context = status.ValueOrDie().first; int trt_context_idx = status.ValueOrDie().second; + auto may_execute_native_segment = [&] { + if (!AllowEngineNativeSegmentExecution()) { + ctx->CtxFailure( + errors::Aborted("User disallowed engine native segment execution")); + return false; + } + return true; + }; if (!engine_context->cuda_engine) { - VLOG(1) << "Engine retrieval for input shapes: " - << TensorShapeUtils::ShapeListString(input_concrete_shapes) - << " failed. Running native segment for " << name(); - ExecuteNativeSegment(ctx, helper); + LOG_WARNING_WITH_PREFIX + << "Engine retrieval for input shapes: " + << TensorShapeUtils::ShapeListString(input_concrete_shapes) + << " failed. Running native segment for " << name(); + if (may_execute_native_segment()) { + ExecuteNativeSegment(ctx, helper); + } return; } Status stat = ExecuteTrtEngine(ctx, engine_context, trt_context_idx); if (!stat.ok()) { - LOG(WARNING) << "Failed to execute engine: " << stat - << " Retrying with native segment for " << name(); + LOG_WARNING_WITH_PREFIX << "Failed to execute engine: " << stat + << " Retrying with native segment for " << name(); + if (!may_execute_native_segment()) { + return; + } // Release any outputs that are allocated, ExecuteNativeSegment will // re-allocate them and fail if they are currently allocated. for (int i = 0; i < ctx->num_outputs(); i++) { @@ -727,9 +753,9 @@ StatusOr> TRTEngineOp::BuildEngine( calibrator, &engine, use_calibration, use_implicit_batch_, nullptr, &cache_resource->profiles_); if (!status.ok()) { - LOG(WARNING) << "Engine creation for " << name() << " failed. " - << "The native segment will be used instead. " - << "Reason: " << status; + LOG_WARNING_WITH_PREFIX << "Engine creation for " << name() << " failed. " + << "The native segment will be used instead. " + << "Reason: " << status; // Store an empty engine in the cache for these input shapes so we don't try // to build the same failing engine again. cache_resource->cache_.emplace(input_concrete_shapes, @@ -791,8 +817,9 @@ StatusOr> TRTEngineOp::GetEngine( FunctionDefToGraphDef(func_handle_, lib, &segment_graph_def_); } if (!status.ok()) { - LOG(WARNING) << "Getting segment graph for " << name() << " failed. " - << "Reason: " << status; + LOG_WARNING_WITH_PREFIX << "Getting segment graph for " << name() + << " failed. " + << "Reason: " << status; } } auto result = BuildEngine(input_concrete_shapes, batch_size, @@ -851,10 +878,11 @@ StatusOr> TRTEngineOp::GetEngine( // If cache does not have a compatible engine then create a new engine. if (engine_contexts == nullptr) { if (!allow_build_at_runtime_) { - LOG(WARNING) << "Found no engine in cache matching input shapes. " - << "Not building a new engine because " - << "allow_build_at_runtime=False. " - << "The native segment will be used instead."; + LOG_WARNING_WITH_PREFIX + << "Found no engine in cache matching input shapes. " + << "Not building a new engine because " + << "allow_build_at_runtime=False. " + << "The native segment will be used instead."; // Store an empty engine in the cache for these input shapes so we don't // try to build the same failing engine again. cache.emplace(input_concrete_shapes, absl::make_unique()); @@ -980,5 +1008,4 @@ REGISTER_KERNEL_BUILDER(Name("TRTEngineOp").Device(DEVICE_GPU), TRTEngineOp); } // namespace tensorrt } // namespace tensorflow -#endif // GOOGLE_TENSORRT -#endif // GOOGLE_CUDA +#endif // GOOGLE_CUDA && GOOGLE_TENSORRT diff --git a/tensorflow/compiler/tf2tensorrt/kernels/trt_engine_op_test.cc b/tensorflow/compiler/tf2tensorrt/kernels/trt_engine_op_test.cc index a06010de1c7..71193dc24cf 100644 --- a/tensorflow/compiler/tf2tensorrt/kernels/trt_engine_op_test.cc +++ b/tensorflow/compiler/tf2tensorrt/kernels/trt_engine_op_test.cc @@ -50,8 +50,7 @@ limitations under the License. #include "tensorflow/core/platform/status.h" #include "tensorflow/core/public/version.h" -#if GOOGLE_CUDA -#if GOOGLE_TENSORRT +#if GOOGLE_CUDA && GOOGLE_TENSORRT namespace tensorflow { namespace tensorrt { @@ -306,5 +305,4 @@ TYPED_TEST(TRTEngineOpTest, Basic) { } // namespace tensorrt } // namespace tensorflow -#endif // GOOGLE_TENSORRT -#endif // GOOGLE_CUDA +#endif // GOOGLE_CUDA && GOOGLE_TENSORRT diff --git a/tensorflow/compiler/tf2tensorrt/kernels/trt_engine_resource_ops.cc b/tensorflow/compiler/tf2tensorrt/kernels/trt_engine_resource_ops.cc index 2c5821df6ac..3b6e7e91d3b 100644 --- a/tensorflow/compiler/tf2tensorrt/kernels/trt_engine_resource_ops.cc +++ b/tensorflow/compiler/tf2tensorrt/kernels/trt_engine_resource_ops.cc @@ -33,8 +33,7 @@ limitations under the License. #include "tensorflow/core/platform/mutex.h" #include "tensorflow/core/platform/thread_annotations.h" -#if GOOGLE_CUDA -#if GOOGLE_TENSORRT +#if GOOGLE_CUDA && GOOGLE_TENSORRT #include "third_party/tensorrt/NvInfer.h" namespace tensorflow { @@ -251,5 +250,4 @@ REGISTER_KERNEL_BUILDER(Name("SerializeTRTResource").Device(DEVICE_GPU), } // namespace tensorrt } // namespace tensorflow -#endif // GOOGLE_TENSORRT -#endif // GOOGLE_CUDA +#endif // GOOGLE_CUDA && GOOGLE_TENSORRT diff --git a/tensorflow/compiler/tf2tensorrt/kernels/trt_engine_resource_ops_test.cc b/tensorflow/compiler/tf2tensorrt/kernels/trt_engine_resource_ops_test.cc index 4a24160569d..6a073ee24d0 100644 --- a/tensorflow/compiler/tf2tensorrt/kernels/trt_engine_resource_ops_test.cc +++ b/tensorflow/compiler/tf2tensorrt/kernels/trt_engine_resource_ops_test.cc @@ -48,8 +48,7 @@ limitations under the License. #include "tensorflow/core/platform/tstring.h" #include "tensorflow/core/platform/types.h" -#if GOOGLE_CUDA -#if GOOGLE_TENSORRT +#if GOOGLE_CUDA && GOOGLE_TENSORRT namespace tensorflow { namespace tensorrt { @@ -246,5 +245,4 @@ TEST_F(TRTEngineResourceOpsTest, Basic) { } // namespace tensorrt } // namespace tensorflow -#endif // GOOGLE_TENSORRT -#endif // GOOGLE_CUDA +#endif // GOOGLE_CUDA && GOOGLE_TENSORRT diff --git a/tensorflow/compiler/tf2tensorrt/ops/get_calibration_data_op.cc b/tensorflow/compiler/tf2tensorrt/ops/get_calibration_data_op.cc index 573172b92e6..2af3164c3e2 100644 --- a/tensorflow/compiler/tf2tensorrt/ops/get_calibration_data_op.cc +++ b/tensorflow/compiler/tf2tensorrt/ops/get_calibration_data_op.cc @@ -13,8 +13,7 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#if GOOGLE_CUDA -#if GOOGLE_TENSORRT +#if GOOGLE_CUDA && GOOGLE_TENSORRT #include "tensorflow/core/framework/common_shape_fns.h" #include "tensorflow/core/framework/op.h" @@ -34,5 +33,4 @@ Returns calibration data for the given resource name } // namespace tensorflow -#endif // GOOGLE_TENSORRT -#endif // GOOGLE_CUDA +#endif // GOOGLE_CUDA && GOOGLE_TENSORRT diff --git a/tensorflow/compiler/tf2tensorrt/ops/trt_engine_op.cc b/tensorflow/compiler/tf2tensorrt/ops/trt_engine_op.cc index bd3c2b299a9..2527fe9b910 100644 --- a/tensorflow/compiler/tf2tensorrt/ops/trt_engine_op.cc +++ b/tensorflow/compiler/tf2tensorrt/ops/trt_engine_op.cc @@ -13,8 +13,7 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#if GOOGLE_CUDA -#if GOOGLE_TENSORRT +#if GOOGLE_CUDA && GOOGLE_TENSORRT #include "tensorflow/core/framework/common_shape_fns.h" #include "tensorflow/core/framework/op.h" @@ -59,5 +58,4 @@ REGISTER_OP("TRTEngineOp") .Attr("static_engine: bool = true"); } // namespace tensorflow -#endif // GOOGLE_TENSORRT -#endif // GOOGLE_CUDA +#endif // GOOGLE_CUDA && GOOGLE_TENSORRT diff --git a/tensorflow/compiler/tf2tensorrt/ops/trt_engine_resource_ops.cc b/tensorflow/compiler/tf2tensorrt/ops/trt_engine_resource_ops.cc index 01911de66ec..3141092de03 100644 --- a/tensorflow/compiler/tf2tensorrt/ops/trt_engine_resource_ops.cc +++ b/tensorflow/compiler/tf2tensorrt/ops/trt_engine_resource_ops.cc @@ -13,8 +13,7 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#if GOOGLE_CUDA -#if GOOGLE_TENSORRT +#if GOOGLE_CUDA && GOOGLE_TENSORRT #include "tensorflow/core/framework/common_shape_fns.h" #include "tensorflow/core/framework/op.h" @@ -46,5 +45,4 @@ REGISTER_OP("SerializeTRTResource") } // namespace tensorflow -#endif // GOOGLE_TENSORRT -#endif // GOOGLE_CUDA +#endif // GOOGLE_CUDA && GOOGLE_TENSORRT diff --git a/tensorflow/compiler/tf2tensorrt/plugin/plugin_cast.cu.cc b/tensorflow/compiler/tf2tensorrt/plugin/plugin_cast.cu.cc index 4c0d8b0392a..141a7d1f462 100644 --- a/tensorflow/compiler/tf2tensorrt/plugin/plugin_cast.cu.cc +++ b/tensorflow/compiler/tf2tensorrt/plugin/plugin_cast.cu.cc @@ -17,8 +17,7 @@ limitations under the License. #include "tensorflow/compiler/tf2tensorrt/plugin/trt_plugin.h" #include "tensorflow/core/platform/logging.h" -#if GOOGLE_CUDA -#if GOOGLE_TENSORRT +#if GOOGLE_CUDA && GOOGLE_TENSORRT #define EIGEN_USE_GPU // For definition of Eigen::GpuDevice. #include "third_party/gpus/cuda/include/cuda_runtime_api.h" #include "tensorflow/core/util/gpu_kernel_helper.h" @@ -234,5 +233,4 @@ REGISTER_TFTRT_PLUGIN(CastPluginCreator); } // namespace tensorrt } // namespace tensorflow -#endif // GOOGLE_CUDA -#endif // GOOGLE_TENSORRT +#endif // GOOGLE_CUDA && GOOGLE_TENSORRT diff --git a/tensorflow/compiler/tf2tensorrt/plugin/trt_plugin.cc b/tensorflow/compiler/tf2tensorrt/plugin/trt_plugin.cc index 563ce724f43..83d5f9b5965 100644 --- a/tensorflow/compiler/tf2tensorrt/plugin/trt_plugin.cc +++ b/tensorflow/compiler/tf2tensorrt/plugin/trt_plugin.cc @@ -17,8 +17,7 @@ limitations under the License. #include -#if GOOGLE_CUDA -#if GOOGLE_TENSORRT +#if GOOGLE_CUDA && GOOGLE_TENSORRT namespace tensorflow { namespace tensorrt { @@ -30,5 +29,4 @@ const char* kTfTrtPluginNamespace = "TF"; } // namespace tensorrt } // namespace tensorflow -#endif // GOOGLE_CUDA -#endif // GOOGLE_TENSORRT +#endif // GOOGLE_CUDA && GOOGLE_TENSORRT diff --git a/tensorflow/compiler/tf2tensorrt/plugin/trt_plugin.h b/tensorflow/compiler/tf2tensorrt/plugin/trt_plugin.h index bdb046e6c71..600ac6683da 100644 --- a/tensorflow/compiler/tf2tensorrt/plugin/trt_plugin.h +++ b/tensorflow/compiler/tf2tensorrt/plugin/trt_plugin.h @@ -20,8 +20,7 @@ limitations under the License. #include "tensorflow/core/platform/logging.h" -#if GOOGLE_CUDA -#if GOOGLE_TENSORRT +#if GOOGLE_CUDA && GOOGLE_TENSORRT #include "third_party/tensorrt/NvInfer.h" namespace tensorflow { @@ -90,7 +89,6 @@ class TrtPluginRegistrar { } // namespace tensorrt } // namespace tensorflow -#endif // GOOGLE_TENSORRT -#endif // GOOGLE_CUDA +#endif // GOOGLE_CUDA && GOOGLE_TENSORRT #endif // TENSORFLOW_COMPILER_TF2TENSORRT_PLUGIN_TRT_PLUGIN_H_ diff --git a/tensorflow/compiler/tf2tensorrt/segment/segment.cc b/tensorflow/compiler/tf2tensorrt/segment/segment.cc index 749335f1b09..d9080b6f69a 100644 --- a/tensorflow/compiler/tf2tensorrt/segment/segment.cc +++ b/tensorflow/compiler/tf2tensorrt/segment/segment.cc @@ -22,6 +22,7 @@ limitations under the License. #include "absl/strings/str_cat.h" #include "absl/strings/str_format.h" +#include "tensorflow/compiler/tf2tensorrt/common/utils.h" #include "tensorflow/compiler/tf2tensorrt/segment/union_find.h" #include "tensorflow/core/common_runtime/graph_constructor.h" #include "tensorflow/core/graph/algorithm.h" @@ -34,8 +35,7 @@ limitations under the License. #include "tensorflow/core/platform/types.h" #include "tensorflow/core/util/env_var.h" -#if GOOGLE_CUDA -#if GOOGLE_TENSORRT +#if GOOGLE_CUDA && GOOGLE_TENSORRT namespace tensorflow { namespace tensorrt { @@ -748,9 +748,10 @@ Status SegmentGraph(const Graph* tf_graph, exclude_node(status.error_message()); } else if (tftrt_op_blacklist.count(node->tf_node()->type_string())) { // WARNING verbosity since the user explicitly requests this behavior. - LOG(WARNING) << "Blacklisted as TF-TRT candidate, " - << "(Op type: " << node->tf_node()->type_string() << "), " - << "(Op name: " << node->name() << ")"; + LOG_WARNING_WITH_PREFIX + << "Blacklisted as TF-TRT candidate, " + << "(Op type: " << node->tf_node()->type_string() << "), " + << "(Op name: " << node->name() << ")"; exclude_node("Blacklisted with the env var TF_TRT_OP_BLACKLIST"); } else { VLOG(2) << "Accepted as a TF-TRT candidate, " @@ -1038,7 +1039,7 @@ Status SegmentGraph(const Graph* tf_graph, for (const auto& dev : dev_itr->second) { StrAppend(&s, dev, ", "); } - LOG(WARNING) << s; + LOG_WARNING_WITH_PREFIX << s; } segments->emplace_back(segment_nodes); @@ -1060,5 +1061,4 @@ Status SegmentGraph(const Graph* tf_graph, } // namespace tensorrt } // namespace tensorflow -#endif // GOOGLE_TENSORRT -#endif // GOOGLE_CUDA +#endif // GOOGLE_CUDA && GOOGLE_TENSORRT diff --git a/tensorflow/compiler/tf2tensorrt/segment/segment.h b/tensorflow/compiler/tf2tensorrt/segment/segment.h index 7295c8f0d9d..3f79983cfd2 100644 --- a/tensorflow/compiler/tf2tensorrt/segment/segment.h +++ b/tensorflow/compiler/tf2tensorrt/segment/segment.h @@ -25,8 +25,7 @@ limitations under the License. #include "tensorflow/core/lib/core/status.h" #include "tensorflow/core/platform/types.h" -#if GOOGLE_CUDA -#if GOOGLE_TENSORRT +#if GOOGLE_CUDA && GOOGLE_TENSORRT namespace tensorflow { namespace tensorrt { @@ -67,7 +66,6 @@ Status SegmentGraph(const Graph* tf_graph, } // namespace tensorrt } // namespace tensorflow -#endif // GOOGLE_TENSORRT -#endif // GOOGLE_CUDA +#endif // GOOGLE_CUDA && GOOGLE_TENSORRT #endif // TENSORFLOW_COMPILER_TF2TENSORRT_SEGMENT_SEGMENT_H_ diff --git a/tensorflow/compiler/tf2tensorrt/segment/segment_test.cc b/tensorflow/compiler/tf2tensorrt/segment/segment_test.cc index 2437481a9c4..f3bc5bfbee6 100644 --- a/tensorflow/compiler/tf2tensorrt/segment/segment_test.cc +++ b/tensorflow/compiler/tf2tensorrt/segment/segment_test.cc @@ -26,8 +26,7 @@ limitations under the License. #include "tensorflow/core/platform/types.h" #include "tensorflow/core/public/session.h" -#if GOOGLE_CUDA -#if GOOGLE_TENSORRT +#if GOOGLE_CUDA && GOOGLE_TENSORRT namespace tensorflow { namespace tensorrt { @@ -522,5 +521,4 @@ TEST_F(SegmentTest, IncompatibleBatchSizes) { } // namespace tensorrt } // namespace tensorflow -#endif // GOOGLE_TENSORRT -#endif // GOOGLE_CUDA +#endif // GOOGLE_CUDA && GOOGLE_TENSORRT diff --git a/tensorflow/compiler/tf2tensorrt/segment/union_find.h b/tensorflow/compiler/tf2tensorrt/segment/union_find.h index 70e83c12fca..b53615ec019 100644 --- a/tensorflow/compiler/tf2tensorrt/segment/union_find.h +++ b/tensorflow/compiler/tf2tensorrt/segment/union_find.h @@ -19,8 +19,7 @@ limitations under the License. #include "absl/strings/str_format.h" #include "absl/types/optional.h" -#if GOOGLE_CUDA -#if GOOGLE_TENSORRT +#if GOOGLE_CUDA && GOOGLE_TENSORRT namespace tensorflow { namespace tensorrt { @@ -217,7 +216,6 @@ UnionFind* UnionFind::FindRoot() { } // namespace tensorrt } // namespace tensorflow -#endif // GOOGLE_TENSORRT -#endif // GOOGLE_CUDA +#endif // GOOGLE_CUDA && GOOGLE_TENSORRT #endif // TENSORFLOW_COMPILER_TF2TENSORRT_SEGMENT_UNION_FIND_H_ diff --git a/tensorflow/compiler/tf2tensorrt/tensorrt_test.cc b/tensorflow/compiler/tf2tensorrt/tensorrt_test.cc index 510591bfe00..e994d20df33 100644 --- a/tensorflow/compiler/tf2tensorrt/tensorrt_test.cc +++ b/tensorflow/compiler/tf2tensorrt/tensorrt_test.cc @@ -18,8 +18,7 @@ limitations under the License. #include "tensorflow/core/platform/stream_executor.h" #include "tensorflow/core/platform/test.h" -#if GOOGLE_CUDA -#if GOOGLE_TENSORRT +#if GOOGLE_CUDA && GOOGLE_TENSORRT #include "third_party/gpus/cuda/include/cuda.h" #include "third_party/gpus/cuda/include/cuda_runtime_api.h" #include "third_party/tensorrt/NvInfer.h" @@ -164,5 +163,4 @@ TEST(TensorrtTest, BasicFunctions) { } // namespace } // namespace tensorflow -#endif // GOOGLE_TENSORRT -#endif // GOOGLE_CUDA +#endif // GOOGLE_CUDA && GOOGLE_TENSORRT diff --git a/tensorflow/compiler/tf2tensorrt/utils/py_utils.cc b/tensorflow/compiler/tf2tensorrt/utils/py_utils.cc index 885f58cd70c..a8e24aa8983 100644 --- a/tensorflow/compiler/tf2tensorrt/utils/py_utils.cc +++ b/tensorflow/compiler/tf2tensorrt/utils/py_utils.cc @@ -16,6 +16,7 @@ limitations under the License. #include "tensorflow/compiler/tf2tensorrt/utils/py_utils.h" #if GOOGLE_CUDA && GOOGLE_TENSORRT +#include "tensorflow/compiler/tf2tensorrt/common/utils.h" #include "tensorflow/stream_executor/platform/dso_loader.h" #include "third_party/tensorrt/NvInfer.h" #endif @@ -27,9 +28,10 @@ bool IsGoogleTensorRTEnabled() { #if GOOGLE_CUDA && GOOGLE_TENSORRT auto handle_or = se::internal::DsoLoader::TryDlopenTensorRTLibraries(); if (!handle_or.ok()) { - LOG(WARNING) << "Cannot dlopen some TensorRT libraries. If you would like " - "to use Nvidia GPU with TensorRT, please make sure the " - "missing libraries mentioned above are installed properly."; + LOG_WARNING_WITH_PREFIX + << "Cannot dlopen some TensorRT libraries. If you would like " + "to use Nvidia GPU with TensorRT, please make sure the " + "missing libraries mentioned above are installed properly."; return false; } else { return true; diff --git a/tensorflow/compiler/tf2tensorrt/utils/trt_allocator.cc b/tensorflow/compiler/tf2tensorrt/utils/trt_allocator.cc index 617ea7fad5c..d4f3a524577 100644 --- a/tensorflow/compiler/tf2tensorrt/utils/trt_allocator.cc +++ b/tensorflow/compiler/tf2tensorrt/utils/trt_allocator.cc @@ -17,11 +17,9 @@ limitations under the License. #include "tensorflow/core/platform/logging.h" -#if GOOGLE_CUDA -#if GOOGLE_TENSORRT +#if GOOGLE_CUDA && GOOGLE_TENSORRT #include "third_party/gpus/cuda/include/cuda_runtime_api.h" -#endif // GOOGLE_TENSORRT -#endif // GOOGLE_CUDA +#endif // GOOGLE_CUDA && GOOGLE_TENSORRT namespace tensorflow { namespace tensorrt { @@ -52,8 +50,7 @@ void* Align(uint64_t alignment, uint64_t size, void*& ptr, uint64_t& space) { } // namespace tensorrt } // namespace tensorflow -#if GOOGLE_CUDA -#if GOOGLE_TENSORRT +#if GOOGLE_CUDA && GOOGLE_TENSORRT namespace tensorflow { namespace tensorrt { @@ -113,5 +110,4 @@ void TRTDeviceAllocator::free(void* memory) { } // namespace tensorrt } // namespace tensorflow -#endif // GOOGLE_TENSORRT -#endif // GOOGLE_CUDA +#endif // GOOGLE_CUDA && GOOGLE_TENSORRT diff --git a/tensorflow/compiler/tf2tensorrt/utils/trt_allocator.h b/tensorflow/compiler/tf2tensorrt/utils/trt_allocator.h index 4ab8b52f523..d219a8a14e8 100644 --- a/tensorflow/compiler/tf2tensorrt/utils/trt_allocator.h +++ b/tensorflow/compiler/tf2tensorrt/utils/trt_allocator.h @@ -20,11 +20,9 @@ limitations under the License. #include "tensorflow/core/framework/allocator.h" -#if GOOGLE_CUDA -#if GOOGLE_TENSORRT +#if GOOGLE_CUDA && GOOGLE_TENSORRT #include "third_party/tensorrt/NvInfer.h" -#endif // GOOGLE_TENSORRT -#endif // GOOGLE_CUDA +#endif // GOOGLE_CUDA && GOOGLE_TENSORRT namespace tensorflow { namespace tensorrt { @@ -33,8 +31,7 @@ void* Align(uint64_t alignment, uint64_t size, void*& ptr, uint64_t& space); } // namespace tensorrt } // namespace tensorflow -#if GOOGLE_CUDA -#if GOOGLE_TENSORRT +#if GOOGLE_CUDA && GOOGLE_TENSORRT namespace tensorflow { namespace tensorrt { @@ -69,6 +66,5 @@ class TRTDeviceAllocator : public TRTBaseAllocator { } // namespace tensorrt } // namespace tensorflow -#endif // GOOGLE_TENSORRT -#endif // GOOGLE_CUDA +#endif // GOOGLE_CUDA && GOOGLE_TENSORRT #endif // TENSORFLOW_COMPILER_TF2TENSORRT_UTILS_TRT_ALLOCATOR_H_ diff --git a/tensorflow/compiler/tf2tensorrt/utils/trt_engine_utils.cc b/tensorflow/compiler/tf2tensorrt/utils/trt_engine_utils.cc index 213c1732e59..8ccfb8b06f0 100644 --- a/tensorflow/compiler/tf2tensorrt/utils/trt_engine_utils.cc +++ b/tensorflow/compiler/tf2tensorrt/utils/trt_engine_utils.cc @@ -25,8 +25,7 @@ limitations under the License. #include "tensorflow/core/lib/core/status.h" #include "tensorflow/core/platform/errors.h" -#if GOOGLE_CUDA -#if GOOGLE_TENSORRT +#if GOOGLE_CUDA && GOOGLE_TENSORRT #include "third_party/tensorrt/NvInfer.h" namespace tensorflow { @@ -46,6 +45,14 @@ Status GetTrtBindingShape(const nvinfer1::ICudaEngine* cuda_engine, // Get dims from context instead of engine in explicit batch mode because // the engine might have dynamic shapes. dims = execution_context->getBindingDimensions(binding_index); + if (dims.nbDims == -1) { + // Invalid dimensions. There can be multiple reasons for this. If we have + // incompatible input shapes (network invalid for the current profile) + // that can trigger this error. + return errors::Internal( + "Binding index out of range. This can happen if profile is not set, " + "or the network is invalid for the current profile."); + } #else return errors::Internal( "Explicit batch mode is only supported with TensorRT 6 and above."); @@ -249,5 +256,4 @@ Status TrtEnqueue(nvinfer1::IExecutionContext* execution_context, } // namespace tensorrt } // namespace tensorflow -#endif // GOOGLE_TENSORRT -#endif // GOOGLE_CUDA +#endif // GOOGLE_CUDA && GOOGLE_TENSORRT diff --git a/tensorflow/compiler/tf2tensorrt/utils/trt_engine_utils.h b/tensorflow/compiler/tf2tensorrt/utils/trt_engine_utils.h index a471749877a..1ea4fe28cb4 100644 --- a/tensorflow/compiler/tf2tensorrt/utils/trt_engine_utils.h +++ b/tensorflow/compiler/tf2tensorrt/utils/trt_engine_utils.h @@ -24,8 +24,7 @@ limitations under the License. #include "tensorflow/core/framework/tensor_shape.h" #include "tensorflow/core/lib/core/status.h" -#if GOOGLE_CUDA -#if GOOGLE_TENSORRT +#if GOOGLE_CUDA && GOOGLE_TENSORRT #include "third_party/tensorrt/NvInfer.h" namespace tensorflow { @@ -91,7 +90,6 @@ Status TrtEnqueue(nvinfer1::IExecutionContext* execution_context, } // namespace tensorrt } // namespace tensorflow -#endif // GOOGLE_TENSORRT -#endif // GOOGLE_CUDA +#endif // GOOGLE_CUDA && GOOGLE_TENSORRT #endif // TENSORFLOW_COMPILER_TF2TENSORRT_UTILS_TRT_ENGINE_UTILS_H_ diff --git a/tensorflow/compiler/tf2tensorrt/utils/trt_int8_calibrator.cc b/tensorflow/compiler/tf2tensorrt/utils/trt_int8_calibrator.cc index 554c127fa37..24271e352a7 100644 --- a/tensorflow/compiler/tf2tensorrt/utils/trt_int8_calibrator.cc +++ b/tensorflow/compiler/tf2tensorrt/utils/trt_int8_calibrator.cc @@ -20,8 +20,7 @@ limitations under the License. #include "tensorflow/core/platform/logging.h" -#if GOOGLE_CUDA -#if GOOGLE_TENSORRT +#if GOOGLE_CUDA && GOOGLE_TENSORRT #include "third_party/gpus/cuda/include/cuda_runtime_api.h" namespace tensorflow { @@ -147,5 +146,4 @@ TRTInt8Calibrator::~TRTInt8Calibrator() { } // namespace tensorrt } // namespace tensorflow -#endif -#endif +#endif // GOOGLE_CUDA && GOOGLE_TENSORRT diff --git a/tensorflow/compiler/tf2tensorrt/utils/trt_int8_calibrator.h b/tensorflow/compiler/tf2tensorrt/utils/trt_int8_calibrator.h index 06b39716490..4c670e85f52 100644 --- a/tensorflow/compiler/tf2tensorrt/utils/trt_int8_calibrator.h +++ b/tensorflow/compiler/tf2tensorrt/utils/trt_int8_calibrator.h @@ -22,8 +22,7 @@ limitations under the License. #include #include "tensorflow/core/platform/mutex.h" -#if GOOGLE_CUDA -#if GOOGLE_TENSORRT +#if GOOGLE_CUDA && GOOGLE_TENSORRT #include "third_party/gpus/cuda/include/cuda_runtime_api.h" #include "third_party/tensorrt/NvInfer.h" @@ -101,6 +100,5 @@ struct TRTInt8Calibrator : public nvinfer1::IInt8EntropyCalibrator { } // namespace tensorrt } // namespace tensorflow -#endif -#endif +#endif // GOOGLE_CUDA && GOOGLE_TENSORRT #endif // TENSORFLOW_COMPILER_TF2TENSORRT_UTILS_TRT_INT8_CALIBRATOR_H_ diff --git a/tensorflow/compiler/tf2tensorrt/utils/trt_logger.cc b/tensorflow/compiler/tf2tensorrt/utils/trt_logger.cc index 6bb6f1f9dd8..e34bf5e7397 100644 --- a/tensorflow/compiler/tf2tensorrt/utils/trt_logger.cc +++ b/tensorflow/compiler/tf2tensorrt/utils/trt_logger.cc @@ -15,8 +15,8 @@ limitations under the License. #include "tensorflow/compiler/tf2tensorrt/utils/trt_logger.h" -#if GOOGLE_CUDA -#if GOOGLE_TENSORRT +#if GOOGLE_CUDA && GOOGLE_TENSORRT +#include "tensorflow/compiler/tf2tensorrt/common/utils.h" #include "tensorflow/compiler/tf2tensorrt/convert/logger_registry.h" #include "tensorflow/core/platform/logging.h" @@ -35,7 +35,7 @@ void Logger::log(Severity severity, const char* msg) { break; } case Severity::kWARNING: { - LOG(WARNING) << name_ << " " << msg; + LOG_WARNING_WITH_PREFIX << name_ << " " << msg; break; } case Severity::kERROR: { @@ -67,5 +67,4 @@ REGISTER_TENSORRT_LOGGER("DefaultLogger", Logger::GetLogger()); } // namespace tensorrt } // namespace tensorflow -#endif // GOOGLE_CUDA -#endif // GOOGLE_TENSORRT +#endif // GOOGLE_CUDA && GOOGLE_TENSORRT diff --git a/tensorflow/compiler/tf2tensorrt/utils/trt_logger.h b/tensorflow/compiler/tf2tensorrt/utils/trt_logger.h index 2ade1b48f47..ce6552e8fe9 100644 --- a/tensorflow/compiler/tf2tensorrt/utils/trt_logger.h +++ b/tensorflow/compiler/tf2tensorrt/utils/trt_logger.h @@ -18,8 +18,7 @@ limitations under the License. #include "tensorflow/core/platform/types.h" -#if GOOGLE_CUDA -#if GOOGLE_TENSORRT +#if GOOGLE_CUDA && GOOGLE_TENSORRT #include "third_party/tensorrt/NvInfer.h" namespace tensorflow { @@ -40,7 +39,6 @@ class Logger : public nvinfer1::ILogger { } // namespace tensorrt } // namespace tensorflow -#endif // GOOGLE_TENSORRT -#endif // GOOGLE_CUDA +#endif // GOOGLE_CUDA && GOOGLE_TENSORRT #endif // TENSORFLOW_COMPILER_TF2TENSORRT_UTILS_TRT_LOGGER_H_ diff --git a/tensorflow/compiler/tf2tensorrt/utils/trt_lru_cache.cc b/tensorflow/compiler/tf2tensorrt/utils/trt_lru_cache.cc index fbcdaad52c0..ee7e6272372 100644 --- a/tensorflow/compiler/tf2tensorrt/utils/trt_lru_cache.cc +++ b/tensorflow/compiler/tf2tensorrt/utils/trt_lru_cache.cc @@ -23,8 +23,7 @@ limitations under the License. #include "tensorflow/core/framework/tensor_shape.h" #include "tensorflow/core/platform/mutex.h" -#if GOOGLE_CUDA -#if GOOGLE_TENSORRT +#if GOOGLE_CUDA && GOOGLE_TENSORRT #include "third_party/tensorrt/NvInfer.h" namespace tensorflow { @@ -141,5 +140,4 @@ EngineContext* TRTEngineCacheResource::GetEngineContext(const int profile_id) { } // namespace tensorrt } // namespace tensorflow -#endif // GOOGLE_TENSORRT -#endif // GOOGLE_CUDA +#endif // GOOGLE_CUDA && GOOGLE_TENSORRT diff --git a/tensorflow/compiler/tf2tensorrt/utils/trt_lru_cache.h b/tensorflow/compiler/tf2tensorrt/utils/trt_lru_cache.h index 8e345254f75..991b9a949e4 100644 --- a/tensorflow/compiler/tf2tensorrt/utils/trt_lru_cache.h +++ b/tensorflow/compiler/tf2tensorrt/utils/trt_lru_cache.h @@ -115,8 +115,7 @@ class LRUCache { } }; -#if GOOGLE_CUDA -#if GOOGLE_TENSORRT +#if GOOGLE_CUDA && GOOGLE_TENSORRT struct EngineContext { EngineContext() {} // Creates an empty context. @@ -223,8 +222,7 @@ class TRTEngineCacheResource : public ResourceBase { TrtShapeOptimizationProfile profiles_; }; -#endif // GOOGLE_TENSORRT -#endif // GOOGLE_CUDA +#endif // GOOGLE_CUDA && GOOGLE_TENSORRT } // namespace tensorrt } // namespace tensorflow diff --git a/tensorflow/compiler/tf2tensorrt/utils/trt_shape_optimization_profiles.h b/tensorflow/compiler/tf2tensorrt/utils/trt_shape_optimization_profiles.h index 40c7f5dcf31..fc688b14139 100644 --- a/tensorflow/compiler/tf2tensorrt/utils/trt_shape_optimization_profiles.h +++ b/tensorflow/compiler/tf2tensorrt/utils/trt_shape_optimization_profiles.h @@ -29,8 +29,7 @@ limitations under the License. #include "tensorflow/core/lib/strings/str_util.h" #include "tensorflow/core/lib/strings/strcat.h" -#if GOOGLE_CUDA -#if GOOGLE_TENSORRT +#if GOOGLE_CUDA && GOOGLE_TENSORRT #include "third_party/tensorrt/NvInfer.h" @@ -173,6 +172,5 @@ class TrtShapeOptimizationProfile { } // namespace tensorrt } // namespace tensorflow -#endif // GOOGLE_TENSORRT -#endif // GOOGLE_CUDA +#endif // GOOGLE_CUDA && GOOGLE_TENSORRT #endif // TENSORFLOW_COMPILER_TF2TENSORRT_UTILS_TRT_SHAPE_OPTIMIZATION_PROFILES_H_ diff --git a/tensorflow/compiler/tf2tensorrt/utils/trt_shape_optimization_profiles_test.cc b/tensorflow/compiler/tf2tensorrt/utils/trt_shape_optimization_profiles_test.cc index 501810587e0..32c2200fb71 100644 --- a/tensorflow/compiler/tf2tensorrt/utils/trt_shape_optimization_profiles_test.cc +++ b/tensorflow/compiler/tf2tensorrt/utils/trt_shape_optimization_profiles_test.cc @@ -13,8 +13,7 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#if GOOGLE_CUDA -#if GOOGLE_TENSORRT +#if GOOGLE_CUDA && GOOGLE_TENSORRT #include @@ -214,5 +213,4 @@ TEST_F(TrtShapeOptimizationProfileTest, Dynamic) { } // namespace tensorrt } // namespace tensorflow -#endif // GOOGLE_TENSORRT -#endif // GOOGLE_CUDA +#endif // GOOGLE_CUDA && GOOGLE_TENSORRT diff --git a/tensorflow/compiler/tf2xla/kernels/BUILD b/tensorflow/compiler/tf2xla/kernels/BUILD index bfdfe38305b..e072225566d 100644 --- a/tensorflow/compiler/tf2xla/kernels/BUILD +++ b/tensorflow/compiler/tf2xla/kernels/BUILD @@ -32,6 +32,7 @@ tf_kernel_library( "data_format_ops.cc", "depthtospace_op.cc", "dequantize_op.cc", + "device_index_op.cc", "diag_op.cc", "dynamic_slice_ops.cc", "dynamic_stitch_op.cc", @@ -316,6 +317,7 @@ tf_kernel_library( "//tensorflow/compiler/tf2xla/ops:xla_ops", "//tensorflow/compiler/xla:literal", "//tensorflow/compiler/xla/client:xla_builder", + "//tensorflow/compiler/xla/client/lib:constants", "//tensorflow/core:framework", "//tensorflow/core:lib", "//tensorflow/core:protos_all_cc", diff --git a/tensorflow/compiler/tf2xla/kernels/case_op.cc b/tensorflow/compiler/tf2xla/kernels/case_op.cc index 1b15c09f7e3..fbd54f1ef39 100644 --- a/tensorflow/compiler/tf2xla/kernels/case_op.cc +++ b/tensorflow/compiler/tf2xla/kernels/case_op.cc @@ -21,13 +21,14 @@ limitations under the License. #include "tensorflow/compiler/tf2xla/xla_context.h" #include "tensorflow/compiler/tf2xla/xla_op_kernel.h" #include "tensorflow/compiler/tf2xla/xla_op_registry.h" +#include "tensorflow/compiler/xla/client/lib/constants.h" #include "tensorflow/compiler/xla/client/xla_builder.h" #include "tensorflow/core/lib/core/errors.h" namespace tensorflow { XlaCaseOp::XlaCaseOp(OpKernelConstruction* ctx) : XlaOpKernel(ctx) { - OP_REQUIRES_OK(ctx, ctx->GetAttr("branches", &branches_)); + OP_REQUIRES_OK(ctx, ctx->GetAttr("branches", &unpruned_branches_)); OP_REQUIRES_OK(ctx, ctx->GetAttr("Tin", &input_types_)); OP_REQUIRES_OK(ctx, ctx->GetAttr("Tout", &output_types_)); if (!ctx->GetAttr(kXlaTokenInputNodesAttrName, &token_input_nodes_).ok()) { @@ -41,12 +42,29 @@ XlaCaseOp::XlaCaseOp(OpKernelConstruction* ctx) : XlaOpKernel(ctx) { } } +std::pair, xla::XlaOp> +XlaCaseOp::GetPrunedBranchesAndIndex(XlaOpKernelContext* ctx) { + xla::Literal branch_index_literal; + bool branch_index_is_constant = + ctx->ConstantInput(0, &branch_index_literal).ok(); + + if (!branch_index_is_constant) { + return {unpruned_branches_, ctx->Input(0)}; + } + + int32 branch_index = branch_index_literal.Get({}); + if (branch_index < 0 || branch_index >= unpruned_branches_.size()) { + branch_index = unpruned_branches_.size() - 1; + } + + std::vector pruned_branch = {unpruned_branches_[branch_index]}; + return {pruned_branch, xla::ZerosLike(ctx->Input(0))}; +} + // TODO(b/35949885): There is duplication here with the handling of the // while_op/if_op. Refactor the common code out/rework. void XlaCaseOp::Compile(XlaOpKernelContext* ctx) { - xla::XlaBuilder* b = ctx->builder(); - int num_branches = branches_.size(); - OP_REQUIRES(ctx, num_branches >= 1, + OP_REQUIRES(ctx, !unpruned_branches_.empty(), errors::InvalidArgument("Must provide at least one case branch")); OP_REQUIRES(ctx, input_type(0) == DT_INT32, errors::InvalidArgument( @@ -55,6 +73,18 @@ void XlaCaseOp::Compile(XlaOpKernelContext* ctx) { errors::InvalidArgument( "branch_index argument must be scalar for XLA compilation")); + xla::XlaBuilder* b = ctx->builder(); + + // We opportunistically prune out branches if the branch index is a + // compile-time constant. This is important in the context of the DeviceIndex + // ops (and other such ops that may come later) since we may have a Case with + // trivially unselected branches that cannot be compiled into HLO. + std::vector branches; + xla::XlaOp branch_index; + std::tie(branches, branch_index) = GetPrunedBranchesAndIndex(ctx); + + int num_branches = branches.size(); + VLOG(1) << "Building Case: " << input_types_.size() << " inputs"; std::vector arguments(input_types_.size()); @@ -94,7 +124,7 @@ void XlaCaseOp::Compile(XlaOpKernelContext* ctx) { std::vector case_bodies(num_branches); for (int branch_idx = 0; branch_idx < num_branches; branch_idx++) { OP_REQUIRES_OK(ctx, FindMustBeConstNodes( - ctx, branches_[branch_idx], + ctx, branches[branch_idx], &case_branch_must_be_const_nodes[branch_idx], &case_bodies[branch_idx])); } @@ -133,7 +163,7 @@ void XlaCaseOp::Compile(XlaOpKernelContext* ctx) { std::vector branch_results_p(num_branches); for (int j = 0; j < num_branches; ++j) { OP_REQUIRES_OK(ctx, - compiler->CompileFunction(options, branches_[j], arguments, + compiler->CompileFunction(options, branches[j], arguments, &branch_results[j])); branch_results_p[j] = &branch_results[j]; } @@ -171,7 +201,7 @@ void XlaCaseOp::Compile(XlaOpKernelContext* ctx) { for (int j = 0; j < num_branches; ++j) { branch_results[j] = {}; OP_REQUIRES_OK(ctx, - compiler->CompileFunction(options, branches_[j], arguments, + compiler->CompileFunction(options, branches[j], arguments, &branch_results[j])); } } @@ -277,7 +307,7 @@ void XlaCaseOp::Compile(XlaOpKernelContext* ctx) { auto input_tuple = xla::Tuple(b, inputs); xla::XlaOp outputs = - xla::Conditional(ctx->Input(0), absl::MakeSpan(result_computations), + xla::Conditional(branch_index, absl::MakeSpan(result_computations), std::vector(num_branches, input_tuple)); // Sets non-variable outputs. for (int i = 0; i < output_types_.size(); ++i) { diff --git a/tensorflow/compiler/tf2xla/kernels/case_op.h b/tensorflow/compiler/tf2xla/kernels/case_op.h index 4a61707864e..4d22a3db830 100644 --- a/tensorflow/compiler/tf2xla/kernels/case_op.h +++ b/tensorflow/compiler/tf2xla/kernels/case_op.h @@ -50,7 +50,16 @@ class XlaCaseOp : public XlaOpKernel { private: TF_DISALLOW_COPY_AND_ASSIGN(XlaCaseOp); - std::vector branches_; + // If the branch_index input is a constant: prunes out all but the branch + // corrresponding to that constant branch index, and returns that branch and + // the literal 0 (as the first and second component of the pair). + // + // If the branch_index input is not a constant: returns unpruned_branches_ and + // the branch_index input. + std::pair, xla::XlaOp> GetPrunedBranchesAndIndex( + XlaOpKernelContext* ctx); + + std::vector unpruned_branches_; DataTypeVector input_types_; DataTypeVector output_types_; bool has_token_input_output_; diff --git a/tensorflow/compiler/tf2xla/kernels/device_index_op.cc b/tensorflow/compiler/tf2xla/kernels/device_index_op.cc new file mode 100644 index 00000000000..ff058f92cd7 --- /dev/null +++ b/tensorflow/compiler/tf2xla/kernels/device_index_op.cc @@ -0,0 +1,51 @@ +/* Copyright 2020 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "absl/container/flat_hash_map.h" +#include "absl/strings/string_view.h" +#include "tensorflow/compiler/tf2xla/xla_helpers.h" +#include "tensorflow/compiler/tf2xla/xla_op_kernel.h" +#include "tensorflow/compiler/tf2xla/xla_op_registry.h" +#include "tensorflow/compiler/xla/client/client_library.h" +#include "tensorflow/compiler/xla/client/lib/arithmetic.h" +#include "tensorflow/compiler/xla/client/lib/constants.h" +#include "tensorflow/compiler/xla/client/lib/math.h" +#include "tensorflow/compiler/xla/client/xla_builder.h" +#include "tensorflow/core/framework/kernel_def_builder.h" + +namespace tensorflow { +namespace { + +class DeviceIndexOp : public XlaOpKernel { + public: + explicit DeviceIndexOp(OpKernelConstruction* ctx) : XlaOpKernel(ctx) { + OP_REQUIRES_OK(ctx, ctx->GetAttr("device_names", &device_names_)); + } + + void Compile(XlaOpKernelContext* ctx) override { + // When compiling we are not executing on any physical device, so we return + // a sentinel value (size of the list of devices). + ctx->SetOutput( + 0, xla::ConstantR0(ctx->builder(), device_names_.size())); + } + + private: + std::vector device_names_; +}; + +REGISTER_XLA_OP(Name("DeviceIndex"), DeviceIndexOp); + +} // namespace +} // namespace tensorflow diff --git a/tensorflow/compiler/tf2xla/kernels/strided_slice_op.cc b/tensorflow/compiler/tf2xla/kernels/strided_slice_op.cc index 2684c982600..784b790767c 100644 --- a/tensorflow/compiler/tf2xla/kernels/strided_slice_op.cc +++ b/tensorflow/compiler/tf2xla/kernels/strided_slice_op.cc @@ -350,6 +350,30 @@ class StridedSliceGradOp : public XlaOpKernel { grad = xla::Rev(grad, dimensions_to_reverse); } grad = xla::Pad(grad, zero, padding_config); + + xla::XlaOp dynamic_shape = ctx->Input(0); + xla::Shape grad_shape = ctx->builder()->GetShape(grad).ValueOrDie(); + ctx->set_dynamic_dimension_is_minus_one(true); + std::vector dynamic_size; + OP_REQUIRES_OK(ctx, ctx->ConstantInputAsIntVector(0, &dynamic_size)); + // Input of strided_slice_op has to have the same shape as output. + DCHECK_EQ(grad_shape.rank(), input_shape.dims()); + for (int64 dim = 0; dim < input_shape.dims(); ++dim) { + DCHECK_EQ(grad_shape.dimensions(dim), input_shape.dim_size(dim)); + if (dynamic_size[dim] == -1) { + // Input is a dynamic dimension, set the same dynamic dimension size in + // the output. + auto dim_size = xla::Slice(dynamic_shape, {dim}, {dim + 1}, {1}); + auto dim_size_scalar = + xla::Reshape(xla::ShapeUtil::MakeScalarShape(xla::S32), dim_size); + grad = xla::SetDimensionSize(grad, dim_size_scalar, dim); + } else if (grad_shape.is_dynamic_dimension(dim)) { + // Input is static but output is dynamic, respect input and remove any + // dynamic dim in the output. + grad = xla::RemoveDynamicDimension(grad, dim); + } + } + ctx->SetOutput(0, grad); } diff --git a/tensorflow/compiler/tf2xla/mlir_tf2xla.cc b/tensorflow/compiler/tf2xla/mlir_tf2xla.cc index 43793be56a7..60d1f3da0c5 100644 --- a/tensorflow/compiler/tf2xla/mlir_tf2xla.cc +++ b/tensorflow/compiler/tf2xla/mlir_tf2xla.cc @@ -165,11 +165,6 @@ Status ConvertGraphDefToXlaViaMlir( device_set.AddDevice(&device); AddDevicesToOp(*module, &device_set); - if (failed(mlir::TF::MarkFunctionVisibilityUsingEntryFunctionSpecification( - *module))) { - return errors::Internal("Problem with mark function visibility"); - } - TF_RETURN_IF_ERROR(mlir::TF::RunBridgeWithStandardPipeline( *module, /*enable_logging=*/VLOG_IS_ON(1), /*enable_inliner=*/true)); diff --git a/tensorflow/compiler/tf2xla/python/xla.py b/tensorflow/compiler/tf2xla/python/xla.py index c59c47e92fb..0ebca2d546f 100644 --- a/tensorflow/compiler/tf2xla/python/xla.py +++ b/tensorflow/compiler/tf2xla/python/xla.py @@ -37,6 +37,7 @@ from tensorflow.python.ops import gen_math_ops from tensorflow.python.ops import gen_random_ops from tensorflow.python.ops import math_ops from tensorflow.python.ops import random_ops +from tensorflow.python.ops import special_math_ops # TODO(phawkins): provide wrappers for all XLA operators. Currently the missing # ops include: @@ -103,8 +104,8 @@ sign = _unary_op(math_ops.sign) tanh = _unary_op(math_ops.tanh) # Bessel -bessel_i0e = _unary_op(math_ops.bessel_i0e) -bessel_i1e = _unary_op(math_ops.bessel_i1e) +bessel_i0e = _unary_op(special_math_ops.bessel_i0e) +bessel_i1e = _unary_op(special_math_ops.bessel_i1e) # Binary operators diff --git a/tensorflow/compiler/tf2xla/xla_compiler.cc b/tensorflow/compiler/tf2xla/xla_compiler.cc index 1cf3e10b774..c1aef3ff690 100644 --- a/tensorflow/compiler/tf2xla/xla_compiler.cc +++ b/tensorflow/compiler/tf2xla/xla_compiler.cc @@ -268,6 +268,7 @@ Status BuildComputation( return a->arg_num() < b->arg_num(); }); + std::vector aliases; for (const XlaResource* resource : arg_resources) { DCHECK_LT(resource->arg_num(), args.size()); const XlaCompiler::Argument& arg = args[resource->arg_num()]; @@ -289,20 +290,19 @@ Status BuildComputation( update.type = resource->type(); update.shape = resource->shape(); update.modified = modified; - if (is_entry_computation && always_return_tuple && + if (is_entry_computation && arg.resource_kind != XlaResource::kTensorArray && alias_resource_update) { // Assuming tuple arg and results are used. - int64 output_index = elems.size(); - if (use_tuple_arg) { - builder->SetUpAlias(/*output_index=*/{output_index}, - /*param_number=*/0, - /*param_index=*/{update.input_index}); - } else { - builder->SetUpAlias(/*output_index=*/{output_index}, - /*param_number=*/update.input_index, - /*param_index=*/{}); - } + xla::ShapeIndex param_index = + use_tuple_arg ? xla::ShapeIndex({update.input_index}) + : xla::ShapeIndex{}; + int param_number = use_tuple_arg ? 0 : update.input_index; + int64 output_index_num = elems.size(); + xla::ShapeIndex output_index = xla::ShapeIndex({output_index_num}); + VLOG(3) << "Storing alias: " << output_index.ToString() << ": (" + << param_number << ", " << param_index.ToString() << ")"; + aliases.push_back({output_index, param_number, param_index}); } for (const auto& grad : resource->tensor_array_gradients()) { update.tensor_array_gradients_accessed.insert(grad.first); @@ -381,8 +381,25 @@ Status BuildComputation( xla::XlaScopedShardingAssignment assign_sharding(builder, op_sharding); tuple = xla::Tuple(builder, elems); } - if (!always_return_tuple && elems.size() == 1) { + bool returns_tuple = always_return_tuple || elems.size() != 1; + VLOG(3) << "Computation returns a tuple=" << returns_tuple; + if (!returns_tuple) { xla::GetTupleElement(tuple, 0); + + for (xla::XlaBuilder::InputOutputAlias& alias : aliases) { + if (alias.output_index == xla::ShapeIndex({0})) { + VLOG(3) << "For aliased parameter " << alias.param_number << ": " + << alias.param_index.ToString() + << " normalizing output_index from {0} to {}, as a scalar is " + "returned from the cluster"; + alias.output_index = xla::ShapeIndex({}); + } + } + } + + for (xla::XlaBuilder::InputOutputAlias& alias : aliases) { + builder->SetUpAlias(alias.output_index, alias.param_number, + alias.param_index); } xla::StatusOr computation_status = builder->Build(); diff --git a/tensorflow/compiler/xla/client/local_client.cc b/tensorflow/compiler/xla/client/local_client.cc index afe115deda8..5fc9909fa2a 100644 --- a/tensorflow/compiler/xla/client/local_client.cc +++ b/tensorflow/compiler/xla/client/local_client.cc @@ -176,15 +176,23 @@ StatusOr LocalExecutable::Run( for (const ShapedBuffer* const arg : arguments) { argument_shapes.push_back(&arg->on_host_shape()); } - TF_ASSIGN_OR_RETURN(auto options_and_stream, - RunHelper(argument_shapes, run_options)); - ExecutableRunOptions options = options_and_stream.first.run_options(); - options.set_device_ordinal(-1); - auto result = RunAsync(arguments, options); - Status block_status = options.stream()->BlockHostUntilDone(); - TF_RETURN_IF_ERROR(result.status()); - TF_RETURN_IF_ERROR(block_status); - return result; + return AsyncCallAndBlockHostUntilDone( + argument_shapes, run_options, [&](const ExecutableRunOptions& options) { + return RunAsync(arguments, options); + }); +} + +StatusOr LocalExecutable::Run( + std::vector arguments, ExecutableRunOptions run_options) { + std::vector argument_shapes; + argument_shapes.reserve(arguments.size()); + for (const ExecutionInput& arg : arguments) { + argument_shapes.push_back(&arg.shape()); + } + return AsyncCallAndBlockHostUntilDone( + argument_shapes, run_options, [&](const ExecutableRunOptions& options) { + return RunAsync(argument_shapes, std::move(arguments), options); + }); } static std::shared_ptr DumpArguments( @@ -312,6 +320,16 @@ StatusOr LocalExecutable::RunAsync( return std::move(outputs); } +StatusOr LocalExecutable::RunAsync( + std::vector arguments, ExecutableRunOptions run_options) { + std::vector argument_shapes; + argument_shapes.reserve(arguments.size()); + for (const ExecutionInput& arg : arguments) { + argument_shapes.push_back(&arg.shape()); + } + return RunAsync(argument_shapes, std::move(arguments), run_options); +} + se::Platform* LocalClient::platform() const { return local_service_->backend().platform(); } diff --git a/tensorflow/compiler/xla/client/local_client.h b/tensorflow/compiler/xla/client/local_client.h index 7cdeb9dcbf6..8b91f4a1739 100644 --- a/tensorflow/compiler/xla/client/local_client.h +++ b/tensorflow/compiler/xla/client/local_client.h @@ -51,6 +51,11 @@ class LocalExecutable { const absl::Span arguments, ExecutableRunOptions run_options); + // Similar to Run(), but allows for donating argument buffers to the + // executable. + StatusOr Run(std::vector arguments, + ExecutableRunOptions run_options); + // Similar to Run(), but need not block the host waiting for the computation // to complete before returning. StatusOr RunAsync( @@ -63,6 +68,9 @@ class LocalExecutable { absl::Span argument_host_shapes, std::vector arguments, ExecutableRunOptions run_options); + StatusOr RunAsync(std::vector arguments, + ExecutableRunOptions run_options); + // Return the options used to build the executable. const ExecutableBuildOptions& build_options() const { return build_options_; } @@ -90,6 +98,22 @@ class LocalExecutable { // Backend::devices_equivalent). int build_device_ordinal() const { return build_options_.device_ordinal(); } + template + StatusOr AsyncCallAndBlockHostUntilDone( + absl::Span argument_shapes, + const ExecutableRunOptions& run_options, + std::function(const ExecutableRunOptions&)> async_callback) { + TF_ASSIGN_OR_RETURN(auto options_and_stream, + RunHelper(argument_shapes, run_options)); + ExecutableRunOptions options = options_and_stream.first.run_options(); + options.set_device_ordinal(-1); + StatusOr result = async_callback(options); + Status block_status = options.stream()->BlockHostUntilDone(); + TF_RETURN_IF_ERROR(result.status()); + TF_RETURN_IF_ERROR(block_status); + return result; + } + // Compiled computation. std::unique_ptr executable_; diff --git a/tensorflow/compiler/xla/client/xla_builder.cc b/tensorflow/compiler/xla/client/xla_builder.cc index bfba48862f6..56e9aba6112 100644 --- a/tensorflow/compiler/xla/client/xla_builder.cc +++ b/tensorflow/compiler/xla/client/xla_builder.cc @@ -1564,16 +1564,12 @@ XlaOp XlaBuilder::CustomCall( const Shape& shape, const string& opaque, absl::optional> operand_shapes_with_layout) { return ReportErrorOrReturn([&]() -> StatusOr { - HloInstructionProto instr; if (absl::StartsWith(call_target_name, "$")) { return InvalidArgument( "Invalid custom_call_target \"%s\": Call targets that start with '$' " "are reserved for internal use.", call_target_name); } - *instr.mutable_shape() = shape.ToProto(); - instr.set_custom_call_target(call_target_name); - instr.set_backend_config(opaque); if (operand_shapes_with_layout.has_value()) { if (!LayoutUtil::HasLayout(shape)) { return InvalidArgument( @@ -1586,7 +1582,6 @@ XlaOp XlaBuilder::CustomCall( "with constrained layout; given %d shapes, expected %d", operand_shapes_with_layout->size(), operands.size()); } - instr.set_constrain_layout(true); int64 operand_num = 0; for (const Shape& operand_shape : *operand_shapes_with_layout) { if (!LayoutUtil::HasLayout(operand_shape)) { @@ -1595,14 +1590,31 @@ XlaOp XlaBuilder::CustomCall( "constrained layout.", operand_num); } - *instr.add_operand_shapes_with_layout() = operand_shape.ToProto(); ++operand_num; } } - return AddInstruction(std::move(instr), HloOpcode::kCustomCall, operands); + return CustomCallInternal(call_target_name, operands, shape, opaque, + operand_shapes_with_layout); }); } +StatusOr XlaBuilder::CustomCallInternal( + const string& call_target_name, absl::Span operands, + const Shape& shape, const string& opaque, + absl::optional> operand_shapes_with_layout) { + HloInstructionProto instr; + *instr.mutable_shape() = shape.ToProto(); + instr.set_custom_call_target(call_target_name); + instr.set_backend_config(opaque); + if (operand_shapes_with_layout.has_value()) { + instr.set_constrain_layout(true); + for (const Shape& operand_shape : *operand_shapes_with_layout) { + *instr.add_operand_shapes_with_layout() = operand_shape.ToProto(); + } + } + return AddInstruction(std::move(instr), HloOpcode::kCustomCall, operands); +} + XlaOp XlaBuilder::CustomCall( const string& call_target_name, absl::Span operands, const XlaComputation& computation, const Shape& shape, const string& opaque, @@ -2727,13 +2739,34 @@ XlaOp XlaBuilder::GetDimensionSize(XlaOp operand, int64 dimension) { }); } -XlaOp XlaBuilder::SetDimensionSize(XlaOp operand, XlaOp val, int64 dimension) { +XlaOp XlaBuilder::RemoveDynamicDimension(XlaOp operand, int64 dimension) { return ReportErrorOrReturn([&]() -> StatusOr { HloInstructionProto instr; TF_ASSIGN_OR_RETURN(const Shape* operand_shape, GetShapePtr(operand)); - TF_ASSIGN_OR_RETURN(Shape shape, ShapeInference::InferSetDimensionSizeShape( - *operand_shape, dimension)); + Shape shape = *operand_shape; + shape.set_dynamic_dimension(dimension, false); + // Setting an op's dynamic dimension to its static size removes the dynamic + // dimension. + XlaOp static_size = + ConstantR0(this, operand_shape->dimensions(dimension)); + + *instr.mutable_shape() = shape.ToProto(); + instr.add_dimensions(dimension); + return AddInstruction(std::move(instr), HloOpcode::kSetDimensionSize, + {operand, static_size}); + }); +} + +XlaOp XlaBuilder::SetDimensionSize(XlaOp operand, XlaOp val, int64 dimension) { + return ReportErrorOrReturn([&]() -> StatusOr { + HloInstructionProto instr; + TF_ASSIGN_OR_RETURN(const Shape* operand_shape, GetShapePtr(operand)); + TF_ASSIGN_OR_RETURN(const Shape* val_shape, GetShapePtr(val)); + + TF_ASSIGN_OR_RETURN(Shape shape, + ShapeInference::InferSetDimensionSizeShape( + *operand_shape, *val_shape, dimension)); // Setting an op's dynamic dimension to the static size is a noop. TF_ASSIGN_OR_RETURN(const HloInstructionProto* val_proto, LookUpInstruction(val)); @@ -3827,4 +3860,8 @@ XlaOp SetDimensionSize(const XlaOp operand, const XlaOp val, int64 dimension) { return operand.builder()->SetDimensionSize(operand, val, dimension); } +XlaOp RemoveDynamicDimension(const XlaOp operand, int64 dimension) { + return operand.builder()->RemoveDynamicDimension(operand, dimension); +} + } // namespace xla diff --git a/tensorflow/compiler/xla/client/xla_builder.h b/tensorflow/compiler/xla/client/xla_builder.h index ffa6a7c3439..3fc26747468 100644 --- a/tensorflow/compiler/xla/client/xla_builder.h +++ b/tensorflow/compiler/xla/client/xla_builder.h @@ -527,6 +527,14 @@ class XlaBuilder { const Shape& shape_with_layout, const string& opaque, absl::optional> operand_shapes_with_layout); + // Internal version of CustomCall without computation that doesn't do op + // specific error handling and expects arguments to be legal. CustomCall + // method above calls this method after error handling. + virtual StatusOr CustomCallInternal( + const string& call_target_name, absl::Span operands, + const Shape& shape_with_layout, const string& opaque, + absl::optional> operand_shapes_with_layout); + XlaOp CustomCall( const string& call_target_name, absl::Span operands, const XlaComputation& computation, const Shape& shape_with_layout, @@ -704,6 +712,8 @@ class XlaBuilder { XlaOp SetDimensionSize(XlaOp operand, XlaOp val, int64 dimension); + XlaOp RemoveDynamicDimension(XlaOp operand, int64 dimension); + StatusOr AddInstruction(HloInstructionProto&& instr, HloOpcode opcode, absl::Span operands = {}); @@ -1151,6 +1161,7 @@ class XlaBuilder { friend XlaOp GetDimensionSize(XlaOp operand, int64 dimension); friend XlaOp SetDimensionSize(XlaOp operand, XlaOp val, int64 dimension); + friend XlaOp RemoveDynamicDimension(XlaOp operand, int64 dimension); protected: // Returns OK status if the given op was built using this builder. Otherwise, @@ -2149,6 +2160,9 @@ XlaOp GetDimensionSize(XlaOp operand, int64 dimension); XlaOp SetDimensionSize(XlaOp operand, XlaOp val, int64 dimension); +// Returns the same op but with dynamic dimension removed. +XlaOp RemoveDynamicDimension(XlaOp operand, int64 dimension); + // Implementation details below this point. // diff --git a/tensorflow/compiler/xla/client/xla_builder_test.cc b/tensorflow/compiler/xla/client/xla_builder_test.cc index 4fa47077fca..7011c946203 100644 --- a/tensorflow/compiler/xla/client/xla_builder_test.cc +++ b/tensorflow/compiler/xla/client/xla_builder_test.cc @@ -556,6 +556,32 @@ TEST_F(XlaBuilderTest, DynamicParameter) { EXPECT_TRUE(param_shape.is_dynamic_dimension(0)); } +TEST_F(XlaBuilderTest, SetDimensionSize) { + XlaBuilder b(TestName()); + auto p0 = Parameter(&b, 0, ShapeUtil::MakeShape(F32, {10}), "p0"); + auto p1 = Parameter(&b, 1, ShapeUtil::MakeShape(S32, {}), "p1"); + auto set_dim_size = SetDimensionSize(p0, p1, 0); + TF_ASSERT_OK_AND_ASSIGN(auto module, + BuildHloModule(&b, /*root=*/set_dim_size)); + const Shape& root_shape = + module->entry_computation()->root_instruction()->shape(); + EXPECT_TRUE(root_shape.is_dynamic_dimension(0)); +} + +TEST_F(XlaBuilderTest, RemoveDimensionSize) { + XlaBuilder b(TestName()); + auto p0 = Parameter(&b, 0, ShapeUtil::MakeShape(F32, {10}), "p0"); + auto p1 = Parameter(&b, 1, ShapeUtil::MakeShape(S32, {}), "p1"); + auto set_dim_size = SetDimensionSize(p0, p1, 0); + auto remove_dim_size = RemoveDynamicDimension(set_dim_size, 0); + TF_ASSERT_OK_AND_ASSIGN(auto module, + BuildHloModule(&b, /*root=*/remove_dim_size)); + const Shape& root_shape = + module->entry_computation()->root_instruction()->shape(); + // Dynamic dimension has been removed. + EXPECT_FALSE(root_shape.is_dynamic_dimension(0)); +} + TEST_F(XlaBuilderTest, DynamicUnary) { XlaBuilder b(TestName()); Shape tuple_param_shape = ShapeUtil::MakeTupleShape( diff --git a/tensorflow/compiler/xla/debug_options_flags.cc b/tensorflow/compiler/xla/debug_options_flags.cc index 81655101701..8ca6e2b294c 100644 --- a/tensorflow/compiler/xla/debug_options_flags.cc +++ b/tensorflow/compiler/xla/debug_options_flags.cc @@ -64,7 +64,7 @@ DebugOptions DefaultDebugOptionsIgnoringFlags() { opts.set_xla_cpu_fast_math_honor_division(true); // By default, copy TF's Eigen style min_max behavior with nans. - opts.set_xla_cpu_enable_fast_min_max(false); + opts.set_xla_cpu_enable_fast_min_max(true); opts.set_xla_gpu_enable_fast_min_max(true); diff --git a/tensorflow/compiler/xla/pjrt/BUILD b/tensorflow/compiler/xla/pjrt/BUILD index dd50d0577d4..695ba9dee93 100644 --- a/tensorflow/compiler/xla/pjrt/BUILD +++ b/tensorflow/compiler/xla/pjrt/BUILD @@ -141,12 +141,15 @@ cc_library( "//tensorflow/compiler/xla/service/gpu:gpu_executable_run_options", "//tensorflow/core:allocator", "//tensorflow/core:lib", + "//tensorflow/core/profiler/lib:connected_traceme", "//tensorflow/core/profiler/lib:traceme", + "//tensorflow/core/profiler/lib:traceme_encode", "//tensorflow/stream_executor:event", "//tensorflow/stream_executor:stream", "//tensorflow/stream_executor/host:host_platform_id", "//tensorflow/stream_executor/lib", "@com_google_absl//absl/base", + "@com_google_absl//absl/container:flat_hash_map", "@com_google_absl//absl/container:flat_hash_set", "@com_google_absl//absl/container:inlined_vector", "@com_google_absl//absl/memory", @@ -154,6 +157,7 @@ cc_library( "@com_google_absl//absl/strings:str_format", "@com_google_absl//absl/synchronization", "@com_google_absl//absl/time", + "@com_google_absl//absl/types:optional", "@com_google_absl//absl/types:span", ], ) diff --git a/tensorflow/compiler/xla/pjrt/distributed/service.h b/tensorflow/compiler/xla/pjrt/distributed/service.h index 725a76791ce..9ecbdb3cc7c 100644 --- a/tensorflow/compiler/xla/pjrt/distributed/service.h +++ b/tensorflow/compiler/xla/pjrt/distributed/service.h @@ -54,15 +54,15 @@ class DistributedRuntimeServiceImpl final absl::Mutex mu_; enum class State { kInitializing, kRunning }; - State state_ GUARDED_BY(mu_) = State::kInitializing; + State state_ ABSL_GUARDED_BY(mu_) = State::kInitializing; - std::vector local_topologies_ GUARDED_BY(mu_); - GlobalTopologyProto topology_ GUARDED_BY(mu_); + std::vector local_topologies_ ABSL_GUARDED_BY(mu_); + GlobalTopologyProto topology_ ABSL_GUARDED_BY(mu_); struct Node { bool present = false; }; - int num_nodes_present_ GUARDED_BY(mu_) = 0; - std::vector nodes_ GUARDED_BY(mu_); + int num_nodes_present_ ABSL_GUARDED_BY(mu_) = 0; + std::vector nodes_ ABSL_GUARDED_BY(mu_); KeyValueStore key_value_store_; }; diff --git a/tensorflow/compiler/xla/pjrt/local_device_state.cc b/tensorflow/compiler/xla/pjrt/local_device_state.cc index d173c891c95..a229e56001e 100644 --- a/tensorflow/compiler/xla/pjrt/local_device_state.cc +++ b/tensorflow/compiler/xla/pjrt/local_device_state.cc @@ -127,11 +127,15 @@ std::unique_ptr LocalDeviceState::BorrowStreamFromPool() { } else { std::unique_ptr stream = std::move(usage_stream_pool_.top()); usage_stream_pool_.pop(); + stream->RefreshStatus().IgnoreError(); // Can return error::Unimplemented + QCHECK(stream->ok()); return stream; } } void LocalDeviceState::ReturnStreamToPool(std::unique_ptr stream) { + stream->RefreshStatus().IgnoreError(); // Can return error::Unimplemented + QCHECK(stream->ok()); absl::MutexLock lock(&mu_); usage_stream_pool_.push(std::move(stream)); } diff --git a/tensorflow/compiler/xla/pjrt/nvidia_gpu_device.cc b/tensorflow/compiler/xla/pjrt/nvidia_gpu_device.cc index de760af8fd9..edffaf6c877 100644 --- a/tensorflow/compiler/xla/pjrt/nvidia_gpu_device.cc +++ b/tensorflow/compiler/xla/pjrt/nvidia_gpu_device.cc @@ -169,7 +169,7 @@ class NcclIdStore { const std::shared_ptr client_; absl::Mutex mu_; - absl::flat_hash_map cache_ GUARDED_BY(mu_); + absl::flat_hash_map cache_ ABSL_GUARDED_BY(mu_); }; StatusOr NcclIdStore::GetNcclUniqueId(const NcclCliqueKey& key) { diff --git a/tensorflow/compiler/xla/pjrt/pjrt_client.cc b/tensorflow/compiler/xla/pjrt/pjrt_client.cc index ccb72b7ce30..e341a11d64f 100644 --- a/tensorflow/compiler/xla/pjrt/pjrt_client.cc +++ b/tensorflow/compiler/xla/pjrt/pjrt_client.cc @@ -76,11 +76,13 @@ limitations under the License. #include "absl/strings/str_format.h" #include "absl/synchronization/mutex.h" #include "absl/time/time.h" +#include "absl/types/optional.h" #include "absl/types/span.h" #include "tensorflow/compiler/xla/client/local_client.h" #include "tensorflow/compiler/xla/client/xla_computation.h" #include "tensorflow/compiler/xla/cpu_function_runtime.h" #include "tensorflow/compiler/xla/executable_run_options.h" +#include "tensorflow/compiler/xla/layout.h" #include "tensorflow/compiler/xla/literal.h" #include "tensorflow/compiler/xla/literal_util.h" #include "tensorflow/compiler/xla/pjrt/distributed/protocol.pb.h" @@ -98,7 +100,9 @@ limitations under the License. #include "tensorflow/core/platform/mem.h" #include "tensorflow/core/platform/status.h" #include "tensorflow/core/platform/types.h" +#include "tensorflow/core/profiler/lib/connected_traceme.h" #include "tensorflow/core/profiler/lib/traceme.h" +#include "tensorflow/core/profiler/lib/traceme_encode.h" #include "tensorflow/stream_executor/device_memory.h" #include "tensorflow/stream_executor/device_memory_allocator.h" #include "tensorflow/stream_executor/event.h" @@ -749,16 +753,22 @@ StatusOr> PjRtBuffer::FromHostLiteral( // memory that has already been allocated, and a possible Event // allocation. + se::Stream* h2d_stream = local_device->host_to_device_stream(); ShapedBuffer buffer = device_buffer->AsShapedBuffer( compact_shape, on_device_shape, client->client()->platform()); TF_CHECK_OK(transfer_manager->TransferLiteralToDeviceAsync( - local_device->host_to_device_stream(), literal, buffer)); + h2d_stream, literal, buffer)); std::shared_ptr event = device_buffer->definition_events()[0]; TF_CHECK_OK(AddDestinationBufferSynchronization( - local_device, std::move(device_buffer), event, - local_device->host_to_device_stream())); + local_device, std::move(device_buffer), event, h2d_stream)); + + // This can sometimes catch the case where the literal memory has been + // freed before the H2D transfer was issued. + h2d_stream->RefreshStatus() + .IgnoreError(); // Can return error::Unimplemented + QCHECK(h2d_stream->ok()); }; client->h2d_transfer_pool()->Schedule(transfer_h2d); return py_buffer; @@ -853,10 +863,10 @@ StatusOr> PjRtBuffer::Release( if (device_buffer_ == nullptr) { return std::shared_ptr(); } - // Set host_value_ and device_buffer_ to null now so that no other thread - // can add a hold while we are in WaitForOutstandingUsageHolds() + // Clear host_values_ and set device_buffer_ to null now so that no other + // thread can add a hold while we are in WaitForOutstandingUsageHolds() // below. - host_value_ = nullptr; + host_values_.clear(); std::swap(device_buffer_, device_buffer); WaitForOutstandingUsageHolds(); // Now that all holds have completed and no more can be added, we can get @@ -991,7 +1001,7 @@ void PjRtBuffer::ConfirmDonation(TrackedDeviceBuffer* device_buffer) { device_buffer->ReleaseDeviceMemory(); // Make *this invalid so it can't be used again. Any threads blocking in // Release or GetBufferWithHold will see an invalid buffer and return. - host_value_ = nullptr; + host_values_.clear(); device_buffer_.reset(); } // Unblock another thread, if any, trying to get a donation hold. @@ -1011,7 +1021,14 @@ void PjRtBuffer::DropHold(ScopedHold::Type type, TrackedDeviceBuffer* buffer) { } } -Status PjRtBuffer::CopyToHostAsync() { +Status PjRtBuffer::CopyToHostAsync(absl::optional layout) { + return CopyToHostAsyncInternal(/*discard_cached_copy=*/false, layout) + .status(); +} + +StatusOr> +PjRtBuffer::CopyToHostAsyncInternal(bool discard_cached_copy, + absl::optional layout) { if (IsEmptyTuple()) { return InvalidArgument("CopyToHostAsync called on empty tuple"); } @@ -1019,6 +1036,8 @@ Status PjRtBuffer::CopyToHostAsync() { std::shared_ptr host_value; LocalDeviceState* local_device = device_->local_device_state(); se::Stream* stream = local_device->GetDeviceToHostStream(); + const xla::Layout& host_layout = + layout.has_value() ? layout.value() : on_host_shape_.layout(); { absl::MutexLock lock(&mu_); // We can't perform any other action while a donation hold is in progress. @@ -1026,17 +1045,36 @@ Status PjRtBuffer::CopyToHostAsync() { if (device_buffer_ == nullptr) { return InvalidArgument("CopyToHostAsync() called on invalid buffer."); } - if (host_value_) { - // The host value has already been requested or is available. - return Status::OK(); + if (discard_cached_copy) { + auto it = host_values_.find(host_layout); + if (it != host_values_.end()) { + host_value = it->second; + host_values_.erase(it); + return host_value; + } else { + host_value = std::make_shared(); + } + } else { + std::shared_ptr& host_value_ref = host_values_[host_layout]; + if (host_value_ref) { + return host_value_ref; + } + host_value = host_value_ref = std::make_shared(); } - host_value = host_value_ = std::make_shared(); AcquireHoldLocked(&device_buffer); } WaitForBufferDefinitionEventsOnStream(*device_buffer, stream); - host_value->value = std::make_shared(on_host_shape_); + Shape host_shape; + if (layout.has_value()) { + host_shape = ShapeUtil::MakeShape(on_host_shape_.element_type(), + on_host_shape_.dimensions()); + *host_shape.mutable_layout() = host_layout; + } else { + host_shape = on_host_shape_; + } + host_value->value = std::make_shared(host_shape); ShapedBuffer shaped_buffer = device_buffer->AsShapedBuffer( - on_host_shape_, on_device_shape_, client_->client()->platform()); + host_shape, on_device_shape_, client_->client()->platform()); client_->client()->backend().transfer_manager()->TransferLiteralFromDevice( stream, shaped_buffer, host_value->value.get(), [host_value](Status done_status) { @@ -1066,17 +1104,14 @@ Status PjRtBuffer::CopyToHostAsync() { RecordUsage(std::move(device_buffer), local_device, local_device, usage_event, stream, /*prefer_to_retain_reference=*/true); - return Status::OK(); + return host_value; } -StatusOr> PjRtBuffer::ToLiteral() { +StatusOr> PjRtBuffer::ToLiteral( + const bool discard_cached_copy, absl::optional layout) { tensorflow::profiler::TraceMe traceme("PjRtBuffer::ToLiteral"); - TF_RETURN_IF_ERROR(CopyToHostAsync()); - std::shared_ptr host_value; - { - absl::MutexLock lock(&mu_); - host_value = host_value_; - } + TF_ASSIGN_OR_RETURN(std::shared_ptr host_value, + CopyToHostAsyncInternal(discard_cached_copy, layout)); if (host_value == nullptr) { return InvalidArgument("ToLiteral called on invalid buffer"); } @@ -1429,10 +1464,9 @@ StatusOr PjRtExecutable::EnqueueExecution( int executable_idx, const RunId& run_id, const ExecuteOptions& options, Device* device, std::vector* device_buffers) const { int device_ordinal = device->local_device_state()->device_ordinal(); - tensorflow::profiler::TraceMe traceme([&] { - return absl::StrCat("LocalExecutable::Execute#run_id=", run_id.ToInt(), - "#"); - }); + tensorflow::profiler::TraceMeConsumer activity( + "LocalExecutable::Execute", tensorflow::profiler::ContextType::kPjRt, + run_id.ToInt()); VLOG(3) << "Replica " << replica << ", partition " << partition << " mapped to device ordinal for execution: " << device_ordinal; @@ -1721,10 +1755,9 @@ PjRtExecutable::ExecuteOnLocalDevices( absl::Span> argument_handles, const ExecuteOptions& options) const { RunId run_id; - tensorflow::profiler::TraceMe traceme([&] { - return absl::StrCat( - "LocalExecutable::ExecuteOnLocalDevices#run_id=", run_id.ToInt(), "#"); - }); + tensorflow::profiler::TraceMeProducer activity( + "LocalExecutable::ExecuteOnLocalDevices", + tensorflow::profiler::ContextType::kPjRt, run_id.ToInt()); const int num_local_devices = local_devices_.size(); diff --git a/tensorflow/compiler/xla/pjrt/pjrt_client.h b/tensorflow/compiler/xla/pjrt/pjrt_client.h index 754eb19bec6..c50d09f631c 100644 --- a/tensorflow/compiler/xla/pjrt/pjrt_client.h +++ b/tensorflow/compiler/xla/pjrt/pjrt_client.h @@ -20,15 +20,18 @@ limitations under the License. #include #include +#include "absl/container/flat_hash_map.h" #include "absl/container/flat_hash_set.h" #include "absl/container/inlined_vector.h" #include "absl/strings/string_view.h" #include "absl/synchronization/mutex.h" #include "absl/synchronization/notification.h" +#include "absl/types/optional.h" #include "absl/types/span.h" #include "tensorflow/compiler/xla/client/executable_build_options.h" #include "tensorflow/compiler/xla/client/local_client.h" #include "tensorflow/compiler/xla/client/xla_computation.h" +#include "tensorflow/compiler/xla/layout.h" #include "tensorflow/compiler/xla/pjrt/local_device_state.h" #include "tensorflow/compiler/xla/pjrt/tracked_device_buffer.h" #include "tensorflow/compiler/xla/service/computation_placer.h" @@ -478,13 +481,20 @@ class PjRtBuffer { // Returns the buffer's value as an XLA Literal. If the value has previously // been prefetched to the host, then returns the prefetched version, otherwise - // copies the buffer to the host. Blocks until the value is ready. - StatusOr> ToLiteral(); + // copies the buffer to the host. Blocks until the value is ready. If + // `discard_cached_copy` is true then buffer will no longer keep hold of a + // cached copy of the literal (i.e. The reference to the host value will be + // removed.) If a layout is passed than a literal with this layout will be + // returned. + StatusOr> ToLiteral( + bool discard_cached_copy = false, + absl::optional layout = {}); // Initiates a copy of the buffer to the host. Does not block waiting for // the transfer to complete. The value can be retrieved by a later call to - // ToLiteral(). - Status CopyToHostAsync(); + // ToLiteral(). If a layout is passed then a cached copy with this layout will + // be created. + Status CopyToHostAsync(absl::optional layout = {}); // Drops the buffer's reference to its associated device memory, leaving the // buffer in an invalid state. The memory will be freed lazily when all async @@ -592,6 +602,14 @@ class PjRtBuffer { // successfully donated to an execution. void ConfirmDonation(TrackedDeviceBuffer* device_buffer); + // Initiates a copy of the buffer to the host. Does not block waiting for + // the transfer to complete. A host value is returned and if + // `discard_cached_copy` is false stored in an internal buffer so that future + // transfers don't have to transfer the data from host again. If a layout is + // passed then a literal of this layout will be returned and possibly cached. + StatusOr> CopyToHostAsyncInternal( + bool discard_cached_copy, absl::optional layout); + // Drops a hold without taking any other action. Does a sanity check that // buffer==device_buffer_ or device_buffer_==nullptr. void DropHold(ScopedHold::Type type, TrackedDeviceBuffer* buffer); @@ -610,6 +628,8 @@ class PjRtBuffer { mutable absl::Mutex mu_; std::shared_ptr device_buffer_ TF_GUARDED_BY(mu_); + absl::flat_hash_map> host_values_ + TF_GUARDED_BY(mu_); std::shared_ptr host_value_ TF_GUARDED_BY(mu_); // Count of holds on the buffer. std::array holds_ TF_GUARDED_BY(mu_); diff --git a/tensorflow/compiler/xla/service/BUILD b/tensorflow/compiler/xla/service/BUILD index acd35cbc153..10e2d7e65d1 100644 --- a/tensorflow/compiler/xla/service/BUILD +++ b/tensorflow/compiler/xla/service/BUILD @@ -1202,6 +1202,9 @@ cc_library( srcs = ["transfer_manager.cc"], hdrs = ["transfer_manager.h"], deps = [ + ":compiler", + ":executable", + ":maybe_owning_device_memory", ":shaped_buffer", "//tensorflow/compiler/xla:literal", "//tensorflow/compiler/xla:shape_util", @@ -1210,8 +1213,6 @@ cc_library( "//tensorflow/compiler/xla:types", "//tensorflow/compiler/xla:util", "//tensorflow/compiler/xla:xla_data_proto_cc", - "//tensorflow/compiler/xla/service:executable", - "//tensorflow/compiler/xla/service:maybe_owning_device_memory", "//tensorflow/core:lib", "//tensorflow/core:stream_executor_no_cuda", "//tensorflow/stream_executor:device_memory", @@ -1679,6 +1680,7 @@ cc_library( hdrs = ["multi_output_fusion.h"], deps = [ ":hlo", + ":hlo_dce", ":hlo_pass", ":hlo_reachability", "//tensorflow/compiler/xla:debug_options_flags", diff --git a/tensorflow/compiler/xla/service/algebraic_simplifier.cc b/tensorflow/compiler/xla/service/algebraic_simplifier.cc index ce2a801fccd..130661bf1cd 100755 --- a/tensorflow/compiler/xla/service/algebraic_simplifier.cc +++ b/tensorflow/compiler/xla/service/algebraic_simplifier.cc @@ -2815,6 +2815,28 @@ Status AlgebraicSimplifierVisitor::HandleCompare(HloInstruction* compare) { HloInstruction* lhs; HloInstruction* rhs; CHECK(Match(compare, m::Compare(m::Op(&lhs), m::Op(&rhs)))); + { + // compare(broadcast(a) + x, broadcast(b)) ==> + // compare(x, broadcast(b-a)) + HloInstruction *x, *a, *b; + if (Match(compare, + m::Compare( + m::AddAnyOrder(m::Op(&x), m::Broadcast(m::Op(&a).WithShape( + m::Shape().IsScalar()))), + m::Broadcast(m::Op(&b).WithShape(m::Shape().IsScalar()))))) { + if (ShapeUtil::ElementIsSigned(x->shape())) { + HloInstruction* sub = + computation_->AddInstruction(HloInstruction::CreateBinary( + b->shape(), HloOpcode::kSubtract, b, a)); + HloInstruction* broadcast = computation_->AddInstruction( + HloInstruction::CreateBroadcast(x->shape(), sub, {})); + HloInstruction* new_compare = computation_->AddInstruction( + HloInstruction::CreateCompare(compare->shape(), x, broadcast, + compare->comparison_direction())); + return ReplaceInstruction(compare, new_compare); + } + } + } if (compare->comparison_direction() == ComparisonDirection::kLt && lhs->opcode() == HloOpcode::kIota && IsAll(rhs, 0)) { diff --git a/tensorflow/compiler/xla/service/conditional_code_motion.cc b/tensorflow/compiler/xla/service/conditional_code_motion.cc index eecdcc851e9..6db4c3eb6d4 100644 --- a/tensorflow/compiler/xla/service/conditional_code_motion.cc +++ b/tensorflow/compiler/xla/service/conditional_code_motion.cc @@ -106,7 +106,6 @@ class BranchVisitor { boundaries_.emplace_back(operand, i, inst); continue; } - worklist_.push_back(operand); visited_.insert(operand); } @@ -197,6 +196,7 @@ bool WorthHoisting(HloInstruction* instruction) { case HloOpcode::kMultiply: case HloOpcode::kDivide: case HloOpcode::kTuple: + case HloOpcode::kSqrt: case HloOpcode::kGetTupleElement: return true; default: @@ -206,10 +206,11 @@ bool WorthHoisting(HloInstruction* instruction) { // Compare if the instructions to be visited at each branches are identical. bool InstructionWithinBranchIdentical( - const std::vector& instructions, bool is_layout_senstive) { + const std::vector& instructions, + bool is_layout_sensitive) { // Identical includes the shape of each operands are equal. auto eq_operand = [&](const HloInstruction* a, const HloInstruction* b) { - bool eq_operands = is_layout_senstive + bool eq_operands = is_layout_sensitive ? ShapeUtil::Equal(a->shape(), b->shape()) : ShapeUtil::Compatible(a->shape(), b->shape()); return eq_operands; @@ -233,7 +234,7 @@ bool InstructionWithinBranchIdentical( auto old_channel_id = instruction->channel_id(); instruction->set_channel_id(instructions[0]->channel_id()); bool eq_instructions = instructions[0]->Identical( - *instruction, eq_operand, eq_computations, is_layout_senstive); + *instruction, eq_operand, eq_computations, is_layout_sensitive); instruction->set_channel_id(old_channel_id); return eq_instructions; }); @@ -243,7 +244,7 @@ bool InstructionWithinBranchIdentical( [&](HloInstruction* instruction) { return instructions[0]->Identical( *instruction, eq_operand, eq_computations, - is_layout_senstive); + is_layout_sensitive); }); } @@ -354,12 +355,228 @@ Status RemoveInstructionFromComputation( return Status::OK(); } +// Identify converts to be hoisted/rematerialized out of the branch +// computations. +absl::flat_hash_set FindSpecialConverts(HloInstruction* old_root, + int branch_count, + HloInstruction* conditional, + bool is_layout_sensitive) { + absl::flat_hash_set kspecial_convert; + for (int64 operand_num = 0; operand_num < old_root->operand_count(); + ++operand_num) { + if (old_root->operand(operand_num)->opcode() != HloOpcode::kConvert) { + continue; + } + bool replica = true; + HloInstruction* kspecial_convert_candidate = + old_root->mutable_operand(operand_num); + // Check whether an identical candidate appears in other branches + for (int others = 1; others < branch_count; ++others) { + HloInstruction* others_root = + conditional->branch_computation(others)->root_instruction(); + bool eq_shape = + is_layout_sensitive + ? ShapeUtil::Equal(others_root->operand(operand_num)->shape(), + kspecial_convert_candidate->shape()) + : ShapeUtil::Compatible( + others_root->operand(operand_num)->shape(), + kspecial_convert_candidate->shape()); + if ((others_root->operand(operand_num)->opcode() == + HloOpcode::kConvert) && + eq_shape) { + // Nothing to be done. + } else { + replica = false; + break; + } + } + if (replica) { + kspecial_convert.insert(operand_num); + } + } + return kspecial_convert; +} + +// Restructuring the conditional instruction as follows: +// i.e., %result = conditional() becomes +// x = conditional() +// y.{0..n} = gte(x, {0..n}) +// z = tuple(y.0, y.1, ...y.n) +// Doing so ensures that we can accommodate the possible shape-change of the +// conditional when the instructions are hoisted. +Status RestructureConditionalInstruction(HloComputation* computation, + HloInstruction* conditional) { + HloInstruction* old_root = computation->root_instruction(); + std::vector new_operands; + int cur_index = 0; + for (; cur_index < ShapeUtil::TupleElementCount(conditional->shape()); + ++cur_index) { + new_operands.push_back( + computation->AddInstruction(HloInstruction::CreateGetTupleElement( + ShapeUtil::GetTupleElementShape(conditional->shape(), cur_index), + conditional, cur_index))); + } + HloInstruction* new_tuple = + computation->AddInstruction(HloInstruction::CreateTuple(new_operands)); + if (old_root == conditional) { + computation->set_root_instruction(new_tuple); + } else { + std::vector new_tuple_users; + for (auto conditional_user : conditional->users()) { + auto is_new_gte = absl::c_find_if( + new_operands, + [&](HloInstruction* instr) { return instr == conditional_user; }); + if (is_new_gte == new_operands.end()) { + new_tuple_users.push_back(conditional_user); + } + } + for (auto new_tuple_user : new_tuple_users) { + TF_RETURN_IF_ERROR( + conditional->ReplaceUseWith(new_tuple_user, new_tuple)); + } + } + VLOG(2) << "computation after root restructure:\n" << computation->ToString(); + return Status::OK(); +} + +StatusOr ConvertSpecialMove(HloInstruction* conditional, + bool is_layout_sensitive) { + int branch_count = conditional->branch_count(); + if (branch_count <= 0) { + return false; + } + + HloInstruction* old_root = + conditional->branch_computation(0)->root_instruction(); + if (old_root->opcode() != HloOpcode::kTuple) { + return false; + } else { + VLOG(2) << "BEFORE :" << conditional->parent()->parent()->ToString(); + // Identify the gte using `index'. + auto find_gte = [](const HloInstruction* conditional_result, + int64 index) -> HloInstruction* { + for (HloInstruction* instr : conditional_result->users()) { + if (instr->opcode() != HloOpcode::kGetTupleElement) { + return nullptr; + } + if (instr->tuple_index() == index) { + return instr; + } + } + return nullptr; + }; + + // Captures tuple indices refering to converts to be rematerialized/hoisted. + absl::flat_hash_set kspecial_convert = FindSpecialConverts( + old_root, branch_count, conditional, is_layout_sensitive); + + // Exit if we cannot find any converts to be hoisted. + if (kspecial_convert.empty()) { + return false; + } + + TF_RETURN_IF_ERROR( + RestructureConditionalInstruction(conditional->parent(), conditional)); + + for (int branch = 0; branch < branch_count; branch++) { + old_root = conditional->branch_computation(branch)->root_instruction(); + absl::flat_hash_map map_inst_to_tuple_index; + std::vector new_operands(old_root->operand_count()); + std::unordered_set to_hoist_set; + + for (int64 operand_num = 0; operand_num < old_root->operand_count(); + ++operand_num) { + map_inst_to_tuple_index[old_root->mutable_operand(operand_num)] = + operand_num; + } + for (int64 operand_num = 0; operand_num < old_root->operand_count(); + ++operand_num) { + HloInstruction* hoist = old_root->mutable_operand(operand_num); + if (!kspecial_convert.contains(operand_num)) { + new_operands[operand_num] = old_root->mutable_operand(operand_num); + continue; + } + + to_hoist_set.insert(hoist); + int64 new_tuple_count = old_root->operand_count(); + + // Replace the hoisted instr in the tuple with the operand/operands. + // We will replace at least one of the operands of the hoist at the + // tuple place; the rest will be added at the end. + bool inplace = true; + CHECK(!hoist->operands().empty()); + for (HloInstruction* prod : hoist->operands()) { + if (inplace) { + map_inst_to_tuple_index[prod] = map_inst_to_tuple_index[hoist]; + new_operands[map_inst_to_tuple_index[hoist]] = prod; + inplace = false; + } else { + map_inst_to_tuple_index[prod] = new_tuple_count++; + new_operands.push_back(prod); + } + } + } + + // Create the new root instruction. + HloComputation* cur_branch = conditional->branch_computation(branch); + HloInstruction* new_branch_root = + cur_branch->AddInstruction(HloInstruction::CreateTuple(new_operands)); + // The shape can vary since the operands to convert are now + // being returned through the branches' root. + cur_branch->set_root_instruction(new_branch_root, true /*new shape*/); + TF_CHECK_OK(cur_branch->RemoveInstruction(old_root)); + + // Only one of the branches needs to change the conditional->parent(). + if (branch != 0) { + continue; + } + HloComputation* conditional_parent = conditional->parent(); + HloInstruction* newconditional = + conditional_parent->AddInstruction(HloInstruction::CreateConditional( + cur_branch->root_instruction()->shape(), + conditional->mutable_operand(0), + absl::MakeSpan(conditional->branch_computations()), + absl::MakeSpan(conditional->operands()).subspan(1))); + // Ensure that all the users of conditional refer to the new one. + TF_RETURN_IF_ERROR( + conditional->ReplaceAllUsesWithDifferentShape(newconditional)); + TF_CHECK_OK(conditional_parent->RemoveInstruction(conditional)); + conditional = newconditional; + // Add the hoisted instructions in the parent. + for (HloInstruction* hoist : to_hoist_set) { + VLOG(2) << "Hoisting instruction:" << hoist->ToString(); + int64 hoist_index = map_inst_to_tuple_index[hoist]; + // Find out the gte that captured the hoisted instr result. + HloInstruction* gte_hoist = find_gte(conditional, hoist_index); + CHECK(gte_hoist != nullptr); + std::vector new_operands; + for (HloInstruction* op : hoist->operands()) { + HloInstruction* gte = conditional_parent->AddInstruction( + HloInstruction::CreateGetTupleElement( + op->shape(), conditional, map_inst_to_tuple_index[op])); + new_operands.push_back(gte); + } + HloInstruction* hoisted = conditional_parent->AddInstruction( + hoist->CloneWithNewOperands(hoist->shape(), new_operands)); + VLOG(2) << "Hoisted instruction in parent:" << hoisted->ToString(); + TF_RETURN_IF_ERROR(gte_hoist->ReplaceAllUsesWith(hoisted)); + TF_CHECK_OK(conditional_parent->RemoveInstruction(gte_hoist)); + } + // No need to explicitly delete a hoisted instruction since if its dead + // then the subsequent DCE will remove it. + } + } + VLOG(2) << "AFTER :" << conditional->parent()->parent()->ToString(); + return true; +} + // Hoist identical ops out of the conditional. The definition of identical // are the shape of the operands are identical and their properties are // identical. Will start from the root instruction of each branch and get // the identical ops to hoist. StatusOr MergeIdenticalElements(HloInstruction* conditional, bool is_layout_sensitive) { + VLOG(1) << " visiting conditional:" << conditional->ToString(); int branch_count = conditional->branch_count(); if (branch_count <= 0) { return false; @@ -399,7 +616,7 @@ StatusOr MergeIdenticalElements(HloInstruction* conditional, } } - if (visitors[0].HoistInstructionSize() <= 1) { + if (visitors[0].HoistInstructionSize() < 1) { return false; } @@ -442,7 +659,6 @@ StatusOr MergeIdenticalElements(HloInstruction* conditional, RemoveInstructionFromComputation(visitors[i].instructions_to_hoist(), conditional->branch_computation(i))); } - return true; } @@ -451,26 +667,55 @@ StatusOr MergeIdenticalElements(HloInstruction* conditional, StatusOr ConditionalCodeMotion::Run(HloModule* module) { bool changed = false; - // Gather all the conditional ops in our module. We do this ahead of time so - // we don't have to worry about mutating the lists of computations or - // instructions as we iterate. - std::vector conditional_ops; - for (auto* comp : module->MakeComputationPostOrder()) { - for (auto* instr : comp->MakeInstructionPostOrder()) { - if (instr->opcode() == HloOpcode::kConditional) { - conditional_ops.push_back(instr); + if (pursue_full_conditional_code_motion_) { + std::vector conditional_ops; + for (auto* comp : module->MakeComputationPostOrder()) { + for (auto* instr : comp->MakeInstructionPostOrder()) { + if (instr->opcode() == HloOpcode::kConditional) { + conditional_ops.push_back(instr); + } } } + + for (HloInstruction* conditional_op : conditional_ops) { + TF_ASSIGN_OR_RETURN( + bool result, + MergeIdenticalElements(conditional_op, is_layout_sensitive_)); + changed |= result; + } + + if (changed) { + HloPassPipeline subpipeline("after_conditional_code_motion"); + subpipeline.AddPass(); + subpipeline.AddPass(); + subpipeline.AddPass(); + TF_ASSIGN_OR_RETURN(bool cleanup_changed, subpipeline.Run(module)); + changed |= cleanup_changed; + } } - for (HloInstruction* conditional_op : conditional_ops) { - TF_ASSIGN_OR_RETURN(bool result, MergeIdenticalElements( - conditional_op, is_layout_sensitive_)); - changed |= result; + // handling convert rematerialization/hoisting + { + std::vector conditional_ops; + for (auto* comp : module->MakeComputationPostOrder()) { + for (auto* instr : comp->MakeInstructionPostOrder()) { + if (instr->opcode() == HloOpcode::kConditional) { + conditional_ops.push_back(instr); + } + } + } + for (HloInstruction* conditional_op : conditional_ops) { + TF_ASSIGN_OR_RETURN( + bool convert_result, + ConvertSpecialMove(conditional_op, is_layout_sensitive_)); + changed |= convert_result; + } } if (changed) { - HloPassPipeline subpipeline("after_conditional_code_motion"); + HloPassPipeline subpipeline( + "after_conditional_code_motion_after_convert_hoisting"); + subpipeline.AddPass(); subpipeline.AddPass(); subpipeline.AddPass(); TF_ASSIGN_OR_RETURN(bool cleanup_changed, subpipeline.Run(module)); diff --git a/tensorflow/compiler/xla/service/conditional_code_motion.h b/tensorflow/compiler/xla/service/conditional_code_motion.h index 1197a8b3620..95f02833e15 100644 --- a/tensorflow/compiler/xla/service/conditional_code_motion.h +++ b/tensorflow/compiler/xla/service/conditional_code_motion.h @@ -23,7 +23,11 @@ limitations under the License. namespace xla { -// HLO pass that moves identical ops out of conditional. +// ConditionalCodeMotion specializes in hoisting/rematerializing +// unconditional converts in the default mode. +// When pursue_full_conditional_code_motion_ is set to true, the +// full HLO pass moves identical ops out of a conditional in addition to moving +// converts. // - The definition of identical are the shape of the operands are identical // and their properties are identical. // - Currently, only some types of instructions is supported. @@ -35,13 +39,18 @@ class ConditionalCodeMotion : public HloModulePass { public: // If is_layout_sensitive is true, then the hoist process preserves layout // during identical comparison. Otherwise, layout is ignored. - explicit ConditionalCodeMotion(bool is_layout_sensitive = true) - : is_layout_sensitive_(is_layout_sensitive) {} + explicit ConditionalCodeMotion( + bool is_layout_sensitive = true, + bool pursue_full_conditional_code_motion = false) + : is_layout_sensitive_(is_layout_sensitive), + pursue_full_conditional_code_motion_( + pursue_full_conditional_code_motion) {} absl::string_view name() const override { return "conditional-code-motion"; } StatusOr Run(HloModule* module) override; private: const bool is_layout_sensitive_; + const bool pursue_full_conditional_code_motion_; }; } // namespace xla diff --git a/tensorflow/compiler/xla/service/conditional_code_motion_test.cc b/tensorflow/compiler/xla/service/conditional_code_motion_test.cc index 4a52303a42a..38b2b515fa0 100644 --- a/tensorflow/compiler/xla/service/conditional_code_motion_test.cc +++ b/tensorflow/compiler/xla/service/conditional_code_motion_test.cc @@ -38,7 +38,86 @@ namespace { using ConditionalCodeMotionTest = HloTestBase; namespace op = xla::testing::opcode_matchers; -TEST_F(ConditionalCodeMotionTest, DoNotMoveConvertOut) { +TEST_F(ConditionalCodeMotionTest, MoveSubsetTupleOut) { + absl::string_view hlo_string = + R"( +HloModule RemoveDotOpOut + +on_true { + %arg_tuple.1 = (f32[93184,4]{1,0}) parameter(0) + %get-tuple-element.1 = f32[93184,4]{1,0} get-tuple-element(%arg_tuple.1), index=0 + %reshape.8493 = f32[2,512,364]{2,1,0} reshape(f32[93184,4]{1,0} %get-tuple-element.1) + %convert.2894 = bf16[2,512,364]{2,1,0} convert(f32[2,512,364]{2,1,0} %reshape.8493) + ROOT %tuple.1 = ( bf16[2,512,364]{2,1,0}, f32[2,512,364]{2,1,0}) tuple(%convert.2894, %reshape.8493) +} + +on_false { + %arg_tuple.2 = (f32[93184,4]{1,0}) parameter(0) + %get-tuple-element.3 = f32[93184,4]{1,0} get-tuple-element(%arg_tuple.2), index=0 + %reshape.9717 = f32[2,512,364]{2,1,0} reshape(f32[93184,4]{1,0} %get-tuple-element.3) + %add = f32[2,512,364]{2,1,0} add(f32[2,512,364]{2,1,0} %reshape.9717, f32[2,512,364]{2,1,0} %reshape.9717) + %convert.3604 = bf16[2,512,364]{2,1,0} convert(f32[2,512,364]{2,1,0} %reshape.9717), metadata={op_type="Cast" op_name="gradients/Cast_125_grad/Cast"} + ROOT %tuple.2 = (bf16[2,512,364]{2,1,0}, f32[2,512,364]{2,1,0}) tuple(%convert.3604, %add) +} + +ENTRY main { + pred.1 = pred[] parameter(0) + arg_tuple.11 = (f32[93184,4]{1,0}) parameter(1) + arg_tuple.22 = (f32[93184,4]{1,0}) parameter(2) + conditional = (bf16[2,512,364]{2,1,0}, f32[2,512,364]{2,1,0}) conditional(pred.1, arg_tuple.11, arg_tuple.22), true_computation=on_true, false_computation=on_false + get-first-index = bf16[2,512,364]{2,1,0} get-tuple-element(conditional), index=0 + get-first-index.2 = f32[2,512,364]{2,1,0} get-tuple-element(conditional), index=1 + ROOT result = (bf16[2,512,364]{2,1,0}, f32[2,512,364]{2,1,0}) tuple(get-first-index, get-first-index.2) +} +)"; + auto module = ParseAndReturnVerifiedModule(hlo_string).ValueOrDie(); + ConditionalCodeMotion pass(true, true); + ASSERT_TRUE(pass.Run(&*module).ValueOrDie()); + + HloInstruction* root = module->entry_computation()->root_instruction(); + EXPECT_THAT(root, AllOf(op::Tuple(op::Convert(), op::GetTupleElement()))); +} + +TEST_F(ConditionalCodeMotionTest, MoveConvertOutConditionalRoot) { + absl::string_view hlo_string = + R"( +HloModule RemoveDotOpOut + +on_true { + %arg_tuple.1 = (f32[93184,4]{1,0}) parameter(0) + %get-tuple-element.1 = f32[93184,4]{1,0} get-tuple-element(%arg_tuple.1), index=0 + %reshape.8493 = f32[2,512,364]{2,1,0} reshape(f32[93184,4]{1,0} %get-tuple-element.1) + %add.8493 = f32[2,512,364]{2,1,0} add(f32[2,512,364]{2,1,0} %reshape.8493, f32[2,512,364]{2,1,0} %reshape.8493) + %convert.2894 = bf16[2,512,364]{2,1,0} convert(f32[2,512,364]{2,1,0} %add.8493) + ROOT %tuple.1 = ( bf16[2,512,364]{2,1,0}) tuple(%convert.2894) +} + +on_false { + %arg_tuple.2 = (f32[93184,4]{1,0}) parameter(0) + %get-tuple-element.3 = f32[93184,4]{1,0} get-tuple-element(%arg_tuple.2), index=0 + %reshape.9717 = f32[2,512,364]{2,1,0} reshape(f32[93184,4]{1,0} %get-tuple-element.3) + %add.8493 = f32[2,512,364]{2,1,0} add(f32[2,512,364]{2,1,0} %reshape.9717, f32[2,512,364]{2,1,0} %reshape.9717) + %sub.8493 = f32[2,512,364]{2,1,0} subtract(f32[2,512,364]{2,1,0} %add.8493, f32[2,512,364]{2,1,0} %reshape.9717) + %convert.3604 = bf16[2,512,364]{2,1,0} convert(f32[2,512,364]{2,1,0} %reshape.9717), metadata={op_type="Cast" op_name="gradients/Cast_125_grad/Cast"} + ROOT %tuple.2 = (bf16[2,512,364]{2,1,0}) tuple(%convert.3604) +} + +ENTRY main { + pred.1 = pred[] parameter(0) + arg_tuple.11 = (f32[93184,4]{1,0}) parameter(1) + arg_tuple.22 = (f32[93184,4]{1,0}) parameter(2) + ROOT conditional = (bf16[2,512,364]{2,1,0}) conditional(pred.1, arg_tuple.11, arg_tuple.22), true_computation=on_true, false_computation=on_false +} +)"; + auto module = ParseAndReturnVerifiedModule(hlo_string).ValueOrDie(); + ConditionalCodeMotion pass(true, true); + ASSERT_TRUE(pass.Run(&*module).ValueOrDie()); + + HloInstruction* root = module->entry_computation()->root_instruction(); + EXPECT_THAT(root, AllOf(op::Tuple(op::Convert()))); +} + +TEST_F(ConditionalCodeMotionTest, MoveConvertOut) { absl::string_view hlo_string = R"( HloModule RemoveDotOpOut @@ -65,12 +144,16 @@ ENTRY main { arg_tuple.22 = (f32[93184,4]{1,0}) parameter(2) conditional = (bf16[2,512,364]{2,1,0}) conditional(pred.1, arg_tuple.11, arg_tuple.22), true_computation=on_true, false_computation=on_false get-first-index = bf16[2,512,364]{2,1,0} get-tuple-element(conditional), index=0 - ROOT result = (bf16[2,512,364]{2,1,0}) tuple(get-first-index) + add.1 = bf16[2,512,364]{2,1,0} add(bf16[2,512,364]{2,1,0} get-first-index, bf16[2,512,364]{2,1,0} get-first-index) + ROOT result = (bf16[2,512,364]{2,1,0}) tuple(add.1) } )"; auto module = ParseAndReturnVerifiedModule(hlo_string).ValueOrDie(); - ConditionalCodeMotion pass; - ASSERT_FALSE(pass.Run(&*module).ValueOrDie()); + ConditionalCodeMotion pass(true, true); + ASSERT_TRUE(pass.Run(&*module).ValueOrDie()); + + HloInstruction* root = module->entry_computation()->root_instruction(); + EXPECT_THAT(root, AllOf(op::Tuple(op::Add(op::Convert(), op::Convert())))); } TEST_F(ConditionalCodeMotionTest, UserShareOperandCannotBeMoved) { @@ -123,7 +206,7 @@ ENTRY main { } )"; auto module = ParseAndReturnVerifiedModule(hlo_string).ValueOrDie(); - ConditionalCodeMotion pass; + ConditionalCodeMotion pass(true, true); ASSERT_TRUE(pass.Run(&*module).ValueOrDie()); const HloInstruction* conditional = @@ -181,7 +264,7 @@ ENTRY main { } )"; auto module = ParseAndReturnVerifiedModule(hlo_string).ValueOrDie(); - ConditionalCodeMotion pass; + ConditionalCodeMotion pass(true, true); ASSERT_TRUE(pass.Run(&*module).ValueOrDie()); const HloInstruction* conditional = FindInstruction(module.get(), "conditional"); @@ -245,7 +328,7 @@ ENTRY main { } )"; auto module = ParseAndReturnVerifiedModule(hlo_string).ValueOrDie(); - ConditionalCodeMotion pass; + ConditionalCodeMotion pass(true, true); ASSERT_TRUE(pass.Run(&*module).ValueOrDie()); const HloInstruction* conditional = @@ -317,7 +400,7 @@ ENTRY main { )"; auto module = ParseAndReturnVerifiedModule(hlo_string).ValueOrDie(); - ConditionalCodeMotion pass; + ConditionalCodeMotion pass(true, true); ASSERT_FALSE(pass.Run(&*module).ValueOrDie()); } @@ -390,7 +473,7 @@ ENTRY main { } )"; auto module = ParseAndReturnVerifiedModule(hlo_string).ValueOrDie(); - ConditionalCodeMotion pass; + ConditionalCodeMotion pass(true, true); ASSERT_TRUE(pass.Run(&*module).ValueOrDie()); const HloInstruction* conditional = FindInstruction(module.get(), "conditional"); diff --git a/tensorflow/compiler/xla/service/cpu/BUILD b/tensorflow/compiler/xla/service/cpu/BUILD index ad023efae59..e12c67f2357 100644 --- a/tensorflow/compiler/xla/service/cpu/BUILD +++ b/tensorflow/compiler/xla/service/cpu/BUILD @@ -30,6 +30,15 @@ filegroup( ]), ) +cc_library( + name = "test_header_helper", + testonly = True, + hdrs = ["test_target_triple_helper.h"], + deps = [ + "//tensorflow/core:test", + ], +) + filegroup( name = "single_threaded_runtime_srcs", srcs = [ @@ -1071,6 +1080,7 @@ tf_cc_test( deps = [ ":cpu_compiler", ":cpu_transfer_manager", + ":test_header_helper", "//tensorflow/compiler/xla:test", "//tensorflow/compiler/xla/tests:hlo_test_base", "//tensorflow/compiler/xla/tests:xla_internal_test_main", diff --git a/tensorflow/compiler/xla/service/cpu/cpu_compiler.cc b/tensorflow/compiler/xla/service/cpu/cpu_compiler.cc index b2416ac2799..31b9fe1c920 100644 --- a/tensorflow/compiler/xla/service/cpu/cpu_compiler.cc +++ b/tensorflow/compiler/xla/service/cpu/cpu_compiler.cc @@ -277,12 +277,12 @@ Status CpuCompiler::RunHloPassesThroughLayoutAssn( pipeline.AddPass( cost_model, /*convert_batch_groups_only=*/false); - pipeline.AddPass(); pipeline.AddPass( /*rewrite_training_op=*/true, /*rewrite_inference_op=*/true, /*rewrite_grad_op=*/true); pipeline.AddPass(); + pipeline.AddPass(); pipeline.AddPass(); pipeline.AddPass(target_machine_features); { diff --git a/tensorflow/compiler/xla/service/cpu/cpu_executable.cc b/tensorflow/compiler/xla/service/cpu/cpu_executable.cc index 4552d7b5ba9..d095d220b97 100644 --- a/tensorflow/compiler/xla/service/cpu/cpu_executable.cc +++ b/tensorflow/compiler/xla/service/cpu/cpu_executable.cc @@ -299,12 +299,11 @@ StatusOr CpuExecutable::ExecuteAsyncOnStream( const Shape& expected_shape = entry_comp->parameter_instruction(i)->shape(); const Shape& actual_shape = arguments[i].Buffers().shape(); - CHECK( - Shape::Equal().IgnoreDynamicDimension()(expected_shape, actual_shape)) - << absl::StreamFormat( - "Shape mismatch on argument %d. Expected %s, but was %s.", i, - expected_shape.ToString(/*print_layout=*/true), - actual_shape.ToString(/*print_layout=*/true)); + TF_RET_CHECK( + ShapeUtil::DynamicShapeIsCompatible(actual_shape, expected_shape)) + << "Shape mismatch on argument " << i << ", " + << expected_shape.ToString(/*print_layout=*/true) << " vs. " + << actual_shape.ToString(/*print_layout=*/true); } } diff --git a/tensorflow/compiler/xla/service/cpu/cpu_instruction_fusion.cc b/tensorflow/compiler/xla/service/cpu/cpu_instruction_fusion.cc index 97e0a518499..9460cc55e10 100644 --- a/tensorflow/compiler/xla/service/cpu/cpu_instruction_fusion.cc +++ b/tensorflow/compiler/xla/service/cpu/cpu_instruction_fusion.cc @@ -94,9 +94,8 @@ bool CpuInstructionFusion::ShouldFuse(HloInstruction* consumer, // Cost condition: not fuse (simple, expensive producers) and (consumers who // reuse operand elements). - if (producer->opcode() != HloOpcode::kFusion && - consumer->ReusesOperandElements(operand_index) && - is_expensive(*producer)) { + if (producer->opcode() != HloOpcode::kFusion && is_expensive(*producer) && + consumer->ReusesOperandElements(operand_index)) { VLOG(2) << "Fusion is not profitable."; return false; } diff --git a/tensorflow/compiler/xla/service/cpu/elemental_ir_emitter.h b/tensorflow/compiler/xla/service/cpu/elemental_ir_emitter.h index 5c9f6677ab3..4c3167e16d9 100644 --- a/tensorflow/compiler/xla/service/cpu/elemental_ir_emitter.h +++ b/tensorflow/compiler/xla/service/cpu/elemental_ir_emitter.h @@ -50,6 +50,10 @@ class CpuElementalIrEmitter : public ElementalIrEmitter { return ir_emitter_->EmitThreadLocalCall(callee, parameters, name); } + bool fast_min_max() override { + return hlo_module_config_.debug_options().xla_cpu_enable_fast_min_max(); + } + IrEmitter* ir_emitter_; }; diff --git a/tensorflow/compiler/xla/service/cpu/llvm_ir_runtime.cc b/tensorflow/compiler/xla/service/cpu/llvm_ir_runtime.cc index f62769cc615..8d9229c1223 100644 --- a/tensorflow/compiler/xla/service/cpu/llvm_ir_runtime.cc +++ b/tensorflow/compiler/xla/service/cpu/llvm_ir_runtime.cc @@ -318,7 +318,9 @@ llvm::Value* GenerateVF32Log(llvm::IRBuilder<>* b, llvm::Value* input, llvm::Value* is_pos_inf_mask = vsl.FCmpEQMask(input, pos_inf); // Cut off denormalized stuff. - llvm::Value* tmp0 = vsl.Max(min_norm_pos, input); + // Always allow fast max because we are checking for the nan above. + llvm::Value* tmp0 = + vsl.Max(min_norm_pos, input, /*enable_fast_min_max=*/true); // VectorSupportLibrary (intentionally) can't juggle more than one type at a // time so drop down to IRBuilder for this bit. diff --git a/tensorflow/compiler/xla/service/cpu/test_target_triple_helper.h b/tensorflow/compiler/xla/service/cpu/test_target_triple_helper.h new file mode 100644 index 00000000000..857de4a8143 --- /dev/null +++ b/tensorflow/compiler/xla/service/cpu/test_target_triple_helper.h @@ -0,0 +1,28 @@ +/* Copyright 2020 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_TEST_TARGET_TRIPLE_HELPER_H_ +#define TENSORFLOW_TEST_TARGET_TRIPLE_HELPER_H_ + +#if (defined(__powerpc__) || \ + defined(__ppc__) && (__BYTE_ORDER__ == __ORDER_LITTLE_ENDIAN__)) +static const char kTargetCpuForHost[] = "ppc"; +static const char kTargetTripleForHost[] = "ppc64le-ibm-linux-gnu"; +#else +static const char kTargetCpuForHost[] = ""; +static const char kTargetTripleForHost[] = "x86_64-pc-linux"; +#endif + +#endif diff --git a/tensorflow/compiler/xla/service/cpu/tests/BUILD b/tensorflow/compiler/xla/service/cpu/tests/BUILD index 7c17b1339d1..d7c50dce3ca 100644 --- a/tensorflow/compiler/xla/service/cpu/tests/BUILD +++ b/tensorflow/compiler/xla/service/cpu/tests/BUILD @@ -41,6 +41,7 @@ tf_cc_test( deps = [ "//tensorflow/compiler/xla/service:hlo", "//tensorflow/compiler/xla/service/cpu:cpu_compiler", + "//tensorflow/compiler/xla/service/cpu:test_header_helper", "//tensorflow/compiler/xla/service/cpu/tests:cpu_codegen_test", "//tensorflow/core:lib", "//tensorflow/core:test", @@ -135,6 +136,7 @@ tf_cc_test( deps = [ "//tensorflow/compiler/xla/service:hlo", "//tensorflow/compiler/xla/service/cpu:cpu_compiler", + "//tensorflow/compiler/xla/service/cpu:test_header_helper", "//tensorflow/compiler/xla/service/cpu/tests:cpu_codegen_test", "//tensorflow/compiler/xla/tests:test_utils", "//tensorflow/core:lib", @@ -215,6 +217,7 @@ tf_cc_test( "//tensorflow/compiler/xla/service:hlo", "//tensorflow/compiler/xla/service:hlo_parser", "//tensorflow/compiler/xla/service/cpu:cpu_compiler", + "//tensorflow/compiler/xla/service/cpu:test_header_helper", "//tensorflow/compiler/xla/service/cpu/tests:cpu_codegen_test", "//tensorflow/core:lib", "//tensorflow/core:test", @@ -228,6 +231,7 @@ tf_cc_test( deps = [ "//tensorflow/compiler/xla/service:hlo", "//tensorflow/compiler/xla/service/cpu:cpu_compiler", + "//tensorflow/compiler/xla/service/cpu:test_header_helper", "//tensorflow/compiler/xla/service/cpu/tests:cpu_codegen_test", "//tensorflow/core:lib", "//tensorflow/core:test", @@ -241,6 +245,7 @@ tf_cc_test( deps = [ "//tensorflow/compiler/xla/service:hlo", "//tensorflow/compiler/xla/service/cpu:cpu_compiler", + "//tensorflow/compiler/xla/service/cpu:test_header_helper", "//tensorflow/compiler/xla/service/cpu/tests:cpu_codegen_test", "//tensorflow/core:lib", "//tensorflow/core:test", diff --git a/tensorflow/compiler/xla/service/cpu/tests/cpu_dyn_shape_test.cc b/tensorflow/compiler/xla/service/cpu/tests/cpu_dyn_shape_test.cc index 46249caa0c7..ce892ad34ae 100644 --- a/tensorflow/compiler/xla/service/cpu/tests/cpu_dyn_shape_test.cc +++ b/tensorflow/compiler/xla/service/cpu/tests/cpu_dyn_shape_test.cc @@ -16,6 +16,7 @@ limitations under the License. #include #include "tensorflow/compiler/xla/service/cpu/cpu_compiler.h" +#include "tensorflow/compiler/xla/service/cpu/test_target_triple_helper.h" #include "tensorflow/compiler/xla/service/cpu/tests/cpu_codegen_test.h" namespace xla { @@ -46,7 +47,8 @@ TEST_F(CpuDynamicShapeTest, DynamicShapeR2) { )"; CpuAotCompilationOptions options{ - /*triple=*/"x86_64", /*cpu_name=*/"", /*features=*/"", + /*triple=*/kTargetTripleForHost, /*cpu_name=*/kTargetCpuForHost, + /*features=*/"", /*entry_point_name=*/"entry", /*relocation_model=*/CpuAotCompilationOptions::RelocationModel::Static}; diff --git a/tensorflow/compiler/xla/service/cpu/tests/cpu_eigen_dot_operation_test.cc b/tensorflow/compiler/xla/service/cpu/tests/cpu_eigen_dot_operation_test.cc index 8b7f843582b..b233ee7df81 100644 --- a/tensorflow/compiler/xla/service/cpu/tests/cpu_eigen_dot_operation_test.cc +++ b/tensorflow/compiler/xla/service/cpu/tests/cpu_eigen_dot_operation_test.cc @@ -20,6 +20,7 @@ limitations under the License. #include "absl/strings/str_cat.h" #include "tensorflow/compiler/xla/service/cpu/cpu_compiler.h" +#include "tensorflow/compiler/xla/service/cpu/test_target_triple_helper.h" #include "tensorflow/compiler/xla/service/cpu/tests/cpu_codegen_test.h" #include "tensorflow/compiler/xla/service/hlo_computation.h" #include "tensorflow/compiler/xla/tests/test_utils.h" @@ -45,7 +46,8 @@ class CpuEigenDotOperationTest void CompileAndCheck(std::unique_ptr entry_computation, const string& filecheck_lines) { CpuAotCompilationOptions options{ - /*triple=*/"x86_64", /*cpu_name=*/"", /*features=*/"", + /*triple=*/kTargetTripleForHost, /*cpu_name=*/kTargetCpuForHost, + /*features=*/"", /*entry_point_name=*/"entry", /*relocation_model=*/CpuAotCompilationOptions::RelocationModel::Static}; diff --git a/tensorflow/compiler/xla/service/cpu/tests/cpu_key_value_sort_test.cc b/tensorflow/compiler/xla/service/cpu/tests/cpu_key_value_sort_test.cc index f3b7b91b2b5..b897f7a1522 100644 --- a/tensorflow/compiler/xla/service/cpu/tests/cpu_key_value_sort_test.cc +++ b/tensorflow/compiler/xla/service/cpu/tests/cpu_key_value_sort_test.cc @@ -16,6 +16,7 @@ limitations under the License. #include #include "tensorflow/compiler/xla/service/cpu/cpu_compiler.h" +#include "tensorflow/compiler/xla/service/cpu/test_target_triple_helper.h" #include "tensorflow/compiler/xla/service/cpu/tests/cpu_codegen_test.h" namespace xla { @@ -48,7 +49,8 @@ CHECK: call void @__xla_cpu_runtime_KeyValueSort TF_ASSERT_OK_AND_ASSIGN(auto module, ParseAndReturnVerifiedModule(hlo_text)); CpuAotCompilationOptions options{ - /*triple=*/"x86_64", /*cpu_name=*/"", /*features=*/"", + /*triple=*/kTargetTripleForHost, /*cpu_name=*/kTargetCpuForHost, + /*features=*/"", /*entry_point_name=*/"entry", /*relocation_model=*/CpuAotCompilationOptions::RelocationModel::Static}; diff --git a/tensorflow/compiler/xla/service/cpu/tests/cpu_literal_caching_test.cc b/tensorflow/compiler/xla/service/cpu/tests/cpu_literal_caching_test.cc index fc670201125..fb48cfe50e2 100644 --- a/tensorflow/compiler/xla/service/cpu/tests/cpu_literal_caching_test.cc +++ b/tensorflow/compiler/xla/service/cpu/tests/cpu_literal_caching_test.cc @@ -14,6 +14,7 @@ limitations under the License. ==============================================================================*/ #include "tensorflow/compiler/xla/service/cpu/cpu_compiler.h" +#include "tensorflow/compiler/xla/service/cpu/test_target_triple_helper.h" #include "tensorflow/compiler/xla/service/cpu/tests/cpu_codegen_test.h" #include "tensorflow/compiler/xla/service/hlo_parser.h" @@ -64,7 +65,8 @@ CHECK-NOT: private unnamed_addr constant [48 x i8] ParseAndReturnVerifiedModule(hlo_text)); CpuAotCompilationOptions options{ - /*triple=*/"x86_64-pc-linux", /*cpu_name=*/"", /*features=*/"", + /*triple=*/kTargetTripleForHost, /*cpu_name=*/kTargetCpuForHost, + /*features=*/"", /*entry_point_name=*/"entry", /*relocation_model=*/CpuAotCompilationOptions::RelocationModel::Static}; @@ -112,7 +114,8 @@ CHECK-NOT: private unnamed_addr constant [8 x i8] ParseAndReturnVerifiedModule(hlo_text)); CpuAotCompilationOptions options{ - /*triple=*/"x86_64-pc-linux", /*cpu_name=*/"", /*features=*/"", + /*triple=*/kTargetTripleForHost, /*cpu_name=*/kTargetCpuForHost, + /*features=*/"", /*entry_point_name=*/"entry", /*relocation_model=*/CpuAotCompilationOptions::RelocationModel::Static}; diff --git a/tensorflow/compiler/xla/service/cpu/tests/cpu_outfeed_test.cc b/tensorflow/compiler/xla/service/cpu/tests/cpu_outfeed_test.cc index ad83c485998..b2ed9bd5f31 100644 --- a/tensorflow/compiler/xla/service/cpu/tests/cpu_outfeed_test.cc +++ b/tensorflow/compiler/xla/service/cpu/tests/cpu_outfeed_test.cc @@ -16,6 +16,7 @@ limitations under the License. #include #include "tensorflow/compiler/xla/service/cpu/cpu_compiler.h" +#include "tensorflow/compiler/xla/service/cpu/test_target_triple_helper.h" #include "tensorflow/compiler/xla/service/cpu/tests/cpu_codegen_test.h" namespace xla { @@ -46,7 +47,8 @@ CHECK: private unnamed_addr constant [48 x i8] TF_ASSERT_OK_AND_ASSIGN(auto module, ParseAndReturnVerifiedModule(hlo_text)); CpuAotCompilationOptions options{ - /*triple=*/"x86_64-pc-linux", /*cpu_name=*/"", /*features=*/"", + /*triple=*/kTargetTripleForHost, /*cpu_name=*/kTargetCpuForHost, + /*features=*/"", /*entry_point_name=*/"entry", /*relocation_model=*/CpuAotCompilationOptions::RelocationModel::Static}; @@ -73,7 +75,8 @@ CHECK: Outfeed TF_ASSERT_OK_AND_ASSIGN(auto module, ParseAndReturnVerifiedModule(hlo_text)); CpuAotCompilationOptions options{ - /*triple=*/"x86_64-pc-linux", /*cpu_name=*/"", /*features=*/"", + /*triple=*/kTargetTripleForHost, /*cpu_name=*/kTargetCpuForHost, + /*features=*/"", /*entry_point_name=*/"entry", /*relocation_model=*/CpuAotCompilationOptions::RelocationModel::Static}; diff --git a/tensorflow/compiler/xla/service/cpu/vector_support_library.cc b/tensorflow/compiler/xla/service/cpu/vector_support_library.cc index b15ad1e162d..0d2eab9fd42 100644 --- a/tensorflow/compiler/xla/service/cpu/vector_support_library.cc +++ b/tensorflow/compiler/xla/service/cpu/vector_support_library.cc @@ -80,10 +80,11 @@ llvm::Value* VectorSupportLibrary::Sub(llvm::Value* lhs, llvm::Value* rhs) { return b()->CreateFSub(lhs, rhs); } -llvm::Value* VectorSupportLibrary::Max(llvm::Value* lhs, llvm::Value* rhs) { +llvm::Value* VectorSupportLibrary::Max(llvm::Value* lhs, llvm::Value* rhs, + bool enable_fast_min_max) { AssertCorrectTypes({lhs, rhs}); if (scalar_type_->isFloatingPointTy()) { - return llvm_ir::EmitFloatMax(lhs, rhs, b_); + return llvm_ir::EmitFloatMax(lhs, rhs, b_, enable_fast_min_max); } else { LOG(FATAL) << "Max for integers is unimplemented"; } diff --git a/tensorflow/compiler/xla/service/cpu/vector_support_library.h b/tensorflow/compiler/xla/service/cpu/vector_support_library.h index cbbc4d7bf34..f1a0b0a4406 100644 --- a/tensorflow/compiler/xla/service/cpu/vector_support_library.h +++ b/tensorflow/compiler/xla/service/cpu/vector_support_library.h @@ -78,9 +78,11 @@ class VectorSupportLibrary { llvm::Value* Sub(llvm::Value* lhs, const llvm::APFloat& rhs) { return Sub(lhs, GetConstantFloat(lhs->getType(), rhs)); } - llvm::Value* Max(llvm::Value* lhs, llvm::Value* rhs); - llvm::Value* Max(const llvm::APFloat& lhs, llvm::Value* rhs) { - return Max(GetConstantFloat(rhs->getType(), lhs), rhs); + llvm::Value* Max(llvm::Value* lhs, llvm::Value* rhs, + bool enable_fast_min_max); + llvm::Value* Max(const llvm::APFloat& lhs, llvm::Value* rhs, + bool enable_fast_min_max) { + return Max(GetConstantFloat(rhs->getType(), lhs), rhs, enable_fast_min_max); } llvm::Value* Div(llvm::Value* lhs, llvm::Value* rhs); diff --git a/tensorflow/compiler/xla/service/cpu/vectorized_reduce_with_no_vector_registers_test.cc b/tensorflow/compiler/xla/service/cpu/vectorized_reduce_with_no_vector_registers_test.cc index 754885d8744..bf38450a386 100644 --- a/tensorflow/compiler/xla/service/cpu/vectorized_reduce_with_no_vector_registers_test.cc +++ b/tensorflow/compiler/xla/service/cpu/vectorized_reduce_with_no_vector_registers_test.cc @@ -18,6 +18,7 @@ limitations under the License. #include "llvm/Support/TargetRegistry.h" #include "llvm/Target/TargetMachine.h" #include "tensorflow/compiler/xla/service/cpu/cpu_compiler.h" +#include "tensorflow/compiler/xla/service/cpu/test_target_triple_helper.h" #include "tensorflow/compiler/xla/test.h" #include "tensorflow/compiler/xla/tests/hlo_test_base.h" @@ -74,8 +75,9 @@ ENTRY main { module_group->push_back(std::move(hlo_module)); // Check that the GetTargetVectorRegisterByteSize is itself working. - TF_ASSERT_OK_AND_ASSIGN(unsigned vector_register_byte_size_for_x86_64, - GetTargetVectorRegisterByteSize("x86_64-pc-linux")); + TF_ASSERT_OK_AND_ASSIGN( + unsigned vector_register_byte_size_for_x86_64, + GetTargetVectorRegisterByteSize(kTargetTripleForHost)); ASSERT_EQ(vector_register_byte_size_for_x86_64, 16); std::string triple = "i686-none-android"; diff --git a/tensorflow/compiler/xla/service/elemental_ir_emitter.cc b/tensorflow/compiler/xla/service/elemental_ir_emitter.cc index e4097b0c06f..4b6c30cadc4 100644 --- a/tensorflow/compiler/xla/service/elemental_ir_emitter.cc +++ b/tensorflow/compiler/xla/service/elemental_ir_emitter.cc @@ -1313,12 +1313,12 @@ StatusOr ElementalIrEmitter::EmitComplexBinaryOp( llvm::Value* ElementalIrEmitter::EmitFloatMax(llvm::Value* lhs_value, llvm::Value* rhs_value) { - return llvm_ir::EmitFloatMax(lhs_value, rhs_value, b_); + return llvm_ir::EmitFloatMax(lhs_value, rhs_value, b_, fast_min_max()); } llvm::Value* ElementalIrEmitter::EmitFloatMin(llvm::Value* lhs_value, llvm::Value* rhs_value) { - return llvm_ir::EmitFloatMin(lhs_value, rhs_value, b_); + return llvm_ir::EmitFloatMin(lhs_value, rhs_value, b_, fast_min_max()); } StatusOr ElementalIrEmitter::EmitLog(PrimitiveType prim_type, diff --git a/tensorflow/compiler/xla/service/elemental_ir_emitter.h b/tensorflow/compiler/xla/service/elemental_ir_emitter.h index e39d2dd99ec..365e3f56b85 100644 --- a/tensorflow/compiler/xla/service/elemental_ir_emitter.h +++ b/tensorflow/compiler/xla/service/elemental_ir_emitter.h @@ -245,6 +245,8 @@ class ElementalIrEmitter : public IrBuilderMixin { std::vector initial_value_generators, const llvm_ir::IrArray::Index& index); + virtual bool fast_min_max() = 0; + llvm::IRBuilder<>* const b_; llvm::Module* module_; diff --git a/tensorflow/compiler/xla/service/executable.cc b/tensorflow/compiler/xla/service/executable.cc index ebf7cc440dd..61ce6200a28 100644 --- a/tensorflow/compiler/xla/service/executable.cc +++ b/tensorflow/compiler/xla/service/executable.cc @@ -28,10 +28,57 @@ limitations under the License. #include "tensorflow/core/lib/io/path.h" #include "tensorflow/core/lib/strings/proto_serialization.h" #include "tensorflow/core/platform/env.h" +#include "tensorflow/core/platform/errors.h" #include "tensorflow/stream_executor/device_description.h" namespace xla { +ExecutionInput::~ExecutionInput() { + for (auto& index : unowned_indices_) { + auto buffer = buffers_.mutable_element(index)->Release(); + if (buffer) { + buffer->Release(); + } + } +} + +Status ExecutionInput::SetDynamicShape(Shape dynamic_shape) { + const Shape& input_shape = shape(); + if (!ShapeUtil::DynamicShapeIsCompatible(input_shape, dynamic_shape)) { + return tensorflow::errors::InvalidArgument( + "Cannot set dynamic shape: ", input_shape.DebugString(), " vs. ", + dynamic_shape.DebugString()); + } + dynamic_shape_ = absl::make_unique(std::move(dynamic_shape)); + return Status::OK(); +} + +void ExecutionInput::SetUnownedBuffer(const ShapeIndex& index, + MaybeOwningDeviceMemory buffer) { + *buffers_.mutable_element(index) = std::move(buffer); + unowned_indices_.insert(index); +} + +xla::StatusOr ExecutionInput::ToShapedBuffer( + se::DeviceMemoryAllocator* allocator, int device_ordinal) const { + const Shape& input_shape = shape(); + xla::ShapedBuffer shaped_buffer(input_shape, input_shape, + allocator->platform(), device_ordinal); + for (const auto& index_buffer : Buffers()) { + const tensorflow::se::OwningDeviceMemory* mem = + index_buffer.second.AsOwningDeviceMemory(); + if (mem != nullptr && (mem->allocator() != allocator || + mem->device_ordinal() != device_ordinal)) { + return tensorflow::errors::InvalidArgument( + "Device buffer at index ", index_buffer.first.ToString(), + " has mismatching allocator/device"); + } + shaped_buffer.set_buffer(index_buffer.second.AsDeviceMemoryBase(), + index_buffer.first); + } + return std::move(shaped_buffer); +} + StatusOr Executable::ExecuteOnStream( const ServiceExecutableRunOptions* run_options, absl::Span arguments, diff --git a/tensorflow/compiler/xla/service/executable.h b/tensorflow/compiler/xla/service/executable.h index 2c979662d24..6881f6dd68a 100644 --- a/tensorflow/compiler/xla/service/executable.h +++ b/tensorflow/compiler/xla/service/executable.h @@ -17,6 +17,7 @@ limitations under the License. #define TENSORFLOW_COMPILER_XLA_SERVICE_EXECUTABLE_H_ #include +#include #include #include @@ -65,31 +66,32 @@ class ExecutionInput { : buffers_(std::move(buffers)) {} ExecutionInput(ExecutionInput&&) = default; - ~ExecutionInput() { - for (auto& index : unowned_indices_) { - auto buffer = buffers_.mutable_element(index)->Release(); - if (buffer) { - buffer->Release(); - } - } - } + ~ExecutionInput(); ExecutionInput& operator=(ExecutionInput&&) = default; - const Shape& shape() const { return buffers_.shape(); } + const Shape& shape() const { + return dynamic_shape_ != nullptr ? *dynamic_shape_ : buffers_.shape(); + } + + Status SetDynamicShape(Shape dynamic_shape); + + xla::StatusOr ToShapedBuffer( + se::DeviceMemoryAllocator* allocator, int device_ordinal) const; void SetBuffer(const ShapeIndex& index, MaybeOwningDeviceMemory buffer) { *buffers_.mutable_element(index) = std::move(buffer); } void SetUnownedBuffer(const ShapeIndex& index, - MaybeOwningDeviceMemory buffer) { - *buffers_.mutable_element(index) = std::move(buffer); - unowned_indices_.push_back(index); - } + MaybeOwningDeviceMemory buffer); void SetUnownedIndex(const ShapeIndex& index) { - unowned_indices_.push_back(index); + unowned_indices_.insert(index); + } + + void ClearUnownedIndex(const ShapeIndex& index) { + unowned_indices_.erase(index); } const ShapeTree& Buffers() const { return buffers_; } @@ -106,9 +108,10 @@ class ExecutionInput { private: ShapeTree buffers_; - // (Unordered) set of indices of buffers that should be returned to the - // caller if an error occurs when enqueuing the computation. - std::vector unowned_indices_; + // Set of indices of buffers that should be returned to the caller if an error + // occurs when enqueuing the computation. + std::set unowned_indices_; + std::unique_ptr dynamic_shape_; }; // ExecutionOutput encapsulates the output buffers of a execution and the @@ -145,7 +148,6 @@ class ExecutionOutput { to_be_released_.push_back(std::move(mem)); } - // Should be called once it is known that the execute operation succeeded, // before returning the ExecutionOutput to the caller. ExecutionOutput& Commit() { diff --git a/tensorflow/compiler/xla/service/gpu/BUILD b/tensorflow/compiler/xla/service/gpu/BUILD index 0eb82128159..472d2117a2c 100644 --- a/tensorflow/compiler/xla/service/gpu/BUILD +++ b/tensorflow/compiler/xla/service/gpu/BUILD @@ -1174,6 +1174,7 @@ cc_library( ":reduction_degenerate_dim_remover", ":reduction_dimension_grouper", ":reduction_layout_normalizer", + ":reduction_splitter", ":stream_assignment", ":stream_executor_util", ":target_constants", @@ -1819,6 +1820,33 @@ cc_library( ], ) +cc_library( + name = "reduction_splitter", + srcs = ["reduction_splitter.cc"], + hdrs = ["reduction_splitter.h"], + deps = [ + ":ir_emission_utils", + "//tensorflow/compiler/xla:shape_util", + "//tensorflow/compiler/xla/service:hlo", + "//tensorflow/compiler/xla/service:hlo_pass", + ], +) + +tf_cc_test( + name = "reduction_splitter_test", + srcs = ["reduction_splitter_test.cc"], + deps = [ + ":reduction_splitter", + "//tensorflow/compiler/xla:shape_util", + "//tensorflow/compiler/xla:test", + "//tensorflow/compiler/xla:test_helpers", + "//tensorflow/compiler/xla/service:hlo_matchers", + "//tensorflow/compiler/xla/service:hlo_parser", + "//tensorflow/compiler/xla/tests:hlo_test_base", + "//tensorflow/compiler/xla/tests:xla_internal_test_main", + ], +) + cc_library( name = "reduction_layout_normalizer", srcs = ["reduction_layout_normalizer.cc"], diff --git a/tensorflow/compiler/xla/service/gpu/elemental_ir_emitter.h b/tensorflow/compiler/xla/service/gpu/elemental_ir_emitter.h index a3056b1ddad..766a4c84df5 100644 --- a/tensorflow/compiler/xla/service/gpu/elemental_ir_emitter.h +++ b/tensorflow/compiler/xla/service/gpu/elemental_ir_emitter.h @@ -96,6 +96,10 @@ class GpuElementalIrEmitter : public ElementalIrEmitter { llvm::Value* EmitThreadId() override; + bool fast_min_max() override { + return hlo_module_config_.debug_options().xla_gpu_enable_fast_min_max(); + } + private: // Emits IR for op, which must have opcode kPower. StatusOr EmitPowerOp(const HloInstruction* op, diff --git a/tensorflow/compiler/xla/service/gpu/gpu_compiler.cc b/tensorflow/compiler/xla/service/gpu/gpu_compiler.cc index cddbee92874..156cb112285 100644 --- a/tensorflow/compiler/xla/service/gpu/gpu_compiler.cc +++ b/tensorflow/compiler/xla/service/gpu/gpu_compiler.cc @@ -65,6 +65,7 @@ limitations under the License. #include "tensorflow/compiler/xla/service/gpu/reduction_degenerate_dim_remover.h" #include "tensorflow/compiler/xla/service/gpu/reduction_dimension_grouper.h" #include "tensorflow/compiler/xla/service/gpu/reduction_layout_normalizer.h" +#include "tensorflow/compiler/xla/service/gpu/reduction_splitter.h" #include "tensorflow/compiler/xla/service/gpu/stream_assignment.h" #include "tensorflow/compiler/xla/service/gpu/stream_executor_util.h" #include "tensorflow/compiler/xla/service/gpu/target_constants.h" @@ -371,6 +372,7 @@ Status GpuCompiler::OptimizeHloPostLayoutAssignment( pipeline.AddPass(); pipeline.AddPass(); pipeline.AddPass(); + pipeline.AddPass>(); // The LayoutAssignment pass may leave behind kCopy instructions which are // duplicate or NOPs, so remove them with algebraic simplification and CSE. diff --git a/tensorflow/compiler/xla/service/gpu/ir_emission_utils.cc b/tensorflow/compiler/xla/service/gpu/ir_emission_utils.cc index b97aa3651c6..01bcf456f75 100644 --- a/tensorflow/compiler/xla/service/gpu/ir_emission_utils.cc +++ b/tensorflow/compiler/xla/service/gpu/ir_emission_utils.cc @@ -226,6 +226,11 @@ bool IsReductionFromOrToContiguousDimensions(const HloInstruction& reduce) { dims_to_keep.push_back(dim); } } + + // We support fast codegen for three cases: + // 1) Row reduction: (K, R) + // 2) Column reduction: (K, R, K) + // 3) "Batched" row reduction: (R, K, R) if (!LayoutUtil::AreDimensionsConsecutive(input->shape().layout(), dims_to_keep) && !LayoutUtil::AreDimensionsConsecutive(input->shape().layout(), diff --git a/tensorflow/compiler/xla/service/gpu/ir_emitter_unnested.cc b/tensorflow/compiler/xla/service/gpu/ir_emitter_unnested.cc index 937a0ea5bbc..74aad5f5bd5 100644 --- a/tensorflow/compiler/xla/service/gpu/ir_emitter_unnested.cc +++ b/tensorflow/compiler/xla/service/gpu/ir_emitter_unnested.cc @@ -1418,6 +1418,13 @@ Status IrEmitterUnnested::HandleSort(HloInstruction* sort) { AddThunkToThunkSequence( absl::make_unique(std::move(thunks), sort)); + if (sort->operand_count() > 1) { + // Emit the tuple as part of the last stage of sorting. + // We are currently in the block sorted.in_bounds.after. + b_.SetInsertPoint(b_.GetInsertBlock()->getTerminator()); + llvm_ir::EmitTuple(GetIrArray(*sort, *sort), + ConstructIrArrayForOutputs(*sort), &b_); + } return Status::OK(); } diff --git a/tensorflow/compiler/xla/service/gpu/kernel_thunk.h b/tensorflow/compiler/xla/service/gpu/kernel_thunk.h index 88351881f3a..25acabb239b 100644 --- a/tensorflow/compiler/xla/service/gpu/kernel_thunk.h +++ b/tensorflow/compiler/xla/service/gpu/kernel_thunk.h @@ -77,8 +77,6 @@ class KernelThunk : public Thunk { // Will be set by IrEmitterUnnested. LaunchDimensions launch_dimensions_; - // Describes how to load this kernel. ExecuteOnStream reuses this loader - // specification for all executions. mutable tensorflow::mutex mutex_; // Loaded kernels for each `StreamExecutor`. Requires pointer stability of diff --git a/tensorflow/compiler/xla/service/gpu/llvm_gpu_backend/gpu_backend_lib.cc b/tensorflow/compiler/xla/service/gpu/llvm_gpu_backend/gpu_backend_lib.cc index 497dcda4361..d2126a8d17d 100644 --- a/tensorflow/compiler/xla/service/gpu/llvm_gpu_backend/gpu_backend_lib.cc +++ b/tensorflow/compiler/xla/service/gpu/llvm_gpu_backend/gpu_backend_lib.cc @@ -492,9 +492,10 @@ void NVPTXBackendInit(const HloModuleConfig& hlo_module_config) { namespace nvptx { -StatusOr CompileToPtx(llvm::Module* module, GpuVersion gpu_version, - const HloModuleConfig& hlo_module_config, - const string& libdevice_dir_path) { +StatusOr CompileToPtx( + llvm::Module* module, GpuVersion gpu_version, + const HloModuleConfig& hlo_module_config, const string& libdevice_dir_path, + std::function configure_target) { static absl::once_flag backend_init_flag; absl::call_once(backend_init_flag, NVPTXBackendInit, hlo_module_config); @@ -525,6 +526,11 @@ StatusOr CompileToPtx(llvm::Module* module, GpuVersion gpu_version, std::unique_ptr target_machine = NVPTXGetTargetMachine( default_target_triple, *compute_capability, hlo_module_config); + // Apply target machine configuration from call-back if available. + if (configure_target) { + configure_target(target_machine.get()); + } + // Link with libdevice, and optimize the LLVM module. TF_RETURN_IF_ERROR(LinkAndOptimizeModule( module, gpu_version, hlo_module_config, libdevice_dir_path, diff --git a/tensorflow/compiler/xla/service/gpu/llvm_gpu_backend/gpu_backend_lib.h b/tensorflow/compiler/xla/service/gpu/llvm_gpu_backend/gpu_backend_lib.h index 526621de7a5..33ef9280c7a 100644 --- a/tensorflow/compiler/xla/service/gpu/llvm_gpu_backend/gpu_backend_lib.h +++ b/tensorflow/compiler/xla/service/gpu/llvm_gpu_backend/gpu_backend_lib.h @@ -22,6 +22,7 @@ limitations under the License. #include "absl/strings/string_view.h" #include "llvm/IR/Module.h" +#include "llvm/Target/TargetMachine.h" #include "tensorflow/compiler/xla/service/gpu/gpu_types.h" #include "tensorflow/compiler/xla/service/hlo_module_config.h" #include "tensorflow/compiler/xla/statusor.h" @@ -38,9 +39,10 @@ namespace nvptx { // The Compile.* interfaces each create their own llvm::LLVMContext objects for // thread safety, but note that LLVM's multithreaded support is very // preliminary; multithreaded use is not recommended at this time. -StatusOr CompileToPtx(llvm::Module* module, GpuVersion gpu_version, - const HloModuleConfig& hlo_module_config, - const string& libdevice_dir_path); +StatusOr CompileToPtx( + llvm::Module* module, GpuVersion gpu_version, + const HloModuleConfig& hlo_module_config, const string& libdevice_dir_path, + std::function configure_target = nullptr); } // namespace nvptx namespace amdgpu { diff --git a/tensorflow/compiler/xla/service/gpu/reduction_splitter.cc b/tensorflow/compiler/xla/service/gpu/reduction_splitter.cc new file mode 100644 index 00000000000..b68213ec35f --- /dev/null +++ b/tensorflow/compiler/xla/service/gpu/reduction_splitter.cc @@ -0,0 +1,117 @@ +/* Copyright 2019 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tensorflow/compiler/xla/service/gpu/reduction_splitter.h" + +#include + +#include "tensorflow/compiler/xla/service/dfs_hlo_visitor_with_default.h" +#include "tensorflow/compiler/xla/service/gpu/ir_emission_utils.h" +#include "tensorflow/compiler/xla/service/hlo_instructions.h" +#include "tensorflow/compiler/xla/shape_util.h" + +namespace xla { +namespace gpu { + +class ReductionSplitterVisitor : public DfsHloRewriteVisitor { + public: + Status HandleReduce(HloInstruction *reduce) override { + VLOG(4) << "Input: " << reduce->ToString(); + + // Reductions with contiguous dimensions are lowered to efficient code. No + // need to split such ops. + if (IsReductionFromOrToContiguousDimensions(*reduce)) { + return Status::OK(); + } + if (reduce->dimensions().size() < 2) { + return Status::OK(); + } + if (!reduce->shape().IsArray()) { + // TODO(cheshire): Handle variadic reduction. + return Status::OK(); + } + + HloInstruction *operand = reduce->mutable_operand(0); + const Shape &shape = operand->shape(); + CHECK(shape == LayoutUtil::GetWithDefaultLayout(shape)) + << "Default layout should be enforced on reduction operand"; + // Verify that contiguous dimensions have been grouped by the + // ReductionDimensionGrouper pass. + for (int64 i = 0; i < reduce->dimensions().size(); ++i) { + for (int64 j = i + 1; j < reduce->dimensions().size(); ++j) { + CHECK(abs(reduce->dimensions(i) - reduce->dimensions(j)) > 1) + << "Reduction dimensions must not be consecutive"; + } + } + + // The reduce op has non-contiguous dimensions. Look for the dimension with + // the largest shape dimension. Reducing along this dimension first will + // reduce the output size most effectively. + int64 max_shape_dim = 0; + int64 max_reduce_dim = 0; + const auto &input_shape = reduce->operand(0)->shape(); + for (int64 i = 0; i < reduce->dimensions().size(); ++i) { + if (input_shape.dimensions(reduce->dimensions(i)) > max_shape_dim) { + max_reduce_dim = reduce->dimensions(i); + max_shape_dim = input_shape.dimensions(max_reduce_dim); + } + } + // TODO(tjoerg): Run microbenchmarks to tune this threshold. + if (max_shape_dim < 128) { + return Status::OK(); + } + + // Split the reduction into a pre-reduction and a final reduction. + VLOG(3) << "Splitting reduction " << reduce->name() << " at dimension " + << max_reduce_dim; + std::vector pre_reduce_dims; + pre_reduce_dims.push_back(max_reduce_dim); + std::vector pre_reduce_shape_dims(input_shape.dimensions().begin(), + input_shape.dimensions().end()); + pre_reduce_shape_dims.erase(pre_reduce_shape_dims.begin() + max_reduce_dim); + Shape pre_reduce_shape = ShapeUtil::MakeShape( + reduce->shape().element_type(), pre_reduce_shape_dims); + std::unique_ptr pre_reduce = HloInstruction::CreateReduce( + pre_reduce_shape, reduce->mutable_operand(0), + reduce->mutable_operand(1), pre_reduce_dims, reduce->to_apply()); + pre_reduce->set_metadata(reduce->metadata()); + + std::vector final_reduce_dims(reduce->dimensions().begin(), + reduce->dimensions().end()); + final_reduce_dims.erase( + std::remove(final_reduce_dims.begin(), final_reduce_dims.end(), + max_reduce_dim), + final_reduce_dims.end()); + for (int64 i = 0; i < final_reduce_dims.size(); ++i) { + if (final_reduce_dims[i] > max_reduce_dim) { + final_reduce_dims[i]--; + } + } + std::unique_ptr final_reduce = HloInstruction::CreateReduce( + reduce->shape(), + reduce->parent()->AddInstruction(std::move(pre_reduce)), + reduce->mutable_operand(1), final_reduce_dims, reduce->to_apply()); + return ReplaceWithNewInstruction(reduce, std::move(final_reduce)); + } +}; + +StatusOr ReductionSplitter::Run(HloModule *module) { + TF_ASSIGN_OR_RETURN(bool changed, + ReductionSplitterVisitor().RunOnModule(module)); + return changed; +} + +} // namespace gpu +} // namespace xla diff --git a/tensorflow/compiler/xla/service/gpu/reduction_splitter.h b/tensorflow/compiler/xla/service/gpu/reduction_splitter.h new file mode 100644 index 00000000000..f161b579eb8 --- /dev/null +++ b/tensorflow/compiler/xla/service/gpu/reduction_splitter.h @@ -0,0 +1,49 @@ +/* Copyright 2019 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ +#ifndef TENSORFLOW_COMPILER_XLA_SERVICE_GPU_REDUCTION_SPLITTER_H_ +#define TENSORFLOW_COMPILER_XLA_SERVICE_GPU_REDUCTION_SPLITTER_H_ + +#include "tensorflow/compiler/xla/service/hlo_module.h" +#include "tensorflow/compiler/xla/service/hlo_pass_interface.h" + +namespace xla { +namespace gpu { + +// Splits a reduce op into two consecutive reduce ops if +// * the reduce dimensions are not contiguous and +// * at least one reduce dimension is large (i.e. corresponds to a large input +// shape dimension). +// +// Reductions with non-contiguous dimensions are emitted as simple element-wise +// loops. This is inefficient when reducing large input shape dimensions. +// Splitting such reductions allows using more efficient reduction emitters. +// +// This pass splits reduce ops into two consecutive reduce ops. Run it to a +// fixpoint to split reduce ops along multiple large dimensions. +// +// Precondition: ReductionDimensionGrouper has been run and adjacent reduce +// dimentsions have been grouped. Reduction layouts have been normalized. + +class ReductionSplitter : public HloModulePass { + public: + absl::string_view name() const override { return "reduction-splitter"; } + + StatusOr Run(HloModule* module) override; +}; + +} // namespace gpu +} // namespace xla + +#endif // TENSORFLOW_COMPILER_XLA_SERVICE_GPU_REDUCTION_SPLITTER_H_ diff --git a/tensorflow/compiler/xla/service/gpu/reduction_splitter_test.cc b/tensorflow/compiler/xla/service/gpu/reduction_splitter_test.cc new file mode 100644 index 00000000000..1be55b84204 --- /dev/null +++ b/tensorflow/compiler/xla/service/gpu/reduction_splitter_test.cc @@ -0,0 +1,140 @@ +/* Copyright 2016 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tensorflow/compiler/xla/service/gpu/reduction_splitter.h" + +#include "tensorflow/compiler/xla/service/hlo_matchers.h" +#include "tensorflow/compiler/xla/service/hlo_parser.h" +#include "tensorflow/compiler/xla/shape_util.h" +#include "tensorflow/compiler/xla/test.h" +#include "tensorflow/compiler/xla/test_helpers.h" +#include "tensorflow/compiler/xla/tests/hlo_test_base.h" + +namespace xla { +namespace gpu { +namespace { + +namespace op = xla::testing::opcode_matchers; + +class ReductionSplitterTest : public HloTestBase {}; + +TEST_F(ReductionSplitterTest, SplitReductionAtDimensionTwo) { + auto module = ParseAndReturnVerifiedModule(R"( + HloModule test + + add_computation { + x = f32[] parameter(0) + y = f32[] parameter(1) + ROOT add = f32[] add(x, y) + } + + ENTRY entry_computation { + param_0 = f16[6,16,512,64]{3,2,1,0} parameter(0) + transpose.1781 = f16[6,512,16,64]{3,1,2,0} transpose(param_0), dimensions={0,2,1,3} + convert.6986 = f32[6,512,16,64]{3,1,2,0} convert(transpose.1781) + bitcast.2136 = f32[6,16,512,64]{3,2,1,0} bitcast(convert.6986) + constant_11111 = f32[] constant(0) + ROOT reduce.982 = f32[16,64]{1,0} reduce(bitcast.2136, constant_11111), dimensions={0,2}, to_apply=add_computation + } + )") + .ValueOrDie(); + ASSERT_TRUE(ReductionSplitter().Run(module.get()).ValueOrDie()); + SCOPED_TRACE(module->ToString()); + const HloInstruction* root_reduction = + module->entry_computation()->root_instruction(); + ASSERT_THAT(root_reduction, op::Reduce(op::Reduce(), op::Constant())); + + auto* pre_reduction = root_reduction->operand(0); + EXPECT_THAT(pre_reduction->dimensions(), std::vector({2})); + EXPECT_THAT(pre_reduction->shape(), ShapeUtil::MakeShape(F32, {6, 16, 64})); + EXPECT_THAT(root_reduction->dimensions(), std::vector({0})); + EXPECT_THAT(root_reduction->shape(), ShapeUtil::MakeShape(F32, {16, 64})); +} + +TEST_F(ReductionSplitterTest, SplitReductionAtDimensionZero) { + auto module = ParseAndReturnVerifiedModule(R"( + HloModule test + + add_computation { + x = f32[] parameter(0) + y = f32[] parameter(1) + ROOT add = f32[] add(x, y) + } + + ENTRY entry_computation { + param_0 = f32[1024,16,512,64,128]{4,3,2,1,0} parameter(0) + constant_11111 = f32[] constant(0) + ROOT reduce.982 = f32[16,64]{1,0} reduce(param_0, constant_11111), dimensions={2,0,4}, to_apply=add_computation + } + )") + .ValueOrDie(); + ASSERT_TRUE(ReductionSplitter().Run(module.get()).ValueOrDie()); + SCOPED_TRACE(module->ToString()); + const HloInstruction* root_reduction = + module->entry_computation()->root_instruction(); + ASSERT_THAT(root_reduction, op::Reduce(op::Reduce(), op::Constant())); + + auto* pre_reduction = root_reduction->operand(0); + EXPECT_THAT(pre_reduction->dimensions(), std::vector({0})); + EXPECT_THAT(pre_reduction->shape(), + ShapeUtil::MakeShape(F32, {16, 512, 64, 128})); + EXPECT_THAT(root_reduction->dimensions(), std::vector({1, 3})); + EXPECT_THAT(root_reduction->shape(), ShapeUtil::MakeShape(F32, {16, 64})); +} + +TEST_F(ReductionSplitterTest, DontSplitReductionWithSmallDimensions) { + auto module = ParseAndReturnVerifiedModule(R"( + HloModule test + + add_computation { + x = f32[] parameter(0) + y = f32[] parameter(1) + ROOT add = f32[] add(x, y) + } + + ENTRY entry_computation { + param_0 = f32[8,1024,8]{2,1,0} parameter(0) + constant_11111 = f32[] constant(0) + ROOT reduce.982 = f32[1024]{0} reduce(param_0, constant_11111), dimensions={2,0}, to_apply=add_computation + } + )") + .ValueOrDie(); + EXPECT_FALSE(ReductionSplitter().Run(module.get()).ValueOrDie()); +} + +TEST_F(ReductionSplitterTest, DontSplitReductionsWithContiguousDimensions) { + auto module = ParseAndReturnVerifiedModule(R"( + HloModule test + + add_computation { + x = f32[] parameter(0) + y = f32[] parameter(1) + ROOT add = f32[] add(x, y) + } + + ENTRY entry_computation { + param_0 = f32[128,128,64,128]{3,2,1,0} parameter(0) + constant_11111 = f32[] constant(0) + // The dimenstions to keep (1 and 2) are contiguous. + ROOT reduce.982 = f32[128,64]{1,0} reduce(param_0, constant_11111), dimensions={3,0}, to_apply=add_computation + } + )") + .ValueOrDie(); + EXPECT_FALSE(ReductionSplitter().Run(module.get()).ValueOrDie()); +} + +} // namespace +} // namespace gpu +} // namespace xla diff --git a/tensorflow/compiler/xla/service/gpu/tests/gpu_ftz_test.cc b/tensorflow/compiler/xla/service/gpu/tests/gpu_ftz_test.cc index 282f7b24a31..9b58457d129 100644 --- a/tensorflow/compiler/xla/service/gpu/tests/gpu_ftz_test.cc +++ b/tensorflow/compiler/xla/service/gpu/tests/gpu_ftz_test.cc @@ -92,26 +92,23 @@ TEST_F(GpuFtzDisabledTest, MultiplyFtz) { } // In NVPTX, exp(float) is implemented in libdevice, and consults __nvvm_reflect -// to determine whether or not ftz is enabled. The implementation uses two -// calls to ex2.approx. When ftz is on, we get two calls to the ftz version; -// when ftz is off, we get one call to the ftz version and one call to the -// regular version. +// to determine whether or not ftz is enabled. +// The implementation in CUDA 11 uses one ex2.approx.ftz, irrespective of ftz +// being enabled or not. In previous CUDA versions, there is a leading +// ex2.approx that does obey the ftz setting. +// Instead of pattern matching implementation details, it might be better to +// value-test the actual result instead. TODO(csigg): change to value-test. TEST_F(GpuFtzEnabledTest, ExpFtz) { CompileAndOptionallyVerifyPtx(CreateUnaryOpModule(HloOpcode::kExp), R"( CHECK-NOT: ex2.approx.f32 CHECK: ex2.approx.ftz.f32 CHECK-NOT: ex2.approx.f32 - CHECK: ex2.approx.ftz.f32 - CHECK-NOT: ex2.approx.f32 - CHECK-NOT: ex2.approx.ftz.f32 )"); } TEST_F(GpuFtzDisabledTest, ExpFtz) { CompileAndOptionallyVerifyPtx(CreateUnaryOpModule(HloOpcode::kExp), R"( - CHECK-NOT: ex2.approx.f32 - CHECK-DAG: ex2.approx.ftz.f32 - CHECK-DAG: ex2.approx.f32 + CHECK: ex2.approx.ftz.f32 CHECK-NOT: ex2.approx.f32 CHECK-NOT: ex2.approx.ftz.f32 )"); diff --git a/tensorflow/compiler/xla/service/gpu/tests/reduction_degenerate_dim_remover_test.cc b/tensorflow/compiler/xla/service/gpu/tests/reduction_degenerate_dim_remover_test.cc index 2c5e704d7c2..92f558ee98d 100644 --- a/tensorflow/compiler/xla/service/gpu/tests/reduction_degenerate_dim_remover_test.cc +++ b/tensorflow/compiler/xla/service/gpu/tests/reduction_degenerate_dim_remover_test.cc @@ -37,6 +37,7 @@ class ReductionDegenerateDimRemoverTest : public GpuCodegenTest { DebugOptions debug_options = GpuCodegenTest::GetDebugOptionsForTest(); debug_options.add_xla_disable_hlo_passes("reduction-layout-normalizer"); debug_options.add_xla_disable_hlo_passes("reduction-dimension-grouper"); + debug_options.add_xla_disable_hlo_passes("reduction-splitter"); debug_options.add_xla_disable_hlo_passes("gpu-tree-reduction-rewriter"); return debug_options; } diff --git a/tensorflow/compiler/xla/service/gpu/tests/reduction_layout_normalizer_test.cc b/tensorflow/compiler/xla/service/gpu/tests/reduction_layout_normalizer_test.cc index d06385480e5..b65c2842320 100644 --- a/tensorflow/compiler/xla/service/gpu/tests/reduction_layout_normalizer_test.cc +++ b/tensorflow/compiler/xla/service/gpu/tests/reduction_layout_normalizer_test.cc @@ -33,6 +33,7 @@ class ReductionLayoutNormalizerTest : public GpuCodegenTest { DebugOptions GetDebugOptionsForTest() override { DebugOptions debug_options = GpuCodegenTest::GetDebugOptionsForTest(); debug_options.add_xla_disable_hlo_passes("reduction-dimension-grouper"); + debug_options.add_xla_disable_hlo_passes("reduction-splitter"); debug_options.add_xla_disable_hlo_passes("layout-assignment"); debug_options.add_xla_disable_hlo_passes("gpu-tree-reduction-rewriter"); return debug_options; diff --git a/tensorflow/compiler/xla/service/hlo_dataflow_analysis.cc b/tensorflow/compiler/xla/service/hlo_dataflow_analysis.cc index f19882c9347..a46d20d5808 100644 --- a/tensorflow/compiler/xla/service/hlo_dataflow_analysis.cc +++ b/tensorflow/compiler/xla/service/hlo_dataflow_analysis.cc @@ -1007,6 +1007,8 @@ void HloDataflowAnalysis::OptimizePhiValues() { HloValue::Id phi_id = values[0]->id(); HloValue::Id new_id = phi_graph_.FindOptimizedValue(phi_id); if (new_id != phi_id) { + VLOG(1) << "Replacing " << values[0]->ToString() << " with " + << GetValue(new_id).ToString(); value_set->Clear(); const HloValue& new_value = GetValue(new_id); value_set->AddValue(&new_value); diff --git a/tensorflow/compiler/xla/service/hlo_evaluator.cc b/tensorflow/compiler/xla/service/hlo_evaluator.cc index 3dd6d82784f..ae8f49df4b4 100644 --- a/tensorflow/compiler/xla/service/hlo_evaluator.cc +++ b/tensorflow/compiler/xla/service/hlo_evaluator.cc @@ -274,6 +274,13 @@ StatusOr HloEvaluator::Evaluate( engine_.seed(seed_); TF_RETURN_IF_ERROR(computation.Accept(this)); + + if (VLOG_IS_ON(100)) { + for (const HloInstruction* instr : computation.instructions()) { + VLOG(100) << instr->name() << " = " << GetEvaluatedLiteralFor(instr); + } + } + return GetEvaluatedLiteralFor(computation.root_instruction()).Clone(); } diff --git a/tensorflow/compiler/xla/service/hlo_instruction.cc b/tensorflow/compiler/xla/service/hlo_instruction.cc index cfa21b95dd2..6de76c1cc63 100644 --- a/tensorflow/compiler/xla/service/hlo_instruction.cc +++ b/tensorflow/compiler/xla/service/hlo_instruction.cc @@ -3908,6 +3908,10 @@ const string& HloInstruction::outfeed_config() const { return Cast(this)->outfeed_config(); } +void HloInstruction::set_outfeed_config(const string& config) { + return Cast(this)->set_outfeed_config(config); +} + const std::vector& HloInstruction::replica_groups() const { return Cast(this)->replica_groups(); } diff --git a/tensorflow/compiler/xla/service/hlo_instruction.h b/tensorflow/compiler/xla/service/hlo_instruction.h index 7a5d506b681..f3bb59ff625 100644 --- a/tensorflow/compiler/xla/service/hlo_instruction.h +++ b/tensorflow/compiler/xla/service/hlo_instruction.h @@ -1755,6 +1755,9 @@ class HloInstruction { // Returns the config for the Outfeed instruction. const string& outfeed_config() const; + // Delegates to HloOutfeedInstruction::set_outfeed_config. + void set_outfeed_config(const string& config); + // Returns the shape for the Outfeed instruction. const Shape& outfeed_shape() const; diff --git a/tensorflow/compiler/xla/service/hlo_instructions.h b/tensorflow/compiler/xla/service/hlo_instructions.h index 6da01dc088e..f5a963ef063 100644 --- a/tensorflow/compiler/xla/service/hlo_instructions.h +++ b/tensorflow/compiler/xla/service/hlo_instructions.h @@ -1141,6 +1141,7 @@ class HloOutfeedInstruction : public HloInstruction { const Shape& outfeed_shape() const { return outfeed_shape_; } // Returns the config for the Outfeed instruction. const string& outfeed_config() const { return outfeed_config_; } + void set_outfeed_config(const string& config) { outfeed_config_ = config; } // Returns a serialized representation of this instruction. HloInstructionProto ToProto() const override; diff --git a/tensorflow/compiler/xla/service/hlo_phi_graph.cc b/tensorflow/compiler/xla/service/hlo_phi_graph.cc index 9b69771dab2..a2cba3d1bff 100644 --- a/tensorflow/compiler/xla/service/hlo_phi_graph.cc +++ b/tensorflow/compiler/xla/service/hlo_phi_graph.cc @@ -20,10 +20,11 @@ limitations under the License. namespace xla { HloValue::Id PhiGraph::GetOptimizedId(const HloValue& value) { Node* node = value_id_to_node_[value.id()]; + CHECK(!node->mark_as_dead); return node->value_id; } -// Returns true if the input to a hlo value is the same as `inputs`. +// Returns true if the inputs to a hlo value are the same as `inputs`. bool PhiGraph::InputsEqualTo(const HloValue& value, absl::Span inputs) { auto iter = value_id_to_node_.find(value.id()); @@ -42,6 +43,7 @@ bool PhiGraph::InputsEqualTo(const HloValue& value, HloValue::Id PhiGraph::FindOptimizedValue(const HloValue::Id id) { auto iter = value_id_to_node_.find(id); CHECK(iter != value_id_to_node_.end()); + CHECK(!iter->second->mark_as_dead); return iter->second->value_id; } @@ -66,6 +68,17 @@ PhiGraph::Node* PhiGraph::CreateOrReuseNode(const HloValue& value) { void PhiGraph::ReplaceNodeWith(PhiGraph::Node* node, PhiGraph::Node* replace) { // Update users. CHECK(node->is_phi); + if (node->mark_as_dead) { + // The node has already been replaced with another. + return; + } + if (replace->mark_as_dead) { + // The node we are placing with has already been replaced with another node. + auto iter = value_id_to_node_.find(replace->value_id); + CHECK(iter != value_id_to_node_.end()); + return ReplaceNodeWith(node, iter->second); + } + CHECK(!replace->mark_as_dead); for (Node* user : node->users) { absl::c_replace(user->operands, node, replace); } @@ -74,6 +87,7 @@ void PhiGraph::ReplaceNodeWith(PhiGraph::Node* node, PhiGraph::Node* replace) { for (Node* operand : node->operands) { absl::c_replace(operand->users, node, replace); } + for (HloValue::Id value_id : node_to_value_id_[node]) { CHECK(value_id_to_node_.contains(value_id)); value_id_to_node_[value_id] = replace; @@ -115,6 +129,8 @@ std::string PhiGraph::ToString() { } void PhiGraph::Optimize() { + VLOG(2) << "Optimizing phi graph:"; + XLA_VLOG_LINES(2, ToString()); // Set up users for each node. for (auto& node : node_storage_) { for (Node* input : node->operands) { @@ -141,6 +157,8 @@ void PhiGraph::Optimize() { Node* node_ptr = node.get(); + VLOG(2) << "Optimizing: " << node_ptr->value_id; + CHECK_GE(node_ptr->operands.size(), 1); // Remove self-referencing ids from users and operands. @@ -167,6 +185,9 @@ void PhiGraph::Optimize() { [&](Node* elem) { return elem == node_ptr->operands[0]; }); if (all_inputs_are_same) { + VLOG(1) << "All inputs to node " << node_ptr->value_id + << " are the same, replacing it with " + << node_ptr->operands[0]->value_id; ReplaceNodeWith(node_ptr, node_ptr->operands[0]); changed = true; continue; @@ -223,6 +244,8 @@ void PhiGraph::Optimize() { CHECK_EQ(node, non_phi); continue; } + VLOG(1) << "Replace node " << node->value_id + << " in the closure with node " << non_phi->value_id; ReplaceNodeWith(node, non_phi); changed = true; } diff --git a/tensorflow/compiler/xla/service/hlo_phi_graph.h b/tensorflow/compiler/xla/service/hlo_phi_graph.h index a0eb994438e..ca0d5c5009c 100644 --- a/tensorflow/compiler/xla/service/hlo_phi_graph.h +++ b/tensorflow/compiler/xla/service/hlo_phi_graph.h @@ -90,7 +90,7 @@ class PhiGraph { // to that phi. absl::flat_hash_map> node_to_value_id_; - // A mapping between a HloValue and node in the phi graph. + // A mapping from a HloValue to node in the phi graph. absl::flat_hash_map value_id_to_node_; std::vector> node_storage_; }; diff --git a/tensorflow/compiler/xla/service/hlo_phi_graph_test.cc b/tensorflow/compiler/xla/service/hlo_phi_graph_test.cc index 41f0454fe55..ee7300b160b 100644 --- a/tensorflow/compiler/xla/service/hlo_phi_graph_test.cc +++ b/tensorflow/compiler/xla/service/hlo_phi_graph_test.cc @@ -82,5 +82,30 @@ TEST_F(PhiGraphTest, CircularPhi) { EXPECT_EQ(D.id(), phi_graph.FindOptimizedValue(C.id())); } +TEST_F(PhiGraphTest, NestedPhiReduction) { + // def A = phi(B, C) + // def B = phi(C, E) + // def C = phi(A, B) + // def D = non-phi + // def E = Phi(D, D) + // 1. Replace E with D + // 2. Replace A B and C with E/D + PhiGraph phi_graph; + HloValue A = NewHloValue(true); + HloValue B = NewHloValue(true); + HloValue C = NewHloValue(true); + HloValue D = NewHloValue(false); + HloValue E = NewHloValue(true); + phi_graph.RegisterPhi(A, {&B, &C}); + phi_graph.RegisterPhi(B, {&E, &C}); + phi_graph.RegisterPhi(C, {&A, &B}); + phi_graph.RegisterPhi(E, {&D, &D}); + phi_graph.Optimize(); + EXPECT_EQ(D.id(), phi_graph.FindOptimizedValue(A.id())); + EXPECT_EQ(D.id(), phi_graph.FindOptimizedValue(B.id())); + EXPECT_EQ(D.id(), phi_graph.FindOptimizedValue(C.id())); + EXPECT_EQ(D.id(), phi_graph.FindOptimizedValue(E.id())); +} + } // namespace } // namespace xla diff --git a/tensorflow/compiler/xla/service/hlo_verifier.cc b/tensorflow/compiler/xla/service/hlo_verifier.cc index 4661b8fd9e3..d8baebd6fdd 100644 --- a/tensorflow/compiler/xla/service/hlo_verifier.cc +++ b/tensorflow/compiler/xla/service/hlo_verifier.cc @@ -1123,7 +1123,8 @@ Status ShapeVerifier::HandleGetDimensionSize(HloInstruction* get_size) { Status ShapeVerifier::HandleSetDimensionSize(HloInstruction* set_size) { return CheckShape(set_size, ShapeInference::InferSetDimensionSizeShape( - set_size->operand(0)->shape(), set_size->dimension())); + set_size->operand(0)->shape(), + set_size->operand(1)->shape(), set_size->dimension())); } Status ShapeVerifier::CheckShape(const HloInstruction* instruction, diff --git a/tensorflow/compiler/xla/service/interpreter/executor.cc b/tensorflow/compiler/xla/service/interpreter/executor.cc index a2e46ba2afe..616fd031c47 100644 --- a/tensorflow/compiler/xla/service/interpreter/executor.cc +++ b/tensorflow/compiler/xla/service/interpreter/executor.cc @@ -35,7 +35,6 @@ XlaInterpreterExecutor::~XlaInterpreterExecutor() {} DeviceMemoryBase XlaInterpreterExecutor::Allocate(uint64 size, int64 memory_space) { - CHECK_EQ(memory_space, 0); return DeviceMemoryBase(new char[size], size); } diff --git a/tensorflow/compiler/xla/service/layout_assignment.cc b/tensorflow/compiler/xla/service/layout_assignment.cc index 307fd82069e..a35ba140e86 100644 --- a/tensorflow/compiler/xla/service/layout_assignment.cc +++ b/tensorflow/compiler/xla/service/layout_assignment.cc @@ -951,12 +951,7 @@ Status LayoutAssignment::CheckLayouts(HloModule* module) { if (!Shape::Equal() .IgnoreDynamicDimension() .MinorToMajorOnlyInLayout()(instruction_subshape, - buffer->shape()) && - // TODO(mingyao): Use explicit linear layout tiling to - // detect and allow special bitcast. - instruction->opcode() != HloOpcode::kBitcast && - instruction->opcode() != HloOpcode::kGetTupleElement && - instruction->opcode() != HloOpcode::kTuple) { + buffer->shape())) { return InternalError( "Layout of instruction %s at index {%s} does not match " "source LogicalBuffer %s: %s vs %s", @@ -1803,6 +1798,13 @@ Status LayoutAssignment::ClearComputationLayouts(HloComputation* computation) { // potential bugs in the layout assignment pass that may accidentally use the // existing layout. for (HloInstruction* instruction : computation->instructions()) { + if (instruction->opcode() == HloOpcode::kBitcast) { + // bitcasts are inherently layout sensitive and so a bitcast instruction + // present in the IR before layout assignment is a bug. + return InternalError( + "Unexpected bitcast operation seen during layout assignment: %s.", + instruction->ToString()); + } // Some instructions carry mandatory layouts in their shape. if (instruction->opcode() != HloOpcode::kInfeed && !IsLayoutConstrainedCustomCall(instruction) && diff --git a/tensorflow/compiler/xla/service/layout_assignment_test.cc b/tensorflow/compiler/xla/service/layout_assignment_test.cc index 6e575247e6b..304a80c7a52 100644 --- a/tensorflow/compiler/xla/service/layout_assignment_test.cc +++ b/tensorflow/compiler/xla/service/layout_assignment_test.cc @@ -814,6 +814,27 @@ TEST_F(LayoutAssignmentTest, ConditionalAsymmetricLayout) { EXPECT_THAT(false_result->opcode(), HloOpcode::kCopy); } +TEST_F(LayoutAssignmentTest, InternalErrorOnBitcast) { + auto builder = HloComputation::Builder(TestName()); + auto constant0 = builder.AddInstruction( + HloInstruction::CreateConstant(LiteralUtil::CreateR2WithLayout( + {{1.0, 2.0}, {3.0, 4.0}}, LayoutUtil::MakeLayout({0, 1})))); + builder.AddInstruction( + HloInstruction::CreateBitcast(constant0->shape(), constant0)); + auto m = CreateNewVerifiedModule(); + m->AddEntryComputation(builder.Build()); + + ComputationLayout computation_layout( + m->entry_computation()->ComputeProgramShape()); + LayoutAssignment layout_assignment(&computation_layout); + Status error_status = layout_assignment.Run(m.get()).status(); + EXPECT_FALSE(error_status.ok()); + EXPECT_THAT( + error_status.error_message(), + ::testing::HasSubstr( + "Unexpected bitcast operation seen during layout assignment")); +} + TEST_F(LayoutAssignmentTest, ChannelLayoutMismatch) { // Pin non matching layouts to parameter and root. const char* module_str = R"( diff --git a/tensorflow/compiler/xla/service/llvm_ir/llvm_util.cc b/tensorflow/compiler/xla/service/llvm_ir/llvm_util.cc index e4ca08f972b..b01ae2efe43 100644 --- a/tensorflow/compiler/xla/service/llvm_ir/llvm_util.cc +++ b/tensorflow/compiler/xla/service/llvm_ir/llvm_util.cc @@ -91,10 +91,8 @@ llvm::CallInst* EmitCallToIntrinsic( } llvm::Value* EmitFloatMax(llvm::Value* lhs_value, llvm::Value* rhs_value, - llvm::IRBuilder<>* b) { - // TODO(tpopp): Pass this information down from the HLO's ModuleConfig. - if (b->getFastMathFlags().noNaNs() || - GetDebugOptionsFromFlags().xla_cpu_enable_fast_min_max()) { + llvm::IRBuilder<>* b, bool enable_fast_min_max) { + if (b->getFastMathFlags().noNaNs() || enable_fast_min_max) { auto cmp = b->CreateFCmpUGE(lhs_value, rhs_value); return b->CreateSelect(cmp, lhs_value, rhs_value); } else { @@ -106,10 +104,8 @@ llvm::Value* EmitFloatMax(llvm::Value* lhs_value, llvm::Value* rhs_value, } llvm::Value* EmitFloatMin(llvm::Value* lhs_value, llvm::Value* rhs_value, - llvm::IRBuilder<>* b) { - // TODO(tpopp): Pass this information down from the HLO's ModuleConfig. - if (b->getFastMathFlags().noNaNs() || - GetDebugOptionsFromFlags().xla_cpu_enable_fast_min_max()) { + llvm::IRBuilder<>* b, bool enable_fast_min_max) { + if (b->getFastMathFlags().noNaNs() || enable_fast_min_max) { auto cmp = b->CreateFCmpULE(lhs_value, rhs_value); return b->CreateSelect(cmp, lhs_value, rhs_value); } else { diff --git a/tensorflow/compiler/xla/service/llvm_ir/llvm_util.h b/tensorflow/compiler/xla/service/llvm_ir/llvm_util.h index 691898011ed..642965b6470 100644 --- a/tensorflow/compiler/xla/service/llvm_ir/llvm_util.h +++ b/tensorflow/compiler/xla/service/llvm_ir/llvm_util.h @@ -108,12 +108,12 @@ llvm::CallInst* EmitCallToIntrinsic( // Emit float max. Emit maxnum intrinsic is fast math is disabled, or // fcmp+select otherwise llvm::Value* EmitFloatMax(llvm::Value* lhs_value, llvm::Value* rhs_value, - llvm::IRBuilder<>* b); + llvm::IRBuilder<>* b, bool enable_fast_min_max); // Emit float min. Emit minnum intrinsic is fast math is disabled, or // fcmp+select otherwise llvm::Value* EmitFloatMin(llvm::Value* lhs_value, llvm::Value* rhs_value, - llvm::IRBuilder<>* b); + llvm::IRBuilder<>* b, bool enable_fast_min_max); // Convenience methods for emitting a GEP instruction that indexes into a buffer // (1-dimensional array), equivalent to array[index]. The type is automatically diff --git a/tensorflow/compiler/xla/service/llvm_ir/math_ops.cc b/tensorflow/compiler/xla/service/llvm_ir/math_ops.cc index 333a2e8f612..0604cb848d2 100644 --- a/tensorflow/compiler/xla/service/llvm_ir/math_ops.cc +++ b/tensorflow/compiler/xla/service/llvm_ir/math_ops.cc @@ -31,9 +31,13 @@ llvm::Value* EmitFastTanh(llvm::IRBuilder<>* b, llvm::Value* input) { b->CreateFCmpOLT(abs_x, llvm::ConstantFP::get(type, kCanUseApprox)); // Clamp the input to [-9, 9]. + // + // To simplify the code base until it's an issue, don't have a slow min/max in + // this approximation. llvm::Value* input_clamped = llvm_ir::EmitFloatMin( - llvm_ir::EmitFloatMax(input, llvm::ConstantFP::get(type, -9.0), b), - llvm::ConstantFP::get(type, 9.0), b); + llvm_ir::EmitFloatMax(input, llvm::ConstantFP::get(type, -9.0), b, + /*enable_fast_min_max=*/true), + llvm::ConstantFP::get(type, 9.0), b, /*enable_fast_min_max=*/true); static constexpr std::array numerator_coeffs{ -2.76076847742355e-16f, 2.00018790482477e-13f, -8.60467152213735e-11f, diff --git a/tensorflow/compiler/xla/service/maybe_owning_device_memory.cc b/tensorflow/compiler/xla/service/maybe_owning_device_memory.cc index c4bf48bcc00..c7505f5fa4a 100644 --- a/tensorflow/compiler/xla/service/maybe_owning_device_memory.cc +++ b/tensorflow/compiler/xla/service/maybe_owning_device_memory.cc @@ -14,7 +14,9 @@ limitations under the License. ==============================================================================*/ #include "tensorflow/compiler/xla/service/maybe_owning_device_memory.h" + #include "absl/types/variant.h" + namespace xla { tensorflow::se::DeviceMemoryBase MaybeOwningDeviceMemory::AsDeviceMemoryBase() @@ -38,4 +40,10 @@ MaybeOwningDeviceMemory::Release() { return std::move(absl::get(mem_)); } +const tensorflow::se::OwningDeviceMemory* +MaybeOwningDeviceMemory::AsOwningDeviceMemory() const { + return HasOwnership() ? &absl::get(mem_) + : nullptr; +} + } // namespace xla diff --git a/tensorflow/compiler/xla/service/maybe_owning_device_memory.h b/tensorflow/compiler/xla/service/maybe_owning_device_memory.h index 7d23d178130..0b56fed0a72 100644 --- a/tensorflow/compiler/xla/service/maybe_owning_device_memory.h +++ b/tensorflow/compiler/xla/service/maybe_owning_device_memory.h @@ -57,6 +57,10 @@ class MaybeOwningDeviceMemory { // A nullopt is returned if the HasOwnership() == false; absl::optional Release(); + // If the device memory is owned, returns a pointer to the internal + // OwningDeviceMemory, otherwise nullptr is returned. + const tensorflow::se::OwningDeviceMemory* AsOwningDeviceMemory() const; + // Returns true if the device_memory has ownership over underlying memory. bool HasOwnership() const; diff --git a/tensorflow/compiler/xla/service/memory_space_propagation.cc b/tensorflow/compiler/xla/service/memory_space_propagation.cc index 80eb4017477..2eb15b14eaf 100644 --- a/tensorflow/compiler/xla/service/memory_space_propagation.cc +++ b/tensorflow/compiler/xla/service/memory_space_propagation.cc @@ -29,36 +29,78 @@ StatusOr MemorySpacePropagation::Run(HloModule* module) { // Propagate the operand subshapes. for (int operand_idx = 0; operand_idx < instruction->operand_count(); ++operand_idx) { - modified |= - PropagateSubshapes(instruction->operand(operand_idx)->shape(), - instruction->fused_parameter(operand_idx)); + for (const ShapeUtil::IndexedShape& indexed_shape : + ShapeUtil::GetLeafShapes( + instruction->operand(operand_idx)->shape())) { + int64 memory_space = indexed_shape.shape.layout().memory_space(); + modified |= Propagate(indexed_shape.index, + instruction->fused_parameter(operand_idx), + memory_space); + } } // Propagate output subshapes. - modified |= PropagateSubshapes(instruction->shape(), - instruction->fused_expression_root()); + for (const ShapeUtil::IndexedShape& indexed_shape : + ShapeUtil::GetLeafShapes(instruction->shape())) { + int64 memory_space = indexed_shape.shape.layout().memory_space(); + modified |= + Propagate(indexed_shape.index, + instruction->fused_expression_root(), memory_space); + } } } } return modified; } -bool MemorySpacePropagation::PropagateSubshapes( - const Shape& caller_shape, const HloInstruction* callee_instruction) const { +bool MemorySpacePropagation::Propagate(ShapeIndexView index, + const HloInstruction* callee_instruction, + int64 memory_space) const { bool modified = false; - for (const ShapeUtil::IndexedShape& indexed_shape : - ShapeUtil::GetLeafShapes(caller_shape)) { - int64 memory_space = indexed_shape.shape.layout().memory_space(); - const HloValue& value = dataflow_analysis_->GetUniqueValueAt( - callee_instruction, indexed_shape.index); + const HloValue& value = dataflow_analysis_->GetUniqueValueAt( + callee_instruction, index.ToShapeIndex()); - for (const HloPosition& position : value.positions()) { - Shape* shape = ShapeUtil::GetMutableSubshape( - position.instruction->mutable_shape(), position.index); - if (shape->layout().memory_space() != memory_space) { - shape->mutable_layout()->set_memory_space(memory_space); - modified = true; - } + for (const HloPosition& position : value.positions()) { + HloInstruction* instruction = position.instruction; + Shape* shape = ShapeUtil::GetMutableSubshape(instruction->mutable_shape(), + position.index); + if (shape->layout().memory_space() == memory_space) { + continue; + } + shape->mutable_layout()->set_memory_space(memory_space); + modified = true; + + // For fusion outputs, propagate the memory space to the fusion root. + if (instruction->opcode() == HloOpcode::kFusion) { + Propagate(position.index, instruction->fused_expression_root(), + memory_space); + } + + const HloInstruction* parent_fusion = + instruction->parent()->FusionInstruction(); + // For nested fusion roots, pop one level up and propagate the memory space + // to the output of the calling fusion instruction. + if (instruction == instruction->parent()->root_instruction() && + parent_fusion->parent()->IsFusionComputation()) { + Propagate(position.index, parent_fusion, memory_space); + } + + // For nested fusion parameters, pop one level up and propagate the memory + // space to the operand of the calling fusion instruction. + if (instruction->opcode() == HloOpcode::kParameter && + parent_fusion->parent()->IsFusionComputation()) { + const HloInstruction* fusion_operand = + parent_fusion->operand(instruction->parameter_number()); + Propagate(position.index, fusion_operand, memory_space); + } + } + + for (const HloUse& use : value.uses()) { + // For fusion uses, propagate the memory space to the fusion parameter. + if (use.instruction->opcode() == HloOpcode::kFusion) { + modified |= Propagate( + use.operand_index, + use.instruction->fused_parameter(use.operand_number), memory_space); } } return modified; diff --git a/tensorflow/compiler/xla/service/memory_space_propagation.h b/tensorflow/compiler/xla/service/memory_space_propagation.h index 65a1dfd14a6..510e9e69f79 100644 --- a/tensorflow/compiler/xla/service/memory_space_propagation.h +++ b/tensorflow/compiler/xla/service/memory_space_propagation.h @@ -31,12 +31,11 @@ class MemorySpacePropagation : public HloModulePass { StatusOr Run(HloModule* module) override; private: - // Given the caller shape (operand or output) and its corresponding - // insturction in the fused computation (parameter or root), propagates the - // memory space to all the subshapes in the callee side. Returns true if the - // module is modified. - bool PropagateSubshapes(const Shape& caller_shape, - const HloInstruction* callee_instruction) const; + // Given the shape index (operand or output) and its corresponding instruction + // in the fused computation (parameter or root), propagates the memory space + // in the callee side. Returns true if the module is modified. + bool Propagate(ShapeIndexView index, const HloInstruction* callee_instruction, + int64 memory_space) const; std::unique_ptr dataflow_analysis_; }; diff --git a/tensorflow/compiler/xla/service/memory_space_propagation_test.cc b/tensorflow/compiler/xla/service/memory_space_propagation_test.cc index 8d74958f6aa..de45af5a190 100644 --- a/tensorflow/compiler/xla/service/memory_space_propagation_test.cc +++ b/tensorflow/compiler/xla/service/memory_space_propagation_test.cc @@ -199,5 +199,153 @@ TEST_F(MemorySpacePropagationTest, TupleOutput) { EXPECT_EQ(module->Hash(), ref->Hash()); } +TEST_F(MemorySpacePropagationTest, NestedInputFusion) { + // Tests propagating the memory space to nested fusions on the input side. + absl::string_view hlo_string = R"( + HloModule NestedFusion + + %bitcast_fusion { + %bf_param = s32[3,2]{0,1:T(128)} parameter(0) + ROOT %bitcast = s32[6]{0:T(128)} bitcast(%bf_param) + } + + %fused_computation { + %param_1.3 = s32[1]{0:T(128)} parameter(1) + %constant.2 = s32[]{:T(128)} constant(-2147483648) + %pad.2 = s32[6]{0:T(128)} pad(s32[1]{0:T(128)} %param_1.3, s32[]{:T(128)} %constant.2), padding=0_5 + %param_2.3 = s32[5]{0:T(128)} parameter(2) + %pad.3 = s32[6]{0:T(128)} pad(s32[5]{0:T(128)} %param_2.3, s32[]{:T(128)} %constant.2), padding=1_0 + %maximum.1 = s32[6]{0:T(128)} maximum(s32[6]{0:T(128)} %pad.2, s32[6]{0:T(128)} %pad.3) + %param_0.1 = s32[3,2]{0,1:T(128)} parameter(0) + %fusion.1 = s32[6]{0:T(128)} fusion(%param_0.1), kind=kLoop, calls=bitcast_fusion + ROOT %add.0 = s32[6]{0:T(128)} add(s32[6]{0:T(128)} %maximum.1, s32[6]{0:T(128)} %fusion.1) + } + + ENTRY %entry { + %param0 = s32[3,2]{0,1:T(128)} parameter(0) + %param1 = s32[1]{0:T(128)} parameter(1) + %param2 = s32[5]{0:T(128)} parameter(2) + %arg0 = s32[3,2]{0,1:T(128)S(1)} copy(%param0) + %arg1 = s32[1]{0:T(128)} copy(%param1) + %arg2 = s32[5]{0:T(128)S(1)} copy(%param2) + %fusion = s32[6]{0:T(128)S(1)} fusion(s32[3,2]{0,1:T(128)S(1)} %arg0, s32[1]{0:T(128)} %arg1, s32[5]{0:T(128)S(1)} %arg2), kind=kLoop, calls=%fused_computation + ROOT %root = s32[6]{0:T(128)} copy(%fusion) + } + )"; + absl::string_view expected_hlo_string = R"( + HloModule NestedFusion + + %bitcast_fusion { + %bf_param = s32[3,2]{0,1:T(128)S(1)} parameter(0) + ROOT %bitcast = s32[6]{0:T(128)S(1)} bitcast(%bf_param) + } + + %fused_computation { + %param_1.3 = s32[1]{0:T(128)} parameter(1) + %constant.2 = s32[]{:T(128)} constant(-2147483648) + %pad.2 = s32[6]{0:T(128)} pad(s32[1]{0:T(128)} %param_1.3, s32[]{:T(128)} %constant.2), padding=0_5 + %param_2.3 = s32[5]{0:T(128)S(1)} parameter(2) + %pad.3 = s32[6]{0:T(128)} pad(s32[5]{0:T(128)} %param_2.3, s32[]{:T(128)} %constant.2), padding=1_0 + %maximum.1 = s32[6]{0:T(128)} maximum(s32[6]{0:T(128)} %pad.2, s32[6]{0:T(128)} %pad.3) + %param_0.1 = s32[3,2]{0,1:T(128)S(1)} parameter(0) + %fusion.1 = s32[6]{0:T(128)S(1)} fusion(%param_0.1), kind=kLoop, calls=bitcast_fusion + ROOT %add.0 = s32[6]{0:T(128)S(1)} add(s32[6]{0:T(128)} %maximum.1, s32[6]{0:T(128)S(1)} %fusion.1) + } + + ENTRY %entry { + %param0 = s32[3,2]{0,1:T(128)} parameter(0) + %param1 = s32[1]{0:T(128)} parameter(1) + %param2 = s32[5]{0:T(128)} parameter(2) + %arg0 = s32[3,2]{0,1:T(128)S(1)} copy(%param0) + %arg1 = s32[1]{0:T(128)} copy(%param1) + %arg2 = s32[5]{0:T(128)S(1)} copy(%param2) + %fusion = s32[6]{0:T(128)S(1)} fusion(s32[3,2]{0,1:T(128)S(1)} %arg0, s32[1]{0:T(128)} %arg1, s32[5]{0:T(128)S(1)} %arg2), kind=kLoop, calls=%fused_computation + ROOT %root = s32[6]{0:T(128)} copy(%fusion) + } + )"; + TF_ASSERT_OK_AND_ASSIGN(auto module, + ParseAndReturnUnverifiedModule(hlo_string)); + MemorySpacePropagation memory_space_propagation; + EXPECT_TRUE(memory_space_propagation.Run(module.get()).ValueOrDie()); + TF_EXPECT_OK(Verify(module.get())); + TF_ASSERT_OK_AND_ASSIGN(auto ref, + ParseAndReturnVerifiedModule(expected_hlo_string)); + EXPECT_EQ(module->Hash(), ref->Hash()); +} + +TEST_F(MemorySpacePropagationTest, NestedOutputFusion) { + // Tests propagating the memory space to nested fusions on the output side. + absl::string_view hlo_string = R"( + HloModule NestedFusion + + %bitcast_fusion { + %bf_param = s32[6]{0:T(128)} parameter(0) + ROOT %bitcast = s32[3,2]{0,1:T(128)} bitcast(%bf_param) + } + + %fused_computation { + %param_1.3 = s32[1]{0:T(128)} parameter(1) + %constant.2 = s32[]{:T(128)} constant(-2147483648) + %pad.2 = s32[6]{0:T(128)} pad(s32[1]{0:T(128)} %param_1.3, s32[]{:T(128)} %constant.2), padding=0_5 + %param_2.3 = s32[5]{0:T(128)} parameter(2) + %pad.3 = s32[6]{0:T(128)} pad(s32[5]{0:T(128)} %param_2.3, s32[]{:T(128)} %constant.2), padding=1_0 + %maximum.1 = s32[6]{0:T(128)} maximum(s32[6]{0:T(128)} %pad.2, s32[6]{0:T(128)} %pad.3) + %param_0.1 = s32[6]{0:T(128)} parameter(0) + %add.0 = s32[6]{0:T(128)} add(s32[6]{0:T(128)} %maximum.1, s32[6]{0:T(128)} %param_0.1) + ROOT %fusion.1 = s32[3,2]{0,1:T(128)} fusion(%add.0), kind=kLoop, calls=bitcast_fusion + } + + ENTRY %entry { + %param0 = s32[6]{0:T(128)} parameter(0) + %param1 = s32[1]{0:T(128)} parameter(1) + %param2 = s32[5]{0:T(128)} parameter(2) + %arg0 = s32[6]{0:T(128)S(1)} copy(%param0) + %arg1 = s32[1]{0:T(128)} copy(%param1) + %arg2 = s32[5]{0:T(128)S(1)} copy(%param2) + %fusion = s32[3,2]{0,1:T(128)S(1)} fusion(s32[6]{0:T(128)S(1)} %arg0, s32[1]{0:T(128)} %arg1, s32[5]{0:T(128)S(1)} %arg2), kind=kLoop, calls=%fused_computation + ROOT %root = s32[3,2]{0,1:T(128)} copy(%fusion) + } + )"; + absl::string_view expected_hlo_string = R"( + HloModule NestedFusion + + %bitcast_fusion { + %bf_param = s32[6]{0:T(128)S(1)} parameter(0) + ROOT %bitcast = s32[3,2]{0,1:T(128)S(1)} bitcast(%bf_param) + } + + %fused_computation { + %param_1.3 = s32[1]{0:T(128)} parameter(1) + %constant.2 = s32[]{:T(128)} constant(-2147483648) + %pad.2 = s32[6]{0:T(128)} pad(s32[1]{0:T(128)} %param_1.3, s32[]{:T(128)} %constant.2), padding=0_5 + %param_2.3 = s32[5]{0:T(128)S(1)} parameter(2) + %pad.3 = s32[6]{0:T(128)} pad(s32[5]{0:T(128)} %param_2.3, s32[]{:T(128)} %constant.2), padding=1_0 + %maximum.1 = s32[6]{0:T(128)} maximum(s32[6]{0:T(128)} %pad.2, s32[6]{0:T(128)} %pad.3) + %param_0.1 = s32[6]{0:T(128)S(1)} parameter(0) + %add.0 = s32[6]{0:T(128)S(1)} add(s32[6]{0:T(128)} %maximum.1, s32[6]{0:T(128)S(1)} %param_0.1) + ROOT %fusion.1 = s32[3,2]{0,1:T(128)S(1)} fusion(%add.0), kind=kLoop, calls=bitcast_fusion + } + + ENTRY %entry { + %param0 = s32[6]{0:T(128)} parameter(0) + %param1 = s32[1]{0:T(128)} parameter(1) + %param2 = s32[5]{0:T(128)} parameter(2) + %arg0 = s32[6]{0:T(128)S(1)} copy(%param0) + %arg1 = s32[1]{0:T(128)} copy(%param1) + %arg2 = s32[5]{0:T(128)S(1)} copy(%param2) + %fusion = s32[3,2]{0,1:T(128)S(1)} fusion(s32[6]{0:T(128)S(1)} %arg0, s32[1]{0:T(128)} %arg1, s32[5]{0:T(128)S(1)} %arg2), kind=kLoop, calls=%fused_computation + ROOT %root = s32[3,2]{0,1:T(128)} copy(%fusion) + } + )"; + TF_ASSERT_OK_AND_ASSIGN(auto module, + ParseAndReturnUnverifiedModule(hlo_string)); + MemorySpacePropagation memory_space_propagation; + EXPECT_TRUE(memory_space_propagation.Run(module.get()).ValueOrDie()); + TF_EXPECT_OK(Verify(module.get())); + TF_ASSERT_OK_AND_ASSIGN(auto ref, + ParseAndReturnVerifiedModule(expected_hlo_string)); + EXPECT_EQ(module->Hash(), ref->Hash()); +} + } // namespace } // namespace xla diff --git a/tensorflow/compiler/xla/service/mlir_gpu/BUILD b/tensorflow/compiler/xla/service/mlir_gpu/BUILD index ce45d937424..efe69450846 100644 --- a/tensorflow/compiler/xla/service/mlir_gpu/BUILD +++ b/tensorflow/compiler/xla/service/mlir_gpu/BUILD @@ -167,6 +167,7 @@ cc_library( "//tensorflow/compiler/mlir/xla:lhlo_legalize_to_affine", "//tensorflow/compiler/mlir/xla:lhlo_legalize_to_gpu", "//tensorflow/compiler/mlir/xla:xla_dialect_registration", + "//tensorflow/compiler/mlir/xla:xla_legalize_tanh_to_approximation", "//tensorflow/compiler/mlir/xla:xla_legalize_to_linalg", "//tensorflow/compiler/xla:status", "//tensorflow/compiler/xla:statusor", diff --git a/tensorflow/compiler/xla/service/mlir_gpu/kernel_lowering.cc b/tensorflow/compiler/xla/service/mlir_gpu/kernel_lowering.cc index b0cbddcdb92..196ea218ef3 100644 --- a/tensorflow/compiler/xla/service/mlir_gpu/kernel_lowering.cc +++ b/tensorflow/compiler/xla/service/mlir_gpu/kernel_lowering.cc @@ -301,7 +301,7 @@ struct RewriteKernelSignature signalPassFailure(); return; } - if (func.getBlocks().size() != 1) { + if (!llvm::hasSingleElement(func)) { func.emitError() << "surrounding function has more than one block"; signalPassFailure(); return; @@ -505,6 +505,16 @@ Status LowerLHLOToGPU(mlir::ModuleOp module, LowerLHLOToGPUOptions options) { // Some basic cleanup. pm.addNestedPass<::mlir::FuncOp>(::mlir::createCanonicalizerPass()); pm.addNestedPass<::mlir::FuncOp>(::mlir::createCSEPass()); + // Make loops with min bounds into a conditional plus static bounds. + // Only do this if we unrolled in the first place. + if (!options.unroll_factors.empty()) { + pm.addNestedPass<::mlir::FuncOp>(mlir::createForLoopSpecializationPass()); + } + // Approximate of requested. + if (options.use_approximations) { + pm.addNestedPass<::mlir::FuncOp>( + ::mlir::xla::createLegalizeTanhToApproximationPass()); + } // Move scalar operations into the launch to ensure smaller signatures. pm.addPass(absl::make_unique()); // Take launches to launches with kernels. @@ -547,7 +557,7 @@ class LowerToNVVMPass // TODO(csigg): Remove once we support replacing non-root ops. target.addLegalOp<::mlir::gpu::GPUModuleOp, ::mlir::gpu::ModuleEndOp, ::mlir::gpu::YieldOp>(); - if (failed(mlir::applyFullConversion(m, target, patterns, &converter))) { + if (failed(mlir::applyFullConversion(m, target, patterns))) { signalPassFailure(); } } diff --git a/tensorflow/compiler/xla/service/mlir_gpu/kernel_lowering.h b/tensorflow/compiler/xla/service/mlir_gpu/kernel_lowering.h index 77cf75b9e47..bd633bb06cb 100644 --- a/tensorflow/compiler/xla/service/mlir_gpu/kernel_lowering.h +++ b/tensorflow/compiler/xla/service/mlir_gpu/kernel_lowering.h @@ -28,6 +28,7 @@ struct LowerLHLOToGPUOptions { llvm::ArrayRef unroll_factors = {}; bool collapse_parallel_loops = true; bool rewrite_signature = true; + bool use_approximations = false; }; Status LowerLHLOToGPU(mlir::ModuleOp module, diff --git a/tensorflow/compiler/xla/service/multi_output_fusion.cc b/tensorflow/compiler/xla/service/multi_output_fusion.cc index b95b27d6291..a21cec538d1 100644 --- a/tensorflow/compiler/xla/service/multi_output_fusion.cc +++ b/tensorflow/compiler/xla/service/multi_output_fusion.cc @@ -17,6 +17,7 @@ limitations under the License. #include "absl/container/flat_hash_set.h" #include "tensorflow/compiler/xla/debug_options_flags.h" +#include "tensorflow/compiler/xla/service/hlo_dce.h" #include "tensorflow/compiler/xla/service/hlo_instruction.h" #include "tensorflow/compiler/xla/service/hlo_opcode.h" #include "tensorflow/compiler/xla/service/hlo_reachability.h" @@ -126,6 +127,10 @@ StatusOr MultiOutputFusion::Run(HloModule* module) { candidates_index_.clear(); all_fusion_candidates_.clear(); reachability_.reset(); + if (changed) { + HloDCE dce; + TF_RETURN_IF_ERROR(dce.Run(module).status()); + } return changed; } diff --git a/tensorflow/compiler/xla/service/shape_inference.cc b/tensorflow/compiler/xla/service/shape_inference.cc index 75a80747c1d..bb4a38ded1e 100644 --- a/tensorflow/compiler/xla/service/shape_inference.cc +++ b/tensorflow/compiler/xla/service/shape_inference.cc @@ -2248,12 +2248,17 @@ ShapeInference::InferDegenerateDimensionBroadcastShape(HloOpcode operation, } /* static */ StatusOr ShapeInference::InferSetDimensionSizeShape( - const Shape& shape, int64 dimension) { + const Shape& shape, const Shape& val_shape, int64 dimension) { if (dimension < 0 || dimension >= shape.rank()) { return InvalidArgument("SetDimensionSize dimension out of bounds: %d.", dimension); } + if (val_shape.rank() != 0 || val_shape.element_type() != S32) { + return InvalidArgument( + "SetDimensionSize's value has to be S32 scalar, got %s", + val_shape.ToString()); + } // TODO(b/119580730): Remove this restriction when very large dimension size // is needed. if (shape.dimensions(dimension) > std::numeric_limits::max()) { diff --git a/tensorflow/compiler/xla/service/shape_inference.h b/tensorflow/compiler/xla/service/shape_inference.h index 2cb5930d098..d47d96ab52d 100644 --- a/tensorflow/compiler/xla/service/shape_inference.h +++ b/tensorflow/compiler/xla/service/shape_inference.h @@ -303,10 +303,13 @@ class ShapeInference { const Shape& updates_shape, const ProgramShape& to_apply_shape, const ScatterDimensionNumbers& scatter_dim_numbers); + // Helper that validates the given input shape to GetDimensionSize. static StatusOr InferGetDimensionSizeShape(const Shape& shape, int64 dimension); - static StatusOr InferSetDimensionSizeShape(const Shape& shape, + // Helper that validates the given input shape to SetDimensionSize. + static StatusOr InferSetDimensionSizeShape(const Shape& operand_shape, + const Shape& val_shape, int64 dimension); // Helper function for creating a Window proto from user-supplied data. diff --git a/tensorflow/compiler/xla/service/shape_inference_test.cc b/tensorflow/compiler/xla/service/shape_inference_test.cc index b5ecf6e583e..916d3ab15c8 100644 --- a/tensorflow/compiler/xla/service/shape_inference_test.cc +++ b/tensorflow/compiler/xla/service/shape_inference_test.cc @@ -1365,6 +1365,28 @@ TEST_F(ShapeInferenceTest, DotWithTwoContractingDimsPasses) { EXPECT_TRUE(ShapeUtil::Equal(inferred_status.ValueOrDie(), output_shape)); } +TEST_F(ShapeInferenceTest, ErrorSetDimensionSize) { + Shape arg_shape = ShapeUtil::MakeShape(F32, {5, 3}); + Shape val_shape = ShapeUtil::MakeShape(S32, {1}); + auto inferred_status = ShapeInference::InferSetDimensionSizeShape( + arg_shape, val_shape, /*dimension=*/0); + + EXPECT_FALSE(inferred_status.ok()); + EXPECT_THAT(inferred_status.status().error_message(), + HasSubstr("value has to be S32 scalar")); +} + +TEST_F(ShapeInferenceTest, ErrorSetDimensionSizeWrongType) { + Shape arg_shape = ShapeUtil::MakeShape(F32, {5, 3}); + Shape val_shape = ShapeUtil::MakeShape(U32, {}); + auto inferred_status = ShapeInference::InferSetDimensionSizeShape( + arg_shape, val_shape, /*dimension=*/0); + + EXPECT_FALSE(inferred_status.ok()); + EXPECT_THAT(inferred_status.status().error_message(), + HasSubstr("value has to be S32 scalar")); +} + // BatchMatMul with different batch dimension sizes fails. TEST_F(ShapeInferenceTest, DotWithMismatchedBatchDimSizesFails) { Shape lhs_shape = ShapeUtil::MakeShape(F32, {2, 11, 3}); diff --git a/tensorflow/compiler/xla/service/transfer_manager.cc b/tensorflow/compiler/xla/service/transfer_manager.cc index 4658aebd571..0fd64209152 100644 --- a/tensorflow/compiler/xla/service/transfer_manager.cc +++ b/tensorflow/compiler/xla/service/transfer_manager.cc @@ -20,6 +20,7 @@ limitations under the License. #include "absl/memory/memory.h" #include "absl/strings/str_cat.h" +#include "tensorflow/compiler/xla/service/compiler.h" #include "tensorflow/compiler/xla/service/maybe_owning_device_memory.h" #include "tensorflow/compiler/xla/shape_util.h" #include "tensorflow/compiler/xla/status_macros.h" @@ -33,6 +34,7 @@ limitations under the License. using absl::StrCat; namespace xla { + /* static */ tensorflow::mutex TransferManager::platform_transfer_manager_mutex_( tensorflow::LINKER_INITIALIZED); @@ -200,6 +202,67 @@ void TransferManager::TransferArrayFromDevice( std::move(done), transfer_metadata); } +Status TransferManager::ReadDynamicShapes(se::Stream* stream, + ShapedBuffer* device_buffer, + Shape* host_shape, + Shape* device_shape) { + DCHECK(device_shape->is_dynamic()); + Shape original_device_shape = *device_shape; + Shape original_host_shape = *host_shape; + TF_RETURN_IF_ERROR(stream->BlockHostUntilDone()); + + TF_ASSIGN_OR_RETURN(auto compiler, + Compiler::GetForPlatform(stream->parent()->platform())); + TF_RETURN_IF_ERROR(device_buffer->buffers().ForEachMutableElementWithStatus( + [&](const ShapeIndex& index, se::DeviceMemoryBase* buffer) { + const Shape& buffer_shape = + ShapeUtil::GetSubshape(*device_shape, index); + if (buffer_shape.IsTuple()) { + return Status::OK(); + } + Shape& host_sub_shape = + *ShapeUtil::GetMutableSubshape(host_shape, index); + Shape& device_sub_shape = + *ShapeUtil::GetMutableSubshape(device_shape, index); + if (device_sub_shape.is_static()) { + return Status::OK(); + } + + // Read the dynamic shape metadata from the device stream. + auto shape_size_fn = compiler->ShapeSizeBytesFunction(); + Shape buffer_shape_static = ShapeUtil::MakeStaticShape(buffer_shape); + const int64 offset = shape_size_fn(buffer_shape_static); + int64 metadata_size = shape_size_fn(buffer_shape) - offset; + if (metadata_size == 0) { + return InvalidArgument("Dynamic shape metadata size should not be 0"); + } + auto buffer_8 = se::DeviceMemory(*buffer); + auto metadata_buffer = + stream->parent()->GetSubBuffer(&buffer_8, offset, metadata_size); + TF_ASSIGN_OR_RETURN( + auto metadata, + TransferArrayFromDevice( + stream, + ShapeUtil::MakeShape(S32, {buffer_shape.dimensions_size()}), + metadata_buffer)); + + // Update shape size from metadata. + for (int64 i = 0; i < metadata.element_count(); ++i) { + host_sub_shape.mutable_dimensions()[i] = metadata.Get({i}); + device_sub_shape.mutable_dimensions()[i] = metadata.Get({i}); + } + return Status::OK(); + })); + host_shape->clear_dynamic_dimensions(); + device_shape->clear_dynamic_dimensions(); + + TF_RET_CHECK(ShapeUtil::DynamicShapeIsCompatible(*device_shape, + original_device_shape)); + TF_RET_CHECK( + ShapeUtil::DynamicShapeIsCompatible(*host_shape, original_host_shape)); + return Status::OK(); +} + /* static */ void TransferManager::RegisterTransferManager( se::Platform::Id platform_id, TransferManagerCreationFunction creation_function) { @@ -355,7 +418,9 @@ StatusOr TransferManager::AllocateScopedShapedBuffer( ShapeUtil::GetSubshape(shaped_buffer.on_device_shape(), index); TF_ASSIGN_OR_RETURN(auto memory, allocator->Allocate(shaped_buffer.device_ordinal(), - GetByteSizeRequirement(subshape))); + GetByteSizeRequirement(subshape), + /*retry_on_failure=*/true, + subshape.layout().memory_space())); // Move the allocated buffer into the ScopedShapedBuffer, which owns it. memory_base = memory.Release(); } diff --git a/tensorflow/compiler/xla/service/transfer_manager.h b/tensorflow/compiler/xla/service/transfer_manager.h index e3f8ceacc42..c0670d26eee 100644 --- a/tensorflow/compiler/xla/service/transfer_manager.h +++ b/tensorflow/compiler/xla/service/transfer_manager.h @@ -184,6 +184,15 @@ class TransferManager { const se::DeviceMemoryBase& source, const TransferMetadata* transfer_metadata = nullptr); + // Read from a device buffer and update the dynamic dimension sizes of + // `host_shape` and `device_shape`. The function takes in bounded dynamic + // shapes, and returns static shapes with dynamic shapes updated. + // The shape of the buffer also have to be compatible with the host shape and + // device shape. + virtual Status ReadDynamicShapes(se::Stream* stream, + ShapedBuffer* device_buffer, + Shape* host_shape, Shape* device_shape); + // Transfers the given literal into the Infeed interface of the device, // using the given executor. virtual Status TransferLiteralToInfeed(se::StreamExecutor* executor, diff --git a/tensorflow/compiler/xla/shape_util.cc b/tensorflow/compiler/xla/shape_util.cc index ab46e49b181..bce40578132 100644 --- a/tensorflow/compiler/xla/shape_util.cc +++ b/tensorflow/compiler/xla/shape_util.cc @@ -1461,7 +1461,7 @@ ShapeUtil::ReshapeLeavesDimensionsUnmodified( return shape; } -/* static */ bool ShapeUtil::DynamicShapeIsCompatible( +/* static */ bool ShapeUtil::DynamicArrayShapeIsCompatible( const xla::Shape& dynamic_shape, const xla::Shape& bounded_shape) { if (dynamic_shape.rank() != bounded_shape.rank()) { return false; @@ -1474,6 +1474,36 @@ ShapeUtil::ReshapeLeavesDimensionsUnmodified( return true; } +/* static */ bool ShapeUtil::DynamicShapeIsCompatible( + const xla::Shape& dynamic_shape, const xla::Shape& bounded_shape) { + bool compatible = true; + xla::ShapeUtil::ForEachSubshape(dynamic_shape, [&](const Shape& sub_shape, + const ShapeIndex& index) { + if (compatible) { + auto subshape_result = TryGetSubshape(bounded_shape, index); + if (subshape_result.ok()) { + const Shape* bounded_sub_shape = subshape_result.ConsumeValueOrDie(); + if (sub_shape.IsTuple()) { + if (!bounded_sub_shape->IsTuple()) { + compatible = false; + } + } else { + if (bounded_sub_shape->IsTuple()) { + compatible = false; + } else if (!sub_shape.is_static() && + !DynamicArrayShapeIsCompatible(sub_shape, + *bounded_sub_shape)) { + compatible = false; + } + } + } else { + compatible = false; + } + } + }); + return compatible; +} + /* static */ Shape ShapeUtil::FilterDimensions( const std::function& p, Shape shape) { CHECK(shape.IsArray()); diff --git a/tensorflow/compiler/xla/shape_util.h b/tensorflow/compiler/xla/shape_util.h index dde56587482..fe1a8acf6e4 100644 --- a/tensorflow/compiler/xla/shape_util.h +++ b/tensorflow/compiler/xla/shape_util.h @@ -657,7 +657,11 @@ class ShapeUtil { Shape shape); // Returns true if `dynamic_shape` has dimensions that are less-equal to the - // "bounded_shape". + // "bounded_shape". Shapes must be arrays. + static bool DynamicArrayShapeIsCompatible(const xla::Shape& dynamic_shape, + const xla::Shape& bounded_shape); + + // Same as DynamicArrayShapeIsCompatible() but supports tuples. static bool DynamicShapeIsCompatible(const xla::Shape& dynamic_shape, const xla::Shape& bounded_shape); diff --git a/tensorflow/compiler/xla/tests/BUILD b/tensorflow/compiler/xla/tests/BUILD index e1863a8a4cf..9b36117602b 100644 --- a/tensorflow/compiler/xla/tests/BUILD +++ b/tensorflow/compiler/xla/tests/BUILD @@ -52,16 +52,26 @@ cc_library( name = "test_macros_header", testonly = True, hdrs = ["test_macros.h"], - deps = [ - "//tensorflow/compiler/xla:types", - "//tensorflow/core:test", - "@com_google_absl//absl/strings", - ], ) # Generate a test_macros_${BACKEND} library per backend with the proper copts. generate_backend_test_macros() +cc_library( + name = "manifest_checking_test", + testonly = True, + srcs = ["manifest_checking_test.cc"], + hdrs = ["manifest_checking_test.h"], + deps = [ + ":test_macros_header", + "//tensorflow/core:regexp_internal", + "//tensorflow/core:test", + "//tensorflow/core/platform:logging", + "@com_google_absl//absl/container:flat_hash_map", + "@com_google_absl//absl/strings", + ], +) + cc_library( name = "test_utils", srcs = ["test_utils.cc"], @@ -136,6 +146,7 @@ cc_library( hdrs = ["hlo_test_base.h"], deps = [ ":literal_test_util", + ":manifest_checking_test", ":test_utils", ":verified_hlo_module", "//tensorflow/compiler/xla:debug_options_flags", @@ -193,6 +204,7 @@ cc_library( srcs = ["client_library_test_base.cc"], hdrs = ["client_library_test_base.h"], deps = [ + ":manifest_checking_test", "//tensorflow/compiler/xla:array2d", "//tensorflow/compiler/xla:array3d", "//tensorflow/compiler/xla:array4d", @@ -273,6 +285,7 @@ cc_library( hdrs = ["local_client_test_base.h"], deps = [ ":client_library_test_base", + ":manifest_checking_test", ":verified_hlo_module", "//tensorflow/compiler/xla:shape_util", "//tensorflow/compiler/xla:status_macros", diff --git a/tensorflow/compiler/xla/tests/build_defs.bzl b/tensorflow/compiler/xla/tests/build_defs.bzl index c0c0751b0de..94d870aa2ef 100644 --- a/tensorflow/compiler/xla/tests/build_defs.bzl +++ b/tensorflow/compiler/xla/tests/build_defs.bzl @@ -266,11 +266,6 @@ def generate_backend_test_macros(backends = []): "-DXLA_DISABLED_MANIFEST=\\\"%s\\\"" % manifest, ], deps = [ - "@com_google_absl//absl/container:flat_hash_map", - "@com_google_absl//absl/strings", - "//tensorflow/compiler/xla:types", - "//tensorflow/core:lib", - "//tensorflow/core:regexp_internal", - "//tensorflow/core:test", + "//tensorflow/core/platform:logging", ], ) diff --git a/tensorflow/compiler/xla/tests/client_library_test_base.h b/tensorflow/compiler/xla/tests/client_library_test_base.h index 790497f888e..17bb70bdb42 100644 --- a/tensorflow/compiler/xla/tests/client_library_test_base.h +++ b/tensorflow/compiler/xla/tests/client_library_test_base.h @@ -35,6 +35,7 @@ limitations under the License. #include "tensorflow/compiler/xla/literal_util.h" #include "tensorflow/compiler/xla/statusor.h" #include "tensorflow/compiler/xla/tests/literal_test_util.h" +#include "tensorflow/compiler/xla/tests/manifest_checking_test.h" #include "tensorflow/compiler/xla/tests/test_utils.h" #include "tensorflow/compiler/xla/xla_data.pb.h" #include "tensorflow/core/lib/core/bitmap.h" @@ -62,7 +63,7 @@ std::vector ExpandUseBfloat16( } // A client library test establishes an in-process XLA client connection. -class ClientLibraryTestBase : public ::testing::Test { +class ClientLibraryTestBase : public ManifestCheckingTest { protected: explicit ClientLibraryTestBase(se::Platform* platform = nullptr); diff --git a/tensorflow/compiler/xla/tests/hlo_test_base.h b/tensorflow/compiler/xla/tests/hlo_test_base.h index 85b1876dd3c..17c2a55ba5b 100644 --- a/tensorflow/compiler/xla/tests/hlo_test_base.h +++ b/tensorflow/compiler/xla/tests/hlo_test_base.h @@ -32,6 +32,7 @@ limitations under the License. #include "tensorflow/compiler/xla/shape_layout.h" #include "tensorflow/compiler/xla/statusor.h" #include "tensorflow/compiler/xla/tests/literal_test_util.h" +#include "tensorflow/compiler/xla/tests/manifest_checking_test.h" #include "tensorflow/compiler/xla/tests/verified_hlo_module.h" #include "tensorflow/compiler/xla/types.h" #include "tensorflow/compiler/xla/xla_data.pb.h" @@ -67,7 +68,7 @@ namespace xla { // ) // // For a more detailed example, see "../tests/sample_text_test.cc". -class HloTestBase : public ::testing::Test { +class HloTestBase : public ManifestCheckingTest { public: // Creates a new HLO module for a test. The module created will have // TestName() for its name; it will also automatically populate its debug diff --git a/tensorflow/compiler/xla/tests/local_client_aot_test_helper.cc b/tensorflow/compiler/xla/tests/local_client_aot_test_helper.cc index 53c0d84854e..3e9a3ec2314 100644 --- a/tensorflow/compiler/xla/tests/local_client_aot_test_helper.cc +++ b/tensorflow/compiler/xla/tests/local_client_aot_test_helper.cc @@ -71,6 +71,8 @@ int main(int argc, char** argv) { triple_string = "aarch64-none-linux-gnu"; } else if (target_cpu == "x64_windows") { triple_string = "x86_64-pc-windows-msvc19"; + } else if (target_cpu == "ppc") { + triple_string = "ppc64le-ibm-linux-gnu"; } else if (target_cpu == "local") { triple_string = llvm::sys::getDefaultTargetTriple(); } else { diff --git a/tensorflow/compiler/xla/tests/local_client_test_base.h b/tensorflow/compiler/xla/tests/local_client_test_base.h index ea457024618..c1951ad1021 100644 --- a/tensorflow/compiler/xla/tests/local_client_test_base.h +++ b/tensorflow/compiler/xla/tests/local_client_test_base.h @@ -32,6 +32,7 @@ limitations under the License. #include "tensorflow/compiler/xla/service/transfer_manager.h" #include "tensorflow/compiler/xla/statusor.h" #include "tensorflow/compiler/xla/tests/client_library_test_base.h" +#include "tensorflow/compiler/xla/tests/manifest_checking_test.h" #include "tensorflow/compiler/xla/tests/verified_hlo_module.h" #include "tensorflow/compiler/xla/xla_data.pb.h" #include "tensorflow/core/platform/mutex.h" @@ -75,7 +76,7 @@ class TestAllocator : public se::StreamExecutorMemoryAllocator { }; // A base class for tests which exercise the LocalClient interface. -class LocalClientTestBase : public ::testing::Test { +class LocalClientTestBase : public ManifestCheckingTest { protected: struct EigenThreadPoolWrapper; explicit LocalClientTestBase(se::Platform* platform = nullptr); diff --git a/tensorflow/compiler/xla/tests/manifest_checking_test.cc b/tensorflow/compiler/xla/tests/manifest_checking_test.cc new file mode 100644 index 00000000000..ac6204f9df9 --- /dev/null +++ b/tensorflow/compiler/xla/tests/manifest_checking_test.cc @@ -0,0 +1,129 @@ +/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tensorflow/compiler/xla/tests/manifest_checking_test.h" + +#include +#include +#include + +#include "absl/container/flat_hash_map.h" +#include "absl/strings/ascii.h" +#include "absl/strings/str_split.h" +#include "tensorflow/compiler/xla/tests/test_macros.h" +#include "tensorflow/core/platform/logging.h" +#include "tensorflow/core/platform/regexp.h" + +namespace xla { + +namespace { + +// Mapping from test name; i.e. MyTest.MyTestCase to platforms on which it is +// disabled - a sequence of regexps. +using ManifestT = absl::flat_hash_map>; + +ManifestT ReadManifest() { + ManifestT manifest; + + absl::string_view path = absl::NullSafeStringView(*DisabledManifestPath()); + if (path.empty()) { + return manifest; + } + + // Note: parens are required to disambiguate vs function decl. + std::ifstream file_stream((std::string(path))); + std::string contents((std::istreambuf_iterator(file_stream)), + std::istreambuf_iterator()); + + std::vector lines = absl::StrSplit(contents, '\n'); + for (std::string& line : lines) { + auto comment = line.find("//"); + if (comment != std::string::npos) { + line = line.substr(0, comment); + } + if (line.empty()) { + continue; + } + absl::StripTrailingAsciiWhitespace(&line); + std::vector pieces = absl::StrSplit(line, ' '); + CHECK_GE(pieces.size(), 1); + auto& platforms = manifest[pieces[0]]; + for (size_t i = 1; i < pieces.size(); ++i) { + platforms.push_back(pieces[i]); + } + } + return manifest; +} + +} // namespace + +void ManifestCheckingTest::SetUp() { + const testing::TestInfo* test_info = + testing::UnitTest::GetInstance()->current_test_info(); + absl::string_view test_case_name = test_info->test_suite_name(); + absl::string_view test_name = test_info->name(); + VLOG(1) << "test_case_name: " << test_case_name; + VLOG(1) << "test_name: " << test_name; + + // Remove the type suffix from the test case name. + if (const char* type_param = test_info->type_param()) { + VLOG(1) << "type_param: " << type_param; + size_t last_slash = test_case_name.rfind('/'); + test_case_name = test_case_name.substr(0, last_slash); + VLOG(1) << "test_case_name: " << test_case_name; + } + + // Remove the test instantiation name if it is present. + auto first_slash = test_case_name.find('/'); + if (first_slash != test_case_name.npos) { + test_case_name.remove_prefix(first_slash + 1); + VLOG(1) << "test_case_name: " << test_case_name; + } + + ManifestT manifest = ReadManifest(); + + // If the test name ends with a slash followed by one or more characters, + // strip that off. + auto last_slash = test_name.rfind('/'); + if (last_slash != test_name.npos) { + test_name = test_name.substr(0, last_slash); + VLOG(1) << "test_name: " << test_name; + } + + // First try full match: test_case_name.test_name + // If that fails, try to find just the test_case_name; this would disable all + // tests in the test case. + auto it = manifest.find(absl::StrCat(test_case_name, ".", test_name)); + if (it == manifest.end()) { + it = manifest.find(test_case_name); + if (it == manifest.end()) { + return; + } + } + + // Expect a full match vs. one of the platform regexps to disable the test. + const std::vector& disabled_platforms = it->second; + auto platform_string = *TestPlatform(); + for (const auto& s : disabled_platforms) { + if (RE2::FullMatch(/*text=*/platform_string, /*re=*/s)) { + GTEST_SKIP(); + return; + } + } + + // We didn't hit in the disabled manifest entries, so don't disable it. +} + +} // namespace xla diff --git a/tensorflow/compiler/xla/tests/manifest_checking_test.h b/tensorflow/compiler/xla/tests/manifest_checking_test.h new file mode 100644 index 00000000000..4f44ed76a3e --- /dev/null +++ b/tensorflow/compiler/xla/tests/manifest_checking_test.h @@ -0,0 +1,35 @@ +/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_COMPILER_XLA_TESTS_MANIFEST_CHECKING_TEST_H_ +#define TENSORFLOW_COMPILER_XLA_TESTS_MANIFEST_CHECKING_TEST_H_ + +#include "tensorflow/core/platform/test.h" + +namespace xla { + +// This class allows us to intercept the test name and use an arbitrary +// heuristic to decide whether the test case should be disabled. We +// determine whether the test case should be disabled by resolving the (test +// case name, test name) in a manifest file. +class ManifestCheckingTest : public ::testing::Test { + protected: + // This method runs before each test runs. + void SetUp() override; +}; + +} // namespace xla + +#endif // TENSORFLOW_COMPILER_XLA_TESTS_MANIFEST_CHECKING_TEST_H_ diff --git a/tensorflow/compiler/xla/tests/multiple_devices_on_host_test.cc b/tensorflow/compiler/xla/tests/multiple_devices_on_host_test.cc index 2b19aaded9c..2231fc6feab 100644 --- a/tensorflow/compiler/xla/tests/multiple_devices_on_host_test.cc +++ b/tensorflow/compiler/xla/tests/multiple_devices_on_host_test.cc @@ -45,7 +45,8 @@ void CompileAndExecute( xla::ClientLibrary::GetXlaService(client->platform()) ->backend() .memory_allocator()); - StatusOr result = executable->Run({}, execute_options); + StatusOr result = + executable->Run(absl::Span(), execute_options); { absl::MutexLock lock(results_mutex); results->emplace_back(device_ordinal, std::move(result)); diff --git a/tensorflow/compiler/xla/tests/test_macros.cc b/tensorflow/compiler/xla/tests/test_macros.cc index dc9ac7b684a..eecbb89b877 100644 --- a/tensorflow/compiler/xla/tests/test_macros.cc +++ b/tensorflow/compiler/xla/tests/test_macros.cc @@ -15,93 +15,18 @@ limitations under the License. #include "tensorflow/compiler/xla/tests/test_macros.h" -#include -#include -#include - -#include "absl/container/flat_hash_map.h" -#include "absl/strings/str_cat.h" -#include "absl/strings/str_split.h" #include "tensorflow/core/platform/logging.h" -#include "tensorflow/core/platform/regexp.h" namespace xla { -namespace { -// Mapping from test name; i.e. MyTest.MyTestCase to platforms on which it is -// disabled - a sequence of regexps. -using ManifestT = absl::flat_hash_map>; - -ManifestT ReadManifest() { - ManifestT manifest; - - string path = XLA_DISABLED_MANIFEST; - if (path.empty()) { - return manifest; - } - - std::ifstream file_stream(path); - // Note: parens are required to disambiguate vs function decl. - string contents((std::istreambuf_iterator(file_stream)), - std::istreambuf_iterator()); - - std::vector lines = absl::StrSplit(contents, '\n'); - for (string& line : lines) { - auto comment = line.find("//"); - if (comment != string::npos) { - line = line.substr(0, comment); - } - if (line.empty()) { - continue; - } - absl::StripTrailingAsciiWhitespace(&line); - std::vector pieces = absl::StrSplit(line, ' '); - CHECK_GE(pieces.size(), 1); - auto& platforms = manifest[pieces[0]]; - for (int64 i = 1; i < pieces.size(); ++i) { - platforms.push_back(pieces[i]); - } - } - return manifest; +static bool InitModule() { + *DisabledManifestPath() = XLA_DISABLED_MANIFEST; + VLOG(1) << "DisabledManifestPath: " << *DisabledManifestPath(); + *TestPlatform() = XLA_PLATFORM; + VLOG(1) << "TestPlatform: " << *TestPlatform(); + return false; } -} // namespace - -std::string PrependDisabledIfIndicated(absl::string_view test_case_name, - absl::string_view test_name) { - ManifestT manifest = ReadManifest(); - - // If the test name ends with a slash followed by one or more digits, strip - // that off; this is just a shard number, and matching on this would be - // unstable even if someone wanted to do it. - static LazyRE2 shard_num_pattern = {R"(/\d+$)"}; - absl::string_view suffix; - if (RE2::PartialMatch(test_name, *shard_num_pattern, &suffix)) { - test_name.remove_suffix(suffix.size()); - } - - // First try full match: test_case_name.test_name - // If that fails, try to find just the test_case_name; this would disable all - // tests in the test case. - auto it = manifest.find(absl::StrCat(test_case_name, ".", test_name)); - if (it == manifest.end()) { - it = manifest.find(test_case_name); - if (it == manifest.end()) { - return std::string(test_name); - } - } - - // Expect a full match vs. one of the platform regexps to disable the test. - const std::vector& disabled_platforms = it->second; - string platform_string = XLA_PLATFORM; - for (const auto& s : disabled_platforms) { - if (RE2::FullMatch(/*text=*/platform_string, /*re=*/s)) { - return absl::StrCat("DISABLED_", test_name); - } - } - - // We didn't hit in the disabled manifest entries, so don't disable it. - return std::string(test_name); -} +static bool module_initialized = InitModule(); } // namespace xla diff --git a/tensorflow/compiler/xla/tests/test_macros.h b/tensorflow/compiler/xla/tests/test_macros.h index 33d2dff9721..16cc9ff6feb 100644 --- a/tensorflow/compiler/xla/tests/test_macros.h +++ b/tensorflow/compiler/xla/tests/test_macros.h @@ -28,12 +28,6 @@ limitations under the License. #ifndef TENSORFLOW_COMPILER_XLA_TESTS_TEST_MACROS_H_ #define TENSORFLOW_COMPILER_XLA_TESTS_TEST_MACROS_H_ -#include - -#include "absl/strings/string_view.h" -#include "tensorflow/compiler/xla/types.h" -#include "tensorflow/core/platform/test.h" - #define DISABLED_ON_CPU(X) X #define DISABLED_ON_GPU(X) X #define DISABLED_ON_GPU_ROCM(X) X @@ -79,117 +73,22 @@ limitations under the License. namespace xla { -// Reads a disabled manifest file to resolve whether test cases should be -// disabled on a particular platform. For a test that should be disabled, -// returns DISABLED_ prepended to its name; otherwise returns the test name -// unmodified. -std::string PrependDisabledIfIndicated(absl::string_view test_case_name, - absl::string_view test_name); +inline const char** DisabledManifestPath() { + static const char* disabled_manifest_path = nullptr; + return &disabled_manifest_path; +} + +inline const char** TestPlatform() { + static const char* test_platform = nullptr; + return &test_platform; +} } // namespace xla -// This is the internal "gtest" class instantiation -- it is identical to the -// GTEST_TEST_ macro, except that we intercept the test name for potential -// modification by PrependDisabledIfIndicated. That file can use an arbitrary -// heuristic to decide whether the test case should be disabled, and we -// determine whether the test case should be disabled by resolving the (test -// case name, test name) in a manifest file. -#define XLA_GTEST_TEST_(test_case_name, test_name, parent_class) \ - class GTEST_TEST_CLASS_NAME_(test_case_name, test_name) \ - : public parent_class { \ - public: \ - GTEST_TEST_CLASS_NAME_(test_case_name, test_name)() {} \ - \ - private: \ - virtual void TestBody(); \ - static ::testing::TestInfo* const test_info_ GTEST_ATTRIBUTE_UNUSED_; \ - GTEST_DISALLOW_COPY_AND_ASSIGN_(GTEST_TEST_CLASS_NAME_(test_case_name, \ - test_name)); \ - }; \ - \ - ::testing::TestInfo* const GTEST_TEST_CLASS_NAME_(test_case_name, \ - test_name)::test_info_ = \ - ::testing::RegisterTest( \ - #test_case_name, \ - ::xla::PrependDisabledIfIndicated(#test_case_name, #test_name) \ - .c_str(), \ - nullptr, nullptr, __FILE__, __LINE__, []() -> parent_class* { \ - return new GTEST_TEST_CLASS_NAME_(test_case_name, test_name)(); \ - }); \ - void GTEST_TEST_CLASS_NAME_(test_case_name, test_name)::TestBody() +#define XLA_TEST_F(test_fixture, test_name) TEST_F(test_fixture, test_name) -// This is identical to the TEST_F macro from "gtest", but it potentially -// disables the test based on an external manifest file, DISABLED_MANIFEST. -// -// Per usual, you can see what tests are available via --gunit_list_tests and -// choose to run tests that have been disabled via the manifest via -// --gunit_also_run_disabled_tests. -#define XLA_TEST_F(test_fixture, test_name) \ - XLA_GTEST_TEST_(test_fixture, test_name, test_fixture) +#define XLA_TEST_P(test_case_name, test_name) TEST_P(test_case_name, test_name) -// Likewise, this is identical to the TEST_P macro from "gtest", but -// potentially disables the test based on the DISABLED_MANIFEST file. -// -// We have to wrap this in an outer layer so that any DISABLED_ON_* macros will -// be properly expanded before the stringification occurs. -#define XLA_TEST_P_IMPL_(test_case_name, test_name) \ - class GTEST_TEST_CLASS_NAME_(test_case_name, test_name) \ - : public test_case_name { \ - public: \ - GTEST_TEST_CLASS_NAME_(test_case_name, test_name)() {} \ - virtual void TestBody(); \ - \ - private: \ - static int AddToRegistry() { \ - ::testing::UnitTest::GetInstance() \ - ->parameterized_test_registry() \ - .GetTestCasePatternHolder( \ - #test_case_name, \ - ::testing::internal::CodeLocation(__FILE__, __LINE__)) \ - ->AddTestPattern( \ - #test_case_name, \ - ::xla::PrependDisabledIfIndicated(#test_case_name, #test_name) \ - .c_str(), \ - new ::testing::internal::TestMetaFactory()); \ - return 0; \ - } \ - static int gtest_registering_dummy_ GTEST_ATTRIBUTE_UNUSED_; \ - GTEST_DISALLOW_COPY_AND_ASSIGN_(GTEST_TEST_CLASS_NAME_(test_case_name, \ - test_name)); \ - }; \ - int GTEST_TEST_CLASS_NAME_(test_case_name, \ - test_name)::gtest_registering_dummy_ = \ - GTEST_TEST_CLASS_NAME_(test_case_name, test_name)::AddToRegistry(); \ - void GTEST_TEST_CLASS_NAME_(test_case_name, test_name)::TestBody() - -#define XLA_TEST_P(test_case_name, test_name) \ - XLA_TEST_P_IMPL_(test_case_name, test_name) - -// This is identical to the TEST_F macro from "gtest", but it potentially -// disables the test based on an external manifest file, DISABLED_MANIFEST. -#define XLA_TYPED_TEST(CaseName, TestName) \ - template \ - class GTEST_TEST_CLASS_NAME_(CaseName, TestName) \ - : public CaseName { \ - private: \ - typedef CaseName TestFixture; \ - typedef gtest_TypeParam_ TypeParam; \ - virtual void TestBody(); \ - }; \ - bool gtest_##CaseName##_##TestName##_registered_ GTEST_ATTRIBUTE_UNUSED_ = \ - ::testing::internal::TypeParameterizedTest< \ - CaseName, \ - ::testing::internal::TemplateSel, \ - GTEST_TYPE_PARAMS_(CaseName)>:: \ - Register( \ - "", ::testing::internal::CodeLocation(__FILE__, __LINE__), \ - #CaseName, \ - ::xla::PrependDisabledIfIndicated(#CaseName, #TestName).c_str(), \ - 0); \ - template \ - void GTEST_TEST_CLASS_NAME_(CaseName, \ - TestName)::TestBody() +#define XLA_TYPED_TEST(CaseName, TestName) TYPED_TEST(CaseName, TestName) #endif // TENSORFLOW_COMPILER_XLA_TESTS_TEST_MACROS_H_ diff --git a/tensorflow/compiler/xla/tests/tuple_test.cc b/tensorflow/compiler/xla/tests/tuple_test.cc index 9ef589e5511..b6ad44497e6 100644 --- a/tensorflow/compiler/xla/tests/tuple_test.cc +++ b/tensorflow/compiler/xla/tests/tuple_test.cc @@ -577,5 +577,37 @@ XLA_TEST_F(TupleHloTest, EXPECT_TRUE(LiteralTestUtil::Equal(expected, literal)); } +XLA_TEST_F(TupleHloTest, TupleSelectOfSort) { + const char* testcase = R"( + HloModule sort + + compare { + p.1.lhs = s32[] parameter(2) + p.1.rhs = s32[] parameter(3) + p.0.lhs = f32[] parameter(0) + p.0.rhs = f32[] parameter(1) + ROOT lt = pred[] compare(p.0.lhs, p.0.rhs), direction=LT + } + + ENTRY Sort { + keys = f32[2]{0} iota(), iota_dimension=0 + values = s32[2]{0} iota(), iota_dimension=0 + preds = pred[] constant(true) + alt = (f32[2], s32[2]) parameter(0) + + sorted = (f32[2]{0}, s32[2]{0}) sort(keys, values), dimensions={0}, + to_apply=compare + ROOT selected = (f32[2], s32[2]) tuple-select(preds, sorted, alt) + } + )"; + auto module = ParseAndReturnVerifiedModule(testcase).ValueOrDie(); + auto param = LiteralUtil::MakeTupleOwned(LiteralUtil::CreateR1({2, 3}), + LiteralUtil::CreateR1({3, 4})); + auto expected = LiteralUtil::MakeTupleOwned( + LiteralUtil::CreateR1({0, 1}), LiteralUtil::CreateR1({0, 1})); + auto result = ExecuteAndTransfer(std::move(module), {¶m}); + EXPECT_TRUE(LiteralTestUtil::Equal(expected, result)); +} + } // namespace } // namespace xla diff --git a/tensorflow/compiler/xla/tests/while_test.cc b/tensorflow/compiler/xla/tests/while_test.cc index d575bbb1f3e..8e8c3605cc7 100644 --- a/tensorflow/compiler/xla/tests/while_test.cc +++ b/tensorflow/compiler/xla/tests/while_test.cc @@ -1324,14 +1324,16 @@ void BM_WhileLoop(int num_iters) { options.set_allocator(&allocator); const int kWarmups = 2; for (int i = 0; i < kWarmups; ++i) { - auto result = executable->Run({}, options); + auto result = + executable->Run(absl::Span(), options); ASSERT_TRUE(result.ok()); } // Run benchmark. tensorflow::testing::StartTiming(); for (int i = 0; i < num_iters; ++i) { - auto result = executable->Run({}, options); + auto result = + executable->Run(absl::Span(), options); ASSERT_TRUE(result.ok()); } } diff --git a/tensorflow/compiler/xrt/BUILD b/tensorflow/compiler/xrt/BUILD index 332c8ff9a14..6a704be4adb 100644 --- a/tensorflow/compiler/xrt/BUILD +++ b/tensorflow/compiler/xrt/BUILD @@ -74,6 +74,7 @@ cc_library( "//tensorflow/compiler/xla/client:local_client", "//tensorflow/compiler/xla/service:backend", "//tensorflow/compiler/xla/service:executable", + "//tensorflow/compiler/xla/service:hlo", "//tensorflow/compiler/xla/service:shaped_buffer", "//tensorflow/core:core_cpu_internal", "//tensorflow/core:framework", diff --git a/tensorflow/compiler/xrt/kernels/xrt_execute_op.cc b/tensorflow/compiler/xrt/kernels/xrt_execute_op.cc index 2fc599e42df..bfd48bd1442 100644 --- a/tensorflow/compiler/xrt/kernels/xrt_execute_op.cc +++ b/tensorflow/compiler/xrt/kernels/xrt_execute_op.cc @@ -51,12 +51,6 @@ namespace tensorflow { namespace { -struct InputBuffers { - std::vector> input_tuples; - std::vector input_allocations; - std::vector input_pointers; -}; - uint32 InitialRandomSeed() { // Support plumbing the TF seed through to XLA is being worked on. // If a user wants deterministic behavior, their best option @@ -80,75 +74,51 @@ uint32 GetXLARandomSeed() { return counter.fetch_add(2); } -xla::StatusOr GetInputBuffers( - XRTMemoryManager::WorkingSet* working_set, xla::Backend* backend, - const std::vector& input_coords, bool release_inputs) { - InputBuffers input_buffers; - input_buffers.input_tuples.reserve(input_coords.size()); - input_buffers.input_allocations.reserve(input_coords.size()); - input_buffers.input_pointers.reserve(input_coords.size()); - for (size_t i = 0; i < input_coords.size(); ++i) { - TF_RETURN_IF_ERROR( - working_set->LookupAndPin(backend, input_coords[i].handle)); - auto tuple = working_set->PinnedTuples().back(); - input_buffers.input_tuples.emplace_back(tuple); - if (release_inputs) { - // We are holding a reference to the tuple, so we can safely delete it - // from the resource manager here. - TF_RETURN_IF_ERROR( - working_set->MemoryManager()->Release(input_coords[i].handle)); - VLOG(2) << "Released allocation handle " << input_coords[i].handle; - } - if (input_coords[i].index.empty()) { - TF_ASSIGN_OR_RETURN(xla::ShapedBuffer shaped_buffer, - tuple->ToShapedBuffer()); - input_buffers.input_allocations.emplace_back(std::move(shaped_buffer)); - } else { - TF_ASSIGN_OR_RETURN(xla::ShapedBuffer shaped_buffer, - tuple->ToShapedBuffer()); - TF_ASSIGN_OR_RETURN(xla::ShapedBuffer sub_shaped_buffer, - shaped_buffer.SubShapedBuffer(input_coords[i].index)); - input_buffers.input_allocations.emplace_back( - std::move(sub_shaped_buffer)); - } +std::vector GetDynamicInputInfo( + const xla::ComputationLayout& computation_layout) { + std::vector input_is_dynamic; + input_is_dynamic.reserve(computation_layout.parameter_count()); + for (int64 i = 0; i < computation_layout.parameter_count(); ++i) { + input_is_dynamic.push_back( + !computation_layout.parameter_shape(i).is_static()); } - for (size_t i = 0; i < input_buffers.input_allocations.size(); ++i) { - input_buffers.input_pointers.push_back(&input_buffers.input_allocations[i]); - } - return std::move(input_buffers); + return input_is_dynamic; } -xla::StatusOr GetChainedOpInputs( +xla::StatusOr>> GetInputTuples( + xla::LocalExecutable* executable, XRTMemoryManager::WorkingSet* working_set, + xla::Backend* backend, const std::vector& input_coords, + bool release_inputs) { + const xla::ComputationLayout& computation_layout = + executable->executable()->module_config().entry_computation_layout(); + + return GetInputTupleAllocations( + input_coords, working_set, backend, computation_layout.parameter_count(), + [&](int64 i) { return computation_layout.parameter_shape(i); }, + release_inputs); +} + +xla::StatusOr>> GetChainedOpInputTuples( const xrt::XRTChainedExecuteOp& op, absl::Span> op_inputs) { - InputBuffers input_buffers; - input_buffers.input_tuples.reserve(op.inputs_size()); - input_buffers.input_allocations.reserve(op.inputs_size()); - input_buffers.input_pointers.reserve(op.inputs_size()); + std::vector> input_tuples; + input_tuples.reserve(op.inputs_size()); for (int i = 0; i < op.inputs_size(); ++i) { auto& input = op.inputs(i); - input_buffers.input_tuples.emplace_back(op_inputs[i]); // Thanks to the greatness of proto3, there is no way to query for // explicitly set fields, so the default for output_index (zero) means no // sub-index. As consequence, the real index is output_index - 1. if (input.output_index() == 0) { - TF_ASSIGN_OR_RETURN(xla::ShapedBuffer shaped_buffer, - input_buffers.input_tuples.back()->ToShapedBuffer()); - input_buffers.input_allocations.emplace_back(std::move(shaped_buffer)); + input_tuples.emplace_back(op_inputs[i]); } else { - TF_ASSIGN_OR_RETURN(xla::ShapedBuffer shaped_buffer, - input_buffers.input_tuples.back()->ToShapedBuffer()); - TF_ASSIGN_OR_RETURN( - xla::ShapedBuffer sub_shaped_buffer, - shaped_buffer.SubShapedBuffer({input.output_index() - 1})); - input_buffers.input_allocations.emplace_back( - std::move(sub_shaped_buffer)); + XRTTupleAllocation* sub_tuple; + TF_RETURN_IF_ERROR(XRTTupleAllocation::MakeSubBuffer( + op_inputs[i].get(), {input.output_index() - 1}, &sub_tuple, + /*alias_parent_allocation=*/true)); + input_tuples.emplace_back(sub_tuple); } } - for (size_t i = 0; i < input_buffers.input_allocations.size(); ++i) { - input_buffers.input_pointers.push_back(&input_buffers.input_allocations[i]); - } - return std::move(input_buffers); + return input_tuples; } // Given a shape, returns a byte array representing the shape metadata of the @@ -228,12 +198,11 @@ Status UpdateMetadata(se::Stream* stream, se::DeviceMemory* buffer, // As we can't expand the size of an existing memory allocation, a reallocation // is required. A list of new allocations are returned after this function. The // caller is reponsible for maintaining those allocations. -xla::StatusOr> UpdateDynamicInputs( +Status UpdateDynamicInputs( se::Stream* stream, se::DeviceMemoryAllocator* allocator, - std::vector runtime_inputs, + std::vector* execution_inputs, const std::vector& compile_time_shapes) { - std::vector new_allocations; - TF_RET_CHECK(runtime_inputs.size() == compile_time_shapes.size()); + TF_RET_CHECK(execution_inputs->size() == compile_time_shapes.size()); TF_ASSIGN_OR_RETURN(auto compiler, xla::Compiler::GetForPlatform( stream->parent()->platform())); auto shape_size_fn = compiler->ShapeSizeBytesFunction(); @@ -242,146 +211,103 @@ xla::StatusOr> UpdateDynamicInputs( if (compile_time_shape.is_static()) { continue; } - auto* runtime_input = runtime_inputs[i]; - + xla::ExecutionInput* execution_input = &(*execution_inputs)[i]; bool element_modified = false; TF_RETURN_IF_ERROR(xla::ShapeUtil::ForEachSubshapeWithStatus( compile_time_shape, - [&](const xla::Shape& compile_time_shape, + [&](const xla::Shape& sub_shape, const xla::ShapeIndex& index) -> Status { - if (compile_time_shape.IsTuple() || compile_time_shape.is_static()) { + if (sub_shape.IsTuple() || sub_shape.is_static()) { return Status::OK(); } - const xla::Shape& runtime_shape = xla::ShapeUtil::GetSubshape( - runtime_input->on_device_shape(), index); - TF_RET_CHECK(!runtime_shape.IsTuple()); - TF_RET_CHECK(xla::ShapeUtil::DynamicShapeIsCompatible( - runtime_shape, compile_time_shape)); - se::DeviceMemoryBase* static_input = - runtime_input->buffers().mutable_element(index); TF_ASSIGN_OR_RETURN( - auto dynamic_input, + const xla::Shape* runtime_shape, + xla::ShapeUtil::TryGetSubshape(execution_input->shape(), index)); + TF_RET_CHECK(!runtime_shape->IsTuple()); + TF_RET_CHECK(xla::ShapeUtil::DynamicArrayShapeIsCompatible( + *runtime_shape, sub_shape)); + TF_ASSIGN_OR_RETURN( + se::OwningDeviceMemory dynamic_input, allocator->Allocate(stream->parent()->device_ordinal(), - shape_size_fn(compile_time_shape))); - new_allocations.emplace_back(std::move(dynamic_input)); - se::DeviceMemory* dynamic_input_base = - new_allocations.back().ptr(); + shape_size_fn(sub_shape))); + + se::DeviceMemoryBase static_input = + execution_input->Buffer(index).AsDeviceMemoryBase(); + se::DeviceMemory* dynamic_input_base = dynamic_input.ptr(); // Send the original data to the new location. - stream->ThenMemcpyD2D(dynamic_input_base, *static_input, - static_input->size()); + stream->ThenMemcpyD2D(dynamic_input_base, static_input, + static_input.size()); TF_RETURN_IF_ERROR(UpdateMetadata(stream, dynamic_input_base, - compile_time_shape, runtime_shape)); + sub_shape, *runtime_shape)); // Modify the memory location in the input shape tree to point to the // new input. - runtime_input->set_buffer(*dynamic_input_base, index); + execution_input->SetBuffer( + index, xla::MaybeOwningDeviceMemory(std::move(dynamic_input))); + execution_input->ClearUnownedIndex(index); element_modified = true; return Status::OK(); })); if (element_modified) { - runtime_input->set_shapes(compile_time_shape, compile_time_shape); + TF_RETURN_IF_ERROR(execution_input->SetDynamicShape(compile_time_shape)); + TF_ASSIGN_OR_RETURN(xla::ShapedBuffer shaped_buffer, + execution_input->ToShapedBuffer( + allocator, stream->parent()->device_ordinal())); // The input location has been modified, need to fix tuple table to // point to the correct address. TF_ASSIGN_OR_RETURN( auto transfer_manager, xla::TransferManager::GetForPlatform(stream->parent()->platform())); TF_RETURN_IF_ERROR( - transfer_manager->WriteTupleIndexTablesAsync(stream, *runtime_input)); + transfer_manager->WriteTupleIndexTablesAsync(stream, shaped_buffer)); } } - return std::move(new_allocations); -} - -xla::StatusOr ReadMetadataLiteral( - se::Stream* stream, se::DeviceMemoryBase* buffer, - const xla::Shape& buffer_shape, xla::TransferManager* transfer_manager) { - TF_ASSIGN_OR_RETURN(auto compiler, xla::Compiler::GetForPlatform( - stream->parent()->platform())); - auto shape_size_fn = compiler->ShapeSizeBytesFunction(); - xla::Shape buffer_shape_static = - xla::ShapeUtil::MakeStaticShape(buffer_shape); - const int64 offset = shape_size_fn(buffer_shape_static); - int64 metadata_size = shape_size_fn(buffer_shape) - offset; - TF_RET_CHECK(metadata_size != 0); - auto buffer_8 = se::DeviceMemory(*buffer); - auto metadata_buffer = - stream->parent()->GetSubBuffer(&buffer_8, offset, metadata_size); - return transfer_manager->TransferArrayFromDevice( - stream, - xla::ShapeUtil::MakeShape(xla::S32, {buffer_shape.dimensions_size()}), - metadata_buffer); -} - -// For each subshape in the result buffer that's dynamic, read the dynamic -// dimension sizes from the metadata, and update output shapes. The result shape -// is a static and concrete shape. -xla::Status UpdateDynamicOutputs(se::Stream* stream, - xla::ShapedBuffer* shaped_buffer, - xla::Shape* output_host_shape, - xla::Shape* output_device_shape) { - DCHECK(output_device_shape->is_dynamic()); - TF_ASSIGN_OR_RETURN( - auto transfer_manager, - xla::TransferManager::GetForPlatform(stream->parent()->platform())); - TF_RETURN_IF_ERROR(stream->BlockHostUntilDone()); - TF_RETURN_IF_ERROR(shaped_buffer->buffers().ForEachMutableElementWithStatus( - [&](const xla::ShapeIndex& index, se::DeviceMemoryBase* buffer) { - const xla::Shape& buffer_shape = - xla::ShapeUtil::GetSubshape(*output_device_shape, index); - if (buffer_shape.IsTuple()) { - return Status::OK(); - } - xla::Shape& host_shape = - *xla::ShapeUtil::GetMutableSubshape(output_host_shape, index); - xla::Shape& device_shape = - *xla::ShapeUtil::GetMutableSubshape(output_device_shape, index); - if (device_shape.is_static()) { - return Status::OK(); - } - TF_ASSIGN_OR_RETURN(auto metadata, - ReadMetadataLiteral(stream, buffer, buffer_shape, - transfer_manager)); - // Update shape size from metadata. - for (int64 i = 0; i < metadata.element_count(); ++i) { - host_shape.mutable_dimensions()[i] = metadata.Get({i}); - device_shape.mutable_dimensions()[i] = metadata.Get({i}); - } - return Status::OK(); - })); - output_host_shape->clear_dynamic_dimensions(); - output_device_shape->clear_dynamic_dimensions(); return Status::OK(); } -// Create output tuple from run_result. xla::StatusOr> CreateOutputTuple( - se::Stream* stream, xla::ScopedShapedBuffer run_result, - xla::Backend* backend, int device_ordinal) { + se::Stream* stream, xla::ExecutionOutput run_result, xla::Backend* backend, + int device_ordinal) { XRTTupleAllocation* output_tuple; - xla::ShapedBuffer shaped_buffer = run_result.release(); - if (shaped_buffer.on_device_shape().is_dynamic()) { + xla::ScopedShapedBuffer* shaped_buffer = run_result.MutableResult(); + if (shaped_buffer->on_device_shape().is_dynamic()) { // Update dynamic shapes from output buffer, and create a XRT tensor with // dimension sizes read from metadata. - xla::Shape output_host_shape = shaped_buffer.on_host_shape(); - xla::Shape output_device_shape = shaped_buffer.on_device_shape(); - TF_RETURN_IF_ERROR(UpdateDynamicOutputs( - stream, &shaped_buffer, &output_host_shape, &output_device_shape)); + xla::Shape output_host_shape = shaped_buffer->on_host_shape(); + xla::Shape output_device_shape = shaped_buffer->on_device_shape(); + TF_ASSIGN_OR_RETURN( + auto transfer_manager, + xla::TransferManager::GetForPlatform(stream->parent()->platform())); + TF_RETURN_IF_ERROR(transfer_manager->ReadDynamicShapes( + stream, shaped_buffer, &output_host_shape, &output_device_shape)); TF_RETURN_IF_ERROR(XRTTupleAllocation::CreateFromBuffer( - shaped_buffer, output_host_shape, output_device_shape, backend, + *shaped_buffer, output_host_shape, output_device_shape, backend, device_ordinal, &output_tuple)); } else { // Fast-path: Don't copy shapes of output buffer. TF_RETURN_IF_ERROR(XRTTupleAllocation::CreateFromBuffer( - shaped_buffer, backend, device_ordinal, &output_tuple)); + *shaped_buffer, backend, device_ordinal, &output_tuple)); } + // After the output tuple is created, we can release the output result + // buffers, to make sure they won't be cleared by its destructor. + (void)run_result.ConsumeResult().release(); return RefPtr(output_tuple); } xla::StatusOr> RunExecutable( OpKernelContext* context, XRTGenericDeviceAccessor::ScopedRef* device_ref, - xla::LocalExecutable* executable, const InputBuffers& input_buffers, - se::Stream* stream, int rng_seed, + xla::LocalExecutable* executable, + absl::Span> input_tuples, + bool release_inputs, se::Stream* stream, int rng_seed, const xrt::CommonExecutionConfig& config) { - VLOG(2) << "Executing computation."; + const xla::ComputationLayout& computation_layout = + executable->executable()->module_config().entry_computation_layout(); + std::vector input_is_dynamic = GetDynamicInputInfo(computation_layout); + TF_ASSIGN_OR_RETURN( + std::vector execution_inputs, + GetArgumentsBuffers( + executable->executable()->module().input_output_alias_config(), + input_tuples, input_is_dynamic, release_inputs)); + xla::ExecutableRunOptions run_options; run_options.set_stream(stream); run_options.set_allocator(device_ref->backend()->memory_allocator()); @@ -419,51 +345,28 @@ xla::StatusOr> RunExecutable( } run_options.set_gpu_executable_run_options(&gpu_options); - Env* env = Env::Default(); - auto start_time = env->NowMicros(); const std::vector& shape_layouts = executable->executable() ->module_config() .entry_computation_layout() .parameter_layouts(); - TF_ASSIGN_OR_RETURN(auto new_allocations, - UpdateDynamicInputs(stream, run_options.allocator(), - input_buffers.input_pointers, - shape_layouts)); - auto new_allocations_ptr = - std::make_shared>( - std::move(new_allocations)); + TF_RETURN_IF_ERROR(UpdateDynamicInputs(stream, run_options.allocator(), + &execution_inputs, shape_layouts)); TF_ASSIGN_OR_RETURN( - xla::ScopedShapedBuffer run_result, - executable->Run(input_buffers.input_pointers, run_options)); - // Retain the new allocation for input memory until the end of execution. - stream->ThenDoHostCallback([new_allocations_ptr]() { return Status::OK(); }); - - auto elapsed = env->NowMicros() - start_time; - VLOG(2) << "Elapsed time: " << elapsed << "us"; + xla::ExecutionOutput run_result, + executable->Run(std::move(execution_inputs), run_options)); TF_ASSIGN_OR_RETURN( RefPtr output_tuple_ptr, CreateOutputTuple(stream, std::move(run_result), device_ref->backend(), device_ref->device_ordinal())); - // The ScopedShapedBuffer returned by the executable Run() API, in case of // input/output buffer aliasing, might have holes in it, which need to be // filled using the proper input tuples buffers which are the source of // aliasing. - const xla::HloInputOutputAliasConfig& input_output_alias = - executable->executable()->module().input_output_alias_config(); - auto alias_function = - [&](const xla::ShapeIndex& output_index, - const xla::HloInputOutputAliasConfig::Alias& alias) -> Status { - TF_RET_CHECK(alias.parameter_number < input_buffers.input_tuples.size()); - return alias.kind == xla::HloInputOutputAliasConfig::AliasKind::kUserAlias - ? output_tuple_ptr->AliasBufferFrom( - *input_buffers.input_tuples[alias.parameter_number], - alias.parameter_index, output_index) - : Status::OK(); - }; - TF_RETURN_IF_ERROR(input_output_alias.ForEachAliasWithStatus(alias_function)); + TF_RETURN_IF_ERROR(RebuildOutputAliases( + output_tuple_ptr, input_tuples, + executable->executable()->module().input_output_alias_config())); return std::move(output_tuple_ptr); } @@ -471,12 +374,13 @@ xla::StatusOr> RunExecutable( xla::StatusOr> ExecuteComputation( OpKernelContext* context, XRTMemoryManager* memory_manager, XRTGenericDeviceAccessor::ScopedRef* device_ref, - xla::LocalExecutable* executable, const InputBuffers& input_buffers, - se::Stream* stream, int rng_seed, + xla::LocalExecutable* executable, + absl::Span> input_tuples, + bool release_inputs, se::Stream* stream, int rng_seed, const xrt::CommonExecutionConfig& config) { auto runfn = [&]() { - return RunExecutable(context, device_ref, executable, input_buffers, stream, - rng_seed, config); + return RunExecutable(context, device_ref, executable, input_tuples, + release_inputs, stream, rng_seed, config); }; // We pass zero as requested_free_size as there is no simple way to get the @@ -495,12 +399,13 @@ xla::StatusOr> ExecuteComputation( se::Stream* stream, int rng_seed, const xrt::CommonExecutionConfig& config) { XRTMemoryManager::WorkingSet working_set(memory_manager); - TF_ASSIGN_OR_RETURN(InputBuffers input_buffers, - GetInputBuffers(&working_set, device_ref->backend(), - input_coords, release_inputs)); + TF_ASSIGN_OR_RETURN( + std::vector> input_tuples, + GetInputTuples(executable, &working_set, device_ref->backend(), + input_coords, release_inputs)); return ExecuteComputation(context, memory_manager.get(), device_ref, - executable, input_buffers, stream, rng_seed, - config); + executable, input_tuples, release_inputs, stream, + rng_seed, config); } // XRTExecuteOp @@ -653,16 +558,16 @@ Status XRTExecuteChainedOp::DoWork(OpKernelContext* context) { auto execute_op = [&](const xrt::XRTChainedExecuteOp& op, absl::Span> op_inputs) -> xla::StatusOr> { - TF_ASSIGN_OR_RETURN(InputBuffers input_buffers, - GetChainedOpInputs(op, op_inputs)); - std::unique_ptr entry; TF_RETURN_IF_ERROR(cache->Lookup(op.computation_handle(), &entry)); xla::LocalExecutable* executable = entry->get().get_executable(); - return ExecuteComputation(context, memory_manager.get(), &device_ref, - executable, input_buffers, stream, rng_seed, - config.common_config()); + TF_ASSIGN_OR_RETURN(std::vector> input_tuples, + GetChainedOpInputTuples(op, op_inputs)); + + return ExecuteComputation( + context, memory_manager.get(), &device_ref, executable, input_tuples, + /*release_inputs=*/false, stream, rng_seed, config.common_config()); }; return ExecuteChained(context, memory_manager, device_ref.backend(), diff --git a/tensorflow/compiler/xrt/xrt_util.cc b/tensorflow/compiler/xrt/xrt_util.cc index b8a0afc92c5..926ba23c7af 100644 --- a/tensorflow/compiler/xrt/xrt_util.cc +++ b/tensorflow/compiler/xrt/xrt_util.cc @@ -221,6 +221,140 @@ xla::StatusOr> GetComputationInputs( return std::move(input_coords); } +bool InputShapeMatches(const xla::Shape& parameter_shape, + const xla::Shape& input_shape) { + auto shape_checker = [&](const xla::Shape& pshape, + const xla::ShapeIndex& index) { + if (pshape.IsArray()) { + TF_ASSIGN_OR_RETURN(const xla::Shape* ishape, + xla::ShapeUtil::TryGetSubshape(input_shape, index)); + if (pshape.rank() != ishape->rank() || + pshape.element_type() != ishape->element_type()) { + return errors::InvalidArgument("Mismatching shapes"); + } + if (pshape.is_static() && pshape.layout() != ishape->layout()) { + return errors::InvalidArgument("Mismatching layouts"); + } + for (int64 dim = 0; dim < pshape.rank(); ++dim) { + if (pshape.is_dynamic_dimension(dim)) { + if (pshape.dimensions(dim) < ishape->dimensions(dim)) { + return errors::InvalidArgument("Mismatching shapes"); + } + } else if (pshape.dimensions(dim) != ishape->dimensions(dim)) { + return errors::InvalidArgument("Mismatching shapes"); + } + } + } + return Status::OK(); + }; + return xla::ShapeUtil::ForEachSubshapeWithStatus(parameter_shape, + shape_checker) + .ok(); +} + +xla::StatusOr>> GetInputTupleAllocations( + const std::vector& input_coords, + XRTMemoryManager::WorkingSet* working_set, xla::Backend* backend, + int64 num_input_shapes, + const std::function& shape_getter, bool release_inputs) { + if (input_coords.size() != num_input_shapes) { + return errors::InvalidArgument( + "Number of inputs does not match executable proto input shapes: ", + input_coords.size(), " vs. ", num_input_shapes); + } + std::vector> input_tuples; + input_tuples.reserve(input_coords.size()); + for (size_t i = 0; i < input_coords.size(); ++i) { + TF_RETURN_IF_ERROR( + working_set->LookupAndPin(backend, input_coords[i].handle)); + auto tuple = working_set->PinnedTuples().back(); + if (release_inputs) { + // We are holding a reference to the tuple, so we can safely delete it + // from the resource manager here. + TF_RETURN_IF_ERROR( + working_set->MemoryManager()->Release(input_coords[i].handle)); + VLOG(2) << "Released allocation handle " << input_coords[i].handle; + } + xla::Shape input_shape = shape_getter(i); + if (!InputShapeMatches(input_shape, tuple->on_host_shape())) { + return errors::InvalidArgument( + "Run-time shape mismatch for XRTExecute argument[", i, "] (", + input_coords[i].handle, "). Expected ", input_shape.DebugString(), + "; got ", tuple->on_host_shape().DebugString()); + } + if (input_coords[i].index.empty()) { + input_tuples.emplace_back(std::move(tuple)); + } else { + XRTTupleAllocation* sub_tuple; + TF_RETURN_IF_ERROR(XRTTupleAllocation::MakeSubBuffer( + tuple.get(), input_coords[i].index, &sub_tuple, + /*alias_parent_allocation=*/true)); + input_tuples.emplace_back(sub_tuple); + } + } + return std::move(input_tuples); +} + +Status RebuildOutputAliases( + const RefPtr& output_tuple, + absl::Span> input_tuples, + const xla::HloInputOutputAliasConfig& input_output_alias) { + auto alias_function = + [&](const xla::ShapeIndex& output_index, + const xla::HloInputOutputAliasConfig::Alias& alias) -> Status { + TF_RET_CHECK(alias.parameter_number < input_tuples.size()); + return alias.kind == xla::HloInputOutputAliasConfig::AliasKind::kUserAlias + ? output_tuple->AliasBufferFrom( + *input_tuples[alias.parameter_number], + alias.parameter_index, output_index) + : Status::OK(); + }; + return input_output_alias.ForEachAliasWithStatus(alias_function); +} + +xla::StatusOr> GetArgumentsBuffers( + const xla::HloInputOutputAliasConfig& input_output_alias, + absl::Span> input_tuples, + const std::vector& input_is_dynamic, bool release_inputs) { + auto is_dynamic = [&](size_t arg) { + return arg < input_is_dynamic.size() && input_is_dynamic[arg]; + }; + std::vector arguments; + // Don't alias dynamic input -- Due to the underlying implementation, + // aliased inputs have two owners: XRTAllocation and return value of + // this function. If an argument is dynamic and the ownership is + // released to output of this function, TPUExecute will free it and + // reallocate a new one, which creates a double freeing issue where + // XRTAllocation also attempts to release the buffer. + bool alias_outputs = release_inputs && input_tuples.size() == 1 && + input_tuples[0]->IsExclusiveOwner() && !is_dynamic(0); + arguments.reserve(input_tuples.size()); + for (int64 i = 0; i < input_tuples.size(); ++i) { + auto alias_checker = + [&](const xla::ShapeIndex& index) -> xla::StatusOr { + // Only the buffers which the caller explicitly marked as aliased + // (kUserAlias), should create aliases. + // The XLA compiler might create opportunistic aliases (kSystemAlias) + // which need a different handling. With a system alias we know that XLA + // is going to reuse a given input parameter buffer for a given output, so + // unless it is known at call site that the input buffer has no more uses, + // a copy needs to be made at call site. With user specified alias the + // caller tells us that he expects a given output to land over the buffers + // of a given parametter. + if (input_output_alias.ParameterAliasKind(i, index) == + xla::HloInputOutputAliasConfig::AliasKind::kUserAlias) { + TF_RET_CHECK(!is_dynamic(i)); + return true; + } + return alias_outputs; + }; + TF_ASSIGN_OR_RETURN(xla::ExecutionInput exec_input, + input_tuples[i]->ToExecutionInput(alias_checker)); + arguments.emplace_back(std::move(exec_input)); + } + return std::move(arguments); +} + Status CreateExecuteOutput(OpKernelContext* context, XRTMemoryManager* memory_manager, RefPtr output_tuple, diff --git a/tensorflow/compiler/xrt/xrt_util.h b/tensorflow/compiler/xrt/xrt_util.h index cc1480fdb00..832c106621f 100644 --- a/tensorflow/compiler/xrt/xrt_util.h +++ b/tensorflow/compiler/xrt/xrt_util.h @@ -23,6 +23,8 @@ limitations under the License. #include #include "tensorflow/compiler/xla/service/backend.h" +#include "tensorflow/compiler/xla/service/hlo_input_output_alias_config.h" +#include "tensorflow/compiler/xla/shape.h" #include "tensorflow/compiler/xla/shape_util.h" #include "tensorflow/compiler/xla/statusor.h" #include "tensorflow/compiler/xla/xla.pb.h" @@ -69,6 +71,25 @@ xla::DebugOptions BuildXlaDebugOptions(const xla::DebugOptions& ref_options); xla::StatusOr> GetComputationInputs( OpKernelContext* context, const char* input_name); +bool InputShapeMatches(const xla::Shape& parameter_shape, + const xla::Shape& input_shape); + +xla::StatusOr>> GetInputTupleAllocations( + const std::vector& input_coords, + XRTMemoryManager::WorkingSet* working_set, xla::Backend* backend, + int64 num_input_shapes, + const std::function& shape_getter, bool release_inputs); + +Status RebuildOutputAliases( + const RefPtr& output_tuple, + absl::Span> input_tuples, + const xla::HloInputOutputAliasConfig& input_output_alias); + +xla::StatusOr> GetArgumentsBuffers( + const xla::HloInputOutputAliasConfig& input_output_alias, + absl::Span> input_tuples, + const std::vector& input_is_dynamic, bool release_inputs); + // Create the XRT execute output tensor given the computation result // (output_tuple). The return_exploded_tuple tells whether a tuple result should // be returned as vector of handles representing each tuple child. diff --git a/tensorflow/core/BUILD b/tensorflow/core/BUILD index 50f1f2527a5..d0be6ee9597 100644 --- a/tensorflow/core/BUILD +++ b/tensorflow/core/BUILD @@ -72,6 +72,7 @@ load( "if_ios", "if_mobile", "if_not_windows", + "if_tpu", "tf_android_core_proto_headers", "tf_cc_test", "tf_cc_test_mkl", @@ -1093,6 +1094,8 @@ cc_library( ]) + if_tensorrt([ "//tensorflow/compiler/tf2tensorrt:trt_engine_resource_op_kernels", "//tensorflow/compiler/tf2tensorrt:trt_op_kernels", + ]) + if_tpu([ + "//tensorflow/core/tpu/kernels", ]), ) @@ -1861,7 +1864,9 @@ cc_library( "//tensorflow/core/lib/io:random_inputstream", "//tensorflow/core/lib/io:record_reader", "//tensorflow/core/lib/io:record_writer", + "//tensorflow/core/lib/io:snappy_compression_options", "//tensorflow/core/lib/io:snappy_inputbuffer", + "//tensorflow/core/lib/io:snappy_inputstream", "//tensorflow/core/lib/io:snappy_outputbuffer", "//tensorflow/core/lib/io:table", "//tensorflow/core/lib/io:table_options", @@ -2254,7 +2259,7 @@ tf_cuda_library( "//tensorflow/core/platform/default/build_config:platformlib", "//tensorflow/core/profiler/lib:annotated_traceme", "//tensorflow/core/profiler/lib:traceme", - "//tensorflow/core/tpu:tpu_library_loader", + "//tensorflow/core/tpu:tpu_api_dlsym_initializer", "//tensorflow/core/util:einsum_op_util", "//tensorflow/core/util:padding", "//tensorflow/core/util:port", diff --git a/tensorflow/core/api_def/base_api/api_def_BandedTriangularSolve.pbtxt b/tensorflow/core/api_def/base_api/api_def_BandedTriangularSolve.pbtxt new file mode 100644 index 00000000000..ba5e1bdcaf2 --- /dev/null +++ b/tensorflow/core/api_def/base_api/api_def_BandedTriangularSolve.pbtxt @@ -0,0 +1,4 @@ +op { + graph_op_name: "BandedTriangularSolve" + visibility: HIDDEN +} diff --git a/tensorflow/core/api_def/base_api/api_def_BesselI0.pbtxt b/tensorflow/core/api_def/base_api/api_def_BesselI0.pbtxt new file mode 100644 index 00000000000..2c47960429c --- /dev/null +++ b/tensorflow/core/api_def/base_api/api_def_BesselI0.pbtxt @@ -0,0 +1,4 @@ +op { + graph_op_name: "BesselI0" + visibility: HIDDEN +} diff --git a/tensorflow/core/api_def/base_api/api_def_BesselI0e.pbtxt b/tensorflow/core/api_def/base_api/api_def_BesselI0e.pbtxt index 08313cebb99..7965af4916e 100644 --- a/tensorflow/core/api_def/base_api/api_def_BesselI0e.pbtxt +++ b/tensorflow/core/api_def/base_api/api_def_BesselI0e.pbtxt @@ -1,10 +1,4 @@ op { graph_op_name: "BesselI0e" - summary: "Computes the Bessel i0e function of `x` element-wise." - description: <