Merge branch 'master' into op_tests_16x8

This commit is contained in:
Elena Zhelezina 2020-06-17 09:09:54 +01:00 committed by GitHub
commit 81c8a6605d
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
653 changed files with 11790 additions and 5622 deletions

View File

@ -30,6 +30,7 @@
# short_logs: Only log errors during build, skip warnings.
# monolithic: Build all TF C++ code into a single shared object.
# dynamic_kernels: Try to link all kernels dynamically (experimental).
# libc++: Link against libc++ instead of stdlibc++
#
#
# TF version options;
@ -79,6 +80,14 @@
# elinux_armhf: Embedded Linux options for armhf (ARMv7) CPU support.
# Allow builds using libc++ as a linker library
# This is mostly for OSSFuzz, so we also pass in the flags from environment to clean build file
build:libc++ --action_env=CC
build:libc++ --action_env=CXX
build:libc++ --action_env=CXXFLAGS=-stdlib=libc++
build:libc++ --action_env=PATH
build:libc++ --define force_libcpp=enabled
build:libc++ --linkopt -fuse-ld=lld
# Android configs. Bazel needs to have --cpu and --fat_apk_cpu both set to the
# target CPU to build transient dependencies correctly. See
@ -200,6 +209,8 @@ build:nogcp --define=no_gcp_support=true
build:nohdfs --define=no_hdfs_support=true
build:nonccl --define=no_nccl_support=true
build:stackdriver_support --define=stackdriver_support=true
build --define=use_fast_cpp_protos=true
build --define=allow_oversize_protos=true

2
.github/stale.yml vendored
View File

@ -23,7 +23,7 @@
daysUntilStale: 7
# Number of days of inactivity before a stale Issue or Pull Request is closed
daysUntilClose: 7
# Issues or Pull Requests with these labels will never be considered stale. Set to `[]` to disable
# Only issues or pull requests with all of these labels are checked if stale. Defaults to `[]` (disabled)
onlyLabels:
- stat:awaiting response
# Comment to post when marking as stale. Set to `false` to disable

View File

@ -61,7 +61,6 @@ commands.
*Nightly binaries are available for testing using the
[tf-nightly](https://pypi.python.org/pypi/tf-nightly) and
[tf-nightly-cpu](https://pypi.python.org/pypi/tf-nightly-cpu) packages on PyPi.*
#### *Try your first TensorFlow program*
```shell
@ -114,6 +113,12 @@ Build Type | Status
**Android** | [![Status](https://storage.googleapis.com/tensorflow-kokoro-build-badges/android.svg)](https://storage.googleapis.com/tensorflow-kokoro-build-badges/android.html) | [![Download](https://api.bintray.com/packages/google/tensorflow/tensorflow/images/download.svg)](https://bintray.com/google/tensorflow/tensorflow/_latestVersion)
**Raspberry Pi 0 and 1** | [![Status](https://storage.googleapis.com/tensorflow-kokoro-build-badges/rpi01-py3.svg)](https://storage.googleapis.com/tensorflow-kokoro-build-badges/rpi01-py3.html) | [Py3](https://storage.googleapis.com/tensorflow-nightly/tensorflow-1.10.0-cp34-none-linux_armv6l.whl)
**Raspberry Pi 2 and 3** | [![Status](https://storage.googleapis.com/tensorflow-kokoro-build-badges/rpi23-py3.svg)](https://storage.googleapis.com/tensorflow-kokoro-build-badges/rpi23-py3.html) | [Py3](https://storage.googleapis.com/tensorflow-nightly/tensorflow-1.10.0-cp34-none-linux_armv7l.whl)
**Libtensorflow MacOS CPU** | [![Status](https://storage.googleapis.com/tensorflow-kokoro-build-badges/libtensorflow-mac-cpu.svg)](https://storage.googleapis.com/tensorflow-kokoro-build-badges/libtensorflow-mac-cpu.html) | [GCS](https://storage.googleapis.com/libtensorflow-nightly)
**Libtensorflow Linux CPU** | [![Status](https://storage.googleapis.com/tensorflow-kokoro-build-badges/libtensorflow-linux-cpu.svg)](https://storage.googleapis.com/tensorflow-kokoro-build-badges/libtensorflow-linux-cpu.html) | [GCS](https://storage.googleapis.com/libtensorflow-nightly)
**Libtensorflow Linux GPU** | [![Status](https://storage.googleapis.com/tensorflow-kokoro-build-badges/libtensorflow-linux-gpu.svg)](https://storage.googleapis.com/tensorflow-kokoro-build-badges/libtensorflow-linux-gpu.html) | [GCS](https://storage.googleapis.com/libtensorflow-nightly)
**Libtensorflow Windows CPU** | [![Status](https://storage.googleapis.com/tensorflow-kokoro-build-badges/libtensorflow-win-cpu.svg)](https://storage.googleapis.com/tensorflow-kokoro-build-badges/libtensorflow-win-cpu.html) | [GCS](https://storage.googleapis.com/libtensorflow-nightly)
**Libtensorflow Windows GPU** | [![Status](https://storage.googleapis.com/tensorflow-kokoro-build-badges/libtensorflow-win-gpu.svg)](https://storage.googleapis.com/tensorflow-kokoro-build-badges/libtensorflow-win-gpu.html) | [GCS](https://storage.googleapis.com/libtensorflow-nightly)
### Community Supported Builds

View File

@ -298,6 +298,13 @@ config_setting(
visibility = ["//visibility:public"],
)
# Experimental features
config_setting(
name = "stackdriver_support",
define_values = {"stackdriver_support": "true"},
visibility = ["//visibility:public"],
)
# Crosses between platforms and file system libraries not supported on those
# platforms due to limitations in nested select() statements.
config_setting(

View File

@ -713,8 +713,8 @@ TFE_Context* TFE_NewContext(const TFE_ContextOptions* opts, TF_Status* status) {
status->status = tfrt::ListOpHandlerChains(
opts->session_options.options, &op_handler_chains, &device_attributes);
if (!status->status.ok()) return nullptr;
return tensorflow::wrap(
new tfrt::ContextInterface(op_handler_chains, device_attributes));
return tensorflow::wrap(new tfrt::ContextInterface(
op_handler_chains, device_attributes, opts->async));
#else
status->status = tensorflow::errors::Unimplemented("TFRT is not supported");
return nullptr;

View File

@ -212,6 +212,35 @@ TEST(CAPI, CancellationManager) {
TFE_DeleteCancellationManager(c_mgr);
}
TEST(CAPI, ExecutorContextDestructionOrder) {
TF_Status* status = TF_NewStatus();
{
TFE_ContextOptions* opts = TFE_NewContextOptions();
TFE_Context* ctx = TFE_NewContext(opts, status);
ASSERT_TRUE(TF_GetCode(status) == TF_OK) << TF_Message(status);
TFE_DeleteContextOptions(opts);
TFE_Executor* executor = TFE_NewExecutor(/*is_async=*/false);
TFE_ContextSetExecutorForThread(ctx, executor);
TFE_DeleteContext(ctx);
TFE_DeleteExecutor(executor);
}
{
TFE_ContextOptions* opts = TFE_NewContextOptions();
TFE_Context* ctx = TFE_NewContext(opts, status);
ASSERT_TRUE(TF_GetCode(status) == TF_OK) << TF_Message(status);
TFE_DeleteContextOptions(opts);
TFE_Executor* executor = TFE_NewExecutor(/*is_async=*/false);
TFE_ContextSetExecutorForThread(ctx, executor);
TFE_DeleteExecutor(executor);
TFE_DeleteContext(ctx);
}
TF_DeleteStatus(status);
}
TEST(CAPI, Function_ident_CPU) {
// First create a simple identity function.
TF_Graph* function_graph = TF_NewGraph();

View File

@ -37,6 +37,15 @@ class StatusDeleter {
using StatusPtr = std::unique_ptr<TF_Status, StatusDeleter>;
class ExecutorDeleter {
public:
void operator()(TFE_Executor* to_delete) const {
TFE_DeleteExecutor(to_delete);
}
};
using ExecutorPtr = std::unique_ptr<TFE_Executor, ExecutorDeleter>;
} // namespace
// Allows a single op at a time to be launched without blocking.
@ -51,6 +60,13 @@ class DeviceThread {
explicit DeviceThread(const std::string& device)
: status_(TF_NewStatus()),
device_(device),
// If the context's default exector is set to async, re-using that in
// each thread would cause collectives to deadlock. For consistency we
// create a new sync executor for every thread.
//
// TODO(allenl): We should have an async API that works with the
// parallel device.
executor_(TFE_NewExecutor(/*is_async=*/false)),
op_(nullptr),
thread_(tensorflow::Env::Default()->StartThread(
tensorflow::ThreadOptions(), "parallel_device_execute",
@ -105,6 +121,7 @@ class DeviceThread {
StatusPtr status_ TF_GUARDED_BY(execution_mutex_);
const std::string device_;
ExecutorPtr executor_ TF_GUARDED_BY(execution_mutex_);
mutable OpPtr op_ TF_GUARDED_BY(execution_mutex_);
std::unique_ptr<Thread> thread_;
};
@ -186,6 +203,7 @@ void DeviceThread::Execute(TFE_Context* context, const char* operation_name,
std::vector<TensorHandlePtr>* outputs,
TF_Status* status) const {
if (op_ == nullptr) {
TFE_ContextSetExecutorForThread(context, executor_.get());
op_.reset(TFE_NewOp(context, operation_name, status));
if (TF_GetCode(status) != TF_OK) return;
TFE_OpSetDevice(op_.get(), device_.c_str(), status);

View File

@ -412,6 +412,7 @@ void TestCollective(bool async) {
TF_NewStatus(), TF_DeleteStatus);
std::unique_ptr<TFE_ContextOptions, decltype(&TFE_DeleteContextOptions)> opts(
TFE_NewContextOptions(), TFE_DeleteContextOptions);
TFE_ContextOptionsSetAsync(opts.get(), async);
std::unique_ptr<TF_Buffer, decltype(&TF_DeleteBuffer)> config(
TF_CreateConfig(
/*xla*/ false,
@ -423,9 +424,6 @@ void TestCollective(bool async) {
std::unique_ptr<TFE_Context, decltype(&TFE_DeleteContext)> context(
TFE_NewContext(opts.get(), status.get()), TFE_DeleteContext);
ASSERT_TRUE(TF_GetCode(status.get()) == TF_OK) << TF_Message(status.get());
std::unique_ptr<TFE_Executor, decltype(&TFE_DeleteExecutor)> executor(
TFE_NewExecutor(async), TFE_DeleteExecutor);
TFE_ContextSetExecutorForThread(context.get(), executor.get());
const char* device_name = "/job:localhost/replica:0/task:0/device:CUSTOM:0";
std::array<const char*, 2> underlying_devices{
@ -455,8 +453,6 @@ void TestCollective(bool async) {
ASSERT_TRUE(TF_GetCode(status.get()) == TF_OK) << TF_Message(status.get());
ExpectScalarEq<float>(result_components[0].get(), 3.);
ExpectScalarEq<float>(result_components[1].get(), 3.);
// Destroying the context's default executor first isn't safe.
context.reset();
}
TEST(PARALLEL_DEVICE, TestCollectiveSync) { TestCollective(/*async=*/false); }

View File

@ -27,5 +27,6 @@ cc_library(
"//tensorflow/c:tf_status",
"//tensorflow/c/experimental/filesystem:filesystem_interface",
"@com_github_googlecloudplatform_google_cloud_cpp//:storage_client",
"@com_google_absl//absl/strings",
],
)

View File

@ -15,6 +15,7 @@ limitations under the License.
#include <stdlib.h>
#include <string.h>
#include "absl/strings/string_view.h"
#include "google/cloud/storage/client.h"
#include "tensorflow/c/experimental/filesystem/filesystem_interface.h"
#include "tensorflow/c/tf_status.h"
@ -35,6 +36,45 @@ static inline void TF_SetStatusFromGCSStatus(
static void* plugin_memory_allocate(size_t size) { return calloc(1, size); }
static void plugin_memory_free(void* ptr) { free(ptr); }
static void ParseGCSPath(absl::string_view fname, bool object_empty_ok,
char** bucket, char** object, TF_Status* status) {
size_t scheme_end = fname.find("://") + 2;
if (fname.substr(0, scheme_end + 1) != "gs://") {
TF_SetStatus(status, TF_INVALID_ARGUMENT,
"GCS path doesn't start with 'gs://'.");
return;
}
size_t bucket_end = fname.find("/", scheme_end + 1);
if (bucket_end == absl::string_view::npos) {
TF_SetStatus(status, TF_INVALID_ARGUMENT,
"GCS path doesn't contain a bucket name.");
return;
}
absl::string_view bucket_view =
fname.substr(scheme_end + 1, bucket_end - scheme_end - 1);
*bucket =
static_cast<char*>(plugin_memory_allocate(bucket_view.length() + 1));
memcpy(*bucket, bucket_view.data(), bucket_view.length());
(*bucket)[bucket_view.length()] = '\0';
absl::string_view object_view = fname.substr(bucket_end + 1);
if (object_view.empty()) {
if (object_empty_ok) {
*object = nullptr;
return;
} else {
TF_SetStatus(status, TF_INVALID_ARGUMENT,
"GCS path doesn't contain an object name.");
return;
}
}
*object =
static_cast<char*>(plugin_memory_allocate(object_view.length() + 1));
// object_view.data() is a null-terminated string_view because fname is.
strcpy(*object, object_view.data());
}
// SECTION 1. Implementation for `TF_RandomAccessFile`
// ----------------------------------------------------------------------------
namespace tf_random_access_file {

View File

@ -0,0 +1,94 @@
# This package contains written convenience helpers for Eager Operations
# used by SavedModel. Once we autogenerate C++ Eager Op wrappers, we can remove these.
load(
"//tensorflow:tensorflow.bzl",
"tf_cc_test",
)
package(
default_visibility = [
# Restricting visibility for now
"//tensorflow/c/experimental/saved_model/core:__subpackages__",
"//tensorflow/c/experimental/saved_model/internal:__subpackages__",
],
licenses = ["notice"], # Apache 2.0
)
cc_library(
name = "owned_eager_op",
hdrs = [
"owned_eager_op.h",
],
deps = [
"//tensorflow/c/eager:operation_interface",
],
)
cc_library(
name = "owned_tensor_handle",
hdrs = [
"owned_tensor_handle.h",
],
deps = [
"//tensorflow/c/eager:tensor_handle_interface",
"//tensorflow/core/common_runtime/eager:tensor_handle",
],
)
cc_library(
name = "owned_eager_context",
hdrs = ["owned_eager_context.h"],
deps = [
"//tensorflow/c/eager:context_interface",
"//tensorflow/core/common_runtime/eager:context",
],
)
cc_library(
name = "owned_tensor",
hdrs = ["owned_tensor.h"],
deps = [
"//tensorflow/c:tensor_interface",
],
)
cc_library(
name = "variable_ops",
srcs = [
"variable_ops.cc",
],
hdrs = [
"variable_ops.h",
],
deps = [
":owned_eager_op",
":owned_tensor_handle",
"//tensorflow/c/eager:context_interface",
"//tensorflow/c/eager:tensor_handle_interface",
"//tensorflow/core:framework",
"//tensorflow/core:lib",
"//tensorflow/core:protos_all_cc",
"@com_google_absl//absl/types:span",
],
)
tf_cc_test(
name = "variable_ops_test",
srcs = [
"variable_ops_test.cc",
],
deps = [
":owned_eager_context",
":owned_tensor",
":owned_tensor_handle",
":variable_ops",
"//tensorflow/core:framework",
"//tensorflow/core:lib",
"//tensorflow/core:protos_all_cc",
"//tensorflow/core:test",
"//tensorflow/core:test_main",
"//tensorflow/core/common_runtime:core_cpu_lib",
"//tensorflow/core/common_runtime/eager:context",
"//tensorflow/core/common_runtime/eager:core",
],
)

View File

@ -0,0 +1,54 @@
/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#ifndef TENSORFLOW_C_EXPERIMENTAL_SAVED_MODEL_CORE_OPS_OWNED_EAGER_CONTEXT_H_
#define TENSORFLOW_C_EXPERIMENTAL_SAVED_MODEL_CORE_OPS_OWNED_EAGER_CONTEXT_H_
#include <memory>
#include "tensorflow/c/eager/context_interface.h"
#include "tensorflow/core/common_runtime/eager/context.h"
namespace tensorflow {
namespace internal {
struct AbstractContextInterfaceDeleter {
void operator()(AbstractContextInterface* p) const {
if (p != nullptr) {
p->Release();
}
}
};
struct EagerContextDeleter {
void operator()(EagerContext* p) const {
if (p != nullptr) {
p->Release();
}
}
};
} // namespace internal
using AbstractContextPtr =
std::unique_ptr<AbstractContextInterface,
internal::AbstractContextInterfaceDeleter>;
using EagerContextPtr =
std::unique_ptr<EagerContext, internal::EagerContextDeleter>;
} // namespace tensorflow
#endif // THIRD_PARTY_TENSORFLOW_C_EXPERIMENTAL_SAVED_MODEL_CORE_OPS_OWNED_EAGER_CONTEXT_H_

View File

@ -0,0 +1,42 @@
/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#ifndef TENSORFLOW_C_EXPERIMENTAL_SAVED_MODEL_CORE_OWNED_EAGER_OP_H_
#define TENSORFLOW_C_EXPERIMENTAL_SAVED_MODEL_CORE_OWNED_EAGER_OP_H_
#include <memory>
#include "tensorflow/c/eager/operation_interface.h"
namespace tensorflow {
namespace internal {
struct AbstractOperationInterfaceDeleter {
void operator()(AbstractOperationInterface* p) const {
if (p != nullptr) {
p->Release();
}
}
};
} // namespace internal
using AbstractOpPtr =
std::unique_ptr<AbstractOperationInterface,
internal::AbstractOperationInterfaceDeleter>;
} // namespace tensorflow
#endif // TENSORFLOW_C_EXPERIMENTAL_SAVED_MODEL_CORE_OWNED_EAGER_OP_H_

View File

@ -0,0 +1,42 @@
/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#ifndef TENSORFLOW_C_EXPERIMENTAL_SAVED_MODEL_CORE_OPS_OWNED_TENSOR_H_
#define TENSORFLOW_C_EXPERIMENTAL_SAVED_MODEL_CORE_OPS_OWNED_TENSOR_H_
#include <memory>
#include "tensorflow/c/tensor_interface.h"
namespace tensorflow {
namespace internal {
struct AbstractTensorInterfaceDeleter {
void operator()(AbstractTensorInterface* p) const {
if (p != nullptr) {
p->Release();
}
}
};
} // namespace internal
using AbstractTensorPtr =
std::unique_ptr<AbstractTensorInterface,
internal::AbstractTensorInterfaceDeleter>;
} // namespace tensorflow
#endif // TENSORFLOW_C_EXPERIMENTAL_SAVED_MODEL_CORE_OPS_OWNED_TENSOR_H_

View File

@ -0,0 +1,54 @@
/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#ifndef TENSORFLOW_C_EXPERIMENTAL_SAVED_MODEL_CORE_OWNED_TENSOR_HANDLE_H_
#define TENSORFLOW_C_EXPERIMENTAL_SAVED_MODEL_CORE_OWNED_TENSOR_HANDLE_H_
#include <memory>
#include "tensorflow/c/eager/tensor_handle_interface.h"
#include "tensorflow/core/common_runtime/eager/tensor_handle.h"
namespace tensorflow {
namespace internal {
struct TensorHandleDeleter {
void operator()(TensorHandle* p) const {
if (p != nullptr) {
p->Release();
}
}
};
struct AbstractTensorHandleDeleter {
void operator()(AbstractTensorHandleInterface* p) const {
if (p != nullptr) {
p->Release();
}
}
};
} // namespace internal
using TensorHandlePtr =
std::unique_ptr<TensorHandle, internal::TensorHandleDeleter>;
using AbstractTensorHandlePtr =
std::unique_ptr<AbstractTensorHandleInterface,
internal::AbstractTensorHandleDeleter>;
} // namespace tensorflow
#endif // TENSORFLOW_C_EXPERIMENTAL_SAVED_MODEL_CORE_OWNED_TENSOR_HANDLE_H_

View File

@ -0,0 +1,104 @@
/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#include "tensorflow/c/experimental/saved_model/core/ops/variable_ops.h"
#include "absl/types/span.h"
#include "tensorflow/c/eager/context_interface.h"
#include "tensorflow/c/experimental/saved_model/core/ops/owned_eager_op.h"
#include "tensorflow/c/experimental/saved_model/core/ops/owned_tensor_handle.h"
#include "tensorflow/core/framework/tensor_shape.h"
#include "tensorflow/core/framework/types.pb.h"
#include "tensorflow/core/lib/gtl/inlined_vector.h"
#include "tensorflow/core/platform/errors.h"
#include "tensorflow/core/platform/status.h"
#include "tensorflow/core/platform/types.h"
namespace tensorflow {
namespace internal {
static const char kNoSharingResourceID[] =
"cd2c89b7-88b7-44c8-ad83-06c2a9158347";
Status CreateUninitializedResourceVariable(AbstractContextInterface* ctx,
DataType dtype, TensorShape shape,
AbstractTensorHandlePtr* handle) {
AbstractOpPtr varhandle_op = AbstractOpPtr(ctx->CreateOperation());
TF_RETURN_IF_ERROR(varhandle_op->Reset("VarHandleOp", nullptr));
TF_RETURN_IF_ERROR(varhandle_op->SetAttrType("dtype", dtype));
// Note that if shape is unknown rank, shape.dim_sizes() will be empty, and
// shape.dims() will be -1.
gtl::InlinedVector<int64, 4> dim_sizes = shape.dim_sizes();
TF_RETURN_IF_ERROR(varhandle_op->SetAttrShape(
"shape", reinterpret_cast<const int64_t*>(dim_sizes.data()),
shape.dims()));
TF_RETURN_IF_ERROR(varhandle_op->SetAttrString("container", "", 0));
TF_RETURN_IF_ERROR(varhandle_op->SetAttrString(
"shared_name", kNoSharingResourceID, strlen(kNoSharingResourceID)));
AbstractTensorHandleInterface* var_handle = nullptr;
int num_retvals = 1;
TF_RETURN_IF_ERROR(varhandle_op->Execute(
absl::MakeSpan(&var_handle, num_retvals), &num_retvals));
handle->reset(var_handle);
return Status();
}
Status AssignVariable(AbstractContextInterface* ctx,
AbstractTensorHandleInterface* variable_handle,
DataType dtype, AbstractTensorHandleInterface* value) {
AbstractOpPtr assign_op(ctx->CreateOperation());
TF_RETURN_IF_ERROR(assign_op->Reset("AssignVariableOp", nullptr));
TF_RETURN_IF_ERROR(assign_op->SetAttrType("dtype", dtype));
TF_RETURN_IF_ERROR(assign_op->AddInput(variable_handle));
TF_RETURN_IF_ERROR(assign_op->AddInput(value));
int num_retvals = 0;
TF_RETURN_IF_ERROR(assign_op->Execute({}, &num_retvals));
return Status();
}
Status ReadVariable(AbstractContextInterface* ctx,
AbstractTensorHandleInterface* variable_handle,
DataType dtype, AbstractTensorHandlePtr* output) {
AbstractOpPtr read_op = AbstractOpPtr(ctx->CreateOperation());
TF_RETURN_IF_ERROR(read_op->Reset("ReadVariableOp", nullptr));
TF_RETURN_IF_ERROR(read_op->SetAttrType("dtype", dtype));
TF_RETURN_IF_ERROR(read_op->AddInput(variable_handle));
AbstractTensorHandleInterface* value = nullptr;
int num_retvals = 1;
TF_RETURN_IF_ERROR(
read_op->Execute(absl::MakeSpan(&value, num_retvals), &num_retvals));
output->reset(value);
return Status();
}
Status DestroyResource(AbstractContextInterface* ctx,
AbstractTensorHandleInterface* handle) {
AbstractOpPtr destroy_op = AbstractOpPtr(ctx->CreateOperation());
TF_RETURN_IF_ERROR(destroy_op->Reset("DestroyResourceOp", nullptr));
TF_RETURN_IF_ERROR(destroy_op->SetAttrBool("ignore_lookup_error", true));
TF_RETURN_IF_ERROR(destroy_op->AddInput(handle));
int num_retvals = 0;
TF_RETURN_IF_ERROR(destroy_op->Execute({}, &num_retvals));
return Status();
}
} // namespace internal
} // namespace tensorflow

View File

@ -0,0 +1,62 @@
/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#ifndef TENSORFLOW_C_EXPERIMENTAL_SAVED_MODEL_CORE_OPS_VARIABLE_OPS_H
#define TENSORFLOW_C_EXPERIMENTAL_SAVED_MODEL_CORE_OPS_VARIABLE_OPS_H
#include "tensorflow/c/eager/context_interface.h"
#include "tensorflow/c/eager/tensor_handle_interface.h"
#include "tensorflow/c/experimental/saved_model/core/ops/owned_tensor_handle.h"
#include "tensorflow/core/framework/tensor_shape.h"
#include "tensorflow/core/framework/types.pb.h"
#include "tensorflow/core/platform/status.h"
namespace tensorflow {
namespace internal {
// Executes a VarHandleOp using `ctx`, and fills `handle` with the DT_RESOURCE
// TensorHandle associated with the variable. This is equivalent to creating an
// unitialized TF2 tf.Variable.
// https://github.com/tensorflow/tensorflow/blob/516608035f85cec8b126712b0ff8407220206b22/tensorflow/python/ops/resource_variable_ops.py#L1867-L1872
Status CreateUninitializedResourceVariable(AbstractContextInterface* ctx,
DataType dtype, TensorShape shape,
AbstractTensorHandlePtr* handle);
// Executes an AssignVariableOp using `ctx`, assigning the variable associated
// with `variable_handle` with `value`. `dtype` must be the datatype of the
// underlying variable for `variable_handle`. Note that it is illegal to assign
// a variable to a Tensor with a different dtype than what the variable was
// created with.
Status AssignVariable(AbstractContextInterface* ctx,
AbstractTensorHandleInterface* variable_handle,
DataType dtype, AbstractTensorHandleInterface* value);
// Executes a ReadVariableOp using `ctx`. This reads the underlying variable
// value of `variable_handle` and copies the value to `output`. `dtype` must be
// the dtype of the variable associated with `variable_handle`.
Status ReadVariable(AbstractContextInterface* ctx,
AbstractTensorHandleInterface* variable_handle,
DataType dtype, AbstractTensorHandlePtr* output);
// Executes DestroyResourceOp on `handle`, using `ctx`. This is equivalent to
// the cleanup that occurs in a tf.Variable's EagerResourceDeleter:
// https://github.com/tensorflow/tensorflow/blob/516608035f85cec8b126712b0ff8407220206b22/tensorflow/python/ops/resource_variable_ops.py#L289-L290
Status DestroyResource(AbstractContextInterface* ctx,
AbstractTensorHandleInterface* handle);
} // namespace internal
} // namespace tensorflow
#endif // TENSORFLOW_C_EXPERIMENTAL_SAVED_MODEL_CORE_OPS_VARIABLE_OPS_H

View File

@ -0,0 +1,107 @@
/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#include "tensorflow/c/experimental/saved_model/core/ops/variable_ops.h"
#include <memory>
#include "tensorflow/c/experimental/saved_model/core/ops/owned_eager_context.h"
#include "tensorflow/c/experimental/saved_model/core/ops/owned_tensor.h"
#include "tensorflow/c/experimental/saved_model/core/ops/owned_tensor_handle.h"
#include "tensorflow/core/common_runtime/device_mgr.h"
#include "tensorflow/core/common_runtime/eager/context.h"
#include "tensorflow/core/framework/tensor.h"
#include "tensorflow/core/framework/types.pb.h"
#include "tensorflow/core/lib/core/status_test_util.h"
#include "tensorflow/core/platform/test.h"
namespace tensorflow {
namespace {
AbstractTensorHandlePtr CreateScalarTensorHandle(EagerContext* context,
float value) {
AbstractTensorPtr tensor(context->CreateFloatScalar(value));
AbstractTensorHandlePtr handle(context->CreateLocalHandle(tensor.get()));
return handle;
}
class VariableOpsTest : public ::testing::Test {
public:
VariableOpsTest()
: device_mgr_(std::make_unique<StaticDeviceMgr>(DeviceFactory::NewDevice(
"CPU", {}, "/job:localhost/replica:0/task:0"))),
ctx_(new EagerContext(
SessionOptions(),
tensorflow::ContextDevicePlacementPolicy::DEVICE_PLACEMENT_SILENT,
tensorflow::ContextMirroringPolicy::MIRRORING_NONE,
/* async= */ false,
/* lazy_copy_function_remote_inputs= */ false, device_mgr_.get(),
/* device_mgr_owned= */ false, /* rendezvous= */ nullptr,
/* custom_kernel_creator= */ nullptr,
/* cluster_flr= */ nullptr)) {}
EagerContext* context() { return ctx_.get(); }
private:
std::unique_ptr<StaticDeviceMgr> device_mgr_;
EagerContextPtr ctx_;
};
// Sanity check for variable creation
TEST_F(VariableOpsTest, CreateVariableSuccessful) {
// Create a DT_Resource TensorHandle that points to a scalar DT_FLOAT tensor
AbstractTensorHandlePtr handle;
TF_EXPECT_OK(internal::CreateUninitializedResourceVariable(
context(), DT_FLOAT, {}, &handle));
// The created TensorHandle should be a DT_Resource
EXPECT_EQ(handle->DataType(), DT_RESOURCE);
}
// Sanity check for variable destruction
TEST_F(VariableOpsTest, DestroyVariableSuccessful) {
// Create a DT_Resource TensorHandle that points to a scalar DT_FLOAT tensor
AbstractTensorHandlePtr handle;
TF_EXPECT_OK(internal::CreateUninitializedResourceVariable(
context(), DT_FLOAT, {}, &handle));
// Destroy the variable
TF_EXPECT_OK(internal::DestroyResource(context(), handle.get()));
}
// Sanity check for handle assignment and reading
TEST_F(VariableOpsTest, AssignVariableAndReadSuccessful) {
// Create a DT_Resource TensorHandle that points to a scalar DT_FLOAT tensor
AbstractTensorHandlePtr variable;
TF_EXPECT_OK(internal::CreateUninitializedResourceVariable(
context(), DT_FLOAT, {}, &variable));
// Create a Scalar float TensorHandle with value 42, and assign it to
// the variable.
AbstractTensorHandlePtr my_value = CreateScalarTensorHandle(context(), 42.0);
TF_EXPECT_OK(internal::AssignVariable(context(), variable.get(), DT_FLOAT,
my_value.get()));
// Read back the value from the variable, and check that it is 42.
AbstractTensorHandlePtr read_value_handle;
TF_EXPECT_OK(internal::ReadVariable(context(), variable.get(), DT_FLOAT,
&read_value_handle));
Status status;
AbstractTensorPtr read_value(read_value_handle->Resolve(&status));
TF_EXPECT_OK(status);
EXPECT_FLOAT_EQ(42.0, *static_cast<float*>(read_value->Data()));
}
} // namespace
} // namespace tensorflow

View File

@ -4,6 +4,7 @@ traces: {
value: {
file_line_cols: {
line: 1
col: 1
}
}
}
@ -12,9 +13,11 @@ traces: {
value: {
file_line_cols: {
line: 3
col: 1
}
file_line_cols: {
line: 4
col: 1
}
}
}
@ -23,6 +26,7 @@ traces: {
value: {
file_line_cols: {
line: 2
col: 1
}
}
}

View File

@ -195,53 +195,46 @@ XlaComputationLaunchContext::XlaComputationLaunchContext(
}
void XlaComputationLaunchContext::PopulateInputs(
OpKernelContext* ctx, const XlaCompiler::CompilationResult* kernel,
OpKernelContext* ctx,
const XlaCompiler::CompilationResult* compilation_result,
const std::map<int, OptionalTensor>& variables,
int missing_ctx_input_prefix) {
se::Stream* stream =
ctx->op_device_context() ? ctx->op_device_context()->stream() : nullptr;
// Build ShapedBuffers that point directly to the Tensor buffers.
arg_ptrs_ = std::vector<ShapedBuffer*>(kernel->xla_input_shapes.size());
arg_ptrs_ =
std::vector<ShapedBuffer*>(compilation_result->xla_input_shapes.size());
// Pass remaining parameters.
const Tensor* t;
for (int i = 0; i < kernel->xla_input_shapes.size(); ++i) {
int arg_num = kernel->input_mapping[i];
DCHECK_GE(arg_num, missing_ctx_input_prefix);
const xla::Shape& shape = kernel->xla_input_shapes[i];
if (variables.count(arg_num)) {
t = &(variables.at(arg_num).value);
CHECK(t);
} else {
t = &(ctx->input(arg_num - missing_ctx_input_prefix));
}
xla::TransferManager* transfer_manager =
client_->backend().transfer_manager();
for (int i = 0; i < compilation_result->xla_input_shapes.size(); ++i) {
int arg_num = compilation_result->input_mapping[i];
CHECK_GE(arg_num, missing_ctx_input_prefix);
const xla::Shape& shape = compilation_result->xla_input_shapes[i];
const Tensor* t = variables.count(arg_num)
? &(variables.at(arg_num).value)
: &(ctx->input(arg_num - missing_ctx_input_prefix));
CHECK(t);
if (use_multiple_streams_) {
CHECK(stream) << "Must have a stream available when using XLA tensors!";
CHECK(ctx->op_device_context() && ctx->op_device_context()->stream())
<< "Must have a stream available when using XLA tensors!";
XlaTensor* xla_tensor = XlaTensor::FromTensor(t);
CHECK(xla_tensor);
xla_tensor->WaitForDefinitionEventOnStream(stream);
xla_tensor->WaitForDefinitionEventOnStream(
ctx->op_device_context()->stream());
}
const xla::Shape on_device_shape =
client_->backend().transfer_manager()->HostShapeToDeviceShape(shape);
if (on_device_shape.IsTuple()) {
const XlaTensor* xla_tensor = XlaTensor::FromTensor(t);
CHECK(xla_tensor && xla_tensor->has_shaped_buffer());
arg_ptrs_[i] = const_cast<ShapedBuffer*>(&xla_tensor->shaped_buffer());
} else {
CHECK(xla::Shape::Equal().MinorToMajorOnlyInLayout()(shape,
on_device_shape))
<< "On-device shape "
<< xla::ShapeUtil::HumanStringWithLayout(on_device_shape)
<< " not the same as on-host shape "
<< xla::ShapeUtil::HumanStringWithLayout(shape);
if (xla::Shape::Equal().MinorToMajorOnlyInLayout()(
shape, transfer_manager->HostShapeToDeviceShape(shape))) {
se::DeviceMemoryBase dmem = XlaTensor::DeviceMemoryFromTensor(*t);
arg_buffers_.emplace_back(
/*on_host_shape=*/shape, /*on_device_shape=*/shape,
client_->platform(), client_->default_device_ordinal());
arg_buffers_.back().set_buffer(dmem, /*index=*/{});
arg_ptrs_[i] = &arg_buffers_.back();
} else {
const XlaTensor* xla_tensor = XlaTensor::FromTensor(t);
CHECK(xla_tensor && xla_tensor->has_shaped_buffer());
arg_ptrs_[i] = const_cast<ShapedBuffer*>(&xla_tensor->shaped_buffer());
}
}
}
@ -370,13 +363,94 @@ static Status SetBufferForResourceVarTensorUnderAllocateXlaTensors(
return Status::OK();
}
// Sets output `output_num` for `ctx` provided it is known at a compile time.
static Status SetOutputForConstant(
OpKernelContext* ctx, se::Stream* stream,
const XlaCompiler::CompilationResult* compilation_result, int output_num) {
CHECK(compilation_result->outputs[output_num].is_constant);
// Output is a constant.
const Tensor& const_tensor =
compilation_result->outputs[output_num].constant_value;
Tensor* output_tensor;
const size_t total_bytes = const_tensor.TotalBytes();
if (stream && total_bytes > 0) {
// Copy host -> device. (Empty tensors don't have backing buffers.)
// Manually allocate memory using an XlaTensorBuffer so we can allocate
// as much memory as the device requires (as given by
// GetByteSizeRequirement). This avoids XlaTransferManager having to
// reallocate the device buffer later.
VLOG(1) << "Constant output tensor on device";
TF_RETURN_IF_ERROR(
ctx->allocate_output(output_num, const_tensor.shape(), &output_tensor));
Device* device = dynamic_cast<Device*>(ctx->device());
if (device == nullptr) {
return errors::Internal("DeviceBase was not a Device.");
}
ctx->op_device_context()->CopyCPUTensorToDevice(
&const_tensor, device, output_tensor,
[&](Status status) { TF_CHECK_OK(status); });
if (device->device_type() == DEVICE_GPU) {
// The GPUDeviceContext enqueues the host->device transfer in a
// separate stream from the main compute stream. We must ensure the
// compute stream is synchronized with the host->device transfer
// stream now otherwise we will create a race condition.
auto* gpu_device_context =
static_cast<GPUDeviceContext*>(ctx->op_device_context());
gpu_device_context->stream()->ThenWaitFor(
gpu_device_context->host_to_device_stream());
}
} else {
// No copy required.
ctx->set_output(output_num, const_tensor);
output_tensor = ctx->mutable_output(output_num);
}
if (XlaTensor* xla_tensor = XlaTensor::FromTensor(output_tensor)) {
xla_tensor->set_host_tensor(const_tensor);
}
return Status::OK();
}
// Creates a list of updates resource variables.
static xla::StatusOr<std::vector<VariableInfo>> GatherVariableInfo(
OpKernelContext* ctx,
const XlaCompiler::CompilationResult* compilation_result,
int missing_ctx_input_prefix) {
std::vector<VariableInfo> variable_infos;
variable_infos.reserve(compilation_result->resource_updates.size());
for (int i = 0; i < compilation_result->resource_updates.size(); ++i) {
const XlaCompiler::ResourceUpdate& write =
compilation_result->resource_updates[i];
int actual_input_index = write.input_index - missing_ctx_input_prefix;
if (actual_input_index < 0 || actual_input_index >= ctx->num_inputs()) {
return errors::Internal("Invalid input index for variable write.");
}
// TODO(b/35625933): tensorflow::Var should contain a PersistentTensor,
// not a Tensor.
Var* variable = nullptr;
TF_RETURN_IF_ERROR(LookupOrCreateResource<Var>(
ctx, HandleFromInput(ctx, actual_input_index), &variable,
[&write](Var** ptr) {
*ptr = new Var(write.type);
return Status::OK();
}));
variable_infos.emplace_back(actual_input_index, variable);
}
return variable_infos;
}
Status XlaComputationLaunchContext::PopulateOutputs(
OpKernelContext* ctx, const XlaCompiler::CompilationResult* kernel,
OpKernelContext* ctx,
const XlaCompiler::CompilationResult* compilation_result,
ScopedShapedBuffer output, int missing_ctx_input_prefix,
const xla::HloInputOutputAliasConfig& input_output_alias,
const std::map<int, OptionalTensor>& resource_var_snapshots) {
se::Stream* stream =
ctx->op_device_context() ? ctx->op_device_context()->stream() : nullptr;
Allocator* allocator = ctx->device()->GetAllocator({});
// Computation output should always be a tuple.
if (VLOG_IS_ON(2)) {
@ -384,7 +458,7 @@ Status XlaComputationLaunchContext::PopulateOutputs(
VLOG(2) << "Result tuple shape (on device): "
<< output.on_device_shape().DebugString();
}
CHECK_EQ(ctx->num_outputs(), kernel->outputs.size());
CHECK_EQ(ctx->num_outputs(), compilation_result->outputs.size());
// If the on-host-shape isn't a tuple, create a new single-element tuple
// buffer with a nullptr root index table. This allows the code below to treat
@ -413,82 +487,41 @@ Status XlaComputationLaunchContext::PopulateOutputs(
// Copy XLA results to the OpOutputList.
int output_num = 0;
for (int i = 0; i < ctx->num_outputs(); ++i) {
Allocator* allocator = ctx->device()->GetAllocator({});
if (kernel->outputs[i].is_constant) {
// Output is a constant.
const Tensor& const_tensor = kernel->outputs[i].constant_value;
Tensor* output_tensor;
const size_t total_bytes = const_tensor.TotalBytes();
if (stream && total_bytes > 0) {
// Copy host -> device. (Empty tensors don't have backing buffers.)
// Manually allocate memory using an XlaTensorBuffer so we can allocate
// as much memory as the device requires (as given by
// GetByteSizeRequirement). This avoids XlaTransferManager having to
// reallocate the device buffer later.
VLOG(1) << "Constant output tensor on device";
const TensorShape& shape = compilation_result->outputs[i].shape;
const DataType& type = compilation_result->outputs[i].type;
VLOG(2) << "Retval " << i << " shape " << shape.DebugString() << " type "
<< DataTypeString(type);
if (type == DT_VARIANT) {
return errors::Unimplemented(
"Support for TensorList crossing the XLA/TF boundary "
"is not implemented");
}
TF_RETURN_IF_ERROR(
ctx->allocate_output(i, const_tensor.shape(), &output_tensor));
Device* device = dynamic_cast<Device*>(ctx->device());
if (device == nullptr) {
return errors::Internal("DeviceBase was not a Device.");
}
ctx->op_device_context()->CopyCPUTensorToDevice(
&const_tensor, device, output_tensor,
[&](Status status) { TF_CHECK_OK(status); });
if (device->device_type() == DEVICE_GPU) {
// The GPUDeviceContext enqueues the host->device transfer in a
// separate stream from the main compute stream. We must ensure the
// compute stream is synchronized with the host->device transfer
// stream now otherwise we will create a race condition.
auto* gpu_device_context =
static_cast<GPUDeviceContext*>(ctx->op_device_context());
gpu_device_context->stream()->ThenWaitFor(
gpu_device_context->host_to_device_stream());
}
} else {
// No copy required.
ctx->set_output(i, const_tensor);
output_tensor = ctx->mutable_output(i);
}
if (XlaTensor* xla_tensor = XlaTensor::FromTensor(output_tensor)) {
xla_tensor->set_host_tensor(const_tensor);
}
if (compilation_result->outputs[i].is_constant) {
TF_RETURN_IF_ERROR(
SetOutputForConstant(ctx, stream, compilation_result, i));
} else if (type == DT_RESOURCE) {
int input_index =
compilation_result->outputs[i].input_index - missing_ctx_input_prefix;
TF_RET_CHECK(input_index >= 0 && input_index < ctx->num_inputs())
<< "Invalid input for outputs " << i << ": " << input_index;
ctx->set_output(i, ctx->input(input_index));
} else {
const TensorShape& shape = kernel->outputs[i].shape;
const DataType& type = kernel->outputs[i].type;
VLOG(2) << "Retval " << i << " shape " << shape.DebugString() << " type "
<< DataTypeString(type);
if (type == DT_RESOURCE) {
int input_index =
kernel->outputs[i].input_index - missing_ctx_input_prefix;
TF_RET_CHECK(input_index >= 0 && input_index < ctx->num_inputs())
<< "Invalid input for outputs " << i << ": " << input_index;
ctx->set_output(i, ctx->input(input_index));
} else {
if (allocate_xla_tensors_) {
TF_RETURN_IF_ERROR(SetBufferForTensorUnderAllocateXlaTensors(
input_output_alias, output_num, ctx, i, shape, &output,
definition_event, stream, use_multiple_streams_));
} else {
if (type == DT_VARIANT) {
return errors::Unimplemented(
"Support for TensorList crossing the XLA/TF boundary "
"is not implemented");
}
if (allocate_xla_tensors_) {
TF_RETURN_IF_ERROR(SetBufferForTensorUnderAllocateXlaTensors(
input_output_alias, output_num, ctx, i, shape, &output,
definition_event, stream, use_multiple_streams_));
se::DeviceMemoryBase buffer = output.buffer({output_num});
Tensor output_tensor = GetOrCreateTensorForOutput(
output_num, ctx, missing_ctx_input_prefix, input_output_alias,
kernel->input_mapping, resource_var_snapshots,
ctx->expected_output_dtype(i), shape, buffer, allocator);
output.set_buffer(se::OwningDeviceMemory(), {output_num});
ctx->set_output(i, output_tensor);
}
++output_num;
} else {
se::DeviceMemoryBase buffer = output.buffer({output_num});
Tensor output_tensor = GetOrCreateTensorForOutput(
output_num, ctx, missing_ctx_input_prefix, input_output_alias,
compilation_result->input_mapping, resource_var_snapshots,
ctx->expected_output_dtype(i), shape, buffer, allocator);
output.set_buffer(se::OwningDeviceMemory(), {output_num});
ctx->set_output(i, output_tensor);
}
++output_num;
}
if (VLOG_IS_ON(3)) {
@ -498,34 +531,14 @@ Status XlaComputationLaunchContext::PopulateOutputs(
// Apply variable updates, if any.
VLOG(2) << "Applying variable updates";
std::vector<VariableInfo> variable_infos;
variable_infos.reserve(kernel->resource_updates.size());
for (int i = 0; i < kernel->resource_updates.size(); ++i) {
const XlaCompiler::ResourceUpdate& write = kernel->resource_updates[i];
int actual_input_index = write.input_index - missing_ctx_input_prefix;
if (actual_input_index < 0 || actual_input_index >= ctx->num_inputs()) {
return errors::Internal("Invalid input index for variable write.");
}
// TODO(b/35625933): tensorflow::Var should contain a PersistentTensor,
// not a Tensor.
Var* variable = nullptr;
TF_RETURN_IF_ERROR(LookupOrCreateResource<Var>(
ctx, HandleFromInput(ctx, actual_input_index), &variable,
[&write](Var** ptr) {
*ptr = new Var(write.type);
return Status::OK();
}));
variable_infos.emplace_back(actual_input_index, variable);
}
TF_ASSIGN_OR_RETURN(
std::vector<VariableInfo> variable_infos,
GatherVariableInfo(ctx, compilation_result, missing_ctx_input_prefix));
TF_RETURN_IF_ERROR(LockVariables(absl::MakeSpan(variable_infos)));
for (int i = 0; i < kernel->resource_updates.size(); ++i) {
Allocator* allocator = ctx->device()->GetAllocator({});
const XlaCompiler::ResourceUpdate& write = kernel->resource_updates[i];
for (int i = 0; i < compilation_result->resource_updates.size(); ++i) {
const XlaCompiler::ResourceUpdate& write =
compilation_result->resource_updates[i];
if (variable_infos[i].var()->tensor()->dtype() != write.type) {
return errors::Internal("Mismatched type in variable write");
}
@ -539,7 +552,7 @@ Status XlaComputationLaunchContext::PopulateOutputs(
output.set_buffer(se::OwningDeviceMemory(), {output_num});
Tensor output_tensor = GetOrCreateTensorForOutput(
output_num, ctx, missing_ctx_input_prefix, input_output_alias,
kernel->input_mapping, resource_var_snapshots, write.type,
compilation_result->input_mapping, resource_var_snapshots, write.type,
write.shape, buffer, allocator);
*variable_infos[i].var()->tensor() = output_tensor;
variable_infos[i].var()->is_initialized |= write.modified;

View File

@ -136,7 +136,7 @@ class XlaComputationLaunchContext {
// input_mapping must be greater than or equal to `missing_ctx_input_prefix`
// (in other words, no inputs actually required by the kernel can be missing).
void PopulateInputs(OpKernelContext* ctx,
const XlaCompiler::CompilationResult* kernel,
const XlaCompiler::CompilationResult* compilation_result,
const std::map<int, OptionalTensor>& variables,
int missing_ctx_input_prefix);
@ -148,10 +148,11 @@ class XlaComputationLaunchContext {
// See jit/resource_operation_safety_analysis for details.
//
//
// Assumes that the first `missing_ctx_input_prefix` inputs to the kernel are
// missing and adjusts input indices accordingly.
// Assumes that the first `missing_ctx_input_prefix` inputs to the
// compilation_result are missing and adjusts input indices accordingly.
Status PopulateOutputs(
OpKernelContext* ctx, const XlaCompiler::CompilationResult* kernel,
OpKernelContext* ctx,
const XlaCompiler::CompilationResult* compilation_result,
xla::ScopedShapedBuffer output, int missing_ctx_input_prefix,
const xla::HloInputOutputAliasConfig& input_output_alias,
const std::map<int, OptionalTensor>& resource_var_snapshots);

View File

@ -55,7 +55,7 @@ class XlaTensor {
// manage the memory for these tensors a ShapedBuffer may be required.
// Return true if this XlaTensor contains a ShapedBuffer.
bool has_shaped_buffer() const { return shaped_buffer_ != nullptr; }
bool has_shaped_buffer() const { return shaped_buffer_.has_value(); }
// Return the contained ShapedBuffer.
// REQUIRES: has_shaped_buffer()
const xla::ShapedBuffer& shaped_buffer() const {
@ -68,8 +68,7 @@ class XlaTensor {
}
// Mutates the XlaTensor to set the ShapedBuffer.
void set_shaped_buffer(xla::ScopedShapedBuffer shaped_buffer) {
shaped_buffer_ =
absl::make_unique<xla::ScopedShapedBuffer>(std::move(shaped_buffer));
shaped_buffer_ = std::move(shaped_buffer);
}
// Some tensors on the device may have known values on the host. We use these
@ -77,14 +76,12 @@ class XlaTensor {
// host value already.
// Return true if this XlaTensor contains a host tensor.
bool has_host_tensor() const { return host_tensor_ != nullptr; }
bool has_host_tensor() const { return host_tensor_.has_value(); }
// Return the contained host tensor.
// REQUIRES: has_host_tensor()
const Tensor& host_tensor() const { return *host_tensor_; }
// Sets the contained host tensor.
void set_host_tensor(const Tensor& tensor) {
host_tensor_.reset(new Tensor(tensor));
}
void set_host_tensor(const Tensor& tensor) { host_tensor_.emplace(tensor); }
// Adds synchronization events to 'stream' that wait for this tensor to be
// defined on 'stream'. Does nothing if the tensor is already defined on that
@ -111,9 +108,9 @@ class XlaTensor {
private:
// The optional contained ShapedBuffer.
std::unique_ptr<xla::ScopedShapedBuffer> shaped_buffer_;
absl::optional<xla::ScopedShapedBuffer> shaped_buffer_;
// An optional host tensor value.
std::unique_ptr<Tensor> host_tensor_;
absl::optional<Tensor> host_tensor_;
// An optional event that is triggered when the tensor's content has been
// defined. If this event is nullptr, it is assumed that the tensor's content
// is always defined.

View File

@ -314,6 +314,7 @@ tf_cc_test(
cc_library(
name = "tensorflow_lite_legalize_tf",
srcs = [
"transforms/device_index_selector.cc",
"transforms/dilated_conv.cc",
"transforms/generated_legalize_tf.inc",
"transforms/generated_lower_static_tensor_list.inc",

View File

@ -1406,22 +1406,67 @@ BufferOffset<tflite::SparsityParameters> Translator::BuildSparsityParameters(
for (int j = 0; j < segments.size(); j++) {
vector_segments[j] = segments[j].dyn_cast<mlir::IntegerAttr>().getInt();
}
auto array_segments =
tflite::CreateInt32Vector(builder_,
builder_.CreateVector(vector_segments))
.Union();
tflite::SparseIndexVector segments_type;
BufferOffset<void> array_segments;
// The segment array is sorted.
// TODO(b/147449640): Clean this up with util functions.
int max_of_segments = vector_segments[segments.size() - 1];
if (max_of_segments <= UINT8_MAX) {
segments_type = tflite::SparseIndexVector_Uint8Vector;
std::vector<uint8_t> uint8_vector(vector_segments.begin(),
vector_segments.end());
array_segments = tflite::CreateUint8Vector(
builder_, builder_.CreateVector(uint8_vector))
.Union();
} else if (max_of_segments <= UINT16_MAX) {
segments_type = tflite::SparseIndexVector_Uint16Vector;
std::vector<uint16_t> uint16_vector(vector_segments.begin(),
vector_segments.end());
array_segments = tflite::CreateUint16Vector(
builder_, builder_.CreateVector(uint16_vector))
.Union();
} else {
segments_type = tflite::SparseIndexVector_Int32Vector;
array_segments = tflite::CreateInt32Vector(
builder_, builder_.CreateVector(vector_segments))
.Union();
}
auto indices = dim_metadata.indices();
std::vector<int> vector_indices(indices.size(), 0);
int max_of_indices = 0;
for (int j = 0; j < indices.size(); j++) {
vector_indices[j] = indices[j].dyn_cast<mlir::IntegerAttr>().getInt();
if (vector_indices[j] > max_of_indices) {
max_of_indices = vector_indices[j];
}
}
auto array_indices = tflite::CreateInt32Vector(
builder_, builder_.CreateVector(vector_indices))
.Union();
tflite::SparseIndexVector indices_type;
BufferOffset<void> array_indices;
if (max_of_indices <= UINT8_MAX) {
indices_type = tflite::SparseIndexVector_Uint8Vector;
std::vector<uint8_t> uint8_vector(vector_indices.begin(),
vector_indices.end());
array_indices = tflite::CreateUint8Vector(
builder_, builder_.CreateVector(uint8_vector))
.Union();
} else if (max_of_indices <= UINT16_MAX) {
indices_type = tflite::SparseIndexVector_Uint16Vector;
std::vector<uint16_t> uint16_vector(vector_indices.begin(),
vector_indices.end());
array_indices = tflite::CreateUint16Vector(
builder_, builder_.CreateVector(uint16_vector))
.Union();
} else {
indices_type = tflite::SparseIndexVector_Int32Vector;
array_indices = tflite::CreateInt32Vector(
builder_, builder_.CreateVector(vector_indices))
.Union();
}
fb_dim_metadata[i] = tflite::CreateDimensionMetadata(
builder_, tflite::DimensionType_SPARSE_CSR, 0,
tflite::SparseIndexVector_Int32Vector, array_segments,
tflite::SparseIndexVector_Int32Vector, array_indices);
builder_, tflite::DimensionType_SPARSE_CSR, 0, segments_type,
array_segments, indices_type, array_indices);
}
}

View File

@ -436,7 +436,7 @@ class TFL_Op<string mnemonic, list<OpTrait> traits = []> :
class TFL_ConvOp<string mnemonic, string opSummary, int index> :
TFL_Op<mnemonic, [NoSideEffect, AccumulatorUniformScale<2, 0, 1>,
TFL_ChannelDimIndexInterface, AffineOpCoefficient<index, 1>,
TFL_GpuTargetOp]> {
TFL_GpuTargetOp, TFL_SparseOp]> {
let summary = opSummary # " operator";
let description = [{
@ -571,7 +571,8 @@ def TFL_TransposeConvOp: TFL_Op<"transpose_conv", [
TFL_OperandHasRank<2, 4>,
PredOpTrait<"input and output must have same element type",
TFL_TCresVTEtIsSameAsOp<0, 2>>,
TFL_GpuTargetOp]> {
TFL_GpuTargetOp,
TFL_SparseOp]> {
let summary = "Transpose convolution operator";
let description = [{
@ -593,6 +594,13 @@ def TFL_TransposeConvOp: TFL_Op<"transpose_conv", [
let hasOptions = 1;
let verifier = [{ return Verify(*this); }];
let extraClassDeclaration = [{
// SparseOpInterface:
std::vector<int> GetSparseOperands() { return {1}; }
std::vector<std::vector<int>> GetFloatBlockSize() { return {}; }
std::vector<std::vector<int>> GetQuantizedBlockSize() { return {}; }
}];
}
def TFL_AveragePool2DOp:
@ -826,6 +834,10 @@ def TFL_Conv2DOp : TFL_ConvOp<"conv_2d", "Convolution", 0> {
let extraClassDeclaration = [{
// ChannelDimIndexInterface:
int GetChannelDimIndex() { return 0; }
// SparseOpInterface:
std::vector<int> GetSparseOperands() { return {1}; }
std::vector<std::vector<int>> GetFloatBlockSize() { return {}; }
std::vector<std::vector<int>> GetQuantizedBlockSize() { return {}; }
}];
}
@ -866,6 +878,10 @@ def TFL_DepthwiseConv2DOp :
let extraClassDeclaration = [{
// ChannelDimIndexInterface:
int GetChannelDimIndex() { return 3; }
// SparseOpInterface:
std::vector<int> GetSparseOperands() { return {1}; }
std::vector<std::vector<int>> GetFloatBlockSize() { return {}; }
std::vector<std::vector<int>> GetQuantizedBlockSize() { return {}; }
}];
}

View File

@ -157,7 +157,7 @@
// CHECK-NEXT: }, {
// CHECK-EMPTY:
// CHECK-NEXT: }, {
// CHECK-NEXT: data: [ 49, 46, 49, 52, 46, 48, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0 ]
// CHECK-NEXT: data: [ 49, 46, 49, 53, 46, 48, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0 ]
// CHECK-NEXT: } ],
// CHECK-NEXT: metadata: [ {
// CHECK-NEXT: name: "min_runtime_version",

View File

@ -154,7 +154,7 @@ func @main(%arg0: tensor<1x224x224x3xf32>) -> tensor<1x1001xf32> {
// CHECK-NEXT: }, {
// CHECK-EMPTY:
// CHECK-NEXT: }, {
// CHECK-NEXT: data: [ 49, 46, 49, 51, 46, 49, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0 ]
// CHECK-NEXT: data: [ 49, 46, 49, 52, 46, 48, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0 ]
// CHECK-NEXT: } ],
// CHECK-NEXT: metadata: [ {
// CHECK-NEXT: name: "min_runtime_version",

View File

@ -190,7 +190,7 @@
// CHECK-NEXT: }, {
// CHECK-EMPTY:
// CHECK-NEXT: }, {
// CHECK-NEXT: data: [ 49, 46, 49, 52, 46, 48, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0 ]
// CHECK-NEXT: data: [ 49, 46, 49, 53, 46, 48, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0 ]
// CHECK-NEXT: } ],
// CHECK-NEXT: metadata: [ {
// CHECK-NEXT: name: "min_runtime_version",

View File

@ -190,7 +190,7 @@
// CHECK-NEXT: }, {
// CHECK-EMPTY:
// CHECK-NEXT: }, {
// CHECK-NEXT: data: [ 49, 46, 49, 52, 46, 48, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0 ]
// CHECK-NEXT: data: [ 49, 46, 49, 53, 46, 48, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0 ]
// CHECK-NEXT: } ],
// CHECK-NEXT: metadata: [ {
// CHECK-NEXT: name: "min_runtime_version",

View File

@ -0,0 +1,25 @@
// Test DeviceIndex selector.
// RUN: tf-opt --tfl-device-index-selector %s | FileCheck %s
// CHECK-LABEL: func @select
func @select(%arg0: tensor<f32>, %arg1: tensor<f32>) -> (tensor<i32>, tensor<f32>) {
// CHECK: %[[first:.*]] = "tf.DeviceIndex"
// CHECK: constant dense<2>
// CHECK: return %[[first]],
%0 = "tf.DeviceIndex"() {device = "", device_names = ["CPU", "GPU"]} : () -> tensor<i32>
%1 = "tf.DeviceIndex"() {device = "", device_names = ["CPU", "GPU"]} : () -> tensor<i32>
%4 = "tf.Case"(%1, %arg0, %arg1) {branches = [@sub, @add], output_shapes = [#tf.shape<>]} : (tensor<i32>, tensor<f32>, tensor<f32>) -> tensor<f32>
return %0, %4 : tensor<i32>, tensor<f32>
}
func @add(%i: tensor<i32>, %arg0: tensor<*xf32>, %arg1: tensor<*xf32>) -> tensor<*xf32> {
%0 = "tf.Add"(%arg0, %arg1): (tensor<*xf32>, tensor<*xf32>) -> tensor<*xf32>
return %0 : tensor<*xf32>
}
func @sub(%i: tensor<i32>, %arg0: tensor<*xf32>, %arg1: tensor<*xf32>) -> tensor<*xf32> {
%0 = "tf.Sub"(%arg0, %arg1) : (tensor<*xf32>, tensor<*xf32>) -> tensor<*xf32>
return %0 : tensor<*xf32>
}

View File

@ -63,6 +63,7 @@ void AddTFToTFLConversionPasses(const mlir::TFL::PassConfig& pass_config,
standard_pipeline_options.enable_inliner = false;
standard_pipeline_options.form_clusters = pass_config.form_clusters;
mlir::TF::CreateTFStandardPipeline(*pass_manager, standard_pipeline_options);
pass_manager->addPass(mlir::TFL::CreateDeviceIndexSelectorPass());
if (pass_config.shape_inference) {
pass_manager->addPass(mlir::TF::CreateTFShapeInferencePass());

View File

@ -40,14 +40,22 @@ void PopulateEncodingParams(const std::vector<int>& block_size,
std::vector<int>* traversal_order,
std::vector<TfLiteDimensionType>* format,
std::vector<int>* b_map, std::vector<int>* b_size) {
*traversal_order = {0, 1};
*format = {kTfLiteDimDense, kTfLiteDimSparseCSR};
const int dims_count = block_size.size();
traversal_order->resize(dims_count);
format->resize(dims_count);
for (int i = 0; i < dims_count; i++) {
(*traversal_order)[i] = i;
}
for (int i = 0; i < dims_count - 1; i++) {
(*format)[i] = kTfLiteDimDense;
}
(*format)[dims_count - 1] = kTfLiteDimSparseCSR;
*b_map = {};
*b_size = {};
int block_rank = 0;
for (int i = 0; i < 2; i++) {
for (int i = 0; i < dims_count; i++) {
if (block_size[i] != 1) {
traversal_order->push_back(block_rank + 2);
traversal_order->push_back(block_rank + dims_count);
format->push_back(kTfLiteDimDense);
block_rank++;
b_map->push_back(i);
@ -58,27 +66,18 @@ void PopulateEncodingParams(const std::vector<int>& block_size,
float CalculateRandomSparsity(const ElementsAttr& attr,
const ShapedType& type) {
int num_elements = 1;
for (int i = 0; i < 2; i++) {
num_elements *= type.getDimSize(i);
}
int num_elements = type.getNumElements();
int num_zeros = 0;
if (type.getElementType().isF32()) {
std::vector<float> data;
data.reserve(type.getNumElements());
for (const auto val : attr.getValues<float>()) data.push_back(val);
for (int i = 0; i < data.size(); i++) {
if (data[i] == 0) {
for (const auto val : attr.getValues<float>()) {
if (val == 0.f) {
num_zeros++;
}
}
} else if (type.getElementType().isa<quant::QuantizedType>()) {
std::vector<int8_t> data;
data.reserve(type.getNumElements());
for (const auto val : attr.getValues<int8_t>()) data.push_back(val);
for (int i = 0; i < data.size(); i++) {
if (data[i] == 0) {
for (const auto val : attr.getValues<int8_t>()) {
if (val == 0) {
num_zeros++;
}
}
@ -150,9 +149,10 @@ InspectResult InspectWeight(
type = cst.getType().cast<ShapedType>();
}
// TODO(b/147449640): Add ability to encode weights more than 2-D, e.g. Conv
// weights.
if (type.getRank() != 2) {
// Currently we only support compressing weights of ops:
// Conv, DepthwiseConv, TransposeConv, whose filter has rank 4, and
// FullyConnected, whose filter has rank 2.
if (type.getRank() != 2 && type.getRank() != 4) {
result.can_compress = false;
return result;
}
@ -195,9 +195,11 @@ std::vector<T> BuildSparsityParameterAttribute(
attr = cst.value();
type = cst.getType().cast<ShapedType>();
}
std::vector<int> shape(2);
shape[0] = type.getDimSize(0);
shape[1] = type.getDimSize(1);
const int dims_count = type.getRank();
std::vector<int> shape(dims_count);
for (int i = 0; i < dims_count; i++) {
shape[i] = type.getDimSize(i);
}
std::vector<int> traversal_order = {};
std::vector<TfLiteDimensionType> format = {};
@ -271,10 +273,13 @@ void DenseToSparse::runOnFunction() {
continue;
}
ShapedType type;
if (isa<ConstOp>(inst)) {
supported_block_size = sparse_op.GetFloatBlockSize();
type = dyn_cast<ConstOp>(inst).getType().cast<ShapedType>();
} else if (isa<QConstOp>(inst)) {
supported_block_size = sparse_op.GetQuantizedBlockSize();
type = dyn_cast<QConstOp>(inst).getType().cast<ShapedType>();
} else {
continue;
}
@ -286,7 +291,7 @@ void DenseToSparse::runOnFunction() {
// The weight is not block sparse. Encode with random sparsity.
if (result.selected_block_size.empty()) {
result.selected_block_size = {1, 1};
result.selected_block_size = std::vector<int>(type.getRank(), 1);
}
builder.setInsertionPoint(op);

View File

@ -0,0 +1,85 @@
/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
// Converts DeviceIndex to constant device.
#include "mlir/Dialect/StandardOps/IR/Ops.h" // from @llvm-project
#include "mlir/IR/Attributes.h" // from @llvm-project
#include "mlir/IR/Builders.h" // from @llvm-project
#include "mlir/IR/Operation.h" // from @llvm-project
#include "mlir/IR/PatternMatch.h" // from @llvm-project
#include "mlir/Pass/Pass.h" // from @llvm-project
#include "tensorflow/compiler/mlir/lite/transforms/passes.h"
#include "tensorflow/compiler/mlir/tensorflow/ir/tf_ops.h"
namespace mlir {
namespace TFL {
namespace {
// Folds the DeviceIndex op to a constant value. The DeviceIndex return the
// index of the device the op should run on. The user can use this to provide
// different op specializations. E.g.,
//
// ```mlir
// %1 = "tf.DeviceIndex"()
// {device = "", device_names = ["CPU", "GPU"]} : () -> tensor<i32>
// %4 = "tf.Case"(%1, %arg0, %arg1)
// {branches = [@foo, @baz], output_shapes = [#tf.shape<>]} :
// (tensor<i32>, tensor<f32>, tensor<f32>) -> tensor<f32>
// ```
//
// Shows an example where there are 2 different functions which could be
// executed to produce the same values but with different functions optimized
// for CPU or GPU.
struct DeviceIndexSelector
: public PassWrapper<DeviceIndexSelector, OperationPass<FuncOp>> {
void runOnOperation() override;
};
} // namespace
void DeviceIndexSelector::runOnOperation() {
FuncOp func = getOperation();
// Convert all the DeviceIndex ops to constant values.
func.getBody().walk([](TF::DeviceIndexOp op) {
// This just selects the default in all cases where DeviceIndex feeds into
// tf.Case. This could be enhanced based on explicit TFLite specification or
// TAC in future.
OpBuilder b(op);
RankedTensorType type = RankedTensorType::get({}, b.getIntegerType(32));
int index = op.device_names().size();
for (auto use : op.getOperation()->getUsers()) {
// Skip if it doesn't feed into case. Alternatively this could always
// return the CPU device index if it exists.
if (!isa<TF::CaseOp>(use)) return;
}
DenseElementsAttr attr =
DenseElementsAttr::get(type, b.getI32IntegerAttr(index));
auto constant = b.create<ConstantOp>(op.getLoc(), type, attr);
op.replaceAllUsesWith(constant.getOperation());
op.erase();
});
}
// Creates an instance of the TensorFlow DeviceIndex selector pass.
std::unique_ptr<OperationPass<FuncOp>> CreateDeviceIndexSelectorPass() {
return std::make_unique<DeviceIndexSelector>();
}
static PassRegistration<DeviceIndexSelector> pass(
"tfl-device-index-selector", "Fold tf.DeviceIndex to constant");
} // namespace TFL
} // namespace mlir

View File

@ -91,6 +91,9 @@ std::unique_ptr<OperationPass<ModuleOp>> CreateWhileOutlinePass();
// Verifies runtime constraints.
std::unique_ptr<OperationPass<FuncOp>> CreateRuntimeVerifyPass();
// Creates function pass to select device index/fold tf.DeviceIndex.
std::unique_ptr<OperationPass<FuncOp>> CreateDeviceIndexSelectorPass();
} // namespace TFL
} // namespace mlir

View File

@ -397,6 +397,73 @@ cc_library(
],
)
cc_library(
name = "lift_variables_lib",
srcs = [
"transforms/lift_variables.cc",
],
hdrs = [
"transforms/lift_variables.h",
],
deps = [
":convert_tensor",
":tensorflow",
"//tensorflow/core:core_cpu",
"//tensorflow/core:core_cpu_lib",
"//tensorflow/core:framework",
"//tensorflow/core:framework_internal",
"//tensorflow/core:lib",
"//tensorflow/core/platform:errors",
"//tensorflow/core/platform:threadpool_options",
"@llvm-project//llvm:Support",
"@llvm-project//mlir:IR",
"@llvm-project//mlir:Pass",
"@llvm-project//mlir:Support",
],
alwayslink = 1,
)
cc_library(
name = "lift_variables_pass",
hdrs = [
"transforms/lift_variables_pass.h",
],
deps = [
":lift_variables_lib",
":tensorflow",
"//tensorflow/core:core_cpu",
"//tensorflow/core:framework",
"//tensorflow/core:framework_internal",
"//tensorflow/core:lib",
"@llvm-project//mlir:IR",
"@llvm-project//mlir:Pass",
"@llvm-project//mlir:Support",
],
alwayslink = 1,
)
cc_library(
name = "lift_variables_test_pass",
hdrs = [
"transforms/lift_variables_test_pass.h",
],
deps = [
":lift_variables_lib",
":tensorflow",
"//tensorflow/core:core_cpu",
"//tensorflow/core:core_cpu_lib",
"//tensorflow/core:framework",
"//tensorflow/core:framework_internal",
"//tensorflow/core:lib",
"//tensorflow/core/platform:errors",
"//tensorflow/core/platform:status",
"//tensorflow/core/platform:threadpool_options",
"@llvm-project//mlir:Pass",
"@llvm-project//mlir:Support",
],
alwayslink = 1,
)
cc_library(
name = "tensorflow_passes",
srcs = [
@ -520,9 +587,11 @@ cc_library(
cc_library(
name = "tensorflow_test_passes",
srcs = [
"transforms/lift_variables_test_pass_registration.cc",
"transforms/lower_tf_pass.cc",
],
deps = [
":lift_variables_test_pass",
":lower_tf_lib",
"@llvm-project//mlir:IR",
"@llvm-project//mlir:Pass",

View File

@ -820,6 +820,12 @@ followed by cropping along the `height` and `width` dimensions.
TF_DerivedOperandTypeAttr T = TF_DerivedOperandTypeAttr<0>;
TF_DerivedOperandTypeAttr Tidx = TF_DerivedOperandTypeAttr<1>;
let verifier = [{
return Verify(*this);
}];
let hasCanonicalizer = 1;
}
def TF_BatchToSpaceNDOp : TF_Op<"BatchToSpaceND", [NoSideEffect]> {
@ -7280,6 +7286,49 @@ $$\text{variable} := \text{variable} - \text{lr}_t * m_t / (\sqrt{v_t} + \epsilo
TF_DerivedOperandTypeAttr T = TF_DerivedOperandTypeAttr<3>;
}
def TF_ResourceApplyCenteredRMSPropOp : TF_Op<"ResourceApplyCenteredRMSProp", []> {
let summary = "Update '*var' according to the centered RMSProp algorithm.";
let description = [{
The centered RMSProp algorithm uses an estimate of the centered second moment
(i.e., the variance) for normalization, as opposed to regular RMSProp, which
uses the (uncentered) second moment. This often helps with training, but is
slightly more expensive in terms of computation and memory.
Note that in dense implementation of this algorithm, mg, ms, and mom will
update even if the grad is zero, but in this sparse implementation, mg, ms,
and mom will not update in iterations during which the grad is zero.
mean_square = decay * mean_square + (1-decay) * gradient ** 2
mean_grad = decay * mean_grad + (1-decay) * gradient
Delta = learning_rate * gradient / sqrt(mean_square + epsilon - mean_grad ** 2)
mg <- rho * mg_{t-1} + (1-rho) * grad
ms <- rho * ms_{t-1} + (1-rho) * grad * grad
mom <- momentum * mom_{t-1} + lr * grad / sqrt(ms - mg * mg + epsilon)
var <- var - mom
}];
let arguments = (ins
TF_ResourceTensor:$var,
TF_ResourceTensor:$mg,
TF_ResourceTensor:$ms,
TF_ResourceTensor:$mom,
TensorOf<[BF16, F16, F32, F64, I16, I32, I64, I8, TF_Complex128, TF_Complex64, TF_Qint32, TF_Qint8, TF_Quint8, TF_Uint16, TF_Uint32, TF_Uint64, TF_Uint8]>:$lr,
TensorOf<[BF16, F16, F32, F64, I16, I32, I64, I8, TF_Complex128, TF_Complex64, TF_Qint32, TF_Qint8, TF_Quint8, TF_Uint16, TF_Uint32, TF_Uint64, TF_Uint8]>:$rho,
TensorOf<[BF16, F16, F32, F64, I16, I32, I64, I8, TF_Complex128, TF_Complex64, TF_Qint32, TF_Qint8, TF_Quint8, TF_Uint16, TF_Uint32, TF_Uint64, TF_Uint8]>:$momentum,
TensorOf<[BF16, F16, F32, F64, I16, I32, I64, I8, TF_Complex128, TF_Complex64, TF_Qint32, TF_Qint8, TF_Quint8, TF_Uint16, TF_Uint32, TF_Uint64, TF_Uint8]>:$epsilon,
TensorOf<[BF16, F16, F32, F64, I16, I32, I64, I8, TF_Complex128, TF_Complex64, TF_Qint32, TF_Qint8, TF_Quint8, TF_Uint16, TF_Uint32, TF_Uint64, TF_Uint8]>:$grad,
DefaultValuedAttr<BoolAttr, "false">:$use_locking
);
let results = (outs);
TF_DerivedOperandTypeAttr T = TF_DerivedOperandTypeAttr<4>;
}
def TF_ResourceApplyGradientDescentOp : TF_Op<"ResourceApplyGradientDescent", []> {
let summary = "Update '*var' by subtracting 'alpha' * 'delta' from it.";
@ -9055,11 +9104,11 @@ particular,
begin = [1, 2, x, x, 0, x] # x denotes don't care (usually 0)
end = [2, 4, x, x, -3, x]
strides = [1, 1, x, x, -1, 1]
begin_mask = 1<<4 | 1 << 5 = 48
begin_mask = 1<<4 | 1<<5 = 48
end_mask = 1<<5 = 32
ellipsis_mask = 1<<3 = 8
new_axis_mask = 1<<2 4
shrink_axis_mask = 1<<0
new_axis_mask = 1<<2 = 4
shrink_axis_mask = 1<<0 = 1
```
In this case if `foo.shape` is (5, 5, 5, 5, 5, 5) the final shape of

View File

@ -695,6 +695,154 @@ void BatchMatMulV2Op::getCanonicalizationPatterns(
results.insert<BatchMatMulV2ToMatMul>(context);
}
//===----------------------------------------------------------------------===//
// BatchToSpaceOp
//===----------------------------------------------------------------------===//
static LogicalResult Verify(BatchToSpaceOp op) {
// Op already has a constraint that block_size >= 2.
int64_t block_size = op.block_size().getSExtValue();
llvm::SmallVector<int64_t, 4> input_shape(4, ShapedType::kDynamicSize);
auto input_type = op.input().getType().cast<TensorType>();
if (input_type.hasRank()) {
if (input_type.getRank() != 4)
return op.emitOpError()
<< "requires input to be a 4D tensor, but got " << input_type;
int64_t input_batch = input_type.getDimSize(0);
if (input_batch != ShapedType::kDynamicSize &&
input_batch % (block_size * block_size) != 0) {
return op.emitOpError()
<< "requires input batch (dimension 0) to be evenly divisible "
"by (block_size * block_size), but got input batch "
<< input_batch << " and block_size " << block_size;
}
input_shape.assign(input_type.getShape().begin(),
input_type.getShape().end());
}
auto crops_type = op.crops().getType().cast<TensorType>();
if (crops_type.hasRank()) {
if (crops_type.getRank() != 2)
return op.emitOpError()
<< "requires crops to be a 2D tensor, but got " << crops_type;
auto dim_of_size = [&](int64_t dim, int64_t size) {
if (crops_type.isDynamicDim(dim)) return true;
return crops_type.getDimSize(dim) == size;
};
if (!dim_of_size(0, 2) || !dim_of_size(1, 2))
return op.emitOpError()
<< "requires crops to be a tensor<2x2>, but got " << crops_type;
}
DenseIntElementsAttr crops_attr;
// Crops are defined as [[crop_top, crop_bottom], [crop_left, crop_right]],
// and flattened as [crop_top, crop_bottom, crop_left, crop_right]
llvm::SmallVector<int64_t, 4> crops_values;
if (matchPattern(op.crops(), m_Constant(&crops_attr))) {
assert(crops_attr.getNumElements() == 4 &&
"tf.BatchToSpace crops must have 4 elements");
auto crops_range = crops_attr.getIntValues();
for (const auto &crops_value : crops_range) {
int64_t crops_value_int = crops_value.getSExtValue();
if (crops_value_int < 0)
return op.emitOpError()
<< "requires all crop values to be nonnegative, but got "
<< crops_attr;
crops_values.push_back(crops_value_int);
}
}
auto output_type = op.output().getType().cast<TensorType>();
if (output_type.hasRank()) {
if (output_type.getRank() != 4)
return op.emitOpError()
<< "requires output to be a 4D tensor, but got " << output_type;
auto static_dims = [](int64_t dim_a, int64_t dim_b) {
return dim_a != ShapedType::kDynamicSize &&
dim_b != ShapedType::kDynamicSize;
};
auto output_shape = output_type.getShape();
// output batch = input batch / (block_size * block_size).
int64_t input_batch = input_shape[0];
int64_t output_batch = output_shape[0];
if (static_dims(input_batch, output_batch) &&
(output_batch * block_size * block_size) != input_batch)
return op.emitOpError()
<< "requires output batch (dimension 0) to be equal to input "
"batch (dimension 0) / (block_size * block_size), but got "
"output batch "
<< output_batch << ", input batch " << input_batch
<< ", and block_size " << block_size;
auto check_spatial_dim = [&](int64_t spatial_dim_index,
llvm::StringRef dim_name,
llvm::StringRef crop_a_name,
llvm::StringRef crop_b_name) -> LogicalResult {
int64_t input_dim = input_shape[spatial_dim_index];
int64_t output_dim = output_shape[spatial_dim_index];
if (!static_dims(input_dim, output_dim)) return success();
int64_t input_dim_pad = input_dim * block_size;
// If crops are unknown, the maximum output spatial dim size is input
// spatial dim size * block_size, as crops can be minimum 0.
if (crops_values.empty() && output_dim > input_dim * block_size)
return op.emitOpError()
<< "requires output " << dim_name << " (dimension "
<< spatial_dim_index << ") to be less than or equal to input "
<< dim_name << " (dimension " << spatial_dim_index
<< ") * block_size, but got output " << dim_name << " "
<< output_dim << ", input " << dim_name << " " << input_dim
<< ", and block_size " << block_size;
if (!crops_values.empty()) {
// output spatial dim = input spatial dim * block_size - crops.
int64_t crop_a = crops_values[2 * (spatial_dim_index - 1)];
int64_t crop_b = crops_values[2 * (spatial_dim_index - 1) + 1];
if (output_dim != input_dim_pad - crop_a - crop_b)
return op.emitOpError()
<< "requires output " << dim_name << " (dimension "
<< spatial_dim_index << ") to be equal to input " << dim_name
<< " (dimension " << spatial_dim_index << ") * block_size - "
<< crop_a_name << " - " << crop_b_name << ", but got output "
<< dim_name << " " << output_dim << ", input " << dim_name
<< " " << input_dim << ", " << crop_a_name << " " << crop_a
<< ", " << crop_b_name << " " << crop_b << ", and block_size "
<< block_size;
}
return success();
};
if (failed(check_spatial_dim(1, "height", "crop_top", "crop_bottom")) ||
failed(check_spatial_dim(2, "width", "crop_left", "crop_right")))
return failure();
int64_t input_depth = input_shape[3];
int64_t output_depth = output_shape[3];
if (static_dims(input_depth, output_depth) && output_depth != input_depth)
return op.emitOpError()
<< "requires output depth (dimension 3) to be equal to input "
"depth (dimension 3), but got output depth "
<< output_depth << " and input depth " << input_depth;
}
return success();
}
void BatchToSpaceOp::getCanonicalizationPatterns(
OwningRewritePatternList &results, MLIRContext *context) {
results.insert<BatchToSpaceToBatchToSpaceND>(context);
}
//===----------------------------------------------------------------------===//
// BiasAddOp
//===----------------------------------------------------------------------===//

View File

@ -586,3 +586,27 @@ func @sub(%arg0: tensor<*xf32>, %arg1: tensor<*xf32>) -> tensor<*xf32> {
%0 = "tf.Sub"(%arg0, %arg1) : (tensor<*xf32>, tensor<*xf32>) -> tensor<*xf32>
return %0 : tensor<*xf32>
}
// CHECK-LABEL: testBatchToSpaceToBatchToSpaceND
// CHECK-SAME: ([[INPUT:%.*]]: tensor<?x?x?x?xf32>, [[CROPS:%.*]]: tensor<?x?xi32>)
func @testBatchToSpaceToBatchToSpaceND(%arg0: tensor<?x?x?x?xf32>, %arg1: tensor<?x?xi32>) -> tensor<*xf32> {
// CHECK: [[BLOCK_SHAPE:%.*]] = "tf.Const"() {value = dense<8> : tensor<2xi64>}
// CHECK: [[BATCH_TO_SHAPE_ND:%.*]] = "tf.BatchToSpaceND"([[INPUT]], [[BLOCK_SHAPE]], [[CROPS]])
%0 = "tf.BatchToSpace"(%arg0, %arg1) {block_size = 8 : i64} : (tensor<?x?x?x?xf32>, tensor<?x?xi32>) -> tensor<*xf32>
// CHECK: return [[BATCH_TO_SHAPE_ND]]
return %0 : tensor<*xf32>
}
// CHECK-LABEL: testBatchToSpaceDynamicInput
func @testBatchToSpaceDynamicInput(%arg0: tensor<*xf32>, %arg1: tensor<?x?xi32>) -> tensor<*xf32> {
// CHECK-NOT: "tf.BatchToSpaceND"
%0 = "tf.BatchToSpace"(%arg0, %arg1) {block_size = 8 : i64} : (tensor<*xf32>, tensor<?x?xi32>) -> tensor<*xf32>
return %0 : tensor<*xf32>
}
// CHECK-LABEL: testBatchToSpaceDynamicCrops
func @testBatchToSpaceDynamicCrops(%arg0: tensor<?x?x?x?xf32>, %arg1: tensor<*xi32>) -> tensor<*xf32> {
// CHECK-NOT: "tf.BatchToSpaceND"
%0 = "tf.BatchToSpace"(%arg0, %arg1) {block_size = 8 : i64} : (tensor<?x?x?x?xf32>, tensor<*xi32>) -> tensor<*xf32>
return %0 : tensor<*xf32>
}

View File

@ -368,6 +368,57 @@ func @decompose_resource_gather_op(%indices : tensor<5xi32>) -> tensor<2x5x16xi3
// -----
// Tests that composite tf.ResourceApplyCenteredRMSProp operation is decomposed.
// CHECK-LABEL: func @decompose_resource_apply_centered_RMS_prop
// CHECK-SAME: [[VAR:%.*]]: tensor<f32>, [[MG:%.*]]: tensor<f32>, [[MS:%.*]]: tensor<f32>, [[MOM:%.*]]: tensor<f32>, [[LR:%.*]]: tensor<f32>, [[RHO:%.*]]: tensor<f32>, [[MOMENTUM:%.*]]: tensor<f32>, [[EPSILON:%.*]]: tensor<f32>, [[GRAD:%.*]]: tensor<f32>
func @decompose_resource_apply_centered_RMS_prop(%arg0: tensor<f32>, %arg1: tensor<f32>, %arg2: tensor<f32>, %arg3: tensor<f32>, %arg4: tensor<f32>, %arg5: tensor<f32>, %arg6: tensor<f32>, %arg7: tensor<f32>, %arg8: tensor<f32>) -> () {
// CHECK: [[ONE:%.*]] = "tf.Const"() {value = dense<1.000000e+00> : tensor<f32>}
// CHECK: [[VAR_HANDLE:%.*]] = "tf.VarHandleOp"
// CHECK: [[MG_HANDLE:%.*]] = "tf.VarHandleOp"
// CHECK: [[MS_HANDLE:%.*]] = "tf.VarHandleOp"
// CHECK: [[MOM_HANDLE:%.*]] = "tf.VarHandleOp"
%0 = "tf.VarHandleOp"() {container = "c", shared_name = "v"} : () -> tensor<*x!tf.resource>
%1 = "tf.VarHandleOp"() {container = "c", shared_name = "v"} : () -> tensor<*x!tf.resource>
%2 = "tf.VarHandleOp"() {container = "c", shared_name = "v"} : () -> tensor<*x!tf.resource>
%3 = "tf.VarHandleOp"() {container = "c", shared_name = "v"} : () -> tensor<*x!tf.resource>
// CHECK: [[GRADSQ:%.*]] = "tf.Mul"([[GRAD]], [[GRAD]])
// CHECK: [[SB:%.*]] = "tf.Sub"([[ONE]], [[RHO]])
// CHECK: [[GRAD_SUB:%.*]] = "tf.Mul"([[GRADSQ]], [[SB]])
// CHECK: [[MS:%.*]] = "tf.ReadVariableOp"([[MS_HANDLE]])
// CHECK: [[MS_RHO:%.*]] = "tf.Mul"([[MS]], [[RHO]])
// CHECK: [[MS_NEW:%.*]] = "tf.Add"([[GRAD_SUB]], [[MS_RHO]])
// CHECK: "tf.AssignVariableOp"([[MS_HANDLE]], [[MS_NEW]])
// CHECK: [[SUB_RHO:%.*]] = "tf.Sub"([[ONE]], [[RHO]])
// CHECK: [[SUB_GRAD:%.*]] = "tf.Mul"([[GRAD]], [[SUB_RHO]])
// CHECK: [[MG:%.*]] = "tf.ReadVariableOp"([[MG_HANDLE]])
// CHECK: [[MG_RHO:%.*]] = "tf.Mul"([[MG]], [[RHO]])
// CHECK: [[MG_NEW:%.*]] = "tf.Add"([[SUB_GRAD]], [[MG_RHO]])
// CHECK: "tf.AssignVariableOp"([[MG_HANDLE]], [[MG_NEW]])
// CHECK: [[MOM:%.*]] = "tf.ReadVariableOp"([[MOM_HANDLE]])
// CHECK: [[MOM_MOM:%.*]] = "tf.Mul"([[MOMENTUM]], [[MOM]])
// CHECK: [[LR_GRAD:%.*]] = "tf.Mul"([[LR]], [[GRAD]])
// CHECK: [[MG_MG:%.*]] = "tf.Mul"([[MG_NEW]], [[MG_NEW]])
// CHECK: [[MG_NEW:%.*]] = "tf.Add"([[MG_MG]], [[EPSILON]])
// CHECK: [[MG_SUB:%.*]] = "tf.Sub"([[MS_NEW]], [[MG_NEW]])
// CHECK: [[MG_SQRT:%.*]] = "tf.Sqrt"([[MG_SUB]])
// CHECK: [[MOM_DIV:%.*]] = "tf.Div"([[LR_GRAD]], [[MG_SQRT]])
// CHECK: [[MOM_NEW:%.*]] = "tf.Add"([[MOM_MOM]], [[MOM_DIV]])
// CHECK: [[VAR:%.*]] = "tf.ReadVariableOp"([[VAR_HANDLE]])
// CHECK: [[VAR_NEW:%.*]] = "tf.Sub"([[VAR]], [[MOM_NEW]])
// CHECK: "tf.AssignVariableOp"([[VAR_HANDLE]], [[VAR_NEW]])
"tf.ResourceApplyCenteredRMSProp"(%0, %1, %2, %3, %arg4, %arg5, %arg6, %arg7, %arg8) {use_locking = false} : (tensor<*x!tf.resource>, tensor<*x!tf.resource>, tensor<*x!tf.resource>, tensor<*x!tf.resource>, tensor<f32>, tensor<f32>, tensor<f32>, tensor<f32>, tensor<f32>) -> ()
return
}
// -----
// Tests that composite tf.ResourceScatterUpdate operation is decomposed.
// CHECK-LABEL: @decompose_resource_scatter_update_op

View File

@ -4,6 +4,7 @@ traces: {
value: {
file_line_cols: {
line : 1
col : 1
}
}
}
@ -12,9 +13,11 @@ traces: {
value: {
file_line_cols: {
line : 3
col : 1
}
file_line_cols: {
line : 4
col : 1
}
}
}
@ -23,6 +26,7 @@ traces: {
value: {
file_line_cols: {
line : 2
col : 1
}
}
}

View File

@ -773,6 +773,20 @@ func @convert_reduce_to_max(%arg0: tensor<1x256xf32>) -> tensor<1xf32> {
return %1 : tensor<1xf32>
}
func @convert_reduce_to_min(%arg0: tensor<1x256xf32>) -> tensor<1xf32> {
// "0x7F800000" represents INF for f32.
%0 = xla_hlo.constant dense<0x7F800000> : tensor<f32>
%1 = "xla_hlo.reduce"(%arg0, %0) ( {
^bb0(%arg1: tensor<f32>, %arg2: tensor<f32>):
%2 = xla_hlo.minimum %arg1, %arg2 : tensor<f32>
"xla_hlo.return"(%2) : (tensor<f32>) -> ()
}) {dimensions = dense<1> : tensor<1xi64>} : (tensor<1x256xf32>, tensor<f32>) -> tensor<1xf32>
return %1 : tensor<1xf32>
}
// NOTE: Assertions have been autogenerated by utils/generate-test-checks.py
// CHECK-LABEL: func @biasAdd_NHWC(
@ -1689,3 +1703,10 @@ func @convert_reduce_to_max(%arg0: tensor<1x256xf32>) -> tensor<1xf32> {
// CHECK: [[VAL_418:%.*]] = "tf.Max"([[VAL_416:%.*]], [[VAL_417:%.*]]) {keep_dims = false} : (tensor<1x256xf32>, tensor<1xi64>) -> tensor<1xf32>
// CHECK: return [[VAL_418]] : tensor<1xf32>
// CHECK: }
// CHECK-LABEL: func @convert_reduce_to_min(
// CHECK-SAME: [[VAL_419:%.*]]: tensor<1x256xf32>) -> tensor<1xf32> {
// CHECK: [[VAL_420:%.*]] = "tf.Const"() {value = dense<1> : tensor<1xi64>} : () -> tensor<1xi64>
// CHECK: [[VAL_421:%.*]] = "tf.Min"([[VAL_419:%.*]], [[VAL_420:%.*]]) {keep_dims = false} : (tensor<1x256xf32>, tensor<1xi64>) -> tensor<1xf32>
// CHECK: return [[VAL_421]] : tensor<1xf32>
// CHECK: }

View File

@ -2872,4 +2872,141 @@ func @testSendTPUEmbeddingGradients(%x: tensor<512x256xf32>) {
return
}
// -----
//===--------------------------------------------------------------------===//
// tf.BatchToSpace
//===--------------------------------------------------------------------===//
func @testBatchToSpaceDynamic(%arg0: tensor<*xf32>, %arg1: tensor<*xi32>) {
%0 = "tf.BatchToSpace"(%arg0, %arg1) {block_size = 2 : i64} : (tensor<*xf32>, tensor<*xi32>) -> tensor<*xf32>
return
}
func @testBatchToSpaceRankedInput(%arg0: tensor<?x?x?x?xf32>, %arg1: tensor<*xi32>) {
%0 = "tf.BatchToSpace"(%arg0, %arg1) {block_size = 2 : i64} : (tensor<?x?x?x?xf32>, tensor<*xi32>) -> tensor<*xf32>
return
}
func @testBatchToSpaceRankedCrops(%arg0: tensor<*xf32>, %arg1: tensor<?x?xi32>) {
%0 = "tf.BatchToSpace"(%arg0, %arg1) {block_size = 2 : i64} : (tensor<*xf32>, tensor<?x?xi32>) -> tensor<*xf32>
return
}
func @testBatchToSpaceRankedOutput(%arg0: tensor<*xf32>, %arg1: tensor<*xi32>) {
%0 = "tf.BatchToSpace"(%arg0, %arg1) {block_size = 2 : i64} : (tensor<*xf32>, tensor<*xi32>) -> tensor<?x?x?x?xf32>
return
}
func @testBatchToSpaceStatic(%arg0: tensor<36x8x8x8xf32>) {
%crops = "tf.Const"() {value = dense<[[1, 2], [3, 4]]> : tensor<2x2xi32>} : () -> tensor<2x2xi32>
%0 = "tf.BatchToSpace"(%arg0, %crops) {block_size = 3 : i64} : (tensor<36x8x8x8xf32>, tensor<2x2xi32>) -> tensor<4x21x17x8xf32>
return
}
// -----
func @testBatchToSpaceInvalidInputRank(%arg0: tensor<8xf32>, %arg1: tensor<*xi32>) {
// expected-error @+1 {{'tf.BatchToSpace' op requires input to be a 4D tensor, but got 'tensor<8xf32>'}}
%0 = "tf.BatchToSpace"(%arg0, %arg1) {block_size = 2 : i64} : (tensor<8xf32>, tensor<*xi32>) -> tensor<*xf32>
return
}
// -----
func @testBatchToSpaceInvalidInputBatch(%arg0: tensor<2x4x6x8xf32>, %arg1: tensor<*xi32>) {
// expected-error @+1 {{'tf.BatchToSpace' op requires input batch (dimension 0) to be evenly divisible by (block_size * block_size), but got input batch 2 and block_size 2}}
%0 = "tf.BatchToSpace"(%arg0, %arg1) {block_size = 2 : i64} : (tensor<2x4x6x8xf32>, tensor<*xi32>) -> tensor<*xf32>
return
}
// -----
func @testBatchToSpaceInvalidCropsRank(%arg0: tensor<*xf32>, %arg1: tensor<?x?x?xi32>) {
// expected-error @+1 {{'tf.BatchToSpace' op requires crops to be a 2D tensor, but got 'tensor<?x?x?xi32>'}}
%0 = "tf.BatchToSpace"(%arg0, %arg1) {block_size = 2 : i64} : (tensor<*xf32>, tensor<?x?x?xi32>) -> tensor<*xf32>
return
}
// -----
func @testBatchToSpaceInvalidCropsFirstDim(%arg0: tensor<*xf32>, %arg1: tensor<3x?xi32>) {
// expected-error @+1 {{'tf.BatchToSpace' op requires crops to be a tensor<2x2>, but got 'tensor<3x?xi32>'}}
%0 = "tf.BatchToSpace"(%arg0, %arg1) {block_size = 2 : i64} : (tensor<*xf32>, tensor<3x?xi32>) -> tensor<*xf32>
return
}
// -----
func @testBatchToSpaceInvalidCropsSecondDim(%arg0: tensor<*xf32>, %arg1: tensor<?x3xi32>) {
// expected-error @+1 {{'tf.BatchToSpace' op requires crops to be a tensor<2x2>, but got 'tensor<?x3xi32>'}}
%0 = "tf.BatchToSpace"(%arg0, %arg1) {block_size = 2 : i64} : (tensor<*xf32>, tensor<?x3xi32>) -> tensor<*xf32>
return
}
// -----
func @testBatchToSpaceBadCropValues(%arg0: tensor<*xf32>) {
%crops = "tf.Const"() {value = dense<[[-1, -2], [-3, -4]]> : tensor<2x2xi32>} : () -> tensor<2x2xi32>
// expected-error @+1 {{'tf.BatchToSpace' op requires all crop values to be nonnegative, but got dense<[[-1, -2], [-3, -4]]> : tensor<2x2xi32>}}
%0 = "tf.BatchToSpace"(%arg0, %crops) {block_size = 2 : i64} : (tensor<*xf32>, tensor<2x2xi32>) -> tensor<*xf32>
return
}
// -----
func @testBatchToSpaceInvalidOutputRank(%arg0: tensor<*xf32>, %arg1: tensor<*xi32>) {
// expected-error @+1 {{'tf.BatchToSpace' op requires output to be a 4D tensor, but got 'tensor<8xf32>'}}
%0 = "tf.BatchToSpace"(%arg0, %arg1) {block_size = 2 : i64} : (tensor<*xf32>, tensor<*xi32>) -> tensor<8xf32>
return
}
// -----
func @testBatchToSpaceInvalidOutputBatch(%arg0: tensor<16x8x8x3xf32>, %arg1: tensor<*xi32>) {
// expected-error @+1 {{'tf.BatchToSpace' op requires output batch (dimension 0) to be equal to input batch (dimension 0) / (block_size * block_size), but got output batch 8, input batch 16, and block_size 2}}
%0 = "tf.BatchToSpace"(%arg0, %arg1) {block_size = 2 : i64} : (tensor<16x8x8x3xf32>, tensor<*xi32>) -> tensor<8x8x8x3xf32>
return
}
// -----
func @testBatchToSpaceInvalidOutputHeight(%arg0: tensor<16x8x8x3xf32>, %arg1: tensor<*xi32>) {
// expected-error @+1 {{'tf.BatchToSpace' op requires output height (dimension 1) to be less than or equal to input height (dimension 1) * block_size, but got output height 17, input height 8, and block_size 2}}
%0 = "tf.BatchToSpace"(%arg0, %arg1) {block_size = 2 : i64} : (tensor<16x8x8x3xf32>, tensor<*xi32>) -> tensor<4x17x8x3xf32>
return
}
// -----
func @testBatchToSpaceInvalidOutputHeightCrops(%arg0: tensor<16x8x8x3xf32>) {
%crops = "tf.Const"() {value = dense<[[1, 2], [3, 4]]> : tensor<2x2xi32>} : () -> tensor<2x2xi32>
// expected-error @+1 {{'tf.BatchToSpace' op requires output height (dimension 1) to be equal to input height (dimension 1) * block_size - crop_top - crop_bottom, but got output height 8, input height 8, crop_top 1, crop_bottom 2, and block_size 2}}
%0 = "tf.BatchToSpace"(%arg0, %crops) {block_size = 2 : i64} : (tensor<16x8x8x3xf32>, tensor<2x2xi32>) -> tensor<4x8x9x3xf32>
return
}
// -----
func @testBatchToSpaceInvalidOutputWidth(%arg0: tensor<16x4x4x3xf32>, %arg1: tensor<*xi32>) {
// expected-error @+1 {{'tf.BatchToSpace' op requires output width (dimension 2) to be less than or equal to input width (dimension 2) * block_size, but got output width 9, input width 4, and block_size 2}}
%0 = "tf.BatchToSpace"(%arg0, %arg1) {block_size = 2 : i64} : (tensor<16x4x4x3xf32>, tensor<*xi32>) -> tensor<4x4x9x3xf32>
return
}
// -----
func @testBatchToSpaceInvalidOutputWidthCrops(%arg0: tensor<16x8x8x3xf32>) {
%crops = "tf.Const"() {value = dense<[[1, 2], [3, 4]]> : tensor<2x2xi32>} : () -> tensor<2x2xi32>
// expected-error @+1 {{'tf.BatchToSpace' op requires output width (dimension 2) to be equal to input width (dimension 2) * block_size - crop_left - crop_right, but got output width 8, input width 8, crop_left 3, crop_right 4, and block_size 2}}
%0 = "tf.BatchToSpace"(%arg0, %crops) {block_size = 2 : i64} : (tensor<16x8x8x3xf32>, tensor<2x2xi32>) -> tensor<4x13x8x3xf32>
return
}
// -----
func @testBatchToSpaceInvalidOutputDepth(%arg0: tensor<16x8x8x3xf32>, %arg1: tensor<*xi32>) {
// expected-error @+1 {{'tf.BatchToSpace' op requires output depth (dimension 3) to be equal to input depth (dimension 3), but got output depth 8 and input depth 3}}
%0 = "tf.BatchToSpace"(%arg0, %arg1) {block_size = 2 : i64} : (tensor<16x8x8x3xf32>, tensor<*xi32>) -> tensor<4x8x8x8xf32>
return
}

View File

@ -0,0 +1,61 @@
// RUN: tf-opt -verify-diagnostics -tf-saved-model-lift-variables-test -split-input-file %s | FileCheck %s --dump-input=fail
module attributes {tf_saved_model.semantics} {
// Test case: Freezing VarHandleOp ops.
func @serving_default(%arg0: tensor<!tf.resource<tensor<100x50xf32>>> {tf.resource_name = "dense/kernel"}, %arg1: tensor<!tf.resource<tensor<50xf32>>> {tf.resource_name = "dense/bias"}) -> (tensor<100x50xf32> {tf_saved_model.index_path = ["dense_2"]})
attributes {tf.entry_function = {control_outputs = "", inputs = "", outputs = "dense_2/Add:0"}, tf_saved_model.exported_names = ["serving_default"]} {
%0 = "tf.VarHandleOp"() {_class = ["loc:@dense/kernel"], allowed_devices = [], container = "", device = "", shared_name = "dense/kernel"} : () -> tensor<!tf.resource<tensor<100x50xf32>>>
%1 = "tf.ReadVariableOp"(%0) {device = ""} : (tensor<!tf.resource<tensor<100x50xf32>>>) -> tensor<100x50xf32>
%2 = "tf.VarHandleOp"() {_class = ["loc:@dense/bias"], allowed_devices = [], container = "", device = "", shared_name = "dense/bias"} : () -> tensor<!tf.resource<tensor<50xf32>>>
%3 = "tf.ReadVariableOp"(%2) {device = ""} : (tensor<!tf.resource<tensor<50xf32>>>) -> tensor<50xf32>
%4 = "tf.Add"(%1, %3) {device = ""} : (tensor<100x50xf32>, tensor<50xf32>) -> tensor<100x50xf32>
return %4 : tensor<100x50xf32>
}
// CHECK: "tf_saved_model.global_tensor"()
// CHECK: sym_name = "dense/kernel"
// CHECK: "tf_saved_model.global_tensor"()
// CHECK: sym_name = "dense/bias"
// CHECK: func @serving_default(
// CHECK: %arg0: tensor<!tf.resource<tensor<100x50xf32>>> {tf_saved_model.bound_input = @"dense/kernel"},
// CHECK: %arg1: tensor<!tf.resource<tensor<50xf32>>> {tf_saved_model.bound_input = @"dense/bias"})
}
// -----
module attributes {tf_saved_model.semantics} {
// Test case: Freezing shared VarHandleOp ops.
func @f(%arg0: tensor<!tf.resource<tensor<100x50xf32>>> {tf.resource_name = "dense/kernel"}, %arg1: tensor<!tf.resource<tensor<50xf32>>> {tf.resource_name = "dense/bias"}) -> (tensor<100x50xf32> {tf_saved_model.index_path = ["dense_2"]})
attributes {tf.entry_function = {control_outputs = "", inputs = "", outputs = "dense_2/Add:0"}, tf_saved_model.exported_names = ["f"]} {
%0 = "tf.VarHandleOp"() {_class = ["loc:@dense/kernel"], allowed_devices = [], container = "", device = "", shared_name = "dense/kernel"} : () -> tensor<!tf.resource<tensor<100x50xf32>>>
%1 = "tf.ReadVariableOp"(%0) {device = ""} : (tensor<!tf.resource<tensor<100x50xf32>>>) -> tensor<100x50xf32>
%2 = "tf.VarHandleOp"() {_class = ["loc:@dense/bias"], allowed_devices = [], container = "", device = "", shared_name = "dense/bias"} : () -> tensor<!tf.resource<tensor<50xf32>>>
%3 = "tf.ReadVariableOp"(%2) {device = ""} : (tensor<!tf.resource<tensor<50xf32>>>) -> tensor<50xf32>
%4 = "tf.Add"(%1, %3) {device = ""} : (tensor<100x50xf32>, tensor<50xf32>) -> tensor<100x50xf32>
return %4 : tensor<100x50xf32>
}
func @f2(%arg0: tensor<!tf.resource<tensor<100x50xf32>>> {tf.resource_name = "dense/kernel"}, %arg1: tensor<!tf.resource<tensor<50xf32>>> {tf.resource_name = "dense/bias"}) -> (tensor<100x50xf32> {tf_saved_model.index_path = ["dense_2"]})
attributes {tf.entry_function = {control_outputs = "", inputs = "", outputs = "dense_2/Add:0"}, tf_saved_model.exported_names = ["f2"]} {
%0 = "tf.VarHandleOp"() {_class = ["loc:@dense/kernel"], allowed_devices = [], container = "", device = "", shared_name = "dense/kernel"} : () -> tensor<!tf.resource<tensor<100x50xf32>>>
%1 = "tf.ReadVariableOp"(%0) {device = ""} : (tensor<!tf.resource<tensor<100x50xf32>>>) -> tensor<100x50xf32>
%2 = "tf.VarHandleOp"() {_class = ["loc:@dense/bias"], allowed_devices = [], container = "", device = "", shared_name = "dense/bias"} : () -> tensor<!tf.resource<tensor<50xf32>>>
%3 = "tf.ReadVariableOp"(%2) {device = ""} : (tensor<!tf.resource<tensor<50xf32>>>) -> tensor<50xf32>
%4 = "tf.Add"(%1, %3) {device = ""} : (tensor<100x50xf32>, tensor<50xf32>) -> tensor<100x50xf32>
return %4 : tensor<100x50xf32>
}
// CHECK: "tf_saved_model.global_tensor"()
// CHECK: sym_name = "dense/kernel"
// CHECK: "tf_saved_model.global_tensor"()
// CHECK: sym_name = "dense/bias"
// CHECK: func @f(
// CHECK: %arg0: tensor<!tf.resource<tensor<100x50xf32>>> {tf_saved_model.bound_input = @"dense/kernel"},
// CHECK: %arg1: tensor<!tf.resource<tensor<50xf32>>> {tf_saved_model.bound_input = @"dense/bias"})
// CHECK: func @f2(
// CHECK: %arg0: tensor<!tf.resource<tensor<100x50xf32>>> {tf_saved_model.bound_input = @"dense/kernel"},
// CHECK: %arg1: tensor<!tf.resource<tensor<50xf32>>> {tf_saved_model.bound_input = @"dense/bias"})
}

View File

@ -0,0 +1,33 @@
// RUN: tf-opt -verify-diagnostics -tf-saved-model-lift-variables-invalid-session-test -split-input-file %s | FileCheck %s --dump-input=fail
// Test case: Invalid session.
// expected-error @+1 {{'module' op no session provided}}
module attributes {tf_saved_model.semantics} {
func @serving_default(%arg0: tensor<!tf.resource<tensor<100x50xf32>>> {tf.resource_name = "dense/kernel"}, %arg1: tensor<!tf.resource<tensor<50xf32>>> {tf.resource_name = "dense/bias"}) -> (tensor<100x50xf32> {tf_saved_model.index_path = ["dense_2"]})
attributes {tf.entry_function = {control_outputs = "", inputs = "", outputs = "dense_2/Add:0"}, tf_saved_model.exported_names = ["serving_default"]} {
%0 = "tf.VarHandleOp"() {_class = ["loc:@dense/kernel"], allowed_devices = [], container = "", device = "", shared_name = "dense/kernel"} : () -> tensor<!tf.resource<tensor<100x50xf32>>>
%1 = "tf.ReadVariableOp"(%0) {device = ""} : (tensor<!tf.resource<tensor<100x50xf32>>>) -> tensor<100x50xf32>
%2 = "tf.VarHandleOp"() {_class = ["loc:@dense/bias"], allowed_devices = [], container = "", device = "", shared_name = "dense/bias"} : () -> tensor<!tf.resource<tensor<50xf32>>>
%3 = "tf.ReadVariableOp"(%2) {device = ""} : (tensor<!tf.resource<tensor<50xf32>>>) -> tensor<50xf32>
%4 = "tf.Add"(%1, %3) {device = ""} : (tensor<100x50xf32>, tensor<50xf32>) -> tensor<100x50xf32>
return %4 : tensor<100x50xf32>
}
}
// -----
// Test case: No errors on no resource arguments.
module attributes {tf_saved_model.semantics} {
// CHECK-LABEL: @serving_default
func @serving_default() -> (tensor<100x50xf32> {tf_saved_model.index_path = ["dense_2"]})
attributes {tf.entry_function = {control_outputs = "", inputs = "", outputs = "dense_2/Add:0"}, tf_saved_model.exported_names = ["serving_default"]} {
%0 = "tf.VarHandleOp"() {_class = ["loc:@dense/kernel"], allowed_devices = [], container = "", device = "", shared_name = "dense/kernel"} : () -> tensor<!tf.resource<tensor<100x50xf32>>>
%1 = "tf.ReadVariableOp"(%0) {device = ""} : (tensor<!tf.resource<tensor<100x50xf32>>>) -> tensor<100x50xf32>
%2 = "tf.VarHandleOp"() {_class = ["loc:@dense/bias"], allowed_devices = [], container = "", device = "", shared_name = "dense/bias"} : () -> tensor<!tf.resource<tensor<50xf32>>>
%3 = "tf.ReadVariableOp"(%2) {device = ""} : (tensor<!tf.resource<tensor<50xf32>>>) -> tensor<50xf32>
%4 = "tf.Add"(%1, %3) {device = ""} : (tensor<100x50xf32>, tensor<50xf32>) -> tensor<100x50xf32>
return %4 : tensor<100x50xf32>
}
}

View File

@ -27,6 +27,8 @@ def SingleResultAndOperandHaveSameType : Constraint<
def IsRank2Tensor : Type<HasAnyRankOfPred<[2]>, "Rank 2 tensor">;
def IsRank4Tensor : Type<HasAnyRankOfPred<[4]>, "Rank 4 tensor">;
// Checks if all the users is ReadVariableOp.
def HasOnlyReadVariableOpUsers : Constraint<
CPred<"llvm::all_of($0.getUsers(), [](mlir::OpOperand op) { "
@ -65,6 +67,21 @@ def BatchMatMulV2ToMatMul : Pat<(TF_BatchMatMulV2Op $x, $y, $adj_x, $adj_y),
(TF_MatMulOp $x, $y, $adj_x, $adj_y),
[(IsRank2Tensor $x), (IsRank2Tensor $y)]>;
//===----------------------------------------------------------------------===//
// BatchToSpace op patterns.
//===----------------------------------------------------------------------===//
def BatchToSpaceBlockSizeToBlockShape : NativeCodeCall<
"DenseElementsAttr::get(RankedTensorType::get({2}, $_builder.getI64Type()), "
"ArrayRef<APInt>{$0.getValue(), $0.getValue()})">;
def BatchToSpaceToBatchToSpaceND :
Pat<(TF_BatchToSpaceOp $input, $crops, $block_size),
(TF_BatchToSpaceNDOp $input,
(TF_ConstOp (BatchToSpaceBlockSizeToBlockShape $block_size)),
$crops),
[(IsRank4Tensor $input), (IsRank2Tensor $crops)]>;
//===----------------------------------------------------------------------===//
// BiasAddV1 op patterns.
//===----------------------------------------------------------------------===//

View File

@ -327,3 +327,66 @@ def DecomposeVariableShape : Pat<
(TF_VariableShapeOp:$src_op $resource),
(TF_ShapeOp (CreateTFReadVariableOpFromResourceHandle $src_op, $resource)),
[(CheckHasResourceSubtype $resource)]>;
// This decomposition is only correct inside XLA as it ignores use_locking
// attribute.
// ms <- rho * ms_{t-1} + (1-rho) * grad * grad
// mg = grad * (one - rho) + mg * rho;
// mom <- momentum * mom_{t-1} + lr * grad / sqrt(ms - mg * mg + epsilon)
//
def DecomposeResourceApplyCenteredRMSProp :
Pattern<
(TF_ResourceApplyCenteredRMSPropOp:$src_op
$var_resource, $mg_resource, $ms_resource, $mom_resource, $lr, $rho, $momentum, $epsilon,
$grad, ConstBoolAttrFalse:$use_locking
),
[(TF_ConstOp:$one (GetScalarOfType<1> $grad)),
(CreateTFReadVariableOp $src_op, $grad, $ms_resource),
(TF_AddOp:$ms_new
(TF_MulOp
(TF_MulOp $grad, $grad),
(TF_SubOp $one, $rho)
),
(TF_MulOp
(CreateTFReadVariableOp $src_op, $grad, $ms_resource),
$rho
)
),
(TF_AssignVariableOp $ms_resource, $ms_new),
// mg = grad * (one - rho) + mg * rho;
(TF_AddOp:$mg_new
(TF_MulOp
$grad,
(TF_SubOp $one, $rho)
),
(TF_MulOp
(CreateTFReadVariableOp $src_op, $grad, $mg_resource),
$rho
)
),
(TF_AssignVariableOp $mg_resource, $mg_new),
// mom <- momentum * mom_{t-1} + lr * grad / sqrt(ms - mg * mg + epsilon)
(TF_AddOp:$mom_new
(TF_MulOp $momentum,
(CreateTFReadVariableOp $src_op, $grad, $mom_resource)),
(TF_DivOp
(TF_MulOp $lr, $grad),
(TF_SqrtOp
(TF_SubOp
$ms_new,
(TF_AddOp
(TF_MulOp
$mg_new,
$mg_new
),
$epsilon
)
)
)
)
),
(TF_AssignVariableOp $mom_resource, $mom_new),
// var <- var - mom
(TF_AssignSubVariableOp $var_resource, $mom_new)
]
>;

View File

@ -32,11 +32,15 @@ namespace TF {
namespace {
// Note: This implements fusions performed in the old Remapper Grappler pass.
// That pass has specific cases for GPU and based on different target
// configurations on both CPU and GPU (Intel MKL, ROCm, etc.). This MLIR pass
// covers the general CPU case and at the moment does not account for any
// target-specific configurations.
// Note: This implements the fusions performed in the old Remapper Grappler
// pass. That pass has specific cases for GPU and based on different
// target configurations on both CPU and GPU (Intel MKL, ROCm, etc.). This MLIR
// pass covers (some of) the general CPU case and at the moment does not account
// for any target-specific configurations.
// This pass is being ported over from the Grappler Remapper pass based on
// need/usage. File a bug to request porting over additional fusions.
// TODO(b/158265178): Support GPU-specific fusions.
// TODO(b/158266710): Support CPU MKL configurations.

View File

@ -489,9 +489,9 @@ LogicalResult MatchReduceOpInput(xla_hlo::ReduceOp reduce_op) {
return success();
}
// TODO(b/157192370): This "xla_hlo::ReduceOp" can corresponds to many TF ops
// with different ops in reduce_op.body. Now we only match to "tf.Max" and
// "tf.Sum".
// TODO(jingpu): This "xla_hlo::ReduceOp" can corresponds to many TF ops
// with different ops in reduce_op.body. Now we only match to "tf.Max", "tf.Min"
// and "tf.Sum".
class ConvertReduceOpToTfSum : public OpConversionPattern<xla_hlo::ReduceOp> {
public:
using OpConversionPattern::OpConversionPattern;
@ -504,15 +504,13 @@ class ConvertReduceOpToTfSum : public OpConversionPattern<xla_hlo::ReduceOp> {
Operation *first_op = &reduce_op.body().front().front();
if (!llvm::isa<xla_hlo::AddOp>(first_op)) return failure();
// In `MatchReduceOpInput` function, we only match that the
// In `MatchReduceOpInput` function, we already match that the
// "xla_hlo::ReduceOp" only has one input, one init_value and one result.
auto input = reduce_op.operands()[0];
// Get reduction dimension.
DenseIntElementsAttr dimension = reduce_op.dimensions();
SmallVector<int64_t, 4> reduce_dims;
const int64_t input_rank = input.getType().cast<ShapedType>().getRank();
for (const int64_t &dim : dimension.getValues<int64_t>()) {
if (dim < 0 || dim >= input_rank) return failure();
reduce_dims.emplace_back(dim);
}
@ -545,15 +543,13 @@ class ConvertReduceOpToTfMax : public OpConversionPattern<xla_hlo::ReduceOp> {
Operation *first_op = &reduce_op.body().front().front();
if (!llvm::isa<xla_hlo::MaxOp>(first_op)) return failure();
// In `MatchReduceOpInput` function, we only match that the
// In `MatchReduceOpInput` function, we already match that the
// "xla_hlo::ReduceOp" only has one input, one init_value and one result.
auto input = reduce_op.operands()[0];
// Get reduction dimension.
DenseIntElementsAttr dimension = reduce_op.dimensions();
SmallVector<int64_t, 4> reduce_dims;
const int64_t input_rank = input.getType().cast<ShapedType>().getRank();
for (const int64_t &dim : dimension.getValues<int64_t>()) {
if (dim < 0 || dim >= input_rank) return failure();
reduce_dims.emplace_back(dim);
}
@ -576,6 +572,47 @@ class ConvertReduceOpToTfMax : public OpConversionPattern<xla_hlo::ReduceOp> {
};
};
class ConvertReduceOpToTfMin : public OpConversionPattern<xla_hlo::ReduceOp> {
public:
using OpConversionPattern::OpConversionPattern;
LogicalResult matchAndRewrite(
xla_hlo::ReduceOp reduce_op, ArrayRef<Value> args,
ConversionPatternRewriter &rewriter) const final {
if (failed(MatchReduceOpInput(reduce_op))) return failure();
Operation *first_op = &reduce_op.body().front().front();
if (!llvm::isa<xla_hlo::MinOp>(first_op)) return failure();
// In `MatchReduceOpInput` function, we already match that the
// "xla_hlo::ReduceOp" only has one input, one init_value and one result.
Value input = reduce_op.operands()[0];
// Get reduction dimension.
DenseIntElementsAttr dimension = reduce_op.dimensions();
SmallVector<int64_t, 4> reduce_dims;
for (const int64_t &dim : dimension.getValues<int64_t>()) {
reduce_dims.emplace_back(dim);
}
// Check initial value is +INF.
DenseFPElementsAttr init_value;
if (!matchPattern(reduce_op.init_values()[0], m_Constant(&init_value)) ||
!init_value.isSplat() ||
!init_value.getSplatValue<APFloat>().isInfinity() ||
init_value.getSplatValue<APFloat>().isNegative())
return failure();
auto dim_type = RankedTensorType::get(
{static_cast<int64_t>(reduce_dims.size())}, rewriter.getI64Type());
auto reduction_indices = rewriter.create<ConstOp>(
reduce_op.getLoc(), dim_type, rewriter.getI64TensorAttr(reduce_dims));
rewriter.replaceOpWithNewOp<MinOp>(
reduce_op, reduce_op.getType(0), input, reduction_indices,
/*keep_dim=*/rewriter.getBoolAttr(false));
return success();
};
};
class LegalizeHloToTf : public PassWrapper<LegalizeHloToTf, FunctionPass> {
public:
LegalizeHloToTf() = default;
@ -709,7 +746,7 @@ void LegalizeHloToTf::runOnFunction() {
OwningRewritePatternList patterns;
populateWithGenerated(&context, &patterns);
patterns.insert<ConvertConvOp, ConvertSliceOp, ConvertReduceOpToTfMax,
ConvertReduceOpToTfSum>(&context);
ConvertReduceOpToTfMin, ConvertReduceOpToTfSum>(&context);
ConversionTarget target(context);
target.addLegalDialect<TensorFlowDialect>();

View File

@ -0,0 +1,183 @@
/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#include "tensorflow/compiler/mlir/tensorflow/transforms/lift_variables.h"
#include <algorithm>
#include <iterator>
#include <string>
#include <tuple>
#include <type_traits>
#include <utility>
#include <vector>
#include "llvm/ADT/ArrayRef.h"
#include "llvm/ADT/STLExtras.h"
#include "llvm/ADT/SmallSet.h"
#include "llvm/ADT/SmallVector.h"
#include "llvm/ADT/StringRef.h"
#include "mlir/IR/Builders.h" // from @llvm-project
#include "mlir/IR/Module.h" // from @llvm-project
#include "mlir/IR/UseDefLists.h" // from @llvm-project
#include "mlir/Support/LLVM.h" // from @llvm-project
#include "tensorflow/compiler/mlir/tensorflow/ir/tf_ops.h"
#include "tensorflow/compiler/mlir/tensorflow/ir/tf_saved_model.h"
#include "tensorflow/compiler/mlir/tensorflow/utils/convert_tensor.h"
#include "tensorflow/core/common_runtime/device.h"
#include "tensorflow/core/common_runtime/device_mgr.h"
#include "tensorflow/core/framework/resource_mgr.h"
#include "tensorflow/core/framework/resource_var.h"
#include "tensorflow/core/framework/tensor.h"
#include "tensorflow/core/platform/errors.h"
#include "tensorflow/core/platform/mutex.h"
#include "tensorflow/core/platform/threadpool_options.h"
#include "tensorflow/core/public/session.h"
namespace mlir {
namespace tf_saved_model {
using llvm::SmallSet;
using ::tensorflow::Device;
using ::tensorflow::DeviceMgr;
using ::tensorflow::mutex_lock;
using ::tensorflow::ResourceHandle;
using ::tensorflow::Session;
using ::tensorflow::Status;
using ::tensorflow::StatusOr;
using ::tensorflow::Tensor;
using ::tensorflow::Var;
namespace {
constexpr char kResourceNameArgAttr[] = "tf.resource_name";
constexpr char kSavedModelArgAttr[] = "tf_saved_model.bound_input";
LogicalResult LiftVariablesFromSession(
ModuleOp module, Session* session,
const SmallSet<StringRef, 4>& resource_names) {
OpBuilder builder(module.getBodyRegion());
MLIRContext* context = module.getContext();
if (!session) return module.emitOpError() << "no session provided";
// Read all resource variables from the session.
std::vector<std::string> variable_names;
variable_names.reserve(resource_names.size());
for (StringRef name : resource_names) variable_names.push_back(name.str());
std::vector<Tensor> resource_tensors;
Status status = session->Run(
/*inputs=*/{}, variable_names,
/*target_node_names=*/{}, &resource_tensors);
if (!status.ok()) {
return module.emitOpError()
<< "failed to run the provided session: " << status.error_message();
}
const DeviceMgr* device_manager;
if (!(session->LocalDeviceManager(&device_manager).ok())) {
return module.emitOpError() << "failed to get local device manager";
}
// Read all underlying tensors of the variables from the session.
std::vector<Tensor> tensors;
tensors.reserve(resource_tensors.size());
for (const Tensor& resource_tensor : resource_tensors) {
if (resource_tensor.dtype() != tensorflow::DT_RESOURCE) {
tensors.push_back(resource_tensor);
continue;
}
const ResourceHandle& resource_handle =
resource_tensor.scalar<ResourceHandle>()();
Device* device;
if (!(device_manager->LookupDevice(resource_handle.device(), &device)
.ok())) {
return module.emitOpError() << "failed to look up device";
}
tensorflow::Var* var_ptr;
if (!(device->resource_manager()
->Lookup(resource_handle.container(), resource_handle.name(),
&var_ptr)
.ok())) {
return module.emitOpError() << "failed to look up resource value";
}
tensorflow::core::RefCountPtr<Var> var(var_ptr);
// The variable tensor is already loaded into corresponding device's
// resource manager when we load the saved model using LoadSavedModel().
// Here we just read its value.
mutex_lock ml(*var->mu());
tensors.push_back(*var->tensor());
}
for (const auto iter : llvm::zip(resource_names, tensors)) {
const StringRef name = std::get<0>(iter);
const Tensor& tensor = std::get<1>(iter);
// Create tensor attribute for this variable.
StatusOr<ElementsAttr> tensor_attr_or = ConvertTensor(tensor, &builder);
if (!tensor_attr_or.ok()) {
return module.emitOpError()
<< "failed to convert tensor (name: " << name.str() << ")";
}
ElementsAttr tensor_attr = tensor_attr_or.ValueOrDie();
builder.create<tf_saved_model::GlobalTensorOp>(
NameLoc::get(builder.getIdentifier(name.str()), context),
builder.getStringAttr(name), tensor_attr,
TypeAttr::get(tensor_attr.getType()), builder.getUnitAttr());
}
return success();
}
} // namespace
LogicalResult LiftVariables(ModuleOp module, Session* session) {
MLIRContext* context = module.getContext();
mlir::Builder builder(context);
Identifier resource_name_id = builder.getIdentifier(kResourceNameArgAttr);
SmallSet<StringRef, 4> resource_names;
for (FuncOp func : module.getOps<FuncOp>()) {
for (int i = 0, e = func.getNumArguments(); i < e; ++i) {
auto resource_arg =
func.getArgAttrOfType<StringAttr>(i, kResourceNameArgAttr);
if (!resource_arg) continue;
StringRef resource_name = resource_arg.getValue();
auto flat_symbol_ref_attr =
FlatSymbolRefAttr::get(resource_name, context);
// Add the corresponding `tf_saved_model.bound_input` attribute.
func.setArgAttr(i, kSavedModelArgAttr, flat_symbol_ref_attr);
resource_names.insert(flat_symbol_ref_attr.getValue());
// Remove the existing `tf.resource_name` attribute.
func.removeArgAttr(i, resource_name_id);
}
}
if (resource_names.empty()) return success();
return LiftVariablesFromSession(module, session, resource_names);
}
} // namespace tf_saved_model
} // namespace mlir

View File

@ -0,0 +1,33 @@
/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#ifndef TENSORFLOW_COMPILER_MLIR_TENSORFLOW_TRANSFORMS_LIFT_VARIABLES_H_
#define TENSORFLOW_COMPILER_MLIR_TENSORFLOW_TRANSFORMS_LIFT_VARIABLES_H_
#include "mlir/IR/Module.h" // from @llvm-project
#include "mlir/Support/LogicalResult.h" // from @llvm-project
#include "tensorflow/core/public/session.h"
namespace mlir {
namespace tf_saved_model {
// Creates GlobalTensorOp for each variable from function arguments and converts
// them to the corresponding saved model arguments.
LogicalResult LiftVariables(ModuleOp module, ::tensorflow::Session* session);
} // namespace tf_saved_model
} // namespace mlir
#endif // TENSORFLOW_COMPILER_MLIR_TENSORFLOW_TRANSFORMS_LIFT_VARIABLES_H_

View File

@ -0,0 +1,57 @@
/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#ifndef TENSORFLOW_COMPILER_MLIR_TENSORFLOW_TRANSFORMS_LIFT_VARIABLES_PASS_H_
#define TENSORFLOW_COMPILER_MLIR_TENSORFLOW_TRANSFORMS_LIFT_VARIABLES_PASS_H_
#include "mlir/IR/Module.h" // from @llvm-project
#include "mlir/Pass/Pass.h" // from @llvm-project
#include "mlir/Pass/PassManager.h" // from @llvm-project
#include "mlir/Support/LogicalResult.h" // from @llvm-project
#include "tensorflow/compiler/mlir/tensorflow/transforms/lift_variables.h"
#include "tensorflow/core/public/session.h"
namespace mlir {
namespace tf_saved_model {
// This pass takes care of finding all variables from the function arguments and
// converting them to the corresponding global tensors, that will be located out
// of function. Also it converts resource arguments from function types to the
// corresponding saved model arguments accordingly.
class LiftVariablesPass
: public PassWrapper<LiftVariablesPass, OperationPass<ModuleOp>> {
public:
explicit LiftVariablesPass(::tensorflow::Session* session)
: session_(session) {}
void runOnOperation() override {
ModuleOp module = getOperation();
if (failed(LiftVariables(module, session_))) signalPassFailure();
}
private:
::tensorflow::Session* session_;
};
// Creates as pass that creates GlobalTensorOp for each variable from function
// arguments and converts the function arguments to the corresponding saved
// model arguments.
std::unique_ptr<OperationPass<ModuleOp>> CreateLiftVariablesPass(
::tensorflow::Session* session);
} // namespace tf_saved_model
} // namespace mlir
#endif // TENSORFLOW_COMPILER_MLIR_TENSORFLOW_TRANSFORMS_LIFT_VARIABLES_PASS_H_

View File

@ -0,0 +1,146 @@
/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#ifndef TENSORFLOW_COMPILER_MLIR_TENSORFLOW_TRANSFORMS_LIFT_VARIABLES_TEST_PASS_H_
#define TENSORFLOW_COMPILER_MLIR_TENSORFLOW_TRANSFORMS_LIFT_VARIABLES_TEST_PASS_H_
#include "mlir/Pass/Pass.h" // from @llvm-project
#include "mlir/Pass/PassManager.h" // from @llvm-project
#include "mlir/Support/LLVM.h" // from @llvm-project
#include "tensorflow/compiler/mlir/tensorflow/transforms/lift_variables.h"
#include "tensorflow/core/common_runtime/device_mgr.h"
#include "tensorflow/core/framework/resource_mgr.h"
#include "tensorflow/core/framework/tensor.h"
#include "tensorflow/core/platform/errors.h"
#include "tensorflow/core/platform/status.h"
#include "tensorflow/core/platform/threadpool_options.h"
#include "tensorflow/core/public/session.h"
namespace mlir {
namespace tf_saved_model {
using ::tensorflow::DeviceMgr;
using ::tensorflow::Session;
using ::tensorflow::Status;
using ::tensorflow::Tensor;
// FakeSession is for testing only.
class FakeSession : public tensorflow::Session {
public:
FakeSession() {}
~FakeSession() override = default;
Status Create(const tensorflow::GraphDef& graph) override {
return tensorflow::errors::Unimplemented("not available");
}
Status Extend(const tensorflow::GraphDef& graph) override {
return tensorflow::errors::Unimplemented("not available");
}
Status Close() override {
return tensorflow::errors::Unimplemented("not available");
}
Status ListDevices(
std::vector<tensorflow::DeviceAttributes>* response) override {
return tensorflow::errors::Unimplemented("not available");
}
Status LocalDeviceManager(
const tensorflow::DeviceMgr** deviceMgrPtr) override {
// This method returns a null device manager without making an error.
// Users of this method will be notified since it will have a fake data.
*deviceMgrPtr = nullptr;
return Status::OK();
}
Status Run(const std::vector<std::pair<std::string, Tensor>>& inputs,
const std::vector<std::string>& output_names,
const std::vector<std::string>& target_nodes,
std::vector<Tensor>* outputs) override {
tensorflow::RunMetadata run_metadata;
return Run(tensorflow::RunOptions(), inputs, output_names, target_nodes,
outputs, &run_metadata);
}
Status Run(const tensorflow::RunOptions& run_options,
const std::vector<std::pair<std::string, Tensor>>& inputs,
const std::vector<std::string>& output_names,
const std::vector<std::string>& target_nodes,
std::vector<Tensor>* outputs,
tensorflow::RunMetadata* run_metadata) override {
return Run(run_options, inputs, output_names, target_nodes, outputs,
run_metadata, tensorflow::thread::ThreadPoolOptions());
}
Status Run(const tensorflow::RunOptions& run_options,
const std::vector<std::pair<std::string, Tensor>>& inputs,
const std::vector<std::string>& output_names,
const std::vector<std::string>& target_nodes,
std::vector<Tensor>* outputs,
tensorflow::RunMetadata* run_metadata,
const tensorflow::thread::ThreadPoolOptions& thread_pool_options)
override {
for (const std::string& output_name : output_names) {
Tensor output;
if (output_name == "dense/bias") {
outputs->push_back(
Tensor(tensorflow::DT_FLOAT, tensorflow::TensorShape({50})));
} else if (output_name == "dense/kernel") {
outputs->push_back(
Tensor(tensorflow::DT_FLOAT, tensorflow::TensorShape({100, 50})));
} else {
// Create a scalar float tensor.
outputs->push_back(
Tensor(tensorflow::DT_FLOAT, tensorflow::TensorShape({})));
}
}
return Status::OK();
}
};
// This pass is only available in the tf-opt binary for testing.
class LiftVariablesTestPass
: public PassWrapper<LiftVariablesTestPass, OperationPass<ModuleOp>> {
public:
LiftVariablesTestPass() { session_ = new FakeSession(); }
~LiftVariablesTestPass() override { delete session_; }
void runOnOperation() override {
ModuleOp module = getOperation();
if (failed(LiftVariables(module, session_))) signalPassFailure();
}
private:
Session* session_;
};
// This pass is only available in the tf-opt binary for testing.
class LiftVariablesInvalidSessionTestPass
: public PassWrapper<LiftVariablesInvalidSessionTestPass,
OperationPass<ModuleOp>> {
public:
void runOnOperation() override {
ModuleOp module = getOperation();
// Pass an invalid session argument, which is a nullptr.
if (failed(LiftVariables(module, /*session=*/nullptr))) signalPassFailure();
}
};
} // namespace tf_saved_model
} // namespace mlir
#endif // TENSORFLOW_COMPILER_MLIR_TENSORFLOW_TRANSFORMS_LIFT_VARIABLES_TEST_PASS_H_

View File

@ -0,0 +1,32 @@
/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#include "tensorflow/compiler/mlir/tensorflow/transforms/lift_variables_test_pass.h"
namespace mlir {
namespace tf_saved_model {
static PassRegistration<LiftVariablesTestPass> lift_variables_test_pass(
"tf-saved-model-lift-variables-test",
"Lift variables and save them as global tensors");
static PassRegistration<LiftVariablesInvalidSessionTestPass>
lift_variables_invalid_session_test_pass(
"tf-saved-model-lift-variables-invalid-session-test",
"Lift variables and save them as global tensors with an invalid "
"session");
} // namespace tf_saved_model
} // namespace mlir

View File

@ -73,9 +73,8 @@ LogicalResult CollectAndGroupOutsideClusterOps(Block* block,
}
// Moves `cluster_ops` to associated `launch_op` body.
void MoveOutsideClusterOpsToLaunchOp(
tf_device::LaunchOp launch_op,
const llvm::SmallVector<Operation*, 8>& cluster_ops) {
void MoveOutsideClusterOpsToLaunchOp(tf_device::LaunchOp launch_op,
llvm::ArrayRef<Operation*> cluster_ops) {
MLIRContext* context = launch_op.getContext();
Operation* terminator = launch_op.GetBody().getTerminator();
@ -123,7 +122,7 @@ void PropagateParallelExecuteReturnToReplicate(
// Extracts all externally provided operands of `cluster_ops`.
llvm::SmallSetVector<Value, 4> GetExternalOperands(
const llvm::SmallVector<Operation*, 8>& cluster_ops) {
llvm::ArrayRef<Operation*> cluster_ops) {
llvm::SmallSetVector<Value, 4> external_values;
for (Operation* op : cluster_ops) {
@ -143,7 +142,7 @@ llvm::SmallSetVector<Value, 4> GetExternalOperands(
// Extracts all externally used outputs of `cluster_ops`.
llvm::SmallVector<Value, 4> GetExternalOutputs(
const llvm::SmallVector<Operation*, 8>& cluster_ops) {
llvm::ArrayRef<Operation*> cluster_ops) {
llvm::SmallSetVector<Value, 4> external_outputs;
for (Operation* op : cluster_ops) {
@ -166,7 +165,7 @@ llvm::SmallVector<Value, 4> GetExternalOutputs(
// as an operand. If there are no external_inputs, set insertion point to first
// cluster_op.
void SetHostComputeInsertion(
OpBuilder* builder, const llvm::SmallVector<Operation*, 8>& cluster_ops,
OpBuilder* builder, llvm::ArrayRef<Operation*> cluster_ops,
const llvm::SmallSetVector<Value, 4>& external_inputs) {
if (external_inputs.empty()) builder->setInsertionPoint(cluster_ops.front());
for (const auto& cluster_op : cluster_ops) {
@ -183,9 +182,9 @@ void SetHostComputeInsertion(
// using `communication_key`.
TF::_HostComputeMlirOp CreateHostCompute(
OpBuilder* builder, tf_device::ClusterOp tpu_cluster,
const llvm::SmallVector<Operation*, 8>& cluster_ops,
llvm::ArrayRef<Operation*> cluster_ops,
const llvm::SmallSetVector<Value, 4>& inputs, llvm::ArrayRef<Value> outputs,
const std::string& communication_key) {
llvm::StringRef communication_key) {
llvm::SmallVector<Type, 4> device_output_types;
for (const auto& output : outputs)
device_output_types.push_back(output.getType());
@ -201,10 +200,9 @@ TF::_HostComputeMlirOp CreateHostCompute(
void MoveOutsideCompiledOps(
tf_device::ClusterOp tpu_cluster, llvm::StringRef outside_cluster_name,
tf_device::LaunchOp host_launch_op,
const llvm::SmallVector<Operation*, 8>& cluster_ops,
tf_device::LaunchOp host_launch_op, llvm::ArrayRef<Operation*> cluster_ops,
const llvm::SmallSetVector<Value, 4>& external_inputs,
const llvm::SmallVector<Value, 4>& external_outputs) {
llvm::ArrayRef<Value> external_outputs) {
if (external_inputs.empty() && external_outputs.empty()) {
MoveOutsideClusterOpsToLaunchOp(host_launch_op, cluster_ops);
return;

View File

@ -125,7 +125,8 @@ Status LowerTfOpToLhloWithDynamicShapes(mlir::ModuleOp module) {
pm.addNestedPass<mlir::FuncOp>(
absl::make_unique<MaterializeBroadcastsPass>());
pm.addNestedPass<mlir::FuncOp>(absl::make_unique<UnfuseBatchNormPass>());
pm.addPass(mlir::xla_hlo::createLegalizeToLhloPass());
pm.addPass(mlir::xla_hlo::createLegalizeToLhloPass(
/*results_escape_functions=*/true));
pm.addNestedPass<mlir::FuncOp>(mlir::xla_lhlo::createLhloCopyRemovalPass());
if (failed(pm.run(module))) {
@ -148,7 +149,12 @@ struct PropagateStaticKnowledge
// We do not change the signature so that we keep a somewhat stable ABI
// that is easy to undertand by tools.
mlir::LLVM::LLVMFuncOp func = getOperation();
// This only works if the function is local and we can rewrite it.
if (func.isExternal()) return;
mlir::OpBuilder b(func.getBody());
// Steal the LLVM representation of the index type from the third argument.
auto index_type = func.getArgument(3).getType();
mlir::Value one = b.create<mlir::LLVM::ConstantOp>(
func.getLoc(), index_type, b.getIntegerAttr(b.getIndexType(), 1));
@ -156,10 +162,21 @@ struct PropagateStaticKnowledge
func.getLoc(), index_type, b.getIntegerAttr(b.getIndexType(), 0));
uint32_t arg_pos = 0;
std::vector<uint32_t> positions;
for (mlir::Type arg_type : func_type.getInputs()) {
// Collect the agument and return types of the surrounding function.
auto arg_types = llvm::to_vector<4>(llvm::concat<const mlir::Type>(
func_type.getInputs(), func_type.getResults()));
for (mlir::Type arg_type : arg_types) {
if (!arg_type.isa<mlir::MemRefType>()) {
func.emitError() << "argument of surrounding func is not ranked memref";
signalPassFailure();
return;
}
positions.push_back(arg_pos);
// Replace the offset with zero. Offset is argument number 3.
func.getArgument(arg_pos + 2).replaceAllUsesWith(zero);
arg_pos += 3 + arg_type.cast<mlir::ShapedType>().getRank() * 2;
// Forward over base_ptr, aligned_ptr, offset, size and stride arguments.
arg_pos += 3 + arg_type.cast<mlir::MemRefType>().getRank() * 2;
// Replace the last stride with constant 1.
func.getArgument(arg_pos - 1).replaceAllUsesWith(one);
}
@ -169,17 +186,17 @@ struct PropagateStaticKnowledge
if (!same_shape.empty()) {
auto first = same_shape.front();
auto first_offset = positions.at(first);
mlir::ShapedType first_type =
func_type.getInput(first).cast<mlir::ShapedType>();
auto first_type = arg_types[first].cast<mlir::ShapedType>();
uint32_t rank = first_type.getRank();
for (auto same : same_shape.drop_front(1)) {
uint32_t same_offset = positions.at(same);
auto same_type = func_type.getInput(same).cast<mlir::ShapedType>();
auto same_type = arg_types[same].cast<mlir::ShapedType>();
if (same_type.getRank() != rank) {
func.emitOpError() << "same shape constraints on arguments with "
"non-matching shapes: #"
<< first << " and #" << same;
signalPassFailure();
continue;
}
for (uint32_t i = 0; i < 2 * rank; ++i) {
@ -237,18 +254,16 @@ StatusOr<std::vector<uint8_t>> tensorflow::kernel_gen::GenerateCubinForTfCode(
mlir::OwningModuleRef module = mlir::parseSourceString(tf_code, &context);
TF_RETURN_IF_ERROR(LowerTfOpToLhloWithDynamicShapes(module.get()));
TF_RETURN_IF_ERROR(
xla::mlir_gpu::LowerLHLOToGPU(module.get(), tile_sizes, unroll_factors,
/*collapseParallelLoops=*/false));
TF_RETURN_IF_ERROR(xla::mlir_gpu::LowerKernelBodiesToNVVM(module.get()));
// TODO(b/156985522): Figure out why we get a segfault when generating Tanh
// with 'same_shape' containing {0, 1}. We would also get the crash if we
// unconditionally call PropagateStaticShapeKnowledgeToKernel while
// 'same_shape' is empty.
if (!same_shape.empty()) {
TF_RETURN_IF_ERROR(
PropagateStaticShapeKnowledgeToKernel(module.get(), same_shape));
{
xla::mlir_gpu::LowerLHLOToGPUOptions options;
options.tile_sizes = tile_sizes;
options.unroll_factors = unroll_factors;
options.collapse_parallel_loops = false;
TF_RETURN_IF_ERROR(xla::mlir_gpu::LowerLHLOToGPU(module.get(), options));
}
TF_RETURN_IF_ERROR(xla::mlir_gpu::LowerKernelBodiesToNVVM(module.get()));
TF_RETURN_IF_ERROR(
PropagateStaticShapeKnowledgeToKernel(module.get(), same_shape));
mlir::OwningModuleRef kernel_module =
xla::mlir_gpu::ExtractKernelModule(*module).ValueOrDie();

View File

@ -377,6 +377,7 @@ cc_library(
"@llvm-project//mlir:LinalgOps",
"@llvm-project//mlir:LinalgTransforms",
"@llvm-project//mlir:Pass",
"@llvm-project//mlir:StandardOps",
"@llvm-project//mlir:TransformUtils",
],
alwayslink = 1,

View File

@ -115,6 +115,9 @@ StatusOr<mlir::FuncOp> HloFunctionImporter::ImportAsFunc(
llvm::ArrayRef<mlir::NamedAttribute> attrs;
auto function = mlir::FuncOp::create(mlir::UnknownLoc::get(context_),
computation_name, func_type, attrs);
auto visibility = computation_name == "main" ? FuncOp::Visibility::Public
: FuncOp::Visibility::Private;
function.setVisibility(visibility);
module_.push_back(function);
// Add to the map right away for function calls.

View File

@ -1112,21 +1112,13 @@ def HLO_DynamicReshapeOp: HLO_Op<"dynamic_reshape", [NoSideEffect]> {
let hasCustomHLOConverter = 1;
}
def ScatterDimensionNumbers : StructAttr<"ScatterDimensionNumbers", HLO_Dialect,
[StructFieldAttr<"update_window_dims", I64ElementsAttr>,
StructFieldAttr<"inserted_window_dims", I64ElementsAttr>,
StructFieldAttr<"scatter_dims_to_operand_dims", I64ElementsAttr>,
StructFieldAttr<"index_vector_dim", I64Attr>]> {
let description = "Structure of dimension information for scatter";
}
def HLO_ScatterOp: HLO_Op<"scatter", [RecursiveSideEffects]>,
BASE_HLO_ScatterOp {
let arguments = (ins
HLO_Tensor:$operand,
HLO_Tensor:$scatter_indices,
HLO_Tensor:$updates,
ScatterDimensionNumbers:$scatter_dimension_numbers,
ScatterDimensionNumbers<HLO_Dialect>:$scatter_dimension_numbers,
DefaultValuedAttr<BoolAttr, "false">:$indices_are_sorted,
DefaultValuedAttr<BoolAttr, "false">:$unique_indices
);

View File

@ -1098,6 +1098,15 @@ class BASE_HLO_ReshapeOp {
}];
}
class ScatterDimensionNumbers<Dialect dialect> : StructAttr<
"ScatterDimensionNumbers", dialect, [
StructFieldAttr<"update_window_dims", I64ElementsAttr>,
StructFieldAttr<"inserted_window_dims", I64ElementsAttr>,
StructFieldAttr<"scatter_dims_to_operand_dims", I64ElementsAttr>,
StructFieldAttr<"index_vector_dim", I64Attr>]> {
let description = "Structure of dimension information for scatter";
}
class BASE_HLO_ScatterOp {
string summary = "Scatter operator";

View File

@ -471,6 +471,12 @@ def LHLO_BatchNormTrainingOp : LHLO_Op<"batch_norm_training", []>,
);
}
// TODO(timshen): add a custom verifier.
def LHLO_BitcastOp: LHLO_Op<"bitcast", []> {
let arguments = (ins Arg<LHLO_Buffer, "", [MemRead]>:$input,
Arg<LHLO_Buffer, "", [MemWrite]>:$output);
}
def LHLO_BroadcastOp : LHLO_Op<"broadcast",
[]>, BASE_HLO_BroadcastOp {
let arguments = (ins
@ -578,6 +584,19 @@ def LHLO_ReshapeOp: LHLO_Op<"reshape", []>, BASE_HLO_ReshapeOp {
);
}
def LHLO_ScatterOp: LHLO_Op<"scatter", []>, BASE_HLO_ScatterOp {
let arguments = (ins
Arg<LHLO_Buffer, "", [MemRead]>:$operand,
Arg<LHLO_Buffer, "", [MemRead]>:$scatter_indices,
Arg<LHLO_Buffer, "", [MemRead]>:$updates,
Arg<LHLO_Buffer, "", [MemWrite]>:$output,
ScatterDimensionNumbers<LHLO_Dialect>:$scatter_dimension_numbers,
DefaultValuedAttr<BoolAttr, "false">:$indices_are_sorted,
DefaultValuedAttr<BoolAttr, "false">:$unique_indices
);
let regions = (region SizedRegion<1>:$update_computation);
}
def LHLO_SelectOp: LHLO_Op<"select", []>, BASE_HLO_SelectOp {
let arguments = (ins
@ -712,6 +731,44 @@ def LHLO_TriangularSolveOp: LHLO_Op<"triangular_solve", [SameOperandsElementType
);
}
// TODO(timshen): add a custom verifier.
def LHLO_MapOp: LHLO_Op<"map", [SameOperandsShape]>, BASE_HLO_MapOp {
let arguments = (ins
Arg<Variadic<LHLO_Buffer>, "", [MemRead]>:$operands,
Arg<LHLO_Buffer, "", [MemWrite]>:$output,
I64ElementsAttr:$dimensions
);
let regions = (region SizedRegion<1>:$computation);
}
def LHLO_RngGetAndUpdateStateOp: LHLO_Op<"rng_get_and_update_state", []> {
let arguments = (ins
Arg<MemRefOf<[UI64]>, "", [MemRead, MemWrite]>:$state,
I64Attr:$delta
);
}
// TODO(timshen): add a custom verifier.
def LHLO_SortOp: LHLO_Op<"sort", []>, BASE_HLO_SortOp {
let arguments = (ins
Arg<Variadic<LHLO_Buffer>, "", [MemRead]>:$operands,
LHLO_BufferOrTuple:$output,
DefaultValuedAttr<I64Attr, "-1">:$dimension,
DefaultValuedAttr<BoolAttr, "false">:$is_stable
);
let regions = (region SizedRegion<1>:$comparator);
}
def LHLO_TupleSelectOp: LHLO_Op<"tuple_select", [SameOperandsShape]> {
let arguments = (ins
Arg<LHLO_PredBuffer, "", [MemRead]>:$pred,
Arg<LHLO_Buffer, "", [MemRead]>:$on_true,
Arg<LHLO_Buffer, "", [MemRead]>:$on_false,
Arg<LHLO_Buffer, "", [MemWrite]>:$output
);
}
//===----------------------------------------------------------------------===//
// Late operations
//===----------------------------------------------------------------------===//

View File

@ -1,12 +1,13 @@
// RUN: xla-opt -hlo-legalize-to-lhlo -buffer-placement -split-input-file %s -o - | FileCheck %s
// RUN: xla-opt -hlo-legalize-to-lhlo -buffer-placement -split-input-file %s -o - | FileCheck --check-prefixes=PRE,BOTH %s
// RUN: xla-opt -hlo-legalize-to-lhlo=results-escape-function=true -buffer-placement -split-input-file %s -o - | FileCheck --check-prefixes=ESC,BOTH %s
// CHECK-LABEL: func @attrs
// BOTH-LABEL: func @attrs
func @attrs_copy(%operand: memref<2x2xf32>, %result: memref<2x2xf32>) {
%tensor_operand = tensor_load %operand : memref<2x2xf32>
%tensor_result = "xla_hlo.exponential"(%tensor_operand)
{some_attr_1 = "exp.1", some_attr_2 = dense<1> : tensor<1xi64>}
: (tensor<2x2xf32>) -> tensor<2x2xf32>
// CHECK: "xla_lhlo.exponential"(%{{.*}}, %{{.*}}) {some_attr_1 = "exp.1", some_attr_2 = dense<1> : tensor<1xi64>}
// BOTH: "xla_lhlo.exponential"(%{{.*}}, %{{.*}}) {some_attr_1 = "exp.1", some_attr_2 = dense<1> : tensor<1xi64>}
tensor_store %tensor_result, %result : memref<2x2xf32>
return
}
@ -16,13 +17,16 @@ func @attrs_copy(%operand: memref<2x2xf32>, %result: memref<2x2xf32>) {
func @return_func(%arg0: tensor<4xf32>) -> tensor<4xf32> {
return %arg0 : tensor<4xf32>
}
// CHECK: (%[[ARG0:.*]]: [[TYPE:.*]], %[[RESULT:.*]]: [[TYPE]])
// CHECK-NEXT: "xla_lhlo.copy"(%[[ARG0]], %[[RESULT]]) : ([[TYPE]], [[TYPE]]) -> ()
// CHECK-NEXT: return
// PRE: (%[[ARG0:.*]]: [[TYPE:.*]], %[[RESULT:.*]]: [[TYPE]])
// PRE-NEXT: "xla_lhlo.copy"(%[[ARG0]], %[[RESULT]]) : ([[TYPE]], [[TYPE]]) -> ()
// PRE-NEXT: return
// ESC: (%[[ARG0:.*]]: [[TYPE:.*]]) -> [[TYPE]]
// ESC-NOT: "xla_lhlo.copy"
// ESC-NEXT: return %[[ARG0]]
// -----
// CHECK-LABEL: func @func_op_long
// BOTH-LABEL: func @func_op_long
func @func_op_long(%arg0: tensor<4xf32>, %arg1: tensor<4xf32>) -> tensor<4xf32> {
%1 = xla_hlo.maximum %arg0, %arg1 : tensor<4xf32>
%2 = xla_hlo.add %arg0, %1 : tensor<4xf32>
@ -31,89 +35,91 @@ func @func_op_long(%arg0: tensor<4xf32>, %arg1: tensor<4xf32>) -> tensor<4xf32>
%5 = xla_hlo.multiply %2, %4 : tensor<4xf32>
return %5 : tensor<4xf32>
}
// CHECK: (%[[NEW_ARG0:.*]]: memref<4xf32>, %[[NEW_ARG1:.*]]: memref<4xf32>, %[[RESULT:.*]]: memref<4xf32>)
// CHECK-NEXT: %[[MAX_RESULT:.*]] = alloc() : memref<4xf32>
// CHECK-NEXT: "xla_lhlo.maximum"(%[[NEW_ARG0]], %[[NEW_ARG1]], %[[MAX_RESULT]])
// CHECK-NEXT: %[[ADD_RESULT:.*]] = alloc() : memref<4xf32>
// CHECK-NEXT: "xla_lhlo.add"(%[[NEW_ARG0]], %[[MAX_RESULT]], %[[ADD_RESULT]])
// CHECK-NEXT: dealloc %[[MAX_RESULT]] : memref<4xf32>
// CHECK-NEXT: %[[MIN_RESULT:.*]] = alloc() : memref<4xf32>
// CHECK-NEXT: "xla_lhlo.minimum"(%[[NEW_ARG0]], %[[NEW_ARG1]], %[[MIN_RESULT]])
// CHECK-NEXT: %[[SUB_RESULT:.*]] = alloc() : memref<4xf32>
// CHECK-NEXT: "xla_lhlo.subtract"(%[[NEW_ARG1]], %[[MIN_RESULT]], %[[SUB_RESULT]])
// CHECK-NEXT: dealloc %[[MIN_RESULT]] : memref<4xf32>
// CHECK-NEXT: %[[MUL_RESULT:.*]] = alloc() : memref<4xf32>
// CHECK-NEXT: "xla_lhlo.multiply"(%[[ADD_RESULT]], %[[SUB_RESULT]], %[[MUL_RESULT]])
// CHECK-NEXT: dealloc %[[SUB_RESULT]] : memref<4xf32>
// CHECK-NEXT: dealloc %[[ADD_RESULT]] : memref<4xf32>
// CHECK-NEXT: "xla_lhlo.copy"(%[[MUL_RESULT]], %[[RESULT]]) : (memref<4xf32>, memref<4xf32>) -> ()
// CHECK-NEXT: dealloc %[[MUL_RESULT]] : memref<4xf32>
// CHECK-NEXT: return
// PRE: (%[[NEW_ARG0:.*]]: memref<4xf32>, %[[NEW_ARG1:.*]]: memref<4xf32>, %[[RESULT:.*]]: memref<4xf32>)
// ESC: (%[[NEW_ARG0:.*]]: memref<4xf32>, %[[NEW_ARG1:.*]]: memref<4xf32>) -> memref<4xf32>
// BOTH-NEXT: %[[MAX_RESULT:.*]] = alloc() : memref<4xf32>
// BOTH-NEXT: "xla_lhlo.maximum"(%[[NEW_ARG0]], %[[NEW_ARG1]], %[[MAX_RESULT]])
// BOTH-NEXT: %[[ADD_RESULT:.*]] = alloc() : memref<4xf32>
// BOTH-NEXT: "xla_lhlo.add"(%[[NEW_ARG0]], %[[MAX_RESULT]], %[[ADD_RESULT]])
// BOTH-NEXT: dealloc %[[MAX_RESULT]] : memref<4xf32>
// BOTH-NEXT: %[[MIN_RESULT:.*]] = alloc() : memref<4xf32>
// BOTH-NEXT: "xla_lhlo.minimum"(%[[NEW_ARG0]], %[[NEW_ARG1]], %[[MIN_RESULT]])
// BOTH-NEXT: %[[SUB_RESULT:.*]] = alloc() : memref<4xf32>
//  BOTH-NEXT: "xla_lhlo.subtract"(%[[NEW_ARG1]], %[[MIN_RESULT]], %[[SUB_RESULT]])
// BOTH-NEXT: dealloc %[[MIN_RESULT]] : memref<4xf32>
// BOTH-NEXT: %[[MUL_RESULT:.*]] = alloc() : memref<4xf32>
// BOTH-NEXT: "xla_lhlo.multiply"(%[[ADD_RESULT]], %[[SUB_RESULT]], %[[MUL_RESULT]])
// BOTH-NEXT: dealloc %[[SUB_RESULT]] : memref<4xf32>
// BOTH-NEXT: dealloc %[[ADD_RESULT]] : memref<4xf32>
// PRE-NEXT: "xla_lhlo.copy"(%[[MUL_RESULT]], %[[RESULT]]) : (memref<4xf32>, memref<4xf32>) -> ()
// PRE-NEXT: dealloc %[[MUL_RESULT]] : memref<4xf32>
// PRE-NEXT: return
// ESC-NEXT: return %[[MUL_RESULT]] : memref<4xf32>
// -----
// CHECK-LABEL: func @fusion
// BOTH-LABEL: func @fusion
func @fusion(%multiplier: memref<2x2xf32>, %summand_1: memref<2x2xf32>,
%summand_2: memref<2x2xf32>, %result: memref<2x2xf32>) {
// CHECK: (%{{.*}}: {{.*}}, {{.*}}: {{.*}}, {{.*}}: {{.*}}, %[[RESULT:.*]]: {{.*}})
// CHECK-NEXT: %[[ADD_RESULT:.*]] = alloc() : memref<2x2xf32>
// BOTH: (%{{.*}}: {{.*}}, {{.*}}: {{.*}}, {{.*}}: {{.*}}, %[[RESULT:.*]]: {{.*}})
// BOTH-NEXT: %[[ADD_RESULT:.*]] = alloc() : memref<2x2xf32>
%tensor_summand_1 = tensor_load %summand_1 : memref<2x2xf32>
%tensor_summand_2 = tensor_load %summand_2 : memref<2x2xf32>
%sum = "xla_hlo.add"(%tensor_summand_1, %tensor_summand_2)
: (tensor<2x2xf32>, tensor<2x2xf32>) -> tensor<2x2xf32>
// CHECK-NEXT: "xla_lhlo.add"(%{{.*}}, %{{.*}}, %[[ADD_RESULT]])
// CHECK-NEXT: %[[MUL_RESULT:.*]] = alloc() : memref<2x2xf32>
// BOTH-NEXT: "xla_lhlo.add"(%{{.*}}, %{{.*}}, %[[ADD_RESULT]])
// BOTH-NEXT: %[[MUL_RESULT:.*]] = alloc() : memref<2x2xf32>
%tensor_multiplier = tensor_load %multiplier : memref<2x2xf32>
%tensor_result = "xla_hlo.multiply"(%sum, %tensor_multiplier)
: (tensor<2x2xf32>, tensor<2x2xf32>) -> tensor<2x2xf32>
// CHECK-NEXT: "xla_lhlo.multiply"(%[[ADD_RESULT]], %{{.*}}, %[[MUL_RESULT]])
// CHECK-NEXT: dealloc %[[ADD_RESULT]] : memref<2x2xf32>
// CHECK-NEXT: "xla_lhlo.copy"(%[[MUL_RESULT]], %[[RESULT]])
// BOTH-NEXT: "xla_lhlo.multiply"(%[[ADD_RESULT]], %{{.*}}, %[[MUL_RESULT]])
// BOTH-NEXT: dealloc %[[ADD_RESULT]] : memref<2x2xf32>
// BOTH-NEXT: "xla_lhlo.copy"(%[[MUL_RESULT]], %[[RESULT]])
tensor_store %tensor_result, %result : memref<2x2xf32>
// CHECK-NEXT: dealloc %[[MUL_RESULT]] : memref<2x2xf32>
// CHECK-NEXT: "xla_lhlo.terminator"() : () -> ()
"xla_lhlo.terminator"() : () -> ()
// BOTH-NEXT: dealloc %[[MUL_RESULT]] : memref<2x2xf32>
// BOTH-NEXT: return
return
}
// -----
// CHECK-LABEL: func @copy
// BOTH-LABEL: func @copy
func @copy(%operand: memref<2x2xf32>, %result: memref<2x2xf32>) {
%tensor_operand = tensor_load %operand : memref<2x2xf32>
%tensor_result = "xla_hlo.copy"(%tensor_operand)
: (tensor<2x2xf32>) -> tensor<2x2xf32>
// CHECK: "xla_lhlo.copy"(%{{.*}}, %{{.*}})
// BOTH: "xla_lhlo.copy"(%{{.*}}, %{{.*}})
tensor_store %tensor_result, %result : memref<2x2xf32>
return
}
// -----
// CHECK-LABEL: func @exp
// BOTH-LABEL: func @exp
func @exp(%operand: memref<2x2xf32>, %result: memref<2x2xf32>) {
%tensor_operand = tensor_load %operand : memref<2x2xf32>
%tensor_result = "xla_hlo.exponential"(%tensor_operand)
: (tensor<2x2xf32>) -> tensor<2x2xf32>
// CHECK: "xla_lhlo.exponential"(%{{.*}}, %{{.*}})
// BOTH: "xla_lhlo.exponential"(%{{.*}}, %{{.*}})
tensor_store %tensor_result, %result : memref<2x2xf32>
return
}
// -----
// CHECK-LABEL: func @log
// BOTH-LABEL: func @log
func @log(%operand: memref<2x2xf32>, %result: memref<2x2xf32>) {
%tensor_operand = tensor_load %operand : memref<2x2xf32>
%tensor_result = "xla_hlo.log"(%tensor_operand)
: (tensor<2x2xf32>) -> tensor<2x2xf32>
// CHECK: "xla_lhlo.log"(%{{.*}}, %{{.*}})
// BOTH: "xla_lhlo.log"(%{{.*}}, %{{.*}})
tensor_store %tensor_result, %result : memref<2x2xf32>
return
}
// -----
// CHECK-LABEL: func @select
// BOTH-LABEL: func @select
func @select(%pred: memref<2x2xi1>, %lhs: memref<2x2xf32>,
%rhs: memref<2x2xf32>, %result: memref<2x2xf32>) {
%tensor_pred = tensor_load %pred : memref<2x2xi1>
@ -121,34 +127,34 @@ func @select(%pred: memref<2x2xi1>, %lhs: memref<2x2xf32>,
%tensor_rhs = tensor_load %rhs : memref<2x2xf32>
%tensor_result = "xla_hlo.select"(%tensor_pred, %tensor_lhs, %tensor_rhs)
: (tensor<2x2xi1>, tensor<2x2xf32>, tensor<2x2xf32>) -> tensor<2x2xf32>
// CHECK: "xla_lhlo.select"(%{{.*}}, %{{.*}}, %{{.*}}, %{{.*}})
// BOTH: "xla_lhlo.select"(%{{.*}}, %{{.*}}, %{{.*}}, %{{.*}})
tensor_store %tensor_result, %result : memref<2x2xf32>
return
}
// -----
// CHECK-LABEL: func @compare
// BOTH-LABEL: func @compare
func @compare(%lhs: memref<2x2xf32>, %rhs: memref<2x2xf32>, %result: memref<2x2xi1>) {
%tensor_lhs = tensor_load %lhs : memref<2x2xf32>
%tensor_rhs = tensor_load %rhs : memref<2x2xf32>
%tensor_result = "xla_hlo.compare"(%tensor_lhs, %tensor_rhs)
{comparison_direction = "EQ"}
: (tensor<2x2xf32>, tensor<2x2xf32>) -> tensor<2x2xi1>
// CHECK: "xla_lhlo.compare"(%{{.*}}, %{{.*}}, %{{.*}}) {comparison_direction = "EQ"}
// BOTH: "xla_lhlo.compare"(%{{.*}}, %{{.*}}, %{{.*}}) {comparison_direction = "EQ"}
tensor_store %tensor_result, %result : memref<2x2xi1>
return
}
// -----
// CHECK-LABEL: func @broadcast
// BOTH-LABEL: func @broadcast
func @broadcast(%operand: memref<5xf32>, %result: memref<10x5xf32>) {
%tensor_operand = tensor_load %operand : memref<5xf32>
%tensor_result = "xla_hlo.broadcast_in_dim"(%tensor_operand)
{broadcast_dimensions = dense<1> : tensor<1xi64>}
: (tensor<5xf32>) -> tensor<10x5xf32>
// CHECK: "xla_lhlo.broadcast_in_dim"(%{{.*}}, %{{.*}}) {broadcast_dimensions = dense<1> : tensor<1xi64>}
// BOTH: "xla_lhlo.broadcast_in_dim"(%{{.*}}, %{{.*}}) {broadcast_dimensions = dense<1> : tensor<1xi64>}
tensor_store %tensor_result, %result : memref<10x5xf32>
return
}
@ -157,55 +163,55 @@ func @broadcast(%operand: memref<5xf32>, %result: memref<10x5xf32>) {
func @external_func() -> tensor<3xi64>
// CHECK: #[[MAP:.*]] = affine_map<(d0, d1)[s0, s1] -> (d0 * s0 + d1 * s1)>
// BOTH: #[[MAP:.*]] = affine_map<(d0, d1)[s0, s1] -> (d0 * s0 + d1 * s1)>
// CHECK-LABEL: func @dyn_broadcast
// BOTH-LABEL: func @dyn_broadcast
func @dyn_broadcast(%operand: memref<?x?xf32>) {
// CHECK-SAME: (%[[OPERAND:.*]]: memref<?x?xf32>)
// BOTH-SAME: (%[[OPERAND:.*]]: memref<?x?xf32>)
%tensor_operand = tensor_load %operand : memref<?x?xf32>
%shape = call @external_func() : () -> tensor<3xi64>
%tensor_result = "xla_hlo.dynamic_broadcast_in_dim"(%tensor_operand, %shape) {
broadcast_dimensions = dense<[1, 2]> : tensor<2xi64>
} : (tensor<?x?xf32>, tensor<3xi64>) -> tensor<?x?x?xf32>
// CHECK: %[[SHAPE:.*]] = call @external_func()
// CHECK: %[[C0:.*]] = constant 0 : index
// CHECK: %[[EL0:.*]] = extract_element %[[SHAPE]][%[[C0]]] : tensor<3xi64>
// CHECK: %[[IC0:.*]] = index_cast %[[EL0]] : i64 to index
// CHECK: %[[C1:.*]] = constant 1 : index
// CHECK: %[[EL1:.*]] = extract_element %[[SHAPE]][%[[C1]]] : tensor<3xi64>
// CHECK: %[[IC1:.*]] = index_cast %[[EL1]] : i64 to index
// CHECK: %[[C2:.*]] = constant 2 : index
// CHECK: %[[EL2:.*]] = extract_element %[[SHAPE]][%[[C2]]] : tensor<3xi64>
// CHECK: %[[IC2:.*]] = index_cast %[[EL2]] : i64 to index
// CHECK: %[[RESULT:.*]] = alloc(%[[IC0]], %[[IC1]], %[[IC2]])
// BOTH: %[[SHAPE:.*]] = call @external_func()
// BOTH: %[[C0:.*]] = constant 0 : index
// BOTH: %[[EL0:.*]] = extract_element %[[SHAPE]][%[[C0]]] : tensor<3xi64>
// BOTH: %[[IC0:.*]] = index_cast %[[EL0]] : i64 to index
// BOTH: %[[C1:.*]] = constant 1 : index
// BOTH: %[[EL1:.*]] = extract_element %[[SHAPE]][%[[C1]]] : tensor<3xi64>
// BOTH: %[[IC1:.*]] = index_cast %[[EL1]] : i64 to index
// BOTH: %[[C2:.*]] = constant 2 : index
// BOTH: %[[EL2:.*]] = extract_element %[[SHAPE]][%[[C2]]] : tensor<3xi64>
// BOTH: %[[IC2:.*]] = index_cast %[[EL2]] : i64 to index
// BOTH: %[[RESULT:.*]] = alloc(%[[IC0]], %[[IC1]], %[[IC2]])
// CHECK: %[[C0_:.*]] = constant 0 : index
// CHECK: %[[C1_:.*]] = constant 1 : index
// BOTH: %[[C0_:.*]] = constant 0 : index
// BOTH: %[[C1_:.*]] = constant 1 : index
// CHECK: %[[C1__:.*]] = constant 1 : index
// CHECK: %[[EL1_:.*]] = extract_element %[[SHAPE]]{{\[}}%[[C1__]]] : tensor<3xi64>
// CHECK: %[[C0___:.*]] = constant 0 : index
// CHECK: %[[OPERAND_DIM_0:.*]] = dim %[[OPERAND]], %[[C0___]] : memref<?x?xf32>
// CHECK: %[[RESULT_DIM_1:.*]] = index_cast %[[EL1_]] : i64 to index
// CHECK: %[[EXPAND_0:.*]] = cmpi "slt", %[[OPERAND_DIM_0]], %[[RESULT_DIM_1]]
// CHECK: %[[STRIDE_0:.*]] = select %[[EXPAND_0]], %[[C0_]], %[[C1_]] : index
// BOTH: %[[C1__:.*]] = constant 1 : index
// BOTH: %[[EL1_:.*]] = extract_element %[[SHAPE]]{{\[}}%[[C1__]]] : tensor<3xi64>
// BOTH: %[[C0___:.*]] = constant 0 : index
// BOTH: %[[OPERAND_DIM_0:.*]] = dim %[[OPERAND]], %[[C0___]] : memref<?x?xf32>
// BOTH: %[[RESULT_DIM_1:.*]] = index_cast %[[EL1_]] : i64 to index
// BOTH: %[[EXPAND_0:.*]] = cmpi "slt", %[[OPERAND_DIM_0]], %[[RESULT_DIM_1]]
// BOTH: %[[STRIDE_0:.*]] = select %[[EXPAND_0]], %[[C0_]], %[[C1_]] : index
// CHECK: %[[C2_:.*]] = constant 2 : index
// CHECK: %[[EL2_:.*]] = extract_element %[[SHAPE]]{{\[}}%[[C2_]]] : tensor<3xi64>
// CHECK: %[[C1___:.*]] = constant 1 : index
// CHECK: %[[OPERAND_DIM_1:.*]] = dim %[[OPERAND]], %[[C1___]] : memref<?x?xf32>
// CHECK: %[[RESULT_DIM_2:.*]] = index_cast %[[EL2_]] : i64 to index
// CHECK: %[[EXPAND_1:.*]] = cmpi "slt", %[[OPERAND_DIM_1]], %[[RESULT_DIM_2]]
// CHECK: %[[STRIDE_1:.*]] = select %[[EXPAND_1]], %[[C0_]], %[[C1_]] : index
// BOTH: %[[C2_:.*]] = constant 2 : index
// BOTH: %[[EL2_:.*]] = extract_element %[[SHAPE]]{{\[}}%[[C2_]]] : tensor<3xi64>
// BOTH: %[[C1___:.*]] = constant 1 : index
// BOTH: %[[OPERAND_DIM_1:.*]] = dim %[[OPERAND]], %[[C1___]] : memref<?x?xf32>
// BOTH: %[[RESULT_DIM_2:.*]] = index_cast %[[EL2_]] : i64 to index
// BOTH: %[[EXPAND_1:.*]] = cmpi "slt", %[[OPERAND_DIM_1]], %[[RESULT_DIM_2]]
// BOTH: %[[STRIDE_1:.*]] = select %[[EXPAND_1]], %[[C0_]], %[[C1_]] : index
// CHECK: %[[TRANSFORMED_MEMREF:.*]] = xla_lhlo.dynamic_memref_cast
// CHECK-SAME: %[[OPERAND]](%[[RESULT_DIM_1]], %[[RESULT_DIM_2]])
// CHECK-SAME: {{\[}}%[[STRIDE_0]], %[[STRIDE_1]]]
// CHECK-SAME: : memref<?x?xf32> -> memref<?x?xf32, #map0>
// BOTH: %[[TRANSFORMED_MEMREF:.*]] = xla_lhlo.dynamic_memref_cast
// BOTH-SAME: %[[OPERAND]](%[[RESULT_DIM_1]], %[[RESULT_DIM_2]])
// BOTH-SAME: {{\[}}%[[STRIDE_0]], %[[STRIDE_1]]]
// BOTH-SAME: : memref<?x?xf32> -> memref<?x?xf32, #map0>
// CHECK: "xla_lhlo.broadcast_in_dim"(%[[TRANSFORMED_MEMREF]], %[[RESULT]]) {
// CHECK-SAME: broadcast_dimensions = dense<[1, 2]> : tensor<2xi64>
// CHECK-SAME: } : (memref<?x?xf32, #[[MAP]]>, memref<?x?x?xf32>) -> ()
// BOTH: "xla_lhlo.broadcast_in_dim"(%[[TRANSFORMED_MEMREF]], %[[RESULT]]) {
// BOTH-SAME: broadcast_dimensions = dense<[1, 2]> : tensor<2xi64>
// BOTH-SAME: } : (memref<?x?xf32, #[[MAP]]>, memref<?x?x?xf32>) -> ()
// Do not store the value back to avoid the tensor-store being rewritten to
// a copy into the pre-allocated argument.
@ -214,7 +220,7 @@ func @dyn_broadcast(%operand: memref<?x?xf32>) {
// -----
// CHECK-LABEL: func @complex
// BOTH-LABEL: func @complex
func @complex(%real: memref<2x2xf32>,
%imag: memref<2x2xf32>,
%result: memref<2x2xcomplex<f32>>) {
@ -222,164 +228,164 @@ func @complex(%real: memref<2x2xf32>,
%tensor_imag = tensor_load %imag : memref<2x2xf32>
%tensor_result = "xla_hlo.complex"(%tensor_real, %tensor_imag)
: (tensor<2x2xf32>, tensor<2x2xf32>) -> tensor<2x2xcomplex<f32>>
// CHECK: "xla_lhlo.complex"(%{{.*}}, %{{.*}})
// BOTH: "xla_lhlo.complex"(%{{.*}}, %{{.*}})
tensor_store %tensor_result, %result : memref<2x2xcomplex<f32>>
return
}
// -----
// CHECK-LABEL: func @real
// BOTH-LABEL: func @real
func @real(%operand: memref<2x2xcomplex<f32>>, %result: memref<2x2xf32>) {
%tensor_operand = tensor_load %operand : memref<2x2xcomplex<f32>>
%tensor_result = "xla_hlo.real"(%tensor_operand)
: (tensor<2x2xcomplex<f32>>) -> tensor<2x2xf32>
// CHECK: "xla_lhlo.real"(%{{.*}}, %{{.*}})
// BOTH: "xla_lhlo.real"(%{{.*}}, %{{.*}})
tensor_store %tensor_result, %result : memref<2x2xf32>
return
}
// -----
// CHECK-LABEL: func @imag
// BOTH-LABEL: func @imag
func @imag(%operand: memref<2x2xcomplex<f32>>, %result: memref<2x2xf32>) {
%tensor_operand = tensor_load %operand : memref<2x2xcomplex<f32>>
%tensor_result = "xla_hlo.imag"(%tensor_operand)
: (tensor<2x2xcomplex<f32>>) -> tensor<2x2xf32>
// CHECK: "xla_lhlo.imag"(%{{.*}}, %{{.*}})
// BOTH: "xla_lhlo.imag"(%{{.*}}, %{{.*}})
tensor_store %tensor_result, %result : memref<2x2xf32>
return
}
// -----
// CHECK-LABEL: func @iota
// BOTH-LABEL: func @iota
func @iota(%result: memref<10xi32>) {
%tensor_result = "xla_hlo.iota"()
{iota_dimension = 0 : i64} : () -> tensor<10xi32>
// CHECK: "xla_lhlo.iota"(%{{.*}}) {iota_dimension = 0 : i64}
// BOTH: "xla_lhlo.iota"(%{{.*}}) {iota_dimension = 0 : i64}
tensor_store %tensor_result, %result : memref<10xi32>
return
}
// -----
// CHECK-LABEL: func @abs
// BOTH-LABEL: func @abs
func @abs(%operand: memref<2x2xf32>, %result: memref<2x2xf32>) {
%tensor_operand = tensor_load %operand : memref<2x2xf32>
%tensor_result = "xla_hlo.abs"(%tensor_operand)
: (tensor<2x2xf32>) -> tensor<2x2xf32>
// CHECK: "xla_lhlo.abs"(%{{.*}}, %{{.*}})
// BOTH: "xla_lhlo.abs"(%{{.*}}, %{{.*}})
tensor_store %tensor_result, %result : memref<2x2xf32>
return
}
// -----
// CHECK-LABEL: func @ceil
// BOTH-LABEL: func @ceil
func @ceil(%operand: memref<2x2xf32>, %result: memref<2x2xf32>) {
%tensor_operand = tensor_load %operand : memref<2x2xf32>
%tensor_result = "xla_hlo.ceil"(%tensor_operand)
: (tensor<2x2xf32>) -> tensor<2x2xf32>
// CHECK: "xla_lhlo.ceil"(%{{.*}}, %{{.*}})
// BOTH: "xla_lhlo.ceil"(%{{.*}}, %{{.*}})
tensor_store %tensor_result, %result : memref<2x2xf32>
return
}
// -----
// CHECK-LABEL: func @convert
// BOTH-LABEL: func @convert
func @convert(%operand: memref<2x2xf32>, %result: memref<2x2xf32>) {
%tensor_operand = tensor_load %operand : memref<2x2xf32>
%tensor_result = "xla_hlo.convert"(%tensor_operand)
: (tensor<2x2xf32>) -> tensor<2x2xf32>
// CHECK: "xla_lhlo.copy"(%{{.*}}, %{{.*}})
// CHECK-NOT: tensor_store
// BOTH: "xla_lhlo.copy"(%{{.*}}, %{{.*}})
// BOTH-NOT: tensor_store
tensor_store %tensor_result, %result : memref<2x2xf32>
return
}
// -----
// CHECK-LABEL: func @cos
// BOTH-LABEL: func @cos
func @cos(%operand: memref<2x2xf32>, %result: memref<2x2xf32>) {
%tensor_operand = tensor_load %operand : memref<2x2xf32>
%tensor_result = "xla_hlo.cosine"(%tensor_operand)
: (tensor<2x2xf32>) -> tensor<2x2xf32>
// CHECK: "xla_lhlo.cosine"(%{{.*}}, %{{.*}})
// BOTH: "xla_lhlo.cosine"(%{{.*}}, %{{.*}})
tensor_store %tensor_result, %result : memref<2x2xf32>
return
}
// -----
// CHECK-LABEL: func @neg
// BOTH-LABEL: func @neg
func @neg(%operand: memref<2x2xf32>, %result: memref<2x2xf32>) {
%tensor_operand = tensor_load %operand : memref<2x2xf32>
%tensor_result = "xla_hlo.negate"(%tensor_operand)
: (tensor<2x2xf32>) -> tensor<2x2xf32>
// CHECK: "xla_lhlo.negate"(%{{.*}}, %{{.*}})
// BOTH: "xla_lhlo.negate"(%{{.*}}, %{{.*}})
tensor_store %tensor_result, %result : memref<2x2xf32>
return
}
// -----
// CHECK-LABEL: func @rsqrt
// BOTH-LABEL: func @rsqrt
func @rsqrt(%operand: memref<2x2xf32>, %result: memref<2x2xf32>) {
%tensor_operand = tensor_load %operand : memref<2x2xf32>
%tensor_result = "xla_hlo.rsqrt"(%tensor_operand)
: (tensor<2x2xf32>) -> tensor<2x2xf32>
// CHECK: "xla_lhlo.rsqrt"(%{{.*}}, %{{.*}})
// BOTH: "xla_lhlo.rsqrt"(%{{.*}}, %{{.*}})
tensor_store %tensor_result, %result : memref<2x2xf32>
return
}
// -----
// CHECK-LABEL: func @sign
// BOTH-LABEL: func @sign
func @sign(%operand: memref<2x2xf32>, %result: memref<2x2xf32>) {
%tensor_operand = tensor_load %operand : memref<2x2xf32>
%tensor_result = "xla_hlo.sign"(%tensor_operand)
: (tensor<2x2xf32>) -> tensor<2x2xf32>
// CHECK: "xla_lhlo.sign"(%{{.*}}, %{{.*}})
// BOTH: "xla_lhlo.sign"(%{{.*}}, %{{.*}})
tensor_store %tensor_result, %result : memref<2x2xf32>
return
}
// -----
// CHECK-LABEL: func @sqrt
// BOTH-LABEL: func @sqrt
func @sqrt(%operand: memref<2x2xf32>, %result: memref<2x2xf32>) {
%tensor_operand = tensor_load %operand : memref<2x2xf32>
%tensor_result = "xla_hlo.sqrt"(%tensor_operand)
: (tensor<2x2xf32>) -> tensor<2x2xf32>
// CHECK: "xla_lhlo.sqrt"(%{{.*}}, %{{.*}})
// BOTH: "xla_lhlo.sqrt"(%{{.*}}, %{{.*}})
tensor_store %tensor_result, %result : memref<2x2xf32>
return
}
// -----
// CHECK-LABEL: func @tanh
// BOTH-LABEL: func @tanh
func @tanh(%operand: memref<2x2xf32>, %result: memref<2x2xf32>) {
%tensor_operand = tensor_load %operand : memref<2x2xf32>
%tensor_result = "xla_hlo.tanh"(%tensor_operand)
: (tensor<2x2xf32>) -> tensor<2x2xf32>
// CHECK: "xla_lhlo.tanh"(%{{.*}}, %{{.*}})
// BOTH: "xla_lhlo.tanh"(%{{.*}}, %{{.*}})
tensor_store %tensor_result, %result : memref<2x2xf32>
return
}
// -----
// CHECK-LABEL: func @remainder
// BOTH-LABEL: func @remainder
func @remainder(%lhs: memref<2x2xf32>, %rhs: memref<2x2xf32>, %result: memref<2x2xf32>) {
%tensor_lhs = tensor_load %lhs : memref<2x2xf32>
%tensor_rhs = tensor_load %rhs : memref<2x2xf32>
%tensor_result = "xla_hlo.remainder"(%tensor_lhs, %tensor_rhs)
: (tensor<2x2xf32>, tensor<2x2xf32>) -> tensor<2x2xf32>
// CHECK: "xla_lhlo.remainder"(%{{.*}}, %{{.*}}, %{{.*}})
// BOTH: "xla_lhlo.remainder"(%{{.*}}, %{{.*}}, %{{.*}})
tensor_store %tensor_result, %result : memref<2x2xf32>
return
}
@ -387,76 +393,79 @@ func @remainder(%lhs: memref<2x2xf32>, %rhs: memref<2x2xf32>, %result: memref<2x
// -----
// Dynamic shape binary element-wise operation.
// CHECK-LABEL: func @add_dyn
// BOTH-LABEL: func @add_dyn
func @add_dyn(%lhs: tensor<?x?xf32>, %rhs: tensor<?x?xf32>) {
%result = "xla_hlo.add"(%lhs, %rhs)
: (tensor<?x?xf32>, tensor<?x?xf32>) -> tensor<?x?xf32>
// CHECK: %[[C0:.*]] = constant 0 : index
// CHECK: %[[DIM0:.*]] = dim %arg0, %[[C0]] : memref<?x?xf32>
// CHECK: %[[IC0:.*]] = index_cast %[[DIM0]] : index to i64
// CHECK: %[[C1:.*]] = constant 1 : index
// CHECK: %[[DIM1:.*]] = dim %arg0, %[[C1]] : memref<?x?xf32>
// CHECK: %[[IC1:.*]] = index_cast %[[DIM1]] : index to i64
// CHECK: %[[SHAPE:.*]] = tensor_from_elements(%[[IC0]], %[[IC1]]) : tensor<2xi64>
// CHECK: %[[C0_:.*]] = constant 0 : index
// CHECK: %[[EE0:.*]] = extract_element %[[SHAPE]][%[[C0_]]] : tensor<2xi64>
// CHECK: %[[ICS0:.*]] = index_cast %[[EE0]] : i64 to index
// CHECK: %[[C1_:.*]] = constant 1 : index
// CHECK: %[[EE1:.*]] = extract_element %[[SHAPE]][%[[C1_]]] : tensor<2xi64>
// CHECK: %[[ICS1:.*]] = index_cast %[[EE1]] : i64 to index
// CHECK: %[[RESULT:.*]] = alloc(%[[ICS0]], %[[ICS1]])
// CHECK: "xla_lhlo.add"(%arg0, %arg1, %[[RESULT]]) : (memref<?x?xf32>, memref<?x?xf32>, memref<?x?xf32>) -> ()
// BOTH: %[[C0:.*]] = constant 0 : index
// BOTH: %[[DIM0:.*]] = dim %arg0, %[[C0]] : memref<?x?xf32>
// BOTH: %[[IC0:.*]] = index_cast %[[DIM0]] : index to i64
// BOTH: %[[C1:.*]] = constant 1 : index
// BOTH: %[[DIM1:.*]] = dim %arg0, %[[C1]] : memref<?x?xf32>
// BOTH: %[[IC1:.*]] = index_cast %[[DIM1]] : index to i64
// BOTH: %[[SHAPE:.*]] = tensor_from_elements(%[[IC0]], %[[IC1]]) : tensor<2xi64>
// BOTH: %[[C0_:.*]] = constant 0 : index
// BOTH: %[[EE0:.*]] = extract_element %[[SHAPE]][%[[C0_]]] : tensor<2xi64>
// BOTH: %[[ICS0:.*]] = index_cast %[[EE0]] : i64 to index
// BOTH: %[[C1_:.*]] = constant 1 : index
// BOTH: %[[EE1:.*]] = extract_element %[[SHAPE]][%[[C1_]]] : tensor<2xi64>
// BOTH: %[[ICS1:.*]] = index_cast %[[EE1]] : i64 to index
// BOTH: %[[RESULT:.*]] = alloc(%[[ICS0]], %[[ICS1]])
// BOTH: "xla_lhlo.add"(%arg0, %arg1, %[[RESULT]]) : (memref<?x?xf32>, memref<?x?xf32>, memref<?x?xf32>) -> ()
return
}
// -----
// Dynamic shape unary element-wise operation.
// CHECK-LABEL: func @tanh_dyn
// BOTH-LABEL: func @tanh_dyn
func @tanh_dyn(%arg0: tensor<?x?xf32>) {
%result = "xla_hlo.tanh"(%arg0)
: (tensor<?x?xf32>) -> tensor<?x?xf32>
// CHECK: %[[C0:.*]] = constant 0 : index
// CHECK: %[[DIM0:.*]] = dim %arg0, %[[C0]] : memref<?x?xf32>
// CHECK: %[[IC0:.*]] = index_cast %[[DIM0]] : index to i64
// CHECK: %[[C1:.*]] = constant 1 : index
// CHECK: %[[DIM1:.*]] = dim %arg0, %[[C1]] : memref<?x?xf32>
// CHECK: %[[IC1:.*]] = index_cast %[[DIM1]] : index to i64
// CHECK: %[[SHAPE:.*]] = tensor_from_elements(%[[IC0]], %[[IC1]]) : tensor<2xi64>
// CHECK: %[[C0_:.*]] = constant 0 : index
// CHECK: %[[EE0:.*]] = extract_element %[[SHAPE]][%[[C0_]]] : tensor<2xi64>
// CHECK: %[[ICS0:.*]] = index_cast %[[EE0]] : i64 to index
// CHECK: %[[C1_:.*]] = constant 1 : index
// CHECK: %[[EE1:.*]] = extract_element %[[SHAPE]][%[[C1_]]] : tensor<2xi64>
// CHECK: %[[ICS1:.*]] = index_cast %[[EE1]] : i64 to index
// CHECK: %[[RESULT:.*]] = alloc(%[[ICS0]], %[[ICS1]])
// CHECK: "xla_lhlo.tanh"(%arg0, %[[RESULT]]) : (memref<?x?xf32>, memref<?x?xf32>) -> ()
// BOTH: %[[C0:.*]] = constant 0 : index
// BOTH: %[[DIM0:.*]] = dim %arg0, %[[C0]] : memref<?x?xf32>
// BOTH: %[[IC0:.*]] = index_cast %[[DIM0]] : index to i64
// BOTH: %[[C1:.*]] = constant 1 : index
// BOTH: %[[DIM1:.*]] = dim %arg0, %[[C1]] : memref<?x?xf32>
// BOTH: %[[IC1:.*]] = index_cast %[[DIM1]] : index to i64
// BOTH: %[[SHAPE:.*]] = tensor_from_elements(%[[IC0]], %[[IC1]]) : tensor<2xi64>
// BOTH: %[[C0_:.*]] = constant 0 : index
// BOTH: %[[EE0:.*]] = extract_element %[[SHAPE]][%[[C0_]]] : tensor<2xi64>
// BOTH: %[[ICS0:.*]] = index_cast %[[EE0]] : i64 to index
// BOTH: %[[C1_:.*]] = constant 1 : index
// BOTH: %[[EE1:.*]] = extract_element %[[SHAPE]][%[[C1_]]] : tensor<2xi64>
// BOTH: %[[ICS1:.*]] = index_cast %[[EE1]] : i64 to index
// BOTH: %[[RESULT:.*]] = alloc(%[[ICS0]], %[[ICS1]])
// BOTH: "xla_lhlo.tanh"(%arg0, %[[RESULT]]) : (memref<?x?xf32>, memref<?x?xf32>) -> ()
return
}
// -----
// CHECK-LABEL: func @dot
// BOTH-LABEL: func @dot
func @dot(%arg0: tensor<1024x1024xf32>) -> tensor<1024x1024xf32> {
// CHECK-SAME: (%[[ARG0:.*]]: [[TYPE:.*]],
// CHECK-SAME: %[[RESULT:.*]]: [[TYPE]])
// CHECK: "xla_lhlo.dot"(%[[ARG0]], %[[ARG0]], %{{.*}}) : ([[TYPE]], [[TYPE]], [[TYPE]]) -> ()
// PRE-SAME: (%[[ARG0:.*]]: [[TYPE:.*]], %[[RESULT:.*]]: [[TYPE]])
// ESC-SAME: (%[[ARG0:.*]]: [[TYPE:.*]]) -> [[TYPE]]
// BOTH-NEXT: %[[ALLOC:.*]] = alloc
// BOTH: "xla_lhlo.dot"(%[[ARG0]], %[[ARG0]], %[[ALLOC]]) : ([[TYPE]], [[TYPE]], [[TYPE]]) -> ()
%dot = "xla_hlo.dot"(%arg0, %arg0)
: (tensor<1024x1024xf32>, tensor<1024x1024xf32>) -> tensor<1024x1024xf32>
// PRE: "xla_lhlo.copy"(%[[ALLOC]], %[[RESULT]])
// ESC: return %[[ALLOC]]
return %dot : tensor<1024x1024xf32>
}
// -----
// CHECK-LABEL: func @conv
// BOTH-LABEL: func @conv
func @conv(%input: tensor<3x5x5x3xf32>, %filter : tensor<2x2x3x4xf32>) -> tensor<3x5x5x4xf32> {
%c0 = constant 0 : index
// CHECK: %[[OUT:.*]] = alloc() : memref<3x5x5x4xf32>
// CHECK: "xla_lhlo.convolution"(%{{.+}}, %{{.+}}, %[[OUT]])
// CHECK-SAME: padding = dense<[
// CHECK-SAME: [0, 1], [0, 1]]> : tensor<2x2xi64>
// CHECK-SAME: rhs_dilation = dense<[1, 2]>
// CHECK-SAME: window_strides = dense<[2, 1]>
// BOTH: %[[OUT:.*]] = alloc() : memref<3x5x5x4xf32>
// BOTH: "xla_lhlo.convolution"(%{{.+}}, %{{.+}}, %[[OUT]])
// BOTH-SAME: padding = dense<[
// BOTH-SAME: [0, 1], [0, 1]]> : tensor<2x2xi64>
// BOTH-SAME: rhs_dilation = dense<[1, 2]>
// BOTH-SAME: window_strides = dense<[2, 1]>
%out = "xla_hlo.convolution"(%filter, %input) {
batch_group_count = 1 : i64,
dimension_numbers = {

View File

@ -1,13 +1,12 @@
// RUN: xla-opt -lhlo-fuse-linalg %s -o - | FileCheck %s --dump-input=always
// RUN: xla-opt -lhlo-fuse-linalg=tile-sizes=2,3 %s -o - | FileCheck %s -check-prefix=TILED
// RUN: xla-opt -lhlo-fuse-linalg=use-parallel-loops %s -o - | FileCheck %s -check-prefix=PLOOP
// RUN: xla-opt -lhlo-fuse-linalg %s -split-input-file | FileCheck %s --dump-input=always
// RUN: xla-opt -lhlo-fuse-linalg=tile-sizes=2,3 %s -split-input-file | FileCheck %s -check-prefix=TILED
// RUN: xla-opt -lhlo-fuse-linalg=use-parallel-loops %s -split-input-file | FileCheck %s -check-prefix=PLOOP
#map0 = affine_map<(d0, d1) -> (d0, d1)>
#pointwise_2d_trait = {args_in = 2, args_out = 1, indexing_maps = [#map0, #map0, #map0], iterator_types = ["parallel", "parallel"]}
func @fusion(%multiplier: memref<6x6xf32>, %summand_1: memref<6x6xf32>,
%summand_2: memref<6x6xf32>, %result: memref<6x6xf32>) {
%temp_result = alloc() {temp = true} : memref<6x6xf32>
%temp_result = alloc() : memref<6x6xf32>
linalg.generic #pointwise_2d_trait %summand_1, %summand_2, %temp_result {
^bb0(%summand_1_in: f32, %summand_2_in: f32, %temp_result_in: f32):
%out = addf %summand_1_in, %summand_2_in : f32
@ -19,7 +18,7 @@ func @fusion(%multiplier: memref<6x6xf32>, %summand_1: memref<6x6xf32>,
linalg.yield %out : f32
} : memref<6x6xf32>, memref<6x6xf32>, memref<6x6xf32>
dealloc %temp_result : memref<6x6xf32>
"xla_lhlo.terminator"() : () -> ()
return
}
// CHECK-LABEL: func @fusion
// CHECK: %[[C1:.*]] = constant 1
@ -53,10 +52,12 @@ func @fusion(%multiplier: memref<6x6xf32>, %summand_1: memref<6x6xf32>,
// PLOOP: linalg.generic
// PLOOP: mulf
// -----
func @fusion_of_three(%arg0: memref<100x10xf32>,
%arg1: memref<100xf32>,
%arg2: memref<100x10xf32>) {
%0 = alloc() {temp = true} : memref<100x10xf32>
%0 = alloc() : memref<100x10xf32>
linalg.generic {
args_in = 1 : i64,
args_out = 1 : i64,
@ -66,7 +67,7 @@ func @fusion_of_three(%arg0: memref<100x10xf32>,
^bb0(%arg3: f32, %arg4: f32): // no predecessors
linalg.yield %arg3 : f32
}: memref<100xf32>, memref<100x10xf32>
%1 = alloc() {temp = true} : memref<100x10xf32>
%1 = alloc() : memref<100x10xf32>
linalg.generic {
args_in = 2 : i64,
args_out = 1 : i64,
@ -126,11 +127,13 @@ func @fusion_of_three(%arg0: memref<100x10xf32>,
// PLOOP: linalg.generic
// PLOOP: exp
#map1 = affine_map<(d0, d1, d2, d3) -> (d0, d1, d2, d3)>
#pointwise_4d_trait = {args_in = 2, args_out = 1, indexing_maps = [#map1, #map1, #map1], iterator_types = ["parallel", "parallel", "parallel", "parallel"]}
// -----
#map0 = affine_map<(d0, d1, d2, d3) -> (d0, d1, d2, d3)>
#pointwise_4d_trait = {args_in = 2, args_out = 1, indexing_maps = [#map0, #map0, #map0], iterator_types = ["parallel", "parallel", "parallel", "parallel"]}
func @fusion_4d(%multiplier: memref<6x6x6x6xf32>, %summand_1: memref<6x6x6x6xf32>,
%summand_2: memref<6x6x6x6xf32>, %result: memref<6x6x6x6xf32>) {
%temp_result = alloc() {temp = true} : memref<6x6x6x6xf32>
%temp_result = alloc() : memref<6x6x6x6xf32>
linalg.generic #pointwise_4d_trait %summand_1, %summand_2, %temp_result {
^bb0(%summand_1_in: f32, %summand_2_in: f32, %temp_result_in: f32):
%out = addf %summand_1_in, %summand_2_in : f32
@ -142,7 +145,7 @@ func @fusion_4d(%multiplier: memref<6x6x6x6xf32>, %summand_1: memref<6x6x6x6xf32
linalg.yield %out : f32
} : memref<6x6x6x6xf32>, memref<6x6x6x6xf32>, memref<6x6x6x6xf32>
dealloc %temp_result : memref<6x6x6x6xf32>
"xla_lhlo.terminator"() : () -> ()
return
}
// CHECK-LABEL: func @fusion_4d
// CHECK: %[[C1:.*]] = constant 1
@ -177,3 +180,57 @@ func @fusion_4d(%multiplier: memref<6x6x6x6xf32>, %summand_1: memref<6x6x6x6xf32
// PLOOP: addf
// PLOOP: linalg.generic
// PLOOP: mulf
// -----
#map0 = affine_map<(d0, d1) -> (d0, d1)>
#pointwise_2d_trait = {args_in = 2, args_out = 1, indexing_maps = [#map0, #map0, #map0], iterator_types = ["parallel", "parallel"]}
func @fusion(%multiplier: memref<6x6xf32>, %summand_1: memref<6x6xf32>,
%summand_2: memref<6x6xf32>) -> memref<6x6xf32> {
%temp_result = alloc() : memref<6x6xf32>
linalg.generic #pointwise_2d_trait %summand_1, %summand_2, %temp_result {
^bb0(%summand_1_in: f32, %summand_2_in: f32, %temp_result_in: f32):
%out = addf %summand_1_in, %summand_2_in : f32
linalg.yield %out : f32
} : memref<6x6xf32>, memref<6x6xf32>, memref<6x6xf32>
%result = alloc() : memref<6x6xf32>
linalg.generic #pointwise_2d_trait %temp_result, %multiplier, %result {
^bb0(%temp_result_in: f32, %multiplier_in: f32, %result_in: f32):
%out = mulf %temp_result_in, %multiplier_in : f32
linalg.yield %out : f32
} : memref<6x6xf32>, memref<6x6xf32>, memref<6x6xf32>
dealloc %temp_result : memref<6x6xf32>
return %result : memref<6x6xf32>
}
// CHECK-LABEL: func @fusion
// CHECK: %[[C1:.*]] = constant 1
// CHECK-NOT: linalg.generic
// CHECK: scf.for {{.*}} step %[[C1]]
// CHECK: scf.for {{.*}} step %[[C1]]
// CHECK-NOT: scf.for
// CHECK: linalg.generic
// CHECK: addf
// CHECK: linalg.generic
// CHECK: mulf
// TILED-LABEL: func @fusion
// TILED-DAG: %[[C2:.*]] = constant 2
// TILED-DAG: %[[C3:.*]] = constant 3
// TILED-NOT: linalg.generic
// TILED: scf.for {{.*}} step %[[C2]]
// TILED: scf.for {{.*}} step %[[C3]]
// TILED-NOT: scf.for
// TILED: linalg.generic
// TILED: addf
// TILED: linalg.generic
// TILED: mulf
// PLOOP-LABEL: func @fusion
// PLOOP-NOT: linalg.generic
// PLOOP: scf.parallel
// PLOOP-NOT: scf.parallel
// PLOOP: linalg.generic
// PLOOP: addf
// PLOOP: linalg.generic
// PLOOP: mulf

View File

@ -38,15 +38,15 @@ func @select_and_scatter(%arg: memref<112x112xf32>,
// CHECK-SAME: [[RESULT_BUF:%.*]]: memref<112x112xf32>) {
// Constants.
// CHECK: [[C56:%.*]] = constant 56 : index
// CHECK: [[C1:%.*]] = constant 1 : index
// CHECK: [[C0_F32:%.*]] = constant 0.000000e+00 : f32
// CHECK: [[CFALSE:%.*]] = constant false
// CHECK: [[C3:%.*]] = constant 3 : index
// CHECK: [[C2:%.*]] = constant 2 : index
// CHECK: [[C0:%.*]] = constant 0 : index
// CHECK: [[C112:%.*]] = constant 112 : index
// CHECK: [[CTRUE:%.*]] = constant true
// CHECK-DAG: [[C56:%.*]] = constant 56 : index
// CHECK-DAG: [[C0:%.*]] = constant 0 : index
// CHECK-DAG: [[C1:%.*]] = constant 1 : index
// CHECK-DAG: [[C0_F32:%.*]] = constant 0.000000e+00 : f32
// CHECK-DAG: [[CFALSE:%.*]] = constant false
// CHECK-DAG: [[C3:%.*]] = constant 3 : index
// CHECK-DAG: [[C2:%.*]] = constant 2 : index
// CHECK-DAG: [[C112:%.*]] = constant 112 : index
// CHECK-DAG: [[CTRUE:%.*]] = constant true
// Parallel loop to initialize the output buffer.
// CHECK: [[INIT:%.*]] = load [[INIT_BUF]][] : memref<f32>
@ -80,23 +80,17 @@ func @select_and_scatter(%arg: memref<112x112xf32>,
// Compute index I of the ARG buffer and check whether it is in padding area.
// CHECK: [[START_I:%.*]] = muli [[II]], [[C2]] : index
// CHECK: [[OFFSET_I:%.*]] = subi [[WIN_I]], [[C0]] : index
// CHECK: [[ARG_I:%.*]] = addi [[START_I]], [[OFFSET_I]] : index
// CHECK: [[ARG_I:%.*]] = addi [[START_I]], [[WIN_I]] : index
// CHECK: [[ARG_I_FITS:%.*]] = cmpi "ult", [[ARG_I]], [[C112]] : index
// Update `INBOUNDS`, i.e. whether or not ARG indices are inside the boundaries
// of the buffer or they are in the padding area.
// CHECK: [[INBOUNDS_0:%.*]] = and [[ARG_I_FITS]], [[CTRUE]] : i1
// Compute index J of the ARG buffer and check whether it is in padding area.
// CHECK: [[START_J:%.*]] = muli [[JJ]], [[C2]] : index
// CHECK: [[OFFSET_J:%.*]] = subi [[WIN_J]], [[C0]] : index
// CHECK: [[ARG_J:%.*]] = addi [[START_J]], [[OFFSET_J]] : index
// CHECK: [[ARG_J:%.*]] = addi [[START_J]], [[WIN_J]] : index
// CHECK: [[ARG_J_FITS:%.*]] = cmpi "ult", [[ARG_J]], [[C112]] : index
// Update `INBOUNDS`, i.e. whether or not ARG indices are inside the boundaries
// of the buffer or they are in the padding area.
// CHECK: [[INBOUNDS_1:%.*]] = and [[INBOUNDS_0]], [[ARG_J_FITS]] : i1
// CHECK: [[INBOUNDS_1:%.*]] = and [[ARG_I_FITS]], [[ARG_J_FITS]] : i1
// If ARG ivs are in the padding area, then 'select' function does not have to
// be applied, current selected ivs (SEL_I, SEL_J) and value (SEL_VAL) are

View File

@ -151,7 +151,6 @@ func @reduce_window(%arg: memref<112x112xf32>,
// CHECK-SAME: [[OPERAND_BUF:%.*]]: memref<112x112xf32>,
// CHECK-SAME: [[INIT_BUF:%.*]]: memref<f32>,
// CHECK-SAME: [[RESULT_BUF:%.*]]: memref<56x56xf32>) {
// CHECK-DAG: [[IN_BOUNDS:%.*]] = constant true
// CHECK-DAG: [[C0:%.*]] = constant 0 : index
// CHECK-DAG: [[C1:%.*]] = constant 1 : index
// CHECK-DAG: [[C2:%.*]] = constant 2 : index
@ -167,16 +166,13 @@ func @reduce_window(%arg: memref<112x112xf32>,
// CHECK-SAME: init ([[INIT]]) -> f32 {
// CHECK: [[START_I:%.*]] = muli [[I]], [[C2]] : index
// CHECK: [[OFFSET_I:%.*]] = subi [[IW]], [[C0]] : index
// CHECK: [[INDEX_I:%.*]] = addi [[START_I]], [[OFFSET_I]] : index
// CHECK: [[INDEX_I:%.*]] = addi [[START_I]], [[IW]] : index
// CHECK: [[INDEX_I_FITS:%.*]] = cmpi "ult", [[INDEX_I]], [[C112]]
// CHECK: [[IN_BOUNDS_0:%.*]] = and [[INDEX_I_FITS]], [[IN_BOUNDS]]
// CHECK: [[START_J:%.*]] = muli [[J]], [[C2]] : index
// CHECK: [[OFFSET_J:%.*]] = subi [[JW]], [[C0]] : index
// CHECK: [[INDEX_J:%.*]] = addi [[START_J]], [[OFFSET_J]] : index
// CHECK: [[INDEX_J:%.*]] = addi [[START_J]], [[JW]] : index
// CHECK: [[INDEX_J_FITS:%.*]] = cmpi "ult", [[INDEX_J]], [[C112]]
// CHECK: [[IN_BOUNDS_1:%.*]] = and [[IN_BOUNDS_0]], [[INDEX_J_FITS]]
// CHECK: [[IN_BOUNDS_1:%.*]] = and [[INDEX_I_FITS]], [[INDEX_J_FITS]]
// CHECK: [[ELEM_TO_REDUCE:%.*]] = scf.if [[IN_BOUNDS_1]] -> (f32) {
// CHECK: [[OPERAND_ELEM:%.*]] =

View File

@ -863,3 +863,124 @@ func @while_memrefs(%arg0: memref<i64>, %arg_out: memref<i64>) -> () {
) : (memref<i64>, memref<i64>) -> ()
return
}
// -----
// CHECK-LABEL: func @bitcast_memrefs
func @bitcast_memrefs(%arg0: memref<1xf64>, %arg_out: memref<2xi32>) -> () {
"xla_lhlo.bitcast"(%arg0, %arg_out) : (memref<1xf64>, memref<2xi32>) -> ()
return
}
// -----
// CHECK-LABEL: func @scatter_memrefs
func @scatter_memrefs(%input: memref<200x100x300xf32>, %indices: memref<10x2xi32>,
%updates: memref<10x300xf32>, %arg_out: memref<200x100x300xf32>) -> () {
"xla_lhlo.scatter" (%input, %indices, %updates, %arg_out) ({
^bb0(%lhs: tensor<f32>, %rhs: tensor<f32>): // no predecessors
%add = xla_hlo.add %lhs, %rhs : tensor<f32>
"xla_hlo.return"(%add) : (tensor<f32>) -> ()
}) {
scatter_dimension_numbers = {
update_window_dims = dense<[1]> : tensor<1xi64>,
inserted_window_dims = dense<[0, 1]> : tensor<2xi64>,
scatter_dims_to_operand_dims = dense<[0, 1]> : tensor<2xi64>,
index_vector_dim = 1 : i64
},
indices_are_sorted = true,
unique_indices = true
} : (memref<200x100x300xf32>, memref<10x2xi32>, memref<10x300xf32>, memref<200x100x300xf32>) -> ()
return
}
// -----
// CHECK-LABEL: func @map_memrefs
func @map_memrefs(%arg0: memref<20xf32>, %arg1: memref<20xf32>, %arg_out: memref<20xf32>) -> () {
"xla_lhlo.map"(%arg0, %arg1, %arg_out) ({
^bb0(%a: tensor<f32>, %b: tensor<f32>):
%c = xla_hlo.add %a, %b : tensor<f32>
"xla_hlo.return"(%c) : (tensor<f32>) -> ()
}) {dimensions = dense<0> : tensor<1xi64>} : (memref<20xf32>, memref<20xf32>, memref<20xf32>) -> ()
return
}
// -----
func @map_memrefs(%arg0: memref<20xf32>, %arg1: memref<20xf32>, %arg_out: memref<10xf32>) -> () {
// expected-error@+1{{requires the same shape for all operands}}
"xla_lhlo.map"(%arg0, %arg1, %arg_out) ({
^bb0(%a: tensor<f32>, %b: tensor<f32>):
%c = xla_hlo.add %a, %b : tensor<f32>
"xla_hlo.return"(%c) : (tensor<f32>) -> ()
}) {dimensions = dense<0> : tensor<1xi64>} : (memref<20xf32>, memref<20xf32>, memref<10xf32>) -> ()
return
}
// -----
// CHECK-LABEL: func @rng_get_and_update_state_memrefs
func @rng_get_and_update_state_memrefs(%state: memref<1xui64>) -> () {
"xla_lhlo.rng_get_and_update_state"(%state) { delta = 1 : i64 } : (memref<1xui64>) -> ()
return
}
// -----
// CHECK-LABEL: func @sort_memrefs
func @sort_memrefs(%arg0: memref<16x16xf32>, %arg1: memref<16x16xf16>,
%arg_out: tuple<memref<16x16xf32>, memref<16x16xf16>>) -> () {
"xla_lhlo.sort"(%arg0, %arg1, %arg_out) ( {
^bb0(%a: tensor<f32>, %b: tensor<f32>, %c: tensor<f16>, %d: tensor<f16>):
%7 = "xla_hlo.compare"(%a, %b) {comparison_direction = "GT"} : (tensor<f32>, tensor<f32>) -> tensor<i1>
"xla_hlo.return"(%7) : (tensor<i1>) -> ()
}) {dimension = 1 : i64, is_stable = true} : (memref<16x16xf32>, memref<16x16xf16>, tuple<memref<16x16xf32>, memref<16x16xf16>>) -> ()
return
}
// -----
// CHECK-LABEL: func @sort_memrefs
func @sort_memrefs(%arg0: memref<16x16xf32>, %arg1: memref<16x16xf16>,
%arg_out: tuple<memref<16x16xf32>, memref<16x16xf16>>) -> () {
"xla_lhlo.sort"(%arg0, %arg1, %arg_out) ( {
^bb0(%a: tensor<f32>, %b: tensor<f32>, %c: tensor<f16>, %d: tensor<f16>):
%7 = "xla_hlo.compare"(%a, %b) {comparison_direction = "GT"} : (tensor<f32>, tensor<f32>) -> tensor<i1>
"xla_hlo.return"(%7) : (tensor<i1>) -> ()
}) {dimension = 1 : i64} : (memref<16x16xf32>, memref<16x16xf16>, tuple<memref<16x16xf32>, memref<16x16xf16>>) -> ()
return
}
// -----
// CHECK-LABEL: func @sort_memrefs
func @sort_memrefs(%arg0: memref<16x16xf32>, %arg1: memref<16x16xf16>,
%arg_out: tuple<memref<16x16xf32>, memref<16x16xf16>>) -> () {
"xla_lhlo.sort"(%arg0, %arg1, %arg_out) ( {
^bb0(%a: tensor<f32>, %b: tensor<f32>, %c: tensor<f16>, %d: tensor<f16>):
%7 = "xla_hlo.compare"(%a, %b) {comparison_direction = "GT"} : (tensor<f32>, tensor<f32>) -> tensor<i1>
"xla_hlo.return"(%7) : (tensor<i1>) -> ()
}) : (memref<16x16xf32>, memref<16x16xf16>, tuple<memref<16x16xf32>, memref<16x16xf16>>) -> ()
return
}
// -----
// CHECK-LABEL: func @tuple_select_memrefs
func @tuple_select_memrefs(%pred: memref<20xi1>, %true_values: memref<20xf32>,
%false_values: memref<20xf32>, %arg_out: memref<20xf32>) -> () {
"xla_lhlo.tuple_select"(%pred, %true_values, %false_values, %arg_out)
: (memref<20xi1>, memref<20xf32>, memref<20xf32>, memref<20xf32>) -> ()
return
}
// -----
func @tuple_select_memrefs(%pred: memref<10xi1>, %true_values: memref<20xf32>,
%false_values: memref<20xf32>, %arg_out: memref<20xf32>) -> () {
// expected-error@+1{{requires the same shape for all operands}}
"xla_lhlo.tuple_select"(%pred, %true_values, %false_values, %arg_out)
: (memref<10xi1>, memref<20xf32>, memref<20xf32>, memref<20xf32>) -> ()
return
}

View File

@ -1,4 +1,4 @@
// RUN: tf-mlir-translate -hlo-text-to-mlir-hlo %s -o - | FileCheck %s
// RUN: tf-mlir-translate -hlo-text-to-mlir-hlo %s -o - | FileCheck %s -DPRIVATE="attributes {sym_visibility = \"private\"}"
HloModule main
@ -8,6 +8,7 @@ ENTRY %dummy_main (Arg_0.1: f32[]) -> f32[] {
}
// CHECK-LABEL: func @test_simple
// CHECK-SAME: [[PRIVATE]]
%test_simple (Arg_0.1: f32[4], Arg_1.2: f32[4]) -> f32[] {
%Arg_0.1 = f32[4]{0} parameter(0)
%Arg_1.2 = f32[4]{0} parameter(1)
@ -21,7 +22,7 @@ ENTRY %dummy_main (Arg_0.1: f32[]) -> f32[] {
}
// CHECK-LABEL: func @test_after_all
// CHECK-SAME: ([[VAL_0:%.*]]: !xla_hlo.token, [[VAL_1:%.*]]: !xla_hlo.token) -> !xla_hlo.token
// CHECK-SAME: ([[VAL_0:%.*]]: !xla_hlo.token, [[VAL_1:%.*]]: !xla_hlo.token) -> !xla_hlo.token [[PRIVATE]]
%test_after_all (token0: token[], token1: token[] ) -> token[] {
token0 = token[] parameter(0)
token1 = token[] parameter(1)
@ -95,7 +96,7 @@ add {
ROOT %batch-norm-grad = (f32[2,2,2,2], f32[2], f32[2]) batch-norm-grad(f32[2,2,2,2] %input, f32[2] %scale, f32[2] %mean, f32[2] %variance, f32[2,2,2,2] %grad_output), epsilon=0.001, feature_index=1
}
// CHECK-LABEL: func @call(%arg0: tensor<i64>) -> tensor<i64> {
// CHECK-LABEL: func @call(%arg0: tensor<i64>) -> tensor<i64>
%call (arg_1: s64[]) -> s64[] {
%arg_1 = s64[] parameter(0), metadata={op_name="HLO_Args"}
ROOT %compare.2 = s64[] add(%arg_1, %arg_1), metadata={op_type="Less" op_name="Less"}
@ -136,7 +137,7 @@ add {
}
// CHECK-LABEL: func @test_compare(%arg0: tensor<3xf32>, %arg1: tensor<3xf32>, %arg2: tensor<3xf32>) -> tensor<3xi1> {
// CHECK-LABEL: func @test_compare(%arg0: tensor<3xf32>, %arg1: tensor<3xf32>, %arg2: tensor<3xf32>) -> tensor<3xi1>
%test_compare (Arg_0.1: f32[3], Arg_1.2: f32[3], Arg_2.3: f32[3]) -> pred[3] {
%Arg_0.1 = f32[3] parameter(0)
%Arg_1.2 = f32[3] parameter(1)
@ -162,7 +163,7 @@ add {
ROOT %complex.3 = c64[4] complex(f32[4] %Arg_0.1, f32[4] %Arg_1.2)
}
// CHECK-LABEL: func @test_concat(%arg0: tensor<4x1xf32>, %arg1: tensor<4x2xf32>) -> tensor<4x3xf32> {
// CHECK-LABEL: func @test_concat(%arg0: tensor<4x1xf32>, %arg1: tensor<4x2xf32>) -> tensor<4x3xf32>
%test_concat (Arg_0.1: f32[4, 1], Arg_1.2: f32[4, 2]) -> f32[4, 3] {
%Arg_0.1 = f32[4, 1] parameter(0)
%Arg_1.2 = f32[4, 2] parameter(1)
@ -201,7 +202,7 @@ add {
// TODO(b/129422361) Potentially update when copy, reshape, and conv have actual
// implementations with attributes, etc.
// CHECK-LABEL: func @test_conv(%arg0: tensor<256x32x32x6xf32>) -> tuple<tensor<256x30x30x16xf32>> {
// CHECK-LABEL: func @test_conv(%arg0: tensor<256x32x32x6xf32>) -> tuple<tensor<256x30x30x16xf32>>
%test_conv {
%arg0.1 = f32[256,32,32,6]{3,2,1,0} parameter(0), metadata={op_name="HLO_Args"}
@ -257,7 +258,7 @@ add {
ROOT %convolution = f32[1,5,1] convolution(f32[1,2,1] %input, f32[1,1,1] %filter), feature_group_count=1, dim_labels=b0f_0io->b0f, window={pad=1_2 size=1}
}
// CHECK-LABEL: func @test_convert(%arg0: tensor<4xf32>, %arg1: tensor<4xf32>) -> tensor<4xf64> {
// CHECK-LABEL: func @test_convert(%arg0: tensor<4xf32>, %arg1: tensor<4xf32>) -> tensor<4xf64>
%test_convert (Arg_0.1: f32[4], Arg_1.2: f32[4]) -> f64[4] {
%Arg_0.1 = f32[4] parameter(0)
%Arg_1.2 = f32[4] parameter(1)
@ -272,7 +273,7 @@ add {
ROOT %add.5 = f64[4] add(f64[4] %convert.3, f64[4] %convert.4)
}
// CHECK-LABEL: func @test_cosine(%arg0: tensor<1x16x16x3xf32>) -> tensor<1x16x16x3xf32> {
// CHECK-LABEL: func @test_cosine(%arg0: tensor<1x16x16x3xf32>) -> tensor<1x16x16x3xf32>
%test_cosine (arg0.1: f32[1,16,16,3]) -> f32[1,16,16,3] {
%arg0.1 = f32[1,16,16,3]{3,2,1,0} parameter(0), metadata={op_name="HLO_Args"}
@ -289,7 +290,7 @@ add {
ROOT %custom-call = f32[1,2,3]{0,2,1} custom-call(f32[2,3] %arg1, f32[5,5] %arg2), custom_call_target="foo", backend_config="bar", custom_call_has_side_effect=true
}
// CHECK-LABEL: func @test_div(%arg0: tensor<4xf32>, %arg1: tensor<4xf32>) -> tensor<4xf32> {
// CHECK-LABEL: func @test_div(%arg0: tensor<4xf32>, %arg1: tensor<4xf32>) -> tensor<4xf32>
%test_div (Arg_0.1: f32[4], Arg_1.2: f32[4]) -> f32[4] {
%Arg_0.1 = f32[4] parameter(0)
%Arg_1.2 = f32[4] parameter(1)
@ -298,7 +299,7 @@ add {
ROOT %divide.3 = f32[4] divide(f32[4] %Arg_0.1, f32[4] %Arg_1.2)
}
// CHECK-LABEL: func @test_dot(%arg0: tensor<1x4xf32>, %arg1: tensor<4x1xf32>) -> tensor<f32> {
// CHECK-LABEL: func @test_dot(%arg0: tensor<1x4xf32>, %arg1: tensor<4x1xf32>) -> tensor<f32>
%test_dot (Arg_0.1: f32[1, 4], Arg_1.2: f32[4, 1]) -> f32[] {
%Arg_0.1 = f32[1, 4] parameter(0)
%Arg_1.2 = f32[4, 1] parameter(1)
@ -350,7 +351,7 @@ add {
ROOT %dynamic-slice = s32[1,1,32] dynamic-slice(s32[2,2,258] %operand, s32[] %start_idx_1, s32[] %start_idx_2, s32[] %start_idx_3), dynamic_slice_sizes={1,1,32}
}
// CHECK-LABEL: func @test_dynamic_update_slice_1(%arg0: tensor<4x4xf32>, %arg1: tensor<1x4xf32>, %arg2: tensor<i32>, %arg3: tensor<i32>) -> tensor<4x4xf32> {
// CHECK-LABEL: func @test_dynamic_update_slice_1(%arg0: tensor<4x4xf32>, %arg1: tensor<1x4xf32>, %arg2: tensor<i32>, %arg3: tensor<i32>) -> tensor<4x4xf32>
%test_dynamic_update_slice_1 (Arg_0.1: f32[4, 4], Arg_1.2: f32[1, 4], Arg_2.3: f32[], Arg_3.4: f32[]) -> f32[4, 4] {
%Arg_0.1 = f32[4, 4] parameter(0)
%Arg_1.2 = f32[1, 4] parameter(1)
@ -371,7 +372,7 @@ add {
ROOT %dynamic-update-slice.5 = f32[4] dynamic-update-slice(%Arg_0.1, %Arg_1.2, %Arg_2.3)
}
// CHECK-LABEL: func @test_exponential(%arg0: tensor<16xf32>) -> tensor<16xf32> {
// CHECK-LABEL: func @test_exponential(%arg0: tensor<16xf32>) -> tensor<16xf32>
%test_exponential (arg0.1: f32[16]) -> f32[16] {
%arg0.1 = f32[16] parameter(0)
@ -379,7 +380,7 @@ add {
ROOT %exp.2 = f32[16] exponential(f32[16] %arg0.1)
}
// CHECK-LABEL: func @test_expm1(%arg0: tensor<16xf32>) -> tensor<16xf32> {
// CHECK-LABEL: func @test_expm1(%arg0: tensor<16xf32>) -> tensor<16xf32>
%test_expm1 (arg0.1: f32[16]) -> f32[16] {
%arg0.1 = f32[16] parameter(0)
@ -387,7 +388,7 @@ add {
ROOT %expm1.2 = f32[16] exponential-minus-one(f32[16] %arg0.1)
}
// CHECK-LABEL: func @test_fft(%arg0: tensor<3x9xf32>) -> tensor<3x5xcomplex<f32>> {
// CHECK-LABEL: func @test_fft(%arg0: tensor<3x9xf32>) -> tensor<3x5xcomplex<f32>>
%test_fft {
%arg0.1 = f32[3,9]{1,0} parameter(0), parameter_replication={false}, metadata={op_name="XLA_Args"}
// CHECK: "xla_hlo.fft"(%arg0) {fft_length = dense<9> : tensor<1xi64>, fft_type = "RFFT"
@ -395,7 +396,7 @@ add {
}
// CHECK-LABEL: func @test_floor(
// CHECK-SAME: [[A0:%.+]]: tensor<16xf32>) -> tensor<16xf32> {
// CHECK-SAME: [[A0:%.+]]: tensor<16xf32>) -> tensor<16xf32>
%test_floor (arg0.1: f32[16]) -> f32[16] {
%arg0.1 = f32[16] parameter(0)
@ -404,7 +405,7 @@ add {
}
// CHECK-LABEL: func @test_gather(
// CHECK-SAME: [[ARG0:%.+]]: tensor<200x100x300xf32>, [[ARG1:%.+]]: tensor<10x2xi32>) -> tensor<10x300xf32> {
// CHECK-SAME: [[ARG0:%.+]]: tensor<200x100x300xf32>, [[ARG1:%.+]]: tensor<10x2xi32>) -> tensor<10x300xf32>
%test_gather (arg.0: f32[200,100,300], arg.1: s32[10,2]) -> f32[10,300] {
%arg.0 = f32[200,100,300] parameter(0)
%arg.1 = s32[10,2] parameter(1)
@ -442,7 +443,7 @@ add {
}
// CHECK-LABEL: func @test_infeed
// CHECK-SAME: ([[TOKEN:%.*]]: !xla_hlo.token) -> tuple<tensor<3xi32>, !xla_hlo.token> {
// CHECK-SAME: ([[TOKEN:%.*]]: !xla_hlo.token) -> tuple<tensor<3xi32>, !xla_hlo.token>
%test_infeed (token0: token[]) -> (s32[3], token[]) {
%token0 = token[] parameter(0)
// CHECK-NEXT: "xla_hlo.infeed"([[TOKEN]])
@ -451,19 +452,19 @@ add {
}
// CHECK-LABEL: func @test_iota_1() -> tensor<4xf32> {
// CHECK-LABEL: func @test_iota_1() -> tensor<4xf32>
%test_iota_1 () -> f32[4] {
// CHECK-NEXT: "xla_hlo.iota"() {iota_dimension = 0 : i64} : () -> tensor<4xf32>
ROOT %iota.0 = f32[4] iota(), iota_dimension=0
}
// CHECK-LABEL: func @test_iota_2() -> tensor<4x5xf32> {
// CHECK-LABEL: func @test_iota_2() -> tensor<4x5xf32>
%test_iota_2 () -> f32[4, 5] {
// CHECK-NEXT: "xla_hlo.iota"() {iota_dimension = 1 : i64} : () -> tensor<4x5xf32>
ROOT %iota.0 = f32[4, 5] iota(), iota_dimension=1
}
// CHECK-LABEL: func @test_log(%arg0: tensor<16xf32>) -> tensor<16xf32> {
// CHECK-LABEL: func @test_log(%arg0: tensor<16xf32>) -> tensor<16xf32>
%test_log (arg0.1: f32[16]) -> f32[16] {
%arg0.1 = f32[16] parameter(0)
@ -471,7 +472,7 @@ add {
ROOT %log.2 = f32[16] log(f32[16] %arg0.1)
}
// CHECK-LABEL: func @test_log1p(%arg0: tensor<16xf32>) -> tensor<16xf32> {
// CHECK-LABEL: func @test_log1p(%arg0: tensor<16xf32>) -> tensor<16xf32>
%test_log1p (arg0.1: f32[16]) -> f32[16] {
%arg0.1 = f32[16] parameter(0)
@ -501,7 +502,7 @@ add {
// CHECK-LABEL: func @test_maximum(%arg0: tensor<4xf32>, %arg1: tensor<4xf32>) -> tensor<4xf32> {
// CHECK-LABEL: func @test_maximum(%arg0: tensor<4xf32>, %arg1: tensor<4xf32>) -> tensor<4xf32>
%test_maximum (Arg_0.1: f32[4], Arg_1.2: f32[4]) -> f32[4] {
%Arg_0.1 = f32[4] parameter(0)
%Arg_1.2 = f32[4] parameter(1)
@ -510,7 +511,7 @@ add {
ROOT %maximum.3 = f32[4] maximum(f32[4] %Arg_0.1, f32[4] %Arg_1.2)
}
// CHECK-LABEL: func @test_minimum(%arg0: tensor<4xf32>, %arg1: tensor<4xf32>) -> tensor<4xf32> {
// CHECK-LABEL: func @test_minimum(%arg0: tensor<4xf32>, %arg1: tensor<4xf32>) -> tensor<4xf32>
%test_minimum (Arg_0.1: f32[4], Arg_1.2: f32[4]) -> f32[4] {
%Arg_0.1 = f32[4] parameter(0)
%Arg_1.2 = f32[4] parameter(1)
@ -519,7 +520,7 @@ add {
ROOT %minimum.3 = f32[4] minimum(f32[4] %Arg_0.1, f32[4] %Arg_1.2)
}
// CHECK-LABEL: func @test_multiply(%arg0: tensor<4xf32>, %arg1: tensor<4xf32>) -> tensor<4xf32> {
// CHECK-LABEL: func @test_multiply(%arg0: tensor<4xf32>, %arg1: tensor<4xf32>) -> tensor<4xf32>
%test_multiply (Arg_0.1: f32[4], Arg_1.2: f32[4]) -> f32[4] {
%Arg_0.1 = f32[4] parameter(0)
%Arg_1.2 = f32[4] parameter(1)
@ -528,7 +529,7 @@ add {
ROOT %multiply.3 = f32[4] multiply(f32[4] %Arg_0.1, f32[4] %Arg_1.2)
}
// CHECK-LABEL: func @test_negate(%arg0: tensor<16xf32>) -> tensor<16xf32> {
// CHECK-LABEL: func @test_negate(%arg0: tensor<16xf32>) -> tensor<16xf32>
%test_negate (arg0.1: f32[16]) -> f32[16] {
%arg0.1 = f32[16] parameter(0)
@ -536,7 +537,7 @@ add {
ROOT %negate.2 = f32[16] negate(f32[16] %arg0.1)
}
// CHECK-LABEL: func @test_not(%arg0: tensor<16xi1>) -> tensor<16xi1> {
// CHECK-LABEL: func @test_not(%arg0: tensor<16xi1>) -> tensor<16xi1>
%test_not (arg0.1: pred[16]) -> pred[16] {
%arg0.1 = pred[16] parameter(0)
@ -554,7 +555,7 @@ add {
}
// CHECK-LABEL: func @test_outfeed
// CHECK-SAME: ([[DATA:%.*]]: tensor<3xi32>, [[TOKEN:%.*]]: !xla_hlo.token) -> !xla_hlo.token {
// CHECK-SAME: ([[DATA:%.*]]: tensor<3xi32>, [[TOKEN:%.*]]: !xla_hlo.token) -> !xla_hlo.token
%test_outfeed (Arg_0.1: s32[3], Arg_1.2: token[]) -> token[] {
%Arg_0.1 = s32[3] parameter(0)
%Arg_1.2 = token[] parameter(1)
@ -563,7 +564,7 @@ add {
ROOT %outfeed.3 = token[] outfeed(s32[3] %Arg_0.1, token[] %Arg_1.2), outfeed_config="foobar"
}
// CHECK-LABEL: func @test_pad(%arg0: tensor<4xf32>, %arg1: tensor<f32>) -> tensor<4xf32> {
// CHECK-LABEL: func @test_pad(%arg0: tensor<4xf32>, %arg1: tensor<f32>) -> tensor<4xf32>
%test_pad (Arg_0.1: f32[4], Arg_1.2: f32[]) -> f32[4] {
%Arg_0.1 = f32[4] parameter(0)
%Arg_1.2 = f32[] parameter(1)
@ -572,7 +573,7 @@ add {
ROOT %pad.3 = f32[4] pad(%Arg_0.1, %Arg_1.2), padding=0_0_0
}
// CHECK-LABEL: func @test_pad_edge(%arg0: tensor<4x4x4xf32>, %arg1: tensor<f32>) -> tensor<7x11x15xf32> {
// CHECK-LABEL: func @test_pad_edge(%arg0: tensor<4x4x4xf32>, %arg1: tensor<f32>) -> tensor<7x11x15xf32>
%test_pad_edge (Arg_0.1: f32[4, 4, 4], Arg_1.2: f32[]) -> f32[7, 11, 15] {
%Arg_0.1 = f32[4, 4, 4] parameter(0)
%Arg_1.2 = f32[] parameter(1)
@ -581,7 +582,7 @@ add {
ROOT %pad.3 = f32[7, 11, 15] pad(%Arg_0.1, %Arg_1.2), padding=1_2x3_4x5_6
}
// CHECK-LABEL: func @test_pad_interior(%arg0: tensor<4xf32>, %arg1: tensor<f32>) -> tensor<10xf32> {
// CHECK-LABEL: func @test_pad_interior(%arg0: tensor<4xf32>, %arg1: tensor<f32>) -> tensor<10xf32>
%test_pad_interior (Arg_0.1: f32[4], Arg_1.2: f32[]) -> f32[10] {
%Arg_0.1 = f32[4] parameter(0)
%Arg_1.2 = f32[] parameter(1)
@ -590,7 +591,7 @@ add {
ROOT %pad.3 = f32[10] pad(%Arg_0.1, %Arg_1.2), padding=0_0_2
}
// CHECK-LABEL: func @test_popcnt(%arg0: tensor<16xi32>) -> tensor<16xi32> {
// CHECK-LABEL: func @test_popcnt(%arg0: tensor<16xi32>) -> tensor<16xi32>
%test_popcnt (arg0.1: s32[16]) -> s32[16] {
%arg0.1 = s32[16] parameter(0)
@ -598,7 +599,7 @@ add {
ROOT %popcnt.2 = s32[16] popcnt(s32[16] %arg0.1)
}
// CHECK-LABEL: func @test_pow(%arg0: tensor<4xf32>, %arg1: tensor<4xf32>) -> tensor<4xf32> {
// CHECK-LABEL: func @test_pow(%arg0: tensor<4xf32>, %arg1: tensor<4xf32>) -> tensor<4xf32>
%test_pow (Arg_0.1: f32[4], Arg_1.2: f32[4]) -> f32[4] {
%Arg_0.1 = f32[4] parameter(0)
%Arg_1.2 = f32[4] parameter(1)
@ -659,7 +660,7 @@ add {
}
// CHECK-LABEL: func @test_reduce
// CHECK-SAME: ([[ARG0:%.*]]: tensor<4x4xf32>, [[ARG1:%.*]]: tensor<4xf32>, [[ARG2:%.*]]: tensor<f32>) -> tuple<tuple<tensor<f32>, tensor<f32>>, tensor<f32>> {
// CHECK-SAME: ([[ARG0:%.*]]: tensor<4x4xf32>, [[ARG1:%.*]]: tensor<4xf32>, [[ARG2:%.*]]: tensor<f32>) -> tuple<tuple<tensor<f32>, tensor<f32>>, tensor<f32>>
%test_reduce (Arg_0.1: f32[4, 4], Arg_1.2: f32[4], Arg_2.3: f32[]) -> ((f32[], f32[]), f32[]) {
%Arg_0.1 = f32[4, 4] parameter(0)
%Arg_1.2 = f32[4] parameter(1)
@ -719,7 +720,7 @@ add {
ROOT %remainder.3 = f32[4] remainder(f32[4] %Arg_0.1, f32[4] %Arg_1.2)
}
// CHECK-LABEL: func @test_reverse_1d(%arg0: tensor<4xf32>) -> tensor<4xf32> {
// CHECK-LABEL: func @test_reverse_1d(%arg0: tensor<4xf32>) -> tensor<4xf32>
%test_reverse_1d (Arg_0.1: f32[4]) -> f32[4] {
%Arg_0.1 = f32[4] parameter(0)
@ -727,7 +728,7 @@ add {
ROOT reverse.2 = f32[4] reverse(%Arg_0.1), dimensions={0}
}
// CHECK-LABEL: func @test_reverse_2d(%arg0: tensor<4x4xf32>) -> tensor<4x4xf32> {
// CHECK-LABEL: func @test_reverse_2d(%arg0: tensor<4x4xf32>) -> tensor<4x4xf32
%test_reverse_2d (Arg_0.1: f32[4, 4]) -> f32[4, 4] {
%Arg_0.1 = f32[4, 4] parameter(0)
@ -736,7 +737,7 @@ add {
}
// CHECK-LABEL: func @test_rsqrt(
// CHECK-SAME: [[ARG0:%.+]]: tensor<16xf32>) -> tensor<16xf32> {
// CHECK-SAME: [[ARG0:%.+]]: tensor<16xf32>) -> tensor<16xf32>
%test_rsqrt (arg0.1: f32[16]) -> f32[16] {
%arg0.1 = f32[16] parameter(0)
@ -744,7 +745,7 @@ add {
ROOT %rsqrt.2 = f32[16] rsqrt(f32[16] %arg0.1)
}
// CHECK-LABEL: func @test_scalar(%arg0: tensor<f32>) -> tensor<f32> {
// CHECK-LABEL: func @test_scalar(%arg0: tensor<f32>) -> tensor<f32>
%test_scalar (Arg_0.1: f32[]) -> f32[] {
// CHECK-NEXT: return %arg0 : tensor<f32>
ROOT %Arg_0.1 = f32[] parameter(0)
@ -781,7 +782,7 @@ add {
// CHECK-SAME: unique_indices = false
// CHECK-LABEL: func @test_select(%arg0: tensor<2x3xi1>, %arg1: tensor<2x3xi32>, %arg2: tensor<2x3xi32>) -> tensor<2x3xi32> {
// CHECK-LABEL: func @test_select(%arg0: tensor<2x3xi1>, %arg1: tensor<2x3xi32>, %arg2: tensor<2x3xi32>) -> tensor<2x3xi32>
%test_select {
%Arg_0.1 = pred[2,3] parameter(0)
%Arg_1.2 = s32[2,3] parameter(1)
@ -838,7 +839,7 @@ add {
ROOT %set-dimension-size.2 = f32[4,<=4] set-dimension-size(f32[4,4] %Arg_0.1, s32[] %Arg_1.2), dimensions={1}
}
// CHECK-LABEL: func @test_sine(%arg0: tensor<1x16x16x3xf32>) -> tensor<1x16x16x3xf32> {
// CHECK-LABEL: func @test_sine(%arg0: tensor<1x16x16x3xf32>) -> tensor<1x16x16x3xf32>
%test_sine (arg0.1: f32[1,16,16,3]) -> f32[1,16,16,3] {
%arg0.1 = f32[1,16,16,3]{3,2,1,0} parameter(0), metadata={op_name="HLO_Args"}
@ -874,7 +875,7 @@ add {
ROOT %subtract.3 = f32[4] subtract(f32[4] %Arg_0.1, f32[4] %Arg_1.2)
}
// CHECK-LABEL: func @test_tanh(%arg0: tensor<1x16x16x3xf32>) -> tensor<1x16x16x3xf32> {
// CHECK-LABEL: func @test_tanh(%arg0: tensor<1x16x16x3xf32>) -> tensor<1x16x16x3xf32>
%test_tanh (arg0.1: f32[1,16,16,3]) -> f32[1,16,16,3] {
%arg0.1 = f32[1,16,16,3]{3,2,1,0} parameter(0), metadata={op_name="HLO_Args"}
@ -882,7 +883,7 @@ add {
ROOT %tanh.3 = f32[1,16,16,3]{3,2,1,0} tanh(f32[1,16,16,3]{3,2,1,0} %arg0.1), metadata={op_type="Tanh" op_name="embedded_inference/tanh_model/Tanh"}
}
// CHECK-LABEL: func @test_transpose(%arg0: tensor<1x2x3x4xi32>) -> tensor<2x1x4x3xi32> {
// CHECK-LABEL: func @test_transpose(%arg0: tensor<1x2x3x4xi32>) -> tensor<2x1x4x3xi32>
%test_transpose {
%Arg_0.1 = s32[1,2,3,4] parameter(0)
@ -903,7 +904,7 @@ add {
ROOT %triangular-solve.3 = f32[4,3] triangular-solve(f32[4,4] %Arg_0.1, f32[4,3] %Arg_1.2), left_side=true, lower=true, transpose_a=NO_TRANSPOSE, unit_diagonal=true
}
// CHECK-LABEL: func @test_tuple(%arg0: tensor<1xi32>, %arg1: tensor<1x2xf32>) -> tuple<tensor<1xi32>, tensor<1x2xf32>> {
// CHECK-LABEL: func @test_tuple(%arg0: tensor<1xi32>, %arg1: tensor<1x2xf32>) -> tuple<tensor<1xi32>, tensor<1x2xf32>>
%test_tuple(Arg_0.1: s32[1], Arg_1.2: f32[1, 2]) -> (s32[1], f32[1,2]) {
%Arg_0.1 = s32[1] parameter(0)
%Arg_1.2 = f32[1, 2] parameter(1)
@ -928,7 +929,7 @@ add {
ROOT %compare.2 = s64[] add(%arg_1, %arg_1), metadata={op_type="Less" op_name="Less"}
}
// CHECK-LABEL: func @test_while(%arg0: tensor<i64>) -> tensor<i64> {
// CHECK-LABEL: func @test_while(%arg0: tensor<i64>) -> tensor<i64>
%test_while (arg0.1: s64[]) -> s64[] {
%arg0.1 = s64[] parameter(0), metadata={op_name="HLO_Args"}
// CHECK-NEXT: "xla_hlo.while"(%arg0) ( {

View File

@ -368,6 +368,15 @@ class HloToLhloTensorStoreOpConverter
struct HloLegalizeToLhlo
: public PassWrapper<HloLegalizeToLhlo, OperationPass<ModuleOp>> {
public:
HloLegalizeToLhlo() = default;
HloLegalizeToLhlo(const HloLegalizeToLhlo& o) {
this->results_escape_function = o.results_escape_function.getValue();
}
explicit HloLegalizeToLhlo(bool results_escape_function) {
this->results_escape_function.setValue(results_escape_function);
}
void runOnOperation() override {
OwningRewritePatternList patterns;
auto& context = getContext();
@ -398,10 +407,28 @@ struct HloLegalizeToLhlo
OwningRewritePatternList patterns;
populateHLOToLHLOConversionPattern(func.getContext(), &bufferAssignment,
&converter, &patterns);
if (results_escape_function) {
populateWithBufferAssignmentOpConversionPatterns<
mlir::ReturnOp, mlir::ReturnOp, xla_lhlo::CopyOp,
/*allowMemrefFunctionResults=*/true>(&context, &bufferAssignment,
&converter, &patterns);
} else {
populateWithBufferAssignmentOpConversionPatterns<
mlir::ReturnOp, mlir::ReturnOp, xla_lhlo::CopyOp,
/*allowMemrefFunctionResults=*/false>(&context, &bufferAssignment,
&converter, &patterns);
}
return WalkResult(
applyPartialConversion(func, target, patterns, &converter));
});
}
private:
Option<bool> results_escape_function{
*this, "results-escape-function",
llvm::cl::desc(
"Allocate the results of functions within the functions body"),
llvm::cl::init(false)};
};
} // namespace
@ -446,14 +473,11 @@ void populateHLOToLHLOConversionPattern(
HloToLhloTensorStoreOpConverter
>(context, bufferAssignment, converter);
// clang-format on
populateWithBufferAssignmentOpConversionPatterns<
mlir::ReturnOp, mlir::ReturnOp, xla_lhlo::CopyOp,
/*allowMemrefFunctionResults=*/false>(context, bufferAssignment,
converter, patterns);
}
std::unique_ptr<OperationPass<ModuleOp>> createLegalizeToLhloPass() {
return absl::make_unique<HloLegalizeToLhlo>();
std::unique_ptr<OperationPass<ModuleOp>> createLegalizeToLhloPass(
bool results_escape_function) {
return absl::make_unique<HloLegalizeToLhlo>(results_escape_function);
}
static PassRegistration<HloLegalizeToLhlo> legalize_pass(

View File

@ -20,6 +20,7 @@ limitations under the License.
#include "absl/memory/memory.h"
#include "llvm/ADT/ArrayRef.h"
#include "mlir/Dialect/Linalg/Transforms/Transforms.h" // from @llvm-project
#include "mlir/Dialect/StandardOps/IR/Ops.h" // from @llvm-project
#include "mlir/Pass/Pass.h" // from @llvm-project
#include "mlir/Transforms/FoldUtils.h" // from @llvm-project
#include "tensorflow/compiler/mlir/xla/transforms/passes.h"
@ -52,10 +53,17 @@ class LhloFuseLinalg : public PassWrapper<LhloFuseLinalg, FunctionPass> {
// The fusion in Linalg is currently possible only when the consumer op is
// tiled. In order to greedily fuse the ops, we have to start from the tiled
// root linalg ops, i.e. linalg ops that write to output buffers of the
// function.
llvm::SmallDenseSet<Value> func_args;
// function or are returned in case of escaping allocations.
llvm::SmallDenseSet<Value> result_buffers;
for (auto func_arg : func.getArguments()) {
func_args.insert(func_arg);
result_buffers.insert(func_arg);
}
for (auto& block : func.getBlocks()) {
auto returnOp = mlir::dyn_cast<mlir::ReturnOp>(block.getTerminator());
if (!returnOp) continue;
for (auto operand : returnOp.getOperands()) {
result_buffers.insert(operand);
}
}
MLIRContext* ctx = func.getContext();
OpBuilder b(func);
@ -68,7 +76,7 @@ class LhloFuseLinalg : public PassWrapper<LhloFuseLinalg, FunctionPass> {
}
auto op = cast<LinalgOp>(generic_op.getOperation());
for (const Value result : op.getOutputBuffers()) {
if (!func_args.count(result)) continue;
if (!result_buffers.count(result)) continue;
if (tileGenericOp(op, tile_sizes, &b)) {
generic_op.erase();
return;

View File

@ -59,9 +59,13 @@ std::unique_ptr<OperationPass<FuncOp>> createLegalizeControlFlowPass();
/// Lowers from HLO dialect to Standard dialect.
std::unique_ptr<OperationPass<FuncOp>> createLegalizeToStdPass();
// Lowers from HLO dialect to LHLO dialect allocating/deallocating temporary
// buffers if necessary.
std::unique_ptr<OperationPass<ModuleOp>> createLegalizeToLhloPass();
/// Lowers from HLO dialect to LHLO dialect allocating/deallocating temporary
/// buffers if necessary. If `results_escape_functions` is set to true,
/// allocated buffers for function results will be returned and escape the
/// function. Otherwise, the signature is rewritten with extra arguments for the
/// buffers that are to be used for results.
std::unique_ptr<OperationPass<ModuleOp>> createLegalizeToLhloPass(
bool results_escape_functions = false);
// Lowers from HLO dialect to Linalg dialect.
std::unique_ptr<OperationPass<FuncOp>> createLegalizeHloToLinalgPass();
@ -111,24 +115,6 @@ std::unique_ptr<Pass> createLhloCopyRemovalPass();
std::unique_ptr<OperationPass<FuncOp>> createLegalizeLhloToParallelLoopsPass();
} // namespace xla_lhlo
namespace xla {
/// Moves alloc nodes (and their associated dealloc nodes - if any) into the
/// right positions. If there is no associated dealloc node for a given alloc
/// node, this pass will automatically insert a proper dealloc node in the right
/// place. The intended use case of this pass is to store SSA values into
/// buffers using load/store operations. For this purpose, you need to know
/// proper positions to place the required allocs and deallocs.
/// 1) Note that the function signatures and all types for which buffers should
/// be allocated need to be converted in advance.
/// 2) All required alloc nodes have the be inserted in advance.
/// 3) Note that the current implementation does not support loops.
/// Refer to the class mlir::xla::BufferAssignmentLegalizer for more
/// information.
std::unique_ptr<OperationPass<FuncOp>> createBufferAssignmentPass();
} // namespace xla
} // namespace mlir
#endif // TENSORFLOW_COMPILER_MLIR_XLA_TRANSFORMS_PASSES_H_

View File

@ -135,18 +135,16 @@ class MatrixTriangularSolveOpTest(xla_test.XLATestCase):
self._VerifyTriangularSolve(
a.astype(np.float32), b.astype(np.float32), True, False, 1e-4)
@test_util.run_deprecated_v1
def testNonSquareCoefficientMatrixV1(self):
def testNonSquareCoefficientMatrix(self):
rng = np.random.RandomState(0)
for dtype in self.float_types:
a = rng.randn(3, 4).astype(dtype)
b = rng.randn(4, 4).astype(dtype)
with self.assertRaises(ValueError):
linalg_ops.matrix_triangular_solve(a, b)
with self.assertRaises(ValueError):
linalg_ops.matrix_triangular_solve(a, b)
with self.test_scope():
with self.assertRaises((ValueError, errors.InvalidArgumentError)):
linalg_ops.matrix_triangular_solve(a, b)
@test_util.run_v2_only
@test_util.run_v2_only # Different error types
def testWrongDimensionsV2(self):
randn = np.random.RandomState(0).randn
for dtype in self.float_types:

View File

@ -61,6 +61,81 @@ def implicit_reparameterization_grad(a, x):
return -gen_math_ops.igamma_grad_a(a, x) / prob
@def_function.function(experimental_compile=True)
def _log1p(x):
return math_ops.log1p(x)
class Log1pTest(xla_test.XLATestCase, parameterized.TestCase):
def setUp(self):
if flags.FLAGS.vary_seed:
entropy = os.urandom(64)
if six.PY2:
answer = int(entropy.encode('hex'), 16)
else:
answer = int.from_bytes(entropy, 'big')
np.random.seed(answer % (2**32 - 1))
super(Log1pTest, self).setUp()
def adjust_tolerance_for_tpu(self, dtype, rtol, atol):
if self.device not in ['TPU']:
return rtol, atol
if dtype == np.float32:
return 4e-4, 0.
return 1e-10, 0.
def _test_range(self, low, high, dtype, rtol, atol, is_negative=False):
# Test values near zero.
rtol, atol = self.adjust_tolerance_for_tpu(dtype, rtol, atol)
x = np.exp(np.random.uniform(
low=low, high=high, size=[NUM_SAMPLES])).astype(dtype)
if is_negative:
x = -x
expected_values = np.log1p(x)
with self.session() as sess:
with self.test_scope():
actual = _log1p(x)
actual = sess.run(actual)
self.assertAllClose(expected_values, actual, atol=atol, rtol=rtol)
@parameterized.parameters((np.float32, 1e-7, 0.),
(np.float64, 1e-15, 0.))
def testSmallX(self, dtype, rtol, atol):
self._test_range(-40., -20., dtype, rtol, atol, is_negative=False)
self._test_range(-40., -20., dtype, rtol, atol, is_negative=True)
@parameterized.parameters((np.float32, 2e-7, 0.),
(np.float64, 1e-15, 0.))
def testGreaterThanNegativeTwentyExponent(self, dtype, rtol, atol):
self._test_range(-20., -10., dtype, rtol, atol, is_negative=False)
self._test_range(-20., -10., dtype, rtol, atol, is_negative=True)
@parameterized.parameters((np.float32, 2e-7, 0.),
(np.float64, 1e-15, 0.))
def testGreaterThanNegativeTenExponent(self, dtype, rtol, atol):
self._test_range(-10., -5., dtype, rtol, atol, is_negative=False)
self._test_range(-10., -5., dtype, rtol, atol, is_negative=True)
@parameterized.parameters((np.float32, 2e-7, 0.),
(np.float64, 1e-15, 0.))
def testGreaterThanNegativeFiveExponent(self, dtype, rtol, atol):
self._test_range(-5., -1., dtype, rtol, atol, is_negative=False)
self._test_range(-5., -1., dtype, rtol, atol, is_negative=True)
@parameterized.parameters((np.float32, 4e-7, 0.),
(np.float64, 3e-14, 0.))
def testXGreaterThanOneTenth(self, dtype, rtol, atol):
self._test_range(-1., 0., dtype, rtol, atol, is_negative=False)
self._test_range(-1., 0., dtype, rtol, atol, is_negative=True)
@parameterized.parameters((np.float32, 2e-7, 0.),
(np.float64, 2e-15, 0.))
def testXGreaterThanOne(self, dtype, rtol, atol):
self._test_range(0., 3., dtype, rtol, atol, is_negative=False)
class IgammaTest(xla_test.XLATestCase, parameterized.TestCase):
def setUp(self):

View File

@ -292,13 +292,17 @@ class UnaryOpsTest(xla_test.XLATestCase):
np.array([[1, 2]], dtype=dtype),
expected=np.array([[0.540297, -0.41614]], dtype=dtype))
# Confirm that log1p will remain precise across a range of small values.
self._assertOpOutputMatchesExpected(
math_ops.log1p,
np.array([[1e-14, 1e-15, 0.6]], dtype=dtype),
expected=np.log1p(np.array([[1e-14, 1e-15, 0.6]],
dtype=dtype)).astype(dtype),
rtol=1e-4,
atol=1e-6)
np.array([[1e-14, 1e-15, 0.6, 2] + [x * 1e-5 for x in range(1, 20)]],
dtype=dtype),
expected=np.log1p(
np.array(
[[1e-14, 1e-15, 0.6, 2] + [x * 1e-5 for x in range(1, 20)]],
dtype=dtype)).astype(dtype),
rtol=1e-15 if dtype == np.float64 else 1e-4,
atol=1e-15 if dtype == np.float64 else 1e-4)
self._assertOpOutputMatchesExpected(
math_ops.rint,

View File

@ -211,6 +211,24 @@ class XlaOpsNumericalTest(xla_test.XLATestCase, parameterized.TestCase):
[7, 7, 7, 7, 7], [7, 2, 3, 7, 7], [7, 7, 7, 7, 7]],
dtype=dtype))
def testPadNegative(self):
for dtype in self.numeric_types:
def pad_fn(x):
return xla.pad(
x,
padding_value=7,
padding_low=[0, -1],
padding_high=[1, -2],
padding_interior=[1, 2])
self._assertOpOutputMatchesExpected(
pad_fn,
args=(np.arange(6, dtype=np.int32).astype(dtype).reshape([2, 3]),),
expected=np.array(
[[7, 7, 1, 7], [7, 7, 7, 7], [7, 7, 4, 7], [7, 7, 7, 7]],
dtype=dtype))
@test_util.disable_mlir_bridge('Not supported yet')
def testReduce(self):
for dtype in set(self.numeric_types).intersection(

View File

@ -50,6 +50,14 @@ class MatrixTriangularSolveOp : public XlaOpKernel {
return;
}
auto lhs_size = lhs_shape.dims();
OP_REQUIRES(
ctx,
lhs_shape.dim_size(lhs_size - 1) == lhs_shape.dim_size(lhs_size - 2),
errors::InvalidArgument("The coefficient matrix must be square in "
"the inner-most two dimensions: ",
lhs_shape.DebugString()));
xla::XlaOp a = ctx->Input(0);
xla::XlaOp b = ctx->Input(1);
std::tie(a, b) = Broadcast(a, lhs_shape, b, rhs_shape, bcast);

View File

@ -64,14 +64,6 @@ class XlaPadOp : public XlaOpKernel {
padding_interior.size(), " vs. ", rank, ")"));
auto non_negative = [](int64 x) { return x >= 0; };
OP_REQUIRES(
context, absl::c_all_of(padding_low, non_negative),
errors::InvalidArgument("padding_low must be non-negative, got [",
absl::StrJoin(padding_low, ","), "]"));
OP_REQUIRES(
context, absl::c_all_of(padding_high, non_negative),
errors::InvalidArgument("padding_high must be non-negative, got [",
absl::StrJoin(padding_high, ","), "]"));
OP_REQUIRES(
context, absl::c_all_of(padding_interior, non_negative),
errors::InvalidArgument("padding_interior must be non-negative, got [",

View File

@ -59,6 +59,7 @@ StatusOr<std::shared_ptr<PjRtClient>> GetCpuClient(bool asynchronous) {
return std::make_shared<PjRtClient>(
kCpuPlatformName, client, std::move(devices), /*host_id=*/0,
/*allocator=*/nullptr, /*host_memory_allocator=*/nullptr,
/*should_stage_host_to_device_transfers=*/false,
/*gpu_run_options=*/nullptr);
}

View File

@ -72,18 +72,21 @@ TEST(GpuMultiStream, Basics) {
TF_ASSERT_OK_AND_ASSIGN(
auto dummy_buffer,
PjRtBuffer::FromHostBuffer(
dummy_inputs.data(), dummy_shape, /*force_copy=*/false,
dummy_inputs.data(), dummy_shape,
PjRtBuffer::HostBufferSemantics::kImmutableUntilTransferCompletes,
/*buffer_reference=*/nullptr, client.get(), device));
TF_ASSERT_OK_AND_ASSIGN(
auto in_buffer0,
PjRtBuffer::FromHostBuffer(inputs.data(), shape, /*force_copy=*/false,
/*buffer_reference=*/nullptr, client.get(),
device));
PjRtBuffer::FromHostBuffer(
inputs.data(), shape,
PjRtBuffer::HostBufferSemantics::kImmutableUntilTransferCompletes,
/*buffer_reference=*/nullptr, client.get(), device));
TF_ASSERT_OK_AND_ASSIGN(
auto in_buffer1,
PjRtBuffer::FromHostBuffer(inputs.data(), shape, /*force_copy=*/false,
/*buffer_reference=*/nullptr, client.get(),
device));
PjRtBuffer::FromHostBuffer(
inputs.data(), shape,
PjRtBuffer::HostBufferSemantics::kImmutableUntilTransferCompletes,
/*buffer_reference=*/nullptr, client.get(), device));
// The execution may be enqueued before the transfers complete, requiring
// adequate device-side synchronization.
ExecuteOptions options;

View File

@ -53,6 +53,7 @@ StatusOr<std::shared_ptr<PjRtClient>> GetInterpreterClient() {
return std::make_shared<PjRtClient>(
kInterpreterPlatformName, client, std::move(devices), /*host_id=*/0,
/*allocator=*/nullptr, /*host_memory_allocator=*/nullptr,
/*should_stage_host_to_device_transfers=*/false,
/*gpu_run_options=*/nullptr);
}

View File

@ -316,6 +316,7 @@ StatusOr<std::shared_ptr<PjRtClient>> GetNvidiaGpuClient(
"gpu", xla_client, std::move(devices),
/*node_id=*/node_id, std::move(allocator),
std::move(host_memory_allocator),
/*should_stage_host_to_device_transfers=*/true,
/*gpu_run_options=*/std::move(gpu_run_options));
return pyclient;
}

View File

@ -95,6 +95,7 @@ limitations under the License.
#include "tensorflow/compiler/xla/util.h"
#include "tensorflow/compiler/xla/xla_data.pb.h"
#include "tensorflow/core/platform/errors.h"
#include "tensorflow/core/platform/mem.h"
#include "tensorflow/core/platform/status.h"
#include "tensorflow/core/platform/types.h"
#include "tensorflow/core/profiler/lib/traceme.h"
@ -154,18 +155,35 @@ StatusOr<DeviceAssignment> DevicesToDeviceAssignment(
return xla_assignment;
}
class CpuAllocator : public tensorflow::Allocator {
public:
CpuAllocator() = default;
string Name() override { return "cpu"; }
void* AllocateRaw(size_t alignment, size_t num_bytes) override {
return tensorflow::port::AlignedMalloc(num_bytes, alignment);
}
void DeallocateRaw(void* ptr) override {
return tensorflow::port::AlignedFree(ptr);
}
};
PjRtClient::PjRtClient(
std::string platform_name, LocalClient* client,
std::vector<std::unique_ptr<Device>> devices, int host_id,
std::unique_ptr<se::DeviceMemoryAllocator> allocator,
std::unique_ptr<tensorflow::Allocator> host_memory_allocator,
bool should_stage_host_to_device_transfers,
std::unique_ptr<GpuExecutableRunOptions> gpu_run_options)
: platform_name_(std::move(platform_name)),
client_(client),
host_memory_allocator_(std::move(host_memory_allocator)),
devices_(std::move(devices)),
host_id_(host_id),
owned_allocator_(std::move(allocator)),
host_memory_allocator_(std::move(host_memory_allocator)),
should_stage_host_to_device_transfers_(
should_stage_host_to_device_transfers),
gpu_run_options_(std::move(gpu_run_options)),
h2d_transfer_pool_(tensorflow::Env::Default(), "py_xla_h2d_transfer",
client->device_count()) {
@ -175,6 +193,10 @@ PjRtClient::PjRtClient(
allocator_ = client_->backend().memory_allocator();
}
if (!host_memory_allocator_) {
host_memory_allocator_ = std::make_unique<CpuAllocator>();
}
for (const std::unique_ptr<Device>& device : devices_) {
CHECK(id_to_device_.insert({device->id(), device.get()}).second)
<< "Duplicate device id: " << device->id();
@ -202,16 +224,58 @@ StatusOr<DeviceAssignment> PjRtClient::GetDefaultDeviceAssignment(
StatusOr<absl::flat_hash_set<int>> PjRtClient::GetParametersThatMustBeDonated(
const LocalExecutable& executable, bool tuple_inputs) const {
// TODO(b/149489114) support buffer donation on CPU/GPU when XLA supports it.
HloComputation* computation =
executable.executable()->module().entry_computation();
int number_of_parameters = [&]() -> int {
if (tuple_inputs) {
CHECK_EQ(computation->num_parameters(), 1);
const Shape& input_tuple_shape =
computation->parameter_instruction(0)->shape();
CHECK(input_tuple_shape.IsTuple());
return input_tuple_shape.tuple_shapes_size();
} else {
return computation->num_parameters();
}
}();
// If any buffer in a parameter is aliased we will donate the entire input
// parameter.
absl::flat_hash_set<int> parameters_to_donate;
const HloInputOutputAliasConfig& config =
executable.executable()->module().input_output_alias_config();
TF_RETURN_IF_ERROR(config.ForEachAliasWithStatus(
[](const ShapeIndex& output_index,
const HloInputOutputAliasConfig::Alias& alias) {
return InvalidArgument(
"Buffer aliasing is not supported by XLA for non-TPU backends.");
[&](const ShapeIndex& output_index,
const HloInputOutputAliasConfig::Alias& alias) {
if (tuple_inputs) {
if (alias.parameter_number != 0) {
return InvalidArgument(
"Unexpected parameter number %d in alias config with tupled "
"inputs",
alias.parameter_number);
}
const ShapeIndex& index = alias.parameter_index;
if (!index.empty()) {
int this_parameter = index.data()[0];
if (this_parameter >= number_of_parameters) {
return InvalidArgument(
"Unexpected parameter index %s in alias config with tupled "
"inputs and %d parameters",
index.ToString(), number_of_parameters);
}
parameters_to_donate.insert(this_parameter);
}
} else {
int this_parameter = alias.parameter_number;
if (this_parameter >= number_of_parameters) {
return InvalidArgument(
"Unexpected parameter number %d in alias config without tupled "
"inputs and %d parameters",
this_parameter, number_of_parameters);
}
parameters_to_donate.insert(this_parameter);
}
return Status::OK();
}));
return absl::flat_hash_set<int>();
return parameters_to_donate;
}
namespace {
@ -484,7 +548,8 @@ void PjRtBuffer::ScopedHold::AddToInput(
/* static */
StatusOr<std::unique_ptr<PjRtBuffer>> PjRtBuffer::FromHostBuffer(
const void* data, const Shape& shape, bool force_copy,
const void* data, const Shape& shape,
HostBufferSemantics host_buffer_semantics,
std::shared_ptr<void> buffer_reference, PjRtClient* client,
Device* device) {
tensorflow::profiler::TraceMe traceme("PjRtBuffer::FromHostBuffer");
@ -495,34 +560,63 @@ StatusOr<std::unique_ptr<PjRtBuffer>> PjRtBuffer::FromHostBuffer(
}
TF_ASSIGN_OR_RETURN(LocalDeviceState * local_device,
device->GetLocalDeviceState());
// If we are on the host platform and the input buffer is sufficiently
// aligned, we can simply point to the input array's data without any further
// copies. At the time of writing we require a 16-byte alignment because XLA
// may generate code which requires it.
if (!force_copy &&
((absl::bit_cast<std::uintptr_t>(data) &
(cpu_function_runtime::kMinAlign - 1)) == 0) &&
local_device->executor()->platform()->id() == se::host::kHostPlatformId) {
std::function<void()> on_delete_callback =
[buffer_reference{std::move(buffer_reference)}]() {
// Frees buffer_reference.
};
se::DeviceMemoryBase buffer(const_cast<void*>(data),
ShapeUtil::ByteSizeOf(shape));
absl::Span<const std::shared_ptr<BufferSequencingEvent>> definition_events;
auto device_buffer = std::make_shared<TrackedDeviceBuffer>(
/*allocator=*/nullptr, local_device->device_ordinal(),
std::initializer_list<se::DeviceMemoryBase>{buffer}, definition_events,
std::move(on_delete_callback));
return absl::make_unique<PjRtBuffer>(shape, shape, std::move(device_buffer),
client, device);
}
int64 size = ShapeUtil::ByteSizeOf(shape);
TransferManager* transfer_manager =
client->client()->backend().transfer_manager();
TF_ASSIGN_OR_RETURN(Shape compact_shape,
transfer_manager->ChooseCompactLayoutForShape(shape));
// The CPU platform is special because the "host" and the "device" are in the
// same memory space. If the input shape is in the correct layout and we don't
// want to defer the copy onto a thread, we can use the following fast
// path.
bool is_cpu_platform =
local_device->executor()->platform()->id() == se::host::kHostPlatformId;
if (is_cpu_platform) {
// If we are on the host platform and the input buffer is sufficiently
// aligned, we can simply point to the input array's data without any
// further copies. At the time of writing we require a 16-byte alignment
// because XLA may generate code which requires it.
bool can_use_zero_copy =
host_buffer_semantics == HostBufferSemantics::kZeroCopy &&
((absl::bit_cast<std::uintptr_t>(data) &
(cpu_function_runtime::kMinAlign - 1)) == 0);
if (shape.layout() == compact_shape.layout() &&
(host_buffer_semantics ==
HostBufferSemantics::kImmutableOnlyDuringCall ||
can_use_zero_copy)) {
std::function<void()> on_delete_callback;
se::DeviceMemoryBase buffer;
// If we are on the host platform and the input buffer is sufficiently
// aligned, we can simply point to the input array's data without any
// further copies. At the time of writing we require a 16-byte alignment
// because XLA may generate code which requires it.
if (can_use_zero_copy) {
on_delete_callback = [buffer_reference{std::move(buffer_reference)}]() {
// Frees buffer_reference.
};
buffer = se::DeviceMemoryBase(const_cast<void*>(data), size);
} else {
void* staging_buffer = client->host_memory_allocator()->AllocateRaw(
cpu_function_runtime::kMinAlign, size);
on_delete_callback = [staging_buffer, client]() {
client->host_memory_allocator()->DeallocateRaw(staging_buffer);
};
buffer = se::DeviceMemoryBase(staging_buffer, size);
std::memcpy(staging_buffer, data, size);
}
absl::Span<const std::shared_ptr<BufferSequencingEvent>>
definition_events;
auto device_buffer = std::make_shared<TrackedDeviceBuffer>(
/*allocator=*/nullptr, local_device->device_ordinal(),
std::initializer_list<se::DeviceMemoryBase>{buffer},
definition_events, std::move(on_delete_callback));
return absl::make_unique<PjRtBuffer>(
shape, shape, std::move(device_buffer), client, device);
}
}
TF_ASSIGN_OR_RETURN(
std::unique_ptr<PjRtBuffer> py_buffer,
AllocateDestinationBuffer(compact_shape, device, local_device,
@ -531,17 +625,41 @@ StatusOr<std::unique_ptr<PjRtBuffer>> PjRtBuffer::FromHostBuffer(
ScopedHold device_buffer(py_buffer->GetBufferWithUsageHold());
CHECK(device_buffer.ok());
// If necessary, allocate a host-side buffer for staging host-to-device
// transfers. On GPU this is a buffer in pinned memory.
std::shared_ptr<void> staging_buffer;
if (host_buffer_semantics == HostBufferSemantics::kImmutableOnlyDuringCall ||
client->should_stage_host_to_device_transfers()) {
void* ptr = client->host_memory_allocator()->AllocateRaw(
tensorflow::Allocator::kAllocatorAlignment, size);
staging_buffer = std::shared_ptr<void>(ptr, [client](void* ptr) {
client->host_memory_allocator()->DeallocateRaw(ptr);
});
}
// Copy the buffer into a staging buffer before returning control to the
// caller if the caller only guaranteed that the buffer is valid for the
// duration of the call. Otherwise, we stage (if necessary) on a separate
// thread.
if (host_buffer_semantics == HostBufferSemantics::kImmutableOnlyDuringCall) {
std::memcpy(staging_buffer.get(), data, size);
buffer_reference.reset();
data = nullptr;
}
// The host to device transfer is performed on a thread pool, mostly because
// it includes linearization that may be slow. It is OK to capture the
// py_buffer pointer because the py_buffer can't be deleted until all the
// usage holds have gone away.
// TODO(misard) assess if it would be preferable to introduce a heuristic to
// put the transfer into the calling thread for small literals.
auto transfer_h2d = [client, transfer_manager, local_device,
movable_device_buffer{device_buffer.ToClosure()}, data,
shape, py_buffer{py_buffer.get()}, compact_shape,
auto transfer_h2d = [client, transfer_manager, local_device, data, size,
movable_device_buffer{device_buffer.ToClosure()}, shape,
py_buffer{py_buffer.get()}, compact_shape,
on_device_shape{py_buffer->on_device_shape()},
buffer_reference{std::move(buffer_reference)}]() {
staging_buffer{std::move(staging_buffer)},
buffer_reference{std::move(buffer_reference)},
host_buffer_semantics]() {
ScopedHold device_buffer(movable_device_buffer);
// This function uses TF_CHECK_OK and ValueOrDie() since we have no way
// to report failures from a callback. However, the operations here are
@ -551,20 +669,16 @@ StatusOr<std::unique_ptr<PjRtBuffer>> PjRtBuffer::FromHostBuffer(
ShapedBuffer buffer = device_buffer->AsShapedBuffer(
compact_shape, on_device_shape, client->client()->platform());
std::shared_ptr<void> staging_buffer;
// If applicable on the backend, stage the transfer via host memory
// allocated via the host_memory_allocator. On GPU, this is pinned
// memory.
if (client->host_memory_allocator()) {
int64 size = ShapeUtil::ByteSizeOf(shape);
void* ptr = client->host_memory_allocator()->AllocateRaw(
tensorflow::Allocator::kAllocatorAlignment, size);
staging_buffer = std::shared_ptr<void>(ptr, [client](void* ptr) {
client->host_memory_allocator()->DeallocateRaw(ptr);
});
std::memcpy(ptr, data, size);
if (staging_buffer) {
// If we didn't already copy the input buffer into the staging buffer,
// do so now.
if (host_buffer_semantics !=
HostBufferSemantics::kImmutableOnlyDuringCall) {
std::memcpy(staging_buffer.get(), data, size);
}
BorrowingLiteral literal(static_cast<const char*>(staging_buffer.get()),
shape);
TF_CHECK_OK(transfer_manager->TransferLiteralToDeviceAsync(
@ -584,9 +698,15 @@ StatusOr<std::unique_ptr<PjRtBuffer>> PjRtBuffer::FromHostBuffer(
local_device->ThenRelease(
local_device->host_to_device_stream(),
std::make_pair(buffer_reference, std::move(staging_buffer)));
std::make_pair(std::move(buffer_reference), std::move(staging_buffer)));
};
client->h2d_transfer_pool()->Schedule(transfer_h2d);
if (is_cpu_platform) {
// Using the h2d_transfer_pool would be a double thread hop; the code
// already defers its work onto a stream (= thread on CPU).
transfer_h2d();
} else {
client->h2d_transfer_pool()->Schedule(transfer_h2d);
}
return py_buffer;
}

View File

@ -128,6 +128,7 @@ class PjRtClient {
std::vector<std::unique_ptr<Device>> devices, int host_id,
std::unique_ptr<se::DeviceMemoryAllocator> allocator,
std::unique_ptr<tensorflow::Allocator> host_memory_allocator,
bool should_stage_host_to_device_transfers,
std::unique_ptr<GpuExecutableRunOptions> gpu_run_options);
virtual ~PjRtClient() = default;
@ -153,6 +154,9 @@ class PjRtClient {
tensorflow::Allocator* host_memory_allocator() const {
return host_memory_allocator_.get();
}
bool should_stage_host_to_device_transfers() const {
return should_stage_host_to_device_transfers_;
}
GpuExecutableRunOptions* gpu_run_options() const {
return gpu_run_options_.get();
@ -190,6 +194,9 @@ class PjRtClient {
std::string platform_name_;
LocalClient* client_;
// Allocator to be used for staging memory transfers to devices.
std::unique_ptr<tensorflow::Allocator> host_memory_allocator_;
// Includes all devices, including non-local devices on multi-host platforms.
std::vector<std::unique_ptr<Device>> devices_;
// Maps Device::id() to the corresponding Device. Includes all devices.
@ -201,10 +208,10 @@ class PjRtClient {
se::DeviceMemoryAllocator* allocator_;
std::unique_ptr<se::DeviceMemoryAllocator> owned_allocator_;
// Allocator to be used for staging memory transfers to devices. Optional;
// only used on GPU where it is more efficient to copy buffers to and from the
// device via a staging area of pinned memory.
std::unique_ptr<tensorflow::Allocator> host_memory_allocator_;
// Should we always prefer to stage host-to-device transfers via memory
// allocated on host_memory_allocator_? True only on GPU, where we prefer to
// transfer via pinned memory.
bool should_stage_host_to_device_transfers_;
std::unique_ptr<GpuExecutableRunOptions> gpu_run_options_;
@ -396,13 +403,35 @@ class PjRtBuffer {
StatusOr<std::shared_ptr<TrackedDeviceBuffer>> buffer_or_;
};
// If `force_copy` is true, forces a copy of the input buffer on CPU.
// Otherwise the library is free to alias the output buffer with `data`.
// `buffer_reference` is an optional shared pointer that should be kept alive
// by the runtime as long as the contents of `data` may still be accessed by
// the runtime (may be nullptr).
// Describes the semantics the caller to FromHostBuffer expects from the
// runtime, in a total order from most restrictive to least restrictive.
enum class HostBufferSemantics {
// The runtime may not hold references to `data` after the call to
// `FromHostBuffer` completes. The caller promises that `data` is immutable
// and will not be freed only for the duration of the FromHostBuffer call.
// `buffer_reference` will be freed by the time `FromHostBuffer` returns.
kImmutableOnlyDuringCall,
// The runtime may hold onto `data` after the call to `FromHostBuffer`
// returns while the runtime completes a transfer to the device. The caller
// promises not to mutate or free `data` until the transfer completes, at
// which point the runtime will release `buffer_reference`. It is also
// correct to wait on the host (directly or indirectly) for the buffer's
// definition event to complete.
kImmutableUntilTransferCompletes,
// The PjRtBuffer may alias `data` internally and the runtime may use the
// `data` contents as long as the buffer is alive.
// The caller promises to keep `data` alive and not to mutate its contents
// as long as the buffer is alive; to notify the caller that the buffer may
// be freed, the runtime will release its `buffer_reference` when the
// PjRtBuffer is freed. On non-CPU platforms this acts identically to
// kImmutableUntilTransferCompletes.
kZeroCopy,
};
static StatusOr<std::unique_ptr<PjRtBuffer>> FromHostBuffer(
const void* data, const Shape& shape, bool force_copy,
const void* data, const Shape& shape,
HostBufferSemantics host_buffer_semantics,
std::shared_ptr<void> buffer_reference, PjRtClient* client,
Device* device);

View File

@ -84,7 +84,8 @@ PyClient::GetDefaultDeviceAssignment1D(int num_replicas) {
}
StatusOr<std::unique_ptr<PyBuffer>> PyClient::BufferFromPyal(
const pybind11::object& argument, Device* device, bool force_copy) {
const pybind11::object& argument, Device* device, bool force_copy,
PjRtBuffer::HostBufferSemantics host_buffer_semantics) {
if (device == nullptr) {
TF_RET_CHECK(!pjrt_client_->local_devices().empty());
device = pjrt_client_->local_devices().front();
@ -111,9 +112,9 @@ StatusOr<std::unique_ptr<PyBuffer>> PyClient::BufferFromPyal(
{
py::gil_scoped_release gil_release;
TF_ASSIGN_OR_RETURN(
buffer, PjRtBuffer::FromHostBuffer(c->buf_ptr, c->shape, force_copy,
std::move(py_buffer_ref),
pjrt_client_.get(), device));
buffer, PjRtBuffer::FromHostBuffer(
c->buf_ptr, c->shape, host_buffer_semantics,
std::move(py_buffer_ref), pjrt_client_.get(), device));
}
auto traceback = Traceback::Get();
return std::make_unique<PyBuffer>(shared_from_this(), std::move(buffer),

View File

@ -120,7 +120,8 @@ class PyClient : public std::enable_shared_from_this<PyClient> {
}
StatusOr<std::unique_ptr<PyBuffer>> BufferFromPyal(
const pybind11::object& argument, Device* device, bool force_copy);
const pybind11::object& argument, Device* device, bool force_copy,
PjRtBuffer::HostBufferSemantics host_buffer_semantics);
StatusOr<std::unique_ptr<PyExecutable>> Compile(
const XlaComputation& computation, CompileOptions options);

View File

@ -509,6 +509,13 @@ PYBIND11_MODULE(xla_extension, m) {
.value("PLATFORM", GpuAllocatorConfig::Kind::kPlatform)
.value("BFC", GpuAllocatorConfig::Kind::kBFC);
py::enum_<PjRtBuffer::HostBufferSemantics>(m, "HostBufferSemantics")
.value("IMMUTABLE_ONLY_DURING_CALL",
PjRtBuffer::HostBufferSemantics::kImmutableOnlyDuringCall)
.value("IMMUTABLE_UNTIL_TRANSFER_COMPLETES",
PjRtBuffer::HostBufferSemantics::kImmutableUntilTransferCompletes)
.value("ZERO_COPY", PjRtBuffer::HostBufferSemantics::kZeroCopy);
py::class_<PyClient, std::shared_ptr<PyClient>> py_local_client(m, "Client");
py_local_client.def_property_readonly("platform", &PyClient::platform_name)
.def("device_count", &PyClient::device_count)
@ -527,7 +534,9 @@ PYBIND11_MODULE(xla_extension, m) {
.def("create_host_to_device_channel_handle",
&PyClient::CreateHostToDeviceChannelHandle)
.def("buffer_from_pyval", &PyClient::BufferFromPyal, py::arg("argument"),
py::arg("device") = nullptr, py::arg("force_copy") = false)
py::arg("device") = nullptr, py::arg("force_copy") = false,
py::arg("host_buffer_semantics") =
PjRtBuffer::HostBufferSemantics::kZeroCopy)
.def("compile", &PyClient::Compile, py::arg("computation"),
py::arg("compile_options") = CompileOptions())
.def("heap_profile", &PyClient::HeapProfile);

View File

@ -304,6 +304,7 @@ def computation_count():
Device = _xla.Device
CompileOptions = _xla.CompileOptions
HostBufferSemantics = _xla.HostBufferSemantics
# An Executable is a C++ class that duck types with the following API:
# class Executable(object):

View File

@ -1909,11 +1909,6 @@ def TestFactory(xla_backend, cloud_tpu=False):
out = ops.Add(p1, p2)
c.setup_alias([], 0, [])
c = c.build(out)
if self.backend.platform != "tpu":
with self.assertRaisesRegex(
RuntimeError, "Buffer aliasing is not supported "
"by XLA for non-TPU backends"):
self.backend.compile(c)
tests.append(AliasTest)
@ -1991,7 +1986,8 @@ def TestFactory(xla_backend, cloud_tpu=False):
def testRoundTrip(self, dtype, shape):
x = np.array(np.random.rand(*shape) * 100, dtype=dtype)
x_ptr = x.__array_interface__["data"][0]
buffer = self.backend.buffer_from_pyval(x)
buffer = self.backend.buffer_from_pyval(
x, host_buffer_semantics=xla_client.HostBufferSemantics.ZERO_COPY)
y = np.array(buffer, copy=False)
y_ptr = y.__array_interface__["data"][0]
np.testing.assert_array_equal(x, y)
@ -2000,7 +1996,9 @@ def TestFactory(xla_backend, cloud_tpu=False):
self.assertTrue((x_ptr & 15) != 0 or x_ptr == y_ptr)
self.assertEqual(y_ptr, buffer.unsafe_buffer_pointer())
buffer2 = self.backend.buffer_from_pyval(x, force_copy=True)
during_call = xla_client.HostBufferSemantics.IMMUTABLE_ONLY_DURING_CALL
buffer2 = self.backend.buffer_from_pyval(
x, host_buffer_semantics=during_call)
z = np.array(buffer2, copy=False)
self.assertNotEqual(x.__array_interface__["data"][0],
z.__array_interface__["data"][0])

View File

@ -3304,6 +3304,15 @@ tf_cc_test(
],
)
cc_library(
name = "memory_space_assignment_utils",
srcs = ["memory_space_assignment_utils.cc"],
hdrs = ["memory_space_assignment_utils.h"],
deps = [
":heap_simulator",
],
)
cc_library(
name = "memory_space_assignment",
srcs = ["memory_space_assignment.cc"],
@ -3311,6 +3320,7 @@ cc_library(
deps = [
":heap_simulator",
":hlo_cost_analysis",
":memory_space_assignment_utils",
"//tensorflow/compiler/xla:debug_options_flags",
"//tensorflow/core/lib/math:math_util",
],

View File

@ -573,6 +573,7 @@ void AlgebraicSimplifierVisitor::ReplaceWithBitcast(HloInstruction* instruction,
auto bitcast = computation_->AddInstruction(
HloInstruction::CreateBitcast(instruction->shape(), operand));
bitcast->set_metadata(instruction->metadata());
TF_CHECK_OK(ReplaceInstruction(instruction, bitcast));
}
@ -2454,6 +2455,25 @@ Status AlgebraicSimplifierVisitor::HandleMultiply(HloInstruction* multiply) {
return Status::OK();
}
{
HloInstruction *convert_operand, *operand;
// Mul(Convert(Pred), operand) => select(pred, operand, 0)
if (Match(multiply,
m::MultiplyAnyOrder(
m::Op(&operand),
m::Convert(
m::Op(&convert_operand)
.WithShape(m::Shape().WithElementType(PRED)))))) {
HloInstruction* zero_like_multiply =
BroadcastZeros(computation_, multiply->shape().element_type(),
multiply->shape().dimensions());
return ReplaceWithNewInstruction(
multiply, HloInstruction::CreateTernary(
multiply->shape(), HloOpcode::kSelect, convert_operand,
operand, zero_like_multiply));
}
}
VLOG(10) << "trying transform [(A * C1) * C2 => A * (C1 * C2)]";
HloInstruction *a, *c1, *c2;
if (Match(multiply,

View File

@ -29,6 +29,7 @@ limitations under the License.
#include "mlir/IR/Builders.h" // from @llvm-project
#include "mlir/IR/Function.h" // from @llvm-project
#include "mlir/IR/MLIRContext.h" // from @llvm-project
#include "mlir/IR/OperationSupport.h" // from @llvm-project
#include "mlir/IR/Value.h" // from @llvm-project
#include "tensorflow/compiler/xla/service/cpu/cpu_options.h"
#include "tensorflow/compiler/xla/service/cpu/cpu_runtime.h"
@ -255,7 +256,8 @@ Status DotOpEmitter::EmitLinalgMatmul() {
mlir::edsc::ScopedContext scope(*builder, function.getLoc());
mlir::Value a = function.getArgument(0), b = function.getArgument(1),
c = function.getArgument(2);
mlir::edsc::intrinsics::linalg_matmul(b, c, a);
mlir::edsc::intrinsics::linalg_matmul(mlir::TypeRange{},
mlir::ValueRange{b, c, a});
mlir::edsc::intrinsics::std_ret();
});
}

View File

@ -29,10 +29,12 @@ namespace xla {
namespace {
// Convert a dot into a canonical form where non-contracting and contracting
// dimensions are reshaped together and batch dimensions are the most major
// dimensions. This requires transposing and reshapes of the lhs and rhs and
// reshaping the output batch to the original shape.
// Convert a dot into a canonical form;
// * Non-contracting dimensions are reshaped together,
// * Contracting dimensions are reshaped together,
// * Batch dimensions are the most major dimensions.
// This requires transposing and reshaping of the lhs and rhs, and reshaping the
// output batch to the original shape.
Status CanonicalizeDot(HloInstruction* original_dot) {
auto computation = original_dot->parent();
const auto& original_dnums = original_dot->dot_dimension_numbers();
@ -63,7 +65,8 @@ Status CanonicalizeDot(HloInstruction* original_dot) {
}
}
// The canonical form of the lhs is
// [BatchDims, NonContractingDims, ContractingsDims]
// [BatchDims, NonContractingDimsProduct, ContractingsDimsProduct]
// If NonContractingDimsProduct is 1, it is omitted.
std::vector<int64> lhs_transpose;
lhs_transpose.reserve(lhs_rank);
lhs_transpose.insert(lhs_transpose.end(),
@ -109,7 +112,8 @@ Status CanonicalizeDot(HloInstruction* original_dot) {
}
// The canonical form of the rhs is
// [BatchDims, ContractingsDims, NonContractingDims]
// [BatchDims, NonContractingDimsProduct, ContractingsDimsProduct]
// If NonContractingDimsProduct is 1, it is omitted.
std::vector<int64> rhs_transpose;
rhs_transpose.reserve(rhs_rank);
rhs_transpose.insert(rhs_transpose.end(),

View File

@ -1336,9 +1336,40 @@ StatusOr<llvm::Value*> ElementalIrEmitter::EmitLog1p(PrimitiveType prim_type,
// When x is large, the naive evaluation of ln(x + 1) is more
// accurate than the Taylor series.
TF_ASSIGN_OR_RETURN(auto for_large_x, EmitLog(prim_type, FAdd(x, one)));
// The Taylor series for ln(x+1) is x - x^2/2 - x^3/3 + ….
auto for_small_x = FMul(FAdd(FMul(negative_half, x), one), x);
const auto kAntilogarithmIsSmallThreshold = 1e-4;
// When x is small, (defined to be less than sqrt(2) / 2), use a rational
// approximation. The approximation below is based on one from the Cephes
// Mathematical Library.
//
// sqrt(2) - 1.
const auto kAntilogarithmIsSmallThreshold = 0.41421356237309504880;
static const std::array<double, 7> kDenominatorCoeffs{
1.,
1.5062909083469192043167E1,
8.3047565967967209469434E1,
2.2176239823732856465394E2,
3.0909872225312059774938E2,
2.1642788614495947685003E2,
6.0118660497603843919306E1,
};
static const std::array<double, 7> kNumeratorCoeffs{
4.5270000862445199635215E-5, 4.9854102823193375972212E-1,
6.5787325942061044846969E0, 2.9911919328553073277375E1,
6.0949667980987787057556E1, 5.7112963590585538103336E1,
2.0039553499201281259648E1,
};
auto x_squared = FMul(x, x);
TF_ASSIGN_OR_RETURN(auto denominator,
EvaluatePolynomial(type, x, kDenominatorCoeffs));
TF_ASSIGN_OR_RETURN(auto numerator,
EvaluatePolynomial(type, x, kNumeratorCoeffs));
auto for_small_x = FDiv(numerator, denominator);
for_small_x = FMul(FMul(x, x_squared), for_small_x);
for_small_x = FAdd(FMul(negative_half, x_squared), for_small_x);
for_small_x = FAdd(x, for_small_x);
auto abs_x =
llvm_ir::EmitCallToIntrinsic(llvm::Intrinsic::fabs, {value}, {type}, b_);
auto x_is_small = FCmpOLT(
@ -2699,4 +2730,14 @@ StatusOr<llvm::Value*> ElementalIrEmitter::EmitElementalReduce(
}
}
// Evaluate polynomial using Horner's method.
StatusOr<llvm::Value*> ElementalIrEmitter::EvaluatePolynomial(
llvm::Type* type, llvm::Value* x, absl::Span<const double> coefficients) {
llvm::Value* poly = llvm::ConstantFP::get(type, 0.0);
for (const double c : coefficients) {
poly = FAdd(FMul(poly, x), llvm::ConstantFP::get(type, c));
}
return poly;
}
} // namespace xla

View File

@ -258,6 +258,10 @@ class ElementalIrEmitter : public IrBuilderMixin<ElementalIrEmitter> {
StatusOr<llvm::Value*> EmitComplexPower(const HloInstruction* op,
llvm::Value* a, llvm::Value* b,
llvm::Value* c, llvm::Value* d);
// Evaluates a polynomial using Horner's method.
StatusOr<llvm::Value*> EvaluatePolynomial(
llvm::Type* type, llvm::Value* x, absl::Span<const double> coefficients);
};
} // namespace xla

View File

@ -39,6 +39,7 @@ cc_library(
"@com_google_absl//absl/memory",
"@com_google_absl//absl/strings",
"@com_google_absl//absl/strings:str_format",
"@llvm-project//llvm:AMDGPUCodeGen",
"@llvm-project//llvm:Analysis",
"@llvm-project//llvm:BitReader",
"@llvm-project//llvm:BitWriter",
@ -52,7 +53,6 @@ cc_library(
"@llvm-project//llvm:Scalar",
"@llvm-project//llvm:Support",
"@llvm-project//llvm:Target",
"@llvm-project//llvm:amdgpu_code_gen",
],
)

View File

@ -101,6 +101,7 @@ class EnforceMinorToMajorReduceOpVisitor : public DfsHloRewriteVisitor {
new_reduce_shape_layout);
HloInstruction *canonical_reduce_input = reduce->parent()->AddInstruction(
HloInstruction::CreateBitcast(new_operand_shape, operand));
canonical_reduce_input->set_metadata(reduce->metadata());
VLOG(5) << "Reduction input: " << canonical_reduce_input->ToString();
std::unique_ptr<HloInstruction> new_reduce = HloInstruction::CreateReduce(

View File

@ -539,6 +539,15 @@ HloInstruction* BroadcastZeros(HloComputation* computation,
/*result_shape_bounds=*/broadcast_dimensions);
}
HloInstruction* BroadcastOnes(HloComputation* computation,
PrimitiveType element_type,
absl::Span<const int64> broadcast_dimensions) {
HloInstruction* one = computation->AddInstruction(
HloInstruction::CreateConstant(LiteralUtil::One(element_type)));
return MakeBroadcastHlo(one, /*broadcast_dimensions=*/{},
/*result_shape_bounds=*/broadcast_dimensions);
}
// Recursively creates a dummy op given a shape. Leaf nodes are broadcasted zero
// while internal nodes are tuples.
HloInstruction* CreateDummyOp(HloComputation::Builder* b, const Shape& shape) {

View File

@ -276,6 +276,11 @@ HloInstruction* BroadcastZeros(HloComputation* computation,
PrimitiveType element_type,
absl::Span<const int64> broadcast_dimensions);
// Same as above, but fill the tensor with ones.
HloInstruction* BroadcastOnes(HloComputation* computation,
PrimitiveType element_type,
absl::Span<const int64> broadcast_dimensions);
// Creates a HLO computation that takes arguments of type `domain` and produces
// a value of type `range`.
StatusOr<std::unique_ptr<HloComputation>> CreateComputationWithSignature(

View File

@ -698,7 +698,7 @@ bool HloDataflowAnalysis::UpdateCollectivePermuteDoneValueSet(
CHECK_EQ(collective_permute_done->opcode(),
HloOpcode::kCollectivePermuteDone);
bool changed = false;
// CollectivePermuteDone forwards the operand value at {0} to its output.
// CollectivePermuteDone forwards the operand value at {1} to its output.
const HloValueSet& operand_value_set =
GetValueSet(collective_permute_done->operand(0), {1});
HloValueSet& value_set = GetValueSet(collective_permute_done);
@ -945,6 +945,17 @@ Status HloDataflowAnalysis::InitializeInstructionValueSets() {
// CopyDone consumes a tuple produced by CopyStart and produces an
// element. Its output aliases its input tuple element {0}.
break;
case HloOpcode::kCollectivePermuteStart:
// CollectivePermuteStart produces a tuple of
// {aliased operand, destination buffer, U32 context, U32 context}.
define_value_at(/*index=*/{});
define_value_at(/*index=*/{1});
define_value_at(/*index=*/{2});
define_value_at(/*index=*/{3});
break;
case HloOpcode::kCollectivePermuteDone:
// CollectivePermuteDone's output aliases its input tuple element {1}.
break;
case HloOpcode::kRecvDone:
// RecvDone produces a two-element tuple. Element zero aliases its
// input tuple element {0}; element one is a token.

View File

@ -550,6 +550,7 @@ bool HloCollectiveInstruction::IdenticalSlowPath(
const auto& casted_other =
static_cast<const HloCollectiveInstruction&>(other);
return HloChannelInstruction::IdenticalSlowPath(other, eq_computations) &&
constrain_layout() == casted_other.constrain_layout() &&
absl::c_equal(replica_groups(), casted_other.replica_groups(),
[](const ReplicaGroup& a, const ReplicaGroup& b) {
return absl::c_equal(a.replica_ids(), b.replica_ids());
@ -1101,7 +1102,9 @@ bool HloMapInstruction::IdenticalSlowPath(
const HloInstruction& other,
const std::function<bool(const HloComputation*, const HloComputation*)>&
eq_computations) const {
return eq_computations(to_apply(), other.to_apply());
const auto& casted_other = static_cast<const HloMapInstruction&>(other);
return eq_computations(to_apply(), casted_other.to_apply()) &&
dimensions() == casted_other.dimensions();
}
std::unique_ptr<HloInstruction> HloMapInstruction::CloneWithNewOperandsImpl(
@ -2515,7 +2518,8 @@ bool HloDynamicSliceInstruction::IdenticalSlowPath(
const HloInstruction& other,
const std::function<bool(const HloComputation*, const HloComputation*)>&
eq_computations) const {
return true;
const auto& casted_other = static_cast<const HloMapInstruction&>(other);
return dynamic_slice_sizes() == casted_other.dynamic_slice_sizes();
}
std::unique_ptr<HloInstruction>

View File

@ -16,6 +16,7 @@ limitations under the License.
#include "tensorflow/compiler/xla/service/memory_space_assignment.h"
#include "tensorflow/compiler/xla/debug_options_flags.h"
#include "tensorflow/compiler/xla/service/memory_space_assignment_utils.h"
#include "tensorflow/core/lib/math/math_util.h"
namespace xla {
@ -30,6 +31,22 @@ const int kWhileExecutionCount = 5;
} // namespace
/*static*/ StatusOr<std::unique_ptr<MemorySpaceAssignmentCostAnalysis>>
MemorySpaceAssignmentCostAnalysis::Create(
const HloCostAnalysis& cost_analysis,
float async_copy_bandwidth_bytes_per_second,
float alternate_mem_bandwidth_bytes_per_second, const HloModule& module) {
TF_ASSIGN_OR_RETURN(auto alias_analysis, HloAliasAnalysis::Run(&module));
TF_ASSIGN_OR_RETURN(auto hlo_live_range,
HloLiveRange::Run(module.schedule(), *alias_analysis,
module.entry_computation()));
auto call_graph = CallGraph::Build(&module);
return absl::WrapUnique(new MemorySpaceAssignmentCostAnalysis(
cost_analysis, async_copy_bandwidth_bytes_per_second,
alternate_mem_bandwidth_bytes_per_second, std::move(alias_analysis),
std::move(hlo_live_range), std::move(call_graph)));
}
float MemorySpaceAssignmentCostAnalysis::GetAlternateMemoryBenefit(
const HloInstruction& instruction, float elapsed_time_due_to_alternate_mem,
MemorySpaceAssignmentCostAnalysis::Cache* cache) const {
@ -73,19 +90,32 @@ float MemorySpaceAssignmentCostAnalysis::GetMemoryBoundedness(
/*operand_in_alternate_mem=*/{},
/*output_in_alternate_mem=*/true),
cache);
for (const HloUse& use : interval.buffer->uses()) {
float use_alternate_mem_benefit = GetAlternateMemoryBenefit(
*use.instruction,
GetInstructionElapsedDueToMemory(*use.instruction, use.operand_number),
cache);
// If the benefit is positive (memory bound), add it to this buffer's
// benefit. If the benefit is negative (compute bound), calculate the
// maximum.
if (alternate_mem_benefit > 0 && use_alternate_mem_benefit > 0) {
alternate_mem_benefit += use_alternate_mem_benefit;
} else {
alternate_mem_benefit =
std::max(alternate_mem_benefit, use_alternate_mem_benefit);
for (const HloBuffer* buffer : alias_analysis_->ComputeBuffersAt(
interval.buffer->defining_position().instruction,
interval.buffer->defining_position().index)) {
for (const HloValue* value : buffer->values()) {
for (const HloUse& use : value->uses()) {
// We look inside the called computations of while and conditional, so
// don't use the benefit of while and conditional directly.
if (use.instruction->opcode() == HloOpcode::kWhile ||
use.instruction->opcode() == HloOpcode::kConditional) {
continue;
}
float use_alternate_mem_benefit =
GetAlternateMemoryBenefit(*use.instruction,
GetInstructionElapsedDueToMemory(
*use.instruction, use.operand_number),
cache);
// If the benefit is positive (memory bound), add it to this buffer's
// benefit. If the benefit is negative (compute bound), calculate the
// maximum.
if (alternate_mem_benefit > 0 && use_alternate_mem_benefit > 0) {
alternate_mem_benefit += use_alternate_mem_benefit;
} else {
alternate_mem_benefit =
std::max(alternate_mem_benefit, use_alternate_mem_benefit);
}
}
}
}
@ -94,17 +124,9 @@ float MemorySpaceAssignmentCostAnalysis::GetMemoryBoundedness(
float alternate_mem_slowdown =
GetInstructionElapsedDueToMemorySlowdown(interval.size);
// Scale the slowdown based on the time of this buffer. We would want earlier
// buffers have lower slowdown values, because they are less likely to overlap
// with other HLOs.
// TODO(yuemmawang): We may want a piecewise function, where a lower slowdown
// for early HLOs, and full slowdown for mid-to-late HLOs.
// TODO(yuemmawang): Further in a smarter way, we want buffers overlapped with
// more HLOs have higher slowdown, and vice versa.
float scale = interval.start * 1.0 / GetScheduleEndTime();
alternate_mem_slowdown *= scale;
return alternate_mem_benefit - alternate_mem_slowdown;
// Divide by the size of the buffer to prioritize smaller buffers that will
// give the largest alternate memory benefit.
return (alternate_mem_benefit - alternate_mem_slowdown) / interval.size;
}
int MemorySpaceAssignmentCostAnalysis::CalculateWhileLoopNestLevel(
@ -112,7 +134,7 @@ int MemorySpaceAssignmentCostAnalysis::CalculateWhileLoopNestLevel(
int nest_level = 0;
const HloComputation* computation = instruction->parent();
while (!computation->IsEntryComputation()) {
auto node = call_graph_.GetNode(computation);
auto node = call_graph_->GetNode(computation);
auto callsites = node.caller_callsites();
CHECK_EQ(callsites.size(), 1) << "The module is not flattened!";
auto callsite = callsites[0];
@ -194,7 +216,7 @@ float MemorySpaceAssignmentCostAnalysis::GetAsyncCopyElapsed(
}
int64 MemorySpaceAssignmentCostAnalysis::GetScheduleEndTime() const {
return hlo_live_range_.schedule_end_time();
return hlo_live_range_->schedule_end_time();
}
bool InstructionCountPrefetchIntervalPicker::CanAllocateInAlternateMemoryNoCopy(
@ -252,6 +274,13 @@ CostAnalysisPrefetchIntervalPicker::CostAnalysisPrefetchIntervalPicker(
std::vector<float> instructions_elapsed_time(instruction_schedule_->size(),
0.0);
for (const auto& instruction_and_logical_time : *instruction_schedule_) {
// To avoid double counting, don't include the elapsed time of while and
// conditional HLOs.
const HloInstruction* instruction = instruction_and_logical_time.first;
if (instruction->opcode() == HloOpcode::kWhile ||
instruction->opcode() == HloOpcode::kConditional) {
continue;
}
float elapsed_time = cost_analysis_.cost_analysis().optimal_seconds(
*instruction_and_logical_time.first);
int64 logical_time = instruction_and_logical_time.second;
@ -597,81 +626,6 @@ AlternateMemoryBestFitHeap::GetSortedColocatedIntervals(
return colocated_intervals;
}
bool AlternateMemoryBestFitHeap::IsIntervalAllowedInAlternateMemory(
const BufferInterval& interval) const {
// If the buffer is a tuple, don't use this algorithm for now. The buffers
// that are pointed to by the tuple will still use this algorithm. Because
// tuples are cheap to place in the alternate memory (they are just pointers)
// we don't need to use prefetch/evict logic.
if (interval.buffer->shape().IsTuple()) {
VLOG(4) << "Keeping value " << interval.buffer->ToShortString()
<< " in default mem because it is a tuple.";
return false;
}
// Don't place scalars in the alternate memory.
if (ShapeUtil::IsEffectiveScalar(interval.buffer->shape())) {
VLOG(4) << "Keeping value " << interval.buffer->ToShortString()
<< " in default mem because it is a scalar.";
return false;
}
// The semantics of TupleSelect are weird: TupleSelect doesn't define a
// buffer, but just forwards the buffers in the either left or right side.
// This means the two different inputs to TupleSelect must not alias, yet they
// should be allocated in the same memory space, and both buffers must be kept
// alive for the entire live range of TupleSelect. Instead, just don't
// allocate TupleSelect in the alternate memory space.
// TODO(berkin): Not allocating add-dependencies either since they need to be
// treated specially. We should revisit this later.
for (const HloPosition& position : interval.buffer->positions()) {
if (position.instruction->opcode() == HloOpcode::kTupleSelect ||
position.instruction->opcode() == HloOpcode::kAddDependency) {
VLOG(4) << "Keeping value " << interval.buffer->ToShortString()
<< " in default mem because it has a tuple-select or "
<< "add-dependency position.";
return false;
}
}
// Send and Recv HLOs return a request identifier. These should not be
// allocated in the alternate memory.
for (const HloPosition& position : interval.buffer->positions()) {
if ((position.instruction->opcode() == HloOpcode::kSend ||
position.instruction->opcode() == HloOpcode::kRecv)) {
// TODO(berkin): Send/recv buffers need a stable buffer allocation
// throughout sending/receiving. Disable memory space allocation for these
// for now.
if (position.index == ShapeIndex({0})) {
VLOG(4) << "Keeping value " << interval.buffer->ToShortString()
<< " in default mem because it is a send/recv buffer.";
return false;
} else if (position.index == ShapeIndex({1})) {
VLOG(4) << "Keeping value " << interval.buffer->ToShortString()
<< " in default mem because it is a request identifier for "
"send/recv.";
return false;
}
}
if ((position.instruction->opcode() == HloOpcode::kCollectivePermuteStart ||
position.instruction->opcode() == HloOpcode::kCollectivePermuteDone)) {
// Disable memory space allocation for these for now.
if (position.index == ShapeIndex({0})) {
VLOG(4) << "Keeping value " << interval.buffer->ToShortString()
<< " in default mem because it is a collective-permute buffer.";
return false;
} else if (position.index == ShapeIndex({1})) {
VLOG(4) << "Keeping value " << interval.buffer->ToShortString()
<< " in default mem because it is a collective-permute buffer.";
return false;
}
}
}
return true;
}
bool AlternateMemoryBestFitHeap::IsUseAllowedInAlternateMemory(
const AllocationValue& value, const HloUse& use) const {
const auto& instruction_schedule = hlo_live_range_.instruction_schedule();
@ -710,8 +664,7 @@ bool AlternateMemoryBestFitHeap::IsUseAllowedInAlternateMemory(
if (!options_.prefetch_interval_picker->CanAllocateInAlternateMemoryNoCopy(
shape, parameter_time, min_use_time)) {
VLOG(4) << "While allocation not allowed in alternate memory. "
<< "use time = " << min_use_time
<< ", root time = " << root_time;
<< "use time = " << min_use_time << ", root time = " << root_time;
return false;
}
// Check if there is a required assignment for the while loop output.
@ -897,7 +850,8 @@ HeapSimulator::Result AlternateMemoryBestFitHeap::Finish() {
continue;
}
if (!IsIntervalAllowedInAlternateMemory(interval)) {
if (!MemorySpaceAssignmentUtils::IsIntervalAllowedInAlternateMemory(
interval)) {
continue;
}
@ -2011,17 +1965,38 @@ AlternateMemoryBestFitHeap::FindBestChunkCandidate(
BufferInterval* alternate_mem_interval) const {
int64 end_time = request.end_time;
if (!preferred_offset) {
// First find the earliest use that is the same or later than the end time.
const auto& uses = request.allocation_value->uses();
auto use_it = uses.begin();
for (; use_it->time < end_time; ++use_it) {
}
CHECK(use_it != uses.end());
int64 earliest_use = use_it->time;
// Then find the latest use that can be allocated contiguously without
// copies.
const Shape& shape = request.allocation_value->defining_position().shape();
for (;
(use_it + 1) != uses.end() &&
options_.prefetch_interval_picker->CanAllocateInAlternateMemoryNoCopy(
shape, use_it->time, (use_it + 1)->time);
++use_it) {
}
CHECK(use_it != uses.end());
int64 latest_contiguous_use = use_it->time;
// Find a chunk that's as long living as possible iterating in reverse over
// the use times.
for (auto use_it = request.allocation_value->uses().rbegin();
use_it != request.allocation_value->uses().rend() &&
use_it->time >= end_time;
++use_it) {
for (; use_it >= uses.begin() && use_it->time >= end_time; --use_it) {
alternate_mem_interval->end = use_it->time;
ChunkCandidate chunk_candidate =
FindChunkCandidate(*alternate_mem_interval);
if (chunk_candidate.heap_size <= available_heap_size()) {
alternate_mem_interval->end = end_time;
VLOG(3) << "FindBestChunkCandidate earliest use = " << earliest_use
<< ", latest contiguous use = " << latest_contiguous_use
<< ", use with available mem = " << use_it->time
<< ", offset = " << chunk_candidate.chunk.offset;
return chunk_candidate;
}
}
@ -2079,8 +2054,8 @@ MemorySpaceAssignment::CalculateAsyncCopyStats() const {
MemorySpaceAssignment::GetMemoryBoundednessBufferIntervalCompare(
const MemorySpaceAssignmentCostAnalysis& cost_analysis,
MemorySpaceAssignmentCostAnalysis::Cache* cache) {
return [cost_analysis, cache](const BufferInterval& x,
const BufferInterval& y) {
return [&cost_analysis, cache](const BufferInterval& x,
const BufferInterval& y) {
float x_memory_boundedness = cost_analysis.GetMemoryBoundedness(x, cache);
float y_memory_boundedness = cost_analysis.GetMemoryBoundedness(y, cache);
if (x_memory_boundedness != y_memory_boundedness) {

View File

@ -84,18 +84,10 @@ class MemorySpaceAssignmentCostAnalysis {
absl::flat_hash_map<const HloInstruction*, float> while_nest_multiplier;
};
MemorySpaceAssignmentCostAnalysis(
static StatusOr<std::unique_ptr<MemorySpaceAssignmentCostAnalysis>> Create(
const HloCostAnalysis& cost_analysis,
float async_copy_bandwidth_bytes_per_second,
float alternate_mem_bandwidth_bytes_per_second,
const HloLiveRange& hlo_live_range, const CallGraph& call_graph)
: cost_analysis_(cost_analysis),
async_copy_bandwidth_bytes_per_second_(
async_copy_bandwidth_bytes_per_second),
alternate_mem_bandwidth_bytes_per_second_(
alternate_mem_bandwidth_bytes_per_second),
hlo_live_range_(hlo_live_range),
call_graph_(call_graph) {}
float alternate_mem_bandwidth_bytes_per_second, const HloModule& module);
const HloCostAnalysis& cost_analysis() const { return cost_analysis_; }
@ -153,14 +145,31 @@ class MemorySpaceAssignmentCostAnalysis {
// 0 means it is not in a while loop.
int CalculateWhileLoopNestLevel(const HloInstruction* instruction) const;
const HloLiveRange& hlo_live_range() const { return hlo_live_range_; }
const HloLiveRange& hlo_live_range() const { return *hlo_live_range_; }
private:
MemorySpaceAssignmentCostAnalysis(
const HloCostAnalysis& cost_analysis,
float async_copy_bandwidth_bytes_per_second,
float alternate_mem_bandwidth_bytes_per_second,
std::unique_ptr<HloAliasAnalysis> alias_analysis,
std::unique_ptr<HloLiveRange> hlo_live_range,
std::unique_ptr<CallGraph> call_graph)
: cost_analysis_(cost_analysis),
async_copy_bandwidth_bytes_per_second_(
async_copy_bandwidth_bytes_per_second),
alternate_mem_bandwidth_bytes_per_second_(
alternate_mem_bandwidth_bytes_per_second),
alias_analysis_(std::move(alias_analysis)),
hlo_live_range_(std::move(hlo_live_range)),
call_graph_(std::move(call_graph)) {}
const HloCostAnalysis& cost_analysis_;
float async_copy_bandwidth_bytes_per_second_;
float alternate_mem_bandwidth_bytes_per_second_;
const HloLiveRange& hlo_live_range_;
const CallGraph& call_graph_;
std::unique_ptr<HloAliasAnalysis> alias_analysis_;
std::unique_ptr<HloLiveRange> hlo_live_range_;
std::unique_ptr<CallGraph> call_graph_;
};
// Abstract base class that memory space assignment uses to pick prefetch
@ -909,10 +918,6 @@ class AlternateMemoryBestFitHeap : public GlobalDecreasingSizeBestFitHeap {
static MemorySpaceAssignment::Allocation* GetLiveAllocationAt(
const MemorySpaceAssignment::AllocationSequence& allocations, int64 time);
// Returns true if this buffer is allowed to be placed in the alternate
// memory.
bool IsIntervalAllowedInAlternateMemory(const BufferInterval& interval) const;
// Returns true if the use is allowed in the alternate memory.
bool IsUseAllowedInAlternateMemory(const AllocationValue& value,
const HloUse& use) const;

Some files were not shown because too many files have changed in this diff Show More