Merge branch 'master' into op_tests_16x8
This commit is contained in:
commit
81c8a6605d
11
.bazelrc
11
.bazelrc
@ -30,6 +30,7 @@
|
||||
# short_logs: Only log errors during build, skip warnings.
|
||||
# monolithic: Build all TF C++ code into a single shared object.
|
||||
# dynamic_kernels: Try to link all kernels dynamically (experimental).
|
||||
# libc++: Link against libc++ instead of stdlibc++
|
||||
#
|
||||
#
|
||||
# TF version options;
|
||||
@ -79,6 +80,14 @@
|
||||
# elinux_armhf: Embedded Linux options for armhf (ARMv7) CPU support.
|
||||
|
||||
|
||||
# Allow builds using libc++ as a linker library
|
||||
# This is mostly for OSSFuzz, so we also pass in the flags from environment to clean build file
|
||||
build:libc++ --action_env=CC
|
||||
build:libc++ --action_env=CXX
|
||||
build:libc++ --action_env=CXXFLAGS=-stdlib=libc++
|
||||
build:libc++ --action_env=PATH
|
||||
build:libc++ --define force_libcpp=enabled
|
||||
build:libc++ --linkopt -fuse-ld=lld
|
||||
|
||||
# Android configs. Bazel needs to have --cpu and --fat_apk_cpu both set to the
|
||||
# target CPU to build transient dependencies correctly. See
|
||||
@ -200,6 +209,8 @@ build:nogcp --define=no_gcp_support=true
|
||||
build:nohdfs --define=no_hdfs_support=true
|
||||
build:nonccl --define=no_nccl_support=true
|
||||
|
||||
build:stackdriver_support --define=stackdriver_support=true
|
||||
|
||||
build --define=use_fast_cpp_protos=true
|
||||
build --define=allow_oversize_protos=true
|
||||
|
||||
|
2
.github/stale.yml
vendored
2
.github/stale.yml
vendored
@ -23,7 +23,7 @@
|
||||
daysUntilStale: 7
|
||||
# Number of days of inactivity before a stale Issue or Pull Request is closed
|
||||
daysUntilClose: 7
|
||||
# Issues or Pull Requests with these labels will never be considered stale. Set to `[]` to disable
|
||||
# Only issues or pull requests with all of these labels are checked if stale. Defaults to `[]` (disabled)
|
||||
onlyLabels:
|
||||
- stat:awaiting response
|
||||
# Comment to post when marking as stale. Set to `false` to disable
|
||||
|
@ -61,7 +61,6 @@ commands.
|
||||
*Nightly binaries are available for testing using the
|
||||
[tf-nightly](https://pypi.python.org/pypi/tf-nightly) and
|
||||
[tf-nightly-cpu](https://pypi.python.org/pypi/tf-nightly-cpu) packages on PyPi.*
|
||||
|
||||
#### *Try your first TensorFlow program*
|
||||
|
||||
```shell
|
||||
@ -114,6 +113,12 @@ Build Type | Status
|
||||
**Android** | [](https://storage.googleapis.com/tensorflow-kokoro-build-badges/android.html) | [](https://bintray.com/google/tensorflow/tensorflow/_latestVersion)
|
||||
**Raspberry Pi 0 and 1** | [](https://storage.googleapis.com/tensorflow-kokoro-build-badges/rpi01-py3.html) | [Py3](https://storage.googleapis.com/tensorflow-nightly/tensorflow-1.10.0-cp34-none-linux_armv6l.whl)
|
||||
**Raspberry Pi 2 and 3** | [](https://storage.googleapis.com/tensorflow-kokoro-build-badges/rpi23-py3.html) | [Py3](https://storage.googleapis.com/tensorflow-nightly/tensorflow-1.10.0-cp34-none-linux_armv7l.whl)
|
||||
**Libtensorflow MacOS CPU** | [](https://storage.googleapis.com/tensorflow-kokoro-build-badges/libtensorflow-mac-cpu.html) | [GCS](https://storage.googleapis.com/libtensorflow-nightly)
|
||||
**Libtensorflow Linux CPU** | [](https://storage.googleapis.com/tensorflow-kokoro-build-badges/libtensorflow-linux-cpu.html) | [GCS](https://storage.googleapis.com/libtensorflow-nightly)
|
||||
**Libtensorflow Linux GPU** | [](https://storage.googleapis.com/tensorflow-kokoro-build-badges/libtensorflow-linux-gpu.html) | [GCS](https://storage.googleapis.com/libtensorflow-nightly)
|
||||
**Libtensorflow Windows CPU** | [](https://storage.googleapis.com/tensorflow-kokoro-build-badges/libtensorflow-win-cpu.html) | [GCS](https://storage.googleapis.com/libtensorflow-nightly)
|
||||
**Libtensorflow Windows GPU** | [](https://storage.googleapis.com/tensorflow-kokoro-build-badges/libtensorflow-win-gpu.html) | [GCS](https://storage.googleapis.com/libtensorflow-nightly)
|
||||
|
||||
|
||||
### Community Supported Builds
|
||||
|
||||
|
@ -298,6 +298,13 @@ config_setting(
|
||||
visibility = ["//visibility:public"],
|
||||
)
|
||||
|
||||
# Experimental features
|
||||
config_setting(
|
||||
name = "stackdriver_support",
|
||||
define_values = {"stackdriver_support": "true"},
|
||||
visibility = ["//visibility:public"],
|
||||
)
|
||||
|
||||
# Crosses between platforms and file system libraries not supported on those
|
||||
# platforms due to limitations in nested select() statements.
|
||||
config_setting(
|
||||
|
@ -713,8 +713,8 @@ TFE_Context* TFE_NewContext(const TFE_ContextOptions* opts, TF_Status* status) {
|
||||
status->status = tfrt::ListOpHandlerChains(
|
||||
opts->session_options.options, &op_handler_chains, &device_attributes);
|
||||
if (!status->status.ok()) return nullptr;
|
||||
return tensorflow::wrap(
|
||||
new tfrt::ContextInterface(op_handler_chains, device_attributes));
|
||||
return tensorflow::wrap(new tfrt::ContextInterface(
|
||||
op_handler_chains, device_attributes, opts->async));
|
||||
#else
|
||||
status->status = tensorflow::errors::Unimplemented("TFRT is not supported");
|
||||
return nullptr;
|
||||
|
@ -212,6 +212,35 @@ TEST(CAPI, CancellationManager) {
|
||||
TFE_DeleteCancellationManager(c_mgr);
|
||||
}
|
||||
|
||||
TEST(CAPI, ExecutorContextDestructionOrder) {
|
||||
TF_Status* status = TF_NewStatus();
|
||||
|
||||
{
|
||||
TFE_ContextOptions* opts = TFE_NewContextOptions();
|
||||
TFE_Context* ctx = TFE_NewContext(opts, status);
|
||||
ASSERT_TRUE(TF_GetCode(status) == TF_OK) << TF_Message(status);
|
||||
TFE_DeleteContextOptions(opts);
|
||||
TFE_Executor* executor = TFE_NewExecutor(/*is_async=*/false);
|
||||
TFE_ContextSetExecutorForThread(ctx, executor);
|
||||
|
||||
TFE_DeleteContext(ctx);
|
||||
TFE_DeleteExecutor(executor);
|
||||
}
|
||||
|
||||
{
|
||||
TFE_ContextOptions* opts = TFE_NewContextOptions();
|
||||
TFE_Context* ctx = TFE_NewContext(opts, status);
|
||||
ASSERT_TRUE(TF_GetCode(status) == TF_OK) << TF_Message(status);
|
||||
TFE_DeleteContextOptions(opts);
|
||||
TFE_Executor* executor = TFE_NewExecutor(/*is_async=*/false);
|
||||
TFE_ContextSetExecutorForThread(ctx, executor);
|
||||
|
||||
TFE_DeleteExecutor(executor);
|
||||
TFE_DeleteContext(ctx);
|
||||
}
|
||||
TF_DeleteStatus(status);
|
||||
}
|
||||
|
||||
TEST(CAPI, Function_ident_CPU) {
|
||||
// First create a simple identity function.
|
||||
TF_Graph* function_graph = TF_NewGraph();
|
||||
|
@ -37,6 +37,15 @@ class StatusDeleter {
|
||||
|
||||
using StatusPtr = std::unique_ptr<TF_Status, StatusDeleter>;
|
||||
|
||||
class ExecutorDeleter {
|
||||
public:
|
||||
void operator()(TFE_Executor* to_delete) const {
|
||||
TFE_DeleteExecutor(to_delete);
|
||||
}
|
||||
};
|
||||
|
||||
using ExecutorPtr = std::unique_ptr<TFE_Executor, ExecutorDeleter>;
|
||||
|
||||
} // namespace
|
||||
|
||||
// Allows a single op at a time to be launched without blocking.
|
||||
@ -51,6 +60,13 @@ class DeviceThread {
|
||||
explicit DeviceThread(const std::string& device)
|
||||
: status_(TF_NewStatus()),
|
||||
device_(device),
|
||||
// If the context's default exector is set to async, re-using that in
|
||||
// each thread would cause collectives to deadlock. For consistency we
|
||||
// create a new sync executor for every thread.
|
||||
//
|
||||
// TODO(allenl): We should have an async API that works with the
|
||||
// parallel device.
|
||||
executor_(TFE_NewExecutor(/*is_async=*/false)),
|
||||
op_(nullptr),
|
||||
thread_(tensorflow::Env::Default()->StartThread(
|
||||
tensorflow::ThreadOptions(), "parallel_device_execute",
|
||||
@ -105,6 +121,7 @@ class DeviceThread {
|
||||
StatusPtr status_ TF_GUARDED_BY(execution_mutex_);
|
||||
|
||||
const std::string device_;
|
||||
ExecutorPtr executor_ TF_GUARDED_BY(execution_mutex_);
|
||||
mutable OpPtr op_ TF_GUARDED_BY(execution_mutex_);
|
||||
std::unique_ptr<Thread> thread_;
|
||||
};
|
||||
@ -186,6 +203,7 @@ void DeviceThread::Execute(TFE_Context* context, const char* operation_name,
|
||||
std::vector<TensorHandlePtr>* outputs,
|
||||
TF_Status* status) const {
|
||||
if (op_ == nullptr) {
|
||||
TFE_ContextSetExecutorForThread(context, executor_.get());
|
||||
op_.reset(TFE_NewOp(context, operation_name, status));
|
||||
if (TF_GetCode(status) != TF_OK) return;
|
||||
TFE_OpSetDevice(op_.get(), device_.c_str(), status);
|
||||
|
@ -412,6 +412,7 @@ void TestCollective(bool async) {
|
||||
TF_NewStatus(), TF_DeleteStatus);
|
||||
std::unique_ptr<TFE_ContextOptions, decltype(&TFE_DeleteContextOptions)> opts(
|
||||
TFE_NewContextOptions(), TFE_DeleteContextOptions);
|
||||
TFE_ContextOptionsSetAsync(opts.get(), async);
|
||||
std::unique_ptr<TF_Buffer, decltype(&TF_DeleteBuffer)> config(
|
||||
TF_CreateConfig(
|
||||
/*xla*/ false,
|
||||
@ -423,9 +424,6 @@ void TestCollective(bool async) {
|
||||
std::unique_ptr<TFE_Context, decltype(&TFE_DeleteContext)> context(
|
||||
TFE_NewContext(opts.get(), status.get()), TFE_DeleteContext);
|
||||
ASSERT_TRUE(TF_GetCode(status.get()) == TF_OK) << TF_Message(status.get());
|
||||
std::unique_ptr<TFE_Executor, decltype(&TFE_DeleteExecutor)> executor(
|
||||
TFE_NewExecutor(async), TFE_DeleteExecutor);
|
||||
TFE_ContextSetExecutorForThread(context.get(), executor.get());
|
||||
|
||||
const char* device_name = "/job:localhost/replica:0/task:0/device:CUSTOM:0";
|
||||
std::array<const char*, 2> underlying_devices{
|
||||
@ -455,8 +453,6 @@ void TestCollective(bool async) {
|
||||
ASSERT_TRUE(TF_GetCode(status.get()) == TF_OK) << TF_Message(status.get());
|
||||
ExpectScalarEq<float>(result_components[0].get(), 3.);
|
||||
ExpectScalarEq<float>(result_components[1].get(), 3.);
|
||||
// Destroying the context's default executor first isn't safe.
|
||||
context.reset();
|
||||
}
|
||||
|
||||
TEST(PARALLEL_DEVICE, TestCollectiveSync) { TestCollective(/*async=*/false); }
|
||||
|
@ -27,5 +27,6 @@ cc_library(
|
||||
"//tensorflow/c:tf_status",
|
||||
"//tensorflow/c/experimental/filesystem:filesystem_interface",
|
||||
"@com_github_googlecloudplatform_google_cloud_cpp//:storage_client",
|
||||
"@com_google_absl//absl/strings",
|
||||
],
|
||||
)
|
||||
|
@ -15,6 +15,7 @@ limitations under the License.
|
||||
#include <stdlib.h>
|
||||
#include <string.h>
|
||||
|
||||
#include "absl/strings/string_view.h"
|
||||
#include "google/cloud/storage/client.h"
|
||||
#include "tensorflow/c/experimental/filesystem/filesystem_interface.h"
|
||||
#include "tensorflow/c/tf_status.h"
|
||||
@ -35,6 +36,45 @@ static inline void TF_SetStatusFromGCSStatus(
|
||||
static void* plugin_memory_allocate(size_t size) { return calloc(1, size); }
|
||||
static void plugin_memory_free(void* ptr) { free(ptr); }
|
||||
|
||||
static void ParseGCSPath(absl::string_view fname, bool object_empty_ok,
|
||||
char** bucket, char** object, TF_Status* status) {
|
||||
size_t scheme_end = fname.find("://") + 2;
|
||||
if (fname.substr(0, scheme_end + 1) != "gs://") {
|
||||
TF_SetStatus(status, TF_INVALID_ARGUMENT,
|
||||
"GCS path doesn't start with 'gs://'.");
|
||||
return;
|
||||
}
|
||||
|
||||
size_t bucket_end = fname.find("/", scheme_end + 1);
|
||||
if (bucket_end == absl::string_view::npos) {
|
||||
TF_SetStatus(status, TF_INVALID_ARGUMENT,
|
||||
"GCS path doesn't contain a bucket name.");
|
||||
return;
|
||||
}
|
||||
absl::string_view bucket_view =
|
||||
fname.substr(scheme_end + 1, bucket_end - scheme_end - 1);
|
||||
*bucket =
|
||||
static_cast<char*>(plugin_memory_allocate(bucket_view.length() + 1));
|
||||
memcpy(*bucket, bucket_view.data(), bucket_view.length());
|
||||
(*bucket)[bucket_view.length()] = '\0';
|
||||
|
||||
absl::string_view object_view = fname.substr(bucket_end + 1);
|
||||
if (object_view.empty()) {
|
||||
if (object_empty_ok) {
|
||||
*object = nullptr;
|
||||
return;
|
||||
} else {
|
||||
TF_SetStatus(status, TF_INVALID_ARGUMENT,
|
||||
"GCS path doesn't contain an object name.");
|
||||
return;
|
||||
}
|
||||
}
|
||||
*object =
|
||||
static_cast<char*>(plugin_memory_allocate(object_view.length() + 1));
|
||||
// object_view.data() is a null-terminated string_view because fname is.
|
||||
strcpy(*object, object_view.data());
|
||||
}
|
||||
|
||||
// SECTION 1. Implementation for `TF_RandomAccessFile`
|
||||
// ----------------------------------------------------------------------------
|
||||
namespace tf_random_access_file {
|
||||
|
94
tensorflow/c/experimental/saved_model/core/ops/BUILD
Normal file
94
tensorflow/c/experimental/saved_model/core/ops/BUILD
Normal file
@ -0,0 +1,94 @@
|
||||
# This package contains written convenience helpers for Eager Operations
|
||||
# used by SavedModel. Once we autogenerate C++ Eager Op wrappers, we can remove these.
|
||||
load(
|
||||
"//tensorflow:tensorflow.bzl",
|
||||
"tf_cc_test",
|
||||
)
|
||||
|
||||
package(
|
||||
default_visibility = [
|
||||
# Restricting visibility for now
|
||||
"//tensorflow/c/experimental/saved_model/core:__subpackages__",
|
||||
"//tensorflow/c/experimental/saved_model/internal:__subpackages__",
|
||||
],
|
||||
licenses = ["notice"], # Apache 2.0
|
||||
)
|
||||
|
||||
cc_library(
|
||||
name = "owned_eager_op",
|
||||
hdrs = [
|
||||
"owned_eager_op.h",
|
||||
],
|
||||
deps = [
|
||||
"//tensorflow/c/eager:operation_interface",
|
||||
],
|
||||
)
|
||||
|
||||
cc_library(
|
||||
name = "owned_tensor_handle",
|
||||
hdrs = [
|
||||
"owned_tensor_handle.h",
|
||||
],
|
||||
deps = [
|
||||
"//tensorflow/c/eager:tensor_handle_interface",
|
||||
"//tensorflow/core/common_runtime/eager:tensor_handle",
|
||||
],
|
||||
)
|
||||
|
||||
cc_library(
|
||||
name = "owned_eager_context",
|
||||
hdrs = ["owned_eager_context.h"],
|
||||
deps = [
|
||||
"//tensorflow/c/eager:context_interface",
|
||||
"//tensorflow/core/common_runtime/eager:context",
|
||||
],
|
||||
)
|
||||
|
||||
cc_library(
|
||||
name = "owned_tensor",
|
||||
hdrs = ["owned_tensor.h"],
|
||||
deps = [
|
||||
"//tensorflow/c:tensor_interface",
|
||||
],
|
||||
)
|
||||
|
||||
cc_library(
|
||||
name = "variable_ops",
|
||||
srcs = [
|
||||
"variable_ops.cc",
|
||||
],
|
||||
hdrs = [
|
||||
"variable_ops.h",
|
||||
],
|
||||
deps = [
|
||||
":owned_eager_op",
|
||||
":owned_tensor_handle",
|
||||
"//tensorflow/c/eager:context_interface",
|
||||
"//tensorflow/c/eager:tensor_handle_interface",
|
||||
"//tensorflow/core:framework",
|
||||
"//tensorflow/core:lib",
|
||||
"//tensorflow/core:protos_all_cc",
|
||||
"@com_google_absl//absl/types:span",
|
||||
],
|
||||
)
|
||||
|
||||
tf_cc_test(
|
||||
name = "variable_ops_test",
|
||||
srcs = [
|
||||
"variable_ops_test.cc",
|
||||
],
|
||||
deps = [
|
||||
":owned_eager_context",
|
||||
":owned_tensor",
|
||||
":owned_tensor_handle",
|
||||
":variable_ops",
|
||||
"//tensorflow/core:framework",
|
||||
"//tensorflow/core:lib",
|
||||
"//tensorflow/core:protos_all_cc",
|
||||
"//tensorflow/core:test",
|
||||
"//tensorflow/core:test_main",
|
||||
"//tensorflow/core/common_runtime:core_cpu_lib",
|
||||
"//tensorflow/core/common_runtime/eager:context",
|
||||
"//tensorflow/core/common_runtime/eager:core",
|
||||
],
|
||||
)
|
@ -0,0 +1,54 @@
|
||||
/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
|
||||
#ifndef TENSORFLOW_C_EXPERIMENTAL_SAVED_MODEL_CORE_OPS_OWNED_EAGER_CONTEXT_H_
|
||||
#define TENSORFLOW_C_EXPERIMENTAL_SAVED_MODEL_CORE_OPS_OWNED_EAGER_CONTEXT_H_
|
||||
|
||||
#include <memory>
|
||||
|
||||
#include "tensorflow/c/eager/context_interface.h"
|
||||
#include "tensorflow/core/common_runtime/eager/context.h"
|
||||
|
||||
namespace tensorflow {
|
||||
namespace internal {
|
||||
|
||||
struct AbstractContextInterfaceDeleter {
|
||||
void operator()(AbstractContextInterface* p) const {
|
||||
if (p != nullptr) {
|
||||
p->Release();
|
||||
}
|
||||
}
|
||||
};
|
||||
|
||||
struct EagerContextDeleter {
|
||||
void operator()(EagerContext* p) const {
|
||||
if (p != nullptr) {
|
||||
p->Release();
|
||||
}
|
||||
}
|
||||
};
|
||||
|
||||
} // namespace internal
|
||||
|
||||
using AbstractContextPtr =
|
||||
std::unique_ptr<AbstractContextInterface,
|
||||
internal::AbstractContextInterfaceDeleter>;
|
||||
|
||||
using EagerContextPtr =
|
||||
std::unique_ptr<EagerContext, internal::EagerContextDeleter>;
|
||||
|
||||
} // namespace tensorflow
|
||||
|
||||
#endif // THIRD_PARTY_TENSORFLOW_C_EXPERIMENTAL_SAVED_MODEL_CORE_OPS_OWNED_EAGER_CONTEXT_H_
|
@ -0,0 +1,42 @@
|
||||
/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
|
||||
#ifndef TENSORFLOW_C_EXPERIMENTAL_SAVED_MODEL_CORE_OWNED_EAGER_OP_H_
|
||||
#define TENSORFLOW_C_EXPERIMENTAL_SAVED_MODEL_CORE_OWNED_EAGER_OP_H_
|
||||
|
||||
#include <memory>
|
||||
|
||||
#include "tensorflow/c/eager/operation_interface.h"
|
||||
|
||||
namespace tensorflow {
|
||||
namespace internal {
|
||||
|
||||
struct AbstractOperationInterfaceDeleter {
|
||||
void operator()(AbstractOperationInterface* p) const {
|
||||
if (p != nullptr) {
|
||||
p->Release();
|
||||
}
|
||||
}
|
||||
};
|
||||
|
||||
} // namespace internal
|
||||
|
||||
using AbstractOpPtr =
|
||||
std::unique_ptr<AbstractOperationInterface,
|
||||
internal::AbstractOperationInterfaceDeleter>;
|
||||
|
||||
} // namespace tensorflow
|
||||
|
||||
#endif // TENSORFLOW_C_EXPERIMENTAL_SAVED_MODEL_CORE_OWNED_EAGER_OP_H_
|
@ -0,0 +1,42 @@
|
||||
/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
|
||||
#ifndef TENSORFLOW_C_EXPERIMENTAL_SAVED_MODEL_CORE_OPS_OWNED_TENSOR_H_
|
||||
#define TENSORFLOW_C_EXPERIMENTAL_SAVED_MODEL_CORE_OPS_OWNED_TENSOR_H_
|
||||
|
||||
#include <memory>
|
||||
|
||||
#include "tensorflow/c/tensor_interface.h"
|
||||
|
||||
namespace tensorflow {
|
||||
namespace internal {
|
||||
|
||||
struct AbstractTensorInterfaceDeleter {
|
||||
void operator()(AbstractTensorInterface* p) const {
|
||||
if (p != nullptr) {
|
||||
p->Release();
|
||||
}
|
||||
}
|
||||
};
|
||||
|
||||
} // namespace internal
|
||||
|
||||
using AbstractTensorPtr =
|
||||
std::unique_ptr<AbstractTensorInterface,
|
||||
internal::AbstractTensorInterfaceDeleter>;
|
||||
|
||||
} // namespace tensorflow
|
||||
|
||||
#endif // TENSORFLOW_C_EXPERIMENTAL_SAVED_MODEL_CORE_OPS_OWNED_TENSOR_H_
|
@ -0,0 +1,54 @@
|
||||
/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
|
||||
#ifndef TENSORFLOW_C_EXPERIMENTAL_SAVED_MODEL_CORE_OWNED_TENSOR_HANDLE_H_
|
||||
#define TENSORFLOW_C_EXPERIMENTAL_SAVED_MODEL_CORE_OWNED_TENSOR_HANDLE_H_
|
||||
|
||||
#include <memory>
|
||||
|
||||
#include "tensorflow/c/eager/tensor_handle_interface.h"
|
||||
#include "tensorflow/core/common_runtime/eager/tensor_handle.h"
|
||||
|
||||
namespace tensorflow {
|
||||
namespace internal {
|
||||
|
||||
struct TensorHandleDeleter {
|
||||
void operator()(TensorHandle* p) const {
|
||||
if (p != nullptr) {
|
||||
p->Release();
|
||||
}
|
||||
}
|
||||
};
|
||||
|
||||
struct AbstractTensorHandleDeleter {
|
||||
void operator()(AbstractTensorHandleInterface* p) const {
|
||||
if (p != nullptr) {
|
||||
p->Release();
|
||||
}
|
||||
}
|
||||
};
|
||||
|
||||
} // namespace internal
|
||||
|
||||
using TensorHandlePtr =
|
||||
std::unique_ptr<TensorHandle, internal::TensorHandleDeleter>;
|
||||
|
||||
using AbstractTensorHandlePtr =
|
||||
std::unique_ptr<AbstractTensorHandleInterface,
|
||||
internal::AbstractTensorHandleDeleter>;
|
||||
|
||||
} // namespace tensorflow
|
||||
|
||||
#endif // TENSORFLOW_C_EXPERIMENTAL_SAVED_MODEL_CORE_OWNED_TENSOR_HANDLE_H_
|
104
tensorflow/c/experimental/saved_model/core/ops/variable_ops.cc
Normal file
104
tensorflow/c/experimental/saved_model/core/ops/variable_ops.cc
Normal file
@ -0,0 +1,104 @@
|
||||
/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
|
||||
#include "tensorflow/c/experimental/saved_model/core/ops/variable_ops.h"
|
||||
|
||||
#include "absl/types/span.h"
|
||||
#include "tensorflow/c/eager/context_interface.h"
|
||||
#include "tensorflow/c/experimental/saved_model/core/ops/owned_eager_op.h"
|
||||
#include "tensorflow/c/experimental/saved_model/core/ops/owned_tensor_handle.h"
|
||||
#include "tensorflow/core/framework/tensor_shape.h"
|
||||
#include "tensorflow/core/framework/types.pb.h"
|
||||
#include "tensorflow/core/lib/gtl/inlined_vector.h"
|
||||
#include "tensorflow/core/platform/errors.h"
|
||||
#include "tensorflow/core/platform/status.h"
|
||||
#include "tensorflow/core/platform/types.h"
|
||||
|
||||
namespace tensorflow {
|
||||
namespace internal {
|
||||
|
||||
static const char kNoSharingResourceID[] =
|
||||
"cd2c89b7-88b7-44c8-ad83-06c2a9158347";
|
||||
|
||||
Status CreateUninitializedResourceVariable(AbstractContextInterface* ctx,
|
||||
DataType dtype, TensorShape shape,
|
||||
AbstractTensorHandlePtr* handle) {
|
||||
AbstractOpPtr varhandle_op = AbstractOpPtr(ctx->CreateOperation());
|
||||
|
||||
TF_RETURN_IF_ERROR(varhandle_op->Reset("VarHandleOp", nullptr));
|
||||
TF_RETURN_IF_ERROR(varhandle_op->SetAttrType("dtype", dtype));
|
||||
|
||||
// Note that if shape is unknown rank, shape.dim_sizes() will be empty, and
|
||||
// shape.dims() will be -1.
|
||||
gtl::InlinedVector<int64, 4> dim_sizes = shape.dim_sizes();
|
||||
TF_RETURN_IF_ERROR(varhandle_op->SetAttrShape(
|
||||
"shape", reinterpret_cast<const int64_t*>(dim_sizes.data()),
|
||||
shape.dims()));
|
||||
TF_RETURN_IF_ERROR(varhandle_op->SetAttrString("container", "", 0));
|
||||
TF_RETURN_IF_ERROR(varhandle_op->SetAttrString(
|
||||
"shared_name", kNoSharingResourceID, strlen(kNoSharingResourceID)));
|
||||
|
||||
AbstractTensorHandleInterface* var_handle = nullptr;
|
||||
int num_retvals = 1;
|
||||
TF_RETURN_IF_ERROR(varhandle_op->Execute(
|
||||
absl::MakeSpan(&var_handle, num_retvals), &num_retvals));
|
||||
handle->reset(var_handle);
|
||||
return Status();
|
||||
}
|
||||
|
||||
Status AssignVariable(AbstractContextInterface* ctx,
|
||||
AbstractTensorHandleInterface* variable_handle,
|
||||
DataType dtype, AbstractTensorHandleInterface* value) {
|
||||
AbstractOpPtr assign_op(ctx->CreateOperation());
|
||||
TF_RETURN_IF_ERROR(assign_op->Reset("AssignVariableOp", nullptr));
|
||||
TF_RETURN_IF_ERROR(assign_op->SetAttrType("dtype", dtype));
|
||||
TF_RETURN_IF_ERROR(assign_op->AddInput(variable_handle));
|
||||
TF_RETURN_IF_ERROR(assign_op->AddInput(value));
|
||||
|
||||
int num_retvals = 0;
|
||||
TF_RETURN_IF_ERROR(assign_op->Execute({}, &num_retvals));
|
||||
return Status();
|
||||
}
|
||||
|
||||
Status ReadVariable(AbstractContextInterface* ctx,
|
||||
AbstractTensorHandleInterface* variable_handle,
|
||||
DataType dtype, AbstractTensorHandlePtr* output) {
|
||||
AbstractOpPtr read_op = AbstractOpPtr(ctx->CreateOperation());
|
||||
TF_RETURN_IF_ERROR(read_op->Reset("ReadVariableOp", nullptr));
|
||||
TF_RETURN_IF_ERROR(read_op->SetAttrType("dtype", dtype));
|
||||
TF_RETURN_IF_ERROR(read_op->AddInput(variable_handle));
|
||||
|
||||
AbstractTensorHandleInterface* value = nullptr;
|
||||
int num_retvals = 1;
|
||||
TF_RETURN_IF_ERROR(
|
||||
read_op->Execute(absl::MakeSpan(&value, num_retvals), &num_retvals));
|
||||
output->reset(value);
|
||||
return Status();
|
||||
}
|
||||
|
||||
Status DestroyResource(AbstractContextInterface* ctx,
|
||||
AbstractTensorHandleInterface* handle) {
|
||||
AbstractOpPtr destroy_op = AbstractOpPtr(ctx->CreateOperation());
|
||||
TF_RETURN_IF_ERROR(destroy_op->Reset("DestroyResourceOp", nullptr));
|
||||
TF_RETURN_IF_ERROR(destroy_op->SetAttrBool("ignore_lookup_error", true));
|
||||
TF_RETURN_IF_ERROR(destroy_op->AddInput(handle));
|
||||
|
||||
int num_retvals = 0;
|
||||
TF_RETURN_IF_ERROR(destroy_op->Execute({}, &num_retvals));
|
||||
return Status();
|
||||
}
|
||||
|
||||
} // namespace internal
|
||||
} // namespace tensorflow
|
@ -0,0 +1,62 @@
|
||||
/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
|
||||
#ifndef TENSORFLOW_C_EXPERIMENTAL_SAVED_MODEL_CORE_OPS_VARIABLE_OPS_H
|
||||
#define TENSORFLOW_C_EXPERIMENTAL_SAVED_MODEL_CORE_OPS_VARIABLE_OPS_H
|
||||
|
||||
#include "tensorflow/c/eager/context_interface.h"
|
||||
#include "tensorflow/c/eager/tensor_handle_interface.h"
|
||||
#include "tensorflow/c/experimental/saved_model/core/ops/owned_tensor_handle.h"
|
||||
#include "tensorflow/core/framework/tensor_shape.h"
|
||||
#include "tensorflow/core/framework/types.pb.h"
|
||||
#include "tensorflow/core/platform/status.h"
|
||||
|
||||
namespace tensorflow {
|
||||
namespace internal {
|
||||
|
||||
// Executes a VarHandleOp using `ctx`, and fills `handle` with the DT_RESOURCE
|
||||
// TensorHandle associated with the variable. This is equivalent to creating an
|
||||
// unitialized TF2 tf.Variable.
|
||||
// https://github.com/tensorflow/tensorflow/blob/516608035f85cec8b126712b0ff8407220206b22/tensorflow/python/ops/resource_variable_ops.py#L1867-L1872
|
||||
Status CreateUninitializedResourceVariable(AbstractContextInterface* ctx,
|
||||
DataType dtype, TensorShape shape,
|
||||
AbstractTensorHandlePtr* handle);
|
||||
|
||||
// Executes an AssignVariableOp using `ctx`, assigning the variable associated
|
||||
// with `variable_handle` with `value`. `dtype` must be the datatype of the
|
||||
// underlying variable for `variable_handle`. Note that it is illegal to assign
|
||||
// a variable to a Tensor with a different dtype than what the variable was
|
||||
// created with.
|
||||
Status AssignVariable(AbstractContextInterface* ctx,
|
||||
AbstractTensorHandleInterface* variable_handle,
|
||||
DataType dtype, AbstractTensorHandleInterface* value);
|
||||
|
||||
// Executes a ReadVariableOp using `ctx`. This reads the underlying variable
|
||||
// value of `variable_handle` and copies the value to `output`. `dtype` must be
|
||||
// the dtype of the variable associated with `variable_handle`.
|
||||
Status ReadVariable(AbstractContextInterface* ctx,
|
||||
AbstractTensorHandleInterface* variable_handle,
|
||||
DataType dtype, AbstractTensorHandlePtr* output);
|
||||
|
||||
// Executes DestroyResourceOp on `handle`, using `ctx`. This is equivalent to
|
||||
// the cleanup that occurs in a tf.Variable's EagerResourceDeleter:
|
||||
// https://github.com/tensorflow/tensorflow/blob/516608035f85cec8b126712b0ff8407220206b22/tensorflow/python/ops/resource_variable_ops.py#L289-L290
|
||||
Status DestroyResource(AbstractContextInterface* ctx,
|
||||
AbstractTensorHandleInterface* handle);
|
||||
|
||||
} // namespace internal
|
||||
} // namespace tensorflow
|
||||
|
||||
#endif // TENSORFLOW_C_EXPERIMENTAL_SAVED_MODEL_CORE_OPS_VARIABLE_OPS_H
|
@ -0,0 +1,107 @@
|
||||
/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
|
||||
#include "tensorflow/c/experimental/saved_model/core/ops/variable_ops.h"
|
||||
|
||||
#include <memory>
|
||||
|
||||
#include "tensorflow/c/experimental/saved_model/core/ops/owned_eager_context.h"
|
||||
#include "tensorflow/c/experimental/saved_model/core/ops/owned_tensor.h"
|
||||
#include "tensorflow/c/experimental/saved_model/core/ops/owned_tensor_handle.h"
|
||||
#include "tensorflow/core/common_runtime/device_mgr.h"
|
||||
#include "tensorflow/core/common_runtime/eager/context.h"
|
||||
#include "tensorflow/core/framework/tensor.h"
|
||||
#include "tensorflow/core/framework/types.pb.h"
|
||||
#include "tensorflow/core/lib/core/status_test_util.h"
|
||||
#include "tensorflow/core/platform/test.h"
|
||||
|
||||
namespace tensorflow {
|
||||
namespace {
|
||||
|
||||
AbstractTensorHandlePtr CreateScalarTensorHandle(EagerContext* context,
|
||||
float value) {
|
||||
AbstractTensorPtr tensor(context->CreateFloatScalar(value));
|
||||
AbstractTensorHandlePtr handle(context->CreateLocalHandle(tensor.get()));
|
||||
return handle;
|
||||
}
|
||||
|
||||
class VariableOpsTest : public ::testing::Test {
|
||||
public:
|
||||
VariableOpsTest()
|
||||
: device_mgr_(std::make_unique<StaticDeviceMgr>(DeviceFactory::NewDevice(
|
||||
"CPU", {}, "/job:localhost/replica:0/task:0"))),
|
||||
ctx_(new EagerContext(
|
||||
SessionOptions(),
|
||||
tensorflow::ContextDevicePlacementPolicy::DEVICE_PLACEMENT_SILENT,
|
||||
tensorflow::ContextMirroringPolicy::MIRRORING_NONE,
|
||||
/* async= */ false,
|
||||
/* lazy_copy_function_remote_inputs= */ false, device_mgr_.get(),
|
||||
/* device_mgr_owned= */ false, /* rendezvous= */ nullptr,
|
||||
/* custom_kernel_creator= */ nullptr,
|
||||
/* cluster_flr= */ nullptr)) {}
|
||||
|
||||
EagerContext* context() { return ctx_.get(); }
|
||||
|
||||
private:
|
||||
std::unique_ptr<StaticDeviceMgr> device_mgr_;
|
||||
EagerContextPtr ctx_;
|
||||
};
|
||||
|
||||
// Sanity check for variable creation
|
||||
TEST_F(VariableOpsTest, CreateVariableSuccessful) {
|
||||
// Create a DT_Resource TensorHandle that points to a scalar DT_FLOAT tensor
|
||||
AbstractTensorHandlePtr handle;
|
||||
TF_EXPECT_OK(internal::CreateUninitializedResourceVariable(
|
||||
context(), DT_FLOAT, {}, &handle));
|
||||
// The created TensorHandle should be a DT_Resource
|
||||
EXPECT_EQ(handle->DataType(), DT_RESOURCE);
|
||||
}
|
||||
|
||||
// Sanity check for variable destruction
|
||||
TEST_F(VariableOpsTest, DestroyVariableSuccessful) {
|
||||
// Create a DT_Resource TensorHandle that points to a scalar DT_FLOAT tensor
|
||||
AbstractTensorHandlePtr handle;
|
||||
TF_EXPECT_OK(internal::CreateUninitializedResourceVariable(
|
||||
context(), DT_FLOAT, {}, &handle));
|
||||
|
||||
// Destroy the variable
|
||||
TF_EXPECT_OK(internal::DestroyResource(context(), handle.get()));
|
||||
}
|
||||
|
||||
// Sanity check for handle assignment and reading
|
||||
TEST_F(VariableOpsTest, AssignVariableAndReadSuccessful) {
|
||||
// Create a DT_Resource TensorHandle that points to a scalar DT_FLOAT tensor
|
||||
AbstractTensorHandlePtr variable;
|
||||
TF_EXPECT_OK(internal::CreateUninitializedResourceVariable(
|
||||
context(), DT_FLOAT, {}, &variable));
|
||||
|
||||
// Create a Scalar float TensorHandle with value 42, and assign it to
|
||||
// the variable.
|
||||
AbstractTensorHandlePtr my_value = CreateScalarTensorHandle(context(), 42.0);
|
||||
TF_EXPECT_OK(internal::AssignVariable(context(), variable.get(), DT_FLOAT,
|
||||
my_value.get()));
|
||||
|
||||
// Read back the value from the variable, and check that it is 42.
|
||||
AbstractTensorHandlePtr read_value_handle;
|
||||
TF_EXPECT_OK(internal::ReadVariable(context(), variable.get(), DT_FLOAT,
|
||||
&read_value_handle));
|
||||
Status status;
|
||||
AbstractTensorPtr read_value(read_value_handle->Resolve(&status));
|
||||
TF_EXPECT_OK(status);
|
||||
EXPECT_FLOAT_EQ(42.0, *static_cast<float*>(read_value->Data()));
|
||||
}
|
||||
|
||||
} // namespace
|
||||
} // namespace tensorflow
|
@ -4,6 +4,7 @@ traces: {
|
||||
value: {
|
||||
file_line_cols: {
|
||||
line: 1
|
||||
col: 1
|
||||
}
|
||||
}
|
||||
}
|
||||
@ -12,9 +13,11 @@ traces: {
|
||||
value: {
|
||||
file_line_cols: {
|
||||
line: 3
|
||||
col: 1
|
||||
}
|
||||
file_line_cols: {
|
||||
line: 4
|
||||
col: 1
|
||||
}
|
||||
}
|
||||
}
|
||||
@ -23,6 +26,7 @@ traces: {
|
||||
value: {
|
||||
file_line_cols: {
|
||||
line: 2
|
||||
col: 1
|
||||
}
|
||||
}
|
||||
}
|
||||
|
@ -195,53 +195,46 @@ XlaComputationLaunchContext::XlaComputationLaunchContext(
|
||||
}
|
||||
|
||||
void XlaComputationLaunchContext::PopulateInputs(
|
||||
OpKernelContext* ctx, const XlaCompiler::CompilationResult* kernel,
|
||||
OpKernelContext* ctx,
|
||||
const XlaCompiler::CompilationResult* compilation_result,
|
||||
const std::map<int, OptionalTensor>& variables,
|
||||
int missing_ctx_input_prefix) {
|
||||
se::Stream* stream =
|
||||
ctx->op_device_context() ? ctx->op_device_context()->stream() : nullptr;
|
||||
// Build ShapedBuffers that point directly to the Tensor buffers.
|
||||
arg_ptrs_ = std::vector<ShapedBuffer*>(kernel->xla_input_shapes.size());
|
||||
arg_ptrs_ =
|
||||
std::vector<ShapedBuffer*>(compilation_result->xla_input_shapes.size());
|
||||
|
||||
// Pass remaining parameters.
|
||||
const Tensor* t;
|
||||
for (int i = 0; i < kernel->xla_input_shapes.size(); ++i) {
|
||||
int arg_num = kernel->input_mapping[i];
|
||||
DCHECK_GE(arg_num, missing_ctx_input_prefix);
|
||||
const xla::Shape& shape = kernel->xla_input_shapes[i];
|
||||
if (variables.count(arg_num)) {
|
||||
t = &(variables.at(arg_num).value);
|
||||
CHECK(t);
|
||||
} else {
|
||||
t = &(ctx->input(arg_num - missing_ctx_input_prefix));
|
||||
}
|
||||
xla::TransferManager* transfer_manager =
|
||||
client_->backend().transfer_manager();
|
||||
for (int i = 0; i < compilation_result->xla_input_shapes.size(); ++i) {
|
||||
int arg_num = compilation_result->input_mapping[i];
|
||||
CHECK_GE(arg_num, missing_ctx_input_prefix);
|
||||
const xla::Shape& shape = compilation_result->xla_input_shapes[i];
|
||||
const Tensor* t = variables.count(arg_num)
|
||||
? &(variables.at(arg_num).value)
|
||||
: &(ctx->input(arg_num - missing_ctx_input_prefix));
|
||||
CHECK(t);
|
||||
|
||||
if (use_multiple_streams_) {
|
||||
CHECK(stream) << "Must have a stream available when using XLA tensors!";
|
||||
CHECK(ctx->op_device_context() && ctx->op_device_context()->stream())
|
||||
<< "Must have a stream available when using XLA tensors!";
|
||||
XlaTensor* xla_tensor = XlaTensor::FromTensor(t);
|
||||
CHECK(xla_tensor);
|
||||
xla_tensor->WaitForDefinitionEventOnStream(stream);
|
||||
xla_tensor->WaitForDefinitionEventOnStream(
|
||||
ctx->op_device_context()->stream());
|
||||
}
|
||||
|
||||
const xla::Shape on_device_shape =
|
||||
client_->backend().transfer_manager()->HostShapeToDeviceShape(shape);
|
||||
if (on_device_shape.IsTuple()) {
|
||||
const XlaTensor* xla_tensor = XlaTensor::FromTensor(t);
|
||||
CHECK(xla_tensor && xla_tensor->has_shaped_buffer());
|
||||
arg_ptrs_[i] = const_cast<ShapedBuffer*>(&xla_tensor->shaped_buffer());
|
||||
} else {
|
||||
CHECK(xla::Shape::Equal().MinorToMajorOnlyInLayout()(shape,
|
||||
on_device_shape))
|
||||
<< "On-device shape "
|
||||
<< xla::ShapeUtil::HumanStringWithLayout(on_device_shape)
|
||||
<< " not the same as on-host shape "
|
||||
<< xla::ShapeUtil::HumanStringWithLayout(shape);
|
||||
if (xla::Shape::Equal().MinorToMajorOnlyInLayout()(
|
||||
shape, transfer_manager->HostShapeToDeviceShape(shape))) {
|
||||
se::DeviceMemoryBase dmem = XlaTensor::DeviceMemoryFromTensor(*t);
|
||||
arg_buffers_.emplace_back(
|
||||
/*on_host_shape=*/shape, /*on_device_shape=*/shape,
|
||||
client_->platform(), client_->default_device_ordinal());
|
||||
arg_buffers_.back().set_buffer(dmem, /*index=*/{});
|
||||
arg_ptrs_[i] = &arg_buffers_.back();
|
||||
} else {
|
||||
const XlaTensor* xla_tensor = XlaTensor::FromTensor(t);
|
||||
CHECK(xla_tensor && xla_tensor->has_shaped_buffer());
|
||||
arg_ptrs_[i] = const_cast<ShapedBuffer*>(&xla_tensor->shaped_buffer());
|
||||
}
|
||||
}
|
||||
}
|
||||
@ -370,13 +363,94 @@ static Status SetBufferForResourceVarTensorUnderAllocateXlaTensors(
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
// Sets output `output_num` for `ctx` provided it is known at a compile time.
|
||||
static Status SetOutputForConstant(
|
||||
OpKernelContext* ctx, se::Stream* stream,
|
||||
const XlaCompiler::CompilationResult* compilation_result, int output_num) {
|
||||
CHECK(compilation_result->outputs[output_num].is_constant);
|
||||
// Output is a constant.
|
||||
const Tensor& const_tensor =
|
||||
compilation_result->outputs[output_num].constant_value;
|
||||
Tensor* output_tensor;
|
||||
const size_t total_bytes = const_tensor.TotalBytes();
|
||||
if (stream && total_bytes > 0) {
|
||||
// Copy host -> device. (Empty tensors don't have backing buffers.)
|
||||
// Manually allocate memory using an XlaTensorBuffer so we can allocate
|
||||
// as much memory as the device requires (as given by
|
||||
// GetByteSizeRequirement). This avoids XlaTransferManager having to
|
||||
// reallocate the device buffer later.
|
||||
VLOG(1) << "Constant output tensor on device";
|
||||
|
||||
TF_RETURN_IF_ERROR(
|
||||
ctx->allocate_output(output_num, const_tensor.shape(), &output_tensor));
|
||||
Device* device = dynamic_cast<Device*>(ctx->device());
|
||||
if (device == nullptr) {
|
||||
return errors::Internal("DeviceBase was not a Device.");
|
||||
}
|
||||
ctx->op_device_context()->CopyCPUTensorToDevice(
|
||||
&const_tensor, device, output_tensor,
|
||||
[&](Status status) { TF_CHECK_OK(status); });
|
||||
|
||||
if (device->device_type() == DEVICE_GPU) {
|
||||
// The GPUDeviceContext enqueues the host->device transfer in a
|
||||
// separate stream from the main compute stream. We must ensure the
|
||||
// compute stream is synchronized with the host->device transfer
|
||||
// stream now otherwise we will create a race condition.
|
||||
auto* gpu_device_context =
|
||||
static_cast<GPUDeviceContext*>(ctx->op_device_context());
|
||||
gpu_device_context->stream()->ThenWaitFor(
|
||||
gpu_device_context->host_to_device_stream());
|
||||
}
|
||||
} else {
|
||||
// No copy required.
|
||||
ctx->set_output(output_num, const_tensor);
|
||||
output_tensor = ctx->mutable_output(output_num);
|
||||
}
|
||||
if (XlaTensor* xla_tensor = XlaTensor::FromTensor(output_tensor)) {
|
||||
xla_tensor->set_host_tensor(const_tensor);
|
||||
}
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
// Creates a list of updates resource variables.
|
||||
static xla::StatusOr<std::vector<VariableInfo>> GatherVariableInfo(
|
||||
OpKernelContext* ctx,
|
||||
const XlaCompiler::CompilationResult* compilation_result,
|
||||
int missing_ctx_input_prefix) {
|
||||
std::vector<VariableInfo> variable_infos;
|
||||
variable_infos.reserve(compilation_result->resource_updates.size());
|
||||
|
||||
for (int i = 0; i < compilation_result->resource_updates.size(); ++i) {
|
||||
const XlaCompiler::ResourceUpdate& write =
|
||||
compilation_result->resource_updates[i];
|
||||
int actual_input_index = write.input_index - missing_ctx_input_prefix;
|
||||
if (actual_input_index < 0 || actual_input_index >= ctx->num_inputs()) {
|
||||
return errors::Internal("Invalid input index for variable write.");
|
||||
}
|
||||
|
||||
// TODO(b/35625933): tensorflow::Var should contain a PersistentTensor,
|
||||
// not a Tensor.
|
||||
Var* variable = nullptr;
|
||||
TF_RETURN_IF_ERROR(LookupOrCreateResource<Var>(
|
||||
ctx, HandleFromInput(ctx, actual_input_index), &variable,
|
||||
[&write](Var** ptr) {
|
||||
*ptr = new Var(write.type);
|
||||
return Status::OK();
|
||||
}));
|
||||
variable_infos.emplace_back(actual_input_index, variable);
|
||||
}
|
||||
return variable_infos;
|
||||
}
|
||||
|
||||
Status XlaComputationLaunchContext::PopulateOutputs(
|
||||
OpKernelContext* ctx, const XlaCompiler::CompilationResult* kernel,
|
||||
OpKernelContext* ctx,
|
||||
const XlaCompiler::CompilationResult* compilation_result,
|
||||
ScopedShapedBuffer output, int missing_ctx_input_prefix,
|
||||
const xla::HloInputOutputAliasConfig& input_output_alias,
|
||||
const std::map<int, OptionalTensor>& resource_var_snapshots) {
|
||||
se::Stream* stream =
|
||||
ctx->op_device_context() ? ctx->op_device_context()->stream() : nullptr;
|
||||
Allocator* allocator = ctx->device()->GetAllocator({});
|
||||
|
||||
// Computation output should always be a tuple.
|
||||
if (VLOG_IS_ON(2)) {
|
||||
@ -384,7 +458,7 @@ Status XlaComputationLaunchContext::PopulateOutputs(
|
||||
VLOG(2) << "Result tuple shape (on device): "
|
||||
<< output.on_device_shape().DebugString();
|
||||
}
|
||||
CHECK_EQ(ctx->num_outputs(), kernel->outputs.size());
|
||||
CHECK_EQ(ctx->num_outputs(), compilation_result->outputs.size());
|
||||
|
||||
// If the on-host-shape isn't a tuple, create a new single-element tuple
|
||||
// buffer with a nullptr root index table. This allows the code below to treat
|
||||
@ -413,82 +487,41 @@ Status XlaComputationLaunchContext::PopulateOutputs(
|
||||
// Copy XLA results to the OpOutputList.
|
||||
int output_num = 0;
|
||||
for (int i = 0; i < ctx->num_outputs(); ++i) {
|
||||
Allocator* allocator = ctx->device()->GetAllocator({});
|
||||
if (kernel->outputs[i].is_constant) {
|
||||
// Output is a constant.
|
||||
const Tensor& const_tensor = kernel->outputs[i].constant_value;
|
||||
Tensor* output_tensor;
|
||||
const size_t total_bytes = const_tensor.TotalBytes();
|
||||
if (stream && total_bytes > 0) {
|
||||
// Copy host -> device. (Empty tensors don't have backing buffers.)
|
||||
// Manually allocate memory using an XlaTensorBuffer so we can allocate
|
||||
// as much memory as the device requires (as given by
|
||||
// GetByteSizeRequirement). This avoids XlaTransferManager having to
|
||||
// reallocate the device buffer later.
|
||||
VLOG(1) << "Constant output tensor on device";
|
||||
const TensorShape& shape = compilation_result->outputs[i].shape;
|
||||
const DataType& type = compilation_result->outputs[i].type;
|
||||
VLOG(2) << "Retval " << i << " shape " << shape.DebugString() << " type "
|
||||
<< DataTypeString(type);
|
||||
if (type == DT_VARIANT) {
|
||||
return errors::Unimplemented(
|
||||
"Support for TensorList crossing the XLA/TF boundary "
|
||||
"is not implemented");
|
||||
}
|
||||
|
||||
TF_RETURN_IF_ERROR(
|
||||
ctx->allocate_output(i, const_tensor.shape(), &output_tensor));
|
||||
|
||||
Device* device = dynamic_cast<Device*>(ctx->device());
|
||||
if (device == nullptr) {
|
||||
return errors::Internal("DeviceBase was not a Device.");
|
||||
}
|
||||
ctx->op_device_context()->CopyCPUTensorToDevice(
|
||||
&const_tensor, device, output_tensor,
|
||||
[&](Status status) { TF_CHECK_OK(status); });
|
||||
|
||||
if (device->device_type() == DEVICE_GPU) {
|
||||
// The GPUDeviceContext enqueues the host->device transfer in a
|
||||
// separate stream from the main compute stream. We must ensure the
|
||||
// compute stream is synchronized with the host->device transfer
|
||||
// stream now otherwise we will create a race condition.
|
||||
auto* gpu_device_context =
|
||||
static_cast<GPUDeviceContext*>(ctx->op_device_context());
|
||||
gpu_device_context->stream()->ThenWaitFor(
|
||||
gpu_device_context->host_to_device_stream());
|
||||
}
|
||||
} else {
|
||||
// No copy required.
|
||||
ctx->set_output(i, const_tensor);
|
||||
output_tensor = ctx->mutable_output(i);
|
||||
}
|
||||
if (XlaTensor* xla_tensor = XlaTensor::FromTensor(output_tensor)) {
|
||||
xla_tensor->set_host_tensor(const_tensor);
|
||||
}
|
||||
if (compilation_result->outputs[i].is_constant) {
|
||||
TF_RETURN_IF_ERROR(
|
||||
SetOutputForConstant(ctx, stream, compilation_result, i));
|
||||
} else if (type == DT_RESOURCE) {
|
||||
int input_index =
|
||||
compilation_result->outputs[i].input_index - missing_ctx_input_prefix;
|
||||
TF_RET_CHECK(input_index >= 0 && input_index < ctx->num_inputs())
|
||||
<< "Invalid input for outputs " << i << ": " << input_index;
|
||||
ctx->set_output(i, ctx->input(input_index));
|
||||
} else {
|
||||
const TensorShape& shape = kernel->outputs[i].shape;
|
||||
const DataType& type = kernel->outputs[i].type;
|
||||
VLOG(2) << "Retval " << i << " shape " << shape.DebugString() << " type "
|
||||
<< DataTypeString(type);
|
||||
if (type == DT_RESOURCE) {
|
||||
int input_index =
|
||||
kernel->outputs[i].input_index - missing_ctx_input_prefix;
|
||||
TF_RET_CHECK(input_index >= 0 && input_index < ctx->num_inputs())
|
||||
<< "Invalid input for outputs " << i << ": " << input_index;
|
||||
ctx->set_output(i, ctx->input(input_index));
|
||||
} else {
|
||||
if (allocate_xla_tensors_) {
|
||||
TF_RETURN_IF_ERROR(SetBufferForTensorUnderAllocateXlaTensors(
|
||||
input_output_alias, output_num, ctx, i, shape, &output,
|
||||
definition_event, stream, use_multiple_streams_));
|
||||
} else {
|
||||
if (type == DT_VARIANT) {
|
||||
return errors::Unimplemented(
|
||||
"Support for TensorList crossing the XLA/TF boundary "
|
||||
"is not implemented");
|
||||
}
|
||||
if (allocate_xla_tensors_) {
|
||||
TF_RETURN_IF_ERROR(SetBufferForTensorUnderAllocateXlaTensors(
|
||||
input_output_alias, output_num, ctx, i, shape, &output,
|
||||
definition_event, stream, use_multiple_streams_));
|
||||
|
||||
se::DeviceMemoryBase buffer = output.buffer({output_num});
|
||||
Tensor output_tensor = GetOrCreateTensorForOutput(
|
||||
output_num, ctx, missing_ctx_input_prefix, input_output_alias,
|
||||
kernel->input_mapping, resource_var_snapshots,
|
||||
ctx->expected_output_dtype(i), shape, buffer, allocator);
|
||||
output.set_buffer(se::OwningDeviceMemory(), {output_num});
|
||||
ctx->set_output(i, output_tensor);
|
||||
}
|
||||
++output_num;
|
||||
} else {
|
||||
se::DeviceMemoryBase buffer = output.buffer({output_num});
|
||||
Tensor output_tensor = GetOrCreateTensorForOutput(
|
||||
output_num, ctx, missing_ctx_input_prefix, input_output_alias,
|
||||
compilation_result->input_mapping, resource_var_snapshots,
|
||||
ctx->expected_output_dtype(i), shape, buffer, allocator);
|
||||
output.set_buffer(se::OwningDeviceMemory(), {output_num});
|
||||
ctx->set_output(i, output_tensor);
|
||||
}
|
||||
++output_num;
|
||||
}
|
||||
|
||||
if (VLOG_IS_ON(3)) {
|
||||
@ -498,34 +531,14 @@ Status XlaComputationLaunchContext::PopulateOutputs(
|
||||
|
||||
// Apply variable updates, if any.
|
||||
VLOG(2) << "Applying variable updates";
|
||||
std::vector<VariableInfo> variable_infos;
|
||||
variable_infos.reserve(kernel->resource_updates.size());
|
||||
|
||||
for (int i = 0; i < kernel->resource_updates.size(); ++i) {
|
||||
const XlaCompiler::ResourceUpdate& write = kernel->resource_updates[i];
|
||||
int actual_input_index = write.input_index - missing_ctx_input_prefix;
|
||||
if (actual_input_index < 0 || actual_input_index >= ctx->num_inputs()) {
|
||||
return errors::Internal("Invalid input index for variable write.");
|
||||
}
|
||||
|
||||
// TODO(b/35625933): tensorflow::Var should contain a PersistentTensor,
|
||||
// not a Tensor.
|
||||
Var* variable = nullptr;
|
||||
TF_RETURN_IF_ERROR(LookupOrCreateResource<Var>(
|
||||
ctx, HandleFromInput(ctx, actual_input_index), &variable,
|
||||
[&write](Var** ptr) {
|
||||
*ptr = new Var(write.type);
|
||||
return Status::OK();
|
||||
}));
|
||||
variable_infos.emplace_back(actual_input_index, variable);
|
||||
}
|
||||
|
||||
TF_ASSIGN_OR_RETURN(
|
||||
std::vector<VariableInfo> variable_infos,
|
||||
GatherVariableInfo(ctx, compilation_result, missing_ctx_input_prefix));
|
||||
TF_RETURN_IF_ERROR(LockVariables(absl::MakeSpan(variable_infos)));
|
||||
|
||||
for (int i = 0; i < kernel->resource_updates.size(); ++i) {
|
||||
Allocator* allocator = ctx->device()->GetAllocator({});
|
||||
const XlaCompiler::ResourceUpdate& write = kernel->resource_updates[i];
|
||||
|
||||
for (int i = 0; i < compilation_result->resource_updates.size(); ++i) {
|
||||
const XlaCompiler::ResourceUpdate& write =
|
||||
compilation_result->resource_updates[i];
|
||||
if (variable_infos[i].var()->tensor()->dtype() != write.type) {
|
||||
return errors::Internal("Mismatched type in variable write");
|
||||
}
|
||||
@ -539,7 +552,7 @@ Status XlaComputationLaunchContext::PopulateOutputs(
|
||||
output.set_buffer(se::OwningDeviceMemory(), {output_num});
|
||||
Tensor output_tensor = GetOrCreateTensorForOutput(
|
||||
output_num, ctx, missing_ctx_input_prefix, input_output_alias,
|
||||
kernel->input_mapping, resource_var_snapshots, write.type,
|
||||
compilation_result->input_mapping, resource_var_snapshots, write.type,
|
||||
write.shape, buffer, allocator);
|
||||
*variable_infos[i].var()->tensor() = output_tensor;
|
||||
variable_infos[i].var()->is_initialized |= write.modified;
|
||||
|
@ -136,7 +136,7 @@ class XlaComputationLaunchContext {
|
||||
// input_mapping must be greater than or equal to `missing_ctx_input_prefix`
|
||||
// (in other words, no inputs actually required by the kernel can be missing).
|
||||
void PopulateInputs(OpKernelContext* ctx,
|
||||
const XlaCompiler::CompilationResult* kernel,
|
||||
const XlaCompiler::CompilationResult* compilation_result,
|
||||
const std::map<int, OptionalTensor>& variables,
|
||||
int missing_ctx_input_prefix);
|
||||
|
||||
@ -148,10 +148,11 @@ class XlaComputationLaunchContext {
|
||||
// See jit/resource_operation_safety_analysis for details.
|
||||
//
|
||||
//
|
||||
// Assumes that the first `missing_ctx_input_prefix` inputs to the kernel are
|
||||
// missing and adjusts input indices accordingly.
|
||||
// Assumes that the first `missing_ctx_input_prefix` inputs to the
|
||||
// compilation_result are missing and adjusts input indices accordingly.
|
||||
Status PopulateOutputs(
|
||||
OpKernelContext* ctx, const XlaCompiler::CompilationResult* kernel,
|
||||
OpKernelContext* ctx,
|
||||
const XlaCompiler::CompilationResult* compilation_result,
|
||||
xla::ScopedShapedBuffer output, int missing_ctx_input_prefix,
|
||||
const xla::HloInputOutputAliasConfig& input_output_alias,
|
||||
const std::map<int, OptionalTensor>& resource_var_snapshots);
|
||||
|
@ -55,7 +55,7 @@ class XlaTensor {
|
||||
// manage the memory for these tensors a ShapedBuffer may be required.
|
||||
|
||||
// Return true if this XlaTensor contains a ShapedBuffer.
|
||||
bool has_shaped_buffer() const { return shaped_buffer_ != nullptr; }
|
||||
bool has_shaped_buffer() const { return shaped_buffer_.has_value(); }
|
||||
// Return the contained ShapedBuffer.
|
||||
// REQUIRES: has_shaped_buffer()
|
||||
const xla::ShapedBuffer& shaped_buffer() const {
|
||||
@ -68,8 +68,7 @@ class XlaTensor {
|
||||
}
|
||||
// Mutates the XlaTensor to set the ShapedBuffer.
|
||||
void set_shaped_buffer(xla::ScopedShapedBuffer shaped_buffer) {
|
||||
shaped_buffer_ =
|
||||
absl::make_unique<xla::ScopedShapedBuffer>(std::move(shaped_buffer));
|
||||
shaped_buffer_ = std::move(shaped_buffer);
|
||||
}
|
||||
|
||||
// Some tensors on the device may have known values on the host. We use these
|
||||
@ -77,14 +76,12 @@ class XlaTensor {
|
||||
// host value already.
|
||||
|
||||
// Return true if this XlaTensor contains a host tensor.
|
||||
bool has_host_tensor() const { return host_tensor_ != nullptr; }
|
||||
bool has_host_tensor() const { return host_tensor_.has_value(); }
|
||||
// Return the contained host tensor.
|
||||
// REQUIRES: has_host_tensor()
|
||||
const Tensor& host_tensor() const { return *host_tensor_; }
|
||||
// Sets the contained host tensor.
|
||||
void set_host_tensor(const Tensor& tensor) {
|
||||
host_tensor_.reset(new Tensor(tensor));
|
||||
}
|
||||
void set_host_tensor(const Tensor& tensor) { host_tensor_.emplace(tensor); }
|
||||
|
||||
// Adds synchronization events to 'stream' that wait for this tensor to be
|
||||
// defined on 'stream'. Does nothing if the tensor is already defined on that
|
||||
@ -111,9 +108,9 @@ class XlaTensor {
|
||||
|
||||
private:
|
||||
// The optional contained ShapedBuffer.
|
||||
std::unique_ptr<xla::ScopedShapedBuffer> shaped_buffer_;
|
||||
absl::optional<xla::ScopedShapedBuffer> shaped_buffer_;
|
||||
// An optional host tensor value.
|
||||
std::unique_ptr<Tensor> host_tensor_;
|
||||
absl::optional<Tensor> host_tensor_;
|
||||
// An optional event that is triggered when the tensor's content has been
|
||||
// defined. If this event is nullptr, it is assumed that the tensor's content
|
||||
// is always defined.
|
||||
|
@ -314,6 +314,7 @@ tf_cc_test(
|
||||
cc_library(
|
||||
name = "tensorflow_lite_legalize_tf",
|
||||
srcs = [
|
||||
"transforms/device_index_selector.cc",
|
||||
"transforms/dilated_conv.cc",
|
||||
"transforms/generated_legalize_tf.inc",
|
||||
"transforms/generated_lower_static_tensor_list.inc",
|
||||
|
@ -1406,22 +1406,67 @@ BufferOffset<tflite::SparsityParameters> Translator::BuildSparsityParameters(
|
||||
for (int j = 0; j < segments.size(); j++) {
|
||||
vector_segments[j] = segments[j].dyn_cast<mlir::IntegerAttr>().getInt();
|
||||
}
|
||||
auto array_segments =
|
||||
tflite::CreateInt32Vector(builder_,
|
||||
builder_.CreateVector(vector_segments))
|
||||
.Union();
|
||||
tflite::SparseIndexVector segments_type;
|
||||
BufferOffset<void> array_segments;
|
||||
// The segment array is sorted.
|
||||
// TODO(b/147449640): Clean this up with util functions.
|
||||
int max_of_segments = vector_segments[segments.size() - 1];
|
||||
if (max_of_segments <= UINT8_MAX) {
|
||||
segments_type = tflite::SparseIndexVector_Uint8Vector;
|
||||
std::vector<uint8_t> uint8_vector(vector_segments.begin(),
|
||||
vector_segments.end());
|
||||
array_segments = tflite::CreateUint8Vector(
|
||||
builder_, builder_.CreateVector(uint8_vector))
|
||||
.Union();
|
||||
} else if (max_of_segments <= UINT16_MAX) {
|
||||
segments_type = tflite::SparseIndexVector_Uint16Vector;
|
||||
std::vector<uint16_t> uint16_vector(vector_segments.begin(),
|
||||
vector_segments.end());
|
||||
array_segments = tflite::CreateUint16Vector(
|
||||
builder_, builder_.CreateVector(uint16_vector))
|
||||
.Union();
|
||||
} else {
|
||||
segments_type = tflite::SparseIndexVector_Int32Vector;
|
||||
array_segments = tflite::CreateInt32Vector(
|
||||
builder_, builder_.CreateVector(vector_segments))
|
||||
.Union();
|
||||
}
|
||||
|
||||
auto indices = dim_metadata.indices();
|
||||
std::vector<int> vector_indices(indices.size(), 0);
|
||||
int max_of_indices = 0;
|
||||
for (int j = 0; j < indices.size(); j++) {
|
||||
vector_indices[j] = indices[j].dyn_cast<mlir::IntegerAttr>().getInt();
|
||||
if (vector_indices[j] > max_of_indices) {
|
||||
max_of_indices = vector_indices[j];
|
||||
}
|
||||
}
|
||||
auto array_indices = tflite::CreateInt32Vector(
|
||||
builder_, builder_.CreateVector(vector_indices))
|
||||
.Union();
|
||||
tflite::SparseIndexVector indices_type;
|
||||
BufferOffset<void> array_indices;
|
||||
if (max_of_indices <= UINT8_MAX) {
|
||||
indices_type = tflite::SparseIndexVector_Uint8Vector;
|
||||
std::vector<uint8_t> uint8_vector(vector_indices.begin(),
|
||||
vector_indices.end());
|
||||
array_indices = tflite::CreateUint8Vector(
|
||||
builder_, builder_.CreateVector(uint8_vector))
|
||||
.Union();
|
||||
} else if (max_of_indices <= UINT16_MAX) {
|
||||
indices_type = tflite::SparseIndexVector_Uint16Vector;
|
||||
std::vector<uint16_t> uint16_vector(vector_indices.begin(),
|
||||
vector_indices.end());
|
||||
array_indices = tflite::CreateUint16Vector(
|
||||
builder_, builder_.CreateVector(uint16_vector))
|
||||
.Union();
|
||||
} else {
|
||||
indices_type = tflite::SparseIndexVector_Int32Vector;
|
||||
array_indices = tflite::CreateInt32Vector(
|
||||
builder_, builder_.CreateVector(vector_indices))
|
||||
.Union();
|
||||
}
|
||||
|
||||
fb_dim_metadata[i] = tflite::CreateDimensionMetadata(
|
||||
builder_, tflite::DimensionType_SPARSE_CSR, 0,
|
||||
tflite::SparseIndexVector_Int32Vector, array_segments,
|
||||
tflite::SparseIndexVector_Int32Vector, array_indices);
|
||||
builder_, tflite::DimensionType_SPARSE_CSR, 0, segments_type,
|
||||
array_segments, indices_type, array_indices);
|
||||
}
|
||||
}
|
||||
|
||||
|
@ -436,7 +436,7 @@ class TFL_Op<string mnemonic, list<OpTrait> traits = []> :
|
||||
class TFL_ConvOp<string mnemonic, string opSummary, int index> :
|
||||
TFL_Op<mnemonic, [NoSideEffect, AccumulatorUniformScale<2, 0, 1>,
|
||||
TFL_ChannelDimIndexInterface, AffineOpCoefficient<index, 1>,
|
||||
TFL_GpuTargetOp]> {
|
||||
TFL_GpuTargetOp, TFL_SparseOp]> {
|
||||
let summary = opSummary # " operator";
|
||||
|
||||
let description = [{
|
||||
@ -571,7 +571,8 @@ def TFL_TransposeConvOp: TFL_Op<"transpose_conv", [
|
||||
TFL_OperandHasRank<2, 4>,
|
||||
PredOpTrait<"input and output must have same element type",
|
||||
TFL_TCresVTEtIsSameAsOp<0, 2>>,
|
||||
TFL_GpuTargetOp]> {
|
||||
TFL_GpuTargetOp,
|
||||
TFL_SparseOp]> {
|
||||
let summary = "Transpose convolution operator";
|
||||
|
||||
let description = [{
|
||||
@ -593,6 +594,13 @@ def TFL_TransposeConvOp: TFL_Op<"transpose_conv", [
|
||||
let hasOptions = 1;
|
||||
|
||||
let verifier = [{ return Verify(*this); }];
|
||||
|
||||
let extraClassDeclaration = [{
|
||||
// SparseOpInterface:
|
||||
std::vector<int> GetSparseOperands() { return {1}; }
|
||||
std::vector<std::vector<int>> GetFloatBlockSize() { return {}; }
|
||||
std::vector<std::vector<int>> GetQuantizedBlockSize() { return {}; }
|
||||
}];
|
||||
}
|
||||
|
||||
def TFL_AveragePool2DOp:
|
||||
@ -826,6 +834,10 @@ def TFL_Conv2DOp : TFL_ConvOp<"conv_2d", "Convolution", 0> {
|
||||
let extraClassDeclaration = [{
|
||||
// ChannelDimIndexInterface:
|
||||
int GetChannelDimIndex() { return 0; }
|
||||
// SparseOpInterface:
|
||||
std::vector<int> GetSparseOperands() { return {1}; }
|
||||
std::vector<std::vector<int>> GetFloatBlockSize() { return {}; }
|
||||
std::vector<std::vector<int>> GetQuantizedBlockSize() { return {}; }
|
||||
}];
|
||||
}
|
||||
|
||||
@ -866,6 +878,10 @@ def TFL_DepthwiseConv2DOp :
|
||||
let extraClassDeclaration = [{
|
||||
// ChannelDimIndexInterface:
|
||||
int GetChannelDimIndex() { return 3; }
|
||||
// SparseOpInterface:
|
||||
std::vector<int> GetSparseOperands() { return {1}; }
|
||||
std::vector<std::vector<int>> GetFloatBlockSize() { return {}; }
|
||||
std::vector<std::vector<int>> GetQuantizedBlockSize() { return {}; }
|
||||
}];
|
||||
}
|
||||
|
||||
|
@ -157,7 +157,7 @@
|
||||
// CHECK-NEXT: }, {
|
||||
// CHECK-EMPTY:
|
||||
// CHECK-NEXT: }, {
|
||||
// CHECK-NEXT: data: [ 49, 46, 49, 52, 46, 48, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0 ]
|
||||
// CHECK-NEXT: data: [ 49, 46, 49, 53, 46, 48, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0 ]
|
||||
// CHECK-NEXT: } ],
|
||||
// CHECK-NEXT: metadata: [ {
|
||||
// CHECK-NEXT: name: "min_runtime_version",
|
||||
|
@ -154,7 +154,7 @@ func @main(%arg0: tensor<1x224x224x3xf32>) -> tensor<1x1001xf32> {
|
||||
// CHECK-NEXT: }, {
|
||||
// CHECK-EMPTY:
|
||||
// CHECK-NEXT: }, {
|
||||
// CHECK-NEXT: data: [ 49, 46, 49, 51, 46, 49, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0 ]
|
||||
// CHECK-NEXT: data: [ 49, 46, 49, 52, 46, 48, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0 ]
|
||||
// CHECK-NEXT: } ],
|
||||
// CHECK-NEXT: metadata: [ {
|
||||
// CHECK-NEXT: name: "min_runtime_version",
|
||||
|
@ -190,7 +190,7 @@
|
||||
// CHECK-NEXT: }, {
|
||||
// CHECK-EMPTY:
|
||||
// CHECK-NEXT: }, {
|
||||
// CHECK-NEXT: data: [ 49, 46, 49, 52, 46, 48, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0 ]
|
||||
// CHECK-NEXT: data: [ 49, 46, 49, 53, 46, 48, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0 ]
|
||||
// CHECK-NEXT: } ],
|
||||
// CHECK-NEXT: metadata: [ {
|
||||
// CHECK-NEXT: name: "min_runtime_version",
|
||||
|
@ -190,7 +190,7 @@
|
||||
// CHECK-NEXT: }, {
|
||||
// CHECK-EMPTY:
|
||||
// CHECK-NEXT: }, {
|
||||
// CHECK-NEXT: data: [ 49, 46, 49, 52, 46, 48, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0 ]
|
||||
// CHECK-NEXT: data: [ 49, 46, 49, 53, 46, 48, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0 ]
|
||||
// CHECK-NEXT: } ],
|
||||
// CHECK-NEXT: metadata: [ {
|
||||
// CHECK-NEXT: name: "min_runtime_version",
|
||||
|
@ -0,0 +1,25 @@
|
||||
// Test DeviceIndex selector.
|
||||
|
||||
// RUN: tf-opt --tfl-device-index-selector %s | FileCheck %s
|
||||
|
||||
// CHECK-LABEL: func @select
|
||||
func @select(%arg0: tensor<f32>, %arg1: tensor<f32>) -> (tensor<i32>, tensor<f32>) {
|
||||
// CHECK: %[[first:.*]] = "tf.DeviceIndex"
|
||||
// CHECK: constant dense<2>
|
||||
// CHECK: return %[[first]],
|
||||
%0 = "tf.DeviceIndex"() {device = "", device_names = ["CPU", "GPU"]} : () -> tensor<i32>
|
||||
%1 = "tf.DeviceIndex"() {device = "", device_names = ["CPU", "GPU"]} : () -> tensor<i32>
|
||||
%4 = "tf.Case"(%1, %arg0, %arg1) {branches = [@sub, @add], output_shapes = [#tf.shape<>]} : (tensor<i32>, tensor<f32>, tensor<f32>) -> tensor<f32>
|
||||
|
||||
return %0, %4 : tensor<i32>, tensor<f32>
|
||||
}
|
||||
|
||||
func @add(%i: tensor<i32>, %arg0: tensor<*xf32>, %arg1: tensor<*xf32>) -> tensor<*xf32> {
|
||||
%0 = "tf.Add"(%arg0, %arg1): (tensor<*xf32>, tensor<*xf32>) -> tensor<*xf32>
|
||||
return %0 : tensor<*xf32>
|
||||
}
|
||||
|
||||
func @sub(%i: tensor<i32>, %arg0: tensor<*xf32>, %arg1: tensor<*xf32>) -> tensor<*xf32> {
|
||||
%0 = "tf.Sub"(%arg0, %arg1) : (tensor<*xf32>, tensor<*xf32>) -> tensor<*xf32>
|
||||
return %0 : tensor<*xf32>
|
||||
}
|
@ -63,6 +63,7 @@ void AddTFToTFLConversionPasses(const mlir::TFL::PassConfig& pass_config,
|
||||
standard_pipeline_options.enable_inliner = false;
|
||||
standard_pipeline_options.form_clusters = pass_config.form_clusters;
|
||||
mlir::TF::CreateTFStandardPipeline(*pass_manager, standard_pipeline_options);
|
||||
pass_manager->addPass(mlir::TFL::CreateDeviceIndexSelectorPass());
|
||||
|
||||
if (pass_config.shape_inference) {
|
||||
pass_manager->addPass(mlir::TF::CreateTFShapeInferencePass());
|
||||
|
@ -40,14 +40,22 @@ void PopulateEncodingParams(const std::vector<int>& block_size,
|
||||
std::vector<int>* traversal_order,
|
||||
std::vector<TfLiteDimensionType>* format,
|
||||
std::vector<int>* b_map, std::vector<int>* b_size) {
|
||||
*traversal_order = {0, 1};
|
||||
*format = {kTfLiteDimDense, kTfLiteDimSparseCSR};
|
||||
const int dims_count = block_size.size();
|
||||
traversal_order->resize(dims_count);
|
||||
format->resize(dims_count);
|
||||
for (int i = 0; i < dims_count; i++) {
|
||||
(*traversal_order)[i] = i;
|
||||
}
|
||||
for (int i = 0; i < dims_count - 1; i++) {
|
||||
(*format)[i] = kTfLiteDimDense;
|
||||
}
|
||||
(*format)[dims_count - 1] = kTfLiteDimSparseCSR;
|
||||
*b_map = {};
|
||||
*b_size = {};
|
||||
int block_rank = 0;
|
||||
for (int i = 0; i < 2; i++) {
|
||||
for (int i = 0; i < dims_count; i++) {
|
||||
if (block_size[i] != 1) {
|
||||
traversal_order->push_back(block_rank + 2);
|
||||
traversal_order->push_back(block_rank + dims_count);
|
||||
format->push_back(kTfLiteDimDense);
|
||||
block_rank++;
|
||||
b_map->push_back(i);
|
||||
@ -58,27 +66,18 @@ void PopulateEncodingParams(const std::vector<int>& block_size,
|
||||
|
||||
float CalculateRandomSparsity(const ElementsAttr& attr,
|
||||
const ShapedType& type) {
|
||||
int num_elements = 1;
|
||||
for (int i = 0; i < 2; i++) {
|
||||
num_elements *= type.getDimSize(i);
|
||||
}
|
||||
int num_elements = type.getNumElements();
|
||||
int num_zeros = 0;
|
||||
|
||||
if (type.getElementType().isF32()) {
|
||||
std::vector<float> data;
|
||||
data.reserve(type.getNumElements());
|
||||
for (const auto val : attr.getValues<float>()) data.push_back(val);
|
||||
for (int i = 0; i < data.size(); i++) {
|
||||
if (data[i] == 0) {
|
||||
for (const auto val : attr.getValues<float>()) {
|
||||
if (val == 0.f) {
|
||||
num_zeros++;
|
||||
}
|
||||
}
|
||||
} else if (type.getElementType().isa<quant::QuantizedType>()) {
|
||||
std::vector<int8_t> data;
|
||||
data.reserve(type.getNumElements());
|
||||
for (const auto val : attr.getValues<int8_t>()) data.push_back(val);
|
||||
for (int i = 0; i < data.size(); i++) {
|
||||
if (data[i] == 0) {
|
||||
for (const auto val : attr.getValues<int8_t>()) {
|
||||
if (val == 0) {
|
||||
num_zeros++;
|
||||
}
|
||||
}
|
||||
@ -150,9 +149,10 @@ InspectResult InspectWeight(
|
||||
type = cst.getType().cast<ShapedType>();
|
||||
}
|
||||
|
||||
// TODO(b/147449640): Add ability to encode weights more than 2-D, e.g. Conv
|
||||
// weights.
|
||||
if (type.getRank() != 2) {
|
||||
// Currently we only support compressing weights of ops:
|
||||
// Conv, DepthwiseConv, TransposeConv, whose filter has rank 4, and
|
||||
// FullyConnected, whose filter has rank 2.
|
||||
if (type.getRank() != 2 && type.getRank() != 4) {
|
||||
result.can_compress = false;
|
||||
return result;
|
||||
}
|
||||
@ -195,9 +195,11 @@ std::vector<T> BuildSparsityParameterAttribute(
|
||||
attr = cst.value();
|
||||
type = cst.getType().cast<ShapedType>();
|
||||
}
|
||||
std::vector<int> shape(2);
|
||||
shape[0] = type.getDimSize(0);
|
||||
shape[1] = type.getDimSize(1);
|
||||
const int dims_count = type.getRank();
|
||||
std::vector<int> shape(dims_count);
|
||||
for (int i = 0; i < dims_count; i++) {
|
||||
shape[i] = type.getDimSize(i);
|
||||
}
|
||||
|
||||
std::vector<int> traversal_order = {};
|
||||
std::vector<TfLiteDimensionType> format = {};
|
||||
@ -271,10 +273,13 @@ void DenseToSparse::runOnFunction() {
|
||||
continue;
|
||||
}
|
||||
|
||||
ShapedType type;
|
||||
if (isa<ConstOp>(inst)) {
|
||||
supported_block_size = sparse_op.GetFloatBlockSize();
|
||||
type = dyn_cast<ConstOp>(inst).getType().cast<ShapedType>();
|
||||
} else if (isa<QConstOp>(inst)) {
|
||||
supported_block_size = sparse_op.GetQuantizedBlockSize();
|
||||
type = dyn_cast<QConstOp>(inst).getType().cast<ShapedType>();
|
||||
} else {
|
||||
continue;
|
||||
}
|
||||
@ -286,7 +291,7 @@ void DenseToSparse::runOnFunction() {
|
||||
|
||||
// The weight is not block sparse. Encode with random sparsity.
|
||||
if (result.selected_block_size.empty()) {
|
||||
result.selected_block_size = {1, 1};
|
||||
result.selected_block_size = std::vector<int>(type.getRank(), 1);
|
||||
}
|
||||
|
||||
builder.setInsertionPoint(op);
|
||||
|
@ -0,0 +1,85 @@
|
||||
/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
|
||||
// Converts DeviceIndex to constant device.
|
||||
|
||||
#include "mlir/Dialect/StandardOps/IR/Ops.h" // from @llvm-project
|
||||
#include "mlir/IR/Attributes.h" // from @llvm-project
|
||||
#include "mlir/IR/Builders.h" // from @llvm-project
|
||||
#include "mlir/IR/Operation.h" // from @llvm-project
|
||||
#include "mlir/IR/PatternMatch.h" // from @llvm-project
|
||||
#include "mlir/Pass/Pass.h" // from @llvm-project
|
||||
#include "tensorflow/compiler/mlir/lite/transforms/passes.h"
|
||||
#include "tensorflow/compiler/mlir/tensorflow/ir/tf_ops.h"
|
||||
|
||||
namespace mlir {
|
||||
namespace TFL {
|
||||
namespace {
|
||||
|
||||
// Folds the DeviceIndex op to a constant value. The DeviceIndex return the
|
||||
// index of the device the op should run on. The user can use this to provide
|
||||
// different op specializations. E.g.,
|
||||
//
|
||||
// ```mlir
|
||||
// %1 = "tf.DeviceIndex"()
|
||||
// {device = "", device_names = ["CPU", "GPU"]} : () -> tensor<i32>
|
||||
// %4 = "tf.Case"(%1, %arg0, %arg1)
|
||||
// {branches = [@foo, @baz], output_shapes = [#tf.shape<>]} :
|
||||
// (tensor<i32>, tensor<f32>, tensor<f32>) -> tensor<f32>
|
||||
// ```
|
||||
//
|
||||
// Shows an example where there are 2 different functions which could be
|
||||
// executed to produce the same values but with different functions optimized
|
||||
// for CPU or GPU.
|
||||
struct DeviceIndexSelector
|
||||
: public PassWrapper<DeviceIndexSelector, OperationPass<FuncOp>> {
|
||||
void runOnOperation() override;
|
||||
};
|
||||
|
||||
} // namespace
|
||||
|
||||
void DeviceIndexSelector::runOnOperation() {
|
||||
FuncOp func = getOperation();
|
||||
// Convert all the DeviceIndex ops to constant values.
|
||||
func.getBody().walk([](TF::DeviceIndexOp op) {
|
||||
// This just selects the default in all cases where DeviceIndex feeds into
|
||||
// tf.Case. This could be enhanced based on explicit TFLite specification or
|
||||
// TAC in future.
|
||||
OpBuilder b(op);
|
||||
RankedTensorType type = RankedTensorType::get({}, b.getIntegerType(32));
|
||||
int index = op.device_names().size();
|
||||
for (auto use : op.getOperation()->getUsers()) {
|
||||
// Skip if it doesn't feed into case. Alternatively this could always
|
||||
// return the CPU device index if it exists.
|
||||
if (!isa<TF::CaseOp>(use)) return;
|
||||
}
|
||||
DenseElementsAttr attr =
|
||||
DenseElementsAttr::get(type, b.getI32IntegerAttr(index));
|
||||
auto constant = b.create<ConstantOp>(op.getLoc(), type, attr);
|
||||
op.replaceAllUsesWith(constant.getOperation());
|
||||
op.erase();
|
||||
});
|
||||
}
|
||||
|
||||
// Creates an instance of the TensorFlow DeviceIndex selector pass.
|
||||
std::unique_ptr<OperationPass<FuncOp>> CreateDeviceIndexSelectorPass() {
|
||||
return std::make_unique<DeviceIndexSelector>();
|
||||
}
|
||||
|
||||
static PassRegistration<DeviceIndexSelector> pass(
|
||||
"tfl-device-index-selector", "Fold tf.DeviceIndex to constant");
|
||||
|
||||
} // namespace TFL
|
||||
} // namespace mlir
|
@ -91,6 +91,9 @@ std::unique_ptr<OperationPass<ModuleOp>> CreateWhileOutlinePass();
|
||||
// Verifies runtime constraints.
|
||||
std::unique_ptr<OperationPass<FuncOp>> CreateRuntimeVerifyPass();
|
||||
|
||||
// Creates function pass to select device index/fold tf.DeviceIndex.
|
||||
std::unique_ptr<OperationPass<FuncOp>> CreateDeviceIndexSelectorPass();
|
||||
|
||||
} // namespace TFL
|
||||
|
||||
} // namespace mlir
|
||||
|
@ -397,6 +397,73 @@ cc_library(
|
||||
],
|
||||
)
|
||||
|
||||
cc_library(
|
||||
name = "lift_variables_lib",
|
||||
srcs = [
|
||||
"transforms/lift_variables.cc",
|
||||
],
|
||||
hdrs = [
|
||||
"transforms/lift_variables.h",
|
||||
],
|
||||
deps = [
|
||||
":convert_tensor",
|
||||
":tensorflow",
|
||||
"//tensorflow/core:core_cpu",
|
||||
"//tensorflow/core:core_cpu_lib",
|
||||
"//tensorflow/core:framework",
|
||||
"//tensorflow/core:framework_internal",
|
||||
"//tensorflow/core:lib",
|
||||
"//tensorflow/core/platform:errors",
|
||||
"//tensorflow/core/platform:threadpool_options",
|
||||
"@llvm-project//llvm:Support",
|
||||
"@llvm-project//mlir:IR",
|
||||
"@llvm-project//mlir:Pass",
|
||||
"@llvm-project//mlir:Support",
|
||||
],
|
||||
alwayslink = 1,
|
||||
)
|
||||
|
||||
cc_library(
|
||||
name = "lift_variables_pass",
|
||||
hdrs = [
|
||||
"transforms/lift_variables_pass.h",
|
||||
],
|
||||
deps = [
|
||||
":lift_variables_lib",
|
||||
":tensorflow",
|
||||
"//tensorflow/core:core_cpu",
|
||||
"//tensorflow/core:framework",
|
||||
"//tensorflow/core:framework_internal",
|
||||
"//tensorflow/core:lib",
|
||||
"@llvm-project//mlir:IR",
|
||||
"@llvm-project//mlir:Pass",
|
||||
"@llvm-project//mlir:Support",
|
||||
],
|
||||
alwayslink = 1,
|
||||
)
|
||||
|
||||
cc_library(
|
||||
name = "lift_variables_test_pass",
|
||||
hdrs = [
|
||||
"transforms/lift_variables_test_pass.h",
|
||||
],
|
||||
deps = [
|
||||
":lift_variables_lib",
|
||||
":tensorflow",
|
||||
"//tensorflow/core:core_cpu",
|
||||
"//tensorflow/core:core_cpu_lib",
|
||||
"//tensorflow/core:framework",
|
||||
"//tensorflow/core:framework_internal",
|
||||
"//tensorflow/core:lib",
|
||||
"//tensorflow/core/platform:errors",
|
||||
"//tensorflow/core/platform:status",
|
||||
"//tensorflow/core/platform:threadpool_options",
|
||||
"@llvm-project//mlir:Pass",
|
||||
"@llvm-project//mlir:Support",
|
||||
],
|
||||
alwayslink = 1,
|
||||
)
|
||||
|
||||
cc_library(
|
||||
name = "tensorflow_passes",
|
||||
srcs = [
|
||||
@ -520,9 +587,11 @@ cc_library(
|
||||
cc_library(
|
||||
name = "tensorflow_test_passes",
|
||||
srcs = [
|
||||
"transforms/lift_variables_test_pass_registration.cc",
|
||||
"transforms/lower_tf_pass.cc",
|
||||
],
|
||||
deps = [
|
||||
":lift_variables_test_pass",
|
||||
":lower_tf_lib",
|
||||
"@llvm-project//mlir:IR",
|
||||
"@llvm-project//mlir:Pass",
|
||||
|
@ -820,6 +820,12 @@ followed by cropping along the `height` and `width` dimensions.
|
||||
|
||||
TF_DerivedOperandTypeAttr T = TF_DerivedOperandTypeAttr<0>;
|
||||
TF_DerivedOperandTypeAttr Tidx = TF_DerivedOperandTypeAttr<1>;
|
||||
|
||||
let verifier = [{
|
||||
return Verify(*this);
|
||||
}];
|
||||
|
||||
let hasCanonicalizer = 1;
|
||||
}
|
||||
|
||||
def TF_BatchToSpaceNDOp : TF_Op<"BatchToSpaceND", [NoSideEffect]> {
|
||||
@ -7280,6 +7286,49 @@ $$\text{variable} := \text{variable} - \text{lr}_t * m_t / (\sqrt{v_t} + \epsilo
|
||||
TF_DerivedOperandTypeAttr T = TF_DerivedOperandTypeAttr<3>;
|
||||
}
|
||||
|
||||
def TF_ResourceApplyCenteredRMSPropOp : TF_Op<"ResourceApplyCenteredRMSProp", []> {
|
||||
let summary = "Update '*var' according to the centered RMSProp algorithm.";
|
||||
|
||||
let description = [{
|
||||
The centered RMSProp algorithm uses an estimate of the centered second moment
|
||||
(i.e., the variance) for normalization, as opposed to regular RMSProp, which
|
||||
uses the (uncentered) second moment. This often helps with training, but is
|
||||
slightly more expensive in terms of computation and memory.
|
||||
|
||||
Note that in dense implementation of this algorithm, mg, ms, and mom will
|
||||
update even if the grad is zero, but in this sparse implementation, mg, ms,
|
||||
and mom will not update in iterations during which the grad is zero.
|
||||
|
||||
mean_square = decay * mean_square + (1-decay) * gradient ** 2
|
||||
mean_grad = decay * mean_grad + (1-decay) * gradient
|
||||
|
||||
Delta = learning_rate * gradient / sqrt(mean_square + epsilon - mean_grad ** 2)
|
||||
|
||||
mg <- rho * mg_{t-1} + (1-rho) * grad
|
||||
ms <- rho * ms_{t-1} + (1-rho) * grad * grad
|
||||
mom <- momentum * mom_{t-1} + lr * grad / sqrt(ms - mg * mg + epsilon)
|
||||
var <- var - mom
|
||||
}];
|
||||
|
||||
let arguments = (ins
|
||||
TF_ResourceTensor:$var,
|
||||
TF_ResourceTensor:$mg,
|
||||
TF_ResourceTensor:$ms,
|
||||
TF_ResourceTensor:$mom,
|
||||
TensorOf<[BF16, F16, F32, F64, I16, I32, I64, I8, TF_Complex128, TF_Complex64, TF_Qint32, TF_Qint8, TF_Quint8, TF_Uint16, TF_Uint32, TF_Uint64, TF_Uint8]>:$lr,
|
||||
TensorOf<[BF16, F16, F32, F64, I16, I32, I64, I8, TF_Complex128, TF_Complex64, TF_Qint32, TF_Qint8, TF_Quint8, TF_Uint16, TF_Uint32, TF_Uint64, TF_Uint8]>:$rho,
|
||||
TensorOf<[BF16, F16, F32, F64, I16, I32, I64, I8, TF_Complex128, TF_Complex64, TF_Qint32, TF_Qint8, TF_Quint8, TF_Uint16, TF_Uint32, TF_Uint64, TF_Uint8]>:$momentum,
|
||||
TensorOf<[BF16, F16, F32, F64, I16, I32, I64, I8, TF_Complex128, TF_Complex64, TF_Qint32, TF_Qint8, TF_Quint8, TF_Uint16, TF_Uint32, TF_Uint64, TF_Uint8]>:$epsilon,
|
||||
TensorOf<[BF16, F16, F32, F64, I16, I32, I64, I8, TF_Complex128, TF_Complex64, TF_Qint32, TF_Qint8, TF_Quint8, TF_Uint16, TF_Uint32, TF_Uint64, TF_Uint8]>:$grad,
|
||||
|
||||
DefaultValuedAttr<BoolAttr, "false">:$use_locking
|
||||
);
|
||||
|
||||
let results = (outs);
|
||||
|
||||
TF_DerivedOperandTypeAttr T = TF_DerivedOperandTypeAttr<4>;
|
||||
}
|
||||
|
||||
def TF_ResourceApplyGradientDescentOp : TF_Op<"ResourceApplyGradientDescent", []> {
|
||||
let summary = "Update '*var' by subtracting 'alpha' * 'delta' from it.";
|
||||
|
||||
@ -9055,11 +9104,11 @@ particular,
|
||||
begin = [1, 2, x, x, 0, x] # x denotes don't care (usually 0)
|
||||
end = [2, 4, x, x, -3, x]
|
||||
strides = [1, 1, x, x, -1, 1]
|
||||
begin_mask = 1<<4 | 1 << 5 = 48
|
||||
begin_mask = 1<<4 | 1<<5 = 48
|
||||
end_mask = 1<<5 = 32
|
||||
ellipsis_mask = 1<<3 = 8
|
||||
new_axis_mask = 1<<2 4
|
||||
shrink_axis_mask = 1<<0
|
||||
new_axis_mask = 1<<2 = 4
|
||||
shrink_axis_mask = 1<<0 = 1
|
||||
```
|
||||
|
||||
In this case if `foo.shape` is (5, 5, 5, 5, 5, 5) the final shape of
|
||||
|
@ -695,6 +695,154 @@ void BatchMatMulV2Op::getCanonicalizationPatterns(
|
||||
results.insert<BatchMatMulV2ToMatMul>(context);
|
||||
}
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
// BatchToSpaceOp
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
static LogicalResult Verify(BatchToSpaceOp op) {
|
||||
// Op already has a constraint that block_size >= 2.
|
||||
int64_t block_size = op.block_size().getSExtValue();
|
||||
|
||||
llvm::SmallVector<int64_t, 4> input_shape(4, ShapedType::kDynamicSize);
|
||||
auto input_type = op.input().getType().cast<TensorType>();
|
||||
if (input_type.hasRank()) {
|
||||
if (input_type.getRank() != 4)
|
||||
return op.emitOpError()
|
||||
<< "requires input to be a 4D tensor, but got " << input_type;
|
||||
|
||||
int64_t input_batch = input_type.getDimSize(0);
|
||||
if (input_batch != ShapedType::kDynamicSize &&
|
||||
input_batch % (block_size * block_size) != 0) {
|
||||
return op.emitOpError()
|
||||
<< "requires input batch (dimension 0) to be evenly divisible "
|
||||
"by (block_size * block_size), but got input batch "
|
||||
<< input_batch << " and block_size " << block_size;
|
||||
}
|
||||
|
||||
input_shape.assign(input_type.getShape().begin(),
|
||||
input_type.getShape().end());
|
||||
}
|
||||
|
||||
auto crops_type = op.crops().getType().cast<TensorType>();
|
||||
if (crops_type.hasRank()) {
|
||||
if (crops_type.getRank() != 2)
|
||||
return op.emitOpError()
|
||||
<< "requires crops to be a 2D tensor, but got " << crops_type;
|
||||
|
||||
auto dim_of_size = [&](int64_t dim, int64_t size) {
|
||||
if (crops_type.isDynamicDim(dim)) return true;
|
||||
return crops_type.getDimSize(dim) == size;
|
||||
};
|
||||
if (!dim_of_size(0, 2) || !dim_of_size(1, 2))
|
||||
return op.emitOpError()
|
||||
<< "requires crops to be a tensor<2x2>, but got " << crops_type;
|
||||
}
|
||||
|
||||
DenseIntElementsAttr crops_attr;
|
||||
// Crops are defined as [[crop_top, crop_bottom], [crop_left, crop_right]],
|
||||
// and flattened as [crop_top, crop_bottom, crop_left, crop_right]
|
||||
llvm::SmallVector<int64_t, 4> crops_values;
|
||||
if (matchPattern(op.crops(), m_Constant(&crops_attr))) {
|
||||
assert(crops_attr.getNumElements() == 4 &&
|
||||
"tf.BatchToSpace crops must have 4 elements");
|
||||
|
||||
auto crops_range = crops_attr.getIntValues();
|
||||
for (const auto &crops_value : crops_range) {
|
||||
int64_t crops_value_int = crops_value.getSExtValue();
|
||||
if (crops_value_int < 0)
|
||||
return op.emitOpError()
|
||||
<< "requires all crop values to be nonnegative, but got "
|
||||
<< crops_attr;
|
||||
|
||||
crops_values.push_back(crops_value_int);
|
||||
}
|
||||
}
|
||||
|
||||
auto output_type = op.output().getType().cast<TensorType>();
|
||||
if (output_type.hasRank()) {
|
||||
if (output_type.getRank() != 4)
|
||||
return op.emitOpError()
|
||||
<< "requires output to be a 4D tensor, but got " << output_type;
|
||||
|
||||
auto static_dims = [](int64_t dim_a, int64_t dim_b) {
|
||||
return dim_a != ShapedType::kDynamicSize &&
|
||||
dim_b != ShapedType::kDynamicSize;
|
||||
};
|
||||
|
||||
auto output_shape = output_type.getShape();
|
||||
|
||||
// output batch = input batch / (block_size * block_size).
|
||||
int64_t input_batch = input_shape[0];
|
||||
int64_t output_batch = output_shape[0];
|
||||
if (static_dims(input_batch, output_batch) &&
|
||||
(output_batch * block_size * block_size) != input_batch)
|
||||
return op.emitOpError()
|
||||
<< "requires output batch (dimension 0) to be equal to input "
|
||||
"batch (dimension 0) / (block_size * block_size), but got "
|
||||
"output batch "
|
||||
<< output_batch << ", input batch " << input_batch
|
||||
<< ", and block_size " << block_size;
|
||||
|
||||
auto check_spatial_dim = [&](int64_t spatial_dim_index,
|
||||
llvm::StringRef dim_name,
|
||||
llvm::StringRef crop_a_name,
|
||||
llvm::StringRef crop_b_name) -> LogicalResult {
|
||||
int64_t input_dim = input_shape[spatial_dim_index];
|
||||
int64_t output_dim = output_shape[spatial_dim_index];
|
||||
if (!static_dims(input_dim, output_dim)) return success();
|
||||
|
||||
int64_t input_dim_pad = input_dim * block_size;
|
||||
// If crops are unknown, the maximum output spatial dim size is input
|
||||
// spatial dim size * block_size, as crops can be minimum 0.
|
||||
if (crops_values.empty() && output_dim > input_dim * block_size)
|
||||
return op.emitOpError()
|
||||
<< "requires output " << dim_name << " (dimension "
|
||||
<< spatial_dim_index << ") to be less than or equal to input "
|
||||
<< dim_name << " (dimension " << spatial_dim_index
|
||||
<< ") * block_size, but got output " << dim_name << " "
|
||||
<< output_dim << ", input " << dim_name << " " << input_dim
|
||||
<< ", and block_size " << block_size;
|
||||
|
||||
if (!crops_values.empty()) {
|
||||
// output spatial dim = input spatial dim * block_size - crops.
|
||||
int64_t crop_a = crops_values[2 * (spatial_dim_index - 1)];
|
||||
int64_t crop_b = crops_values[2 * (spatial_dim_index - 1) + 1];
|
||||
if (output_dim != input_dim_pad - crop_a - crop_b)
|
||||
return op.emitOpError()
|
||||
<< "requires output " << dim_name << " (dimension "
|
||||
<< spatial_dim_index << ") to be equal to input " << dim_name
|
||||
<< " (dimension " << spatial_dim_index << ") * block_size - "
|
||||
<< crop_a_name << " - " << crop_b_name << ", but got output "
|
||||
<< dim_name << " " << output_dim << ", input " << dim_name
|
||||
<< " " << input_dim << ", " << crop_a_name << " " << crop_a
|
||||
<< ", " << crop_b_name << " " << crop_b << ", and block_size "
|
||||
<< block_size;
|
||||
}
|
||||
|
||||
return success();
|
||||
};
|
||||
|
||||
if (failed(check_spatial_dim(1, "height", "crop_top", "crop_bottom")) ||
|
||||
failed(check_spatial_dim(2, "width", "crop_left", "crop_right")))
|
||||
return failure();
|
||||
|
||||
int64_t input_depth = input_shape[3];
|
||||
int64_t output_depth = output_shape[3];
|
||||
if (static_dims(input_depth, output_depth) && output_depth != input_depth)
|
||||
return op.emitOpError()
|
||||
<< "requires output depth (dimension 3) to be equal to input "
|
||||
"depth (dimension 3), but got output depth "
|
||||
<< output_depth << " and input depth " << input_depth;
|
||||
}
|
||||
|
||||
return success();
|
||||
}
|
||||
|
||||
void BatchToSpaceOp::getCanonicalizationPatterns(
|
||||
OwningRewritePatternList &results, MLIRContext *context) {
|
||||
results.insert<BatchToSpaceToBatchToSpaceND>(context);
|
||||
}
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
// BiasAddOp
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
@ -586,3 +586,27 @@ func @sub(%arg0: tensor<*xf32>, %arg1: tensor<*xf32>) -> tensor<*xf32> {
|
||||
%0 = "tf.Sub"(%arg0, %arg1) : (tensor<*xf32>, tensor<*xf32>) -> tensor<*xf32>
|
||||
return %0 : tensor<*xf32>
|
||||
}
|
||||
|
||||
// CHECK-LABEL: testBatchToSpaceToBatchToSpaceND
|
||||
// CHECK-SAME: ([[INPUT:%.*]]: tensor<?x?x?x?xf32>, [[CROPS:%.*]]: tensor<?x?xi32>)
|
||||
func @testBatchToSpaceToBatchToSpaceND(%arg0: tensor<?x?x?x?xf32>, %arg1: tensor<?x?xi32>) -> tensor<*xf32> {
|
||||
// CHECK: [[BLOCK_SHAPE:%.*]] = "tf.Const"() {value = dense<8> : tensor<2xi64>}
|
||||
// CHECK: [[BATCH_TO_SHAPE_ND:%.*]] = "tf.BatchToSpaceND"([[INPUT]], [[BLOCK_SHAPE]], [[CROPS]])
|
||||
%0 = "tf.BatchToSpace"(%arg0, %arg1) {block_size = 8 : i64} : (tensor<?x?x?x?xf32>, tensor<?x?xi32>) -> tensor<*xf32>
|
||||
// CHECK: return [[BATCH_TO_SHAPE_ND]]
|
||||
return %0 : tensor<*xf32>
|
||||
}
|
||||
|
||||
// CHECK-LABEL: testBatchToSpaceDynamicInput
|
||||
func @testBatchToSpaceDynamicInput(%arg0: tensor<*xf32>, %arg1: tensor<?x?xi32>) -> tensor<*xf32> {
|
||||
// CHECK-NOT: "tf.BatchToSpaceND"
|
||||
%0 = "tf.BatchToSpace"(%arg0, %arg1) {block_size = 8 : i64} : (tensor<*xf32>, tensor<?x?xi32>) -> tensor<*xf32>
|
||||
return %0 : tensor<*xf32>
|
||||
}
|
||||
|
||||
// CHECK-LABEL: testBatchToSpaceDynamicCrops
|
||||
func @testBatchToSpaceDynamicCrops(%arg0: tensor<?x?x?x?xf32>, %arg1: tensor<*xi32>) -> tensor<*xf32> {
|
||||
// CHECK-NOT: "tf.BatchToSpaceND"
|
||||
%0 = "tf.BatchToSpace"(%arg0, %arg1) {block_size = 8 : i64} : (tensor<?x?x?x?xf32>, tensor<*xi32>) -> tensor<*xf32>
|
||||
return %0 : tensor<*xf32>
|
||||
}
|
||||
|
@ -368,6 +368,57 @@ func @decompose_resource_gather_op(%indices : tensor<5xi32>) -> tensor<2x5x16xi3
|
||||
|
||||
// -----
|
||||
|
||||
// Tests that composite tf.ResourceApplyCenteredRMSProp operation is decomposed.
|
||||
|
||||
// CHECK-LABEL: func @decompose_resource_apply_centered_RMS_prop
|
||||
// CHECK-SAME: [[VAR:%.*]]: tensor<f32>, [[MG:%.*]]: tensor<f32>, [[MS:%.*]]: tensor<f32>, [[MOM:%.*]]: tensor<f32>, [[LR:%.*]]: tensor<f32>, [[RHO:%.*]]: tensor<f32>, [[MOMENTUM:%.*]]: tensor<f32>, [[EPSILON:%.*]]: tensor<f32>, [[GRAD:%.*]]: tensor<f32>
|
||||
func @decompose_resource_apply_centered_RMS_prop(%arg0: tensor<f32>, %arg1: tensor<f32>, %arg2: tensor<f32>, %arg3: tensor<f32>, %arg4: tensor<f32>, %arg5: tensor<f32>, %arg6: tensor<f32>, %arg7: tensor<f32>, %arg8: tensor<f32>) -> () {
|
||||
// CHECK: [[ONE:%.*]] = "tf.Const"() {value = dense<1.000000e+00> : tensor<f32>}
|
||||
// CHECK: [[VAR_HANDLE:%.*]] = "tf.VarHandleOp"
|
||||
// CHECK: [[MG_HANDLE:%.*]] = "tf.VarHandleOp"
|
||||
// CHECK: [[MS_HANDLE:%.*]] = "tf.VarHandleOp"
|
||||
// CHECK: [[MOM_HANDLE:%.*]] = "tf.VarHandleOp"
|
||||
%0 = "tf.VarHandleOp"() {container = "c", shared_name = "v"} : () -> tensor<*x!tf.resource>
|
||||
%1 = "tf.VarHandleOp"() {container = "c", shared_name = "v"} : () -> tensor<*x!tf.resource>
|
||||
%2 = "tf.VarHandleOp"() {container = "c", shared_name = "v"} : () -> tensor<*x!tf.resource>
|
||||
%3 = "tf.VarHandleOp"() {container = "c", shared_name = "v"} : () -> tensor<*x!tf.resource>
|
||||
|
||||
// CHECK: [[GRADSQ:%.*]] = "tf.Mul"([[GRAD]], [[GRAD]])
|
||||
// CHECK: [[SB:%.*]] = "tf.Sub"([[ONE]], [[RHO]])
|
||||
// CHECK: [[GRAD_SUB:%.*]] = "tf.Mul"([[GRADSQ]], [[SB]])
|
||||
// CHECK: [[MS:%.*]] = "tf.ReadVariableOp"([[MS_HANDLE]])
|
||||
// CHECK: [[MS_RHO:%.*]] = "tf.Mul"([[MS]], [[RHO]])
|
||||
// CHECK: [[MS_NEW:%.*]] = "tf.Add"([[GRAD_SUB]], [[MS_RHO]])
|
||||
// CHECK: "tf.AssignVariableOp"([[MS_HANDLE]], [[MS_NEW]])
|
||||
|
||||
// CHECK: [[SUB_RHO:%.*]] = "tf.Sub"([[ONE]], [[RHO]])
|
||||
// CHECK: [[SUB_GRAD:%.*]] = "tf.Mul"([[GRAD]], [[SUB_RHO]])
|
||||
// CHECK: [[MG:%.*]] = "tf.ReadVariableOp"([[MG_HANDLE]])
|
||||
// CHECK: [[MG_RHO:%.*]] = "tf.Mul"([[MG]], [[RHO]])
|
||||
// CHECK: [[MG_NEW:%.*]] = "tf.Add"([[SUB_GRAD]], [[MG_RHO]])
|
||||
// CHECK: "tf.AssignVariableOp"([[MG_HANDLE]], [[MG_NEW]])
|
||||
|
||||
// CHECK: [[MOM:%.*]] = "tf.ReadVariableOp"([[MOM_HANDLE]])
|
||||
// CHECK: [[MOM_MOM:%.*]] = "tf.Mul"([[MOMENTUM]], [[MOM]])
|
||||
// CHECK: [[LR_GRAD:%.*]] = "tf.Mul"([[LR]], [[GRAD]])
|
||||
|
||||
// CHECK: [[MG_MG:%.*]] = "tf.Mul"([[MG_NEW]], [[MG_NEW]])
|
||||
// CHECK: [[MG_NEW:%.*]] = "tf.Add"([[MG_MG]], [[EPSILON]])
|
||||
// CHECK: [[MG_SUB:%.*]] = "tf.Sub"([[MS_NEW]], [[MG_NEW]])
|
||||
// CHECK: [[MG_SQRT:%.*]] = "tf.Sqrt"([[MG_SUB]])
|
||||
// CHECK: [[MOM_DIV:%.*]] = "tf.Div"([[LR_GRAD]], [[MG_SQRT]])
|
||||
// CHECK: [[MOM_NEW:%.*]] = "tf.Add"([[MOM_MOM]], [[MOM_DIV]])
|
||||
|
||||
// CHECK: [[VAR:%.*]] = "tf.ReadVariableOp"([[VAR_HANDLE]])
|
||||
// CHECK: [[VAR_NEW:%.*]] = "tf.Sub"([[VAR]], [[MOM_NEW]])
|
||||
// CHECK: "tf.AssignVariableOp"([[VAR_HANDLE]], [[VAR_NEW]])
|
||||
|
||||
"tf.ResourceApplyCenteredRMSProp"(%0, %1, %2, %3, %arg4, %arg5, %arg6, %arg7, %arg8) {use_locking = false} : (tensor<*x!tf.resource>, tensor<*x!tf.resource>, tensor<*x!tf.resource>, tensor<*x!tf.resource>, tensor<f32>, tensor<f32>, tensor<f32>, tensor<f32>, tensor<f32>) -> ()
|
||||
return
|
||||
}
|
||||
|
||||
// -----
|
||||
|
||||
// Tests that composite tf.ResourceScatterUpdate operation is decomposed.
|
||||
|
||||
// CHECK-LABEL: @decompose_resource_scatter_update_op
|
||||
|
@ -4,6 +4,7 @@ traces: {
|
||||
value: {
|
||||
file_line_cols: {
|
||||
line : 1
|
||||
col : 1
|
||||
}
|
||||
}
|
||||
}
|
||||
@ -12,9 +13,11 @@ traces: {
|
||||
value: {
|
||||
file_line_cols: {
|
||||
line : 3
|
||||
col : 1
|
||||
}
|
||||
file_line_cols: {
|
||||
line : 4
|
||||
col : 1
|
||||
}
|
||||
}
|
||||
}
|
||||
@ -23,6 +26,7 @@ traces: {
|
||||
value: {
|
||||
file_line_cols: {
|
||||
line : 2
|
||||
col : 1
|
||||
}
|
||||
}
|
||||
}
|
||||
|
@ -773,6 +773,20 @@ func @convert_reduce_to_max(%arg0: tensor<1x256xf32>) -> tensor<1xf32> {
|
||||
return %1 : tensor<1xf32>
|
||||
}
|
||||
|
||||
|
||||
func @convert_reduce_to_min(%arg0: tensor<1x256xf32>) -> tensor<1xf32> {
|
||||
// "0x7F800000" represents INF for f32.
|
||||
%0 = xla_hlo.constant dense<0x7F800000> : tensor<f32>
|
||||
%1 = "xla_hlo.reduce"(%arg0, %0) ( {
|
||||
^bb0(%arg1: tensor<f32>, %arg2: tensor<f32>):
|
||||
%2 = xla_hlo.minimum %arg1, %arg2 : tensor<f32>
|
||||
"xla_hlo.return"(%2) : (tensor<f32>) -> ()
|
||||
}) {dimensions = dense<1> : tensor<1xi64>} : (tensor<1x256xf32>, tensor<f32>) -> tensor<1xf32>
|
||||
return %1 : tensor<1xf32>
|
||||
}
|
||||
|
||||
|
||||
|
||||
// NOTE: Assertions have been autogenerated by utils/generate-test-checks.py
|
||||
|
||||
// CHECK-LABEL: func @biasAdd_NHWC(
|
||||
@ -1689,3 +1703,10 @@ func @convert_reduce_to_max(%arg0: tensor<1x256xf32>) -> tensor<1xf32> {
|
||||
// CHECK: [[VAL_418:%.*]] = "tf.Max"([[VAL_416:%.*]], [[VAL_417:%.*]]) {keep_dims = false} : (tensor<1x256xf32>, tensor<1xi64>) -> tensor<1xf32>
|
||||
// CHECK: return [[VAL_418]] : tensor<1xf32>
|
||||
// CHECK: }
|
||||
|
||||
// CHECK-LABEL: func @convert_reduce_to_min(
|
||||
// CHECK-SAME: [[VAL_419:%.*]]: tensor<1x256xf32>) -> tensor<1xf32> {
|
||||
// CHECK: [[VAL_420:%.*]] = "tf.Const"() {value = dense<1> : tensor<1xi64>} : () -> tensor<1xi64>
|
||||
// CHECK: [[VAL_421:%.*]] = "tf.Min"([[VAL_419:%.*]], [[VAL_420:%.*]]) {keep_dims = false} : (tensor<1x256xf32>, tensor<1xi64>) -> tensor<1xf32>
|
||||
// CHECK: return [[VAL_421]] : tensor<1xf32>
|
||||
// CHECK: }
|
||||
|
@ -2872,4 +2872,141 @@ func @testSendTPUEmbeddingGradients(%x: tensor<512x256xf32>) {
|
||||
return
|
||||
}
|
||||
|
||||
// -----
|
||||
|
||||
//===--------------------------------------------------------------------===//
|
||||
// tf.BatchToSpace
|
||||
//===--------------------------------------------------------------------===//
|
||||
|
||||
func @testBatchToSpaceDynamic(%arg0: tensor<*xf32>, %arg1: tensor<*xi32>) {
|
||||
%0 = "tf.BatchToSpace"(%arg0, %arg1) {block_size = 2 : i64} : (tensor<*xf32>, tensor<*xi32>) -> tensor<*xf32>
|
||||
return
|
||||
}
|
||||
|
||||
func @testBatchToSpaceRankedInput(%arg0: tensor<?x?x?x?xf32>, %arg1: tensor<*xi32>) {
|
||||
%0 = "tf.BatchToSpace"(%arg0, %arg1) {block_size = 2 : i64} : (tensor<?x?x?x?xf32>, tensor<*xi32>) -> tensor<*xf32>
|
||||
return
|
||||
}
|
||||
|
||||
func @testBatchToSpaceRankedCrops(%arg0: tensor<*xf32>, %arg1: tensor<?x?xi32>) {
|
||||
%0 = "tf.BatchToSpace"(%arg0, %arg1) {block_size = 2 : i64} : (tensor<*xf32>, tensor<?x?xi32>) -> tensor<*xf32>
|
||||
return
|
||||
}
|
||||
|
||||
func @testBatchToSpaceRankedOutput(%arg0: tensor<*xf32>, %arg1: tensor<*xi32>) {
|
||||
%0 = "tf.BatchToSpace"(%arg0, %arg1) {block_size = 2 : i64} : (tensor<*xf32>, tensor<*xi32>) -> tensor<?x?x?x?xf32>
|
||||
return
|
||||
}
|
||||
|
||||
func @testBatchToSpaceStatic(%arg0: tensor<36x8x8x8xf32>) {
|
||||
%crops = "tf.Const"() {value = dense<[[1, 2], [3, 4]]> : tensor<2x2xi32>} : () -> tensor<2x2xi32>
|
||||
%0 = "tf.BatchToSpace"(%arg0, %crops) {block_size = 3 : i64} : (tensor<36x8x8x8xf32>, tensor<2x2xi32>) -> tensor<4x21x17x8xf32>
|
||||
return
|
||||
}
|
||||
|
||||
// -----
|
||||
|
||||
func @testBatchToSpaceInvalidInputRank(%arg0: tensor<8xf32>, %arg1: tensor<*xi32>) {
|
||||
// expected-error @+1 {{'tf.BatchToSpace' op requires input to be a 4D tensor, but got 'tensor<8xf32>'}}
|
||||
%0 = "tf.BatchToSpace"(%arg0, %arg1) {block_size = 2 : i64} : (tensor<8xf32>, tensor<*xi32>) -> tensor<*xf32>
|
||||
return
|
||||
}
|
||||
|
||||
// -----
|
||||
|
||||
func @testBatchToSpaceInvalidInputBatch(%arg0: tensor<2x4x6x8xf32>, %arg1: tensor<*xi32>) {
|
||||
// expected-error @+1 {{'tf.BatchToSpace' op requires input batch (dimension 0) to be evenly divisible by (block_size * block_size), but got input batch 2 and block_size 2}}
|
||||
%0 = "tf.BatchToSpace"(%arg0, %arg1) {block_size = 2 : i64} : (tensor<2x4x6x8xf32>, tensor<*xi32>) -> tensor<*xf32>
|
||||
return
|
||||
}
|
||||
|
||||
// -----
|
||||
|
||||
func @testBatchToSpaceInvalidCropsRank(%arg0: tensor<*xf32>, %arg1: tensor<?x?x?xi32>) {
|
||||
// expected-error @+1 {{'tf.BatchToSpace' op requires crops to be a 2D tensor, but got 'tensor<?x?x?xi32>'}}
|
||||
%0 = "tf.BatchToSpace"(%arg0, %arg1) {block_size = 2 : i64} : (tensor<*xf32>, tensor<?x?x?xi32>) -> tensor<*xf32>
|
||||
return
|
||||
}
|
||||
|
||||
// -----
|
||||
|
||||
func @testBatchToSpaceInvalidCropsFirstDim(%arg0: tensor<*xf32>, %arg1: tensor<3x?xi32>) {
|
||||
// expected-error @+1 {{'tf.BatchToSpace' op requires crops to be a tensor<2x2>, but got 'tensor<3x?xi32>'}}
|
||||
%0 = "tf.BatchToSpace"(%arg0, %arg1) {block_size = 2 : i64} : (tensor<*xf32>, tensor<3x?xi32>) -> tensor<*xf32>
|
||||
return
|
||||
}
|
||||
|
||||
// -----
|
||||
|
||||
func @testBatchToSpaceInvalidCropsSecondDim(%arg0: tensor<*xf32>, %arg1: tensor<?x3xi32>) {
|
||||
// expected-error @+1 {{'tf.BatchToSpace' op requires crops to be a tensor<2x2>, but got 'tensor<?x3xi32>'}}
|
||||
%0 = "tf.BatchToSpace"(%arg0, %arg1) {block_size = 2 : i64} : (tensor<*xf32>, tensor<?x3xi32>) -> tensor<*xf32>
|
||||
return
|
||||
}
|
||||
|
||||
// -----
|
||||
|
||||
func @testBatchToSpaceBadCropValues(%arg0: tensor<*xf32>) {
|
||||
%crops = "tf.Const"() {value = dense<[[-1, -2], [-3, -4]]> : tensor<2x2xi32>} : () -> tensor<2x2xi32>
|
||||
// expected-error @+1 {{'tf.BatchToSpace' op requires all crop values to be nonnegative, but got dense<[[-1, -2], [-3, -4]]> : tensor<2x2xi32>}}
|
||||
%0 = "tf.BatchToSpace"(%arg0, %crops) {block_size = 2 : i64} : (tensor<*xf32>, tensor<2x2xi32>) -> tensor<*xf32>
|
||||
return
|
||||
}
|
||||
|
||||
// -----
|
||||
|
||||
func @testBatchToSpaceInvalidOutputRank(%arg0: tensor<*xf32>, %arg1: tensor<*xi32>) {
|
||||
// expected-error @+1 {{'tf.BatchToSpace' op requires output to be a 4D tensor, but got 'tensor<8xf32>'}}
|
||||
%0 = "tf.BatchToSpace"(%arg0, %arg1) {block_size = 2 : i64} : (tensor<*xf32>, tensor<*xi32>) -> tensor<8xf32>
|
||||
return
|
||||
}
|
||||
|
||||
// -----
|
||||
|
||||
func @testBatchToSpaceInvalidOutputBatch(%arg0: tensor<16x8x8x3xf32>, %arg1: tensor<*xi32>) {
|
||||
// expected-error @+1 {{'tf.BatchToSpace' op requires output batch (dimension 0) to be equal to input batch (dimension 0) / (block_size * block_size), but got output batch 8, input batch 16, and block_size 2}}
|
||||
%0 = "tf.BatchToSpace"(%arg0, %arg1) {block_size = 2 : i64} : (tensor<16x8x8x3xf32>, tensor<*xi32>) -> tensor<8x8x8x3xf32>
|
||||
return
|
||||
}
|
||||
|
||||
// -----
|
||||
|
||||
func @testBatchToSpaceInvalidOutputHeight(%arg0: tensor<16x8x8x3xf32>, %arg1: tensor<*xi32>) {
|
||||
// expected-error @+1 {{'tf.BatchToSpace' op requires output height (dimension 1) to be less than or equal to input height (dimension 1) * block_size, but got output height 17, input height 8, and block_size 2}}
|
||||
%0 = "tf.BatchToSpace"(%arg0, %arg1) {block_size = 2 : i64} : (tensor<16x8x8x3xf32>, tensor<*xi32>) -> tensor<4x17x8x3xf32>
|
||||
return
|
||||
}
|
||||
|
||||
// -----
|
||||
|
||||
func @testBatchToSpaceInvalidOutputHeightCrops(%arg0: tensor<16x8x8x3xf32>) {
|
||||
%crops = "tf.Const"() {value = dense<[[1, 2], [3, 4]]> : tensor<2x2xi32>} : () -> tensor<2x2xi32>
|
||||
// expected-error @+1 {{'tf.BatchToSpace' op requires output height (dimension 1) to be equal to input height (dimension 1) * block_size - crop_top - crop_bottom, but got output height 8, input height 8, crop_top 1, crop_bottom 2, and block_size 2}}
|
||||
%0 = "tf.BatchToSpace"(%arg0, %crops) {block_size = 2 : i64} : (tensor<16x8x8x3xf32>, tensor<2x2xi32>) -> tensor<4x8x9x3xf32>
|
||||
return
|
||||
}
|
||||
|
||||
// -----
|
||||
|
||||
func @testBatchToSpaceInvalidOutputWidth(%arg0: tensor<16x4x4x3xf32>, %arg1: tensor<*xi32>) {
|
||||
// expected-error @+1 {{'tf.BatchToSpace' op requires output width (dimension 2) to be less than or equal to input width (dimension 2) * block_size, but got output width 9, input width 4, and block_size 2}}
|
||||
%0 = "tf.BatchToSpace"(%arg0, %arg1) {block_size = 2 : i64} : (tensor<16x4x4x3xf32>, tensor<*xi32>) -> tensor<4x4x9x3xf32>
|
||||
return
|
||||
}
|
||||
|
||||
// -----
|
||||
|
||||
func @testBatchToSpaceInvalidOutputWidthCrops(%arg0: tensor<16x8x8x3xf32>) {
|
||||
%crops = "tf.Const"() {value = dense<[[1, 2], [3, 4]]> : tensor<2x2xi32>} : () -> tensor<2x2xi32>
|
||||
// expected-error @+1 {{'tf.BatchToSpace' op requires output width (dimension 2) to be equal to input width (dimension 2) * block_size - crop_left - crop_right, but got output width 8, input width 8, crop_left 3, crop_right 4, and block_size 2}}
|
||||
%0 = "tf.BatchToSpace"(%arg0, %crops) {block_size = 2 : i64} : (tensor<16x8x8x3xf32>, tensor<2x2xi32>) -> tensor<4x13x8x3xf32>
|
||||
return
|
||||
}
|
||||
|
||||
// -----
|
||||
|
||||
func @testBatchToSpaceInvalidOutputDepth(%arg0: tensor<16x8x8x3xf32>, %arg1: tensor<*xi32>) {
|
||||
// expected-error @+1 {{'tf.BatchToSpace' op requires output depth (dimension 3) to be equal to input depth (dimension 3), but got output depth 8 and input depth 3}}
|
||||
%0 = "tf.BatchToSpace"(%arg0, %arg1) {block_size = 2 : i64} : (tensor<16x8x8x3xf32>, tensor<*xi32>) -> tensor<4x8x8x8xf32>
|
||||
return
|
||||
}
|
||||
|
@ -0,0 +1,61 @@
|
||||
// RUN: tf-opt -verify-diagnostics -tf-saved-model-lift-variables-test -split-input-file %s | FileCheck %s --dump-input=fail
|
||||
|
||||
module attributes {tf_saved_model.semantics} {
|
||||
|
||||
// Test case: Freezing VarHandleOp ops.
|
||||
|
||||
func @serving_default(%arg0: tensor<!tf.resource<tensor<100x50xf32>>> {tf.resource_name = "dense/kernel"}, %arg1: tensor<!tf.resource<tensor<50xf32>>> {tf.resource_name = "dense/bias"}) -> (tensor<100x50xf32> {tf_saved_model.index_path = ["dense_2"]})
|
||||
attributes {tf.entry_function = {control_outputs = "", inputs = "", outputs = "dense_2/Add:0"}, tf_saved_model.exported_names = ["serving_default"]} {
|
||||
%0 = "tf.VarHandleOp"() {_class = ["loc:@dense/kernel"], allowed_devices = [], container = "", device = "", shared_name = "dense/kernel"} : () -> tensor<!tf.resource<tensor<100x50xf32>>>
|
||||
%1 = "tf.ReadVariableOp"(%0) {device = ""} : (tensor<!tf.resource<tensor<100x50xf32>>>) -> tensor<100x50xf32>
|
||||
%2 = "tf.VarHandleOp"() {_class = ["loc:@dense/bias"], allowed_devices = [], container = "", device = "", shared_name = "dense/bias"} : () -> tensor<!tf.resource<tensor<50xf32>>>
|
||||
%3 = "tf.ReadVariableOp"(%2) {device = ""} : (tensor<!tf.resource<tensor<50xf32>>>) -> tensor<50xf32>
|
||||
%4 = "tf.Add"(%1, %3) {device = ""} : (tensor<100x50xf32>, tensor<50xf32>) -> tensor<100x50xf32>
|
||||
return %4 : tensor<100x50xf32>
|
||||
}
|
||||
// CHECK: "tf_saved_model.global_tensor"()
|
||||
// CHECK: sym_name = "dense/kernel"
|
||||
// CHECK: "tf_saved_model.global_tensor"()
|
||||
// CHECK: sym_name = "dense/bias"
|
||||
// CHECK: func @serving_default(
|
||||
// CHECK: %arg0: tensor<!tf.resource<tensor<100x50xf32>>> {tf_saved_model.bound_input = @"dense/kernel"},
|
||||
// CHECK: %arg1: tensor<!tf.resource<tensor<50xf32>>> {tf_saved_model.bound_input = @"dense/bias"})
|
||||
}
|
||||
|
||||
// -----
|
||||
|
||||
module attributes {tf_saved_model.semantics} {
|
||||
|
||||
// Test case: Freezing shared VarHandleOp ops.
|
||||
|
||||
func @f(%arg0: tensor<!tf.resource<tensor<100x50xf32>>> {tf.resource_name = "dense/kernel"}, %arg1: tensor<!tf.resource<tensor<50xf32>>> {tf.resource_name = "dense/bias"}) -> (tensor<100x50xf32> {tf_saved_model.index_path = ["dense_2"]})
|
||||
attributes {tf.entry_function = {control_outputs = "", inputs = "", outputs = "dense_2/Add:0"}, tf_saved_model.exported_names = ["f"]} {
|
||||
%0 = "tf.VarHandleOp"() {_class = ["loc:@dense/kernel"], allowed_devices = [], container = "", device = "", shared_name = "dense/kernel"} : () -> tensor<!tf.resource<tensor<100x50xf32>>>
|
||||
%1 = "tf.ReadVariableOp"(%0) {device = ""} : (tensor<!tf.resource<tensor<100x50xf32>>>) -> tensor<100x50xf32>
|
||||
%2 = "tf.VarHandleOp"() {_class = ["loc:@dense/bias"], allowed_devices = [], container = "", device = "", shared_name = "dense/bias"} : () -> tensor<!tf.resource<tensor<50xf32>>>
|
||||
%3 = "tf.ReadVariableOp"(%2) {device = ""} : (tensor<!tf.resource<tensor<50xf32>>>) -> tensor<50xf32>
|
||||
%4 = "tf.Add"(%1, %3) {device = ""} : (tensor<100x50xf32>, tensor<50xf32>) -> tensor<100x50xf32>
|
||||
return %4 : tensor<100x50xf32>
|
||||
}
|
||||
|
||||
func @f2(%arg0: tensor<!tf.resource<tensor<100x50xf32>>> {tf.resource_name = "dense/kernel"}, %arg1: tensor<!tf.resource<tensor<50xf32>>> {tf.resource_name = "dense/bias"}) -> (tensor<100x50xf32> {tf_saved_model.index_path = ["dense_2"]})
|
||||
attributes {tf.entry_function = {control_outputs = "", inputs = "", outputs = "dense_2/Add:0"}, tf_saved_model.exported_names = ["f2"]} {
|
||||
%0 = "tf.VarHandleOp"() {_class = ["loc:@dense/kernel"], allowed_devices = [], container = "", device = "", shared_name = "dense/kernel"} : () -> tensor<!tf.resource<tensor<100x50xf32>>>
|
||||
%1 = "tf.ReadVariableOp"(%0) {device = ""} : (tensor<!tf.resource<tensor<100x50xf32>>>) -> tensor<100x50xf32>
|
||||
%2 = "tf.VarHandleOp"() {_class = ["loc:@dense/bias"], allowed_devices = [], container = "", device = "", shared_name = "dense/bias"} : () -> tensor<!tf.resource<tensor<50xf32>>>
|
||||
%3 = "tf.ReadVariableOp"(%2) {device = ""} : (tensor<!tf.resource<tensor<50xf32>>>) -> tensor<50xf32>
|
||||
%4 = "tf.Add"(%1, %3) {device = ""} : (tensor<100x50xf32>, tensor<50xf32>) -> tensor<100x50xf32>
|
||||
return %4 : tensor<100x50xf32>
|
||||
}
|
||||
// CHECK: "tf_saved_model.global_tensor"()
|
||||
// CHECK: sym_name = "dense/kernel"
|
||||
// CHECK: "tf_saved_model.global_tensor"()
|
||||
// CHECK: sym_name = "dense/bias"
|
||||
// CHECK: func @f(
|
||||
// CHECK: %arg0: tensor<!tf.resource<tensor<100x50xf32>>> {tf_saved_model.bound_input = @"dense/kernel"},
|
||||
// CHECK: %arg1: tensor<!tf.resource<tensor<50xf32>>> {tf_saved_model.bound_input = @"dense/bias"})
|
||||
|
||||
// CHECK: func @f2(
|
||||
// CHECK: %arg0: tensor<!tf.resource<tensor<100x50xf32>>> {tf_saved_model.bound_input = @"dense/kernel"},
|
||||
// CHECK: %arg1: tensor<!tf.resource<tensor<50xf32>>> {tf_saved_model.bound_input = @"dense/bias"})
|
||||
}
|
@ -0,0 +1,33 @@
|
||||
// RUN: tf-opt -verify-diagnostics -tf-saved-model-lift-variables-invalid-session-test -split-input-file %s | FileCheck %s --dump-input=fail
|
||||
|
||||
// Test case: Invalid session.
|
||||
// expected-error @+1 {{'module' op no session provided}}
|
||||
module attributes {tf_saved_model.semantics} {
|
||||
|
||||
func @serving_default(%arg0: tensor<!tf.resource<tensor<100x50xf32>>> {tf.resource_name = "dense/kernel"}, %arg1: tensor<!tf.resource<tensor<50xf32>>> {tf.resource_name = "dense/bias"}) -> (tensor<100x50xf32> {tf_saved_model.index_path = ["dense_2"]})
|
||||
attributes {tf.entry_function = {control_outputs = "", inputs = "", outputs = "dense_2/Add:0"}, tf_saved_model.exported_names = ["serving_default"]} {
|
||||
%0 = "tf.VarHandleOp"() {_class = ["loc:@dense/kernel"], allowed_devices = [], container = "", device = "", shared_name = "dense/kernel"} : () -> tensor<!tf.resource<tensor<100x50xf32>>>
|
||||
%1 = "tf.ReadVariableOp"(%0) {device = ""} : (tensor<!tf.resource<tensor<100x50xf32>>>) -> tensor<100x50xf32>
|
||||
%2 = "tf.VarHandleOp"() {_class = ["loc:@dense/bias"], allowed_devices = [], container = "", device = "", shared_name = "dense/bias"} : () -> tensor<!tf.resource<tensor<50xf32>>>
|
||||
%3 = "tf.ReadVariableOp"(%2) {device = ""} : (tensor<!tf.resource<tensor<50xf32>>>) -> tensor<50xf32>
|
||||
%4 = "tf.Add"(%1, %3) {device = ""} : (tensor<100x50xf32>, tensor<50xf32>) -> tensor<100x50xf32>
|
||||
return %4 : tensor<100x50xf32>
|
||||
}
|
||||
}
|
||||
|
||||
// -----
|
||||
|
||||
// Test case: No errors on no resource arguments.
|
||||
module attributes {tf_saved_model.semantics} {
|
||||
|
||||
// CHECK-LABEL: @serving_default
|
||||
func @serving_default() -> (tensor<100x50xf32> {tf_saved_model.index_path = ["dense_2"]})
|
||||
attributes {tf.entry_function = {control_outputs = "", inputs = "", outputs = "dense_2/Add:0"}, tf_saved_model.exported_names = ["serving_default"]} {
|
||||
%0 = "tf.VarHandleOp"() {_class = ["loc:@dense/kernel"], allowed_devices = [], container = "", device = "", shared_name = "dense/kernel"} : () -> tensor<!tf.resource<tensor<100x50xf32>>>
|
||||
%1 = "tf.ReadVariableOp"(%0) {device = ""} : (tensor<!tf.resource<tensor<100x50xf32>>>) -> tensor<100x50xf32>
|
||||
%2 = "tf.VarHandleOp"() {_class = ["loc:@dense/bias"], allowed_devices = [], container = "", device = "", shared_name = "dense/bias"} : () -> tensor<!tf.resource<tensor<50xf32>>>
|
||||
%3 = "tf.ReadVariableOp"(%2) {device = ""} : (tensor<!tf.resource<tensor<50xf32>>>) -> tensor<50xf32>
|
||||
%4 = "tf.Add"(%1, %3) {device = ""} : (tensor<100x50xf32>, tensor<50xf32>) -> tensor<100x50xf32>
|
||||
return %4 : tensor<100x50xf32>
|
||||
}
|
||||
}
|
@ -27,6 +27,8 @@ def SingleResultAndOperandHaveSameType : Constraint<
|
||||
|
||||
def IsRank2Tensor : Type<HasAnyRankOfPred<[2]>, "Rank 2 tensor">;
|
||||
|
||||
def IsRank4Tensor : Type<HasAnyRankOfPred<[4]>, "Rank 4 tensor">;
|
||||
|
||||
// Checks if all the users is ReadVariableOp.
|
||||
def HasOnlyReadVariableOpUsers : Constraint<
|
||||
CPred<"llvm::all_of($0.getUsers(), [](mlir::OpOperand op) { "
|
||||
@ -65,6 +67,21 @@ def BatchMatMulV2ToMatMul : Pat<(TF_BatchMatMulV2Op $x, $y, $adj_x, $adj_y),
|
||||
(TF_MatMulOp $x, $y, $adj_x, $adj_y),
|
||||
[(IsRank2Tensor $x), (IsRank2Tensor $y)]>;
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
// BatchToSpace op patterns.
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
def BatchToSpaceBlockSizeToBlockShape : NativeCodeCall<
|
||||
"DenseElementsAttr::get(RankedTensorType::get({2}, $_builder.getI64Type()), "
|
||||
"ArrayRef<APInt>{$0.getValue(), $0.getValue()})">;
|
||||
|
||||
def BatchToSpaceToBatchToSpaceND :
|
||||
Pat<(TF_BatchToSpaceOp $input, $crops, $block_size),
|
||||
(TF_BatchToSpaceNDOp $input,
|
||||
(TF_ConstOp (BatchToSpaceBlockSizeToBlockShape $block_size)),
|
||||
$crops),
|
||||
[(IsRank4Tensor $input), (IsRank2Tensor $crops)]>;
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
// BiasAddV1 op patterns.
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
@ -327,3 +327,66 @@ def DecomposeVariableShape : Pat<
|
||||
(TF_VariableShapeOp:$src_op $resource),
|
||||
(TF_ShapeOp (CreateTFReadVariableOpFromResourceHandle $src_op, $resource)),
|
||||
[(CheckHasResourceSubtype $resource)]>;
|
||||
|
||||
// This decomposition is only correct inside XLA as it ignores use_locking
|
||||
// attribute.
|
||||
// ms <- rho * ms_{t-1} + (1-rho) * grad * grad
|
||||
// mg = grad * (one - rho) + mg * rho;
|
||||
// mom <- momentum * mom_{t-1} + lr * grad / sqrt(ms - mg * mg + epsilon)
|
||||
//
|
||||
def DecomposeResourceApplyCenteredRMSProp :
|
||||
Pattern<
|
||||
(TF_ResourceApplyCenteredRMSPropOp:$src_op
|
||||
$var_resource, $mg_resource, $ms_resource, $mom_resource, $lr, $rho, $momentum, $epsilon,
|
||||
$grad, ConstBoolAttrFalse:$use_locking
|
||||
),
|
||||
[(TF_ConstOp:$one (GetScalarOfType<1> $grad)),
|
||||
(CreateTFReadVariableOp $src_op, $grad, $ms_resource),
|
||||
(TF_AddOp:$ms_new
|
||||
(TF_MulOp
|
||||
(TF_MulOp $grad, $grad),
|
||||
(TF_SubOp $one, $rho)
|
||||
),
|
||||
(TF_MulOp
|
||||
(CreateTFReadVariableOp $src_op, $grad, $ms_resource),
|
||||
$rho
|
||||
)
|
||||
),
|
||||
(TF_AssignVariableOp $ms_resource, $ms_new),
|
||||
// mg = grad * (one - rho) + mg * rho;
|
||||
(TF_AddOp:$mg_new
|
||||
(TF_MulOp
|
||||
$grad,
|
||||
(TF_SubOp $one, $rho)
|
||||
),
|
||||
(TF_MulOp
|
||||
(CreateTFReadVariableOp $src_op, $grad, $mg_resource),
|
||||
$rho
|
||||
)
|
||||
),
|
||||
(TF_AssignVariableOp $mg_resource, $mg_new),
|
||||
// mom <- momentum * mom_{t-1} + lr * grad / sqrt(ms - mg * mg + epsilon)
|
||||
(TF_AddOp:$mom_new
|
||||
(TF_MulOp $momentum,
|
||||
(CreateTFReadVariableOp $src_op, $grad, $mom_resource)),
|
||||
(TF_DivOp
|
||||
(TF_MulOp $lr, $grad),
|
||||
(TF_SqrtOp
|
||||
(TF_SubOp
|
||||
$ms_new,
|
||||
(TF_AddOp
|
||||
(TF_MulOp
|
||||
$mg_new,
|
||||
$mg_new
|
||||
),
|
||||
$epsilon
|
||||
)
|
||||
)
|
||||
)
|
||||
)
|
||||
),
|
||||
(TF_AssignVariableOp $mom_resource, $mom_new),
|
||||
// var <- var - mom
|
||||
(TF_AssignSubVariableOp $var_resource, $mom_new)
|
||||
]
|
||||
>;
|
||||
|
@ -32,11 +32,15 @@ namespace TF {
|
||||
|
||||
namespace {
|
||||
|
||||
// Note: This implements fusions performed in the old Remapper Grappler pass.
|
||||
// That pass has specific cases for GPU and based on different target
|
||||
// configurations on both CPU and GPU (Intel MKL, ROCm, etc.). This MLIR pass
|
||||
// covers the general CPU case and at the moment does not account for any
|
||||
// target-specific configurations.
|
||||
// Note: This implements the fusions performed in the old Remapper Grappler
|
||||
// pass. That pass has specific cases for GPU and based on different
|
||||
// target configurations on both CPU and GPU (Intel MKL, ROCm, etc.). This MLIR
|
||||
// pass covers (some of) the general CPU case and at the moment does not account
|
||||
// for any target-specific configurations.
|
||||
|
||||
// This pass is being ported over from the Grappler Remapper pass based on
|
||||
// need/usage. File a bug to request porting over additional fusions.
|
||||
|
||||
// TODO(b/158265178): Support GPU-specific fusions.
|
||||
// TODO(b/158266710): Support CPU MKL configurations.
|
||||
|
||||
|
@ -489,9 +489,9 @@ LogicalResult MatchReduceOpInput(xla_hlo::ReduceOp reduce_op) {
|
||||
return success();
|
||||
}
|
||||
|
||||
// TODO(b/157192370): This "xla_hlo::ReduceOp" can corresponds to many TF ops
|
||||
// with different ops in reduce_op.body. Now we only match to "tf.Max" and
|
||||
// "tf.Sum".
|
||||
// TODO(jingpu): This "xla_hlo::ReduceOp" can corresponds to many TF ops
|
||||
// with different ops in reduce_op.body. Now we only match to "tf.Max", "tf.Min"
|
||||
// and "tf.Sum".
|
||||
class ConvertReduceOpToTfSum : public OpConversionPattern<xla_hlo::ReduceOp> {
|
||||
public:
|
||||
using OpConversionPattern::OpConversionPattern;
|
||||
@ -504,15 +504,13 @@ class ConvertReduceOpToTfSum : public OpConversionPattern<xla_hlo::ReduceOp> {
|
||||
Operation *first_op = &reduce_op.body().front().front();
|
||||
if (!llvm::isa<xla_hlo::AddOp>(first_op)) return failure();
|
||||
|
||||
// In `MatchReduceOpInput` function, we only match that the
|
||||
// In `MatchReduceOpInput` function, we already match that the
|
||||
// "xla_hlo::ReduceOp" only has one input, one init_value and one result.
|
||||
auto input = reduce_op.operands()[0];
|
||||
// Get reduction dimension.
|
||||
DenseIntElementsAttr dimension = reduce_op.dimensions();
|
||||
SmallVector<int64_t, 4> reduce_dims;
|
||||
const int64_t input_rank = input.getType().cast<ShapedType>().getRank();
|
||||
for (const int64_t &dim : dimension.getValues<int64_t>()) {
|
||||
if (dim < 0 || dim >= input_rank) return failure();
|
||||
reduce_dims.emplace_back(dim);
|
||||
}
|
||||
|
||||
@ -545,15 +543,13 @@ class ConvertReduceOpToTfMax : public OpConversionPattern<xla_hlo::ReduceOp> {
|
||||
Operation *first_op = &reduce_op.body().front().front();
|
||||
if (!llvm::isa<xla_hlo::MaxOp>(first_op)) return failure();
|
||||
|
||||
// In `MatchReduceOpInput` function, we only match that the
|
||||
// In `MatchReduceOpInput` function, we already match that the
|
||||
// "xla_hlo::ReduceOp" only has one input, one init_value and one result.
|
||||
auto input = reduce_op.operands()[0];
|
||||
// Get reduction dimension.
|
||||
DenseIntElementsAttr dimension = reduce_op.dimensions();
|
||||
SmallVector<int64_t, 4> reduce_dims;
|
||||
const int64_t input_rank = input.getType().cast<ShapedType>().getRank();
|
||||
for (const int64_t &dim : dimension.getValues<int64_t>()) {
|
||||
if (dim < 0 || dim >= input_rank) return failure();
|
||||
reduce_dims.emplace_back(dim);
|
||||
}
|
||||
|
||||
@ -576,6 +572,47 @@ class ConvertReduceOpToTfMax : public OpConversionPattern<xla_hlo::ReduceOp> {
|
||||
};
|
||||
};
|
||||
|
||||
class ConvertReduceOpToTfMin : public OpConversionPattern<xla_hlo::ReduceOp> {
|
||||
public:
|
||||
using OpConversionPattern::OpConversionPattern;
|
||||
|
||||
LogicalResult matchAndRewrite(
|
||||
xla_hlo::ReduceOp reduce_op, ArrayRef<Value> args,
|
||||
ConversionPatternRewriter &rewriter) const final {
|
||||
if (failed(MatchReduceOpInput(reduce_op))) return failure();
|
||||
|
||||
Operation *first_op = &reduce_op.body().front().front();
|
||||
if (!llvm::isa<xla_hlo::MinOp>(first_op)) return failure();
|
||||
|
||||
// In `MatchReduceOpInput` function, we already match that the
|
||||
// "xla_hlo::ReduceOp" only has one input, one init_value and one result.
|
||||
Value input = reduce_op.operands()[0];
|
||||
// Get reduction dimension.
|
||||
DenseIntElementsAttr dimension = reduce_op.dimensions();
|
||||
SmallVector<int64_t, 4> reduce_dims;
|
||||
for (const int64_t &dim : dimension.getValues<int64_t>()) {
|
||||
reduce_dims.emplace_back(dim);
|
||||
}
|
||||
|
||||
// Check initial value is +INF.
|
||||
DenseFPElementsAttr init_value;
|
||||
if (!matchPattern(reduce_op.init_values()[0], m_Constant(&init_value)) ||
|
||||
!init_value.isSplat() ||
|
||||
!init_value.getSplatValue<APFloat>().isInfinity() ||
|
||||
init_value.getSplatValue<APFloat>().isNegative())
|
||||
return failure();
|
||||
|
||||
auto dim_type = RankedTensorType::get(
|
||||
{static_cast<int64_t>(reduce_dims.size())}, rewriter.getI64Type());
|
||||
auto reduction_indices = rewriter.create<ConstOp>(
|
||||
reduce_op.getLoc(), dim_type, rewriter.getI64TensorAttr(reduce_dims));
|
||||
rewriter.replaceOpWithNewOp<MinOp>(
|
||||
reduce_op, reduce_op.getType(0), input, reduction_indices,
|
||||
/*keep_dim=*/rewriter.getBoolAttr(false));
|
||||
return success();
|
||||
};
|
||||
};
|
||||
|
||||
class LegalizeHloToTf : public PassWrapper<LegalizeHloToTf, FunctionPass> {
|
||||
public:
|
||||
LegalizeHloToTf() = default;
|
||||
@ -709,7 +746,7 @@ void LegalizeHloToTf::runOnFunction() {
|
||||
OwningRewritePatternList patterns;
|
||||
populateWithGenerated(&context, &patterns);
|
||||
patterns.insert<ConvertConvOp, ConvertSliceOp, ConvertReduceOpToTfMax,
|
||||
ConvertReduceOpToTfSum>(&context);
|
||||
ConvertReduceOpToTfMin, ConvertReduceOpToTfSum>(&context);
|
||||
|
||||
ConversionTarget target(context);
|
||||
target.addLegalDialect<TensorFlowDialect>();
|
||||
|
183
tensorflow/compiler/mlir/tensorflow/transforms/lift_variables.cc
Normal file
183
tensorflow/compiler/mlir/tensorflow/transforms/lift_variables.cc
Normal file
@ -0,0 +1,183 @@
|
||||
/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
#include "tensorflow/compiler/mlir/tensorflow/transforms/lift_variables.h"
|
||||
|
||||
#include <algorithm>
|
||||
#include <iterator>
|
||||
#include <string>
|
||||
#include <tuple>
|
||||
#include <type_traits>
|
||||
#include <utility>
|
||||
#include <vector>
|
||||
|
||||
#include "llvm/ADT/ArrayRef.h"
|
||||
#include "llvm/ADT/STLExtras.h"
|
||||
#include "llvm/ADT/SmallSet.h"
|
||||
#include "llvm/ADT/SmallVector.h"
|
||||
#include "llvm/ADT/StringRef.h"
|
||||
#include "mlir/IR/Builders.h" // from @llvm-project
|
||||
#include "mlir/IR/Module.h" // from @llvm-project
|
||||
#include "mlir/IR/UseDefLists.h" // from @llvm-project
|
||||
#include "mlir/Support/LLVM.h" // from @llvm-project
|
||||
#include "tensorflow/compiler/mlir/tensorflow/ir/tf_ops.h"
|
||||
#include "tensorflow/compiler/mlir/tensorflow/ir/tf_saved_model.h"
|
||||
#include "tensorflow/compiler/mlir/tensorflow/utils/convert_tensor.h"
|
||||
#include "tensorflow/core/common_runtime/device.h"
|
||||
#include "tensorflow/core/common_runtime/device_mgr.h"
|
||||
#include "tensorflow/core/framework/resource_mgr.h"
|
||||
#include "tensorflow/core/framework/resource_var.h"
|
||||
#include "tensorflow/core/framework/tensor.h"
|
||||
#include "tensorflow/core/platform/errors.h"
|
||||
#include "tensorflow/core/platform/mutex.h"
|
||||
#include "tensorflow/core/platform/threadpool_options.h"
|
||||
#include "tensorflow/core/public/session.h"
|
||||
|
||||
namespace mlir {
|
||||
namespace tf_saved_model {
|
||||
|
||||
using llvm::SmallSet;
|
||||
using ::tensorflow::Device;
|
||||
using ::tensorflow::DeviceMgr;
|
||||
using ::tensorflow::mutex_lock;
|
||||
using ::tensorflow::ResourceHandle;
|
||||
using ::tensorflow::Session;
|
||||
using ::tensorflow::Status;
|
||||
using ::tensorflow::StatusOr;
|
||||
using ::tensorflow::Tensor;
|
||||
using ::tensorflow::Var;
|
||||
|
||||
namespace {
|
||||
|
||||
constexpr char kResourceNameArgAttr[] = "tf.resource_name";
|
||||
constexpr char kSavedModelArgAttr[] = "tf_saved_model.bound_input";
|
||||
|
||||
LogicalResult LiftVariablesFromSession(
|
||||
ModuleOp module, Session* session,
|
||||
const SmallSet<StringRef, 4>& resource_names) {
|
||||
OpBuilder builder(module.getBodyRegion());
|
||||
MLIRContext* context = module.getContext();
|
||||
|
||||
if (!session) return module.emitOpError() << "no session provided";
|
||||
|
||||
// Read all resource variables from the session.
|
||||
std::vector<std::string> variable_names;
|
||||
variable_names.reserve(resource_names.size());
|
||||
for (StringRef name : resource_names) variable_names.push_back(name.str());
|
||||
|
||||
std::vector<Tensor> resource_tensors;
|
||||
Status status = session->Run(
|
||||
/*inputs=*/{}, variable_names,
|
||||
/*target_node_names=*/{}, &resource_tensors);
|
||||
if (!status.ok()) {
|
||||
return module.emitOpError()
|
||||
<< "failed to run the provided session: " << status.error_message();
|
||||
}
|
||||
|
||||
const DeviceMgr* device_manager;
|
||||
if (!(session->LocalDeviceManager(&device_manager).ok())) {
|
||||
return module.emitOpError() << "failed to get local device manager";
|
||||
}
|
||||
|
||||
// Read all underlying tensors of the variables from the session.
|
||||
std::vector<Tensor> tensors;
|
||||
tensors.reserve(resource_tensors.size());
|
||||
for (const Tensor& resource_tensor : resource_tensors) {
|
||||
if (resource_tensor.dtype() != tensorflow::DT_RESOURCE) {
|
||||
tensors.push_back(resource_tensor);
|
||||
continue;
|
||||
}
|
||||
|
||||
const ResourceHandle& resource_handle =
|
||||
resource_tensor.scalar<ResourceHandle>()();
|
||||
|
||||
Device* device;
|
||||
if (!(device_manager->LookupDevice(resource_handle.device(), &device)
|
||||
.ok())) {
|
||||
return module.emitOpError() << "failed to look up device";
|
||||
}
|
||||
|
||||
tensorflow::Var* var_ptr;
|
||||
if (!(device->resource_manager()
|
||||
->Lookup(resource_handle.container(), resource_handle.name(),
|
||||
&var_ptr)
|
||||
.ok())) {
|
||||
return module.emitOpError() << "failed to look up resource value";
|
||||
}
|
||||
tensorflow::core::RefCountPtr<Var> var(var_ptr);
|
||||
|
||||
// The variable tensor is already loaded into corresponding device's
|
||||
// resource manager when we load the saved model using LoadSavedModel().
|
||||
// Here we just read its value.
|
||||
mutex_lock ml(*var->mu());
|
||||
tensors.push_back(*var->tensor());
|
||||
}
|
||||
|
||||
for (const auto iter : llvm::zip(resource_names, tensors)) {
|
||||
const StringRef name = std::get<0>(iter);
|
||||
const Tensor& tensor = std::get<1>(iter);
|
||||
|
||||
// Create tensor attribute for this variable.
|
||||
StatusOr<ElementsAttr> tensor_attr_or = ConvertTensor(tensor, &builder);
|
||||
if (!tensor_attr_or.ok()) {
|
||||
return module.emitOpError()
|
||||
<< "failed to convert tensor (name: " << name.str() << ")";
|
||||
}
|
||||
ElementsAttr tensor_attr = tensor_attr_or.ValueOrDie();
|
||||
|
||||
builder.create<tf_saved_model::GlobalTensorOp>(
|
||||
NameLoc::get(builder.getIdentifier(name.str()), context),
|
||||
builder.getStringAttr(name), tensor_attr,
|
||||
TypeAttr::get(tensor_attr.getType()), builder.getUnitAttr());
|
||||
}
|
||||
|
||||
return success();
|
||||
}
|
||||
|
||||
} // namespace
|
||||
|
||||
LogicalResult LiftVariables(ModuleOp module, Session* session) {
|
||||
MLIRContext* context = module.getContext();
|
||||
mlir::Builder builder(context);
|
||||
Identifier resource_name_id = builder.getIdentifier(kResourceNameArgAttr);
|
||||
|
||||
SmallSet<StringRef, 4> resource_names;
|
||||
|
||||
for (FuncOp func : module.getOps<FuncOp>()) {
|
||||
for (int i = 0, e = func.getNumArguments(); i < e; ++i) {
|
||||
auto resource_arg =
|
||||
func.getArgAttrOfType<StringAttr>(i, kResourceNameArgAttr);
|
||||
if (!resource_arg) continue;
|
||||
|
||||
StringRef resource_name = resource_arg.getValue();
|
||||
auto flat_symbol_ref_attr =
|
||||
FlatSymbolRefAttr::get(resource_name, context);
|
||||
|
||||
// Add the corresponding `tf_saved_model.bound_input` attribute.
|
||||
func.setArgAttr(i, kSavedModelArgAttr, flat_symbol_ref_attr);
|
||||
|
||||
resource_names.insert(flat_symbol_ref_attr.getValue());
|
||||
|
||||
// Remove the existing `tf.resource_name` attribute.
|
||||
func.removeArgAttr(i, resource_name_id);
|
||||
}
|
||||
}
|
||||
|
||||
if (resource_names.empty()) return success();
|
||||
|
||||
return LiftVariablesFromSession(module, session, resource_names);
|
||||
}
|
||||
|
||||
} // namespace tf_saved_model
|
||||
} // namespace mlir
|
@ -0,0 +1,33 @@
|
||||
/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
|
||||
#ifndef TENSORFLOW_COMPILER_MLIR_TENSORFLOW_TRANSFORMS_LIFT_VARIABLES_H_
|
||||
#define TENSORFLOW_COMPILER_MLIR_TENSORFLOW_TRANSFORMS_LIFT_VARIABLES_H_
|
||||
|
||||
#include "mlir/IR/Module.h" // from @llvm-project
|
||||
#include "mlir/Support/LogicalResult.h" // from @llvm-project
|
||||
#include "tensorflow/core/public/session.h"
|
||||
|
||||
namespace mlir {
|
||||
namespace tf_saved_model {
|
||||
|
||||
// Creates GlobalTensorOp for each variable from function arguments and converts
|
||||
// them to the corresponding saved model arguments.
|
||||
LogicalResult LiftVariables(ModuleOp module, ::tensorflow::Session* session);
|
||||
|
||||
} // namespace tf_saved_model
|
||||
} // namespace mlir
|
||||
|
||||
#endif // TENSORFLOW_COMPILER_MLIR_TENSORFLOW_TRANSFORMS_LIFT_VARIABLES_H_
|
@ -0,0 +1,57 @@
|
||||
/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
|
||||
#ifndef TENSORFLOW_COMPILER_MLIR_TENSORFLOW_TRANSFORMS_LIFT_VARIABLES_PASS_H_
|
||||
#define TENSORFLOW_COMPILER_MLIR_TENSORFLOW_TRANSFORMS_LIFT_VARIABLES_PASS_H_
|
||||
|
||||
#include "mlir/IR/Module.h" // from @llvm-project
|
||||
#include "mlir/Pass/Pass.h" // from @llvm-project
|
||||
#include "mlir/Pass/PassManager.h" // from @llvm-project
|
||||
#include "mlir/Support/LogicalResult.h" // from @llvm-project
|
||||
#include "tensorflow/compiler/mlir/tensorflow/transforms/lift_variables.h"
|
||||
#include "tensorflow/core/public/session.h"
|
||||
|
||||
namespace mlir {
|
||||
namespace tf_saved_model {
|
||||
|
||||
// This pass takes care of finding all variables from the function arguments and
|
||||
// converting them to the corresponding global tensors, that will be located out
|
||||
// of function. Also it converts resource arguments from function types to the
|
||||
// corresponding saved model arguments accordingly.
|
||||
class LiftVariablesPass
|
||||
: public PassWrapper<LiftVariablesPass, OperationPass<ModuleOp>> {
|
||||
public:
|
||||
explicit LiftVariablesPass(::tensorflow::Session* session)
|
||||
: session_(session) {}
|
||||
|
||||
void runOnOperation() override {
|
||||
ModuleOp module = getOperation();
|
||||
if (failed(LiftVariables(module, session_))) signalPassFailure();
|
||||
}
|
||||
|
||||
private:
|
||||
::tensorflow::Session* session_;
|
||||
};
|
||||
|
||||
// Creates as pass that creates GlobalTensorOp for each variable from function
|
||||
// arguments and converts the function arguments to the corresponding saved
|
||||
// model arguments.
|
||||
std::unique_ptr<OperationPass<ModuleOp>> CreateLiftVariablesPass(
|
||||
::tensorflow::Session* session);
|
||||
|
||||
} // namespace tf_saved_model
|
||||
} // namespace mlir
|
||||
|
||||
#endif // TENSORFLOW_COMPILER_MLIR_TENSORFLOW_TRANSFORMS_LIFT_VARIABLES_PASS_H_
|
@ -0,0 +1,146 @@
|
||||
/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
|
||||
#ifndef TENSORFLOW_COMPILER_MLIR_TENSORFLOW_TRANSFORMS_LIFT_VARIABLES_TEST_PASS_H_
|
||||
#define TENSORFLOW_COMPILER_MLIR_TENSORFLOW_TRANSFORMS_LIFT_VARIABLES_TEST_PASS_H_
|
||||
|
||||
#include "mlir/Pass/Pass.h" // from @llvm-project
|
||||
#include "mlir/Pass/PassManager.h" // from @llvm-project
|
||||
#include "mlir/Support/LLVM.h" // from @llvm-project
|
||||
#include "tensorflow/compiler/mlir/tensorflow/transforms/lift_variables.h"
|
||||
#include "tensorflow/core/common_runtime/device_mgr.h"
|
||||
#include "tensorflow/core/framework/resource_mgr.h"
|
||||
#include "tensorflow/core/framework/tensor.h"
|
||||
#include "tensorflow/core/platform/errors.h"
|
||||
#include "tensorflow/core/platform/status.h"
|
||||
#include "tensorflow/core/platform/threadpool_options.h"
|
||||
#include "tensorflow/core/public/session.h"
|
||||
|
||||
namespace mlir {
|
||||
namespace tf_saved_model {
|
||||
|
||||
using ::tensorflow::DeviceMgr;
|
||||
using ::tensorflow::Session;
|
||||
using ::tensorflow::Status;
|
||||
using ::tensorflow::Tensor;
|
||||
|
||||
// FakeSession is for testing only.
|
||||
class FakeSession : public tensorflow::Session {
|
||||
public:
|
||||
FakeSession() {}
|
||||
~FakeSession() override = default;
|
||||
|
||||
Status Create(const tensorflow::GraphDef& graph) override {
|
||||
return tensorflow::errors::Unimplemented("not available");
|
||||
}
|
||||
Status Extend(const tensorflow::GraphDef& graph) override {
|
||||
return tensorflow::errors::Unimplemented("not available");
|
||||
}
|
||||
|
||||
Status Close() override {
|
||||
return tensorflow::errors::Unimplemented("not available");
|
||||
}
|
||||
|
||||
Status ListDevices(
|
||||
std::vector<tensorflow::DeviceAttributes>* response) override {
|
||||
return tensorflow::errors::Unimplemented("not available");
|
||||
}
|
||||
|
||||
Status LocalDeviceManager(
|
||||
const tensorflow::DeviceMgr** deviceMgrPtr) override {
|
||||
// This method returns a null device manager without making an error.
|
||||
// Users of this method will be notified since it will have a fake data.
|
||||
*deviceMgrPtr = nullptr;
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
Status Run(const std::vector<std::pair<std::string, Tensor>>& inputs,
|
||||
const std::vector<std::string>& output_names,
|
||||
const std::vector<std::string>& target_nodes,
|
||||
std::vector<Tensor>* outputs) override {
|
||||
tensorflow::RunMetadata run_metadata;
|
||||
return Run(tensorflow::RunOptions(), inputs, output_names, target_nodes,
|
||||
outputs, &run_metadata);
|
||||
}
|
||||
|
||||
Status Run(const tensorflow::RunOptions& run_options,
|
||||
const std::vector<std::pair<std::string, Tensor>>& inputs,
|
||||
const std::vector<std::string>& output_names,
|
||||
const std::vector<std::string>& target_nodes,
|
||||
std::vector<Tensor>* outputs,
|
||||
tensorflow::RunMetadata* run_metadata) override {
|
||||
return Run(run_options, inputs, output_names, target_nodes, outputs,
|
||||
run_metadata, tensorflow::thread::ThreadPoolOptions());
|
||||
}
|
||||
|
||||
Status Run(const tensorflow::RunOptions& run_options,
|
||||
const std::vector<std::pair<std::string, Tensor>>& inputs,
|
||||
const std::vector<std::string>& output_names,
|
||||
const std::vector<std::string>& target_nodes,
|
||||
std::vector<Tensor>* outputs,
|
||||
tensorflow::RunMetadata* run_metadata,
|
||||
const tensorflow::thread::ThreadPoolOptions& thread_pool_options)
|
||||
override {
|
||||
for (const std::string& output_name : output_names) {
|
||||
Tensor output;
|
||||
if (output_name == "dense/bias") {
|
||||
outputs->push_back(
|
||||
Tensor(tensorflow::DT_FLOAT, tensorflow::TensorShape({50})));
|
||||
} else if (output_name == "dense/kernel") {
|
||||
outputs->push_back(
|
||||
Tensor(tensorflow::DT_FLOAT, tensorflow::TensorShape({100, 50})));
|
||||
} else {
|
||||
// Create a scalar float tensor.
|
||||
outputs->push_back(
|
||||
Tensor(tensorflow::DT_FLOAT, tensorflow::TensorShape({})));
|
||||
}
|
||||
}
|
||||
return Status::OK();
|
||||
}
|
||||
};
|
||||
|
||||
// This pass is only available in the tf-opt binary for testing.
|
||||
class LiftVariablesTestPass
|
||||
: public PassWrapper<LiftVariablesTestPass, OperationPass<ModuleOp>> {
|
||||
public:
|
||||
LiftVariablesTestPass() { session_ = new FakeSession(); }
|
||||
|
||||
~LiftVariablesTestPass() override { delete session_; }
|
||||
|
||||
void runOnOperation() override {
|
||||
ModuleOp module = getOperation();
|
||||
if (failed(LiftVariables(module, session_))) signalPassFailure();
|
||||
}
|
||||
|
||||
private:
|
||||
Session* session_;
|
||||
};
|
||||
|
||||
// This pass is only available in the tf-opt binary for testing.
|
||||
class LiftVariablesInvalidSessionTestPass
|
||||
: public PassWrapper<LiftVariablesInvalidSessionTestPass,
|
||||
OperationPass<ModuleOp>> {
|
||||
public:
|
||||
void runOnOperation() override {
|
||||
ModuleOp module = getOperation();
|
||||
// Pass an invalid session argument, which is a nullptr.
|
||||
if (failed(LiftVariables(module, /*session=*/nullptr))) signalPassFailure();
|
||||
}
|
||||
};
|
||||
|
||||
} // namespace tf_saved_model
|
||||
} // namespace mlir
|
||||
|
||||
#endif // TENSORFLOW_COMPILER_MLIR_TENSORFLOW_TRANSFORMS_LIFT_VARIABLES_TEST_PASS_H_
|
@ -0,0 +1,32 @@
|
||||
/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
|
||||
#include "tensorflow/compiler/mlir/tensorflow/transforms/lift_variables_test_pass.h"
|
||||
|
||||
namespace mlir {
|
||||
namespace tf_saved_model {
|
||||
|
||||
static PassRegistration<LiftVariablesTestPass> lift_variables_test_pass(
|
||||
"tf-saved-model-lift-variables-test",
|
||||
"Lift variables and save them as global tensors");
|
||||
|
||||
static PassRegistration<LiftVariablesInvalidSessionTestPass>
|
||||
lift_variables_invalid_session_test_pass(
|
||||
"tf-saved-model-lift-variables-invalid-session-test",
|
||||
"Lift variables and save them as global tensors with an invalid "
|
||||
"session");
|
||||
|
||||
} // namespace tf_saved_model
|
||||
} // namespace mlir
|
@ -73,9 +73,8 @@ LogicalResult CollectAndGroupOutsideClusterOps(Block* block,
|
||||
}
|
||||
|
||||
// Moves `cluster_ops` to associated `launch_op` body.
|
||||
void MoveOutsideClusterOpsToLaunchOp(
|
||||
tf_device::LaunchOp launch_op,
|
||||
const llvm::SmallVector<Operation*, 8>& cluster_ops) {
|
||||
void MoveOutsideClusterOpsToLaunchOp(tf_device::LaunchOp launch_op,
|
||||
llvm::ArrayRef<Operation*> cluster_ops) {
|
||||
MLIRContext* context = launch_op.getContext();
|
||||
Operation* terminator = launch_op.GetBody().getTerminator();
|
||||
|
||||
@ -123,7 +122,7 @@ void PropagateParallelExecuteReturnToReplicate(
|
||||
|
||||
// Extracts all externally provided operands of `cluster_ops`.
|
||||
llvm::SmallSetVector<Value, 4> GetExternalOperands(
|
||||
const llvm::SmallVector<Operation*, 8>& cluster_ops) {
|
||||
llvm::ArrayRef<Operation*> cluster_ops) {
|
||||
llvm::SmallSetVector<Value, 4> external_values;
|
||||
|
||||
for (Operation* op : cluster_ops) {
|
||||
@ -143,7 +142,7 @@ llvm::SmallSetVector<Value, 4> GetExternalOperands(
|
||||
|
||||
// Extracts all externally used outputs of `cluster_ops`.
|
||||
llvm::SmallVector<Value, 4> GetExternalOutputs(
|
||||
const llvm::SmallVector<Operation*, 8>& cluster_ops) {
|
||||
llvm::ArrayRef<Operation*> cluster_ops) {
|
||||
llvm::SmallSetVector<Value, 4> external_outputs;
|
||||
|
||||
for (Operation* op : cluster_ops) {
|
||||
@ -166,7 +165,7 @@ llvm::SmallVector<Value, 4> GetExternalOutputs(
|
||||
// as an operand. If there are no external_inputs, set insertion point to first
|
||||
// cluster_op.
|
||||
void SetHostComputeInsertion(
|
||||
OpBuilder* builder, const llvm::SmallVector<Operation*, 8>& cluster_ops,
|
||||
OpBuilder* builder, llvm::ArrayRef<Operation*> cluster_ops,
|
||||
const llvm::SmallSetVector<Value, 4>& external_inputs) {
|
||||
if (external_inputs.empty()) builder->setInsertionPoint(cluster_ops.front());
|
||||
for (const auto& cluster_op : cluster_ops) {
|
||||
@ -183,9 +182,9 @@ void SetHostComputeInsertion(
|
||||
// using `communication_key`.
|
||||
TF::_HostComputeMlirOp CreateHostCompute(
|
||||
OpBuilder* builder, tf_device::ClusterOp tpu_cluster,
|
||||
const llvm::SmallVector<Operation*, 8>& cluster_ops,
|
||||
llvm::ArrayRef<Operation*> cluster_ops,
|
||||
const llvm::SmallSetVector<Value, 4>& inputs, llvm::ArrayRef<Value> outputs,
|
||||
const std::string& communication_key) {
|
||||
llvm::StringRef communication_key) {
|
||||
llvm::SmallVector<Type, 4> device_output_types;
|
||||
for (const auto& output : outputs)
|
||||
device_output_types.push_back(output.getType());
|
||||
@ -201,10 +200,9 @@ TF::_HostComputeMlirOp CreateHostCompute(
|
||||
|
||||
void MoveOutsideCompiledOps(
|
||||
tf_device::ClusterOp tpu_cluster, llvm::StringRef outside_cluster_name,
|
||||
tf_device::LaunchOp host_launch_op,
|
||||
const llvm::SmallVector<Operation*, 8>& cluster_ops,
|
||||
tf_device::LaunchOp host_launch_op, llvm::ArrayRef<Operation*> cluster_ops,
|
||||
const llvm::SmallSetVector<Value, 4>& external_inputs,
|
||||
const llvm::SmallVector<Value, 4>& external_outputs) {
|
||||
llvm::ArrayRef<Value> external_outputs) {
|
||||
if (external_inputs.empty() && external_outputs.empty()) {
|
||||
MoveOutsideClusterOpsToLaunchOp(host_launch_op, cluster_ops);
|
||||
return;
|
||||
|
@ -125,7 +125,8 @@ Status LowerTfOpToLhloWithDynamicShapes(mlir::ModuleOp module) {
|
||||
pm.addNestedPass<mlir::FuncOp>(
|
||||
absl::make_unique<MaterializeBroadcastsPass>());
|
||||
pm.addNestedPass<mlir::FuncOp>(absl::make_unique<UnfuseBatchNormPass>());
|
||||
pm.addPass(mlir::xla_hlo::createLegalizeToLhloPass());
|
||||
pm.addPass(mlir::xla_hlo::createLegalizeToLhloPass(
|
||||
/*results_escape_functions=*/true));
|
||||
pm.addNestedPass<mlir::FuncOp>(mlir::xla_lhlo::createLhloCopyRemovalPass());
|
||||
|
||||
if (failed(pm.run(module))) {
|
||||
@ -148,7 +149,12 @@ struct PropagateStaticKnowledge
|
||||
// We do not change the signature so that we keep a somewhat stable ABI
|
||||
// that is easy to undertand by tools.
|
||||
mlir::LLVM::LLVMFuncOp func = getOperation();
|
||||
|
||||
// This only works if the function is local and we can rewrite it.
|
||||
if (func.isExternal()) return;
|
||||
|
||||
mlir::OpBuilder b(func.getBody());
|
||||
// Steal the LLVM representation of the index type from the third argument.
|
||||
auto index_type = func.getArgument(3).getType();
|
||||
mlir::Value one = b.create<mlir::LLVM::ConstantOp>(
|
||||
func.getLoc(), index_type, b.getIntegerAttr(b.getIndexType(), 1));
|
||||
@ -156,10 +162,21 @@ struct PropagateStaticKnowledge
|
||||
func.getLoc(), index_type, b.getIntegerAttr(b.getIndexType(), 0));
|
||||
uint32_t arg_pos = 0;
|
||||
std::vector<uint32_t> positions;
|
||||
for (mlir::Type arg_type : func_type.getInputs()) {
|
||||
// Collect the agument and return types of the surrounding function.
|
||||
auto arg_types = llvm::to_vector<4>(llvm::concat<const mlir::Type>(
|
||||
func_type.getInputs(), func_type.getResults()));
|
||||
for (mlir::Type arg_type : arg_types) {
|
||||
if (!arg_type.isa<mlir::MemRefType>()) {
|
||||
func.emitError() << "argument of surrounding func is not ranked memref";
|
||||
signalPassFailure();
|
||||
return;
|
||||
}
|
||||
positions.push_back(arg_pos);
|
||||
// Replace the offset with zero. Offset is argument number 3.
|
||||
func.getArgument(arg_pos + 2).replaceAllUsesWith(zero);
|
||||
arg_pos += 3 + arg_type.cast<mlir::ShapedType>().getRank() * 2;
|
||||
// Forward over base_ptr, aligned_ptr, offset, size and stride arguments.
|
||||
arg_pos += 3 + arg_type.cast<mlir::MemRefType>().getRank() * 2;
|
||||
// Replace the last stride with constant 1.
|
||||
func.getArgument(arg_pos - 1).replaceAllUsesWith(one);
|
||||
}
|
||||
|
||||
@ -169,17 +186,17 @@ struct PropagateStaticKnowledge
|
||||
if (!same_shape.empty()) {
|
||||
auto first = same_shape.front();
|
||||
auto first_offset = positions.at(first);
|
||||
mlir::ShapedType first_type =
|
||||
func_type.getInput(first).cast<mlir::ShapedType>();
|
||||
auto first_type = arg_types[first].cast<mlir::ShapedType>();
|
||||
uint32_t rank = first_type.getRank();
|
||||
for (auto same : same_shape.drop_front(1)) {
|
||||
uint32_t same_offset = positions.at(same);
|
||||
auto same_type = func_type.getInput(same).cast<mlir::ShapedType>();
|
||||
auto same_type = arg_types[same].cast<mlir::ShapedType>();
|
||||
if (same_type.getRank() != rank) {
|
||||
func.emitOpError() << "same shape constraints on arguments with "
|
||||
"non-matching shapes: #"
|
||||
<< first << " and #" << same;
|
||||
signalPassFailure();
|
||||
continue;
|
||||
}
|
||||
|
||||
for (uint32_t i = 0; i < 2 * rank; ++i) {
|
||||
@ -237,18 +254,16 @@ StatusOr<std::vector<uint8_t>> tensorflow::kernel_gen::GenerateCubinForTfCode(
|
||||
mlir::OwningModuleRef module = mlir::parseSourceString(tf_code, &context);
|
||||
|
||||
TF_RETURN_IF_ERROR(LowerTfOpToLhloWithDynamicShapes(module.get()));
|
||||
TF_RETURN_IF_ERROR(
|
||||
xla::mlir_gpu::LowerLHLOToGPU(module.get(), tile_sizes, unroll_factors,
|
||||
/*collapseParallelLoops=*/false));
|
||||
TF_RETURN_IF_ERROR(xla::mlir_gpu::LowerKernelBodiesToNVVM(module.get()));
|
||||
// TODO(b/156985522): Figure out why we get a segfault when generating Tanh
|
||||
// with 'same_shape' containing {0, 1}. We would also get the crash if we
|
||||
// unconditionally call PropagateStaticShapeKnowledgeToKernel while
|
||||
// 'same_shape' is empty.
|
||||
if (!same_shape.empty()) {
|
||||
TF_RETURN_IF_ERROR(
|
||||
PropagateStaticShapeKnowledgeToKernel(module.get(), same_shape));
|
||||
{
|
||||
xla::mlir_gpu::LowerLHLOToGPUOptions options;
|
||||
options.tile_sizes = tile_sizes;
|
||||
options.unroll_factors = unroll_factors;
|
||||
options.collapse_parallel_loops = false;
|
||||
TF_RETURN_IF_ERROR(xla::mlir_gpu::LowerLHLOToGPU(module.get(), options));
|
||||
}
|
||||
TF_RETURN_IF_ERROR(xla::mlir_gpu::LowerKernelBodiesToNVVM(module.get()));
|
||||
TF_RETURN_IF_ERROR(
|
||||
PropagateStaticShapeKnowledgeToKernel(module.get(), same_shape));
|
||||
|
||||
mlir::OwningModuleRef kernel_module =
|
||||
xla::mlir_gpu::ExtractKernelModule(*module).ValueOrDie();
|
||||
|
@ -377,6 +377,7 @@ cc_library(
|
||||
"@llvm-project//mlir:LinalgOps",
|
||||
"@llvm-project//mlir:LinalgTransforms",
|
||||
"@llvm-project//mlir:Pass",
|
||||
"@llvm-project//mlir:StandardOps",
|
||||
"@llvm-project//mlir:TransformUtils",
|
||||
],
|
||||
alwayslink = 1,
|
||||
|
@ -115,6 +115,9 @@ StatusOr<mlir::FuncOp> HloFunctionImporter::ImportAsFunc(
|
||||
llvm::ArrayRef<mlir::NamedAttribute> attrs;
|
||||
auto function = mlir::FuncOp::create(mlir::UnknownLoc::get(context_),
|
||||
computation_name, func_type, attrs);
|
||||
auto visibility = computation_name == "main" ? FuncOp::Visibility::Public
|
||||
: FuncOp::Visibility::Private;
|
||||
function.setVisibility(visibility);
|
||||
module_.push_back(function);
|
||||
|
||||
// Add to the map right away for function calls.
|
||||
|
@ -1112,21 +1112,13 @@ def HLO_DynamicReshapeOp: HLO_Op<"dynamic_reshape", [NoSideEffect]> {
|
||||
let hasCustomHLOConverter = 1;
|
||||
}
|
||||
|
||||
def ScatterDimensionNumbers : StructAttr<"ScatterDimensionNumbers", HLO_Dialect,
|
||||
[StructFieldAttr<"update_window_dims", I64ElementsAttr>,
|
||||
StructFieldAttr<"inserted_window_dims", I64ElementsAttr>,
|
||||
StructFieldAttr<"scatter_dims_to_operand_dims", I64ElementsAttr>,
|
||||
StructFieldAttr<"index_vector_dim", I64Attr>]> {
|
||||
let description = "Structure of dimension information for scatter";
|
||||
}
|
||||
|
||||
def HLO_ScatterOp: HLO_Op<"scatter", [RecursiveSideEffects]>,
|
||||
BASE_HLO_ScatterOp {
|
||||
let arguments = (ins
|
||||
HLO_Tensor:$operand,
|
||||
HLO_Tensor:$scatter_indices,
|
||||
HLO_Tensor:$updates,
|
||||
ScatterDimensionNumbers:$scatter_dimension_numbers,
|
||||
ScatterDimensionNumbers<HLO_Dialect>:$scatter_dimension_numbers,
|
||||
DefaultValuedAttr<BoolAttr, "false">:$indices_are_sorted,
|
||||
DefaultValuedAttr<BoolAttr, "false">:$unique_indices
|
||||
);
|
||||
|
@ -1098,6 +1098,15 @@ class BASE_HLO_ReshapeOp {
|
||||
}];
|
||||
}
|
||||
|
||||
class ScatterDimensionNumbers<Dialect dialect> : StructAttr<
|
||||
"ScatterDimensionNumbers", dialect, [
|
||||
StructFieldAttr<"update_window_dims", I64ElementsAttr>,
|
||||
StructFieldAttr<"inserted_window_dims", I64ElementsAttr>,
|
||||
StructFieldAttr<"scatter_dims_to_operand_dims", I64ElementsAttr>,
|
||||
StructFieldAttr<"index_vector_dim", I64Attr>]> {
|
||||
let description = "Structure of dimension information for scatter";
|
||||
}
|
||||
|
||||
class BASE_HLO_ScatterOp {
|
||||
string summary = "Scatter operator";
|
||||
|
||||
|
@ -471,6 +471,12 @@ def LHLO_BatchNormTrainingOp : LHLO_Op<"batch_norm_training", []>,
|
||||
);
|
||||
}
|
||||
|
||||
// TODO(timshen): add a custom verifier.
|
||||
def LHLO_BitcastOp: LHLO_Op<"bitcast", []> {
|
||||
let arguments = (ins Arg<LHLO_Buffer, "", [MemRead]>:$input,
|
||||
Arg<LHLO_Buffer, "", [MemWrite]>:$output);
|
||||
}
|
||||
|
||||
def LHLO_BroadcastOp : LHLO_Op<"broadcast",
|
||||
[]>, BASE_HLO_BroadcastOp {
|
||||
let arguments = (ins
|
||||
@ -578,6 +584,19 @@ def LHLO_ReshapeOp: LHLO_Op<"reshape", []>, BASE_HLO_ReshapeOp {
|
||||
);
|
||||
}
|
||||
|
||||
def LHLO_ScatterOp: LHLO_Op<"scatter", []>, BASE_HLO_ScatterOp {
|
||||
let arguments = (ins
|
||||
Arg<LHLO_Buffer, "", [MemRead]>:$operand,
|
||||
Arg<LHLO_Buffer, "", [MemRead]>:$scatter_indices,
|
||||
Arg<LHLO_Buffer, "", [MemRead]>:$updates,
|
||||
Arg<LHLO_Buffer, "", [MemWrite]>:$output,
|
||||
ScatterDimensionNumbers<LHLO_Dialect>:$scatter_dimension_numbers,
|
||||
DefaultValuedAttr<BoolAttr, "false">:$indices_are_sorted,
|
||||
DefaultValuedAttr<BoolAttr, "false">:$unique_indices
|
||||
);
|
||||
|
||||
let regions = (region SizedRegion<1>:$update_computation);
|
||||
}
|
||||
|
||||
def LHLO_SelectOp: LHLO_Op<"select", []>, BASE_HLO_SelectOp {
|
||||
let arguments = (ins
|
||||
@ -712,6 +731,44 @@ def LHLO_TriangularSolveOp: LHLO_Op<"triangular_solve", [SameOperandsElementType
|
||||
);
|
||||
}
|
||||
|
||||
// TODO(timshen): add a custom verifier.
|
||||
def LHLO_MapOp: LHLO_Op<"map", [SameOperandsShape]>, BASE_HLO_MapOp {
|
||||
let arguments = (ins
|
||||
Arg<Variadic<LHLO_Buffer>, "", [MemRead]>:$operands,
|
||||
Arg<LHLO_Buffer, "", [MemWrite]>:$output,
|
||||
I64ElementsAttr:$dimensions
|
||||
);
|
||||
let regions = (region SizedRegion<1>:$computation);
|
||||
}
|
||||
|
||||
def LHLO_RngGetAndUpdateStateOp: LHLO_Op<"rng_get_and_update_state", []> {
|
||||
let arguments = (ins
|
||||
Arg<MemRefOf<[UI64]>, "", [MemRead, MemWrite]>:$state,
|
||||
I64Attr:$delta
|
||||
);
|
||||
}
|
||||
|
||||
// TODO(timshen): add a custom verifier.
|
||||
def LHLO_SortOp: LHLO_Op<"sort", []>, BASE_HLO_SortOp {
|
||||
let arguments = (ins
|
||||
Arg<Variadic<LHLO_Buffer>, "", [MemRead]>:$operands,
|
||||
LHLO_BufferOrTuple:$output,
|
||||
DefaultValuedAttr<I64Attr, "-1">:$dimension,
|
||||
DefaultValuedAttr<BoolAttr, "false">:$is_stable
|
||||
);
|
||||
|
||||
let regions = (region SizedRegion<1>:$comparator);
|
||||
}
|
||||
|
||||
def LHLO_TupleSelectOp: LHLO_Op<"tuple_select", [SameOperandsShape]> {
|
||||
let arguments = (ins
|
||||
Arg<LHLO_PredBuffer, "", [MemRead]>:$pred,
|
||||
Arg<LHLO_Buffer, "", [MemRead]>:$on_true,
|
||||
Arg<LHLO_Buffer, "", [MemRead]>:$on_false,
|
||||
Arg<LHLO_Buffer, "", [MemWrite]>:$output
|
||||
);
|
||||
}
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
// Late operations
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
@ -1,12 +1,13 @@
|
||||
// RUN: xla-opt -hlo-legalize-to-lhlo -buffer-placement -split-input-file %s -o - | FileCheck %s
|
||||
// RUN: xla-opt -hlo-legalize-to-lhlo -buffer-placement -split-input-file %s -o - | FileCheck --check-prefixes=PRE,BOTH %s
|
||||
// RUN: xla-opt -hlo-legalize-to-lhlo=results-escape-function=true -buffer-placement -split-input-file %s -o - | FileCheck --check-prefixes=ESC,BOTH %s
|
||||
|
||||
// CHECK-LABEL: func @attrs
|
||||
// BOTH-LABEL: func @attrs
|
||||
func @attrs_copy(%operand: memref<2x2xf32>, %result: memref<2x2xf32>) {
|
||||
%tensor_operand = tensor_load %operand : memref<2x2xf32>
|
||||
%tensor_result = "xla_hlo.exponential"(%tensor_operand)
|
||||
{some_attr_1 = "exp.1", some_attr_2 = dense<1> : tensor<1xi64>}
|
||||
: (tensor<2x2xf32>) -> tensor<2x2xf32>
|
||||
// CHECK: "xla_lhlo.exponential"(%{{.*}}, %{{.*}}) {some_attr_1 = "exp.1", some_attr_2 = dense<1> : tensor<1xi64>}
|
||||
// BOTH: "xla_lhlo.exponential"(%{{.*}}, %{{.*}}) {some_attr_1 = "exp.1", some_attr_2 = dense<1> : tensor<1xi64>}
|
||||
tensor_store %tensor_result, %result : memref<2x2xf32>
|
||||
return
|
||||
}
|
||||
@ -16,13 +17,16 @@ func @attrs_copy(%operand: memref<2x2xf32>, %result: memref<2x2xf32>) {
|
||||
func @return_func(%arg0: tensor<4xf32>) -> tensor<4xf32> {
|
||||
return %arg0 : tensor<4xf32>
|
||||
}
|
||||
// CHECK: (%[[ARG0:.*]]: [[TYPE:.*]], %[[RESULT:.*]]: [[TYPE]])
|
||||
// CHECK-NEXT: "xla_lhlo.copy"(%[[ARG0]], %[[RESULT]]) : ([[TYPE]], [[TYPE]]) -> ()
|
||||
// CHECK-NEXT: return
|
||||
// PRE: (%[[ARG0:.*]]: [[TYPE:.*]], %[[RESULT:.*]]: [[TYPE]])
|
||||
// PRE-NEXT: "xla_lhlo.copy"(%[[ARG0]], %[[RESULT]]) : ([[TYPE]], [[TYPE]]) -> ()
|
||||
// PRE-NEXT: return
|
||||
// ESC: (%[[ARG0:.*]]: [[TYPE:.*]]) -> [[TYPE]]
|
||||
// ESC-NOT: "xla_lhlo.copy"
|
||||
// ESC-NEXT: return %[[ARG0]]
|
||||
|
||||
// -----
|
||||
|
||||
// CHECK-LABEL: func @func_op_long
|
||||
// BOTH-LABEL: func @func_op_long
|
||||
func @func_op_long(%arg0: tensor<4xf32>, %arg1: tensor<4xf32>) -> tensor<4xf32> {
|
||||
%1 = xla_hlo.maximum %arg0, %arg1 : tensor<4xf32>
|
||||
%2 = xla_hlo.add %arg0, %1 : tensor<4xf32>
|
||||
@ -31,89 +35,91 @@ func @func_op_long(%arg0: tensor<4xf32>, %arg1: tensor<4xf32>) -> tensor<4xf32>
|
||||
%5 = xla_hlo.multiply %2, %4 : tensor<4xf32>
|
||||
return %5 : tensor<4xf32>
|
||||
}
|
||||
// CHECK: (%[[NEW_ARG0:.*]]: memref<4xf32>, %[[NEW_ARG1:.*]]: memref<4xf32>, %[[RESULT:.*]]: memref<4xf32>)
|
||||
// CHECK-NEXT: %[[MAX_RESULT:.*]] = alloc() : memref<4xf32>
|
||||
// CHECK-NEXT: "xla_lhlo.maximum"(%[[NEW_ARG0]], %[[NEW_ARG1]], %[[MAX_RESULT]])
|
||||
// CHECK-NEXT: %[[ADD_RESULT:.*]] = alloc() : memref<4xf32>
|
||||
// CHECK-NEXT: "xla_lhlo.add"(%[[NEW_ARG0]], %[[MAX_RESULT]], %[[ADD_RESULT]])
|
||||
// CHECK-NEXT: dealloc %[[MAX_RESULT]] : memref<4xf32>
|
||||
// CHECK-NEXT: %[[MIN_RESULT:.*]] = alloc() : memref<4xf32>
|
||||
// CHECK-NEXT: "xla_lhlo.minimum"(%[[NEW_ARG0]], %[[NEW_ARG1]], %[[MIN_RESULT]])
|
||||
// CHECK-NEXT: %[[SUB_RESULT:.*]] = alloc() : memref<4xf32>
|
||||
// CHECK-NEXT: "xla_lhlo.subtract"(%[[NEW_ARG1]], %[[MIN_RESULT]], %[[SUB_RESULT]])
|
||||
// CHECK-NEXT: dealloc %[[MIN_RESULT]] : memref<4xf32>
|
||||
// CHECK-NEXT: %[[MUL_RESULT:.*]] = alloc() : memref<4xf32>
|
||||
// CHECK-NEXT: "xla_lhlo.multiply"(%[[ADD_RESULT]], %[[SUB_RESULT]], %[[MUL_RESULT]])
|
||||
// CHECK-NEXT: dealloc %[[SUB_RESULT]] : memref<4xf32>
|
||||
// CHECK-NEXT: dealloc %[[ADD_RESULT]] : memref<4xf32>
|
||||
// CHECK-NEXT: "xla_lhlo.copy"(%[[MUL_RESULT]], %[[RESULT]]) : (memref<4xf32>, memref<4xf32>) -> ()
|
||||
// CHECK-NEXT: dealloc %[[MUL_RESULT]] : memref<4xf32>
|
||||
// CHECK-NEXT: return
|
||||
// PRE: (%[[NEW_ARG0:.*]]: memref<4xf32>, %[[NEW_ARG1:.*]]: memref<4xf32>, %[[RESULT:.*]]: memref<4xf32>)
|
||||
// ESC: (%[[NEW_ARG0:.*]]: memref<4xf32>, %[[NEW_ARG1:.*]]: memref<4xf32>) -> memref<4xf32>
|
||||
// BOTH-NEXT: %[[MAX_RESULT:.*]] = alloc() : memref<4xf32>
|
||||
// BOTH-NEXT: "xla_lhlo.maximum"(%[[NEW_ARG0]], %[[NEW_ARG1]], %[[MAX_RESULT]])
|
||||
// BOTH-NEXT: %[[ADD_RESULT:.*]] = alloc() : memref<4xf32>
|
||||
// BOTH-NEXT: "xla_lhlo.add"(%[[NEW_ARG0]], %[[MAX_RESULT]], %[[ADD_RESULT]])
|
||||
// BOTH-NEXT: dealloc %[[MAX_RESULT]] : memref<4xf32>
|
||||
// BOTH-NEXT: %[[MIN_RESULT:.*]] = alloc() : memref<4xf32>
|
||||
// BOTH-NEXT: "xla_lhlo.minimum"(%[[NEW_ARG0]], %[[NEW_ARG1]], %[[MIN_RESULT]])
|
||||
// BOTH-NEXT: %[[SUB_RESULT:.*]] = alloc() : memref<4xf32>
|
||||
// BOTH-NEXT: "xla_lhlo.subtract"(%[[NEW_ARG1]], %[[MIN_RESULT]], %[[SUB_RESULT]])
|
||||
// BOTH-NEXT: dealloc %[[MIN_RESULT]] : memref<4xf32>
|
||||
// BOTH-NEXT: %[[MUL_RESULT:.*]] = alloc() : memref<4xf32>
|
||||
// BOTH-NEXT: "xla_lhlo.multiply"(%[[ADD_RESULT]], %[[SUB_RESULT]], %[[MUL_RESULT]])
|
||||
// BOTH-NEXT: dealloc %[[SUB_RESULT]] : memref<4xf32>
|
||||
// BOTH-NEXT: dealloc %[[ADD_RESULT]] : memref<4xf32>
|
||||
// PRE-NEXT: "xla_lhlo.copy"(%[[MUL_RESULT]], %[[RESULT]]) : (memref<4xf32>, memref<4xf32>) -> ()
|
||||
// PRE-NEXT: dealloc %[[MUL_RESULT]] : memref<4xf32>
|
||||
// PRE-NEXT: return
|
||||
// ESC-NEXT: return %[[MUL_RESULT]] : memref<4xf32>
|
||||
|
||||
// -----
|
||||
|
||||
// CHECK-LABEL: func @fusion
|
||||
// BOTH-LABEL: func @fusion
|
||||
func @fusion(%multiplier: memref<2x2xf32>, %summand_1: memref<2x2xf32>,
|
||||
%summand_2: memref<2x2xf32>, %result: memref<2x2xf32>) {
|
||||
// CHECK: (%{{.*}}: {{.*}}, {{.*}}: {{.*}}, {{.*}}: {{.*}}, %[[RESULT:.*]]: {{.*}})
|
||||
// CHECK-NEXT: %[[ADD_RESULT:.*]] = alloc() : memref<2x2xf32>
|
||||
// BOTH: (%{{.*}}: {{.*}}, {{.*}}: {{.*}}, {{.*}}: {{.*}}, %[[RESULT:.*]]: {{.*}})
|
||||
// BOTH-NEXT: %[[ADD_RESULT:.*]] = alloc() : memref<2x2xf32>
|
||||
%tensor_summand_1 = tensor_load %summand_1 : memref<2x2xf32>
|
||||
%tensor_summand_2 = tensor_load %summand_2 : memref<2x2xf32>
|
||||
%sum = "xla_hlo.add"(%tensor_summand_1, %tensor_summand_2)
|
||||
: (tensor<2x2xf32>, tensor<2x2xf32>) -> tensor<2x2xf32>
|
||||
// CHECK-NEXT: "xla_lhlo.add"(%{{.*}}, %{{.*}}, %[[ADD_RESULT]])
|
||||
// CHECK-NEXT: %[[MUL_RESULT:.*]] = alloc() : memref<2x2xf32>
|
||||
// BOTH-NEXT: "xla_lhlo.add"(%{{.*}}, %{{.*}}, %[[ADD_RESULT]])
|
||||
// BOTH-NEXT: %[[MUL_RESULT:.*]] = alloc() : memref<2x2xf32>
|
||||
%tensor_multiplier = tensor_load %multiplier : memref<2x2xf32>
|
||||
%tensor_result = "xla_hlo.multiply"(%sum, %tensor_multiplier)
|
||||
: (tensor<2x2xf32>, tensor<2x2xf32>) -> tensor<2x2xf32>
|
||||
// CHECK-NEXT: "xla_lhlo.multiply"(%[[ADD_RESULT]], %{{.*}}, %[[MUL_RESULT]])
|
||||
// CHECK-NEXT: dealloc %[[ADD_RESULT]] : memref<2x2xf32>
|
||||
// CHECK-NEXT: "xla_lhlo.copy"(%[[MUL_RESULT]], %[[RESULT]])
|
||||
// BOTH-NEXT: "xla_lhlo.multiply"(%[[ADD_RESULT]], %{{.*}}, %[[MUL_RESULT]])
|
||||
// BOTH-NEXT: dealloc %[[ADD_RESULT]] : memref<2x2xf32>
|
||||
// BOTH-NEXT: "xla_lhlo.copy"(%[[MUL_RESULT]], %[[RESULT]])
|
||||
tensor_store %tensor_result, %result : memref<2x2xf32>
|
||||
// CHECK-NEXT: dealloc %[[MUL_RESULT]] : memref<2x2xf32>
|
||||
// CHECK-NEXT: "xla_lhlo.terminator"() : () -> ()
|
||||
"xla_lhlo.terminator"() : () -> ()
|
||||
// BOTH-NEXT: dealloc %[[MUL_RESULT]] : memref<2x2xf32>
|
||||
// BOTH-NEXT: return
|
||||
return
|
||||
}
|
||||
|
||||
// -----
|
||||
|
||||
// CHECK-LABEL: func @copy
|
||||
// BOTH-LABEL: func @copy
|
||||
func @copy(%operand: memref<2x2xf32>, %result: memref<2x2xf32>) {
|
||||
%tensor_operand = tensor_load %operand : memref<2x2xf32>
|
||||
%tensor_result = "xla_hlo.copy"(%tensor_operand)
|
||||
: (tensor<2x2xf32>) -> tensor<2x2xf32>
|
||||
// CHECK: "xla_lhlo.copy"(%{{.*}}, %{{.*}})
|
||||
// BOTH: "xla_lhlo.copy"(%{{.*}}, %{{.*}})
|
||||
tensor_store %tensor_result, %result : memref<2x2xf32>
|
||||
return
|
||||
}
|
||||
|
||||
// -----
|
||||
|
||||
// CHECK-LABEL: func @exp
|
||||
// BOTH-LABEL: func @exp
|
||||
func @exp(%operand: memref<2x2xf32>, %result: memref<2x2xf32>) {
|
||||
%tensor_operand = tensor_load %operand : memref<2x2xf32>
|
||||
%tensor_result = "xla_hlo.exponential"(%tensor_operand)
|
||||
: (tensor<2x2xf32>) -> tensor<2x2xf32>
|
||||
// CHECK: "xla_lhlo.exponential"(%{{.*}}, %{{.*}})
|
||||
// BOTH: "xla_lhlo.exponential"(%{{.*}}, %{{.*}})
|
||||
tensor_store %tensor_result, %result : memref<2x2xf32>
|
||||
return
|
||||
}
|
||||
|
||||
// -----
|
||||
|
||||
// CHECK-LABEL: func @log
|
||||
// BOTH-LABEL: func @log
|
||||
func @log(%operand: memref<2x2xf32>, %result: memref<2x2xf32>) {
|
||||
%tensor_operand = tensor_load %operand : memref<2x2xf32>
|
||||
%tensor_result = "xla_hlo.log"(%tensor_operand)
|
||||
: (tensor<2x2xf32>) -> tensor<2x2xf32>
|
||||
// CHECK: "xla_lhlo.log"(%{{.*}}, %{{.*}})
|
||||
// BOTH: "xla_lhlo.log"(%{{.*}}, %{{.*}})
|
||||
tensor_store %tensor_result, %result : memref<2x2xf32>
|
||||
return
|
||||
}
|
||||
|
||||
// -----
|
||||
|
||||
// CHECK-LABEL: func @select
|
||||
// BOTH-LABEL: func @select
|
||||
func @select(%pred: memref<2x2xi1>, %lhs: memref<2x2xf32>,
|
||||
%rhs: memref<2x2xf32>, %result: memref<2x2xf32>) {
|
||||
%tensor_pred = tensor_load %pred : memref<2x2xi1>
|
||||
@ -121,34 +127,34 @@ func @select(%pred: memref<2x2xi1>, %lhs: memref<2x2xf32>,
|
||||
%tensor_rhs = tensor_load %rhs : memref<2x2xf32>
|
||||
%tensor_result = "xla_hlo.select"(%tensor_pred, %tensor_lhs, %tensor_rhs)
|
||||
: (tensor<2x2xi1>, tensor<2x2xf32>, tensor<2x2xf32>) -> tensor<2x2xf32>
|
||||
// CHECK: "xla_lhlo.select"(%{{.*}}, %{{.*}}, %{{.*}}, %{{.*}})
|
||||
// BOTH: "xla_lhlo.select"(%{{.*}}, %{{.*}}, %{{.*}}, %{{.*}})
|
||||
tensor_store %tensor_result, %result : memref<2x2xf32>
|
||||
return
|
||||
}
|
||||
|
||||
// -----
|
||||
|
||||
// CHECK-LABEL: func @compare
|
||||
// BOTH-LABEL: func @compare
|
||||
func @compare(%lhs: memref<2x2xf32>, %rhs: memref<2x2xf32>, %result: memref<2x2xi1>) {
|
||||
%tensor_lhs = tensor_load %lhs : memref<2x2xf32>
|
||||
%tensor_rhs = tensor_load %rhs : memref<2x2xf32>
|
||||
%tensor_result = "xla_hlo.compare"(%tensor_lhs, %tensor_rhs)
|
||||
{comparison_direction = "EQ"}
|
||||
: (tensor<2x2xf32>, tensor<2x2xf32>) -> tensor<2x2xi1>
|
||||
// CHECK: "xla_lhlo.compare"(%{{.*}}, %{{.*}}, %{{.*}}) {comparison_direction = "EQ"}
|
||||
// BOTH: "xla_lhlo.compare"(%{{.*}}, %{{.*}}, %{{.*}}) {comparison_direction = "EQ"}
|
||||
tensor_store %tensor_result, %result : memref<2x2xi1>
|
||||
return
|
||||
}
|
||||
|
||||
// -----
|
||||
|
||||
// CHECK-LABEL: func @broadcast
|
||||
// BOTH-LABEL: func @broadcast
|
||||
func @broadcast(%operand: memref<5xf32>, %result: memref<10x5xf32>) {
|
||||
%tensor_operand = tensor_load %operand : memref<5xf32>
|
||||
%tensor_result = "xla_hlo.broadcast_in_dim"(%tensor_operand)
|
||||
{broadcast_dimensions = dense<1> : tensor<1xi64>}
|
||||
: (tensor<5xf32>) -> tensor<10x5xf32>
|
||||
// CHECK: "xla_lhlo.broadcast_in_dim"(%{{.*}}, %{{.*}}) {broadcast_dimensions = dense<1> : tensor<1xi64>}
|
||||
// BOTH: "xla_lhlo.broadcast_in_dim"(%{{.*}}, %{{.*}}) {broadcast_dimensions = dense<1> : tensor<1xi64>}
|
||||
tensor_store %tensor_result, %result : memref<10x5xf32>
|
||||
return
|
||||
}
|
||||
@ -157,55 +163,55 @@ func @broadcast(%operand: memref<5xf32>, %result: memref<10x5xf32>) {
|
||||
|
||||
func @external_func() -> tensor<3xi64>
|
||||
|
||||
// CHECK: #[[MAP:.*]] = affine_map<(d0, d1)[s0, s1] -> (d0 * s0 + d1 * s1)>
|
||||
// BOTH: #[[MAP:.*]] = affine_map<(d0, d1)[s0, s1] -> (d0 * s0 + d1 * s1)>
|
||||
|
||||
// CHECK-LABEL: func @dyn_broadcast
|
||||
// BOTH-LABEL: func @dyn_broadcast
|
||||
func @dyn_broadcast(%operand: memref<?x?xf32>) {
|
||||
// CHECK-SAME: (%[[OPERAND:.*]]: memref<?x?xf32>)
|
||||
// BOTH-SAME: (%[[OPERAND:.*]]: memref<?x?xf32>)
|
||||
%tensor_operand = tensor_load %operand : memref<?x?xf32>
|
||||
%shape = call @external_func() : () -> tensor<3xi64>
|
||||
%tensor_result = "xla_hlo.dynamic_broadcast_in_dim"(%tensor_operand, %shape) {
|
||||
broadcast_dimensions = dense<[1, 2]> : tensor<2xi64>
|
||||
} : (tensor<?x?xf32>, tensor<3xi64>) -> tensor<?x?x?xf32>
|
||||
// CHECK: %[[SHAPE:.*]] = call @external_func()
|
||||
// CHECK: %[[C0:.*]] = constant 0 : index
|
||||
// CHECK: %[[EL0:.*]] = extract_element %[[SHAPE]][%[[C0]]] : tensor<3xi64>
|
||||
// CHECK: %[[IC0:.*]] = index_cast %[[EL0]] : i64 to index
|
||||
// CHECK: %[[C1:.*]] = constant 1 : index
|
||||
// CHECK: %[[EL1:.*]] = extract_element %[[SHAPE]][%[[C1]]] : tensor<3xi64>
|
||||
// CHECK: %[[IC1:.*]] = index_cast %[[EL1]] : i64 to index
|
||||
// CHECK: %[[C2:.*]] = constant 2 : index
|
||||
// CHECK: %[[EL2:.*]] = extract_element %[[SHAPE]][%[[C2]]] : tensor<3xi64>
|
||||
// CHECK: %[[IC2:.*]] = index_cast %[[EL2]] : i64 to index
|
||||
// CHECK: %[[RESULT:.*]] = alloc(%[[IC0]], %[[IC1]], %[[IC2]])
|
||||
// BOTH: %[[SHAPE:.*]] = call @external_func()
|
||||
// BOTH: %[[C0:.*]] = constant 0 : index
|
||||
// BOTH: %[[EL0:.*]] = extract_element %[[SHAPE]][%[[C0]]] : tensor<3xi64>
|
||||
// BOTH: %[[IC0:.*]] = index_cast %[[EL0]] : i64 to index
|
||||
// BOTH: %[[C1:.*]] = constant 1 : index
|
||||
// BOTH: %[[EL1:.*]] = extract_element %[[SHAPE]][%[[C1]]] : tensor<3xi64>
|
||||
// BOTH: %[[IC1:.*]] = index_cast %[[EL1]] : i64 to index
|
||||
// BOTH: %[[C2:.*]] = constant 2 : index
|
||||
// BOTH: %[[EL2:.*]] = extract_element %[[SHAPE]][%[[C2]]] : tensor<3xi64>
|
||||
// BOTH: %[[IC2:.*]] = index_cast %[[EL2]] : i64 to index
|
||||
// BOTH: %[[RESULT:.*]] = alloc(%[[IC0]], %[[IC1]], %[[IC2]])
|
||||
|
||||
// CHECK: %[[C0_:.*]] = constant 0 : index
|
||||
// CHECK: %[[C1_:.*]] = constant 1 : index
|
||||
// BOTH: %[[C0_:.*]] = constant 0 : index
|
||||
// BOTH: %[[C1_:.*]] = constant 1 : index
|
||||
|
||||
// CHECK: %[[C1__:.*]] = constant 1 : index
|
||||
// CHECK: %[[EL1_:.*]] = extract_element %[[SHAPE]]{{\[}}%[[C1__]]] : tensor<3xi64>
|
||||
// CHECK: %[[C0___:.*]] = constant 0 : index
|
||||
// CHECK: %[[OPERAND_DIM_0:.*]] = dim %[[OPERAND]], %[[C0___]] : memref<?x?xf32>
|
||||
// CHECK: %[[RESULT_DIM_1:.*]] = index_cast %[[EL1_]] : i64 to index
|
||||
// CHECK: %[[EXPAND_0:.*]] = cmpi "slt", %[[OPERAND_DIM_0]], %[[RESULT_DIM_1]]
|
||||
// CHECK: %[[STRIDE_0:.*]] = select %[[EXPAND_0]], %[[C0_]], %[[C1_]] : index
|
||||
// BOTH: %[[C1__:.*]] = constant 1 : index
|
||||
// BOTH: %[[EL1_:.*]] = extract_element %[[SHAPE]]{{\[}}%[[C1__]]] : tensor<3xi64>
|
||||
// BOTH: %[[C0___:.*]] = constant 0 : index
|
||||
// BOTH: %[[OPERAND_DIM_0:.*]] = dim %[[OPERAND]], %[[C0___]] : memref<?x?xf32>
|
||||
// BOTH: %[[RESULT_DIM_1:.*]] = index_cast %[[EL1_]] : i64 to index
|
||||
// BOTH: %[[EXPAND_0:.*]] = cmpi "slt", %[[OPERAND_DIM_0]], %[[RESULT_DIM_1]]
|
||||
// BOTH: %[[STRIDE_0:.*]] = select %[[EXPAND_0]], %[[C0_]], %[[C1_]] : index
|
||||
|
||||
// CHECK: %[[C2_:.*]] = constant 2 : index
|
||||
// CHECK: %[[EL2_:.*]] = extract_element %[[SHAPE]]{{\[}}%[[C2_]]] : tensor<3xi64>
|
||||
// CHECK: %[[C1___:.*]] = constant 1 : index
|
||||
// CHECK: %[[OPERAND_DIM_1:.*]] = dim %[[OPERAND]], %[[C1___]] : memref<?x?xf32>
|
||||
// CHECK: %[[RESULT_DIM_2:.*]] = index_cast %[[EL2_]] : i64 to index
|
||||
// CHECK: %[[EXPAND_1:.*]] = cmpi "slt", %[[OPERAND_DIM_1]], %[[RESULT_DIM_2]]
|
||||
// CHECK: %[[STRIDE_1:.*]] = select %[[EXPAND_1]], %[[C0_]], %[[C1_]] : index
|
||||
// BOTH: %[[C2_:.*]] = constant 2 : index
|
||||
// BOTH: %[[EL2_:.*]] = extract_element %[[SHAPE]]{{\[}}%[[C2_]]] : tensor<3xi64>
|
||||
// BOTH: %[[C1___:.*]] = constant 1 : index
|
||||
// BOTH: %[[OPERAND_DIM_1:.*]] = dim %[[OPERAND]], %[[C1___]] : memref<?x?xf32>
|
||||
// BOTH: %[[RESULT_DIM_2:.*]] = index_cast %[[EL2_]] : i64 to index
|
||||
// BOTH: %[[EXPAND_1:.*]] = cmpi "slt", %[[OPERAND_DIM_1]], %[[RESULT_DIM_2]]
|
||||
// BOTH: %[[STRIDE_1:.*]] = select %[[EXPAND_1]], %[[C0_]], %[[C1_]] : index
|
||||
|
||||
// CHECK: %[[TRANSFORMED_MEMREF:.*]] = xla_lhlo.dynamic_memref_cast
|
||||
// CHECK-SAME: %[[OPERAND]](%[[RESULT_DIM_1]], %[[RESULT_DIM_2]])
|
||||
// CHECK-SAME: {{\[}}%[[STRIDE_0]], %[[STRIDE_1]]]
|
||||
// CHECK-SAME: : memref<?x?xf32> -> memref<?x?xf32, #map0>
|
||||
// BOTH: %[[TRANSFORMED_MEMREF:.*]] = xla_lhlo.dynamic_memref_cast
|
||||
// BOTH-SAME: %[[OPERAND]](%[[RESULT_DIM_1]], %[[RESULT_DIM_2]])
|
||||
// BOTH-SAME: {{\[}}%[[STRIDE_0]], %[[STRIDE_1]]]
|
||||
// BOTH-SAME: : memref<?x?xf32> -> memref<?x?xf32, #map0>
|
||||
|
||||
// CHECK: "xla_lhlo.broadcast_in_dim"(%[[TRANSFORMED_MEMREF]], %[[RESULT]]) {
|
||||
// CHECK-SAME: broadcast_dimensions = dense<[1, 2]> : tensor<2xi64>
|
||||
// CHECK-SAME: } : (memref<?x?xf32, #[[MAP]]>, memref<?x?x?xf32>) -> ()
|
||||
// BOTH: "xla_lhlo.broadcast_in_dim"(%[[TRANSFORMED_MEMREF]], %[[RESULT]]) {
|
||||
// BOTH-SAME: broadcast_dimensions = dense<[1, 2]> : tensor<2xi64>
|
||||
// BOTH-SAME: } : (memref<?x?xf32, #[[MAP]]>, memref<?x?x?xf32>) -> ()
|
||||
|
||||
// Do not store the value back to avoid the tensor-store being rewritten to
|
||||
// a copy into the pre-allocated argument.
|
||||
@ -214,7 +220,7 @@ func @dyn_broadcast(%operand: memref<?x?xf32>) {
|
||||
|
||||
// -----
|
||||
|
||||
// CHECK-LABEL: func @complex
|
||||
// BOTH-LABEL: func @complex
|
||||
func @complex(%real: memref<2x2xf32>,
|
||||
%imag: memref<2x2xf32>,
|
||||
%result: memref<2x2xcomplex<f32>>) {
|
||||
@ -222,164 +228,164 @@ func @complex(%real: memref<2x2xf32>,
|
||||
%tensor_imag = tensor_load %imag : memref<2x2xf32>
|
||||
%tensor_result = "xla_hlo.complex"(%tensor_real, %tensor_imag)
|
||||
: (tensor<2x2xf32>, tensor<2x2xf32>) -> tensor<2x2xcomplex<f32>>
|
||||
// CHECK: "xla_lhlo.complex"(%{{.*}}, %{{.*}})
|
||||
// BOTH: "xla_lhlo.complex"(%{{.*}}, %{{.*}})
|
||||
tensor_store %tensor_result, %result : memref<2x2xcomplex<f32>>
|
||||
return
|
||||
}
|
||||
|
||||
// -----
|
||||
|
||||
// CHECK-LABEL: func @real
|
||||
// BOTH-LABEL: func @real
|
||||
func @real(%operand: memref<2x2xcomplex<f32>>, %result: memref<2x2xf32>) {
|
||||
%tensor_operand = tensor_load %operand : memref<2x2xcomplex<f32>>
|
||||
%tensor_result = "xla_hlo.real"(%tensor_operand)
|
||||
: (tensor<2x2xcomplex<f32>>) -> tensor<2x2xf32>
|
||||
// CHECK: "xla_lhlo.real"(%{{.*}}, %{{.*}})
|
||||
// BOTH: "xla_lhlo.real"(%{{.*}}, %{{.*}})
|
||||
tensor_store %tensor_result, %result : memref<2x2xf32>
|
||||
return
|
||||
}
|
||||
|
||||
// -----
|
||||
|
||||
// CHECK-LABEL: func @imag
|
||||
// BOTH-LABEL: func @imag
|
||||
func @imag(%operand: memref<2x2xcomplex<f32>>, %result: memref<2x2xf32>) {
|
||||
%tensor_operand = tensor_load %operand : memref<2x2xcomplex<f32>>
|
||||
%tensor_result = "xla_hlo.imag"(%tensor_operand)
|
||||
: (tensor<2x2xcomplex<f32>>) -> tensor<2x2xf32>
|
||||
// CHECK: "xla_lhlo.imag"(%{{.*}}, %{{.*}})
|
||||
// BOTH: "xla_lhlo.imag"(%{{.*}}, %{{.*}})
|
||||
tensor_store %tensor_result, %result : memref<2x2xf32>
|
||||
return
|
||||
}
|
||||
|
||||
// -----
|
||||
|
||||
// CHECK-LABEL: func @iota
|
||||
// BOTH-LABEL: func @iota
|
||||
func @iota(%result: memref<10xi32>) {
|
||||
%tensor_result = "xla_hlo.iota"()
|
||||
{iota_dimension = 0 : i64} : () -> tensor<10xi32>
|
||||
// CHECK: "xla_lhlo.iota"(%{{.*}}) {iota_dimension = 0 : i64}
|
||||
// BOTH: "xla_lhlo.iota"(%{{.*}}) {iota_dimension = 0 : i64}
|
||||
tensor_store %tensor_result, %result : memref<10xi32>
|
||||
return
|
||||
}
|
||||
|
||||
// -----
|
||||
|
||||
// CHECK-LABEL: func @abs
|
||||
// BOTH-LABEL: func @abs
|
||||
func @abs(%operand: memref<2x2xf32>, %result: memref<2x2xf32>) {
|
||||
%tensor_operand = tensor_load %operand : memref<2x2xf32>
|
||||
%tensor_result = "xla_hlo.abs"(%tensor_operand)
|
||||
: (tensor<2x2xf32>) -> tensor<2x2xf32>
|
||||
// CHECK: "xla_lhlo.abs"(%{{.*}}, %{{.*}})
|
||||
// BOTH: "xla_lhlo.abs"(%{{.*}}, %{{.*}})
|
||||
tensor_store %tensor_result, %result : memref<2x2xf32>
|
||||
return
|
||||
}
|
||||
|
||||
// -----
|
||||
|
||||
// CHECK-LABEL: func @ceil
|
||||
// BOTH-LABEL: func @ceil
|
||||
func @ceil(%operand: memref<2x2xf32>, %result: memref<2x2xf32>) {
|
||||
%tensor_operand = tensor_load %operand : memref<2x2xf32>
|
||||
%tensor_result = "xla_hlo.ceil"(%tensor_operand)
|
||||
: (tensor<2x2xf32>) -> tensor<2x2xf32>
|
||||
// CHECK: "xla_lhlo.ceil"(%{{.*}}, %{{.*}})
|
||||
// BOTH: "xla_lhlo.ceil"(%{{.*}}, %{{.*}})
|
||||
tensor_store %tensor_result, %result : memref<2x2xf32>
|
||||
return
|
||||
}
|
||||
|
||||
// -----
|
||||
|
||||
// CHECK-LABEL: func @convert
|
||||
// BOTH-LABEL: func @convert
|
||||
func @convert(%operand: memref<2x2xf32>, %result: memref<2x2xf32>) {
|
||||
%tensor_operand = tensor_load %operand : memref<2x2xf32>
|
||||
%tensor_result = "xla_hlo.convert"(%tensor_operand)
|
||||
: (tensor<2x2xf32>) -> tensor<2x2xf32>
|
||||
// CHECK: "xla_lhlo.copy"(%{{.*}}, %{{.*}})
|
||||
// CHECK-NOT: tensor_store
|
||||
// BOTH: "xla_lhlo.copy"(%{{.*}}, %{{.*}})
|
||||
// BOTH-NOT: tensor_store
|
||||
tensor_store %tensor_result, %result : memref<2x2xf32>
|
||||
return
|
||||
}
|
||||
|
||||
// -----
|
||||
|
||||
// CHECK-LABEL: func @cos
|
||||
// BOTH-LABEL: func @cos
|
||||
func @cos(%operand: memref<2x2xf32>, %result: memref<2x2xf32>) {
|
||||
%tensor_operand = tensor_load %operand : memref<2x2xf32>
|
||||
%tensor_result = "xla_hlo.cosine"(%tensor_operand)
|
||||
: (tensor<2x2xf32>) -> tensor<2x2xf32>
|
||||
// CHECK: "xla_lhlo.cosine"(%{{.*}}, %{{.*}})
|
||||
// BOTH: "xla_lhlo.cosine"(%{{.*}}, %{{.*}})
|
||||
tensor_store %tensor_result, %result : memref<2x2xf32>
|
||||
return
|
||||
}
|
||||
|
||||
// -----
|
||||
|
||||
// CHECK-LABEL: func @neg
|
||||
// BOTH-LABEL: func @neg
|
||||
func @neg(%operand: memref<2x2xf32>, %result: memref<2x2xf32>) {
|
||||
%tensor_operand = tensor_load %operand : memref<2x2xf32>
|
||||
%tensor_result = "xla_hlo.negate"(%tensor_operand)
|
||||
: (tensor<2x2xf32>) -> tensor<2x2xf32>
|
||||
// CHECK: "xla_lhlo.negate"(%{{.*}}, %{{.*}})
|
||||
// BOTH: "xla_lhlo.negate"(%{{.*}}, %{{.*}})
|
||||
tensor_store %tensor_result, %result : memref<2x2xf32>
|
||||
return
|
||||
}
|
||||
|
||||
// -----
|
||||
|
||||
// CHECK-LABEL: func @rsqrt
|
||||
// BOTH-LABEL: func @rsqrt
|
||||
func @rsqrt(%operand: memref<2x2xf32>, %result: memref<2x2xf32>) {
|
||||
%tensor_operand = tensor_load %operand : memref<2x2xf32>
|
||||
%tensor_result = "xla_hlo.rsqrt"(%tensor_operand)
|
||||
: (tensor<2x2xf32>) -> tensor<2x2xf32>
|
||||
// CHECK: "xla_lhlo.rsqrt"(%{{.*}}, %{{.*}})
|
||||
// BOTH: "xla_lhlo.rsqrt"(%{{.*}}, %{{.*}})
|
||||
tensor_store %tensor_result, %result : memref<2x2xf32>
|
||||
return
|
||||
}
|
||||
|
||||
// -----
|
||||
|
||||
// CHECK-LABEL: func @sign
|
||||
// BOTH-LABEL: func @sign
|
||||
func @sign(%operand: memref<2x2xf32>, %result: memref<2x2xf32>) {
|
||||
%tensor_operand = tensor_load %operand : memref<2x2xf32>
|
||||
%tensor_result = "xla_hlo.sign"(%tensor_operand)
|
||||
: (tensor<2x2xf32>) -> tensor<2x2xf32>
|
||||
// CHECK: "xla_lhlo.sign"(%{{.*}}, %{{.*}})
|
||||
// BOTH: "xla_lhlo.sign"(%{{.*}}, %{{.*}})
|
||||
tensor_store %tensor_result, %result : memref<2x2xf32>
|
||||
return
|
||||
}
|
||||
|
||||
// -----
|
||||
|
||||
// CHECK-LABEL: func @sqrt
|
||||
// BOTH-LABEL: func @sqrt
|
||||
func @sqrt(%operand: memref<2x2xf32>, %result: memref<2x2xf32>) {
|
||||
%tensor_operand = tensor_load %operand : memref<2x2xf32>
|
||||
%tensor_result = "xla_hlo.sqrt"(%tensor_operand)
|
||||
: (tensor<2x2xf32>) -> tensor<2x2xf32>
|
||||
// CHECK: "xla_lhlo.sqrt"(%{{.*}}, %{{.*}})
|
||||
// BOTH: "xla_lhlo.sqrt"(%{{.*}}, %{{.*}})
|
||||
tensor_store %tensor_result, %result : memref<2x2xf32>
|
||||
return
|
||||
}
|
||||
|
||||
// -----
|
||||
|
||||
// CHECK-LABEL: func @tanh
|
||||
// BOTH-LABEL: func @tanh
|
||||
func @tanh(%operand: memref<2x2xf32>, %result: memref<2x2xf32>) {
|
||||
%tensor_operand = tensor_load %operand : memref<2x2xf32>
|
||||
%tensor_result = "xla_hlo.tanh"(%tensor_operand)
|
||||
: (tensor<2x2xf32>) -> tensor<2x2xf32>
|
||||
// CHECK: "xla_lhlo.tanh"(%{{.*}}, %{{.*}})
|
||||
// BOTH: "xla_lhlo.tanh"(%{{.*}}, %{{.*}})
|
||||
tensor_store %tensor_result, %result : memref<2x2xf32>
|
||||
return
|
||||
}
|
||||
|
||||
// -----
|
||||
|
||||
// CHECK-LABEL: func @remainder
|
||||
// BOTH-LABEL: func @remainder
|
||||
func @remainder(%lhs: memref<2x2xf32>, %rhs: memref<2x2xf32>, %result: memref<2x2xf32>) {
|
||||
%tensor_lhs = tensor_load %lhs : memref<2x2xf32>
|
||||
%tensor_rhs = tensor_load %rhs : memref<2x2xf32>
|
||||
%tensor_result = "xla_hlo.remainder"(%tensor_lhs, %tensor_rhs)
|
||||
: (tensor<2x2xf32>, tensor<2x2xf32>) -> tensor<2x2xf32>
|
||||
// CHECK: "xla_lhlo.remainder"(%{{.*}}, %{{.*}}, %{{.*}})
|
||||
// BOTH: "xla_lhlo.remainder"(%{{.*}}, %{{.*}}, %{{.*}})
|
||||
tensor_store %tensor_result, %result : memref<2x2xf32>
|
||||
return
|
||||
}
|
||||
@ -387,76 +393,79 @@ func @remainder(%lhs: memref<2x2xf32>, %rhs: memref<2x2xf32>, %result: memref<2x
|
||||
// -----
|
||||
|
||||
// Dynamic shape binary element-wise operation.
|
||||
// CHECK-LABEL: func @add_dyn
|
||||
// BOTH-LABEL: func @add_dyn
|
||||
func @add_dyn(%lhs: tensor<?x?xf32>, %rhs: tensor<?x?xf32>) {
|
||||
%result = "xla_hlo.add"(%lhs, %rhs)
|
||||
: (tensor<?x?xf32>, tensor<?x?xf32>) -> tensor<?x?xf32>
|
||||
// CHECK: %[[C0:.*]] = constant 0 : index
|
||||
// CHECK: %[[DIM0:.*]] = dim %arg0, %[[C0]] : memref<?x?xf32>
|
||||
// CHECK: %[[IC0:.*]] = index_cast %[[DIM0]] : index to i64
|
||||
// CHECK: %[[C1:.*]] = constant 1 : index
|
||||
// CHECK: %[[DIM1:.*]] = dim %arg0, %[[C1]] : memref<?x?xf32>
|
||||
// CHECK: %[[IC1:.*]] = index_cast %[[DIM1]] : index to i64
|
||||
// CHECK: %[[SHAPE:.*]] = tensor_from_elements(%[[IC0]], %[[IC1]]) : tensor<2xi64>
|
||||
// CHECK: %[[C0_:.*]] = constant 0 : index
|
||||
// CHECK: %[[EE0:.*]] = extract_element %[[SHAPE]][%[[C0_]]] : tensor<2xi64>
|
||||
// CHECK: %[[ICS0:.*]] = index_cast %[[EE0]] : i64 to index
|
||||
// CHECK: %[[C1_:.*]] = constant 1 : index
|
||||
// CHECK: %[[EE1:.*]] = extract_element %[[SHAPE]][%[[C1_]]] : tensor<2xi64>
|
||||
// CHECK: %[[ICS1:.*]] = index_cast %[[EE1]] : i64 to index
|
||||
// CHECK: %[[RESULT:.*]] = alloc(%[[ICS0]], %[[ICS1]])
|
||||
// CHECK: "xla_lhlo.add"(%arg0, %arg1, %[[RESULT]]) : (memref<?x?xf32>, memref<?x?xf32>, memref<?x?xf32>) -> ()
|
||||
// BOTH: %[[C0:.*]] = constant 0 : index
|
||||
// BOTH: %[[DIM0:.*]] = dim %arg0, %[[C0]] : memref<?x?xf32>
|
||||
// BOTH: %[[IC0:.*]] = index_cast %[[DIM0]] : index to i64
|
||||
// BOTH: %[[C1:.*]] = constant 1 : index
|
||||
// BOTH: %[[DIM1:.*]] = dim %arg0, %[[C1]] : memref<?x?xf32>
|
||||
// BOTH: %[[IC1:.*]] = index_cast %[[DIM1]] : index to i64
|
||||
// BOTH: %[[SHAPE:.*]] = tensor_from_elements(%[[IC0]], %[[IC1]]) : tensor<2xi64>
|
||||
// BOTH: %[[C0_:.*]] = constant 0 : index
|
||||
// BOTH: %[[EE0:.*]] = extract_element %[[SHAPE]][%[[C0_]]] : tensor<2xi64>
|
||||
// BOTH: %[[ICS0:.*]] = index_cast %[[EE0]] : i64 to index
|
||||
// BOTH: %[[C1_:.*]] = constant 1 : index
|
||||
// BOTH: %[[EE1:.*]] = extract_element %[[SHAPE]][%[[C1_]]] : tensor<2xi64>
|
||||
// BOTH: %[[ICS1:.*]] = index_cast %[[EE1]] : i64 to index
|
||||
// BOTH: %[[RESULT:.*]] = alloc(%[[ICS0]], %[[ICS1]])
|
||||
// BOTH: "xla_lhlo.add"(%arg0, %arg1, %[[RESULT]]) : (memref<?x?xf32>, memref<?x?xf32>, memref<?x?xf32>) -> ()
|
||||
return
|
||||
}
|
||||
|
||||
// -----
|
||||
|
||||
// Dynamic shape unary element-wise operation.
|
||||
// CHECK-LABEL: func @tanh_dyn
|
||||
// BOTH-LABEL: func @tanh_dyn
|
||||
func @tanh_dyn(%arg0: tensor<?x?xf32>) {
|
||||
%result = "xla_hlo.tanh"(%arg0)
|
||||
: (tensor<?x?xf32>) -> tensor<?x?xf32>
|
||||
// CHECK: %[[C0:.*]] = constant 0 : index
|
||||
// CHECK: %[[DIM0:.*]] = dim %arg0, %[[C0]] : memref<?x?xf32>
|
||||
// CHECK: %[[IC0:.*]] = index_cast %[[DIM0]] : index to i64
|
||||
// CHECK: %[[C1:.*]] = constant 1 : index
|
||||
// CHECK: %[[DIM1:.*]] = dim %arg0, %[[C1]] : memref<?x?xf32>
|
||||
// CHECK: %[[IC1:.*]] = index_cast %[[DIM1]] : index to i64
|
||||
// CHECK: %[[SHAPE:.*]] = tensor_from_elements(%[[IC0]], %[[IC1]]) : tensor<2xi64>
|
||||
// CHECK: %[[C0_:.*]] = constant 0 : index
|
||||
// CHECK: %[[EE0:.*]] = extract_element %[[SHAPE]][%[[C0_]]] : tensor<2xi64>
|
||||
// CHECK: %[[ICS0:.*]] = index_cast %[[EE0]] : i64 to index
|
||||
// CHECK: %[[C1_:.*]] = constant 1 : index
|
||||
// CHECK: %[[EE1:.*]] = extract_element %[[SHAPE]][%[[C1_]]] : tensor<2xi64>
|
||||
// CHECK: %[[ICS1:.*]] = index_cast %[[EE1]] : i64 to index
|
||||
// CHECK: %[[RESULT:.*]] = alloc(%[[ICS0]], %[[ICS1]])
|
||||
// CHECK: "xla_lhlo.tanh"(%arg0, %[[RESULT]]) : (memref<?x?xf32>, memref<?x?xf32>) -> ()
|
||||
// BOTH: %[[C0:.*]] = constant 0 : index
|
||||
// BOTH: %[[DIM0:.*]] = dim %arg0, %[[C0]] : memref<?x?xf32>
|
||||
// BOTH: %[[IC0:.*]] = index_cast %[[DIM0]] : index to i64
|
||||
// BOTH: %[[C1:.*]] = constant 1 : index
|
||||
// BOTH: %[[DIM1:.*]] = dim %arg0, %[[C1]] : memref<?x?xf32>
|
||||
// BOTH: %[[IC1:.*]] = index_cast %[[DIM1]] : index to i64
|
||||
// BOTH: %[[SHAPE:.*]] = tensor_from_elements(%[[IC0]], %[[IC1]]) : tensor<2xi64>
|
||||
// BOTH: %[[C0_:.*]] = constant 0 : index
|
||||
// BOTH: %[[EE0:.*]] = extract_element %[[SHAPE]][%[[C0_]]] : tensor<2xi64>
|
||||
// BOTH: %[[ICS0:.*]] = index_cast %[[EE0]] : i64 to index
|
||||
// BOTH: %[[C1_:.*]] = constant 1 : index
|
||||
// BOTH: %[[EE1:.*]] = extract_element %[[SHAPE]][%[[C1_]]] : tensor<2xi64>
|
||||
// BOTH: %[[ICS1:.*]] = index_cast %[[EE1]] : i64 to index
|
||||
// BOTH: %[[RESULT:.*]] = alloc(%[[ICS0]], %[[ICS1]])
|
||||
// BOTH: "xla_lhlo.tanh"(%arg0, %[[RESULT]]) : (memref<?x?xf32>, memref<?x?xf32>) -> ()
|
||||
return
|
||||
}
|
||||
|
||||
// -----
|
||||
|
||||
// CHECK-LABEL: func @dot
|
||||
// BOTH-LABEL: func @dot
|
||||
func @dot(%arg0: tensor<1024x1024xf32>) -> tensor<1024x1024xf32> {
|
||||
// CHECK-SAME: (%[[ARG0:.*]]: [[TYPE:.*]],
|
||||
// CHECK-SAME: %[[RESULT:.*]]: [[TYPE]])
|
||||
// CHECK: "xla_lhlo.dot"(%[[ARG0]], %[[ARG0]], %{{.*}}) : ([[TYPE]], [[TYPE]], [[TYPE]]) -> ()
|
||||
// PRE-SAME: (%[[ARG0:.*]]: [[TYPE:.*]], %[[RESULT:.*]]: [[TYPE]])
|
||||
// ESC-SAME: (%[[ARG0:.*]]: [[TYPE:.*]]) -> [[TYPE]]
|
||||
// BOTH-NEXT: %[[ALLOC:.*]] = alloc
|
||||
// BOTH: "xla_lhlo.dot"(%[[ARG0]], %[[ARG0]], %[[ALLOC]]) : ([[TYPE]], [[TYPE]], [[TYPE]]) -> ()
|
||||
%dot = "xla_hlo.dot"(%arg0, %arg0)
|
||||
: (tensor<1024x1024xf32>, tensor<1024x1024xf32>) -> tensor<1024x1024xf32>
|
||||
// PRE: "xla_lhlo.copy"(%[[ALLOC]], %[[RESULT]])
|
||||
// ESC: return %[[ALLOC]]
|
||||
return %dot : tensor<1024x1024xf32>
|
||||
}
|
||||
|
||||
// -----
|
||||
|
||||
// CHECK-LABEL: func @conv
|
||||
// BOTH-LABEL: func @conv
|
||||
func @conv(%input: tensor<3x5x5x3xf32>, %filter : tensor<2x2x3x4xf32>) -> tensor<3x5x5x4xf32> {
|
||||
%c0 = constant 0 : index
|
||||
// CHECK: %[[OUT:.*]] = alloc() : memref<3x5x5x4xf32>
|
||||
// CHECK: "xla_lhlo.convolution"(%{{.+}}, %{{.+}}, %[[OUT]])
|
||||
// CHECK-SAME: padding = dense<[
|
||||
// CHECK-SAME: [0, 1], [0, 1]]> : tensor<2x2xi64>
|
||||
// CHECK-SAME: rhs_dilation = dense<[1, 2]>
|
||||
// CHECK-SAME: window_strides = dense<[2, 1]>
|
||||
// BOTH: %[[OUT:.*]] = alloc() : memref<3x5x5x4xf32>
|
||||
// BOTH: "xla_lhlo.convolution"(%{{.+}}, %{{.+}}, %[[OUT]])
|
||||
// BOTH-SAME: padding = dense<[
|
||||
// BOTH-SAME: [0, 1], [0, 1]]> : tensor<2x2xi64>
|
||||
// BOTH-SAME: rhs_dilation = dense<[1, 2]>
|
||||
// BOTH-SAME: window_strides = dense<[2, 1]>
|
||||
%out = "xla_hlo.convolution"(%filter, %input) {
|
||||
batch_group_count = 1 : i64,
|
||||
dimension_numbers = {
|
||||
|
@ -1,13 +1,12 @@
|
||||
// RUN: xla-opt -lhlo-fuse-linalg %s -o - | FileCheck %s --dump-input=always
|
||||
// RUN: xla-opt -lhlo-fuse-linalg=tile-sizes=2,3 %s -o - | FileCheck %s -check-prefix=TILED
|
||||
// RUN: xla-opt -lhlo-fuse-linalg=use-parallel-loops %s -o - | FileCheck %s -check-prefix=PLOOP
|
||||
|
||||
// RUN: xla-opt -lhlo-fuse-linalg %s -split-input-file | FileCheck %s --dump-input=always
|
||||
// RUN: xla-opt -lhlo-fuse-linalg=tile-sizes=2,3 %s -split-input-file | FileCheck %s -check-prefix=TILED
|
||||
// RUN: xla-opt -lhlo-fuse-linalg=use-parallel-loops %s -split-input-file | FileCheck %s -check-prefix=PLOOP
|
||||
|
||||
#map0 = affine_map<(d0, d1) -> (d0, d1)>
|
||||
#pointwise_2d_trait = {args_in = 2, args_out = 1, indexing_maps = [#map0, #map0, #map0], iterator_types = ["parallel", "parallel"]}
|
||||
func @fusion(%multiplier: memref<6x6xf32>, %summand_1: memref<6x6xf32>,
|
||||
%summand_2: memref<6x6xf32>, %result: memref<6x6xf32>) {
|
||||
%temp_result = alloc() {temp = true} : memref<6x6xf32>
|
||||
%temp_result = alloc() : memref<6x6xf32>
|
||||
linalg.generic #pointwise_2d_trait %summand_1, %summand_2, %temp_result {
|
||||
^bb0(%summand_1_in: f32, %summand_2_in: f32, %temp_result_in: f32):
|
||||
%out = addf %summand_1_in, %summand_2_in : f32
|
||||
@ -19,7 +18,7 @@ func @fusion(%multiplier: memref<6x6xf32>, %summand_1: memref<6x6xf32>,
|
||||
linalg.yield %out : f32
|
||||
} : memref<6x6xf32>, memref<6x6xf32>, memref<6x6xf32>
|
||||
dealloc %temp_result : memref<6x6xf32>
|
||||
"xla_lhlo.terminator"() : () -> ()
|
||||
return
|
||||
}
|
||||
// CHECK-LABEL: func @fusion
|
||||
// CHECK: %[[C1:.*]] = constant 1
|
||||
@ -53,10 +52,12 @@ func @fusion(%multiplier: memref<6x6xf32>, %summand_1: memref<6x6xf32>,
|
||||
// PLOOP: linalg.generic
|
||||
// PLOOP: mulf
|
||||
|
||||
// -----
|
||||
|
||||
func @fusion_of_three(%arg0: memref<100x10xf32>,
|
||||
%arg1: memref<100xf32>,
|
||||
%arg2: memref<100x10xf32>) {
|
||||
%0 = alloc() {temp = true} : memref<100x10xf32>
|
||||
%0 = alloc() : memref<100x10xf32>
|
||||
linalg.generic {
|
||||
args_in = 1 : i64,
|
||||
args_out = 1 : i64,
|
||||
@ -66,7 +67,7 @@ func @fusion_of_three(%arg0: memref<100x10xf32>,
|
||||
^bb0(%arg3: f32, %arg4: f32): // no predecessors
|
||||
linalg.yield %arg3 : f32
|
||||
}: memref<100xf32>, memref<100x10xf32>
|
||||
%1 = alloc() {temp = true} : memref<100x10xf32>
|
||||
%1 = alloc() : memref<100x10xf32>
|
||||
linalg.generic {
|
||||
args_in = 2 : i64,
|
||||
args_out = 1 : i64,
|
||||
@ -126,11 +127,13 @@ func @fusion_of_three(%arg0: memref<100x10xf32>,
|
||||
// PLOOP: linalg.generic
|
||||
// PLOOP: exp
|
||||
|
||||
#map1 = affine_map<(d0, d1, d2, d3) -> (d0, d1, d2, d3)>
|
||||
#pointwise_4d_trait = {args_in = 2, args_out = 1, indexing_maps = [#map1, #map1, #map1], iterator_types = ["parallel", "parallel", "parallel", "parallel"]}
|
||||
// -----
|
||||
|
||||
#map0 = affine_map<(d0, d1, d2, d3) -> (d0, d1, d2, d3)>
|
||||
#pointwise_4d_trait = {args_in = 2, args_out = 1, indexing_maps = [#map0, #map0, #map0], iterator_types = ["parallel", "parallel", "parallel", "parallel"]}
|
||||
func @fusion_4d(%multiplier: memref<6x6x6x6xf32>, %summand_1: memref<6x6x6x6xf32>,
|
||||
%summand_2: memref<6x6x6x6xf32>, %result: memref<6x6x6x6xf32>) {
|
||||
%temp_result = alloc() {temp = true} : memref<6x6x6x6xf32>
|
||||
%temp_result = alloc() : memref<6x6x6x6xf32>
|
||||
linalg.generic #pointwise_4d_trait %summand_1, %summand_2, %temp_result {
|
||||
^bb0(%summand_1_in: f32, %summand_2_in: f32, %temp_result_in: f32):
|
||||
%out = addf %summand_1_in, %summand_2_in : f32
|
||||
@ -142,7 +145,7 @@ func @fusion_4d(%multiplier: memref<6x6x6x6xf32>, %summand_1: memref<6x6x6x6xf32
|
||||
linalg.yield %out : f32
|
||||
} : memref<6x6x6x6xf32>, memref<6x6x6x6xf32>, memref<6x6x6x6xf32>
|
||||
dealloc %temp_result : memref<6x6x6x6xf32>
|
||||
"xla_lhlo.terminator"() : () -> ()
|
||||
return
|
||||
}
|
||||
// CHECK-LABEL: func @fusion_4d
|
||||
// CHECK: %[[C1:.*]] = constant 1
|
||||
@ -177,3 +180,57 @@ func @fusion_4d(%multiplier: memref<6x6x6x6xf32>, %summand_1: memref<6x6x6x6xf32
|
||||
// PLOOP: addf
|
||||
// PLOOP: linalg.generic
|
||||
// PLOOP: mulf
|
||||
|
||||
// -----
|
||||
|
||||
#map0 = affine_map<(d0, d1) -> (d0, d1)>
|
||||
#pointwise_2d_trait = {args_in = 2, args_out = 1, indexing_maps = [#map0, #map0, #map0], iterator_types = ["parallel", "parallel"]}
|
||||
func @fusion(%multiplier: memref<6x6xf32>, %summand_1: memref<6x6xf32>,
|
||||
%summand_2: memref<6x6xf32>) -> memref<6x6xf32> {
|
||||
%temp_result = alloc() : memref<6x6xf32>
|
||||
linalg.generic #pointwise_2d_trait %summand_1, %summand_2, %temp_result {
|
||||
^bb0(%summand_1_in: f32, %summand_2_in: f32, %temp_result_in: f32):
|
||||
%out = addf %summand_1_in, %summand_2_in : f32
|
||||
linalg.yield %out : f32
|
||||
} : memref<6x6xf32>, memref<6x6xf32>, memref<6x6xf32>
|
||||
%result = alloc() : memref<6x6xf32>
|
||||
linalg.generic #pointwise_2d_trait %temp_result, %multiplier, %result {
|
||||
^bb0(%temp_result_in: f32, %multiplier_in: f32, %result_in: f32):
|
||||
%out = mulf %temp_result_in, %multiplier_in : f32
|
||||
linalg.yield %out : f32
|
||||
} : memref<6x6xf32>, memref<6x6xf32>, memref<6x6xf32>
|
||||
dealloc %temp_result : memref<6x6xf32>
|
||||
return %result : memref<6x6xf32>
|
||||
}
|
||||
|
||||
// CHECK-LABEL: func @fusion
|
||||
// CHECK: %[[C1:.*]] = constant 1
|
||||
// CHECK-NOT: linalg.generic
|
||||
// CHECK: scf.for {{.*}} step %[[C1]]
|
||||
// CHECK: scf.for {{.*}} step %[[C1]]
|
||||
// CHECK-NOT: scf.for
|
||||
// CHECK: linalg.generic
|
||||
// CHECK: addf
|
||||
// CHECK: linalg.generic
|
||||
// CHECK: mulf
|
||||
|
||||
// TILED-LABEL: func @fusion
|
||||
// TILED-DAG: %[[C2:.*]] = constant 2
|
||||
// TILED-DAG: %[[C3:.*]] = constant 3
|
||||
// TILED-NOT: linalg.generic
|
||||
// TILED: scf.for {{.*}} step %[[C2]]
|
||||
// TILED: scf.for {{.*}} step %[[C3]]
|
||||
// TILED-NOT: scf.for
|
||||
// TILED: linalg.generic
|
||||
// TILED: addf
|
||||
// TILED: linalg.generic
|
||||
// TILED: mulf
|
||||
|
||||
// PLOOP-LABEL: func @fusion
|
||||
// PLOOP-NOT: linalg.generic
|
||||
// PLOOP: scf.parallel
|
||||
// PLOOP-NOT: scf.parallel
|
||||
// PLOOP: linalg.generic
|
||||
// PLOOP: addf
|
||||
// PLOOP: linalg.generic
|
||||
// PLOOP: mulf
|
||||
|
@ -38,15 +38,15 @@ func @select_and_scatter(%arg: memref<112x112xf32>,
|
||||
// CHECK-SAME: [[RESULT_BUF:%.*]]: memref<112x112xf32>) {
|
||||
|
||||
// Constants.
|
||||
// CHECK: [[C56:%.*]] = constant 56 : index
|
||||
// CHECK: [[C1:%.*]] = constant 1 : index
|
||||
// CHECK: [[C0_F32:%.*]] = constant 0.000000e+00 : f32
|
||||
// CHECK: [[CFALSE:%.*]] = constant false
|
||||
// CHECK: [[C3:%.*]] = constant 3 : index
|
||||
// CHECK: [[C2:%.*]] = constant 2 : index
|
||||
// CHECK: [[C0:%.*]] = constant 0 : index
|
||||
// CHECK: [[C112:%.*]] = constant 112 : index
|
||||
// CHECK: [[CTRUE:%.*]] = constant true
|
||||
// CHECK-DAG: [[C56:%.*]] = constant 56 : index
|
||||
// CHECK-DAG: [[C0:%.*]] = constant 0 : index
|
||||
// CHECK-DAG: [[C1:%.*]] = constant 1 : index
|
||||
// CHECK-DAG: [[C0_F32:%.*]] = constant 0.000000e+00 : f32
|
||||
// CHECK-DAG: [[CFALSE:%.*]] = constant false
|
||||
// CHECK-DAG: [[C3:%.*]] = constant 3 : index
|
||||
// CHECK-DAG: [[C2:%.*]] = constant 2 : index
|
||||
// CHECK-DAG: [[C112:%.*]] = constant 112 : index
|
||||
// CHECK-DAG: [[CTRUE:%.*]] = constant true
|
||||
|
||||
// Parallel loop to initialize the output buffer.
|
||||
// CHECK: [[INIT:%.*]] = load [[INIT_BUF]][] : memref<f32>
|
||||
@ -80,23 +80,17 @@ func @select_and_scatter(%arg: memref<112x112xf32>,
|
||||
|
||||
// Compute index I of the ARG buffer and check whether it is in padding area.
|
||||
// CHECK: [[START_I:%.*]] = muli [[II]], [[C2]] : index
|
||||
// CHECK: [[OFFSET_I:%.*]] = subi [[WIN_I]], [[C0]] : index
|
||||
// CHECK: [[ARG_I:%.*]] = addi [[START_I]], [[OFFSET_I]] : index
|
||||
// CHECK: [[ARG_I:%.*]] = addi [[START_I]], [[WIN_I]] : index
|
||||
// CHECK: [[ARG_I_FITS:%.*]] = cmpi "ult", [[ARG_I]], [[C112]] : index
|
||||
|
||||
// Update `INBOUNDS`, i.e. whether or not ARG indices are inside the boundaries
|
||||
// of the buffer or they are in the padding area.
|
||||
// CHECK: [[INBOUNDS_0:%.*]] = and [[ARG_I_FITS]], [[CTRUE]] : i1
|
||||
|
||||
// Compute index J of the ARG buffer and check whether it is in padding area.
|
||||
// CHECK: [[START_J:%.*]] = muli [[JJ]], [[C2]] : index
|
||||
// CHECK: [[OFFSET_J:%.*]] = subi [[WIN_J]], [[C0]] : index
|
||||
// CHECK: [[ARG_J:%.*]] = addi [[START_J]], [[OFFSET_J]] : index
|
||||
// CHECK: [[ARG_J:%.*]] = addi [[START_J]], [[WIN_J]] : index
|
||||
// CHECK: [[ARG_J_FITS:%.*]] = cmpi "ult", [[ARG_J]], [[C112]] : index
|
||||
|
||||
// Update `INBOUNDS`, i.e. whether or not ARG indices are inside the boundaries
|
||||
// of the buffer or they are in the padding area.
|
||||
// CHECK: [[INBOUNDS_1:%.*]] = and [[INBOUNDS_0]], [[ARG_J_FITS]] : i1
|
||||
// CHECK: [[INBOUNDS_1:%.*]] = and [[ARG_I_FITS]], [[ARG_J_FITS]] : i1
|
||||
|
||||
// If ARG ivs are in the padding area, then 'select' function does not have to
|
||||
// be applied, current selected ivs (SEL_I, SEL_J) and value (SEL_VAL) are
|
||||
|
@ -151,7 +151,6 @@ func @reduce_window(%arg: memref<112x112xf32>,
|
||||
// CHECK-SAME: [[OPERAND_BUF:%.*]]: memref<112x112xf32>,
|
||||
// CHECK-SAME: [[INIT_BUF:%.*]]: memref<f32>,
|
||||
// CHECK-SAME: [[RESULT_BUF:%.*]]: memref<56x56xf32>) {
|
||||
// CHECK-DAG: [[IN_BOUNDS:%.*]] = constant true
|
||||
// CHECK-DAG: [[C0:%.*]] = constant 0 : index
|
||||
// CHECK-DAG: [[C1:%.*]] = constant 1 : index
|
||||
// CHECK-DAG: [[C2:%.*]] = constant 2 : index
|
||||
@ -167,16 +166,13 @@ func @reduce_window(%arg: memref<112x112xf32>,
|
||||
// CHECK-SAME: init ([[INIT]]) -> f32 {
|
||||
|
||||
// CHECK: [[START_I:%.*]] = muli [[I]], [[C2]] : index
|
||||
// CHECK: [[OFFSET_I:%.*]] = subi [[IW]], [[C0]] : index
|
||||
// CHECK: [[INDEX_I:%.*]] = addi [[START_I]], [[OFFSET_I]] : index
|
||||
// CHECK: [[INDEX_I:%.*]] = addi [[START_I]], [[IW]] : index
|
||||
// CHECK: [[INDEX_I_FITS:%.*]] = cmpi "ult", [[INDEX_I]], [[C112]]
|
||||
// CHECK: [[IN_BOUNDS_0:%.*]] = and [[INDEX_I_FITS]], [[IN_BOUNDS]]
|
||||
|
||||
// CHECK: [[START_J:%.*]] = muli [[J]], [[C2]] : index
|
||||
// CHECK: [[OFFSET_J:%.*]] = subi [[JW]], [[C0]] : index
|
||||
// CHECK: [[INDEX_J:%.*]] = addi [[START_J]], [[OFFSET_J]] : index
|
||||
// CHECK: [[INDEX_J:%.*]] = addi [[START_J]], [[JW]] : index
|
||||
// CHECK: [[INDEX_J_FITS:%.*]] = cmpi "ult", [[INDEX_J]], [[C112]]
|
||||
// CHECK: [[IN_BOUNDS_1:%.*]] = and [[IN_BOUNDS_0]], [[INDEX_J_FITS]]
|
||||
// CHECK: [[IN_BOUNDS_1:%.*]] = and [[INDEX_I_FITS]], [[INDEX_J_FITS]]
|
||||
|
||||
// CHECK: [[ELEM_TO_REDUCE:%.*]] = scf.if [[IN_BOUNDS_1]] -> (f32) {
|
||||
// CHECK: [[OPERAND_ELEM:%.*]] =
|
||||
|
@ -863,3 +863,124 @@ func @while_memrefs(%arg0: memref<i64>, %arg_out: memref<i64>) -> () {
|
||||
) : (memref<i64>, memref<i64>) -> ()
|
||||
return
|
||||
}
|
||||
|
||||
// -----
|
||||
|
||||
// CHECK-LABEL: func @bitcast_memrefs
|
||||
func @bitcast_memrefs(%arg0: memref<1xf64>, %arg_out: memref<2xi32>) -> () {
|
||||
"xla_lhlo.bitcast"(%arg0, %arg_out) : (memref<1xf64>, memref<2xi32>) -> ()
|
||||
return
|
||||
}
|
||||
|
||||
// -----
|
||||
|
||||
// CHECK-LABEL: func @scatter_memrefs
|
||||
func @scatter_memrefs(%input: memref<200x100x300xf32>, %indices: memref<10x2xi32>,
|
||||
%updates: memref<10x300xf32>, %arg_out: memref<200x100x300xf32>) -> () {
|
||||
"xla_lhlo.scatter" (%input, %indices, %updates, %arg_out) ({
|
||||
^bb0(%lhs: tensor<f32>, %rhs: tensor<f32>): // no predecessors
|
||||
%add = xla_hlo.add %lhs, %rhs : tensor<f32>
|
||||
"xla_hlo.return"(%add) : (tensor<f32>) -> ()
|
||||
}) {
|
||||
scatter_dimension_numbers = {
|
||||
update_window_dims = dense<[1]> : tensor<1xi64>,
|
||||
inserted_window_dims = dense<[0, 1]> : tensor<2xi64>,
|
||||
scatter_dims_to_operand_dims = dense<[0, 1]> : tensor<2xi64>,
|
||||
index_vector_dim = 1 : i64
|
||||
},
|
||||
indices_are_sorted = true,
|
||||
unique_indices = true
|
||||
} : (memref<200x100x300xf32>, memref<10x2xi32>, memref<10x300xf32>, memref<200x100x300xf32>) -> ()
|
||||
return
|
||||
}
|
||||
|
||||
// -----
|
||||
|
||||
// CHECK-LABEL: func @map_memrefs
|
||||
func @map_memrefs(%arg0: memref<20xf32>, %arg1: memref<20xf32>, %arg_out: memref<20xf32>) -> () {
|
||||
"xla_lhlo.map"(%arg0, %arg1, %arg_out) ({
|
||||
^bb0(%a: tensor<f32>, %b: tensor<f32>):
|
||||
%c = xla_hlo.add %a, %b : tensor<f32>
|
||||
"xla_hlo.return"(%c) : (tensor<f32>) -> ()
|
||||
}) {dimensions = dense<0> : tensor<1xi64>} : (memref<20xf32>, memref<20xf32>, memref<20xf32>) -> ()
|
||||
return
|
||||
}
|
||||
|
||||
// -----
|
||||
|
||||
func @map_memrefs(%arg0: memref<20xf32>, %arg1: memref<20xf32>, %arg_out: memref<10xf32>) -> () {
|
||||
// expected-error@+1{{requires the same shape for all operands}}
|
||||
"xla_lhlo.map"(%arg0, %arg1, %arg_out) ({
|
||||
^bb0(%a: tensor<f32>, %b: tensor<f32>):
|
||||
%c = xla_hlo.add %a, %b : tensor<f32>
|
||||
"xla_hlo.return"(%c) : (tensor<f32>) -> ()
|
||||
}) {dimensions = dense<0> : tensor<1xi64>} : (memref<20xf32>, memref<20xf32>, memref<10xf32>) -> ()
|
||||
return
|
||||
}
|
||||
|
||||
// -----
|
||||
|
||||
// CHECK-LABEL: func @rng_get_and_update_state_memrefs
|
||||
func @rng_get_and_update_state_memrefs(%state: memref<1xui64>) -> () {
|
||||
"xla_lhlo.rng_get_and_update_state"(%state) { delta = 1 : i64 } : (memref<1xui64>) -> ()
|
||||
return
|
||||
}
|
||||
|
||||
// -----
|
||||
|
||||
// CHECK-LABEL: func @sort_memrefs
|
||||
func @sort_memrefs(%arg0: memref<16x16xf32>, %arg1: memref<16x16xf16>,
|
||||
%arg_out: tuple<memref<16x16xf32>, memref<16x16xf16>>) -> () {
|
||||
"xla_lhlo.sort"(%arg0, %arg1, %arg_out) ( {
|
||||
^bb0(%a: tensor<f32>, %b: tensor<f32>, %c: tensor<f16>, %d: tensor<f16>):
|
||||
%7 = "xla_hlo.compare"(%a, %b) {comparison_direction = "GT"} : (tensor<f32>, tensor<f32>) -> tensor<i1>
|
||||
"xla_hlo.return"(%7) : (tensor<i1>) -> ()
|
||||
}) {dimension = 1 : i64, is_stable = true} : (memref<16x16xf32>, memref<16x16xf16>, tuple<memref<16x16xf32>, memref<16x16xf16>>) -> ()
|
||||
return
|
||||
}
|
||||
|
||||
// -----
|
||||
|
||||
// CHECK-LABEL: func @sort_memrefs
|
||||
func @sort_memrefs(%arg0: memref<16x16xf32>, %arg1: memref<16x16xf16>,
|
||||
%arg_out: tuple<memref<16x16xf32>, memref<16x16xf16>>) -> () {
|
||||
"xla_lhlo.sort"(%arg0, %arg1, %arg_out) ( {
|
||||
^bb0(%a: tensor<f32>, %b: tensor<f32>, %c: tensor<f16>, %d: tensor<f16>):
|
||||
%7 = "xla_hlo.compare"(%a, %b) {comparison_direction = "GT"} : (tensor<f32>, tensor<f32>) -> tensor<i1>
|
||||
"xla_hlo.return"(%7) : (tensor<i1>) -> ()
|
||||
}) {dimension = 1 : i64} : (memref<16x16xf32>, memref<16x16xf16>, tuple<memref<16x16xf32>, memref<16x16xf16>>) -> ()
|
||||
return
|
||||
}
|
||||
|
||||
// -----
|
||||
|
||||
// CHECK-LABEL: func @sort_memrefs
|
||||
func @sort_memrefs(%arg0: memref<16x16xf32>, %arg1: memref<16x16xf16>,
|
||||
%arg_out: tuple<memref<16x16xf32>, memref<16x16xf16>>) -> () {
|
||||
"xla_lhlo.sort"(%arg0, %arg1, %arg_out) ( {
|
||||
^bb0(%a: tensor<f32>, %b: tensor<f32>, %c: tensor<f16>, %d: tensor<f16>):
|
||||
%7 = "xla_hlo.compare"(%a, %b) {comparison_direction = "GT"} : (tensor<f32>, tensor<f32>) -> tensor<i1>
|
||||
"xla_hlo.return"(%7) : (tensor<i1>) -> ()
|
||||
}) : (memref<16x16xf32>, memref<16x16xf16>, tuple<memref<16x16xf32>, memref<16x16xf16>>) -> ()
|
||||
return
|
||||
}
|
||||
|
||||
// -----
|
||||
|
||||
// CHECK-LABEL: func @tuple_select_memrefs
|
||||
func @tuple_select_memrefs(%pred: memref<20xi1>, %true_values: memref<20xf32>,
|
||||
%false_values: memref<20xf32>, %arg_out: memref<20xf32>) -> () {
|
||||
"xla_lhlo.tuple_select"(%pred, %true_values, %false_values, %arg_out)
|
||||
: (memref<20xi1>, memref<20xf32>, memref<20xf32>, memref<20xf32>) -> ()
|
||||
return
|
||||
}
|
||||
|
||||
// -----
|
||||
|
||||
func @tuple_select_memrefs(%pred: memref<10xi1>, %true_values: memref<20xf32>,
|
||||
%false_values: memref<20xf32>, %arg_out: memref<20xf32>) -> () {
|
||||
// expected-error@+1{{requires the same shape for all operands}}
|
||||
"xla_lhlo.tuple_select"(%pred, %true_values, %false_values, %arg_out)
|
||||
: (memref<10xi1>, memref<20xf32>, memref<20xf32>, memref<20xf32>) -> ()
|
||||
return
|
||||
}
|
||||
|
@ -1,4 +1,4 @@
|
||||
// RUN: tf-mlir-translate -hlo-text-to-mlir-hlo %s -o - | FileCheck %s
|
||||
// RUN: tf-mlir-translate -hlo-text-to-mlir-hlo %s -o - | FileCheck %s -DPRIVATE="attributes {sym_visibility = \"private\"}"
|
||||
|
||||
HloModule main
|
||||
|
||||
@ -8,6 +8,7 @@ ENTRY %dummy_main (Arg_0.1: f32[]) -> f32[] {
|
||||
}
|
||||
|
||||
// CHECK-LABEL: func @test_simple
|
||||
// CHECK-SAME: [[PRIVATE]]
|
||||
%test_simple (Arg_0.1: f32[4], Arg_1.2: f32[4]) -> f32[] {
|
||||
%Arg_0.1 = f32[4]{0} parameter(0)
|
||||
%Arg_1.2 = f32[4]{0} parameter(1)
|
||||
@ -21,7 +22,7 @@ ENTRY %dummy_main (Arg_0.1: f32[]) -> f32[] {
|
||||
}
|
||||
|
||||
// CHECK-LABEL: func @test_after_all
|
||||
// CHECK-SAME: ([[VAL_0:%.*]]: !xla_hlo.token, [[VAL_1:%.*]]: !xla_hlo.token) -> !xla_hlo.token
|
||||
// CHECK-SAME: ([[VAL_0:%.*]]: !xla_hlo.token, [[VAL_1:%.*]]: !xla_hlo.token) -> !xla_hlo.token [[PRIVATE]]
|
||||
%test_after_all (token0: token[], token1: token[] ) -> token[] {
|
||||
token0 = token[] parameter(0)
|
||||
token1 = token[] parameter(1)
|
||||
@ -95,7 +96,7 @@ add {
|
||||
ROOT %batch-norm-grad = (f32[2,2,2,2], f32[2], f32[2]) batch-norm-grad(f32[2,2,2,2] %input, f32[2] %scale, f32[2] %mean, f32[2] %variance, f32[2,2,2,2] %grad_output), epsilon=0.001, feature_index=1
|
||||
}
|
||||
|
||||
// CHECK-LABEL: func @call(%arg0: tensor<i64>) -> tensor<i64> {
|
||||
// CHECK-LABEL: func @call(%arg0: tensor<i64>) -> tensor<i64>
|
||||
%call (arg_1: s64[]) -> s64[] {
|
||||
%arg_1 = s64[] parameter(0), metadata={op_name="HLO_Args"}
|
||||
ROOT %compare.2 = s64[] add(%arg_1, %arg_1), metadata={op_type="Less" op_name="Less"}
|
||||
@ -136,7 +137,7 @@ add {
|
||||
}
|
||||
|
||||
|
||||
// CHECK-LABEL: func @test_compare(%arg0: tensor<3xf32>, %arg1: tensor<3xf32>, %arg2: tensor<3xf32>) -> tensor<3xi1> {
|
||||
// CHECK-LABEL: func @test_compare(%arg0: tensor<3xf32>, %arg1: tensor<3xf32>, %arg2: tensor<3xf32>) -> tensor<3xi1>
|
||||
%test_compare (Arg_0.1: f32[3], Arg_1.2: f32[3], Arg_2.3: f32[3]) -> pred[3] {
|
||||
%Arg_0.1 = f32[3] parameter(0)
|
||||
%Arg_1.2 = f32[3] parameter(1)
|
||||
@ -162,7 +163,7 @@ add {
|
||||
ROOT %complex.3 = c64[4] complex(f32[4] %Arg_0.1, f32[4] %Arg_1.2)
|
||||
}
|
||||
|
||||
// CHECK-LABEL: func @test_concat(%arg0: tensor<4x1xf32>, %arg1: tensor<4x2xf32>) -> tensor<4x3xf32> {
|
||||
// CHECK-LABEL: func @test_concat(%arg0: tensor<4x1xf32>, %arg1: tensor<4x2xf32>) -> tensor<4x3xf32>
|
||||
%test_concat (Arg_0.1: f32[4, 1], Arg_1.2: f32[4, 2]) -> f32[4, 3] {
|
||||
%Arg_0.1 = f32[4, 1] parameter(0)
|
||||
%Arg_1.2 = f32[4, 2] parameter(1)
|
||||
@ -201,7 +202,7 @@ add {
|
||||
|
||||
// TODO(b/129422361) Potentially update when copy, reshape, and conv have actual
|
||||
// implementations with attributes, etc.
|
||||
// CHECK-LABEL: func @test_conv(%arg0: tensor<256x32x32x6xf32>) -> tuple<tensor<256x30x30x16xf32>> {
|
||||
// CHECK-LABEL: func @test_conv(%arg0: tensor<256x32x32x6xf32>) -> tuple<tensor<256x30x30x16xf32>>
|
||||
%test_conv {
|
||||
%arg0.1 = f32[256,32,32,6]{3,2,1,0} parameter(0), metadata={op_name="HLO_Args"}
|
||||
|
||||
@ -257,7 +258,7 @@ add {
|
||||
ROOT %convolution = f32[1,5,1] convolution(f32[1,2,1] %input, f32[1,1,1] %filter), feature_group_count=1, dim_labels=b0f_0io->b0f, window={pad=1_2 size=1}
|
||||
}
|
||||
|
||||
// CHECK-LABEL: func @test_convert(%arg0: tensor<4xf32>, %arg1: tensor<4xf32>) -> tensor<4xf64> {
|
||||
// CHECK-LABEL: func @test_convert(%arg0: tensor<4xf32>, %arg1: tensor<4xf32>) -> tensor<4xf64>
|
||||
%test_convert (Arg_0.1: f32[4], Arg_1.2: f32[4]) -> f64[4] {
|
||||
%Arg_0.1 = f32[4] parameter(0)
|
||||
%Arg_1.2 = f32[4] parameter(1)
|
||||
@ -272,7 +273,7 @@ add {
|
||||
ROOT %add.5 = f64[4] add(f64[4] %convert.3, f64[4] %convert.4)
|
||||
}
|
||||
|
||||
// CHECK-LABEL: func @test_cosine(%arg0: tensor<1x16x16x3xf32>) -> tensor<1x16x16x3xf32> {
|
||||
// CHECK-LABEL: func @test_cosine(%arg0: tensor<1x16x16x3xf32>) -> tensor<1x16x16x3xf32>
|
||||
%test_cosine (arg0.1: f32[1,16,16,3]) -> f32[1,16,16,3] {
|
||||
%arg0.1 = f32[1,16,16,3]{3,2,1,0} parameter(0), metadata={op_name="HLO_Args"}
|
||||
|
||||
@ -289,7 +290,7 @@ add {
|
||||
ROOT %custom-call = f32[1,2,3]{0,2,1} custom-call(f32[2,3] %arg1, f32[5,5] %arg2), custom_call_target="foo", backend_config="bar", custom_call_has_side_effect=true
|
||||
}
|
||||
|
||||
// CHECK-LABEL: func @test_div(%arg0: tensor<4xf32>, %arg1: tensor<4xf32>) -> tensor<4xf32> {
|
||||
// CHECK-LABEL: func @test_div(%arg0: tensor<4xf32>, %arg1: tensor<4xf32>) -> tensor<4xf32>
|
||||
%test_div (Arg_0.1: f32[4], Arg_1.2: f32[4]) -> f32[4] {
|
||||
%Arg_0.1 = f32[4] parameter(0)
|
||||
%Arg_1.2 = f32[4] parameter(1)
|
||||
@ -298,7 +299,7 @@ add {
|
||||
ROOT %divide.3 = f32[4] divide(f32[4] %Arg_0.1, f32[4] %Arg_1.2)
|
||||
}
|
||||
|
||||
// CHECK-LABEL: func @test_dot(%arg0: tensor<1x4xf32>, %arg1: tensor<4x1xf32>) -> tensor<f32> {
|
||||
// CHECK-LABEL: func @test_dot(%arg0: tensor<1x4xf32>, %arg1: tensor<4x1xf32>) -> tensor<f32>
|
||||
%test_dot (Arg_0.1: f32[1, 4], Arg_1.2: f32[4, 1]) -> f32[] {
|
||||
%Arg_0.1 = f32[1, 4] parameter(0)
|
||||
%Arg_1.2 = f32[4, 1] parameter(1)
|
||||
@ -350,7 +351,7 @@ add {
|
||||
ROOT %dynamic-slice = s32[1,1,32] dynamic-slice(s32[2,2,258] %operand, s32[] %start_idx_1, s32[] %start_idx_2, s32[] %start_idx_3), dynamic_slice_sizes={1,1,32}
|
||||
}
|
||||
|
||||
// CHECK-LABEL: func @test_dynamic_update_slice_1(%arg0: tensor<4x4xf32>, %arg1: tensor<1x4xf32>, %arg2: tensor<i32>, %arg3: tensor<i32>) -> tensor<4x4xf32> {
|
||||
// CHECK-LABEL: func @test_dynamic_update_slice_1(%arg0: tensor<4x4xf32>, %arg1: tensor<1x4xf32>, %arg2: tensor<i32>, %arg3: tensor<i32>) -> tensor<4x4xf32>
|
||||
%test_dynamic_update_slice_1 (Arg_0.1: f32[4, 4], Arg_1.2: f32[1, 4], Arg_2.3: f32[], Arg_3.4: f32[]) -> f32[4, 4] {
|
||||
%Arg_0.1 = f32[4, 4] parameter(0)
|
||||
%Arg_1.2 = f32[1, 4] parameter(1)
|
||||
@ -371,7 +372,7 @@ add {
|
||||
ROOT %dynamic-update-slice.5 = f32[4] dynamic-update-slice(%Arg_0.1, %Arg_1.2, %Arg_2.3)
|
||||
}
|
||||
|
||||
// CHECK-LABEL: func @test_exponential(%arg0: tensor<16xf32>) -> tensor<16xf32> {
|
||||
// CHECK-LABEL: func @test_exponential(%arg0: tensor<16xf32>) -> tensor<16xf32>
|
||||
%test_exponential (arg0.1: f32[16]) -> f32[16] {
|
||||
%arg0.1 = f32[16] parameter(0)
|
||||
|
||||
@ -379,7 +380,7 @@ add {
|
||||
ROOT %exp.2 = f32[16] exponential(f32[16] %arg0.1)
|
||||
}
|
||||
|
||||
// CHECK-LABEL: func @test_expm1(%arg0: tensor<16xf32>) -> tensor<16xf32> {
|
||||
// CHECK-LABEL: func @test_expm1(%arg0: tensor<16xf32>) -> tensor<16xf32>
|
||||
%test_expm1 (arg0.1: f32[16]) -> f32[16] {
|
||||
%arg0.1 = f32[16] parameter(0)
|
||||
|
||||
@ -387,7 +388,7 @@ add {
|
||||
ROOT %expm1.2 = f32[16] exponential-minus-one(f32[16] %arg0.1)
|
||||
}
|
||||
|
||||
// CHECK-LABEL: func @test_fft(%arg0: tensor<3x9xf32>) -> tensor<3x5xcomplex<f32>> {
|
||||
// CHECK-LABEL: func @test_fft(%arg0: tensor<3x9xf32>) -> tensor<3x5xcomplex<f32>>
|
||||
%test_fft {
|
||||
%arg0.1 = f32[3,9]{1,0} parameter(0), parameter_replication={false}, metadata={op_name="XLA_Args"}
|
||||
// CHECK: "xla_hlo.fft"(%arg0) {fft_length = dense<9> : tensor<1xi64>, fft_type = "RFFT"
|
||||
@ -395,7 +396,7 @@ add {
|
||||
}
|
||||
|
||||
// CHECK-LABEL: func @test_floor(
|
||||
// CHECK-SAME: [[A0:%.+]]: tensor<16xf32>) -> tensor<16xf32> {
|
||||
// CHECK-SAME: [[A0:%.+]]: tensor<16xf32>) -> tensor<16xf32>
|
||||
%test_floor (arg0.1: f32[16]) -> f32[16] {
|
||||
%arg0.1 = f32[16] parameter(0)
|
||||
|
||||
@ -404,7 +405,7 @@ add {
|
||||
}
|
||||
|
||||
// CHECK-LABEL: func @test_gather(
|
||||
// CHECK-SAME: [[ARG0:%.+]]: tensor<200x100x300xf32>, [[ARG1:%.+]]: tensor<10x2xi32>) -> tensor<10x300xf32> {
|
||||
// CHECK-SAME: [[ARG0:%.+]]: tensor<200x100x300xf32>, [[ARG1:%.+]]: tensor<10x2xi32>) -> tensor<10x300xf32>
|
||||
%test_gather (arg.0: f32[200,100,300], arg.1: s32[10,2]) -> f32[10,300] {
|
||||
%arg.0 = f32[200,100,300] parameter(0)
|
||||
%arg.1 = s32[10,2] parameter(1)
|
||||
@ -442,7 +443,7 @@ add {
|
||||
}
|
||||
|
||||
// CHECK-LABEL: func @test_infeed
|
||||
// CHECK-SAME: ([[TOKEN:%.*]]: !xla_hlo.token) -> tuple<tensor<3xi32>, !xla_hlo.token> {
|
||||
// CHECK-SAME: ([[TOKEN:%.*]]: !xla_hlo.token) -> tuple<tensor<3xi32>, !xla_hlo.token>
|
||||
%test_infeed (token0: token[]) -> (s32[3], token[]) {
|
||||
%token0 = token[] parameter(0)
|
||||
// CHECK-NEXT: "xla_hlo.infeed"([[TOKEN]])
|
||||
@ -451,19 +452,19 @@ add {
|
||||
}
|
||||
|
||||
|
||||
// CHECK-LABEL: func @test_iota_1() -> tensor<4xf32> {
|
||||
// CHECK-LABEL: func @test_iota_1() -> tensor<4xf32>
|
||||
%test_iota_1 () -> f32[4] {
|
||||
// CHECK-NEXT: "xla_hlo.iota"() {iota_dimension = 0 : i64} : () -> tensor<4xf32>
|
||||
ROOT %iota.0 = f32[4] iota(), iota_dimension=0
|
||||
}
|
||||
|
||||
// CHECK-LABEL: func @test_iota_2() -> tensor<4x5xf32> {
|
||||
// CHECK-LABEL: func @test_iota_2() -> tensor<4x5xf32>
|
||||
%test_iota_2 () -> f32[4, 5] {
|
||||
// CHECK-NEXT: "xla_hlo.iota"() {iota_dimension = 1 : i64} : () -> tensor<4x5xf32>
|
||||
ROOT %iota.0 = f32[4, 5] iota(), iota_dimension=1
|
||||
}
|
||||
|
||||
// CHECK-LABEL: func @test_log(%arg0: tensor<16xf32>) -> tensor<16xf32> {
|
||||
// CHECK-LABEL: func @test_log(%arg0: tensor<16xf32>) -> tensor<16xf32>
|
||||
%test_log (arg0.1: f32[16]) -> f32[16] {
|
||||
%arg0.1 = f32[16] parameter(0)
|
||||
|
||||
@ -471,7 +472,7 @@ add {
|
||||
ROOT %log.2 = f32[16] log(f32[16] %arg0.1)
|
||||
}
|
||||
|
||||
// CHECK-LABEL: func @test_log1p(%arg0: tensor<16xf32>) -> tensor<16xf32> {
|
||||
// CHECK-LABEL: func @test_log1p(%arg0: tensor<16xf32>) -> tensor<16xf32>
|
||||
%test_log1p (arg0.1: f32[16]) -> f32[16] {
|
||||
%arg0.1 = f32[16] parameter(0)
|
||||
|
||||
@ -501,7 +502,7 @@ add {
|
||||
|
||||
|
||||
|
||||
// CHECK-LABEL: func @test_maximum(%arg0: tensor<4xf32>, %arg1: tensor<4xf32>) -> tensor<4xf32> {
|
||||
// CHECK-LABEL: func @test_maximum(%arg0: tensor<4xf32>, %arg1: tensor<4xf32>) -> tensor<4xf32>
|
||||
%test_maximum (Arg_0.1: f32[4], Arg_1.2: f32[4]) -> f32[4] {
|
||||
%Arg_0.1 = f32[4] parameter(0)
|
||||
%Arg_1.2 = f32[4] parameter(1)
|
||||
@ -510,7 +511,7 @@ add {
|
||||
ROOT %maximum.3 = f32[4] maximum(f32[4] %Arg_0.1, f32[4] %Arg_1.2)
|
||||
}
|
||||
|
||||
// CHECK-LABEL: func @test_minimum(%arg0: tensor<4xf32>, %arg1: tensor<4xf32>) -> tensor<4xf32> {
|
||||
// CHECK-LABEL: func @test_minimum(%arg0: tensor<4xf32>, %arg1: tensor<4xf32>) -> tensor<4xf32>
|
||||
%test_minimum (Arg_0.1: f32[4], Arg_1.2: f32[4]) -> f32[4] {
|
||||
%Arg_0.1 = f32[4] parameter(0)
|
||||
%Arg_1.2 = f32[4] parameter(1)
|
||||
@ -519,7 +520,7 @@ add {
|
||||
ROOT %minimum.3 = f32[4] minimum(f32[4] %Arg_0.1, f32[4] %Arg_1.2)
|
||||
}
|
||||
|
||||
// CHECK-LABEL: func @test_multiply(%arg0: tensor<4xf32>, %arg1: tensor<4xf32>) -> tensor<4xf32> {
|
||||
// CHECK-LABEL: func @test_multiply(%arg0: tensor<4xf32>, %arg1: tensor<4xf32>) -> tensor<4xf32>
|
||||
%test_multiply (Arg_0.1: f32[4], Arg_1.2: f32[4]) -> f32[4] {
|
||||
%Arg_0.1 = f32[4] parameter(0)
|
||||
%Arg_1.2 = f32[4] parameter(1)
|
||||
@ -528,7 +529,7 @@ add {
|
||||
ROOT %multiply.3 = f32[4] multiply(f32[4] %Arg_0.1, f32[4] %Arg_1.2)
|
||||
}
|
||||
|
||||
// CHECK-LABEL: func @test_negate(%arg0: tensor<16xf32>) -> tensor<16xf32> {
|
||||
// CHECK-LABEL: func @test_negate(%arg0: tensor<16xf32>) -> tensor<16xf32>
|
||||
%test_negate (arg0.1: f32[16]) -> f32[16] {
|
||||
%arg0.1 = f32[16] parameter(0)
|
||||
|
||||
@ -536,7 +537,7 @@ add {
|
||||
ROOT %negate.2 = f32[16] negate(f32[16] %arg0.1)
|
||||
}
|
||||
|
||||
// CHECK-LABEL: func @test_not(%arg0: tensor<16xi1>) -> tensor<16xi1> {
|
||||
// CHECK-LABEL: func @test_not(%arg0: tensor<16xi1>) -> tensor<16xi1>
|
||||
%test_not (arg0.1: pred[16]) -> pred[16] {
|
||||
%arg0.1 = pred[16] parameter(0)
|
||||
|
||||
@ -554,7 +555,7 @@ add {
|
||||
}
|
||||
|
||||
// CHECK-LABEL: func @test_outfeed
|
||||
// CHECK-SAME: ([[DATA:%.*]]: tensor<3xi32>, [[TOKEN:%.*]]: !xla_hlo.token) -> !xla_hlo.token {
|
||||
// CHECK-SAME: ([[DATA:%.*]]: tensor<3xi32>, [[TOKEN:%.*]]: !xla_hlo.token) -> !xla_hlo.token
|
||||
%test_outfeed (Arg_0.1: s32[3], Arg_1.2: token[]) -> token[] {
|
||||
%Arg_0.1 = s32[3] parameter(0)
|
||||
%Arg_1.2 = token[] parameter(1)
|
||||
@ -563,7 +564,7 @@ add {
|
||||
ROOT %outfeed.3 = token[] outfeed(s32[3] %Arg_0.1, token[] %Arg_1.2), outfeed_config="foobar"
|
||||
}
|
||||
|
||||
// CHECK-LABEL: func @test_pad(%arg0: tensor<4xf32>, %arg1: tensor<f32>) -> tensor<4xf32> {
|
||||
// CHECK-LABEL: func @test_pad(%arg0: tensor<4xf32>, %arg1: tensor<f32>) -> tensor<4xf32>
|
||||
%test_pad (Arg_0.1: f32[4], Arg_1.2: f32[]) -> f32[4] {
|
||||
%Arg_0.1 = f32[4] parameter(0)
|
||||
%Arg_1.2 = f32[] parameter(1)
|
||||
@ -572,7 +573,7 @@ add {
|
||||
ROOT %pad.3 = f32[4] pad(%Arg_0.1, %Arg_1.2), padding=0_0_0
|
||||
}
|
||||
|
||||
// CHECK-LABEL: func @test_pad_edge(%arg0: tensor<4x4x4xf32>, %arg1: tensor<f32>) -> tensor<7x11x15xf32> {
|
||||
// CHECK-LABEL: func @test_pad_edge(%arg0: tensor<4x4x4xf32>, %arg1: tensor<f32>) -> tensor<7x11x15xf32>
|
||||
%test_pad_edge (Arg_0.1: f32[4, 4, 4], Arg_1.2: f32[]) -> f32[7, 11, 15] {
|
||||
%Arg_0.1 = f32[4, 4, 4] parameter(0)
|
||||
%Arg_1.2 = f32[] parameter(1)
|
||||
@ -581,7 +582,7 @@ add {
|
||||
ROOT %pad.3 = f32[7, 11, 15] pad(%Arg_0.1, %Arg_1.2), padding=1_2x3_4x5_6
|
||||
}
|
||||
|
||||
// CHECK-LABEL: func @test_pad_interior(%arg0: tensor<4xf32>, %arg1: tensor<f32>) -> tensor<10xf32> {
|
||||
// CHECK-LABEL: func @test_pad_interior(%arg0: tensor<4xf32>, %arg1: tensor<f32>) -> tensor<10xf32>
|
||||
%test_pad_interior (Arg_0.1: f32[4], Arg_1.2: f32[]) -> f32[10] {
|
||||
%Arg_0.1 = f32[4] parameter(0)
|
||||
%Arg_1.2 = f32[] parameter(1)
|
||||
@ -590,7 +591,7 @@ add {
|
||||
ROOT %pad.3 = f32[10] pad(%Arg_0.1, %Arg_1.2), padding=0_0_2
|
||||
}
|
||||
|
||||
// CHECK-LABEL: func @test_popcnt(%arg0: tensor<16xi32>) -> tensor<16xi32> {
|
||||
// CHECK-LABEL: func @test_popcnt(%arg0: tensor<16xi32>) -> tensor<16xi32>
|
||||
%test_popcnt (arg0.1: s32[16]) -> s32[16] {
|
||||
%arg0.1 = s32[16] parameter(0)
|
||||
|
||||
@ -598,7 +599,7 @@ add {
|
||||
ROOT %popcnt.2 = s32[16] popcnt(s32[16] %arg0.1)
|
||||
}
|
||||
|
||||
// CHECK-LABEL: func @test_pow(%arg0: tensor<4xf32>, %arg1: tensor<4xf32>) -> tensor<4xf32> {
|
||||
// CHECK-LABEL: func @test_pow(%arg0: tensor<4xf32>, %arg1: tensor<4xf32>) -> tensor<4xf32>
|
||||
%test_pow (Arg_0.1: f32[4], Arg_1.2: f32[4]) -> f32[4] {
|
||||
%Arg_0.1 = f32[4] parameter(0)
|
||||
%Arg_1.2 = f32[4] parameter(1)
|
||||
@ -659,7 +660,7 @@ add {
|
||||
}
|
||||
|
||||
// CHECK-LABEL: func @test_reduce
|
||||
// CHECK-SAME: ([[ARG0:%.*]]: tensor<4x4xf32>, [[ARG1:%.*]]: tensor<4xf32>, [[ARG2:%.*]]: tensor<f32>) -> tuple<tuple<tensor<f32>, tensor<f32>>, tensor<f32>> {
|
||||
// CHECK-SAME: ([[ARG0:%.*]]: tensor<4x4xf32>, [[ARG1:%.*]]: tensor<4xf32>, [[ARG2:%.*]]: tensor<f32>) -> tuple<tuple<tensor<f32>, tensor<f32>>, tensor<f32>>
|
||||
%test_reduce (Arg_0.1: f32[4, 4], Arg_1.2: f32[4], Arg_2.3: f32[]) -> ((f32[], f32[]), f32[]) {
|
||||
%Arg_0.1 = f32[4, 4] parameter(0)
|
||||
%Arg_1.2 = f32[4] parameter(1)
|
||||
@ -719,7 +720,7 @@ add {
|
||||
ROOT %remainder.3 = f32[4] remainder(f32[4] %Arg_0.1, f32[4] %Arg_1.2)
|
||||
}
|
||||
|
||||
// CHECK-LABEL: func @test_reverse_1d(%arg0: tensor<4xf32>) -> tensor<4xf32> {
|
||||
// CHECK-LABEL: func @test_reverse_1d(%arg0: tensor<4xf32>) -> tensor<4xf32>
|
||||
%test_reverse_1d (Arg_0.1: f32[4]) -> f32[4] {
|
||||
%Arg_0.1 = f32[4] parameter(0)
|
||||
|
||||
@ -727,7 +728,7 @@ add {
|
||||
ROOT reverse.2 = f32[4] reverse(%Arg_0.1), dimensions={0}
|
||||
}
|
||||
|
||||
// CHECK-LABEL: func @test_reverse_2d(%arg0: tensor<4x4xf32>) -> tensor<4x4xf32> {
|
||||
// CHECK-LABEL: func @test_reverse_2d(%arg0: tensor<4x4xf32>) -> tensor<4x4xf32
|
||||
%test_reverse_2d (Arg_0.1: f32[4, 4]) -> f32[4, 4] {
|
||||
%Arg_0.1 = f32[4, 4] parameter(0)
|
||||
|
||||
@ -736,7 +737,7 @@ add {
|
||||
}
|
||||
|
||||
// CHECK-LABEL: func @test_rsqrt(
|
||||
// CHECK-SAME: [[ARG0:%.+]]: tensor<16xf32>) -> tensor<16xf32> {
|
||||
// CHECK-SAME: [[ARG0:%.+]]: tensor<16xf32>) -> tensor<16xf32>
|
||||
%test_rsqrt (arg0.1: f32[16]) -> f32[16] {
|
||||
%arg0.1 = f32[16] parameter(0)
|
||||
|
||||
@ -744,7 +745,7 @@ add {
|
||||
ROOT %rsqrt.2 = f32[16] rsqrt(f32[16] %arg0.1)
|
||||
}
|
||||
|
||||
// CHECK-LABEL: func @test_scalar(%arg0: tensor<f32>) -> tensor<f32> {
|
||||
// CHECK-LABEL: func @test_scalar(%arg0: tensor<f32>) -> tensor<f32>
|
||||
%test_scalar (Arg_0.1: f32[]) -> f32[] {
|
||||
// CHECK-NEXT: return %arg0 : tensor<f32>
|
||||
ROOT %Arg_0.1 = f32[] parameter(0)
|
||||
@ -781,7 +782,7 @@ add {
|
||||
// CHECK-SAME: unique_indices = false
|
||||
|
||||
|
||||
// CHECK-LABEL: func @test_select(%arg0: tensor<2x3xi1>, %arg1: tensor<2x3xi32>, %arg2: tensor<2x3xi32>) -> tensor<2x3xi32> {
|
||||
// CHECK-LABEL: func @test_select(%arg0: tensor<2x3xi1>, %arg1: tensor<2x3xi32>, %arg2: tensor<2x3xi32>) -> tensor<2x3xi32>
|
||||
%test_select {
|
||||
%Arg_0.1 = pred[2,3] parameter(0)
|
||||
%Arg_1.2 = s32[2,3] parameter(1)
|
||||
@ -838,7 +839,7 @@ add {
|
||||
ROOT %set-dimension-size.2 = f32[4,<=4] set-dimension-size(f32[4,4] %Arg_0.1, s32[] %Arg_1.2), dimensions={1}
|
||||
}
|
||||
|
||||
// CHECK-LABEL: func @test_sine(%arg0: tensor<1x16x16x3xf32>) -> tensor<1x16x16x3xf32> {
|
||||
// CHECK-LABEL: func @test_sine(%arg0: tensor<1x16x16x3xf32>) -> tensor<1x16x16x3xf32>
|
||||
%test_sine (arg0.1: f32[1,16,16,3]) -> f32[1,16,16,3] {
|
||||
%arg0.1 = f32[1,16,16,3]{3,2,1,0} parameter(0), metadata={op_name="HLO_Args"}
|
||||
|
||||
@ -874,7 +875,7 @@ add {
|
||||
ROOT %subtract.3 = f32[4] subtract(f32[4] %Arg_0.1, f32[4] %Arg_1.2)
|
||||
}
|
||||
|
||||
// CHECK-LABEL: func @test_tanh(%arg0: tensor<1x16x16x3xf32>) -> tensor<1x16x16x3xf32> {
|
||||
// CHECK-LABEL: func @test_tanh(%arg0: tensor<1x16x16x3xf32>) -> tensor<1x16x16x3xf32>
|
||||
%test_tanh (arg0.1: f32[1,16,16,3]) -> f32[1,16,16,3] {
|
||||
%arg0.1 = f32[1,16,16,3]{3,2,1,0} parameter(0), metadata={op_name="HLO_Args"}
|
||||
|
||||
@ -882,7 +883,7 @@ add {
|
||||
ROOT %tanh.3 = f32[1,16,16,3]{3,2,1,0} tanh(f32[1,16,16,3]{3,2,1,0} %arg0.1), metadata={op_type="Tanh" op_name="embedded_inference/tanh_model/Tanh"}
|
||||
}
|
||||
|
||||
// CHECK-LABEL: func @test_transpose(%arg0: tensor<1x2x3x4xi32>) -> tensor<2x1x4x3xi32> {
|
||||
// CHECK-LABEL: func @test_transpose(%arg0: tensor<1x2x3x4xi32>) -> tensor<2x1x4x3xi32>
|
||||
%test_transpose {
|
||||
%Arg_0.1 = s32[1,2,3,4] parameter(0)
|
||||
|
||||
@ -903,7 +904,7 @@ add {
|
||||
ROOT %triangular-solve.3 = f32[4,3] triangular-solve(f32[4,4] %Arg_0.1, f32[4,3] %Arg_1.2), left_side=true, lower=true, transpose_a=NO_TRANSPOSE, unit_diagonal=true
|
||||
}
|
||||
|
||||
// CHECK-LABEL: func @test_tuple(%arg0: tensor<1xi32>, %arg1: tensor<1x2xf32>) -> tuple<tensor<1xi32>, tensor<1x2xf32>> {
|
||||
// CHECK-LABEL: func @test_tuple(%arg0: tensor<1xi32>, %arg1: tensor<1x2xf32>) -> tuple<tensor<1xi32>, tensor<1x2xf32>>
|
||||
%test_tuple(Arg_0.1: s32[1], Arg_1.2: f32[1, 2]) -> (s32[1], f32[1,2]) {
|
||||
%Arg_0.1 = s32[1] parameter(0)
|
||||
%Arg_1.2 = f32[1, 2] parameter(1)
|
||||
@ -928,7 +929,7 @@ add {
|
||||
ROOT %compare.2 = s64[] add(%arg_1, %arg_1), metadata={op_type="Less" op_name="Less"}
|
||||
}
|
||||
|
||||
// CHECK-LABEL: func @test_while(%arg0: tensor<i64>) -> tensor<i64> {
|
||||
// CHECK-LABEL: func @test_while(%arg0: tensor<i64>) -> tensor<i64>
|
||||
%test_while (arg0.1: s64[]) -> s64[] {
|
||||
%arg0.1 = s64[] parameter(0), metadata={op_name="HLO_Args"}
|
||||
// CHECK-NEXT: "xla_hlo.while"(%arg0) ( {
|
||||
|
@ -368,6 +368,15 @@ class HloToLhloTensorStoreOpConverter
|
||||
|
||||
struct HloLegalizeToLhlo
|
||||
: public PassWrapper<HloLegalizeToLhlo, OperationPass<ModuleOp>> {
|
||||
public:
|
||||
HloLegalizeToLhlo() = default;
|
||||
HloLegalizeToLhlo(const HloLegalizeToLhlo& o) {
|
||||
this->results_escape_function = o.results_escape_function.getValue();
|
||||
}
|
||||
explicit HloLegalizeToLhlo(bool results_escape_function) {
|
||||
this->results_escape_function.setValue(results_escape_function);
|
||||
}
|
||||
|
||||
void runOnOperation() override {
|
||||
OwningRewritePatternList patterns;
|
||||
auto& context = getContext();
|
||||
@ -398,10 +407,28 @@ struct HloLegalizeToLhlo
|
||||
OwningRewritePatternList patterns;
|
||||
populateHLOToLHLOConversionPattern(func.getContext(), &bufferAssignment,
|
||||
&converter, &patterns);
|
||||
if (results_escape_function) {
|
||||
populateWithBufferAssignmentOpConversionPatterns<
|
||||
mlir::ReturnOp, mlir::ReturnOp, xla_lhlo::CopyOp,
|
||||
/*allowMemrefFunctionResults=*/true>(&context, &bufferAssignment,
|
||||
&converter, &patterns);
|
||||
} else {
|
||||
populateWithBufferAssignmentOpConversionPatterns<
|
||||
mlir::ReturnOp, mlir::ReturnOp, xla_lhlo::CopyOp,
|
||||
/*allowMemrefFunctionResults=*/false>(&context, &bufferAssignment,
|
||||
&converter, &patterns);
|
||||
}
|
||||
return WalkResult(
|
||||
applyPartialConversion(func, target, patterns, &converter));
|
||||
});
|
||||
}
|
||||
|
||||
private:
|
||||
Option<bool> results_escape_function{
|
||||
*this, "results-escape-function",
|
||||
llvm::cl::desc(
|
||||
"Allocate the results of functions within the functions body"),
|
||||
llvm::cl::init(false)};
|
||||
};
|
||||
} // namespace
|
||||
|
||||
@ -446,14 +473,11 @@ void populateHLOToLHLOConversionPattern(
|
||||
HloToLhloTensorStoreOpConverter
|
||||
>(context, bufferAssignment, converter);
|
||||
// clang-format on
|
||||
populateWithBufferAssignmentOpConversionPatterns<
|
||||
mlir::ReturnOp, mlir::ReturnOp, xla_lhlo::CopyOp,
|
||||
/*allowMemrefFunctionResults=*/false>(context, bufferAssignment,
|
||||
converter, patterns);
|
||||
}
|
||||
|
||||
std::unique_ptr<OperationPass<ModuleOp>> createLegalizeToLhloPass() {
|
||||
return absl::make_unique<HloLegalizeToLhlo>();
|
||||
std::unique_ptr<OperationPass<ModuleOp>> createLegalizeToLhloPass(
|
||||
bool results_escape_function) {
|
||||
return absl::make_unique<HloLegalizeToLhlo>(results_escape_function);
|
||||
}
|
||||
|
||||
static PassRegistration<HloLegalizeToLhlo> legalize_pass(
|
||||
|
@ -20,6 +20,7 @@ limitations under the License.
|
||||
#include "absl/memory/memory.h"
|
||||
#include "llvm/ADT/ArrayRef.h"
|
||||
#include "mlir/Dialect/Linalg/Transforms/Transforms.h" // from @llvm-project
|
||||
#include "mlir/Dialect/StandardOps/IR/Ops.h" // from @llvm-project
|
||||
#include "mlir/Pass/Pass.h" // from @llvm-project
|
||||
#include "mlir/Transforms/FoldUtils.h" // from @llvm-project
|
||||
#include "tensorflow/compiler/mlir/xla/transforms/passes.h"
|
||||
@ -52,10 +53,17 @@ class LhloFuseLinalg : public PassWrapper<LhloFuseLinalg, FunctionPass> {
|
||||
// The fusion in Linalg is currently possible only when the consumer op is
|
||||
// tiled. In order to greedily fuse the ops, we have to start from the tiled
|
||||
// root linalg ops, i.e. linalg ops that write to output buffers of the
|
||||
// function.
|
||||
llvm::SmallDenseSet<Value> func_args;
|
||||
// function or are returned in case of escaping allocations.
|
||||
llvm::SmallDenseSet<Value> result_buffers;
|
||||
for (auto func_arg : func.getArguments()) {
|
||||
func_args.insert(func_arg);
|
||||
result_buffers.insert(func_arg);
|
||||
}
|
||||
for (auto& block : func.getBlocks()) {
|
||||
auto returnOp = mlir::dyn_cast<mlir::ReturnOp>(block.getTerminator());
|
||||
if (!returnOp) continue;
|
||||
for (auto operand : returnOp.getOperands()) {
|
||||
result_buffers.insert(operand);
|
||||
}
|
||||
}
|
||||
MLIRContext* ctx = func.getContext();
|
||||
OpBuilder b(func);
|
||||
@ -68,7 +76,7 @@ class LhloFuseLinalg : public PassWrapper<LhloFuseLinalg, FunctionPass> {
|
||||
}
|
||||
auto op = cast<LinalgOp>(generic_op.getOperation());
|
||||
for (const Value result : op.getOutputBuffers()) {
|
||||
if (!func_args.count(result)) continue;
|
||||
if (!result_buffers.count(result)) continue;
|
||||
if (tileGenericOp(op, tile_sizes, &b)) {
|
||||
generic_op.erase();
|
||||
return;
|
||||
|
@ -59,9 +59,13 @@ std::unique_ptr<OperationPass<FuncOp>> createLegalizeControlFlowPass();
|
||||
/// Lowers from HLO dialect to Standard dialect.
|
||||
std::unique_ptr<OperationPass<FuncOp>> createLegalizeToStdPass();
|
||||
|
||||
// Lowers from HLO dialect to LHLO dialect allocating/deallocating temporary
|
||||
// buffers if necessary.
|
||||
std::unique_ptr<OperationPass<ModuleOp>> createLegalizeToLhloPass();
|
||||
/// Lowers from HLO dialect to LHLO dialect allocating/deallocating temporary
|
||||
/// buffers if necessary. If `results_escape_functions` is set to true,
|
||||
/// allocated buffers for function results will be returned and escape the
|
||||
/// function. Otherwise, the signature is rewritten with extra arguments for the
|
||||
/// buffers that are to be used for results.
|
||||
std::unique_ptr<OperationPass<ModuleOp>> createLegalizeToLhloPass(
|
||||
bool results_escape_functions = false);
|
||||
|
||||
// Lowers from HLO dialect to Linalg dialect.
|
||||
std::unique_ptr<OperationPass<FuncOp>> createLegalizeHloToLinalgPass();
|
||||
@ -111,24 +115,6 @@ std::unique_ptr<Pass> createLhloCopyRemovalPass();
|
||||
std::unique_ptr<OperationPass<FuncOp>> createLegalizeLhloToParallelLoopsPass();
|
||||
|
||||
} // namespace xla_lhlo
|
||||
|
||||
namespace xla {
|
||||
|
||||
/// Moves alloc nodes (and their associated dealloc nodes - if any) into the
|
||||
/// right positions. If there is no associated dealloc node for a given alloc
|
||||
/// node, this pass will automatically insert a proper dealloc node in the right
|
||||
/// place. The intended use case of this pass is to store SSA values into
|
||||
/// buffers using load/store operations. For this purpose, you need to know
|
||||
/// proper positions to place the required allocs and deallocs.
|
||||
/// 1) Note that the function signatures and all types for which buffers should
|
||||
/// be allocated need to be converted in advance.
|
||||
/// 2) All required alloc nodes have the be inserted in advance.
|
||||
/// 3) Note that the current implementation does not support loops.
|
||||
/// Refer to the class mlir::xla::BufferAssignmentLegalizer for more
|
||||
/// information.
|
||||
std::unique_ptr<OperationPass<FuncOp>> createBufferAssignmentPass();
|
||||
|
||||
} // namespace xla
|
||||
} // namespace mlir
|
||||
|
||||
#endif // TENSORFLOW_COMPILER_MLIR_XLA_TRANSFORMS_PASSES_H_
|
||||
|
@ -135,18 +135,16 @@ class MatrixTriangularSolveOpTest(xla_test.XLATestCase):
|
||||
self._VerifyTriangularSolve(
|
||||
a.astype(np.float32), b.astype(np.float32), True, False, 1e-4)
|
||||
|
||||
@test_util.run_deprecated_v1
|
||||
def testNonSquareCoefficientMatrixV1(self):
|
||||
def testNonSquareCoefficientMatrix(self):
|
||||
rng = np.random.RandomState(0)
|
||||
for dtype in self.float_types:
|
||||
a = rng.randn(3, 4).astype(dtype)
|
||||
b = rng.randn(4, 4).astype(dtype)
|
||||
with self.assertRaises(ValueError):
|
||||
linalg_ops.matrix_triangular_solve(a, b)
|
||||
with self.assertRaises(ValueError):
|
||||
linalg_ops.matrix_triangular_solve(a, b)
|
||||
with self.test_scope():
|
||||
with self.assertRaises((ValueError, errors.InvalidArgumentError)):
|
||||
linalg_ops.matrix_triangular_solve(a, b)
|
||||
|
||||
@test_util.run_v2_only
|
||||
@test_util.run_v2_only # Different error types
|
||||
def testWrongDimensionsV2(self):
|
||||
randn = np.random.RandomState(0).randn
|
||||
for dtype in self.float_types:
|
||||
|
@ -61,6 +61,81 @@ def implicit_reparameterization_grad(a, x):
|
||||
return -gen_math_ops.igamma_grad_a(a, x) / prob
|
||||
|
||||
|
||||
@def_function.function(experimental_compile=True)
|
||||
def _log1p(x):
|
||||
return math_ops.log1p(x)
|
||||
|
||||
|
||||
class Log1pTest(xla_test.XLATestCase, parameterized.TestCase):
|
||||
|
||||
def setUp(self):
|
||||
if flags.FLAGS.vary_seed:
|
||||
entropy = os.urandom(64)
|
||||
if six.PY2:
|
||||
answer = int(entropy.encode('hex'), 16)
|
||||
else:
|
||||
answer = int.from_bytes(entropy, 'big')
|
||||
np.random.seed(answer % (2**32 - 1))
|
||||
super(Log1pTest, self).setUp()
|
||||
|
||||
def adjust_tolerance_for_tpu(self, dtype, rtol, atol):
|
||||
if self.device not in ['TPU']:
|
||||
return rtol, atol
|
||||
|
||||
if dtype == np.float32:
|
||||
return 4e-4, 0.
|
||||
return 1e-10, 0.
|
||||
|
||||
def _test_range(self, low, high, dtype, rtol, atol, is_negative=False):
|
||||
# Test values near zero.
|
||||
rtol, atol = self.adjust_tolerance_for_tpu(dtype, rtol, atol)
|
||||
x = np.exp(np.random.uniform(
|
||||
low=low, high=high, size=[NUM_SAMPLES])).astype(dtype)
|
||||
if is_negative:
|
||||
x = -x
|
||||
expected_values = np.log1p(x)
|
||||
with self.session() as sess:
|
||||
with self.test_scope():
|
||||
actual = _log1p(x)
|
||||
actual = sess.run(actual)
|
||||
self.assertAllClose(expected_values, actual, atol=atol, rtol=rtol)
|
||||
|
||||
@parameterized.parameters((np.float32, 1e-7, 0.),
|
||||
(np.float64, 1e-15, 0.))
|
||||
def testSmallX(self, dtype, rtol, atol):
|
||||
self._test_range(-40., -20., dtype, rtol, atol, is_negative=False)
|
||||
self._test_range(-40., -20., dtype, rtol, atol, is_negative=True)
|
||||
|
||||
@parameterized.parameters((np.float32, 2e-7, 0.),
|
||||
(np.float64, 1e-15, 0.))
|
||||
def testGreaterThanNegativeTwentyExponent(self, dtype, rtol, atol):
|
||||
self._test_range(-20., -10., dtype, rtol, atol, is_negative=False)
|
||||
self._test_range(-20., -10., dtype, rtol, atol, is_negative=True)
|
||||
|
||||
@parameterized.parameters((np.float32, 2e-7, 0.),
|
||||
(np.float64, 1e-15, 0.))
|
||||
def testGreaterThanNegativeTenExponent(self, dtype, rtol, atol):
|
||||
self._test_range(-10., -5., dtype, rtol, atol, is_negative=False)
|
||||
self._test_range(-10., -5., dtype, rtol, atol, is_negative=True)
|
||||
|
||||
@parameterized.parameters((np.float32, 2e-7, 0.),
|
||||
(np.float64, 1e-15, 0.))
|
||||
def testGreaterThanNegativeFiveExponent(self, dtype, rtol, atol):
|
||||
self._test_range(-5., -1., dtype, rtol, atol, is_negative=False)
|
||||
self._test_range(-5., -1., dtype, rtol, atol, is_negative=True)
|
||||
|
||||
@parameterized.parameters((np.float32, 4e-7, 0.),
|
||||
(np.float64, 3e-14, 0.))
|
||||
def testXGreaterThanOneTenth(self, dtype, rtol, atol):
|
||||
self._test_range(-1., 0., dtype, rtol, atol, is_negative=False)
|
||||
self._test_range(-1., 0., dtype, rtol, atol, is_negative=True)
|
||||
|
||||
@parameterized.parameters((np.float32, 2e-7, 0.),
|
||||
(np.float64, 2e-15, 0.))
|
||||
def testXGreaterThanOne(self, dtype, rtol, atol):
|
||||
self._test_range(0., 3., dtype, rtol, atol, is_negative=False)
|
||||
|
||||
|
||||
class IgammaTest(xla_test.XLATestCase, parameterized.TestCase):
|
||||
|
||||
def setUp(self):
|
||||
|
@ -292,13 +292,17 @@ class UnaryOpsTest(xla_test.XLATestCase):
|
||||
np.array([[1, 2]], dtype=dtype),
|
||||
expected=np.array([[0.540297, -0.41614]], dtype=dtype))
|
||||
|
||||
# Confirm that log1p will remain precise across a range of small values.
|
||||
self._assertOpOutputMatchesExpected(
|
||||
math_ops.log1p,
|
||||
np.array([[1e-14, 1e-15, 0.6]], dtype=dtype),
|
||||
expected=np.log1p(np.array([[1e-14, 1e-15, 0.6]],
|
||||
dtype=dtype)).astype(dtype),
|
||||
rtol=1e-4,
|
||||
atol=1e-6)
|
||||
np.array([[1e-14, 1e-15, 0.6, 2] + [x * 1e-5 for x in range(1, 20)]],
|
||||
dtype=dtype),
|
||||
expected=np.log1p(
|
||||
np.array(
|
||||
[[1e-14, 1e-15, 0.6, 2] + [x * 1e-5 for x in range(1, 20)]],
|
||||
dtype=dtype)).astype(dtype),
|
||||
rtol=1e-15 if dtype == np.float64 else 1e-4,
|
||||
atol=1e-15 if dtype == np.float64 else 1e-4)
|
||||
|
||||
self._assertOpOutputMatchesExpected(
|
||||
math_ops.rint,
|
||||
|
@ -211,6 +211,24 @@ class XlaOpsNumericalTest(xla_test.XLATestCase, parameterized.TestCase):
|
||||
[7, 7, 7, 7, 7], [7, 2, 3, 7, 7], [7, 7, 7, 7, 7]],
|
||||
dtype=dtype))
|
||||
|
||||
def testPadNegative(self):
|
||||
for dtype in self.numeric_types:
|
||||
|
||||
def pad_fn(x):
|
||||
return xla.pad(
|
||||
x,
|
||||
padding_value=7,
|
||||
padding_low=[0, -1],
|
||||
padding_high=[1, -2],
|
||||
padding_interior=[1, 2])
|
||||
|
||||
self._assertOpOutputMatchesExpected(
|
||||
pad_fn,
|
||||
args=(np.arange(6, dtype=np.int32).astype(dtype).reshape([2, 3]),),
|
||||
expected=np.array(
|
||||
[[7, 7, 1, 7], [7, 7, 7, 7], [7, 7, 4, 7], [7, 7, 7, 7]],
|
||||
dtype=dtype))
|
||||
|
||||
@test_util.disable_mlir_bridge('Not supported yet')
|
||||
def testReduce(self):
|
||||
for dtype in set(self.numeric_types).intersection(
|
||||
|
@ -50,6 +50,14 @@ class MatrixTriangularSolveOp : public XlaOpKernel {
|
||||
return;
|
||||
}
|
||||
|
||||
auto lhs_size = lhs_shape.dims();
|
||||
OP_REQUIRES(
|
||||
ctx,
|
||||
lhs_shape.dim_size(lhs_size - 1) == lhs_shape.dim_size(lhs_size - 2),
|
||||
errors::InvalidArgument("The coefficient matrix must be square in "
|
||||
"the inner-most two dimensions: ",
|
||||
lhs_shape.DebugString()));
|
||||
|
||||
xla::XlaOp a = ctx->Input(0);
|
||||
xla::XlaOp b = ctx->Input(1);
|
||||
std::tie(a, b) = Broadcast(a, lhs_shape, b, rhs_shape, bcast);
|
||||
|
@ -64,14 +64,6 @@ class XlaPadOp : public XlaOpKernel {
|
||||
padding_interior.size(), " vs. ", rank, ")"));
|
||||
|
||||
auto non_negative = [](int64 x) { return x >= 0; };
|
||||
OP_REQUIRES(
|
||||
context, absl::c_all_of(padding_low, non_negative),
|
||||
errors::InvalidArgument("padding_low must be non-negative, got [",
|
||||
absl::StrJoin(padding_low, ","), "]"));
|
||||
OP_REQUIRES(
|
||||
context, absl::c_all_of(padding_high, non_negative),
|
||||
errors::InvalidArgument("padding_high must be non-negative, got [",
|
||||
absl::StrJoin(padding_high, ","), "]"));
|
||||
OP_REQUIRES(
|
||||
context, absl::c_all_of(padding_interior, non_negative),
|
||||
errors::InvalidArgument("padding_interior must be non-negative, got [",
|
||||
|
@ -59,6 +59,7 @@ StatusOr<std::shared_ptr<PjRtClient>> GetCpuClient(bool asynchronous) {
|
||||
return std::make_shared<PjRtClient>(
|
||||
kCpuPlatformName, client, std::move(devices), /*host_id=*/0,
|
||||
/*allocator=*/nullptr, /*host_memory_allocator=*/nullptr,
|
||||
/*should_stage_host_to_device_transfers=*/false,
|
||||
/*gpu_run_options=*/nullptr);
|
||||
}
|
||||
|
||||
|
@ -72,18 +72,21 @@ TEST(GpuMultiStream, Basics) {
|
||||
TF_ASSERT_OK_AND_ASSIGN(
|
||||
auto dummy_buffer,
|
||||
PjRtBuffer::FromHostBuffer(
|
||||
dummy_inputs.data(), dummy_shape, /*force_copy=*/false,
|
||||
dummy_inputs.data(), dummy_shape,
|
||||
PjRtBuffer::HostBufferSemantics::kImmutableUntilTransferCompletes,
|
||||
/*buffer_reference=*/nullptr, client.get(), device));
|
||||
TF_ASSERT_OK_AND_ASSIGN(
|
||||
auto in_buffer0,
|
||||
PjRtBuffer::FromHostBuffer(inputs.data(), shape, /*force_copy=*/false,
|
||||
/*buffer_reference=*/nullptr, client.get(),
|
||||
device));
|
||||
PjRtBuffer::FromHostBuffer(
|
||||
inputs.data(), shape,
|
||||
PjRtBuffer::HostBufferSemantics::kImmutableUntilTransferCompletes,
|
||||
/*buffer_reference=*/nullptr, client.get(), device));
|
||||
TF_ASSERT_OK_AND_ASSIGN(
|
||||
auto in_buffer1,
|
||||
PjRtBuffer::FromHostBuffer(inputs.data(), shape, /*force_copy=*/false,
|
||||
/*buffer_reference=*/nullptr, client.get(),
|
||||
device));
|
||||
PjRtBuffer::FromHostBuffer(
|
||||
inputs.data(), shape,
|
||||
PjRtBuffer::HostBufferSemantics::kImmutableUntilTransferCompletes,
|
||||
/*buffer_reference=*/nullptr, client.get(), device));
|
||||
// The execution may be enqueued before the transfers complete, requiring
|
||||
// adequate device-side synchronization.
|
||||
ExecuteOptions options;
|
||||
|
@ -53,6 +53,7 @@ StatusOr<std::shared_ptr<PjRtClient>> GetInterpreterClient() {
|
||||
return std::make_shared<PjRtClient>(
|
||||
kInterpreterPlatformName, client, std::move(devices), /*host_id=*/0,
|
||||
/*allocator=*/nullptr, /*host_memory_allocator=*/nullptr,
|
||||
/*should_stage_host_to_device_transfers=*/false,
|
||||
/*gpu_run_options=*/nullptr);
|
||||
}
|
||||
|
||||
|
@ -316,6 +316,7 @@ StatusOr<std::shared_ptr<PjRtClient>> GetNvidiaGpuClient(
|
||||
"gpu", xla_client, std::move(devices),
|
||||
/*node_id=*/node_id, std::move(allocator),
|
||||
std::move(host_memory_allocator),
|
||||
/*should_stage_host_to_device_transfers=*/true,
|
||||
/*gpu_run_options=*/std::move(gpu_run_options));
|
||||
return pyclient;
|
||||
}
|
||||
|
@ -95,6 +95,7 @@ limitations under the License.
|
||||
#include "tensorflow/compiler/xla/util.h"
|
||||
#include "tensorflow/compiler/xla/xla_data.pb.h"
|
||||
#include "tensorflow/core/platform/errors.h"
|
||||
#include "tensorflow/core/platform/mem.h"
|
||||
#include "tensorflow/core/platform/status.h"
|
||||
#include "tensorflow/core/platform/types.h"
|
||||
#include "tensorflow/core/profiler/lib/traceme.h"
|
||||
@ -154,18 +155,35 @@ StatusOr<DeviceAssignment> DevicesToDeviceAssignment(
|
||||
return xla_assignment;
|
||||
}
|
||||
|
||||
class CpuAllocator : public tensorflow::Allocator {
|
||||
public:
|
||||
CpuAllocator() = default;
|
||||
|
||||
string Name() override { return "cpu"; }
|
||||
|
||||
void* AllocateRaw(size_t alignment, size_t num_bytes) override {
|
||||
return tensorflow::port::AlignedMalloc(num_bytes, alignment);
|
||||
}
|
||||
void DeallocateRaw(void* ptr) override {
|
||||
return tensorflow::port::AlignedFree(ptr);
|
||||
}
|
||||
};
|
||||
|
||||
PjRtClient::PjRtClient(
|
||||
std::string platform_name, LocalClient* client,
|
||||
std::vector<std::unique_ptr<Device>> devices, int host_id,
|
||||
std::unique_ptr<se::DeviceMemoryAllocator> allocator,
|
||||
std::unique_ptr<tensorflow::Allocator> host_memory_allocator,
|
||||
bool should_stage_host_to_device_transfers,
|
||||
std::unique_ptr<GpuExecutableRunOptions> gpu_run_options)
|
||||
: platform_name_(std::move(platform_name)),
|
||||
client_(client),
|
||||
host_memory_allocator_(std::move(host_memory_allocator)),
|
||||
devices_(std::move(devices)),
|
||||
host_id_(host_id),
|
||||
owned_allocator_(std::move(allocator)),
|
||||
host_memory_allocator_(std::move(host_memory_allocator)),
|
||||
should_stage_host_to_device_transfers_(
|
||||
should_stage_host_to_device_transfers),
|
||||
gpu_run_options_(std::move(gpu_run_options)),
|
||||
h2d_transfer_pool_(tensorflow::Env::Default(), "py_xla_h2d_transfer",
|
||||
client->device_count()) {
|
||||
@ -175,6 +193,10 @@ PjRtClient::PjRtClient(
|
||||
allocator_ = client_->backend().memory_allocator();
|
||||
}
|
||||
|
||||
if (!host_memory_allocator_) {
|
||||
host_memory_allocator_ = std::make_unique<CpuAllocator>();
|
||||
}
|
||||
|
||||
for (const std::unique_ptr<Device>& device : devices_) {
|
||||
CHECK(id_to_device_.insert({device->id(), device.get()}).second)
|
||||
<< "Duplicate device id: " << device->id();
|
||||
@ -202,16 +224,58 @@ StatusOr<DeviceAssignment> PjRtClient::GetDefaultDeviceAssignment(
|
||||
|
||||
StatusOr<absl::flat_hash_set<int>> PjRtClient::GetParametersThatMustBeDonated(
|
||||
const LocalExecutable& executable, bool tuple_inputs) const {
|
||||
// TODO(b/149489114) support buffer donation on CPU/GPU when XLA supports it.
|
||||
HloComputation* computation =
|
||||
executable.executable()->module().entry_computation();
|
||||
int number_of_parameters = [&]() -> int {
|
||||
if (tuple_inputs) {
|
||||
CHECK_EQ(computation->num_parameters(), 1);
|
||||
const Shape& input_tuple_shape =
|
||||
computation->parameter_instruction(0)->shape();
|
||||
CHECK(input_tuple_shape.IsTuple());
|
||||
return input_tuple_shape.tuple_shapes_size();
|
||||
} else {
|
||||
return computation->num_parameters();
|
||||
}
|
||||
}();
|
||||
// If any buffer in a parameter is aliased we will donate the entire input
|
||||
// parameter.
|
||||
absl::flat_hash_set<int> parameters_to_donate;
|
||||
const HloInputOutputAliasConfig& config =
|
||||
executable.executable()->module().input_output_alias_config();
|
||||
TF_RETURN_IF_ERROR(config.ForEachAliasWithStatus(
|
||||
[](const ShapeIndex& output_index,
|
||||
const HloInputOutputAliasConfig::Alias& alias) {
|
||||
return InvalidArgument(
|
||||
"Buffer aliasing is not supported by XLA for non-TPU backends.");
|
||||
[&](const ShapeIndex& output_index,
|
||||
const HloInputOutputAliasConfig::Alias& alias) {
|
||||
if (tuple_inputs) {
|
||||
if (alias.parameter_number != 0) {
|
||||
return InvalidArgument(
|
||||
"Unexpected parameter number %d in alias config with tupled "
|
||||
"inputs",
|
||||
alias.parameter_number);
|
||||
}
|
||||
const ShapeIndex& index = alias.parameter_index;
|
||||
if (!index.empty()) {
|
||||
int this_parameter = index.data()[0];
|
||||
if (this_parameter >= number_of_parameters) {
|
||||
return InvalidArgument(
|
||||
"Unexpected parameter index %s in alias config with tupled "
|
||||
"inputs and %d parameters",
|
||||
index.ToString(), number_of_parameters);
|
||||
}
|
||||
parameters_to_donate.insert(this_parameter);
|
||||
}
|
||||
} else {
|
||||
int this_parameter = alias.parameter_number;
|
||||
if (this_parameter >= number_of_parameters) {
|
||||
return InvalidArgument(
|
||||
"Unexpected parameter number %d in alias config without tupled "
|
||||
"inputs and %d parameters",
|
||||
this_parameter, number_of_parameters);
|
||||
}
|
||||
parameters_to_donate.insert(this_parameter);
|
||||
}
|
||||
return Status::OK();
|
||||
}));
|
||||
return absl::flat_hash_set<int>();
|
||||
return parameters_to_donate;
|
||||
}
|
||||
|
||||
namespace {
|
||||
@ -484,7 +548,8 @@ void PjRtBuffer::ScopedHold::AddToInput(
|
||||
|
||||
/* static */
|
||||
StatusOr<std::unique_ptr<PjRtBuffer>> PjRtBuffer::FromHostBuffer(
|
||||
const void* data, const Shape& shape, bool force_copy,
|
||||
const void* data, const Shape& shape,
|
||||
HostBufferSemantics host_buffer_semantics,
|
||||
std::shared_ptr<void> buffer_reference, PjRtClient* client,
|
||||
Device* device) {
|
||||
tensorflow::profiler::TraceMe traceme("PjRtBuffer::FromHostBuffer");
|
||||
@ -495,34 +560,63 @@ StatusOr<std::unique_ptr<PjRtBuffer>> PjRtBuffer::FromHostBuffer(
|
||||
}
|
||||
TF_ASSIGN_OR_RETURN(LocalDeviceState * local_device,
|
||||
device->GetLocalDeviceState());
|
||||
|
||||
// If we are on the host platform and the input buffer is sufficiently
|
||||
// aligned, we can simply point to the input array's data without any further
|
||||
// copies. At the time of writing we require a 16-byte alignment because XLA
|
||||
// may generate code which requires it.
|
||||
if (!force_copy &&
|
||||
((absl::bit_cast<std::uintptr_t>(data) &
|
||||
(cpu_function_runtime::kMinAlign - 1)) == 0) &&
|
||||
local_device->executor()->platform()->id() == se::host::kHostPlatformId) {
|
||||
std::function<void()> on_delete_callback =
|
||||
[buffer_reference{std::move(buffer_reference)}]() {
|
||||
// Frees buffer_reference.
|
||||
};
|
||||
se::DeviceMemoryBase buffer(const_cast<void*>(data),
|
||||
ShapeUtil::ByteSizeOf(shape));
|
||||
absl::Span<const std::shared_ptr<BufferSequencingEvent>> definition_events;
|
||||
auto device_buffer = std::make_shared<TrackedDeviceBuffer>(
|
||||
/*allocator=*/nullptr, local_device->device_ordinal(),
|
||||
std::initializer_list<se::DeviceMemoryBase>{buffer}, definition_events,
|
||||
std::move(on_delete_callback));
|
||||
return absl::make_unique<PjRtBuffer>(shape, shape, std::move(device_buffer),
|
||||
client, device);
|
||||
}
|
||||
int64 size = ShapeUtil::ByteSizeOf(shape);
|
||||
|
||||
TransferManager* transfer_manager =
|
||||
client->client()->backend().transfer_manager();
|
||||
TF_ASSIGN_OR_RETURN(Shape compact_shape,
|
||||
transfer_manager->ChooseCompactLayoutForShape(shape));
|
||||
|
||||
// The CPU platform is special because the "host" and the "device" are in the
|
||||
// same memory space. If the input shape is in the correct layout and we don't
|
||||
// want to defer the copy onto a thread, we can use the following fast
|
||||
// path.
|
||||
bool is_cpu_platform =
|
||||
local_device->executor()->platform()->id() == se::host::kHostPlatformId;
|
||||
if (is_cpu_platform) {
|
||||
// If we are on the host platform and the input buffer is sufficiently
|
||||
// aligned, we can simply point to the input array's data without any
|
||||
// further copies. At the time of writing we require a 16-byte alignment
|
||||
// because XLA may generate code which requires it.
|
||||
bool can_use_zero_copy =
|
||||
host_buffer_semantics == HostBufferSemantics::kZeroCopy &&
|
||||
((absl::bit_cast<std::uintptr_t>(data) &
|
||||
(cpu_function_runtime::kMinAlign - 1)) == 0);
|
||||
if (shape.layout() == compact_shape.layout() &&
|
||||
(host_buffer_semantics ==
|
||||
HostBufferSemantics::kImmutableOnlyDuringCall ||
|
||||
can_use_zero_copy)) {
|
||||
std::function<void()> on_delete_callback;
|
||||
se::DeviceMemoryBase buffer;
|
||||
// If we are on the host platform and the input buffer is sufficiently
|
||||
// aligned, we can simply point to the input array's data without any
|
||||
// further copies. At the time of writing we require a 16-byte alignment
|
||||
// because XLA may generate code which requires it.
|
||||
if (can_use_zero_copy) {
|
||||
on_delete_callback = [buffer_reference{std::move(buffer_reference)}]() {
|
||||
// Frees buffer_reference.
|
||||
};
|
||||
buffer = se::DeviceMemoryBase(const_cast<void*>(data), size);
|
||||
} else {
|
||||
void* staging_buffer = client->host_memory_allocator()->AllocateRaw(
|
||||
cpu_function_runtime::kMinAlign, size);
|
||||
on_delete_callback = [staging_buffer, client]() {
|
||||
client->host_memory_allocator()->DeallocateRaw(staging_buffer);
|
||||
};
|
||||
buffer = se::DeviceMemoryBase(staging_buffer, size);
|
||||
std::memcpy(staging_buffer, data, size);
|
||||
}
|
||||
absl::Span<const std::shared_ptr<BufferSequencingEvent>>
|
||||
definition_events;
|
||||
auto device_buffer = std::make_shared<TrackedDeviceBuffer>(
|
||||
/*allocator=*/nullptr, local_device->device_ordinal(),
|
||||
std::initializer_list<se::DeviceMemoryBase>{buffer},
|
||||
definition_events, std::move(on_delete_callback));
|
||||
return absl::make_unique<PjRtBuffer>(
|
||||
shape, shape, std::move(device_buffer), client, device);
|
||||
}
|
||||
}
|
||||
|
||||
TF_ASSIGN_OR_RETURN(
|
||||
std::unique_ptr<PjRtBuffer> py_buffer,
|
||||
AllocateDestinationBuffer(compact_shape, device, local_device,
|
||||
@ -531,17 +625,41 @@ StatusOr<std::unique_ptr<PjRtBuffer>> PjRtBuffer::FromHostBuffer(
|
||||
ScopedHold device_buffer(py_buffer->GetBufferWithUsageHold());
|
||||
CHECK(device_buffer.ok());
|
||||
|
||||
// If necessary, allocate a host-side buffer for staging host-to-device
|
||||
// transfers. On GPU this is a buffer in pinned memory.
|
||||
std::shared_ptr<void> staging_buffer;
|
||||
if (host_buffer_semantics == HostBufferSemantics::kImmutableOnlyDuringCall ||
|
||||
client->should_stage_host_to_device_transfers()) {
|
||||
void* ptr = client->host_memory_allocator()->AllocateRaw(
|
||||
tensorflow::Allocator::kAllocatorAlignment, size);
|
||||
staging_buffer = std::shared_ptr<void>(ptr, [client](void* ptr) {
|
||||
client->host_memory_allocator()->DeallocateRaw(ptr);
|
||||
});
|
||||
}
|
||||
|
||||
// Copy the buffer into a staging buffer before returning control to the
|
||||
// caller if the caller only guaranteed that the buffer is valid for the
|
||||
// duration of the call. Otherwise, we stage (if necessary) on a separate
|
||||
// thread.
|
||||
if (host_buffer_semantics == HostBufferSemantics::kImmutableOnlyDuringCall) {
|
||||
std::memcpy(staging_buffer.get(), data, size);
|
||||
buffer_reference.reset();
|
||||
data = nullptr;
|
||||
}
|
||||
|
||||
// The host to device transfer is performed on a thread pool, mostly because
|
||||
// it includes linearization that may be slow. It is OK to capture the
|
||||
// py_buffer pointer because the py_buffer can't be deleted until all the
|
||||
// usage holds have gone away.
|
||||
// TODO(misard) assess if it would be preferable to introduce a heuristic to
|
||||
// put the transfer into the calling thread for small literals.
|
||||
auto transfer_h2d = [client, transfer_manager, local_device,
|
||||
movable_device_buffer{device_buffer.ToClosure()}, data,
|
||||
shape, py_buffer{py_buffer.get()}, compact_shape,
|
||||
auto transfer_h2d = [client, transfer_manager, local_device, data, size,
|
||||
movable_device_buffer{device_buffer.ToClosure()}, shape,
|
||||
py_buffer{py_buffer.get()}, compact_shape,
|
||||
on_device_shape{py_buffer->on_device_shape()},
|
||||
buffer_reference{std::move(buffer_reference)}]() {
|
||||
staging_buffer{std::move(staging_buffer)},
|
||||
buffer_reference{std::move(buffer_reference)},
|
||||
host_buffer_semantics]() {
|
||||
ScopedHold device_buffer(movable_device_buffer);
|
||||
// This function uses TF_CHECK_OK and ValueOrDie() since we have no way
|
||||
// to report failures from a callback. However, the operations here are
|
||||
@ -551,20 +669,16 @@ StatusOr<std::unique_ptr<PjRtBuffer>> PjRtBuffer::FromHostBuffer(
|
||||
|
||||
ShapedBuffer buffer = device_buffer->AsShapedBuffer(
|
||||
compact_shape, on_device_shape, client->client()->platform());
|
||||
|
||||
std::shared_ptr<void> staging_buffer;
|
||||
|
||||
// If applicable on the backend, stage the transfer via host memory
|
||||
// allocated via the host_memory_allocator. On GPU, this is pinned
|
||||
// memory.
|
||||
if (client->host_memory_allocator()) {
|
||||
int64 size = ShapeUtil::ByteSizeOf(shape);
|
||||
void* ptr = client->host_memory_allocator()->AllocateRaw(
|
||||
tensorflow::Allocator::kAllocatorAlignment, size);
|
||||
staging_buffer = std::shared_ptr<void>(ptr, [client](void* ptr) {
|
||||
client->host_memory_allocator()->DeallocateRaw(ptr);
|
||||
});
|
||||
std::memcpy(ptr, data, size);
|
||||
if (staging_buffer) {
|
||||
// If we didn't already copy the input buffer into the staging buffer,
|
||||
// do so now.
|
||||
if (host_buffer_semantics !=
|
||||
HostBufferSemantics::kImmutableOnlyDuringCall) {
|
||||
std::memcpy(staging_buffer.get(), data, size);
|
||||
}
|
||||
BorrowingLiteral literal(static_cast<const char*>(staging_buffer.get()),
|
||||
shape);
|
||||
TF_CHECK_OK(transfer_manager->TransferLiteralToDeviceAsync(
|
||||
@ -584,9 +698,15 @@ StatusOr<std::unique_ptr<PjRtBuffer>> PjRtBuffer::FromHostBuffer(
|
||||
|
||||
local_device->ThenRelease(
|
||||
local_device->host_to_device_stream(),
|
||||
std::make_pair(buffer_reference, std::move(staging_buffer)));
|
||||
std::make_pair(std::move(buffer_reference), std::move(staging_buffer)));
|
||||
};
|
||||
client->h2d_transfer_pool()->Schedule(transfer_h2d);
|
||||
if (is_cpu_platform) {
|
||||
// Using the h2d_transfer_pool would be a double thread hop; the code
|
||||
// already defers its work onto a stream (= thread on CPU).
|
||||
transfer_h2d();
|
||||
} else {
|
||||
client->h2d_transfer_pool()->Schedule(transfer_h2d);
|
||||
}
|
||||
return py_buffer;
|
||||
}
|
||||
|
||||
|
@ -128,6 +128,7 @@ class PjRtClient {
|
||||
std::vector<std::unique_ptr<Device>> devices, int host_id,
|
||||
std::unique_ptr<se::DeviceMemoryAllocator> allocator,
|
||||
std::unique_ptr<tensorflow::Allocator> host_memory_allocator,
|
||||
bool should_stage_host_to_device_transfers,
|
||||
std::unique_ptr<GpuExecutableRunOptions> gpu_run_options);
|
||||
virtual ~PjRtClient() = default;
|
||||
|
||||
@ -153,6 +154,9 @@ class PjRtClient {
|
||||
tensorflow::Allocator* host_memory_allocator() const {
|
||||
return host_memory_allocator_.get();
|
||||
}
|
||||
bool should_stage_host_to_device_transfers() const {
|
||||
return should_stage_host_to_device_transfers_;
|
||||
}
|
||||
|
||||
GpuExecutableRunOptions* gpu_run_options() const {
|
||||
return gpu_run_options_.get();
|
||||
@ -190,6 +194,9 @@ class PjRtClient {
|
||||
std::string platform_name_;
|
||||
LocalClient* client_;
|
||||
|
||||
// Allocator to be used for staging memory transfers to devices.
|
||||
std::unique_ptr<tensorflow::Allocator> host_memory_allocator_;
|
||||
|
||||
// Includes all devices, including non-local devices on multi-host platforms.
|
||||
std::vector<std::unique_ptr<Device>> devices_;
|
||||
// Maps Device::id() to the corresponding Device. Includes all devices.
|
||||
@ -201,10 +208,10 @@ class PjRtClient {
|
||||
se::DeviceMemoryAllocator* allocator_;
|
||||
std::unique_ptr<se::DeviceMemoryAllocator> owned_allocator_;
|
||||
|
||||
// Allocator to be used for staging memory transfers to devices. Optional;
|
||||
// only used on GPU where it is more efficient to copy buffers to and from the
|
||||
// device via a staging area of pinned memory.
|
||||
std::unique_ptr<tensorflow::Allocator> host_memory_allocator_;
|
||||
// Should we always prefer to stage host-to-device transfers via memory
|
||||
// allocated on host_memory_allocator_? True only on GPU, where we prefer to
|
||||
// transfer via pinned memory.
|
||||
bool should_stage_host_to_device_transfers_;
|
||||
|
||||
std::unique_ptr<GpuExecutableRunOptions> gpu_run_options_;
|
||||
|
||||
@ -396,13 +403,35 @@ class PjRtBuffer {
|
||||
StatusOr<std::shared_ptr<TrackedDeviceBuffer>> buffer_or_;
|
||||
};
|
||||
|
||||
// If `force_copy` is true, forces a copy of the input buffer on CPU.
|
||||
// Otherwise the library is free to alias the output buffer with `data`.
|
||||
// `buffer_reference` is an optional shared pointer that should be kept alive
|
||||
// by the runtime as long as the contents of `data` may still be accessed by
|
||||
// the runtime (may be nullptr).
|
||||
// Describes the semantics the caller to FromHostBuffer expects from the
|
||||
// runtime, in a total order from most restrictive to least restrictive.
|
||||
enum class HostBufferSemantics {
|
||||
// The runtime may not hold references to `data` after the call to
|
||||
// `FromHostBuffer` completes. The caller promises that `data` is immutable
|
||||
// and will not be freed only for the duration of the FromHostBuffer call.
|
||||
// `buffer_reference` will be freed by the time `FromHostBuffer` returns.
|
||||
kImmutableOnlyDuringCall,
|
||||
|
||||
// The runtime may hold onto `data` after the call to `FromHostBuffer`
|
||||
// returns while the runtime completes a transfer to the device. The caller
|
||||
// promises not to mutate or free `data` until the transfer completes, at
|
||||
// which point the runtime will release `buffer_reference`. It is also
|
||||
// correct to wait on the host (directly or indirectly) for the buffer's
|
||||
// definition event to complete.
|
||||
kImmutableUntilTransferCompletes,
|
||||
|
||||
// The PjRtBuffer may alias `data` internally and the runtime may use the
|
||||
// `data` contents as long as the buffer is alive.
|
||||
// The caller promises to keep `data` alive and not to mutate its contents
|
||||
// as long as the buffer is alive; to notify the caller that the buffer may
|
||||
// be freed, the runtime will release its `buffer_reference` when the
|
||||
// PjRtBuffer is freed. On non-CPU platforms this acts identically to
|
||||
// kImmutableUntilTransferCompletes.
|
||||
kZeroCopy,
|
||||
};
|
||||
static StatusOr<std::unique_ptr<PjRtBuffer>> FromHostBuffer(
|
||||
const void* data, const Shape& shape, bool force_copy,
|
||||
const void* data, const Shape& shape,
|
||||
HostBufferSemantics host_buffer_semantics,
|
||||
std::shared_ptr<void> buffer_reference, PjRtClient* client,
|
||||
Device* device);
|
||||
|
||||
|
@ -84,7 +84,8 @@ PyClient::GetDefaultDeviceAssignment1D(int num_replicas) {
|
||||
}
|
||||
|
||||
StatusOr<std::unique_ptr<PyBuffer>> PyClient::BufferFromPyal(
|
||||
const pybind11::object& argument, Device* device, bool force_copy) {
|
||||
const pybind11::object& argument, Device* device, bool force_copy,
|
||||
PjRtBuffer::HostBufferSemantics host_buffer_semantics) {
|
||||
if (device == nullptr) {
|
||||
TF_RET_CHECK(!pjrt_client_->local_devices().empty());
|
||||
device = pjrt_client_->local_devices().front();
|
||||
@ -111,9 +112,9 @@ StatusOr<std::unique_ptr<PyBuffer>> PyClient::BufferFromPyal(
|
||||
{
|
||||
py::gil_scoped_release gil_release;
|
||||
TF_ASSIGN_OR_RETURN(
|
||||
buffer, PjRtBuffer::FromHostBuffer(c->buf_ptr, c->shape, force_copy,
|
||||
std::move(py_buffer_ref),
|
||||
pjrt_client_.get(), device));
|
||||
buffer, PjRtBuffer::FromHostBuffer(
|
||||
c->buf_ptr, c->shape, host_buffer_semantics,
|
||||
std::move(py_buffer_ref), pjrt_client_.get(), device));
|
||||
}
|
||||
auto traceback = Traceback::Get();
|
||||
return std::make_unique<PyBuffer>(shared_from_this(), std::move(buffer),
|
||||
|
@ -120,7 +120,8 @@ class PyClient : public std::enable_shared_from_this<PyClient> {
|
||||
}
|
||||
|
||||
StatusOr<std::unique_ptr<PyBuffer>> BufferFromPyal(
|
||||
const pybind11::object& argument, Device* device, bool force_copy);
|
||||
const pybind11::object& argument, Device* device, bool force_copy,
|
||||
PjRtBuffer::HostBufferSemantics host_buffer_semantics);
|
||||
|
||||
StatusOr<std::unique_ptr<PyExecutable>> Compile(
|
||||
const XlaComputation& computation, CompileOptions options);
|
||||
|
@ -509,6 +509,13 @@ PYBIND11_MODULE(xla_extension, m) {
|
||||
.value("PLATFORM", GpuAllocatorConfig::Kind::kPlatform)
|
||||
.value("BFC", GpuAllocatorConfig::Kind::kBFC);
|
||||
|
||||
py::enum_<PjRtBuffer::HostBufferSemantics>(m, "HostBufferSemantics")
|
||||
.value("IMMUTABLE_ONLY_DURING_CALL",
|
||||
PjRtBuffer::HostBufferSemantics::kImmutableOnlyDuringCall)
|
||||
.value("IMMUTABLE_UNTIL_TRANSFER_COMPLETES",
|
||||
PjRtBuffer::HostBufferSemantics::kImmutableUntilTransferCompletes)
|
||||
.value("ZERO_COPY", PjRtBuffer::HostBufferSemantics::kZeroCopy);
|
||||
|
||||
py::class_<PyClient, std::shared_ptr<PyClient>> py_local_client(m, "Client");
|
||||
py_local_client.def_property_readonly("platform", &PyClient::platform_name)
|
||||
.def("device_count", &PyClient::device_count)
|
||||
@ -527,7 +534,9 @@ PYBIND11_MODULE(xla_extension, m) {
|
||||
.def("create_host_to_device_channel_handle",
|
||||
&PyClient::CreateHostToDeviceChannelHandle)
|
||||
.def("buffer_from_pyval", &PyClient::BufferFromPyal, py::arg("argument"),
|
||||
py::arg("device") = nullptr, py::arg("force_copy") = false)
|
||||
py::arg("device") = nullptr, py::arg("force_copy") = false,
|
||||
py::arg("host_buffer_semantics") =
|
||||
PjRtBuffer::HostBufferSemantics::kZeroCopy)
|
||||
.def("compile", &PyClient::Compile, py::arg("computation"),
|
||||
py::arg("compile_options") = CompileOptions())
|
||||
.def("heap_profile", &PyClient::HeapProfile);
|
||||
|
@ -304,6 +304,7 @@ def computation_count():
|
||||
Device = _xla.Device
|
||||
CompileOptions = _xla.CompileOptions
|
||||
|
||||
HostBufferSemantics = _xla.HostBufferSemantics
|
||||
|
||||
# An Executable is a C++ class that duck types with the following API:
|
||||
# class Executable(object):
|
||||
|
@ -1909,11 +1909,6 @@ def TestFactory(xla_backend, cloud_tpu=False):
|
||||
out = ops.Add(p1, p2)
|
||||
c.setup_alias([], 0, [])
|
||||
c = c.build(out)
|
||||
if self.backend.platform != "tpu":
|
||||
with self.assertRaisesRegex(
|
||||
RuntimeError, "Buffer aliasing is not supported "
|
||||
"by XLA for non-TPU backends"):
|
||||
self.backend.compile(c)
|
||||
|
||||
tests.append(AliasTest)
|
||||
|
||||
@ -1991,7 +1986,8 @@ def TestFactory(xla_backend, cloud_tpu=False):
|
||||
def testRoundTrip(self, dtype, shape):
|
||||
x = np.array(np.random.rand(*shape) * 100, dtype=dtype)
|
||||
x_ptr = x.__array_interface__["data"][0]
|
||||
buffer = self.backend.buffer_from_pyval(x)
|
||||
buffer = self.backend.buffer_from_pyval(
|
||||
x, host_buffer_semantics=xla_client.HostBufferSemantics.ZERO_COPY)
|
||||
y = np.array(buffer, copy=False)
|
||||
y_ptr = y.__array_interface__["data"][0]
|
||||
np.testing.assert_array_equal(x, y)
|
||||
@ -2000,7 +1996,9 @@ def TestFactory(xla_backend, cloud_tpu=False):
|
||||
self.assertTrue((x_ptr & 15) != 0 or x_ptr == y_ptr)
|
||||
self.assertEqual(y_ptr, buffer.unsafe_buffer_pointer())
|
||||
|
||||
buffer2 = self.backend.buffer_from_pyval(x, force_copy=True)
|
||||
during_call = xla_client.HostBufferSemantics.IMMUTABLE_ONLY_DURING_CALL
|
||||
buffer2 = self.backend.buffer_from_pyval(
|
||||
x, host_buffer_semantics=during_call)
|
||||
z = np.array(buffer2, copy=False)
|
||||
self.assertNotEqual(x.__array_interface__["data"][0],
|
||||
z.__array_interface__["data"][0])
|
||||
|
@ -3304,6 +3304,15 @@ tf_cc_test(
|
||||
],
|
||||
)
|
||||
|
||||
cc_library(
|
||||
name = "memory_space_assignment_utils",
|
||||
srcs = ["memory_space_assignment_utils.cc"],
|
||||
hdrs = ["memory_space_assignment_utils.h"],
|
||||
deps = [
|
||||
":heap_simulator",
|
||||
],
|
||||
)
|
||||
|
||||
cc_library(
|
||||
name = "memory_space_assignment",
|
||||
srcs = ["memory_space_assignment.cc"],
|
||||
@ -3311,6 +3320,7 @@ cc_library(
|
||||
deps = [
|
||||
":heap_simulator",
|
||||
":hlo_cost_analysis",
|
||||
":memory_space_assignment_utils",
|
||||
"//tensorflow/compiler/xla:debug_options_flags",
|
||||
"//tensorflow/core/lib/math:math_util",
|
||||
],
|
||||
|
@ -573,6 +573,7 @@ void AlgebraicSimplifierVisitor::ReplaceWithBitcast(HloInstruction* instruction,
|
||||
|
||||
auto bitcast = computation_->AddInstruction(
|
||||
HloInstruction::CreateBitcast(instruction->shape(), operand));
|
||||
bitcast->set_metadata(instruction->metadata());
|
||||
TF_CHECK_OK(ReplaceInstruction(instruction, bitcast));
|
||||
}
|
||||
|
||||
@ -2454,6 +2455,25 @@ Status AlgebraicSimplifierVisitor::HandleMultiply(HloInstruction* multiply) {
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
{
|
||||
HloInstruction *convert_operand, *operand;
|
||||
// Mul(Convert(Pred), operand) => select(pred, operand, 0)
|
||||
if (Match(multiply,
|
||||
m::MultiplyAnyOrder(
|
||||
m::Op(&operand),
|
||||
m::Convert(
|
||||
m::Op(&convert_operand)
|
||||
.WithShape(m::Shape().WithElementType(PRED)))))) {
|
||||
HloInstruction* zero_like_multiply =
|
||||
BroadcastZeros(computation_, multiply->shape().element_type(),
|
||||
multiply->shape().dimensions());
|
||||
return ReplaceWithNewInstruction(
|
||||
multiply, HloInstruction::CreateTernary(
|
||||
multiply->shape(), HloOpcode::kSelect, convert_operand,
|
||||
operand, zero_like_multiply));
|
||||
}
|
||||
}
|
||||
|
||||
VLOG(10) << "trying transform [(A * C1) * C2 => A * (C1 * C2)]";
|
||||
HloInstruction *a, *c1, *c2;
|
||||
if (Match(multiply,
|
||||
|
@ -29,6 +29,7 @@ limitations under the License.
|
||||
#include "mlir/IR/Builders.h" // from @llvm-project
|
||||
#include "mlir/IR/Function.h" // from @llvm-project
|
||||
#include "mlir/IR/MLIRContext.h" // from @llvm-project
|
||||
#include "mlir/IR/OperationSupport.h" // from @llvm-project
|
||||
#include "mlir/IR/Value.h" // from @llvm-project
|
||||
#include "tensorflow/compiler/xla/service/cpu/cpu_options.h"
|
||||
#include "tensorflow/compiler/xla/service/cpu/cpu_runtime.h"
|
||||
@ -255,7 +256,8 @@ Status DotOpEmitter::EmitLinalgMatmul() {
|
||||
mlir::edsc::ScopedContext scope(*builder, function.getLoc());
|
||||
mlir::Value a = function.getArgument(0), b = function.getArgument(1),
|
||||
c = function.getArgument(2);
|
||||
mlir::edsc::intrinsics::linalg_matmul(b, c, a);
|
||||
mlir::edsc::intrinsics::linalg_matmul(mlir::TypeRange{},
|
||||
mlir::ValueRange{b, c, a});
|
||||
mlir::edsc::intrinsics::std_ret();
|
||||
});
|
||||
}
|
||||
|
@ -29,10 +29,12 @@ namespace xla {
|
||||
|
||||
namespace {
|
||||
|
||||
// Convert a dot into a canonical form where non-contracting and contracting
|
||||
// dimensions are reshaped together and batch dimensions are the most major
|
||||
// dimensions. This requires transposing and reshapes of the lhs and rhs and
|
||||
// reshaping the output batch to the original shape.
|
||||
// Convert a dot into a canonical form;
|
||||
// * Non-contracting dimensions are reshaped together,
|
||||
// * Contracting dimensions are reshaped together,
|
||||
// * Batch dimensions are the most major dimensions.
|
||||
// This requires transposing and reshaping of the lhs and rhs, and reshaping the
|
||||
// output batch to the original shape.
|
||||
Status CanonicalizeDot(HloInstruction* original_dot) {
|
||||
auto computation = original_dot->parent();
|
||||
const auto& original_dnums = original_dot->dot_dimension_numbers();
|
||||
@ -63,7 +65,8 @@ Status CanonicalizeDot(HloInstruction* original_dot) {
|
||||
}
|
||||
}
|
||||
// The canonical form of the lhs is
|
||||
// [BatchDims, NonContractingDims, ContractingsDims]
|
||||
// [BatchDims, NonContractingDimsProduct, ContractingsDimsProduct]
|
||||
// If NonContractingDimsProduct is 1, it is omitted.
|
||||
std::vector<int64> lhs_transpose;
|
||||
lhs_transpose.reserve(lhs_rank);
|
||||
lhs_transpose.insert(lhs_transpose.end(),
|
||||
@ -109,7 +112,8 @@ Status CanonicalizeDot(HloInstruction* original_dot) {
|
||||
}
|
||||
|
||||
// The canonical form of the rhs is
|
||||
// [BatchDims, ContractingsDims, NonContractingDims]
|
||||
// [BatchDims, NonContractingDimsProduct, ContractingsDimsProduct]
|
||||
// If NonContractingDimsProduct is 1, it is omitted.
|
||||
std::vector<int64> rhs_transpose;
|
||||
rhs_transpose.reserve(rhs_rank);
|
||||
rhs_transpose.insert(rhs_transpose.end(),
|
||||
|
@ -1336,9 +1336,40 @@ StatusOr<llvm::Value*> ElementalIrEmitter::EmitLog1p(PrimitiveType prim_type,
|
||||
// When x is large, the naive evaluation of ln(x + 1) is more
|
||||
// accurate than the Taylor series.
|
||||
TF_ASSIGN_OR_RETURN(auto for_large_x, EmitLog(prim_type, FAdd(x, one)));
|
||||
// The Taylor series for ln(x+1) is x - x^2/2 - x^3/3 + ….
|
||||
auto for_small_x = FMul(FAdd(FMul(negative_half, x), one), x);
|
||||
const auto kAntilogarithmIsSmallThreshold = 1e-4;
|
||||
// When x is small, (defined to be less than sqrt(2) / 2), use a rational
|
||||
// approximation. The approximation below is based on one from the Cephes
|
||||
// Mathematical Library.
|
||||
//
|
||||
// sqrt(2) - 1.
|
||||
const auto kAntilogarithmIsSmallThreshold = 0.41421356237309504880;
|
||||
|
||||
static const std::array<double, 7> kDenominatorCoeffs{
|
||||
1.,
|
||||
1.5062909083469192043167E1,
|
||||
8.3047565967967209469434E1,
|
||||
2.2176239823732856465394E2,
|
||||
3.0909872225312059774938E2,
|
||||
2.1642788614495947685003E2,
|
||||
6.0118660497603843919306E1,
|
||||
};
|
||||
|
||||
static const std::array<double, 7> kNumeratorCoeffs{
|
||||
4.5270000862445199635215E-5, 4.9854102823193375972212E-1,
|
||||
6.5787325942061044846969E0, 2.9911919328553073277375E1,
|
||||
6.0949667980987787057556E1, 5.7112963590585538103336E1,
|
||||
2.0039553499201281259648E1,
|
||||
};
|
||||
|
||||
auto x_squared = FMul(x, x);
|
||||
TF_ASSIGN_OR_RETURN(auto denominator,
|
||||
EvaluatePolynomial(type, x, kDenominatorCoeffs));
|
||||
TF_ASSIGN_OR_RETURN(auto numerator,
|
||||
EvaluatePolynomial(type, x, kNumeratorCoeffs));
|
||||
auto for_small_x = FDiv(numerator, denominator);
|
||||
for_small_x = FMul(FMul(x, x_squared), for_small_x);
|
||||
for_small_x = FAdd(FMul(negative_half, x_squared), for_small_x);
|
||||
for_small_x = FAdd(x, for_small_x);
|
||||
|
||||
auto abs_x =
|
||||
llvm_ir::EmitCallToIntrinsic(llvm::Intrinsic::fabs, {value}, {type}, b_);
|
||||
auto x_is_small = FCmpOLT(
|
||||
@ -2699,4 +2730,14 @@ StatusOr<llvm::Value*> ElementalIrEmitter::EmitElementalReduce(
|
||||
}
|
||||
}
|
||||
|
||||
// Evaluate polynomial using Horner's method.
|
||||
StatusOr<llvm::Value*> ElementalIrEmitter::EvaluatePolynomial(
|
||||
llvm::Type* type, llvm::Value* x, absl::Span<const double> coefficients) {
|
||||
llvm::Value* poly = llvm::ConstantFP::get(type, 0.0);
|
||||
for (const double c : coefficients) {
|
||||
poly = FAdd(FMul(poly, x), llvm::ConstantFP::get(type, c));
|
||||
}
|
||||
return poly;
|
||||
}
|
||||
|
||||
} // namespace xla
|
||||
|
@ -258,6 +258,10 @@ class ElementalIrEmitter : public IrBuilderMixin<ElementalIrEmitter> {
|
||||
StatusOr<llvm::Value*> EmitComplexPower(const HloInstruction* op,
|
||||
llvm::Value* a, llvm::Value* b,
|
||||
llvm::Value* c, llvm::Value* d);
|
||||
|
||||
// Evaluates a polynomial using Horner's method.
|
||||
StatusOr<llvm::Value*> EvaluatePolynomial(
|
||||
llvm::Type* type, llvm::Value* x, absl::Span<const double> coefficients);
|
||||
};
|
||||
|
||||
} // namespace xla
|
||||
|
@ -39,6 +39,7 @@ cc_library(
|
||||
"@com_google_absl//absl/memory",
|
||||
"@com_google_absl//absl/strings",
|
||||
"@com_google_absl//absl/strings:str_format",
|
||||
"@llvm-project//llvm:AMDGPUCodeGen",
|
||||
"@llvm-project//llvm:Analysis",
|
||||
"@llvm-project//llvm:BitReader",
|
||||
"@llvm-project//llvm:BitWriter",
|
||||
@ -52,7 +53,6 @@ cc_library(
|
||||
"@llvm-project//llvm:Scalar",
|
||||
"@llvm-project//llvm:Support",
|
||||
"@llvm-project//llvm:Target",
|
||||
"@llvm-project//llvm:amdgpu_code_gen",
|
||||
],
|
||||
)
|
||||
|
||||
|
@ -101,6 +101,7 @@ class EnforceMinorToMajorReduceOpVisitor : public DfsHloRewriteVisitor {
|
||||
new_reduce_shape_layout);
|
||||
HloInstruction *canonical_reduce_input = reduce->parent()->AddInstruction(
|
||||
HloInstruction::CreateBitcast(new_operand_shape, operand));
|
||||
canonical_reduce_input->set_metadata(reduce->metadata());
|
||||
|
||||
VLOG(5) << "Reduction input: " << canonical_reduce_input->ToString();
|
||||
std::unique_ptr<HloInstruction> new_reduce = HloInstruction::CreateReduce(
|
||||
|
@ -539,6 +539,15 @@ HloInstruction* BroadcastZeros(HloComputation* computation,
|
||||
/*result_shape_bounds=*/broadcast_dimensions);
|
||||
}
|
||||
|
||||
HloInstruction* BroadcastOnes(HloComputation* computation,
|
||||
PrimitiveType element_type,
|
||||
absl::Span<const int64> broadcast_dimensions) {
|
||||
HloInstruction* one = computation->AddInstruction(
|
||||
HloInstruction::CreateConstant(LiteralUtil::One(element_type)));
|
||||
return MakeBroadcastHlo(one, /*broadcast_dimensions=*/{},
|
||||
/*result_shape_bounds=*/broadcast_dimensions);
|
||||
}
|
||||
|
||||
// Recursively creates a dummy op given a shape. Leaf nodes are broadcasted zero
|
||||
// while internal nodes are tuples.
|
||||
HloInstruction* CreateDummyOp(HloComputation::Builder* b, const Shape& shape) {
|
||||
|
@ -276,6 +276,11 @@ HloInstruction* BroadcastZeros(HloComputation* computation,
|
||||
PrimitiveType element_type,
|
||||
absl::Span<const int64> broadcast_dimensions);
|
||||
|
||||
// Same as above, but fill the tensor with ones.
|
||||
HloInstruction* BroadcastOnes(HloComputation* computation,
|
||||
PrimitiveType element_type,
|
||||
absl::Span<const int64> broadcast_dimensions);
|
||||
|
||||
// Creates a HLO computation that takes arguments of type `domain` and produces
|
||||
// a value of type `range`.
|
||||
StatusOr<std::unique_ptr<HloComputation>> CreateComputationWithSignature(
|
||||
|
@ -698,7 +698,7 @@ bool HloDataflowAnalysis::UpdateCollectivePermuteDoneValueSet(
|
||||
CHECK_EQ(collective_permute_done->opcode(),
|
||||
HloOpcode::kCollectivePermuteDone);
|
||||
bool changed = false;
|
||||
// CollectivePermuteDone forwards the operand value at {0} to its output.
|
||||
// CollectivePermuteDone forwards the operand value at {1} to its output.
|
||||
const HloValueSet& operand_value_set =
|
||||
GetValueSet(collective_permute_done->operand(0), {1});
|
||||
HloValueSet& value_set = GetValueSet(collective_permute_done);
|
||||
@ -945,6 +945,17 @@ Status HloDataflowAnalysis::InitializeInstructionValueSets() {
|
||||
// CopyDone consumes a tuple produced by CopyStart and produces an
|
||||
// element. Its output aliases its input tuple element {0}.
|
||||
break;
|
||||
case HloOpcode::kCollectivePermuteStart:
|
||||
// CollectivePermuteStart produces a tuple of
|
||||
// {aliased operand, destination buffer, U32 context, U32 context}.
|
||||
define_value_at(/*index=*/{});
|
||||
define_value_at(/*index=*/{1});
|
||||
define_value_at(/*index=*/{2});
|
||||
define_value_at(/*index=*/{3});
|
||||
break;
|
||||
case HloOpcode::kCollectivePermuteDone:
|
||||
// CollectivePermuteDone's output aliases its input tuple element {1}.
|
||||
break;
|
||||
case HloOpcode::kRecvDone:
|
||||
// RecvDone produces a two-element tuple. Element zero aliases its
|
||||
// input tuple element {0}; element one is a token.
|
||||
|
@ -550,6 +550,7 @@ bool HloCollectiveInstruction::IdenticalSlowPath(
|
||||
const auto& casted_other =
|
||||
static_cast<const HloCollectiveInstruction&>(other);
|
||||
return HloChannelInstruction::IdenticalSlowPath(other, eq_computations) &&
|
||||
constrain_layout() == casted_other.constrain_layout() &&
|
||||
absl::c_equal(replica_groups(), casted_other.replica_groups(),
|
||||
[](const ReplicaGroup& a, const ReplicaGroup& b) {
|
||||
return absl::c_equal(a.replica_ids(), b.replica_ids());
|
||||
@ -1101,7 +1102,9 @@ bool HloMapInstruction::IdenticalSlowPath(
|
||||
const HloInstruction& other,
|
||||
const std::function<bool(const HloComputation*, const HloComputation*)>&
|
||||
eq_computations) const {
|
||||
return eq_computations(to_apply(), other.to_apply());
|
||||
const auto& casted_other = static_cast<const HloMapInstruction&>(other);
|
||||
return eq_computations(to_apply(), casted_other.to_apply()) &&
|
||||
dimensions() == casted_other.dimensions();
|
||||
}
|
||||
|
||||
std::unique_ptr<HloInstruction> HloMapInstruction::CloneWithNewOperandsImpl(
|
||||
@ -2515,7 +2518,8 @@ bool HloDynamicSliceInstruction::IdenticalSlowPath(
|
||||
const HloInstruction& other,
|
||||
const std::function<bool(const HloComputation*, const HloComputation*)>&
|
||||
eq_computations) const {
|
||||
return true;
|
||||
const auto& casted_other = static_cast<const HloMapInstruction&>(other);
|
||||
return dynamic_slice_sizes() == casted_other.dynamic_slice_sizes();
|
||||
}
|
||||
|
||||
std::unique_ptr<HloInstruction>
|
||||
|
@ -16,6 +16,7 @@ limitations under the License.
|
||||
#include "tensorflow/compiler/xla/service/memory_space_assignment.h"
|
||||
|
||||
#include "tensorflow/compiler/xla/debug_options_flags.h"
|
||||
#include "tensorflow/compiler/xla/service/memory_space_assignment_utils.h"
|
||||
#include "tensorflow/core/lib/math/math_util.h"
|
||||
namespace xla {
|
||||
|
||||
@ -30,6 +31,22 @@ const int kWhileExecutionCount = 5;
|
||||
|
||||
} // namespace
|
||||
|
||||
/*static*/ StatusOr<std::unique_ptr<MemorySpaceAssignmentCostAnalysis>>
|
||||
MemorySpaceAssignmentCostAnalysis::Create(
|
||||
const HloCostAnalysis& cost_analysis,
|
||||
float async_copy_bandwidth_bytes_per_second,
|
||||
float alternate_mem_bandwidth_bytes_per_second, const HloModule& module) {
|
||||
TF_ASSIGN_OR_RETURN(auto alias_analysis, HloAliasAnalysis::Run(&module));
|
||||
TF_ASSIGN_OR_RETURN(auto hlo_live_range,
|
||||
HloLiveRange::Run(module.schedule(), *alias_analysis,
|
||||
module.entry_computation()));
|
||||
auto call_graph = CallGraph::Build(&module);
|
||||
return absl::WrapUnique(new MemorySpaceAssignmentCostAnalysis(
|
||||
cost_analysis, async_copy_bandwidth_bytes_per_second,
|
||||
alternate_mem_bandwidth_bytes_per_second, std::move(alias_analysis),
|
||||
std::move(hlo_live_range), std::move(call_graph)));
|
||||
}
|
||||
|
||||
float MemorySpaceAssignmentCostAnalysis::GetAlternateMemoryBenefit(
|
||||
const HloInstruction& instruction, float elapsed_time_due_to_alternate_mem,
|
||||
MemorySpaceAssignmentCostAnalysis::Cache* cache) const {
|
||||
@ -73,19 +90,32 @@ float MemorySpaceAssignmentCostAnalysis::GetMemoryBoundedness(
|
||||
/*operand_in_alternate_mem=*/{},
|
||||
/*output_in_alternate_mem=*/true),
|
||||
cache);
|
||||
for (const HloUse& use : interval.buffer->uses()) {
|
||||
float use_alternate_mem_benefit = GetAlternateMemoryBenefit(
|
||||
*use.instruction,
|
||||
GetInstructionElapsedDueToMemory(*use.instruction, use.operand_number),
|
||||
cache);
|
||||
// If the benefit is positive (memory bound), add it to this buffer's
|
||||
// benefit. If the benefit is negative (compute bound), calculate the
|
||||
// maximum.
|
||||
if (alternate_mem_benefit > 0 && use_alternate_mem_benefit > 0) {
|
||||
alternate_mem_benefit += use_alternate_mem_benefit;
|
||||
} else {
|
||||
alternate_mem_benefit =
|
||||
std::max(alternate_mem_benefit, use_alternate_mem_benefit);
|
||||
for (const HloBuffer* buffer : alias_analysis_->ComputeBuffersAt(
|
||||
interval.buffer->defining_position().instruction,
|
||||
interval.buffer->defining_position().index)) {
|
||||
for (const HloValue* value : buffer->values()) {
|
||||
for (const HloUse& use : value->uses()) {
|
||||
// We look inside the called computations of while and conditional, so
|
||||
// don't use the benefit of while and conditional directly.
|
||||
if (use.instruction->opcode() == HloOpcode::kWhile ||
|
||||
use.instruction->opcode() == HloOpcode::kConditional) {
|
||||
continue;
|
||||
}
|
||||
float use_alternate_mem_benefit =
|
||||
GetAlternateMemoryBenefit(*use.instruction,
|
||||
GetInstructionElapsedDueToMemory(
|
||||
*use.instruction, use.operand_number),
|
||||
cache);
|
||||
// If the benefit is positive (memory bound), add it to this buffer's
|
||||
// benefit. If the benefit is negative (compute bound), calculate the
|
||||
// maximum.
|
||||
if (alternate_mem_benefit > 0 && use_alternate_mem_benefit > 0) {
|
||||
alternate_mem_benefit += use_alternate_mem_benefit;
|
||||
} else {
|
||||
alternate_mem_benefit =
|
||||
std::max(alternate_mem_benefit, use_alternate_mem_benefit);
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
@ -94,17 +124,9 @@ float MemorySpaceAssignmentCostAnalysis::GetMemoryBoundedness(
|
||||
float alternate_mem_slowdown =
|
||||
GetInstructionElapsedDueToMemorySlowdown(interval.size);
|
||||
|
||||
// Scale the slowdown based on the time of this buffer. We would want earlier
|
||||
// buffers have lower slowdown values, because they are less likely to overlap
|
||||
// with other HLOs.
|
||||
// TODO(yuemmawang): We may want a piecewise function, where a lower slowdown
|
||||
// for early HLOs, and full slowdown for mid-to-late HLOs.
|
||||
// TODO(yuemmawang): Further in a smarter way, we want buffers overlapped with
|
||||
// more HLOs have higher slowdown, and vice versa.
|
||||
float scale = interval.start * 1.0 / GetScheduleEndTime();
|
||||
alternate_mem_slowdown *= scale;
|
||||
|
||||
return alternate_mem_benefit - alternate_mem_slowdown;
|
||||
// Divide by the size of the buffer to prioritize smaller buffers that will
|
||||
// give the largest alternate memory benefit.
|
||||
return (alternate_mem_benefit - alternate_mem_slowdown) / interval.size;
|
||||
}
|
||||
|
||||
int MemorySpaceAssignmentCostAnalysis::CalculateWhileLoopNestLevel(
|
||||
@ -112,7 +134,7 @@ int MemorySpaceAssignmentCostAnalysis::CalculateWhileLoopNestLevel(
|
||||
int nest_level = 0;
|
||||
const HloComputation* computation = instruction->parent();
|
||||
while (!computation->IsEntryComputation()) {
|
||||
auto node = call_graph_.GetNode(computation);
|
||||
auto node = call_graph_->GetNode(computation);
|
||||
auto callsites = node.caller_callsites();
|
||||
CHECK_EQ(callsites.size(), 1) << "The module is not flattened!";
|
||||
auto callsite = callsites[0];
|
||||
@ -194,7 +216,7 @@ float MemorySpaceAssignmentCostAnalysis::GetAsyncCopyElapsed(
|
||||
}
|
||||
|
||||
int64 MemorySpaceAssignmentCostAnalysis::GetScheduleEndTime() const {
|
||||
return hlo_live_range_.schedule_end_time();
|
||||
return hlo_live_range_->schedule_end_time();
|
||||
}
|
||||
|
||||
bool InstructionCountPrefetchIntervalPicker::CanAllocateInAlternateMemoryNoCopy(
|
||||
@ -252,6 +274,13 @@ CostAnalysisPrefetchIntervalPicker::CostAnalysisPrefetchIntervalPicker(
|
||||
std::vector<float> instructions_elapsed_time(instruction_schedule_->size(),
|
||||
0.0);
|
||||
for (const auto& instruction_and_logical_time : *instruction_schedule_) {
|
||||
// To avoid double counting, don't include the elapsed time of while and
|
||||
// conditional HLOs.
|
||||
const HloInstruction* instruction = instruction_and_logical_time.first;
|
||||
if (instruction->opcode() == HloOpcode::kWhile ||
|
||||
instruction->opcode() == HloOpcode::kConditional) {
|
||||
continue;
|
||||
}
|
||||
float elapsed_time = cost_analysis_.cost_analysis().optimal_seconds(
|
||||
*instruction_and_logical_time.first);
|
||||
int64 logical_time = instruction_and_logical_time.second;
|
||||
@ -597,81 +626,6 @@ AlternateMemoryBestFitHeap::GetSortedColocatedIntervals(
|
||||
return colocated_intervals;
|
||||
}
|
||||
|
||||
bool AlternateMemoryBestFitHeap::IsIntervalAllowedInAlternateMemory(
|
||||
const BufferInterval& interval) const {
|
||||
// If the buffer is a tuple, don't use this algorithm for now. The buffers
|
||||
// that are pointed to by the tuple will still use this algorithm. Because
|
||||
// tuples are cheap to place in the alternate memory (they are just pointers)
|
||||
// we don't need to use prefetch/evict logic.
|
||||
if (interval.buffer->shape().IsTuple()) {
|
||||
VLOG(4) << "Keeping value " << interval.buffer->ToShortString()
|
||||
<< " in default mem because it is a tuple.";
|
||||
return false;
|
||||
}
|
||||
|
||||
// Don't place scalars in the alternate memory.
|
||||
if (ShapeUtil::IsEffectiveScalar(interval.buffer->shape())) {
|
||||
VLOG(4) << "Keeping value " << interval.buffer->ToShortString()
|
||||
<< " in default mem because it is a scalar.";
|
||||
return false;
|
||||
}
|
||||
|
||||
// The semantics of TupleSelect are weird: TupleSelect doesn't define a
|
||||
// buffer, but just forwards the buffers in the either left or right side.
|
||||
// This means the two different inputs to TupleSelect must not alias, yet they
|
||||
// should be allocated in the same memory space, and both buffers must be kept
|
||||
// alive for the entire live range of TupleSelect. Instead, just don't
|
||||
// allocate TupleSelect in the alternate memory space.
|
||||
// TODO(berkin): Not allocating add-dependencies either since they need to be
|
||||
// treated specially. We should revisit this later.
|
||||
for (const HloPosition& position : interval.buffer->positions()) {
|
||||
if (position.instruction->opcode() == HloOpcode::kTupleSelect ||
|
||||
position.instruction->opcode() == HloOpcode::kAddDependency) {
|
||||
VLOG(4) << "Keeping value " << interval.buffer->ToShortString()
|
||||
<< " in default mem because it has a tuple-select or "
|
||||
<< "add-dependency position.";
|
||||
return false;
|
||||
}
|
||||
}
|
||||
|
||||
// Send and Recv HLOs return a request identifier. These should not be
|
||||
// allocated in the alternate memory.
|
||||
for (const HloPosition& position : interval.buffer->positions()) {
|
||||
if ((position.instruction->opcode() == HloOpcode::kSend ||
|
||||
position.instruction->opcode() == HloOpcode::kRecv)) {
|
||||
// TODO(berkin): Send/recv buffers need a stable buffer allocation
|
||||
// throughout sending/receiving. Disable memory space allocation for these
|
||||
// for now.
|
||||
if (position.index == ShapeIndex({0})) {
|
||||
VLOG(4) << "Keeping value " << interval.buffer->ToShortString()
|
||||
<< " in default mem because it is a send/recv buffer.";
|
||||
return false;
|
||||
} else if (position.index == ShapeIndex({1})) {
|
||||
VLOG(4) << "Keeping value " << interval.buffer->ToShortString()
|
||||
<< " in default mem because it is a request identifier for "
|
||||
"send/recv.";
|
||||
return false;
|
||||
}
|
||||
}
|
||||
|
||||
if ((position.instruction->opcode() == HloOpcode::kCollectivePermuteStart ||
|
||||
position.instruction->opcode() == HloOpcode::kCollectivePermuteDone)) {
|
||||
// Disable memory space allocation for these for now.
|
||||
if (position.index == ShapeIndex({0})) {
|
||||
VLOG(4) << "Keeping value " << interval.buffer->ToShortString()
|
||||
<< " in default mem because it is a collective-permute buffer.";
|
||||
return false;
|
||||
} else if (position.index == ShapeIndex({1})) {
|
||||
VLOG(4) << "Keeping value " << interval.buffer->ToShortString()
|
||||
<< " in default mem because it is a collective-permute buffer.";
|
||||
return false;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
return true;
|
||||
}
|
||||
|
||||
bool AlternateMemoryBestFitHeap::IsUseAllowedInAlternateMemory(
|
||||
const AllocationValue& value, const HloUse& use) const {
|
||||
const auto& instruction_schedule = hlo_live_range_.instruction_schedule();
|
||||
@ -710,8 +664,7 @@ bool AlternateMemoryBestFitHeap::IsUseAllowedInAlternateMemory(
|
||||
if (!options_.prefetch_interval_picker->CanAllocateInAlternateMemoryNoCopy(
|
||||
shape, parameter_time, min_use_time)) {
|
||||
VLOG(4) << "While allocation not allowed in alternate memory. "
|
||||
<< "use time = " << min_use_time
|
||||
<< ", root time = " << root_time;
|
||||
<< "use time = " << min_use_time << ", root time = " << root_time;
|
||||
return false;
|
||||
}
|
||||
// Check if there is a required assignment for the while loop output.
|
||||
@ -897,7 +850,8 @@ HeapSimulator::Result AlternateMemoryBestFitHeap::Finish() {
|
||||
continue;
|
||||
}
|
||||
|
||||
if (!IsIntervalAllowedInAlternateMemory(interval)) {
|
||||
if (!MemorySpaceAssignmentUtils::IsIntervalAllowedInAlternateMemory(
|
||||
interval)) {
|
||||
continue;
|
||||
}
|
||||
|
||||
@ -2011,17 +1965,38 @@ AlternateMemoryBestFitHeap::FindBestChunkCandidate(
|
||||
BufferInterval* alternate_mem_interval) const {
|
||||
int64 end_time = request.end_time;
|
||||
if (!preferred_offset) {
|
||||
// First find the earliest use that is the same or later than the end time.
|
||||
const auto& uses = request.allocation_value->uses();
|
||||
auto use_it = uses.begin();
|
||||
for (; use_it->time < end_time; ++use_it) {
|
||||
}
|
||||
CHECK(use_it != uses.end());
|
||||
int64 earliest_use = use_it->time;
|
||||
|
||||
// Then find the latest use that can be allocated contiguously without
|
||||
// copies.
|
||||
const Shape& shape = request.allocation_value->defining_position().shape();
|
||||
for (;
|
||||
(use_it + 1) != uses.end() &&
|
||||
options_.prefetch_interval_picker->CanAllocateInAlternateMemoryNoCopy(
|
||||
shape, use_it->time, (use_it + 1)->time);
|
||||
++use_it) {
|
||||
}
|
||||
CHECK(use_it != uses.end());
|
||||
int64 latest_contiguous_use = use_it->time;
|
||||
|
||||
// Find a chunk that's as long living as possible iterating in reverse over
|
||||
// the use times.
|
||||
for (auto use_it = request.allocation_value->uses().rbegin();
|
||||
use_it != request.allocation_value->uses().rend() &&
|
||||
use_it->time >= end_time;
|
||||
++use_it) {
|
||||
for (; use_it >= uses.begin() && use_it->time >= end_time; --use_it) {
|
||||
alternate_mem_interval->end = use_it->time;
|
||||
ChunkCandidate chunk_candidate =
|
||||
FindChunkCandidate(*alternate_mem_interval);
|
||||
if (chunk_candidate.heap_size <= available_heap_size()) {
|
||||
alternate_mem_interval->end = end_time;
|
||||
VLOG(3) << "FindBestChunkCandidate earliest use = " << earliest_use
|
||||
<< ", latest contiguous use = " << latest_contiguous_use
|
||||
<< ", use with available mem = " << use_it->time
|
||||
<< ", offset = " << chunk_candidate.chunk.offset;
|
||||
return chunk_candidate;
|
||||
}
|
||||
}
|
||||
@ -2079,8 +2054,8 @@ MemorySpaceAssignment::CalculateAsyncCopyStats() const {
|
||||
MemorySpaceAssignment::GetMemoryBoundednessBufferIntervalCompare(
|
||||
const MemorySpaceAssignmentCostAnalysis& cost_analysis,
|
||||
MemorySpaceAssignmentCostAnalysis::Cache* cache) {
|
||||
return [cost_analysis, cache](const BufferInterval& x,
|
||||
const BufferInterval& y) {
|
||||
return [&cost_analysis, cache](const BufferInterval& x,
|
||||
const BufferInterval& y) {
|
||||
float x_memory_boundedness = cost_analysis.GetMemoryBoundedness(x, cache);
|
||||
float y_memory_boundedness = cost_analysis.GetMemoryBoundedness(y, cache);
|
||||
if (x_memory_boundedness != y_memory_boundedness) {
|
||||
|
@ -84,18 +84,10 @@ class MemorySpaceAssignmentCostAnalysis {
|
||||
absl::flat_hash_map<const HloInstruction*, float> while_nest_multiplier;
|
||||
};
|
||||
|
||||
MemorySpaceAssignmentCostAnalysis(
|
||||
static StatusOr<std::unique_ptr<MemorySpaceAssignmentCostAnalysis>> Create(
|
||||
const HloCostAnalysis& cost_analysis,
|
||||
float async_copy_bandwidth_bytes_per_second,
|
||||
float alternate_mem_bandwidth_bytes_per_second,
|
||||
const HloLiveRange& hlo_live_range, const CallGraph& call_graph)
|
||||
: cost_analysis_(cost_analysis),
|
||||
async_copy_bandwidth_bytes_per_second_(
|
||||
async_copy_bandwidth_bytes_per_second),
|
||||
alternate_mem_bandwidth_bytes_per_second_(
|
||||
alternate_mem_bandwidth_bytes_per_second),
|
||||
hlo_live_range_(hlo_live_range),
|
||||
call_graph_(call_graph) {}
|
||||
float alternate_mem_bandwidth_bytes_per_second, const HloModule& module);
|
||||
|
||||
const HloCostAnalysis& cost_analysis() const { return cost_analysis_; }
|
||||
|
||||
@ -153,14 +145,31 @@ class MemorySpaceAssignmentCostAnalysis {
|
||||
// 0 means it is not in a while loop.
|
||||
int CalculateWhileLoopNestLevel(const HloInstruction* instruction) const;
|
||||
|
||||
const HloLiveRange& hlo_live_range() const { return hlo_live_range_; }
|
||||
const HloLiveRange& hlo_live_range() const { return *hlo_live_range_; }
|
||||
|
||||
private:
|
||||
MemorySpaceAssignmentCostAnalysis(
|
||||
const HloCostAnalysis& cost_analysis,
|
||||
float async_copy_bandwidth_bytes_per_second,
|
||||
float alternate_mem_bandwidth_bytes_per_second,
|
||||
std::unique_ptr<HloAliasAnalysis> alias_analysis,
|
||||
std::unique_ptr<HloLiveRange> hlo_live_range,
|
||||
std::unique_ptr<CallGraph> call_graph)
|
||||
: cost_analysis_(cost_analysis),
|
||||
async_copy_bandwidth_bytes_per_second_(
|
||||
async_copy_bandwidth_bytes_per_second),
|
||||
alternate_mem_bandwidth_bytes_per_second_(
|
||||
alternate_mem_bandwidth_bytes_per_second),
|
||||
alias_analysis_(std::move(alias_analysis)),
|
||||
hlo_live_range_(std::move(hlo_live_range)),
|
||||
call_graph_(std::move(call_graph)) {}
|
||||
|
||||
const HloCostAnalysis& cost_analysis_;
|
||||
float async_copy_bandwidth_bytes_per_second_;
|
||||
float alternate_mem_bandwidth_bytes_per_second_;
|
||||
const HloLiveRange& hlo_live_range_;
|
||||
const CallGraph& call_graph_;
|
||||
std::unique_ptr<HloAliasAnalysis> alias_analysis_;
|
||||
std::unique_ptr<HloLiveRange> hlo_live_range_;
|
||||
std::unique_ptr<CallGraph> call_graph_;
|
||||
};
|
||||
|
||||
// Abstract base class that memory space assignment uses to pick prefetch
|
||||
@ -909,10 +918,6 @@ class AlternateMemoryBestFitHeap : public GlobalDecreasingSizeBestFitHeap {
|
||||
static MemorySpaceAssignment::Allocation* GetLiveAllocationAt(
|
||||
const MemorySpaceAssignment::AllocationSequence& allocations, int64 time);
|
||||
|
||||
// Returns true if this buffer is allowed to be placed in the alternate
|
||||
// memory.
|
||||
bool IsIntervalAllowedInAlternateMemory(const BufferInterval& interval) const;
|
||||
|
||||
// Returns true if the use is allowed in the alternate memory.
|
||||
bool IsUseAllowedInAlternateMemory(const AllocationValue& value,
|
||||
const HloUse& use) const;
|
||||
|
Some files were not shown because too many files have changed in this diff Show More
Loading…
Reference in New Issue
Block a user