diff --git a/.bazelrc b/.bazelrc index 1b9f5e87c6b..e765c302c28 100644 --- a/.bazelrc +++ b/.bazelrc @@ -461,12 +461,12 @@ build:rbe_linux_cuda11.0_nvcc_py3.6 --config=rbe_linux_cuda11.0_nvcc_base --repo build:rbe_linux_cuda11.0_nvcc_py3.7 --config=rbe_linux_cuda11.0_nvcc_base --repo_env=TF_PYTHON_CONFIG_REPO="@ubuntu18.04-gcc7_manylinux2010-cuda11.0-cudnn8-tensorrt7.1_config_python3.7" build:rbe_linux_cuda11.0_nvcc_py3.8 --config=rbe_linux_cuda11.0_nvcc_base --repo_env=TF_PYTHON_CONFIG_REPO="@ubuntu18.04-gcc7_manylinux2010-cuda11.0-cudnn8-tensorrt7.1_config_python3.8" -# Map default to CUDA 10.1. +# Map default to CUDA 11 for PY35 and greater. build:rbe_linux_cuda_nvcc_py27 --config=rbe_linux_cuda10.1_nvcc_py2.7 -build:rbe_linux_cuda_nvcc_py35 --config=rbe_linux_cuda10.1_nvcc_py3.5 -build:rbe_linux_cuda_nvcc_py36 --config=rbe_linux_cuda10.1_nvcc_py3.6 -build:rbe_linux_cuda_nvcc_py37 --config=rbe_linux_cuda10.1_nvcc_py3.7 -build:rbe_linux_cuda_nvcc_py38 --config=rbe_linux_cuda10.1_nvcc_py3.8 +build:rbe_linux_cuda_nvcc_py35 --config=rbe_linux_cuda11.0_nvcc_py3.5 +build:rbe_linux_cuda_nvcc_py36 --config=rbe_linux_cuda11.0_nvcc_py3.6 +build:rbe_linux_cuda_nvcc_py37 --config=rbe_linux_cuda11.0_nvcc_py3.7 +build:rbe_linux_cuda_nvcc_py38 --config=rbe_linux_cuda11.0_nvcc_py3.8 # Deprecated configs that people might still use. build:rbe_linux_cuda_nvcc --config=rbe_linux_cuda_nvcc_py36 @@ -583,9 +583,9 @@ build:release_cpu_macos --config=avx_linux build:release_gpu_common --config=release_common build:release_gpu_common --config=cuda build:release_gpu_common --config=tensorrt -build:release_gpu_common --action_env CUDA_TOOLKIT_PATH="/usr/local/cuda-10.1" -build:release_gpu_common --action_env=TF_CUDA_VERSION="10" -build:release_gpu_common --action_env=TF_CUDNN_VERSION="7" +build:release_gpu_common --action_env CUDA_TOOLKIT_PATH="/usr/local/cuda-11.0" +build:release_gpu_common --action_env=TF_CUDA_VERSION="11" +build:release_gpu_common --action_env=TF_CUDNN_VERSION="8" build:release_gpu_common --action_env=TF_NEED_TENSORRT="1" build:release_gpu_common --action_env=TF_CUDA_COMPUTE_CAPABILITIES="sm_35,sm_37,sm_52,sm_60,sm_61,compute_70" build:release_gpu_common --action_env=TENSORRT_INSTALL_PATH="/usr/local/tensorrt" @@ -595,8 +595,7 @@ build:release_gpu_common --action_env=GCC_HOST_COMPILER_PATH="/usr/bin/gcc-5" build:release_gpu_linux --config=release_gpu_common build:release_gpu_linux --config=avx_linux -build:release_gpu_linux --crosstool_top=//third_party/toolchains/preconfig/ubuntu16.04/gcc7_manylinux2010-nvcc-cuda10.1:toolchain - +build:release_gpu_linux --crosstool_top=//third_party/toolchains/preconfig/ubuntu16.04/gcc7_manylinux2010-nvcc-cuda11:toolchain build:release_windows_common --config=release_common build:release_windows_common --define=no_tensorflow_py_deps=true build:release_windows_common --announce_rc diff --git a/RELEASE.md b/RELEASE.md index d4b5b27630e..7057657c340 100644 --- a/RELEASE.md +++ b/RELEASE.md @@ -81,6 +81,12 @@ server and set `dispatcher_fault_tolerance=True`. The dispatcher will store its state to `work_dir`, so that on restart it can continue from its previous state after restart. + * Added tf.data service support for sharing dataset graphs via shared + filesystem instead of over RPC. This reduces load on the dispatcher, + improving performance of distributing datasets. For this to work, the + dispatcher's `work_dir` must be accessible from workers. If the worker + fails to read from the `work_dir`, it falls back to using RPC for dataset + graph transfer. * Added optional `exclude_cols` parameter to CsvDataset. This parameter is the complement of `select_cols`; at most one of these should be specified. * We have implemented an optimization which reorders data-discarding @@ -88,6 +94,7 @@ dataset when it is safe to do so. The optimization can be disabled via the `experimental_optimization.reorder_data_discarding_ops` dataset option. + * `tf.data.Options` were previously immutable and can now be overriden. * `tf.image`: * Added deterministic `tf.image.stateless_random_*` functions for each `tf.image.random_*` function. Added a new op @@ -106,7 +113,8 @@ * Error messages when Functional API construction goes wrong (and when ops cannot be converted to Keras layers automatically) should be clearer and easier to understand. * `Optimizer.minimize` can now accept a loss `Tensor` and a `GradientTape` as an alternative to accepting a `callable` loss. - * Added `beta` parameter to FTRL optimizer to match paper. + * Added `beta` hyperparameter to FTRL optimizer classes (Keras and others) + to match FTRL paper (https://research.google.com/pubs/archive/41159.pdf). * Added `mobilenet_v3` to keras application model. * `Optimizer.__init__` now accepts a `gradient_aggregator` to allow for customization of how gradients are aggregated across devices, as well as @@ -155,6 +163,14 @@ * * Tracing and Debugging: * +* `tf.train.Checkpoint`: + * Now accepts a `root` argument in the initialization, which generates a + checkpoint with a root object. This allows users to create a `Checkpoint` + object that is compatible with Keras `model.save_weights()` and + `model.load_weights`. The checkpoint is also compatible with the + checkpoint saved in the `variables/` folder in the SavedModel. + * When restoring, `save_path` can be a path to a SavedModel. The function + will automatically find the checkpoint in the SavedModel. * Other: * We have replaced uses of "whitelist" and "blacklist" with "allowlist" and "denylist" where possible. Please see @@ -251,6 +267,7 @@ stjohnso98, , , , , * Mutable tables now restore checkpointed values when loaded from SavedModel. * GPU * TF 2.3 includes PTX kernels only for [compute capability](https://developer.nvidia.com/cuda-gpus) 7.0 to reduce the TF pip binary size. Earlier releases included PTX for a variety of older compute capabilities. + * Remove environmental variable `TF_USE_CUDNN`. * Others * Retain parent namescope for ops added inside `tf.while_loop`/`tf.cond`/`tf.switch_case`. * Update `tf.vectorized_map` to support vectorizing `tf.while_loop` and TensorList operations. @@ -1582,6 +1599,7 @@ Yuan (Terry) Tang, Yuchen Ying, Yves-Noel Weweler, zhangyujing, zjjott, zyeric, color palette of the frame. This has been fixed now * image.resize now considers proper pixel centers and has new kernels (incl. anti-aliasing). + * Added an isotonic regression solver (tf.nn.isotonic_regression). * Performance * Turn on MKL-DNN contraction kernels by default. MKL-DNN dynamically dispatches the best kernel implementation based on CPU vector diff --git a/tensorflow/c/BUILD b/tensorflow/c/BUILD index fb5a0d250bb..9d8032aca52 100644 --- a/tensorflow/c/BUILD +++ b/tensorflow/c/BUILD @@ -58,9 +58,9 @@ filegroup( visibility = ["//visibility:public"], ) -filegroup( +cc_library( name = "pywrap_required_hdrs", - srcs = [ + textual_hdrs = [ "c_api_internal.h", "c_api_macros.h", "conversion_macros.h", @@ -220,6 +220,7 @@ cc_library( name = "logging", srcs = ["logging.cc"], hdrs = ["logging.h"], + visibility = ["//visibility:public"], deps = [ ":c_api_macros", "//tensorflow/core/platform:logging", diff --git a/tensorflow/c/eager/BUILD b/tensorflow/c/eager/BUILD index 47452c245dc..ce2e2382309 100644 --- a/tensorflow/c/eager/BUILD +++ b/tensorflow/c/eager/BUILD @@ -240,6 +240,7 @@ tf_cuda_cc_test( "//tensorflow/c:c_api", "//tensorflow/c:c_test_util", "//tensorflow/c:tf_status_helper", + "//tensorflow/c/experimental/gradients:array_grad", "//tensorflow/c/experimental/gradients:math_grad", "//tensorflow/c/experimental/ops:array_ops", "//tensorflow/cc/profiler", @@ -255,6 +256,72 @@ tf_cuda_cc_test( ], ) +cc_library( + name = "mnist_gradients_testutil", + srcs = [ + "mnist_gradients_testutil.cc", + ], + hdrs = [ + "mnist_gradients_testutil.h", + ], + visibility = [ + "//tensorflow:internal", + ], + deps = [ + ":abstract_context", + ":abstract_operation", + ":abstract_tensor_handle", + ":c_api_unified_internal", + ":gradients_internal", + ":tape", + "//tensorflow/c/experimental/ops:array_ops", + "//tensorflow/c/experimental/ops:math_ops", + "//tensorflow/c/experimental/ops:nn_ops", + "//tensorflow/core/common_runtime/eager:attr_builder", + "//tensorflow/core/lib/llvm_rtti", + "@com_google_absl//absl/container:flat_hash_map", + "@com_google_absl//absl/strings", + ], +) + +tf_cuda_cc_test( + name = "mnist_gradients_test", + size = "small", + srcs = [ + "mnist_gradients_test.cc", + ], + args = ["--heap_check=local"], + extra_copts = tfe_xla_copts(), + linkstatic = tf_kernel_tests_linkstatic(), + tags = tf_cuda_tests_tags() + ["nomac"], + deps = [ + ":abstract_tensor_handle", + ":c_api_experimental", + ":c_api_test_util", + ":c_api_unified_internal", + ":gradients_internal", + ":mnist_gradients_testutil", + "//tensorflow/c:c_api", + "//tensorflow/c:c_test_util", + "//tensorflow/c:tf_status_helper", + "//tensorflow/c/experimental/gradients:math_grad", + "//tensorflow/c/experimental/gradients:nn_grad", + "//tensorflow/c/experimental/ops:array_ops", + "//tensorflow/c/experimental/ops:math_ops", + "//tensorflow/c/experimental/ops:nn_ops", + "//tensorflow/cc/profiler", + "//tensorflow/compiler/mlir/tensorflow/c:mlir_c_api_registration", + "//tensorflow/core:lib", + "//tensorflow/core:protos_all_cc", + "//tensorflow/core:test", + "//tensorflow/core:test_main", + "//tensorflow/core/lib/llvm_rtti", + "@com_google_absl//absl/container:flat_hash_set", + "@com_google_absl//absl/strings", + "@com_google_absl//absl/types:span", + ], +) + cc_library( name = "abstract_tensor_handle", hdrs = ["abstract_tensor_handle.h"], diff --git a/tensorflow/c/eager/c_api_remote_function_test.cc b/tensorflow/c/eager/c_api_remote_function_test.cc index d3f9826635c..a9bbd5b694f 100644 --- a/tensorflow/c/eager/c_api_remote_function_test.cc +++ b/tensorflow/c/eager/c_api_remote_function_test.cc @@ -30,18 +30,26 @@ TEST(CAPI, RemoteExecuteSilentCopiesAsyncFunc) { TestRemoteExecuteSilentCopiesFunc(/*async=*/true, /*remote=*/true, /*heavy_load_on_streaming_rpc=*/false); } +TEST(CAPI, RemoteExecuteSilentCopiesFuncRemoteOutputs) { + TestRemoteExecuteSilentCopiesFunc(/*async=*/false, /*remote=*/true, + /*heavy_load_on_streaming_rpc=*/false, + /*remote_func_outputs=*/true); +} +TEST(CAPI, RemoteExecuteSilentCopiesAsyncFuncRemoteOutputs) { + TestRemoteExecuteSilentCopiesFunc(/*async=*/true, /*remote=*/true, + /*heavy_load_on_streaming_rpc=*/false, + /*remote_func_outputs=*/true); +} TEST(CAPI, RemoteExecuteSilentCopiesLocalAsyncFunc) { TestRemoteExecuteSilentCopiesFunc(/*async=*/true, /*remote=*/false, /*heavy_load_on_streaming_rpc=*/false); } -// TODO(b/162618595): Enable this test once we remove the check of remote -// outputs in ProcessFunctionLibraryRuntime. -TEST(CAPI, DISABLED_RemoteExecuteSilentCopiesLocalFuncRemoteOutputs) { +TEST(CAPI, RemoteExecuteSilentCopiesLocalFuncRemoteOutputs) { TestRemoteExecuteSilentCopiesFunc(/*async=*/false, /*remote=*/false, /*heavy_load_on_streaming_rpc=*/false, /*remote_func_outputs=*/true); } -TEST(CAPI, DISABLED_RemoteExecuteSilentCopiesLocalAsyncFuncRemoteOutputs) { +TEST(CAPI, RemoteExecuteSilentCopiesLocalAsyncFuncRemoteOutputs) { TestRemoteExecuteSilentCopiesFunc(/*async=*/true, /*remote=*/false, /*heavy_load_on_streaming_rpc=*/false, /*remote_func_outputs=*/true); diff --git a/tensorflow/c/eager/c_api_remote_test_util.cc b/tensorflow/c/eager/c_api_remote_test_util.cc index 0ae5b74553a..159fa442a73 100644 --- a/tensorflow/c/eager/c_api_remote_test_util.cc +++ b/tensorflow/c/eager/c_api_remote_test_util.cc @@ -169,6 +169,13 @@ void TestRemoteExecuteSilentCopies(bool async, bool remote, bool func, ASSERT_TRUE(remote_arg->HasLocalMirror(nullptr)); } + if (remote_func_outputs) { + const string backing_device = + TFE_TensorHandleBackingDeviceName(retvals[0], status); + ASSERT_EQ(TF_GetCode(status), TF_OK) << TF_Message(status); + EXPECT_EQ(backing_device, task2_name); + } + auto* retval_task0 = TFE_TensorHandleCopyToDevice( retvals[0], ctx, "/job:localhost/replica:0/task:0/device:CPU:0", status); ASSERT_EQ(TF_GetCode(status), TF_OK) << TF_Message(status); diff --git a/tensorflow/c/eager/c_api_test_util.cc b/tensorflow/c/eager/c_api_test_util.cc index 192f10533a6..fd68866f502 100644 --- a/tensorflow/c/eager/c_api_test_util.cc +++ b/tensorflow/c/eager/c_api_test_util.cc @@ -102,6 +102,32 @@ TFE_TensorHandle* TestMatrixTensorHandleWithInput(TFE_Context* ctx, return th; } +TFE_TensorHandle* TestTensorHandleWithDimsFloat(TFE_Context* ctx, float data[], + int64_t dims[], int num_dims) { + TF_Status* status = TF_NewStatus(); + TF_Tensor* t = + TFE_AllocateHostTensor(ctx, TF_FLOAT, &dims[0], num_dims, status); + memcpy(TF_TensorData(t), &data[0], TF_TensorByteSize(t)); + TFE_TensorHandle* th = TFE_NewTensorHandleFromTensor(ctx, t, status); + CHECK_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status); + TF_DeleteTensor(t); + TF_DeleteStatus(status); + return th; +} + +TFE_TensorHandle* TestTensorHandleWithDimsInt(TFE_Context* ctx, int data[], + int64_t dims[], int num_dims) { + TF_Status* status = TF_NewStatus(); + TF_Tensor* t = + TFE_AllocateHostTensor(ctx, TF_INT32, &dims[0], num_dims, status); + memcpy(TF_TensorData(t), &data[0], TF_TensorByteSize(t)); + TFE_TensorHandle* th = TFE_NewTensorHandleFromTensor(ctx, t, status); + CHECK_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status); + TF_DeleteTensor(t); + TF_DeleteStatus(status); + return th; +} + TFE_TensorHandle* TestMatrixTensorHandle100x100(TFE_Context* ctx) { constexpr int64_t dims[] = {100, 100}; constexpr int num_elements = dims[0] * dims[1]; diff --git a/tensorflow/c/eager/c_api_test_util.h b/tensorflow/c/eager/c_api_test_util.h index fcf407aa9c3..2f77ae5cf44 100644 --- a/tensorflow/c/eager/c_api_test_util.h +++ b/tensorflow/c/eager/c_api_test_util.h @@ -40,6 +40,14 @@ TFE_TensorHandle* TestMatrixTensorHandleWithInput(TFE_Context* ctx, float data[], int64_t dims[], int num_dims); +// Get a Matrix TensorHandle with given float values and dimensions +TFE_TensorHandle* TestTensorHandleWithDimsFloat(TFE_Context* ctx, float data[], + int64_t dims[], int num_dims); + +// Get a Matrix TensorHandle with given int values and dimensions +TFE_TensorHandle* TestTensorHandleWithDimsInt(TFE_Context* ctx, int data[], + int64_t dims[], int num_dims); + // Return a tensor handle containing a 100x100 matrix of floats TFE_TensorHandle* TestMatrixTensorHandle100x100(TFE_Context* ctx); diff --git a/tensorflow/c/eager/c_api_unified_experimental_graph.cc b/tensorflow/c/eager/c_api_unified_experimental_graph.cc index 7bda3aed76d..9d064039141 100644 --- a/tensorflow/c/eager/c_api_unified_experimental_graph.cc +++ b/tensorflow/c/eager/c_api_unified_experimental_graph.cc @@ -85,7 +85,11 @@ class GraphOperation : public TracingOperation { return errors::FailedPrecondition( "GraphOperation::Reset must be called before calling SetOpName."); } - op_.reset(TF_NewOperation(g_, op_type_.c_str(), op_name)); + // TODO(b/145674566): We use Graph::NewName to get a unique name here but + // this may not be consistent with python's naming policy. + mutex_lock l(g_->mu); + op_.reset(new TF_OperationDescription(g_, op_type_.c_str(), + g_->graph.NewName(op_name).c_str())); return Status::OK(); } const string& Name() const override { return op_type_; } diff --git a/tensorflow/c/eager/c_api_unified_experimental_test.cc b/tensorflow/c/eager/c_api_unified_experimental_test.cc index c669ff4cf96..7b3a497a0c5 100644 --- a/tensorflow/c/eager/c_api_unified_experimental_test.cc +++ b/tensorflow/c/eager/c_api_unified_experimental_test.cc @@ -557,7 +557,7 @@ TEST_P(UnifiedCAPI, TestMultiOutputGraph) { auto* add_op = TF_NewAbstractOp(graph_ctx); TF_AbstractOpSetOpType(add_op, "Add", s); ASSERT_EQ(TF_OK, TF_GetCode(s)) << TF_Message(s); - TF_AbstractOpSetOpName(add_op, "my_add1", s); + TF_AbstractOpSetOpName(add_op, "my_add", s); ASSERT_EQ(TF_OK, TF_GetCode(s)) << TF_Message(s); TF_AbstractTensor* inputs[2] = {arg0, arg1}; TF_OutputList* add_outputs = TF_NewOutputList(); @@ -579,7 +579,7 @@ TEST_P(UnifiedCAPI, TestMultiOutputGraph) { auto* add_op = TF_NewAbstractOp(graph_ctx); TF_AbstractOpSetOpType(add_op, "Add", s); ASSERT_EQ(TF_OK, TF_GetCode(s)) << TF_Message(s); - TF_AbstractOpSetOpName(add_op, "my_add2", s); + TF_AbstractOpSetOpName(add_op, "my_add", s); ASSERT_EQ(TF_OK, TF_GetCode(s)) << TF_Message(s); TF_AbstractTensor* inputs[2] = {arg1, arg1}; TF_OutputList* add_outputs = TF_NewOutputList(); diff --git a/tensorflow/c/eager/gradients.cc b/tensorflow/c/eager/gradients.cc index 39cadd421e2..9bcd0d0fea0 100644 --- a/tensorflow/c/eager/gradients.cc +++ b/tensorflow/c/eager/gradients.cc @@ -15,6 +15,7 @@ limitations under the License. #include "tensorflow/c/eager/gradients.h" #include "absl/strings/str_cat.h" +#include "tensorflow/c/eager/abstract_tensor_handle.h" #include "tensorflow/c/eager/c_api_unified_experimental_internal.h" #include "tensorflow/c/eager/gradients_internal.h" #include "tensorflow/core/common_runtime/eager/attr_builder.h" @@ -23,25 +24,97 @@ limitations under the License. namespace tensorflow { namespace gradients { -Status GradientRegistry::Register(const string& op_name, - GradientFunctionFactory factory) { +namespace { +Status ZerosLike(AbstractContext* ctx, AbstractTensorHandle* t, + AbstractTensorHandle** result) { + AbstractOperationPtr op(ctx->CreateOperation()); + TF_RETURN_IF_ERROR(op->Reset("ZerosLike", /*raw_device_name=*/nullptr)); + if (isa(op.get())) { + TF_RETURN_IF_ERROR(dyn_cast(op.get())->SetOpName( + absl::StrCat("ZerosLike", ToId(t)).c_str())); + } + TF_RETURN_IF_ERROR(op->AddInput(t)); + int num_outputs = 1; + std::vector outputs(num_outputs); + TF_RETURN_IF_ERROR( + op->Execute(absl::Span(outputs), &num_outputs)); + *result = outputs[0]; + return Status::OK(); +} +} // namespace + +class IncomingGradientsImpl : public IncomingGradients { + public: + explicit IncomingGradientsImpl( + absl::Span grad_inputs, Context* ctx, + DefaultGradientFunction* default_gradients) + : grad_inputs_(grad_inputs), + ctx_(ctx), + default_gradients_(default_gradients) {} + AbstractTensorHandle* operator[](int i) const override { + return default_gradients_->get(ctx_, grad_inputs_, i); + } + size_t size() const override { return grad_inputs_.size(); } + + private: + absl::Span grad_inputs_; + Context* ctx_; + DefaultGradientFunction* default_gradients_; +}; + +AllZerosDefaultGradients::AllZerosDefaultGradients(const ForwardOperation& op) + : outputs_(op.outputs) { + for (auto output : outputs_) { + output->Ref(); + } +} +AbstractTensorHandle* AllZerosDefaultGradients::get( + Context* ctx, absl::Span grad_inputs, int i) { + if (grad_inputs[i]) { + return grad_inputs[i]; + } + if (cached_default_grads_[i]) { + return cached_default_grads_[i].get(); + } + AbstractTensorHandle* result = nullptr; + Status s = ZerosLike(ctx->ctx, outputs_[i], &result); + if (!s.ok()) { + if (result) { + result->Unref(); + } + VLOG(1) << "Failed to create ZerosLike for index " << i; + return nullptr; + } + cached_default_grads_[i].reset(result); + return result; +} + +PassThroughDefaultGradients::PassThroughDefaultGradients( + const ForwardOperation& op) {} +AbstractTensorHandle* PassThroughDefaultGradients::get( + Context* ctx, absl::Span grad_inputs, int i) { + return grad_inputs[i]; +} + +Status GradientRegistry::Register( + const string& op_name, BackwardFunctionFactory backward_function_factory) { auto iter = registry_.find(op_name); if (iter != registry_.end()) { const string error_msg = "Gradient already exists for op: " + op_name + "."; return errors::AlreadyExists(error_msg); } - registry_.insert({op_name, factory}); + registry_.insert({op_name, backward_function_factory}); return Status::OK(); } Status GradientRegistry::Lookup( const ForwardOperation& op, - std::unique_ptr* grad_fn) const { + std::unique_ptr* backward_function) const { auto iter = registry_.find(op.op_name); if (iter == registry_.end()) { const string error_msg = "No gradient defined for op: " + op.op_name + "."; return errors::NotFound(error_msg); } - grad_fn->reset(iter->second(op)); + backward_function->reset(iter->second(op)); return Status::OK(); } @@ -92,33 +165,8 @@ AbstractTensorHandle* TapeTensor::OnesLike() const { } return outputs[0]; } -AbstractTensorHandle* TapeTensor::ZerosLike() const { - AbstractOperationPtr op(ctx_->CreateOperation()); - // TODO(srbs): Consider adding a TF_RETURN_NULLPTR_IF_ERROR. - Status s = op->Reset("ZerosLike", /*raw_device_name=*/nullptr); - if (!s.ok()) { - return nullptr; - } - if (isa(op.get())) { - s = dyn_cast(op.get())->SetOpName( - absl::StrCat("ZerosLike", ToId(handle_)).c_str()); - if (!s.ok()) { - return nullptr; - } - } - s = op->AddInput(handle_); - if (!s.ok()) { - return nullptr; - } - int num_outputs = 1; - // TODO(srbs): Figure out who is in charge of releasing this. - std::vector outputs(num_outputs); - s = op->Execute(absl::Span(outputs), &num_outputs); - if (!s.ok()) { - return nullptr; - } - return outputs[0]; -} + +AbstractTensorHandle* TapeTensor::ZerosLike() const { return nullptr; } // Returns the number of elements in the gradient tensor. int64 TapeVSpace::NumElements(AbstractTensorHandle* tensor) const { @@ -159,13 +207,16 @@ AbstractTensorHandle* TapeVSpace::AggregateGradients( // Calls the passed-in backward function. Status TapeVSpace::CallBackwardFunction( - GradientFunction* backward_function, + BackwardFunction* backward_function, const std::vector& unneeded_gradients, gtl::ArraySlice output_gradients, std::vector* result) const { if (backward_function == nullptr) return Status::OK(); Context ctx = {ctx_}; - return backward_function->Compute(&ctx, output_gradients, result); + IncomingGradientsImpl incoming_gradients( + output_gradients, &ctx, backward_function->GetDefaultGradientFunction()); + return backward_function->GetGradientFunction()->Compute( + &ctx, incoming_gradients, result); } // Looks up the ID of a Gradient. @@ -373,15 +424,15 @@ Status Execute(AbstractOperation* op_, AbstractContext* ctx, } tape->RecordOperation( op_->Name(), tape_tensors, input_ids, input_dtypes, - [registry, forward_op_]() -> GradientFunction* { - std::unique_ptr grad_fn; - Status s = registry.Lookup(*forward_op_, &grad_fn); + [registry, forward_op_]() -> BackwardFunction* { + std::unique_ptr backward_fn; + Status s = registry.Lookup(*forward_op_, &backward_fn); if (!s.ok()) { return nullptr; } - return grad_fn.release(); + return backward_fn.release(); }, - [](GradientFunction* ptr) { + [](BackwardFunction* ptr) { if (ptr) { delete ptr; } diff --git a/tensorflow/c/eager/gradients.h b/tensorflow/c/eager/gradients.h index 267ee5b7ab2..04e11291404 100644 --- a/tensorflow/c/eager/gradients.h +++ b/tensorflow/c/eager/gradients.h @@ -55,18 +55,25 @@ struct Context { public: AbstractContext* ctx; }; + +class IncomingGradients { + public: + virtual AbstractTensorHandle* operator[](int i) const = 0; + virtual size_t size() const = 0; + virtual ~IncomingGradients() {} +}; + class GradientFunction { public: // TODO(srbs): How we support CompositeTensors e.g. IndexedSlices in // `grad_inputs`. - virtual Status Compute(Context* ctx, - absl::Span grad_inputs, + virtual Status Compute(Context* ctx, const IncomingGradients& grad_inputs, std::vector* grad_outputs) = 0; virtual ~GradientFunction() {} }; // Metadata from the forward operation that is made available to the -// gradient registerer to instantiate a GradientFunction. +// gradient registerer to instantiate a BackwardFunction. struct ForwardOperation { public: string op_name; @@ -76,18 +83,86 @@ struct ForwardOperation { AbstractContext* ctx; }; -using GradientFunctionFactory = - std::function; - -// Map from op name to a `GradientFunctionFactory`. -class GradientRegistry { +// Interface for building default zeros gradients for op outputs which are +// missing incoming gradients. Custom implementations of this can be used to +// control which of the forward op's output tensors/their metadata needs to +// be kept around in memory to build the default zeros grad. +// +// Some common helper implementations are provided below. +class DefaultGradientFunction { public: - Status Register(const string& op, GradientFunctionFactory factory); - Status Lookup(const ForwardOperation& op, - std::unique_ptr* grad_fn) const; + virtual AbstractTensorHandle* get( + Context* ctx, absl::Span grad_inputs, + int i) = 0; + virtual ~DefaultGradientFunction() {} +}; + +// Returns zeros for any `nullptr` in `grad_inputs`. +// +// This may require keeping track of all of forward op's output +// tensors and hence may incur a higher memory footprint. Use sparingly. +// +// Multiple calls to `AllZerosDefaultGradients::get` return the same tensor +// handle. +// +// The destructor of this class `Unref`'s any cached tensor handles so users of +// those tensor handles should `Ref` them in order to keep them alive if needed. +class AllZerosDefaultGradients : public DefaultGradientFunction { + public: + explicit AllZerosDefaultGradients(const ForwardOperation& op); + AbstractTensorHandle* get(Context* ctx, + absl::Span grad_inputs, + int i) override; private: - absl::flat_hash_map registry_; + // TODO(srbs): We do not always need to keep the tensors around. In immediate + // execution mode we just need to store the shape and dtype. During tracing + // we may need to keep the tensor around if the shape is not full defined. + std::vector outputs_; + std::vector cached_default_grads_; +}; + +// Passes through `grad_inputs` as-is. The `GradientFunction` +// will be expected to deal with nullptr in `grad_inputs` if any. +class PassThroughDefaultGradients : public DefaultGradientFunction { + public: + explicit PassThroughDefaultGradients(const ForwardOperation& op); + AbstractTensorHandle* get(Context* ctx, + absl::Span grad_inputs, + int i) override; +}; + +// A `BackwardFunction` wraps a `GradientFunction` and a +// `DefaultGradientFunction`. Both are owned by this class' instance. +class BackwardFunction { + public: + BackwardFunction(GradientFunction* gradient_function, + DefaultGradientFunction* default_gradients) + : gradient_function_(gradient_function), + default_gradients_(default_gradients) {} + GradientFunction* GetGradientFunction() { return gradient_function_.get(); } + DefaultGradientFunction* GetDefaultGradientFunction() { + return default_gradients_.get(); + } + + private: + std::unique_ptr gradient_function_; + std::unique_ptr default_gradients_; +}; + +using BackwardFunctionFactory = + std::function; + +// Map from op name to a `BackwardFunctionFactory`. +class GradientRegistry { + public: + Status Register(const string& op, + BackwardFunctionFactory backward_function_factory); + Status Lookup(const ForwardOperation& op, + std::unique_ptr* backward_function) const; + + private: + absl::flat_hash_map registry_; }; // Returns a unique id for the tensor which is used by the tape to build @@ -106,9 +181,16 @@ int64 ToId(AbstractTensorHandle* t); // allow us to trace the data dependencies between operations and hence compute // gradients. // -// This also implements `ZerosLike` and `OnesLike` to create the default +// This also implements `OnesLike` to create the default // incoming gradients for tensors which do not already have an incoming // gradient. +// +// `ZerosLike` is not expected to be called and returns a nullptr. The creation +// of default zeros grads is handled by the `DefaultGradientFunction` registered +// for each op. +// TODO(srbs): We need to define `ZerosLike` here to keep the compiler happy. +// Figure out a way to avoid this. +// TODO(srbs): Should ZerosLike check-fail instead of returning nullptr? class TapeTensor { public: TapeTensor(AbstractTensorHandle* handle, AbstractContext* ctx); @@ -123,7 +205,7 @@ class TapeTensor { private: AbstractTensorHandle* handle_; - // The context where OnesLike and ZerosLike ops are to be created. + // The context where OnesLike ops are to be created. AbstractContext* ctx_; }; @@ -132,7 +214,7 @@ class TapeTensor { // gradient and for performing gradient aggregation. // See `tensorflow::eager::VSpace` for more details. class TapeVSpace - : public eager::VSpace { + : public eager::VSpace { public: explicit TapeVSpace(AbstractContext* ctx) : ctx_(ctx) {} ~TapeVSpace() override {} @@ -147,7 +229,7 @@ class TapeVSpace // Calls the passed-in backward function. Status CallBackwardFunction( - GradientFunction* backward_function, + BackwardFunction* backward_function, const std::vector& unneeded_gradients, gtl::ArraySlice output_gradients, std::vector* result) const override; @@ -168,8 +250,14 @@ class TapeVSpace }; // A tracing/immediate-execution agnostic tape. +// +// Gradient functions defined for this library support handling null incoming +// gradients. `Tape::ComputeGradient` should be called with +// `build_default_zeros_grads=false`. Calling with +// `build_default_zeros_grads=true` (the default) is equivalent but just results +// in extra work because `TapeTensor::ZerosLike` returns a `nullptr` anyway. using Tape = tensorflow::eager::GradientTape; + BackwardFunction, TapeTensor>; } // namespace gradients } // namespace tensorflow diff --git a/tensorflow/c/eager/gradients_test.cc b/tensorflow/c/eager/gradients_test.cc index 944b10c000b..80b1f157074 100644 --- a/tensorflow/c/eager/gradients_test.cc +++ b/tensorflow/c/eager/gradients_test.cc @@ -24,6 +24,7 @@ limitations under the License. #include "tensorflow/c/eager/c_api_unified_experimental.h" #include "tensorflow/c/eager/c_api_unified_experimental_internal.h" #include "tensorflow/c/eager/gradients_internal.h" +#include "tensorflow/c/experimental/gradients/array_grad.h" #include "tensorflow/c/experimental/gradients/math_grad.h" #include "tensorflow/c/experimental/ops/array_ops.h" #include "tensorflow/c/tf_status_helper.h" @@ -50,6 +51,7 @@ class CppGradients Status RegisterGradients(GradientRegistry* registry) { TF_RETURN_IF_ERROR(registry->Register("Add", AddRegisterer)); TF_RETURN_IF_ERROR(registry->Register("Exp", ExpRegisterer)); + TF_RETURN_IF_ERROR(registry->Register("IdentityN", IdentityNRegisterer)); return Status::OK(); } @@ -94,6 +96,26 @@ Status Exp(AbstractContext* ctx, Tape* tape, registry); } +// Computes `IdentityN(inputs)` and records it on the tape. +Status IdentityN(AbstractContext* ctx, Tape* tape, + absl::Span inputs, + absl::Span outputs, + const GradientRegistry& registry) { + AbstractOperationPtr identity_n_op(ctx->CreateOperation()); + ForwardOperation forward_op; + forward_op.ctx = ctx; + TF_RETURN_IF_ERROR(Reset(identity_n_op.get(), "IdentityN", + /*raw_device_name=*/nullptr, &forward_op)); + if (isa(identity_n_op.get())) { + TF_RETURN_IF_ERROR(dyn_cast(identity_n_op.get()) + ->SetOpName("my_identity_n")); + } + TF_RETURN_IF_ERROR(AddInputList(identity_n_op.get(), inputs, &forward_op)); + int num_retvals = outputs.size(); + return Execute(identity_n_op.get(), ctx, outputs, &num_retvals, &forward_op, + tape, registry); +} + // Computes // y = inputs[0] + inputs[1] // return grad(y, {inputs[0], inputs[1]}) @@ -116,7 +138,8 @@ Status AddGradModel(AbstractContext* ctx, vspace, /*target_tensor_ids=*/{ToId(add_outputs[0])}, /*source_tensor_ids=*/{ToId(inputs[0]), ToId(inputs[1])}, source_tensors_that_are_targets, - /*output_gradients=*/{}, &out_grads)); + /*output_gradients=*/{}, &out_grads, + /*build_default_zeros_grads=*/false)); for (auto add_output : add_outputs) { add_output->Unref(); } @@ -146,7 +169,8 @@ Status ExpGradModel(AbstractContext* ctx, TF_RETURN_IF_ERROR(tape->ComputeGradient( vspace, /*target_tensor_ids=*/{ToId(exp_outputs[0])}, /*source_tensor_ids=*/{ToId(inputs[0])}, source_tensors_that_are_targets, - /*output_gradients=*/{}, &out_grads)); + /*output_gradients=*/{}, &out_grads, + /*build_default_zeros_grads=*/false)); for (auto exp_output : exp_outputs) { exp_output->Unref(); } @@ -155,6 +179,41 @@ Status ExpGradModel(AbstractContext* ctx, return Status::OK(); } +// Computes +// ignored, y = IdentityN(inputs[0], inputs[1]) +// return grad(y, {inputs[0], inputs[1]}) +// This should return [nullptr, 1]. +Status IdentityNGradModel(AbstractContext* ctx, + absl::Span inputs, + absl::Span outputs, + const GradientRegistry& registry) { + TapeVSpace vspace(ctx); + auto tape = new Tape(/*persistent=*/false); + tape->Watch(ToId(inputs[0])); + tape->Watch(ToId(inputs[1])); + + vector identity_n_outputs(2); + TF_RETURN_IF_ERROR(IdentityN(ctx, tape, inputs, + absl::MakeSpan(identity_n_outputs), registry)); + + std::unordered_map + source_tensors_that_are_targets; + vector out_grads; + TF_RETURN_IF_ERROR(tape->ComputeGradient( + vspace, /*target_tensor_ids=*/{ToId(identity_n_outputs[1])}, + /*source_tensor_ids=*/{ToId(inputs[0]), ToId(inputs[1])}, + source_tensors_that_are_targets, + /*output_gradients=*/{}, &out_grads, + /*build_default_zeros_grads=*/false)); + for (auto identity_n_output : identity_n_outputs) { + identity_n_output->Unref(); + } + outputs[0] = out_grads[0]; + outputs[1] = out_grads[1]; + delete tape; + return Status::OK(); +} + AbstractContext* BuildFunction(const char* fn_name) { std::unique_ptr status( TF_NewStatus(), TF_DeleteStatus); @@ -389,13 +448,72 @@ TEST_P(CppGradients, TestExpGrad) { result_tensor = nullptr; } -// TODO(b/160888630): Enable this test with mlir after AddInputList is -// supported. It is needed for AddN op which is used for gradient aggregation. +TEST_P(CppGradients, TestIdentityNGrad) { + // Pseudo-code: + // + // tape.watch(x1) + // tape.watch(x2) + // unused, y = IdentityN([x1, x2]) + // outputs = tape.gradient(y, [x1, x2]) + // Expected: [nullptr, 1] + // + // This test is interesting because the current implementation of GradientTape + // would return [0, 1] whereas we use build_default_zeros_grads=false here + // so we get back [nullptr, 1]. + std::unique_ptr status( + TF_NewStatus(), TF_DeleteStatus); + AbstractContextPtr ctx; + { + AbstractContext* ctx_raw = nullptr; + Status s = + BuildImmediateExecutionContext(std::get<1>(GetParam()), &ctx_raw); + ASSERT_EQ(errors::OK, s.code()) << s.error_message(); + ctx.reset(ctx_raw); + } + + AbstractTensorHandlePtr x1; + { + AbstractTensorHandle* x_raw = nullptr; + Status s = TestScalarTensorHandle(ctx.get(), 1.0f, &x_raw); + ASSERT_EQ(errors::OK, s.code()) << s.error_message(); + x1.reset(x_raw); + } + AbstractTensorHandlePtr x2; + { + AbstractTensorHandle* x_raw = nullptr; + Status s = TestScalarTensorHandle(ctx.get(), 1.0f, &x_raw); + ASSERT_EQ(errors::OK, s.code()) << s.error_message(); + x2.reset(x_raw); + } + + GradientRegistry registry; + Status s = RegisterGradients(®istry); + ASSERT_EQ(errors::OK, s.code()) << s.error_message(); + + std::vector outputs(2); + s = RunModel(IdentityNGradModel, ctx.get(), {x1.get(), x2.get()}, + absl::MakeSpan(outputs), + /*use_function=*/!std::get<2>(GetParam()), registry); + ASSERT_EQ(errors::OK, s.code()) << s.error_message(); + + EXPECT_EQ(outputs[0], nullptr); + TF_Tensor* result_tensor; + s = getValue(outputs[1], &result_tensor); + ASSERT_EQ(errors::OK, s.code()) << s.error_message(); + auto result_value = static_cast(TF_TensorData(result_tensor)); + EXPECT_EQ(*result_value, 1.0); + outputs[1]->Unref(); + TF_DeleteTensor(result_tensor); + result_tensor = nullptr; +} + +// TODO(b/164171226): Enable this test with tfrt after AddInputList is +// supported. It is needed for IdentityN. #ifdef PLATFORM_GOOGLE INSTANTIATE_TEST_SUITE_P( UnifiedCAPI, CppGradients, ::testing::Combine(::testing::Values("graphdef", "mlir"), - /*tfrt*/ ::testing::Values(true, false), + /*tfrt*/ ::testing::Values(false), /*executing_eagerly*/ ::testing::Values(true, false))); #else INSTANTIATE_TEST_SUITE_P( diff --git a/tensorflow/c/eager/immediate_execution_context.h b/tensorflow/c/eager/immediate_execution_context.h index 6d06d9a8de6..02a3320ef65 100644 --- a/tensorflow/c/eager/immediate_execution_context.h +++ b/tensorflow/c/eager/immediate_execution_context.h @@ -57,15 +57,10 @@ class ImmediateExecutionContext : public AbstractContext { // Create a tensor instance from the given data buffer and description. // `memory_releaser` will be called on destruction, and it's responsible for - // cleaning up the underlying buffer. `convert_string` indicates whether it - // has to handle tstring conversion. Expected to be removed once tstring - // migration is done. - virtual AbstractTensorInterface* CreateTensor(DataType dtype, - const int64_t* dims, - int num_dims, void* data, - size_t len, bool convert_string, - MemoryReleaser memory_releaser, - void* memory_releaser_arg) = 0; + // cleaning up the underlying buffer. + virtual AbstractTensorInterface* CreateTensor( + DataType dtype, const int64_t* dims, int num_dims, void* data, size_t len, + MemoryReleaser memory_releaser, void* memory_releaser_arg) = 0; // Create a handle to wrap and manage a Tensor virtual ImmediateExecutionTensorHandle* CreateLocalHandle( diff --git a/tensorflow/c/eager/mnist_gradients_test.cc b/tensorflow/c/eager/mnist_gradients_test.cc new file mode 100644 index 00000000000..1f8ad138858 --- /dev/null +++ b/tensorflow/c/eager/mnist_gradients_test.cc @@ -0,0 +1,781 @@ +/* Copyright 2020 The TensorFlow Authors. All Rights Reserved. +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + http://www.apache.org/licenses/LICENSE-2.0 +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ +#include + +#include "absl/types/span.h" +#include "tensorflow/c/eager/abstract_tensor_handle.h" +#include "tensorflow/c/eager/c_api_experimental.h" +#include "tensorflow/c/eager/c_api_test_util.h" +#include "tensorflow/c/eager/c_api_unified_experimental.h" +#include "tensorflow/c/eager/c_api_unified_experimental_internal.h" +#include "tensorflow/c/eager/gradients.h" +#include "tensorflow/c/eager/gradients_internal.h" +#include "tensorflow/c/eager/mnist_gradients_testutil.h" +#include "tensorflow/c/experimental/gradients/math_grad.h" +#include "tensorflow/c/experimental/gradients/nn_grad.h" +#include "tensorflow/c/experimental/ops/array_ops.h" +#include "tensorflow/c/tf_status_helper.h" +#include "tensorflow/c/tf_tensor.h" +#include "tensorflow/core/lib/llvm_rtti/llvm_rtti.h" +#include "tensorflow/core/platform/errors.h" +#include "tensorflow/core/platform/test.h" + +namespace tensorflow { +namespace gradients { +namespace internal { +namespace { + +class CppGradients + : public ::testing::TestWithParam> { + protected: + void SetUp() override { + TF_SetTracingImplementation(std::get<0>(GetParam())); + } +}; + +Status RegisterGradients(GradientRegistry* registry) { + TF_RETURN_IF_ERROR(registry->Register("Add", AddRegisterer)); + TF_RETURN_IF_ERROR(registry->Register("Exp", ExpRegisterer)); + TF_RETURN_IF_ERROR(registry->Register("MatMul", MatMulRegisterer)); + TF_RETURN_IF_ERROR(registry->Register("Relu", ReluRegisterer)); + TF_RETURN_IF_ERROR( + registry->Register("SparseSoftmaxCrossEntropyWithLogits", + SparseSoftmaxCrossEntropyLossRegisterer)); + return Status::OK(); +} + +// ========================= Test Util Functions ============================== + +// Get a scalar TensorHandle with given value +Status TestScalarTensorHandle(AbstractContext* ctx, float value, + AbstractTensorHandle** tensor) { + std::unique_ptr status( + TF_NewStatus(), TF_DeleteStatus); + TFE_Context* eager_ctx = + TF_ExecutionContextGetTFEContext(wrap(ctx), status.get()); + TF_RETURN_IF_ERROR(StatusFromTF_Status(status.get())); + TFE_TensorHandle* input_eager = TestScalarTensorHandle(eager_ctx, value); + *tensor = + unwrap(TF_CreateAbstractTensorFromEagerTensor(input_eager, status.get())); + return Status::OK(); +} + +// Get a Matrix TensorHandle with given float values and dimensions +Status TestTensorHandleWithDimsFloat(AbstractContext* ctx, float data[], + int64_t dims[], int num_dims, + AbstractTensorHandle** tensor) { + std::unique_ptr status( + TF_NewStatus(), TF_DeleteStatus); + TFE_Context* eager_ctx = + TF_ExecutionContextGetTFEContext(wrap(ctx), status.get()); + TF_RETURN_IF_ERROR(StatusFromTF_Status(status.get())); + TFE_TensorHandle* input_eager = + TestTensorHandleWithDimsFloat(eager_ctx, data, dims, num_dims); + *tensor = + unwrap(TF_CreateAbstractTensorFromEagerTensor(input_eager, status.get())); + return Status::OK(); +} + +// Get a Matrix TensorHandle with given int values and dimensions +Status TestTensorHandleWithDimsInt(AbstractContext* ctx, int data[], + int64_t dims[], int num_dims, + AbstractTensorHandle** tensor) { + std::unique_ptr status( + TF_NewStatus(), TF_DeleteStatus); + TFE_Context* eager_ctx = + TF_ExecutionContextGetTFEContext(wrap(ctx), status.get()); + TF_RETURN_IF_ERROR(StatusFromTF_Status(status.get())); + TFE_TensorHandle* input_eager = + TestTensorHandleWithDimsInt(eager_ctx, data, dims, num_dims); + *tensor = + unwrap(TF_CreateAbstractTensorFromEagerTensor(input_eager, status.get())); + return Status::OK(); +} + +Status GetValue(AbstractTensorHandle* t, TF_Tensor** result_tensor) { + std::unique_ptr status( + TF_NewStatus(), TF_DeleteStatus); + TFE_TensorHandle* result_t = + TF_AbstractTensorGetEagerTensor(wrap(t), status.get()); + TF_RETURN_IF_ERROR(StatusFromTF_Status(status.get())); + *result_tensor = TFE_TensorHandleResolve(result_t, status.get()); + return Status::OK(); +} + +AbstractTensorHandlePtr GetTensorHandleUtilFloat(AbstractContext* ctx, + float vals[], int64_t dims[], + int num_dims) { + AbstractTensorHandlePtr A; + AbstractTensorHandle* a_raw = nullptr; + Status s = TestTensorHandleWithDimsFloat(ctx, vals, dims, num_dims, &a_raw); + A.reset(a_raw); + return A; +} + +AbstractTensorHandlePtr GetTensorHandleUtilInt(AbstractContext* ctx, int vals[], + int64_t dims[], int num_dims) { + AbstractTensorHandlePtr A; + AbstractTensorHandle* a_raw = nullptr; + Status s = TestTensorHandleWithDimsInt(ctx, vals, dims, num_dims, &a_raw); + A.reset(a_raw); + return A; +} + +// =========================== Start Tests ================================ + +TEST_P(CppGradients, TestMatMulGrad) { + std::unique_ptr status( + TF_NewStatus(), TF_DeleteStatus); + AbstractContextPtr ctx; + { + AbstractContext* ctx_raw = nullptr; + Status s = + BuildImmediateExecutionContext(std::get<1>(GetParam()), &ctx_raw); + ASSERT_EQ(errors::OK, s.code()) << s.error_message(); + ctx.reset(ctx_raw); + } + + float A_vals[] = {1.0f, 2.0f, 3.0f, 4.0f}; + int64_t A_dims[] = {2, 2}; + float B_vals[] = {.5f, -1.0f, 1.0f, 1.0f}; + int64_t B_dims[] = {2, 2}; + int num_dims = 2; + + AbstractTensorHandlePtr A = + GetTensorHandleUtilFloat(ctx.get(), A_vals, A_dims, num_dims); + AbstractTensorHandlePtr B = + GetTensorHandleUtilFloat(ctx.get(), B_vals, B_dims, num_dims); + + GradientRegistry registry; + Status s = RegisterGradients(®istry); + ASSERT_EQ(errors::OK, s.code()) << s.error_message(); + + /* Pseudo-code: + * + * tape.watch(A) + * tape.watch(B) + * Y = AB + * outputs = tape.gradient(Y, [A, B]) + */ + + std::vector outputs(2); + s = RunModel(MatMulGradModel, ctx.get(), {A.get(), B.get()}, + absl::MakeSpan(outputs), + /*use_function=*/!std::get<2>(GetParam()), registry); + ASSERT_EQ(errors::OK, s.code()) << s.error_message(); + + TF_Tensor* dA_tensor; + s = GetValue(outputs[0], &dA_tensor); + ASSERT_EQ(errors::OK, s.code()) << s.error_message(); + + float result_data[4] = {0}; + memcpy(&result_data[0], TF_TensorData(dA_tensor), + TF_TensorByteSize(dA_tensor)); + + float expected_dA[4] = {-.5f, 2.0f, -.5f, 2.0f}; + float tolerance = 1e-3; + for (int j = 0; j < 4; j++) { + ASSERT_NEAR(result_data[j], expected_dA[j], tolerance); + } + + TF_Tensor* dB_tensor; + s = GetValue(outputs[1], &dB_tensor); + ASSERT_EQ(errors::OK, s.code()) << s.error_message(); + + memcpy(&result_data[0], TF_TensorData(dB_tensor), + TF_TensorByteSize(dB_tensor)); + + float expected_dB[4] = {4.0f, 4.0f, 6.0f, 6.0f}; + for (int j = 0; j < 4; j++) { + ASSERT_NEAR(result_data[j], expected_dB[j], tolerance); + } + + outputs[0]->Unref(); + outputs[1]->Unref(); + TF_DeleteTensor(dA_tensor); + TF_DeleteTensor(dB_tensor); +} + +TEST_P(CppGradients, TestMNISTForward) { + AbstractContextPtr ctx; + { + AbstractContext* ctx_raw = nullptr; + Status s = + BuildImmediateExecutionContext(std::get<1>(GetParam()), &ctx_raw); + ASSERT_EQ(errors::OK, s.code()) << s.error_message(); + ctx.reset(ctx_raw); + } + + // X = data + float X_vals[] = {1.0f, 2.0f, 3.0f, 4.0f}; + int64_t dims[] = {2, 2}; + int num_dims = 2; + AbstractTensorHandlePtr X = + GetTensorHandleUtilFloat(ctx.get(), X_vals, dims, num_dims); + + // W1 = first weights + float W1_vals[] = {-1.0f, 10.0f, .5f, 1.0f}; + AbstractTensorHandlePtr W1 = + GetTensorHandleUtilFloat(ctx.get(), W1_vals, dims, num_dims); + + // W2 = second weights + float W2_vals[] = {.1f, .2f, .3f, -.5f}; + AbstractTensorHandlePtr W2 = + GetTensorHandleUtilFloat(ctx.get(), W2_vals, dims, num_dims); + + // y = labels + int y_vals[] = {1, 1}; + int64_t dims_y[] = {2}; + num_dims = sizeof(dims_y) / sizeof(dims_y[0]); + AbstractTensorHandlePtr y = + GetTensorHandleUtilInt(ctx.get(), y_vals, dims, num_dims); + + GradientRegistry registry; + + // Run the Forward Pass + std::vector outputs(2); + Status s = + RunModel(MNISTForwardModel, ctx.get(), + {X.get(), W1.get(), W2.get(), y.get()}, absl::MakeSpan(outputs), + /*use_function=*/!std::get<2>(GetParam()), registry); + ASSERT_EQ(errors::OK, s.code()) << s.error_message(); + + // Verify the Results + TF_Tensor* scores_tensor; + s = GetValue(outputs[0], &scores_tensor); + ASSERT_EQ(errors::OK, s.code()) << s.error_message(); + + float result_data[4] = {0}; + memcpy(&result_data[0], TF_TensorData(scores_tensor), + TF_TensorByteSize(scores_tensor)); + + float expected_scores[4] = {3.6f, -6.0f, 10.2f, -17.0f}; + float tolerance = 1e-3; + for (int j = 0; j < 4; j++) { + ASSERT_NEAR(result_data[j], expected_scores[j], tolerance); + } + + TF_Tensor* loss_vals_tensor; + s = GetValue(outputs[1], &loss_vals_tensor); + ASSERT_EQ(errors::OK, s.code()) << s.error_message(); + + memcpy(&result_data[0], TF_TensorData(loss_vals_tensor), + TF_TensorByteSize(loss_vals_tensor)); + float expected_losses[2] = {9.6f, 27.2f}; + for (int j = 0; j < 2; j++) { + ASSERT_NEAR(result_data[j], expected_losses[j], tolerance); + } + + outputs[0]->Unref(); + outputs[1]->Unref(); + TF_DeleteTensor(scores_tensor); + TF_DeleteTensor(loss_vals_tensor); +} + +TEST_P(CppGradients, TestMNISTForward2) { + AbstractContextPtr ctx; + { + AbstractContext* ctx_raw = nullptr; + Status s = + BuildImmediateExecutionContext(std::get<1>(GetParam()), &ctx_raw); + ASSERT_EQ(errors::OK, s.code()) << s.error_message(); + ctx.reset(ctx_raw); + } + + // X = data + float X_vals[] = {1.0f, 2.0f, 3.0f, 4.0f, 5.0f, 6.0f}; + int64_t X_dims[] = {3, 2}; + int num_dims = 2; + AbstractTensorHandlePtr X = + GetTensorHandleUtilFloat(ctx.get(), X_vals, X_dims, num_dims); + + // W1 = first weights + float W1_vals[] = {-1.0f, 10.0f, .5f, 1.0f}; + int64_t dims[] = {2, 2}; + AbstractTensorHandlePtr W1 = + GetTensorHandleUtilFloat(ctx.get(), W1_vals, dims, num_dims); + + // W2 = second weights + float W2_vals[] = {.1f, .2f, .3f, -.5f}; + AbstractTensorHandlePtr W2 = + GetTensorHandleUtilFloat(ctx.get(), W2_vals, dims, num_dims); + + // y = labels + int y_vals[] = {1, 1, 1}; + int64_t y_dims[] = {3}; + num_dims = sizeof(y_dims) / sizeof(y_dims[0]); + AbstractTensorHandlePtr y = + GetTensorHandleUtilInt(ctx.get(), y_vals, y_dims, num_dims); + + GradientRegistry registry; + + // Run the Forward Pass + std::vector outputs(2); + Status s = + RunModel(MNISTForwardModel, ctx.get(), + {X.get(), W1.get(), W2.get(), y.get()}, absl::MakeSpan(outputs), + /*use_function=*/!std::get<2>(GetParam()), registry); + ASSERT_EQ(errors::OK, s.code()) << s.error_message(); + + // Verify the Results + TF_Tensor* scores_tensor; + s = GetValue(outputs[0], &scores_tensor); + ASSERT_EQ(errors::OK, s.code()) << s.error_message(); + + float result_data[6] = {0}; + memcpy(&result_data[0], TF_TensorData(scores_tensor), + TF_TensorByteSize(scores_tensor)); + + float expected_scores[6] = {3.6f, -6.0f, 10.2f, -17.0f, 16.8f, -28.0f}; + float tolerance = 1e-3; + for (int j = 0; j < 6; j++) { + ASSERT_NEAR(result_data[j], expected_scores[j], tolerance); + } + + TF_Tensor* loss_vals_tensor; + s = GetValue(outputs[1], &loss_vals_tensor); + ASSERT_EQ(errors::OK, s.code()) << s.error_message(); + + memcpy(&result_data[0], TF_TensorData(loss_vals_tensor), + TF_TensorByteSize(loss_vals_tensor)); + float expected_losses[3] = {9.6f, 27.2f, 44.8f}; + for (int j = 0; j < 3; j++) { + ASSERT_NEAR(result_data[j], expected_losses[j], tolerance); + } + + outputs[0]->Unref(); + outputs[1]->Unref(); + TF_DeleteTensor(scores_tensor); + TF_DeleteTensor(loss_vals_tensor); +} + +TEST_P(CppGradients, TestMatMulTranspose) { + std::unique_ptr status( + TF_NewStatus(), TF_DeleteStatus); + + AbstractContextPtr ctx; + { + AbstractContext* ctx_raw = nullptr; + Status s = + BuildImmediateExecutionContext(std::get<1>(GetParam()), &ctx_raw); + ASSERT_EQ(errors::OK, s.code()) << s.error_message(); + ctx.reset(ctx_raw); + } + + // X = data + float X_vals[] = {1.0f, 2.0f, 3.0f, 4.0f, 5.0f, 6.0f}; + int64_t X_dims[] = {2, 3}; + int num_dims = 2; + AbstractTensorHandlePtr X = + GetTensorHandleUtilFloat(ctx.get(), X_vals, X_dims, num_dims); + + // W1 = first weights + float W1_vals[] = {1.0f, 2.0f, 3.0f, 4.0f}; + int64_t dims[] = {2, 2}; + AbstractTensorHandlePtr W1 = + GetTensorHandleUtilFloat(ctx.get(), W1_vals, dims, num_dims); + + GradientRegistry registry; + + // Run the MatMul Op + std::vector outputs(1); + + Status s = RunModel(MatMulTransposeModel, ctx.get(), {X.get(), W1.get()}, + absl::MakeSpan(outputs), + /*use_function=*/!std::get<2>(GetParam()), registry); + + ASSERT_EQ(errors::OK, s.code()) << s.error_message(); + + // Verify the Results + TF_Tensor* scores_tensor; + s = GetValue(outputs[0], &scores_tensor); + ASSERT_EQ(errors::OK, s.code()) << s.error_message(); + + float result_data[6] = {0}; + memcpy(&result_data[0], TF_TensorData(scores_tensor), + TF_TensorByteSize(scores_tensor)); + + float expected_scores[6] = {13.0f, 18.0f, 17.0f, 24.0f, 21.0f, 30.0f}; + float tolerance = 1e-3; + for (int j = 0; j < 6; j++) { + ASSERT_NEAR(result_data[j], expected_scores[j], tolerance); + } +} + +TEST_P(CppGradients, TestReluGrad) { + std::unique_ptr status( + TF_NewStatus(), TF_DeleteStatus); + + AbstractContextPtr ctx; + { + AbstractContext* ctx_raw = nullptr; + Status s = + BuildImmediateExecutionContext(std::get<1>(GetParam()), &ctx_raw); + ASSERT_EQ(errors::OK, s.code()) << s.error_message(); + ctx.reset(ctx_raw); + } + + // X = data + float X_vals[] = {1.0f, 2.0f, 3.0f, -5.0f, -4.0f, -3.0f, 2.0f, 0.0f, -1.0f}; + int64_t X_dims[] = {3, 3}; + int num_dims = 2; + AbstractTensorHandlePtr X = + GetTensorHandleUtilFloat(ctx.get(), X_vals, X_dims, num_dims); + + GradientRegistry registry; + Status s = RegisterGradients(®istry); + ASSERT_EQ(errors::OK, s.code()) << s.error_message(); + + /* Pseudo-code: + * + * tape.watch(X) + * Y = Relu(X) + * outputs = tape.gradient(Y, [X]) + */ + std::vector outputs(1); + s = RunModel(ReluGradModel, ctx.get(), {X.get()}, absl::MakeSpan(outputs), + /*use_function=*/!std::get<2>(GetParam()), registry); + ASSERT_EQ(errors::OK, s.code()) << s.error_message(); + + TF_Tensor* dX_tensor; + s = GetValue(outputs[0], &dX_tensor); + ASSERT_EQ(errors::OK, s.code()) << s.error_message(); + + float result_data[9] = {0}; + memcpy(&result_data[0], TF_TensorData(dX_tensor), + TF_TensorByteSize(dX_tensor)); + + float expected_dX[9] = {1.0f, 1.0f, 1.0f, 0.0f, 0.0f, 0.0f, 1.0f, 0.0f, 0.0f}; + float tolerance = 1e-3; + for (int j = 0; j < 9; j++) { + ASSERT_NEAR(result_data[j], expected_dX[j], tolerance); + } + + outputs[0]->Unref(); + TF_DeleteTensor(dX_tensor); +} + +TEST_P(CppGradients, TestSoftmaxLossGrad) { + std::unique_ptr status( + TF_NewStatus(), TF_DeleteStatus); + + AbstractContextPtr ctx; + { + AbstractContext* ctx_raw = nullptr; + Status s = + BuildImmediateExecutionContext(std::get<1>(GetParam()), &ctx_raw); + ASSERT_EQ(errors::OK, s.code()) << s.error_message(); + ctx.reset(ctx_raw); + } + + // X = scores + float X_vals[] = {1.0f, 2.0f, 3.0f, -5.0f, -4.0f, -3.0f, 2.0f, 0.0f, -1.0f}; + int64_t X_dims[] = {3, 3}; + int num_dims = 2; + AbstractTensorHandlePtr X = + GetTensorHandleUtilFloat(ctx.get(), X_vals, X_dims, num_dims); + + // y = labels + int y_vals[] = {1, 0, 1}; + int64_t y_dims[] = {3}; + num_dims = sizeof(y_dims) / sizeof(y_dims[0]); + AbstractTensorHandlePtr y = + GetTensorHandleUtilInt(ctx.get(), y_vals, y_dims, num_dims); + + GradientRegistry registry; + Status s = RegisterGradients(®istry); + ASSERT_EQ(errors::OK, s.code()) << s.error_message(); + + /* Pseudo-code: + * + * tape.watch(X) + * tape.watch(labels) + * loss = SoftmaxLoss(X, labels) + * outputs = tape.gradient(loss, [X, labels]) + * + * + */ + + std::vector outputs(2); + s = RunModel(SoftmaxLossGradModel, ctx.get(), {X.get(), y.get()}, + absl::MakeSpan(outputs), + /*use_function=*/!std::get<2>(GetParam()), registry); + + ASSERT_EQ(errors::OK, s.code()) << s.error_message(); + + TF_Tensor* dX_tensor; + s = GetValue(outputs[0], &dX_tensor); + ASSERT_EQ(errors::OK, s.code()) << s.error_message(); + + float result_data[9] = {0}; + memcpy(&result_data[0], TF_TensorData(dX_tensor), + TF_TensorByteSize(dX_tensor)); + + float expected_dX[9] = {0.090f, -0.7553f, 0.6652f, -0.9099f, 0.2447f, + 0.6652f, 0.8437f, -0.8858f, 0.0420f}; + float tolerance = 1e-3; + for (int j = 0; j < 9; j++) { + ASSERT_NEAR(result_data[j], expected_dX[j], tolerance); + } + + // Only Unref() first output as 2nd is nullptr grad for labels + outputs[0]->Unref(); + TF_DeleteTensor(dX_tensor); +} + +TEST_P(CppGradients, TestMNISTGrad) { + std::unique_ptr status( + TF_NewStatus(), TF_DeleteStatus); + AbstractContextPtr ctx; + { + AbstractContext* ctx_raw = nullptr; + Status s = + BuildImmediateExecutionContext(std::get<1>(GetParam()), &ctx_raw); + ASSERT_EQ(errors::OK, s.code()) << s.error_message(); + ctx.reset(ctx_raw); + } + + // X = data + float X_vals[] = {1.0f, 2.0f, 3.0f, 4.0f}; + int64_t X_dims[] = {2, 2}; + int num_dims = 2; + AbstractTensorHandlePtr X = + GetTensorHandleUtilFloat(ctx.get(), X_vals, X_dims, num_dims); + + // W1 = first weights + float W1_vals[] = {-1.0f, 10.0f, .5f, 1.0f}; + int64_t dims[] = {2, 2}; + AbstractTensorHandlePtr W1 = + GetTensorHandleUtilFloat(ctx.get(), W1_vals, dims, num_dims); + + // W2 = second weights + float W2_vals[] = {.1f, .2f, .3f, -.5f}; + AbstractTensorHandlePtr W2 = + GetTensorHandleUtilFloat(ctx.get(), W2_vals, dims, num_dims); + + // y = labels + int y_vals[] = {1, 1}; + int64_t y_dims[] = {2}; + num_dims = sizeof(y_dims) / sizeof(y_dims[0]); + AbstractTensorHandlePtr y = + GetTensorHandleUtilInt(ctx.get(), y_vals, y_dims, num_dims); + + // Register Grads + GradientRegistry registry; + Status s = RegisterGradients(®istry); + ASSERT_EQ(errors::OK, s.code()) << s.error_message(); + + /* Pseudo-code: + * + * + * tape.watch(W1) + * tape.watch(W2) + * mm = X*W1 + * hidden = Relu(mm) + * scores = W2*hidden + * loss = SoftmaxLoss(scores, y) + * outputs = tape.gradient(loss, [A, B]) + * + */ + + std::vector outputs(3); + s = RunModel(MNISTGradModel, ctx.get(), + {X.get(), W1.get(), W2.get(), y.get()}, absl::MakeSpan(outputs), + /*use_function=*/!std::get<2>(GetParam()), registry); + ASSERT_EQ(errors::OK, s.code()) << s.error_message(); + + float tolerance = 1e-3; + TF_Tensor* dW1_tensor; + s = GetValue(outputs[0], &dW1_tensor); + ASSERT_EQ(errors::OK, s.code()) << s.error_message(); + + float result_data[4] = {0}; + memcpy(&result_data[0], TF_TensorData(dW1_tensor), + TF_TensorByteSize(dW1_tensor)); + + float expected_dW1[4] = {0.0f, 3.2f, 0.0f, 4.8f}; + ; // dLoss + for (int j = 0; j < 4; j++) { + ASSERT_NEAR(result_data[j], expected_dW1[j], tolerance); + } + + TF_Tensor* dW2_tensor; + s = GetValue(outputs[1], &dW2_tensor); + ASSERT_EQ(errors::OK, s.code()) << s.error_message(); + + memcpy(&result_data[0], TF_TensorData(dW2_tensor), + TF_TensorByteSize(dW2_tensor)); + + float expected_dW2[4] = {0.0f, 0.0f, 46.0f, -46.0f}; // dLoss + for (int j = 0; j < 4; j++) { + ASSERT_NEAR(result_data[j], expected_dW2[j], tolerance); + } + + outputs[0]->Unref(); + outputs[1]->Unref(); + outputs[2]->Unref(); + TF_DeleteTensor(dW1_tensor); + TF_DeleteTensor(dW2_tensor); +} + +TEST_P(CppGradients, TestScalarMul) { + std::unique_ptr status( + TF_NewStatus(), TF_DeleteStatus); + + AbstractContextPtr ctx; + { + AbstractContext* ctx_raw = nullptr; + Status s = + BuildImmediateExecutionContext(std::get<1>(GetParam()), &ctx_raw); + ASSERT_EQ(errors::OK, s.code()) << s.error_message(); + ctx.reset(ctx_raw); + } + + AbstractTensorHandlePtr eta; + { + AbstractTensorHandle* x_raw = nullptr; + Status s = TestScalarTensorHandle(ctx.get(), 1.5f, &x_raw); + ASSERT_EQ(errors::OK, s.code()) << s.error_message(); + eta.reset(x_raw); + } + + float A_vals[] = {1.0f, 2.0f, 3.0f, 4.0f}; + int64_t A_dims[] = {2, 2}; + int num_dims = 2; + + AbstractTensorHandlePtr A = + GetTensorHandleUtilFloat(ctx.get(), A_vals, A_dims, num_dims); + + GradientRegistry registry; + std::vector outputs(1); + Status s = RunModel(ScalarMulModel, ctx.get(), {eta.get(), A.get()}, + absl::MakeSpan(outputs), + /*use_function=*/!std::get<2>(GetParam()), registry); + ASSERT_EQ(errors::OK, s.code()) << s.error_message(); + + TF_Tensor* dA_tensor; + s = GetValue(outputs[0], &dA_tensor); + ASSERT_EQ(errors::OK, s.code()) << s.error_message(); + + float result_data[4] = {0}; + memcpy(&result_data[0], TF_TensorData(dA_tensor), + TF_TensorByteSize(dA_tensor)); + + float tolerance = 1e-3; + float eta_val = 1.5f; + for (int j = 0; j < 4; j++) { + ASSERT_NEAR(result_data[j], eta_val * A_vals[j], tolerance); + } + + outputs[0]->Unref(); + TF_DeleteTensor(dA_tensor); +} + +TEST_P(CppGradients, TestMNIST_Training) { + std::unique_ptr status( + TF_NewStatus(), TF_DeleteStatus); + + AbstractContextPtr ctx; + { + AbstractContext* ctx_raw = nullptr; + Status s = + BuildImmediateExecutionContext(std::get<1>(GetParam()), &ctx_raw); + ASSERT_EQ(errors::OK, s.code()) << s.error_message(); + ctx.reset(ctx_raw); + } + + // X = data + float X_vals[] = {1.0f, 2.0f, 3.0f, 4.0f}; + int64_t X_dims[] = {2, 2}; + int num_dims = 2; + AbstractTensorHandlePtr X = + GetTensorHandleUtilFloat(ctx.get(), X_vals, X_dims, num_dims); + + // TODO(amturati): use random initializer for weights instead of + // constant values. + + // W1 = first weights + float W1_vals[] = {-.01f, 0.4f, 0.5f, -.2f}; + int64_t dims[] = {2, 2}; + AbstractTensorHandlePtr W1 = + GetTensorHandleUtilFloat(ctx.get(), W1_vals, dims, num_dims); + + // W2 = second weights + float W2_vals[] = {.1f, .2f, .3f, -.5f}; + AbstractTensorHandlePtr W2 = + GetTensorHandleUtilFloat(ctx.get(), W2_vals, dims, num_dims); + + // y = labels + int y_vals[] = {1, 1}; + int64_t y_dims[] = {2}; + num_dims = sizeof(y_dims) / sizeof(y_dims[0]); + AbstractTensorHandlePtr y = + GetTensorHandleUtilInt(ctx.get(), y_vals, y_dims, num_dims); + + // Register Grads + GradientRegistry registry; + Status s = RegisterGradients(®istry); + ASSERT_EQ(errors::OK, s.code()) << s.error_message(); + + // Prepare for training + std::vector weights; + weights.push_back(W1.get()); + weights.push_back(W2.get()); + + // Set learning rate to be 1e-1 + AbstractTensorHandle* learning_rate = nullptr; + s = TestScalarTensorHandle(ctx.get(), 1e-1, &learning_rate); + ASSERT_EQ(errors::OK, s.code()) << s.error_message(); + + // Train + int num_iters = 10; + std::vector mnist_outputs(3); + std::vector grads(2); + for (int i = 0; i < num_iters; i++) { + // Run Forward Pass + s = RunModel(MNISTGradModel, ctx.get(), + {X.get(), weights[0], weights[1], y.get()}, + absl::MakeSpan(mnist_outputs), + /*use_function=*/!std::get<2>(GetParam()), registry); + ASSERT_EQ(errors::OK, s.code()) << s.error_message(); + + // Fill grads + grads[0] = mnist_outputs[0]; + grads[1] = mnist_outputs[1]; + + // Gradient Update + s = UpdateWeights(ctx.get(), grads, weights, learning_rate); + ASSERT_EQ(errors::OK, s.code()) << s.error_message(); + } + + grads[0]->Unref(); // release W1_grad + grads[1]->Unref(); // release W2_grad + mnist_outputs[2]->Unref(); // release loss +} + +#ifdef PLATFORM_GOOGLE +INSTANTIATE_TEST_SUITE_P( + UnifiedCAPI, CppGradients, + ::testing::Combine(::testing::Values("graphdef"), + /*tfrt*/ ::testing::Values(false), + /*executing_eagerly*/ ::testing::Values(true, false))); +#else +INSTANTIATE_TEST_SUITE_P( + UnifiedCAPI, CppGradients, + ::testing::Combine(::testing::Values("graphdef"), + /*tfrt*/ ::testing::Values(false), + /*executing_eagerly*/ ::testing::Values(true, false))); +#endif +} // namespace +} // namespace internal +} // namespace gradients +} // namespace tensorflow \ No newline at end of file diff --git a/tensorflow/c/eager/mnist_gradients_testutil.cc b/tensorflow/c/eager/mnist_gradients_testutil.cc new file mode 100644 index 00000000000..4b2c87c678d --- /dev/null +++ b/tensorflow/c/eager/mnist_gradients_testutil.cc @@ -0,0 +1,594 @@ +/* Copyright 2020 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ +#include "tensorflow/c/eager/mnist_gradients_testutil.h" + +#include + +#include "absl/container/flat_hash_set.h" +#include "absl/types/span.h" +#include "tensorflow/c/eager/abstract_tensor_handle.h" +#include "tensorflow/c/eager/c_api_experimental.h" +#include "tensorflow/c/eager/c_api_unified_experimental.h" +#include "tensorflow/c/eager/c_api_unified_experimental_internal.h" +#include "tensorflow/c/eager/gradients.h" +#include "tensorflow/c/eager/gradients_internal.h" +#include "tensorflow/c/experimental/ops/array_ops.h" +#include "tensorflow/c/experimental/ops/math_ops.h" +#include "tensorflow/c/experimental/ops/nn_ops.h" +#include "tensorflow/c/tf_status_helper.h" +#include "tensorflow/c/tf_tensor.h" +#include "tensorflow/core/lib/llvm_rtti/llvm_rtti.h" + +using std::vector; +using tracing::TracingOperation; + +// ========================== Tape Ops ============================== + +// Computes `inputs[0] + inputs[1]` and records it on the tape. +Status Add(AbstractContext* ctx, Tape* tape, + absl::Span inputs, + absl::Span outputs, + const GradientRegistry& registry) { + AbstractOperationPtr add_op(ctx->CreateOperation()); + ForwardOperation forward_op; + forward_op.ctx = ctx; + TF_RETURN_IF_ERROR( + Reset(add_op.get(), "Add", /*raw_device_name=*/nullptr, &forward_op)); + if (isa(add_op.get())) { + TF_RETURN_IF_ERROR( + dyn_cast(add_op.get())->SetOpName("my_add")); + } + TF_RETURN_IF_ERROR(AddInput(add_op.get(), inputs[0], &forward_op)); + TF_RETURN_IF_ERROR(AddInput(add_op.get(), inputs[1], &forward_op)); + int num_retvals = 1; + return Execute(add_op.get(), ctx, outputs, &num_retvals, &forward_op, tape, + registry); +} + +// Computes `inputs[0] * inputs[1]` for matrices and records it on the tape. +Status MatMul(AbstractContext* ctx, Tape* tape, + absl::Span inputs, + absl::Span outputs, const char* name, + bool transpose_a, bool transpose_b, + const GradientRegistry& registry) { + AbstractOperationPtr matmul_op(ctx->CreateOperation()); + ForwardOperation forward_op; + forward_op.ctx = ctx; + TF_RETURN_IF_ERROR(Reset(matmul_op.get(), "MatMul", + /*raw_device_name=*/nullptr, &forward_op)); + if (isa(matmul_op.get())) { + TF_RETURN_IF_ERROR( + dyn_cast(matmul_op.get())->SetOpName(name)); + } + + TF_RETURN_IF_ERROR(AddInput(matmul_op.get(), inputs[0], &forward_op)); + TF_RETURN_IF_ERROR(AddInput(matmul_op.get(), inputs[1], &forward_op)); + TF_RETURN_IF_ERROR(tensorflow::gradients::internal::SetAttrBool( + matmul_op.get(), "transpose_a", transpose_a, &forward_op)); + TF_RETURN_IF_ERROR(tensorflow::gradients::internal::SetAttrBool( + matmul_op.get(), "transpose_b", transpose_b, &forward_op)); + + int num_retvals = 1; + return Execute(matmul_op.get(), ctx, outputs, &num_retvals, &forward_op, tape, + registry); +} + +Status Mul(AbstractContext* ctx, Tape* tape, + absl::Span inputs, + absl::Span outputs, const char* name, + const GradientRegistry& registry) { + AbstractOperationPtr mul_op(ctx->CreateOperation()); + ForwardOperation forward_op; + forward_op.ctx = ctx; + TF_RETURN_IF_ERROR( + Reset(mul_op.get(), "Mul", /*raw_device_name=*/nullptr, &forward_op)); + if (isa(mul_op.get())) { + TF_RETURN_IF_ERROR( + dyn_cast(mul_op.get())->SetOpName(name)); + } + + TF_RETURN_IF_ERROR(AddInput(mul_op.get(), inputs[0], &forward_op)); + TF_RETURN_IF_ERROR(AddInput(mul_op.get(), inputs[1], &forward_op)); + + int num_retvals = 1; + return Execute(mul_op.get(), ctx, outputs, &num_retvals, &forward_op, tape, + registry); +} + +// Computes `Relu(inputs[0])` and records it on the tape. +Status Relu(AbstractContext* ctx, Tape* tape, + absl::Span inputs, + absl::Span outputs, const char* name, + const GradientRegistry& registry) { + AbstractOperationPtr relu_op(ctx->CreateOperation()); + ForwardOperation forward_op; + forward_op.ctx = ctx; + TF_RETURN_IF_ERROR( + Reset(relu_op.get(), "Relu", /*raw_device_name=*/nullptr, &forward_op)); + if (isa(relu_op.get())) { + TF_RETURN_IF_ERROR( + dyn_cast(relu_op.get())->SetOpName(name)); + } + TF_RETURN_IF_ERROR(AddInput(relu_op.get(), inputs[0], &forward_op)); + int num_retvals = 1; + return Execute(relu_op.get(), ctx, outputs, &num_retvals, &forward_op, tape, + registry); +} + +// Computes `SoftmaxLoss(scores, labels)` for matrices and records it on the +// tape. +Status SparseSoftmaxCrossEntropyLoss( + AbstractContext* ctx, Tape* tape, + absl::Span inputs, + absl::Span outputs, const char* name, + const GradientRegistry& registry) { + AbstractTensorHandle* scores = inputs[0]; + AbstractTensorHandle* labels = inputs[1]; + + AbstractOperationPtr sm_op(ctx->CreateOperation()); + ForwardOperation forward_op; + forward_op.ctx = ctx; + TF_RETURN_IF_ERROR(Reset(sm_op.get(), "SparseSoftmaxCrossEntropyWithLogits", + /*raw_device_name=*/nullptr, &forward_op)); + if (isa(sm_op.get())) { + TF_RETURN_IF_ERROR( + dyn_cast(sm_op.get())->SetOpName(name)); + } + + TF_RETURN_IF_ERROR(AddInput(sm_op.get(), scores, &forward_op)); + TF_RETURN_IF_ERROR(AddInput(sm_op.get(), labels, &forward_op)); + + int num_retvals = 2; // returns loss values and backprop + return Execute(sm_op.get(), ctx, outputs, &num_retvals, &forward_op, tape, + registry); +} + +//===================== Test Models to run ========================= + +// Computes +// y = inputs[0] + inputs[1] +// return grad(y, {inputs[0], inputs[1]}) +Status AddGradModel(AbstractContext* ctx, + absl::Span inputs, + absl::Span outputs, + const GradientRegistry& registry) { + TapeVSpace vspace(ctx); + auto tape = new Tape(/*persistent=*/false); + tape->Watch(ToId(inputs[0])); // Watch x. + tape->Watch(ToId(inputs[1])); // Watch y. + std::vector add_outputs(1); + TF_RETURN_IF_ERROR(Add(ctx, tape, inputs, absl::MakeSpan(add_outputs), + registry)); // Compute x+y. + std::unordered_map + source_tensors_that_are_targets; + + std::vector out_grads; + TF_RETURN_IF_ERROR(tape->ComputeGradient( + vspace, /*target_tensor_ids=*/{ToId(add_outputs[0])}, + /*source_tensor_ids=*/{ToId(inputs[0]), ToId(inputs[1])}, + source_tensors_that_are_targets, + /*output_gradients=*/{}, &out_grads, + /*build_default_zeros_grads=*/false)); + for (auto add_output : add_outputs) { + add_output->Unref(); + } + outputs[0] = out_grads[0]; + outputs[1] = out_grads[1]; + delete tape; + return Status::OK(); +} + +// Computes +// y = inputs[0] * inputs[1] +// return grad(y, {inputs[0], inputs[1]}) +Status MatMulGradModel(AbstractContext* ctx, + absl::Span inputs, + absl::Span outputs, + const GradientRegistry& registry) { + TapeVSpace vspace(ctx); + auto tape = new Tape(/*persistent=*/false); + tape->Watch(ToId(inputs[0])); // Watch x. + tape->Watch(ToId(inputs[1])); // Watch y. + vector mm_outputs(1); + TF_RETURN_IF_ERROR(MatMul(ctx, tape, inputs, absl::MakeSpan(mm_outputs), + "matmul0", /*transpose_a=*/false, + /*transpose_b=*/false, registry)); // Compute x*y. + + std::unordered_map + source_tensors_that_are_targets; + + vector out_grads; + TF_RETURN_IF_ERROR(tape->ComputeGradient( + vspace, /*target_tensor_ids=*/{ToId(mm_outputs[0])}, + /*source_tensor_ids=*/{ToId(inputs[0]), ToId(inputs[1])}, + source_tensors_that_are_targets, + /*output_gradients=*/{}, &out_grads, + /*build_default_zeros_grads=*/false)); + for (auto mm_output : mm_outputs) { + mm_output->Unref(); + } + outputs[0] = out_grads[0]; + outputs[1] = out_grads[1]; + delete tape; + return Status::OK(); +} + +// Model to run 2-layer net +Status MNISTForwardModel(AbstractContext* ctx, + absl::Span inputs, + absl::Span outputs, + const GradientRegistry& registry) { + /** + * We will trace a 2-layer fully connected network for an MNIST model: + * + * def mnist_forward(X, W1, W2, y_labels): + * mm_out_1 = tf.matmul(X,W1) + * hidden_layer = tf.nn.relu(mm_out_1) + * scores = tf.matmul(hidden_layer,W2) + * softmax = + * tf.nn.sparse_softmax_cross_entropy_with_logits(scores,y_labels) return + * scores, softmax + * + * Use this convention for inputs: + * + * inputs = [X, W1, W2, y_labels] + * + */ + AbstractTensorHandle* X = inputs[0]; + AbstractTensorHandle* W1 = inputs[1]; + AbstractTensorHandle* W2 = inputs[2]; + AbstractTensorHandle* y_labels = inputs[3]; + + TapeVSpace vspace(ctx); + auto tape = new Tape(/*persistent=*/false); + tape->Watch(ToId(W1)); // Watch W1. + tape->Watch(ToId(W2)); // Watch W2. + vector temp_outputs(1); + + TF_RETURN_IF_ERROR(MatMul(ctx, tape, {X, W1}, absl::MakeSpan(temp_outputs), + "matmul0", /*transpose_a=*/false, + /*transpose_b=*/false, registry)); // Compute X*W1 + + TF_RETURN_IF_ERROR(Relu(ctx, tape, {temp_outputs[0]}, + absl::MakeSpan(temp_outputs), "relu", + registry)); // Compute Relu(X*W1) + + TF_RETURN_IF_ERROR(MatMul(ctx, tape, {temp_outputs[0], W2}, + absl::MakeSpan(temp_outputs), "matmul1", + /*transpose_a=*/false, /*transpose_b=*/false, + registry)); // Compute W2*Relu(X*W1) + + AbstractTensorHandle* scores = temp_outputs[0]; + + TF_RETURN_IF_ERROR(SparseSoftmaxCrossEntropyLoss( + ctx, tape, {scores, y_labels}, absl::MakeSpan(temp_outputs), + "softmax_loss", registry)); // Compute Softmax(Scores,labels) + + AbstractTensorHandle* loss_vals = temp_outputs[0]; + + outputs[0] = scores; + outputs[1] = loss_vals; + delete tape; + return Status::OK(); +} + +Status MatMulTransposeModel(AbstractContext* ctx, + absl::Span inputs, + absl::Span outputs, + const GradientRegistry& registry) { + AbstractTensorHandle* X = inputs[0]; + AbstractTensorHandle* W1 = inputs[1]; + + TapeVSpace vspace(ctx); + auto tape = new Tape(/*persistent=*/false); + tape->Watch(ToId(X)); + tape->Watch(ToId(W1)); + vector temp_outputs(1); + + TF_RETURN_IF_ERROR(MatMul(ctx, tape, {X, W1}, absl::MakeSpan(temp_outputs), + "matmul0", /*transpose_a=*/true, + /*transpose_b=*/false, registry)); // Compute X*W1 + + outputs[0] = temp_outputs[0]; + + delete tape; + return Status::OK(); +} + +Status ReluGradModel(AbstractContext* ctx, + absl::Span inputs, + absl::Span outputs, + const GradientRegistry& registry) { + TapeVSpace vspace(ctx); + auto tape = new Tape(/*persistent=*/false); + tape->Watch(ToId(inputs[0])); // Watch X + vector relu_outputs(1); + TF_RETURN_IF_ERROR(Relu(ctx, tape, inputs, absl::MakeSpan(relu_outputs), + "relu0", registry)); // Relu(X) + + std::unordered_map + source_tensors_that_are_targets; + + vector out_grads; + TF_RETURN_IF_ERROR(tape->ComputeGradient( + vspace, /*target_tensor_ids=*/{ToId(relu_outputs[0])}, + /*source_tensor_ids=*/{ToId(inputs[0])}, source_tensors_that_are_targets, + /*output_gradients=*/{}, &out_grads, + /*build_default_zeros_grads=*/false)); + + for (auto relu_output : relu_outputs) { + relu_output->Unref(); + } + + outputs[0] = out_grads[0]; + delete tape; + return Status::OK(); +} + +Status SoftmaxLossGradModel(AbstractContext* ctx, + absl::Span inputs, + absl::Span outputs, + const GradientRegistry& registry) { + TapeVSpace vspace(ctx); + auto tape = new Tape(/*persistent=*/false); + tape->Watch(ToId(inputs[0])); // Watch scores. + tape->Watch(ToId(inputs[1])); // Watch labels. + vector sm_outputs(2); + TF_RETURN_IF_ERROR(SparseSoftmaxCrossEntropyLoss( + ctx, tape, inputs, absl::MakeSpan(sm_outputs), "softmax0", registry)); + + std::unordered_map + source_tensors_that_are_targets; + + vector out_grads; + TF_RETURN_IF_ERROR(tape->ComputeGradient( + vspace, /*target_tensor_ids=*/{ToId(sm_outputs[0])}, + /*source_tensor_ids=*/{ToId(inputs[0]), ToId(inputs[1])}, + source_tensors_that_are_targets, + /*output_gradients=*/{}, &out_grads, + /*build_default_zeros_grads=*/false)); + + outputs[0] = out_grads[0]; + outputs[1] = out_grads[1]; + delete tape; + return Status::OK(); +} + +Status MNISTGradModel(AbstractContext* ctx, + absl::Span inputs, + absl::Span outputs, + const GradientRegistry& registry) { + AbstractTensorHandle* X = inputs[0]; + AbstractTensorHandle* W1 = inputs[1]; + AbstractTensorHandle* W2 = inputs[2]; + AbstractTensorHandle* y_labels = inputs[3]; + + TapeVSpace vspace(ctx); + auto tape = new Tape(/*persistent=*/true); + tape->Watch(ToId(X)); // Watch X. + tape->Watch(ToId(W1)); // Watch W1. + tape->Watch(ToId(W2)); // Watch W1. + vector temp_outputs(1); + TF_RETURN_IF_ERROR(MatMul(ctx, tape, {X, W1}, absl::MakeSpan(temp_outputs), + "matmul0", /*transpose_a=*/false, + /*transpose_b=*/false, registry)); // Compute X*W1 + + AbstractTensorHandle* mm = temp_outputs[0]; + + TF_RETURN_IF_ERROR(Relu(ctx, tape, {mm}, + absl::MakeSpan(temp_outputs), // Relu(X*W1) + "relu0", registry)); + + AbstractTensorHandle* hidden = temp_outputs[0]; + + TF_RETURN_IF_ERROR(MatMul(ctx, tape, {hidden, W2}, + absl::MakeSpan(temp_outputs), "matmul1", + /*transpose_a=*/false, /*transpose_b=*/false, + registry)); // W2*Relu(X*W1) + + AbstractTensorHandle* scores = temp_outputs[0]; + + temp_outputs.resize(2); + TF_RETURN_IF_ERROR(SparseSoftmaxCrossEntropyLoss( + ctx, tape, {scores, y_labels}, absl::MakeSpan(temp_outputs), + "softmaxloss", registry)); // W2*Relu(X*W1) + + AbstractTensorHandle* loss = temp_outputs[0]; + + std::unordered_map + source_tensors_that_are_targets; + + vector out_grads; + TF_RETURN_IF_ERROR( + tape->ComputeGradient(vspace, /*target_tensor_ids=*/{ToId(loss)}, + /*source_tensor_ids=*/{ToId(W1), ToId(W2)}, + source_tensors_that_are_targets, + /*output_gradients=*/{}, &out_grads, + /*build_default_zeros_grads=*/false)); + + // Only release 2nd temp output as first holds loss values. + temp_outputs[1]->Unref(); + + outputs[0] = out_grads[0]; // dW1 + outputs[1] = out_grads[1]; // dW2 + outputs[2] = loss; + + delete tape; + return Status::OK(); +} + +Status ScalarMulModel(AbstractContext* ctx, + absl::Span inputs, + absl::Span outputs, + const GradientRegistry& registry) { + AbstractTensorHandle* eta = inputs[0]; + AbstractTensorHandle* A = inputs[1]; + + TapeVSpace vspace(ctx); + auto tape = new Tape(/*persistent=*/false); + vector temp_outputs(1); + + TF_RETURN_IF_ERROR(Mul(ctx, tape, {eta, A}, absl::MakeSpan(temp_outputs), + "scalarMul0", registry)); // Compute eta*A + + outputs[0] = temp_outputs[0]; + + delete tape; + return Status::OK(); +} + +// ============================= End Models ================================ + +Status UpdateWeights(AbstractContext* ctx, vector& grads, + vector& weights, + AbstractTensorHandle* learning_rate) { + /* Update weights one by one using gradient update rule: + * + * w -= lr*grad[w] + * + * NOTE: assuming learning rate is positive + */ + + Status s; + int num_grads = grads.size(); + vector temp_outputs(1); + std::string update_str; + + // Negate learning rate for gradient descent + TF_RETURN_IF_ERROR(ops::Neg(ctx, {learning_rate}, + absl::MakeSpan(temp_outputs), + "neg_lr")); // Compute -lr + learning_rate = temp_outputs[0]; + + for (int i = 0; i < num_grads; i++) { + // Compute dW = -lr * grad(w[i]) + update_str = "update_mul_" + std::to_string(i); + s = ops::Mul(ctx, {learning_rate, grads[i]}, absl::MakeSpan(temp_outputs), + update_str.c_str()); + + AbstractTensorHandle* dW = temp_outputs[0]; + + // Compute temp = weights[i] + dW + update_str = "update_add_" + std::to_string(i); + s = ops::Add(ctx, {weights[i], dW}, absl::MakeSpan(temp_outputs), + update_str.c_str()); + + // Update the weights + weights[i] = temp_outputs[0]; + } + + return Status::OK(); +} + +AbstractContext* BuildFunction(const char* fn_name) { + std::unique_ptr status( + TF_NewStatus(), TF_DeleteStatus); + TF_ExecutionContext* graph_ctx = TF_CreateFunction(fn_name, status.get()); + return unwrap(graph_ctx); +} + +Status CreateParamsForInputs(AbstractContext* ctx, + absl::Span inputs, + vector* params) { + tracing::TracingTensorHandle* handle = nullptr; + for (auto input : inputs) { + TF_RETURN_IF_ERROR(dyn_cast(ctx)->AddParameter( + input->DataType(), &handle)); + params->emplace_back(handle); + } + return Status::OK(); +} + +Status RunModel(Model model, AbstractContext* ctx, + absl::Span inputs, + absl::Span outputs, bool use_function, + const GradientRegistry& registry) { + if (use_function) { + const char* fn_name = "test_fn"; + std::unique_ptr scoped_func; + // Returning null tensors from a tf.function is not supported, so we keep + // track of indices in the model's outputs are nullptr in this set. + // The FunctionDef only outputs the non-null tensors. We later pad the + // function op outputs to have nullptrs at the `null_indices`. + absl::flat_hash_set null_indices; + { + AbstractContextPtr func_ctx(BuildFunction(fn_name)); + vector func_inputs; + func_inputs.reserve(inputs.size()); + TF_RETURN_IF_ERROR( + CreateParamsForInputs(func_ctx.get(), inputs, &func_inputs)); + vector model_outputs; + model_outputs.resize(outputs.size()); + TF_RETURN_IF_ERROR(model(func_ctx.get(), absl::MakeSpan(func_inputs), + absl::MakeSpan(model_outputs), registry)); + for (auto func_input : func_inputs) { + func_input->Unref(); + } + AbstractFunction* func = nullptr; + OutputList output_list; + output_list.expected_num_outputs = 0; + output_list.outputs.reserve(outputs.size()); + for (int i = 0; i < model_outputs.size(); i++) { + if (model_outputs[i]) { + output_list.outputs.emplace_back(model_outputs[i]); + output_list.expected_num_outputs += 1; + } else { + null_indices.insert(i); + } + } + TF_RETURN_IF_ERROR(dyn_cast(func_ctx.get()) + ->Finalize(&output_list, &func)); + scoped_func.reset(func); + for (auto output : output_list.outputs) { + output->Unref(); + } + TF_RETURN_IF_ERROR(ctx->RegisterFunction(func)); + } + + AbstractOperationPtr fn_op(ctx->CreateOperation()); + TF_RETURN_IF_ERROR(fn_op->Reset(fn_name, /*raw_device_name=*/nullptr)); + for (auto input : inputs) { + TF_RETURN_IF_ERROR(fn_op->AddInput(input)); + } + int retvals = outputs.size() - null_indices.size(); + vector fn_outputs(retvals); + TF_RETURN_IF_ERROR(fn_op->Execute( + absl::Span(fn_outputs.data(), fn_outputs.size()), + &retvals)); + int skipped_indices = 0; + for (int i = 0; i < outputs.size(); i++) { + if (!null_indices.contains(i)) { + outputs[i] = fn_outputs[i - skipped_indices]; + } else { + skipped_indices += 1; + } + } + TF_RETURN_IF_ERROR(ctx->RemoveFunction(fn_name)); + return Status::OK(); + } else { + return model(ctx, inputs, outputs, registry); + } +} + +Status BuildImmediateExecutionContext(bool use_tfrt, AbstractContext** ctx) { + std::unique_ptr status( + TF_NewStatus(), TF_DeleteStatus); + TFE_ContextOptions* opts = TFE_NewContextOptions(); + TFE_ContextOptionsSetTfrt(opts, use_tfrt); + *ctx = unwrap(TF_NewEagerExecutionContext(opts, status.get())); + TF_RETURN_IF_ERROR(StatusFromTF_Status(status.get())); + TFE_DeleteContextOptions(opts); + return Status::OK(); +} diff --git a/tensorflow/c/eager/mnist_gradients_testutil.h b/tensorflow/c/eager/mnist_gradients_testutil.h new file mode 100644 index 00000000000..b6de8ff6788 --- /dev/null +++ b/tensorflow/c/eager/mnist_gradients_testutil.h @@ -0,0 +1,146 @@ +/* Copyright 2020 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ +#include + +#include "absl/types/span.h" +#include "tensorflow/c/eager/abstract_tensor_handle.h" +#include "tensorflow/c/eager/c_api_experimental.h" +#include "tensorflow/c/eager/c_api_unified_experimental.h" +#include "tensorflow/c/eager/c_api_unified_experimental_internal.h" +#include "tensorflow/c/eager/gradients.h" +#include "tensorflow/c/eager/gradients_internal.h" +#include "tensorflow/c/experimental/ops/array_ops.h" +#include "tensorflow/c/experimental/ops/math_ops.h" +#include "tensorflow/c/experimental/ops/nn_ops.h" +#include "tensorflow/c/tf_status_helper.h" +#include "tensorflow/c/tf_tensor.h" +#include "tensorflow/core/lib/llvm_rtti/llvm_rtti.h" + +using namespace tensorflow; +using namespace tensorflow::gradients; +using namespace tensorflow::gradients::internal; + +// ========================== Tape Ops ============================== + +// Computes `inputs[0] + inputs[1]` and records it on the tape. +Status Add(AbstractContext* ctx, Tape* tape, + absl::Span inputs, + absl::Span outputs, + const GradientRegistry& registry); + +// Computes `inputs[0] * inputs[1]` for matrices and records it on the tape. +Status MatMul(AbstractContext* ctx, Tape* tape, + absl::Span inputs, + absl::Span outputs, const char* name, + bool transpose_a, bool transpose_b, + const GradientRegistry& registry); + +// Computes `inputs[0] * inputs[1]` and records it on the tape. +Status Mul(AbstractContext* ctx, Tape* tape, + absl::Span inputs, + absl::Span outputs, const char* name, + const GradientRegistry& registry); + +// Computes `Relu(inputs[0])` and records it on the tape. +Status Relu(AbstractContext* ctx, Tape* tape, + absl::Span inputs, + absl::Span outputs, const char* name, + const GradientRegistry& registry); + +// Computes `SoftmaxLoss(scores, labels)` for matrices and records it on the +// tape. +Status SparseSoftmaxCrossEntropyLoss( + AbstractContext* ctx, Tape* tape, + absl::Span inputs, + absl::Span outputs, const char* name, + const GradientRegistry& registry); + +// ====================== End Tape Ops ============================ + +// Computes +// y = inputs[0] + inputs[1] +// return grad(y, {inputs[0], inputs[1]}) +Status AddGradModel(AbstractContext* ctx, + absl::Span inputs, + absl::Span outputs, + const GradientRegistry& registry); + +// Computes +// y = inputs[0] * inputs[1] +// return grad(y, {inputs[0], inputs[1]}) +Status MatMulGradModel(AbstractContext* ctx, + absl::Span inputs, + absl::Span outputs, + const GradientRegistry& registry); + +// Computes 2-layer Neural Network with Softmax Loss. +Status MNISTForwardModel(AbstractContext* ctx, + absl::Span inputs, + absl::Span outputs, + const GradientRegistry& registry); + +// Computes MatMul with first matrix tranposed. +Status MatMulTransposeModel(AbstractContext* ctx, + absl::Span inputs, + absl::Span outputs, + const GradientRegistry& registry); + +// Test Model to verify ReluGrad functionality +Status ReluGradModel(AbstractContext* ctx, + absl::Span inputs, + absl::Span outputs, + const GradientRegistry& registry); + +// Test Model to verify SoftmaxGrad functionality +Status SoftmaxLossGradModel(AbstractContext* ctx, + absl::Span inputs, + absl::Span outputs, + const GradientRegistry& registry); + +// Test Model to verify Multi-grad functionality for MNIST +Status MNISTGradModel(AbstractContext* ctx, + absl::Span inputs, + absl::Span outputs, + const GradientRegistry& registry); + +// Test Model to verify scalar-tensor multiplication Op +Status ScalarMulModel(AbstractContext* ctx, + absl::Span inputs, + absl::Span outputs, + const GradientRegistry& registry); + +// Updates the weights for a neural network given incoming grads and learning +// rate +Status UpdateWeights(AbstractContext* ctx, + std::vector& grads, + std::vector& weights, + AbstractTensorHandle* learning_rate); + +AbstractContext* BuildFunction(const char* fn_name); + +Status CreateParamsForInputs(AbstractContext* ctx, + absl::Span inputs, + std::vector* params); + +using Model = std::function, + absl::Span, const GradientRegistry&)>; + +Status RunModel(Model model, AbstractContext* ctx, + absl::Span inputs, + absl::Span outputs, bool use_function, + const GradientRegistry& registry); + +Status BuildImmediateExecutionContext(bool use_tfrt, AbstractContext** ctx); diff --git a/tensorflow/c/eager/parallel_device/BUILD b/tensorflow/c/eager/parallel_device/BUILD index 0d0e5ffce10..df5504adce2 100644 --- a/tensorflow/c/eager/parallel_device/BUILD +++ b/tensorflow/c/eager/parallel_device/BUILD @@ -76,10 +76,26 @@ cc_library( "//tensorflow/c/eager:c_api_experimental", "//tensorflow/core:lib", "@com_google_absl//absl/types:optional", + "@com_google_absl//absl/types:span", "@com_google_absl//absl/types:variant", ], ) +tf_cc_test( + name = "parallel_device_lib_test", + srcs = ["parallel_device_lib_test.cc"], + deps = [ + ":parallel_device_lib", + "//tensorflow/c:c_api", + "//tensorflow/c:c_api_experimental", + "//tensorflow/c/eager:c_api", + "//tensorflow/c/eager:c_api_experimental", + "//tensorflow/core:protos_all_cc", + "//tensorflow/core:test", + "//tensorflow/core:test_main", + ], +) + cc_library( name = "parallel_device_testlib", testonly = 1, diff --git a/tensorflow/c/eager/parallel_device/parallel_device_lib.cc b/tensorflow/c/eager/parallel_device/parallel_device_lib.cc index 768f686bd88..e270bfcbb80 100644 --- a/tensorflow/c/eager/parallel_device/parallel_device_lib.cc +++ b/tensorflow/c/eager/parallel_device/parallel_device_lib.cc @@ -15,6 +15,7 @@ limitations under the License. #include "tensorflow/c/eager/parallel_device/parallel_device_lib.h" +#include "tensorflow/c/tf_status.h" #include "tensorflow/core/lib/gtl/cleanup.h" #include "tensorflow/core/platform/env.h" #include "tensorflow/core/platform/mutex.h" @@ -118,6 +119,9 @@ class DeviceThread { int expected_max_outputs_ TF_GUARDED_BY(execution_mutex_); // Outputs std::vector op_outputs_ TF_GUARDED_BY(execution_mutex_); + // TF_Status is an incomplete type and so can't be stack allocated. To avoid + // unnecessary allocations each Execute call, we keep one heap-allocated + // version for the thread. StatusPtr status_ TF_GUARDED_BY(execution_mutex_); const std::string device_; @@ -188,6 +192,9 @@ std::vector DeviceThread::Join(TF_Status* status) { if (TF_GetCode(status_.get()) != TF_OK) { TF_SetStatus(status, TF_GetCode(status_.get()), TF_Message(status_.get())); + // Reset the member `status_` so future op executions (after recovery from + // the bad `status`) start with an OK status. + TF_SetStatus(status_.get(), TF_OK, ""); } execution_state_ = ExecutionState::kIdle; result = std::move(op_outputs_); @@ -255,18 +262,27 @@ std::unique_ptr ParallelDevice::CopyToParallelDevice( status); } -std::unique_ptr ParallelDevice::DeviceIDs( - TFE_Context* context, TF_Status* status) const { +std::unique_ptr ParallelDevice::Vector( + TFE_Context* context, TF_Status* status, + absl::Span values) const { // TODO(allenl): We could cache DeviceIDs (keyed by context). std::vector components; components.reserve(underlying_devices_.size()); - for (int device_index = 0; device_index < underlying_devices_.size(); + + if (values.size() != num_underlying_devices()) { + TF_SetStatus( + status, TF_INVALID_ARGUMENT, + "Number of values did not match number of underlying devices."); + return nullptr; + } + + for (int device_index = 0; device_index < num_underlying_devices(); ++device_index) { - int32_t* device_id = new int32_t; - *device_id = device_index; + int32_t* device_value = new int32_t; + *device_value = values[device_index]; std::unique_ptr tensor( TF_NewTensor( - TF_INT32, /*dims=*/nullptr, /*num_dims=*/0, device_id, + TF_INT32, /*dims=*/nullptr, /*num_dims=*/0, device_value, sizeof(int32_t), [](void* data, size_t, void* arg) { delete reinterpret_cast(data); @@ -295,6 +311,16 @@ std::unique_ptr ParallelDevice::DeviceIDs( status); } +std::unique_ptr ParallelDevice::DeviceIDs( + TFE_Context* context, TF_Status* status) const { + std::vector ids; + ids.reserve(num_underlying_devices()); + for (int i = 0; i < num_underlying_devices(); ++i) { + ids.push_back(i); + } + return Vector(context, status, ids); +} + absl::optional>> ParallelDevice::Execute(TFE_Context* context, const std::vector& inputs, @@ -319,21 +345,36 @@ ParallelDevice::Execute(TFE_Context* context, std::move(device_inputs), attributes, expected_max_outputs); } + StatusPtr first_bad_status(nullptr); for (int device_index = 0; device_index < underlying_devices_.size(); ++device_index) { DeviceThread* device_thread = device_threads_[device_index].get(); per_device_output_tensors.push_back(device_thread->Join(status)); - if (TF_GetCode(status) != TF_OK) return result; + // We will run every Join even if there are bad statuses in case the user + // wants to recover and continue running ops on the parallel device (which + // would otherwise deadlock). + if (TF_GetCode(status) != TF_OK && first_bad_status == nullptr) { + first_bad_status.reset(TF_NewStatus()); + TF_SetStatus(first_bad_status.get(), TF_GetCode(status), + TF_Message(status)); + } + if (device_index == 0) { first_op_output_count = per_device_output_tensors.rbegin()->size(); } else { - if (per_device_output_tensors.rbegin()->size() != first_op_output_count) { - TF_SetStatus(status, TF_INTERNAL, + if (first_bad_status == nullptr && + per_device_output_tensors.rbegin()->size() != first_op_output_count) { + first_bad_status.reset(TF_NewStatus()); + TF_SetStatus(first_bad_status.get(), TF_INTERNAL, "Parallel ops produced different numbers of tensors."); - return result; } } } + if (first_bad_status != nullptr) { + TF_SetStatus(status, TF_GetCode(first_bad_status.get()), + TF_Message(first_bad_status.get())); + return result; + } // For each output of the original operation, pack the per-device // TensorHandles we've computed into a single parallel TensorHandle. std::vector> per_device_outputs; diff --git a/tensorflow/c/eager/parallel_device/parallel_device_lib.h b/tensorflow/c/eager/parallel_device/parallel_device_lib.h index cbfea31d95f..b3dc47ab088 100644 --- a/tensorflow/c/eager/parallel_device/parallel_device_lib.h +++ b/tensorflow/c/eager/parallel_device/parallel_device_lib.h @@ -21,6 +21,7 @@ limitations under the License. #include #include "absl/types/optional.h" +#include "absl/types/span.h" #include "absl/types/variant.h" #include "tensorflow/c/c_api.h" #include "tensorflow/c/eager/c_api.h" @@ -61,6 +62,11 @@ class ParallelDevice { TFE_TensorHandle* tensor, TF_Status* status) const; + // Construct a parallel tensor consisting of the scalar values from `values`. + std::unique_ptr Vector( + TFE_Context* context, TF_Status* status, + absl::Span values) const; + // A parallel tensor with scalar integers numbering component devices. std::unique_ptr DeviceIDs(TFE_Context* context, TF_Status* status) const; diff --git a/tensorflow/c/eager/parallel_device/parallel_device_lib_test.cc b/tensorflow/c/eager/parallel_device/parallel_device_lib_test.cc new file mode 100644 index 00000000000..35befe959cb --- /dev/null +++ b/tensorflow/c/eager/parallel_device/parallel_device_lib_test.cc @@ -0,0 +1,84 @@ +/* Copyright 2020 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tensorflow/c/eager/parallel_device/parallel_device_lib.h" + +#include "tensorflow/c/c_api.h" +#include "tensorflow/c/c_api_experimental.h" +#include "tensorflow/c/eager/c_api.h" +#include "tensorflow/c/eager/c_api_experimental.h" +#include "tensorflow/core/platform/test.h" + +namespace tensorflow { +namespace parallel_device { + +TEST(PARALLEL_DEVICE_LIB, TestOpWithError) { + std::unique_ptr status( + TF_NewStatus(), TF_DeleteStatus); + std::unique_ptr opts( + TFE_NewContextOptions(), TFE_DeleteContextOptions); + std::unique_ptr config( + TF_CreateConfig( + /*xla*/ false, + /* gpu_memory_allow_growth */ true, /* num_cpu_devices */ + 2), + TF_DeleteBuffer); + TFE_ContextOptionsSetConfig(opts.get(), config->data, config->length, + status.get()); + std::unique_ptr context( + TFE_NewContext(opts.get(), status.get()), TFE_DeleteContext); + ASSERT_TRUE(TF_GetCode(status.get()) == TF_OK) << TF_Message(status.get()); + + std::vector devices{ + "/job:localhost/replica:0/task:0/device:CPU:0", + "/job:localhost/replica:0/task:0/device:CPU:1"}; + ParallelDevice parallel_device(std::move(devices)); + std::unique_ptr handle_op( + TFE_NewOp(context.get(), "VarHandleOp", status.get()), TFE_DeleteOp); + ASSERT_TRUE(TF_GetCode(status.get()) == TF_OK) << TF_Message(status.get()); + TFE_OpSetAttrType(handle_op.get(), "dtype", TF_FLOAT); + TFE_OpSetAttrShape(handle_op.get(), "shape", /*dims=*/nullptr, /*num_dims=*/0, + status.get()); + ASSERT_TRUE(TF_GetCode(status.get()) == TF_OK) << TF_Message(status.get()); + auto outputs = + parallel_device.Execute(context.get(), std::vector(), + "VarHandleOp", TFE_OpGetAttrs(handle_op.get()), + /*expected_max_outputs=*/1, status.get()); + ASSERT_TRUE(TF_GetCode(status.get()) == TF_OK) << TF_Message(status.get()); + const std::vector>& handles = *outputs; + std::vector handle_inputs; + handle_inputs.reserve(handles.size()); + for (auto& handle : handles) { + handle_inputs.push_back(handle.get()); + } + std::unique_ptr read_op( + TFE_NewOp(context.get(), "ReadVariableOp", status.get()), TFE_DeleteOp); + ASSERT_TRUE(TF_GetCode(status.get()) == TF_OK) << TF_Message(status.get()); + TFE_OpSetAttrType(read_op.get(), "dtype", TF_FLOAT); + parallel_device.Execute(context.get(), handle_inputs, "ReadVariableOp", + TFE_OpGetAttrs(read_op.get()), + /*expected_max_outputs=*/1, status.get()); + ASSERT_FALSE(TF_GetCode(status.get()) == TF_OK); + TF_SetStatus(status.get(), TF_OK, ""); + + // Check that ops still run successfully on the device. + parallel_device.Execute(context.get(), std::vector(), + "VarHandleOp", TFE_OpGetAttrs(handle_op.get()), + /*expected_max_outputs=*/1, status.get()); + ASSERT_TRUE(TF_GetCode(status.get()) == TF_OK) << TF_Message(status.get()); +} + +} // namespace parallel_device +} // namespace tensorflow diff --git a/tensorflow/c/eager/tape.h b/tensorflow/c/eager/tape.h index 27629bb3bdf..fcebe973500 100644 --- a/tensorflow/c/eager/tape.h +++ b/tensorflow/c/eager/tape.h @@ -146,13 +146,16 @@ class GradientTape { // once) and produces the gradient of the target tensors with respect to the // source tensors. The output gradients are used if not empty and not // null. The result is populated with one tensor per target element. + // When running backward functions, builds zeros-like tensors for + // incoming grads which are nullptrs, unless `build_default_zeros_grads` + // is set to false. Status ComputeGradient( const VSpace& vspace, const gtl::ArraySlice target_tensor_ids, const gtl::ArraySlice source_tensor_ids, const std::unordered_map& sources_that_are_targets, gtl::ArraySlice output_gradients, - std::vector* result); + std::vector* result, bool build_default_zeros_grads = true); bool IsPersistent() const { return persistent_; } @@ -655,8 +658,8 @@ Status GradientTape::ComputeGradient( const gtl::ArraySlice target_tensor_ids, const gtl::ArraySlice source_tensor_ids, const std::unordered_map& sources_that_are_targets, - gtl::ArraySlice output_gradients, - std::vector* result) { + gtl::ArraySlice output_gradients, std::vector* result, + bool build_default_zeros_grads) { std::unordered_set sources_set(source_tensor_ids.begin(), source_tensor_ids.end()); BackpropInitialState state = PrepareBackprop( @@ -717,14 +720,14 @@ Status GradientTape::ComputeGradient( const int64 id = trace.output_tensor_info[i].GetID(); auto grad_it = gradients.find(id); if (grad_it == gradients.end()) { - auto func_name_it = - FunctionsAcceptingNoneForIndicesMap()->find(trace.op_type); - if (func_name_it != FunctionsAcceptingNoneForIndicesMap()->end() && - func_name_it->second.find(i) != func_name_it->second.end()) { - out_gradients.push_back(nullptr); - } else { - out_gradients.push_back(nullptr); - zero_indices.push_back(i); + out_gradients.push_back(nullptr); + if (build_default_zeros_grads) { + auto func_name_it = + FunctionsAcceptingNoneForIndicesMap()->find(trace.op_type); + if (func_name_it == FunctionsAcceptingNoneForIndicesMap()->end() || + func_name_it->second.find(i) == func_name_it->second.end()) { + zero_indices.push_back(i); + } } } else { any_gradient_nonzero = true; @@ -745,6 +748,7 @@ Status GradientTape::ComputeGradient( } } std::vector in_gradients; + DCHECK(build_default_zeros_grads || zero_indices.empty()); if (any_gradient_nonzero) { for (const auto i : zero_indices) { out_gradients[i] = trace.output_tensor_info[i].ZerosLike(); diff --git a/tensorflow/c/experimental/filesystem/plugins/s3/BUILD b/tensorflow/c/experimental/filesystem/plugins/s3/BUILD index 56bd3b4a75c..a2108d06cbb 100644 --- a/tensorflow/c/experimental/filesystem/plugins/s3/BUILD +++ b/tensorflow/c/experimental/filesystem/plugins/s3/BUILD @@ -26,6 +26,8 @@ cc_library( }), deps = [ ":aws_crypto", + ":aws_logging", + "//tensorflow/c:logging", "//tensorflow/c:tf_status", "//tensorflow/c/experimental/filesystem:filesystem_interface", "@aws", @@ -45,6 +47,18 @@ cc_library( alwayslink = 1, ) +cc_library( + name = "aws_logging", + srcs = ["aws_logging.cc"], + hdrs = ["aws_logging.h"], + deps = [ + "//tensorflow/c:logging", + "@aws", + "@com_google_absl//absl/synchronization", + ], + alwayslink = 1, +) + tf_cc_test( name = "s3_filesystem_test", srcs = [ diff --git a/tensorflow/c/experimental/filesystem/plugins/s3/aws_logging.cc b/tensorflow/c/experimental/filesystem/plugins/s3/aws_logging.cc new file mode 100644 index 00000000000..353b733fd25 --- /dev/null +++ b/tensorflow/c/experimental/filesystem/plugins/s3/aws_logging.cc @@ -0,0 +1,159 @@ +/* Copyright 2020 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ +#include "tensorflow/c/experimental/filesystem/plugins/s3/aws_logging.h" + +#include +#include +#include + +#include +#include +#include + +#include "absl/synchronization/mutex.h" +#include "tensorflow/c/logging.h" + +static constexpr char kAWSLoggingTag[] = "AWSLogging"; + +static const std::map + log_levels_string_to_aws = { + {"off", Aws::Utils::Logging::LogLevel::Off}, + {"fatal", Aws::Utils::Logging::LogLevel::Fatal}, + {"error", Aws::Utils::Logging::LogLevel::Error}, + {"warn", Aws::Utils::Logging::LogLevel::Warn}, + {"info", Aws::Utils::Logging::LogLevel::Info}, + {"debug", Aws::Utils::Logging::LogLevel::Debug}, + {"trace", Aws::Utils::Logging::LogLevel::Trace}}; + +static const std::map + log_levels_tf_to_aws = {{0, Aws::Utils::Logging::LogLevel::Info}, + {1, Aws::Utils::Logging::LogLevel::Warn}, + {2, Aws::Utils::Logging::LogLevel::Error}, + {3, Aws::Utils::Logging::LogLevel::Fatal}}; + +namespace tf_s3_filesystem { + +AWSLogSystem::AWSLogSystem(Aws::Utils::Logging::LogLevel log_level) + : log_level_(log_level) {} + +void AWSLogSystem::LogMessage(Aws::Utils::Logging::LogLevel log_level, + const std::string& message) { + if (message == "Initializing Curl library") return; + switch (log_level) { + case Aws::Utils::Logging::LogLevel::Info: + TF_Log(TF_INFO, message.c_str()); + break; + case Aws::Utils::Logging::LogLevel::Warn: + TF_Log(TF_WARNING, message.c_str()); + break; + case Aws::Utils::Logging::LogLevel::Error: + TF_Log(TF_ERROR, message.c_str()); + break; + case Aws::Utils::Logging::LogLevel::Fatal: + TF_Log(TF_FATAL, message.c_str()); + break; + default: + // this will match for DEBUG, TRACE + TF_Log(TF_INFO, message.c_str()); + break; + } +} + +void AWSLogSystem::Log(Aws::Utils::Logging::LogLevel log_level, const char* tag, + const char* format, ...) { + char buffer[256]; + va_list args; + va_start(args, format); + vsnprintf(buffer, 256, format, args); + va_end(args); + LogMessage(log_level, buffer); +} + +void AWSLogSystem::LogStream(Aws::Utils::Logging::LogLevel log_level, + const char* tag, + const Aws::OStringStream& message_stream) { + LogMessage(log_level, message_stream.rdbuf()->str().c_str()); +} + +void AWSLogSystem::Flush() { return; } + +static Aws::Utils::Logging::LogLevel TfLogLevelToAwsLogLevel(int level) { + // Converts TF Log Levels INFO, WARNING, ERROR and FATAL to the AWS enum + // values for the levels + if (log_levels_tf_to_aws.find(level) != log_levels_tf_to_aws.end()) { + return log_levels_tf_to_aws.at(level); + } else { + // default to fatal + return Aws::Utils::Logging::LogLevel::Fatal; + } +} + +static Aws::Utils::Logging::LogLevel ParseAwsLogLevelFromEnv() { + // defaults to FATAL log level for the AWS SDK + // this is because many normal tensorflow operations are logged as errors in + // the AWS SDK such as checking if a file exists can log an error in AWS SDK + // if the file does not actually exist. Another such case is when reading a + // file till the end, TensorFlow expects to see an InvalidRange exception at + // the end, but this would be an error in the AWS SDK. This confuses users, + // hence the default setting. + Aws::Utils::Logging::LogLevel log_level = + Aws::Utils::Logging::LogLevel::Fatal; + + const char* aws_env_var_val = getenv("AWS_LOG_LEVEL"); + if (aws_env_var_val != nullptr) { + std::string maybe_integer_str(aws_env_var_val, strlen(aws_env_var_val)); + std::istringstream ss(maybe_integer_str); + int level; + ss >> level; + if (ss.fail()) { + // wasn't a number + // expecting a string + std::string level_str = maybe_integer_str; + if (log_levels_string_to_aws.find(level_str) != + log_levels_string_to_aws.end()) { + log_level = log_levels_string_to_aws.at(level_str); + } + } else { + // backwards compatibility + // valid number, but this number follows the standard TensorFlow log + // levels need to convert this to AWS SDK logging level number + log_level = TfLogLevelToAwsLogLevel(level); + } + } + return log_level; +} + +static bool initialized = false; +ABSL_CONST_INIT static absl::Mutex s3_logging_mutex(absl::kConstInit); +void AWSLogSystem::InitializeAWSLogging() { + absl::MutexLock l(&s3_logging_mutex); + if (!initialized) { + Aws::Utils::Logging::InitializeAWSLogging(Aws::MakeShared( + kAWSLoggingTag, ParseAwsLogLevelFromEnv())); + initialized = true; + return; + } +} + +void AWSLogSystem::ShutdownAWSLogging() { + absl::MutexLock l(&s3_logging_mutex); + if (initialized) { + Aws::Utils::Logging::ShutdownAWSLogging(); + initialized = false; + return; + } +} + +} // namespace tf_s3_filesystem diff --git a/tensorflow/c/experimental/filesystem/plugins/s3/aws_logging.h b/tensorflow/c/experimental/filesystem/plugins/s3/aws_logging.h new file mode 100644 index 00000000000..afecd7e5e62 --- /dev/null +++ b/tensorflow/c/experimental/filesystem/plugins/s3/aws_logging.h @@ -0,0 +1,64 @@ +/* Copyright 2020 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ +#ifndef TENSORFLOW_C_EXPERIMENTAL_FILESYSTEM_PLUGINS_S3_AWS_LOGGING_H_ +#define TENSORFLOW_C_EXPERIMENTAL_FILESYSTEM_PLUGINS_S3_AWS_LOGGING_H_ + +#include +#include + +#include +#include + +namespace tf_s3_filesystem { + +class AWSLogSystem : public Aws::Utils::Logging::LogSystemInterface { + public: + static void InitializeAWSLogging(); + static void ShutdownAWSLogging(); + + explicit AWSLogSystem(Aws::Utils::Logging::LogLevel log_level); + virtual ~AWSLogSystem() = default; + + // Gets the currently configured log level. + Aws::Utils::Logging::LogLevel GetLogLevel(void) const override { + return log_level_; + } + + // Set a new log level. This has the immediate effect of changing the log. + void SetLogLevel(Aws::Utils::Logging::LogLevel log_level) { + log_level_.store(log_level); + } + + // Does a printf style output to ProcessFormattedStatement. Don't use this, + // it's unsafe. See LogStream. + void Log(Aws::Utils::Logging::LogLevel log_level, const char* tag, + const char* format, ...) override; + + // Writes the stream to ProcessFormattedStatement. + void LogStream(Aws::Utils::Logging::LogLevel log_level, const char* tag, + const Aws::OStringStream& messageStream) override; + + // Flushes the buffered messages if the logger supports buffering + void Flush() override; + + private: + void LogMessage(Aws::Utils::Logging::LogLevel log_level, + const std::string& message); + std::atomic log_level_; +}; + +} // namespace tf_s3_filesystem + +#endif // TENSORFLOW_C_EXPERIMENTAL_FILESYSTEM_PLUGINS_S3_AWS_LOGGING_H_ diff --git a/tensorflow/c/experimental/filesystem/plugins/s3/s3_filesystem.cc b/tensorflow/c/experimental/filesystem/plugins/s3/s3_filesystem.cc index 7e1b36f2dcc..9ff07633f2a 100644 --- a/tensorflow/c/experimental/filesystem/plugins/s3/s3_filesystem.cc +++ b/tensorflow/c/experimental/filesystem/plugins/s3/s3_filesystem.cc @@ -38,6 +38,8 @@ limitations under the License. #include "absl/strings/str_cat.h" #include "tensorflow/c/experimental/filesystem/filesystem_interface.h" #include "tensorflow/c/experimental/filesystem/plugins/s3/aws_crypto.h" +#include "tensorflow/c/experimental/filesystem/plugins/s3/aws_logging.h" +#include "tensorflow/c/logging.h" #include "tensorflow/c/tf_status.h" // Implementation of a filesystem for S3 environments. @@ -186,6 +188,8 @@ static void GetS3Client(tf_s3_filesystem::S3File* s3_file) { absl::MutexLock l(&s3_file->initialization_lock); if (s3_file->s3_client.get() == nullptr) { + tf_s3_filesystem::AWSLogSystem::InitializeAWSLogging(); + Aws::SDKOptions options; options.cryptoOptions.sha256Factory_create_fn = []() { return Aws::MakeShared( @@ -250,6 +254,7 @@ static void ShutdownClient(Aws::S3::S3Client* s3_client) { delete s3_client; Aws::SDKOptions options; Aws::ShutdownAPI(options); + tf_s3_filesystem::AWSLogSystem::ShutdownAWSLogging(); } } @@ -281,6 +286,7 @@ void Cleanup(TF_RandomAccessFile* file) { static int64_t ReadS3Client(S3File* s3_file, uint64_t offset, size_t n, char* buffer, TF_Status* status) { + TF_VLog(3, "ReadFile using S3Client\n"); Aws::S3::Model::GetObjectRequest get_object_request; get_object_request.WithBucket(s3_file->bucket).WithKey(s3_file->object); Aws::String bytes = @@ -306,12 +312,14 @@ static int64_t ReadS3Client(S3File* s3_file, uint64_t offset, size_t n, static int64_t ReadS3TransferManager(S3File* s3_file, uint64_t offset, size_t n, char* buffer, TF_Status* status) { + TF_VLog(3, "Using TransferManager\n"); auto create_download_stream = [&]() { return Aws::New( "S3ReadStream", Aws::New( "S3ReadStream", reinterpret_cast(buffer), n)); }; + TF_VLog(3, "Created stream to read with transferManager\n"); auto handle = s3_file->transfer_manager->DownloadFile( s3_file->bucket, s3_file->object, offset, n, create_download_stream); handle->WaitUntilFinished(); @@ -322,6 +330,10 @@ static int64_t ReadS3TransferManager(S3File* s3_file, uint64_t offset, size_t n, Aws::Http::HttpResponseCode::REQUESTED_RANGE_NOT_SATISFIABLE && retries++ < kDownloadRetries) { // Only failed parts will be downloaded again. + TF_VLog( + 1, + "Retrying read of s3://%s/%s after failure. Current retry count: %u\n", + s3_file->bucket.c_str(), s3_file->object.c_str(), retries); s3_file->transfer_manager->RetryDownload(handle); handle->WaitUntilFinished(); } @@ -341,6 +353,8 @@ static int64_t ReadS3TransferManager(S3File* s3_file, uint64_t offset, size_t n, int64_t Read(const TF_RandomAccessFile* file, uint64_t offset, size_t n, char* buffer, TF_Status* status) { auto s3_file = static_cast(file->plugin_file); + TF_VLog(1, "ReadFilefromS3 s3://%s/%s from %u for n: %u\n", + s3_file->bucket.c_str(), s3_file->object.c_str(), offset, n); if (s3_file->use_multi_part_download) return ReadS3TransferManager(s3_file, offset, n, buffer, status); else @@ -416,6 +430,8 @@ void Sync(const TF_WritableFile* file, TF_Status* status) { TF_SetStatus(status, TF_OK, ""); return; } + TF_VLog(1, "WriteFileToS3: s3://%s/%s\n", s3_file->bucket.c_str(), + s3_file->object.c_str()); auto position = static_cast(s3_file->outfile->tellp()); auto handle = s3_file->transfer_manager->UploadFile( s3_file->outfile, s3_file->bucket, s3_file->object, @@ -426,6 +442,10 @@ void Sync(const TF_WritableFile* file, TF_Status* status) { while (handle->GetStatus() == Aws::Transfer::TransferStatus::FAILED && retries++ < kUploadRetries) { // if multipart upload was used, only the failed parts will be re-sent + TF_VLog(1, + "Retrying upload of s3://%s/%s after failure. Current retry count: " + "%u\n", + s3_file->bucket.c_str(), s3_file->object.c_str(), retries); s3_file->transfer_manager->RetryUpload(s3_file->outfile, handle); handle->WaitUntilFinished(); } @@ -613,6 +633,7 @@ void NewAppendableFile(const TF_Filesystem* filesystem, const char* path, void Stat(const TF_Filesystem* filesystem, const char* path, TF_FileStatistics* stats, TF_Status* status) { + TF_VLog(1, "Stat on path: %s\n", path); Aws::String bucket, object; ParseS3Path(path, true, &bucket, &object, status); if (TF_GetCode(status) != TF_OK) return; @@ -737,6 +758,8 @@ static void SimpleCopyFile(const Aws::String& source, const Aws::String& bucket_dst, const Aws::String& object_dst, S3File* s3_file, TF_Status* status) { + TF_VLog(1, "SimpleCopyFile from %s to %s/%s\n", bucket_dst.c_str(), + object_dst.c_str()); Aws::S3::Model::CopyObjectRequest copy_object_request; copy_object_request.WithCopySource(source) .WithBucket(bucket_dst) @@ -801,6 +824,8 @@ static void MultiPartCopy(const Aws::String& source, const Aws::String& object_dst, const size_t num_parts, const uint64_t file_size, S3File* s3_file, TF_Status* status) { + TF_VLog(1, "MultiPartCopy from %s to %s/%s\n", bucket_dst.c_str(), + object_dst.c_str()); Aws::S3::Model::CreateMultipartUploadRequest create_multipart_upload_request; create_multipart_upload_request.WithBucket(bucket_dst).WithKey(object_dst); @@ -827,6 +852,8 @@ static void MultiPartCopy(const Aws::String& source, auto chunk_size = s3_file->multi_part_chunk_sizes[Aws::Transfer::TransferDirection::UPLOAD]; + TF_VLog(1, "Copying from %s in %u parts of size %u each\n", source.c_str(), + num_parts, chunk_size); size_t retries = 0; while (retries++ < 3) { // Queue up parts. @@ -891,6 +918,9 @@ static void MultiPartCopy(const Aws::String& source, status); } else { // Retry. + TF_Log(TF_ERROR, + "Retrying failed copy of part %u due to an error with S3\n", + part_number); num_finished_parts--; } } @@ -967,6 +997,7 @@ void CopyFile(const TF_Filesystem* filesystem, const char* src, const char* dst, void DeleteFile(const TF_Filesystem* filesystem, const char* path, TF_Status* status) { + TF_VLog(1, "DeleteFile: %s\n", path); Aws::String bucket, object; ParseS3Path(path, false, &bucket, &object, status); if (TF_GetCode(status) != TF_OK) return; @@ -985,6 +1016,7 @@ void DeleteFile(const TF_Filesystem* filesystem, const char* path, void CreateDir(const TF_Filesystem* filesystem, const char* path, TF_Status* status) { + TF_VLog(1, "CreateDir: %s\n", path); Aws::String bucket, object; ParseS3Path(path, true, &bucket, &object, status); if (TF_GetCode(status) != TF_OK) return; @@ -1026,6 +1058,7 @@ void CreateDir(const TF_Filesystem* filesystem, const char* path, void DeleteDir(const TF_Filesystem* filesystem, const char* path, TF_Status* status) { + TF_VLog(1, "DeleteDir: %s\n", path); Aws::String bucket, object; ParseS3Path(path, false, &bucket, &object, status); if (TF_GetCode(status) != TF_OK) return; @@ -1060,6 +1093,7 @@ void DeleteDir(const TF_Filesystem* filesystem, const char* path, void RenameFile(const TF_Filesystem* filesystem, const char* src, const char* dst, TF_Status* status) { + TF_VLog(1, "RenameFile from: %s to %s\n", src, dst); Aws::String bucket_src, object_src; ParseS3Path(src, false, &bucket_src, &object_src, status); if (TF_GetCode(status) != TF_OK) return; @@ -1120,6 +1154,7 @@ void RenameFile(const TF_Filesystem* filesystem, const char* src, int GetChildren(const TF_Filesystem* filesystem, const char* path, char*** entries, TF_Status* status) { + TF_VLog(1, "GetChildren for path: %s\n", path); Aws::String bucket, prefix; ParseS3Path(path, true, &bucket, &prefix, status); if (TF_GetCode(status) != TF_OK) return -1; diff --git a/tensorflow/c/experimental/gradients/BUILD b/tensorflow/c/experimental/gradients/BUILD index 80c4e8d9791..36a3251def7 100644 --- a/tensorflow/c/experimental/gradients/BUILD +++ b/tensorflow/c/experimental/gradients/BUILD @@ -3,6 +3,24 @@ package( licenses = ["notice"], # Apache 2.0 ) +cc_library( + name = "array_grad", + srcs = ["array_grad.cc"], + hdrs = [ + "array_grad.h", + ], + visibility = [ + "//tensorflow:internal", + ], + deps = [ + "//tensorflow/c/eager:abstract_operation", + "//tensorflow/c/eager:abstract_tensor_handle", + "//tensorflow/c/eager:c_api_unified_internal", + "//tensorflow/c/eager:gradients", + "//tensorflow/core/lib/llvm_rtti", + ], +) + cc_library( name = "math_grad", srcs = ["math_grad.cc"], @@ -19,6 +37,28 @@ cc_library( "//tensorflow/c/eager:gradients", "//tensorflow/c/experimental/ops:array_ops", "//tensorflow/c/experimental/ops:math_ops", + "//tensorflow/c/experimental/ops:nn_ops", + "//tensorflow/core/lib/llvm_rtti", + ], +) + +cc_library( + name = "nn_grad", + srcs = ["nn_grad.cc"], + hdrs = [ + "nn_grad.h", + ], + visibility = [ + "//tensorflow:internal", + ], + deps = [ + "//tensorflow/c/eager:abstract_operation", + "//tensorflow/c/eager:abstract_tensor_handle", + "//tensorflow/c/eager:c_api_unified_internal", + "//tensorflow/c/eager:gradients", + "//tensorflow/c/experimental/ops:array_ops", + "//tensorflow/c/experimental/ops:math_ops", + "//tensorflow/c/experimental/ops:nn_ops", "//tensorflow/core/lib/llvm_rtti", ], ) diff --git a/tensorflow/c/experimental/gradients/array_grad.cc b/tensorflow/c/experimental/gradients/array_grad.cc new file mode 100644 index 00000000000..069209a4b6b --- /dev/null +++ b/tensorflow/c/experimental/gradients/array_grad.cc @@ -0,0 +1,48 @@ +/* Copyright 2020 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ +#include "tensorflow/c/experimental/gradients/array_grad.h" + +namespace tensorflow { +namespace gradients { +namespace { +using std::vector; +class IdentityNGradientFunction : public GradientFunction { + public: + Status Compute(Context* ctx, const IncomingGradients& grad_inputs, + vector* grad_outputs) override { + grad_outputs->resize(grad_inputs.size(), nullptr); + for (int i = 0; i < grad_inputs.size(); i++) { + auto grad_input = grad_inputs[i]; + // TODO(srbs): Should we add a copy contructor to AbstractTensorHandle + // that takes care of this similar to `Tensor`? + if (grad_input) { + grad_input->Ref(); + } + (*grad_outputs)[i] = grad_input; + } + return Status::OK(); + } + ~IdentityNGradientFunction() override {} +}; +} // namespace + +BackwardFunction* IdentityNRegisterer(const ForwardOperation& op) { + auto gradient_function = new IdentityNGradientFunction; + auto default_gradients = new PassThroughDefaultGradients(op); + return new BackwardFunction(gradient_function, default_gradients); +} + +} // namespace gradients +} // namespace tensorflow diff --git a/tensorflow/c/experimental/gradients/array_grad.h b/tensorflow/c/experimental/gradients/array_grad.h new file mode 100644 index 00000000000..edeeb5fcb4a --- /dev/null +++ b/tensorflow/c/experimental/gradients/array_grad.h @@ -0,0 +1,26 @@ +/* Copyright 2020 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ +#ifndef TENSORFLOW_C_EXPERIMENTAL_GRADIENTS_ARRAY_GRAD_H_ +#define TENSORFLOW_C_EXPERIMENTAL_GRADIENTS_ARRAY_GRAD_H_ + +#include "tensorflow/c/eager/gradients.h" + +namespace tensorflow { +namespace gradients { +BackwardFunction* IdentityNRegisterer(const ForwardOperation& op); +} // namespace gradients +} // namespace tensorflow + +#endif // TENSORFLOW_C_EXPERIMENTAL_GRADIENTS_ARRAY_GRAD_H_ diff --git a/tensorflow/c/experimental/gradients/math_grad.cc b/tensorflow/c/experimental/gradients/math_grad.cc index d8b70848d4e..f298c202046 100644 --- a/tensorflow/c/experimental/gradients/math_grad.cc +++ b/tensorflow/c/experimental/gradients/math_grad.cc @@ -15,13 +15,17 @@ limitations under the License. #include "tensorflow/c/experimental/gradients/math_grad.h" #include "tensorflow/c/eager/abstract_tensor_handle.h" +#include "tensorflow/c/eager/gradients.h" #include "tensorflow/c/experimental/ops/array_ops.h" #include "tensorflow/c/experimental/ops/math_ops.h" +#include "tensorflow/c/experimental/ops/nn_ops.h" using std::vector; using tensorflow::ops::Conj; using tensorflow::ops::Identity; +using tensorflow::ops::MatMul; using tensorflow::ops::Mul; +using tensorflow::ops::ZerosLike; namespace tensorflow { namespace gradients { @@ -29,20 +33,23 @@ namespace { class AddGradientFunction : public GradientFunction { public: - Status Compute(Context* ctx, - absl::Span grad_inputs, + Status Compute(Context* ctx, const IncomingGradients& grad_inputs, vector* grad_outputs) override { grad_outputs->resize(2); vector identity_outputs(1); // TODO(b/145674566): Handle name unification in tracing code. // TODO(b/161805092): Support broadcasting. + + std::string name = "Identity_A"; TF_RETURN_IF_ERROR(ops::Identity(ctx->ctx, {grad_inputs[0]}, absl::MakeSpan(identity_outputs), - "Identity0")); + name.c_str())); (*grad_outputs)[0] = identity_outputs[0]; + + name = "Identity_B"; TF_RETURN_IF_ERROR(ops::Identity(ctx->ctx, {grad_inputs[0]}, absl::MakeSpan(identity_outputs), - "Identity1")); + name.c_str())); (*grad_outputs)[1] = identity_outputs[0]; return Status::OK(); } @@ -54,16 +61,18 @@ class ExpGradientFunction : public GradientFunction { explicit ExpGradientFunction(AbstractTensorHandle* exp) : exp_(exp) { exp->Ref(); } - Status Compute(Context* ctx, - absl::Span grad_inputs, + Status Compute(Context* ctx, const IncomingGradients& grad_inputs, vector* grad_outputs) override { vector conj_outputs(1); - TF_RETURN_IF_ERROR( - Conj(ctx->ctx, {exp_.get()}, absl::MakeSpan(conj_outputs), "ExpConj")); + std::string name = "Conj_Exp_Grad"; + TF_RETURN_IF_ERROR(Conj(ctx->ctx, {exp_.get()}, + absl::MakeSpan(conj_outputs), name.c_str())); AbstractTensorHandlePtr conj_output_releaser(conj_outputs[0]); grad_outputs->resize(1); + + name = "Mul_Exp_Grad"; TF_RETURN_IF_ERROR(Mul(ctx->ctx, {conj_outputs[0], grad_inputs[0]}, - absl::MakeSpan(*grad_outputs), "ExpGradMul")); + absl::MakeSpan(*grad_outputs), name.c_str())); return Status::OK(); } ~ExpGradientFunction() override {} @@ -72,14 +81,142 @@ class ExpGradientFunction : public GradientFunction { AbstractTensorHandlePtr exp_; }; +class MatMulGradientFunction : public GradientFunction { + public: + explicit MatMulGradientFunction(vector f_inputs, + AttrBuilder f_attrs) + : forward_inputs(f_inputs), forward_attrs(f_attrs) {} + + Status Compute(Context* ctx, const IncomingGradients& grad_inputs, + vector* grad_outputs) override { + /* Given upstream grad U and a matmul op A*B, the gradients are: + * + * dA = U * B.T + * dB = A.T * U + * + * where A.T means `transpose(A)` + */ + AbstractTensorHandle* upstream_grad = grad_inputs[0]; + grad_outputs->resize(2); + + // Get transpose attrs + bool t_a; + forward_attrs.Get("transpose_a", &t_a); + + bool t_b; + forward_attrs.Get("transpose_b", &t_b); + + // Conj each input + vector conj_outputs(1); + std::string name = "Conj_A_MatMul_Grad"; + TF_RETURN_IF_ERROR(Conj(ctx->ctx, {forward_inputs[0]}, + absl::MakeSpan(conj_outputs), name.c_str())); + + AbstractTensorHandle* A = conj_outputs[0]; + + name = "Conj_B_MatMul_Grad"; + TF_RETURN_IF_ERROR(Conj(ctx->ctx, {forward_inputs[1]}, + absl::MakeSpan(conj_outputs), name.c_str())); + + AbstractTensorHandle* B = conj_outputs[0]; + + // Calc Grad + vector matmul_A_outputs(1); + vector matmul_B_outputs(1); + std::string name_grad_A = "MatMul_Grad_A"; + std::string name_grad_B = "MatMul_Grad_B"; + if (!t_a && !t_b) { + TF_RETURN_IF_ERROR(MatMul(ctx->ctx, {upstream_grad, B}, + absl::MakeSpan(matmul_A_outputs), + name_grad_A.c_str(), + /*transpose_a = */ false, + /*transpose_b = */ true)); + + TF_RETURN_IF_ERROR(MatMul(ctx->ctx, {A, upstream_grad}, + absl::MakeSpan(matmul_B_outputs), + name_grad_B.c_str(), + /*transpose_a = */ true, + /*transpose_b = */ false)); + } else if (!t_a && t_b) { + TF_RETURN_IF_ERROR(MatMul(ctx->ctx, {upstream_grad, B}, + absl::MakeSpan(matmul_A_outputs), + name_grad_A.c_str(), + /*transpose_a = */ false, + /*transpose_b = */ false)); + + TF_RETURN_IF_ERROR(MatMul(ctx->ctx, {upstream_grad, A}, + absl::MakeSpan(matmul_B_outputs), + name_grad_B.c_str(), + /*transpose_a = */ true, + /*transpose_b = */ false)); + + } else if (t_a && !t_b) { + TF_RETURN_IF_ERROR(MatMul(ctx->ctx, {B, upstream_grad}, + absl::MakeSpan(matmul_A_outputs), + name_grad_A.c_str(), + /*transpose_a = */ false, + /*transpose_b = */ true)); + + TF_RETURN_IF_ERROR(MatMul(ctx->ctx, {A, upstream_grad}, + absl::MakeSpan(matmul_B_outputs), + name_grad_B.c_str(), + /*transpose_a = */ false, + /*transpose_b = */ false)); + } else { // t_a && t_b + TF_RETURN_IF_ERROR(MatMul(ctx->ctx, {B, upstream_grad}, + absl::MakeSpan(matmul_A_outputs), + name_grad_A.c_str(), + /*transpose_a = */ true, + /*transpose_b = */ true)); + + TF_RETURN_IF_ERROR(MatMul(ctx->ctx, {upstream_grad, A}, + absl::MakeSpan(matmul_B_outputs), + name_grad_B.c_str(), + /*transpose_a = */ true, + /*transpose_b = */ true)); + } + + // Gradient for A + (*grad_outputs)[0] = matmul_A_outputs[0]; + + // Gradient for B + (*grad_outputs)[1] = matmul_B_outputs[0]; + return Status::OK(); + } + ~MatMulGradientFunction() override {} + + private: + vector forward_inputs; + AttrBuilder forward_attrs; +}; + } // namespace -GradientFunction* AddRegisterer(const ForwardOperation& op) { - return new AddGradientFunction; +BackwardFunction* AddRegisterer(const ForwardOperation& op) { + auto gradient_function = new AddGradientFunction; + // For ops with a single output, the gradient function is not called if there + // is no incoming gradient. So we do not need to worry about creating zeros + // grads in this case. + auto default_gradients = new PassThroughDefaultGradients(op); + return new BackwardFunction(gradient_function, default_gradients); } -GradientFunction* ExpRegisterer(const ForwardOperation& op) { - return new ExpGradientFunction(op.outputs[0]); +BackwardFunction* ExpRegisterer(const ForwardOperation& op) { + auto gradient_function = new ExpGradientFunction(op.outputs[0]); + // For ops with a single output, the gradient function is not called if there + // is no incoming gradient. So we do not need to worry about creating zeros + // grads in this case. + auto default_gradients = new PassThroughDefaultGradients(op); + return new BackwardFunction(gradient_function, default_gradients); +} + +BackwardFunction* MatMulRegisterer(const ForwardOperation& op) { + auto gradient_function = new MatMulGradientFunction(op.inputs, op.attrs); + // For ops with a single output, the gradient function is not called if there + // is no incoming gradient. So we do not need to worry about creating zeros + // grads in this case. + auto default_gradients = new PassThroughDefaultGradients(op); + return new BackwardFunction(gradient_function, default_gradients); } } // namespace gradients diff --git a/tensorflow/c/experimental/gradients/math_grad.h b/tensorflow/c/experimental/gradients/math_grad.h index 6c7242a1a49..205419e1201 100644 --- a/tensorflow/c/experimental/gradients/math_grad.h +++ b/tensorflow/c/experimental/gradients/math_grad.h @@ -19,9 +19,10 @@ limitations under the License. namespace tensorflow { namespace gradients { -GradientFunction* AddRegisterer(const ForwardOperation& op); -GradientFunction* ExpRegisterer(const ForwardOperation& op); +BackwardFunction* AddRegisterer(const ForwardOperation& op); +BackwardFunction* ExpRegisterer(const ForwardOperation& op); +BackwardFunction* MatMulRegisterer(const ForwardOperation& op); } // namespace gradients } // namespace tensorflow -#endif // TENSORFLOW_C_EXPERIMENTAL_GRADIENTS_MATH_GRAD_H_ +#endif // TENSORFLOW_C_EXPERIMENTAL_GRADIENTS_MATH_GRAD_H_ \ No newline at end of file diff --git a/tensorflow/c/experimental/gradients/nn_grad.cc b/tensorflow/c/experimental/gradients/nn_grad.cc new file mode 100644 index 00000000000..3da1e0dc153 --- /dev/null +++ b/tensorflow/c/experimental/gradients/nn_grad.cc @@ -0,0 +1,111 @@ +/* Copyright 2020 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ +#include "tensorflow/c/experimental/gradients/nn_grad.h" + +#include "tensorflow/c/experimental/ops/array_ops.h" +#include "tensorflow/c/experimental/ops/math_ops.h" +#include "tensorflow/c/experimental/ops/nn_ops.h" + +using std::vector; +using tensorflow::ops::Conj; +using tensorflow::ops::Identity; +using tensorflow::ops::Mul; +using tensorflow::ops::ReluGrad; +using tensorflow::ops::SparseSoftmaxCrossEntropyLoss; +using tensorflow::ops::ZerosLike; + +namespace tensorflow { +namespace gradients { +namespace { + +class ReluGradientFunction : public GradientFunction { + public: + explicit ReluGradientFunction(vector f_outputs) + : forward_outputs(f_outputs) {} + + Status Compute(Context* ctx, const IncomingGradients& grad_inputs, + vector* grad_outputs) override { + AbstractTensorHandle* upstream_grad = grad_inputs[0]; + AbstractTensorHandle* activations = forward_outputs[0]; + grad_outputs->resize(1); + vector relugrad_outputs(1); + + // Calculate Grad + std::string name = "relu_grad"; + + TF_RETURN_IF_ERROR(ReluGrad(ctx->ctx, {upstream_grad, activations}, + absl::MakeSpan(relugrad_outputs), + name.c_str())); + (*grad_outputs)[0] = relugrad_outputs[0]; + + return Status::OK(); + } + ~ReluGradientFunction() override {} + + private: + vector forward_outputs; +}; + +class SparseSoftmaxCrossEntropyLossGradientFunction : public GradientFunction { + public: + explicit SparseSoftmaxCrossEntropyLossGradientFunction( + vector f_outputs) + : forward_outputs(f_outputs) {} + + Status Compute(Context* ctx, const IncomingGradients& grad_inputs, + vector* grad_outputs) override { + grad_outputs->resize(2); + + // Grad for Softmax Input + std::string name = "Mul_Softmax_Grad"; + vector mul_outputs(1); + TF_RETURN_IF_ERROR( + ops::Mul(ctx->ctx, {grad_inputs[0], forward_outputs[1]}, + absl::MakeSpan(mul_outputs), + name.c_str())); // upstream_grad * local softmax grad + (*grad_outputs)[0] = mul_outputs[0]; + + // Grad for labels is null + (*grad_outputs)[1] = nullptr; + + return Status::OK(); + } + ~SparseSoftmaxCrossEntropyLossGradientFunction() override {} + + private: + vector forward_outputs; +}; + +} // namespace + +BackwardFunction* ReluRegisterer(const ForwardOperation& op) { + auto gradient_function = new ReluGradientFunction(op.outputs); + // For ops with a single output, the gradient function is not called if there + // is no incoming gradient. So we do not need to worry about creating zeros + // grads in this case. + auto default_gradients = new PassThroughDefaultGradients(op); + return new BackwardFunction(gradient_function, default_gradients); +} + +BackwardFunction* SparseSoftmaxCrossEntropyLossRegisterer( + const ForwardOperation& op) { + auto gradient_function = + new SparseSoftmaxCrossEntropyLossGradientFunction(op.outputs); + auto default_gradients = new PassThroughDefaultGradients(op); + return new BackwardFunction(gradient_function, default_gradients); +} + +} // namespace gradients +} // namespace tensorflow diff --git a/tensorflow/c/experimental/gradients/nn_grad.h b/tensorflow/c/experimental/gradients/nn_grad.h new file mode 100644 index 00000000000..d002725847f --- /dev/null +++ b/tensorflow/c/experimental/gradients/nn_grad.h @@ -0,0 +1,28 @@ +/* Copyright 2020 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ +#ifndef TENSORFLOW_C_EXPERIMENTAL_GRADIENTS_NN_GRAD_H_ +#define TENSORFLOW_C_EXPERIMENTAL_GRADIENTS_NN_GRAD_H_ + +#include "tensorflow/c/eager/gradients.h" + +namespace tensorflow { +namespace gradients { +BackwardFunction* ReluRegisterer(const ForwardOperation& op); +BackwardFunction* SparseSoftmaxCrossEntropyLossRegisterer( + const ForwardOperation& op); +} // namespace gradients +} // namespace tensorflow + +#endif // TENSORFLOW_C_EXPERIMENTAL_GRADIENTS_NN_GRAD_H_ \ No newline at end of file diff --git a/tensorflow/c/experimental/ops/BUILD b/tensorflow/c/experimental/ops/BUILD index d13d7a72d3e..3504737c314 100644 --- a/tensorflow/c/experimental/ops/BUILD +++ b/tensorflow/c/experimental/ops/BUILD @@ -15,7 +15,6 @@ cc_library( "//tensorflow:internal", ], deps = [ - "//tensorflow/c/eager:abstract_context", "//tensorflow/c/eager:abstract_operation", "//tensorflow/c/eager:abstract_tensor_handle", "//tensorflow/c/eager:c_api_unified_internal", @@ -36,12 +35,30 @@ cc_library( "//tensorflow:internal", ], deps = [ - ":array_ops", - "//tensorflow/c/eager:abstract_context", "//tensorflow/c/eager:abstract_operation", "//tensorflow/c/eager:abstract_tensor_handle", "//tensorflow/c/eager:c_api_unified_internal", - "//tensorflow/core:framework_headers_lib", + "//tensorflow/c/experimental/ops:array_ops", + "//tensorflow/core/lib/llvm_rtti", + "//tensorflow/core/platform:errors", + ], +) + +cc_library( + name = "nn_ops", + srcs = [ + "nn_ops.cc", + ], + hdrs = [ + "nn_ops.h", + ], + visibility = [ + "//tensorflow:internal", + ], + deps = [ + "//tensorflow/c/eager:abstract_operation", + "//tensorflow/c/eager:abstract_tensor_handle", + "//tensorflow/c/eager:c_api_unified_internal", "//tensorflow/core/lib/llvm_rtti", "//tensorflow/core/platform:errors", ], diff --git a/tensorflow/c/experimental/ops/array_ops.cc b/tensorflow/c/experimental/ops/array_ops.cc index ab2d114d9d9..df0f4639fbd 100644 --- a/tensorflow/c/experimental/ops/array_ops.cc +++ b/tensorflow/c/experimental/ops/array_ops.cc @@ -19,7 +19,7 @@ limitations under the License. namespace tensorflow { namespace ops { -// Creates an Identity op. + Status Identity(AbstractContext* ctx, absl::Span inputs, absl::Span outputs, const char* name) { @@ -35,5 +35,19 @@ Status Identity(AbstractContext* ctx, return identity_op->Execute(outputs, &num_retvals); } +Status ZerosLike(AbstractContext* ctx, + absl::Span inputs, + absl::Span outputs, const char* name) { + AbstractOperationPtr z_op(ctx->CreateOperation()); + TF_RETURN_IF_ERROR(z_op->Reset("ZerosLike", /*raw_device_name=*/nullptr)); + if (isa(z_op.get())) { + TF_RETURN_IF_ERROR( + dyn_cast(z_op.get())->SetOpName(name)); + } + TF_RETURN_IF_ERROR(z_op->AddInput(inputs[0])); + int num_retvals = 1; + return z_op->Execute(outputs, &num_retvals); +} + } // namespace ops } // namespace tensorflow diff --git a/tensorflow/c/experimental/ops/array_ops.h b/tensorflow/c/experimental/ops/array_ops.h index 226461fd286..8dc68db673f 100644 --- a/tensorflow/c/experimental/ops/array_ops.h +++ b/tensorflow/c/experimental/ops/array_ops.h @@ -22,9 +22,15 @@ limitations under the License. namespace tensorflow { namespace ops { + Status Identity(AbstractContext* ctx, absl::Span inputs, absl::Span outputs, const char* name); + +Status ZerosLike(AbstractContext* ctx, + absl::Span inputs, + absl::Span outputs, const char* name); + } // namespace ops } // namespace tensorflow diff --git a/tensorflow/c/experimental/ops/math_ops.cc b/tensorflow/c/experimental/ops/math_ops.cc index e91acbd6370..82c2f0e8169 100644 --- a/tensorflow/c/experimental/ops/math_ops.cc +++ b/tensorflow/c/experimental/ops/math_ops.cc @@ -51,5 +51,60 @@ Status Conj(AbstractContext* ctx, return Status::OK(); } +Status Add(AbstractContext* ctx, absl::Span inputs, + absl::Span outputs, const char* name) { + AbstractOperationPtr add_op(ctx->CreateOperation()); + TF_RETURN_IF_ERROR(add_op->Reset("AddV2", /*raw_device_name=*/nullptr)); + + if (isa(add_op.get())) { + TF_RETURN_IF_ERROR( + dyn_cast(add_op.get())->SetOpName(name)); + } + + TF_RETURN_IF_ERROR(add_op->AddInput(inputs[0])); + TF_RETURN_IF_ERROR(add_op->AddInput(inputs[1])); + + int num_retvals = 1; + TF_RETURN_IF_ERROR(add_op->Execute(outputs, &num_retvals)); + return Status::OK(); +} + +Status MatMul(AbstractContext* ctx, + absl::Span inputs, + absl::Span outputs, const char* name, + bool transpose_a = false, bool transpose_b = false) { + AbstractOperationPtr matmul_op(ctx->CreateOperation()); + TF_RETURN_IF_ERROR(matmul_op->Reset("MatMul", /*raw_device_name=*/nullptr)); + + if (isa(matmul_op.get())) { + TF_RETURN_IF_ERROR( + dyn_cast(matmul_op.get())->SetOpName(name)); + } + + TF_RETURN_IF_ERROR(matmul_op->AddInput(inputs[0])); + TF_RETURN_IF_ERROR(matmul_op->AddInput(inputs[1])); + + TF_RETURN_IF_ERROR(matmul_op->SetAttrBool("transpose_a", transpose_a)); + TF_RETURN_IF_ERROR(matmul_op->SetAttrBool("transpose_b", transpose_b)); + + int num_retvals = 1; + TF_RETURN_IF_ERROR(matmul_op->Execute(outputs, &num_retvals)); + return Status::OK(); +} + +Status Neg(AbstractContext* ctx, absl::Span inputs, + absl::Span outputs, const char* name) { + AbstractOperationPtr neg_op(ctx->CreateOperation()); + TF_RETURN_IF_ERROR(neg_op->Reset("Neg", /*raw_device_name=*/nullptr)); + if (isa(neg_op.get())) { + TF_RETURN_IF_ERROR( + dyn_cast(neg_op.get())->SetOpName(name)); + } + TF_RETURN_IF_ERROR(neg_op->AddInput(inputs[0])); + + int num_retvals = 1; + return neg_op->Execute(outputs, &num_retvals); +} + } // namespace ops } // namespace tensorflow diff --git a/tensorflow/c/experimental/ops/math_ops.h b/tensorflow/c/experimental/ops/math_ops.h index 4d7c3d838ce..ed1e6c5b3d6 100644 --- a/tensorflow/c/experimental/ops/math_ops.h +++ b/tensorflow/c/experimental/ops/math_ops.h @@ -25,6 +25,15 @@ Status Mul(AbstractContext* ctx, absl::Span inputs, Status Conj(AbstractContext* ctx, absl::Span inputs, absl::Span outputs, const char* name); +Status Add(AbstractContext* ctx, absl::Span inputs, + absl::Span outputs, const char* name); +Status MatMul(AbstractContext* ctx, + absl::Span inputs, + absl::Span outputs, const char* name, + bool transpose_a, bool transpose_b); +Status Neg(AbstractContext* ctx, absl::Span inputs, + absl::Span outputs, const char* name); + } // namespace ops } // namespace tensorflow diff --git a/tensorflow/c/experimental/ops/nn_ops.cc b/tensorflow/c/experimental/ops/nn_ops.cc new file mode 100644 index 00000000000..8f5f550bb8b --- /dev/null +++ b/tensorflow/c/experimental/ops/nn_ops.cc @@ -0,0 +1,67 @@ + +/* Copyright 2020 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ +#include "tensorflow/c/experimental/ops/nn_ops.h" + +#include "tensorflow/core/platform/errors.h" + +namespace tensorflow { +namespace ops { + +// Softmax Loss given scores and labels, used by the SoftMaxLossGradient +Status SparseSoftmaxCrossEntropyLoss( + AbstractContext* ctx, absl::Span inputs, + absl::Span outputs, const char* name) { + AbstractOperationPtr sm_loss_op(ctx->CreateOperation()); + TF_RETURN_IF_ERROR(sm_loss_op->Reset("SparseSoftmaxCrossEntropyWithLogits", + /*raw_device_name=*/nullptr)); + + if (isa(sm_loss_op.get())) { + TF_RETURN_IF_ERROR( + dyn_cast(sm_loss_op.get())->SetOpName(name)); + } + + TF_RETURN_IF_ERROR(sm_loss_op->AddInput(inputs[0])); // input scores + TF_RETURN_IF_ERROR(sm_loss_op->AddInput(inputs[1])); // labels + + // Outputs will contain: [loss_vals, gradients]. + int num_retvals = 2; + TF_RETURN_IF_ERROR(sm_loss_op->Execute(outputs, &num_retvals)); + return Status::OK(); +} + +// Computes Relu gradient given input features +Status ReluGrad(AbstractContext* ctx, + absl::Span inputs, + absl::Span outputs, const char* name) { + AbstractOperationPtr relugrad_op(ctx->CreateOperation()); + TF_RETURN_IF_ERROR( + relugrad_op->Reset("ReluGrad", /*raw_device_name=*/nullptr)); + + if (isa(relugrad_op.get())) { + TF_RETURN_IF_ERROR(dyn_cast(relugrad_op.get()) + ->SetOpName(name)); + } + + TF_RETURN_IF_ERROR(relugrad_op->AddInput(inputs[0])); // upstream grads + TF_RETURN_IF_ERROR(relugrad_op->AddInput(inputs[1])); // relu inputs + + int num_retvals = 1; + TF_RETURN_IF_ERROR(relugrad_op->Execute(outputs, &num_retvals)); + return Status::OK(); +} + +} // namespace ops +} // namespace tensorflow diff --git a/tensorflow/c/experimental/ops/nn_ops.h b/tensorflow/c/experimental/ops/nn_ops.h new file mode 100644 index 00000000000..3e618b00869 --- /dev/null +++ b/tensorflow/c/experimental/ops/nn_ops.h @@ -0,0 +1,37 @@ +/* Copyright 2020 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ +#ifndef TENSORFLOW_C_EXPERIMENTAL_OPS_NN_OPS_H_ +#define TENSORFLOW_C_EXPERIMENTAL_OPS_NN_OPS_H_ + +#include "tensorflow/c/eager/abstract_operation.h" +#include "tensorflow/c/eager/abstract_tensor_handle.h" +#include "tensorflow/c/eager/c_api_unified_experimental_internal.h" +#include "tensorflow/core/lib/llvm_rtti/llvm_rtti.h" + +namespace tensorflow { +namespace ops { + +Status SparseSoftmaxCrossEntropyLoss( + AbstractContext* ctx, absl::Span inputs, + absl::Span outputs, const char* name); + +Status ReluGrad(AbstractContext* ctx, + absl::Span inputs, + absl::Span outputs, const char* name); + +} // namespace ops +} // namespace tensorflow + +#endif // TENSORFLOW_C_EXPERIMENTAL_OPS_NN_OPS_H_ diff --git a/tensorflow/c/experimental/saved_model/core/BUILD b/tensorflow/c/experimental/saved_model/core/BUILD index b2e432782de..2feb7c1b33e 100644 --- a/tensorflow/c/experimental/saved_model/core/BUILD +++ b/tensorflow/c/experimental/saved_model/core/BUILD @@ -44,7 +44,9 @@ cc_library( ], deps = [ ":concrete_function", + ":signature_def_function", "//tensorflow/core:lib", + "@com_google_absl//absl/strings", ], ) @@ -70,6 +72,26 @@ cc_library( ], ) +cc_library( + name = "signature_def_function", + hdrs = [ + "signature_def_function.h", + ], + deps = [ + ":signature_def_function_metadata", + "//tensorflow/c/eager:immediate_execution_operation", + "//tensorflow/c/eager:immediate_execution_tensor_handle", + "@com_google_absl//absl/types:span", + ], +) + +cc_library( + name = "signature_def_function_metadata", + hdrs = [ + "signature_def_function_metadata.h", + ], +) + cc_library( name = "test_utils", testonly = True, @@ -115,6 +137,7 @@ cc_library( ":concrete_function", ":saved_model_api", ":saved_model_utils", + ":signature_def_function", "//tensorflow/c:tensor_interface", "//tensorflow/c/eager:immediate_execution_context", "//tensorflow/c/eager:immediate_execution_tensor_handle", @@ -206,13 +229,13 @@ tf_cc_test( "//tensorflow/c/experimental/saved_model/core/revived_types:constant", "//tensorflow/core:all_kernels", "//tensorflow/core:framework", - "//tensorflow/core:lib", "//tensorflow/core:protos_all_cc", "//tensorflow/core:test", "//tensorflow/core:test_main", "//tensorflow/core/common_runtime:core_cpu_lib", "//tensorflow/core/common_runtime/eager:context", "//tensorflow/core/common_runtime/eager:core", + "//tensorflow/core/common_runtime/eager:tensor_handle", ], ) diff --git a/tensorflow/c/experimental/saved_model/core/concrete_function.h b/tensorflow/c/experimental/saved_model/core/concrete_function.h index da3a64b91a3..934fa6d2bda 100644 --- a/tensorflow/c/experimental/saved_model/core/concrete_function.h +++ b/tensorflow/c/experimental/saved_model/core/concrete_function.h @@ -26,10 +26,14 @@ limitations under the License. namespace tensorflow { -// Note that ConcreteFunctions's lifetimes are effectively bound -// to the SavedModel they are loaded from, since they retain pointers -// to the TensorHandles owned by the SavedModel, and the FunctionDef -// of the SavedModel. +// ConcreteFunctions correspond to an instance of a tf.function with a known set +// of inputs (either through get_concrete_function) or an input_signature. +// ConcreteFunction attempts to preserve the user-facing semantics of the +// tf.function python API and can take a limited set of types as arguments +// (to be modeled in tensorflow::Value), not just Tensors. +// SavedModelAPI's ConcreteFunctions' lifetimes are bound to the SavedModel they +// are loaded from, since they retain pointers to the TensorHandles owned by the +// SavedModel, and the FunctionDef of the SavedModel. // Note(bmzhao): This class is only TEMPORARILY virtual, as a way to unblock // TFRT integration with TF Serving. Do not add more virtual implementations of // this class. Eventually we want to remove this virtual base class indirection diff --git a/tensorflow/c/experimental/saved_model/core/ops/variable_ops.cc b/tensorflow/c/experimental/saved_model/core/ops/variable_ops.cc index 492a58f816d..be9ffff99ff 100644 --- a/tensorflow/c/experimental/saved_model/core/ops/variable_ops.cc +++ b/tensorflow/c/experimental/saved_model/core/ops/variable_ops.cc @@ -37,10 +37,11 @@ static const char kNoSharingResourceID[] = Status CreateUninitializedResourceVariable(ImmediateExecutionContext* ctx, DataType dtype, TensorShape shape, + const char* raw_device_name, ImmediateTensorHandlePtr* handle) { ImmediateOpPtr varhandle_op(ctx->CreateOperation()); - TF_RETURN_IF_ERROR(varhandle_op->Reset("VarHandleOp", nullptr)); + TF_RETURN_IF_ERROR(varhandle_op->Reset("VarHandleOp", raw_device_name)); TF_RETURN_IF_ERROR(varhandle_op->SetAttrType("dtype", dtype)); // Note that if shape is unknown rank, shape.dim_sizes() will be empty, and diff --git a/tensorflow/c/experimental/saved_model/core/ops/variable_ops.h b/tensorflow/c/experimental/saved_model/core/ops/variable_ops.h index 13c941a77fe..accad1591da 100644 --- a/tensorflow/c/experimental/saved_model/core/ops/variable_ops.h +++ b/tensorflow/c/experimental/saved_model/core/ops/variable_ops.h @@ -31,6 +31,7 @@ namespace internal { // https://github.com/tensorflow/tensorflow/blob/516608035f85cec8b126712b0ff8407220206b22/tensorflow/python/ops/resource_variable_ops.py#L1867-L1872 Status CreateUninitializedResourceVariable(ImmediateExecutionContext* ctx, DataType dtype, TensorShape shape, + const char* raw_device_name, ImmediateTensorHandlePtr* handle); // Executes an AssignVariableOp using `ctx`, assigning the variable associated diff --git a/tensorflow/c/experimental/saved_model/core/ops/variable_ops_test.cc b/tensorflow/c/experimental/saved_model/core/ops/variable_ops_test.cc index 55a4a32e983..5ce027fe6d8 100644 --- a/tensorflow/c/experimental/saved_model/core/ops/variable_ops_test.cc +++ b/tensorflow/c/experimental/saved_model/core/ops/variable_ops_test.cc @@ -55,7 +55,7 @@ TEST_F(VariableOpsTest, CreateVariableSuccessful) { // Create a DT_Resource TensorHandle that points to a scalar DT_FLOAT tensor ImmediateTensorHandlePtr handle; TF_EXPECT_OK(internal::CreateUninitializedResourceVariable( - context(), DT_FLOAT, {}, &handle)); + context(), DT_FLOAT, {}, nullptr, &handle)); // The created TensorHandle should be a DT_Resource EXPECT_EQ(handle->DataType(), DT_RESOURCE); } @@ -65,7 +65,7 @@ TEST_F(VariableOpsTest, DestroyVariableSuccessful) { // Create a DT_Resource TensorHandle that points to a scalar DT_FLOAT tensor ImmediateTensorHandlePtr handle; TF_EXPECT_OK(internal::CreateUninitializedResourceVariable( - context(), DT_FLOAT, {}, &handle)); + context(), DT_FLOAT, {}, nullptr, &handle)); // Destroy the variable TF_EXPECT_OK(internal::DestroyResource(context(), handle.get())); @@ -76,7 +76,7 @@ TEST_F(VariableOpsTest, AssignVariableAndReadSuccessful) { // Create a DT_Resource TensorHandle that points to a scalar DT_FLOAT tensor ImmediateTensorHandlePtr variable; TF_EXPECT_OK(internal::CreateUninitializedResourceVariable( - context(), DT_FLOAT, {}, &variable)); + context(), DT_FLOAT, {}, nullptr, &variable)); // Create a Scalar float TensorHandle with value 42, and assign it to // the variable. diff --git a/tensorflow/c/experimental/saved_model/core/revived_types/variable.cc b/tensorflow/c/experimental/saved_model/core/revived_types/variable.cc index d831a8dd840..a212c25bd28 100644 --- a/tensorflow/c/experimental/saved_model/core/revived_types/variable.cc +++ b/tensorflow/c/experimental/saved_model/core/revived_types/variable.cc @@ -65,10 +65,11 @@ Status Variable::ReadValue(ImmediateTensorHandlePtr* out) { Status Variable::CreateUninitialized(ImmediateExecutionContext* ctx, DataType dtype, TensorShape shape, absl::optional name, + const char* raw_device_name, std::unique_ptr* output) { ImmediateTensorHandlePtr handle; TF_RETURN_IF_ERROR(internal::CreateUninitializedResourceVariable( - ctx, dtype, shape, &handle)); + ctx, dtype, shape, raw_device_name, &handle)); output->reset( new Variable(ctx, dtype, shape, std::move(name), std::move(handle))); diff --git a/tensorflow/c/experimental/saved_model/core/revived_types/variable.h b/tensorflow/c/experimental/saved_model/core/revived_types/variable.h index 48ea1d08862..13f56fda5f3 100644 --- a/tensorflow/c/experimental/saved_model/core/revived_types/variable.h +++ b/tensorflow/c/experimental/saved_model/core/revived_types/variable.h @@ -37,6 +37,7 @@ class Variable : public TensorHandleConvertible { static Status CreateUninitialized(ImmediateExecutionContext* ctx, DataType dtype, TensorShape shape, absl::optional name, + const char* raw_device_name, std::unique_ptr* output); // The dtype of the underlying variable. diff --git a/tensorflow/c/experimental/saved_model/core/saved_model_api.h b/tensorflow/c/experimental/saved_model/core/saved_model_api.h index 5d0ed63a765..ff891e13ba4 100644 --- a/tensorflow/c/experimental/saved_model/core/saved_model_api.h +++ b/tensorflow/c/experimental/saved_model/core/saved_model_api.h @@ -22,6 +22,7 @@ limitations under the License. #include #include "tensorflow/c/experimental/saved_model/core/concrete_function.h" +#include "tensorflow/c/experimental/saved_model/core/signature_def_function.h" #include "tensorflow/core/platform/status.h" namespace tensorflow { @@ -39,11 +40,11 @@ class SavedModelAPI { virtual Status GetFunction(const std::string& function_path, ConcreteFunction** function) = 0; - // Retrieve a function from a SavedModel, using the key of the + // Retrieve a SignatureDefFunction from a SavedModel, using the key of the // SignatureDef map: // https://github.com/tensorflow/tensorflow/blob/69b08900b1e991d84bce31f3b404f5ed768f339f/tensorflow/core/protobuf/meta_graph.proto#L89 virtual Status GetSignatureDefFunction(const std::string& signature_def_key, - ConcreteFunction** function) = 0; + SignatureDefFunction** function) = 0; virtual std::vector ListFunctions() = 0; diff --git a/tensorflow/c/experimental/saved_model/core/saved_model_utils.cc b/tensorflow/c/experimental/saved_model/core/saved_model_utils.cc index 0d97741d7f0..e79fd8d7001 100644 --- a/tensorflow/c/experimental/saved_model/core/saved_model_utils.cc +++ b/tensorflow/c/experimental/saved_model/core/saved_model_utils.cc @@ -122,9 +122,9 @@ Status LoadSavedVariable(ImmediateExecutionContext* ctx, tensorflow::TensorShape shape(variable.shape()); tensorflow::DataType dtype = variable.dtype(); - TF_RETURN_IF_ERROR( - Variable::CreateUninitialized(ctx, dtype, shape, name, output)); - + TF_RETURN_IF_ERROR(Variable::CreateUninitialized( + ctx, dtype, shape, name, + variable.device().empty() ? nullptr : variable.device().c_str(), output)); return Status(); } diff --git a/tensorflow/c/experimental/saved_model/core/saved_variable_loading_test.cc b/tensorflow/c/experimental/saved_model/core/saved_variable_loading_test.cc index cf58e5e3536..45b0ac00c9b 100644 --- a/tensorflow/c/experimental/saved_model/core/saved_variable_loading_test.cc +++ b/tensorflow/c/experimental/saved_model/core/saved_variable_loading_test.cc @@ -23,6 +23,7 @@ limitations under the License. #include "tensorflow/c/tensor_interface.h" #include "tensorflow/core/common_runtime/device_mgr.h" #include "tensorflow/core/common_runtime/eager/context.h" +#include "tensorflow/core/common_runtime/eager/tensor_handle.h" #include "tensorflow/core/framework/tensor.h" #include "tensorflow/core/framework/tensor.pb.h" #include "tensorflow/core/framework/tensor_shape.h" @@ -38,9 +39,15 @@ namespace { class SavedVariableLoadingTest : public ::testing::TestWithParam< std::tuple>> { public: - SavedVariableLoadingTest() - : device_mgr_(testing::CreateTestingDeviceMgr()), - ctx_(testing::CreateTestingEagerContext(device_mgr_.get())) {} + SavedVariableLoadingTest() { + SessionOptions options; + options.config.mutable_device_count()->insert({"CPU", 3}); + std::vector> devices; + TF_CHECK_OK(DeviceFactory::AddDevices( + options, "/job:localhost/replica:0/task:0", &devices)); + device_mgr_ = absl::make_unique(std::move(devices)); + ctx_ = testing::CreateTestingEagerContext(device_mgr_.get()); + } EagerContext* context() { return ctx_.get(); } @@ -67,6 +74,39 @@ TEST_P(SavedVariableLoadingTest, LoadSavedVariableSuccessful) { EXPECT_EQ(var->shape(), shape); } +// Verify that a device specified in the SavedVariable is kept. +TEST_P(SavedVariableLoadingTest, LoadSavedVariableWithDevice) { + auto& test_params = GetParam(); + DataType dtype = std::get<0>(test_params); + TensorShape shape(std::get<1>(test_params)); + + SavedVariable saved_variable; + saved_variable.set_dtype(dtype); + saved_variable.set_device("/job:localhost/replica:0/task:0/device:CPU:1"), + shape.AsProto(saved_variable.mutable_shape()); + + std::unique_ptr var; + TF_ASSERT_OK(internal::LoadSavedVariable(context(), saved_variable, &var)); + EXPECT_EQ(down_cast(var->handle())->resource_device()->name(), + "/job:localhost/replica:0/task:0/device:CPU:1"); +} + +// Verify load failure if a non-existing device is specified. +TEST_P(SavedVariableLoadingTest, LoadSavedVariableWithInvalidDevice) { + auto& test_params = GetParam(); + DataType dtype = std::get<0>(test_params); + TensorShape shape(std::get<1>(test_params)); + + SavedVariable saved_variable; + saved_variable.set_dtype(dtype); + saved_variable.set_device("/job:localhost/replica:0/task:0/device:CPU:99"), + shape.AsProto(saved_variable.mutable_shape()); + + std::unique_ptr var; + ASSERT_NE(Status::OK(), + internal::LoadSavedVariable(context(), saved_variable, &var)); +} + // Assigning and reading values should yield // consistent results. TEST_P(SavedVariableLoadingTest, AssignAndReadVariableSuccesful) { @@ -79,7 +119,7 @@ TEST_P(SavedVariableLoadingTest, AssignAndReadVariableSuccesful) { Status status; std::unique_ptr var; TF_EXPECT_OK(Variable::CreateUninitialized(context(), dtype, shape, - absl::nullopt, &var)); + absl::nullopt, nullptr, &var)); // Create a TensorHandle ImmediateTensorHandlePtr expected_handle = diff --git a/tensorflow/c/experimental/saved_model/core/signature_def_function.h b/tensorflow/c/experimental/saved_model/core/signature_def_function.h new file mode 100644 index 00000000000..0a217f3cc21 --- /dev/null +++ b/tensorflow/c/experimental/saved_model/core/signature_def_function.h @@ -0,0 +1,62 @@ +/* Copyright 2020 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_C_EXPERIMENTAL_SAVED_MODEL_CORE_SIGNATURE_DEF_FUNCTION_H_ +#define TENSORFLOW_C_EXPERIMENTAL_SAVED_MODEL_CORE_SIGNATURE_DEF_FUNCTION_H_ + +#include +#include + +#include "absl/types/span.h" +#include "tensorflow/c/eager/immediate_execution_operation.h" +#include "tensorflow/c/eager/immediate_execution_tensor_handle.h" +#include "tensorflow/c/experimental/saved_model/core/signature_def_function_metadata.h" + +namespace tensorflow { + +// See tensorflow/cc/experimental/saved_model/public/signature_def_function.h +// for SignatureDefFunction's intended user-facing semantics. +// This class is the "implementation" C++ part of the C++/C/C++ sandwich for +// a SignatureDefFunction. +// Note(bmzhao): Implementation-wise, SignatureDefFunctions are always saved as +// a "BareConcreteFunction", w/o a FunctionSpec, rather than a SavedFunction: +// https://github.com/tensorflow/tensorflow/blob/9bcefa44cd335c1db4a703a13da09f29ae1bbdb2/tensorflow/core/protobuf/saved_object_graph.proto#L60 +// Additionally they are guaranteed to be children of the .signatures attribute +// of the root object, where the child object "name" is the signature_def key: +// https://github.com/tensorflow/tensorflow/blob/9bcefa44cd335c1db4a703a13da09f29ae1bbdb2/tensorflow/python/saved_model/signature_serialization.py#L181-L230 +// One of the critical requirements of SignatureDef functions is that their +// inputs and outputs are "named". For example, a `.signatures` function: +// a. Requires users to pass: kwargs of all inputs: +// https://github.com/tensorflow/tensorflow/blob/26c4ee0c833e74f94d0102d8b005c41a28b44445/tensorflow/python/saved_model/signature_serialization.py#L119-L126 +// b. Returns a dictionary of named outputs. +// https://github.com/tensorflow/tensorflow/blob/26c4ee0c833e74f94d0102d8b005c41a28b44445/tensorflow/python/saved_model/signature_serialization.py#L153-L161 +// Since SignatureDefFunctions do not have FunctionSpecs, but guarantee the +// dictionary of inputs/outputs, we can parse these dictionaries' keys to obtain +// the input/output names of the SignatureDef: +// https://github.com/tensorflow/tensorflow/blob/9bcefa44cd335c1db4a703a13da09f29ae1bbdb2/tensorflow/core/protobuf/meta_graph.proto#L318-L321 +class SignatureDefFunction { + public: + virtual ~SignatureDefFunction() = default; + + // Creates a "Call" Op used to execute the function. + virtual Status MakeCallOp(absl::Span inputs, + ImmediateOpPtr* out) const = 0; + + virtual const SignatureDefFunctionMetadata& GetFunctionMetadata() const = 0; +}; + +} // namespace tensorflow + +#endif // TENSORFLOW_C_EXPERIMENTAL_SAVED_MODEL_CORE_SIGNATURE_DEF_FUNCTION_H_ diff --git a/tensorflow/c/experimental/saved_model/core/signature_def_function_metadata.h b/tensorflow/c/experimental/saved_model/core/signature_def_function_metadata.h new file mode 100644 index 00000000000..5a579676d4e --- /dev/null +++ b/tensorflow/c/experimental/saved_model/core/signature_def_function_metadata.h @@ -0,0 +1,27 @@ +/* Copyright 2020 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_C_EXPERIMENTAL_SAVED_MODEL_CORE_SIGNATURE_DEF_FUNCTION_METADATA_H_ +#define TENSORFLOW_C_EXPERIMENTAL_SAVED_MODEL_CORE_SIGNATURE_DEF_FUNCTION_METADATA_H_ + +namespace tensorflow { + +class SignatureDefFunctionMetadata { + // TODO(bmzhao): Fill in with fields as necessary +}; + +} // namespace tensorflow + +#endif // TENSORFLOW_C_EXPERIMENTAL_SAVED_MODEL_CORE_SIGNATURE_DEF_FUNCTION_METADATA_H_ diff --git a/tensorflow/c/experimental/saved_model/core/test_utils.cc b/tensorflow/c/experimental/saved_model/core/test_utils.cc index b803d129b90..d551919ea94 100644 --- a/tensorflow/c/experimental/saved_model/core/test_utils.cc +++ b/tensorflow/c/experimental/saved_model/core/test_utils.cc @@ -28,7 +28,7 @@ limitations under the License. #include "tensorflow/core/framework/tensor.h" #include "tensorflow/core/framework/types.h" #include "tensorflow/core/framework/types.pb.h" -#include "tensorflow/core/lib/bfloat16/bfloat16.h" +#include "tensorflow/core/platform/bfloat16.h" #include "tensorflow/core/platform/logging.h" #include "tensorflow/core/platform/test.h" #include "tensorflow/core/platform/types.h" diff --git a/tensorflow/c/experimental/saved_model/core/tf_saved_model_api.cc b/tensorflow/c/experimental/saved_model/core/tf_saved_model_api.cc index 0f0102be857..ab7052b52ed 100644 --- a/tensorflow/c/experimental/saved_model/core/tf_saved_model_api.cc +++ b/tensorflow/c/experimental/saved_model/core/tf_saved_model_api.cc @@ -34,6 +34,7 @@ limitations under the License. #include "tensorflow/c/experimental/saved_model/core/revived_types/tf_concrete_function.h" #include "tensorflow/c/experimental/saved_model/core/revived_types/variable.h" #include "tensorflow/c/experimental/saved_model/core/saved_model_utils.h" +#include "tensorflow/c/experimental/saved_model/core/signature_def_function.h" #include "tensorflow/cc/saved_model/bundle_v2.h" #include "tensorflow/cc/saved_model/constants.h" #include "tensorflow/core/framework/attr_value.pb.h" @@ -305,7 +306,7 @@ Status TFSavedModelAPI::GetFunction(const std::string& function_path, } Status TFSavedModelAPI::GetSignatureDefFunction( - const std::string& signature_def_key, ConcreteFunction** function) { + const std::string& signature_def_key, SignatureDefFunction** function) { // TODO(bmzhao): Add support for retrieving a signaturedef function. return errors::Unimplemented( "Retrieving SignatureDef functions is unimplemented currently"); diff --git a/tensorflow/c/experimental/saved_model/core/tf_saved_model_api.h b/tensorflow/c/experimental/saved_model/core/tf_saved_model_api.h index fc8e738e86f..fd07c09474b 100644 --- a/tensorflow/c/experimental/saved_model/core/tf_saved_model_api.h +++ b/tensorflow/c/experimental/saved_model/core/tf_saved_model_api.h @@ -28,6 +28,7 @@ limitations under the License. #include "tensorflow/c/experimental/saved_model/core/revived_types/tensorhandle_convertible.h" #include "tensorflow/c/experimental/saved_model/core/revived_types/tf_concrete_function.h" #include "tensorflow/c/experimental/saved_model/core/saved_model_api.h" +#include "tensorflow/c/experimental/saved_model/core/signature_def_function.h" #include "tensorflow/cc/saved_model/bundle_v2.h" #include "tensorflow/core/platform/status.h" @@ -55,7 +56,7 @@ class TFSavedModelAPI : public SavedModelAPI { ConcreteFunction** function) override; Status GetSignatureDefFunction(const std::string& signature_def_key, - ConcreteFunction** function) override; + SignatureDefFunction** function) override; static Status Load( const std::string& directory, diff --git a/tensorflow/c/experimental/saved_model/internal/BUILD b/tensorflow/c/experimental/saved_model/internal/BUILD index 323298c5fc1..c0d121a4aee 100644 --- a/tensorflow/c/experimental/saved_model/internal/BUILD +++ b/tensorflow/c/experimental/saved_model/internal/BUILD @@ -142,6 +142,8 @@ cc_library( ":concrete_function_list_type", ":concrete_function_type", ":saved_model_api_type", + ":signature_def_function", + ":signature_def_function_type", "//tensorflow/c:c_api_macros", "//tensorflow/c:tf_status", "//tensorflow/c:tf_status_internal", @@ -165,6 +167,77 @@ cc_library( ], ) +cc_library( + name = "signature_def_function", + srcs = [ + "signature_def_function.cc", + ], + hdrs = [ + "//tensorflow/c/experimental/saved_model/public:signature_def_function.h", + ], + copts = tf_copts(), + visibility = [ + "//tensorflow/c/experimental/saved_model/public:__pkg__", + ], + deps = [ + ":signature_def_function_metadata", + ":signature_def_function_metadata_type", + ":signature_def_function_type", + "//tensorflow/c:c_api_macros", + "//tensorflow/c:tf_status_internal", + "//tensorflow/c/eager:abstract_tensor_handle", + "//tensorflow/c/eager:c_api", + "//tensorflow/c/eager:immediate_execution_operation", + "//tensorflow/c/eager:tfe_op_internal", + "//tensorflow/c/eager:tfe_tensorhandle_internal", + "//tensorflow/c/experimental/saved_model/core:signature_def_function", + "//tensorflow/c/experimental/saved_model/core:signature_def_function_metadata", + "//tensorflow/core:lib", + "@com_google_absl//absl/types:span", + ], +) + +cc_library( + name = "signature_def_function_type", + hdrs = [ + "signature_def_function_type.h", + ], + deps = [ + "//tensorflow/c:conversion_macros", + "//tensorflow/c/experimental/saved_model/core:signature_def_function", + ], +) + +cc_library( + name = "signature_def_function_metadata", + srcs = [ + "signature_def_function_metadata.cc", + ], + hdrs = [ + "//tensorflow/c/experimental/saved_model/public:signature_def_function_metadata.h", + ], + copts = tf_copts(), + visibility = [ + "//tensorflow/c/experimental/saved_model/public:__pkg__", + ], + deps = [ + ":signature_def_function_metadata_type", + "//tensorflow/c:c_api_macros", + "//tensorflow/c/experimental/saved_model/core:signature_def_function_metadata", + ], +) + +cc_library( + name = "signature_def_function_metadata_type", + hdrs = [ + "signature_def_function_metadata_type.h", + ], + deps = [ + "//tensorflow/c:conversion_macros", + "//tensorflow/c/experimental/saved_model/core:signature_def_function_metadata", + ], +) + tf_cc_test( name = "saved_model_api_test", size = "small", diff --git a/tensorflow/c/experimental/saved_model/internal/saved_model_api.cc b/tensorflow/c/experimental/saved_model/internal/saved_model_api.cc index 983c98affb2..b89fb9f6d64 100644 --- a/tensorflow/c/experimental/saved_model/internal/saved_model_api.cc +++ b/tensorflow/c/experimental/saved_model/internal/saved_model_api.cc @@ -26,6 +26,7 @@ limitations under the License. #include "tensorflow/c/experimental/saved_model/internal/concrete_function_list_type.h" #include "tensorflow/c/experimental/saved_model/internal/concrete_function_type.h" #include "tensorflow/c/experimental/saved_model/internal/saved_model_api_type.h" +#include "tensorflow/c/experimental/saved_model/internal/signature_def_function_type.h" #include "tensorflow/c/tf_status.h" #include "tensorflow/c/tf_status_internal.h" #include "tensorflow/core/common_runtime/eager/context.h" @@ -106,9 +107,11 @@ TF_ConcreteFunction* TF_GetSavedModelConcreteFunction(TF_SavedModel* model, return tensorflow::wrap(result); } -TF_CAPI_EXPORT extern TF_ConcreteFunction* TF_GetSavedModelSignatureDefFunction( - TF_SavedModel* model, const char* signature_def_key, TF_Status* status) { - tensorflow::ConcreteFunction* result = nullptr; +TF_CAPI_EXPORT extern TF_SignatureDefFunction* +TF_GetSavedModelSignatureDefFunction(TF_SavedModel* model, + const char* signature_def_key, + TF_Status* status) { + tensorflow::SignatureDefFunction* result = nullptr; tensorflow::Status get_function_status = tensorflow::unwrap(model)->GetSignatureDefFunction(signature_def_key, &result); diff --git a/tensorflow/c/experimental/saved_model/internal/signature_def_function.cc b/tensorflow/c/experimental/saved_model/internal/signature_def_function.cc new file mode 100644 index 00000000000..64f7506f32e --- /dev/null +++ b/tensorflow/c/experimental/saved_model/internal/signature_def_function.cc @@ -0,0 +1,53 @@ +/* Copyright 2020 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tensorflow/c/experimental/saved_model/public/signature_def_function.h" + +#include "absl/types/span.h" +#include "tensorflow/c/eager/abstract_tensor_handle.h" +#include "tensorflow/c/eager/immediate_execution_operation.h" +#include "tensorflow/c/eager/tfe_op_internal.h" +#include "tensorflow/c/eager/tfe_tensorhandle_internal.h" +#include "tensorflow/c/experimental/saved_model/core/signature_def_function.h" +#include "tensorflow/c/experimental/saved_model/core/signature_def_function_metadata.h" +#include "tensorflow/c/experimental/saved_model/internal/signature_def_function_metadata_type.h" +#include "tensorflow/c/experimental/saved_model/internal/signature_def_function_type.h" +#include "tensorflow/c/tf_status_internal.h" +#include "tensorflow/core/platform/status.h" + +extern "C" { + +TF_SignatureDefFunctionMetadata* TF_SignatureDefFunctionGetMetadata( + TF_SignatureDefFunction* func) { + return tensorflow::wrap(const_cast( + &tensorflow::unwrap(func)->GetFunctionMetadata())); +} + +TFE_Op* TF_SignatureDefFunctionMakeCallOp(TF_SignatureDefFunction* func, + TFE_TensorHandle** inputs, + int num_inputs, TF_Status* status) { + tensorflow::ImmediateOpPtr call_op; + absl::Span input_span( + reinterpret_cast( + tensorflow::unwrap(inputs)), + static_cast(num_inputs)); + status->status = tensorflow::unwrap(func)->MakeCallOp(input_span, &call_op); + if (!status->status.ok()) { + return nullptr; + } + return tensorflow::wrap(call_op.release()); +} + +} // end extern "C" diff --git a/tensorflow/c/experimental/saved_model/internal/signature_def_function_metadata.cc b/tensorflow/c/experimental/saved_model/internal/signature_def_function_metadata.cc new file mode 100644 index 00000000000..c5c3616211c --- /dev/null +++ b/tensorflow/c/experimental/saved_model/internal/signature_def_function_metadata.cc @@ -0,0 +1,20 @@ +/* Copyright 2020 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tensorflow/c/experimental/saved_model/public/signature_def_function_metadata.h" + +#include "tensorflow/c/experimental/saved_model/internal/signature_def_function_metadata_type.h" + +// TODO(bmzhao): Add getter functions here as necessary. diff --git a/tensorflow/c/experimental/saved_model/internal/signature_def_function_metadata_type.h b/tensorflow/c/experimental/saved_model/internal/signature_def_function_metadata_type.h new file mode 100644 index 00000000000..fa6d0f6541e --- /dev/null +++ b/tensorflow/c/experimental/saved_model/internal/signature_def_function_metadata_type.h @@ -0,0 +1,31 @@ +/* Copyright 2020 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_C_EXPERIMENTAL_SAVED_MODEL_INTERNAL_SIGNATURE_DEF_FUNCTION_METADATA_TYPE_H_ +#define TENSORFLOW_C_EXPERIMENTAL_SAVED_MODEL_INTERNAL_SIGNATURE_DEF_FUNCTION_METADATA_TYPE_H_ + +#include "tensorflow/c/conversion_macros.h" +#include "tensorflow/c/experimental/saved_model/core/signature_def_function_metadata.h" + +typedef struct TF_SignatureDefFunctionMetadata TF_SignatureDefFunctionMetadata; + +namespace tensorflow { + +DEFINE_CONVERSION_FUNCTIONS(tensorflow::SignatureDefFunctionMetadata, + TF_SignatureDefFunctionMetadata) + +} // namespace tensorflow + +#endif // TENSORFLOW_C_EXPERIMENTAL_SAVED_MODEL_INTERNAL_SIGNATURE_DEF_FUNCTION_METADATA_TYPE_H_ diff --git a/tensorflow/c/experimental/saved_model/internal/signature_def_function_type.h b/tensorflow/c/experimental/saved_model/internal/signature_def_function_type.h new file mode 100644 index 00000000000..ca44dc43bd6 --- /dev/null +++ b/tensorflow/c/experimental/saved_model/internal/signature_def_function_type.h @@ -0,0 +1,31 @@ +/* Copyright 2020 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_C_EXPERIMENTAL_SAVED_MODEL_INTERNAL_SIGNATURE_DEF_FUNCTION_TYPE_H_ +#define TENSORFLOW_C_EXPERIMENTAL_SAVED_MODEL_INTERNAL_SIGNATURE_DEF_FUNCTION_TYPE_H_ + +#include "tensorflow/c/conversion_macros.h" +#include "tensorflow/c/experimental/saved_model/core/signature_def_function.h" + +typedef struct TF_SignatureDefFunction TF_SignatureDefFunction; + +namespace tensorflow { + +DEFINE_CONVERSION_FUNCTIONS(tensorflow::SignatureDefFunction, + TF_SignatureDefFunction) + +} // namespace tensorflow + +#endif // TENSORFLOW_C_EXPERIMENTAL_SAVED_MODEL_INTERNAL_SIGNATURE_DEF_FUNCTION_TYPE_H_ diff --git a/tensorflow/c/experimental/saved_model/public/BUILD b/tensorflow/c/experimental/saved_model/public/BUILD index af65e05e7f6..d29585ae1ba 100644 --- a/tensorflow/c/experimental/saved_model/public/BUILD +++ b/tensorflow/c/experimental/saved_model/public/BUILD @@ -24,6 +24,8 @@ exports_files( "concrete_function_list.h", "function_metadata.h", "saved_model_api.h", + "signature_def_function.h", + "signature_def_function_metadata.h", ], visibility = ["//tensorflow/c/experimental/saved_model/internal:__pkg__"], ) @@ -39,6 +41,8 @@ cc_library( ":concrete_function_list", ":function_metadata", ":saved_model_api", + ":signature_def_function", + ":signature_def_function_metadata", ], ) @@ -61,3 +65,13 @@ alias( name = "saved_model_api", actual = "//tensorflow/c/experimental/saved_model/internal:saved_model_api", ) + +alias( + name = "signature_def_function", + actual = "//tensorflow/c/experimental/saved_model/internal:signature_def_function", +) + +alias( + name = "signature_def_function_metadata", + actual = "//tensorflow/c/experimental/saved_model/internal:signature_def_function_metadata", +) diff --git a/tensorflow/c/experimental/saved_model/public/c_saved_model_api.h b/tensorflow/c/experimental/saved_model/public/c_saved_model_api.h index 30f533f140a..cedb9de66b8 100644 --- a/tensorflow/c/experimental/saved_model/public/c_saved_model_api.h +++ b/tensorflow/c/experimental/saved_model/public/c_saved_model_api.h @@ -21,6 +21,8 @@ limitations under the License. #include "tensorflow/c/experimental/saved_model/public/concrete_function_list.h" #include "tensorflow/c/experimental/saved_model/public/function_metadata.h" #include "tensorflow/c/experimental/saved_model/public/saved_model_api.h" +#include "tensorflow/c/experimental/saved_model/public/signature_def_function.h" +#include "tensorflow/c/experimental/saved_model/public/signature_def_function_metadata.h" // IWYU pragma: end_exports #endif // TENSORFLOW_C_EXPERIMENTAL_SAVED_MODEL_PUBLIC_C_SAVED_MODEL_API_H_ diff --git a/tensorflow/c/experimental/saved_model/public/concrete_function.h b/tensorflow/c/experimental/saved_model/public/concrete_function.h index ee5292294d6..0fd0f70cf16 100644 --- a/tensorflow/c/experimental/saved_model/public/concrete_function.h +++ b/tensorflow/c/experimental/saved_model/public/concrete_function.h @@ -40,6 +40,13 @@ TF_CAPI_EXPORT extern TF_FunctionMetadata* TF_ConcreteFunctionGetMetadata( // The caller is responsible for deleting the returned TFE_Op. If op // construction fails, `status` will be non-OK and the returned pointer will be // null. +// TODO(bmzhao): Remove this function in a subsequent change; Design + implement +// a Function Execution interface for ConcreteFunction that accepts a tagged +// union of types (tensorflow::Value). This effectively requires moving much of +// the implementation of function.py/def_function.py to C++, and exposing a +// high-level API here. A strawman for what this interface could look like: +// TF_Value* TF_ExecuteFunction(TFE_Context*, TF_ConcreteFunction*, TF_Value* +// inputs, int num_inputs, TF_Status* status); TF_CAPI_EXPORT extern TFE_Op* TF_ConcreteFunctionGetCallOp( TF_ConcreteFunction* func, TFE_TensorHandle** inputs, int num_inputs, TF_Status* status); diff --git a/tensorflow/c/experimental/saved_model/public/saved_model_api.h b/tensorflow/c/experimental/saved_model/public/saved_model_api.h index 875167bec63..80ba37bab26 100644 --- a/tensorflow/c/experimental/saved_model/public/saved_model_api.h +++ b/tensorflow/c/experimental/saved_model/public/saved_model_api.h @@ -19,6 +19,7 @@ limitations under the License. #include "tensorflow/c/c_api_macros.h" #include "tensorflow/c/experimental/saved_model/public/concrete_function.h" #include "tensorflow/c/experimental/saved_model/public/concrete_function_list.h" +#include "tensorflow/c/experimental/saved_model/public/signature_def_function.h" #include "tensorflow/c/tf_status.h" #ifdef __cplusplus @@ -91,10 +92,13 @@ TF_CAPI_EXPORT extern TF_ConcreteFunction* TF_GetSavedModelConcreteFunction( // status - Set to OK on success and an appropriate error on failure. // Returns: // If status is not OK, returns nullptr. Otherwise, returns a -// TF_ConcreteFunction instance. Once `model` is deleted, all -// `TF_ConcreteFunctions` retrieved from it are invalid, and have been deleted. -TF_CAPI_EXPORT extern TF_ConcreteFunction* TF_GetSavedModelSignatureDefFunction( - TF_SavedModel* model, const char* signature_def_key, TF_Status* status); +// TF_SignatureDefFunction instance. Once `model` is deleted, all +// `TF_SignatureDefFunctions` retrieved from it are invalid, and have been +// deleted. +TF_CAPI_EXPORT extern TF_SignatureDefFunction* +TF_GetSavedModelSignatureDefFunction(TF_SavedModel* model, + const char* signature_def_key, + TF_Status* status); // Returns a list of all ConcreteFunctions stored in this SavedModel. // The lifetime of the returned list is bound to `model`. diff --git a/tensorflow/c/experimental/saved_model/public/signature_def_function.h b/tensorflow/c/experimental/saved_model/public/signature_def_function.h new file mode 100644 index 00000000000..16471fdc1fa --- /dev/null +++ b/tensorflow/c/experimental/saved_model/public/signature_def_function.h @@ -0,0 +1,50 @@ +/* Copyright 2020 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_C_EXPERIMENTAL_SAVED_MODEL_PUBLIC_SIGNATURE_DEF_FUNCTION_H_ +#define TENSORFLOW_C_EXPERIMENTAL_SAVED_MODEL_PUBLIC_SIGNATURE_DEF_FUNCTION_H_ + +#include "tensorflow/c/c_api_macros.h" +#include "tensorflow/c/eager/c_api.h" +#include "tensorflow/c/experimental/saved_model/public/signature_def_function_metadata.h" + +#ifdef __cplusplus +extern "C" { +#endif // __cplusplus + +// An opaque type that corresponds to a SignatureDefFunction loaded from a +// SavedModel. +typedef struct TF_SignatureDefFunction TF_SignatureDefFunction; + +// Returns FunctionMetadata associated with `func`. Metadata's lifetime is +// bound to `func`, which is bound to the TF_SavedModel it was loaded from. +TF_CAPI_EXPORT extern TF_SignatureDefFunctionMetadata* +TF_SignatureDefFunctionGetMetadata(TF_SignatureDefFunction* func); + +// Returns a TFE_Op suitable for executing this function. Caller must provide +// all function inputs in `inputs`, and must not add any additional inputs on +// the returned op. (i.e. don't call TFE_OpAddInput or TFE_OpAddInputList). +// The caller is responsible for deleting the returned TFE_Op. If op +// construction fails, `status` will be non-OK and the returned pointer will be +// null. +TF_CAPI_EXPORT extern TFE_Op* TF_SignatureDefFunctionMakeCallOp( + TF_SignatureDefFunction* func, TFE_TensorHandle** inputs, int num_inputs, + TF_Status* status); + +#ifdef __cplusplus +} // end extern "C" +#endif // __cplusplus + +#endif // TENSORFLOW_C_EXPERIMENTAL_SAVED_MODEL_PUBLIC_SIGNATURE_DEF_FUNCTION_H_ diff --git a/tensorflow/c/experimental/saved_model/public/signature_def_function_metadata.h b/tensorflow/c/experimental/saved_model/public/signature_def_function_metadata.h new file mode 100644 index 00000000000..6f4459732c4 --- /dev/null +++ b/tensorflow/c/experimental/saved_model/public/signature_def_function_metadata.h @@ -0,0 +1,31 @@ +/* Copyright 2020 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_C_EXPERIMENTAL_SAVED_MODEL_PUBLIC_SIGNATURE_DEF_FUNCTION_METADATA_H_ +#define TENSORFLOW_C_EXPERIMENTAL_SAVED_MODEL_PUBLIC_SIGNATURE_DEF_FUNCTION_METADATA_H_ + +#ifdef __cplusplus +extern "C" { +#endif // __cplusplus + +// An opaque type that corresponds to a SignatureDefFunction loaded from a +// SavedModel. +typedef struct TF_SignatureDefFunctionMetadata TF_SignatureDefFunctionMetadata; + +#ifdef __cplusplus +} // end extern "C" +#endif // __cplusplus + +#endif // TENSORFLOW_C_EXPERIMENTAL_SAVED_MODEL_PUBLIC_SIGNATURE_DEF_FUNCTION_METADATA_H_ diff --git a/tensorflow/c/experimental/stream_executor/BUILD b/tensorflow/c/experimental/stream_executor/BUILD new file mode 100644 index 00000000000..7daa311d461 --- /dev/null +++ b/tensorflow/c/experimental/stream_executor/BUILD @@ -0,0 +1,60 @@ +# Description: +# StreamExecutor C API. + +load( + "//tensorflow:tensorflow.bzl", + "tf_cc_test", +) + +package( + licenses = ["notice"], # Apache 2.0 +) + +cc_library( + name = "stream_executor", + srcs = ["stream_executor.cc"], + hdrs = ["stream_executor.h"], + visibility = ["//visibility:public"], + deps = [ + ":stream_executor_internal", + "//tensorflow/c:c_api_macros", + "//tensorflow/c:tf_status", + "//tensorflow/c:tf_status_helper", + "//tensorflow/core:lib", + "//tensorflow/stream_executor:executor_cache", + "//tensorflow/stream_executor:multi_platform_manager", + "//tensorflow/stream_executor:platform", + "//tensorflow/stream_executor:stream_executor_internal", + "//tensorflow/stream_executor:stream_executor_pimpl", + "//tensorflow/stream_executor:timer", + ], +) + +cc_library( + name = "stream_executor_internal", + hdrs = [ + "stream_executor.h", + "stream_executor_internal.h", + ], + deps = [ + "//tensorflow/c:c_api_macros", + "//tensorflow/c:tf_status", + "//tensorflow/stream_executor:executor_cache", + "//tensorflow/stream_executor/lib", + ], +) + +tf_cc_test( + name = "stream_executor_test", + srcs = ["stream_executor_test.cc"], + deps = [ + ":stream_executor", + ":stream_executor_internal", + "//tensorflow/core:test", + "//tensorflow/core:test_main", + "//tensorflow/core/protobuf:error_codes_proto_impl_cc", + "//tensorflow/stream_executor:multi_platform_manager", + "//tensorflow/stream_executor:stream", + "//tensorflow/stream_executor:stream_executor_pimpl", + ], +) diff --git a/tensorflow/c/experimental/stream_executor/stream_executor.cc b/tensorflow/c/experimental/stream_executor/stream_executor.cc new file mode 100644 index 00000000000..0e55ba3d72a --- /dev/null +++ b/tensorflow/c/experimental/stream_executor/stream_executor.cc @@ -0,0 +1,809 @@ +/* Copyright 2020 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ +// This file extends/implements core stream executor base classes in terms of +// the C API defined in stream_executor.h. A class "CSomething" represents a +// "Something" that can be manipulated via calls in the C interface and a C +// struct called "SP_Something". +// +// This file also contains stream_executor::Platform registration for pluggable +// device. +#include "tensorflow/c/experimental/stream_executor/stream_executor.h" + +#include + +#include "tensorflow/c/experimental/stream_executor/stream_executor_internal.h" +#include "tensorflow/c/tf_status_helper.h" +#include "tensorflow/core/platform/env.h" +#include "tensorflow/core/platform/errors.h" +#include "tensorflow/core/platform/logging.h" +#include "tensorflow/core/platform/status.h" +#include "tensorflow/stream_executor/executor_cache.h" +#include "tensorflow/stream_executor/multi_platform_manager.h" +#include "tensorflow/stream_executor/platform.h" +#include "tensorflow/stream_executor/stream_executor_internal.h" +#include "tensorflow/stream_executor/stream_executor_pimpl.h" +#include "tensorflow/stream_executor/timer.h" + +using tensorflow::StatusFromTF_Status; + +namespace stream_executor { +namespace { + +#define VALIDATE_STRUCT_SIZE(STRUCT_NAME, STRUCT_OBJ, SIZE_VALUE_NAME) \ + do { \ + if (STRUCT_OBJ.struct_size == 0) { \ + return port::FailedPreconditionError( \ + "struct_size field in " #STRUCT_NAME \ + " must be set to " #SIZE_VALUE_NAME "."); \ + } \ + } while (0) + +#define VALIDATE_MEMBER(STRUCT_NAME, STRUCT_OBJ, NAME) \ + do { \ + if (STRUCT_OBJ.NAME == 0) { \ + return port::FailedPreconditionError( \ + "'" #NAME "' field in " #STRUCT_NAME " must be set."); \ + } \ + } while (0) + +port::Status ValidateSPPlatform(const SP_Platform& platform) { + VALIDATE_STRUCT_SIZE(SP_Platform, platform, SP_PLATFORM_STRUCT_SIZE); + VALIDATE_MEMBER(SP_Platform, platform, name); + VALIDATE_MEMBER(SP_Platform, platform, type); + VALIDATE_MEMBER(SP_Platform, platform, visible_device_count); + VALIDATE_MEMBER(SP_Platform, platform, create_device); + VALIDATE_MEMBER(SP_Platform, platform, destroy_device); + VALIDATE_MEMBER(SP_Platform, platform, create_stream_executor); + VALIDATE_MEMBER(SP_Platform, platform, destroy_stream_executor); + VALIDATE_MEMBER(SP_Platform, platform, create_timer_fns); + VALIDATE_MEMBER(SP_Platform, platform, destroy_timer_fns); + return port::Status::OK(); +} + +port::Status ValidateSPTimerFns(const SP_TimerFns& timer_fns) { + VALIDATE_STRUCT_SIZE(SP_TimerFns, timer_fns, SP_TIMER_FNS_STRUCT_SIZE); + VALIDATE_MEMBER(SP_TimerFns, timer_fns, nanoseconds); + return port::Status::OK(); +} + +port::Status ValidateSPAllocatorStats(const SP_AllocatorStats& stats) { + VALIDATE_STRUCT_SIZE(SP_AllocatorStats, stats, SP_ALLOCATORSTATS_STRUCT_SIZE); + // All other fields could theoretically be zero/null. + return port::Status::OK(); +} + +port::Status ValidateSPDeviceMemoryBase(const SP_DeviceMemoryBase& mem) { + VALIDATE_STRUCT_SIZE(SP_DeviceMemoryBase, mem, + SP_DEVICE_MEMORY_BASE_STRUCT_SIZE); + // All other fields could theoretically be zero/null. + return port::Status::OK(); +} + +port::Status ValidateSPDevice(const SP_Device& device) { + VALIDATE_STRUCT_SIZE(SP_Device, device, SP_DEVICE_STRUCT_SIZE); + // All other fields could theoretically be zero/null. + return port::Status::OK(); +} + +port::Status ValidateSPStreamExecutor(const SP_StreamExecutor& se) { + VALIDATE_STRUCT_SIZE(SP_StreamExecutor, se, SP_STREAM_EXECUTOR_STRUCT_SIZE); + VALIDATE_MEMBER(SP_StreamExecutor, se, allocate); + VALIDATE_MEMBER(SP_StreamExecutor, se, deallocate); + VALIDATE_MEMBER(SP_StreamExecutor, se, get_allocator_stats); + VALIDATE_MEMBER(SP_StreamExecutor, se, device_memory_usage); + VALIDATE_MEMBER(SP_StreamExecutor, se, create_stream); + VALIDATE_MEMBER(SP_StreamExecutor, se, destroy_stream); + VALIDATE_MEMBER(SP_StreamExecutor, se, create_stream_dependency); + VALIDATE_MEMBER(SP_StreamExecutor, se, get_stream_status); + VALIDATE_MEMBER(SP_StreamExecutor, se, create_event); + VALIDATE_MEMBER(SP_StreamExecutor, se, destroy_event); + VALIDATE_MEMBER(SP_StreamExecutor, se, get_event_status); + VALIDATE_MEMBER(SP_StreamExecutor, se, record_event); + VALIDATE_MEMBER(SP_StreamExecutor, se, wait_for_event); + VALIDATE_MEMBER(SP_StreamExecutor, se, create_timer); + VALIDATE_MEMBER(SP_StreamExecutor, se, destroy_timer); + VALIDATE_MEMBER(SP_StreamExecutor, se, start_timer); + VALIDATE_MEMBER(SP_StreamExecutor, se, stop_timer); + VALIDATE_MEMBER(SP_StreamExecutor, se, memcpy_dtoh); + VALIDATE_MEMBER(SP_StreamExecutor, se, memcpy_htod); + VALIDATE_MEMBER(SP_StreamExecutor, se, sync_memcpy_dtoh); + VALIDATE_MEMBER(SP_StreamExecutor, se, sync_memcpy_htod); + VALIDATE_MEMBER(SP_StreamExecutor, se, block_host_for_event); + VALIDATE_MEMBER(SP_StreamExecutor, se, synchronize_all_activity); + VALIDATE_MEMBER(SP_StreamExecutor, se, host_callback); + return port::Status::OK(); +} + +port::Status ValidateSEPlatformRegistrationParams( + const SE_PlatformRegistrationParams& params) { + VALIDATE_STRUCT_SIZE(SE_PlatformRegistrationParams, params, + SE_PLATFORM_REGISTRATION_PARAMS_STRUCT_SIZE); + VALIDATE_MEMBER(SE_PlatformRegistrationParams, params, destroy_platform); + return port::Status::OK(); +} + +#undef VALIDATE_MEMBER + +struct TFStatusDeleter { + void operator()(TF_Status* s) const { TF_DeleteStatus(s); } +}; +using OwnedTFStatus = std::unique_ptr; + +class CStream : public internal::StreamInterface { + public: + CStream(SP_Device* device, SP_StreamExecutor* stream_executor) + : device_(device), + stream_executor_(stream_executor), + stream_handle_(nullptr) {} + ~CStream() override { Destroy(); } + + port::Status Create() { + OwnedTFStatus c_status(TF_NewStatus()); + stream_executor_->create_stream(device_, &stream_handle_, c_status.get()); + port::Status s = StatusFromTF_Status(c_status.get()); + return s; + } + + void Destroy() { + if (stream_handle_ != nullptr) { + stream_executor_->destroy_stream(device_, stream_handle_); + stream_handle_ = nullptr; + } + } + + SP_Stream Handle() { return stream_handle_; } + + private: + SP_Device* device_; + SP_StreamExecutor* stream_executor_; + SP_Stream stream_handle_; +}; + +// Converts SE_EventStatus to Event::Status. +Event::Status SEEventStatusToEventStatus(SE_EventStatus s) { + switch (s) { + case SE_EVENT_ERROR: + return Event::Status::kError; + case SE_EVENT_PENDING: + return Event::Status::kPending; + case SE_EVENT_COMPLETE: + return Event::Status::kComplete; + default: + return Event::Status::kUnknown; + } +} + +class CEvent : public internal::EventInterface { + public: + CEvent(SP_Device* device, SP_StreamExecutor* stream_executor) + : device_(device), + stream_executor_(stream_executor), + event_handle_(nullptr) {} + ~CEvent() override { Destroy(); } + + port::Status Create() { + OwnedTFStatus c_status(TF_NewStatus()); + stream_executor_->create_event(device_, &event_handle_, c_status.get()); + return StatusFromTF_Status(c_status.get()); + } + + port::Status Record(SP_Stream stream_handle) { + OwnedTFStatus c_status(TF_NewStatus()); + stream_executor_->record_event(device_, stream_handle, event_handle_, + c_status.get()); + return StatusFromTF_Status(c_status.get()); + } + + void Destroy() { + if (event_handle_ != nullptr) { + stream_executor_->destroy_event(device_, event_handle_); + event_handle_ = nullptr; + } + } + + SP_Event Handle() { return event_handle_; } + + private: + SP_Device* device_; + SP_StreamExecutor* stream_executor_; + SP_Event event_handle_; +}; + +class CTimer : public internal::TimerInterface { + public: + CTimer(SP_Device* device, SP_StreamExecutor* stream_executor, + SP_TimerFns* timer_fns) + : device_(device), + stream_executor_(stream_executor), + timer_handle_(nullptr), + timer_fns_(timer_fns) {} + ~CTimer() override { Destroy(); } + + port::Status Create() { + OwnedTFStatus c_status(TF_NewStatus()); + stream_executor_->create_timer(device_, &timer_handle_, c_status.get()); + return StatusFromTF_Status(c_status.get()); + } + + void Destroy() { + if (timer_handle_ != nullptr) { + stream_executor_->destroy_timer(device_, timer_handle_); + timer_handle_ = nullptr; + } + } + + SP_Timer Handle() { return timer_handle_; } + + uint64 Microseconds() const override { + return timer_fns_->nanoseconds(timer_handle_) / 1000; + } + + uint64 Nanoseconds() const override { + return timer_fns_->nanoseconds(timer_handle_); + } + + private: + SP_Device* device_; + SP_StreamExecutor* stream_executor_; + SP_Timer timer_handle_; + SP_TimerFns* timer_fns_; +}; + +// Converts DeviceMemoryBase to a C struct. +SP_DeviceMemoryBase DeviceMemoryBaseToC(const DeviceMemoryBase* mem) { + SP_DeviceMemoryBase device_memory_base{SP_DEVICE_MEMORY_BASE_STRUCT_SIZE}; + // `opaque` field inside SP_DeviceMemoryBase is not const. + // Therefore, we need to cast away the constness before setting it. + device_memory_base.opaque = const_cast(mem->opaque()); + device_memory_base.size = mem->size(); + device_memory_base.payload = mem->payload(); + // TODO(annarev): Add `ext` field to DeviceMemoryBase and set it here. + return device_memory_base; +} + +DeviceMemoryBase DeviceMemoryBaseFromC(const SP_DeviceMemoryBase& mem) { + DeviceMemoryBase base(mem.opaque, mem.size); + base.SetPayload(mem.payload); + // TODO(annarev): Add `ext` field to DeviceMemoryBase and set it here. + return base; +} + +// Wrapper that allows passing std::function across C API. +struct HostCallbackContext { + std::function callback; +}; + +// This wrapper allows calling `HostCallbackContext::callback` across C API. +// This function matches `SE_StatusCallbackFn` signature and will be passed as +// `callback_fn` to `host_callback` in `SP_StreamExecutor`. +void HostCallbackTrampoline(void* ctx, TF_Status* status) { + HostCallbackContext* host_ctx = static_cast(ctx); + port::Status s = host_ctx->callback(); + Set_TF_Status_from_Status(status, s); + delete host_ctx; +} + +class CStreamExecutor : public internal::StreamExecutorInterface { + public: + explicit CStreamExecutor(SP_Device device, + void (*destroy_device)(SP_Device* const device), + SP_StreamExecutor* stream_executor, + SP_TimerFns* timer_fns, const std::string& name, + int visible_device_count) + : device_(std::move(device)), + destroy_device_(destroy_device), + stream_executor_(stream_executor), + timer_fns_(timer_fns), + platform_name_(name), + visible_device_count_(visible_device_count) {} + + ~CStreamExecutor() override { destroy_device_(&device_); } + + port::Status Init(int device_ordinal, DeviceOptions device_options) override { + return port::Status::OK(); + } + + DeviceMemoryBase Allocate(uint64 size, int64 memory_space) override { + SP_DeviceMemoryBase mem = {SP_DEVICE_MEMORY_BASE_STRUCT_SIZE}; + stream_executor_->allocate(&device_, size, memory_space, &mem); + port::Status status = ValidateSPDeviceMemoryBase(mem); + if (!status.ok()) { + LOG(ERROR) << status.error_message(); + } + return DeviceMemoryBaseFromC(mem); + } + DeviceMemoryBase Allocate(uint64 size) { + return Allocate(size, /*memory_space=*/0); + } + void* GetSubBuffer(DeviceMemoryBase* parent, uint64 offset, + uint64 size) override { + LOG(FATAL) << "GetSubBuffer is not supported by pluggable device."; + } + + void Deallocate(DeviceMemoryBase* mem) override { + SP_DeviceMemoryBase device_memory_base = DeviceMemoryBaseToC(mem); + stream_executor_->deallocate(&device_, &device_memory_base); + } + + void* HostMemoryAllocate(uint64 size) override { + return stream_executor_->host_memory_allocate(&device_, size); + } + + void HostMemoryDeallocate(void* mem) override { + stream_executor_->host_memory_deallocate(&device_, mem); + } + + bool HostMemoryRegister(void* mem, uint64 size) override { return false; } + bool HostMemoryUnregister(void* mem) override { return false; } + + absl::optional GetAllocatorStats() override { + SP_AllocatorStats c_stats{SP_ALLOCATORSTATS_STRUCT_SIZE}; + TF_Bool has_stats = + stream_executor_->get_allocator_stats(&device_, &c_stats); + if (!has_stats) { + return absl::nullopt; + } + port::Status status = ValidateSPAllocatorStats(c_stats); + if (!status.ok()) { + LOG(ERROR) << status.error_message(); + return absl::nullopt; + } + // TODO(annarev): validate SP_AllocatorStats. + ::stream_executor::AllocatorStats stats; + stats.num_allocs = c_stats.num_allocs; + stats.bytes_in_use = c_stats.bytes_in_use; + stats.peak_bytes_in_use = c_stats.peak_bytes_in_use; + stats.largest_alloc_size = c_stats.largest_alloc_size; + if (c_stats.has_bytes_limit) { + stats.bytes_limit = c_stats.bytes_limit; + } + stats.bytes_reserved = c_stats.bytes_reserved; + stats.peak_bytes_reserved = c_stats.peak_bytes_reserved; + if (c_stats.has_bytes_reservable_limit) { + stats.bytes_reservable_limit = c_stats.bytes_reservable_limit; + } + stats.largest_free_block_bytes = c_stats.largest_free_block_bytes; + return stats; + } + bool SynchronizeAllActivity() override { + OwnedTFStatus c_status(TF_NewStatus()); + stream_executor_->synchronize_all_activity(&device_, c_status.get()); + if (TF_GetCode(c_status.get()) != TF_OK) { + LOG(ERROR) << TF_Message(c_status.get()); + return false; + } + return true; + } + port::Status SynchronousMemZero(DeviceMemoryBase* location, + uint64 size) override { + // TODO(annarev): figure out if we should support memzero/memset + // functionality by allocating on host and then copying to device. + return port::UnimplementedError( + "SynchronousMemZero is not supported by pluggable device."); + } + port::Status SynchronousMemSet(DeviceMemoryBase* location, int value, + uint64 size) override { + return port::UnimplementedError( + "SynchronousMemSet is not supported by pluggable device."); + } + port::Status SynchronousMemcpy(DeviceMemoryBase* gpu_dst, + const void* host_src, uint64 size) override { + OwnedTFStatus c_status(TF_NewStatus()); + SP_DeviceMemoryBase device_memory_base = DeviceMemoryBaseToC(gpu_dst); + stream_executor_->sync_memcpy_htod(&device_, &device_memory_base, host_src, + size, c_status.get()); + return StatusFromTF_Status(c_status.get()); + } + port::Status SynchronousMemcpy(void* host_dst, + const DeviceMemoryBase& gpu_src, + uint64 size) override { + OwnedTFStatus c_status(TF_NewStatus()); + SP_DeviceMemoryBase device_memory_base = DeviceMemoryBaseToC(&gpu_src); + stream_executor_->sync_memcpy_dtoh(&device_, host_dst, &device_memory_base, + size, c_status.get()); + return StatusFromTF_Status(c_status.get()); + } + port::Status SynchronousMemcpyDeviceToDevice(DeviceMemoryBase* gpu_dst, + const DeviceMemoryBase& gpu_src, + uint64 size) override { + OwnedTFStatus c_status(TF_NewStatus()); + SP_DeviceMemoryBase device_mem_dst = DeviceMemoryBaseToC(gpu_dst); + SP_DeviceMemoryBase device_mem_src = DeviceMemoryBaseToC(&gpu_src); + stream_executor_->sync_memcpy_dtod(&device_, &device_mem_dst, + &device_mem_src, size, c_status.get()); + return StatusFromTF_Status(c_status.get()); + } + port::Status MemZero(Stream* stream, DeviceMemoryBase* location, + uint64 size) override { + return port::UnimplementedError( + "MemZero is not supported by pluggable device."); + } + port::Status Memset(Stream* stream, DeviceMemoryBase* location, uint8 pattern, + uint64 size) override { + return port::UnimplementedError( + "Memset is not supported by pluggable device."); + } + port::Status Memset32(Stream* stream, DeviceMemoryBase* location, + uint32 pattern, uint64 size) override { + return port::UnimplementedError( + "Memset32 is not supported by pluggable device."); + } + bool Memcpy(Stream* stream, void* host_dst, const DeviceMemoryBase& gpu_src, + uint64 size) override { + OwnedTFStatus c_status(TF_NewStatus()); + SP_Stream stream_handle = + static_cast(stream->implementation())->Handle(); + SP_DeviceMemoryBase device_mem_src = DeviceMemoryBaseToC(&gpu_src); + stream_executor_->memcpy_dtoh(&device_, stream_handle, host_dst, + &device_mem_src, size, c_status.get()); + if (TF_GetCode(c_status.get()) != TF_OK) { + LOG(ERROR) << TF_Message(c_status.get()); + return false; + } + return true; + } + bool Memcpy(Stream* stream, DeviceMemoryBase* gpu_dst, const void* host_src, + uint64 size) override { + OwnedTFStatus c_status(TF_NewStatus()); + SP_Stream stream_handle = + static_cast(stream->implementation())->Handle(); + SP_DeviceMemoryBase device_mem_dst = DeviceMemoryBaseToC(gpu_dst); + stream_executor_->memcpy_htod(&device_, stream_handle, &device_mem_dst, + host_src, size, c_status.get()); + if (TF_GetCode(c_status.get()) != TF_OK) { + LOG(ERROR) << TF_Message(c_status.get()); + return false; + } + return true; + } + bool MemcpyDeviceToDevice(Stream* stream, DeviceMemoryBase* gpu_dst, + const DeviceMemoryBase& gpu_src, + uint64 size) override { + OwnedTFStatus c_status(TF_NewStatus()); + SP_Stream stream_handle = + static_cast(stream->implementation())->Handle(); + SP_DeviceMemoryBase device_mem_dst = DeviceMemoryBaseToC(gpu_dst); + SP_DeviceMemoryBase device_mem_src = DeviceMemoryBaseToC(&gpu_src); + stream_executor_->memcpy_dtod(&device_, stream_handle, &device_mem_dst, + &device_mem_src, size, c_status.get()); + if (TF_GetCode(c_status.get()) != TF_OK) { + LOG(ERROR) << TF_Message(c_status.get()); + return false; + } + return true; + } + bool HostCallback(Stream* stream, + std::function callback) override { + SP_Stream stream_handle = + static_cast(stream->implementation())->Handle(); + HostCallbackContext* ctx = new HostCallbackContext{callback}; + return stream_executor_->host_callback(&device_, stream_handle, + &HostCallbackTrampoline, ctx); + } + port::Status AllocateEvent(Event* event) override { + DCHECK(event != nullptr); + return static_cast(event->implementation())->Create(); + } + port::Status DeallocateEvent(Event* event) override { + static_cast(event->implementation())->Destroy(); + return port::Status::OK(); + } + port::Status RecordEvent(Stream* stream, Event* event) override { + SP_Stream stream_handle = + static_cast(stream->implementation())->Handle(); + return static_cast(event->implementation())->Record(stream_handle); + } + port::Status WaitForEvent(Stream* stream, Event* event) override { + SP_Stream stream_handle = + static_cast(stream->implementation())->Handle(); + SP_Event event_handle = + static_cast(event->implementation())->Handle(); + OwnedTFStatus c_status(TF_NewStatus()); + stream_executor_->wait_for_event(&device_, stream_handle, event_handle, + c_status.get()); + port::Status s = StatusFromTF_Status(c_status.get()); + return s; + } + Event::Status PollForEventStatus(Event* event) override { + SP_Event event_handle = + static_cast(event->implementation())->Handle(); + SE_EventStatus event_status = + stream_executor_->get_event_status(&device_, event_handle); + return SEEventStatusToEventStatus(event_status); + } + bool AllocateStream(Stream* stream) override { + DCHECK(stream != nullptr); + port::Status status = + static_cast(stream->implementation())->Create(); + // TODO(annarev): update AllocateStream to return status instead + // (similar to AllocateEvent). + return status.ok(); + } + void DeallocateStream(Stream* stream) override { + static_cast(stream->implementation())->Destroy(); + } + bool CreateStreamDependency(Stream* dependent, Stream* other) override { + OwnedTFStatus c_status(TF_NewStatus()); + SP_Stream dependent_handle = + static_cast(dependent->implementation())->Handle(); + SP_Stream other_handle = + static_cast(other->implementation())->Handle(); + stream_executor_->create_stream_dependency(&device_, dependent_handle, + other_handle, c_status.get()); + if (TF_GetCode(c_status.get()) != TF_OK) { + LOG(ERROR) << TF_Message(c_status.get()); + return false; + } + return true; + } + bool AllocateTimer(Timer* timer) override { + port::Status status = + static_cast(timer->implementation())->Create(); + // TODO(annarev): change return value of AllocateTimer + // to status (similar to AllocateEvent). + return status.ok(); + } + void DeallocateTimer(Timer* timer) override { + static_cast(timer->implementation())->Destroy(); + } + bool StartTimer(Stream* stream, Timer* timer) override { + OwnedTFStatus c_status(TF_NewStatus()); + SP_Stream stream_handle = + static_cast(stream->implementation())->Handle(); + SP_Timer timer_handle = + static_cast(timer->implementation())->Handle(); + stream_executor_->start_timer(&device_, stream_handle, timer_handle, + c_status.get()); + if (TF_GetCode(c_status.get()) != TF_OK) { + LOG(ERROR) << TF_Message(c_status.get()); + return false; + } + return true; + } + bool StopTimer(Stream* stream, Timer* timer) override { + OwnedTFStatus c_status(TF_NewStatus()); + SP_Stream stream_handle = + static_cast(stream->implementation())->Handle(); + SP_Timer timer_handle = + static_cast(timer->implementation())->Handle(); + stream_executor_->stop_timer(&device_, stream_handle, timer_handle, + c_status.get()); + if (TF_GetCode(c_status.get()) != TF_OK) { + LOG(ERROR) << TF_Message(c_status.get()); + return false; + } + return true; + } + port::Status BlockHostForEvent(Stream* stream, Event* event) { + OwnedTFStatus c_status(TF_NewStatus()); + SP_Event event_handle = + static_cast(event->implementation())->Handle(); + stream_executor_->block_host_for_event(&device_, event_handle, + c_status.get()); + return StatusFromTF_Status(c_status.get()); + } + + port::Status BlockHostUntilDone(Stream* stream) override { + OwnedTFStatus c_status(TF_NewStatus()); + SP_Event event_handle; + stream_executor_->create_event(&device_, &event_handle, c_status.get()); + TF_RETURN_IF_ERROR(StatusFromTF_Status(c_status.get())); + SP_Stream stream_handle = + static_cast(stream->implementation())->Handle(); + stream_executor_->record_event(&device_, stream_handle, event_handle, + c_status.get()); + port::Status s = StatusFromTF_Status(c_status.get()); + if (!s.ok()) { + stream_executor_->destroy_event(&device_, event_handle); + return s; + } + stream_executor_->block_host_for_event(&device_, event_handle, + c_status.get()); + stream_executor_->destroy_event(&device_, event_handle); + return StatusFromTF_Status(c_status.get()); + } + + port::Status GetStatus(Stream* stream) override { + OwnedTFStatus c_status(TF_NewStatus()); + SP_Stream stream_handle = + static_cast(stream->implementation())->Handle(); + stream_executor_->get_stream_status(&device_, stream_handle, + c_status.get()); + return StatusFromTF_Status(c_status.get()); + } + int PlatformDeviceCount() override { return visible_device_count_; } + port::Status EnablePeerAccessTo(StreamExecutorInterface* other) override { + return port::UnimplementedError( + "EnablePeerAccessTo is not supported by pluggable device."); + } + bool CanEnablePeerAccessTo(StreamExecutorInterface* other) override { + return false; + } + + bool DeviceMemoryUsage(int64* free, int64* total) const override { + static_assert(sizeof(int64_t) == sizeof(tensorflow::int64), + "64-bit int types should match in size"); + return stream_executor_->device_memory_usage( + &device_, reinterpret_cast(free), + reinterpret_cast(total)); + } + + // Creates a new DeviceDescription object. + // Ownership is transferred to the caller. + port::StatusOr> CreateDeviceDescription() + const override { + // TODO(annarev): Figure out if we need to support more description fields. + internal::DeviceDescriptionBuilder builder; + builder.set_name(platform_name_); + return builder.Build(); + } + + // Each call creates a new instance of the platform-specific implementation of + // the corresponding interface type. + std::unique_ptr CreateEventImplementation() + override { + return std::unique_ptr( + new CEvent(&device_, stream_executor_)); + } + std::unique_ptr CreateKernelImplementation() + override { + LOG(FATAL) + << "CreateKernelImplementation is not supported by pluggable device."; + } + std::unique_ptr GetStreamImplementation() + override { + return std::unique_ptr( + new CStream(&device_, stream_executor_)); + } + std::unique_ptr GetTimerImplementation() override { + return std::unique_ptr( + new CTimer(&device_, stream_executor_, timer_fns_)); + } + + private: + SP_Device device_; + void (*destroy_device_)(SP_Device* const device); + SP_StreamExecutor* stream_executor_; + SP_TimerFns* timer_fns_; + std::string platform_name_; + int visible_device_count_; +}; +} // namespace + +CPlatform::CPlatform(SP_Platform platform, + void (*destroy_platform)(SP_Platform*), + SP_StreamExecutor stream_executor, SP_TimerFns timer_fns) + : platform_(std::move(platform)), + destroy_platform_(destroy_platform), + stream_executor_(std::move(stream_executor)), + timer_fns_(std::move(timer_fns)), + name_(platform.name) {} + +CPlatform::~CPlatform() { + executor_cache_.DestroyAllExecutors(); + platform_.destroy_stream_executor(&stream_executor_); + platform_.destroy_timer_fns(&timer_fns_); + destroy_platform_(&platform_); +} + +port::StatusOr> +CPlatform::DescriptionForDevice(int ordinal) const { + // TODO(annarev): see if we can get StreamExecutor instance + // and call GetDeviceDescription. executor_cache_.Get would need + // to be made const for it to work. + internal::DeviceDescriptionBuilder builder; + builder.set_name(name_); + return builder.Build(); +} +port::StatusOr CPlatform::ExecutorForDevice(int ordinal) { + stream_executor::StreamExecutorConfig config; + config.ordinal = ordinal; + return GetExecutor(config); +} +port::StatusOr CPlatform::ExecutorForDeviceWithPluginConfig( + int ordinal, const PluginConfig& plugin_config) { + StreamExecutorConfig config; + config.ordinal = ordinal; + config.plugin_config = plugin_config; + return GetExecutor(config); +} +port::StatusOr CPlatform::GetExecutor( + const StreamExecutorConfig& config) { + return executor_cache_.GetOrCreate( + config, [&]() { return GetUncachedExecutor(config); }); +} +port::StatusOr> CPlatform::GetUncachedExecutor( + const StreamExecutorConfig& config) { + // Fill device creation params + SE_CreateDeviceParams device_params{SE_CREATE_DEVICE_PARAMS_STRUCT_SIZE}; + SP_Device device{SP_DEVICE_STRUCT_SIZE}; + device_params.device = &device; + device_params.ext = nullptr; + device_params.ordinal = config.ordinal; + OwnedTFStatus c_status(TF_NewStatus()); + + // Create Device + platform_.create_device(&device_params, c_status.get()); + TF_RETURN_IF_ERROR(StatusFromTF_Status(c_status.get())); + TF_RETURN_IF_ERROR(ValidateSPDevice(device)); + + auto executor = absl::make_unique( + std::move(device), platform_.destroy_device, &stream_executor_, + &timer_fns_, name_, platform_.visible_device_count); + auto result = absl::make_unique(this, std::move(executor), + config.ordinal); + return result; +} + +port::Status RegisterDevicePlugin(const std::string& dso_path) { + // Step 1: Load plugin + tensorflow::Env* env = tensorflow::Env::Default(); + void* dso_handle; + TF_RETURN_IF_ERROR(env->LoadDynamicLibrary(dso_path.c_str(), &dso_handle)); + + // Step 2: Load symbol for `TF_InitPlugin` + void* dso_symbol; + TF_RETURN_IF_ERROR( + env->GetSymbolFromLibrary(dso_handle, "SE_InitPlugin", &dso_symbol)); + + // Step 3: Call `TF_InitPlugin` + auto init_fn = reinterpret_cast(dso_symbol); + return RegisterDevicePlugin(init_fn); +} + +port::Status RegisterDevicePlugin(SEPluginInitFn init_fn) { + SE_PlatformRegistrationParams params{ + SE_PLATFORM_REGISTRATION_PARAMS_STRUCT_SIZE}; + SP_Platform platform{SP_PLATFORM_STRUCT_SIZE}; + params.major_version = SE_MAJOR; + params.minor_version = SE_MINOR; + params.revision_version = SE_REVISION; + params.platform = &platform; + + OwnedTFStatus c_status(TF_NewStatus()); + init_fn(¶ms, c_status.get()); + TF_RETURN_IF_ERROR(tensorflow::StatusFromTF_Status(c_status.get())); + TF_RETURN_IF_ERROR(ValidateSEPlatformRegistrationParams(params)); + TF_RETURN_IF_ERROR(ValidateSPPlatform(platform)); + + // Fill stream executor creation params + SE_CreateStreamExecutorParams se_params{ + SE_CREATE_STREAM_EXECUTOR_PARAMS_STRUCT_SIZE}; + SP_StreamExecutor se{SP_STREAMEXECUTOR_STRUCT_SIZE}; + se_params.stream_executor = &se; + + // Create StreamExecutor + platform.create_stream_executor(&se_params, c_status.get()); + TF_RETURN_IF_ERROR(tensorflow::StatusFromTF_Status(c_status.get())); + TF_RETURN_IF_ERROR(ValidateSPStreamExecutor(se)); + + SP_TimerFns timer_fns{SP_TIMER_FNS_STRUCT_SIZE}; + platform.create_timer_fns(&timer_fns, c_status.get()); + TF_RETURN_IF_ERROR(tensorflow::StatusFromTF_Status(c_status.get())); + TF_RETURN_IF_ERROR(ValidateSPTimerFns(timer_fns)); + + // Register new platform + std::string platform_name = std::string(platform.name); + std::unique_ptr cplatform( + new stream_executor::CPlatform(std::move(platform), + params.destroy_platform, std::move(se), + std::move(timer_fns))); + SE_CHECK_OK(stream_executor::MultiPlatformManager::RegisterPlatform( + std::move(cplatform))); + + // TODO(annarev): Add pluggable device registration here. + return port::Status::OK(); +} +} // namespace stream_executor diff --git a/tensorflow/c/experimental/stream_executor/stream_executor.h b/tensorflow/c/experimental/stream_executor/stream_executor.h new file mode 100644 index 00000000000..b3459a29ccc --- /dev/null +++ b/tensorflow/c/experimental/stream_executor/stream_executor.h @@ -0,0 +1,395 @@ +/* Copyright 2020 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ +#ifndef TENSORFLOW_C_EXPERIMENTAL_STREAM_EXECUTOR_STREAM_EXECUTOR_H_ +#define TENSORFLOW_C_EXPERIMENTAL_STREAM_EXECUTOR_STREAM_EXECUTOR_H_ +#include +#include + +#include "tensorflow/c/c_api_macros.h" +#include "tensorflow/c/tf_status.h" + +// -------------------------------------------------------------------------- +// C API for StreamExecutor. The API is under active development and eventually +// should allow registering a pluggable device with TensorFlow. +// +// Conventions: +// * Struct prefix indicates whether struct fields should be filled by the +// plugin or core implementation: +// * SE_ : set/filled by core unless explicitly marked otherwise. +// * SP_ : set/filled by plugin unless explicitly marked otherwise. +// * We use `struct_size` for version checking. It is exempt from the `SE/SP` +// rule above and should be set both by core and the plugin. +// * For example, `create_device` function receives `SP_Device*` as input +// with `struct_size` populated by core. The plugin is responsible for +// setting `struct_size` as well, along with all other fields. +// * Refer to "TensorFlow Versioning Strategy" section at +// https://github.com/tensorflow/community/pull/257/files. +// * Note that the API is still under active development and doesn't have +// versioning guarantees yet. +// * `void* ext` is a free-form field that can be populated by +// a plugin in `SP_*` structs or potential future extension points in `SE_` +// structs. +// +// Example usage: +// +// /* Sample TensorFlow code below, exact implementation might differ. */ +// // Version checking uses `struct_size`. It is exempt from the `SE/SP` rule +// // above and should be set both by core and the plugin." +// SP_Device device { SP_DEVICE_STRUCT_SIZE }; +// SE_CreateDeviceParams params { SE_CREATE_DEVICE_PARAMS_STRUCT_SIZE } ; +// params.device = &device; +// +// /* Plugin code below */ +// constexpr char DEVICE_NAME[] = "MyDevice"; +// constexpr char DEVICE_TYPE[] = "GPU"; +// +// void create_device(SE_CreateDeviceParams* params, TF_Status* status) { +// // Custom actions based on TensorFlow's view of SP_Device. +// OnTFDeviceView(params->device->struct_size); +// params->device = { SP_DEVICE_STRUCT_SIZE }; +// params->device->device_handle = get_my_device_handle(device->ordinal); +// params->device->ordinal = params->ordinal; +// ... +// } +// +// void destroy_device(SP_Device* device) { +// delete_my_device_handle(device->device_handle); +// } +// +// void SE_InitPlugin( +// SE_PlatformRegistrationParams* params, +// TF_Status* status) { +// params->platform = { SP_PLATFORM_STRUCT_SIZE }; +// // Values such as `name` and `type` must outlive SE_InitPlugin call. +// params->platform->name = DEVICE_NAME; +// params->platform->type = DEVICE_TYPE; +// params->platform->visible_device_count = 2; +// params->platform->create_device = create_device; +// params->platform->destroy_device = destroy_device; +// ... +// } + +#define SE_MAJOR 0 +#define SE_MINOR 0 +#define SE_REVISION 1 + +#ifdef __cplusplus +extern "C" { +#endif + +typedef struct SP_Stream_st* SP_Stream; +typedef struct SP_Event_st* SP_Event; +typedef struct SP_Timer_st* SP_Timer; +// Takes `callback_arg` passed to `host_callback` as the first argument. +typedef void (*SE_StatusCallbackFn)(void* const, TF_Status* const); + +typedef struct SP_TimerFns { + size_t struct_size; + void* ext; // reserved for future use + uint64_t (*nanoseconds)(SP_Timer timer); +} SP_TimerFns; + +#define SP_TIMER_FNS_STRUCT_SIZE TF_OFFSET_OF_END(SP_TimerFns, nanoseconds) + +typedef struct SP_AllocatorStats { + size_t struct_size; + int64_t num_allocs; + int64_t bytes_in_use; + int64_t peak_bytes_in_use; + int64_t largest_alloc_size; + + int8_t has_bytes_limit; + int64_t bytes_limit; + + int64_t bytes_reserved; + int64_t peak_bytes_reserved; + + int8_t has_bytes_reservable_limit; + int64_t bytes_reservable_limit; + + int64_t largest_free_block_bytes; +} SP_AllocatorStats; + +#define SP_ALLOCATORSTATS_STRUCT_SIZE \ + TF_OFFSET_OF_END(SP_AllocatorStats, largest_free_block_bytes) + +// Potential states for an SP_Event. If `poll_for_status` returns anything aside +// from kPending or kComplete, an error has occurred; kUnknown is a bad state. +typedef enum SE_EventStatus { + SE_EVENT_UNKNOWN, + SE_EVENT_ERROR, + SE_EVENT_PENDING, + SE_EVENT_COMPLETE, +} SE_EventStatus; + +// Memory allocation information. +// This matches DeviceMemoryBase defined here: +// https://cs.opensource.google/tensorflow/tensorflow/+/refs/tags/v2.3.0:tensorflow/stream_executor/device_memory.h;l=57 +typedef struct SP_DeviceMemoryBase { + size_t struct_size; + void* ext; // free-form data set by plugin + // Platform-dependent value representing allocated memory. + void* opaque; + uint64_t size; // Size in bytes of this allocation. + uint64_t payload; // Value for plugin's use +} SP_DeviceMemoryBase; + +#define SP_DEVICE_MEMORY_BASE_STRUCT_SIZE \ + TF_OFFSET_OF_END(SP_DeviceMemoryBase, size) + +typedef struct SP_Device { + size_t struct_size; + void* ext; // free-form data set by plugin + int32_t ordinal; // device index + + // Device vendor can store handle to their device representation + // here. + void* device_handle; +} SP_Device; + +#define SP_DEVICE_STRUCT_SIZE TF_OFFSET_OF_END(SP_Device, device_handle) + +typedef struct SE_CreateDeviceParams { + size_t struct_size; + void* ext; // reserved for future use + int32_t ordinal; // device index + + SP_Device* device; // Input/output, struct_size set by TF for plugin to read. + // Subsequently plugin fills the entire struct. +} SE_CreateDeviceParams; + +#define SE_CREATE_DEVICE_PARAMS_STRUCT_SIZE \ + TF_OFFSET_OF_END(SE_CreateDeviceParams, device) + +typedef struct SP_StreamExecutor { + size_t struct_size; + void* ext; // reserved for future use + + /*** ALLOCATION CALLBACKS ***/ + // Synchronously allocates `size` bytes on the underlying platform and returns + // `SP_DeviceMemoryBase` representing that allocation. In the case of failure, + // nullptr is returned. + // `memory_space` is reserved for a potential future usage and should be set + // to 0. + void (*allocate)(const SP_Device* device, uint64_t size, int64_t memory_space, + SP_DeviceMemoryBase* mem); + + // Deallocate the device memory previously allocated via this interface. + // Deallocation of a nullptr-representative value is permitted. + void (*deallocate)(const SP_Device* device, SP_DeviceMemoryBase* memory); + + // Allocates a region of host memory and registers it with the platform API. + // Memory allocated in this manner is required for use in asynchronous memcpy + // operations, such as `memcpy_dtoh`. + void* (*host_memory_allocate)(const SP_Device* device, uint64_t size); + + // Deallocates a region of host memory allocated by `host_memory_allocate`. + void (*host_memory_deallocate)(const SP_Device* device, void* mem); + + // Fills SP_AllocatorStats with allocator statistics, if it is available. + // If it is not available, return false. + TF_Bool (*get_allocator_stats)(const SP_Device* device, + SP_AllocatorStats* stats); + // Fills the underlying device memory usage information, if it is + // available. If it is not available (false is returned), free/total need not + // be initialized. + TF_Bool (*device_memory_usage)(const SP_Device* device, int64_t* free, + int64_t* total); + + /*** STREAM CALLBACKS ***/ + // Creates SP_Stream. This call should also allocate stream + // resources on the underlying platform and initializes its + // internals. + void (*create_stream)(const SP_Device* device, SP_Stream* stream, + TF_Status* status); + + // Destroys SP_Stream and deallocates any underlying resources. + void (*destroy_stream)(const SP_Device* device, SP_Stream stream); + + // Causes `dependent` to not begin execution until `other` has finished its + // last-enqueued work. + void (*create_stream_dependency)(const SP_Device* device, SP_Stream dependent, + SP_Stream other, TF_Status* status); + + // Without blocking the device, retrieve the current stream status. + void (*get_stream_status)(const SP_Device* device, SP_Stream stream, + TF_Status* status); + + /*** EVENT CALLBACKS ***/ + // Create SP_Event. Performs platform-specific allocation and initialization + // of an event. + void (*create_event)(const SP_Device* device, SP_Event* event, + TF_Status* status); + + // Destroy SE_Event and perform any platform-specific deallocation and + // cleanup of an event. + void (*destroy_event)(const SP_Device* device, SP_Event event); + + // Requests the current status of the event from the underlying platform. + SE_EventStatus (*get_event_status)(const SP_Device* device, SP_Event event); + // Inserts the specified event at the end of the specified stream. + void (*record_event)(const SP_Device* device, SP_Stream stream, + SP_Event event, TF_Status* status); + + // Wait for the specified event at the end of the specified stream. + void (*wait_for_event)(const SP_Device* const device, SP_Stream stream, + SP_Event event, TF_Status* const status); + + /*** TIMER CALLBACKS ***/ + // Creates SP_Timer. Allocates timer resources on the underlying platform + // and initializes its internals, setting `timer` output variable. Sets + // values in `timer_fns` struct. + void (*create_timer)(const SP_Device* device, SP_Timer* timer, + TF_Status* status); + + // Destroy timer and deallocates timer resources on the underlying platform. + void (*destroy_timer)(const SP_Device* device, SP_Timer timer); + + // Records a start event for an interval timer. + void (*start_timer)(const SP_Device* device, SP_Stream stream, SP_Timer timer, + TF_Status* status); + + // Records a stop event for an interval timer. + void (*stop_timer)(const SP_Device* device, SP_Stream stream, SP_Timer timer, + TF_Status* status); + + /*** MEMCPY CALLBACKS ***/ + // Enqueues a memcpy operation onto stream, with a host destination location + // `host_dst` and a device memory source, with target size `size`. + void (*memcpy_dtoh)(const SP_Device* device, SP_Stream stream, void* host_dst, + const SP_DeviceMemoryBase* device_src, uint64_t size, + TF_Status* status); + + // Enqueues a memcpy operation onto stream, with a device destination + // location and a host memory source, with target size `size`. + void (*memcpy_htod)(const SP_Device* device, SP_Stream stream, + SP_DeviceMemoryBase* device_dst, const void* host_src, + uint64_t size, TF_Status* status); + + // Enqueues a memcpy operation onto stream, with a device destination + // location and a device memory source, with target size `size`. + void (*memcpy_dtod)(const SP_Device* device, SP_Stream stream, + SP_DeviceMemoryBase* device_dst, + const SP_DeviceMemoryBase* device_src, uint64_t size, + TF_Status* status); + + // Blocks the caller while a data segment of the given size is + // copied from the device source to the host destination. + void (*sync_memcpy_dtoh)(const SP_Device* device, void* host_dst, + const SP_DeviceMemoryBase* device_src, uint64_t size, + TF_Status* status); + + // Blocks the caller while a data segment of the given size is + // copied from the host source to the device destination. + void (*sync_memcpy_htod)(const SP_Device* device, + SP_DeviceMemoryBase* device_dst, + const void* host_src, uint64_t size, + TF_Status* status); + + // Blocks the caller while a data segment of the given size is copied from the + // device source to the device destination. + void (*sync_memcpy_dtod)(const SP_Device* device, + SP_DeviceMemoryBase* device_dst, + const SP_DeviceMemoryBase* device_src, uint64_t size, + TF_Status* status); + + // Causes the host code to synchronously wait for the event to complete. + void (*block_host_for_event)(const SP_Device* device, SP_Event event, + TF_Status* status); + + // Synchronizes all activity occurring in the StreamExecutor's context (most + // likely a whole device). + void (*synchronize_all_activity)(const SP_Device* device, TF_Status* status); + + // Enqueues on a stream a user-specified function to be run on the host. + // `callback_arg` should be passed as the first argument to `callback_fn`. + TF_Bool (*host_callback)(SP_Device* device, SP_Stream stream, + SE_StatusCallbackFn callback_fn, void* callback_arg); +} SP_StreamExecutor; + +#define SP_STREAMEXECUTOR_STRUCT_SIZE \ + TF_OFFSET_OF_END(SP_StreamExecutor, host_callback) + +typedef struct SE_CreateStreamExecutorParams { + size_t struct_size; + void* ext; // reserved for future use + + SP_StreamExecutor* stream_executor; // output, to be filled by plugin +} SE_CreateStreamExecutorParams; + +#define SE_CREATE_STREAM_EXECUTOR_PARAMS_STRUCT_SIZE \ + TF_OFFSET_OF_END(SE_CreateStreamExecutorParams, stream_executor) + +typedef struct SP_Platform { + size_t struct_size; + + void* ext; // free-form data set by plugin + + // Platform name. Must be null-terminated. + const char* name; + + // Device type name, for example GPU. Must be null-terminated. + const char* type; + + // Number of visible devices + size_t visible_device_count; + + // Callbacks for creating/destroying SP_Device. + void (*create_device)(SE_CreateDeviceParams* params, TF_Status* status); + + // Clean up fields inside SP_Device that were allocated + // by the plugin. `device` itself should not be deleted here. + void (*destroy_device)(SP_Device* device); + + // Callbacks for creating/destroying SP_StreamExecutor. + void (*create_stream_executor)(SE_CreateStreamExecutorParams* params, + TF_Status* status); + // Clean up fields inside SP_StreamExecutor that were allocated + // by the plugin. `stream_executor` itself should not be deleted here. + void (*destroy_stream_executor)(SP_StreamExecutor* stream_executor); + + // Callbacks for creating/destroying SP_TimerFns. + void (*create_timer_fns)(SP_TimerFns* timer, TF_Status* status); + + void (*destroy_timer_fns)(SP_TimerFns* timer_fns); +} SP_Platform; + +#define SP_PLATFORM_STRUCT_SIZE TF_OFFSET_OF_END(SP_Platform, destroy_timer_fns) + +typedef struct SE_PlatformRegistrationParams { + size_t struct_size; + void* ext; // reserved for future use + + // StreamExecutor C API version. + int32_t major_version; + int32_t minor_version; + int32_t revision_version; + + SP_Platform* platform; // output, set by plugin + // Clean up fields inside SP_Platform that were allocated + // by the plugin. `platform` itself should not be deleted here. + void (*destroy_platform)(SP_Platform* platform); // out, set by plugin +} SE_PlatformRegistrationParams; + +#define SE_PLATFORM_REGISTRATION_PARAMS_STRUCT_SIZE \ + TF_OFFSET_OF_END(SE_PlatformRegistrationParams, destroy_platform) + +void SE_InitPlugin(SE_PlatformRegistrationParams* params, TF_Status* status); + +#ifdef __cplusplus +} // extern "C" +#endif + +#endif // TENSORFLOW_C_EXPERIMENTAL_STREAM_EXECUTOR_STREAM_EXECUTOR_H_ diff --git a/tensorflow/c/experimental/stream_executor/stream_executor_internal.h b/tensorflow/c/experimental/stream_executor/stream_executor_internal.h new file mode 100644 index 00000000000..2285fe85867 --- /dev/null +++ b/tensorflow/c/experimental/stream_executor/stream_executor_internal.h @@ -0,0 +1,80 @@ +/* Copyright 2020 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ +// Classes and utilities that work with StreamExecutor C API for internal use. +// This includes functions used for device registration and interfaces needed +// for testing. +#ifndef TENSORFLOW_C_EXPERIMENTAL_STREAM_EXECUTOR_STREAM_EXECUTOR_INTERNAL_H_ +#define TENSORFLOW_C_EXPERIMENTAL_STREAM_EXECUTOR_STREAM_EXECUTOR_INTERNAL_H_ + +#include "tensorflow/c/experimental/stream_executor/stream_executor.h" +#include "tensorflow/stream_executor/executor_cache.h" +#include "tensorflow/stream_executor/lib/status.h" +#include "tensorflow/stream_executor/platform.h" + +namespace stream_executor { + +// Plugin initialization function that a device plugin +// must define. +typedef void (*SEPluginInitFn)(SE_PlatformRegistrationParams* const, + TF_Status* const); + +// Loads dso and registers StreamExecutor-based pluggable device. +port::Status RegisterDevicePlugin(const std::string& dso_path); + +// Allow registering a plugin using a function (used for testing). +port::Status RegisterDevicePlugin(SEPluginInitFn init_fn); + +class CPlatform : public Platform { + public: + explicit CPlatform(SP_Platform platform, + void (*destroy_platform)(SP_Platform*), + SP_StreamExecutor stream_executor, SP_TimerFns timer_fns); + ~CPlatform() override; + + Id id() const override { return const_cast(&plugin_id_value_); } + const std::string& Name() const override { return name_; } + int VisibleDeviceCount() const override { + return platform_.visible_device_count; + } + port::StatusOr> DescriptionForDevice( + int ordinal) const override; + port::StatusOr ExecutorForDevice(int ordinal) override; + port::StatusOr ExecutorForDeviceWithPluginConfig( + int ordinal, const PluginConfig& plugin_config) override; + port::StatusOr GetExecutor( + const StreamExecutorConfig& config) override; + port::StatusOr> GetUncachedExecutor( + const StreamExecutorConfig& config) override; + + // Trace listener is not supported + void RegisterTraceListener(std::unique_ptr listener) override { + LOG(FATAL) << "RegisterTraceListener is not supported by pluggable device"; + } + void UnregisterTraceListener(TraceListener* listener) override {} + + void DestroyAllExecutors() { executor_cache_.DestroyAllExecutors(); } + + private: + SP_Platform platform_; + void (*destroy_platform_)(SP_Platform*); + SP_StreamExecutor stream_executor_; + SP_TimerFns timer_fns_; + const std::string name_; + int plugin_id_value_; + stream_executor::ExecutorCache executor_cache_; +}; + +} // namespace stream_executor +#endif // TENSORFLOW_C_EXPERIMENTAL_STREAM_EXECUTOR_STREAM_EXECUTOR_INTERNAL_H_ diff --git a/tensorflow/c/experimental/stream_executor/stream_executor_test.cc b/tensorflow/c/experimental/stream_executor/stream_executor_test.cc new file mode 100644 index 00000000000..86fe00fe5ad --- /dev/null +++ b/tensorflow/c/experimental/stream_executor/stream_executor_test.cc @@ -0,0 +1,802 @@ +/* Copyright 2020 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0(the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ +#include "tensorflow/c/experimental/stream_executor/stream_executor.h" + +#include "tensorflow/c/experimental/stream_executor/stream_executor_internal.h" +#include "tensorflow/core/lib/core/status_test_util.h" +#include "tensorflow/core/platform/test.h" +#include "tensorflow/core/protobuf/error_codes.pb.h" +#include "tensorflow/stream_executor/event.h" +#include "tensorflow/stream_executor/multi_platform_manager.h" +#include "tensorflow/stream_executor/stream.h" +#include "tensorflow/stream_executor/stream_executor_pimpl.h" +#include "tensorflow/stream_executor/timer.h" + +struct SP_Stream_st { + explicit SP_Stream_st(int id) : stream_id(id) {} + int stream_id; +}; + +struct SP_Event_st { + explicit SP_Event_st(int id) : event_id(id) {} + int event_id; +}; + +struct SP_Timer_st { + explicit SP_Timer_st(int id) : timer_id(id) {} + int timer_id; +}; + +namespace stream_executor { +namespace { +constexpr int DEVICE_COUNT = 2; +constexpr char DEVICE_NAME[] = "MyDevice"; +constexpr char DEVICE_TYPE[] = "GPU"; + +/*** Create SP_StreamExecutor (with empty functions) ***/ +void allocate(const SP_Device* const device, uint64_t size, + int64_t memory_space, SP_DeviceMemoryBase* const mem) {} +void deallocate(const SP_Device* const device, SP_DeviceMemoryBase* const mem) { +} +TF_Bool get_allocator_stats(const SP_Device* const device, + SP_AllocatorStats* const stats) { + return true; +} +TF_Bool device_memory_usage(const SP_Device* const device, int64_t* const free, + int64_t* const total) { + return true; +} +void create_stream(const SP_Device* const device, SP_Stream* stream, + TF_Status* const status) { + stream = nullptr; +} +void destroy_stream(const SP_Device* const device, SP_Stream stream) {} +void create_stream_dependency(const SP_Device* const device, + SP_Stream dependent, SP_Stream other, + TF_Status* const status) {} +void get_stream_status(const SP_Device* const device, SP_Stream stream, + TF_Status* const status) {} +void create_event(const SP_Device* const device, SP_Event* event, + TF_Status* const status) { + event = nullptr; +} +void destroy_event(const SP_Device* const device, SP_Event event) {} +SE_EventStatus get_event_status(const SP_Device* const device, SP_Event event) { + return SE_EVENT_UNKNOWN; +} +void record_event(const SP_Device* const device, SP_Stream stream, + SP_Event event, TF_Status* const status) {} +void wait_for_event(const SP_Device* const device, SP_Stream stream, + SP_Event event, TF_Status* const status) {} +void create_timer(const SP_Device* const device, SP_Timer* timer, + TF_Status* const status) {} +void destroy_timer(const SP_Device* const device, SP_Timer timer) {} +void start_timer(const SP_Device* const device, SP_Stream stream, + SP_Timer timer, TF_Status* const status) {} +void stop_timer(const SP_Device* const device, SP_Stream stream, SP_Timer timer, + TF_Status* const status) {} +void memcpy_dtoh(const SP_Device* const device, SP_Stream stream, + void* host_dst, const SP_DeviceMemoryBase* const device_src, + uint64_t size, TF_Status* const status) {} +void memcpy_htod(const SP_Device* const device, SP_Stream stream, + SP_DeviceMemoryBase* const device_dst, const void* host_src, + uint64_t size, TF_Status* const status) {} +void sync_memcpy_dtoh(const SP_Device* const device, void* host_dst, + const SP_DeviceMemoryBase* const device_src, + uint64_t size, TF_Status* const status) {} +void sync_memcpy_htod(const SP_Device* const device, + SP_DeviceMemoryBase* const device_dst, + const void* host_src, uint64_t size, + TF_Status* const status) {} +void block_host_for_event(const SP_Device* const device, SP_Event event, + TF_Status* const status) {} +void synchronize_all_activity(const SP_Device* const device, + TF_Status* const status) {} +TF_Bool host_callback(SP_Device* const device, SP_Stream stream, + SE_StatusCallbackFn const callback_fn, + void* const callback_arg) { + return true; +} + +void PopulateDefaultStreamExecutor(SP_StreamExecutor* se) { + se->struct_size = SP_STREAMEXECUTOR_STRUCT_SIZE; + se->allocate = allocate; + se->deallocate = deallocate; + se->get_allocator_stats = get_allocator_stats; + se->device_memory_usage = device_memory_usage; + se->create_stream = create_stream; + se->destroy_stream = destroy_stream; + se->create_stream_dependency = create_stream_dependency; + se->get_stream_status = get_stream_status; + se->create_event = create_event; + se->destroy_event = destroy_event; + se->get_event_status = get_event_status; + se->record_event = record_event; + se->wait_for_event = wait_for_event; + se->create_timer = create_timer; + se->destroy_timer = destroy_timer; + se->start_timer = start_timer; + se->stop_timer = stop_timer; + se->memcpy_dtoh = memcpy_dtoh; + se->memcpy_htod = memcpy_htod; + se->sync_memcpy_dtoh = sync_memcpy_dtoh; + se->sync_memcpy_htod = sync_memcpy_htod; + se->block_host_for_event = block_host_for_event; + se->synchronize_all_activity = synchronize_all_activity; + se->host_callback = host_callback; +} + +/*** Create SP_TimerFns ***/ +uint64_t nanoseconds(SP_Timer timer) { return timer->timer_id; } + +void PopulateDefaultTimerFns(SP_TimerFns* timer_fns) { + timer_fns->nanoseconds = nanoseconds; +} + +/*** Create SP_Platform ***/ +void create_timer_fns(SP_TimerFns* timer_fns, TF_Status* status) { + TF_SetStatus(status, TF_OK, ""); + PopulateDefaultTimerFns(timer_fns); +} +void destroy_timer_fns(SP_TimerFns* timer_fns) {} + +void create_stream_executor(SE_CreateStreamExecutorParams* params, + TF_Status* status) { + TF_SetStatus(status, TF_OK, ""); + PopulateDefaultStreamExecutor(params->stream_executor); +} +void destroy_stream_executor(SP_StreamExecutor* se) {} + +void create_device(SE_CreateDeviceParams* params, TF_Status* status) { + TF_SetStatus(status, TF_OK, ""); + params->device->struct_size = SP_DEVICE_STRUCT_SIZE; +} +void destroy_device(SP_Device* device) {} + +void PopulateDefaultPlatform(SP_Platform* platform) { + platform->struct_size = SP_PLATFORM_STRUCT_SIZE; + platform->name = DEVICE_NAME; + platform->type = DEVICE_TYPE; + platform->visible_device_count = DEVICE_COUNT; + platform->create_device = create_device; + platform->destroy_device = destroy_device; + platform->create_stream_executor = create_stream_executor; + platform->destroy_stream_executor = destroy_stream_executor; + platform->create_timer_fns = create_timer_fns; + platform->destroy_timer_fns = destroy_timer_fns; +} + +void destroy_platform(SP_Platform* const platform) {} + +/*** Registration tests ***/ +TEST(StreamExecutor, SuccessfulRegistration) { + auto plugin_init = [](SE_PlatformRegistrationParams* const params, + TF_Status* const status) -> void { + TF_SetStatus(status, TF_OK, ""); + PopulateDefaultPlatform(params->platform); + params->destroy_platform = destroy_platform; + }; + port::Status status = RegisterDevicePlugin(plugin_init); + TF_ASSERT_OK(status); + port::StatusOr maybe_platform = + MultiPlatformManager::PlatformWithName("MyDevice"); + TF_ASSERT_OK(maybe_platform.status()); + Platform* platform = maybe_platform.ConsumeValueOrDie(); + ASSERT_EQ(platform->Name(), DEVICE_NAME); + ASSERT_EQ(platform->VisibleDeviceCount(), DEVICE_COUNT); + + port::StatusOr maybe_executor = + platform->ExecutorForDevice(0); + TF_ASSERT_OK(maybe_executor.status()); + StreamExecutor* executor = maybe_executor.ConsumeValueOrDie(); + ASSERT_EQ(executor->GetDeviceDescription().name(), "MyDevice"); +} + +TEST(StreamExecutor, NameNotSet) { + auto plugin_init = [](SE_PlatformRegistrationParams* const params, + TF_Status* const status) -> void { + TF_SetStatus(status, TF_OK, ""); + PopulateDefaultPlatform(params->platform); + params->platform->name = nullptr; + params->destroy_platform = destroy_platform; + }; + + port::Status status = RegisterDevicePlugin(plugin_init); + ASSERT_EQ(status.code(), tensorflow::error::FAILED_PRECONDITION); + ASSERT_EQ(status.error_message(), "'name' field in SP_Platform must be set."); +} + +TEST(StreamExecutor, CreateDeviceNotSet) { + auto plugin_init = [](SE_PlatformRegistrationParams* const params, + TF_Status* const status) -> void { + TF_SetStatus(status, TF_OK, ""); + PopulateDefaultPlatform(params->platform); + params->platform->create_device = nullptr; + params->destroy_platform = destroy_platform; + }; + + port::Status status = RegisterDevicePlugin(plugin_init); + ASSERT_EQ(status.code(), tensorflow::error::FAILED_PRECONDITION); + ASSERT_EQ(status.error_message(), + "'create_device' field in SP_Platform must be set."); +} + +/*** StreamExecutor behavior tests ***/ +class StreamExecutorTest : public ::testing::Test { + protected: + StreamExecutorTest() {} + void SetUp() override { + PopulateDefaultPlatform(&platform_); + PopulateDefaultStreamExecutor(&se_); + PopulateDefaultTimerFns(&timer_fns_); + } + void TearDown() override {} + + StreamExecutor* GetExecutor(int ordinal) { + if (!cplatform_) { + cplatform_ = absl::make_unique(platform_, destroy_platform, + se_, timer_fns_); + } + port::StatusOr maybe_executor = + cplatform_->ExecutorForDevice(ordinal); + TF_CHECK_OK(maybe_executor.status()); + return maybe_executor.ConsumeValueOrDie(); + } + SP_Platform platform_; + SP_StreamExecutor se_; + SP_TimerFns timer_fns_; + std::unique_ptr cplatform_; +}; + +TEST_F(StreamExecutorTest, Allocate) { + se_.allocate = [](const SP_Device* const device, uint64_t size, + int64_t memory_space, SP_DeviceMemoryBase* const mem) { + mem->struct_size = SP_DEVICE_MEMORY_BASE_STRUCT_SIZE; + mem->opaque = std::malloc(size); + mem->size = size; + }; + se_.deallocate = [](const SP_Device* const device, + SP_DeviceMemoryBase* const mem) { + EXPECT_EQ(mem->size, 2 * sizeof(int)); + std::free(mem->opaque); + mem->opaque = nullptr; + mem->size = 0; + }; + StreamExecutor* executor = GetExecutor(0); + DeviceMemory mem = executor->AllocateArray(2); + ASSERT_NE(mem.opaque(), nullptr); + ASSERT_EQ(mem.size(), 2 * sizeof(int)); + executor->Deallocate(&mem); + ASSERT_EQ(mem.opaque(), nullptr); +} + +TEST_F(StreamExecutorTest, HostMemoryAllocate) { + static bool allocate_called = false; + static bool deallocate_called = false; + se_.host_memory_allocate = [](const SP_Device* const device, uint64_t size) { + allocate_called = true; + return std::malloc(size); + }; + se_.host_memory_deallocate = [](const SP_Device* const device, void* mem) { + std::free(mem); + deallocate_called = true; + }; + StreamExecutor* executor = GetExecutor(0); + ASSERT_FALSE(allocate_called); + void* mem = executor->HostMemoryAllocate(8); + ASSERT_NE(mem, nullptr); + ASSERT_TRUE(allocate_called); + ASSERT_FALSE(deallocate_called); + executor->HostMemoryDeallocate(mem); + ASSERT_TRUE(deallocate_called); +} + +TEST_F(StreamExecutorTest, GetAllocatorStats) { + se_.get_allocator_stats = [](const SP_Device* const device, + SP_AllocatorStats* const stat) -> TF_Bool { + stat->struct_size = SP_ALLOCATORSTATS_STRUCT_SIZE; + stat->bytes_in_use = 123; + return true; + }; + + StreamExecutor* executor = GetExecutor(0); + absl::optional optional_stats = executor->GetAllocatorStats(); + ASSERT_TRUE(optional_stats.has_value()); + AllocatorStats stats = optional_stats.value(); + ASSERT_EQ(stats.bytes_in_use, 123); +} + +TEST_F(StreamExecutorTest, DeviceMemoryUsage) { + se_.device_memory_usage = [](const SP_Device* const device, + int64_t* const free, + int64_t* const total) -> TF_Bool { + *free = 45; + *total = 7; + return true; + }; + + StreamExecutor* executor = GetExecutor(0); + int64 free = 0; + int64 total = 0; + executor->DeviceMemoryUsage(&free, &total); + ASSERT_EQ(free, 45); + ASSERT_EQ(total, 7); +} + +TEST_F(StreamExecutorTest, CreateStream) { + static bool stream_created = false; + static bool stream_deleted = false; + se_.create_stream = [](const SP_Device* const device, SP_Stream* stream, + TF_Status* const status) -> void { + *stream = new SP_Stream_st(14); + stream_created = true; + }; + se_.destroy_stream = [](const SP_Device* const device, + SP_Stream stream) -> void { + auto custom_stream = static_cast(stream); + ASSERT_EQ(custom_stream->stream_id, 14); + delete custom_stream; + stream_deleted = true; + }; + + StreamExecutor* executor = GetExecutor(0); + ASSERT_FALSE(stream_created); + Stream* stream = new Stream(executor); + stream->Init(); + ASSERT_TRUE(stream->ok()); + ASSERT_TRUE(stream_created); + ASSERT_FALSE(stream_deleted); + delete stream; + ASSERT_TRUE(stream_deleted); +} + +TEST_F(StreamExecutorTest, CreateStreamDependency) { + static bool create_stream_dependency_called = false; + se_.create_stream_dependency = [](const SP_Device* const device, + SP_Stream dependent, SP_Stream other, + TF_Status* const status) { + TF_SetStatus(status, TF_OK, ""); + create_stream_dependency_called = true; + }; + + StreamExecutor* executor = GetExecutor(0); + Stream dependent(executor); + dependent.Init(); + Stream other(executor); + other.Init(); + ASSERT_FALSE(create_stream_dependency_called); + dependent.ThenWaitFor(&other); + ASSERT_TRUE(create_stream_dependency_called); +} + +TEST_F(StreamExecutorTest, StreamStatus) { + static bool status_ok = true; + se_.get_stream_status = [](const SP_Device* const device, SP_Stream stream, + TF_Status* const status) -> void { + if (status_ok) { + TF_SetStatus(status, TF_OK, ""); + } else { + TF_SetStatus(status, TF_INTERNAL, "Test error"); + } + }; + + StreamExecutor* executor = GetExecutor(0); + Stream stream(executor); + stream.Init(); + ASSERT_TRUE(stream.ok()); + TF_ASSERT_OK(stream.RefreshStatus()); + status_ok = false; + auto updated_status = stream.RefreshStatus(); + ASSERT_FALSE(stream.ok()); + ASSERT_EQ(updated_status.error_message(), "Test error"); +} + +TEST_F(StreamExecutorTest, CreateEvent) { + static bool event_created = false; + static bool event_deleted = false; + se_.create_event = [](const SP_Device* const device, SP_Event* event, + TF_Status* const status) -> void { + *event = new SP_Event_st(123); + event_created = true; + }; + se_.destroy_event = [](const SP_Device* const device, + SP_Event event) -> void { + auto custom_event = static_cast(event); + ASSERT_EQ(custom_event->event_id, 123); + delete custom_event; + event_deleted = true; + }; + + StreamExecutor* executor = GetExecutor(0); + ASSERT_FALSE(event_created); + Event* event = new Event(executor); + event->Init(); + ASSERT_TRUE(event_created); + ASSERT_FALSE(event_deleted); + delete event; + ASSERT_TRUE(event_deleted); +} + +TEST_F(StreamExecutorTest, PollForEventStatus) { + static SE_EventStatus event_status = SE_EVENT_COMPLETE; + se_.create_event = [](const SP_Device* const device, SP_Event* event, + TF_Status* const status) -> void { + *event = new SP_Event_st(123); + }; + se_.destroy_event = [](const SP_Device* const device, + SP_Event event) -> void { delete event; }; + se_.get_event_status = [](const SP_Device* const device, + SP_Event event) -> SE_EventStatus { + EXPECT_EQ(event->event_id, 123); + return event_status; + }; + + StreamExecutor* executor = GetExecutor(0); + Event event(executor); + event.Init(); + ASSERT_EQ(event.PollForStatus(), Event::Status::kComplete); + event_status = SE_EVENT_ERROR; + ASSERT_EQ(event.PollForStatus(), Event::Status::kError); +} + +TEST_F(StreamExecutorTest, RecordAndWaitForEvent) { + static bool record_called = false; + static bool wait_called = false; + se_.create_stream = [](const SP_Device* const device, SP_Stream* stream, + TF_Status* const status) -> void { + *stream = new SP_Stream_st(1); + }; + se_.destroy_stream = [](const SP_Device* const device, + SP_Stream stream) -> void { delete stream; }; + se_.create_event = [](const SP_Device* const device, SP_Event* event, + TF_Status* const status) -> void { + *event = new SP_Event_st(2); + }; + se_.destroy_event = [](const SP_Device* const device, + SP_Event event) -> void { delete event; }; + se_.record_event = [](const SP_Device* const device, SP_Stream stream, + SP_Event event, TF_Status* const status) { + EXPECT_EQ(stream->stream_id, 1); + EXPECT_EQ(event->event_id, 2); + TF_SetStatus(status, TF_OK, ""); + record_called = true; + }; + se_.wait_for_event = [](const SP_Device* const device, SP_Stream stream, + SP_Event event, TF_Status* const status) { + EXPECT_EQ(stream->stream_id, 1); + EXPECT_EQ(event->event_id, 2); + TF_SetStatus(status, TF_OK, ""); + wait_called = true; + }; + + StreamExecutor* executor = GetExecutor(0); + Event event(executor); + event.Init(); + Stream stream(executor); + stream.Init(); + ASSERT_FALSE(record_called); + stream.ThenRecordEvent(&event); + ASSERT_TRUE(record_called); + ASSERT_FALSE(wait_called); + stream.ThenWaitFor(&event); + ASSERT_TRUE(wait_called); +} + +TEST_F(StreamExecutorTest, CreateTimer) { + static bool timer_created = false; + static bool timer_deleted = false; + se_.create_timer = [](const SP_Device* const device, SP_Timer* timer, + TF_Status* const status) -> void { + *timer = new SP_Timer_st(25); + timer_created = true; + }; + se_.destroy_timer = [](const SP_Device* const device, + SP_Timer timer) -> void { + auto custom_timer = static_cast(timer); + EXPECT_EQ(custom_timer->timer_id, 25); + delete custom_timer; + timer_deleted = true; + }; + + StreamExecutor* executor = GetExecutor(0); + ASSERT_FALSE(timer_created); + Stream stream(executor); + stream.Init(); + Timer* timer = new Timer(executor); + stream.InitTimer(timer); + ASSERT_TRUE(stream.ok()); + ASSERT_TRUE(timer_created); + ASSERT_FALSE(timer_deleted); + delete timer; + ASSERT_TRUE(timer_deleted); +} + +TEST_F(StreamExecutorTest, StartTimer) { + static bool start_called = false; + static bool stop_called = false; + static TF_Code start_timer_status = TF_OK; + static TF_Code stop_timer_status = TF_OK; + se_.create_timer = [](const SP_Device* const device, SP_Timer* timer, + TF_Status* const status) -> void { + *timer = new SP_Timer_st(7); + }; + se_.destroy_timer = [](const SP_Device* const device, + SP_Timer timer) -> void { delete timer; }; + se_.start_timer = [](const SP_Device* const device, SP_Stream stream, + SP_Timer timer, TF_Status* const status) { + TF_SetStatus(status, start_timer_status, ""); + EXPECT_EQ(timer->timer_id, 7); + start_called = true; + }; + se_.stop_timer = [](const SP_Device* const device, SP_Stream stream, + SP_Timer timer, TF_Status* const status) { + TF_SetStatus(status, stop_timer_status, ""); + EXPECT_EQ(timer->timer_id, 7); + stop_called = true; + }; + StreamExecutor* executor = GetExecutor(0); + Stream stream(executor); + stream.Init(); + Timer timer(executor); + stream.InitTimer(&timer); + + // Check both start and stop succeed + ASSERT_FALSE(start_called); + stream.ThenStartTimer(&timer); + ASSERT_TRUE(start_called); + ASSERT_FALSE(stop_called); + stream.ThenStopTimer(&timer); + ASSERT_TRUE(stop_called); + + // Check start timer fails + ASSERT_TRUE(stream.ok()); + start_timer_status = TF_UNKNOWN; + stream.ThenStartTimer(&timer); + ASSERT_FALSE(stream.ok()); + + // Check stop timer fails + start_timer_status = TF_OK; + stop_timer_status = TF_UNKNOWN; + Stream stream2(executor); + stream2.Init(); + Timer timer2(executor); + stream2.InitTimer(&timer2); + stream2.ThenStartTimer(&timer2); + ASSERT_TRUE(stream2.ok()); + stream2.ThenStopTimer(&timer2); + ASSERT_FALSE(stream2.ok()); +} + +TEST_F(StreamExecutorTest, TimerFns) { + se_.create_timer = [](const SP_Device* const device, SP_Timer* timer, + TF_Status* const status) -> void { + *timer = new SP_Timer_st(25000); + }; + se_.destroy_timer = [](const SP_Device* const device, + SP_Timer timer) -> void { delete timer; }; + + StreamExecutor* executor = GetExecutor(0); + Stream stream(executor); + stream.Init(); + Timer timer(executor); + stream.InitTimer(&timer); + // Our test nanoseconds function just returns value + // passed to SP_Timer_st constructor. + ASSERT_EQ(timer.Nanoseconds(), 25000); + ASSERT_EQ(timer.Microseconds(), 25); +} + +TEST_F(StreamExecutorTest, MemcpyToHost) { + se_.create_stream = [](const SP_Device* const device, SP_Stream* stream, + TF_Status* const status) -> void { + *stream = new SP_Stream_st(14); + }; + se_.destroy_stream = [](const SP_Device* const device, + SP_Stream stream) -> void { delete stream; }; + + se_.memcpy_dtoh = [](const SP_Device* const device, SP_Stream stream, + void* host_dst, + const SP_DeviceMemoryBase* const device_src, + uint64_t size, TF_Status* const status) { + TF_SetStatus(status, TF_OK, ""); + EXPECT_EQ(stream->stream_id, 14); + std::memcpy(host_dst, device_src->opaque, size); + }; + + StreamExecutor* executor = GetExecutor(0); + Stream stream(executor); + stream.Init(); + size_t size = sizeof(int); + int src_data = 34; + int dst_data = 2; + DeviceMemoryBase device_src(&src_data, size); + Stream& stream_ref = stream.ThenMemcpy(&dst_data, device_src, size); + ASSERT_EQ(dst_data, 34); + ASSERT_EQ(stream_ref.implementation(), stream.implementation()); +} + +TEST_F(StreamExecutorTest, MemcpyFromHost) { + se_.memcpy_htod = [](const SP_Device* const device, SP_Stream stream, + SP_DeviceMemoryBase* const device_dst, + const void* host_src, uint64_t size, + TF_Status* const status) { + TF_SetStatus(status, TF_OK, ""); + std::memcpy(device_dst->opaque, host_src, size); + }; + + StreamExecutor* executor = GetExecutor(0); + Stream stream(executor); + stream.Init(); + size_t size = sizeof(int); + int src_data = 18; + int dst_data = 0; + DeviceMemoryBase device_dst(&dst_data, size); + stream.ThenMemcpy(&device_dst, &src_data, size); + ASSERT_EQ(dst_data, 18); +} + +TEST_F(StreamExecutorTest, MemcpyDeviceToDevice) { + se_.memcpy_dtod = [](const SP_Device* const device, SP_Stream stream, + SP_DeviceMemoryBase* const device_dst, + const SP_DeviceMemoryBase* const device_src, + uint64_t size, TF_Status* const status) { + TF_SetStatus(status, TF_OK, ""); + std::memcpy(device_dst->opaque, device_src->opaque, size); + }; + + StreamExecutor* executor = GetExecutor(0); + Stream stream(executor); + stream.Init(); + size_t size = sizeof(int); + int src_data = 18; + int dst_data = 0; + DeviceMemoryBase device_dst(&dst_data, size); + DeviceMemoryBase device_src(&src_data, size); + stream.ThenMemcpy(&device_dst, device_src, size); + ASSERT_EQ(dst_data, 18); +} + +TEST_F(StreamExecutorTest, SyncMemcpyToHost) { + se_.sync_memcpy_dtoh = [](const SP_Device* const device, void* host_dst, + const SP_DeviceMemoryBase* const device_src, + uint64_t size, TF_Status* const status) { + TF_SetStatus(status, TF_OK, ""); + std::memcpy(host_dst, device_src->opaque, size); + }; + + StreamExecutor* executor = GetExecutor(0); + size_t size = sizeof(int); + int src_data = 34; + int dst_data = 2; + DeviceMemoryBase device_src(&src_data, size); + TF_ASSERT_OK(executor->SynchronousMemcpyD2H(device_src, size, &dst_data)); + ASSERT_EQ(dst_data, 34); +} + +TEST_F(StreamExecutorTest, SyncMemcpyFromHost) { + se_.sync_memcpy_htod = + [](const SP_Device* const device, SP_DeviceMemoryBase* const device_dst, + const void* host_src, uint64_t size, TF_Status* const status) { + TF_SetStatus(status, TF_OK, ""); + std::memcpy(device_dst->opaque, host_src, size); + }; + + StreamExecutor* executor = GetExecutor(0); + size_t size = sizeof(int); + int src_data = 18; + int dst_data = 0; + DeviceMemoryBase device_dst(&dst_data, size); + TF_ASSERT_OK(executor->SynchronousMemcpyH2D(&src_data, size, &device_dst)); + ASSERT_EQ(dst_data, 18); +} + +TEST_F(StreamExecutorTest, SyncMemcpyDeviceToDevice) { + se_.sync_memcpy_dtod = [](const SP_Device* const device, + SP_DeviceMemoryBase* const device_dst, + const SP_DeviceMemoryBase* const device_src, + uint64_t size, TF_Status* const status) { + TF_SetStatus(status, TF_OK, ""); + std::memcpy(device_dst->opaque, device_src->opaque, size); + }; + + StreamExecutor* executor = GetExecutor(0); + size_t size = sizeof(int); + int src_data = 18; + int dst_data = 0; + DeviceMemoryBase device_dst(&dst_data, size); + DeviceMemoryBase device_src(&src_data, size); + ASSERT_TRUE(executor->SynchronousMemcpy(&device_dst, device_src, size)); + ASSERT_EQ(dst_data, 18); +} + +TEST_F(StreamExecutorTest, BlockHostForEvent) { + static bool block_host_for_event_called = false; + se_.create_event = [](const SP_Device* const device, SP_Event* event, + TF_Status* const status) { + *event = new SP_Event_st(357); + }; + se_.destroy_event = [](const SP_Device* const device, SP_Event event) { + delete event; + }; + se_.block_host_for_event = [](const SP_Device* const device, SP_Event event, + TF_Status* const status) -> void { + ASSERT_EQ(event->event_id, 357); + TF_SetStatus(status, TF_OK, ""); + block_host_for_event_called = true; + }; + + StreamExecutor* executor = GetExecutor(0); + Stream stream(executor); + stream.Init(); + ASSERT_FALSE(block_host_for_event_called); + TF_ASSERT_OK(stream.BlockHostUntilDone()); + ASSERT_TRUE(block_host_for_event_called); +} + +TEST_F(StreamExecutorTest, SynchronizeAllActivity) { + static bool synchronize_all_called = false; + se_.synchronize_all_activity = [](const SP_Device* const device, + TF_Status* const status) { + TF_SetStatus(status, TF_OK, ""); + synchronize_all_called = true; + }; + + StreamExecutor* executor = GetExecutor(0); + ASSERT_FALSE(synchronize_all_called); + ASSERT_TRUE(executor->SynchronizeAllActivity()); + ASSERT_TRUE(synchronize_all_called); +} + +TEST_F(StreamExecutorTest, HostCallbackOk) { + se_.host_callback = [](SP_Device* const device, SP_Stream stream, + SE_StatusCallbackFn const callback_fn, + void* const callback_arg) -> TF_Bool { + TF_Status* status = TF_NewStatus(); + callback_fn(callback_arg, status); + bool ok = TF_GetCode(status) == TF_OK; + TF_DeleteStatus(status); + return ok; + }; + StreamExecutor* executor = GetExecutor(0); + Stream stream(executor); + stream.Init(); + std::function callback = []() -> port::Status { + return port::Status::OK(); + }; + stream.ThenDoHostCallbackWithStatus(callback); + ASSERT_TRUE(stream.ok()); +} + +TEST_F(StreamExecutorTest, HostCallbackError) { + se_.host_callback = [](SP_Device* const device, SP_Stream stream, + SE_StatusCallbackFn const callback_fn, + void* const callback_arg) -> TF_Bool { + TF_Status* status = TF_NewStatus(); + callback_fn(callback_arg, status); + bool ok = TF_GetCode(status) == TF_OK; + TF_DeleteStatus(status); + return ok; + }; + StreamExecutor* executor = GetExecutor(0); + Stream stream(executor); + stream.Init(); + std::function callback = []() -> port::Status { + return port::UnimplementedError("Unimplemented"); + }; + stream.ThenDoHostCallbackWithStatus(callback); + ASSERT_FALSE(stream.ok()); +} +} // namespace +} // namespace stream_executor diff --git a/tensorflow/c/kernels.cc b/tensorflow/c/kernels.cc index 0b12b17c09b..ed501b5b101 100644 --- a/tensorflow/c/kernels.cc +++ b/tensorflow/c/kernels.cc @@ -280,6 +280,36 @@ TF_Tensor* TF_AllocateOutput(TF_OpKernelContext* context, int index, return tf_tensor; } +TF_Tensor* TF_ForwardInputOrAllocateOutput( + TF_OpKernelContext* context, int* candidate_input_indices, + int num_candidate_input_indices, int output_index, int64_t* output_dims, + int output_num_dims, int* forwarded_input, TF_Status* status) { + TF_SetStatus(status, TF_OK, ""); + auto* cc_ctx = reinterpret_cast<::tensorflow::OpKernelContext*>(context); + + static_assert(sizeof(int64_t) == sizeof(tensorflow::int64), + "64-bit int types should match in size"); + tensorflow::gtl::ArraySlice input_indices_array( + candidate_input_indices, num_candidate_input_indices); + tensorflow::gtl::ArraySlice output_dimarray( + reinterpret_cast(output_dims), output_num_dims); + tensorflow::Tensor* output_tensor_pointer; + tensorflow::Status s = cc_ctx->forward_input_or_allocate_output( + input_indices_array, output_index, + tensorflow::TensorShape(output_dimarray), &output_tensor_pointer, + forwarded_input); + if (!s.ok()) { + ::tensorflow::Set_TF_Status_from_Status(status, s); + return nullptr; + } + TF_Tensor* tf_tensor_output = TF_TensorFromTensor(*output_tensor_pointer, &s); + if (!s.ok()) { + ::tensorflow::Set_TF_Status_from_Status(status, s); + return nullptr; + } + return tf_tensor_output; +} + TF_Tensor* TF_AllocateTemp(TF_OpKernelContext* context, TF_DataType dtype, int64_t* dims, int num_dims, TF_AllocatorAttributes* attributes, diff --git a/tensorflow/c/kernels.h b/tensorflow/c/kernels.h index 15fcf0f5188..489aa5399a5 100644 --- a/tensorflow/c/kernels.h +++ b/tensorflow/c/kernels.h @@ -200,6 +200,17 @@ TF_CAPI_EXPORT TF_Tensor* TF_AllocateOutput(TF_OpKernelContext* context, int64_t* dims, int num_dims, size_t len, TF_Status* status); +// Tries to forward one of the inputs given in input_indices to +// output[output_index]. If none of the given inputs can be forwarded, calls +// allocate_output() to allocate a new output buffer. The index of the +// forwarded input will be assign to output argument forwarded_input (if it's +// not nullptr). If no inputs are forwarded, forwarded_input will be assigned +// -1. +TF_CAPI_EXPORT TF_Tensor* TF_ForwardInputOrAllocateOutput( + TF_OpKernelContext* context, int* candidate_input_indices, + int num_candidate_input_indices, int output_index, int64_t* output_dims, + int output_num_dims, int* forwarded_input, TF_Status* status); + // Allocates a temporary Tensor of the specified type and shape. The // Tensor must not be used after kernel construction is // complete. diff --git a/tensorflow/c/kernels/histogram_summary_op.cc b/tensorflow/c/kernels/histogram_summary_op.cc index ada1bd3c630..5de52703f5d 100644 --- a/tensorflow/c/kernels/histogram_summary_op.cc +++ b/tensorflow/c/kernels/histogram_summary_op.cc @@ -20,8 +20,8 @@ limitations under the License. #include "tensorflow/core/framework/selective_registration.h" #include "tensorflow/core/framework/summary.pb.h" #include "tensorflow/core/framework/types.h" -#include "tensorflow/core/lib/bfloat16/bfloat16.h" #include "tensorflow/core/lib/histogram/histogram.h" +#include "tensorflow/core/platform/bfloat16.h" #include "tensorflow/core/platform/logging.h" #include "tensorflow/core/platform/macros.h" #include "tensorflow/core/platform/protobuf.h" diff --git a/tensorflow/c/kernels/summary_op.cc b/tensorflow/c/kernels/summary_op.cc index bd528da4165..ac7eced0ae7 100644 --- a/tensorflow/c/kernels/summary_op.cc +++ b/tensorflow/c/kernels/summary_op.cc @@ -25,7 +25,7 @@ limitations under the License. #include "tensorflow/core/framework/selective_registration.h" #include "tensorflow/core/framework/summary.pb.h" #include "tensorflow/core/framework/types.h" -#include "tensorflow/core/lib/bfloat16/bfloat16.h" +#include "tensorflow/core/platform/bfloat16.h" #include "tensorflow/core/platform/logging.h" #include "tensorflow/core/platform/macros.h" #include "tensorflow/core/platform/protobuf.h" diff --git a/tensorflow/c/kernels_test.cc b/tensorflow/c/kernels_test.cc index e8223e40064..c9df2cc34d1 100644 --- a/tensorflow/c/kernels_test.cc +++ b/tensorflow/c/kernels_test.cc @@ -565,6 +565,74 @@ TEST_F(DeviceKernelOpTest, TestAllocateTempSize2x3) { output->DebugString(100)); } +TEST_F(DeviceKernelOpTest, TestForwardInputOrAllocateOutput) { + const char* node_name = "TestForwardInputOrAllocateOutputKernel"; + const char* op_name = "BazOp"; + const char* device_name = "FakeDeviceName"; + + REGISTER_OP(op_name) + .Input("input1: float") + .Input("input2: float") + .Output("output1: float") + .Attr("SomeDataTypeAttr: type"); + + // A kernel whose Compute function that forwards a scalar input to output + auto my_compute_func = [](void* kernel, TF_OpKernelContext* ctx) { + TF_Status* s = TF_NewStatus(); + int candidate_input_indices[1] = {0}; + int forwarded_input; + int64_t output_dims[1] = {}; + TF_Tensor* output = TF_ForwardInputOrAllocateOutput( + /*context=*/ctx, candidate_input_indices, + /*num_candidate_input_indices=*/1, + /*output_index=*/0, output_dims, /*output_num_dims=*/0, + &forwarded_input, /*status=*/s); + EXPECT_EQ(TF_OK, TF_GetCode(s)); + EXPECT_EQ(forwarded_input, 0); + EXPECT_EQ(TF_FLOAT, TF_TensorType(output)); + EXPECT_EQ(0, TF_NumDims(output)); + TF_DeleteStatus(s); + TF_DeleteTensor(output); + }; + + TF_KernelBuilder* builder = TF_NewKernelBuilder(op_name, device_name, nullptr, + my_compute_func, nullptr); + + { + TF_Status* status = TF_NewStatus(); + TF_RegisterKernelBuilder(node_name, builder, status); + EXPECT_EQ(TF_OK, TF_GetCode(status)); + TF_DeleteStatus(status); + } + + { + OpKernelContext::Params p; + DummyDevice dummy_device(nullptr); + p.device = &dummy_device; + AllocatorAttributes alloc_attrs; + p.output_attr_array = &alloc_attrs; + + Tensor t(123.0f); + + gtl::InlinedVector inputs; + // GetFakeKernel requires a NodeDef with two inputs + inputs.emplace_back(&t); + inputs.emplace_back(); + p.inputs = &inputs; + + Status status; + std::unique_ptr kernel = + GetFakeKernel(device_name, op_name, node_name, &status); + TF_EXPECT_OK(status); + ASSERT_NE(nullptr, kernel.get()); + + p.op_kernel = kernel.get(); + OpKernelContext ctx(&p); + kernel->Compute(&ctx); + ASSERT_EQ(123, ctx.mutable_output(0)->scalar()()); + } +} + void validate_tensor(TF_Tensor* tensor, int64_t* dims, int64_t num_dims, TF_DataType dtype) { EXPECT_EQ(TF_FLOAT, TF_TensorType(tensor)); diff --git a/tensorflow/c/logging.cc b/tensorflow/c/logging.cc index bf6bf069fff..13c9e6ac208 100644 --- a/tensorflow/c/logging.cc +++ b/tensorflow/c/logging.cc @@ -28,6 +28,7 @@ void TF_Log(TF_LogLevel level, const char* fmt, ...) { va_list args; va_start(args, fmt); auto message = BuildMessage(fmt, args); + va_end(args); switch (level) { case TF_INFO: LOG(INFO) << message; @@ -48,6 +49,7 @@ void TF_VLog(int level, const char* fmt, ...) { va_list args; va_start(args, fmt); auto message = BuildMessage(fmt, args); + va_end(args); VLOG(level) << message; } @@ -55,5 +57,6 @@ void TF_DVLog(int level, const char* fmt, ...) { va_list args; va_start(args, fmt); auto message = BuildMessage(fmt, args); + va_end(args); DVLOG(level) << message; } diff --git a/tensorflow/cc/saved_model/BUILD b/tensorflow/cc/saved_model/BUILD index a67d349bab7..a3ea0c75bc7 100644 --- a/tensorflow/cc/saved_model/BUILD +++ b/tensorflow/cc/saved_model/BUILD @@ -47,6 +47,7 @@ cc_library( # TODO(b/111634734): :lib and :protos_all contain dependencies that # cannot be built on mobile platforms. Instead, include the appropriate # tf_lib depending on the build platform. + "@com_google_absl//absl/memory:memory", "//tensorflow/core:lib", "//tensorflow/core:protos_all_cc", ]), @@ -171,6 +172,7 @@ tf_cc_test( deps = [ ":constants", ":loader", + ":reader", ":signature_constants", ":tag_constants", "//tensorflow/core:lib", diff --git a/tensorflow/cc/saved_model/experimental/public/BUILD b/tensorflow/cc/saved_model/experimental/public/BUILD index 3e9a671a61f..9640848ebf5 100644 --- a/tensorflow/cc/saved_model/experimental/public/BUILD +++ b/tensorflow/cc/saved_model/experimental/public/BUILD @@ -51,8 +51,32 @@ cc_library( deps = [ ":concrete_function", ":concrete_function_list", + ":signature_def_function", "//tensorflow/c/experimental/saved_model/public:saved_model_api", "//tensorflow/cc/experimental/base/public:runtime", "//tensorflow/cc/experimental/base/public:status", ], ) + +cc_library( + name = "signature_def_function", + hdrs = [ + "signature_def_function.h", + ], + deps = [ + ":signature_def_function_metadata", + "//tensorflow/c/eager:c_api", + "//tensorflow/c/experimental/saved_model/public:signature_def_function", + "//tensorflow/cc/experimental/base/public:status", + ], +) + +cc_library( + name = "signature_def_function_metadata", + hdrs = [ + "signature_def_function_metadata.h", + ], + deps = [ + "//tensorflow/c/experimental/saved_model/public:signature_def_function_metadata", + ], +) diff --git a/tensorflow/cc/saved_model/experimental/public/saved_model_api.h b/tensorflow/cc/saved_model/experimental/public/saved_model_api.h index 04018bf2aab..c2bfb4dcf83 100644 --- a/tensorflow/cc/saved_model/experimental/public/saved_model_api.h +++ b/tensorflow/cc/saved_model/experimental/public/saved_model_api.h @@ -26,6 +26,7 @@ limitations under the License. #include "tensorflow/cc/experimental/base/public/status.h" #include "tensorflow/cc/saved_model/experimental/public/concrete_function.h" #include "tensorflow/cc/saved_model/experimental/public/concrete_function_list.h" +#include "tensorflow/cc/saved_model/experimental/public/signature_def_function.h" namespace tensorflow { namespace experimental { @@ -80,8 +81,8 @@ class SavedModelAPI { // If status is not OK, returns nullptr. Otherwise, returns a // tensorflow::cc::ConcreteFunction pointer. The lifetime of this pointer // is bound to SavedModelAPI it was loaded from. - ConcreteFunction* GetSignatureDefFunction(const std::string& function_path, - Status* status); + SignatureDefFunction* GetSignatureDefFunction( + const std::string& function_path, Status* status); // Lists all Conrete Functions available from the SavedModel. std::vector ListFunctions(); @@ -140,14 +141,14 @@ inline ConcreteFunction* SavedModelAPI::GetConcreteFunction( return ConcreteFunction::wrap(function); } -inline ConcreteFunction* SavedModelAPI::GetSignatureDefFunction( +inline SignatureDefFunction* SavedModelAPI::GetSignatureDefFunction( const std::string& function_path, Status* status) { - TF_ConcreteFunction* function = TF_GetSavedModelSignatureDefFunction( + TF_SignatureDefFunction* function = TF_GetSavedModelSignatureDefFunction( saved_model_.get(), function_path.c_str(), status->GetTFStatus()); if (!status->ok()) { return nullptr; } - return ConcreteFunction::wrap(function); + return SignatureDefFunction::wrap(function); } inline std::vector SavedModelAPI::ListFunctions() { diff --git a/tensorflow/cc/saved_model/experimental/public/signature_def_function.h b/tensorflow/cc/saved_model/experimental/public/signature_def_function.h new file mode 100644 index 00000000000..bc72d208e87 --- /dev/null +++ b/tensorflow/cc/saved_model/experimental/public/signature_def_function.h @@ -0,0 +1,89 @@ +/* Copyright 2020 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_CC_SAVED_MODEL_EXPERIMENTAL_PUBLIC_SIGNATURE_DEF_FUNCTION_H_ +#define TENSORFLOW_CC_SAVED_MODEL_EXPERIMENTAL_PUBLIC_SIGNATURE_DEF_FUNCTION_H_ + +#include + +#include "tensorflow/c/eager/c_api.h" +#include "tensorflow/c/experimental/saved_model/public/signature_def_function.h" +#include "tensorflow/cc/experimental/base/public/status.h" +#include "tensorflow/cc/saved_model/experimental/public/signature_def_function_metadata.h" + +namespace tensorflow { +namespace experimental { +namespace cc { + +// SignatureDefFunctions are functions that correspond to either: +// "signatures" saved from a TF2 SavedModel APIs: +// https://github.com/tensorflow/tensorflow/blob/8ce0600f58ed84a8c84a7bbdb014d1f09e44f4c8/tensorflow/python/saved_model/save.py#L830-L854 +// Or the "SignatureDefMap" saved from TF1 SavedModel APIs: +// https://github.com/tensorflow/tensorflow/blob/8ce0600f58ed84a8c84a7bbdb014d1f09e44f4c8/tensorflow/python/saved_model/load_v1_in_v2_test.py#L170-L174 +// In both cases, a SignatureDef is serialized as a SignatureDef protobuf: +// https://github.com/tensorflow/tensorflow/blob/8ce0600f58ed84a8c84a7bbdb014d1f09e44f4c8/tensorflow/core/protobuf/meta_graph.proto#L260-L330 +// and represents a computation defined by a TF subgraph. +// These Signatures were primarily designed to be interoperable with the legacy +// TF 1 Session-based C++ SavedModelBundle loading APIs: +// https://github.com/tensorflow/tensorflow/blob/26c4ee0c833e74f94d0102d8b005c41a28b44445/tensorflow/cc/saved_model/loader.h#L96-L108 +// SignatureDefFunctions have different semantics from regular TF2 +// ConcreteFunctions, and are mainly intended provide a serving-friendly +// transition point from the TF1 Session API. +// First, SignatureDefFunctions have different calling conventions. +// SignatureDefFunctions' inputs and outputs are constrained to **flattened +// lists of TensorHandles only**. They do not support more exotic input/output +// types (like optionals, generators, etc). Additionally, this flattening means +// they will not preserve the exact interface of the original tf.function they +// were traced from, as things like composite tensors decay into their +// internal dense tensor representation. +// Second, all inputs and outputs are "named", and these names are load bearing +// (eg: they are part of the interface of tensorflow_serving): +// https://github.com/tensorflow/serving/blob/e0d247b2e4050713194b8fad0be24a0636df7209/tensorflow_serving/apis/predict.proto#L21 +// https://github.com/tensorflow/serving/blob/e0d247b2e4050713194b8fad0be24a0636df7209/tensorflow_serving/apis/predict.proto#L39 +// The name of each input/output is stored in the corresponding tf::Argument in +// SignatureDefFunctionMetadata::arguments(). Users must ensure the order of +// TensorHandles passed to the function matches with the order of named +// arguments. Similarly the name of the outputs is stored in +// SignatureDefFunctionMetadata::returns(). +class SignatureDefFunction final { + public: + // Returns FunctionMetadata associated with this ConcreteFunction. + const SignatureDefFunctionMetadata* GetFunctionMetadata(); + + private: + friend class SavedModelAPI; + friend class ConcreteFunctionList; + + // TODO(bmzhao): Consider adding a macro for wrapping/unwrapping + // when moving out of experimental. + static SignatureDefFunction* wrap(TF_SignatureDefFunction* p) { + return reinterpret_cast(p); + } + static TF_SignatureDefFunction* unwrap(SignatureDefFunction* p) { + return reinterpret_cast(p); + } +}; + +inline const SignatureDefFunctionMetadata* +SignatureDefFunction::GetFunctionMetadata() { + return SignatureDefFunctionMetadata::wrap( + TF_SignatureDefFunctionGetMetadata(unwrap(this))); +} + +} // namespace cc +} // namespace experimental +} // namespace tensorflow + +#endif // TENSORFLOW_CC_SAVED_MODEL_EXPERIMENTAL_PUBLIC_SIGNATURE_DEF_FUNCTION_H_ diff --git a/tensorflow/cc/saved_model/experimental/public/signature_def_function_metadata.h b/tensorflow/cc/saved_model/experimental/public/signature_def_function_metadata.h new file mode 100644 index 00000000000..6cb01bf1a26 --- /dev/null +++ b/tensorflow/cc/saved_model/experimental/public/signature_def_function_metadata.h @@ -0,0 +1,47 @@ +/* Copyright 2020 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_CC_SAVED_MODEL_EXPERIMENTAL_PUBLIC_SIGNATURE_DEF_FUNCTION_METADATA_H_ +#define TENSORFLOW_CC_SAVED_MODEL_EXPERIMENTAL_PUBLIC_SIGNATURE_DEF_FUNCTION_METADATA_H_ + +#include + +#include "tensorflow/c/experimental/saved_model/public/signature_def_function_metadata.h" + +namespace tensorflow { +namespace experimental { +namespace cc { + +// SignatureDefFunctionMetadata stores additional information on each input +// and output's names, dtypes, and shape. +class SignatureDefFunctionMetadata final { + // TODO(bmzhao): Add getters here as necessary. + private: + friend class SignatureDefFunction; + static SignatureDefFunctionMetadata* wrap( + TF_SignatureDefFunctionMetadata* p) { + return reinterpret_cast(p); + } + static TF_SignatureDefFunctionMetadata* unwrap( + SignatureDefFunctionMetadata* p) { + return reinterpret_cast(p); + } +}; + +} // namespace cc +} // namespace experimental +} // namespace tensorflow + +#endif // TENSORFLOW_CC_SAVED_MODEL_EXPERIMENTAL_PUBLIC_SIGNATURE_DEF_FUNCTION_METADATA_H_ diff --git a/tensorflow/cc/saved_model/loader.cc b/tensorflow/cc/saved_model/loader.cc index f9c720a2ba2..ecefe7d0406 100644 --- a/tensorflow/cc/saved_model/loader.cc +++ b/tensorflow/cc/saved_model/loader.cc @@ -31,6 +31,7 @@ limitations under the License. #include "tensorflow/core/platform/env.h" #include "tensorflow/core/platform/errors.h" #include "tensorflow/core/protobuf/graph_debug_info.pb.h" +#include "tensorflow/core/protobuf/meta_graph.pb.h" #include "tensorflow/core/protobuf/saver.pb.h" #include "tensorflow/core/public/session.h" #include "tensorflow/core/public/session_options.h" @@ -95,16 +96,6 @@ static Status ValidateSavedTensors(const GraphDef& graph_def) { return Status::OK(); } -Status LoadMetaGraphIntoSession(const MetaGraphDef& meta_graph_def, - const SessionOptions& session_options, - std::unique_ptr* session) { - Session* session_p = nullptr; - TF_RETURN_IF_ERROR(NewSession(session_options, &session_p)); - session->reset(session_p); - TF_RETURN_IF_ERROR(ValidateSavedTensors(meta_graph_def.graph_def())); - return (*session)->Create(meta_graph_def.graph_def()); -} - Tensor CreateStringTensor(const string& value) { Tensor tensor(DT_STRING, TensorShape({})); tensor.scalar()() = value; @@ -228,22 +219,18 @@ Status RunRestore(const RunOptions& run_options, const string& export_dir, nullptr /* outputs */, &run_metadata, session); } -Status ReadSavedModelDebugInfoIfPresent( - const string& export_dir, - std::unique_ptr* debug_info_proto) { - LOG(INFO) << "Reading SavedModel debug info (if present) from: " - << export_dir; +} // namespace - const string debug_info_pb_path = - io::JoinPath(export_dir, "debug", "saved_model_debug_info.pb"); - if (Env::Default()->FileExists(debug_info_pb_path).ok()) { - GraphDebugInfo debug_info; - TF_RETURN_IF_ERROR( - ReadBinaryProto(Env::Default(), debug_info_pb_path, &debug_info)); - *debug_info_proto = - absl::make_unique(std::move(debug_info)); - } - return Status::OK(); +SavedModelBundleInterface::~SavedModelBundleInterface() {} + +Status LoadMetagraphIntoSession(const SessionOptions& session_options, + const MetaGraphDef& meta_graph, + std::unique_ptr* session) { + Session* session_p = nullptr; + TF_RETURN_IF_ERROR(NewSession(session_options, &session_p)); + session->reset(session_p); + TF_RETURN_IF_ERROR(ValidateSavedTensors(meta_graph.graph_def())); + return (*session)->Create(meta_graph.graph_def()); } Status LoadSavedModelInternal(const SessionOptions& session_options, @@ -251,46 +238,17 @@ Status LoadSavedModelInternal(const SessionOptions& session_options, const string& export_dir, const std::unordered_set& tags, SavedModelBundle* const bundle) { - const uint64 read_start_microseconds = Env::Default()->NowMicros(); TF_RETURN_IF_ERROR(ReadMetaGraphDefFromSavedModel(export_dir, tags, &bundle->meta_graph_def)); TF_RETURN_IF_ERROR( ReadSavedModelDebugInfoIfPresent(export_dir, &bundle->debug_info)); - TF_RETURN_IF_ERROR(LoadMetaGraphIntoSession( - bundle->meta_graph_def, session_options, &bundle->session)); - - std::vector asset_file_defs; - TF_RETURN_IF_ERROR( - internal::GetAssetFileDefs(bundle->meta_graph_def, &asset_file_defs)); - TF_RETURN_IF_ERROR( - RunRestore(run_options, export_dir, - bundle->meta_graph_def.saver_def().restore_op_name(), - bundle->meta_graph_def.saver_def().filename_tensor_name(), - asset_file_defs, bundle->session.get())); - // Record walltime spent in restoring graph from disk, but postpone metric - // increments until graph init finishes. - const uint64 restore_graph_walltime = - GetLatencyMicroseconds(read_start_microseconds); - - const uint64 graph_init_start_microseconds = Env::Default()->NowMicros(); - string init_op_name; - TF_RETURN_IF_ERROR( - internal::GetInitOp(export_dir, bundle->meta_graph_def, &init_op_name)); - TF_RETURN_IF_ERROR(RunInitOp(run_options, export_dir, bundle->meta_graph_def, - asset_file_defs, bundle->session.get(), - init_op_name)); - load_latency_by_stage->GetCell(export_dir, "restore_graph") - ->Add(restore_graph_walltime); - // Record wall time spent in init op. - load_latency_by_stage->GetCell(export_dir, "init_graph") - ->Add(GetLatencyMicroseconds(graph_init_start_microseconds)); + TF_RETURN_IF_ERROR(LoadMetagraphIntoSession( + session_options, bundle->meta_graph_def, &bundle->session)); + TF_RETURN_IF_ERROR(RestoreSession(run_options, bundle->meta_graph_def, + export_dir, &bundle->session)); return Status::OK(); } -} // namespace - -SavedModelBundleInterface::~SavedModelBundleInterface() {} - Status LoadSavedModel(const SessionOptions& session_options, const RunOptions& run_options, const string& export_dir, const std::unordered_set& tags, @@ -424,6 +382,35 @@ class LiteSessionWrapper : public Session { }; } // namespace +Status RestoreSession(const RunOptions& run_options, + const MetaGraphDef& meta_graph, const string& export_dir, + std::unique_ptr* session) { + const uint64 read_start_microseconds = Env::Default()->NowMicros(); + std::vector asset_file_defs; + TF_RETURN_IF_ERROR(internal::GetAssetFileDefs(meta_graph, &asset_file_defs)); + TF_RETURN_IF_ERROR(RunRestore(run_options, export_dir, + meta_graph.saver_def().restore_op_name(), + meta_graph.saver_def().filename_tensor_name(), + asset_file_defs, session->get())); + // Record walltime spent in restoring graph from disk, but postpone metric + // increments until graph init finishes. + const uint64 restore_graph_walltime = + GetLatencyMicroseconds(read_start_microseconds); + + const uint64 graph_init_start_microseconds = Env::Default()->NowMicros(); + string init_op_name; + TF_RETURN_IF_ERROR( + internal::GetInitOp(export_dir, meta_graph, &init_op_name)); + TF_RETURN_IF_ERROR(RunInitOp(run_options, export_dir, meta_graph, + asset_file_defs, session->get(), init_op_name)); + load_latency_by_stage->GetCell(export_dir, "restore_graph") + ->Add(restore_graph_walltime); + // Record wall time spent in init op. + load_latency_by_stage->GetCell(export_dir, "init_graph") + ->Add(GetLatencyMicroseconds(graph_init_start_microseconds)); + return Status::OK(); +} + Status LoadSavedModel(const SessionOptions& session_options, const RunOptions& run_options, const string& export_dir, const std::unordered_set& tags, diff --git a/tensorflow/cc/saved_model/loader.h b/tensorflow/cc/saved_model/loader.h index 2b2e44bc619..5ef6070998e 100644 --- a/tensorflow/cc/saved_model/loader.h +++ b/tensorflow/cc/saved_model/loader.h @@ -96,6 +96,21 @@ class SavedModelBundleLite : public SavedModelBundleInterface { protobuf::Map signatures_; }; +// Restore variable and resources in the SavedModel export dir for the +// indicated metagraph. +// The recommended way to load a saved model is to call LoadSavedModel, +// which provides an already initialized Metagraph, Session, and DebugInfo. +Status RestoreSession(const RunOptions& run_options, + const MetaGraphDef& meta_graph, const string& export_dir, + std::unique_ptr* session); + +// Initialize a session which wraps this metagraph. +// The recommended way to load a saved model is to call LoadSavedModel, +// which provides an already initialized Metagraph, Session, and DebugInfo. +Status LoadMetagraphIntoSession(const SessionOptions& session_options, + const MetaGraphDef& meta_graph, + std::unique_ptr* session); + /// Loads a SavedModel from the specified export directory. The MetaGraphDef /// to be loaded is identified by the supplied tags, corresponding exactly to /// the set of tags used at SavedModel build time. Stores a SavedModel bundle in diff --git a/tensorflow/cc/saved_model/reader.cc b/tensorflow/cc/saved_model/reader.cc index d6d99229372..c1d4736f6b9 100644 --- a/tensorflow/cc/saved_model/reader.cc +++ b/tensorflow/cc/saved_model/reader.cc @@ -17,6 +17,7 @@ limitations under the License. #include +#include "absl/memory/memory.h" #include "tensorflow/cc/saved_model/constants.h" #include "tensorflow/core/lib/io/path.h" #include "tensorflow/core/lib/strings/str_util.h" @@ -86,4 +87,22 @@ Status ReadMetaGraphDefFromSavedModel(const string& export_dir, return Status::OK(); } +Status ReadSavedModelDebugInfoIfPresent( + const string& export_dir, + std::unique_ptr* debug_info_proto) { + LOG(INFO) << "Reading SavedModel debug info (if present) from: " + << export_dir; + + const string debug_info_pb_path = + io::JoinPath(export_dir, "debug", "saved_model_debug_info.pb"); + if (Env::Default()->FileExists(debug_info_pb_path).ok()) { + GraphDebugInfo debug_info; + TF_RETURN_IF_ERROR( + ReadBinaryProto(Env::Default(), debug_info_pb_path, &debug_info)); + *debug_info_proto = + absl::make_unique(std::move(debug_info)); + } + return Status::OK(); +} + } // namespace tensorflow diff --git a/tensorflow/cc/saved_model/reader.h b/tensorflow/cc/saved_model/reader.h index 5815108df2a..602f6cb21c1 100644 --- a/tensorflow/cc/saved_model/reader.h +++ b/tensorflow/cc/saved_model/reader.h @@ -22,6 +22,7 @@ limitations under the License. #include #include "tensorflow/core/lib/core/status.h" +#include "tensorflow/core/protobuf/graph_debug_info.pb.h" #include "tensorflow/core/protobuf/meta_graph.pb.h" namespace tensorflow { @@ -34,6 +35,11 @@ Status ReadMetaGraphDefFromSavedModel(const string& export_dir, const std::unordered_set& tags, MetaGraphDef* const meta_graph_def); +// Store debug info from the SavedModel export dir. +Status ReadSavedModelDebugInfoIfPresent( + const string& export_dir, + std::unique_ptr* debug_info_proto); + } // namespace tensorflow #endif // TENSORFLOW_CC_SAVED_MODEL_READER_H_ diff --git a/tensorflow/cc/saved_model/reader_test.cc b/tensorflow/cc/saved_model/reader_test.cc index bc630bcaede..b5e8b67a123 100644 --- a/tensorflow/cc/saved_model/reader_test.cc +++ b/tensorflow/cc/saved_model/reader_test.cc @@ -106,5 +106,11 @@ TEST_F(ReaderTest, InvalidExportPath) { EXPECT_FALSE(st.ok()); } +TEST_F(ReaderTest, ReadSavedModelDebugInfoIfPresent) { + const string export_dir = GetDataDependencyFilepath(TestDataSharded()); + std::unique_ptr debug_info_proto; + TF_ASSERT_OK(ReadSavedModelDebugInfoIfPresent(export_dir, &debug_info_proto)); +} + } // namespace } // namespace tensorflow diff --git a/tensorflow/cc/saved_model/saved_model_bundle_test.cc b/tensorflow/cc/saved_model/saved_model_bundle_test.cc index d6c375c7448..31f676920aa 100644 --- a/tensorflow/cc/saved_model/saved_model_bundle_test.cc +++ b/tensorflow/cc/saved_model/saved_model_bundle_test.cc @@ -13,9 +13,9 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#include "tensorflow/cc/saved_model/loader.h" - #include "tensorflow/cc/saved_model/constants.h" +#include "tensorflow/cc/saved_model/loader.h" +#include "tensorflow/cc/saved_model/reader.h" #include "tensorflow/cc/saved_model/signature_constants.h" #include "tensorflow/cc/saved_model/tag_constants.h" #include "tensorflow/core/example/example.pb.h" @@ -26,6 +26,7 @@ limitations under the License. #include "tensorflow/core/lib/io/path.h" #include "tensorflow/core/lib/strings/str_util.h" #include "tensorflow/core/platform/test.h" +#include "tensorflow/core/protobuf/meta_graph.pb.h" namespace tensorflow { namespace { @@ -131,6 +132,43 @@ TEST_F(LoaderTest, TagMatch) { CheckSavedModelBundle(export_dir, bundle); } +TEST_F(LoaderTest, ReadMetaGraphFromSavedModel) { + SavedModelBundle bundle; + SessionOptions session_options; + RunOptions run_options; + + const string export_dir = + io::JoinPath(testing::TensorFlowSrcRoot(), kTestDataSharded); + TF_ASSERT_OK(LoadSavedModel(session_options, run_options, export_dir, + {kSavedModelTagServe}, &bundle)); + MetaGraphDef actual_metagraph; + TF_ASSERT_OK(ReadMetaGraphDefFromSavedModel(export_dir, {kSavedModelTagServe}, + &actual_metagraph)); + EXPECT_EQ(actual_metagraph.DebugString(), + bundle.meta_graph_def.DebugString()); +} + +TEST_F(LoaderTest, RestoreSession) { + SavedModelBundle bundle; + SessionOptions session_options; + RunOptions run_options; + + const string export_dir = + io::JoinPath(testing::TensorFlowSrcRoot(), kTestDataSharded); + TF_ASSERT_OK(LoadSavedModel(session_options, run_options, export_dir, + {kSavedModelTagServe}, &bundle)); + + SavedModelBundle actual_bundle; + const std::unordered_set tags = {kSavedModelTagServe}; + TF_ASSERT_OK(ReadMetaGraphDefFromSavedModel(export_dir, tags, + &actual_bundle.meta_graph_def)); + TF_ASSERT_OK(LoadMetagraphIntoSession( + session_options, actual_bundle.meta_graph_def, &actual_bundle.session)); + TF_ASSERT_OK(RestoreSession(run_options, actual_bundle.meta_graph_def, + export_dir, &actual_bundle.session)); + CheckSavedModelBundle(export_dir, actual_bundle); +} + TEST_F(LoaderTest, NoTagMatch) { SavedModelBundle bundle; RunOptions run_options; diff --git a/tensorflow/compiler/jit/xla_compilation_cache.cc b/tensorflow/compiler/jit/xla_compilation_cache.cc index b1525337dbc..971a5383f6b 100644 --- a/tensorflow/compiler/jit/xla_compilation_cache.cc +++ b/tensorflow/compiler/jit/xla_compilation_cache.cc @@ -278,16 +278,14 @@ Status XlaCompilationCache::CompileSingleOp( const NodeDef& node_def = ctx->op_kernel().def(); TF_ASSIGN_OR_RETURN(auto graph, CreateGraph(node_def, args, result_dtypes)); - bool are_args_supported = - absl::c_all_of(args, [](const XlaCompiler::Argument arg) { - return arg.kind == XlaCompiler::Argument::kConstant || - arg.kind == XlaCompiler::Argument::kParameter; + bool has_tensor_list_arg = + absl::c_any_of(args, [](const XlaCompiler::Argument arg) { + return arg.kind == XlaCompiler::Argument::kTensorList; }); const ConfigProto* config = ctx->function_library()->config_proto(); bool use_mlir = config && config->experimental().enable_mlir_bridge(); - // TODO(b/155596779): Understand the source of other argument types and - // depending on the source either support those or avoid these codepath. - if (!use_mlir || !are_args_supported) { + // TODO(b/155596779): Support TensorList args. + if (!use_mlir || !has_tensor_list_arg) { return compiler->CompileGraph(compile_options, node_def.name(), std::move(graph), args, result); } diff --git a/tensorflow/compiler/mlir/BUILD b/tensorflow/compiler/mlir/BUILD index 01c187790b7..d8b4fe5bcef 100644 --- a/tensorflow/compiler/mlir/BUILD +++ b/tensorflow/compiler/mlir/BUILD @@ -40,13 +40,16 @@ cc_library( srcs = ["tf_mlir_opt_main.cc"], deps = [ ":init_mlir", + "//tensorflow/compiler/mlir/lite:tensorflow_lite", + "//tensorflow/compiler/mlir/tensorflow", "//tensorflow/core:lib", "//tensorflow/core/platform:logging", "@llvm-project//llvm:Support", - "@llvm-project//mlir:AllPassesAndDialects", + "@llvm-project//mlir:AllPassesAndDialectsNoRegistration", "@llvm-project//mlir:IR", "@llvm-project//mlir:MlirOptLib", "@llvm-project//mlir:Pass", + "@llvm-project//mlir:Shape", "@llvm-project//mlir:Support", ], ) @@ -127,9 +130,7 @@ tf_cc_binary( deps = [ ":passes", ":tf_mlir_opt_main", - "//tensorflow/compiler/mlir/lite:tensorflow_lite_dialect_registration", "//tensorflow/compiler/mlir/tensorflow:mlir_roundtrip_pass_registration", - "//tensorflow/compiler/mlir/tensorflow:tensorflow_dialect_registration", "//tensorflow/compiler/mlir/tensorflow:tf_graph_optimization_pass", "//tensorflow/compiler/mlir/tfjs:tensorflow_js_dialect_registration", "//tensorflow/compiler/mlir/xla:all_xla_passes_for_testing", diff --git a/tensorflow/compiler/mlir/hlo/BUILD b/tensorflow/compiler/mlir/hlo/BUILD index 126d44670a0..7be39aef9da 100644 --- a/tensorflow/compiler/mlir/hlo/BUILD +++ b/tensorflow/compiler/mlir/hlo/BUILD @@ -813,7 +813,8 @@ cc_binary( ], deps = [ ":all_passes", - ":hlo_dialect_registration", + ":hlo", + ":lhlo", "@llvm-project//llvm:Support", "@llvm-project//mlir:AllPassesAndDialectsNoRegistration", "@llvm-project//mlir:IR", diff --git a/tensorflow/compiler/mlir/hlo/include/mlir-hlo/Dialect/mhlo/IR/hlo_ops.h b/tensorflow/compiler/mlir/hlo/include/mlir-hlo/Dialect/mhlo/IR/hlo_ops.h index ad044e1d322..4286c837a24 100644 --- a/tensorflow/compiler/mlir/hlo/include/mlir-hlo/Dialect/mhlo/IR/hlo_ops.h +++ b/tensorflow/compiler/mlir/hlo/include/mlir-hlo/Dialect/mhlo/IR/hlo_ops.h @@ -56,19 +56,9 @@ class MhloDialect : public Dialect { void printType(Type type, DialectAsmPrinter &os) const override; }; -namespace HLOTypes { -enum Kind { - Token = Type::FIRST_XLA_HLO_TYPE, -}; -} // namespace HLOTypes - class TokenType : public Type::TypeBase { public: using Base::Base; - - static TokenType get(MLIRContext *context) { - return Base::get(context, HLOTypes::Token); - } }; // Shape derivation function that computes the shape of the result based on diff --git a/tensorflow/compiler/mlir/hlo/include/mlir-hlo/Dialect/mhlo/IR/lhlo_ops.td b/tensorflow/compiler/mlir/hlo/include/mlir-hlo/Dialect/mhlo/IR/lhlo_ops.td index 3fa46584ca2..750cce65b62 100644 --- a/tensorflow/compiler/mlir/hlo/include/mlir-hlo/Dialect/mhlo/IR/lhlo_ops.td +++ b/tensorflow/compiler/mlir/hlo/include/mlir-hlo/Dialect/mhlo/IR/lhlo_ops.td @@ -81,6 +81,8 @@ def LHLO_ConstOp : LHLO_Op<"constant", []>, BASE_HLO_ConstOp { ElementsAttr:$value, Arg:$output ); + + let hasCanonicalizer = 1; } def LHLO_IotaOp : LHLO_Op<"iota", []>, BASE_HLO_IotaOp { diff --git a/tensorflow/compiler/mlir/hlo/include/mlir-hlo/utils/hlo_utils.h b/tensorflow/compiler/mlir/hlo/include/mlir-hlo/utils/hlo_utils.h index 1e335ae6b82..74ea9c9b1a7 100644 --- a/tensorflow/compiler/mlir/hlo/include/mlir-hlo/utils/hlo_utils.h +++ b/tensorflow/compiler/mlir/hlo/include/mlir-hlo/utils/hlo_utils.h @@ -65,9 +65,24 @@ static ElementsAttr getSplat(Builder* b, Value val, T constant) { // Returns DenseElementsAttr of rank zero with the given element type and the // value. -// Requires `ty` to be either FloatType of IntegerType. +// Requires `ty` to be either FloatType, IntegerType, or ComplexType. DenseElementsAttr GetScalarOfType(Type ty, int64_t raw_value); +// Enum type used to specify scalar argument to GetScalarLimitOfType. +enum ScalarLimit { + kLowest, // The scalar corresponding to numeric_limits::lowest. + kInfinityLowest, // Like kMax, but returns -infinity where available. + kMax, // The scalar corresponding to numeric_limits::max. + kInfinityMax, // Like kMax, but returns infinity where available. +}; + +// Returns a scalar limit value for the given type. +// +// The argument 'limit' describes which scalar value to return. +// +// Requires `ty` to be either FloatType or IntegerType. +DenseElementsAttr GetScalarLimitOfType(Type ty, ScalarLimit limit); + } // namespace hlo } // namespace mlir diff --git a/tensorflow/compiler/mlir/hlo/lib/Dialect/mhlo/IR/lhlo_ops.cc b/tensorflow/compiler/mlir/hlo/lib/Dialect/mhlo/IR/lhlo_ops.cc index f61a66397e7..81407c89204 100644 --- a/tensorflow/compiler/mlir/hlo/lib/Dialect/mhlo/IR/lhlo_ops.cc +++ b/tensorflow/compiler/mlir/hlo/lib/Dialect/mhlo/IR/lhlo_ops.cc @@ -29,6 +29,7 @@ limitations under the License. #include "llvm/ADT/StringRef.h" #include "llvm/Support/FormatVariadic.h" #include "mlir-hlo/Dialect/mhlo/IR/lhlo_ops.h.inc" +#include "mlir/Dialect/StandardOps/IR/Ops.h" #include "mlir/IR/Attributes.h" #include "mlir/IR/Builders.h" #include "mlir/IR/Dialect.h" @@ -56,6 +57,38 @@ LmhloDialect::LmhloDialect(MLIRContext *context) >(); } +//===----------------------------------------------------------------------===// +// ConstOp. +//===----------------------------------------------------------------------===// + +/// An lho.constant on an memref that is locally allocated and with no other +/// users (other than dealloc's) can be erased. +// TODO: This can be generalized to an arbitrary op by making use of memory +// effects (write memory effect). +struct EraseConstOp : public OpRewritePattern { + using OpRewritePattern::OpRewritePattern; + + LogicalResult matchAndRewrite(ConstOp op, + PatternRewriter& rewriter) const override { + Value memref = op.output(); + if (!memref.getDefiningOp()) { + return failure(); + } + + // Check that all uses of the memref are either DeallocOps or this op. + for (Operation* user : memref.getUsers()) + if (user != op && !isa(user)) return failure(); + + rewriter.eraseOp(op); + return success(); + } +}; + +void ConstOp::getCanonicalizationPatterns(OwningRewritePatternList& results, + MLIRContext* context) { + results.insert(context); +} + //===----------------------------------------------------------------------===// // StaticMemRefCastOp //===----------------------------------------------------------------------===// diff --git a/tensorflow/compiler/mlir/hlo/lib/Dialect/mhlo/transforms/legalize_to_linalg.cc b/tensorflow/compiler/mlir/hlo/lib/Dialect/mhlo/transforms/legalize_to_linalg.cc index f47f2c2fbdc..033021c36ac 100644 --- a/tensorflow/compiler/mlir/hlo/lib/Dialect/mhlo/transforms/legalize_to_linalg.cc +++ b/tensorflow/compiler/mlir/hlo/lib/Dialect/mhlo/transforms/legalize_to_linalg.cc @@ -15,6 +15,8 @@ limitations under the License. // This file implements logic for lowering HLO/LHLO dialect to Linalg dialect. +#include + #include "mlir-hlo/Dialect/mhlo/IR/hlo_ops.h" #include "mlir-hlo/Dialect/mhlo/IR/lhlo_ops.h" #include "mlir-hlo/Dialect/mhlo/transforms/map_lmhlo_to_scalar_op.h" @@ -598,6 +600,7 @@ class ReshapeOpConverter : public OpConversionPattern { unsigned currSrcDim = 0, currDstDim = 0; SmallVector reassociationMap( dstShape.size()); + bool isExpandingOrCollapsing = true; while (currSrcDim < srcShape.size() && currDstDim < dstShape.size()) { int64_t dstSize = dstShape[currDstDim]; int64_t srcSize = srcShape[currSrcDim]; @@ -619,11 +622,47 @@ class ReshapeOpConverter : public OpConversionPattern { } } } else { - return failure(); + isExpandingOrCollapsing = false; + break; } currDstDim++; } - if (currSrcDim != srcShape.size()) return failure(); + if (currSrcDim != srcShape.size()) isExpandingOrCollapsing = false; + + if (!isExpandingOrCollapsing) { + auto getIdentityExprs = [&rewriter](int n) { + SmallVector exprs; + for (int i = 0; i < n; ++i) + exprs.push_back(rewriter.getAffineDimExpr(i)); + return exprs; + }; + Location loc = reshapeOp.getLoc(); + int64_t totalElems = std::accumulate(srcShape.begin(), srcShape.end(), 1, + std::multiplies()); + auto elemType = operandType.getElementType(); + SmallVector collapsingMap = { + getIdentityExprs(dstShape.size())}; + SmallVector expandingMap = { + getIdentityExprs(srcShape.size())}; + + if (isLHLO) { + auto collapsedType = MemRefType::get({totalElems}, elemType); + Value collapsedOp = rewriter.create( + loc, collapsedType, args[0], collapsingMap); + Value reshapeBuffer = rewriter.create( + loc, resultType, collapsedOp, expandingMap); + rewriter.replaceOpWithNewOp( + reshapeOp, reshapeBuffer, args[1], /*inputPermutation =*/nullptr, + /*outputPermutation =*/nullptr); + } else { + auto collapsedType = RankedTensorType::get({totalElems}, elemType); + Value collapsedOp = rewriter.create( + loc, collapsedType, args[0], collapsingMap); + rewriter.replaceOpWithNewOp( + reshapeOp, resultType, collapsedOp, expandingMap); + } + return success(); + } if (isLHLO) { Value reshapeBuffer = rewriter.create( diff --git a/tensorflow/compiler/mlir/hlo/lib/utils/hlo_utils.cc b/tensorflow/compiler/mlir/hlo/lib/utils/hlo_utils.cc index df2442cc4b6..0bbd91e0680 100644 --- a/tensorflow/compiler/mlir/hlo/lib/utils/hlo_utils.cc +++ b/tensorflow/compiler/mlir/hlo/lib/utils/hlo_utils.cc @@ -60,10 +60,76 @@ DenseElementsAttr GetScalarOfType(Type ty, int64_t raw_value) { if (auto float_ty = ty.dyn_cast()) { APFloat value(float_ty.getFloatSemantics(), raw_value); return DenseElementsAttr::get(scalar_ty, value); + } else if (auto int_ty = ty.dyn_cast()) { + APInt value(int_ty.getWidth(), static_cast(raw_value), true); + return DenseElementsAttr::get(scalar_ty, value); + } else if (auto complex_ty = ty.dyn_cast()) { + Type complex_element_ty = complex_ty.getElementType(); + if (complex_element_ty.isF32()) { + return DenseElementsAttr::get( + scalar_ty, static_cast>(raw_value)); + } else if (complex_element_ty.isF64()) { + return DenseElementsAttr::get( + scalar_ty, static_cast>(raw_value)); + } } - auto int_ty = ty.cast(); - APInt value(int_ty.getWidth(), static_cast(raw_value), true); - return DenseElementsAttr::get(scalar_ty, value); + llvm_unreachable("unsupported type"); +} + +static APFloat GetScalarLimitOfFloatType(FloatType float_ty, + ScalarLimit limit) { + auto &semantics = float_ty.getFloatSemantics(); + switch (limit) { + case kLowest: + return APFloat::getLargest(semantics, /*negative=*/true); + case kInfinityLowest: + return APFloat::getInf(semantics, /*negative=*/true); + case kMax: + return APFloat::getLargest(semantics, /*negative=*/false); + case kInfinityMax: + return APFloat::getInf(semantics, /*negative=*/false); + } + llvm_unreachable("invalid limit"); +} + +// Returns a scalar value for the given integer type. +// +// The argument 'scalar' describes which scalar value to return. `integer_value` +// is used to specify the integer value for kInteger. For any other scalar, +// integer_value is ignored. +static APInt GetScalarLimitOfIntegerType(IntegerType integer_ty, + ScalarLimit limit) { + unsigned width = integer_ty.getWidth(); + switch (limit) { + case kLowest: + case kInfinityLowest: + if (integer_ty.isUnsigned()) { + return APInt::getMinValue(width); + } else { + return APInt::getSignedMinValue(width); + } + + case kMax: + case kInfinityMax: + if (integer_ty.isUnsigned()) { + return APInt::getMaxValue(width); + } else { + return APInt::getSignedMaxValue(width); + } + } + llvm_unreachable("invalid limit"); +} + +DenseElementsAttr GetScalarLimitOfType(Type ty, ScalarLimit limit) { + RankedTensorType scalar_ty = RankedTensorType::get({}, ty); + if (auto float_ty = ty.dyn_cast()) { + return DenseElementsAttr::get(scalar_ty, + GetScalarLimitOfFloatType(float_ty, limit)); + } else if (auto integer_ty = ty.dyn_cast()) { + return DenseElementsAttr::get( + scalar_ty, GetScalarLimitOfIntegerType(integer_ty, limit)); + } + llvm_unreachable("unsupported type"); } } // namespace hlo diff --git a/tensorflow/compiler/mlir/hlo/tests/canonicalize.mlir b/tensorflow/compiler/mlir/hlo/tests/canonicalize.mlir index 15b1a150fdd..0d20c3f517b 100644 --- a/tensorflow/compiler/mlir/hlo/tests/canonicalize.mlir +++ b/tensorflow/compiler/mlir/hlo/tests/canonicalize.mlir @@ -597,3 +597,24 @@ func @unpack_repack_same_tuple_single_element(%arg0: tuple>) -> tupl // CHECK: return [[ARG0]] return %3 : tuple> } + +// CHECK-LABEL: func @erase_dead_lhlo_constant +func @erase_dead_lhlo_constant() { + %M = alloc() : memref<256x1024xf32> + // CHECK-NEXT: return + "lmhlo.constant"(%M) {value = dense<0.0> : tensor} : (memref<256x1024xf32>) -> () + dealloc %M : memref<256x1024xf32> + return +} + +// A negative test for dead lhlo constant op erasure. +// CHECK-LABEL: func @erase_dead_lhlo_constant_negative +func @erase_dead_lhlo_constant_negative(%M : memref<4xf32>) -> memref<256x1024xf32> { + // CHECK-NEXT: lmhlo.constant + "lmhlo.constant"(%M) {value = dense<0.0> : tensor} : (memref<4xf32>) -> () + // CHECK-NEXT: alloc + // CHECK-NEXT: lmhlo.constant + %N = alloc() : memref<256x1024xf32> + "lmhlo.constant"(%N) {value = dense<0.0> : tensor} : (memref<256x1024xf32>) -> () + return %N : memref<256x1024xf32> +} diff --git a/tensorflow/compiler/mlir/hlo/tests/hlo-legalize-to-linalg.mlir b/tensorflow/compiler/mlir/hlo/tests/hlo-legalize-to-linalg.mlir index 46725e0bd09..aecf612962a 100644 --- a/tensorflow/compiler/mlir/hlo/tests/hlo-legalize-to-linalg.mlir +++ b/tensorflow/compiler/mlir/hlo/tests/hlo-legalize-to-linalg.mlir @@ -373,6 +373,18 @@ func @reshape_2D_4D(%arg0: tensor<12x42xi32>) -> tensor<12x1x42x1xi32> { // ----- +// CHECK-DAG: #[[RESHAPE_MAP1:.*]] = affine_map<(d0, d1, d2) -> (d0, d1, d2)> +// CHECK-DAG: #[[RESHAPE_MAP2:.*]] = affine_map<(d0, d1, d2, d3) -> (d0, d1, d2, d3)> +// CHECK-LABEL: func @reshape_3D_4D +func @reshape_3D_4D(%arg0: tensor<1x49x16xf32>) -> tensor<1x784x1x1xf32> { + %0 = "mhlo.reshape"(%arg0) : (tensor<1x49x16xf32>) -> tensor<1x784x1x1xf32> + return %0 : tensor<1x784x1x1xf32> +} +// CHECK: linalg.tensor_reshape %{{.*}} [#[[RESHAPE_MAP1]]] +// CHECK: linalg.tensor_reshape %{{.*}} [#[[RESHAPE_MAP2]]] + +// ----- + // CHECK-LABEL: func @minf func @minf(%lhs: tensor<2x2xf32>, %rhs: tensor<2x2xf32>) -> tensor<2x2xf32> { %0 = "mhlo.minimum"(%lhs, %rhs) diff --git a/tensorflow/compiler/mlir/hlo/tests/lhlo-legalize-to-linalg.mlir b/tensorflow/compiler/mlir/hlo/tests/lhlo-legalize-to-linalg.mlir index 768d8da22bd..f174b005a8d 100644 --- a/tensorflow/compiler/mlir/hlo/tests/lhlo-legalize-to-linalg.mlir +++ b/tensorflow/compiler/mlir/hlo/tests/lhlo-legalize-to-linalg.mlir @@ -688,6 +688,20 @@ func @reshape_2D_4D(%arg0: memref<12x42xi32>, %arg1 : memref<12x1x42x1xi32>) { // ----- +// CHECK-DAG: #[[RESHAPE_MAP1:.*]] = affine_map<(d0, d1, d2) -> (d0, d1, d2)> +// CHECK-DAG: #[[RESHAPE_MAP2:.*]] = affine_map<(d0, d1, d2, d3) -> (d0, d1, d2, d3)> +// CHECK-LABEL: func @reshape_3D_4D +func @reshape_3D_4D(%arg0: memref<1x49x16xf32>, %arg1: memref<1x784x1x1xf32>) { + "lmhlo.reshape"(%arg0, %arg1) + : (memref<1x49x16xf32>, memref<1x784x1x1xf32>) -> () + return +} +// CHECK: linalg.reshape %{{.*}} [#[[RESHAPE_MAP1]]] +// CHECK: linalg.reshape %{{.*}} [#[[RESHAPE_MAP2]]] +// CHECK: linalg.copy + +// ----- + // CHECK-DAG: #[[OPERAND_MAP:.*]] = affine_map<(d0, d1) -> (d0, -d1 + 2)> // CHECK-DAG: #[[RESULT_MAP:.*]] = affine_map<(d0, d1) -> (d0, d1)> // CHECK-LABEL: func @reverse diff --git a/tensorflow/compiler/mlir/hlo/tests/mhlo-transform-unranked.mlir b/tensorflow/compiler/mlir/hlo/tests/mhlo-transform-unranked.mlir index 56a7cf7294c..01ef250efd0 100644 --- a/tensorflow/compiler/mlir/hlo/tests/mhlo-transform-unranked.mlir +++ b/tensorflow/compiler/mlir/hlo/tests/mhlo-transform-unranked.mlir @@ -69,7 +69,7 @@ func @sqrt_static(%a: tensor<2x3xf32>) -> tensor<2x3xf32> { func @add_unranked(%a : tensor<*xf32>, %b : tensor<*xf32>) -> tensor<*xf32> { // CHECK: %[[SHAPE_A:.*]] = shape.shape_of %[[A]] // CHECK: %[[SHAPE_B:.*]] = shape.shape_of %[[B]] - // CHECK: %[[SHAPE:.*]] = "shape.any"(%[[SHAPE_A]], %[[SHAPE_B]]) + // CHECK: %[[SHAPE:.*]] = shape.any %[[SHAPE_A]], %[[SHAPE_B]] // CHECK: %[[NUM_ELEMENTS:.*]] = shape.num_elements %[[SHAPE]] // CHECK: %[[FLAT_SHAPE:.*]] = tensor_from_elements(%[[NUM_ELEMENTS]]) : tensor<1xindex> // CHECK: %[[FLAT_A:.*]] = "mhlo.dynamic_reshape"(%[[A]], %[[FLAT_SHAPE]]) : (tensor<*xf32>, tensor<1xindex>) -> tensor diff --git a/tensorflow/compiler/mlir/hlo/tools/mlir-hlo-opt/CMakeLists.txt b/tensorflow/compiler/mlir/hlo/tools/mlir-hlo-opt/CMakeLists.txt index 754469a3c84..69971f4c024 100644 --- a/tensorflow/compiler/mlir/hlo/tools/mlir-hlo-opt/CMakeLists.txt +++ b/tensorflow/compiler/mlir/hlo/tools/mlir-hlo-opt/CMakeLists.txt @@ -30,3 +30,5 @@ add_llvm_executable(mlir-hlo-opt mlir-hlo-opt.cpp ) llvm_update_compile_flags(mlir-hlo-opt) target_link_libraries(mlir-hlo-opt PRIVATE ${LIBS}) + +mlir_check_all_link_libraries(mlir-hlo-opt) diff --git a/tensorflow/compiler/mlir/hlo/tools/mlir-hlo-opt/mlir-hlo-opt.cpp b/tensorflow/compiler/mlir/hlo/tools/mlir-hlo-opt/mlir-hlo-opt.cpp index 70fc21d6959..d0c0e3c51e1 100644 --- a/tensorflow/compiler/mlir/hlo/tools/mlir-hlo-opt/mlir-hlo-opt.cpp +++ b/tensorflow/compiler/mlir/hlo/tools/mlir-hlo-opt/mlir-hlo-opt.cpp @@ -13,109 +13,25 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#include "llvm/Support/CommandLine.h" -#include "llvm/Support/InitLLVM.h" -#include "llvm/Support/SourceMgr.h" -#include "llvm/Support/ToolOutputFile.h" -#include "mlir-hlo/Dialect/mhlo/IR/register.h" +#include "mlir-hlo/Dialect/mhlo/IR/chlo_ops.h" +#include "mlir-hlo/Dialect/mhlo/IR/hlo_ops.h" +#include "mlir-hlo/Dialect/mhlo/IR/lhlo_ops.h" #include "mlir-hlo/Dialect/mhlo/transforms/register_passes.h" -#include "mlir/IR/Dialect.h" -#include "mlir/IR/MLIRContext.h" #include "mlir/InitAllDialects.h" #include "mlir/InitAllPasses.h" -#include "mlir/Pass/Pass.h" -#include "mlir/Pass/PassManager.h" -#include "mlir/Support/FileUtilities.h" #include "mlir/Support/MlirOptMain.h" -// NOLINTNEXTLINE -static llvm::cl::opt inputFilename(llvm::cl::Positional, - llvm::cl::desc(""), - llvm::cl::init("-")); - -// NOLINTNEXTLINE -static llvm::cl::opt outputFilename( - "o", llvm::cl::desc("Output filename"), llvm::cl::value_desc("filename"), - llvm::cl::init("-")); - -// NOLINTNEXTLINE -static llvm::cl::opt splitInputFile( - "split-input-file", - llvm::cl::desc("Split the input file into pieces and process each " - "chunk independently"), - llvm::cl::init(false)); - -// NOLINTNEXTLINE -static llvm::cl::opt verifyDiagnostics( - "verify-diagnostics", - llvm::cl::desc("Check that emitted diagnostics match " - "expected-* lines on the corresponding line"), - llvm::cl::init(false)); - -// NOLINTNEXTLINE -static llvm::cl::opt verifyPasses( - "verify-each", - llvm::cl::desc("Run the verifier after each transformation pass"), - llvm::cl::init(true)); - -// NOLINTNEXTLINE -static llvm::cl::opt allowUnregisteredDialects( - "allow-unregistered-dialect", - llvm::cl::desc("Allow operation with no registered dialects"), - llvm::cl::init(false)); - -// NOLINTNEXTLINE -static llvm::cl::opt showDialects( - "show-dialects", llvm::cl::desc("Print the list of registered dialects"), - llvm::cl::init(false)); - int main(int argc, char **argv) { - mlir::registerAllDialects(); mlir::registerAllPasses(); - - mlir::mhlo::registerAllDialects(); mlir::mhlo::registerAllMhloPasses(); mlir::lmhlo::registerAllLmhloPasses(); - llvm::InitLLVM y(argc, argv); + mlir::DialectRegistry registry; + mlir::registerAllDialects(registry); + registry.insert(); + registry.insert(); + registry.insert(); - // Register any pass manager command line options. - mlir::registerPassManagerCLOptions(); - mlir::PassPipelineCLParser passPipeline("", "Compiler passes to run"); - - // Parse pass names in main to ensure static initialization completed. - llvm::cl::ParseCommandLineOptions(argc, argv, - "MLIR modular optimizer driver\n"); - - if (showDialects) { - mlir::MLIRContext context; - llvm::outs() << "Registered Dialects:\n"; - for (mlir::Dialect *dialect : context.getRegisteredDialects()) { - llvm::outs() << dialect->getNamespace() << "\n"; - } - return 0; - } - - // Set up the input file. - std::string errorMessage; - auto file = mlir::openInputFile(inputFilename, &errorMessage); - if (!file) { - llvm::errs() << errorMessage << "\n"; - return 1; - } - - auto output = mlir::openOutputFile(outputFilename, &errorMessage); - if (!output) { - llvm::errs() << errorMessage << "\n"; - exit(1); - } - - if (failed(MlirOptMain(output->os(), std::move(file), passPipeline, - splitInputFile, verifyDiagnostics, verifyPasses, - allowUnregisteredDialects))) { - return 1; - } - // Keep the output file if the invocation of MlirOptMain was successful. - output->keep(); - return 0; + return failed( + mlir::MlirOptMain(argc, argv, "MLIR HLO pass driver\n", registry)); } diff --git a/tensorflow/compiler/mlir/lite/BUILD b/tensorflow/compiler/mlir/lite/BUILD index bd1dcdf06ea..2d3a58b5b9d 100644 --- a/tensorflow/compiler/mlir/lite/BUILD +++ b/tensorflow/compiler/mlir/lite/BUILD @@ -29,6 +29,7 @@ filegroup( "ir/tfl_ops.td", "//tensorflow/compiler/mlir/lite/quantization:quantization_td_files", "@llvm-project//mlir:OpBaseTdFiles", + "@llvm-project//mlir:include/mlir/Interfaces/InferTypeOpInterface.td", "@llvm-project//mlir:include/mlir/Interfaces/LoopLikeInterface.td", "@llvm-project//mlir:include/mlir/Interfaces/SideEffectInterfaces.td", ], @@ -227,6 +228,7 @@ cc_library( "@llvm-project//mlir:DerivedAttributeOpInterface", "@llvm-project//mlir:Dialect", "@llvm-project//mlir:IR", + "@llvm-project//mlir:InferTypeOpInterface", "@llvm-project//mlir:LoopLikeInterface", "@llvm-project//mlir:QuantOps", "@llvm-project//mlir:SideEffects", @@ -500,6 +502,7 @@ gentbl( tblgen = "//tensorflow/compiler/mlir/lite/quantization:op_quant_spec_getters_gen", td_file = "ir/tfl_ops.td", td_srcs = [ + "@llvm-project//mlir:include/mlir/Interfaces/InferTypeOpInterface.td", "@llvm-project//mlir:include/mlir/Interfaces/LoopLikeInterface.td", "//tensorflow/compiler/mlir/lite/quantization:quantization_td_files", "ir/tfl_op_interfaces.td", @@ -670,6 +673,7 @@ cc_library( ":flatbuffer_tflite_operator_lib", ":tensorflow_lite", ":tensorflow_lite_dialect_registration", + "//tensorflow/compiler/mlir/tensorflow", "//tensorflow/compiler/mlir/tensorflow:mangling_util", "//tensorflow/compiler/mlir/tensorflow:tensorflow_types", "//tensorflow/compiler/xla:statusor", @@ -737,16 +741,13 @@ cc_library( ], deps = [ ":flatbuffer_translate_lib", + ":tensorflow_lite", + "//tensorflow/compiler/mlir/tensorflow", "//tensorflow/compiler/mlir/tensorflow:mlir_roundtrip_flags", - "@com_google_absl//absl/base", - "@com_google_absl//absl/base:core_headers", - "@com_google_absl//absl/container:flat_hash_map", - "@com_google_absl//absl/container:flat_hash_set", "@llvm-project//llvm:Support", "@llvm-project//mlir:IR", "@llvm-project//mlir:MlirTranslateMain", "@llvm-project//mlir:QuantOps", - "@llvm-project//mlir:SCFTransforms", "@llvm-project//mlir:StandardOps", "@llvm-project//mlir:Support", "@llvm-project//mlir:Translation", @@ -759,7 +760,7 @@ tf_cc_binary( deps = [ ":flatbuffer_translate_registeration", # TODO(b/155809683): Link only necessary dialects. - "@llvm-project//mlir:AllPassesAndDialects", + "@llvm-project//mlir:AllPassesAndDialectsNoRegistration", ], ) @@ -811,7 +812,7 @@ tf_cc_binary( "@com_google_absl//absl/strings", "@llvm-project//llvm:Support", # TODO(b/155809683): Link only necessary dialects. - "@llvm-project//mlir:AllPassesAndDialects", + "@llvm-project//mlir:AllPassesAndDialectsNoRegistration", "@llvm-project//mlir:IR", "@llvm-project//mlir:Pass", "@llvm-project//mlir:Support", @@ -835,19 +836,18 @@ tf_cc_binary( deps = [ ":flatbuffer_translate_lib", ":flatbuffer_translate_registeration", - "@com_google_absl//absl/strings", - "@llvm-project//llvm:Support", - # TODO(b/155809683): Link only necessary dialects. - "@llvm-project//mlir:AllPassesAndDialects", - "@llvm-project//mlir:IR", - "@llvm-project//mlir:Parser", - "@llvm-project//mlir:Support", - "//tensorflow/compiler/mlir/tensorflow:mlir_roundtrip_flags", + ":tensorflow_lite", + "//tensorflow/compiler/mlir/tensorflow", "//tensorflow/core:lib", "//tensorflow/core/platform:logging", "//tensorflow/lite:framework", "//tensorflow/lite/delegates/flex:delegate", "//tensorflow/lite/kernels:builtin_ops", + "@com_google_absl//absl/strings", + "@llvm-project//llvm:Support", + "@llvm-project//mlir:IR", + "@llvm-project//mlir:Parser", + "@llvm-project//mlir:StandardOps", ], ) @@ -874,7 +874,7 @@ cc_library( "//tensorflow/compiler/mlir/tensorflow:translate_lib", "//tensorflow/core:core_cpu_base", "@llvm-project//llvm:Support", - "@llvm-project//mlir:AllPassesAndDialects", + "@llvm-project//mlir:AllPassesAndDialectsNoRegistration", "@llvm-project//mlir:IR", "@llvm-project//mlir:Pass", "@llvm-project//mlir:Transforms", @@ -908,7 +908,7 @@ cc_library( "//tensorflow/stream_executor/lib", "@com_google_absl//absl/types:span", "@llvm-project//llvm:Support", - "@llvm-project//mlir:AllPassesAndDialects", + "@llvm-project//mlir:AllPassesAndDialectsNoRegistration", "@llvm-project//mlir:IR", "@llvm-project//mlir:Parser", "@llvm-project//mlir:Pass", diff --git a/tensorflow/compiler/mlir/lite/flatbuffer_export.cc b/tensorflow/compiler/mlir/lite/flatbuffer_export.cc index 89fae87cb25..34200fb88b6 100644 --- a/tensorflow/compiler/mlir/lite/flatbuffer_export.cc +++ b/tensorflow/compiler/mlir/lite/flatbuffer_export.cc @@ -61,6 +61,7 @@ limitations under the License. #include "tensorflow/compiler/mlir/lite/utils/convert_type.h" #include "tensorflow/compiler/mlir/lite/utils/stateful_ops_utils.h" #include "tensorflow/compiler/mlir/op_or_arg_name_mapper.h" +#include "tensorflow/compiler/mlir/tensorflow/ir/tf_executor.h" #include "tensorflow/compiler/mlir/tensorflow/ir/tf_ops.h" #include "tensorflow/compiler/mlir/tensorflow/ir/tf_types.h" #include "tensorflow/compiler/mlir/tensorflow/translate/export_tf_dialect_op.h" @@ -133,63 +134,59 @@ static StatusOr GetTFLiteType(Type type, return Status(error::INVALID_ARGUMENT, "'isSigned' can only be set for 8-bits integer type"); } - switch (type.getKind()) { - case mlir::StandardTypes::F32: - return tflite::TensorType_FLOAT32; - case mlir::StandardTypes::F16: - return tflite::TensorType_FLOAT16; - case mlir::StandardTypes::F64: - return tflite::TensorType_FLOAT64; - case mlir::TF::TensorFlowTypes::STRING: - return tflite::TensorType_STRING; - case mlir::TF::TensorFlowTypes::QUINT8: - return tflite::TensorType_UINT8; - case mlir::StandardTypes::Complex: { - auto ftype = type.cast().getElementType(); - if (ftype && ftype.isF32()) { - return tflite::TensorType_COMPLEX64; - } - if (ftype && ftype.isF64()) { - return tflite::TensorType_COMPLEX128; - } - return Status(error::INVALID_ARGUMENT, "Unsupported type"); + + if (type.isF32()) { + return tflite::TensorType_FLOAT32; + } else if (type.isF16()) { + return tflite::TensorType_FLOAT16; + } else if (type.isF64()) { + return tflite::TensorType_FLOAT64; + } else if (type.isa()) { + return tflite::TensorType_STRING; + } else if (type.isa()) { + return tflite::TensorType_UINT8; + } else if (auto complex_type = type.dyn_cast()) { + auto ftype = complex_type.getElementType(); + if (ftype.isF32()) { + return tflite::TensorType_COMPLEX64; } - case mlir::StandardTypes::Integer: { - const auto& itype = type.cast(); - switch (itype.getWidth()) { - case 1: - return tflite::TensorType_BOOL; - case 8: - return itype.isUnsigned() ? tflite::TensorType_UINT8 - : tflite::TensorType_INT8; - case 16: - return tflite::TensorType_INT16; - case 32: - return tflite::TensorType_INT32; - case 64: - return tflite::TensorType_INT64; - } + if (ftype.isF64()) { + return tflite::TensorType_COMPLEX128; } - case mlir::quant::QuantizationTypes::UniformQuantized: { - auto qtype = type.cast(); - return GetTFLiteType(qtype.getStorageType(), qtype.isSigned()); + return Status(error::INVALID_ARGUMENT, "Unsupported type"); + } else if (auto itype = type.dyn_cast()) { + switch (itype.getWidth()) { + case 1: + return tflite::TensorType_BOOL; + case 8: + return itype.isUnsigned() ? tflite::TensorType_UINT8 + : tflite::TensorType_INT8; + case 16: + return tflite::TensorType_INT16; + case 32: + return tflite::TensorType_INT32; + case 64: + return tflite::TensorType_INT64; } - case mlir::quant::QuantizationTypes::UniformQuantizedPerAxis: { - auto qtype = type.cast(); - return GetTFLiteType(qtype.getStorageType(), qtype.isSigned()); - } - case mlir::TF::TensorFlowTypes::RESOURCE: { - // Treat tf.resource values as integer values in flatbuffer. - // TODO(b/146131919): Maybe need to have a detailed design for supporting - // other resource types beyonds hash table resources and resource - // variables. - return tflite::TensorType_INT32; - } - default: - // TFLite export fills FLOAT32 for unknown data types. Returning an error - // for now for safety and this could be revisited when required. - return Status(error::INVALID_ARGUMENT, "Unsupported type"); + } else if (auto q_uniform_type = + type.dyn_cast()) { + return GetTFLiteType(q_uniform_type.getStorageType(), + q_uniform_type.isSigned()); + + } else if (auto q_peraxis_type = + type.dyn_cast()) { + return GetTFLiteType(q_peraxis_type.getStorageType(), + q_peraxis_type.isSigned()); + } else if (type.isa()) { + // Treat tf.resource values as integer values in flatbuffer. + // TODO(b/146131919): Maybe need to have a detailed design for supporting + // other resource types beyonds hash table resources and resource + // variables. + return tflite::TensorType_INT32; } + // TFLite export fills FLOAT32 for unknown data types. Returning an error + // for now for safety and this could be revisited when required. + return Status(error::INVALID_ARGUMENT, "Unsupported type"); } static bool IsConst(Operation* op) { @@ -358,8 +355,13 @@ class Translator { if (emit_custom_ops) { enabled_op_types_.emplace(OpType::kCustomOp); } - tf_dialect_ = module.getContext()->getRegisteredDialect("tf"); - tfl_dialect_ = module.getContext()->getRegisteredDialect("tfl"); + tf_dialect_ = + module.getContext()->getOrLoadDialect(); + tfl_dialect_ = module.getContext() + ->getOrLoadDialect(); + // Right now the TF executor dialect is still needed to build NodeDef. + module.getContext() + ->getOrLoadDialect(); } Optional TranslateInternal(); diff --git a/tensorflow/compiler/mlir/lite/flatbuffer_import.cc b/tensorflow/compiler/mlir/lite/flatbuffer_import.cc index 3c8bf26aa14..230383729c4 100644 --- a/tensorflow/compiler/mlir/lite/flatbuffer_import.cc +++ b/tensorflow/compiler/mlir/lite/flatbuffer_import.cc @@ -65,6 +65,7 @@ limitations under the License. #include "tensorflow/compiler/mlir/lite/flatbuffer_operator.h" #include "tensorflow/compiler/mlir/lite/ir/tfl_ops.h" #include "tensorflow/compiler/mlir/lite/utils/convert_type.h" +#include "tensorflow/compiler/mlir/tensorflow/ir/tf_ops.h" #include "tensorflow/compiler/mlir/tensorflow/ir/tf_types.h" #include "tensorflow/compiler/mlir/tensorflow/utils/mangling_util.h" #include "tensorflow/compiler/xla/statusor.h" @@ -254,20 +255,35 @@ mlir::Operation* ConvertMinMaxToStatsOp(const TensorT& tensor, OpBuilder b, layer_stats, axis_stats, axis); } -StatusOr OpNameForOpCode(const tflite::OperatorCodeT opcode) { - if (opcode.builtin_code == tflite::BuiltinOperator_CUSTOM) { +// Returns true if this is a basic LSTM op. +bool IsBasicLSTMOp(tflite::BuiltinOptionsUnion op_union) { + if (const auto* op = op_union.AsLSTMOptions()) { + return op->kernel_type == tflite::LSTMKernelType_BASIC; + } else { + return false; + } +} + +// Gets the MLIR op name with the dialect name for the flatbuffer operator. +StatusOr GetMlirOpName(const tflite::OperatorT& op, + const tflite::OperatorCodeT& op_code) { + if (IsBasicLSTMOp(op.builtin_options)) { + return std::string("tfl.basic_lstm"); + } + + if (op_code.builtin_code == tflite::BuiltinOperator_CUSTOM) { return std::string("tfl.custom"); } - if (opcode.builtin_code == tflite::BuiltinOperator_IF) { + if (op_code.builtin_code == tflite::BuiltinOperator_IF) { return std::string("tf.If"); } - if (opcode.builtin_code == tflite::BuiltinOperator_WHILE) { + if (op_code.builtin_code == tflite::BuiltinOperator_WHILE) { return std::string("tf.While"); } - const char* op_name = tflite::EnumNameBuiltinOperator(opcode.builtin_code); - std::string lowered_name = llvm::StringRef(op_name).lower(); - return llvm::Twine("tfl.", lowered_name).str(); + llvm::StringRef op_name( + tflite::EnumNameBuiltinOperator(op_code.builtin_code)); + return llvm::Twine("tfl.", op_name.lower()).str(); } // The buffers in TFLite flatbuffers have their contents stored as a vector of @@ -464,7 +480,7 @@ StatusOr BuildConstOp(const tflite::TensorT& tensor, value = mlir::DenseStringElementsAttr::get(shaped_type, refs); } else if (elem_type.isa()) { - auto dialect = elem_type.getContext()->getRegisteredDialect("tf"); + auto dialect = elem_type.getContext()->getLoadedDialect("tf"); tensorflow::TensorProto repr = ConvertTfliteConstTensor(tensor, buffer); std::string mangled = tensorflow::mangling_util::MangleTensor(repr); @@ -510,14 +526,6 @@ llvm::SmallVector ConvertSubgraphIdxsToFunctionAttrs( return {}; } -// Returns true if this is a basic LSTM op. -bool IsBasicLSTMOp(tflite::BuiltinOptionsUnion op_union) { - if (const auto* op = op_union.AsLSTMOptions()) { - return op->kernel_type == tflite::LSTMKernelType_BASIC; - } else { - return false; - } -} // TODO(krzysd) Handle function calls StatusOr ConvertOp( @@ -525,7 +533,6 @@ StatusOr ConvertOp( const std::vector& intermediate_types, Value optional_arg_marker, const std::vector>& op_codes, - const std::vector& op_names, const std::vector& func_names, const std::vector>& tensors, Location loc, OpBuilder builder) { @@ -537,10 +544,10 @@ StatusOr ConvertOp( return emitError(loc, err.ToString()), err; } - const bool is_basic_lstm = IsBasicLSTMOp(op.builtin_options); - const tflite::OperatorCodeT op_code = *op_codes.at(op.opcode_index); - const std::string& op_name = - is_basic_lstm ? "tfl.basic_lstm" : op_names.at(op.opcode_index); + const tflite::OperatorCodeT& op_code = *op_codes.at(op.opcode_index); + + TF_ASSIGN_OR_RETURN(const std::string op_name, GetMlirOpName(op, op_code)); + OperationState op_state(loc, op_name); for (auto input_num : op.inputs) { @@ -791,8 +798,7 @@ static StatusOr PostProcessFuncOp(FuncOp func) { } // Build a FuncOp from a tflite SubGraph -// The op_names are a mapping from indexes into the TFLite operators array to -// the operator name MLIR expects (tfl.foo_op). The buffers are directly taken +// The buffers are directly taken // from the deserialized flatbuffer as we do not have the type information to // interpret them until this point. The base_loc parameter is the location of // the flatbuffer as a whole (usually a file). The is_entry_point flag @@ -802,7 +808,6 @@ static StatusOr PostProcessFuncOp(FuncOp func) { StatusOr ConvertSubgraph( const tflite::SubGraphT& subgraph, llvm::StringRef name, const std::vector>& op_codes, - const std::vector& op_names, const std::vector& func_names, const std::vector>& buffers, Location base_loc, Builder builder, bool is_entry_point, @@ -1002,8 +1007,7 @@ StatusOr ConvertSubgraph( TF_ASSIGN_OR_RETURN( auto* mlir_op, ConvertOp(*op, vals_map, intermediate_types, maybe_optional_arg_marker, - op_codes, op_names, func_names, subgraph.tensors, op_loc, - op_builder)); + op_codes, func_names, subgraph.tensors, op_loc, op_builder)); // Add the results to the value maps. There are two cases: 1. the result // tensor does not have min/max values, the original op result is used @@ -1069,6 +1073,10 @@ OwningModuleRef tflite::FlatBufferToMlir( const std::vector& ordered_input_arrays, const std::vector& ordered_output_arrays, bool experimental_prune_unreachable_nodes_unconditionally) { + context->loadDialect< + mlir::StandardOpsDialect, mlir::quant::QuantizationDialect, + mlir::TFL::TensorFlowLiteDialect, mlir::TF::TensorFlowDialect>(); + auto model_ptr = FlatBufferModel::VerifyAndBuildFromBuffer(buffer.data(), buffer.length()); if (nullptr == model_ptr) { @@ -1079,17 +1087,6 @@ OwningModuleRef tflite::FlatBufferToMlir( auto builder = Builder(context); - std::vector operator_names; - operator_names.reserve(model->operator_codes.size()); - - for (auto& opcode : model->operator_codes) { - auto operator_name_or_error = OpNameForOpCode(*opcode); - if (!operator_name_or_error.ok()) { - return emitError(base_loc, operator_name_or_error.status().ToString()), - nullptr; - } - operator_names.push_back(operator_name_or_error.ConsumeValueOrDie()); - } std::vector func_names; for (auto& subgraph : model->subgraphs) { @@ -1110,8 +1107,8 @@ OwningModuleRef tflite::FlatBufferToMlir( auto& subgraph = e.value(); std::string name = SubgraphName(e.index(), *subgraph); auto func_or_error = ConvertSubgraph( - *subgraph, name, model->operator_codes, operator_names, func_names, - model->buffers, base_loc, builder, + *subgraph, name, model->operator_codes, func_names, model->buffers, + base_loc, builder, // TODO(b/131175224,b/132239787) Support multiple entry points /*is_entry_point=*/e.index() == 0, /*use_external_constant=*/use_external_constant, ordered_input_arrays, diff --git a/tensorflow/compiler/mlir/lite/flatbuffer_operator.cc b/tensorflow/compiler/mlir/lite/flatbuffer_operator.cc index ceaa4e215cf..5accb419e83 100644 --- a/tensorflow/compiler/mlir/lite/flatbuffer_operator.cc +++ b/tensorflow/compiler/mlir/lite/flatbuffer_operator.cc @@ -95,40 +95,34 @@ static tflite::MirrorPadMode ConvertTFL_MirrorPaddingAttrForOptionWriter( static tflite::TensorType ConvertDerivedTypeAttrForOptionWriter( mlir::Type type, flatbuffers::FlatBufferBuilder* builder) { - switch (type.getKind()) { - case mlir::StandardTypes::F16: - return tflite::TensorType_FLOAT16; - case mlir::StandardTypes::F32: - return tflite::TensorType_FLOAT32; - case mlir::TF::TensorFlowTypes::STRING: - return tflite::TensorType_STRING; - case mlir::StandardTypes::Complex: { - auto etype = type.cast().getElementType(); - if (etype.isF32()) { - return tflite::TensorType_COMPLEX64; - } - llvm_unreachable("invalid complex Type in conversion"); + if (type.isF16()) { + return tflite::TensorType_FLOAT16; + } else if (type.isF32()) { + return tflite::TensorType_FLOAT32; + } else if (type.isa()) { + return tflite::TensorType_STRING; + } else if (auto complex_type = type.dyn_cast()) { + if (complex_type.getElementType().isF32()) { + return tflite::TensorType_COMPLEX64; } - case mlir::StandardTypes::Integer: { - const auto& itype = type.cast(); - switch (itype.getWidth()) { - case 1: - return tflite::TensorType_BOOL; - case 8: - return tflite::TensorType_INT8; - case 16: - return tflite::TensorType_INT16; - case 32: - return tflite::TensorType_INT32; - case 64: - return tflite::TensorType_INT64; - default: - llvm_unreachable("invalid integer Type in conversion"); - } + llvm_unreachable("invalid complex Type in conversion"); + } else if (auto itype = type.dyn_cast()) { + switch (itype.getWidth()) { + case 1: + return tflite::TensorType_BOOL; + case 8: + return tflite::TensorType_INT8; + case 16: + return tflite::TensorType_INT16; + case 32: + return tflite::TensorType_INT32; + case 64: + return tflite::TensorType_INT64; + default: + llvm_unreachable("invalid integer Type in conversion"); } - default: - llvm_unreachable("invalid Type in conversion"); } + llvm_unreachable("invalid Type in conversion"); } // I32Attr already returns an int as required by flatbuffer builders. @@ -255,7 +249,7 @@ Status mlir::CustomOptionsToAttributes( {static_cast(custom_options.size())}, builder.getIntegerType(8)); attributes->emplace_back(builder.getNamedAttr( "custom_option", - OpaqueElementsAttr::get(builder.getContext()->getRegisteredDialect("tfl"), + OpaqueElementsAttr::get(builder.getContext()->getLoadedDialect("tfl"), type, content))); return Status::OK(); diff --git a/tensorflow/compiler/mlir/lite/flatbuffer_translate.cc b/tensorflow/compiler/mlir/lite/flatbuffer_translate.cc index 5b95b30a96c..94f7e2261f7 100644 --- a/tensorflow/compiler/mlir/lite/flatbuffer_translate.cc +++ b/tensorflow/compiler/mlir/lite/flatbuffer_translate.cc @@ -17,6 +17,7 @@ limitations under the License. #include "llvm/Support/FormatVariadic.h" #include "llvm/Support/ToolOutputFile.h" #include "llvm/Support/raw_ostream.h" +#include "mlir/Dialect/Quant/QuantOps.h" // from @llvm-project #include "mlir/Dialect/Quant/QuantTypes.h" // from @llvm-project #include "mlir/Dialect/StandardOps/IR/Ops.h" // from @llvm-project #include "mlir/IR/Attributes.h" // from @llvm-project @@ -33,6 +34,8 @@ limitations under the License. #include "mlir/Translation.h" // from @llvm-project #include "tensorflow/compiler/mlir/lite/flatbuffer_export.h" #include "tensorflow/compiler/mlir/lite/flatbuffer_import.h" +#include "tensorflow/compiler/mlir/lite/ir/tfl_ops.h" +#include "tensorflow/compiler/mlir/tensorflow/ir/tf_ops.h" #include "tensorflow/compiler/mlir/tensorflow/translate/mlir_roundtrip_flags.h" using llvm::cl::opt; @@ -175,5 +178,11 @@ static TranslateToMLIRRegistration FlatBufferFileToMlirTransReg( }); static TranslateFromMLIRRegistration MLIRToFlatBufferTranslate( - "mlir-to-tflite-flatbuffer", MlirToFlatBufferFileTranslateFunction); + "mlir-to-tflite-flatbuffer", MlirToFlatBufferFileTranslateFunction, + [](DialectRegistry& registry) { + registry.insert(); + registry.insert(); + registry.insert(); + registry.insert(); + }); } // namespace mlir diff --git a/tensorflow/compiler/mlir/lite/ir/tfl_ops.cc b/tensorflow/compiler/mlir/lite/ir/tfl_ops.cc index b5fcd5e82e2..403b3dd18ad 100644 --- a/tensorflow/compiler/mlir/lite/ir/tfl_ops.cc +++ b/tensorflow/compiler/mlir/lite/ir/tfl_ops.cc @@ -30,6 +30,7 @@ limitations under the License. #include "mlir/Dialect/StandardOps/IR/Ops.h" // from @llvm-project #include "mlir/IR/Attributes.h" // from @llvm-project #include "mlir/IR/Builders.h" // from @llvm-project +#include "mlir/IR/Location.h" // from @llvm-project #include "mlir/IR/Matchers.h" // from @llvm-project #include "mlir/IR/OpImplementation.h" // from @llvm-project #include "mlir/IR/PatternMatch.h" // from @llvm-project @@ -253,9 +254,8 @@ struct TensorFlowLiteInlinerInterface : public DialectInlinerInterface { } }; -struct TensorFlowLiteOpFolderDialectInterface - : public OpFolderDialectInterface { - using OpFolderDialectInterface::OpFolderDialectInterface; +struct TensorFlowLiteDialectFoldInterface : public DialectFoldInterface { + using DialectFoldInterface::DialectFoldInterface; // Registered hook to check if the given region, which is attached to an // operation that is *not* isolated from above (i.e. no internal regions @@ -275,7 +275,7 @@ TensorFlowLiteDialect::TensorFlowLiteDialect(mlir::MLIRContext *context) #include "tensorflow/compiler/mlir/lite/ir/tfl_ops.cc.inc" >(); addInterfaces(); + TensorFlowLiteDialectFoldInterface>(); } //===----------------------------------------------------------------------===// @@ -1028,9 +1028,12 @@ static LogicalResult Verify(PackOp op) { // Check axis bounds. if (input_type.hasRank()) { int64_t axis_value = op.axis().getSExtValue(); - if (abs(axis_value) > input_type.getRank()) - return op.emitOpError("op attribute 'axis' is out of bounds, got ") - << axis_value; + if (axis_value < 0) axis_value += input_type.getRank() + 1; + if (axis_value < 0 || axis_value >= input_type.getRank() + 1) + return op.emitOpError() + << "op attribute 'axis' should be in range [-rank - 1, rank + 1), " + << "got rank = " << input_type.getRank() + << ", and axis = " << op.axis().getSExtValue(); } // Make sure all inputs have the same shape and element type. @@ -1443,12 +1446,59 @@ void FakeQuantOp::getCanonicalizationPatterns(OwningRewritePatternList &results, // TODO(b/133486129): Implement shape inference for unpack -static LogicalResult Verify(UnpackOp op) { - // TODO(antiagainst): Implement other checks as in - // tensorflow/lite/kernels/unpack.cc +LogicalResult UnpackOp::inferReturnTypes( + MLIRContext *context, Optional loc, ValueRange operands, + DictionaryAttr attributes, RegionRange regions, + SmallVectorImpl &inferredReturnTypes) { + UnpackOpAdaptor op(operands, attributes); + // TODO(jpienaar): Refactor verify + if (failed(op.verify(loc.hasValue() ? *loc : UnknownLoc::get(context)))) + return failure(); - if (op.getOperation()->getNumResults() != op.num()) - return op.emitOpError("output count should match 'num' attribute"); + if (operands.size() != 1) { + return emitOptionalError(loc, "input count should be equal to 1"); + } + + const int64_t num_value = op.num().getInt(); + auto input_type = operands[0].getType().dyn_cast(); + if (!input_type || !input_type.hasRank()) { + // If input is unranked, then so is output. + inferredReturnTypes.assign( + num_value, UnrankedTensorType::get(input_type.getElementType())); + return success(); + } + + if (input_type.hasStaticShape() && input_type.getNumElements() <= 0) { + return emitOptionalError( + loc, "number of elements in input shoule be larger than 0"); + } + + const int64_t rank = input_type.getRank(); + if (rank <= 0) { + return emitOptionalError(loc, "input should be of rank larger than 0"); + } + + int64_t axis_value = op.axis().getInt(); + if (axis_value < 0) { + axis_value += rank; + } + if (axis_value < 0 || axis_value >= rank) { + return emitOptionalError( + loc, "attribute 'axis' should be in range [-rank, rank), got axis = ", + op.axis().getInt(), ", and rank = ", rank); + } + + if (!ShapedType::isDynamic(input_type.getDimSize(axis_value)) && + input_type.getDimSize(axis_value) != num_value) { + return emitOptionalError(loc, "output count should match 'num' attribute"); + } + + auto output_shape = llvm::to_vector<4>(input_type.getShape()); + output_shape.erase(output_shape.begin() + axis_value); + + auto output_type = + RankedTensorType::get(output_shape, input_type.getElementType()); + inferredReturnTypes.assign(num_value, output_type); return success(); } diff --git a/tensorflow/compiler/mlir/lite/ir/tfl_ops.h b/tensorflow/compiler/mlir/lite/ir/tfl_ops.h index caed0bb3ad9..d2d8442155b 100644 --- a/tensorflow/compiler/mlir/lite/ir/tfl_ops.h +++ b/tensorflow/compiler/mlir/lite/ir/tfl_ops.h @@ -26,6 +26,7 @@ limitations under the License. #include "mlir/IR/OpImplementation.h" // from @llvm-project #include "mlir/IR/StandardTypes.h" // from @llvm-project #include "mlir/Interfaces/DerivedAttributeOpInterface.h" // from @llvm-project +#include "mlir/Interfaces/InferTypeOpInterface.h" // from @llvm-project #include "mlir/Interfaces/LoopLikeInterface.h" // from @llvm-project #include "mlir/Interfaces/SideEffectInterfaces.h" // from @llvm-project #include "mlir/Support/LLVM.h" // from @llvm-project diff --git a/tensorflow/compiler/mlir/lite/ir/tfl_ops.td b/tensorflow/compiler/mlir/lite/ir/tfl_ops.td index 6dc9fda656f..f1cdfec631d 100644 --- a/tensorflow/compiler/mlir/lite/ir/tfl_ops.td +++ b/tensorflow/compiler/mlir/lite/ir/tfl_ops.td @@ -19,6 +19,7 @@ limitations under the License. #define TFL_OPS include "mlir/IR/OpBase.td" +include "mlir/Interfaces/InferTypeOpInterface.td" include "mlir/Interfaces/LoopLikeInterface.td" include "mlir/Interfaces/SideEffectInterfaces.td" include "tensorflow/compiler/mlir/lite/ir/tfl_op_interfaces.td" @@ -107,7 +108,11 @@ def OpaqueBytesAttr : ElementsAttrBase< ".getElementType().isInteger(8)">, ]>, "opaque bytes attribute" - >; + > { + let storageType = [{ OpaqueElementsAttr }]; + let returnType = [{ OpaqueElementsAttr }]; + let convertFromStorage = "$_self"; +} //===----------------------------------------------------------------------===// // Derived shape attribute class. @@ -3024,7 +3029,8 @@ def TFL_TransposeOp : TFL_Op<"transpose", [ def TFL_UnpackOp : TFL_Op<"unpack", [ NoSideEffect, SameOperandsAndResultElementType, - SameOperandsAndResultsScale]> { + SameOperandsAndResultsScale, + DeclareOpInterfaceMethods]> { let summary = "Unpacks a tensor along a dimension into multiple tensors"; let description = [{ @@ -3047,7 +3053,7 @@ def TFL_UnpackOp : TFL_Op<"unpack", [ let arguments = (ins TFL_TensorOf<[F32, I1, I8, UI8, I32, QI8, QUI8, I16, QI16]>:$input, - I32Attr:$num, + Confined:$num, I32Attr:$axis ); @@ -3055,8 +3061,6 @@ def TFL_UnpackOp : TFL_Op<"unpack", [ TFL_VariadicTensorOf<[F32, I1, I8, UI8, I32, QI8, QUI8, I16, QI16]>:$outputs ); - let verifier = [{ return Verify(*this); }]; - let hasOptions = 1; } diff --git a/tensorflow/compiler/mlir/lite/mlir_tflite_runner.cc b/tensorflow/compiler/mlir/lite/mlir_tflite_runner.cc index 0d42fbb9646..35a58a01a29 100644 --- a/tensorflow/compiler/mlir/lite/mlir_tflite_runner.cc +++ b/tensorflow/compiler/mlir/lite/mlir_tflite_runner.cc @@ -30,12 +30,16 @@ limitations under the License. #include "llvm/Support/MemoryBuffer.h" #include "llvm/Support/SMLoc.h" #include "llvm/Support/SourceMgr.h" +#include "mlir/Dialect/StandardOps/IR/Ops.h" // from @llvm-project +#include "mlir/IR/Dialect.h" // from @llvm-project #include "mlir/IR/Function.h" // from @llvm-project #include "mlir/IR/MLIRContext.h" // from @llvm-project #include "mlir/IR/Module.h" // from @llvm-project #include "mlir/Parser.h" // from @llvm-project #include "tensorflow/compiler/mlir/lite/flatbuffer_export.h" #include "tensorflow/compiler/mlir/lite/flatbuffer_export_flags.h" +#include "tensorflow/compiler/mlir/lite/ir/tfl_ops.h" +#include "tensorflow/compiler/mlir/tensorflow/ir/tf_ops.h" #include "tensorflow/core/platform/init_main.h" #include "tensorflow/core/platform/logging.h" #include "tensorflow/lite/delegates/flex/delegate.h" @@ -98,6 +102,10 @@ int main(int argc, char** argv) { // Load the MLIR module. mlir::MLIRContext context; + context.getDialectRegistry() + .insert(); + llvm::SourceMgr source_mgr; source_mgr.AddNewSourceBuffer(std::move(*file_or_err), llvm::SMLoc()); mlir::OwningModuleRef module(mlir::parseSourceFile(source_mgr, &context)); diff --git a/tensorflow/compiler/mlir/lite/quantization/import_quant_stats_pass.cc b/tensorflow/compiler/mlir/lite/quantization/import_quant_stats_pass.cc index 6299a70b1df..7e7d4678a87 100644 --- a/tensorflow/compiler/mlir/lite/quantization/import_quant_stats_pass.cc +++ b/tensorflow/compiler/mlir/lite/quantization/import_quant_stats_pass.cc @@ -62,6 +62,10 @@ class ImportQuantStatsPass void runOnFunction() override; + void getDependentDialects(DialectRegistry ®istry) const override { + registry.insert(); + } + // Parses the serialized quant stats protobuf and initialize the internal // data structure. This method must be called after the pass is created. bool ParseQuantStats(const std::string &stats_str); diff --git a/tensorflow/compiler/mlir/lite/quantization/lite/BUILD b/tensorflow/compiler/mlir/lite/quantization/lite/BUILD index 31c0e4cb8a9..38c7ad86e05 100644 --- a/tensorflow/compiler/mlir/lite/quantization/lite/BUILD +++ b/tensorflow/compiler/mlir/lite/quantization/lite/BUILD @@ -28,6 +28,7 @@ cc_library( deps = [ "//tensorflow/compiler/mlir/lite:common", "//tensorflow/compiler/mlir/lite:flatbuffer_translate_lib", + "//tensorflow/compiler/mlir/lite:tensorflow_lite", "//tensorflow/compiler/mlir/lite:tensorflow_lite_quantize", "//tensorflow/compiler/mlir/lite/quantization:quantization_config", "//tensorflow/compiler/mlir/tensorflow:error_util", @@ -74,6 +75,6 @@ tf_cc_binary( "//tensorflow/lite/schema:schema_fbs", "@com_google_absl//absl/strings", "@llvm-project//llvm:Support", - "@llvm-project//mlir:AllPassesAndDialects", + "@llvm-project//mlir:AllPassesAndDialectsNoRegistration", ], ) diff --git a/tensorflow/compiler/mlir/lite/quantization/lite/quantize_model.cc b/tensorflow/compiler/mlir/lite/quantization/lite/quantize_model.cc index a2e3c065113..238710bcf13 100644 --- a/tensorflow/compiler/mlir/lite/quantization/lite/quantize_model.cc +++ b/tensorflow/compiler/mlir/lite/quantization/lite/quantize_model.cc @@ -25,6 +25,7 @@ limitations under the License. #include "tensorflow/compiler/mlir/lite/common/tfl_pass_config.h" #include "tensorflow/compiler/mlir/lite/flatbuffer_export.h" #include "tensorflow/compiler/mlir/lite/flatbuffer_import.h" +#include "tensorflow/compiler/mlir/lite/ir/tfl_ops.h" #include "tensorflow/compiler/mlir/lite/quantization/quantization_config.h" #include "tensorflow/compiler/mlir/lite/transforms/passes.h" #include "tensorflow/compiler/mlir/lite/utils/convert_type.h" @@ -52,6 +53,7 @@ TfLiteStatus QuantizeModel( } MLIRContext context; + context.getDialectRegistry().insert(); StatusScopedDiagnosticHandler statusHandler(&context, /*propagate=*/true); diff --git a/tensorflow/compiler/mlir/lite/quantization/quantization_driver.cc b/tensorflow/compiler/mlir/lite/quantization/quantization_driver.cc index 9e0ad990657..16b51496b5f 100644 --- a/tensorflow/compiler/mlir/lite/quantization/quantization_driver.cc +++ b/tensorflow/compiler/mlir/lite/quantization/quantization_driver.cc @@ -99,12 +99,14 @@ class QuantizationDriver { public: explicit QuantizationDriver(FuncOp fn, bool is_signed, bool disable_per_channel, - OpQuantSpecGetter op_quant_spec_getter) + OpQuantSpecGetter op_quant_spec_getter, + bool enforce_fixed_output_range) : fn_(fn), builder_(fn.getBody()), is_signed_(is_signed), disable_per_channel_(disable_per_channel), - op_quant_spec_getter_(op_quant_spec_getter) {} + op_quant_spec_getter_(op_quant_spec_getter), + enforce_fixed_output_range_(enforce_fixed_output_range) {} // The entry point of the quantization parameters propagation. void Run(); @@ -354,6 +356,8 @@ class QuantizationDriver { llvm::SmallVector args_; OpQuantSpecGetter op_quant_spec_getter_; + + bool enforce_fixed_output_range_; }; } // namespace @@ -794,7 +798,8 @@ bool QuantizationDriver::PropagateParams() { } // TODO(fengliuai): make the bit width configurable. - if (auto restricted = llvm::dyn_cast(op)) { + auto restricted = llvm::dyn_cast(op); + if (restricted && enforce_fixed_output_range_) { // TODO(fengliuai): different result can have different fixed range. auto params = restricted.GetFixedOutputRange(is_signed_, /*bit_width=*/8); for (auto i = 0; i < op->getNumResults(); ++i) { @@ -864,10 +869,12 @@ void QuantizationDriver::Run() { } } -void ApplyQuantizationParamsPropagation( - mlir::FuncOp func, bool is_signed, bool disable_per_channel, - OpQuantSpecGetter op_quant_spec_getter) { - QuantizationDriver(func, is_signed, disable_per_channel, op_quant_spec_getter) +void ApplyQuantizationParamsPropagation(mlir::FuncOp func, bool is_signed, + bool disable_per_channel, + OpQuantSpecGetter op_quant_spec_getter, + bool post_training_quantization) { + QuantizationDriver(func, is_signed, disable_per_channel, op_quant_spec_getter, + post_training_quantization) .Run(); } diff --git a/tensorflow/compiler/mlir/lite/quantization/quantization_utils.h b/tensorflow/compiler/mlir/lite/quantization/quantization_utils.h index 07e5ba4e879..6e356acbbdf 100644 --- a/tensorflow/compiler/mlir/lite/quantization/quantization_utils.h +++ b/tensorflow/compiler/mlir/lite/quantization/quantization_utils.h @@ -490,9 +490,13 @@ quant::QuantizedType GetUniformQuantizedTypeForBias( // and the propagation results are materialized by inserting pairs of quantize // and dequantize ops to this function. Set `disable_per_channel` to true to not // use per channel quantization even the op supports it. +// Setting `enforce_fixed_output_range` to true, to infer quantization +// parameters from the fixed output range ops. This is only used for +// post-training quantization. void ApplyQuantizationParamsPropagation(mlir::FuncOp func, bool is_signed, bool disable_per_channel, - OpQuantSpecGetter op_quant_spec_getter); + OpQuantSpecGetter op_quant_spec_getter, + bool enforce_fixed_output_range); // The function might contain more stats ops than required, and it will // introduce requantize if the calibration stats have conflicts. This method diff --git a/tensorflow/compiler/mlir/lite/tests/end2end/if_op.pbtxt b/tensorflow/compiler/mlir/lite/tests/end2end/if_op.pbtxt index f482e3db6b9..a7f6040f211 100644 --- a/tensorflow/compiler/mlir/lite/tests/end2end/if_op.pbtxt +++ b/tensorflow/compiler/mlir/lite/tests/end2end/if_op.pbtxt @@ -1,4 +1,4 @@ -# RUN: tf_tfl_translate -tf-input-arrays=a,b -tf-input-data-types=DT_FLOAT,DT_FLOAT -tf-input-shapes=4:4 -tf-output-arrays=StatefulIf,StatelessIf %s -o - --output-mlir | FileCheck %s +# RUN: tf_tfl_translate -tf-input-arrays=a,b -tf-input-data-types=DT_FLOAT,DT_FLOAT -tf-input-shapes=: -tf-output-arrays=StatefulIf,StatelessIf %s -o - --output-mlir | FileCheck %s node { name: "tf.Less" op: "Less" diff --git a/tensorflow/compiler/mlir/lite/tests/mlir2flatbuffer/unknown-op.mlir b/tensorflow/compiler/mlir/lite/tests/mlir2flatbuffer/unknown-op.mlir deleted file mode 100644 index 7e9f66baa90..00000000000 --- a/tensorflow/compiler/mlir/lite/tests/mlir2flatbuffer/unknown-op.mlir +++ /dev/null @@ -1,8 +0,0 @@ -// RUN: not flatbuffer_translate -mlir-to-tflite-flatbuffer %s -o - 2>&1 | FileCheck %s - -func @main(tensor<3x2xi32>) -> tensor<3x2xi32> { -^bb0(%arg0: tensor<3x2xi32>): - // CHECK: error: 'unknown_op' op dialect is not registered - %0 = "unknown_op"(%arg0) : (tensor<3x2xi32>) -> tensor<3x2xi32> - return %0 : tensor<3x2xi32> -} diff --git a/tensorflow/compiler/mlir/lite/tests/ops.mlir b/tensorflow/compiler/mlir/lite/tests/ops.mlir index 7ef6997f938..cbb562c2e03 100644 --- a/tensorflow/compiler/mlir/lite/tests/ops.mlir +++ b/tensorflow/compiler/mlir/lite/tests/ops.mlir @@ -1139,9 +1139,15 @@ func @packInputRank(%arg0: tensor<1x4xi32>, %arg1: tensor<1x4xi32>) -> tensor<1x // ----- -func @packNegInputRank(%arg0: tensor<1x4xi32>, %arg1: tensor<1x4xi32>) -> tensor<2x1x4xi32> { +func @packNegInputAxis2(%arg0: tensor<1x4xi32>, %arg1: tensor<1x4xi32>) -> tensor<1x2x4xi32> { // CHECK: "tfl.pack"(%arg0, %arg1) {axis = -2 : i32, values_count = 2 : i32} - %0 = "tfl.pack"(%arg0, %arg1) {axis = -2 : i32, values_count = 2 : i32} : (tensor<1x4xi32>, tensor<1x4xi32>) -> tensor<2x1x4xi32> + %0 = "tfl.pack"(%arg0, %arg1) {axis = -2 : i32, values_count = 2 : i32} : (tensor<1x4xi32>, tensor<1x4xi32>) -> tensor<1x2x4xi32> + return %0 : tensor<1x2x4xi32> +} + +func @packNegInputAxis3(%arg0: tensor<1x4xi32>, %arg1: tensor<1x4xi32>) -> tensor<2x1x4xi32> { + // CHECK: "tfl.pack"(%arg0, %arg1) {axis = -3 : i32, values_count = 2 : i32} + %0 = "tfl.pack"(%arg0, %arg1) {axis = -3 : i32, values_count = 2 : i32} : (tensor<1x4xi32>, tensor<1x4xi32>) -> tensor<2x1x4xi32> return %0 : tensor<2x1x4xi32> } @@ -1172,7 +1178,7 @@ func @pack(%arg0: tensor<1xi32>, %arg1: tensor<2xi32>) -> tensor<2x2xi32> { // ----- func @pack(%arg0: tensor<2xi32>, %arg1: tensor<2xi32>) -> tensor<2x2xi32> { - // expected-error @+1 {{op attribute 'axis' is out of bounds, got 3}} + // expected-error @+1 {{op attribute 'axis' should be in range [-rank - 1, rank + 1), got rank = 1, and axis = 3}} %0 = "tfl.pack"(%arg0, %arg1) {axis = 3 : i32, values_count = 2 : i32} : (tensor<2xi32>, tensor<2xi32>) -> tensor<2x2xi32> return %0 : tensor<2x2xi32> } @@ -1183,7 +1189,22 @@ func @unpack(%arg0: tensor<2x3xi32>) -> tensor<2xi32> { // CHECK: "tfl.unpack"(%arg0) {axis = 1 : i32, num = 3 : i32} %0:3 = "tfl.unpack"(%arg0) {axis = 1 : i32, num = 3 : i32} : (tensor<2x3xi32>) -> (tensor<2xi32>, tensor<2xi32>, tensor<2xi32>) return %0#0 : tensor<2xi32> +} +// ----- + +func @unpack(%arg0: tensor<2x3xi32>) -> tensor<2xi32> { + // CHECK: "tfl.unpack"(%arg0) {axis = -1 : i32, num = 3 : i32} + %0:3 = "tfl.unpack"(%arg0) {axis = -1 : i32, num = 3 : i32} : (tensor<2x3xi32>) -> (tensor<2xi32>, tensor<2xi32>, tensor<2xi32>) + return %0#0 : tensor<2xi32> +} + +// ----- + +func @unpack(%arg0: tensor<2x3xi32>) -> tensor<3xi32> { + // CHECK: "tfl.unpack"(%arg0) {axis = -2 : i32, num = 2 : i32} + %0:2 = "tfl.unpack"(%arg0) {axis = -2 : i32, num = 2 : i32} : (tensor<2x3xi32>) -> (tensor<3xi32>, tensor<3xi32>) + return %0#0 : tensor<3xi32> } // ----- @@ -1204,6 +1225,45 @@ func @unpack(%arg0: tensor<2x3xi32>) -> tensor<2xi32> { // ----- +func @unpack(%arg0: tensor<2x3xi32>) -> tensor<2xi32> { + // expected-error @+1 {{attribute 'axis' should be in range [-rank, rank), got axis = 2, and rank = 2}} + %0:3 = "tfl.unpack"(%arg0) {axis = 2 : i32, num = 3 : i32} : (tensor<2x3xi32>) -> (tensor<2xi32>, tensor<2xi32>, tensor<2xi32>) + return %0#0 : tensor<2xi32> +} + +// ----- + +func @unpack(%arg0: tensor<2x3xi32>) -> tensor<2xi32> { + // expected-error @+1 {{attribute 'axis' should be in range [-rank, rank), got axis = -3, and rank = 2}} + %0:3 = "tfl.unpack"(%arg0) {axis = -3 : i32, num = 3 : i32} : (tensor<2x3xi32>) -> (tensor<2xi32>, tensor<2xi32>, tensor<2xi32>) + return %0#0 : tensor<2xi32> +} + +// ----- + +func @unpack(%arg0: tensor) -> tensor<2xi32> { + // expected-error @+1 {{input should be of rank larger than 0}} + %0:3 = "tfl.unpack"(%arg0) {axis = 0 : i32, num = 3 : i32} : (tensor) -> (tensor<2xi32>, tensor<2xi32>, tensor<2xi32>) + return %0#0 : tensor<2xi32> +} + +// ----- + +func @unpack(%arg0: tensor<2x3xi32>) -> tensor<2xi32> { + // expected-error @+1 {{op inferred type incompatible with return type of operation}} + %0:3 = "tfl.unpack"(%arg0) {axis = 1 : i32, num = 3 : i32} : (tensor<2x3xi32>) -> (tensor<2xi32>, tensor<2x1xi32>, tensor<2xi32>) + return %0#0 : tensor<2xi32> +} + +// ----- + +func @unpack(%arg0: tensor<*xi32>) -> (tensor<*xi32>, tensor<*xi32>) { + %0:2 = "tfl.unpack"(%arg0) {axis = 1 : i32, num = 2 : i32} : (tensor<*xi32>) -> (tensor<*xi32>, tensor<*xi32>) + return %0#0, %0#1 : tensor<*xi32>, tensor<*xi32> +} + +// ----- + // CHECK-LABEL: testMean func @testMean(%arg0: tensor<2x2xf32>, %arg1 : tensor<1xi32>) -> tensor<1x2xf32> { // CHECK: "tfl.mean"(%arg0, %arg1) {keep_dims = false} diff --git a/tensorflow/compiler/mlir/lite/tests/optimize.mlir b/tensorflow/compiler/mlir/lite/tests/optimize.mlir index 7923c82ba92..edbcef3d321 100644 --- a/tensorflow/compiler/mlir/lite/tests/optimize.mlir +++ b/tensorflow/compiler/mlir/lite/tests/optimize.mlir @@ -1115,3 +1115,63 @@ func @ConvertIdentityScatterNd(%arg0: tensor<4x3xf32>) -> tensor<4x3xf32> { // CHECK-SAME: (%[[ARG:.*]]: tensor<4x3xf32>) -> tensor<4x3xf32> // CHECK-NEXT: return %[[ARG]] : tensor<4x3xf32> } + +func @ReshapeAddUnknownShape(%arg0: tensor<*xf32>) -> tensor<3x4xf32> { + %cst = constant dense<[3, 4]> : tensor<2xi32> + %cst_0 = constant dense<1.000000e+00> : tensor<3x4xf32> + %0 = "tfl.reshape"(%arg0, %cst) : (tensor<*xf32>, tensor<2xi32>) -> tensor<3x4xf32> + %1 = "tfl.add"(%0, %cst_0) {fused_activation_function = "NONE"} : (tensor<3x4xf32>, tensor<3x4xf32>) -> tensor<3x4xf32> + return %1 : tensor<3x4xf32> +// CHECK-LABEL: ReshapeAddUnknownShape +// CHECK: %[[rs1:.*]] = "tfl.reshape"(%arg0 +// CHECK: %[[rs2:.*]] = tfl.add %[[rs1]] +// CHECK: return %[[rs2]] +} + +func @FoldSumKeepDim(%arg0: tensor<8x128xf32>) -> tensor<8x1xf32> { + %cst = constant dense<1> : tensor<1xi32> + %cst_1 = constant dense<[8, 1]> : tensor<2xi32> + %0 = "tfl.sum"(%arg0, %cst) {keep_dims = false} : (tensor<8x128xf32>, tensor<1xi32>) -> tensor<8xf32> + %1 = "tfl.reshape"(%0, %cst_1) : (tensor<8xf32>, tensor<2xi32>) -> tensor<8x1xf32> + return %1 : tensor<8x1xf32> + +// CHECK-LABEL: FoldSumKeepDim +// CHECK: %[[RESULT:.*]] = "tfl.sum"(%arg0, %cst) {keep_dims = true} : (tensor<8x128xf32>, tensor<1xi32>) -> tensor<8x1xf32> +// CHECK: return %[[RESULT]] : tensor<8x1xf32> +} + +func @FoldReduceMinKeepDim(%arg0: tensor<8x128xf32>) -> tensor<1x128xf32> { + %cst = constant dense<0> : tensor<1xi32> + %cst_1 = constant dense<[1, 128]> : tensor<2xi32> + %0 = "tfl.reduce_min"(%arg0, %cst) {keep_dims = false} : (tensor<8x128xf32>, tensor<1xi32>) -> tensor<128xf32> + %1 = "tfl.reshape"(%0, %cst_1) : (tensor<128xf32>, tensor<2xi32>) -> tensor<1x128xf32> + return %1 : tensor<1x128xf32> + +// CHECK-LABEL: FoldReduceMinKeepDim +// CHECK: %[[RESULT:.*]] = "tfl.reduce_min"(%arg0, %cst) {keep_dims = true} : (tensor<8x128xf32>, tensor<1xi32>) -> tensor<1x128xf32> +// CHECK: return %[[RESULT]] : tensor<1x128xf32> +} + +func @FoldReduceMaxKeepDim(%arg0: tensor<8x128xf32>) -> tensor<1x128xf32> { + %cst = constant dense<0> : tensor<1xi32> + %cst_1 = constant dense<[1, 128]> : tensor<2xi32> + %0 = "tfl.reduce_max"(%arg0, %cst) {keep_dims = false} : (tensor<8x128xf32>, tensor<1xi32>) -> tensor<128xf32> + %1 = "tfl.reshape"(%0, %cst_1) : (tensor<128xf32>, tensor<2xi32>) -> tensor<1x128xf32> + return %1 : tensor<1x128xf32> + +// CHECK-LABEL: FoldReduceMaxKeepDim +// CHECK: %[[RESULT:.*]] = "tfl.reduce_max"(%arg0, %cst) {keep_dims = true} : (tensor<8x128xf32>, tensor<1xi32>) -> tensor<1x128xf32> +// CHECK: return %[[RESULT]] : tensor<1x128xf32> +} + +func @FoldReduceProdKeepDim(%arg0: tensor<8x128xf32>) -> tensor<1x1xf32> { + %cst = constant dense<[0, 1]> : tensor<2xi32> + %cst_1 = constant dense<[1, 1]> : tensor<2xi32> + %0 = "tfl.reduce_prod"(%arg0, %cst) {keep_dims = false} : (tensor<8x128xf32>, tensor<2xi32>) -> tensor + %1 = "tfl.reshape"(%0, %cst_1) : (tensor, tensor<2xi32>) -> tensor<1x1xf32> + return %1 : tensor<1x1xf32> + +// CHECK-LABEL: FoldReduceProdKeepDim +// CHECK: %[[RESULT:.*]] = "tfl.reduce_prod"(%arg0, %cst) {keep_dims = true} : (tensor<8x128xf32>, tensor<2xi32>) -> tensor<1x1xf32> +// CHECK: return %[[RESULT]] : tensor<1x1xf32> +} diff --git a/tensorflow/compiler/mlir/lite/tests/prepare-tf.mlir b/tensorflow/compiler/mlir/lite/tests/prepare-tf.mlir index 6ee5b67d65e..6a992d6dfe4 100644 --- a/tensorflow/compiler/mlir/lite/tests/prepare-tf.mlir +++ b/tensorflow/compiler/mlir/lite/tests/prepare-tf.mlir @@ -615,4 +615,18 @@ func @broadcast_to_i32(%input: tensor<3xi32>, %shape: tensor<2xi32>) -> tensor<3 // CHECK: return [[MUL]] : tensor<3x3xi32> } +// CHECK-LABEL: lower_rfft_to_rfft2d +func @lower_rfft_to_rfft2d(%input: tensor<10x20x30xf32>, %fft_len: tensor<1xi32>) -> tensor<10x20x30xcomplex> { + %0 = "tf.RFFT"(%input, %fft_len) : (tensor<10x20x30xf32>, tensor<1xi32>) -> tensor<10x20x30xcomplex> + return %0: tensor<10x20x30xcomplex> + +// CHECK: %[[CST:.*]] = constant dense<-2> : tensor +// CHECK: %[[CST0:.*]] = constant dense<1> : tensor<1xi32> +// CHECK: %[[CST1:.*]] = constant dense<0> : tensor +// CHECK: %[[EXP:.*]] = "tf.ExpandDims"(%arg0, %[[CST]]) : (tensor<10x20x30xf32>, tensor) -> tensor<10x20x1x30xf32> +// CHECK: %[[CON:.*]] = "tf.ConcatV2"(%[[CST0]], %arg1, %[[CST1]]) : (tensor<1xi32>, tensor<1xi32>, tensor) -> tensor<2xi32> +// CHECK: %[[RFF:.*]] = "tf.RFFT2D"(%[[EXP]], %[[CON]]) : (tensor<10x20x1x30xf32>, tensor<2xi32>) -> tensor<10x20x1x30xcomplex> +// CHECK: %[[SQE:.*]] = "tf.Squeeze"(%[[RFF]]) {squeeze_dims = [-2]} : (tensor<10x20x1x30xcomplex>) -> tensor<10x20x30xcomplex> +} + } diff --git a/tensorflow/compiler/mlir/lite/transforms/optimize.cc b/tensorflow/compiler/mlir/lite/transforms/optimize.cc index eeecfac67cf..d28ee4b31fa 100644 --- a/tensorflow/compiler/mlir/lite/transforms/optimize.cc +++ b/tensorflow/compiler/mlir/lite/transforms/optimize.cc @@ -27,6 +27,7 @@ limitations under the License. #include "llvm/ADT/APFloat.h" #include "llvm/ADT/APInt.h" #include "llvm/ADT/ArrayRef.h" +#include "llvm/ADT/SmallSet.h" #include "llvm/ADT/SmallVector.h" #include "llvm/ADT/StringRef.h" #include "llvm/ADT/StringSwitch.h" @@ -37,8 +38,10 @@ limitations under the License. #include "mlir/IR/PatternMatch.h" // from @llvm-project #include "mlir/IR/StandardTypes.h" // from @llvm-project #include "mlir/IR/TypeUtilities.h" // from @llvm-project +#include "mlir/IR/Value.h" // from @llvm-project #include "mlir/Pass/Pass.h" // from @llvm-project #include "mlir/Support/LLVM.h" // from @llvm-project +#include "mlir/Support/LogicalResult.h" // from @llvm-project #include "tensorflow/compiler/mlir/lite/ir/tfl_ops.h" #include "tensorflow/compiler/mlir/lite/quantization/quantization_utils.h" #include "tensorflow/compiler/mlir/lite/transforms/passes.h" @@ -103,7 +106,8 @@ bool OperandsBroadcastToOutputType(Type a, Type b, Type expected_output) { bool IsTailOfShape(Type type1, Type type2) { auto tail_type = type1.dyn_cast(); auto full_type = type2.dyn_cast(); - if (!tail_type || !full_type || tail_type.getRank() > full_type.getRank()) + if (!tail_type || !full_type || !tail_type.hasRank() || + !full_type.hasRank() || tail_type.getRank() > full_type.getRank()) return false; auto i1 = tail_type.getShape().rbegin(), e1 = tail_type.getShape().rend(); auto i2 = full_type.getShape().rbegin(); @@ -244,6 +248,38 @@ static Type GetShapeStrippedType(TypeAttr type_attr) { } } +// Returns `true` if reducing `axes` in `input` with `keep_dims=true` results in +// the specified `shape` and `false` otherwise. +static bool ShapeMatchesReduceWithKeepAxes(Value input, + const mlir::Attribute &axes, + const mlir::Attribute &shape) { + RankedTensorType type = input.getType().dyn_cast_or_null(); + if (!type) return false; + + DenseIntElementsAttr axes_attr = + axes.dyn_cast_or_null(); + DenseIntElementsAttr shape_attr = + shape.dyn_cast_or_null(); + if (!axes_attr || !shape_attr) return false; + + if (shape_attr.getNumElements() != type.getRank()) return false; + + llvm::SmallSet axes_set; + for (auto a : axes_attr.getIntValues()) { + axes_set.insert(a.getZExtValue()); + } + + auto type_shape = type.getShape(); + for (uint64_t i = 0; i < type.getRank(); ++i) { + if (axes_set.contains(i)) { + if (shape_attr.getValue({i}) != 1) return false; + } else { + if (shape_attr.getValue({i}) != type_shape[i]) return false; + } + } + return true; +} + #include "tensorflow/compiler/mlir/lite/transforms/generated_optimize.inc" // Fuse Add with proceeding FullyConnected. diff --git a/tensorflow/compiler/mlir/lite/transforms/optimize_patterns.td b/tensorflow/compiler/mlir/lite/transforms/optimize_patterns.td index 3c5fc7a0c5e..559d22dcf47 100644 --- a/tensorflow/compiler/mlir/lite/transforms/optimize_patterns.td +++ b/tensorflow/compiler/mlir/lite/transforms/optimize_patterns.td @@ -535,4 +535,20 @@ def OptimizeIdentityScatterNdOp : Pat< (replaceWithValue $params), [(CanOptimizeIdentityGatherNdOrScatterNdOp $params, $indices)]>; +def ShapeMatchesReduceWithKeepAxes : Constraint>; + +// Fold reshapes re-inserting reduced dimensions into the results of a reduction +// with `keep_dims=false` by chaning it to one using `keep_dims=true`. +foreach ReduceOp = [TFL_ReduceMaxOp, TFL_ReduceMinOp, TFL_ReduceProdOp, + TFL_SumOp] in { + def FoldReshapeTo#ReduceOp : Pat< + (TFL_ReshapeOp + (ReduceOp:$reduce $input, (ConstantOp I32ElementsAttr: $axes), + ConstBoolAttrFalse), + (ConstantOp I32ElementsAttr: $shape)), + (ReduceOp $input, (ConstantOp $axes), ConstBoolAttrTrue), + [(ShapeMatchesReduceWithKeepAxes $input, $axes, $shape), + (HasOneUse $reduce)]>; +} diff --git a/tensorflow/compiler/mlir/lite/transforms/prepare_quantize.cc b/tensorflow/compiler/mlir/lite/transforms/prepare_quantize.cc index 9a27d0de62a..07b7aacd95d 100644 --- a/tensorflow/compiler/mlir/lite/transforms/prepare_quantize.cc +++ b/tensorflow/compiler/mlir/lite/transforms/prepare_quantize.cc @@ -23,6 +23,7 @@ limitations under the License. #include "llvm/ADT/StringRef.h" #include "llvm/Support/CommandLine.h" #include "mlir/Dialect/Quant/QuantOps.h" // from @llvm-project +#include "mlir/IR/Function.h" // from @llvm-project #include "mlir/IR/MLIRContext.h" // from @llvm-project #include "mlir/IR/PatternMatch.h" // from @llvm-project #include "mlir/IR/Value.h" // from @llvm-project @@ -122,6 +123,10 @@ class PrepareQuantizePass // the best quantization practise. This also fixes some simple violations. void SanityCheckAndAdjustment(FuncOp func); + // Whether the func contains Quantize ops. This is used to determine whether + // to use the quantization parameters from the fixed output range property. + bool ContainsQuantizeOps(FuncOp func); + QuantizationSpecs quant_specs_; }; @@ -285,6 +290,13 @@ void PrepareQuantizePass::SanityCheckAndAdjustment(FuncOp func) { }); } +bool PrepareQuantizePass::ContainsQuantizeOps(FuncOp func) { + for (const auto& op : func.getOps()) { + if (llvm::isa(op)) return true; + } + return false; +} + using PrepareQuantStats = quant::ConvertStatsToQDQs; @@ -309,6 +321,7 @@ void PrepareQuantizePass::runOnFunction() { OwningRewritePatternList patterns; bool is_signed = quant_specs_.IsSignedInferenceType(); int bit_width = quant_specs_.GetQuantizationTypeWidth(); + bool enforce_fixed_output_range = ContainsQuantizeOps(func); if (is_signed) { patterns.insert>(ctx); // Convert quant stats to int8 quantization parameters. @@ -327,7 +340,8 @@ void PrepareQuantizePass::runOnFunction() { // values (tensors). ApplyQuantizationParamsPropagation( func, is_signed, disable_per_channel || quant_specs_.disable_per_channel, - GetOpQuantSpec); + GetOpQuantSpec, + enforce_fixed_output_range || quant_specs_.post_training_quantization); ConvertMlirQuantOpsToTFLQuantOps(func); } diff --git a/tensorflow/compiler/mlir/lite/transforms/prepare_tf.cc b/tensorflow/compiler/mlir/lite/transforms/prepare_tf.cc index 918c3c69c93..c521ca0ed53 100644 --- a/tensorflow/compiler/mlir/lite/transforms/prepare_tf.cc +++ b/tensorflow/compiler/mlir/lite/transforms/prepare_tf.cc @@ -40,6 +40,7 @@ limitations under the License. #include "llvm/Support/Debug.h" #include "mlir/Analysis/LoopAnalysis.h" // from @llvm-project #include "mlir/Dialect/Quant/FakeQuantSupport.h" // from @llvm-project +#include "mlir/Dialect/Quant/QuantOps.h" // from @llvm-project #include "mlir/Dialect/Quant/UniformSupport.h" // from @llvm-project #include "mlir/Dialect/StandardOps/IR/Ops.h" // from @llvm-project #include "mlir/IR/Attributes.h" // from @llvm-project @@ -84,6 +85,11 @@ class PrepareTFPass : public PassWrapper { : unfold_batch_matmul_(unfold_batch_matmul) {} void runOnFunction() override; + void getDependentDialects(DialectRegistry ®istry) const override { + registry.insert(); + } + private: bool unfold_batch_matmul_; }; @@ -706,10 +712,8 @@ struct ConvertTFBroadcastTo : public RewritePattern { shape_type.getDimSize(0) <= 5))) return failure(); - if (!((element_type.getKind() == mlir::StandardTypes::F32) || - (element_type.getKind() == mlir::StandardTypes::BF16) || - (element_type.getKind() == mlir::StandardTypes::Integer && - element_type.cast().getWidth() == 32))) + if (!(element_type.isa() || + element_type.isInteger(32))) return failure(); auto status_or_const_op = @@ -762,6 +766,102 @@ LogicalResult ConvertTf2XlaOps(FuncOp func, MLIRContext *context) { return applyPartialConversion(func, target, patterns); } +// Convert rfft to rfft2d. +// The transformation pattern looks like below: +// +// input fft_len +// \ / +// rfft +// +// || +// \/ +// +// input fft_len +// \ / +// expand_dim concat with [1] at the front +// \ / +// rfft_2d +// | +// squeeze +struct ConvertRfftToRfft2d : public RewritePattern { + explicit ConvertRfftToRfft2d(MLIRContext *context) + : RewritePattern(TF::RFFTOp::getOperationName(), 1, context) {} + + LogicalResult matchAndRewrite(Operation *op, + PatternRewriter &rewriter) const override { + auto rfft_op = dyn_cast(op); + + auto input = rfft_op.input(); + auto input_type = input.getType().dyn_cast_or_null(); + if (!input_type) return failure(); + auto fft_len = rfft_op.fft_length(); + auto fft_len_type = fft_len.getType().dyn_cast_or_null(); + if (!fft_len_type) return failure(); + + auto output_type = + rfft_op.getResult().getType().dyn_cast_or_null(); + if (!output_type) return failure(); + + // Expanded inputs. + // Insert at -2 location. + auto one_ele_type = + mlir::RankedTensorType::get({1}, rewriter.getIntegerType(32)); + auto minus_two = CreateConstOpWithSingleValue(&rewriter, rfft_op.getLoc(), + one_ele_type, -2); + + SmallVector expanded_input_shape; + SmallVector expanded_output_shape; + int expanded_rank = input_type.getRank() + 1; + int r = 0; + for (int i = 0; i < expanded_rank; ++i) { + if (i == expanded_rank - 2) { + expanded_input_shape.push_back(1); + expanded_output_shape.push_back(1); + } else { + expanded_input_shape.push_back(input_type.getDimSize(r)); + expanded_output_shape.push_back(output_type.getDimSize(r)); + r++; + } + } + + auto expaned_input_type = mlir::RankedTensorType::get( + expanded_input_shape, input_type.getElementType()); + TF::ExpandDimsOp expanded_input = rewriter.create( + rfft_op.getLoc(), expaned_input_type, input, minus_two->getResult()); + + // Expanded fft_len. + auto one_attr = mlir::DenseIntElementsAttr::get(one_ele_type, {1}); + + auto one = rewriter.create(rfft_op.getLoc(), one_attr); + + auto zero = CreateConstOpWithSingleValue(&rewriter, rfft_op.getLoc(), + one_ele_type, 0); + + auto expanded_fft_len_type = + mlir::RankedTensorType::get({2}, fft_len_type.getElementType()); + + TF::ConcatV2Op expanded_fft_len = rewriter.create( + rfft_op.getLoc(), expanded_fft_len_type, + SmallVector({one.getResult(), fft_len}), zero->getResult()); + + // Insert the rfft_2d. + auto rfft2d_out_type = mlir::RankedTensorType::get( + expanded_output_shape, output_type.getElementType()); + TF::RFFT2DOp rfft2d = rewriter.create( + rfft_op.getLoc(), rfft2d_out_type, expanded_input.getResult(), + expanded_fft_len.getResult()); + + // Insert the squeeze op. + auto squeeze_dim = rewriter.getI64ArrayAttr({-2}); + TF::SqueezeOp squeeze = rewriter.create( + rfft_op.getLoc(), output_type, rfft2d.getResult(), squeeze_dim); + + rewriter.replaceOp(op, squeeze.getResult()); + + return success(); + } +}; + void PrepareTFPass::runOnFunction() { OwningRewritePatternList patterns; auto func = getFunction(); @@ -811,7 +911,8 @@ void PrepareTFPass::runOnFunction() { TF::ConvertTFBatchMatMulOp>(ctx); } patterns.insert(ctx); + ConvertTFDepthwiseConv2dNative, ConvertTFStridedSlice, + ConvertRfftToRfft2d>(ctx); applyPatternsAndFoldGreedily(func, patterns); } diff --git a/tensorflow/compiler/mlir/lite/utils/constant_utils.cc b/tensorflow/compiler/mlir/lite/utils/constant_utils.cc index 8562f623258..b32da24d00f 100644 --- a/tensorflow/compiler/mlir/lite/utils/constant_utils.cc +++ b/tensorflow/compiler/mlir/lite/utils/constant_utils.cc @@ -30,80 +30,66 @@ stream_executor::port::StatusOr CreateConstOpWithSingleValue( Type element_type = shaped_type.getElementType(); ShapedType scalar_type = RankedTensorType::get({}, element_type); Attribute attr; - switch (element_type.getKind()) { - case mlir::StandardTypes::F16: { - auto floatType = mlir::FloatType::getF16(element_type.getContext()); - auto floatAttr = - mlir::FloatAttr::get(floatType, static_cast(value)); - std::vector floatValues({floatAttr}); - attr = DenseElementsAttr::get(scalar_type, floatValues); - break; - } - case mlir::StandardTypes::BF16: { - auto floatType = mlir::FloatType::getBF16(element_type.getContext()); - auto floatAttr = - mlir::FloatAttr::get(floatType, static_cast(value)); - std::vector floatValues({floatAttr}); - attr = DenseElementsAttr::get(scalar_type, floatValues); - break; - } - case mlir::StandardTypes::F32: { - attr = - DenseElementsAttr::get(scalar_type, static_cast(value)); - break; - } - case mlir::StandardTypes::Complex: { - auto etype = element_type.cast().getElementType(); - if (etype.isF32()) { - auto dialect = etype.getContext()->getRegisteredDialect("tf"); - tensorflow::TensorProto repr; - repr.set_dtype(tensorflow::DT_COMPLEX64); + if (element_type.isF16()) { + auto floatType = mlir::FloatType::getF16(element_type.getContext()); + auto floatAttr = mlir::FloatAttr::get(floatType, static_cast(value)); + std::vector floatValues({floatAttr}); + attr = DenseElementsAttr::get(scalar_type, floatValues); + } else if (element_type.isBF16()) { + auto floatType = mlir::FloatType::getBF16(element_type.getContext()); + auto floatAttr = mlir::FloatAttr::get(floatType, static_cast(value)); + std::vector floatValues({floatAttr}); + attr = DenseElementsAttr::get(scalar_type, floatValues); + } else if (element_type.isF32()) { + attr = + DenseElementsAttr::get(scalar_type, static_cast(value)); + } else if (auto complex_type = element_type.dyn_cast()) { + auto etype = complex_type.getElementType(); + if (etype.isF32()) { + auto dialect = etype.getContext()->getLoadedDialect("tf"); + tensorflow::TensorProto repr; + repr.set_dtype(tensorflow::DT_COMPLEX64); - tensorflow::TensorShapeProto* shape = repr.mutable_tensor_shape(); - shape->set_unknown_rank(false); - shape->add_dim()->set_size(int64_t{1}); - std::string content; - auto complex_value = - std::complex(static_cast(value), 0.0f); - content.assign(reinterpret_cast(&complex_value), - sizeof(complex_value)); - repr.set_tensor_content(content); - std::string mangled = tensorflow::mangling_util::MangleTensor(repr); + tensorflow::TensorShapeProto* shape = repr.mutable_tensor_shape(); + shape->set_unknown_rank(false); + shape->add_dim()->set_size(int64_t{1}); + std::string content; + auto complex_value = std::complex(static_cast(value), 0.0f); + content.assign(reinterpret_cast(&complex_value), + sizeof(complex_value)); + repr.set_tensor_content(content); + std::string mangled = tensorflow::mangling_util::MangleTensor(repr); - attr = mlir::OpaqueElementsAttr::get(dialect, scalar_type, mangled); + attr = mlir::OpaqueElementsAttr::get(dialect, scalar_type, mangled); + } else { + return tensorflow::Status(tensorflow::error::INVALID_ARGUMENT, + "Unsupported type"); + } + } else if (auto itype = element_type.dyn_cast()) { + switch (itype.getWidth()) { + case 8: + attr = DenseElementsAttr::get(scalar_type, + static_cast(value)); break; - } - return tensorflow::Status(tensorflow::error::INVALID_ARGUMENT, - "Unsupported type"); + case 16: + attr = DenseElementsAttr::get(scalar_type, + static_cast(value)); + break; + case 32: + attr = DenseElementsAttr::get(scalar_type, + static_cast(value)); + break; + case 64: + attr = DenseElementsAttr::get(scalar_type, + static_cast(value)); + break; + default: + return tensorflow::Status(tensorflow::error::INVALID_ARGUMENT, + "Unsupported type"); } - case mlir::StandardTypes::Integer: { - const auto& itype = element_type.cast(); - switch (itype.getWidth()) { - case 8: - attr = DenseElementsAttr::get(scalar_type, - static_cast(value)); - break; - case 16: - attr = DenseElementsAttr::get(scalar_type, - static_cast(value)); - break; - case 32: - attr = DenseElementsAttr::get(scalar_type, - static_cast(value)); - break; - case 64: - attr = DenseElementsAttr::get(scalar_type, - static_cast(value)); - break; - default: - return tensorflow::Status(tensorflow::error::INVALID_ARGUMENT, - "Unsupported type"); - } - break; - } - default: - return tensorflow::Status(tensorflow::error::INVALID_ARGUMENT, - "Unsupported type"); + } else { + return tensorflow::Status(tensorflow::error::INVALID_ARGUMENT, + "Unsupported type"); } return rewriter->create(loc, scalar_type, attr); } diff --git a/tensorflow/compiler/mlir/lite/utils/lstm_utils_test.cc b/tensorflow/compiler/mlir/lite/utils/lstm_utils_test.cc index 081ba7ac6e7..f26689fac5e 100644 --- a/tensorflow/compiler/mlir/lite/utils/lstm_utils_test.cc +++ b/tensorflow/compiler/mlir/lite/utils/lstm_utils_test.cc @@ -93,8 +93,9 @@ class LstmUtilsTest : public ::testing::Test { LstmUtilsTest() {} void SetUp() override { - RegisterDialects(); context_ = std::make_unique(); + context_->loadDialect(); builder_ = std::unique_ptr(new Builder(context_.get())); fused_lstm_func_ = createLstmCompositeFunc(builder_.get(), false, false); fused_lstm_func_cifg_ = @@ -109,12 +110,6 @@ class LstmUtilsTest : public ::testing::Test { builder_.reset(); } - void RegisterDialects() { - mlir::registerDialect(); - mlir::registerDialect(); - mlir::registerDialect(); - } - FuncOp fused_lstm_func_; FuncOp fused_lstm_func_cifg_; FuncOp fused_ln_lstm_func_; diff --git a/tensorflow/compiler/mlir/lite/utils/tftext_utils.cc b/tensorflow/compiler/mlir/lite/utils/tftext_utils.cc index 96d22cb51e9..4035fed221d 100644 --- a/tensorflow/compiler/mlir/lite/utils/tftext_utils.cc +++ b/tensorflow/compiler/mlir/lite/utils/tftext_utils.cc @@ -56,9 +56,9 @@ inline OpaqueElementsAttr CustomOption(OpBuilder* builder, const std::string& content) { ShapedType type = RankedTensorType::get( {static_cast(content.size())}, builder->getIntegerType(8)); - return OpaqueElementsAttr::get( - builder->getContext()->getRegisteredDialect("tfl"), type, - StringRef(content.data(), content.size())); + return OpaqueElementsAttr::get(builder->getContext()->getLoadedDialect("tfl"), + type, + StringRef(content.data(), content.size())); } inline TensorType GetInputType(FuncOp func, int idx) { diff --git a/tensorflow/compiler/mlir/mlir_graph_optimization_pass.cc b/tensorflow/compiler/mlir/mlir_graph_optimization_pass.cc index 8be6facce38..00efffff144 100644 --- a/tensorflow/compiler/mlir/mlir_graph_optimization_pass.cc +++ b/tensorflow/compiler/mlir/mlir_graph_optimization_pass.cc @@ -128,6 +128,7 @@ Status MlirFunctionOptimizationPass::Run( GraphDebugInfo debug_info; RegisterDialects(); mlir::MLIRContext context; + context.loadAllGloballyRegisteredDialects(); GraphImportConfig import_config; import_config.graph_as_function = true; import_config.control_outputs = *control_ret_node_names; @@ -208,6 +209,7 @@ Status MlirV1CompatGraphOptimizationPass::Run( GraphDebugInfo debug_info; RegisterDialects(); mlir::MLIRContext context; + context.loadAllGloballyRegisteredDialects(); GraphImportConfig import_config; import_config.upgrade_legacy = true; // Restrict functionalization to TPU nodes to avoid problems in v1 session diff --git a/tensorflow/compiler/mlir/python/mlir.cc b/tensorflow/compiler/mlir/python/mlir.cc index 5ce0ca8cfcb..f1f6c43d3b3 100644 --- a/tensorflow/compiler/mlir/python/mlir.cc +++ b/tensorflow/compiler/mlir/python/mlir.cc @@ -41,6 +41,7 @@ std::string ImportGraphDef(const std::string &proto, GraphDebugInfo debug_info; GraphImportConfig specs; mlir::MLIRContext context; + context.loadAllGloballyRegisteredDialects(); auto module = ConvertGraphdefToMlir(graphdef, debug_info, specs, &context); if (!module.ok()) { Set_TF_Status_from_Status(status, module.status()); @@ -85,6 +86,7 @@ std::string ExperimentalConvertSavedModelToMlir( std::vector exported_names = absl::StrSplit(exported_names_str, ',', absl::SkipEmpty()); mlir::MLIRContext context; + context.loadAllGloballyRegisteredDialects(); auto module_or = ConvertSavedModelToMlir( &bundle, &context, absl::Span(exported_names)); if (!module_or.status().ok()) { @@ -115,6 +117,7 @@ std::string ExperimentalConvertSavedModelV1ToMlir( // Convert the SavedModelBundle to an MLIR module. mlir::MLIRContext context; + context.loadAllGloballyRegisteredDialects(); auto module_or = ConvertSavedModelV1ToMlir(bundle, {}, &context, upgrade_legacy); if (!module_or.status().ok()) { diff --git a/tensorflow/compiler/mlir/python/mlir_wrapper/mlir_wrapper.cc b/tensorflow/compiler/mlir/python/mlir_wrapper/mlir_wrapper.cc index 63ca4c7bb28..4152b576e71 100644 --- a/tensorflow/compiler/mlir/python/mlir_wrapper/mlir_wrapper.cc +++ b/tensorflow/compiler/mlir/python/mlir_wrapper/mlir_wrapper.cc @@ -38,6 +38,7 @@ PYBIND11_MODULE(mlir_wrapper, m) { SM.AddNewSourceBuffer(llvm::MemoryBuffer::getMemBuffer(input), llvm::SMLoc()); mlir::MLIRContext ctx; + ctx.loadAllGloballyRegisteredDialects(); auto module = mlir::parseSourceFile(SM, &ctx); if (!module) { return false; diff --git a/tensorflow/compiler/mlir/python/mlir_wrapper/types.cc b/tensorflow/compiler/mlir/python/mlir_wrapper/types.cc index 2be67f8e93e..be2dc2065f3 100644 --- a/tensorflow/compiler/mlir/python/mlir_wrapper/types.cc +++ b/tensorflow/compiler/mlir/python/mlir_wrapper/types.cc @@ -20,11 +20,6 @@ limitations under the License. void init_types(py::module& m) { // Type py::class_ Type(m, "Type"); - Type.def("getKind", &mlir::Type::getKind); - - // Type Enums - py::enum_(Type, "StandardTypes_Kind") - .value("BF16", mlir::StandardTypes::BF16); // Type Sub-classes py::class_(m, "FunctionType") @@ -32,7 +27,10 @@ void init_types(py::module& m) { [](mlir::FunctionType& ft) { return ft.getResults().vec(); }); py::class_(m, "FloatType") - .def("get", &mlir::FloatType::get); + .def("getBF16", &mlir::FloatType::getBF16) + .def("getF16", &mlir::FloatType::getF16) + .def("getF32", &mlir::FloatType::getF32) + .def("getF64", &mlir::FloatType::getF64); py::class_(m, "IntegerType") .def("get", py::overload_cast( diff --git a/tensorflow/compiler/mlir/tensorflow/BUILD b/tensorflow/compiler/mlir/tensorflow/BUILD index f9b1abcccc6..b8c7376ebd3 100644 --- a/tensorflow/compiler/mlir/tensorflow/BUILD +++ b/tensorflow/compiler/mlir/tensorflow/BUILD @@ -355,6 +355,7 @@ cc_library( "ir/tf_remaining_ops.h.inc", ] + ["ir/tf_" + target["name"] + ".h.inc" for target in tf_ops_category_list], deps = [ + ":attribute_utils", ":tensorflow_attributes", ":tensorflow_canonicalize_inc_gen", ":tensorflow_op_interfaces", @@ -722,6 +723,7 @@ cc_library( "//tensorflow/core:framework", "@com_google_absl//absl/strings", "@llvm-project//llvm:Support", + "@llvm-project//mlir:Analysis", "@llvm-project//mlir:IR", "@llvm-project//mlir:Pass", "@llvm-project//mlir:Support", @@ -775,6 +777,7 @@ cc_library( "transforms/sink_constant.cc", "transforms/stack_ops_decomposition.cc", "transforms/tensor_array_ops_decomposition.cc", + "transforms/tensor_device_copy_conversion.cc", "transforms/tensor_list_ops_decomposition.cc", "transforms/test_resource_alias_analysis.cc", "transforms/test_side_effect_analysis.cc", @@ -787,6 +790,7 @@ cc_library( "transforms/tpu_extract_head_tail_outside_compilation.cc", "transforms/tpu_extract_outside_compilation.cc", "transforms/tpu_host_computation_expansion.cc", + "transforms/tpu_identity_pruning.cc", "transforms/tpu_merge_variables_with_execute.cc", "transforms/tpu_outside_compilation_cluster.cc", "transforms/tpu_rewrite_pass.cc", @@ -799,7 +803,6 @@ cc_library( "translate/tf_functional_to_executor.cc", ], hdrs = [ - "transforms/attribute_utils.h", "transforms/batchmatmul_to_einsum.h", "transforms/bridge.h", "transforms/collection_ops_util.h", @@ -809,6 +812,7 @@ cc_library( ], includes = ["include"], deps = [ + ":attribute_utils", ":bridge_logger", ":convert_tensor", ":convert_type", @@ -1269,7 +1273,7 @@ cc_library( name = "tf_dialect_passes", srcs = [ "transforms/constant_fold.cc", - "transforms/dialect_hooks.cc", + "transforms/decode_attributes_hook.cc", ], hdrs = [ "transforms/constant_fold.h", @@ -1632,6 +1636,7 @@ cc_library( deps = [ ":lower_tf_inc_gen", ":tensorflow", + ":tensorflow_ops", ":tensorflow_types", "//tensorflow/core:framework", "@llvm-project//llvm:Support", @@ -1819,3 +1824,11 @@ cc_library( "@llvm-project//mlir:Support", ], ) + +cc_library( + name = "attribute_utils", + hdrs = ["utils/attribute_utils.h"], + deps = [ + "@llvm-project//mlir:IR", + ], +) diff --git a/tensorflow/compiler/mlir/tensorflow/analysis/resource_alias_analysis.cc b/tensorflow/compiler/mlir/tensorflow/analysis/resource_alias_analysis.cc index 7ad2705263b..8ec7513f81f 100644 --- a/tensorflow/compiler/mlir/tensorflow/analysis/resource_alias_analysis.cc +++ b/tensorflow/compiler/mlir/tensorflow/analysis/resource_alias_analysis.cc @@ -21,11 +21,13 @@ limitations under the License. #include "absl/strings/str_cat.h" #include "llvm/ADT/DenseMap.h" #include "llvm/ADT/DenseSet.h" +#include "llvm/ADT/SCCIterator.h" #include "llvm/ADT/STLExtras.h" #include "llvm/ADT/SmallVector.h" #include "llvm/ADT/iterator_range.h" #include "llvm/Support/Casting.h" #include "llvm/Support/Debug.h" +#include "mlir/Analysis/CallGraph.h" // from @llvm-project #include "mlir/IR/Attributes.h" // from @llvm-project #include "mlir/IR/Block.h" // from @llvm-project #include "mlir/IR/Builders.h" // from @llvm-project @@ -35,6 +37,7 @@ limitations under the License. #include "mlir/IR/StandardTypes.h" // from @llvm-project #include "mlir/IR/Value.h" // from @llvm-project #include "mlir/IR/Visitors.h" // from @llvm-project +#include "mlir/Interfaces/CallInterfaces.h" // from @llvm-project #include "mlir/Support/LLVM.h" // from @llvm-project #include "mlir/Support/LogicalResult.h" // from @llvm-project #include "tensorflow/compiler/mlir/tensorflow/ir/tf_device.h" @@ -134,12 +137,46 @@ class BacktrackAnalysis { return GetAnalysisForRegion(region); } + // Returns the backtrack analysis for the given region if it exists. + // If the region has not yet been analyzed, returns llvm::None. + Optional GetAnalysisIfExists(Region& region) const { + auto it = info_map_.find(®ion); + if (it == info_map_.end()) return llvm::None; + return &it->second; + } + + Optional GetAnalysisIfExists(FuncOp func) const { + return GetAnalysisIfExists(func.getBody()); + } + private: llvm::SmallDenseMap info_map_; }; // Analyzes all regions attached to all operations in the module. BacktrackAnalysis::BacktrackAnalysis(ModuleOp module) { + const CallGraph call_graph(module); + + // Visit functions bottom up when doing the analysis. Note that SCC iterator + // has the property that if there is an edge from SCC1->SCC2, SCC1 is visited + // after SCC2, i.e., the graph is traversed bottom up just the way we want. + auto scc_begin = llvm::scc_begin(&call_graph); + auto scc_end = llvm::scc_end(&call_graph); + for (auto& scc : make_range(scc_begin, scc_end)) { + // Each SCC node is a collection of callgraph nodes that form a cycle. We + // will visit these nodes in an arbitrary order. If a node being visited + // calls a function that has not yet been analyzed, we will not be able to + // backtrack through that function call (our analysis will be correct but + // pessimistic). + for (CallGraphNode* node : scc) { + if (node->isExternal()) continue; + Region* region = node->getCallableRegion(); + GetOrCreateAnalysis(*region); + } + } + + // This above call graph analysis will cover all regions attached to functions + // but we also need to analyze regions attached to other ops. module.walk([this](Operation* op) { for (Region& region : op->getRegions()) GetOrCreateAnalysis(region); }); @@ -160,6 +197,18 @@ Value BacktrackAnalysis::BacktrackValue(Value value) { value = island.GetYield().getOperand(res_index); } else if (isa(op)) { value = op->getOperand(res_index); + } else if (auto call = dyn_cast(op)) { + FuncOp func = dyn_cast(call.resolveCallable()); + if (!func) break; + // Check if the function being called has been analyzed. if not, + // we cannot backtrack the value further. + Optional callee_info = GetAnalysisIfExists(func); + if (!callee_info) break; + Optional passthrough_arg = callee_info.getValue()->GetArg(res_index); + if (!passthrough_arg) break; + value = call.getArgOperands()[passthrough_arg.getValue()]; + } else if (isa(op)) { + value = op->getRegion(0).front().getTerminator()->getOperand(res_index); } else { break; } @@ -359,6 +408,13 @@ ResourceAliasAnalysisInfo::ResourceAliasAnalysisInfo( AddValueUniqueIDMapping(result, kUnknownResourceId); } } + } else if (isa(op)) { + Region& region = op->getRegion(0); + const auto& body_info = backtrack_analysis.GetAnalysisForRegion(region); + for (auto result : filter_resources(op->getResults())) { + Value body_result = body_info.GetValue(result.getResultNumber()); + PropagateInputToOutput(body_result, result); + } } else { assign_unknown_id_to_all(op->getResults()); } @@ -493,10 +549,7 @@ llvm::SmallSetVector ResourceAliasAnalysisInfo::GetResourceAliases( // ResourceAliasAnalysis //===----------------------------------------------------------------------===// -ResourceAliasAnalysis::ResourceAliasAnalysis(Operation* op) { - auto module = dyn_cast(op); - assert(module); - +ResourceAliasAnalysis::ResourceAliasAnalysis(ModuleOp module) { // Analyze all regions for backtracking info. detail::BacktrackAnalysis backtrack_analysis(module); diff --git a/tensorflow/compiler/mlir/tensorflow/analysis/resource_alias_analysis.h b/tensorflow/compiler/mlir/tensorflow/analysis/resource_alias_analysis.h index c965b5d7602..46bb57c942d 100644 --- a/tensorflow/compiler/mlir/tensorflow/analysis/resource_alias_analysis.h +++ b/tensorflow/compiler/mlir/tensorflow/analysis/resource_alias_analysis.h @@ -102,7 +102,7 @@ class ResourceAliasAnalysis : public detail::PerFunctionAggregateAnalysis< detail::ResourceAliasAnalysisInfo> { public: // Constructs analysis by analyzing the given module operation. - explicit ResourceAliasAnalysis(Operation* op); + explicit ResourceAliasAnalysis(ModuleOp module); }; // Returns a range with just resource type values from the input range diff --git a/tensorflow/compiler/mlir/tensorflow/analysis/side_effect_analysis.cc b/tensorflow/compiler/mlir/tensorflow/analysis/side_effect_analysis.cc index e382bdb28c6..c78a7e403c4 100644 --- a/tensorflow/compiler/mlir/tensorflow/analysis/side_effect_analysis.cc +++ b/tensorflow/compiler/mlir/tensorflow/analysis/side_effect_analysis.cc @@ -320,10 +320,7 @@ SideEffectAnalysisInfo::DirectControlSuccessors( } } // namespace detail -SideEffectAnalysis::SideEffectAnalysis(Operation* op) { - auto module = dyn_cast(op); - assert(module); - +SideEffectAnalysis::SideEffectAnalysis(ModuleOp module) { // Analyze entire module for alias analysis info. ResourceAliasAnalysis alias_analysis(module); diff --git a/tensorflow/compiler/mlir/tensorflow/analysis/side_effect_analysis.h b/tensorflow/compiler/mlir/tensorflow/analysis/side_effect_analysis.h index c92c6e1882c..a75f7eb7dee 100644 --- a/tensorflow/compiler/mlir/tensorflow/analysis/side_effect_analysis.h +++ b/tensorflow/compiler/mlir/tensorflow/analysis/side_effect_analysis.h @@ -130,7 +130,7 @@ class SideEffectAnalysis : public detail::PerFunctionAggregateAnalysis< detail::SideEffectAnalysisInfo> { public: // Constructs analysis by analyzing the given module operation. - explicit SideEffectAnalysis(Operation* op); + explicit SideEffectAnalysis(ModuleOp module); }; } // namespace TF diff --git a/tensorflow/compiler/mlir/tensorflow/c/BUILD b/tensorflow/compiler/mlir/tensorflow/c/BUILD index 801e35280d6..243f4b5139f 100644 --- a/tensorflow/compiler/mlir/tensorflow/c/BUILD +++ b/tensorflow/compiler/mlir/tensorflow/c/BUILD @@ -41,6 +41,7 @@ tf_cuda_library( "//tensorflow/core:protos_all_cc", "//tensorflow/core/lib/llvm_rtti", "//tensorflow/core/platform:errors", + "@com_google_absl//absl/strings", "@llvm-project//llvm:Support", "@llvm-project//mlir:IR", "@llvm-project//mlir:Pass", diff --git a/tensorflow/compiler/mlir/tensorflow/c/c_api_unified_experimental_mlir.cc b/tensorflow/compiler/mlir/tensorflow/c/c_api_unified_experimental_mlir.cc index edf5d09b401..c62d62a2d3d 100644 --- a/tensorflow/compiler/mlir/tensorflow/c/c_api_unified_experimental_mlir.cc +++ b/tensorflow/compiler/mlir/tensorflow/c/c_api_unified_experimental_mlir.cc @@ -16,6 +16,7 @@ limitations under the License. #include #include +#include "absl/strings/str_cat.h" #include "llvm/ADT/StringRef.h" #include "llvm/ADT/iterator_range.h" #include "llvm/Support/raw_ostream.h" @@ -64,6 +65,9 @@ using tensorflow::AbstractTensorInterface; using tensorflow::dyn_cast; using tensorflow::OutputList; using tensorflow::string; +using tensorflow::errors::FailedPrecondition; +using tensorflow::errors::InvalidArgument; +using tensorflow::errors::Unimplemented; using tensorflow::tracing::TracingContext; using tensorflow::tracing::TracingOperation; using tensorflow::tracing::TracingTensorHandle; @@ -103,6 +107,9 @@ class MlirTensor : public TracingTensorHandle { } Value getValue() { return value_; } + Type getElementType() { + return value_.getType().cast().getElementType(); + } // For LLVM style RTTI. static bool classof(const AbstractTensorHandle* ptr) { @@ -184,11 +191,18 @@ class MlirAbstractOp : public TracingOperation { } private: + // Return true is there are still unfilled ODS slots for adding more inputs. + bool IsNextODSArgAvailable(); + MLIRContext* context_; MlirFunctionContext* function_context_; SmallVector operands_; llvm::StringMap attrs_; std::unique_ptr state_; + // This is the index of the next ODS operand that will be added with AddInput + // or AddInput; + int current_ods_input_ = 0; + const tensorflow::OpDef* op_def_ = nullptr; const char* op_name_ = nullptr; string tf_op_type_; // TODO(srbs): Use this. @@ -244,12 +258,12 @@ class MlirFunctionContext : public TracingContext { Status Finalize(OutputList* outputs, AbstractFunction** f) override; Status RegisterFunction(AbstractFunction* func) override { - return tensorflow::errors::Unimplemented( + return Unimplemented( "Registering graph functions has not been implemented yet."); } Status RemoveFunction(const string& func) override { - return tensorflow::errors::Unimplemented( + return Unimplemented( "MlirFunctionContext::RemoveFunction has not been implemented yet."); } @@ -264,9 +278,12 @@ class MlirFunctionContext : public TracingContext { Status MlirAbstractOp::Reset(const char* op, const char* device_name) { if (state_) { - return tensorflow::errors::FailedPrecondition( - "Reset called on already built op."); + return FailedPrecondition("Reset called on already built op."); } + TF_RETURN_IF_ERROR( + tensorflow::OpRegistry::Global()->LookUpOpDef(op, &op_def_)); + assert(op_def_); + tf_op_type_ = op; std::string name = "tf."; name += op; @@ -277,13 +294,12 @@ Status MlirAbstractOp::Reset(const char* op, const char* device_name) { Status MlirAbstractOp::SetAttrType(const char* attr_name, tensorflow::DataType dtype) { - if (!state_) { - return Status(tensorflow::error::Code::FAILED_PRECONDITION, - "op_type must be specified before specifying attrs."); - } + if (!state_) + return FailedPrecondition( + "op_type must be specified before specifying attrs."); Type mlir_type; Builder builder(context_); - TF_RETURN_IF_ERROR(ConvertDataTypeToTensor(dtype, builder, &mlir_type)); + TF_RETURN_IF_ERROR(ConvertDataType(dtype, builder, &mlir_type)); attrs_[attr_name] = TypeAttr::get(mlir_type); return Status::OK(); } @@ -291,8 +307,7 @@ Status MlirAbstractOp::SetAttrType(const char* attr_name, Status MlirAbstractOp::SetOpName(const char* const op_name) { // TODO(aminim): should we use a location? if (op_name_) { - return tensorflow::errors::FailedPrecondition( - "SetOpName called on already built op."); + return FailedPrecondition("SetOpName called on already built op."); } op_name_ = op_name; return Status::OK(); @@ -301,8 +316,7 @@ Status MlirAbstractOp::SetOpName(const char* const op_name) { Status MlirAbstractOp::AddRef(Type type, Type* output_type) { Type elt_type = getElementTypeOrSelf(type); if (elt_type.isa()) { - return tensorflow::errors::InvalidArgument( - "Requested reference to a reference type"); + return InvalidArgument("Requested reference to a reference type"); } elt_type = TensorFlowRefType::get(elt_type); if (RankedTensorType tensor_type = type.dyn_cast()) { @@ -315,138 +329,97 @@ Status MlirAbstractOp::AddRef(Type type, Type* output_type) { Status MlirAbstractOp::Create(ArrayRef operands, OperationState** state) { state_->operands = llvm::to_vector<4>(operands); - const tensorflow::OpDef* op_def; - auto node_name = state_->name.getStringRef().drop_front( - TensorFlowDialect::getDialectNamespace().size() + 1); - TF_RETURN_IF_ERROR( - tensorflow::OpRegistry::Global()->LookUpOpDef(node_name.str(), &op_def)); Builder builder(context_); - // Process operands according to the op_def and infer derived attributes. - int current_operand = 0; - for (const tensorflow::OpDef::ArgDef& input_arg : op_def->input_arg()) { - if (!input_arg.number_attr().empty()) { - // TODO(b/156122856): we don't support variadic operands. - return tensorflow::errors::Unimplemented( - "Unsupported 'number_attr' for '", input_arg.number_attr(), "'"); - } else if (!input_arg.type_list_attr().empty()) { - return tensorflow::errors::InvalidArgument( - "Unsupported 'type_list_attr' for '", input_arg.number_attr(), "'"); - } - if (current_operand >= operands.size()) { - return tensorflow::errors::InvalidArgument("Missing operand for '", - input_arg.name(), "'"); - } - Type expected_type; - if (input_arg.type() != tensorflow::DT_INVALID) { - TF_RETURN_IF_ERROR( - ConvertDataTypeToTensor(input_arg.type(), builder, &expected_type)); - Type output_type; - if (input_arg.is_ref()) - TF_RETURN_IF_ERROR(AddRef(expected_type, &output_type)); - expected_type = output_type; - } else { - expected_type = operands[current_operand].getType(); - } - if (!input_arg.type_attr().empty()) { - attrs_[input_arg.type_attr()] = TypeAttr::get(expected_type); - } - ++current_operand; - } - for (const tensorflow::OpDef::ArgDef& output_arg : op_def->output_arg()) { + if (current_ods_input_ != op_def_->input_arg_size()) + return InvalidArgument(absl::StrCat("Mismatch in operands number: got ", + current_ods_input_, " expected ", + op_def_->input_arg_size(), " ; for op ", + state_->name.getStringRef().str())); + + // Process results according to the op_def and infer types for derived + // attributes. + for (const tensorflow::OpDef::ArgDef& output_arg : op_def_->output_arg()) { int original_size = state_->types.size(); if (!output_arg.number_attr().empty()) { // Same type repeated "repeats" times. Attribute repeats_attr = attrs_[output_arg.number_attr()]; - if (!repeats_attr) { - return tensorflow::errors::InvalidArgument( - "Missing attribute '", output_arg.number_attr(), - "' required for output list '", output_arg.name(), "'"); - } - if (!repeats_attr.isa()) { - return tensorflow::errors::InvalidArgument( - "Attribute '", output_arg.number_attr(), - "' required for output list '", output_arg.name(), - "' isn't an integer"); - } + if (!repeats_attr) + return InvalidArgument("Missing attribute '", output_arg.number_attr(), + "' required for output list '", + output_arg.name(), "'"); + if (!repeats_attr.isa()) + return InvalidArgument("Attribute '", output_arg.number_attr(), + "' required for output list '", + output_arg.name(), "' isn't an integer"); int64_t repeats = repeats_attr.cast().getInt(); if (!output_arg.type_attr().empty()) { // Same type repeated "repeats" times. Attribute attr = attrs_[output_arg.type_attr()]; - if (!attr) { - return tensorflow::errors::InvalidArgument( - "Missing attribute '", output_arg.type_attr(), - "' required for output '", output_arg.name(), "'"); - } + if (!attr) + return InvalidArgument("Missing attribute '", output_arg.type_attr(), + "' required for output '", output_arg.name(), + "'"); TypeAttr type_attr = attr.dyn_cast(); - if (!type_attr) { - return tensorflow::errors::InvalidArgument( - "Attribute '", output_arg.type_attr(), "' required for output '", - output_arg.name(), "' isn't a type attribute"); - } + if (!type_attr) + return InvalidArgument("Attribute '", output_arg.type_attr(), + "' required for output '", output_arg.name(), + "' isn't a type attribute"); for (int i = 0; i < repeats; ++i) - state_->types.push_back(type_attr.getType()); + state_->types.push_back(UnrankedTensorType::get(type_attr.getType())); } else if (output_arg.type() != tensorflow::DT_INVALID) { for (int i = 0; i < repeats; ++i) { Type type; TF_RETURN_IF_ERROR( - ConvertDataTypeToTensor(output_arg.type(), builder, &type)); + ConvertDataType(output_arg.type(), builder, &type)); state_->types.push_back(type); } } else { - return tensorflow::errors::InvalidArgument( - "Missing type or type_attr field in ", - output_arg.ShortDebugString()); + return InvalidArgument("Missing type or type_attr field in ", + output_arg.ShortDebugString()); } } else if (!output_arg.type_attr().empty()) { Attribute attr = attrs_[output_arg.type_attr()]; - if (!attr) { - return tensorflow::errors::InvalidArgument( - "Missing attribute '", output_arg.type_attr(), - "' required for output '", output_arg.name(), "'"); - } + if (!attr) + return InvalidArgument("Missing attribute '", output_arg.type_attr(), + "' required for output '", output_arg.name(), + "'"); TypeAttr type_attr = attr.dyn_cast(); - if (!type_attr) { - return tensorflow::errors::InvalidArgument( - "Attribute '", output_arg.type_attr(), "' required for output '", - output_arg.name(), "' isn't a type attribute"); - } - state_->types.push_back(type_attr.getValue()); + if (!type_attr) + return InvalidArgument("Attribute '", output_arg.type_attr(), + "' required for output '", output_arg.name(), + "' isn't a type attribute"); + state_->types.push_back(UnrankedTensorType::get(type_attr.getValue())); } else if (!output_arg.type_list_attr().empty()) { // This is pointing to an attribute which is an array of types. Attribute attr = attrs_[output_arg.type_list_attr()]; - if (!attr) { - return tensorflow::errors::InvalidArgument( + if (!attr) + return InvalidArgument( "Missing attribute '", output_arg.type_list_attr(), "' required for output '", output_arg.name(), "'"); - } ArrayAttr array_attr = attr.dyn_cast(); - if (!array_attr) { - return tensorflow::errors::InvalidArgument( - "Attribute '", output_arg.type_list_attr(), - "' required for output '", output_arg.name(), - "' isn't an array attribute"); - } + if (!array_attr) + return InvalidArgument("Attribute '", output_arg.type_list_attr(), + "' required for output '", output_arg.name(), + "' isn't an array attribute"); for (Attribute attr : array_attr) { TypeAttr type_attr = attr.dyn_cast(); - if (!type_attr) { - return tensorflow::errors::InvalidArgument( - "Array Attribute '", output_arg.type_list_attr(), - "' required for output '", output_arg.name(), - "' has a non-Type element"); - } - state_->types.push_back(type_attr.getValue()); + if (!type_attr) + return InvalidArgument("Array Attribute '", + output_arg.type_list_attr(), + "' required for output '", output_arg.name(), + "' has a non-Type element"); + state_->types.push_back(UnrankedTensorType::get(type_attr.getValue())); } } else if (output_arg.type() != tensorflow::DT_INVALID) { Type type; Builder builder(context_); - TF_RETURN_IF_ERROR( - ConvertDataTypeToTensor(output_arg.type(), builder, &type)); + TF_RETURN_IF_ERROR(ConvertDataType(output_arg.type(), builder, &type)); state_->types.push_back(type); } else { - return tensorflow::errors::InvalidArgument("No type fields in ", - output_arg.ShortDebugString()); + return InvalidArgument("No type fields in ", + output_arg.ShortDebugString()); } if (output_arg.is_ref()) { // For all types that were added by this function call, make them refs. @@ -472,88 +445,67 @@ Status MlirAbstractOp::SetDeviceName(const char* name) { return Status::OK(); } -Status MlirAbstractOp::AddInputList( - absl::Span inputs) { - return tensorflow::errors::Unimplemented( - "AddInputList has not been implemented yet."); -} - Status MlirAbstractOp::SetAttrString(const char* attr_name, const char* data, size_t length) { - return tensorflow::errors::Unimplemented( - "SetAttrString has not been implemented yet."); + return Unimplemented("SetAttrString has not been implemented yet."); } Status MlirAbstractOp::SetAttrInt(const char* attr_name, int64_t value) { - return tensorflow::errors::Unimplemented( - "SetAttrInt has not been implemented yet."); + return Unimplemented("SetAttrInt has not been implemented yet."); } Status MlirAbstractOp::SetAttrFloat(const char* attr_name, float value) { - return tensorflow::errors::Unimplemented( - "SetAttrFloat has not been implemented yet."); + return Unimplemented("SetAttrFloat has not been implemented yet."); } Status MlirAbstractOp::SetAttrBool(const char* attr_name, bool value) { - return tensorflow::errors::Unimplemented( - "SetAttrBool has not been implemented yet."); + return Unimplemented("SetAttrBool has not been implemented yet."); } Status MlirAbstractOp::SetAttrShape(const char* attr_name, const int64_t* dims, const int num_dims) { - return tensorflow::errors::Unimplemented( - "SetAttrShape has not been implemented yet."); + return Unimplemented("SetAttrShape has not been implemented yet."); } Status MlirAbstractOp::SetAttrFunction(const char* attr_name, const AbstractOperation* value) { - return tensorflow::errors::Unimplemented( - "SetAttrFunction has not been implemented yet."); + return Unimplemented("SetAttrFunction has not been implemented yet."); } Status MlirAbstractOp::SetAttrFunctionName(const char* attr_name, const char* value, size_t length) { - return tensorflow::errors::Unimplemented( - "SetAttrFunctionName has not been implemented yet."); + return Unimplemented("SetAttrFunctionName has not been implemented yet."); } Status MlirAbstractOp::SetAttrTensor(const char* attr_name, AbstractTensorInterface* tensor) { - return tensorflow::errors::Unimplemented( - "SetAttrTensor has not been implemented yet."); + return Unimplemented("SetAttrTensor has not been implemented yet."); } Status MlirAbstractOp::SetAttrStringList(const char* attr_name, const void* const* values, const size_t* lengths, int num_values) { - return tensorflow::errors::Unimplemented( - "SetAttrStringList has not been implemented yet."); + return Unimplemented("SetAttrStringList has not been implemented yet."); } Status MlirAbstractOp::SetAttrFloatList(const char* attr_name, const float* values, int num_values) { - return tensorflow::errors::Unimplemented( - "SetAttrFloatList has not been implemented yet."); + return Unimplemented("SetAttrFloatList has not been implemented yet."); } Status MlirAbstractOp::SetAttrIntList(const char* attr_name, const int64_t* values, int num_values) { - return tensorflow::errors::Unimplemented( - "SetAttrIntList has not been implemented yet."); + return Unimplemented("SetAttrIntList has not been implemented yet."); } Status MlirAbstractOp::SetAttrTypeList(const char* attr_name, const tensorflow::DataType* values, int num_values) { - return tensorflow::errors::Unimplemented( - "SetAttrTypeList has not been implemented yet."); + return Unimplemented("SetAttrTypeList has not been implemented yet."); } Status MlirAbstractOp::SetAttrBoolList(const char* attr_name, const unsigned char* values, int num_values) { - return tensorflow::errors::Unimplemented( - "SetAttrBoolList has not been implemented yet."); + return Unimplemented("SetAttrBoolList has not been implemented yet."); } Status MlirAbstractOp::SetAttrShapeList(const char* attr_name, const int64_t** dims, const int* num_dims, int num_values) { - return tensorflow::errors::Unimplemented( - "SetAttrShapeList has not been implemented yet."); + return Unimplemented("SetAttrShapeList has not been implemented yet."); } Status MlirAbstractOp::SetAttrFunctionList( const char* attr_name, absl::Span values) { - return tensorflow::errors::Unimplemented( - "SetAttrFunctionList has not been implemented yet."); + return Unimplemented("SetAttrFunctionList has not been implemented yet."); } Status MlirFunction::GetFunctionDef(tensorflow::FunctionDef** f) { @@ -605,28 +557,101 @@ Status MlirFunctionContext::AddParameter(tensorflow::DataType dtype, } Status MlirAbstractOp::AddInput(AbstractTensorHandle* input) { + if (current_ods_input_ >= op_def_->input_arg_size()) + return InvalidArgument( + absl::StrCat("More Input() (", current_ods_input_, ") calls than the ", + op_def_->input_arg_size(), " allowed input_args ; for op ", + state_->name.getStringRef().str())); + auto* operand = dyn_cast(input); - if (!operand) { - return tensorflow::errors::InvalidArgument( - "Unable to cast input to MlirTensor"); - } + if (!operand) return InvalidArgument("Unable to cast input to MlirTensor"); operands_.push_back(operand->getValue()); + + // Get the next ArgDef and use it to infer the derived attributes associated + // to this input. + const tensorflow::OpDef::ArgDef& arg_def = + op_def_->input_arg(current_ods_input_++); + Type expected_type; + if (arg_def.type() != tensorflow::DT_INVALID) { + Builder builder(context_); + TF_RETURN_IF_ERROR( + tensorflow::ConvertDataType(arg_def.type(), builder, &expected_type)); + if (arg_def.is_ref()) { + Type output_type; + TF_RETURN_IF_ERROR(AddRef(expected_type, &output_type)); + expected_type = output_type; + } + } else { + expected_type = cast(input)->getElementType(); + } + if (!arg_def.type_attr().empty()) + attrs_[arg_def.type_attr()] = TypeAttr::get(expected_type); + return Status::OK(); } + +Status MlirAbstractOp::AddInputList( + absl::Span inputs) { + if (current_ods_input_ >= op_def_->input_arg_size()) + return InvalidArgument( + absl::StrCat("More Input() (", current_ods_input_, ") calls than the ", + op_def_->input_arg_size(), " allowed input_args")); + + for (AbstractTensorHandle* input : inputs) { + auto* operand = dyn_cast(input); + if (!operand) return InvalidArgument("Unable to cast input to MlirTensor"); + operands_.push_back(operand->getValue()); + } + + // Get the next ArgDef and use it to infer the derived attributes associated + // to this input. + const tensorflow::OpDef::ArgDef& arg_def = + op_def_->input_arg(current_ods_input_++); + if (!arg_def.number_attr().empty()) { + Builder builder(context_); + attrs_[arg_def.number_attr()] = builder.getI32IntegerAttr(inputs.size()); + // TODO(aminim): handle ref variable. + if (arg_def.type() != tensorflow::DT_INVALID) { + // TODO(aminim): check type wrt input + Type arg_def_type; + TF_RETURN_IF_ERROR( + ConvertDataType(arg_def.type(), builder, &arg_def_type)); + // Ensure each of the type in the list matches the op def type. + // TODO(aminim): can we improve the error message with the actual types? + for (AbstractTensorHandle* input : inputs) + if (arg_def_type != cast(input)->getElementType()) + return InvalidArgument( + "Invalid input list: type mismatch the op def expectation"); + } else if (!inputs.empty()) { + if (arg_def.type_attr().empty()) + return FailedPrecondition( + "Invalid opdef type constraint: either type or type_attr required"); + + attrs_[arg_def.type_attr()] = + TypeAttr::get(cast(inputs.front())->getElementType()); + } + } else if (!arg_def.type_list_attr().empty()) { + // TODO(aminim): handle ref variable. + SmallVector types; + types.reserve(inputs.size()); + for (AbstractTensorHandle* input : inputs) + types.push_back(TypeAttr::get(cast(input)->getElementType())); + attrs_[arg_def.type_list_attr()] = ArrayAttr::get(types, GetContext()); + } + return Status::OK(); +} + Status MlirFunctionContext::Finalize(OutputList* outputs, AbstractFunction** f) { Block& body = func_.getBody().front(); SmallVector ret_operands; for (auto* output : outputs->outputs) { auto* operand = dyn_cast(output); - if (!operand) { - return tensorflow::errors::InvalidArgument( - "Capturing eager tensors is not supported yet."); - } - if (operand->getValue().getContext() != context_.get()) { - return tensorflow::errors::InvalidArgument( + if (!operand) + return InvalidArgument("Capturing eager tensors is not supported yet."); + if (operand->getValue().getContext() != context_.get()) + return InvalidArgument( "Capturing tensors from other context is not supported."); - } ret_operands.push_back(operand->getValue()); } builder_.create(func_.getLoc(), ret_operands); diff --git a/tensorflow/compiler/mlir/tensorflow/ir/tf_attributes.cc b/tensorflow/compiler/mlir/tensorflow/ir/tf_attributes.cc index dfad1fce26d..40cc2c99c27 100644 --- a/tensorflow/compiler/mlir/tensorflow/ir/tf_attributes.cc +++ b/tensorflow/compiler/mlir/tensorflow/ir/tf_attributes.cc @@ -74,12 +74,9 @@ struct FuncAttrStorage : public AttributeStorage { // Get or create a shape attribute. ShapeAttr ShapeAttr::get(mlir::MLIRContext* context, llvm::Optional> shape) { - if (shape) - return Base::get(context, AttrKind::SHAPE, *shape, - /*unranked=*/false); + if (shape) return Base::get(context, *shape, /*unranked=*/false); - return Base::get(context, AttrKind::SHAPE, ArrayRef(), - /*unranked=*/true); + return Base::get(context, ArrayRef(), /*unranked=*/true); } llvm::Optional> ShapeAttr::getValue() const { @@ -112,12 +109,12 @@ bool ShapeAttr::hasStaticShape() const { FuncAttr FuncAttr::get(mlir::MLIRContext* context, llvm::StringRef name, DictionaryAttr attr) { auto symbol = SymbolRefAttr::get(name, context); - return Base::get(context, AttrKind::FUNC, symbol, attr); + return Base::get(context, symbol, attr); } FuncAttr FuncAttr::get(mlir::MLIRContext* context, SymbolRefAttr symbol, DictionaryAttr attr) { - return Base::get(context, AttrKind::FUNC, symbol, attr); + return Base::get(context, symbol, attr); } SymbolRefAttr FuncAttr::GetName() const { diff --git a/tensorflow/compiler/mlir/tensorflow/ir/tf_attributes.h b/tensorflow/compiler/mlir/tensorflow/ir/tf_attributes.h index e0fef228eb4..5a18b77ab5c 100644 --- a/tensorflow/compiler/mlir/tensorflow/ir/tf_attributes.h +++ b/tensorflow/compiler/mlir/tensorflow/ir/tf_attributes.h @@ -24,19 +24,6 @@ limitations under the License. namespace mlir { namespace TF { -namespace AttrKind { - -// List of supported custom TensorFlow Attribute kinds, necessary for -// isa/dyn_cast. -enum Kind { - FIRST_USED_TENSORFLOW_ATTR = Attribute::FIRST_TENSORFLOW_ATTR, - SHAPE = FIRST_USED_TENSORFLOW_ATTR, - FUNC, - LAST_USED_TENSORFLOW_ATTR, -}; - -} // namespace AttrKind - namespace detail { struct ShapeAttrStorage; diff --git a/tensorflow/compiler/mlir/tensorflow/ir/tf_executor.cc b/tensorflow/compiler/mlir/tensorflow/ir/tf_executor.cc index 9c2968fab37..ea9ae5d9477 100644 --- a/tensorflow/compiler/mlir/tensorflow/ir/tf_executor.cc +++ b/tensorflow/compiler/mlir/tensorflow/ir/tf_executor.cc @@ -54,9 +54,6 @@ namespace tf_executor { namespace { -using TF::DropRefType; -using TF::DropTypeSubTypes; - struct TensorFlowExecutorInlinerInterface : public DialectInlinerInterface { using DialectInlinerInterface::DialectInlinerInterface; @@ -75,9 +72,8 @@ struct TensorFlowExecutorInlinerInterface : public DialectInlinerInterface { } }; -struct TensorFlowExecutorOpFolderDialectInterface - : public OpFolderDialectInterface { - using OpFolderDialectInterface::OpFolderDialectInterface; +struct TensorFlowExecutorDialectFoldInterface : public DialectFoldInterface { + using DialectFoldInterface::DialectFoldInterface; // Registered hook to check if the given region, which is attached to an // operation that is *not* isolated from above (i.e. no internal regions @@ -100,7 +96,7 @@ TensorFlowExecutorDialect::TensorFlowExecutorDialect(MLIRContext *context) >(); addInterfaces(); + TensorFlowExecutorDialectFoldInterface>(); addTypes(); } @@ -551,8 +547,8 @@ LogicalResult Verify(SwitchNOp switchn) { << operand0_tensor_type << " vs " << output_tensor_type; } Type broadcasted_type = OpTrait::util::getBroadcastedType( - DropRefType(DropTypeSubTypes(operand0_tensor_type)), - DropRefType(DropTypeSubTypes(output_tensor_type))); + TF::DropRefAndSubTypes(operand0_tensor_type), + TF::DropRefAndSubTypes(output_tensor_type)); if (!broadcasted_type) { return switchn.emitOpError() << "expects data operand to be broadcastable with all output types" @@ -668,8 +664,8 @@ LogicalResult Verify(MergeOp merge) { << operand_tensor_ty << " vs " << output_tensor_ty; } Type broadcasted_type = OpTrait::util::getBroadcastedType( - DropRefType(DropTypeSubTypes(output_tensor_ty)), - DropRefType(DropTypeSubTypes(operand_tensor_ty))); + TF::DropRefAndSubTypes(output_tensor_ty), + TF::DropRefAndSubTypes(operand_tensor_ty)); if (!broadcasted_type) return merge.emitOpError() << "expects all operands to be broadcastable with output type" diff --git a/tensorflow/compiler/mlir/tensorflow/ir/tf_executor.h b/tensorflow/compiler/mlir/tensorflow/ir/tf_executor.h index da63826a6d4..60036ddc9f8 100644 --- a/tensorflow/compiler/mlir/tensorflow/ir/tf_executor.h +++ b/tensorflow/compiler/mlir/tensorflow/ir/tf_executor.h @@ -45,31 +45,16 @@ class TensorFlowExecutorDialect : public Dialect { void printType(Type type, DialectAsmPrinter &os) const override; }; -namespace TFTypes { -enum Kind { - Control = Type::FIRST_TENSORFLOW_EXECUTOR_TYPE, - Token, -}; -} // namespace TFTypes - // The Control type is a token-like value that models control dependencies from // TensorFlow graphs. class ControlType : public Type::TypeBase { public: using Base::Base; - - static ControlType get(MLIRContext *context) { - return Base::get(context, TFTypes::Control); - } }; class TokenType : public Type::TypeBase { public: using Base::Base; - - static TokenType get(MLIRContext *context) { - return Base::get(context, TFTypes::Token); - } }; // Declares the operations for this dialect using the generated header. diff --git a/tensorflow/compiler/mlir/tensorflow/ir/tf_generated_ops.td b/tensorflow/compiler/mlir/tensorflow/ir/tf_generated_ops.td index cc07d50eee2..283e3326029 100644 --- a/tensorflow/compiler/mlir/tensorflow/ir/tf_generated_ops.td +++ b/tensorflow/compiler/mlir/tensorflow/ir/tf_generated_ops.td @@ -136,7 +136,7 @@ Inputs must be of same size and shape. let hasFolder = 1; } -def TF_AddV2Op : TF_Op<"AddV2", [Commutative, NoSideEffect, ResultsBroadcastableShape, TF_LayoutAgnostic, TF_SameOperandsAndResultElementTypeResolveRef, TF_CwiseBinary]>, +def TF_AddV2Op : TF_Op<"AddV2", [Commutative, NoSideEffect, ResultsBroadcastableShape, TF_CwiseBinary, TF_LayoutAgnostic, TF_SameOperandsAndResultElementTypeResolveRef]>, WithBroadcastableBinOpBuilder { let summary = "Returns x + y element-wise."; @@ -859,15 +859,15 @@ about broadcasting }]; let arguments = (ins - TensorOf<[BF16, F16, F32, F64, I32, I64, TF_Complex128, TF_Complex64]>:$x, - TensorOf<[BF16, F16, F32, F64, I32, I64, TF_Complex128, TF_Complex64]>:$y, + TensorOf<[BF16, F16, F32, F64, I16, I32, I64, TF_Complex128, TF_Complex64]>:$x, + TensorOf<[BF16, F16, F32, F64, I16, I32, I64, TF_Complex128, TF_Complex64]>:$y, DefaultValuedAttr:$adj_x, DefaultValuedAttr:$adj_y ); let results = (outs - TensorOf<[BF16, F16, F32, F64, I32, I64, TF_Complex128, TF_Complex64]>:$output + TensorOf<[BF16, F16, F32, F64, I16, I32, I64, TF_Complex128, TF_Complex64]>:$output ); TF_DerivedOperandTypeAttr T = TF_DerivedOperandTypeAttr<0>; @@ -965,6 +965,40 @@ reverse of SpaceToBatch. See below for a precise description. TF_DerivedOperandTypeAttr Tblock_shape = TF_DerivedOperandTypeAttr<1>; } +def TF_BetaincOp : TF_Op<"Betainc", [NoSideEffect]> { + let summary = [{ +Compute the regularized incomplete beta integral \\(I_x(a, b)\\). + }]; + + let description = [{ +The regularized incomplete beta integral is defined as: + + +\\(I_x(a, b) = \frac{B(x; a, b)}{B(a, b)}\\) + +where + + +\\(B(x; a, b) = \int_0^x t^{a-1} (1 - t)^{b-1} dt\\) + + +is the incomplete beta function and \\(B(a, b)\\) is the *complete* +beta function. + }]; + + let arguments = (ins + TF_F32OrF64Tensor:$a, + TF_F32OrF64Tensor:$b, + TF_F32OrF64Tensor:$x + ); + + let results = (outs + TF_F32OrF64Tensor:$z + ); + + TF_DerivedOperandTypeAttr T = TF_DerivedOperandTypeAttr<0>; +} + def TF_BiasAddOp : TF_Op<"BiasAdd", [NoSideEffect]> { let summary = "Adds `bias` to `value`."; @@ -1319,6 +1353,7 @@ subsequent operation and then be optimized away, however.) let verifier = [{ return Verify(*this); }]; + let hasFolder = 1; } def TF_BucketizeOp : TF_Op<"Bucketize", [NoSideEffect, SameOperandsAndResultShape]> { @@ -1404,6 +1439,38 @@ that are not a number (NaN) or infinity (Inf). Otherwise, passes `tensor` as-is. TF_DerivedOperandTypeAttr T = TF_DerivedOperandTypeAttr<0>; } +def TF_CholeskyOp : TF_Op<"Cholesky", [NoSideEffect]> { + let summary = [{ +Computes the Cholesky decomposition of one or more square matrices. + }]; + + let description = [{ +The input is a tensor of shape `[..., M, M]` whose inner-most 2 dimensions +form square matrices. + +The input has to be symmetric and positive definite. Only the lower-triangular +part of the input will be used for this operation. The upper-triangular part +will not be read. + +The output is a tensor of the same shape as the input +containing the Cholesky decompositions for all input submatrices `[..., :, :]`. + +**Note**: The gradient computation on GPU is faster for large matrices but +not for large batch dimensions when the submatrices are small. In this +case it might be faster to use the CPU. + }]; + + let arguments = (ins + TensorOf<[F16, F32, F64, TF_Complex128, TF_Complex64]>:$input + ); + + let results = (outs + TensorOf<[F16, F32, F64, TF_Complex128, TF_Complex64]>:$output + ); + + TF_DerivedOperandTypeAttr T = TF_DerivedOperandTypeAttr<0>; +} + def TF_ClipByValueOp : TF_Op<"ClipByValue", [NoSideEffect, TF_SameOperandsAndResultElementTypeResolveRef]> { let summary = "Clips tensor values to a specified min and max."; @@ -2025,17 +2092,73 @@ and `B, D, F, H` as group 1. Thus we get the outputs: }]; let arguments = (ins - TensorOf<[BF16, F32, I32, TF_Uint32]>:$input, + TensorOf<[BF16, F16, F32, I32, TF_Uint32]>:$input, I32Tensor:$group_assignment ); let results = (outs - TensorOf<[BF16, F32, I32, TF_Uint32]>:$output + TensorOf<[BF16, F16, F32, I32, TF_Uint32]>:$output ); TF_DerivedOperandTypeAttr T = TF_DerivedOperandTypeAttr<0>; } +def TF_CumprodOp : TF_Op<"Cumprod", [NoSideEffect, TF_AllTypesMatch<["x", "out"]>]> { + let summary = [{ +Compute the cumulative product of the tensor `x` along `axis`. + }]; + + let description = [{ +By default, this op performs an inclusive cumprod, which means that the first +element of the input is identical to the first element of the output: + +```python +tf.cumprod([a, b, c]) # => [a, a * b, a * b * c] +``` + +By setting the `exclusive` kwarg to `True`, an exclusive cumprod is +performed instead: + +```python +tf.cumprod([a, b, c], exclusive=True) # => [1, a, a * b] +``` + +By setting the `reverse` kwarg to `True`, the cumprod is performed in the +opposite direction: + +```python +tf.cumprod([a, b, c], reverse=True) # => [a * b * c, b * c, c] +``` + +This is more efficient than using separate `tf.reverse` ops. + +The `reverse` and `exclusive` kwargs can also be combined: + +```python +tf.cumprod([a, b, c], exclusive=True, reverse=True) # => [b * c, c, 1] +``` + }]; + + let arguments = (ins + TensorOf<[BF16, F16, F32, F64, I16, I32, I64, I8, TF_Complex128, TF_Complex64, TF_Qint32, TF_Qint8, TF_Quint8, TF_Uint16, TF_Uint32, TF_Uint64, TF_Uint8]>:$x, + TF_I32OrI64Tensor:$axis, + + DefaultValuedAttr:$exclusive, + DefaultValuedAttr:$reverse + ); + + let results = (outs + TensorOf<[BF16, F16, F32, F64, I16, I32, I64, I8, TF_Complex128, TF_Complex64, TF_Qint32, TF_Qint8, TF_Quint8, TF_Uint16, TF_Uint32, TF_Uint64, TF_Uint8]>:$out + ); + + TF_DerivedOperandTypeAttr T = TF_DerivedOperandTypeAttr<0>; + TF_DerivedOperandTypeAttr Tidx = TF_DerivedOperandTypeAttr<1>; + + let verifier = [{ + return Verify(*this); + }]; +} + def TF_CumsumOp : TF_Op<"Cumsum", [NoSideEffect, TF_AllTypesMatch<["x", "out"]>]> { let summary = "Compute the cumulative sum of the tensor `x` along `axis`."; @@ -2084,6 +2207,10 @@ tf.cumsum([a, b, c], exclusive=True, reverse=True) # => [b + c, c, 0] TF_DerivedOperandTypeAttr T = TF_DerivedOperandTypeAttr<0>; TF_DerivedOperandTypeAttr Tidx = TF_DerivedOperandTypeAttr<1>; + + let verifier = [{ + return Verify(*this); + }]; } def TF_DataFormatDimMapOp : TF_Op<"DataFormatDimMap", [NoSideEffect, SameOperandsAndResultType]> { @@ -2109,6 +2236,82 @@ the source data format. TF_DerivedOperandTypeAttr T = TF_DerivedOperandTypeAttr<0>; } +def TF_DataFormatVecPermuteOp : TF_Op<"DataFormatVecPermute", [NoSideEffect, SameOperandsAndResultType]> { + let summary = "Permute input tensor from `src_format` to `dst_format`."; + + let description = [{ +Input tensor must be a vector of size 4, or a 4x2 tensor. + +For example, with `src_format` of `NHWC`, `dst_format` of `NCHW`, and inputs: +``` +[1, 2, 3, 4] +``` +and +``` +[[1, 2, 3, 4], + [5, 6, 7, 8]] +``` +, the outputs will be (respectively): +``` +[1, 4, 2, 3] +``` +and +``` +[[1, 4, 2, 3], + [5, 8, 6, 7]] +``` + }]; + + let arguments = (ins + TF_I32OrI64Tensor:$x, + + DefaultValuedAttr:$src_format, + DefaultValuedAttr:$dst_format + ); + + let results = (outs + TF_I32OrI64Tensor:$y + ); + + TF_DerivedOperandTypeAttr T = TF_DerivedOperandTypeAttr<0>; + + let verifier = [{ return Verify(*this); }]; +} + +def TF_DebugIdentityV2Op : TF_Op<"DebugIdentityV2", []> { + let summary = "Debug Identity V2 Op."; + + let description = [{ +Provides an identity mapping from input to output, while writing the content of +the input tensor by calling DebugEventsWriter. + +The semantics of the input tensor depends on tensor_debug_mode. In typical +usage, the input tensor comes directly from the user computation only when +graph_debug_mode is FULL_TENSOR (see protobuf/debug_event.proto for a +list of all the possible values of graph_debug_mode). For the other debug modes, +the input tensor should be produced by an additional op or subgraph that +computes summary information about one or more tensors. + }]; + + let arguments = (ins + TF_Tensor:$input, + + StrAttr:$tfdbg_context_id, + StrAttr:$op_name, + DefaultValuedAttr:$output_slot, + DefaultValuedAttr:$tensor_debug_mode, + DefaultValuedAttr:$debug_urls, + DefaultValuedAttr:$circular_buffer_size, + StrAttr:$tfdbg_run_id + ); + + let results = (outs + TF_Tensor:$output + ); + + TF_DerivedOperandTypeAttr T = TF_DerivedOperandTypeAttr<0>; +} + def TF_DecodeAndCropJpegOp : TF_Op<"DecodeAndCropJpeg", [NoSideEffect]> { let summary = "Decode and Crop a JPEG-encoded image to a uint8 tensor."; @@ -2402,6 +2605,54 @@ horizontal and vertices strides, `strides = [1, stride, stride, 1]`. TF_DerivedOperandTypeAttr T = TF_DerivedOperandTypeAttr<0>; } +def TF_DepthwiseConv2dNativeBackpropFilterOp : TF_Op<"DepthwiseConv2dNativeBackpropFilter", [NoSideEffect]> { + let summary = [{ +Computes the gradients of depthwise convolution with respect to the filter. + }]; + + let arguments = (ins + TF_FpTensor:$input, + I32Tensor:$filter_sizes, + TF_FpTensor:$out_backprop, + + I64ArrayAttr:$strides, + TF_AnyStrAttrOf<["SAME", "VALID", "EXPLICIT"]>:$padding, + DefaultValuedAttr:$explicit_paddings, + DefaultValuedAttr:$data_format, + DefaultValuedAttr:$dilations + ); + + let results = (outs + TF_FpTensor:$output + ); + + TF_DerivedOperandTypeAttr T = TF_DerivedOperandTypeAttr<0>; +} + +def TF_DepthwiseConv2dNativeBackpropInputOp : TF_Op<"DepthwiseConv2dNativeBackpropInput", [NoSideEffect]> { + let summary = [{ +Computes the gradients of depthwise convolution with respect to the input. + }]; + + let arguments = (ins + I32Tensor:$input_sizes, + TF_FpTensor:$filter, + TF_FpTensor:$out_backprop, + + I64ArrayAttr:$strides, + TF_AnyStrAttrOf<["SAME", "VALID", "EXPLICIT"]>:$padding, + DefaultValuedAttr:$explicit_paddings, + DefaultValuedAttr:$data_format, + DefaultValuedAttr:$dilations + ); + + let results = (outs + TF_FpTensor:$output + ); + + TF_DerivedOperandTypeAttr T = TF_DerivedOperandTypeAttr<1>; +} + def TF_DeviceIndexOp : TF_Op<"DeviceIndex", [NoSideEffect]> { let summary = "Return the index of device the op runs."; @@ -2421,6 +2672,40 @@ this op runs. The length of the list is returned in two cases: ); } +def TF_DiagOp : TF_Op<"Diag", [NoSideEffect, SameOperandsAndResultElementType]> { + let summary = "Returns a diagonal tensor with a given diagonal values."; + + let description = [{ +Given a `diagonal`, this operation returns a tensor with the `diagonal` and +everything else padded with zeros. The diagonal is computed as follows: + +Assume `diagonal` has dimensions [D1,..., Dk], then the output is a tensor of +rank 2k with dimensions [D1,..., Dk, D1,..., Dk] where: + +`output[i1,..., ik, i1,..., ik] = diagonal[i1, ..., ik]` and 0 everywhere else. + +For example: + +``` +# 'diagonal' is [1, 2, 3, 4] +tf.diag(diagonal) ==> [[1, 0, 0, 0] + [0, 2, 0, 0] + [0, 0, 3, 0] + [0, 0, 0, 4]] +``` + }]; + + let arguments = (ins + TensorOf<[BF16, F16, F32, F64, I32, I64, TF_Complex128, TF_Complex64]>:$diagonal + ); + + let results = (outs + TensorOf<[BF16, F16, F32, F64, I32, I64, TF_Complex128, TF_Complex64]>:$output + ); + + TF_DerivedOperandTypeAttr T = TF_DerivedOperandTypeAttr<0>; +} + def TF_DiagPartOp : TF_Op<"DiagPart", [NoSideEffect]> { let summary = "Returns the diagonal part of the tensor."; @@ -3075,6 +3360,27 @@ i.e. `exp(x) - 1` or `e^(x) - 1`, where `x` is the input tensor. TF_DerivedOperandTypeAttr T = TF_DerivedOperandTypeAttr<0>; } +def TF_ExtractImagePatchesOp : TF_Op<"ExtractImagePatches", [NoSideEffect]> { + let summary = [{ +Extract `patches` from `images` and put them in the "depth" output dimension. + }]; + + let arguments = (ins + TensorOf<[BF16, F16, F32, F64, I1, I16, I32, I64, I8, TF_Complex128, TF_Complex64, TF_Uint16, TF_Uint32, TF_Uint64, TF_Uint8]>:$images, + + Confined]>:$ksizes, + Confined]>:$strides, + Confined]>:$rates, + TF_AnyStrAttrOf<["SAME", "VALID"]>:$padding + ); + + let results = (outs + TensorOf<[BF16, F16, F32, F64, I1, I16, I32, I64, I8, TF_Complex128, TF_Complex64, TF_Uint16, TF_Uint32, TF_Uint64, TF_Uint8]>:$patches + ); + + TF_DerivedOperandTypeAttr T = TF_DerivedOperandTypeAttr<0>; +} + def TF_FFTOp : TF_Op<"FFT", [NoSideEffect]> { let summary = "Fast Fourier transform."; @@ -4185,6 +4491,22 @@ tf.imag(input) ==> [4.75, 5.75] TF_DerivedResultTypeAttr Tout = TF_DerivedResultTypeAttr<0>; } +def TF_InfeedDequeueOp : TF_Op<"InfeedDequeue", []> { + let summary = [{ +A placeholder op for a value that will be fed into the computation. + }]; + + let arguments = (ins + TF_ShapeAttr:$shape + ); + + let results = (outs + TF_Tensor:$output + ); + + TF_DerivedResultTypeAttr dtype = TF_DerivedResultTypeAttr<0>; +} + def TF_InitializeTableFromTextFileV2Op : TF_Op<"InitializeTableFromTextFileV2", []> { let summary = "Initializes a table from a text file."; @@ -4730,6 +5052,49 @@ tf.linspace(10.0, 12.0, 3, name="linspace") => [ 10.0 11.0 12.0] TF_DerivedOperandTypeAttr Tidx = TF_DerivedOperandTypeAttr<2>; } +def TF_ListDiffOp : TF_Op<"ListDiff", [NoSideEffect]> { + let summary = [{ +Computes the difference between two lists of numbers or strings. + }]; + + let description = [{ +Given a list `x` and a list `y`, this operation returns a list `out` that +represents all values that are in `x` but not in `y`. The returned list `out` +is sorted in the same order that the numbers appear in `x` (duplicates are +preserved). This operation also returns a list `idx` that represents the +position of each `out` element in `x`. In other words: + +`out[i] = x[idx[i]] for i in [0, 1, ..., len(out) - 1]` + +For example, given this input: + +``` +x = [1, 2, 3, 4, 5, 6] +y = [1, 3, 5] +``` + +This operation would return: + +``` +out ==> [2, 4, 6] +idx ==> [1, 3, 5] +``` + }]; + + let arguments = (ins + TF_Tensor:$x, + TF_Tensor:$y + ); + + let results = (outs + TF_Tensor:$out, + TF_I32OrI64Tensor:$idx + ); + + TF_DerivedOperandTypeAttr T = TF_DerivedOperandTypeAttr<0>; + TF_DerivedResultTypeAttr out_idx = TF_DerivedResultTypeAttr<1>; +} + def TF_LogOp : TF_Op<"Log", [NoSideEffect, SameOperandsAndResultType]> { let summary = "Computes natural logarithm of x element-wise."; @@ -4913,6 +5278,44 @@ def TF_LookupTableSizeV2Op : TF_Op<"LookupTableSizeV2", []> { ); } +def TF_LowerBoundOp : TF_Op<"LowerBound", [NoSideEffect]> { + let summary = [{ +Applies lower_bound(sorted_search_values, values) along each row. + }]; + + let description = [{ +Each set of rows with the same index in (sorted_inputs, values) is treated +independently. The resulting row is the equivalent of calling +`np.searchsorted(sorted_inputs, values, side='left')`. + +The result is not a global index to the entire +`Tensor`, but rather just the index in the last dimension. + +A 2-D example: + sorted_sequence = [[0, 3, 9, 9, 10], + [1, 2, 3, 4, 5]] + values = [[2, 4, 9], + [0, 2, 6]] + + result = LowerBound(sorted_sequence, values) + + result == [[1, 2, 2], + [0, 1, 5]] + }]; + + let arguments = (ins + TF_Tensor:$sorted_inputs, + TF_Tensor:$values + ); + + let results = (outs + TF_I32OrI64Tensor:$output + ); + + TF_DerivedOperandTypeAttr T = TF_DerivedOperandTypeAttr<0>; + TF_DerivedResultTypeAttr out_type = TF_DerivedResultTypeAttr<0>; +} + def TF_MatMulOp : TF_Op<"MatMul", [NoSideEffect, TF_SameOperandsAndResultElementTypeResolveRef]> { let summary = [{ Multiply the matrix "a" by the matrix "b". @@ -5422,6 +5825,36 @@ tf.matrix_diag(diagonal, k = -1, num_rows = 3, padding_value = 9) TF_DerivedOperandTypeAttr T = TF_DerivedOperandTypeAttr<0>; } +def TF_MatrixInverseOp : TF_Op<"MatrixInverse", [NoSideEffect]> { + let summary = [{ +Computes the inverse of one or more square invertible matrices or their adjoints (conjugate transposes). + }]; + + let description = [{ +The input is a tensor of shape `[..., M, M]` whose inner-most 2 dimensions +form square matrices. The output is a tensor of the same shape as the input +containing the inverse for all input submatrices `[..., :, :]`. + +The op uses LU decomposition with partial pivoting to compute the inverses. + +If a matrix is not invertible there is no guarantee what the op does. It +may detect the condition and raise an exception or it may simply return a +garbage result. + }]; + + let arguments = (ins + TensorOf<[F16, F32, F64, TF_Complex128, TF_Complex64]>:$input, + + DefaultValuedAttr:$adjoint + ); + + let results = (outs + TensorOf<[F16, F32, F64, TF_Complex128, TF_Complex64]>:$output + ); + + TF_DerivedOperandTypeAttr T = TF_DerivedOperandTypeAttr<0>; +} + def TF_MatrixSetDiagOp : TF_Op<"MatrixSetDiag", [NoSideEffect]> { let summary = [{ Returns a batched matrix tensor with new batched diagonal values. @@ -5673,6 +6106,100 @@ tf.matrix_set_diag(input, diagonals, k = (-1, 2), align="LEFT_RIGHT") TF_DerivedOperandTypeAttr T = TF_DerivedOperandTypeAttr<0>; } +def TF_MatrixSolveOp : TF_Op<"MatrixSolve", [NoSideEffect]> { + let summary = "Solves systems of linear equations."; + + let description = [{ +`Matrix` is a tensor of shape `[..., M, M]` whose inner-most 2 dimensions +form square matrices. `Rhs` is a tensor of shape `[..., M, K]`. The `output` is +a tensor shape `[..., M, K]`. If `adjoint` is `False` then each output matrix +satisfies `matrix[..., :, :] * output[..., :, :] = rhs[..., :, :]`. +If `adjoint` is `True` then each output matrix satisfies +`adjoint(matrix[..., :, :]) * output[..., :, :] = rhs[..., :, :]`. + }]; + + let arguments = (ins + TensorOf<[F16, F32, F64, TF_Complex128, TF_Complex64]>:$matrix, + TensorOf<[F16, F32, F64, TF_Complex128, TF_Complex64]>:$rhs, + + DefaultValuedAttr:$adjoint + ); + + let results = (outs + TensorOf<[F16, F32, F64, TF_Complex128, TF_Complex64]>:$output + ); + + TF_DerivedOperandTypeAttr T = TF_DerivedOperandTypeAttr<0>; +} + +def TF_MatrixTriangularSolveOp : TF_Op<"MatrixTriangularSolve", [NoSideEffect]> { + let summary = [{ +Solves systems of linear equations with upper or lower triangular matrices by backsubstitution. + }]; + + let description = [{ +`matrix` is a tensor of shape `[..., M, M]` whose inner-most 2 dimensions form +square matrices. If `lower` is `True` then the strictly upper triangular part +of each inner-most matrix is assumed to be zero and not accessed. +If `lower` is False then the strictly lower triangular part of each inner-most +matrix is assumed to be zero and not accessed. +`rhs` is a tensor of shape `[..., M, N]`. + +The output is a tensor of shape `[..., M, N]`. If `adjoint` is +`True` then the innermost matrices in `output` satisfy matrix equations +`matrix[..., :, :] * output[..., :, :] = rhs[..., :, :]`. +If `adjoint` is `False` then the strictly then the innermost matrices in +`output` satisfy matrix equations +`adjoint(matrix[..., i, k]) * output[..., k, j] = rhs[..., i, j]`. + +Note, the batch shapes for the inputs only need to broadcast. + +Example: +```python + +a = tf.constant([[3, 0, 0, 0], + [2, 1, 0, 0], + [1, 0, 1, 0], + [1, 1, 1, 1]], dtype=tf.float32) + +b = tf.constant([[4], + [2], + [4], + [2]], dtype=tf.float32) + +x = tf.linalg.triangular_solve(a, b, lower=True) +x +# + +# in python3 one can use `a@x` +tf.matmul(a, x) +# +``` + }]; + + let arguments = (ins + TensorOf<[F16, F32, F64, TF_Complex128, TF_Complex64]>:$matrix, + TensorOf<[F16, F32, F64, TF_Complex128, TF_Complex64]>:$rhs, + + DefaultValuedAttr:$lower, + DefaultValuedAttr:$adjoint + ); + + let results = (outs + TensorOf<[F16, F32, F64, TF_Complex128, TF_Complex64]>:$output + ); + + TF_DerivedOperandTypeAttr T = TF_DerivedOperandTypeAttr<0>; +} + def TF_MaxOp : TF_Op<"Max", [NoSideEffect]> { let summary = [{ Computes the maximum of elements across dimensions of a tensor. @@ -5818,12 +6345,44 @@ def TF_MaximumOp : TF_Op<"Maximum", [NoSideEffect, ResultsBroadcastableShape, TF TF_DerivedOperandTypeAttr T = TF_DerivedOperandTypeAttr<0>; } +def TF_MeanOp : TF_Op<"Mean", [NoSideEffect, TF_FoldOperandsTransposeInterface]> { + let summary = "Computes the mean of elements across dimensions of a tensor."; + + let description = [{ +Reduces `input` along the dimensions given in `axis`. Unless +`keep_dims` is true, the rank of the tensor is reduced by 1 for each entry in +`axis`. If `keep_dims` is true, the reduced dimensions are +retained with length 1. + }]; + + let arguments = (ins + TensorOf<[BF16, F16, F32, F64, I16, I32, I64, I8, TF_Complex128, TF_Complex64, TF_Qint32, TF_Qint8, TF_Quint8, TF_Uint16, TF_Uint32, TF_Uint64, TF_Uint8]>:$input, + TF_I32OrI64Tensor:$reduction_indices, + + DefaultValuedAttr:$keep_dims + ); + + let results = (outs + TensorOf<[BF16, F16, F32, F64, I16, I32, I64, I8, TF_Complex128, TF_Complex64, TF_Qint32, TF_Qint8, TF_Quint8, TF_Uint16, TF_Uint32, TF_Uint64, TF_Uint8]>:$output + ); + + TF_DerivedOperandTypeAttr T = TF_DerivedOperandTypeAttr<0>; + TF_DerivedOperandTypeAttr Tidx = TF_DerivedOperandTypeAttr<1>; + + let extraClassDeclaration = [{ + // TF_FoldOperandsTransposeInterface: + SmallVector GetLayoutDependentArgs() { return {0}; } + SmallVector GetLayoutDependentResults() { return {}; } + LogicalResult FoldOperandsPermutation(ArrayRef permutation); + }]; +} + def TF_MergeSummaryOp : TF_Op<"MergeSummary", [NoSideEffect, SameOperandsAndResultType]> { let summary = "Merges summaries."; let description = [{ This op creates a -[`Summary`](https://github.com/tensorflow/tensorflow/blob/master/tensorflow/core/framework/summary.proto) +[`Summary`](https://www.tensorflow.org/code/tensorflow/core/framework/summary.proto) protocol buffer that contains the union of all the values in the input summaries. @@ -6054,7 +6613,7 @@ the result here is consistent with a truncating divide. E.g. TF_DerivedOperandTypeAttr T = TF_DerivedOperandTypeAttr<0>; } -def TF_MulOp : TF_Op<"Mul", [Commutative, NoSideEffect, ResultsBroadcastableShape, TF_SameOperandsAndResultElementTypeResolveRef, TF_CwiseBinary]>, +def TF_MulOp : TF_Op<"Mul", [Commutative, NoSideEffect, ResultsBroadcastableShape, TF_CwiseBinary, TF_SameOperandsAndResultElementTypeResolveRef]>, WithBroadcastableBinOpBuilder { let summary = "Returns x * y element-wise."; @@ -7215,9 +7774,6 @@ def TF_RangeDatasetOp : TF_Op<"RangeDataset", []> { Creates a dataset with a range of values. Corresponds to python's xrange. }]; - let description = [{ - }]; - let arguments = (ins I64Tensor:$start, I64Tensor:$stop, @@ -8111,6 +8667,47 @@ rint([-1.7, -1.5, -0.2, 0.2, 1.5, 1.7, 2.0]) ==> [-2., -2., -0., 0., 2., 2., 2.] TF_DerivedOperandTypeAttr T = TF_DerivedOperandTypeAttr<0>; } +def TF_RollOp : TF_Op<"Roll", [NoSideEffect]> { + let summary = "Rolls the elements of a tensor along an axis."; + + let description = [{ +The elements are shifted positively (towards larger indices) by the offset of +`shift` along the dimension of `axis`. Negative `shift` values will shift +elements in the opposite direction. Elements that roll passed the last position +will wrap around to the first and vice versa. Multiple shifts along multiple +axes may be specified. + +For example: + +``` +# 't' is [0, 1, 2, 3, 4] +roll(t, shift=2, axis=0) ==> [3, 4, 0, 1, 2] + +# shifting along multiple dimensions +# 't' is [[0, 1, 2, 3, 4], [5, 6, 7, 8, 9]] +roll(t, shift=[1, -2], axis=[0, 1]) ==> [[7, 8, 9, 5, 6], [2, 3, 4, 0, 1]] + +# shifting along the same axis multiple times +# 't' is [[0, 1, 2, 3, 4], [5, 6, 7, 8, 9]] +roll(t, shift=[2, -3], axis=[1, 1]) ==> [[1, 2, 3, 4, 0], [6, 7, 8, 9, 5]] +``` + }]; + + let arguments = (ins + TF_Tensor:$input, + TF_I32OrI64Tensor:$shift, + TF_I32OrI64Tensor:$axis + ); + + let results = (outs + TF_Tensor:$output + ); + + TF_DerivedOperandTypeAttr Tshift = TF_DerivedOperandTypeAttr<1>; + TF_DerivedOperandTypeAttr T = TF_DerivedOperandTypeAttr<0>; + TF_DerivedOperandTypeAttr Taxis = TF_DerivedOperandTypeAttr<2>; +} + def TF_RoundOp : TF_Op<"Round", [NoSideEffect, SameOperandsAndResultType]> { let summary = [{ Rounds the values of a tensor to the nearest integer, element-wise. @@ -8858,6 +9455,8 @@ size(t) ==> 12 let verifier = [{ return Verify(*this); }]; + + let hasFolder = 1; } def TF_SliceOp : TF_Op<"Slice", [NoSideEffect]> { @@ -9464,7 +10063,7 @@ I.e., \\(y = x * x = x^2\\). def TF_SquaredDifferenceOp : TF_Op<"SquaredDifference", [Commutative, NoSideEffect, ResultsBroadcastableShape]>, WithBroadcastableBinOpBuilder { - let summary = "Returns (x - y)(x - y) element-wise."; + let summary = "Returns conj(x - y)(x - y) element-wise."; let description = [{ *NOTE*: `SquaredDifference` supports broadcasting. More about broadcasting @@ -9576,6 +10175,49 @@ def TF_StackV2Op : TF_Op<"StackV2", []> { ); } +def TF_StatelessMultinomialOp : TF_Op<"StatelessMultinomial", [NoSideEffect]> { + let summary = "Draws samples from a multinomial distribution."; + + let arguments = (ins + TF_IntOrFpTensor:$logits, + I32Tensor:$num_samples, + TF_I32OrI64Tensor:$seed + ); + + let results = (outs + TF_I32OrI64Tensor:$output + ); + + TF_DerivedOperandTypeAttr T = TF_DerivedOperandTypeAttr<0>; + TF_DerivedOperandTypeAttr Tseed = TF_DerivedOperandTypeAttr<2>; + TF_DerivedResultTypeAttr output_dtype = TF_DerivedResultTypeAttr<0>; +} + +def TF_StatelessRandomNormalOp : TF_Op<"StatelessRandomNormal", [NoSideEffect]> { + let summary = [{ +Outputs deterministic pseudorandom values from a normal distribution. + }]; + + let description = [{ +The generated values will have mean 0 and standard deviation 1. + +The outputs are a deterministic function of `shape` and `seed`. + }]; + + let arguments = (ins + TF_I32OrI64Tensor:$shape, + TF_I32OrI64Tensor:$seed + ); + + let results = (outs + TF_FpTensor:$output + ); + + TF_DerivedOperandTypeAttr T = TF_DerivedOperandTypeAttr<0>; + TF_DerivedOperandTypeAttr Tseed = TF_DerivedOperandTypeAttr<1>; + TF_DerivedResultTypeAttr dtype = TF_DerivedResultTypeAttr<0>; +} + def TF_StatelessRandomUniformOp : TF_Op<"StatelessRandomUniform", [NoSideEffect]> { let summary = [{ Outputs deterministic pseudorandom random values from a uniform distribution. @@ -9602,6 +10244,33 @@ The outputs are a deterministic function of `shape` and `seed`. TF_DerivedResultTypeAttr dtype = TF_DerivedResultTypeAttr<0>; } +def TF_StatelessRandomUniformIntOp : TF_Op<"StatelessRandomUniformInt", [NoSideEffect]> { + let summary = [{ +Outputs deterministic pseudorandom random integers from a uniform distribution. + }]; + + let description = [{ +The generated values follow a uniform distribution in the range `[minval, maxval)`. + +The outputs are a deterministic function of `shape`, `seed`, `minval`, and `maxval`. + }]; + + let arguments = (ins + TF_I32OrI64Tensor:$shape, + TF_I32OrI64Tensor:$seed, + TF_I32OrI64Tensor:$minval, + TF_I32OrI64Tensor:$maxval + ); + + let results = (outs + TF_I32OrI64Tensor:$output + ); + + TF_DerivedOperandTypeAttr T = TF_DerivedOperandTypeAttr<0>; + TF_DerivedOperandTypeAttr Tseed = TF_DerivedOperandTypeAttr<1>; + TF_DerivedOperandTypeAttr dtype = TF_DerivedOperandTypeAttr<2>; +} + def TF_StatelessTruncatedNormalOp : TF_Op<"StatelessTruncatedNormal", [NoSideEffect]> { let summary = [{ Outputs deterministic pseudorandom values from a truncated normal distribution. @@ -9871,7 +10540,37 @@ Examples: TF_DerivedOperandSizeAttr N = TF_DerivedOperandSizeAttr<0>; } -def TF_SubOp : TF_Op<"Sub", [NoSideEffect, ResultsBroadcastableShape, TF_SameOperandsAndResultElementTypeResolveRef, TF_CwiseBinary]>, +def TF_StringToHashBucketFastOp : TF_Op<"StringToHashBucketFast", [NoSideEffect]> { + let summary = [{ +Converts each string in the input Tensor to its hash mod by a number of buckets. + }]; + + let description = [{ +The hash function is deterministic on the content of the string within the +process and will never change. However, it is not suitable for cryptography. +This function may be used when CPU time is scarce and inputs are trusted or +unimportant. There is a risk of adversaries constructing inputs that all hash +to the same bucket. To prevent this problem, use a strong hash function with +`tf.string_to_hash_bucket_strong`. + +Examples: + +>>> tf.strings.to_hash_bucket_fast(["Hello", "TensorFlow", "2.x"], 3).numpy() +array([0, 2, 2]) + }]; + + let arguments = (ins + TF_StrTensor:$input, + + Confined]>:$num_buckets + ); + + let results = (outs + I64Tensor:$output + ); +} + +def TF_SubOp : TF_Op<"Sub", [NoSideEffect, ResultsBroadcastableShape, TF_CwiseBinary, TF_SameOperandsAndResultElementTypeResolveRef]>, WithBroadcastableBinOpBuilder { let summary = "Returns x - y element-wise."; @@ -9926,6 +10625,25 @@ retained with length 1. >]; } +def TF_SymbolicGradientOp : TF_Op<"SymbolicGradient", [NoSideEffect]> { + let summary = [{ +Computes the gradient function for function f via backpropagation. + }]; + + let arguments = (ins + Variadic:$input, + + SymbolRefAttr:$f + ); + + let results = (outs + Variadic:$output + ); + + TF_DerivedOperandTypeListAttr Tin = TF_DerivedOperandTypeListAttr<0>; + TF_DerivedResultTypeListAttr Tout = TF_DerivedResultTypeListAttr<0>; +} + def TF_TPUCompilationResultOp : TF_Op<"TPUCompilationResult", [NoSideEffect]> { let summary = "Returns the result of a TPU compilation."; @@ -10912,43 +11630,6 @@ array([[1, 2, 3, 1, 2, 3], // input.rank() } -def TF_ToBoolOp : TF_Op<"ToBool", [NoSideEffect]> { - let summary = "Converts a tensor to a scalar predicate."; - - let description = [{ -Converts a tensor to a scalar predicate with the following rules: - -- For 0D tensors, truthiness is determined by comparing against a "zero" - value. For numerical types it is the obvious zero. For strings it is the - empty string. - -- For >0D tensors, truthiness is determined by looking at the number of - elements. If has zero elements, then the result is false. Otherwise the - result is true. - -This matches the behavior of If and While for determining if a tensor counts -as true/false for a branch condition. - }]; - - let arguments = (ins - TF_Tensor:$input - ); - - let results = (outs - I1Tensor:$output - ); - - TF_DerivedOperandTypeAttr T = TF_DerivedOperandTypeAttr<0>; - - let builders = [OpBuilder< - "OpBuilder &builder, OperationState &result, Value value", [{ - build(builder, result, RankedTensorType::get({}, builder.getI1Type()), - value); - }]>]; - - let hasCanonicalizer = 1; -} - def TF_TopKV2Op : TF_Op<"TopKV2", [NoSideEffect]> { let summary = [{ Finds values and indices of the `k` largest elements for the last dimension. @@ -11370,6 +12051,44 @@ tf.unsorted_segment_sum(c, tf.constant([0, 1, 0]), num_segments=2) let verifier = [{ return VerifyUnsortedSegmentReduction(*this); }]; } +def TF_UpperBoundOp : TF_Op<"UpperBound", [NoSideEffect]> { + let summary = [{ +Applies upper_bound(sorted_search_values, values) along each row. + }]; + + let description = [{ +Each set of rows with the same index in (sorted_inputs, values) is treated +independently. The resulting row is the equivalent of calling +`np.searchsorted(sorted_inputs, values, side='right')`. + +The result is not a global index to the entire +`Tensor`, but rather just the index in the last dimension. + +A 2-D example: + sorted_sequence = [[0, 3, 9, 9, 10], + [1, 2, 3, 4, 5]] + values = [[2, 4, 9], + [0, 2, 6]] + + result = UpperBound(sorted_sequence, values) + + result == [[1, 2, 4], + [0, 2, 5]] + }]; + + let arguments = (ins + TF_Tensor:$sorted_inputs, + TF_Tensor:$values + ); + + let results = (outs + TF_I32OrI64Tensor:$output + ); + + TF_DerivedOperandTypeAttr T = TF_DerivedOperandTypeAttr<0>; + TF_DerivedResultTypeAttr out_type = TF_DerivedResultTypeAttr<0>; +} + def TF_VarIsInitializedOp : TF_Op<"VarIsInitializedOp", []> { let summary = [{ Checks whether a resource handle-based variable has been initialized. @@ -11901,6 +12620,13 @@ https://www.tensorflow.org/performance/xla/operation_semantics#pad def TF_XlaRecvFromHostOp : TF_Op<"XlaRecvFromHost", []> { let summary = "An op to receive a tensor from the host."; + let description = [{ +output: the tensor that will be received from the host. +Toutput: element type for output. +shape: shape for output. +key: A unique identifier for this region used to match up host transfers. + }]; + let arguments = (ins TF_ShapeAttr:$shape, StrAttr:$key @@ -11945,6 +12671,31 @@ def TF_XlaReplicaIdOp : TF_Op<"XlaReplicaId", [NoSideEffect]> { ); } +def TF_XlaScatterOp : TF_Op<"XlaScatter", [NoSideEffect]> { + let summary = "Wraps the XLA Scatter operator documented at"; + + let description = [{ +https://www.tensorflow.org/xla/operation_semantics#scatter. + }]; + + let arguments = (ins + TensorOf<[BF16, F16, F32, F64, I16, I32, I64, I8, TF_Complex128, TF_Complex64, TF_Qint32, TF_Qint8, TF_Quint8, TF_Uint16, TF_Uint32, TF_Uint64, TF_Uint8]>:$operand, + TF_I32OrI64Tensor:$scatter_indices, + TensorOf<[BF16, F16, F32, F64, I16, I32, I64, I8, TF_Complex128, TF_Complex64, TF_Qint32, TF_Qint8, TF_Quint8, TF_Uint16, TF_Uint32, TF_Uint64, TF_Uint8]>:$updates, + + SymbolRefAttr:$update_computation, + StrAttr:$dimension_numbers, + BoolAttr:$indices_are_sorted + ); + + let results = (outs + TensorOf<[BF16, F16, F32, F64, I16, I32, I64, I8, TF_Complex128, TF_Complex64, TF_Qint32, TF_Qint8, TF_Quint8, TF_Uint16, TF_Uint32, TF_Uint64, TF_Uint8]>:$output + ); + + TF_DerivedOperandTypeAttr Tindices = TF_DerivedOperandTypeAttr<1>; + TF_DerivedOperandTypeAttr T = TF_DerivedOperandTypeAttr<0>; +} + def TF_XlaSelfAdjointEigOp : TF_Op<"XlaSelfAdjointEig", [NoSideEffect]> { let summary = [{ Computes the eigen decomposition of a batch of self-adjoint matrices @@ -11977,6 +12728,12 @@ i=0...N-1. def TF_XlaSendToHostOp : TF_Op<"XlaSendToHost", []> { let summary = "An op to send a tensor to the host."; + let description = [{ +input: the tensor that will be sent to the host. +Tinput: element type for input. +key: A unique identifier for this region used to match up host transfers. + }]; + let arguments = (ins TF_Tensor:$input, @@ -12062,6 +12819,43 @@ def TF_ZerosLikeOp : TF_Op<"ZerosLike", [NoSideEffect, SameOperandsAndResultType TF_DerivedOperandTypeAttr T = TF_DerivedOperandTypeAttr<0>; } +def TF__FusedBatchNormExOp : TF_Op<"_FusedBatchNormEx", [NoSideEffect]> { + let summary = "Internal FusedBatchNorm operation: reserved for internal use."; + + let description = [{ +Do not invoke this operator directly in Python. A fusion optimization is +expected to create these operators. + }]; + + let arguments = (ins + TensorOf<[F16, F32]>:$x, + F32Tensor:$scale, + F32Tensor:$offset, + F32Tensor:$mean, + F32Tensor:$variance, + Variadic>:$side_input, + + DefaultValuedAttr:$epsilon, + DefaultValuedAttr:$exponential_avg_factor, + DefaultValuedAttr:$activation_mode, + DefaultValuedAttr:$data_format, + DefaultValuedAttr:$is_training + ); + + let results = (outs + TensorOf<[F16, F32]>:$y, + F32Tensor:$batch_mean, + F32Tensor:$batch_variance, + F32Tensor:$reserve_space_1, + F32Tensor:$reserve_space_2, + F32Tensor:$reserve_space_3 + ); + + TF_DerivedOperandTypeAttr T = TF_DerivedOperandTypeAttr<0>; + TF_DerivedOperandTypeAttr U = TF_DerivedOperandTypeAttr<1>; + TF_DerivedOperandSizeAttr num_side_inputs = TF_DerivedOperandSizeAttr<5>; +} + def TF__FusedConv2DOp : TF_Op<"_FusedConv2D", [NoSideEffect]> { let summary = [{ Performs a convolution followed by a specified series of operations. @@ -12183,18 +12977,17 @@ Compiles a computations for execution on one or more TPU devices. }]; let description = [{ -For the internal use of the distributed TPU compiler. Note that currently only -single TPU device is supported. +For the internal use of the distributed TPU compiler. 'mlir_module' is a serialized MLIR module with a `main` function that contains target computation. 'dynamic_shapes' contains dynamic shapes of arguments whose shapes were not known statically at TPUReplication rewrite time. -'metadata' is a serialized TPUCompileMetadataProto describing -the shapes and types of the inputs to the computation, as well as a mapping onto -the TPU pod topology. -'program' output is a string key that is passed to the _TPUExecute op and -used to look up the program in the compilation cache. +'metadata' is a serialized TPUCompileMetadataProto describing the shapes and +types of the inputs to the computation, as well as a mapping onto the TPU pod +topology. +'program' output is a string key that is passed to the TPUExecute op and used to +look up the program in the compilation cache. }]; let arguments = (ins @@ -12231,6 +13024,28 @@ rewrite passes must replace this op with a _TPUCompileMlir op `program` output. ); } +def TF__UnaryOpsCompositionOp : TF_Op<"_UnaryOpsComposition", [NoSideEffect, SameOperandsAndResultType]> { + let summary = [{ +*NOTE*: Do not invoke this operator directly in Python. Graph rewrite pass is + }]; + + let description = [{ +expected to create these operators. + }]; + + let arguments = (ins + TensorOf<[F16, F32, F64]>:$x, + + StrArrayAttr:$op_names + ); + + let results = (outs + TensorOf<[F16, F32, F64]>:$y + ); + + TF_DerivedOperandTypeAttr T = TF_DerivedOperandTypeAttr<0>; +} + def TF__XlaHostComputeMlirOp : TF_Op<"_XlaHostComputeMlir", []> { let summary = [{ A pseudo-op to represent host-side computation in an XLA program. diff --git a/tensorflow/compiler/mlir/tensorflow/ir/tf_ops.cc b/tensorflow/compiler/mlir/tensorflow/ir/tf_ops.cc index e35e5dc40a8..737442d5f8c 100644 --- a/tensorflow/compiler/mlir/tensorflow/ir/tf_ops.cc +++ b/tensorflow/compiler/mlir/tensorflow/ir/tf_ops.cc @@ -55,6 +55,8 @@ limitations under the License. #include "mlir/IR/TypeUtilities.h" // from @llvm-project #include "mlir/IR/Types.h" // from @llvm-project #include "mlir/IR/Value.h" // from @llvm-project +#include "mlir/Interfaces/DecodeAttributesInterfaces.h" // from @llvm-project +#include "mlir/Interfaces/FoldInterfaces.h" // from @llvm-project #include "mlir/Parser.h" // from @llvm-project #include "mlir/Support/LLVM.h" // from @llvm-project #include "mlir/Support/LogicalResult.h" // from @llvm-project @@ -112,6 +114,22 @@ bool HasSingleUse(FuncOp func) { return true; } +struct TFConstantFoldInterface : public DialectFoldInterface { + TFConstantFoldInterface(Dialect *dialect) : DialectFoldInterface(dialect) {} + LogicalResult fold(Operation *op, ArrayRef operands, + SmallVectorImpl &results) const final { + return TensorFlowDialect::constantFold(op, operands, results); + } +}; + +struct TFDecodeAttributesInterface : public DialectDecodeAttributesInterface { + TFDecodeAttributesInterface(Dialect *dialect) + : DialectDecodeAttributesInterface(dialect) {} + LogicalResult decode(OpaqueElementsAttr input, ElementsAttr &output) const { + return TensorFlowDialect::decode(input, output); + } +}; + struct TFInlinerInterface : public DialectInlinerInterface { using DialectInlinerInterface::DialectInlinerInterface; @@ -206,6 +224,9 @@ std::vector *TensorFlowDialect::additional_operation_hooks_ = new std::vector(); +TensorFlowDialect::ConstantFoldHook TensorFlowDialect::constant_fold_hook_; +TensorFlowDialect::DecodeConstantHook TensorFlowDialect::decode_constant_hook_; + TensorFlowDialect::TensorFlowDialect(MLIRContext *context) : Dialect(/*name=*/"tf", context, TypeID::get()) { addOperations< @@ -217,7 +238,8 @@ TensorFlowDialect::TensorFlowDialect(MLIRContext *context) #define HANDLE_LAST_TF_TYPE(tftype, enumerant, name) tftype##Type #include "tensorflow/compiler/mlir/tensorflow/ir/tf_types.def" >(); - addInterfaces(); + addInterfaces(); addAttributes(); // Support unknown operations because not all TensorFlow operations are @@ -336,16 +358,12 @@ Attribute TensorFlowDialect::parseAttribute(DialectAsmParser &parser, void TensorFlowDialect::printAttribute(Attribute attr, DialectAsmPrinter &os) const { - switch (attr.getKind()) { - case AttrKind::SHAPE: - PrintShapeAttr(attr.cast(), os); - break; - case AttrKind::FUNC: - PrintFuncAttr(attr.cast(), os); - break; - default: - llvm_unreachable("unexpected tensorflow attribute kind"); - } + if (auto shape_attr = attr.dyn_cast()) + PrintShapeAttr(shape_attr, os); + else if (auto func_attr = attr.dyn_cast()) + PrintFuncAttr(func_attr, os); + else + llvm_unreachable("unexpected tensorflow attribute type"); } // Parses a type registered to this dialect. @@ -354,51 +372,37 @@ Type TensorFlowDialect::parseType(DialectAsmParser &parser) const { if (parser.parseKeyword(&data)) return Type(); Location loc = parser.getEncodedSourceLoc(parser.getNameLoc()); - auto typeKind = llvm::StringSwitch(data) + #define HANDLE_TF_TYPE(tftype, enumerant, name) \ - .Case(name, TensorFlowTypes::enumerant) + if (data == name) return tftype##Type::get(getContext()); // Custom TensorFlow types are handled separately at the end as they do partial // match. #define HANDLE_CUSTOM_TF_TYPE(tftype, enumerant, name) // NOLINTNEXTLINE #include "tensorflow/compiler/mlir/tensorflow/ir/tf_types.def" - .StartsWith("resource", TensorFlowTypes::RESOURCE) - .StartsWith("variant", TensorFlowTypes::VARIANT) - .Default(0); - switch (typeKind) { - default: - return (emitError(loc, "unknown TensorFlow type: " + data), nullptr); -#define HANDLE_TF_TYPE(tftype, enumerant, name) \ - case TensorFlowTypes::enumerant: \ - return tftype##Type::get(getContext()); -#define HANDLE_CUSTOM_TF_TYPE(tftype, enumerant, name) -// NOLINTNEXTLINE -#include "tensorflow/compiler/mlir/tensorflow/ir/tf_types.def" - case TensorFlowTypes::RESOURCE: - return ParseResourceType(parser, loc); - case TensorFlowTypes::VARIANT: - return ParseVariantType(parser, loc); - } + if (data.startswith("resource")) return ParseResourceType(parser, loc); + if (data.startswith("variant")) return ParseVariantType(parser, loc); + return (emitError(loc, "unknown TensorFlow type: " + data), nullptr); } // Prints a type registered to this dialect. void TensorFlowDialect::printType(Type ty, DialectAsmPrinter &os) const { assert(ty.isa()); - switch (ty.getKind()) { - default: - llvm_unreachable("unexpected tensorflow type kind"); -#define HANDLE_TF_TYPE(tftype, enumerant, name) \ - case TensorFlowTypes::enumerant: \ - os << name; \ - break; +#define HANDLE_TF_TYPE(tftype, enumerant, name) \ + if (auto derived_ty = ty.dyn_cast()) { \ + os << name; \ + return; \ + } #define HANDLE_CUSTOM_TF_TYPE(tftype, enumerant, name) \ - case TensorFlowTypes::enumerant: \ - Print##tftype##Type(ty.cast(), os); \ - break; + if (auto derived_ty = ty.dyn_cast()) { \ + Print##tftype##Type(derived_ty, os); \ + return; \ + } // NOLINTNEXTLINE #include "tensorflow/compiler/mlir/tensorflow/ir/tf_types.def" - } + + llvm_unreachable("unexpected tensorflow type kind"); } namespace { diff --git a/tensorflow/compiler/mlir/tensorflow/ir/tf_ops.h b/tensorflow/compiler/mlir/tensorflow/ir/tf_ops.h index bbcce4ee177..3169f7fba8d 100644 --- a/tensorflow/compiler/mlir/tensorflow/ir/tf_ops.h +++ b/tensorflow/compiler/mlir/tensorflow/ir/tf_ops.h @@ -116,10 +116,35 @@ class TensorFlowDialect : public Dialect { 0, (addOperation(AbstractOperation::get(*this)), 0)...}; } + using ConstantFoldHook = LogicalResult (*)(Operation *, ArrayRef, + SmallVectorImpl &); + static void RegisterConstantFoldHook(ConstantFoldHook fn) { + constant_fold_hook_ = std::move(fn); + } + + static LogicalResult constantFold(Operation *op, ArrayRef operands, + SmallVectorImpl &results) { + if (constant_fold_hook_) return constant_fold_hook_(op, operands, results); + return failure(); + } + + using DecodeConstantHook = LogicalResult (*)(OpaqueElementsAttr input, + ElementsAttr &output); + static void RegisterDecodeConstantHook(DecodeConstantHook fn) { + decode_constant_hook_ = std::move(fn); + } + static LogicalResult decode(OpaqueElementsAttr input, ElementsAttr &output) { + if (decode_constant_hook_) return decode_constant_hook_(input, output); + return failure(); + } + private: // Hook functions which may add additional operations to the dialect. // These are invoked at construction time. static std::vector *additional_operation_hooks_; + + static ConstantFoldHook constant_fold_hook_; + static DecodeConstantHook decode_constant_hook_; }; } // namespace TF diff --git a/tensorflow/compiler/mlir/tensorflow/ir/tf_ops.td b/tensorflow/compiler/mlir/tensorflow/ir/tf_ops.td index 5269bb82239..db0a97d4b96 100644 --- a/tensorflow/compiler/mlir/tensorflow/ir/tf_ops.td +++ b/tensorflow/compiler/mlir/tensorflow/ir/tf_ops.td @@ -97,10 +97,10 @@ An n-way switch statement, implementing the following: Variadic:$input, Confined]>:$branches, - DefaultValuedAttr:$output_shapes, - // Used to map StatelessCase and Case to a common op. - DefaultValuedAttr:$is_stateless + // Used to map StatelessCase and Case op defined in TensorFlow to a common + // op. + BoolAttr:$is_stateless ); let results = (outs @@ -109,8 +109,57 @@ An n-way switch statement, implementing the following: TF_DerivedOperandTypeListAttr Tin = TF_DerivedOperandTypeListAttr<1>; TF_DerivedResultTypeListAttr Tout = TF_DerivedResultTypeListAttr<0>; + TF_DerivedResultShapeListAttr output_shapes = TF_DerivedResultShapeListAttr<0>; let hasCanonicalizer = 1; + + let verifier = [{ + return Verify(*this); + }]; +} + +def TF_CaseRegionOp : TF_Op<"CaseRegion", + [SingleBlockImplicitTerminator<"YieldOp">, NoRegionArguments]> { + let summary = [{ +An n-way switch statement which calls a single branch function. + }]; + + let description = [{ +An n-way switch statement, implementing the following: + ``` + switch (branch_index) { + case 0: + output = branches[0](input); + break; + case 1: + output = branches[1](input); + break; + ... + case [[nbranches-1]]: + default: + output = branches[nbranches-1](input); + break; + } + ``` + }]; + + let arguments = (ins + I32Tensor:$branch_index, + + // Used to map StatelessCase and Case op defined in TensorFlow to a common + // op. + BoolAttr:$is_stateless + ); + + let results = (outs + Variadic:$output + ); + + let regions = (region VariadicRegion>:$branches); + + let verifier = [{ + return Verify(*this); + }]; } // In MLIR, the TensorFlow tensor value is represented as an ElementsAttr, with @@ -168,30 +217,6 @@ source_target_pairs=`[[0,1],[1,2],[2,3],[3,0]]` gets the outputs: TF_DerivedOperandTypeAttr T = TF_DerivedOperandTypeAttr<0>; } - -def TF_DataFormatVecPermuteOp : TF_Op<"DataFormatVecPermute", [NoSideEffect, SameOperandsAndResultType]> { - let summary = "Permute input tensor from `src_format` to `dst_format`"; - - let description = [{ -Input tensor must be a vector of size 4, or a 4x2 tensor. - }]; - - let arguments = (ins - TF_I32OrI64Tensor:$x, - - DefaultValuedAttr:$src_format, - DefaultValuedAttr:$dst_format - ); - - let results = (outs - TF_I32OrI64Tensor:$y - ); - - TF_DerivedOperandTypeAttr T = TF_DerivedOperandTypeAttr<0>; - - let verifier = [{ return Verify(*this); }]; -} - def TF_EmptyTensorListOp : TF_TensorListInitOp<"EmptyTensorList"> { let summary = "Creates and returns an empty tensor list."; @@ -292,7 +317,7 @@ else_branch: A function that takes 'inputs' and returns a list of } def TF_YieldOp : TF_Op<"Yield", - [Terminator, ParentOneOf<["IfRegionOp", "WhileRegionOp"]>]> { + [Terminator, ParentOneOf<["CaseRegionOp", "IfRegionOp", "WhileRegionOp"]>]> { let summary = "Yield operation"; let description = [{ @@ -328,7 +353,7 @@ else_branch: A region that computes the outputs of the op if cond = false. }]; let arguments = (ins - TF_Tensor:$cond, + 0DTensorOf<[I1]>:$cond, // Used to map StatelessIf and If op defined in TensorFlow to a common op. BoolAttr:$is_stateless @@ -338,47 +363,13 @@ else_branch: A region that computes the outputs of the op if cond = false. Variadic:$output ); - TF_DerivedOperandTypeAttr Tcond = TF_DerivedOperandTypeAttr<0>; - TF_DerivedOperandTypeListAttr Tin = TF_DerivedOperandTypeListAttr<1>; - TF_DerivedResultTypeListAttr Tout = TF_DerivedResultTypeListAttr<0>; - let regions = (region SizedRegion<1>:$then_branch, SizedRegion<1>:$else_branch); let verifier = [{ return Verify(*this); }]; -} -def TF_MeanOp : TF_Op<"Mean", [NoSideEffect, TF_FoldOperandsTransposeInterface]> { - let summary = "Computes the mean of elements across dimensions of a tensor."; - - let description = [{ -Reduces `input` along the dimensions given in `axis`. Unless -`keep_dims` is true, the rank of the tensor is reduced by 1 for each entry in -`axis`. If `keep_dims` is true, the reduced dimensions are -retained with length 1. - }]; - - let arguments = (ins - TF_NumberTensor:$input, - TF_I32OrI64Tensor:$reduction_indices, - - DefaultValuedAttr:$keep_dims - ); - - let results = (outs - TF_NumberTensor:$output - ); - - TF_DerivedOperandTypeAttr T = TF_DerivedOperandTypeAttr<0>; - TF_DerivedOperandTypeAttr Tidx = TF_DerivedOperandTypeAttr<1>; - - let extraClassDeclaration = [{ - // TF_FoldOperandsTransposeInterface: - SmallVector GetLayoutDependentArgs() { return {0}; } - SmallVector GetLayoutDependentResults() { return {}; } - LogicalResult FoldOperandsPermutation(ArrayRef permutation); - }]; + let hasCanonicalizer = 1; } def TF_LegacyCallOp : TF_Op<"LegacyCall", @@ -755,8 +746,6 @@ def TL_WhileRegionOp : TF_Op<"WhileRegion", ); let results = (outs Variadic:$output); - TF_DerivedOperandTypeListAttr T = TF_DerivedOperandTypeListAttr<0>; - let regions = (region SizedRegion<1>:$cond, SizedRegion<1>:$body); let verifier = [{ return Verify(*this); }]; @@ -841,45 +830,6 @@ Example: TF_DerivedOperandOrResultHandleShapeAttr<"resource">; } -// Not generated because it begins with an underscore, which isn't allowed by -// the C++ standard. -def TF_FusedBatchNormExOp : TF_Op<"_FusedBatchNormEx", [NoSideEffect]> { - let summary = "Internal FusedBatchNorm operation: reserved for internal use"; - - let description = [{ - Do not invoke this operator directly in Python. A fusion optimization is - expected to create these operators. - }]; - - let arguments = (ins - TensorOf<[F16, F32]>:$x, - F32Tensor:$scale, - F32Tensor:$offset, - F32Tensor:$mean, - F32Tensor:$variance, - Variadic>:$side_input, - - DefaultValuedAttr:$epsilon, - DefaultValuedAttr:$exponential_avg_factor, - DefaultValuedAttr:$activation_mode, - DefaultValuedAttr:$data_format, - DefaultValuedAttr:$is_training - ); - - let results = (outs - TensorOf<[F16, F32]>:$y, - F32Tensor:$batch_mean, - F32Tensor:$batch_variance, - F32Tensor:$reserve_space_1, - F32Tensor:$reserve_space_2, - F32Tensor:$reserve_space_3 - ); - - TF_DerivedOperandTypeAttr T = TF_DerivedOperandTypeAttr<0>; - TF_DerivedOperandTypeAttr U = TF_DerivedOperandTypeAttr<1>; - TF_DerivedOperandSizeAttr num_side_inputs = TF_DerivedOperandSizeAttr<5>; -} - // Multiple variadic operands with different sizes are not supported by the // dialect generator, so we manually added the op. def TF_SendTPUEmbeddingGradientsOp : TF_Op<"SendTPUEmbeddingGradients", [AttrSizedOperandSegments]> { @@ -1150,6 +1100,43 @@ def TF_TensorSliceDatasetOp : TF_Op<"TensorSliceDataset", []> { TF_DerivedOperandTypeListAttr Toutput_types = TF_DerivedOperandTypeListAttr<0>; } +def TF_ToBoolOp : TF_Op<"ToBool", [NoSideEffect]> { + let summary = "Converts a tensor to a scalar predicate."; + + let description = [{ +Converts a tensor to a scalar predicate with the following rules: + +- For 0D tensors, truthiness is determined by comparing against a "zero" + value. For numerical types it is the obvious zero. For strings it is the + empty string. + +- For >0D tensors, truthiness is determined by looking at the number of + elements. If has zero elements, then the result is false. Otherwise the + result is true. + +This matches the behavior of If and While for determining if a tensor counts +as true/false for a branch condition. + }]; + + let arguments = (ins + TF_Tensor:$input + ); + + let results = (outs + 0DTensorOf<[I1]>:$output + ); + + TF_DerivedOperandTypeAttr T = TF_DerivedOperandTypeAttr<0>; + + let builders = [OpBuilder< + "OpBuilder &builder, OperationState &result, Value value", [{ + build(builder, result, RankedTensorType::get({}, builder.getI1Type()), + value); + }]>]; + + let hasCanonicalizer = 1; +} + def TF_BesselI0eOp : TF_Op<"BesselI0e", [NoSideEffect, SameOperandsAndResultType]> { let summary = "Computes the Bessel i0e function of `x` element-wise."; @@ -1192,36 +1179,6 @@ This function is faster and numerically stabler than `bessel_i1(x)`. TF_DerivedOperandTypeAttr T = TF_DerivedOperandTypeAttr<0>; } -def TF_StringToHashBucketFastOp : TF_Op<"StringToHashBucketFast", [NoSideEffect]> { - let summary = [{ -Converts each string in the input Tensor to its hash mod by a number of buckets. - }]; - - let description = [{ -The hash function is deterministic on the content of the string within the -process and will never change. However, it is not suitable for cryptography. -This function may be used when CPU time is scarce and inputs are trusted or -unimportant. There is a risk of adversaries constructing inputs that all hash -to the same bucket. To prevent this problem, use a strong hash function with -`tf.string_to_hash_bucket_strong`. - -Examples: - ->>> tf.strings.to_hash_bucket_fast(["Hello", "TensorFlow", "2.x"], 3).numpy() -array([0, 2, 2]) - }]; - - let arguments = (ins - TF_StrTensor:$input, - - Confined]>:$num_buckets - ); - - let results = (outs - I64Tensor:$output - ); -} - def TF_TPUPartitionedCallOp : TF_Op<"TPUPartitionedCall", [CallOpInterface]> { let summary = "Calls a function placed on a specified TPU device."; diff --git a/tensorflow/compiler/mlir/tensorflow/ir/tf_ops_a_m.cc b/tensorflow/compiler/mlir/tensorflow/ir/tf_ops_a_m.cc index 1a730a38618..b465c1da68c 100644 --- a/tensorflow/compiler/mlir/tensorflow/ir/tf_ops_a_m.cc +++ b/tensorflow/compiler/mlir/tensorflow/ir/tf_ops_a_m.cc @@ -18,6 +18,7 @@ limitations under the License. #include #include #include +#include #include #include #include @@ -37,6 +38,7 @@ limitations under the License. #include "llvm/ADT/iterator_range.h" #include "llvm/Support/Casting.h" #include "llvm/Support/FormatVariadic.h" +#include "llvm/Support/raw_ostream.h" #include "mlir/Dialect/StandardOps/IR/Ops.h" // from @llvm-project #include "mlir/Dialect/Traits.h" // from @llvm-project #include "mlir/IR/Attributes.h" // from @llvm-project @@ -64,6 +66,7 @@ limitations under the License. #include "tensorflow/compiler/mlir/tensorflow/ir/tf_side_effects.h" #include "tensorflow/compiler/mlir/tensorflow/ir/tf_structs.h" #include "tensorflow/compiler/mlir/tensorflow/ir/tf_types.h" +#include "tensorflow/compiler/mlir/tensorflow/utils/attribute_utils.h" #include "tensorflow/core/platform/logging.h" #include "tensorflow/core/util/tensor_format.h" @@ -438,6 +441,19 @@ static LogicalResult Verify(BroadcastToOp op) { return success(); } +OpFoldResult BroadcastToOp::fold(ArrayRef operands) { + Value input = this->input(); + + // Fold broadcast if operand and result types are the same and all dimensions + // are statically known (no-op broadcast). + auto result_ty = getType().dyn_cast(); + if (result_ty && result_ty.hasStaticShape() && result_ty == input.getType()) { + return input; + } + + return {}; +} + //===----------------------------------------------------------------------===// // CaseOp //===----------------------------------------------------------------------===// @@ -456,28 +472,139 @@ LogicalResult FoldConstantCaseOp::matchAndRewrite( DenseIntElementsAttr branch; if (!matchPattern(op.branch_index(), m_Constant(&branch))) return failure(); - // Only attempt to fold scalar valued case statements. - // TODO(jpienaar): This can be removed if CaseOp's verifier covers it. - if (!branch.getType().cast().getShape().empty()) - return failure(); - int index = *branch.getValues().begin(); - // TODO(jpienaar): This can be removed if CaseOp's verifier covers it. - if (index >= op.branches().size()) return failure(); + if (index < 0 || index >= op.branches().size()) + index = op.branches().size() - 1; auto func = op.branches()[index].cast(); auto empty = rewriter.getStringAttr(""); auto call_op = rewriter.create( op.getLoc(), op.getResultTypes(), op.getOperands().drop_front(), func, /*config=*/empty, /*config_proto=*/empty, /*executor_type=*/empty); - PropagateDeviceAndInternalAttrs(op.getOperation(), call_op); + CopyDeviceAndUnderscoredAttributes(op.getOperation(), call_op); rewriter.replaceOp(op, call_op.getResults()); return success(); } void CaseOp::getCanonicalizationPatterns(OwningRewritePatternList &results, MLIRContext *context) { - results.insert(context); + results.insert>(context); +} + +static LogicalResult VerifyCaseOpBase(Operation *op, Value branch_index) { + if (!IsOfRankOrUnranked(branch_index, 0)) + return op->emitOpError() + << "expects 'branch_index' to be a scalar, but got " + << branch_index.getType(); + return success(); +} + +static LogicalResult VerifyCaseOrIfOpBranchFunctions( + Operation *op, ArrayRef branches, + llvm::function_ref branch_name) { + SmallVector branch_types; + branch_types.reserve(branches.size()); + + // Functions have one less operand compared to op as first operand is elided + // (`cond` of `tf.If` and `branch_index` of `tf.Case`). + int expected_num_inputs = op->getNumOperands() - 1; + int expected_num_results = op->getNumResults(); + for (auto branch : llvm::enumerate(branches)) { + auto branch_func = SymbolTable::lookupNearestSymbolFrom( + op, branch.value().cast()); + if (!branch_func) + return op->emitOpError() + << "expects " << branch_name(branch.index()) << " (" + << branch.value() << ") to point to a defined function"; + + FunctionType branch_type = branch_func.getType(); + if (branch_type.getNumInputs() != expected_num_inputs) + return op->emitOpError() + << "expects all branches to have " << expected_num_inputs + << " input(s), but " << branch_name(branch.index()) << " has " + << branch_type.getNumInputs() << " input(s)"; + + if (branch_type.getNumResults() != expected_num_results) + return op->emitOpError() + << "expects all branches to have " << expected_num_results + << " result(s), but " << branch_name(branch.index()) << " has " + << branch_type.getNumResults() << " result(s)"; + + // Non-conditional operands starting with the second operand are passed to + // branches and should be compatible across all branches' inputs. + for (auto operand_type : + llvm::enumerate(llvm::drop_begin(op->getOperandTypes(), 1))) { + Type branch_input_i_type = branch_type.getInput(operand_type.index()); + if (!AreCastCompatible({operand_type.value(), branch_input_i_type})) + return op->emitOpError() + << "expects operand type " << operand_type.value() + << " to be cast compatible with " << branch_name(branch.index()) + << " input type " << branch_input_i_type << " at index " + << operand_type.index(); + } + + // Branches' results should be pair-wise compatible with the op results. + for (auto result_type : llvm::enumerate(op->getResultTypes())) { + Type branch_result_i_type = branch_type.getResult(result_type.index()); + if (!AreCastCompatible({result_type.value(), branch_result_i_type})) + return op->emitOpError() + << "expects result type " << result_type.value() + << " to be cast compatible with " << branch_name(branch.index()) + << " result type " << branch_result_i_type << " at index " + << result_type.index(); + } + + branch_types.push_back(branch_type); + } + + // If branches have incompatible input types that means that no tensor can + // serve as input to all the functions. Hence, the op is invalid. + for (int i = 0; i < expected_num_inputs; ++i) { + SmallVector branch_input_i_types; + branch_input_i_types.reserve(branches.size()); + llvm::transform( + branch_types, std::back_inserter(branch_input_i_types), + [i](FunctionType &branch_type) { return branch_type.getInput(i); }); + if (!AreCastCompatible(branch_input_i_types)) { + std::string input_types_str; + llvm::raw_string_ostream os(input_types_str); + llvm::interleaveComma(branch_input_i_types, os); + return op->emitOpError() + << "expects all branch input type(s) (" << os.str() + << ") at index " << i << " to be cast compatible"; + } + } + + return success(); +} + +static LogicalResult Verify(CaseOp op) { + if (failed(VerifyCaseOpBase(op, op.branch_index()))) return failure(); + auto branch_name = [](unsigned index) { + return llvm::formatv("branch #{0}", index).str(); + }; + return VerifyCaseOrIfOpBranchFunctions(op, op.branches().getValue(), + branch_name); +} + +//===----------------------------------------------------------------------===// +// CaseRegionOp +//===----------------------------------------------------------------------===// + +static LogicalResult Verify(CaseRegionOp op) { + if (op.branches().empty()) + return op.emitOpError() << "expects to have at least 1 region"; + + if (failed(VerifyCaseOpBase(op, op.branch_index()))) return failure(); + + for (auto region_and_idx : llvm::enumerate(op.branches())) { + std::string region_name = + llvm::formatv("region #{0}", region_and_idx.index()).str(); + if (failed(VerifyRegionResults(op, region_and_idx.value(), region_name))) + return failure(); + } + + return success(); } //===----------------------------------------------------------------------===// @@ -734,6 +861,35 @@ void ConcatV2Op::getCanonicalizationPatterns(OwningRewritePatternList &results, context); } +//===----------------------------------------------------------------------===// +// CumsumOp and CumprodOp +//===----------------------------------------------------------------------===// + +template ::value>::type * = nullptr> +static LogicalResult Verify(OpT op) { + if (!IsOfRankOrUnranked(op.axis(), 0)) + return op.emitOpError("requires scalar axis operand"); + + DenseIntElementsAttr axis_attr; + if (matchPattern(op.axis(), m_Constant(&axis_attr))) { + auto input_ty = op.x().getType().template dyn_cast(); + if (input_ty) { + int64_t rank = input_ty.getRank(); + assert(axis_attr.getNumElements() == 1 && + "scalar attribute should have exactly one element"); + int64_t axis = (*axis_attr.begin()).getSExtValue(); + if (axis < -rank || axis >= rank) { + return op.emitError() + << "axis operand should be within range [" << -rank << ", " + << rank << "); actual value: " << axis; + } + } + } + + return success(); +} + //===----------------------------------------------------------------------===// // ConcatOffsetOp //===----------------------------------------------------------------------===// @@ -1768,79 +1924,18 @@ static LogicalResult Verify(GatherV2Op op) { //===----------------------------------------------------------------------===// static LogicalResult Verify(IfOp op) { - auto then_fn = op.then_func(); - if (!then_fn) - return op.emitOpError("then_branch refers to an undefined function : ") - << op.then_branch(); - auto else_fn = op.else_func(); - if (!else_fn) - return op.emitOpError("else_branch refers to an undefined function : ") - << op.else_branch(); - auto then_fn_type = then_fn.getType(); - auto else_fn_type = else_fn.getType(); - - // Non-conditional operands starting with the second operand are passed to - // branches and should be pair-wise compatible with branches' inputs. - unsigned expected_num_inputs = op.getNumOperands() - 1; - if (then_fn_type.getNumInputs() != expected_num_inputs || - else_fn_type.getNumInputs() != expected_num_inputs) - return op.emitError("branches should have " + Twine(expected_num_inputs) + - " inputs"); - - for (unsigned i = 0; i < expected_num_inputs; ++i) { - auto operand_type = op.getOperand(i + 1).getType().cast(); - auto then_input_type = then_fn_type.getInput(i).cast(); - if (!AreCastCompatible({operand_type, then_input_type})) - return op.emitError( - llvm::formatv("then branch input type {0} is incompatible with " - "operand type {1} at index {2}", - then_input_type, operand_type, i)); - - auto else_input_type = else_fn_type.getInput(i).cast(); - if (!AreCastCompatible({operand_type, else_input_type})) - return op.emitError( - llvm::formatv("else branch input type {0} is incompatible with " - "operand type {1} at index {2}", - else_input_type, operand_type, i)); - - // If branches have incompatible input types that means that no tensor can - // serve as input to both the functions. Hence, the op is invalid. - if (!AreCastCompatible({then_input_type, else_input_type})) - return op.emitError(llvm::formatv( - "branches inputs have incompatible types {0} and {1} at index {2}", - then_input_type, else_input_type, i)); - } - - // Branches' results should be pair-wise compatible with the op results. - unsigned expected_num_results = op.getNumResults(); - if (then_fn_type.getNumResults() != expected_num_results || - else_fn_type.getNumResults() != expected_num_results) - return op.emitError("branches should have " + Twine(expected_num_results) + - " results"); - - for (unsigned i = 0; i < expected_num_results; ++i) { - auto result_type = op.getResult(i).getType().cast(); - auto then_result_type = then_fn_type.getResult(i).cast(); - if (!AreCastCompatible({then_result_type, result_type})) - return op.emitError( - llvm::formatv("then branch result type {0} is incompatible with op " - "result type {1} at index {2}", - then_result_type, result_type, i)); - - auto else_result_type = else_fn_type.getResult(i).cast(); - if (!AreCastCompatible({else_result_type, result_type})) - return op.emitError( - llvm::formatv("else branch result type {0} is incompatible with op " - "result type {1} at index {2}", - else_result_type, result_type, i)); - } - return success(); + auto branch_name = [](unsigned index) -> std::string { + return index == 0 ? "'then_branch'" : "'else_branch'"; + }; + return VerifyCaseOrIfOpBranchFunctions( + op, {op.then_branchAttr(), op.else_branchAttr()}, branch_name); } //===----------------------------------------------------------------------===// // IfOp canonicalization. //===----------------------------------------------------------------------===// +namespace { class FoldConstantIfOp : public OpRewritePattern { public: explicit FoldConstantIfOp(MLIRContext *context) @@ -1872,9 +1967,9 @@ LogicalResult FoldConstantIfOp::matchAndRewrite( auto rewrite = [&](auto op_type) { auto empty = rewriter.getStringAttr(""); auto call_op = rewriter.create( - op.getLoc(), op.getResultTypes(), op.getOperands().drop_front(), func, + op.getLoc(), op.getResultTypes(), op.input(), func, /*config=*/empty, /*config_proto=*/empty, /*executor_type=*/empty); - PropagateDeviceAndInternalAttrs(op.getOperation(), call_op); + CopyDeviceAndUnderscoredAttributes(op.getOperation(), call_op); rewriter.replaceOp(op, call_op.getResults()); }; @@ -1885,6 +1980,7 @@ LogicalResult FoldConstantIfOp::matchAndRewrite( return success(); } +} // anonymous namespace void IfOp::getCanonicalizationPatterns(OwningRewritePatternList &results, MLIRContext *context) { @@ -1903,6 +1999,61 @@ static LogicalResult Verify(IfRegionOp op) { return success(); } +namespace { +class FoldConstantIfRegionOp : public OpRewritePattern { + public: + explicit FoldConstantIfRegionOp(MLIRContext *context) + : OpRewritePattern(context) {} + LogicalResult matchAndRewrite(TF::IfRegionOp op, + PatternRewriter &rewriter) const override; +}; + +LogicalResult FoldConstantIfRegionOp::matchAndRewrite( + TF::IfRegionOp op, PatternRewriter &rewriter) const { + // Extract the constant cond value. + DenseIntElementsAttr cond_attr; + if (!matchPattern(op.cond(), m_Constant(&cond_attr))) return failure(); + + // IfRegion condition should always be a scalar. Select the region to fold to. + bool cond = cond_attr.getSplatValue().getValue(); + Region ®ion = cond ? op.then_branch() : op.else_branch(); + + // If the IfRegion is stateless but the region being inlined itself is not + // stateless, then inlining the region could cause a loss of information. + // However, its probably better to fold the IfRegion instead of having the + // dead branch stay. + + // Inline the region in place of the IfRegion op, and forward the yield + // inputs to the IfRegion op results. This is possible only if the yield + // types match the result types. + auto yield = cast(region.front().getTerminator()); + auto updated_results = llvm::to_vector<4>(yield.getOperands()); + + // If the yield types do not match the IfRegion result types, add appropriate + // casts. + rewriter.setInsertionPoint(yield); + for (auto it : llvm::zip(op.getResultTypes(), updated_results)) { + auto &updated_result = std::get<1>(it); + Type result_type = std::get<0>(it); + if (result_type != updated_result.getType()) { + updated_result = + rewriter.create(op.getLoc(), result_type, updated_result, + /*Truncate=*/rewriter.getBoolAttr(false)); + } + } + // Inline the region into the block containing the IfRegion. + rewriter.mergeBlockBefore(®ion.front(), op); + rewriter.eraseOp(yield); + rewriter.replaceOp(op, updated_results); + return success(); +} +} // anonymous namespace + +void IfRegionOp::getCanonicalizationPatterns(OwningRewritePatternList &results, + MLIRContext *context) { + results.insert(context); +} + //===----------------------------------------------------------------------===// // InvertOp //===----------------------------------------------------------------------===// diff --git a/tensorflow/compiler/mlir/tensorflow/ir/tf_ops_helpers.inc b/tensorflow/compiler/mlir/tensorflow/ir/tf_ops_helpers.inc index 71f1560aa6c..bb7d9a50521 100644 --- a/tensorflow/compiler/mlir/tensorflow/ir/tf_ops_helpers.inc +++ b/tensorflow/compiler/mlir/tensorflow/ir/tf_ops_helpers.inc @@ -18,17 +18,6 @@ limitations under the License. // tf_verifiers or tf_ops. // TODO(jpienaar): Remove this file post refactoring. -// Propagates underscore and device attributes from src to dst. -// TODO(b/158769932): This should be a general feature instead post some policy -// discussion. -static void PropagateDeviceAndInternalAttrs(Operation *src, Operation *dst) { - auto device = mlir::Identifier::get("device", src->getContext()); - for (auto named_attr : src->getAttrs()) { - if (*named_attr.first.begin() == '_' || named_attr.first == device) - dst->setAttr(named_attr.first, named_attr.second); - } -} - //===----------------------------------------------------------------------===// // TF op helper functions //===----------------------------------------------------------------------===// diff --git a/tensorflow/compiler/mlir/tensorflow/ir/tf_ops_n_z.cc b/tensorflow/compiler/mlir/tensorflow/ir/tf_ops_n_z.cc index 887473efbea..cbac03f80f8 100644 --- a/tensorflow/compiler/mlir/tensorflow/ir/tf_ops_n_z.cc +++ b/tensorflow/compiler/mlir/tensorflow/ir/tf_ops_n_z.cc @@ -707,7 +707,6 @@ OpFoldResult ReshapeOp::fold(ArrayRef operands) { // Fold reshape if operand and result types are the same and all dimensions // are statically known (no-op reshape). - // TODO(ezhulenev): Add the same folding for BroadcastToOp. auto result_ty = getType().dyn_cast(); if (result_ty && result_ty.hasStaticShape() && result_ty == tensor.getType()) { @@ -1015,9 +1014,23 @@ static LogicalResult Verify(SizeOp op) { return op.emitOpError( "requires ranked input tensor to be of rank INT32_MAX or less"); + // Output type needs to be scalar. + if (!IsOfRankOrUnranked(op.output(), /*rank=*/0)) + return op.emitOpError("requires scalar output"); + return success(); } +OpFoldResult SizeOp::fold(ArrayRef operands) { + ShapedType output_type = getType().cast(); + ShapedType input_type = getOperand().getType().cast(); + if (!input_type.hasStaticShape()) return {}; + int size = input_type.getNumElements(); + return DenseElementsAttr::get( + output_type, + IntegerAttr::get(output_type.getElementType(), /*value=*/size)); +} + //===----------------------------------------------------------------------===// // SliceOp //===----------------------------------------------------------------------===// @@ -1783,26 +1796,57 @@ static LogicalResult Verify(TopKV2Op op) { //===----------------------------------------------------------------------===// namespace { -// If the input to ToBoolOp is a `tensor`, then the ToBoolOp is an identity -// function and can be removed. -class ToBoolOfZeroDBoolTensor : public OpRewritePattern { +// If the input to ToBoolOp is a ranked tensor, then the ToBoolOp can be folded +// into an identity or an equality comparison. +class ToBoolOfRankedTensor : public OpRewritePattern { using OpRewritePattern::OpRewritePattern; LogicalResult matchAndRewrite(ToBoolOp op, PatternRewriter &rewriter) const override { - if (auto type = op.getOperand().getType().dyn_cast()) { - if (type.getRank() == 0 && type.getElementType().isInteger(1)) { - rewriter.replaceOp(op, op.getOperand()); - return success(); - } + auto type = op.getOperand().getType().dyn_cast(); + // If the input is an unranked tensor, cannpt rewrite. + if (!type) return failure(); + + // Expected return type of the ToBool operation. + auto result_type = op.getResult().getType().cast(); + + // If input is already a tensor, it can be folded into an identity. + if (type == result_type) { + rewriter.replaceOp(op, op.getOperand()); + return success(); } - return failure(); + + if (type.getRank() == 0) { + // If the input is a scalar tensor, the ToBool can be expanded to + // element != 0 (for numerical values) or element == empty (for string). + Type element_type = type.getElementType(); + Attribute zero_attr; + if (element_type.isIntOrFloat()) + zero_attr = rewriter.getZeroAttr(type); + else if (element_type.isa()) + zero_attr = DenseStringElementsAttr::get(type, {""}); + + if (!zero_attr) return failure(); + + auto zero_const = rewriter.create(op.getLoc(), zero_attr); + rewriter.replaceOpWithNewOp( + op, result_type, op.getOperand(), zero_const, false); + } else { + // If the input is a non-scalar ranked tensor, ToBool can be expanded + // to numElements != 0. numElements will be 0 iff one of the dimensions is + // zero. + bool any_zero = + llvm::any_of(type.getShape(), [](int64_t dim) { return dim == 0; }); + rewriter.replaceOpWithNewOp( + op, result_type, DenseElementsAttr::get(result_type, {!any_zero})); + } + return success(); } }; } // namespace void ToBoolOp::getCanonicalizationPatterns(OwningRewritePatternList &results, MLIRContext *context) { - results.insert(context); + results.insert(context); } //===----------------------------------------------------------------------===// @@ -1895,11 +1939,9 @@ void TransposeOp::build(OpBuilder &builder, OperationState &result, Value x, namespace { OpFoldResult FoldIdentityTranspose(TransposeOp op) { - auto const_perm = dyn_cast_or_null(op.perm().getDefiningOp()); - if (!const_perm) return {}; - - auto const_value = const_perm.value(); - const auto elements = const_value.getValues(); + DenseIntElementsAttr perm; + if (!matchPattern(op.perm(), m_Constant(&perm))) return {}; + const auto elements = perm.getValues(); for (auto it : llvm::enumerate(elements)) { if (it.index() != it.value()) return {}; @@ -1922,14 +1964,14 @@ OpFoldResult FoldCancellableTranspose(TransposeOp op) { if (!transpose) return {}; // Permutations defined by constant operations. - auto perm0 = dyn_cast_or_null(op.perm().getDefiningOp()); - auto perm1 = dyn_cast_or_null(transpose.perm().getDefiningOp()); - if (!perm0 || !perm1) return {}; + DenseIntElementsAttr perm0; + DenseIntElementsAttr perm1; + if (!matchPattern(op.perm(), m_Constant(&perm0)) || + !matchPattern(transpose.perm(), m_Constant(&perm1))) + return {}; // With permutation indices that cancel each other - auto perm0_value = perm0.value().cast(); - auto perm1_value = perm1.value().cast(); - if (!AreCancellablePermutations(perm0_value, perm1_value)) return {}; + if (!AreCancellablePermutations(perm0, perm1)) return {}; return transpose.x(); } diff --git a/tensorflow/compiler/mlir/tensorflow/ir/tf_traits.h b/tensorflow/compiler/mlir/tensorflow/ir/tf_traits.h index fc8e6f40f65..412bf113a0f 100644 --- a/tensorflow/compiler/mlir/tensorflow/ir/tf_traits.h +++ b/tensorflow/compiler/mlir/tensorflow/ir/tf_traits.h @@ -33,7 +33,7 @@ namespace TF { static inline LogicalResult VerifyRefTypeMatch(mlir::Type type, mlir::Type maybe_ref_type) { if (auto ref_type = maybe_ref_type.dyn_cast()) - return success(ref_type.RemoveRef().getKind() == type.getKind()); + return success(ref_type.RemoveRef().getTypeID() == type.getTypeID()); return failure(); } diff --git a/tensorflow/compiler/mlir/tensorflow/ir/tf_types.cc b/tensorflow/compiler/mlir/tensorflow/ir/tf_types.cc index 994378ea1cf..2ec73824f6c 100644 --- a/tensorflow/compiler/mlir/tensorflow/ir/tf_types.cc +++ b/tensorflow/compiler/mlir/tensorflow/ir/tf_types.cc @@ -17,6 +17,7 @@ limitations under the License. #include "llvm/Support/ErrorHandling.h" #include "mlir/Dialect/Traits.h" // from @llvm-project +#include "mlir/IR/Dialect.h" // from @llvm-project #include "mlir/IR/StandardTypes.h" // from @llvm-project #include "mlir/IR/TypeUtilities.h" // from @llvm-project @@ -100,7 +101,7 @@ mlir::Type GetCastCompatibleType(mlir::Type a, mlir::Type b, if (a == b) return a; } } - if (a.getKind() != b.getKind()) return nullptr; + if (a.getTypeID() != b.getTypeID()) return nullptr; // If either is not a type that contain subtypes then the types are not cast // compatible. @@ -178,127 +179,116 @@ ResultShapeIterator::ResultShapeIterator(Operation::result_iterator it) // TF types helper functions //===----------------------------------------------------------------------===// +bool TensorFlowType::classof(Type type) { + return type.getDialect().getNamespace() == "tf"; +} +bool TensorFlowRefType::classof(Type type) { + return type.isa< +#define HANDLE_TF_TYPE(tftype, enumerant, name) +#define HANDLE_TF_REF_TYPE(tftype, enumerant, name) tftype##Type, +#define HANDLE_LAST_TF_TYPE(tftype, enumerant, name) tftype##Type +// NOLINTNEXTLINE +#include "tensorflow/compiler/mlir/tensorflow/ir/tf_types.def" + >(); +} +bool TensorFlowTypeWithSubtype::classof(Type type) { + return type.isa(); +} + TensorFlowType TensorFlowRefType::get(Type type) { MLIRContext* ctx = type.getContext(); - switch (getElementTypeOrSelf(type).getKind()) { - case StandardTypes::F16: - return HalfRefType::get(ctx); - case StandardTypes::F32: - return FloatRefType::get(ctx); - case StandardTypes::F64: - return DoubleRefType::get(ctx); - case StandardTypes::BF16: - return Bfloat16RefType::get(ctx); - case StandardTypes::Complex: { - const auto& etype = type.cast().getElementType(); - switch (getElementTypeOrSelf(etype).getKind()) { - case StandardTypes::F32: - return Complex64RefType::get(ctx); - case StandardTypes::F64: - return Complex128RefType::get(ctx); - default: - llvm_unreachable("unexpected complex type"); - } + type = getElementTypeOrSelf(type); + if (type.isF16()) { + return HalfRefType::get(ctx); + } else if (type.isF32()) { + return FloatRefType::get(ctx); + } else if (type.isF64()) { + return DoubleRefType::get(ctx); + } else if (type.isBF16()) { + return Bfloat16RefType::get(ctx); + } else if (auto complex_type = type.dyn_cast()) { + Type etype = complex_type.getElementType(); + if (etype.isF32()) { + return Complex64RefType::get(ctx); + } else if (etype.isF64()) { + return Complex128RefType::get(ctx); } - case StandardTypes::Integer: { - const auto& itype = type.cast(); - switch (itype.getWidth()) { - case 1: - return BoolRefType::get(ctx); - case 8: - return itype.isUnsigned() ? TensorFlowType(Uint8RefType::get(ctx)) - : Int8RefType::get(ctx); - case 16: - return itype.isUnsigned() ? TensorFlowType(Uint16RefType::get(ctx)) - : Int16RefType::get(ctx); - case 32: - return itype.isUnsigned() ? TensorFlowType(Uint32RefType::get(ctx)) - : Int32RefType::get(ctx); - case 64: - return itype.isUnsigned() ? TensorFlowType(Uint64RefType::get(ctx)) - : Int64RefType::get(ctx); - default: - llvm_unreachable("unexpected integer type"); - } + llvm_unreachable("unexpected complex type"); + } else if (auto itype = type.dyn_cast()) { + switch (itype.getWidth()) { + case 1: + return BoolRefType::get(ctx); + case 8: + return itype.isUnsigned() ? TensorFlowType(Uint8RefType::get(ctx)) + : Int8RefType::get(ctx); + case 16: + return itype.isUnsigned() ? TensorFlowType(Uint16RefType::get(ctx)) + : Int16RefType::get(ctx); + case 32: + return itype.isUnsigned() ? TensorFlowType(Uint32RefType::get(ctx)) + : Int32RefType::get(ctx); + case 64: + return itype.isUnsigned() ? TensorFlowType(Uint64RefType::get(ctx)) + : Int64RefType::get(ctx); + default: + llvm_unreachable("unexpected integer type"); } -#define HANDLE_TF_TYPE(tftype, enumerant, name) \ - case TensorFlowTypes::enumerant: \ + } +#define HANDLE_TF_TYPE(tftype, enumerant, name) \ + if (auto derived_ty = type.dyn_cast()) \ return tftype##RefType::get(ctx); #define HANDLE_TF_REF_TYPE(tftype, enumerant, name) // NOLINTNEXTLINE #include "tensorflow/compiler/mlir/tensorflow/ir/tf_types.def" - default: - llvm_unreachable("unexpected type kind"); - } + llvm_unreachable("unexpected type kind"); } Type TensorFlowRefType::RemoveRef() { MLIRContext* ctx = getContext(); - switch (getKind()) { - case TensorFlowTypes::HALF_REF: - return mlir::FloatType::getF16(ctx); - case TensorFlowTypes::FLOAT_REF: - return mlir::FloatType::getF32(ctx); - case TensorFlowTypes::DOUBLE_REF: - return mlir::FloatType::getF64(ctx); - case TensorFlowTypes::BFLOAT16_REF: - return mlir::FloatType::getBF16(ctx); - case TensorFlowTypes::BOOL_REF: - return mlir::IntegerType::get(1, ctx); - case TensorFlowTypes::INT8_REF: - return mlir::IntegerType::get(8, ctx); - case TensorFlowTypes::INT16_REF: - return mlir::IntegerType::get(16, ctx); - case TensorFlowTypes::INT32_REF: - return mlir::IntegerType::get(32, ctx); - case TensorFlowTypes::INT64_REF: - return mlir::IntegerType::get(64, ctx); - case TensorFlowTypes::UINT8_REF: - return mlir::IntegerType::get(8, IntegerType::Unsigned, ctx); - case TensorFlowTypes::UINT16_REF: - return mlir::IntegerType::get(16, IntegerType::Unsigned, ctx); - case TensorFlowTypes::UINT32_REF: - return mlir::IntegerType::get(32, IntegerType::Unsigned, ctx); - case TensorFlowTypes::UINT64_REF: - return mlir::IntegerType::get(64, IntegerType::Unsigned, ctx); - case TensorFlowTypes::COMPLEX64_REF: - return mlir::ComplexType::get(mlir::FloatType::getF32(ctx)); - case TensorFlowTypes::COMPLEX128_REF: - return mlir::ComplexType::get(mlir::FloatType::getF64(ctx)); + if (isa()) return mlir::FloatType::getF16(ctx); + if (isa()) return mlir::FloatType::getF32(ctx); + if (isa()) return mlir::FloatType::getF64(ctx); + if (isa()) return mlir::FloatType::getBF16(ctx); + if (isa()) return mlir::IntegerType::get(1, ctx); + if (isa()) return mlir::IntegerType::get(8, ctx); + if (isa()) return mlir::IntegerType::get(16, ctx); + if (isa()) return mlir::IntegerType::get(32, ctx); + if (isa()) return mlir::IntegerType::get(64, ctx); + if (isa()) + return mlir::IntegerType::get(8, IntegerType::Unsigned, ctx); + if (isa()) + return mlir::IntegerType::get(16, IntegerType::Unsigned, ctx); + if (isa()) + return mlir::IntegerType::get(32, IntegerType::Unsigned, ctx); + if (isa()) + return mlir::IntegerType::get(64, IntegerType::Unsigned, ctx); + if (isa()) + return mlir::ComplexType::get(mlir::FloatType::getF32(ctx)); + if (isa()) + return mlir::ComplexType::get(mlir::FloatType::getF64(ctx)); #define HANDLE_TF_TYPE(tftype, enumerant, name) \ - case TensorFlowTypes::enumerant##_REF: \ - return tftype##Type::get(ctx); + if (isa()) return tftype##Type::get(ctx); #define HANDLE_TF_REF_TYPE(tftype, enumerant, name) // NOLINTNEXTLINE #include "tensorflow/compiler/mlir/tensorflow/ir/tf_types.def" - default: - llvm_unreachable("unexpected tensorflow ref type kind"); - } + llvm_unreachable("unexpected tensorflow ref type kind"); } Type TensorFlowTypeWithSubtype::RemoveSubtypes() { MLIRContext* ctx = getContext(); - switch (getKind()) { - case TensorFlowTypes::VARIANT: - return VariantType::get(ctx); - case TensorFlowTypes::RESOURCE: - return ResourceType::get(ctx); - default: - llvm_unreachable("unexpected tensorflow type with subtypes kind"); - } + if (isa()) return VariantType::get(ctx); + if (isa()) return ResourceType::get(ctx); + llvm_unreachable("unexpected tensorflow type with subtypes kind"); } ArrayRef TensorFlowTypeWithSubtype::GetSubtypes() { - switch (getKind()) { - case TensorFlowTypes::VARIANT: - return this->cast().getSubtypes(); - case TensorFlowTypes::RESOURCE: - return this->cast().getSubtypes(); - default: - llvm_unreachable("unexpected tensorflow type with subtypes kind"); - } + if (auto variant_type = dyn_cast()) + return variant_type.getSubtypes(); + if (auto resource_type = dyn_cast()) + return resource_type.getSubtypes(); + llvm_unreachable("unexpected tensorflow type with subtypes kind"); } // TODO(jpienaar): BroadcastCompatible and HasCompatibleElementTypes have @@ -306,8 +296,11 @@ ArrayRef TensorFlowTypeWithSubtype::GetSubtypes() { bool BroadcastCompatible(ArrayRef lhs, ArrayRef rhs) { if (lhs.size() != rhs.size()) return false; for (auto types : llvm::zip(lhs, rhs)) { - auto lhs_type = std::get<0>(types); - auto rhs_type = std::get<1>(types); + // Drop ref types because they don't affect broadcast compatibility. E.g., + // `tensor` and `tensor` should be considered broadcast + // compatible. + auto lhs_type = DropRefType(std::get<0>(types)); + auto rhs_type = DropRefType(std::get<1>(types)); // This should be true for all TF ops: auto lhs_tt = lhs_type.dyn_cast(); @@ -366,27 +359,31 @@ bool AreCastCompatible(ArrayRef types) { return true; } -ShapedType DropTypeSubTypes(ShapedType ty) { - Type element_ty = ty.getElementType(); - auto subtype_ty = element_ty.dyn_cast(); - if (!subtype_ty) return ty; +// Assumes a function `GetDefaultTypeOf(ComposedType)` that returns the default +// type for a composed type (such as a ref type or a type with subtypes). +template +Type DropTypeHelper(Type ty) { + Type element_ty = getElementTypeOrSelf(ty); + auto composed_type = element_ty.dyn_cast(); + if (!composed_type) return ty; - Type default_ty = GetDefaultTypeOf(subtype_ty); - if (ty.hasRank()) return RankedTensorType::get(ty.getShape(), default_ty); - - return UnrankedTensorType::get(default_ty); + Type default_ty = GetDefaultTypeOf(composed_type); + if (auto ranked_ty = ty.dyn_cast()) { + return RankedTensorType::get(ranked_ty.getShape(), default_ty); + } else if (ty.dyn_cast()) { + return UnrankedTensorType::get(default_ty); + } else { + return default_ty; + } } -ShapedType DropRefType(ShapedType ty) { - Type element_ty = ty.getElementType(); - TF::TensorFlowRefType ref_ty = element_ty.dyn_cast(); - if (!ref_ty) return ty; - - Type default_ty = TF::GetDefaultTypeOf(ref_ty); - if (ty.hasRank()) return RankedTensorType::get(ty.getShape(), default_ty); - - return UnrankedTensorType::get(default_ty); +Type DropSubTypes(Type ty) { + return DropTypeHelper(ty); } +Type DropRefType(Type ty) { return DropTypeHelper(ty); } + +Type DropRefAndSubTypes(Type ty) { return DropRefType(DropSubTypes(ty)); } + } // namespace TF } // namespace mlir diff --git a/tensorflow/compiler/mlir/tensorflow/ir/tf_types.h b/tensorflow/compiler/mlir/tensorflow/ir/tf_types.h index 43d5f2fa476..f93f6b657da 100644 --- a/tensorflow/compiler/mlir/tensorflow/ir/tf_types.h +++ b/tensorflow/compiler/mlir/tensorflow/ir/tf_types.h @@ -67,26 +67,13 @@ using ResultShapeRange = iterator_range; // TensorFlow types //===----------------------------------------------------------------------===// -namespace TensorFlowTypes { -// List of supported TensorFlowType kinds, necessary for isa/dyn_cast. -enum Kind { - FIRST_USED_TENSORFLOW_TYPE = Type::FIRST_TENSORFLOW_TYPE, -#define HANDLE_TF_TYPE(tftype, enumerant, name) enumerant, -#include "tensorflow/compiler/mlir/tensorflow/ir/tf_types.def" - LAST_USED_TENSORFLOW_TYPE, -}; -} // namespace TensorFlowTypes - // The base class in the TensorFlow type hierarchy. class TensorFlowType : public Type { public: using Type::Type; // Support method to enable LLVM-style type casting. - static bool classof(Type type) { - return type.getKind() >= Type::FIRST_TENSORFLOW_TYPE && - type.getKind() <= TensorFlowTypes::LAST_USED_TENSORFLOW_TYPE; - } + static bool classof(Type type); }; // Returns true if the specified type is a valid TensorFlow element type. @@ -105,10 +92,7 @@ static inline bool IsValidTFTensorType(Type type) { namespace detail { // Common implementation of TensorFlow types. The template argument indicates -// the concrete derived class per CRTP. Concrete classes must implement the -// following: -// - `static unsigned getTypeKind()` that returns the (fixed) kind of the -// type. +// the concrete derived class per CRTP. template class TensorFlowTypeImpl : public Type::TypeBase { @@ -116,11 +100,6 @@ class TensorFlowTypeImpl using Base = typename Type::TypeBase; using TFBase = TensorFlowTypeImpl; using Base::Base; - - // Get the unique'ed type in the given context. - static Derived get(MLIRContext* context) { - return Base::get(context, Derived::getTypeKind()); - } }; } // namespace detail @@ -130,10 +109,7 @@ class TensorFlowRefType : public TensorFlowType { using TensorFlowType::TensorFlowType; // Checks if a type is TensorFlow Ref type. - static bool classof(Type type) { - return type.getKind() >= TensorFlowTypes::FLOAT_REF && - type.getKind() <= TensorFlowTypes::LAST_USED_TENSORFLOW_TYPE; - } + static bool classof(Type type); // Converts a type to the corresponding TensorFlowRef type. static TensorFlowType get(Type type); @@ -179,7 +155,6 @@ static inline Type GetElementTypeOrSelfResolveRef(Type type) { class tftype##Type : public detail::TensorFlowTypeImpl { \ public: \ using TFBase::TFBase; \ - static unsigned getTypeKind() { return TensorFlowTypes::enumerant; } \ }; // Custom TensorFlow types are defined separately. @@ -217,8 +192,6 @@ class TypeWithSubtypeStorage : public TypeStorage { // opaque and their interpretation depends on the actual underlying type. // The template argument indicates the concrete derived class per CRTP. Concrete // classes must implement the following: -// - `static unsigned getTypeKind()` that returns the (fixed) kind of the -// type. // - `static std::string getTypeName()` that returns the name of the type for // verification logging. template @@ -230,12 +203,12 @@ class TypeWithSubtypeImpl using Base::Base; static Derived get(ArrayRef subtypes, MLIRContext* context) { - return Base::get(context, Derived::getTypeKind(), subtypes); + return Base::get(context, subtypes); } static Derived getChecked(ArrayRef subtypes, MLIRContext* context, Location loc) { - return Base::getChecked(loc, Derived::getTypeKind(), subtypes); + return Base::getChecked(loc, subtypes); } static Derived get(MLIRContext* context) { return get({}, context); } @@ -263,10 +236,7 @@ class TensorFlowTypeWithSubtype : public TensorFlowType { using TensorFlowType::TensorFlowType; // Checks if a type is TensorFlow type with subtypes. - static bool classof(Type type) { - return type.getKind() == TensorFlowTypes::VARIANT || - type.getKind() == TensorFlowTypes::RESOURCE; - } + static bool classof(Type type); // Converts a TypeWithSubtype type to the same type but without its subtypes. Type RemoveSubtypes(); @@ -288,7 +258,6 @@ static inline Type GetDefaultTypeOf(TensorFlowTypeWithSubtype type) { class ResourceType : public detail::TypeWithSubtypeImpl { public: using TFBase::TFBase; - static unsigned getTypeKind() { return TensorFlowTypes::RESOURCE; } static std::string getTypeName() { return "ResourceType"; } }; @@ -300,7 +269,6 @@ class ResourceType : public detail::TypeWithSubtypeImpl { class VariantType : public detail::TypeWithSubtypeImpl { public: using TFBase::TFBase; - static unsigned getTypeKind() { return TensorFlowTypes::VARIANT; } static std::string getTypeName() { return "VariantType"; } }; @@ -325,15 +293,21 @@ bool HasCompatibleElementTypes(Type lhs, Type rhs, // compatible. bool AreCastCompatible(ArrayRef types); -// If the given tensor has elements of type with subtypes, then returns a new -// type after dropping subtypes info. Otherwise, returns the original type as -// is. -ShapedType DropTypeSubTypes(ShapedType ty); +// If `ty` is a tensor type and its element type has subtypes, then returns a +// new type of same shape but dropped subtypes for the element type. +// Otherwise, if `ty` has subtypes, then returns corresponding type with dropped +// subtypes. +// Otherwise, returns the original type `ty`. +Type DropSubTypes(Type ty); -// If the given tensor has elements of type ref, then returns a new type -// of the shape, but corresponding non-ref type as element type. Otherwise, -// returns the original type as is. -ShapedType DropRefType(ShapedType ty); +// If `ty` is a tensor type and has elements of a ref type, then returns a new +// type of same shape but corresponding non-ref type as element type. +// Otherwise, if `ty` is a ref type, then returns corresponding non-ref type. +// Otherwise, returns the original type `ty`. +Type DropRefType(Type ty); + +// Convenience call for executing both `DropRefType` and `DropSubTypes`. +Type DropRefAndSubTypes(Type ty); } // end namespace TF } // end namespace mlir diff --git a/tensorflow/compiler/mlir/tensorflow/tests/canonicalize.mlir b/tensorflow/compiler/mlir/tensorflow/tests/canonicalize.mlir index 595bdce5be4..50486909694 100644 --- a/tensorflow/compiler/mlir/tensorflow/tests/canonicalize.mlir +++ b/tensorflow/compiler/mlir/tensorflow/tests/canonicalize.mlir @@ -444,6 +444,14 @@ func @testReshapeNoOp(%arg0: tensor<2x4xf32>, %arg1: tensor<2xi32>) -> tensor<2x return %0 : tensor<2x4xf32> } +// CHECK-LABEL: func @testBroadcastToNoOp +func @testBroadcastToNoOp(%arg0: tensor<2x4xf32>, %arg1: tensor<2xi32>) -> tensor<2x4xf32> { + %0 = "tf.BroadcastTo"(%arg0, %arg1) : (tensor<2x4xf32>, tensor<2xi32>) -> tensor<2x4xf32> + + // CHECK: return %arg0 + return %0 : tensor<2x4xf32> +} + // CHECK-LABEL: func @testPackShapeComputation func @testPackShapeComputation(%arg0: tensor, %arg1: tensor, %arg2: tensor<*xf32>) -> (tensor<2xi32>, tensor<3xi32>, tensor<3xi32>, tensor<3xi32>, tensor<3xi32>, tensor<*xi32>) { // Test dimensions sizes. @@ -620,6 +628,15 @@ func @testLogicalNotOfLessEqual(%arg0: tensor<8x16xf32>, %arg1: tensor<8x16xf32> // CHECK: return %0 } +// CHECK-LABEL: testSizeFolding +func @testSizeFolding(%arg0: tensor<3x5x7xf32>) -> tensor { + %0 = "tf.Size"(%arg0) : (tensor<3x5x7xf32>) -> tensor + return %0: tensor + +// CHECK: %0 = "tf.Const"() {value = dense<105> : tensor} : () -> tensor +// CHECK: return %0 : tensor +} + // CHECK-LABEL: testDivWithSqrtDivisor func @testDivWithSqrtDivisor(%arg0: tensor<8x16xf32>, %arg1: tensor<8x16xf32>) -> tensor<8x16xf32> { %0 = "tf.Sqrt"(%arg1) : (tensor<8x16xf32>) -> tensor<8x16xf32> @@ -685,6 +702,15 @@ func @identityTranspose(%arg0: tensor<2x3x4x5x6xf32>) -> tensor<2x3x4x5x6xf32> { // CHECK: return %arg0 } +// CHECK-LABEL: @identityTransposeConst +func @identityTransposeConst(%arg0: tensor<2x3x4x5x6xf32>) -> tensor<2x3x4x5x6xf32> { + %0 = constant dense<[0, 1, 2, 3, 4]> : tensor<5xi32> + %1 = "tf.Transpose"(%arg0, %0) : (tensor<2x3x4x5x6xf32>, tensor<5xi32>) -> tensor<2x3x4x5x6xf32> + + return %1 : tensor<2x3x4x5x6xf32> + // CHECK: return %arg0 +} + // CHECK-LABEL: @nonIdentityTranspose func @nonIdentityTranspose(%arg0: tensor<2x3x4x5x6xf32>) -> tensor<2x3x4x6x5xf32> { %0 = "tf.Const"() {value = dense<[0, 1, 2, 4, 3]> : tensor<5xi32>} : () -> tensor<5xi32> @@ -707,6 +733,17 @@ func @cancellableTranspose(%arg0: tensor<1x4x4x8xf32>) -> tensor<1x4x4x8xf32> { // CHECK: return %arg0 } +// CHECK-LABEL: @cancellableTransposeConst +func @cancellableTransposeConst(%arg0: tensor<1x4x4x8xf32>) -> tensor<1x4x4x8xf32> { + %0 = constant dense<[0, 3, 1, 2]> : tensor<4xi32> + %1 = constant dense<[0, 2, 3, 1]> : tensor<4xi32> + %2 = "tf.Transpose"(%arg0, %0) : (tensor<1x4x4x8xf32>, tensor<4xi32>) -> tensor<1x8x4x4xf32> + %3 = "tf.Transpose"(%2, %1) : (tensor<1x8x4x4xf32>, tensor<4xi32>) -> tensor<1x4x4x8xf32> + + return %3 : tensor<1x4x4x8xf32> + // CHECK: return %arg0 +} + // CHECK-LABEL: @nonCancellableTranspose func @nonCancellableTranspose(%arg0: tensor<1x4x4x8xf32>) -> tensor<4x1x4x8xf32> { %0 = "tf.Const"() {value = dense<[0, 3, 1, 2]> : tensor<4xi32>} : () -> tensor<4xi32> @@ -725,13 +762,72 @@ func @addN(%arg0: tensor<*xf32>) -> tensor<*xf32> { return %0 : tensor<*xf32> } -// CHECK-LABEL: func @ToBool_0DScalar -func @ToBool_0DScalar(%arg0: tensor) -> tensor { +// CHECK-LABEL: func @ToBool_0DScalarI1 +func @ToBool_0DScalarI1(%arg0: tensor) -> tensor { // CHECK: return %arg0 %0 = "tf.ToBool"(%arg0) : (tensor) -> tensor return %0 : tensor } +// CHECK-LABEL: func @ToBool_0DScalarInt +func @ToBool_0DScalarInt(%arg0: tensor) -> tensor { + // CHECK: [[Zero:%.*]] = "tf.Const"() {value = dense<0> : tensor} + // CHECK: [[NE:%.*]] = "tf.NotEqual"(%arg0, [[Zero]]) + // CHECK: return [[NE]] + %0 = "tf.ToBool"(%arg0) : (tensor) -> tensor + return %0 : tensor +} + +// CHECK-LABEL: func @ToBool_0DScalarFloat +func @ToBool_0DScalarFloat(%arg0: tensor) -> tensor { + // CHECK: [[Zero:%.*]] = "tf.Const"() {value = dense<0.000000e+00> : tensor} : () -> tensor + // CHECK: [[NE:%.*]] = "tf.NotEqual"(%arg0, [[Zero]]) + // CHECK: return [[NE]] + %0 = "tf.ToBool"(%arg0) : (tensor) -> tensor + return %0 : tensor +} + +// CHECK-LABEL: func @ToBool_0DScalarString +func @ToBool_0DScalarString(%arg0: tensor) -> tensor { + // CHECK: [[EmptyStr:%.*]] = "tf.Const"() {value = dense<""> : tensor} : () -> tensor + // CHECK: [[NE:%.*]] = "tf.NotEqual"(%arg0, [[EmptyStr]]) {incompatible_shape_error = false} : (tensor, tensor) -> tensor + // CHECK: return [[NE]] : tensor + %0 = "tf.ToBool"(%arg0) : (tensor) -> tensor + return %0 : tensor +} + +// CHECK-LABEL: func @ToBool_1DTensor +func @ToBool_1DTensor(%arg0: tensor<1xf32>) -> tensor { + // CHECK: [[Const:%.*]] = "tf.Const"() {value = dense : tensor} : () -> tensor + // CHECK: return [[Const]] + %0 = "tf.ToBool"(%arg0) : (tensor<1xf32>) -> tensor + return %0 : tensor +} + +// CHECK-LABEL: func @ToBool_1DTensorZeroDim +func @ToBool_1DTensorZeroDim(%arg0: tensor<0xf32>) -> tensor { + // CHECK: [[Const:%.*]] = "tf.Const"() {value = dense : tensor} : () -> tensor + // CHECK: return [[Const]] + %0 = "tf.ToBool"(%arg0) : (tensor<0xf32>) -> tensor + return %0 : tensor +} + +// CHECK-LABEL: func @ToBool_2DTensor +func @ToBool_2DTensor(%arg0: tensor<1x5xf32>) -> tensor { + // CHECK: [[Const:%.*]] = "tf.Const"() {value = dense : tensor} : () -> tensor + // CHECK: return [[Const]] + %0 = "tf.ToBool"(%arg0) : (tensor<1x5xf32>) -> tensor + return %0 : tensor +} + +// CHECK-LABEL: func @ToBool_2DTensorZeroDim +func @ToBool_2DTensorZeroDim(%arg0: tensor<1x0xf32>) -> tensor { + // CHECK: [[Const:%.*]] = "tf.Const"() {value = dense : tensor} : () -> tensor + // CHECK: return [[Const]] + %0 = "tf.ToBool"(%arg0) : (tensor<1x0xf32>) -> tensor + return %0 : tensor +} + // CHECK-LABEL: testReadVariableOpOfCast func @testReadVariableOpOfCast(%arg0: tensor>>) -> tensor<8x40xf32> { %0 = "tf.Cast"(%arg0) : (tensor>>) -> tensor<*x!tf.resource> @@ -826,6 +922,51 @@ func @foldIf(%arg0: tensor, %arg1: tensor, %arg2: tensor) -> (tens return %4 : tensor } +// CHECK-LABEL: foldIfRegion +func @foldIfRegion(%arg0: tensor, %arg1: tensor, %arg2: tensor) -> (tensor, tensor) { + %false = "tf.Const"() {value = dense : tensor} : () -> tensor + %true = "tf.Const"() {value = dense : tensor} : () -> tensor + + // CHECK: [[Val0:%.*]] = "tf.Mul"(%arg0, %arg1) + %0 = "tf.IfRegion"(%true) ({ + %true_value = "tf.Mul"(%arg0, %arg1) : (tensor, tensor) -> tensor + "tf.Yield"(%true_value) : (tensor) -> () + }, { + %false_value = "tf.Sub"(%arg0, %arg1) : (tensor, tensor) -> tensor + "tf.Yield"(%false_value) : (tensor) -> () + }) { is_stateless = true}: (tensor) -> tensor + + // CHECK: [[Val1:%.*]] = "tf.Sub"(%arg0, %arg1) + %1 = "tf.IfRegion"(%false) ({ + %true_value = "tf.Mul"(%arg0, %arg1) : (tensor, tensor) -> tensor + "tf.Yield"(%true_value) : (tensor) -> () + }, { + %false_value = "tf.Sub"(%arg0, %arg1) : (tensor, tensor) -> tensor + "tf.Yield"(%false_value) : (tensor) -> () + }) { is_stateless = true}: (tensor) -> tensor + + // CHECK: return [[Val0]], [[Val1]] + return %0, %1 : tensor, tensor +} + +// CHECK-LABEL: foldIfRegionMismatchedTypes +func @foldIfRegionMismatchedTypes(%arg0: tensor, %arg1: tensor, %arg2: tensor) -> tensor<1xf32> { + %false = "tf.Const"() {value = dense : tensor} : () -> tensor + %true = "tf.Const"() {value = dense : tensor} : () -> tensor + + // CHECK: [[Val0:%.*]] = "tf.Mul"(%arg0, %arg1) + // CHECK-NEXT: [[Cast:%.*]] = "tf.Cast"([[Val0]]) + // CHECK-NEXT: return [[Cast]] + %0 = "tf.IfRegion"(%true) ({ + %true_value = "tf.Mul"(%arg0, %arg1) : (tensor, tensor) -> tensor + "tf.Yield"(%true_value) : (tensor) -> () + }, { + %false_value = "tf.Sub"(%arg0, %arg1) : (tensor, tensor) -> tensor + "tf.Yield"(%false_value) : (tensor) -> () + }) { is_stateless = true}: (tensor) -> tensor<1xf32> + return %0 : tensor<1xf32> +} + // CHECK-LABEL: foldCase func @foldCase(%arg0: tensor, %arg1: tensor) -> (tensor) { %2 = constant dense<1> : tensor @@ -834,11 +975,11 @@ func @foldCase(%arg0: tensor, %arg1: tensor) -> (tensor) { // CHECK: PartitionedCall // CHECK-SAME: device = "noodle" // CHECK-SAME: f = @add - %4 = "tf.Case"(%2, %arg0, %arg1) {branches = [@sub, @add], output_shapes = [#tf.shape<>], device = "noodle"} : (tensor, tensor, tensor) -> tensor + %4 = "tf.Case"(%2, %arg0, %arg1) {branches = [@sub, @add], output_shapes = [#tf.shape<>], device = "noodle", is_stateless = false} : (tensor, tensor, tensor) -> tensor // CHECK: PartitionedCall // CHECK-SAME: _cluster_launch = "not_ready" // CHECK-SAME: f = @sub - %5 = "tf.Case"(%3, %4, %arg1) {branches = [@sub, @add], output_shapes = [#tf.shape<>], _cluster_launch = "not_ready"} : (tensor, tensor, tensor) -> tensor + %5 = "tf.Case"(%3, %4, %arg1) {branches = [@sub, @add], output_shapes = [#tf.shape<>], _cluster_launch = "not_ready", is_stateless = false} : (tensor, tensor, tensor) -> tensor return %5 : tensor } diff --git a/tensorflow/compiler/mlir/tensorflow/tests/device_copy.mlir b/tensorflow/compiler/mlir/tensorflow/tests/device_copy.mlir new file mode 100644 index 00000000000..8250bcf7101 --- /dev/null +++ b/tensorflow/compiler/mlir/tensorflow/tests/device_copy.mlir @@ -0,0 +1,16 @@ +// RUN: tf-opt -tf-tensor-device-copy %s | FileCheck %s --dump-input=fail + +// CHECK-LABEL: func @fold_identity +// CHECK-SAME: ([[arg0:%.*]]: tensor<2x2xf32>, [[arg1:%.*]]: tensor<2x2xf32> +module attributes {tf.versions = {bad_consumers = [], min_consumer = 0 : i32}} { + func @fold_identity(%arg0: tensor<2x2xf32>, %arg1: tensor<2x2xf32>) -> tensor<2x2xf32> { + %0 = tf_executor.graph { + // CHECK: tf.MatMul + %outputs, %control = tf_executor.island wraps "tf.MatMul"(%arg0, %arg1) {device = "", transpose_a = false, transpose_b = false} : (tensor<2x2xf32>, tensor<2x2xf32>) -> tensor<2x2xf32> + // CHECK-NOT: tf.Identity + %outputs_0, %control_1 = tf_executor.island wraps "tf.Identity"(%outputs) {device = ""} : (tensor<2x2xf32>) -> tensor<2x2xf32> + tf_executor.fetch %outputs_0 : tensor<2x2xf32> + } + return %0 : tensor<2x2xf32> + } +} diff --git a/tensorflow/compiler/mlir/tensorflow/tests/executor_tpuv1_outline_island/case_op.mlir b/tensorflow/compiler/mlir/tensorflow/tests/executor_tpuv1_outline_island/case_op.mlir index 7d761b5d690..0000d43823b 100644 --- a/tensorflow/compiler/mlir/tensorflow/tests/executor_tpuv1_outline_island/case_op.mlir +++ b/tensorflow/compiler/mlir/tensorflow/tests/executor_tpuv1_outline_island/case_op.mlir @@ -16,7 +16,7 @@ module { "tf.TPUReplicateMetadata"() {_tpu_replicate = "cluster", device = "device", num_replicas = 1, topology = "topology"} : () -> () %index = "tf.opA"(%arg0) {_tpu_replicate = "cluster"} : (tensor) -> tensor %input = "tf.opB"(%arg0) {_tpu_replicate = "cluster"} : (tensor) -> tensor - %result = "tf.Case"(%index, %input) {branches = [@branch_0, @branch_1, @branch_2, @branch_3, @branch_4]} : (tensor, tensor) -> tensor + %result = "tf.Case"(%index, %input) {branches = [@branch_0, @branch_1, @branch_2, @branch_3, @branch_4], is_stateless = false} : (tensor, tensor) -> tensor tf_executor.yield %result : tensor } tf_executor.fetch %output : tensor diff --git a/tensorflow/compiler/mlir/tensorflow/tests/functional-control-flow-to-regions.mlir b/tensorflow/compiler/mlir/tensorflow/tests/functional-control-flow-to-regions.mlir index c8c82c5c08f..e4e7f0859c8 100644 --- a/tensorflow/compiler/mlir/tensorflow/tests/functional-control-flow-to-regions.mlir +++ b/tensorflow/compiler/mlir/tensorflow/tests/functional-control-flow-to-regions.mlir @@ -123,6 +123,27 @@ func @testIfNoInputAndNoResult(%arg0: tensor) -> () { // ----- +// If with non tensor condition + +// Simple If +// CHECK: func @testIf1Then{{.+}} +// CHECK: func @testIf1Else{{.+}} +func @testIf1Then(tensor<*xf32>) -> tensor<*xf32> +func @testIf1Else(tensor<*xf32>) -> tensor<*xf32> + +// CHECK-LABEL: func @testIf1Result(%arg0: tensor, %arg1: tensor<*xf32>) +func @testIf1Result(%arg0: tensor, %arg1: tensor<*xf32>) -> tensor<*xf32> { + %0 = "tf.If"(%arg0, %arg1) { + then_branch = @testIf1Then, else_branch = @testIf1Else, is_stateless = false + } : (tensor, tensor<*xf32>) -> tensor<*xf32> + + // CHECK: [[ToBool:%.*]] = "tf.ToBool" + // CHECK: "tf.IfRegion"([[ToBool]]) + return %0 : tensor<*xf32> +} + +// ----- + // Simple While func @testWhileCond(tensor<*xf32>) -> (tensor) func @testWhileBody(tensor<*xf32>) -> (tensor<*xf32>) @@ -200,3 +221,58 @@ func @testWhileResult(tensor<*xf32>) -> (tensor<*xf32>) { return %1 : tensor<*xf32> } +// ----- + +// While with non tensor condition +func @testWhileCond(tensor<*xf32>) -> (tensor) +func @testWhileBody(tensor<*xf32>) -> (tensor<*xf32>) + +// CHECK-LABEL: func @testWhileResult +func @testWhileResult(tensor<*xf32>) -> (tensor<*xf32>) { +^bb0(%arg0: tensor<*xf32>): + %1 = "tf.While"(%arg0) { + cond = @testWhileCond, + body = @testWhileBody, + is_stateless = true, + _attr0 = 10, _attr1 = true, attr2 = "hello" + } : (tensor<*xf32>) -> (tensor<*xf32>) + + // CHECK: [[Result0:%.*]] = "tf.WhileRegion" + // CHECK: [[Result1:%.*]] = call @testWhileCond + // CHECK: [[ToBool:%.*]] = "tf.ToBool"([[Result1]]) + // CHECK: "tf.Yield"([[ToBool]]) + // CHECK: [[Result2:%.*]] = call @testWhileBody + // CHECK: "tf.Yield"([[Result2]]) + // CHECK: return [[Result0]] + return %1 : tensor<*xf32> +} + +// ----- + +func @then_branch() -> () +func @else_branch() -> () + +// Test tf.If device is preserved. +// CHECK-LABEL: func @testIfDevice +func @testIfDevice(%arg0: tensor) { + "tf.If"(%arg0) {then_branch = @then_branch, else_branch = @else_branch, is_stateless = false, device = "/device:CPU:0"} : (tensor) -> () + + // CHECK: "tf.IfRegion" + // CHECK: device = "/device:CPU:0" + return +} + +// ----- + +func @cond() -> tensor +func @body() -> () + +// Test tf.While device is preserved. +// CHECK-LABEL: func @testWhileDevice +func @testWhileDevice() { + "tf.While"() {cond = @cond, body = @body, is_stateless = false, device = "/device:CPU:0"} : () -> () + + // CHECK: "tf.WhileRegion" + // CHECK: device = "/device:CPU:0" + return +} diff --git a/tensorflow/compiler/mlir/tensorflow/tests/lower_tf.mlir b/tensorflow/compiler/mlir/tensorflow/tests/lower_tf.mlir index e11474c0755..ea55e50db30 100644 --- a/tensorflow/compiler/mlir/tensorflow/tests/lower_tf.mlir +++ b/tensorflow/compiler/mlir/tensorflow/tests/lower_tf.mlir @@ -479,13 +479,39 @@ func @DynamicStitch_duplicates(%arg0: tensor<2x2xf32>) -> tensor<1x2xf32> { return %0 : tensor<1x2xf32> } -func @Reciprocal(%arg0: tensor<*xf32>) -> tensor<*xf32> { +// CHECK-LABEL: @Reciprocal_i32 +func @Reciprocal_i32(%arg0: tensor<*xi32>) -> tensor<*xi32> { + // CHECK: %[[ONE:.*]] = "tf.Const"() {value = dense<1> : tensor} : () -> tensor + // CHECK: "tf.Div"(%[[ONE]], %arg0) : (tensor, tensor<*xi32>) -> tensor<*xi32> + %0 = "tf.Reciprocal"(%arg0) : (tensor<*xi32>) -> tensor<*xi32> + return %0 : tensor<*xi32> +} + +// CHECK-LABEL: @Reciprocal_f32 +func @Reciprocal_f32(%arg0: tensor<*xf32>) -> tensor<*xf32> { // CHECK: %[[ONE:.*]] = "tf.Const"() {value = dense<1.000000e+00> : tensor} : () -> tensor // CHECK: "tf.Div"(%[[ONE]], %arg0) : (tensor, tensor<*xf32>) -> tensor<*xf32> %0 = "tf.Reciprocal"(%arg0) : (tensor<*xf32>) -> tensor<*xf32> return %0 : tensor<*xf32> } +// CHECK-LABEL: @Reciprocal_complexf32 +func @Reciprocal_complexf32(%arg0: tensor<*xcomplex>) -> tensor<*xcomplex> { + // CHECK: %[[ONE:.*]] = "tf.Const"() {value = dense<(1.000000e+00,0.000000e+00)> : tensor>} : () -> tensor> + // CHECK: "tf.Div"(%[[ONE]], %arg0) : (tensor>, tensor<*xcomplex>) -> tensor<*xcomplex> + %0 = "tf.Reciprocal"(%arg0) : (tensor<*xcomplex>) -> tensor<*xcomplex> + return %0 : tensor<*xcomplex> +} + +// CHECK-LABEL: @Reciprocal_complexf64 +func @Reciprocal_complexf64(%arg0: tensor<*xcomplex>) -> tensor<*xcomplex> { + // CHECK: %[[ONE:.*]] = "tf.Const"() {value = dense<(1.000000e+00,0.000000e+00)> : tensor>} : () -> tensor> + // CHECK: "tf.Div"(%[[ONE]], %arg0) : (tensor>, tensor<*xcomplex>) -> tensor<*xcomplex> + %0 = "tf.Reciprocal"(%arg0) : (tensor<*xcomplex>) -> tensor<*xcomplex> + return %0 : tensor<*xcomplex> +} + +// CHECK-LABEL: @ScatterNd func @ScatterNd(%arg0: tensor<4x1xi32>, %arg1: tensor<4xf32>) -> tensor<8xf32> { // CHECK: %[[ZERO:.*]] = "tf.Const"() {value = dense<0.000000e+00> : tensor<8xf32>} : () -> tensor<8xf32> // CHECK: "tf.TensorScatterUpdate"(%[[ZERO]], %arg0, %arg1) : (tensor<8xf32>, tensor<4x1xi32>, tensor<4xf32>) -> tensor<8xf32> @@ -494,3 +520,16 @@ func @ScatterNd(%arg0: tensor<4x1xi32>, %arg1: tensor<4xf32>) -> tensor<8xf32> { %0 = "tf.ScatterNd"(%arg0, %arg1, %shape) : (tensor<4x1xi32>, tensor<4xf32>, tensor<1xi32>) -> tensor<8xf32> return %0 : tensor<8xf32> } + +// CHECK-LABEL: @_UnaryOpsComposition +// CHECK-SAME: %[[ARG0:.*]]: tensor<4xf32> +func @_UnaryOpsComposition(%arg0: tensor<4xf32>) -> tensor<4xf32> { + + // CHECK: %[[RESULT0:.*]] = "tf.Asin"(%[[ARG0]]) + // CHECK: %[[RESULT1:.*]] = "tf.Abs"(%[[RESULT0]]) + // CHECK: %[[RESULT2:.*]] = "tf.Log"(%[[RESULT1]]) + // CHECK: return %[[RESULT2]] + + %0 = "tf._UnaryOpsComposition"(%arg0) {op_names = ["Asin", "Abs", "Log"]} : (tensor<4xf32>) -> tensor<4xf32> + return %0 : tensor<4xf32> +} diff --git a/tensorflow/compiler/mlir/tensorflow/tests/mark_ops_for_outside_compilation.mlir b/tensorflow/compiler/mlir/tensorflow/tests/mark_ops_for_outside_compilation.mlir index 9544a02dca4..df2add2208a 100644 --- a/tensorflow/compiler/mlir/tensorflow/tests/mark_ops_for_outside_compilation.mlir +++ b/tensorflow/compiler/mlir/tensorflow/tests/mark_ops_for_outside_compilation.mlir @@ -136,6 +136,7 @@ func @if_region_captured_string(%arg0: tensor, %arg1: tensor) -> // CHECK-NOT: _xla_outside_compilation // CHECK: "tf.IfRegion" // CHECK: "tf.StringToNumber" + // CHECK-NOT: _xla_outside_compilation // CHECK: _xla_outside_compilation = "auto", is_stateless = true %1 = "tf.Const"() {value = dense<1> : tensor} : () -> tensor %2 = "tf.IfRegion"(%arg0) ( { diff --git a/tensorflow/compiler/mlir/tensorflow/tests/mlir2graphdef/func_list_attr.mlir b/tensorflow/compiler/mlir/tensorflow/tests/mlir2graphdef/func_list_attr.mlir index c6543f3121e..09a38b5b5de 100644 --- a/tensorflow/compiler/mlir/tensorflow/tests/mlir2graphdef/func_list_attr.mlir +++ b/tensorflow/compiler/mlir/tensorflow/tests/mlir2graphdef/func_list_attr.mlir @@ -43,7 +43,7 @@ func @main() { // CHECK-NEXT: } // CHECK-NEXT: } // CHECK: } - %1:2 = tf_executor.island wraps "tf.Case"(%0#0) {Tin = [], Tout = ["tfdtype$DT_FLOAT"], branches = [@foo, @bar], device = "", output_shapes = []} : (tensor) -> tensor<*xf32> loc("Case") + %1:2 = tf_executor.island wraps "tf.Case"(%0#0) {Tin = [], Tout = ["tfdtype$DT_FLOAT"], branches = [@foo, @bar], device = "", output_shapes = [], is_stateless = false} : (tensor) -> tensor<*xf32> loc("Case") tf_executor.fetch } return diff --git a/tensorflow/compiler/mlir/tensorflow/tests/region-control-flow-to-functional.mlir b/tensorflow/compiler/mlir/tensorflow/tests/region-control-flow-to-functional.mlir index e9d4e441a10..3e8935b699e 100644 --- a/tensorflow/compiler/mlir/tensorflow/tests/region-control-flow-to-functional.mlir +++ b/tensorflow/compiler/mlir/tensorflow/tests/region-control-flow-to-functional.mlir @@ -212,6 +212,28 @@ func @testNoOutputs(%arg0: tensor, %arg1: tensor<*xf32>) -> () { return } +// ----- +// Check ToBool folding for IfRegion +// CHECK: func @tf.IfRegion_else(%arg0: tensor<*xf32>) -> tensor<*xf32> +// CHECK-NEXT: "tf.Neg" +// CHECK: func @tf.IfRegion_then(%arg0: tensor<*xf32>) -> tensor<*xf32> +// CHECK-NEXT: "tf.Abs" +// CHECK-LABEL: @testToBoolFold +func @testToBoolFold(%arg0: tensor, %arg1: tensor<*xf32>) -> tensor<*xf32> { + // CHECK-NEXT: "tf.If"(%arg0, %arg1) + // CHECK-SAME: else_branch = @tf.IfRegion_else + // CHECK-SAME: then_branch = @tf.IfRegion_then + %tobool = "tf.ToBool"(%arg0) : (tensor) -> tensor + %0 = "tf.IfRegion"(%tobool) ({ + %1 = "tf.Abs"(%arg1) : (tensor<*xf32>) -> tensor<*xf32> + "tf.Yield"(%1) : (tensor<*xf32>) -> () + }, { + %2 = "tf.Neg"(%arg1) : (tensor<*xf32>) -> tensor<*xf32> + "tf.Yield"(%2) : (tensor<*xf32>) -> () + }) {is_stateless = true} : (tensor) -> tensor<*xf32> + return %0 : tensor<*xf32> +} + // ----- // Simple WhileRegion @@ -592,3 +614,64 @@ func @testWhileRegionBlockArgMismatch(%arg0 : tensor<*xf32>, %arg1 : tensor // CHECK: return [[Result]]#0 return %0#0 : tensor<*xf32> } + +// ----- + +// Simple trivially transformable while with ToBool +// CHECK: func @while_cond +// CHECK: func @while_body +// CHECK-LABEL: testWhileRegionTrivial +func @while_cond(%arg0 : tensor<*xf32>, %arg1 : tensor) -> tensor +func @while_body(%arg0 : tensor<*xf32>, %arg1 : tensor) -> (tensor<*xf32>, tensor) +func @testWhileRegionTrivial(%arg0 : tensor<*xf32>, %arg1 : tensor) -> tensor<*xf32> { + // CHECK: [[Result:%.*]]:2 = "tf.While"(%arg0, %arg1) {body = @while_body, cond = @while_cond + %0:2 = "tf.WhileRegion"(%arg0, %arg1) ( + { + ^bb0(%carg0: tensor<*xf32>, %carg1: tensor): + %cond_i32 = call @while_cond(%carg0, %carg1) : (tensor<*xf32>, tensor) -> tensor + %cond = "tf.ToBool"(%cond_i32) : (tensor) -> tensor + "tf.Yield"(%cond) : (tensor) -> () + }, + { + // loop body + ^bb0(%barg0: tensor<*xf32>, %barg1: tensor): + %bdy:2 = call @while_body(%barg0, %barg1) : (tensor<*xf32>, tensor) -> (tensor<*xf32>, tensor) + "tf.Yield"(%bdy#0, %bdy#1) : (tensor<*xf32>, tensor) -> () + } + ) { is_stateless = false } : (tensor<*xf32>, tensor) -> (tensor<*xf32>, tensor) + // CHECK: return [[Result]]#0 + return %0#0 : tensor<*xf32> +} + +// ----- + +// Test tf.IfRegion device is preserved. +// CHECK-LABEL: func @testIfRegionDevice +func @testIfRegionDevice(%arg0: tensor) { + "tf.IfRegion"(%arg0) ({ + "tf.Yield"() : () -> () + }, { + "tf.Yield"() : () -> () + }) {is_stateless = false, device = "/device:CPU:0"} : (tensor) -> () + + // CHECK: "tf.If" + // CHECK-SAME: device = "/device:CPU:0" + return +} + +// ----- + +// Test tf.WhileRegion device is preserved. +// CHECK-LABEL: func @testWhileRegionDevice +func @testWhileRegionDevice() { + "tf.WhileRegion"() ( { + %0 = "tf.Const"() {value = dense : tensor} : () -> tensor + "tf.Yield"(%0) : (tensor) -> () + }, { + "tf.Yield"() : () -> () + }) {is_stateless = false, device = "/device:CPU:0"} : () -> () + + // CHECK: "tf.While" + // CHECK-SAME: device = "/device:CPU:0" + return +} diff --git a/tensorflow/compiler/mlir/tensorflow/tests/resource-alias-analysis-test.mlir b/tensorflow/compiler/mlir/tensorflow/tests/resource-alias-analysis-test.mlir index 87da399b726..da0a2df9e6a 100644 --- a/tensorflow/compiler/mlir/tensorflow/tests/resource-alias-analysis-test.mlir +++ b/tensorflow/compiler/mlir/tensorflow/tests/resource-alias-analysis-test.mlir @@ -173,7 +173,7 @@ func @passthru(%arg0: !tf_res) -> (!tf_res, !tf_res) { // ----- // Test aliasing through IfRegion -!tf_res = type tensor<*x!tf.resource>> +!tf_res = type tensor<*x!tf.resource>> // CHECK-LABEL: func @if_region_aliasing // expected-remark@below {{Region #0, Arg #0, ID 7 : 1, 4, 6, 7}} @@ -181,7 +181,7 @@ func @passthru(%arg0: !tf_res) -> (!tf_res, !tf_res) { func @if_region_aliasing(%arg0: !tf_res, %arg1: !tf_res) { // expected-remark@below {{Result #0, ID 0 : 0, 1, 3, 4, 5}} %vh0 = "tf.VarHandleOp"() {container = "c", shared_name = "v0"} : () -> !tf_res - %read0 = "tf.ReadVariableOp"(%vh0) : (!tf_res) -> tensor<32xf32> + %read0 = "tf.ReadVariableOp"(%vh0) : (!tf_res) -> tensor // expected-remark@below {{Result #0, ID 4 : Unknown}} // expected-remark@below {{Result #1, ID 5 : 0, 1, 2, 3, 4, 5, 6, 8}} // expected-remark@below {{Result #2, ID 6 : 1, 2, 4, 5, 6, 7, 8}} @@ -195,7 +195,7 @@ func @if_region_aliasing(%arg0: !tf_res, %arg1: !tf_res) { // expected-remark@below {{Result #0, ID 3 : 0, 1, 3, 4, 5}} %id0 = "tf.Identity"(%vh0) : (!tf_res) -> !tf_res "tf.Yield"(%id0, %id0, %arg0) : (!tf_res, !tf_res, !tf_res) -> () - }) {is_stateless = true} : (tensor<32xf32>) -> (!tf_res, !tf_res, !tf_res) + }) {is_stateless = true} : (tensor) -> (!tf_res, !tf_res, !tf_res) return } @@ -232,3 +232,55 @@ func @while_region_aliasing(%arg0: !tf_res, %arg1: !tf_res, %arg2: !tf_res) { return } +// ----- +// Test aliasing through calls +!tf_res = type tensor<*x!tf.resource>> + +// CHECK-LABEL: func @aliasing_through_calls +func @aliasing_through_calls(%arg0: tensor<32xf32>) -> () { + // expected-remark@below {{Result #0, ID 0 : 0, 1, 2}} + %vh0 = "tf.VarHandleOp"() {container = "c", shared_name = "v0"} : () -> !tf_res + // expected-remark@below {{Result #0, ID 1 : Unknown}} + // expected-remark@below {{Result #1, ID 2 : 0, 1, 2}} + %c:2 = call @passthru(%vh0) : (!tf_res) -> (!tf_res, !tf_res) + return +} + +// expected-remark@below {{Region #0, Arg #0, ID 1 : 1}} +func @passthru(%arg0: !tf_res) -> (!tf_res, !tf_res) { + // expected-remark@below {{Result #0, ID 0 : 0}} + %vh0 = "tf.VarHandleOp"() {container = "c", shared_name = "v0"} : () -> !tf_res + return %vh0, %arg0 : !tf_res, !tf_res +} + +// ----- +// Test aliasing through tf_device.launch +!tf_res = type tensor<*x!tf.resource>> + +// CHECK-LABEL: func @aliasing_through_launch +func @aliasing_through_launch(%arg0: tensor<32xf32>) { + // expected-remark@below {{Result #0, ID 0 : 0, 1}} + %vh = "tf.VarHandleOp"() {container = "c", shared_name = "v"} : () -> !tf_res + + // expected-remark@below {{Result #0, ID 1 : 0, 1}} + %launch = "tf_device.launch"() ({ + tf_device.return %vh : !tf_res + }) {device = ""} : () -> !tf_res + return +} + +// ----- +// Test aliasing through tf_device.cluster +!tf_res = type tensor<*x!tf.resource>> + +// CHECK-LABEL: func @aliasing_through_cluster +func @aliasing_through_cluster(%arg0: tensor<32xf32>) { + // expected-remark@below {{Result #0, ID 0 : 0, 1}} + %vh = "tf.VarHandleOp"() {container = "c", shared_name = "v"} : () -> !tf_res + + // expected-remark@below {{Result #0, ID 1 : 0, 1}} + %cluster = "tf_device.cluster"() ({ + tf_device.return %vh : !tf_res + }) : () -> !tf_res + return +} diff --git a/tensorflow/compiler/mlir/tensorflow/tests/resource-device-inference.mlir b/tensorflow/compiler/mlir/tensorflow/tests/resource-device-inference.mlir index dd622e565c0..75cafde88e3 100644 --- a/tensorflow/compiler/mlir/tensorflow/tests/resource-device-inference.mlir +++ b/tensorflow/compiler/mlir/tensorflow/tests/resource-device-inference.mlir @@ -424,3 +424,117 @@ func @propagate_if_region_inlined( } return } + +// Test propagation through WhileRegion (inlined calls) +// CHECK-LABEL: func @propagate_while_region_inlined +func @propagate_while_region_inlined( + %arg0: !tf_res {tf.device = "/TPU:0"}, + %arg1: tensor) { + tf_executor.graph { + // CHECK: tf_executor.island + %island = tf_executor.island { + // CHECK-NEXT: "tf.Identity" + // CHECK-SAME: {device = "/TPU:0"} + %id0 = "tf.Identity"(%arg0) : (!tf_res) -> !tf_res + // CHECK-NEXT: "tf.VarHandleOp" + %var_handle = "tf.VarHandleOp"() {container = "c", shared_name = "v0", device = "/TPU:1"} : () -> !tf_res + // CHECK-NEXT: "tf.WhileRegion" + "tf.WhileRegion"(%arg1, %id0, %var_handle) ({ + ^bb0(%carg0: tensor, %carg1: !tf_res, %carg2: !tf_res): + // CHECK: ^bb + // CHECK: "tf.Identity" + // CHECK-SAME: {device = "/TPU:0"} + %cid0 = "tf.Identity"(%carg1) : (!tf_res) -> !tf_res loc("cid0") + %read = "tf.ReadVariableOp"(%cid0) : (!tf_res) -> tensor<32xf32> + %cst = constant dense<3.0> : tensor<32xf32> + %cmp = "tf.Less"(%read, %cst) : (tensor<32xf32>, tensor<32xf32>) -> tensor<32xi1> + %dims = constant dense<0> : tensor<1xi32> + %reduce = "tf.All"(%cmp, %dims) {keep_dims = false} : (tensor<32xi1>, tensor<1xi32>) -> tensor + "tf.Yield"(%reduce) : (tensor) -> () + }, { + ^bb0(%barg0: tensor, %barg1: !tf_res, %barg2: !tf_res): + // CHECK: ^bb + // CHECK: "tf.Identity" + // CHECK-SAME: {device = "/TPU:0"} + %bid0 = "tf.Identity"(%barg1) : (!tf_res) -> !tf_res + // CHECK-NEXT: "tf.Identity" + // CHECK-SAME: {device = "/TPU:1"} + %id1 = "tf.Identity"(%barg2) : (!tf_res) -> !tf_res + "tf.Yield"(%barg0, %bid0, %id1) : (tensor, !tf_res,!tf_res) -> () + }){is_stateless = false} + : (tensor, !tf_res, !tf_res) -> (tensor, !tf_res, !tf_res) + tf_executor.yield + } + tf_executor.fetch %island : !tf_executor.control + } + return +} + +// Test propagation through WhileRegion (non-inlined calls) +// CHECK-LABEL: func @propagate_while_region +func @propagate_while_region( + %arg0: !tf_res {tf.device = "/TPU:0"}, + %arg1: tensor) { + tf_executor.graph { + // CHECK: tf_executor.island + %island = tf_executor.island { + // CHECK-NEXT: "tf.Identity" + // CHECK-SAME: {device = "/TPU:0"} + %id0 = "tf.Identity"(%arg0) : (!tf_res) -> !tf_res + // CHECK-NEXT: "tf.VarHandleOp" + %var_handle = "tf.VarHandleOp"() {container = "c", shared_name = "v0", device = "/TPU:1"} : () -> !tf_res + // CHECK-NEXT: "tf.WhileRegion" + "tf.WhileRegion"(%arg1, %id0, %var_handle) ({ + ^bb0(%carg0: tensor, %carg1: !tf_res, %carg2: !tf_res): + %cond = call @whileregion_cond(%carg0, %carg1, %carg2) : (tensor, !tf_res, !tf_res) -> tensor + "tf.Yield"(%cond) : (tensor) -> () + }, { + ^bb0(%barg0: tensor, %barg1: !tf_res, %barg2: !tf_res): + %new_values:3 = call @whileregion_body(%barg0, %barg1, %barg2) : (tensor, !tf_res,!tf_res) -> (tensor, !tf_res,!tf_res) + "tf.Yield"(%new_values#0, %new_values#1, %new_values#2) : (tensor, !tf_res,!tf_res) -> () + }){is_stateless = false} + : (tensor, !tf_res, !tf_res) -> (tensor, !tf_res, !tf_res) + tf_executor.yield + } + tf_executor.fetch %island : !tf_executor.control + } + return +} + +// CHECK-LABEL: func @whileregion_body +func @whileregion_body(%arg0: tensor, %arg1: !tf_res, %arg2: !tf_res) -> (tensor, !tf_res, !tf_res) { + %graph:3 = tf_executor.graph { + // CHECK: tf_executor.island + %island:4 = tf_executor.island { + // CHECK-NEXT: "tf.Identity" + // CHECK-SAME: {device = "/TPU:0"} + %id0 = "tf.Identity"(%arg1) : (!tf_res) -> !tf_res + // CHECK-NEXT: "tf.Identity" + // CHECK-SAME: {device = "/TPU:1"} + %id1 = "tf.Identity"(%arg2) : (!tf_res) -> !tf_res + tf_executor.yield %arg0, %id0, %id1 : tensor, !tf_res, !tf_res + } + tf_executor.fetch %island#0, %island#1, %island#2 : tensor, !tf_res, !tf_res + } + return %graph#0, %graph#1, %graph#2: tensor, !tf_res, !tf_res +} + +// CHECK-LABEL: func @whileregion_cond +func @whileregion_cond(%arg0: tensor, %arg1: !tf_res, %arg2: !tf_res) -> tensor { + %graph = tf_executor.graph { + // CHECK: tf_executor.island + %island:2 = tf_executor.island { + // CHECK-NEXT: "tf.Identity" + // CHECK-SAME: {device = "/TPU:0"} + %id0 = "tf.Identity"(%arg1) : (!tf_res) -> !tf_res + %read = "tf.ReadVariableOp"(%id0) : (!tf_res) -> tensor<32xf32> + %cst = constant dense<3.0> : tensor<32xf32> + %cmp = "tf.Less"(%read, %cst) : (tensor<32xf32>, tensor<32xf32>) -> tensor<32xi1> + %dims = constant dense<0> : tensor<1xi32> + %reduce = "tf.All"(%cmp, %dims) {keep_dims = false} : (tensor<32xi1>, tensor<1xi32>) -> tensor + tf_executor.yield %reduce : tensor + } + tf_executor.fetch %island#0 : tensor + } + return %graph : tensor +} diff --git a/tensorflow/compiler/mlir/tensorflow/tests/resource_op_lifting.mlir b/tensorflow/compiler/mlir/tensorflow/tests/resource_op_lifting.mlir index ac5c2df8f7e..213ca402f56 100644 --- a/tensorflow/compiler/mlir/tensorflow/tests/resource_op_lifting.mlir +++ b/tensorflow/compiler/mlir/tensorflow/tests/resource_op_lifting.mlir @@ -112,26 +112,6 @@ func @internal_resource() -> tensor<*xi32> { // ----- -// Tests that pass fails when there are remaining resource operationss that can -// not be lifted. - -func @lifting_failure() -> tensor<*xi32> { - - %0 = "tf.VarHandleOp"() {container = "c", shared_name = "v"} : () -> tensor<*x!tf.resource> - - // expected-error @+1 {{has remaining resource inputs that can not be lifted}} - %1 = "tf_device.cluster"() ( { - %2 = "tf.ReadVariableOp"(%0) {dtype = i32} : (tensor<*x!tf.resource>) -> tensor<*xi32> - %3 = "tf.SomeResourceOp"(%0, %2) : (tensor<*x!tf.resource>, tensor<*xi32>) -> tensor<*xi32> - "tf.AssignVariableOp"(%0, %3) {dtype = i32} : (tensor<*x!tf.resource>, tensor<*xi32>) -> () - tf_device.return %3 : tensor<*xi32> - }) {cluster_attr = "cluster_attr"} : () -> tensor<*xi32> - - return %1 : tensor<*xi32> -} - -// ----- - // Tests that pass lifts resource reads/writes from a loop, and removed unused // resources. @@ -347,30 +327,6 @@ func @while_cond(%arg0: tensor<*x!tf.resource>>) -> tensor { // ----- -// Tests that pass reports error on unsupported ops in loop body. - -func @cluster_with_loop() -> () { - %0 = "tf.VarHandleOp"() {container = "c", shared_name = "v"} : () -> tensor<*x!tf.resource>> - "tf_device.cluster"() ( { - %1 = "tf.While"(%0) { - body = @while_body, cond = @while_cond, device = "", is_stateless = false} - : (tensor<*x!tf.resource>>) -> (tensor<*x!tf.resource>>) - tf_device.return - }) {cluster_attr = "cluster_attr"} : () -> () - return -} -func @while_body(%arg0: tensor<*x!tf.resource>>) -> (tensor<*x!tf.resource>>) { - // expected-error @+1 {{found unsupported operations on resource.}} - "tf._UnknownOp"(%arg0) : (tensor<*x!tf.resource>>) -> () - return %arg0 : tensor<*x!tf.resource>> -} -func @while_cond(%arg0: tensor<*x!tf.resource>>) -> tensor { - %read = "tf.ReadVariableOp"(%arg0) : (tensor<*x!tf.resource>>) -> tensor - return %read : tensor -} - -// ----- - // Tests that pass reports error on unsupported ops in loop cond. func @cluster_with_loop() -> () { @@ -409,7 +365,7 @@ func @cluster_with_case(%arg0: tensor) -> tensor<4xf32> { // CHECK: %[[CLUSTER:.*]]:2 = "tf_device.cluster"() %2 = "tf_device.cluster"() ( { // CHECK: %[[CASE:.*]]:2 = "tf.Case"(%[[ARG0]], %[[READ0]], %[[READ1]]) - %3:2 = "tf.Case"(%arg0, %0, %1) {branches = [@branch_0, @branch_1, @branch_2]} + %3:2 = "tf.Case"(%arg0, %0, %1) {branches = [@branch_0, @branch_1, @branch_2], is_stateless = false} : (tensor, tensor<*x!tf.resource>>, tensor<*x!tf.resource>>) -> (tensor<*x!tf.resource>>, tensor<4xf32>) // CHECK-NEXT: %[[ADD:.*]] = "tf.AddV2"(%[[CASE]]#1, %[[CASE]]#0) diff --git a/tensorflow/compiler/mlir/tensorflow/tests/shape_inference.mlir b/tensorflow/compiler/mlir/tensorflow/tests/shape_inference.mlir index 4a5e3c8deaa..3e613573d42 100644 --- a/tensorflow/compiler/mlir/tensorflow/tests/shape_inference.mlir +++ b/tensorflow/compiler/mlir/tensorflow/tests/shape_inference.mlir @@ -2,69 +2,69 @@ // RUN: tf-opt %s -tf-shape-inference=propagate-caller-callee-constants -verify-diagnostics | FileCheck %s module attributes {tf.versions = {bad_consumers = [], min_consumer = 0 : i32, producer = 130 : i32}} { -// CHECK-LABEL: func @main(%arg0: tensor<1xi32>, %arg1: tensor<1xi32>) -> tensor<1xi32> + // CHECK-LABEL: func @main(%arg0: tensor<1xi32>, %arg1: tensor<1xi32>) -> tensor<1xi32> func @main(%arg0: tensor<1xi32>, %arg1: tensor<1xi32>) -> tensor<*xi32> { - // CHECK: %[[RESULT:.*]] = "tf.AddV2" - // CHECK-SAME: (tensor<1xi32>, tensor<1xi32>) -> tensor<1xi32> - // CHECK: return %[[RESULT]] : tensor<1xi32> + // CHECK: %[[RESULT:.*]] = "tf.AddV2" + // CHECK-SAME: (tensor<1xi32>, tensor<1xi32>) -> tensor<1xi32> + // CHECK: return %[[RESULT]] : tensor<1xi32> %0 = "tf.Cast"(%arg0) : (tensor<1xi32>) -> tensor<*xi32> %1 = "tf.Cast"(%arg1) : (tensor<1xi32>) -> tensor<*xi32> %2 = "tf.AddV2"(%0, %1) : (tensor<*xi32>, tensor<*xi32>) -> tensor<*xi32> return %2 : tensor<*xi32> } -// CHECK-LABEL: func @simple_chain + // CHECK-LABEL: func @simple_chain func @simple_chain(%arg0: tensor<1xf32>) -> tensor<*xf32> { -// CHECK: %[[MUL:.*]] = "tf.Mul"{{.*}} (tensor<1xf32>, tensor<1xf32>) -> tensor<1xf32> -// CHECK: %[[ADD:.*]] = "tf.Add"(%[[MUL]], %[[MUL]]) : (tensor<1xf32>, tensor<1xf32>) -> tensor<1xf32> -// CHECK: return %[[ADD]] : tensor<1xf32> + // CHECK: %[[MUL:.*]] = "tf.Mul"{{.*}} (tensor<1xf32>, tensor<1xf32>) -> tensor<1xf32> + // CHECK: %[[ADD:.*]] = "tf.Add"(%[[MUL]], %[[MUL]]) : (tensor<1xf32>, tensor<1xf32>) -> tensor<1xf32> + // CHECK: return %[[ADD]] : tensor<1xf32> %0 = "tf.Mul"(%arg0, %arg0) : (tensor<1xf32>, tensor<1xf32>) -> tensor<*xf32> %1 = "tf.Add"(%0, %0) : (tensor<*xf32>, tensor<*xf32>) -> tensor<*xf32> return %1 : tensor<*xf32> } -// CHECK-LABEL: func @simple_chain_with_broadcast + // CHECK-LABEL: func @simple_chain_with_broadcast func @simple_chain_with_broadcast(%arg0: tensor<1xf32>, %arg1: tensor<10xf32>) -> tensor<*xf32> { -// CHECK: %[[MUL:.*]] = "tf.Mul"{{.*}} (tensor<1xf32>, tensor<10xf32>) -> tensor<10xf32> -// CHECK: %[[ADD:.*]] = "tf.Add"(%[[MUL]], %[[MUL]]) : (tensor<10xf32>, tensor<10xf32>) -> tensor<10xf32> -// CHECK: %[[CAST:.*]] = "tf.Cast"(%[[ADD]]) {{.*}} : (tensor<10xf32>) -> tensor<*xf32> -// CHECK: %[[UNKNOWN:.*]] = addf %[[CAST]], %[[CAST]] : tensor<*xf32> -// CHECK: return %[[UNKNOWN]] : tensor<*xf32> + // CHECK: %[[MUL:.*]] = "tf.Mul"{{.*}} (tensor<1xf32>, tensor<10xf32>) -> tensor<10xf32> + // CHECK: %[[ADD:.*]] = "tf.Add"(%[[MUL]], %[[MUL]]) : (tensor<10xf32>, tensor<10xf32>) -> tensor<10xf32> + // CHECK: %[[CAST:.*]] = "tf.Cast"(%[[ADD]]) {{.*}} : (tensor<10xf32>) -> tensor<*xf32> + // CHECK: %[[UNKNOWN:.*]] = addf %[[CAST]], %[[CAST]] : tensor<*xf32> + // CHECK: return %[[UNKNOWN]] : tensor<*xf32> %0 = "tf.Mul"(%arg0, %arg1) : (tensor<1xf32>, tensor<10xf32>) -> tensor<*xf32> %1 = "tf.Add"(%0, %0) : (tensor<*xf32>, tensor<*xf32>) -> tensor<*xf32> %2 = addf %1, %1 : tensor<*xf32> return %2 : tensor<*xf32> } -// CHECK-LABEL: func @unknown_op + // CHECK-LABEL: func @unknown_op func @unknown_op(%arg0: tensor<1xf32>) -> tensor<*xf32> { -// CHECK: %[[MUL:.*]] = "tf.Mul"{{.*}} (tensor<1xf32>, tensor<1xf32>) -> tensor<1xf32> -// CHECK: %[[UNKNOWN:.*]] = "tf.Unknown"(%[[MUL]], %[[MUL]]) : (tensor<1xf32>, tensor<1xf32>) -> tensor<*xf32> -// CHECK: return %[[UNKNOWN]] : tensor<*xf32> + // CHECK: %[[MUL:.*]] = "tf.Mul"{{.*}} (tensor<1xf32>, tensor<1xf32>) -> tensor<1xf32> + // CHECK: %[[UNKNOWN:.*]] = "tf.Unknown"(%[[MUL]], %[[MUL]]) : (tensor<1xf32>, tensor<1xf32>) -> tensor<*xf32> + // CHECK: return %[[UNKNOWN]] : tensor<*xf32> %0 = "tf.Mul"(%arg0, %arg0) : (tensor<1xf32>, tensor<1xf32>) -> tensor<*xf32> %1 = "tf.Unknown"(%0, %0) : (tensor<*xf32>, tensor<*xf32>) -> tensor<*xf32> return %1 : tensor<*xf32> } -// CHECK-LABEL: func @multiple_blocks_one_return(%arg0: tensor) -> tensor -func @multiple_blocks_one_return(%arg0: tensor) -> tensor<*xf32> { - br ^bb1 -^bb1: -// CHECK: %[[IDENTITY:.*]] = "tf.Identity"(%arg0) : (tensor) -> tensor -// CHECK: return %[[IDENTITY]] : tensor - %ret = "tf.Identity"(%arg0) : (tensor) -> tensor<*xf32> - return %ret : tensor<*xf32> -} + // CHECK-LABEL: func @multiple_blocks_one_return(%arg0: tensor) -> tensor + func @multiple_blocks_one_return(%arg0: tensor) -> tensor<*xf32> { + br ^bb1 + ^bb1: + // CHECK: %[[IDENTITY:.*]] = "tf.Identity"(%arg0) : (tensor) -> tensor + // CHECK: return %[[IDENTITY]] : tensor + %ret = "tf.Identity"(%arg0) : (tensor) -> tensor<*xf32> + return %ret : tensor<*xf32> + } -// Tests the case where an inference opportunity relies on folding. + // Tests the case where an inference opportunity relies on folding. -// CHECK-LABEL: func @simple_folding + // CHECK-LABEL: func @simple_folding func @simple_folding(%arg0: tensor<1x1x1x1xi32>, %arg1: tensor<1x1x1x1xf32>) -> tensor { -// CHECK: %[[SHAPE:.*]] = "tf.Shape" -// CHECK: %[[CONV:.*]] = "tf.Conv2DBackpropInput"(%[[SHAPE]] -// CHECK-SAME: (tensor<4xi32>, tensor<1x1x1x1xf32>, tensor<1x1x1x1xf32>) -> tensor<1x1x1x1xf32> -// CHECK: return %[[CONV]] : tensor<1x1x1x1xf32> + // CHECK: %[[SHAPE:.*]] = "tf.Shape" + // CHECK: %[[CONV:.*]] = "tf.Conv2DBackpropInput"(%[[SHAPE]] + // CHECK-SAME: (tensor<4xi32>, tensor<1x1x1x1xf32>, tensor<1x1x1x1xf32>) -> tensor<1x1x1x1xf32> + // CHECK: return %[[CONV]] : tensor<1x1x1x1xf32> %0 = "tf.Shape"(%arg0) : (tensor<1x1x1x1xi32>) -> tensor<4xi32> %1 = "tf.Conv2DBackpropInput"(%0, %arg1, %arg1) { padding = "VALID", strides = [1, 1, 1, 1] @@ -72,7 +72,7 @@ func @multiple_blocks_one_return(%arg0: tensor) -> tensor<*xf32> { return %1 : tensor } -// Tests where tf.Const's value needs to be refined. + // Tests where tf.Const's value needs to be refined. func @const_refine() -> tensor<*xi32> { %0 = "tf.Const"() {value = dense<[3, 2]> : tensor<2xi32>} : () -> tensor<*xi32> @@ -81,9 +81,9 @@ func @multiple_blocks_one_return(%arg0: tensor) -> tensor<*xf32> { return %0 : tensor<*xi32> } -// Tests the case where an op's shape function returns non-fully-defined shapes. + // Tests the case where an op's shape function returns non-fully-defined shapes. -// CHECK-LABEL: func @op_non_fully_defined_shape_fn + // CHECK-LABEL: func @op_non_fully_defined_shape_fn func @op_non_fully_defined_shape_fn(%arg0: tensor<0xi32>, %arg1: tensor<0xi32>) -> tensor { // CHECK: tf.BroadcastGradientArgs // CHECK-SAME: (tensor<0xi32>, tensor<0xi32>) -> (tensor, tensor) @@ -91,7 +91,7 @@ func @multiple_blocks_one_return(%arg0: tensor) -> tensor<*xf32> { return %2#0 : tensor } -// CHECK-LABEL: func @shape_from_const_input + // CHECK-LABEL: func @shape_from_const_input func @shape_from_const_input(%arg0: tensor<3x3x32x64xf32>, %arg1: tensor<200x24x24x64xf32>) -> tensor { %0 = "tf.Const"() {value = dense<[200, 26, 26, 32]> : tensor<4xi32>} : () -> tensor<4xi32> // CHECK: tf.Conv2DBackpropInput @@ -223,7 +223,7 @@ func @multiple_blocks_one_return(%arg0: tensor) -> tensor<*xf32> { // CHECK-SAME: %[[ARG_1:.*]]: tensor>> func @shape_from_case_to_branch_functions(%arg0: tensor, %arg1: tensor>>) -> tensor<1x2x3xf32> { // CHECK: %[[CASE:.*]] = "tf.Case"(%[[ARG_0]], %[[ARG_1]]) - %0 = "tf.Case"(%arg0, %arg1) {branches = [@branch_0, @branch_1]} : (tensor, tensor>>) -> tensor<1x2x3xf32> + %0 = "tf.Case"(%arg0, %arg1) {branches = [@branch_0, @branch_1], is_stateless = false} : (tensor, tensor>>) -> tensor<1x2x3xf32> // CHECK: return %[[CASE]] : tensor<1x2x3xf32> return %0 : tensor<1x2x3xf32> } diff --git a/tensorflow/compiler/mlir/tensorflow/tests/tensor_list_ops_decomposition.mlir b/tensorflow/compiler/mlir/tensorflow/tests/tensor_list_ops_decomposition.mlir index 3d187aa5d60..92cb0458bf9 100644 --- a/tensorflow/compiler/mlir/tensorflow/tests/tensor_list_ops_decomposition.mlir +++ b/tensorflow/compiler/mlir/tensorflow/tests/tensor_list_ops_decomposition.mlir @@ -256,7 +256,7 @@ func @main(%arg0: tensor) -> () { %max_size = "tf.Const"() {value = dense<10> : tensor} : () -> tensor // CHECK-NOT: tf.EmptyTensorList %tl = "tf.EmptyTensorList"(%elem_shape, %max_size) : (tensor<0xi32>, tensor) -> tensor>> - %case_op = "tf.Case"(%arg0, %tl) {branches = [@branch_0, @branch_1, @branch_2]} + %case_op = "tf.Case"(%arg0, %tl) {branches = [@branch_0, @branch_1, @branch_2], is_stateless = false} : (tensor, tensor>>) -> tensor>> // CHECK: "tf.Slice" %pop:2 = "tf.TensorListPopBack"(%case_op, %elem_shape) : (tensor>>, tensor<0xi32>) -> (tensor>>, tensor) diff --git a/tensorflow/compiler/mlir/tensorflow/tests/tf-ops.mlir b/tensorflow/compiler/mlir/tensorflow/tests/tf-ops.mlir index 20a0e22c48e..9a8d97eddf1 100644 --- a/tensorflow/compiler/mlir/tensorflow/tests/tf-ops.mlir +++ b/tensorflow/compiler/mlir/tensorflow/tests/tf-ops.mlir @@ -775,12 +775,30 @@ func @testInvalidIfOp(tensor, tensor<2xf32>) -> tensor<2xf32> { // ----- func @testIfThen(tensor<2xf32>, tensor<2xf32>) -> tensor<2xf32> -func @testIfElse(tensor<2xf32>, tensor<2xf32>) -> tensor<2xf32> +func @testIfElse(tensor<2xf32>) -> tensor<2xf32> // Test invalid tf.If operation func @testInvalidIfOp(tensor, tensor<2xf32>) -> tensor<2xf32> { ^bb0(%arg0: tensor, %arg1: tensor<2xf32>): - // expected-error @+1 {{branches should have 1 inputs}} + // expected-error @+1 {{expects all branches to have 1 input(s), but 'then_branch' has 2 input(s)}} + %1 = "tf.If"(%arg0, %arg1) { + then_branch = @testIfThen, + else_branch = @testIfElse, + is_stateless = false + } : (tensor, tensor<2xf32>) -> tensor<2xf32> + + return %1 : tensor<2xf32> +} + +// ----- + +func @testIfThen(tensor<2xf32>) -> (tensor<2xf32>, tensor<2xf32>) +func @testIfElse(tensor<2xf32>) -> tensor<2xf32> + +// Test invalid tf.If operation +func @testInvalidIfOp(tensor, tensor<2xf32>) -> tensor<2xf32> { +^bb0(%arg0: tensor, %arg1: tensor<2xf32>): + // expected-error @+1 {{expects all branches to have 1 result(s), but 'then_branch' has 2 result(s)}} %1 = "tf.If"(%arg0, %arg1) { then_branch = @testIfThen, else_branch = @testIfElse, @@ -798,7 +816,7 @@ func @testIfElse(tensor<*xf32>) -> tensor<*xf32> // Test invalid tf.If operation func @testInvalidIfOp(tensor, tensor<2xf32>) -> tensor<2xf32> { ^bb0(%arg0: tensor, %arg1: tensor<2xf32>): - // expected-error @+1 {{then branch input type tensor<*xf16> is incompatible with operand type tensor<2xf32>}} + // expected-error @+1 {{expects operand type 'tensor<2xf32>' to be cast compatible with 'then_branch' input type 'tensor<*xf16>' at index 0}} %1 = "tf.If"(%arg0, %arg1) { then_branch = @testIfThen, else_branch = @testIfElse, @@ -816,7 +834,7 @@ func @testIfElse(tensor<3xf32>) -> tensor<*xf32> // Test invalid tf.If operation func @testInvalidIfOp(tensor, tensor<*xf32>) -> tensor<2xf32> { ^bb0(%arg0: tensor, %arg1: tensor<*xf32>): - // expected-error @+1 {{branches inputs have incompatible types tensor<2xf32> and tensor<3xf32>}} + // expected-error @+1 {{expects all branch input type(s) (tensor<2xf32>, tensor<3xf32>) at index 0 to be cast compatible}} %1 = "tf.If"(%arg0, %arg1) { then_branch = @testIfThen, else_branch = @testIfElse, @@ -834,7 +852,7 @@ func @testIfElse(tensor<*xf32>) -> tensor<3xf32> // Test invalid tf.If operation func @testInvalidIfOp(tensor, tensor<*xf32>) -> tensor<2xf32> { ^bb0(%arg0: tensor, %arg1: tensor<*xf32>): - // expected-error @+1 {{else branch result type tensor<3xf32> is incompatible with op result type tensor<2xf32>}} + // expected-error @+1 {{expects result type 'tensor<2xf32>' to be cast compatible with 'else_branch' result type 'tensor<3xf32>' at index 0}} %1 = "tf.If"(%arg0, %arg1) { then_branch = @testIfThen, else_branch = @testIfElse, @@ -848,7 +866,7 @@ func @testInvalidIfOp(tensor, tensor<*xf32>) -> tensor<2xf32> { // Test invalid tf.Yield operation (parent should be IfRegion) func @testInvalidYieldOp(%arg0: f32) -> () { - // expected-error @+1 {{'tf.Yield' op expects parent op to be one of 'tf.IfRegion, tf.WhileRegion'}} + // expected-error @+1 {{'tf.Yield' op expects parent op to be one of 'tf.CaseRegion, tf.IfRegion, tf.WhileRegion'}} "tf.Yield"(%arg0) : (f32) -> () } @@ -895,7 +913,7 @@ func @testValidIfRegionOpWithMultipleResults(%arg0: tensor, %arg1: tensor<2x // Test invalid type for operand #0 for tf.IfRegion operation func @testInvalidIfRegionOpType0(%arg0: f32, %arg1: tensor<2xf32>) -> tensor<2xf32> { - // expected-error @+1 {{operand #0 must be tensor of tf.dtype values}} + // expected-error @+1 {{operand #0 must be 0D tensor of 1-bit signless integer values, but got 'f32'}} %0 = "tf.IfRegion"(%arg0) ({ %t = "tf.Abs"(%arg1) : (tensor<2xf32>) -> tensor<2xf32> "tf.Yield"(%t) : (tensor<2xf32>) -> () @@ -2033,6 +2051,15 @@ func @testConst() -> tensor { // ----- +// Test invalid tf.ToBool +func @testInvalidToBool(%arg0: tensor) -> tensor<1xi1> { + // expected-error @+1 {{op result #0 must be 0D tensor of 1-bit signless integer values, but got 'tensor<1xi1>'}} + %0 = "tf.ToBool"(%arg0) : (tensor) -> tensor<1xi1> + return %0 : tensor<1xi1> +} + +// ----- + // Test valid tf.Transpose // CHECK-LABEL: testTranspose func @testTranspose(tensor<2x3xf32>) -> tensor<3x2xf32> { @@ -3313,3 +3340,131 @@ func @testBatchToSpaceInvalidOutputDepth(%arg0: tensor<16x8x8x3xf32>, %arg1: ten %0 = "tf.BatchToSpace"(%arg0, %arg1) {block_size = 2 : i64} : (tensor<16x8x8x3xf32>, tensor<*xi32>) -> tensor<4x8x8x8xf32> return } + +// ----- + +func @branch() + +func @testCaseBadBranchIndicesShape(%arg0: tensor<8xi32>) { + // expected-error @+1 {{expects 'branch_index' to be a scalar, but got 'tensor<8xi32>'}} + "tf.Case"(%arg0) {branches = [@branch], is_stateless = false} : (tensor<8xi32>) -> () + return +} + +// ----- + +func @branch0(tensor<2xf32>, tensor<2xf32>) -> tensor<2xf32> +func @branch1(tensor<2xf32>) -> tensor<2xf32> + +func @testCaseMismatchedNumOperands(%arg0: tensor, %arg1: tensor<2xf32>) -> tensor<2xf32> { + // expected-error @+1 {{expects all branches to have 1 input(s), but branch #0 has 2 input(s)}} + %0 = "tf.Case"(%arg0, %arg1) {branches = [@branch0, @branch1], is_stateless = false} : (tensor, tensor<2xf32>) -> tensor<2xf32> + return %0 : tensor<2xf32> +} + +// ----- + +func @branch0(tensor<2xf32>) -> (tensor<2xf32>, tensor<2xf32>) +func @branch1(tensor<2xf32>) -> tensor<2xf32> + +func @testCaseMismatchedNumResults(%arg0: tensor, %arg1: tensor<2xf32>) -> tensor<2xf32> { + // expected-error @+1 {{expects all branches to have 1 result(s), but branch #0 has 2 result(s)}} + %0 = "tf.Case"(%arg0, %arg1) {branches = [@branch0, @branch1], is_stateless = false} : (tensor, tensor<2xf32>) -> tensor<2xf32> + return %0 : tensor<2xf32> +} + +// ----- + +func @branch0(tensor<*xf16>) -> tensor<*xf32> +func @branch1(tensor<*xf32>) -> tensor<*xf32> + +func @testCaseOperandNotCastCompatible(%arg0: tensor, %arg1: tensor<2xf32>) -> tensor<2xf32> { + // expected-error @+1 {{expects operand type 'tensor<2xf32>' to be cast compatible with branch #0 input type 'tensor<*xf16>' at index 0}} + %0 = "tf.Case"(%arg0, %arg1) {branches = [@branch0, @branch1], is_stateless = false} : (tensor, tensor<2xf32>) -> tensor<2xf32> + return %0 : tensor<2xf32> +} + +// ----- + +func @branch0(tensor<2xf32>) -> tensor<*xf32> +func @branch1(tensor<3xf32>) -> tensor<*xf32> + +func @testCaseBranchArgumentsNotCastCompatible(%arg0: tensor, %arg1: tensor<*xf32>) -> tensor<2xf32> { + // expected-error @+1 {{expects all branch input type(s) (tensor<2xf32>, tensor<3xf32>) at index 0 to be cast compatible}} + %0 = "tf.Case"(%arg0, %arg1) {branches = [@branch0, @branch1], is_stateless = false} : (tensor, tensor<*xf32>) -> tensor<2xf32> + return %0 : tensor<2xf32> +} + +// ----- + +func @branch0(tensor<*xf32>) -> tensor<*xf32> +func @branch1(tensor<*xf32>) -> tensor<3xf32> + +func @testCaseResultNotCastCompatible(%arg0: tensor, %arg1: tensor<*xf32>) -> tensor<2xf32> { + // expected-error @+1 {{expects result type 'tensor<2xf32>' to be cast compatible with branch #1 result type 'tensor<3xf32>' at index 0}} + %0 = "tf.Case"(%arg0, %arg1) {branches = [@branch0, @branch1], is_stateless = false} : (tensor, tensor<*xf32>) -> tensor<2xf32> + return %0 : tensor<2xf32> +} + +// ----- + +func @testCaseRegionNoRegions(%arg0: tensor) { + // expected-error @+1 {{expects to have at least 1 region}} + "tf.CaseRegion"(%arg0) {is_stateless = false} : (tensor) -> () + return +} + +// ----- + +func @testCaseRegionBadBranchIndicesShape(%arg0: tensor<8xi32>) { + // expected-error @+1 {{expects 'branch_index' to be a scalar, but got 'tensor<8xi32>'}} + "tf.CaseRegion"(%arg0) ( { + "tf.Yield"() : () -> () + }) {is_stateless = false} : (tensor<8xi32>) -> () + return +} + +// ----- + +func @testCaseRegionMismatchedNumResults(%arg0: tensor) { + // expected-error @+1 {{region #0 should have same number (1) of results as tf.CaseRegion but has 0 results}} + %1 = "tf.CaseRegion"(%arg0) ( { + "tf.Yield"() : () -> () + }) {is_stateless = false} : (tensor) -> tensor + return +} + +// ----- + +func @testCaseRegionMismatchedResultTypes(%arg0: tensor, %arg1: tensor) { + // expected-error @+1 {{region #0 result type tensor is incompatible with tf.CaseRegion result type tensor at index 0}} + %1 = "tf.CaseRegion"(%arg0) ( { + "tf.Yield"(%arg1) : (tensor) -> () + }) {is_stateless = false} : (tensor) -> tensor + return +} + +// ----- + +// Test valid tf.Cumsum +func @testCumsum(%arg: tensor<8x16xf32>, %axis: tensor) -> tensor<8x16xf32> { + %0 = "tf.Cumsum"(%arg, %axis) : (tensor<8x16xf32>, tensor) -> tensor<8x16xf32> + return %0 : tensor<8x16xf32> +} + +// ----- + +func @testCumprod(%arg: tensor<8x16xf32>, %axis: tensor<2xi32>) -> tensor<8x16xf32> { + // expected-error @+1 {{requires scalar axis operand}} + %0 = "tf.Cumprod"(%arg, %axis) : (tensor<8x16xf32>, tensor<2xi32>) -> tensor<8x16xf32> + return %0 : tensor<8x16xf32> +} + +// ----- + +func @testCumprod(%arg: tensor<8x16xf32>) -> tensor<8x16xf32> { + %axis = constant dense<-3> : tensor + // expected-error @+1 {{axis operand should be within range [-2, 2)}} + %0 = "tf.Cumprod"(%arg, %axis) : (tensor<8x16xf32>, tensor) -> tensor<8x16xf32> + return %0 : tensor<8x16xf32> +} diff --git a/tensorflow/compiler/mlir/tensorflow/tests/tf_device_index_selector.mlir b/tensorflow/compiler/mlir/tensorflow/tests/tf_device_index_selector.mlir index 7fc2b210f91..11ceac1fe99 100644 --- a/tensorflow/compiler/mlir/tensorflow/tests/tf_device_index_selector.mlir +++ b/tensorflow/compiler/mlir/tensorflow/tests/tf_device_index_selector.mlir @@ -9,17 +9,17 @@ func @select(%arg0: tensor, %arg1: tensor) -> (tensor, tensor tensor %1 = "tf.DeviceIndex"() {device = "", device_names = ["CPU", "GPU"]} : () -> tensor - %4 = "tf.Case"(%1, %arg0, %arg1) {branches = [@sub, @add], output_shapes = [#tf.shape<>]} : (tensor, tensor, tensor) -> tensor + %4 = "tf.Case"(%1, %arg0, %arg1) {branches = [@sub, @add], output_shapes = [#tf.shape<>], is_stateless = false} : (tensor, tensor, tensor) -> tensor return %0, %4 : tensor, tensor } -func @add(%i: tensor, %arg0: tensor<*xf32>, %arg1: tensor<*xf32>) -> tensor<*xf32> { +func @add(%arg0: tensor<*xf32>, %arg1: tensor<*xf32>) -> tensor<*xf32> { %0 = "tf.Add"(%arg0, %arg1): (tensor<*xf32>, tensor<*xf32>) -> tensor<*xf32> return %0 : tensor<*xf32> } -func @sub(%i: tensor, %arg0: tensor<*xf32>, %arg1: tensor<*xf32>) -> tensor<*xf32> { +func @sub(%arg0: tensor<*xf32>, %arg1: tensor<*xf32>) -> tensor<*xf32> { %0 = "tf.Sub"(%arg0, %arg1) : (tensor<*xf32>, tensor<*xf32>) -> tensor<*xf32> return %0 : tensor<*xf32> } diff --git a/tensorflow/compiler/mlir/tensorflow/tests/tpu-dynamic-layout-pass.mlir b/tensorflow/compiler/mlir/tensorflow/tests/tpu-dynamic-layout-pass.mlir index 9467f890419..7b670cd831c 100644 --- a/tensorflow/compiler/mlir/tensorflow/tests/tpu-dynamic-layout-pass.mlir +++ b/tensorflow/compiler/mlir/tensorflow/tests/tpu-dynamic-layout-pass.mlir @@ -11,9 +11,9 @@ func @non_replicated(%arg0: tensor<*x!tf.resource> {tf.device = "/device:CPU:0"} NumDynamicShapes = 0 : i64, // The metadata encodes 2 parameter and two return values. metadata = "\0A\0E\08\01\18\01\22\08\08\01\1A\01\01\22\01\00\0A \08\01\12\10\12\02\08\03\12\02\08\03\12\02\08\01\12\02\08 \18\01\22\08\08\01\1A\01\01\22\01\00\12\0A\0A\08\08\01\1A\01\01\22\01\00\12\0A\0A\08\08\01\1A\01\01\22\01\00\18\02 \01", - mlir_module = "..."} : () -> (tensor, tensor) - tf_device.return %1#0, %1#1 : tensor, tensor - }) {device = "/device:CPU:0"} : () -> (tensor, tensor) + mlir_module = "..."} : () -> (tensor, tensor<2x!tf.string>) + tf_device.return %1#0, %1#1 : tensor, tensor<2x!tf.string> + }) {device = "/device:CPU:0"} : () -> (tensor, tensor<2x!tf.string>) // CHECK-DAG: %[[LAYOUT0:.*]] = "tf.TPUGetLayoutOp"(%[[COMPILE]]#1) {index = 0 : i64, is_output = false} // CHECK-DAG: %[[LAYOUT1:.*]] = "tf.TPUGetLayoutOp"(%[[COMPILE]]#1) {index = 1 : i64, is_output = false} // CHECK: %[[ITER:.*]]:2 = "tf.IteratorGetNext" @@ -31,7 +31,7 @@ func @non_replicated(%arg0: tensor<*x!tf.resource> {tf.device = "/device:CPU:0"} // CHECK-NEXT: "tf.TPUExecute"(%[[COPY0]], %[[COPY1]], %[[COMPILE]]#1) %execute = "tf_device.launch"() ( { %3 = "tf.TPUExecute"(%2#0, %2#1, %compile#1) - : (tensor<3x3x1x32xf32>, tensor<3x3x1x32xf32>, tensor) -> tensor + : (tensor<3x3x1x32xf32>, tensor<3x3x1x32xf32>, tensor<2x!tf.string>) -> tensor tf_device.return %3 : tensor }) {device = "/device:TPU:0"} : () -> tensor return %execute : tensor @@ -49,9 +49,9 @@ func @multiple_compile_uses(%arg0: tensor<*x!tf.resource> {tf.device = "/device: NumDynamicShapes = 0 : i64, // The metadata encodes 2 parameter and two return values. metadata = "\0A\0E\08\01\18\01\22\08\08\01\1A\01\01\22\01\00\0A \08\01\12\10\12\02\08\03\12\02\08\03\12\02\08\01\12\02\08 \18\01\22\08\08\01\1A\01\01\22\01\00\12\0A\0A\08\08\01\1A\01\01\22\01\00\12\0A\0A\08\08\01\1A\01\01\22\01\00\18\02 \01", - mlir_module = "..."} : () -> (tensor, tensor) - tf_device.return %1#0, %1#1 : tensor, tensor - }) {device = "/device:CPU:0"} : () -> (tensor, tensor) + mlir_module = "..."} : () -> (tensor, tensor<2x!tf.string>) + tf_device.return %1#0, %1#1 : tensor, tensor<2x!tf.string> + }) {device = "/device:CPU:0"} : () -> (tensor, tensor<2x!tf.string>) // CHECK-NOT: "tf.TPUGetLayoutOp" // CHECK-NOT: "tf.TPUCopyWithLayout" %2:2 = "tf.IteratorGetNext"(%arg0) {device = "/device:CPU:0"} @@ -62,13 +62,13 @@ func @multiple_compile_uses(%arg0: tensor<*x!tf.resource> {tf.device = "/device: }) {device = "/device:CPU:0"} : () -> () %execute0 = "tf_device.launch"() ( { %3 = "tf.TPUExecute"(%2#0, %2#1, %compile#1) - : (tensor<3x3x1x32xf32>, tensor<3x3x1x32xf32>, tensor) -> tensor + : (tensor<3x3x1x32xf32>, tensor<3x3x1x32xf32>, tensor<2x!tf.string>) -> tensor tf_device.return %3 : tensor }) {device = "/device:TPU:0"} : () -> tensor %4:2 = "tf._UnKnownOp_"() : () -> (tensor<3x3x1x32xf32>, tensor<3x3x1x32xf32>) %execute1 = "tf_device.launch"() ( { %5 = "tf.TPUExecute"(%4#0, %4#1, %compile#1) - : (tensor<3x3x1x32xf32>, tensor<3x3x1x32xf32>, tensor) -> tensor + : (tensor<3x3x1x32xf32>, tensor<3x3x1x32xf32>, tensor<2x!tf.string>) -> tensor tf_device.return %5 : tensor }) {device = "/device:TPU:0"} : () -> tensor return %execute1 : tensor @@ -85,9 +85,9 @@ func @on_tpu_iter(%arg0: tensor<*x!tf.resource> {tf.device = "/device:TPU:0"}) - NumDynamicShapes = 0 : i64, // The metadata encodes 2 parameter and two return values. metadata = "\0A\0E\08\01\18\01\22\08\08\01\1A\01\01\22\01\00\0A \08\01\12\10\12\02\08\03\12\02\08\03\12\02\08\01\12\02\08 \18\01\22\08\08\01\1A\01\01\22\01\00\12\0A\0A\08\08\01\1A\01\01\22\01\00\12\0A\0A\08\08\01\1A\01\01\22\01\00\18\02 \01", - mlir_module = "..."} : () -> (tensor, tensor) - tf_device.return %1#0, %1#1 : tensor, tensor - }) {device = "/device:CPU:0"} : () -> (tensor, tensor) + mlir_module = "..."} : () -> (tensor, tensor<2x!tf.string>) + tf_device.return %1#0, %1#1 : tensor, tensor<2x!tf.string> + }) {device = "/device:CPU:0"} : () -> (tensor, tensor<2x!tf.string>) // CHECK-NOT: "tf.TPUGetLayoutOp" // CHECK-NOT: "tf.TPUCopyWithLayout" %2:2 = "tf.IteratorGetNext"(%arg0) {device = "/device:TPU:0"} @@ -98,7 +98,7 @@ func @on_tpu_iter(%arg0: tensor<*x!tf.resource> {tf.device = "/device:TPU:0"}) - }) {device = "/device:CPU:0"} : () -> () %execute = "tf_device.launch"() ( { %3 = "tf.TPUExecute"(%2#0, %2#1, %compile#1) - : (tensor<3x3x1x32xf32>, tensor<3x3x1x32xf32>, tensor) -> tensor + : (tensor<3x3x1x32xf32>, tensor<3x3x1x32xf32>, tensor<2x!tf.string>) -> tensor tf_device.return %3 : tensor }) {device = "/device:TPU:0"} : () -> tensor return %execute : tensor @@ -116,9 +116,9 @@ func @arg_on_tpu_iter_on_cpu(%arg0: tensor<*x!tf.resource> {tf.device = "/device NumDynamicShapes = 0 : i64, // The metadata encodes 2 parameter and two return values. metadata = "\0A\0E\08\01\18\01\22\08\08\01\1A\01\01\22\01\00\0A \08\01\12\10\12\02\08\03\12\02\08\03\12\02\08\01\12\02\08 \18\01\22\08\08\01\1A\01\01\22\01\00\12\0A\0A\08\08\01\1A\01\01\22\01\00\12\0A\0A\08\08\01\1A\01\01\22\01\00\18\02 \01", - mlir_module = "..."} : () -> (tensor, tensor) - tf_device.return %1#0, %1#1 : tensor, tensor - }) {device = "/device:CPU:0"} : () -> (tensor, tensor) + mlir_module = "..."} : () -> (tensor, tensor<2x!tf.string>) + tf_device.return %1#0, %1#1 : tensor, tensor<2x!tf.string> + }) {device = "/device:CPU:0"} : () -> (tensor, tensor<2x!tf.string>) // CHECK-NOT: "tf.TPUGetLayoutOp" // CHECK-NOT: "tf.TPUCopyWithLayout" %2:2 = "tf.IteratorGetNext"(%arg0) {device = "/device:CPU:0"} @@ -129,7 +129,7 @@ func @arg_on_tpu_iter_on_cpu(%arg0: tensor<*x!tf.resource> {tf.device = "/device }) {device = "/device:CPU:0"} : () -> () %execute = "tf_device.launch"() ( { %3 = "tf.TPUExecute"(%2#0, %2#1, %compile#1) - : (tensor<3x3x1x32xf32>, tensor<3x3x1x32xf32>, tensor) -> tensor + : (tensor<3x3x1x32xf32>, tensor<3x3x1x32xf32>, tensor<2x!tf.string>) -> tensor tf_device.return %3 : tensor }) {device = "/device:TPU:0"} : () -> tensor return %execute : tensor @@ -148,9 +148,9 @@ func @arg_on_tpu_intermediate_ops_on_cpu(%arg0: tensor<*x!tf.resource> {tf.devic NumDynamicShapes = 0 : i64, // The metadata encodes 2 parameter and two return values. metadata = "\0A\0E\08\01\18\01\22\08\08\01\1A\01\01\22\01\00\0A \08\01\12\10\12\02\08\03\12\02\08\03\12\02\08\01\12\02\08 \18\01\22\08\08\01\1A\01\01\22\01\00\12\0A\0A\08\08\01\1A\01\01\22\01\00\12\0A\0A\08\08\01\1A\01\01\22\01\00\18\02 \01", - mlir_module = "..."} : () -> (tensor, tensor) - tf_device.return %1#0, %1#1 : tensor, tensor - }) {device = "/device:CPU:0"} : () -> (tensor, tensor) + mlir_module = "..."} : () -> (tensor, tensor<2x!tf.string>) + tf_device.return %1#0, %1#1 : tensor, tensor<2x!tf.string> + }) {device = "/device:CPU:0"} : () -> (tensor, tensor<2x!tf.string>) %id1 = "tf.Identity"(%arg0) {device = "/device:CPU:0"} : (tensor<*x!tf.resource>) -> (tensor<*x!tf.resource>) %id2 = "tf.Identity"(%id1) {device = "/device:CPU:0"} : (tensor<*x!tf.resource>) -> (tensor<*x!tf.resource>) // CHECK-NOT: "tf.TPUGetLayoutOp" @@ -163,7 +163,7 @@ func @arg_on_tpu_intermediate_ops_on_cpu(%arg0: tensor<*x!tf.resource> {tf.devic }) {device = "/device:CPU:0"} : () -> () %execute = "tf_device.launch"() ( { %3 = "tf.TPUExecute"(%2#0, %2#1, %compile#1) - : (tensor<3x3x1x32xf32>, tensor<3x3x1x32xf32>, tensor) -> tensor + : (tensor<3x3x1x32xf32>, tensor<3x3x1x32xf32>, tensor<2x!tf.string>) -> tensor tf_device.return %3 : tensor }) {device = "/device:TPU:0"} : () -> tensor return %execute : tensor @@ -181,9 +181,9 @@ func @var_handle_on_tpu_iter_on_cpu() -> tensor { NumDynamicShapes = 0 : i64, // The metadata encodes 2 parameter and two return values. metadata = "\0A\0E\08\01\18\01\22\08\08\01\1A\01\01\22\01\00\0A \08\01\12\10\12\02\08\03\12\02\08\03\12\02\08\01\12\02\08 \18\01\22\08\08\01\1A\01\01\22\01\00\12\0A\0A\08\08\01\1A\01\01\22\01\00\12\0A\0A\08\08\01\1A\01\01\22\01\00\18\02 \01", - mlir_module = "..."} : () -> (tensor, tensor) - tf_device.return %1#0, %1#1 : tensor, tensor - }) {device = "/device:CPU:0"} : () -> (tensor, tensor) + mlir_module = "..."} : () -> (tensor, tensor<2x!tf.string>) + tf_device.return %1#0, %1#1 : tensor, tensor<2x!tf.string> + }) {device = "/device:CPU:0"} : () -> (tensor, tensor<2x!tf.string>) %var = "tf.VarHandleOp"() {container = "c", shared_name = "v", device = "/device:TPU:0"} : () -> tensor<*x!tf.resource> // CHECK-NOT: "tf.TPUGetLayoutOp" // CHECK-NOT: "tf.TPUCopyWithLayout" @@ -195,7 +195,7 @@ func @var_handle_on_tpu_iter_on_cpu() -> tensor { }) {device = "/device:CPU:0"} : () -> () %execute = "tf_device.launch"() ( { %3 = "tf.TPUExecute"(%2#0, %2#1, %compile#1) - : (tensor<3x3x1x32xf32>, tensor<3x3x1x32xf32>, tensor) -> tensor + : (tensor<3x3x1x32xf32>, tensor<3x3x1x32xf32>, tensor<2x!tf.string>) -> tensor tf_device.return %3 : tensor }) {device = "/device:TPU:0"} : () -> tensor return %execute : tensor @@ -212,9 +212,9 @@ func @unsupported_ops(%arg0: tensor<3x3x1x32xf32> {tf.device = "/device:CPU:0"}) NumDynamicShapes = 0 : i64, // The metadata encodes 2 parameter and two return values. metadata = "\0A\0E\08\01\18\01\22\08\08\01\1A\01\01\22\01\00\0A \08\01\12\10\12\02\08\03\12\02\08\03\12\02\08\01\12\02\08 \18\01\22\08\08\01\1A\01\01\22\01\00\12\0A\0A\08\08\01\1A\01\01\22\01\00\12\0A\0A\08\08\01\1A\01\01\22\01\00\18\02 \01", - mlir_module = "..."} : () -> (tensor, tensor) - tf_device.return %1#0, %1#1 : tensor, tensor - }) {device = "/device:CPU:0"} : () -> (tensor, tensor) + mlir_module = "..."} : () -> (tensor, tensor<2x!tf.string>) + tf_device.return %1#0, %1#1 : tensor, tensor<2x!tf.string> + }) {device = "/device:CPU:0"} : () -> (tensor, tensor<2x!tf.string>) // CHECK-NOT: "tf.TPUGetLayoutOp" // CHECK-NOT: "tf.TPUCopyWithLayout" %2 = "tf._Unknown_"() : () -> tensor<3x3x1x32xf32> @@ -224,7 +224,7 @@ func @unsupported_ops(%arg0: tensor<3x3x1x32xf32> {tf.device = "/device:CPU:0"}) }) {device = "/device:CPU:0"} : () -> () %execute = "tf_device.launch"() ( { %3 = "tf.TPUExecute"(%arg0, %2, %compile#1) - : (tensor<3x3x1x32xf32>, tensor<3x3x1x32xf32>, tensor) -> tensor + : (tensor<3x3x1x32xf32>, tensor<3x3x1x32xf32>, tensor<2x!tf.string>) -> tensor tf_device.return %3 : tensor }) {device = "/device:TPU:0"} : () -> tensor return %execute : tensor @@ -246,9 +246,9 @@ func @replicated(%arg0: tensor<*x!tf.resource> {tf.device = "/device:CPU:0"}) -> NumDynamicShapes = 0 : i64, // The metadata encodes 2 parameter and two return values. metadata = "\0A\0E\08\01\18\01\22\08\08\01\1A\01\01\22\01\00\0A \08\01\12\10\12\02\08\03\12\02\08\03\12\02\08\01\12\02\08 \18\01\22\08\08\01\1A\01\01\22\01\00\12\0A\0A\08\08\01\1A\01\01\22\01\00\12\0A\0A\08\08\01\1A\01\01\22\01\00\18\02 \01", - mlir_module = "..."} : () -> (tensor, tensor) - tf_device.return %1#0, %1#1 : tensor, tensor - }) {device = "/device:CPU:0"} : () -> (tensor, tensor) + mlir_module = "..."} : () -> (tensor, tensor<2x!tf.string>) + tf_device.return %1#0, %1#1 : tensor, tensor<2x!tf.string> + }) {device = "/device:CPU:0"} : () -> (tensor, tensor<2x!tf.string>) // CHECK-DAG: %[[LAYOUT0:.*]] = "tf.TPUGetLayoutOp"(%[[COMPILE]]#1) {index = 0 : i64, is_output = false} // CHECK-DAG: %[[LAYOUT1:.*]] = "tf.TPUGetLayoutOp"(%[[COMPILE]]#1) {index = 1 : i64, is_output = false} // CHECK: %[[ITER1:.*]]:2 = "tf.IteratorGetNext" @@ -267,7 +267,7 @@ func @replicated(%arg0: tensor<*x!tf.resource> {tf.device = "/device:CPU:0"}) -> {n = 2 : i32, devices = {TPU_REPLICATED_CORE_0 = ["/device:TPU:0", "/device:TPU:1"]}} { // CHECK: "tf.TPUExecute"(%[[R0]], %[[R1]], %[[COMPILE]]#1) %execute = "tf_device.launch"() ( { - %4 = "tf.TPUExecute"(%r0, %r1, %compile#1) : (tensor<3x3x1x32xf32>, tensor<3x3x1x32xf32>, tensor) -> tensor + %4 = "tf.TPUExecute"(%r0, %r1, %compile#1) : (tensor<3x3x1x32xf32>, tensor<3x3x1x32xf32>, tensor<2x!tf.string>) -> tensor tf_device.return %4 : tensor }) {device = "TPU_REPLICATED_CORE_0"} : () -> tensor tf_device.return %execute : tensor @@ -286,9 +286,9 @@ func @inside_replicated(%arg0: tensor<*x!tf.resource> {tf.device = "/device:CPU: NumDynamicShapes = 0 : i64, // The metadata encodes 2 parameter and two return values. metadata = "\0A\0E\08\01\18\01\22\08\08\01\1A\01\01\22\01\00\0A \08\01\12\10\12\02\08\03\12\02\08\03\12\02\08\01\12\02\08 \18\01\22\08\08\01\1A\01\01\22\01\00\12\0A\0A\08\08\01\1A\01\01\22\01\00\12\0A\0A\08\08\01\1A\01\01\22\01\00\18\02 \01", - mlir_module = "..."} : () -> (tensor, tensor) - tf_device.return %1#0, %1#1 : tensor, tensor - }) {device = "/device:CPU:0"} : () -> (tensor, tensor) + mlir_module = "..."} : () -> (tensor, tensor<2x!tf.string>) + tf_device.return %1#0, %1#1 : tensor, tensor<2x!tf.string> + }) {device = "/device:CPU:0"} : () -> (tensor, tensor<2x!tf.string>) // CHECK-NOT: "tf.TPUGetLayoutOp" // CHECK-NOT: "tf.TPUCopyWithLayout" "tf_device.launch"() ( { @@ -300,7 +300,7 @@ func @inside_replicated(%arg0: tensor<*x!tf.resource> {tf.device = "/device:CPU: %2:2 = "tf.IteratorGetNext"(%r0) : (tensor<*x!tf.resource>) -> (tensor<3x3x1x32xf32>, tensor<3x3x1x32xf32>) %execute = "tf_device.launch"() ( { - %4 = "tf.TPUExecute"(%2#0, %2#1, %compile#1) : (tensor<3x3x1x32xf32>, tensor<3x3x1x32xf32>, tensor) -> tensor + %4 = "tf.TPUExecute"(%2#0, %2#1, %compile#1) : (tensor<3x3x1x32xf32>, tensor<3x3x1x32xf32>, tensor<2x!tf.string>) -> tensor tf_device.return %4 : tensor }) {device = "TPU_REPLICATED_CORE_0"} : () -> tensor tf_device.return %execute : tensor @@ -330,9 +330,9 @@ func @parallel_execute(%arg0: tensor<*x!tf.resource> {tf.device = "/device:CPU:0 // CHECK: %[[COMPILE:.*]]:3 = "tf_device.launch" // CHECK-NEXT: "tf._TPUCompileMlir"() %compile:3 = "tf_device.launch"() ( { - %1:3 = "tf._TPUCompileMlir"() {NumDynamicShapes = 0 : i64, metadata = "\0A\09\08\01\12\05\12\03\08\80\01\18\01 \02", mlir_module = "..."} : () -> (tensor, tensor, tensor) - tf_device.return %1#0, %1#1, %1#2 : tensor, tensor, tensor - }) {device = "/device:CPU:0"} : () -> (tensor, tensor, tensor) + %1:3 = "tf._TPUCompileMlir"() {NumDynamicShapes = 0 : i64, metadata = "\0A\09\08\01\12\05\12\03\08\80\01\18\01 \02", mlir_module = "..."} : () -> (tensor, tensor<2x!tf.string>, tensor<2x!tf.string>) + tf_device.return %1#0, %1#1, %1#2 : tensor, tensor<2x!tf.string>, tensor<2x!tf.string> + }) {device = "/device:CPU:0"} : () -> (tensor, tensor<2x!tf.string>, tensor<2x!tf.string>) // CHECK-DAG: %[[LAYOUT0:.*]] = "tf.TPUGetLayoutOp"(%[[COMPILE]]#1) {index = 0 : i64, is_output = false} // CHECK-DAG: %[[LAYOUT1:.*]] = "tf.TPUGetLayoutOp"(%[[COMPILE]]#2) {index = 0 : i64, is_output = false} // CHECK: %[[ITER:.*]]:2 = "tf.IteratorGetNext" @@ -351,7 +351,7 @@ func @parallel_execute(%arg0: tensor<*x!tf.resource> {tf.device = "/device:CPU:0 // CHECK-NEXT: tf_device.return // CHECK-NEXT: device = "/device:TPU:0" "tf_device.launch"() ( { - "tf.TPUExecute"(%2#0, %compile#1) : (tensor<128xf32>, tensor) -> () + "tf.TPUExecute"(%2#0, %compile#1) : (tensor<128xf32>, tensor<2x!tf.string>) -> () tf_device.return }) {device = "/device:TPU:0"} : () -> () tf_device.return @@ -364,7 +364,7 @@ func @parallel_execute(%arg0: tensor<*x!tf.resource> {tf.device = "/device:CPU:0 // CHECK-NEXT: tf_device.return // CHECK-NEXT: device = "/device:TPU:1" "tf_device.launch"() ( { - "tf.TPUExecute"(%2#1, %compile#2) : (tensor<128xf32>, tensor) -> () + "tf.TPUExecute"(%2#1, %compile#2) : (tensor<128xf32>, tensor<2x!tf.string>) -> () tf_device.return }) {device = "/device:TPU:1"} : () -> () tf_device.return @@ -396,9 +396,9 @@ func @replicated_parallel_execute(%arg0: tensor<*x!tf.resource> {tf.device = "/d // CHECK: %[[COMPILE:.*]]:3 = "tf_device.launch" // CHECK-NEXT: "tf._TPUCompileMlir"() %compile:3 = "tf_device.launch"() ( { - %1:3 = "tf._TPUCompileMlir"() {NumDynamicShapes = 0 : i64, metadata = "\0A\09\08\01\12\05\12\03\08\80\01\18\02 \02", mlir_module = "..."} : () -> (tensor, tensor, tensor) - tf_device.return %1#0, %1#1, %1#2 : tensor, tensor, tensor - }) {device = "/device:CPU:0"} : () -> (tensor, tensor, tensor) + %1:3 = "tf._TPUCompileMlir"() {NumDynamicShapes = 0 : i64, metadata = "\0A\09\08\01\12\05\12\03\08\80\01\18\02 \02", mlir_module = "..."} : () -> (tensor, tensor<2x!tf.string>, tensor<2x!tf.string>) + tf_device.return %1#0, %1#1, %1#2 : tensor, tensor<2x!tf.string>, tensor<2x!tf.string> + }) {device = "/device:CPU:0"} : () -> (tensor, tensor<2x!tf.string>, tensor<2x!tf.string>) // CHECK-DAG: %[[LAYOUT0:.*]] = "tf.TPUGetLayoutOp"(%[[COMPILE]]#1) {index = 0 : i64, is_output = false} // CHECK-DAG: %[[LAYOUT1:.*]] = "tf.TPUGetLayoutOp"(%[[COMPILE]]#2) {index = 0 : i64, is_output = false} // CHECK-DAG: %[[ITER0:.*]]:2 = "tf.IteratorGetNext"(%[[ARG0]]) @@ -423,7 +423,7 @@ func @replicated_parallel_execute(%arg0: tensor<*x!tf.resource> {tf.device = "/d // CHECK-NEXT: tf_device.return // CHECK-NEXT: device = "TPU_REPLICATED_CORE_0" "tf_device.launch"() ( { - "tf.TPUExecute"(%r0, %compile#1) : (tensor<128xf32>, tensor) -> () + "tf.TPUExecute"(%r0, %compile#1) : (tensor<128xf32>, tensor<2x!tf.string>) -> () tf_device.return }) {device = "TPU_REPLICATED_CORE_0"} : () -> () tf_device.return @@ -433,7 +433,7 @@ func @replicated_parallel_execute(%arg0: tensor<*x!tf.resource> {tf.device = "/d // CHECK-NEXT: tf_device.return // CHECK-NEXT: device = "TPU_REPLICATED_CORE_1" "tf_device.launch"() ( { - "tf.TPUExecute"(%r1, %compile#2) : (tensor<128xf32>, tensor) -> () + "tf.TPUExecute"(%r1, %compile#2) : (tensor<128xf32>, tensor<2x!tf.string>) -> () tf_device.return }) {device = "TPU_REPLICATED_CORE_1"} : () -> () tf_device.return diff --git a/tensorflow/compiler/mlir/tensorflow/tests/tpu-variable-runtime-reformatting.mlir b/tensorflow/compiler/mlir/tensorflow/tests/tpu-variable-runtime-reformatting.mlir index 1e308b42bfc..277e4a8415e 100644 --- a/tensorflow/compiler/mlir/tensorflow/tests/tpu-variable-runtime-reformatting.mlir +++ b/tensorflow/compiler/mlir/tensorflow/tests/tpu-variable-runtime-reformatting.mlir @@ -61,9 +61,9 @@ module attributes {tf.versions = {bad_consumers = [], min_consumer = 0 : i32, pr NumDynamicShapes = 0 : i64, // The metadata encodes 2 parameter and two return values. metadata = "\0A\0E\08\01\18\01\22\08\08\01\1A\01\01\22\01\00\0A \08\01\12\10\12\02\08\03\12\02\08\03\12\02\08\01\12\02\08 \18\01\22\08\08\01\1A\01\01\22\01\00\12\0A\0A\08\08\01\1A\01\01\22\01\00\12\0A\0A\08\08\01\1A\01\01\22\01\00\18\02 \01", - mlir_module = "..."} : () -> (tensor, tensor) - tf_device.return %2#0, %2#1 : tensor, tensor - }) {device = "/device:CPU:0"} : () -> (tensor, tensor) + mlir_module = "..."} : () -> (tensor, tensor<2x!tf.string>) + tf_device.return %2#0, %2#1 : tensor, tensor<2x!tf.string> + }) {device = "/device:CPU:0"} : () -> (tensor, tensor<2x!tf.string>) "tf_device.launch"() ( { "tf.TPUCompileSucceededAssert"(%compile#0) : (tensor) -> () tf_device.return @@ -86,7 +86,7 @@ module attributes {tf.versions = {bad_consumers = [], min_consumer = 0 : i32, pr "tf_device.launch"() ( { "tf.TPUExecuteAndUpdateVariables"(%id, %arg31, %compile#1) {device_var_reads_indices = [0, 1], device_var_updates_indices = [0, 1]} - : (tensor<*x!tf.resource>>, tensor<*x!tf.resource>>, tensor) -> () + : (tensor<*x!tf.resource>>, tensor<*x!tf.resource>>, tensor<2x!tf.string>) -> () tf_device.return }) {device = "TPU_REPLICATED_CORE_0"} : () -> () %ret = "tf.Const"() {value = dense<0> : tensor} : () -> tensor @@ -153,9 +153,9 @@ module attributes {tf.versions = {bad_consumers = [], min_consumer = 0 : i32, pr NumDynamicShapes = 0 : i64, // The metadata encodes 2 parameter and two return values. metadata = "\0A\0E\08\01\18\01\22\08\08\01\1A\01\01\22\01\00\0A \08\01\12\10\12\02\08\03\12\02\08\03\12\02\08\01\12\02\08 \18\01\22\08\08\01\1A\01\01\22\01\00\12\0A\0A\08\08\01\1A\01\01\22\01\00\12\0A\0A\08\08\01\1A\01\01\22\01\00\18\02 \01", - mlir_module = "..."} : () -> (tensor, tensor) - tf_device.return %2#0, %2#1 : tensor, tensor - }) {device = "/device:CPU:0"} : () -> (tensor, tensor) + mlir_module = "..."} : () -> (tensor, tensor<2x!tf.string>) + tf_device.return %2#0, %2#1 : tensor, tensor<2x!tf.string> + }) {device = "/device:CPU:0"} : () -> (tensor, tensor<2x!tf.string>) "tf_device.launch"() ( { "tf.TPUCompileSucceededAssert"(%compile#0) : (tensor) -> () tf_device.return @@ -173,7 +173,7 @@ module attributes {tf.versions = {bad_consumers = [], min_consumer = 0 : i32, pr "tf.TPUExecuteAndUpdateVariables"(%arg30, %arg31, %arg32, %compile#1) {device_var_reads_indices = [0, 1], device_var_updates_indices = [0, 1]} : (tensor<*x!tf.resource>>, tensor<*x!tf.resource>>, - tensor<*x!tf.resource>>, tensor) -> () + tensor<*x!tf.resource>>, tensor<2x!tf.string>) -> () tf_device.return }) {device = "TPU_REPLICATED_CORE_0"} : () -> () tf_device.return @@ -239,9 +239,9 @@ module attributes {tf.versions = {bad_consumers = [], min_consumer = 0 : i32, pr NumDynamicShapes = 0 : i64, // The metadata encodes 2 parameter and two return values. metadata = "\0A\0E\08\01\18\01\22\08\08\01\1A\01\01\22\01\00\0A \08\01\12\10\12\02\08\03\12\02\08\03\12\02\08\01\12\02\08 \18\01\22\08\08\01\1A\01\01\22\01\00\12\0A\0A\08\08\01\1A\01\01\22\01\00\12\0A\0A\08\08\01\1A\01\01\22\01\00\18\02 \01", - mlir_module = "..."} : () -> (tensor, tensor) - tf_device.return %2#0, %2#1 : tensor, tensor - }) {device = "/device:CPU:0"} : () -> (tensor, tensor) + mlir_module = "..."} : () -> (tensor, tensor<2x!tf.string>) + tf_device.return %2#0, %2#1 : tensor, tensor<2x!tf.string> + }) {device = "/device:CPU:0"} : () -> (tensor, tensor<2x!tf.string>) "tf_device.launch"() ( { "tf.TPUCompileSucceededAssert"(%compile#0) : (tensor) -> () tf_device.return @@ -254,7 +254,7 @@ module attributes {tf.versions = {bad_consumers = [], min_consumer = 0 : i32, pr "tf_device.launch"() ( { "tf.TPUExecuteAndUpdateVariables"(%id, %arg31, %compile#1) {device_var_reads_indices = [0, 1], device_var_updates_indices = [0, 1]} - : (tensor<*x!tf.resource>>, tensor<*x!tf.resource>>, tensor) -> () + : (tensor<*x!tf.resource>>, tensor<*x!tf.resource>>, tensor<2x!tf.string>) -> () tf_device.return }) {device = "TPU_REPLICATED_CORE_0"} : () -> () tf_device.return @@ -342,9 +342,9 @@ module attributes {tf.versions = {bad_consumers = [], min_consumer = 0 : i32, pr NumDynamicShapes = 0 : i64, // The metadata encodes 2 parameter and two return values. metadata = "\0A\0E\08\01\18\01\22\08\08\01\1A\01\01\22\01\00\0A \08\01\12\10\12\02\08\03\12\02\08\03\12\02\08\01\12\02\08 \18\01\22\08\08\01\1A\01\01\22\01\00\12\0A\0A\08\08\01\1A\01\01\22\01\00\12\0A\0A\08\08\01\1A\01\01\22\01\00\18\02 \01", - mlir_module = "..."} : () -> (tensor, tensor) - tf_device.return %2#0, %2#1 : tensor, tensor - }) {device = "/device:CPU:0"} : () -> (tensor, tensor) + mlir_module = "..."} : () -> (tensor, tensor<2x!tf.string>) + tf_device.return %2#0, %2#1 : tensor, tensor<2x!tf.string> + }) {device = "/device:CPU:0"} : () -> (tensor, tensor<2x!tf.string>) "tf_device.launch"() ( { "tf.TPUCompileSucceededAssert"(%compile#0) : (tensor) -> () tf_device.return @@ -367,7 +367,7 @@ module attributes {tf.versions = {bad_consumers = [], min_consumer = 0 : i32, pr "tf_device.launch"() ( { "tf.TPUExecuteAndUpdateVariables"(%id, %arg31, %compile#1) {device_var_reads_indices = [0, 1], device_var_updates_indices = [0, 1]} - : (tensor<*x!tf.resource>>, tensor<*x!tf.resource>>, tensor) -> () + : (tensor<*x!tf.resource>>, tensor<*x!tf.resource>>, tensor<2x!tf.string>) -> () tf_device.return }) {device = "TPU_REPLICATED_CORE_0"} : () -> () %ret = "tf.Const"() {value = dense<0> : tensor} : () -> tensor diff --git a/tensorflow/compiler/mlir/tensorflow/tests/tpu_extract_outside_compilation.mlir b/tensorflow/compiler/mlir/tensorflow/tests/tpu_extract_outside_compilation.mlir index 1f516a25824..2271bca7382 100644 --- a/tensorflow/compiler/mlir/tensorflow/tests/tpu_extract_outside_compilation.mlir +++ b/tensorflow/compiler/mlir/tensorflow/tests/tpu_extract_outside_compilation.mlir @@ -512,6 +512,137 @@ module attributes {tf.versions = {producer = 888 : i32}, tf.devices = ["/job:wor return %1 : tensor } + // Tests extraction of an outside compiled tf.IfRegion op where the entirety + // of tf.IfRegion op is outside compiled + + // CHECK-LABEL: func @outside_compiled_tf_if + func @outside_compiled_tf_if(%arg0: tensor) -> tensor { + // CHECK: %[[A_OUT:[0-9]*]] = "tf.A" + // CHECK: %[[F_OUT:[0-9]*]] = "tf.F" + // CHECK: %[[REPLICATE:[0-9]*]]:2 = tf_device.replicate + // CHECK: %[[PARALLEL_EXECUTE_OUTPUT:[0-9]*]] = "tf_device.parallel_execute" + // CHECK-NEXT: "tf_device.launch" + // CHECK-NEXT: %[[PLACEHOLDER_KEY:[0-9]*]] = "tf._TPUCompileMlirPlaceholderProgramKey"() + // CHECK-NEXT: %[[RECV_OUTPUT:[0-9]*]]:3 = "tf._XlaRecvAtHost"(%[[PLACEHOLDER_KEY]]) + // CHECK-SAME: device_ordinal = 0 + // CHECK-SAME: key = "host_compute_channel_cluster1_args" + // CHECK-SAME: (tensor<2x!tf.string>) -> (tensor, tensor, tensor) + // CHECK-NEXT: tf.IfRegion"(%[[RECV_OUTPUT]]#2) + // CHECK: "tf.D"(%[[RECV_OUTPUT]]#0, %[[RECV_OUTPUT]]#1, %[[F_OUT]]) + // CHECK: "tf._XlaSendFromHost"(%[[PLACEHOLDER_KEY]]) + // CHECK-SAME: device_ordinal = 0 + // CHECK-SAME: key = "host_compute_channel_cluster1_retvals" + // CHECK: "tf_device.cluster" + // CHECK: %[[A_OUTPUT:[0-9]*]] = "tf.A" + // CHECK: %[[B_OUTPUT:[0-9]*]] = "tf.B" + // CHECK: %[[G_OUTPUT:[0-9]*]] = "tf.G" + // CHECK: "tf._XlaHostComputeMlir"(%[[B_OUTPUT]], %[[A_OUTPUT]], %[[G_OUTPUT]]) + // CHECK-SAME: recv_key = "host_compute_channel_cluster1_retvals" + // CHECK-SAME: send_key = "host_compute_channel_cluster1_args" + // CHECK-SAME: tpu_core = 0 + %0 = "tf.A"(%arg0) : (tensor) -> tensor + %7 = "tf.F"() : () -> tensor + + %1:2 = tf_device.replicate([%0, %arg0] as %ri_0: tensor) {n = 2 : i32} { + %2 = "tf_device.cluster"() ( { + %3 = "tf.A"() : () -> (tensor) + %4 = "tf.B"() : () -> (tensor) + %6 = "tf.G"() : () -> (tensor) + + "tf.IfRegion"(%6) ({ + "tf.D"(%4, %3, %7) {} : (tensor, tensor, tensor) -> () + "tf.Yield"() : () -> () + }, { + "tf.Yield"() : () -> () + }) {_xla_outside_compilation = "cluster1", is_stateless = false} : (tensor) -> () + + %5 = "tf.E"() : () -> tensor + tf_device.return %5 : tensor + }) {num_cores_per_replica = 1, topology = "", device_assignment = []} : () -> tensor + tf_device.return %2 : tensor + } + + return %1 : tensor + } + + // Tests extraction of an outside compiled tf.IfRegion op where the entirety + // of tf.IfRegion op is outside compiled and wrapped inside another + // tf.IfRegion op + + // CHECK-LABEL: func @outside_compiled_tf_if_nested + func @outside_compiled_tf_if_nested(%arg0: tensor) -> tensor { + // CHECK: %[[A_OUT:[0-9]*]] = "tf.A" + // CHECK: %[[F_OUT:[0-9]*]] = "tf.F" + // CHECK: %[[REPLICATE:[0-9]*]]:2 = tf_device.replicate + // CHECK: %[[PARALLEL_EXECUTE_OUTPUT:[0-9]*]] = "tf_device.parallel_execute" + // CHECK-NEXT: "tf_device.launch" + // CHECK-NEXT: %[[PLACEHOLDER_KEY:[0-9]*]] = "tf._TPUCompileMlirPlaceholderProgramKey"() + // CHECK-NEXT: %[[RECV_OUTPUT_PREDICATE:[0-9]*]] = "tf._XlaRecvAtHost"(%[[PLACEHOLDER_KEY]]) + // CHECK-SAME: device_ordinal = 0 + // CHECK-SAME: key = "if_predicate_channel_cluster1_0" + // CHECK-SAME: (tensor<2x!tf.string>) -> tensor + // CHECK-NEXT: tf.IfRegion"(%[[RECV_OUTPUT_PREDICATE]]) + // CHECK-NEXT: %[[RECV_OUTPUT:[0-9]*]]:2 = "tf._XlaRecvAtHost"(%[[PLACEHOLDER_KEY]]) + // CHECK-SAME: device_ordinal = 0 + // CHECK-SAME: key = "host_compute_channel_cluster1_args" + // CHECK-SAME: (tensor<2x!tf.string>) -> (tensor, tensor) + // CHECK-NEXT: tf.IfRegion"(%[[RECV_OUTPUT]]#1) + // CHECK-NEXT: "tf.H"(%[[RECV_OUTPUT]]#0, %[[F_OUT]]) + // CHECK: "tf.Yield"() : () -> () + // CHECK: "tf.Yield"() : () -> () + // CHECK: "tf._XlaSendFromHost"(%[[PLACEHOLDER_KEY]]) + // CHECK-SAME: device_ordinal = 0 + // CHECK-SAME: key = "host_compute_channel_cluster1_retvals" + // CHECK: "tf_device.cluster" + // CHECK: %[[A_OUTPUT:[0-9]*]] = "tf.A" + // CHECK: %[[B_OUTPUT:[0-9]*]] = "tf.B" + // CHECK: %[[G_OUTPUT:[0-9]*]] = "tf.G" + // CHECK: "tf.XlaSendToHost"(%[[G_OUTPUT]]) + // CHECK-SAME: key = "if_predicate_channel_cluster1_0" + // CHECK-SAME: (tensor) -> () + // CHECK-NEXT: "tf.IfRegion"(%[[G_OUTPUT]]) + // CHECK: %[[D_OUT:[0-9]*]] = "tf.D" + // CHECK-NEXT: %[[F_OUT:[0-9]*]] = "tf.F" + // CHECK: "tf._XlaHostComputeMlir"(%[[D_OUT]], %[[F_OUT]]) + // CHECK-SAME: recv_key = "host_compute_channel_cluster1_retvals" + // CHECK-SAME: send_key = "host_compute_channel_cluster1_args" + // CHECK-SAME: tpu_core = 0 + // CHECK: "tf.Yield"() : () -> () + // CHECK: "tf.Yield"() : () -> () + %0 = "tf.A"(%arg0) : (tensor) -> tensor + %7 = "tf.F"() : () -> tensor + + %1:2 = tf_device.replicate([%0, %arg0] as %ri_0: tensor) {n = 2 : i32} { + %2 = "tf_device.cluster"() ( { + %3 = "tf.A"() : () -> (tensor) + %4 = "tf.B"() : () -> (tensor) + %6 = "tf.G"() : () -> (tensor) + + "tf.IfRegion"(%6) ({ + %8 = "tf.D"(%4, %3, %7) {} : (tensor, tensor, tensor) -> (tensor) + %9 = "tf.F"(%4) {} : (tensor) -> (tensor) + + "tf.IfRegion"(%9) ({ + "tf.H"(%8, %7) : (tensor, tensor) -> () + "tf.Yield"() : () -> () + }, { + "tf.Yield"() : () -> () + }) {_xla_outside_compilation = "cluster1", is_stateless = false} : (tensor) -> () + + "tf.Yield"() : () -> () + }, { + "tf.Yield"() : () -> () + }) {is_stateless = false} : (tensor) -> () + + %5 = "tf.E"() : () -> tensor + tf_device.return %5 : tensor + }) {num_cores_per_replica = 1, topology = "", device_assignment = []} : () -> tensor + tf_device.return %2 : tensor + } + + return %1 : tensor + } + // Tests extraction of a single outside compiled cluster inside a tf.IfRegion // op with return values. diff --git a/tensorflow/compiler/mlir/tensorflow/tests/tpu_identity_pruning.mlir b/tensorflow/compiler/mlir/tensorflow/tests/tpu_identity_pruning.mlir new file mode 100644 index 00000000000..317e7036c42 --- /dev/null +++ b/tensorflow/compiler/mlir/tensorflow/tests/tpu_identity_pruning.mlir @@ -0,0 +1,93 @@ +// RUN: tf-opt %s -tf-tpu-identity-pruning | FileCheck %s --dump-input=always + +// Tests Identity op in cluster is pruned away. + +// CHECK-LABEL: func @testIdentity +// CHECK-SAME: ([[ARG0:%.*]]: tensor) +func @testIdentity(%arg0: tensor) { + // CHECK-NOT: "tf.Identity" + // CHECK: "tf_device.cluster" + // CHECK-NEXT: tf_device.return [[ARG0]] + %0 = "tf_device.cluster"() ( { + %1 = "tf.Identity"(%arg0) : (tensor) -> tensor + tf_device.return %1 : tensor + }) : () -> tensor + return +} + +// Tests IdentityN op in cluster is pruned away. + +// CHECK-LABEL: func @testIdentityN +// CHECK-SAME: ([[ARG0:%.*]]: tensor, [[ARG1:%.*]]: tensor) +func @testIdentityN(%arg0: tensor, %arg1: tensor) { + // CHECK-NOT: "tf.IdentityN" + // CHECK: "tf_device.cluster" + // CHECK-NEXT: tf_device.return [[ARG0]], [[ARG1]] + %0:2 = "tf_device.cluster"() ( { + %1:2 = "tf.IdentityN"(%arg0, %arg1) : (tensor, tensor) -> (tensor, tensor) + tf_device.return %1#0, %1#1 : tensor, tensor + }) : () -> (tensor, tensor) + return +} + +// Tests transitive Identity ops reachable from the cluster are pruned away. + +// CHECK-LABEL: func @testTransitiveIdentity +// CHECK-SAME: ([[ARG0:%.*]]: tensor) +func @testTransitiveIdentity(%arg0: tensor) { + // CHECK: "tf_device.cluster" + // CHECK: "tf.PartitionedCall"([[ARG0]]) + // CHECK-SAME: f = @callee0 + %0 = "tf_device.cluster"() ( { + %1 = "tf.PartitionedCall"(%arg0) {config = "", config_proto = "", executor_type = "", f = @callee0} : (tensor) -> tensor + tf_device.return %1 : tensor + }) : () -> tensor + return +} + +// CHECK-LABEL: func @callee0 +// CHECK-SAME: ([[ARG0:%.*]]: tensor) +func @callee0(%arg0: tensor) -> tensor { + // CHECK-NOT: "tf.Identity" + // CHECK: "tf.PartitionedCall"([[ARG0]]) + // CHECK-SAME: f = @callee1 + %0 = "tf.Identity"(%arg0) : (tensor) -> tensor + %1 = "tf.PartitionedCall"(%arg0) {config = "", config_proto = "", executor_type = "", f = @callee1} : (tensor) -> tensor + return %1 : tensor +} + +// CHECK-LABEL: func @callee1 +// CHECK-SAME: ([[ARG0:%.*]]: tensor) +func @callee1(%arg0: tensor) -> tensor { + // CHECK-NOT: "tf.Identity" + // CHECK: return [[ARG0]] + %0 = "tf.Identity"(%arg0) : (tensor) -> tensor + return %0 : tensor +} + +// Tests Identity ops not reachable from the cluster are not pruned away. + +// CHECK-LABEL: func @testIdentityOutsideCluster +// CHECK-SAME: ([[ARG0:%.*]]: tensor) +func @testIdentityOutsideCluster(%arg0: tensor) { + // CHECK: [[IDENTITY:%.*]] = "tf.Identity"([[ARG0]]) + // CHECK: [[CLUSTER:%.*]] = "tf_device.cluster" + // CHECK-NEXT: tf_device.return [[IDENTITY]] + %0 = "tf.Identity"(%arg0) : (tensor) -> tensor + %1 = "tf_device.cluster"() ( { + tf_device.return %0 : tensor + }) : () -> tensor + // CHECK: "tf.PartitionedCall"([[CLUSTER]]) + // CHECK-SAME: f = @callee2 + %2 = "tf.PartitionedCall"(%1) {config = "", config_proto = "", executor_type = "", f = @callee2} : (tensor) -> tensor + return +} + +// CHECK-LABEL: func @callee2 +// CHECK-SAME: ([[ARG0:%.*]]: tensor) +func @callee2(%arg0: tensor) -> tensor { + // CHECK: [[IDENTITY:%.*]] = "tf.Identity"([[ARG0]]) + %0 = "tf.Identity"(%arg0) : (tensor) -> tensor + // CHECK: return [[IDENTITY]] + return %0 : tensor +} diff --git a/tensorflow/compiler/mlir/tensorflow/tests/tpu_rewrite.mlir b/tensorflow/compiler/mlir/tensorflow/tests/tpu_rewrite.mlir index 2a0091ce9bf..ef7b52cd978 100644 --- a/tensorflow/compiler/mlir/tensorflow/tests/tpu_rewrite.mlir +++ b/tensorflow/compiler/mlir/tensorflow/tests/tpu_rewrite.mlir @@ -1262,15 +1262,15 @@ module attributes {tf.versions = {producer = 888 : i32}, tf.devices = ["/job:wor // CHECK-NOT:"tf._TPUCompileMlirPlaceholderProgramKey" // CHECK: "tf.E"(%[[COMPILE_OUTPUT]]#1 %3 = "tf_device.parallel_execute"() ( { - %program = "tf._TPUCompileMlirPlaceholderProgramKey"() : () -> tensor - "tf.D"(%program) : (tensor) -> () + %program = "tf._TPUCompileMlirPlaceholderProgramKey"() : () -> tensor<2x!tf.string> + "tf.D"(%program) : (tensor<2x!tf.string>) -> () tf_device.return }, { %4 = "tf_device.cluster_func"(%ri_0) {_tpu_replicate = "cluster0", func = @tpu0_func, num_cores_per_replica = 1, step_marker_location = "STEP_MARK_AT_TOP_LEVEL_WHILE_LOOP", padding_map = ["\08\01\10\02\18\03"], topology = "", device_assignment = [], input_sharding_configuration = ["\08\01\1A\01\01\22\01\00"], output_sharding_configuration = ["\08\01\1A\01\01\22\01\00"], use_spmd_for_xla_partitioning = false} : (tensor) -> tensor tf_device.return %4 : tensor }, { - %program = "tf._TPUCompileMlirPlaceholderProgramKey"() : () -> tensor - "tf.E"(%program) : (tensor) -> () + %program = "tf._TPUCompileMlirPlaceholderProgramKey"() : () -> tensor<2x!tf.string> + "tf.E"(%program) : (tensor<2x!tf.string>) -> () tf_device.return }) : () -> (tensor) tf_device.return %3 : tensor diff --git a/tensorflow/compiler/mlir/tensorflow/tests/visitor-interrupt-util.mlir b/tensorflow/compiler/mlir/tensorflow/tests/visitor-interrupt-util.mlir index 1770b4e146d..8cc8d273bec 100644 --- a/tensorflow/compiler/mlir/tensorflow/tests/visitor-interrupt-util.mlir +++ b/tensorflow/compiler/mlir/tensorflow/tests/visitor-interrupt-util.mlir @@ -69,7 +69,7 @@ func @foo(%arg0: tensor) -> tensor { // Test static filtering // expected-remark@below {{0: before all regions}} // expected-remark@below {{7: walk was interrupted}} -func @foo(%arg0: tensor) -> tensor { +func @foo(%arg0: tensor, %arg1: tensor) -> tensor { // expected-remark@below {{1: before all regions}} %cst = constant dense<1.0> : tensor // expected-remark@below {{2: before all regions}} @@ -77,7 +77,7 @@ func @foo(%arg0: tensor) -> tensor { // expected-remark@below {{8: before all regions}} // expected-remark@below {{9: before region #1}} // expected-remark@below {{10: after all regions}} - %0 = "tf.IfRegion"(%arg0) ({ + %0 = "tf.IfRegion"(%arg1) ({ // expected-remark@below {{3: before all regions}} %1 = "tf.Identity"(%arg0) : (tensor) -> tensor // expected-remark@below {{4: before all regions}} @@ -86,6 +86,6 @@ func @foo(%arg0: tensor) -> tensor { // expected-remark@below {{6: before all regions}} %1 = "tf.Identity"(%arg0) : (tensor) -> tensor "tf.Yield"(%1) { interrupt_after_all = true } : (tensor) -> () - }) {is_stateless = true}: (tensor) -> tensor + }) {is_stateless = true}: (tensor) -> tensor return %0 : tensor } diff --git a/tensorflow/compiler/mlir/tensorflow/tests/visitor-util.mlir b/tensorflow/compiler/mlir/tensorflow/tests/visitor-util.mlir index d376fad5c33..9a832b7fe8d 100644 --- a/tensorflow/compiler/mlir/tensorflow/tests/visitor-util.mlir +++ b/tensorflow/compiler/mlir/tensorflow/tests/visitor-util.mlir @@ -77,7 +77,7 @@ func @foo(%arg0: tensor) -> tensor { // Test static filtering // expected-remark@below {{0: before all regions}} // expected-remark@below {{10: after all regions}} -func @foo(%arg0: tensor) -> tensor { +func @foo(%arg0: tensor, %arg1: tensor) -> tensor { // expected-remark@below {{1: before all regions}} %cst = constant dense<1.0> : tensor // expected-remark@below {{2: before all regions}} @@ -86,7 +86,7 @@ func @foo(%arg0: tensor) -> tensor { // expected-remark@below {{11: before all regions}} // expected-remark@below {{12: before region #1}} // expected-remark@below {{13: after all regions}} - %0 = "tf.IfRegion"(%arg0) ({ + %0 = "tf.IfRegion"(%arg1) ({ // expected-remark@below {{3: before all regions}} %1 = "tf.Identity"(%arg0) : (tensor) -> tensor // expected-remark@below {{4: before all regions}} @@ -96,7 +96,7 @@ func @foo(%arg0: tensor) -> tensor { %1 = "tf.Identity"(%arg0) : (tensor) -> tensor // expected-remark@below {{7: before all regions}} "tf.Yield"(%1) : (tensor) -> () - }) {is_stateless = true}: (tensor) -> tensor + }) {is_stateless = true}: (tensor) -> tensor // expected-remark@below {{9: before all regions}} return %0 : tensor } diff --git a/tensorflow/compiler/mlir/tensorflow/transforms/bridge.cc b/tensorflow/compiler/mlir/tensorflow/transforms/bridge.cc index 5b0a4b4e619..0c21078b0ad 100644 --- a/tensorflow/compiler/mlir/tensorflow/transforms/bridge.cc +++ b/tensorflow/compiler/mlir/tensorflow/transforms/bridge.cc @@ -95,16 +95,18 @@ void CreateTPUBridgePipeline(OpPassManager &pm) { func_pm.addPass(CreateTPUHostComputationExpansionPass()); func_pm.addPass(CreateTPUUpdateEmbeddingEnqueueOpInputsPass()); } - pm.addPass(TF::CreateTFFunctionalControlFlowToRegions()); - pm.addPass(mlir::createInlinerPass()); - pm.addPass(CreateTPUExtractHeadTailOutsideCompilationPass()); - pm.addPass(TF::CreateTFRegionControlFlowToFunctional()); - // Run another shape inference pass because resource decomposition might have // created new partial types. pm.addPass(TF::CreateTFShapeInferencePass()); - pm.addNestedPass(tf_executor::CreateTFExecutorConstantSinkingPass()); pm.addPass(TFDevice::CreateResourceOpLiftingPass()); + pm.addPass(TF::CreateTFFunctionalControlFlowToRegions()); + pm.addPass(mlir::createInlinerPass()); + pm.addPass(TFDevice::CreateMarkOpsForOutsideCompilationPass()); + pm.addPass(CreateTPUExtractHeadTailOutsideCompilationPass()); + pm.addPass(CreateTPUExtractOutsideCompilationPass()); + pm.addPass(TF::CreateTFRegionControlFlowToFunctional()); + + pm.addNestedPass(tf_executor::CreateTFExecutorConstantSinkingPass()); pm.addPass(TF::CreateResourceDeviceInferencePass()); pm.addPass(TFDevice::CreateClusterOutliningPass()); pm.addPass(CreateTPUDynamicPaddingMapperPass()); diff --git a/tensorflow/compiler/mlir/tensorflow/transforms/constant_fold.cc b/tensorflow/compiler/mlir/tensorflow/transforms/constant_fold.cc index 1429e2b3fd4..3005c78c54f 100644 --- a/tensorflow/compiler/mlir/tensorflow/transforms/constant_fold.cc +++ b/tensorflow/compiler/mlir/tensorflow/transforms/constant_fold.cc @@ -17,6 +17,7 @@ limitations under the License. #include +#include "mlir/IR/OpDefinition.h" // from @llvm-project #include "mlir/Interfaces/SideEffectInterfaces.h" // from @llvm-project #include "mlir/Support/LogicalResult.h" // from @llvm-project #include "tensorflow/c/eager/c_api.h" @@ -68,7 +69,7 @@ static bool ShouldBeFolded(Operation* inst) { LogicalResult ConstantFoldFallbackHook( Operation* inst, ArrayRef operands, - SmallVectorImpl& results) { // NOLINT + SmallVectorImpl& results) { // NOLINT // Instructions with side effects should not be constant folded to preserve // the original semantics. if (inst->getNumRegions() != 0 || !MemoryEffectOpInterface::hasNoEffect(inst)) @@ -126,8 +127,16 @@ LogicalResult ConstantFoldFallbackHook( // TODO(jpienaar): Avoid using global context & mutex here. static auto* mu = new tensorflow::mutex(); tensorflow::mutex_lock l(*mu); - return tensorflow::EvaluateOperation(inst, inputs, ctx, &results); + SmallVector constants; + LogicalResult status = + tensorflow::EvaluateOperation(inst, inputs, ctx, &constants); + results.assign(constants.begin(), constants.end()); + return status; } +static bool init_hooks = ([] () { + TensorFlowDialect::RegisterConstantFoldHook(ConstantFoldFallbackHook); +}(), true); + } // namespace TF } // namespace mlir diff --git a/tensorflow/compiler/mlir/tensorflow/transforms/constant_fold.h b/tensorflow/compiler/mlir/tensorflow/transforms/constant_fold.h index 69e39080965..887eea745e7 100644 --- a/tensorflow/compiler/mlir/tensorflow/transforms/constant_fold.h +++ b/tensorflow/compiler/mlir/tensorflow/transforms/constant_fold.h @@ -27,7 +27,7 @@ namespace TF { LogicalResult ConstantFoldFallbackHook( Operation *inst, ArrayRef operands, - SmallVectorImpl &results); // NOLINT + SmallVectorImpl &results); // NOLINT } // namespace TF } // namespace mlir diff --git a/tensorflow/compiler/mlir/tensorflow/transforms/dialect_hooks.cc b/tensorflow/compiler/mlir/tensorflow/transforms/decode_attributes_hook.cc similarity index 74% rename from tensorflow/compiler/mlir/tensorflow/transforms/dialect_hooks.cc rename to tensorflow/compiler/mlir/tensorflow/transforms/decode_attributes_hook.cc index 109ceea47e7..d309c6d379f 100644 --- a/tensorflow/compiler/mlir/tensorflow/transforms/dialect_hooks.cc +++ b/tensorflow/compiler/mlir/tensorflow/transforms/decode_attributes_hook.cc @@ -19,7 +19,6 @@ limitations under the License. #include "mlir/IR/Attributes.h" // from @llvm-project #include "mlir/IR/Builders.h" // from @llvm-project #include "mlir/IR/Dialect.h" // from @llvm-project -#include "mlir/IR/DialectHooks.h" // from @llvm-project #include "mlir/IR/StandardTypes.h" // from @llvm-project #include "mlir/IR/Types.h" // from @llvm-project #include "mlir/Support/LogicalResult.h" // from @llvm-project @@ -35,31 +34,22 @@ namespace { // Since this method is passed to MLIR as decode hook it has to conform // to LLVM style used by MLIR. -bool DecodeOpaqueTensorHook(const OpaqueElementsAttr input, - ElementsAttr& output) { // NOLINT +LogicalResult DecodeOpaqueTensorHook(const OpaqueElementsAttr input, + ElementsAttr& output) { // NOLINT Builder builder(input.getType().getContext()); auto decoded_attr_or = tensorflow::DecodeOpaqueTensor(input, builder); if (!decoded_attr_or.ok()) { VLOG(2) << decoded_attr_or.status().error_message(); - return true; + return failure(); } output = decoded_attr_or.ValueOrDie(); - return false; + return success(); } -// Hooks for the TensorFlow dialect. -class TensorFlowHooks : public DialectHooks { - public: - DialectConstantFoldHook getConstantFoldHook() { - return TF::ConstantFoldFallbackHook; - } - DialectConstantDecodeHook getDecodeHook() { return DecodeOpaqueTensorHook; } -}; +static bool init_hooks = ([] () { + TF::TensorFlowDialect::RegisterDecodeConstantHook(DecodeOpaqueTensorHook); +}(), true); } // anonymous namespace - -// Static initialization for TensorFlow dialect hooks registration. -static DialectHooksRegistration tf_hooks_registration("tf"); - } // namespace mlir diff --git a/tensorflow/compiler/mlir/tensorflow/transforms/fold_switch.cc b/tensorflow/compiler/mlir/tensorflow/transforms/fold_switch.cc index b47378762a9..cc24c98a786 100644 --- a/tensorflow/compiler/mlir/tensorflow/transforms/fold_switch.cc +++ b/tensorflow/compiler/mlir/tensorflow/transforms/fold_switch.cc @@ -240,7 +240,7 @@ static LogicalResult FoldMergeNodes(FuncOp function, const DeadQueue& queue) { auto def_op = val.getDefiningOp(); #ifndef NDEBUG auto exec_dialect = - function.getContext()->getRegisteredDialect("tf_executor"); + function.getContext()->getLoadedDialect("tf_executor"); assert(def_op->getDialect() == exec_dialect && "unable to forward control dependencies"); #endif diff --git a/tensorflow/compiler/mlir/tensorflow/transforms/functional_control_flow_to_regions.cc b/tensorflow/compiler/mlir/tensorflow/transforms/functional_control_flow_to_regions.cc index d23b977f0e3..11d74e87f96 100644 --- a/tensorflow/compiler/mlir/tensorflow/transforms/functional_control_flow_to_regions.cc +++ b/tensorflow/compiler/mlir/tensorflow/transforms/functional_control_flow_to_regions.cc @@ -23,6 +23,7 @@ limitations under the License. #include "mlir/IR/Builders.h" // from @llvm-project #include "mlir/IR/Function.h" // from @llvm-project #include "mlir/IR/Operation.h" // from @llvm-project +#include "mlir/IR/StandardTypes.h" // from @llvm-project #include "mlir/IR/TypeUtilities.h" // from @llvm-project #include "mlir/IR/Value.h" // from @llvm-project #include "mlir/IR/Verifier.h" // from @llvm-project @@ -31,8 +32,8 @@ limitations under the License. #include "mlir/Pass/PassRegistry.h" // from @llvm-project #include "tensorflow/compiler/mlir/tensorflow/ir/tf_ops.h" #include "tensorflow/compiler/mlir/tensorflow/ir/tf_types.h" -#include "tensorflow/compiler/mlir/tensorflow/transforms/attribute_utils.h" #include "tensorflow/compiler/mlir/tensorflow/transforms/passes.h" +#include "tensorflow/compiler/mlir/tensorflow/utils/attribute_utils.h" #define DEBUG_TYPE "tf-functional-cf-to-region" @@ -53,8 +54,8 @@ struct FunctionalControlFlowToRegions // the input arguments are used as is (for IfOp) or block arguments of the same // type as the input arguments are created and then used as call arguments (for // While). -void CreateCall(Operation* op, FuncOp func, Region& caller_region, - ValueRange args, bool use_region_args) { +YieldOp CreateCall(Operation* op, FuncOp func, Region& caller_region, + ValueRange args, bool use_region_args) { assert(caller_region.empty() && "Expected empty region for newly created ops"); OpBuilder builder(caller_region); @@ -76,15 +77,26 @@ void CreateCall(Operation* op, FuncOp func, Region& caller_region, casted_args.push_back(arg); } auto call = builder.create(op->getLoc(), func, casted_args); - builder.create(op->getLoc(), call.getResults()); + return builder.create(op->getLoc(), call.getResults()); +} + +// Converts the condition for an IfOp/WhileOp to a boolean value. +Value ConvertConditionToBoolean(Operation* op, Value cond) { + if (auto ranked_type = cond.getType().dyn_cast()) + if (ranked_type.getRank() == 0 && + ranked_type.getElementType().isSignlessInteger(1)) + return cond; + + OpBuilder builder(op); + return builder.create(op->getLoc(), cond); } // Transform a functional IfOp to a region based IfRegionOp. LogicalResult ConvertIfOp(IfOp if_op) { + Value cond = ConvertConditionToBoolean(if_op, if_op.cond()); auto if_region = OpBuilder(if_op).create( - if_op.getLoc(), if_op.getResultTypes(), if_op.cond(), - if_op.is_stateless()); - CopyUnderscoredAttributes(if_op, if_region); + if_op.getLoc(), if_op.getResultTypes(), cond, if_op.is_stateless()); + CopyDeviceAndUnderscoredAttributes(if_op, if_region); CreateCall(if_op, if_op.then_func(), /*caller_region=*/if_region.then_branch(), if_op.input(), @@ -101,11 +113,16 @@ LogicalResult ConvertWhileOp(WhileOp while_op) { auto while_region = OpBuilder(while_op).create( while_op.getLoc(), while_op.getResultTypes(), while_op.input(), while_op.is_stateless(), while_op.parallel_iterations()); - CopyUnderscoredAttributes(while_op, while_region); + CopyDeviceAndUnderscoredAttributes(while_op, while_region); + + YieldOp cond_yield = + CreateCall(while_op, while_op.cond_func(), + /*caller_region=*/while_region.cond(), while_op.input(), + /*use_region_args=*/true); + Value i1_cond = + ConvertConditionToBoolean(cond_yield, cond_yield.getOperand(0)); + cond_yield.setOperand(0, i1_cond); - CreateCall(while_op, while_op.cond_func(), - /*caller_region=*/while_region.cond(), while_op.input(), - /*use_region_args=*/true); CreateCall(while_op, while_op.body_func(), /*caller_region=*/while_region.body(), while_op.input(), /*use_region_args=*/true); diff --git a/tensorflow/compiler/mlir/tensorflow/transforms/gpu_fusion.cc b/tensorflow/compiler/mlir/tensorflow/transforms/gpu_fusion.cc index 175baeb627f..fbe0524ce8b 100644 --- a/tensorflow/compiler/mlir/tensorflow/transforms/gpu_fusion.cc +++ b/tensorflow/compiler/mlir/tensorflow/transforms/gpu_fusion.cc @@ -91,7 +91,7 @@ struct ReluToFusedBatchNorm : public OpRewritePattern { // Build the newly fused operation to replace the batch norm OperationState state(batch_norm.getLoc(), - FusedBatchNormExOp::getOperationName()); + _FusedBatchNormExOp::getOperationName()); state.addOperands(batch_norm.getOperands()); if (side_input) state.operands.push_back(side_input); state.addTypes(batch_norm.getResultTypes()); diff --git a/tensorflow/compiler/mlir/tensorflow/transforms/launch_to_device_attribute.cc b/tensorflow/compiler/mlir/tensorflow/transforms/launch_to_device_attribute.cc index 9f67a3e7e71..4e507c8e760 100644 --- a/tensorflow/compiler/mlir/tensorflow/transforms/launch_to_device_attribute.cc +++ b/tensorflow/compiler/mlir/tensorflow/transforms/launch_to_device_attribute.cc @@ -104,7 +104,7 @@ LogicalResult HoistOpsAndAnnotateWithDevice(const Dialect* tf_dialect, } void LaunchToDeviceAttributePass::runOnFunction() { - const Dialect* tf_dialect = getContext().getRegisteredDialect("tf"); + const Dialect* tf_dialect = getContext().getLoadedDialect("tf"); if (!tf_dialect) { getFunction().emitError() << "'tf' dialect is not registered"; return signalPassFailure(); diff --git a/tensorflow/compiler/mlir/tensorflow/transforms/lower_tf.cc b/tensorflow/compiler/mlir/tensorflow/transforms/lower_tf.cc index 483c84b3e80..6946dc65104 100644 --- a/tensorflow/compiler/mlir/tensorflow/transforms/lower_tf.cc +++ b/tensorflow/compiler/mlir/tensorflow/transforms/lower_tf.cc @@ -26,6 +26,7 @@ limitations under the License. #include "mlir/IR/StandardTypes.h" // from @llvm-project #include "mlir/IR/TypeUtilities.h" // from @llvm-project #include "tensorflow/compiler/mlir/tensorflow/ir/tf_ops.h" +#include "tensorflow/compiler/mlir/tensorflow/ir/tf_remaining_ops.h" #include "tensorflow/compiler/mlir/tensorflow/ir/tf_types.h" #include "tensorflow/core/util/tensor_format.h" @@ -55,18 +56,27 @@ static DenseIntElementsAttr GetI64ElementsAttrForSeq(int start, int end, return DenseIntElementsAttr::get(ty, vals); } -// Returns int or float DenseElementsAttr with scalar shape with the given -// element type and the integer value. +// Returns int, float, or complex DenseElementsAttr with scalar shape with the +// given element type and the integer value. static DenseElementsAttr GetScalarOfType(Type ty, int64_t raw_value) { RankedTensorType scalar_ty = RankedTensorType::get({}, ty); if (auto float_ty = ty.dyn_cast_or_null()) { FloatAttr attr = FloatAttr::get(float_ty, raw_value); return DenseElementsAttr::get(scalar_ty, attr); + } else if (auto int_ty = ty.dyn_cast_or_null()) { + IntegerAttr attr = IntegerAttr::get(int_ty, raw_value); + return DenseElementsAttr::get(scalar_ty, attr); + } else if (auto complex_ty = ty.dyn_cast_or_null()) { + Type complex_element_ty = complex_ty.getElementType(); + if (complex_element_ty.isF32()) { + return DenseElementsAttr::get( + scalar_ty, static_cast>(raw_value)); + } else if (complex_element_ty.isF64()) { + return DenseElementsAttr::get( + scalar_ty, static_cast>(raw_value)); + } } - - auto int_ty = ty.cast(); - IntegerAttr attr = IntegerAttr::get(int_ty, raw_value); - return DenseElementsAttr::get(scalar_ty, attr); + llvm_unreachable("unsupported type"); } // Returns float DenseElementsAttr with scalar shape with the specified value. @@ -427,12 +437,38 @@ class LowerSparseMatMulOp : public OpRewritePattern { } }; +// Lowers _UnaryOpsComposition op as a series of original TensorFlow ops that +// were fused together. +class Lower_UnaryOpsComposition + : public OpRewritePattern { + public: + using OpRewritePattern::OpRewritePattern; + + LogicalResult matchAndRewrite(TF::_UnaryOpsCompositionOp op, + PatternRewriter &rewriter) const override { + Value result = op.x(); + for (StringRef op_name : + op.op_names().getAsRange()) { + std::string full_name = "tf." + op_name.str(); + // All ops in the sequences have the same result type as the original + // result type. + OperationState state(op.getLoc(), full_name, /*operands=*/{result}, + /*types=*/{op.getType()}, /*attributes=*/{}); + Operation *op = rewriter.createOperation(state); + result = op->getResult(0); + } + rewriter.replaceOp(op, {result}); + return success(); + } +}; + } // namespace void PopulateLoweringTFPatterns(MLIRContext *context, OwningRewritePatternList *patterns) { patterns->insert(context); + LowerPackOp, LowerSparseMatMulOp, Lower_UnaryOpsComposition>( + context); populateWithGenerated(context, patterns); } diff --git a/tensorflow/compiler/mlir/tensorflow/transforms/lower_tf.td b/tensorflow/compiler/mlir/tensorflow/transforms/lower_tf.td index 6b7d7178ab6..f7a867f3130 100644 --- a/tensorflow/compiler/mlir/tensorflow/transforms/lower_tf.td +++ b/tensorflow/compiler/mlir/tensorflow/transforms/lower_tf.td @@ -195,8 +195,7 @@ def : Pat<(TF_PadOp TensorOf<[AnySignlessInteger, AnyFloat]>:$input, $paddings), // Reciprocal op patterns. //===----------------------------------------------------------------------===// -// TODO(hinsu): Support complex and unsigned input types. -def LowerReciprocal : Pat<(TF_ReciprocalOp TF_SintOrFpTensor:$x), +def LowerReciprocal : Pat<(TF_ReciprocalOp $x), (TF_DivOp (TF_ConstOp (GetScalarOfType<1> $x)), $x)>; //===----------------------------------------------------------------------===// diff --git a/tensorflow/compiler/mlir/tensorflow/transforms/mark_ops_for_outside_compilation.cc b/tensorflow/compiler/mlir/tensorflow/transforms/mark_ops_for_outside_compilation.cc index e538491ae9d..38cbe3f404e 100644 --- a/tensorflow/compiler/mlir/tensorflow/transforms/mark_ops_for_outside_compilation.cc +++ b/tensorflow/compiler/mlir/tensorflow/transforms/mark_ops_for_outside_compilation.cc @@ -131,9 +131,28 @@ LogicalResult MarkUncompilableOps( return success(); } +// Unmarks outside compilation for any op that has parents already +// marked for outside compilation since the child will be extracted +// anyways. +void UnmarkChildren(Block* block) { + block->walk([&](Operation* op) { + if (!op->getAttrOfType(kXlaOutsideCompilationAttr)) return; + Operation* iter_op = op; + bool remove_attr = false; + while (auto* parent_op = iter_op->getParentOp()) { + if (parent_op->getAttrOfType(kXlaOutsideCompilationAttr)) { + remove_attr = true; + break; + } + iter_op = parent_op; + } + if (remove_attr) op->removeAttr(kXlaOutsideCompilationAttr); + }); +} + void MarkOpsForOutsideCompilation::runOnOperation() { auto module = getOperation(); - const Dialect* tf_dialect = getContext().getRegisteredDialect("tf"); + const Dialect* tf_dialect = getContext().getLoadedDialect("tf"); if (!tf_dialect) { getOperation().emitError() << "'tf' dialect is not registered"; return signalPassFailure(); @@ -168,6 +187,17 @@ void MarkOpsForOutsideCompilation::runOnOperation() { }); if (result.wasInterrupted()) return signalPassFailure(); + + module.walk([&](tf_device::ClusterOp cluster) { + // Only if `allow_soft_placement` attribute is true should we unmark ops + // for outside compilation. + auto soft_placement_attr = + cluster.getAttrOfType(kAllowSoftPlacementAttr); + if (!(soft_placement_attr && soft_placement_attr.getValue())) { + return; + } + UnmarkChildren(&cluster.GetBody()); + }); } } // namespace diff --git a/tensorflow/compiler/mlir/tensorflow/transforms/passes.h b/tensorflow/compiler/mlir/tensorflow/transforms/passes.h index fb2d6e39da3..d93d9ddccaf 100644 --- a/tensorflow/compiler/mlir/tensorflow/transforms/passes.h +++ b/tensorflow/compiler/mlir/tensorflow/transforms/passes.h @@ -79,6 +79,11 @@ std::unique_ptr> CreateRewriteTPUEmbeddingOpsPass(); // Performs specific fusion for GPU targets. std::unique_ptr> CreateGpuOpFusionPass(); +// Create a pass that convert ops that copy tensors between devices, e.g. +// tf.Identity. +std::unique_ptr> +CreateTensorDeviceCopyConversionPass(); + struct LayoutOptimizationPipelineOptions : public PassPipelineOptions { Option force_data_format{ @@ -271,6 +276,9 @@ namespace TFTPU { // `_tpu_replicate` attribute. std::unique_ptr> CreateTPUClusterFormationPass(); +// Creates a pass that removes Identity/IdentityN ops from a cluster. +std::unique_ptr> CreateTPUIdentityPruningPass(); + // Creates a pass that allows TPU program inputs to have layouts determined at // run time. std::unique_ptr> CreateTPUDynamicLayoutPass(); diff --git a/tensorflow/compiler/mlir/tensorflow/transforms/region_control_flow_to_functional.cc b/tensorflow/compiler/mlir/tensorflow/transforms/region_control_flow_to_functional.cc index ba876e08fbb..1e403bff0eb 100644 --- a/tensorflow/compiler/mlir/tensorflow/transforms/region_control_flow_to_functional.cc +++ b/tensorflow/compiler/mlir/tensorflow/transforms/region_control_flow_to_functional.cc @@ -36,8 +36,8 @@ limitations under the License. #include "tensorflow/compiler/mlir/op_or_arg_name_mapper.h" #include "tensorflow/compiler/mlir/tensorflow/ir/tf_ops.h" #include "tensorflow/compiler/mlir/tensorflow/ir/tf_types.h" -#include "tensorflow/compiler/mlir/tensorflow/transforms/attribute_utils.h" #include "tensorflow/compiler/mlir/tensorflow/transforms/passes.h" +#include "tensorflow/compiler/mlir/tensorflow/utils/attribute_utils.h" #define DEBUG_TYPE "tf-region-cf-to-functional" @@ -158,9 +158,11 @@ void ExtractSingleBlockRegion(Region& region, StringRef name, } // Returns call for region with single call whose result feeds into the -// terminator of the region. Returns none if the region doesn't contain just -// call and non-truncting casts ops. -llvm::Optional IsSingleCallRegion(Region& region) { +// terminator of the region. if `allow_to_bool` is true, also allows a single +// ToBoolOp between the region yield and the call. Returns none if the region +// does not conform to this pattern. +llvm::Optional IsSingleCallRegion(Region& region, + bool allow_to_bool = false) { if (!llvm::hasSingleElement(region)) return llvm::None; Block& block = region.front(); @@ -169,31 +171,44 @@ llvm::Optional IsSingleCallRegion(Region& region) { if (it == block.rend()) return llvm::None; + // Operation which is expected to consume all the call results. + Operation* call_consumer = yield; + + // Allow a single ToBoolOp between the call and the yield (valid only + // when the yield has a single operand) + if (allow_to_bool && yield.getNumOperands() == 1 && isa(*it)) { + if (it->getResult(0) != yield.getOperand(0)) return llvm::None; + call_consumer = cast(*it); + it++; + } + // Check if there is a Call before the Yield. CallOp call = dyn_cast(*it++); if (!call) return llvm::None; + // All call results should feed into expected consumer + // All results of the call should feed into the yield. + if (call.getNumResults() != call_consumer->getNumOperands()) + return llvm::None; + + for (auto res_it : llvm::zip(call.getResults(), call_consumer->getOperands())) + if (std::get<0>(res_it) != std::get<1>(res_it)) return llvm::None; + // There can only be non-truncating cast op's prior to the call. for (; it != block.rend(); ++it) { CastOp cast = dyn_cast(*it); if (!cast || cast.Truncate()) return llvm::None; } - // All results of the call should feed into the yield. - if (call.getNumResults() != yield.getNumOperands()) return llvm::None; - - for (auto res_it : llvm::zip(call.getResults(), yield.getOperands())) - if (std::get<0>(res_it) != std::get<1>(res_it)) return llvm::None; - return call; } -using MatcherFn = function_ref; +using ArgMatcherFn = function_ref; // Returns whether the arguments of the given 2 calls are match (after looking // through cast ops). `matcher` is the predicate used to check if two arguments // match. -bool MatchCallArgs(CallOp first, CallOp second, MatcherFn matcher) { +bool MatchCallArgs(CallOp first, CallOp second, ArgMatcherFn matcher) { if (first.getNumOperands() != second.getNumOperands()) return false; Region& first_region = *first.getParentRegion(); @@ -225,38 +240,37 @@ struct TrivialTransformInfo { // List of callee names (one for each region). llvm::SmallVector callee_names; - // Constructor will analyze the 2 regions. - TrivialTransformInfo(Region& first, Region& second, MatcherFn matcher); + // Analyzes the given calls (from regions attached to the same parent op) to + // check if the parent op be transformed to functional form trivially (i.e., + // reusing existing functions and without outlining). This is possible when + // all the regions are single call regions (checked using matchers outside + // this class) and the all the calls match using the given argument matcher. + // + // If such a trivial transformation is possible, stash the relevant + // information needed for the transformation, else indicate that a trivial + // transformation is not possible by setting `can_transform` to false. + TrivialTransformInfo(llvm::Optional first_call, + llvm::Optional second_call, + ArgMatcherFn arg_matcher) { + if (!first_call || !second_call) return; + + if (!MatchCallArgs(first_call.getValue(), second_call.getValue(), + arg_matcher)) + return; + + can_transform = true; + callee_names = {first_call.getValue().getCallee(), + second_call.getValue().getCallee()}; + } }; -// Analyzes the given set of regions (attached to the same parent op) to check -// if the parent op be transformed to functional form trivially (i.e., reusing -// existing functions and without outlining). This is possible when all the -// regions are single call regions and the all the calls have the same -// arguments. -// -// If such a trivial transformation is possible, stash the relevant information -// needed for the transformation, else indicate that a trivial transformation is -// not possible by setting `can_transform` to false. -TrivialTransformInfo::TrivialTransformInfo(Region& first, Region& second, - MatcherFn matcher) { - auto call0 = IsSingleCallRegion(first); - auto call1 = IsSingleCallRegion(second); - if (!call0 || !call1) return; - - if (!MatchCallArgs(call0.getValue(), call1.getValue(), matcher)) return; - - can_transform = true; - callee_names = {call0.getValue().getCallee(), call1.getValue().getCallee()}; -} - // Transform IfRegionOp to IfOp. LogicalResult RegionControlFlowToFunctional::ConvertIfOp(IfRegionOp if_region) { llvm::SmallVector extern_values; // For IfOp, arguments of calls in the then and else regions match if they // are the same value. - auto if_matcher = [&](Value first, Region&, Value second, Region&) { + auto if_arg_matcher = [&](Value first, Region&, Value second, Region&) { if (first != second) return false; // collect the call arguments post lookup through cast Op's @@ -264,8 +278,9 @@ LogicalResult RegionControlFlowToFunctional::ConvertIfOp(IfRegionOp if_region) { return true; }; - const TrivialTransformInfo tti(if_region.then_branch(), - if_region.else_branch(), if_matcher); + const TrivialTransformInfo tti(IsSingleCallRegion(if_region.then_branch()), + IsSingleCallRegion(if_region.else_branch()), + if_arg_matcher); std::string then_name, else_name; @@ -293,16 +308,23 @@ LogicalResult RegionControlFlowToFunctional::ConvertIfOp(IfRegionOp if_region) { worklist, /*extern_values_passthrough=*/false); } + // Look through ToBool operations for the condition. + Value cond = if_region.cond(); + auto to_bool = dyn_cast_or_null(cond.getDefiningOp()); + if (to_bool) cond = to_bool.getOperand(); + // Once we have the `then` and `else` functions ready (either outlined or // existing ones), replace the region based op with a functional control flow // op. OpBuilder builder(if_region); auto if_op = builder.create( - if_region.getLoc(), if_region.getResultTypes(), if_region.cond(), - extern_values, then_name, else_name, if_region.is_stateless()); - CopyUnderscoredAttributes(if_region, if_op); + if_region.getLoc(), if_region.getResultTypes(), cond, extern_values, + then_name, else_name, if_region.is_stateless()); + CopyDeviceAndUnderscoredAttributes(if_region, if_op); if_region.replaceAllUsesWith(if_op.getResults()); if_region.erase(); + + if (to_bool && to_bool.use_empty()) to_bool.erase(); return success(); } @@ -315,8 +337,8 @@ LogicalResult RegionControlFlowToFunctional::ConvertWhileOp( // cannot do a trivial transformation because post transform, we will need to // pass this extern value as an argument to the function, so we cannot use the // existing function as is. - auto while_matcher = [](Value first, Region& first_region, Value second, - Region& second_region) { + auto while_arg_matcher = [](Value first, Region& first_region, Value second, + Region& second_region) { if (!first.isa() || !second.isa()) return false; BlockArgument first_block_arg = first.cast(); @@ -329,8 +351,9 @@ LogicalResult RegionControlFlowToFunctional::ConvertWhileOp( second_block_arg.getParentBlock() == &second_region.front(); }; - const TrivialTransformInfo tti(while_region.cond(), while_region.body(), - while_matcher); + const TrivialTransformInfo tti( + IsSingleCallRegion(while_region.cond(), /*allow_to_bool=*/true), + IsSingleCallRegion(while_region.body()), while_arg_matcher); // All existing inputs to while region are inputs to the functional while. auto new_inputs = llvm::to_vector<4>(while_region.getOperands()); @@ -376,7 +399,7 @@ LogicalResult RegionControlFlowToFunctional::ConvertWhileOp( auto while_op = builder.create( while_region.getLoc(), new_result_types, new_inputs, cond_name, body_name, while_region.parallel_iterations(), while_region.is_stateless()); - CopyUnderscoredAttributes(while_region, while_op); + CopyDeviceAndUnderscoredAttributes(while_region, while_op); // Redirect old results to new results. for (auto it : llvm::zip( diff --git a/tensorflow/compiler/mlir/tensorflow/transforms/replicate_to_island.cc b/tensorflow/compiler/mlir/tensorflow/transforms/replicate_to_island.cc index ef75f90d5c1..d99279c0014 100644 --- a/tensorflow/compiler/mlir/tensorflow/transforms/replicate_to_island.cc +++ b/tensorflow/compiler/mlir/tensorflow/transforms/replicate_to_island.cc @@ -438,7 +438,7 @@ LogicalResult CreateIslandsFromReplicate(const Dialect* tf_dialect, void ReplicateToIslandPass::runOnOperation() { auto module = getOperation(); - const Dialect* tf_dialect = getContext().getRegisteredDialect("tf"); + const Dialect* tf_dialect = getContext().getLoadedDialect("tf"); if (!tf_dialect) { module.emitError() << "'tf' dialect is not registered"; return signalPassFailure(); diff --git a/tensorflow/compiler/mlir/tensorflow/transforms/resource_device_inference.cc b/tensorflow/compiler/mlir/tensorflow/transforms/resource_device_inference.cc index bd0e8a94a61..c1ca98bf1f1 100644 --- a/tensorflow/compiler/mlir/tensorflow/transforms/resource_device_inference.cc +++ b/tensorflow/compiler/mlir/tensorflow/transforms/resource_device_inference.cc @@ -26,10 +26,13 @@ limitations under the License. #include "llvm/ADT/StringRef.h" #include "llvm/ADT/iterator_range.h" #include "llvm/Support/Casting.h" +#include "llvm/Support/Debug.h" #include "mlir/IR/Attributes.h" // from @llvm-project #include "mlir/IR/Builders.h" // from @llvm-project #include "mlir/IR/Function.h" // from @llvm-project +#include "mlir/IR/OpImplementation.h" // from @llvm-project #include "mlir/IR/Operation.h" // from @llvm-project +#include "mlir/IR/OperationSupport.h" // from @llvm-project #include "mlir/IR/Types.h" // from @llvm-project #include "mlir/IR/Value.h" // from @llvm-project #include "mlir/IR/Visitors.h" // from @llvm-project @@ -39,6 +42,9 @@ limitations under the License. #include "tensorflow/compiler/mlir/tensorflow/analysis/resource_alias_analysis.h" #include "tensorflow/compiler/mlir/tensorflow/ir/tf_ops.h" #include "tensorflow/compiler/mlir/tensorflow/ir/tf_types.h" +#include "tensorflow/compiler/mlir/tensorflow/utils/visitor_util.h" + +#define DEBUG_TYPE "tf-resource-device-inference" namespace mlir { namespace TF { @@ -132,6 +138,13 @@ inline StringRef GetDeviceAttr(Operation* op) { return device_attr ? device_attr.getValue() : ""; } +// Print operation with debug info (to get line number info for debugging) +void dump(StringRef message, Operation* op) { + llvm::dbgs() << message; + op->print(llvm::dbgs(), OpPrintingFlags().enableDebugInfo(true)); + llvm::dbgs() << "\n"; +} + // Propagates device assignment inside a function. LogicalResult ComputeResourceDevicesInComputation(FuncOp func_op, PerFunctionResult* result) { @@ -153,26 +166,67 @@ LogicalResult ComputeResourceDevicesInComputation(FuncOp func_op, if (failed(res)) return res; } - auto walk_res = func_op.walk([&](Operation* op) { - if (auto var_handle = dyn_cast(op)) { - // Record VarHandleOp's device attribute. - StringRef device_attr = GetDeviceAttr(op); - if (device_attr.empty()) return WalkResult::advance(); - auto res = AddResourceDeviceAndEmitError(var_handle.resource(), - device_attr, op, result); - if (failed(res)) return WalkResult::interrupt(); - } - if (auto identity = dyn_cast(op)) { - // Try to construct IdentityOp's attribute from recorded assignment. - if (!GetDeviceAttr(op).empty()) return WalkResult::advance(); - for (auto output : filter_resources(op->getResults())) { - if (auto device = result->DeviceForResource(output)) - identity.setAttr(kDeviceAttr, builder.getStringAttr(*device)); - } - return WalkResult::advance(); - } - return WalkResult::advance(); - }); + // To support WhileRegion, we need to propagate device attributes from + // WhileRegion operands to body/cond region arguments *prior* to visiting + // these regions. Use tensorflow::walk() instead of MLIR core walker to + // implement such a pre-order walk. + auto walk_res = tensorflow::GenericWalk( + func_op, [&](Operation* op, const tensorflow::WalkStage& stage) { + // We just need to visit operations in pre-order mode. + if (!stage.IsBeforeAllRegions()) return WalkResult::advance(); + + if (auto var_handle = dyn_cast(op)) { + // Record VarHandleOp's device attribute. + StringRef device_attr = GetDeviceAttr(op); + if (device_attr.empty()) return WalkResult::advance(); + auto res = AddResourceDeviceAndEmitError(var_handle.resource(), + device_attr, op, result); + if (failed(res)) return WalkResult::interrupt(); + } else if (auto identity = dyn_cast(op)) { + LLVM_DEBUG(dump("Visiting ", identity)); + // Try to construct IdentityOp's attribute from recorded assignment. + if (!GetDeviceAttr(op).empty()) return WalkResult::advance(); + for (auto output : filter_resources(op->getResults())) { + LLVM_DEBUG(llvm::dbgs() << " Processing output #" + << output.getResultNumber() << "\n"); + if (auto device = result->DeviceForResource(output)) { + LLVM_DEBUG(llvm::dbgs() + << " Setting device = " << *device << "\n"); + identity.setAttr(kDeviceAttr, builder.getStringAttr(*device)); + } + } + } else if (auto while_region = dyn_cast(op)) { + // For WhileRegion, do local analysis prior to visiting the attached + // regions and propagate device annotations to the cond and body + // region arguments. The annotations are the union of annotations + // on the input and result. Resource alias analysis already propagates + // resource ID from the inputs to the results for a while, so just + // need to consider the results. + LLVM_DEBUG(llvm::dbgs() << "Visiting WhileRegion\n"); + + for (auto output : filter_resources(while_region.getResults())) { + auto device = result->DeviceForResource(output); + int output_index = output.getResultNumber(); + if (!device) { + LLVM_DEBUG(llvm::dbgs() + << " No device for output #" << output_index << "\n"); + continue; + } + // Transfer the annotation to both region arguments + for (Region* region : while_region.getRegions()) { + BlockArgument arg = region->getArgument(output_index); + LLVM_DEBUG(llvm::dbgs() + << " Propagating device = '" << *device + << "' to arg #" << output_index << " of region #" + << region->getRegionNumber() << "\n"); + if (failed(AddResourceDeviceAndEmitError(arg, *device, + while_region, result))) + return WalkResult::interrupt(); + } + } + } + return WalkResult::advance(); + }); return failure(walk_res.wasInterrupted()); } @@ -201,6 +255,10 @@ void ResourceDeviceInference::runOnOperation() { Value arg_operand = caller_operands[arg.getArgNumber()]; auto device = caller_res.DeviceForResource(arg_operand); if (!device) continue; + LLVM_DEBUG(llvm::dbgs() + << "Propagating '" << *device << "' to arg #" + << arg.getArgNumber() << " of function @" + << callee.getName() << "\n"); if (failed(AddResourceDeviceAndEmitError(arg, *device, caller, &callee_res, &callee_needs_recompute))) @@ -240,6 +298,8 @@ void ResourceDeviceInference::runOnOperation() { "call"); return WalkResult::interrupt(); } + LLVM_DEBUG(llvm::dbgs() + << "Visiting call to function @" << func.getName() << "\n"); if (failed(propagate_operands_to_callee_arguments( call, call.getArgOperands(), {func}, func_res))) return WalkResult::interrupt(); diff --git a/tensorflow/compiler/mlir/tensorflow/transforms/resource_op_lifting.cc b/tensorflow/compiler/mlir/tensorflow/transforms/resource_op_lifting.cc index 702455d156d..77f672f5ee4 100644 --- a/tensorflow/compiler/mlir/tensorflow/transforms/resource_op_lifting.cc +++ b/tensorflow/compiler/mlir/tensorflow/transforms/resource_op_lifting.cc @@ -21,6 +21,7 @@ limitations under the License. #include "llvm/ADT/MapVector.h" #include "llvm/ADT/STLExtras.h" #include "llvm/ADT/SmallVector.h" +#include "llvm/ADT/StringRef.h" #include "llvm/Support/Casting.h" #include "mlir/Dialect/StandardOps/IR/Ops.h" // from @llvm-project #include "mlir/IR/Attributes.h" // from @llvm-project @@ -330,15 +331,6 @@ LogicalResult HoistResourceOpsFromCluster(tf_device::ClusterOp cluster, getUsedValuesDefinedAbove(new_cluster.body(), new_cluster.body(), captured_values); - for (Value v : captured_values) { - auto tensor_type = v.getType().dyn_cast(); - if (!tensor_type) continue; - if (!tensor_type.getElementType().isa()) continue; - - return new_cluster.emitOpError() - << "has remaining resource inputs that can not be lifted"; - } - return success(); } @@ -361,29 +353,23 @@ LogicalResult FindResourceArgUseInfo( ResourceArgUseInfo info; info.used = false; info.updated = false; - bool do_not_touch = false; + bool read_or_assigned = false; for (auto user : arg.getUsers()) { if (user == return_op) continue; + info.used = true; if (auto read = llvm::dyn_cast(user)) { - info.used = true; + read_or_assigned = true; info.data_type = read.getType(); continue; } if (auto assign = llvm::dyn_cast(user)) { - info.used = true; + read_or_assigned = true; info.updated = true; info.data_type = assign.value().getType(); continue; } - if (isa(user)) { - // Stacks will be handled by a separate pass. - do_not_touch = true; - break; - } - user->emitOpError("found unsupported operations on resource."); - return failure(); } - if (!do_not_touch) (*result)[arg.getArgNumber()] = info; + if (!info.used || read_or_assigned) (*result)[arg.getArgNumber()] = info; } return success(); } @@ -914,8 +900,8 @@ LogicalResult HandlePartitionedCallOpCallee( // resource-lifted new callee function in lifting_info. template void UpdatePartitionedCallOpWithNewCallee( - CallOpType call_op, const PartitionedCallLiftingInfo& lifting_info) { - if (lifting_info.lifted_callee == nullptr) return; + CallOpType call_op, PartitionedCallLiftingInfo& lifting_info) { + if (!lifting_info.lifted_callee) return; // Replace output resource uses with the aliasing input, so that we can remove // this output. for (const auto& entry : lifting_info.old_outputs_aliasing_old_inputs) { @@ -929,12 +915,10 @@ void UpdatePartitionedCallOpWithNewCallee( auto new_operands = FilterRange(call_op.args(), lifting_info.use_info); auto new_call = builder.create( - call_op.getLoc(), - const_cast(lifting_info.lifted_callee).getType().getResults(), + call_op.getLoc(), lifting_info.lifted_callee.getType().getResults(), new_operands, call_op.getAttrs()); new_call.setAttr( - "f", builder.getSymbolRefAttr( - const_cast(lifting_info.lifted_callee).getName())); + "f", builder.getSymbolRefAttr(lifting_info.lifted_callee.getName())); AddLoadsStoresOutsideControlFlowOp( new_call, lifting_info.arg_data_type_and_updated_output_index); // Replace uses. @@ -949,7 +933,8 @@ void UpdatePartitionedCallOpWithNewCallee( } LogicalResult HoistForFunctionalControlFlow( - Block*, ModuleOp, llvm::SmallDenseMap*); + Block*, ModuleOp, + llvm::SmallDenseMap*); // A templated routine for handling both PartitionedCallOp and // StatefulPartitionedCallOp. If the callee is already lifted, it just updates @@ -958,9 +943,10 @@ LogicalResult HoistForFunctionalControlFlow( template LogicalResult HandlePartitionedCallOp( CallOpType call_op, FuncOp callee, ModuleOp module, - llvm::SmallDenseMap* lifted_callees) { - auto emplace_res = - lifted_callees->try_emplace(callee, PartitionedCallLiftingInfo()); + llvm::SmallDenseMap* + lifted_callees) { + auto emplace_res = lifted_callees->try_emplace(callee.getName(), + PartitionedCallLiftingInfo()); if (emplace_res.second) { // Unseen callee. Perform resource lifting on it. HoistForFunctionalControlFlow(&callee.front(), module, lifted_callees); @@ -977,7 +963,7 @@ LogicalResult HandlePartitionedCallOp( // body/cond/branch/callee functions. LogicalResult HoistForFunctionalControlFlow( Block* block, ModuleOp module, - llvm::SmallDenseMap* + llvm::SmallDenseMap* lifted_partitioned_call_callees) { // Remove identity nodes to avoid aliasing. RemoveIdentity(block); @@ -1056,7 +1042,7 @@ LogicalResult HoistForFunctionalControlFlow( // Returns failure if there are remaining resource-type values that can not be // lifted. void ResourceOpLiftingPass::runOnOperation() { - llvm::SmallDenseMap + llvm::SmallDenseMap lifted_partitioned_call_callees; ModuleOp module = getOperation(); auto result = module.walk([&](FuncOp func_op) { @@ -1121,7 +1107,7 @@ LogicalResult ResourceLiftingForFunctionalControlFlow(FuncOp function) { << function.getBlocks().size(); } - llvm::SmallDenseMap + llvm::SmallDenseMap lifted_partitioned_call_callees; return HoistForFunctionalControlFlow(&function.front(), cast(function.getParentOp()), diff --git a/tensorflow/compiler/mlir/tensorflow/transforms/shape_inference.cc b/tensorflow/compiler/mlir/tensorflow/transforms/shape_inference.cc index 597fbe2c0b1..88ad787df3e 100644 --- a/tensorflow/compiler/mlir/tensorflow/transforms/shape_inference.cc +++ b/tensorflow/compiler/mlir/tensorflow/transforms/shape_inference.cc @@ -40,6 +40,7 @@ limitations under the License. #include "mlir/IR/SymbolTable.h" // from @llvm-project #include "mlir/IR/Value.h" // from @llvm-project #include "mlir/Interfaces/CallInterfaces.h" // from @llvm-project +#include "mlir/Interfaces/FoldInterfaces.h" // from @llvm-project #include "mlir/Pass/Pass.h" // from @llvm-project #include "mlir/Pass/PassRegistry.h" // from @llvm-project #include "mlir/Support/LLVM.h" // from @llvm-project @@ -596,7 +597,7 @@ ShapeInference::ShapeInference(int64_t graph_version, MLIRContext* context, bool propagate_caller_callee_constants) : graph_version_(graph_version), propagate_caller_callee_constants_(propagate_caller_callee_constants) { - tf_dialect_ = context->getRegisteredDialect(); + tf_dialect_ = context->getLoadedDialect(); } ShapeHandle ShapeInference::ComputeOutputAsShape(OpResult result, @@ -697,11 +698,8 @@ bool ShapeInference::RefineShapeForPassThroughOps(Operation* op) { // TODO(jpienaar): The tf.Cast op, which is uniformly inserted at the // moment, cannot handle arbirary types (e.g., it can't handle quantized // types). This restriction can be relaxed if not only tf.Cast is used. - auto kind = t.getKind(); - return (kind >= Type::FIRST_STANDARD_TYPE && - kind < Type::LAST_STANDARD_TYPE) || - (kind >= Type::FIRST_TENSORFLOW_TYPE && - kind < Type::LAST_TENSORFLOW_TYPE); + return t.getDialect().getNamespace().empty() || + isa(t.getDialect()); }; bool changed = false; @@ -1174,10 +1172,11 @@ LogicalResult ShapeInference::TryToFold(Operation* op) { if (!dialect) return failure(); // Only attempt TF dialect fallback if there are no unknown operands. if (some_unknown && dialect == tf_dialect_) return failure(); - SmallVector constants; - if (failed(dialect->constantFoldHook(op, constant_operands, constants))) + auto* interface = dialect->getRegisteredInterface(); + if (!interface) return failure(); + + if (failed(interface->fold(op, constant_operands, fold_results))) return failure(); - fold_results.assign(constants.begin(), constants.end()); } for (auto result : zip(op->getResults(), fold_results)) { diff --git a/tensorflow/compiler/mlir/tensorflow/transforms/tensor_device_copy_conversion.cc b/tensorflow/compiler/mlir/tensorflow/transforms/tensor_device_copy_conversion.cc new file mode 100644 index 00000000000..f14efeb91ce --- /dev/null +++ b/tensorflow/compiler/mlir/tensorflow/transforms/tensor_device_copy_conversion.cc @@ -0,0 +1,81 @@ +/* Copyright 2020 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "mlir/Dialect/StandardOps/IR/Ops.h" +#include "mlir/Pass/PassManager.h" +#include "mlir/Transforms/DialectConversion.h" +#include "mlir/Transforms/Passes.h" +#include "mlir/IR/OperationSupport.h" // from @llvm-project +#include "mlir/IR/StandardTypes.h" // from @llvm-project +#include "mlir/IR/Types.h" // from @llvm-project +#include "mlir/Pass/PassOptions.h" // from @llvm-project +#include "tensorflow/compiler/mlir/tensorflow/ir/tf_ops.h" +#include "tensorflow/compiler/mlir/tensorflow/transforms/passes.h" +#include "tensorflow/compiler/mlir/tensorflow/utils/convert_tensor.h" + +namespace mlir { +namespace TF { +namespace { + +// Deletes the op and forwards the arguments. +template +class PassThroughConversion : public mlir::OpConversionPattern { + public: + explicit PassThroughConversion(MLIRContext *context) + : mlir::OpConversionPattern(context) {} + + LogicalResult matchAndRewrite( + TF_Op op, ArrayRef operands, + ConversionPatternRewriter &rewriter) const override { // NOLINT + // Just forward the arguments to results. + rewriter.replaceOp(op, operands); + return success(); + } +}; + +class TensorDeviceCopyConversionPass + : public PassWrapper { + public: + void runOnFunction() override { + mlir::OwningRewritePatternList patterns; + mlir::ConversionTarget target(getContext()); + + // TODO(tfrt-devs): when device placer is introduced in the lowering pass, + // we need to check if Identity op and it's previous op are placed on the + // same device. If not, we don't fold Identity op since it's used for tensor + // copying between devices. + patterns.insert, + PassThroughConversion>(&getContext()); + + if (failed(applyPartialConversion(getFunction(), target, patterns))) { + signalPassFailure(); + } + } +}; + +} // namespace + +std::unique_ptr> +CreateTensorDeviceCopyConversionPass() { + return std::make_unique(); +} + +static mlir::PassRegistration + tensor_device_copy_pass( + "tf-tensor-device-copy", + "Handle ops that copy tensors between devices. E.g., tf.Identity."); + +} // namespace TF +} // namespace mlir diff --git a/tensorflow/compiler/mlir/tensorflow/transforms/tf_device_assignment.cc b/tensorflow/compiler/mlir/tensorflow/transforms/tf_device_assignment.cc index 2a770b2615d..f26887eb276 100644 --- a/tensorflow/compiler/mlir/tensorflow/transforms/tf_device_assignment.cc +++ b/tensorflow/compiler/mlir/tensorflow/transforms/tf_device_assignment.cc @@ -34,7 +34,7 @@ class SimpleTFDeviceAssignmentPass void runOnFunction() override { Builder builder(&getContext()); - Dialect* tf = getContext().getRegisteredDialect(); + Dialect* tf = getContext().getLoadedDialect(); getFunction().walk([&](Operation* op) { if (auto device_attr = op->getAttrOfType("device")) { // We assign default device to ops with device attribute that is empty. diff --git a/tensorflow/compiler/mlir/tensorflow/transforms/tpu_extract_head_tail_outside_compilation.cc b/tensorflow/compiler/mlir/tensorflow/transforms/tpu_extract_head_tail_outside_compilation.cc index 2be6ee7a78c..fed4002bfcf 100644 --- a/tensorflow/compiler/mlir/tensorflow/transforms/tpu_extract_head_tail_outside_compilation.cc +++ b/tensorflow/compiler/mlir/tensorflow/transforms/tpu_extract_head_tail_outside_compilation.cc @@ -331,7 +331,8 @@ void RemoveClusterAliasedOutputs(OpBuilder* builder, for (auto result : llvm::zip(cluster_terminator->getOperands(), cluster.getResults())) { Value cluster_terminator_operand = std::get<0>(result); - if (cluster.getOperation()->isProperAncestor( + if (cluster_terminator_operand.getDefiningOp() && + cluster.getOperation()->isProperAncestor( cluster_terminator_operand.getDefiningOp())) { new_cluster_results.push_back(cluster_terminator_operand); new_cluster_result_types.push_back(cluster_terminator_operand.getType()); diff --git a/tensorflow/compiler/mlir/tensorflow/transforms/tpu_extract_outside_compilation.cc b/tensorflow/compiler/mlir/tensorflow/transforms/tpu_extract_outside_compilation.cc index 8adafe05cd3..b141a7dc792 100644 --- a/tensorflow/compiler/mlir/tensorflow/transforms/tpu_extract_outside_compilation.cc +++ b/tensorflow/compiler/mlir/tensorflow/transforms/tpu_extract_outside_compilation.cc @@ -314,21 +314,41 @@ tf_device::LaunchOp CreateLaunchOpForOutsideCluster( return launch_op; } -// Extracts all externally provided operands of `cluster_ops`. +// Extracts all externally provided operands of `host_cluster_ops`. llvm::SmallSetVector GetExternalOperands( - llvm::ArrayRef cluster_ops) { + tf_device::ClusterOp tpu_cluster, + llvm::ArrayRef host_cluster_ops) { llvm::SmallSetVector external_values; - for (Operation* op : cluster_ops) { - for (Value v : op->getOperands()) { - Operation* defining_op = v.getDefiningOp(); - if (!defining_op) continue; - bool is_external = llvm::none_of(cluster_ops, [&](Operation* cluster_op) { - return defining_op == cluster_op; - }); + for (Operation* host_cluster_op : host_cluster_ops) { + auto cluster_op_parent_region = host_cluster_op->getParentRegion(); + host_cluster_op->walk([&](Operation* op) { + auto region = op->getParentRegion(); - if (is_external) external_values.insert(v); - } + if (region == cluster_op_parent_region) { + // For op operands, add operand defining ops, if they are not included + // in `host_cluster_ops`. + for (Value v : op->getOperands()) { + Operation* defining_op = v.getDefiningOp(); + if (!defining_op) continue; + bool is_external = llvm::none_of( + host_cluster_ops, + [&](Operation* cluster_op) { return defining_op == cluster_op; }); + + if (is_external) external_values.insert(v); + } + } else { + llvm::SetVector external_captured_inputs; + visitUsedValuesDefinedAbove(*region, *region, [&](OpOperand* operand) { + Region* parent_region = operand->get().getParentRegion(); + if (!tpu_cluster.body().isAncestor(parent_region)) return; + + external_captured_inputs.insert(operand->get()); + }); + external_values.insert(external_captured_inputs.begin(), + external_captured_inputs.end()); + } + }); } return external_values; @@ -494,7 +514,7 @@ void CreateParallelExecuteFromOutsideClusters(ModuleOp module, &builder, cluster_ops.back(), host_device); // Determine if there are any inputs that are provided out of cluster. - auto external_inputs = GetExternalOperands(cluster_ops); + auto external_inputs = GetExternalOperands(tpu_cluster, cluster_ops); auto external_outputs = GetExternalOutputs(cluster_ops); MoveOutsideCompiledOps(module, tpu_cluster, cluster.value().getFirst(), diff --git a/tensorflow/compiler/mlir/tensorflow/transforms/tpu_identity_pruning.cc b/tensorflow/compiler/mlir/tensorflow/transforms/tpu_identity_pruning.cc new file mode 100644 index 00000000000..32b1eb340d6 --- /dev/null +++ b/tensorflow/compiler/mlir/tensorflow/transforms/tpu_identity_pruning.cc @@ -0,0 +1,113 @@ +/* Copyright 2020 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include +#include + +#include "llvm/ADT/STLExtras.h" +#include "llvm/ADT/SmallVector.h" +#include "mlir/IR/Function.h" // from @llvm-project +#include "mlir/IR/Module.h" // from @llvm-project +#include "mlir/IR/Operation.h" // from @llvm-project +#include "mlir/IR/Region.h" // from @llvm-project +#include "mlir/Interfaces/CallInterfaces.h" // from @llvm-project +#include "mlir/Pass/Pass.h" // from @llvm-project +#include "mlir/Pass/PassRegistry.h" // from @llvm-project +#include "mlir/Support/LLVM.h" // from @llvm-project +#include "tensorflow/compiler/mlir/tensorflow/ir/tf_device.h" +#include "tensorflow/compiler/mlir/tensorflow/ir/tf_ops.h" + +namespace mlir { +namespace TFTPU { + +namespace { + +// This pass removes Identity/IdentityN ops from the TPU computation and +// reachable functions. +// TODO(lyandy): Remove this pass once resource op lifting is migrated to use +// resource alias analysis and support region based control flow. Removing +// Identity ops may remove `_XlaSharding` annotation attribute if Identity ops +// are used to propagate such information. + +struct TPUIdentityPruning + : public PassWrapper> { + void runOnOperation() override; +}; + +// Collects all reachable functions (via call ops) from a given region. +SmallVector CollectReachableFunctions(Region& region) { + llvm::SmallPtrSet reachable_funcs; + + auto collect_reachable_funcs = + [&reachable_funcs](Region& src, SmallVectorImpl& funcs_to_visit) { + src.walk([&reachable_funcs, &funcs_to_visit](CallOpInterface call_op) { + auto func = dyn_cast_or_null(call_op.resolveCallable()); + if (func && reachable_funcs.insert(func).second) + funcs_to_visit.push_back(func); + }); + }; + + SmallVector funcs_to_visit; + collect_reachable_funcs(region, funcs_to_visit); + + while (!funcs_to_visit.empty()) { + SmallVector new_funcs_to_visit; + for (FuncOp func_to_visit : funcs_to_visit) { + if (!func_to_visit.getCallableRegion()) continue; + collect_reachable_funcs(*func_to_visit.getCallableRegion(), + new_funcs_to_visit); + } + funcs_to_visit.swap(new_funcs_to_visit); + } + + return llvm::to_vector<4>(reachable_funcs); +} + +// Removes Identity/IdentityN ops from a region and forwards its operands to its +// results. +void RemoveIdentityFromRegion(Region& region) { + region.walk([](Operation* op) { + if (isa(op)) { + op->replaceAllUsesWith(op->getOperands()); + op->erase(); + } + }); +} + +void TPUIdentityPruning::runOnOperation() { + SmallVector clusters; + getOperation().walk( + [&](tf_device::ClusterOp cluster) { clusters.push_back(cluster); }); + + for (tf_device::ClusterOp cluster : clusters) { + RemoveIdentityFromRegion(cluster.body()); + auto reachable_funcs = CollectReachableFunctions(cluster.body()); + for (FuncOp reachable_func : reachable_funcs) + RemoveIdentityFromRegion(*reachable_func.getCallableRegion()); + } +} + +} // anonymous namespace + +std::unique_ptr> CreateTPUIdentityPruningPass() { + return std::make_unique(); +} + +static PassRegistration pass( + "tf-tpu-identity-pruning", + "Removes Identity/IdentityN ops from the TPU computation"); + +} // namespace TFTPU +} // namespace mlir diff --git a/tensorflow/compiler/mlir/tensorflow/transforms/tpu_rewrite_pass.cc b/tensorflow/compiler/mlir/tensorflow/transforms/tpu_rewrite_pass.cc index ca77feafc05..21ad457a7a6 100644 --- a/tensorflow/compiler/mlir/tensorflow/transforms/tpu_rewrite_pass.cc +++ b/tensorflow/compiler/mlir/tensorflow/transforms/tpu_rewrite_pass.cc @@ -409,12 +409,15 @@ Operation* BuildCompileOp( std::string txt_module; if (failed(EncapsulateFuncAndSerialize(func, &txt_module))) return nullptr; - auto result_type = + auto compilation_status_type = RankedTensorType::get({}, builder->getType()); + auto program_type = + RankedTensorType::get({2}, builder->getType()); auto compile_op = builder->create( - cluster_func.getLoc(), /*compilation_status=*/result_type, /*program=*/ - llvm::SmallVector(num_cores_per_replica, result_type), + cluster_func.getLoc(), + /*compilation_status=*/compilation_status_type, /*program=*/ + llvm::SmallVector(num_cores_per_replica, program_type), compile_op_operands, txt_module, txt_metadata); return WrapOpInLaunch(builder, compile_op.getLoc(), compile_op, @@ -598,9 +601,9 @@ void BuildTPUCompileSucceededAssertOp(Operation* compile_op, // func @main(%arg0: tensor) { // %0 = "tf.Shape"(%arg0) : (tensor) -> tensor // %1:2 = "tf._TPUCompileMlir"(%0) {device = "/CPU:0"} : -// (tensor) -> (tensor, tensor) +// (tensor) -> (tensor, tensor<2x!tf.string>) // %2 = "tf.TPUExecute"(%arg0, %1#0) {device = "/TPU:0"} : -// (tensor, tensor) -> tensor +// (tensor, tensor<2x!tf.string>) -> tensor // return // } // @@ -624,9 +627,9 @@ void BuildTPUCompileSucceededAssertOp(Operation* compile_op, // {n = 2 : i32, devices = ["/TPU:0", "/TPU:1"]} { // %1 = "tf.Shape"(%ri) : (tensor) -> tensor // %2:2 = "tf._TPUCompileMlir"(%1) {device = "/CPU:0"} : -// (tensor) -> (tensor, tensor) +// (tensor) -> (tensor, tensor<2x!tf.string>) // %3 = "tf.TPUExecute"(%ri, %2#0) : -// (tensor, tensor) -> tensor +// (tensor, tensor<2x!tf.string>) -> tensor // tf_device.return %3 : tensor // } // return diff --git a/tensorflow/compiler/mlir/tensorflow/translate/export_graphdef.cc b/tensorflow/compiler/mlir/tensorflow/translate/export_graphdef.cc index 571d5e3e715..631553b381e 100644 --- a/tensorflow/compiler/mlir/tensorflow/translate/export_graphdef.cc +++ b/tensorflow/compiler/mlir/tensorflow/translate/export_graphdef.cc @@ -726,7 +726,7 @@ Status Exporter::Convert(mlir::ModuleOp module, mlir::Identifier::get("main", module.getContext()); absl::optional entry_func; FunctionDefLibrary flib; - auto tf_dialect = module.getContext()->getRegisteredDialect("tf"); + auto tf_dialect = module.getContext()->getLoadedDialect("tf"); for (auto function : module.getOps()) { if (function.isExternal()) return errors::FailedPrecondition("External functions not supported"); @@ -799,7 +799,7 @@ StatusOr> ConvertMlirToGraphdef( stream_executor::port::Status ConvertMlirFunctionToFunctionLibraryDef( mlir::FuncOp func, const GraphExportConfig& configs, FunctionDef* function_def) { - Dialect* tf_dialect = func.getContext()->getRegisteredDialect("tf"); + Dialect* tf_dialect = func.getContext()->getLoadedDialect("tf"); FunctionDefLibrary flib; TF_RETURN_IF_ERROR( Exporter::ConvertLibFunction(configs, tf_dialect, func, &flib)); diff --git a/tensorflow/compiler/mlir/tensorflow/translate/import_model.cc b/tensorflow/compiler/mlir/tensorflow/translate/import_model.cc index 94ddf76736e..692d0eaf962 100644 --- a/tensorflow/compiler/mlir/tensorflow/translate/import_model.cc +++ b/tensorflow/compiler/mlir/tensorflow/translate/import_model.cc @@ -177,7 +177,8 @@ Status UpgradeLegacyGraph(Graph* graph, FunctionLibraryDefinition* flib_def, restrict_functionalization_to_tpu_nodes ? [](const Node* n) { return n->attrs().Find(kTpuReplicateAttr); } : NodeFilter{}; - return FunctionalizeControlFlow(graph, flib_def, node_filter); + return FunctionalizeControlFlow(graph, flib_def, node_filter, + /*include_functions=*/true); } // Stateful helper class to import a TensorFlow model into an MLIR Module. @@ -2135,6 +2136,11 @@ StatusOr GraphDefImporter::Convert( mlir::MLIRContext* context, const Graph& graph, const GraphDebugInfo& debug_info, const FunctionLibraryDefinition& flib_def, const GraphImportConfig& specs, llvm::StringRef func_name) { + // Load dialects involved in the conversion + context->loadDialect(); + context->loadDialect(); + context->loadDialect(); + mlir::OwningModuleRef module = mlir::ModuleOp::create(mlir::UnknownLoc::get(context)); std::unordered_map tf_name_to_mlir_name; diff --git a/tensorflow/compiler/mlir/tensorflow/translate/tf_mlir_translate.cc b/tensorflow/compiler/mlir/tensorflow/translate/tf_mlir_translate.cc index 1c7988d3a40..58377661a23 100644 --- a/tensorflow/compiler/mlir/tensorflow/translate/tf_mlir_translate.cc +++ b/tensorflow/compiler/mlir/tensorflow/translate/tf_mlir_translate.cc @@ -219,22 +219,18 @@ StatusOr GraphdefToSplattedMlirTranslateFunction( if (auto attr = inst.getAttrOfType(attr_id)) { mlir::Attribute rand_val; mlir::Type element_type = attr.getType().getElementType(); + if (element_type.isa()) { + rand_val = mlir::IntegerAttr::get(element_type, std::rand()); + } else if (element_type.isF16() || element_type.isF32() || + element_type.isF64()) { + rand_val = mlir::FloatAttr::get(element_type, + std::rand() * 1.0 / RAND_MAX); - switch (element_type.getKind()) { - case mlir::StandardTypes::Integer: - rand_val = mlir::IntegerAttr::get(element_type, std::rand()); - break; - case mlir::StandardTypes::F16: - case mlir::StandardTypes::F32: - case mlir::StandardTypes::F64: - rand_val = mlir::FloatAttr::get(element_type, - std::rand() * 1.0 / RAND_MAX); - break; - default: - inst.emitWarning() - << "Skipping splat conversion for " - << "an unsupported attribute type " << element_type; - continue; + } else { + inst.emitWarning() + << "Skipping splat conversion for " + << "an unsupported attribute type " << element_type; + continue; } auto new_attr = mlir::DenseElementsAttr::get(attr.getType(), rand_val); diff --git a/tensorflow/compiler/mlir/tensorflow/transforms/attribute_utils.h b/tensorflow/compiler/mlir/tensorflow/utils/attribute_utils.h similarity index 66% rename from tensorflow/compiler/mlir/tensorflow/transforms/attribute_utils.h rename to tensorflow/compiler/mlir/tensorflow/utils/attribute_utils.h index 599a8df63d7..bd81cae5730 100644 --- a/tensorflow/compiler/mlir/tensorflow/transforms/attribute_utils.h +++ b/tensorflow/compiler/mlir/tensorflow/utils/attribute_utils.h @@ -13,8 +13,8 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#ifndef TENSORFLOW_COMPILER_MLIR_TENSORFLOW_TRANSFORMS_ATTRIBUTE_UTILS_H_ -#define TENSORFLOW_COMPILER_MLIR_TENSORFLOW_TRANSFORMS_ATTRIBUTE_UTILS_H_ +#ifndef TENSORFLOW_COMPILER_MLIR_TENSORFLOW_UTILS_ATTRIBUTE_UTILS_H_ +#define TENSORFLOW_COMPILER_MLIR_TENSORFLOW_UTILS_ATTRIBUTE_UTILS_H_ #include "mlir/IR/Attributes.h" // from @llvm-project #include "mlir/IR/Operation.h" // from @llvm-project @@ -36,7 +36,18 @@ inline void CopyUnderscoredAttributes(Operation *from, Operation *to) { }); } +// Copies attributes that are either `device` or whose name begins with an _ +// from `from` to `to`. +// TODO(b/158769932): This should be a general feature instead post some policy +// discussion. +inline void CopyDeviceAndUnderscoredAttributes(Operation *from, Operation *to) { + auto device = mlir::Identifier::get("device", from->getContext()); + CopyAttributes(from, to, [&device](const NamedAttribute &attr) { + return attr.first.strref().front() == '_' || attr.first == device; + }); +} + } // namespace TF } // namespace mlir -#endif // TENSORFLOW_COMPILER_MLIR_TENSORFLOW_TRANSFORMS_ATTRIBUTE_UTILS_H_ +#endif // TENSORFLOW_COMPILER_MLIR_TENSORFLOW_UTILS_ATTRIBUTE_UTILS_H_ diff --git a/tensorflow/compiler/mlir/tensorflow/utils/compile_mlir_util.cc b/tensorflow/compiler/mlir/tensorflow/utils/compile_mlir_util.cc index 99a5e32adc2..f7a9823a1a8 100644 --- a/tensorflow/compiler/mlir/tensorflow/utils/compile_mlir_util.cc +++ b/tensorflow/compiler/mlir/tensorflow/utils/compile_mlir_util.cc @@ -420,6 +420,7 @@ Status CompileSerializedMlirToXlaHlo( std::vector> custom_legalization_passes) { RegisterDialects(); mlir::MLIRContext mlir_context; + mlir_context.loadAllGloballyRegisteredDialects(); mlir::OwningModuleRef mlir_module; TF_RETURN_IF_ERROR( @@ -509,6 +510,7 @@ Status CompileGraphToXlaHlo( RegisterDialects(); mlir::MLIRContext context; + context.loadAllGloballyRegisteredDialects(); GraphImportConfig config; config.graph_as_function = true; // Disable shape inference during import as some TensorFlow op fails during diff --git a/tensorflow/compiler/mlir/tensorflow/utils/convert_tensor.cc b/tensorflow/compiler/mlir/tensorflow/utils/convert_tensor.cc index 359314a64b0..05e1f059029 100644 --- a/tensorflow/compiler/mlir/tensorflow/utils/convert_tensor.cc +++ b/tensorflow/compiler/mlir/tensorflow/utils/convert_tensor.cc @@ -36,8 +36,8 @@ limitations under the License. #include "tensorflow/core/framework/tensor.pb.h" #include "tensorflow/core/framework/tensor_shape.pb.h" #include "tensorflow/core/framework/types.pb.h" -#include "tensorflow/core/lib/bfloat16/bfloat16.h" #include "tensorflow/core/lib/core/errors.h" +#include "tensorflow/core/platform/bfloat16.h" #include "tensorflow/core/platform/errors.h" #include "tensorflow/core/platform/protobuf.h" #include "tensorflow/core/platform/tstring.h" @@ -161,7 +161,7 @@ StatusOr ConvertTensor(const Tensor& input_tensor, default: // TODO(shpeisman): restructure code to reuse dialect pointer across // calls. - auto* dialect = builder->getContext()->getRegisteredDialect("tf"); + auto* dialect = builder->getContext()->getLoadedDialect("tf"); return OpaqueElementsAttr::get(dialect, type, MangleTensor(input_tensor)); } diff --git a/tensorflow/compiler/mlir/tensorflow/utils/convert_tensor_test.cc b/tensorflow/compiler/mlir/tensorflow/utils/convert_tensor_test.cc index bf96e3d1df4..4917d73ba2a 100644 --- a/tensorflow/compiler/mlir/tensorflow/utils/convert_tensor_test.cc +++ b/tensorflow/compiler/mlir/tensorflow/utils/convert_tensor_test.cc @@ -43,6 +43,7 @@ static void RegisterDialects() { TEST(ConvertTypeToTensorTypeTest, UnrankedTensorType) { mlir::MLIRContext context; + context.loadAllGloballyRegisteredDialects(); mlir::Builder b(&context); PartialTensorShape output_shape = @@ -52,6 +53,7 @@ TEST(ConvertTypeToTensorTypeTest, UnrankedTensorType) { TEST(ConvertTypeToTensorTypeTest, NonFullyDefinedRankedTensorType) { mlir::MLIRContext context; + context.loadAllGloballyRegisteredDialects(); mlir::Builder b(&context); PartialTensorShape output_shape = ConvertTypeToTensorShape( @@ -61,6 +63,7 @@ TEST(ConvertTypeToTensorTypeTest, NonFullyDefinedRankedTensorType) { TEST(ConvertTypeToTensorTypeTest, FullyDefinedRankedTensorType) { mlir::MLIRContext context; + context.loadAllGloballyRegisteredDialects(); mlir::Builder b(&context); PartialTensorShape output_shape = ConvertTypeToTensorShape( diff --git a/tensorflow/compiler/mlir/tensorflow/utils/convert_type.cc b/tensorflow/compiler/mlir/tensorflow/utils/convert_type.cc index 0caceb69510..0d035e8f864 100644 --- a/tensorflow/compiler/mlir/tensorflow/utils/convert_type.cc +++ b/tensorflow/compiler/mlir/tensorflow/utils/convert_type.cc @@ -91,64 +91,62 @@ Status ConvertDataType(DataType dtype, Builder builder, Type* type) { } Status ConvertScalarTypeToDataType(Type type, DataType* dtype) { - switch (type.getKind()) { - case mlir::StandardTypes::F16: - *dtype = DT_HALF; - return Status::OK(); - case mlir::StandardTypes::F32: - *dtype = DT_FLOAT; - return Status::OK(); - case mlir::StandardTypes::F64: - *dtype = DT_DOUBLE; - return Status::OK(); - case mlir::StandardTypes::BF16: - *dtype = DT_BFLOAT16; - return Status::OK(); - case mlir::StandardTypes::Integer: { - const auto& itype = type.cast(); - switch (itype.getWidth()) { - case 1: - *dtype = DT_BOOL; - return Status::OK(); - case 8: - *dtype = itype.isUnsigned() ? DT_UINT8 : DT_INT8; - return Status::OK(); - case 16: - *dtype = itype.isUnsigned() ? DT_UINT16 : DT_INT16; - return Status::OK(); - case 32: - *dtype = itype.isUnsigned() ? DT_UINT32 : DT_INT32; - return Status::OK(); - case 64: - *dtype = itype.isUnsigned() ? DT_UINT64 : DT_INT64; - return Status::OK(); - default: - return errors::Unimplemented( - absl::StrCat("Converting ", debugString(type), " to DataType")); - } - } - case mlir::StandardTypes::Complex: { - auto etype = type.cast().getElementType(); - if (etype.isF32()) { - *dtype = DT_COMPLEX64; - return Status::OK(); - } else if (etype.isF64()) { - *dtype = DT_COMPLEX128; - return Status::OK(); - } - return errors::Unimplemented( - absl::StrCat("Converting ", debugString(type), " to DataType")); - } -#define HANDLE_TF_TYPE(tftype, enumerant, name) \ - case mlir::TF::TensorFlowTypes::enumerant: \ - *dtype = DT_##enumerant; \ + if (type.isF16()) { + *dtype = DT_HALF; return Status::OK(); + } else if (type.isF32()) { + *dtype = DT_FLOAT; + return Status::OK(); + } else if (type.isF64()) { + *dtype = DT_DOUBLE; + return Status::OK(); + } else if (type.isBF16()) { + *dtype = DT_BFLOAT16; + return Status::OK(); + } else if (auto itype = type.dyn_cast()) { + switch (itype.getWidth()) { + case 1: + *dtype = DT_BOOL; + return Status::OK(); + case 8: + *dtype = itype.isUnsigned() ? DT_UINT8 : DT_INT8; + return Status::OK(); + case 16: + *dtype = itype.isUnsigned() ? DT_UINT16 : DT_INT16; + return Status::OK(); + case 32: + *dtype = itype.isUnsigned() ? DT_UINT32 : DT_INT32; + return Status::OK(); + case 64: + *dtype = itype.isUnsigned() ? DT_UINT64 : DT_INT64; + return Status::OK(); + default: + return errors::Unimplemented( + absl::StrCat("Converting ", debugString(type), " to DataType")); + } + } else if (auto complex_type = type.dyn_cast()) { + auto etype = complex_type.getElementType(); + if (etype.isF32()) { + *dtype = DT_COMPLEX64; + return Status::OK(); + } else if (etype.isF64()) { + *dtype = DT_COMPLEX128; + return Status::OK(); + } + return errors::Unimplemented( + absl::StrCat("Converting ", debugString(type), " to DataType")); + } + +#define HANDLE_TF_TYPE(tftype, enumerant, name) \ + if (type.isa()) { \ + *dtype = DT_##enumerant; \ + return Status::OK(); \ + } // NOLINTNEXTLINE #include "tensorflow/compiler/mlir/tensorflow/ir/tf_types.def" - default: - return errors::Unimplemented( - absl::StrCat("Converting ", debugString(type), " to DataType")); - } + + return errors::Unimplemented( + absl::StrCat("Converting ", debugString(type), " to DataType")); } Status ConvertToDataType(Type type, DataType* dtype) { diff --git a/tensorflow/compiler/mlir/tensorflow/utils/convert_type_test.cc b/tensorflow/compiler/mlir/tensorflow/utils/convert_type_test.cc index 07f6b129a41..5b791752eb0 100644 --- a/tensorflow/compiler/mlir/tensorflow/utils/convert_type_test.cc +++ b/tensorflow/compiler/mlir/tensorflow/utils/convert_type_test.cc @@ -36,6 +36,7 @@ std::string ConvertToMlirString(const std::vector& dims, } mlir::MLIRContext context; mlir::Builder b(&context); + context.loadAllGloballyRegisteredDialects(); auto status_or = ConvertToMlirTensorType(shape, dtype, &b); std::string buf; llvm::raw_string_ostream os(buf); diff --git a/tensorflow/compiler/mlir/tensorflow/utils/device_util_test.cc b/tensorflow/compiler/mlir/tensorflow/utils/device_util_test.cc index 1da1f5973f6..e41b62ddccd 100644 --- a/tensorflow/compiler/mlir/tensorflow/utils/device_util_test.cc +++ b/tensorflow/compiler/mlir/tensorflow/utils/device_util_test.cc @@ -60,6 +60,7 @@ class FakeDevice : public Device { TEST(DeviceUtilTest, AddDeviceToOp) { mlir::MLIRContext context; + context.loadAllGloballyRegisteredDialects(); mlir::OwningModuleRef module_ref = mlir::ModuleOp::create(mlir::UnknownLoc::get(&context)); @@ -101,6 +102,7 @@ TEST(DeviceUtilTest, AddDeviceToOp) { TEST(DeviceUtilTest, AddDeviceToOpNullDeviceSet) { mlir::MLIRContext context; + context.loadAllGloballyRegisteredDialects(); mlir::OwningModuleRef module_ref = mlir::ModuleOp::create(mlir::UnknownLoc::get(&context)); @@ -110,6 +112,7 @@ TEST(DeviceUtilTest, AddDeviceToOpNullDeviceSet) { TEST(DeviceUtilTest, GetDevicesFromOpNoDevicesAttribute) { mlir::MLIRContext context; + context.loadAllGloballyRegisteredDialects(); mlir::OwningModuleRef module_ref = mlir::ModuleOp::create(mlir::UnknownLoc::get(&context)); diff --git a/tensorflow/compiler/mlir/tensorflow/utils/dump_graph.cc b/tensorflow/compiler/mlir/tensorflow/utils/dump_graph.cc index c77107c8de7..4fcf036b160 100644 --- a/tensorflow/compiler/mlir/tensorflow/utils/dump_graph.cc +++ b/tensorflow/compiler/mlir/tensorflow/utils/dump_graph.cc @@ -66,6 +66,7 @@ Status DumpTextualIRToFile(const MlirDumpConfig& config, const Graph& graph, WritableFile* file) { WritableFileRawStream os(std::move(file)); mlir::MLIRContext context; + context.loadAllGloballyRegisteredDialects(); mlir::OwningModuleRef module; if (flib_def) { flib_def = &graph.flib_def(); diff --git a/tensorflow/compiler/mlir/tensorflow/utils/dump_mlir_util_test.cc b/tensorflow/compiler/mlir/tensorflow/utils/dump_mlir_util_test.cc index c0d109f7569..dee499605e1 100644 --- a/tensorflow/compiler/mlir/tensorflow/utils/dump_mlir_util_test.cc +++ b/tensorflow/compiler/mlir/tensorflow/utils/dump_mlir_util_test.cc @@ -28,6 +28,7 @@ namespace { TEST(DumpMlirModuleTest, NoEnvPrefix) { mlir::MLIRContext context; + context.loadAllGloballyRegisteredDialects(); mlir::OwningModuleRef module_ref = mlir::ModuleOp::create(mlir::UnknownLoc::get(&context)); unsetenv("TF_DUMP_GRAPH_PREFIX"); @@ -38,6 +39,7 @@ TEST(DumpMlirModuleTest, NoEnvPrefix) { TEST(DumpMlirModuleTest, LogInfo) { mlir::MLIRContext context; + context.loadAllGloballyRegisteredDialects(); mlir::OwningModuleRef module_ref = mlir::ModuleOp::create(mlir::UnknownLoc::get(&context)); setenv("TF_DUMP_GRAPH_PREFIX", "-", 1); @@ -48,6 +50,7 @@ TEST(DumpMlirModuleTest, LogInfo) { TEST(DumpMlirModuleTest, Valid) { mlir::MLIRContext context; + context.loadAllGloballyRegisteredDialects(); mlir::OwningModuleRef module_ref = mlir::ModuleOp::create(mlir::UnknownLoc::get(&context)); setenv("TF_DUMP_GRAPH_PREFIX", testing::TmpDir().c_str(), 1); diff --git a/tensorflow/compiler/mlir/tensorflow/utils/error_util_test.cc b/tensorflow/compiler/mlir/tensorflow/utils/error_util_test.cc index b174ad40a3b..832bc04fdaa 100644 --- a/tensorflow/compiler/mlir/tensorflow/utils/error_util_test.cc +++ b/tensorflow/compiler/mlir/tensorflow/utils/error_util_test.cc @@ -29,6 +29,7 @@ using testing::HasSubstr; TEST(ErrorUtilTest, StatusScopedDiagnosticHandler) { MLIRContext context; + context.loadAllGloballyRegisteredDialects(); auto id = Identifier::get("test.cc", &context); auto loc = FileLineColLoc::get(id, 0, 0, &context); diff --git a/tensorflow/compiler/mlir/tensorflow/utils/export_utils.cc b/tensorflow/compiler/mlir/tensorflow/utils/export_utils.cc index ad9ddb277d7..67c2aebf121 100644 --- a/tensorflow/compiler/mlir/tensorflow/utils/export_utils.cc +++ b/tensorflow/compiler/mlir/tensorflow/utils/export_utils.cc @@ -22,6 +22,7 @@ limitations under the License. #include "absl/strings/str_split.h" #include "absl/strings/string_view.h" #include "llvm/ADT/StringRef.h" +#include "llvm/ADT/TypeSwitch.h" #include "llvm/Support/Casting.h" #include "mlir/Dialect/StandardOps/IR/Ops.h" // from @llvm-project #include "mlir/IR/Attributes.h" // from @llvm-project @@ -368,65 +369,36 @@ Status ConvertAttributes( name = mangling_util::DemangleAttributeName(name); } AttrValue value; - switch (attr.getKind()) { - case mlir::StandardAttributes::SymbolRef: { - TF_RETURN_IF_ERROR( - ConvertAttribute(attr.cast(), &value)); - func_call_attrs[string(name)] = value; - continue; - } - case mlir::StandardAttributes::Integer: - if (auto boolAttr = attr.dyn_cast()) { - TF_RETURN_IF_ERROR(ConvertAttribute(boolAttr, &value)); - } else { - TF_RETURN_IF_ERROR( - ConvertAttribute(attr.cast(), &value)); - } - break; - case mlir::StandardAttributes::Float: - TF_RETURN_IF_ERROR( - ConvertAttribute(attr.cast(), &value)); - break; - case mlir::StandardAttributes::String: - TF_RETURN_IF_ERROR( - ConvertAttribute(attr.cast(), &value)); - break; - case mlir::StandardAttributes::Array: - TF_RETURN_IF_ERROR( - ConvertAttribute(attr.cast(), &value)); - break; - case mlir::StandardAttributes::DenseIntOrFPElements: - case mlir::StandardAttributes::DenseStringElements: - case mlir::StandardAttributes::OpaqueElements: - TF_RETURN_IF_ERROR( - ConvertAttribute(attr.cast(), &value)); - break; - case mlir::StandardAttributes::Type: - TF_RETURN_IF_ERROR( - ConvertAttribute(attr.cast(), &value)); - break; - case mlir::StandardAttributes::Unit: - TF_RETURN_IF_ERROR( - ConvertAttribute(attr.cast(), &value)); - break; - case static_cast(mlir::TF::AttrKind::SHAPE): - TF_RETURN_IF_ERROR( - ConvertAttribute(attr.cast(), &value)); - break; - case static_cast(mlir::TF::AttrKind::FUNC): { - TF_RETURN_IF_ERROR( - ConvertAttribute(attr.cast(), &value)); - func_call_attrs[string(name)] = value; - continue; - } - // AffineMap kind is not implemented. - case mlir::StandardAttributes::AffineMap: - return errors::Unimplemented("AffineMap attribute (needed for '", - name_strref, "') unimplemented"); - default: - return errors::Unimplemented("Unhandled attribute kind for attribute '", - name_strref, '\''); + if (auto symbol_ref = attr.dyn_cast()) { + TF_RETURN_IF_ERROR( + ConvertAttribute(symbol_ref.cast(), &value)); + func_call_attrs[string(name)] = value; + continue; } + if (auto func_attr = attr.dyn_cast()) { + TF_RETURN_IF_ERROR(ConvertAttribute(func_attr, &value)); + func_call_attrs[string(name)] = value; + continue; + } + if (attr.isa()) { + // AffineMapAttr is not implemented. + return errors::Unimplemented("AffineMap attribute (needed for '", + name_strref, "') unimplemented"); + } + TF_RETURN_IF_ERROR( + llvm::TypeSwitch(attr) + .Case( + [&](auto derived_attr) { + return ConvertAttribute(derived_attr, &value); + }) + .Default([&](mlir::Attribute) { + return errors::Unimplemented( + "Unhandled attribute kind for attribute '", name_strref, + '\''); + })); + // According to the NodeDef proto definition, an attribute name from the // input TensorFlow GraphDef shouldn't contain '.'. If it does appear in // the attribute from MLIR, it is treated as an attribute from function diff --git a/tensorflow/compiler/mlir/tensorflow/utils/tpu_rewrite_device_util_test.cc b/tensorflow/compiler/mlir/tensorflow/utils/tpu_rewrite_device_util_test.cc index b23fbe7d73c..fc206ca08f9 100644 --- a/tensorflow/compiler/mlir/tensorflow/utils/tpu_rewrite_device_util_test.cc +++ b/tensorflow/compiler/mlir/tensorflow/utils/tpu_rewrite_device_util_test.cc @@ -602,6 +602,7 @@ TEST(TPURewriteDeviceUtilTest, ValidGeneralDeviceAssignmentMesh1x2x1x3) { TEST(TPURewriteDeviceUtilTest, TestGetDeviceCoordinates) { mlir::MLIRContext context; + context.loadAllGloballyRegisteredDialects(); mlir::Builder builder(&context); auto device_assignment_attr = builder.getI64ArrayAttr({1, 2, 3}); auto status_or_device_coodinates = @@ -615,6 +616,7 @@ TEST(TPURewriteDeviceUtilTest, TestGetDeviceCoordinates) { TEST(TPURewriteDeviceUtilTest, TestInvalidAttrForDeviceAssignmentDisallowed) { mlir::MLIRContext context; + context.loadAllGloballyRegisteredDialects(); mlir::Builder builder(&context); auto device_assignment_attr = builder.getF32ArrayAttr({1.0, 2.0, 3.0}); auto status_or_device_coodinates = @@ -627,6 +629,7 @@ TEST(TPURewriteDeviceUtilTest, TestInvalidAttrForDeviceAssignmentDisallowed) { TEST(TPURewriteDeviceUtilTest, TestGetHostFailDeviceMissingAttributes) { mlir::registerDialect(); mlir::MLIRContext context; + context.loadAllGloballyRegisteredDialects(); mlir::OwningModuleRef module_ref = mlir::ModuleOp::create(mlir::UnknownLoc::get(&context)); mlir::OpBuilder builder(module_ref->getBodyRegion()); diff --git a/tensorflow/compiler/mlir/tf_mlir_opt_main.cc b/tensorflow/compiler/mlir/tf_mlir_opt_main.cc index 1416ac038d6..144e22750ca 100644 --- a/tensorflow/compiler/mlir/tf_mlir_opt_main.cc +++ b/tensorflow/compiler/mlir/tf_mlir_opt_main.cc @@ -17,77 +17,36 @@ limitations under the License. #include "llvm/Support/InitLLVM.h" #include "llvm/Support/SourceMgr.h" #include "llvm/Support/ToolOutputFile.h" +#include "mlir/Dialect/Shape/IR/Shape.h" // from @llvm-project #include "mlir/IR/AsmState.h" // from @llvm-project +#include "mlir/InitAllDialects.h" // from @llvm-project +#include "mlir/InitAllPasses.h" // from @llvm-project #include "mlir/Pass/Pass.h" // from @llvm-project #include "mlir/Pass/PassManager.h" // from @llvm-project #include "mlir/Support/FileUtilities.h" // from @llvm-project #include "mlir/Support/MlirOptMain.h" // from @llvm-project #include "tensorflow/compiler/mlir/init_mlir.h" +#include "tensorflow/compiler/mlir/lite/ir/tfl_ops.h" +#include "tensorflow/compiler/mlir/tensorflow/ir/tf_device.h" +#include "tensorflow/compiler/mlir/tensorflow/ir/tf_executor.h" +#include "tensorflow/compiler/mlir/tensorflow/ir/tf_ops.h" +#include "tensorflow/compiler/mlir/tensorflow/ir/tf_saved_model.h" #include "tensorflow/core/platform/init_main.h" #include "tensorflow/core/platform/logging.h" -// NOLINTNEXTLINE -static llvm::cl::opt input_filename(llvm::cl::Positional, - llvm::cl::desc(""), - llvm::cl::init("-")); - -// NOLINTNEXTLINE -static llvm::cl::opt output_filename( - "o", llvm::cl::desc("Output filename"), llvm::cl::value_desc("filename"), - llvm::cl::init("-")); - -// NOLINTNEXTLINE -static llvm::cl::opt split_input_file( - "split-input-file", - llvm::cl::desc("Split the input file into pieces and process each " - "chunk independently"), - llvm::cl::init(false)); - -// NOLINTNEXTLINE -static llvm::cl::opt verify_diagnostics( - "verify-diagnostics", - llvm::cl::desc("Check that emitted diagnostics match " - "expected-* lines on the corresponding line"), - llvm::cl::init(false)); - -// NOLINTNEXTLINE -static llvm::cl::opt verify_passes( - "verify-each", - llvm::cl::desc("Run the verifier after each transformation pass"), - llvm::cl::init(true)); - -// NOLINTNEXTLINE -static llvm::cl::opt allowUnregisteredDialects( - "allow-unregistered-dialect", - llvm::cl::desc("Allow operation with no registered dialects"), - llvm::cl::init(false)); - int main(int argc, char **argv) { tensorflow::InitMlir y(&argc, &argv); - // Register various MLIR command line options. - mlir::registerAsmPrinterCLOptions(); - mlir::registerMLIRContextCLOptions(); - mlir::registerPassManagerCLOptions(); + mlir::registerAllPasses(); - // Parse pass names in main to ensure static initialization completed. - mlir::PassPipelineCLParser pass_pipeline("", "Compiler passes to run"); - - llvm::cl::ParseCommandLineOptions(argc, argv, - "TF MLIR modular optimizer driver\n"); - - // Set up the input file. - std::string error_message; - auto file = mlir::openInputFile(input_filename, &error_message); - QCHECK(file) << error_message; - - auto output = mlir::openOutputFile(output_filename, &error_message); - QCHECK(output) << error_message; - - if (failed(mlir::MlirOptMain(output->os(), std::move(file), pass_pipeline, - split_input_file, verify_diagnostics, - verify_passes, allowUnregisteredDialects))) - return 1; - output->keep(); - return 0; + mlir::DialectRegistry registry; + mlir::registerAllDialects(registry); + registry.insert(); + registry.insert(); + registry.insert(); + registry.insert(); + registry.insert(); + registry.insert(); + return failed( + mlir::MlirOptMain(argc, argv, "TensorFlow pass driver\n", registry)); } diff --git a/tensorflow/compiler/mlir/tf_mlir_translate_main.cc b/tensorflow/compiler/mlir/tf_mlir_translate_main.cc index caac8ea1eeb..9b0b3aaa82b 100644 --- a/tensorflow/compiler/mlir/tf_mlir_translate_main.cc +++ b/tensorflow/compiler/mlir/tf_mlir_translate_main.cc @@ -111,6 +111,7 @@ int main(int argc, char** argv) { if (import_saved_model_object_graph) { mlir::MLIRContext context; + context.loadAllGloballyRegisteredDialects(); auto module_or = tensorflow::SavedModelObjectGraphToMlirImport( input_filename, tags, exported_names, &context); @@ -119,6 +120,7 @@ int main(int argc, char** argv) { module_or.ConsumeValueOrDie()->print(output->os()); } else if (import_saved_model_signature_defs) { mlir::MLIRContext context; + context.loadAllGloballyRegisteredDialects(); auto module_or = tensorflow::SavedModelSignatureDefsToMlirImport( input_filename, tags, exported_names, &context, upgrade_legacy); @@ -139,6 +141,7 @@ int main(int argc, char** argv) { llvm::SourceMgr sourceMgr; sourceMgr.AddNewSourceBuffer(std::move(ownedBuffer), llvm::SMLoc()); mlir::MLIRContext context; + context.loadAllGloballyRegisteredDialects(); mlir::SourceMgrDiagnosticHandler diagnostic_handler(sourceMgr, &context); return (*requested_translation)(sourceMgr, os, &context); }; diff --git a/tensorflow/compiler/mlir/tfjs/translate/tf_tfjs_translate.cc b/tensorflow/compiler/mlir/tfjs/translate/tf_tfjs_translate.cc index e735a3c7b8c..915fb91a8df 100644 --- a/tensorflow/compiler/mlir/tfjs/translate/tf_tfjs_translate.cc +++ b/tensorflow/compiler/mlir/tfjs/translate/tf_tfjs_translate.cc @@ -125,6 +125,7 @@ int main(int argc, char** argv) { "TF GraphDef to TFJS JSON converter\n"); MLIRContext context; + context.loadAllGloballyRegisteredDialects(); llvm::SourceMgr source_mgr; mlir::SourceMgrDiagnosticHandler sourceMgrHandler(source_mgr, &context); diff --git a/tensorflow/compiler/mlir/tools/kernel_gen/BUILD b/tensorflow/compiler/mlir/tools/kernel_gen/BUILD index 5befdcdc513..e01c059ad90 100644 --- a/tensorflow/compiler/mlir/tools/kernel_gen/BUILD +++ b/tensorflow/compiler/mlir/tools/kernel_gen/BUILD @@ -72,7 +72,6 @@ tf_cc_binary( "@llvm-project//mlir:AllPassesAndDialectsNoRegistration", "@llvm-project//mlir:IR", "@llvm-project//mlir:MlirOptLib", - "@llvm-project//mlir:MlirOptMain", "@llvm-project//mlir:Pass", "@llvm-project//mlir:Support", ], diff --git a/tensorflow/compiler/mlir/tools/kernel_gen/cubin_creator.cc b/tensorflow/compiler/mlir/tools/kernel_gen/cubin_creator.cc index 82b0e613f90..5f358c61cc2 100644 --- a/tensorflow/compiler/mlir/tools/kernel_gen/cubin_creator.cc +++ b/tensorflow/compiler/mlir/tools/kernel_gen/cubin_creator.cc @@ -261,6 +261,7 @@ StatusOr> tensorflow::kernel_gen::GenerateCubinForTfCode( llvm::ArrayRef unroll_factors) { RegisterDialects(); mlir::MLIRContext context; + context.loadAllGloballyRegisteredDialects(); mlir::OwningModuleRef module = mlir::parseSourceString(tf_code, &context); TF_RETURN_IF_ERROR(LowerTfOpToLhloWithDynamicShapes(module.get())); diff --git a/tensorflow/compiler/mlir/tools/kernel_gen/ir/tf_framework_ops.cc b/tensorflow/compiler/mlir/tools/kernel_gen/ir/tf_framework_ops.cc index 5b7a19a3eac..8c02a734f1d 100644 --- a/tensorflow/compiler/mlir/tools/kernel_gen/ir/tf_framework_ops.cc +++ b/tensorflow/compiler/mlir/tools/kernel_gen/ir/tf_framework_ops.cc @@ -48,13 +48,11 @@ Type TFFrameworkDialect::parseType(DialectAsmParser &parser) const { /// Print a type registered to this dialect. void TFFrameworkDialect::printType(Type type, DialectAsmPrinter &os) const { - switch (type.getKind()) { - case TFFrameworkTypes::OpKernelContextType: - os << "op_kernel_context"; - return; - default: - llvm_unreachable("unexpected TF Framework type kind"); + if (type.isa()) { + os << "op_kernel_context"; + return; } + llvm_unreachable("unexpected TF Framework type kind"); } template diff --git a/tensorflow/compiler/mlir/tools/kernel_gen/ir/tf_framework_ops.h b/tensorflow/compiler/mlir/tools/kernel_gen/ir/tf_framework_ops.h index a4c588a41f5..d2612a38799 100644 --- a/tensorflow/compiler/mlir/tools/kernel_gen/ir/tf_framework_ops.h +++ b/tensorflow/compiler/mlir/tools/kernel_gen/ir/tf_framework_ops.h @@ -30,22 +30,12 @@ namespace mlir { namespace kernel_gen { namespace tf_framework { -namespace TFFrameworkTypes { -enum Kind { - OpKernelContextType = Type::FIRST_TF_FRAMEWORK_TYPE, -}; -} // namespace TFFrameworkTypes - /// OpKernelContextType corresponds to C++ class OpKernelContext defined in /// tensorflow/core/framework/op_kernel.h class OpKernelContextType : public Type::TypeBase { public: using Base::Base; - - static OpKernelContextType get(MLIRContext *context) { - return Base::get(context, TFFrameworkTypes::Kind::OpKernelContextType); - } }; #define GET_OP_CLASSES diff --git a/tensorflow/compiler/mlir/tools/kernel_gen/tools/kernel-gen-opt/kernel-gen-opt.cc b/tensorflow/compiler/mlir/tools/kernel_gen/tools/kernel-gen-opt/kernel-gen-opt.cc index c1af35617b1..4fb169a9729 100644 --- a/tensorflow/compiler/mlir/tools/kernel_gen/tools/kernel-gen-opt/kernel-gen-opt.cc +++ b/tensorflow/compiler/mlir/tools/kernel_gen/tools/kernel-gen-opt/kernel-gen-opt.cc @@ -90,8 +90,9 @@ int main(int argc, char **argv) { if (showDialects) { mlir::MLIRContext context; + context.loadAllGloballyRegisteredDialects(); llvm::outs() << "Registered Dialects:\n"; - for (mlir::Dialect *dialect : context.getRegisteredDialects()) { + for (mlir::Dialect *dialect : context.getLoadedDialects()) { llvm::outs() << dialect->getNamespace() << "\n"; } return 0; @@ -111,9 +112,12 @@ int main(int argc, char **argv) { exit(1); } - if (failed(MlirOptMain(output->os(), std::move(file), passPipeline, + mlir::DialectRegistry registry; + registerAllDialects(registry); + if (failed(MlirOptMain(output->os(), std::move(file), passPipeline, registry, splitInputFile, verifyDiagnostics, verifyPasses, - allowUnregisteredDialects))) { + allowUnregisteredDialects, + /*preloadDialectsInContext=*/true))) { return 1; } // Keep the output file if the invocation of MlirOptMain was successful. diff --git a/tensorflow/compiler/mlir/xla/BUILD b/tensorflow/compiler/mlir/xla/BUILD index 71e18af498b..4c14bcf8960 100644 --- a/tensorflow/compiler/mlir/xla/BUILD +++ b/tensorflow/compiler/mlir/xla/BUILD @@ -55,6 +55,7 @@ cc_library( "transforms/passes.h", ], deps = [ + ":attribute_importer", ":type_to_shape", ":xla_legalize_tf_with_tf2xla", "//tensorflow/compiler/mlir/hlo", @@ -69,7 +70,7 @@ cc_library( "//tensorflow/compiler/xla/client/lib:conv_grad_size_util", "//tensorflow/core:framework", "//tensorflow/core/kernels:conv_grad_shape_utils", - "//tensorflow/core/lib/bfloat16", + "//tensorflow/core/platform:bfloat16", "@llvm-project//llvm:Support", "@llvm-project//mlir:Analysis", "@llvm-project//mlir:Dialect", @@ -95,6 +96,7 @@ cc_library( "//tensorflow/compiler/mlir/tensorflow:convert_type", "//tensorflow/compiler/mlir/tensorflow:export_tf_dialect_op", "//tensorflow/compiler/mlir/tensorflow:lower_tf_lib", + "//tensorflow/compiler/mlir/tensorflow:tensorflow_ops", "//tensorflow/compiler/mlir/tensorflow:translate_utils", "//tensorflow/compiler/tf2xla:xla_compilation_device", "//tensorflow/compiler/tf2xla:xla_context", diff --git a/tensorflow/compiler/mlir/xla/hlo_function_importer.h b/tensorflow/compiler/mlir/xla/hlo_function_importer.h index db981bb0227..e0cc89004cf 100644 --- a/tensorflow/compiler/mlir/xla/hlo_function_importer.h +++ b/tensorflow/compiler/mlir/xla/hlo_function_importer.h @@ -19,6 +19,7 @@ limitations under the License. #include #include "absl/types/optional.h" +#include "mlir/Dialect/StandardOps/IR/Ops.h" // from @llvm-project #include "mlir/IR/Attributes.h" // from @llvm-project #include "mlir/IR/Builders.h" // from @llvm-project #include "mlir/IR/Function.h" // from @llvm-project @@ -62,7 +63,10 @@ class HloFunctionImporter { : context_(module.getContext()), module_(module), builder_(builder), - function_map_(function_map) {} + function_map_(function_map) { + context_->loadDialect(); + context_->loadDialect(); + } // Imports the given computation as a new function, if it hasn't been already // imported. diff --git a/tensorflow/compiler/mlir/xla/hlo_module_importer.cc b/tensorflow/compiler/mlir/xla/hlo_module_importer.cc index dd045da3899..9db5861934f 100644 --- a/tensorflow/compiler/mlir/xla/hlo_module_importer.cc +++ b/tensorflow/compiler/mlir/xla/hlo_module_importer.cc @@ -30,6 +30,12 @@ limitations under the License. namespace xla { +HloModuleImporter::HloModuleImporter(mlir::ModuleOp module) + : module_(module), builder_(module.getContext()) { + module.getContext()->loadDialect(); + module.getContext()->loadDialect(); +} + Status HloModuleImporter::Import(const xla::HloModule& module) { // TODO(hinsu): Only import the entry computation here once all HLO ops with // reference to other computation are updated to have a region instead of a diff --git a/tensorflow/compiler/mlir/xla/hlo_module_importer.h b/tensorflow/compiler/mlir/xla/hlo_module_importer.h index 69ac1e28219..401299484ed 100644 --- a/tensorflow/compiler/mlir/xla/hlo_module_importer.h +++ b/tensorflow/compiler/mlir/xla/hlo_module_importer.h @@ -38,8 +38,7 @@ class Shape; // dialect. HloModuleImporter does not take ownership. class HloModuleImporter { public: - explicit HloModuleImporter(mlir::ModuleOp module) - : module_(module), builder_(module.getContext()) {} + explicit HloModuleImporter(mlir::ModuleOp module); // Import the HloModule into the MLIR Module. Status Import(const xla::HloModule& module); diff --git a/tensorflow/compiler/mlir/xla/hlo_utils.cc b/tensorflow/compiler/mlir/xla/hlo_utils.cc index cf78c81908d..b9d563a659d 100644 --- a/tensorflow/compiler/mlir/xla/hlo_utils.cc +++ b/tensorflow/compiler/mlir/xla/hlo_utils.cc @@ -22,7 +22,7 @@ limitations under the License. #include "mlir/IR/StandardTypes.h" // from @llvm-project #include "mlir/IR/TypeUtilities.h" // from @llvm-project #include "tensorflow/compiler/xla/literal.h" -#include "tensorflow/core/lib/bfloat16/bfloat16.h" +#include "tensorflow/core/platform/bfloat16.h" #include "tensorflow/core/platform/logging.h" namespace xla { @@ -83,6 +83,9 @@ StatusOr> GetPermutationIfAvailable( strides[dim] = accumulated_stride; accumulated_stride *= shape.dimensions(dim); } + if (accumulated_stride == 0) { + return llvm::SmallVector{}; + } return llvm::SmallVector{ makeStridedLinearLayoutMap(strides, /*offset=*/0, builder.getContext())}; } diff --git a/tensorflow/compiler/mlir/xla/ir/mlir_hlo_builder.cc b/tensorflow/compiler/mlir/xla/ir/mlir_hlo_builder.cc index c94110d9102..ac5e01a0abf 100644 --- a/tensorflow/compiler/mlir/xla/ir/mlir_hlo_builder.cc +++ b/tensorflow/compiler/mlir/xla/ir/mlir_hlo_builder.cc @@ -312,6 +312,16 @@ StatusOr MlirHloBuilder::RngOpInternal( return CreateOp(op_name, shape, operands); } +StatusOr MlirHloBuilder::RngBitGeneratorInternal( + const Shape& full_result_shape, RandomAlgorithm algorithm, + XlaOp initial_state) { + TF_ASSIGN_OR_RETURN(mlir::Type ty, ConvertShapeToType( + full_result_shape, builder_)); + auto op = builder_.create( + loc_, ty, builder_.getI32IntegerAttr(algorithm), GetValue(initial_state)); + return MakeXlaOp(op); +} + StatusOr MlirHloBuilder::ReshapeInternal(const Shape& shape, XlaOp operand, int64 inferred_dimension) { @@ -351,6 +361,13 @@ StatusOr MlirHloBuilder::InDimBroadcast( return MakeXlaOp(op.getResult()); } +StatusOr MlirHloBuilder::AddInstruction( + HloInstructionProto&& instr, HloOpcode opcode, + absl::Span operands) { + return Unimplemented("MlirHloBuilder does not support op %s", + HloOpcodeString(opcode)); +} + StatusOr MlirHloBuilder::Compare(const Shape& shape, XlaOp lhs, XlaOp rhs, ComparisonDirection direction) { @@ -382,6 +399,31 @@ XlaOp MlirHloBuilder::CreateToken() { }); } +StatusOr MlirHloBuilder::TriangularSolveInternal( + const Shape& shape, XlaOp a, XlaOp b, TriangularSolveOptions options) { + TF_ASSIGN_OR_RETURN( + mlir::Type result_ty, + ConvertShapeToType(shape, builder_)); + auto op = builder_.create( + loc_, result_ty, GetValue(a), GetValue(b), + builder_.getBoolAttr(options.left_side()), + builder_.getBoolAttr(options.lower()), + builder_.getBoolAttr(options.unit_diagonal()), + builder_.getStringAttr( + TriangularSolveOptions::Transpose_Name(options.transpose_a()))); + return MakeXlaOp(op); +} + +StatusOr MlirHloBuilder::CholeskyInternal(const Shape& shape, XlaOp a, + bool lower) { + TF_ASSIGN_OR_RETURN( + mlir::Type result_ty, + ConvertShapeToType(shape, builder_)); + auto op = builder_.create( + loc_, result_ty, GetValue(a), builder_.getBoolAttr(lower)); + return MakeXlaOp(op); +} + StatusOr MlirHloBuilder::InfeedWithTokenInternal( const Shape& infeed_instruction_shape, XlaOp token, const string& config) { TF_ASSIGN_OR_RETURN(mlir::Type result_type, diff --git a/tensorflow/compiler/mlir/xla/ir/mlir_hlo_builder.h b/tensorflow/compiler/mlir/xla/ir/mlir_hlo_builder.h index a12eb723465..00b7aa4d0b0 100644 --- a/tensorflow/compiler/mlir/xla/ir/mlir_hlo_builder.h +++ b/tensorflow/compiler/mlir/xla/ir/mlir_hlo_builder.h @@ -124,6 +124,13 @@ class MlirHloBuilder : public XlaBuilder { FftType fft_type, absl::Span fft_length) override; + StatusOr TriangularSolveInternal( + const Shape& shape, XlaOp a, XlaOp b, + TriangularSolveOptions options) override; + + StatusOr CholeskyInternal(const Shape& shape, XlaOp a, + bool lower) override; + StatusOr CustomCallInternal( const string& call_target_name, absl::Span operands, const Shape& shape, const string& opaque, @@ -176,6 +183,9 @@ class MlirHloBuilder : public XlaBuilder { StatusOr RngOpInternal(RandomDistribution distribution, absl::Span parameters, const Shape& shape) override; + StatusOr RngBitGeneratorInternal(const Shape& full_result_shape, + RandomAlgorithm algorithm, + XlaOp initial_state) override; StatusOr ReshapeInternal(const Shape& shape, XlaOp operand, int64 inferred_dimension) override; @@ -189,6 +199,9 @@ class MlirHloBuilder : public XlaBuilder { const Shape& shape, XlaOp operand, absl::Span broadcast_dimensions) override; + StatusOr AddInstruction(HloInstructionProto&& instr, HloOpcode opcode, + absl::Span operands) override; + StatusOr Compare(const Shape& shape, XlaOp lhs, XlaOp rhs, ComparisonDirection direction) override; diff --git a/tensorflow/compiler/mlir/xla/tests/hlo_to_lhlo_with_xla/non_identity_layouts.hlotxt b/tensorflow/compiler/mlir/xla/tests/hlo_to_lhlo_with_xla/non_identity_layouts.hlotxt index 3630d2d45e4..a83e36cff64 100644 --- a/tensorflow/compiler/mlir/xla/tests/hlo_to_lhlo_with_xla/non_identity_layouts.hlotxt +++ b/tensorflow/compiler/mlir/xla/tests/hlo_to_lhlo_with_xla/non_identity_layouts.hlotxt @@ -8,6 +8,6 @@ HloModule TestModule ENTRY TestComputation { x = f32[3, 2]{1,0} parameter(0) - // CHECK: "lmhlo.copy"(%{{.*}}, %{{.*}}) : (memref<3x2xf32>, memref<3x2xf32, #[[MAP]]>) -> () + // CHECK: "lmhlo.copy"(%{{.*}}, %{{.*}}) {name = "copy.1"} : (memref<3x2xf32>, memref<3x2xf32, #[[MAP]]>) -> () ROOT x.copy = f32[3, 2]{0,1} copy(x) } diff --git a/tensorflow/compiler/mlir/xla/tests/legalize-tf-BatchMatMulV2.mlir b/tensorflow/compiler/mlir/xla/tests/legalize-tf-BatchMatMulV2.mlir index 69eaeeb946d..cffb15022b0 100644 --- a/tensorflow/compiler/mlir/xla/tests/legalize-tf-BatchMatMulV2.mlir +++ b/tensorflow/compiler/mlir/xla/tests/legalize-tf-BatchMatMulV2.mlir @@ -17,9 +17,7 @@ func @batchmatmulv2_basic(%arg0: tensor<1x4x2xf32>, %arg1: tensor<3x2x4xf32>) -> // CHECK: [[LHSSHAPEEXTENTS:%.*]] = shape.to_extent_tensor [[LHSBCASTSHAPE]] // CHECK: [[LHSBCAST:%.*]] = "mhlo.dynamic_broadcast_in_dim"([[LHS]], [[LHSSHAPEEXTENTS]]) {broadcast_dimensions = dense<[0, 1, 2]> : tensor<3xi64>} : (tensor<1x4x2xf32>, tensor<3xindex>) -> tensor<3x4x2xf32> // CHECK: [[RHSBCASTSHAPE:%.*]] = shape.concat [[BCASTHEAD]], [[RHSTAIL]] -// CHECK: [[RHSSHAPEEXTENTS:%.*]] = shape.to_extent_tensor [[RHSBCASTSHAPE]] -// CHECK: [[RHSBCAST:%.*]] = "mhlo.dynamic_broadcast_in_dim"([[RHS]], [[RHSSHAPEEXTENTS]]) {broadcast_dimensions = dense<[0, 1, 2]> : tensor<3xi64>} : (tensor<3x2x4xf32>, tensor<3xindex>) -> tensor<3x2x4xf32> -// CHECK: [[RESULT:%.*]] = "mhlo.dot_general"([[LHSBCAST]], [[RHSBCAST]]) {dot_dimension_numbers = {lhs_batching_dimensions = dense<0> : tensor<1xi64>, lhs_contracting_dimensions = dense<2> : tensor<1xi64>, rhs_batching_dimensions = dense<0> : tensor<1xi64>, rhs_contracting_dimensions = dense<1> : tensor<1xi64>}} : (tensor<3x4x2xf32>, tensor<3x2x4xf32>) -> tensor<3x4x4xf32> +// CHECK: [[RESULT:%.*]] = "mhlo.dot_general"([[LHSBCAST]], [[RHS]]) {dot_dimension_numbers = {lhs_batching_dimensions = dense<0> : tensor<1xi64>, lhs_contracting_dimensions = dense<2> : tensor<1xi64>, rhs_batching_dimensions = dense<0> : tensor<1xi64>, rhs_contracting_dimensions = dense<1> : tensor<1xi64>}} : (tensor<3x4x2xf32>, tensor<3x2x4xf32>) -> tensor<3x4x4xf32> // CHECK: return [[RESULT]] : tensor<3x4x4xf32> // CHECK: } @@ -29,7 +27,6 @@ func @batchmatmulv2_basic(%arg0: tensor<1x4x2xf32>, %arg1: tensor<3x2x4xf32>) -> func @batchmatmulv2_lhs_batch(%arg0: tensor<3x4x2xf32>, %arg1: tensor<2x4xf32>) -> tensor<3x4x4xf32> { // CHECK-LABEL: func @batchmatmulv2_lhs_batch -// CHECK: "mhlo.dynamic_broadcast_in_dim"({{.*}}, {{.*}}) {broadcast_dimensions = dense<[0, 1, 2]> : tensor<3xi64>} // CHECK: "mhlo.dynamic_broadcast_in_dim"({{.*}}, {{.*}}) {broadcast_dimensions = dense<[1, 2]> : tensor<2xi64>} // CHECK: "mhlo.dot_general"({{.*}}, {{.*}}) {dot_dimension_numbers = { // CHECK-SAME: lhs_batching_dimensions = dense<0> : tensor<1xi64>, @@ -43,7 +40,6 @@ func @batchmatmulv2_lhs_batch(%arg0: tensor<3x4x2xf32>, %arg1: tensor<2x4xf32>) func @batchmatmulv2_rhs_batch(%arg0: tensor<4x2xf32>, %arg1: tensor<3x2x4xf32>) -> tensor<3x4x4xf32> { // CHECK-LABEL: func @batchmatmulv2_rhs_batch // CHECK: "mhlo.dynamic_broadcast_in_dim"({{.*}}, {{.*}}) {broadcast_dimensions = dense<[1, 2]> : tensor<2xi64>} -// CHECK: "mhlo.dynamic_broadcast_in_dim"({{.*}}, {{.*}}) {broadcast_dimensions = dense<[0, 1, 2]> : tensor<3xi64>} // CHECK: "mhlo.dot_general"({{.*}}, {{.*}}) {dot_dimension_numbers = { // CHECK-SAME: lhs_batching_dimensions = dense<0> : tensor<1xi64>, // CHECK-SAME: lhs_contracting_dimensions = dense<2> : tensor<1xi64>, diff --git a/tensorflow/compiler/mlir/xla/tests/legalize-tf-communication.mlir b/tensorflow/compiler/mlir/xla/tests/legalize-tf-communication.mlir index 550b2ba4da3..876a1bf03e7 100644 --- a/tensorflow/compiler/mlir/xla/tests/legalize-tf-communication.mlir +++ b/tensorflow/compiler/mlir/xla/tests/legalize-tf-communication.mlir @@ -169,7 +169,7 @@ func @send_to_host(%arg0: tensor) { // CHECK: "mhlo.send"([[ARG0]], [[INIT_TOKEN]]) // CHECK-SAME: channel_id = {handle = 1 : i64, type = 2 : i64} // CHECK-SAME: is_host_transfer = true - // CHECK-SAME: mhlo.frontend_attributes = {_xla_host_transfer_original_type = "s32", _xla_host_transfer_rendezvous = "send_key"} + // CHECK-SAME: mhlo.frontend_attributes = {_xla_host_transfer_original_type = "s32", _xla_host_transfer_rendezvous = "send_key_dtoh_0"} // CHECK-SAME: (tensor, !mhlo.token) -> !mhlo.token "tf.XlaSendToHost"(%arg0) {key = "send_key"} : (tensor) -> () return @@ -186,7 +186,7 @@ func @recv_from_host() -> tensor { // CHECK: [[RECV_TUPLE:%.*]] = "mhlo.recv"([[INIT_TOKEN]]) // CHECK-SAME: channel_id = {handle = 1 : i64, type = 3 : i64} // CHECK-SAME: is_host_transfer = true - // CHECK-SAME: mhlo.frontend_attributes = {_xla_host_transfer_original_type = "s32", _xla_host_transfer_rendezvous = "recv_key"} + // CHECK-SAME: mhlo.frontend_attributes = {_xla_host_transfer_original_type = "s32", _xla_host_transfer_rendezvous = "recv_key_htod_0"} // CHECK-SAME: (!mhlo.token) -> tuple, !mhlo.token> diff --git a/tensorflow/compiler/mlir/xla/tests/legalize-tf-control-flow.mlir b/tensorflow/compiler/mlir/xla/tests/legalize-tf-control-flow.mlir index 5a9089756a9..93eac3821b2 100644 --- a/tensorflow/compiler/mlir/xla/tests/legalize-tf-control-flow.mlir +++ b/tensorflow/compiler/mlir/xla/tests/legalize-tf-control-flow.mlir @@ -44,7 +44,7 @@ attributes {tf._input_shapes = ["tfshape$", "tfshape$"]} { // CHECK-LABEL: func @case // CHECK-SAME: %[[BRANCH_INDEX:.*]]: tensor, %[[ARG0:.*]]: tensor, %[[ARG1:.*]]: tensor) -> (tensor, tensor) func @case(%index: tensor, %arg0: tensor, %arg1: tensor) -> (tensor, tensor) { - %0:2 = "tf.Case"(%index, %arg0, %arg1) {branches = [@exponential, @log, @floor]} : (tensor, tensor, tensor) -> (tensor, tensor) + %0:2 = "tf.Case"(%index, %arg0, %arg1) {branches = [@exponential, @log, @floor], is_stateless = true} : (tensor, tensor, tensor) -> (tensor, tensor) // CHECK: %[[TUPLE_INPUT:.*]] = "mhlo.tuple"(%[[ARG0]], %[[ARG1]]) : (tensor, tensor) -> tuple, tensor> // CHECK: %[[CASE:.*]]:2 = "mhlo.case"(%[[BRANCH_INDEX]], %[[TUPLE_INPUT]], %[[TUPLE_INPUT]], %[[TUPLE_INPUT]]) ( { // CHECK: ^bb0(%[[TUPLE_ARG:.*]]: tuple, tensor>): diff --git a/tensorflow/compiler/mlir/xla/tests/legalize-tf-with-tf2xla.mlir b/tensorflow/compiler/mlir/xla/tests/legalize-tf-with-tf2xla.mlir index cd351447303..df4f0303a84 100644 --- a/tensorflow/compiler/mlir/xla/tests/legalize-tf-with-tf2xla.mlir +++ b/tensorflow/compiler/mlir/xla/tests/legalize-tf-with-tf2xla.mlir @@ -265,6 +265,47 @@ func @non_max_suppression_v4(%arg0: tensor<3x4xf32>, %arg1: tensor<3xf32>, %arg2 return %0#0 : tensor<2xi32> } +// CHECK-LABEL: bessel_i0e +func @bessel_i0e(%arg0: tensor<3xf16>, %arg1: tensor<3xf32>, %arg2: tensor<3xf64>) -> (tensor<3xf16>, tensor<3xf32>, tensor<3xf64>) { + // CHECK-NOT: tf.BesselI0e + %0 = "tf.BesselI0e"(%arg0) : (tensor<3xf16>) -> (tensor<3xf16>) + %1 = "tf.BesselI0e"(%arg1) : (tensor<3xf32>) -> (tensor<3xf32>) + %2 = "tf.BesselI0e"(%arg2) : (tensor<3xf64>) -> (tensor<3xf64>) + return %0, %1, %2 : tensor<3xf16>, tensor<3xf32>, tensor<3xf64> +} + +// CHECK-LABEL: bessel_i1e +func @bessel_i1e(%arg0: tensor<3xf16>, %arg1: tensor<3xf32>, %arg2: tensor<3xf64>) -> (tensor<3xf16>, tensor<3xf32>, tensor<3xf64>) { + // CHECK-NOT: tf.BesselI1e + %0 = "tf.BesselI1e"(%arg0) : (tensor<3xf16>) -> (tensor<3xf16>) + %1 = "tf.BesselI1e"(%arg1) : (tensor<3xf32>) -> (tensor<3xf32>) + %2 = "tf.BesselI1e"(%arg2) : (tensor<3xf64>) -> (tensor<3xf64>) + return %0, %1, %2 : tensor<3xf16>, tensor<3xf32>, tensor<3xf64> +} + +// CHECK-LABEL: diag +func @diag(%arg0: tensor<2xf32>) -> tensor<2x2xf32> { + // CHECK-NOT: tf.Diag + %0 = "tf.Diag"(%arg0) : (tensor<2xf32>) -> tensor<2x2xf32> + return %0 : tensor<2x2xf32> +} + +// CHECK-LABEL: random_uniform_int +func @random_uniform_int(%arg0: tensor, %arg1: tensor) -> tensor<1000xi32> { + %0 = "tf.Const"() {value = dense<1000> : tensor<1xi32>} : () -> tensor<1xi32> + // CHECK-NOT: tf.RandomUniformInt + %1 = "tf.RandomUniformInt"(%0, %arg0, %arg1) {seed = 0 : i64, seed2 = 0 : i64} : (tensor<1xi32>, tensor, tensor) -> tensor<1000xi32> + return %1 : tensor<1000xi32> +} + +// CHECK-LABEL: multinomial +func @multinomial(%arg0: tensor<2x4xf32>, %seed: tensor, %seed2: tensor) -> tensor<2x10xi32> { + // CHECK-NOT: tf.Multinomial + %samples = "tf.Const"() { value = dense<10> : tensor } : () -> tensor + %1 = "tf.Multinomial"(%arg0, %samples) {seed = 0, seed2 = 0}: (tensor<2x4xf32>, tensor) -> tensor<2x10xi32> + return %1 : tensor<2x10xi32> +} + // TODO(hinsu): Add a test with a valid TF op for which tf2xla kernel is // available but doesn't support this instance. } diff --git a/tensorflow/compiler/mlir/xla/tests/legalize-tf.mlir b/tensorflow/compiler/mlir/xla/tests/legalize-tf.mlir index 9b32fb97260..56d4236c0a0 100644 --- a/tensorflow/compiler/mlir/xla/tests/legalize-tf.mlir +++ b/tensorflow/compiler/mlir/xla/tests/legalize-tf.mlir @@ -1499,6 +1499,35 @@ func @stateful_pcall_multi_in_out(%arg0: tensor, %arg1: tensor) -> (te return %arg1, %arg0 : tensor, tensor } +//===----------------------------------------------------------------------===// +// Elu op legalizations. +//===----------------------------------------------------------------------===// + +// CHECK-LABEL: func @elu +func @elu(%arg0: tensor<1xf32>) -> tensor<1xf32> { + // CHECK-DAG: %[[ZERO:.*]] = mhlo.constant dense<0.000000e+00> : tensor + // CHECK-DAG: %[[PRED:.*]] = chlo.broadcast_compare %arg0, %[[ZERO]] {broadcast_dimensions = dense<> : tensor<0xi64>, comparison_direction = "GT"} + // CHECK-DAG: %[[EXP:.*]] = "mhlo.exponential_minus_one"(%arg0) + // CHECK: %[[RESULT:.*]] = "mhlo.select"(%[[PRED]], %arg0, %[[EXP]]) + // CHECK: return %[[RESULT]] + %0 = "tf.Elu"(%arg0) : (tensor<1xf32>) -> tensor<1xf32> + return %0: tensor<1xf32> +} + +// CHECK-LABEL: func @elu_grad +// CHECK-SAME: (%[[GRADIENTS:.*]]: tensor<4x8xf32>, %[[FEATURES:.*]]: tensor) +func @elu_grad(%gradients: tensor<4x8xf32>, %features: tensor) -> tensor<4x8xf32> { + // CHECK-DAG: %[[ZERO:.*]] = mhlo.constant dense<0.000000e+00> : tensor + // CHECK-DAG: %[[ONE:.*]] = mhlo.constant dense<1.000000e+00> : tensor + // CHECK-DAG: %[[PRED:.*]] = chlo.broadcast_compare %[[FEATURES]], %[[ZERO]] {broadcast_dimensions = dense<> : tensor<0xi64>, comparison_direction = "GT"} + // CHECK-DAG: %[[ADD1:.*]] = chlo.broadcast_add %[[FEATURES]], %[[ONE]] {broadcast_dimensions = dense<> : tensor<0xi64>} + // CHECK-DAG: %[[MULGRAD:.*]] = "mhlo.multiply"(%[[GRADIENTS]], %[[ADD1]]) + // CHECK: %[[RESULT:.*]] = "mhlo.select"(%[[PRED]], %[[GRADIENTS]], %[[MULGRAD]]) + // CHECK: return %[[RESULT]] + %2 = "tf.EluGrad"(%gradients, %features) : (tensor<4x8xf32>, tensor) -> tensor<4x8xf32> + return %2 : tensor<4x8xf32> +} + //===----------------------------------------------------------------------===// // Relu op legalizations. //===----------------------------------------------------------------------===// @@ -3484,6 +3513,20 @@ func @conv3d_backprop_filter(%input: tensor<2x8x8x8x1xf32>, %out_backprop: tenso return %result : tensor<2x8x8x8x1xf32> } +// CHECK-LABEL: @collective_permute +func @collective_permute(%arg0: tensor<128x32xf32>) -> tensor<128x32xf32> { + %source_target_pairs = "tf.Const" () { + value = dense<[[0, 1], [1, 2], [2, 3]]> : tensor<3x2xi32> + } : () -> tensor<3x2xi32> + + // CHECK: "mhlo.collective_permute" + // CHECK-SAME: source_target_pairs = dense<{{\[}}[0, 1], [1, 2], [2, 3]]> : tensor<3x2xi64> + %0 = "tf.CollectivePermute"(%arg0, %source_target_pairs) { + } : (tensor<128x32xf32>, tensor<3x2xi32>) -> tensor<128x32xf32> + + return %0 : tensor<128x32xf32> +} + // CHECK-LABEL: @cross_replica_sum func @cross_replica_sum(%input: tensor<10xf32>) -> tensor<10xf32> { %replica_groups = "tf.Const" () { @@ -3504,8 +3547,9 @@ func @cross_replica_sum(%input: tensor<10xf32>) -> tensor<10xf32> { func @size_scalar_i32(%input: tensor) -> (tensor) { // CHECK: %[[CONST:.*]] = mhlo.constant dense<1> // CHECK-SAME: tensor + // CHECK: %[[CAST:.*]] = tensor_cast %[[CONST]] : tensor to tensor %size = "tf.Size"(%input) {T = "tfdtype$DT_FLOAT", out_type = "tfdtype$DT_INT32"} : (tensor) -> tensor - // CHECK: return %[[CONST]] + // CHECK: return %[[CAST]] return %size : tensor } @@ -3513,8 +3557,9 @@ func @size_scalar_i32(%input: tensor) -> (tensor) { func @size_scalar_i64(%input: tensor) -> (tensor) { // CHECK: %[[CONST:.*]] = mhlo.constant dense<1> // CHECK-SAME: tensor + // CHECK: %[[CAST:.*]] = tensor_cast %[[CONST]] : tensor to tensor %size = "tf.Size"(%input) {T = "tfdtype$DT_FLOAT", out_type = "tfdtype$DT_INT64"} : (tensor) -> tensor - // CHECK: return %[[CONST]] + // CHECK: return %[[CAST]] return %size : tensor } @@ -3775,7 +3820,7 @@ func @unsorted_segment_prod(%data: tensor<8x?x64xf32>, %segment_ids : tensor, %segment_ids : tensor) -> (tensor<4x?xf32>) { %num_segments = "tf.Const"() {value = dense<4> : tensor} : () -> tensor - // CHECK: mhlo.constant dense<0x7F800000> : tensor + // CHECK: mhlo.constant dense<3.40282347E+38> : tensor // CHECK: mhlo.scatter // CHECK: mhlo.minimum %0 = "tf.UnsortedSegmentMin"(%data, %segment_ids, %num_segments) : (tensor<8x?x64xf32>, tensor, tensor) -> (tensor<4x?xf32>) @@ -3785,7 +3830,7 @@ func @unsorted_segment_min(%data: tensor<8x?x64xf32>, %segment_ids : tensor, %segment_ids : tensor) -> (tensor<4x?xf32>) { %num_segments = "tf.Const"() {value = dense<4> : tensor} : () -> tensor - // CHECK: mhlo.constant dense<0xFF800000> : tensor + // CHECK: mhlo.constant dense<-3.40282347E+38> : tensor // CHECK: mhlo.scatter // CHECK: mhlo.maximum %0 = "tf.UnsortedSegmentMax"(%data, %segment_ids, %num_segments) : (tensor<8x?x64xf32>, tensor, tensor) -> (tensor<4x?xf32>) @@ -4668,6 +4713,20 @@ func @cumsum_dynamic(%arg0: tensor, %arg1: tensor) -> tensor return %0 : tensor } +//===----------------------------------------------------------------------===// +// Cumprod op legalizations. +//===----------------------------------------------------------------------===// + +// CHECK-LABEL: func @cumprod +func @cumprod(%arg0: tensor<4xf32>) -> tensor<4xf32> { + // CHECK: [[INIT:%.*]] = mhlo.constant dense<1.000000e+00> : tensor + // CHECK: "mhlo.reduce_window"({{.*}}, [[INIT]]) ( { + // CHECK: mhlo.mul + %0 = "tf.Const"() {_output_shapes = ["tfshape$"], device = "", dtype = i32, value = dense<0> : tensor} : () -> tensor + %1 = "tf.Cumprod"(%arg0, %0) {exclusive = false, reverse = false} : (tensor<4xf32>, tensor) -> tensor<4xf32> + return %1 : tensor<4xf32> +} + //===----------------------------------------------------------------------===// // Qr op legalization //===----------------------------------------------------------------------===// @@ -4766,3 +4825,37 @@ func @softplus_f64(%arg0: tensor<8x16xf64>) -> tensor<8x16xf64> { // CHECK: return [[ENTRY_SELECT]] : tensor<8x16xf64> return %0 : tensor<8x16xf64> } + +// CHECK-LABEL: @xla_gather +func @xla_gather(%arg0: tensor<200x100x300xf32>, %arg1: tensor<10x2xi32>) -> tensor<10x1x300xf32> { + %cst = "tf.Const"() { value = dense<[1, 1, 300]> : tensor<3xi64> } : () -> tensor<3xi64> + + // CHECK: "mhlo.gather" + // CHECK-SAME: dimension_numbers = + // CHECK-SAME: collapsed_slice_dims = dense<0> : tensor<1xi64> + // CHECK-SAME: index_vector_dim = 1 : i64 + // CHECK-SAME: offset_dims = dense<1> : tensor<1xi64> + // CHECK-SAME: start_index_map = dense<0> : tensor<1xi64> + // CHECK-SAME: indices_are_sorted = true + // CHECK-SAME: slice_sizes = dense<[1, 1, 300]> : tensor<3xi64> + + %0 = "tf.XlaGather"(%arg0, %arg1, %cst) {dimension_numbers = "\0A\01\01\12\01\00\1A\01\00 \01", indices_are_sorted = true} : (tensor<200x100x300xf32>, tensor<10x2xi32>, tensor<3xi64>) -> tensor<10x1x300xf32> + return %0 : tensor<10x1x300xf32> +} + +// CHECK-LABEL: @xla_gather_i32 +func @xla_gather_i32(%arg0: tensor<200x100x300xf32>, %arg1: tensor<10x2xi32>) -> tensor<10x1x300xf32> { + %cst = "tf.Const"() { value = dense<[1, 1, 300]> : tensor<3xi32> } : () -> tensor<3xi32> + + // CHECK: "mhlo.gather" + // CHECK-SAME: dimension_numbers = + // CHECK-SAME: collapsed_slice_dims = dense<0> : tensor<1xi64> + // CHECK-SAME: index_vector_dim = 1 : i64 + // CHECK-SAME: offset_dims = dense<1> : tensor<1xi64> + // CHECK-SAME: start_index_map = dense<0> : tensor<1xi64> + // CHECK-SAME: indices_are_sorted = true + // CHECK-SAME: slice_sizes = dense<[1, 1, 300]> : tensor<3xi64> + + %0 = "tf.XlaGather"(%arg0, %arg1, %cst) {dimension_numbers = "\0A\01\01\12\01\00\1A\01\00 \01", indices_are_sorted = true} : (tensor<200x100x300xf32>, tensor<10x2xi32>, tensor<3xi32>) -> tensor<10x1x300xf32> + return %0 : tensor<10x1x300xf32> +} diff --git a/tensorflow/compiler/mlir/xla/transforms/legalize_tf.cc b/tensorflow/compiler/mlir/xla/transforms/legalize_tf.cc index 5fe933ee635..3462b3b7a5a 100644 --- a/tensorflow/compiler/mlir/xla/transforms/legalize_tf.cc +++ b/tensorflow/compiler/mlir/xla/transforms/legalize_tf.cc @@ -50,6 +50,7 @@ limitations under the License. #include "tensorflow/compiler/mlir/hlo/include/mlir-hlo/utils/hlo_utils.h" #include "tensorflow/compiler/mlir/tensorflow/ir/tf_ops.h" #include "tensorflow/compiler/mlir/tensorflow/transforms/lower_tf.h" +#include "tensorflow/compiler/mlir/xla/attribute_importer.h" #include "tensorflow/compiler/mlir/xla/transforms/passes.h" #include "tensorflow/compiler/xla/client/lib/conv_grad_size_util.h" #include "tensorflow/compiler/xla/client/padding.h" @@ -57,7 +58,7 @@ limitations under the License. #include "tensorflow/compiler/xla/xla_data.pb.h" #include "tensorflow/core/framework/kernel_shape_util.h" #include "tensorflow/core/kernels/conv_grad_shape_utils.h" -#include "tensorflow/core/lib/bfloat16/bfloat16.h" +#include "tensorflow/core/platform/bfloat16.h" #include "tensorflow/core/util/padding.h" #include "tensorflow/core/util/tensor_format.h" @@ -262,49 +263,21 @@ tensorflow::TensorShape ToTensorShape( sizes.begin(), sizes.end())); } -// Returns minimal value for the given int or float element type. -static ConstOp GetMinValueForType(Type ty, Location loc, - PatternRewriter *rewriter) { - RankedTensorType scalar_ty = RankedTensorType::get({}, ty); - - DenseElementsAttr attr; - if (auto float_ty = ty.dyn_cast_or_null()) { - APFloat neg_inf = - APFloat::getInf(float_ty.getFloatSemantics(), /*negative=*/true); - attr = DenseElementsAttr::get(scalar_ty, neg_inf); - } else { - auto int_ty = ty.cast(); - APInt min_val = APInt::getSignedMinValue(int_ty.getWidth()); - attr = DenseElementsAttr::get(scalar_ty, min_val); - } - return rewriter->create(loc, attr); -} - -// Returns maximal value for the given int or float element type. -static ConstOp GetMaxValueForType(Type ty, Location loc, - PatternRewriter *rewriter) { - RankedTensorType scalar_ty = RankedTensorType::get({}, ty); - - DenseElementsAttr attr; - if (auto float_ty = ty.dyn_cast_or_null()) { - APFloat pos_inf = - APFloat::getInf(float_ty.getFloatSemantics(), /*negative=*/false); - attr = DenseElementsAttr::get(scalar_ty, pos_inf); - } else { - auto int_ty = ty.cast(); - APInt max_val = APInt::getSignedMaxValue(int_ty.getWidth()); - attr = DenseElementsAttr::get(scalar_ty, max_val); - } - return rewriter->create(loc, attr); -} - -// Returns int or float scalar DenseElementsAttr attribute with the given -// element type and the value. +// Returns int, float, or complex scalar DenseElementsAttr attribute with the +// given element type and the value. static ConstOp GetScalarConstOfType(Type ty, Location loc, int64_t raw_value, OpBuilder *builder) { return builder->create(loc, hlo::GetScalarOfType(ty, raw_value)); } +// Returns a limit scalar const op for the given type. +// Requires FloatType or IntegerType +static ConstOp GetScalarLimitConstOfType(Type ty, Location loc, + hlo::ScalarLimit limit, + OpBuilder *builder) { + return builder->create(loc, hlo::GetScalarLimitOfType(ty, limit)); +} + // Creates an mhlo::SliceOp where the major dimensions have full size, and // the minor dimensions have the provided offsets and sizes. static Value SliceInMinorDims(Location loc, Value v, @@ -1065,6 +1038,21 @@ static void BuildSortComparisonBody(llvm::ArrayRef element_types, builder->create(loc, compare); } +//===----------------------------------------------------------------------===// +// XlaGather op utilities. +//===----------------------------------------------------------------------===// + +bool HasValidGatherDims(StringAttr attr) { + ::xla::GatherDimensionNumbers dims; + return dims.ParseFromString(attr.getValue().str()); +} + +GatherDimensionNumbers GetGatherDimNumsAttr(StringAttr attr, Builder *builder) { + ::xla::GatherDimensionNumbers dims; + if (!dims.ParseFromString(attr.getValue().str())) return {}; + return ::xla::ConvertGatherDimensionNumbers(dims, builder); +} + //===----------------------------------------------------------------------===// // Op converters. //===----------------------------------------------------------------------===// @@ -2385,15 +2373,16 @@ class ConvertMaxPoolOp : public OpRewritePattern { op.input().getType().template cast().getElementType(); if (!element_type.isSignlessIntOrFloat()) return failure(); Location loc = op.getLoc(); - ConstOp init = GetMinValueForType(element_type, loc, &rewriter); + ConstOp init = GetScalarLimitConstOfType(element_type, loc, + hlo::kInfinityLowest, &rewriter); auto input_ty = op.input().getType().template dyn_cast(); if (!input_ty) return failure(); DenseIntElementsAttr paddings_attr = GetReduceWindowPaddingAsAttr( input_ty.getShape(), op.ksize(), op.strides(), op.padding(), &rewriter); auto reduce = rewriter.create( - loc, op.getType(), op.input(), init.getResult(), - GetI64ElementsAttr(op.ksize()), GetI64ElementsAttr(op.strides()), + loc, op.getType(), op.input(), init, GetI64ElementsAttr(op.ksize()), + GetI64ElementsAttr(op.strides()), /*base_dilations=*/DenseIntElementsAttr(), /*window_dilations=*/DenseIntElementsAttr(), paddings_attr); BuildReduceBody(element_type, &reduce.body(), &rewriter); @@ -3636,7 +3625,8 @@ class ConvertMaxOp static Value GetInitialValue(Type reduce_element_type, Location loc, PatternRewriter *rewriter) { - return GetMinValueForType(reduce_element_type, loc, rewriter); + return GetScalarLimitConstOfType(reduce_element_type, loc, + hlo::kInfinityLowest, rewriter); } }; @@ -3653,7 +3643,8 @@ class ConvertMinOp static Value GetInitialValue(Type reduce_element_type, Location loc, PatternRewriter *rewriter) { - return GetMaxValueForType(reduce_element_type, loc, rewriter); + return GetScalarLimitConstOfType(reduce_element_type, loc, + hlo::kInfinityMax, rewriter); } }; @@ -3789,7 +3780,8 @@ class ConvertArgMaxOp static Value GetInitialValue(Type reduce_element_type, Location loc, PatternRewriter &rewriter) { - return GetMinValueForType(reduce_element_type, loc, &rewriter); + return GetScalarLimitConstOfType(reduce_element_type, loc, + hlo::kInfinityLowest, &rewriter); } static StringRef GetDirection() { return "GT"; } @@ -4728,7 +4720,7 @@ class GenericConvertUnsortedSegmentReductionOp : public OpRewritePattern { auto output_type = RankedTensorType::get(output_shape, data_type.getElementType()); - // Broadccast the initial value for reduction. This will become the + // Broadcast the initial value for reduction. This will become the // 'operand' parameter to scatter to for the final scatter op. Value init = ConcreteClass::GetInitialValue(data_type.getElementType(), op.getLoc(), &rewriter); @@ -4768,7 +4760,8 @@ class ConvertUnsortedSegmentMaxOp static Value GetInitialValue(Type reduce_element_type, Location loc, PatternRewriter *rewriter) { - return GetMinValueForType(reduce_element_type, loc, rewriter); + return GetScalarLimitConstOfType(reduce_element_type, loc, hlo::kLowest, + rewriter); } }; @@ -4781,7 +4774,8 @@ class ConvertUnsortedSegmentMinOp static Value GetInitialValue(Type reduce_element_type, Location loc, PatternRewriter *rewriter) { - return GetMaxValueForType(reduce_element_type, loc, rewriter); + return GetScalarLimitConstOfType(reduce_element_type, loc, hlo::kMax, + rewriter); } }; @@ -5092,17 +5086,19 @@ class ConvertXlaDynamicUpdateSliceOp } }; -/// Converts the Cumsum TensorFlow op to the HLO ReduceWindow op by setting -/// appropriate window dimensions, with 'add' as the reduction function. The -/// input tensor needs to have a static shape, and 'axis' must be const. The -/// TableGen pattern is not used for this rewrite because it involves regions. -class ConvertCumsumOp : public OpRewritePattern { - using OpRewritePattern::OpRewritePattern; +// Converts the Cumsum or Cumprod TensorFlow op to the HLO ReduceWindow op by +// setting appropriate window dimensions, with the given aggregation op as the +// reduction function. The input tensor needs to have a static shape, and 'axis' +// must be const. The TableGen pattern is not used for this rewrite because it +// involves regions. +template +class ConvertCumOp : public OpRewritePattern { + using OpRewritePattern::OpRewritePattern; - LogicalResult matchAndRewrite(TF::CumsumOp op, + LogicalResult matchAndRewrite(OpT op, PatternRewriter &rewriter) const override { auto input = op.x(); - auto input_type = input.getType().dyn_cast(); + auto input_type = input.getType().template dyn_cast(); if (!input_type || !input_type.hasStaticShape()) { return failure(); } @@ -5135,6 +5131,10 @@ class ConvertCumsumOp : public OpRewritePattern { // Convert if we need to enlarge the element type's bitwidth to avoid // precision loss. Type input_element_type = input_type.getElementType(); + + // TODO(hinsu): Handle complex element types. + if (!input_element_type.isIntOrFloat()) return failure(); + Type sum_element_type = GetSumAccumulationType(input_element_type); input = rewriter.create(op.getLoc(), input, sum_element_type); @@ -5148,8 +5148,9 @@ class ConvertCumsumOp : public OpRewritePattern { RankedTensorType::get({rank, 2}, rewriter.getIntegerType(64)), paddings); - Value init = - GetScalarConstOfType(sum_element_type, op.getLoc(), 0, &rewriter); + int64_t init_value = (std::is_same::value) ? 0 : 1; + Value init = GetScalarConstOfType(sum_element_type, op.getLoc(), init_value, + &rewriter); auto reduce = rewriter.create( op.getLoc(), input_type, input, init, @@ -5157,7 +5158,7 @@ class ConvertCumsumOp : public OpRewritePattern { GetI64ElementsAttr(rewriter.getI64ArrayAttr(window_strides)), /*base_dilations=*/DenseIntElementsAttr(), /*window_dilations=*/DenseIntElementsAttr(), paddings_attr); - BuildReduceBody(sum_element_type, &reduce.body(), &rewriter); + BuildReduceBody(sum_element_type, &reduce.body(), &rewriter); Value result = reduce.getResult(); if (op.exclusive()) { @@ -5193,6 +5194,9 @@ class ConvertCumsumOp : public OpRewritePattern { } }; +using ConvertCumsumOp = ConvertCumOp; +using ConvertCumprodOp = ConvertCumOp; + // Converts the Tensorflow ShapeOp to a sequence of Shape dialect and Standard // dialect lowerings. This involves extracting the shape type, extracting and // converting each dimension to a known integer type, and repacking into a final @@ -5857,7 +5861,7 @@ void PopulateLegalizeTfPatterns(MLIRContext *context, ConvertConv2DOp, ConvertConv3DOp, ConvertDepthConv2DOp, ConvertConv2DBackpropFilterOp, ConvertConv3DBackpropFilterOp, ConvertConv2DBackpropInputOp, ConvertConv3DBackpropInputOp, - ConvertCumsumOp, ConvertDiagPartOp, ConvertEinsumOp, + ConvertCumprodOp, ConvertCumsumOp, ConvertDiagPartOp, ConvertEinsumOp, ConvertFusedBatchNormGradOp, ConvertFusedBatchNormGradV2Op, ConvertFusedBatchNormGradV3Op, ConvertFusedBatchNormV2Op, ConvertFusedBatchNormV3Op, ConvertInfeedDequeueTupleOp, diff --git a/tensorflow/compiler/mlir/xla/transforms/legalize_tf_communication.cc b/tensorflow/compiler/mlir/xla/transforms/legalize_tf_communication.cc index 1d6ce36300f..1f884b1bdea 100644 --- a/tensorflow/compiler/mlir/xla/transforms/legalize_tf_communication.cc +++ b/tensorflow/compiler/mlir/xla/transforms/legalize_tf_communication.cc @@ -215,11 +215,17 @@ void SetOpSharding(Operation* op, int64_t tpu_core) { } // Assigns frontend attributes holding information about data type and -// TensorFlow rendezvous channel name. -void SetFrontendAttributes(Operation* op, StringRef key, Type type) { +// TensorFlow rendezvous channel name. The TensorFlow rendezvous channel name is +// handled differently as individual names are used per data send and receive. +void SetFrontendAttributes(Operation* op, int32_t index, StringRef key, + Type type, bool device_to_host) { MLIRContext* context = op->getContext(); - auto rendezvous_name = StringAttr::get(key, context); + std::string formatted_key = + device_to_host ? llvm::formatv("{0}_dtoh_{1}", key, index).str() + : llvm::formatv("{0}_htod_{1}", key, index).str(); + + auto rendezvous_name = StringAttr::get(formatted_key, context); auto rendezvous_name_attr = NamedAttribute( Identifier::get(kXlaHostTransferRendezvousNameAttr, context), rendezvous_name); @@ -239,24 +245,10 @@ void SetFrontendAttributes(Operation* op, StringRef key, Type type) { op->setAttr(kFrontendAttributesAttr, frontend_attributes); } -// Assigns frontend attributes holding information about data type and -// TensorFlow rendezvous channel name specific to `tf._XlaHostComputeMlir`. -// TensorFlow rendezvous channel name is handled differently as individual names -// are used per data send and receive. -void SetFrontendAttributes(Operation* op, int32_t index, StringRef key, - Type type, bool device_to_host) { - std::string formatted_key = - device_to_host ? llvm::formatv("{0}_dtoh_{1}", key, index).str() - : llvm::formatv("{0}_htod_{1}", key, index).str(); - - return SetFrontendAttributes(op, formatted_key, type); -} - -// Creates a `mhlo.send` op for sending value `operand`. If `index` is set, -// `key` will be rewritten with a suffix and index. If `tpu_core` is set, op -// sharding for the respective device will be set. +// Creates a `mhlo.send` op for sending value `operand`. If `tpu_core` is set, +// op sharding for the respective device will be set. Value CreateSendOp(OpBuilder& builder, int64_t& channel_id, Location loc, - Value operand, StringRef key, const Optional& index, + Value operand, StringRef key, size_t index, const Optional& tpu_core, Value token) { // type 2 == DEVICE_TO_HOST auto channel_handle = ChannelHandle::get( @@ -266,23 +258,18 @@ Value CreateSendOp(OpBuilder& builder, int64_t& channel_id, Location loc, loc, token.getType(), operand, token, channel_handle, /*is_host_transfer=*/builder.getBoolAttr(true)); - if (index) { - SetFrontendAttributes(send, *index, key, operand.getType(), - /*device_to_host=*/true); - } else { - SetFrontendAttributes(send, key, operand.getType()); - } + SetFrontendAttributes(send, index, key, operand.getType(), + /*device_to_host=*/true); if (tpu_core) SetOpSharding(send, *tpu_core); return send.getResult(); } -// Creates a `mhlo.recv` op for receiving a value. If `index` is set, `key` will -// be rewritten with a suffix and index. If `tpu_core` is set, op sharding for -// the respective device will be set. +// Creates a `mhlo.recv` op for receiving a value. If `tpu_core` is set, op +// sharding for the respective device will be set. Value CreateRecvOp(OpBuilder& builder, int64_t& channel_id, Location loc, - Value result, StringRef key, const Optional& index, + Value result, StringRef key, size_t index, const Optional& tpu_core, Value token) { // type 3 == HOST_TO_DEVICE auto channel_handle = ChannelHandle::get( @@ -294,12 +281,10 @@ Value CreateRecvOp(OpBuilder& builder, int64_t& channel_id, Location loc, auto recv = builder.create(loc, recv_result_type, token, channel_handle, /*is_host_transfer=*/builder.getBoolAttr(true)); - if (index) { - SetFrontendAttributes(recv, *index, key, result_type, - /*device_to_host=*/false); - } else { - SetFrontendAttributes(recv, key, result.getType()); - } + + SetFrontendAttributes(recv, index, key, result_type, + /*device_to_host=*/false); + if (tpu_core) SetOpSharding(recv, *tpu_core); auto get_tuple_element = @@ -369,7 +354,7 @@ Value RewriteSendToHostOp(OpBuilder& builder, int64_t& channel_id, builder.setInsertionPoint(send_to_host); token = CreateSendOp(builder, channel_id, send_to_host.getLoc(), send_to_host.input(), send_to_host.key(), - /*index=*/llvm::None, /*tpu_core=*/llvm::None, token); + /*index=*/0, /*tpu_core=*/llvm::None, token); send_to_host.erase(); return token; @@ -381,7 +366,7 @@ Value RewriteRecvFromHostOp(OpBuilder& builder, int64_t& channel_id, builder.setInsertionPoint(recv_from_host); token = CreateRecvOp(builder, channel_id, recv_from_host.getLoc(), recv_from_host.output(), recv_from_host.key(), - /*index=*/llvm::None, /*tpu_core=*/llvm::None, token); + /*index=*/0, /*tpu_core=*/llvm::None, token); recv_from_host.erase(); return token; diff --git a/tensorflow/compiler/mlir/xla/transforms/legalize_tf_patterns.td b/tensorflow/compiler/mlir/xla/transforms/legalize_tf_patterns.td index 1d4c9503afa..73ce305091c 100644 --- a/tensorflow/compiler/mlir/xla/transforms/legalize_tf_patterns.td +++ b/tensorflow/compiler/mlir/xla/transforms/legalize_tf_patterns.td @@ -51,6 +51,10 @@ def GetHLOAxisFromTFAxisVariadic : NativeCodeCall< "$0, (*$1.begin()).getType().cast().getRank(), " "&$_builder)">; +def CastElementsToI64Elements : NativeCodeCall< + "hlo::ConvertElementsAttr(" + "$0, $_builder.getIntegerType(64)).cast()">; + def : Pattern< (TF_FusedBatchNormOp:$root $x, $scale, $offset, $mean, $variance, $epsilon, $exponential_avg_factor, $data_format, @@ -255,12 +259,16 @@ def : Pat<(TF_ConcatV2Op $inputs, (TF_ConstOp OneElementAttr:$axis)), [(HasRankedFirstOperand $inputs)]>; //===----------------------------------------------------------------------===// -// CrossReplicaSum op patterns. +// CollectivePermute op patterns. //===----------------------------------------------------------------------===// -def CastElementsToI64Elements : NativeCodeCall< - "hlo::ConvertElementsAttr(" - "$0, $_builder.getIntegerType(64)).cast()">; +def : Pat<(TF_CollectivePermuteOp $input, (TF_ConstOp $source_target_pairs)), + (HLO_CollectivePermuteOp $input, + (CastElementsToI64Elements $source_target_pairs))>; + +//===----------------------------------------------------------------------===// +// CrossReplicaSum op patterns. +//===----------------------------------------------------------------------===// def : Pat<(TF_CrossReplicaSumOp $input, (TF_ConstOp $group_assignment)), (HLO_CrossReplicaSumOp $input, @@ -427,6 +435,35 @@ def : Pat<(TF_ConstOp:$res ElementsAttr:$value), (TensorCastOp (HLO_ConstOp $value)), [(HLO_Tensor $res)]>; +//===----------------------------------------------------------------------===// +// Elu op patterns. +//===----------------------------------------------------------------------===// + +def : Pat<(TF_EluOp AnyRankedTensor:$features), + (HLO_SelectOp + (HLOClient_BroadcastCompareOp + $features, + (HLO_ConstOp:$zero (GetScalarOfType<0> $features)), + (BinBroadcastDimensions $zero, $features), + HLO_COMPARISON_DIRECTION_GT), + $features, + (HLO_Expm1Op $features))>; + +def : Pat<(TF_EluGradOp AnyStaticShapeTensor:$gradients, AnyRankedTensor:$features), + (HLO_SelectOp + (HLOClient_BroadcastCompareOp + $features, + (HLO_ConstOp:$zero (GetScalarOfType<0> $features)), + (BinBroadcastDimensions $zero, $features), + HLO_COMPARISON_DIRECTION_GT), + $gradients, + (HLO_MulOp + $gradients, + (HLOClient_BroadcastAddOp + $features, + (HLO_ConstOp:$one (GetScalarOfType<1> $features)), + (BinBroadcastDimensions $one, $features))))>; + //===----------------------------------------------------------------------===// // Relu op patterns. //===----------------------------------------------------------------------===// @@ -660,3 +697,19 @@ def : Pattern<(TF_SoftplusOp AnyTensor:$features), ), (replaceWithValue $output) ]>; + +//===----------------------------------------------------------------------===// +// XlaGather op. +//===----------------------------------------------------------------------===// + +def ToGatherDimNumsAttr : NativeCodeCall<"GetGatherDimNumsAttr($0, &$_builder)">; + +def HasValidGatherDims : Constraint>; + +def : Pat<(TF_XlaGatherOp $operand, $start_indices, (TF_ConstOp $slice_sizes), + $dimension_numbers, $indices_are_sorted), + (HLO_GatherOp $operand, $start_indices, + (ToGatherDimNumsAttr $dimension_numbers), + (CastElementsToI64Elements $slice_sizes), + $indices_are_sorted), + [(HasValidGatherDims $dimension_numbers)]>; diff --git a/tensorflow/compiler/mlir/xla/transforms/legalize_tf_with_tf2xla.cc b/tensorflow/compiler/mlir/xla/transforms/legalize_tf_with_tf2xla.cc index 904b80e05b1..2f73d1a54df 100644 --- a/tensorflow/compiler/mlir/xla/transforms/legalize_tf_with_tf2xla.cc +++ b/tensorflow/compiler/mlir/xla/transforms/legalize_tf_with_tf2xla.cc @@ -40,6 +40,7 @@ limitations under the License. #include "mlir/Support/LogicalResult.h" // from @llvm-project #include "tensorflow/compiler/mlir/op_or_arg_name_mapper.h" #include "tensorflow/compiler/mlir/tensorflow/ir/tf_ops.h" +#include "tensorflow/compiler/mlir/tensorflow/ir/tf_ops_a_m.h" #include "tensorflow/compiler/mlir/tensorflow/translate/export_tf_dialect_op.h" #include "tensorflow/compiler/mlir/tensorflow/utils/convert_tensor.h" #include "tensorflow/compiler/mlir/tensorflow/utils/convert_type.h" @@ -81,6 +82,7 @@ bool IsOpAllowedTf2XlaFallback(Operation* op) { // TODO(hinsu): Drop explicit allowlist when MLIR based bridge is enabled for // all tf2xla kernels. // clang-format off + static llvm::SmallDenseSet ops = { TypeID::get(), TypeID::get(), @@ -102,6 +104,9 @@ bool IsOpAllowedTf2XlaFallback(Operation* op) { TypeID::get(), TypeID::get(), TypeID::get(), + TypeID::get(), + TypeID::get(), + TypeID::get(), TypeID::get(), TypeID::get(), TypeID::get(), @@ -110,12 +115,17 @@ bool IsOpAllowedTf2XlaFallback(Operation* op) { TypeID::get(), TypeID::get(), TypeID::get(), + TypeID::get(), TypeID::get(), TypeID::get(), TypeID::get(), TypeID::get(), TypeID::get(), TypeID::get(), + TypeID::get(), + TypeID::get(), + TypeID::get(), + TypeID::get(), TypeID::get(), TypeID::get(), TypeID::get(), @@ -124,6 +134,7 @@ bool IsOpAllowedTf2XlaFallback(Operation* op) { TypeID::get(), TypeID::get(), TypeID::get(), + TypeID::get(), TypeID::get(), TypeID::get(), TypeID::get(), @@ -139,6 +150,9 @@ bool IsOpAllowedTf2XlaFallback(Operation* op) { TypeID::get(), TypeID::get(), TypeID::get(), + TypeID::get(), + TypeID::get(), + TypeID::get(), TypeID::get(), TypeID::get(), TypeID::get(), @@ -149,26 +163,38 @@ bool IsOpAllowedTf2XlaFallback(Operation* op) { TypeID::get(), TypeID::get(), TypeID::get(), + TypeID::get(), TypeID::get(), TypeID::get(), TypeID::get(), TypeID::get(), + TypeID::get(), TypeID::get(), TypeID::get(), + TypeID::get(), TypeID::get(), + TypeID::get(), + TypeID::get(), TypeID::get(), TypeID::get(), + TypeID::get(), TypeID::get(), TypeID::get(), TypeID::get(), TypeID::get(), TypeID::get(), TypeID::get(), + // TODO(hinsu): Canonicalize QuantizeAndDequantize and + // QuantizeAndDequantizeV2 to QuantizeAndDequantizeV3 by converting + // attributes to operands. + TypeID::get(), + TypeID::get(), + TypeID::get(), TypeID::get(), TypeID::get(), TypeID::get(), + TypeID::get(), TypeID::get(), - TypeID::get(), TypeID::get(), TypeID::get(), TypeID::get(), @@ -177,6 +203,7 @@ bool IsOpAllowedTf2XlaFallback(Operation* op) { TypeID::get(), TypeID::get(), TypeID::get(), + TypeID::get(), TypeID::get(), TypeID::get(), TypeID::get(), @@ -190,9 +217,15 @@ bool IsOpAllowedTf2XlaFallback(Operation* op) { TypeID::get(), TypeID::get(), TypeID::get(), + TypeID::get(), TypeID::get(), TypeID::get(), TypeID::get(), + TypeID::get(), + TypeID::get(), + TypeID::get(), + TypeID::get(), + TypeID::get(), TypeID::get(), TypeID::get(), TypeID::get(), @@ -200,6 +233,7 @@ bool IsOpAllowedTf2XlaFallback(Operation* op) { TypeID::get(), TypeID::get(), TypeID::get(), + TypeID::get(), TypeID::get(), TypeID::get(), TypeID::get(), diff --git a/tensorflow/compiler/mlir/xla/transforms/mhlo_to_lhlo_with_xla.cc b/tensorflow/compiler/mlir/xla/transforms/mhlo_to_lhlo_with_xla.cc index cc74d82839b..22462428367 100644 --- a/tensorflow/compiler/mlir/xla/transforms/mhlo_to_lhlo_with_xla.cc +++ b/tensorflow/compiler/mlir/xla/transforms/mhlo_to_lhlo_with_xla.cc @@ -34,7 +34,6 @@ limitations under the License. #include "mlir/Pass/Pass.h" // from @llvm-project #include "mlir/Pass/PassOptions.h" // from @llvm-project #include "mlir/Translation.h" // from @llvm-project -#include "tensorflow/compiler/mlir/hlo/include/mlir-hlo/Dialect/mhlo/IR/lhlo_ops.h" #include "tensorflow/compiler/mlir/xla/hlo_function_importer.h" #include "tensorflow/compiler/mlir/xla/hlo_utils.h" #include "tensorflow/compiler/mlir/xla/mlir_hlo_to_hlo.h" @@ -182,7 +181,10 @@ template StatusOr LhloDialectEmitter::CreateOpWithoutAttrs( HloInstruction* instr) { Location loc = getLocation(instr); - ArrayRef> attrs; + std::pair attrs[] = { + {Identifier::get("name", builder_.getContext()), + builder_.getStringAttr(instr->name())}, + }; ArrayRef rets{}; llvm::SmallVector operands; @@ -252,15 +254,14 @@ Status LhloDialectEmitter::DefaultAction(HloInstruction* instr) { return Status::OK(); } -StatusOr LhloDialectEmitter::EmitSortOp( - HloInstruction* instr) { +StatusOr LhloDialectEmitter::EmitSortOp(HloInstruction* instr) { TF_ASSIGN_OR_RETURN(auto sort, CreateOpWithoutAttrs(instr)); auto* sort_instr = ::xla::Cast<::xla::HloSortInstruction>(instr); sort.dimensionAttr(builder_.getI64IntegerAttr(sort_instr->sort_dimension())); sort.is_stableAttr(builder_.getBoolAttr(sort_instr->is_stable())); TF_RETURN_IF_ERROR(::xla::HloFunctionImporter::ImportAsRegion( *sort_instr->called_computations()[0], &sort.comparator(), &builder_)); - return sort.getOperation(); + return sort; } Status LhloDialectEmitter::HandleSort(HloInstruction* instr) { diff --git a/tensorflow/compiler/mlir/xla/transforms/mhlo_to_lhlo_with_xla.h b/tensorflow/compiler/mlir/xla/transforms/mhlo_to_lhlo_with_xla.h index b191d53840d..89514116254 100644 --- a/tensorflow/compiler/mlir/xla/transforms/mhlo_to_lhlo_with_xla.h +++ b/tensorflow/compiler/mlir/xla/transforms/mhlo_to_lhlo_with_xla.h @@ -19,6 +19,7 @@ limitations under the License. #include "mlir/IR/Builders.h" // from @llvm-project #include "mlir/IR/Module.h" // from @llvm-project #include "mlir/IR/StandardTypes.h" // from @llvm-project +#include "tensorflow/compiler/mlir/hlo/include/mlir-hlo/Dialect/mhlo/IR/lhlo_ops.h" #include "tensorflow/compiler/xla/service/buffer_assignment.h" #include "tensorflow/compiler/xla/service/hlo_module.h" @@ -41,7 +42,7 @@ class LhloDialectEmitter : public ::xla::DfsHloVisitorWithDefault { builder_(module.getContext()), i8_type_(builder_.getIntegerType(8)) {} - ::xla::StatusOr EmitSortOp(::xla::HloInstruction* instr); + ::xla::StatusOr EmitSortOp(::xla::HloInstruction* instr); private: template diff --git a/tensorflow/compiler/mlir/xla/type_to_shape.cc b/tensorflow/compiler/mlir/xla/type_to_shape.cc index afc36916348..b725f56b455 100644 --- a/tensorflow/compiler/mlir/xla/type_to_shape.cc +++ b/tensorflow/compiler/mlir/xla/type_to_shape.cc @@ -43,47 +43,41 @@ using xla::ShapeUtil; namespace xla { PrimitiveType TypeToPrimitiveType(mlir::Type type) { - switch (type.getKind()) { - case mlir::StandardTypes::BF16: - return PrimitiveType::BF16; - case mlir::StandardTypes::Complex: { - mlir::Type element_ty = type.cast().getElementType(); - switch (element_ty.getKind()) { - case mlir::StandardTypes::F32: - return PrimitiveType::C64; - case mlir::StandardTypes::F64: - return PrimitiveType::C128; - default: - return PrimitiveType::PRIMITIVE_TYPE_INVALID; - } + if (type.isBF16()) { + return PrimitiveType::BF16; + } else if (type.isF16()) { + return PrimitiveType::F16; + } else if (type.isF32()) { + return PrimitiveType::F32; + } else if (type.isF64()) { + return PrimitiveType::F64; + } else if (auto complex_type = type.dyn_cast()) { + mlir::Type element_ty = complex_type.getElementType(); + if (element_ty.isF32()) { + return PrimitiveType::C64; + + } else if (element_ty.isF64()) { + return PrimitiveType::C128; } - case mlir::StandardTypes::F16: - return PrimitiveType::F16; - case mlir::StandardTypes::F32: - return PrimitiveType::F32; - case mlir::StandardTypes::F64: - return PrimitiveType::F64; - case mlir::StandardTypes::Integer: { - const auto integer = type.cast(); - bool is_unsigned = integer.isUnsigned(); - switch (integer.getWidth()) { - case 1: - return PrimitiveType::PRED; - case 8: - return is_unsigned ? PrimitiveType::U8 : PrimitiveType::S8; - case 16: - return is_unsigned ? PrimitiveType::U16 : PrimitiveType::S16; - case 32: - return is_unsigned ? PrimitiveType::U32 : PrimitiveType::S32; - case 64: - return is_unsigned ? PrimitiveType::U64 : PrimitiveType::S64; - default: - return PrimitiveType::PRIMITIVE_TYPE_INVALID; - } + return PrimitiveType::PRIMITIVE_TYPE_INVALID; + } else if (auto integer_type = type.dyn_cast()) { + bool is_unsigned = integer_type.isUnsigned(); + switch (integer_type.getWidth()) { + case 1: + return PrimitiveType::PRED; + case 8: + return is_unsigned ? PrimitiveType::U8 : PrimitiveType::S8; + case 16: + return is_unsigned ? PrimitiveType::U16 : PrimitiveType::S16; + case 32: + return is_unsigned ? PrimitiveType::U32 : PrimitiveType::S32; + case 64: + return is_unsigned ? PrimitiveType::U64 : PrimitiveType::S64; + default: + return PrimitiveType::PRIMITIVE_TYPE_INVALID; } - default: - return PrimitiveType::PRIMITIVE_TYPE_INVALID; } + return PrimitiveType::PRIMITIVE_TYPE_INVALID; } StatusOr TypeToShape( @@ -108,108 +102,89 @@ Shape TypeToShape(mlir::Type type) { if (ptype != PrimitiveType::PRIMITIVE_TYPE_INVALID) return ShapeUtil::MakeShape(ptype, {}); - switch (type.getKind()) { - case mlir::StandardTypes::BF16: - case mlir::StandardTypes::F32: - case mlir::StandardTypes::F64: - case mlir::StandardTypes::Integer: { - auto* context = type.getContext(); - mlir::emitError(mlir::UnknownLoc::get(context)) - << "lowering should have been handled by primitive type lowering for " - << debugString(type); - break; + if (type.isBF16() || type.isF32() || type.isF64() || + type.isa()) { + auto* context = type.getContext(); + mlir::emitError(mlir::UnknownLoc::get(context)) + << "lowering should have been handled by primitive type lowering for " + << debugString(type); + } else if (auto v = type.dyn_cast()) { + llvm::SmallVector span(v.getShape().begin(), v.getShape().end()); + mlir::Type element_type = v.getElementType(); + PrimitiveType primitive_type = TypeToPrimitiveType(element_type); + if (primitive_type != PrimitiveType::PRIMITIVE_TYPE_INVALID) + return ShapeUtil::MakeShape(primitive_type, span); + } else if (auto m = type.dyn_cast()) { + llvm::SmallVector span(m.getShape().begin(), m.getShape().end()); + mlir::Type element_type = m.getElementType(); + // Treat a memref of a vector as if it was a memref of primitive type with + // the vector dimensions at the end. + if (auto v = element_type.dyn_cast()) { + element_type = v.getElementType(); + span.insert(span.end(), v.getShape().begin(), v.getShape().end()); } - case mlir::StandardTypes::Vector: { - const auto v = type.cast(); - llvm::SmallVector span(v.getShape().begin(), - v.getShape().end()); - mlir::Type element_type = v.getElementType(); - PrimitiveType primitive_type = TypeToPrimitiveType(element_type); - if (primitive_type != PrimitiveType::PRIMITIVE_TYPE_INVALID) - return ShapeUtil::MakeShape(primitive_type, span); - break; - } - case mlir::StandardTypes::MemRef: { - const auto m = type.cast(); - llvm::SmallVector span(m.getShape().begin(), - m.getShape().end()); - mlir::Type element_type = m.getElementType(); - // Treat a memref of a vector as if it was a memref of primitive type with - // the vector dimensions at the end. - if (auto v = element_type.dyn_cast()) { - element_type = v.getElementType(); - span.insert(span.end(), v.getShape().begin(), v.getShape().end()); + PrimitiveType primitive_type = TypeToPrimitiveType(element_type); + if (primitive_type == PrimitiveType::PRIMITIVE_TYPE_INVALID) return {}; + // For the primitive type case, the shape of the memref is similar to the + // vector type case (i.e., it is, modulo the layout, the same dimensions + // and primitive type). + if (m.getAffineMaps().empty()) + return ShapeUtil::MakeShape(primitive_type, span); + + if (m.getAffineMaps().size() == 1) { + llvm::SmallVector strides; + int64_t offset; + if (failed(mlir::getStridesAndOffset(m, strides, offset))) return {}; + + llvm::SmallVector, 4> strides_with_indices; + for (const auto& e : llvm::enumerate(strides)) { + strides_with_indices.push_back({e.value(), e.index()}); } - PrimitiveType primitive_type = TypeToPrimitiveType(element_type); - if (primitive_type == PrimitiveType::PRIMITIVE_TYPE_INVALID) break; - // For the primitive type case, the shape of the memref is similar to the - // vector type case (i.e., it is, modulo the layout, the same dimensions - // and primitive type). - if (m.getAffineMaps().empty()) - return ShapeUtil::MakeShape(primitive_type, span); + std::sort(strides_with_indices.begin(), strides_with_indices.end()); - if (m.getAffineMaps().size() == 1) { - llvm::SmallVector strides; - int64_t offset; - if (failed(mlir::getStridesAndOffset(m, strides, offset))) return {}; + llvm::SmallVector minor_to_major; + int64_t stride = 1; + for (const auto& pr : strides_with_indices) { + minor_to_major.push_back(pr.second); - llvm::SmallVector, 4> strides_with_indices; - for (const auto& e : llvm::enumerate(strides)) { - strides_with_indices.push_back({e.value(), e.index()}); - } - std::sort(strides_with_indices.begin(), strides_with_indices.end()); + // Either the affine map is not perfectly strided, or the dimensions + // recovered from strides don't match the actual dimensions in shapes. + if (stride != pr.first) return {}; - llvm::SmallVector minor_to_major; - int64_t stride = 1; - for (const auto& pr : strides_with_indices) { - minor_to_major.push_back(pr.second); - - // Either the affine map is not perfectly strided, or the dimensions - // recovered from strides don't match the actual dimensions in shapes. - if (stride != pr.first) return {}; - - stride *= m.getShape()[pr.second]; - } - - llvm::SmallVector dimensions(m.getShape().begin(), - m.getShape().end()); - return ::xla::ShapeUtil::MakeShapeWithLayout(primitive_type, dimensions, - minor_to_major); + stride *= m.getShape()[pr.second]; } - break; + + llvm::SmallVector dimensions(m.getShape().begin(), + m.getShape().end()); + return ::xla::ShapeUtil::MakeShapeWithLayout(primitive_type, dimensions, + minor_to_major); } - case mlir::StandardTypes::RankedTensor: { - // TODO(jpienaar): This is only handling the base case with primitive - // element type. - const auto t = type.cast(); - llvm::SmallVector span(t.getShape().begin(), - t.getShape().end()); - // Only fully static shapes are supported. - // TODO(b/115638799): Update once xla::Shape can support dynamic shapes. - if (std::find(t.getShape().begin(), t.getShape().end(), -1) != - t.getShape().end()) - break; - mlir::Type element_type = t.getElementType(); - PrimitiveType primitive_type = TypeToPrimitiveType(element_type); - // Only primitive element type supported. - if (primitive_type != PrimitiveType::PRIMITIVE_TYPE_INVALID) - return ShapeUtil::MakeShape(primitive_type, span); - break; + } else if (auto t = type.dyn_cast()) { + // TODO(jpienaar): This is only handling the base case with primitive + // element type. + llvm::SmallVector span(t.getShape().begin(), t.getShape().end()); + // Only fully static shapes are supported. + // TODO(b/115638799): Update once xla::Shape can support dynamic shapes. + if (std::find(t.getShape().begin(), t.getShape().end(), -1) != + t.getShape().end()) + return {}; + mlir::Type element_type = t.getElementType(); + PrimitiveType primitive_type = TypeToPrimitiveType(element_type); + // Only primitive element type supported. + if (primitive_type != PrimitiveType::PRIMITIVE_TYPE_INVALID) + return ShapeUtil::MakeShape(primitive_type, span); + } else if (auto tuple_type = type.dyn_cast()) { + llvm::SmallVector shapes; + shapes.reserve(tuple_type.size()); + for (mlir::Type sub_type : tuple_type.getTypes()) { + shapes.push_back(TypeToShape(sub_type)); } - case mlir::StandardTypes::Tuple: { - const auto t = type.cast(); - llvm::SmallVector shapes; - shapes.reserve(t.size()); - for (mlir::Type sub_type : t.getTypes()) { - shapes.push_back(TypeToShape(sub_type)); - } - return ShapeUtil::MakeTupleShape(shapes); - } - case mlir::mhlo::HLOTypes::Token: - return ShapeUtil::MakeTokenShape(); - default: - break; + return ShapeUtil::MakeTupleShape(shapes); + + } else if (type.isa()) { + return ShapeUtil::MakeTokenShape(); } + // Return empty XLA shape to signify error. No MLIR Type maps to a empty // Shape. return {}; diff --git a/tensorflow/compiler/mlir/xla/type_to_shape_test.cc b/tensorflow/compiler/mlir/xla/type_to_shape_test.cc index a4a2bc42d99..ce709b10462 100644 --- a/tensorflow/compiler/mlir/xla/type_to_shape_test.cc +++ b/tensorflow/compiler/mlir/xla/type_to_shape_test.cc @@ -64,6 +64,7 @@ inline ::testing::PolymorphicMatcher EqualsProto( TEST(TypeToShapeTest, ConvertPrimitiveTypes) { MLIRContext context; + context.loadAllGloballyRegisteredDialects(); Builder b(&context); EXPECT_EQ(TypeToPrimitiveType(b.getF32Type()), PrimitiveType::F32); @@ -74,6 +75,7 @@ TEST(TypeToShapeTest, ConvertPrimitiveTypes) { TEST(TypeToShapeTest, ConvertBasicTypesToTypes) { MLIRContext context; + context.loadAllGloballyRegisteredDialects(); Builder b(&context); EXPECT_TRUE( @@ -95,6 +97,7 @@ TEST(TypeToShapeTest, ConvertBasicTypesToTypes) { TEST(TypeToShapeTest, ConvertMemRefTypeToTypes) { MLIRContext context; + context.loadAllGloballyRegisteredDialects(); Builder b(&context); // Memref without any affine map. Note: memory space is ignored for shape. diff --git a/tensorflow/compiler/tests/BUILD b/tensorflow/compiler/tests/BUILD index 7f099540f39..30b8a7e5561 100644 --- a/tensorflow/compiler/tests/BUILD +++ b/tensorflow/compiler/tests/BUILD @@ -265,6 +265,7 @@ tf_xla_py_test( name = "categorical_op_test", size = "small", srcs = ["categorical_op_test.py"], + enable_mlir_bridge = True, python_version = "PY3", tags = [ "no_pip", # TODO(b/149738646): fix pip install so these tests run on kokoro pip @@ -283,6 +284,7 @@ tf_xla_py_test( name = "cholesky_op_test", size = "medium", srcs = ["cholesky_op_test.py"], + enable_mlir_bridge = True, python_version = "PY3", tags = [ "no_pip", # TODO(b/149738646): fix pip install so these tests run on kokoro pip @@ -347,6 +349,7 @@ tf_xla_py_test( size = "small", timeout = "moderate", srcs = ["searchsorted_op_test.py"], + enable_mlir_bridge = True, python_version = "PY3", tags = [ "no_pip", # TODO(b/149738646): fix pip install so these tests run on kokoro pip @@ -389,6 +392,7 @@ tf_xla_py_test( size = "small", timeout = "moderate", srcs = ["matrix_inverse_op_test.py"], + enable_mlir_bridge = True, python_version = "PY3", tags = [ "no_pip", # TODO(b/149738646): fix pip install so these tests run on kokoro pip @@ -411,6 +415,7 @@ tf_xla_py_test( size = "small", timeout = "moderate", srcs = ["matrix_solve_op_test.py"], + enable_mlir_bridge = True, python_version = "PY3", tags = [ "no_pip", # TODO(b/149738646): fix pip install so these tests run on kokoro pip @@ -429,6 +434,7 @@ tf_xla_py_test( size = "small", timeout = "moderate", srcs = ["matrix_triangular_solve_op_test.py"], + enable_mlir_bridge = True, python_version = "PY3", tags = [ "no_pip", # TODO(b/149738646): fix pip install so these tests run on kokoro pip @@ -469,7 +475,6 @@ tf_xla_py_test( enable_mlir_bridge = True, python_version = "PY3", tags = [ - "many_xla_args", "no_pip", # TODO(b/149738646): fix pip install so these tests run on kokoro pip "no_rocm", ], @@ -533,6 +538,7 @@ tf_xla_py_test( name = "depthwise_conv_op_test", size = "medium", srcs = ["depthwise_conv_op_test.py"], + enable_mlir_bridge = True, python_version = "PY3", shard_count = 5, tags = [ @@ -632,6 +638,7 @@ tf_xla_py_test( name = "extract_image_patches_op_test", size = "small", srcs = ["extract_image_patches_op_test.py"], + enable_mlir_bridge = True, python_version = "PY3", tags = [ "no_pip", # TODO(b/149738646): fix pip install so these tests run on kokoro pip @@ -688,6 +695,7 @@ tf_xla_py_test( name = "fft_test", size = "medium", srcs = ["fft_test.py"], + enable_mlir_bridge = True, python_version = "PY3", shard_count = 6, tags = [ @@ -783,6 +791,7 @@ tf_xla_py_test( name = "listdiff_op_test", size = "small", srcs = ["listdiff_op_test.py"], + enable_mlir_bridge = True, python_version = "PY3", tags = [ "no_pip", # TODO(b/149738646): fix pip install so these tests run on kokoro pip @@ -821,6 +830,7 @@ tf_xla_py_test( name = "manip_ops_test", size = "small", srcs = ["manip_ops_test.py"], + enable_mlir_bridge = True, python_version = "PY3", tags = [ "no_pip", # TODO(b/149738646): fix pip install so these tests run on kokoro pip @@ -928,6 +938,7 @@ tf_xla_py_test( name = "pooling_ops_test", size = "medium", srcs = ["pooling_ops_test.py"], + enable_mlir_bridge = True, python_version = "PY3", shard_count = 20, tags = [ @@ -1006,6 +1017,7 @@ tf_xla_py_test( "cpu", "cpu_ondemand", ], + enable_mlir_bridge = True, python_version = "PY3", shard_count = 5, tags = [ @@ -1032,6 +1044,7 @@ tf_xla_py_test( "cpu", "cpu_ondemand", ], + enable_mlir_bridge = True, python_version = "PY3", shard_count = 5, tags = [ @@ -1114,6 +1127,7 @@ tf_xla_py_test( name = "reverse_ops_test", size = "medium", srcs = ["reverse_ops_test.py"], + enable_mlir_bridge = True, python_version = "PY3", tags = [ "no_pip", # TODO(b/149738646): fix pip install so these tests run on kokoro pip @@ -1165,6 +1179,7 @@ tf_xla_py_test( name = "scan_ops_test", size = "medium", srcs = ["scan_ops_test.py"], + enable_mlir_bridge = True, python_version = "PY3", tags = [ "no_pip", # TODO(b/149738646): fix pip install so these tests run on kokoro pip @@ -1183,6 +1198,7 @@ tf_xla_py_test( name = "segment_reduction_ops_test", size = "medium", srcs = ["segment_reduction_ops_test.py"], + enable_mlir_bridge = True, python_version = "PY3", tags = [ "no_pip", # TODO(b/149738646): fix pip install so these tests run on kokoro pip @@ -1281,6 +1297,7 @@ tf_xla_py_test( name = "stateless_random_ops_test", size = "medium", srcs = ["stateless_random_ops_test.py"], + enable_mlir_bridge = True, python_version = "PY3", tags = [ "no_pip", # TODO(b/149738646): fix pip install so these tests run on kokoro pip @@ -1564,6 +1581,7 @@ tf_xla_py_test( name = "xla_device_test", size = "small", srcs = ["xla_device_test.py"], + enable_mlir_bridge = True, python_version = "PY3", tags = [ "no_pip", # TODO(b/149738646): fix pip install so these tests run on kokoro pip @@ -1886,6 +1904,7 @@ tf_xla_py_test( name = "special_math_test", size = "medium", srcs = ["special_math_test.py"], + enable_mlir_bridge = True, shard_count = 5, tags = [ "no_pip", # TODO(b/149738646): fix pip install so these tests run on kokoro pip diff --git a/tensorflow/compiler/tests/matrix_triangular_solve_op_test.py b/tensorflow/compiler/tests/matrix_triangular_solve_op_test.py index 0202c582ef3..9d278cfbb28 100644 --- a/tensorflow/compiler/tests/matrix_triangular_solve_op_test.py +++ b/tensorflow/compiler/tests/matrix_triangular_solve_op_test.py @@ -135,6 +135,7 @@ class MatrixTriangularSolveOpTest(xla_test.XLATestCase): self._VerifyTriangularSolve( a.astype(np.float32), b.astype(np.float32), True, False, 1e-4) + @test_util.disable_mlir_bridge("Error handling") def testNonSquareCoefficientMatrix(self): rng = np.random.RandomState(0) for dtype in self.float_types: @@ -145,6 +146,7 @@ class MatrixTriangularSolveOpTest(xla_test.XLATestCase): linalg_ops.matrix_triangular_solve(a, b) @test_util.run_v2_only # Different error types + @test_util.disable_mlir_bridge("Error handling") def testWrongDimensionsV2(self): randn = np.random.RandomState(0).randn for dtype in self.float_types: @@ -156,6 +158,7 @@ class MatrixTriangularSolveOpTest(xla_test.XLATestCase): linalg_ops.matrix_triangular_solve(lhs, rhs) @test_util.run_v1_only("Different error types") + @test_util.disable_mlir_bridge("Error handling") def testWrongDimensionsV1(self): randn = np.random.RandomState(0).randn for dtype in self.float_types: diff --git a/tensorflow/compiler/tests/randomized_tests.cc b/tensorflow/compiler/tests/randomized_tests.cc index 9f963110cf3..0f19affc8e3 100644 --- a/tensorflow/compiler/tests/randomized_tests.cc +++ b/tensorflow/compiler/tests/randomized_tests.cc @@ -63,9 +63,9 @@ limitations under the License. #include "tensorflow/core/framework/tensor_testutil.h" #include "tensorflow/core/framework/types.pb.h" #include "tensorflow/core/graph/graph.h" -#include "tensorflow/core/lib/bfloat16/bfloat16.h" #include "tensorflow/core/lib/core/status.h" #include "tensorflow/core/lib/core/status_test_util.h" +#include "tensorflow/core/platform/bfloat16.h" #include "tensorflow/core/platform/test.h" #include "tensorflow/core/public/session.h" #include "tensorflow/core/public/session_options.h" diff --git a/tensorflow/compiler/tests/scan_ops_test.py b/tensorflow/compiler/tests/scan_ops_test.py index 7c36f8b13ca..440b7672d98 100644 --- a/tensorflow/compiler/tests/scan_ops_test.py +++ b/tensorflow/compiler/tests/scan_ops_test.py @@ -24,6 +24,7 @@ from tensorflow.compiler.tests import xla_test from tensorflow.python.framework import constant_op from tensorflow.python.framework import errors_impl from tensorflow.python.framework import ops +from tensorflow.python.framework import test_util from tensorflow.python.ops import array_ops from tensorflow.python.ops import math_ops from tensorflow.python.platform import test @@ -129,6 +130,7 @@ class CumsumTest(xla_test.XLATestCase): for axis in range(-6, 6, 3): self._compareAll(x, axis) + @test_util.disable_mlir_bridge("Error handling") def testInvalidAxis(self): x = np.arange(0, 10).reshape([2, 5]).astype(np.float32) with self.session(), self.test_scope(): @@ -207,6 +209,7 @@ class CumprodTest(xla_test.XLATestCase): for axis in range(-6, 6, 3): self._compareAll(x, axis) + @test_util.disable_mlir_bridge("Error handling") def testInvalidAxis(self): x = np.arange(0, 10).reshape([2, 5]).astype(np.float32) with self.session(), self.test_scope(): diff --git a/tensorflow/compiler/tests/ternary_ops_test.py b/tensorflow/compiler/tests/ternary_ops_test.py index 7bbfecff403..4109fdc64a5 100644 --- a/tensorflow/compiler/tests/ternary_ops_test.py +++ b/tensorflow/compiler/tests/ternary_ops_test.py @@ -214,7 +214,6 @@ class TernaryOpsTest(xla_test.XLATestCase, parameterized.TestCase): upper, expected=np.minimum(np.maximum(x, lower), upper)) - @test_util.disable_mlir_bridge('Enable tf.Betainc Compilation') def testBetaincSanity(self): # This operation is only supported for float32 and float64. for dtype in self.numeric_types & {np.float32, np.float64}: @@ -252,7 +251,6 @@ class TernaryOpsTest(xla_test.XLATestCase, parameterized.TestCase): 'atol': 2e-4 }, ) - @test_util.disable_mlir_bridge('Enable tf.Betainc Compilation') def testBetainc(self, sigma, rtol, atol): # This operation is only supported for float32 and float64. for dtype in self.numeric_types & {np.float32, np.float64}: diff --git a/tensorflow/compiler/tests/unary_ops_test.py b/tensorflow/compiler/tests/unary_ops_test.py index eb022da6895..b5f82bcff12 100644 --- a/tensorflow/compiler/tests/unary_ops_test.py +++ b/tensorflow/compiler/tests/unary_ops_test.py @@ -96,7 +96,7 @@ class UnaryOpsTest(xla_test.XLATestCase): self.assertAllEqual(result, expected) @test_util.disable_mlir_bridge( - "MlirHloBuilder::Iota missing required for xla::Diag") + "Handle complex element type in DiagPart lowering") def testAllTypeOps(self): for dtype in self.numeric_types - {np.int8, np.uint8}: self._assertOpOutputMatchesExpected( @@ -538,8 +538,6 @@ class UnaryOpsTest(xla_test.XLATestCase): np.array([-40, 40], dtype=dtype), expected=np.array([1.0, 0.025], dtype=dtype)) - @test_util.disable_mlir_bridge( - "TODO(b/153812660): Handle tf.QuantizeAndDequantize compilation") def testQuantizeAndDequantize(self): for dtype in self.float_types: @@ -1070,8 +1068,6 @@ class UnaryOpsTest(xla_test.XLATestCase): ], equality_test=self.ListsAreClose) - @test_util.disable_mlir_bridge( - "TODO(b/153812660): Handle tf.DepthToSpace compilation") def testDepthToSpace(self): def make_op(data_format): @@ -1118,14 +1114,12 @@ class UnaryOpsTest(xla_test.XLATestCase): self._assertOpOutputMatchesExpected( make_op("NCHW_VECT_C"), np.arange(32, dtype=dtype).reshape((1, 8, 1, 1, 4)), - expected=np.array([[[[[0, 1], [8, 9]], [[16, 17], [24, 25]]], - [[[2, 3], [10, 11]], [[18, 19], [26, 27]]], - [[[4, 5], [12, 13]], [[20, 21], [28, 29]]], - [[[6, 7], [14, 15]], [[22, 23], [30, 31]]]]], + expected=np.array([[[[[0, 1, 2, 3], [8, 9, 10, 11]], + [[16, 17, 18, 19], [24, 25, 26, 27]]], + [[[4, 5, 6, 7], [12, 13, 14, 15]], + [[20, 21, 22, 23], [28, 29, 30, 31]]]]], dtype=dtype)) - @test_util.disable_mlir_bridge( - "TODO(b/153812660): Handle tf.SpaceToDepth compilation") def testSpaceToDepth(self): def make_op(data_format): @@ -1172,11 +1166,11 @@ class UnaryOpsTest(xla_test.XLATestCase): self._assertOpOutputMatchesExpected( make_op("NCHW_VECT_C"), np.arange(32, dtype=dtype).reshape((1, 2, 2, 2, 4)), - expected=np.array([[[[[0, 1, 2, 3, 16, 17, 18, 19]]], - [[[4, 5, 6, 7, 20, 21, 22, 23]]], - [[[8, 9, 10, 11, 24, 25, 26, 27]]], - [[[12, 13, 14, 15, 28, 29, 30, 31]]]]], - dtype=dtype)) + expected=np.array( + [[[[[0, 1, 2, 3]]], [[[16, 17, 18, 19]]], [[[4, 5, 6, 7]]], + [[[20, 21, 22, 23]]], [[[8, 9, 10, 11]]], [[[24, 25, 26, 27]]], + [[[12, 13, 14, 15]]], [[[28, 29, 30, 31]]]]], + dtype=dtype)) def _assertSoftplusMatchesExpected(self, features, diff --git a/tensorflow/compiler/tf2tensorrt/BUILD b/tensorflow/compiler/tf2tensorrt/BUILD index 0718bd8cd65..44fb5513886 100644 --- a/tensorflow/compiler/tf2tensorrt/BUILD +++ b/tensorflow/compiler/tf2tensorrt/BUILD @@ -11,7 +11,6 @@ load( "tf_custom_op_library_additional_deps", "tf_gen_op_libs", "tf_gen_op_wrapper_py", - "tf_gpu_kernel_library", ) # buildifier: disable=same-origin-load @@ -81,6 +80,7 @@ tf_cuda_cc_test( cc_library( name = "common_utils", + srcs = ["common/utils.cc"], hdrs = ["common/utils.h"], copts = tf_copts(), deps = [ @@ -539,20 +539,6 @@ tf_cuda_cc_test( ], ) -tf_gpu_kernel_library( - name = "plugin_cast", - srcs = ["plugin/plugin_cast.cu.cc"], - deps = [ - ":trt_plugins", - "@com_google_absl//absl/strings", - "//tensorflow/core/platform:logging", - "//tensorflow/core:framework_lite", - ] + if_tensorrt([ - "@local_config_cuda//cuda:cuda_headers", - "@local_config_tensorrt//:tensorrt", - ]), -) - tf_cuda_library( name = "trt_plugins", srcs = ["plugin/trt_plugin.cc"], @@ -602,6 +588,7 @@ pybind_extension( link_in_framework = True, module_name = "_pywrap_py_utils", deps = [ + ":common_utils", ":py_utils", "//tensorflow/core/platform:env", "//tensorflow/core/platform:logging", diff --git a/tensorflow/compiler/tf2tensorrt/common/utils.cc b/tensorflow/compiler/tf2tensorrt/common/utils.cc new file mode 100644 index 00000000000..6679ca04513 --- /dev/null +++ b/tensorflow/compiler/tf2tensorrt/common/utils.cc @@ -0,0 +1,99 @@ +/* Copyright 2020 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tensorflow/compiler/tf2tensorrt/common/utils.h" + +#if GOOGLE_CUDA && GOOGLE_TENSORRT +#include "absl/base/call_once.h" +#include "absl/strings/str_join.h" +#include "third_party/tensorrt/NvInferPlugin.h" +#endif + +namespace tensorflow { +namespace tensorrt { + +std::tuple GetLinkedTensorRTVersion() { +#if GOOGLE_CUDA && GOOGLE_TENSORRT + return std::tuple{NV_TENSORRT_MAJOR, NV_TENSORRT_MINOR, + NV_TENSORRT_PATCH}; +#else + return std::tuple{0, 0, 0}; +#endif +} + +std::tuple GetLoadedTensorRTVersion() { +#if GOOGLE_CUDA && GOOGLE_TENSORRT + int ver = getInferLibVersion(); + int major = ver / 1000; + ver = ver - major * 1000; + int minor = ver / 100; + int patch = ver - minor * 100; + return std::tuple{major, minor, patch}; +#else + return std::tuple{0, 0, 0}; +#endif +} + +} // namespace tensorrt +} // namespace tensorflow + +#if GOOGLE_CUDA && GOOGLE_TENSORRT +namespace tensorflow { +namespace tensorrt { +namespace { + +void InitializeTrtPlugins(nvinfer1::ILogger* trt_logger) { + LOG(INFO) << "Linked TensorRT version: " + << absl::StrJoin(GetLinkedTensorRTVersion(), "."); + LOG(INFO) << "Loaded TensorRT version: " + << absl::StrJoin(GetLoadedTensorRTVersion(), "."); + + bool plugin_initialized = initLibNvInferPlugins(trt_logger, ""); + if (!plugin_initialized) { + LOG(ERROR) << "Failed to initialize TensorRT plugins, and conversion may " + "fail later."; + } + + int num_trt_plugins = 0; + nvinfer1::IPluginCreator* const* trt_plugin_creator_list = + getPluginRegistry()->getPluginCreatorList(&num_trt_plugins); + if (!trt_plugin_creator_list) { + LOG_WARNING_WITH_PREFIX << "Can not find any TensorRT plugins in registry."; + } else { + VLOG(1) << "Found the following " << num_trt_plugins + << " TensorRT plugins in registry:"; + for (int i = 0; i < num_trt_plugins; ++i) { + if (!trt_plugin_creator_list[i]) { + LOG_WARNING_WITH_PREFIX + << "TensorRT plugin at index " << i + << " is not accessible (null pointer returned by " + "getPluginCreatorList for this plugin)"; + } else { + VLOG(1) << " " << trt_plugin_creator_list[i]->getPluginName(); + } + } + } +} + +} // namespace + +void MaybeInitializeTrtPlugins(nvinfer1::ILogger* trt_logger) { + static absl::once_flag once; + absl::call_once(once, InitializeTrtPlugins, trt_logger); +} + +} // namespace tensorrt +} // namespace tensorflow +#endif diff --git a/tensorflow/compiler/tf2tensorrt/common/utils.h b/tensorflow/compiler/tf2tensorrt/common/utils.h index b428733ecd4..b76b75de783 100644 --- a/tensorflow/compiler/tf2tensorrt/common/utils.h +++ b/tensorflow/compiler/tf2tensorrt/common/utils.h @@ -16,15 +16,33 @@ limitations under the License. #ifndef TENSORFLOW_COMPILER_TF2TENSORRT_COMMON_UTILS_H_ #define TENSORFLOW_COMPILER_TF2TENSORRT_COMMON_UTILS_H_ +#include + +namespace tensorflow { +namespace tensorrt { +// Returns the compile time TensorRT library version information +// {Maj, Min, Patch}. +std::tuple GetLinkedTensorRTVersion(); + +// Returns the runtime time TensorRT library version information +// {Maj, Min, Patch}. +std::tuple GetLoadedTensorRTVersion(); +} // namespace tensorrt +} // namespace tensorflow + #if GOOGLE_CUDA && GOOGLE_TENSORRT #include "tensorflow/core/platform/logging.h" +#include "third_party/tensorrt/NvInfer.h" namespace tensorflow { namespace tensorrt { #define LOG_WARNING_WITH_PREFIX LOG(WARNING) << "TF-TRT Warning: " +// Initializes the TensorRT plugin registry if this hasn't been done yet. +void MaybeInitializeTrtPlugins(nvinfer1::ILogger* trt_logger); + } // namespace tensorrt } // namespace tensorflow diff --git a/tensorflow/compiler/tf2tensorrt/convert/convert_nodes.cc b/tensorflow/compiler/tf2tensorrt/convert/convert_nodes.cc index f80c0f42eca..c0c3f25177e 100644 --- a/tensorflow/compiler/tf2tensorrt/convert/convert_nodes.cc +++ b/tensorflow/compiler/tf2tensorrt/convert/convert_nodes.cc @@ -1197,42 +1197,6 @@ Status TrtNodeValidator::ConvertConstToWeights( return status; } -static void InitializeTrtPlugins(nvinfer1::ILogger* trt_logger) { - static mutex plugin_mutex(LINKER_INITIALIZED); - static bool plugin_initialized = false; - mutex_lock lock(plugin_mutex); - if (plugin_initialized) return; - - LOG(INFO) << "Linked TensorRT version: " << GetLinkedTensorRTVersion(); - LOG(INFO) << "Loaded TensorRT version: " << GetLoadedTensorRTVersion(); - - plugin_initialized = initLibNvInferPlugins(trt_logger, ""); - if (!plugin_initialized) { - LOG(ERROR) << "Failed to initialize TensorRT plugins, and conversion may " - "fail later."; - } - - int num_trt_plugins = 0; - nvinfer1::IPluginCreator* const* trt_plugin_creator_list = - getPluginRegistry()->getPluginCreatorList(&num_trt_plugins); - if (!trt_plugin_creator_list) { - LOG_WARNING_WITH_PREFIX << "Can not find any TensorRT plugins in registry."; - } else { - VLOG(1) << "Found the following " << num_trt_plugins - << " TensorRT plugins in registry:"; - for (int i = 0; i < num_trt_plugins; ++i) { - if (!trt_plugin_creator_list[i]) { - LOG_WARNING_WITH_PREFIX - << "TensorRT plugin at index " << i - << " is not accessible (null pointer returned by " - "getPluginCreatorList for this plugin)"; - } else { - VLOG(1) << " " << trt_plugin_creator_list[i]->getPluginName(); - } - } - } -} - // static StatusOr> Converter::Create( TrtPrecisionMode precision_mode, bool use_calibration, @@ -1249,7 +1213,7 @@ Converter::Converter(TrtPrecisionMode precision_mode, bool use_calibration, : precision_mode_(precision_mode), use_calibration_(use_calibration), use_implicit_batch_(use_implicit_batch) { - InitializeTrtPlugins(trt_logger); + MaybeInitializeTrtPlugins(trt_logger); this->RegisterOpConverters(); } @@ -1434,7 +1398,8 @@ Status Converter::BuildCudaEngine( TF_RETURN_IF_ERROR( TrtPrecisionModeToName(precision_mode_, &precision_mode_str)); string trt_network_name = StrCat( - "TF:", TF_VERSION_STRING, ", ", "TRT:", GetLoadedTensorRTVersion(), "-", + "TF:", TF_VERSION_STRING, ", ", + "TRT:", absl::StrJoin(GetLoadedTensorRTVersion(), "."), "-", "Precision:", precision_mode_str, ", ", "Calibration:", use_calibration_, ", ", "Max-Batch-Size:", max_batch_size, ", ", "Max-Workspace-Size:", max_workspace_size_bytes); diff --git a/tensorflow/compiler/tf2tensorrt/convert/convert_nodes_test.cc b/tensorflow/compiler/tf2tensorrt/convert/convert_nodes_test.cc index aeae44a5562..72348c3cede 100644 --- a/tensorflow/compiler/tf2tensorrt/convert/convert_nodes_test.cc +++ b/tensorflow/compiler/tf2tensorrt/convert/convert_nodes_test.cc @@ -5374,7 +5374,9 @@ TEST_P(OpConverterTest1, ConvertReduce) { expected_output_dims.erase(std::remove(expected_output_dims.begin(), expected_output_dims.end(), 0), expected_output_dims.end()); - VLOG(2) << "out dims " << expected_output_dims; + VLOG(2) << "out dims " + << absl::StrCat("[", absl::StrJoin(expected_output_dims, ","), + "]"); std::vector expected_values = CalcReduce( op.name, p.helper_array, p.stride, op.val_func, op.init_val); TestOpConverter("my_reduce", node_def, expected_output_dims, diff --git a/tensorflow/compiler/tf2tensorrt/convert/utils.cc b/tensorflow/compiler/tf2tensorrt/convert/utils.cc index a69960005fc..1fc0d13c993 100644 --- a/tensorflow/compiler/tf2tensorrt/convert/utils.cc +++ b/tensorflow/compiler/tf2tensorrt/convert/utils.cc @@ -241,36 +241,6 @@ int GetNumberOfEngineInputs(const nvinfer1::ICudaEngine* engine) { #endif -string GetLinkedTensorRTVersion() { - int major, minor, patch; -#if GOOGLE_CUDA && GOOGLE_TENSORRT - major = NV_TENSORRT_MAJOR; - minor = NV_TENSORRT_MINOR; - patch = NV_TENSORRT_PATCH; -#else - major = 0; - minor = 0; - patch = 0; -#endif - return absl::StrCat(major, ".", minor, ".", patch); -} - -string GetLoadedTensorRTVersion() { - int major, minor, patch; -#if GOOGLE_CUDA && GOOGLE_TENSORRT - int ver = getInferLibVersion(); - major = ver / 1000; - ver = ver - major * 1000; - minor = ver / 100; - patch = ver - minor * 100; -#else - major = 0; - minor = 0; - patch = 0; -#endif - return absl::StrCat(major, ".", minor, ".", patch); -} - absl::string_view GetDeviceName(const Node* node) { if (node->has_assigned_device_name()) { return node->assigned_device_name(); diff --git a/tensorflow/compiler/tf2tensorrt/convert/utils.h b/tensorflow/compiler/tf2tensorrt/convert/utils.h index a0505c3f922..7570dff1c9d 100644 --- a/tensorflow/compiler/tf2tensorrt/convert/utils.h +++ b/tensorflow/compiler/tf2tensorrt/convert/utils.h @@ -117,14 +117,6 @@ Status TrtDimsToTensorShape(const nvinfer1::Dims trt_dims, Status TfTypeToTrtType(DataType tf_type, nvinfer1::DataType* trt_type); Status TrtTypeToTfType(nvinfer1::DataType trt_type, DataType* tf_type); -// Returns a string that includes compile time TensorRT library version -// information {Maj, Min, Patch}. -string GetLinkedTensorRTVersion(); - -// Returns a string that includes runtime time TensorRT library version -// information {Maj, Min, Patch}. -string GetLoadedTensorRTVersion(); - // Returns true if an engine built for cached_shapes can also run actual_shapes. bool AreShapesCompatible(const std::vector& actual_shapes, const std::vector& cached_shapes); diff --git a/tensorflow/compiler/tf2tensorrt/kernels/trt_engine_op.cc b/tensorflow/compiler/tf2tensorrt/kernels/trt_engine_op.cc index 58d1c611463..5b2ae822d59 100644 --- a/tensorflow/compiler/tf2tensorrt/kernels/trt_engine_op.cc +++ b/tensorflow/compiler/tf2tensorrt/kernels/trt_engine_op.cc @@ -800,6 +800,9 @@ StatusOr> TRTEngineOp::GetEngine( TrtUniquePtrType infer(nvinfer1::createInferRuntime(logger)); infer->setGpuAllocator(allocator); + // Need to initialize plugins in order to deserialize engines that contain + // plugins. + MaybeInitializeTrtPlugins(&logger); TrtUniquePtrType static_engine( infer->deserializeCudaEngine(serialized_segment_.c_str(), serialized_segment_.size(), nullptr)); diff --git a/tensorflow/compiler/tf2tensorrt/plugin/plugin_cast.cu.cc b/tensorflow/compiler/tf2tensorrt/plugin/plugin_cast.cu.cc deleted file mode 100644 index 141a7d1f462..00000000000 --- a/tensorflow/compiler/tf2tensorrt/plugin/plugin_cast.cu.cc +++ /dev/null @@ -1,236 +0,0 @@ -/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. -==============================================================================*/ - -#include "absl/strings/str_cat.h" -#include "tensorflow/compiler/tf2tensorrt/plugin/trt_plugin.h" -#include "tensorflow/core/platform/logging.h" - -#if GOOGLE_CUDA && GOOGLE_TENSORRT -#define EIGEN_USE_GPU // For definition of Eigen::GpuDevice. -#include "third_party/gpus/cuda/include/cuda_runtime_api.h" -#include "tensorflow/core/util/gpu_kernel_helper.h" -#include "third_party/tensorrt/NvInfer.h" - -namespace tensorflow { -namespace tensorrt { -using nvinfer1::DataType; -using nvinfer1::Dims; -using nvinfer1::IPluginCreator; -using nvinfer1::IPluginV2; -using nvinfer1::IPluginV2Ext; -using nvinfer1::PluginField; -using nvinfer1::PluginFieldCollection; -using nvinfer1::PluginFieldType; -using nvinfer1::PluginFormat; - -template -__global__ void Cast(const SrcT* input, int num_elements, DstT* output) { - for (int i : CudaGridRangeX(num_elements)) { - output[i] = static_cast(input[i]); - } -} - -template -void RunCast(const SrcT* d_input, int num_elements, DstT* d_output, - cudaStream_t stream) { - const int threads_per_block = 256; - const int blocks_per_grid = - (num_elements + threads_per_block - 1) / threads_per_block; - TF_CHECK_OK(CudaLaunchKernel(Cast, threads_per_block, - blocks_per_grid, 0, stream, d_input, - num_elements, d_output)); -} - -const char* kPluginName = "TfTrtPluginCast"; - -class CastPlugin : public TrtPlugin { - public: - CastPlugin(DataType src_type, DataType dst_type) - : src_type_(src_type), dst_type_(dst_type) {} - - CastPlugin(const void* serialized_data, size_t length) - : TrtPlugin(serialized_data, length) { - const char* buffer = static_cast(serialized_data); - src_type_ = ReadFromBuffer(&buffer); - dst_type_ = ReadFromBuffer(&buffer); - src_dims_ = ReadFromBuffer(&buffer); - } - - CastPlugin(const CastPlugin& rhs) - : TrtPlugin(rhs), - src_type_(rhs.src_type_), - dst_type_(rhs.dst_type_), - src_dims_(rhs.src_dims_) {} - - // Methods from IPluginV2Ext. - - DataType getOutputDataType(int index, const DataType* input_types, - int num_inputs) const override { - DCHECK_EQ(0, index); - DCHECK_EQ(1, num_inputs); - return dst_type_; - } - - bool isOutputBroadcastAcrossBatch(int output_index, - const bool* input_is_broadcasted, - int num_inputs) const override { - return false; - } - - bool canBroadcastInputAcrossBatch(int input_index) const override { - return false; - } - - void configurePlugin(const Dims* input_dims, int num_inputs, - const Dims* output_dims, int num_outputs, - const DataType* input_types, - const DataType* output_types, - const bool* input_is_broadcast, - const bool* output_is_broadcast, - PluginFormat float_format, int max_batch_size) override { - DCHECK_EQ(1, num_inputs); - DCHECK_EQ(1, num_outputs); - DCHECK(src_type_ == input_types[0]); - DCHECK(dst_type_ == output_types[0]); - src_dims_ = input_dims[0]; - } - - IPluginV2Ext* clone() const override { return new CastPlugin(*this); } - - // Methods from IPluginV2. - - const char* getPluginType() const override { return kPluginName; }; - - const char* getPluginVersion() const override { return kTfTrtPluginVersion; }; - - int getNbOutputs() const override { return 1; } - - Dims getOutputDimensions(int index, const Dims* inputs, - int num_input_dims) override { - DCHECK_EQ(0, index); - DCHECK_EQ(1, num_input_dims); - return inputs[0]; - } - - bool supportsFormat(DataType type, PluginFormat format) const override { - return type == DataType::kFLOAT || type == DataType::kINT32; - } - - size_t getWorkspaceSize(int max_batch_size) const override { return 0; } - - int enqueue(int batch_size, const void* const* inputs, void** outputs, void*, - cudaStream_t stream) override { - int num_elements = batch_size; - for (int i = 0; i < src_dims_.nbDims; i++) { - num_elements *= src_dims_.d[i]; - } - const void* input = inputs[0]; - void* output = outputs[0]; - DCHECK_NE(static_cast(src_type_), static_cast(dst_type_)); - - switch (src_type_) { - case DataType::kFLOAT: - RunCast(reinterpret_cast(input), num_elements, - reinterpret_cast(output), stream); - break; - case DataType::kINT32: - RunCast(reinterpret_cast(input), num_elements, - reinterpret_cast(output), stream); - break; - default: - return 1; // Indicates a failure. - } - return 0; - } - - size_t getSerializationSize() const override { - return 2 * sizeof(DataType) + sizeof(Dims); - } - - void serialize(void* serialized_data) const override { - char* buffer = static_cast(serialized_data); - WriteToBuffer(src_type_, &buffer); - WriteToBuffer(dst_type_, &buffer); - WriteToBuffer(src_dims_, &buffer); - } - - private: - DataType src_type_; - DataType dst_type_; - Dims src_dims_; -}; - -class CastPluginCreator : public IPluginCreator { - public: - CastPluginCreator() { - setPluginNamespace(kTfTrtPluginNamespace); - plugin_fields_.emplace_back( - PluginField("SrcT", nullptr, PluginFieldType::kINT32, 1)); - plugin_fields_.emplace_back( - PluginField("DstT", nullptr, PluginFieldType::kINT32, 1)); - - field_collection_.nbFields = plugin_fields_.size(); - field_collection_.fields = plugin_fields_.data(); - } - - const char* getPluginName() const override { return kPluginName; } - - const char* getPluginVersion() const override { return kTfTrtPluginVersion; } - - const PluginFieldCollection* getFieldNames() override { - return &field_collection_; - } - - IPluginV2* createPlugin( - const char* name, - const PluginFieldCollection* field_collection) override { - const PluginField* fields = field_collection->fields; - DataType src_type, dst_type; - for (int i = 0; i < field_collection->nbFields; ++i) { - const char* attr_name = fields[i].name; - if (!strcmp(attr_name, "SrcT")) { - src_type = *static_cast(fields[i].data); - } else if (!strcmp(attr_name, "DstT")) { - dst_type = *static_cast(fields[i].data); - } else { - return nullptr; - } - } - return new CastPlugin(src_type, dst_type); - } - - IPluginV2* deserializePlugin(const char* name, const void* serial_data, - size_t serial_len) override { - return new CastPlugin(serial_data, serial_len); - } - - void setPluginNamespace(const char* plugin_namespace) override { - namespace_ = plugin_namespace; - } - - const char* getPluginNamespace() const override { return namespace_.c_str(); } - - private: - PluginFieldCollection field_collection_; - std::vector plugin_fields_; - std::string namespace_; -}; - -REGISTER_TFTRT_PLUGIN(CastPluginCreator); - -} // namespace tensorrt -} // namespace tensorflow - -#endif // GOOGLE_CUDA && GOOGLE_TENSORRT diff --git a/tensorflow/compiler/tf2tensorrt/utils/py_utils.cc b/tensorflow/compiler/tf2tensorrt/utils/py_utils.cc index a8e24aa8983..3f8a11f7410 100644 --- a/tensorflow/compiler/tf2tensorrt/utils/py_utils.cc +++ b/tensorflow/compiler/tf2tensorrt/utils/py_utils.cc @@ -41,31 +41,5 @@ bool IsGoogleTensorRTEnabled() { #endif } -void GetLinkedTensorRTVersion(int* major, int* minor, int* patch) { -#if GOOGLE_CUDA && GOOGLE_TENSORRT - *major = NV_TENSORRT_MAJOR; - *minor = NV_TENSORRT_MINOR; - *patch = NV_TENSORRT_PATCH; -#else - *major = 0; - *minor = 0; - *patch = 0; -#endif -} - -void GetLoadedTensorRTVersion(int* major, int* minor, int* patch) { -#if GOOGLE_CUDA && GOOGLE_TENSORRT - int ver = getInferLibVersion(); - *major = ver / 1000; - ver = ver - *major * 1000; - *minor = ver / 100; - *patch = ver - *minor * 100; -#else - *major = 0; - *minor = 0; - *patch = 0; -#endif -} - } // namespace tensorrt } // namespace tensorflow diff --git a/tensorflow/compiler/tf2tensorrt/utils/py_utils.h b/tensorflow/compiler/tf2tensorrt/utils/py_utils.h index f52bb6f1bad..9b24eb36cf9 100644 --- a/tensorflow/compiler/tf2tensorrt/utils/py_utils.h +++ b/tensorflow/compiler/tf2tensorrt/utils/py_utils.h @@ -21,12 +21,6 @@ namespace tensorrt { bool IsGoogleTensorRTEnabled(); -// Return compile time TensorRT library version information {Maj, Min, Patch}. -void GetLinkedTensorRTVersion(int* major, int* minor, int* patch); - -// Return runtime time TensorRT library version information {Maj, Min, Patch}. -void GetLoadedTensorRTVersion(int* major, int* minor, int* patch); - } // namespace tensorrt } // namespace tensorflow diff --git a/tensorflow/compiler/tf2tensorrt/utils/py_utils_wrapper.cc b/tensorflow/compiler/tf2tensorrt/utils/py_utils_wrapper.cc index 03f77c6bd5f..52252f125ac 100644 --- a/tensorflow/compiler/tf2tensorrt/utils/py_utils_wrapper.cc +++ b/tensorflow/compiler/tf2tensorrt/utils/py_utils_wrapper.cc @@ -16,18 +16,15 @@ limitations under the License. #include #include "pybind11/pybind11.h" +#include "tensorflow/compiler/tf2tensorrt/common/utils.h" #include "tensorflow/compiler/tf2tensorrt/utils/py_utils.h" std::tuple get_linked_tensorrt_version() { - int major, minor, patch; - tensorflow::tensorrt::GetLinkedTensorRTVersion(&major, &minor, &patch); - return std::tuple{major, minor, patch}; + return tensorflow::tensorrt::GetLinkedTensorRTVersion(); } std::tuple get_loaded_tensorrt_version() { - int major, minor, patch; - tensorflow::tensorrt::GetLoadedTensorRTVersion(&major, &minor, &patch); - return std::tuple{major, minor, patch}; + return tensorflow::tensorrt::GetLoadedTensorRTVersion(); } PYBIND11_MODULE(_pywrap_py_utils, m) { diff --git a/tensorflow/compiler/tf2xla/BUILD b/tensorflow/compiler/tf2xla/BUILD index ac999d875de..e9bcbcc6d83 100644 --- a/tensorflow/compiler/tf2xla/BUILD +++ b/tensorflow/compiler/tf2xla/BUILD @@ -337,7 +337,6 @@ cc_library( visibility = [":friends"], deps = [ ":common", - ":frontend_attributes_util", ":host_compute_metadata_proto_cc", ":rearrange_function_argument", ":sharding_util", @@ -353,23 +352,16 @@ cc_library( "//tensorflow/compiler/jit:common", "//tensorflow/compiler/jit:flags", "//tensorflow/compiler/jit:shape_inference", - "//tensorflow/compiler/jit:xla_cluster_util", "//tensorflow/compiler/mlir/tensorflow:compile_mlir_util_no_tf_dialect_passes", - "//tensorflow/compiler/tf2xla/lib:util", - "//tensorflow/compiler/xla:literal", + "//tensorflow/compiler/xla:protobuf_util", "//tensorflow/compiler/xla:shape_util", "//tensorflow/compiler/xla:status_macros", - "//tensorflow/compiler/xla:statusor", - "//tensorflow/compiler/xla:types", "//tensorflow/compiler/xla:util", "//tensorflow/compiler/xla:xla_data_proto_cc", - "//tensorflow/compiler/xla/client", "//tensorflow/compiler/xla/client:client_library", "//tensorflow/compiler/xla/client:local_client", "//tensorflow/compiler/xla/client:xla_builder", "//tensorflow/compiler/xla/client:xla_computation", - "//tensorflow/compiler/xla/client/lib:arithmetic", - "//tensorflow/compiler/xla/client/lib:constants", "//tensorflow/compiler/xla/service:hlo", "//tensorflow/core:core_cpu", "//tensorflow/core:core_cpu_internal", @@ -378,11 +370,8 @@ cc_library( "//tensorflow/core:lib_internal", "//tensorflow/core:ops", "//tensorflow/core:protos_all_cc", - "//tensorflow/core:stream_executor_no_cuda", "@com_google_absl//absl/algorithm:container", "@com_google_absl//absl/memory", - "@com_google_absl//absl/strings", - "@com_google_absl//absl/types:optional", "@com_google_absl//absl/types:span", "@com_google_absl//absl/types:variant", ], diff --git a/tensorflow/compiler/tf2xla/functionalize_control_flow.cc b/tensorflow/compiler/tf2xla/functionalize_control_flow.cc index 10b26f9801c..596fa8e8e38 100644 --- a/tensorflow/compiler/tf2xla/functionalize_control_flow.cc +++ b/tensorflow/compiler/tf2xla/functionalize_control_flow.cc @@ -46,12 +46,254 @@ limitations under the License. namespace tensorflow { +// Helper functions for functionalizing control flow in functions. + +// Maps function name to +// - new function name, if the function body was functionalized +// - absl::nullopt, if not +using FuncMap = std::map>; +using FuncMapIter = std::map>::const_iterator; + +// Returns whether function has been processed before. +bool FunctionHasBeenProcessed(FuncMapIter func_iter, const FuncMap* func_map) { + return func_iter != func_map->end(); +} + +// Returns whether function has been modified (i.e., functionalized) before. +bool FunctionHasBeenModified(FuncMapIter func_iter) { + return func_iter->second.has_value(); +} + +// Returns a name for the new functionalized version of a function. +string GetNewFunctionName( + const string& func_name, Node* n, + AssociatedFunctionInfo::AssociatedFunctionType func_type, + FunctionLibraryDefinition* fld) { + // For SymbolicGradient, `func_name` is always "SymbolicGradient" which + // is not very informative. Use node name instead. + return ( + func_type == + AssociatedFunctionInfo::AssociatedFunctionType::kSymbolicGradient + ? fld->UniqueFunctionName(absl::StrCat(n->name(), "_f15n_")) + : fld->UniqueFunctionName(absl::StrCat(func_name, "_f15n_"))); +} + +// Returns name to which a modified function has been mapped. +const string& GetMappedFunctionName(FuncMapIter func_iter) { + DCHECK(func_iter->second.has_value()); + return func_iter->second.value(); +} + +// Updates `func_map` with function given by `canonicalized_name`. +void UpdateFunctionMap(FuncMap* func_map, const string& canonicalized_name, + const string& new_func_name, bool function_modified) { + // If function was modified store its new name, otherwise add empty entry to + // record that function has been processed and does not need to be rewritten. + (*func_map)[canonicalized_name] = + function_modified ? absl::make_optional(new_func_name) : absl::nullopt; +} + +// Adds new function def to graph's function library if necessary. +Status AddFunctionDefToGraphLibrary( + const string& func_name, const AssociatedFunctionInfo& associated_function, + Graph* graph, FunctionLibraryDefinition* fld) { + const OpRegistrationData* op_reg_data; + // We have to be careful with adding the function def since there are three + // different `OpRegistryInterface`s involved here: + // `fld`, `graph->flib_def()` and `graph->flib_def().default_registry()`. + // We have already added the function def to `fld` before calling this + // function but for the subsequent `RewriteAssociatedFunction` call we need + // the function def to be in one of the other two registries, otherwise + // `RewriteAssociatedFunction` will fail for the `kFunctionCallNode` case + // because it cannot find the associated function def. + // On the other hand, we should not add the function def if it is already + // contained in one of the last two registries, this would lead to errors when + // the function def is already in one registry and we try to add it to the + // other one (if we try to add it to the same it's fine). This can happen in + // cases where one of the last two registries is identical to `fld` (which we + // already updated). + // Therefore, before adding the function def we have to check if it's already + // contained in either `graph->flib_def()` or + // `graph->flib_def().default_registry()` which is done in the following line + // (we have to use `LookUp` instead of `Contains` or `Find` because the latter + // both don't check the default registry). + if (graph->flib_def().LookUp(func_name, &op_reg_data).ok()) + return Status::OK(); + + const FunctionDef* new_fdef = fld->Find(func_name); + DCHECK(new_fdef != nullptr); + FunctionDefLibrary fdef_lib; + *(fdef_lib.add_function()) = *new_fdef; + return graph->AddFunctionLibrary(fdef_lib); +} + +// Functionalizes function given by `func_name`. Update `func_map` accordingly. +Status FunctionalizeControlFlowForFunction( + const string& func_name, const string& new_func_name, + const protobuf::Map& attrs, + FunctionLibraryDefinition* fld, FunctionLibraryRuntime* flr, + FuncMap* func_map, bool* function_modified, + const NodeFilter& node_filter = {}); + +// Functionalizes all functions that are (directly or indirectly) associated to +// any node in `graph`. Adds processed functions to `func_map`. +Status FunctionalizeControlFlowForNodeAssociatedFunctions( + FuncMap* func_map, Graph* graph, FunctionLibraryDefinition* fld, + FunctionLibraryRuntime* flr, bool* any_function_modified, + const NodeFilter& node_filter) { + std::vector>> + nodes_to_associated_functions; + for (auto* n : graph->nodes()) { + auto associated_functions = GetAssociatedFunctions(*n, fld); + if (!associated_functions.empty()) { + nodes_to_associated_functions.push_back({n, associated_functions}); + } + } + for (const auto& pair : nodes_to_associated_functions) { + Node* n = pair.first; + auto associated_functions = pair.second; + for (auto& associated_function : associated_functions) { + // Note that if `n` is a function call node, then potential calls of + // `RewriteAssociatedFunction` below might delete `n` and create a new + // node instead, making `n` an invalid pointer. That's fine because in + // that case `n` only has one associated function, so this loop has only + // one iteration and we don't use `n` again after the rewrite. + // The invariant is guaranteed by `GetAssociatedFunctions` and confirmed + // below. + DCHECK(associated_function.type() != + AssociatedFunctionInfo::kFunctionCallNode || + associated_functions.size() == 1); + + // Process one node-function-pair. + string func_name = associated_function.func_name(); + string canonicalized_name = + Canonicalize(func_name, AttrSlice(&associated_function.attrs())); + auto func_iter = func_map->find(canonicalized_name); + string new_func_name; + if (FunctionHasBeenProcessed(func_iter, func_map)) { + if (FunctionHasBeenModified(func_iter)) { + *any_function_modified = true; + new_func_name = GetMappedFunctionName(func_iter); + TF_RETURN_IF_ERROR(RewriteAssociatedFunction( + graph, n, fld, associated_function, new_func_name)); + } + continue; + } + // Function is processed for the first time. + bool function_modified = false; + new_func_name = + GetNewFunctionName(func_name, n, associated_function.type(), fld); + // Perform functionalization for current function. + TF_RETURN_IF_ERROR(FunctionalizeControlFlowForFunction( + func_name, new_func_name, associated_function.attrs(), fld, flr, + func_map, &function_modified, node_filter)); + UpdateFunctionMap(func_map, canonicalized_name, new_func_name, + function_modified); + if (function_modified) { + *any_function_modified = true; + TF_RETURN_IF_ERROR(AddFunctionDefToGraphLibrary( + new_func_name, associated_function, graph, fld)); + TF_RETURN_IF_ERROR(RewriteAssociatedFunction( + graph, n, fld, associated_function, new_func_name)); + } + } + } + return Status::OK(); +} + +Status FunctionalizeControlFlowForFunction( + const string& func_name, const string& new_func_name, + const protobuf::Map& attrs, + FunctionLibraryDefinition* fld, FunctionLibraryRuntime* flr, + FuncMap* func_map, bool* function_modified, const NodeFilter& node_filter) { + *function_modified = false; + + // Convert the function to a graph. + FunctionLibraryRuntime::Handle handle; + TF_RETURN_IF_ERROR(flr->Instantiate(func_name, AttrSlice(&attrs), &handle)); + Status ret_status = Status::OK(); + auto cleanup_handle = gtl::MakeCleanup([&]() { + auto s = flr->ReleaseHandle(handle); + if (!s.ok()) { + ret_status.Update(s); + } + }); + const FunctionBody* body = flr->GetFunctionBody(handle); + Graph* g = body->graph; + + // Check if the graph has Switch or Merge node. + bool has_switch_or_merge = false; + for (Node* n : body->graph->nodes()) { + // Skip nodes that are filtered out. + if (node_filter && !node_filter(n)) continue; + if (n->type_string() == "Switch" || n->type_string() == "Merge") { + has_switch_or_merge = true; + break; + } + } + // Before functionalizing control flow in `g` we functionalize control flow + // in functions (directly or indirectly) associated with nodes in `g`. + TF_RETURN_IF_ERROR(FunctionalizeControlFlowForNodeAssociatedFunctions( + func_map, g, fld, flr, function_modified, node_filter)); + + if (has_switch_or_merge) { + *function_modified = true; + + // Functionalize the function body. + if (VLOG_IS_ON(4)) { + DumpGraphToFile( + absl::StrCat("functionalize_control_flow_before_fdef_", func_name), + *g, fld); + } + TF_RETURN_IF_ERROR(FunctionalizeControlFlow(g, fld, node_filter)); + if (VLOG_IS_ON(4)) { + DumpGraphToFile( + absl::StrCat("functionalize_control_flow_after_fdef_", func_name), *g, + fld); + } + } + if (*function_modified) { + // Add rewritten FunctionDef into library. + FunctionDef functionalized_fdef; + TF_RETURN_IF_ERROR( + GraphToFunctionDef(*g, new_func_name, &functionalized_fdef)); + if (func_name == new_func_name) { + VLOG(2) << "Replacing function " << func_name; + TF_RETURN_IF_ERROR( + fld->ReplaceFunction(new_func_name, functionalized_fdef)); + } else { + VLOG(2) << "Adding function " << new_func_name; + TF_RETURN_IF_ERROR(fld->AddFunctionDef(functionalized_fdef)); + } + } + + return ret_status; +} + Status FunctionalizeControlFlow(Graph* graph, FunctionLibraryDefinition* library, - const NodeFilter& node_filter) { + const NodeFilter& node_filter, + bool include_functions) { VLOG(2) << "FunctionalizeControlFlow (initial): " << DumpGraphToFile("functionalize_initial", *graph, library); + if (include_functions) { + // Functionalize control flow in functions that are (directly or indirectly) + // associated with a node in `graph`. + auto pflr = absl::make_unique( + /*device_mgr=*/nullptr, tensorflow::Env::Default(), + /*config=*/nullptr, TF_GRAPH_DEF_VERSION, library, + tensorflow::OptimizerOptions()); + // `pflr` has only one `FunctionLibraryRuntime`, for `kDefaultFLRDevice` + // (because we constructed it with `device_mgr = nullptr`). + FunctionLibraryRuntime* flr = + pflr->GetFLR(ProcessFunctionLibraryRuntime::kDefaultFLRDevice); + + FuncMap func_map; + bool modified = false; + TF_RETURN_IF_ERROR(FunctionalizeControlFlowForNodeAssociatedFunctions( + &func_map, graph, library, flr, &modified, node_filter)); + } // Functionalize and remove while loops from graph. TF_RETURN_IF_ERROR(FunctionalizeWhileLoop(graph, library, node_filter)); @@ -68,153 +310,19 @@ Status FunctionalizeControlFlow(Graph* graph, Status FunctionalizeControlFlowForGraphDef(GraphDef* graph_def, FunctionLibraryDefinition* library, - const NodeFilter& node_filter) { + const NodeFilter& node_filter, + bool include_functions) { FunctionDefLibrary function_lib = graph_def->library(); Graph graph(OpRegistry::Global()); TF_RETURN_IF_ERROR(ConvertGraphDefToGraph({}, *graph_def, &graph)); - TF_RETURN_IF_ERROR(FunctionalizeControlFlow(&graph, library, node_filter)); + TF_RETURN_IF_ERROR(FunctionalizeControlFlow(&graph, library, node_filter, + include_functions)); graph.ToGraphDef(graph_def); std::swap(*graph_def->mutable_library(), function_lib); return Status::OK(); } -Status FunctionalizeControlFlowForFunction( - const string& func_name, const string& new_func_name, - const protobuf::Map& attrs, - FunctionLibraryDefinition* fld, FunctionLibraryRuntime* flr, - std::map>* canonicalized_name_to_new_name, - bool* modified) { - *modified = false; - - // Convert the function to Graph. - FunctionLibraryRuntime::Handle handle; - TF_RETURN_IF_ERROR(flr->Instantiate(func_name, AttrSlice(&attrs), &handle)); - Status ret_status = Status::OK(); - auto cleanup_handle = gtl::MakeCleanup([&]() { - auto s = flr->ReleaseHandle(handle); - if (!s.ok()) { - ret_status.Update(s); - } - }); - const FunctionBody* body = flr->GetFunctionBody(handle); - Graph* g = body->graph; - - // Check if the graph has Switch or Merge node. - bool has_switch_or_merge = false; - for (Node* n : body->graph->nodes()) { - if (n->type_string() == "Switch" || n->type_string() == "Merge") { - has_switch_or_merge = true; - break; - } - } - // We cannot return here directly if the graph has no Switch/Merge. - // It might contain function call nodes, or If/While nodes with Switch/Merge - // in function body. We still need to rewrite those functions and modify - // corresponding nodes. - - // If any node has associated functions, functionalize them first. - // Gather nodes with associated functions first, because rewriting those nodes - // might involve node deletion/addition. Avoid modifying nodes while iterating - // it. - std::vector>> - nodes_to_associated_functions; - for (auto* n : g->nodes()) { - auto associated_functions = GetAssociatedFunctions(*n, fld); - if (!associated_functions.empty()) { - nodes_to_associated_functions.push_back({n, associated_functions}); - } - } - for (const auto& iter : nodes_to_associated_functions) { - Node* n = iter.first; - auto associated_functions = iter.second; - for (auto& associated_function : associated_functions) { - string name = associated_function.func_name(); - string canonicalized_name = - Canonicalize(name, AttrSlice(&associated_function.attrs())); - auto iter = canonicalized_name_to_new_name->find(canonicalized_name); - string new_name; - bool function_modified; - if (iter != canonicalized_name_to_new_name->end()) { - // If we already processed this function, check if it was rewritten. If - // the function was rewritten, the entry will be non-empty. Otherwise - // the entry will be empty. - function_modified = iter->second.has_value(); - if (function_modified) { - new_name = iter->second.value(); - } - } else { - if (associated_function.type() == - AssociatedFunctionInfo::AssociatedFunctionType::kSymbolicGradient) { - // For SymbolicGradient, `name` is always "SymbolicGradient", - // which is not very informative. Use node name instead. - new_name = fld->UniqueFunctionName(absl::StrCat(n->name(), "_f15n_")); - } else { - new_name = fld->UniqueFunctionName(absl::StrCat(name, "_f15n_")); - } - TF_RETURN_IF_ERROR(FunctionalizeControlFlowForFunction( - name, new_name, associated_function.attrs(), fld, flr, - canonicalized_name_to_new_name, &function_modified)); - if (function_modified) { - // If the function was rewritten, add an non-empty entry. So later we - // know we have processed this function, and it was rewritten into - // another function. - (*canonicalized_name_to_new_name)[canonicalized_name] = new_name; - } else { - // If the function was not rewritten, add an empty entry. So later - // we know we have processed this function, and it does not need to be - // rewritten. - (*canonicalized_name_to_new_name)[canonicalized_name] = absl::nullopt; - } - } - if (function_modified) { - *modified = true; - - // Notice that if "n" is a function call, RewriteAssociatedFunction() - // will delete it and create a new node instead, making "n" an invalid - // pointer. That's fine because in that case, associated_functions will - // only have one member and the loop will only run once. - TF_RETURN_IF_ERROR(RewriteAssociatedFunction( - g, n, fld, associated_function, new_name)); - } - } - } - - if (has_switch_or_merge) { - *modified = true; - - // Functionalize the function body. - if (VLOG_IS_ON(4)) { - DumpGraphToFile( - absl::StrCat("functionalize_control_flow_before_fdef_", func_name), - *g, fld); - } - TF_RETURN_IF_ERROR(FunctionalizeControlFlow(g, fld)); - if (VLOG_IS_ON(4)) { - DumpGraphToFile( - absl::StrCat("functionalize_control_flow_after_fdef_", func_name), *g, - fld); - } - } - - if (*modified) { - // Add rewritten FunctionDef into library. - FunctionDef functionalized_fdef; - TF_RETURN_IF_ERROR( - GraphToFunctionDef(*g, new_func_name, &functionalized_fdef)); - if (func_name == new_func_name) { - VLOG(2) << "Replacing function " << func_name; - TF_RETURN_IF_ERROR( - fld->ReplaceFunction(new_func_name, functionalized_fdef)); - } else { - VLOG(2) << "Adding function " << new_func_name; - TF_RETURN_IF_ERROR(fld->AddFunctionDef(functionalized_fdef)); - } - } - - return ret_status; -} - Status FunctionalizeControlFlowForXlaPass::Run( const GraphOptimizationPassOptions& options) { Graph* graph = options.graph->get(); @@ -241,7 +349,7 @@ Status FunctionalizeControlFlowForXlaPass::Run( // XlaLaunch ops are generated by EncapsulateXlaComputationsPass. {"XlaLaunch", "function"}, }; - std::map> canonicalized_name_to_new_name; + FuncMap func_map; bool fld_modified = false; for (Node* n : graph->nodes()) { auto it = kNodeTypeToFunctionAttrMapping->find(n->type_string()); @@ -258,7 +366,7 @@ Status FunctionalizeControlFlowForXlaPass::Run( bool modified; TF_RETURN_IF_ERROR(FunctionalizeControlFlowForFunction( func.name(), new_func_name, func.attr(), options.flib_def, flr, - &canonicalized_name_to_new_name, &modified)); + &func_map, &modified)); if (modified) { n->ClearAttr(func_attr); func.set_name(new_func_name); diff --git a/tensorflow/compiler/tf2xla/functionalize_control_flow.h b/tensorflow/compiler/tf2xla/functionalize_control_flow.h index f9e751e2d67..46abae27878 100644 --- a/tensorflow/compiler/tf2xla/functionalize_control_flow.h +++ b/tensorflow/compiler/tf2xla/functionalize_control_flow.h @@ -30,6 +30,13 @@ namespace tensorflow { // // If `node_filter` is defined, then only loops and conditions for whose // nodes `node_filter` returns true are functionalized. + +// If `include_functions` is true, then loops and conditions inside of functions +// that are associated with nodes in `graph` (e.g., a function called from a +// node in `graph`) are also functionalized, otherwise they are not. +// This also handles transitive cases, e.g., a function body will be +// functionalized when it is called in another function that is called by some +// node in `graph` (and so on). The node filter also applies here. // // Precondition: // For any node in a loop or condition for which `node_filter` returns true, @@ -43,11 +50,13 @@ namespace tensorflow { // satisfies the above conditions. Status FunctionalizeControlFlow(Graph* graph, FunctionLibraryDefinition* library, - const NodeFilter& node_filter = {}); + const NodeFilter& node_filter = {}, + bool include_functions = false); Status FunctionalizeControlFlowForGraphDef(GraphDef* graph_def, FunctionLibraryDefinition* library, - const NodeFilter& node_filter = {}); + const NodeFilter& node_filter = {}, + bool include_functions = false); // This pass looks at the graph, and turns V1 control flow structure // (Switch/Merge/etc.) into V2 control flow structure (If/While). diff --git a/tensorflow/compiler/tf2xla/functionalize_control_flow_test.cc b/tensorflow/compiler/tf2xla/functionalize_control_flow_test.cc index 79a042ad680..951ebdd7ec1 100644 --- a/tensorflow/compiler/tf2xla/functionalize_control_flow_test.cc +++ b/tensorflow/compiler/tf2xla/functionalize_control_flow_test.cc @@ -27,12 +27,15 @@ limitations under the License. #include "tensorflow/core/common_runtime/function.h" #include "tensorflow/core/common_runtime/graph_constructor.h" #include "tensorflow/core/framework/function.h" +#include "tensorflow/core/framework/graph_to_functiondef.h" #include "tensorflow/core/framework/node_def_util.h" #include "tensorflow/core/framework/op.h" #include "tensorflow/core/graph/graph_def_builder.h" #include "tensorflow/core/graph/validate.h" #include "tensorflow/core/lib/core/status_test_util.h" #include "tensorflow/core/platform/test.h" +#include "tensorflow/core/public/version.h" +#include "tensorflow/core/util/dump_graph.h" #include "tensorflow/core/util/equal_graph_def.h" namespace tensorflow { @@ -63,18 +66,41 @@ Status FindIfThenAndElse(const GraphDef& graph, string* op_name, // math_ops.less(y, x), lambda: math_ops.multiply(y, 17), // lambda: math_ops.add(x, 23)) // -// Tests different node filters. -class ConditionalTestFixture : public ::testing::TestWithParam { +// Tests different node filters and functionalization inside of a function. +class ConditionalTestFixture + : public ::testing::TestWithParam> { protected: - void SetUp() override { restrict_to_tpu_nodes_ = GetParam(); } + void SetUp() override { + restrict_to_tpu_nodes_ = std::get<0>(GetParam()); + wrap_condition_in_function_ = std::get<1>(GetParam()); + } void RunTest(); private: + void BuildCondGraph(Graph* cond_graph); + void CheckGraphDef(const GraphDef& graph_def, + const FunctionLibraryDefinition& library); + bool restrict_to_tpu_nodes_ = false; + bool wrap_condition_in_function_ = false; }; -void ConditionalTestFixture::RunTest() { - Graph graph(OpRegistry::Global()); +TEST_P(ConditionalTestFixture, ConditionalTests) { RunTest(); } + +INSTANTIATE_TEST_SUITE_P( + FunctionalizeControlFlow, ConditionalTestFixture, + ::testing::Combine(::testing::Bool(), ::testing::Bool()), + [](const ::testing::TestParamInfo& + info) { + bool restrict_to_tpu_nodes = std::get<0>(info.param); + bool wrap_cond_in_function = std::get<1>(info.param); + string name = + absl::StrCat(restrict_to_tpu_nodes ? "with_filter" : "without_filter", + wrap_cond_in_function ? "_in_function" : "_in_graph"); + return name; + }); + +void ConditionalTestFixture::BuildCondGraph(Graph* cond_graph) { { Scope scope = Scope::NewRootScope().ExitOnError(); @@ -102,13 +128,117 @@ void ConditionalTestFixture::RunTest() { auto merge = ops::Merge(scope.WithOpName("cond/Merge"), std::initializer_list{add, mul}); - TF_EXPECT_OK(scope.ToGraph(&graph)); + TF_EXPECT_OK(scope.ToGraph(cond_graph)); // Set `_tpu_replicate` attribute for all nodes. - for (Node* n : graph.nodes()) { + for (Node* n : cond_graph->nodes()) { n->AddAttr("_tpu_replicate", "cluster"); } } +} + +void ConditionalTestFixture::CheckGraphDef( + const GraphDef& graph_def, const FunctionLibraryDefinition& library) { + string op_name; + NameAttrList then_fn; + NameAttrList else_fn; + TF_EXPECT_OK(FindIfThenAndElse(graph_def, &op_name, &then_fn, &else_fn)); + InstantiationResultForTest else_result; + TF_EXPECT_OK( + InstantiateFunctionForTest(else_fn.name(), library, &else_result)); + + // Outer graph + { + Scope scope = Scope::NewRootScope().ExitOnError(); + auto y = ops::Placeholder(scope.WithOpName("y"), DT_INT32); + auto x = ops::Placeholder(scope.WithOpName("x"), DT_INT32); + auto less = ops::Less(scope.WithOpName("cond/Less"), y, x); + auto if_op = + ops::If(scope.WithOpName(op_name), less, + std::initializer_list{less, y, x}, {DT_INT32}, then_fn, + else_fn, ops::If::OutputShapes({PartialTensorShape()})); + auto id = ops::Identity(scope.WithOpName("cond/Merge"), if_op.output[0]); + GraphDef expected; + TF_EXPECT_OK(scope.ToGraphDef(&expected)); + TF_EXPECT_GRAPH_EQ(expected, graph_def); + } + + // then body. + { + Scope scope = Scope::NewRootScope().ExitOnError(); + auto arg_0 = ops::_Arg(scope.WithOpName("arg0"), DT_BOOL, 0); + auto arg_1 = ops::_Arg(scope.WithOpName("arg1"), DT_INT32, 1); + auto arg_2 = ops::_Arg(scope.WithOpName("arg2"), DT_INT32, 2); + auto identity = ops::Identity(scope.WithOpName("cond/Identity"), arg_0); + auto cond = ops::Const( + scope.WithOpName("cond").WithControlDependencies(identity), 17); + auto mul = ops::Mul(scope.WithOpName("cond/Mul"), arg_1, cond); + auto retval0 = ops::_Retval(scope.WithOpName("retval0_RetVal"), mul, 0); + + GraphDef expected; + TF_EXPECT_OK(scope.ToGraphDef(&expected)); + + InstantiationResultForTest result; + TF_EXPECT_OK(InstantiateFunctionForTest(then_fn.name(), library, &result)); + + EXPECT_EQ(DataTypeVector{DT_INT32}, result.ret_types); + EXPECT_EQ((DataTypeVector{DT_BOOL, DT_INT32, DT_INT32}), result.arg_types); + TF_EXPECT_GRAPH_EQ(expected, result.gdef); + } + + // else body. + { + Scope scope = Scope::NewRootScope().ExitOnError(); + auto arg_0 = ops::_Arg(scope.WithOpName("arg0"), DT_BOOL, 0); + auto arg_1 = ops::_Arg(scope.WithOpName("arg1"), DT_INT32, 1); + auto arg_2 = ops::_Arg(scope.WithOpName("arg2"), DT_INT32, 2); + auto identity = ops::Identity(scope.WithOpName("cond/Identity_1"), arg_0); + auto cond_1 = ops::Const( + scope.WithOpName("cond_1").WithControlDependencies(identity), 23); + auto add = ops::Add(scope.WithOpName("cond/false/add"), arg_2, cond_1); + auto retval0 = ops::_Retval(scope.WithOpName("retval0_RetVal"), add, 0); + + GraphDef expected; + TF_EXPECT_OK(scope.ToGraphDef(&expected)); + + InstantiationResultForTest result; + TF_EXPECT_OK(InstantiateFunctionForTest(else_fn.name(), library, &result)); + + EXPECT_EQ(DataTypeVector{DT_INT32}, result.ret_types); + EXPECT_EQ((DataTypeVector{DT_BOOL, DT_INT32, DT_INT32}), result.arg_types); + TF_EXPECT_GRAPH_EQ(expected, result.gdef); + } +} + +void ConditionalTestFixture::RunTest() { + Graph graph(OpRegistry::Global()); + if (wrap_condition_in_function_) { + // Wrap condition in a function which is called from `graph`. + Scope scope = Scope::NewRootScope().ExitOnError(); + auto source = ops::Placeholder(scope.WithOpName("source"), DT_INT32); + + Graph cond_graph(OpRegistry::Global()); + BuildCondGraph(&cond_graph); + + FunctionDef cond_fdef; + TF_ASSERT_OK(GraphToFunctionDef(cond_graph, "cond_fn", &cond_fdef)); + + FunctionDefLibrary fdef_lib; + *(fdef_lib.add_function()) = cond_fdef; + TF_ASSERT_OK(scope.graph()->AddFunctionLibrary(fdef_lib)); + NodeDef cond_fn; + cond_fn.set_name("cond_node"); + cond_fn.set_op("cond_fn"); + *(cond_fn.add_input()) = "source"; + Status status; + scope.graph()->AddNode(cond_fn, &status); + TF_ASSERT_OK(status); + TF_ASSERT_OK(scope.ToGraph(&graph)); + } else { + // Build condition in `graph`. + BuildCondGraph(&graph); + } + FunctionLibraryDefinition library(graph.flib_def()); // If `restrict_to_tpu_nodes_` is true let filter function return true for // `_tpu_replicate` nodes. NodeFilter node_filter = @@ -116,99 +246,47 @@ void ConditionalTestFixture::RunTest() { ? [](const Node* n) { return n->attrs().Find("_tpu_replicate"); } : NodeFilter{}; - FunctionLibraryDefinition library(OpRegistry::Global(), {}); GraphDef optimized_graph_def; graph.ToGraphDef(&optimized_graph_def); - TF_ASSERT_OK(FunctionalizeControlFlowForGraphDef(&optimized_graph_def, - &library, node_filter)); - TF_ASSERT_OK(FunctionalizeControlFlow(&graph, &library, node_filter)); - GraphDef converted_graph_def; - graph.ToGraphDef(&converted_graph_def); + TF_ASSERT_OK(FunctionalizeControlFlowForGraphDef( + &optimized_graph_def, &library, node_filter, + /*include_functions=*/wrap_condition_in_function_)); + TF_ASSERT_OK(FunctionalizeControlFlow( + &graph, &library, node_filter, + /*include_functions=*/wrap_condition_in_function_)); - for (const GraphDef& graph_def : {optimized_graph_def, converted_graph_def}) { - string op_name; - NameAttrList then_fn; - NameAttrList else_fn; - TF_EXPECT_OK(FindIfThenAndElse(graph_def, &op_name, &then_fn, &else_fn)); - InstantiationResultForTest else_result; - TF_EXPECT_OK( - InstantiateFunctionForTest(else_fn.name(), library, &else_result)); + if (wrap_condition_in_function_) { + // Check if function body was functionalized. + auto pflr = absl::make_unique( + /*device_mgr=*/nullptr, tensorflow::Env::Default(), + /*config=*/nullptr, TF_GRAPH_DEF_VERSION, &library, + tensorflow::OptimizerOptions()); + FunctionLibraryRuntime* flr = + pflr->GetFLR(ProcessFunctionLibraryRuntime::kDefaultFLRDevice); + FunctionLibraryRuntime::Handle handle; - // Outer graph - { - Scope scope = Scope::NewRootScope().ExitOnError(); - auto y = ops::Placeholder(scope.WithOpName("y"), DT_INT32); - auto x = ops::Placeholder(scope.WithOpName("x"), DT_INT32); - auto less = ops::Less(scope.WithOpName("cond/Less"), y, x); - auto if_op = - ops::If(scope.WithOpName(op_name), less, - std::initializer_list{less, y, x}, {DT_INT32}, then_fn, - else_fn, ops::If::OutputShapes({PartialTensorShape()})); - auto id = ops::Identity(scope.WithOpName("cond/Merge"), if_op.output[0]); - GraphDef expected; - TF_EXPECT_OK(scope.ToGraphDef(&expected)); - TF_EXPECT_GRAPH_EQ(expected, graph_def); - } - - // then body. - { - Scope scope = Scope::NewRootScope().ExitOnError(); - auto arg_0 = ops::_Arg(scope.WithOpName("arg0"), DT_BOOL, 0); - auto arg_1 = ops::_Arg(scope.WithOpName("arg1"), DT_INT32, 1); - auto arg_2 = ops::_Arg(scope.WithOpName("arg2"), DT_INT32, 2); - auto identity = ops::Identity(scope.WithOpName("cond/Identity"), arg_0); - auto cond = ops::Const( - scope.WithOpName("cond").WithControlDependencies(identity), 17); - auto mul = ops::Mul(scope.WithOpName("cond/Mul"), arg_1, cond); - auto retval0 = ops::_Retval(scope.WithOpName("retval0_RetVal"), mul, 0); - - GraphDef expected; - TF_EXPECT_OK(scope.ToGraphDef(&expected)); - - InstantiationResultForTest result; - TF_EXPECT_OK( - InstantiateFunctionForTest(then_fn.name(), library, &result)); - - EXPECT_EQ(DataTypeVector{DT_INT32}, result.ret_types); - EXPECT_EQ((DataTypeVector{DT_BOOL, DT_INT32, DT_INT32}), - result.arg_types); - TF_EXPECT_GRAPH_EQ(expected, result.gdef); - } - - // else body. - { - Scope scope = Scope::NewRootScope().ExitOnError(); - auto arg_0 = ops::_Arg(scope.WithOpName("arg0"), DT_BOOL, 0); - auto arg_1 = ops::_Arg(scope.WithOpName("arg1"), DT_INT32, 1); - auto arg_2 = ops::_Arg(scope.WithOpName("arg2"), DT_INT32, 2); - auto identity = ops::Identity(scope.WithOpName("cond/Identity_1"), arg_0); - auto cond_1 = ops::Const( - scope.WithOpName("cond_1").WithControlDependencies(identity), 23); - auto add = ops::Add(scope.WithOpName("cond/false/add"), arg_2, cond_1); - auto retval0 = ops::_Retval(scope.WithOpName("retval0_RetVal"), add, 0); - - GraphDef expected; - TF_EXPECT_OK(scope.ToGraphDef(&expected)); - - InstantiationResultForTest result; - TF_EXPECT_OK( - InstantiateFunctionForTest(else_fn.name(), library, &result)); - - EXPECT_EQ(DataTypeVector{DT_INT32}, result.ret_types); - EXPECT_EQ((DataTypeVector{DT_BOOL, DT_INT32, DT_INT32}), - result.arg_types); - TF_EXPECT_GRAPH_EQ(expected, result.gdef); + // Functionalized function name is the type string of `cond_node`. + string func_name; + for (Node* n : graph.nodes()) { + if (n->name() == "cond_node") { + func_name = n->type_string(); + break; + } } + TF_ASSERT_OK(flr->Instantiate(func_name, AttrSlice(), &handle)); + const FunctionBody* body = flr->GetFunctionBody(handle); + GraphDef graph_def; + body->graph->ToGraphDef(&graph_def); + CheckGraphDef(graph_def, library); + } else { + // Check if graphs were functionalized. + CheckGraphDef(optimized_graph_def, library); + GraphDef converted_graph_def; + graph.ToGraphDef(&converted_graph_def); + CheckGraphDef(converted_graph_def, library); } } -TEST_P(ConditionalTestFixture, ConditionalTests) { RunTest(); } - -INSTANTIATE_TEST_SUITE_P( - FunctionalizeControlFlow, ConditionalTestFixture, ::testing::Bool(), - [](const ::testing::TestParamInfo& - info) { return info.param ? "with_filter" : "without_filter"; }); - // Returns the names of the "cond" and "body" functions for the While node // in a graph. Status FindWhileCondAndBody(const GraphDef& graph, NameAttrList* cond, diff --git a/tensorflow/compiler/tf2xla/kernels/broadcast_to_op.cc b/tensorflow/compiler/tf2xla/kernels/broadcast_to_op.cc index d7a8e67dd33..807c061b60f 100644 --- a/tensorflow/compiler/tf2xla/kernels/broadcast_to_op.cc +++ b/tensorflow/compiler/tf2xla/kernels/broadcast_to_op.cc @@ -16,6 +16,7 @@ limitations under the License. #include "tensorflow/compiler/tf2xla/lib/broadcast.h" #include "tensorflow/compiler/tf2xla/xla_op_kernel.h" #include "tensorflow/compiler/tf2xla/xla_op_registry.h" +#include "tensorflow/compiler/xla/client/xla_builder.h" #include "tensorflow/core/platform/macros.h" #include "tensorflow/core/platform/types.h" @@ -28,13 +29,26 @@ class BroadcastToOp : public XlaOpKernel { : XlaOpKernel(context) {} void Compile(XlaOpKernelContext* context) override { - const TensorShape input_shape = context->InputShape(0); TensorShape output_shape; OP_REQUIRES_OK(context, context->ConstantInputAsShape(1, &output_shape)); + auto output_status_or = + BroadcastTo(context->Input(0), output_shape.dim_sizes()); + OP_REQUIRES_OK(context, output_status_or.status()); + auto output = output_status_or.ValueOrDie(); + std::vector dynamic_dims; + OP_REQUIRES_OK( + context, context->ResolveInputDynamismIntoPredVector(1, &dynamic_dims)); + for (int64 dim = 0; dim < dynamic_dims.size(); ++dim) { + if (dynamic_dims[dim]) { + output = xla::SetDimensionSize( + output, + xla::Reshape(xla::Slice(context->Input(1), {dim}, {dim + 1}, {1}), + {}), + dim); + } + } - auto output = BroadcastTo(context->Input(0), output_shape.dim_sizes()); - OP_REQUIRES_OK(context, output.status()); - context->SetOutput(0, output.ValueOrDie()); + context->SetOutput(0, output); } }; diff --git a/tensorflow/compiler/tf2xla/kernels/replica_id_op.cc b/tensorflow/compiler/tf2xla/kernels/replica_id_op.cc index 46585a26769..71920372cde 100644 --- a/tensorflow/compiler/tf2xla/kernels/replica_id_op.cc +++ b/tensorflow/compiler/tf2xla/kernels/replica_id_op.cc @@ -30,7 +30,8 @@ class XlaReplicaIdOp : public XlaOpKernel { }; void XlaReplicaIdOp::Compile(XlaOpKernelContext* ctx) { - ctx->SetOutput(0, xla::ReplicaId(ctx->builder())); + ctx->SetOutput( + 0, xla::ConvertElementType(xla::ReplicaId(ctx->builder()), xla::S32)); } REGISTER_XLA_OP(Name("XlaReplicaId"), XlaReplicaIdOp); diff --git a/tensorflow/compiler/tf2xla/kernels/reshape_op.cc b/tensorflow/compiler/tf2xla/kernels/reshape_op.cc index a85ba547179..213045e428a 100644 --- a/tensorflow/compiler/tf2xla/kernels/reshape_op.cc +++ b/tensorflow/compiler/tf2xla/kernels/reshape_op.cc @@ -19,8 +19,10 @@ limitations under the License. #include "tensorflow/compiler/tf2xla/xla_helpers.h" #include "tensorflow/compiler/tf2xla/xla_op_kernel.h" #include "tensorflow/compiler/tf2xla/xla_op_registry.h" +#include "tensorflow/compiler/xla/client/lib/constants.h" #include "tensorflow/compiler/xla/client/xla_builder.h" #include "tensorflow/compiler/xla/literal.h" +#include "tensorflow/compiler/xla/util.h" #include "tensorflow/core/framework/op_kernel.h" #include "tensorflow/core/framework/register_types.h" #include "tensorflow/core/framework/tensor.h" @@ -108,38 +110,73 @@ class ReshapeOp : public XlaOpKernel { VLOG(2) << "Reshape from " << input_shape.DebugString() << " to " << shape.DebugString() << ", unknown_index=" << unknown_index; + auto input_xla_shape = ctx->InputXlaShape(0); + if (input_xla_shape->is_static()) { + ctx->SetOutput(0, xla::Reshape(ctx->Input(0), shape.dim_sizes())); + return; + } + // Handing dynamic reshapes if input contains a dynamic dimension. + std::vector output_dim_sizes; + std::vector dims_are_dynamic; + for (int64 i = 0; i < shape.dims(); ++i) { + output_dim_sizes.push_back( + xla::Reshape(xla::Slice(ctx->Input(1), {i}, {i + 1}, {1}), {})); + } + OP_REQUIRES_OK( + ctx, ctx->ResolveInputDynamismIntoPredVector(1, &dims_are_dynamic)); + if (unknown_index == -1) { + // No unknown index. + ctx->SetOutput(0, + xla::DynamicReshape(ctx->Input(0), output_dim_sizes, + shape.dim_sizes(), dims_are_dynamic)); + return; + } + auto common_factors = + xla::CommonFactors(input_shape.dim_sizes(), shape.dim_sizes()); - int dynamic_dimension = -1; - if (ctx->InputXlaShape(0)->is_dynamic()) { - std::vector dynamic_dims; - OP_REQUIRES_OK(ctx, - ctx->ResolveInputDynamismIntoPredVector(1, &dynamic_dims)); - for (int d = 0; d < num_dims; ++d) { - const bool dim_is_dynamic = dynamic_dims[d]; - if (dim_is_dynamic) { - dynamic_dimension = d; + // Find common_factors that the input belongs to. + for (int64 i = 0; i < common_factors.size() - 1; ++i) { + auto start = common_factors[i]; + auto end = common_factors[i + 1]; + bool input_is_dynamic = false; + // product of all input dims in this group. E.g., in + // reshape(Tensor([2, 3, 3]), [3, -1, 3]) product of the group + // containing -1 will be 6. + xla::XlaOp product = xla::One(ctx->builder(), xla::S32); + for (int64 dim = start.first; dim < end.first; ++dim) { + if (input_xla_shape->is_dynamic_dimension(dim)) { + input_is_dynamic = true; + } + product = xla::Mul(product, xla::GetDimensionSize(ctx->Input(0), dim)); + } + bool unknown_dim_in_group = false; + // The real size for the -1 dimension in a reshape. E.g., in + // reshape(Tensor([2, 3, 3]), [3, -1, 3]) this will be 2. + xla::XlaOp unknown_dim_size = product; + for (int64 dim = start.second; dim < end.second; ++dim) { + if (dim == unknown_index) { + unknown_dim_in_group = true; + } else { + unknown_dim_size = xla::Div(unknown_dim_size, output_dim_sizes[dim]); } } - // When reshaping from dynamic dimension, unkwown index is considered - // dynamic. E.g., - // [<=10] - // | - // Reshape - // | - // [2, -1] - // The second dimension is dynamic. - if (dynamic_dimension == -1) { - dynamic_dimension = unknown_index; + if (unknown_dim_in_group) { + // If input dim is dynamic, output dim at the -1 position must be + // dynamic. Similarly, if input dim is static, output dim has to be + // static at the -1 dimension. + dims_are_dynamic[unknown_index] = input_is_dynamic; + output_dim_sizes[unknown_index] = unknown_dim_size; + + ctx->SetOutput( + 0, xla::DynamicReshape(ctx->Input(0), output_dim_sizes, + shape.dim_sizes(), dims_are_dynamic)); + VLOG(2) << "Reshape from " << ctx->InputXlaShape(0)->ToString() + << " to " << xla::VectorString(shape.dim_sizes()) + << ", dynamic_dims=" << xla::VectorString(dims_are_dynamic); + return; } - VLOG(2) << "Reshape from " << ctx->InputXlaShape(0)->ToString() << " to " - << xla::VectorString(shape.dim_sizes()) - << ", dynamic_dim=" << dynamic_dimension; } - // Pass unknown_index to Xla::Reshape as a hint for dynamic shape inference - // in XLA to know which output dimension is dynamic. - ctx->SetOutput(0, xla::ReshapeWithInferredDimension( - ctx->Input(0), shape.dim_sizes(), dynamic_dimension)); } }; diff --git a/tensorflow/compiler/tf2xla/kernels/segment_reduction_ops.cc b/tensorflow/compiler/tf2xla/kernels/segment_reduction_ops.cc index 97359f81eee..d63b8146491 100644 --- a/tensorflow/compiler/tf2xla/kernels/segment_reduction_ops.cc +++ b/tensorflow/compiler/tf2xla/kernels/segment_reduction_ops.cc @@ -74,12 +74,44 @@ class UnsortedSegmentReduce : public XlaOpKernel { " vs. ", indices_shape.dim_size(d))); } xla::XlaBuilder* builder = ctx->builder(); + // data shape = [indices_shape, segment_shape] + // buffer shape = [num_segment, segment_shape] + // We now create the buffer shape by reverse enginerring data shape into + // indices shape and segment shape. TensorShape buffer_shape = data_shape; buffer_shape.RemoveDimRange(0, indices_shape.dims()); buffer_shape.InsertDim(0, num_segments); + auto buffer = xla::Broadcast(InitialValue(builder), buffer_shape.dim_sizes()); + // Build dynamic dim sizes for buffer, as well as whether each dimension + // size is dynamic or static. We build two parts: num_sgement part and + // segment_shape part. + std::vector buffer_dims; + std::vector buffer_dims_are_dynamic; + // Build the "num_segment" part. + bool num_segments_is_dynamic; + OP_REQUIRES_OK( + ctx, ctx->ResolveInputDynamismIntoPred(2, &num_segments_is_dynamic)); + + buffer_dims.insert(buffer_dims.begin(), ctx->Input(2)); + buffer_dims_are_dynamic.insert(buffer_dims_are_dynamic.begin(), + num_segments_is_dynamic); + // Build the segment shape part. + for (int64 i = indices_shape.dims(); i < data_shape.dims(); ++i) { + buffer_dims.push_back(xla::GetDimensionSize(data, i)); + buffer_dims_are_dynamic.push_back( + ctx->InputXlaShape(0)->is_dynamic_dimension(i)); + } + + for (int64 i = 0; i < buffer_dims.size(); ++i) { + if (buffer_dims_are_dynamic[i]) { + // For each dynamic dimension, call set-dimension-size on it. + buffer = xla::SetDimensionSize(buffer, buffer_dims[i], i); + } + } + auto combiner = [this](xla::XlaOp a, xla::XlaOp b, xla::XlaBuilder* builder) { return Combine(a, b); }; diff --git a/tensorflow/compiler/tf2xla/kernels/strided_slice_op.cc b/tensorflow/compiler/tf2xla/kernels/strided_slice_op.cc index 784b790767c..72cb746f5ff 100644 --- a/tensorflow/compiler/tf2xla/kernels/strided_slice_op.cc +++ b/tensorflow/compiler/tf2xla/kernels/strided_slice_op.cc @@ -15,6 +15,9 @@ limitations under the License. #include "tensorflow/core/util/strided_slice_op.h" +#include + +#include "absl/algorithm/container.h" #include "absl/types/span.h" #include "tensorflow/compiler/tf2xla/literal_util.h" #include "tensorflow/compiler/tf2xla/type_util.h" @@ -23,6 +26,7 @@ limitations under the License. #include "tensorflow/compiler/tf2xla/xla_op_registry.h" #include "tensorflow/compiler/xla/client/lib/constants.h" #include "tensorflow/compiler/xla/client/xla_builder.h" +#include "tensorflow/compiler/xla/util.h" #include "tensorflow/core/framework/op_kernel.h" #include "tensorflow/core/framework/ops_util.h" #include "tensorflow/core/framework/register_types.h" @@ -33,6 +37,7 @@ limitations under the License. namespace tensorflow { namespace { +using errors::InvalidArgument; class StridedSliceOp : public XlaOpKernel { public: @@ -48,7 +53,7 @@ class StridedSliceOp : public XlaOpKernel { void Compile(XlaOpKernelContext* ctx) override { const TensorShape input_shape = ctx->InputShape(0); const TensorShape begin_shape = ctx->InputShape("begin"); - + VLOG(0) << "strided slice"; OP_REQUIRES( ctx, begin_shape.dims() == 1, errors::InvalidArgument("'begin' input has to be a rank 1 vector")); @@ -78,20 +83,24 @@ class StridedSliceOp : public XlaOpKernel { TensorShape final_shape; PartialTensorShape dummy_processing_shape, partial_final_shape; bool dummy = false; - OP_REQUIRES_OK(ctx, ValidateStridedSliceOp( - begin_is_constant ? &begin_tensor : nullptr, - end_is_constant ? &end_tensor : nullptr, - strides_tensor, input_shape, begin_mask_, end_mask_, - ellipsis_mask_, new_axis_mask_, shrink_axis_mask_, - &dummy_processing_shape, &partial_final_shape, - &dummy, &dummy, &dummy, &begin, &end, &strides)); + absl::InlinedVector output_to_sparse_mapping; + absl::InlinedVector output_to_processing_mapping; + OP_REQUIRES_OK( + ctx, + ValidateStridedSliceOp( + begin_is_constant ? &begin_tensor : nullptr, + end_is_constant ? &end_tensor : nullptr, strides_tensor, + input_shape, begin_mask_, end_mask_, ellipsis_mask_, new_axis_mask_, + shrink_axis_mask_, &dummy_processing_shape, &partial_final_shape, + &dummy, &dummy, &dummy, &begin, &end, &strides, + &output_to_sparse_mapping, &output_to_processing_mapping)); - OP_REQUIRES(ctx, partial_final_shape.AsTensorShape(&final_shape), - errors::InvalidArgument( - "XLA can't deduce compile time constant output " - "shape for strided slice: ", - partial_final_shape.DebugString(), - ", output shape must be a compile-time constant")); + OP_REQUIRES( + ctx, partial_final_shape.AsTensorShape(&final_shape), + InvalidArgument("XLA can't deduce compile time constant output " + "shape for strided slice: ", + partial_final_shape.DebugString(), + ", output shape must be a compile-time constant")); xla::XlaOp slice = ctx->Input(0); if (begin_is_constant && end_is_constant) { @@ -119,69 +128,84 @@ class StridedSliceOp : public XlaOpKernel { auto operand_shape_or = ctx->builder()->GetShape(ctx->Input(0)); OP_REQUIRES_OK(ctx, operand_shape_or.status()); xla::Shape xla_shape = operand_shape_or.ValueOrDie(); - if (xla_shape.is_static()) { - // Static output shape, return a static slice. - slice = xla::Reshape(slice, final_shape.dim_sizes()); + std::vector begins_are_dynamic; + OP_REQUIRES_OK( + ctx, ctx->ResolveInputDynamismIntoPredVector(1, &begins_are_dynamic)); + std::vector ends_are_dynamic; + OP_REQUIRES_OK( + ctx, ctx->ResolveInputDynamismIntoPredVector(2, &ends_are_dynamic)); + bool begins_are_static = absl::c_all_of( + begins_are_dynamic, [](bool dynamic) { return !dynamic; }); + OP_REQUIRES(ctx, begins_are_static, + errors::InvalidArgument( + "XLA can't use dynamic begin values for slice.")); + bool ends_are_static = absl::c_all_of( + ends_are_dynamic, [](bool dynamic) { return !dynamic; }); + // Static output shape, return a static slice. + slice = xla::Reshape(slice, final_shape.dim_sizes()); + if (xla_shape.is_static() && ends_are_static) { ctx->SetOutput(0, slice); return; } - auto input_dim_sizes = input_shape.dim_sizes(); - for (int64 i = 0; i < xla_shape.rank(); ++i) { - if (xla_shape.is_dynamic_dimension(i)) { - input_dim_sizes[i] = -1; + for (int64 i = 0; i < final_shape.dims(); ++i) { + int64 input_index = output_to_processing_mapping[i]; + if (input_index == -1) { + continue; } - } - PartialTensorShape input_partial_shape(input_dim_sizes); - partial_final_shape.Clear(); - end.clear(); - strides.clear(); - begin.clear(); - // Run shape inferenference again with partial shape. - OP_REQUIRES_OK(ctx, ValidateStridedSliceOp( - &begin_tensor, &end_tensor, strides_tensor, - input_partial_shape, begin_mask_, end_mask_, - ellipsis_mask_, new_axis_mask_, shrink_axis_mask_, - &dummy_processing_shape, &partial_final_shape, - &dummy, &dummy, &dummy, &begin, &end, &strides)); - if (partial_final_shape.AsTensorShape(&final_shape)) { - // Static output shape, return a static slice. - slice = xla::Reshape(slice, final_shape.dim_sizes()); - ctx->SetOutput(0, slice); - return; - } + bool input_is_dynamic = xla_shape.is_dynamic_dimension(input_index); - // We consider slicing a dynamic tensor t with negative indices as a - // dynamic sized slice. E.g., t[: -n], the result length is shape(t) - n - for (int64 i = 0; i < partial_final_shape.dims(); ++i) { - bool dynamic_dim = partial_final_shape.dim_size(i) - 1; - bool backward_slice = end[i] < 0; - if (dynamic_dim && backward_slice) { + int64 sparse_index = output_to_sparse_mapping[i]; + bool end_is_dynamic = + sparse_index == -1 ? false : ends_are_dynamic[sparse_index]; + bool backward_slice = sparse_index == -1 + ? false + : end_literal.Get({sparse_index}) < 0; + if ((input_is_dynamic && backward_slice) || end_is_dynamic) { OP_REQUIRES( - ctx, strides[i] == 1, + ctx, strides[input_index] == 1, errors::InvalidArgument("XLA has not implemented dynamic " "sized slice with non-trival stride yet. " "Please file a bug against XLA")); - - OP_REQUIRES(ctx, begin[i] >= 0, - errors::InvalidArgument( - "XLA has not implemented dynamic " - "sized slice with negative begin index %lld. " - "Please file a bug against XLA", - begin[i])); // If there is a dynamic dimension, properly set dimension size of // the result. - auto operand_size = xla::GetDimensionSize(ctx->Input(0), i); - - operand_size = xla::Add( - operand_size, xla::ConstantR0(ctx->builder(), end[i])); + auto operand_size = xla::GetDimensionSize(ctx->Input(0), input_index); + if (backward_slice) { + // We consider slicing a dynamic tensor t with negative indices as + // a dynamic sized slice. E.g., t[: -n], the result length is + // shape(t) - n. + OP_REQUIRES(ctx, !end_is_dynamic, + errors::InvalidArgument( + "XLA has not implemented dynamic " + "sized slice with dynamic negative index %lld. ")); + operand_size = xla::Add( + operand_size, + xla::ConstantR0(ctx->builder(), + end_literal.Get({sparse_index}))); + } else { + // The end of slice with dynamic slice size is the min of operand + // shape and slice size. E.g., t[:end_size], result size is + // min(shape(t), end_size). + xla::XlaOp end_size; + if (end_is_dynamic) { + end_size = xla::Reshape(xla::Slice(ctx->Input(2), {sparse_index}, + {sparse_index + 1}, {1}), + {}); + } else { + end_size = + xla::ConstantR0(ctx->builder(), end[input_index]); + } + operand_size = xla::Min(operand_size, end_size); + } slice = xla::SetDimensionSize( slice, - xla::Sub(operand_size, - xla::ConstantR0(ctx->builder(), begin[i])), + xla::Sub(operand_size, xla::ConstantR0( + ctx->builder(), begin[input_index])), i); } } + ctx->SetOutput(0, slice); + return; } else { // When output shape is fully defined, it must be a size one slice: // @@ -239,9 +263,9 @@ class StridedSliceOp : public XlaOpKernel { std::vector output_shape_dim_sizes; slice = xla::DynamicSlice(slice, start_indices, slice_sizes); + slice = xla::Reshape(slice, final_shape.dim_sizes()); + ctx->SetOutput(0, slice); } - slice = xla::Reshape(slice, final_shape.dim_sizes()); - ctx->SetOutput(0, slice); } private: diff --git a/tensorflow/compiler/tf2xla/lib/data_format.cc b/tensorflow/compiler/tf2xla/lib/data_format.cc index e5913a8bbf3..eb1ab79d165 100644 --- a/tensorflow/compiler/tf2xla/lib/data_format.cc +++ b/tensorflow/compiler/tf2xla/lib/data_format.cc @@ -62,7 +62,7 @@ xla::StatusOr Expand(xla::XlaOp input, int64 dim) { std::vector expanded_shape = xla::SpanToVector(input_shape.dimensions()); expanded_shape[dim] /= 4; - expanded_shape.insert(expanded_shape.begin() + dim, 4); + expanded_shape.insert(expanded_shape.begin() + dim + 1, 4); // Move the newly created dimension to the end with a transpose. std::vector permutation; diff --git a/tensorflow/compiler/tf2xla/mlir_tf2xla.cc b/tensorflow/compiler/tf2xla/mlir_tf2xla.cc index abaeb305104..db1a6929934 100644 --- a/tensorflow/compiler/tf2xla/mlir_tf2xla.cc +++ b/tensorflow/compiler/tf2xla/mlir_tf2xla.cc @@ -152,6 +152,7 @@ Status ConvertGraphDefToXlaViaMlir( RegisterDialects(); mlir::MLIRContext context; + context.loadAllGloballyRegisteredDialects(); TF_ASSIGN_OR_RETURN( mlir::OwningModuleRef module, ConvertGraphdefToMlir(pruned_graph_def, debug_info, specs, &context)); diff --git a/tensorflow/compiler/tf2xla/tf2xla.cc b/tensorflow/compiler/tf2xla/tf2xla.cc index 242a2b04ab9..3cf9df64b0b 100644 --- a/tensorflow/compiler/tf2xla/tf2xla.cc +++ b/tensorflow/compiler/tf2xla/tf2xla.cc @@ -137,7 +137,6 @@ Status ConvertVarHandlesToAotVarHandles(GraphDef* graph_def) { const auto& it = node.attr().find("allowed_devices"); if (it != node.attr().end()) { if (!it->second.list().s().empty()) { - // TODO(b/149512838): Support non-empty allowed devices. return errors::InvalidArgument( "VarHandleOp with non-empty allowed devices is not supported."); } diff --git a/tensorflow/compiler/tf2xla/xla_compiler.cc b/tensorflow/compiler/tf2xla/xla_compiler.cc index 635b7170d82..f8319cd446a 100644 --- a/tensorflow/compiler/tf2xla/xla_compiler.cc +++ b/tensorflow/compiler/tf2xla/xla_compiler.cc @@ -36,6 +36,7 @@ limitations under the License. #include "tensorflow/compiler/xla/client/client_library.h" #include "tensorflow/compiler/xla/client/xla_builder.h" #include "tensorflow/compiler/xla/client/xla_computation.h" +#include "tensorflow/compiler/xla/protobuf_util.h" #include "tensorflow/compiler/xla/shape_util.h" #include "tensorflow/compiler/xla/util.h" #include "tensorflow/core/common_runtime/device.h" @@ -990,20 +991,6 @@ Status XlaCompiler::BuildArguments( tuple = xla::Parameter(builder, 0, (*input_shapes)[0], "arg_tuple"); } - for (int i = 0, end = input_to_args->size(); i < end; ++i) { - const XlaCompiler::Argument& arg = args[input_to_args->at(i)]; - for (const auto& dim_and_arg_num : arg.dynamic_dim_to_arg_num_map) { - int dynamic_size_param_index = arg_to_inputs.at(dim_and_arg_num.second); - VLOG(1) << "Setting dynamic binding " << i << " -> " - << dynamic_size_param_index; - - TF_RETURN_IF_ERROR(builder->SetDynamicBinding( - /*dynamic_size_param_num=*/0, {dynamic_size_param_index}, - /*target_param_num=*/0, /*target_param_index=*/{i}, - dim_and_arg_num.first)); - } - } - for (std::vector::size_type i = 0; i < input_to_args->size(); ++i) { auto it = arg_shardings.find(i); xla::XlaScopedShardingAssignment assign_sharding( @@ -1035,16 +1022,17 @@ Status XlaCompiler::BuildArguments( absl::StrCat("arg", i)); } } + } - for (int i = 0, end = input_to_args->size(); i < end; ++i) { - const XlaCompiler::Argument& arg = args[input_to_args->at(i)]; - for (const auto& dim_and_arg_num : arg.dynamic_dim_to_arg_num_map) { - int dynamic_size_param_index = arg_to_inputs.at(dim_and_arg_num.second); - TF_RETURN_IF_ERROR(builder->SetDynamicBinding( - /*dynamic_size_param_num=*/dynamic_size_param_index, {}, - /*target_param_num=*/i, /*target_param_index=*/{}, - dim_and_arg_num.first)); - } + for (int i = 0, end = input_to_args->size(); i < end; ++i) { + const XlaCompiler::Argument& arg = args[input_to_args->at(i)]; + for (const auto& dim_and_arg_num : arg.dynamic_dim_to_arg_num_map) { + int dynamic_size_param_index = arg_to_inputs.at(dim_and_arg_num.second); + VLOG(1) << "Setting dynamic size " << i << " -> " + << dynamic_size_param_index; + arg_handles[i] = xla::SetDimensionSize( + arg_handles[i], arg_handles[dynamic_size_param_index], + dim_and_arg_num.first); } } @@ -1370,8 +1358,15 @@ Status XlaCompiler::SetDeviceToHostMetadata( const string& key, absl::Span types, absl::Span shapes) { if (host_compute_sends_.find(key) != host_compute_sends_.end()) { - return errors::InvalidArgument( - "Duplicate calls to SetDeviceToHostMetadata with key ", key); + tf2xla::HostTransferMetadata& existing_transfer = host_compute_sends_[key]; + tf2xla::HostTransferMetadata new_transfer; + SetTransfer(key, types, shapes, &new_transfer); + if (xla::protobuf_util::ProtobufEquals(existing_transfer, new_transfer)) { + return Status::OK(); + } else { + return errors::InvalidArgument( + "Duplicate calls to SetDeviceToHostMetadata with key ", key); + } } tf2xla::HostTransferMetadata& transfer = host_compute_sends_[key]; SetTransfer(key, types, shapes, &transfer); @@ -1396,9 +1391,16 @@ Status XlaCompiler::GetDeviceToHostShapes( Status XlaCompiler::SetHostToDeviceMetadata( const string& key, absl::Span types, absl::Span shapes) { - if (host_compute_recvs_.find(key) != host_compute_sends_.end()) { - return errors::InvalidArgument( - "Duplicate calls to SetHostToDeviceMetadata with key ", key); + if (host_compute_recvs_.find(key) != host_compute_recvs_.end()) { + tf2xla::HostTransferMetadata& existing_transfer = host_compute_recvs_[key]; + tf2xla::HostTransferMetadata new_transfer; + SetTransfer(key, types, shapes, &new_transfer); + if (xla::protobuf_util::ProtobufEquals(existing_transfer, new_transfer)) { + return Status::OK(); + } else { + return errors::InvalidArgument( + "Duplicate calls to SetHostToDeviceMetadata with key ", key); + } } tf2xla::HostTransferMetadata& transfer = host_compute_recvs_[key]; SetTransfer(key, types, shapes, &transfer); diff --git a/tensorflow/compiler/tf2xla/xla_compiler_test.cc b/tensorflow/compiler/tf2xla/xla_compiler_test.cc index 5df508d60b3..f348552050b 100644 --- a/tensorflow/compiler/tf2xla/xla_compiler_test.cc +++ b/tensorflow/compiler/tf2xla/xla_compiler_test.cc @@ -1897,5 +1897,63 @@ TEST_F(XlaCompilerTest, AliasResourceUpdates) { EXPECT_EQ(alias.entries(0).parameter_number(), 0); } +// Tests that passing in an exact duplicate input to SetDeviceToHostMeatadata +// is not an error. +TEST_F(XlaCompilerTest, SetDeviceToHostMetadataExactDuplicate) { + XlaCompiler compiler(DefaultOptions()); + + const string& key = "comm_key"; + std::vector types{DT_INT32}; + std::vector shapes{TensorShape({2})}; + + TF_ASSERT_OK(compiler.SetDeviceToHostMetadata(key, types, shapes)); + TF_ASSERT_OK(compiler.SetDeviceToHostMetadata(key, types, shapes)); +} + +// Tests that passing in a mismatched duplicate input to +// SetDeviceToHostMeatadata is not an error. +TEST_F(XlaCompilerTest, SetDeviceToHostMetadataMismatchedDuplicate) { + XlaCompiler compiler(DefaultOptions()); + + const string& key = "comm_key"; + std::vector types{DT_INT32}; + std::vector shapes{TensorShape({2})}; + std::vector types2{DT_FLOAT}; + std::vector shapes2{TensorShape({1})}; + + TF_ASSERT_OK(compiler.SetDeviceToHostMetadata(key, types, shapes)); + Status status = compiler.SetDeviceToHostMetadata(key, types2, shapes2); + EXPECT_EQ(status.code(), error::Code::INVALID_ARGUMENT); +} + +// Tests that passing in an exact duplicate input to SetHostToDeviceMeatadata +// is not an error. +TEST_F(XlaCompilerTest, SetHostToDeviceMetadataExactDuplicate) { + XlaCompiler compiler(DefaultOptions()); + + const string& key = "comm_key"; + std::vector types{DT_INT32}; + std::vector shapes{TensorShape({2})}; + + TF_ASSERT_OK(compiler.SetHostToDeviceMetadata(key, types, shapes)); + TF_ASSERT_OK(compiler.SetHostToDeviceMetadata(key, types, shapes)); +} + +// Tests that passing in a mismatched duplicate input to +// SetHostToDeviceMeatadata is not an error. +TEST_F(XlaCompilerTest, SetHostToDeviceMetadataMismatchedDuplicate) { + XlaCompiler compiler(DefaultOptions()); + + const string& key = "comm_key"; + std::vector types{DT_INT32}; + std::vector shapes{TensorShape({2})}; + std::vector types2{DT_FLOAT}; + std::vector shapes2{TensorShape({1})}; + + TF_ASSERT_OK(compiler.SetHostToDeviceMetadata(key, types, shapes)); + Status status = compiler.SetHostToDeviceMetadata(key, types2, shapes2); + EXPECT_EQ(status.code(), error::Code::INVALID_ARGUMENT); +} + } // namespace } // namespace tensorflow diff --git a/tensorflow/compiler/xla/bit_cast.h b/tensorflow/compiler/xla/bit_cast.h index 90e9a5c25dd..feb548c9433 100644 --- a/tensorflow/compiler/xla/bit_cast.h +++ b/tensorflow/compiler/xla/bit_cast.h @@ -29,7 +29,7 @@ limitations under the License. #include "absl/base/casts.h" #include "third_party/eigen3/Eigen/Core" #include "tensorflow/compiler/xla/types.h" -#include "tensorflow/core/lib/bfloat16/bfloat16.h" +#include "tensorflow/core/platform/bfloat16.h" #include "tensorflow/core/platform/types.h" namespace xla { diff --git a/tensorflow/compiler/xla/client/lib/arithmetic.cc b/tensorflow/compiler/xla/client/lib/arithmetic.cc index 20d9930341f..744cdcea14c 100644 --- a/tensorflow/compiler/xla/client/lib/arithmetic.cc +++ b/tensorflow/compiler/xla/client/lib/arithmetic.cc @@ -137,7 +137,7 @@ XlaComputation CreateMinMaxComputation(XlaBuilder* outer_builder, arg_max = Select(eq, tie_id, arg_max); } Tuple(b, {max, arg_max}); - return b->Build().ConsumeValueOrDie(); + return b->BuildAndNoteError(); } XlaOp ArgMinMax(XlaOp input, PrimitiveType output_type, int axis, bool is_min, diff --git a/tensorflow/compiler/xla/client/lib/comparators.cc b/tensorflow/compiler/xla/client/lib/comparators.cc index cd594a5cf39..c9d6cea740d 100644 --- a/tensorflow/compiler/xla/client/lib/comparators.cc +++ b/tensorflow/compiler/xla/client/lib/comparators.cc @@ -84,7 +84,12 @@ XlaComputation CreateScalarComparisonComputation( CHECK_NE(parameter_count, 0); - Shape shape = b->GetShape(lhs_params[0]).ValueOrDie(); + auto shape_or = b->GetShape(lhs_params[0]); + if (!shape_or.ok()) { + b->ReportError(shape_or.status()); + return {}; + } + Shape shape = shape_or.ValueOrDie(); shape.set_element_type(PRED); XlaOp param_equal = Broadcast(One(b.get(), shape.element_type()), AsInt64Slice(shape.dimensions())); diff --git a/tensorflow/compiler/xla/client/lib/math.cc b/tensorflow/compiler/xla/client/lib/math.cc index 6fdaab58686..cd9f88a74ce 100644 --- a/tensorflow/compiler/xla/client/lib/math.cc +++ b/tensorflow/compiler/xla/client/lib/math.cc @@ -1111,11 +1111,28 @@ XlaOp RoundToEven(XlaOp x) { // acos(x) = 2 * atan(sqrt(1 - x^2) / (1 + x)) if x != -1 // pi if x == -1 +// For complex: +// acos(x) = -(i * log(x + i * sqrt((1 + x) * (1 - x)))) XlaOp Acos(XlaOp x) { - return Select(Ne(x, FullLike(x, -1)), - ScalarLike(x, 2.0) * Atan2(Sqrt(ScalarLike(x, 1.0) - x * x), - ScalarLike(x, 1.0) + x), - FullLike(x, M_PI)); + XlaBuilder* b = x.builder(); + return b->ReportErrorOrReturn([&]() -> StatusOr { + TF_ASSIGN_OR_RETURN(auto shape, b->GetShape(x)); + + if (primitive_util::IsComplexType(shape.element_type())) { + auto one = ScalarLike(x, 1); + auto imag_one = Complex( + Zero(b, primitive_util::ComplexComponentType(shape.element_type())), + One(b, primitive_util::ComplexComponentType(shape.element_type()))); + + auto result = + Neg(imag_one * Log(x + imag_one * Sqrt((one + x) * (one - x)))); + return result; + } + return Select(Ne(x, FullLike(x, -1)), + ScalarLike(x, 2.0) * Atan2(Sqrt(ScalarLike(x, 1.0) - x * x), + ScalarLike(x, 1.0) + x), + FullLike(x, M_PI)); + }); } // asin(x) = 2 * atan(x / (1 + sqrt(1 - x^2))) diff --git a/tensorflow/compiler/xla/client/lib/math_test.cc b/tensorflow/compiler/xla/client/lib/math_test.cc index cb79b2ef7db..ae4d839d8fa 100644 --- a/tensorflow/compiler/xla/client/lib/math_test.cc +++ b/tensorflow/compiler/xla/client/lib/math_test.cc @@ -660,5 +660,19 @@ XLA_TEST_F(MathTest, BesselI1eDouble) { ComputeAndCompareR1(&builder, expected, {}, error_spec_); } +XLA_TEST_F(MathTest, AcosComplexValues) { + XlaBuilder builder(TestName()); + auto x = ConstantR1>( + &builder, {{0, 0}, {0, 1}, {1, 1}, {0.8, 0.2}}); + + Acos(x); + std::vector> expected = { + {1.5707963267948966, 0}, + {1.5707963267948966, -0.881373587019543}, + {0.9045568943023814, -1.0612750619050357}, + {0.7011246914497526, -0.30527648462436596}}; + ComputeAndCompareR1>(&builder, expected, {}, error_spec_); +} + } // namespace } // namespace xla diff --git a/tensorflow/compiler/xla/client/lib/prng.cc b/tensorflow/compiler/xla/client/lib/prng.cc index 044a742eddd..cc5639f1be1 100644 --- a/tensorflow/compiler/xla/client/lib/prng.cc +++ b/tensorflow/compiler/xla/client/lib/prng.cc @@ -426,32 +426,36 @@ RngOutput PhiloxRngBit64(XlaOp op_key, XlaOp initial_state, XlaOp ConvertRandomBitsToUniformFloatingPoint(XlaOp bits, XlaOp minval, XlaOp maxval) { XlaBuilder* builder = bits.builder(); - PrimitiveType value_type = - builder->GetShape(minval).ConsumeValueOrDie().element_type(); - PrimitiveType bit_type = - builder->GetShape(bits).ConsumeValueOrDie().element_type(); - CHECK((value_type == F32 && bit_type == U32) || - (value_type == F64 && bit_type == U64)); + return builder->ReportErrorOrReturn([&]() -> StatusOr { + TF_ASSIGN_OR_RETURN(const Shape* minval_shape, + builder->GetShapePtr(minval)); + TF_ASSIGN_OR_RETURN(const Shape* bits_shape, builder->GetShapePtr(bits)); + PrimitiveType value_type = minval_shape->element_type(); + PrimitiveType bit_type = bits_shape->element_type(); + CHECK((value_type == F32 && bit_type == U32) || + (value_type == F64 && bit_type == U64)); - // Form random mantissa bits for float/double, with a leading 1 bit. - int num_float_bits = primitive_util::BitWidth(value_type); - // Subtract one as SignificandWidth includes the leading 1 bit. - int num_mantissa_bits = primitive_util::SignificandWidth(value_type) - 1; + // Form random mantissa bits for float/double, with a leading 1 bit. + int num_float_bits = primitive_util::BitWidth(value_type); + // Subtract one as SignificandWidth includes the leading 1 bit. + int num_mantissa_bits = primitive_util::SignificandWidth(value_type) - 1; - // Ignore the exponent bits and convert the mantissa bits to the floating - // point type. - bits = ShiftRightLogical( - bits, ScalarLike(bits, num_float_bits - num_mantissa_bits)); + // Ignore the exponent bits and convert the mantissa bits to the floating + // point type. + bits = ShiftRightLogical( + bits, ScalarLike(bits, num_float_bits - num_mantissa_bits)); - // We have an integer-valued floating point number in the range - // [0, 2**{num_mantissa_bits}). - XlaOp values = ConvertElementType(bits, value_type); + // We have an integer-valued floating point number in the range + // [0, 2**{num_mantissa_bits}). + XlaOp values = ConvertElementType(bits, value_type); - // Divide by 2**{-num_mantissa_bits} to get a number in the range [0.0, 1.0). - values = values * ScalarLike(values, std::ldexp(1., -num_mantissa_bits)); + // Divide by 2**{-num_mantissa_bits} to get a number in the range + // [0.0, 1.0). + values = values * ScalarLike(values, std::ldexp(1., -num_mantissa_bits)); - // Multiply and add to shift to the range [minval, maxval). - return values * (maxval - minval) + minval; + // Multiply and add to shift to the range [minval, maxval). + return values * (maxval - minval) + minval; + }); } XlaOp ConvertRandomBitsToUniformInt(XlaOp bits, XlaOp minval, XlaOp maxval, diff --git a/tensorflow/compiler/xla/client/lib/quantize.h b/tensorflow/compiler/xla/client/lib/quantize.h index 26dbbd5b00b..320dfcbf062 100644 --- a/tensorflow/compiler/xla/client/lib/quantize.h +++ b/tensorflow/compiler/xla/client/lib/quantize.h @@ -25,7 +25,7 @@ limitations under the License. #include "tensorflow/compiler/xla/types.h" #include "tensorflow/compiler/xla/util.h" #include "tensorflow/compiler/xla/xla_data.pb.h" -#include "tensorflow/core/lib/bfloat16/bfloat16.h" +#include "tensorflow/core/platform/bfloat16.h" namespace xla { diff --git a/tensorflow/compiler/xla/client/lib/self_adjoint_eig.cc b/tensorflow/compiler/xla/client/lib/self_adjoint_eig.cc index 1c0680b883a..58905e4ca6f 100644 --- a/tensorflow/compiler/xla/client/lib/self_adjoint_eig.cc +++ b/tensorflow/compiler/xla/client/lib/self_adjoint_eig.cc @@ -228,7 +228,7 @@ StatusOr> WhileLoopFn( auto max_sweeps = ScalarLike(k, max_sweep_updates); auto sweep_update_cond = Gt(max_sweeps, k); - auto norms = ComputeFrobeniusNorms(values[2]).ValueOrDie(); + TF_ASSIGN_OR_RETURN(auto norms, ComputeFrobeniusNorms(values[2])); auto tol = norms.total_norm * values[3]; auto tol_cond = ReduceAll(Lt(tol, norms.off_diagonal_norm), xla::ConstantR0(cond_builder, false), @@ -400,7 +400,7 @@ SelfAdjointEigResult SelfAdjointEig(XlaOp a, bool lower, int64 max_iter, return result; }; auto shape_with_status = builder->GetShape(a); - if (!shape_with_status.status().ok()) { + if (!shape_with_status.ok()) { return return_error(shape_with_status.status()); } Shape a_shape = shape_with_status.ValueOrDie(); @@ -450,7 +450,7 @@ SelfAdjointEigResult SelfAdjointEig(XlaOp a, bool lower, int64 max_iter, S32, // "CyclicJacobi", // builder); - if (!output_with_status.status().ok()) { + if (!output_with_status.ok()) { return return_error(output_with_status.status()); } @@ -460,7 +460,11 @@ SelfAdjointEigResult SelfAdjointEig(XlaOp a, bool lower, int64 max_iter, result.v = output[1]; result.w = GetMatrixDiagonal(output[2]); - return SortByEigenvalues(result).ValueOrDie(); + auto result_or = SortByEigenvalues(result); + if (!result_or.ok()) { + return return_error(result_or.status()); + } + return result_or.ValueOrDie(); } } // namespace xla diff --git a/tensorflow/compiler/xla/client/lib/svd.cc b/tensorflow/compiler/xla/client/lib/svd.cc index 646875a20a2..80ea4d644c0 100644 --- a/tensorflow/compiler/xla/client/lib/svd.cc +++ b/tensorflow/compiler/xla/client/lib/svd.cc @@ -837,8 +837,11 @@ SVDResult SVD(XlaOp a, int64 max_iter, float epsilon, auto eps = ScalarLike(a, epsilon); - SVDResult svd_result = - HouseHolderBidiagonalization(a, eps, precision).ValueOrDie(); + auto svd_result_or = HouseHolderBidiagonalization(a, eps, precision); + if (!svd_result_or.ok()) { + return return_error(svd_result_or.status()); + } + SVDResult svd_result = svd_result_or.ValueOrDie(); auto output_with_status = WhileLoopFn( { @@ -861,7 +864,13 @@ SVDResult SVD(XlaOp a, int64 max_iter, float epsilon, svd_result.u = output[1]; svd_result.v = output[2]; svd_result.d = output[3]; - svd_result = SortBySingularValuesAndPostProcessing(svd_result).ValueOrDie(); + + svd_result_or = SortBySingularValuesAndPostProcessing(svd_result); + if (!svd_result_or.ok()) { + return return_error(svd_result_or.status()); + } + svd_result = svd_result_or.ValueOrDie(); + if (maybe_transpose) { std::swap(svd_result.u, svd_result.v); } diff --git a/tensorflow/compiler/xla/client/xla_builder.cc b/tensorflow/compiler/xla/client/xla_builder.cc index 2b69c71042d..34d78f9d933 100644 --- a/tensorflow/compiler/xla/client/xla_builder.cc +++ b/tensorflow/compiler/xla/client/xla_builder.cc @@ -28,6 +28,7 @@ limitations under the License. #include "absl/strings/match.h" #include "absl/strings/str_cat.h" #include "absl/strings/str_join.h" +#include "absl/types/span.h" #include "tensorflow/compiler/xla/client/sharding_builder.h" #include "tensorflow/compiler/xla/client/xla_computation.h" #include "tensorflow/compiler/xla/comparison_util.h" @@ -78,16 +79,13 @@ ShapeProto ConvertShapeProtoToPred(const ShapeProto& shape_proto) { return ShapeUtil::ChangeElementType(Shape(shape_proto), PRED).ToProto(); } -HloInstructionProto CreateConstantInstruction(int64 id, const Shape& shape, - bool pred) { - HloInstructionProto const_instr; +void SetInstructionAsConstant(HloInstructionProto* instr, int64 id, + const Shape& shape, bool pred) { Literal literal = LiteralUtil::CreateR0(pred); Literal literal_broadcast = literal.Broadcast(shape, {}).ValueOrDie(); - *const_instr.mutable_shape() = shape.ToProto(); - *const_instr.mutable_literal() = literal_broadcast.ToProto(); - *const_instr.mutable_opcode() = HloOpcodeString(HloOpcode::kConstant); - const_instr.set_id(id); - return const_instr; + *instr->mutable_shape() = shape.ToProto(); + *instr->mutable_literal() = literal_broadcast.ToProto(); + *instr->mutable_opcode() = HloOpcodeString(HloOpcode::kConstant); } // Converts a HloComputation into ReducerOr with predicate types. @@ -1083,6 +1081,36 @@ XlaOp XlaBuilder::Reshape(const Shape& shape, XlaOp operand, }); } +XlaOp XlaBuilder::DynamicReshape(XlaOp operand, + absl::Span dim_sizes, + absl::Span new_size_bounds, + const std::vector& dims_are_dynamic) { + return ReportErrorOrReturn([&]() -> StatusOr { + TF_ASSIGN_OR_RETURN(const Shape* operand_shape, GetShapePtr(operand)); + std::vector dim_size_shape_ptrs; + TF_ASSIGN_OR_RETURN(const auto& dim_size_shapes, + GetOperandShapes(dim_sizes)); + + absl::c_transform(dim_size_shapes, std::back_inserter(dim_size_shape_ptrs), + [](const Shape& shape) { return &shape; }); + TF_ASSIGN_OR_RETURN(const Shape shape, + ShapeInference::InferDynamicReshapeShape( + *operand_shape, dim_size_shape_ptrs, + new_size_bounds, dims_are_dynamic)); + TF_RETURN_IF_ERROR(first_error_); + std::vector operands; + operands.reserve(1 + dim_sizes.size()); + operands.push_back(operand); + for (const XlaOp& dim_size : dim_sizes) { + operands.push_back(dim_size); + } + HloInstructionProto instr; + *instr.mutable_shape() = shape.ToProto(); + return AddInstruction(std::move(instr), HloOpcode::kDynamicReshape, + operands); + }); +} + XlaOp XlaBuilder::Collapse(XlaOp operand, absl::Span dimensions) { return ReportErrorOrReturn([&]() -> StatusOr { if (dimensions.size() <= 1) { @@ -1425,6 +1453,25 @@ StatusOr XlaBuilder::FftInternal( return AddInstruction(std::move(instr), HloOpcode::kFft, {operand}); } +StatusOr XlaBuilder::TriangularSolveInternal( + const Shape& shape, XlaOp a, XlaOp b, TriangularSolveOptions options) { + HloInstructionProto instr; + *instr.mutable_triangular_solve_options() = std::move(options); + *instr.mutable_shape() = shape.ToProto(); + + return AddInstruction(std::move(instr), HloOpcode::kTriangularSolve, {a, b}); +} + +StatusOr XlaBuilder::CholeskyInternal(const Shape& shape, XlaOp a, + bool lower) { + HloInstructionProto instr; + xla::CholeskyOptions& options = *instr.mutable_cholesky_options(); + options.set_lower(lower); + *instr.mutable_shape() = shape.ToProto(); + + return AddInstruction(std::move(instr), HloOpcode::kCholesky, {a}); +} + XlaOp XlaBuilder::Infeed(const Shape& shape, const string& config) { return ReportErrorOrReturn([&]() -> StatusOr { HloInstructionProto instr; @@ -1935,7 +1982,6 @@ XlaOp XlaBuilder::RngUniform(XlaOp a, XlaOp b, const Shape& shape) { XlaOp XlaBuilder::RngBitGenerator(RandomAlgorithm algorithm, XlaOp initial_state, const Shape& shape) { return ReportErrorOrReturn([&]() -> StatusOr { - HloInstructionProto instr; TF_RETURN_IF_ERROR(ShapeUtil::ValidateShapeWithOptionalLayout(shape)); TF_ASSIGN_OR_RETURN(Shape state_shape, GetShape(initial_state)); Shape output_shape = shape; @@ -1954,14 +2000,22 @@ XlaOp XlaBuilder::RngBitGenerator(RandomAlgorithm algorithm, return InvalidArgument("Unsupported shape for RngBitGenerator: %s", PrimitiveType_Name(output_shape.element_type())); } - *instr.mutable_shape() = - ShapeUtil::MakeTupleShape({state_shape, output_shape}).ToProto(); - instr.set_rng_algorithm(algorithm); - return AddInstruction(std::move(instr), HloOpcode::kRngBitGenerator, - {initial_state}); + return RngBitGeneratorInternal( + ShapeUtil::MakeTupleShape({state_shape, output_shape}), algorithm, + initial_state); }); } +StatusOr XlaBuilder::RngBitGeneratorInternal( + const Shape& full_result_shape, RandomAlgorithm algorithm, + XlaOp initial_state) { + HloInstructionProto instr; + *instr.mutable_shape() = full_result_shape.ToProto(); + instr.set_rng_algorithm(algorithm); + return AddInstruction(std::move(instr), HloOpcode::kRngBitGenerator, + {initial_state}); +} + XlaOp XlaBuilder::While(const XlaComputation& condition, const XlaComputation& body, XlaOp init) { return ReportErrorOrReturn([&]() -> StatusOr { @@ -2527,6 +2581,7 @@ XlaOp XlaBuilder::AllToAll(XlaOp operand, int64 split_dimension, } *(shape.mutable_tuple_shapes(i)->mutable_layout()) = *layout; } + instr.set_constrain_layout(true); } *instr.mutable_shape() = shape.ToProto(); @@ -2914,27 +2969,12 @@ StatusOr XlaBuilder::BuildDynamicInferenceGraph(XlaOp root_op) { *program_shape->mutable_result() = ShapeUtil::ChangeElementType(Shape(root->shape()), PRED).ToProto(); - std::set seen; - struct WorkItem { - explicit WorkItem(int64 handle, bool need_rewrite) - : handle(handle), need_rewrite(need_rewrite) {} - int64 handle; - // If need_rewrite is true, the instruction will be copied and rewrite into - // a pred instruction indicating if each value is dynamic. If need_rewrite - // is false, simply copy the instruction to the output graph. - // E.g., - // For select(P, A, B), we need to rewrite A and B into predicates, but - // don't need to rewrite P. - bool need_rewrite; - }; - std::queue worklist; - worklist.push(WorkItem(root->id(), true)); - entry.set_root_id(root->id()); std::vector called_computatons; - // Rewritre instruction with id "from" into the new graph. - // Returns more work items that need to finish. - auto rewrite_instruction = - [&](int64 from, bool need_rewrite) -> StatusOr> { + // Process instruction and copy it into the new graph. The new node in the new + // graph with have id set to `id`. + auto process_instruction = [&](const HloInstructionProto* instr_proto, + bool need_rewrite, int64 id, + absl::Span operand_ids) { // Rewrite the instruction with following rules: // - Unary ops: Convert into bitcast (identity) with type Pred. // - Binary ops: Convert into binary or. @@ -2947,22 +2987,20 @@ StatusOr XlaBuilder::BuildDynamicInferenceGraph(XlaOp root_op) { // - Constant: Convert to constant False. // - Other ops: Not supported. // Create the instruction for the new handle. - TF_ASSIGN_OR_RETURN(const HloInstructionProto* instr_proto, - LookUpInstructionByHandle(from)); - TF_ASSIGN_OR_RETURN(HloOpcode opcode, StringToHloOpcode(instr_proto->opcode())); - std::vector operands_todo; auto* new_instr = entry.add_instructions(); *new_instr = *instr_proto; - for (auto operand_id : new_instr->operand_ids()) { - operands_todo.emplace_back(operand_id, need_rewrite); + new_instr->set_id(id); + new_instr->mutable_operand_ids()->Clear(); + for (auto operand_id : operand_ids) { + new_instr->mutable_operand_ids()->Add(operand_id); } if (!need_rewrite) { *new_instr->mutable_name() = - GetFullName(instr_proto->opcode(), kNameSeparator, instr_proto->id()); - return operands_todo; + GetFullName(instr_proto->opcode(), kNameSeparator, id); + return Status::OK(); } *new_instr->mutable_shape() = ConvertShapeProtoToPred(instr_proto->shape()); Shape new_shape(new_instr->shape()); @@ -3017,10 +3055,8 @@ StatusOr XlaBuilder::BuildDynamicInferenceGraph(XlaOp root_op) { *new_instr->mutable_opcode() = HloOpcodeString(HloOpcode::kOr); break; case HloOpcode::kSelect: - operands_todo[0].need_rewrite = false; break; case HloOpcode::kGather: - operands_todo[1].need_rewrite = false; break; case HloOpcode::kReduce: { int64 reducer_id = new_instr->called_computation_ids(0); @@ -3042,39 +3078,101 @@ StatusOr XlaBuilder::BuildDynamicInferenceGraph(XlaOp root_op) { TF_ASSIGN_OR_RETURN(const HloInstructionProto* operand_proto, LookUpInstructionByHandle(operand_handle)); - *new_instr = CreateConstantInstruction( - from, new_shape, + SetInstructionAsConstant( + new_instr, id, new_shape, operand_proto->shape().is_dynamic_dimension(dimension)); - operands_todo.clear(); break; } case HloOpcode::kConstant: - *new_instr = CreateConstantInstruction(from, new_shape, false); + SetInstructionAsConstant(new_instr, id, new_shape, false); break; case HloOpcode::kParameter: - *new_instr = CreateConstantInstruction(from, new_shape, true); + SetInstructionAsConstant(new_instr, id, new_shape, true); break; default: return InvalidArgument("Dynamic inferencing %s is not supported", instr_proto->DebugString()); } *new_instr->mutable_name() = - GetFullName(instr_proto->opcode(), kNameSeparator, instr_proto->id()); - return operands_todo; + GetFullName(instr_proto->opcode(), kNameSeparator, id); + return Status::OK(); }; + struct WorkItem { + explicit WorkItem(int64 handle, bool need_rewrite) + : handle(handle), need_rewrite(need_rewrite), visited(false) {} + int64 handle; + // If need_rewrite is true, the instruction will be copied and rewrite into + // a pred instruction indicating if each value is dynamic. If need_rewrite + // is false, simply copy the instruction to the output graph. + // E.g., + // For select(P, A, B), we need to rewrite A and B into predicates, but + // don't need to rewrite P. + bool need_rewrite; + // Used in dfs to remember the ids of processed operands of this item. + std::vector processed_operands; + // Whether this node been visited before or not. + bool visited; + }; + // Only copy each pair of {handle, need_rewrite} once. Value is the id in the + // new graph. + absl::flat_hash_map, int64> seen; + // Monotonically increasing id to assign to new instructions. + int64 global_id = 0; + // The result id of the last rewritten item -- return value of last stack + // item. + int64 stacktop_id = -1; + std::vector worklist; + worklist.push_back(WorkItem(root->id(), true)); while (!worklist.empty()) { - WorkItem item = worklist.front(); - worklist.pop(); - if (!seen.insert(item.handle).second) { + WorkItem& item = worklist.back(); + auto item_key = std::make_pair(item.handle, item.need_rewrite); + auto iter = seen.find(item_key); + // Already processed this item. Return previous results. + if (iter != seen.end()) { + stacktop_id = iter->second; + worklist.pop_back(); continue; } - TF_ASSIGN_OR_RETURN(auto todos, - rewrite_instruction(item.handle, item.need_rewrite)); - for (WorkItem& todo : todos) { - worklist.push(todo); + + int64 next_operand = item.processed_operands.size(); + TF_ASSIGN_OR_RETURN(const HloInstructionProto* instr_proto, + LookUpInstructionByHandle(item.handle)); + VLOG(3) << "Processing" << instr_proto->name(); + if (!item.visited) { + item.visited = true; + } else { + // Record previous processed operand. + item.processed_operands.push_back(stacktop_id); + next_operand++; } + TF_ASSIGN_OR_RETURN(HloOpcode opcode, + StringToHloOpcode(instr_proto->opcode())); + if (next_operand >= instr_proto->operand_ids_size() || + opcode == HloOpcode::kGetDimensionSize) { + // No more operands to process, process self. + int64 new_id = ++global_id; + VLOG(3) << "new_id: " << new_id << "instr: " << instr_proto->name(); + TF_RETURN_IF_ERROR(process_instruction(instr_proto, item.need_rewrite, + new_id, item.processed_operands)); + stacktop_id = new_id; + seen[item_key] = stacktop_id; + worklist.pop_back(); + continue; + } + + WorkItem next_item(instr_proto->operand_ids(next_operand), true); + if (opcode == HloOpcode::kSelect && next_operand == 0) { + next_item.need_rewrite = false; + } + if (opcode == HloOpcode::kGather && next_operand == 1) { + next_item.need_rewrite = false; + } + // Push next operand into worklist. + worklist.push_back(next_item); } + TF_RET_CHECK(stacktop_id != -1); + entry.set_root_id(stacktop_id); absl::c_sort(*entry.mutable_instructions(), [](const HloInstructionProto& p1, const HloInstructionProto& p2) { return p1.id() < p2.id(); }); @@ -3466,6 +3564,13 @@ XlaOp Reshape(const Shape& shape, XlaOp operand) { return operand.builder()->Reshape(shape, operand); } +XlaOp DynamicReshape(XlaOp operand, absl::Span dim_sizes, + absl::Span new_size_bounds, + const std::vector& dims_are_dynamic) { + return operand.builder()->DynamicReshape(operand, dim_sizes, new_size_bounds, + dims_are_dynamic); +} + XlaOp ReshapeWithInferredDimension(XlaOp operand, absl::Span new_sizes, int64 inferred_dimension) { @@ -3684,36 +3789,26 @@ XlaOp TriangularSolve(XlaOp a, XlaOp b, bool left_side, bool lower, TriangularSolveOptions::Transpose transpose_a) { XlaBuilder* builder = a.builder(); return builder->ReportErrorOrReturn([&]() -> StatusOr { - HloInstructionProto instr; TF_ASSIGN_OR_RETURN(const Shape* a_shape, builder->GetShapePtr(a)); TF_ASSIGN_OR_RETURN(const Shape* b_shape, builder->GetShapePtr(b)); - xla::TriangularSolveOptions& options = - *instr.mutable_triangular_solve_options(); + xla::TriangularSolveOptions options; options.set_left_side(left_side); options.set_lower(lower); options.set_unit_diagonal(unit_diagonal); options.set_transpose_a(transpose_a); TF_ASSIGN_OR_RETURN(Shape shape, ShapeInference::InferTriangularSolveShape( *a_shape, *b_shape, options)); - *instr.mutable_shape() = shape.ToProto(); - - return builder->AddInstruction(std::move(instr), - HloOpcode::kTriangularSolve, {a, b}); + return builder->TriangularSolveInternal(shape, a, b, std::move(options)); }); } XlaOp Cholesky(XlaOp a, bool lower) { XlaBuilder* builder = a.builder(); return builder->ReportErrorOrReturn([&]() -> StatusOr { - HloInstructionProto instr; TF_ASSIGN_OR_RETURN(const Shape* a_shape, builder->GetShapePtr(a)); - xla::CholeskyOptions& options = *instr.mutable_cholesky_options(); - options.set_lower(lower); TF_ASSIGN_OR_RETURN(Shape shape, ShapeInference::InferCholeskyShape(*a_shape)); - *instr.mutable_shape() = shape.ToProto(); - - return builder->AddInstruction(std::move(instr), HloOpcode::kCholesky, {a}); + return builder->CholeskyInternal(shape, a, lower); }); } diff --git a/tensorflow/compiler/xla/client/xla_builder.h b/tensorflow/compiler/xla/client/xla_builder.h index 6d30195d3d0..f841a1a75a0 100644 --- a/tensorflow/compiler/xla/client/xla_builder.h +++ b/tensorflow/compiler/xla/client/xla_builder.h @@ -366,6 +366,7 @@ class XlaBuilder { // // TODO(b/119520625): Remove this API once we have more dynamic shape infra // ready. + ABSL_DEPRECATED("Use SetDimensionSize to set a dynamic dimension.") Status SetDynamicBinding(int64 dynamic_size_param_num, ShapeIndex dynamic_size_param_index, int64 target_param_num, @@ -454,6 +455,10 @@ class XlaBuilder { XlaOp Reshape(const Shape& shape, XlaOp operand, int64 inferred_dimension = -1); + XlaOp DynamicReshape(XlaOp operand, absl::Span dim_sizes, + absl::Span new_size_bounds, + const std::vector& dims_are_dynamic); + XlaOp Collapse(XlaOp operand, absl::Span dimensions); XlaOp Slice(XlaOp operand, absl::Span start_indices, @@ -553,6 +558,12 @@ class XlaBuilder { FftType fft_type, absl::Span fft_length); + virtual StatusOr TriangularSolveInternal( + const Shape& shape, XlaOp a, XlaOp b, TriangularSolveOptions options); + + virtual StatusOr CholeskyInternal(const Shape& shape, XlaOp a, + bool lower); + XlaOp Infeed(const Shape& shape, const string& config = ""); XlaOp InfeedWithToken(XlaOp token, const Shape& shape, const string& config); virtual StatusOr InfeedWithTokenInternal( @@ -701,6 +712,11 @@ class XlaBuilder { XlaOp RngBitGenerator(RandomAlgorithm algorithm, XlaOp initial_state, const Shape& shape); + // Internal variant for the op with the full result shape containing both data + // and state shape as a tuple. + virtual StatusOr RngBitGeneratorInternal( + const Shape& full_result_shape, RandomAlgorithm algorithm, + XlaOp initial_state); XlaOp While(const XlaComputation& condition, const XlaComputation& body, XlaOp init); @@ -773,8 +789,13 @@ class XlaBuilder { XlaOp RemoveDynamicDimension(XlaOp operand, int64 dimension); - StatusOr AddInstruction(HloInstructionProto&& instr, HloOpcode opcode, - absl::Span operands = {}); + virtual StatusOr AddInstruction(HloInstructionProto&& instr, + HloOpcode opcode, + absl::Span operands); + StatusOr AddInstruction(HloInstructionProto&& instr, + HloOpcode opcode) { + return AddInstruction(std::move(instr), opcode, /*operands=*/{}); + } void AddCalledComputation(const XlaComputation& computation, HloInstructionProto* instr); @@ -940,6 +961,10 @@ class XlaBuilder { friend XlaOp Reshape(const Shape& shape, XlaOp operand); + friend XlaOp DynamicReshape(XlaOp operand, absl::Span dim_sizes, + absl::Span new_size_bounds, + const std::vector& dims_are_dynamic); + friend XlaOp ReshapeWithInferredDimension(XlaOp operand, absl::Span new_sizes, int64 inferred_dimension); @@ -1453,9 +1478,16 @@ XlaOp Pad(XlaOp operand, XlaOp padding_value, XlaOp Reshape(XlaOp operand, absl::Span dimensions, absl::Span new_sizes); -// Enqueues an operation onto the computation that collapses the operand, from -// first to last dimension (C order), then reshapes it to the given dimension -// sizes. Conceptually, this is a limited form of "shape casting". +// Enqueues a dynamic reshape operation. The dynamic reshape takes additional +// XlaOps as sizes for the result dimension. The result dim i is a dynamic +// dimension dimension if dims_are_dynamic[i] is true. +XlaOp DynamicReshape(XlaOp operand, absl::Span dim_sizes, + absl::Span new_size_bounds, + const std::vector& dims_are_dynamic); + +// Enqueues an operation onto the computation that collapses the operand, +// from first to last dimension (C order), then reshapes it to the given +// dimension sizes. Conceptually, this is a limited form of "shape casting". XlaOp Reshape(XlaOp operand, absl::Span new_sizes); // Enqueues a Reshape op that uses an explicit target shape. diff --git a/tensorflow/compiler/xla/pjrt/cpu_device.cc b/tensorflow/compiler/xla/pjrt/cpu_device.cc index be70c16fc12..e2543bda7df 100644 --- a/tensorflow/compiler/xla/pjrt/cpu_device.cc +++ b/tensorflow/compiler/xla/pjrt/cpu_device.cc @@ -25,8 +25,8 @@ static const char kCpuPlatformName[] = "cpu"; CpuDevice::CpuDevice(int id, std::unique_ptr local_device_state) - : Device(id, std::move(local_device_state), kCpuPlatformName, - /*device_kind=*/kCpuPlatformName) {} + : PjRtDevice(id, std::move(local_device_state), kCpuPlatformName, + /*device_kind=*/kCpuPlatformName) {} StatusOr> GetCpuClient(bool asynchronous) { TF_ASSIGN_OR_RETURN(se::Platform * platform, @@ -39,7 +39,7 @@ StatusOr> GetCpuClient(bool asynchronous) { TF_ASSIGN_OR_RETURN(LocalClient * client, ClientLibrary::GetOrCreateLocalClient(options)); - std::vector> devices; + std::vector> devices; for (int i = 0; i < client->device_count(); ++i) { se::StreamExecutorConfig config; config.ordinal = i; diff --git a/tensorflow/compiler/xla/pjrt/cpu_device.h b/tensorflow/compiler/xla/pjrt/cpu_device.h index c70d90ae228..ad0079b1c4a 100644 --- a/tensorflow/compiler/xla/pjrt/cpu_device.h +++ b/tensorflow/compiler/xla/pjrt/cpu_device.h @@ -23,7 +23,7 @@ limitations under the License. namespace xla { -class CpuDevice : public Device { +class CpuDevice : public PjRtDevice { public: CpuDevice(int id, std::unique_ptr local_device_state); }; diff --git a/tensorflow/compiler/xla/pjrt/gpu_multistream_test.cc b/tensorflow/compiler/xla/pjrt/gpu_multistream_test.cc index d54be61fbb8..298c41c7f58 100644 --- a/tensorflow/compiler/xla/pjrt/gpu_multistream_test.cc +++ b/tensorflow/compiler/xla/pjrt/gpu_multistream_test.cc @@ -32,7 +32,7 @@ TEST(GpuMultiStream, Basics) { GetNvidiaGpuClient(/*asynchronous=*/true, GpuAllocatorConfig(), /*distributed_client=*/nullptr, /*node_id=*/0)); - Device* device = client->local_devices().at(0); + PjRtDevice* device = client->local_devices().at(0); int n = 1024; Shape shape = ShapeUtil::MakeShape(S32, {n}); diff --git a/tensorflow/compiler/xla/pjrt/interpreter_device.cc b/tensorflow/compiler/xla/pjrt/interpreter_device.cc index f7138a8c181..c1149f2dbf9 100644 --- a/tensorflow/compiler/xla/pjrt/interpreter_device.cc +++ b/tensorflow/compiler/xla/pjrt/interpreter_device.cc @@ -25,8 +25,8 @@ static const char kInterpreterPlatformName[] = "interpreter"; InterpreterDevice::InterpreterDevice( int id, std::unique_ptr local_device_state) - : Device(id, std::move(local_device_state), kInterpreterPlatformName, - /*device_kind=*/kInterpreterPlatformName) {} + : PjRtDevice(id, std::move(local_device_state), kInterpreterPlatformName, + /*device_kind=*/kInterpreterPlatformName) {} StatusOr> GetInterpreterClient() { TF_ASSIGN_OR_RETURN(se::Platform * platform, @@ -40,7 +40,7 @@ StatusOr> GetInterpreterClient() { TF_ASSIGN_OR_RETURN(LocalClient * client, ClientLibrary::GetOrCreateLocalClient(options)); - std::vector> devices; + std::vector> devices; se::StreamExecutor* executor = client->backend().stream_executor(0).ValueOrDie(); auto device_state = absl::make_unique( diff --git a/tensorflow/compiler/xla/pjrt/interpreter_device.h b/tensorflow/compiler/xla/pjrt/interpreter_device.h index 58b210ad762..cf732f70124 100644 --- a/tensorflow/compiler/xla/pjrt/interpreter_device.h +++ b/tensorflow/compiler/xla/pjrt/interpreter_device.h @@ -23,7 +23,7 @@ limitations under the License. namespace xla { -class InterpreterDevice : public Device { +class InterpreterDevice : public PjRtDevice { public: InterpreterDevice(int id, std::unique_ptr local_device_state); diff --git a/tensorflow/compiler/xla/pjrt/nvidia_gpu_device.cc b/tensorflow/compiler/xla/pjrt/nvidia_gpu_device.cc index edffaf6c877..512ff81ef6e 100644 --- a/tensorflow/compiler/xla/pjrt/nvidia_gpu_device.cc +++ b/tensorflow/compiler/xla/pjrt/nvidia_gpu_device.cc @@ -207,9 +207,9 @@ StatusOr NcclIdStore::GetNcclUniqueId(const NcclCliqueKey& key) { return cache_.emplace(key_string, result.ValueOrDie()).first->second; } -std::vector> BuildLocalDevices( +std::vector> BuildLocalDevices( std::vector> local_device_states) { - std::vector> devices; + std::vector> devices; for (auto& local_device : local_device_states) { int device_ordinal = local_device->device_ordinal(); const se::DeviceDescription& description = @@ -225,7 +225,7 @@ std::vector> BuildLocalDevices( Status BuildDistributedDevices( std::vector> local_device_states, std::shared_ptr distributed_client, int node_id, - std::vector>* devices, + std::vector>* devices, GpuExecutableRunOptions* gpu_executable_run_options) { LocalTopologyProto local_topology; local_topology.set_node_id(node_id); @@ -286,8 +286,8 @@ Status BuildDistributedDevices( GpuDevice::GpuDevice(int id, std::unique_ptr local_device_state, std::string device_kind, int node_id) - : Device(id, std::move(local_device_state), kGpuPlatformName, - std::move(device_kind), node_id) {} + : PjRtDevice(id, std::move(local_device_state), kGpuPlatformName, + std::move(device_kind), node_id) {} StatusOr> GetNvidiaGpuClient( bool asynchronous, const GpuAllocatorConfig& allocator_config, @@ -302,7 +302,7 @@ StatusOr> GetNvidiaGpuClient( auto host_memory_allocator = GetGpuHostAllocator(local_device_states.front()->executor()); - std::vector> devices; + std::vector> devices; auto gpu_run_options = absl::make_unique(); if (distributed_client) { TF_RETURN_IF_ERROR(BuildDistributedDevices( diff --git a/tensorflow/compiler/xla/pjrt/nvidia_gpu_device.h b/tensorflow/compiler/xla/pjrt/nvidia_gpu_device.h index bf59ddef3a9..4f22a169bd8 100644 --- a/tensorflow/compiler/xla/pjrt/nvidia_gpu_device.h +++ b/tensorflow/compiler/xla/pjrt/nvidia_gpu_device.h @@ -25,7 +25,7 @@ limitations under the License. namespace xla { -class GpuDevice : public Device { +class GpuDevice : public PjRtDevice { public: GpuDevice(int id, std::unique_ptr local_device_state, std::string device_kind, int node_id); diff --git a/tensorflow/compiler/xla/pjrt/pjrt_client.cc b/tensorflow/compiler/xla/pjrt/pjrt_client.cc index c5dce4a37f7..099c7729679 100644 --- a/tensorflow/compiler/xla/pjrt/pjrt_client.cc +++ b/tensorflow/compiler/xla/pjrt/pjrt_client.cc @@ -112,19 +112,19 @@ limitations under the License. namespace xla { -StatusOr Device::GetLocalDeviceState() const { +StatusOr PjRtDevice::GetLocalDeviceState() const { if (local_device_state_) { return local_device_state_.get(); } return InvalidArgument("Device %s is not a local device.", DebugString()); } -std::string Device::DebugString() const { +std::string PjRtDevice::DebugString() const { return absl::StrCat(platform_name(), ":", id()); } StatusOr DevicesToDeviceAssignment( - absl::Span> devices) { + absl::Span> devices) { if (devices.empty()) { return InvalidArgument( "Device assignment passed to Compile() must be non-empty."); @@ -175,7 +175,7 @@ class CpuAllocator : public tensorflow::Allocator { PjRtClient::PjRtClient( std::string platform_name, LocalClient* client, - std::vector> devices, int host_id, + std::vector> devices, int host_id, std::unique_ptr allocator, std::unique_ptr host_memory_allocator, bool should_stage_host_to_device_transfers, @@ -201,7 +201,7 @@ PjRtClient::PjRtClient( host_memory_allocator_ = std::make_unique(); } - for (const std::unique_ptr& device : devices_) { + for (const std::unique_ptr& device : devices_) { CHECK(id_to_device_.insert({device->id(), device.get()}).second) << "Duplicate device id: " << device->id(); @@ -376,8 +376,9 @@ void RecordUsage(PjRtBuffer::ScopedHold device_buffer, // It is safe to delete the returned PjRtBuffer without further // synchronization if an error occurs before the buffer is used. StatusOr> AllocateDestinationBuffer( - const Shape& on_host_shape, Device* device, LocalDeviceState* local_device, - se::Stream* copy_stream, bool is_uninitialized_create, PjRtClient* client) { + const Shape& on_host_shape, PjRtDevice* device, + LocalDeviceState* local_device, se::Stream* copy_stream, + bool is_uninitialized_create, PjRtClient* client) { if (on_host_shape.IsTuple() && on_host_shape.tuple_shapes_size() == 0) { return InvalidArgument("Can't make a buffer from an empty tuple"); } @@ -574,7 +575,7 @@ StatusOr> PjRtBuffer::FromHostBuffer( const void* data, const Shape& shape, HostBufferSemantics host_buffer_semantics, std::shared_ptr buffer_reference, PjRtClient* client, - Device* device) { + PjRtDevice* device) { tensorflow::profiler::TraceMe traceme("PjRtBuffer::FromHostBuffer"); VLOG(2) << "PjRtBuffer::FromHostBuffer: shape: " << shape.ToString() << " device: " << device->DebugString(); @@ -736,7 +737,7 @@ StatusOr> PjRtBuffer::FromHostBuffer( /* static */ StatusOr> PjRtBuffer::CreateUninitialized( - const Shape& shape, PjRtClient* client, Device* device) { + const Shape& shape, PjRtClient* client, PjRtDevice* device) { tensorflow::profiler::TraceMe traceme("PjRtBuffer::CreateUninitialized"); VLOG(2) << "PjRtBuffer::CreateUninitialized: shape: " << shape.ToString() << " device: " << device->DebugString(); @@ -755,7 +756,7 @@ StatusOr> PjRtBuffer::CreateUninitialized( /* static */ StatusOr> PjRtBuffer::FromHostLiteral( - const LiteralSlice& literal, PjRtClient* client, Device* device) { + const LiteralSlice& literal, PjRtClient* client, PjRtDevice* device) { tensorflow::profiler::TraceMe traceme("PjRtBuffer::FromHostLiteral"); VLOG(2) << "PjRtBuffer::FromHostLiteral: shape: " << literal.shape().ToString() << " device: " << device->DebugString(); @@ -815,7 +816,7 @@ StatusOr> PjRtBuffer::FromHostLiteral( } /*static*/ void PjRtBuffer::MakeCrossHostReceiveBuffers( - absl::Span shapes, PjRtClient* client, Device* device, + absl::Span shapes, PjRtClient* client, PjRtDevice* device, PjRtCrossHostRecvNotifier&& notifier) { if (shapes.empty()) { notifier(InvalidArgument( @@ -849,7 +850,7 @@ StatusOr> PjRtBuffer::FromHostLiteral( PjRtBuffer::PjRtBuffer(Shape on_host_shape, Shape on_device_shape, std::shared_ptr device_buffer, - PjRtClient* client, Device* device) + PjRtClient* client, PjRtDevice* device) : client_(client), on_host_shape_(std::move(on_host_shape)), on_device_shape_(std::move(on_device_shape)), @@ -1189,7 +1190,7 @@ PjRtBuffer::ScopedHold PjRtBuffer::GetBufferWithHold(ScopedHold::Type type) { StatusOr, std::shared_ptr>> PjRtBuffer::CopyToDeviceHelper( - Device* dst_device, LocalDeviceState* dst_local_device, + PjRtDevice* dst_device, LocalDeviceState* dst_local_device, LocalDeviceState* transfer_local_device, se::Stream* transfer_stream, std::shared_ptr src_device_buffer) { TF_ASSIGN_OR_RETURN( @@ -1249,7 +1250,7 @@ PjRtBuffer::CopyToDeviceHelper( } StatusOr> PjRtBuffer::CopyToDevice( - Device* dst_device) { + PjRtDevice* dst_device) { tensorflow::profiler::TraceMe traceme("PjRtBuffer::CopyToDevice"); if (dst_device == device_) { return InvalidArgument( @@ -1342,8 +1343,6 @@ namespace { // Helper struct for the tuple that is transiently constructed to hold the // arguments of an execution. struct TupleHandle { - // The tuple's shape on the host. - Shape on_host_shape; // The ExecutionInput describing the tuple. ExecutionInput execution_input; // A definition event that has been recorded on the host_to_device stream @@ -1414,8 +1413,7 @@ StatusOr MakeTupleHelper( auto transfer_event = std::make_shared(); transfer_event->SetSequencingEvent(event_or.ConsumeValueOrDie(), stream); - return TupleHandle({std::move(on_host_shape), std::move(execution_input), - std::move(transfer_event)}); + return TupleHandle({std::move(execution_input), std::move(transfer_event)}); } // Converts a ScopedShapedBuffer returned from an execution into a @@ -1423,20 +1421,20 @@ StatusOr MakeTupleHelper( std::unique_ptr OutputBufferHelper( ScopedShapedBuffer* result_buffer, std::shared_ptr definition_event, PjRtClient* client, - Device* device, LocalDeviceState* local_device) { + PjRtDevice* device, LocalDeviceState* local_device) { std::shared_ptr out_buffer = TrackedDeviceBuffer::FromScopedShapedBuffer(result_buffer, {definition_event}); - auto py_buffer = absl::make_unique( + auto pjrt_buffer = absl::make_unique( result_buffer->on_host_shape(), result_buffer->on_device_shape(), std::move(out_buffer), client, device); - RecordUsage(py_buffer->GetBufferWithUsageHold(), local_device, local_device, + RecordUsage(pjrt_buffer->GetBufferWithUsageHold(), local_device, local_device, definition_event, local_device->compute_stream(), /*prefer_to_retain_reference=*/false); - return py_buffer; + return pjrt_buffer; } -static Device* LookupDevice(const PjRtClient& client, int device_id) { +static PjRtDevice* LookupDevice(const PjRtClient& client, int device_id) { auto it = client.id_to_device().find(device_id); CHECK(it != client.id_to_device().end()) << "Unknown device id: " << device_id; @@ -1450,7 +1448,7 @@ PjRtExecutable::PjRtExecutable( bool parameter_is_tupled_arguments, std::shared_ptr device_assignment, std::vector> local_logical_device_ids, - std::vector local_devices, PjRtClient* client) + std::vector local_devices, PjRtClient* client) : client_(client), device_assignment_(std::move(device_assignment)), parameter_is_tupled_arguments_(parameter_is_tupled_arguments), @@ -1508,15 +1506,64 @@ const std::string& PjRtExecutable::name() const { } } +bool PjRtExecutable::MustDonateParameter(int executable_idx, + int parameter) const { + return parameters_that_must_be_donated_[executable_idx].contains(parameter); +} + +StatusOr> +PjRtExecutable::MakeExecutionInputsAndWaitForEvents( + int device_ordinal, const ExecuteOptions& options, + absl::Span argument_handles, + absl::Span device_buffers, + absl::flat_hash_set& events) const { + std::vector execution_inputs; + LocalDeviceState* device_state = &client_->device_state(device_ordinal); + // Lift tuple_handle outside the conditional so that the event it returns is + // not destroyed until after the loop below that waits on events. + absl::optional tuple_handle; + if (parameter_is_tupled_arguments_ && !options.arguments_are_tupled) { + TF_ASSIGN_OR_RETURN(tuple_handle, + MakeTupleHelper(client_, device_state, argument_handles, + device_buffers, device_ordinal)); + events.insert(tuple_handle->event.get()); + execution_inputs.emplace_back(std::move(tuple_handle->execution_input)); + } else { + execution_inputs.reserve(argument_handles.size()); + for (int i = 0; i < argument_handles.size(); ++i) { + PjRtBuffer* handle = argument_handles[i]; + + // Make an ExecutionInput from the device buffer. + execution_inputs.emplace_back(handle->on_device_shape(), + handle->on_host_shape()); + ExecutionInput& execution_input = execution_inputs.back(); + ShapeTree::iterator input_iterator = + execution_input.MutableBuffers()->begin(); + ShapeTree::iterator iterator_end = + execution_input.MutableBuffers()->end(); + device_buffers[i].AddToInput(&input_iterator, iterator_end, + &execution_input, client_->allocator()); + CHECK(input_iterator == iterator_end); + } + } + + for (BufferSequencingEvent* event : events) { + event->WaitForEventOnStream(device_state->compute_stream()); + } + + return execution_inputs; +} + // Enqueues a computation onto the compute stream. Each buffer returned in // device_buffers has a usage hold added that must be dropped on error or // converted on success. StatusOr PjRtExecutable::EnqueueExecution( absl::Span argument_handles, int replica, int partition, int executable_idx, const RunId& run_id, const ExecuteOptions& options, - Device* device, std::vector* device_buffers, + PjRtDevice* device, std::vector* device_buffers, std::shared_ptr device_assignment) const { int device_ordinal = device->local_device_state()->device_ordinal(); + LocalDeviceState* device_state = &client_->device_state(device_ordinal); tensorflow::profiler::TraceMeConsumer activity( "LocalExecutable::Execute", tensorflow::profiler::ContextType::kPjRt, run_id.ToInt()); @@ -1524,10 +1571,7 @@ StatusOr PjRtExecutable::EnqueueExecution( << " mapped to device ordinal for execution: " << device_ordinal; absl::flat_hash_set events; - std::vector execution_inputs; device_buffers->reserve(argument_handles.size()); - const absl::flat_hash_set& parameters_that_must_be_donated = - parameters_that_must_be_donated_[executable_idx]; for (int i = 0; i < argument_handles.size(); ++i) { PjRtBuffer* handle = argument_handles[i]; if (handle->device() != device) { @@ -1536,8 +1580,7 @@ StatusOr PjRtExecutable::EnqueueExecution( "device %s, but replica is assigned to device %s.", i, replica, handle->device()->DebugString(), device->DebugString()); } - bool must_donate = parameters_that_must_be_donated.find(i) != - parameters_that_must_be_donated.end(); + bool must_donate = MustDonateParameter(executable_idx, i); device_buffers->emplace_back(handle->GetBufferWithHold( must_donate ? PjRtBuffer::ScopedHold::kDonation : PjRtBuffer::ScopedHold::kUsage)); @@ -1571,37 +1614,10 @@ StatusOr PjRtExecutable::EnqueueExecution( } } - LocalDeviceState* device_state = &client_->device_state(device_ordinal); - absl::optional tuple_handle; - if (parameter_is_tupled_arguments_ && !options.arguments_are_tupled) { - TF_ASSIGN_OR_RETURN(tuple_handle, - MakeTupleHelper(client_, device_state, argument_handles, - *device_buffers, device_ordinal)); - events.insert(tuple_handle->event.get()); - execution_inputs.emplace_back(std::move(tuple_handle->execution_input)); - } else { - execution_inputs.reserve(argument_handles.size()); - for (int i = 0; i < argument_handles.size(); ++i) { - PjRtBuffer* handle = argument_handles[i]; - - const PjRtBuffer::ScopedHold& device_buffer = (*device_buffers)[i]; - // Make an ExecutionInput from the device buffer. - execution_inputs.emplace_back(handle->on_device_shape(), - handle->on_host_shape()); - ExecutionInput& execution_input = execution_inputs.back(); - ShapeTree::iterator input_iterator = - execution_input.MutableBuffers()->begin(); - ShapeTree::iterator iterator_end = - execution_input.MutableBuffers()->end(); - device_buffer.AddToInput(&input_iterator, iterator_end, &execution_input, - client_->allocator()); - CHECK(input_iterator == iterator_end); - } - } - - for (BufferSequencingEvent* event : events) { - event->WaitForEventOnStream(device_state->compute_stream()); - } + TF_ASSIGN_OR_RETURN( + std::vector execution_inputs, + MakeExecutionInputsAndWaitForEvents( + device_ordinal, options, argument_handles, *device_buffers, events)); ExecutableRunOptions run_options; run_options.set_stream(device_state->compute_stream()); @@ -1676,11 +1692,45 @@ StatusOr PjRtExecutable::EnqueueExecution( return result_buffer_or_status.ConsumeValueOrDie().ConsumeResult(); } +std::vector> PjRtExecutable::MakeOutputBuffers( + int device_ordinal, const ExecuteOptions& options, + ScopedShapedBuffer result_buffer, + std::shared_ptr definition_event, + PjRtDevice* device) const { + std::vector> outputs; + LocalDeviceState* device_state = &client_->device_state(device_ordinal); + if (options.untuple_result && result_buffer.on_host_shape().IsTuple()) { + int tuple_count = result_buffer.on_host_shape().tuple_shapes_size(); + outputs.reserve(tuple_count); + // Take ownership of each of the output values, leaving only the root table + // in result_buffer. + for (int i = 0; i < tuple_count; ++i) { + ScopedShapedBuffer tuple_buffer = result_buffer.TakeSubTree({i}); + outputs.push_back(OutputBufferHelper(&tuple_buffer, definition_event, + client_, device, device_state)); + } + if (device_state->allocation_model() == LocalDeviceState::kSynchronous) { + // Don't release the root buffer until after execution completes. + ShapedBuffer root_buffer_holder = result_buffer.release(); + se::DeviceMemoryBase root_buffer = root_buffer_holder.root_buffer(); + device_state->ThenExecuteOnCallbackThread( + device_state->compute_stream(), + [root_buffer, allocator{client_->allocator()}, device_ordinal]() { + TF_CHECK_OK(allocator->Deallocate(device_ordinal, root_buffer)); + }); + } + } else { + outputs.push_back(OutputBufferHelper(&result_buffer, definition_event, + client_, device, device_state)); + } + return outputs; +} + StatusOr>> PjRtExecutable::ExecuteHelper(absl::Span argument_handles, int replica, int partition, const RunId& run_id, const ExecuteOptions& options, - Device* device) const { + PjRtDevice* device) const { std::shared_ptr device_assignment; if (device == nullptr) { CHECK(device_assignment_ != nullptr); @@ -1737,31 +1787,9 @@ PjRtExecutable::ExecuteHelper(absl::Span argument_handles, } auto definition_event = std::make_shared(); definition_event->SetSequencingEvent(event_or.ConsumeValueOrDie(), stream); - std::vector> outputs; - if (options.untuple_result && result_buffer.on_host_shape().IsTuple()) { - int tuple_count = result_buffer.on_host_shape().tuple_shapes_size(); - outputs.reserve(tuple_count); - // Take ownership of each of the output values, leaving only the root table - // in result_buffer. - for (int i = 0; i < tuple_count; ++i) { - ScopedShapedBuffer tuple_buffer = result_buffer.TakeSubTree({i}); - outputs.push_back(OutputBufferHelper(&tuple_buffer, definition_event, - client_, device, device_state)); - } - if (device_state->allocation_model() == LocalDeviceState::kSynchronous) { - // Don't release the root buffer until after execution completes. - ShapedBuffer root_buffer_holder = result_buffer.release(); - se::DeviceMemoryBase root_buffer = root_buffer_holder.root_buffer(); - device_state->ThenExecuteOnCallbackThread( - device_state->compute_stream(), - [root_buffer, allocator{client_->allocator()}, device_ordinal]() { - TF_CHECK_OK(allocator->Deallocate(device_ordinal, root_buffer)); - }); - } - } else { - outputs.push_back(OutputBufferHelper(&result_buffer, definition_event, - client_, device, device_state)); - } + std::vector> outputs = + MakeOutputBuffers(device_ordinal, options, std::move(result_buffer), + definition_event, device); for (PjRtBuffer::ScopedHold& b : device_buffers) { // prefer_to_retain_reference=false because when using the @@ -1801,7 +1829,7 @@ StatusOr>> PjRtExecutable::Execute( StatusOr>> PjRtExecutable::ExecuteOnLocalDevice( - absl::Span argument_handles, Device* device, + absl::Span argument_handles, PjRtDevice* device, const ExecuteOptions& options) const { if (device_assignment_ == nullptr) { VLOG(1) << "Executing portable single-core program on " @@ -1867,7 +1895,7 @@ PjRtExecutable::ExecuteOnLocalDevices( for (int i = 0; i < num_local_devices; ++i) { const int replica = local_logical_device_ids_[i].first; const int partition = local_logical_device_ids_[i].second; - Device* device = local_devices_[i]; + PjRtDevice* device = local_devices_[i]; const LocalDeviceState& device_state = *device->local_device_state(); device_state.execute_thread()->Schedule([&, replica, partition, i] { results[i] = ExecuteHelper(argument_handles[i], replica, partition, @@ -2114,12 +2142,12 @@ StatusOr, Shape>> GetShardedProgramShapes( build_options.set_result_layout(result_layout); std::vector> local_logical_device_ids; - std::vector local_devices; + std::vector local_devices; if (device_assignment != nullptr) { for (int replica = 0; replica < num_replicas; ++replica) { for (int partition = 0; partition < num_partitions; ++partition) { int device_id = (*device_assignment)(replica, partition); - Device* device = LookupDevice(*client, device_id); + PjRtDevice* device = LookupDevice(*client, device_id); if (device->host_id() != client->host_id()) { VLOG(3) << "Non-local device: " << device_id; continue; diff --git a/tensorflow/compiler/xla/pjrt/pjrt_client.h b/tensorflow/compiler/xla/pjrt/pjrt_client.h index bb9093a8bf7..1bed959e3e6 100644 --- a/tensorflow/compiler/xla/pjrt/pjrt_client.h +++ b/tensorflow/compiler/xla/pjrt/pjrt_client.h @@ -52,17 +52,18 @@ namespace xla { class PjRtClient; -class Device { +class PjRtDevice { public: - explicit Device(int id, std::unique_ptr local_device_state, - std::string platform_name, std::string device_kind, - int host_id = 0) + explicit PjRtDevice(int id, + std::unique_ptr local_device_state, + std::string platform_name, std::string device_kind, + int host_id = 0) : id_(id), local_device_state_(std::move(local_device_state)), host_id_(host_id), platform_name_(std::move(platform_name)), device_kind_(std::move(device_kind)) {} - virtual ~Device() {} + virtual ~PjRtDevice() {} // The ID of this device. IDs are unique among devices of this type // (e.g. CPUs, GPUs). On multi-host platforms, this will be unique across all @@ -130,7 +131,7 @@ class PjRtClient { // `allocator` may null, in which case the platform default allocator is used. explicit PjRtClient( std::string platform_name, LocalClient* client, - std::vector> devices, int host_id, + std::vector> devices, int host_id, std::unique_ptr allocator, std::unique_ptr host_memory_allocator, bool should_stage_host_to_device_transfers, @@ -142,11 +143,15 @@ class PjRtClient { int device_count() const { return devices_.size(); } int local_device_count() const { return local_devices_.size(); } - const std::vector>& devices() const { + const std::vector>& devices() const { return devices_; } - const std::vector& local_devices() const { return local_devices_; } - const std::map& id_to_device() const { return id_to_device_; } + const std::vector& local_devices() const { + return local_devices_; + } + const std::map& id_to_device() const { + return id_to_device_; + } int host_id() const { return host_id_; } const std::string& platform_name() const { return platform_name_; } @@ -210,11 +215,11 @@ class PjRtClient { std::unique_ptr host_memory_allocator_; // Includes all devices, including non-local devices on multi-host platforms. - std::vector> devices_; + std::vector> devices_; // Maps Device::id() to the corresponding Device. Includes all devices. - std::map id_to_device_; + std::map id_to_device_; // Local devices indexed by local device ordinal. - std::vector local_devices_; + std::vector local_devices_; int host_id_; se::DeviceMemoryAllocator* allocator_; @@ -233,7 +238,7 @@ class PjRtClient { // Converts a 2D set of Device objects indexed by [replica][partition] into an // xla::DeviceAssignment. StatusOr DevicesToDeviceAssignment( - absl::Span> devices); + absl::Span> devices); // Holds a reference from Python to a tuple of device buffers. A PjRtBuffer // can be either valid or invalid. An invalid buffer is one that has never been @@ -417,7 +422,7 @@ class PjRtBuffer { // Returns a buffer with uninitialized contents. static StatusOr> CreateUninitialized( - const Shape& shape, PjRtClient* client, Device* device); + const Shape& shape, PjRtClient* client, PjRtDevice* device); // Describes the semantics the caller to FromHostBuffer expects from the // runtime, in a total order from most restrictive to least restrictive. @@ -449,13 +454,13 @@ class PjRtBuffer { const void* data, const Shape& shape, HostBufferSemantics host_buffer_semantics, std::shared_ptr buffer_reference, PjRtClient* client, - Device* device); + PjRtDevice* device); // Note that literal must remain in scope until the transfer has completed, so // the caller should, for example, wait for BlockHostUntilReady() completes on // the return value before letting literal go out of scope. static StatusOr> FromHostLiteral( - const LiteralSlice& literal, PjRtClient* client, Device* device); + const LiteralSlice& literal, PjRtClient* client, PjRtDevice* device); // Asynchronously makes a vector of PjRtBuffers that can be used to receive // cross host transfers using `client` on `device'. `shapes` must be the exact @@ -467,12 +472,13 @@ class PjRtBuffer { // sending host and used in a call to CopyToRemoteDevice. None of the recv // buffers will become ready until *all* of the sends have completed. static void MakeCrossHostReceiveBuffers(absl::Span shapes, - PjRtClient* client, Device* device, + PjRtClient* client, + PjRtDevice* device, PjRtCrossHostRecvNotifier&& notifier); PjRtBuffer(Shape on_host_shape, Shape on_device_shape, std::shared_ptr device_buffer, - PjRtClient* client, Device* device); + PjRtClient* client, PjRtDevice* device); ~PjRtBuffer(); PjRtBuffer(const PjRtBuffer&) = delete; @@ -482,7 +488,7 @@ class PjRtBuffer { const Shape& on_host_shape() const { return on_host_shape_; } const Shape& on_device_shape() const { return on_device_shape_; } - Device* device() const { return device_; } + PjRtDevice* device() const { return device_; } const std::string& platform_name() const { return client_->platform_name(); } PjRtClient* client() const { return client_; } bool IsEmptyTuple() const { @@ -556,7 +562,7 @@ class PjRtBuffer { // Copies the buffer to device `dst_device`. Returns an error if the buffer is // already on dst_device. - StatusOr> CopyToDevice(Device* dst_device); + StatusOr> CopyToDevice(PjRtDevice* dst_device); // Copies the buffer to the remote device encoded in serialized_descriptor. // This call must be preceded by a call to MakeCrossHostReceiveBuffers on the @@ -629,7 +635,7 @@ class PjRtBuffer { StatusOr, std::shared_ptr>> - CopyToDeviceHelper(Device* dst_device, LocalDeviceState* dst_local_device, + CopyToDeviceHelper(PjRtDevice* dst_device, LocalDeviceState* dst_local_device, LocalDeviceState* transfer_local_device, se::Stream* transfer_stream, std::shared_ptr src_device_buffer); @@ -637,7 +643,7 @@ class PjRtBuffer { PjRtClient* const client_; const Shape on_host_shape_; const Shape on_device_shape_; - Device* const device_; + PjRtDevice* const device_; mutable absl::Mutex mu_; std::shared_ptr device_buffer_ TF_GUARDED_BY(mu_); @@ -668,6 +674,11 @@ struct CompileOptions { bool compile_portable_executable = false; }; +class ExecuteContext { + public: + virtual ~ExecuteContext() = default; +}; + struct ExecuteOptions { // If true, the client must pass a single PjRtBuffer which contains all of // the arguments as a single XLA tuple, otherwise each argument must be @@ -682,6 +693,9 @@ struct ExecuteOptions { // multi-host programs are launched in different orders on different hosts, // the launch IDs may be used by the runtime to detect the mismatch. int32 launch_id = 0; + // If non-null, an opaque context passed to an execution that may be used to + // supply additional arguments to a derived class of PjRtExecutable. + ExecuteContext* context = nullptr; }; // Represents a compiled computation that can be executed given handles to @@ -699,7 +713,7 @@ class PjRtExecutable { bool parameter_is_tupled_arguments, std::shared_ptr device_assignment, std::vector> local_logical_device_ids, - std::vector local_devices, PjRtClient* client); + std::vector local_devices, PjRtClient* client); virtual ~PjRtExecutable() = default; @@ -733,14 +747,16 @@ class PjRtExecutable { return local_logical_device_ids_; } - const std::vector& local_devices() const { return local_devices_; } + const std::vector& local_devices() const { + return local_devices_; + } StatusOr>> Execute( absl::Span argument_handles, const ExecuteOptions& options) const; StatusOr>> ExecuteOnLocalDevice( - absl::Span argument_handles, Device* device, + absl::Span argument_handles, PjRtDevice* device, const ExecuteOptions& options) const; // Execute on local devices. Takes a sequence of argument lists (one argument @@ -756,22 +772,42 @@ class PjRtExecutable { const string& name() const; + protected: + bool parameter_is_tupled_arguments() const { + return parameter_is_tupled_arguments_; + } + private: // Initializes information about which arguments to which executables must be // donated due to aliases that were specified by the computation. Status SetUpDonation(PjRtClient* client, bool tuple_inputs); + virtual bool MustDonateParameter(int executable_idx, int parameter) const; + + virtual StatusOr> + MakeExecutionInputsAndWaitForEvents( + int device_ordinal, const ExecuteOptions& options, + absl::Span argument_handles, + absl::Span device_buffers, + absl::flat_hash_set& events) const; + StatusOr EnqueueExecution( absl::Span argument_handles, int replica, int partition, int executable_idx, const RunId& run_id, - const ExecuteOptions& options, Device* device, + const ExecuteOptions& options, PjRtDevice* device, std::vector* device_buffers, std::shared_ptr device_assignment) const; + virtual std::vector> MakeOutputBuffers( + int device_ordinal, const ExecuteOptions& options, + ScopedShapedBuffer result_buffer, + std::shared_ptr definition_event, + PjRtDevice* device) const; + StatusOr>> ExecuteHelper( absl::Span argument_handles, int replica, int partition, const RunId& run_id, const ExecuteOptions& options, - Device* device = nullptr) const; + PjRtDevice* device = nullptr) const; // Create shared pointers so we can free them after the execution: with // asynchronous execution, the process being executed can outlive the @@ -800,7 +836,7 @@ class PjRtExecutable { // assigned. // shared_ptrs instead of unique_ptrs to play well with the Python bindings // (see xla.cc). - std::vector local_devices_; + std::vector local_devices_; }; } // namespace xla diff --git a/tensorflow/compiler/xla/python/BUILD b/tensorflow/compiler/xla/python/BUILD index 1330dca6402..046fadb405b 100644 --- a/tensorflow/compiler/xla/python/BUILD +++ b/tensorflow/compiler/xla/python/BUILD @@ -155,7 +155,7 @@ cc_library( "//tensorflow/compiler/xla:statusor", "//tensorflow/compiler/xla:types", "//tensorflow/compiler/xla:util", - "//tensorflow/core/lib/bfloat16", + "//tensorflow/core/platform:bfloat16", "//tensorflow/core/platform:logging", "//third_party/py/numpy:headers", "//third_party/python_runtime:headers", # buildcleaner: keep @@ -242,6 +242,33 @@ cc_library( ], ) +cc_library( + name = "jax_jit", + srcs = ["jax_jit.cc"], + hdrs = ["jax_jit.h"], + copts = [ + "-fexceptions", + "-fno-strict-aliasing", + ], + features = ["-use_header_modules"], + visibility = ["//visibility:private"], + deps = [ + ":py_client", + ":pytree", + ":types", + "//tensorflow/compiler/xla:shape_util", + "//tensorflow/compiler/xla:statusor", + "//tensorflow/compiler/xla:util", + "//tensorflow/compiler/xla:xla_data_proto_cc", + "//tensorflow/compiler/xla/pjrt:pjrt_client", + "//tensorflow/core/platform:status", + "@com_google_absl//absl/container:flat_hash_map", + "@com_google_absl//absl/container:inlined_vector", + "@com_google_absl//absl/types:optional", + "@pybind11", + ], +) + cc_library( name = "ops", srcs = ["ops.cc"], @@ -367,6 +394,7 @@ pybind_extension( deps = [ ":bfloat16", ":dlpack", + ":jax_jit", ":ops", ":py_client", ":pytree", diff --git a/tensorflow/compiler/xla/python/bfloat16.cc b/tensorflow/compiler/xla/python/bfloat16.cc index 1f21b3fb242..b70244cc3ef 100644 --- a/tensorflow/compiler/xla/python/bfloat16.cc +++ b/tensorflow/compiler/xla/python/bfloat16.cc @@ -27,7 +27,7 @@ limitations under the License. #include "absl/strings/str_cat.h" #include "tensorflow/compiler/xla/types.h" #include "tensorflow/compiler/xla/util.h" -#include "tensorflow/core/lib/bfloat16/bfloat16.h" +#include "tensorflow/core/platform/bfloat16.h" #include "tensorflow/core/platform/logging.h" namespace xla { diff --git a/tensorflow/compiler/xla/python/dlpack.cc b/tensorflow/compiler/xla/python/dlpack.cc index 4fc17172ea7..974816407ee 100644 --- a/tensorflow/compiler/xla/python/dlpack.cc +++ b/tensorflow/compiler/xla/python/dlpack.cc @@ -193,7 +193,7 @@ StatusOr> StridesToLayout(absl::Span dims, return minor_to_major; } -StatusOr DLDeviceTypeForDevice(const Device& device) { +StatusOr DLDeviceTypeForDevice(const PjRtDevice& device) { const se::Platform* platform = device.local_device_state()->executor()->platform(); if (platform->id() == se::host::kHostPlatformId) { @@ -205,15 +205,15 @@ StatusOr DLDeviceTypeForDevice(const Device& device) { device.DebugString()); } -StatusOr DLContextForDevice(const Device& device) { +StatusOr DLContextForDevice(const PjRtDevice& device) { DLContext context; TF_ASSIGN_OR_RETURN(context.device_type, DLDeviceTypeForDevice(device)); context.device_id = device.local_device_state()->device_ordinal(); return context; } -StatusOr DeviceForDLContext(const PjRtClient& client, - const DLContext& context) { +StatusOr DeviceForDLContext(const PjRtClient& client, + const DLContext& context) { se::Platform::Id platform_id; switch (context.device_type) { case kDLCPU: @@ -226,7 +226,7 @@ StatusOr DeviceForDLContext(const PjRtClient& client, return InvalidArgument("Unknown/unsupported DLPack device type %d", context.device_type); } - auto it = absl::c_find_if(client.local_devices(), [&](Device* device) { + auto it = absl::c_find_if(client.local_devices(), [&](PjRtDevice* device) { return device->local_device_state()->executor()->platform()->id() == platform_id && device->local_device_state()->device_ordinal() == context.device_id; @@ -313,7 +313,7 @@ StatusOr> DLPackManagedTensorToBuffer( dlmt->dl_tensor.ndim); } TF_ASSIGN_OR_RETURN( - Device * device, + PjRtDevice * device, DeviceForDLContext(*client->pjrt_client(), dlmt->dl_tensor.ctx)); absl::Span dimensions( reinterpret_cast(dlmt->dl_tensor.shape), dlmt->dl_tensor.ndim); diff --git a/tensorflow/compiler/xla/python/jax_jit.cc b/tensorflow/compiler/xla/python/jax_jit.cc new file mode 100644 index 00000000000..96cf1e64b85 --- /dev/null +++ b/tensorflow/compiler/xla/python/jax_jit.cc @@ -0,0 +1,708 @@ +/* Copyright 2020 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +// This files implements the `jax.jit` dispatch and just-in-time feature. +// +// In a nutshell, `Jit(f)` returns a callable that will dispatch (i.e. forward +// based on passed arguments dtypes/shapes/identity) the execution to a +// just-in-time compiled XLA Executable. All of that is done in C++ for +// performance reasons. +// +// This file contains the utilities to: +// (a) inspect arguments and describe their structure, dtype/shapes, etc. +// (b) keep a mapping from function signatures to compiled XLA Executables. + +#include "tensorflow/compiler/xla/python/jax_jit.h" + +#include +#include + +#include "absl/container/flat_hash_map.h" +#include "absl/container/inlined_vector.h" +#include "absl/types/optional.h" +#include "pybind11/cast.h" +#include "pybind11/numpy.h" +#include "pybind11/pybind11.h" +#include "pybind11/pytypes.h" +#include "tensorflow/compiler/xla/pjrt/pjrt_client.h" +#include "tensorflow/compiler/xla/python/py_buffer.h" +#include "tensorflow/compiler/xla/python/py_executable.h" +#include "tensorflow/compiler/xla/python/pytree.h" +#include "tensorflow/compiler/xla/python/types.h" +#include "tensorflow/compiler/xla/shape_util.h" +#include "tensorflow/compiler/xla/statusor.h" +#include "tensorflow/compiler/xla/util.h" +#include "tensorflow/compiler/xla/xla_data.pb.h" +#include "tensorflow/core/platform/status.h" + +namespace xla { + +namespace py = pybind11; + +// TODO(phawkins): Add support for Tracers. +// TODO(jblespiau): Add support for donate_argnums. +// TODO(jblespiau): Use absl Status. + +namespace { + +// Describes the abstract shape and dtype of an argument. +struct ArgSignature { + // This is the XLA dtype of the object. + xla::PrimitiveType dtype; + // JAX arguments can be of weak type, if and only if they are Python scalars + // or `DeviceArray` values such that `aval.weak_type` is true. + bool weak_type; + absl::InlinedVector shape; + bool operator==(const ArgSignature& other) const { + return std::tie(dtype, weak_type, shape) == + std::tie(other.dtype, other.weak_type, other.shape); + } + bool operator!=(const ArgSignature& other) const { return !(*this == other); } + + std::string DebugString() const { + std::string result = ""; + if (weak_type) { + absl::StrAppend(&result, "weak_"); + } + absl::StrAppend(&result, xla::PrimitiveType_Name(dtype)); + absl::StrAppend(&result, "[", absl::StrJoin(shape, ","), "]"); + return result; + } +}; + +template +H AbslHashValue(H h, const ArgSignature& s) { + h = H::combine(std::move(h), s.dtype); + if (!s.shape.empty()) { + h = H::combine_contiguous(std::move(h), &s.shape.front(), s.shape.size()); + } + return h; +} + +// The signature of Python jitted function call, partitioned into: +// - dynamic positional arguments (i.e. positional args which are not static) +// - static positional arguments (i.e. the args associated to static_argnums) +// - keyword arguments +// The CallSignature should unambiguously identify a function call, thus, +// equality is based on: +// (a) Same PyTree for all dynamic positional arguments and keyword arguments +// (a) equality of the arguments and keyword arguments ArgSignature +// (a) equality (delegated to Python) of the static arguments. +struct CallSignature { + struct KwargEntry { + // To avoid comparing strings, we intern the kwargs strings. + // The compilation cache holds a reference to all the keys. + py::handle key; + PyTreeDef value_treedef; + bool operator==(const KwargEntry& other) const { + return key.ptr() == other.key.ptr() && + value_treedef == other.value_treedef; + } + bool operator!=(const KwargEntry& other) const { return !(*this == other); } + }; + + // Only contains the arguments associated to `static_argnums`, sorted in the + // order of their argnum index. + std::vector static_args; + // A PyTreeDef for each positional dynamic (i.e. not static) argument. + std::vector dynamic_positional_args_treedef; + // Keyword arguments. Sorted by the interned keyword pointers. + std::vector keyword_args; + // Shape and dtype for both the dynamic positional arguments and the keyword + // arguments (sorted by interned keyword pointers). + std::vector dynamic_args_signatures; + + bool operator==(const CallSignature& other) const { + return std::tie(dynamic_positional_args_treedef, static_args, keyword_args, + dynamic_args_signatures) == + std::tie(other.dynamic_positional_args_treedef, other.static_args, + other.keyword_args, other.dynamic_args_signatures); + } + bool operator!=(const CallSignature& other) const { + return !(*this == other); + } + + // To be used when we want to keep ownership of Python values referenced by + // the `CallSignature` (i.e. when we insert an entry). + void IncRef() const; + // The destructor of the cache should call this on all entries. + void DecRef() const; + + std::string DebugString() const; +}; + +void CallSignature::IncRef() const { + for (const auto& kw : keyword_args) { + kw.key.inc_ref(); + } +} + +void CallSignature::DecRef() const { + for (const auto& kw : keyword_args) { + kw.key.dec_ref(); + } +} + +template +H AbslHashValue(H h, const CallSignature::KwargEntry& kw) { + h = H::combine(std::move(h), kw.key.ptr(), kw.value_treedef); + return h; +} + +template +H AbslHashValue(H h, const CallSignature& s) { + // /!\ important: We cannot include static arguments to the hash, because + // the py::object must be hashable for absl. We can try delegating to the + // Python __hash__, but there are many non-hashable Python types such as + // np.ndarray. + // TODO(jblespiau): We should either ban non-hashable objects from jit or we + // should hash them by object identity. + h = H::combine_contiguous(std::move(h), + &s.dynamic_positional_args_treedef.front(), + s.dynamic_positional_args_treedef.size()); + h = H::combine_contiguous(std::move(h), &s.keyword_args.front(), + s.keyword_args.size()); + h = H::combine_contiguous(std::move(h), &s.dynamic_args_signatures.front(), + s.dynamic_args_signatures.size()); + return h; +} + +std::string CallSignature::DebugString() const { + std::vector static_args_str; + static_args_str.reserve(static_args.size()); + for (auto& static_arg : static_args) { + static_args_str.emplace_back(py::cast(static_arg.str())); + } + + std::vector signature_str; + signature_str.reserve(dynamic_args_signatures.size()); + + for (auto& arg_signature : dynamic_args_signatures) { + signature_str.emplace_back(arg_signature.DebugString()); + } + std::vector tree_def_str; + signature_str.reserve(dynamic_positional_args_treedef.size()); + for (auto& tree_def : dynamic_positional_args_treedef) { + tree_def_str.emplace_back(tree_def.ToString()); + } + std::vector keyword_names; + keyword_names.reserve(keyword_args.size()); + for (auto& kwarg_entry : keyword_args) { + keyword_names.emplace_back(py::cast(kwarg_entry.key)); + tree_def_str.emplace_back(kwarg_entry.value_treedef.ToString()); + } + return absl::StrCat( + static_args.size(), " static_args: ", absl::StrJoin(static_args_str, ","), + "\n", // new line + keyword_args.size(), " keyword args:", absl::StrJoin(keyword_names, ","), + "\n", // new-line + dynamic_positional_args_treedef.size(), " positional args.\n", + dynamic_args_signatures.size(), + " dynamic args (positional+keyword):\n - ", + absl::StrJoin(signature_str, ", "), "\n - ", + absl::StrJoin(tree_def_str, " | ")); +} + +struct CacheEntry { + std::shared_ptr executable; + xla::PjRtDevice* device; + PyTreeDef out_pytree_def; + // These are the objects required to create a `DeviceArray` object. + // We use Python types within the vector because this is what we will be + // returning to Python. No need to convert back and forth. + // We need py::object to maintain the objects alive. + std::vector out_avals; + std::vector out_lazy_exprs; +}; + +// A `CompiledFunction` is associated to a `jax.jit(f)` and takes care of the +// bookkeeping of the different signatures used and the dispatch of calls to +// the correct underlying `PyExecutable`. +class CompiledFunction { + public: + CompiledFunction(py::function cache_miss_fun, py::function python_f_jitted, + bool jax_enable_x64, std::vector static_argnums, + std::shared_ptr pyclient, + xla::PjRtDevice* device); + ~CompiledFunction(); + + // This function will: + // (a) flatten the inputs using pytree + // (b) get buffer objects from the arguments + // (c) call the executable + // (d) construct `DeviceArray` objects from the outputs + // (e) reconstruct the `PyTree`. + py::object Call(py::args args, py::kwargs kwargs); + + private: + CacheEntry& GetCacheEntry(const py::args& args, const py::kwargs& kwargs, + const CallSignature& signature); + + // The Python function in charge of returning a `xla::PyExecutable` from + // the arguments passed to `jitted_f`. + const py::function cache_miss_fun_; + // A function to call as fallback. This is the result of calling the Python + // `jax.jit`. + // TODO(jblespiau): Delete this when the C++ codepath supports all features. + const py::function python_f_jitted_; + + // The value of the Python flag when the object was created. + const bool jax_enable_x64_; + + // We need to know the static arguments to remove them from the arguments + // passed to the underlying PyExecutable. In sorted order. + std::vector static_argnums_; + // We need a `unique_ptr` here to ensure value pointer stability. + absl::flat_hash_map> executables_; + + const std::shared_ptr pyclient_; + xla::PjRtDevice* const default_device_; +}; + +CompiledFunction::CompiledFunction(py::function cache_miss_fun, + py::function python_f_jitted, + bool jax_enable_x64, + std::vector static_argnums, + std::shared_ptr pyclient, + xla::PjRtDevice* device) + : cache_miss_fun_(std::move(cache_miss_fun)), + python_f_jitted_(std::move(python_f_jitted)), + jax_enable_x64_(jax_enable_x64), + static_argnums_(std::move(static_argnums)), + pyclient_(std::move(pyclient)), + default_device_(device) { + std::sort(static_argnums_.begin(), static_argnums_.end()); +} + +CompiledFunction::~CompiledFunction() { + for (const auto& entry : executables_) { + entry.first.DecRef(); + } +} + +namespace { + +// The resulting information of the parsing and conversion of the arguments. +struct ParsedArgumentsAsBuffers { + // The call signature will be filled during 2 steps: + // - `FlattenArguments` will fill the static arguments and the pytree + // structures + // - the shapes and dtypes are filled later, by `ParseAndTransferArguments`. + CallSignature signature; + // The concatenation of the dynamic positional arguments and the sorted + // keyword arguments. We do not need ownership, thus the py::handle. + // TODO(jblespiau): We do not need py::object here and py::handle suffice and + // will prevent any counter increment. + std::vector flat_dynamic_args; + std::vector keep_alive_objects; + + // The following is only valid if the parsing succeeds. + std::vector arg_buffers; + // We may need to keep some objects around, because: + // (a) we need to extend the lifetime of objects created within + // `ConvertArgsToBuffers` + // (b) `arg_buffers` do not maintain ownership + std::vector, + std::unique_ptr>> + keep_alive; +}; + +// Filter out static arguments, flatten and concatenate other arguments (i.e. +// dynamic positional and keyword arguments), filling `arguments` in place. +void FlattenArguments(const py::args& args, const py::kwargs& py_kwargs, + absl::Span static_argnums, + ParsedArgumentsAsBuffers& arguments) { + arguments.flat_dynamic_args.reserve(args.size() + py_kwargs.size() - + static_argnums.size()); + arguments.signature.dynamic_positional_args_treedef.reserve( + args.size() - static_argnums.size()); + + // Positional arguments. + for (size_t i = 0; i < args.size(); ++i) { + if (std::find(static_argnums.begin(), static_argnums.end(), i) == + static_argnums.end()) { + PyTreeDef pytree_def; + pytree_def.FlattenInto(args[i], arguments.flat_dynamic_args); + arguments.signature.dynamic_positional_args_treedef.push_back(pytree_def); + } else { + arguments.signature.static_args.emplace_back( + // borrow is mandatory here. + py::reinterpret_borrow(args[i])); + } + } + + // Keyword arguments. + std::vector> kwargs(py_kwargs.begin(), + py_kwargs.end()); + // We first intern the keys, then sort them (by pointer) and then create + // the signatures. + arguments.signature.keyword_args.resize(kwargs.size()); + for (size_t i = 0; i < kwargs.size(); ++i) { + // Intern the key if not already interned. + if (!PyUnicode_CHECK_INTERNED(kwargs[i].first.ptr())) { + PyObject* key = kwargs[i].first.ptr(); + kwargs[i].first.inc_ref(); + PyUnicode_InternInPlace(&key); + arguments.keep_alive_objects.push_back( + py::reinterpret_steal(key)); + kwargs[i].first = py::handle(key); + } + } + + std::sort(kwargs.begin(), kwargs.end(), + [](const std::pair& a, + const std::pair& b) { + return a.first.ptr() < b.first.ptr(); + }); + for (size_t i = 0; i < kwargs.size(); ++i) { + arguments.signature.keyword_args[i].key = kwargs[i].first; + arguments.signature.keyword_args[i].value_treedef.FlattenInto( + kwargs[i].second, arguments.flat_dynamic_args); + } +} + +template +std::unique_ptr ConvertToScalarBuffer( + const py::handle& scalar, xla::PjRtClient* client, + xla::PjRtDevice* device) { + CppType data = py::cast(scalar); + xla::Shape shape = xla::ShapeUtil::MakeShapeWithType({}); + return ValueOrThrow(xla::PjRtBuffer::FromHostBuffer( + &data, shape, + xla::PjRtBuffer::HostBufferSemantics::kImmutableOnlyDuringCall, nullptr, + client, device)); +} + +// Convert a scalar to the associated PjRtBuffer or raises an error if it is +// not convertible (thus, this must be called after other checks). +StatusOr> ScalarToBuffer( + py::handle scalar, bool jax_enable_x64, xla::PjRtClient* client, + xla::PjRtDevice* device) { + // Important: In Python, isinstance(True, int) returns True. Thus, we have + // to check for bool before int. + if (py::isinstance(scalar)) { + return ConvertToScalarBuffer(scalar, client, device); + } else if (py::isinstance(scalar)) { + if (jax_enable_x64) { + return ConvertToScalarBuffer(scalar, client, device); + } else { + return ConvertToScalarBuffer(scalar, client, device); + } + } else if (py::isinstance(scalar)) { + if (jax_enable_x64) { + return ConvertToScalarBuffer(scalar, client, device); + + } else { + return ConvertToScalarBuffer(scalar, client, device); + } + } else if (PyComplex_Check(scalar.ptr())) { + Py_complex result = PyComplex_AsCComplex(scalar.ptr()); + if (result.real == -1.0 && PyErr_Occurred()) { + PyErr_Clear(); + throw std::runtime_error("Could not convert the complex number"); + } + if (jax_enable_x64) { + xla::complex128 data(result.real, result.imag); + xla::Shape shape = xla::ShapeUtil::MakeShapeWithType({}); + return ValueOrThrow(xla::PjRtBuffer::FromHostBuffer( + &data, shape, + xla::PjRtBuffer::HostBufferSemantics::kImmutableOnlyDuringCall, + nullptr, client, device)); + } else { + xla::complex64 data(result.real, result.imag); + xla::Shape shape = xla::ShapeUtil::MakeShapeWithType({}); + return ValueOrThrow(xla::PjRtBuffer::FromHostBuffer( + &data, shape, + xla::PjRtBuffer::HostBufferSemantics::kImmutableOnlyDuringCall, + nullptr, client, device)); + } + } + return InvalidArgument( + "%s", absl::StrCat( + "Not supported: The C++ jax jit execution path, only accepts " + "DeviceArray, Numpy arrays, or Python scalars. Got type ", + py::cast(scalar.get_type().str()))); +} + +const py::dtype* DtypeTo32BitDtype(const py::dtype& dtype) { + static const auto* int64_dt = new py::dtype("int64"); + static const auto* int32_dt = new py::dtype("int32"); + static const auto* uint64_dt = new py::dtype("uint64"); + static const auto* uint32_dt = new py::dtype("uint32"); + static const auto* float64_dt = new py::dtype("float64"); + static const auto* float32_dt = new py::dtype("float32"); + static const auto* complex64_dt = new py::dtype("complex64"); + static const auto* complex128_dt = new py::dtype("complex128"); + + if (dtype == *int64_dt) { + return int32_dt; + } + if (dtype == *float64_dt) { + return float32_dt; + } + if (dtype == *uint64_dt) { + return uint32_dt; + } + if (dtype == *complex128_dt) { + return complex64_dt; + } + + return nullptr; +} + +// Converts flattened arguments contained in ParsedArgumentsAsBuffers in +// place. If arguments are `DeviceArray`, they must all be on the same `Device`. +// +// Returns `OkStatus()` on success. +Status ConvertArgsToBuffers(bool jax_enable_x64, xla::PyClient& pyclient, + xla::PjRtDevice* default_device, + ParsedArgumentsAsBuffers& arguments) { + std::vector& arg_buffers = arguments.arg_buffers; + auto& keep_alive = arguments.keep_alive; + + int num_flat_dynamic_args = arguments.flat_dynamic_args.size(); + arg_buffers.reserve(num_flat_dynamic_args); + arguments.signature.dynamic_args_signatures.reserve(num_flat_dynamic_args); + + static const auto* xla_module = + new py::module(py::module::import("jax.interpreters.xla")); + const auto& device_array = xla_module->attr("DeviceArray"); + + static const auto* numpy_module = new py::module(py::module::import("numpy")); + const auto& array = numpy_module->attr("array"); + + // TODO(phawkins): consider device stickiness. + // We first check whether any `DeviceArray` is present and whether they are + // attached to any specific device. See also + // https://github.com/google/jax/pull/1884 + // https://github.com/google/jax/pull/1916 for the rationale why the + // computation follows the data locality. + // It's also similar to PyTorch's behavior. + xla::PjRtDevice* data_device = nullptr; + for (py::handle arg : arguments.flat_dynamic_args) { + if (py::isinstance(arg, device_array)) { + xla::PyBuffer* buffer = + py::cast(arg.attr("device_buffer")); + xla::PjRtDevice* device = buffer->buffer()->device(); + if (data_device && (device != data_device)) { + return InvalidArgument( + "%s", + absl::StrCat( + "Arguments to a jit-compiled function must be colocated on the " + "same device. Arguments were found to be on the two following " + "different devices: ", + device->DebugString(), " and ", data_device->DebugString())); + } else { + data_device = device; + } + } + } + if (!data_device) { + // No `DeviceArray` were found default to `default_device`. + data_device = default_device; + } + xla::PjRtClient* pjrt_client = data_device->client(); + + for (py::handle arg : arguments.flat_dynamic_args) { + // We do not support here d2d transparent transfers. + // We assumes all the `DeviceArray` are already on the correct and shared + // device. + if (py::isinstance(arg, device_array)) { + xla::PyBuffer* buffer = + py::cast(arg.attr("device_buffer")); + arg_buffers.push_back(buffer->buffer()); + ArgSignature sig; + sig.dtype = buffer->shape().element_type(); + sig.shape.assign(buffer->shape().dimensions().begin(), + buffer->shape().dimensions().end()); + sig.weak_type = py::cast(arg.attr("aval").attr("weak_type")); + arguments.signature.dynamic_args_signatures.push_back(std::move(sig)); + } else if (py::isinstance(arg)) { + // TODO(jblespiau): Can we improve this call? Do we need the underlying + // GlobalPyRefManager() and co? + py::array numpy_array = py::cast(arg); + // If jax_enable_x64 is not set, we need to coerce 32 bits types. + // Note that this is calling back to Python! + // TODO(jblespiau): We can remove this complexity when we delete + // jax_enable_x64 mode. + if (!jax_enable_x64) { + const py::dtype* to_dtype = DtypeTo32BitDtype(numpy_array.dtype()); + if (to_dtype) { + numpy_array = array(numpy_array, to_dtype); + } + } + std::unique_ptr buffer = + ValueOrThrow(pyclient.BufferFromPyval( + numpy_array, data_device, + /*force_copy=*/false, /*host_buffer_semantics=*/ + xla::PjRtBuffer::HostBufferSemantics::kZeroCopy)); + arg_buffers.push_back(buffer->buffer()); + + ArgSignature sig; + sig.dtype = buffer->shape().element_type(); + sig.shape.assign(buffer->shape().dimensions().begin(), + buffer->shape().dimensions().end()); + arguments.signature.dynamic_args_signatures.push_back(sig); + + keep_alive.emplace_back(std::move(buffer)); + } else { + StatusOr> buffer = + ScalarToBuffer(arg, jax_enable_x64, pjrt_client, data_device); + if (!buffer.ok()) { + return buffer.status(); + } + arg_buffers.push_back(buffer.ValueOrDie().get()); + ArgSignature sig; + sig.dtype = buffer.ValueOrDie()->on_host_shape().element_type(); + sig.weak_type = true; + arguments.signature.dynamic_args_signatures.push_back(sig); + + keep_alive.emplace_back(std::move(buffer).ValueOrDie()); + } + } + return Status::OK(); +} + +} // namespace + +CacheEntry& CompiledFunction::GetCacheEntry(const py::args& args, + const py::kwargs& kwargs, + const CallSignature& signature) { + auto found_iterator = executables_.find(signature); + if (found_iterator != executables_.end()) { // Cache hit! + return *(found_iterator->second); + } + + // We need to insert the element. + auto result = executables_.emplace(signature, std::make_unique()); + auto it = result.first; + + // CallSignatures in the cache own their keyword argument reference. + result.first->first.IncRef(); + + // Cache miss? Call the Python cache miss function. + py::tuple executable_and_pytree = cache_miss_fun_(*args, **kwargs); + if (executable_and_pytree.size() != 4) { + throw std::runtime_error( + "AssertionError: The cache miss function should return 4 " + "arguments."); + } + it->second->executable = py::cast>( + std::move(executable_and_pytree[0])); + int num_devices = + it->second->executable->pjrt_executable().local_devices().size(); + if (num_devices != 1) { + throw std::runtime_error(absl::StrCat( + "Running on more than a single device is not currently supported." + "The underlying PjRtExecutable has ", + num_devices)); + } + it->second->device = + it->second->executable->pjrt_executable().local_devices()[0]; + it->second->out_pytree_def = py::cast(executable_and_pytree[1]); + + py::list shaped_arrays = + py::reinterpret_borrow(executable_and_pytree[2]); + py::list lazy_expressions = + py::reinterpret_borrow(executable_and_pytree[3]); + + it->second->out_avals.reserve(shaped_arrays.size()); + it->second->out_lazy_exprs.reserve(lazy_expressions.size()); + + int num_outputs = shaped_arrays.size(); + for (int i = 0; i < num_outputs; ++i) { + py::object shaped_array = + py::reinterpret_borrow(shaped_arrays[i]); + py::object lazy_expr = + py::reinterpret_borrow(lazy_expressions[i]); + + it->second->out_avals.push_back(shaped_array); + it->second->out_lazy_exprs.push_back(lazy_expr); + } + + return *(it->second); +} + +py::object CompiledFunction::Call(py::args args, py::kwargs kwargs) { + ParsedArgumentsAsBuffers arguments; + FlattenArguments(args, kwargs, static_argnums_, arguments); + + // The C++ jit do not support Tracers arguments yet. The Python-based jit + // function will be called if any of the dynamic arguments is unsupported. + if (!ConvertArgsToBuffers(jax_enable_x64_, *pyclient_, default_device_, + arguments) + .ok()) { + return python_f_jitted_(*args, **kwargs); + } + + CacheEntry& cache_entry = GetCacheEntry(args, kwargs, arguments.signature); + + std::vector> outputs = + ValueOrThrow(cache_entry.executable->PjRtExecute(arguments.arg_buffers)); + + static const auto* xla_module = + new py::module(py::module::import("jax.interpreters.xla")); + const auto& device_array = xla_module->attr("DeviceArray"); + + const std::vector& out_avals = cache_entry.out_avals; + const std::vector& out_lazy_exprs = cache_entry.out_lazy_exprs; + + py::list flat_device_arrays; + for (int i = 0; i < outputs.size(); ++i) { + flat_device_arrays.append(device_array( + /*aval=*/out_avals[i], /*device=*/outputs[i]->device(), + /*lazy_expr=*/out_lazy_exprs[i], + /*device_buffer=*/std::move(outputs[i]))); + } + return cache_entry.out_pytree_def.Unflatten(flat_device_arrays); +} + +} // namespace + +void BuildJaxjitSubmodule(pybind11::module& m) { + py::module jitlib = m.def_submodule("jax_jit", "Jax C++ jit library"); + + py::class_> cfun( + jitlib, "CompiledFunction"); + cfun.def("__call__", &CompiledFunction::Call); + + jitlib.def("jit", + [](py::function cache_miss_fun, + py::function fallback_on_unsupported_argument, + bool jax_enable_x64, std::vector static_argnums, + xla::ClientAndPtr client_and_device) + -> std::unique_ptr { + return std::make_unique( + std::move(cache_miss_fun), + std::move(fallback_on_unsupported_argument), jax_enable_x64, + std::move(static_argnums), client_and_device.client, + client_and_device.contents); + }); + + // Only for testing purposes + jitlib.def("_ScalarToBuffer", [](py::handle scalar, bool jax_enable_x64, + std::shared_ptr client) { + xla::PjRtClient* pjrt_client = client->pjrt_client(); + + return std::make_unique( + client, + ScalarToBuffer(scalar, jax_enable_x64, pjrt_client, + pjrt_client->local_devices()[0]) + .ValueOrDie(), + nullptr); + }); +} + +} // namespace xla diff --git a/tensorflow/compiler/xla/python/jax_jit.h b/tensorflow/compiler/xla/python/jax_jit.h new file mode 100644 index 00000000000..2b1603aac27 --- /dev/null +++ b/tensorflow/compiler/xla/python/jax_jit.h @@ -0,0 +1,27 @@ +/* Copyright 2020 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_COMPILER_XLA_PYTHON_JAX_JIT_H_ +#define TENSORFLOW_COMPILER_XLA_PYTHON_JAX_JIT_H_ + +#include "pybind11/pybind11.h" + +namespace xla { + +void BuildJaxjitSubmodule(pybind11::module& m); + +} // namespace xla + +#endif // TENSORFLOW_COMPILER_XLA_PYTHON_JAX_JIT_H_ diff --git a/tensorflow/compiler/xla/python/outfeed_receiver.cc b/tensorflow/compiler/xla/python/outfeed_receiver.cc index 7c029ca7d19..f6067e650c0 100644 --- a/tensorflow/compiler/xla/python/outfeed_receiver.cc +++ b/tensorflow/compiler/xla/python/outfeed_receiver.cc @@ -101,14 +101,14 @@ uint32_t constexpr kOutfeedCidShutdown = 0; // Encapsulates data received from a device outfeed. class OutfeedData { public: - OutfeedData(Device* device, uint32_t consumer_id, Shape shape) + OutfeedData(PjRtDevice* device, uint32_t consumer_id, Shape shape) : device_(device), consumer_id_(consumer_id), shape_(shape), literal_(nullptr), literal_size_bytes_(0) {} - Device* device() { return device_; } + PjRtDevice* device() { return device_; } uint32_t consumer_id() const { return consumer_id_; } Shape shape() const { return shape_; } std::unique_ptr literal() { @@ -123,7 +123,7 @@ class OutfeedData { std::string DebugString() const; private: - Device* device_; + PjRtDevice* device_; uint32_t consumer_id_; Shape shape_; std::unique_ptr literal_; @@ -187,8 +187,8 @@ class OutfeedReceiverImpl { Status SendShutdownOutfeedHeader(int device_idx); // Receives a raw Literal from a device outfeed. - StatusOr> ReceiveRawFromOutfeed(const Device* device, - const Shape& shape); + StatusOr> ReceiveRawFromOutfeed( + const PjRtDevice* device, const Shape& shape); // Enqueues received data in the callbaback queue. void EnqueueReceivedData(std::unique_ptr received) @@ -200,7 +200,7 @@ class OutfeedReceiverImpl { OutfeedReceiver::Callback callback_; // The devices on which we are listening. - std::vector devices_; + std::vector devices_; // Maximum bytes capacity of the callback queue. uint64_t max_callback_queue_size_bytes_; @@ -283,7 +283,7 @@ void OutfeedReceiverImpl::DeviceListenerThreadLoop(int device_idx) { absl::MutexLock lock(&mu_); ++num_listening_threads_; } - Device* device = devices_[device_idx]; + PjRtDevice* device = devices_[device_idx]; while (true) { Shape header_shape = ShapeUtil::MakeShape(U32, {kOutfeedHeaderWords}); std::unique_ptr header = @@ -339,7 +339,7 @@ void OutfeedReceiverImpl::EnqueueReceivedData( } StatusOr> OutfeedReceiverImpl::ReceiveRawFromOutfeed( - const Device* device, const Shape& shape) { + const PjRtDevice* device, const Shape& shape) { std::shared_ptr literal_shared; TF_ASSIGN_OR_RETURN(LocalDeviceState * local_device, @@ -390,7 +390,7 @@ void OutfeedReceiverImpl::CallbackThreadLoop() { } Status OutfeedReceiverImpl::SendShutdownOutfeedHeader(int device_idx) { - const Device* device = devices_[device_idx]; + const PjRtDevice* device = devices_[device_idx]; constexpr int consumer_id = kOutfeedCidShutdown; VLOG(2) << "[" << device->DebugString() << "] SendSpecialHeader cons=" << consumer_id; diff --git a/tensorflow/compiler/xla/python/outfeed_receiver.h b/tensorflow/compiler/xla/python/outfeed_receiver.h index a8dcc559810..46e2e5d9526 100644 --- a/tensorflow/compiler/xla/python/outfeed_receiver.h +++ b/tensorflow/compiler/xla/python/outfeed_receiver.h @@ -33,7 +33,7 @@ class OutfeedReceiver { public: // A callback takes: device, consumer id, received. using Callback = - std::function)>; + std::function)>; // Constructs the receiver for the given clients and callback function. // diff --git a/tensorflow/compiler/xla/python/outfeed_receiver_py.cc b/tensorflow/compiler/xla/python/outfeed_receiver_py.cc index d297df332ff..a732ab8e21a 100644 --- a/tensorflow/compiler/xla/python/outfeed_receiver_py.cc +++ b/tensorflow/compiler/xla/python/outfeed_receiver_py.cc @@ -40,7 +40,7 @@ class OutfeedReceiverForPython { public: // A callback to Python takes: consumer id, received literal. using CallbackToPython = - std::function, uint32_t, pybind11::object)>; + std::function, uint32_t, pybind11::object)>; OutfeedReceiverForPython(CallbackToPython callback_python, std::vector> clients, @@ -48,7 +48,7 @@ class OutfeedReceiverForPython { : callback_python_(std::move(callback_python)), clients_(std::move(clients)) { OutfeedReceiver::Callback callback = - [this](Device* device, uint32_t consumer_id, + [this](PjRtDevice* device, uint32_t consumer_id, std::shared_ptr literal) { this->Callback(device, consumer_id, std::move(literal)); }; @@ -86,7 +86,7 @@ class OutfeedReceiverForPython { arrays); } - void Callback(Device* device, uint32_t consumer_id, + void Callback(PjRtDevice* device, uint32_t consumer_id, std::shared_ptr literal) { { absl::MutexLock lock(&mu_); @@ -106,7 +106,7 @@ class OutfeedReceiverForPython { LiteralToPython(std::move(literal)).ValueOrDie(); // The callback_ should handle all exceptions in user-code. If we get // an exception here, it is a bug in the callback and we should stop. - callback_python_(WrapWithClient(*it, device), consumer_id, + callback_python_(WrapWithClient(*it, device), consumer_id, std::move(literal_python)); } diff --git a/tensorflow/compiler/xla/python/outfeed_receiver_test.cc b/tensorflow/compiler/xla/python/outfeed_receiver_test.cc index e8a5063b70b..919dafe2e0b 100644 --- a/tensorflow/compiler/xla/python/outfeed_receiver_test.cc +++ b/tensorflow/compiler/xla/python/outfeed_receiver_test.cc @@ -78,11 +78,11 @@ TEST(OutfeedReceiverTest, ReceiveOutfeedSimple) { std::vector clients{cpu_client.get()}; auto receiver = absl::make_unique(); - OutfeedReceiver::Callback callback = [&receiver]( - Device* device, uint32_t consumer_id, - std::shared_ptr data) { - receiver->Receive(consumer_id, data); - }; + OutfeedReceiver::Callback callback = + [&receiver](PjRtDevice* device, uint32_t consumer_id, + std::shared_ptr data) { + receiver->Receive(consumer_id, data); + }; auto outfeed_receiver = std::make_shared(callback, clients, 128); outfeed_receiver->Start(); @@ -111,11 +111,11 @@ TEST(OutfeedReceiverTest, ReceiveOutfeedTwoComputations) { std::vector clients{cpu_client.get()}; auto receiver = absl::make_unique(); - OutfeedReceiver::Callback callback = [&receiver]( - Device* device, uint32_t consumer_id, - std::shared_ptr data) { - receiver->Receive(consumer_id, data); - }; + OutfeedReceiver::Callback callback = + [&receiver](PjRtDevice* device, uint32_t consumer_id, + std::shared_ptr data) { + receiver->Receive(consumer_id, data); + }; auto outfeed_receiver = std::make_shared(callback, clients, 128); outfeed_receiver->Start(); @@ -156,11 +156,11 @@ TEST(OutfeedReceiverTest, ReceiveOutfeedTwoOutfeed) { std::vector clients{cpu_client.get()}; auto receiver = absl::make_unique(); - OutfeedReceiver::Callback callback = [&receiver]( - Device* device, uint32_t consumer_id, - std::shared_ptr data) { - receiver->Receive(consumer_id, data); - }; + OutfeedReceiver::Callback callback = + [&receiver](PjRtDevice* device, uint32_t consumer_id, + std::shared_ptr data) { + receiver->Receive(consumer_id, data); + }; auto outfeed_receiver = std::make_shared(callback, clients, 128); outfeed_receiver->Start(); @@ -199,11 +199,11 @@ TEST(OutfeedReceiverTest, DifferentShapeForConsumerIdError) { std::vector clients{cpu_client.get()}; auto receiver = absl::make_unique(); - OutfeedReceiver::Callback callback = [&receiver]( - Device* device, uint32_t consumer_id, - std::shared_ptr data) { - receiver->Receive(consumer_id, data); - }; + OutfeedReceiver::Callback callback = + [&receiver](PjRtDevice* device, uint32_t consumer_id, + std::shared_ptr data) { + receiver->Receive(consumer_id, data); + }; auto outfeed_receiver = std::make_shared(callback, clients, 128); outfeed_receiver->Start(); @@ -233,11 +233,11 @@ TEST(OutfeedReceiverTest, InvalidConsumerIdError) { std::vector clients{cpu_client.get()}; auto receiver = absl::make_unique(); - OutfeedReceiver::Callback callback = [&receiver]( - Device* device, uint32_t consumer_id, - std::shared_ptr data) { - receiver->Receive(consumer_id, data); - }; + OutfeedReceiver::Callback callback = + [&receiver](PjRtDevice* device, uint32_t consumer_id, + std::shared_ptr data) { + receiver->Receive(consumer_id, data); + }; auto outfeed_receiver = std::make_shared(callback, clients, 128); outfeed_receiver->Start(); diff --git a/tensorflow/compiler/xla/python/py_buffer.cc b/tensorflow/compiler/xla/python/py_buffer.cc index ed4787310b4..b32fe047530 100644 --- a/tensorflow/compiler/xla/python/py_buffer.cc +++ b/tensorflow/compiler/xla/python/py_buffer.cc @@ -51,12 +51,12 @@ PyBuffer::~PyBuffer() { } } -ClientAndPtr PyBuffer::device() const { +ClientAndPtr PyBuffer::device() const { return WrapWithClient(client_, buffer_->device()); } StatusOr> PyBuffer::CopyToDevice( - const ClientAndPtr& dst_device) const { + const ClientAndPtr& dst_device) const { CHECK(dst_device.get() != nullptr); GlobalPyRefManager()->CollectGarbage(); std::unique_ptr out; diff --git a/tensorflow/compiler/xla/python/py_buffer.h b/tensorflow/compiler/xla/python/py_buffer.h index 76791e969cb..d7906574ec1 100644 --- a/tensorflow/compiler/xla/python/py_buffer.h +++ b/tensorflow/compiler/xla/python/py_buffer.h @@ -38,12 +38,12 @@ class PyBuffer { std::shared_ptr client() const { return client_; } PjRtBuffer* buffer() const { return buffer_.get(); } - ClientAndPtr device() const; + ClientAndPtr device() const; const std::string& platform_name() const { return buffer_->platform_name(); } bool is_deleted() const { return buffer_->IsDeleted(); } StatusOr> CopyToDevice( - const ClientAndPtr& dst_device) const; + const ClientAndPtr& dst_device) const; void Delete() { return buffer_->Delete(); } diff --git a/tensorflow/compiler/xla/python/py_client.cc b/tensorflow/compiler/xla/python/py_client.cc index 9b95f8e03de..6df11322564 100644 --- a/tensorflow/compiler/xla/python/py_client.cc +++ b/tensorflow/compiler/xla/python/py_client.cc @@ -33,8 +33,8 @@ namespace pprof = tensorflow::tfprof::pprof; PyClient::PyClient(std::shared_ptr pjrt_client) : pjrt_client_(std::move(pjrt_client)) {} -std::vector> PyClient::Devices() { - std::vector> devices; +std::vector> PyClient::Devices() { + std::vector> devices; devices.reserve(pjrt_client_->devices().size()); for (const auto& device : pjrt_client_->devices()) { devices.push_back(WrapWithClient(shared_from_this(), device.get())); @@ -42,21 +42,21 @@ std::vector> PyClient::Devices() { return devices; } -std::vector> PyClient::LocalDevices() { - std::vector> devices; +std::vector> PyClient::LocalDevices() { + std::vector> devices; devices.reserve(pjrt_client_->local_devices().size()); - for (Device* device : pjrt_client_->local_devices()) { + for (PjRtDevice* device : pjrt_client_->local_devices()) { devices.push_back(WrapWithClient(shared_from_this(), device)); } return devices; } -StatusOr>>> +StatusOr>>> PyClient::GetDefaultDeviceAssignment(int num_replicas, int num_partitions) { TF_ASSIGN_OR_RETURN( DeviceAssignment device_assignment, pjrt_client_->GetDefaultDeviceAssignment(num_replicas, num_partitions)); - std::vector>> result; + std::vector>> result; result.resize(num_replicas); for (int r = 0; r < num_replicas; ++r) { result[r].resize(num_partitions); @@ -70,12 +70,12 @@ PyClient::GetDefaultDeviceAssignment(int num_replicas, int num_partitions) { return result; } -StatusOr>> +StatusOr>> PyClient::GetDefaultDeviceAssignment1D(int num_replicas) { TF_ASSIGN_OR_RETURN(DeviceAssignment device_assignment, pjrt_client_->GetDefaultDeviceAssignment( num_replicas, /*num_partitions=*/1)); - std::vector> result; + std::vector> result; for (int i = 0; i < num_replicas; ++i) { int device_id = device_assignment(i, 0); auto iter = pjrt_client_->id_to_device().find(device_id); @@ -86,7 +86,7 @@ PyClient::GetDefaultDeviceAssignment1D(int num_replicas) { } StatusOr> PyClient::BufferFromPyval( - const pybind11::object& argument, Device* device, bool force_copy, + const pybind11::object& argument, PjRtDevice* device, bool force_copy, PjRtBuffer::HostBufferSemantics host_buffer_semantics) { if (device == nullptr) { TF_RET_CHECK(!pjrt_client_->local_devices().empty()); @@ -206,7 +206,7 @@ namespace { struct HeapProfileKey { Traceback* traceback; int64 size; - Device* device; + PjRtDevice* device; bool operator==(const HeapProfileKey& other) const; }; diff --git a/tensorflow/compiler/xla/python/py_client.h b/tensorflow/compiler/xla/python/py_client.h index e41415c42f2..f12a4ae4f0a 100644 --- a/tensorflow/compiler/xla/python/py_client.h +++ b/tensorflow/compiler/xla/python/py_client.h @@ -100,14 +100,14 @@ class PyClient : public std::enable_shared_from_this { int device_count() const { return pjrt_client_->device_count(); } int host_id() const { return pjrt_client_->host_id(); } - std::vector> Devices(); - std::vector> LocalDevices(); + std::vector> Devices(); + std::vector> LocalDevices(); - StatusOr>>> + StatusOr>>> GetDefaultDeviceAssignment(int num_replicas, int num_partitions); // TODO(skye): delete after all callers can handle 2D output - StatusOr>> GetDefaultDeviceAssignment1D( + StatusOr>> GetDefaultDeviceAssignment1D( int num_replicas); StatusOr CreateChannelHandle() { @@ -121,7 +121,7 @@ class PyClient : public std::enable_shared_from_this { } StatusOr> BufferFromPyval( - const pybind11::object& argument, Device* device, bool force_copy, + const pybind11::object& argument, PjRtDevice* device, bool force_copy, PjRtBuffer::HostBufferSemantics host_buffer_semantics); StatusOr> Compile( diff --git a/tensorflow/compiler/xla/python/py_executable.cc b/tensorflow/compiler/xla/python/py_executable.cc index ed524f1cb33..53891b96846 100644 --- a/tensorflow/compiler/xla/python/py_executable.cc +++ b/tensorflow/compiler/xla/python/py_executable.cc @@ -58,10 +58,10 @@ PyExecutable::~PyExecutable() { } } -std::vector> PyExecutable::LocalDevices() const { - std::vector> devices; +std::vector> PyExecutable::LocalDevices() const { + std::vector> devices; devices.reserve(executable_->local_devices().size()); - for (Device* device : executable_->local_devices()) { + for (PjRtDevice* device : executable_->local_devices()) { devices.push_back(WrapWithClient(client_, device)); } return devices; diff --git a/tensorflow/compiler/xla/python/py_executable.h b/tensorflow/compiler/xla/python/py_executable.h index 24f177261e7..2e51548ae51 100644 --- a/tensorflow/compiler/xla/python/py_executable.h +++ b/tensorflow/compiler/xla/python/py_executable.h @@ -47,7 +47,7 @@ class PyExecutable { return executable_->local_logical_device_ids(); } - std::vector> LocalDevices() const; + std::vector> LocalDevices() const; int64 SizeOfGeneratedCodeInBytes() const { return executable_->SizeOfGeneratedCodeInBytes(); diff --git a/tensorflow/compiler/xla/python/pytree.cc b/tensorflow/compiler/xla/python/pytree.cc index 58d6a585b08..bf0bb1a8d93 100644 --- a/tensorflow/compiler/xla/python/pytree.cc +++ b/tensorflow/compiler/xla/python/pytree.cc @@ -107,7 +107,7 @@ bool PyTreeDef::operator==(const PyTreeDef& other) const { } void PyTreeDef::FlattenInto(py::handle handle, - std::vector& leaves) { + std::vector& leaves) { Node node; int start_num_nodes = traversal_.size(); int start_num_leaves = leaves.size(); @@ -158,23 +158,19 @@ void PyTreeDef::FlattenInto(py::handle handle, } } else { assert(node.kind == Kind::kLeaf); - leaves.push_back(handle); + leaves.push_back(pybind11::reinterpret_borrow(handle)); } node.num_nodes = traversal_.size() - start_num_nodes + 1; node.num_leaves = leaves.size() - start_num_leaves; traversal_.push_back(std::move(node)); } -/*static*/ std::pair> PyTreeDef::Flatten( - py::handle x) { - std::vector leaves; +/*static*/ std::pair, std::unique_ptr> +PyTreeDef::Flatten(py::handle x) { + std::vector leaves; auto tree = absl::make_unique(); tree->FlattenInto(x, leaves); - py::list outputs(leaves.size()); - for (int i = 0; i < leaves.size(); ++i) { - outputs[i] = py::reinterpret_borrow(leaves[i]); - } - return std::make_pair(std::move(outputs), std::move(tree)); + return std::make_pair(std::move(leaves), std::move(tree)); } /*static*/ bool PyTreeDef::AllLeaves(const py::iterable& x) { diff --git a/tensorflow/compiler/xla/python/pytree.h b/tensorflow/compiler/xla/python/pytree.h index 76fd76fad6a..69cd93a7d08 100644 --- a/tensorflow/compiler/xla/python/pytree.h +++ b/tensorflow/compiler/xla/python/pytree.h @@ -84,12 +84,12 @@ class PyTreeDef { PyTreeDef() = default; // Flattens a Pytree into a list of leaves and a PyTreeDef. - static std::pair> Flatten( - pybind11::handle x); + static std::pair, std::unique_ptr> + Flatten(pybind11::handle x); // Recursive helper used to implement Flatten(). void FlattenInto(pybind11::handle handle, - std::vector& leaves); + std::vector& leaves); // Tests whether the given list is a flat list of leaves. static bool AllLeaves(const pybind11::iterable& x); diff --git a/tensorflow/compiler/xla/python/tpu_driver/client/tpu_client.cc b/tensorflow/compiler/xla/python/tpu_driver/client/tpu_client.cc index e78f04ff980..e4fb2cdfd41 100644 --- a/tensorflow/compiler/xla/python/tpu_driver/client/tpu_client.cc +++ b/tensorflow/compiler/xla/python/tpu_driver/client/tpu_client.cc @@ -37,8 +37,8 @@ namespace xla { TpuDevice::TpuDevice(int id, int host_id, const std::array& coords, int core_on_chip) - : xla::Device(id, /*local_device_state=*/nullptr, kTpuPlatform, - /*device_kind=*/"Cloud TPU", host_id), + : xla::PjRtDevice(id, /*local_device_state=*/nullptr, kTpuPlatform, + /*device_kind=*/"Cloud TPU", host_id), coords_(coords), core_on_chip_(core_on_chip) {} @@ -47,9 +47,9 @@ std::string TpuDevice::DebugString() const { coords_[0], coords_[1], coords_[2], core_on_chip_); } -xla::StatusOr>> +xla::StatusOr>> TpuDevice::GetTpuDevices(const tpu_driver::SystemInfo& system_info) { - std::vector> devices; + std::vector> devices; for (const auto& chip : system_info.tpu_chip()) { auto& coord = chip.chip_coord(); std::array coords_array = {coord.x(), coord.y(), coord.z()}; @@ -78,7 +78,7 @@ StatusOr> PyTpuClient::Get( tpu_driver::SystemInfo system_info; client->QuerySystemInfo(&system_info); - TF_ASSIGN_OR_RETURN(std::vector> devices, + TF_ASSIGN_OR_RETURN(std::vector> devices, TpuDevice::GetTpuDevices(system_info)); return std::make_shared(kTpuPlatform, std::move(client), @@ -88,13 +88,13 @@ StatusOr> PyTpuClient::Get( PyTpuClient::PyTpuClient(std::string platform_name, std::unique_ptr driver, - std::vector> devices, + std::vector> devices, int host_id) : platform_name_(std::move(platform_name)), driver_(std::move(driver)), devices_(std::move(devices)), host_id_(host_id) { - for (const std::shared_ptr& device : devices_) { + for (const std::shared_ptr& device : devices_) { CHECK(id_to_device_.insert({device->id(), device}).second) << "Duplicate device id: " << device->id(); @@ -173,7 +173,7 @@ static Status CheckDataType(xla::PrimitiveType dtype) { StatusOr> PyTpuBuffer::FromLiterals( std::vector leaves, const Shape& tuple_shape, std::shared_ptr leaves_references, - std::shared_ptr client, std::shared_ptr device) { + std::shared_ptr client, std::shared_ptr device) { tensorflow::profiler::TraceMe traceme("PyTpuBuffer::FromLiterals"); VLOG(1) << "PyTpuBuffer::FromLiterals: shape: " << tuple_shape.DebugString() << " device: " << device->DebugString(); @@ -229,7 +229,7 @@ StatusOr> PyTpuBuffer::FromLiterals( /* static */ StatusOr> PyTpuBuffer::MakeTuple( absl::Span buffers, std::shared_ptr client, - std::shared_ptr device) { + std::shared_ptr device) { std::vector child_shapes; std::vector> child_device_buffers; std::vector child_handle_ptrs; @@ -388,7 +388,7 @@ PyTpuBuffer::DestructureTuple() { } StatusOr> PyTpuBuffer::CopyToDevice( - std::shared_ptr dst_device) { + std::shared_ptr dst_device) { tensorflow::profiler::TraceMe traceme("PyTpuBuffer::CopyToDevice"); if (on_host_shape_.IsTuple()) { return Unimplemented("CopyToDevice for tuples is not supported."); @@ -433,7 +433,7 @@ Status PyTpuBuffer::BlockHostUntilReady() { /* static */ StatusOr> PyTpuBuffer::AllocateBuffer( const Shape& shape, std::shared_ptr client, - std::shared_ptr device) { + std::shared_ptr device) { tensorflow::profiler::TraceMe traceme("PyTpuBuffer::AllocateBuffer"); VLOG(1) << "PyTpuBuffer::AllocateBuffer: shape: " << shape.DebugString() << " device: " << device->DebugString(); @@ -465,7 +465,7 @@ StatusOr> PyTpuBuffer::AllocateBuffer( /*static*/ StatusOr> PyTpuBuffer::CreateBuffer( const Shape& non_tuple_shape, absl::optional initializer, - std::shared_ptr client, std::shared_ptr device) { + std::shared_ptr client, std::shared_ptr device) { tensorflow::profiler::TraceMe traceme("PyTpuBuffer::CreateBuffer"); VLOG(1) << "PyTpuBuffer::CreateBuffer: shape: " << non_tuple_shape.DebugString() @@ -493,8 +493,8 @@ StatusOr> PyTpuBuffer::CreateBuffer( std::vector>(), client); } -static std::shared_ptr LookupDevice(const PyTpuClient& client, - int device_id) { +static std::shared_ptr LookupDevice(const PyTpuClient& client, + int device_id) { auto it = client.id_to_device().find(device_id); CHECK(it != client.id_to_device().end()) << "Unknown device id: " << device_id; @@ -516,7 +516,7 @@ PyTpuExecutable::PyTpuExecutable( for (int replica = 0; replica < num_replicas; ++replica) { for (int partition = 0; partition < num_partitions; ++partition) { int device_id = device_assignment_(replica, partition); - std::shared_ptr device = LookupDevice(*client_, device_id); + std::shared_ptr device = LookupDevice(*client_, device_id); if (device->host_id() != client_->host_id()) { VLOG(3) << "Non-local device: " << device_id; continue; @@ -541,7 +541,7 @@ PyTpuExecutable::ExecuteResult PyTpuExecutable::ExecuteHelper( absl::Span this_core_arguments, int replica, int partition, const RunId& run_id) { const int device_id = device_assignment_(replica, partition); - std::shared_ptr device = LookupDevice(*client_, device_id); + std::shared_ptr device = LookupDevice(*client_, device_id); CHECK_EQ(device->host_id(), client_->host_id()); tensorflow::profiler::TraceMe traceme("PyTpuExecutable::Execute"); VLOG(3) << "Replica " << replica << ", partition " << partition diff --git a/tensorflow/compiler/xla/python/tpu_driver/client/tpu_client.h b/tensorflow/compiler/xla/python/tpu_driver/client/tpu_client.h index 4c45df181db..c2a424677fd 100644 --- a/tensorflow/compiler/xla/python/tpu_driver/client/tpu_client.h +++ b/tensorflow/compiler/xla/python/tpu_driver/client/tpu_client.h @@ -38,7 +38,7 @@ namespace xla { constexpr char kTpuPlatform[] = "tpu"; -class TpuDevice : public Device { +class TpuDevice : public PjRtDevice { public: TpuDevice(int id, int host_id, const std::array& coords, int core_on_chip); @@ -48,8 +48,8 @@ class TpuDevice : public Device { std::string DebugString() const override; - static xla::StatusOr>> GetTpuDevices( - const tpu_driver::SystemInfo& system_info); + static xla::StatusOr>> + GetTpuDevices(const tpu_driver::SystemInfo& system_info); private: const std::array coords_; @@ -66,7 +66,7 @@ class PyTpuClient { explicit PyTpuClient(std::string platform_name, std::unique_ptr driver, - std::vector> devices, + std::vector> devices, int host_id); virtual ~PyTpuClient() = default; @@ -83,11 +83,11 @@ class PyTpuClient { int device_count() const { return devices_.size(); } int local_device_count() const { return local_devices_.size(); } - const std::vector>& devices() { return devices_; } - const std::vector>& local_devices() { + const std::vector>& devices() { return devices_; } + const std::vector>& local_devices() { return local_devices_; } - const std::map>& id_to_device() const { + const std::map>& id_to_device() const { return id_to_device_; } int host_id() const { return host_id_; } @@ -110,11 +110,11 @@ class PyTpuClient { std::unique_ptr driver_; // Includes all devices, including non-local devices on multi-host platforms. - std::vector> devices_; + std::vector> devices_; // Maps Device::id() to the corresponding Device. Includes all devices. - std::map> id_to_device_; + std::map> id_to_device_; // Local devices indexed by local device ordinal. - std::vector> local_devices_; + std::vector> local_devices_; int host_id_; // A thread pool for scheduling core executions in parallel. @@ -128,7 +128,7 @@ struct TpuSharedBuffer final { TpuSharedBuffer(tpu_driver::TpuDriver* driver, std::unique_ptr handle, std::vector> wait_for_use, - std::shared_ptr src_device) + std::shared_ptr src_device) : driver(driver), device(std::move(src_device)), handle(std::move(handle)), @@ -143,7 +143,7 @@ struct TpuSharedBuffer final { } tpu_driver::TpuDriver* const driver; - const std::shared_ptr device; + const std::shared_ptr device; std::unique_ptr handle; std::vector> wait_for_use; @@ -162,12 +162,12 @@ class PyTpuBuffer { static StatusOr> FromLiterals( std::vector leaves_literals, const Shape& tuple_shape, std::shared_ptr leaves_reference, - std::shared_ptr client, std::shared_ptr device); + std::shared_ptr client, std::shared_ptr device); // Supports nested tuple creation. static StatusOr> MakeTuple( absl::Span buffers, - std::shared_ptr client, std::shared_ptr device); + std::shared_ptr client, std::shared_ptr device); PyTpuBuffer() = delete; PyTpuBuffer(Shape on_host_shape, @@ -181,7 +181,7 @@ class PyTpuBuffer { PyTpuBuffer& operator=(PyTpuBuffer&&) = delete; const Shape& on_host_shape() const { return on_host_shape_; } - std::shared_ptr device() const { return device_; } + std::shared_ptr device() const { return device_; } const std::string& platform_name() const { return client_->platform_name(); } std::shared_ptr client() const { return client_; } @@ -210,7 +210,7 @@ class PyTpuBuffer { // Copies the buffer to target device `dst_device` and returns a PyTpuBuffer // object holding the context to the target device buffer. StatusOr> CopyToDevice( - std::shared_ptr dst_device); + std::shared_ptr dst_device); // Blocks the host until the buffer's value has been computed and is ready for // immediate use on the device. Useful in particular for timing benchmarks. @@ -220,7 +220,7 @@ class PyTpuBuffer { // tuple, the returned buffer corresponds to the root tuple buffer. static StatusOr> AllocateBuffer( const Shape& shape, std::shared_ptr client, - std::shared_ptr device); + std::shared_ptr device); private: // Initializes a just allocated device buffer. The returned event will be @@ -231,11 +231,11 @@ class PyTpuBuffer { static StatusOr> CreateBuffer( const Shape& non_tuple_shape, absl::optional initializer, - std::shared_ptr client, std::shared_ptr device); + std::shared_ptr client, std::shared_ptr device); const std::shared_ptr client_; const Shape on_host_shape_; - const std::shared_ptr device_; + const std::shared_ptr device_; // If this is a tuple, `device_buffer_` stores the tuple buffer and // `child_buffers_` stores the child buffers; else, `device_buffer_` stores @@ -302,7 +302,7 @@ class PyTpuExecutable { return local_logical_device_ids_; } - const std::vector>& local_devices() const { + const std::vector>& local_devices() const { return local_devices_; } @@ -350,7 +350,7 @@ class PyTpuExecutable { // assigned. // shared_ptrs instead of unique_ptrs to play well with the Python bindings // (see xla.cc). - std::vector> local_devices_; + std::vector> local_devices_; xla::Shape result_shape_; }; diff --git a/tensorflow/compiler/xla/python/tpu_driver/client/tpu_client_extension.cc b/tensorflow/compiler/xla/python/tpu_driver/client/tpu_client_extension.cc index 9a794b79c5c..5d526b51899 100644 --- a/tensorflow/compiler/xla/python/tpu_driver/client/tpu_client_extension.cc +++ b/tensorflow/compiler/xla/python/tpu_driver/client/tpu_client_extension.cc @@ -40,11 +40,12 @@ PYBIND11_MODULE(tpu_client_extension, m) { .def("host_id", &PyTpuClient::host_id) .def("get_default_device_assignment", [](PyTpuClient* client, int num_replicas, int num_partitions) - -> StatusOr>>> { + -> StatusOr< + std::vector>>> { TF_ASSIGN_OR_RETURN(DeviceAssignment device_assignment, client->GetDefaultDeviceAssignment( num_replicas, num_partitions)); - std::vector>> result; + std::vector>> result; result.resize(num_replicas); for (int r = 0; r < num_replicas; ++r) { result[r].resize(num_partitions); @@ -60,11 +61,11 @@ PYBIND11_MODULE(tpu_client_extension, m) { // TODO(skye): delete after all callers can handle 2D output .def("get_default_device_assignment", [](PyTpuClient* client, int num_replicas) - -> StatusOr>> { + -> StatusOr>> { TF_ASSIGN_OR_RETURN(DeviceAssignment device_assignment, client->GetDefaultDeviceAssignment( num_replicas, /*num_partitions=*/1)); - std::vector> result; + std::vector> result; for (int i = 0; i < num_replicas; ++i) { int device_id = device_assignment(i, 0); auto iter = client->id_to_device().find(device_id); @@ -96,7 +97,8 @@ PYBIND11_MODULE(tpu_client_extension, m) { .def( "buffer_from_pyval", [](std::shared_ptr client, - const pybind11::object& argument, std::shared_ptr device, + const pybind11::object& argument, + std::shared_ptr device, bool force_copy) -> StatusOr> { if (device == nullptr) { TF_RET_CHECK(!client->local_devices().empty()); @@ -145,7 +147,7 @@ PYBIND11_MODULE(tpu_client_extension, m) { py::class_(m, "PyTpuBuffer") .def_property_readonly("client", &PyTpuBuffer::client) .def("copy_to_device", - [](PyTpuBuffer* buffer, std::shared_ptr dst_device) { + [](PyTpuBuffer* buffer, std::shared_ptr dst_device) { CHECK(dst_device != nullptr); GlobalPyRefManager()->CollectGarbage(); py::gil_scoped_release gil_release; @@ -202,7 +204,7 @@ PYBIND11_MODULE(tpu_client_extension, m) { .def_property_readonly("traceback", [](PyTpuExecutable*) { return py::none(); }); - py::class_>(m, "TpuDevice") + py::class_>(m, "TpuDevice") .def_property_readonly("coords", &TpuDevice::coords) .def_property_readonly("core_on_chip", &TpuDevice::core_on_chip) .def("__repr__", [](const TpuDevice& device) { diff --git a/tensorflow/compiler/xla/python/xla.cc b/tensorflow/compiler/xla/python/xla.cc index e3bbc49f85c..d5977f4f0cf 100644 --- a/tensorflow/compiler/xla/python/xla.cc +++ b/tensorflow/compiler/xla/python/xla.cc @@ -44,6 +44,7 @@ limitations under the License. #include "tensorflow/compiler/xla/pjrt/pjrt_client.h" #include "tensorflow/compiler/xla/python/bfloat16.h" #include "tensorflow/compiler/xla/python/dlpack.h" +#include "tensorflow/compiler/xla/python/jax_jit.h" #include "tensorflow/compiler/xla/python/ops.h" #include "tensorflow/compiler/xla/python/outfeed_receiver_py.h" #include "tensorflow/compiler/xla/python/py_buffer.h" @@ -438,26 +439,26 @@ PYBIND11_MODULE(xla_extension, m) { device_assignment); }); - py::class_>( + py::class_>( m, "Device", "A descriptor of an available device.\n\nSubclasses are used to " "represent specific types of devices, e.g. CPUs, GPUs. Subclasses may " "have additional properties specific to that device type.") .def_property_readonly( - "id", &Device::id, + "id", &PjRtDevice::id, "Integer ID of this device.\n\nUnique across all available devices " "of this type, including remote devices on multi-host platforms.") - .def_property_readonly("host_id", &Device::host_id, + .def_property_readonly("host_id", &PjRtDevice::host_id, "Integer ID of this device's host.\n\n" "This is always 0 except on multi-host platforms.") - .def_property_readonly("platform", &Device::platform_name) - .def_property_readonly("device_kind", &Device::device_kind) + .def_property_readonly("platform", &PjRtDevice::platform_name) + .def_property_readonly("device_kind", &PjRtDevice::device_kind) .def_property_readonly( "client", - [](const ClientAndPtr& device) { return device.client; }) - .def("__str__", &Device::DebugString) + [](const ClientAndPtr& device) { return device.client; }) + .def("__str__", &PjRtDevice::DebugString) .def("transfer_to_infeed", - [](const Device& device, const LiteralSlice& literal) { + [](const PjRtDevice& device, const LiteralSlice& literal) { GlobalPyRefManager()->CollectGarbage(); py::gil_scoped_release gil_release; TF_ASSIGN_OR_RETURN(LocalDeviceState * local_device, @@ -467,7 +468,8 @@ PYBIND11_MODULE(xla_extension, m) { }) .def( "transfer_from_outfeed", - [](const Device& device, const Shape& shape) -> StatusOr { + [](const PjRtDevice& device, + const Shape& shape) -> StatusOr { GlobalPyRefManager()->CollectGarbage(); std::shared_ptr literal_shared; { @@ -491,12 +493,12 @@ PYBIND11_MODULE(xla_extension, m) { return LiteralToPython(std::move(literal_shared)); }); - py::class_>(m, "CpuDevice") + py::class_>(m, "CpuDevice") .def("__repr__", [](const CpuDevice& device) { return absl::StrFormat("CpuDevice(id=%i)", device.id()); }); - py::class_>(m, "GpuDevice") + py::class_>(m, "GpuDevice") .def("__repr__", [](const GpuDevice& device) { return absl::StrFormat("GpuDevice(id=%i)", device.id()); }); @@ -738,7 +740,7 @@ PYBIND11_MODULE(xla_extension, m) { .def(py::init([](const py::bytes& serialized_hlo_module_proto) -> std::unique_ptr { HloModuleProto proto; - proto.ParseFromString(serialized_hlo_module_proto); + proto.ParseFromString(std::string(serialized_hlo_module_proto)); return absl::make_unique(proto); })) .def("get_hlo_module", &GetHloModule) @@ -899,6 +901,7 @@ PYBIND11_MODULE(xla_extension, m) { BuildProfilerSubmodule(&m); BuildOutfeedReceiverSubmodule(&m); BuildPytreeSubmodule(m); + BuildJaxjitSubmodule(m); py::class_> diff --git a/tensorflow/compiler/xla/service/BUILD b/tensorflow/compiler/xla/service/BUILD index f5618b95c3e..dd16bd32dd1 100644 --- a/tensorflow/compiler/xla/service/BUILD +++ b/tensorflow/compiler/xla/service/BUILD @@ -1431,6 +1431,7 @@ cc_library( ":hlo_live_range", ":hlo_ordering", ":hlo_proto_cc", + ":memory_space_assignment_repacking", ":tuple_points_to_analysis", "//tensorflow/compiler/xla:statusor", "//tensorflow/compiler/xla:util", @@ -1842,6 +1843,7 @@ cc_library( ":hlo", ":hlo_creation_utils", ":hlo_pass", + ":op_expander_pass", ":while_util", "//tensorflow/compiler/xla:literal_util", "//tensorflow/compiler/xla:statusor", @@ -2684,6 +2686,7 @@ cc_library( ":hlo_casting_utils", ":hlo_dce", ":hlo_pass", + ":shape_inference", "//tensorflow/compiler/xla:comparison_util", "//tensorflow/compiler/xla:literal", "//tensorflow/compiler/xla:literal_util", @@ -2707,7 +2710,6 @@ xla_test( ":dynamic_padder", ":hlo", ":hlo_dce", - ":hlo_get_dimension_size_rewriter", ":hlo_matchers", ":hlo_parser", "//tensorflow/compiler/xla:debug_options_flags", @@ -3435,6 +3437,26 @@ cc_library( ], ) +cc_library( + name = "memory_space_assignment_best_fit_repacker", + srcs = ["memory_space_assignment_best_fit_repacker.cc"], + hdrs = ["memory_space_assignment_best_fit_repacker.h"], + deps = [ + ":heap_simulator", + ":memory_space_assignment_repacking", + ], +) + +tf_cc_test( + name = "memory_space_assignment_best_fit_repacker_test", + srcs = ["memory_space_assignment_best_fit_repacker_test.cc"], + deps = [ + ":memory_space_assignment_best_fit_repacker", + "//tensorflow/compiler/xla/tests:xla_internal_test_main", + "//tensorflow/core:test", + ], +) + cc_library( name = "memory_space_assignment", srcs = ["memory_space_assignment.cc"], @@ -3997,42 +4019,6 @@ tf_cc_test( ], ) -cc_library( - name = "hlo_get_dimension_size_rewriter", - srcs = ["hlo_get_dimension_size_rewriter.cc"], - hdrs = ["hlo_get_dimension_size_rewriter.h"], - deps = [ - ":dynamic_dimension_inference", - ":hlo", - ":hlo_pass", - ":shape_inference", - "//tensorflow/compiler/xla:literal_util", - "@com_google_absl//absl/algorithm:container", - ], -) - -tf_cc_test( - name = "hlo_get_dimension_size_rewriter_test", - srcs = ["hlo_get_dimension_size_rewriter_test.cc"], - deps = [ - ":hlo", - ":hlo_get_dimension_size_rewriter", - ":hlo_matchers", - ":hlo_parser", - "//tensorflow/compiler/xla:literal", - "//tensorflow/compiler/xla:shape_util", - "//tensorflow/compiler/xla:types", - "//tensorflow/compiler/xla:util", - "//tensorflow/compiler/xla:xla_data_proto_cc", - "//tensorflow/compiler/xla/tests:hlo_test_base", - "//tensorflow/compiler/xla/tests:literal_test_util", - "//tensorflow/compiler/xla/tests:test_utils", - "//tensorflow/compiler/xla/tests:xla_internal_test_main", - "//tensorflow/core:lib", - "//tensorflow/core:test", - ], -) - cc_library( name = "maybe_owning_device_memory", srcs = [ diff --git a/tensorflow/compiler/xla/service/algebraic_simplifier.cc b/tensorflow/compiler/xla/service/algebraic_simplifier.cc index fa4d0e47a5d..214cbfa93a7 100755 --- a/tensorflow/compiler/xla/service/algebraic_simplifier.cc +++ b/tensorflow/compiler/xla/service/algebraic_simplifier.cc @@ -2500,6 +2500,20 @@ Status AlgebraicSimplifierVisitor::HandleGather(HloInstruction* gather) { if (ShapeUtil::IsZeroElementArray(operand_shape)) { return ReplaceInstruction(gather, MakeScalarLike(gather, 0)); } + + // Gathering from a scalar operand is simply a broadcast of that scalar + if (ShapeUtil::IsEffectiveScalar(operand_shape)) { + HloInstruction* new_operand = gather->mutable_operand(0); + if (operand_shape.rank()) { + TF_ASSIGN_OR_RETURN(new_operand, + MakeReshapeHlo(ShapeUtil::MakeScalarShape( + operand_shape.element_type()), + new_operand)); + } + HloInstruction* new_gather = + MakeBroadcastHlo(new_operand, {}, gather->shape()); + return ReplaceInstruction(gather, new_gather); + } // If the operand of a gather is very small, it is easier to fuse a // sequence of selects. const Shape& index_shape = gather->operand(1)->shape(); @@ -2712,7 +2726,7 @@ Status AlgebraicSimplifierVisitor::HandleMultiply(HloInstruction* multiply) { // Mul(Mul(x, constant1), Mul(y, constant2)) => Mul(Mul(x, y), // constant1*constant2) if (Match(multiply, - m::Multiply( + m::MultiplyAnyOrder( m::MultiplyAnyOrder(m::NonConstant(&a), m::Constant(&c1)), m::MultiplyAnyOrder(m::NonConstant(&b), m::Constant(&c2))))) { TF_ASSIGN_OR_RETURN(auto* product_of_constants, @@ -2734,6 +2748,29 @@ Status AlgebraicSimplifierVisitor::HandleMultiply(HloInstruction* multiply) { } } + { + HloInstruction *a, *c1, *c2; + // Mul(Mul(a, constant1), constant2) => Mul(a, constant1*constant2) + if (Match(multiply, + m::MultiplyAnyOrder( + m::MultiplyAnyOrder(m::NonConstant(&a), m::Constant(&c1)), + m::Constant(&c2)))) { + TF_ASSIGN_OR_RETURN(auto* product_of_constants, + MakeBinaryHlo(HloOpcode::kMultiply, c1, c2)); + if (ShapeUtil::IsScalar(product_of_constants->shape()) && + !ShapeUtil::IsScalar(multiply->shape())) { + product_of_constants = + computation_->AddInstruction(HloInstruction::CreateBroadcast( + multiply->shape(), product_of_constants, {})); + } + + return ReplaceWithNewInstruction( + multiply, + HloInstruction::CreateBinary(multiply->shape(), HloOpcode::kMultiply, + a, product_of_constants)); + } + } + { HloInstruction *a, *b, *constant, *op; // Mul(Mul(a, constant1), Broadcast(b)) => diff --git a/tensorflow/compiler/xla/service/algebraic_simplifier_test.cc b/tensorflow/compiler/xla/service/algebraic_simplifier_test.cc index 95700b2a994..70147f6ecad 100644 --- a/tensorflow/compiler/xla/service/algebraic_simplifier_test.cc +++ b/tensorflow/compiler/xla/service/algebraic_simplifier_test.cc @@ -140,6 +140,26 @@ TEST_F(AlgebraicSimplifierTest, MultiplyChain) { m::MultiplyAnyOrder(m::ConstantScalar(2), m::ConstantScalar(4))))); } +// (a*C1)*C2 => a*(C1*C2) +TEST_F(AlgebraicSimplifierTest, MultiplyChain2) { + const char* kModuleStr = R"( + HloModule m + test { + p0 = f32[] parameter(0) + a = f32[] constant(2) + b = f32[] constant(4) + c = f32[] multiply(p0, a) + ROOT y = f32[] multiply(c, b) + } + )"; + TF_ASSERT_OK_AND_ASSIGN(auto m, ParseAndReturnVerifiedModule(kModuleStr)); + ASSERT_TRUE(AlgebraicSimplifier(default_options_).Run(m.get()).ValueOrDie()); + EXPECT_THAT(m->entry_computation()->root_instruction(), + GmockMatch(m::MultiplyAnyOrder( + m::Parameter(0), m::MultiplyAnyOrder(m::ConstantScalar(2), + m::ConstantScalar(4))))); +} + // MUL(MUL(X, BROADCAST(constant)), BROADCAST(Y)) ==> // MUL(X, BROADCAST(MUL(Y, BROADCAST(constant)))) TEST_F(AlgebraicSimplifierTest, MultiplyBroadcastReassoc) { @@ -5627,6 +5647,30 @@ INSTANTIATE_TEST_SUITE_P( DotOfGatherSimplificationTestInstantiation, DotOfGatherSimplificationTest, ::testing::ValuesIn(DotOfGatherPositiveNegativeTests())); +TEST_F(AlgebraicSimplifierTest, GatherOfScalarToBroadcast) { + const char* hlo_string = R"( + HloModule repeat + + ENTRY main { + o = f32[1,1] parameter(0) + i = s32[100,2] parameter(1) + ROOT g = f32[100] gather(o, i), collapsed_slice_dims={0,1}, + start_index_map={0,1}, + index_vector_dim=1, + offset_dims={}, + slice_sizes={1,1} + } + )"; + TF_ASSERT_OK_AND_ASSIGN(auto module, + ParseAndReturnVerifiedModule(hlo_string)); + + AlgebraicSimplifierOptions options; + AlgebraicSimplifier simplifier(options); + EXPECT_TRUE(simplifier.Run(module.get()).ValueOrDie()); + auto root = module->entry_computation()->root_instruction(); + EXPECT_THAT(root, GmockMatch(m::Broadcast(m::Reshape(m::Parameter(0))))); +} + TEST_F(AlgebraicSimplifierTest, TupleReduceReshape) { const char* hlo_string = R"( HloModule module diff --git a/tensorflow/compiler/xla/service/all_reduce_combiner.cc b/tensorflow/compiler/xla/service/all_reduce_combiner.cc index 9d8f03c92ca..5fb4935a4b1 100644 --- a/tensorflow/compiler/xla/service/all_reduce_combiner.cc +++ b/tensorflow/compiler/xla/service/all_reduce_combiner.cc @@ -268,6 +268,11 @@ StatusOr AllReduceCombiner::Run(HloModule* module) { VLOG(1) << "Running AllReduceCombiner with threshold of " << combine_threshold_in_bytes_ << " bytes"; + if (combine_threshold_in_bytes_ <= 0 || combine_threshold_count_ <= 0) { + VLOG(1) << "Skip AllReduceCombiner because the threshold is zero"; + return false; + } + if (hlo_query::ContainsLayoutConstrainedAllReduce(*module)) { VLOG(1) << "Skip AllReduceCombiner because the module contains all-reduce " "with constrained layouts"; diff --git a/tensorflow/compiler/xla/service/buffer_assignment.cc b/tensorflow/compiler/xla/service/buffer_assignment.cc index 6cd58b86f0c..a0989d5765e 100644 --- a/tensorflow/compiler/xla/service/buffer_assignment.cc +++ b/tensorflow/compiler/xla/service/buffer_assignment.cc @@ -1424,13 +1424,16 @@ Status BufferAssigner::AssignBuffersWithSequentialOrdering( // Returns a heap algorithm that chooses the best result from several // algorithms. auto get_heap_algorithm = [&](int64 alignment) { - auto algorithms = - absl::make_unique>>(); - algorithms->push_back(absl::make_unique( - alignment, GlobalDecreasingSizeBestFitHeap::kSpatial)); - algorithms->push_back(absl::make_unique( - alignment, GlobalDecreasingSizeBestFitHeap::kTemporal)); - return absl::make_unique(std::move(algorithms)); + auto algorithms = absl::make_unique< + std::vector>>>(); + algorithms->push_back( + absl::make_unique>( + alignment, GlobalDecreasingSizeBestFitHeap::kSpatial)); + algorithms->push_back( + absl::make_unique>( + alignment, GlobalDecreasingSizeBestFitHeap::kTemporal)); + return absl::make_unique>( + std::move(algorithms)); }; if (run_whole_module_heap_simulation) { @@ -1461,7 +1464,7 @@ Status BufferAssigner::AssignBuffersWithSequentialOrdering( options.buffers_to_assign = &single_colored_set.second; TF_ASSIGN_OR_RETURN( - HeapSimulator::Result result, + HeapSimulator::Result result, HeapSimulator::Run( get_heap_algorithm(alignment), assignment->module(), schedule, assignment->alias_analysis(), assignment->buffer_size_, options)); @@ -1487,7 +1490,7 @@ Status BufferAssigner::AssignBuffersWithSequentialOrdering( HeapSimulator::Options options; options.buffers_to_assign = &single_colored_set.second; TF_ASSIGN_OR_RETURN( - HeapSimulator::Result result, + HeapSimulator::Result result, HeapSimulator::Run(get_heap_algorithm(alignment), *computation, *instruction_sequence, assignment->alias_analysis(), @@ -1582,7 +1585,7 @@ std::vector ComputePeakMemoryLogicalBuffers( } // namespace void BufferAssigner::AssignBuffersFromHeapSimulator( - const HeapSimulator::Result& result, BufferAssignment* assignment, + const HeapSimulator::Result& result, BufferAssignment* assignment, BufferValue::Color color) { if (assignment->stats_.preallocated_temp_fragmentation_bytes == -1) { assignment->stats_.preallocated_temp_fragmentation_bytes = diff --git a/tensorflow/compiler/xla/service/buffer_assignment.h b/tensorflow/compiler/xla/service/buffer_assignment.h index 50a4750601b..60422965832 100644 --- a/tensorflow/compiler/xla/service/buffer_assignment.h +++ b/tensorflow/compiler/xla/service/buffer_assignment.h @@ -661,9 +661,9 @@ class BufferAssigner { // Uses the results of the heap simulator to create a single allocation, with // LogicalBuffers packed to specific offsets. - void AssignBuffersFromHeapSimulator(const HeapSimulator::Result& result, - BufferAssignment* assignment, - LogicalBuffer::Color color); + void AssignBuffersFromHeapSimulator( + const HeapSimulator::Result& result, + BufferAssignment* assignment, LogicalBuffer::Color color); // Tries to assign the given instruction to the given buffer. Returns if the // assignment was successful. diff --git a/tensorflow/compiler/xla/service/conditional_code_motion.cc b/tensorflow/compiler/xla/service/conditional_code_motion.cc index cdda0aeb925..ce80b4cfc15 100644 --- a/tensorflow/compiler/xla/service/conditional_code_motion.cc +++ b/tensorflow/compiler/xla/service/conditional_code_motion.cc @@ -100,7 +100,7 @@ class BoundaryVisitor { // of reuses This is used as a placeholder only, assuming all // instructions can be fused to enable data reuses int64 ReusesCarriedBy(HloInstruction* op, HloInstruction* user) { - VLOG(1) << "ConditionalCodeMotion: Add reuses carried by instr: " + VLOG(2) << "ConditionalCodeMotion: Add reuses carried by instr: " << op->ToString() << "=>" << user->ToString() << "\n"; switch (user->opcode()) { case HloOpcode::kGetTupleElement: @@ -432,7 +432,8 @@ StatusOr ConditionalCodeMotion::MoveInstructionOut( if (to_move_out.empty()) { return false; } - VLOG(1) << "number of boundaries to move out:" << to_move_out.size() << "\n"; + VLOG(1) << "Modifying code--number of boundaries to move out:" + << to_move_out.size() << "\n"; HloComputation* conditional_parent = conditional->parent(); // save the old users before add new conditional user instructions std::vector old_conditional_users = conditional->users(); @@ -441,7 +442,7 @@ StatusOr ConditionalCodeMotion::MoveInstructionOut( absl::flat_hash_map hoisted_instructions; // Insert GetTupleElement before the instructions whose operands might still // be within the conditional. - VLOG(2) << "before opt:" + VLOG(1) << "before opt:" << conditional_parent->ToString(HloPrintOptions::Fingerprint()) << "\n"; int64 op_index = 0; @@ -470,16 +471,22 @@ StatusOr ConditionalCodeMotion::MoveInstructionOut( HloInstruction* old_root = conditional->branch_computation(0)->root_instruction(); for (auto user_instr : old_conditional_users) { + VLOG(2) << "Checking conditional user: " << user_instr->ToString() << "\n"; CHECK(user_instr->opcode() == HloOpcode::kGetTupleElement); auto tuple_opd = static_cast(user_instr); int64 index = tuple_opd->tuple_index(); + CHECK(old_root->operands().size() > index); HloInstruction* old_opd = old_root->operands()[index]; + CHECK(ContainsKey(hoisted_instructions, old_opd)); HloInstruction* new_opd = hoisted_instructions[old_opd].operands()[0]; CHECK(old_opd != nullptr); CHECK(new_opd != nullptr); + VLOG(2) << "Try replace all uses of :" << old_opd->ToString() << "\n"; TF_RETURN_IF_ERROR(user_instr->ReplaceAllUsesWith(new_opd)); TF_RETURN_IF_ERROR(conditional_parent->RemoveInstruction(user_instr)); } + VLOG(2) << "Done changing conditional users\n" + << conditional_parent->ToString() << "\n"; // Create tuple element within each branch and set it as root. int64 branch_count = conditional->branch_count(); for (int i = 0; i < branch_count; i++) { @@ -487,9 +494,8 @@ StatusOr ConditionalCodeMotion::MoveInstructionOut( std::vector elements; for (auto b1 : new_boundaries) { HloInstruction* op = b1.operands()[i]; - VLOG(1) << "branch count=" << i << "\n"; CHECK(op != nullptr); - VLOG(1) << "Adding to root " << i << " with " << op->ToString() << "\n"; + VLOG(2) << "Adding to root " << i << " with " << op->ToString() << "\n"; elements.push_back(op); } HloInstruction* tuple = @@ -507,7 +513,7 @@ StatusOr ConditionalCodeMotion::MoveInstructionOut( conditional->branch_computation(0)->root_instruction(); *conditional->mutable_shape() = new_root->shape(); // - VLOG(2) << "done moving instructions out of branches\n" + VLOG(1) << "done moving instructions out of branches\n" << conditional_parent->ToString(HloPrintOptions::Fingerprint()) << "\n"; return true; @@ -520,48 +526,79 @@ StatusOr ConditionalCodeMotion::MoveInstructionIn( if (to_move_in.empty()) { return false; } - VLOG(1) << "number of boundaries to move in:" << to_move_in.size() << "\n"; - HloComputation* conditional_parent = conditional->parent(); - VLOG(2) << "before opt:" - << conditional_parent->ToString(HloPrintOptions::Fingerprint()) + VLOG(1) << "Modifying code---number of boundaries to move in:" + << to_move_in.size() << "\n"; + VLOG(1) << "before opt:" + << conditional->parent()->ToString(HloPrintOptions::Fingerprint()) << "\n"; // Mapping instructions to be moved to their new representations. absl::flat_hash_map hoisted_instructions; int64 to_move_in_size = to_move_in.size(); int64 branch_count = conditional->branch_count(); - int64 op_index = conditional->shape().tuple_shapes_size(); - // Map conditional to its old root, then create a new root instruction in each - // branch. - Boundary b(Boundary::Position::kInsideBranch); + // Number of old conditional entries still to be used outside. + // If conditional shape is not tuple, will create a tuple and use subscript + // 0 to save the old operand being used. + int64 op_index = conditional->shape().IsTuple() + ? conditional->shape().tuple_shapes_size() - 1 + : 0; + HloGetTupleElementInstruction* tuple_use = + dynamic_cast(to_move_in[0].operands()[0]); + int64 use_index = (tuple_use != nullptr) ? tuple_use->tuple_index() : -1; + VLOG(2) << "Tuple use index = " << use_index << "\n"; + // Use to map the tuple_use instruction to its operand; + Boundary b_opd_use(Boundary::Position::kInsideBranch); + Boundary b_old_root(Boundary::Position::kInsideBranch); + // Create a new root instruction in each branch. for (int i = 0; i < branch_count; i++) { auto computation = conditional->branch_computation(i); auto old_root = computation->root_instruction(); - b.mutable_operands().push_back(old_root); - HloInstruction* new_root = nullptr; + b_old_root.mutable_operands().push_back(old_root); + std::vector operands; if (old_root->opcode() == HloOpcode::kTuple) { - new_root = computation->AddInstruction(old_root->Clone()); - } else { - std::vector operands; - if (!old_root->shape().IsTuple()) { - operands.push_back(old_root); - } else { - const Shape& old_shape = old_root->shape(); - for (int64 i = 0; i < old_shape.tuple_shapes_size(); ++i) { - auto element = - computation->AddInstruction(HloInstruction::CreateGetTupleElement( - old_shape.tuple_shapes(i), old_root, i)); - operands.push_back(element); + // Use operands of old_root directly, so old_root can be removed later. + for (int i = 0; i < old_root->operand_count(); ++i) { + if (i != use_index) { + operands.push_back(old_root->operands()[i]); + } else { // Map conditional use to the tuple operand. + b_opd_use.mutable_operands().push_back(old_root->operands()[i]); } } - new_root = - computation->AddInstruction(HloInstruction::CreateTuple(operands)); + } else if (old_root->shape().IsTuple()) { + // If old_root is not a kTuple but has tuple shape, elements within the + // tuple must be extracted first to be used by the new instructions. + const Shape& old_shape = old_root->shape(); + for (int64 i = 0; i < old_shape.tuple_shapes_size(); ++i) { + auto element = + computation->AddInstruction(HloInstruction::CreateGetTupleElement( + old_shape.tuple_shapes(i), old_root, i)); + if (i != use_index) { + operands.push_back(element); + } else { + b_opd_use.mutable_operands().push_back(element); + } + } + } else { + // If old_root is not a tuple and does not have tuple shape, use it + // to replace the conditional directly in the new computation. + b_opd_use.mutable_operands().push_back(conditional); } + HloInstruction* new_root = + computation->AddInstruction(HloInstruction::CreateTuple(operands)); VLOG(2) << "setting new root: " << new_root->ToString() << "\n"; - computation->set_root_instruction(new_root); + computation->set_root_instruction(new_root, + /*accept_different_shape*/ true); + if (old_root->opcode() == HloOpcode::kTuple) { + TF_RETURN_IF_ERROR(computation->RemoveInstruction(old_root)); + } VLOG(2) << "new branch computation: " << computation->ToString() << "\n"; } - hoisted_instructions[conditional] = b; - for (int64 i = 0; i < to_move_in_size; i++) { + hoisted_instructions[conditional] = b_old_root; + int64 cp_start = 0; + if (use_index >= 0) { + hoisted_instructions[tuple_use] = b_opd_use; + cp_start = 1; + } + for (int64 i = cp_start; i < to_move_in_size; i++) { Boundary b_to_move = to_move_in[i]; HloInstruction* op = b_to_move.operands()[0]; CHECK(op != nullptr); @@ -591,12 +628,12 @@ StatusOr ConditionalCodeMotion::MoveInstructionIn( } if (to_be_used_outside) { // Modify uses of instructions outside of the conditionals - HloInstruction* gtr = conditional_parent->AddInstruction( + HloInstruction* gtr = conditional->parent()->AddInstruction( HloInstruction::CreateGetTupleElement(op->shape(), conditional, op_index++)); TF_RETURN_IF_ERROR(op->ReplaceAllUsesWith(gtr)); - if (conditional_parent->root_instruction() == op) { - conditional_parent->set_root_instruction(gtr); + if (conditional->parent()->root_instruction() == op) { + conditional->parent()->set_root_instruction(gtr); } } } @@ -606,8 +643,8 @@ StatusOr ConditionalCodeMotion::MoveInstructionIn( HloInstruction* new_root = conditional->branch_computation(0)->root_instruction(); *conditional->mutable_shape() = new_root->shape(); - VLOG(2) << "Before removing instructions:" << conditional_parent->ToString() - << "\n"; + VLOG(2) << "Before removing instructions:" + << conditional->parent()->ToString() << "\n"; // Remove hoisted instructions from the branches. for (int64 i = to_move_in_size - 1; i >= 0; i--) { Boundary boundary_to_move_in = to_move_in[i]; @@ -616,10 +653,10 @@ StatusOr ConditionalCodeMotion::MoveInstructionIn( for (auto user : op->users()) { VLOG(2) << "Has User: " << user->ToString() << "\n"; } - TF_RETURN_IF_ERROR(conditional_parent->RemoveInstruction(op)); + TF_RETURN_IF_ERROR(conditional->parent()->RemoveInstruction(op)); } - VLOG(2) << "Done moving instructions inside branches\n" - << conditional_parent->ToString(HloPrintOptions::Fingerprint()) + VLOG(1) << "Done moving instructions inside branches\n" + << conditional->parent()->ToString(HloPrintOptions::Fingerprint()) << "\n"; return true; } @@ -631,6 +668,7 @@ class GroupConnectedBoundaries { HloInstruction* conditional_; HloComputation* conditional_parent_; bool is_layout_sensitive_; + // Instructions that have been visited but are not going to be moved. absl::flat_hash_set visited_; public: @@ -663,7 +701,7 @@ class GroupConnectedBoundaries { case HloOpcode::kReshape: return true; default: - VLOG(1) << "Instruction is convert and its operand is not know to " + VLOG(2) << "Instruction is convert and its operand is not know to " "be worth hoisting\n"; return false; } @@ -680,24 +718,28 @@ class GroupConnectedBoundaries { case HloOpcode::kGetTupleElement: return true; default: - VLOG(1) << "Instruction is not known to be worth hoisting\n"; + VLOG(2) << "Instruction is not known to be worth hoisting\n"; return false; } } int64 ReusesBeforeBoundary(HloInstruction* user) { int64 reuses = 0; for (auto op : user->operands()) { + // The operand must be an instruction that is not going to be moved (if + // user is inside the conditional); otherwise it must be the conditional + // itself and its user must be outside of the conditional. + if (!ContainsKey(visited_, op) && op != conditional_) { + continue; + } // Only consider single-user cases as reuseable. - if (ContainsKey(visited_, op) && op->user_count() == 1) { + if (user->opcode() == HloOpcode::kGetTupleElement && + user->user_count() == 1) { + reuses += ReusesCarriedBy(op, user->users()[0]); + } else if (op->user_count() == 1) { reuses += ReusesCarriedBy(op, user); - } else if (op->opcode() == HloOpcode::kConditional && - user->opcode() == HloOpcode::kGetTupleElement) { - if (user->user_count() == 1) { - reuses += ReusesCarriedBy(op, user->users()[0]); - } } } - VLOG(1) << "Reuses before instruction " << user->ToString() << ":" << reuses + VLOG(2) << "Reuses before instruction " << user->ToString() << ":" << reuses << "\n"; return reuses; } @@ -735,7 +777,7 @@ class GroupConnectedBoundaries { } else if (ContainsKey(visited_, op)) { reuses += ReusesCarriedBy(user, op); } - VLOG(1) << "reuses after instruction " << user->ToString() << ":" + VLOG(2) << "reuses after instruction " << user->ToString() << ":" << reuses << "\n"; return reuses; } @@ -744,7 +786,8 @@ class GroupConnectedBoundaries { int64 BenefitForMovingBoundaries(const std::vector& boundaries) { int64 reuses_before = 0, reuses_after = 0; - if (boundaries.size() == 1 && boundaries[0].IsOutsideBranch()) { + if (boundaries.size() == 1 && boundaries[0].IsOutsideBranch() && + boundaries[0].operands()[0]->opcode() == HloOpcode::kGetTupleElement) { // The only boundary of moving-in is the get_tuple_element op. return -1; } @@ -754,16 +797,16 @@ class GroupConnectedBoundaries { continue; } reuses_before += ReusesBeforeBoundary(op); - VLOG(1) << "Reuses before boundary so far: " << reuses_before << "\n"; + VLOG(2) << "Reuses before boundary so far: " << reuses_before << "\n"; reuses_after += ReusesAfterBoundary(op); - VLOG(1) << "Reuese after boundary so far : " << reuses_after << "\n"; + VLOG(2) << "Reuese after boundary so far : " << reuses_after << "\n"; } if (reuses_after == 0 && reuses_before == 0) { return -1; } else if (boundaries[0].IsInsideBranch()) { return reuses_after - reuses_before; } else { - return reuses_before - reuses_after; + return reuses_before - reuses_after - 1; } } @@ -800,12 +843,12 @@ class GroupConnectedBoundaries { visitor.AddToWorkList(boundary); while (visitor.HasNextBoundary()) { Boundary b = visitor.PopNextBoundary(); - VLOG(1) << "visiting boundary " << b.ToString() << "\n"; + VLOG(2) << "visiting boundary " << b.ToString() << "\n"; if ((b.IsOutsideBranch() || InstructionWithinBranchIdentical( b.operands(), is_layout_sensitive_)) && WorthHoisting(b.operands()[0])) { connected_boundaries_.push_back(b); - VLOG(1) << "boundary can be moved\n"; + VLOG(2) << "boundary can be moved\n"; int64 operand_count = (b.IsInsideBranch()) ? b.operands()[0]->operand_count() : b.operands()[0]->users().size(); @@ -829,7 +872,7 @@ class GroupConnectedBoundaries { } } } else { - VLOG(1) << "boundary cannot be moved\n"; + VLOG(2) << "boundary cannot be moved\n"; visited_.insert(b.operands()[0]); new_boundaries_.push_back(b); } @@ -876,7 +919,7 @@ ConditionalCodeMotion::Decision ConditionalCodeMotion::ConsiderCodeMotion( auto move_in_or_out = connect.BoundariesToMoveInOrOut(cur_boundary); if (!move_in_or_out.empty()) { auto benefit = connect.BenefitForMovingBoundaries(move_in_or_out); - VLOG(1) << "benefit of moving in or out " + VLOG(2) << "benefit of moving in or out " << cur_boundary.operands()[0]->ToString() << ":" << benefit << "\n"; if (benefit >= 0) { new_boundaries.clear(); @@ -899,9 +942,20 @@ StatusOr ConditionalCodeMotion::Run(HloModule* module) { // Gather all the conditional ops in the module ahead of time, to avoid // potential complications of modifying the code that affecting traversal. std::vector conditional_ops; + // Track how many times each branch computation is shared. + absl::flat_hash_map conditional_computations; for (auto* comp : module->MakeComputationPostOrder()) { for (auto* instr : comp->MakeInstructionPostOrder()) { if (instr->opcode() == HloOpcode::kConditional) { + int branch_count = instr->branch_count(); + for (int i = 0; i < branch_count; ++i) { + HloComputation* branch_i = instr->branch_computation(i); + if (ContainsKey(conditional_computations, branch_i)) { + conditional_computations[branch_i]++; + } else { + conditional_computations[branch_i] = 0; + } + } conditional_ops.push_back(instr); } } @@ -909,6 +963,17 @@ StatusOr ConditionalCodeMotion::Run(HloModule* module) { bool changed = false; for (HloInstruction* conditional : conditional_ops) { + int branch_count = conditional->branch_count(); + // check for shared conditional computations + bool conditional_is_shared = false; + for (int i = 0; i < branch_count; ++i) { + HloComputation* branch_i = conditional->branch_computation(i); + if (conditional_computations[branch_i] > 0) { + conditional_is_shared = true; + break; + } + } + // Boundaries to move out or to move into the branches. std::vector to_move_out, to_move_in, new_boundaries; // The conditional is moved into a worklist as the seed (starting point). @@ -926,6 +991,33 @@ StatusOr ConditionalCodeMotion::Run(HloModule* module) { Boundary boundary = visitor.PopNextBoundary(); VLOG(2) << "Analyzing boundary:" << boundary.ToString() << "\n"; d = ConsiderCodeMotion(conditional, boundary, to_move, next_boundary); + if (d != Decision::kNoChange && conditional_is_shared) { + for (int i = 0; i < branch_count; ++i) { + HloComputation* branch_i = conditional->branch_computation(i); + if (conditional_computations[branch_i] > 0) { + // Cloning is absolutely needed if the computation is shared by + // different branches, but the cloning can be potentially avoided + // if the sharing is only among branches of the same conditional. + // If cloning these branches causes a problem due to space issues, + // a fix can pass a vector of unique branches to the actual + // transformations, as an alternative representation of the + // conditional branches to be modified. Right now we assume the + // overhead of cloning is minimal since later stages of the compiler + // inline all the computations anyway. + HloComputation* clone_i = + conditional->parent()->parent()->AddEmbeddedComputation( + branch_i->Clone()); + conditional->set_branch_computation(i, clone_i); + conditional_computations[branch_i]--; + } + } + to_move.clear(); + next_boundary.clear(); + VLOG(2) << "Cloned branches as needed: " << conditional->ToString() + << "\n"; + // Need to reanalyze the cloned code to generate correct result. + d = ConsiderCodeMotion(conditional, boundary, to_move, next_boundary); + } switch (d) { case Decision::kMoveOutOfBranch: VLOG(2) << "Decision is move out of branch\n"; @@ -961,22 +1053,14 @@ StatusOr ConditionalCodeMotion::Run(HloModule* module) { MoveInstructionIn(conditional, to_move_in, new_boundaries)); VLOG(2) << "moving in result:" << result << "\n"; changed |= result; - } - } - // handling convert rematerialization/hoisting - if (!changed && pursue_full_conditional_code_motion_) { - std::vector conditional_ops; - for (auto* comp : module->MakeComputationPostOrder()) { - for (auto* instr : comp->MakeInstructionPostOrder()) { - if (instr->opcode() == HloOpcode::kConditional) { - conditional_ops.push_back(instr); - } - } - } - for (HloInstruction* conditional_op : conditional_ops) { + } else if (pursue_full_conditional_code_motion_ && !conditional_is_shared) { + // Invoke special handling for convert rematerialization/hoisting + // We need to make sure no sharing is present in the branches because no + // cloning has been done by the earlier analysis. + // TOOD[b/165848866]: extend solution to handle cloning for special move. TF_ASSIGN_OR_RETURN( bool convert_result, - ConvertSpecialMove(conditional_op, is_layout_sensitive_)); + ConvertSpecialMove(conditional, is_layout_sensitive_)); changed |= convert_result; } } diff --git a/tensorflow/compiler/xla/service/conditional_code_motion_test.cc b/tensorflow/compiler/xla/service/conditional_code_motion_test.cc index b0a6ba92f48..b91f3813980 100644 --- a/tensorflow/compiler/xla/service/conditional_code_motion_test.cc +++ b/tensorflow/compiler/xla/service/conditional_code_motion_test.cc @@ -580,6 +580,154 @@ ENTRY main { HloInstruction* root = module->entry_computation()->root_instruction(); EXPECT_THAT(root, AllOf(op::GetTupleElement(op::Conditional()))); } + +TEST_F(ConditionalCodeMotionTest, MovePowInWithSharedBranch) { + absl::string_view hlo_string = + R"( +HloModule RemoveIdenticalInstruction + +branch { + arg_tuple.1 = (f32[10]) parameter(0) + get-tuple-element.1 = f32[10] get-tuple-element(arg_tuple.1), index=0 + add.1 = f32[10] add(get-tuple-element.1, get-tuple-element.1) + ROOT tuple.3 = (f32[10]) tuple(add.1) +} + +ENTRY main { + pred.1 = pred[] parameter(0) + tuple.1 = (f32[10]) parameter(1) + tuple.2 = (f32[10]) parameter(2) + conditional = (f32[10]) + conditional(pred.1, tuple.1, tuple.2), true_computation=branch, + false_computation=branch + get-first-index = f32[10] get-tuple-element(conditional), index=0 + ROOT pow.1 = f32[10] power(get-first-index, get-first-index) +} +)"; + auto module = ParseAndReturnVerifiedModule(hlo_string).ValueOrDie(); + ConditionalCodeMotion pass(true, true); + ASSERT_TRUE(pass.Run(&*module).ValueOrDie()); + const HloInstruction* conditional = + FindInstruction(module.get(), "conditional"); + const HloComputation* on_true = conditional->branch_computation(0); + ASSERT_EQ(on_true->instruction_count(), 5); + const HloComputation* on_false = conditional->branch_computation(1); + ASSERT_EQ(on_false->instruction_count(), 5); + + HloInstruction* root = module->entry_computation()->root_instruction(); + EXPECT_THAT(root, AllOf(op::GetTupleElement(op::Conditional()))); +} + +TEST_F(ConditionalCodeMotionTest, MovePowInWithNonTupleRoot) { + absl::string_view hlo_string = + R"( +HloModule RemoveIdenticalInstruction + +branch { + arg_tuple.1 = (f32[10]) parameter(0) + get-tuple-element.1 = f32[10] get-tuple-element(arg_tuple.1), index=0 + ROOT add.1 = f32[10] add(get-tuple-element.1, get-tuple-element.1) +} + +ENTRY main { + pred.1 = pred[] parameter(0) + tuple.1 = (f32[10]) parameter(1) + tuple.2 = (f32[10]) parameter(2) + conditional = f32[10] + conditional(pred.1, tuple.1, tuple.2), true_computation=branch, + false_computation=branch + ROOT pow.1 = f32[10] power(conditional, conditional) +} +)"; + auto module = ParseAndReturnVerifiedModule(hlo_string).ValueOrDie(); + ConditionalCodeMotion pass(true, true); + ASSERT_TRUE(pass.Run(&*module).ValueOrDie()); + const HloInstruction* conditional = + FindInstruction(module.get(), "conditional"); + const HloComputation* on_true = conditional->branch_computation(0); + ASSERT_EQ(on_true->instruction_count(), 5); + const HloComputation* on_false = conditional->branch_computation(1); + ASSERT_EQ(on_false->instruction_count(), 5); + + HloInstruction* root = module->entry_computation()->root_instruction(); + EXPECT_THAT(root, AllOf(op::GetTupleElement(op::Conditional()))); +} + +TEST_F(ConditionalCodeMotionTest, MovePowInWithEmptyBranch) { + absl::string_view hlo_string = + R"( +HloModule RemoveIdenticalInstruction + +branch1 { + arg_tuple.1 = (f32[10]) parameter(0) + get-tuple-element.1 = f32[10] get-tuple-element(arg_tuple.1), index=0 + add.1 = f32[10] add(get-tuple-element.1, get-tuple-element.1) + ROOT tuple.3 = (f32[10]) tuple(add.1) +} + +branch2 { + ROOT arg_tuple.1 = (f32[10]) parameter(0) +} + +ENTRY main { + pred.1 = pred[] parameter(0) + tuple.1 = (f32[10]) parameter(1) + tuple.2 = (f32[10]) parameter(2) + conditional = (f32[10]) + conditional(pred.1, tuple.1, tuple.2), true_computation=branch1, + false_computation=branch2 + get-first-index = f32[10] get-tuple-element(conditional), index=0 + ROOT pow.1 = f32[10] power(get-first-index, get-first-index) +} +)"; + auto module = ParseAndReturnVerifiedModule(hlo_string).ValueOrDie(); + ConditionalCodeMotion pass(true, true); + ASSERT_TRUE(pass.Run(&*module).ValueOrDie()); + const HloInstruction* conditional = + FindInstruction(module.get(), "conditional"); + const HloComputation* on_true = conditional->branch_computation(0); + ASSERT_EQ(on_true->instruction_count(), 5); + const HloComputation* on_false = conditional->branch_computation(1); + ASSERT_EQ(on_false->instruction_count(), 4); + + HloInstruction* root = module->entry_computation()->root_instruction(); + EXPECT_THAT(root, AllOf(op::GetTupleElement(op::Conditional()))); +} + +TEST_F(ConditionalCodeMotionTest, MovePowInWithNonTupleParameter) { + absl::string_view hlo_string = + R"( +HloModule RemoveIdenticalInstruction + +branch { + arg.1 = f32[10] parameter(0) + ROOT add.1 = f32[10] add(arg.1, arg.1) +} + +ENTRY main { + pred.1 = pred[] parameter(0) + tuple.1 = f32[10] parameter(1) + tuple.2 = f32[10] parameter(2) + conditional = f32[10] + conditional(pred.1, tuple.1, tuple.2), true_computation=branch, + false_computation=branch + ROOT pow.1 = f32[10] power(conditional, conditional) +} +)"; + auto module = ParseAndReturnVerifiedModule(hlo_string).ValueOrDie(); + ConditionalCodeMotion pass(true, true); + ASSERT_TRUE(pass.Run(&*module).ValueOrDie()); + const HloInstruction* conditional = + FindInstruction(module.get(), "conditional"); + const HloComputation* on_true = conditional->branch_computation(0); + ASSERT_EQ(on_true->instruction_count(), 4); + const HloComputation* on_false = conditional->branch_computation(1); + ASSERT_EQ(on_false->instruction_count(), 4); + + HloInstruction* root = module->entry_computation()->root_instruction(); + EXPECT_THAT(root, AllOf(op::GetTupleElement(op::Conditional()))); +} + } // namespace conditional_opt } // namespace xla diff --git a/tensorflow/compiler/xla/service/cpu/BUILD b/tensorflow/compiler/xla/service/cpu/BUILD index 7c362b2da44..b622b712f82 100644 --- a/tensorflow/compiler/xla/service/cpu/BUILD +++ b/tensorflow/compiler/xla/service/cpu/BUILD @@ -140,7 +140,6 @@ cc_library( "//tensorflow/compiler/xla/service:map_inliner", "//tensorflow/compiler/xla/service:rng_bit_generator_expander", "//tensorflow/compiler/xla/service:tree_reduction_rewriter", - "//tensorflow/compiler/xla/service:hlo_get_dimension_size_rewriter", "//tensorflow/compiler/xla/service:conditional_canonicalizer", "//tensorflow/compiler/xla/service:conditional_to_select", "//tensorflow/compiler/xla/service:slow_operation_alarm", diff --git a/tensorflow/compiler/xla/service/cpu/cpu_compiler.cc b/tensorflow/compiler/xla/service/cpu/cpu_compiler.cc index 39d2b11ad37..d8bf15ecdeb 100644 --- a/tensorflow/compiler/xla/service/cpu/cpu_compiler.cc +++ b/tensorflow/compiler/xla/service/cpu/cpu_compiler.cc @@ -85,7 +85,6 @@ limitations under the License. #include "tensorflow/compiler/xla/service/hlo_cse.h" #include "tensorflow/compiler/xla/service/hlo_dce.h" #include "tensorflow/compiler/xla/service/hlo_element_type_converter.h" -#include "tensorflow/compiler/xla/service/hlo_get_dimension_size_rewriter.h" #include "tensorflow/compiler/xla/service/hlo_instruction.h" #include "tensorflow/compiler/xla/service/hlo_memory_scheduler.h" #include "tensorflow/compiler/xla/service/hlo_opcode.h" @@ -291,8 +290,7 @@ Status CpuCompiler::RunHloPassesThroughLayoutAssn( /*expansion_type=*/LogisticExpansionType::kExp); pipeline.AddPass(); pipeline.AddPass(); - pipeline.AddPass(); - pipeline.AddPass(); + pipeline.AddPass(ScatterExpander::kEliminateAllScatters); pipeline.AddPass(target_machine_features); { auto& pass = @@ -624,6 +622,7 @@ StatusOr> CpuCompiler::RunBackend( // Compile must be thread-safe so create a new LLVM context for the module. mlir::MLIRContext mlir_context; + mlir_context.loadAllGloballyRegisteredDialects(); llvm::LLVMContext llvm_context; auto llvm_module = absl::make_unique("__compute_module", llvm_context); @@ -835,6 +834,7 @@ CpuCompiler::CompileAheadOfTime(std::unique_ptr module_group, // Compile must be thread-safe so create a new LLVM context for the module. mlir::MLIRContext mlir_context; + mlir_context.loadAllGloballyRegisteredDialects(); llvm::LLVMContext llvm_context; llvm::Module llvm_module("__compute_module", llvm_context); llvm_module.setDataLayout(target_machine->createDataLayout()); diff --git a/tensorflow/compiler/xla/service/cpu/ir_emitter.cc b/tensorflow/compiler/xla/service/cpu/ir_emitter.cc index 242f3c6ceb7..36566d6c25f 100644 --- a/tensorflow/compiler/xla/service/cpu/ir_emitter.cc +++ b/tensorflow/compiler/xla/service/cpu/ir_emitter.cc @@ -1640,7 +1640,7 @@ IrEmitter::ShardedVectorType IrEmitter::CreateShardedVectorType( if (current_size_fragment >= vector_register_size_in_elements) { auto vector_type = llvm::VectorType::get( - element_ir_type, vector_register_size_in_elements); + element_ir_type, vector_register_size_in_elements, false); sharded_vector_type.insert( sharded_vector_type.end(), current_size_fragment / vector_register_size_in_elements, @@ -1656,7 +1656,7 @@ IrEmitter::ShardedVectorType IrEmitter::CreateShardedVectorType( // of two are all legal vector sizes (or at least can be lowered easily by // LLVM). sharded_vector_type.push_back( - llvm::VectorType::get(element_ir_type, current_size_fragment)); + llvm::VectorType::get(element_ir_type, current_size_fragment, false)); } return sharded_vector_type; } diff --git a/tensorflow/compiler/xla/service/cpu/llvm_ir_runtime.cc b/tensorflow/compiler/xla/service/cpu/llvm_ir_runtime.cc index 8d9229c1223..3afdd9c163e 100644 --- a/tensorflow/compiler/xla/service/cpu/llvm_ir_runtime.cc +++ b/tensorflow/compiler/xla/service/cpu/llvm_ir_runtime.cc @@ -115,7 +115,7 @@ void RewriteCalls( // Upcast to vector type if input is a scalar. if (vector_width == 1) { - llvm::Type* v1_type = llvm::VectorType::get(input->getType(), 1); + llvm::Type* v1_type = llvm::VectorType::get(input->getType(), 1, false); input = b.CreateInsertElement(llvm::UndefValue::get(v1_type), input, uint64_t{0}); } @@ -264,8 +264,8 @@ llvm::Value* GenerateVF32Exp(llvm::IRBuilder<>* b, llvm::Value* input, z = vsl.Add(one, z); // Convert n' to an i32. This is safe because we clamped it above. - llvm::Value* n_i32 = - b->CreateFPToSI(n, llvm::VectorType::get(b->getInt32Ty(), vector_width)); + llvm::Value* n_i32 = b->CreateFPToSI( + n, llvm::VectorType::get(b->getInt32Ty(), vector_width, false)); auto splat_i32 = [&](int32 v) { return b->CreateVectorSplat(vector_width, b->getInt32(v)); @@ -329,7 +329,7 @@ llvm::Value* GenerateVF32Log(llvm::IRBuilder<>* b, llvm::Value* input, llvm::Value* vector_constant_23 = b->CreateVectorSplat(vector_width, b->getInt32(23)); llvm::Type* i32_vector_type = - llvm::VectorType::get(b->getInt32Ty(), vector_width); + llvm::VectorType::get(b->getInt32Ty(), vector_width, false); llvm::Value* emm0 = b->CreateLShr(b->CreateBitCast(tmp0, i32_vector_type), vector_constant_23); diff --git a/tensorflow/compiler/xla/service/cpu/vector_support_library.cc b/tensorflow/compiler/xla/service/cpu/vector_support_library.cc index 0d2eab9fd42..48aa32f6b8f 100644 --- a/tensorflow/compiler/xla/service/cpu/vector_support_library.cc +++ b/tensorflow/compiler/xla/service/cpu/vector_support_library.cc @@ -33,7 +33,7 @@ VectorSupportLibrary::VectorSupportLibrary(PrimitiveType primitive_type, scalar_type_ = llvm_ir::PrimitiveTypeToIrType( primitive_type, b_->GetInsertBlock()->getModule()); scalar_pointer_type_ = llvm::PointerType::getUnqual(scalar_type_); - vector_type_ = llvm::VectorType::get(scalar_type_, vector_size); + vector_type_ = llvm::VectorType::get(scalar_type_, vector_size, false); vector_pointer_type_ = llvm::PointerType::getUnqual(vector_type_); } @@ -155,7 +155,7 @@ llvm::Type* VectorSupportLibrary::IntegerTypeForFloatSize(bool vector) { int64 float_size_bits = data_layout.getTypeSizeInBits(scalar_type()); llvm::Type* scalar_int_type = b()->getIntNTy(float_size_bits); if (vector) { - return llvm::VectorType::get(scalar_int_type, vector_size()); + return llvm::VectorType::get(scalar_int_type, vector_size(), false); } else { return scalar_int_type; } diff --git a/tensorflow/compiler/xla/service/cpu/vector_support_library.h b/tensorflow/compiler/xla/service/cpu/vector_support_library.h index f1a0b0a4406..cbed232897f 100644 --- a/tensorflow/compiler/xla/service/cpu/vector_support_library.h +++ b/tensorflow/compiler/xla/service/cpu/vector_support_library.h @@ -276,7 +276,7 @@ class VectorSupportLibrary { llvm::Constant* scalar_value = llvm::ConstantFP::get(type->getContext(), f); if (llvm::isa(type)) { return llvm::ConstantVector::getSplat( - llvm::ElementCount(vector_size(), /*Scalable=*/false), scalar_value); + llvm::ElementCount::getFixed(vector_size()), scalar_value); } return scalar_value; } diff --git a/tensorflow/compiler/xla/service/dfs_hlo_visitor.h b/tensorflow/compiler/xla/service/dfs_hlo_visitor.h index b0def1a2dd8..60d832a940a 100644 --- a/tensorflow/compiler/xla/service/dfs_hlo_visitor.h +++ b/tensorflow/compiler/xla/service/dfs_hlo_visitor.h @@ -245,6 +245,7 @@ class DfsHloVisitorBase { virtual Status HandleBitcast(HloInstructionPtr hlo) = 0; virtual Status HandleBroadcast(HloInstructionPtr hlo) = 0; virtual Status HandleReshape(HloInstructionPtr hlo) = 0; + virtual Status HandleDynamicReshape(HloInstructionPtr hlo) = 0; virtual Status HandleTranspose(HloInstructionPtr hlo) = 0; virtual Status HandleParameter(HloInstructionPtr hlo) = 0; virtual Status HandleFusion(HloInstructionPtr hlo) = 0; diff --git a/tensorflow/compiler/xla/service/dfs_hlo_visitor_with_default.h b/tensorflow/compiler/xla/service/dfs_hlo_visitor_with_default.h index b1d674fe467..3d1a9a3c894 100644 --- a/tensorflow/compiler/xla/service/dfs_hlo_visitor_with_default.h +++ b/tensorflow/compiler/xla/service/dfs_hlo_visitor_with_default.h @@ -198,6 +198,9 @@ class DfsHloVisitorWithDefaultBase Status HandlePad(HloInstructionPtr pad) override { return DefaultAction(pad); } + Status HandleDynamicReshape(HloInstructionPtr dynamic_reshape) override { + return DefaultAction(dynamic_reshape); + } Status HandleReshape(HloInstructionPtr reshape) override { return DefaultAction(reshape); } diff --git a/tensorflow/compiler/xla/service/dynamic_dimension_inference.cc b/tensorflow/compiler/xla/service/dynamic_dimension_inference.cc index 36429d3d755..80f98775c01 100644 --- a/tensorflow/compiler/xla/service/dynamic_dimension_inference.cc +++ b/tensorflow/compiler/xla/service/dynamic_dimension_inference.cc @@ -97,6 +97,8 @@ class DynamicDimensionInferenceVisitor : public DfsHloVisitorWithDefault { Status HandleTranspose(HloInstruction* hlo) override; + Status HandleDynamicReshape(HloInstruction* hlo) override; + Status HandleReshape(HloInstruction* hlo) override; Status HandleSort(HloInstruction* hlo) override; @@ -621,6 +623,18 @@ Status DynamicDimensionInferenceVisitor::HandleClamp(HloInstruction* hlo) { return PassThroughDynamicDimension(hlo); } +Status DynamicDimensionInferenceVisitor::HandleDynamicReshape( + HloInstruction* hlo) { + HloDynamicReshapeInstruction* dynamic_reshape = + Cast(hlo); + for (int64 i = 0; i < hlo->shape().rank(); ++i) { + if (hlo->shape().is_dynamic_dimension(i)) { + parent_->SetDynamicSize(hlo, {}, i, dynamic_reshape->dim_sizes(i)); + } + } + return Status::OK(); +} + Status DynamicDimensionInferenceVisitor::HandleReshape(HloInstruction* hlo) { return ForEachOperandDynamicDimension( hlo, diff --git a/tensorflow/compiler/xla/service/dynamic_dimension_inference_test.cc b/tensorflow/compiler/xla/service/dynamic_dimension_inference_test.cc index b5a17619edf..69f64c31a2f 100644 --- a/tensorflow/compiler/xla/service/dynamic_dimension_inference_test.cc +++ b/tensorflow/compiler/xla/service/dynamic_dimension_inference_test.cc @@ -1248,5 +1248,34 @@ TEST_F(DynamicDimensionInferenceTest, InfersCustomOp) { EXPECT_TRUE(handler_called); } +TEST_F(DynamicDimensionInferenceTest, DynamicReshapeOp) { + auto builder = HloComputation::Builder(TestName()); + auto input = builder.AddInstruction(HloInstruction::CreateParameter( + 0, ShapeUtil::MakeShape(F32, {9}), "data_input")); + auto six = builder.AddInstruction( + HloInstruction::CreateConstant(LiteralUtil::CreateR0(6))); + // Creates an input of shape [<=9], dynamic size is 6. + auto dynamic_input = + builder.AddInstruction(HloInstruction::CreateSetDimensionSize( + ShapeUtil::MakeShape(F32, {9}, {true}), input, six, 0)); + auto dynamic_size = builder.AddInstruction(HloInstruction::CreateParameter( + 1, ShapeUtil::MakeShape(S32, {}), "size_param")); + auto three = builder.AddInstruction( + HloInstruction::CreateConstant(LiteralUtil::CreateR0(3))); + + // Reshape [<=9] into [3, <=3] + + auto dynamic_reshape = + builder.AddInstruction(HloInstruction::CreateDynamicReshape( + ShapeUtil::MakeShape(F32, {3, 3}, {false, true}), dynamic_input, + {three, dynamic_size})); + + module_->AddEntryComputation(builder.Build()); + + TF_ASSERT_OK(RunInference()); + EXPECT_EQ(inference_->GetDynamicSize(dynamic_reshape, {}, 0), nullptr); + EXPECT_EQ(inference_->GetDynamicSize(dynamic_reshape, {}, 1), dynamic_size); +} + } // namespace } // namespace xla diff --git a/tensorflow/compiler/xla/service/dynamic_padder.cc b/tensorflow/compiler/xla/service/dynamic_padder.cc index c1f9da599e8..9b4d24bbbe9 100644 --- a/tensorflow/compiler/xla/service/dynamic_padder.cc +++ b/tensorflow/compiler/xla/service/dynamic_padder.cc @@ -32,6 +32,8 @@ limitations under the License. #include "tensorflow/compiler/xla/service/hlo_dce.h" #include "tensorflow/compiler/xla/service/hlo_instruction.h" #include "tensorflow/compiler/xla/service/hlo_instructions.h" +#include "tensorflow/compiler/xla/service/hlo_module.h" +#include "tensorflow/compiler/xla/service/shape_inference.h" #include "tensorflow/compiler/xla/shape_util.h" #include "tensorflow/compiler/xla/status_macros.h" #include "tensorflow/compiler/xla/util.h" @@ -125,6 +127,58 @@ StatusOr ChooseIdentityValue(HloInstruction* inst, } } +StatusOr ReplaceGetSize( + HloInstruction* instr, + DynamicDimensionInference* dynamic_dimension_inference) { + if (instr->opcode() != HloOpcode::kGetDimensionSize) { + return false; + } + HloComputation* computation = instr->parent(); + + TF_ASSIGN_OR_RETURN(auto legal_shape, + ShapeInference::InferGetDimensionSizeShape( + instr->operand(0)->shape(), instr->dimension())); + TF_RET_CHECK(ShapeUtil::Equal(instr->shape(), legal_shape)) + << "instr->shape() " << instr->shape().ToString() << " , " + << "legal_shape " << legal_shape.ToString(); + TF_RET_CHECK(ShapeUtil::HasPrimitiveType(instr->shape(), S32)); + HloInstruction* operand = instr->mutable_operand(0); + int64 dim = instr->dimension(); + HloInstruction* dynamic_size = + dynamic_dimension_inference->GetDynamicSize(operand, {}, dim); + if (dynamic_size != nullptr) { + TF_RETURN_IF_ERROR(instr->ReplaceAllUsesWith(dynamic_size)); + // The dependency between a instruction and its dynamic dimensions is not + // modeled in the IR. As instr is being replaced by dynamic_size, also tell + // dynamic dimension inference that the instruction is being replaced. + dynamic_dimension_inference->ReplaceAllDynamicDimensionUsesWith( + instr, dynamic_size); + } else { + int32 size = instr->operand(0)->shape().dimensions(dim); + HloInstruction* new_instr = computation->AddInstruction( + HloInstruction::CreateConstant(LiteralUtil::CreateR0(size))); + TF_RETURN_IF_ERROR(instr->ReplaceAllUsesWith(new_instr)); + dynamic_dimension_inference->ReplaceAllDynamicDimensionUsesWith(instr, + new_instr); + } + return true; +} + +StatusOr ReplaceSetSize(HloInstruction* instr) { + if (instr->opcode() != HloOpcode::kSetDimensionSize) { + return false; + } + + TF_RET_CHECK(Shape::Equal().IgnoreDynamicDimension()( + instr->shape(), instr->operand(0)->shape())) + << "instr->shape() " << instr->shape().ToString() << " , " + << "instruction operand shape " << instr->operand(0)->shape(); + HloInstruction* operand = instr->mutable_operand(0); + + TF_RETURN_IF_ERROR(instr->ReplaceAllUsesWith(operand)); + return true; +} + bool ShouldSkipPadOnOperand(const HloInstruction* inst, int64 operand_num, int64 dimension) { if ((inst->opcode() == HloOpcode::kReduceWindow || @@ -1236,6 +1290,18 @@ StatusOr DynamicPadder::Run(HloModule* module) { changed, RewriteDynamicReshape(inst, &dynamic_dimension_inference)); continue; } + + if (inst->opcode() == HloOpcode::kDynamicReshape) { + TF_ASSIGN_OR_RETURN( + changed, RewriteDynamicReshape(inst, &dynamic_dimension_inference)); + auto* static_reshape = + computation->AddInstruction(HloInstruction::CreateReshape( + inst->shape(), inst->mutable_operand(0))); + TF_RETURN_IF_ERROR(inst->ReplaceAllUsesWith(static_reshape)); + TF_RETURN_IF_ERROR(dynamic_dimension_inference.ForwardDynamicSize( + inst, static_reshape, {})); + continue; + } for (int64 operand_num = 0; operand_num < inst->operand_count(); ++operand_num) { HloInstruction* original_operand = inst->mutable_operand(operand_num); @@ -1292,6 +1358,22 @@ StatusOr DynamicPadder::Run(HloModule* module) { /*require_dynamic_output=*/require_dynamic_output)); } + for (auto* computation : module->computations()) { + for (auto instruction : computation->MakeInstructionPostOrder()) { + TF_ASSIGN_OR_RETURN( + bool replaced_get_size, + ReplaceGetSize(instruction, &dynamic_dimension_inference)); + changed = changed || replaced_get_size; + } + } + + for (auto* computation : module->computations()) { + for (auto instruction : computation->MakeInstructionPostOrder()) { + TF_ASSIGN_OR_RETURN(bool replaced_set_size, ReplaceSetSize(instruction)); + changed = changed || replaced_set_size; + } + } + HloDCE dce; TF_ASSIGN_OR_RETURN(changed, dce.Run(module)); VLOG(2) << "Post DynamicPadder HLO:"; diff --git a/tensorflow/compiler/xla/service/dynamic_padder_test.cc b/tensorflow/compiler/xla/service/dynamic_padder_test.cc index e8f429d9db6..3855531a97b 100644 --- a/tensorflow/compiler/xla/service/dynamic_padder_test.cc +++ b/tensorflow/compiler/xla/service/dynamic_padder_test.cc @@ -20,7 +20,6 @@ limitations under the License. #include "tensorflow/compiler/xla/literal.h" #include "tensorflow/compiler/xla/service/hlo_computation.h" #include "tensorflow/compiler/xla/service/hlo_dce.h" -#include "tensorflow/compiler/xla/service/hlo_get_dimension_size_rewriter.h" #include "tensorflow/compiler/xla/service/hlo_instruction.h" #include "tensorflow/compiler/xla/service/hlo_matchers.h" #include "tensorflow/compiler/xla/service/hlo_module.h" @@ -380,10 +379,15 @@ class ExecutionTest : public HloTestBase { Literal PadAndExecute(std::unique_ptr module, absl::Span arguments, bool slice_dynamic_output = true) { + if (!slice_dynamic_output) { + auto new_config = module->config(); + new_config.mutable_entry_computation_layout() + ->mutable_result_layout() + ->ClearDynamicShape(); + module->set_config(new_config); + } DynamicPadder padder(slice_dynamic_output); TF_CHECK_OK(padder.Run(module.get()).status()); - HloGetDimensionSizeRewriter rewriter; - TF_CHECK_OK(rewriter.Run(module.get()).status()); HloDCE dce; TF_CHECK_OK(dce.Run(module.get()).status()); return ExecuteAndTransfer(std::move(module), arguments); @@ -1179,6 +1183,84 @@ ENTRY main { EXPECT_EQ(result, expected); } +XLA_TEST_F(ExecutionTest, DynamicReshapeDoubleDynamicDimensions) { + const string hlo_text = R"( +HloModule TensorFlowScatterV1 + +ENTRY main { + param = s32[2, 3, 3] parameter(0) + size = s32[] constant(2) + param_padded_partial = s32[2, <=3, 3] set-dimension-size(param, size), + dimensions={1} + param_padded = s32[2, <=3, <=3] set-dimension-size(param_padded_partial, size), + dimensions={2} + result_size = s32[] constant(8) + ROOT reshaped = s32[<=18] dynamic-reshape(param_padded, result_size) +} +)"; + + // First dimension (1) is dynamic. Since dynamic size is 0, result is also 0. + Literal operand = LiteralUtil::CreateR3( + {{{0, 1, 2}, {3, 4, 5}, {6, 7, 8}}, {{0, 1, 2}, {3, 4, 5}, {6, 7, 8}}}); + auto module = GetHloModule(hlo_text); + + Literal result = PadAndExecute(std::move(module), {&operand}, false); + result.SetDynamicSize(0, 8); + // Padded data looks like this (P is padding which is ignored). + // [[0, 1, P] + // [3, 4, P] + // [P, P, P]] + // + // [[0, 1, P] + // [3, 4, P] + // [P, P, P]] + // + // Reshaping (with correct reshape rewriting) produces: + // [0, 1, 3, 4, 0, 1, 3, 4] + Literal expected = LiteralUtil::CreateR1({0, 1, 3, 4, 0, 1, 3, 4}); + + EXPECT_EQ(result, expected); +} + +XLA_TEST_F(ExecutionTest, DynamicReshapeOutputDoubleDynamicDimensions) { + const string hlo_text = R"( +HloModule TensorFlowScatterV1 + +ENTRY main { + param = s32[18] parameter(0) + eight = s32[] constant(8) + param_dynamic = s32[<=18] set-dimension-size(param, eight), dimensions={0} + two = s32[] constant(2) + // every dimension has dynamic size two. + ROOT reshaped = s32[2, <=3, <=3] dynamic-reshape(param_dynamic, two, two, two) +} +)"; + Literal operand = LiteralUtil::CreateR1( + {0, 1, 3, 4, 0, 1, 3, 4, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1}); + + auto module = GetHloModule(hlo_text); + + Literal result = PadAndExecute(std::move(module), {&operand}, false); + + result.SetDynamicSize(1, 2); + result.SetDynamicSize(2, 2); + // Padded operand is: + // [0, 1, 3, 4, 0, 1, 3, 4, P, P ....] + // + // Reshaping it should produce: + // [[0, 1, P] + // [3, 4, P] + // [P, P, P]] + // + // [[0, 1, P] + // [3, 4, P] + // [P, P, P]] + Literal expected = + LiteralUtil::CreateR3({{{0, 1}, {3, 4}}, {{0, 1}, {3, 4}}}); + + EXPECT_EQ(result, expected); +} + XLA_TEST_F(ExecutionTest, SetGetDimensionSize) { const string hlo_text = R"( HloModule TensorFlowScatterV1 @@ -1371,5 +1453,70 @@ ENTRY main { EXPECT_EQ(result, expected); } +namespace op = xla::testing::opcode_matchers; + +class HloDimensionSizeLegalizerTest : public HloTestBase { + protected: + HloDimensionSizeLegalizerTest() {} +}; + +TEST_F(HloDimensionSizeLegalizerTest, Ok) { + auto module = ParseAndReturnVerifiedModule(R"( +HloModule _ +ENTRY gds { + p = s32[3,4] parameter(0) + size0 = s32[] get-dimension-size(p), dimensions={0} + size1 = s32[] get-dimension-size(p), dimensions={1} + ROOT mul = s32[] multiply(size0, size1) +})") + .ValueOrDie(); + DynamicPadder pass; + EXPECT_TRUE(pass.Run(module.get()).ValueOrDie()); + EXPECT_THAT(module->entry_computation()->root_instruction(), + op::Multiply(op::Constant(), op::Constant())); +} + +TEST_F(HloDimensionSizeLegalizerTest, GetSetSetDimensionSizeRewriter) { + auto module = ParseAndReturnVerifiedModule(R"( +HloModule _ +ENTRY gds { + p = s32[3,4] parameter(0) + size0 = s32[] get-dimension-size(p), dimensions={0} + p_copy = s32[3,4] copy(p) + p_copy_dynamic = s32[<=3, 4] set-dimension-size(p_copy, size0), dimensions={0} + size1 = s32[] get-dimension-size(p_copy_dynamic), dimensions={0} + ROOT mul = s32[] multiply(size0, size1) +})") + .ValueOrDie(); + DynamicPadder pass; + EXPECT_TRUE(pass.Run(module.get()).ValueOrDie()); + EXPECT_THAT(module->entry_computation()->root_instruction(), + op::Multiply(op::Constant(), op::Constant())); +} + +TEST_F(HloDimensionSizeLegalizerTest, IllegalType) { + auto module = ParseAndReturnUnverifiedModule(R"( +HloModule _ +ENTRY gds { + p = s32[3]{0} parameter(0) + ROOT gds = s64[] get-dimension-size(p), dimensions={0} +})") + .ValueOrDie(); + DynamicPadder pass; + EXPECT_FALSE(pass.Run(module.get()).ok()); +} + +TEST_F(HloDimensionSizeLegalizerTest, IllegalDimension) { + auto module = ParseAndReturnUnverifiedModule(R"( +HloModule _ +ENTRY gds { + p = f32[2,5] parameter(0) + ROOT gds = s32[] get-dimension-size(p), dimensions={2} +})") + .ValueOrDie(); + DynamicPadder pass; + EXPECT_FALSE(pass.Run(module.get()).ok()); +} + } // namespace } // namespace xla diff --git a/tensorflow/compiler/xla/service/gpu/BUILD b/tensorflow/compiler/xla/service/gpu/BUILD index 074fbd92b27..d1d0827981e 100644 --- a/tensorflow/compiler/xla/service/gpu/BUILD +++ b/tensorflow/compiler/xla/service/gpu/BUILD @@ -254,6 +254,11 @@ cc_library( ":target_util", ":thunk", ":thunk_emitter", + "//tensorflow/compiler/mlir/hlo:lhlo", + "//tensorflow/compiler/mlir/xla:hlo_utils", + "//tensorflow/compiler/mlir/xla:mhlo_to_lhlo_with_xla", + "//tensorflow/compiler/mlir/xla:mlir_hlo_to_hlo", + "//tensorflow/compiler/mlir/xla:type_to_shape", "//tensorflow/compiler/xla:literal", "//tensorflow/compiler/xla:shape_util", "//tensorflow/compiler/xla:status_macros", @@ -291,6 +296,8 @@ cc_library( "@com_google_absl//absl/types:span", "@llvm-project//llvm:Core", "@llvm-project//llvm:Support", + "@llvm-project//mlir:IR", + "@llvm-project//mlir:StandardOps", ], ) @@ -1159,6 +1166,7 @@ cc_library( ":target_constants", ":tree_reduction_rewriter", ":variadic_op_splitter", + "//tensorflow/compiler/mlir/xla:mhlo_to_lhlo_with_xla", "//tensorflow/compiler/xla:protobuf_util", "//tensorflow/compiler/xla:status_macros", "//tensorflow/compiler/xla:statusor", @@ -1186,7 +1194,6 @@ cc_library( "//tensorflow/compiler/xla/service:hlo_dataflow_analysis", "//tensorflow/compiler/xla/service:hlo_dce", "//tensorflow/compiler/xla/service:hlo_element_type_converter", - "//tensorflow/compiler/xla/service:hlo_get_dimension_size_rewriter", "//tensorflow/compiler/xla/service:hlo_pass", "//tensorflow/compiler/xla/service:hlo_pass_pipeline", "//tensorflow/compiler/xla/service:hlo_proto_util", @@ -1217,6 +1224,8 @@ cc_library( "@com_google_absl//absl/memory", "@com_google_absl//absl/strings", "@llvm-project//llvm:Core", + "@llvm-project//mlir:AllPassesAndDialectsNoRegistration", + "@llvm-project//mlir:IR", ], ) diff --git a/tensorflow/compiler/xla/service/gpu/fusion_merger.cc b/tensorflow/compiler/xla/service/gpu/fusion_merger.cc index 60e4cb84b09..a499dc70e23 100644 --- a/tensorflow/compiler/xla/service/gpu/fusion_merger.cc +++ b/tensorflow/compiler/xla/service/gpu/fusion_merger.cc @@ -230,18 +230,15 @@ Status FusionInstructionMerger::HandleFusion(HloInstruction* fusion) { // This is done to avoid the duplication of expensive instructions, which // would occur if 'fusion' were merged into multiple users. // - // If 'fusion' has just one user, then an earlier fusion pass chose not to - // fuse this producer/consumer pair (likely because of expensive instruction - // re-use by the consumer), and so we honor that choice here as well. - // - // Moreover, if we are going to save a "lot" in memory bandwidth then we + // However, if we are going to save a "lot" in memory bandwidth then we // ignore how expensive the fusion instructions are. The heuristic used to // determine "a lot" is the following: merging must reduce memory traffic by a // factor of 0.3, and the amount of memory accessed must not be entirely // trivial (above 1K). This likely has room for improvement in the future. bool allow_expensive_ops = - merged_to_current_bytes_ratio < 0.3 && current_bytes_transferred > 1024; + fusion->user_count() == 1 || + (merged_to_current_bytes_ratio < 0.3 && current_bytes_transferred > 1024); if (!allow_expensive_ops && absl::c_any_of(fusion->fused_instructions(), diff --git a/tensorflow/compiler/xla/service/gpu/fusion_merger_test.cc b/tensorflow/compiler/xla/service/gpu/fusion_merger_test.cc index 42891154c23..cc4894f4c00 100644 --- a/tensorflow/compiler/xla/service/gpu/fusion_merger_test.cc +++ b/tensorflow/compiler/xla/service/gpu/fusion_merger_test.cc @@ -398,6 +398,29 @@ TEST_F(FusionMergerTest, WillMergeExpensiveFusionsIfSavesMemory) { EXPECT_TRUE(FusionMerger().Run(module.get()).ValueOrDie()); } +TEST_F(FusionMergerTest, WillMergeExpensiveFusionsWithSingleConsumer) { + auto module = ParseAndReturnVerifiedModule(R"( + HloModule m + + %f_b (p: f32[1024,1024,1024]) -> f32[1024,1024,1024] { + %p = f32[1024,1024,1024] parameter(0) + ROOT %t = f32[1024,1024,1024] tanh(%p) + } + + %f_c (p: f32[1024,1024,1024]) -> f32[1024,1024,1024] { + %p = f32[1024,1024,1024] parameter(0) + ROOT %t = f32[1024,1024,1024] add(%p, %p) + } + + ENTRY entry { + p0 = f32[1024,1024,1024] parameter(0) + f1 = f32[1024,1024,1024] fusion(p0), kind=kLoop, calls=%f_b + ROOT f2 = f32[1024,1024,1024] fusion(f1), kind=kLoop, calls=%f_c + })") + .ValueOrDie(); + EXPECT_TRUE(FusionMerger().Run(module.get()).ValueOrDie()); +} + } // namespace } // namespace gpu } // namespace xla diff --git a/tensorflow/compiler/xla/service/gpu/gpu_compiler.cc b/tensorflow/compiler/xla/service/gpu/gpu_compiler.cc index f5bf7476059..77fcf2c59f7 100644 --- a/tensorflow/compiler/xla/service/gpu/gpu_compiler.cc +++ b/tensorflow/compiler/xla/service/gpu/gpu_compiler.cc @@ -29,6 +29,8 @@ limitations under the License. #include "llvm/IR/LLVMContext.h" #include "llvm/IR/Module.h" #include "llvm/IR/Verifier.h" +#include "mlir/IR/Module.h" // from @llvm-project +#include "mlir/InitAllDialects.h" // from @llvm-project #include "tensorflow/compiler/xla/protobuf_util.h" #include "tensorflow/compiler/xla/service/algebraic_simplifier.h" #include "tensorflow/compiler/xla/service/all_reduce_combiner.h" @@ -81,7 +83,6 @@ limitations under the License. #include "tensorflow/compiler/xla/service/hlo_dataflow_analysis.h" #include "tensorflow/compiler/xla/service/hlo_dce.h" #include "tensorflow/compiler/xla/service/hlo_element_type_converter.h" -#include "tensorflow/compiler/xla/service/hlo_get_dimension_size_rewriter.h" #include "tensorflow/compiler/xla/service/hlo_instruction.h" #include "tensorflow/compiler/xla/service/hlo_pass_fix.h" #include "tensorflow/compiler/xla/service/hlo_pass_pipeline.h" @@ -195,13 +196,12 @@ Status GpuCompiler::OptimizeHloModule( /*layout_sensitive=*/false, /*allow_mixed_precision=*/false); - pass.AddPass(); - // BatchNormExpander can create zero-sized ops, so zero-sized HLO // elimination has to come after that pass. pass.AddPass(); pass.AddPass(GatherExpander::kEliminateSimpleGathers); + pass.AddPass(ScatterExpander::kEliminateSimpleScatters); AlgebraicSimplifierOptions options; // When transposes appear in a fusion node, we can easily adjust the @@ -516,15 +516,22 @@ static Status CompileModuleToLlvmIrImpl( DumpHloModuleIfEnabled(*hlo_module, **buffer_assignment, "after_optimizations"); + mlir::registerAllDialects(); + mlir::MLIRContext mlir_context; + IrEmitterContext ir_emitter_context( hlo_module, buffer_assignment->get(), platform_name, gpu_device_info, - cuda_compute_capability, profile_index_map, llvm_module->get()); + cuda_compute_capability, profile_index_map, &mlir_context, + llvm_module->get()); HloComputation* entry_computation = hlo_module->entry_computation(); - IrEmitterUnnested ir_emitter(hlo_module->config(), entry_computation, - &ir_emitter_context); - TF_RETURN_IF_ERROR(ir_emitter.EmitConstantGlobals()); + TF_ASSIGN_OR_RETURN( + auto ir_emitter, + IrEmitterUnnested::Create(hlo_module->config(), entry_computation, + &ir_emitter_context)); + + TF_RETURN_IF_ERROR(ir_emitter->EmitConstantGlobals()); { XLA_SCOPED_LOGGING_TIMER("GpuCompiler::RunBackend - IR emission"); @@ -533,9 +540,10 @@ static Status CompileModuleToLlvmIrImpl( ThunkSequence thunk_sequence; absl::Span order = hlo_schedule->ThunkLaunchOrder(); for (HloInstruction* instruction : order) { - TF_RETURN_IF_ERROR(instruction->Visit(&ir_emitter)); - TF_RETURN_IF_ERROR(ir_emitter.Postprocess(instruction)); - std::unique_ptr thunks = ir_emitter.ConsumeThunkSequence(); + TF_RETURN_IF_ERROR(instruction->Visit(ir_emitter.get())); + TF_RETURN_IF_ERROR(ir_emitter->Postprocess(instruction)); + std::unique_ptr thunks = + ir_emitter->ConsumeThunkSequence(); // The invariants between each input HloInstruction* and output Thunk* are // not all explicitly checked, but at least we can document them here: diff --git a/tensorflow/compiler/xla/service/gpu/gpu_scatter_expander.cc b/tensorflow/compiler/xla/service/gpu/gpu_scatter_expander.cc index 6287f1e3ca2..31f011fa734 100644 --- a/tensorflow/compiler/xla/service/gpu/gpu_scatter_expander.cc +++ b/tensorflow/compiler/xla/service/gpu/gpu_scatter_expander.cc @@ -23,26 +23,11 @@ limitations under the License. namespace xla { -StatusOr GpuScatterExpander::Run(HloModule* module) { - auto is_nontrivial_scatter = [](HloInstruction* inst) { - // TODO(b/129698548): Scattering elements larger than 64 bits is not - // supported by XLA:GPU. - return inst->opcode() == HloOpcode::kScatter && - inst->shape().element_type() == C128; - }; - - std::vector scatter_instrs; - for (HloComputation* computation : module->MakeNonfusionComputations()) { - absl::c_copy_if(computation->instructions(), - std::back_inserter(scatter_instrs), is_nontrivial_scatter); - } - - for (HloInstruction* inst : scatter_instrs) { - TF_ASSIGN_OR_RETURN(HloInstruction * expanded_root, ExpandScatter(inst)); - TF_RETURN_IF_ERROR(inst->parent()->ReplaceInstruction(inst, expanded_root)); - } - - return !scatter_instrs.empty(); +bool GpuScatterExpander::InstructionMatchesPattern(HloInstruction* inst) { + // TODO(b/129698548): Scattering elements larger than 64 bits is not + // supported by XLA:GPU. + return inst->opcode() == HloOpcode::kScatter && + primitive_util::BitWidth(inst->shape().element_type()) > 64; } } // namespace xla diff --git a/tensorflow/compiler/xla/service/gpu/gpu_scatter_expander.h b/tensorflow/compiler/xla/service/gpu/gpu_scatter_expander.h index 0818b32474f..92acb909729 100644 --- a/tensorflow/compiler/xla/service/gpu/gpu_scatter_expander.h +++ b/tensorflow/compiler/xla/service/gpu/gpu_scatter_expander.h @@ -20,10 +20,17 @@ limitations under the License. namespace xla { +// Legalizes scatters on the GPU. class GpuScatterExpander : public ScatterExpander { public: + // Although we pass kEliminateAllScatters, we override this behavior in + // InstruuctionMatchesPattern and select only some scatters to expand. + GpuScatterExpander() : ScatterExpander(kEliminateAllScatters) {} + absl::string_view name() const override { return "gpu_scatter_expander"; } - StatusOr Run(HloModule* module) override; + + protected: + bool InstructionMatchesPattern(HloInstruction* inst) override; }; } // namespace xla diff --git a/tensorflow/compiler/xla/service/gpu/hlo_to_ir_bindings.cc b/tensorflow/compiler/xla/service/gpu/hlo_to_ir_bindings.cc index 5d38d1b727c..332db83b6ad 100644 --- a/tensorflow/compiler/xla/service/gpu/hlo_to_ir_bindings.cc +++ b/tensorflow/compiler/xla/service/gpu/hlo_to_ir_bindings.cc @@ -117,11 +117,11 @@ static bool HasMeaningfulName(llvm::Value* value) { return false; } -llvm::Value* HloToIrBindings::GetTypedIrValue(const HloInstruction& hlo, - ShapeIndexView shape_index, - llvm::Value* ir_value) { - llvm::Type* pointee_type = llvm_ir::ShapeToIrType( - ShapeUtil::GetSubshape(hlo.shape(), shape_index), module_); +llvm::Value* CastToTypedValue(const Shape& shape, llvm::Value* ir_value, + llvm::IRBuilder<>* b) { + llvm::Type* pointee_type = + llvm_ir::ShapeToIrType(shape, b->GetInsertBlock()->getModule()); + llvm::Type* dest_type = pointee_type->getPointerTo(); llvm::Value* typed_ir_value; @@ -129,9 +129,17 @@ llvm::Value* HloToIrBindings::GetTypedIrValue(const HloInstruction& hlo, typed_ir_value = llvm::ConstantExpr::getPointerBitCastOrAddrSpaceCast( llvm::cast(ir_value), dest_type); } else { - typed_ir_value = b_->CreatePointerBitCastOrAddrSpaceCast( + typed_ir_value = b->CreatePointerBitCastOrAddrSpaceCast( ir_value, pointee_type->getPointerTo()); } + return typed_ir_value; +} + +llvm::Value* HloToIrBindings::GetTypedIrValue(const HloInstruction& hlo, + ShapeIndexView shape_index, + llvm::Value* ir_value) { + auto typed_ir_value = CastToTypedValue( + ShapeUtil::GetSubshape(hlo.shape(), shape_index), ir_value, b_); if (!HasMeaningfulName(ir_value)) { ir_value->setName(llvm_ir::IrName(&hlo, "raw")); } diff --git a/tensorflow/compiler/xla/service/gpu/hlo_to_ir_bindings.h b/tensorflow/compiler/xla/service/gpu/hlo_to_ir_bindings.h index 5eef6727801..3813ec6c949 100644 --- a/tensorflow/compiler/xla/service/gpu/hlo_to_ir_bindings.h +++ b/tensorflow/compiler/xla/service/gpu/hlo_to_ir_bindings.h @@ -116,6 +116,10 @@ class HloToIrBindings { llvm::Value* temp_buffer_base_ = nullptr; }; +// Converts `ir_value` with type i8* to a typed LLVM Value* based on `shape`. +llvm::Value* CastToTypedValue(const Shape& shape, llvm::Value* ir_value, + llvm::IRBuilder<>* b); + } // namespace gpu } // namespace xla diff --git a/tensorflow/compiler/xla/service/gpu/ir_emission_utils.cc b/tensorflow/compiler/xla/service/gpu/ir_emission_utils.cc index 6309d7fcdee..9d4ec358bd3 100644 --- a/tensorflow/compiler/xla/service/gpu/ir_emission_utils.cc +++ b/tensorflow/compiler/xla/service/gpu/ir_emission_utils.cc @@ -433,7 +433,7 @@ llvm::Value* EmitFullWarpShuffleDown(llvm::Value* value, llvm::Value* offset, builder->CreateZExt( builder->CreateBitCast(value, builder->getIntNTy(bit_width)), builder->getIntNTy(32 * num_segments)), - llvm::VectorType::get(builder->getInt32Ty(), num_segments)); + llvm::VectorType::get(builder->getInt32Ty(), num_segments, false)); for (int i = 0; i < num_segments; ++i) { llvm::Value* insert_val; if (target_triple.isNVPTX()) { diff --git a/tensorflow/compiler/xla/service/gpu/ir_emitter_context.h b/tensorflow/compiler/xla/service/gpu/ir_emitter_context.h index 9c43f80dc60..7d5a8d032e6 100644 --- a/tensorflow/compiler/xla/service/gpu/ir_emitter_context.h +++ b/tensorflow/compiler/xla/service/gpu/ir_emitter_context.h @@ -17,6 +17,7 @@ limitations under the License. #define TENSORFLOW_COMPILER_XLA_SERVICE_GPU_IR_EMITTER_CONTEXT_H_ #include "llvm/IR/Module.h" +#include "mlir/IR/MLIRContext.h" // from @llvm-project #include "tensorflow/compiler/xla/service/buffer_assignment.h" #include "tensorflow/compiler/xla/service/gpu/launch_dimensions.h" #include "tensorflow/compiler/xla/service/hlo_execution_profile.h" @@ -34,13 +35,15 @@ class IrEmitterContext { const HloModule* hlo_module, const BufferAssignment* buffer_assignment, std::string platform_name, GpuDeviceInfo gpu_device_info, absl::optional cuda_compute_capability, - const HloProfileIndexMap* profile_index_map, llvm::Module* llvm_module) + const HloProfileIndexMap* profile_index_map, + mlir::MLIRContext* mlir_context, llvm::Module* llvm_module) : hlo_module_(hlo_module), buffer_assignment_(buffer_assignment), platform_name_(std::move(platform_name)), gpu_device_info_(gpu_device_info), cuda_compute_capability_(cuda_compute_capability), profile_index_map_(profile_index_map), + mlir_context_(mlir_context), llvm_module_(llvm_module) {} // Disallow copy and assign. IrEmitterContext(const IrEmitterContext&) = delete; @@ -57,6 +60,7 @@ class IrEmitterContext { return cuda_compute_capability_; } const HloProfileIndexMap* profile_index_map() { return profile_index_map_; } + mlir::MLIRContext* mlir_context() { return mlir_context_; } llvm::Module* llvm_module() { return llvm_module_; } NameUniquer* name_uniquer() { return &name_uniquer_; } @@ -67,6 +71,7 @@ class IrEmitterContext { GpuDeviceInfo gpu_device_info_; absl::optional cuda_compute_capability_; const HloProfileIndexMap* profile_index_map_; + mlir::MLIRContext* mlir_context_; llvm::Module* llvm_module_; NameUniquer name_uniquer_; }; diff --git a/tensorflow/compiler/xla/service/gpu/ir_emitter_unnested.cc b/tensorflow/compiler/xla/service/gpu/ir_emitter_unnested.cc index 61b78b6004d..f88c70b1a33 100644 --- a/tensorflow/compiler/xla/service/gpu/ir_emitter_unnested.cc +++ b/tensorflow/compiler/xla/service/gpu/ir_emitter_unnested.cc @@ -37,6 +37,13 @@ limitations under the License. #include "llvm/IR/Instructions.h" #include "llvm/IR/LLVMContext.h" #include "llvm/IR/Module.h" +#include "mlir/Dialect/StandardOps/IR/Ops.h" // from @llvm-project +#include "mlir/IR/Builders.h" // from @llvm-project +#include "mlir/IR/Function.h" // from @llvm-project +#include "tensorflow/compiler/mlir/hlo/include/mlir-hlo/Dialect/mhlo/IR/lhlo_ops.h" +#include "tensorflow/compiler/mlir/xla/hlo_utils.h" +#include "tensorflow/compiler/mlir/xla/mlir_hlo_to_hlo.h" +#include "tensorflow/compiler/mlir/xla/type_to_shape.h" #include "tensorflow/compiler/xla/layout_util.h" #include "tensorflow/compiler/xla/literal.h" #include "tensorflow/compiler/xla/service/buffer_assignment.h" @@ -144,13 +151,86 @@ void UpdateLaunchDimensions(const LaunchDimensions& launch_dims, Thunk* thunk, llvm::ConstantAsMetadata::get(threads_per_block_ir_value)})); } +const BufferAllocation* GetAllocation( + mlir::BlockArgument func_arg, const BufferAssignment& buffer_assignment) { + auto func_op = + mlir::cast(func_arg.getParentRegion()->getParentOp()); + int64 allocation_index = func_op + .getArgAttrOfType( + func_arg.getArgNumber(), "lmhlo.alloc") + .getValue() + .getSExtValue(); + return &buffer_assignment.GetAllocation(allocation_index); +} + +StatusOr GetAllocationSliceForMlir( + mlir::Value v, const BufferAssignment& buffer_assignment) { + int64 size = v.getType().cast().getSizeInBits() / 8; + + if (auto arg = v.dyn_cast()) { + return BufferAllocation::Slice(GetAllocation(arg, buffer_assignment), 0, + size); + } + + // We match two patterns here: + // * v = ViewOp(arg); + // * v = StaticMemRefCastOp(ViewOp(arg)); + if (mlir::Operation* op = v.getDefiningOp()) { + if (auto cast = mlir::dyn_cast(op)) { + mlir::Value source = cast.getViewSource(); + op = source.getDefiningOp(); + if (!op) { + return Unimplemented("StaticMemRefCastOp has to wrap an op"); + } + } + if (auto view = mlir::dyn_cast(op)) { + return BufferAllocation::Slice( + GetAllocation(view.source().cast(), + buffer_assignment), + mlir::cast(view.byte_shift().getDefiningOp()) + .value() + .cast() + .getValue() + .getSExtValue(), + size); + } + return Unimplemented("StaticMemRefCastOp has to wrap a ViewOp"); + } + + return Unimplemented( + "Operand has to be in the form of ViewOp(arg) or " + "StaticMemRefCastOp(ViewOp(arg))"); +} + +absl::string_view GetHloName(mlir::Operation* op) { + if (auto attr = op->getAttrOfType("name")) { + auto ref = attr.getValue(); + return absl::string_view(ref.data(), ref.size()); + } + return ""; +} + } // namespace IrEmitterUnnested::IrEmitterUnnested(const HloModuleConfig& hlo_module_config, const HloComputation* hlo_computation, IrEmitterContext* ir_emitter_context) : IrEmitter(hlo_module_config, ir_emitter_context, /*is_nested=*/false), - hlo_computation_(hlo_computation) {} + hlo_computation_(hlo_computation), + mlir_scratch_module_(mlir::ModuleOp::create( + mlir::Builder(ir_emitter_context->mlir_context()).getUnknownLoc())), + lhlo_scratch_emitter_(ir_emitter_context_->buffer_assignment(), + *hlo_computation, mlir_scratch_module_.get()) {} + +StatusOr> IrEmitterUnnested::Create( + const HloModuleConfig& hlo_module_config, + const HloComputation* hlo_computation, + IrEmitterContext* ir_emitter_context) { + auto emitter = std::unique_ptr(new IrEmitterUnnested( + hlo_module_config, hlo_computation, ir_emitter_context)); + TF_RETURN_IF_ERROR(emitter->lhlo_scratch_emitter_.Initialize()); + return std::move(emitter); +} Status IrEmitterUnnested::Postprocess(HloInstruction* hlo) { bindings_.UnbindAllLocalIrValues(); @@ -158,12 +238,11 @@ Status IrEmitterUnnested::Postprocess(HloInstruction* hlo) { } llvm::Function* IrEmitterUnnested::BuildKernelPrototype( - const HloInstruction& inst, - absl::Span args) { + absl::string_view name, absl::Span args) { // Compute the kernel name. The opcode string may contain "-" which cannot be // in a PTX function name, so sanitize the name before uniquifying it. string kernel_name = ir_emitter_context_->name_uniquer()->GetUniqueName( - llvm_ir::SanitizeFunctionName(inst.name())); + llvm_ir::SanitizeFunctionName(std::string(name))); // Create the kernel and add it to the module. llvm::Module* module = ir_emitter_context_->llvm_module(); @@ -359,7 +438,8 @@ Status IrEmitterUnnested::HandleDot(HloInstruction* dot) { } Status IrEmitterUnnested::HandleConditional(HloInstruction* conditional) { - AddThunkToThunkSequence(BuildConditionalThunk(conditional)); + TF_ASSIGN_OR_RETURN(auto thunk, BuildConditionalThunk(conditional)); + AddThunkToThunkSequence(std::move(thunk)); return Status::OK(); } @@ -1038,10 +1118,13 @@ Status IrEmitterUnnested::HandleWhile(HloInstruction* xla_while) { // Build ForThunk for conformant while loops, otherwise build WhileThunk. auto config = xla_while->backend_config(); if (config.ok() && config.ValueOrDie().has_known_trip_count()) { - AddThunkToThunkSequence( + TF_ASSIGN_OR_RETURN( + auto thunk, BuildForThunk(xla_while, config.ValueOrDie().known_trip_count().n())); + AddThunkToThunkSequence(std::move(thunk)); } else { - AddThunkToThunkSequence(BuildWhileThunk(xla_while)); + TF_ASSIGN_OR_RETURN(auto thunk, BuildWhileThunk(xla_while)); + AddThunkToThunkSequence(std::move(thunk)); } return Status::OK(); } @@ -1264,39 +1347,109 @@ Status IrEmitterUnnested::HandleSelect(HloInstruction* select) { return IrEmitter::HandleSelect(select); } +StatusOr +IrEmitterUnnested::GetOrCreateSubComputationFromRegion(mlir::Region* region) { + std::unique_ptr& module = scratch_nested_computations_[region]; + if (module == nullptr) { + xla::XlaComputation xla_computation; + TF_RETURN_IF_ERROR(ConvertRegionToComputation(region, &xla_computation)); + TF_ASSIGN_OR_RETURN(auto program_shape, xla_computation.GetProgramShape()); + TF_ASSIGN_OR_RETURN( + module, HloModule::CreateFromProto(xla_computation.proto(), + HloModuleConfig(program_shape))); + } + return module->entry_computation(); +} + Status IrEmitterUnnested::HandleSort(HloInstruction* sort) { + MlirEmitterInput result; + + TF_ASSIGN_OR_RETURN(auto sort_op, lhlo_scratch_emitter_.EmitSortOp(sort)); + result.op = sort_op; + result.name = GetHloName(sort_op); + // The name in sort op has no semantics, and it's for debug only. If the name + // doesn't exist, we should use a namer (e.g. count-based). + // TODO(timshen): use a namer instead of relying on the HloInstruction names. + if (result.name.empty()) { + result.name = sort->name(); + } + const auto& buffer_assignment = ir_emitter_context_->buffer_assignment(); + auto& slice = result.extra_slice; + TF_ASSIGN_OR_RETURN(slice.buffer_slice, + buffer_assignment.GetUniqueSlice(sort, {})); + slice.written = true; + slice.shape = sort->shape(); + + result.thunk_info = GetThunkInfo(sort); + + return EmitMlirSort(result); +} + +Status IrEmitterUnnested::EmitMlirSort(MlirEmitterInput input) { + const auto& buffer_assignment = ir_emitter_context_->buffer_assignment(); + auto sort_op = mlir::cast(input.op); + + int operand_count = sort_op.operands().size(); + std::vector operand_shapes(operand_count); + std::vector slices; + std::vector output_shapes(sort_op.output().size()); + + for (int i = 0; i < operand_count; i++) { + operand_shapes[i] = + TypeToShape(sort_op.operands()[i].getType().cast()); + } + + // Craft n + 1 slices, where the first n are output parameters, and the last + // is the on-device tuple storage. We don't need n operands because sorting + // kernels are always in-place. + for (int i = 0; i < operand_count; i++) { + output_shapes[i] = + TypeToShape(sort_op.output()[i].getType().cast()); + MlirBufferSlice slice; + TF_ASSIGN_OR_RETURN( + slice.buffer_slice, + GetAllocationSliceForMlir(sort_op.output()[i], buffer_assignment)); + slice.written = true; + slice.shape = operand_shapes[i]; + slices.push_back(slice); + } + slices.push_back(input.extra_slice); + std::vector> thunks; - Shape keys_shape = sort->operand(0)->shape(); - int64 dimension_to_sort = sort->dimensions(0); - for (int64 i = 0; i < sort->operand_count(); ++i) { - ShapeIndex shape_index = - sort->operand_count() > 1 ? ShapeIndex({i}) : ShapeIndex({}); + + Shape keys_shape = operand_shapes[0]; + int64 dimension_to_sort = sort_op.dimension().getSExtValue(); + for (int64 i = 0; i < operand_count; ++i) { // We assume that the layout of all involved operands and outputs is the // same. - TF_RET_CHECK(LayoutUtil::LayoutsInShapesEqual(keys_shape, - sort->operand(i)->shape())); - TF_RET_CHECK(LayoutUtil::LayoutsInShapesEqual( - keys_shape, ShapeUtil::GetSubshape(sort->shape(), shape_index))); + TF_RET_CHECK( + LayoutUtil::LayoutsInShapesEqual(keys_shape, operand_shapes[i])); + TF_RET_CHECK( + LayoutUtil::LayoutsInShapesEqual(keys_shape, output_shapes[i])); // If possible, we share buffers. If that is not possible, we need to copy // the values, because the emitter does the sorting in-place. - auto destination_buffer = GetAllocationSlice(*sort, shape_index); - auto source_address = GetAllocationSlice(*sort->operand(i)); + TF_ASSIGN_OR_RETURN( + auto destination_buffer, + GetAllocationSliceForMlir(sort_op.output()[i], buffer_assignment)); + TF_ASSIGN_OR_RETURN( + auto source_address, + GetAllocationSliceForMlir(sort_op.operands()[i], buffer_assignment)); if (destination_buffer != source_address) { // TODO(b/26783907): Figure out why we never seem to share buffers for // key/value sort. - VLOG(2) << sort->name() << " requires initial D2D copy for operand " << i; + VLOG(2) << input.name << " requires initial D2D copy for operand " << i; thunks.push_back(absl::make_unique( Thunk::ThunkInfo(), /*source_address=*/source_address, /*destination_buffer=*/destination_buffer, - /*mem_size=*/ShapeUtil::ByteSizeOf(sort->operand(i)->shape()))); + /*mem_size=*/ShapeUtil::ByteSizeOf(operand_shapes[i]))); } } uint64 dimension_to_sort_bound = keys_shape.dimensions(dimension_to_sort); int64 num_stages = tensorflow::Log2Ceiling(dimension_to_sort_bound); - VLOG(2) << sort->name() << " requires " << num_stages << " stages."; + VLOG(2) << input.name << " requires " << num_stages << " stages."; CHECK_GE(1ULL << num_stages, dimension_to_sort_bound); CHECK_LT(1ULL << (num_stages - 1), dimension_to_sort_bound); @@ -1360,10 +1513,10 @@ Status IrEmitterUnnested::HandleSort(HloInstruction* sort) { // we have not enough threads, or not enough shared memory. Also it does not // give a speedup if the tile size is < 128. int64 total_shared_memory_needed = 0; - for (int64 i = 0; i < sort->operand_count(); ++i) { + for (int64 i = 0; i < operand_count; ++i) { total_shared_memory_needed += - kTileSize * ShapeUtil::ByteSizeOfPrimitiveType( - sort->operand(i)->shape().element_type()); + kTileSize * + ShapeUtil::ByteSizeOfPrimitiveType(operand_shapes[i].element_type()); } bool no_tiling = kTileSize < 128 || @@ -1376,7 +1529,7 @@ Status IrEmitterUnnested::HandleSort(HloInstruction* sort) { "kTileSize=%d < 128, " "kThreadsPerBlock=%d > threads_per_block_limit=%d, " "total_shared_memory_needed=%d > shared_memory_per_block=%d", - sort->name(), (no_tiling ? "won't" : "will"), kTileSize, kThreadsPerBlock, + input.name, (no_tiling ? "won't" : "will"), kTileSize, kThreadsPerBlock, ir_emitter_context_->gpu_device_info().threads_per_block_limit, total_shared_memory_needed, ir_emitter_context_->gpu_device_info().shared_memory_per_block); @@ -1384,37 +1537,38 @@ Status IrEmitterUnnested::HandleSort(HloInstruction* sort) { uint64 num_blocks = CeilOfRatio(num_iterations, kThreadsPerBlock); LaunchDimensions tiled_launch_dimensions(num_blocks, kThreadsPerBlock); VLOG(2) << absl::StreamFormat("%s launch dims: %d blocks, %d threads/block", - sort->name(), num_blocks, kThreadsPerBlock); + input.name, num_blocks, kThreadsPerBlock); + std::vector ir_arrays; auto emit_kernel = [&](absl::Span xor_masks) { VLOG(2) << absl::StreamFormat( - "%s uses kernel for xor masks [%s]", sort->name(), + "%s uses kernel for xor masks [%s]", input.name, absl::StrJoin(xor_masks, ", ", [](std::string* out, int64 xor_mask) { absl::StrAppendFormat(out, "0x%x", xor_mask); })); - thunks.push_back( - BuildKernelThunk(sort, /*implements_whole_instruction=*/false)); + thunks.push_back(BuildKernelThunkForMlir(input.name, Thunk::ThunkInfo(), + slices, &ir_arrays)); LaunchDimensions launch_dimensions = xor_masks.size() > 1 ? tiled_launch_dimensions : standard_launch_dimensions; UpdateLaunchDimensions(launch_dimensions, thunks.back().get(), ir_emitter_context_->llvm_module()); std::vector values_arrays; - values_arrays.reserve(sort->operand_count()); - for (int64 i = 0; i < sort->operand_count(); ++i) { - ShapeIndex shape_index = - sort->operand_count() > 1 ? ShapeIndex({i}) : ShapeIndex({}); - values_arrays.push_back(GetIrArray(*sort, *sort, shape_index)); + values_arrays.reserve(operand_count); + for (int64 i = 0; i < operand_count; ++i) { + values_arrays.push_back(ir_arrays[i]); } + TF_ASSIGN_OR_RETURN( + const HloComputation* comparator, + GetOrCreateSubComputationFromRegion(&sort_op.comparator())); return llvm_ir::EmitSortInPlace( - dimension_to_sort, values_arrays, IrName(sort), xor_masks, &b_, + dimension_to_sort, values_arrays, IrName(input.name), xor_masks, &b_, launch_dimensions, xor_masks.size() > 1 ? num_iterations_in_sort_dim : standard_num_iterations_in_sort_dim, kTileSize, [&](absl::Span operands, llvm::Value* output) { - return EmitCallToNestedComputation(*sort->to_apply(), operands, - output); + return EmitCallToNestedComputation(*comparator, operands, output); }); }; std::vector xor_masks; @@ -1441,17 +1595,18 @@ Status IrEmitterUnnested::HandleSort(HloInstruction* sort) { TF_RETURN_IF_ERROR(emit_kernel(xor_masks)); } VLOG(2) << absl::StreamFormat( - "%s requires %d thunks (including any D2D copies)", sort->name(), + "%s requires %d thunks (including any D2D copies)", input.name, thunks.size()); - AddThunkToThunkSequence(absl::make_unique( - GetThunkInfo(sort), std::move(thunks))); - if (sort->operand_count() > 1) { + AddThunkToThunkSequence( + absl::make_unique(input.thunk_info, std::move(thunks))); + if (operand_count > 1) { // Emit the tuple as part of the last stage of sorting. // We are currently in the block sorted.in_bounds.after. b_.SetInsertPoint(b_.GetInsertBlock()->getTerminator()); - llvm_ir::EmitTuple(GetIrArray(*sort, *sort), - ConstructIrArrayForOutputs(*sort), &b_); + llvm_ir::EmitTuple( + ir_arrays[operand_count], + absl::MakeSpan(ir_arrays).subspan(0, ir_arrays.size() - 1), &b_); } return Status::OK(); } @@ -1589,24 +1744,6 @@ Status IrEmitterUnnested::HandleAfterAll(HloInstruction* after_all) { return Status::OK(); } -// Describes how to access a particular subshape for an HLO. For instance if -// `.hlo_index` is {1} and `.gte_index` is {3, 4} then buffer for `.instr` at -// ShapeIndex {1} (i.e. the buffer for the second tuple element of hlo) is found -// at `.buffer_slice`[3][4]. That is, `.slice` is a void***, which we -// dereference twice -- first at index 3, and then at index 4 -- to get the -// address of our buffer. -struct HloBufferSlice { - const HloInstruction* instr; - ShapeIndex hlo_index; - - // The root buffer to look at. - BufferAllocation::Slice buffer_slice; - - // Describes how to dereference starting at that buffer to get to the buffer - // in question. - ShapeIndex gte_index; -}; - // Figures out how to access the buffers for all subshapes of hlo's operands and // for hlo itself (i.e. all the buffers produced by HLO). // @@ -1715,22 +1852,22 @@ static std::vector GetHloBufferSlices( return result; } -std::unique_ptr IrEmitterUnnested::BuildKernelThunk( - const HloInstruction* inst, bool implements_whole_instruction) { - const BufferAssignment& buffer_assn = - ir_emitter_context_->buffer_assignment(); - - std::vector hlo_slices = - GetHloBufferSlices(inst, buffer_assn); +std::unique_ptr +IrEmitterUnnested::BuildKernelThunkFromBufferSlices( + absl::string_view name, Thunk::ThunkInfo thunk_info, + absl::Span slices, + std::function + bind_slice_to_ir_value) { + const auto& buffer_assn = ir_emitter_context_->buffer_assignment(); // Figure out which buffer allocations need to be passed as arguments to our - // kernel. This is simply all of the allocations referenced in hlo_slices, + // kernel. This is simply all of the allocations referenced in slices, // plus the XLA temp buffer (if we have it). We always include the temp // buffer because even if the kernel itself doesn't use it, a nested // subcomputation within the kernel (e.g. a kMap's computation) might. std::unordered_set buffers_needed; - for (const auto& hlo_buffer_slice : hlo_slices) { - buffers_needed.insert(hlo_buffer_slice.buffer_slice.allocation()); + for (auto* slice : slices) { + buffers_needed.insert(slice->buffer_slice.allocation()); } absl::optional temp_buffer; for (const BufferAllocation& alloc : buffer_assn.Allocations()) { @@ -1759,7 +1896,7 @@ std::unique_ptr IrEmitterUnnested::BuildKernelThunk( return a->index() < b->index(); }); - llvm::Function* kernel = BuildKernelPrototype(*inst, non_constant_buffers); + llvm::Function* kernel = BuildKernelPrototype(name, non_constant_buffers); // Build a map from a BufferAllocation to the corresponding argument in our // kernel. @@ -1793,24 +1930,19 @@ std::unique_ptr IrEmitterUnnested::BuildKernelThunk( // For each buffer our kernel might want to touch, bind it to a value derived // from our kernel args. - for (const auto& hlo_buffer_slice : hlo_slices) { - const HloInstruction* instr = hlo_buffer_slice.instr; - const ShapeIndex& index = hlo_buffer_slice.hlo_index; - const BufferAllocation::Slice& slice = hlo_buffer_slice.buffer_slice; - const ShapeIndex& gte_index = hlo_buffer_slice.gte_index; - - VLOG(3) << "Buffer for " << instr->ToString() << " at " << index.ToString() - << " is found in slice " << slice.ToString() << " at GTE index " - << gte_index.ToString(); + for (auto* slice : slices) { + const BufferAllocation::Slice& buffer_slice = slice->buffer_slice; + const ShapeIndex& gte_index = slice->gte_index; llvm::Value* loc; - if (slice.allocation()->is_constant()) { + if (buffer_slice.allocation()->is_constant()) { loc = ir_emitter_context_->llvm_module()->getGlobalVariable( - llvm_ir::ConstantBufferAllocationToGlobalName(*slice.allocation())); + llvm_ir::ConstantBufferAllocationToGlobalName( + *buffer_slice.allocation())); CHECK_NE(loc, nullptr); } else { - loc = InBoundsGEP(kernel_args.at(slice.allocation()), - {b_.getInt64(slice.offset())}); + loc = InBoundsGEP(kernel_args.at(buffer_slice.allocation()), + {b_.getInt64(buffer_slice.offset())}); } // If gte_index is nonempty, we have to dereference `loc` to get to the @@ -1822,7 +1954,7 @@ std::unique_ptr IrEmitterUnnested::BuildKernelThunk( loc = Load(InBoundsGEP(loc, {b_.getInt64(idx)})); } - bindings_.BindHloToIrValue(*instr, loc, index); + bind_slice_to_ir_value(slice, loc); } // Bind the temp buffer so that nested subcomputations can find it if they @@ -1834,9 +1966,66 @@ std::unique_ptr IrEmitterUnnested::BuildKernelThunk( llvm::ConstantPointerNull::get(b_.getInt8PtrTy())); } - return absl::make_unique( + return absl::make_unique(thunk_info, non_constant_buffers, + std::string(kernel->getName())); +} + +std::unique_ptr IrEmitterUnnested::BuildKernelThunk( + const HloInstruction* inst, bool implements_whole_instruction) { + std::vector hlo_slices = + GetHloBufferSlices(inst, ir_emitter_context_->buffer_assignment()); + + std::vector slice_ptrs; + slice_ptrs.reserve(hlo_slices.size()); + for (auto& slice : hlo_slices) { + slice_ptrs.push_back(&slice); + } + + return BuildKernelThunkFromBufferSlices( + inst->name(), implements_whole_instruction ? GetThunkInfo(inst) : Thunk::ThunkInfo(), - non_constant_buffers, std::string(kernel->getName())); + slice_ptrs, [this](const BufferSlice* slice, llvm::Value* value) { + const HloBufferSlice* hlo_buffer_slice = + static_cast(slice); + const HloInstruction* instr = hlo_buffer_slice->instr; + const ShapeIndex& index = hlo_buffer_slice->hlo_index; + VLOG(3) << "Buffer for " << instr->ToString() << " at " + << index.ToString() << " is found in slice " + << hlo_buffer_slice->buffer_slice.ToString() << " at GTE index " + << hlo_buffer_slice->gte_index.ToString(); + + bindings_.BindHloToIrValue(*instr, value, index); + }); +} + +std::unique_ptr IrEmitterUnnested::BuildKernelThunkForMlir( + absl::string_view name, Thunk::ThunkInfo thunk_info, + absl::Span slices, + std::vector* ir_arrays) { + absl::flat_hash_set buffers_written; + std::vector slice_ptrs; + slice_ptrs.reserve(slices.size()); + for (auto& slice : slices) { + slice_ptrs.push_back(&slice); + if (slice.written) { + buffers_written.insert(slice.buffer_slice); + } + } + + ir_arrays->clear(); + return BuildKernelThunkFromBufferSlices( + name, thunk_info, slice_ptrs, + [&](const BufferSlice* slice, llvm::Value* value) { + const auto& mlir_slice = static_cast(*slice); + + llvm_ir::IrArray ir_array( + CastToTypedValue(mlir_slice.shape, value, &b_), mlir_slice.shape); + if (!buffers_written.contains(slice->buffer_slice)) { + ir_array.MarkInvariantOverWholeProgram(&value->getContext()); + } + + ir_arrays->push_back(ir_array); + }); } StatusOr> IrEmitterUnnested::BuildInitializerThunk( @@ -2043,7 +2232,7 @@ Status CheckConditionalBuffersShareAllocation( } // namespace -std::unique_ptr IrEmitterUnnested::BuildWhileThunk( +StatusOr> IrEmitterUnnested::BuildWhileThunk( const HloInstruction* hlo) { // Check that all while-related buffers share an allocation. TF_CHECK_OK(CheckWhileBuffersShareAllocation( @@ -2051,24 +2240,26 @@ std::unique_ptr IrEmitterUnnested::BuildWhileThunk( // Generate thunk sequence for while 'condition'. HloComputation* condition = hlo->while_condition(); - IrEmitterUnnested ir_emitter_condition(hlo_module_config_, condition, - ir_emitter_context_); - TF_CHECK_OK(condition->Accept(&ir_emitter_condition)); + TF_ASSIGN_OR_RETURN(auto ir_emitter_condition, + IrEmitterUnnested::Create(hlo_module_config_, condition, + ir_emitter_context_)); + TF_RETURN_IF_ERROR(condition->Accept(ir_emitter_condition.get())); // Generate thunk sequence for while 'body'. HloComputation* body = hlo->while_body(); - IrEmitterUnnested ir_emitter_body(hlo_module_config_, body, - ir_emitter_context_); - TF_CHECK_OK(body->Accept(&ir_emitter_body)); + TF_ASSIGN_OR_RETURN( + auto ir_emitter_body, + IrEmitterUnnested::Create(hlo_module_config_, body, ir_emitter_context_)); + TF_RETURN_IF_ERROR(body->Accept(ir_emitter_body.get())); - return absl::make_unique( + return std::unique_ptr(new WhileThunk( GetThunkInfo(hlo), GetAllocationSlice(*condition->root_instruction()), // cond result - ir_emitter_condition.ConsumeThunkSequence(), - ir_emitter_body.ConsumeThunkSequence()); + ir_emitter_condition->ConsumeThunkSequence(), + ir_emitter_body->ConsumeThunkSequence())); } -std::unique_ptr IrEmitterUnnested::BuildForThunk( +StatusOr> IrEmitterUnnested::BuildForThunk( const HloInstruction* hlo, const int64 loop_limit) { // Check that all while-related buffers share an allocation. TF_CHECK_OK(CheckWhileBuffersShareAllocation( @@ -2076,15 +2267,16 @@ std::unique_ptr IrEmitterUnnested::BuildForThunk( // Generate thunk sequence for while 'body' (will be used a For loop body). HloComputation* body = hlo->while_body(); - IrEmitterUnnested ir_emitter_body(hlo_module_config_, body, - ir_emitter_context_); - TF_CHECK_OK(body->Accept(&ir_emitter_body)); + TF_ASSIGN_OR_RETURN( + auto ir_emitter_body, + IrEmitterUnnested::Create(hlo_module_config_, body, ir_emitter_context_)); + TF_RETURN_IF_ERROR(body->Accept(ir_emitter_body.get())); - return absl::make_unique(GetThunkInfo(hlo), loop_limit, - ir_emitter_body.ConsumeThunkSequence()); + return std::unique_ptr(new ForThunk( + GetThunkInfo(hlo), loop_limit, ir_emitter_body->ConsumeThunkSequence())); } -std::unique_ptr IrEmitterUnnested::BuildConditionalThunk( +StatusOr> IrEmitterUnnested::BuildConditionalThunk( const HloInstruction* hlo) { // Check that the buffers used in conditional are shared with the operands and // result appropriately. @@ -2096,15 +2288,17 @@ std::unique_ptr IrEmitterUnnested::BuildConditionalThunk( for (int j = 0; j < hlo->branch_count(); ++j) { branch_operands.emplace_back(GetAllocationSlice(*hlo->operand(j + 1))); HloComputation* branch_computation = hlo->branch_computation(j); - IrEmitterUnnested ir_emitter(hlo_module_config_, branch_computation, - ir_emitter_context_); - TF_CHECK_OK(branch_computation->Accept(&ir_emitter)); - branch_thunks.push_back(std::move(*ir_emitter.ConsumeThunkSequence())); + TF_ASSIGN_OR_RETURN( + auto ir_emitter, + IrEmitterUnnested::Create(hlo_module_config_, branch_computation, + ir_emitter_context_)); + TF_CHECK_OK(branch_computation->Accept(ir_emitter.get())); + branch_thunks.push_back(std::move(*ir_emitter->ConsumeThunkSequence())); } - return absl::make_unique( + return std::unique_ptr(new ConditionalThunk( GetThunkInfo(hlo), GetAllocationSlice(*hlo->operand(0)), branch_operands, - std::move(branch_thunks)); + std::move(branch_thunks))); } Status IrEmitterUnnested::EmitTargetElementLoopInThunk( diff --git a/tensorflow/compiler/xla/service/gpu/ir_emitter_unnested.h b/tensorflow/compiler/xla/service/gpu/ir_emitter_unnested.h index 019fcdf21db..b9146dd8fae 100644 --- a/tensorflow/compiler/xla/service/gpu/ir_emitter_unnested.h +++ b/tensorflow/compiler/xla/service/gpu/ir_emitter_unnested.h @@ -17,6 +17,7 @@ limitations under the License. #define TENSORFLOW_COMPILER_XLA_SERVICE_GPU_IR_EMITTER_UNNESTED_H_ #include "absl/container/inlined_vector.h" +#include "tensorflow/compiler/mlir/xla/transforms/mhlo_to_lhlo_with_xla.h" #include "tensorflow/compiler/xla/service/gpu/ir_emitter.h" #include "tensorflow/compiler/xla/service/gpu/kernel_mapping_scheme.h" #include "tensorflow/compiler/xla/service/gpu/sequential_thunk.h" @@ -28,6 +29,40 @@ limitations under the License. namespace xla { namespace gpu { +struct BufferSlice { + // The root buffer to look at. + BufferAllocation::Slice buffer_slice; + + // Describes how to dereference starting at that buffer to get to the buffer + // in question. + ShapeIndex gte_index; +}; + +// Describes how to access a particular subshape for an HLO. For instance if +// `.hlo_index` is {1} and `.gte_index` is {3, 4} then buffer for `.instr` at +// ShapeIndex {1} (i.e. the buffer for the second tuple element of hlo) is +// found at `.buffer_slice`[3][4]. That is, `.slice` is a void***, which we +// dereference twice -- first at index 3, and then at index 4 -- to get the +// address of our buffer. +struct HloBufferSlice : public BufferSlice { + const HloInstruction* instr; + ShapeIndex hlo_index; +}; + +struct MlirBufferSlice : public BufferSlice { + // The buffer is modified by the kernel. + bool written; + + Shape shape; +}; + +struct MlirEmitterInput { + mlir::Operation* op; + absl::string_view name; + Thunk::ThunkInfo thunk_info; + MlirBufferSlice extra_slice; +}; + // Emits LLVM IR for an "unnested computation". // // An unnested computation is an HloComputation which you run by executing one @@ -89,12 +124,14 @@ class IrEmitterUnnested : public IrEmitter, const string& loop_name, llvm::Value* tile_height, llvm::Value* tile_width, KernelSupportLibrary* ksl)>; - IrEmitterUnnested(const HloModuleConfig& hlo_module_config, - const HloComputation* hlo_computation, - IrEmitterContext* ir_emitter_context); IrEmitterUnnested(const IrEmitterUnnested&) = delete; IrEmitterUnnested& operator=(const IrEmitterUnnested&) = delete; + static StatusOr> Create( + const HloModuleConfig& hlo_module_config, + const HloComputation* hlo_computation, + IrEmitterContext* ir_emitter_context); + // Transfers the ownship of thunk_sequence_ out. std::unique_ptr ConsumeThunkSequence() { return std::make_unique(std::move(thunk_sequence_)); @@ -124,6 +161,7 @@ class IrEmitterUnnested : public IrEmitter, Status HandleScatter(HloInstruction* scatter) override; Status HandleSelect(HloInstruction* select) override; Status HandleSort(HloInstruction* sort) override; + Status EmitMlirSort(MlirEmitterInput input); Status HandleTriangularSolve(HloInstruction* hlo) override; Status HandleTupleSelect(HloInstruction* tuple_select) override; Status HandleAllReduce(HloInstruction* crs) override; @@ -148,6 +186,10 @@ class IrEmitterUnnested : public IrEmitter, Status Postprocess(HloInstruction* hlo) override; private: + IrEmitterUnnested(const HloModuleConfig& hlo_module_config, + const HloComputation* hlo_computation, + IrEmitterContext* ir_emitter_context); + // Add a owning Thunk object to the thunk sequence. void AddThunkToThunkSequence(std::unique_ptr thunk) override { thunk_sequence_.emplace_back(std::move(thunk)); @@ -264,8 +306,7 @@ class IrEmitterUnnested : public IrEmitter, // Builds the prototype of the IR kernel for `inst` and adds it to the module. // This kernel takes as arguments pointers to the given buffer allocations. llvm::Function* BuildKernelPrototype( - const HloInstruction& inst, - absl::Span args); + absl::string_view name, absl::Span args); // Helper for writing extra outputs from inside a reduce kernel. Status EmitExtraOutputsForReduce( @@ -490,6 +531,12 @@ class IrEmitterUnnested : public IrEmitter, HloComputation* reducer, llvm::Type* element_type, llvm::Value* partial_result_address); + std::unique_ptr BuildKernelThunkFromBufferSlices( + absl::string_view name, Thunk::ThunkInfo thunk_info, + absl::Span slices, + std::function + bind_slice_to_ir_value); + // Returns a KernelThunk that invokes the kernel emitted for `inst`. The // caller needs to make sure `inst` outlives the lifetime of the returned // Thunk object. 'implements_whole_instruction' specifies whether this @@ -498,6 +545,11 @@ class IrEmitterUnnested : public IrEmitter, std::unique_ptr BuildKernelThunk( const HloInstruction* inst, bool implements_whole_instruction); + std::unique_ptr BuildKernelThunkForMlir( + absl::string_view name, Thunk::ThunkInfo thunk_info, + absl::Span slices, + std::vector* ir_arrays); + // Returns a thunk that, given a reduce or select-and-scatter op, // initializes its memory to the appropriate initial value. StatusOr> BuildInitializerThunk( @@ -505,17 +557,18 @@ class IrEmitterUnnested : public IrEmitter, // Returns a WhileThunk that invokes thunk sequences for 'condition' and // 'body' sub-computations of while instruction 'hlo'. - std::unique_ptr BuildWhileThunk(const HloInstruction* hlo); + StatusOr> BuildWhileThunk(const HloInstruction* hlo); // Returns a ForThunk which executes 'loop_limit' invocations of a thunk // sequence from the 'body' sub-computation of the while instruction 'hlo'. - std::unique_ptr BuildForThunk(const HloInstruction* hlo, - const int64 loop_limit); + StatusOr> BuildForThunk(const HloInstruction* hlo, + const int64 loop_limit); // Returns a ConditionalThunk which executes the thunk sequence for the // 'branch_computation' corresponding to the predicate/branch_index of the // given conditional instruction. - std::unique_ptr BuildConditionalThunk(const HloInstruction* hlo); + StatusOr> BuildConditionalThunk( + const HloInstruction* hlo); // Emits current thread id with the given type. // @@ -545,6 +598,9 @@ class IrEmitterUnnested : public IrEmitter, absl::optional thread_id_filter = absl::nullopt, absl::optional block_id_filter = absl::nullopt); + StatusOr GetOrCreateSubComputationFromRegion( + mlir::Region* region); + // Returns the last generated thunk. Thunk* LastThunk() const { return thunk_sequence_.back().get(); } @@ -555,6 +611,14 @@ class IrEmitterUnnested : public IrEmitter, // The HloComputation that this IrEmitter emits code for. const HloComputation* hlo_computation_; + + mlir::OwningModuleRef mlir_scratch_module_; + + // This is for cache-purpose only. It has no significant semantics. + mlir::LhloDialectEmitter lhlo_scratch_emitter_; + + absl::flat_hash_map> + scratch_nested_computations_; }; } // namespace gpu diff --git a/tensorflow/compiler/xla/service/gpu/llvm_gpu_backend/gpu_backend_lib.cc b/tensorflow/compiler/xla/service/gpu/llvm_gpu_backend/gpu_backend_lib.cc index 1228a1b4823..04af67a70b9 100644 --- a/tensorflow/compiler/xla/service/gpu/llvm_gpu_backend/gpu_backend_lib.cc +++ b/tensorflow/compiler/xla/service/gpu/llvm_gpu_backend/gpu_backend_lib.cc @@ -62,8 +62,10 @@ limitations under the License. #include "tensorflow/core/lib/io/path.h" #include "tensorflow/core/platform/env.h" #include "tensorflow/core/platform/logging.h" +#include "tensorflow/core/platform/random.h" #include "tensorflow/core/platform/tracing.h" #include "tensorflow/core/profiler/lib/traceme.h" +#include "tensorflow/core/util/env_var.h" namespace xla { namespace gpu { @@ -86,14 +88,21 @@ static string GetSmName(std::pair compute_capability) { int sm_version = 30; // If the current compute capability isn't known, fallback to the // most recent version before it. - for (int v : {75, 72, 70, 62, 61, 60, 53, 52, 50, 37, 35, 32, 30}) { + int supported_versions[] = {75, 72, 70, 62, 61, 60, 53, + 52, 50, 37, 35, 32, 30}; + for (int v : supported_versions) { if (v <= compute_capability_version) { sm_version = v; break; } } - if (sm_version != compute_capability_version) { + // If the current CC isn't supported by LLVM and it is newer then + // the max supported LLVM version, do not warn about it. The end + // user can't do anything about this. PTX compiled for SM75 will + // run on SM80 too. + if (sm_version != compute_capability_version && + compute_capability_version < supported_versions[0]) { LOG(WARNING) << "Unknown compute capability (" << compute_capability.first << ", " << compute_capability.second << ") ." << "Defaulting to telling LLVM that we're compiling for sm_" @@ -570,6 +579,60 @@ static std::vector GetROCDLPaths(int amdgpu_version, return result; } +struct HsacoCacheEntry { + uint64 hash; + std::string ir; + int gfx; + std::vector hsaco; +}; + +struct HsacoCache { + protected: + std::vector cache; + std::mutex m_mutex; + int request_count = 0; + int hit_count = 0; + + public: + static bool Find(const std::string& ir, uint64_t& hash, int gfx, + std::vector& hsaco); + static void Add(const std::string& ir, uint64_t hash, int gfx, + const std::vector& hsaco); +}; + +static HsacoCache g_hsacoCache; + +bool HsacoCache::Find(const std::string& ir, uint64_t& hash, int gfx, + std::vector& hsaco) { + std::lock_guard lg(g_hsacoCache.m_mutex); + hash = std::hash{}(ir); + bool hit = false; + for (auto& x : g_hsacoCache.cache) { + if (x.hash != hash) continue; + if (x.gfx != gfx) continue; + if (x.ir != ir) continue; + hsaco = x.hsaco; + hit = true; + break; + } + g_hsacoCache.request_count++; + if (hit) g_hsacoCache.hit_count++; + if (!(g_hsacoCache.request_count % 50)) + VLOG(1) << "HSACO cache: " << g_hsacoCache.request_count << " requests, " + << g_hsacoCache.hit_count << " hits"; + return hit; +} + +void HsacoCache::Add(const std::string& ir, uint64_t hash, int gfx, + const std::vector& hsaco) { + std::lock_guard lg(g_hsacoCache.m_mutex); + g_hsacoCache.cache.resize(g_hsacoCache.cache.size() + 1); + g_hsacoCache.cache.back().ir = ir; + g_hsacoCache.cache.back().hash = hash; + g_hsacoCache.cache.back().gfx = gfx; + g_hsacoCache.cache.back().hsaco = hsaco; +} + // Emits the given module to HSA Code Object. target_machine is an initialized // TargetMachine for the AMDGPU target. StatusOr> EmitModuleToHsaco( @@ -584,18 +647,29 @@ StatusOr> EmitModuleToHsaco( std::string tempdir_name = tempdir_vector.front(); VLOG(1) << "Compile-time artifacts located at: " << tempdir_name; + bool keep_tempfiles = false; + TF_CHECK_OK(tensorflow::ReadBoolFromEnvVar("TF_ROCM_KEEP_XLA_TEMPFILES", + /*default_val=*/false, + &keep_tempfiles)); // Prepare filenames for all stages of compilation: // IR, binary ISA, and HSACO. - std::string ir_filename = absl::StrCat(module->getModuleIdentifier(), ".ll"); + std::string random_number = std::to_string(tensorflow::random::New64()); + std::string ir_filename = + absl::StrCat(module->getModuleIdentifier(), random_number + ".ll"); std::string ir_path = tensorflow::io::JoinPath(tempdir_name, ir_filename); + std::string ir_opt_filename = + absl::StrCat(module->getModuleIdentifier(), random_number + "_opt.ll"); + std::string ir_opt_path = + tensorflow::io::JoinPath(tempdir_name, ir_opt_filename); + std::string isabin_filename = - absl::StrCat(module->getModuleIdentifier(), ".o"); + absl::StrCat(module->getModuleIdentifier(), random_number + ".o"); std::string isabin_path = tensorflow::io::JoinPath(tempdir_name, isabin_filename); std::string hsaco_filename = - absl::StrCat(module->getModuleIdentifier(), ".hsaco"); + absl::StrCat(module->getModuleIdentifier(), random_number + ".hsaco"); std::string hsaco_path = tensorflow::io::JoinPath(tempdir_name, hsaco_filename); @@ -613,7 +687,7 @@ StatusOr> EmitModuleToHsaco( std::string module_id = module->getModuleIdentifier(); IrDumpingPassManager codegen_passes( ReplaceFilenameExtension(tensorflow::io::Basename(module_id), - "-amdgpu.dummy"), + random_number + "-amdgpu.dummy"), "", false); codegen_passes.add(new llvm::TargetLibraryInfoWrapperPass( llvm::Triple(module->getTargetTriple()))); @@ -627,6 +701,12 @@ StatusOr> EmitModuleToHsaco( codegen_passes.run(*module); isabin_fs->flush(); + if (keep_tempfiles) { + std::unique_ptr ir_fs( + new llvm::raw_fd_ostream(ir_opt_path, ec, llvm::sys::fs::F_None)); + module->print(*ir_fs, nullptr); + ir_fs->flush(); + } // Locate lld. // TODO(whchung@gmail.com): change to tensorflow::ROCmRoot() after // ROCm-Device-Libs PR. @@ -652,9 +732,9 @@ StatusOr> EmitModuleToHsaco( int lld_result = llvm::sys::ExecuteAndWait(*lld_program, llvm_ir::AsArrayRef(lld_args), llvm::None, {}, 0, 0, &error_message); - if (lld_result) { - return xla::InternalError("ld.lld execute fail: %s", error_message); + return xla::InternalError("ld.lld execute fail: %s, error code %d", + error_message, lld_result); } // Read HSACO. @@ -664,6 +744,12 @@ StatusOr> EmitModuleToHsaco( std::vector hsaco(hsaco_file_size); hsaco_file.seekg(0, std::ios::beg); hsaco_file.read(reinterpret_cast(&hsaco[0]), hsaco_file_size); + hsaco_file.close(); + if (!keep_tempfiles) { + remove(ir_path.c_str()); + remove(isabin_path.c_str()); + remove(hsaco_path.c_str()); + } return hsaco; } @@ -728,6 +814,20 @@ StatusOr> CompileToHsaco( std::vector hsaco; std::unique_ptr target_machine; + std::string str; + llvm::raw_string_ostream stream(str); + stream << *module; + // Delete the first two lines, since they usually vary even when the rest of + // the code is the same (but verify that they are what we expect). + if (str.size() >= 13 && str.substr(0, 13) == "; ModuleID = ") { + auto pos = str.find("\n"); + if (pos != std::string::npos) str = str.substr(pos + 1); + } + if (str.size() >= 18 && str.substr(0, 18) == "source_filename = ") { + auto pos = str.find("\n"); + if (pos != std::string::npos) str = str.substr(pos + 1); + } + str += hlo_module_config.compilation_cache_key(); { tensorflow::profiler::TraceMe activity( [&] { return absl::StrCat("Compiling IR", module->getName().str()); }, @@ -739,6 +839,21 @@ StatusOr> CompileToHsaco( return xla::InternalError( "Incompatible AMD GCN ISA version was specified."); } + uint64_t hash; + if (HsacoCache::Find(str, hash, *amdgpu_version, hsaco)) { + VLOG(1) << "HSACO cache hit"; + return hsaco; + } + VLOG(1) << "HSACO cache miss"; + bool dump_lls = false; + if (dump_lls) { + static int hsaco_count = 0; + std::string name = "/tmp/" + std::to_string(hsaco_count) + ".ll"; + hsaco_count++; + std::ofstream ofs(name); + ofs << str; + ofs.close(); + } llvm::Triple default_target_triple("amdgcn--amdhsa-amdgiz"); // Construct LLVM TargetMachine for AMDGPU. @@ -754,6 +869,7 @@ StatusOr> CompileToHsaco( // Lower optimized LLVM module to HSA code object. TF_ASSIGN_OR_RETURN(hsaco, EmitModuleToHsaco(module, target_machine.get())); + HsacoCache::Add(str, hash, *amdgpu_version, hsaco); } return hsaco; } diff --git a/tensorflow/compiler/xla/service/gpu/tests/BUILD b/tensorflow/compiler/xla/service/gpu/tests/BUILD index a2bddd2d0d7..809b277317f 100644 --- a/tensorflow/compiler/xla/service/gpu/tests/BUILD +++ b/tensorflow/compiler/xla/service/gpu/tests/BUILD @@ -458,6 +458,35 @@ xla_test( ], ) +tf_cc_test( + name = "sorting_test", + srcs = [ + "sorting_test.cc", + ], + tags = tf_cuda_tests_tags() + [ + "no_rocm", + ], + deps = [ + ":gpu_codegen_test", + "//tensorflow/compiler/xla:debug_options_flags", + "//tensorflow/compiler/xla:statusor", + "//tensorflow/compiler/xla:xla_proto_cc", + "//tensorflow/compiler/xla/service:gpu_plugin", + "//tensorflow/compiler/xla/service:hlo", + "//tensorflow/compiler/xla/service:hlo_module_config", + "//tensorflow/compiler/xla/service:hlo_parser", + "//tensorflow/compiler/xla/service/gpu:gpu_executable", + "//tensorflow/compiler/xla/tests:filecheck", + "//tensorflow/compiler/xla/tests:hlo_test_base", + "//tensorflow/compiler/xla/tests:llvm_irgen_test_base", + "//tensorflow/core:lib", + "//tensorflow/core:test", + "//tensorflow/core:test_main", + "//tensorflow/stream_executor/lib", + "@com_google_absl//absl/memory", + ], +) + tf_cc_binary( name = "hlo_to_llvm_ir", srcs = ["hlo_to_llvm_ir.cc"], diff --git a/tensorflow/compiler/xla/service/gpu/tests/sorting.hlo b/tensorflow/compiler/xla/service/gpu/tests/sorting.hlo index 272c9a25769..4d29a8df116 100644 --- a/tensorflow/compiler/xla/service/gpu/tests/sorting.hlo +++ b/tensorflow/compiler/xla/service/gpu/tests/sorting.hlo @@ -8,162 +8,162 @@ compare { ROOT lt = pred[] compare(p.0.lhs, p.0.rhs), direction=LT } -// CHECK: define void @sort(i8* noalias align 64 dereferenceable(24) [[ALLOC0:%.*]], i8* noalias align 16 dereferenceable(24) [[ALLOC1:%.*]]) +// CHECK: define void @sort(i8* noalias align 64 dereferenceable(24) [[ALLOC0:%.*]]) // CHECK-NEXT: entry: // CHECK-NEXT: [[COMPARE_RETURN_BUFFER:%.*]] = alloca i8, align 1 -// CHECK-NEXT: [[SORT_RAW:%.*]] = getelementptr inbounds i8, i8* [[ALLOC0]], i64 0 -// CHECK-NEXT: [[SORT_TYPED:%.*]] = bitcast i8* [[SORT_RAW]] to [2 x [3 x float]]* -// CHECK-NEXT: [[X_RAW:%.*]] = getelementptr inbounds i8, i8* [[ALLOC1]], i64 0 -// CHECK-NEXT: [[X_TYPED:%.*]] = bitcast i8* [[X_RAW]] to [2 x [3 x float]]* -// CHECK-NEXT: [[TMP0:%.*]] = call i32 @llvm.nvvm.read.ptx.sreg.ctaid.x(), !range !6 -// CHECK-NEXT: [[BLOCK_ID:%.*]] = zext i32 [[TMP0]] to i64 -// CHECK-NEXT: [[TMP1:%.*]] = call i32 @llvm.nvvm.read.ptx.sreg.tid.x(), !range !7 -// CHECK-NEXT: [[THREAD_ID:%.*]] = zext i32 [[TMP1]] to i64 -// CHECK-NEXT: [[TMP2:%.*]] = mul nuw nsw i64 [[BLOCK_ID]], 4 -// CHECK-NEXT: [[LINEAR_INDEX:%.*]] = add nuw nsw i64 [[TMP2]], [[THREAD_ID]] +// CHECK-NEXT: [[TMP0:%.*]] = getelementptr inbounds i8, i8* [[ALLOC0:%.*]], i64 0 +// CHECK-NEXT: [[TMP1:%.*]] = bitcast i8* [[TMP0]] to [2 x [3 x float]]* +// CHECK-NEXT: [[TMP2:%.*]] = getelementptr inbounds i8, i8* [[ALLOC0]], i64 0 +// CHECK-NEXT: [[TMP3:%.*]] = bitcast i8* [[TMP2]] to [2 x [3 x float]]* +// CHECK-NEXT: [[TMP4:%.*]] = call i32 @llvm.nvvm.read.ptx.sreg.ctaid.x(), !range !6 +// CHECK-NEXT: [[BLOCK_ID:%.*]] = zext i32 [[TMP4]] to i64 +// CHECK-NEXT: [[TMP5:%.*]] = call i32 @llvm.nvvm.read.ptx.sreg.tid.x(), !range !7 +// CHECK-NEXT: [[THREAD_ID:%.*]] = zext i32 [[TMP5]] to i64 +// CHECK-NEXT: [[TMP6:%.*]] = mul nuw nsw i64 [[BLOCK_ID]], 4 +// CHECK-NEXT: [[LINEAR_INDEX:%.*]] = add nuw nsw i64 [[TMP6]], [[THREAD_ID]] // CHECK-NEXT: [[LINEAR_INDEX_IN_RANGE:%.*]] = icmp ult i64 [[LINEAR_INDEX]], 4 // CHECK-NEXT: call void @llvm.assume(i1 [[LINEAR_INDEX_IN_RANGE]]) -// CHECK-NEXT: [[TMP3:%.*]] = udiv i64 [[LINEAR_INDEX]], 1 -// CHECK-NEXT: [[TMP4:%.*]] = urem i64 [[TMP3]], 2 -// CHECK-NEXT: [[TMP5:%.*]] = udiv i64 [[LINEAR_INDEX]], 2 -// CHECK-NEXT: [[TMP6:%.*]] = icmp ult i64 [[LINEAR_INDEX]], 4 -// CHECK-NEXT: br i1 [[TMP6]], label [[SORT_IN_BOUNDS_TRUE:%.*]], label [[SORT_IN_BOUNDS_AFTER:%.*]] +// CHECK-NEXT: [[TMP7:%.*]] = udiv i64 [[LINEAR_INDEX]], 1 +// CHECK-NEXT: [[TMP8:%.*]] = urem i64 [[TMP7]], 2 +// CHECK-NEXT: [[TMP9:%.*]] = udiv i64 [[LINEAR_INDEX]], 2 +// CHECK-NEXT: [[TMP10:%.*]] = icmp ult i64 [[LINEAR_INDEX]], 4 +// CHECK-NEXT: br i1 [[TMP10]], label [[SORT_IN_BOUNDS_TRUE:%.*]], label [[SORT_IN_BOUNDS_AFTER:%.*]] // CHECK: sort.in_bounds-after: // CHECK-NEXT: ret void // CHECK: sort.in_bounds-true: -// CHECK-NEXT: [[TMP7:%.*]] = mul i64 [[TMP4]], 2 -// CHECK-NEXT: [[TMP8:%.*]] = xor i64 [[TMP7]], 1 -// CHECK-NEXT: [[TMP9:%.*]] = icmp slt i64 [[TMP7]], [[TMP8]] -// CHECK-NEXT: [[TMP10:%.*]] = icmp slt i64 [[TMP8]], 3 -// CHECK-NEXT: [[TMP11:%.*]] = and i1 [[TMP9]], [[TMP10]] -// CHECK-NEXT: br i1 [[TMP11]], label [[SMALLER_COMPARISON_INDEX_TRUE:%.*]], label [[SMALLER_COMPARISON_INDEX_AFTER:%.*]] +// CHECK-NEXT: [[TMP11:%.*]] = mul i64 [[TMP8]], 2 +// CHECK-NEXT: [[TMP12:%.*]] = xor i64 [[TMP11]], 1 +// CHECK-NEXT: [[TMP13:%.*]] = icmp slt i64 [[TMP11]], [[TMP12]] +// CHECK-NEXT: [[TMP14:%.*]] = icmp slt i64 [[TMP12]], 3 +// CHECK-NEXT: [[TMP15:%.*]] = and i1 [[TMP13]], [[TMP14]] +// CHECK-NEXT: br i1 [[TMP15]], label [[SMALLER_COMPARISON_INDEX_TRUE:%.*]], label [[SMALLER_COMPARISON_INDEX_AFTER:%.*]] // CHECK: smaller_comparison_index-after: // CHECK-NEXT: br label [[SORT_IN_BOUNDS_AFTER]] // CHECK: smaller_comparison_index-true: -// CHECK-NEXT: [[TMP12:%.*]] = getelementptr inbounds [2 x [3 x float]], [2 x [3 x float]]* [[SORT_TYPED]], i64 0, i64 [[TMP5]], i64 [[TMP8]] -// CHECK-NEXT: [[TMP13:%.*]] = getelementptr inbounds [2 x [3 x float]], [2 x [3 x float]]* [[SORT_TYPED]], i64 0, i64 [[TMP5]], i64 [[TMP7]] -// CHECK-NEXT: call void @compare(float* [[TMP12]], float* [[TMP13]], i8* [[COMPARE_RETURN_BUFFER]]) -// CHECK-NEXT: [[TMP14:%.*]] = load i8, i8* [[COMPARE_RETURN_BUFFER]], align 1 -// CHECK-NEXT: [[BOOLEAN_PREDICATE:%.*]] = icmp ne i8 [[TMP14]], 0 +// CHECK-NEXT: [[TMP16:%.*]] = getelementptr inbounds [2 x [3 x float]], [2 x [3 x float]]* [[TMP1]], i64 0, i64 [[TMP9]], i64 [[TMP12]] +// CHECK-NEXT: [[TMP17:%.*]] = getelementptr inbounds [2 x [3 x float]], [2 x [3 x float]]* [[TMP1]], i64 0, i64 [[TMP9]], i64 [[TMP11]] +// CHECK-NEXT: call void @region_0_4(float* [[TMP16]], float* [[TMP17]], i8* [[COMPARE_RETURN_BUFFER]]) +// CHECK-NEXT: [[TMP18:%.*]] = load i8, i8* [[COMPARE_RETURN_BUFFER]], align 1 +// CHECK-NEXT: [[BOOLEAN_PREDICATE:%.*]] = icmp ne i8 [[TMP18]], 0 // CHECK-NEXT: br i1 [[BOOLEAN_PREDICATE]], label [[IS_SMALLER_THAN_TRUE:%.*]], label [[IS_SMALLER_THAN_AFTER:%.*]] // CHECK: is_smaller_than-after: // CHECK-NEXT: br label [[SMALLER_COMPARISON_INDEX_AFTER]] // CHECK: is_smaller_than-true: -// CHECK-NEXT: [[TMP15:%.*]] = load float, float* [[TMP12]], align 4 -// CHECK-NEXT: [[TMP16:%.*]] = load float, float* [[TMP13]], align 4 -// CHECK-NEXT: [[TMP17:%.*]] = getelementptr inbounds [2 x [3 x float]], [2 x [3 x float]]* [[SORT_TYPED]], i64 0, i64 [[TMP5]], i64 [[TMP7]] -// CHECK-NEXT: store float [[TMP15]], float* [[TMP17]], align 4 -// CHECK-NEXT: [[TMP18:%.*]] = getelementptr inbounds [2 x [3 x float]], [2 x [3 x float]]* [[SORT_TYPED]], i64 0, i64 [[TMP5]], i64 [[TMP8]] -// CHECK-NEXT: store float [[TMP16]], float* [[TMP18]], align 4 +// CHECK-NEXT: [[TMP19:%.*]] = load float, float* [[TMP16]], align 4 +// CHECK-NEXT: [[TMP20:%.*]] = load float, float* [[TMP17]], align 4 +// CHECK-NEXT: [[TMP21:%.*]] = getelementptr inbounds [2 x [3 x float]], [2 x [3 x float]]* [[TMP1]], i64 0, i64 [[TMP9]], i64 [[TMP11]] +// CHECK-NEXT: store float [[TMP19]], float* [[TMP21]], align 4 +// CHECK-NEXT: [[TMP22:%.*]] = getelementptr inbounds [2 x [3 x float]], [2 x [3 x float]]* [[TMP1]], i64 0, i64 [[TMP9]], i64 [[TMP12]] +// CHECK-NEXT: store float [[TMP20]], float* [[TMP22]], align 4 // CHECK-NEXT: br label [[IS_SMALLER_THAN_AFTER]] -// CHECK: define internal void @compare(float* dereferenceable(4) [[P_0_LHS_TYPED:%.*]], float* dereferenceable(4) [[P_0_RHS_TYPED:%.*]], i8* dereferenceable(1) [[OUTPUT_ARG:%.*]]) +// CHECK: define internal void @region_0_4(float* dereferenceable(4) [[P_0_LHS_TYPED:%.*]], float* dereferenceable(4) [[P_0_RHS_TYPED:%.*]], i8* dereferenceable(1) [[OUTPUT_ARG:%.*]]) // CHECK-NEXT: entry: -// CHECK-NEXT: [[LT_TYPED:%.*]] = alloca i8, align 1 -// CHECK-NEXT: [[TMP0:%.*]] = load float, float* [[P_0_LHS_TYPED]], align 4 -// CHECK-NEXT: [[TMP1:%.*]] = load float, float* [[P_0_RHS_TYPED]], align 4 +// CHECK-NEXT: [[COMPARE_3_TYPED:%.*]] = alloca i8, align 1 +// CHECK-NEXT: [[TMP0:%.*]] = load float, float* [[ARG_0_1_TYPED:%.*]], align 4 +// CHECK-NEXT: [[TMP1:%.*]] = load float, float* [[ARG_1_2_TYPED:%.*]], align 4 // CHECK-NEXT: [[TMP2:%.*]] = fcmp olt float [[TMP0]], [[TMP1]] // CHECK-NEXT: [[TMP3:%.*]] = zext i1 [[TMP2]] to i8 -// CHECK-NEXT: store i8 [[TMP3]], i8* [[LT_TYPED]], align 1 -// CHECK-NEXT: [[LOAD_RET_VALUE:%.*]] = load i8, i8* [[LT_TYPED]], align 1 -// CHECK-NEXT: store i8 [[LOAD_RET_VALUE]], i8* [[OUTPUT_ARG]], align 1 +// CHECK-NEXT: store i8 [[TMP3]], i8* [[COMPARE_3_TYPED]], align 1 +// CHECK-NEXT: [[LOAD_RET_VALUE:%.*]] = load i8, i8* [[COMPARE_3_TYPED]], align 1 +// CHECK-NEXT: store i8 [[LOAD_RET_VALUE]], i8* [[OUTPUT_ARG:%.*]], align 1 // CHECK-NEXT: ret void -// CHECK: define void @sort__1(i8* noalias align 64 dereferenceable(24) [[ALLOC0:%.*]], i8* noalias align 16 dereferenceable(24) [[ALLOC1:%.*]]) { +// CHECK: define void @sort__1(i8* noalias align 64 dereferenceable(24) [[ALLOC0:%.*]]) { // CHECK-NEXT: entry: // CHECK-NEXT: [[COMPARE_RETURN_BUFFER:%.*]] = alloca i8, align 1 -// CHECK-NEXT: [[SORT_RAW:%.*]] = getelementptr inbounds i8, i8* [[ALLOC0]], i64 0 -// CHECK-NEXT: [[SORT_TYPED:%.*]] = bitcast i8* [[SORT_RAW]] to [2 x [3 x float]]* -// CHECK-NEXT: [[X_RAW:%.*]] = getelementptr inbounds i8, i8* [[ALLOC1]], i64 0 -// CHECK-NEXT: [[X_TYPED:%.*]] = bitcast i8* [[X_RAW]] to [2 x [3 x float]]* -// CHECK-NEXT: [[TMP0:%.*]] = call i32 @llvm.nvvm.read.ptx.sreg.ctaid.x(), !range !6 -// CHECK-NEXT: [[BLOCK_ID:%.*]] = zext i32 [[TMP0]] to i64 -// CHECK-NEXT: [[TMP1:%.*]] = call i32 @llvm.nvvm.read.ptx.sreg.tid.x(), !range !7 -// CHECK-NEXT: [[THREAD_ID:%.*]] = zext i32 [[TMP1]] to i64 -// CHECK-NEXT: [[TMP2:%.*]] = mul nuw nsw i64 [[BLOCK_ID]], 4 -// CHECK-NEXT: [[LINEAR_INDEX:%.*]] = add nuw nsw i64 [[TMP2]], [[THREAD_ID]] +// CHECK-NEXT: [[TMP0:%.*]] = getelementptr inbounds i8, i8* [[ALLOC0:%.*]], i64 0 +// CHECK-NEXT: [[TMP1:%.*]] = bitcast i8* [[TMP0]] to [2 x [3 x float]]* +// CHECK-NEXT: [[TMP2:%.*]] = getelementptr inbounds i8, i8* [[ALLOC0]], i64 0 +// CHECK-NEXT: [[TMP3:%.*]] = bitcast i8* [[TMP2]] to [2 x [3 x float]]* +// CHECK-NEXT: [[TMP4:%.*]] = call i32 @llvm.nvvm.read.ptx.sreg.ctaid.x(), !range !6 +// CHECK-NEXT: [[BLOCK_ID:%.*]] = zext i32 [[TMP4]] to i64 +// CHECK-NEXT: [[TMP5:%.*]] = call i32 @llvm.nvvm.read.ptx.sreg.tid.x(), !range !7 +// CHECK-NEXT: [[THREAD_ID:%.*]] = zext i32 [[TMP5]] to i64 +// CHECK-NEXT: [[TMP6:%.*]] = mul nuw nsw i64 [[BLOCK_ID]], 4 +// CHECK-NEXT: [[LINEAR_INDEX:%.*]] = add nuw nsw i64 [[TMP6]], [[THREAD_ID]] // CHECK-NEXT: [[LINEAR_INDEX_IN_RANGE:%.*]] = icmp ult i64 [[LINEAR_INDEX]], 4 // CHECK-NEXT: call void @llvm.assume(i1 [[LINEAR_INDEX_IN_RANGE]]) -// CHECK-NEXT: [[TMP3:%.*]] = udiv i64 [[LINEAR_INDEX]], 1 -// CHECK-NEXT: [[TMP4:%.*]] = urem i64 [[TMP3]], 2 -// CHECK-NEXT: [[TMP5:%.*]] = udiv i64 [[LINEAR_INDEX]], 2 -// CHECK-NEXT: [[TMP6:%.*]] = icmp ult i64 [[LINEAR_INDEX]], 4 -// CHECK-NEXT: br i1 [[TMP6]], label [[SORT_IN_BOUNDS_TRUE:%.*]], label [[SORT_IN_BOUNDS_AFTER:%.*]] +// CHECK-NEXT: [[TMP7:%.*]] = udiv i64 [[LINEAR_INDEX]], 1 +// CHECK-NEXT: [[TMP8:%.*]] = urem i64 [[TMP7]], 2 +// CHECK-NEXT: [[TMP9:%.*]] = udiv i64 [[LINEAR_INDEX]], 2 +// CHECK-NEXT: [[TMP10:%.*]] = icmp ult i64 [[LINEAR_INDEX]], 4 +// CHECK-NEXT: br i1 [[TMP10]], label [[SORT_IN_BOUNDS_TRUE:%.*]], label [[SORT_IN_BOUNDS_AFTER:%.*]] // CHECK: sort.in_bounds-after: // CHECK-NEXT: ret void // CHECK: sort.in_bounds-true: -// CHECK-NEXT: [[TMP7:%.*]] = xor i64 [[TMP4]], 3 -// CHECK-NEXT: [[TMP8:%.*]] = icmp slt i64 [[TMP4]], [[TMP7]] -// CHECK-NEXT: [[TMP9:%.*]] = icmp slt i64 [[TMP7]], 3 -// CHECK-NEXT: [[TMP10:%.*]] = and i1 [[TMP8]], [[TMP9]] -// CHECK-NEXT: br i1 [[TMP10]], label [[SMALLER_COMPARISON_INDEX_TRUE:%.*]], label [[SMALLER_COMPARISON_INDEX_AFTER:%.*]] +// CHECK-NEXT: [[TMP11:%.*]] = xor i64 [[TMP8]], 3 +// CHECK-NEXT: [[TMP12:%.*]] = icmp slt i64 [[TMP8]], [[TMP11]] +// CHECK-NEXT: [[TMP13:%.*]] = icmp slt i64 [[TMP11]], 3 +// CHECK-NEXT: [[TMP14:%.*]] = and i1 [[TMP12]], [[TMP13]] +// CHECK-NEXT: br i1 [[TMP14]], label [[SMALLER_COMPARISON_INDEX_TRUE:%.*]], label [[SMALLER_COMPARISON_INDEX_AFTER:%.*]] // CHECK: smaller_comparison_index-after: // CHECK-NEXT: br label [[SORT_IN_BOUNDS_AFTER]] // CHECK: smaller_comparison_index-true: -// CHECK-NEXT: [[TMP11:%.*]] = getelementptr inbounds [2 x [3 x float]], [2 x [3 x float]]* [[SORT_TYPED]], i64 0, i64 [[TMP5]], i64 [[TMP7]] -// CHECK-NEXT: [[TMP12:%.*]] = getelementptr inbounds [2 x [3 x float]], [2 x [3 x float]]* [[SORT_TYPED]], i64 0, i64 [[TMP5]], i64 [[TMP4]] -// CHECK-NEXT: call void @compare(float* [[TMP11]], float* [[TMP12]], i8* [[COMPARE_RETURN_BUFFER]]) -// CHECK-NEXT: [[TMP13:%.*]] = load i8, i8* [[COMPARE_RETURN_BUFFER]], align 1 -// CHECK-NEXT: [[BOOLEAN_PREDICATE:%.*]] = icmp ne i8 [[TMP13]], 0 +// CHECK-NEXT: [[TMP15:%.*]] = getelementptr inbounds [2 x [3 x float]], [2 x [3 x float]]* [[TMP1]], i64 0, i64 [[TMP9]], i64 [[TMP11]] +// CHECK-NEXT: [[TMP16:%.*]] = getelementptr inbounds [2 x [3 x float]], [2 x [3 x float]]* [[TMP1]], i64 0, i64 [[TMP9]], i64 [[TMP8]] +// CHECK-NEXT: call void @region_0_4(float* [[TMP15]], float* [[TMP16]], i8* [[COMPARE_RETURN_BUFFER]]) +// CHECK-NEXT: [[TMP17:%.*]] = load i8, i8* [[COMPARE_RETURN_BUFFER]], align 1 +// CHECK-NEXT: [[BOOLEAN_PREDICATE:%.*]] = icmp ne i8 [[TMP17]], 0 // CHECK-NEXT: br i1 [[BOOLEAN_PREDICATE]], label [[IS_SMALLER_THAN_TRUE:%.*]], label [[IS_SMALLER_THAN_AFTER:%.*]] // CHECK: is_smaller_than-after: // CHECK-NEXT: br label [[SMALLER_COMPARISON_INDEX_AFTER]] // CHECK: is_smaller_than-true: -// CHECK-NEXT: [[TMP14:%.*]] = load float, float* [[TMP11]], align 4 -// CHECK-NEXT: [[TMP15:%.*]] = load float, float* [[TMP12]], align 4 -// CHECK-NEXT: [[TMP16:%.*]] = getelementptr inbounds [2 x [3 x float]], [2 x [3 x float]]* [[SORT_TYPED]], i64 0, i64 [[TMP5]], i64 [[TMP4]] -// CHECK-NEXT: store float [[TMP14]], float* [[TMP16]], align 4 -// CHECK-NEXT: [[TMP17:%.*]] = getelementptr inbounds [2 x [3 x float]], [2 x [3 x float]]* [[SORT_TYPED]], i64 0, i64 [[TMP5]], i64 [[TMP7]] -// CHECK-NEXT: store float [[TMP15]], float* [[TMP17]], align 4 +// CHECK-NEXT: [[TMP18:%.*]] = load float, float* [[TMP15]], align 4 +// CHECK-NEXT: [[TMP19:%.*]] = load float, float* [[TMP16]], align 4 +// CHECK-NEXT: [[TMP20:%.*]] = getelementptr inbounds [2 x [3 x float]], [2 x [3 x float]]* [[TMP1]], i64 0, i64 [[TMP9]], i64 [[TMP8]] +// CHECK-NEXT: store float [[TMP18]], float* [[TMP20]], align 4 +// CHECK-NEXT: [[TMP21:%.*]] = getelementptr inbounds [2 x [3 x float]], [2 x [3 x float]]* [[TMP1]], i64 0, i64 [[TMP9]], i64 [[TMP11]] +// CHECK-NEXT: store float [[TMP19]], float* [[TMP21]], align 4 // CHECK-NEXT: br label [[IS_SMALLER_THAN_AFTER]] -// CHECK: define void @sort__2(i8* noalias align 64 dereferenceable(24) [[ALLOC0:%.*]], i8* noalias align 16 dereferenceable(24) [[ALLOC1:%.*]]) { +// CHECK: define void @sort__2(i8* noalias align 64 dereferenceable(24) [[ALLOC0:%.*]]) { // CHECK-NEXT: entry: // CHECK-NEXT: [[COMPARE_RETURN_BUFFER:%.*]] = alloca i8, align 1 -// CHECK-NEXT: [[SORT_RAW:%.*]] = getelementptr inbounds i8, i8* [[ALLOC0:%.*]], i64 0 -// CHECK-NEXT: [[SORT_TYPED:%.*]] = bitcast i8* [[SORT_RAW]] to [2 x [3 x float]]* -// CHECK-NEXT: [[X_RAW:%.*]] = getelementptr inbounds i8, i8* [[ALLOC1:%.*]], i64 0 -// CHECK-NEXT: [[X_TYPED:%.*]] = bitcast i8* [[X_RAW]] to [2 x [3 x float]]* -// CHECK-NEXT: [[TMP0:%.*]] = call i32 @llvm.nvvm.read.ptx.sreg.ctaid.x(), !range !6 -// CHECK-NEXT: [[BLOCK_ID:%.*]] = zext i32 [[TMP0]] to i64 -// CHECK-NEXT: [[TMP1:%.*]] = call i32 @llvm.nvvm.read.ptx.sreg.tid.x(), !range !7 -// CHECK-NEXT: [[THREAD_ID:%.*]] = zext i32 [[TMP1]] to i64 -// CHECK-NEXT: [[TMP2:%.*]] = mul nuw nsw i64 [[BLOCK_ID]], 4 -// CHECK-NEXT: [[LINEAR_INDEX:%.*]] = add nuw nsw i64 [[TMP2]], [[THREAD_ID]] +// CHECK-NEXT: [[TMP0:%.*]] = getelementptr inbounds i8, i8* [[ALLOC0:%.*]], i64 0 +// CHECK-NEXT: [[TMP1:%.*]] = bitcast i8* [[TMP0]] to [2 x [3 x float]]* +// CHECK-NEXT: [[TMP2:%.*]] = getelementptr inbounds i8, i8* [[ALLOC0]], i64 0 +// CHECK-NEXT: [[TMP3:%.*]] = bitcast i8* [[TMP2]] to [2 x [3 x float]]* +// CHECK-NEXT: [[TMP4:%.*]] = call i32 @llvm.nvvm.read.ptx.sreg.ctaid.x(), !range !6 +// CHECK-NEXT: [[BLOCK_ID:%.*]] = zext i32 [[TMP4]] to i64 +// CHECK-NEXT: [[TMP5:%.*]] = call i32 @llvm.nvvm.read.ptx.sreg.tid.x(), !range !7 +// CHECK-NEXT: [[THREAD_ID:%.*]] = zext i32 [[TMP5]] to i64 +// CHECK-NEXT: [[TMP6:%.*]] = mul nuw nsw i64 [[BLOCK_ID]], 4 +// CHECK-NEXT: [[LINEAR_INDEX:%.*]] = add nuw nsw i64 [[TMP6]], [[THREAD_ID]] // CHECK-NEXT: [[LINEAR_INDEX_IN_RANGE:%.*]] = icmp ult i64 [[LINEAR_INDEX]], 4 // CHECK-NEXT: call void @llvm.assume(i1 [[LINEAR_INDEX_IN_RANGE]]) -// CHECK-NEXT: [[TMP3:%.*]] = udiv i64 [[LINEAR_INDEX]], 1 -// CHECK-NEXT: [[TMP4:%.*]] = urem i64 [[TMP3]], 2 -// CHECK-NEXT: [[TMP5:%.*]] = udiv i64 [[LINEAR_INDEX]], 2 -// CHECK-NEXT: [[TMP6:%.*]] = icmp ult i64 [[LINEAR_INDEX]], 4 -// CHECK-NEXT: br i1 [[TMP6]], label [[SORT_IN_BOUNDS_TRUE:%.*]], label [[SORT_IN_BOUNDS_AFTER:%.*]] +// CHECK-NEXT: [[TMP7:%.*]] = udiv i64 [[LINEAR_INDEX]], 1 +// CHECK-NEXT: [[TMP8:%.*]] = urem i64 [[TMP7]], 2 +// CHECK-NEXT: [[TMP9:%.*]] = udiv i64 [[LINEAR_INDEX]], 2 +// CHECK-NEXT: [[TMP10:%.*]] = icmp ult i64 [[LINEAR_INDEX]], 4 +// CHECK-NEXT: br i1 [[TMP10]], label [[SORT_IN_BOUNDS_TRUE:%.*]], label [[SORT_IN_BOUNDS_AFTER:%.*]] // CHECK: sort.in_bounds-after: // CHECK-NEXT: ret void // CHECK: sort.in_bounds-true: -// CHECK-NEXT: [[TMP7:%.*]] = mul i64 [[TMP4]], 2 -// CHECK-NEXT: [[TMP8:%.*]] = xor i64 [[TMP7]], 1 -// CHECK-NEXT: [[TMP9:%.*]] = icmp slt i64 [[TMP7]], [[TMP8]] -// CHECK-NEXT: [[TMP10:%.*]] = icmp slt i64 [[TMP8]], 3 -// CHECK-NEXT: [[TMP11:%.*]] = and i1 [[TMP9]], [[TMP10]] -// CHECK-NEXT: br i1 [[TMP11]], label [[SMALLER_COMPARISON_INDEX_TRUE:%.*]], label [[SMALLER_COMPARISON_INDEX_AFTER:%.*]] +// CHECK-NEXT: [[TMP11:%.*]] = mul i64 [[TMP8]], 2 +// CHECK-NEXT: [[TMP12:%.*]] = xor i64 [[TMP11]], 1 +// CHECK-NEXT: [[TMP13:%.*]] = icmp slt i64 [[TMP11]], [[TMP12]] +// CHECK-NEXT: [[TMP14:%.*]] = icmp slt i64 [[TMP12]], 3 +// CHECK-NEXT: [[TMP15:%.*]] = and i1 [[TMP13]], [[TMP14]] +// CHECK-NEXT: br i1 [[TMP15]], label [[SMALLER_COMPARISON_INDEX_TRUE:%.*]], label [[SMALLER_COMPARISON_INDEX_AFTER:%.*]] // CHECK: smaller_comparison_index-after: // CHECK-NEXT: br label [[SORT_IN_BOUNDS_AFTER]] // CHECK: smaller_comparison_index-true: -// CHECK-NEXT: [[TMP12:%.*]] = getelementptr inbounds [2 x [3 x float]], [2 x [3 x float]]* [[SORT_TYPED]], i64 0, i64 [[TMP5]], i64 [[TMP8]] -// CHECK-NEXT: [[TMP13:%.*]] = getelementptr inbounds [2 x [3 x float]], [2 x [3 x float]]* [[SORT_TYPED]], i64 0, i64 [[TMP5]], i64 [[TMP7]] -// CHECK-NEXT: call void @compare(float* [[TMP12]], float* [[TMP13]], i8* [[COMPARE_RETURN_BUFFER]]) -// CHECK-NEXT: [[TMP14:%.*]] = load i8, i8* [[COMPARE_RETURN_BUFFER]], align 1 -// CHECK-NEXT: [[BOOLEAN_PREDICATE:%.*]] = icmp ne i8 [[TMP14]], 0 +// CHECK-NEXT: [[TMP16:%.*]] = getelementptr inbounds [2 x [3 x float]], [2 x [3 x float]]* [[TMP1]], i64 0, i64 [[TMP9]], i64 [[TMP12]] +// CHECK-NEXT: [[TMP17:%.*]] = getelementptr inbounds [2 x [3 x float]], [2 x [3 x float]]* [[TMP1]], i64 0, i64 [[TMP9]], i64 [[TMP11]] +// CHECK-NEXT: call void @region_0_4(float* [[TMP16]], float* [[TMP17]], i8* [[COMPARE_RETURN_BUFFER]]) +// CHECK-NEXT: [[TMP18:%.*]] = load i8, i8* [[COMPARE_RETURN_BUFFER]], align 1 +// CHECK-NEXT: [[BOOLEAN_PREDICATE:%.*]] = icmp ne i8 [[TMP18]], 0 // CHECK-NEXT: br i1 [[BOOLEAN_PREDICATE]], label [[IS_SMALLER_THAN_TRUE:%.*]], label [[IS_SMALLER_THAN_AFTER:%.*]] // CHECK: is_smaller_than-after: // CHECK-NEXT: br label [[SMALLER_COMPARISON_INDEX_AFTER]] // CHECK: is_smaller_than-true: -// CHECK-NEXT: [[TMP15:%.*]] = load float, float* [[TMP12]], align 4 -// CHECK-NEXT: [[TMP16:%.*]] = load float, float* [[TMP13]], align 4 -// CHECK-NEXT: [[TMP17:%.*]] = getelementptr inbounds [2 x [3 x float]], [2 x [3 x float]]* [[SORT_TYPED]], i64 0, i64 [[TMP5]], i64 [[TMP7]] -// CHECK-NEXT: store float [[TMP15]], float* [[TMP17]], align 4 -// CHECK-NEXT: [[TMP18:%.*]] = getelementptr inbounds [2 x [3 x float]], [2 x [3 x float]]* [[SORT_TYPED]], i64 0, i64 [[TMP5]], i64 [[TMP8]] -// CHECK-NEXT: store float [[TMP16]], float* [[TMP18]], align 4 +// CHECK-NEXT: [[TMP19:%.*]] = load float, float* [[TMP16]], align 4 +// CHECK-NEXT: [[TMP20:%.*]] = load float, float* [[TMP17]], align 4 +// CHECK-NEXT: [[TMP21:%.*]] = getelementptr inbounds [2 x [3 x float]], [2 x [3 x float]]* [[TMP1]], i64 0, i64 [[TMP9]], i64 [[TMP11]] +// CHECK-NEXT: store float [[TMP19]], float* [[TMP21]], align 4 +// CHECK-NEXT: [[TMP22:%.*]] = getelementptr inbounds [2 x [3 x float]], [2 x [3 x float]]* [[TMP1]], i64 0, i64 [[TMP9]], i64 [[TMP12]] +// CHECK-NEXT: store float [[TMP20]], float* [[TMP22]], align 4 // CHECK-NEXT: br label [[IS_SMALLER_THAN_AFTER]] ENTRY main { x = f32[2, 3] parameter(0) @@ -182,210 +182,198 @@ compare { ROOT lt = pred[] compare(p.1.lhs, p.1.rhs), direction=LT } -// CHECK: define void @sort(i8* noalias align 64 dereferenceable(24) [[ALLOC0:%.*]], i8* noalias align 64 dereferenceable(24) [[ALLOC1:%.*]], i8* noalias align 16 dereferenceable(24) [[ALLOC2:%.*]], i8* noalias align 16 dereferenceable(24) [[ALLOC3:%.*]], i8* noalias align 64 dereferenceable(16) [[ALLOC4:%.*]]) +// CHECK: define void @sort(i8* noalias align 64 dereferenceable(24) [[ALLOC0:%.*]], i8* noalias align 64 dereferenceable(24) [[ALLOC1:%.*]], i8* noalias align 64 dereferenceable(16) [[ALLOC4:%.*]]) // CHECK-NEXT: entry: // CHECK-NEXT: [[COMPARE_RETURN_BUFFER:%.*]] = alloca i8, align 1 -// CHECK-NEXT: [[SORT_RAW:%.*]] = getelementptr inbounds i8, i8* [[ALLOC4]], i64 0 -// CHECK-NEXT: [[SORT_TYPED:%.*]] = bitcast i8* [[SORT_RAW]] to [2 x i8*]* -// CHECK-NEXT: [[SORT_RAW1:%.*]] = getelementptr inbounds i8, i8* [[ALLOC0]], i64 0 -// CHECK-NEXT: [[SORT_TYPED2:%.*]] = bitcast i8* [[SORT_RAW1]] to [2 x [3 x i32]]* -// CHECK-NEXT: [[SORT_RAW3:%.*]] = getelementptr inbounds i8, i8* [[ALLOC1]], i64 0 -// CHECK-NEXT: [[SORT_TYPED4:%.*]] = bitcast i8* [[SORT_RAW3]] to [2 x [3 x float]]* -// CHECK-NEXT: [[X_RAW:%.*]] = getelementptr inbounds i8, i8* [[ALLOC2]], i64 0 -// CHECK-NEXT: [[X_TYPED:%.*]] = bitcast i8* [[X_RAW]] to [2 x [3 x i32]]* -// CHECK-NEXT: [[Y_RAW:%.*]] = getelementptr inbounds i8, i8* [[ALLOC3]], i64 0 -// CHECK-NEXT: [[Y_TYPED:%.*]] = bitcast i8* [[Y_RAW]] to [2 x [3 x float]]* -// CHECK-NEXT: [[TMP0:%.*]] = call i32 @llvm.nvvm.read.ptx.sreg.ctaid.x(), !range !6 -// CHECK-NEXT: [[BLOCK_ID:%.*]] = zext i32 [[TMP0]] to i64 -// CHECK-NEXT: [[TMP1:%.*]] = call i32 @llvm.nvvm.read.ptx.sreg.tid.x(), !range !7 -// CHECK-NEXT: [[THREAD_ID:%.*]] = zext i32 [[TMP1]] to i64 -// CHECK-NEXT: [[TMP2:%.*]] = mul nuw nsw i64 [[BLOCK_ID]], 4 -// CHECK-NEXT: [[LINEAR_INDEX:%.*]] = add nuw nsw i64 [[TMP2]], [[THREAD_ID]] +// CHECK-NEXT: [[TMP0:%.*]] = getelementptr inbounds i8, i8* [[ALLOC0:%.*]], i64 0 +// CHECK-NEXT: [[TMP1:%.*]] = bitcast i8* [[TMP0]] to [2 x [3 x i32]]* +// CHECK-NEXT: [[TMP2:%.*]] = getelementptr inbounds i8, i8* [[ALLOC1:%.*]], i64 0 +// CHECK-NEXT: [[TMP3:%.*]] = bitcast i8* [[TMP2]] to [2 x [3 x float]]* +// CHECK-NEXT: [[TMP4:%.*]] = getelementptr inbounds i8, i8* [[ALLOC4:%.*]], i64 0 +// CHECK-NEXT: [[TMP5:%.*]] = bitcast i8* [[TMP4]] to [2 x i8*]* +// CHECK-NEXT: [[TMP6:%.*]] = call i32 @llvm.nvvm.read.ptx.sreg.ctaid.x(), !range !6 +// CHECK-NEXT: [[BLOCK_ID:%.*]] = zext i32 [[TMP6]] to i64 +// CHECK-NEXT: [[TMP7:%.*]] = call i32 @llvm.nvvm.read.ptx.sreg.tid.x(), !range !7 +// CHECK-NEXT: [[THREAD_ID:%.*]] = zext i32 [[TMP7]] to i64 +// CHECK-NEXT: [[TMP8:%.*]] = mul nuw nsw i64 [[BLOCK_ID]], 4 +// CHECK-NEXT: [[LINEAR_INDEX:%.*]] = add nuw nsw i64 [[TMP8]], [[THREAD_ID]] // CHECK-NEXT: [[LINEAR_INDEX_IN_RANGE:%.*]] = icmp ult i64 [[LINEAR_INDEX]], 4 // CHECK-NEXT: call void @llvm.assume(i1 [[LINEAR_INDEX_IN_RANGE]]) -// CHECK-NEXT: [[TMP3:%.*]] = udiv i64 [[LINEAR_INDEX]], 1 -// CHECK-NEXT: [[TMP4:%.*]] = urem i64 [[TMP3]], 2 -// CHECK-NEXT: [[TMP5:%.*]] = udiv i64 [[LINEAR_INDEX]], 2 -// CHECK-NEXT: [[TMP6:%.*]] = icmp ult i64 [[LINEAR_INDEX]], 4 -// CHECK-NEXT: br i1 [[TMP6]], label [[SORT_IN_BOUNDS_TRUE:%.*]], label [[SORT_IN_BOUNDS_AFTER:%.*]] +// CHECK-NEXT: [[TMP9:%.*]] = udiv i64 [[LINEAR_INDEX]], 1 +// CHECK-NEXT: [[TMP10:%.*]] = urem i64 [[TMP9]], 2 +// CHECK-NEXT: [[TMP11:%.*]] = udiv i64 [[LINEAR_INDEX]], 2 +// CHECK-NEXT: [[TMP12:%.*]] = icmp ult i64 [[LINEAR_INDEX]], 4 +// CHECK-NEXT: br i1 [[TMP12]], label [[SORT_IN_BOUNDS_TRUE:%.*]], label [[SORT_IN_BOUNDS_AFTER:%.*]] // CHECK: sort.in_bounds-after: // CHECK-NEXT: ret void // CHECK: sort.in_bounds-true: -// CHECK-NEXT: [[TMP7:%.*]] = mul i64 [[TMP4]], 2 -// CHECK-NEXT: [[TMP8:%.*]] = xor i64 [[TMP7]], 1 -// CHECK-NEXT: [[TMP9:%.*]] = icmp slt i64 [[TMP7]], [[TMP8]] -// CHECK-NEXT: [[TMP10:%.*]] = icmp slt i64 [[TMP8]], 3 -// CHECK-NEXT: [[TMP11:%.*]] = and i1 [[TMP9]], [[TMP10]] -// CHECK-NEXT: br i1 [[TMP11]], label [[SMALLER_COMPARISON_INDEX_TRUE:%.*]], label [[SMALLER_COMPARISON_INDEX_AFTER:%.*]] +// CHECK-NEXT: [[TMP13:%.*]] = mul i64 [[TMP10]], 2 +// CHECK-NEXT: [[TMP14:%.*]] = xor i64 [[TMP13]], 1 +// CHECK-NEXT: [[TMP15:%.*]] = icmp slt i64 [[TMP13]], [[TMP14]] +// CHECK-NEXT: [[TMP16:%.*]] = icmp slt i64 [[TMP14]], 3 +// CHECK-NEXT: [[TMP17:%.*]] = and i1 [[TMP15]], [[TMP16]] +// CHECK-NEXT: br i1 [[TMP17]], label [[SMALLER_COMPARISON_INDEX_TRUE:%.*]], label [[SMALLER_COMPARISON_INDEX_AFTER:%.*]] // CHECK: smaller_comparison_index-after: // CHECK-NEXT: br label [[SORT_IN_BOUNDS_AFTER]] // CHECK: smaller_comparison_index-true: -// CHECK-NEXT: [[TMP12:%.*]] = getelementptr inbounds [2 x [3 x i32]], [2 x [3 x i32]]* [[SORT_TYPED2]], i64 0, i64 [[TMP5]], i64 [[TMP8]] -// CHECK-NEXT: [[TMP13:%.*]] = getelementptr inbounds [2 x [3 x i32]], [2 x [3 x i32]]* [[SORT_TYPED2]], i64 0, i64 [[TMP5]], i64 [[TMP7]] -// CHECK-NEXT: [[TMP14:%.*]] = getelementptr inbounds [2 x [3 x float]], [2 x [3 x float]]* [[SORT_TYPED4]], i64 0, i64 [[TMP5]], i64 [[TMP8]] -// CHECK-NEXT: [[TMP15:%.*]] = getelementptr inbounds [2 x [3 x float]], [2 x [3 x float]]* [[SORT_TYPED4]], i64 0, i64 [[TMP5]], i64 [[TMP7]] -// CHECK-NEXT: call void @compare(i32* [[TMP12]], i32* [[TMP13]], float* [[TMP14]], float* [[TMP15]], i8* [[COMPARE_RETURN_BUFFER]]) -// CHECK-NEXT: [[TMP16:%.*]] = load i8, i8* [[COMPARE_RETURN_BUFFER]], align 1 -// CHECK-NEXT: [[BOOLEAN_PREDICATE:%.*]] = icmp ne i8 [[TMP16]], 0 +// CHECK-NEXT: [[TMP18:%.*]] = getelementptr inbounds [2 x [3 x i32]], [2 x [3 x i32]]* [[TMP1]], i64 0, i64 [[TMP11]], i64 [[TMP14]] +// CHECK-NEXT: [[TMP19:%.*]] = getelementptr inbounds [2 x [3 x i32]], [2 x [3 x i32]]* [[TMP1]], i64 0, i64 [[TMP11]], i64 [[TMP13]] +// CHECK-NEXT: [[TMP20:%.*]] = getelementptr inbounds [2 x [3 x float]], [2 x [3 x float]]* [[TMP3]], i64 0, i64 [[TMP11]], i64 [[TMP14]] +// CHECK-NEXT: [[TMP21:%.*]] = getelementptr inbounds [2 x [3 x float]], [2 x [3 x float]]* [[TMP3]], i64 0, i64 [[TMP11]], i64 [[TMP13]] +// CHECK-NEXT: call void @region_0_6(i32* [[TMP18]], i32* [[TMP19]], float* [[TMP20]], float* [[TMP21]], i8* [[COMPARE_RETURN_BUFFER]]) +// CHECK-NEXT: [[TMP22:%.*]] = load i8, i8* [[COMPARE_RETURN_BUFFER]], align 1 +// CHECK-NEXT: [[BOOLEAN_PREDICATE:%.*]] = icmp ne i8 [[TMP22]], 0 // CHECK-NEXT: br i1 [[BOOLEAN_PREDICATE]], label [[IS_SMALLER_THAN_TRUE:%.*]], label [[IS_SMALLER_THAN_AFTER:%.*]] // CHECK: is_smaller_than-after: // CHECK-NEXT: br label [[SMALLER_COMPARISON_INDEX_AFTER]] // CHECK: is_smaller_than-true: -// CHECK-NEXT: [[TMP17:%.*]] = load i32, i32* [[TMP12]], align 4 -// CHECK-NEXT: [[TMP18:%.*]] = load i32, i32* [[TMP13]], align 4 -// CHECK-NEXT: [[TMP19:%.*]] = getelementptr inbounds [2 x [3 x i32]], [2 x [3 x i32]]* [[SORT_TYPED2]], i64 0, i64 [[TMP5]], i64 [[TMP7]] -// CHECK-NEXT: store i32 [[TMP17]], i32* [[TMP19]], align 4 -// CHECK-NEXT: [[TMP20:%.*]] = getelementptr inbounds [2 x [3 x i32]], [2 x [3 x i32]]* [[SORT_TYPED2]], i64 0, i64 [[TMP5]], i64 [[TMP8]] -// CHECK-NEXT: store i32 [[TMP18]], i32* [[TMP20]], align 4 -// CHECK-NEXT: [[TMP21:%.*]] = load float, float* [[TMP14]], align 4 -// CHECK-NEXT: [[TMP22:%.*]] = load float, float* [[TMP15]], align 4 -// CHECK-NEXT: [[TMP23:%.*]] = getelementptr inbounds [2 x [3 x float]], [2 x [3 x float]]* [[SORT_TYPED4]], i64 0, i64 [[TMP5]], i64 [[TMP7]] -// CHECK-NEXT: store float [[TMP21]], float* [[TMP23]], align 4 -// CHECK-NEXT: [[TMP24:%.*]] = getelementptr inbounds [2 x [3 x float]], [2 x [3 x float]]* [[SORT_TYPED4]], i64 0, i64 [[TMP5]], i64 [[TMP8]] -// CHECK-NEXT: store float [[TMP22]], float* [[TMP24]], align 4 +// CHECK-NEXT: [[TMP23:%.*]] = load i32, i32* [[TMP18]], align 4 +// CHECK-NEXT: [[TMP24:%.*]] = load i32, i32* [[TMP19]], align 4 +// CHECK-NEXT: [[TMP25:%.*]] = getelementptr inbounds [2 x [3 x i32]], [2 x [3 x i32]]* [[TMP1]], i64 0, i64 [[TMP11]], i64 [[TMP13]] +// CHECK-NEXT: store i32 [[TMP23]], i32* [[TMP25]], align 4 +// CHECK-NEXT: [[TMP26:%.*]] = getelementptr inbounds [2 x [3 x i32]], [2 x [3 x i32]]* [[TMP1]], i64 0, i64 [[TMP11]], i64 [[TMP14]] +// CHECK-NEXT: store i32 [[TMP24]], i32* [[TMP26]], align 4 +// CHECK-NEXT: [[TMP27:%.*]] = load float, float* [[TMP20]], align 4 +// CHECK-NEXT: [[TMP28:%.*]] = load float, float* [[TMP21]], align 4 +// CHECK-NEXT: [[TMP29:%.*]] = getelementptr inbounds [2 x [3 x float]], [2 x [3 x float]]* [[TMP3]], i64 0, i64 [[TMP11]], i64 [[TMP13]] +// CHECK-NEXT: store float [[TMP27]], float* [[TMP29]], align 4 +// CHECK-NEXT: [[TMP30:%.*]] = getelementptr inbounds [2 x [3 x float]], [2 x [3 x float]]* [[TMP3]], i64 0, i64 [[TMP11]], i64 [[TMP14]] +// CHECK-NEXT: store float [[TMP28]], float* [[TMP30]], align 4 // CHECK-NEXT: br label [[IS_SMALLER_THAN_AFTER]] -// CHECK: define internal void @compare(i32* dereferenceable(4) [[P_0_LHS_TYPED:%.*]], i32* dereferenceable(4) [[P_0_RHS_TYPED:%.*]], float* dereferenceable(4) [[P_1_LHS_TYPED:%.*]], float* dereferenceable(4) [[P_1_RHS_TYPED:%.*]], i8* dereferenceable(1) [[OUTPUT_ARG:%.*]]) +// CHECK: define internal void @region_0_6(i32* dereferenceable(4) [[P_0_LHS_TYPED:%.*]], i32* dereferenceable(4) [[P_0_RHS_TYPED:%.*]], float* dereferenceable(4) [[P_1_LHS_TYPED:%.*]], float* dereferenceable(4) [[P_1_RHS_TYPED:%.*]], i8* dereferenceable(1) [[OUTPUT_ARG:%.*]]) // CHECK-NEXT: entry: -// CHECK-NEXT: [[LT_TYPED:%.*]] = alloca i8, align 1 -// CHECK-NEXT: [[TMP0:%.*]] = load float, float* [[P_1_LHS_TYPED]], align 4 -// CHECK-NEXT: [[TMP1:%.*]] = load float, float* [[P_1_RHS_TYPED]], align 4 +// CHECK-NEXT: [[COMPARE_5_TYPED:%.*]] = alloca i8, align 1 +// CHECK-NEXT: [[TMP0:%.*]] = load float, float* [[ARG_2_3_TYPED:%.*]], align 4 +// CHECK-NEXT: [[TMP1:%.*]] = load float, float* [[ARG_3_4_TYPED:%.*]], align 4 // CHECK-NEXT: [[TMP2:%.*]] = fcmp olt float [[TMP0]], [[TMP1]] // CHECK-NEXT: [[TMP3:%.*]] = zext i1 [[TMP2]] to i8 -// CHECK-NEXT: store i8 [[TMP3]], i8* [[LT_TYPED]], align 1 -// CHECK-NEXT: [[LOAD_RET_VALUE:%.*]] = load i8, i8* [[LT_TYPED]], align 1 -// CHECK-NEXT: store i8 [[LOAD_RET_VALUE]], i8* [[OUTPUT_ARG]], align 1 +// CHECK-NEXT: store i8 [[TMP3]], i8* [[COMPARE_5_TYPED]], align 1 +// CHECK-NEXT: [[LOAD_RET_VALUE:%.*]] = load i8, i8* [[COMPARE_5_TYPED]], align 1 +// CHECK-NEXT: store i8 [[LOAD_RET_VALUE]], i8* [[OUTPUT_ARG:%.*]], align 1 // CHECK-NEXT: ret void -// CHECK: define void @sort__1(i8* noalias align 64 dereferenceable(24) [[ALLOC0:%.*]], i8* noalias align 64 dereferenceable(24) [[ALLOC1:%.*]], i8* noalias align 16 dereferenceable(24) [[ALLOC2:%.*]], i8* noalias align 16 dereferenceable(24) [[ALLOC3:%.*]], i8* noalias align 64 dereferenceable(16) [[ALLOC4:%.*]]) +// CHECK: define void @sort__1(i8* noalias align 64 dereferenceable(24) [[ALLOC0:%.*]], i8* noalias align 64 dereferenceable(24) [[ALLOC1:%.*]], i8* noalias align 64 dereferenceable(16) [[ALLOC4:%.*]]) // CHECK-NEXT: entry: // CHECK-NEXT: [[COMPARE_RETURN_BUFFER:%.*]] = alloca i8, align 1 -// CHECK-NEXT: [[SORT_RAW:%.*]] = getelementptr inbounds i8, i8* [[ALLOC4:%.*]], i64 0 -// CHECK-NEXT: [[SORT_TYPED:%.*]] = bitcast i8* [[SORT_RAW]] to [2 x i8*]* -// CHECK-NEXT: [[SORT_RAW1:%.*]] = getelementptr inbounds i8, i8* [[ALLOC0:%.*]], i64 0 -// CHECK-NEXT: [[SORT_TYPED2:%.*]] = bitcast i8* [[SORT_RAW1]] to [2 x [3 x i32]]* -// CHECK-NEXT: [[SORT_RAW3:%.*]] = getelementptr inbounds i8, i8* [[ALLOC1:%.*]], i64 0 -// CHECK-NEXT: [[SORT_TYPED4:%.*]] = bitcast i8* [[SORT_RAW3]] to [2 x [3 x float]]* -// CHECK-NEXT: [[X_RAW:%.*]] = getelementptr inbounds i8, i8* [[ALLOC2:%.*]], i64 0 -// CHECK-NEXT: [[X_TYPED:%.*]] = bitcast i8* [[X_RAW]] to [2 x [3 x i32]]* -// CHECK-NEXT: [[Y_RAW:%.*]] = getelementptr inbounds i8, i8* [[ALLOC3:%.*]], i64 0 -// CHECK-NEXT: [[Y_TYPED:%.*]] = bitcast i8* [[Y_RAW]] to [2 x [3 x float]]* -// CHECK-NEXT: [[TMP0:%.*]] = call i32 @llvm.nvvm.read.ptx.sreg.ctaid.x(), !range !6 -// CHECK-NEXT: [[BLOCK_ID:%.*]] = zext i32 [[TMP0]] to i64 -// CHECK-NEXT: [[TMP1:%.*]] = call i32 @llvm.nvvm.read.ptx.sreg.tid.x(), !range !7 -// CHECK-NEXT: [[THREAD_ID:%.*]] = zext i32 [[TMP1]] to i64 -// CHECK-NEXT: [[TMP2:%.*]] = mul nuw nsw i64 [[BLOCK_ID]], 4 -// CHECK-NEXT: [[LINEAR_INDEX:%.*]] = add nuw nsw i64 [[TMP2]], [[THREAD_ID]] +// CHECK-NEXT: [[TMP0:%.*]] = getelementptr inbounds i8, i8* [[ALLOC0:%.*]], i64 0 +// CHECK-NEXT: [[TMP1:%.*]] = bitcast i8* [[TMP0]] to [2 x [3 x i32]]* +// CHECK-NEXT: [[TMP2:%.*]] = getelementptr inbounds i8, i8* [[ALLOC1:%.*]], i64 0 +// CHECK-NEXT: [[TMP3:%.*]] = bitcast i8* [[TMP2]] to [2 x [3 x float]]* +// CHECK-NEXT: [[TMP4:%.*]] = getelementptr inbounds i8, i8* [[ALLOC4:%.*]], i64 0 +// CHECK-NEXT: [[TMP5:%.*]] = bitcast i8* [[TMP4]] to [2 x i8*]* +// CHECK-NEXT: [[TMP6:%.*]] = call i32 @llvm.nvvm.read.ptx.sreg.ctaid.x(), !range !6 +// CHECK-NEXT: [[BLOCK_ID:%.*]] = zext i32 [[TMP6]] to i64 +// CHECK-NEXT: [[TMP7:%.*]] = call i32 @llvm.nvvm.read.ptx.sreg.tid.x(), !range !7 +// CHECK-NEXT: [[THREAD_ID:%.*]] = zext i32 [[TMP7]] to i64 +// CHECK-NEXT: [[TMP8:%.*]] = mul nuw nsw i64 [[BLOCK_ID]], 4 +// CHECK-NEXT: [[LINEAR_INDEX:%.*]] = add nuw nsw i64 [[TMP8]], [[THREAD_ID]] // CHECK-NEXT: [[LINEAR_INDEX_IN_RANGE:%.*]] = icmp ult i64 [[LINEAR_INDEX]], 4 // CHECK-NEXT: call void @llvm.assume(i1 [[LINEAR_INDEX_IN_RANGE]]) -// CHECK-NEXT: [[TMP3:%.*]] = udiv i64 [[LINEAR_INDEX]], 1 -// CHECK-NEXT: [[TMP4:%.*]] = urem i64 [[TMP3]], 2 -// CHECK-NEXT: [[TMP5:%.*]] = udiv i64 [[LINEAR_INDEX]], 2 -// CHECK-NEXT: [[TMP6:%.*]] = icmp ult i64 [[LINEAR_INDEX]], 4 -// CHECK-NEXT: br i1 [[TMP6]], label [[SORT_IN_BOUNDS_TRUE:%.*]], label [[SORT_IN_BOUNDS_AFTER:%.*]] +// CHECK-NEXT: [[TMP9:%.*]] = udiv i64 [[LINEAR_INDEX]], 1 +// CHECK-NEXT: [[TMP10:%.*]] = urem i64 [[TMP9]], 2 +// CHECK-NEXT: [[TMP11:%.*]] = udiv i64 [[LINEAR_INDEX]], 2 +// CHECK-NEXT: [[TMP12:%.*]] = icmp ult i64 [[LINEAR_INDEX]], 4 +// CHECK-NEXT: br i1 [[TMP12]], label [[SORT_IN_BOUNDS_TRUE:%.*]], label [[SORT_IN_BOUNDS_AFTER:%.*]] // CHECK: sort.in_bounds-after: // CHECK-NEXT: ret void // CHECK: sort.in_bounds-true: -// CHECK-NEXT: [[TMP7:%.*]] = xor i64 [[TMP4]], 3 -// CHECK-NEXT: [[TMP8:%.*]] = icmp slt i64 [[TMP4]], [[TMP7]] -// CHECK-NEXT: [[TMP9:%.*]] = icmp slt i64 [[TMP7]], 3 -// CHECK-NEXT: [[TMP10:%.*]] = and i1 [[TMP8]], [[TMP9]] -// CHECK-NEXT: br i1 [[TMP10]], label [[SMALLER_COMPARISON_INDEX_TRUE:%.*]], label [[SMALLER_COMPARISON_INDEX_AFTER:%.*]] +// CHECK-NEXT: [[TMP13:%.*]] = xor i64 [[TMP10]], 3 +// CHECK-NEXT: [[TMP14:%.*]] = icmp slt i64 [[TMP10]], [[TMP13]] +// CHECK-NEXT: [[TMP15:%.*]] = icmp slt i64 [[TMP13]], 3 +// CHECK-NEXT: [[TMP16:%.*]] = and i1 [[TMP14]], [[TMP15]] +// CHECK-NEXT: br i1 [[TMP16]], label [[SMALLER_COMPARISON_INDEX_TRUE:%.*]], label [[SMALLER_COMPARISON_INDEX_AFTER:%.*]] // CHECK: smaller_comparison_index-after: // CHECK-NEXT: br label [[SORT_IN_BOUNDS_AFTER]] // CHECK: smaller_comparison_index-true: -// CHECK-NEXT: [[TMP11:%.*]] = getelementptr inbounds [2 x [3 x i32]], [2 x [3 x i32]]* [[SORT_TYPED2]], i64 0, i64 [[TMP5]], i64 [[TMP7]] -// CHECK-NEXT: [[TMP12:%.*]] = getelementptr inbounds [2 x [3 x i32]], [2 x [3 x i32]]* [[SORT_TYPED2]], i64 0, i64 [[TMP5]], i64 [[TMP4]] -// CHECK-NEXT: [[TMP13:%.*]] = getelementptr inbounds [2 x [3 x float]], [2 x [3 x float]]* [[SORT_TYPED4]], i64 0, i64 [[TMP5]], i64 [[TMP7]] -// CHECK-NEXT: [[TMP14:%.*]] = getelementptr inbounds [2 x [3 x float]], [2 x [3 x float]]* [[SORT_TYPED4]], i64 0, i64 [[TMP5]], i64 [[TMP4]] -// CHECK-NEXT: call void @compare(i32* [[TMP11]], i32* [[TMP12]], float* [[TMP13]], float* [[TMP14]], i8* [[COMPARE_RETURN_BUFFER]]) -// CHECK-NEXT: [[TMP15:%.*]] = load i8, i8* [[COMPARE_RETURN_BUFFER]], align 1 -// CHECK-NEXT: [[BOOLEAN_PREDICATE:%.*]] = icmp ne i8 [[TMP15]], 0 +// CHECK-NEXT: [[TMP17:%.*]] = getelementptr inbounds [2 x [3 x i32]], [2 x [3 x i32]]* [[TMP1]], i64 0, i64 [[TMP11]], i64 [[TMP13]] +// CHECK-NEXT: [[TMP18:%.*]] = getelementptr inbounds [2 x [3 x i32]], [2 x [3 x i32]]* [[TMP1]], i64 0, i64 [[TMP11]], i64 [[TMP10]] +// CHECK-NEXT: [[TMP19:%.*]] = getelementptr inbounds [2 x [3 x float]], [2 x [3 x float]]* [[TMP3]], i64 0, i64 [[TMP11]], i64 [[TMP13]] +// CHECK-NEXT: [[TMP20:%.*]] = getelementptr inbounds [2 x [3 x float]], [2 x [3 x float]]* [[TMP3]], i64 0, i64 [[TMP11]], i64 [[TMP10]] +// CHECK-NEXT: call void @region_0_6(i32* [[TMP17]], i32* [[TMP18]], float* [[TMP19]], float* [[TMP20]], i8* [[COMPARE_RETURN_BUFFER]]) +// CHECK-NEXT: [[TMP21:%.*]] = load i8, i8* [[COMPARE_RETURN_BUFFER]], align 1 +// CHECK-NEXT: [[BOOLEAN_PREDICATE:%.*]] = icmp ne i8 [[TMP21]], 0 // CHECK-NEXT: br i1 [[BOOLEAN_PREDICATE]], label [[IS_SMALLER_THAN_TRUE:%.*]], label [[IS_SMALLER_THAN_AFTER:%.*]] // CHECK: is_smaller_than-after: // CHECK-NEXT: br label [[SMALLER_COMPARISON_INDEX_AFTER]] // CHECK: is_smaller_than-true: -// CHECK-NEXT: [[TMP16:%.*]] = load i32, i32* [[TMP11]], align 4 -// CHECK-NEXT: [[TMP17:%.*]] = load i32, i32* [[TMP12]], align 4 -// CHECK-NEXT: [[TMP18:%.*]] = getelementptr inbounds [2 x [3 x i32]], [2 x [3 x i32]]* [[SORT_TYPED2]], i64 0, i64 [[TMP5]], i64 [[TMP4]] -// CHECK-NEXT: store i32 [[TMP16]], i32* [[TMP18]], align 4 -// CHECK-NEXT: [[TMP19:%.*]] = getelementptr inbounds [2 x [3 x i32]], [2 x [3 x i32]]* [[SORT_TYPED2]], i64 0, i64 [[TMP5]], i64 [[TMP7]] -// CHECK-NEXT: store i32 [[TMP17]], i32* [[TMP19]], align 4 -// CHECK-NEXT: [[TMP20:%.*]] = load float, float* [[TMP13]], align 4 -// CHECK-NEXT: [[TMP21:%.*]] = load float, float* [[TMP14]], align 4 -// CHECK-NEXT: [[TMP22:%.*]] = getelementptr inbounds [2 x [3 x float]], [2 x [3 x float]]* [[SORT_TYPED4]], i64 0, i64 [[TMP5]], i64 [[TMP4]] -// CHECK-NEXT: store float [[TMP20]], float* [[TMP22]], align 4 -// CHECK-NEXT: [[TMP23:%.*]] = getelementptr inbounds [2 x [3 x float]], [2 x [3 x float]]* [[SORT_TYPED4]], i64 0, i64 [[TMP5]], i64 [[TMP7]] -// CHECK-NEXT: store float [[TMP21]], float* [[TMP23]], align 4 +// CHECK-NEXT: [[TMP22:%.*]] = load i32, i32* [[TMP17]], align 4 +// CHECK-NEXT: [[TMP23:%.*]] = load i32, i32* [[TMP18]], align 4 +// CHECK-NEXT: [[TMP24:%.*]] = getelementptr inbounds [2 x [3 x i32]], [2 x [3 x i32]]* [[TMP1]], i64 0, i64 [[TMP11]], i64 [[TMP10]] +// CHECK-NEXT: store i32 [[TMP22]], i32* [[TMP24]], align 4 +// CHECK-NEXT: [[TMP25:%.*]] = getelementptr inbounds [2 x [3 x i32]], [2 x [3 x i32]]* [[TMP1]], i64 0, i64 [[TMP11]], i64 [[TMP13]] +// CHECK-NEXT: store i32 [[TMP23]], i32* [[TMP25]], align 4 +// CHECK-NEXT: [[TMP26:%.*]] = load float, float* [[TMP19]], align 4 +// CHECK-NEXT: [[TMP27:%.*]] = load float, float* [[TMP20]], align 4 +// CHECK-NEXT: [[TMP28:%.*]] = getelementptr inbounds [2 x [3 x float]], [2 x [3 x float]]* [[TMP3]], i64 0, i64 [[TMP11]], i64 [[TMP10]] +// CHECK-NEXT: store float [[TMP26]], float* [[TMP28]], align 4 +// CHECK-NEXT: [[TMP29:%.*]] = getelementptr inbounds [2 x [3 x float]], [2 x [3 x float]]* [[TMP3]], i64 0, i64 [[TMP11]], i64 [[TMP13]] +// CHECK-NEXT: store float [[TMP27]], float* [[TMP29]], align 4 // CHECK-NEXT: br label [[IS_SMALLER_THAN_AFTER]] -// CHECK: define void @sort__2(i8* noalias align 64 dereferenceable(24) [[ALLOC0:%.*]], i8* noalias align 64 dereferenceable(24) [[ALLOC1:%.*]], i8* noalias align 16 dereferenceable(24) [[ALLOC2:%.*]], i8* noalias align 16 dereferenceable(24) [[ALLOC3:%.*]], i8* noalias align 64 dereferenceable(16) [[ALLOC4:%.*]]) +// CHECK: define void @sort__2(i8* noalias align 64 dereferenceable(24) [[ALLOC0:%.*]], i8* noalias align 64 dereferenceable(24) [[ALLOC1:%.*]], i8* noalias align 64 dereferenceable(16) [[ALLOC4:%.*]]) // CHECK-NEXT: entry: // CHECK-NEXT: [[COMPARE_RETURN_BUFFER:%.*]] = alloca i8, align 1 -// CHECK-NEXT: [[SORT_RAW:%.*]] = getelementptr inbounds i8, i8* [[ALLOC4:%.*]], i64 0 -// CHECK-NEXT: [[SORT_TYPED:%.*]] = bitcast i8* [[SORT_RAW]] to [2 x i8*]* -// CHECK-NEXT: [[SORT_RAW1:%.*]] = getelementptr inbounds i8, i8* [[ALLOC0:%.*]], i64 0 -// CHECK-NEXT: [[SORT_TYPED2:%.*]] = bitcast i8* [[SORT_RAW1]] to [2 x [3 x i32]]* -// CHECK-NEXT: [[SORT_RAW3:%.*]] = getelementptr inbounds i8, i8* [[ALLOC1:%.*]], i64 0 -// CHECK-NEXT: [[SORT_TYPED4:%.*]] = bitcast i8* [[SORT_RAW3]] to [2 x [3 x float]]* -// CHECK-NEXT: [[X_RAW:%.*]] = getelementptr inbounds i8, i8* [[ALLOC2:%.*]], i64 0 -// CHECK-NEXT: [[X_TYPED:%.*]] = bitcast i8* [[X_RAW]] to [2 x [3 x i32]]* -// CHECK-NEXT: [[Y_RAW:%.*]] = getelementptr inbounds i8, i8* [[ALLOC3:%.*]], i64 0 -// CHECK-NEXT: [[Y_TYPED:%.*]] = bitcast i8* [[Y_RAW]] to [2 x [3 x float]]* -// CHECK-NEXT: [[TMP0:%.*]] = call i32 @llvm.nvvm.read.ptx.sreg.ctaid.x(), !range !6 -// CHECK-NEXT: [[BLOCK_ID:%.*]] = zext i32 [[TMP0]] to i64 -// CHECK-NEXT: [[TMP1:%.*]] = call i32 @llvm.nvvm.read.ptx.sreg.tid.x(), !range !7 -// CHECK-NEXT: [[THREAD_ID:%.*]] = zext i32 [[TMP1]] to i64 -// CHECK-NEXT: [[TMP2:%.*]] = mul nuw nsw i64 [[BLOCK_ID]], 4 -// CHECK-NEXT: [[LINEAR_INDEX:%.*]] = add nuw nsw i64 [[TMP2]], [[THREAD_ID]] +// CHECK-NEXT: [[TMP0:%.*]] = getelementptr inbounds i8, i8* [[ALLOC0:%.*]], i64 0 +// CHECK-NEXT: [[TMP1:%.*]] = bitcast i8* [[TMP0]] to [2 x [3 x i32]]* +// CHECK-NEXT: [[TMP2:%.*]] = getelementptr inbounds i8, i8* [[ALLOC1:%.*]], i64 0 +// CHECK-NEXT: [[TMP3:%.*]] = bitcast i8* [[TMP2]] to [2 x [3 x float]]* +// CHECK-NEXT: [[TMP4:%.*]] = getelementptr inbounds i8, i8* [[ALLOC4:%.*]], i64 0 +// CHECK-NEXT: [[TMP5:%.*]] = bitcast i8* [[TMP4]] to [2 x i8*]* +// CHECK-NEXT: [[TMP6:%.*]] = call i32 @llvm.nvvm.read.ptx.sreg.ctaid.x(), !range !6 +// CHECK-NEXT: [[BLOCK_ID:%.*]] = zext i32 [[TMP6]] to i64 +// CHECK-NEXT: [[TMP7:%.*]] = call i32 @llvm.nvvm.read.ptx.sreg.tid.x(), !range !7 +// CHECK-NEXT: [[THREAD_ID:%.*]] = zext i32 [[TMP7]] to i64 +// CHECK-NEXT: [[TMP8:%.*]] = mul nuw nsw i64 [[BLOCK_ID]], 4 +// CHECK-NEXT: [[LINEAR_INDEX:%.*]] = add nuw nsw i64 [[TMP8]], [[THREAD_ID]] // CHECK-NEXT: [[LINEAR_INDEX_IN_RANGE:%.*]] = icmp ult i64 [[LINEAR_INDEX]], 4 // CHECK-NEXT: call void @llvm.assume(i1 [[LINEAR_INDEX_IN_RANGE]]) -// CHECK-NEXT: [[TMP3:%.*]] = udiv i64 [[LINEAR_INDEX]], 1 -// CHECK-NEXT: [[TMP4:%.*]] = urem i64 [[TMP3]], 2 -// CHECK-NEXT: [[TMP5:%.*]] = udiv i64 [[LINEAR_INDEX]], 2 -// CHECK-NEXT: [[TMP6:%.*]] = icmp ult i64 [[LINEAR_INDEX]], 4 -// CHECK-NEXT: br i1 [[TMP6]], label [[SORT_IN_BOUNDS_TRUE:%.*]], label [[SORT_IN_BOUNDS_AFTER:%.*]] +// CHECK-NEXT: [[TMP9:%.*]] = udiv i64 [[LINEAR_INDEX]], 1 +// CHECK-NEXT: [[TMP10:%.*]] = urem i64 [[TMP9]], 2 +// CHECK-NEXT: [[TMP11:%.*]] = udiv i64 [[LINEAR_INDEX]], 2 +// CHECK-NEXT: [[TMP12:%.*]] = icmp ult i64 [[LINEAR_INDEX]], 4 +// CHECK-NEXT: br i1 [[TMP12]], label [[SORT_IN_BOUNDS_TRUE:%.*]], label [[SORT_IN_BOUNDS_AFTER:%.*]] // CHECK: sort.in_bounds-after: -// CHECK-NEXT: [[TMP7:%.*]] = bitcast [2 x [3 x i32]]* [[SORT_TYPED2]] to i8* -// CHECK-NEXT: [[TMP8:%.*]] = getelementptr inbounds [2 x i8*], [2 x i8*]* [[SORT_TYPED]], i64 0, i64 0 -// CHECK-NEXT: store i8* [[TMP7]], i8** [[TMP8]], align 8 -// CHECK-NEXT: [[TMP9:%.*]] = bitcast [2 x [3 x float]]* [[SORT_TYPED4]] to i8* -// CHECK-NEXT: [[TMP10:%.*]] = getelementptr inbounds [2 x i8*], [2 x i8*]* [[SORT_TYPED]], i64 0, i64 1 -// CHECK-NEXT: store i8* [[TMP9]], i8** [[TMP10]], align 8 +// CHECK-NEXT: [[TMP13:%.*]] = bitcast [2 x [3 x i32]]* [[TMP1]] to i8* +// CHECK-NEXT: [[TMP14:%.*]] = getelementptr inbounds [2 x i8*], [2 x i8*]* [[TMP5]], i64 0, i64 0 +// CHECK-NEXT: store i8* [[TMP13]], i8** [[TMP14]], align 8 +// CHECK-NEXT: [[TMP15:%.*]] = bitcast [2 x [3 x float]]* [[TMP3]] to i8* +// CHECK-NEXT: [[TMP16:%.*]] = getelementptr inbounds [2 x i8*], [2 x i8*]* [[TMP5]], i64 0, i64 1 +// CHECK-NEXT: store i8* [[TMP15]], i8** [[TMP16]], align 8 // CHECK-NEXT: ret void // CHECK: sort.in_bounds-true: -// CHECK-NEXT: [[TMP11:%.*]] = mul i64 [[TMP4]], 2 -// CHECK-NEXT: [[TMP12:%.*]] = xor i64 [[TMP11]], 1 -// CHECK-NEXT: [[TMP13:%.*]] = icmp slt i64 [[TMP11]], [[TMP12]] -// CHECK-NEXT: [[TMP14:%.*]] = icmp slt i64 [[TMP12]], 3 -// CHECK-NEXT: [[TMP15:%.*]] = and i1 [[TMP13]], [[TMP14]] -// CHECK-NEXT: br i1 [[TMP15]], label [[SMALLER_COMPARISON_INDEX_TRUE:%.*]], label [[SMALLER_COMPARISON_INDEX_AFTER:%.*]] +// CHECK-NEXT: [[TMP17:%.*]] = mul i64 [[TMP10]], 2 +// CHECK-NEXT: [[TMP18:%.*]] = xor i64 [[TMP17]], 1 +// CHECK-NEXT: [[TMP19:%.*]] = icmp slt i64 [[TMP17]], [[TMP18]] +// CHECK-NEXT: [[TMP20:%.*]] = icmp slt i64 [[TMP18]], 3 +// CHECK-NEXT: [[TMP21:%.*]] = and i1 [[TMP19]], [[TMP20]] +// CHECK-NEXT: br i1 [[TMP21]], label [[SMALLER_COMPARISON_INDEX_TRUE:%.*]], label [[SMALLER_COMPARISON_INDEX_AFTER:%.*]] // CHECK: smaller_comparison_index-after: // CHECK-NEXT: br label [[SORT_IN_BOUNDS_AFTER]] // CHECK: smaller_comparison_index-true: -// CHECK-NEXT: [[TMP16:%.*]] = getelementptr inbounds [2 x [3 x i32]], [2 x [3 x i32]]* [[SORT_TYPED2]], i64 0, i64 [[TMP5]], i64 [[TMP12]] -// CHECK-NEXT: [[TMP17:%.*]] = getelementptr inbounds [2 x [3 x i32]], [2 x [3 x i32]]* [[SORT_TYPED2]], i64 0, i64 [[TMP5]], i64 [[TMP11]] -// CHECK-NEXT: [[TMP18:%.*]] = getelementptr inbounds [2 x [3 x float]], [2 x [3 x float]]* [[SORT_TYPED4]], i64 0, i64 [[TMP5]], i64 [[TMP12]] -// CHECK-NEXT: [[TMP19:%.*]] = getelementptr inbounds [2 x [3 x float]], [2 x [3 x float]]* [[SORT_TYPED4]], i64 0, i64 [[TMP5]], i64 [[TMP11]] -// CHECK-NEXT: call void @compare(i32* [[TMP16]], i32* [[TMP17]], float* [[TMP18]], float* [[TMP19]], i8* [[COMPARE_RETURN_BUFFER]]) -// CHECK-NEXT: [[TMP20:%.*]] = load i8, i8* [[COMPARE_RETURN_BUFFER]], align 1 -// CHECK-NEXT: [[BOOLEAN_PREDICATE:%.*]] = icmp ne i8 [[TMP20]], 0 +// CHECK-NEXT: [[TMP22:%.*]] = getelementptr inbounds [2 x [3 x i32]], [2 x [3 x i32]]* [[TMP1]], i64 0, i64 [[TMP11]], i64 [[TMP18]] +// CHECK-NEXT: [[TMP23:%.*]] = getelementptr inbounds [2 x [3 x i32]], [2 x [3 x i32]]* [[TMP1]], i64 0, i64 [[TMP11]], i64 [[TMP17]] +// CHECK-NEXT: [[TMP24:%.*]] = getelementptr inbounds [2 x [3 x float]], [2 x [3 x float]]* [[TMP3]], i64 0, i64 [[TMP11]], i64 [[TMP18]] +// CHECK-NEXT: [[TMP25:%.*]] = getelementptr inbounds [2 x [3 x float]], [2 x [3 x float]]* [[TMP3]], i64 0, i64 [[TMP11]], i64 [[TMP17]] +// CHECK-NEXT: call void @region_0_6(i32* [[TMP22]], i32* [[TMP23]], float* [[TMP24]], float* [[TMP25]], i8* [[COMPARE_RETURN_BUFFER]]) +// CHECK-NEXT: [[TMP26:%.*]] = load i8, i8* [[COMPARE_RETURN_BUFFER]], align 1 +// CHECK-NEXT: [[BOOLEAN_PREDICATE:%.*]] = icmp ne i8 [[TMP26]], 0 // CHECK-NEXT: br i1 [[BOOLEAN_PREDICATE]], label [[IS_SMALLER_THAN_TRUE:%.*]], label [[IS_SMALLER_THAN_AFTER:%.*]] // CHECK: is_smaller_than-after: // CHECK-NEXT: br label [[SMALLER_COMPARISON_INDEX_AFTER]] // CHECK: is_smaller_than-true: -// CHECK-NEXT: [[TMP21:%.*]] = load i32, i32* [[TMP16]], align 4 -// CHECK-NEXT: [[TMP22:%.*]] = load i32, i32* [[TMP17]], align 4 -// CHECK-NEXT: [[TMP23:%.*]] = getelementptr inbounds [2 x [3 x i32]], [2 x [3 x i32]]* [[SORT_TYPED2]], i64 0, i64 [[TMP5]], i64 [[TMP11]] -// CHECK-NEXT: store i32 [[TMP21]], i32* [[TMP23]], align 4 -// CHECK-NEXT: [[TMP24:%.*]] = getelementptr inbounds [2 x [3 x i32]], [2 x [3 x i32]]* [[SORT_TYPED2]], i64 0, i64 [[TMP5]], i64 [[TMP12]] -// CHECK-NEXT: store i32 [[TMP22]], i32* [[TMP24]], align 4 -// CHECK-NEXT: [[TMP25:%.*]] = load float, float* [[TMP18]], align 4 -// CHECK-NEXT: [[TMP26:%.*]] = load float, float* [[TMP19]], align 4 -// CHECK-NEXT: [[TMP27:%.*]] = getelementptr inbounds [2 x [3 x float]], [2 x [3 x float]]* [[SORT_TYPED4]], i64 0, i64 [[TMP5]], i64 [[TMP11]] -// CHECK-NEXT: store float [[TMP25]], float* [[TMP27]], align 4 -// CHECK-NEXT: [[TMP28:%.*]] = getelementptr inbounds [2 x [3 x float]], [2 x [3 x float]]* [[SORT_TYPED4]], i64 0, i64 [[TMP5]], i64 [[TMP12]] -// CHECK-NEXT: store float [[TMP26]], float* [[TMP28]], align 4 +// CHECK-NEXT: [[TMP27:%.*]] = load i32, i32* [[TMP22]], align 4 +// CHECK-NEXT: [[TMP28:%.*]] = load i32, i32* [[TMP23]], align 4 +// CHECK-NEXT: [[TMP29:%.*]] = getelementptr inbounds [2 x [3 x i32]], [2 x [3 x i32]]* [[TMP1]], i64 0, i64 [[TMP11]], i64 [[TMP17]] +// CHECK-NEXT: store i32 [[TMP27]], i32* [[TMP29]], align 4 +// CHECK-NEXT: [[TMP30:%.*]] = getelementptr inbounds [2 x [3 x i32]], [2 x [3 x i32]]* [[TMP1]], i64 0, i64 [[TMP11]], i64 [[TMP18]] +// CHECK-NEXT: store i32 [[TMP28]], i32* [[TMP30]], align 4 +// CHECK-NEXT: [[TMP31:%.*]] = load float, float* [[TMP24]], align 4 +// CHECK-NEXT: [[TMP32:%.*]] = load float, float* [[TMP25]], align 4 +// CHECK-NEXT: [[TMP33:%.*]] = getelementptr inbounds [2 x [3 x float]], [2 x [3 x float]]* [[TMP3]], i64 0, i64 [[TMP11]], i64 [[TMP17]] +// CHECK-NEXT: store float [[TMP31]], float* [[TMP33]], align 4 +// CHECK-NEXT: [[TMP34:%.*]] = getelementptr inbounds [2 x [3 x float]], [2 x [3 x float]]* [[TMP3]], i64 0, i64 [[TMP11]], i64 [[TMP18]] +// CHECK-NEXT: store float [[TMP32]], float* [[TMP34]], align 4 // CHECK-NEXT: br label [[IS_SMALLER_THAN_AFTER]] ENTRY main { x = s32[2, 3] parameter(0) diff --git a/tensorflow/compiler/xla/service/gpu/tests/sorting_test.cc b/tensorflow/compiler/xla/service/gpu/tests/sorting_test.cc new file mode 100644 index 00000000000..197a0c6cfeb --- /dev/null +++ b/tensorflow/compiler/xla/service/gpu/tests/sorting_test.cc @@ -0,0 +1,71 @@ +/* Copyright 2020 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include + +#include "tensorflow/compiler/xla/service/gpu/gpu_executable.h" +#include "tensorflow/compiler/xla/service/gpu/tests/gpu_codegen_test.h" +#include "tensorflow/compiler/xla/service/hlo_instruction.h" +#include "tensorflow/compiler/xla/service/hlo_module_config.h" +#include "tensorflow/compiler/xla/service/hlo_parser.h" +#include "tensorflow/compiler/xla/statusor.h" +#include "tensorflow/compiler/xla/tests/filecheck.h" +#include "tensorflow/compiler/xla/tests/hlo_test_base.h" +#include "tensorflow/compiler/xla/xla.pb.h" +#include "tensorflow/core/lib/core/status_test_util.h" +#include "tensorflow/core/platform/test.h" +#include "tensorflow/stream_executor/lib/statusor.h" + +namespace xla { +namespace gpu { + +namespace { + +class SortingTest : public GpuCodegenTest { + protected: + HloModuleConfig ConfigWithoutLayoutAssignment() { + HloModuleConfig config; + auto debug_options = HloTestBase::GetDebugOptionsForTest(); + // Disable layout_assignment to use the preassigned layouts. + debug_options.add_xla_disable_hlo_passes("layout-assignment"); + config.set_debug_options(debug_options); + return config; + } +}; + +TEST_F(SortingTest, Regression1) { + const char* hlo_text = R"( +HloModule TestModule + +compare { + p.0.lhs = f32[] parameter(0) + p.0.rhs = f32[] parameter(1) + ROOT lt = pred[] compare(p.0.lhs, p.0.rhs), direction=LT +} + +ENTRY TestComputation { + x = f32[3, 2]{1, 0} parameter(0) + x.copy = f32[3, 2]{0, 1} copy(x) + ROOT sort = f32[3, 2]{0, 1} sort(x.copy), dimensions={1}, to_apply=compare +} + +)"; + + EXPECT_TRUE(RunAndCompareNoHloPasses(hlo_text, ErrorSpec{1e-5, 1e-5})); +} + +} // namespace +} // namespace gpu +} // namespace xla diff --git a/tensorflow/compiler/xla/service/heap_simulator.cc b/tensorflow/compiler/xla/service/heap_simulator.cc index 10751752571..2e2b668eba7 100644 --- a/tensorflow/compiler/xla/service/heap_simulator.cc +++ b/tensorflow/compiler/xla/service/heap_simulator.cc @@ -24,6 +24,7 @@ limitations under the License. #include "tensorflow/compiler/xla/map_util.h" #include "tensorflow/compiler/xla/service/hlo_live_range.h" #include "tensorflow/compiler/xla/service/hlo_schedule.h" +#include "tensorflow/compiler/xla/service/memory_space_assignment_repacking.h" #include "tensorflow/compiler/xla/util.h" namespace xla { @@ -55,9 +56,10 @@ StatusOr HeapSimulator::MinimumMemoryForModule( // rather than summing each computation, since it gives us a better lower // bound, by minimizing the liveness of sub-computations. TF_ASSIGN_OR_RETURN( - HeapSimulator::Result result, - HeapSimulator::Run(absl::make_unique(), *module, - schedule, *alias_analysis, size_function)); + HeapSimulator::Result result, + HeapSimulator::Run( + absl::make_unique>(), *module, + schedule, *alias_analysis, size_function)); return result.heap_size; } @@ -69,10 +71,11 @@ StatusOr HeapSimulator::MinimumMemoryForComputation( const absl::flat_hash_map* memory_by_computation) { TF_ASSIGN_OR_RETURN( - HeapSimulator::Result result, - HeapSimulator::Run(absl::make_unique(), - computation, sequence, alias_analysis, size_function, - HeapSimulator::Options(), memory_by_computation)); + HeapSimulator::Result result, + HeapSimulator::Run( + absl::make_unique>(), computation, + sequence, alias_analysis, size_function, HeapSimulator::Options(), + memory_by_computation)); return result.heap_size; } @@ -82,16 +85,17 @@ StatusOr HeapSimulator::MinimumMemoryForComputation( const LogicalBuffer::SizeFunction& size_function, const HloSchedule* schedule) { TF_ASSIGN_OR_RETURN( - HeapSimulator::Result result, - HeapSimulator::Run(absl::make_unique(), - computation, sequence, alias_analysis, size_function, - schedule, HeapSimulator::Options())); + HeapSimulator::Result result, + HeapSimulator::Run( + absl::make_unique>(), computation, + sequence, alias_analysis, size_function, schedule, + HeapSimulator::Options())); return result.heap_size; } /*static*/ -StatusOr HeapSimulator::Run( - std::unique_ptr algorithm, const HloModule& module, +StatusOr> HeapSimulator::Run( + std::unique_ptr> algorithm, const HloModule& module, const HloSchedule& schedule, const HloAliasAnalysis& alias_analysis, const BufferValue::SizeFunction& size_fn, const Options& options) { HeapSimulator heap(std::move(algorithm), size_fn, options, &schedule); @@ -108,8 +112,9 @@ StatusOr HeapSimulator::Run( } /*static*/ -StatusOr HeapSimulator::Run( - std::unique_ptr algorithm, const HloComputation& computation, +StatusOr> HeapSimulator::Run( + std::unique_ptr> algorithm, + const HloComputation& computation, const HloInstructionSequence& instruction_sequence, const HloAliasAnalysis& alias_analysis, const BufferValue::SizeFunction& size_fn, const Options& options, @@ -128,8 +133,9 @@ StatusOr HeapSimulator::Run( } /*static*/ -StatusOr HeapSimulator::Run( - std::unique_ptr algorithm, const HloComputation& computation, +StatusOr> HeapSimulator::Run( + std::unique_ptr> algorithm, + const HloComputation& computation, const HloInstructionSequence& instruction_sequence, const HloAliasAnalysis& alias_analysis, const BufferValue::SizeFunction& size_fn, const HloSchedule* schedule, @@ -326,12 +332,13 @@ Status HeapSimulator::RunComputation( } HeapSimulator::HeapSimulator( - std::unique_ptr algorithm, + std::unique_ptr> algorithm, const BufferValue::SizeFunction& size_fn, const Options& options, const HloSchedule* schedule, const absl::flat_hash_map* memory_by_computation) - : no_fragmentation_stats_(absl::make_unique()), + : no_fragmentation_stats_( + absl::make_unique>()), algorithm_(std::move(algorithm)), size_fn_(size_fn), options_(options), @@ -396,8 +403,8 @@ void HeapSimulator::ShareBuffer(const HloValue* buffer, const HloValue* shared, shared); } -HeapSimulator::Result HeapSimulator::Finish() { - Result result = algorithm_->Finish(); +HeapSimulator::Result HeapSimulator::Finish() { + Result result = algorithm_->Finish(); // Post-process the result to add chunks for shared buffers. An empty chunk // map means that either no buffers were allocated, or the heap was only @@ -411,7 +418,7 @@ HeapSimulator::Result HeapSimulator::Finish() { } // Fragmentation is the difference between the actual and ideal sizes. - const Result no_frag_result = no_fragmentation_stats_->Finish(); + const Result no_frag_result = no_fragmentation_stats_->Finish(); result.fragmentation_size = result.heap_size - no_frag_result.heap_size; // Copy the debug trace we collected to the final result. @@ -437,14 +444,17 @@ void HeapSimulator::FillDebugTrace(HeapSimulatorTrace::Event::Kind kind, } } -void NoFragmentationStatsHeap::Alloc(const HloValue* buffer, int64 size) { +template +void NoFragmentationStatsHeap::Alloc(const BufferType* buffer, + int64 size) { current_heap_size_ += size; if (current_heap_size_ > max_heap_size_) { max_heap_size_ = current_heap_size_; } } -void NoFragmentationStatsHeap::AccountForSubcomputationMemory( +template +void NoFragmentationStatsHeap::AccountForSubcomputationMemory( const HloInstruction* instruction, int64 alloc_size_by_instruction, const absl::flat_hash_map& memory_by_computation) { @@ -472,11 +482,15 @@ void NoFragmentationStatsHeap::AccountForSubcomputationMemory( std::max(max_heap_size_, current_heap_size_ + max_subcomputation_bytes); } -void NoFragmentationStatsHeap::Free(const HloValue* buffer, int64 size) { +template +void NoFragmentationStatsHeap::Free(const BufferType* buffer, + int64 size) { current_heap_size_ -= size; } -HeapSimulator::Result NoFragmentationStatsHeap::Finish() { +template +HeapSimulator::Result +NoFragmentationStatsHeap::Finish() { // The result.chunk_map is empty, since we only collect stats, and don't // actually compute chunk assignments. Result result; @@ -484,7 +498,8 @@ HeapSimulator::Result NoFragmentationStatsHeap::Finish() { return result; } -GlobalDecreasingSizeBestFitHeap::GlobalDecreasingSizeBestFitHeap( +template +GlobalDecreasingSizeBestFitHeap::GlobalDecreasingSizeBestFitHeap( int64 alignment, Type type) : alignment_(alignment) { if (type == kTemporal) { @@ -495,8 +510,10 @@ GlobalDecreasingSizeBestFitHeap::GlobalDecreasingSizeBestFitHeap( } } -GlobalDecreasingSizeBestFitHeap::BufferIntervalCompare -GlobalDecreasingSizeBestFitHeap::GetTemporalBufferIntervalCompare() const { +template +typename GlobalDecreasingSizeBestFitHeap::BufferIntervalCompare +GlobalDecreasingSizeBestFitHeap::GetTemporalBufferIntervalCompare() + const { return [&](const BufferInterval& x, const BufferInterval& y) { int64 x_end = x.end; for (auto colocation : GetTransitiveColocations(x)) { @@ -515,12 +532,14 @@ GlobalDecreasingSizeBestFitHeap::GetTemporalBufferIntervalCompare() const { if (x.size != y.size) { return x.size > y.size; } - return x.buffer->id() < y.buffer->id(); + return *x.buffer < *y.buffer; }; } -/*static*/ GlobalDecreasingSizeBestFitHeap::BufferIntervalCompare -GlobalDecreasingSizeBestFitHeap::GetSpatialBufferIntervalCompare() { +template +/*static*/ typename GlobalDecreasingSizeBestFitHeap< + BufferType>::BufferIntervalCompare +GlobalDecreasingSizeBestFitHeap::GetSpatialBufferIntervalCompare() { return [&](const BufferInterval& x, const BufferInterval& y) { if (x.size != y.size) { return x.size > y.size; @@ -528,12 +547,13 @@ GlobalDecreasingSizeBestFitHeap::GetSpatialBufferIntervalCompare() { if (x.end - x.start != y.end - y.start) { return x.end - x.start > y.end - y.start; } - return x.buffer->id() < y.buffer->id(); + return *x.buffer < *y.buffer; }; } -void GlobalDecreasingSizeBestFitHeap::Alloc(const HloValue* buffer, - int64 size) { +template +void GlobalDecreasingSizeBestFitHeap::Alloc( + const BufferType* buffer, int64 size) { // Degenerate case: 0-sized buffers are always allocated at offset 0. if (size == 0) { result_.chunk_map.emplace(buffer, Chunk{0, 0}); @@ -546,9 +566,9 @@ void GlobalDecreasingSizeBestFitHeap::Alloc(const HloValue* buffer, ++current_time_; } -void GlobalDecreasingSizeBestFitHeap::ShareWith(const HloValue* buffer, - const HloValue* share_with, - int64 size) { +template +void GlobalDecreasingSizeBestFitHeap::ShareWith( + const BufferType* buffer, const BufferType* share_with, int64 size) { // Degenerate case: 0-sized buffers are always allocated at offset 0. if (size == 0) { result_.chunk_map.emplace(buffer, Chunk{0, 0}); @@ -562,15 +582,16 @@ void GlobalDecreasingSizeBestFitHeap::ShareWith(const HloValue* buffer, ++current_time_; } -absl::flat_hash_set -GlobalDecreasingSizeBestFitHeap::GetTransitiveColocations( +template +absl::flat_hash_set +GlobalDecreasingSizeBestFitHeap::GetTransitiveColocations( const BufferInterval& interval) const { - absl::flat_hash_set result; + absl::flat_hash_set result; std::vector worklist = {&interval}; while (!worklist.empty()) { const BufferInterval* item = worklist.back(); worklist.pop_back(); - for (const HloValue* buffer_colocated : item->colocations) { + for (const BufferType* buffer_colocated : item->colocations) { result.insert(buffer_colocated); worklist.push_back(&buffer_intervals_.at(buffer_colocated)); } @@ -579,7 +600,9 @@ GlobalDecreasingSizeBestFitHeap::GetTransitiveColocations( return result; } -void GlobalDecreasingSizeBestFitHeap::Free(const HloValue* buffer, int64 size) { +template +void GlobalDecreasingSizeBestFitHeap::Free(const BufferType* buffer, + int64 size) { // Degenerate case: 0-sized buffers are always allocated at offset 0. if (size == 0) { return; @@ -785,7 +808,9 @@ std::vector BufferIntervalTree::ChunksOverlappingInTime( return result; } -HeapSimulator::Result GlobalDecreasingSizeBestFitHeap::Finish() { +template +HeapSimulator::Result +GlobalDecreasingSizeBestFitHeap::Finish() { std::vector sorted_buffer_intervals = GetSortedBufferIntervals(); @@ -803,8 +828,10 @@ HeapSimulator::Result GlobalDecreasingSizeBestFitHeap::Finish() { return result_; } -std::vector -GlobalDecreasingSizeBestFitHeap::GetSortedBufferIntervals() const { +template +std::vector< + typename GlobalDecreasingSizeBestFitHeap::BufferInterval> +GlobalDecreasingSizeBestFitHeap::GetSortedBufferIntervals() const { std::vector sorted_buffer_intervals; for (auto& entry : buffer_intervals_) { sorted_buffer_intervals.push_back(entry.second); @@ -814,8 +841,9 @@ GlobalDecreasingSizeBestFitHeap::GetSortedBufferIntervals() const { return sorted_buffer_intervals; } -GlobalDecreasingSizeBestFitHeap::ChunkCandidate -GlobalDecreasingSizeBestFitHeap::FindChunkCandidate( +template +typename GlobalDecreasingSizeBestFitHeap::ChunkCandidate +GlobalDecreasingSizeBestFitHeap::FindChunkCandidate( const GlobalDecreasingSizeBestFitHeap::BufferInterval& buffer_interval, int64 preferred_offset) const { VLOG(1) << "Finding chunks for buffer: " @@ -912,9 +940,12 @@ GlobalDecreasingSizeBestFitHeap::FindChunkCandidate( return chunk_candidate; } -void GlobalDecreasingSizeBestFitHeap::CommitChunk( - const GlobalDecreasingSizeBestFitHeap::BufferInterval& buffer_interval, - GlobalDecreasingSizeBestFitHeap::ChunkCandidate chunk_candidate) { +template +void GlobalDecreasingSizeBestFitHeap::CommitChunk( + const GlobalDecreasingSizeBestFitHeap::BufferInterval& + buffer_interval, + GlobalDecreasingSizeBestFitHeap::ChunkCandidate + chunk_candidate) { // Update the maximum heap size according to the one determined by the chunk // candidate. result_.heap_size = chunk_candidate.heap_size; @@ -930,13 +961,16 @@ void GlobalDecreasingSizeBestFitHeap::CommitChunk( AddToChunkMap(buffer_interval.buffer, chunk_candidate.chunk); } -void GlobalDecreasingSizeBestFitHeap::AddToChunkMap(const HloValue* buffer, - Chunk chunk) { +template +void GlobalDecreasingSizeBestFitHeap::AddToChunkMap( + const BufferType* buffer, Chunk chunk) { const auto emplace_result = result_.chunk_map.emplace(buffer, chunk); DCHECK(emplace_result.second); } -HeapSimulator::Result ChooseBestHeapAlgorithm::Finish() { +template +HeapSimulator::Result +ChooseBestHeapAlgorithm::Finish() { DCHECK(!algorithms_.empty()); std::vector results(algorithms_.size()); int64 min_size = INT64_MAX; @@ -953,4 +987,9 @@ HeapSimulator::Result ChooseBestHeapAlgorithm::Finish() { return results[min_size_index]; } +template class GlobalDecreasingSizeBestFitHeap; +template class GlobalDecreasingSizeBestFitHeap< + MemorySpaceAssignmentRepacker::AllocationBlock>; +template class ChooseBestHeapAlgorithm; + } // namespace xla diff --git a/tensorflow/compiler/xla/service/heap_simulator.h b/tensorflow/compiler/xla/service/heap_simulator.h index d3b781ded0c..b47ff685139 100644 --- a/tensorflow/compiler/xla/service/heap_simulator.h +++ b/tensorflow/compiler/xla/service/heap_simulator.h @@ -40,7 +40,9 @@ limitations under the License. namespace xla { // Forward declare classes defined below. +template class HeapAlgorithm; +template class NoFragmentationStatsHeap; // HeapSimulator assigns buffer offsets by running a simulation of a regular @@ -66,9 +68,10 @@ class HeapSimulator { }; // Result represents the result of the heap simulation. + template struct Result { // The assignment of buffers to chunks. - absl::flat_hash_map chunk_map; + absl::flat_hash_map chunk_map; // The total size in bytes of the heap, containing all assigned chunks. int64 heap_size = 0; @@ -128,19 +131,19 @@ class HeapSimulator { // to running on a per-computation basis, since we can re-use buffer space for // called sub-computations. // - static StatusOr Run(std::unique_ptr algorithm, - const HloModule& module, - const HloSchedule& schedule, - const HloAliasAnalysis& alias_analysis, - const BufferValue::SizeFunction& size_fn, - const Options& options = Options()); + static StatusOr> Run( + std::unique_ptr> algorithm, + const HloModule& module, const HloSchedule& schedule, + const HloAliasAnalysis& alias_analysis, + const BufferValue::SizeFunction& size_fn, + const Options& options = Options()); // Same as above, but runs on a single computation. The 'instruction_sequence' // must contain a topologically-consistent total ordering of all instructions // in the computation. The result is invalid if instructions are not run in // exactly this sequence. - static StatusOr Run( - std::unique_ptr algorithm, + static StatusOr> Run( + std::unique_ptr> algorithm, const HloComputation& computation, const HloInstructionSequence& instruction_sequence, const HloAliasAnalysis& alias_analysis, @@ -151,8 +154,8 @@ class HeapSimulator { // Same as above, but runs on with a schedule that covers all nested // computations. - static StatusOr Run( - std::unique_ptr algorithm, + static StatusOr> Run( + std::unique_ptr> algorithm, const HloComputation& computation, const HloInstructionSequence& instruction_sequence, const HloAliasAnalysis& alias_analysis, @@ -163,7 +166,7 @@ class HeapSimulator { // If 'schedule' is non-null, it is used to find kCall and kWhile // sub-computations, and the heap simulation for those sub-computations will // be run recursively. I.e. the simulation is run over the whole module. - HeapSimulator(std::unique_ptr algorithm, + HeapSimulator(std::unique_ptr> algorithm, const BufferValue::SizeFunction& size_fn, const Options& options, const HloSchedule* schedule = nullptr, const absl::flat_hash_map* @@ -187,7 +190,7 @@ class HeapSimulator { // Two buffers belong to the same shared group. // Eight of the buffer has no shared group assigned. bool InSameSharedGroup(const HloValue* left, const HloValue* right); - Result Finish(); + Result Finish(); void FillDebugTrace(HeapSimulatorTrace::Event::Kind kind, const HloValue* buffer, const HloInstruction* instruction, @@ -196,8 +199,9 @@ class HeapSimulator { // Counterintuitive: the algorithm_ itself can be a NoFragmentationStatsHeap, // in which case we are calculating the same allocs/frees twice in the // simulation. - const std::unique_ptr no_fragmentation_stats_; - const std::unique_ptr algorithm_; + const std::unique_ptr> + no_fragmentation_stats_; + const std::unique_ptr> algorithm_; const BufferValue::SizeFunction size_fn_; const Options options_; // schedule_ is set by buffer assignment, and memory_by_computation_ is @@ -220,15 +224,16 @@ class HeapSimulator { // offsets to buffers. A sequence of Alloc / Free calls will be made, with the // same semantics as a regular memory heap. Finish will be called at the end to // collect the simulation results. +template class HeapAlgorithm { public: using Chunk = HeapSimulator::Chunk; - using Result = HeapSimulator::Result; + using Result = HeapSimulator::Result; virtual ~HeapAlgorithm() = default; // Alloc allocates a buffer of 'size' bytes. - virtual void Alloc(const HloValue* buffer, int64 size) = 0; + virtual void Alloc(const BufferType* buffer, int64 size) = 0; // Takes memory usage of subcomputations into account when calculating the // memory usage of a computation. Currently, we don't handle buffer aliasing @@ -247,7 +252,7 @@ class HeapAlgorithm { memory_by_computation) {} // Free de-allocates a previously allocated buffer. - virtual void Free(const HloValue* buffer, int64 size) = 0; + virtual void Free(const BufferType* buffer, int64 size) = 0; // Indicates that a buffer has to be collocated with another buffer. In // addition to Alloc and Free, the heap simulator exposes a concept of buffer @@ -255,7 +260,7 @@ class HeapAlgorithm { // the buffer, it associates the buffer with a previously allocated (or // shared) buffer. Each group of mutually-shared buffers points to a single // SharedGroup instance, which is a shared control block. - virtual void ShareWith(const HloValue* buffer, const HloValue* share_with, + virtual void ShareWith(const BufferType* buffer, const BufferType* share_with, int64 size) { Alloc(buffer, size); } @@ -269,19 +274,22 @@ class HeapAlgorithm { // this is the absolute minimum size for a given instruction sequence. The // result.chunk_map returned in Finish is always empty, since we only collect // stats, and don't actually compute chunk assignments. -class NoFragmentationStatsHeap : public HeapAlgorithm { +template +class NoFragmentationStatsHeap : public HeapAlgorithm { public: + using Result = HeapSimulator::Result; + NoFragmentationStatsHeap() = default; ~NoFragmentationStatsHeap() override = default; - void Alloc(const HloValue* buffer, int64 size) override; + void Alloc(const BufferType* buffer, int64 size) override; void AccountForSubcomputationMemory( const HloInstruction* instruction, int64 alloc_size_by_instruction, const absl::flat_hash_map& memory_by_computation) override; - void Free(const HloValue* buffer, int64 size) override; + void Free(const BufferType* buffer, int64 size) override; Result Finish() override; @@ -336,8 +344,12 @@ class BufferIntervalTree { // alloc/free time. It internally tracks the allocated buffers and their live // intervals; when allocating a buffer, it finds the best-fit free chunk during // its live interval. -class GlobalDecreasingSizeBestFitHeap : public HeapAlgorithm { +template +class GlobalDecreasingSizeBestFitHeap : public HeapAlgorithm { public: + using Result = HeapSimulator::Result; + using Chunk = HeapSimulator::Chunk; + enum Type { kSpatial = 0, kTemporal, @@ -345,7 +357,7 @@ class GlobalDecreasingSizeBestFitHeap : public HeapAlgorithm { // BufferInterval stores a buffer's size and time interval. struct BufferInterval { - const HloValue* buffer; + const BufferType* buffer; int64 size; // Alloc time of the buffer. int64 start; @@ -353,7 +365,7 @@ class GlobalDecreasingSizeBestFitHeap : public HeapAlgorithm { int64 end; // Colocation buffers that need to be collocated with this one. - std::vector colocations; + std::vector colocations; // True if this buffer needs an allocation. False if it is collocated with // other buffer. @@ -368,10 +380,10 @@ class GlobalDecreasingSizeBestFitHeap : public HeapAlgorithm { Type type = kSpatial); ~GlobalDecreasingSizeBestFitHeap() override {} - void Alloc(const HloValue* buffer, int64 size) override; - void Free(const HloValue* buffer, int64 size) override; + void Alloc(const BufferType* buffer, int64 size) override; + void Free(const BufferType* buffer, int64 size) override; - void ShareWith(const HloValue* buffer, const HloValue* share_with, + void ShareWith(const BufferType* buffer, const BufferType* share_with, int64 size) override; Result Finish() override; @@ -404,7 +416,7 @@ class GlobalDecreasingSizeBestFitHeap : public HeapAlgorithm { void CommitChunk(const BufferInterval& buffer_interval, ChunkCandidate chunk_candidate); // Adds the buffer and the chunk to the result chunk map. - virtual void AddToChunkMap(const HloValue* buffer, Chunk chunk); + virtual void AddToChunkMap(const BufferType* buffer, Chunk chunk); // Return a BufferIntervalCompare function that sorts by live ranges. A live // range is defined by the range between the start of the first buffer and the @@ -413,7 +425,7 @@ class GlobalDecreasingSizeBestFitHeap : public HeapAlgorithm { // contiguous. BufferIntervalCompare GetTemporalBufferIntervalCompare() const; - absl::flat_hash_map buffer_intervals_; + absl::flat_hash_map buffer_intervals_; Result result_; BufferIntervalCompare buffer_interval_compare_; BufferIntervalTree interval_tree_; @@ -428,33 +440,37 @@ class GlobalDecreasingSizeBestFitHeap : public HeapAlgorithm { // Returns all transitive colocated buffers of this buffer interval. I.e., If // a buffer A is colocated with B and B is colocated with C, this function // returns all three of them. - absl::flat_hash_set GetTransitiveColocations( + absl::flat_hash_set GetTransitiveColocations( const BufferInterval& interval) const; }; // A heap algorithm that chooses the best results from other algorithms added to // it. -class ChooseBestHeapAlgorithm : public HeapAlgorithm { +template +class ChooseBestHeapAlgorithm : public HeapAlgorithm { public: + using Result = HeapSimulator::Result; + ChooseBestHeapAlgorithm( - std::unique_ptr>> algorithms) + std::unique_ptr>>> + algorithms) : algorithms_(std::move(*algorithms)) {} ~ChooseBestHeapAlgorithm() override {} - void Alloc(const HloValue* buffer, int64 size) override { + void Alloc(const BufferType* buffer, int64 size) override { for (auto& algorithm : algorithms_) { algorithm->Alloc(buffer, size); } } - void ShareWith(const HloValue* buffer, const HloValue* share_with, + void ShareWith(const BufferType* buffer, const BufferType* share_with, int64 size) override { for (auto& algorithm : algorithms_) { algorithm->ShareWith(buffer, share_with, size); } } - void Free(const HloValue* buffer, int64 size) override { + void Free(const BufferType* buffer, int64 size) override { for (auto& algorithm : algorithms_) { algorithm->Free(buffer, size); } @@ -463,7 +479,7 @@ class ChooseBestHeapAlgorithm : public HeapAlgorithm { Result Finish() override; private: - std::vector> algorithms_; + std::vector>> algorithms_; }; } // namespace xla diff --git a/tensorflow/compiler/xla/service/heap_simulator_test.cc b/tensorflow/compiler/xla/service/heap_simulator_test.cc index b5b711cab4f..8f7668b4965 100644 --- a/tensorflow/compiler/xla/service/heap_simulator_test.cc +++ b/tensorflow/compiler/xla/service/heap_simulator_test.cc @@ -228,7 +228,7 @@ const char kFinish[] = "Finish"; using CallSequence = std::vector>; // HeapCallRecorder is a dummy heap algorithm that simply records its calls. -class HeapCallRecorder : public HeapAlgorithm { +class HeapCallRecorder : public HeapAlgorithm { public: explicit HeapCallRecorder(CallSequence* calls) : calls_(calls) {} ~HeapCallRecorder() override {} @@ -396,7 +396,7 @@ class HeapSimulatorTracker { std::unique_ptr module_; std::unique_ptr alias_analysis_; CallSequence actual_calls_; - HeapSimulator::Result result_; + HeapSimulator::Result result_; }; class HeapSimulatorTest : public HloTestBase { @@ -976,12 +976,12 @@ class HeapAlgorithmTestBase : public ::testing::Test { class NoFragmentationStatsHeapTest : public HeapAlgorithmTestBase {}; TEST_F(NoFragmentationStatsHeapTest, Empty) { - NoFragmentationStatsHeap heap; + NoFragmentationStatsHeap heap; EXPECT_EQ(0, heap.Finish().heap_size); } TEST_F(NoFragmentationStatsHeapTest, Simple) { - NoFragmentationStatsHeap heap; + NoFragmentationStatsHeap heap; heap.Alloc(buffer_a_, 10); heap.Alloc(buffer_b_, 20); heap.Alloc(buffer_c_, 30); @@ -994,7 +994,7 @@ TEST_F(NoFragmentationStatsHeapTest, Simple) { } TEST_F(NoFragmentationStatsHeapTest, Mixed) { - NoFragmentationStatsHeap heap; + NoFragmentationStatsHeap heap; heap.Alloc(buffer_a_, 10); // max: A heap.Alloc(buffer_b_, 20); // max: A+B @@ -1013,7 +1013,7 @@ TEST_F(NoFragmentationStatsHeapTest, Mixed) { class GlobalDecreasingSizeBestFitHeapTest : public HeapAlgorithmTestBase { protected: class InheritedGlobalDecreasingSizeBestFitHeap - : public GlobalDecreasingSizeBestFitHeap { + : public GlobalDecreasingSizeBestFitHeap { public: InheritedGlobalDecreasingSizeBestFitHeap() : GlobalDecreasingSizeBestFitHeap(/*alignment=*/1) {} @@ -1048,8 +1048,8 @@ class GlobalDecreasingSizeBestFitHeapTest : public HeapAlgorithmTestBase { }; TEST_F(GlobalDecreasingSizeBestFitHeapTest, Empty) { - GlobalDecreasingSizeBestFitHeap heap(/*alignment=*/1); - const HeapSimulator::Result result = heap.Finish(); + GlobalDecreasingSizeBestFitHeap heap(/*alignment=*/1); + const HeapSimulator::Result result = heap.Finish(); EXPECT_EQ(0, result.heap_size); EXPECT_EQ(0, result.chunk_map.size()); } @@ -1068,7 +1068,7 @@ TEST_F(GlobalDecreasingSizeBestFitHeapTest, DecreasingSize) { // | | d | // | +-------+ // -----------------> time - GlobalDecreasingSizeBestFitHeap heap(/*alignment=*/1); + GlobalDecreasingSizeBestFitHeap heap(/*alignment=*/1); heap.Alloc(buffer_a_, 10); heap.Alloc(buffer_b_, 30); heap.Alloc(buffer_c_, 20); @@ -1078,7 +1078,7 @@ TEST_F(GlobalDecreasingSizeBestFitHeapTest, DecreasingSize) { heap.Free(buffer_c_, 20); heap.Free(buffer_d_, 40); - const HeapSimulator::Result result = heap.Finish(); + const HeapSimulator::Result result = heap.Finish(); EXPECT_EQ(100, result.heap_size); EXPECT_EQ(10, result.chunk_map.at(buffer_a_).size); EXPECT_EQ(30, result.chunk_map.at(buffer_b_).size); @@ -1107,7 +1107,7 @@ TEST_F(GlobalDecreasingSizeBestFitHeapTest, DecreasingSizeWithAlignment) { // | | | // | +-------+ // ---------------------> time - GlobalDecreasingSizeBestFitHeap heap(/*alignment=*/20); + GlobalDecreasingSizeBestFitHeap heap(/*alignment=*/20); heap.Alloc(buffer_a_, 10); heap.Alloc(buffer_b_, 20); heap.Alloc(buffer_c_, 50); @@ -1117,7 +1117,7 @@ TEST_F(GlobalDecreasingSizeBestFitHeapTest, DecreasingSizeWithAlignment) { heap.Free(buffer_c_, 50); heap.Free(buffer_d_, 40); - const HeapSimulator::Result result = heap.Finish(); + const HeapSimulator::Result result = heap.Finish(); EXPECT_EQ(120, result.heap_size); EXPECT_EQ(10, result.chunk_map.at(buffer_a_).size); EXPECT_EQ(20, result.chunk_map.at(buffer_b_).size); @@ -1148,7 +1148,7 @@ TEST_F(GlobalDecreasingSizeBestFitHeapTest, BestFit) { // | | | // | +-------+ // ---------------------> time - GlobalDecreasingSizeBestFitHeap heap(/*alignment=*/1); + GlobalDecreasingSizeBestFitHeap heap(/*alignment=*/1); heap.Alloc(buffer_a_, 10); heap.Alloc(buffer_b_, 20); heap.Alloc(buffer_c_, 40); @@ -1160,7 +1160,7 @@ TEST_F(GlobalDecreasingSizeBestFitHeapTest, BestFit) { heap.Free(buffer_d_, 30); heap.Free(buffer_e_, 50); - const HeapSimulator::Result result = heap.Finish(); + const HeapSimulator::Result result = heap.Finish(); EXPECT_EQ(140, result.heap_size); EXPECT_EQ(10, result.chunk_map.at(buffer_a_).size); EXPECT_EQ(20, result.chunk_map.at(buffer_b_).size); @@ -1184,7 +1184,7 @@ TEST_F(GlobalDecreasingSizeBestFitHeapTest, Colocated) { // || |+----+| | // |+--a---++-b--++---c---+ // ---------------------> time - GlobalDecreasingSizeBestFitHeap heap(/*alignment=*/1); + GlobalDecreasingSizeBestFitHeap heap(/*alignment=*/1); heap.Alloc(buffer_a_, 40); heap.Free(buffer_a_, 40); heap.Alloc(buffer_b_, 20); @@ -1192,7 +1192,7 @@ TEST_F(GlobalDecreasingSizeBestFitHeapTest, Colocated) { heap.ShareWith(buffer_c_, buffer_a_, 40); heap.Free(buffer_c_, 40); - const HeapSimulator::Result result = heap.Finish(); + const HeapSimulator::Result result = heap.Finish(); EXPECT_EQ(40, result.heap_size); EXPECT_EQ(40, result.chunk_map.at(buffer_a_).size); EXPECT_EQ(20, result.chunk_map.at(buffer_b_).size); @@ -1212,7 +1212,7 @@ TEST_F(GlobalDecreasingSizeBestFitHeapTest, ColocatedII) { // || | | | <--- colocate with a // |+--a---+ +---c---+ // ---------------------> time - GlobalDecreasingSizeBestFitHeap heap(/*alignment=*/1); + GlobalDecreasingSizeBestFitHeap heap(/*alignment=*/1); heap.Alloc(buffer_a_, 40); heap.Free(buffer_a_, 40); heap.Alloc(buffer_b_, 20); @@ -1221,7 +1221,7 @@ TEST_F(GlobalDecreasingSizeBestFitHeapTest, ColocatedII) { heap.Free(buffer_c_, 40); heap.Free(buffer_b_, 20); - const HeapSimulator::Result result = heap.Finish(); + const HeapSimulator::Result result = heap.Finish(); EXPECT_EQ(60, result.heap_size); EXPECT_EQ(40, result.chunk_map.at(buffer_a_).size); EXPECT_EQ(20, result.chunk_map.at(buffer_b_).size); @@ -1242,7 +1242,7 @@ TEST_F(GlobalDecreasingSizeBestFitHeapTest, ColocatedIII) { // | | | // | +-------b-------+ // ---------------------> time - GlobalDecreasingSizeBestFitHeap heap(/*alignment=*/1); + GlobalDecreasingSizeBestFitHeap heap(/*alignment=*/1); heap.Alloc(buffer_a_, 10); heap.Free(buffer_a_, 10); heap.Alloc(buffer_b_, 30); @@ -1251,7 +1251,7 @@ TEST_F(GlobalDecreasingSizeBestFitHeapTest, ColocatedIII) { heap.Free(buffer_c_, 10); heap.Free(buffer_b_, 30); - const HeapSimulator::Result result = heap.Finish(); + const HeapSimulator::Result result = heap.Finish(); EXPECT_EQ(40, result.heap_size); EXPECT_EQ(10, result.chunk_map.at(buffer_a_).size); EXPECT_EQ(30, result.chunk_map.at(buffer_b_).size); diff --git a/tensorflow/compiler/xla/service/hlo.proto b/tensorflow/compiler/xla/service/hlo.proto index 17a7b18c84b..c3a7b3a5c14 100644 --- a/tensorflow/compiler/xla/service/hlo.proto +++ b/tensorflow/compiler/xla/service/hlo.proto @@ -35,7 +35,7 @@ import "tensorflow/compiler/xla/xla_data.proto"; option cc_enable_arenas = true; // Serialization of HloInstruction. -// Next ID: 73 +// Next ID: 74 message HloInstructionProto { reserved 10; reserved "parameter_name"; @@ -251,6 +251,9 @@ message HloInstructionProto { // The comparison type used for kCompare. string comparison_type = 72; + + // Specifies if this is a cross-program-prefetch, used by kCopyStart. + bool is_cross_program_prefetch = 73; } // Serialization of HloComputation. diff --git a/tensorflow/compiler/xla/service/hlo_computation.cc b/tensorflow/compiler/xla/service/hlo_computation.cc index 438aa6ff05f..14daf680ac9 100644 --- a/tensorflow/compiler/xla/service/hlo_computation.cc +++ b/tensorflow/compiler/xla/service/hlo_computation.cc @@ -545,7 +545,7 @@ string HloComputation::ToString( if (options.print_percent()) { s << "%"; } - if (options.print_ids() || !IsEntryComputation()) { + if (options.print_ids()) { // Exclude entry computation's name because it includes and leads to // non-deterministic fingerprint. s << PrintName(name(), options.print_ids()) << " "; diff --git a/tensorflow/compiler/xla/service/hlo_cost_analysis.cc b/tensorflow/compiler/xla/service/hlo_cost_analysis.cc index 72b15db0dcd..939c713fc18 100644 --- a/tensorflow/compiler/xla/service/hlo_cost_analysis.cc +++ b/tensorflow/compiler/xla/service/hlo_cost_analysis.cc @@ -486,6 +486,10 @@ Status HloCostAnalysis::HandleReshape(const HloInstruction*) { return Status::OK(); } +Status HloCostAnalysis::HandleDynamicReshape(const HloInstruction*) { + return Status::OK(); +} + Status HloCostAnalysis::HandleBatchNormTraining(const HloInstruction*) { // TODO(b/62294698): Implement cost analysis for batch-norm-training. return Status::OK(); diff --git a/tensorflow/compiler/xla/service/hlo_cost_analysis.h b/tensorflow/compiler/xla/service/hlo_cost_analysis.h index d9085dd7785..f101e3819c9 100644 --- a/tensorflow/compiler/xla/service/hlo_cost_analysis.h +++ b/tensorflow/compiler/xla/service/hlo_cost_analysis.h @@ -113,6 +113,7 @@ class HloCostAnalysis : public ConstDfsHloVisitor { Status HandleBroadcast(const HloInstruction* broadcast) override; Status HandlePad(const HloInstruction* pad) override; Status HandleReshape(const HloInstruction* reshape) override; + Status HandleDynamicReshape(const HloInstruction* reshape) override; Status HandleAddDependency(const HloInstruction* add_dependency) override; Status HandleAfterAll(const HloInstruction* token) override; Status HandleTranspose(const HloInstruction* transpose) override; diff --git a/tensorflow/compiler/xla/service/hlo_dataflow_analysis_test.cc b/tensorflow/compiler/xla/service/hlo_dataflow_analysis_test.cc index 1bbbb248bbc..551ffb52031 100644 --- a/tensorflow/compiler/xla/service/hlo_dataflow_analysis_test.cc +++ b/tensorflow/compiler/xla/service/hlo_dataflow_analysis_test.cc @@ -1229,10 +1229,10 @@ TEST_P(HloDataflowAnalysisTest, CopyStartAndCopyDone) { auto builder = HloComputation::Builder(TestName()); auto constant = builder.AddInstruction( HloInstruction::CreateConstant(LiteralUtil::CreateR0(1.0))); - auto copy_start = builder.AddInstruction(HloInstruction::CreateUnary( + auto copy_start = builder.AddInstruction(HloInstruction::CreateCopyStart( ShapeUtil::MakeTupleShape({constant->shape(), constant->shape(), ShapeUtil::MakeShape(U32, {})}), - HloOpcode::kCopyStart, constant)); + constant)); auto copy_done = builder.AddInstruction(HloInstruction::CreateUnary( constant->shape(), HloOpcode::kCopyDone, copy_start)); module_->AddEntryComputation(builder.Build()); diff --git a/tensorflow/compiler/xla/service/hlo_get_dimension_size_rewriter.cc b/tensorflow/compiler/xla/service/hlo_get_dimension_size_rewriter.cc deleted file mode 100644 index 9415e20af7b..00000000000 --- a/tensorflow/compiler/xla/service/hlo_get_dimension_size_rewriter.cc +++ /dev/null @@ -1,120 +0,0 @@ -/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. -==============================================================================*/ - -#include "tensorflow/compiler/xla/service/hlo_get_dimension_size_rewriter.h" - -#include "absl/algorithm/container.h" -#include "tensorflow/compiler/xla/literal_util.h" -#include "tensorflow/compiler/xla/service/dynamic_dimension_inference.h" -#include "tensorflow/compiler/xla/service/hlo_instruction.h" -#include "tensorflow/compiler/xla/service/hlo_opcode.h" -#include "tensorflow/compiler/xla/service/shape_inference.h" - -namespace xla { - -namespace { - -StatusOr ReplaceGetSize( - HloInstruction* instr, - DynamicDimensionInference* dynamic_dimension_inference) { - if (instr->opcode() != HloOpcode::kGetDimensionSize) { - return false; - } - HloComputation* computation = instr->parent(); - - TF_ASSIGN_OR_RETURN(auto legal_shape, - ShapeInference::InferGetDimensionSizeShape( - instr->operand(0)->shape(), instr->dimension())); - TF_RET_CHECK(ShapeUtil::Equal(instr->shape(), legal_shape)) - << "instr->shape() " << instr->shape().ToString() << " , " - << "legal_shape " << legal_shape.ToString(); - TF_RET_CHECK(ShapeUtil::HasPrimitiveType(instr->shape(), S32)); - HloInstruction* operand = instr->mutable_operand(0); - int64 dim = instr->dimension(); - HloInstruction* dynamic_size = - dynamic_dimension_inference->GetDynamicSize(operand, {}, dim); - if (dynamic_size != nullptr) { - TF_RETURN_IF_ERROR(instr->ReplaceAllUsesWith(dynamic_size)); - // The dependency between a instruction and its dynamic dimensions is not - // modeled in the IR. As instr is being replaced by dynamic_size, also tell - // dynamic dimension inference that the instruction is being replaced. - dynamic_dimension_inference->ReplaceAllDynamicDimensionUsesWith( - instr, dynamic_size); - } else { - int32 size = instr->operand(0)->shape().dimensions(dim); - HloInstruction* new_instr = computation->AddInstruction( - HloInstruction::CreateConstant(LiteralUtil::CreateR0(size))); - TF_RETURN_IF_ERROR(instr->ReplaceAllUsesWith(new_instr)); - dynamic_dimension_inference->ReplaceAllDynamicDimensionUsesWith(instr, - new_instr); - } - return true; -} - -StatusOr ReplaceSetSize(HloInstruction* instr) { - if (instr->opcode() != HloOpcode::kSetDimensionSize) { - return false; - } - - TF_RET_CHECK(Shape::Equal().IgnoreDynamicDimension()( - instr->shape(), instr->operand(0)->shape())) - << "instr->shape() " << instr->shape().ToString() << " , " - << "instruction operand shape " << instr->operand(0)->shape(); - HloInstruction* operand = instr->mutable_operand(0); - - TF_RETURN_IF_ERROR(instr->ReplaceAllUsesWith(operand)); - return true; -} - -} // namespace - -StatusOr HloGetDimensionSizeRewriter::Run(HloModule* module) { - bool changed = false; - HloProto proto; - TF_ASSIGN_OR_RETURN(DynamicDimensionInference inference, - DynamicDimensionInference::Run(module)); - *proto.mutable_hlo_module() = module->ToProto(); - // It's important to replace get-dimension-size first before - // set-dimension-size for the case below: - // static_op dynamic_size - // | | - // set-dimension-size // Marks the dimension as dynamic - // | - // get-dimension-size - // - // If we replace set dimension size first, we'd have - // - // static_op - // | - // get-dimension-size - // - // This will get static size of the op, which is incorrect. - for (auto* computation : module->computations()) { - for (auto instruction : computation->MakeInstructionPostOrder()) { - TF_ASSIGN_OR_RETURN(bool replaced_get_size, - ReplaceGetSize(instruction, &inference)); - changed = changed || replaced_get_size; - } - } - for (auto* computation : module->computations()) { - for (auto instruction : computation->MakeInstructionPostOrder()) { - TF_ASSIGN_OR_RETURN(bool replaced_set_size, ReplaceSetSize(instruction)); - changed = changed || replaced_set_size; - } - } - return changed; -} - -} // namespace xla diff --git a/tensorflow/compiler/xla/service/hlo_get_dimension_size_rewriter_test.cc b/tensorflow/compiler/xla/service/hlo_get_dimension_size_rewriter_test.cc deleted file mode 100644 index b1491e96095..00000000000 --- a/tensorflow/compiler/xla/service/hlo_get_dimension_size_rewriter_test.cc +++ /dev/null @@ -1,102 +0,0 @@ -/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. -==============================================================================*/ - -#include "tensorflow/compiler/xla/service/hlo_get_dimension_size_rewriter.h" - -#include "tensorflow/compiler/xla/service/hlo_computation.h" -#include "tensorflow/compiler/xla/service/hlo_instruction.h" -#include "tensorflow/compiler/xla/service/hlo_matchers.h" -#include "tensorflow/compiler/xla/service/hlo_module.h" -#include "tensorflow/compiler/xla/service/hlo_opcode.h" -#include "tensorflow/compiler/xla/service/hlo_parser.h" -#include "tensorflow/compiler/xla/shape_util.h" -#include "tensorflow/compiler/xla/tests/hlo_test_base.h" -#include "tensorflow/compiler/xla/tests/literal_test_util.h" -#include "tensorflow/compiler/xla/tests/test_utils.h" -#include "tensorflow/compiler/xla/types.h" -#include "tensorflow/compiler/xla/util.h" -#include "tensorflow/core/lib/core/status_test_util.h" -#include "tensorflow/core/platform/types.h" - -namespace xla { -namespace { - -namespace op = xla::testing::opcode_matchers; - -class HloGetDimensionSizeRewriterTest : public HloTestBase { - protected: - HloGetDimensionSizeRewriterTest() {} -}; - -TEST_F(HloGetDimensionSizeRewriterTest, Ok) { - auto module = ParseAndReturnVerifiedModule(R"( -HloModule _ -ENTRY gds { - p = s32[3,4] parameter(0) - size0 = s32[] get-dimension-size(p), dimensions={0} - size1 = s32[] get-dimension-size(p), dimensions={1} - ROOT mul = s32[] multiply(size0, size1) -})") - .ValueOrDie(); - HloGetDimensionSizeRewriter pass; - EXPECT_TRUE(pass.Run(module.get()).ValueOrDie()); - EXPECT_THAT(module->entry_computation()->root_instruction(), - op::Multiply(op::Constant(), op::Constant())); -} - -TEST_F(HloGetDimensionSizeRewriterTest, GetSetSetDimensionSizeRewriter) { - auto module = ParseAndReturnVerifiedModule(R"( -HloModule _ -ENTRY gds { - p = s32[3,4] parameter(0) - size0 = s32[] get-dimension-size(p), dimensions={0} - p_copy = s32[3,4] copy(p) - p_copy_dynamic = s32[<=3, 4] set-dimension-size(p_copy, size0), dimensions={0} - size1 = s32[] get-dimension-size(p_copy_dynamic), dimensions={0} - ROOT mul = s32[] multiply(size0, size1) -})") - .ValueOrDie(); - HloGetDimensionSizeRewriter pass; - EXPECT_TRUE(pass.Run(module.get()).ValueOrDie()); - EXPECT_THAT(module->entry_computation()->root_instruction(), - op::Multiply(op::Constant(), op::Constant())); -} - -TEST_F(HloGetDimensionSizeRewriterTest, IllegalType) { - auto module = ParseAndReturnUnverifiedModule(R"( -HloModule _ -ENTRY gds { - p = s32[3]{0} parameter(0) - ROOT gds = s64[] get-dimension-size(p), dimensions={0} -})") - .ValueOrDie(); - HloGetDimensionSizeRewriter pass; - EXPECT_FALSE(pass.Run(module.get()).ok()); -} - -TEST_F(HloGetDimensionSizeRewriterTest, IllegalDimension) { - auto module = ParseAndReturnUnverifiedModule(R"( -HloModule _ -ENTRY gds { - p = f32[2,5] parameter(0) - ROOT gds = s32[] get-dimension-size(p), dimensions={2} -})") - .ValueOrDie(); - HloGetDimensionSizeRewriter pass; - EXPECT_FALSE(pass.Run(module.get()).ok()); -} - -} // namespace -} // namespace xla diff --git a/tensorflow/compiler/xla/service/hlo_graph_dumper.cc b/tensorflow/compiler/xla/service/hlo_graph_dumper.cc index d7e8984dee8..164e92ae8e8 100644 --- a/tensorflow/compiler/xla/service/hlo_graph_dumper.cc +++ b/tensorflow/compiler/xla/service/hlo_graph_dumper.cc @@ -1012,6 +1012,7 @@ ColorScheme HloDotDumper::GetInstructionColor(const HloInstruction* instr) { case HloOpcode::kGather: case HloOpcode::kPad: case HloOpcode::kReshape: + case HloOpcode::kDynamicReshape: case HloOpcode::kReverse: case HloOpcode::kTupleSelect: case HloOpcode::kTranspose: diff --git a/tensorflow/compiler/xla/service/hlo_instruction.cc b/tensorflow/compiler/xla/service/hlo_instruction.cc index 2ce3c12b4e9..bb01fdd0e15 100644 --- a/tensorflow/compiler/xla/service/hlo_instruction.cc +++ b/tensorflow/compiler/xla/service/hlo_instruction.cc @@ -167,6 +167,11 @@ StatusOr> HloInstruction::CreateFromProto( absl::Span(fft_length)); break; } + case HloOpcode::kCopyStart: { + instruction = CreateCopyStart(shape, operands(0), + proto.is_cross_program_prefetch()); + break; + } case HloOpcode::kCompare: { // Auto-upgraded from deprecated opcode skips the following. if (!comparison_direction) { @@ -700,6 +705,17 @@ StatusOr> HloInstruction::CreateFromProto( instruction = CreateReshape(shape, operands(0), inferred_dimension); break; } + case HloOpcode::kDynamicReshape: { + TF_RET_CHECK(shape.IsArray() && operands(0)->shape().IsArray() && + ShapeUtil::ElementsIn(shape) == + ShapeUtil::ElementsIn(operands(0)->shape())) + << "shape: " << ShapeUtil::HumanString(shape) + << " operand: " << ShapeUtil::HumanString(operands(0)->shape()); + const auto& operand_vector = all_operands(); + instruction = CreateDynamicReshape( + shape, operands(0), absl::MakeSpan(operand_vector).subspan(1)); + break; + } default: { instruction = absl::WrapUnique(new HloInstruction(opcode, shape)); for (const int64 operand_id : proto.operand_ids()) { @@ -828,7 +844,6 @@ HloInstruction::CreateRngBitGenerator(const Shape& shape, HloInstruction* state, case HloOpcode::kCeil: case HloOpcode::kCollectivePermuteDone: case HloOpcode::kCopy: - case HloOpcode::kCopyStart: case HloOpcode::kCopyDone: case HloOpcode::kCos: case HloOpcode::kClz: @@ -935,6 +950,13 @@ HloInstruction::CreateRngBitGenerator(const Shape& shape, HloInstruction* state, fft_length); } +/* static */ std::unique_ptr HloInstruction::CreateCopyStart( + const Shape& shape, HloInstruction* operand, + bool is_cross_program_prefetch) { + return absl::make_unique(shape, operand, + is_cross_program_prefetch); +} + /* static */ std::unique_ptr HloInstruction::CreateCompare( const Shape& shape, HloInstruction* lhs, HloInstruction* rhs, ComparisonDirection direction, absl::optional type) { @@ -1373,6 +1395,19 @@ HloInstruction::CreateBroadcastSequence( inferred_dimension); } +/* static */ std::unique_ptr +HloInstruction::CreateDynamicReshape( + const Shape& shape, HloInstruction* data_operand, + absl::Span dim_sizes) { + CHECK_EQ(ShapeUtil::ElementsIn(shape), + ShapeUtil::ElementsIn(data_operand[0].shape())) + << "shape: " << ShapeUtil::HumanString(shape) + << " operand: " << ShapeUtil::HumanString(data_operand[0].shape()); + CHECK_EQ(shape.rank(), dim_sizes.size()); + return absl::make_unique(shape, data_operand, + dim_sizes); +} + /* static */ std::unique_ptr HloInstruction::CreateTranspose( const Shape& shape, HloInstruction* operand, absl::Span dimensions) { @@ -1569,6 +1604,7 @@ std::unique_ptr HloInstruction::CloneWithNewOperands( case HloOpcode::kTranspose: case HloOpcode::kBroadcast: case HloOpcode::kReshape: + case HloOpcode::kDynamicReshape: case HloOpcode::kMap: case HloOpcode::kSlice: case HloOpcode::kConstant: @@ -2007,6 +2043,7 @@ bool HloInstruction::IdenticalSlowPath( case HloOpcode::kReal: case HloOpcode::kRemainder: case HloOpcode::kReshape: + case HloOpcode::kDynamicReshape: case HloOpcode::kReplicaId: case HloOpcode::kRoundNearestAfz: case HloOpcode::kRsqrt: @@ -2812,7 +2849,8 @@ HloInstructionProto HloInstruction::ToProto() const { string HloInstruction::ToCategory() const { if (opcode() == HloOpcode::kTranspose || opcode() == HloOpcode::kCopy || - opcode() == HloOpcode::kReshape) { + opcode() == HloOpcode::kReshape || + opcode() == HloOpcode::kDynamicReshape) { return "data formatting"; } @@ -3033,6 +3071,8 @@ Status HloInstruction::Visit(DfsHloVisitorBase* visitor) { return visitor->HandlePad(this); case HloOpcode::kReshape: return visitor->HandleReshape(this); + case HloOpcode::kDynamicReshape: + return visitor->HandleDynamicReshape(this); case HloOpcode::kTranspose: return visitor->HandleTranspose(this); case HloOpcode::kReverse: @@ -4089,6 +4129,10 @@ const DomainMetadata& HloInstruction::user_side_metadata() const { return Cast(this)->user_side_metadata(); } +bool HloInstruction::is_cross_program_prefetch() const { + return Cast(this)->is_cross_program_prefetch(); +} + ComparisonDirection HloInstruction::comparison_direction() const { return Cast(this)->direction(); } diff --git a/tensorflow/compiler/xla/service/hlo_instruction.h b/tensorflow/compiler/xla/service/hlo_instruction.h index bdd64c908f0..7db128b4d34 100644 --- a/tensorflow/compiler/xla/service/hlo_instruction.h +++ b/tensorflow/compiler/xla/service/hlo_instruction.h @@ -592,6 +592,12 @@ class HloInstruction { const Shape& shape, HloInstruction* operand, FftType fft_type, absl::Span fft_length); + // Creates a copy-start op, indicating whether this is a cross-program + // prefetch or not. + static std::unique_ptr CreateCopyStart( + const Shape& shape, HloInstruction* operand, + bool is_cross_program_prefetch = false); + // Creates a compare op, performing the comparison specified in direction. static std::unique_ptr CreateCompare( const Shape& shape, HloInstruction* lhs, HloInstruction* rhs, @@ -879,6 +885,14 @@ class HloInstruction { const Shape& shape, HloInstruction* operand, int64 inferred_dimension = -1); + // Creates a dynamic reshape instruction. Similar to reshape but dynamic + // dimensions sizes are provided as additional variadic arguments. + // + // Precondition: dim_sizes.size() == shape.rank() + static std::unique_ptr CreateDynamicReshape( + const Shape& shape, HloInstruction* data_operand, + absl::Span dim_sizes); + // Creates a transpose instruction which permutes the operand dimensions. static std::unique_ptr CreateTranspose( const Shape& shape, HloInstruction* operand, @@ -1857,6 +1871,9 @@ class HloInstruction { // Delegates to HloDomainInstruction::user_side_metadata(). const DomainMetadata& user_side_metadata() const; + // Delegates to HloCopyStartInstruction::is_cross_program_prefetch(). + bool is_cross_program_prefetch() const; + // Delegates to HloCompareInstruction::direction(). ComparisonDirection comparison_direction() const; diff --git a/tensorflow/compiler/xla/service/hlo_instructions.cc b/tensorflow/compiler/xla/service/hlo_instructions.cc index dbc1d85d1bb..df225e27aad 100644 --- a/tensorflow/compiler/xla/service/hlo_instructions.cc +++ b/tensorflow/compiler/xla/service/hlo_instructions.cc @@ -204,6 +204,47 @@ std::unique_ptr HloFftInstruction::CloneWithNewOperandsImpl( fft_length_); } +HloCopyStartInstruction::HloCopyStartInstruction(const Shape& shape, + HloInstruction* operand, + bool is_cross_program_prefetch) + : HloInstruction(HloOpcode::kCopyStart, shape), + is_cross_program_prefetch_(is_cross_program_prefetch) { + AppendOperand(operand); +} + +HloInstructionProto HloCopyStartInstruction::ToProto() const { + HloInstructionProto proto = HloInstruction::ToProto(); + proto.set_is_cross_program_prefetch(is_cross_program_prefetch_); + return proto; +} + +std::vector HloCopyStartInstruction::ExtraAttributesToStringImpl( + const HloPrintOptions& options) const { + std::vector result; + if (is_cross_program_prefetch()) { + result.push_back("is_cross_program_prefetch=true"); + } + return result; +} + +bool HloCopyStartInstruction::IdenticalSlowPath( + const HloInstruction& other, + const std::function& + eq_computations) const { + const auto& casted_other = static_cast(other); + return is_cross_program_prefetch() == + casted_other.is_cross_program_prefetch(); +} + +std::unique_ptr +HloCopyStartInstruction::CloneWithNewOperandsImpl( + const Shape& shape, absl::Span new_operands, + HloCloneContext* context) const { + CHECK_EQ(new_operands.size(), 1); + return absl::make_unique( + shape, new_operands[0], is_cross_program_prefetch()); +} + HloCompareInstruction::HloCompareInstruction( const Shape& shape, HloInstruction* lhs, HloInstruction* rhs, ComparisonDirection direction, absl::optional type) @@ -1027,6 +1068,25 @@ HloBroadcastInstruction::CloneWithNewOperandsImpl( dimensions()); } +HloDynamicReshapeInstruction::HloDynamicReshapeInstruction( + const Shape& shape, HloInstruction* data_operand, + absl::Span dim_sizes) + : HloInstruction(HloOpcode::kDynamicReshape, shape) { + AppendOperand(data_operand); + for (auto operand : dim_sizes) { + AppendOperand(operand); + } +} + +std::unique_ptr +HloDynamicReshapeInstruction::CloneWithNewOperandsImpl( + const Shape& shape, absl::Span new_operands, + HloCloneContext* context) const { + CHECK_GE(new_operands.size(), 1); + return absl::make_unique( + shape, new_operands[0], new_operands.subspan(1)); +} + HloReshapeInstruction::HloReshapeInstruction(const Shape& shape, HloInstruction* operand, int64 inferred_dimension) diff --git a/tensorflow/compiler/xla/service/hlo_instructions.h b/tensorflow/compiler/xla/service/hlo_instructions.h index 3f92bb92f02..17368e8b714 100644 --- a/tensorflow/compiler/xla/service/hlo_instructions.h +++ b/tensorflow/compiler/xla/service/hlo_instructions.h @@ -132,6 +132,28 @@ class HloFftInstruction : public HloInstruction { std::vector fft_length_; }; +class HloCopyStartInstruction : public HloInstruction { + public: + explicit HloCopyStartInstruction(const Shape& shape, HloInstruction* operand, + bool is_cross_program_prefetch); + + bool is_cross_program_prefetch() const { return is_cross_program_prefetch_; } + HloInstructionProto ToProto() const override; + + private: + std::vector ExtraAttributesToStringImpl( + const HloPrintOptions& options) const override; + bool IdenticalSlowPath( + const HloInstruction& other, + const std::function& + eq_computations) const override; + std::unique_ptr CloneWithNewOperandsImpl( + const Shape& shape, absl::Span new_operands, + HloCloneContext* context) const override; + + bool is_cross_program_prefetch_; +}; + class HloCompareInstruction : public HloInstruction { public: explicit HloCompareInstruction(const Shape& shape, HloInstruction* lhs, @@ -679,6 +701,25 @@ class HloBroadcastInstruction : public HloInstruction { std::vector dimensions_; }; +class HloDynamicReshapeInstruction : public HloInstruction { + public: + explicit HloDynamicReshapeInstruction( + const Shape& shape, HloInstruction* data_operand, + absl::Span dim_sizes); + + // Returns the input dim sizes dimensions, which is operands[1:] + absl::Span dim_sizes() const { + return absl::MakeSpan(operands()).subspan(1, operand_count()); + } + + std::unique_ptr CloneWithNewOperandsImpl( + const Shape& shape, absl::Span new_operands, + HloCloneContext* context) const override; + + // Returns the input dim size dimension, which is operands[1+i] + HloInstruction* dim_sizes(int64 i) const { return operands()[i + 1]; } +}; + class HloReshapeInstruction : public HloInstruction { public: explicit HloReshapeInstruction(const Shape& shape, HloInstruction* operand, diff --git a/tensorflow/compiler/xla/service/hlo_matchers_test.cc b/tensorflow/compiler/xla/service/hlo_matchers_test.cc index cb5cbd05d65..9c6509d8b73 100644 --- a/tensorflow/compiler/xla/service/hlo_matchers_test.cc +++ b/tensorflow/compiler/xla/service/hlo_matchers_test.cc @@ -276,10 +276,10 @@ TEST_F(HloMatchersTest, AsyncCopyMatcher) { /*element_size_in_bits=*/0, /*memory_space=*/2); auto p0 = HloInstruction::CreateParameter(0, shape_memspace1, "p0"); - auto copy_start = HloInstruction::CreateUnary( + auto copy_start = HloInstruction::CreateCopyStart( ShapeUtil::MakeTupleShape( {shape_memspace2, shape_memspace1, ShapeUtil::MakeShape(U32, {})}), - HloOpcode::kCopyStart, p0.get()); + p0.get()); auto copy_done = HloInstruction::CreateUnary( shape_memspace2, HloOpcode::kCopyDone, copy_start.get()); diff --git a/tensorflow/compiler/xla/service/hlo_memory_scheduler_test.cc b/tensorflow/compiler/xla/service/hlo_memory_scheduler_test.cc index 8ee8d332aff..076e31dc8eb 100644 --- a/tensorflow/compiler/xla/service/hlo_memory_scheduler_test.cc +++ b/tensorflow/compiler/xla/service/hlo_memory_scheduler_test.cc @@ -50,9 +50,9 @@ int64 PeakMemoryUseOfEntryComputation( HloComputation* computation = module->entry_computation(); const HloInstructionSequence& sequence = schedule.sequence(computation); - return HeapSimulator::Run(absl::make_unique(), - *computation, sequence, *alias_analysis, - size_function) + return HeapSimulator::Run( + absl::make_unique>(), + *computation, sequence, *alias_analysis, size_function) .ValueOrDie() .heap_size; } diff --git a/tensorflow/compiler/xla/service/hlo_opcode.h b/tensorflow/compiler/xla/service/hlo_opcode.h index 1625d0bbae4..b50c7d9a584 100644 --- a/tensorflow/compiler/xla/service/hlo_opcode.h +++ b/tensorflow/compiler/xla/service/hlo_opcode.h @@ -123,6 +123,7 @@ namespace xla { V(kRemainder, "remainder", 2) \ V(kReplicaId, "replica-id", 0) \ V(kReshape, "reshape", 1) \ + V(kDynamicReshape, "dynamic-reshape", kHloOpcodeIsVariadic) \ V(kReverse, "reverse", 1) \ V(kRng, "rng", kHloOpcodeIsVariadic) \ V(kRngGetAndUpdateState, "rng-get-and-update-state", 0) \ diff --git a/tensorflow/compiler/xla/service/hlo_opcode_test.cc b/tensorflow/compiler/xla/service/hlo_opcode_test.cc index 136e6702b21..cceb60a70e9 100644 --- a/tensorflow/compiler/xla/service/hlo_opcode_test.cc +++ b/tensorflow/compiler/xla/service/hlo_opcode_test.cc @@ -58,6 +58,7 @@ TEST(HloOpcodeTest, OpcodeProperties) { case HloOpcode::kCustomCall: case HloOpcode::kDynamicSlice: case HloOpcode::kDynamicUpdateSlice: + case HloOpcode::kDynamicReshape: case HloOpcode::kFusion: case HloOpcode::kMap: case HloOpcode::kReduce: diff --git a/tensorflow/compiler/xla/service/hlo_parser.cc b/tensorflow/compiler/xla/service/hlo_parser.cc index 2afa06a5df4..e2bbda3a607 100644 --- a/tensorflow/compiler/xla/service/hlo_parser.cc +++ b/tensorflow/compiler/xla/service/hlo_parser.cc @@ -883,7 +883,6 @@ bool HloParserImpl::ParseInstructionRhs(HloComputation::Builder* builder, case HloOpcode::kClz: case HloOpcode::kCollectivePermuteDone: case HloOpcode::kCopy: - case HloOpcode::kCopyStart: case HloOpcode::kCopyDone: case HloOpcode::kCos: case HloOpcode::kExp: @@ -1091,6 +1090,20 @@ bool HloParserImpl::ParseInstructionRhs(HloComputation::Builder* builder, } break; } + case HloOpcode::kCopyStart: { + // If the is_cross_program_prefetch attribute is not present then default + // to false. + optional is_cross_program_prefetch = false; + attrs["is_cross_program_prefetch"] = {/*required=*/false, AttrTy::kBool, + &is_cross_program_prefetch}; + if (!ParseOperands(&operands, /*expected_size=*/1) || + !ParseAttributes(attrs)) { + return false; + } + instruction = builder->AddInstruction(HloInstruction::CreateCopyStart( + shape, operands[0], *is_cross_program_prefetch)); + break; + } case HloOpcode::kReplicaId: { if (!ParseOperands(&operands, /*expected_size=*/0) || !ParseAttributes(attrs)) { @@ -1108,6 +1121,16 @@ bool HloParserImpl::ParseInstructionRhs(HloComputation::Builder* builder, builder->AddInstruction(HloInstruction::CreatePartitionId()); break; } + case HloOpcode::kDynamicReshape: { + if (!ParseOperands(&operands) || !ParseAttributes(attrs)) { + return false; + } + instruction = + builder->AddInstruction(HloInstruction::CreateDynamicReshape( + shape, operands[0], + absl::Span(operands).subspan(1))); + break; + } case HloOpcode::kReshape: { optional inferred_dimension; attrs["inferred_dimension"] = {/*required=*/false, AttrTy::kInt64, diff --git a/tensorflow/compiler/xla/service/hlo_parser_test.cc b/tensorflow/compiler/xla/service/hlo_parser_test.cc index aba6aeff999..620e67c3a2f 100644 --- a/tensorflow/compiler/xla/service/hlo_parser_test.cc +++ b/tensorflow/compiler/xla/service/hlo_parser_test.cc @@ -318,7 +318,7 @@ R"(HloModule CopyStartAndCopyDone_module ENTRY %CopyStartAndCopyDone (v1: f32[], v2: f32[2,3]) -> (f32[], f32[2,3]) { %v1 = f32[] parameter(0) - %copy-start.1 = (f32[], f32[], u32[]) copy-start(f32[] %v1) + %copy-start.1 = (f32[], f32[], u32[]) copy-start(f32[] %v1), is_cross_program_prefetch=true %copy-done.1 = f32[] copy-done((f32[], f32[], u32[]) %copy-start.1) %v2 = f32[2,3]{1,0:S(1)} parameter(1) %copy-start.2 = (f32[2,3]{1,0:S(2)}, f32[2,3]{1,0:S(1)}, u32[]) copy-start(f32[2,3]{1,0:S(1)} %v2) diff --git a/tensorflow/compiler/xla/service/hlo_runner.cc b/tensorflow/compiler/xla/service/hlo_runner.cc index 83130108dd7..3a5e7ca6f40 100644 --- a/tensorflow/compiler/xla/service/hlo_runner.cc +++ b/tensorflow/compiler/xla/service/hlo_runner.cc @@ -259,9 +259,15 @@ StatusOr> HloRunner::ExecuteReplicated( return ExecuteReplicated(executable.get(), options, device_assignment); } -StatusOr> HloRunner::ExecuteReplicated( - Executable* executable, const ReplicatedExecuteOptions& options, - DeviceAssignment* device_assignment, ExecutionProfile* profile) { +StatusOr> HloRunner::ExecuteReplicatedImpl( + std::function>( + const std::vector&, + const std::vector>&)> + execution_helper, + std::function argument_count_provider, + std::function argument_provider, + const ReplicatedExecuteOptions& options, + DeviceAssignment* device_assignment) { std::vector> streams; std::vector service_run_options; @@ -269,12 +275,19 @@ StatusOr> HloRunner::ExecuteReplicated( // This reserve() call is necessary for correctness, because // argument_buffer_ptrs contains pointers into the elements of // argument_buffers. - argument_buffers.reserve(options.num_replicas * options.arguments.size()); + const int64 total_argument_count = [&]() { + int64 total = 0; + for (int64 i = 0; i < options.num_replicas; ++i) { + total += argument_count_provider(i); + } + return total; + }(); + argument_buffers.reserve(total_argument_count); // Plus one so we can safely get &argument_buffer_ptrs[0] in case there are // no arguments. - std::vector argument_buffer_ptrs( - options.num_replicas * options.arguments.size() + 1); + std::vector argument_buffer_ptrs(total_argument_count + + 1); std::vector> argument_buffer_slices; int64 index = 0; RunId run_id; @@ -288,7 +301,10 @@ StatusOr> HloRunner::ExecuteReplicated( device, streams.back().get(), device_assignment, run_id)); // Copy arguments to device. - for (const Literal* argument : options.arguments) { + const int64 argument_count = argument_count_provider(i); + for (int64 arg_index = 0; arg_index < argument_count; arg_index++) { + const Literal* const argument = argument_provider(i, arg_index); + TF_RET_CHECK(argument != nullptr); TF_ASSIGN_OR_RETURN( ScopedShapedBuffer argument_buffer, backend().transfer_manager()->AllocateScopedShapedBuffer( @@ -299,8 +315,7 @@ StatusOr> HloRunner::ExecuteReplicated( argument_buffer_ptrs[index++] = &argument_buffers.back(); } argument_buffer_slices.emplace_back( - &argument_buffer_ptrs[index - options.arguments.size()], - options.arguments.size()); + &argument_buffer_ptrs[index - argument_count], argument_count); } std::unique_ptr pool; @@ -355,39 +370,9 @@ StatusOr> HloRunner::ExecuteReplicated( } LOG(INFO) << "Replicated execution started"; - std::vector results; - if (!options.use_threads) { - TF_ASSIGN_OR_RETURN(results, - executable->ExecuteOnStreams(service_run_options, - argument_buffer_slices)); - } else { - tensorflow::mutex mutex; - std::vector> thread_results( - options.num_replicas); - { - LOG(INFO) << "Creating thread pool for " << options.num_replicas - << " replicas"; - tensorflow::thread::ThreadPool pool(tensorflow::Env::Default(), - "replicas", options.num_replicas); - for (int64 i = 0; i < options.num_replicas; ++i) { - pool.Schedule([&, i] { - auto result = executable->ExecuteOnStream( - &service_run_options[i], argument_buffer_slices[i], nullptr); - tensorflow::mutex_lock lock(mutex); - thread_results[i] = std::move(result); - }); - } - - // Note: the thread pool destructor guarantees it completes all work - // before we leave this scope. - } - for (auto& thread_result : thread_results) { - if (!thread_result.ok()) { - return thread_result.status(); - } - results.push_back(std::move(thread_result).ValueOrDie()); - } - } + TF_ASSIGN_OR_RETURN( + std::vector results, + execution_helper(service_run_options, argument_buffer_slices)); LOG(INFO) << "Replicated execution terminated"; std::vector exec_results; @@ -401,6 +386,104 @@ StatusOr> HloRunner::ExecuteReplicated( return std::move(exec_results); } +StatusOr> HloRunner::ExecuteReplicated( + Executable* executable, const ReplicatedExecuteOptions& options, + DeviceAssignment* device_assignment, ExecutionProfile* profile) { + return ExecuteReplicatedImpl( + [&](const std::vector& service_run_options, + const std::vector>& + argument_buffer_slices) + -> StatusOr> { + std::vector results; + if (!options.use_threads) { + TF_ASSIGN_OR_RETURN( + results, executable->ExecuteOnStreams(service_run_options, + argument_buffer_slices)); + } else { + tensorflow::mutex mutex; + std::vector> thread_results( + options.num_replicas); + { + LOG(INFO) << "Creating thread pool for " << options.num_replicas + << " replicas"; + tensorflow::thread::ThreadPool pool( + tensorflow::Env::Default(), "replicas", options.num_replicas); + for (int64 i = 0; i < options.num_replicas; ++i) { + pool.Schedule([&, i] { + auto result = executable->ExecuteOnStream( + &service_run_options[i], argument_buffer_slices[i], + nullptr); + tensorflow::mutex_lock lock(mutex); + thread_results[i] = std::move(result); + }); + } + + // Note: the thread pool destructor guarantees it completes all work + // before we leave this scope. + } + for (auto& thread_result : thread_results) { + if (!thread_result.ok()) { + return thread_result.status(); + } + results.push_back(std::move(thread_result).ValueOrDie()); + } + } + return results; + }, + [&](int64 replica) { return options.arguments.size(); }, + [&](int64 replica, int64 index) { return options.arguments[index]; }, + options, device_assignment); +} + +StatusOr> HloRunner::ExecuteReplicated( + std::function executable_provider, + std::function argument_count_provider, + std::function argument_provider, + const ReplicatedExecuteOptions& options) { + TF_ASSIGN_OR_RETURN( + DeviceAssignment device_assignment, + backend().computation_placer()->AssignDevices(options.num_replicas, 1)); + return ExecuteReplicatedImpl( + [&](const std::vector& service_run_options, + const std::vector>& + argument_buffer_slices) + -> StatusOr> { + TF_RET_CHECK(options.use_threads); + std::vector results; + tensorflow::mutex mutex; + std::vector> thread_results( + options.num_replicas); + { + LOG(INFO) << "Creating thread pool for " << options.num_replicas + << " replicas"; + tensorflow::thread::ThreadPool pool(tensorflow::Env::Default(), + "replicas", options.num_replicas); + for (int64 i = 0; i < options.num_replicas; ++i) { + for (const auto& arg : argument_buffer_slices[i]) { + TF_RET_CHECK(arg != nullptr); + } + pool.Schedule([&, i] { + auto result = executable_provider(i)->ExecuteOnStream( + &service_run_options[i], argument_buffer_slices[i], nullptr); + tensorflow::mutex_lock lock(mutex); + thread_results[i] = std::move(result); + }); + } + + // Note: the thread pool destructor guarantees it completes all work + // before we leave this scope. + } + for (auto& thread_result : thread_results) { + if (!thread_result.ok()) { + return thread_result.status(); + } + results.push_back(std::move(thread_result).ValueOrDie()); + } + return results; + }, + argument_count_provider, argument_provider, options, &device_assignment); +} + StatusOr> HloRunner::ExecuteReplicated( std::unique_ptr module, const ReplicatedExecuteOptions& options) { diff --git a/tensorflow/compiler/xla/service/hlo_runner.h b/tensorflow/compiler/xla/service/hlo_runner.h index 7e8b301ab54..733bb8bff54 100644 --- a/tensorflow/compiler/xla/service/hlo_runner.h +++ b/tensorflow/compiler/xla/service/hlo_runner.h @@ -176,6 +176,17 @@ class HloRunner { Executable* executable, const ReplicatedExecuteOptions& options, DeviceAssignment* device_assignment, ExecutionProfile* profile = nullptr); + // Same as above, but with different reusable Executables. This may update the + // profile information in *executables. + // + // Note that this call ignores ReplicatedExecutionOptions::run_hlo_passes, + // since we've already compiled the Executable. + StatusOr> ExecuteReplicated( + std::function executable_provider, + std::function argument_count_provider, + std::function argument_provider, + const ReplicatedExecuteOptions& options); + // If backend is not created in the constructor, creates and returns the // default backend. If creation fails, crashes the program. // @@ -193,6 +204,17 @@ class HloRunner { int64 device, se::Stream* stream, DeviceAssignment* device_assignment, RunId run_id); + // Common implementation code for ExecuteReplicated() above. + StatusOr> ExecuteReplicatedImpl( + std::function>( + const std::vector&, + const std::vector>&)> + execution_helper, + std::function argument_count_provider, + std::function argument_provider, + const ReplicatedExecuteOptions& options, + DeviceAssignment* device_assignment); + std::unique_ptr backend_; }; diff --git a/tensorflow/compiler/xla/service/hlo_sharding_util.cc b/tensorflow/compiler/xla/service/hlo_sharding_util.cc index 007b6158fc2..e1e506b2892 100644 --- a/tensorflow/compiler/xla/service/hlo_sharding_util.cc +++ b/tensorflow/compiler/xla/service/hlo_sharding_util.cc @@ -106,21 +106,28 @@ HloSharding TransposeSharding(const HloSharding& sharding, if (sharding.IsTileMaximal()) { return sharding; } - const int64 rank = dimensions.size(); + auto perm_dimensions = dimensions; + if (sharding.ReplicateOnLastTileDim() && + dimensions.size() < sharding.tile_assignment().num_dimensions()) { + perm_dimensions.push_back(dimensions.size()); + } + const int64 rank = perm_dimensions.size(); std::vector tile_assignment_dim(rank); for (int64 i = 0; i < rank; ++i) { - tile_assignment_dim[i] = sharding.tile_assignment().dim(dimensions[i]); + tile_assignment_dim[i] = sharding.tile_assignment().dim(perm_dimensions[i]); } Array tile_assignment = sharding.tile_assignment(); tile_assignment.Reshape(tile_assignment_dim); tile_assignment.Each([&](absl::Span indices, int64* value) { std::vector src_indices(indices.size(), -1); for (int64 i = 0; i < indices.size(); ++i) { - src_indices[dimensions[i]] = indices[i]; + src_indices[perm_dimensions[i]] = indices[i]; } *value = sharding.tile_assignment()(src_indices); }); - return HloSharding::Tile(tile_assignment); + return sharding.ReplicateOnLastTileDim() + ? HloSharding::PartialTile(tile_assignment) + : HloSharding::Tile(tile_assignment); } absl::optional ReshapeSharding(const Shape& source_shape, @@ -227,8 +234,14 @@ absl::optional ReshapeSharding(const Shape& source_shape, } } Array new_tile_assignment = sharding.tile_assignment(); + if (sharding.ReplicateOnLastTileDim()) { + target_tile_assignment_dimensions.push_back( + sharding.tile_assignment().dimensions().back()); + } new_tile_assignment.Reshape(target_tile_assignment_dimensions); - return HloSharding::Tile(new_tile_assignment); + return sharding.ReplicateOnLastTileDim() + ? HloSharding::PartialTile(new_tile_assignment) + : HloSharding::Tile(new_tile_assignment); } HloSharding ReverseSharding(const HloSharding& sharding, @@ -246,7 +259,9 @@ HloSharding ReverseSharding(const HloSharding& sharding, } *device = sharding.tile_assignment()(original_indices); }); - return HloSharding::Tile(new_tile_assignment); + return sharding.ReplicateOnLastTileDim() + ? HloSharding::PartialTile(new_tile_assignment) + : HloSharding::Tile(new_tile_assignment); } HloSharding ReshapeToTileDimension(const HloSharding& sharding, int64 dim, @@ -343,6 +358,7 @@ HloSharding GatherOutputSharding(const HloSharding& index_sharding, HloSharding GatherIndexSharding(const HloSharding& output_sharding, const HloInstruction* hlo) { + CHECK(hlo->opcode() == HloOpcode::kGather); if (output_sharding.IsTileMaximal()) { return output_sharding; } @@ -355,6 +371,14 @@ HloSharding GatherIndexSharding(const HloSharding& output_sharding, output_sharding.tile_assignment().dim(i)); } } + int64 index_rank = hlo->operand(1)->shape().rank(); + + // Vector indices sharding is not supported yet. + if (index_rank > index_tile_assignment_dims.size()) { + index_tile_assignment_dims.insert( + index_tile_assignment_dims.begin() + dnums.index_vector_dim(), 1); + } + Array new_tile_assignment = output_sharding.tile_assignment(); if (new_tile_assignment.num_elements() != Product(index_tile_assignment_dims)) { diff --git a/tensorflow/compiler/xla/service/hlo_verifier.cc b/tensorflow/compiler/xla/service/hlo_verifier.cc index d395fddcc5d..0af2a45bfc7 100644 --- a/tensorflow/compiler/xla/service/hlo_verifier.cc +++ b/tensorflow/compiler/xla/service/hlo_verifier.cc @@ -703,6 +703,20 @@ Status ShapeVerifier::HandleBroadcast(HloInstruction* broadcast) { return Status::OK(); } +Status ShapeVerifier::HandleDynamicReshape(HloInstruction* dynamic_reshape) { + // Check for mixed precision. + const Shape& operand_shape = dynamic_reshape->operand(0)->shape(); + TF_RET_CHECK(SameElementType(dynamic_reshape->shape(), operand_shape)); + TF_RET_CHECK(ShapeUtil::ElementsIn(dynamic_reshape->shape()) == + ShapeUtil::ElementsIn(operand_shape)); + TF_RET_CHECK(dynamic_reshape->shape().rank() + 1 == + dynamic_reshape->operand_count()); + for (int64 i = 1; i < dynamic_reshape->operand_count(); ++i) { + TF_RET_CHECK(dynamic_reshape->operand(i)->shape().element_type() == S32); + } + return Status::OK(); +} + Status ShapeVerifier::HandleReshape(HloInstruction* reshape) { // Check for mixed precision. const Shape& operand_shape = reshape->operand(0)->shape(); diff --git a/tensorflow/compiler/xla/service/hlo_verifier.h b/tensorflow/compiler/xla/service/hlo_verifier.h index 85b02e0518c..03fca5938ff 100644 --- a/tensorflow/compiler/xla/service/hlo_verifier.h +++ b/tensorflow/compiler/xla/service/hlo_verifier.h @@ -78,6 +78,7 @@ class ShapeVerifier : public DfsHloVisitor { Status HandleBitcast(HloInstruction* bitcast) override; Status HandleBroadcast(HloInstruction* broadcast) override; Status HandleReshape(HloInstruction* reshape) override; + Status HandleDynamicReshape(HloInstruction* dynamic_reshape) override; Status HandleTranspose(HloInstruction* transpose) override; Status HandleParameter(HloInstruction*) override; Status HandleFusion(HloInstruction*) override; diff --git a/tensorflow/compiler/xla/service/instruction_fusion.cc b/tensorflow/compiler/xla/service/instruction_fusion.cc index 8d8930615b2..b290b1bd68b 100644 --- a/tensorflow/compiler/xla/service/instruction_fusion.cc +++ b/tensorflow/compiler/xla/service/instruction_fusion.cc @@ -102,6 +102,7 @@ bool IsAlwaysDuplicable(const HloInstruction& instruction) { case HloOpcode::kReducePrecision: case HloOpcode::kReplicaId: case HloOpcode::kReshape: + case HloOpcode::kDynamicReshape: case HloOpcode::kReverse: case HloOpcode::kRoundNearestAfz: case HloOpcode::kSelect: diff --git a/tensorflow/compiler/xla/service/interpreter/executable.cc b/tensorflow/compiler/xla/service/interpreter/executable.cc index cc7fdeaf0f6..1446b55f5a8 100644 --- a/tensorflow/compiler/xla/service/interpreter/executable.cc +++ b/tensorflow/compiler/xla/service/interpreter/executable.cc @@ -52,6 +52,7 @@ InterpreterExecutable::InterpreterExecutable( } StatusOr InterpreterExecutable::Evaluate( + const ServiceExecutableRunOptions* run_options, const HloComputation& computation, absl::Span arg_literals) { // Execute the graph using the HloEvaluator. tensorflow::mutex_lock lock(evaluator_lock_); diff --git a/tensorflow/compiler/xla/service/interpreter/executable.h b/tensorflow/compiler/xla/service/interpreter/executable.h index ce68a8472f5..514ed029a22 100644 --- a/tensorflow/compiler/xla/service/interpreter/executable.h +++ b/tensorflow/compiler/xla/service/interpreter/executable.h @@ -51,7 +51,8 @@ class InterpreterExecutable : public InterpreterExecutableBase { static int64 ShapeSizeBytes(const Shape& shape); protected: - StatusOr Evaluate(const HloComputation& computation, + StatusOr Evaluate(const ServiceExecutableRunOptions* run_options, + const HloComputation& computation, absl::Span arg_literals) override TF_LOCKS_EXCLUDED(evaluator_lock_); diff --git a/tensorflow/compiler/xla/service/interpreter/executable_base.cc b/tensorflow/compiler/xla/service/interpreter/executable_base.cc index 4b6a8aa5202..745750bffe1 100644 --- a/tensorflow/compiler/xla/service/interpreter/executable_base.cc +++ b/tensorflow/compiler/xla/service/interpreter/executable_base.cc @@ -50,11 +50,15 @@ StatusOr InterpreterExecutableBase::ExecuteAsyncOnStream( // TransferManager methods below. std::vector argument_buffers; argument_buffers.reserve(arguments.size()); + int device_ordinal = run_options->device_ordinal(); + if (device_ordinal < 0) { + device_ordinal = 0; + } for (auto& argument : arguments) { const ShapeTree& buffers = argument.Buffers(); argument_buffers.push_back(ShapedBuffer(buffers.shape(), buffers.shape(), /*platform=*/nullptr, - /*device_ordinal=*/0)); + /*device_ordinal=*/device_ordinal)); auto in_it = buffers.begin(); auto out_it = argument_buffers.back().buffers().begin(); for (; in_it != buffers.end(); ++in_it, ++out_it) { @@ -118,7 +122,7 @@ StatusOr InterpreterExecutableBase::ExecuteAsyncOnStream( } TF_ASSIGN_OR_RETURN(Literal result_literal, - Evaluate(*computation, arg_literals)); + Evaluate(run_options, *computation, arg_literals)); // Shrink the generated dynamic shape into static shape. result_literal = result_literal.ToStatic(); diff --git a/tensorflow/compiler/xla/service/interpreter/executable_base.h b/tensorflow/compiler/xla/service/interpreter/executable_base.h index a02ab7af8d0..eb47841a179 100644 --- a/tensorflow/compiler/xla/service/interpreter/executable_base.h +++ b/tensorflow/compiler/xla/service/interpreter/executable_base.h @@ -44,6 +44,7 @@ class InterpreterExecutableBase : public Executable { protected: virtual StatusOr Evaluate( + const ServiceExecutableRunOptions* run_options, const HloComputation& computation, absl::Span arg_literals) = 0; diff --git a/tensorflow/compiler/xla/service/layout_assignment.cc b/tensorflow/compiler/xla/service/layout_assignment.cc index bea0f1fb93c..55569cfde0e 100644 --- a/tensorflow/compiler/xla/service/layout_assignment.cc +++ b/tensorflow/compiler/xla/service/layout_assignment.cc @@ -1891,7 +1891,7 @@ Status LayoutAssignment::RunOnComputation( ? ShapeUtil::GetSubshape(instruction->literal().shape(), buffer.index()) .layout() - : LayoutUtil::GetDefaultLayoutForShape(buffer.shape()); + : GetUnconstrainedLayout(buffer); TF_RETURN_IF_ERROR(constraints.SetBufferLayout(new_layout, buffer, /*mandatory=*/false)); @@ -2278,6 +2278,7 @@ bool LayoutAssignment::InstructionCanChangeLayout( case HloOpcode::kReduce: case HloOpcode::kReplicaId: case HloOpcode::kReshape: + case HloOpcode::kDynamicReshape: case HloOpcode::kRng: case HloOpcode::kRngBitGenerator: case HloOpcode::kRngGetAndUpdateState: diff --git a/tensorflow/compiler/xla/service/layout_assignment.h b/tensorflow/compiler/xla/service/layout_assignment.h index a04d056c618..def620bcee9 100644 --- a/tensorflow/compiler/xla/service/layout_assignment.h +++ b/tensorflow/compiler/xla/service/layout_assignment.h @@ -27,6 +27,7 @@ limitations under the License. #include "absl/container/flat_hash_map.h" #include "absl/container/flat_hash_set.h" +#include "tensorflow/compiler/xla/layout_util.h" #include "tensorflow/compiler/xla/service/call_graph.h" #include "tensorflow/compiler/xla/service/computation_layout.h" #include "tensorflow/compiler/xla/service/hlo_computation.h" @@ -338,6 +339,9 @@ class LayoutAssignment : public HloModulePass { const ResultLayoutConstraint& layout_constraint, LayoutConstraints* constraints); + virtual Layout GetUnconstrainedLayout(const LogicalBuffer& buffer) { + return LayoutUtil::GetDefaultLayoutForShape(buffer.shape()); + } // Called after layouts of an instruction have been finalized to allow // subclasses to check for platform specific assumptions. virtual Status Verify(const HloInstruction* instruction) { diff --git a/tensorflow/compiler/xla/service/llvm_ir/llvm_util.cc b/tensorflow/compiler/xla/service/llvm_ir/llvm_util.cc index b01ae2efe43..2963d546380 100644 --- a/tensorflow/compiler/xla/service/llvm_ir/llvm_util.cc +++ b/tensorflow/compiler/xla/service/llvm_ir/llvm_util.cc @@ -415,9 +415,10 @@ llvm::Instruction* AddRangeMetadata(int64 lower, int64 upper, return inst; } -string IrName(string a) { - a.erase(std::remove(a.begin(), a.end(), '%'), a.end()); - return a; +string IrName(absl::string_view a) { + std::string s(a); + s.erase(std::remove(s.begin(), s.end(), '%'), s.end()); + return s; } string IrName(absl::string_view a, absl::string_view b) { diff --git a/tensorflow/compiler/xla/service/llvm_ir/llvm_util.h b/tensorflow/compiler/xla/service/llvm_ir/llvm_util.h index 642965b6470..c0a55e4da33 100644 --- a/tensorflow/compiler/xla/service/llvm_ir/llvm_util.h +++ b/tensorflow/compiler/xla/service/llvm_ir/llvm_util.h @@ -87,7 +87,7 @@ string DumpModuleToString(const llvm::Module& module); // - joining all of the nonempty inputs by '.', and then // - removing all '%'s. // -string IrName(string a); +string IrName(absl::string_view a); string IrName(absl::string_view a, absl::string_view b); string IrName(const HloInstruction* a, absl::string_view b = ""); diff --git a/tensorflow/compiler/xla/service/memory_space_assignment.cc b/tensorflow/compiler/xla/service/memory_space_assignment.cc index c5ae0573bed..c53f2c19695 100644 --- a/tensorflow/compiler/xla/service/memory_space_assignment.cc +++ b/tensorflow/compiler/xla/service/memory_space_assignment.cc @@ -80,7 +80,7 @@ float MemorySpaceAssignmentCostAnalysis::GetAlternateMemoryBenefit( } float MemorySpaceAssignmentCostAnalysis::GetMemoryBoundedness( - const GlobalDecreasingSizeBestFitHeap::BufferInterval& interval, + const GlobalDecreasingSizeBestFitHeap::BufferInterval& interval, MemorySpaceAssignmentCostAnalysis::Cache* cache) const { const HloInstruction& defining_instruction = *interval.buffer->defining_instruction(); @@ -236,15 +236,26 @@ int64 InstructionCountPrefetchIntervalPicker::PreferredEvictionEndTime( } int64 InstructionCountPrefetchIntervalPicker::LatestPrefetchStartTime( - const HloUse& use, int64 start_time, int64 end_time) const { + const Shape& shape, int64 start_time, int64 end_time, + const HloUse* use) const { return end_time - min_overlap_count_; } +int64 InstructionCountPrefetchIntervalPicker::PreferredPrefetchStartTime( + const Shape& shape, int64 earliest_prefetch_start_time, + int64 latest_prefetch_start_time, int64 prefetch_end_time) const { + return std::max(earliest_prefetch_start_time, + prefetch_end_time - max_overlap_count_); +} + void InstructionCountPrefetchIntervalPicker::Begin(const HloUse& use, int64 start_time, int64 end_time) { end_time_ = end_time; - current_prefetch_time_ = std::max(start_time, end_time_ - max_overlap_count_); + const Shape& shape = ShapeUtil::GetSubshape( + use.instruction->operand(use.operand_number)->shape(), use.operand_index); + current_prefetch_time_ = + PreferredPrefetchStartTime(shape, start_time, end_time, end_time); } int64 InstructionCountPrefetchIntervalPicker::Next() { @@ -361,18 +372,22 @@ int64 CostAnalysisPrefetchIntervalPicker::PreferredEvictionEndTime( } int64 CostAnalysisPrefetchIntervalPicker::LatestPrefetchStartTime( - const HloUse& use, int64 start_time, int64 end_time) const { - const Shape& shape = ShapeUtil::GetSubshape( - use.instruction->operand(use.operand_number)->shape(), use.operand_index); + const Shape& shape, int64 start_time, int64 end_time, + const HloUse* use) const { // Find the earliest time that satisfies max_async_copy_to_overlap_ratio_. float async_copy_elapsed = cost_analysis_.GetAsyncCopyElapsed(shape); - // Estimate the time we would save by having this op in alternate memory. - float elapsed_time = cost_analysis_.GetInstructionElapsed(*use.instruction); - float elapsed_time_in_alternate_mem = - cost_analysis_.GetInstructionElapsedInAlternateMemory( - *use.instruction, use.operand_number, - /*output_in_alternate_mem=*/false); - float inst_elapsed_reduction = elapsed_time - elapsed_time_in_alternate_mem; + // If there is a use, estimate the time we would save by having this op in + // alternate memory. + float inst_elapsed_reduction = 0.0f; + if (use) { + float elapsed_time = + cost_analysis_.GetInstructionElapsed(*use->instruction); + float elapsed_time_in_alternate_mem = + cost_analysis_.GetInstructionElapsedInAlternateMemory( + *use->instruction, use->operand_number, + /*output_in_alternate_mem=*/false); + inst_elapsed_reduction = elapsed_time - elapsed_time_in_alternate_mem; + } int end_nest_level = while_nest_level_[end_time]; // Find the latest time we're allowed to start prefetching. @@ -390,6 +405,33 @@ int64 CostAnalysisPrefetchIntervalPicker::LatestPrefetchStartTime( return latest_prefetch_time; } +int64 CostAnalysisPrefetchIntervalPicker::PreferredPrefetchStartTime( + const Shape& shape, int64 earliest_prefetch_start_time, + int64 latest_prefetch_start_time, int64 prefetch_end_time) const { + // Between the earliest and latest prefetch interval, find the interval + // closest to the preferred interval and start iterating from there. + float async_copy_elapsed = cost_analysis_.GetAsyncCopyElapsed(shape); + int64 preferred_prefetch_start_time = earliest_prefetch_start_time; + float preferred_interval = + preferred_async_copy_to_overlap_ratio_ * async_copy_elapsed; + float best_interval = GetLogicalIntervalElapsed(earliest_prefetch_start_time, + prefetch_end_time); + int end_nest_level = while_nest_level_[prefetch_end_time]; + for (int64 prefetch_start_time = earliest_prefetch_start_time + 1; + prefetch_start_time <= latest_prefetch_start_time; + ++prefetch_start_time) { + float interval = + GetLogicalIntervalElapsed(prefetch_start_time, prefetch_end_time); + if (while_nest_level_[prefetch_start_time] == end_nest_level && + std::abs(preferred_interval - interval) < + std::abs(preferred_interval - best_interval)) { + best_interval = interval; + preferred_prefetch_start_time = prefetch_start_time; + } + } + return preferred_prefetch_start_time; +} + int64 CostAnalysisPrefetchIntervalPicker::LatestPrefetchEndTime( int64 original_prefetch_end_time, int64 proposed_prefetch_end_time) const { // Iterate towards the beginning until we find a suitable end time that is the @@ -422,7 +464,8 @@ void CostAnalysisPrefetchIntervalPicker::Begin(const HloUse& use, // Find the latest time we're allowed to start prefetching. float min_interval = min_async_copy_to_overlap_ratio_ * async_copy_elapsed_; - latest_prefetch_time_ = LatestPrefetchStartTime(use, start_time, end_time); + latest_prefetch_time_ = + LatestPrefetchStartTime(shape, start_time, end_time, &use); // Find the earliest time we're allowed to start prefetching. float max_interval = max_async_copy_to_overlap_ratio_ * @@ -443,24 +486,10 @@ void CostAnalysisPrefetchIntervalPicker::Begin(const HloUse& use, return; } - // Between the earliest and latest prefetch interval, find the interval - // closest to the preferred interval and start iterating from there. - int64 starting_prefetch_time = earliest_prefetch_time_; + int64 starting_prefetch_time = PreferredPrefetchStartTime( + shape, earliest_prefetch_time_, latest_prefetch_time_, end_logical_time_); float preferred_interval = preferred_async_copy_to_overlap_ratio_ * async_copy_elapsed_; - float best_interval = - GetLogicalIntervalElapsed(earliest_prefetch_time_, end_logical_time_); - for (int64 prefetch_time = earliest_prefetch_time_ + 1; - prefetch_time <= latest_prefetch_time_; ++prefetch_time) { - float interval = - GetLogicalIntervalElapsed(prefetch_time, end_logical_time_); - if (while_nest_level_[prefetch_time] == end_nest_level && - std::abs(preferred_interval - interval) < - std::abs(preferred_interval - best_interval)) { - best_interval = interval; - starting_prefetch_time = prefetch_time; - } - } VLOG(4) << "Interval min/max/preferred = " << min_interval << " " << max_interval << " " << preferred_interval << " prefetch time earliest/latest/starting = " @@ -570,7 +599,8 @@ std::string CostAnalysisPrefetchIntervalPicker::ToNoCopyDebugString( absl::optional CostAnalysisPrefetchIntervalPicker::BufferIntervalAlternateMemoryBenefit( - const GlobalDecreasingSizeBestFitHeap::BufferInterval& interval) const { + const GlobalDecreasingSizeBestFitHeap::BufferInterval& interval) + const { return cost_analysis_.GetMemoryBoundedness(interval); } @@ -733,9 +763,9 @@ void AlternateMemoryBestFitHeap::FindAliases( } } -std::vector +std::vector AlternateMemoryBestFitHeap::GetSortedColocatedIntervals( - const GlobalDecreasingSizeBestFitHeap::BufferInterval& interval) const { + const AlternateMemoryBestFitHeap::BufferInterval& interval) const { std::vector colocated_intervals; std::vector worklist = {&interval}; while (!worklist.empty()) { @@ -864,7 +894,7 @@ bool AlternateMemoryBestFitHeap::IsUseAllowedInAlternateMemory( } void AlternateMemoryBestFitHeap::AppendBufferInfoDebugString( - const GlobalDecreasingSizeBestFitHeap::BufferInterval& interval, + const AlternateMemoryBestFitHeap::BufferInterval& interval, std::string* debug_str) const { // Columns in buffer information: // buffer_id: int. This value can be used to match the allocation in @@ -954,7 +984,7 @@ void AlternateMemoryBestFitHeap::DumpDebugStringsIfEnabled() const { options_.dump_fn("allocinfo", allocation_info_str_); } -HeapSimulator::Result AlternateMemoryBestFitHeap::Finish() { +HeapSimulator::Result AlternateMemoryBestFitHeap::Finish() { std::vector sorted_buffer_intervals = GetSortedBufferIntervals(); @@ -1051,6 +1081,7 @@ HeapSimulator::Result AlternateMemoryBestFitHeap::Finish() { allocation_values); // Retry allocating this value with larger limits if allocation fails. + bool repacked = false; for (int retry_number = 0; retry_number < options_.max_retries; retry_number++) { bool final_retry = (retry_number == options_.max_retries - 1); @@ -1064,11 +1095,13 @@ HeapSimulator::Result AlternateMemoryBestFitHeap::Finish() { UncommitPendingChunks(absl::MakeSpan(allocation_values)); VLOG(2) << "Couldn't allocate. Retry number " << retry_number; } else if (result_is(result, Result::kFailOutOfMemory) && - num_repacks_ < options_.max_repacks) { + num_repacks_ < options_.max_repacks && !repacked) { UncommitPendingChunks(absl::MakeSpan(allocation_values)); ++num_repacks_; + repacked = true; CHECK_NE(options_.repacker, nullptr); - std::vector repack_allocation_blocks; + std::vector + repack_allocation_blocks; ExportAllocationsForRepacking(repack_allocation_blocks); VLOG(2) << "Repacking."; auto repack_status = @@ -1076,7 +1109,7 @@ HeapSimulator::Result AlternateMemoryBestFitHeap::Finish() { CHECK_EQ(repack_status.status(), Status::OK()); VLOG(2) << "Repack complete. Modified = " << *repack_status; if (*repack_status) { - ImportRepackedAllocations(absl::MakeSpan(repack_allocation_blocks)); + ImportRepackedAllocations(); --retry_number; } } else { @@ -1367,21 +1400,80 @@ void AlternateMemoryBestFitHeap::AllocateCrossProgramPrefetchBuffer( // Find the earliest use. const auto& instruction_schedule = hlo_live_range_.instruction_schedule(); auto uses = buffer->uses(); - auto first_use = - absl::c_min_element(uses, [&](const HloUse& lhs, const HloUse& rhs) { - return instruction_schedule.at(lhs.instruction) < - instruction_schedule.at(rhs.instruction); - }); + auto use_schedule_compare = [&](const HloUse& lhs, const HloUse& rhs) { + return instruction_schedule.at(lhs.instruction) < + instruction_schedule.at(rhs.instruction); + }; + auto first_use = absl::c_min_element(uses, use_schedule_compare); int64 latest_prefetch_time = instruction_schedule.at(first_use->instruction); + // Find the latest use time. + int64 last_use_time = instruction_schedule.at( + absl::c_max_element(uses, use_schedule_compare)->instruction); + for (const HloValue* colocation : prefetch_candidate->colocations) { + last_use_time = std::max( + last_use_time, + instruction_schedule.at( + absl::c_max_element(colocation->uses(), use_schedule_compare) + ->instruction)); + } + + int64 end_of_program_prefetch_end_time = instruction_schedule.size() - 1; + int64 end_of_program_prefetch_start_time = + options_.prefetch_interval_picker->PreferredPrefetchStartTime( + buffer->defining_position().shape(), last_use_time, + end_of_program_prefetch_end_time, end_of_program_prefetch_end_time); + VLOG(2) << "last use time = " << last_use_time + << ", end-of-program prefetch start time = " + << end_of_program_prefetch_start_time; + bool free_buffer = + (end_of_program_prefetch_start_time > last_use_time && + end_of_program_prefetch_start_time < end_of_program_prefetch_end_time); + int64 cross_program_prefetch_end_time = + free_buffer ? last_use_time : prefetch_candidate->end; + AddAsyncCopy(*allocations.back(), MemorySpace::kAlternate, chunk_candidate.chunk, prefetch_candidate->start, - prefetch_candidate->end, latest_prefetch_time, &allocations); + cross_program_prefetch_end_time, latest_prefetch_time, + &allocations, + /*is_cross_program_prefetch=*/true); absl::c_for_each(uses, [&](auto& use) { allocations.back()->AddUse(use); }); + int64 cross_program_prefetch_offset = allocations.back()->chunk().offset; + + if (free_buffer) { + VLOG(2) << "Adding an end-of-program prefetch for freed " + "cross-program-prefetched buffer."; + AddAsyncCopy(*allocations.front(), MemorySpace::kAlternate, + chunk_candidate.chunk, end_of_program_prefetch_start_time, + end_of_program_prefetch_end_time, + end_of_program_prefetch_end_time, &allocations); + CHECK_EQ(cross_program_prefetch_offset, allocations.back()->chunk().offset); + } + for (auto& allocation : allocations) { allocations_->push_back(std::move(allocation)); } + // Add a repack allocation block for the Allocation objects in alternate + // memory. + CHECK_EQ(repack_allocation_blocks_.size(), 0); + for (const auto& allocation : *allocations_) { + if (allocation->memory_space() == MemorySpace::kAlternate) { + repack_allocation_blocks_.push_back(MakeRepackAllocationBlock( + allocation->start_time(), allocation->end_time(), + allocation->chunk().size, allocation->chunk().offset, + static_cast(repack_allocation_blocks_.size()), + allocation.get())); + RepackAllocationBlock* inserted = &repack_allocation_blocks_.back(); + for (RepackAllocationBlock& colocation : repack_allocation_blocks_) { + colocation.colocations.push_back(inserted); + if (&colocation != inserted) { + inserted->colocations.push_back(&colocation); + } + } + } + } + ClearPendingChunks(); } @@ -1560,29 +1652,27 @@ bool AlternateMemoryBestFitHeap::AreIntervalsReservedInAlternateMemory( } void AlternateMemoryBestFitHeap::ExportAllocationsForRepacking( - std::vector& - allocations) { + std::vector& allocations) { for (RepackAllocationBlock& allocation_block : repack_allocation_blocks_) { allocations.push_back(&allocation_block); } } -void AlternateMemoryBestFitHeap::ImportRepackedAllocations( - absl::Span - repacked_allocations) { +void AlternateMemoryBestFitHeap::ImportRepackedAllocations() { interval_tree_ = {}; - for (RepackAllocationBlock* allocation_block : repacked_allocations) { - MemorySpaceAssignment::Allocation* allocation = allocation_block->opaque; + for (RepackAllocationBlock& allocation_block : repack_allocation_blocks_) { + MemorySpaceAssignment::Allocation* allocation = allocation_block.allocation; VLOG(3) << "Moved " << allocation->ToString() << ", size " - << allocation->chunk().size << " from " - << allocation_block->initial_offset << " to " - << allocation_block->offset; - allocation_block->opaque->mutable_chunk()->offset = - allocation_block->offset; - interval_tree_.Add(allocation_block->start_time, allocation_block->end_time, - {allocation_block->offset, allocation_block->size}); - allocation_block->initial_offset = allocation_block->offset; - allocation_block->offset = -1; + << allocation->chunk().size << ", (" << allocation_block.start_time + << ", " << allocation_block.end_time << ") from " + << allocation_block.initial_offset << " to " + << allocation_block.offset; + allocation_block.allocation->mutable_chunk()->offset = + allocation_block.offset; + interval_tree_.Add(allocation_block.start_time, allocation_block.end_time, + {allocation_block.offset, allocation_block.size}); + allocation_block.initial_offset = allocation_block.offset; + allocation_block.offset = -1; } } @@ -1655,17 +1745,19 @@ void AlternateMemoryBestFitHeap::FinalizeAllocations( // Export these to repack_allocation_blocks_ so that we can repack them to // reduce fragmentation. for (auto& colocation : colocation_map) { - std::vector colocations; + std::vector colocations; for (MemorySpaceAssignment::Allocation* colocated_allocation : colocation.second) { - repack_allocation_blocks_.push_back( - {colocated_allocation->start_time(), colocated_allocation->end_time(), - colocated_allocation->chunk().size, /*offset=*/-1, - colocated_allocation->chunk().offset, /*colocations=*/{}, - colocated_allocation}); + repack_allocation_blocks_.push_back(MakeRepackAllocationBlock( + colocated_allocation->start_time(), colocated_allocation->end_time(), + colocated_allocation->chunk().size, + colocated_allocation->chunk().offset, + static_cast(repack_allocation_blocks_.size()), + colocated_allocation)); colocations.push_back(&repack_allocation_blocks_.back()); } - for (RepackAllocationBlock* repack_block : colocations) { + for (MemorySpaceAssignmentRepacker::AllocationBlock* repack_block : + colocations) { repack_block->colocations = colocations; } } @@ -1842,7 +1934,8 @@ void AlternateMemoryBestFitHeap::AddAsyncCopy( const MemorySpaceAssignment::Allocation& prev_allocation, MemorySpace memory_space, absl::optional chunk, int64 start_time, int64 end_time, int64 copy_done_schedule_before_time, - MemorySpaceAssignment::AllocationSequence* allocations) { + MemorySpaceAssignment::AllocationSequence* allocations, + bool is_cross_program_prefetch) { VLOG(3) << "Copy to " << (memory_space == MemorySpaceAssignment::MemorySpace::kDefault ? "default" @@ -1854,7 +1947,7 @@ void AlternateMemoryBestFitHeap::AddAsyncCopy( allocations->push_back( absl::make_unique( prev_allocation, memory_space, chunk, start_time, end_time, - copy_done_schedule_before_time)); + copy_done_schedule_before_time, is_cross_program_prefetch)); // Register the additional async copy with the interval tree to keep track of // the limit at any given time. @@ -2116,12 +2209,15 @@ int64 AlternateMemoryBestFitHeap::FindPrefetchEndTime( const AllocationRequest& request, int64 earliest_prefetch_time) const { int64 prefetch_end_time = request.latest_prefetch_time; + const HloUse& use = request.use->hlo_use; + const Shape& shape = ShapeUtil::GetSubshape( + use.instruction->operand(use.operand_number)->shape(), use.operand_index); for (int retry_number = 0; retry_number < options_.prefetch_copy_done_reorder_max_retries; ++retry_number) { int64 latest_prefetch_time = options_.prefetch_interval_picker->LatestPrefetchStartTime( - request.use->hlo_use, earliest_prefetch_time, prefetch_end_time); + shape, earliest_prefetch_time, prefetch_end_time, &use); VLOG(4) << "Latest prefetch start time = " << latest_prefetch_time << ", earliest prefetch start time = " << earliest_prefetch_time << ", prefetch end time = " << prefetch_end_time; @@ -2356,8 +2452,8 @@ MemorySpaceAssignment::GetMemoryBoundednessBufferIntervalCompare( return x_memory_boundedness > y_memory_boundedness; } // Tie-break if the memory boundedness is the same. - return GlobalDecreasingSizeBestFitHeap::GetSpatialBufferIntervalCompare()( - x, y); + return GlobalDecreasingSizeBestFitHeap< + HloValue>::GetSpatialBufferIntervalCompare()(x, y); }; } @@ -2428,7 +2524,9 @@ FindCrossProgramPrefetchCandidate( const HloAliasAnalysis& alias_analysis, const HloLiveRange& hlo_live_range, const MemorySpaceAssignment::Options& options) { std::vector candidates; - for (HloValue* value : alias_analysis.dataflow_analysis().values()) { + for (const HloBuffer& buffer : alias_analysis.buffers()) { + CHECK_GE(buffer.values().size(), 1); + const HloValue* value = buffer.values().at(0); if (IsCrossProgramPrefetchCandidate(*value, options)) { MemorySpaceAssignment::BufferInterval interval; interval.buffer = value; @@ -2436,6 +2534,7 @@ FindCrossProgramPrefetchCandidate( interval.start = 0; interval.end = hlo_live_range.schedule_end_time(); interval.need_allocation = true; + interval.colocations = {++buffer.values().begin(), buffer.values().end()}; candidates.emplace_back(interval); } } @@ -2665,9 +2764,9 @@ Status MemorySpaceAssignment::CopyAllocation::Process( Shape shape = defining_position().shape(); HloInstruction* producing_instruction = AddGetTupleElements(); HloComputation* computation = producing_instruction->parent(); - copy_start_ = computation->AddInstruction(HloInstruction::CreateUnary( + copy_start_ = computation->AddInstruction(HloInstruction::CreateCopyStart( ShapeUtil::MakeTupleShape({shape, shape, ShapeUtil::MakeShape(U32, {})}), - HloOpcode::kCopyStart, producing_instruction)); + producing_instruction, is_cross_program_prefetch_)); copy_done_ = computation->AddInstruction( HloInstruction::CreateUnary(shape, HloOpcode::kCopyDone, copy_start_)); VLOG(4) << "Created " << copy_start_->name() diff --git a/tensorflow/compiler/xla/service/memory_space_assignment.h b/tensorflow/compiler/xla/service/memory_space_assignment.h index d366c06a599..04737663424 100644 --- a/tensorflow/compiler/xla/service/memory_space_assignment.h +++ b/tensorflow/compiler/xla/service/memory_space_assignment.h @@ -106,7 +106,7 @@ class MemorySpaceAssignmentCostAnalysis { // BufferInterval. The larger this number, the higher priority it will be // placed in the alternate memory. float GetMemoryBoundedness( - const GlobalDecreasingSizeBestFitHeap::BufferInterval& interval, + const GlobalDecreasingSizeBestFitHeap::BufferInterval& interval, Cache* cache = nullptr) const; // Returns the elapsed time in seconds due to compute only. @@ -200,8 +200,15 @@ class PrefetchIntervalPicker { int64 latest_end_time) const = 0; // Returns the latest time that a prefetch can start. - virtual int64 LatestPrefetchStartTime(const HloUse& use, int64 start_time, - int64 end_time) const = 0; + virtual int64 LatestPrefetchStartTime(const Shape& shape, int64 start_time, + int64 end_time, + const HloUse* use) const = 0; + + // Returns the preferred time that a prefetch can start. + virtual int64 PreferredPrefetchStartTime(const Shape& shape, + int64 earliest_prefetch_start_time, + int64 latest_prefetch_start_time, + int64 prefetch_end_time) const = 0; // Returns the latest time that a prefetch can end that is less than or equal // to proposed_prefetch_end_time. @@ -235,7 +242,8 @@ class PrefetchIntervalPicker { // of placing the BufferInterval in the alternate memory. The larger value, // the more beneficial. virtual absl::optional BufferIntervalAlternateMemoryBenefit( - const GlobalDecreasingSizeBestFitHeap::BufferInterval& interval) const { + const GlobalDecreasingSizeBestFitHeap::BufferInterval& interval) + const { return absl::nullopt; } @@ -268,8 +276,14 @@ class InstructionCountPrefetchIntervalPicker : public PrefetchIntervalPicker { int64 PreferredEvictionEndTime(const Shape& shape, int64 start_time, int64 latest_end_time) const override; - int64 LatestPrefetchStartTime(const HloUse& use, int64 start_time, - int64 end_time) const override; + int64 LatestPrefetchStartTime(const Shape& shape, int64 start_time, + int64 end_time, + const HloUse* use) const override; + + int64 PreferredPrefetchStartTime(const Shape& shape, + int64 earliest_prefetch_start_time, + int64 latest_prefetch_start_time, + int64 prefetch_end_time) const override; void Begin(const HloUse& use, int64 start_time, int64 end_time) override; @@ -307,11 +321,18 @@ class CostAnalysisPrefetchIntervalPicker : public PrefetchIntervalPicker { int64 PreferredEvictionEndTime(const Shape& shape, int64 start_time, int64 latest_end_time) const override; - int64 LatestPrefetchStartTime(const HloUse& use, int64 start_time, - int64 end_time) const override; int64 LatestPrefetchEndTime(int64 original_prefetch_end_time, int64 proposed_prefetch_end_time) const override; + int64 LatestPrefetchStartTime(const Shape& shape, int64 start_time, + int64 end_time, + const HloUse* use) const override; + + int64 PreferredPrefetchStartTime(const Shape& shape, + int64 earliest_prefetch_start_time, + int64 latest_prefetch_start_time, + int64 prefetch_end_time) const override; + void Begin(const HloUse& use, int64 start_time, int64 end_time) override; int64 Next() override; @@ -324,7 +345,7 @@ class CostAnalysisPrefetchIntervalPicker : public PrefetchIntervalPicker { int64 end_time) const override; absl::optional BufferIntervalAlternateMemoryBenefit( - const GlobalDecreasingSizeBestFitHeap::BufferInterval& interval) + const GlobalDecreasingSizeBestFitHeap::BufferInterval& interval) const override; private: @@ -370,9 +391,10 @@ class CostAnalysisPrefetchIntervalPicker : public PrefetchIntervalPicker { class MemorySpaceAssignment { public: using Chunk = HeapSimulator::Chunk; - using BufferInterval = GlobalDecreasingSizeBestFitHeap::BufferInterval; + using BufferInterval = + GlobalDecreasingSizeBestFitHeap::BufferInterval; using BufferIntervalCompare = - GlobalDecreasingSizeBestFitHeap::BufferIntervalCompare; + GlobalDecreasingSizeBestFitHeap::BufferIntervalCompare; using IsAllowedInAlternateMemoryFunction = std::function; @@ -435,7 +457,7 @@ class MemorySpaceAssignment { // The repacking algorithm to reduce fragmentation. Must be non-null if // max_repacks is greater than 0. - MemorySpaceAssignmentRepacker* repacker = nullptr; + MemorySpaceAssignmentRepacker* repacker = nullptr; // If true, tries allocating buffers across (e.g., before and inside a while // loop body) sequential calls (kWhile, kCall, and kConditional). @@ -559,12 +581,14 @@ class MemorySpaceAssignment { public: CopyAllocation(const Allocation& prev_allocation, MemorySpace memory_space, absl::optional chunk, int64 start_time, - int64 end_time, int64 copy_done_schedule_before_time) + int64 end_time, int64 copy_done_schedule_before_time, + bool is_cross_program_prefetch = false) : Allocation(/*defining_position=*/{nullptr, {}}, memory_space, chunk, start_time, end_time), prev_allocation_(prev_allocation), copy_start_schedule_after_(start_time), - copy_done_schedule_before_(copy_done_schedule_before_time) {} + copy_done_schedule_before_(copy_done_schedule_before_time), + is_cross_program_prefetch_(is_cross_program_prefetch) {} bool is_copy_allocation() const override { return true; } @@ -604,6 +628,10 @@ class MemorySpaceAssignment { copy_start_schedule_after_ = copy_start_schedule_after; } + bool is_cross_program_prefetch() const { + return is_cross_program_prefetch_; + } + bool operator==(const CopyAllocation& other) const; std::string ToString() const override; @@ -615,6 +643,7 @@ class MemorySpaceAssignment { // is before copy_done_schedule_before_. int64 copy_start_schedule_after_; int64 copy_done_schedule_before_; + bool is_cross_program_prefetch_; HloInstruction* copy_start_; HloInstruction* copy_done_; }; @@ -913,7 +942,8 @@ class AsynchronousCopyOrdering { // This class inherits from GlobalDecreasingSizeBestFitHeap with a notion of // maximum size. -class AlternateMemoryBestFitHeap : public GlobalDecreasingSizeBestFitHeap { +class AlternateMemoryBestFitHeap + : public GlobalDecreasingSizeBestFitHeap { public: using MemorySpace = MemorySpaceAssignment::MemorySpace; using AllocationValue = MemorySpaceAssignment::AllocationValue; @@ -940,11 +970,15 @@ class AlternateMemoryBestFitHeap : public GlobalDecreasingSizeBestFitHeap { void AllocateCrossProgramPrefetchBuffer( HloModule* module, absl::optional prefetch_candidate); - HeapSimulator::Result Finish() override; + HeapSimulator::Result Finish() override; private: - using RepackAllocationBlock = MemorySpaceAssignmentRepacker< - MemorySpaceAssignment::Allocation*>::AllocationBlock; + // We inherit AllocationBlock struct to attach the Allocation information to + // make importing repacked offsets easier. + struct RepackAllocationBlock + : MemorySpaceAssignmentRepacker::AllocationBlock { + MemorySpaceAssignment::Allocation* allocation; + }; // An allocation request for a use segment. A use segment is the time segment // between the definition and the first use, and the time segment between the @@ -1169,19 +1203,20 @@ class AlternateMemoryBestFitHeap : public GlobalDecreasingSizeBestFitHeap { // Exports the allocations for repacking and puts them into the vector in the // parameter. void ExportAllocationsForRepacking( - std::vector& allocations); + std::vector& + allocations); // Imports repacked allocations and updates the internal data structures // consistent with the new packing. - void ImportRepackedAllocations( - absl::Span repacked_allocations); + void ImportRepackedAllocations(); // Adds an asynchronous copy to the allocations. void AddAsyncCopy(const MemorySpaceAssignment::Allocation& prev_allocation, MemorySpace memory_space, absl::optional chunk, int64 start_time, int64 end_time, int64 copy_done_schedule_before_time, - MemorySpaceAssignment::AllocationSequence* allocations); + MemorySpaceAssignment::AllocationSequence* allocations, + bool is_cross_program_prefetch = false); // This method is used for committing the chunk candidate but adding it to // pending_chunks_ so that we can "uncommit" them in case we need to roll back @@ -1215,6 +1250,22 @@ class AlternateMemoryBestFitHeap : public GlobalDecreasingSizeBestFitHeap { return options_.max_size_in_bytes - reserved_in_bytes_; } + // Creates and returns a RepackAllocationBlock. + static RepackAllocationBlock MakeRepackAllocationBlock( + int64 start_time, int64 end_time, int64 size, int64 initial_offset, + int64 id, MemorySpaceAssignment::Allocation* allocation) { + RepackAllocationBlock allocation_block; + allocation_block.start_time = start_time; + allocation_block.end_time = end_time; + allocation_block.size = size; + allocation_block.offset = -1; + allocation_block.initial_offset = initial_offset; + allocation_block.id = id; + allocation_block.colocations = {}; + allocation_block.allocation = allocation; + return allocation_block; + } + MemorySpaceAssignment::AllocationSequence* allocations_; const MemorySpaceAssignment::Options& options_; const HloAliasAnalysis& alias_analysis_; diff --git a/tensorflow/compiler/xla/service/memory_space_assignment_best_fit_repacker.cc b/tensorflow/compiler/xla/service/memory_space_assignment_best_fit_repacker.cc new file mode 100644 index 00000000000..53b092f1939 --- /dev/null +++ b/tensorflow/compiler/xla/service/memory_space_assignment_best_fit_repacker.cc @@ -0,0 +1,88 @@ +/* Copyright 2020 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tensorflow/compiler/xla/service/memory_space_assignment_best_fit_repacker.h" + +#include "tensorflow/compiler/xla/service/heap_simulator.h" + +namespace xla { + +namespace { + +using AllocationBlock = MemorySpaceAssignmentRepacker::AllocationBlock; +using Type = GlobalDecreasingSizeBestFitHeap::Type; + +// This class inherits GlobalDecreasingSizeBestFitHeap and converts +// AllocationBlock objects into BufferIntervals that the heap algorithm +// understands. +class BestFitRepacker + : public GlobalDecreasingSizeBestFitHeap { + public: + BestFitRepacker(int64 max_size, int64 alignment, Type type) + : GlobalDecreasingSizeBestFitHeap(alignment, type), + max_size_(max_size) {} + + void ImportAllocationBlocks(absl::Span allocations) { + allocation_blocks_ = allocations; + for (AllocationBlock* allocation_block : allocations) { + // Check if any of the colocations are already added to buffer_intervals_. + bool need_allocation = true; + auto aliased_it = absl::c_find_if( + allocation_block->colocations, [&](AllocationBlock* search) { + return buffer_intervals_.contains(search); + }); + if (aliased_it != allocation_block->colocations.end()) { + buffer_intervals_[*aliased_it].colocations.push_back(allocation_block); + need_allocation = false; + } + buffer_intervals_[allocation_block] = {allocation_block, + allocation_block->size, + allocation_block->start_time, + allocation_block->end_time, + {}, + need_allocation}; + } + } + + bool Repack() { + Finish(); + bool success = result_.heap_size <= max_size_; + if (success) { + for (AllocationBlock* block : allocation_blocks_) { + auto chunk_it = result_.chunk_map.find(block); + if (chunk_it != result_.chunk_map.end()) { + block->offset = chunk_it->second.offset; + } + } + } + return success; + } + + private: + int64 max_size_; + absl::Span allocation_blocks_; +}; + +} // namespace + +StatusOr MemorySpaceAssignmentBestFitRepacker::Repack( + absl::Span allocations) { + BestFitRepacker best_fit_repacker = + BestFitRepacker(max_size_, alignment_, type_); + best_fit_repacker.ImportAllocationBlocks(allocations); + return best_fit_repacker.Repack(); +} + +} // namespace xla diff --git a/tensorflow/compiler/xla/service/memory_space_assignment_best_fit_repacker.h b/tensorflow/compiler/xla/service/memory_space_assignment_best_fit_repacker.h new file mode 100644 index 00000000000..6937b8b0e8c --- /dev/null +++ b/tensorflow/compiler/xla/service/memory_space_assignment_best_fit_repacker.h @@ -0,0 +1,44 @@ +/* Copyright 2020 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_COMPILER_XLA_SERVICE_MEMORY_SPACE_ASSIGNMENT_BEST_FIT_REPACKER_H_ +#define TENSORFLOW_COMPILER_XLA_SERVICE_MEMORY_SPACE_ASSIGNMENT_BEST_FIT_REPACKER_H_ + +#include "tensorflow/compiler/xla/service/heap_simulator.h" +#include "tensorflow/compiler/xla/service/memory_space_assignment_repacking.h" + +namespace xla { + +// This is a repacker algorithm that wraps around best fit heap algorithm in +// heap simulator. +class MemorySpaceAssignmentBestFitRepacker + : public MemorySpaceAssignmentRepacker { + public: + using Type = GlobalDecreasingSizeBestFitHeap::Type; + + explicit MemorySpaceAssignmentBestFitRepacker( + int64 max_size, int64 alignment, + Type type = GlobalDecreasingSizeBestFitHeap::kTemporal) + : MemorySpaceAssignmentRepacker(max_size, alignment), type_(type) {} + + StatusOr Repack(absl::Span allocations) override; + + private: + Type type_; +}; + +} // namespace xla + +#endif // TENSORFLOW_COMPILER_XLA_SERVICE_MEMORY_SPACE_ASSIGNMENT_BEST_FIT_REPACKER_H_ diff --git a/tensorflow/compiler/xla/service/memory_space_assignment_best_fit_repacker_test.cc b/tensorflow/compiler/xla/service/memory_space_assignment_best_fit_repacker_test.cc new file mode 100644 index 00000000000..44da2828eac --- /dev/null +++ b/tensorflow/compiler/xla/service/memory_space_assignment_best_fit_repacker_test.cc @@ -0,0 +1,89 @@ +/* Copyright 2020 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tensorflow/compiler/xla/service/memory_space_assignment_best_fit_repacker.h" + +#include "tensorflow/core/platform/test.h" + +namespace xla { + +class MemorySpaceAssignmentBestFitRepackerTest : public ::testing::Test { + protected: + using AllocationBlock = MemorySpaceAssignmentRepacker::AllocationBlock; + + MemorySpaceAssignmentBestFitRepackerTest() : repacker_(100, 1) {} + + AllocationBlock* MakeAllocationBlock(int64 start_time, int64 end_time, + int64 size, int64 initial_offset = -1) { + allocation_blocks_.push_back({start_time, + end_time, + size, + -1, + initial_offset, + static_cast(allocation_blocks_.size()), + {}}); + AllocationBlock* block = &allocation_blocks_.back(); + block->colocations.push_back(block); + return block; + } + + std::list allocation_blocks_; + MemorySpaceAssignmentBestFitRepacker repacker_; +}; + +TEST_F(MemorySpaceAssignmentBestFitRepackerTest, Simple) { + std::vector allocation_blocks; + allocation_blocks.push_back(MakeAllocationBlock(10, 20, 10)); + allocation_blocks.push_back(MakeAllocationBlock(5, 25, 15)); + EXPECT_TRUE(*repacker_.Repack(absl::MakeSpan(allocation_blocks))); + + EXPECT_EQ(allocation_blocks[0]->offset, 15); + EXPECT_EQ(allocation_blocks[1]->offset, 0); +} + +TEST_F(MemorySpaceAssignmentBestFitRepackerTest, Colocation) { + std::vector allocation_blocks; + allocation_blocks.push_back(MakeAllocationBlock(0, 2, 10)); + allocation_blocks.push_back(MakeAllocationBlock(10, 20, 10)); + // Allocation blocks 0 and 1 are colocated. + allocation_blocks[0]->colocations.push_back(allocation_blocks[1]); + allocation_blocks[1]->colocations.push_back(allocation_blocks[0]); + allocation_blocks.push_back(MakeAllocationBlock(5, 25, 15)); + EXPECT_TRUE(*repacker_.Repack(absl::MakeSpan(allocation_blocks))); + + EXPECT_EQ(allocation_blocks[0]->offset, 15); + EXPECT_EQ(allocation_blocks[1]->offset, 15); + EXPECT_EQ(allocation_blocks[2]->offset, 0); +} + +TEST_F(MemorySpaceAssignmentBestFitRepackerTest, TooLarge) { + // Memory size is 100, total size of buffers is 105. + std::vector allocation_blocks; + allocation_blocks.push_back(MakeAllocationBlock(10, 20, 10)); + allocation_blocks.push_back(MakeAllocationBlock(5, 25, 15)); + allocation_blocks.push_back(MakeAllocationBlock(15, 20, 10)); + allocation_blocks.push_back(MakeAllocationBlock(12, 22, 50)); + allocation_blocks.push_back(MakeAllocationBlock(10, 18, 20)); + EXPECT_FALSE(*repacker_.Repack(absl::MakeSpan(allocation_blocks))); + + // Make sure the buffers didn't get offset assignments. + EXPECT_EQ(allocation_blocks[0]->offset, -1); + EXPECT_EQ(allocation_blocks[1]->offset, -1); + EXPECT_EQ(allocation_blocks[2]->offset, -1); + EXPECT_EQ(allocation_blocks[3]->offset, -1); + EXPECT_EQ(allocation_blocks[4]->offset, -1); +} + +} // namespace xla diff --git a/tensorflow/compiler/xla/service/memory_space_assignment_repacking.h b/tensorflow/compiler/xla/service/memory_space_assignment_repacking.h index fcfdfc797fb..eb2f0698a95 100644 --- a/tensorflow/compiler/xla/service/memory_space_assignment_repacking.h +++ b/tensorflow/compiler/xla/service/memory_space_assignment_repacking.h @@ -22,10 +22,10 @@ limitations under the License. namespace xla { // An interface to define allocation repacking algorithms. -template class MemorySpaceAssignmentRepacker { public: - MemorySpaceAssignmentRepacker() = default; + MemorySpaceAssignmentRepacker(int64 max_size, int64 alignment) + : max_size_(max_size), alignment_(alignment) {} virtual ~MemorySpaceAssignmentRepacker() = default; // A contiguous block of allocation consisting of start and end (logical) @@ -33,23 +33,36 @@ class MemorySpaceAssignmentRepacker { // successful and the allocations were modified, the offset field holds the // new offset. To support aliased allocations, AllocationBlock also includes a // vector of AllocationBlock pointers, called colocations. All AllocationBlock - // objects within the colocations must get the same offset. The opaque field - // is used by the MemorySpaceAssignment pass and should not be accessed by the - // repacking algorithm. + // objects within the colocations must get the same offset. The id should be + // unique and is used to ensure determinism for comparison tie-breaker. struct AllocationBlock { int64 start_time; int64 end_time; int64 size; int64 offset; int64 initial_offset; + int64 id; std::vector colocations; - O opaque; + + std::string ToString() const { + return absl::StrCat("[", start_time, ", ", end_time, "] : size = ", size, + ", offset = ", offset, + " initial offset = ", initial_offset); + } + + // This is required by BufferIntervalCompare as a tie breaker. Use a unique + // and deterministic id. + bool operator<(const AllocationBlock& other) const { return id < other.id; } }; // Repack the AllocationBlocks provided in the parameter. Returns true if // allocations have been modified and false if not. Returns a non-ok status if // there was an error. virtual StatusOr Repack(absl::Span allocations) = 0; + + protected: + int64 max_size_; + int64 alignment_; }; } // namespace xla diff --git a/tensorflow/compiler/xla/service/memory_space_assignment_test.cc b/tensorflow/compiler/xla/service/memory_space_assignment_test.cc index 464cfb502be..cc4f740bc25 100644 --- a/tensorflow/compiler/xla/service/memory_space_assignment_test.cc +++ b/tensorflow/compiler/xla/service/memory_space_assignment_test.cc @@ -4069,12 +4069,12 @@ TEST_P(MemorySpaceAssignmentTest, MoveCopyDoneEarlier) { // A mock MemorySpaceAssignmentRepacker class that accepst a map of // (start_time,offset) -> new_offset values. Using this map, the repacker // repacks the allocations to the new_offset. -class FakeMemorySpaceAssignmentRepacker - : public MemorySpaceAssignmentRepacker { +class FakeMemorySpaceAssignmentRepacker : public MemorySpaceAssignmentRepacker { public: - FakeMemorySpaceAssignmentRepacker( + explicit FakeMemorySpaceAssignmentRepacker( absl::flat_hash_map, int64>& repack_map) - : repack_map_(repack_map) {} + : MemorySpaceAssignmentRepacker(/*max_size=*/128, /*alignment=*/8), + repack_map_(repack_map) {} StatusOr Repack(absl::Span allocations) override { bool modified = false; @@ -4566,6 +4566,125 @@ TEST_P(MemorySpaceAssignmentTest, CrossProgramPrefetchPinnedTest) { EXPECT_EQ(cross_program_prefetches.size(), 0); } +TEST_P(MemorySpaceAssignmentTest, CrossProgramPrefetchReuse) { + // This test is for checking if the cross-program-prefetched buffer is freed + // after its last use and there is an end-of-program prefetch. + absl::string_view hlo_string = R"( + HloModule cross_program_prefetch, is_scheduled=true + + ENTRY CrossProgramPrefetch { + p0 = (f32[8,8]{1,0}, f32[8,2]{1,0}) parameter(0) + get-tuple-element = f32[8,8]{1,0} get-tuple-element(p0), index=0 + get-tuple-element.1 = f32[8,2]{1,0} get-tuple-element(p0), index=1 + dot = f32[8,2]{1,0} dot(get-tuple-element, get-tuple-element.1), lhs_contracting_dims={1}, rhs_contracting_dims={0} + negate.1 = f32[8,2]{1,0} negate(dot) + negate.2 = f32[8,2]{1,0} negate(negate.1) + negate.3 = f32[8,2]{1,0} negate(negate.2) + negate.4 = f32[8,2]{1,0} negate(negate.3) + negate.5 = f32[8,2]{1,0} negate(negate.4) + negate.6 = f32[8,2]{1,0} negate(negate.5) + negate.7 = f32[8,2]{1,0} negate(negate.6) + negate.8 = f32[8,2]{1,0} negate(negate.7) + ROOT negate.9 = f32[8,2]{1,0} negate(negate.8) + } + )"; + TF_ASSERT_OK_AND_ASSIGN(auto module, + ParseAndReturnVerifiedModule(hlo_string)); + + AssignMemorySpace(module.get(), /*max_outstanding_async_copies=*/-1, + /*max_prefetch_interval=*/5, /*min_prefetch_interval=*/2); + + auto cross_program_prefetches = module->CrossProgramPrefetches(); + EXPECT_EQ(cross_program_prefetches.size(), 1); + if (!cross_program_prefetches.empty()) { + EXPECT_EQ(cross_program_prefetches[0].first, 0); + EXPECT_EQ(cross_program_prefetches[0].second, ShapeIndex({1})); + } + + TF_ASSERT_OK_AND_ASSIGN( + std::unique_ptr dataflow_analysis, + HloDataflowAnalysis::Run(*module)); + const HloValue& cross_program_prefetched_value = + dataflow_analysis->GetValueDefinedAt( + module->entry_computation()->parameter_instruction(0), {1}); + // Expect that there are two prefetches that use this value, one is the + // cross-program prefetch, the other is the end-of-program prefetch. + auto is_cross_program_prefetch = [](const HloUse& use) { + return use.instruction->opcode() == HloOpcode::kCopyStart && + use.instruction->is_cross_program_prefetch(); + }; + EXPECT_EQ(absl::c_count_if(cross_program_prefetched_value.uses(), + is_cross_program_prefetch), + 1); + auto is_end_of_program_prefetch = [](const HloUse& use) { + return use.instruction->opcode() == HloOpcode::kCopyStart && + !use.instruction->is_cross_program_prefetch(); + }; + EXPECT_EQ(absl::c_count_if(cross_program_prefetched_value.uses(), + is_end_of_program_prefetch), + 1); +} + +TEST_P(MemorySpaceAssignmentTest, CrossProgramPrefetchNoReuse) { + // This tests the scenario that the cross-program-prefetched buffer is used + // again close to the end of the computation. In this case, it is better not + // to free the buffer. + absl::string_view hlo_string = R"( + HloModule cross_program_prefetch, is_scheduled=true + + ENTRY CrossProgramPrefetch { + p0 = (f32[8,8]{1,0}, f32[8,2]{1,0}) parameter(0) + get-tuple-element = f32[8,8]{1,0} get-tuple-element(p0), index=0 + get-tuple-element.1 = f32[8,2]{1,0} get-tuple-element(p0), index=1 + dot = f32[8,2]{1,0} dot(get-tuple-element, get-tuple-element.1), lhs_contracting_dims={1}, rhs_contracting_dims={0} + negate.1 = f32[8,2]{1,0} negate(dot) + negate.2 = f32[8,2]{1,0} negate(negate.1) + negate.3 = f32[8,2]{1,0} negate(negate.2) + negate.4 = f32[8,2]{1,0} negate(negate.3) + negate.5 = f32[8,2]{1,0} negate(negate.4) + negate.6 = f32[8,2]{1,0} negate(negate.5) + negate.7 = f32[8,2]{1,0} negate(negate.6) + negate.8 = f32[8,2]{1,0} negate(negate.7) + ROOT dot.2 = f32[2,2]{1,0} dot(negate.8, get-tuple-element.1), lhs_contracting_dims={0}, rhs_contracting_dims={0} + } + )"; + TF_ASSERT_OK_AND_ASSIGN(auto module, + ParseAndReturnVerifiedModule(hlo_string)); + + AssignMemorySpace(module.get(), /*max_outstanding_async_copies=*/-1, + /*max_prefetch_interval=*/5, /*min_prefetch_interval=*/2); + + auto cross_program_prefetches = module->CrossProgramPrefetches(); + EXPECT_EQ(cross_program_prefetches.size(), 1); + if (!cross_program_prefetches.empty()) { + EXPECT_EQ(cross_program_prefetches[0].first, 0); + EXPECT_EQ(cross_program_prefetches[0].second, ShapeIndex({1})); + } + + TF_ASSERT_OK_AND_ASSIGN( + std::unique_ptr dataflow_analysis, + HloDataflowAnalysis::Run(*module)); + const HloValue& cross_program_prefetched_value = + dataflow_analysis->GetValueDefinedAt( + module->entry_computation()->parameter_instruction(0), {1}); + // Expect that there is one prefetch that use this value, the cross-program + // prefetch. There shouldn't be an end-of-program prefetch. + auto is_cross_program_prefetch = [](const HloUse& use) { + return use.instruction->opcode() == HloOpcode::kCopyStart && + use.instruction->is_cross_program_prefetch(); + }; + EXPECT_EQ(absl::c_count_if(cross_program_prefetched_value.uses(), + is_cross_program_prefetch), + 1); + auto is_end_of_program_prefetch = [](const HloUse& use) { + return use.instruction->opcode() == HloOpcode::kCopyStart && + !use.instruction->is_cross_program_prefetch(); + }; + EXPECT_EQ(absl::c_count_if(cross_program_prefetched_value.uses(), + is_end_of_program_prefetch), + 0); +} + using CostAnalysisPrefetchIntervalPickerTest = HloTestBase; TEST_F(CostAnalysisPrefetchIntervalPickerTest, PrefetchIntervalOrder) { @@ -4790,11 +4909,12 @@ TEST_F(CostAnalysisPrefetchIntervalPickerTest, NestedWhile) { HloInstruction* root = module->entry_computation()->root_instruction(); const HloUse use{root, /*operand_number=*/1, /*operand_index=*/{}}; + const Shape& shape = root->operand(1)->shape(); // We expect the root's latest prefetch start time to be before the while loop // (logical time 4). - EXPECT_EQ(interval_picker.LatestPrefetchStartTime(use, /*start_time=*/0, - /*end_time=*/23), + EXPECT_EQ(interval_picker.LatestPrefetchStartTime(shape, /*start_time=*/0, + /*end_time=*/23, &use), 4); } diff --git a/tensorflow/compiler/xla/service/memory_space_assignment_utils.cc b/tensorflow/compiler/xla/service/memory_space_assignment_utils.cc index 0215f007c9c..0c44ae0d766 100644 --- a/tensorflow/compiler/xla/service/memory_space_assignment_utils.cc +++ b/tensorflow/compiler/xla/service/memory_space_assignment_utils.cc @@ -17,21 +17,21 @@ limitations under the License. namespace xla { -bool MemorySpaceAssignmentUtils::IsIntervalAllowedInAlternateMemory( - const GlobalDecreasingSizeBestFitHeap::BufferInterval& interval) { +bool MemorySpaceAssignmentUtils::IsValueAllowedInAlternateMemory( + const HloValue* value) { // If the buffer is a tuple, don't use this algorithm for now. The buffers // that are pointed to by the tuple will still use this algorithm. Because // tuples are cheap to place in the alternate memory (they are just pointers) // we don't need to use prefetch/evict logic. - if (interval.buffer->shape().IsTuple()) { - VLOG(4) << "Keeping value " << interval.buffer->ToShortString() + if (value->shape().IsTuple()) { + VLOG(4) << "Keeping value " << value->ToShortString() << " in default mem because it is a tuple."; return false; } // Don't place scalars in the alternate memory. - if (ShapeUtil::IsEffectiveScalar(interval.buffer->shape())) { - VLOG(4) << "Keeping value " << interval.buffer->ToShortString() + if (ShapeUtil::IsEffectiveScalar(value->shape())) { + VLOG(4) << "Keeping value " << value->ToShortString() << " in default mem because it is a scalar."; return false; } @@ -44,10 +44,10 @@ bool MemorySpaceAssignmentUtils::IsIntervalAllowedInAlternateMemory( // allocate TupleSelect in the alternate memory space. // TODO(berkin): Not allocating add-dependencies either since they need to be // treated specially. We should revisit this later. - for (const HloPosition& position : interval.buffer->positions()) { + for (const HloPosition& position : value->positions()) { if (position.instruction->opcode() == HloOpcode::kTupleSelect || position.instruction->opcode() == HloOpcode::kAddDependency) { - VLOG(4) << "Keeping value " << interval.buffer->ToShortString() + VLOG(4) << "Keeping value " << value->ToShortString() << " in default mem because it has a tuple-select or " << "add-dependency position."; return false; @@ -56,18 +56,18 @@ bool MemorySpaceAssignmentUtils::IsIntervalAllowedInAlternateMemory( // Send and Recv HLOs return a request identifier. These should not be // allocated in the alternate memory. - for (const HloPosition& position : interval.buffer->positions()) { + for (const HloPosition& position : value->positions()) { if ((position.instruction->opcode() == HloOpcode::kSend || position.instruction->opcode() == HloOpcode::kRecv)) { // TODO(berkin): Send/recv buffers need a stable buffer allocation // throughout sending/receiving. Disable memory space allocation for these // for now. if (position.index == ShapeIndex({0})) { - VLOG(4) << "Keeping value " << interval.buffer->ToShortString() + VLOG(4) << "Keeping value " << value->ToShortString() << " in default mem because it is a send/recv buffer."; return false; } else if (position.index == ShapeIndex({1})) { - VLOG(4) << "Keeping value " << interval.buffer->ToShortString() + VLOG(4) << "Keeping value " << value->ToShortString() << " in default mem because it is a request identifier for " "send/recv."; return false; @@ -78,11 +78,11 @@ bool MemorySpaceAssignmentUtils::IsIntervalAllowedInAlternateMemory( position.instruction->opcode() == HloOpcode::kCollectivePermuteDone)) { // Disable memory space allocation for these for now. if (position.index == ShapeIndex({0})) { - VLOG(4) << "Keeping value " << interval.buffer->ToShortString() + VLOG(4) << "Keeping value " << value->ToShortString() << " in default mem because it is a collective-permute buffer."; return false; } else if (position.index == ShapeIndex({1})) { - VLOG(4) << "Keeping value " << interval.buffer->ToShortString() + VLOG(4) << "Keeping value " << value->ToShortString() << " in default mem because it is a collective-permute buffer."; return false; } @@ -92,4 +92,10 @@ bool MemorySpaceAssignmentUtils::IsIntervalAllowedInAlternateMemory( return true; } +bool MemorySpaceAssignmentUtils::IsIntervalAllowedInAlternateMemory( + const GlobalDecreasingSizeBestFitHeap::BufferInterval& interval) { + return IsValueAllowedInAlternateMemory(interval.buffer) && + absl::c_all_of(interval.colocations, IsValueAllowedInAlternateMemory); +} + } // namespace xla diff --git a/tensorflow/compiler/xla/service/memory_space_assignment_utils.h b/tensorflow/compiler/xla/service/memory_space_assignment_utils.h index 651ac107c25..082efa5eb64 100644 --- a/tensorflow/compiler/xla/service/memory_space_assignment_utils.h +++ b/tensorflow/compiler/xla/service/memory_space_assignment_utils.h @@ -26,7 +26,11 @@ class MemorySpaceAssignmentUtils { // Returns true if this buffer is allowed to be placed in the alternate // memory. static bool IsIntervalAllowedInAlternateMemory( - const GlobalDecreasingSizeBestFitHeap::BufferInterval& interval); + const GlobalDecreasingSizeBestFitHeap::BufferInterval& + interval); + + // Returns true if the HloValue is allowed to be placed in alternate memory. + static bool IsValueAllowedInAlternateMemory(const HloValue* value); }; } // namespace xla diff --git a/tensorflow/compiler/xla/service/mlir_gpu/BUILD b/tensorflow/compiler/xla/service/mlir_gpu/BUILD index 31cf36dee85..68bcde4f7ee 100644 --- a/tensorflow/compiler/xla/service/mlir_gpu/BUILD +++ b/tensorflow/compiler/xla/service/mlir_gpu/BUILD @@ -149,6 +149,7 @@ cc_library( "//tensorflow/core:lib", "//tensorflow/stream_executor:stream_executor_headers", "@com_google_absl//absl/container:flat_hash_map", + "@llvm-project//llvm:Core", "@llvm-project//mlir:IR", "@llvm-project//mlir:LLVMDialect", "@llvm-project//mlir:StandardOps", diff --git a/tensorflow/compiler/xla/service/mlir_gpu/emission_context.cc b/tensorflow/compiler/xla/service/mlir_gpu/emission_context.cc index ca979262df0..cb5ea946c1b 100644 --- a/tensorflow/compiler/xla/service/mlir_gpu/emission_context.cc +++ b/tensorflow/compiler/xla/service/mlir_gpu/emission_context.cc @@ -25,6 +25,7 @@ namespace mlir_gpu { EmissionContext::EmissionContext(std::unique_ptr module) : module_(std::move(module)), context_() { + context_.loadAllGloballyRegisteredDialects(); error_handler_ = [](const ErrorMap& instructions_with_error, HloModule* module) { std::set computations_with_error; diff --git a/tensorflow/compiler/xla/service/mlir_gpu/experimental/conv_emitter/conv_emitter_test.cc b/tensorflow/compiler/xla/service/mlir_gpu/experimental/conv_emitter/conv_emitter_test.cc index d5cad385324..f7a7decff76 100644 --- a/tensorflow/compiler/xla/service/mlir_gpu/experimental/conv_emitter/conv_emitter_test.cc +++ b/tensorflow/compiler/xla/service/mlir_gpu/experimental/conv_emitter/conv_emitter_test.cc @@ -46,6 +46,7 @@ std::string CompileHloConvAndGetMlir(absl::string_view hlo_text) { hlo_module.entry_computation()->root_instruction(); mlir::MLIRContext context; + context.loadAllGloballyRegisteredDialects(); mlir::OwningModuleRef mlir_module( mlir::ModuleOp::create(mlir::UnknownLoc::get(&context))); diff --git a/tensorflow/compiler/xla/service/mlir_gpu/lhlo_dialect_emitter.cc b/tensorflow/compiler/xla/service/mlir_gpu/lhlo_dialect_emitter.cc index e0d7456fbb8..b275dd4525f 100644 --- a/tensorflow/compiler/xla/service/mlir_gpu/lhlo_dialect_emitter.cc +++ b/tensorflow/compiler/xla/service/mlir_gpu/lhlo_dialect_emitter.cc @@ -17,6 +17,7 @@ limitations under the License. #include +#include "llvm/IR/DataLayout.h" #include "mlir/Dialect/LLVMIR/LLVMDialect.h" // from @llvm-project #include "mlir/Dialect/StandardOps/IR/Ops.h" // from @llvm-project #include "mlir/IR/Attributes.h" // from @llvm-project @@ -203,9 +204,13 @@ LhloDialectEmitter::LhloDialectEmitter( builder_(mlir_module_.getContext()), buffer_assignment_(assignment), platform_(platform) { - LLVMDialect* llvmDialect = - mlir_module.getContext()->getRegisteredDialect(); - pointer_size_ = llvmDialect->getDataLayout().getPointerSize(); + llvm::DataLayout data_layout(""); + if (auto data_layout_attr = mlir_module.getAttrOfType( + mlir::LLVM::LLVMDialect::getDataLayoutAttrName())) { + data_layout.reset(data_layout_attr.getValue()); + } + + pointer_size_ = data_layout.getPointerSize(); } void LhloDialectEmitter::AddThunkToThunkSequence(std::unique_ptr thunk) { diff --git a/tensorflow/compiler/xla/service/mlir_gpu/mlir_compiler.cc b/tensorflow/compiler/xla/service/mlir_gpu/mlir_compiler.cc index df2bd2e4c23..26c9e155c0c 100644 --- a/tensorflow/compiler/xla/service/mlir_gpu/mlir_compiler.cc +++ b/tensorflow/compiler/xla/service/mlir_gpu/mlir_compiler.cc @@ -25,19 +25,8 @@ limitations under the License. namespace xla { namespace mlir_gpu { -namespace { -using ::mlir::MLIRContext; -using ::mlir::LLVM::LLVMDialect; - -int64 GetPointerSize(MLIRContext* context) { - LLVMDialect* dialect = context->getRegisteredDialect(); - return dialect->getDataLayout().getPointerSize(); -} - -} // namespace - -MlirCompiler::MlirCompiler() : pointer_size_(GetPointerSize(&context_)) {} +MlirCompiler::MlirCompiler() : data_layout_("") {} se::Platform::Id MlirCompiler::PlatformId() const { return stream_executor::cuda::kCudaPlatformId; diff --git a/tensorflow/compiler/xla/service/mlir_gpu/mlir_compiler.h b/tensorflow/compiler/xla/service/mlir_gpu/mlir_compiler.h index a7b2f9446fa..261e249c0a1 100644 --- a/tensorflow/compiler/xla/service/mlir_gpu/mlir_compiler.h +++ b/tensorflow/compiler/xla/service/mlir_gpu/mlir_compiler.h @@ -16,6 +16,7 @@ limitations under the License. #ifndef TENSORFLOW_COMPILER_XLA_SERVICE_MLIR_GPU_MLIR_COMPILER_H_ #define TENSORFLOW_COMPILER_XLA_SERVICE_MLIR_GPU_MLIR_COMPILER_H_ +#include "llvm/IR/DataLayout.h" #include "mlir/IR/MLIRContext.h" // from @llvm-project #include "mlir/IR/Module.h" // from @llvm-project #include "tensorflow/compiler/xla/service/compiler.h" @@ -58,7 +59,7 @@ class MlirCompiler : public Compiler { protected: ::mlir::MLIRContext context_; - int64 pointer_size_; + llvm::DataLayout data_layout_; IRHook module_hook_; ErrorHandler error_handler_; }; diff --git a/tensorflow/compiler/xla/service/mlir_gpu/mlir_compiler_impl.cc b/tensorflow/compiler/xla/service/mlir_gpu/mlir_compiler_impl.cc index 4879c6b5099..c7977aa776a 100644 --- a/tensorflow/compiler/xla/service/mlir_gpu/mlir_compiler_impl.cc +++ b/tensorflow/compiler/xla/service/mlir_gpu/mlir_compiler_impl.cc @@ -104,7 +104,7 @@ class MlirCompilerImpl : public MlirCompiler { const AotCompilationOptions& options) override; HloCostAnalysis::ShapeSizeFunction ShapeSizeBytesFunction() const override { - int64 pointer_size = pointer_size_; + int64 pointer_size = data_layout_.getPointerSize(); return [pointer_size](const Shape& shape) { return ShapeUtil::ByteSizeOf(shape, pointer_size); }; @@ -462,9 +462,9 @@ StatusOr> MlirCompilerImpl::RunBackend( // must also be used to determine the thunk launch schedule. std::unique_ptr stream_assignment = xla::gpu::AssignStreams(*module); - TF_ASSIGN_OR_RETURN( - std::unique_ptr hlo_schedule, - GpuHloSchedule::Build(*module, *stream_assignment, pointer_size_)); + TF_ASSIGN_OR_RETURN(std::unique_ptr hlo_schedule, + GpuHloSchedule::Build(*module, *stream_assignment, + data_layout_.getPointerSize())); // Run buffer analysis on the HLO graph. This analysis figures out which // temporary buffers are required to run the computation. diff --git a/tensorflow/compiler/xla/service/pattern_matcher.h b/tensorflow/compiler/xla/service/pattern_matcher.h index febbf9294b0..eb29fa89098 100644 --- a/tensorflow/compiler/xla/service/pattern_matcher.h +++ b/tensorflow/compiler/xla/service/pattern_matcher.h @@ -351,8 +351,7 @@ class AllOfPattern { // Returns a pattern that represents the conjunction of all input patterns. All // patterns need to match in order to have the AllOf pattern match. template -detail::AllOfPattern::type, Patterns...> AllOf( - const Patterns&... patterns) { +auto AllOf(const Patterns&... patterns) { return detail::AllOfPattern::type, Patterns...>(patterns...); } @@ -361,10 +360,8 @@ detail::AllOfPattern::type, Patterns...> AllOf( // // This transformation is necessary for good pretty-printing. template -detail::AllOfPattern::type, InnerPs..., - OuterPs...> -AllOf(const detail::AllOfPattern& inner_p, - const OuterPs&... outer_ps) { +auto AllOf(const detail::AllOfPattern& inner_p, + const OuterPs&... outer_ps) { // Invoke constructor of AllOfPattern. auto make_all_of = [](const InnerPs&... inner_ps, const OuterPs&... outer_ps) { @@ -453,10 +450,7 @@ template class LayoutPattern { private: template - auto AppendImpl(NewImpl new_impl) const - -> LayoutPattern(std::declval(), - std::move(new_impl)))> { + auto AppendImpl(NewImpl new_impl) const { auto new_allof = AllOf<::xla::Layout>(impl_, std::move(new_impl)); return LayoutPattern(std::move(new_allof), matched_layout_); @@ -495,14 +489,12 @@ class LayoutPattern { // Modifies the pattern to match only if the layout equals the given proto. // The layout must outlive the returned pattern. - constexpr auto EqualTo(const ::xla::Layout* layout) const - -> decltype(this->AppendImpl(LayoutPatternEqualImpl(layout))) { + constexpr auto EqualTo(const ::xla::Layout* layout) const { return AppendImpl(LayoutPatternEqualImpl(layout)); } // Modifies the pattern to match only if the layout has a dense format. - constexpr auto WithDenseFormat() const - -> decltype(this->AppendImpl(LayoutPatternFormatImpl(DENSE))) { + constexpr auto WithDenseFormat() const { return AppendImpl(LayoutPatternFormatImpl(DENSE)); } @@ -626,17 +618,14 @@ class AnyOfPattern { // patterns. The returned pattern matches from left to right, and stops on the // first match. template -detail::AnyOfPattern::type, Patterns...> AnyOf( - const Patterns&... patterns) { +auto AnyOf(const Patterns&... patterns) { return detail::AnyOfPattern::type, Patterns...>(patterns...); } // Creates a layout pattern that will capture the matched layout in the // argument. -inline constexpr detail::LayoutPattern -Layout(const ::xla::Layout** matched_layout = nullptr) { +inline constexpr auto Layout(const ::xla::Layout** matched_layout = nullptr) { return detail::LayoutPattern( detail::LayoutPatternBaseImpl(), matched_layout); @@ -644,9 +633,7 @@ Layout(const ::xla::Layout** matched_layout = nullptr) { // Creates a layout pattern that will capture the matched layout in the // argument. -inline constexpr detail::LayoutPattern<::xla::Layout, - detail::LayoutPatternBaseImpl> -Layout(::xla::Layout** matched_layout) { +inline constexpr auto Layout(::xla::Layout** matched_layout) { return detail::LayoutPattern<::xla::Layout, detail::LayoutPatternBaseImpl>( detail::LayoutPatternBaseImpl(), matched_layout); } @@ -939,10 +926,7 @@ template class ShapePattern { private: template - auto AppendImpl(NewImpl new_impl) const - -> ShapePattern(std::declval(), - std::move(new_impl)))> { + auto AppendImpl(NewImpl new_impl) const { auto new_all_of = AllOf<::xla::Shape>(impl_, std::move(new_impl)); return ShapePattern(std::move(new_all_of), matched_shape_); @@ -988,80 +972,66 @@ class ShapePattern { // Modifies the pattern to match only if the shape equals the given proto. // The layout must outlive the returned pattern. - constexpr auto EqualTo(const ::xla::Shape* shape) const - -> decltype(this->AppendImpl(ShapePatternEqualImpl(shape))) { + constexpr auto EqualTo(const ::xla::Shape* shape) const { return AppendImpl(ShapePatternEqualImpl(shape)); } // Modifies the pattern to match only if the shape is compatible to the given // proto. The layout must outlive the returned pattern. - constexpr auto CompatibleTo(const ::xla::Shape* shape) const - -> decltype(this->AppendImpl(ShapePatternCompatibleImpl(shape))) { + constexpr auto CompatibleTo(const ::xla::Shape* shape) const { return AppendImpl(ShapePatternCompatibleImpl(shape)); } // Modifies the pattern to match only if the shape has the given element type. - constexpr auto WithElementType(PrimitiveType element_type) const - -> decltype(this->AppendImpl(ShapePatternElementTypeImpl(element_type))) { + constexpr auto WithElementType(PrimitiveType element_type) const { return AppendImpl(ShapePatternElementTypeImpl(element_type)); } // Modifies the pattern to match only if the shape is scalar. - constexpr auto IsScalar() const - -> decltype(this->AppendImpl(ShapePatternIsScalarImpl())) { + constexpr auto IsScalar() const { return AppendImpl(ShapePatternIsScalarImpl()); } // Modifies the pattern to match only if the shape is an array. - constexpr auto IsArray() const - -> decltype(this->AppendImpl(ShapePatternIsArrayImpl())) { + constexpr auto IsArray() const { return AppendImpl(ShapePatternIsArrayImpl()); } // Modifies the pattern to match only if the shape is a tuple. - constexpr auto IsTuple() const - -> decltype(this->AppendImpl(ShapePatternIsTupleImpl())) { + constexpr auto IsTuple() const { return AppendImpl(ShapePatternIsTupleImpl()); } - constexpr auto IsEffectiveScalar() const - -> decltype(this->AppendImpl(ShapePatternEffectiveScalarImpl())) { + constexpr auto IsEffectiveScalar() const { return AppendImpl(ShapePatternEffectiveScalarImpl()); } // Modifies the pattern to match only if the shape has the given rank. - constexpr auto WithRank(int64 rank) const - -> decltype(this->AppendImpl(ShapePatternRankImpl(rank))) { + constexpr auto WithRank(int64 rank) const { return AppendImpl(ShapePatternRankImpl(rank)); } // Modifies the pattern to match only if the shape has a layout that matches // the given pattern. template - auto WithLayout(const LayoutPattern& layout) const - -> decltype(this->AppendImpl( - ShapePatternLayoutImpl(layout))) { + auto WithLayout(const LayoutPattern& layout) const { return AppendImpl(ShapePatternLayoutImpl(layout)); } - constexpr auto WithLayoutEqualTo(const ::xla::Layout* layout) const - -> decltype(this->WithLayout(Layout().EqualTo(layout))) { + constexpr auto WithLayoutEqualTo(const ::xla::Layout* layout) const { return WithLayout(Layout().EqualTo(layout)); } - constexpr auto IsDenseArray() const - -> decltype(this->WithLayout(Layout().WithDenseFormat())) { + constexpr auto IsDenseArray() const { return WithLayout(Layout().WithDenseFormat()); } // Modifies the pattern to match only if the shape has a subshape that matches // the given pattern. template - auto WithSubshape(ShapeIndexView index, - const ShapePattern& subshape) - const -> decltype(this->AppendImpl( - ShapePatternSubshapeImpl(index, - subshape))) { + auto WithSubshape( + ShapeIndexView index, + const ShapePattern& subshape) const { return AppendImpl( ShapePatternSubshapeImpl(index, subshape)); } @@ -1101,17 +1071,13 @@ class ShapePattern { } // namespace detail // Creates a shape pattern that will capture the matched layout in the argument. -inline constexpr detail::ShapePattern -Shape(const ::xla::Shape** matched_shape = nullptr) { +inline constexpr auto Shape(const ::xla::Shape** matched_shape = nullptr) { return detail::ShapePattern( detail::ShapePatternBaseImpl(), matched_shape); } // Creates a shape pattern that will capture the matched layout in the argument. -inline constexpr detail::ShapePattern<::xla::Shape, - detail::ShapePatternBaseImpl> -Shape(::xla::Shape** matched_shape) { +inline constexpr auto Shape(::xla::Shape** matched_shape) { return detail::ShapePattern<::xla::Shape, detail::ShapePatternBaseImpl>( detail::ShapePatternBaseImpl(), matched_shape); } @@ -1797,9 +1763,7 @@ template class HloInstructionPattern { private: template - auto AppendImpl(NewImpl new_impl) const -> HloInstructionPattern< - HloInstructionType, decltype(AllOf<::xla::HloInstruction>( - std::declval(), std::move(new_impl)))> { + auto AppendImpl(NewImpl new_impl) const { auto new_allof = AllOf<::xla::HloInstruction>(impl_, std::move(new_impl)); return HloInstructionPattern( std::move(new_allof), matched_inst_); @@ -1837,51 +1801,38 @@ class HloInstructionPattern { } // Modifies the pattern to match only if the instruction has the given name. - auto WithName(absl::string_view name) const - -> decltype(this->AppendImpl(HloInstructionPatternNameImpl(name))) { + auto WithName(absl::string_view name) const { return AppendImpl(HloInstructionPatternNameImpl(name)); } // Modifies the pattern to match only if the instruction has the given opcode. - auto WithOpcode(HloOpcode opcode) const - -> decltype(this->AppendImpl(HloInstructionPatternOpcodeImpl(opcode, - false))) { + auto WithOpcode(HloOpcode opcode) const { return AppendImpl(HloInstructionPatternOpcodeImpl(opcode, false)); } // Modifies the pattern to match only the custom call with a given target. - auto WithCustomCallTarget(absl::string_view custom_call_target) const - -> decltype(this->AppendImpl( - HloInstructionCustomCallTargetImpl(custom_call_target))) { + auto WithCustomCallTarget(absl::string_view custom_call_target) const { return AppendImpl(HloInstructionCustomCallTargetImpl(custom_call_target)); } - auto WithNumOperands(int64 num_operands) const -> decltype( - this->AppendImpl(HloInstructionPatternNumOperandsImpl(num_operands))) { + auto WithNumOperands(int64 num_operands) const { return AppendImpl(HloInstructionPatternNumOperandsImpl(num_operands)); } // Modifies the pattern to match only if the instruction does not have the // given opcode. - auto WithoutOpcode(HloOpcode opcode) const - -> decltype(this->AppendImpl(HloInstructionPatternOpcodeImpl(opcode, - true))) { + auto WithoutOpcode(HloOpcode opcode) const { return AppendImpl(HloInstructionPatternOpcodeImpl(opcode, true)); } - constexpr auto Is(const HloInstruction* instr) const - -> decltype(this->AppendImpl(HloInstructionIsImpl(instr))) { + constexpr auto Is(const HloInstruction* instr) const { return AppendImpl(HloInstructionIsImpl(instr)); } // Modifies the pattern to match only if the instruction is a constant. - constexpr auto IsConstant() const - -> decltype(this->WithOpcode(HloOpcode::kConstant)) { - return WithOpcode(HloOpcode::kConstant); - } + constexpr auto IsConstant() const { return WithOpcode(HloOpcode::kConstant); } - constexpr auto IsConstantScalar() const -> decltype(this->AppendImpl( - HloConstantScalarImpl(/*match_effective_scalar=*/false))) { + constexpr auto IsConstantScalar() const { return AppendImpl( HloConstantScalarImpl(/*match_effective_scalar=*/false)); } @@ -1889,39 +1840,32 @@ class HloInstructionPattern { // This does not check that T has the same type as the instruction, so e.g. // IsConstantScalar(1.0) may match a constant of shape int32[]. template - constexpr auto IsConstantScalar(const ScalarTy& val) const - -> decltype(this->AppendImpl(HloConstantScalarImpl( - val, /*match_effective_scalar=*/false))) { + constexpr auto IsConstantScalar(const ScalarTy& val) const { return AppendImpl( HloConstantScalarImpl(val, /*match_effective_scalar=*/false)); } - constexpr auto IsConstantEffectiveScalar() const -> decltype(this->AppendImpl( - HloConstantScalarImpl(/*match_effective_scalar=*/true))) { + constexpr auto IsConstantEffectiveScalar() const { return AppendImpl( HloConstantScalarImpl(/*match_effective_scalar=*/true)); } template - constexpr auto IsConstantEffectiveScalar(const ScalarTy& val) const - -> decltype(this->AppendImpl(HloConstantScalarImpl( - val, /*match_effective_scalar=*/true))) { + constexpr auto IsConstantEffectiveScalar(const ScalarTy& val) const { return AppendImpl( HloConstantScalarImpl(val, /*match_effective_scalar=*/true)); } // Modifies the pattern to match only if the instruction is not a constant. - constexpr auto IsNonConstant() const - -> decltype(this->WithoutOpcode(HloOpcode::kConstant)) { + constexpr auto IsNonConstant() const { return WithoutOpcode(HloOpcode::kConstant); } // Modifies the pattern to match only if the instruction has a shape that // matches the given pattern. template - constexpr auto WithShape(const ShapePattern& shape) - const -> decltype(this->AppendImpl( - HloInstructionPatternShapeImpl(shape))) { + constexpr auto WithShape( + const ShapePattern& shape) const { return AppendImpl( HloInstructionPatternShapeImpl(shape)); } @@ -1929,16 +1873,14 @@ class HloInstructionPattern { // Make this a templated function to work around gcc 4.9.4 template infinite // recursion bug. template - constexpr auto WithShapeEqualTo(const ::xla::Shape* shape) const - -> decltype(this->WithShape(Shape().EqualTo(shape))) { + constexpr auto WithShapeEqualTo(const ::xla::Shape* shape) const { return WithShape(Shape().EqualTo(shape)); } // Make this a templated function to work around gcc 4.9.4 template infinite // recursion bug. template - constexpr auto WithShapeCompatibleTo(const ::xla::Shape* shape) const - -> decltype(this->WithShape(Shape().CompatibleTo(shape))) { + constexpr auto WithShapeCompatibleTo(const ::xla::Shape* shape) const { return WithShape(Shape().CompatibleTo(shape)); } @@ -1947,10 +1889,7 @@ class HloInstructionPattern { template constexpr auto WithOperand( int64 operand_index, - const HloInstructionPattern& operand) const - -> decltype(this->AppendImpl( - HloInstructionPatternOperandImpl( - operand_index, operand))) { + const HloInstructionPattern& operand) const { return AppendImpl( HloInstructionPatternOperandImpl( operand_index, operand)); @@ -1960,11 +1899,7 @@ class HloInstructionPattern { typename OperandImpl2> constexpr auto WithBinaryOperandsAnyOrder( const HloInstructionPattern& op1, - const HloInstructionPattern& op2) const - -> decltype(this->AppendImpl( - HloInstructionPatternBinaryOperandsAnyOrderImpl< - OperandType1, OperandImpl1, OperandType2, OperandImpl2>(op1, - op2))) { + const HloInstructionPattern& op2) const { return AppendImpl( HloInstructionPatternBinaryOperandsAnyOrderImpl< OperandType1, OperandImpl1, OperandType2, OperandImpl2>(op1, op2)); @@ -1972,46 +1907,39 @@ class HloInstructionPattern { // Modifies the pattern to match only if the instruction is a fusion node with // the given kind. - constexpr auto WithFusionKind(HloInstruction::FusionKind kind) const - -> decltype(this->AppendImpl(HloInstructionPatternFusionKindImpl(kind))) { + constexpr auto WithFusionKind(HloInstruction::FusionKind kind) const { return AppendImpl(HloInstructionPatternFusionKindImpl(kind)); } // Modifies the pattern to match only if the instruction is a // get-tuple-element with the given tuple index. - constexpr auto WithTupleIndex(int64 tuple_index) const -> decltype( - this->AppendImpl(HloInstructionPatternTupleIndexImpl(tuple_index))) { + constexpr auto WithTupleIndex(int64 tuple_index) const { return AppendImpl(HloInstructionPatternTupleIndexImpl(tuple_index)); } // Modifies the pattern to match only if the instruction is a parameter // with the given parameter number. - constexpr auto WithParameterNum(int64 parameter_num) const -> decltype( - this->AppendImpl(HloInstructionPatternParameterNumImpl(parameter_num))) { + constexpr auto WithParameterNum(int64 parameter_num) const { return AppendImpl(HloInstructionPatternParameterNumImpl(parameter_num)); } // Modifies the pattern to match if the instruction is used exactly once. // Does not match if the instruction is used twice by the same user (e.g. // multiply(x,x)). - constexpr auto WithOneUse() const - -> decltype(this->AppendImpl(HloInstructionPatternOneUseImpl())) { + constexpr auto WithOneUse() const { return AppendImpl(HloInstructionPatternOneUseImpl()); } // Modifies the pattern to match if the instruction is used by exactly one // other instruction. Will match if the instruction is used twice, so long as // it's by the same user (e.g. multiply(x,x)). - constexpr auto WithOneUser() const - -> decltype(this->AppendImpl(HloInstructionPatternOneUserImpl())) { + constexpr auto WithOneUser() const { return AppendImpl(HloInstructionPatternOneUserImpl()); } // Modifies the pattern to match only if the instruction has the given // comparison direction. - auto WithComparisonDirection(ComparisonDirection direction) const - -> decltype(this->AppendImpl( - HloInstructionPatternComparisonDirectionImpl(direction))) { + auto WithComparisonDirection(ComparisonDirection direction) const { return AppendImpl(HloInstructionPatternComparisonDirectionImpl(direction)); } @@ -2028,9 +1956,7 @@ class HloInstructionPattern { // Creates an instruction pattern that will capture the matched instruction in // the argument. -inline constexpr detail::HloInstructionPattern< - const ::xla::HloInstruction, detail::HloInstructionPatternBaseImpl> -Op(const ::xla::HloInstruction** matched_inst = nullptr) { +inline constexpr auto Op(const ::xla::HloInstruction** matched_inst = nullptr) { return detail::HloInstructionPattern( detail::HloInstructionPatternBaseImpl(), matched_inst); @@ -2038,24 +1964,19 @@ Op(const ::xla::HloInstruction** matched_inst = nullptr) { // Creates an instruction pattern that will capture the matched instruction in // the argument. -inline constexpr detail::HloInstructionPattern< - ::xla::HloInstruction, detail::HloInstructionPatternBaseImpl> -Op(::xla::HloInstruction** matched_inst) { +inline constexpr auto Op(::xla::HloInstruction** matched_inst) { return detail::HloInstructionPattern<::xla::HloInstruction, detail::HloInstructionPatternBaseImpl>( detail::HloInstructionPatternBaseImpl(), matched_inst); } // Helpers for nullary instructions. -#define XLA_NULLOP_PATTERN(NAME) \ - inline auto NAME()->decltype(Op().WithOpcode(HloOpcode::k##NAME)) { \ - return Op().WithOpcode(HloOpcode::k##NAME); \ - } \ - \ - template \ - inline auto NAME(HloInstructionType** matched_inst) \ - ->decltype(Op(matched_inst).WithOpcode(HloOpcode::k##NAME)) { \ - return Op(matched_inst).WithOpcode(HloOpcode::k##NAME); \ +#define XLA_NULLOP_PATTERN(NAME) \ + inline auto NAME() { return Op().WithOpcode(HloOpcode::k##NAME); } \ + \ + template \ + inline auto NAME(HloInstructionType** matched_inst) { \ + return Op(matched_inst).WithOpcode(HloOpcode::k##NAME); \ } XLA_NULLOP_PATTERN(Constant) XLA_NULLOP_PATTERN(Parameter) @@ -2064,28 +1985,21 @@ XLA_NULLOP_PATTERN(Rng) #undef XLA_NULLOP_PATTERN // Helpers for unary instructions. -#define XLA_UNOP_PATTERN(NAME) \ - inline auto NAME()->decltype(Op().WithOpcode(HloOpcode::k##NAME)) { \ - return Op().WithOpcode(HloOpcode::k##NAME); \ - } \ - \ - template \ - inline auto NAME(Arg&& arg)->decltype( \ - Op().WithOpcode(HloOpcode::k##NAME) \ - .WithOperand(0, std::forward(arg))) { \ - return Op() \ - .WithOpcode(HloOpcode::k##NAME) \ - .WithOperand(0, std::forward(arg)); \ - } \ - \ - template \ - inline auto NAME(HloInstructionType** matched_inst, Arg&& arg) \ - ->decltype(Op(matched_inst) \ - .WithOpcode(HloOpcode::k##NAME) \ - .WithOperand(0, std::forward(arg))) { \ - return Op(matched_inst) \ - .WithOpcode(HloOpcode::k##NAME) \ - .WithOperand(0, std::forward(arg)); \ +#define XLA_UNOP_PATTERN(NAME) \ + inline auto NAME() { return Op().WithOpcode(HloOpcode::k##NAME); } \ + \ + template \ + inline auto NAME(Arg&& arg) { \ + return Op() \ + .WithOpcode(HloOpcode::k##NAME) \ + .WithOperand(0, std::forward(arg)); \ + } \ + \ + template \ + inline auto NAME(HloInstructionType** matched_inst, Arg&& arg) { \ + return Op(matched_inst) \ + .WithOpcode(HloOpcode::k##NAME) \ + .WithOperand(0, std::forward(arg)); \ } XLA_UNOP_PATTERN(Abs) XLA_UNOP_PATTERN(RoundNearestAfz) @@ -2124,55 +2038,40 @@ XLA_UNOP_PATTERN(Transpose) #undef XLA_UNOP_PATTERN // Helpers for binary instructions. -#define XLA_BINOP_PATTERN(NAME) \ - inline auto NAME()->decltype(Op().WithOpcode(HloOpcode::k##NAME)) { \ - return Op().WithOpcode(HloOpcode::k##NAME); \ - } \ - \ - template \ - inline auto NAME(Lhs&& lhs, Rhs&& rhs) \ - ->decltype(Op().WithOpcode(HloOpcode::k##NAME) \ - .WithOperand(0, std::forward(lhs)) \ - .WithOperand(1, std::forward(rhs))) { \ - return Op() \ - .WithOpcode(HloOpcode::k##NAME) \ - .WithOperand(0, std::forward(lhs)) \ - .WithOperand(1, std::forward(rhs)); \ - } \ - \ - template \ - inline auto NAME(HloInstructionType** matched_inst, Lhs&& lhs, Rhs&& rhs) \ - ->decltype(Op(matched_inst) \ - .WithOpcode(HloOpcode::k##NAME) \ - .WithOperand(0, std::forward(lhs)) \ - .WithOperand(1, std::forward(rhs))) { \ - return Op(matched_inst) \ - .WithOpcode(HloOpcode::k##NAME) \ - .WithOperand(0, std::forward(lhs)) \ - .WithOperand(1, std::forward(rhs)); \ +#define XLA_BINOP_PATTERN(NAME) \ + inline auto NAME() { return Op().WithOpcode(HloOpcode::k##NAME); } \ + \ + template \ + inline auto NAME(Lhs&& lhs, Rhs&& rhs) { \ + return Op() \ + .WithOpcode(HloOpcode::k##NAME) \ + .WithOperand(0, std::forward(lhs)) \ + .WithOperand(1, std::forward(rhs)); \ + } \ + \ + template \ + inline auto NAME(HloInstructionType** matched_inst, Lhs&& lhs, Rhs&& rhs) { \ + return Op(matched_inst) \ + .WithOpcode(HloOpcode::k##NAME) \ + .WithOperand(0, std::forward(lhs)) \ + .WithOperand(1, std::forward(rhs)); \ } -#define XLA_COMMUTATIVE_BINOP_PATTERN(NAME) \ - XLA_BINOP_PATTERN(NAME) \ - \ - template \ - inline auto NAME##AnyOrder(HloInstructionType** matched_inst, Lhs&& lhs, \ - Rhs&& rhs) \ - ->decltype(Op(matched_inst) \ - .WithOpcode(HloOpcode::k##NAME) \ - .WithBinaryOperandsAnyOrder(std::forward(lhs), \ - std::forward(rhs))) { \ - return Op(matched_inst) \ - .WithOpcode(HloOpcode::k##NAME) \ - .WithBinaryOperandsAnyOrder(std::forward(lhs), \ - std::forward(rhs)); \ - } \ - template \ - inline auto NAME##AnyOrder(Lhs&& lhs, Rhs&& rhs) \ - ->decltype(NAME##AnyOrder( \ - nullptr, std::forward(lhs), std::forward(rhs))) { \ - return NAME##AnyOrder( \ - nullptr, std::forward(lhs), std::forward(rhs)); \ +#define XLA_COMMUTATIVE_BINOP_PATTERN(NAME) \ + XLA_BINOP_PATTERN(NAME) \ + \ + template \ + inline auto NAME##AnyOrder(HloInstructionType** matched_inst, Lhs&& lhs, \ + Rhs&& rhs) { \ + return Op(matched_inst) \ + .WithOpcode(HloOpcode::k##NAME) \ + .WithBinaryOperandsAnyOrder(std::forward(lhs), \ + std::forward(rhs)); \ + } \ + template \ + inline auto NAME##AnyOrder(Lhs&& lhs, Rhs&& rhs) { \ + return NAME##AnyOrder( \ + nullptr, std::forward(lhs), std::forward(rhs)); \ } XLA_COMMUTATIVE_BINOP_PATTERN(Add) XLA_BINOP_PATTERN(Atan2) @@ -2202,16 +2101,10 @@ XLA_BINOP_PATTERN(ShiftRightLogical) // Helpers for ternary instructions. #define XLA_TERNOP_PATTERN(NAME) \ - inline auto NAME()->decltype(Op().WithOpcode(HloOpcode::k##NAME)) { \ - return Op().WithOpcode(HloOpcode::k##NAME); \ - } \ + inline auto NAME() { return Op().WithOpcode(HloOpcode::k##NAME); } \ \ template \ - inline auto NAME(Arg0&& arg0, Arg1&& arg1, Arg2&& arg2) \ - ->decltype(Op().WithOpcode(HloOpcode::k##NAME) \ - .WithOperand(0, std::forward(arg0)) \ - .WithOperand(1, std::forward(arg1)) \ - .WithOperand(2, std::forward(arg2))) { \ + inline auto NAME(Arg0&& arg0, Arg1&& arg1, Arg2&& arg2) { \ return Op() \ .WithOpcode(HloOpcode::k##NAME) \ .WithOperand(0, std::forward(arg0)) \ @@ -2222,12 +2115,7 @@ XLA_BINOP_PATTERN(ShiftRightLogical) template \ inline auto NAME(HloInstructionType** matched_inst, Arg0&& arg0, \ - Arg1&& arg1, Arg2&& arg2) \ - ->decltype(Op(matched_inst) \ - .WithOpcode(HloOpcode::k##NAME) \ - .WithOperand(0, std::forward(arg0)) \ - .WithOperand(1, std::forward(arg1)) \ - .WithOperand(2, std::forward(arg2))) { \ + Arg1&& arg1, Arg2&& arg2) { \ return Op(matched_inst) \ .WithOpcode(HloOpcode::k##NAME) \ .WithOperand(0, std::forward(arg0)) \ @@ -2241,17 +2129,13 @@ XLA_TERNOP_PATTERN(Select); namespace detail { template -inline auto WithOperands(Matcher&& m, int64 operand_num, FirstArg&& first_arg) - -> decltype(m.WithOperand(operand_num, std::forward(first_arg))) { +inline auto WithOperands(Matcher&& m, int64 operand_num, FirstArg&& first_arg) { return m.WithOperand(operand_num, std::forward(first_arg)); } template inline auto WithOperands(Matcher&& m, int64 operand_num, FirstArg&& first_arg, - Args&&... args) - -> decltype(WithOperands(m.WithOperand(operand_num, - std::forward(first_arg)), - operand_num + 1, std::forward(args)...)) { + Args&&... args) { return WithOperands( m.WithOperand(operand_num, std::forward(first_arg)), operand_num + 1, std::forward(args)...); @@ -2259,26 +2143,17 @@ inline auto WithOperands(Matcher&& m, int64 operand_num, FirstArg&& first_arg, } // namespace detail #define XLA_VARIADIC_OP_PATTERN(NAME) \ - inline auto NAME()->decltype(Op().WithOpcode(HloOpcode::k##NAME)) { \ - return Op().WithOpcode(HloOpcode::k##NAME); \ - } \ + inline auto NAME() { return Op().WithOpcode(HloOpcode::k##NAME); } \ \ template \ - inline auto NAME(Args&&... args) \ - ->decltype(detail::WithOperands(Op().WithOpcode(HloOpcode::k##NAME) \ - .WithNumOperands(sizeof...(Args)), \ - 0, std::forward(args)...)) { \ + inline auto NAME(Args&&... args) { \ return detail::WithOperands( \ Op().WithOpcode(HloOpcode::k##NAME).WithNumOperands(sizeof...(Args)), \ /*operand_num=*/0, std::forward(args)...); \ } \ \ template \ - inline auto NAME(HloInstructionType** matched_inst, Args&&... args) \ - ->decltype(detail::WithOperands(Op(matched_inst) \ - .WithOpcode(HloOpcode::k##NAME) \ - .WithNumOperands(sizeof...(Args)), \ - 0, std::forward(args)...)) { \ + inline auto NAME(HloInstructionType** matched_inst, Args&&... args) { \ return detail::WithOperands(Op(matched_inst) \ .WithOpcode(HloOpcode::k##NAME) \ .WithNumOperands(sizeof...(Args)), \ @@ -2299,63 +2174,46 @@ XLA_VARIADIC_OP_PATTERN(Sort); XLA_VARIADIC_OP_PATTERN(Tuple); // Helpers for comparison instructions. -#define XLA_COMPARE_PATTERN(NAME) \ - inline auto NAME()->decltype( \ - Op().WithOpcode(HloOpcode::kCompare) \ - .WithComparisonDirection(ComparisonDirection::k##NAME)) { \ - return Op() \ - .WithOpcode(HloOpcode::kCompare) \ - .WithComparisonDirection(ComparisonDirection::k##NAME); \ - } \ - \ - template \ - inline auto NAME(Lhs&& lhs, Rhs&& rhs) \ - ->decltype(Op().WithOpcode(HloOpcode::kCompare) \ - .WithOperand(0, std::forward(lhs)) \ - .WithOperand(1, std::forward(rhs)) \ - .WithComparisonDirection(ComparisonDirection::k##NAME)) { \ - return Op() \ - .WithOpcode(HloOpcode::kCompare) \ - .WithOperand(0, std::forward(lhs)) \ - .WithOperand(1, std::forward(rhs)) \ - .WithComparisonDirection(ComparisonDirection::k##NAME); \ - } \ - \ - template \ - inline auto NAME(HloInstructionType** matched_inst, Lhs&& lhs, Rhs&& rhs) \ - ->decltype(Op(matched_inst) \ - .WithOpcode(HloOpcode::kCompare) \ - .WithOperand(0, std::forward(lhs)) \ - .WithOperand(1, std::forward(rhs)) \ - .WithComparisonDirection(ComparisonDirection::k##NAME)) { \ - return Op(matched_inst) \ - .WithOpcode(HloOpcode::kCompare) \ - .WithOperand(0, std::forward(lhs)) \ - .WithOperand(1, std::forward(rhs)) \ - .WithComparisonDirection(ComparisonDirection::k##NAME); \ +#define XLA_COMPARE_PATTERN(NAME) \ + inline auto NAME() { \ + return Op() \ + .WithOpcode(HloOpcode::kCompare) \ + .WithComparisonDirection(ComparisonDirection::k##NAME); \ + } \ + \ + template \ + inline auto NAME(Lhs&& lhs, Rhs&& rhs) { \ + return Op() \ + .WithOpcode(HloOpcode::kCompare) \ + .WithOperand(0, std::forward(lhs)) \ + .WithOperand(1, std::forward(rhs)) \ + .WithComparisonDirection(ComparisonDirection::k##NAME); \ + } \ + \ + template \ + inline auto NAME(HloInstructionType** matched_inst, Lhs&& lhs, Rhs&& rhs) { \ + return Op(matched_inst) \ + .WithOpcode(HloOpcode::kCompare) \ + .WithOperand(0, std::forward(lhs)) \ + .WithOperand(1, std::forward(rhs)) \ + .WithComparisonDirection(ComparisonDirection::k##NAME); \ } -#define XLA_COMMUTATIVE_COMPARE_PATTERN(NAME) \ - XLA_COMPARE_PATTERN(NAME) \ - \ - template \ - inline auto NAME##AnyOrder(HloInstructionType** matched_inst, Lhs&& lhs, \ - Rhs&& rhs) \ - ->decltype(Op(matched_inst) \ - .WithOpcode(HloOpcode::kCompare) \ - .WithBinaryOperandsAnyOrder(std::forward(lhs), \ - std::forward(rhs))) { \ - return Op(matched_inst) \ - .WithOpcode(HloOpcode::kCompare) \ - .WithBinaryOperandsAnyOrder(std::forward(lhs), \ - std::forward(rhs)); \ - } \ - template \ - inline auto NAME##AnyOrder(Lhs&& lhs, Rhs&& rhs) \ - ->decltype(NAME##AnyOrder( \ - nullptr, std::forward(lhs), std::forward(rhs))) { \ - return NAME##AnyOrder( \ - nullptr, std::forward(lhs), std::forward(rhs)); \ +#define XLA_COMMUTATIVE_COMPARE_PATTERN(NAME) \ + XLA_COMPARE_PATTERN(NAME) \ + \ + template \ + inline auto NAME##AnyOrder(HloInstructionType** matched_inst, Lhs&& lhs, \ + Rhs&& rhs) { \ + return Op(matched_inst) \ + .WithOpcode(HloOpcode::kCompare) \ + .WithBinaryOperandsAnyOrder(std::forward(lhs), \ + std::forward(rhs)); \ + } \ + template \ + inline auto NAME##AnyOrder(Lhs&& lhs, Rhs&& rhs) { \ + return NAME##AnyOrder( \ + nullptr, std::forward(lhs), std::forward(rhs)); \ } XLA_COMMUTATIVE_COMPARE_PATTERN(Eq); @@ -2366,23 +2224,17 @@ XLA_COMPARE_PATTERN(Le); XLA_COMPARE_PATTERN(Lt); // Helpers for matching non-constant instructions. -inline auto NonConstant() -> decltype(Op().IsNonConstant()) { - return Op().IsNonConstant(); -} +inline auto NonConstant() { return Op().IsNonConstant(); } template -inline auto NonConstant(HloInstructionType** matched_inst) - -> decltype(Op(matched_inst).IsNonConstant()) { +inline auto NonConstant(HloInstructionType** matched_inst) { return Op(matched_inst).IsNonConstant(); } // Add overloads for GetTupleElement which take a int64 specifying which tuple // element is selected. template -inline auto GetTupleElement(Arg&& arg, int64 tuple_index) - -> decltype(Op().WithOpcode(HloOpcode::kGetTupleElement) - .WithOperand(0, std::forward(arg)) - .WithTupleIndex(tuple_index)) { +inline auto GetTupleElement(Arg&& arg, int64 tuple_index) { return Op() .WithOpcode(HloOpcode::kGetTupleElement) .WithOperand(0, std::forward(arg)) @@ -2391,11 +2243,7 @@ inline auto GetTupleElement(Arg&& arg, int64 tuple_index) template inline auto GetTupleElement(HloInstructionType** matched_inst, Arg&& arg, - int64 tuple_index) - -> decltype(Op(matched_inst) - .WithOpcode(HloOpcode::kGetTupleElement) - .WithOperand(0, std::forward(arg)) - .WithTupleIndex(tuple_index)) { + int64 tuple_index) { return Op(matched_inst) .WithOpcode(HloOpcode::kGetTupleElement) .WithOperand(0, std::forward(arg)) @@ -2404,62 +2252,50 @@ inline auto GetTupleElement(HloInstructionType** matched_inst, Arg&& arg, // Add overloads for Parameter which take an int64 specifying the parameter // number. -inline auto Parameter(int64 parameter_num) -> decltype( - Op().WithOpcode(HloOpcode::kParameter).WithParameterNum(parameter_num)) { +inline auto Parameter(int64 parameter_num) { return Op().WithOpcode(HloOpcode::kParameter).WithParameterNum(parameter_num); } template -inline auto Parameter(HloInstructionType** matched_inst, int64 parameter_num) - -> decltype(Op(matched_inst) - .WithOpcode(HloOpcode::kParameter) - .WithParameterNum(parameter_num)) { +inline auto Parameter(HloInstructionType** matched_inst, int64 parameter_num) { return Op(matched_inst) .WithOpcode(HloOpcode::kParameter) .WithParameterNum(parameter_num); } -inline auto ConstantScalar() -> decltype(Op().IsConstantScalar()) { - return Op().IsConstantScalar(); -} +inline auto ConstantScalar() { return Op().IsConstantScalar(); } template -inline auto ConstantScalar(HloInstructionType** matched_inst) - -> decltype(Op(matched_inst).IsConstantScalar()) { +inline auto ConstantScalar(HloInstructionType** matched_inst) { return Op(matched_inst).IsConstantScalar(); } template -inline auto ConstantScalar(ScalarTy val) - -> decltype(Op().IsConstantScalar(val)) { +inline auto ConstantScalar(ScalarTy val) { return Op().IsConstantScalar(val); } template -inline auto ConstantScalar(HloInstructionType** matched_inst, ScalarTy val) - -> decltype(Op(matched_inst).IsConstantScalar(val)) { +inline auto ConstantScalar(HloInstructionType** matched_inst, ScalarTy val) { return Op(matched_inst).IsConstantScalar(val); } -inline auto ConstantEffectiveScalar() -> decltype(Op().IsConstantScalar()) { +inline auto ConstantEffectiveScalar() { return Op().IsConstantEffectiveScalar(); } template -inline auto ConstantEffectiveScalar(HloInstructionType** matched_inst) - -> decltype(Op(matched_inst).IsConstantScalar()) { +inline auto ConstantEffectiveScalar(HloInstructionType** matched_inst) { return Op(matched_inst).IsConstantEffectiveScalar(); } template -inline auto ConstantEffectiveScalar(ScalarTy val) - -> decltype(Op().IsConstantEffectiveScalar(val)) { +inline auto ConstantEffectiveScalar(ScalarTy val) { return Op().IsConstantEffectiveScalar(val); } template inline auto ConstantEffectiveScalar(HloInstructionType** matched_inst, - ScalarTy val) - -> decltype(Op(matched_inst).IsConstantEffectiveScalar(val)) { + ScalarTy val) { return Op(matched_inst).IsConstantEffectiveScalar(val); } diff --git a/tensorflow/compiler/xla/service/scatter_expander.cc b/tensorflow/compiler/xla/service/scatter_expander.cc index e3a3feb8640..bd99f920ea0 100644 --- a/tensorflow/compiler/xla/service/scatter_expander.cc +++ b/tensorflow/compiler/xla/service/scatter_expander.cc @@ -325,6 +325,22 @@ static StatusOr> ScatterLoopBody( {updated_operand, scatter_indices, updates}}; } +static int64 ScatterTripCount(HloInstruction* scatter) { + // Compute the trip count for the while loop to be used for scatter. This + // should be the number of indices we should scatter into the operand. + HloInstruction* scatter_indices = scatter->mutable_operand(1); + const Shape& scatter_indices_shape = scatter_indices->shape(); + const ScatterDimensionNumbers& dim_numbers = + scatter->scatter_dimension_numbers(); + int64 scatter_loop_trip_count = 1; + for (int64 i = 0, e = scatter_indices_shape.dimensions_size(); i < e; i++) { + if (i != dim_numbers.index_vector_dim()) { + scatter_loop_trip_count *= scatter_indices_shape.dimensions(i); + } + } + return scatter_loop_trip_count; +} + // High Level Algorithm. // // 1. Canonicalize the scatter_indices tensor such that it has rank 2, where @@ -342,7 +358,7 @@ static StatusOr> ScatterLoopBody( // from c. and d. using the update_computation of scatter. // f. Write the updated value of the slice into the operand tensor. -StatusOr ScatterExpander::ExpandScatter( +StatusOr ScatterExpander::ExpandInstruction( HloInstruction* scatter) { HloInstruction* operand = scatter->mutable_operand(0); HloInstruction* scatter_indices = scatter->mutable_operand(1); @@ -358,13 +374,7 @@ StatusOr ScatterExpander::ExpandScatter( // Compute the trip count for the while loop to be used for scatter. This // should be the number of indices we should scatter into the operand. - const Shape& scatter_indices_shape = scatter_indices->shape(); - int64 scatter_loop_trip_count = 1; - for (int64 i = 0, e = scatter_indices_shape.dimensions_size(); i < e; i++) { - if (i != dim_numbers.index_vector_dim()) { - scatter_loop_trip_count *= scatter_indices_shape.dimensions(i); - } - } + int64 scatter_loop_trip_count = ScatterTripCount(scatter); if (!IsInt32(scatter_loop_trip_count)) { return Unimplemented( "Scatter operations with more than 2147483647 scatter indices are not " @@ -408,23 +418,9 @@ StatusOr ScatterExpander::ExpandScatter( return scatter_loop_result.front(); } -StatusOr ScatterExpander::Run(HloModule* module) { - std::vector scatter_instrs; - for (HloComputation* computation : module->MakeNonfusionComputations()) { - for (HloInstruction* instr : computation->instructions()) { - if (instr->opcode() == HloOpcode::kScatter) { - scatter_instrs.push_back(instr); - } - } - } - - for (auto instr : scatter_instrs) { - TF_ASSIGN_OR_RETURN(HloInstruction * expanded_root, ExpandScatter(instr)); - TF_RETURN_IF_ERROR( - instr->parent()->ReplaceInstruction(instr, expanded_root)); - } - - return !scatter_instrs.empty(); +bool ScatterExpander::InstructionMatchesPattern(HloInstruction* inst) { + return inst->opcode() == HloOpcode::kScatter && + (mode_ == kEliminateAllScatters || ScatterTripCount(inst) == 1); } } // namespace xla diff --git a/tensorflow/compiler/xla/service/scatter_expander.h b/tensorflow/compiler/xla/service/scatter_expander.h index 533af060bc9..aa59e7ec3b0 100644 --- a/tensorflow/compiler/xla/service/scatter_expander.h +++ b/tensorflow/compiler/xla/service/scatter_expander.h @@ -16,17 +16,43 @@ limitations under the License. #ifndef TENSORFLOW_COMPILER_XLA_SERVICE_SCATTER_EXPANDER_H_ #define TENSORFLOW_COMPILER_XLA_SERVICE_SCATTER_EXPANDER_H_ -#include "tensorflow/compiler/xla/service/hlo_pass_interface.h" +#include "tensorflow/compiler/xla/service/op_expander_pass.h" namespace xla { -class ScatterExpander : public HloModulePass { +// This pass rewrites scatter operations into (roughly) while loops of +// dynamic-update-slices. +// +// This pass can be used in two ways: +// +// - kEliminateAllScatters: For backends that don't support scatter, this pass +// can convert every scatter into a loop. +// +// - kEliminateSimpleScatters: For backends that *do* support scatter, this +// pass can strength-reduce "simple" scatters -- specifically, scatters that +// can be represented without a loop -- to dynamic-update-slices. +// +// Note that even in kEliminateSimpleScatters mode, this pass may still expand a +// scatter into a loop (with a trip-count of 1). It's up to other +// simplification passes to remove the loop. +class ScatterExpander : public OpExpanderPass { public: + enum Mode { + kEliminateAllScatters, + kEliminateSimpleScatters, + }; + + explicit ScatterExpander(Mode m) : mode_(m) {} + absl::string_view name() const override { return "scatter_expander"; } - StatusOr Run(HloModule* module) override; protected: - StatusOr ExpandScatter(HloInstruction* scatter); + bool InstructionMatchesPattern(HloInstruction* inst) override; + + StatusOr ExpandInstruction(HloInstruction* scatter) override; + + private: + Mode mode_; }; } // namespace xla diff --git a/tensorflow/compiler/xla/service/scatter_expander_test.cc b/tensorflow/compiler/xla/service/scatter_expander_test.cc index 3852b82c1ef..9f4cc5406d8 100644 --- a/tensorflow/compiler/xla/service/scatter_expander_test.cc +++ b/tensorflow/compiler/xla/service/scatter_expander_test.cc @@ -57,11 +57,79 @@ TEST_F(ScatterExpanderTest, ScatterOperandWithoutLayout) { ParseAndReturnVerifiedModule(kModuleStr)); // The HLO parser changes all no layout shapes from the input to have a - // default layout, clear the layout of the scatter operand for testing. + // default layout. Clear the layout of the scatter operand for testing. HloInstruction* scatter_operand = FindInstruction(module.get(), "operand"); scatter_operand->mutable_shape()->clear_layout(); - ScatterExpander scatter_expander; + ScatterExpander scatter_expander(ScatterExpander::kEliminateAllScatters); + TF_ASSERT_OK_AND_ASSIGN(bool result, + RunHloPass(&scatter_expander, module.get())); + EXPECT_TRUE(result); +} + +TEST_F(ScatterExpanderTest, EliminateSimpleScattersSkipsNontrivialScatter) { + const char* kModuleStr = R"( + HloModule scatter_expander + + scatter_computation { + parameter0 = s32[] parameter(0) + ROOT parameter1 = s32[] parameter(1) + } + + ENTRY kernel_entry { + operand = s32[3,3] parameter(0) + indices = s32[2] parameter(1) + updates = s32[2,3] parameter(2) + ROOT scatter = s32[3,3] scatter(operand, indices, updates), + to_apply=scatter_computation, + update_window_dims={1}, + inserted_window_dims={0}, + scatter_dims_to_operand_dims={0}, + index_vector_dim=1 + })"; + + TF_ASSERT_OK_AND_ASSIGN(auto module, + ParseAndReturnVerifiedModule(kModuleStr)); + + // The HLO parser changes all no layout shapes from the input to have a + // default layout. Clear the layout of the scatter operand for testing. + HloInstruction* scatter_operand = FindInstruction(module.get(), "operand"); + scatter_operand->mutable_shape()->clear_layout(); + + ScatterExpander scatter_expander(ScatterExpander::kEliminateSimpleScatters); + TF_ASSERT_OK_AND_ASSIGN(bool result, + RunHloPass(&scatter_expander, module.get())); + EXPECT_FALSE(result); +} + +TEST_F(ScatterExpanderTest, EliminateSimpleScattersRewritesTrivialScatter) { + const char* kModuleStr = R"( + HloModule scatter_expander + + scatter_computation { + parameter0 = s32[] parameter(0) + ROOT parameter1 = s32[] parameter(1) + } + + ENTRY kernel_entry { + operand = s32[5] iota(), iota_dimension=0 + indices = s32[1] parameter(0) + update = s32[] constant(0) + ROOT scatter = s32[5]{0} scatter(operand, indices, update), + update_window_dims={}, inserted_window_dims={0}, + scatter_dims_to_operand_dims={0}, index_vector_dim=0, + to_apply=scatter_computation + })"; + + TF_ASSERT_OK_AND_ASSIGN(auto module, + ParseAndReturnVerifiedModule(kModuleStr)); + + // The HLO parser changes all no layout shapes from the input to have a + // default layout. Clear the layout of the scatter operand for testing. + HloInstruction* scatter_operand = FindInstruction(module.get(), "operand"); + scatter_operand->mutable_shape()->clear_layout(); + + ScatterExpander scatter_expander(ScatterExpander::kEliminateSimpleScatters); TF_ASSERT_OK_AND_ASSIGN(bool result, RunHloPass(&scatter_expander, module.get())); EXPECT_TRUE(result); diff --git a/tensorflow/compiler/xla/service/shape_inference.cc b/tensorflow/compiler/xla/service/shape_inference.cc index 8e39e32e4c3..a96c9c34260 100644 --- a/tensorflow/compiler/xla/service/shape_inference.cc +++ b/tensorflow/compiler/xla/service/shape_inference.cc @@ -2825,6 +2825,38 @@ ShapeInference::InferDegenerateDimensionBroadcastShape(HloOpcode operation, return output_shape; } +/* static */ StatusOr ShapeInference::InferDynamicReshapeShape( + const Shape& operand, absl::Span dim_size_shapes, + absl::Span new_size_bounds, + const std::vector& dims_are_dynamic) { + if (new_size_bounds.size() != dims_are_dynamic.size()) { + return InvalidArgument( + "DynamicReshape has to have the same number of elements in new_sizes " + "(%d) and dims_are_dynamic (%d)", + new_size_bounds.size(), dims_are_dynamic.size()); + } + + for (const Shape* dim_size_shape : dim_size_shapes) { + if (dim_size_shape->element_type() != S32 && dim_size_shape->rank() != 0) { + return InvalidArgument( + "DynamicReshape's dim size has to be scalar S32, got (%s): ", + dim_size_shape->ToString()); + } + } + + Shape inferred_shape = ShapeUtil::MakeShape( + operand.element_type(), new_size_bounds, dims_are_dynamic); + if (ShapeUtil::ElementsIn(operand) != ShapeUtil::ElementsIn(inferred_shape)) { + return InvalidArgument( + "Reshape operation has mismatched element counts: from=%d (%s) " + "to=%d (%s).", + ShapeUtil::ElementsIn(operand), ShapeUtil::HumanString(operand), + ShapeUtil::ElementsIn(inferred_shape), + ShapeUtil::HumanString(inferred_shape)); + } + return inferred_shape; +} + /* static */ StatusOr ShapeInference::InferReshapeShape( const Shape& operand, absl::Span dimensions, absl::Span new_sizes, int64 inferred_dimension) { diff --git a/tensorflow/compiler/xla/service/shape_inference.h b/tensorflow/compiler/xla/service/shape_inference.h index d47d96ab52d..f03e4e5fa98 100644 --- a/tensorflow/compiler/xla/service/shape_inference.h +++ b/tensorflow/compiler/xla/service/shape_inference.h @@ -241,6 +241,15 @@ class ShapeInference { absl::Span new_sizes, int64 inferred_dimension); + // Infers the shape produced by a dynamic reshape operation from the element + // type of its operand and the new dimension sizes specified. The result shape + // will have dynamic dimensions as specific in `dim_is_dynamic` and bound + // `new_size_bounds`. + static StatusOr InferDynamicReshapeShape( + const Shape& operand, absl::Span dim_size_shapes, + absl::Span new_size_bounds, + const std::vector& dims_are_dynamic); + // Infers the shape produced by a transpose operation from the element type of // its operand and its dimensions field. static StatusOr InferTransposeShape( diff --git a/tensorflow/compiler/xla/service/sharding_propagation.cc b/tensorflow/compiler/xla/service/sharding_propagation.cc index bcbebf3460f..7136ce82e25 100644 --- a/tensorflow/compiler/xla/service/sharding_propagation.cc +++ b/tensorflow/compiler/xla/service/sharding_propagation.cc @@ -120,34 +120,34 @@ HloSharding MergeForMoreSpecificSharding(const HloSharding& a, return IsShardingMoreSpecific(a, b) ? a : b; } -// Returns a sharding that is refined by merging old and to_merge. May combine -// partial sharding in addition to MergeForMoreSpecificSharding(). -HloSharding MergeSharding(const HloSharding& old, const HloSharding& to_merge, - bool may_combine_partial_sharding) { +// Tries to refine `to_merge` by combining with `old`. Returns if the final +// `to_merge` is more specific than `old`. May combine partial sharding in +// addition to MergeForMoreSpecificSharding(). +bool MergeSharding(const HloSharding& old, HloSharding* to_merge, + bool may_combine_partial_sharding) { if (old.IsTuple()) { - HloSharding result = old; - CHECK(to_merge.IsTuple()); - CHECK_EQ(old.tuple_elements().size(), to_merge.tuple_elements().size()); - for (int64 i = 0; i < result.tuple_elements().size(); ++i) { - result.tuple_elements()[i] = - MergeSharding(old.tuple_elements()[i], to_merge.tuple_elements()[i], + CHECK(to_merge->IsTuple()); + bool changed = false; + for (int64 i = 0; i < old.tuple_elements().size(); ++i) { + changed |= + MergeSharding(old.tuple_elements()[i], &to_merge->tuple_elements()[i], may_combine_partial_sharding); } - return result; + return changed; } if (!may_combine_partial_sharding || !old.ReplicateOnLastTileDim() || - !to_merge.ReplicateOnLastTileDim() || + !to_merge->ReplicateOnLastTileDim() || old.tile_assignment().num_elements() != - to_merge.tile_assignment().num_elements()) { - return IsShardingMoreSpecific(to_merge, old) ? to_merge : old; + to_merge->tile_assignment().num_elements()) { + return IsShardingMoreSpecific(*to_merge, old); } // Combine the tile dimension sizes from new and old. int64 num_devices = old.tile_assignment().num_elements(); std::vector new_tile_dims; bool compatible = true; - new_tile_dims.reserve(to_merge.tile_assignment().num_dimensions()); - for (int64 i = 0; i < to_merge.tile_assignment().num_dimensions() - 1; ++i) { - int64 new_dim = to_merge.tile_assignment().dim(i); + new_tile_dims.reserve(to_merge->tile_assignment().num_dimensions()); + for (int64 i = 0; i < to_merge->tile_assignment().num_dimensions() - 1; ++i) { + int64 new_dim = to_merge->tile_assignment().dim(i); int64 old_dim = old.tile_assignment().dim(i); if (new_dim == 1) { new_tile_dims.push_back(old_dim); @@ -163,7 +163,7 @@ HloSharding MergeSharding(const HloSharding& old, const HloSharding& to_merge, int64 replication = num_devices / Product(new_tile_dims); if (!compatible || num_devices % Product(new_tile_dims) != 0 || replication >= old.tile_assignment().dimensions().back()) { - return IsShardingMoreSpecific(to_merge, old) ? to_merge : old; + return IsShardingMoreSpecific(*to_merge, old); } new_tile_dims.push_back(replication); Array new_tile(new_tile_dims); @@ -174,7 +174,7 @@ HloSharding MergeSharding(const HloSharding& old, const HloSharding& to_merge, const HloSharding& sharding) { int64 group_id = 0; for (int64 i = 0; i < tile_indices.size() - 1; ++i) { - group_id *= to_merge.tile_assignment().dim(i); + group_id *= to_merge->tile_assignment().dim(i); group_id += tile_indices[i]; } return group_id; @@ -183,9 +183,9 @@ HloSharding MergeSharding(const HloSharding& old, const HloSharding& to_merge, [&](absl::Span indices, int64 device) { old_group_members[get_group_index(indices, old)].insert(device); }); - to_merge.tile_assignment().Each( + to_merge->tile_assignment().Each( [&](absl::Span indices, int64 device) { - new_group_members[get_group_index(indices, to_merge)].insert(device); + new_group_members[get_group_index(indices, *to_merge)].insert(device); }); // Try to find the intersection of old and new replication groups, in // order to determine the merged tile assignment. @@ -199,12 +199,12 @@ HloSharding MergeSharding(const HloSharding& old, const HloSharding& to_merge, if (old.tile_assignment().dim(i) == 1) { old_index[i] = 0; } - if (to_merge.tile_assignment().dim(i) == 1) { + if (to_merge->tile_assignment().dim(i) == 1) { new_index[i] = 0; } } int64 old_group_id = get_group_index(old_index, old); - int64 new_group_id = get_group_index(new_index, to_merge); + int64 new_group_id = get_group_index(new_index, *to_merge); if (old_group_members[old_group_id].empty() || new_group_members[new_group_id].empty() || *old_group_members[old_group_id].begin() != @@ -220,11 +220,13 @@ HloSharding MergeSharding(const HloSharding& old, const HloSharding& to_merge, if (replication == 1) { new_tile_dims.pop_back(); new_tile.Reshape(new_tile_dims); - return HloSharding::Tile(new_tile); + *to_merge = HloSharding::Tile(new_tile); + } else { + *to_merge = HloSharding::PartialTile(new_tile); } - return HloSharding::PartialTile(new_tile); + return true; } - return IsShardingMoreSpecific(to_merge, old) ? to_merge : old; + return IsShardingMoreSpecific(*to_merge, old); } // Updates the sharding of the specified instruction with the specified sharding @@ -232,7 +234,7 @@ HloSharding MergeSharding(const HloSharding& old, const HloSharding& to_merge, // been applied. If may_combine_partial_sharding is true, this may combine the // new and existing sharding if they are both partial tiling partial // replication. -bool MaybeImproveInstructionSharding(const HloSharding& sharding, +bool MaybeImproveInstructionSharding(HloSharding sharding, HloInstruction* instruction, bool may_combine_partial_sharding) { // We don't want to propagate tile maximal shardings. @@ -241,13 +243,13 @@ bool MaybeImproveInstructionSharding(const HloSharding& sharding, } // Any sharding is better then no sharding. if (!instruction->has_sharding()) { - instruction->set_sharding(sharding); + instruction->set_sharding(std::move(sharding)); return true; } - auto merged = MergeSharding(instruction->sharding(), sharding, + auto merged = MergeSharding(instruction->sharding(), &sharding, may_combine_partial_sharding); - if (merged != instruction->sharding()) { - instruction->set_sharding(merged); + if (merged) { + instruction->set_sharding(std::move(sharding)); return true; } return false; @@ -387,6 +389,7 @@ const HloInstruction* PickRepresentativeOperand( case HloOpcode::kDot: case HloOpcode::kDynamicSlice: case HloOpcode::kDynamicUpdateSlice: + case HloOpcode::kDynamicReshape: case HloOpcode::kFft: case HloOpcode::kFusion: case HloOpcode::kGather: @@ -538,7 +541,7 @@ bool InferDotShardingFromOperands( // Convolution handling for InferShardingFromOperands(). bool InferConvolutionShardingFromOperands(HloInstruction* instruction, - bool aggressive_prop, + int64 aggressiveness, bool may_combine_partial_sharding) { if (auto dot_dims = dot_as_convolution_util::ParseDotGeneralFromConvolution( instruction)) { @@ -586,12 +589,27 @@ bool InferConvolutionShardingFromOperands(HloInstruction* instruction, may_combine_partial_sharding); } +bool CanPropagateThroughAtAgressiveLevel(const HloInstruction& inst, + int64 aggressiveness) { + // At minimum agressiveness, only allow pass-through ops. + if (aggressiveness < 1 && !inst.IsElementwise() && + inst.opcode() != HloOpcode::kTranspose && + inst.opcode() != HloOpcode::kReshape) { + return false; + } + return true; +} + // Tries to update the sharding of the specified instruction based on its // operands and returns true if the sharding of the instruction have been // changed and false otherwise. bool InferShardingFromOperands(HloInstruction* instruction, const ComputationMap& computation_map, - bool is_spmd, bool aggressive_prop) { + bool is_spmd, int64 aggressiveness) { + if (!CanPropagateThroughAtAgressiveLevel(*instruction, aggressiveness)) { + return false; + } + const bool may_combine_partial_sharding = is_spmd && aggressiveness > 0; if (!SupportSpatialPartitioning(instruction, computation_map, is_spmd)) { // If an array shaped HLO doesn't support spatial partitioning but at least // one of its operand is replicated then we make the HLO replicated as well. @@ -604,8 +622,7 @@ bool InferShardingFromOperands(HloInstruction* instruction, return op->has_sharding() && op->sharding().IsReplicated(); })) { return MaybeImproveInstructionSharding( - HloSharding::Replicate(), instruction, - /*may_combine_partial_sharding=*/is_spmd); + HloSharding::Replicate(), instruction, may_combine_partial_sharding); } return false; } @@ -619,7 +636,7 @@ bool InferShardingFromOperands(HloInstruction* instruction, HloSharding new_sharding = operand->sharding().GetSubSharding( operand->shape(), {instruction->tuple_index()}); return MaybeImproveInstructionSharding( - new_sharding, instruction, /*may_combine_partial_sharding=*/is_spmd); + std::move(new_sharding), instruction, may_combine_partial_sharding); } case HloOpcode::kTuple: { if (absl::c_none_of(instruction->operands(), @@ -684,12 +701,12 @@ bool InferShardingFromOperands(HloInstruction* instruction, if (!IsSpatiallyPartitioned(operand)) { continue; } - auto get_maybe_tuple_sharding = [&](const HloSharding& sharding) { + auto get_maybe_tuple_sharding = [&](HloSharding sharding) { if (instruction->operand_count() == 2) { return sharding; } std::vector tuple(instruction->operand_count() / 2, - sharding); + std::move(sharding)); return HloSharding::Tuple(instruction->shape(), tuple); }; if (operand->sharding().IsReplicated() || @@ -701,7 +718,7 @@ bool InferShardingFromOperands(HloInstruction* instruction, // support this in SPMD. changed |= MaybeImproveInstructionSharding( get_maybe_tuple_sharding(HloSharding::Replicate()), instruction, - /*may_combine_partial_sharding=*/is_spmd); + may_combine_partial_sharding); continue; } auto after_partial_replication = @@ -712,7 +729,7 @@ bool InferShardingFromOperands(HloInstruction* instruction, if (after_partial_replication.IsReplicated()) { changed |= MaybeImproveInstructionSharding( get_maybe_tuple_sharding(HloSharding::Replicate()), instruction, - /*may_combine_partial_sharding=*/is_spmd); + may_combine_partial_sharding); continue; } // Use the same sharding for all tuple elements, because they are part @@ -721,8 +738,7 @@ bool InferShardingFromOperands(HloInstruction* instruction, get_maybe_tuple_sharding(hlo_sharding_util::RemoveShapeDimensions( after_partial_replication, instruction->dimensions())); changed |= MaybeImproveInstructionSharding( - new_sharding, instruction, - /*may_combine_partial_sharding=*/is_spmd); + std::move(new_sharding), instruction, may_combine_partial_sharding); } return changed; } @@ -763,12 +779,11 @@ bool InferShardingFromOperands(HloInstruction* instruction, ? HloSharding::PartialTile(new_tile_assignment) : HloSharding::Tile(new_tile_assignment); return MaybeImproveInstructionSharding( - new_sharding, instruction, /*may_combine_partial_sharding=*/is_spmd); + std::move(new_sharding), instruction, may_combine_partial_sharding); } case HloOpcode::kConvolution: - return InferConvolutionShardingFromOperands( - instruction, aggressive_prop, - /*may_combine_partial_sharding=*/is_spmd); + return InferConvolutionShardingFromOperands(instruction, aggressiveness, + may_combine_partial_sharding); case HloOpcode::kTranspose: { const HloInstruction* input = instruction->operand(0); if (!IsSpatiallyPartitioned(input)) { @@ -776,8 +791,8 @@ bool InferShardingFromOperands(HloInstruction* instruction, } HloSharding sharding = hlo_sharding_util::TransposeSharding( input->sharding(), instruction->dimensions()); - return MaybeImproveInstructionSharding( - sharding, instruction, /*may_combine_partial_sharding=*/is_spmd); + return MaybeImproveInstructionSharding(std::move(sharding), instruction, + may_combine_partial_sharding); } case HloOpcode::kReduceWindow: { const HloInstruction* lhs = instruction->operand(0); @@ -795,9 +810,8 @@ bool InferShardingFromOperands(HloInstruction* instruction, << instruction->ToString(); return false; } - return MaybeImproveInstructionSharding( - lhs->sharding(), instruction, - /*may_combine_partial_sharding=*/is_spmd); + return MaybeImproveInstructionSharding(lhs->sharding(), instruction, + may_combine_partial_sharding); } case HloOpcode::kSelectAndScatter: { // Shard according to first operand, as output keeps the same shape. @@ -816,9 +830,8 @@ bool InferShardingFromOperands(HloInstruction* instruction, << instruction->ToString(); return false; } - return MaybeImproveInstructionSharding( - lhs->sharding(), instruction, - /*may_combine_partial_sharding=*/is_spmd); + return MaybeImproveInstructionSharding(lhs->sharding(), instruction, + may_combine_partial_sharding); } case HloOpcode::kReshape: { if (!IsSpatiallyPartitioned(instruction->operand(0))) { @@ -829,9 +842,9 @@ bool InferShardingFromOperands(HloInstruction* instruction, instruction->operand(0)->shape(), instruction->shape(), instruction->operand(0)->sharding()); if (new_sharding.has_value()) { - return MaybeImproveInstructionSharding( - new_sharding.value(), instruction, - /*may_combine_partial_sharding=*/is_spmd); + return MaybeImproveInstructionSharding(std::move(*new_sharding), + instruction, + may_combine_partial_sharding); } return false; } @@ -842,14 +855,13 @@ bool InferShardingFromOperands(HloInstruction* instruction, return MaybeImproveInstructionSharding( hlo_sharding_util::ReverseSharding( instruction->operand(0)->sharding(), instruction->dimensions()), - instruction, /*may_combine_partial_sharding=*/is_spmd); + instruction, may_combine_partial_sharding); } case HloOpcode::kDot: { const auto& dnums = dot_as_convolution_util::ParseDotGeneralFromDot(instruction); - return InferDotShardingFromOperands( - instruction, dnums, - /*may_combine_partial_sharding=*/is_spmd); + return InferDotShardingFromOperands(instruction, dnums, + may_combine_partial_sharding); } case HloOpcode::kParameter: { auto parent_it = computation_map.find(instruction->parent()); @@ -864,7 +876,7 @@ bool InferShardingFromOperands(HloInstruction* instruction, if (parent->operand(i)->has_sharding()) { return MaybeImproveInstructionSharding( parent->operand(i)->sharding(), instruction, - /*may_combine_partial_sharding=*/is_spmd); + may_combine_partial_sharding); } return false; } @@ -891,16 +903,15 @@ bool InferShardingFromOperands(HloInstruction* instruction, if (instruction->shape().IsTuple()) { return MaybeImproveInstructionSharding( HloSharding::SingleTuple(instruction->shape(), operand->sharding()), - instruction, /*may_combine_partial_sharding=*/is_spmd); + instruction, may_combine_partial_sharding); } else { - return MaybeImproveInstructionSharding( - operand->sharding(), instruction, - /*may_combine_partial_sharding=*/is_spmd); + return MaybeImproveInstructionSharding(operand->sharding(), instruction, + may_combine_partial_sharding); } } case HloOpcode::kDynamicSlice: case HloOpcode::kDynamicUpdateSlice: { - auto propagate_slicing = [instruction, is_spmd]() { + auto propagate_slicing = [&]() { const HloInstruction* operand = instruction->opcode() == HloOpcode::kDynamicSlice ? instruction->operand(0) @@ -910,9 +921,9 @@ bool InferShardingFromOperands(HloInstruction* instruction, } if (operand->sharding().IsReplicated()) { - return MaybeImproveInstructionSharding( - HloSharding::Replicate(), instruction, - /*may_combine_partial_sharding=*/is_spmd); + return MaybeImproveInstructionSharding(HloSharding::Replicate(), + instruction, + may_combine_partial_sharding); } const auto& tile_assignment = operand->sharding().tile_assignment(); @@ -923,11 +934,10 @@ bool InferShardingFromOperands(HloInstruction* instruction, return false; } } - return MaybeImproveInstructionSharding( - operand->sharding(), instruction, - /*may_combine_partial_sharding=*/is_spmd); + return MaybeImproveInstructionSharding(operand->sharding(), instruction, + may_combine_partial_sharding); }; - auto propagate_base = [instruction, is_spmd]() { + auto propagate_base = [&]() { if (instruction->opcode() != HloOpcode::kDynamicUpdateSlice) { return false; } @@ -936,7 +946,7 @@ bool InferShardingFromOperands(HloInstruction* instruction, } return MaybeImproveInstructionSharding( instruction->operand(0)->sharding(), instruction, - /*may_combine_partial_sharding=*/is_spmd); + may_combine_partial_sharding); }; return propagate_slicing() || propagate_base(); } @@ -946,8 +956,7 @@ bool InferShardingFromOperands(HloInstruction* instruction, HloSharding new_sharding = hlo_sharding_util::GatherOutputSharding( instruction->operand(1)->sharding(), instruction); changed |= MaybeImproveInstructionSharding( - new_sharding, instruction, - /*may_combine_partial_sharding=*/is_spmd); + std::move(new_sharding), instruction, may_combine_partial_sharding); } if (is_spmd && IsSpatiallyPartitioned(instruction->operand(0))) { auto maybe_from_data = @@ -955,8 +964,8 @@ bool InferShardingFromOperands(HloInstruction* instruction, instruction->operand(0)->sharding(), *instruction); if (maybe_from_data) { changed |= MaybeImproveInstructionSharding( - *maybe_from_data, instruction, - /*may_combine_partial_sharding=*/is_spmd); + std::move(*maybe_from_data), instruction, + may_combine_partial_sharding); } } return changed; @@ -966,7 +975,7 @@ bool InferShardingFromOperands(HloInstruction* instruction, if (is_spmd && IsSpatiallyPartitioned(instruction->operand(0))) { changed |= MaybeImproveInstructionSharding( instruction->operand(0)->sharding(), instruction, - /*may_combine_partial_sharding=*/is_spmd); + may_combine_partial_sharding); } if (!IsSpatiallyPartitioned(instruction->operand(1)) && !IsSpatiallyPartitioned(instruction->operand(2))) { @@ -978,13 +987,12 @@ bool InferShardingFromOperands(HloInstruction* instruction, instruction->operand(2)->sharding(), *instruction); if (maybe_from_update) { changed |= MaybeImproveInstructionSharding( - *maybe_from_update, instruction, - /*may_combine_partial_sharding=*/is_spmd); + std::move(*maybe_from_update), instruction, + may_combine_partial_sharding); } } changed |= MaybeImproveInstructionSharding( - HloSharding::Replicate(), instruction, - /*may_combine_partial_sharding=*/is_spmd); + HloSharding::Replicate(), instruction, may_combine_partial_sharding); return changed; } case HloOpcode::kWhile: { @@ -996,17 +1004,16 @@ bool InferShardingFromOperands(HloInstruction* instruction, sharding = MergeForMoreSpecificSharding(sharding, instruction->sharding()); } - return MaybeImproveInstructionSharding( - sharding, instruction, /*may_combine_partial_sharding=*/is_spmd); + return MaybeImproveInstructionSharding(std::move(sharding), instruction, + may_combine_partial_sharding); } default: { - if (instruction->IsElementwise() && is_spmd) { + if (instruction->IsElementwise() && may_combine_partial_sharding) { bool changed = false; for (auto operand : instruction->operands()) { if (IsSpatiallyPartitioned(operand)) { changed |= MaybeImproveInstructionSharding( - operand->sharding(), instruction, - /*may_combine_partial_sharding=*/is_spmd); + operand->sharding(), instruction, may_combine_partial_sharding); } } return changed; @@ -1015,9 +1022,8 @@ bool InferShardingFromOperands(HloInstruction* instruction, if (!operand || !IsSpatiallyPartitioned(operand)) { return false; } - return MaybeImproveInstructionSharding( - operand->sharding(), instruction, - /*may_combine_partial_sharding=*/is_spmd); + return MaybeImproveInstructionSharding(operand->sharding(), instruction, + may_combine_partial_sharding); } } return false; @@ -1088,12 +1094,14 @@ HloSharding InferDotOperandSharding( operand_to_other_dims[operand_index == 0 ? dim.lhs : dim.rhs] = operand_index == 0 ? dim.rhs : dim.lhs; } - sharding = - MergeSharding(sharding, - *hlo_sharding_util::TransposeShardingWithCollapsedDims( - other_operand_dims_replicated, other_to_operand_dims, - operand_to_other_dims), - may_combine_partial_sharding); + HloSharding sharding_from_other = + *hlo_sharding_util::TransposeShardingWithCollapsedDims( + other_operand_dims_replicated, other_to_operand_dims, + operand_to_other_dims); + if (MergeSharding(sharding, &sharding_from_other, + may_combine_partial_sharding)) { + sharding = std::move(sharding_from_other); + } } return sharding; } @@ -1101,10 +1109,14 @@ HloSharding InferDotOperandSharding( // Return the sharding that should be propagated from user to instruction. absl::optional GetShardingFromUser( const HloInstruction& instruction, const HloInstruction& user, - bool aggressive_prop, bool is_spmd) { + int64 aggressiveness, bool is_spmd) { + if (!CanPropagateThroughAtAgressiveLevel(user, aggressiveness)) { + return absl::nullopt; + } if (!IsSpatiallyPartitioned(&user)) { return absl::nullopt; } + const bool may_combine_partial_sharding = is_spmd && aggressiveness > 0; switch (user.opcode()) { case HloOpcode::kBroadcast: { if (user.sharding().IsReplicated()) { @@ -1176,9 +1188,8 @@ absl::optional GetShardingFromUser( if (auto dot_dims = dot_as_convolution_util::ParseDotGeneralFromConvolution(&user)) { int64 op_idx = user.operand_index(&instruction); - return InferDotOperandSharding( - &user, *dot_dims, op_idx, - /*may_combine_partial_sharding=*/is_spmd); + return InferDotOperandSharding(&user, *dot_dims, op_idx, + may_combine_partial_sharding); } return absl::nullopt; } @@ -1263,7 +1274,7 @@ absl::optional GetShardingFromUser( int64 op_idx = user.operand_index(&instruction); auto dnums = dot_as_convolution_util::ParseDotGeneralFromDot(&user); return InferDotOperandSharding(&user, dnums, op_idx, - /*may_combine_partial_sharding=*/is_spmd); + may_combine_partial_sharding); } case HloOpcode::kReduce: { if (instruction.shape().rank() == 0) { @@ -1364,18 +1375,18 @@ absl::optional GetShardingFromUser( // false otherwise. bool InferShardingFromUsers(HloInstruction* instruction, const ComputationMap& computation_map, - bool aggressive_prop, bool is_spmd) { + int64 aggressiveness, bool is_spmd) { if (!SupportSpatialPartitioning(instruction, computation_map, is_spmd)) { return false; } bool improved_sharding = false; + const bool may_combine_partial_sharding = is_spmd && aggressiveness > 0; for (const HloInstruction* user : instruction->users()) { absl::optional user_sharding = - GetShardingFromUser(*instruction, *user, aggressive_prop, is_spmd); + GetShardingFromUser(*instruction, *user, aggressiveness, is_spmd); if (user_sharding) { improved_sharding |= MaybeImproveInstructionSharding( - *user_sharding, instruction, - /*may_combine_partial_sharding=*/is_spmd); + std::move(*user_sharding), instruction, may_combine_partial_sharding); } } return improved_sharding; @@ -1645,10 +1656,18 @@ StatusOr ShardingPropagation::Run(HloModule* module) { // strictly improve the sharding of the graph and it can't be improved // indefinitely. int64 iterations = 0; - auto run_to_fix_point = [&](bool aggressive_prop) { - bool changed = true; - while (changed) { - changed = false; + auto run_to_fix_point = [&](int64 aggressiveness) { + absl::flat_hash_set workset; + for (const HloComputation* computation : module->computations()) { + for (const HloInstruction* instruction : computation->instructions()) { + // Remove the instructions where the sharding was provided from the + // outside so we don't modify them. + if (!provided_shardings.contains(instruction)) { + workset.insert(instruction); + } + } + } + while (!workset.empty()) { int64 inferred_from_operand_counter = 0; int64 inferred_from_user_counter = 0; int64 instruction_counter = 0; @@ -1662,12 +1681,10 @@ StatusOr ShardingPropagation::Run(HloModule* module) { already_sharded_counter += (instruction->has_sharding() ? 1 : 0); } - // Remove the instructions where the sharding was provided from the - // outside so we don't modify them. instructions.erase( std::remove_if(instructions.begin(), instructions.end(), [&](HloInstruction* instruction) { - return provided_shardings.contains(instruction); + return !workset.contains(instruction); }), instructions.end()); @@ -1675,28 +1692,40 @@ StatusOr ShardingPropagation::Run(HloModule* module) { // operands. for (HloInstruction* instruction : instructions) { if (InferShardingFromOperands(instruction, computation_map, is_spmd_, - aggressive_prop)) { + aggressiveness)) { ++inferred_from_operand_counter; - changed = true; + any_changed = true; VLOG(2) << "Add sharding (forward-pass): " << instruction->ToString(); maybe_computation_propagation(instruction); + for (auto user : instruction->users()) { + if (!provided_shardings.contains(user)) { + workset.insert(user); + } + } + } else { + workset.erase(instruction); } } // Then iterate the HLO graph in reverse post order taking shardings // from users. for (auto it = instructions.rbegin(); it != instructions.rend(); ++it) { - if (InferShardingFromUsers(*it, computation_map, aggressive_prop, + if (InferShardingFromUsers(*it, computation_map, aggressiveness, is_spmd_)) { ++inferred_from_user_counter; - changed = true; + any_changed = true; VLOG(2) << "Add sharding (backward-pass): " << (*it)->ToString(); maybe_computation_propagation(*it); + workset.insert(*it); + for (auto operand : (*it)->operands()) { + if (!provided_shardings.contains(operand)) { + workset.insert(operand); + } + } } } } - any_changed |= changed; VLOG(1) << "Sharding propagation iteration " << iterations << ";"; VLOG(1) << " total instructions: " << instruction_counter; VLOG(1) << " instructions already sharded: " << already_sharded_counter; @@ -1707,8 +1736,8 @@ StatusOr ShardingPropagation::Run(HloModule* module) { ++iterations; } }; - run_to_fix_point(false); - run_to_fix_point(true); + run_to_fix_point(0); + run_to_fix_point(1); VLOG(1) << "Sharding propagation completed after " << iterations << " iterations"; diff --git a/tensorflow/compiler/xla/service/sharding_propagation_test.cc b/tensorflow/compiler/xla/service/sharding_propagation_test.cc index 5ed1398149b..03c77c2038c 100644 --- a/tensorflow/compiler/xla/service/sharding_propagation_test.cc +++ b/tensorflow/compiler/xla/service/sharding_propagation_test.cc @@ -556,6 +556,43 @@ ENTRY %replicated { op::Sharding("{devices=[1,2,2,1]0,1,2,3}")); } +TEST_F(ShardingPropagationTest, PartialReplicateReshapeForwardPass) { + const char* const hlo_string = R"( +HloModule module +ENTRY %reshape { + %param0 = f32[1430,1]{1,0} parameter(0), + sharding={devices=[2,1,2]0,1,2,3 last_tile_dim_replicate} + %reshape = f32[10,11,13]{2,1,0} reshape(%param0) + ROOT %copy = f32[10,11,13]{2,1,0} copy(%reshape) +})"; + TF_ASSERT_OK_AND_ASSIGN(auto module, + ParseAndReturnVerifiedModule(hlo_string)); + TF_ASSERT_OK_AND_ASSIGN( + bool changed, ShardingPropagation(/*is_spmd=*/true).Run(module.get())); + EXPECT_TRUE(changed); + EXPECT_THAT( + FindInstruction(module.get(), "reshape"), + op::Sharding("{devices=[2,1,1,2]0,1,2,3 last_tile_dim_replicate}")); +} + +TEST_F(ShardingPropagationTest, PartialReplicateReshapeBackwardPass) { + const char* const hlo_string = R"( +HloModule module +ENTRY %reshape { + %param0 = f32[2002,1]{1,0} parameter(0) + %copy = f32[2002,1]{1,0} copy(f32[2002,1]{1,0} %param0) + ROOT %reshape = f32[14,11,13]{2,1,0} reshape(%copy), + sharding={devices=[2,1,1,2]0,1,2,3 last_tile_dim_replicate} +})"; + TF_ASSERT_OK_AND_ASSIGN(auto module, + ParseAndReturnVerifiedModule(hlo_string)); + TF_ASSERT_OK_AND_ASSIGN( + bool changed, ShardingPropagation(/*is_spmd=*/true).Run(module.get())); + EXPECT_TRUE(changed); + EXPECT_THAT(FindInstruction(module.get(), "copy"), + op::Sharding("{devices=[2,1,2]0,1,2,3 last_tile_dim_replicate}")); +} + TEST_F(ShardingPropagationTest, DontShardTuplesIfAllInputIsMaximal) { const char* const hlo_string = R"( HloModule module @@ -1779,6 +1816,52 @@ ENTRY entry { op::Sharding("{devices=[2]0,1}")); } +TEST_F(ShardingPropagationTest, GatherToIndex2) { + const char* hlo_string = R"( +HloModule module + +ENTRY entry { + %input = bf16[2,4819,4] parameter(0), sharding={replicated} + %p1 = s32[2,1000,2] parameter(1) + %indices = s32[2,1000,2] copy(%p1) + ROOT %gather = bf16[2,1000,4] + gather(bf16[2,4819,4] %input, s32[2,1000,2] %indices), + offset_dims={2}, collapsed_slice_dims={0,1}, + start_index_map={0,1}, index_vector_dim=2, slice_sizes={1,1,4}, + sharding={devices=[1,2,1]0,1} +})"; + TF_ASSERT_OK_AND_ASSIGN(auto module, + ParseAndReturnVerifiedModule(hlo_string)); + TF_ASSERT_OK_AND_ASSIGN(bool changed, + ShardingPropagation().Run(module.get())); + EXPECT_TRUE(changed); + EXPECT_THAT(FindInstruction(module.get(), "indices"), + op::Sharding("{devices=[1,2,1]0,1}")); +} + +TEST_F(ShardingPropagationTest, GatherToIndex3) { + const char* hlo_string = R"( +HloModule module + +ENTRY entry { + %input = bf16[2,4819,4] parameter(0), sharding={replicated} + %p1 = s32[2,2,1000] parameter(1) + %indices = s32[2,2,1000] copy(%p1) + ROOT %gather = bf16[2,1000,4] + gather(bf16[2,4819,4] %input, s32[2,2,1000] %indices), + offset_dims={2}, collapsed_slice_dims={0,1}, + start_index_map={0,1}, index_vector_dim=1, slice_sizes={1,1,4}, + sharding={devices=[1,2,1]0,1} +})"; + TF_ASSERT_OK_AND_ASSIGN(auto module, + ParseAndReturnVerifiedModule(hlo_string)); + TF_ASSERT_OK_AND_ASSIGN(bool changed, + ShardingPropagation().Run(module.get())); + EXPECT_TRUE(changed); + EXPECT_THAT(FindInstruction(module.get(), "indices"), + op::Sharding("{devices=[1,1,2]0,1}")); +} + TEST_F(ShardingPropagationTest, GatherToDataOperand) { const char* hlo_string = R"( HloModule module @@ -2039,5 +2122,45 @@ ENTRY entry { op::Sharding("{devices=[2,2,2]0,1,4,5,2,3,6,7 last_tile_dim_replicate}")); } +TEST_F(ShardingPropagationTest, PartialShardingTransposeForwardPass) { + const char* const hlo_string = R"( +HloModule module +ENTRY %transpose { + %param = f32[7,11,13]{2,1,0} parameter(0), + sharding={devices=[2,1,2,2]0,1,2,3,4,5,6,7 last_tile_dim_replicate} + %transpose = f32[11,13,7]{2,1,0} transpose(%param), dimensions={1,2,0} + ROOT %copy = f32[11,13,7]{2,1,0} copy(%transpose) +})"; + TF_ASSERT_OK_AND_ASSIGN(auto module, + ParseAndReturnVerifiedModule(hlo_string)); + TF_ASSERT_OK_AND_ASSIGN( + bool changed, ShardingPropagation(/*is_spmd=*/true).Run(module.get())); + EXPECT_TRUE(changed); + EXPECT_THAT( + FindInstruction(module.get(), "transpose"), + op::Sharding( + "{devices=[1,2,2,2]0,1,4,5,2,3,6,7 last_tile_dim_replicate}")); +} + +TEST_F(ShardingPropagationTest, PartialShardingTransposeBackwardPass) { + const char* const hlo_string = R"( +HloModule module +ENTRY %transpose { + %param = f32[7,11,13]{2,1,0} parameter(0) + %copy = f32[7,11,13]{2,1,0} copy(%param) + ROOT %transpose = f32[11,13,7]{2,1,0} transpose(%copy), dimensions={1,2,0}, + sharding={devices=[1,2,2,2]0,1,2,3,4,5,6,7 last_tile_dim_replicate} +})"; + TF_ASSERT_OK_AND_ASSIGN(auto module, + ParseAndReturnVerifiedModule(hlo_string)); + TF_ASSERT_OK_AND_ASSIGN( + bool changed, ShardingPropagation(/*is_spmd=*/true).Run(module.get())); + EXPECT_TRUE(changed); + EXPECT_THAT( + FindInstruction(module.get(), "copy"), + op::Sharding( + "{devices=[2,1,2,2]0,1,4,5,2,3,6,7 last_tile_dim_replicate}")); +} + } // namespace } // namespace xla diff --git a/tensorflow/compiler/xla/service/spmd/BUILD b/tensorflow/compiler/xla/service/spmd/BUILD index dd3da796d61..d2243d30adf 100644 --- a/tensorflow/compiler/xla/service/spmd/BUILD +++ b/tensorflow/compiler/xla/service/spmd/BUILD @@ -74,3 +74,16 @@ tf_cc_test( "//tensorflow/core:test", ], ) + +cc_library( + name = "schedule_aware_all_gather_cse", + srcs = ["schedule_aware_all_gather_cse.cc"], + hdrs = ["schedule_aware_all_gather_cse.h"], + deps = [ + "//tensorflow/compiler/xla/service:hlo", + "//tensorflow/compiler/xla/service:hlo_casting_utils", + "//tensorflow/compiler/xla/service:hlo_pass", + "//tensorflow/stream_executor/lib", + "@com_google_absl//absl/container:flat_hash_map", + ], +) diff --git a/tensorflow/compiler/xla/service/spmd/dot_handler.cc b/tensorflow/compiler/xla/service/spmd/dot_handler.cc index a24bafe26ce..da432965497 100644 --- a/tensorflow/compiler/xla/service/spmd/dot_handler.cc +++ b/tensorflow/compiler/xla/service/spmd/dot_handler.cc @@ -100,7 +100,8 @@ StatusOr PartitionBaseCase( int64 output_rhs_non_contracting_partitions, int64 threshold_for_windowed_einsum_mib, SpmdBuilder* b, std::vector* - windowed_dot_general_loops) { + windowed_dot_general_loops, + bool may_reshard_without_detecting_match) { const HloSharding& lhs_sharding = lhs.sharding(); const HloSharding& rhs_sharding = rhs.sharding(); if (lhs_sharding.ReplicateOnLastTileDim() || @@ -491,29 +492,36 @@ StatusOr PartitionBaseCase( return dot; } - // Output is batch partitioned. - if (output_batch_partitions == num_partitions) { - auto resharded_lhs = lhs.Reshard(*output_sharding_transposed_to_match_lhs); - auto resharded_rhs = rhs.Reshard(*output_sharding_transposed_to_match_rhs); - TF_ASSIGN_OR_RETURN(auto dot, create_sharded_dot(resharded_lhs.hlo(), - resharded_rhs.hlo(), b)); - return dot; - } - // Output is partitioned along LHS non-contracting dimensions. - if (output_lhs_non_contracting_partitions == num_partitions) { - auto resharded_lhs = lhs.Reshard(*output_sharding_transposed_to_match_lhs); - auto replicated_rhs = rhs.Reshard(HloSharding::Replicate()); - TF_ASSIGN_OR_RETURN(auto dot, create_sharded_dot(resharded_lhs.hlo(), - replicated_rhs.hlo(), b)); - return dot; - } - // Output is partitioned along RHS non-contracting dimensions. - if (output_rhs_non_contracting_partitions == num_partitions) { - auto replicated_lhs = lhs.Reshard(HloSharding::Replicate()); - auto resharded_rhs = rhs.Reshard(*output_sharding_transposed_to_match_rhs); - TF_ASSIGN_OR_RETURN(auto dot, create_sharded_dot(replicated_lhs.hlo(), - resharded_rhs.hlo(), b)); - return dot; + if (may_reshard_without_detecting_match) { + // Output is batch partitioned. + if (output_batch_partitions == num_partitions) { + auto resharded_lhs = + lhs.Reshard(*output_sharding_transposed_to_match_lhs); + auto resharded_rhs = + rhs.Reshard(*output_sharding_transposed_to_match_rhs); + TF_ASSIGN_OR_RETURN(auto dot, create_sharded_dot(resharded_lhs.hlo(), + resharded_rhs.hlo(), b)); + return dot; + } + // Output is partitioned along LHS non-contracting dimensions. + if (output_lhs_non_contracting_partitions == num_partitions) { + auto resharded_lhs = + lhs.Reshard(*output_sharding_transposed_to_match_lhs); + auto replicated_rhs = rhs.Reshard(HloSharding::Replicate()); + TF_ASSIGN_OR_RETURN( + auto dot, + create_sharded_dot(resharded_lhs.hlo(), replicated_rhs.hlo(), b)); + return dot; + } + // Output is partitioned along RHS non-contracting dimensions. + if (output_rhs_non_contracting_partitions == num_partitions) { + auto replicated_lhs = lhs.Reshard(HloSharding::Replicate()); + auto resharded_rhs = + rhs.Reshard(*output_sharding_transposed_to_match_rhs); + TF_ASSIGN_OR_RETURN(auto dot, create_sharded_dot(replicated_lhs.hlo(), + resharded_rhs.hlo(), b)); + return dot; + } } // Returns true if it is beneficial to reshard the operand at `operand_idx` @@ -808,7 +816,8 @@ StatusOr PartitionDotGroupOnBatch( StatusOr PartitionDotGroupOnNonContracting( bool lhs_matching, PartitionedHlo matching, PartitionedHlo other, int64 matching_contracting_partitions, int64 other_contracting_partitions, - int64 matching_non_contracting_partitions, + absl::Span + partitioned_non_contractin_dims, int64 other_non_contracting_partitions, int64 output_other_non_contracting_partitions, const Shape& output_base_shape, const HloSharding& output_sharding, @@ -828,48 +837,20 @@ StatusOr PartitionDotGroupOnNonContracting( } }); - const bool may_replicate_other_contracting_dims = - (other_contracting_partitions == matching_non_contracting_partitions && - other_non_contracting_partitions == - output_other_non_contracting_partitions); - const bool may_replicate_other_non_contracting_dims = - matching_non_contracting_partitions == other_non_contracting_partitions && - matching_contracting_partitions == other_contracting_partitions; - std::vector other_group_dims; - if (may_replicate_other_contracting_dims && - (!may_replicate_other_non_contracting_dims || - ShapeUtil::ByteSizeOf(other.hlo()->shape()) <= - ShapeUtil::ByteSizeOf( - MakePartitionedShape(output_base_shape, output_sharding)))) { - for (const auto& dim : dims_mapping.contracting_dims) { - other_group_dims.push_back(lhs_matching ? dim.rhs : dim.lhs); - } - } else if (may_replicate_other_non_contracting_dims) { - for (const auto& dim : lhs_matching - ? dims_mapping.rhs_non_contracting_dims - : dims_mapping.lhs_non_contracting_dims) { - other_group_dims.push_back(lhs_matching ? dim.rhs : dim.lhs); - } - } else if (!(other.sharding().ReplicateOnLastTileDim() && - other.sharding().tile_assignment().dimensions().back() % - matching_non_contracting_partitions == - 0) && - !other.sharding().IsReplicated()) { - return nullptr; - } auto matching_sharding_dims = matching.sharding().tile_assignment().dimensions(); std::vector matching_dims; std::vector output_dims; + int64 group_count = 1; // Make sure the partitioning on matching's non-contracting dimensions // defines the same device groups for both matching and output. - for (const auto& dim : lhs_matching ? dims_mapping.lhs_non_contracting_dims - : dims_mapping.rhs_non_contracting_dims) { + for (const auto& dim : partitioned_non_contractin_dims) { int64 md = lhs_matching ? dim.lhs : dim.rhs; matching_sharding_dims[md] = output_sharding.tile_assignment().dim(dim.output); matching_dims.push_back(md); output_dims.push_back(dim.output); + group_count *= output_sharding.tile_assignment().dim(dim.output); } auto output_grouped = GroupShardingOnDims(output_sharding, output_dims); auto reshaped_matching_tiling = matching.sharding().tile_assignment(); @@ -885,6 +866,42 @@ StatusOr PartitionDotGroupOnNonContracting( matching.sharding() != UngroupSharding(matching_grouped)) { return nullptr; } + + std::vector other_group_dims; + if (other.sharding().ReplicateOnLastTileDim() && + other.sharding().tile_assignment().dimensions().back() % group_count == + 0) { + other_group_dims.push_back(other.base_shape().rank()); + } else { + const bool may_replicate_other_contracting_dims = + (other_contracting_partitions == group_count && + other_non_contracting_partitions == + output_other_non_contracting_partitions); + const bool may_replicate_other_non_contracting_dims = + group_count == other_non_contracting_partitions && + matching_contracting_partitions == other_contracting_partitions; + if (auto found_dims = FindMatchingPartitionedDimsForGrouping( + other.sharding(), output_grouped.device_groups)) { + other_group_dims = std::move(*found_dims); + } else if (may_replicate_other_contracting_dims && + (!may_replicate_other_non_contracting_dims || + ShapeUtil::ByteSizeOf(other.hlo()->shape()) <= + ShapeUtil::ByteSizeOf(MakePartitionedShape( + output_base_shape, output_sharding)))) { + for (const auto& dim : dims_mapping.contracting_dims) { + other_group_dims.push_back(lhs_matching ? dim.rhs : dim.lhs); + } + } else if (may_replicate_other_non_contracting_dims) { + for (const auto& dim : lhs_matching + ? dims_mapping.rhs_non_contracting_dims + : dims_mapping.lhs_non_contracting_dims) { + other_group_dims.push_back(lhs_matching ? dim.rhs : dim.lhs); + } + } else { + other = other.Replicate(); + } + } + matching = matching.Reshard(UngroupSharding(matching_grouped)); auto per_group_partitioner_state = CreatePerGroupPartitioningState( matching.state(), matching_grouped.device_groups, b); @@ -896,16 +913,14 @@ StatusOr PartitionDotGroupOnNonContracting( per_group_partitioner_state); auto partially_replicated_other = other.hlo(); - if (other.sharding().ReplicateOnLastTileDim() && - other.sharding().tile_assignment().dimensions().back() % - matching_non_contracting_partitions == - 0) { + if (other_group_dims.size() == 1 && + other_group_dims[0] == other.base_shape().rank()) { + // Group on replication dim. auto grouped = AlignGroupsWith( GroupShardingOnDims( - other.sharding(), - {other.sharding().tile_assignment().num_dimensions() - 1}, + other.sharding(), {other_group_dims[0]}, {other.sharding().tile_assignment().dimensions().back() / - matching_non_contracting_partitions}), + group_count}), output_grouped); other = other.Reshard(UngroupSharding(grouped)); partially_replicated_other = other.hlo(); @@ -916,9 +931,13 @@ StatusOr PartitionDotGroupOnNonContracting( AlignGroupsWith(GroupShardingOnDims(other.sharding(), other_group_dims), output_grouped, /*ignore_group_order=*/true); other = other.Reshard(UngroupSharding(other_grouped)); - // TODO(yuanzx): Use reshard to replicate when ready. partially_replicated_other = - other.ReplicatePartial(other_grouped.group_dims); + other + .Reshard(hlo_sharding_util::PartiallyReplicateTiledShardingOnDims( + other.sharding(), other_grouped.group_dims)) + .hlo(); + top_level_sharding_to_reset.emplace_back( + partially_replicated_other, partially_replicated_other->sharding()); partially_replicated_other->set_sharding(other_grouped.sharding); } auto other_p = PartitionedHlo(partially_replicated_other, other.base_shape(), @@ -937,7 +956,9 @@ StatusOr PartitionDotGroupOnNonContracting( } StatusOr PartitionDotGroupOnContracting( - PartitionedHlo lhs, PartitionedHlo rhs, int64 contracting_partitions, + PartitionedHlo lhs, PartitionedHlo rhs, + absl::Span + partitioned_contractin_dims, int64 output_batch_partitions, int64 output_lhs_non_contracting_partitions, int64 output_rhs_non_contracting_partitions, const Shape& output_base_shape, const HloSharding& output_sharding, @@ -962,13 +983,15 @@ StatusOr PartitionDotGroupOnContracting( auto rhs_tile_shape = rhs_sharding.tile_assignment().dimensions(); std::vector lhs_dims; std::vector rhs_dims; - for (const auto& dim : dims_mapping.contracting_dims) { + int64 group_count = 1; + for (const auto& dim : partitioned_contractin_dims) { lhs_dims.push_back(dim.lhs); rhs_dims.push_back(dim.rhs); + group_count *= lhs_sharding.tile_assignment().dim(dim.lhs); } if (ShapeUtil::ByteSizeOf(lhs.hlo()->shape()) > ShapeUtil::ByteSizeOf(rhs.hlo()->shape())) { - for (const auto& dim : dims_mapping.contracting_dims) { + for (const auto& dim : partitioned_contractin_dims) { rhs_tile_shape[dim.rhs] = lhs_tile_shape[dim.lhs]; } auto new_tile = rhs.sharding().tile_assignment(); @@ -977,7 +1000,7 @@ StatusOr PartitionDotGroupOnContracting( ? HloSharding::PartialTile(new_tile) : HloSharding::Tile(new_tile); } else { - for (const auto& dim : dims_mapping.contracting_dims) { + for (const auto& dim : partitioned_contractin_dims) { lhs_tile_shape[dim.lhs] = rhs_tile_shape[dim.rhs]; } auto new_tile = lhs.sharding().tile_assignment(); @@ -1012,43 +1035,47 @@ StatusOr PartitionDotGroupOnContracting( HloSharding inner_output_sharding = HloSharding::Replicate(); HloSharding outer_output_tmp_sharding = HloSharding::Replicate(); if (output_sharding.ReplicateOnLastTileDim() && - output_sharding.tile_assignment().dimensions().back() % - contracting_partitions == + output_sharding.tile_assignment().dimensions().back() % group_count == 0) { auto grouped = AlignGroupsWith( GroupShardingOnDims( output_sharding, {output_sharding.tile_assignment().num_dimensions() - 1}, {output_sharding.tile_assignment().dimensions().back() / - contracting_partitions}), - GroupShardingOnDims(lhs_sharding, lhs_dims)); + group_count}), + lhs_grouped); outer_output_tmp_sharding = UngroupSharding(grouped); inner_output_sharding = std::move(grouped.sharding); - } else if (output_lhs_non_contracting_partitions == contracting_partitions || - output_rhs_non_contracting_partitions == contracting_partitions || - output_batch_partitions == contracting_partitions) { + } else { std::vector group_dims; - if (output_lhs_non_contracting_partitions == contracting_partitions) { - for (const auto& dim : dims_mapping.lhs_non_contracting_dims) { - group_dims.push_back(dim.output); - } - } else if (output_rhs_non_contracting_partitions == - contracting_partitions) { - for (const auto& dim : dims_mapping.rhs_non_contracting_dims) { - group_dims.push_back(dim.output); - } - } else { - for (const auto& dim : dims_mapping.batch_dims) { - group_dims.push_back(dim.output); + if (auto found_dims = FindMatchingPartitionedDimsForGrouping( + output_sharding, lhs_grouped.device_groups)) { + group_dims = std::move(*found_dims); + } else if (output_lhs_non_contracting_partitions == group_count || + output_rhs_non_contracting_partitions == group_count || + output_batch_partitions == group_count) { + if (output_lhs_non_contracting_partitions == group_count) { + for (const auto& dim : dims_mapping.lhs_non_contracting_dims) { + group_dims.push_back(dim.output); + } + } else if (output_rhs_non_contracting_partitions == group_count) { + for (const auto& dim : dims_mapping.rhs_non_contracting_dims) { + group_dims.push_back(dim.output); + } + } else { + for (const auto& dim : dims_mapping.batch_dims) { + group_dims.push_back(dim.output); + } } } - auto grouped = - AlignGroupsWith(GroupShardingOnDims(output_sharding, group_dims), - GroupShardingOnDims(lhs_sharding, lhs_dims)); - inner_output_sharding = grouped.sharding; - outer_output_tmp_sharding = - hlo_sharding_util::PartiallyReplicateTiledShardingOnDims( - UngroupSharding(grouped), group_dims); + if (!group_dims.empty()) { + auto grouped = AlignGroupsWith( + GroupShardingOnDims(output_sharding, group_dims), lhs_grouped); + inner_output_sharding = grouped.sharding; + outer_output_tmp_sharding = + hlo_sharding_util::PartiallyReplicateTiledShardingOnDims( + UngroupSharding(grouped), group_dims); + } } auto inner_state = CreatePerGroupPartitioningState( lhs.state(), lhs_grouped.device_groups, b); @@ -1062,10 +1089,9 @@ StatusOr PartitionDotGroupOnContracting( GetPerGroupBaseShape(rhs_grouped, rhs.base_shape()), inner_state), MakePartitionedShape(output_base_shape, outer_output_tmp_sharding), - inner_output_sharding, dims_mapping, - num_partitions / contracting_partitions, create_sharded_dot, module, - original_hlo, threshold_for_windowed_einsum_mib, b, - windowed_dot_general_loops)); + inner_output_sharding, dims_mapping, num_partitions / group_count, + create_sharded_dot, module, original_hlo, + threshold_for_windowed_einsum_mib, b, windowed_dot_general_loops)); if (!dot) { return nullptr; } @@ -1141,6 +1167,8 @@ StatusOr PartitionDot( output_sharding, dims_mapping.lhs_non_contracting_dims, 2); const int64 output_rhs_non_contracting_partitions = get_partitions_for_dims( output_sharding, dims_mapping.rhs_non_contracting_dims, 2); + // Before we find partial matches along the dimensions, invoke base case again + // without may_reshard_without_detecting_match. TF_ASSIGN_OR_RETURN( auto try_partitioned_dot, PartitionBaseCase( @@ -1151,7 +1179,8 @@ StatusOr PartitionDot( lhs_non_contracting_partitions, rhs_non_contracting_partitions, output_lhs_non_contracting_partitions, output_rhs_non_contracting_partitions, - threshold_for_windowed_einsum_mib, b, windowed_dot_general_loops)); + threshold_for_windowed_einsum_mib, b, windowed_dot_general_loops, + /*may_reshard_without_detecting_match=*/false)); if (try_partitioned_dot) { return try_partitioned_dot; } @@ -1202,8 +1231,8 @@ StatusOr PartitionDot( : rhs_contracting_partitions, lhs_matching ? rhs_contracting_partitions : lhs_contracting_partitions, - lhs_matching ? lhs_non_contracting_partitions - : rhs_non_contracting_partitions, + lhs_matching ? dims_mapping.lhs_non_contracting_dims + : dims_mapping.rhs_non_contracting_dims, lhs_matching ? rhs_non_contracting_partitions : lhs_non_contracting_partitions, lhs_matching ? output_rhs_non_contracting_partitions @@ -1216,6 +1245,62 @@ StatusOr PartitionDot( return dot; } } + if (lhs_non_contracting_partitions > 1 && + output_lhs_non_contracting_partitions > 1) { + // If part of LHS non-contracting dims match output, try them. + std::vector matching_dims; + for (const auto& dim : dims_mapping.lhs_non_contracting_dims) { + int64 lhs_partitions = lhs.sharding().tile_assignment().dim(dim.lhs); + if (lhs_partitions > 1 && + lhs_partitions == output_sharding.tile_assignment().dim(dim.output)) { + matching_dims.push_back(dim); + } + } + if (!matching_dims.empty()) { + TF_ASSIGN_OR_RETURN( + auto dot, + PartitionDotGroupOnNonContracting( + /*lhs_matching=*/true, lhs, rhs, lhs_contracting_partitions, + rhs_contracting_partitions, matching_dims, + rhs_non_contracting_partitions, + output_rhs_non_contracting_partitions, output_base_shape, + output_sharding, dims_mapping, num_partitions, create_sharded_dot, + module, original_hlo, require_matching_devices_to_group, + threshold_for_windowed_einsum_mib, b, + windowed_dot_general_loops)); + if (dot) { + return dot; + } + } + } + if (rhs_non_contracting_partitions > 1 && + output_rhs_non_contracting_partitions > 1) { + // If part of RHS non-contracting dims match output, try them. + std::vector matching_dims; + for (const auto& dim : dims_mapping.rhs_non_contracting_dims) { + int64 rhs_partitions = rhs.sharding().tile_assignment().dim(dim.rhs); + if (rhs_partitions > 1 && + rhs_partitions == output_sharding.tile_assignment().dim(dim.output)) { + matching_dims.push_back(dim); + } + } + if (!matching_dims.empty()) { + TF_ASSIGN_OR_RETURN( + auto dot, + PartitionDotGroupOnNonContracting( + /*lhs_matching=*/false, rhs, lhs, rhs_contracting_partitions, + lhs_contracting_partitions, matching_dims, + lhs_non_contracting_partitions, + output_lhs_non_contracting_partitions, output_base_shape, + output_sharding, dims_mapping, num_partitions, create_sharded_dot, + module, original_hlo, require_matching_devices_to_group, + threshold_for_windowed_einsum_mib, b, + windowed_dot_general_loops)); + if (dot) { + return dot; + } + } + } // Case 3: Group partitions by contracting dimensions. if (lhs_contracting_partitions == rhs_contracting_partitions && @@ -1223,7 +1308,7 @@ StatusOr PartitionDot( TF_ASSIGN_OR_RETURN( auto dot, PartitionDotGroupOnContracting( - lhs, rhs, lhs_contracting_partitions, output_batch_partitions, + lhs, rhs, dims_mapping.contracting_dims, output_batch_partitions, output_lhs_non_contracting_partitions, output_rhs_non_contracting_partitions, output_base_shape, output_sharding, dims_mapping, num_partitions, create_sharded_dot, @@ -1233,6 +1318,71 @@ StatusOr PartitionDot( return dot; } } + if (lhs_contracting_partitions > 1 && rhs_contracting_partitions > 1) { + // If part of contracting dims match, try them. + std::vector matching_dims; + for (const auto& dim : dims_mapping.contracting_dims) { + int64 lhs_partitions = lhs.sharding().tile_assignment().dim(dim.lhs); + if (lhs_partitions > 1 && + lhs_partitions == rhs.sharding().tile_assignment().dim(dim.rhs)) { + matching_dims.push_back(dim); + } + } + if (!matching_dims.empty()) { + TF_ASSIGN_OR_RETURN( + auto dot, + PartitionDotGroupOnContracting( + lhs, rhs, matching_dims, output_batch_partitions, + output_lhs_non_contracting_partitions, + output_rhs_non_contracting_partitions, output_base_shape, + output_sharding, dims_mapping, num_partitions, create_sharded_dot, + module, original_hlo, require_matching_devices_to_group, + threshold_for_windowed_einsum_mib, b, + windowed_dot_general_loops)); + if (dot) { + return dot; + } + } + } + + // Case 4: If operands are replicated but output is partially replicated, + // recursive call with partial replication removed. + if (lhs.sharding().IsReplicated() && rhs.sharding().IsReplicated() && + output_sharding.ReplicateOnLastTileDim()) { + auto grouped_output = + GroupShardingOnDims(output_sharding, {output_base_shape.rank()}); + auto inner_state = CreatePerGroupPartitioningState( + lhs.state(), grouped_output.device_groups, b); + TF_ASSIGN_OR_RETURN( + auto dot, + PartitionDot(PartitionedHlo(lhs.hlo(), lhs.base_shape(), inner_state), + PartitionedHlo(rhs.hlo(), rhs.base_shape(), inner_state), + output_base_shape, grouped_output.sharding, dims_mapping, + output_sharding.NumTiles(), create_sharded_dot, module, + original_hlo, threshold_for_windowed_einsum_mib, b, + windowed_dot_general_loops)); + if (dot) { + return dot; + } + } + + // We failed to find partial matches, invoke base case again with + // may_reshard_without_detecting_match. + TF_ASSIGN_OR_RETURN( + auto dot, + PartitionBaseCase( + lhs, rhs, output_base_shape, output_sharding, dims_mapping, + num_partitions, create_sharded_dot, module, original_hlo, + lhs_batch_partitions, rhs_batch_partitions, output_batch_partitions, + lhs_contracting_partitions, rhs_contracting_partitions, + lhs_non_contracting_partitions, rhs_non_contracting_partitions, + output_lhs_non_contracting_partitions, + output_rhs_non_contracting_partitions, + threshold_for_windowed_einsum_mib, b, windowed_dot_general_loops, + /*may_reshard_without_detecting_match=*/true)); + if (dot) { + return dot; + } return nullptr; } diff --git a/tensorflow/compiler/xla/service/spmd/schedule_aware_all_gather_cse.cc b/tensorflow/compiler/xla/service/spmd/schedule_aware_all_gather_cse.cc new file mode 100644 index 00000000000..cc97d5ebda7 --- /dev/null +++ b/tensorflow/compiler/xla/service/spmd/schedule_aware_all_gather_cse.cc @@ -0,0 +1,132 @@ +/* Copyright 2020 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tensorflow/compiler/xla/service/spmd/schedule_aware_all_gather_cse.h" + +#include "absl/container/flat_hash_map.h" +#include "tensorflow/compiler/xla/service/hlo_casting_utils.h" +#include "tensorflow/compiler/xla/service/hlo_computation.h" +#include "tensorflow/compiler/xla/service/hlo_instruction.h" +#include "tensorflow/compiler/xla/service/hlo_instructions.h" +#include "tensorflow/stream_executor/lib/statusor.h" + +namespace xla { +namespace { + +HloCollectiveInstruction* MayConsiderAsAllGather(HloInstruction* hlo, + bool for_replicas) { + auto coll = DynCast(hlo); + if (!coll) { + return nullptr; + } + if (coll->constrain_layout()) { + return nullptr; + } + if (for_replicas == coll->channel_id().has_value()) { + return nullptr; + } + if (coll->opcode() == HloOpcode::kAllGather) { + return coll; + } + // Consider broadcast -> dynamic-update-slice -> all-reduce as all-gather. + if (coll->opcode() == HloOpcode::kAllReduce && coll->shape().IsArray()) { + auto operand = coll->operand(0); + return operand->opcode() == HloOpcode::kDynamicUpdateSlice && + operand->operand(0)->opcode() == HloOpcode::kBroadcast + ? coll + : nullptr; + } + return nullptr; +} + +StatusOr RunOnComputation(HloComputation* comp, bool for_replicas, + int64 distance_threshold) { + // We consider estimate the live ranges of all-gathers by comparing their + // users' distance to the root, e.g., height. + absl::flat_hash_map height; + auto ordered_hlos = comp->MakeInstructionPostOrder(); + int64 max_height = 0; + for (auto it = ordered_hlos.rbegin(); it != ordered_hlos.rend(); ++it) { + auto hlo = *it; + int64 h = 0; + for (auto user : hlo->users()) { + h = std::max(h, height[user]) + 1; + } + max_height = std::max(max_height, h); + height[hlo] = h; + } + + auto lowest_user_height = [&](const HloInstruction* hlo) { + int64 lowest = height[hlo]; + for (auto user : hlo->users()) { + lowest = std::min(lowest, height[user]); + } + return lowest; + }; + + absl::flat_hash_map> + operand_to_ag; + bool changed = false; + for (auto hlo : ordered_hlos) { + auto ag = MayConsiderAsAllGather(hlo, for_replicas); + if (!ag) { + continue; + } + + auto& earlier_ags = operand_to_ag[ag->operand(0)]; + bool found = false; + int64 lowest_user_h = lowest_user_height(ag); + for (auto& eag : earlier_ags) { + auto old_channel_id = ag->channel_id(); + if (eag->channel_id() && ag->channel_id()) { + ag->set_channel_id(eag->channel_id()); + } + if (!eag->Identical(*ag)) { + ag->set_channel_id(old_channel_id); + continue; + } + found = true; + ag->set_channel_id(old_channel_id); + if (lowest_user_height(eag) > lowest_user_h + distance_threshold) { + eag = ag; + continue; + } + changed = true; + VLOG(1) << "Replacing " << ag->ToString() << " with " << eag->ToString(); + TF_RETURN_IF_ERROR(ag->ReplaceAllUsesWith(eag)); + break; + } + if (!found) { + earlier_ags.push_back(ag); + } + } + return changed; +} + +} // namespace + +StatusOr ScheduleAwareAllGatherCSE::Run(HloModule* module) { + bool changed = false; + for (auto comp : module->computations()) { + TF_ASSIGN_OR_RETURN( + auto comp_changed, + RunOnComputation(comp, for_replicas_, distance_threshold_)); + changed |= comp_changed; + } + return changed; +} + +} // namespace xla diff --git a/tensorflow/compiler/xla/service/spmd/schedule_aware_all_gather_cse.h b/tensorflow/compiler/xla/service/spmd/schedule_aware_all_gather_cse.h new file mode 100644 index 00000000000..4653286ae97 --- /dev/null +++ b/tensorflow/compiler/xla/service/spmd/schedule_aware_all_gather_cse.h @@ -0,0 +1,49 @@ +/* Copyright 2020 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_COMPILER_XLA_SERVICE_SPMD_SCHEDULE_AWARE_ALL_GATHER_CSE_H_ +#define TENSORFLOW_COMPILER_XLA_SERVICE_SPMD_SCHEDULE_AWARE_ALL_GATHER_CSE_H_ + +#include "tensorflow/compiler/xla/service/hlo_module.h" +#include "tensorflow/compiler/xla/service/hlo_pass_interface.h" + +namespace xla { + +// Performs CSE for all-gather if their users are within reasonable live range. +class ScheduleAwareAllGatherCSE : public HloModulePass { + public: + // distance_threshold: maximum live range (in number of HLO instructions on + // the path) to consider CSE. + // for_replicas: specifies if this pass is for cross-replica or + // cross-partition all-gathers. + explicit ScheduleAwareAllGatherCSE(int64 distance_threshold, + bool for_replicas) + : distance_threshold_(distance_threshold), for_replicas_(for_replicas) {} + + ~ScheduleAwareAllGatherCSE() override = default; + absl::string_view name() const override { + return "schedule-aware-all-gather-cse"; + } + + StatusOr Run(HloModule* module) override; + + private: + int64 distance_threshold_; + bool for_replicas_; +}; + +} // namespace xla + +#endif // TENSORFLOW_COMPILER_XLA_SERVICE_SPMD_SCHEDULE_AWARE_ALL_GATHER_CSE_H_ diff --git a/tensorflow/compiler/xla/service/spmd/spmd_partitioner.cc b/tensorflow/compiler/xla/service/spmd/spmd_partitioner.cc index a850c05600e..f16b7bacda3 100644 --- a/tensorflow/compiler/xla/service/spmd/spmd_partitioner.cc +++ b/tensorflow/compiler/xla/service/spmd/spmd_partitioner.cc @@ -221,15 +221,23 @@ HloInstruction* SpmdBuilder::AddInstruction( PartitionedHlo PartitionedHlo::Reshard(const HloSharding& target) { auto& cache = state_.reshard_cache->per_hlo_cache[hlo()].reshard_cache; - for (auto& entry : cache) { - if (entry.first == target) { - return entry.second; + const bool is_to_replicate = + hlo_->shape().IsArray() && target.NumTiles() < sharding().NumTiles(); + if (!is_to_replicate || state_.partitioner->options().cache_all_gather) { + for (auto& entry : cache) { + if (entry.first == target) { + return entry.second; + } } } - cache.emplace_back(target, ReshardNoCache(target)); - state_.reshard_cache->per_hlo_cache[cache.back().second.hlo()] + auto resharded = ReshardNoCache(target); + state_.reshard_cache->per_hlo_cache[resharded.hlo()] .reshard_cache.emplace_back(sharding(), *this); - return cache.back().second; + if (!is_to_replicate || state_.partitioner->options().cache_all_gather) { + cache.emplace_back(target, std::move(resharded)); + return cache.back().second; + } + return resharded; } PartitionedHlo PartitionedHlo::ReshardNoCache(const HloSharding& target) { @@ -282,133 +290,17 @@ PartitionedHlo PartitionedHlo::ReshardNoCache(const HloSharding& target) { return ReshardWithAllToAll(target, *src_tgt_dims); } - // Partial replicated to tiled. - if (sharding().ReplicateOnLastTileDim() && !target.ReplicateOnLastTileDim() && - !target.IsTileMaximal()) { - // Get the temp sharding target from partial replicate to target tile dims. - // target_compatible_sharding has the same tile_assignment dimensions - // as the target and can reshard to target by collective permute. - // target_compatible_sharding could have different device assignment as - // targe. sharding() can reshard to target_compatible_sharding by - // dynamic slice. - auto target_compatible_sharding = PartialReplicateToTileCompatibleSharding( - sharding(), target.tile_assignment().dimensions()); - // Reshard to target_compatible_sharding by dynamic slice. - if (target_compatible_sharding.has_value()) { - std::vector expand_tile_dims; - std::vector tiling_dim_factors; - int64 rank = shape.rank(); - tiling_dim_factors.reserve(rank); - auto temp_target_sharding = target_compatible_sharding.value(); - for (int64 dim = 0; dim < rank; dim++) { - if (temp_target_sharding.tile_assignment().dim(dim) > - sharding().tile_assignment().dim(dim)) { - expand_tile_dims.push_back(dim); - } - tiling_dim_factors.emplace_back( - temp_target_sharding.tile_assignment().dim(dim) / - sharding().tile_assignment().dim(dim)); - } - - // Get per_group partitioner state. - std::vector group_dims( - sharding().tile_assignment().num_dimensions() - 1); - std::iota(group_dims.begin(), group_dims.end(), 0); - auto sharding_grouped = GroupShardingOnDims(sharding(), group_dims); - auto per_group_partitioner_state = CreatePerGroupPartitioningState( - state_, sharding_grouped.device_groups, state_.b); - // 2. Get the padded_hlo, do right halo exchange if needed. - auto padded_hlo = PadFromPartialReplicateShape( - hlo_, base_shape_, sharding(), temp_target_sharding, expand_tile_dims, - state_.collective_ops_creator, state_.next_channel_id, - state_.partition_id, state_.b); - if (padded_hlo.has_value()) { - // 3. Slice out the tile from replicate ones. - auto shard_shape = - MakePartitionedShape(base_shape_, temp_target_sharding); - // device assignment within each group is sorted in - // HloSharding::PartialTile, thus partiton_id within each group can be - // matched with the order in tile_assignment. - Array tiling_assignment(tiling_dim_factors); - tiling_assignment.FillIota(0); - auto slice = - state_.b->AddInstruction(HloInstruction::CreateDynamicSlice( - shard_shape, padded_hlo.value(), - MakePartitionOffsets(padded_hlo.value()->shape(), - HloSharding::Tile(tiling_assignment), - per_group_partitioner_state.partition_id, - per_group_partitioner_state.b), - shard_shape.dimensions())); - slice->set_sharding(temp_target_sharding); - auto result = PartitionedHlo(slice, base_shape_, state_); - // If temp_target_sharding's device assignment is different from target, - // use collective permute to reshard. - if (CanReshardWithCollectivePermute(temp_target_sharding, target)) { - return result.ReshardWithCollectivePermute(target); - } - // If device assignment in temp_target_sharding and target are the same, - // return result directly. - return result; - } + if (!target.IsTileMaximal() && sharding().ReplicateOnLastTileDim()) { + auto try_reshard = ReshardFromPartialReplicateWithDynamicSlice(target); + if (try_reshard.has_value()) { + return try_reshard.value(); } } - // Tiled to partial replicate - if (!sharding().ReplicateOnLastTileDim() && !sharding().IsTileMaximal() && - target.ReplicateOnLastTileDim()) { - // Get the comptible sharding to target with resharding by all reduce. - auto compatible_sharding = PartialReplicateToTileCompatibleSharding( - target, sharding().tile_assignment().dimensions()); - if (compatible_sharding.has_value()) { - auto temp_sharding = compatible_sharding.value(); - auto partitioned_hlo = *this; - // Use collective permute to adjust device assignment if needed. - if (CanReshardWithCollectivePermute(sharding(), temp_sharding)) { - partitioned_hlo = - partitioned_hlo.ReshardWithCollectivePermute(temp_sharding); - } - - // Get replicate dims and replicate factor of each dimensions. - int64 rank = hlo_->shape().rank(); - std::vector replicate_dims; - std::vector replicate_factors; - for (int64 dim = 0; dim < rank; dim++) { - int64 replicate_factor = temp_sharding.tile_assignment().dim(dim) / - target.tile_assignment().dim(dim); - if (replicate_factor > 1) { - replicate_dims.emplace_back(dim); - replicate_factors.emplace_back(replicate_factor); - } - } - - // Do left halo exchange if all-reduce directly will remove useful data - // from the source. - auto halo_exchange = TileToPartialReplicateHaloExchange( - partitioned_hlo.hlo_, base_shape_, temp_sharding, target, - replicate_dims, partitioned_hlo.state().collective_ops_creator, - partitioned_hlo.state().next_channel_id, - partitioned_hlo.state().partition_id, partitioned_hlo.state().b); - if (halo_exchange.has_value()) { - auto halo_exchange_hlo = halo_exchange.value(); - // Grouped on replicate dimensions. - auto sharding_grouped = GroupShardingOnDims( - temp_sharding, replicate_dims, replicate_factors); - auto per_group_partitioner_state = CreatePerGroupPartitioningState( - partitioned_hlo.state(), sharding_grouped.device_groups, - partitioned_hlo.state().b); - auto base_shape = MakePartitionedShape(base_shape_, target); - // It's possible that halo_exchange_hlo == hlo.hlo(). - // Record the sharding of hlo here, and reset it before return. - auto original_sharding = partitioned_hlo.sharding(); - halo_exchange_hlo->set_sharding(sharding_grouped.sharding); - auto partial_replicate_hlo = PartitionedHlo( - halo_exchange_hlo, base_shape, per_group_partitioner_state); - HloInstruction* result = - partial_replicate_hlo.ReplicatePartial(replicate_dims); - partitioned_hlo.hlo()->set_sharding(original_sharding); - result->set_sharding(target); - return PartitionedHlo(result, base_shape_, partitioned_hlo.state()); - } + if (!sharding().IsTileMaximal() && target.ReplicateOnLastTileDim()) { + auto try_reshard = ReshardToPartialReplicateWithAllGather(target); + if (try_reshard.has_value()) { + return try_reshard.value(); } } @@ -794,6 +686,14 @@ PartitionedHlo::ReshardAsWindowedInput(const Window& window, } PartitionedHlo PartitionedHlo::Replicate() { + auto& cache = state_.reshard_cache->per_hlo_cache[hlo()].reshard_cache; + if (state_.partitioner->options().cache_all_gather) { + for (auto& entry : cache) { + if (entry.first.IsReplicated()) { + return entry.second; + } + } + } const HloSharding& sharding = hlo_->sharding(); const Shape& shape = hlo_->shape(); CHECK(!shape.IsTuple() && shape.element_type() != TOKEN); @@ -801,7 +701,6 @@ PartitionedHlo PartitionedHlo::Replicate() { if (sharding.IsReplicated()) { return *this; } - auto& cache = state_.reshard_cache->per_hlo_cache[hlo()].reshard_cache; for (auto& entry : cache) { if (entry.first.IsReplicated()) { return entry.second; @@ -810,8 +709,11 @@ PartitionedHlo PartitionedHlo::Replicate() { auto update_cache = [&](PartitionedHlo resharded) { state_.reshard_cache->per_hlo_cache[resharded.hlo()] .reshard_cache.emplace_back(sharding, *this); - cache.emplace_back(HloSharding::Replicate(), std::move(resharded)); - return cache.back().second; + if (state_.partitioner->options().cache_all_gather) { + cache.emplace_back(HloSharding::Replicate(), std::move(resharded)); + return cache.back().second; + } + return resharded; }; // 'Single Device' to 'Repliated'. if (sharding.IsTileMaximal()) { @@ -872,6 +774,155 @@ HloInstruction* PartitionedHlo::ReplicatePartial(absl::Span dims) { return result; } +absl::optional +PartitionedHlo::ReshardToPartialReplicateWithAllGather( + const HloSharding& target) { + if (!target.ReplicateOnLastTileDim()) { + return absl::nullopt; + } + // Tiled/partial replicate to partial replicate + // Get the comptible sharding to target with resharding by all reduce. + auto compatible_sharding = + PartialReplicateReshardCompatibleSharding(target, sharding()); + if (!compatible_sharding.has_value()) { + return absl::nullopt; + } + + const auto& temp_sharding = compatible_sharding.value(); + auto partitioned_hlo = *this; + // Use collective permute to adjust device assignment if needed. + if (CanReshardWithCollectivePermute(sharding(), temp_sharding)) { + partitioned_hlo = + partitioned_hlo.ReshardWithCollectivePermute(temp_sharding); + } + + // Get replicate dims and replicate factor of each dimensions. + int64 rank = hlo_->shape().rank(); + std::vector replicate_dims; + std::vector replicate_factors; + for (int64 dim = 0; dim < rank; dim++) { + int64 replicate_factor = temp_sharding.tile_assignment().dim(dim) / + target.tile_assignment().dim(dim); + if (replicate_factor > 1) { + replicate_dims.emplace_back(dim); + replicate_factors.emplace_back(replicate_factor); + } + } + + // Do left halo exchange if all-reduce directly will remove useful data + // from the source. + auto halo_exchange = TileToPartialReplicateHaloExchange( + partitioned_hlo.hlo_, base_shape_, temp_sharding, target, replicate_dims, + partitioned_hlo.state().collective_ops_creator, + partitioned_hlo.state().next_channel_id, + partitioned_hlo.state().partition_id, partitioned_hlo.state().b); + if (!halo_exchange.has_value()) { + return absl::nullopt; + } + auto halo_exchange_hlo = halo_exchange.value(); + // Grouped on replicate dimensions. + auto sharding_grouped = + GroupShardingOnDims(temp_sharding, replicate_dims, replicate_factors); + auto per_group_partitioner_state = CreatePerGroupPartitioningState( + partitioned_hlo.state(), sharding_grouped.device_groups, + partitioned_hlo.state().b); + auto base_shape = MakePartitionedShape(base_shape_, target); + // It's possible that halo_exchange_hlo == hlo.hlo(). + // Record the sharding of hlo here, and reset it before return. + auto original_sharding = partitioned_hlo.sharding(); + halo_exchange_hlo->set_sharding(sharding_grouped.sharding); + auto partial_replicate_hlo = PartitionedHlo(halo_exchange_hlo, base_shape, + per_group_partitioner_state); + HloInstruction* result = + partial_replicate_hlo.ReplicatePartial(replicate_dims); + partitioned_hlo.hlo()->set_sharding(original_sharding); + result->set_sharding(target); + return PartitionedHlo(result, base_shape_, partitioned_hlo.state()); +} + +absl::optional +PartitionedHlo::ReshardFromPartialReplicateWithDynamicSlice( + const HloSharding& target) { + if (!sharding().ReplicateOnLastTileDim()) { + return absl::nullopt; + } + + // Get the temp sharding target from partial replicate to target tile dims. + // target_compatible_sharding has the same tile_assignment dimensions + // as the target and can reshard to target by collective permute. + // target_compatible_sharding could have different device assignment as + // targe. sharding() can reshard to target_compatible_sharding by + // dynamic slice. + auto target_compatible_sharding = + PartialReplicateReshardCompatibleSharding(sharding(), target); + // Reshard to target_compatible_sharding by dynamic slice. + if (!target_compatible_sharding.has_value()) { + return absl::nullopt; + } + std::vector expand_tile_dims; + std::vector tiling_dim_factors; + int64 rank = hlo_->shape().rank(); + tiling_dim_factors.reserve(target.tile_assignment().num_dimensions()); + const auto& temp_target_sharding = target_compatible_sharding.value(); + for (int64 dim = 0; dim < rank; dim++) { + if (temp_target_sharding.tile_assignment().dim(dim) > + sharding().tile_assignment().dim(dim)) { + expand_tile_dims.push_back(dim); + } + tiling_dim_factors.emplace_back( + temp_target_sharding.tile_assignment().dim(dim) / + sharding().tile_assignment().dim(dim)); + } + + // Add another dimension in tiling_dim_factors if target is partial replicate. + if (target.ReplicateOnLastTileDim()) { + tiling_dim_factors.emplace_back( + target.tile_assignment().dimensions().back()); + } + + // Get per_group partitioner state. + std::vector group_dims(sharding().tile_assignment().num_dimensions() - + 1); + std::iota(group_dims.begin(), group_dims.end(), 0); + auto sharding_grouped = GroupShardingOnDims(sharding(), group_dims); + auto per_group_partitioner_state = CreatePerGroupPartitioningState( + state_, sharding_grouped.device_groups, state_.b); + // 2. Get the padded_hlo, do right halo exchange if needed. + auto padded_hlo = PadFromPartialReplicateShape( + hlo_, base_shape_, sharding(), temp_target_sharding, expand_tile_dims, + state_.collective_ops_creator, state_.next_channel_id, + state_.partition_id, state_.b); + if (!padded_hlo.has_value()) { + return absl::nullopt; + } + // 3. Slice out the tile from replicate ones. + auto shard_shape = MakePartitionedShape(base_shape_, temp_target_sharding); + // device assignment within each group is sorted in + // HloSharding::PartialTile, thus partiton_id within each group can be + // matched with the order in tile_assignment. + Array tiling_assignment(tiling_dim_factors); + tiling_assignment.FillIota(0); + auto slice = state_.b->AddInstruction(HloInstruction::CreateDynamicSlice( + shard_shape, padded_hlo.value(), + MakePartitionOffsets(padded_hlo.value()->shape(), + target.ReplicateOnLastTileDim() + ? HloSharding::PartialTile(tiling_assignment) + : HloSharding::Tile(tiling_assignment), + per_group_partitioner_state.partition_id, + per_group_partitioner_state.b), + shard_shape.dimensions())); + slice->set_sharding(temp_target_sharding); + auto result = PartitionedHlo(slice, base_shape_, state_); + // If temp_target_sharding's device assignment is different from target, + // use collective permute to reshard. + if (CanReshardWithCollectivePermute(temp_target_sharding, target)) { + return result.ReshardWithCollectivePermute(target); + } + // If device assignment in temp_target_sharding and target are the same, + // return result directly. + return result; +} + PartitionedHlo PartitionedHlo::Broadcast() const { const Shape& shape = hlo_->shape(); const HloSharding& sharding = hlo_->sharding(); @@ -1048,6 +1099,25 @@ PartitionedHlo PartitionedHlo::ReshardWithCollectivePermute( const HloSharding& target) const { CHECK(CanReshardWithCollectivePermute(sharding(), target)) << sharding().ToString() << " to " << target.ToString(); + if (hlo()->opcode() == HloOpcode::kBroadcast) { + // If hlo() is a broadcast, check if data is already the same between + // source/destination pairs. + std::vector new_dims; + for (int64 i = 0; i < hlo()->shape().rank(); ++i) { + if (!absl::c_linear_search(hlo()->dimensions(), i)) { + new_dims.push_back(i); + } + } + if (hlo_sharding_util::PartiallyReplicateTiledShardingOnDims(sharding(), + new_dims) == + hlo_sharding_util::PartiallyReplicateTiledShardingOnDims(target, + new_dims)) { + auto copy = state_.b->AddInstruction( + HloInstruction::CreateUnary(hlo()->shape(), HloOpcode::kCopy, hlo())); + copy->set_sharding(target); + return PartitionedHlo(copy, base_shape_, state_); + } + } std::vector> src_dst_pairs; sharding().tile_assignment().Each( [&](absl::Span indices, int64 src_device) { @@ -1868,6 +1938,16 @@ Status SpmdPartitioningVisitor::HandleReshape(HloInstruction* hlo) { return Status::OK(); } + // Check if operand sharding and sharding are both tiled or partial replicate. + // If both of them are partial replicate, check num_replications are the same. + if (operand.sharding().ReplicateOnLastTileDim() != + sharding.ReplicateOnLastTileDim() || + (sharding.ReplicateOnLastTileDim() && + (operand.sharding().tile_assignment().dimensions().back() != + sharding.tile_assignment().dimensions().back()))) { + return DefaultAction(hlo); + } + // Try use halo exchange for certain split-dim/merge-dims cases. // ReshapeSharding failed in these cases probably due to uneven partitioning, // where halo exchange could help. Specifically we check the following @@ -1903,7 +1983,14 @@ Status SpmdPartitioningVisitor::HandleReshape(HloInstruction* hlo) { Array new_input_tile_assignment = sharding.tile_assignment(); new_input_tile_assignment.Reshape( operand.sharding().tile_assignment().dimensions()); - operand = operand.Reshard(HloSharding::Tile(new_input_tile_assignment)); + auto aligned_sharding = + sharding.ReplicateOnLastTileDim() + ? HloSharding::PartialTile(new_input_tile_assignment) + : HloSharding::Tile(new_input_tile_assignment); + operand = operand.Reshard(aligned_sharding); + auto replication_count = sharding.ReplicateOnLastTileDim() + ? sharding.tile_assignment().dimensions().back() + : 1; int64 input_dim_size = operand.base_shape().dimensions(input_sharded_dim); int64 output_dim_size = hlo->shape().dimensions(output_sharded_dim); @@ -1926,7 +2013,7 @@ Status SpmdPartitioningVisitor::HandleReshape(HloInstruction* hlo) { dim->set_padding_low(0); if (i == input_sharded_dim) { dim->set_padding_high(output_shard_size * split_factor * - num_partitions_ - + num_partitions_ / replication_count - input_dim_size); } else { dim->set_padding_high(0); @@ -1964,8 +2051,8 @@ Status SpmdPartitioningVisitor::HandleReshape(HloInstruction* hlo) { tmp_reshape->set_sharding(hlo->sharding()); auto tmp_full_shape = tmp_shard_shape; tmp_full_shape.set_dimensions( - output_sharded_dim, - tmp_shard_shape.dimensions(output_sharded_dim) * num_partitions_); + output_sharded_dim, tmp_shard_shape.dimensions(output_sharded_dim) * + num_partitions_ / replication_count); auto tmp_output = PartitionedHlo(tmp_reshape, tmp_full_shape, MakePartitioningState()); @@ -1982,7 +2069,7 @@ Status SpmdPartitioningVisitor::HandleReshape(HloInstruction* hlo) { if (i == output_sharded_dim) { dim->set_padding_high(output_dim_size - tmp_shard_shape.dimensions(output_sharded_dim) * - num_partitions_); + num_partitions_ / replication_count); } else { dim->set_padding_high(0); } @@ -2605,7 +2692,13 @@ Status SpmdPartitioningVisitor::HandleReduce(HloInstruction* hlo) { .Reshard(HloSharding::Replicate()) .hlo()); inputs.push_back(GetPartitionedHlo(hlo->operand(operand_id))); - if (operand_id > 0) { + if (hlo->shape().IsTuple() && operand_id == 0) { + // We cannot do tuple-reduce where partitioned dimensions are reduced. + // Partially replicate on those dims. + inputs[0] = inputs[0].Reshard( + hlo_sharding_util::PartiallyReplicateTiledShardingOnDims( + inputs[0].sharding(), hlo->dimensions())); + } else { // Make sure all operands are sharded in the same way. inputs.back() = inputs.back().Reshard(inputs[0].sharding()); } @@ -2613,17 +2706,6 @@ Status SpmdPartitioningVisitor::HandleReduce(HloInstruction* hlo) { inputs.back() = inputs.back().PadWithValue(inits[operand_id]); } } - bool reduce_sharded_dimension = false; - if (!inputs[0].sharding().IsTileMaximal()) { - reduce_sharded_dimension = absl::c_any_of(hlo->dimensions(), [&](int64 i) { - return inputs[0].sharding().tile_assignment().dim(i) > 1; - }); - - // reduce_sharded_dimension is not supported for tuple-shaped reduces. - if (reduce_sharded_dimension && input_count > 1) { - return DefaultAction(hlo); - } - } std::vector new_operand_shapes(input_count * 2); for (int64 i = 0; i < input_count; ++i) { @@ -2646,6 +2728,11 @@ Status SpmdPartitioningVisitor::HandleReduce(HloInstruction* hlo) { SetPartitionedHlo(hlo, [&]() { HloInstruction* reduce = local_reduce; + const bool reduce_sharded_dimension = + !inputs[0].sharding().IsTileMaximal() && + absl::c_any_of(hlo->dimensions(), [&](int64 i) { + return inputs[0].sharding().tile_assignment().dim(i) > 1; + }); if (reduce_sharded_dimension) { CHECK(local_reduce->shape().IsArray()); std::vector preserved_dims; @@ -3353,7 +3440,7 @@ StatusOr SpmdPartitioner::Run(HloModule* module) { HloPassPipeline pass("spmd-cleanup"); pass.AddPass(); pass.AddPass(); - pass.AddPass(/*is_layout_sensitive=*/true); + pass.AddPass(/*is_layout_sensitive=*/false); pass.AddPass(); TF_RETURN_IF_ERROR(pass.Run(module).status()); } diff --git a/tensorflow/compiler/xla/service/spmd/spmd_partitioner.h b/tensorflow/compiler/xla/service/spmd/spmd_partitioner.h index a612c16bdae..6447d08be41 100644 --- a/tensorflow/compiler/xla/service/spmd/spmd_partitioner.h +++ b/tensorflow/compiler/xla/service/spmd/spmd_partitioner.h @@ -47,6 +47,12 @@ struct SpmdPartitionerOptions { // Whether the entry computations' signature could change after partitioning. bool allow_module_signature_change = false; + + // Whether to use cached all-gather to avoid repeatedly replicate a tiled + // tensor. If it is set to false, the result tends to be more + // memory-efficient, and the compiler can use the ScheduleAwareAllGatherCSE + // pass to CSE some all-gathers which are relatively close to each other. + bool cache_all_gather = true; }; // Class to wrap the computation builder to capture information during SPMD @@ -180,6 +186,8 @@ class SpmdPartitioner : public HloModulePass { int64 channel_id, absl::Span selected_dims, const SPMDCollectiveOpsCreator& collectives_creator); + const SpmdPartitionerOptions& options() { return options_; } + protected: virtual std::unique_ptr CreateVisitor( HloComputation* computation, int64 num_partitions, int64 num_replicas, @@ -305,6 +313,14 @@ class PartitionedHlo { // Helper function to reshard the tensor using CollectivePermute. PartitionedHlo ReshardWithCollectivePermute(const HloSharding& target) const; + // Helper function to reshard to partial replicate using AllGather. + absl::optional ReshardToPartialReplicateWithAllGather( + const HloSharding& target); + + // Helper function to reshard from partial replicate using DynamicSlice. + absl::optional ReshardFromPartialReplicateWithDynamicSlice( + const HloSharding& target); + // SPMD instruction. HloInstruction* hlo_; diff --git a/tensorflow/compiler/xla/service/spmd/spmd_partitioner_test.cc b/tensorflow/compiler/xla/service/spmd/spmd_partitioner_test.cc index 1dc4c474c49..089c4c339a4 100644 --- a/tensorflow/compiler/xla/service/spmd/spmd_partitioner_test.cc +++ b/tensorflow/compiler/xla/service/spmd/spmd_partitioner_test.cc @@ -138,8 +138,7 @@ ENTRY entry { op::AllReduce(op::Select( op::Broadcast(op::Compare(op::PartitionId(), op::Constant())), op::Constant(), op::Broadcast())), - op::Reshape(op::DynamicSlice(op::Constant(), op::PartitionId(), - op::Constant())), + op::Reshape(op::DynamicSlice(op::Constant(), op::PartitionId())), op::Constant())), op::Shape("s32[1,3]"))); } @@ -161,8 +160,7 @@ ENTRY entry { op::Copy(op::AllReduce(AllOf( op::DynamicUpdateSlice( op::Broadcast(), AllOf(op::Constant(), op::Shape("s32[1,3]")), - op::Reshape(op::DynamicSlice(op::Constant(), op::PartitionId(), - op::Constant())), + op::Reshape(op::DynamicSlice(op::Constant(), op::PartitionId())), op::Constant()), op::Shape("s32[2,3]"))))); } @@ -184,8 +182,7 @@ ENTRY entry { op::Copy(op::Copy(op::AllReduce(AllOf( op::DynamicUpdateSlice( op::Broadcast(), AllOf(op::Constant(), op::Shape("s32[1,3]")), - op::Reshape(op::DynamicSlice(op::Constant(), op::PartitionId(), - op::Constant())), + op::Reshape(op::DynamicSlice(op::Constant(), op::PartitionId())), op::Constant()), op::Shape("s32[2,3]")))))); } @@ -279,8 +276,8 @@ ENTRY entry { HloInstruction* root = module->entry_computation()->root_instruction(); ASSERT_THAT(root, op::Tuple()); - auto offset = op::Reshape( - op::DynamicSlice(op::Constant(), op::PartitionId(), op::Constant())); + auto offset = + op::Reshape(op::DynamicSlice(op::Constant(), op::PartitionId())); EXPECT_THAT(root->operand(0), op::DynamicSlice(op::GetTupleElement(op::Parameter()), offset, @@ -305,13 +302,13 @@ ENTRY entry { PartitionComputation(hlo_string, /*num_devices=*/2)); HloInstruction* root = module->entry_computation()->root_instruction(); EXPECT_THAT( - root, op::Copy(op::AllReduce(op::DynamicUpdateSlice( - op::Broadcast(), - op::GetTupleElement( - AllOf(op::Infeed(), op::Shape("(f32[4,2]{1,0}, token[])"))), - op::Reshape(op::DynamicSlice(op::Constant(), op::PartitionId(), - op::Constant())), - op::Constant())))); + root, + op::Copy(op::AllReduce(op::DynamicUpdateSlice( + op::Broadcast(), + op::GetTupleElement( + AllOf(op::Infeed(), op::Shape("(f32[4,2]{1,0}, token[])"))), + op::Reshape(op::DynamicSlice(op::Constant(), op::PartitionId())), + op::Constant())))); } TEST_F(SpmdPartitioningTest, UnevenTiledInfeed) { @@ -2598,6 +2595,79 @@ ENTRY entry { EXPECT_THAT(root, AllOf(op::Transpose(), op::Shape("f32[16,2,38,38]"))); } +TEST_F(SpmdPartitioningTest, PartialReplicateShardableTranspose) { + const char* const hlo_string = R"( +HloModule module + +ENTRY entry { + %param0 = f32[16,38,38,4] parameter(0) + %param0.copy = f32[16,38,38,4] copy(%param0), + sharding={devices=[1,2,1,1,2]0,1,2,3 last_tile_dim_replicate} + ROOT %transpose = f32[16,4,38,38] transpose(%param0.copy), + dimensions={0,3,1,2}, + sharding={devices=[1,1,2,1,2]0,1,2,3 last_tile_dim_replicate} +})"; + + TF_ASSERT_OK_AND_ASSIGN(auto module, + PartitionComputation(hlo_string, /*num_devices=*/4)); + VLOG(1) << module->ToString(); + + auto root = module->entry_computation()->root_instruction(); + auto param0 = AllOf( + op::Copy(op::DynamicSlice(op::Parameter(), op::Constant(), op::Reshape(), + op::Constant(), op::Constant())), + op::Shape("f32[16,19,38,4]")); + EXPECT_THAT(root, AllOf(op::Transpose(param0), op::Shape("f32[16,4,19,38]"))); +} + +TEST_F(SpmdPartitioningTest, PartialReplicateNonShardableTranspose) { + const char* const hlo_string = R"( +HloModule module + +ENTRY entry { + %param0 = f32[16,38,38,4] parameter(0) + %param0.copy = f32[16,38,38,4] copy(%param0), + sharding={devices=[1,2,1,1,2]0,1,2,3 last_tile_dim_replicate} + ROOT %transpose = f32[16,4,38,38] transpose(%param0.copy), + dimensions={0,3,1,2}, + sharding={devices=[1,2,1,1,2]0,1,2,3 last_tile_dim_replicate} +})"; + + TF_ASSERT_OK_AND_ASSIGN(auto module, + PartitionComputation(hlo_string, /*num_devices=*/4)); + VLOG(1) << module->ToString(); + + auto root = module->entry_computation()->root_instruction(); + auto resahrd = AllOf(op::Reshape(op::Transpose(op::Reshape(op::AllToAll()))), + op::Shape("f32[16,38,38,2]")); + EXPECT_THAT(root, AllOf(op::Transpose(), op::Shape("f32[16,2,38,38]"))); +} + +TEST_F(SpmdPartitioningTest, PartialReplicateMultiDimensionShardedTranspose) { + const char* const hlo_string = R"( +HloModule module + +ENTRY entry { + %param0 = f32[16,38,38,4] parameter(0) + %param0.copy = f32[16,38,38,4] copy(%param0), + sharding={devices=[2,2,1,1,2]0,1,2,3,4,5,6,7 last_tile_dim_replicate} + ROOT %transpose = f32[38,4,16,38] transpose(%param0.copy), + dimensions={1,3,0,2}, + sharding={devices=[2,1,2,1,2]0,1,4,5,2,3,6,7 last_tile_dim_replicate} +})"; + + TF_ASSERT_OK_AND_ASSIGN(auto module, + PartitionComputation(hlo_string, /*num_devices=*/8)); + VLOG(1) << module->ToString(); + + auto root = module->entry_computation()->root_instruction(); + auto param0 = AllOf( + op::Copy(op::DynamicSlice(op::Parameter(), op::Reshape(), op::Reshape(), + op::Constant(), op::Constant())), + op::Shape("f32[8,19,38,4]")); + EXPECT_THAT(root, AllOf(op::Transpose(param0), op::Shape("f32[19,4,8,38]"))); +} + TEST_F(SpmdPartitioningTest, ShardableReshape) { const char* const hlo_string = R"( HloModule module @@ -2621,6 +2691,30 @@ ENTRY entry { EXPECT_THAT(root, AllOf(op::Reshape(param0), op::Shape("f32[19,38,4,81]"))); } +TEST_F(SpmdPartitioningTest, PartialReplicateShardableReshape) { + const char* const hlo_string = R"( +HloModule module + +ENTRY entry { + %param0 = f32[38,38,324] parameter(0) + %param0.copy = f32[38,38,324] copy(%param0), + sharding={devices=[2,1,1,2]0,1,2,3 last_tile_dim_replicate} + ROOT %reshape = f32[38,38,4,81] reshape(%param0.copy), + sharding={devices=[2,1,1,1,2]0,1,2,3 last_tile_dim_replicate} +})"; + + TF_ASSERT_OK_AND_ASSIGN(auto module, + PartitionComputation(hlo_string, /*num_devices=*/4)); + VLOG(1) << module->ToString(); + + auto root = module->entry_computation()->root_instruction(); + auto param0 = + AllOf(op::Copy(op::DynamicSlice(op::Parameter(), op::Reshape(), + op::Constant(), op::Constant())), + op::Shape("f32[19,38,324]")); + EXPECT_THAT(root, AllOf(op::Reshape(param0), op::Shape("f32[19,38,4,81]"))); +} + TEST_F(SpmdPartitioningTest, NonShardableReshape) { const char* const hlo_string = R"( HloModule module @@ -2673,6 +2767,30 @@ ENTRY entry { EXPECT_THAT(root, AllOf(exchanged, op::Shape("s32[3,2,1,7,5]"))); } +TEST_F(SpmdPartitioningTest, PartialReplicateReshapeMergeDimsWithHaloExchange) { + const char* const hlo_string = R"( +HloModule module + +ENTRY entry { + %input = s32[2,3,7,10] parameter(0), + sharding={devices=[1,1,2,1,2]0,1,2,3 last_tile_dim_replicate} + ROOT %reshape = s32[3,2,1,14,5] reshape(%input), + sharding={devices=[1,1,1,2,1,2]0,1,2,3 last_tile_dim_replicate} +})"; + + TF_ASSERT_OK_AND_ASSIGN(auto module, + PartitionComputation(hlo_string, /*num_devices=*/4)); + VLOG(1) << module->ToString(); + + auto reshape = + AllOf(op::Reshape(op::Parameter(0)), op::Shape("s32[3,2,1,8,5]")); + auto halo = op::CollectivePermute(op::Slice(reshape)); + auto exchanged = + op::DynamicSlice(op::Concatenate(halo, reshape), _, _, _, _, _); + auto root = module->entry_computation()->root_instruction(); + EXPECT_THAT(root, AllOf(exchanged, op::Shape("s32[3,2,1,7,5]"))); +} + // Produces an invalid module after transformation. TEST_F(SpmdPartitioningTest, InceptionV3_4_way_ReduceWindowDilated) { const char* const hlo_string = R"( @@ -2831,6 +2949,48 @@ ENTRY %main { op::Shape("(f32[14], s32[14])"))); } +TEST_F(SpmdPartitioningTest, TiledToTiledTupleReduce2) { + const char* const hlo_string = R"( +HloModule module + +%minmax_func { + %lhs_value = f32[] parameter(0) + %rhs_value = f32[] parameter(2) + %compare.2 = pred[] compare(%lhs_value, %rhs_value), direction=GT + %select.4 = f32[] select(%compare.2, %lhs_value, %rhs_value) + %lhs_index = s32[] parameter(1) + %rhs_index = s32[] parameter(3) + %select.5 = s32[] select(%compare.2, %lhs_index, %rhs_index) + ROOT %tuple.2 = (f32[], s32[]) tuple(%select.4, %select.5) +} + +ENTRY %main { + %param0 = f32[28,10] parameter(0), sharding={devices=[2,2]0,1,2,3} + %param1 = s32[28,10] parameter(1), sharding={devices=[2,2]0,1,2,3} + %init0 = f32[] parameter(2) + %init1 = s32[] parameter(3) + ROOT %reduce = (f32[28], s32[28]) reduce(%param0, %param1, %init0, %init1), + dimensions={1}, to_apply=%minmax_func, + sharding={{devices=[2,2]0,1,2,3 last_tile_dim_replicate}, + {devices=[2,2]0,1,2,3 last_tile_dim_replicate}} +})"; + + TF_ASSERT_OK_AND_ASSIGN(auto module, + PartitionComputation(hlo_string, /*num_devices=*/4)); + VLOG(1) << module->ToString(); + + auto lhs = + AllOf(op::Shape("f32[14,10]"), + op::AllReduce(op::DynamicUpdateSlice(_, op::Parameter(0), _, _))); + auto rhs = + AllOf(op::Shape("s32[14,10]"), + op::AllReduce(op::DynamicUpdateSlice(_, op::Parameter(1), _, _))); + auto root = module->entry_computation()->root_instruction(); + EXPECT_THAT(root, + AllOf(op::Reduce(lhs, rhs, op::Parameter(2), op::Parameter(3)), + op::Shape("(f32[14], s32[14])"))); +} + TEST_F(SpmdPartitioningTest, TiledToTiledReduceOutputReshard) { const char* const hlo_string = R"( HloModule module @@ -3793,8 +3953,8 @@ ENTRY entry { TF_ASSERT_OK_AND_ASSIGN(auto module, PartitionComputation(hlo_string, /*num_devices=*/2)); VLOG(1) << module->ToString(); - auto offset = op::Reshape( - op::DynamicSlice(op::Constant(), op::PartitionId(), op::Constant())); + auto offset = + op::Reshape(op::DynamicSlice(op::Constant(), op::PartitionId())); auto min = AllOf(op::Broadcast(offset), op::Shape("s32[2,3]")); auto max = AllOf(op::Broadcast(op::Add(offset, op::Constant())), op::Shape("s32[2,3]")); @@ -3930,8 +4090,8 @@ ENTRY entry { TF_ASSERT_OK_AND_ASSIGN(auto module, PartitionComputation(hlo_string, /*num_devices=*/2)); VLOG(1) << module->ToString(); - auto offset = op::Reshape( - op::DynamicSlice(op::Constant(), op::PartitionId(), op::Constant())); + auto offset = + op::Reshape(op::DynamicSlice(op::Constant(), op::PartitionId())); auto indices = op::Subtract( op::Parameter(1), AllOf(op::Broadcast(offset), op::Shape("s32[2,3]"))); HloInstruction* root = module->entry_computation()->root_instruction(); @@ -4119,7 +4279,7 @@ HloModule module ENTRY entry { %lhs = f32[48,12] parameter(0), sharding={devices=[2,2]0,1,2,3} - %rhs = f32[32,12] parameter(1), sharding={devices=[2,2]0,1,2,3} + %rhs = f32[32,12] parameter(1), sharding={devices=[2,2]0,2,1,3} ROOT %dot = f32[48,32] dot(%lhs, %rhs), lhs_batch_dims={}, rhs_batch_dims={}, lhs_contracting_dims={1}, rhs_contracting_dims={1}, @@ -4136,8 +4296,8 @@ ENTRY entry { op::AllReduce(op::DynamicUpdateSlice(_, lhs, _, _))); auto rhs = AllOf(op::Shape("f32[16,6]"), op::Parameter(1)); auto partial_replicated_rhs = - AllOf(op::Shape("f32[16,12]"), op::AllReduce(op::DynamicUpdateSlice( - _, op::CollectivePermute(rhs), _, _))); + AllOf(op::Shape("f32[16,12]"), + op::AllReduce(op::DynamicUpdateSlice(_, rhs, _, _))); auto root = module->entry_computation()->root_instruction(); EXPECT_THAT(root, AllOf(op::Dot(partial_replicated_lhs, partial_replicated_rhs), @@ -4429,6 +4589,33 @@ ENTRY entry { EXPECT_THAT(root, op::AllReduce(dot)); } +TEST_F(SpmdPartitioningTest, DotPartialContracting3) { + const char* const hlo_string = R"( +HloModule module + +ENTRY entry { + %lhs = f32[24,100] parameter(0), + sharding={devices=[1,2,4]0,1,2,3,4,5,6,7 last_tile_dim_replicate} + %rhs = f32[32,100] parameter(1), + sharding={devices=[1,2,4]0,1,2,3,4,5,6,7 last_tile_dim_replicate} + ROOT %dot = f32[24,32] dot(%lhs, %rhs), + lhs_batch_dims={}, rhs_batch_dims={}, + lhs_contracting_dims={1}, rhs_contracting_dims={1}, + sharding={devices=[1,2,4]0,1,2,3,4,5,6,7 last_tile_dim_replicate} +})"; + + TF_ASSERT_OK_AND_ASSIGN(auto module, + PartitionComputation(hlo_string, /*num_devices=*/8)); + VLOG(1) << module->ToString(); + + auto lhs = AllOf(op::Shape("f32[24,50]"), op::Parameter(0)); + auto rhs = + AllOf(op::Shape("f32[16,50]"), op::DynamicSlice(op::Parameter(1), _, _)); + auto dot = AllOf(op::Shape("f32[24,16]"), op::Dot(lhs, rhs)); + auto root = module->entry_computation()->root_instruction(); + EXPECT_THAT(root, op::CollectivePermute(op::AllReduce(dot))); +} + TEST_F(SpmdPartitioningTest, DotBatchAndPartialContracting) { const char* const hlo_string = R"( HloModule module @@ -4484,6 +4671,119 @@ ENTRY entry { EXPECT_THAT(root, dot); } +TEST_F(SpmdPartitioningTest, DotPartialNonContractingPartialMatch) { + const char* const hlo_string = R"( +HloModule module + +ENTRY entry { + %lhs = f32[24,8,100] parameter(0), sharding={devices=[2,2,1]0,1,2,3} + %rhs = f32[32,100] parameter(1), + sharding={devices=[2,1,2]0,2,1,3 last_tile_dim_replicate} + ROOT %dot = f32[24,8,32] dot(%lhs, %rhs), + lhs_batch_dims={}, rhs_batch_dims={}, + lhs_contracting_dims={2}, rhs_contracting_dims={1}, + sharding={devices=[2,1,2]0,1,2,3} +})"; + + TF_ASSERT_OK_AND_ASSIGN(auto module, + PartitionComputation(hlo_string, /*num_devices=*/4)); + VLOG(1) << module->ToString(); + + auto lhs = AllOf(op::Shape("f32[12,4,100]"), op::Parameter(0)); + auto rhs = AllOf(op::Shape("f32[16,100]"), op::Parameter(1)); + auto partially_replicated_lhs = AllOf( + op::Shape("f32[12,8,100]"), + op::AllReduce(op::DynamicUpdateSlice(op::Broadcast(_), lhs, _, _, _))); + auto dot = + AllOf(op::Shape("f32[12,8,16]"), op::Dot(partially_replicated_lhs, rhs)); + auto root = module->entry_computation()->root_instruction(); + EXPECT_THAT(root, dot); +} + +TEST_F(SpmdPartitioningTest, DotPartialContractingPartialMatch) { + const char* const hlo_string = R"( +HloModule module + +ENTRY entry { + %lhs = f32[24,8,100] parameter(0), sharding={devices=[1,2,2]0,1,2,3} + %rhs = f32[32,8,100] parameter(1), + sharding={devices=[1,1,2,2]0,2,1,3 last_tile_dim_replicate} + ROOT %dot = f32[24,32] dot(%lhs, %rhs), + lhs_batch_dims={}, rhs_batch_dims={}, + lhs_contracting_dims={1,2}, rhs_contracting_dims={1,2}, + sharding={replicated} +})"; + + TF_ASSERT_OK_AND_ASSIGN(auto module, + PartitionComputation(hlo_string, /*num_devices=*/4)); + VLOG(1) << module->ToString(); + + auto lhs = AllOf(op::Shape("f32[24,4,50]"), op::Parameter(0)); + auto rhs = AllOf(op::Shape("f32[32,8,50]"), op::Parameter(1)); + auto dot = AllOf(op::Shape("f32[24,32]"), + op::Dot(lhs, AllOf(op::Shape("f32[32,4,50]"), + op::DynamicSlice(rhs, _, _, _)))); + auto root = module->entry_computation()->root_instruction(); + EXPECT_THAT(root, op::AllReduce(op::AllReduce(dot))); +} + +TEST_F(SpmdPartitioningTest, DotNonContractingPartialMatchContractingMatch) { + const char* const hlo_string = R"( +HloModule module + +ENTRY entry { + %lhs = f32[24,8,100] parameter(0), sharding={devices=[2,1,2]0,1,2,3} + %rhs = f32[100,50] parameter(1), sharding={devices=[2,2]0,2,1,3} + ROOT %dot = f32[24,8,50] dot(%lhs, %rhs), + lhs_batch_dims={}, rhs_batch_dims={}, + lhs_contracting_dims={2}, rhs_contracting_dims={0}, + sharding={devices=[2,2,1]0,1,2,3} +})"; + + TF_ASSERT_OK_AND_ASSIGN(auto module, + PartitionComputation(hlo_string, /*num_devices=*/4)); + VLOG(1) << module->ToString(); + + auto lhs = AllOf(op::Shape("f32[12,8,50]"), op::Parameter(0)); + auto rhs = AllOf(op::Shape("f32[50,25]"), op::Parameter(1)); + auto dot = AllOf( + op::Shape("f32[12,8,50]"), + op::Dot(lhs, AllOf(op::Shape("f32[50,50]"), + op::AllReduce(op::DynamicUpdateSlice(_, rhs, _, _))))); + auto root = module->entry_computation()->root_instruction(); + EXPECT_THAT(root, AllOf(op::Shape("f32[12,4,50]"), + op::DynamicSlice(op::AllReduce(dot), _, _, _))) + << module->ToString(); +} + +TEST_F(SpmdPartitioningTest, DotLHSMutiNonContractingRHSNotMatch) { + const char* const hlo_string = R"( +HloModule module + +ENTRY entry { + %lhs = f32[24,8,10] parameter(0), sharding={devices=[2,2,1]0,1,2,3} + %rhs = f32[10,50] parameter(1), + sharding={devices=[2,1,2]0,2,1,3 last_tile_dim_replicate} + ROOT %dot = f32[24,8,50] dot(%lhs, %rhs), + lhs_batch_dims={}, rhs_batch_dims={}, + lhs_contracting_dims={2}, rhs_contracting_dims={0}, + sharding={devices=[2,2,1]0,1,2,3} +})"; + + TF_ASSERT_OK_AND_ASSIGN(auto module, + PartitionComputation(hlo_string, /*num_devices=*/4)); + VLOG(1) << module->ToString(); + + auto lhs = AllOf(op::Shape("f32[12,4,10]"), op::Parameter(0)); + auto rhs = AllOf(op::Shape("f32[5,50]"), op::Parameter(1)); + auto dot = AllOf( + op::Shape("f32[12,4,50]"), + op::Dot(lhs, AllOf(op::Shape("f32[10,50]"), + op::AllReduce(op::DynamicUpdateSlice(_, rhs, _, _))))); + auto root = module->entry_computation()->root_instruction(); + EXPECT_THAT(root, dot) << module->ToString(); +} + TEST_F(SpmdPartitioningTest, ElementwiseTest_PartialReplicateToTiledHaloExchange) { const char* const hlo_string = R"( @@ -4531,6 +4831,266 @@ ENTRY entry { EXPECT_THAT(root, AllOf(op::Shape("f32[2,3]"), op::Add(add_lhs, add_rhs))); } +TEST_F(SpmdPartitioningTest, TileToPartialReplicateReshard) { + const char* const hlo_string = R"( +HloModule module + +ENTRY entry { + %param0 = f32[8,8] parameter(0) + %copy = f32[8,8] copy(%param0), + sharding={devices=[2,2]0,1,2,3} + ROOT %copy0 = f32[8,8] copy(%copy), + sharding={devices=[2,1,2]0,1,2,3 last_tile_dim_replicate} +})"; + + TF_ASSERT_OK_AND_ASSIGN(auto module, + PartitionComputation(hlo_string, /*num_devices=*/4)); + VLOG(1) << module->ToString(); + auto tiled = AllOf(op::Shape("f32[4,4]"), + op::Copy(op::DynamicSlice(op::Parameter(0), op::Reshape(), + op::Reshape()))); + auto partially_replicated = AllOf( + op::Shape("f32[4,8]"), op::Copy(op::AllReduce(op::DynamicUpdateSlice( + op::Broadcast(_), tiled, _, _)))); + auto root = module->entry_computation()->root_instruction(); + EXPECT_THAT(root, partially_replicated); +} + +TEST_F(SpmdPartitioningTest, PartialReplicateToTileReshard) { + const char* const hlo_string = R"( +HloModule module + +ENTRY entry { + %param0 = f32[8,8] parameter(0) + %copy = f32[8,8] copy(%param0), + sharding={devices=[2,1,2]0,1,2,3 last_tile_dim_replicate} + ROOT %copy0 = f32[8,8] copy(%copy), + sharding={devices=[2,2]0,1,2,3} +})"; + + TF_ASSERT_OK_AND_ASSIGN(auto module, + PartitionComputation(hlo_string, /*num_devices=*/4)); + VLOG(1) << module->ToString(); + auto partially_replicated = + AllOf(op::Shape("f32[4,8]"), + op::Copy(op::DynamicSlice(op::Parameter(0), op::Reshape(), + op::Constant()))); + auto tiled = AllOf(op::Shape("f32[4,4]"), + op::Copy(op::DynamicSlice(partially_replicated, + op::Constant(), op::Reshape()))); + auto root = module->entry_computation()->root_instruction(); + EXPECT_THAT(root, tiled); +} + +TEST_F(SpmdPartitioningTest, + PartialReplicateToPartialReplicateReshard_AllReduce) { + const char* const hlo_string = R"( +HloModule module + +ENTRY entry { + %param0 = f32[8,8] parameter(0) + %copy = f32[8,8] copy(param0), + sharding={devices=[2,2,2]0,1,2,3,4,5,6,7 last_tile_dim_replicate} + ROOT %copy0 = f32[8,8] copy(%copy), + sharding={devices=[2,1,4]0,1,2,3,4,5,6,7 last_tile_dim_replicate} +})"; + + TF_ASSERT_OK_AND_ASSIGN(auto module, + PartitionComputation(hlo_string, /*num_devices=*/8)); + + VLOG(1) << module->ToString(); + auto partially_replicated_init = + AllOf(op::Shape("f32[4,4]"), + op::Copy(op::DynamicSlice(op::Parameter(0), op::Reshape(), + op::Reshape()))); + auto partially_replicated = + AllOf(op::Shape("f32[4,8]"), + op::Copy(op::AllReduce(op::DynamicUpdateSlice( + op::Broadcast(_), partially_replicated_init, _, _)))); + auto root = module->entry_computation()->root_instruction(); + EXPECT_THAT(root, partially_replicated); +} + +TEST_F(SpmdPartitioningTest, + PartialReplicateToPartialReplicateReshard_DynamicSlice) { + const char* const hlo_string = R"( +HloModule module + +ENTRY entry { + %param0 = f32[8,8] parameter(0) + %copy = f32[8,8] copy(%param0), + sharding={devices=[2,1,4]0,1,2,3,4,5,6,7 last_tile_dim_replicate} + ROOT %copy0 = f32[8,8] copy(%copy), + sharding={devices=[2,2,2]0,1,2,3,4,5,6,7 last_tile_dim_replicate} +})"; + + TF_ASSERT_OK_AND_ASSIGN(auto module, + PartitionComputation(hlo_string, /*num_devices=*/8)); + VLOG(1) << module->ToString(); + auto partially_replicated = + AllOf(op::Shape("f32[4,8]"), + op::Copy(op::DynamicSlice(op::Parameter(0), op::Reshape(), + op::Constant()))); + auto tiled = AllOf(op::Shape("f32[4,4]"), + op::Copy(op::DynamicSlice(partially_replicated, + op::Constant(), op::Reshape()))); + auto root = module->entry_computation()->root_instruction(); + EXPECT_THAT(root, tiled); +} + +TEST_F(SpmdPartitioningTest, + PartialReplicateToPartialReplicateReshard_DynamicSlice2) { + const char* const hlo_string = R"( +HloModule module + +ENTRY entry { + %param0 = f32[8,8] parameter(0) + %copy = f32[8,8] copy(%param0), + sharding={devices=[1,1,8]0,1,2,3,4,5,6,7 last_tile_dim_replicate} + ROOT %copy0 = f32[8,8] copy(%copy), + sharding={devices=[2,2,2]0,1,2,3,4,5,6,7 last_tile_dim_replicate} +})"; + + TF_ASSERT_OK_AND_ASSIGN(auto module, + PartitionComputation(hlo_string, /*num_devices=*/8)); + VLOG(1) << module->ToString(); + auto partially_replicated = + AllOf(op::Shape("f32[8,8]"), + op::Copy(op::DynamicSlice(op::Parameter(0), op::Constant(), + op::Constant()))); + auto tiled = AllOf(op::Shape("f32[4,4]"), + op::Copy(op::DynamicSlice(partially_replicated, + op::Reshape(), op::Reshape()))); + auto root = module->entry_computation()->root_instruction(); + EXPECT_THAT(root, tiled); +} + +TEST_F(SpmdPartitioningTest, + PartialReplicateToPartialReplicateReshardWithCollectivePermute) { + const char* const hlo_string = R"( +HloModule module + +ENTRY entry { + %param0 = f32[8,8] parameter(0) + %copy = f32[8,8] copy(param0), + sharding={devices=[2,2,2]0,1,2,3,4,5,6,7 last_tile_dim_replicate} + ROOT %copy0 = f32[8,8] copy(%copy), + sharding={devices=[1,2,4]0,1,2,3,4,5,6,7 last_tile_dim_replicate} +})"; + + TF_ASSERT_OK_AND_ASSIGN(auto module, + PartitionComputation(hlo_string, /*num_devices=*/8)); + + VLOG(1) << module->ToString(); + auto partially_replicated_init = + AllOf(op::Shape("f32[4,4]"), + op::CollectivePermute(op::Copy(op::DynamicSlice( + op::Parameter(0), op::Reshape(), op::Reshape())))); + auto partially_replicated = + AllOf(op::Shape("f32[8,4]"), + op::Copy(op::AllReduce(op::DynamicUpdateSlice( + op::Broadcast(_), partially_replicated_init, _, _)))); + auto root = module->entry_computation()->root_instruction(); + EXPECT_THAT(root, partially_replicated); +} + +TEST_F(SpmdPartitioningTest, + PartialReplicateToPartialReplicateReshardCollectivePermute1) { + const char* const hlo_string = R"( +HloModule module + +ENTRY entry { + %param0 = f32[8,8] parameter(0) + %copy = f32[8,8] copy(%param0), + sharding={devices=[1,2,4]0,1,2,3,4,5,6,7 last_tile_dim_replicate} + ROOT %copy0 = f32[8,8] copy(%copy), + sharding={devices=[2,2,2]0,1,2,3,4,5,6,7 last_tile_dim_replicate} +})"; + + TF_ASSERT_OK_AND_ASSIGN(auto module, + PartitionComputation(hlo_string, /*num_devices=*/8)); + VLOG(1) << module->ToString(); + auto partially_replicated = + AllOf(op::Shape("f32[8,4]"), + op::Copy(op::DynamicSlice(op::Parameter(0), op::Constant(), + op::Reshape()))); + auto tiled = + AllOf(op::Shape("f32[4,4]"), + op::Copy(op::CollectivePermute(op::DynamicSlice( + partially_replicated, op::Reshape(), op::Constant())))); + auto root = module->entry_computation()->root_instruction(); + EXPECT_THAT(root, tiled); +} + +TEST_F(SpmdPartitioningTest, + PartialReplicateToPartialReplicateReshardHaloExchange) { + const char* const hlo_string = R"( +HloModule module + +ENTRY entry { + %param0 = f32[6,3] parameter(0) + %copy = f32[6,3] copy(param0), + sharding={devices=[4,1,2]0,1,2,3,4,5,6,7 last_tile_dim_replicate} + ROOT %copy0 = f32[6,3] copy(%copy), + sharding={devices=[2,1,4]0,1,2,3,4,5,6,7 last_tile_dim_replicate} +})"; + + TF_ASSERT_OK_AND_ASSIGN(auto module, + PartitionComputation(hlo_string, /*num_devices=*/8)); + + VLOG(1) << module->ToString(); + auto partially_replicated_init = + AllOf(op::Shape("f32[2,3]"), + op::Copy(op::DynamicSlice(op::Pad(op::Parameter(0), op::Constant()), + op::Reshape(), op::Constant()))); + auto slice = + AllOf(op::Shape("f32[2,3]"), + op::DynamicSlice(op::Concatenate(op::CollectivePermute(op::Slice( + partially_replicated_init)), + partially_replicated_init), + _, _)); + auto partially_replicated = + AllOf(op::Shape("f32[3,3]"), + op::Copy(op::Slice(op::AllReduce( + op::DynamicUpdateSlice(op::Broadcast(_), slice, _, _))))); + auto root = module->entry_computation()->root_instruction(); + EXPECT_THAT(root, partially_replicated); +} + +TEST_F(SpmdPartitioningTest, + PartialReplicateToPartialReplicateReshardHaloExchange1) { + const char* const hlo_string = R"( +HloModule module + +ENTRY entry { + %param0 = f32[6,3] parameter(0) + %copy = f32[6,3] copy(param0), + sharding={devices=[2,1,4]0,1,2,3,4,5,6,7 last_tile_dim_replicate} + ROOT %copy0 = f32[6,3] copy(%copy), + sharding={devices=[4,1,2]0,1,2,3,4,5,6,7 last_tile_dim_replicate} +})"; + + TF_ASSERT_OK_AND_ASSIGN(auto module, + PartitionComputation(hlo_string, /*num_devices=*/8)); + + VLOG(1) << module->ToString(); + auto partially_replicated_init = + AllOf(op::Shape("f32[3,3]"), + op::Copy(op::DynamicSlice(op::Parameter(0), op::Reshape(), + op::Constant()))); + auto slice = AllOf( + op::Shape("f32[4,3]"), + op::DynamicSlice(op::Pad(op::Concatenate(partially_replicated_init, + op::CollectivePermute(op::Slice( + partially_replicated_init))), + op::Constant()), + _, _)); + auto partially_replicated = + AllOf(op::Shape("f32[2,3]"), op::Copy(op::DynamicSlice(slice, _, _))); + auto root = module->entry_computation()->root_instruction(); + EXPECT_THAT(root, partially_replicated); +} + } // namespace } // namespace spmd } // namespace xla diff --git a/tensorflow/compiler/xla/service/spmd/spmd_partitioner_util.cc b/tensorflow/compiler/xla/service/spmd/spmd_partitioner_util.cc index da2a3a44405..0edbd4f2b8d 100644 --- a/tensorflow/compiler/xla/service/spmd/spmd_partitioner_util.cc +++ b/tensorflow/compiler/xla/service/spmd/spmd_partitioner_util.cc @@ -29,6 +29,7 @@ limitations under the License. #include "tensorflow/compiler/xla/service/hlo_instruction.h" #include "tensorflow/compiler/xla/service/hlo_instructions.h" #include "tensorflow/compiler/xla/service/hlo_module.h" +#include "tensorflow/compiler/xla/service/hlo_opcode.h" #include "tensorflow/compiler/xla/service/hlo_sharding.h" #include "tensorflow/compiler/xla/service/hlo_sharding_util.h" #include "tensorflow/compiler/xla/service/pattern_matcher.h" @@ -202,13 +203,17 @@ std::vector MakePartitionOffsets( absl::Span dims) { CHECK(!shape.IsTuple()); - Array2D offset_array( - {sharding.tile_assignment().num_elements(), shape.rank()}); - offset_array.Each([&](int64 i, int64 j, int32* value) { - *value = sharding.TileOffsetForDevice(shape, i)[j]; - }); - auto offset_table = b->AddInstruction(HloInstruction::CreateConstant( - LiteralUtil::CreateR2FromArray2D(offset_array))); + std::vector> offset_arrays(shape.rank()); + for (int64 i = 0; i < shape.rank(); ++i) { + offset_arrays[i].resize(sharding.tile_assignment().num_elements()); + } + auto shard_shape = MakePartitionedShape(shape, sharding); + sharding.tile_assignment().Each( + [&](absl::Span indices, int64 device) { + for (int64 i = 0; i < shape.rank(); ++i) { + offset_arrays[i][device] = indices[i] * shard_shape.dimensions(i); + } + }); std::vector offsets; for (int64 i = 0; i < shape.rank(); ++i) { if (sharding.tile_assignment().dim(i) == 1 || @@ -216,11 +221,10 @@ std::vector MakePartitionOffsets( offsets.push_back(b->AddInstruction( HloInstruction::CreateConstant(LiteralUtil::Zero(S32)))); } else { + auto offset_table = b->AddInstruction(HloInstruction::CreateConstant( + LiteralUtil::CreateR1(offset_arrays[i]))); auto index = b->AddInstruction(HloInstruction::CreateDynamicSlice( - ShapeUtil::MakeShape(S32, {1, 1}), offset_table, - {partition_id, b->AddInstruction(HloInstruction::CreateConstant( - LiteralUtil::CreateR0(i)))}, - {1, 1})); + ShapeUtil::MakeShape(S32, {1}), offset_table, {partition_id}, {1})); offsets.push_back(b->AddInstruction( HloInstruction::CreateReshape(ShapeUtil::MakeShape(S32, {}), index))); } @@ -292,17 +296,29 @@ HloInstruction* PadBaseShapeBeforeUnevenTiledSharding( return PadToShape(hlo, padded_base_shape, b); } -// TODO(wangtao): generize this function when target is partial replicate. -absl::optional PartialReplicateToTileCompatibleSharding( - const HloSharding& partial_sharding, - const std::vector& target_tile_dims) { +absl::optional PartialReplicateReshardCompatibleSharding( + const HloSharding& partial_sharding, const HloSharding& target_sharding) { if (!partial_sharding.ReplicateOnLastTileDim()) { return absl::nullopt; } int64 rank = partial_sharding.tile_assignment().num_dimensions() - 1; - if (target_tile_dims.size() < rank) { + int64 target_rank = target_sharding.tile_assignment().num_dimensions() - + (target_sharding.ReplicateOnLastTileDim() ? 1 : 0); + if (target_rank != rank) { return absl::nullopt; } + + absl::flat_hash_map device_to_replication_group; + partial_sharding.tile_assignment().Each( + [&](absl::Span indices, int64 device) { + int64 gid = 0; + for (int64 i = 0; i < rank; ++i) { + gid *= partial_sharding.tile_assignment().dim(i); + gid += indices[i]; + } + device_to_replication_group[device] = gid; + }); + // A dimension is expanded when target_tile_size > partial_tile_size and // target_tile_size % partial_tile_size == 0. // expand_tile_dims_positions is the index of the expand_dim. @@ -312,7 +328,7 @@ absl::optional PartialReplicateToTileCompatibleSharding( int num_expand_dims = 0; for (int64 dim = 0; dim < rank; dim++) { int64 partial_tile_size = partial_sharding.tile_assignment().dim(dim); - int64 target_tile_size = target_tile_dims[dim]; + int64 target_tile_size = target_sharding.tile_assignment().dim(dim); if (target_tile_size % partial_tile_size != 0 || target_tile_size < partial_tile_size) { return absl::nullopt; @@ -325,14 +341,26 @@ absl::optional PartialReplicateToTileCompatibleSharding( } // Reshape the partial replicate tile_dimensions. + int64 num_target_replication = 1; + if (target_sharding.ReplicateOnLastTileDim()) { + num_target_replication = + target_sharding.tile_assignment().dimensions().back(); + } auto reshape_dimensions = partial_sharding.tile_assignment().dimensions(); int64 num_replication = reshape_dimensions.back(); - if (num_replication != Product(expand_tile_sizes)) { + if (num_replication / num_target_replication != Product(expand_tile_sizes) || + num_replication % num_target_replication != 0) { return absl::nullopt; } + reshape_dimensions.pop_back(); reshape_dimensions.insert(reshape_dimensions.end(), expand_tile_sizes.begin(), expand_tile_sizes.end()); + + if (target_sharding.ReplicateOnLastTileDim()) { + reshape_dimensions.push_back(num_target_replication); + } + auto reshape_tile_assignment = partial_sharding.tile_assignment(); reshape_tile_assignment.Reshape(reshape_dimensions); @@ -346,13 +374,31 @@ absl::optional PartialReplicateToTileCompatibleSharding( } } auto transpose_sharding = hlo_sharding_util::TransposeSharding( - HloSharding::Tile(reshape_tile_assignment), perm); + target_sharding.ReplicateOnLastTileDim() + ? HloSharding::PartialTile(reshape_tile_assignment) + : HloSharding::Tile(reshape_tile_assignment), + perm); // Reshape to target shape auto transpose_tile_assignment = transpose_sharding.tile_assignment(); - transpose_tile_assignment.Reshape(target_tile_dims); + transpose_tile_assignment.Reshape( + target_sharding.tile_assignment().dimensions()); - return HloSharding::Tile(transpose_tile_assignment); + bool groups_matching = true; + target_sharding.tile_assignment().Each( + [&](absl::Span indices, int64 device) { + if (device_to_replication_group[device] != + device_to_replication_group[transpose_tile_assignment(indices)]) { + groups_matching = false; + } + }); + + if (groups_matching) { + return target_sharding; + } + return target_sharding.ReplicateOnLastTileDim() + ? HloSharding::PartialTile(transpose_tile_assignment) + : HloSharding::Tile(transpose_tile_assignment); } absl::optional TileToPartialReplicateHaloExchange( @@ -581,7 +627,10 @@ absl::optional UniqueTiledDim(const HloSharding& sharding) { return absl::nullopt; } int64 dim = -1; - for (int64 i = 0; i < sharding.tile_assignment().num_dimensions(); ++i) { + int64 rank = sharding.ReplicateOnLastTileDim() + ? sharding.tile_assignment().num_dimensions() - 1 + : sharding.tile_assignment().num_dimensions(); + for (int64 i = 0; i < rank; ++i) { if (sharding.tile_assignment().dim(i) > 1) { if (dim != -1) { return absl::nullopt; @@ -1403,7 +1452,7 @@ HloSharding UngroupSharding(const GroupedSharding& grouped_sharding) { } for (int64 i = 0; i < grouped_sharding.group_dims.size(); ++i) { int64 dim = grouped_sharding.group_dims[i]; - tiling_dims[dim] = grouped_sharding.group_dim_sizes[i]; + tiling_dims[dim] *= grouped_sharding.group_dim_sizes[i]; } Array tiling(tiling_dims); grouped_tiling.Each([&](absl::Span indices, int64 device) { @@ -1411,9 +1460,12 @@ HloSharding UngroupSharding(const GroupedSharding& grouped_sharding) { for (int64 g = 0; g < grouped_sharding.device_groups.size(); ++g) { int64 remaining_group_index = g; for (int64 i = grouped_sharding.group_dims.size() - 1; i >= 0; --i) { - ungrouped_inds[grouped_sharding.group_dims[i]] = - remaining_group_index % grouped_sharding.group_dim_sizes[i]; - remaining_group_index /= grouped_sharding.group_dim_sizes[i]; + int64 dim = grouped_sharding.group_dims[i]; + int64 groups_in_this_dim = grouped_sharding.group_dim_sizes[i]; + ungrouped_inds[dim] = (remaining_group_index % groups_in_this_dim) * + grouped_tiling.dim(dim) + + indices[dim]; + remaining_group_index /= groups_in_this_dim; } tiling(ungrouped_inds) = grouped_sharding.device_groups[g][device]; } @@ -1684,5 +1736,47 @@ absl::optional ParseReductionComputation( return root->opcode(); } +absl::optional> FindMatchingPartitionedDimsForGrouping( + const HloSharding& sharding, + const std::vector>& device_groups) { + if (sharding.NumTiles() < device_groups.size() || device_groups.size() < 2 || + device_groups[0].size() < 2) { + return absl::nullopt; + } + int64 rank = sharding.tile_assignment().num_dimensions(); + if (sharding.ReplicateOnLastTileDim()) { + rank--; + } + absl::flat_hash_map> device_to_index; + sharding.tile_assignment().Each( + [&](absl::Span index, int64 device) { + device_to_index[device] = + std::vector(index.begin(), index.begin() + rank); + }); + std::vector dims; + int64 group_count = 1; + for (int64 i = 0; i < rank; ++i) { + if (device_to_index[device_groups[0][0]][i] == + device_to_index[device_groups[0][1]][i]) { + dims.push_back(i); + group_count *= sharding.tile_assignment().dim(i); + } + } + if (group_count != device_groups.size()) { + return absl::nullopt; + } + for (const auto& group : device_groups) { + for (int64 i = 1; i < group.size(); ++i) { + if (absl::c_any_of(dims, [&](const int64 dim) { + return device_to_index[group[i]][dim] != + device_to_index[group[0]][dim]; + })) { + return absl::nullopt; + } + } + } + return dims; +} + } // namespace spmd } // namespace xla diff --git a/tensorflow/compiler/xla/service/spmd/spmd_partitioner_util.h b/tensorflow/compiler/xla/service/spmd/spmd_partitioner_util.h index 69ed90a4b66..f6f15481b55 100644 --- a/tensorflow/compiler/xla/service/spmd/spmd_partitioner_util.h +++ b/tensorflow/compiler/xla/service/spmd/spmd_partitioner_util.h @@ -356,8 +356,8 @@ absl::optional PadFromPartialReplicateShape( const SPMDCollectiveOpsCreator& collective_ops_creator, int64* next_channel_id, HloInstruction* partition_id, SpmdBuilder* b); -// Get the compatible sharding from a partial replicate sharding to a given -// target tile dimensions. +// Get the compatible sharding from a partial replicate sharding to a desired +// target tiled sharding. // Compatible means replicate sharding can transform to the target tile // dimensions by dynamic slice. // For example, if partial_sharding is @@ -366,9 +366,9 @@ absl::optional PadFromPartialReplicateShape( // sharding={devices=[1,2,2]0,2,1,3 last_tile_dim_replicate}. // If patial replicate sharding is not partial replicate or can't reshard to // target_tile_dims by dynamic slice, return absl::nullopt. -absl::optional PartialReplicateToTileCompatibleSharding( - const HloSharding& partial_sharding, - const std::vector& target_tile_dims); +// If target_sharding is already compatible, returns it. +absl::optional PartialReplicateReshardCompatibleSharding( + const HloSharding& partial_sharding, const HloSharding& target_sharding); // Do left halo exchange if all-reduce directly from tile sharding to partial // replicate sharding will remove useful data from the source. @@ -379,6 +379,12 @@ absl::optional TileToPartialReplicateHaloExchange( const SPMDCollectiveOpsCreator& collective_ops_creator, int64* next_channel_id, HloInstruction* partition_id, SpmdBuilder* b); +// Finds a list of dimensions that can be grouped on such that it will have the +// specified device groups. Group order and dimension order are ignored. +absl::optional> FindMatchingPartitionedDimsForGrouping( + const HloSharding& sharding, + const std::vector>& device_groups); + } // namespace spmd } // namespace xla diff --git a/tensorflow/compiler/xla/service/tuple_points_to_analysis_test.cc b/tensorflow/compiler/xla/service/tuple_points_to_analysis_test.cc index c66f9d96a50..e2b977ad493 100644 --- a/tensorflow/compiler/xla/service/tuple_points_to_analysis_test.cc +++ b/tensorflow/compiler/xla/service/tuple_points_to_analysis_test.cc @@ -333,10 +333,10 @@ TEST_F(TuplePointsToAnalysisTest, CopyStartAndCopyDone) { auto builder = HloComputation::Builder(TestName()); auto constant = builder.AddInstruction( HloInstruction::CreateConstant(LiteralUtil::CreateR0(1.0))); - auto copy_start = builder.AddInstruction(HloInstruction::CreateUnary( + auto copy_start = builder.AddInstruction(HloInstruction::CreateCopyStart( ShapeUtil::MakeTupleShape({constant->shape(), constant->shape(), ShapeUtil::MakeShape(U32, {})}), - HloOpcode::kCopyStart, constant)); + constant)); auto copy_done = builder.AddInstruction(HloInstruction::CreateUnary( constant->shape(), HloOpcode::kCopyDone, copy_start)); diff --git a/tensorflow/compiler/xla/shape_layout.h b/tensorflow/compiler/xla/shape_layout.h index b4982f1d8e4..64c9635f335 100644 --- a/tensorflow/compiler/xla/shape_layout.h +++ b/tensorflow/compiler/xla/shape_layout.h @@ -61,6 +61,10 @@ class ShapeLayout { // Returns the shape (with layouts). const Shape& shape() const { return shape_; } + // Clear dynamic dimensions of this module. Pretending the module creates + // static results. Useful in inspecting full outputs when testing. + void ClearDynamicShape() { shape_.clear_dynamic_dimensions(); } + // Checks that a layout is set for the shape, and returns a reference to the // layout directly on the shape. Shape must not be a tuple. const Layout& layout() const; diff --git a/tensorflow/compiler/xla/tests/dynamism_inference_test.cc b/tensorflow/compiler/xla/tests/dynamism_inference_test.cc index ba4092def16..a7e032448e0 100644 --- a/tensorflow/compiler/xla/tests/dynamism_inference_test.cc +++ b/tensorflow/compiler/xla/tests/dynamism_inference_test.cc @@ -104,12 +104,26 @@ TEST_F(DynamismInferenceTest, ScalarInt32Literal) { } } +TEST_F(DynamismInferenceTest, TupleSimple) { + for (ClientType client_type : client_types) { + Client* client = ClientOrDie(platform_, client_type); + XlaBuilder b(TestName()); + auto c = ConstantR0(&b, 42); + auto p = Parameter(&b, 0, ShapeUtil::MakeScalarShape(S32), "p0"); + + auto tuple = Tuple(&b, {c, p}); + EXPECT_EQ(ComputeDynamismScalar(client, tuple, &b, {0}).ValueOrDie(), + false); + EXPECT_EQ(ComputeDynamismScalar(client, tuple, &b, {1}).ValueOrDie(), true); + } +} + TEST_F(DynamismInferenceTest, TupleGteKeepsDynamism) { for (ClientType client_type : client_types) { Client* client = ClientOrDie(platform_, client_type); XlaBuilder b(TestName()); auto c = ConstantR0(&b, 42); - auto p = Parameter(&b, 0, ShapeUtil::MakeScalarShape(S32), "0"); + auto p = Parameter(&b, 0, ShapeUtil::MakeScalarShape(S32), "p0"); auto tuple = Tuple(&b, {c, p}); auto gte0 = GetTupleElement(tuple, 0); @@ -122,12 +136,25 @@ TEST_F(DynamismInferenceTest, TupleGteKeepsDynamism) { } } +TEST_F(DynamismInferenceTest, PredValueUsedTwice) { + for (ClientType client_type : client_types) { + Client* client = ClientOrDie(platform_, client_type); + XlaBuilder b(TestName()); + auto c = ConstantR0(&b, 42); + auto p = Parameter(&b, 0, ShapeUtil::MakeScalarShape(S32), "p0"); + auto pred = Eq(c, p); + auto result = Select(pred, p, c); + EXPECT_EQ(ComputeDynamismScalar(client, result, &b, {}).ValueOrDie(), + false); + } +} + TEST_F(DynamismInferenceTest, ConcatSliceReshapeKeepsDynamism) { for (ClientType client_type : client_types) { Client* client = ClientOrDie(platform_, client_type); XlaBuilder b(TestName()); auto c = ConstantR0(&b, 42); - auto p = Parameter(&b, 0, ShapeUtil::MakeScalarShape(S32), "0"); + auto p = Parameter(&b, 0, ShapeUtil::MakeScalarShape(S32), "p0"); auto concat = ConcatScalars(&b, {c, p}); auto slice0 = SliceInDim(concat, 0, 1, 1, 0); @@ -146,7 +173,7 @@ TEST_F(DynamismInferenceTest, ParameterIsDynamic) { for (ClientType client_type : client_types) { Client* client = ClientOrDie(platform_, client_type); XlaBuilder b(TestName()); - auto computation = Parameter(&b, 0, ShapeUtil::MakeScalarShape(S32), "0"); + auto computation = Parameter(&b, 0, ShapeUtil::MakeScalarShape(S32), "p0"); auto value = ComputeDynamismScalar(client, computation, &b); ASSERT_TRUE(value.ok()) << value.status(); @@ -160,7 +187,7 @@ TEST_F(DynamismInferenceTest, UnaryOpKeepsDynamism) { Client* client = ClientOrDie(platform_, client_type); XlaBuilder b(TestName()); auto c = ConstantR0(&b, 42); - auto p = Parameter(&b, 0, ShapeUtil::MakeScalarShape(S32), "0"); + auto p = Parameter(&b, 0, ShapeUtil::MakeScalarShape(S32), "p0"); auto neg0 = Neg(c); auto neg1 = Neg(p); @@ -177,7 +204,7 @@ TEST_F(DynamismInferenceTest, BinaryOpsOrsDynamism) { Client* client = ClientOrDie(platform_, client_type); XlaBuilder b(TestName()); auto c = ConstantR0(&b, 42); - auto p = Parameter(&b, 0, ShapeUtil::MakeScalarShape(S32), "0"); + auto p = Parameter(&b, 0, ShapeUtil::MakeScalarShape(S32), "p0"); // Static value + static value = static auto add1 = Add(c, c); @@ -198,8 +225,8 @@ TEST_F(DynamismInferenceTest, GetDimensionSize) { // param = Param([<=2, 3]) // get_dimension_size(param, 0) is dynamic // get_dimension_size(param, 1) is static - auto p = - Parameter(&b, 0, ShapeUtil::MakeShape(S32, {2, 3}, {true, false}), "0"); + auto p = Parameter(&b, 0, ShapeUtil::MakeShape(S32, {2, 3}, {true, false}), + "p0"); auto gds0 = GetDimensionSize(p, 0); auto gds1 = GetDimensionSize(p, 1); diff --git a/tensorflow/compiler/xla/tests/exhaustive_binary_16_bit_test.cc b/tensorflow/compiler/xla/tests/exhaustive_binary_16_bit_test.cc index 09c91d4be14..dca8e31e792 100644 --- a/tensorflow/compiler/xla/tests/exhaustive_binary_16_bit_test.cc +++ b/tensorflow/compiler/xla/tests/exhaustive_binary_16_bit_test.cc @@ -123,8 +123,16 @@ BINARY_TEST_16BIT(Min, { }) // TODO(bixia): Pow fails with bfloat16 on CPU. -BINARY_TEST_16BIT(DISABLED_ON_CPU(Pow), - { Run(AddEmptyBroadcastDimension(Pow), std::pow); }) +BINARY_TEST_16BIT(DISABLED_ON_CPU(Pow), { + // See b/162664705. + known_incorrect_fn_ = [](int64 val) { + Eigen::bfloat16 f; + uint16_t val_16 = val; + memcpy(&f, &val_16, 2); + return std::isnan(f); + }; + Run(AddEmptyBroadcastDimension(Pow), std::pow); +}) // TODO(bixia): Atan2 fails with bfloat16 on CPU. BINARY_TEST_16BIT(DISABLED_ON_CPU(Atan2), diff --git a/tensorflow/compiler/xla/tests/gather_operation_test.cc b/tensorflow/compiler/xla/tests/gather_operation_test.cc index 0fd5f191db0..0f8a4c1e273 100644 --- a/tensorflow/compiler/xla/tests/gather_operation_test.cc +++ b/tensorflow/compiler/xla/tests/gather_operation_test.cc @@ -711,6 +711,24 @@ ENTRY main { RunTest(hlo_text, &operand, &start_indices); } +XLA_TEST_F(GatherOperationTest, GatherFromScalarNonZeroIndices) { + const string hlo_text = R"( +HloModule GatherFromScalar + +ENTRY main { + operand = f32[1,1,1] parameter(0) + indices = s32[2,3,50] parameter(1) + ROOT gather = f32[1,2,50] gather(operand, indices), + offset_dims={0}, + collapsed_slice_dims={0,1}, + start_index_map={1,0,2}, + index_vector_dim=1, + slice_sizes={1,1,1} +} +)"; + EXPECT_TRUE(RunAndCompare(hlo_text, ErrorSpec{0, 0})); +} + class GatherClientLibraryTest : public ClientLibraryTestBase {}; // Disabled on interpreter since ExecuteAsyncOnStream is not supported. diff --git a/tensorflow/compiler/xla/tests/hlo_test_base.cc b/tensorflow/compiler/xla/tests/hlo_test_base.cc index d0b6e5f80ed..663e7d81006 100644 --- a/tensorflow/compiler/xla/tests/hlo_test_base.cc +++ b/tensorflow/compiler/xla/tests/hlo_test_base.cc @@ -230,6 +230,19 @@ StatusOr> HloTestBase::ExecuteReplicated( device_assignment); } +StatusOr> HloTestBase::ExecuteReplicated( + std::function executable_provider, + std::function argument_count_provider, + std::function argument_provider, + int64 num_replicas, bool run_hlo_passes) { + HloRunner::ReplicatedExecuteOptions options; + options.num_replicas = num_replicas; + options.run_hlo_passes = run_hlo_passes; + options.use_threads = true; + return test_runner_.ExecuteReplicated( + executable_provider, argument_count_provider, argument_provider, options); +} + StatusOr> HloTestBase::MakeReferenceModule( const HloModule& test_module, const std::function& reference_preprocessor) { diff --git a/tensorflow/compiler/xla/tests/hlo_test_base.h b/tensorflow/compiler/xla/tests/hlo_test_base.h index 17c2a55ba5b..fc680e39682 100644 --- a/tensorflow/compiler/xla/tests/hlo_test_base.h +++ b/tensorflow/compiler/xla/tests/hlo_test_base.h @@ -169,6 +169,13 @@ class HloTestBase : public ManifestCheckingTest { int64 num_replicas, DeviceAssignment* device_assignment, bool run_hlo_passes, bool use_threads); + // Same as above, but allows passing different programs for replicas. + StatusOr> ExecuteReplicated( + std::function executable_provider, + std::function argument_count_provider, + std::function argument_provider, + int64 num_replicas, bool run_hlo_passes); + // Executes the given hlo module on two backends and compares results. // // 'arguments': the input of the hlo module. diff --git a/tensorflow/compiler/xla/util.cc b/tensorflow/compiler/xla/util.cc index 1fbce96625b..4034e5fdd27 100644 --- a/tensorflow/compiler/xla/util.cc +++ b/tensorflow/compiler/xla/util.cc @@ -31,10 +31,10 @@ limitations under the License. #include "absl/strings/str_split.h" #include "absl/types/optional.h" #include "tensorflow/compiler/xla/types.h" -#include "tensorflow/core/lib/bfloat16/bfloat16.h" #include "tensorflow/core/lib/core/errors.h" #include "tensorflow/core/lib/math/math_util.h" #include "tensorflow/core/lib/strings/numbers.h" +#include "tensorflow/core/platform/bfloat16.h" #include "tensorflow/core/platform/env.h" #include "tensorflow/core/platform/mutex.h" #include "tensorflow/core/platform/numbers.h" diff --git a/tensorflow/core/BUILD b/tensorflow/core/BUILD index b9fe544783c..e45e0000017 100644 --- a/tensorflow/core/BUILD +++ b/tensorflow/core/BUILD @@ -318,7 +318,6 @@ alias( cc_library( name = "lib_proto_parsing", hdrs = [ - "//tensorflow/core/lib/bfloat16:bfloat16.h", "//tensorflow/core/lib/core:legacy_lib_proto_parsing_headers", "//tensorflow/core/lib/strings:legacy_lib_proto_parsing_headers", "//tensorflow/core/platform:lib_proto_parsing_hdrs", @@ -328,7 +327,6 @@ cc_library( ":platform_base", "@com_google_absl//absl/strings", "@double_conversion//:double-conversion", - "//tensorflow/core/lib/bfloat16", "//tensorflow/core/lib/core:errors", "//tensorflow/core/lib/core:stringpiece", "//tensorflow/core/lib/core:status", @@ -353,6 +351,7 @@ cc_library( cc_library( name = "lib", hdrs = [ + # TODO(rmlarsen): Remove bfloat16.h once dependency in third_party/swift is updated. "//tensorflow/core/lib/bfloat16:bfloat16.h", "//tensorflow/core/lib/core:legacy_lib_core_headers", "//tensorflow/core/lib/gtl:legacy_lib_gtl_headers", @@ -582,7 +581,6 @@ cc_library( "//tensorflow/core/framework:numeric_types.h", "//tensorflow/core/framework:tensor_types.h", "//tensorflow/core/framework:type_traits.h", - "//tensorflow/core/lib/bfloat16:bfloat16.h", "//tensorflow/core/platform:framework_lite_hdrs", "//tensorflow/core/platform/default:integral_types.h", "//tensorflow/core/platform/default:logging.h", @@ -593,7 +591,6 @@ cc_library( "@nsync//:nsync_cpp", ] + [ "//third_party/eigen3", - "//tensorflow/core/lib/bfloat16", "//tensorflow/core/platform:dynamic_annotations", "//tensorflow/core/platform:platform_port", "//tensorflow/core/platform:thread_annotations", @@ -1014,6 +1011,7 @@ cc_library( "//tensorflow/core/kernels:grappler", "//tensorflow/core/kernels:histogram_op", "//tensorflow/core/kernels:io", + "//tensorflow/core/kernels:isotonic_regression_op", "//tensorflow/core/kernels:lookup", "//tensorflow/core/kernels:logging", "//tensorflow/core/kernels:manip", @@ -1258,7 +1256,6 @@ filegroup( "//tensorflow/core/example:mobile_srcs_no_runtime", "//tensorflow/core/framework:attr_value_proto_text_srcs", "//tensorflow/core/framework:mobile_srcs_no_runtime", - "//tensorflow/core/lib/bfloat16:mobile_srcs_no_runtime", "//tensorflow/core/lib/core:mobile_srcs_no_runtime", "//tensorflow/core/lib/gtl:mobile_srcs_no_runtime", "//tensorflow/core/lib/hash:mobile_srcs_no_runtime", @@ -1696,7 +1693,6 @@ filegroup( "//tensorflow/core/framework:resource_handle.h", "//tensorflow/core/platform:legacy_lib_internal_headers", "//tensorflow/core/platform:lib_internal_private_hdrs", - "//tensorflow/core/lib/bfloat16:bfloat16.h", "//tensorflow/core/lib/core:legacy_lib_core_all_headers", "//tensorflow/core/lib/gtl:legacy_lib_gtl_all_headers", "//tensorflow/core/lib/histogram:legacy_lib_histogram_all_headers", @@ -1813,7 +1809,6 @@ cc_library( "@com_google_absl//absl/memory", "@com_google_absl//absl/strings", "//third_party/eigen3", - "//tensorflow/core/lib/bfloat16", "//tensorflow/core/lib/core:arena", "//tensorflow/core/lib/core:bitmap", "//tensorflow/core/lib/core:blocking_counter", @@ -1894,6 +1889,7 @@ cc_library( "//tensorflow/core/lib/strings:strcat", "//tensorflow/core/lib/strings:stringprintf", "//tensorflow/core/platform:abi", + "//tensorflow/core/platform:bfloat16", "//tensorflow/core/platform:base64", "//tensorflow/core/platform:blocking_counter", "//tensorflow/core/platform:casts", @@ -2021,7 +2017,6 @@ alias( cc_library( name = "tflite_portable_logging", hdrs = [ - "//tensorflow/core/lib/bfloat16:bfloat16.h", "//tensorflow/core/platform:tflite_portable_logging_hdrs", "//tensorflow/core/platform/default:integral_types.h", "//tensorflow/core/platform/default:logging.h", @@ -2051,7 +2046,6 @@ cc_library( hdrs = [ "lib/jpeg/jpeg_handle.h", "lib/jpeg/jpeg_mem.h", - "//tensorflow/core/lib/bfloat16:bfloat16.h", "//tensorflow/core/lib/core:legacy_lib_core_stringpiece_header", "//tensorflow/core/platform:jpeg_internal_hdrs", "//tensorflow/core/platform/default:integral_types.h", @@ -2078,7 +2072,6 @@ cc_library( ]), hdrs = [ "lib/gif/gif_io.h", - "//tensorflow/core/lib/bfloat16:bfloat16.h", "//tensorflow/core/lib/core:legacy_lib_core_stringpiece_header", "//tensorflow/core/lib/gtl:legacy_android_gif_internal_headers", "//tensorflow/core/platform:gif_internal_hdrs", @@ -2969,6 +2962,8 @@ filegroup( srcs = [ # PNG data "//tensorflow/core/lib/png:testdata", + "//tensorflow/core/lib/ssim:testdata", + "//tensorflow/core/lib/psnr:testdata", # JPEG data "lib/jpeg/testdata/jpeg_merge_test1.jpg", "lib/jpeg/testdata/jpeg_merge_test1_cmyk.jpg", @@ -2998,13 +2993,6 @@ filegroup( "lib/bmp/testdata/grayscale_small.bmp", "lib/bmp/testdata/grayscale_small_3channels.bmp", "lib/bmp/testdata/grayscale_small_4channels.bmp", - # SSIM, PSNR data - "lib/ssim/testdata/checkerboard1.png", - "lib/ssim/testdata/checkerboard2.png", - "lib/ssim/testdata/checkerboard3.png", - "lib/psnr/testdata/cat_q20.jpg", - "lib/psnr/testdata/cat_q72.jpg", - "lib/psnr/testdata/cat_q95.jpg", ], visibility = ["//visibility:public"], ) diff --git a/tensorflow/core/api_def/base_api/api_def_DataFormatVecPermute.pbtxt b/tensorflow/core/api_def/base_api/api_def_DataFormatVecPermute.pbtxt index d87c088899e..5e736078f18 100644 --- a/tensorflow/core/api_def/base_api/api_def_DataFormatVecPermute.pbtxt +++ b/tensorflow/core/api_def/base_api/api_def_DataFormatVecPermute.pbtxt @@ -24,8 +24,27 @@ END destination data format. END } - summary: "Returns the permuted vector/tensor in the destination data format given the" + summary: "Permute input tensor from `src_format` to `dst_format`." description: <