Merge remote-tracking branch 'upstream/master' into offline_memory_planner

This commit is contained in:
Jens Elofsson 2020-06-15 10:06:36 +02:00
commit 708ecda43e
709 changed files with 15423 additions and 6384 deletions

View File

@ -202,7 +202,6 @@ cc_library(
":operation_interface",
":tensor_handle_interface",
"//tensorflow/c:tensor_interface",
"//tensorflow/c/experimental/saved_model/core:saved_model_api",
"//tensorflow/core:framework",
"//tensorflow/core:lib",
"//tensorflow/core:protos_all_cc",

View File

@ -1473,14 +1473,10 @@ const TFE_OpAttrs* TFE_OpGetAttrs(TFE_Op* op) {
}
void TFE_OpAddAttrs(TFE_Op* op, const TFE_OpAttrs* attrs) {
tensorflow::AttrValueMap m;
tensorflow::unwrap(attrs)->FillAttrValueMap(&m);
tensorflow::EagerOperation* operation =
OperationFromInterface(tensorflow::unwrap(op));
tensorflow::AttrBuilder* destination = operation->MutableAttrs();
for (const auto& attribute : m) {
destination->Set(attribute.first, attribute.second);
}
destination->CopyAttributes(*tensorflow::unwrap(attrs));
}
void TFE_OpAttrsSerialize(const TFE_OpAttrs* attrs, TF_Buffer* buf,

View File

@ -21,8 +21,8 @@ limitations under the License.
#include "absl/types/span.h"
#include "tensorflow/c/eager/operation_interface.h"
#include "tensorflow/c/eager/tensor_handle_interface.h"
#include "tensorflow/c/experimental/saved_model/core/saved_model_api.h"
#include "tensorflow/c/tensor_interface.h"
#include "tensorflow/core/framework/function.pb.h"
#include "tensorflow/core/framework/numeric_types.h"
#include "tensorflow/core/framework/types.pb.h"
#include "tensorflow/core/platform/status.h"
@ -84,11 +84,10 @@ class AbstractContextInterface {
// Create an operation to perform op execution
virtual AbstractOperationInterface* CreateOperation() = 0;
// Load a SavedModelAPI object from the given directory and tags
virtual std::unique_ptr<SavedModelAPI> LoadSavedModelAPI(
const std::string& directory,
const absl::optional<std::unordered_set<std::string>>& tags,
tensorflow::Status* status) = 0;
// Returns whether the runtime is backed by TFRT or the legacy TF Eager
// Runtime. This is necessary to decouple runtime-dependent
// code that is layered on top of the runtime.
virtual bool UsesTFRT() = 0;
// List attributes of available devices
virtual void ListDevices(std::vector<DeviceAttributes>* devices) = 0;

View File

@ -16,6 +16,8 @@ limitations under the License.
#include "tensorflow/c/eager/parallel_device/parallel_device_lib.h"
#include "tensorflow/core/lib/gtl/cleanup.h"
#include "tensorflow/core/platform/env.h"
#include "tensorflow/core/platform/mutex.h"
namespace tensorflow {
namespace parallel_device {
@ -28,21 +30,198 @@ class OpDeleter {
using OpPtr = std::unique_ptr<TFE_Op, OpDeleter>;
// Creates a vector of `count` new executors (threads).
std::vector<ExecutorPtr> MakeExecutors(size_t count) {
std::vector<ExecutorPtr> executors;
executors.reserve(count);
for (int i = 0; i < count; ++i) {
executors.emplace_back(TFE_NewExecutor(true /* is_async */));
}
return executors;
}
class StatusDeleter {
public:
void operator()(TF_Status* to_delete) const { TF_DeleteStatus(to_delete); }
};
using StatusPtr = std::unique_ptr<TF_Status, StatusDeleter>;
} // namespace
// Allows a single op at a time to be launched without blocking.
//
// DeviceThread itself is thread-safe, in that StartExecute will block if there
// is a pending execution. Since StartExecute is equivalent to grabbing a lock,
// multiple DeviceThreads should always be accessed in the same order to avoid
// deadlocks.
class DeviceThread {
public:
// Starts a background thread waiting for `StartExecute`.
explicit DeviceThread(const std::string& device)
: status_(TF_NewStatus()),
device_(device),
op_(nullptr),
thread_(tensorflow::Env::Default()->StartThread(
tensorflow::ThreadOptions(), "parallel_device_execute",
std::bind(&DeviceThread::Run, this))) {}
~DeviceThread();
// Requests that the worker thread execute the specified operation. Blocks
// until the previously pending operation (a StartExecute without a Join) has
// finished, if any.
void StartExecute(TFE_Context* context, const char* operation_name,
std::vector<TFE_TensorHandle*> inputs,
const TFE_OpAttrs* attributes, int expected_max_outputs);
// Block until the previous `StartExecute` operation has executed. Forwards
// the status from `TFE_Execute` and returns outputs if the status is OK.
std::vector<TensorHandlePtr> Join(TF_Status* status);
private:
void Run();
void Execute(TFE_Context* context, const char* operation_name,
std::vector<TFE_TensorHandle*> inputs,
const TFE_OpAttrs* attributes, int expected_max_outputs,
std::vector<TensorHandlePtr>* outputs, TF_Status* status) const
TF_EXCLUSIVE_LOCKS_REQUIRED(execution_mutex_);
enum class ExecutionState {
kReadyToExecute,
kHasResult,
kIdle,
kShuttingDown,
};
tensorflow::mutex execution_mutex_;
ExecutionState execution_state_ TF_GUARDED_BY(execution_mutex_) =
ExecutionState::kIdle;
// Tells the worker thread that there is new work.
tensorflow::condition_variable start_execute_;
// The worker thread notifies that work has finished.
tensorflow::condition_variable finished_execute_;
// Notifies a StartExecute that the previous Join has finished.
tensorflow::condition_variable finished_join_;
// Temporary state between `StartExecute` and `Join`.
// Inputs
TFE_Context* context_ TF_GUARDED_BY(execution_mutex_);
const char* operation_name_ TF_GUARDED_BY(execution_mutex_);
std::vector<TFE_TensorHandle*> op_inputs_ TF_GUARDED_BY(execution_mutex_);
const TFE_OpAttrs* attributes_ TF_GUARDED_BY(execution_mutex_);
int expected_max_outputs_ TF_GUARDED_BY(execution_mutex_);
// Outputs
std::vector<TensorHandlePtr> op_outputs_ TF_GUARDED_BY(execution_mutex_);
StatusPtr status_ TF_GUARDED_BY(execution_mutex_);
const std::string device_;
mutable OpPtr op_ TF_GUARDED_BY(execution_mutex_);
std::unique_ptr<Thread> thread_;
};
DeviceThread::~DeviceThread() {
{
tensorflow::mutex_lock l(execution_mutex_);
execution_state_ = ExecutionState::kShuttingDown;
}
start_execute_.notify_one();
}
void DeviceThread::Run() {
while (true) {
{
tensorflow::mutex_lock l(execution_mutex_);
while (execution_state_ == ExecutionState::kIdle ||
execution_state_ == ExecutionState::kHasResult) {
start_execute_.wait(l);
}
if (execution_state_ == ExecutionState::kShuttingDown) {
return;
} else if (execution_state_ == ExecutionState::kReadyToExecute) {
// op_outputs_ may have been std::moved
op_outputs_ = std::vector<TensorHandlePtr>();
Execute(context_, operation_name_, std::move(op_inputs_), attributes_,
expected_max_outputs_, &op_outputs_, status_.get());
execution_state_ = ExecutionState::kHasResult;
}
}
finished_execute_.notify_one();
}
}
void DeviceThread::StartExecute(TFE_Context* context,
const char* operation_name,
std::vector<TFE_TensorHandle*> inputs,
const TFE_OpAttrs* attributes,
int expected_max_outputs) {
{
tensorflow::mutex_lock l(execution_mutex_);
while (execution_state_ != ExecutionState::kIdle) {
// If there's already a pending execution, wait until Join finishes before
// starting on the next operation.
finished_join_.wait(l);
}
context_ = context;
operation_name_ = operation_name;
op_inputs_ = inputs;
attributes_ = attributes;
expected_max_outputs_ = expected_max_outputs;
execution_state_ = ExecutionState::kReadyToExecute;
}
start_execute_.notify_one();
}
std::vector<TensorHandlePtr> DeviceThread::Join(TF_Status* status) {
std::vector<TensorHandlePtr> result;
{
tensorflow::mutex_lock l(execution_mutex_);
while (execution_state_ != ExecutionState::kHasResult) {
finished_execute_.wait(l);
}
if (TF_GetCode(status_.get()) != TF_OK) {
TF_SetStatus(status, TF_GetCode(status_.get()),
TF_Message(status_.get()));
}
execution_state_ = ExecutionState::kIdle;
result = std::move(op_outputs_);
}
finished_join_.notify_one();
return result;
}
void DeviceThread::Execute(TFE_Context* context, const char* operation_name,
std::vector<TFE_TensorHandle*> inputs,
const TFE_OpAttrs* attributes,
int expected_max_outputs,
std::vector<TensorHandlePtr>* outputs,
TF_Status* status) const {
if (op_ == nullptr) {
op_.reset(TFE_NewOp(context, operation_name, status));
if (TF_GetCode(status) != TF_OK) return;
TFE_OpSetDevice(op_.get(), device_.c_str(), status);
if (TF_GetCode(status) != TF_OK) return;
} else {
TFE_OpReset(op_.get(), operation_name, device_.c_str(), status);
if (TF_GetCode(status) != TF_OK) return;
}
TFE_OpAddAttrs(op_.get(), attributes);
for (int input_index = 0; input_index < inputs.size(); ++input_index) {
TFE_OpAddInput(op_.get(), inputs[input_index], status);
if (TF_GetCode(status) != TF_OK) return;
}
std::vector<TFE_TensorHandle*> unwrapped_results(expected_max_outputs);
int real_num_outputs = expected_max_outputs;
if (TF_GetCode(status) != TF_OK) return;
TFE_Execute(op_.get(), unwrapped_results.data(), &real_num_outputs, status);
if (TF_GetCode(status) != TF_OK) return;
unwrapped_results.resize(real_num_outputs);
outputs->reserve(real_num_outputs);
for (TFE_TensorHandle* unwrapped_result : unwrapped_results) {
outputs->emplace_back(unwrapped_result);
}
}
ParallelDevice::ParallelDevice(const std::vector<std::string>& devices)
: underlying_devices_(devices),
executors_(MakeExecutors(underlying_devices_.size())) {}
: underlying_devices_(devices) {
device_threads_.reserve(devices.size());
for (int device_index = 0; device_index < devices.size(); ++device_index) {
device_threads_.emplace_back(
new DeviceThread(devices[device_index].c_str()));
}
}
// Necessary for a unique_ptr to a forward-declared type.
ParallelDevice::~ParallelDevice() = default;
std::unique_ptr<ParallelTensor> ParallelDevice::CopyToParallelDevice(
TFE_Context* context, TFE_TensorHandle* tensor, TF_Status* status) const {
@ -108,72 +287,34 @@ ParallelDevice::Execute(TFE_Context* context,
// Compute per-device per-output tensors
std::vector<std::vector<TensorHandlePtr>> per_device_output_tensors;
per_device_output_tensors.reserve(underlying_devices_.size());
// TODO(allenl): Add a TFE_ExecuteWithExecutor API so we don't have to keep
// setting the thread-local executor like this.
TFE_Executor* previous_executor(TFE_ContextGetExecutorForThread(context));
auto reset_executor =
tensorflow::gtl::MakeCleanup([context, previous_executor]() {
TFE_ContextSetExecutorForThread(context, previous_executor);
TFE_DeleteExecutor(previous_executor);
});
int first_op_output_count;
int first_op_output_count = 0;
for (int device_index = 0; device_index < underlying_devices_.size();
++device_index) {
TFE_Executor* executor = executors_[device_index].get();
// Note that the `reset_executor` cleanup sets the thread's executor back to
// the value before this function ran.
TFE_ContextSetExecutorForThread(context, executor);
OpPtr op(TFE_NewOp(context, operation_name, status));
if (TF_GetCode(status) != TF_OK) return result;
TFE_OpSetDevice(op.get(), underlying_devices_[device_index].c_str(),
status);
TFE_OpAddAttrs(op.get(), attributes);
DeviceThread* device_thread = device_threads_[device_index].get();
std::vector<TFE_TensorHandle*> device_inputs;
device_inputs.reserve(device_inputs.size());
for (int input_index = 0; input_index < inputs.size(); ++input_index) {
// Parallel tensors are divided between operations by device.
TFE_OpAddInput(op.get(), inputs[input_index]->tensor(device_index),
status);
if (TF_GetCode(status) != TF_OK) return result;
device_inputs.push_back(inputs[input_index]->tensor(device_index));
}
std::vector<TFE_TensorHandle*> op_outputs(expected_max_outputs);
int real_num_outputs = expected_max_outputs;
// For nested devices, the inner device sees the async executor we've
// set. Inner parallel devices will just overwrite this with their own and
// then set it back to ours before returning. This means parallel devices
// which consist of several aliased parallel devices would hypothetically
// deadlock if the outer parallel device ran one collective with a group
// size equal to the total number of aliased physical devices. Currently
// physical devices cannot participate in a single collective reduction
// multiple times, so this would fail earlier.
//
// TODO(allenl): Keep a map from outer executor to list of inner executors
// rather than a single list of executors so aliased nested parallel devices
// don't re-use an executor.
TFE_Execute(op.get(), op_outputs.data(), &real_num_outputs, status);
device_thread->StartExecute(context, operation_name,
std::move(device_inputs), attributes,
expected_max_outputs);
}
for (int device_index = 0; device_index < underlying_devices_.size();
++device_index) {
DeviceThread* device_thread = device_threads_[device_index].get();
per_device_output_tensors.push_back(device_thread->Join(status));
if (TF_GetCode(status) != TF_OK) return result;
if (device_index == 0) {
first_op_output_count = real_num_outputs;
first_op_output_count = per_device_output_tensors.rbegin()->size();
} else {
if (real_num_outputs != first_op_output_count) {
if (per_device_output_tensors.rbegin()->size() != first_op_output_count) {
TF_SetStatus(status, TF_INTERNAL,
"Parallel ops produced different numbers of tensors.");
return result;
}
}
if (TF_GetCode(status) != TF_OK) return result;
std::vector<TensorHandlePtr> this_outputs;
this_outputs.reserve(real_num_outputs);
for (int output_num = 0; output_num < real_num_outputs; ++output_num) {
this_outputs.emplace_back(op_outputs[output_num]);
}
per_device_output_tensors.push_back(std::move(this_outputs));
}
for (int device_index = 0; device_index < underlying_devices_.size();
++device_index) {
TFE_Executor* executor = executors_[device_index].get();
// TODO(b/157523095): Syncing the executor here shouldn't be
// necessary. Currently async+remote is missing cross-executor
// coordination.
TFE_ExecutorWaitForAllPendingNodes(executor, status);
if (TF_GetCode(status) != TF_OK) return result;
}
// For each output of the original operation, pack the per-device
// TensorHandles we've computed into a single parallel TensorHandle.

View File

@ -41,16 +41,8 @@ class TensorHandleDeleter {
using TensorHandlePtr = std::unique_ptr<TFE_TensorHandle, TensorHandleDeleter>;
class ExecutorDeleter {
public:
void operator()(TFE_Executor* to_delete) const {
TFE_DeleteExecutor(to_delete);
}
};
using ExecutorPtr = std::unique_ptr<TFE_Executor, ExecutorDeleter>;
class ParallelTensor;
class DeviceThread;
// Forwards operations to `devices`, maintaining ParallelTensor with components
// placed on each underlying device.
@ -58,6 +50,8 @@ class ParallelDevice {
public:
explicit ParallelDevice(const std::vector<std::string>& devices);
~ParallelDevice();
// Helper to copy a tensor handle from another device once for each component
// of the ParallelDevice.
//
@ -94,9 +88,19 @@ class ParallelDevice {
// A sequence of device names, indicating which devices replicated operations
// are forwarded to.
const std::vector<std::string> underlying_devices_;
// A sequence of TFE_Executors, one per device, for executing operations in
// A sequence of thread wrappers, one per device, for executing operations in
// parallel.
const std::vector<ExecutorPtr> executors_;
//
// Conceptually this is a thread pool with one thread per device. It requires
// less synchronization than a thread pool would for this task, since Execute
// acquires each thread in order (and so only one Execute will schedule
// blocking collective operations at a time), and avoids some dynamic
// allocation/scheduling.
//
// TODO(allenl): Keep a map from outer thread to list of inner threads rather
// than a single list of threads so aliased nested parallel devices don't
// re-use a thread.
std::vector<std::unique_ptr<DeviceThread>> device_threads_;
};
// Contains a tuple of tensors, one on each of the `underlying_devices_` of the

View File

@ -407,7 +407,7 @@ TensorHandlePtr CollectiveSum(TFE_Context* context, TFE_TensorHandle* input,
return TensorHandlePtr(result_handle);
}
TEST(PARALLEL_DEVICE, TestCollective) {
void TestCollective(bool async) {
std::unique_ptr<TF_Status, decltype(&TF_DeleteStatus)> status(
TF_NewStatus(), TF_DeleteStatus);
std::unique_ptr<TFE_ContextOptions, decltype(&TFE_DeleteContextOptions)> opts(
@ -423,6 +423,9 @@ TEST(PARALLEL_DEVICE, TestCollective) {
std::unique_ptr<TFE_Context, decltype(&TFE_DeleteContext)> context(
TFE_NewContext(opts.get(), status.get()), TFE_DeleteContext);
ASSERT_TRUE(TF_GetCode(status.get()) == TF_OK) << TF_Message(status.get());
std::unique_ptr<TFE_Executor, decltype(&TFE_DeleteExecutor)> executor(
TFE_NewExecutor(async), TFE_DeleteExecutor);
TFE_ContextSetExecutorForThread(context.get(), executor.get());
const char* device_name = "/job:localhost/replica:0/task:0/device:CUSTOM:0";
std::array<const char*, 2> underlying_devices{
@ -452,8 +455,16 @@ TEST(PARALLEL_DEVICE, TestCollective) {
ASSERT_TRUE(TF_GetCode(status.get()) == TF_OK) << TF_Message(status.get());
ExpectScalarEq<float>(result_components[0].get(), 3.);
ExpectScalarEq<float>(result_components[1].get(), 3.);
// Destroying the context's default executor first isn't safe.
context.reset();
}
TEST(PARALLEL_DEVICE, TestCollectiveSync) { TestCollective(/*async=*/false); }
// Note that ops on the parallel device currently don't execute
// asynchronously. The test is just that we don't get deadlocks.
TEST(PARALLEL_DEVICE, TestCollectiveAsync) { TestCollective(/*async=*/true); }
void RegisterCollectiveMulFunction(TFE_Context* context,
const char* function_name, int group_size,
TF_Status* status) {

View File

@ -26,5 +26,6 @@ cc_library(
deps = [
"//tensorflow/c:tf_status",
"//tensorflow/c/experimental/filesystem:filesystem_interface",
"@com_github_googlecloudplatform_google_cloud_cpp//:storage_client",
],
)

View File

@ -15,11 +15,22 @@ limitations under the License.
#include <stdlib.h>
#include <string.h>
#include "google/cloud/storage/client.h"
#include "tensorflow/c/experimental/filesystem/filesystem_interface.h"
#include "tensorflow/c/tf_status.h"
// Implementation of a filesystem for GCS environments.
// This filesystem will support `gs://` URI schemes.
namespace gcs = google::cloud::storage;
// We can cast `google::cloud::StatusCode` to `TF_Code` because they have the
// same integer values. See
// https://github.com/googleapis/google-cloud-cpp/blob/6c09cbfa0160bc046e5509b4dd2ab4b872648b4a/google/cloud/status.h#L32-L52
static inline void TF_SetStatusFromGCSStatus(
const google::cloud::Status& gcs_status, TF_Status* status) {
TF_SetStatus(status, static_cast<TF_Code>(gcs_status.code()),
gcs_status.message().c_str());
}
static void* plugin_memory_allocate(size_t size) { return calloc(1, size); }
static void plugin_memory_free(void* ptr) { free(ptr); }
@ -52,6 +63,20 @@ namespace tf_read_only_memory_region {
// ----------------------------------------------------------------------------
namespace tf_gcs_filesystem {
// TODO(vnvo2409): Add lazy-loading and customizing parameters.
static void Init(TF_Filesystem* filesystem, TF_Status* status) {
google::cloud::StatusOr<gcs::Client> client =
gcs::Client::CreateDefaultClient();
if (!client) {
TF_SetStatusFromGCSStatus(client.status(), status);
return;
}
filesystem->plugin_filesystem = plugin_memory_allocate(sizeof(gcs::Client));
auto gcs_client = static_cast<gcs::Client*>(filesystem->plugin_filesystem);
(*gcs_client) = client.value();
TF_SetStatus(status, TF_OK, "");
}
// TODO(vnvo2409): Implement later
} // namespace tf_gcs_filesystem
@ -60,6 +85,10 @@ static void ProvideFilesystemSupportFor(TF_FilesystemPluginOps* ops,
const char* uri) {
TF_SetFilesystemVersionMetadata(ops);
ops->scheme = strdup(uri);
ops->filesystem_ops = static_cast<TF_FilesystemOps*>(
plugin_memory_allocate(TF_FILESYSTEM_OPS_SIZE));
ops->filesystem_ops->init = tf_gcs_filesystem::Init;
}
void TF_InitPlugin(TF_FilesystemPluginInfo* info) {

View File

@ -57,6 +57,7 @@ cc_library(
":concrete_function",
":saved_model_api",
"//tensorflow/core:lib",
"//tensorflow/core/common_runtime/eager:context",
"@com_google_absl//absl/types:optional",
],
)

View File

@ -21,6 +21,7 @@ limitations under the License.
#include "absl/types/optional.h"
#include "tensorflow/c/experimental/saved_model/core/concrete_function.h"
#include "tensorflow/core/common_runtime/eager/context.h"
#include "tensorflow/core/platform/errors.h"
namespace tensorflow {
@ -51,7 +52,7 @@ std::vector<ConcreteFunction*> TFSavedModelAPIImpl::ListFunctions() {
Status TFSavedModelAPIImpl::Load(
const std::string& directory,
const absl::optional<std::unordered_set<std::string>>& tags,
TFSavedModelAPIImpl* out) {
EagerContext* context, std::unique_ptr<TFSavedModelAPIImpl>* out) {
// TODO(bmzhao): Add support for loading a TFSavedModelImpl.
return errors::Unimplemented(
"TFSavedModelAPIImpl loading is unimplemented currently");

View File

@ -23,14 +23,13 @@ limitations under the License.
#include "absl/types/optional.h"
#include "tensorflow/c/experimental/saved_model/core/concrete_function.h"
#include "tensorflow/c/experimental/saved_model/core/saved_model_api.h"
#include "tensorflow/core/common_runtime/eager/context.h"
#include "tensorflow/core/platform/status.h"
namespace tensorflow {
class TFSavedModelAPIImpl : public SavedModelAPI {
public:
TFSavedModelAPIImpl() = default;
Status GetFunction(const std::string& function_path,
ConcreteFunction** function) override;
@ -40,13 +39,14 @@ class TFSavedModelAPIImpl : public SavedModelAPI {
static Status Load(
const std::string& directory,
const absl::optional<std::unordered_set<std::string>>& tags,
TFSavedModelAPIImpl* out);
EagerContext* context, std::unique_ptr<TFSavedModelAPIImpl>* out);
std::vector<ConcreteFunction*> ListFunctions() override;
~TFSavedModelAPIImpl() override = default;
private:
TFSavedModelAPIImpl() = default;
std::vector<ConcreteFunction> functions_;
};

View File

@ -144,7 +144,9 @@ cc_library(
"//tensorflow/c:tf_status_internal",
"//tensorflow/c/eager:tfe_context_internal",
"//tensorflow/c/experimental/saved_model/core:saved_model_api",
"//tensorflow/c/experimental/saved_model/core:tf_saved_model_impl",
"//tensorflow/core:lib",
"//tensorflow/core/common_runtime/eager:context",
"@com_google_absl//absl/types:optional",
],
)

View File

@ -22,11 +22,15 @@ limitations under the License.
#include "absl/types/optional.h"
#include "tensorflow/c/eager/tfe_context_internal.h"
#include "tensorflow/c/experimental/saved_model/core/saved_model_api.h"
#include "tensorflow/c/experimental/saved_model/core/tf_saved_model_impl.h"
#include "tensorflow/c/experimental/saved_model/internal/concrete_function_list_type.h"
#include "tensorflow/c/experimental/saved_model/internal/concrete_function_type.h"
#include "tensorflow/c/experimental/saved_model/internal/saved_model_api_type.h"
#include "tensorflow/c/tf_status.h"
#include "tensorflow/c/tf_status_internal.h"
#include "tensorflow/core/common_runtime/eager/context.h"
#include "tensorflow/core/platform/casts.h"
#include "tensorflow/core/platform/errors.h"
#include "tensorflow/core/platform/status.h"
extern "C" {
@ -34,10 +38,21 @@ extern "C" {
TF_SavedModel* TF_LoadSavedModel(const char* dirname, TFE_Context* ctx,
TF_Status* status) {
std::string saved_model_dir(dirname);
std::unique_ptr<tensorflow::SavedModelAPI> result;
if (tensorflow::unwrap(ctx)->UsesTFRT()) {
status->status = tensorflow::errors::Unimplemented(
"TFRT SavedModel implementation will be added in the future");
} else {
std::unique_ptr<tensorflow::TFSavedModelAPIImpl> saved_model;
status->status = tensorflow::TFSavedModelAPIImpl::Load(
dirname, absl::nullopt,
tensorflow::down_cast<tensorflow::EagerContext*>(
tensorflow::unwrap(ctx)),
&saved_model);
result = std::move(saved_model);
}
std::unique_ptr<tensorflow::SavedModelAPI> result =
tensorflow::unwrap(ctx)->LoadSavedModelAPI(dirname, absl::nullopt,
&status->status);
if (!status->status.ok()) {
return nullptr;
}
@ -54,9 +69,20 @@ TF_SavedModel* TF_LoadSavedModelWithTags(const char* dirname, TFE_Context* ctx,
tagset.insert(std::string(tags[i]));
}
std::unique_ptr<tensorflow::SavedModelAPI> result =
tensorflow::unwrap(ctx)->LoadSavedModelAPI(dirname, std::move(tagset),
&status->status);
std::unique_ptr<tensorflow::SavedModelAPI> result;
if (tensorflow::unwrap(ctx)->UsesTFRT()) {
status->status = tensorflow::errors::Unimplemented(
"TFRT SavedModel implementation will be added in the future");
} else {
std::unique_ptr<tensorflow::TFSavedModelAPIImpl> saved_model;
status->status = tensorflow::TFSavedModelAPIImpl::Load(
dirname, tagset,
tensorflow::down_cast<tensorflow::EagerContext*>(
tensorflow::unwrap(ctx)),
&saved_model);
result = std::move(saved_model);
}
if (!status->status.ok()) {
return nullptr;
}

View File

@ -106,6 +106,7 @@ cc_library(
hdrs = ["loader.h"],
deps = [
":constants",
":loader_util",
":reader",
] + if_not_mobile([
"//tensorflow/core:core_cpu",
@ -132,6 +133,17 @@ cc_library(
],
)
cc_library(
name = "loader_util",
srcs = ["loader_util.cc"],
hdrs = ["loader_util.h"],
deps = [":constants"] + if_not_mobile([
"//tensorflow/core:lib",
"//tensorflow/core:lib_internal",
"//tensorflow/core:protos_all_cc",
]),
)
tf_cc_test(
name = "bundle_v2_test",
srcs = ["bundle_v2_test.cc"],

View File

@ -18,6 +18,7 @@ limitations under the License.
#include <unordered_set>
#include "tensorflow/cc/saved_model/constants.h"
#include "tensorflow/cc/saved_model/loader_util.h"
#include "tensorflow/cc/saved_model/reader.h"
#include "tensorflow/core/framework/attr_value.pb.h"
#include "tensorflow/core/framework/node_def.pb.h"
@ -29,7 +30,6 @@ limitations under the License.
#include "tensorflow/core/lib/strings/strcat.h"
#include "tensorflow/core/platform/env.h"
#include "tensorflow/core/platform/errors.h"
#include "tensorflow/core/platform/protobuf_internal.h"
#include "tensorflow/core/protobuf/graph_debug_info.pb.h"
#include "tensorflow/core/protobuf/saver.pb.h"
#include "tensorflow/core/public/session.h"
@ -191,41 +191,6 @@ Status RunInitOp(const RunOptions& run_options, const string& export_dir,
return Status::OK();
}
// A SavedModel may store the name of the initialization op to run in the
// in the SignatureDef (v2) or a collection (v1). If an init_op collection
// exists, then the collection must contain exactly one op.
Status GetInitOp(const string& export_dir, const MetaGraphDef& meta_graph_def,
string* init_op_name) {
const auto& sig_def_map = meta_graph_def.signature_def();
const auto& init_op_sig_it =
meta_graph_def.signature_def().find(kSavedModelInitOpSignatureKey);
if (init_op_sig_it != sig_def_map.end()) {
*init_op_name = init_op_sig_it->second.outputs()
.find(kSavedModelInitOpSignatureKey)
->second.name();
return Status::OK();
}
const auto& collection_def_map = meta_graph_def.collection_def();
string init_op_collection_key;
if (collection_def_map.find(kSavedModelMainOpKey) !=
collection_def_map.end()) {
init_op_collection_key = kSavedModelMainOpKey;
} else {
init_op_collection_key = kSavedModelLegacyInitOpKey;
}
const auto init_op_it = collection_def_map.find(init_op_collection_key);
if (init_op_it != collection_def_map.end()) {
if (init_op_it->second.node_list().value_size() != 1) {
return errors::FailedPrecondition(
strings::StrCat("Expected exactly one main op in : ", export_dir));
}
*init_op_name = init_op_it->second.node_list().value(0);
}
return Status::OK();
}
Status RunRestore(const RunOptions& run_options, const string& export_dir,
const StringPiece restore_op_name,
const StringPiece variable_filename_const_op_name,
@ -263,32 +228,6 @@ Status RunRestore(const RunOptions& run_options, const string& export_dir,
nullptr /* outputs */, &run_metadata, session);
}
Status GetAssetFileDefs(const MetaGraphDef& meta_graph_def,
std::vector<AssetFileDef>* asset_file_defs) {
// With SavedModel v2, we write asset file def into metagraph instead of
// collection, so read from metagraph first.
if (meta_graph_def.asset_file_def_size() > 0) {
for (const auto& asset : meta_graph_def.asset_file_def()) {
asset_file_defs->push_back(asset);
}
return Status::OK();
}
// Fall back to read from collection to be backward compatible with v1.
const auto& collection_def_map = meta_graph_def.collection_def();
const auto assets_it = collection_def_map.find(kSavedModelAssetsKey);
if (assets_it == collection_def_map.end()) {
return Status::OK();
}
const auto& any_assets = assets_it->second.any_list().value();
for (const auto& any_asset : any_assets) {
AssetFileDef asset_file_def;
TF_RETURN_IF_ERROR(
ParseAny(any_asset, &asset_file_def, "tensorflow.AssetFileDef"));
asset_file_defs->push_back(asset_file_def);
}
return Status::OK();
}
Status ReadSavedModelDebugInfoIfPresent(
const string& export_dir,
std::unique_ptr<GraphDebugInfo>* debug_info_proto) {
@ -322,7 +261,7 @@ Status LoadSavedModelInternal(const SessionOptions& session_options,
std::vector<AssetFileDef> asset_file_defs;
TF_RETURN_IF_ERROR(
GetAssetFileDefs(bundle->meta_graph_def, &asset_file_defs));
internal::GetAssetFileDefs(bundle->meta_graph_def, &asset_file_defs));
TF_RETURN_IF_ERROR(
RunRestore(run_options, export_dir,
bundle->meta_graph_def.saver_def().restore_op_name(),
@ -336,7 +275,7 @@ Status LoadSavedModelInternal(const SessionOptions& session_options,
const uint64 graph_init_start_microseconds = Env::Default()->NowMicros();
string init_op_name;
TF_RETURN_IF_ERROR(
GetInitOp(export_dir, bundle->meta_graph_def, &init_op_name));
internal::GetInitOp(export_dir, bundle->meta_graph_def, &init_op_name));
TF_RETURN_IF_ERROR(RunInitOp(run_options, export_dir, bundle->meta_graph_def,
asset_file_defs, bundle->session.get(),
init_op_name));

View File

@ -0,0 +1,90 @@
/* Copyright 2016 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#include "tensorflow/cc/saved_model/loader_util.h"
#include <vector>
#include "tensorflow/cc/saved_model/constants.h"
#include "tensorflow/core/lib/strings/strcat.h"
#include "tensorflow/core/platform/errors.h"
#include "tensorflow/core/platform/protobuf_internal.h"
namespace tensorflow {
namespace internal {
// A SavedModel may store the name of the initialization op to run in the
// in the SignatureDef (v2) or a collection (v1). If an init_op collection
// exists, then the collection must contain exactly one op.
Status GetInitOp(const string& export_dir, const MetaGraphDef& meta_graph_def,
string* init_op_name) {
const auto& sig_def_map = meta_graph_def.signature_def();
const auto& init_op_sig_it =
meta_graph_def.signature_def().find(kSavedModelInitOpSignatureKey);
if (init_op_sig_it != sig_def_map.end()) {
*init_op_name = init_op_sig_it->second.outputs()
.find(kSavedModelInitOpSignatureKey)
->second.name();
return Status::OK();
}
const auto& collection_def_map = meta_graph_def.collection_def();
string init_op_collection_key;
if (collection_def_map.find(kSavedModelMainOpKey) !=
collection_def_map.end()) {
init_op_collection_key = kSavedModelMainOpKey;
} else {
init_op_collection_key = kSavedModelLegacyInitOpKey;
}
const auto init_op_it = collection_def_map.find(init_op_collection_key);
if (init_op_it != collection_def_map.end()) {
if (init_op_it->second.node_list().value_size() != 1) {
return errors::FailedPrecondition(
strings::StrCat("Expected exactly one main op in : ", export_dir));
}
*init_op_name = init_op_it->second.node_list().value(0);
}
return Status::OK();
}
Status GetAssetFileDefs(const MetaGraphDef& meta_graph_def,
std::vector<AssetFileDef>* asset_file_defs) {
// With SavedModel v2, we write asset file def into metagraph instead of
// collection, so read from metagraph first.
if (meta_graph_def.asset_file_def_size() > 0) {
for (const auto& asset : meta_graph_def.asset_file_def()) {
asset_file_defs->push_back(asset);
}
return Status::OK();
}
// Fall back to read from collection to be backward compatible with v1.
const auto& collection_def_map = meta_graph_def.collection_def();
const auto assets_it = collection_def_map.find(kSavedModelAssetsKey);
if (assets_it == collection_def_map.end()) {
return Status::OK();
}
const auto& any_assets = assets_it->second.any_list().value();
for (const auto& any_asset : any_assets) {
AssetFileDef asset_file_def;
TF_RETURN_IF_ERROR(
ParseAny(any_asset, &asset_file_def, "tensorflow.AssetFileDef"));
asset_file_defs->push_back(asset_file_def);
}
return Status::OK();
}
} // namespace internal
} // namespace tensorflow

View File

@ -0,0 +1,39 @@
/* Copyright 2016 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#ifndef TENSORFLOW_CC_SAVED_MODEL_LOADER_UTIL_H_
#define TENSORFLOW_CC_SAVED_MODEL_LOADER_UTIL_H_
#include <string>
#include "tensorflow/core/lib/core/status.h"
#include "tensorflow/core/protobuf/meta_graph.pb.h"
namespace tensorflow {
namespace internal {
// A SavedModel may store the name of the initialization op to run in the
// in the SignatureDef (v2) or a collection (v1). If an init_op collection
// exists, then the collection must contain exactly one op.
Status GetInitOp(const string& export_dir, const MetaGraphDef& meta_graph_def,
string* init_op_name);
Status GetAssetFileDefs(const MetaGraphDef& meta_graph_def,
std::vector<AssetFileDef>* asset_file_defs);
} // namespace internal
} // namespace tensorflow
#endif // TENSORFLOW_CC_SAVED_MODEL_LOADER_UTIL_H_

View File

@ -33,6 +33,7 @@ MarkForCompilationPassFlags* mark_for_compilation_flags;
XlaDeviceFlags* device_flags;
XlaOpsCommonFlags* ops_flags;
IntroduceFloatingPointJitterPassFlags* jitter_flags;
MlirCommonFlags* mlir_flags;
std::vector<Flag>* flag_list;
absl::once_flag flags_init;
@ -166,6 +167,9 @@ void AllocateAndParseFlags() {
jitter_flags = new IntroduceFloatingPointJitterPassFlags;
jitter_flags->jitter_amount = 1e-5;
mlir_flags = new MlirCommonFlags;
mlir_flags->tf_mlir_enable_mlir_bridge = false;
auto setter_for_jitter_tensor_names = [](string sequence) {
jitter_flags->tensor_names = absl::StrSplit(sequence, ',');
return true;
@ -211,7 +215,11 @@ void AllocateAndParseFlags() {
Flag("tf_introduce_floating_point_jitter_amount",
&jitter_flags->jitter_amount,
"The amount of jitter to introduce. This amount is added to each "
"element in the tensors named in `tensor_names.")});
"element in the tensors named in `tensor_names."),
Flag("tf_mlir_enable_mlir_bridge",
&mlir_flags->tf_mlir_enable_mlir_bridge,
"Enables experimental MLIR-Based TensorFlow Compiler Bridge.")});
AppendMarkForCompilationPassFlagsInternal(flag_list);
xla::ParseFlagsFromEnvAndDieIfUnknown("TF_XLA_FLAGS", *flag_list);
@ -250,6 +258,11 @@ GetIntroduceFloatingPointJitterPassFlags() {
return *jitter_flags;
}
MlirCommonFlags* GetMlirCommonFlags() {
absl::call_once(flags_init, &AllocateAndParseFlags);
return mlir_flags;
}
void AppendMarkForCompilationPassFlags(std::vector<Flag>* flag_list) {
absl::call_once(flags_init, &AllocateAndParseFlags);
AppendMarkForCompilationPassFlagsInternal(flag_list);

View File

@ -133,6 +133,11 @@ struct IntroduceFloatingPointJitterPassFlags {
std::vector<string> tensor_names;
};
// Flags for common MLIR configurations.
struct MlirCommonFlags {
bool tf_mlir_enable_mlir_bridge;
};
// Return a pointer to the DumpGraphFlags struct;
// repeated calls return the same pointer.
// This should be called only after Flags::Parse() has returned.
@ -148,6 +153,8 @@ const XlaOpsCommonFlags& GetXlaOpsCommonFlags();
const IntroduceFloatingPointJitterPassFlags&
GetIntroduceFloatingPointJitterPassFlags();
MlirCommonFlags* GetMlirCommonFlags();
// Appends the flag definitions associated with
// MarkForCompilationPassFlags/DumpGraphFlags to `flag_list`.
//

View File

@ -30,7 +30,7 @@ cc_library(
hdrs = ["op_or_arg_name_mapper.h"],
deps = [
"@com_google_absl//absl/strings",
"@llvm-project//llvm:support",
"@llvm-project//llvm:Support",
"@llvm-project//mlir:IR",
],
)
@ -42,7 +42,7 @@ cc_library(
":init_mlir",
"//tensorflow/core:lib",
"//tensorflow/core/platform:logging",
"@llvm-project//llvm:support",
"@llvm-project//llvm:Support",
"@llvm-project//mlir:AllPassesAndDialects",
"@llvm-project//mlir:IR",
"@llvm-project//mlir:MlirOptLib",
@ -86,7 +86,7 @@ cc_library(
hdrs = ["init_mlir.h"],
deps = [
"//tensorflow/core:lib",
"@llvm-project//llvm:support",
"@llvm-project//llvm:Support",
],
)
@ -102,7 +102,7 @@ cc_library(
"//tensorflow/compiler/mlir/tensorflow:mlir_roundtrip_flags",
"//tensorflow/core:core_cpu",
"@com_google_absl//absl/container:flat_hash_set",
"@llvm-project//llvm:support",
"@llvm-project//llvm:Support",
"@llvm-project//mlir:IR",
"@llvm-project//mlir:Shape",
"@llvm-project//mlir:StandardOps",
@ -155,7 +155,7 @@ tf_cc_binary(
"//tensorflow/core:tensorflow",
"//tensorflow/stream_executor/lib",
"@com_google_absl//absl/strings",
"@llvm-project//llvm:support",
"@llvm-project//llvm:Support",
"@llvm-project//mlir:AllPassesAndDialects",
"@llvm-project//mlir:IR",
"@llvm-project//mlir:Support",

View File

@ -225,7 +225,7 @@ cc_library(
"//tensorflow/compiler/mlir/lite/quantization:quantization_lib",
"//tensorflow/compiler/mlir/tensorflow:tensorflow_types",
"//tensorflow/lite/schema:schema_fbs",
"@llvm-project//llvm:support",
"@llvm-project//llvm:Support",
"@llvm-project//mlir:Analysis",
"@llvm-project//mlir:DerivedAttributeOpInterface",
"@llvm-project//mlir:Dialect",
@ -253,7 +253,7 @@ cc_library(
deps = [
":tensorflow_lite",
"//tensorflow/compiler/mlir/tensorflow",
"@llvm-project//llvm:support",
"@llvm-project//llvm:Support",
"@llvm-project//mlir:IR",
"@llvm-project//mlir:StandardOps",
"@llvm-project//mlir:Support",
@ -272,7 +272,7 @@ cc_library(
deps = [
":tensorflow_lite",
"//tensorflow/compiler/mlir/tensorflow",
"@llvm-project//llvm:support",
"@llvm-project//llvm:Support",
"@llvm-project//mlir:IR",
"@llvm-project//mlir:StandardOps",
"@llvm-project//mlir:Support",
@ -289,7 +289,7 @@ cc_library(
],
deps = [
":tensorflow_lite",
"@llvm-project//llvm:support",
"@llvm-project//llvm:Support",
"@llvm-project//mlir:IR",
"@llvm-project//mlir:StandardOps",
],
@ -304,7 +304,7 @@ tf_cc_test(
"//tensorflow/compiler/mlir/tensorflow",
"//tensorflow/core:test",
"//tensorflow/core:test_main",
"@llvm-project//llvm:support",
"@llvm-project//llvm:Support",
"@llvm-project//mlir:IR",
"@llvm-project//mlir:StandardOps",
"@llvm-project//mlir:Support",
@ -357,7 +357,7 @@ cc_library(
"//tensorflow/core/platform:logging",
"@com_google_absl//absl/container:inlined_vector",
"@com_google_absl//absl/memory",
"@llvm-project//llvm:support",
"@llvm-project//llvm:Support",
"@llvm-project//mlir:Analysis",
"@llvm-project//mlir:IR",
"@llvm-project//mlir:Pass",
@ -383,7 +383,7 @@ cc_library(
":validators",
"//tensorflow/compiler/mlir/lite/quantization:quantization_lib",
"//tensorflow/compiler/mlir/tensorflow",
"@llvm-project//llvm:support",
"@llvm-project//llvm:Support",
"@llvm-project//mlir:Analysis",
"@llvm-project//mlir:IR",
"@llvm-project//mlir:Pass",
@ -416,7 +416,7 @@ cc_library(
"//tensorflow/compiler/mlir/lite/quantization/lite:tfl_to_std",
"//tensorflow/core:protos_all_cc",
"@com_google_absl//absl/memory",
"@llvm-project//llvm:support",
"@llvm-project//llvm:Support",
"@llvm-project//mlir:Analysis",
"@llvm-project//mlir:IR",
"@llvm-project//mlir:Pass",
@ -441,7 +441,7 @@ cc_library(
"@com_google_absl//absl/base",
"@com_google_absl//absl/memory",
"@com_google_absl//absl/strings",
"@llvm-project//llvm:support",
"@llvm-project//llvm:Support",
"@llvm-project//mlir:IR",
"@llvm-project//mlir:Pass",
"@llvm-project//mlir:StandardOps",
@ -494,8 +494,8 @@ tf_native_cc_binary(
"converter_gen.cc",
],
deps = [
"@llvm-project//llvm:support",
"@llvm-project//llvm:tablegen",
"@llvm-project//llvm:Support",
"@llvm-project//llvm:TableGen",
"@llvm-project//mlir:TableGen",
],
)
@ -541,8 +541,8 @@ cc_library(
"@com_google_absl//absl/container:flat_hash_map",
"@com_google_absl//absl/strings",
"@flatbuffers",
"@llvm-project//llvm:analysis",
"@llvm-project//llvm:support",
"@llvm-project//llvm:Analysis",
"@llvm-project//llvm:Support",
"@llvm-project//mlir:IR",
"@llvm-project//mlir:TransformUtils",
],
@ -619,7 +619,7 @@ cc_library(
"@com_google_absl//absl/container:flat_hash_set",
"@com_google_absl//absl/strings",
"@flatbuffers",
"@llvm-project//llvm:support",
"@llvm-project//llvm:Support",
"@llvm-project//mlir:IR",
"@llvm-project//mlir:QuantOps",
"@llvm-project//mlir:StandardOps",
@ -653,7 +653,7 @@ cc_library(
"@com_google_absl//absl/container:flat_hash_map",
"@com_google_absl//absl/container:flat_hash_set",
"@com_google_absl//absl/strings",
"@llvm-project//llvm:support",
"@llvm-project//llvm:Support",
"@llvm-project//mlir:IR",
"@llvm-project//mlir:QuantOps",
"@llvm-project//mlir:StandardOps",
@ -713,7 +713,7 @@ cc_library(
"@com_google_absl//absl/container:flat_hash_map",
"@com_google_absl//absl/container:flat_hash_set",
"@com_google_absl//absl/strings",
"@llvm-project//llvm:support",
"@llvm-project//llvm:Support",
"@llvm-project//mlir:IR",
"@llvm-project//mlir:MlirTranslateMain",
"@llvm-project//mlir:QuantOps",
@ -743,7 +743,7 @@ cc_library(
"tf_tfl_translate_cl.h",
],
deps = [
"@llvm-project//llvm:support",
"@llvm-project//llvm:Support",
],
alwayslink = 1,
)
@ -755,7 +755,7 @@ cc_library(
],
deps = [
"//tensorflow/compiler/mlir/lite/quantization:quantization_config",
"@llvm-project//llvm:support",
"@llvm-project//llvm:Support",
],
)
@ -780,7 +780,7 @@ tf_cc_binary(
":tf_tfl_translate_cl_options",
":tf_to_tfl_flatbuffer",
"@com_google_absl//absl/strings",
"@llvm-project//llvm:support",
"@llvm-project//llvm:Support",
# TODO(b/155809683): Link only necessary dialects.
"@llvm-project//mlir:AllPassesAndDialects",
"@llvm-project//mlir:IR",
@ -805,7 +805,7 @@ tf_cc_binary(
":flatbuffer_translate_lib",
":flatbuffer_translate_registeration",
"@com_google_absl//absl/strings",
"@llvm-project//llvm:support",
"@llvm-project//llvm:Support",
# TODO(b/155809683): Link only necessary dialects.
"@llvm-project//mlir:AllPassesAndDialects",
"@llvm-project//mlir:IR",
@ -874,7 +874,7 @@ cc_library(
"//tensorflow/lite/tools/optimize:quantize_weights",
"//tensorflow/stream_executor/lib",
"@com_google_absl//absl/types:span",
"@llvm-project//llvm:support",
"@llvm-project//llvm:Support",
"@llvm-project//mlir:AllPassesAndDialects",
"@llvm-project//mlir:IR",
"@llvm-project//mlir:Parser",
@ -894,6 +894,6 @@ cc_library(
"//tensorflow/lite/experimental/mlir:__subpackages__",
],
deps = [
"@llvm-project//llvm:support",
"@llvm-project//llvm:Support",
],
)

View File

@ -868,6 +868,8 @@ StatusOr<FuncOp> ConvertSubgraph(
subgraph, &builder, "outputs", func_outputs));
}
func.setAttr("tf.entry_function", builder.getDictionaryAttr(attributes));
} else {
func.setVisibility(FuncOp::Visibility::Private);
}
absl::flat_hash_set<const tflite::OperatorT*> pruned_subgraph_ops;

View File

@ -27,7 +27,7 @@ cc_library(
"//tensorflow/lite/toco:toco_flags_proto_cc",
"//tensorflow/lite/toco:types_proto_cc",
"//tensorflow/stream_executor/lib",
"@llvm-project//llvm:support",
"@llvm-project//llvm:Support",
"@llvm-project//mlir:IR",
"@llvm-project//mlir:Pass",
"@llvm-project//mlir:Support",
@ -56,7 +56,7 @@ cc_library(
"//tensorflow/lite/toco:toco_flags_proto_cc",
"//tensorflow/lite/toco:types_proto_cc",
"//tensorflow/stream_executor/lib",
"@llvm-project//llvm:support",
"@llvm-project//llvm:Support",
"@llvm-project//mlir:IR",
"@llvm-project//mlir:Pass",
"@llvm-project//mlir:Support",
@ -85,7 +85,7 @@ cc_library(
"//tensorflow/lite/toco:toco_flags_proto_cc",
"//tensorflow/lite/toco:types_proto_cc",
"//tensorflow/stream_executor/lib",
"@llvm-project//llvm:support",
"@llvm-project//llvm:Support",
"@llvm-project//mlir:IR",
"@llvm-project//mlir:Pass",
"@llvm-project//mlir:Support",

View File

@ -80,7 +80,7 @@ cc_library(
"//tensorflow/core:lib_proto_parsing",
"@com_google_absl//absl/memory",
"@com_google_absl//absl/strings",
"@llvm-project//llvm:support",
"@llvm-project//llvm:Support",
"@llvm-project//mlir:Analysis",
"@llvm-project//mlir:IR",
"@llvm-project//mlir:Pass",
@ -106,7 +106,7 @@ cc_library(
deps = [
"//tensorflow/core:lib_proto_parsing",
"@com_google_absl//absl/memory",
"@llvm-project//llvm:support",
"@llvm-project//llvm:Support",
"@llvm-project//mlir:IR",
"@llvm-project//mlir:QuantOps",
"@llvm-project//mlir:StandardOps",
@ -125,7 +125,7 @@ cc_library(
deps = [
"//tensorflow/core:protos_all_cc",
"@com_google_absl//absl/strings",
"@llvm-project//llvm:support",
"@llvm-project//llvm:Support",
],
)
@ -135,8 +135,8 @@ tf_native_cc_binary(
"tools/op_quant_spec_getters_gen.cc",
],
deps = [
"@llvm-project//llvm:support",
"@llvm-project//llvm:tablegen",
"@llvm-project//llvm:Support",
"@llvm-project//llvm:TableGen",
"@llvm-project//mlir:TableGen",
],
)
@ -157,7 +157,7 @@ cc_library(
deps = [
":numerical_utils",
"@com_google_absl//absl/types:optional",
"@llvm-project//llvm:support",
"@llvm-project//llvm:Support",
"@llvm-project//mlir:IR",
"@llvm-project//mlir:QuantOps",
"@llvm-project//mlir:Support",
@ -172,7 +172,7 @@ cc_library(
":device_target",
":quantization_lib",
"@com_google_absl//absl/memory",
"@llvm-project//llvm:support",
"@llvm-project//llvm:Support",
"@llvm-project//mlir:IR",
"@llvm-project//mlir:QuantOps",
"@llvm-project//mlir:StandardOps",

View File

@ -36,7 +36,7 @@ cc_library(
"//tensorflow/lite/core/api",
"//tensorflow/lite/schema:schema_fbs",
"@com_google_absl//absl/strings",
"@llvm-project//llvm:support",
"@llvm-project//llvm:Support",
"@llvm-project//mlir:IR",
"@llvm-project//mlir:Pass",
],
@ -54,7 +54,7 @@ cc_library(
deps = [
"//tensorflow/compiler/mlir/lite:tensorflow_lite",
"//tensorflow/compiler/mlir/lite/quantization:quantization_lib",
"@llvm-project//llvm:support",
"@llvm-project//llvm:Support",
"@llvm-project//mlir:IR",
"@llvm-project//mlir:QuantOps",
"@llvm-project//mlir:StandardOps",
@ -73,7 +73,7 @@ tf_cc_binary(
"//tensorflow/lite:framework",
"//tensorflow/lite/schema:schema_fbs",
"@com_google_absl//absl/strings",
"@llvm-project//llvm:support",
"@llvm-project//llvm:Support",
"@llvm-project//mlir:AllPassesAndDialects",
],
)

View File

@ -27,7 +27,7 @@ cc_library(
"//tensorflow/compiler/mlir/lite/quantization:quantization_lib",
"//tensorflow/compiler/mlir/tensorflow",
"@com_google_absl//absl/strings",
"@llvm-project//llvm:support",
"@llvm-project//llvm:Support",
"@llvm-project//mlir:IR",
"@llvm-project//mlir:Pass",
"@llvm-project//mlir:QuantOps",

View File

@ -32,7 +32,7 @@ cc_library(
"//tensorflow/lite/core/api",
"//tensorflow/lite/schema:schema_fbs",
"@com_google_absl//absl/strings",
"@llvm-project//llvm:support",
"@llvm-project//llvm:Support",
"@llvm-project//mlir:IR",
"@llvm-project//mlir:Pass",
],

View File

@ -54,7 +54,7 @@ tf_native_cc_binary(
"//tensorflow/lite:framework",
"//tensorflow/lite/schema:schema_fbs",
"@com_google_absl//absl/strings",
"@llvm-project//llvm:support",
"@llvm-project//llvm:Support",
],
)
@ -70,6 +70,6 @@ tf_native_cc_binary(
"//tensorflow/lite:framework",
"//tensorflow/lite/schema:schema_fbs",
"@com_google_absl//absl/strings",
"@llvm-project//llvm:support",
"@llvm-project//llvm:Support",
],
)

View File

@ -0,0 +1,12 @@
// RUN: tf-opt %s -tfl-legalize-tf='run-tfl-runtime-verification=false' | FileCheck %s
func @broadcast_to_bf16(%arg0: tensor<3xbf16>, %arg1: tensor<2xi64>) -> tensor<3x3xbf16> {
%0 = "tf.BroadcastTo"(%arg0, %arg1) : (tensor<3xbf16>, tensor<2xi64>) -> tensor<3x3xbf16>
return %0: tensor<3x3xbf16>
// CHECK-LABEL: broadcast_to_bf16
// CHECK: [[CST:%.*]] = constant dense<1.000000e+00> : tensor<bf16>
// CHECK: [[FILL:%.*]] = "tfl.fill"(%arg1, [[CST]]) : (tensor<2xi64>, tensor<bf16>) -> tensor<3x3xbf16>
// CHECK: [[MUL:%.*]] = "tfl.mul"(%arg0, [[FILL]]) {fused_activation_function = "NONE"} : (tensor<3xbf16>, tensor<3x3xbf16>) -> tensor<3x3xbf16>
// CHECK: return [[MUL]] : tensor<3x3xbf16>
}

View File

@ -1021,24 +1021,6 @@ func @matmul_transposed(%arg0: tensor<40x37xf32>, %arg1: tensor<40x37xf32>) -> t
// CHECK: "tfl.fully_connected"(%arg0, %arg1, %cst) {fused_activation_function = "NONE", keep_num_dims = false, weights_format = "DEFAULT"} : (tensor<40x37xf32>, tensor<40x37xf32>, none) -> tensor<40x40xf32>
}
func @concat2Tensors(%arg0: tensor<2x1xi32>, %arg1: tensor<2x1xi32>) -> tensor<2x2xi32> {
%0 = "tf.Const"() { value = dense<1> : tensor<i32> } : () -> tensor<i32>
%1 = "tf.Concat"(%0, %arg0, %arg1) : (tensor<i32>, tensor<2x1xi32>, tensor<2x1xi32>) -> tensor<2x2xi32>
return %1 : tensor<2x2xi32>
// CHECK-LABEL: concat2Tensors
// CHECK: "tfl.concatenation"(%arg0, %arg1) {axis = 1 : i32, fused_activation_function = "NONE"} : (tensor<2x1xi32>, tensor<2x1xi32>) -> tensor<2x2xi32>
}
func @concat3Tensors(%arg0: tensor<2x1xi32>, %arg1: tensor<2x1xi32>, %arg2: tensor<2x1xi32>) -> tensor<2x3xi32> {
%0 = "tf.Const"() { value = dense<-1> : tensor<i32> } : () -> tensor<i32>
%1 = "tf.Concat"(%0, %arg0, %arg1, %arg2) : (tensor<i32>, tensor<2x1xi32>, tensor<2x1xi32>, tensor<2x1xi32>) -> tensor<2x3xi32>
return %1 : tensor<2x3xi32>
// CHECK-LABEL: concat3Tensors
// CHECK: "tfl.concatenation"(%arg0, %arg1, %arg2) {axis = -1 : i32, fused_activation_function = "NONE"} : (tensor<2x1xi32>, tensor<2x1xi32>, tensor<2x1xi32>) -> tensor<2x3xi32>
}
func @concatv2With3Tensors(%arg0: tensor<2x1xi32>, %arg1: tensor<2x1xi32>, %arg2: tensor<2x1xi32>) -> tensor<2x3xi32> {
%0 = "tf.Const"() { value = dense<-1> : tensor<i32> } : () -> tensor<i32>
%1 = "tf.ConcatV2"(%arg0, %arg1, %arg2, %0) : (tensor<2x1xi32>, tensor<2x1xi32>, tensor<2x1xi32>, tensor<i32>) -> tensor<2x3xi32>

View File

@ -13,14 +13,12 @@ func @main(%arg0: tensor<f32>, %arg1: tensor<f32>) -> (tensor<f32>) {
return %3 : tensor<f32>
}
// CHECK-NOT: add
// CHECK-NOT: sub
func @add(%arg0: tensor<*xf32>, %arg1: tensor<*xf32>) -> tensor<*xf32> {
func @add(%arg0: tensor<*xf32>, %arg1: tensor<*xf32>) -> tensor<*xf32> attributes {sym_visibility = "private"} {
%0 = "tf.Add"(%arg0, %arg1): (tensor<*xf32>, tensor<*xf32>) -> tensor<*xf32>
return %0 : tensor<*xf32>
}
func @sub(%arg0: tensor<*xf32>, %arg1: tensor<*xf32>) -> tensor<*xf32> {
func @sub(%arg0: tensor<*xf32>, %arg1: tensor<*xf32>) -> tensor<*xf32> attributes {sym_visibility = "private"} {
%0 = "tf.Sub"(%arg0, %arg1) : (tensor<*xf32>, tensor<*xf32>) -> tensor<*xf32>
return %0 : tensor<*xf32>
}
@ -42,65 +40,31 @@ func @main(%arg0: tensor<f32>, %arg1: tensor<f32>) -> (tensor<f32>) {
return %3 : tensor<f32>
}
// CHECK-NOT: addormul
// CHECK-NOT: sub
// CHECK-NOT: mul
// CHECK-NOT: add
func @addormul(%arg0: tensor<*xf32>, %arg1: tensor<*xf32>) -> tensor<*xf32> {
func @addormul(%arg0: tensor<*xf32>, %arg1: tensor<*xf32>) -> tensor<*xf32> attributes {sym_visibility = "private"} {
%0 = constant dense<false> : tensor<i1>
%1 = "tf.If"(%0, %arg1, %arg0) {else_branch = @mul, then_branch = @add, is_stateless = true} : (tensor<i1>, tensor<*xf32>, tensor<*xf32>) -> tensor<*xf32>
return %1 : tensor<*xf32>
}
func @sub(%arg0: tensor<*xf32>, %arg1: tensor<*xf32>) -> tensor<*xf32> {
func @sub(%arg0: tensor<*xf32>, %arg1: tensor<*xf32>) -> tensor<*xf32> attributes {sym_visibility = "private"} {
%0 = "tf.Sub"(%arg0, %arg1) : (tensor<*xf32>, tensor<*xf32>) -> tensor<*xf32>
return %0 : tensor<*xf32>
}
func @add(%arg0: tensor<*xf32>, %arg1: tensor<*xf32>) -> tensor<*xf32> {
func @add(%arg0: tensor<*xf32>, %arg1: tensor<*xf32>) -> tensor<*xf32> attributes {sym_visibility = "private"} {
%0 = "tf.Add"(%arg0, %arg1): (tensor<*xf32>, tensor<*xf32>) -> tensor<*xf32>
return %0 : tensor<*xf32>
}
func @mul(%arg0: tensor<*xf32>, %arg1: tensor<*xf32>) -> tensor<*xf32> {
func @mul(%arg0: tensor<*xf32>, %arg1: tensor<*xf32>) -> tensor<*xf32> attributes {sym_visibility = "private"} {
%0 = "tf.Multiply"(%arg0, %arg1): (tensor<*xf32>, tensor<*xf32>) -> tensor<*xf32>
return %0 : tensor<*xf32>
}
// -----
// Verify that branch functions with multiple references are not erased.
func @main(%arg0: tensor<f32>, %arg1: tensor<f32>, %arg2: tensor<i1>) -> (tensor<f32>, tensor<f32>) {
%0 = "tf.Placeholder.input"(%arg0) : (tensor<f32>) -> tensor<f32>
%1 = "tf.Placeholder.input"(%arg1) : (tensor<f32>) -> tensor<f32>
%2 = constant dense<true> : tensor<i1>
// CHECK: tf.Add
%3 = "tf.If"(%2, %0, %1) {else_branch = @sub, then_branch = @add, is_stateless = true} : (tensor<i1>, tensor<f32>, tensor<f32>) -> tensor<f32>
// CHECK: tf.If
%4 = "tf.If"(%arg2, %0, %1) {else_branch = @sub, then_branch = @add, is_stateless = true} : (tensor<i1>, tensor<f32>, tensor<f32>) -> tensor<f32>
return %3, %4 : tensor<f32>, tensor<f32>
}
// CHECK: add
// CHECK: sub
func @add(%arg0: tensor<*xf32>, %arg1: tensor<*xf32>) -> tensor<*xf32> {
%0 = "tf.Add"(%arg0, %arg1): (tensor<*xf32>, tensor<*xf32>) -> tensor<*xf32>
return %0 : tensor<*xf32>
}
func @sub(%arg0: tensor<*xf32>, %arg1: tensor<*xf32>) -> tensor<*xf32> {
%0 = "tf.Sub"(%arg0, %arg1) : (tensor<*xf32>, tensor<*xf32>) -> tensor<*xf32>
return %0 : tensor<*xf32>
}
// -----
// Verify unused if with functions without side-effects are removed.
// Verify unused if with functions without side-effects is removed.
// CHECK-LABEL: main
func @main(%arg0: tensor<3x15x14x3xf32>) -> tensor<3x15x14x8xf32>
attributes {tf.entry_function = {inputs = "input", outputs = "Conv2D"}} {
%cst = constant dense<[0, 1, 2, 3]> : tensor<4xi32>
@ -118,26 +82,22 @@ func @main(%arg0: tensor<3x15x14x3xf32>) -> tensor<3x15x14x8xf32>
return %4 : tensor<3x15x14x8xf32>
}
func @_functionalize_if_else_branch_00(%arg0: tensor<*xi1>, %arg1: tensor<*xf32>, %arg2: tensor<*xf32>) -> tensor<i1> {
func @_functionalize_if_else_branch_00(%arg0: tensor<*xi1>, %arg1: tensor<*xf32>, %arg2: tensor<*xf32>) -> tensor<i1> attributes {sym_visibility = "private"} {
%cst = constant dense<false> : tensor<i1>
return %cst : tensor<i1>
}
func @_functionalize_if_then_branch_00(%arg0: tensor<*xi1>, %arg1: tensor<*xf32>, %arg2: tensor<*xf32>) -> tensor<i1> {
func @_functionalize_if_then_branch_00(%arg0: tensor<*xi1>, %arg1: tensor<*xf32>, %arg2: tensor<*xf32>) -> tensor<i1> attributes {sym_visibility = "private"} {
%cst = constant dense<true> : tensor<i1>
return %cst : tensor<i1>
}
// CHECK: func @main
// CHECK-NOT: tf.If
// CHECK: return
// CHECK-NOT: func @_functionalize_if_else_branch_00
// CHECK-NOT: func @_functionalize_if_then_branch_00
// -----
// Verify unused if with function with side-effects is not removed.
// CHECK-LABEL: main
func @main(%arg0: tensor<3x15x14x3xf32>) -> tensor<3x15x14x8xf32>
attributes {tf.entry_function = {inputs = "input", outputs = "Conv2D"}} {
%cst = constant dense<[0, 1, 2, 3]> : tensor<4xi32>
@ -155,27 +115,25 @@ func @main(%arg0: tensor<3x15x14x3xf32>) -> tensor<3x15x14x8xf32>
return %4 : tensor<3x15x14x8xf32>
}
func @_functionalize_if_else_branch_01(%arg0: tensor<*xi1>, %arg1: tensor<*xf32>, %arg2: tensor<*xf32>) -> tensor<i1> {
func @_functionalize_if_else_branch_01(%arg0: tensor<*xi1>, %arg1: tensor<*xf32>, %arg2: tensor<*xf32>) -> tensor<i1> attributes {sym_visibility = "private"} {
%cst = constant dense<false> : tensor<i1>
return %cst : tensor<i1>
}
func @_functionalize_if_then_branch_01(%arg0: tensor<*xi1>, %arg1: tensor<*xf32>, %arg2: tensor<*xf32>) -> tensor<i1> {
func @_functionalize_if_then_branch_01(%arg0: tensor<*xi1>, %arg1: tensor<*xf32>, %arg2: tensor<*xf32>) -> tensor<i1> attributes {sym_visibility = "private"} {
%0 = "tf.blah"() : () -> tensor<i1>
return %0 : tensor<i1>
}
// CHECK: func @main
// CHECK: tf.If
// CHECK: return
// CHECK: func @_functionalize_if_else_branch_01
// CHECK: func @_functionalize_if_then_branch_01
// -----
// Verify unused if with function with side-effects is removed if op says
// stateless.
// CHECK-LABEL: main
func @main(%arg0: tensor<3x15x14x3xf32>) -> tensor<3x15x14x8xf32>
attributes {tf.entry_function = {inputs = "input", outputs = "Conv2D"}} {
%cst = constant dense<[0, 1, 2, 3]> : tensor<4xi32>
@ -193,18 +151,15 @@ func @main(%arg0: tensor<3x15x14x3xf32>) -> tensor<3x15x14x8xf32>
return %4 : tensor<3x15x14x8xf32>
}
func @_functionalize_if_else_branch_02(%arg0: tensor<*xi1>, %arg1: tensor<*xf32>, %arg2: tensor<*xf32>) -> tensor<i1> {
func @_functionalize_if_else_branch_02(%arg0: tensor<*xi1>, %arg1: tensor<*xf32>, %arg2: tensor<*xf32>) -> tensor<i1> attributes {sym_visibility = "private"} {
%cst = constant dense<false> : tensor<i1>
return %cst : tensor<i1>
}
func @_functionalize_if_then_branch_02(%arg0: tensor<*xi1>, %arg1: tensor<*xf32>, %arg2: tensor<*xf32>) -> tensor<i1> {
func @_functionalize_if_then_branch_02(%arg0: tensor<*xi1>, %arg1: tensor<*xf32>, %arg2: tensor<*xf32>) -> tensor<i1> attributes {sym_visibility = "private"} {
%0 = "tf.blah"() : () -> tensor<i1>
return %0 : tensor<i1>
}
// CHECK: func @main
// CHECK-NOT: tf.If
// CHECK: return
// CHECK-NOT: func @_functionalize_if_else_branch_02
// CHECK-NOT: func @_functionalize_if_then_branch_02

View File

@ -94,12 +94,6 @@ void AddTFToTFLConversionPasses(const mlir::TFL::PassConfig& pass_config,
pass_manager->addPass(mlir::TFL::CreatePrepareCompositeFunctionsPass());
}
// This pass marks non-exported functions as symbol visibility 'private'
// those deemed read-only as immutable.
pass_manager->addPass(
mlir::tf_saved_model::
CreateMarkFunctionVisibilityUsingSavedModelLinkagePass());
pass_manager->addPass(mlir::createInlinerPass());
pass_manager->addPass(mlir::createSymbolDCEPass());
@ -162,6 +156,7 @@ void AddTFToTFLConversionPasses(const mlir::TFL::PassConfig& pass_config,
// so that it can target constants introduced once TensorFlow Identity ops
// are removed during legalization.
pass_manager->addPass(mlir::TFL::CreateOptimizeFunctionalOpsPass());
pass_manager->addPass(mlir::createSymbolDCEPass());
pass_manager->addNestedPass<mlir::FuncOp>(mlir::createCanonicalizerPass());
pass_manager->addNestedPass<mlir::FuncOp>(mlir::createCSEPass());
// This pass should be always at the end of the floating point model
@ -237,6 +232,7 @@ void CreateTFLStandardPipeline(OpPassManager& pm,
mlir::TFL::CreateLegalizeTFPass(/*run_tfl_runtime_verification=*/true));
pm.addPass(mlir::TFL::CreateOptimizePass());
pm.addPass(mlir::TFL::CreateOptimizeFunctionalOpsPass());
pm.addPass(mlir::createSymbolDCEPass());
// Canonicalize, CSE etc.
pm.addNestedPass<mlir::FuncOp>(mlir::createCanonicalizerPass());

View File

@ -123,7 +123,6 @@ bool HasSameStaticShapes(Operation* op) {
// operands are properly supported in declarative rewrite rule specification.
DECL_CONVERT_OP(Assert);
DECL_CONVERT_OP(Concat);
DECL_CONVERT_OP(ConcatV2);
DECL_CONVERT_OP(MatMul);
DECL_CONVERT_OP(MatrixDiagV2);
@ -184,25 +183,6 @@ LogicalResult ConvertTFRandomUniformOp::matchAndRewrite(
return failure();
}
LogicalResult ConvertTFConcatOp::matchAndRewrite(
Operation* op, PatternRewriter& rewriter) const {
auto tf_concat_op = cast<TF::ConcatOp>(op);
auto values = tf_concat_op.values();
auto output_type = tf_concat_op.output().getType();
// Extract axis attribute from constant concat_dims tensor
ElementsAttr axis;
if (!matchPattern(tf_concat_op.concat_dim(), m_Constant(&axis)))
return failure();
StringAttr fused_activation_function =
StringAttr::get("NONE", rewriter.getContext());
rewriter.replaceOpWithNewOp<TFL::ConcatenationOp>(
op, output_type, values, mlir::TFL::ExtractSingleElementAsInteger(axis),
fused_activation_function);
return success();
}
// Converts any IntegerAttr to an IntegerAttr of an i32 type.
// The value won't change in the new attribute, but if the value is out of
// the bound of i32, the function returns a failure.
@ -517,6 +497,14 @@ StatusOr<ConstantOp> CreateConstOpWithSingleValue(PatternRewriter* rewriter,
attr = DenseElementsAttr::get(scalar_type, floatValues);
break;
}
case mlir::StandardTypes::BF16: {
auto floatType = mlir::FloatType::getBF16(element_type.getContext());
auto floatAttr =
mlir::FloatAttr::get(floatType, static_cast<float>(value));
std::vector<Attribute> floatValues({floatAttr});
attr = DenseElementsAttr::get(scalar_type, floatValues);
break;
}
case mlir::StandardTypes::F32: {
attr =
DenseElementsAttr::get<float>(scalar_type, static_cast<float>(value));
@ -756,11 +744,11 @@ void LegalizeTF::runOnFunction() {
// Add the generated patterns to the list.
populateWithGenerated(context, &patterns);
patterns.insert<ConvertTFConcatOp, ConvertTFConcatV2Op, ConvertTFMatMulOp,
ConvertTFMatrixDiagV2Op, ConvertTFMatrixDiagV3Op,
ConvertTFPackOp, ConvertTFReshapeOp, ConvertTFSplitOp,
ConvertTFSplitVOp, ConvertTFStridedSliceOp, ConvertTFUnpackOp,
ConvertTFAssertOp, ConvertTFReciprocalOp,
patterns
.insert<ConvertTFConcatV2Op, ConvertTFMatMulOp, ConvertTFMatrixDiagV2Op,
ConvertTFMatrixDiagV3Op, ConvertTFPackOp, ConvertTFReshapeOp,
ConvertTFSplitOp, ConvertTFSplitVOp, ConvertTFStridedSliceOp,
ConvertTFUnpackOp, ConvertTFAssertOp, ConvertTFReciprocalOp,
ConvertTFRandomUniformOp, ConvertTFBroadcastToOp>(context);
// Ophint python converter converted tf node pattern.

View File

@ -32,8 +32,6 @@ namespace mlir {
namespace TFL {
namespace {
using FuncSet = llvm::SmallSet<FuncOp, 4>;
// Module pass to optimize TensorFlow functional ops.
struct OptimizeFunctionalOpsPass
: public PassWrapper<OptimizeFunctionalOpsPass, OperationPass<ModuleOp>> {
@ -44,8 +42,8 @@ struct OptimizeFunctionalOpsPass
// op operands' types.
//
// Requires the function has exactly one block.
static void UpdateFuncType(FuncOp func) {
Operation* terminator = &func.getBlocks().front().back();
void UpdateFuncType(FuncOp func) {
Operation* terminator = func.front().getTerminator();
auto return_types = llvm::to_vector<4>(terminator->getOperandTypes());
FunctionType func_type = func.getType();
@ -57,7 +55,7 @@ static void UpdateFuncType(FuncOp func) {
}
// TODO(jpienaar): Remove when recursive side-effect modeling is added.
static bool IsSideEffectFree(FuncOp func) {
bool IsSideEffectFree(FuncOp func) {
return !func.getBody()
.walk([&](Operation* op) {
if (!MemoryEffectOpInterface::hasNoEffect(op) &&
@ -72,8 +70,8 @@ static bool IsSideEffectFree(FuncOp func) {
// function body based on the conditional value.
class FoldIfOp : public OpRewritePattern<TF::IfOp> {
public:
explicit FoldIfOp(MLIRContext* context, FuncSet* inlined_funcs)
: OpRewritePattern<TF::IfOp>(context), inlined_funcs_(inlined_funcs) {}
explicit FoldIfOp(MLIRContext* context)
: OpRewritePattern<TF::IfOp>(context) {}
LogicalResult matchAndRewrite(TF::IfOp op,
PatternRewriter& rewriter) const override {
@ -82,7 +80,7 @@ class FoldIfOp : public OpRewritePattern<TF::IfOp> {
// updated if operands' shapes change after inlining. Without this
// restriction, it would require tensor cast ops.
FuncOp parent_op = op.getParentOfType<FuncOp>();
if (parent_op.getBlocks().size() != 1) return failure();
if (!llvm::hasSingleElement(parent_op)) return failure();
// Find the then and else branch functions.
SymbolTable table(op.getParentOfType<ModuleOp>());
@ -95,8 +93,6 @@ class FoldIfOp : public OpRewritePattern<TF::IfOp> {
if (op.use_empty() &&
(op.is_stateless() ||
(IsSideEffectFree(then_branch) && IsSideEffectFree(else_branch)))) {
inlined_funcs_->insert(then_branch);
inlined_funcs_->insert(else_branch);
rewriter.eraseOp(op.getOperation());
return success();
}
@ -118,14 +114,14 @@ class FoldIfOp : public OpRewritePattern<TF::IfOp> {
// Make sure that the function has exactly one block to simplify inlining.
// TFLite doesn't use control flow with blocks so functions with more than
// one blocks are not encountered in practice.
if (func.getBody().getBlocks().size() != 1) return failure();
if (!llvm::hasSingleElement(func)) return failure();
BlockAndValueMapping mapper;
for (int i = 0, e = func.getNumArguments(); i != e; ++i)
mapper.map(func.getArgument(i), op.getOperand(i + 1));
llvm::SmallVector<Value, 4> updated_results;
for (auto& op_to_inline : func.getBody().front()) {
for (auto& op_to_inline : func.front()) {
// If this is a terminator, identify the values to use to replace the
// original If op.
if (op_to_inline.isKnownTerminator()) {
@ -145,64 +141,26 @@ class FoldIfOp : public OpRewritePattern<TF::IfOp> {
// return type should be updated.
UpdateFuncType(parent_op);
// Track functions that could be erased if this op was the last reference
// of the function.
inlined_funcs_->insert(then_branch);
inlined_funcs_->insert(else_branch);
return success();
}
private:
FuncSet* inlined_funcs_;
};
// Erases functions from the given candidates that are not referenced by any of
// the ops in the module.
static void EraseDeadFuncs(const FuncSet& candidate_funcs, ModuleOp module) {
if (candidate_funcs.empty()) return;
SymbolTable manager(module);
// Identify the functions that are used as symbols in the module and shouldn't
// be erased.
FuncSet in_use_funcs;
manager.getOp()->walk([&](Operation* op) {
for (auto attr : op->getAttrs()) {
if (auto symbol = attr.second.dyn_cast<FlatSymbolRefAttr>()) {
auto func = manager.lookup<FuncOp>(symbol.getValue());
in_use_funcs.insert(func);
}
}
});
for (FuncOp func : candidate_funcs) {
if (!in_use_funcs.count(func)) manager.erase(func);
}
}
void OptimizeFunctionalOpsPass::runOnOperation() {
OwningRewritePatternList patterns;
FuncSet inlined_funcs;
patterns.insert<FoldIfOp>(&getContext(), &inlined_funcs);
patterns.insert<FoldIfOp>(&getContext());
ModuleOp module = getOperation();
applyPatternsAndFoldGreedily(module, patterns);
// Erase inlined functions that don't have any references.
//
// TODO(hinsu): Update this to not erase entry points once TFLite support to
// have multiple entry points is implemented. Until then, it is safe to
// erase these functions.
EraseDeadFuncs(inlined_funcs, module);
}
PassRegistration<OptimizeFunctionalOpsPass> pass(
"tfl-optimize-functional-ops", "Optimize TensorFlow functional ops");
} // namespace
std::unique_ptr<OperationPass<ModuleOp>> CreateOptimizeFunctionalOpsPass() {
return std::make_unique<OptimizeFunctionalOpsPass>();
}
static PassRegistration<OptimizeFunctionalOpsPass> pass(
"tfl-optimize-functional-ops", "Optimize TensorFlow functional ops");
} // namespace TFL
} // namespace mlir

View File

@ -29,7 +29,7 @@ cc_library(
# place for core related components.
"//tensorflow/compiler/mlir/tensorflow:graph_optimization_pass_registration",
"//tensorflow/compiler/mlir/tensorflow:import_utils",
"@llvm-project//llvm:support",
"@llvm-project//llvm:Support",
"@llvm-project//mlir:IR",
"@llvm-project//mlir:Parser",
"@llvm-project//mlir:Pass",

View File

@ -20,7 +20,7 @@ tf_python_pybind_extension(
"//tensorflow/compiler/mlir/tensorflow:tensorflow_types",
"//tensorflow/python:pybind11_lib",
"//tensorflow/python:pybind11_status",
"@llvm-project//llvm:support",
"@llvm-project//llvm:Support",
"@llvm-project//mlir:IR",
"@llvm-project//mlir:StandardOps",
"@pybind11",
@ -35,7 +35,7 @@ tf_python_pybind_extension(
deps = [
"//tensorflow/python:pybind11_lib",
"//tensorflow/python:pybind11_status",
"@llvm-project//llvm:support",
"@llvm-project//llvm:Support",
"@pybind11",
],
)

View File

@ -187,7 +187,7 @@ gentbl(
td_file = "transforms/legalize_hlo_patterns.td",
td_srcs = [
"//tensorflow/compiler/mlir/xla:hlo_ops_td_files",
"@llvm-project//llvm:support",
"@llvm-project//llvm:Support",
"@llvm-project//mlir:StdOpsTdFiles",
"//tensorflow/compiler/mlir/tensorflow:tensorflow_ops_td_files",
],
@ -204,7 +204,7 @@ cc_library(
":tensorflow",
"//tensorflow/compiler/mlir/xla:hlo",
"//tensorflow/core:framework",
"@llvm-project//llvm:support",
"@llvm-project//llvm:Support",
"@llvm-project//mlir:Analysis",
"@llvm-project//mlir:Dialect",
"@llvm-project//mlir:IR",
@ -225,7 +225,7 @@ cc_library(
"ir/tf_attributes.h",
],
deps = [
"@llvm-project//llvm:support",
"@llvm-project//llvm:Support",
"@llvm-project//mlir:IR",
],
)
@ -240,7 +240,7 @@ cc_library(
"ir/tf_types.h",
],
deps = [
"@llvm-project//llvm:support",
"@llvm-project//llvm:Support",
"@llvm-project//mlir:Dialect",
"@llvm-project//mlir:IR",
],
@ -293,7 +293,7 @@ cc_library(
"//tensorflow/core:framework",
"//tensorflow/core:lib",
"//tensorflow/core/platform:logging",
"@llvm-project//llvm:support",
"@llvm-project//llvm:Support",
"@llvm-project//mlir:Analysis",
"@llvm-project//mlir:CallOpInterfacesIncGen",
"@llvm-project//mlir:DerivedAttributeOpInterface",
@ -388,7 +388,7 @@ cc_library(
":tensorflow",
"//tensorflow/core:framework",
"@com_google_absl//absl/memory",
"@llvm-project//llvm:support",
"@llvm-project//llvm:Support",
"@llvm-project//mlir:Analysis",
"@llvm-project//mlir:IR",
"@llvm-project//mlir:Pass",
@ -416,6 +416,7 @@ cc_library(
"transforms/fold_switch.cc",
"transforms/freeze_global_tensors.cc",
"transforms/functional_control_flow_to_cfg.cc",
"transforms/fused_kernel_matcher.cc",
"transforms/generated_canonicalize.inc",
"transforms/generated_optimize.inc",
"transforms/gpu_fusion.cc",
@ -424,7 +425,6 @@ cc_library(
"transforms/layout_optimization.cc",
"transforms/mark_function_visibility.cc",
"transforms/materialize_mlir_passthrough_op.cc",
"transforms/op_fusion.cc",
"transforms/optimize.cc",
"transforms/optimize_global_tensors.cc",
"transforms/parallel_execute_to_islands.cc",
@ -500,7 +500,7 @@ cc_library(
"//tensorflow/core/protobuf/tpu:dynamic_padding_proto_cc",
"@com_google_absl//absl/memory",
"@com_google_absl//absl/strings",
"@llvm-project//llvm:support",
"@llvm-project//llvm:Support",
"@llvm-project//mlir:Analysis",
"@llvm-project//mlir:IR",
"@llvm-project//mlir:Parser",
@ -609,7 +609,7 @@ cc_library(
"@com_google_absl//absl/container:inlined_vector",
"@com_google_absl//absl/strings",
"@com_google_absl//absl/types:optional",
"@llvm-project//llvm:support",
"@llvm-project//llvm:Support",
"@llvm-project//mlir:Analysis",
"@llvm-project//mlir:IR",
"@llvm-project//mlir:Pass",
@ -639,7 +639,7 @@ cc_library(
":parse_text_proto",
"//tensorflow/core:lib",
"@com_google_absl//absl/strings",
"@llvm-project//llvm:support",
"@llvm-project//llvm:Support",
],
)
@ -669,7 +669,7 @@ cc_library(
"@com_google_absl//absl/container:flat_hash_set",
"@com_google_absl//absl/memory",
"@com_google_absl//absl/strings",
"@llvm-project//llvm:support",
"@llvm-project//llvm:Support",
"@llvm-project//mlir:IR",
"@llvm-project//mlir:StandardOps",
"@llvm-project//mlir:Support",
@ -712,7 +712,7 @@ cc_library(
"//tensorflow/stream_executor/lib",
"@com_google_absl//absl/container:flat_hash_set",
"@com_google_absl//absl/strings",
"@llvm-project//llvm:support",
"@llvm-project//llvm:Support",
"@llvm-project//mlir:IR",
],
)
@ -722,7 +722,7 @@ cc_library(
srcs = ["translate/translate_tf_dialect_op.cc"],
deps = [
":export_tf_dialect_op",
"@llvm-project//llvm:support",
"@llvm-project//llvm:Support",
"@llvm-project//mlir:IR",
"@llvm-project//mlir:Support",
"@llvm-project//mlir:Translation",
@ -781,7 +781,7 @@ cc_library(
"@com_google_absl//absl/container:flat_hash_set",
"@com_google_absl//absl/container:inlined_vector",
"@com_google_absl//absl/strings",
"@llvm-project//llvm:support",
"@llvm-project//llvm:Support",
],
)
@ -799,7 +799,7 @@ cc_library(
"//tensorflow/core:protos_all_cc",
"//tensorflow/stream_executor/lib",
"@com_google_absl//absl/strings",
"@llvm-project//llvm:support",
"@llvm-project//llvm:Support",
"@llvm-project//mlir:IR",
"@llvm-project//mlir:Support",
],
@ -816,7 +816,7 @@ tf_cc_test(
"//tensorflow/core:test",
"//tensorflow/core:test_main",
"//tensorflow/stream_executor/lib",
"@llvm-project//llvm:support",
"@llvm-project//llvm:Support",
"@llvm-project//mlir:IR",
],
)
@ -838,7 +838,7 @@ cc_library(
"@com_google_absl//absl/base",
"@com_google_absl//absl/container:inlined_vector",
"@com_google_absl//absl/strings",
"@llvm-project//llvm:support",
"@llvm-project//llvm:Support",
"@llvm-project//mlir:IR",
],
)
@ -882,7 +882,7 @@ cc_library(
deps = [
"//tensorflow/core/platform:errors",
"//tensorflow/core/platform:status",
"@llvm-project//llvm:support",
"@llvm-project//llvm:Support",
"@llvm-project//mlir:IR",
],
)
@ -925,7 +925,7 @@ cc_library(
"//tensorflow/core:lib",
"//tensorflow/stream_executor",
"//tensorflow/stream_executor/lib",
"@llvm-project//llvm:support",
"@llvm-project//llvm:Support",
"@llvm-project//mlir:IR",
"@llvm-project//mlir:SideEffects",
"@llvm-project//mlir:Support",
@ -957,7 +957,7 @@ cc_library(
"//tensorflow/core:ops",
"//tensorflow/core:protos_all_cc",
"//tensorflow/stream_executor/lib",
"@llvm-project//llvm:support",
"@llvm-project//llvm:Support",
"@llvm-project//mlir:IR",
"@llvm-project//mlir:Pass",
],
@ -986,7 +986,7 @@ cc_library(
"@com_google_absl//absl/container:inlined_vector",
"@com_google_absl//absl/memory",
"@com_google_absl//absl/strings",
"@llvm-project//llvm:support",
"@llvm-project//llvm:Support",
"@llvm-project//mlir:IR",
"@llvm-project//mlir:Support",
],
@ -1011,7 +1011,7 @@ cc_library(
"@com_google_absl//absl/memory",
"@com_google_absl//absl/strings",
"@com_google_absl//absl/types:span",
"@llvm-project//llvm:support",
"@llvm-project//llvm:Support",
"@llvm-project//mlir:IR",
"@llvm-project//mlir:Parser",
"@llvm-project//mlir:Pass",
@ -1027,7 +1027,7 @@ cc_library(
"translate/tf_mlir_translate_cl.h",
],
deps = [
"@llvm-project//llvm:support",
"@llvm-project//llvm:Support",
],
alwayslink = 1,
)
@ -1044,7 +1044,7 @@ cc_library(
":translate_lib",
"//tensorflow/core:protos_all_cc",
"//tensorflow/stream_executor/lib",
"@llvm-project//llvm:support",
"@llvm-project//llvm:Support",
"@llvm-project//mlir:IR",
"@llvm-project//mlir:Translation",
],
@ -1060,7 +1060,7 @@ tf_cc_test(
"//tensorflow/core:lib",
"//tensorflow/core:test",
"//tensorflow/core:test_main",
"@llvm-project//llvm:support",
"@llvm-project//llvm:Support",
"@llvm-project//mlir:IR",
],
)
@ -1071,8 +1071,8 @@ tf_native_cc_binary(
"translate/derived_attr_populator_gen.cc",
],
deps = [
"@llvm-project//llvm:support",
"@llvm-project//llvm:tablegen",
"@llvm-project//llvm:Support",
"@llvm-project//llvm:TableGen",
"@llvm-project//mlir:TableGen",
],
)
@ -1134,7 +1134,7 @@ COMPILE_MLIR_UTIL_DEPS = [
":tensorflow_passes",
":translate_utils",
"@com_google_absl//absl/types:optional",
"@llvm-project//llvm:support",
"@llvm-project//llvm:Support",
"@llvm-project//mlir:IR",
"@llvm-project//mlir:Parser",
"@llvm-project//mlir:Pass",
@ -1266,7 +1266,7 @@ cc_library(
":tensorflow",
":tensorflow_types",
"//tensorflow/core:framework",
"@llvm-project//llvm:support",
"@llvm-project//llvm:Support",
"@llvm-project//mlir:IR",
],
alwayslink = 1,
@ -1285,7 +1285,7 @@ cc_library(
"//tensorflow/core/protobuf/tpu:topology_proto_cc",
"//tensorflow/stream_executor/lib",
"@com_google_absl//absl/strings",
"@llvm-project//llvm:support",
"@llvm-project//llvm:Support",
"@llvm-project//mlir:IR",
],
)
@ -1300,7 +1300,7 @@ tf_cc_test(
"//tensorflow/core:test",
"//tensorflow/core:test_main",
"//tensorflow/core/protobuf/tpu:topology_proto_cc",
"@llvm-project//llvm:support",
"@llvm-project//llvm:Support",
"@llvm-project//mlir:IR",
],
)
@ -1313,7 +1313,7 @@ cc_library(
":tensorflow",
"//tensorflow/core:core_cpu_lib",
"//tensorflow/core:framework",
"@llvm-project//llvm:support",
"@llvm-project//llvm:Support",
"@llvm-project//mlir:IR",
"@llvm-project//mlir:Support",
],
@ -1331,7 +1331,7 @@ tf_cc_test(
"//tensorflow/core:protos_all_cc",
"//tensorflow/core:test",
"//tensorflow/core:test_main",
"@llvm-project//llvm:support",
"@llvm-project//llvm:Support",
"@llvm-project//mlir:IR",
"@llvm-project//mlir:Support",
],
@ -1344,7 +1344,7 @@ cc_library(
deps = [
"//tensorflow/core:lib",
"//tensorflow/core/platform:logging",
"@llvm-project//llvm:support",
"@llvm-project//llvm:Support",
"@llvm-project//mlir:IR",
],
)
@ -1359,7 +1359,7 @@ tf_cc_test(
"//tensorflow/core:test",
"//tensorflow/core:test_main",
"//tensorflow/core/platform:test",
"@llvm-project//llvm:support",
"@llvm-project//llvm:Support",
"@llvm-project//mlir:IR",
],
)
@ -1378,7 +1378,7 @@ cc_library(
"//tensorflow/core:graph",
"//tensorflow/core:lib",
"//tensorflow/core/platform:logging",
"@llvm-project//llvm:support",
"@llvm-project//llvm:Support",
"@llvm-project//mlir:AllPassesAndDialects",
"@llvm-project//mlir:Analysis",
"@llvm-project//mlir:IR",
@ -1398,7 +1398,7 @@ tf_cc_test(
"//tensorflow/core:test",
"//tensorflow/core:test_main",
"//tensorflow/core/platform:test",
"@llvm-project//llvm:support",
"@llvm-project//llvm:Support",
"@llvm-project//mlir:IR",
],
)
@ -1409,7 +1409,7 @@ cc_library(
hdrs = ["utils/bridge_logger.h"],
deps = [
":dump_mlir_util",
"@llvm-project//llvm:support",
"@llvm-project//llvm:Support",
"@llvm-project//mlir:IR",
"@llvm-project//mlir:Pass",
],
@ -1425,7 +1425,7 @@ cc_library(
"//tensorflow/compiler/tf2xla:resource_operation_table",
"//tensorflow/core:framework",
"@com_google_absl//absl/strings",
"@llvm-project//llvm:support",
"@llvm-project//llvm:Support",
"@llvm-project//mlir:IR",
"@llvm-project//mlir:Support",
],
@ -1443,7 +1443,7 @@ cc_library(
":tensorflow",
"//tensorflow/compiler/xla:xla_data_proto_cc",
"//tensorflow/core/protobuf/tpu:compile_metadata_proto_cc",
"@llvm-project//llvm:support",
"@llvm-project//llvm:Support",
"@llvm-project//mlir:IR",
"@llvm-project//mlir:Support",
],

View File

@ -36,7 +36,7 @@ tf_cuda_library(
"//tensorflow/core:framework",
"//tensorflow/core:protos_all_cc",
"//tensorflow/core/platform:errors",
"@llvm-project//llvm:support",
"@llvm-project//llvm:Support",
"@llvm-project//mlir:IR",
"@llvm-project//mlir:Pass",
"@llvm-project//mlir:StandardOps",

View File

@ -794,6 +794,34 @@ This op is deprecated. Prefer `tf.nn.batch_normalization`.
TF_DerivedOperandTypeAttr T = TF_DerivedOperandTypeAttr<0>;
}
def TF_BatchToSpaceOp : TF_Op<"BatchToSpace", [NoSideEffect]> {
let summary = "BatchToSpace for 4-D tensors of type T.";
let description = [{
This is a legacy version of the more general BatchToSpaceND.
Rearranges (permutes) data from batch into blocks of spatial data, followed by
cropping. This is the reverse transformation of SpaceToBatch. More specifically,
this op outputs a copy of the input tensor where values from the `batch`
dimension are moved in spatial blocks to the `height` and `width` dimensions,
followed by cropping along the `height` and `width` dimensions.
}];
let arguments = (ins
TF_Tensor:$input,
TF_I32OrI64Tensor:$crops,
Confined<I64Attr, [IntMinValue<2>]>:$block_size
);
let results = (outs
TF_Tensor:$output
);
TF_DerivedOperandTypeAttr T = TF_DerivedOperandTypeAttr<0>;
TF_DerivedOperandTypeAttr Tidx = TF_DerivedOperandTypeAttr<1>;
}
def TF_BatchToSpaceNDOp : TF_Op<"BatchToSpaceND", [NoSideEffect]> {
let summary = "BatchToSpace for N-D tensors of type T.";
@ -1219,6 +1247,35 @@ subsequent operation and then be optimized away, however.)
}];
}
def TF_BucketizeOp : TF_Op<"Bucketize", [NoSideEffect, SameOperandsAndResultShape]> {
let summary = "Bucketizes 'input' based on 'boundaries'.";
let description = [{
For example, if the inputs are
boundaries = [0, 10, 100]
input = [[-5, 10000]
[150, 10]
[5, 100]]
then the output will be
output = [[0, 3]
[3, 2]
[1, 3]]
}];
let arguments = (ins
TensorOf<[F32, F64, I32, I64]>:$input,
F32ArrayAttr:$boundaries
);
let results = (outs
I32Tensor:$output
);
TF_DerivedOperandTypeAttr T = TF_DerivedOperandTypeAttr<0>;
}
def TF_CaseOp : TF_Op<"Case", []> {
let summary = [{
An n-way switch statement which calls a single branch function.
@ -1257,6 +1314,8 @@ An n-way switch statement, implementing the following:
TF_DerivedOperandTypeListAttr Tin = TF_DerivedOperandTypeListAttr<1>;
TF_DerivedResultTypeListAttr Tout = TF_DerivedResultTypeListAttr<0>;
let hasCanonicalizer = 1;
}
def TF_CastOp : TF_Op<"Cast", [NoSideEffect, SameOperandsAndResultShape]> {
@ -1519,6 +1578,8 @@ def TF_ConcatOp : TF_Op<"Concat", [NoSideEffect]> {
let verifier = [{
return Verify(*this);
}];
let hasCanonicalizer = 1;
}
def TF_ConcatOffsetOp : TF_Op<"ConcatOffset", [NoSideEffect]> {
@ -2320,6 +2381,21 @@ horizontal and vertices strides, `strides = [1, stride, stride, 1]`.
TF_DerivedOperandTypeAttr T = TF_DerivedOperandTypeAttr<0>;
}
def TF_DeviceIndexOp : TF_Op<"DeviceIndex", [NoSideEffect]> {
let summary = "Return the index of device the op runs.";
let description = [{
}];
let arguments = (ins
StrArrayAttr:$device_names
);
let results = (outs
I32Tensor:$index
);
}
def TF_DiagPartOp : TF_Op<"DiagPart", [NoSideEffect]> {
let summary = "Returns the diagonal part of the tensor.";
@ -2968,6 +3044,63 @@ i.e. `exp(x) - 1` or `e^(x) - 1`, where `x` is the input tensor.
TF_DerivedOperandTypeAttr T = TF_DerivedOperandTypeAttr<0>;
}
def TF_FFTOp : TF_Op<"FFT", [NoSideEffect]> {
let summary = "Fast Fourier transform.";
let description = [{
Computes the 1-dimensional discrete Fourier transform over the inner-most
dimension of `input`.
}];
let arguments = (ins
TensorOf<[TF_Complex128, TF_Complex64]>:$input
);
let results = (outs
TensorOf<[TF_Complex128, TF_Complex64]>:$output
);
TF_DerivedOperandTypeAttr Tcomplex = TF_DerivedOperandTypeAttr<0>;
}
def TF_FFT2DOp : TF_Op<"FFT2D", [NoSideEffect]> {
let summary = "2D fast Fourier transform.";
let description = [{
Computes the 2-dimensional discrete Fourier transform over the inner-most
2 dimensions of `input`.
}];
let arguments = (ins
TensorOf<[TF_Complex128, TF_Complex64]>:$input
);
let results = (outs
TensorOf<[TF_Complex128, TF_Complex64]>:$output
);
TF_DerivedOperandTypeAttr Tcomplex = TF_DerivedOperandTypeAttr<0>;
}
def TF_FFT3DOp : TF_Op<"FFT3D", [NoSideEffect]> {
let summary = "3D fast Fourier transform.";
let description = [{
Computes the 3-dimensional discrete Fourier transform over the inner-most 3
dimensions of `input`.
}];
let arguments = (ins
TensorOf<[TF_Complex128, TF_Complex64]>:$input
);
let results = (outs
TensorOf<[TF_Complex128, TF_Complex64]>:$output
);
TF_DerivedOperandTypeAttr Tcomplex = TF_DerivedOperandTypeAttr<0>;
}
def TF_FakeQuantWithMinMaxArgsOp : TF_Op<"FakeQuantWithMinMaxArgs", [NoSideEffect, SameOperandsAndResultType]> {
let summary = [{
Fake-quantize the 'inputs' tensor, type float to 'outputs' tensor of same type.
@ -3727,6 +3860,161 @@ table will be immutable.
);
}
def TF_IFFTOp : TF_Op<"IFFT", [NoSideEffect]> {
let summary = "Inverse fast Fourier transform.";
let description = [{
Computes the inverse 1-dimensional discrete Fourier transform over the
inner-most dimension of `input`.
}];
let arguments = (ins
TensorOf<[TF_Complex128, TF_Complex64]>:$input
);
let results = (outs
TensorOf<[TF_Complex128, TF_Complex64]>:$output
);
TF_DerivedOperandTypeAttr Tcomplex = TF_DerivedOperandTypeAttr<0>;
}
def TF_IFFT2DOp : TF_Op<"IFFT2D", [NoSideEffect]> {
let summary = "Inverse 2D fast Fourier transform.";
let description = [{
Computes the inverse 2-dimensional discrete Fourier transform over the
inner-most 2 dimensions of `input`.
}];
let arguments = (ins
TensorOf<[TF_Complex128, TF_Complex64]>:$input
);
let results = (outs
TensorOf<[TF_Complex128, TF_Complex64]>:$output
);
TF_DerivedOperandTypeAttr Tcomplex = TF_DerivedOperandTypeAttr<0>;
}
def TF_IFFT3DOp : TF_Op<"IFFT3D", [NoSideEffect]> {
let summary = "Inverse 3D fast Fourier transform.";
let description = [{
Computes the inverse 3-dimensional discrete Fourier transform over the
inner-most 3 dimensions of `input`.
}];
let arguments = (ins
TensorOf<[TF_Complex128, TF_Complex64]>:$input
);
let results = (outs
TensorOf<[TF_Complex128, TF_Complex64]>:$output
);
TF_DerivedOperandTypeAttr Tcomplex = TF_DerivedOperandTypeAttr<0>;
}
def TF_IRFFTOp : TF_Op<"IRFFT", [NoSideEffect]> {
let summary = "Inverse real-valued fast Fourier transform.";
let description = [{
Computes the inverse 1-dimensional discrete Fourier transform of a real-valued
signal over the inner-most dimension of `input`.
The inner-most dimension of `input` is assumed to be the result of `RFFT`: the
`fft_length / 2 + 1` unique components of the DFT of a real-valued signal. If
`fft_length` is not provided, it is computed from the size of the inner-most
dimension of `input` (`fft_length = 2 * (inner - 1)`). If the FFT length used to
compute `input` is odd, it should be provided since it cannot be inferred
properly.
Along the axis `IRFFT` is computed on, if `fft_length / 2 + 1` is smaller
than the corresponding dimension of `input`, the dimension is cropped. If it is
larger, the dimension is padded with zeros.
}];
let arguments = (ins
TensorOf<[TF_Complex128, TF_Complex64]>:$input,
I32Tensor:$fft_length
);
let results = (outs
TF_F32OrF64Tensor:$output
);
TF_DerivedResultTypeAttr Treal = TF_DerivedResultTypeAttr<0>;
TF_DerivedOperandTypeAttr Tcomplex = TF_DerivedOperandTypeAttr<0>;
}
def TF_IRFFT2DOp : TF_Op<"IRFFT2D", [NoSideEffect]> {
let summary = "Inverse 2D real-valued fast Fourier transform.";
let description = [{
Computes the inverse 2-dimensional discrete Fourier transform of a real-valued
signal over the inner-most 2 dimensions of `input`.
The inner-most 2 dimensions of `input` are assumed to be the result of `RFFT2D`:
The inner-most dimension contains the `fft_length / 2 + 1` unique components of
the DFT of a real-valued signal. If `fft_length` is not provided, it is computed
from the size of the inner-most 2 dimensions of `input`. If the FFT length used
to compute `input` is odd, it should be provided since it cannot be inferred
properly.
Along each axis `IRFFT2D` is computed on, if `fft_length` (or
`fft_length / 2 + 1` for the inner-most dimension) is smaller than the
corresponding dimension of `input`, the dimension is cropped. If it is larger,
the dimension is padded with zeros.
}];
let arguments = (ins
TensorOf<[TF_Complex128, TF_Complex64]>:$input,
I32Tensor:$fft_length
);
let results = (outs
TF_F32OrF64Tensor:$output
);
TF_DerivedResultTypeAttr Treal = TF_DerivedResultTypeAttr<0>;
TF_DerivedOperandTypeAttr Tcomplex = TF_DerivedOperandTypeAttr<0>;
}
def TF_IRFFT3DOp : TF_Op<"IRFFT3D", [NoSideEffect]> {
let summary = "Inverse 3D real-valued fast Fourier transform.";
let description = [{
Computes the inverse 3-dimensional discrete Fourier transform of a real-valued
signal over the inner-most 3 dimensions of `input`.
The inner-most 3 dimensions of `input` are assumed to be the result of `RFFT3D`:
The inner-most dimension contains the `fft_length / 2 + 1` unique components of
the DFT of a real-valued signal. If `fft_length` is not provided, it is computed
from the size of the inner-most 3 dimensions of `input`. If the FFT length used
to compute `input` is odd, it should be provided since it cannot be inferred
properly.
Along each axis `IRFFT3D` is computed on, if `fft_length` (or
`fft_length / 2 + 1` for the inner-most dimension) is smaller than the
corresponding dimension of `input`, the dimension is cropped. If it is larger,
the dimension is padded with zeros.
}];
let arguments = (ins
TensorOf<[TF_Complex128, TF_Complex64]>:$input,
I32Tensor:$fft_length
);
let results = (outs
TF_F32OrF64Tensor:$output
);
TF_DerivedResultTypeAttr Treal = TF_DerivedResultTypeAttr<0>;
TF_DerivedOperandTypeAttr Tcomplex = TF_DerivedOperandTypeAttr<0>;
}
def TF_IdentityNOp : TF_Op<"IdentityN", [NoSideEffect]> {
let summary = [{
Returns a list of tensors with the same shapes and contents as the input
@ -4142,6 +4430,30 @@ convolutional neural networks (NIPS 2012)](http://papers.nips.cc/paper/4824-imag
TF_DerivedOperandTypeAttr T = TF_DerivedOperandTypeAttr<0>;
}
def TF_LRNGradOp : TF_Op<"LRNGrad", [NoSideEffect]> {
let summary = "Gradients for Local Response Normalization.";
let description = [{
}];
let arguments = (ins
TensorOf<[BF16, F16, F32]>:$input_grads,
TensorOf<[BF16, F16, F32]>:$input_image,
TensorOf<[BF16, F16, F32]>:$output_image,
DefaultValuedAttr<I64Attr, "5">:$depth_radius,
DefaultValuedAttr<F32Attr, "1.0f">:$bias,
DefaultValuedAttr<F32Attr, "1.0f">:$alpha,
DefaultValuedAttr<F32Attr, "0.5f">:$beta
);
let results = (outs
TensorOf<[BF16, F16, F32]>:$output
);
TF_DerivedOperandTypeAttr T = TF_DerivedOperandTypeAttr<0>;
}
def TF_LeakyReluOp : TF_Op<"LeakyRelu", [NoSideEffect, SameOperandsAndResultType]> {
let summary = "Computes rectified linear: `max(features, features * alpha)`.";
@ -6333,6 +6645,66 @@ the dimension is padded with zeros.
TF_DerivedResultTypeAttr Tcomplex = TF_DerivedResultTypeAttr<0>;
}
def TF_RFFT2DOp : TF_Op<"RFFT2D", [NoSideEffect]> {
let summary = "2D real-valued fast Fourier transform.";
let description = [{
Computes the 2-dimensional discrete Fourier transform of a real-valued signal
over the inner-most 2 dimensions of `input`.
Since the DFT of a real signal is Hermitian-symmetric, `RFFT2D` only returns the
`fft_length / 2 + 1` unique components of the FFT for the inner-most dimension
of `output`: the zero-frequency term, followed by the `fft_length / 2`
positive-frequency terms.
Along each axis `RFFT2D` is computed on, if `fft_length` is smaller than the
corresponding dimension of `input`, the dimension is cropped. If it is larger,
the dimension is padded with zeros.
}];
let arguments = (ins
TF_F32OrF64Tensor:$input,
I32Tensor:$fft_length
);
let results = (outs
TensorOf<[TF_Complex128, TF_Complex64]>:$output
);
TF_DerivedOperandTypeAttr Treal = TF_DerivedOperandTypeAttr<0>;
TF_DerivedResultTypeAttr Tcomplex = TF_DerivedResultTypeAttr<0>;
}
def TF_RFFT3DOp : TF_Op<"RFFT3D", [NoSideEffect]> {
let summary = "3D real-valued fast Fourier transform.";
let description = [{
Computes the 3-dimensional discrete Fourier transform of a real-valued signal
over the inner-most 3 dimensions of `input`.
Since the DFT of a real signal is Hermitian-symmetric, `RFFT3D` only returns the
`fft_length / 2 + 1` unique components of the FFT for the inner-most dimension
of `output`: the zero-frequency term, followed by the `fft_length / 2`
positive-frequency terms.
Along each axis `RFFT3D` is computed on, if `fft_length` is smaller than the
corresponding dimension of `input`, the dimension is cropped. If it is larger,
the dimension is padded with zeros.
}];
let arguments = (ins
TF_F32OrF64Tensor:$input,
I32Tensor:$fft_length
);
let results = (outs
TensorOf<[TF_Complex128, TF_Complex64]>:$output
);
TF_DerivedOperandTypeAttr Treal = TF_DerivedOperandTypeAttr<0>;
TF_DerivedResultTypeAttr Tcomplex = TF_DerivedResultTypeAttr<0>;
}
def TF_RandomGammaGradOp : TF_Op<"RandomGammaGrad", [NoSideEffect, ResultsBroadcastableShape]>,
WithBroadcastableBinOpBuilder {
let summary = [{
@ -8134,6 +8506,34 @@ def TF_SoftsignGradOp : TF_Op<"SoftsignGrad", [NoSideEffect, SameOperandsAndResu
TF_DerivedOperandTypeAttr T = TF_DerivedOperandTypeAttr<0>;
}
def TF_SpaceToBatchOp : TF_Op<"SpaceToBatch", [NoSideEffect]> {
let summary = "SpaceToBatch for 4-D tensors of type T.";
let description = [{
This is a legacy version of the more general SpaceToBatchND.
Zero-pads and then rearranges (permutes) blocks of spatial data into batch.
More specifically, this op outputs a copy of the input tensor where values from
the `height` and `width` dimensions are moved to the `batch` dimension. After
the zero-padding, both `height` and `width` of the input must be divisible by the
block size.
}];
let arguments = (ins
TF_Tensor:$input,
TF_I32OrI64Tensor:$paddings,
Confined<I64Attr, [IntMinValue<2>]>:$block_size
);
let results = (outs
TF_Tensor:$output
);
TF_DerivedOperandTypeAttr T = TF_DerivedOperandTypeAttr<0>;
TF_DerivedOperandTypeAttr Tpaddings = TF_DerivedOperandTypeAttr<1>;
}
def TF_SpaceToBatchNDOp : TF_Op<"SpaceToBatchND", [NoSideEffect]> {
let summary = "SpaceToBatch for N-D tensors of type T.";
@ -10892,6 +11292,50 @@ create these operators.
TF_DerivedOperandSizeAttr num_args = TF_DerivedOperandSizeAttr<2>;
}
def TF__FusedMatMulOp : TF_Op<"_FusedMatMul", [NoSideEffect]> {
let summary = [{
Performs a MatMul followed by a specified series of operations.
}];
let description = [{
The inputs to the MatMul are specified by `a` and `b`. The series of operations
that follows is specified by the `fused_ops` attribute, which is a list of TF op
names specified as strings (e.g. "Relu"). They are performed in order, where the
(first) input to each op is the output of the preceding op. The first input and
the output of each fused_op must be of type T.
Currently supported fused_op combinations are: ["BiasAdd"] and ["BiasAdd",A],
where A is one of {"Elu","Relu","Relu6"}.
* The first input to BiasAdd is the Conv2D result, and the additional BiasAdd
input is specified by `args`.
* If there is an op A specified, the output of the BiasAdd is the input to op A,
and op A produces the _FusedConv2D output. Otherwise, the BiasAdd produces the
_FusedConv2D output.
*NOTE*: Do not invoke this operator directly in Python. Grappler is
expected to create these operators.
}];
let arguments = (ins
F32Tensor:$a,
F32Tensor:$b,
Variadic<F32Tensor>:$args,
DefaultValuedAttr<BoolAttr, "false">:$transpose_a,
DefaultValuedAttr<BoolAttr, "false">:$transpose_b,
DefaultValuedAttr<StrArrayAttr, "{}">:$fused_ops,
DefaultValuedAttr<F32Attr, "0.0001f">:$epsilon
);
let results = (outs
F32Tensor:$product
);
TF_DerivedOperandTypeAttr T = TF_DerivedOperandTypeAttr<0>;
TF_DerivedOperandSizeAttr num_args = TF_DerivedOperandSizeAttr<2>;
}
def TF__HostComputeMlirOp : TF_Op<"_HostComputeMlir", []> {
let summary = "A host-side computation called from a TPU device.";

View File

@ -44,6 +44,7 @@ limitations under the License.
#include "mlir/IR/Diagnostics.h" // from @llvm-project
#include "mlir/IR/DialectImplementation.h" // from @llvm-project
#include "mlir/IR/Function.h" // from @llvm-project
#include "mlir/IR/Identifier.h" // from @llvm-project
#include "mlir/IR/Location.h" // from @llvm-project
#include "mlir/IR/MLIRContext.h" // from @llvm-project
#include "mlir/IR/Matchers.h" // from @llvm-project
@ -68,6 +69,17 @@ limitations under the License.
namespace mlir {
namespace TF {
// Propagates underscore and device attributes from src to dst.
// TODO(b/158769932): This should be a general feature instead post some policy
// discussion.
static void PropagateAttributes(Operation *src, Operation *dst) {
auto device = mlir::Identifier::get("device", src->getContext());
for (auto named_attr : src->getAttrs()) {
if (*named_attr.first.begin() == '_' || named_attr.first == device)
dst->setAttr(named_attr.first, named_attr.second);
}
}
//===----------------------------------------------------------------------===//
// TF op helper functions
//===----------------------------------------------------------------------===//
@ -786,11 +798,49 @@ static LogicalResult Verify(BroadcastToOp op) {
}
//===----------------------------------------------------------------------===//
// CastOp
// CaseOp
//===----------------------------------------------------------------------===//
class FoldConstantCaseOp : public OpRewritePattern<TF::CaseOp> {
public:
explicit FoldConstantCaseOp(MLIRContext *context)
: OpRewritePattern<TF::CaseOp>(context) {}
LogicalResult matchAndRewrite(TF::CaseOp op,
PatternRewriter &rewriter) const override;
};
LogicalResult FoldConstantCaseOp::matchAndRewrite(
TF::CaseOp op, PatternRewriter &rewriter) const {
// Extract the constant cond value.
DenseIntElementsAttr branch;
if (!matchPattern(op.branch_index(), m_Constant(&branch))) return failure();
// Only attempt to fold scalar valued case statements.
// TODO(jpienaar): This can be removed if CaseOp's verifier covers it.
if (!branch.getType().cast<RankedTensorType>().getShape().empty())
return failure();
int index = *branch.getValues<int>().begin();
// TODO(jpienaar): This can be removed if CaseOp's verifier covers it.
if (index >= op.branches().size()) return failure();
auto func = op.branches()[index].cast<SymbolRefAttr>();
auto empty = rewriter.getStringAttr("");
auto call_op = rewriter.create<PartitionedCallOp>(
op.getLoc(), op.getResultTypes(), op.getOperands().drop_front(), func,
/*config=*/empty, /*config_proto=*/empty, /*executor_type=*/empty);
PropagateAttributes(op.getOperation(), call_op);
rewriter.replaceOp(op, call_op.getResults());
return success();
}
void CaseOp::getCanonicalizationPatterns(OwningRewritePatternList &results,
MLIRContext *context) {
results.insert<FoldConstantCaseOp>(context);
}
//===----------------------------------------------------------------------===//
// LeakyReluOp
// CastOp
//===----------------------------------------------------------------------===//
OpFoldResult CastOp::fold(ArrayRef<Attribute> operands) {
@ -823,6 +873,11 @@ static LogicalResult Verify(OpT op) {
/*mask_one_dim=*/true, op.getOperation());
}
void ConcatOp::getCanonicalizationPatterns(OwningRewritePatternList &results,
MLIRContext *context) {
results.insert<ConvertToConcatV2>(context);
}
//===----------------------------------------------------------------------===//
// ConcatOffsetOp
//===----------------------------------------------------------------------===//

View File

@ -334,7 +334,7 @@ SmallVector<StringRef, 2> GetExportedNames(Operation *op) {
bool IsExported(Operation *op) {
auto exported_names =
op->getAttrOfType<ArrayAttr>("tf_saved_model.exported_names");
return exported_names && exported_names.size() != 0;
return exported_names && !exported_names.empty();
}
bool HasTfSavedModelSemantics(ModuleOp module) {

View File

@ -133,6 +133,16 @@ func @testSameCastTypeAcrossBasicBlocks(tensor<8x16x32x64xf32>) -> tensor<8x16x3
// CHECK: return %arg0
}
// CHECK-LABEL: testConcatCanonicalization
func @testConcatCanonicalization(%arg0: tensor<2x1xi32>, %arg1: tensor<2x1xi32>) -> tensor<2x2xi32> {
// CHECK: %[[AXIS:.*]] = "tf.Const"
%0 = "tf.Const"() { value = dense<1> : tensor<i32> } : () -> tensor<i32>
// CHECK: "tf.ConcatV2"(%arg0, %arg1, %[[AXIS]])
%1 = "tf.Concat"(%0, %arg0, %arg1) : (tensor<i32>, tensor<2x1xi32>, tensor<2x1xi32>) -> tensor<2x2xi32>
return %1 : tensor<2x2xi32>
}
// CHECK-LABEL: testLogOfSoftmax
func @testLogOfSoftmax(%arg0: tensor<8x16xf32>) -> tensor<8x16xf32> {
%0 = "tf.Softmax"(%arg0) : (tensor<8x16xf32>) -> tensor<8x16xf32>
@ -550,3 +560,29 @@ func @foldFill() -> (tensor<3x2x1xf32>, tensor<*xf32>, tensor<*xcomplex<f32>>) {
return %2, %3, %4 : tensor<3x2x1xf32>, tensor<*xf32>, tensor<*xcomplex<f32>>
}
// CHECK-LABEL: foldCase
func @foldCase(%arg0: tensor<f32>, %arg1: tensor<f32>) -> (tensor<f32>) {
%2 = constant dense<1> : tensor<i32>
%3 = constant dense<0> : tensor<i32>
// CHECK: PartitionedCall
// CHECK-SAME: device = "noodle"
// CHECK-SAME: f = @add
%4 = "tf.Case"(%2, %arg0, %arg1) {branches = [@sub, @add], output_shapes = [#tf.shape<>], device = "noodle"} : (tensor<i32>, tensor<f32>, tensor<f32>) -> tensor<f32>
// CHECK: PartitionedCall
// CHECK-SAME: _cluster_launch = "not_ready"
// CHECK-SAME: f = @sub
%5 = "tf.Case"(%3, %4, %arg1) {branches = [@sub, @add], output_shapes = [#tf.shape<>], _cluster_launch = "not_ready"} : (tensor<i32>, tensor<f32>, tensor<f32>) -> tensor<f32>
return %5 : tensor<f32>
}
func @add(%arg0: tensor<*xf32>, %arg1: tensor<*xf32>) -> tensor<*xf32> {
%0 = "tf.Add"(%arg0, %arg1): (tensor<*xf32>, tensor<*xf32>) -> tensor<*xf32>
return %0 : tensor<*xf32>
}
func @sub(%arg0: tensor<*xf32>, %arg1: tensor<*xf32>) -> tensor<*xf32> {
%0 = "tf.Sub"(%arg0, %arg1) : (tensor<*xf32>, tensor<*xf32>) -> tensor<*xf32>
return %0 : tensor<*xf32>
}

View File

@ -370,7 +370,6 @@ func @decompose_resource_gather_op(%indices : tensor<5xi32>) -> tensor<2x5x16xi3
// Tests that composite tf.ResourceScatterUpdate operation is decomposed.
// CHECK-LABEL: @decompose_resource_scatter_update_op
// CHECK-SAME: ([[INDEX:%.+]]: tensor<2x?xi32>, [[UPDATE:%.+]]: tensor<?x?x?xi32>)
func @decompose_resource_scatter_update_op(%indices : tensor<2x?xi32>, %updates: tensor<?x?x?xi32>) {
@ -384,3 +383,34 @@ func @decompose_resource_scatter_update_op(%indices : tensor<2x?xi32>, %updates:
return
}
// -----
// Tests that tf.VariableShape operation is decomposed.
// CHECK-LABEL: @decompose_variable_shape_i32
func @decompose_variable_shape_i32(%input: tensor<!tf.resource<tensor<?x?x?xf32>>>) -> tensor<3xi32> {
%0 = "tf.VariableShape"(%input) : (tensor<!tf.resource<tensor<?x?x?xf32>>>) -> tensor<3xi32>
// CHECK: %[[READ:.*]] = "tf.ReadVariableOp"(%arg0)
// CHECK: %[[SHAPE:.*]] = "tf.Shape"(%[[READ]])
// CHECK: return %[[SHAPE]]
return %0 : tensor<3xi32>
}
// CHECK-LABEL: @decompose_variable_shape_i64
func @decompose_variable_shape_i64(%input: tensor<!tf.resource<tensor<?x?x?xf32>>>) -> tensor<3xi64> {
%0 = "tf.VariableShape"(%input) : (tensor<!tf.resource<tensor<?x?x?xf32>>>) -> tensor<3xi64>
// CHECK: %[[READ:.*]] = "tf.ReadVariableOp"(%arg0)
// CHECK: %[[SHAPE:.*]] = "tf.Shape"(%[[READ]])
// CHECK: return %[[SHAPE]]
return %0 : tensor<3xi64>
}
// CHECK-LABEL: @decompose_variable_shape_no_subtype
func @decompose_variable_shape_no_subtype(%input: tensor<!tf.resource>) -> tensor<3xi32> {
%0 = "tf.VariableShape"(%input) : (tensor<!tf.resource>) -> tensor<3xi32>
// CHECK: "tf.VariableShape"
// CHECK-NOT: "tf.ReadVariableOp"
// CHECK-NOT: "tf.Shape"
return %0 : tensor<3xi32>
}

View File

@ -1,4 +1,4 @@
// RUN: tf-opt %s -tf-op-fusion | FileCheck %s
// RUN: tf-opt %s -tf-fused-kernel-matcher | FileCheck %s
//===----------------------------------------------------------------------===//
// Conv2D + BiasAdd + <Activation> fusions.
@ -107,3 +107,54 @@ func @conv2D_dataFormatMismatch(%arg0: tensor<128xf32>, %arg1: tensor<1x1x3x128x
%3 = "tf.Identity"(%2) : (tensor<*xf32>) -> tensor<*xf32>
return %3 : tensor<*xf32>
}
//===----------------------------------------------------------------------===//
// MatMul + BiasAdd + <Activation> fusions.
//===----------------------------------------------------------------------===//
// CHECK-LABEL: matmulBiasAdd
func @matmulBiasAdd(%arg0: tensor<64xf32>, %arg1: tensor<8x32xf32>, %arg2: tensor<32x64xf32>) -> (tensor<*xf32>) {
// CHECK: %[[VAL_3:.*]] = "tf._FusedMatMul"(%arg1, %arg2, %arg0) {epsilon = 0.000000e+00 : f32, fused_ops = ["BiasAdd"], transpose_a = false, transpose_b = false} : (tensor<8x32xf32>, tensor<32x64xf32>, tensor<64xf32>) -> tensor<*xf32>
// CHECK: %[[VAL_4:.*]] = "tf.Identity"(%[[VAL_3]]) : (tensor<*xf32>) -> tensor<*xf32>
// CHECK: return %[[VAL_4]]
%3 = "tf.MatMul"(%arg1, %arg2) {transpose_a = false, transpose_b = false} : (tensor<8x32xf32>, tensor<32x64xf32>) -> tensor<*xf32>
%4 = "tf.BiasAdd"(%3, %arg0) {data_format = "NHWC"} : (tensor<*xf32>, tensor<64xf32>) -> tensor<*xf32>
%5 = "tf.Identity"(%4) : (tensor<*xf32>) -> tensor<*xf32>
return %5 : tensor<*xf32>
}
// CHECK-LABEL: matmulBiasAdd_relu
func @matmulBiasAdd_relu(%arg0: tensor<64xf32>, %arg1: tensor<8x32xf32>, %arg2: tensor<32x64xf32>) -> (tensor<*xf32>) {
// CHECK: %[[VAL_3:.*]] = "tf._FusedMatMul"(%arg1, %arg2, %arg0) {epsilon = 0.000000e+00 : f32, fused_ops = ["BiasAdd", "Relu"], transpose_a = false, transpose_b = false} : (tensor<8x32xf32>, tensor<32x64xf32>, tensor<64xf32>) -> tensor<*xf32>
// CHECK: %[[VAL_4:.*]] = "tf.Identity"(%[[VAL_3]]) : (tensor<*xf32>) -> tensor<*xf32>
// CHECK: return %[[VAL_4]]
%3 = "tf.MatMul"(%arg1, %arg2) {transpose_a = false, transpose_b = false} : (tensor<8x32xf32>, tensor<32x64xf32>) -> tensor<*xf32>
%4 = "tf.BiasAdd"(%3, %arg0) {data_format = "NHWC"} : (tensor<*xf32>, tensor<64xf32>) -> tensor<*xf32>
%5 = "tf.Relu"(%4) : (tensor<*xf32>) -> tensor<*xf32>
%6 = "tf.Identity"(%5) : (tensor<*xf32>) -> tensor<*xf32>
return %6 : tensor<*xf32>
}
// CHECK-LABEL: matmulBiasAdd_relu6
func @matmulBiasAdd_relu6(%arg0: tensor<64xf32>, %arg1: tensor<8x32xf32>, %arg2: tensor<32x64xf32>) -> (tensor<*xf32>) {
// CHECK: %[[VAL_3:.*]] = "tf._FusedMatMul"(%arg1, %arg2, %arg0) {epsilon = 0.000000e+00 : f32, fused_ops = ["BiasAdd", "Relu6"], transpose_a = false, transpose_b = false} : (tensor<8x32xf32>, tensor<32x64xf32>, tensor<64xf32>) -> tensor<*xf32>
// CHECK: %[[VAL_4:.*]] = "tf.Identity"(%[[VAL_3]]) : (tensor<*xf32>) -> tensor<*xf32>
// CHECK: return %[[VAL_4]]
%3 = "tf.MatMul"(%arg1, %arg2) {transpose_a = false, transpose_b = false} : (tensor<8x32xf32>, tensor<32x64xf32>) -> tensor<*xf32>
%4 = "tf.BiasAdd"(%3, %arg0) {data_format = "NHWC"} : (tensor<*xf32>, tensor<64xf32>) -> tensor<*xf32>
%5 = "tf.Relu6"(%4) : (tensor<*xf32>) -> tensor<*xf32>
%6 = "tf.Identity"(%5) : (tensor<*xf32>) -> tensor<*xf32>
return %6 : tensor<*xf32>
}
// CHECK-LABEL: matmulBiasAdd_elu
func @matmulBiasAdd_elu(%arg0: tensor<64xf32>, %arg1: tensor<8x32xf32>, %arg2: tensor<32x64xf32>) -> (tensor<*xf32>) {
// CHECK: %[[VAL_3:.*]] = "tf._FusedMatMul"(%arg1, %arg2, %arg0) {epsilon = 0.000000e+00 : f32, fused_ops = ["BiasAdd", "Elu"], transpose_a = false, transpose_b = false} : (tensor<8x32xf32>, tensor<32x64xf32>, tensor<64xf32>) -> tensor<*xf32>
// CHECK: %[[VAL_4:.*]] = "tf.Identity"(%[[VAL_3]]) : (tensor<*xf32>) -> tensor<*xf32>
// CHECK: return %[[VAL_4]]
%3 = "tf.MatMul"(%arg1, %arg2) {transpose_a = false, transpose_b = false} : (tensor<8x32xf32>, tensor<32x64xf32>) -> tensor<*xf32>
%4 = "tf.BiasAdd"(%3, %arg0) {data_format = "NHWC"} : (tensor<*xf32>, tensor<64xf32>) -> tensor<*xf32>
%5 = "tf.Elu"(%4) : (tensor<*xf32>) -> tensor<*xf32>
%6 = "tf.Identity"(%5) : (tensor<*xf32>) -> tensor<*xf32>
return %6 : tensor<*xf32>
}

View File

@ -49,5 +49,5 @@ library {
}
}
# CHECK-DAG: func @custom_relu{{[0-9]*}}() attributes {tf._implements = #tf.func<@tensorflow.relu, {}>}
# CHECK-DAG: func @custom_embedding_matmul{{[0-9]*}}() attributes {tf._implements = #tf.func<@tensorflow.embedding_matmul, {key1 = 2 : i64, key2 = false}>}
# CHECK-DAG: func @custom_relu{{[0-9]*}}(){{.+}}tf._implements = #tf.func<@tensorflow.relu, {}>}
# CHECK-DAG: func @custom_embedding_matmul{{[0-9]*}}(){{.+}}tf._implements = #tf.func<@tensorflow.embedding_matmul, {key1 = 2 : i64, key2 = false}>}

View File

@ -124,5 +124,5 @@ versions {
# CHECK: "tf.LegacyCall"() {_disable_call_shape_inference = false, f = @foo110}
# CHECK: "tf.LegacyCall"() {_disable_call_shape_inference = false, f = @foo111}
# CHECK-LABEL: func @foo110() {
# CHECK-LABEL: func @foo111() {
# CHECK-LABEL: func @foo110() attributes {sym_visibility = "private"}
# CHECK-LABEL: func @foo111() attributes {sym_visibility = "private"}

View File

@ -57,7 +57,7 @@ versions {
# CHECK: "tf.LegacyCall"() {_disable_call_shape_inference = true, f = @foo0}
# CHECK: "tf.LegacyCall"() {_disable_call_shape_inference = false, f = @bar0}
# CHECK-LABEL: func @foo0() {
# CHECK-LABEL: func @foo0() attributes {sym_visibility = "private"}
# CHECK: "tf.LegacyCall"() {_disable_call_shape_inference = false, f = @bar0}
# CHECK-LABEL: func @bar0() {
# CHECK-LABEL: func @bar0() attributes {sym_visibility = "private"}

View File

@ -7,6 +7,7 @@ module attributes {tf.versions = {producer = 888 : i32}, tf.devices = ["/job:wor
func @head_single_outside_compiled_op(%arg0: tensor<i32>) {
// CHECK: "tf_device.launch"
// CHECK-NEXT: "tf.A"
// CHECK-NOT: _xla_outside_compilation
// CHECK-NEXT: tf_device.return
// CHECK-NEXT: device = "/job:worker/replica:0/task:0/device:CPU:0"
//
@ -27,6 +28,7 @@ module attributes {tf.versions = {producer = 888 : i32}, tf.devices = ["/job:wor
func @head_single_outside_compiled_op_no_operands() {
// CHECK: %[[LAUNCH_OUT:.*]] = "tf_device.launch"
// CHECK-NEXT: %[[A_OUT:.*]] = "tf.A"
// CHECK-NOT: _xla_outside_compilation
// CHECK-NEXT: tf_device.return %[[A_OUT]]
// CHECK-NEXT: device = "/job:worker/replica:0/task:0/device:CPU:0"
//
@ -49,6 +51,7 @@ module attributes {tf.versions = {producer = 888 : i32}, tf.devices = ["/job:wor
%a = "tf.A"() : () -> tensor<i32>
// CHECK-NEXT: %[[LAUNCH_OUT:.*]] = "tf_device.launch"
// CHECK-NEXT: %[[B_OUT:.*]] = "tf.B"
// CHECK-NOT: _xla_outside_compilation
// CHECK-NEXT: tf_device.return %[[B_OUT]]
// CHECK-NEXT: device = "/job:worker/replica:0/task:0/device:CPU:0"
//
@ -69,6 +72,7 @@ module attributes {tf.versions = {producer = 888 : i32}, tf.devices = ["/job:wor
func @head_aliased_output() -> (tensor<i32>, tensor<i32>, tensor<i32>) {
// CHECK: %[[LAUNCH_OUT:.*]] = "tf_device.launch"
// CHECK-NEXT: %[[A_OUT:.*]] = "tf.A"
// CHECK-NOT: _xla_outside_compilation
// CHECK-NEXT: tf_device.return %[[A_OUT]]
// CHECK-NEXT: device = "/job:worker/replica:0/task:0/device:CPU:0"
//
@ -96,8 +100,11 @@ module attributes {tf.versions = {producer = 888 : i32}, tf.devices = ["/job:wor
func @head_all_cluster_op(%arg0: tensor<i32>) -> tensor<i32> {
// CHECK: %[[LAUNCH_OUT:.*]] = "tf_device.launch"
// CHECK-NEXT: %[[A_OUT:.*]] = "tf.A"
// CHECK-NOT: _xla_outside_compilation
// CHECK-NEXT: %[[B_OUT:.*]] = "tf.B"(%[[A_OUT]])
// CHECK-NOT: _xla_outside_compilation
// CHECK-NEXT: %[[C_OUT:.*]] = "tf.C"(%[[B_OUT]], %arg0)
// CHECK-NOT: _xla_outside_compilation
// CHECK-NEXT: tf_device.return %[[C_OUT]]
// CHECK-NEXT: device = "/job:worker/replica:0/task:0/device:CPU:0"
//
@ -117,8 +124,11 @@ module attributes {tf.versions = {producer = 888 : i32}, tf.devices = ["/job:wor
func @head_multiple_outside_compiled_ops(%arg0: tensor<i32>) {
// CHECK: %[[LAUNCH_OUT:.*]] = "tf_device.launch"
// CHECK-NEXT: %[[A_OUT:.*]] = "tf.A"
// CHECK-NOT: _xla_outside_compilation
// CHECK-NEXT: %[[B_OUT:.*]] = "tf.B"(%[[A_OUT]])
// CHECK-NOT: _xla_outside_compilation
// CHECK-NEXT: "tf.C"
// CHECK-NOT: _xla_outside_compilation
// CHECK-NEXT: tf_device.return %[[B_OUT]]
// CHECK-NEXT: device = "/job:worker/replica:0/task:0/device:CPU:0"
//
@ -141,6 +151,7 @@ module attributes {tf.versions = {producer = 888 : i32}, tf.devices = ["/job:wor
//
// CHECK-NEXT: %[[LAUNCH_OUT:.*]] = "tf_device.launch"()
// CHECK-NEXT: %[[A_OUT:.*]] = "tf.A"(%[[RI]])
// CHECK-NOT: _xla_outside_compilation
// CHECK-NEXT: tf_device.return %[[A_OUT]]
// CHECK-NEXT: device = "TPU_REPLICATED_HOST"
//
@ -173,6 +184,7 @@ module attributes {tf.versions = {producer = 888 : i32}, tf.devices = ["/job:wor
//
// CHECK: "tf_device.launch"
// CHECK-NEXT: "tf.B"(%[[CLUSTER_OUT]])
// CHECK-NOT: _xla_outside_compilation
// CHECK-NEXT: tf_device.return
// CHECK-NEXT: device = "/job:worker/replica:0/task:0/device:CPU:0"
"tf_device.cluster"() ( {
@ -199,6 +211,7 @@ module attributes {tf.versions = {producer = 888 : i32}, tf.devices = ["/job:wor
//
// CHECK: %[[LAUNCH_OUT:.*]] = "tf_device.launch"
// CHECK-NEXT: %[[B_OUT:.*]] = "tf.B"(%[[CLUSTER_OUT]])
// CHECK-NOT: _xla_outside_compilation
// CHECK-NEXT: tf_device.return %[[B_OUT]]
// CHECK-NEXT: device = "/job:worker/replica:0/task:0/device:CPU:0"
%cluster = "tf_device.cluster"() ( {
@ -226,7 +239,9 @@ module attributes {tf.versions = {producer = 888 : i32}, tf.devices = ["/job:wor
//
// CHECK: "tf_device.launch"
// CHECK-NEXT: %[[C_OUT:.*]] = "tf.C"(%arg0, %[[CLUSTER_OUT]]#1)
// CHECK-NOT: _xla_outside_compilation
// CHECK-NEXT: "tf.D"(%[[C_OUT]], %arg0, %[[CLUSTER_OUT]]#0)
// CHECK-NOT: _xla_outside_compilation
// CHECK-NEXT: tf_device.return
// CHECK-NEXT: device = "/job:worker/replica:0/task:0/device:CPU:0"
"tf_device.cluster"() ( {
@ -258,6 +273,7 @@ module attributes {tf.versions = {producer = 888 : i32}, tf.devices = ["/job:wor
//
// CHECK: %[[LAUNCH_OUT:.*]] = "tf_device.launch"
// CHECK-NEXT: %[[D_OUT:.*]] = "tf.D"(%[[CLUSTER_OUT]]#0, %[[A_OUT]])
// CHECK-NOT: _xla_outside_compilation
// CHECK-NEXT: tf_device.return
// CHECK-NEXT: device = "/job:worker/replica:0/task:0/device:CPU:0"
%cluster:5 = "tf_device.cluster"() ( {
@ -286,6 +302,7 @@ module attributes {tf.versions = {producer = 888 : i32}, tf.devices = ["/job:wor
//
// CHECK-NEXT: "tf_device.launch"()
// CHECK-NEXT: %[[B_OUT:.*]] = "tf.B"(%[[CLUSTER_OUT]], %[[RI]])
// CHECK-NOT: _xla_outside_compilation
// CHECK-NEXT: tf_device.return
// CHECK-NEXT: device = "TPU_REPLICATED_HOST"
tf_device.replicate([%arg0, %arg1] as %ri : tensor<i32>) {n = 2 : i32} {
@ -320,6 +337,7 @@ module attributes {tf.versions = {producer = 888 : i32}, tf.devices = ["/job:wor
func @head_tail_simple_extraction(%arg0: tensor<i32>) -> tensor<i32> {
// CHECK: %[[HEAD_LAUNCH_OUT:.*]] = "tf_device.launch"
// CHECK-NEXT: %[[A_OUT:.*]] = "tf.A"(%arg0)
// CHECK-NOT: _xla_outside_compilation
// CHECK-NEXT: tf_device.return %[[A_OUT]]
// CHECK-NEXT: device = "/job:worker/replica:0/task:0/device:CPU:0"
//
@ -335,6 +353,7 @@ module attributes {tf.versions = {producer = 888 : i32}, tf.devices = ["/job:wor
//
// CHECK: %[[TAIL_LAUNCH_OUT:.*]] = "tf_device.launch"
// CHECK-NEXT: %[[C_OUT:.*]] = "tf.C"(%[[CLUSTER_OUT]])
// CHECK-NOT: _xla_outside_compilation
// CHECK-NEXT: tf_device.return %[[C_OUT]]
// CHECK-NEXT: device = "/job:worker/replica:0/task:0/device:CPU:0"
%cluster = "tf_device.cluster"() ( {
@ -353,6 +372,7 @@ module attributes {tf.versions = {producer = 888 : i32}, tf.devices = ["/job:wor
//
// CHECK-NEXT: %[[HEAD_LAUNCH_OUT:.*]] = "tf_device.launch"()
// CHECK-NEXT: %[[A_OUT:.*]] = "tf.A"(%[[RI]])
// CHECK-NOT: _xla_outside_compilation
// CHECK-NEXT: tf_device.return %[[A_OUT]]
// CHECK-NEXT: device = "TPU_REPLICATED_HOST"
//
@ -370,6 +390,7 @@ module attributes {tf.versions = {producer = 888 : i32}, tf.devices = ["/job:wor
//
// CHECK-NEXT: "tf_device.launch"()
// CHECK-NEXT: "tf.D"(%[[HEAD_LAUNCH_OUT]], %[[CLUSTER_OUT]], %[[RI]])
// CHECK-NOT: _xla_outside_compilation
// CHECK-NEXT: tf_device.return
// CHECK-NEXT: device = "TPU_REPLICATED_HOST"
tf_device.replicate([%arg0, %arg1] as %ri : tensor<i32>) {n = 2 : i32} {

View File

@ -1,4 +1,4 @@
// RUN: tf-opt %s -split-input-file -tf-tpu-host-computation-expansion | FileCheck %s --dump-input-on-failure
// RUN: tf-opt %s -split-input-file -tf-tpu-host-computation-expansion | FileCheck %s
// Tests expansion of a outside compiled ops at head/tail of TPU computation.
@ -26,7 +26,7 @@ func @cast_at_head_expanded(%arg0: tensor<?xi32>) {
"tf.B"(%1) {_xla_outside_compilation = "cluster1"} : (tensor<?xi32>) -> ()
"tf.C"() : () -> ()
tf_device.return
}) {cluster_attr = "cluster_attr"} : () -> ()
}) {} : () -> ()
return
}
@ -44,7 +44,7 @@ func @check_consecutive_unary_ops_outside_compiled(%arg0: tensor<?xi32>) {
"tf.B"(%2) {_xla_outside_compilation = "cluster1"} : (tensor<?xi32>) -> ()
"tf.C"() : () -> ()
tf_device.return
}) {cluster_attr = "cluster_attr"} : () -> ()
}) {} : () -> ()
return
}
@ -59,7 +59,7 @@ func @check_only_necesarily_ops_outside_compiled(%arg0: tensor<?xi32>) {
"tf.B"(%1) : (tensor<?xi32>) -> ()
"tf.C"() : () -> ()
tf_device.return
}) {cluster_attr = "cluster_attr"} : () -> ()
}) {} : () -> ()
return
}
@ -67,9 +67,9 @@ func @check_only_necesarily_ops_outside_compiled(%arg0: tensor<?xi32>) {
func @check_only_necesarily_ops_outside_compiled_with_chained_ops(%arg0: tensor<?xi32>) {
// CHECK: "tf_device.cluster"
// CHECK-NEXT: "tf.Cast"
// CHECK-NOT: _xla_outside_compilation = ""
// CHECK-NOT: _xla_outside_compilation
// CHECK-NEXT: "tf.Identity"
// CHECK-NOT: _xla_outside_compilation = ""
// CHECK-NOT: _xla_outside_compilation
// CHECK-NEXT: "tf.B"
"tf_device.cluster"() ( {
%1 = "tf.Cast"(%arg0) : (tensor<?xi32>) -> (tensor<?xi32>)
@ -77,6 +77,19 @@ func @check_only_necesarily_ops_outside_compiled_with_chained_ops(%arg0: tensor<
"tf.B"(%2) : (tensor<?xi32>) -> ()
"tf.C"() : () -> ()
tf_device.return
}) {cluster_attr = "cluster_attr"} : () -> ()
}) : () -> ()
return
}
// CHECK-LABEL: func @check_op_without_usage_not_outside_compiled
func @check_op_without_usage_not_outside_compiled(%arg0: tensor<?xi32>) {
// CHECK: "tf_device.cluster"
// CHECK-NEXT: "tf.Identity"
// CHECK-NOT: _xla_outside_compilation
"tf_device.cluster"() ( {
"tf.Identity"(%arg0) : (tensor<?xi32>) -> (tensor<?xi32>)
"tf.C"() : () -> ()
tf_device.return
}) : () -> ()
return
}

View File

@ -87,7 +87,8 @@ void CreateTPUBridgePipeline(OpPassManager &pm) {
// because DecomposeResourceOpsPass uses pattern rewriter which hoists
// changed constants out of tf_device.Launch.
func_pm.addPass(TFDevice::CreateDecomposeResourceOpsPass());
pm.addNestedPass<FuncOp>(CreateTPUHostComputationExpansionPass());
pm.addPass(CreateTPUExtractHeadTailOutsideCompilationPass());
// Run another shape inference pass because resource decomposition might have
// created new partial types.
pm.addPass(TF::CreateTFShapeInferencePass());

View File

@ -83,6 +83,13 @@ def BitcastSameType : Pat<(TF_BitcastOp:$res $arg), (replaceWithValue $arg),
def BitcastNested : Pat<(TF_BitcastOp (TF_BitcastOp $arg)),
(TF_BitcastOp $arg)>;
//===----------------------------------------------------------------------===//
// Concat op patterns.
//===----------------------------------------------------------------------===//
def ConvertToConcatV2 : Pat<(TF_ConcatOp $axis, $inputs),
(TF_ConcatV2Op $inputs, $axis)>;
//===----------------------------------------------------------------------===//
// Conj op patterns.
//===----------------------------------------------------------------------===//

View File

@ -50,6 +50,24 @@ static Type GetResourceSubtypeOrDefault(Value resource, Type element_type) {
return UnrankedTensorType::get(element_type);
}
static bool HasResourceSubtype(Value resource) {
return resource.getType()
.cast<TensorType>()
.getElementType()
.cast<ResourceType>()
.getSubtypes()
.size() == 1;
}
static Type GetResourceSubtype(Value resource) {
return resource.getType()
.cast<TensorType>()
.getElementType()
.cast<ResourceType>()
.getSubtypes()
.front();
}
#include "tensorflow/compiler/mlir/tensorflow/transforms/generated_decompose_resource_ops.inc"
} // namespace

View File

@ -23,7 +23,7 @@ class GetScalarOfType<int value> : NativeCodeCall<
// Creates a tf.ReadVariable op that reads a resource `$2` that has the same
// element type as `$1`. The op created will use location of `$0`.
def CreateTFReadVariableOp: NativeCodeCall<
def CreateTFReadVariableOp : NativeCodeCall<
"$_builder.create<TF::ReadVariableOp>("
" $0.getLoc(),"
" GetResourceSubtypeOrDefault("
@ -31,6 +31,12 @@ def CreateTFReadVariableOp: NativeCodeCall<
" $2)"
>;
def CheckHasResourceSubtype : Constraint<CPred<"HasResourceSubtype($0)">>;
def CreateTFReadVariableOpFromResourceHandle : NativeCodeCall<
"$_builder.create<TF::ReadVariableOp>("
"$0.getLoc(), GetResourceSubtype($1), $1)">;
def DecomposeAssignAddVariableOp :
Pat<
(TF_AssignAddVariableOp:$src_op $resource, $value),
@ -315,3 +321,9 @@ def DecomposeResourceScatterUpdate : Pat<
$updates
)
)>;
// Pattern to decompose tf.VariableShape into tf.ReadVariable and tf.Shape.
def DecomposeVariableShape : Pat<
(TF_VariableShapeOp:$src_op $resource),
(TF_ShapeOp (CreateTFReadVariableOpFromResourceHandle $src_op, $resource)),
[(CheckHasResourceSubtype $resource)]>;

View File

@ -0,0 +1,225 @@
/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#include <cstdio>
#include <iostream>
#include "llvm/ADT/StringRef.h"
#include "mlir/IR/Attributes.h" // from @llvm-project
#include "mlir/IR/Diagnostics.h" // from @llvm-project
#include "mlir/IR/Function.h" // from @llvm-project
#include "mlir/IR/PatternMatch.h" // from @llvm-project
#include "mlir/Pass/Pass.h" // from @llvm-project
#include "mlir/Support/LLVM.h" // from @llvm-project
#include "mlir/Support/LogicalResult.h" // from @llvm-project
#include "tensorflow/compiler/mlir/tensorflow/ir/tf_ops.h"
namespace mlir {
namespace TF {
namespace {
// Note: This implements fusions performed in the old Remapper Grappler pass.
// That pass has specific cases for GPU and based on different target
// configurations on both CPU and GPU (Intel MKL, ROCm, etc.). This MLIR pass
// covers the general CPU case and at the moment does not account for any
// target-specific configurations.
// TODO(b/158265178): Support GPU-specific fusions.
// TODO(b/158266710): Support CPU MKL configurations.
// Optimizes TF computations by fusing subgraphs/nodes onto more efficient
// implementations to decrease the number of operations needed to perform a
// computation.
struct FusedKernelMatcherPass
: public PassWrapper<FusedKernelMatcherPass, FunctionPass> {
void runOnFunction() override;
};
// Returns an op's name with the dialect prefix stripped off.
StringRef GetOpNameWithoutDialect(Operation *op) {
return op->getName().getStringRef().split(".").second;
}
bool IsActivationFunction(Operation *op) {
return isa<EluOp>(op) || isa<ReluOp>(op) || isa<Relu6Op>(op);
}
// Finds and returns an activation op that uses the result of `op`. If there are
// multiple such activations, one is returned (with no guarantee as to which
// one). If there are no activation functions that use the output, returns
// nullptr.
Operation *GetActivation(Value op) {
for (auto &use : op.getUses()) {
if (IsActivationFunction(use.getOwner())) return use.getOwner();
}
return nullptr;
}
// Finds and returns a BiasAdd that uses the result of `op` as the `value`
// input. If there are multiple such BiasAdds, one is returned (with no
// guarantee as to which one). If there are no BiasAdds that use the output,
// returns a null BiasAddOp.
BiasAddOp GetBiasAdd(Value op) {
for (auto &use : op.getUses()) {
auto bias_add = dyn_cast_or_null<BiasAddOp>(use.getOwner());
// If it's a BiasAdd, check that the conv op is the first input.
if (bias_add && bias_add.value() == op) return bias_add;
}
// No BiasAddOps found among uses.
return BiasAddOp();
}
// Performs a fusion of the following pattern(s), if possible:
// <Contraction> + BiasAdd + <Activation> -> <FusedContraction>
//
// Note that fusion with activation is preferred, but a contraction and BiasAdd
// can also be replaced by a _FusedConv2D if there is no other activation
// function.
// i.e., this class also supports the following fusion:
// <Contraction> + BiasAdd -> <FusedContraction>
//
// TODO(b/158266331): Support fusing activation chains of arbitrary length.
template <typename SrcOpT, typename FusedOpT>
class FuseContractionWithBiasAdd : public OpRewritePattern<SrcOpT> {
public:
using OpRewritePattern<SrcOpT>::OpRewritePattern;
// Class users should override this method if there are any op-specific
// compatibility requirements between the contraction op and the BiasAdd op.
virtual bool AreFuseCompatible(SrcOpT contraction_op, BiasAddOp bias_add,
PatternRewriter &rewriter) const {
return true;
}
LogicalResult matchAndRewrite(SrcOpT contraction,
PatternRewriter &rewriter) const override {
auto context = rewriter.getContext();
// If the contraction is used in multiple places, fusing it will only create
// more contraction nodes, which is slower.
if (!contraction.getResult().hasOneUse())
return rewriter.notifyMatchFailure(contraction,
"result is used by multiple ops");
BiasAddOp bias_add = GetBiasAdd(contraction.getResult());
if (!bias_add) {
return rewriter.notifyMatchFailure(
contraction, "does not feed into a tf.BiasAdd/tf.BiasAddV1 op");
}
if (!AreFuseCompatible(contraction, bias_add, rewriter)) {
return rewriter.notifyMatchFailure(
contraction, "cannot fuse with the subsequent BiasAdd op");
}
SmallVector<Location, 3> locations{contraction.getLoc(), bias_add.getLoc()};
SmallVector<Attribute, 2> fused_ops{
StringAttr::get(GetOpNameWithoutDialect(bias_add), context)};
// BiasAdd may or may not feed into an activation function.
auto activation = GetActivation(bias_add);
// If there is an activation, only fuse it if this is the only op to use the
// result of the BiasAdd.
bool fuse_activation = activation && bias_add.output().hasOneUse();
Type result_type;
// Include info about the activation function if applicable.
if (fuse_activation) {
locations.push_back(activation->getLoc());
fused_ops.push_back(
StringAttr::get(GetOpNameWithoutDialect(activation), context));
result_type = activation->getResultTypes().front();
} else {
result_type = bias_add.getResult().getType();
}
auto fused_loc = rewriter.getFusedLoc(locations);
// The fused contraction has the same operands as the original contraction
// with `bias` from the BiasAddOp appended.
SmallVector<Value, 4> operands(contraction.operand_begin(),
contraction.operand_end());
operands.push_back(bias_add.bias());
// The fused contraction has the same attributes as the original
// contraction, with two additions: the list of ops which have been fused
// together; epsilon (only with FusedBatchNorm).
std::vector<NamedAttribute> attrs = contraction.getAttrs();
ArrayAttr fused_ops_attr = ArrayAttr::get(fused_ops, context);
attrs.push_back(
NamedAttribute(Identifier::get("fused_ops", context), fused_ops_attr));
// Epsilon is used only in fusions with the FusedBatchNorm op, so we zero it
// here.
Attribute epsilon = rewriter.getF32FloatAttr(0);
attrs.push_back(
NamedAttribute(Identifier::get("epsilon", context), epsilon));
Value fused_op = rewriter.create<FusedOpT>(fused_loc, result_type,
ValueRange(operands), attrs);
auto op_to_replace = fuse_activation ? activation : bias_add;
rewriter.replaceOp(op_to_replace, ValueRange({fused_op}));
return success();
}
};
// Performs a fusion of the following pattern(s), if possible:
// Conv2D + BiasAdd + <Activation> -> _FusedConv2D
class FuseConv2DBiasAdd
: public FuseContractionWithBiasAdd<Conv2DOp, _FusedConv2DOp> {
public:
using FuseContractionWithBiasAdd<Conv2DOp,
_FusedConv2DOp>::FuseContractionWithBiasAdd;
// Verify that the Conv2D and BiasAdd data formats match. This is necessary
// for the ops to fuse correctly, the fused Conv2D op has one data format
// attribute which is shared.
bool AreFuseCompatible(Conv2DOp conv, BiasAddOp bias_add,
PatternRewriter &rewriter) const override {
// Verify that the data formats match and are valid for fusion.
if (conv.data_format() != bias_add.data_format()) {
rewriter.notifyMatchFailure(conv, [&](Diagnostic &diag) {
diag << "data format does not match Conv2D data format ("
<< bias_add.data_format() << " vs " << conv.data_format() << ")";
});
return false;
}
return true;
}
};
// Performs a fusion of the following pattern(s), if possible:
// MatMulOp + BiasAdd + <Activation> -> _FusedMatMulOp
using FuseMatMulBiasAdd = FuseContractionWithBiasAdd<MatMulOp, _FusedMatMulOp>;
void FusedKernelMatcherPass::runOnFunction() {
OwningRewritePatternList patterns;
auto func = getFunction();
patterns.insert<FuseConv2DBiasAdd, FuseMatMulBiasAdd>(&getContext());
applyPatternsAndFoldGreedily(func, patterns);
}
} // namespace
std::unique_ptr<OperationPass<FuncOp>> CreateFusedKernelMatcherPass() {
return std::make_unique<FusedKernelMatcherPass>();
}
static PassRegistration<FusedKernelMatcherPass> pass(
"tf-fused-kernel-matcher",
"Matches computations corresponding to optimized fused kernels");
} // namespace TF
} // namespace mlir

View File

@ -1,174 +0,0 @@
/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#include <cstdio>
#include <iostream>
#include "llvm/ADT/StringRef.h"
#include "mlir/IR/Attributes.h" // from @llvm-project
#include "mlir/IR/Diagnostics.h" // from @llvm-project
#include "mlir/IR/Function.h" // from @llvm-project
#include "mlir/IR/PatternMatch.h" // from @llvm-project
#include "mlir/Pass/Pass.h" // from @llvm-project
#include "mlir/Support/LLVM.h" // from @llvm-project
#include "mlir/Support/LogicalResult.h" // from @llvm-project
#include "tensorflow/compiler/mlir/tensorflow/ir/tf_ops.h"
namespace mlir {
namespace TF {
namespace {
// Note: This implements the fusions performed in the old Remapper Grappler
// pass. That pass has specific cases for GPU and based on different
// target configurations on both CPU and GPU (Intel MKL, ROCm, etc.). This MLIR
// pass covers the general CPU case and at the moment does not account for any
// specific target configurations.
// TODO(b/158265178): Support GPU-specific fusions.
// TODO(b/158266710): Support CPU MKL configurations.
// Optimizes TF computations by fusing subgraphs/nodes onto more efficient
// implementations to decrease the number of operations needed to perform a
// computation.
struct OpFusionPass : public PassWrapper<OpFusionPass, FunctionPass> {
void runOnFunction() override;
};
// Returns an op's name with the dialect prefix stripped off.
StringRef GetOpNameWithoutDialect(Operation *op) {
return op->getName().getStringRef().split(".").second;
}
bool IsActivationFunction(Operation *op) {
return isa<EluOp>(op) || isa<ReluOp>(op) || isa<Relu6Op>(op);
}
// Finds and returns an activation op that uses the result of `op`. If there are
// multiple such activations, one is returned (with no guarantee as to which
// one). If there are no activation functions that use the output, returns
// nullptr.
Operation *GetActivation(Value op) {
for (auto &use : op.getUses()) {
if (IsActivationFunction(use.getOwner())) return use.getOwner();
}
return nullptr;
}
// Finds and returns a BiasAdd that uses the result of `op` as the `value`
// input. If there are multiple such BiasAdds, one is returned (with no
// guarantee as to which one). If there are no BiasAdds that use the output,
// returns a null BiasAddOp.
BiasAddOp GetBiasAdd(Value op) {
for (auto &use : op.getUses()) {
auto bias_add = dyn_cast_or_null<BiasAddOp>(use.getOwner());
// If it's a BiasAdd, check that the conv op is the first input.
if (bias_add && bias_add.value() == op) return bias_add;
}
// No BiasAddOps found among uses.
return BiasAddOp();
}
// Performs a fusion of the following pattern(s), if possible:
// Conv2D + BiasAdd + <Activation> -> _FusedConv2D
//
// Note that fusion with activation is preferred, but a Conv2D and BiasAdd can
// also be replaced by a _FusedConv2D if there is no other activation function.
// i.e., this class also supports the following fusion:
// Conv2D + BiasAdd -> _FusedConv2D
//
// TODO(b/158266331): Support fusing Conv2D + BiasAdd + a chain of activations.
class FuseConv2DBiasAdd : public OpRewritePattern<Conv2DOp> {
public:
using OpRewritePattern::OpRewritePattern;
LogicalResult matchAndRewrite(Conv2DOp op,
PatternRewriter &rewriter) const override {
// If the convolution is used in multiple places, fusing it will only create
// more convolutions, which is slower.
if (!op.getResult().hasOneUse())
return rewriter.notifyMatchFailure(op, "result is used by multiple ops");
BiasAddOp bias_add = GetBiasAdd(op);
if (!bias_add) {
return rewriter.notifyMatchFailure(
op, "does not feed into a tf.BiasAdd/tf.BiasAddV1 op");
}
// Check that Conv and BiasAdd formats match.
if (op.data_format() != bias_add.data_format()) {
return rewriter.notifyMatchFailure(op, [&](Diagnostic &diag) {
diag << "data format does not match Conv2D data format ("
<< bias_add.data_format() << " vs " << op.data_format() << ")";
});
}
SmallVector<Location, 3> locations{op.getLoc(), bias_add.getLoc()};
SmallVector<Attribute, 2> fused_ops{StringAttr::get(
GetOpNameWithoutDialect(bias_add), rewriter.getContext())};
Type result_type;
// BiasAdd may or may not feed into an activation function.
auto activation = GetActivation(bias_add);
// If there is an activation, only fuse it if this is the only op to use the
// result of the BiasAdd.
bool fuse_activation = activation && bias_add.output().hasOneUse();
// Include info about the activation function if applicable.
if (fuse_activation) {
locations.push_back(activation->getLoc());
fused_ops.push_back(StringAttr::get(GetOpNameWithoutDialect(activation),
rewriter.getContext()));
result_type = activation->getResultTypes().front();
} else {
result_type = bias_add.getResult().getType();
}
auto loc = rewriter.getFusedLoc(locations);
ArrayAttr fused_ops_attr = ArrayAttr::get(fused_ops, rewriter.getContext());
// Epsilon is used only in fusions with the BatchNorm op.
APFloat epsilon = APFloat(0.0f);
auto fused_op = rewriter.create<_FusedConv2DOp>(
loc, result_type, op.input(), op.filter(), bias_add.bias(),
op.strides(), op.padding(), op.explicit_paddings(), op.data_format(),
op.dilations(), op.use_cudnn_on_gpu(), fused_ops_attr, epsilon);
auto op_to_replace = fuse_activation ? activation : bias_add;
rewriter.replaceOp(op_to_replace, {fused_op});
return success();
}
};
void OpFusionPass::runOnFunction() {
OwningRewritePatternList patterns;
auto func = getFunction();
patterns.insert<FuseConv2DBiasAdd>(&getContext());
applyPatternsAndFoldGreedily(func, patterns);
}
} // namespace
std::unique_ptr<OperationPass<FuncOp>> CreateOpFusionPass() {
return std::make_unique<OpFusionPass>();
}
static PassRegistration<OpFusionPass> pass(
"tf-op-fusion",
"Replaces commonly occurring subgraphs with optimized fused kernels");
} // namespace TF
} // namespace mlir

View File

@ -54,7 +54,7 @@ std::unique_ptr<OperationPass<FuncOp>> CreateTFOptimizePass();
// Creates pass to rewrite RecvTPUEmbeddingActivationsOp and
// SendTPUEmbeddingGradients ops to internal variants.
std::unique_ptr<OperationPass<FuncOp>> CreateRewriteTPUEmbeddingOps();
std::unique_ptr<OperationPass<FuncOp>> CreateRewriteTPUEmbeddingOpsPass();
// Performs specific fusion for GPU targets.
std::unique_ptr<OperationPass<FuncOp>> CreateGpuOpFusionPass();
@ -148,8 +148,10 @@ CreateTensorArrayOpsDecompositionPass();
// Create a pass that legalize HLO to TF dialect.
std::unique_ptr<OperationPass<FuncOp>> CreateLegalizeHloToTfPass();
// Creates a pass that performs fusion of common sequences of ops.
std::unique_ptr<OperationPass<FuncOp>> CreateOpFusionPass();
// Matches sequence of ops to TensorFlow fused kernels. This pass should not be
// generally used beyond exporting to runtimes that supports these ops. In the
// future these fusions may be codegen'd automatically.
std::unique_ptr<OperationPass<FuncOp>> CreateFusedKernelMatcherPass();
} // namespace TF
namespace tf_executor {

View File

@ -101,7 +101,7 @@ void RewriteTPUEmbeddingOps::runOnFunction() {
} // anonymous namespace
std::unique_ptr<OperationPass<FuncOp>> CreateRewriteTPUEmbeddingOps() {
std::unique_ptr<OperationPass<FuncOp>> CreateRewriteTPUEmbeddingOpsPass() {
return std::make_unique<RewriteTPUEmbeddingOps>();
}

View File

@ -209,8 +209,10 @@ void CreateHeadComputation(OpBuilder* builder, tf_device::ClusterOp cluster,
llvm::ArrayRef<Operation*> head_outside_compiled_ops,
llvm::StringRef host_device) {
Block* launch_block = new Block;
for (Operation* head_outside_compiled_op : head_outside_compiled_ops)
for (Operation* head_outside_compiled_op : head_outside_compiled_ops) {
head_outside_compiled_op->removeAttr(kXlaOutsideCompilationAttr);
head_outside_compiled_op->moveBefore(launch_block, launch_block->end());
}
tf_device::LaunchOp launch = CreateLaunchForBlock(
builder, cluster, /*before=*/true, launch_block, host_device);
@ -294,8 +296,10 @@ void CreateTailComputation(OpBuilder* builder, tf_device::ClusterOp cluster,
llvm::ArrayRef<Operation*> tail_outside_compiled_ops,
llvm::StringRef host_device) {
Block* launch_block = new Block;
for (Operation* tail_outside_compiled_op : tail_outside_compiled_ops)
for (Operation* tail_outside_compiled_op : tail_outside_compiled_ops) {
tail_outside_compiled_op->removeAttr(kXlaOutsideCompilationAttr);
tail_outside_compiled_op->moveBefore(launch_block, launch_block->begin());
}
tf_device::LaunchOp launch = CreateLaunchForBlock(
builder, cluster, /*before=*/false, launch_block, host_device);

View File

@ -92,10 +92,13 @@ void ExpandHeadOutsideCompiledOps(tf_device::ClusterOp cluster,
for (auto head_outside_compiled_op :
llvm::reverse(head_outside_compiled_ops)) {
if (HasOutsideCompilationAttribute(head_outside_compiled_op)) continue;
auto users = head_outside_compiled_op->getUsers();
if (users.empty() ||
HasOutsideCompilationAttribute(head_outside_compiled_op))
continue;
bool should_expand_op_to_host_computation = true;
for (auto consumer_op : head_outside_compiled_op->getUsers()) {
for (auto consumer_op : users) {
if (should_expand_op_to_host_computation &&
!HasOutsideCompilationAttribute(consumer_op)) {
should_expand_op_to_host_computation = false;

View File

@ -219,8 +219,7 @@ class ImporterBase {
const absl::InlinedVector<OutputTensor, 4>& arg_nodes,
const absl::InlinedVector<OutputTensor, 4>& ret_nodes,
const absl::InlinedVector<Node*, 4>& control_ret_nodes,
llvm::ArrayRef<mlir::NamedAttribute> attrs,
bool function_graph);
llvm::ArrayRef<mlir::NamedAttribute> attrs);
// Finds out the function definition for the given function name from the
// graph and converts it to a function of the module. This method is called
@ -1302,8 +1301,7 @@ Status ImporterBase::ConvertLibFunction(llvm::StringRef func_name) {
TF_RETURN_IF_ERROR(child_importer.Convert(
mlir_func_name, func_type, arg_nodes, ret_nodes, control_ret_nodes,
llvm::makeArrayRef(attributes.begin(), attributes.end()),
/*function_graph=*/true));
llvm::makeArrayRef(attributes.begin(), attributes.end())));
return Status::OK();
}
@ -1405,7 +1403,7 @@ Status ImporterBase::Convert(
const absl::InlinedVector<OutputTensor, 4>& arg_nodes,
const absl::InlinedVector<OutputTensor, 4>& ret_nodes,
const absl::InlinedVector<Node*, 4>& control_ret_nodes,
llvm::ArrayRef<mlir::NamedAttribute> attrs, bool function_graph) {
llvm::ArrayRef<mlir::NamedAttribute> attrs) {
// TODO(b/122040776): Uses debug info for FunctionDef.
auto function = mlir::FuncOp::create(mlir::UnknownLoc::get(context_),
func_name, func_type, attrs);
@ -2222,8 +2220,15 @@ StatusOr<mlir::OwningModuleRef> GraphDefImporter::Convert(
PopulateTfVersions(module.get(), graph.versions());
TF_RETURN_IF_ERROR(importer.ImporterBase::Convert(
func_name, func_type, arg_nodes, ret_nodes, control_ret_nodes, attrs,
specs.graph_as_function));
func_name, func_type, arg_nodes, ret_nodes, control_ret_nodes, attrs));
// Mark main function public, others private.
for (auto function : module.get().getOps<mlir::FuncOp>()) {
auto visibility = function.getName() == func_name
? mlir::FuncOp::Visibility::Public
: mlir::FuncOp::Visibility::Private;
function.setVisibility(visibility);
}
return module;
}
@ -2888,6 +2893,16 @@ void AdjustBoundInputArgTypes(mlir::ModuleOp module) {
}
}
// Marks the visibility of functions in the saved model module.
void MarkSavedModelFunctionVisibility(mlir::ModuleOp module) {
for (auto func : module.getOps<mlir::FuncOp>()) {
auto visibility = mlir::tf_saved_model::IsExported(func)
? mlir::FuncOp::Visibility::Public
: mlir::FuncOp::Visibility::Private;
func.setVisibility(visibility);
}
}
// Reorder the ops in the module to make testing easier and less dependent
// on implementation details such as the order of functions in the
// FunctionDefLibrary.
@ -3130,6 +3145,7 @@ Status CreateSavedModelIR(
AdjustBoundInputArgTypes(module);
module.setAttr("tf_saved_model.semantics", builder.getUnitAttr());
SortSavedModelModule(module);
MarkSavedModelFunctionVisibility(module);
return Status::OK();
}
@ -3299,6 +3315,7 @@ SavedModelSignatureDefImporter::ConvertSignatures() {
mlir::OpBuilder builder(module_->getBodyRegion());
module_->setAttr("tf_saved_model.semantics", builder.getUnitAttr());
SortSavedModelModule(*module_);
MarkSavedModelFunctionVisibility(*module_);
return std::move(module_);
}

View File

@ -57,7 +57,7 @@ cc_library(
],
deps = [
":tfjs_inc_gen",
"@llvm-project//llvm:support",
"@llvm-project//llvm:Support",
"@llvm-project//mlir:Dialect",
"@llvm-project//mlir:IR",
"@llvm-project//mlir:SideEffects",
@ -109,7 +109,7 @@ cc_library(
":tensorflow_js",
":tensorflow_js_dialect_registration",
"//tensorflow/compiler/mlir/tensorflow",
"@llvm-project//llvm:support",
"@llvm-project//llvm:Support",
"@llvm-project//mlir:Analysis",
"@llvm-project//mlir:IR",
"@llvm-project//mlir:Pass",
@ -190,7 +190,7 @@ cc_library(
"//tensorflow/stream_executor/lib",
"@com_google_absl//absl/strings",
"@com_google_absl//absl/types:span",
"@llvm-project//llvm:support",
"@llvm-project//llvm:Support",
"@llvm-project//mlir:AllPassesAndDialects",
"@llvm-project//mlir:IR",
"@llvm-project//mlir:Parser",
@ -229,7 +229,7 @@ tf_cc_binary(
"//tensorflow/core/platform:errors",
"//tensorflow/stream_executor/lib",
"@com_google_absl//absl/strings",
"@llvm-project//llvm:support",
"@llvm-project//llvm:Support",
"@llvm-project//mlir:IR",
"@llvm-project//mlir:Pass",
"@llvm-project//mlir:Support",

View File

@ -11,7 +11,7 @@ cc_library(
deps = [
"@com_google_absl//absl/memory",
"@com_google_absl//absl/strings",
"@llvm-project//llvm:support",
"@llvm-project//llvm:Support",
"@llvm-project//mlir:AllPassesAndDialects",
"@llvm-project//mlir:GPUDialect",
"@llvm-project//mlir:IR",

View File

@ -107,7 +107,7 @@ gentbl(
td_file = "transforms/legalize_tf_patterns.td",
td_srcs = [
":hlo_ops_td_files",
"@llvm-project//llvm:support",
"@llvm-project//llvm:Support",
"@llvm-project//mlir:StdOpsTdFiles",
"//tensorflow/compiler/mlir/tensorflow:tensorflow_ops_td_files",
],
@ -177,7 +177,8 @@ cc_library(
"//tensorflow/compiler/xla/client:sharding_builder",
"//tensorflow/core:framework",
"//tensorflow/core/kernels:conv_grad_shape_utils",
"@llvm-project//llvm:support",
"//tensorflow/core/lib/bfloat16",
"@llvm-project//llvm:Support",
"@llvm-project//mlir:Analysis",
"@llvm-project//mlir:Dialect",
"@llvm-project//mlir:IR",
@ -217,7 +218,7 @@ cc_library(
"@com_google_absl//absl/container:inlined_vector",
"@com_google_absl//absl/memory",
"@com_google_absl//absl/strings",
"@llvm-project//llvm:support",
"@llvm-project//llvm:Support",
"@llvm-project//mlir:IR",
"@llvm-project//mlir:Pass",
"@llvm-project//mlir:StandardOps",
@ -233,7 +234,7 @@ cc_library(
":hlo",
"//tensorflow/compiler/mlir/tensorflow",
"//tensorflow/compiler/mlir/tensorflow:lower_tf_lib",
"@llvm-project//llvm:support",
"@llvm-project//llvm:Support",
"@llvm-project//mlir:IR",
"@llvm-project//mlir:Pass",
"@llvm-project//mlir:Support",
@ -249,7 +250,7 @@ cc_library(
":hlo",
":lhlo",
":map_hlo_to_lhlo_op",
"@llvm-project//llvm:support",
"@llvm-project//llvm:Support",
"@llvm-project//mlir:StandardOps",
],
)
@ -272,7 +273,7 @@ cc_library(
":map_xla_to_scalar_op",
"//tensorflow/compiler/xla:status",
"@com_google_absl//absl/memory",
"@llvm-project//llvm:support",
"@llvm-project//llvm:Support",
"@llvm-project//mlir:Affine",
"@llvm-project//mlir:IR",
"@llvm-project//mlir:Pass",
@ -287,7 +288,7 @@ cc_library(
deps = [
":lhlo",
"@com_google_absl//absl/memory",
"@llvm-project//llvm:support",
"@llvm-project//llvm:Support",
"@llvm-project//mlir:IR",
"@llvm-project//mlir:LinalgOps",
"@llvm-project//mlir:Pass",
@ -321,7 +322,7 @@ cc_library(
":lhlo",
":map_xla_to_scalar_op",
"@com_google_absl//absl/memory",
"@llvm-project//llvm:support",
"@llvm-project//llvm:Support",
"@llvm-project//mlir:IR",
"@llvm-project//mlir:LinalgOps",
"@llvm-project//mlir:Pass",
@ -354,7 +355,7 @@ cc_library(
":lhlo",
":map_xla_to_scalar_op",
"@com_google_absl//absl/memory",
"@llvm-project//llvm:support",
"@llvm-project//llvm:Support",
"@llvm-project//mlir:GPUDialect",
"@llvm-project//mlir:IR",
"@llvm-project//mlir:LinalgOps",
@ -372,7 +373,7 @@ cc_library(
deps = [
":lhlo",
"@com_google_absl//absl/memory",
"@llvm-project//llvm:support",
"@llvm-project//llvm:Support",
"@llvm-project//mlir:LinalgOps",
"@llvm-project//mlir:LinalgTransforms",
"@llvm-project//mlir:Pass",
@ -416,7 +417,7 @@ cc_library(
srcs = ["transforms/cycle_detector.cc"],
hdrs = ["transforms/cycle_detector.h"],
deps = [
"@llvm-project//llvm:support",
"@llvm-project//llvm:Support",
],
alwayslink = 1,
)
@ -437,8 +438,8 @@ cc_library(
deps = [
":cycle_detector",
":hlo",
"@llvm-project//llvm:ir",
"@llvm-project//llvm:support",
"@llvm-project//llvm:Core",
"@llvm-project//llvm:Support",
"@llvm-project//mlir:IR",
"@llvm-project//mlir:Pass",
"@llvm-project//mlir:StandardOps",
@ -466,7 +467,7 @@ cc_library(
srcs = ["transforms/legalize_control_flow.cc"],
deps = [
":hlo",
"@llvm-project//llvm:support",
"@llvm-project//llvm:Support",
"@llvm-project//mlir:Analysis",
"@llvm-project//mlir:IR",
"@llvm-project//mlir:Pass",
@ -490,7 +491,7 @@ cc_library(
"//tensorflow/compiler/xla:util",
"//tensorflow/compiler/xla/service:buffer_assignment",
"//tensorflow/compiler/xla/service:hlo",
"@llvm-project//llvm:support",
"@llvm-project//llvm:Support",
"@llvm-project//mlir:IR",
"@llvm-project//mlir:Pass",
"@llvm-project//mlir:StandardOps",
@ -504,7 +505,7 @@ cc_library(
deps = [
":hlo",
":xla_legalize_to_standard_inc_gen",
"@llvm-project//llvm:support",
"@llvm-project//llvm:Support",
"@llvm-project//mlir:Analysis",
"@llvm-project//mlir:IR",
"@llvm-project//mlir:Pass",
@ -522,7 +523,7 @@ gentbl(
td_file = "transforms/lower_complex_patterns.td",
td_srcs = [
":hlo_ops_td_files",
"@llvm-project//llvm:support",
"@llvm-project//llvm:Support",
"@llvm-project//mlir:StdOpsTdFiles",
],
)
@ -537,7 +538,7 @@ cc_library(
deps = [
":hlo",
":xla_dialect_registration",
"@llvm-project//llvm:support",
"@llvm-project//llvm:Support",
"@llvm-project//mlir:Analysis",
"@llvm-project//mlir:IR",
"@llvm-project//mlir:Pass",
@ -564,7 +565,7 @@ cc_library(
srcs = ["transforms/unfuse_batch_norm.cc"],
deps = [
":hlo",
"@llvm-project//llvm:support",
"@llvm-project//llvm:Support",
"@llvm-project//mlir:IR",
"@llvm-project//mlir:StandardOps",
"@llvm-project//mlir:Transforms",
@ -637,7 +638,7 @@ cc_library(
":infer_fusibility_op_interface",
":xla_canonicalize_inc_gen",
"@com_google_absl//absl/container:flat_hash_set",
"@llvm-project//llvm:support",
"@llvm-project//llvm:Support",
"@llvm-project//mlir:Analysis",
"@llvm-project//mlir:IR",
"@llvm-project//mlir:InferTypeOpInterface",
@ -659,6 +660,7 @@ cc_library(
deps = [
":attribute_importer",
":hlo",
":hlo_module_importer",
":hlo_utils",
":type_to_shape",
"//tensorflow/compiler/xla:comparison_util",
@ -671,7 +673,7 @@ cc_library(
"//tensorflow/core/platform:types",
"//tensorflow/stream_executor/lib",
"@com_google_absl//absl/container:flat_hash_map",
"@llvm-project//llvm:support",
"@llvm-project//llvm:Support",
"@llvm-project//mlir:IR",
],
)
@ -692,7 +694,7 @@ cc_library(
deps = [
":hlo_ops_base_inc_gen",
":lhlo_ops_inc_gen",
"@llvm-project//llvm:support",
"@llvm-project//llvm:Support",
"@llvm-project//mlir:Analysis",
"@llvm-project//mlir:IR",
"@llvm-project//mlir:Pass",
@ -748,7 +750,7 @@ cc_library(
"//tensorflow/core:framework",
"//tensorflow/core/platform:logging",
"//tensorflow/core/platform:types",
"@llvm-project//llvm:support",
"@llvm-project//llvm:Support",
"@llvm-project//mlir:IR",
"@llvm-project//mlir:Support",
],
@ -798,7 +800,7 @@ cc_library(
"//tensorflow/core:lib",
"//tensorflow/core:protos_all_cc",
"//tensorflow/stream_executor/lib",
"@llvm-project//llvm:support",
"@llvm-project//llvm:Support",
"@llvm-project//mlir:Analysis",
"@llvm-project//mlir:IR",
"@llvm-project//mlir:Pass",
@ -846,7 +848,7 @@ cc_library(
"//tensorflow/compiler/xla/service:hlo_casting_utils",
"//tensorflow/core:lib",
"@com_google_absl//absl/types:optional",
"@llvm-project//llvm:support",
"@llvm-project//llvm:Support",
"@llvm-project//mlir:IR",
"@llvm-project//mlir:StandardOps",
],
@ -877,7 +879,7 @@ cc_library(
"//tensorflow/compiler/xla/service:hlo_parser",
"//tensorflow/compiler/xla/service:hlo_proto_cc",
"//tensorflow/core:lib",
"@llvm-project//llvm:support",
"@llvm-project//llvm:Support",
"@llvm-project//mlir:IR",
"@llvm-project//mlir:Translation",
],
@ -888,8 +890,8 @@ tf_native_cc_binary(
name = "operator_writer_gen",
srcs = ["operator_writer_gen.cc"],
deps = [
"@llvm-project//llvm:support",
"@llvm-project//llvm:tablegen",
"@llvm-project//llvm:Support",
"@llvm-project//llvm:TableGen",
"@llvm-project//mlir:Support",
"@llvm-project//mlir:TableGen",
],

View File

@ -15,6 +15,8 @@ limitations under the License.
#include "tensorflow/compiler/mlir/xla/hlo_function_importer.h"
#include <unordered_map>
#include "absl/types/optional.h"
#include "llvm/ADT/ArrayRef.h"
#include "llvm/ADT/STLExtras.h"
@ -79,30 +81,35 @@ bool DotIsDefault(const HloInstruction* instruction) {
}
} // namespace
StatusOr<mlir::FuncOp> HloFunctionImporter::ImportFunction(
mlir::ModuleOp module, mlir::Builder* builder,
std::unordered_map<HloComputation*, FuncOp>* function_map,
HloComputation* computation) {
HloFunctionImporter importer(module, builder, function_map);
return importer.ImportFunction(computation);
Status HloFunctionImporter::ImportAsFunc(
const HloComputation& computation, mlir::ModuleOp module,
std::unordered_map<const HloComputation*, FuncOp>* function_map,
mlir::Builder* builder) {
HloFunctionImporter importer(module, function_map, builder);
return importer.ImportAsFunc(computation).status();
}
StatusOr<mlir::FuncOp> HloFunctionImporter::ImportFunction(
HloComputation* computation) {
auto& imported = (*function_map_)[computation];
Status HloFunctionImporter::ImportAsRegion(
const xla::HloComputation& computation, mlir::Region* region,
mlir::Builder* builder) {
HloFunctionImporter importer(region->getParentOfType<mlir::ModuleOp>(), {},
builder);
return importer.ImportAsRegion(computation, region);
}
StatusOr<mlir::FuncOp> HloFunctionImporter::ImportAsFunc(
const HloComputation& computation) {
auto& imported = (*function_map_)[&computation];
if (imported) return imported;
llvm::SmallVector<Type, 4> args, rets;
TF_RETURN_IF_ERROR(
GetMlirTypes(computation->parameter_instructions(), &args));
TF_RETURN_IF_ERROR(GetMlirTypes({computation->root_instruction()}, &rets));
TF_RETURN_IF_ERROR(GetMlirTypes(computation.parameter_instructions(), &args));
TF_RETURN_IF_ERROR(GetMlirTypes({computation.root_instruction()}, &rets));
auto func_type = mlir::FunctionType::get(args, rets, context_);
string computation_name =
computation->parent()->entry_computation() == computation
computation.parent()->entry_computation() == &computation
? "main"
: SanitizeFunctionName(computation->name());
: SanitizeFunctionName(computation.name());
// Construct the MLIR function and map arguments.
llvm::ArrayRef<mlir::NamedAttribute> attrs;
@ -119,31 +126,30 @@ StatusOr<mlir::FuncOp> HloFunctionImporter::ImportFunction(
return function;
}
tensorflow::Status HloFunctionImporter::ImportComputation(
HloComputation* computation, mlir::Region* region) {
tensorflow::Status HloFunctionImporter::ImportAsRegion(
const HloComputation& computation, mlir::Region* region) {
// TODO(hinsu): Store computation name as an attribute for round-trip.
auto* block = new mlir::Block;
region->push_back(block);
llvm::SmallVector<Type, 4> args;
TF_RETURN_IF_ERROR(
GetMlirTypes(computation->parameter_instructions(), &args));
TF_RETURN_IF_ERROR(GetMlirTypes(computation.parameter_instructions(), &args));
block->addArguments(args);
return ImportInstructions(computation, block);
}
tensorflow::Status HloFunctionImporter::ImportInstructions(
HloComputation* computation, mlir::Block* block) {
const HloComputation& computation, mlir::Block* block) {
// Setup the input parameters.
const int num_parameters = computation->num_parameters();
const int num_parameters = computation.num_parameters();
for (int i = 0; i < num_parameters; i++) {
auto hlo_parameter = computation->parameter_instruction(i);
auto hlo_parameter = computation.parameter_instruction(i);
instruction_value_map_[hlo_parameter] = block->getArgument(i);
}
mlir::OpBuilder builder = mlir::OpBuilder::atBlockEnd(block);
for (auto instruction : computation->MakeInstructionPostOrder()) {
for (auto instruction : computation.MakeInstructionPostOrder()) {
TF_ASSIGN_OR_RETURN(auto new_operation,
ImportInstruction(instruction, &builder));
if (new_operation) {
@ -156,7 +162,7 @@ tensorflow::Status HloFunctionImporter::ImportInstructions(
// Setup the return type (HLO only supports a single return value).
TF_ASSIGN_OR_RETURN(auto result,
GetMlirValue(computation->root_instruction()));
GetMlirValue(computation.root_instruction()));
// Create terminator op depending on the parent op of this region.
if (llvm::isa<FuncOp>(block->getParentOp())) {
@ -249,7 +255,7 @@ StatusOr<mlir::Operation*> HloFunctionImporter::ImportInstruction(
}
case HloOpcode::kCall: {
TF_ASSIGN_OR_RETURN(FuncOp function,
ImportFunction(instruction->to_apply()));
ImportAsFunc(*instruction->to_apply()));
mlir::Operation* new_operation =
func_builder->create<mlir::CallOp>(loc, function, operands);
return new_operation;
@ -365,7 +371,7 @@ StatusOr<mlir::Operation*> HloFunctionImporter::ImportInstruction(
auto scatter_op = func_builder->create<mlir::xla_hlo::ScatterOp>(
loc, result_type, operands, attributes);
TF_RETURN_IF_ERROR(ImportComputation(scatter->to_apply(),
TF_RETURN_IF_ERROR(ImportAsRegion(*scatter->to_apply(),
&scatter_op.update_computation()));
return scatter_op.getOperation();
}
@ -387,9 +393,9 @@ StatusOr<mlir::Operation*> HloFunctionImporter::ImportInstruction(
auto select_scatter_op =
func_builder->create<mlir::xla_hlo::SelectAndScatterOp>(
loc, result_type, operands, attributes);
TF_RETURN_IF_ERROR(ImportComputation(select_scatter->select(),
TF_RETURN_IF_ERROR(ImportAsRegion(*select_scatter->select(),
&select_scatter_op.select()));
TF_RETURN_IF_ERROR(ImportComputation(select_scatter->scatter(),
TF_RETURN_IF_ERROR(ImportAsRegion(*select_scatter->scatter(),
&select_scatter_op.scatter()));
return select_scatter_op.getOperation();
}
@ -414,8 +420,8 @@ StatusOr<mlir::Operation*> HloFunctionImporter::ImportInstruction(
loc, result_type, operands,
builder_->getI64IntegerAttr(sort_instruction->sort_dimension()),
builder_->getBoolAttr(sort_instruction->is_stable()));
TF_RETURN_IF_ERROR(ImportComputation(sort_instruction->to_apply(),
&sort_op.comparator()));
TF_RETURN_IF_ERROR(
ImportAsRegion(*sort_instruction->to_apply(), &sort_op.comparator()));
return sort_op.getOperation();
}
case HloOpcode::kConditional: {
@ -430,9 +436,9 @@ StatusOr<mlir::Operation*> HloFunctionImporter::ImportInstruction(
auto op = func_builder->create<mlir::xla_hlo::IfOp>(loc, rets, operands,
attributes);
TF_RETURN_IF_ERROR(ImportComputation(instruction->true_computation(),
TF_RETURN_IF_ERROR(ImportAsRegion(*instruction->true_computation(),
&op.true_branch()));
TF_RETURN_IF_ERROR(ImportComputation(instruction->false_computation(),
TF_RETURN_IF_ERROR(ImportAsRegion(*instruction->false_computation(),
&op.false_branch()));
return op.getOperation();
}
@ -448,8 +454,7 @@ StatusOr<mlir::Operation*> HloFunctionImporter::ImportInstruction(
llvm::enumerate(instruction->branch_computations())) {
auto index = index_and_computation.index();
HloComputation* computation = index_and_computation.value();
TF_RETURN_IF_ERROR(
ImportComputation(computation, &op.branches()[index]));
TF_RETURN_IF_ERROR(ImportAsRegion(*computation, &op.branches()[index]));
}
return op.getOperation();
}
@ -468,7 +473,7 @@ StatusOr<mlir::Operation*> HloFunctionImporter::ImportInstruction(
attributes.push_back(ConvertChannelHandle(all_reduce->channel_id()));
auto all_reduce_op = func_builder->create<mlir::xla_hlo::AllReduceOp>(
loc, result_type, operands, attributes);
TF_RETURN_IF_ERROR(ImportComputation(all_reduce->to_apply(),
TF_RETURN_IF_ERROR(ImportAsRegion(*all_reduce->to_apply(),
&all_reduce_op.computation()));
return all_reduce_op.getOperation();
}
@ -481,7 +486,7 @@ StatusOr<mlir::Operation*> HloFunctionImporter::ImportInstruction(
llvm::makeArrayRef(operands).drop_front(num_inputs),
ConvertDimensions(instruction->dimensions()));
TF_RETURN_IF_ERROR(
ImportComputation(instruction->to_apply(), &reduce.body()));
ImportAsRegion(*instruction->to_apply(), &reduce.body()));
return reduce.getOperation();
}
case HloOpcode::kReverse: {
@ -517,9 +522,9 @@ StatusOr<mlir::Operation*> HloFunctionImporter::ImportInstruction(
auto op = func_builder->create<mlir::xla_hlo::WhileOp>(
loc, operands[0].getType(), operands[0]);
TF_RETURN_IF_ERROR(
ImportComputation(instruction->while_condition(), &op.cond()));
ImportAsRegion(*instruction->while_condition(), &op.cond()));
TF_RETURN_IF_ERROR(
ImportComputation(instruction->while_body(), &op.body()));
ImportAsRegion(*instruction->while_body(), &op.body()));
return op.getOperation();
}
case HloOpcode::kGetTupleElement: {
@ -580,7 +585,7 @@ StatusOr<mlir::Operation*> HloFunctionImporter::ImportInstruction(
auto reduce = func_builder->create<mlir::xla_hlo::ReduceWindowOp>(
loc, result_type, operands, attributes);
TF_RETURN_IF_ERROR(
ImportComputation(instruction->to_apply(), &reduce.body()));
ImportAsRegion(*instruction->to_apply(), &reduce.body()));
return reduce.getOperation();
}
case HloOpcode::kMap: {
@ -588,7 +593,7 @@ StatusOr<mlir::Operation*> HloFunctionImporter::ImportInstruction(
loc, result_type, operands,
ConvertDimensions(instruction->dimensions()));
TF_RETURN_IF_ERROR(
ImportComputation(instruction->to_apply(), &op.computation()));
ImportAsRegion(*instruction->to_apply(), &op.computation()));
return op.getOperation();
}
case HloOpcode::kConvolution: {

View File

@ -42,29 +42,39 @@ class Shape;
// Helper class for importing HloComputations.
class HloFunctionImporter {
public:
static StatusOr<mlir::FuncOp> ImportFunction(
mlir::ModuleOp module, mlir::Builder* builder,
std::unordered_map<xla::HloComputation*, mlir::FuncOp>* function_map,
xla::HloComputation* computation);
// Imports the given computation as a function in the given module. This also
// imports any computations referred by instructions in this computation.
static Status ImportAsFunc(const xla::HloComputation& computation,
mlir::ModuleOp module,
std::unordered_map<const xla::HloComputation*,
mlir::FuncOp>* function_map,
mlir::Builder* builder);
// Imports the given hlo computation to the specified region.
static Status ImportAsRegion(const xla::HloComputation& computation,
mlir::Region* region, mlir::Builder* builder);
private:
HloFunctionImporter(
mlir::ModuleOp module, mlir::Builder* builder,
std::unordered_map<xla::HloComputation*, mlir::FuncOp>* function_map)
HloFunctionImporter(mlir::ModuleOp module,
std::unordered_map<const xla::HloComputation*,
mlir::FuncOp>* function_map,
mlir::Builder* builder)
: context_(module.getContext()),
module_(module),
builder_(builder),
function_map_(function_map) {}
StatusOr<mlir::FuncOp> ImportFunction(xla::HloComputation* computation);
// Imports the given computation as a new function, if it hasn't been already
// imported.
StatusOr<mlir::FuncOp> ImportAsFunc(const xla::HloComputation& computation);
// Imports the given computation in the specified region.
tensorflow::Status ImportComputation(HloComputation* computation,
tensorflow::Status ImportAsRegion(const HloComputation& computation,
mlir::Region* region);
// Imports instructions from the given computation in the specified block.
// Assumes that the block already has correct arguments populated.
tensorflow::Status ImportInstructions(HloComputation* computation,
tensorflow::Status ImportInstructions(const HloComputation& computation,
mlir::Block* block);
// Imports an instruction.
@ -125,7 +135,7 @@ class HloFunctionImporter {
mlir::Builder* builder_;
// Mapping from HloComputation to the created MLIR function.
std::unordered_map<xla::HloComputation*, mlir::FuncOp>* function_map_;
std::unordered_map<const xla::HloComputation*, mlir::FuncOp>* function_map_;
// Mapping from HloInstructions to the associative MLIR values.
std::unordered_map<xla::HloInstruction*, mlir::Value> instruction_value_map_;

View File

@ -33,11 +33,11 @@ namespace xla {
Status HloModuleImporter::Import(const xla::HloModule& module) {
// TODO(hinsu): Only import the entry computation here once all HLO ops with
// reference to other computation are updated to have a region instead of a
// function attribute.
for (const auto& computation : module.computations()) {
auto result = HloFunctionImporter::ImportFunction(
module_, &builder_, &function_map_, computation);
TF_RETURN_IF_ERROR(result.status());
// function attribute. Currently the importer test doesn't refer to all the
// computations from the entry computation so tests may need some update.
for (const auto* computation : module.computations()) {
TF_RETURN_IF_ERROR(HloFunctionImporter::ImportAsFunc(
*computation, module_, &function_map_, &builder_));
}
return Status::OK();

View File

@ -54,7 +54,7 @@ class HloModuleImporter {
// Map for tracking which MLIR function map to which HLO Computation. This
// tracks functions as they are imported and provides a quick lookup for
// functions invoked by control flow related operations (e.g. while, call).
std::unordered_map<xla::HloComputation*, mlir::FuncOp> function_map_;
std::unordered_map<const xla::HloComputation*, mlir::FuncOp> function_map_;
};
} // namespace xla

View File

@ -836,21 +836,33 @@ LogicalResult ConcatenateOp::inferReturnTypes(
auto dimension = dimension_attr.getInt();
auto first_type = (*operands.begin()).getType().cast<ShapedType>();
auto out_element = first_type.getElementType();
for (auto operand : operands.getTypes()) {
auto element_type = getElementTypeOrSelf(operand);
if (element_type != out_element) {
return failure();
}
}
// If an input is unranked the output shape is unranked.
if (!first_type.hasRank()) {
inferredReturnTypes.push_back(UnrankedTensorType::get(out_element));
return success();
}
auto out_shape = llvm::to_vector<6>(first_type.getShape());
out_shape[dimension] = 0;
for (auto operand : operands.getTypes()) {
auto type = operand.cast<ShapedType>();
auto dim = type.getShape()[dimension];
// Validate the element types match.
if (type.getElementType() != out_element) {
return failure();
if (!type.hasRank()) {
inferredReturnTypes.push_back(UnrankedTensorType::get(out_element));
return success();
}
// If the dimension is dynamic we know the output dimension is dynamic.
auto dim = type.getShape()[dimension];
if (dim == -1) {
out_shape[dimension] = -1;
break;
@ -937,26 +949,39 @@ OpFoldResult ConcatenateOp::fold(ArrayRef<Attribute> operands) {
}
static LogicalResult Verify(ConcatenateOp op) {
auto firstType = op.getOperand(0).getType().cast<RankedTensorType>();
Type element_type = getElementTypeOrSelf(op.getOperand(0).getType());
RankedTensorType first_ranked_type;
int num_operands = op.getNumOperands();
for (int i = 0; i < num_operands; i++) {
auto second_type = op.getOperand(i).getType().dyn_cast<ShapedType>();
if (second_type.getElementType() != element_type) {
return op.emitOpError(
llvm::formatv("operands (0) and ({0}) do not match element type", i));
}
auto firstShape = firstType.getShape();
int numOperands = op.getNumOperands();
for (int i = 1; i < numOperands; i++) {
auto secondType = op.getOperand(i).getType().cast<RankedTensorType>();
if (!second_type.hasRank()) {
continue;
}
if (firstType.getRank() != secondType.getRank()) {
if (!first_ranked_type) {
first_ranked_type = second_type.cast<RankedTensorType>();
continue;
}
if (first_ranked_type.getRank() != second_type.getRank()) {
return op.emitOpError(
llvm::formatv("operands (0) and ({0}) do not match rank", i));
}
auto secondShape = secondType.getShape();
for (int d = 0; d < firstType.getRank(); ++d) {
if (firstShape[d] != secondShape[d] && d != op.dimension()) {
auto first_shape = second_type.getShape();
auto second_shape = second_type.getShape();
for (int d = 0; d < first_ranked_type.getRank(); ++d) {
if (first_shape[d] != second_shape[d] && d != op.dimension()) {
return op.emitOpError(llvm::formatv(
"operands (0) and ({0}) non-concat dimensions do not match "
"({1}) != ({2})",
i, llvm::make_range(firstShape.begin(), firstShape.end()),
llvm::make_range(secondShape.begin(), secondShape.end())));
i, llvm::make_range(first_shape.begin(), first_shape.end()),
llvm::make_range(second_shape.begin(), second_shape.end())));
}
}
}

View File

@ -358,7 +358,7 @@ def HLO_SubOp : HLO_BinaryElementwiseOp<"subtract",
}
//===----------------------------------------------------------------------===//
// XLA binary elementwise op definitions.
// XLA binary logical elementwise op definitions.
//===----------------------------------------------------------------------===//
// See https://www.tensorflow.org/xla/operation_semantics#element-wise_binary_arithmetic_operations
@ -379,15 +379,6 @@ def HLO_XorOp : HLO_BinaryLogicalElementwiseOp<"xor">, BASE_HLO_XorOp;
// XLA communication op definitions.
//===----------------------------------------------------------------------===//
// Represents a unique identifier for each Send/Recv instruction pair or
// optionally for collective instructions (AllReduce, CollectivePermute,
// AllToAll). Non-positive channel_id handle is equivalent to no channel id.
def ChannelHandle : StructAttr<"ChannelHandle", HLO_Dialect, [
StructFieldAttr<"handle", I64Attr>,
StructFieldAttr<"type", I64Attr>]> {
let description = "two 64-bit integers 'handle' and 'type'";
}
// InfeedOp corresponds to 'InfeedWithToken' xla client API and not 'Infeed'.
// InfeedWithToken allows ordering of infeed HLO instructions using tokens.
def HLO_InfeedOp : HLO_Op<"infeed", []> {
@ -451,7 +442,7 @@ def HLO_SendOp : HLO_Op<"send", []> {
let arguments = (ins
HLO_TensorOrTuple:$operand,
HLO_Token:$token,
ChannelHandle:$channel_id,
ChannelHandle<HLO_Dialect>:$channel_id,
DefaultValuedAttr<BoolAttr, "false">:$is_host_transfer
);
@ -476,7 +467,7 @@ def HLO_RecvOp : HLO_Op<"recv", []> {
let arguments = (ins
HLO_Token:$token,
ChannelHandle:$channel_id,
ChannelHandle<HLO_Dialect>:$channel_id,
DefaultValuedAttr<BoolAttr, "false">:$is_host_transfer
);
@ -564,16 +555,8 @@ def HLO_CaseOp: HLO_Op<"case", [RecursiveSideEffects]>,
def HLO_WhileOp: HLO_Op<"while", [RecursiveSideEffects,
SameOperandsAndResultType]> {
string summary = "While operator";
string description = [{
Returns the result of executing a body function until the cond body returns
true.
See https://www.tensorflow.org/xla/operation_semantics#while.
}];
SameOperandsAndResultType]>,
BASE_HLO_WhileOp {
let arguments = (ins HLO_TensorOrTuple:$val);
let regions = (region AnyRegion:$cond, AnyRegion:$body);
@ -590,7 +573,7 @@ def HLO_AllReduceOp : HLO_Op<"all_reduce",
let arguments = (ins
HLO_Tensor:$operand,
I64ElementsAttr:$replica_groups,
OptionalAttr<ChannelHandle>:$channel_id
OptionalAttr<ChannelHandle<HLO_Dialect>>:$channel_id
);
let regions = (region SizedRegion<1>:$computation);
let results = (outs HLO_Tensor);

View File

@ -584,6 +584,15 @@ class BASE_HLO_CaseOp {
// XLA parallelism related op definitions.
//===----------------------------------------------------------------------===//
// Represents a unique identifier for each Send/Recv instruction pair or
// optionally for collective instructions (AllReduce, CollectivePermute,
// AllToAll). Non-positive channel_id handle is equivalent to no channel id.
class ChannelHandle<Dialect dialect> : StructAttr<"ChannelHandle", dialect, [
StructFieldAttr<"handle", I64Attr>,
StructFieldAttr<"type", I64Attr>]> {
let description = "two 64-bit integers 'handle' and 'type'";
}
class BASE_HLO_ReplicaIdOp {
string summary = "ReplicaId operator";
@ -1258,4 +1267,45 @@ class BASE_HLO_RngNormalOp {
}];
}
class BASE_HLO_ReducePrecisionOp {
string summary = "Reduce precision operator";
string description = [{
Models the effect of converting floating - point values to a lower -
precision format(such as IEEE - FP16) and back to the original
format. The number of exponent and mantissa bits in the lower -
precision format can be specified arbitrarily,
although all bit sizes may not be supported on all hardware
implementations.
See https://www.tensorflow.org/xla/operation_semantics#reduceprecision.
}];
}
class BASE_HLO_InfeedOp {
string summary = "Infeed operator";
string description = [{
Reads a single data item from the implicit Infeed streaming interface of
the device, interpreting the data as the given shape and its layout, and
returns an LHLO op of the data. Multiple Infeed operations are allowed in a
computation, but there must be a total order among the Infeed operations.
For example, two Infeeds in the code below have a total order since there
is a dependency between the while loops.
See https://www.tensorflow.org/xla/operation_semantics#infeed
}];
}
class BASE_HLO_WhileOp {
string summary = "While operator";
string description = [{
Returns the result of executing a body function until the cond body returns
true.
See https://www.tensorflow.org/xla/operation_semantics#while.
}];
}
#endif // HLO_OPS_BASE

View File

@ -14,6 +14,20 @@ limitations under the License.
==============================================================================*/
// This is the operation definition file for LXLA.
//
// This file largely overlaps with hlo_ops.td at a logic level. It's tempting to
// merge these two files together, but we need to consider the following
// obstacles:
// * We need to have a common representation for arguments. That is to say,
// HLO_Array<X> translates to HLO_Tensor<X> in HLO dialect, and
// Arg<LHLO_Buffer<X>, "", [Mem(Read|Write)]> in LHLO. Array types within tuples
// also need to be transformed.
// * As of now, TableGen's dag functions are not sufficient to accomplish the
// one above.
// * Traits aren't identical, but need to be coped. For example,
// SameOperandAndResultType in HLO corresponds to SameTypeOperands in LHLO.
// * Also, currently HLO describes the API in XLA's client side, not service
// side. LHLO aims for the service side.
#ifndef LHLO_OPS
#define LHLO_OPS
@ -38,11 +52,17 @@ def LHLO_IntBuffer : MemRefOf<[HLO_Int]>;
// Any floating-point tensor types
def LHLO_FpBuffer : MemRefOf<[AnyFloat]>;
def LHLO_ComplexBuffer : MemRefOf<[AnyComplex]>;
def LHLO_FpOrComplexBuffer : MemRefOf<[AnyFloat, AnyComplex]>;
def LHLO_PredBuffer : MemRefOf<[HLO_Pred]>;
// Any integer or floating-point tensor types
def LHLO_IntOrFpBuffer : MemRefOf<[HLO_Int, AnyFloat]>;
def LHLO_PredOrIntBuffer : MemRefOf<[HLO_Int, HLO_Pred]>;
def LHLO_Buffer : MemRefOf<[AnyFloat, AnySignlessInteger, AnyComplex]>;
def LHLO_TupleBuffer : NestedTupleOf<[LHLO_Buffer]>;
@ -74,88 +94,126 @@ def LHLO_IotaOp : LHLO_Op<"iota", []>, BASE_HLO_IotaOp {
//===----------------------------------------------------------------------===//
// See https://www.tensorflow.org/xla/operation_semantics#element-wise_unary_functions
class LHLO_UnaryElementwiseOp<string mnemonic> :
LHLO_Op<mnemonic, [SameTypeOperands]> {
let arguments = (ins Arg<LHLO_Buffer, "", [MemRead]>:$input,
Arg<LHLO_Buffer, "", [MemWrite]>:$output);
class LHLO_UnaryElementwiseOp<string mnemonic,
Type BufferType = LHLO_Buffer,
list<OpTrait> traits = [SameTypeOperands]>
: LHLO_Op<mnemonic, traits> {
let arguments = (ins Arg<BufferType, "", [MemRead]>:$input,
Arg<BufferType, "", [MemWrite]>:$output);
}
def LHLO_AbsOp: LHLO_UnaryElementwiseOp<"abs">, BASE_HLO_AbsOp;
def LHLO_CeilOp: LHLO_UnaryElementwiseOp<"ceil">, BASE_HLO_CeilOp;
// TODO(timshen): add a custom verifier.
def LHLO_BitcastConvertOp:
LHLO_UnaryElementwiseOp<"bitcast_convert", LHLO_Buffer, [SameOperandsShape]>, BASE_HLO_BitcastConvertOp;
def LHLO_ConvertOp : LHLO_Op<"convert", [SameOperandsShape]>, BASE_HLO_ConvertOp {
let arguments = (ins Arg<LHLO_Buffer, "", [MemRead]>:$input,
Arg<LHLO_Buffer, "", [MemWrite]>:$output);
}
def LHLO_CeilOp: LHLO_UnaryElementwiseOp<"ceil", LHLO_FpBuffer>, BASE_HLO_CeilOp;
def LHLO_CosOp: LHLO_UnaryElementwiseOp<"cosine">, BASE_HLO_CosOp;
def LHLO_ClzOp: LHLO_UnaryElementwiseOp<"count_leading_zeros", LHLO_IntBuffer>, BASE_HLO_ClzOp;
def LHLO_ExpOp: LHLO_UnaryElementwiseOp<"exponential">, BASE_HLO_ExpOp;
// TODO(timshen): add a custom verifier.
def LHLO_ConvertOp : LHLO_UnaryElementwiseOp<"convert", LHLO_Buffer, [SameOperandsShape]>, BASE_HLO_ConvertOp;
def LHLO_CosOp: LHLO_UnaryElementwiseOp<"cosine", LHLO_FpOrComplexBuffer>, BASE_HLO_CosOp;
def LHLO_ExpOp: LHLO_UnaryElementwiseOp<"exponential", LHLO_FpOrComplexBuffer>, BASE_HLO_ExpOp;
def LHLO_Expm1Op: LHLO_UnaryElementwiseOp<"exponential_minus_one", LHLO_FpOrComplexBuffer>, BASE_HLO_Expm1Op;
def LHLO_FloorOp: LHLO_UnaryElementwiseOp<"floor", LHLO_FpBuffer>, BASE_HLO_FloorOp;
def LHLO_ImagOp: LHLO_Op<"imag", [SameOperandsShape]>, BASE_HLO_ImagOp {
let arguments = (ins Arg<LHLO_Buffer, "", [MemRead]>:$input,
Arg<LHLO_Buffer, "", [MemWrite]>:$output);
let arguments = (ins Arg<LHLO_ComplexBuffer, "", [MemRead]>:$input,
Arg<LHLO_FpBuffer, "", [MemWrite]>:$output);
}
def LHLO_LogOp: LHLO_UnaryElementwiseOp<"log">, BASE_HLO_LogOp;
def LHLO_IsFiniteOp: LHLO_Op<"is_finite", [SameOperandsShape]>, BASE_HLO_IsFiniteOp {
let arguments = (ins Arg<LHLO_FpBuffer, "", [MemRead]>:$input,
Arg<LHLO_PredBuffer, "", [MemWrite]>:$output);
}
def LHLO_LogOp: LHLO_UnaryElementwiseOp<"log", LHLO_FpOrComplexBuffer>, BASE_HLO_LogOp;
def LHLO_Log1pOp: LHLO_UnaryElementwiseOp<"log_plus_one", LHLO_FpOrComplexBuffer>, BASE_HLO_Log1pOp;
def LHLO_NegOp: LHLO_UnaryElementwiseOp<"negate">, BASE_HLO_NegOp;
def LHLO_NotOp: LHLO_UnaryElementwiseOp<"not", LHLO_PredOrIntBuffer>, BASE_HLO_NotOp;
def LHLO_PopulationCountOp: LHLO_UnaryElementwiseOp<"popcnt", LHLO_IntBuffer>, BASE_HLO_PopulationCountOp;
def LHLO_RealOp: LHLO_Op<"real", [SameOperandsShape]>, BASE_HLO_RealOp {
let arguments = (ins Arg<LHLO_Buffer, "", [MemRead]>:$input,
Arg<LHLO_Buffer, "", [MemWrite]>:$output);
let arguments = (ins Arg<LHLO_ComplexBuffer, "", [MemRead]>:$input,
Arg<LHLO_FpBuffer, "", [MemWrite]>:$output);
}
def LHLO_RsqrtOp: LHLO_UnaryElementwiseOp<"rsqrt">, BASE_HLO_RsqrtOp;
def LHLO_RoundOp: LHLO_UnaryElementwiseOp<"round_nearest_afz", LHLO_FpBuffer>, BASE_HLO_RoundOp;
def LHLO_SqrtOp: LHLO_UnaryElementwiseOp<"sqrt">, BASE_HLO_SqrtOp;
def LHLO_RsqrtOp: LHLO_UnaryElementwiseOp<"rsqrt", LHLO_FpOrComplexBuffer>, BASE_HLO_RsqrtOp;
def LHLO_SqrtOp: LHLO_UnaryElementwiseOp<"sqrt", LHLO_FpOrComplexBuffer>, BASE_HLO_SqrtOp;
def LHLO_SignOp: LHLO_UnaryElementwiseOp<"sign">, BASE_HLO_SignOp;
def LHLO_SinOp: LHLO_UnaryElementwiseOp<"sine">, BASE_HLO_SinOp;
def LHLO_SinOp: LHLO_UnaryElementwiseOp<"sine", LHLO_FpOrComplexBuffer>, BASE_HLO_SinOp;
def LHLO_TanhOp: LHLO_UnaryElementwiseOp<"tanh">, BASE_HLO_TanhOp;
def LHLO_TanhOp: LHLO_UnaryElementwiseOp<"tanh", LHLO_FpOrComplexBuffer>, BASE_HLO_TanhOp;
//===----------------------------------------------------------------------===//
// XLA binary elementwise op definitions.
//===----------------------------------------------------------------------===//
// See https://www.tensorflow.org/xla/operation_semantics#element-wise_binary_arithmetic_operations
class LHLO_BinaryElementwiseOp<string mnemonic, list<OpTrait> traits> :
class LHLO_BinaryElementwiseOp<string mnemonic, Type BufferType = LHLO_Buffer,
list<OpTrait> traits = [SameTypeOperands]> :
LHLO_Op<mnemonic, traits> {
let arguments = (ins
Arg<LHLO_Buffer, "", [MemRead]>:$lhs,
Arg<LHLO_Buffer, "", [MemRead]>:$rhs,
Arg<LHLO_Buffer, "", [MemWrite]>:$out,
Arg<BufferType, "", [MemRead]>:$lhs,
Arg<BufferType, "", [MemRead]>:$rhs,
Arg<BufferType, "", [MemWrite]>:$out,
OptionalAttr<BroadcastDimAttr>:$broadcast_dimensions
);
}
def LHLO_AddOp : LHLO_BinaryElementwiseOp<"add", []>, BASE_HLO_AddOp;
def LHLO_AddOp : LHLO_BinaryElementwiseOp<"add">, BASE_HLO_AddOp;
def LHLO_AndOp: LHLO_BinaryElementwiseOp<"and", LHLO_PredOrIntBuffer>, BASE_HLO_AndOp;
def LHLO_Atan2Op : LHLO_BinaryElementwiseOp<"atan2", LHLO_FpOrComplexBuffer>, BASE_HLO_Atan2Op;
def LHLO_ComplexOp: LHLO_Op<"complex", [SameOperandsShape]>, BASE_HLO_ComplexOp {
let arguments = (ins Arg<LHLO_Buffer, "", [MemRead]>:$lhs,
Arg<LHLO_Buffer, "", [MemRead]>:$rhs,
Arg<LHLO_Buffer, "", [MemWrite]>:$output);
let arguments = (ins
Arg<LHLO_FpBuffer, "", [MemRead]>:$lhs,
Arg<LHLO_FpBuffer, "", [MemRead]>:$rhs,
Arg<LHLO_ComplexBuffer, "", [MemWrite]>:$output,
OptionalAttr<BroadcastDimAttr>:$broadcast_dimensions
);
}
def LHLO_DivOp : LHLO_BinaryElementwiseOp<"divide", []>, BASE_HLO_DivOp;
def LHLO_DivOp : LHLO_BinaryElementwiseOp<"divide">, BASE_HLO_DivOp;
def LHLO_MaxOp : LHLO_BinaryElementwiseOp<"maximum", []>, BASE_HLO_MaxOp;
def LHLO_MaxOp : LHLO_BinaryElementwiseOp<"maximum">, BASE_HLO_MaxOp;
def LHLO_MinOp : LHLO_BinaryElementwiseOp<"minimum", []>, BASE_HLO_MinOp;
def LHLO_MinOp : LHLO_BinaryElementwiseOp<"minimum">, BASE_HLO_MinOp;
def LHLO_MulOp : LHLO_BinaryElementwiseOp<"multiply", []>, BASE_HLO_MulOp;
def LHLO_MulOp : LHLO_BinaryElementwiseOp<"multiply">, BASE_HLO_MulOp;
def LHLO_RemOp :
LHLO_BinaryElementwiseOp<"remainder", []>, BASE_HLO_RemOp;
def LHLO_OrOp : LHLO_BinaryElementwiseOp<"or", LHLO_PredOrIntBuffer>, BASE_HLO_OrOp;
def LHLO_SubOp : LHLO_BinaryElementwiseOp<"subtract", []>, BASE_HLO_SubOp;
def LHLO_PowOp : LHLO_BinaryElementwiseOp<"power">, BASE_HLO_PowOp;
def LHLO_AndOp: LHLO_BinaryElementwiseOp<"and", []>, BASE_HLO_AndOp;
def LHLO_RemOp : LHLO_BinaryElementwiseOp<"remainder", LHLO_IntOrFpBuffer>, BASE_HLO_RemOp;
def LHLO_OrOp: LHLO_BinaryElementwiseOp<"or", []>, BASE_HLO_OrOp;
def LHLO_ShiftLeftOp : LHLO_BinaryElementwiseOp<"shift_left", LHLO_IntBuffer>, BASE_HLO_ShiftLeftOp;
def LHLO_ShiftRightArithmeticOp : LHLO_BinaryElementwiseOp<"shift_right_arithmetic", LHLO_IntBuffer>, BASE_HLO_ShiftRightArithmeticOp;
def LHLO_ShiftRightLogicalOp : LHLO_BinaryElementwiseOp<"shift_right_logical", LHLO_IntBuffer>, BASE_HLO_ShiftRightLogicalOp;
def LHLO_SubOp : LHLO_BinaryElementwiseOp<"subtract">, BASE_HLO_SubOp;
def LHLO_XorOp : LHLO_BinaryElementwiseOp<"xor", LHLO_PredOrIntBuffer>, BASE_HLO_XorOp;
//===----------------------------------------------------------------------===//
// XLA control flow op definitions.
@ -210,6 +268,16 @@ def LHLO_CaseOp: LHLO_Op<"case", [
let regions = (region VariadicRegion<SizedRegion<1>>:$branches);
}
// TODO(timshen): Add a custom syntax for this.
def LHLO_WhileOp: LHLO_Op<"while", [SameTypeOperands]>, BASE_HLO_WhileOp {
let arguments = (ins
Arg<LHLO_BufferOrTuple, "", [MemRead]>:$val,
Arg<LHLO_BufferOrTuple, "", [MemWrite]>:$output
);
let regions = (region SizedRegion<1>:$cond, SizedRegion<1>:$body);
}
//===----------------------------------------------------------------------===//
// XLA tuple op definitions.
//===----------------------------------------------------------------------===//
@ -269,7 +337,9 @@ def HLO_DynamicUpdateSliceOp: LHLO_Op<"dynamic-update-slice", []> {
def HLO_StaticMemRefCastOp: Op<LHLO_Dialect, "static_memref_cast",
[NoSideEffect, DeclareOpInterfaceMethods<ViewLikeOpInterface>]> {
let summary = "static memref cast operation";
let summary = [{
"modifies the offset, sizes and strides of a statically shaped memref.
}];
let description = [{
Allows to modify the offset, sizes and strides of a statically shaped memref.
@ -357,7 +427,23 @@ def HLO_DynamicMemRefCastOp: Op<LHLO_Dialect, "dynamic_memref_cast",
// XLA Other op definitions.
//===----------------------------------------------------------------------===//
def HLO_BatchNormInferenceOp : LHLO_Op<"batch_norm_inference", []>,
def LHLO_BatchNormGradOp : LHLO_Op<"batch_norm_grad", []>,
BASE_HLO_BatchNormGradOp {
let arguments = (ins
Arg<LHLO_Buffer, "", [MemRead]>:$operand,
Arg<LHLO_Buffer, "", [MemRead]>:$scale,
Arg<LHLO_Buffer, "", [MemRead]>:$mean,
Arg<LHLO_Buffer, "", [MemRead]>:$variance,
Arg<LHLO_Buffer, "", [MemRead]>:$grad_output,
Arg<LHLO_TupleBuffer, "", [MemWrite]>:$output,
F32Attr:$epsilon,
I64Attr:$feature_index
);
}
def LHLO_BatchNormInferenceOp : LHLO_Op<"batch_norm_inference", []>,
BASE_HLO_BatchNormInferenceOp {
let arguments = (ins
@ -372,6 +458,19 @@ def HLO_BatchNormInferenceOp : LHLO_Op<"batch_norm_inference", []>,
);
}
def LHLO_BatchNormTrainingOp : LHLO_Op<"batch_norm_training", []>,
BASE_HLO_BatchNormTrainingOp {
let arguments = (ins
Arg<LHLO_Buffer, "", [MemRead]>:$operand,
Arg<LHLO_Buffer, "", [MemRead]>:$scale,
Arg<LHLO_Buffer, "", [MemRead]>:$offset,
Arg<LHLO_TupleBuffer, "", [MemWrite]>:$output,
F32Attr:$epsilon,
I64Attr:$feature_index
);
}
def LHLO_BroadcastOp : LHLO_Op<"broadcast",
[]>, BASE_HLO_BroadcastOp {
let arguments = (ins
@ -531,6 +630,88 @@ def LHLO_TransposeOp: LHLO_Op<"transpose", []>, BASE_HLO_TransposeOp {
);
}
def LHLO_ReducePrecisionOp: LHLO_Op<"reduce_precision", [SameTypeOperands]>,
BASE_HLO_ReducePrecisionOp {
let arguments = (ins
Arg<LHLO_FpBuffer, "", [MemRead]>:$operand,
Arg<LHLO_FpBuffer, "", [MemWrite]>:$output,
I32Attr:$exponent_bits,
I32Attr:$mantissa_bits
);
}
def LHLO_AllReduceOp : LHLO_Op<"all_reduce", [SameTypeOperands]>,
BASE_HLO_AllReduceOp {
let arguments = (ins
Arg<LHLO_Buffer, "", [MemRead]>:$operand,
Arg<LHLO_Buffer, "", [MemWrite]>:$output,
I64ElementsAttr:$replica_groups,
DefaultValuedAttr<BoolAttr, "false">:$constrain_layout,
OptionalAttr<ChannelHandle<LHLO_Dialect>>:$channel_id,
DefaultValuedAttr<BoolAttr, "false">:$use_global_device_ids
);
let regions = (region SizedRegion<1>:$computation);
}
def LHLO_CollectivePermuteOp: LHLO_Op<"collective_permute", [SameTypeOperands]>,
BASE_HLO_CollectivePermuteOp {
let arguments = (ins
Arg<LHLO_Buffer, "", [MemRead]>:$operand,
Arg<LHLO_Buffer, "", [MemWrite]>:$output,
I64ElementsAttr:$source_target_pairs,
OptionalAttr<ChannelHandle<LHLO_Dialect>>:$channel_id
);
}
def LHLO_FftOp: LHLO_Op<"fft", []>, BASE_HLO_FftOp {
let arguments = (ins
Arg<LHLO_Buffer, "", [MemRead]>:$operand,
Arg<LHLO_Buffer, "", [MemWrite]>:$output,
HLO_FftTypeAttr:$fft_type,
I64ElementsAttr:$fft_length
);
}
def LHLO_CholeskyOp: LHLO_Op<"cholesky", [SameOperandsElementType]>, BASE_HLO_CholeskyOp {
let arguments = (ins
Arg<LHLO_FpOrComplexBuffer, "", [MemRead]>:$a,
Arg<LHLO_FpOrComplexBuffer, "", [MemWrite]>:$output,
DefaultValuedAttr<BoolAttr, "false">:$lower
);
}
def LHLO_Infeed: LHLO_Op<"infeed", []>, BASE_HLO_InfeedOp {
let arguments = (ins
Arg<LHLO_Buffer, "", [MemWrite]>:$output,
DefaultValuedAttr<StrAttr, "">:$config
);
}
def LHLO_Outfeed: LHLO_Op<"outfeed", []> {
let arguments = (ins
Arg<LHLO_Buffer, "", [MemRead]>:$operand,
DefaultValuedAttr<StrAttr, "">:$config
);
}
def LHLO_ReplicaIdOp : LHLO_Op<"replica_id", []>, BASE_HLO_ReplicaIdOp {
let arguments = (ins Arg<MemRefOf<[UI32]>, "", [MemWrite]>);
}
def LHLO_TriangularSolveOp: LHLO_Op<"triangular_solve", [SameOperandsElementType]>,
BASE_HLO_TriangularSolveOp {
let arguments = (ins
Arg<LHLO_FpOrComplexBuffer, "", [MemRead]>:$a,
Arg<LHLO_FpOrComplexBuffer, "", [MemRead]>:$b,
Arg<LHLO_FpOrComplexBuffer, "", [MemWrite]>:$output,
BoolAttr:$left_side,
BoolAttr:$lower,
BoolAttr:$unit_diagonal,
HLO_TransposeAttr:$transpose_a
);
}
//===----------------------------------------------------------------------===//
// Late operations
//===----------------------------------------------------------------------===//

View File

@ -19,10 +19,12 @@ limitations under the License.
#include "mlir/IR/Builders.h" // from @llvm-project
#include "mlir/IR/StandardTypes.h" // from @llvm-project
#include "tensorflow/compiler/mlir/xla/attribute_importer.h"
#include "tensorflow/compiler/mlir/xla/hlo_function_importer.h"
#include "tensorflow/compiler/mlir/xla/hlo_utils.h"
#include "tensorflow/compiler/mlir/xla/ir/hlo_ops.h"
#include "tensorflow/compiler/mlir/xla/type_to_shape.h"
#include "tensorflow/compiler/xla/comparison_util.h"
#include "tensorflow/compiler/xla/service/hlo_module.h"
#include "tensorflow/compiler/xla/service/shape_inference.h"
#include "tensorflow/compiler/xla/util.h"
@ -118,6 +120,76 @@ StatusOr<XlaOp> MlirHloBuilder::ConvGeneralDilatedInternal(
return MakeXlaOp(op);
}
StatusOr<XlaOp> MlirHloBuilder::FftInternal(
const Shape& shape, XlaOp operand, FftType fft_type,
absl::Span<const int64> fft_length) {
TF_ASSIGN_OR_RETURN(mlir::Type ty, ConvertShapeToType<mlir::RankedTensorType>(
shape, builder_));
auto op = builder_.create<mlir::xla_hlo::FftOp>(
loc_, ty, GetValue(operand),
builder_.getStringAttr(FftType_Name(fft_type)),
GetI64ElementsAttr(fft_length, &builder_));
return MakeXlaOp(op);
}
StatusOr<XlaOp> MlirHloBuilder::ReduceInternal(
const Shape& shape, absl::Span<const XlaOp> all_operands,
const XlaComputation& computation,
absl::Span<const int64> dimensions_to_reduce) {
// Reduce takes two set of variadic operands inputs and init_values.
// all_operands contains both of these so split operands into two parts.
int64_t num_args = all_operands.size() / 2;
auto op = builder_.create<mlir::xla_hlo::ReduceOp>(
loc_, GetValues(all_operands.first(num_args)),
GetValues(all_operands.subspan(num_args)),
GetI64ElementsAttr(dimensions_to_reduce, &builder_));
TF_RETURN_IF_ERROR(ImportComputation(computation.proto(), &op.body()));
if (op.getNumResults() == 1) return MakeXlaOp(op.getResult(0));
auto tuple = builder_.create<mlir::xla_hlo::TupleOp>(loc_, op.getResults());
return MakeXlaOp(tuple);
}
StatusOr<XlaOp> MlirHloBuilder::ReduceWindowInternal(
const Shape& shape, XlaOp operand, XlaOp init_value,
const XlaComputation& computation, Window window) {
TF_ASSIGN_OR_RETURN(mlir::Type ty, ConvertShapeToType<mlir::RankedTensorType>(
shape, builder_));
llvm::SmallVector<int64, 4> sizes, strides, base_dilations, win_dilations;
llvm::SmallVector<int64, 8> padding;
for (const auto& dim : window.dimensions()) {
sizes.push_back(dim.size());
strides.push_back(dim.stride());
base_dilations.push_back(dim.base_dilation());
win_dilations.push_back(dim.window_dilation());
padding.push_back(dim.padding_low());
padding.push_back(dim.padding_high());
}
auto padding_ty =
mlir::RankedTensorType::get({static_cast<int64_t>(padding.size()) / 2, 2},
builder_.getIntegerType(64));
auto op = builder_.create<mlir::xla_hlo::ReduceWindowOp>(
loc_, ty, GetValue(operand), GetValue(init_value),
GetI64ElementsAttr(sizes, &builder_),
GetI64ElementsAttr(strides, &builder_),
GetI64ElementsAttr(base_dilations, &builder_),
GetI64ElementsAttr(win_dilations, &builder_),
mlir::DenseIntElementsAttr::get(padding_ty, padding));
TF_RETURN_IF_ERROR(ImportComputation(computation.proto(), &op.body()));
return MakeXlaOp(op);
}
XlaOp MlirHloBuilder::Iota(const Shape& shape, int64 iota_dimension) {
return ReportErrorOrReturn([&]() -> StatusOr<XlaOp> {
TF_ASSIGN_OR_RETURN(
mlir::Type ty,
ConvertShapeToType<mlir::RankedTensorType>(shape, builder_));
auto op = builder_.create<mlir::xla_hlo::IotaOp>(
loc_, ty,
builder_.getIntegerAttr(builder_.getI64Type(), iota_dimension));
return MakeXlaOp(op);
});
}
StatusOr<XlaOp> MlirHloBuilder::TransposeInternal(
const Shape& shape, XlaOp operand, absl::Span<const int64> permutation) {
TF_ASSIGN_OR_RETURN(mlir::Type ty, ConvertShapeToType<mlir::RankedTensorType>(
@ -127,6 +199,15 @@ StatusOr<XlaOp> MlirHloBuilder::TransposeInternal(
return MakeXlaOp(op);
}
StatusOr<XlaOp> MlirHloBuilder::RevInternal(
const Shape& shape, XlaOp operand, absl::Span<const int64> dimensions) {
TF_ASSIGN_OR_RETURN(mlir::Type ty, ConvertShapeToType<mlir::RankedTensorType>(
shape, builder_));
auto op = builder_.create<mlir::xla_hlo::ReverseOp>(
loc_, ty, GetValue(operand), GetI64ElementsAttr(dimensions, &builder_));
return MakeXlaOp(op);
}
StatusOr<XlaOp> MlirHloBuilder::GatherInternal(
const Shape& shape, XlaOp input, XlaOp start_indices,
const GatherDimensionNumbers& dimension_numbers,
@ -140,6 +221,24 @@ StatusOr<XlaOp> MlirHloBuilder::GatherInternal(
return MakeXlaOp(op);
}
StatusOr<XlaOp> MlirHloBuilder::ScatterInternal(
const Shape& shape, XlaOp input, XlaOp scatter_indices, XlaOp updates,
const XlaComputation& update_computation,
const ScatterDimensionNumbers& dimension_numbers, bool indices_are_sorted,
bool unique_indices) {
TF_ASSIGN_OR_RETURN(mlir::Type ty, ConvertShapeToType<mlir::RankedTensorType>(
shape, builder_));
auto op = builder_.create<mlir::xla_hlo::ScatterOp>(
loc_, ty, GetValue(input), GetValue(scatter_indices), GetValue(updates),
ConvertScatterDimensionNumbers(dimension_numbers, &builder_),
builder_.getBoolAttr(indices_are_sorted),
builder_.getBoolAttr(unique_indices));
TF_RETURN_IF_ERROR(
ImportComputation(update_computation.proto(), &op.update_computation()));
return MakeXlaOp(op);
}
StatusOr<XlaOp> MlirHloBuilder::RngOpInternal(
RandomDistribution distribution, absl::Span<const XlaOp> parameters,
const Shape& shape) {
@ -348,6 +447,18 @@ StatusOr<XlaOp> MlirHloBuilder::CreateOp(
return MakeXlaOp(op->getResult(0));
}
Status MlirHloBuilder::ImportComputation(const HloModuleProto& computation,
mlir::Region* region) {
TF_ASSIGN_OR_RETURN(auto module_config,
xla::HloModule::CreateModuleConfigFromProto(
computation, xla::DebugOptions()));
TF_ASSIGN_OR_RETURN(auto hlo_module, xla::HloModule::CreateFromProto(
computation, module_config));
return HloFunctionImporter::ImportAsRegion(*hlo_module->entry_computation(),
region, &builder_);
}
StatusOr<const Shape*> MlirHloBuilder::GetShapePtr(XlaOp op) const {
TF_RETURN_IF_ERROR(first_error());
TF_RETURN_IF_ERROR(CheckOpBuilder(op));

View File

@ -120,15 +120,40 @@ class MlirHloBuilder : public XlaBuilder {
int64 feature_group_count, int64 batch_group_count,
const PrecisionConfig* precision_config) override;
StatusOr<XlaOp> FftInternal(const Shape& shape, XlaOp operand,
FftType fft_type,
absl::Span<const int64> fft_length) override;
StatusOr<XlaOp> ReduceInternal(
const Shape& shape, absl::Span<const XlaOp> all_operands,
const XlaComputation& computation,
absl::Span<const int64> dimensions_to_reduce) override;
StatusOr<XlaOp> ReduceWindowInternal(const Shape& shape, XlaOp operand,
XlaOp init_value,
const XlaComputation& computation,
Window window) override;
XlaOp Iota(const Shape& shape, int64 iota_dimension) override;
StatusOr<XlaOp> TransposeInternal(
const Shape& shape, XlaOp operand,
absl::Span<const int64> permutation) override;
StatusOr<XlaOp> RevInternal(const Shape& shape, XlaOp operand,
absl::Span<const int64> dimensions) override;
StatusOr<XlaOp> GatherInternal(
const Shape& shape, XlaOp input, XlaOp start_indices,
const GatherDimensionNumbers& dimension_numbers,
absl::Span<const int64> slice_sizes, bool indices_are_sorted) override;
StatusOr<XlaOp> ScatterInternal(
const Shape& shape, XlaOp input, XlaOp scatter_indices, XlaOp updates,
const XlaComputation& update_computation,
const ScatterDimensionNumbers& dimension_numbers, bool indices_are_sorted,
bool unique_indices) override;
StatusOr<XlaOp> RngOpInternal(RandomDistribution distribution,
absl::Span<const XlaOp> parameters,
const Shape& shape) override;
@ -196,6 +221,9 @@ class MlirHloBuilder : public XlaBuilder {
llvm::ArrayRef<XlaOp> operands,
llvm::ArrayRef<mlir::NamedAttribute> attributes = {});
Status ImportComputation(const HloModuleProto& computation,
mlir::Region* region);
mlir::OpBuilder builder_;
mlir::Location loc_;

View File

@ -68,10 +68,11 @@ static StringRef GetClientBuilder(const Operator& op) {
return kOpToXLABuilderMap->lookup(op_name);
}
static void BuildOperator(const Operator& op, raw_ostream* output) {
auto& os = *output;
os << " auto& value_map = *lowering_context.values;\n"
<< " auto result = xla_op.getResult();\n";
static void BuildOperator(const Operator& op, raw_ostream& os) {
os << "mlir::LogicalResult ExportXlaOp(mlir::xla_hlo::"
<< op.getCppClassName() << " op, OpLoweringContext ctx) {\n"
<< " auto& value_map = *ctx.values;\n"
<< " auto result = op.getResult();\n";
// Build a conversion for each of the arguments.
int operand_number = 0;
@ -82,15 +83,14 @@ static void BuildOperator(const Operator& op, raw_ostream* output) {
if (auto* operand_cst = arg.dyn_cast<NamedTypeConstraint*>()) {
// Handle a non-variadic operand.
if (!operand_cst->isVariableLength()) {
os << " auto xla_arg_" << index
<< " = value_map[*xla_op.getODSOperands(" << operand_number++
<< ").begin()];\n";
os << " auto xla_arg_" << index << " = value_map[*op.getODSOperands("
<< operand_number++ << ").begin()];\n";
continue;
}
// Otherwise, this is a varidiac operand list.
os << " std::vector<xla::XlaOp> xla_arg_" << index << ";\n"
<< " for (auto operand : xla_op.getODSOperands(" << operand_number++
<< " for (auto operand : op.getODSOperands(" << operand_number++
<< "))\n xla_arg_" << index
<< ".push_back(value_map[operand]);\n";
continue;
@ -99,8 +99,8 @@ static void BuildOperator(const Operator& op, raw_ostream* output) {
// Otherwise, this is an attribute.
auto named_attr = arg.get<NamedAttribute*>();
os << " auto xla_arg_" << index << " = "
<< GetDefaultAttrExport(*named_attr) << "(xla_op."
<< op.getArgName(index) << "());\n";
<< GetDefaultAttrExport(*named_attr) << "(op." << op.getArgName(index)
<< "());\n";
}
// Emit call to client API
@ -109,7 +109,7 @@ static void BuildOperator(const Operator& op, raw_ostream* output) {
// If all operands are variadic, then pass the builder explicitly to xla
// client API call
if (op.getNumOperands() == op.getNumVariableLengthOperands()) {
os << "lowering_context.builder";
os << "ctx.builder";
if (op.getNumArgs() != 0) os << ", ";
}
@ -120,6 +120,7 @@ static void BuildOperator(const Operator& op, raw_ostream* output) {
os << " value_map[result] = xla_result;\n";
os << " return mlir::success();\n";
os << "}\n";
}
// The function below has a non-constant reference as that is required by LLVM's
@ -128,6 +129,14 @@ static void BuildOperator(const Operator& op, raw_ostream* output) {
static bool OperatorWritersMain(raw_ostream& os, RecordKeeper& records) {
emitSourceFileHeader("MLIR XLA Builders", os);
// Emit all the helper functions.
for (const auto* def : records.getAllDerivedDefinitions("HLO_Op")) {
Operator op(def);
// Skip operations that have a custom exporter.
if (!def->getValueAsBit("hasCustomHLOConverter")) BuildOperator(op, os);
}
// Emit a function to generate an XLA operation for the operations with
// auto-generated builders.
os << "mlir::LogicalResult ExportXlaOperator(\n"
@ -153,12 +162,11 @@ static bool OperatorWritersMain(raw_ostream& os, RecordKeeper& records) {
// Cast to the current operation and build the exporter.
os << " if (auto xla_op = llvm::dyn_cast<mlir::xla_hlo::"
<< op.getCppClassName() << ">(op)) {\n";
if (def->getValueAsBit("hasCustomHLOConverter")) {
os << " return mlir::xla_hlo::ExportXlaOp(xla_op, "
"lowering_context);\n";
} else {
BuildOperator(op, &os);
}
os << " return ";
// The autogenerated converters aren't in the same namespace.
// TODO(jpienaar): Reconsider this.
if (def->getValueAsBit("hasCustomHLOConverter")) os << "mlir::xla_hlo::";
os << "ExportXlaOp(xla_op, lowering_context);\n";
os << " }\n";
}

View File

@ -30,7 +30,7 @@ tf_cc_test(
"//tensorflow/core:test",
"//tensorflow/core:test_main",
"//tensorflow/core:testlib",
"@llvm-project//llvm:support",
"@llvm-project//llvm:Support",
"@llvm-project//mlir:IR",
],
)

View File

@ -9,7 +9,7 @@ func @broadcast_add(%arg0: tensor<?xf32>, %arg1: tensor<?xf32>) -> tensor<1xinde
// CHECK-DAG: %[[ARG0_S:.+]] = shape.shape_of %[[ARG0]]
// CHECK-DAG: %[[ARG1_S:.+]] = shape.shape_of %[[ARG1]]
// CHECK-DAG: %[[BCAST_S:.+]] = "shape.broadcast"(%[[ARG0_S]], %[[ARG1_S]])
// CHECK: %[[EXTENTS:.+]] = "shape.to_extent_tensor"(%[[BCAST_S]])
// CHECK: %[[EXTENTS:.+]] = shape.to_extent_tensor %[[BCAST_S]]
// CHECK: return %[[EXTENTS]]
%0 = xla_chlo.broadcast_add %arg0, %arg1 : (tensor<?xf32>, tensor<?xf32>) -> tensor<?xf32>
%1 = "xla_test.reify_return_type_shapes"(%0) : (tensor<?xf32>) -> tensor<1xindex>

View File

@ -17,7 +17,7 @@ func @dynamicBroadcast(%arg0: tensor<?xf32>, %arg1: tensor<?x?xf32>) -> tensor<?
// CHECK-DAG: %[[ARG0_S:.+]] = shape.shape_of %[[ARG0]]
// CHECK-DAG: %[[ARG1_S:.+]] = shape.shape_of %[[ARG1]]
// CHECK-DAG: %[[RESULT_S:.+]] = "shape.broadcast"(%[[ARG0_S]], %[[ARG1_S]])
// CHECK: %[[RESULT_EXTENTS:.+]] = "shape.to_extent_tensor"(%[[RESULT_S]])
// CHECK: %[[RESULT_EXTENTS:.+]] = shape.to_extent_tensor %[[RESULT_S]]
// CHECK-DAG: %[[ARG0_B:.+]] = "xla_hlo.dynamic_broadcast_in_dim"(%[[ARG0]], %[[RESULT_EXTENTS]]) {broadcast_dimensions = dense<1> : tensor<1xi64>}
// CHECK-DAG: %[[ARG1_B:.+]] = "xla_hlo.dynamic_broadcast_in_dim"(%[[ARG1]], %[[RESULT_EXTENTS]]) {broadcast_dimensions = dense<[0, 1]> : tensor<2xi64>}
// CHECK-DAG: %[[RESULT:.+]] = xla_hlo.add %[[ARG0_B]], %[[ARG1_B]]
@ -34,7 +34,7 @@ func @dynamicBroadcastComplex(%arg0: tensor<?xf32>, %arg1: tensor<?x?xf32>) -> t
// CHECK-DAG: %[[ARG0_S:.+]] = shape.shape_of %[[ARG0]]
// CHECK-DAG: %[[ARG1_S:.+]] = shape.shape_of %[[ARG1]]
// CHECK-DAG: %[[RESULT_S:.+]] = "shape.broadcast"(%[[ARG0_S]], %[[ARG1_S]])
// CHECK: %[[RESULT_EXTENTS:.+]] = "shape.to_extent_tensor"(%[[RESULT_S]])
// CHECK: %[[RESULT_EXTENTS:.+]] = shape.to_extent_tensor %[[RESULT_S]]
// CHECK-DAG: %[[ARG0_B:.+]] = "xla_hlo.dynamic_broadcast_in_dim"(%[[ARG0]], %[[RESULT_EXTENTS]]) {broadcast_dimensions = dense<1> : tensor<1xi64>} : (tensor<?xf32>, tensor<2xindex>) -> tensor<?x?xf32>
// CHECK-DAG: %[[ARG1_B:.+]] = "xla_hlo.dynamic_broadcast_in_dim"(%[[ARG1]], %[[RESULT_EXTENTS]]) {broadcast_dimensions = dense<[0, 1]> : tensor<2xi64>} : (tensor<?x?xf32>, tensor<2xindex>) -> tensor<?x?xf32>
// CHECK-DAG: %[[RESULT:.+]] = "xla_hlo.complex"(%[[ARG0_B]], %[[ARG1_B]]) : (tensor<?x?xf32>, tensor<?x?xf32>) -> tensor<?x?xcomplex<f32>>
@ -51,7 +51,7 @@ func @dynamicBroadcastCompare(%arg0: tensor<?xf32>, %arg1: tensor<?x?xf32>) -> t
// CHECK-DAG: %[[ARG0_S:.+]] = shape.shape_of %[[ARG0]]
// CHECK-DAG: %[[ARG1_S:.+]] = shape.shape_of %[[ARG1]]
// CHECK-DAG: %[[RESULT_S:.+]] = "shape.broadcast"(%[[ARG0_S]], %[[ARG1_S]])
// CHECK: %[[RESULT_EXTENTS:.+]] = "shape.to_extent_tensor"(%[[RESULT_S]])
// CHECK: %[[RESULT_EXTENTS:.+]] = shape.to_extent_tensor %[[RESULT_S]]
// CHECK-DAG: %[[ARG0_B:.+]] = "xla_hlo.dynamic_broadcast_in_dim"(%[[ARG0]], %[[RESULT_EXTENTS]]) {broadcast_dimensions = dense<1> : tensor<1xi64>} : (tensor<?xf32>, tensor<2xindex>) -> tensor<?x?xf32>
// CHECK-DAG: %[[ARG1_B:.+]] = "xla_hlo.dynamic_broadcast_in_dim"(%[[ARG1]], %[[RESULT_EXTENTS]]) {broadcast_dimensions = dense<[0, 1]> : tensor<2xi64>} : (tensor<?x?xf32>, tensor<2xindex>) -> tensor<?x?xf32>
// CHECK-DAG: %[[RESULT:.+]] = "xla_hlo.compare"(%[[ARG0_B]], %[[ARG1_B]]) {comparison_direction = "EQ"} : (tensor<?x?xf32>, tensor<?x?xf32>) -> tensor<?x?xi1>

View File

@ -184,14 +184,16 @@ func @dyn_broadcast(%operand: memref<?x?xf32>) {
// CHECK: %[[C1__:.*]] = constant 1 : index
// CHECK: %[[EL1_:.*]] = extract_element %[[SHAPE]]{{\[}}%[[C1__]]] : tensor<3xi64>
// CHECK: %[[OPERAND_DIM_0:.*]] = dim %[[OPERAND]], 0 : memref<?x?xf32>
// CHECK: %[[C0___:.*]] = constant 0 : index
// CHECK: %[[OPERAND_DIM_0:.*]] = dim %[[OPERAND]], %[[C0___]] : memref<?x?xf32>
// CHECK: %[[RESULT_DIM_1:.*]] = index_cast %[[EL1_]] : i64 to index
// CHECK: %[[EXPAND_0:.*]] = cmpi "slt", %[[OPERAND_DIM_0]], %[[RESULT_DIM_1]]
// CHECK: %[[STRIDE_0:.*]] = select %[[EXPAND_0]], %[[C0_]], %[[C1_]] : index
// CHECK: %[[C2_:.*]] = constant 2 : index
// CHECK: %[[EL2_:.*]] = extract_element %[[SHAPE]]{{\[}}%[[C2_]]] : tensor<3xi64>
// CHECK: %[[OPERAND_DIM_1:.*]] = dim %[[OPERAND]], 1 : memref<?x?xf32>
// CHECK: %[[C1___:.*]] = constant 1 : index
// CHECK: %[[OPERAND_DIM_1:.*]] = dim %[[OPERAND]], %[[C1___]] : memref<?x?xf32>
// CHECK: %[[RESULT_DIM_2:.*]] = index_cast %[[EL2_]] : i64 to index
// CHECK: %[[EXPAND_1:.*]] = cmpi "slt", %[[OPERAND_DIM_1]], %[[RESULT_DIM_2]]
// CHECK: %[[STRIDE_1:.*]] = select %[[EXPAND_1]], %[[C0_]], %[[C1_]] : index
@ -389,15 +391,18 @@ func @remainder(%lhs: memref<2x2xf32>, %rhs: memref<2x2xf32>, %result: memref<2x
func @add_dyn(%lhs: tensor<?x?xf32>, %rhs: tensor<?x?xf32>) {
%result = "xla_hlo.add"(%lhs, %rhs)
: (tensor<?x?xf32>, tensor<?x?xf32>) -> tensor<?x?xf32>
// CHECK: %[[DIM0:.*]] = dim %arg0, 0 : memref<?x?xf32>
// CHECK: %[[C0:.*]] = constant 0 : index
// CHECK: %[[DIM0:.*]] = dim %arg0, %[[C0]] : memref<?x?xf32>
// CHECK: %[[IC0:.*]] = index_cast %[[DIM0]] : index to i64
// CHECK: %[[DIM1:.*]] = dim %arg0, 1 : memref<?x?xf32>
// CHECK: %[[C1:.*]] = constant 1 : index
// CHECK: %[[DIM1:.*]] = dim %arg0, %[[C1]] : memref<?x?xf32>
// CHECK: %[[IC1:.*]] = index_cast %[[DIM1]] : index to i64
// CHECK: %[[SHAPE:.*]] = tensor_from_elements(%[[IC0]], %[[IC1]]) : tensor<2xi64>
// CHECK: %[[C0:.*]] = constant 0 : index
// CHECK: %[[EE0:.*]] = extract_element %[[SHAPE]][%[[C0]]] : tensor<2xi64>
// CHECK: %[[C0_:.*]] = constant 0 : index
// CHECK: %[[EE0:.*]] = extract_element %[[SHAPE]][%[[C0_]]] : tensor<2xi64>
// CHECK: %[[ICS0:.*]] = index_cast %[[EE0]] : i64 to index
// CHECK: %[[EE1:.*]] = extract_element %[[SHAPE]][%[[C1]]] : tensor<2xi64>
// CHECK: %[[C1_:.*]] = constant 1 : index
// CHECK: %[[EE1:.*]] = extract_element %[[SHAPE]][%[[C1_]]] : tensor<2xi64>
// CHECK: %[[ICS1:.*]] = index_cast %[[EE1]] : i64 to index
// CHECK: %[[RESULT:.*]] = alloc(%[[ICS0]], %[[ICS1]])
// CHECK: "xla_lhlo.add"(%arg0, %arg1, %[[RESULT]]) : (memref<?x?xf32>, memref<?x?xf32>, memref<?x?xf32>) -> ()
@ -411,15 +416,18 @@ func @add_dyn(%lhs: tensor<?x?xf32>, %rhs: tensor<?x?xf32>) {
func @tanh_dyn(%arg0: tensor<?x?xf32>) {
%result = "xla_hlo.tanh"(%arg0)
: (tensor<?x?xf32>) -> tensor<?x?xf32>
// CHECK: %[[DIM0:.*]] = dim %arg0, 0 : memref<?x?xf32>
// CHECK: %[[C0:.*]] = constant 0 : index
// CHECK: %[[DIM0:.*]] = dim %arg0, %[[C0]] : memref<?x?xf32>
// CHECK: %[[IC0:.*]] = index_cast %[[DIM0]] : index to i64
// CHECK: %[[DIM1:.*]] = dim %arg0, 1 : memref<?x?xf32>
// CHECK: %[[C1:.*]] = constant 1 : index
// CHECK: %[[DIM1:.*]] = dim %arg0, %[[C1]] : memref<?x?xf32>
// CHECK: %[[IC1:.*]] = index_cast %[[DIM1]] : index to i64
// CHECK: %[[SHAPE:.*]] = tensor_from_elements(%[[IC0]], %[[IC1]]) : tensor<2xi64>
// CHECK: %[[C0:.*]] = constant 0 : index
// CHECK: %[[EE0:.*]] = extract_element %[[SHAPE]][%[[C0]]] : tensor<2xi64>
// CHECK: %[[C0_:.*]] = constant 0 : index
// CHECK: %[[EE0:.*]] = extract_element %[[SHAPE]][%[[C0_]]] : tensor<2xi64>
// CHECK: %[[ICS0:.*]] = index_cast %[[EE0]] : i64 to index
// CHECK: %[[EE1:.*]] = extract_element %[[SHAPE]][%[[C1]]] : tensor<2xi64>
// CHECK: %[[C1_:.*]] = constant 1 : index
// CHECK: %[[EE1:.*]] = extract_element %[[SHAPE]][%[[C1_]]] : tensor<2xi64>
// CHECK: %[[ICS1:.*]] = index_cast %[[EE1]] : i64 to index
// CHECK: %[[RESULT:.*]] = alloc(%[[ICS0]], %[[ICS1]])
// CHECK: "xla_lhlo.tanh"(%arg0, %[[RESULT]]) : (memref<?x?xf32>, memref<?x?xf32>) -> ()

View File

@ -340,36 +340,36 @@ func @transpose(%arg0: tensor<2x3x9x5xi32>) -> tensor<3x2x5x9xi32> {
// -----
// CHECK-DAG: #[[OPERAND_MAP:.*]] = affine_map<(d0, d1) -> (d0, 0, d1)>
// CHECK-DAG: #[[RESULT_MAP:.*]] = affine_map<(d0, d1) -> (d0, d1)>
// CHECK-DAG: #[[RESHAPE_MAP1:.*]] = affine_map<(d0, d1, d2) -> (d0, d1)>
// CHECK-DAG: #[[RESHAPE_MAP2:.*]] = affine_map<(d0, d1, d2) -> (d2)>
// CHECK-LABEL: func @reshape_3D_2D
func @reshape_3D_2D(%arg0: tensor<12x1x42xi32>) -> tensor<12x42xi32> {
%0 = "xla_hlo.reshape"(%arg0) : (tensor<12x1x42xi32>) -> tensor<12x42xi32>
return %0 : tensor<12x42xi32>
}
// CHECK: linalg.generic {{{.*}}indexing_maps = [#[[OPERAND_MAP]], #[[RESULT_MAP]]]
// CHECK: linalg.tensor_reshape %{{.*}} [#[[RESHAPE_MAP1]], #[[RESHAPE_MAP2]]]
// -----
// CHECK-DAG: #[[OPERAND_MAP:.*]] = affine_map<(d0, d1) -> (d0, d1, 0, 0)>
// CHECK-DAG: #[[RESULT_MAP:.*]] = affine_map<(d0, d1) -> (d0, d1)>
// CHECK-DAG: #[[RESHAPE_MAP1:.*]] = affine_map<(d0, d1, d2, d3) -> (d0)>
// CHECK-DAG: #[[RESHAPE_MAP2:.*]] = affine_map<(d0, d1, d2, d3) -> (d1, d2, d3)>
// CHECK-LABEL: func @reshape_4D_2D
func @reshape_4D_2D(%arg0: tensor<12x42x1x1xi32>) -> tensor<12x42xi32> {
%0 = "xla_hlo.reshape"(%arg0) : (tensor<12x42x1x1xi32>) -> tensor<12x42xi32>
return %0 : tensor<12x42xi32>
}
// CHECK: linalg.generic {{{.*}}indexing_maps = [#[[OPERAND_MAP]], #[[RESULT_MAP]]]
// CHECK: linalg.tensor_reshape %{{.*}} [#[[RESHAPE_MAP1]], #[[RESHAPE_MAP2]]]
// -----
// CHECK-DAG: #[[OPERAND_MAP:.*]] = affine_map<(d0, d1, d2, d3) -> (d0, d2)>
// CHECK-DAG: #[[RESULT_MAP:.*]] = affine_map<(d0, d1, d2, d3) -> (d0, d1, d2, d3)>
// CHECK-DAG: #[[RESHAPE_MAP1:.*]] = affine_map<(d0, d1, d2, d3) -> (d0, d1)>
// CHECK-DAG: #[[RESHAPE_MAP2:.*]] = affine_map<(d0, d1, d2, d3) -> (d2, d3)>
// CHECK-LABEL: func @reshape_2D_4D
func @reshape_2D_4D(%arg0: tensor<12x42xi32>) -> tensor<12x1x42x1xi32> {
%0 = "xla_hlo.reshape"(%arg0) : (tensor<12x42xi32>) -> tensor<12x1x42x1xi32>
return %0 : tensor<12x1x42x1xi32>
}
// CHECK: linalg.generic {{{.*}}indexing_maps = [#[[OPERAND_MAP]], #[[RESULT_MAP]]]
// CHECK: linalg.tensor_reshape %{{.*}} [#[[RESHAPE_MAP1]], #[[RESHAPE_MAP2]]]
// -----
@ -407,7 +407,8 @@ func @add_scalar(%lhs: tensor<f32>, %rhs: tensor<f32>) -> tensor<f32> {
%0 = "xla_hlo.add"(%lhs, %rhs) : (tensor<f32>, tensor<f32>) -> tensor<f32>
return %0 : tensor<f32>
}
// CHECK: linalg.generic {{{.*}}indexing_maps = [#[[MAP]], #[[MAP]], #[[MAP]]]
// CHECK: linalg.generic
// CHECK-SAME: indexing_maps = [#[[MAP]], #[[MAP]], #[[MAP]]]
// CHECK-NEXT: ^bb0(%[[LHS:.*]]: f32, %[[RHS:.*]]: f32):
// CHECK: %[[RESULT:.*]] = addf %[[LHS]], %[[RHS]]
// CHECK-NEXT: linalg.yield %[[RESULT]] : f32
@ -554,4 +555,5 @@ func @reverse(%input: tensor<2x3xf32>) -> tensor<2x3xf32> {
} : (tensor<2x3xf32>) -> tensor<2x3xf32>
return %result : tensor<2x3xf32>
}
// CHECK: linalg.generic {{{.*}}indexing_maps = [#[[OPERAND_MAP]], #[[RESULT_MAP]]]
// CHECK: linalg.generic
// CHECK-SAME: indexing_maps = [#[[OPERAND_MAP]], #[[RESULT_MAP]]]

View File

@ -14,10 +14,10 @@ func @batchmatmulv2_basic(%arg0: tensor<1x4x2xf32>, %arg1: tensor<3x2x4xf32>) ->
// CHECK: [[RHSHEAD:%.*]], [[RHSTAIL:%.*]] = "shape.split_at"([[RHSSHAPE]], [[CM2]]) : (!shape.shape, i32) -> (!shape.shape, !shape.shape)
// CHECK: [[BCASTHEAD:%.*]] = "shape.broadcast"([[LHSHEAD]], [[RHSHEAD]]) : (!shape.shape, !shape.shape) -> !shape.shape
// CHECK: [[LHSBCASTSHAPE:%.*]] = "shape.concat"([[BCASTHEAD]], [[LHSTAIL]]) : (!shape.shape, !shape.shape) -> !shape.shape
// CHECK: [[LHSSHAPEEXTENTS:%.*]] = "shape.to_extent_tensor"([[LHSBCASTSHAPE]]) : (!shape.shape) -> tensor<3xindex>
// CHECK: [[LHSSHAPEEXTENTS:%.*]] = shape.to_extent_tensor [[LHSBCASTSHAPE]] : tensor<3xindex>
// CHECK: [[LHSBCAST:%.*]] = "xla_hlo.dynamic_broadcast_in_dim"([[LHS]], [[LHSSHAPEEXTENTS]]) {broadcast_dimensions = dense<[0, 1, 2]> : tensor<3xi64>} : (tensor<1x4x2xf32>, tensor<3xindex>) -> tensor<3x4x2xf32>
// CHECK: [[RHSBCASTSHAPE:%.*]] = "shape.concat"([[BCASTHEAD]], [[RHSTAIL]]) : (!shape.shape, !shape.shape) -> !shape.shape
// CHECK: [[RHSSHAPEEXTENTS:%.*]] = "shape.to_extent_tensor"([[RHSBCASTSHAPE]]) : (!shape.shape) -> tensor<3xindex>
// CHECK: [[RHSSHAPEEXTENTS:%.*]] = shape.to_extent_tensor [[RHSBCASTSHAPE]] : tensor<3xindex>
// CHECK: [[RHSBCAST:%.*]] = "xla_hlo.dynamic_broadcast_in_dim"([[RHS]], [[RHSSHAPEEXTENTS]]) {broadcast_dimensions = dense<[0, 1, 2]> : tensor<3xi64>} : (tensor<3x2x4xf32>, tensor<3xindex>) -> tensor<3x2x4xf32>
// CHECK: [[RESULT:%.*]] = "xla_hlo.dot_general"([[LHSBCAST]], [[RHSBCAST]]) {dot_dimension_numbers = {lhs_batching_dimensions = dense<0> : tensor<1xi64>, lhs_contracting_dimensions = dense<2> : tensor<1xi64>, rhs_batching_dimensions = dense<0> : tensor<1xi64>, rhs_contracting_dimensions = dense<1> : tensor<1xi64>}} : (tensor<3x4x2xf32>, tensor<3x2x4xf32>) -> tensor<3x4x4xf32>
// CHECK: return [[RESULT]] : tensor<3x4x4xf32>

View File

@ -27,7 +27,7 @@ func @broadcast_add(%arg0: tensor<1xi32>, %arg1: tensor<1x2xi32>) -> tensor<1x2x
// CHECK: %[[UNUSED_LHS_SHAPE:.+]] = shape.const_shape [1]
// CHECK: %[[UNUSED_RHS_SHAPE:.+]] = shape.const_shape [1, 2]
// CHECK: %[[RESULT_SHAPE:.+]] = shape.const_shape [1, 2]
// CHECK-DAG: %[[RESULT_EXTENTS:.+]] = "shape.to_extent_tensor"(%[[RESULT_SHAPE]])
// CHECK-DAG: %[[RESULT_EXTENTS:.+]] = shape.to_extent_tensor %[[RESULT_SHAPE]]
// CHECK-DAG: %[[LHS_BCAST:.+]] = "xla_hlo.dynamic_broadcast_in_dim"(%arg0, %[[RESULT_EXTENTS]]) {broadcast_dimensions = dense<1> : tensor<1xi64>}
// CHECK-DAG: %[[RHS_BCAST:.+]] = "xla_hlo.dynamic_broadcast_in_dim"(%arg1, %[[RESULT_EXTENTS]]) {broadcast_dimensions = dense<[0, 1]> : tensor<2xi64>}
// CHECK: xla_hlo.add %[[LHS_BCAST]], %[[RHS_BCAST]]
@ -42,7 +42,7 @@ func @broadcast_multi_dim_add(%arg0: tensor<4x1x1xi32>, %arg1: tensor<4x4x4x4xi3
// CHECK: %[[UNUSED_LHS_SHAPE:.+]] = shape.const_shape [4, 1, 1]
// CHECK: %[[UNUSED_RHS_SHAPE:.+]] = shape.const_shape [4, 4, 4, 4]
// CHECK: %[[RESULT_SHAPE:.+]] = shape.const_shape [4, 4, 4, 4]
// CHECK-DAG: %[[RESULT_EXTENTS:.+]] = "shape.to_extent_tensor"(%[[RESULT_SHAPE]])
// CHECK-DAG: %[[RESULT_EXTENTS:.+]] = shape.to_extent_tensor %[[RESULT_SHAPE]]
// CHECK-DAG: %[[LHS_BCAST:.+]] = "xla_hlo.dynamic_broadcast_in_dim"(%arg0, %[[RESULT_EXTENTS]]) {broadcast_dimensions = dense<[1, 2, 3]> : tensor<3xi64>}
// CHECK-DAG: %[[RHS_BCAST:.+]] = "xla_hlo.dynamic_broadcast_in_dim"(%arg1, %[[RESULT_EXTENTS]]) {broadcast_dimensions = dense<[0, 1, 2, 3]> : tensor<4xi64>}
// CHECK: xla_hlo.add %[[LHS_BCAST]], %[[RHS_BCAST]]
@ -55,7 +55,7 @@ func @add_dynamic(%arg0: tensor<?xi32>, %arg1: tensor<?x?xi32>) -> tensor<?x?xi3
// CHECK-DAG: %[[LHS_SHAPE:.+]] = shape.shape_of %arg0
// CHECK-DAG: %[[RHS_SHAPE:.+]] = shape.shape_of %arg1
// CHECK-DAG: %[[RESULT_SHAPE:.+]] = "shape.broadcast"(%[[LHS_SHAPE]], %[[RHS_SHAPE]])
// CHECK-DAG: %[[RESULT_EXTENTS:.+]] = "shape.to_extent_tensor"(%[[RESULT_SHAPE]])
// CHECK-DAG: %[[RESULT_EXTENTS:.+]] = shape.to_extent_tensor %[[RESULT_SHAPE]]
// CHECK-DAG: %[[LHS_BCAST:.+]] = "xla_hlo.dynamic_broadcast_in_dim"(%arg0, %[[RESULT_EXTENTS]]) {broadcast_dimensions = dense<1> : tensor<1xi64>}
// CHECK-DAG: %[[RHS_BCAST:.+]] = "xla_hlo.dynamic_broadcast_in_dim"(%arg1, %[[RESULT_EXTENTS]]) {broadcast_dimensions = dense<[0, 1]> : tensor<2xi64>}
// CHECK: xla_hlo.add %4, %5 : tensor<?x?xi32>
@ -203,7 +203,7 @@ func @equal_dynamic(%arg0: tensor<?xi32>, %arg1: tensor<1xi32>) -> tensor<?xi1>
// CHECK-DAG: %[[LHS_SHAPE:.+]] = shape.shape_of %arg0
// CHECK-DAG: %[[RHS_SHAPE:.+]] = shape.const_shape [1]
// CHECK-DAG: %[[RESULT_SHAPE:.+]] = "shape.broadcast"(%[[LHS_SHAPE]], %[[RHS_SHAPE]])
// CHECK-DAG: %[[RESULT_EXTENTS:.+]] = "shape.to_extent_tensor"(%[[RESULT_SHAPE]])
// CHECK-DAG: %[[RESULT_EXTENTS:.+]] = shape.to_extent_tensor %[[RESULT_SHAPE]]
// CHECK-DAG: %[[LHS_BCAST:.+]] = "xla_hlo.dynamic_broadcast_in_dim"(%arg0, %[[RESULT_EXTENTS]]) {broadcast_dimensions = dense<0> : tensor<1xi64>}
// CHECK-DAG: %[[RHS_BCAST:.+]] = "xla_hlo.dynamic_broadcast_in_dim"(%arg1, %[[RESULT_EXTENTS]]) {broadcast_dimensions = dense<0> : tensor<1xi64>}
// CHECK: "xla_hlo.compare"(%[[LHS_BCAST]], %[[RHS_BCAST]]) {comparison_direction = "EQ"}
@ -216,7 +216,7 @@ func @equal_broadcast(%arg0: tensor<1xi32>, %arg1: tensor<1x2xi32>) -> tensor<1x
// CHECK-DAG: %[[LHS_SHAPE:.+]] = shape.const_shape [1]
// CHECK-DAG: %[[RHS_SHAPE:.+]] = shape.const_shape [1, 2]
// CHECK-DAG: %[[RESULT_SHAPE:.+]] = shape.const_shape [1, 2]
// CHECK-DAG: %[[RESULT_EXTENTS:.+]] = "shape.to_extent_tensor"(%[[RESULT_SHAPE]])
// CHECK-DAG: %[[RESULT_EXTENTS:.+]] = shape.to_extent_tensor %[[RESULT_SHAPE]]
// CHECK-DAG: %[[LHS_BCAST:.+]] = "xla_hlo.dynamic_broadcast_in_dim"(%arg0, %[[RESULT_EXTENTS]]) {broadcast_dimensions = dense<1> : tensor<1xi64>}
// CHECK-DAG: %[[RHS_BCAST:.+]] = "xla_hlo.dynamic_broadcast_in_dim"(%arg1, %[[RESULT_EXTENTS]]) {broadcast_dimensions = dense<[0, 1]> : tensor<2xi64>}
// CHECK: "xla_hlo.compare"(%[[LHS_BCAST]], %[[RHS_BCAST]]) {comparison_direction = "EQ"}
@ -284,7 +284,7 @@ func @broadcast_greater(%arg0: tensor<1xi32>, %arg1: tensor<1x2xi32>) -> tensor<
// CHECK-DAG: %[[LHS_SHAPE:.+]] = shape.const_shape [1]
// CHECK-DAG: %[[RHS_SHAPE:.+]] = shape.const_shape [1, 2]
// CHECK-DAG: %[[RESULT_SHAPE:.+]] = shape.const_shape [1, 2]
// CHECK-DAG: %[[RESULT_EXTENTS:.+]] = "shape.to_extent_tensor"(%[[RESULT_SHAPE]])
// CHECK-DAG: %[[RESULT_EXTENTS:.+]] = shape.to_extent_tensor %[[RESULT_SHAPE]]
// CHECK-DAG: %[[LHS_BCAST:.+]] = "xla_hlo.dynamic_broadcast_in_dim"(%arg0, %[[RESULT_EXTENTS]]) {broadcast_dimensions = dense<1> : tensor<1xi64>}
// CHECK-DAG: %[[RHS_BCAST:.+]] = "xla_hlo.dynamic_broadcast_in_dim"(%arg1, %[[RESULT_EXTENTS]]) {broadcast_dimensions = dense<[0, 1]> : tensor<2xi64>}
// CHECK: "xla_hlo.compare"(%[[LHS_BCAST]], %[[RHS_BCAST]]) {comparison_direction = "GT"}
@ -297,7 +297,7 @@ func @greater_dynamic(%arg0: tensor<?xi32>, %arg1: tensor<?xi32>) -> tensor<?xi1
// CHECK-DAG: %[[LHS_SHAPE:.+]] = shape.shape_of %arg0
// CHECK-DAG: %[[RHS_SHAPE:.+]] = shape.shape_of %arg1
// CHECK-DAG: %[[RESULT_SHAPE:.+]] = "shape.broadcast"(%[[LHS_SHAPE]], %[[RHS_SHAPE]])
// CHECK-DAG: %[[RESULT_EXTENTS:.+]] = "shape.to_extent_tensor"(%[[RESULT_SHAPE]])
// CHECK-DAG: %[[RESULT_EXTENTS:.+]] = shape.to_extent_tensor %[[RESULT_SHAPE]]
// CHECK-DAG: %[[LHS_BCAST:.+]] = "xla_hlo.dynamic_broadcast_in_dim"(%arg0, %[[RESULT_EXTENTS]]) {broadcast_dimensions = dense<0> : tensor<1xi64>}
// CHECK-DAG: %[[RHS_BCAST:.+]] = "xla_hlo.dynamic_broadcast_in_dim"(%arg1, %[[RESULT_EXTENTS]]) {broadcast_dimensions = dense<0> : tensor<1xi64>}
// CHECK: "xla_hlo.compare"(%[[LHS_BCAST]], %[[RHS_BCAST]]) {comparison_direction = "GT"}

View File

@ -187,6 +187,70 @@ func @dynamic_update_slice(%arg0: tensor<3x4xi32>, %arg1: tensor<2x2xi32>, %arg2
return %0: tensor<3x4xi32>
}
// CHECK-LABEL: @sparse_to_dense
// CHECK-SAME: (%[[ARG0:.*]]: tensor<3x2xi32>, %[[ARG1:.*]]: tensor<3xf32>, %[[ARG2:.*]]: tensor<f32>)
func @sparse_to_dense(%arg0: tensor<3x2xi32>, %arg1: tensor<3xf32>, %arg2: tensor<f32>) -> tensor<3x3xf32> {
// CHECK: %[[CST:.*]] = xla_hlo.constant dense<3> : tensor<2xi32>
// CHECK: %[[DEFAULT:.*]] = "xla_hlo.broadcast_in_dim"(%[[ARG2]]) {broadcast_dimensions = dense<[]> : tensor<0xi64>} : (tensor<f32>) -> tensor<3x3xf32>
// CHECK: %[[RESULT:.*]] = "xla_hlo.scatter"(%[[DEFAULT]], %[[ARG0]], %[[ARG1]]) ( {
// CHECK: ^bb0(%[[ARG3:.*]]: tensor<f32>, %[[ARG4:.*]]: tensor<f32>): // no predecessors
// CHECK: "xla_hlo.return"(%[[ARG4]]) : (tensor<f32>) -> ()
// CHECK: })
// CHECK-SAME: indices_are_sorted = false
// CHECK-SAME: scatter_dimension_numbers
// CHECK-SAME: index_vector_dim = 1 : i64
// CHECK-SAME: inserted_window_dims = dense<[0, 1]> : tensor<2xi64>
// CHECK-SAME: scatter_dims_to_operand_dims = dense<[0, 1]> : tensor<2xi64>
// CHECK-SAME: update_window_dims = dense<[]> : tensor<0xi64>
// CHECK-SAME: unique_indices = false
// CHECK-SAME: (tensor<3x3xf32>, tensor<3x2xi32>, tensor<3xf32>) -> tensor<3x3xf32>
// return %[[RESULT]] : tensor<3x3xf32>
%cst = xla_hlo.constant dense<3> : tensor<2xi32>
%0 = "tf.SparseToDense"(%arg0, %cst, %arg1, %arg2) {validate_indices = true}: (tensor<3x2xi32>, tensor<2xi32>, tensor<3xf32>, tensor<f32>) -> tensor<3x3xf32>
return %0 : tensor<3x3xf32>
}
// CHECK-LABEL: fft
func @fft(%arg0: tensor<3x5x8xcomplex<f32>>) -> tensor<3x5x8xcomplex<f32>> {
// CHECK: "xla_hlo.fft"(%arg0)
%0 = "tf.FFT"(%arg0) : (tensor<3x5x8xcomplex<f32>>) -> tensor<3x5x8xcomplex<f32>>
return %0 : tensor<3x5x8xcomplex<f32>>
}
// CHECK-LABEL: reverse_sequence
func @reverse_sequence(%arg0: tensor<4x2x3x1x1xi32>, %arg1: tensor<3xi32>) -> tensor<4x2x3x1x1xi32> {
// CHECK-NOT: tf.ReverseSequence
%0 = "tf.ReverseSequence"(%arg0, %arg1) {batch_dim = 2 : i64, seq_dim = 0 : i64}: (tensor<4x2x3x1x1xi32>, tensor<3xi32>) -> tensor<4x2x3x1x1xi32>
return %0 : tensor<4x2x3x1x1xi32>
}
// CHECK-LABEL: mirror_pad
func @mirror_pad(%arg0: tensor<2x3xcomplex<f64>>) -> tensor<4x7xcomplex<f64>> {
%0 = xla_hlo.constant dense<[[1, 1], [2, 2]]> : tensor<2x2xi32>
// CHECK-NOT: tf.MirrorPad
%1 = "tf.MirrorPad"(%arg0, %0) {mode = "SYMMETRIC"} : (tensor<2x3xcomplex<f64>>, tensor<2x2xi32>) -> tensor<4x7xcomplex<f64>>
return %1 : tensor<4x7xcomplex<f64>>
}
// CHECK-LABEL: bucketize
func @bucketize(%arg0: tensor<2x5xf32>) -> tensor<2x5xi32> {
// CHECK-NOT: tf.Bucketize
%0 = "tf.Bucketize"(%arg0) {boundaries = [0.000000e+00 : f32, 3.000000e+00 : f32, 8.000000e+00 : f32, 1.100000e+01 : f32]} : (tensor<2x5xf32>) -> tensor<2x5xi32>
return %0 : tensor<2x5xi32>
}
// CHECK-LABEL: arg_min
func @arg_min(%arg0: tensor<6xf64>) -> tensor<i32> {
// CHECK-NOT: ArgMin
%0 = xla_hlo.constant dense<0> : tensor<i32>
%1 = "tf.ArgMin"(%arg0, %0) : (tensor<6xf64>, tensor<i32>) -> tensor<i32>
return %1 : tensor<i32>
}
// TODO(hinsu): Add a test with a valid TF op for which tf2xla kernel is
// available but doesn't support this instance.
}

View File

@ -420,7 +420,7 @@ func @fusedBatchNormGradV3_Training_NCHW(%arg0: tensor<8x8x8x8xf32>, %arg1: tens
// CHECK-LABEL: func @biasAdd_NHWC
func @biasAdd_NHWC(%arg0: tensor<1x32x10x32xi32>, %arg1: tensor<32xi32>) -> tensor<1x32x10x32xi32> {
// CHECK: %[[ARG0_SHAPE:.+]] = shape.shape_of %arg0
// CHECK: %[[ARG0_EXTENTS:.+]] = "shape.to_extent_tensor"(%[[ARG0_SHAPE]])
// CHECK: %[[ARG0_EXTENTS:.+]] = shape.to_extent_tensor %[[ARG0_SHAPE]]
// CHECK: %[[ARG1_BCAST:.+]] = "xla_hlo.dynamic_broadcast_in_dim"(%arg1, %[[ARG0_EXTENTS]])
// CHECK-SAME: {broadcast_dimensions = dense<3> : tensor<1xi64>}
// CHECK: %[[RESULT:.+]] = xla_hlo.add %arg0, %[[ARG1_BCAST]]
@ -431,7 +431,7 @@ func @biasAdd_NHWC(%arg0: tensor<1x32x10x32xi32>, %arg1: tensor<32xi32>) -> tens
// CHECK-LABEL: func @biasAdd_NCHW
func @biasAdd_NCHW(%arg0: tensor<1x32x10x32xi32>, %arg1: tensor<32xi32>) -> tensor<1x32x10x32xi32> {
// CHECK: %[[ARG0_SHAPE:.+]] = shape.shape_of %arg0
// CHECK: %[[ARG0_EXTENTS:.+]] = "shape.to_extent_tensor"(%[[ARG0_SHAPE]])
// CHECK: %[[ARG0_EXTENTS:.+]] = shape.to_extent_tensor %[[ARG0_SHAPE]]
// CHECK: %[[ARG1_BCAST:.+]] = "xla_hlo.dynamic_broadcast_in_dim"(%arg1, %[[ARG0_EXTENTS]])
// CHECK-SAME: {broadcast_dimensions = dense<1> : tensor<1xi64>}
// CHECK: %[[RESULT:.+]] = xla_hlo.add %arg0, %[[ARG1_BCAST]]
@ -442,7 +442,7 @@ func @biasAdd_NCHW(%arg0: tensor<1x32x10x32xi32>, %arg1: tensor<32xi32>) -> tens
// CHECK-LABEL: func @biasAdd_dynamic
func @biasAdd_dynamic(%arg0: tensor<?x?x?x?xi32>, %arg1: tensor<?xi32>) -> tensor<?x?x?x?xi32> {
// CHECK: %[[ARG0_SHAPE:.+]] = shape.shape_of %arg0
// CHECK: %[[ARG0_EXTENTS:.+]] = "shape.to_extent_tensor"(%[[ARG0_SHAPE]])
// CHECK: %[[ARG0_EXTENTS:.+]] = shape.to_extent_tensor %[[ARG0_SHAPE]]
// CHECK: %[[ARG1_BCAST:.+]] = "xla_hlo.dynamic_broadcast_in_dim"(%arg1, %[[ARG0_EXTENTS]])
// CHECK-SAME: {broadcast_dimensions = dense<1> : tensor<1xi64>}
// CHECK: %[[RESULT:.+]] = xla_hlo.add %arg0, %[[ARG1_BCAST]]
@ -1445,7 +1445,7 @@ func @simple_softmax(%arg0: tensor<2x3xf32>) -> tensor<2x3xf32> {
// CHECK: %[[CASTED_MAX:.*]] = "xla_hlo.convert"(%[[MAX]]) : (tensor<2xf32>) -> tensor<2xf32>
// CHECK: %[[RESULT_SHAPE:.+]] = shape.shape_of %[[ARG0]]
// CHECK: %[[RESULT_EXTENTS:.+]] = "shape.to_extent_tensor"(%[[RESULT_SHAPE]]) : (!shape.shape) -> tensor<2xindex>
// CHECK: %[[RESULT_EXTENTS:.+]] = shape.to_extent_tensor %[[RESULT_SHAPE]] : tensor<2xindex>
// CHECK: %[[BCAST_MAX:.+]] = "xla_hlo.dynamic_broadcast_in_dim"(%[[CASTED_MAX]], %[[RESULT_EXTENTS]]) {broadcast_dimensions = dense<0> : tensor<1xi64>}
// CHECK: %[[SHIFTED_INP:.*]] = xla_hlo.subtract %[[ARG0]], %[[BCAST_MAX]]
// CHECK: %[[EXP:.*]] = "xla_hlo.exponential"(%[[SHIFTED_INP]])
@ -1460,7 +1460,7 @@ func @simple_softmax(%arg0: tensor<2x3xf32>) -> tensor<2x3xf32> {
// CHECK: %[[CASTED_SUM:.*]] = "xla_hlo.convert"(%[[SUM]]) : (tensor<2xf32>) -> tensor<2xf32>
// CHECK: %[[RESULT_SHAPE:.+]] = shape.shape_of %[[ARG0]]
// CHECK: %[[RESULT_EXTENTS:.+]] = "shape.to_extent_tensor"(%[[RESULT_SHAPE]]) : (!shape.shape) -> tensor<2xindex>
// CHECK: %[[RESULT_EXTENTS:.+]] = shape.to_extent_tensor %[[RESULT_SHAPE]] : tensor<2xindex>
// CHECK: %[[BCAST_SUM:.+]] = "xla_hlo.dynamic_broadcast_in_dim"(%[[CASTED_SUM]], %[[RESULT_EXTENTS]]) {broadcast_dimensions = dense<0> : tensor<1xi64>}
// CHECK: %[[RESULT:.*]] = xla_hlo.divide %[[EXP]], %[[BCAST_SUM]]
// CHECK: return %[[RESULT]]
@ -1517,7 +1517,7 @@ func @simple_logsoftmax(%arg0: tensor<2x3xf32>) -> tensor<2x3xf32> {
// CHECK: %[[CASTED_SUM:.*]] = "xla_hlo.convert"(%[[SUM]]) : (tensor<2xf32>) -> tensor<2xf32>
// CHECK: %[[LOG:.*]] = "xla_hlo.log"(%[[CASTED_SUM]]) : (tensor<2xf32>) -> tensor<2xf32>
// CHECK: %[[RESULT_SHAPE:.+]] = shape.shape_of %[[ARG0]]
// CHECK: %[[RESULT_EXTENTS:.+]] = "shape.to_extent_tensor"(%[[RESULT_SHAPE]]) : (!shape.shape) -> tensor<2xindex>
// CHECK: %[[RESULT_EXTENTS:.+]] = shape.to_extent_tensor %[[RESULT_SHAPE]] : tensor<2xindex>
// CHECK: %[[BCAST_SUM:.+]] = "xla_hlo.dynamic_broadcast_in_dim"(%[[LOG]], %[[RESULT_EXTENTS]]) {broadcast_dimensions = dense<0> : tensor<1xi64>}
// CHECK: %[[RESULT:.*]] = xla_hlo.subtract {{.*}}, %[[BCAST_SUM]]
// CHECK: return %[[RESULT]]
@ -1544,58 +1544,34 @@ func @rfft_1D(%arg0: tensor<8xf32>) -> tensor<8xcomplex<f32>> {
// CHECK-LABEL: func @shape_1D
func @shape_1D(%arg0: tensor<?xf32>) -> tensor<1xi32> {
// CHECK-DAG: [[SHAPE:%.+]] = shape.shape_of %arg0
// CHECK-DAG: [[EXTENT:%.+]] = shape.get_extent [[SHAPE]], 0
// CHECK-DAG: [[TO_INDEX:%.+]] = shape.size_to_index [[EXTENT]]
// CHECK-DAG: [[CAST:%.+]] = index_cast [[TO_INDEX]]
// CHECK-DAG: [[TENSOR:%.+]] = tensor_from_elements([[CAST]])
// CHECK-DAG: [[RESHAPE:%.+]] = "xla_hlo.reshape"([[TENSOR]])
// CHECK-DAG: [[CONCAT:%.+]] = "xla_hlo.concatenate"([[RESHAPE]]) {dimension = 0 : i64}
// CHECK: [[SHAPE:%.+]] = shape.shape_of %arg0
// CHECK: [[TENSOR:%.+]] = shape.to_extent_tensor [[SHAPE]]
// CHECK: [[CAST:%.+]] = index_cast [[TENSOR]]
%0 = "tf.Shape"(%arg0) : (tensor<?xf32>) -> tensor<1xi32>
// CHECK: return [[CONCAT]]
// CHECK: return [[CAST]]
return %0 : tensor<1xi32>
}
// CHECK-LABEL: func @shape_2D
func @shape_2D(%arg0: tensor<?x?xf32>) -> tensor<2xi32> {
// CHECK-DAG: [[SHAPE:%.+]] = shape.shape_of %arg0
// CHECK-DAG: [[EXTENT0:%.+]] = shape.get_extent [[SHAPE]], 0
// CHECK-DAG: [[EXTENT1:%.+]] = shape.get_extent [[SHAPE]], 1
// CHECK-DAG: [[TO_INDEX0:%.+]] = shape.size_to_index [[EXTENT0]]
// CHECK-DAG: [[TO_INDEX1:%.+]] = shape.size_to_index [[EXTENT1]]
// CHECK-DAG: [[CAST0:%.+]] = index_cast [[TO_INDEX0]]
// CHECK-DAG: [[CAST1:%.+]] = index_cast [[TO_INDEX1]]
// CHECK-DAG: [[TENSOR0:%.+]] = tensor_from_elements([[CAST0]])
// CHECK-DAG: [[TENSOR1:%.+]] = tensor_from_elements([[CAST1]])
// CHECK-DAG: [[RESHAPE0:%.+]] = "xla_hlo.reshape"([[TENSOR0]])
// CHECK-DAG: [[RESHAPE1:%.+]] = "xla_hlo.reshape"([[TENSOR1]])
// CHECK-DAG: [[CONCAT:%.+]] = "xla_hlo.concatenate"([[RESHAPE0]], [[RESHAPE1]]) {dimension = 0 : i64}
// CHECK: [[SHAPE:%.+]] = shape.shape_of %arg0
// CHECK: [[TENSOR:%.+]] = shape.to_extent_tensor [[SHAPE]]
// CHECK: [[CAST:%.+]] = index_cast [[TENSOR]]
%0 = "tf.Shape"(%arg0) : (tensor<?x?xf32>) -> tensor<2xi32>
// CHECK: return [[CONCAT]]
return %0 : tensor<2xi32>
}
// CHECK-LABEL: func @shape_with_const
func @shape_with_const(%arg0: tensor<?x3xf32>) -> tensor<2xi32> {
// CHECK-DAG: [[SHAPE:%.+]] = shape.shape_of %arg0
// CHECK-DAG: [[EXTENT:%.+]] = shape.get_extent [[SHAPE]], 0
// CHECK-DAG: [[TO_INDEX:%.+]] = shape.size_to_index [[EXTENT]]
// CHECK-DAG: [[CAST:%.+]] = index_cast [[TO_INDEX]]
// CHECK-DAG: [[TENSOR:%.+]] = tensor_from_elements([[CAST]])
// CHECK-DAG: [[RESHAPE:%.+]] = "xla_hlo.reshape"([[TENSOR]])
// CHECK-DAG: [[CONST:%.+]] = xla_hlo.constant dense<3>
// CHECK-DAG: [[CONCAT:%.+]] = "xla_hlo.concatenate"([[RESHAPE]], [[CONST]]) {dimension = 0 : i64}
%0 = "tf.Shape"(%arg0) : (tensor<?x3xf32>) -> tensor<2xi32>
// CHECK: return [[CONCAT]]
// CHECK: return [[CAST]]
return %0 : tensor<2xi32>
}
// CHECK-LABEL: func @shape_rankless
func @shape_rankless(%arg0: tensor<*xf32>) -> tensor<?xi32> {
// CHECK: [[SHAPE:%.+]] = shape.shape_of %arg0
// CHECK: [[TENSOR:%.+]] = shape.to_extent_tensor [[SHAPE]]
// CHECK: [[CAST:%.+]] = index_cast [[TENSOR]]
%0 = "tf.Shape"(%arg0) : (tensor<*xf32>) -> tensor<?xi32>
// CHECK: return [[CAST]]
return %0 : tensor<?xi32>
}
@ -1884,7 +1860,7 @@ func @neg_unranked(%arg0: tensor<*xf32>) -> tensor<*xf32> {
func @sigmoid(%arg0: tensor<2xf32>) -> tensor<2xf32> {
// CHECK-DAG: [[SCALAR:%.+]] = xla_hlo.constant dense<5.000000e-01> : tensor<f32>
// CHECK-DAG: [[SHAPE:%.+]] = shape.shape_of %arg0 : tensor<2xf32>
// CHECK-DAG: [[SHAPE_VAL:%.+]] = "shape.to_extent_tensor"([[SHAPE]]) : (!shape.shape) -> tensor<1xindex>
// CHECK-DAG: [[SHAPE_VAL:%.+]] = shape.to_extent_tensor [[SHAPE]] : tensor<1xindex>
// CHECK-DAG: [[HALF:%.+]] = "xla_hlo.dynamic_broadcast_in_dim"([[SCALAR]], [[SHAPE_VAL]]) {broadcast_dimensions = dense<[]> : tensor<0xi64>} : (tensor<f32>, tensor<1xindex>) -> tensor<2xf32>
// CHECK-DAG: [[R1:%.+]] = xla_hlo.multiply %arg0, [[HALF]] : tensor<2xf32>
// CHECK-DAG: [[R2:%.+]] = "xla_hlo.tanh"([[R1]]) : (tensor<2xf32>) -> tensor<2xf32>
@ -1906,7 +1882,7 @@ func @sigmoid_complex(%arg0: tensor<2xcomplex<f32>>) -> tensor<2xcomplex<f32>> {
func @sigmoid_unranked(%arg0: tensor<*xf32>) -> tensor<*xf32> {
// CHECK-DAG: [[SCALAR:%.+]] = xla_hlo.constant dense<5.000000e-01> : tensor<f32>
// CHECK-DAG: [[SHAPE:%.+]] = shape.shape_of %arg0 : tensor<*xf32>
// CHECK-DAG: [[SHAPE_VAL:%.+]] = "shape.to_extent_tensor"([[SHAPE]]) : (!shape.shape) -> tensor<?xindex>
// CHECK-DAG: [[SHAPE_VAL:%.+]] = shape.to_extent_tensor [[SHAPE]] : tensor<?xindex>
// CHECK-DAG: [[HALF:%.+]] = "xla_hlo.dynamic_broadcast_in_dim"([[SCALAR]], [[SHAPE_VAL]]) {broadcast_dimensions = dense<[]> : tensor<0xi64>} : (tensor<f32>, tensor<?xindex>) -> tensor<*xf32>
// CHECK-DAG: [[R1:%.+]] = xla_hlo.multiply %arg0, [[HALF]] : tensor<*xf32>
// CHECK-DAG: [[R2:%.+]] = "xla_hlo.tanh"([[R1]]) : (tensor<*xf32>) -> tensor<*xf32>
@ -3826,42 +3802,6 @@ func @random_shuffle_3D(%input: tensor<4x?x16xf32>) -> tensor<4x?x16xf32> {
return %0: tensor<4x?x16xf32>
}
//===----------------------------------------------------------------------===//
// tf.VariableShape legalization
//===----------------------------------------------------------------------===//
// CHECK-LABLE: @variable_shape32
func @variable_shape32(%input: tensor<!tf.resource<tensor<2x4x8xf32>>>) -> tensor<3xi32> {
// CHECK: [[CST:%.*]] = xla_hlo.constant dense<[2, 4, 8]> : tensor<3xi32>
// CHECK: [[CST_CAST:%.*]] = tensor_cast [[CST]]
%0 = "tf.VariableShape"(%input) : (tensor<!tf.resource<tensor<2x4x8xf32>>>) -> (tensor<3xi32>)
// CHECK: return [[CST_CAST]]
return %0: tensor<3xi32>
}
// CHECK-LABLE: @variable_shape64
func @variable_shape64(%input: tensor<!tf.resource<tensor<2x4x8xf32>>>) -> tensor<3xi64> {
// CHECK: [[CST:%.*]] = xla_hlo.constant dense<[2, 4, 8]> : tensor<3xi64>
// CHECK: [[CST_CAST:%.*]] = tensor_cast [[CST]]
%0 = "tf.VariableShape"(%input) : (tensor<!tf.resource<tensor<2x4x8xf32>>>) -> (tensor<3xi64>)
// CHECK: return [[CST_CAST]]
return %0: tensor<3xi64>
}
// CHECK-LABEL: @variable_shape_unknown_resource
func @variable_shape_unknown_resource(%input: tensor<!tf.resource>) -> tensor<?xi32> {
// CHECK: tf.VariableShape
%0 = "tf.VariableShape"(%input) : (tensor<!tf.resource>) -> (tensor<?xi32>)
return %0: tensor<?xi32>
}
// CHECK-LABEL: @variable_shape_unknown_resource_shape
func @variable_shape_unknown_resource_shape(%input: tensor<!tf.resource<tensor<?x?xf32>>>) -> tensor<2xi32> {
// CHECK: tf.VariableShape
%0 = "tf.VariableShape"(%input) : (tensor<!tf.resource<tensor<?x?xf32>>>) -> (tensor<2xi32>)
return %0: tensor<2xi32>
}
//===----------------------------------------------------------------------===//
// tf.AvgPool legalization
//===----------------------------------------------------------------------===//
@ -4025,3 +3965,87 @@ func @qr(%arg0: tensor<500x100x75xf32>) -> (tensor<500x100x75xf32>, tensor<500x7
%0:2 = "tf.Qr"(%arg0) {full_matrices = false} : (tensor<500x100x75xf32>) -> (tensor<500x100x75xf32>, tensor<500x75x75xf32>)
return %0#0, %0#1 : tensor<500x100x75xf32>, tensor<500x75x75xf32>
}
//===----------------------------------------------------------------------===//
// tf.Softplus legalization
//===----------------------------------------------------------------------===//
// CHECK-LABEL: func @softplus_f16
// CHECK-SAME: ([[FEATURES:%.*]]: tensor<8x16xf16>)
func @softplus_f16(%arg0: tensor<8x16xf16>) -> tensor<8x16xf16> {
// CHECK-DAG: [[FEATURES_EXP:%.*]] = "xla_hlo.exponential"([[FEATURES]])
// CHECK-DAG: [[EPSILON:%.*]] = xla_hlo.constant dense<1.220700e-04> : tensor<f16>
// CHECK-DAG: [[EPSILON_LOG:%.*]] = "xla_hlo.log"([[EPSILON]])
// CHECK-DAG: [[TWO:%.*]] = xla_hlo.constant dense<2.000000e+00> : tensor<f16>
// CHECK: [[THRESHOLD:%.*]] = xla_chlo.broadcast_add [[EPSILON_LOG]], [[TWO]]
// CHECK: [[NEG_THRESHOLD:%.*]] = "xla_hlo.negate"([[THRESHOLD]])
// CHECK-DAG: [[COMPARE_GT:%.*]] = xla_chlo.broadcast_compare [[FEATURES]], [[NEG_THRESHOLD]] {comparison_direction = "GT"}
// CHECK-DAG: [[COMPARE_LT:%.*]] = xla_chlo.broadcast_compare [[FEATURES]], [[THRESHOLD]] {comparison_direction = "LT"}
// CHECK-DAG: [[FEATURES_EXP_LOG:%.*]] = "xla_hlo.log_plus_one"([[FEATURES_EXP]])
// CHECK: [[ELSE_SELECT:%.*]] = "xla_hlo.select"([[COMPARE_LT]], [[FEATURES_EXP]], [[FEATURES_EXP_LOG]])
// CHECK: [[ENTRY_SELECT:%.*]] = "xla_hlo.select"([[COMPARE_GT]], [[FEATURES]], [[ELSE_SELECT]])
%0 = "tf.Softplus"(%arg0) : (tensor<8x16xf16>) -> tensor<8x16xf16>
// CHECK: return [[ENTRY_SELECT]] : tensor<8x16xf16>
return %0 : tensor<8x16xf16>
}
// CHECK-LABEL: func @softplus_bf16
// CHECK-SAME: ([[FEATURES:%.*]]: tensor<8x16xbf16>)
func @softplus_bf16(%arg0: tensor<8x16xbf16>) -> tensor<8x16xbf16> {
// CHECK-DAG: [[FEATURES_EXP:%.*]] = "xla_hlo.exponential"([[FEATURES]])
// CHECK-DAG: [[EPSILON:%.*]] = xla_hlo.constant dense<7.812500e-03> : tensor<bf16>
// CHECK-DAG: [[EPSILON_LOG:%.*]] = "xla_hlo.log"([[EPSILON]])
// CHECK-DAG: [[TWO:%.*]] = xla_hlo.constant dense<2.000000e+00> : tensor<bf16>
// CHECK: [[THRESHOLD:%.*]] = xla_chlo.broadcast_add [[EPSILON_LOG]], [[TWO]]
// CHECK: [[NEG_THRESHOLD:%.*]] = "xla_hlo.negate"([[THRESHOLD]])
// CHECK-DAG: [[COMPARE_GT:%.*]] = xla_chlo.broadcast_compare [[FEATURES]], [[NEG_THRESHOLD]] {comparison_direction = "GT"}
// CHECK-DAG: [[COMPARE_LT:%.*]] = xla_chlo.broadcast_compare [[FEATURES]], [[THRESHOLD]] {comparison_direction = "LT"}
// CHECK-DAG: [[FEATURES_EXP_LOG:%.*]] = "xla_hlo.log_plus_one"([[FEATURES_EXP]])
// CHECK: [[ELSE_SELECT:%.*]] = "xla_hlo.select"([[COMPARE_LT]], [[FEATURES_EXP]], [[FEATURES_EXP_LOG]])
// CHECK: [[ENTRY_SELECT:%.*]] = "xla_hlo.select"([[COMPARE_GT]], [[FEATURES]], [[ELSE_SELECT]])
%0 = "tf.Softplus"(%arg0) : (tensor<8x16xbf16>) -> tensor<8x16xbf16>
// CHECK: return [[ENTRY_SELECT]] : tensor<8x16xbf16>
return %0 : tensor<8x16xbf16>
}
// CHECK-LABEL: func @softplus_f32
// CHECK-SAME: ([[FEATURES:%.*]]: tensor<8x16xf32>)
func @softplus_f32(%arg0: tensor<8x16xf32>) -> tensor<8x16xf32> {
// CHECK-DAG: [[FEATURES_EXP:%.*]] = "xla_hlo.exponential"([[FEATURES]])
// CHECK-DAG: [[EPSILON:%.*]] = xla_hlo.constant dense<1.1920929E-7> : tensor<f32>
// CHECK-DAG: [[EPSILON_LOG:%.*]] = "xla_hlo.log"([[EPSILON]])
// CHECK-DAG: [[TWO:%.*]] = xla_hlo.constant dense<2.000000e+00> : tensor<f32>
// CHECK: [[THRESHOLD:%.*]] = xla_chlo.broadcast_add [[EPSILON_LOG]], [[TWO]]
// CHECK: [[NEG_THRESHOLD:%.*]] = "xla_hlo.negate"([[THRESHOLD]])
// CHECK-DAG: [[COMPARE_GT:%.*]] = xla_chlo.broadcast_compare [[FEATURES]], [[NEG_THRESHOLD]] {comparison_direction = "GT"}
// CHECK-DAG: [[COMPARE_LT:%.*]] = xla_chlo.broadcast_compare [[FEATURES]], [[THRESHOLD]] {comparison_direction = "LT"}
// CHECK-DAG: [[FEATURES_EXP_LOG:%.*]] = "xla_hlo.log_plus_one"([[FEATURES_EXP]])
// CHECK: [[ELSE_SELECT:%.*]] = "xla_hlo.select"([[COMPARE_LT]], [[FEATURES_EXP]], [[FEATURES_EXP_LOG]])
// CHECK: [[ENTRY_SELECT:%.*]] = "xla_hlo.select"([[COMPARE_GT]], [[FEATURES]], [[ELSE_SELECT]])
%0 = "tf.Softplus"(%arg0) : (tensor<8x16xf32>) -> tensor<8x16xf32>
// CHECK: return [[ENTRY_SELECT]] : tensor<8x16xf32>
return %0 : tensor<8x16xf32>
}
// CHECK-LABEL: func @softplus_f64
// CHECK-SAME: ([[FEATURES:%.*]]: tensor<8x16xf64>)
func @softplus_f64(%arg0: tensor<8x16xf64>) -> tensor<8x16xf64> {
// CHECK-DAG: [[FEATURES_EXP:%.*]] = "xla_hlo.exponential"([[FEATURES]])
// CHECK-DAG: [[EPSILON:%.*]] = xla_hlo.constant dense<2.2204460492503131E-16> : tensor<f64>
// CHECK-DAG: [[EPSILON_LOG:%.*]] = "xla_hlo.log"([[EPSILON]])
// CHECK-DAG: [[TWO:%.*]] = xla_hlo.constant dense<2.000000e+00> : tensor<f64>
// CHECK: [[THRESHOLD:%.*]] = xla_chlo.broadcast_add [[EPSILON_LOG]], [[TWO]]
// CHECK: [[NEG_THRESHOLD:%.*]] = "xla_hlo.negate"([[THRESHOLD]])
// CHECK-DAG: [[COMPARE_GT:%.*]] = xla_chlo.broadcast_compare [[FEATURES]], [[NEG_THRESHOLD]] {comparison_direction = "GT"}
// CHECK-DAG: [[COMPARE_LT:%.*]] = xla_chlo.broadcast_compare [[FEATURES]], [[THRESHOLD]] {comparison_direction = "LT"}
// CHECK-DAG: [[FEATURES_EXP_LOG:%.*]] = "xla_hlo.log_plus_one"([[FEATURES_EXP]])
// CHECK: [[ELSE_SELECT:%.*]] = "xla_hlo.select"([[COMPARE_LT]], [[FEATURES_EXP]], [[FEATURES_EXP_LOG]])
// CHECK: [[ENTRY_SELECT:%.*]] = "xla_hlo.select"([[COMPARE_GT]], [[FEATURES]], [[ELSE_SELECT]])
%0 = "tf.Softplus"(%arg0) : (tensor<8x16xf64>) -> tensor<8x16xf64>
// CHECK: return [[ENTRY_SELECT]] : tensor<8x16xf64>
return %0 : tensor<8x16xf64>
}

View File

@ -173,7 +173,8 @@ func @iota(%out: memref<7x10xf32>) {
"xla_lhlo.iota"(%out) {iota_dimension = 1 : i64} : (memref<7x10xf32>) -> ()
return
}
// CHECK: linalg.indexed_generic {{{.*}}indexing_maps = [#[[RESULT_MAP]]]
// CHECK: linalg.indexed_generic
// CHECK-SAME: indexing_maps = [#[[RESULT_MAP]]]
// CHECK-NEXT: ^bb0(%[[D0:.*]]: index, %[[D1:.*]]: index, %[[RESULT:.*]]: f32):
// CHECK-NEXT: %[[INT_CAST:.*]] = index_cast %[[D1]] : index to i32
// CHECK-NEXT: %[[FLOAT_CAST:.*]] = sitofp %[[INT_CAST]] : i32 to f32
@ -190,7 +191,8 @@ func @broadcast_scalar(%operand: memref<f32>, %result: memref<4x2x1xf32>) {
} : (memref<f32>, memref<4x2x1xf32>) -> ()
return
}
// CHECK: linalg.generic {{{.*}}indexing_maps = [#[[OPERAND_MAP]], #[[RESULT_MAP]]]
// CHECK: linalg.generic
// CHECK-SAME: indexing_maps = [#[[OPERAND_MAP]], #[[RESULT_MAP]]]
// CHECK-NEXT: ^bb0(%[[OPERAND:.+]]: f32, %{{.+}}: f32):
// CHECK-NEXT: linalg.yield %[[OPERAND]] : f32
@ -206,7 +208,8 @@ func @broadcast(%operand: memref<4x?x16xf32>,
} : (memref<4x?x16xf32>, memref<4x2x1x4x?x16xf32>) -> ()
return
}
// CHECK: linalg.generic {{{.*}}indexing_maps = [#[[OPERAND_MAP]], #[[RESULT_MAP]]]
// CHECK: linalg.generic
// CHECK-SAME: indexing_maps = [#[[OPERAND_MAP]], #[[RESULT_MAP]]]
// CHECK-NEXT: ^bb0(%[[OPERAND:.+]]: f32, %{{.+}}: f32):
// CHECK-NEXT: linalg.yield %[[OPERAND]] : f32
@ -222,7 +225,8 @@ func @dynamic_broadcast_in_dim(%operand: memref<?x?x?xf32>,
} : (memref<?x?x?xf32>, memref<?x?x?x?x?xf32>) -> ()
return
}
// CHECK: linalg.generic {{{.*}}indexing_maps = [#[[OPERAND_MAP]], #[[RESULT_MAP]]]
// CHECK: linalg.generic
// CHECK-SAME: indexing_maps = [#[[OPERAND_MAP]], #[[RESULT_MAP]]]
// CHECK-NEXT: ^bb0(%[[OPERAND:.*]]: f32, %[[RESULT:.*]]: f32):
// CHECK-NEXT: linalg.yield %[[OPERAND]] : f32
@ -645,39 +649,42 @@ func @slice(%operand: memref<?x?xf32>, %result: memref<?x?xf32>) {
// -----
// CHECK-DAG: #[[OPERAND_MAP:.*]] = affine_map<(d0, d1) -> (d0, 0, d1)>
// CHECK-DAG: #[[RESULT_MAP:.*]] = affine_map<(d0, d1) -> (d0, d1)>
// CHECK-DAG: #[[MAP1:.*]] = affine_map<(d0, d1, d2) -> (d0, d1)>
// CHECK-DAG: #[[MAP2:.*]] = affine_map<(d0, d1, d2) -> (d2)>
// CHECK-LABEL: func @reshape_3D_2D
func @reshape_3D_2D(%arg0: memref<12x1x42xi32>, %arg1 : memref<12x42xi32>) {
"xla_lhlo.reshape"(%arg0, %arg1)
: (memref<12x1x42xi32>, memref<12x42xi32>) -> ()
return
}
// CHECK: linalg.generic {{{.*}}indexing_maps = [#[[OPERAND_MAP]], #[[RESULT_MAP]]]
// CHECK: linalg.reshape %{{.*}} [#[[MAP1]], #[[MAP2]]]
// CHECK-NEXT: linalg.copy
// -----
// CHECK-DAG: #[[OPERAND_MAP:.*]] = affine_map<(d0, d1) -> (d0, d1, 0, 0)>
// CHECK-DAG: #[[RESULT_MAP:.*]] = affine_map<(d0, d1) -> (d0, d1)>
// CHECK-DAG: #[[MAP1:.*]] = affine_map<(d0, d1, d2, d3) -> (d0)>
// CHECK-DAG: #[[MAP2:.*]] = affine_map<(d0, d1, d2, d3) -> (d1, d2, d3)>
// CHECK-LABEL: func @reshape_4D_2D
func @reshape_4D_2D(%arg0: memref<12x42x1x1xi32>, %arg1 : memref<12x42xi32>) {
"xla_lhlo.reshape"(%arg0, %arg1)
: (memref<12x42x1x1xi32>, memref<12x42xi32>) -> ()
return
}
// CHECK: linalg.generic {{{.*}}indexing_maps = [#[[OPERAND_MAP]], #[[RESULT_MAP]]]
// CHECK: linalg.reshape %{{.*}} [#[[MAP1]], #[[MAP2]]]
// CHECK-NEXT: linalg.copy
// -----
// CHECK-DAG: #[[OPERAND_MAP:.*]] = affine_map<(d0, d1, d2, d3) -> (d0, d2)>
// CHECK-DAG: #[[RESULT_MAP:.*]] = affine_map<(d0, d1, d2, d3) -> (d0, d1, d2, d3)>
// CHECK-DAG: #[[MAP1:.*]] = affine_map<(d0, d1, d2, d3) -> (d0, d1)>
// CHECK-DAG: #[[MAP2:.*]] = affine_map<(d0, d1, d2, d3) -> (d2, d3)>
// CHECK-LABEL: func @reshape_2D_4D
func @reshape_2D_4D(%arg0: memref<12x42xi32>, %arg1 : memref<12x1x42x1xi32>) {
"xla_lhlo.reshape"(%arg0, %arg1)
: (memref<12x42xi32>, memref<12x1x42x1xi32>) -> ()
return
}
// CHECK: linalg.generic {{{.*}}indexing_maps = [#[[OPERAND_MAP]], #[[RESULT_MAP]]]
// CHECK: linalg.reshape %{{.*}} [#[[MAP1]], #[[MAP2]]]
// CHECK-NEXT: linalg.copy
// -----

View File

@ -103,9 +103,10 @@ func @dynamic_reduce(%arg: memref<?x?x?xf32>,
// CHECK-SAME: [[RESULT_BUF:%.*]]: memref<?x?xf32>) {
// CHECK-DAG: [[C0:%.*]] = constant 0 : index
// CHECK-DAG: [[C1:%.*]] = constant 1 : index
// CHECK: [[DIM0:%.*]] = dim [[ARG_BUF]], 0 : memref<?x?x?xf32>
// CHECK: [[DIM1:%.*]] = dim [[ARG_BUF]], 1 : memref<?x?x?xf32>
// CHECK: [[DIM2:%.*]] = dim [[ARG_BUF]], 2 : memref<?x?x?xf32>
// CHECK-DAG: [[C2:%.*]] = constant 2 : index
// CHECK: [[DIM0:%.*]] = dim [[ARG_BUF]], [[C0]] : memref<?x?x?xf32>
// CHECK: [[DIM1:%.*]] = dim [[ARG_BUF]], [[C1]] : memref<?x?x?xf32>
// CHECK: [[DIM2:%.*]] = dim [[ARG_BUF]], [[C2]] : memref<?x?x?xf32>
// CHECK: [[INIT:%.*]] = load [[INIT_BUF]]
// CHECK: scf.parallel ([[I:%.*]], [[K:%.*]]) = ([[C0]], [[C0]])
// CHECK-SAME: to ([[DIM0]], [[DIM2]]) step ([[C1]], [[C1]]) {

View File

@ -1,8 +1,66 @@
// RUN: xla-opt %s -verify-diagnostics -split-input-file | xla-opt | FileCheck %s
func @enforce_same_shape(%arg0: memref<1xf32>, %arg1: memref<2xf32>) -> () {
// expected-error@+1{{'xla_lhlo.tanh' op requires all operands to have the same type}}
"xla_lhlo.tanh"(%arg0, %arg1) : (memref<1xf32>, memref<2xf32>) -> ()
// -----
// CHECK-LABEL: func @ceil
func @ceil(%input: memref<2x2xf32>, %result: memref<2x2xf32>) {
"xla_lhlo.ceil"(%input, %result) : (memref<2x2xf32>, memref<2x2xf32>) -> ()
return
}
// -----
func @ceil(%input: memref<2x2xi32>, %result: memref<2x2xi32>) {
// expected-error@+1{{must be memref of floating-point values}}
"xla_lhlo.ceil"(%input, %result) : (memref<2x2xi32>, memref<2x2xi32>) -> ()
return
}
// -----
// CHECK-LABEL: func @cos
func @cos(%input: memref<2x2xf32>, %result: memref<2x2xf32>) {
"xla_lhlo.cosine"(%input, %result) : (memref<2x2xf32>, memref<2x2xf32>) -> ()
return
}
// -----
// CHECK-LABEL: func @cos
func @cos(%input: memref<2x2xcomplex<f32>>, %result: memref<2x2xcomplex<f32>>) {
"xla_lhlo.cosine"(%input, %result) : (memref<2x2xcomplex<f32>>, memref<2x2xcomplex<f32>>) -> ()
return
}
// -----
func @cos(%input: memref<2x2xi32>, %result: memref<2x2xi32>) {
// expected-error@+1{{must be memref of floating-point or complex-type values}}
"xla_lhlo.cosine"(%input, %result) : (memref<2x2xi32>, memref<2x2xi32>) -> ()
return
}
// -----
// CHECK-LABEL: func @sin
func @sin(%input: memref<2x2xf32>, %result: memref<2x2xf32>) {
"xla_lhlo.sine"(%input, %result) : (memref<2x2xf32>, memref<2x2xf32>) -> ()
return
}
// -----
// CHECK-LABEL: func @sin
func @sin(%input: memref<2x2xcomplex<f32>>, %result: memref<2x2xcomplex<f32>>) {
"xla_lhlo.sine"(%input, %result) : (memref<2x2xcomplex<f32>>, memref<2x2xcomplex<f32>>) -> ()
return
}
// -----
func @sin(%input: memref<2x2xi32>, %result: memref<2x2xi32>) {
// expected-error@+1{{must be memref of floating-point or complex-type values}}
"xla_lhlo.sine"(%input, %result) : (memref<2x2xi32>, memref<2x2xi32>) -> ()
return
}
@ -25,16 +83,40 @@ func @abs_memref(%in: memref<10xf32>, %out: memref<10xf32>) -> () {
// -----
// CHECK-LABEL: func @convert_memref
func @convert_memref(%in: memref<10xf32>, %out: memref<10xf32>) -> () {
"xla_lhlo.convert"(%in, %out) : (memref<10xf32>, memref<10xf32>) -> ()
func @convert_memref(%in: memref<10xf32>, %out: memref<10xi32>) -> () {
"xla_lhlo.convert"(%in, %out) : (memref<10xf32>, memref<10xi32>) -> ()
return
}
// -----
// CHECK-LABEL: func @exp_memref
func @exp_memref(%in: memref<10xf32>, %out: memref<10xf32>) -> () {
"xla_lhlo.exponential"(%in, %out) : (memref<10xf32>, memref<10xf32>) -> ()
func @convert_memref(%in: memref<10xf32>, %out: memref<9xi32>) -> () {
// expected-error@+1{{requires the same shape for all operands}}
"xla_lhlo.convert"(%in, %out) : (memref<10xf32>, memref<9xi32>) -> ()
return
}
// -----
// CHECK-LABEL: func @exp
func @exp(%input: memref<2x2xf32>, %result: memref<2x2xf32>) {
"xla_lhlo.exponential"(%input, %result) : (memref<2x2xf32>, memref<2x2xf32>) -> ()
return
}
// -----
// CHECK-LABEL: func @exp
func @exp(%input: memref<2x2xcomplex<f32>>, %result: memref<2x2xcomplex<f32>>) {
"xla_lhlo.exponential"(%input, %result) : (memref<2x2xcomplex<f32>>, memref<2x2xcomplex<f32>>) -> ()
return
}
// -----
func @exp(%input: memref<2x2xi32>, %result: memref<2x2xi32>) {
// expected-error@+1{{must be memref of floating-point or complex-type values}}
"xla_lhlo.exponential"(%input, %result) : (memref<2x2xi32>, memref<2x2xi32>) -> ()
return
}
@ -48,6 +130,22 @@ func @log_memref(%in: memref<10xf32>, %out: memref<10xf32>) -> () {
// -----
// CHECK-LABEL: func @log_memref
func @log_memref(%in: memref<10xcomplex<f32>>, %out: memref<10xcomplex<f32>>) -> () {
"xla_lhlo.log"(%in, %out) : (memref<10xcomplex<f32>>, memref<10xcomplex<f32>>) -> ()
return
}
// -----
func @log_memref(%in: memref<10xi32>, %out: memref<10xi32>) -> () {
// expected-error@+1{{must be memref of floating-point or complex-type values}}
"xla_lhlo.log"(%in, %out) : (memref<10xi32>, memref<10xi32>) -> ()
return
}
// -----
// CHECK-LABEL: func @neg_memref
func @neg_memref(%in: memref<10xf32>, %out: memref<10xf32>) -> () {
"xla_lhlo.negate"(%in, %out) : (memref<10xf32>, memref<10xf32>) -> ()
@ -64,6 +162,46 @@ func @rsqrt_memref(%in: memref<10xf32>, %out: memref<10xf32>) -> () {
// -----
// CHECK-LABEL: func @rsqrt_memref
func @rsqrt_memref(%in: memref<10xcomplex<f32>>, %out: memref<10xcomplex<f32>>) -> () {
"xla_lhlo.rsqrt"(%in, %out) : (memref<10xcomplex<f32>>, memref<10xcomplex<f32>>) -> ()
return
}
// -----
func @rsqrt_memref(%in: memref<10xi32>, %out: memref<10xi32>) -> () {
// expected-error@+1{{must be memref of floating-point or complex-type values}}
"xla_lhlo.rsqrt"(%in, %out) : (memref<10xi32>, memref<10xi32>) -> ()
return
}
// -----
// CHECK-LABEL: func @sqrt_memref
func @sqrt_memref(%in: memref<10xf32>, %out: memref<10xf32>) -> () {
"xla_lhlo.sqrt"(%in, %out) : (memref<10xf32>, memref<10xf32>) -> ()
return
}
// -----
// CHECK-LABEL: func @sqrt_memref
func @sqrt_memref(%in: memref<10xcomplex<f32>>, %out: memref<10xcomplex<f32>>) -> () {
"xla_lhlo.sqrt"(%in, %out) : (memref<10xcomplex<f32>>, memref<10xcomplex<f32>>) -> ()
return
}
// -----
func @sqrt_memref(%in: memref<10xi32>, %out: memref<10xi32>) -> () {
// expected-error@+1{{must be memref of floating-point or complex-type values}}
"xla_lhlo.sqrt"(%in, %out) : (memref<10xi32>, memref<10xi32>) -> ()
return
}
// -----
// CHECK-LABEL: func @sign_memref
func @sign_memref(%in: memref<10xf32>, %out: memref<10xf32>) -> () {
"xla_lhlo.sign"(%in, %out) : (memref<10xf32>, memref<10xf32>) -> ()
@ -80,6 +218,30 @@ func @tanh_memref(%in: memref<10xf32>, %out: memref<10xf32>) -> () {
// -----
// CHECK-LABEL: func @tanh_memref
func @tanh_memref(%in: memref<10xcomplex<f32>>, %out: memref<10xcomplex<f32>>) -> () {
"xla_lhlo.tanh"(%in, %out) : (memref<10xcomplex<f32>>, memref<10xcomplex<f32>>) -> ()
return
}
// -----
func @tanh_memref(%in: memref<10xi32>, %out: memref<10xi32>) -> () {
// expected-error@+1{{must be memref of floating-point or complex-type values}}
"xla_lhlo.tanh"(%in, %out) : (memref<10xi32>, memref<10xi32>) -> ()
return
}
// -----
func @tanh_memref(%arg0: memref<1xf32>, %arg1: memref<2xf32>) -> () {
// expected-error@+1{{'xla_lhlo.tanh' op requires all operands to have the same type}}
"xla_lhlo.tanh"(%arg0, %arg1) : (memref<1xf32>, memref<2xf32>) -> ()
return
}
// -----
// CHECK-LABEL: func @add_memref
func @add_memref(%lhs: memref<10xf32>, %rhs: memref<10xf32>, %out: memref<10xf32>) -> () {
"xla_lhlo.add"(%lhs, %rhs, %out) : (memref<10xf32>, memref<10xf32>, memref<10xf32>) -> ()
@ -129,13 +291,77 @@ func @sub_memref(%lhs: memref<10xf32>, %rhs: memref<10xf32>, %out: memref<10xf32
// -----
// CHECK-LABEL: func @and_memref
func @and_memref(%lhs: memref<10xi32>, %rhs: memref<10xi32>, %out: memref<10xi32>) -> () {
"xla_lhlo.and"(%lhs, %rhs, %out) : (memref<10xi32>, memref<10xi32>, memref<10xi32>) -> ()
return
}
// -----
// CHECK-LABEL: func @and_memref
func @and_memref(%lhs: memref<10xi1>, %rhs: memref<10xi1>, %out: memref<10xi1>) -> () {
"xla_lhlo.and"(%lhs, %rhs, %out) : (memref<10xi1>, memref<10xi1>, memref<10xi1>) -> ()
return
}
// -----
func @and_memref(%lhs: memref<10xf32>, %rhs: memref<10xf32>, %out: memref<10xf32>) -> () {
// expected-error @+1 {{must be memref of 8/16/32/64-bit signless integer or 8/16/32/64-bit unsigned integer or pred (AKA boolean or 1-bit integer) values}}
"xla_lhlo.and"(%lhs, %rhs, %out) : (memref<10xf32>, memref<10xf32>, memref<10xf32>) -> ()
return
}
// -----
// CHECK-LABEL: func @or_memref
func @or_memref(%lhs: memref<10xi32>, %rhs: memref<10xi32>, %out: memref<10xi32>) -> () {
"xla_lhlo.or"(%lhs, %rhs, %out) : (memref<10xi32>, memref<10xi32>, memref<10xi32>) -> ()
return
}
// -----
// CHECK-LABEL: func @or_memref
func @or_memref(%lhs: memref<10xi1>, %rhs: memref<10xi1>, %out: memref<10xi1>) -> () {
"xla_lhlo.or"(%lhs, %rhs, %out) : (memref<10xi1>, memref<10xi1>, memref<10xi1>) -> ()
return
}
// -----
func @or_memref(%lhs: memref<10xf32>, %rhs: memref<10xf32>, %out: memref<10xf32>) -> () {
// expected-error @+1 {{must be memref of 8/16/32/64-bit signless integer or 8/16/32/64-bit unsigned integer or pred (AKA boolean or 1-bit integer) values}}
"xla_lhlo.or"(%lhs, %rhs, %out) : (memref<10xf32>, memref<10xf32>, memref<10xf32>) -> ()
return
}
// -----
// CHECK-LABEL: func @xor_memref
func @xor_memref(%lhs: memref<10xi32>, %rhs: memref<10xi32>, %out: memref<10xi32>) -> () {
"xla_lhlo.xor"(%lhs, %rhs, %out) : (memref<10xi32>, memref<10xi32>, memref<10xi32>) -> ()
return
}
// -----
// CHECK-LABEL: func @xor_memref
func @xor_memref(%lhs: memref<10xi1>, %rhs: memref<10xi1>, %out: memref<10xi1>) -> () {
"xla_lhlo.xor"(%lhs, %rhs, %out) : (memref<10xi1>, memref<10xi1>, memref<10xi1>) -> ()
return
}
// -----
func @xor_memref(%lhs: memref<10xf32>, %rhs: memref<10xf32>, %out: memref<10xf32>) -> () {
// expected-error @+1 {{must be memref of 8/16/32/64-bit signless integer or 8/16/32/64-bit unsigned integer or pred (AKA boolean or 1-bit integer) values}}
"xla_lhlo.xor"(%lhs, %rhs, %out) : (memref<10xf32>, memref<10xf32>, memref<10xf32>) -> ()
return
}
// -----
// CHECK-LABEL: func @broadcast_in_dim_memref
func @broadcast_in_dim_memref(%arg0: memref<1x2xi32>, %out: memref<1x2x2xi32>) -> () {
"xla_lhlo.broadcast_in_dim"(%arg0, %out) {broadcast_dimensions = dense<[1, 2]> : tensor<2xi64>} : (memref<1x2xi32>, memref<1x2x2xi32>) -> ()
@ -248,3 +474,392 @@ func @dynamic_memref_cast_incompatible_result_type(%in: memref<?xf32>) {
: memref<?xf32> -> memref<?x?xf32, offset: 0, strides: [?, ?]>
return
}
// -----
// CHECK-LABEL: func @atan2_memrefs
func @atan2_memrefs(%arg0: memref<1xf32>, %arg1: memref<1xf32>, %arg_out: memref<1xf32>) -> () {
"xla_lhlo.atan2"(%arg0, %arg1, %arg_out) : (memref<1xf32>, memref<1xf32>, memref<1xf32>) -> ()
return
}
// -----
// CHECK-LABEL: func @atan2_memrefs
func @atan2_memrefs(%arg0: memref<1xcomplex<f32>>, %arg1: memref<1xcomplex<f32>>, %arg_out: memref<1xcomplex<f32>>) -> () {
"xla_lhlo.atan2"(%arg0, %arg1, %arg_out) : (memref<1xcomplex<f32>>, memref<1xcomplex<f32>>, memref<1xcomplex<f32>>) -> ()
return
}
// -----
func @atan2_memrefs(%arg0: memref<1xi32>, %arg1: memref<1xi32>, %arg_out: memref<1xi32>) -> () {
// expected-error@+1{{must be memref of floating-point or complex-type values}}
"xla_lhlo.atan2"(%arg0, %arg1, %arg_out) : (memref<1xi32>, memref<1xi32>, memref<1xi32>) -> ()
return
}
// -----
// CHECK-LABEL: func @bitcast_convert_memrefs
func @bitcast_convert_memrefs(%arg0: memref<1xf32>, %arg_out: memref<1xi32>) -> () {
"xla_lhlo.bitcast_convert"(%arg0, %arg_out) : (memref<1xf32>, memref<1xi32>) -> ()
return
}
// -----
func @bitcast_convert_memrefs(%arg0: memref<1xf32>, %arg_out: memref<2xi32>) -> () {
// expected-error@+1{{requires the same shape for all operands}}
"xla_lhlo.bitcast_convert"(%arg0, %arg_out) : (memref<1xf32>, memref<2xi32>) -> ()
return
}
// -----
// CHECK-LABEL: func @clz_memrefs
func @clz_memrefs(%arg0: memref<1xi32>, %arg_out: memref<1xi32>) -> () {
"xla_lhlo.count_leading_zeros"(%arg0, %arg_out) : (memref<1xi32>, memref<1xi32>) -> ()
return
}
// -----
// CHECK-LABEL: func @expm1_memrefs
func @expm1_memrefs(%arg0: memref<1xf32>, %arg_out: memref<1xf32>) -> () {
"xla_lhlo.exponential_minus_one"(%arg0, %arg_out) : (memref<1xf32>, memref<1xf32>) -> ()
return
}
// -----
// CHECK-LABEL: func @expm1_memrefs
func @expm1_memrefs(%arg0: memref<1xcomplex<f32>>, %arg_out: memref<1xcomplex<f32>>) -> () {
"xla_lhlo.exponential_minus_one"(%arg0, %arg_out) : (memref<1xcomplex<f32>>, memref<1xcomplex<f32>>) -> ()
return
}
// -----
// CHECK-LABEL: func @floor_memrefs
func @floor_memrefs(%arg0: memref<1xf32>, %arg_out: memref<1xf32>) -> () {
"xla_lhlo.floor"(%arg0, %arg_out) : (memref<1xf32>, memref<1xf32>) -> ()
return
}
// -----
func @floor_memrefs(%arg0: memref<1xi32>, %arg_out: memref<1xi32>) -> () {
// expected-error@+1{{must be memref of floating-point values}}
"xla_lhlo.floor"(%arg0, %arg_out) : (memref<1xi32>, memref<1xi32>) -> ()
return
}
// -----
// CHECK-LABEL: func @imag_memrefs
func @imag_memrefs(%arg0: memref<1xcomplex<f32>>, %arg_out: memref<1xf32>) -> () {
"xla_lhlo.imag"(%arg0, %arg_out) : (memref<1xcomplex<f32>>, memref<1xf32>) -> ()
return
}
// -----
func @imag_memrefs(%arg0: memref<1xf32>, %arg_out: memref<1xf32>) -> () {
// expected-error@+1{{must be memref of complex-type values}}
"xla_lhlo.imag"(%arg0, %arg_out) : (memref<1xf32>, memref<1xf32>) -> ()
return
}
// -----
// CHECK-LABEL: func @real_memrefs
func @real_memrefs(%arg0: memref<1xcomplex<f32>>, %arg_out: memref<1xf32>) -> () {
"xla_lhlo.real"(%arg0, %arg_out) : (memref<1xcomplex<f32>>, memref<1xf32>) -> ()
return
}
// -----
func @real_memrefs(%arg0: memref<1xf32>, %arg_out: memref<1xf32>) -> () {
// expected-error@+1{{must be memref of complex-type values}}
"xla_lhlo.real"(%arg0, %arg_out) : (memref<1xf32>, memref<1xf32>) -> ()
return
}
// -----
// CHECK-LABEL: func @is_finite_memrefs
func @is_finite_memrefs(%arg0: memref<1xf32>, %arg_out: memref<1xi1>) -> () {
"xla_lhlo.is_finite"(%arg0, %arg_out) : (memref<1xf32>, memref<1xi1>) -> ()
return
}
// -----
// CHECK-LABEL: func @log1p_memrefs
func @log1p_memrefs(%arg0: memref<1xf32>, %arg_out: memref<1xf32>) -> () {
"xla_lhlo.log_plus_one"(%arg0, %arg_out) : (memref<1xf32>, memref<1xf32>) -> ()
return
}
// -----
// CHECK-LABEL: func @log1p_memrefs
func @log1p_memrefs(%arg0: memref<1xcomplex<f32>>, %arg_out: memref<1xcomplex<f32>>) -> () {
"xla_lhlo.log_plus_one"(%arg0, %arg_out) : (memref<1xcomplex<f32>>, memref<1xcomplex<f32>>) -> ()
return
}
// -----
func @log1p_memref(%in: memref<10xi32>, %out: memref<10xi32>) -> () {
// expected-error@+1{{must be memref of floating-point or complex-type values}}
"xla_lhlo.log_plus_one"(%in, %out) : (memref<10xi32>, memref<10xi32>) -> ()
return
}
// -----
// CHECK-LABEL: func @not_memrefs
func @not_memrefs(%arg0: memref<1xi32>, %arg_out: memref<1xi32>) -> () {
"xla_lhlo.not"(%arg0, %arg_out) : (memref<1xi32>, memref<1xi32>) -> ()
return
}
// -----
// CHECK-LABEL: func @not_memrefs
func @not_memrefs(%arg0: memref<1xi1>, %arg_out: memref<1xi1>) -> () {
"xla_lhlo.not"(%arg0, %arg_out) : (memref<1xi1>, memref<1xi1>) -> ()
return
}
// -----
func @not_memrefs(%arg0: memref<1xf32>, %arg_out: memref<1xf32>) -> () {
// expected-error @+1 {{must be memref of 8/16/32/64-bit signless integer or 8/16/32/64-bit unsigned integer or pred (AKA boolean or 1-bit integer) values}}
"xla_lhlo.not"(%arg0, %arg_out) : (memref<1xf32>, memref<1xf32>) -> ()
return
}
// -----
// CHECK-LABEL: func @popcnt_memrefs
func @popcnt_memrefs(%arg0: memref<1xi32>, %arg_out: memref<1xi32>) -> () {
"xla_lhlo.popcnt"(%arg0, %arg_out) : (memref<1xi32>, memref<1xi32>) -> ()
return
}
// -----
func @popcnt_memrefs(%arg0: memref<1xf32>, %arg_out: memref<1xf32>) -> () {
// expected-error @+1 {{must be memref of 8/16/32/64-bit signless integer or 8/16/32/64-bit unsigned integer values}}
"xla_lhlo.popcnt"(%arg0, %arg_out) : (memref<1xf32>, memref<1xf32>) -> ()
return
}
// -----
// CHECK-LABEL: func @reduce_precision_memrefs
func @reduce_precision_memrefs(%arg0: memref<1xf32>, %arg_out: memref<1xf32>) -> () {
"xla_lhlo.reduce_precision"(%arg0, %arg_out) { exponent_bits = 4 : i32, mantissa_bits = 4 : i32 } : (memref<1xf32>, memref<1xf32>) -> ()
return
}
// -----
// CHECK-LABEL: func @round_memrefs
func @round_memrefs(%arg0: memref<1xf32>, %arg_out: memref<1xf32>) -> () {
"xla_lhlo.round_nearest_afz"(%arg0, %arg_out) : (memref<1xf32>, memref<1xf32>) -> ()
return
}
// -----
func @round_memrefs(%arg0: memref<1xi32>, %arg_out: memref<1xi32>) -> () {
// expected-error@+1{{must be memref of floating-point values}}
"xla_lhlo.round_nearest_afz"(%arg0, %arg_out) : (memref<1xi32>, memref<1xi32>) -> ()
return
}
// -----
// CHECK-LABEL: func @shift_left_memrefs
func @shift_left_memrefs(%arg0: memref<1xi32>, %arg1: memref<1xi32>, %arg_out: memref<1xi32>) -> () {
"xla_lhlo.shift_left"(%arg0, %arg1, %arg_out) : (memref<1xi32>, memref<1xi32>, memref<1xi32>) -> ()
return
}
// -----
func @shift_left_memrefs(%arg0: memref<1xf32>, %arg1: memref<1xf32>, %arg_out: memref<1xf32>) -> () {
// expected-error @+1 {{must be memref of 8/16/32/64-bit signless integer or 8/16/32/64-bit unsigned integer values}}
"xla_lhlo.shift_left"(%arg0, %arg1, %arg_out) : (memref<1xf32>, memref<1xf32>, memref<1xf32>) -> ()
return
}
// -----
// CHECK-LABEL: func @shift_right_arithmetic_memrefs
func @shift_right_arithmetic_memrefs(%arg0: memref<1xi32>, %arg1: memref<1xi32>, %arg_out: memref<1xi32>) -> () {
"xla_lhlo.shift_right_arithmetic"(%arg0, %arg1, %arg_out) : (memref<1xi32>, memref<1xi32>, memref<1xi32>) -> ()
return
}
// -----
func @shift_right_arithmetic_memrefs(%arg0: memref<1xf32>, %arg1: memref<1xf32>, %arg_out: memref<1xf32>) -> () {
// expected-error @+1 {{must be memref of 8/16/32/64-bit signless integer or 8/16/32/64-bit unsigned integer values}}
"xla_lhlo.shift_right_arithmetic"(%arg0, %arg1, %arg_out) : (memref<1xf32>, memref<1xf32>, memref<1xf32>) -> ()
return
}
// -----
// CHECK-LABEL: func @shift_right_logical_memrefs
func @shift_right_logical_memrefs(%arg0: memref<1xi32>, %arg1: memref<1xi32>, %arg_out: memref<1xi32>) -> () {
"xla_lhlo.shift_right_logical"(%arg0, %arg1, %arg_out) : (memref<1xi32>, memref<1xi32>, memref<1xi32>) -> ()
return
}
// -----
func @shift_right_logical_memrefs(%arg0: memref<1xf32>, %arg1: memref<1xf32>, %arg_out: memref<1xf32>) -> () {
// expected-error @+1 {{must be memref of 8/16/32/64-bit signless integer or 8/16/32/64-bit unsigned integer values}}
"xla_lhlo.shift_right_logical"(%arg0, %arg1, %arg_out) : (memref<1xf32>, memref<1xf32>, memref<1xf32>) -> ()
return
}
// -----
// CHECK-LABEL: func @all_reduce_memrefs
func @all_reduce_memrefs(%arg0: memref<10xf32>, %arg_out: memref<10xf32>) -> () {
"xla_lhlo.all_reduce"(%arg0, %arg_out) ({
^bb0(%lhs: tensor<f32>, %rhs: tensor<f32>):
%max = xla_hlo.maximum %lhs, %rhs : tensor<f32>
"xla_hlo.return"(%max) : (tensor<f32>) -> ()
})
{ replica_groups = dense<[[0, 2, 4, 6], [1, 3, 5, 7]]> : tensor<2x4xi64> }: (memref<10xf32>, memref<10xf32>) -> ()
"xla_lhlo.all_reduce"(%arg0, %arg_out) ({
^bb0(%lhs: tensor<f32>, %rhs: tensor<f32>):
%max = xla_hlo.maximum %lhs, %rhs : tensor<f32>
"xla_hlo.return"(%max) : (tensor<f32>) -> ()
})
{
replica_groups = dense<[[0, 2, 4, 6], [1, 3, 5, 7]]> : tensor<2x4xi64>,
channel_id = { handle = 5 : i64, type = 2 : i64 },
constrain_layout = true,
use_global_device_ids = true
}: (memref<10xf32>, memref<10xf32>) -> ()
return
}
// -----
// CHECK-LABEL: func @collective_permute_memrefs
func @collective_permute_memrefs(%arg0: memref<128x32xf32>, %arg_out: memref<128x32xf32>) -> () {
"xla_lhlo.collective_permute"(%arg0, %arg_out) {
source_target_pairs = dense<[[0, 1], [1, 2], [2, 3]]> : tensor<3x2xi64>
} : (memref<128x32xf32>, memref<128x32xf32>) -> ()
"xla_lhlo.collective_permute"(%arg0, %arg_out) {
source_target_pairs = dense<[[0, 1], [1, 2], [2, 3]]> : tensor<3x2xi64>,
channel_id = { handle = 5 : i64, type = 2 : i64 }
} : (memref<128x32xf32>, memref<128x32xf32>) -> ()
return
}
// -----
// CHECK-LABEL: func @fft_memrefs
func @fft_memrefs(%arg0: memref<3x9xf32>, %arg_out: memref<3x5xcomplex<f32>>) -> () {
"xla_lhlo.fft"(%arg0, %arg_out) {fft_length = dense<9> : tensor<1xi64>, fft_type = "RFFT"} : (memref<3x9xf32>, memref<3x5xcomplex<f32>>) -> ()
return
}
// -----
// CHECK-LABEL: func @batch_norm_grad_memrefs
func @batch_norm_grad_memrefs(%arg0: memref<8x8x8x8xf32>, %arg1: memref<8xf32>, %arg2: memref<8xf32>,
%arg3: memref<8xf32>, %arg4: memref<8x8x8x8xf32>,
%arg_out: tuple<memref<8x8x8x8xf32>, memref<8xf32>, memref<8xf32>>) -> () {
"xla_lhlo.batch_norm_grad"(%arg0, %arg1, %arg2, %arg3, %arg4, %arg_out) {epsilon = 1.000000e-03 : f32, feature_index = 3 : i64}
: (memref<8x8x8x8xf32>, memref<8xf32>, memref<8xf32>, memref<8xf32>, memref<8x8x8x8xf32>,
tuple<memref<8x8x8x8xf32>, memref<8xf32>, memref<8xf32>>) -> ()
return
}
// -----
// CHECK-LABEL: func @batch_norm_inference_memrefs
func @batch_norm_inference_memrefs(%arg0: memref<8x8x8x8xf32>, %arg1: memref<8xf32>, %arg2: memref<8xf32>,
%arg3: memref<8xf32>, %arg4: memref<8xf32>, %arg_out: memref<8x8x8x8xf32>) -> () {
"xla_lhlo.batch_norm_inference"(%arg0, %arg1, %arg2, %arg3, %arg4, %arg_out) {epsilon = 1.000000e-03 : f32, feature_index = 3 : i64}
: (memref<8x8x8x8xf32>, memref<8xf32>, memref<8xf32>, memref<8xf32>, memref<8xf32>, memref<8x8x8x8xf32>) -> ()
return
}
// -----
// CHECK-LABEL: func @batch_norm_training_memrefs
func @batch_norm_training_memrefs(%arg0: memref<8x8x8x8xf32>, %arg1: memref<8xf32>, %arg2: memref<8xf32>,
%arg_out: tuple<memref<8x8x8x8xf32>, memref<8xf32>, memref<8xf32>>) -> () {
"xla_lhlo.batch_norm_training"(%arg0, %arg1, %arg2, %arg_out) {epsilon = 1.000000e-03 : f32, feature_index = 3 : i64}
: (memref<8x8x8x8xf32>, memref<8xf32>, memref<8xf32>, tuple<memref<8x8x8x8xf32>, memref<8xf32>, memref<8xf32>>) -> ()
return
}
// -----
// CHECK-LABEL: func @cholesky_memrefs
func @cholesky_memrefs(%arg0: memref<1x291x291xf32>, %arg_out: memref<1x291x291xf32>) -> () {
"xla_lhlo.cholesky"(%arg0, %arg_out) : (memref<1x291x291xf32>, memref<1x291x291xf32>) -> ()
"xla_lhlo.cholesky"(%arg0, %arg_out) { lower = true } : (memref<1x291x291xf32>, memref<1x291x291xf32>) -> ()
return
}
// -----
// CHECK-LABEL: func @infeed_memrefs
func @infeed_memrefs(%arg_out: memref<3xf32>) -> () {
"xla_lhlo.infeed"(%arg_out) { config = "x" } : (memref<3xf32>) -> ()
return
}
// -----
// CHECK-LABEL: func @outfeed_memrefs
func @outfeed_memrefs(%arg0: memref<3xf32>) -> () {
"xla_lhlo.outfeed"(%arg0) { config = "x" } : (memref<3xf32>) -> ()
return
}
// -----
// CHECK-LABEL: func @replica_id_memrefs
func @replica_id_memrefs(%arg_out: memref<ui32>) -> () {
"xla_lhlo.replica_id"(%arg_out) : (memref<ui32>) -> ()
return
}
// -----
// CHECK-LABEL: func @triangular_solve_memrefs
func @triangular_solve_memrefs(%arg0: memref<4x4xf32>, %arg1: memref<3x4xf32>, %arg_out: memref<3x4xf32>) -> () {
"xla_lhlo.triangular_solve"(%arg0, %arg1, %arg_out) {left_side = true, lower = true, transpose_a = "NO_TRANSPOSE", unit_diagonal = true}
: (memref<4x4xf32>, memref<3x4xf32>, memref<3x4xf32>) -> ()
return
}
// -----
// CHECK-LABEL: func @while_memrefs
func @while_memrefs(%arg0: memref<i64>, %arg_out: memref<i64>) -> () {
"xla_lhlo.while"(%arg0, %arg_out) (
{ ^bb0(%arg: memref<i64>, %cond: memref<i1>): "xla_lhlo.terminator"() : () -> () },
{ ^bb0(%arg: memref<i64>, %body_out: memref<i64>): "xla_lhlo.terminator"() : () -> () }
) : (memref<i64>, memref<i64>) -> ()
return
}

View File

@ -304,6 +304,46 @@ func @collective_permute_duplicate_sources(%arg0: tensor<128x32xf32>) -> tensor<
// -----
// CHECK-LABEL: @concat_1D
func @concat_1D(%arg0: tensor<1xi32>, %arg1: tensor<2xi32>) -> tensor<3xi32> {
%0 = "xla_hlo.concatenate"(%arg0, %arg1) { dimension = 0 : i64 } : (tensor<1xi32>, tensor<2xi32>) -> tensor<3xi32>
return %0 : tensor<3xi32>
}
// -----
func @concat_1D_type_error(%arg0: tensor<1xi32>, %arg1: tensor<2xf32>) -> tensor<3xi32> {
// expected-error@+1 {{'xla_hlo.concatenate' op requires the same element type for all operands and results}}
%0 = "xla_hlo.concatenate"(%arg0, %arg1) { dimension = 0 : i64 } : (tensor<1xi32>, tensor<2xf32>) -> tensor<3xi32>
return %0 : tensor<3xi32>
}
// -----
// CHECK-LABEL: @concat_1D_unranked
func @concat_1D_unranked(%arg0: tensor<1xi32>, %arg1: tensor<*xi32>) -> tensor<*xi32> {
%0 = "xla_hlo.concatenate"(%arg0, %arg1) { dimension = 0 : i64 } : (tensor<1xi32>, tensor<*xi32>) -> tensor<*xi32>
return %0 : tensor<*xi32>
}
// -----
func @concat_1D_unranked_error(%arg0: tensor<1xi32>, %arg1: tensor<*xi32>) -> tensor<3xi32> {
// expected-error@+1 {{'xla_hlo.concatenate' op inferred type incompatible with return type of operation}}
%0 = "xla_hlo.concatenate"(%arg0, %arg1) { dimension = 0 : i64 } : (tensor<1xi32>, tensor<*xi32>) -> tensor<3xi32>
return %0 : tensor<3xi32>
}
// -----
func @concat_1D_error(%arg0: tensor<1xi32>, %arg1: tensor<2xi32>) -> tensor<4xi32> {
// expected-error@+1 {{'xla_hlo.concatenate' op inferred type incompatible with return type of operation}}
%0 = "xla_hlo.concatenate"(%arg0, %arg1) { dimension = 0 : i64 } : (tensor<1xi32>, tensor<2xi32>) -> tensor<4xi32>
return %0 : tensor<4xi32>
}
// -----
// CHECK-LABEL: func @clamp
func @clamp(%arg0: tensor<1xi32>) -> tensor<1xi32> {
%0 = "xla_hlo.clamp"(%arg0, %arg0, %arg0) : (tensor<1xi32>, tensor<1xi32>, tensor<1xi32>) -> tensor<1xi32>

View File

@ -104,16 +104,20 @@ func @batchNormInference_dynamic_shape(
%x: tensor<?x?x?x?xf32>, %scale: tensor<?xf32>, %offset: tensor<?xf32>,
%mean: tensor<?xf32>, %variance: tensor<?xf32>)
-> tensor<?x?x?x?xf32> {
// CHECK-DAG: %[[C0:.*]] = constant 0 : index
// CHECK-DAG: %[[C1:.*]] = constant 1 : index
// CHECK-DAG: %[[C2:.*]] = constant 2 : index
// CHECK-DAG: %[[C3:.*]] = constant 3 : index
// CHECK-DAG: %[[EPS:.+]] = xla_hlo.constant dense<1.000000e-03> : tensor<f32>
// CHECK-DAG: %[[DIM:.+]] = dim %[[VARIANCE]], 0 : tensor<?xf32>
// CHECK-DAG: %[[DIM:.+]] = dim %[[VARIANCE]], %[[C0]] : tensor<?xf32>
// CHECK-DAG: %[[TO_DIM_TENSOR:.+]] = tensor_from_elements(%[[DIM]]) : tensor<1xindex>
// CHECK-DAG: %[[EPS_BCAST:.+]] = "xla_hlo.dynamic_broadcast_in_dim"(%[[EPS]], %[[TO_DIM_TENSOR]]) {broadcast_dimensions = dense<[]> : tensor<0xi64>} : (tensor<f32>, tensor<1xindex>) -> tensor<?xf32>
// CHECK-DAG: %[[VARIANCE_EPS:.+]] = xla_hlo.add %[[VARIANCE]], %[[EPS_BCAST]] : tensor<?xf32>
// CHECK-DAG: %[[STDDEV:.+]] = "xla_hlo.sqrt"(%[[VARIANCE_EPS]]) : (tensor<?xf32>) -> tensor<?xf32>
// CHECK-DAG: %[[INPUT_DIM_0:.+]] = dim %[[X]], 0 : tensor<?x?x?x?xf32>
// CHECK-DAG: %[[INPUT_DIM_1:.+]] = dim %[[X]], 1 : tensor<?x?x?x?xf32>
// CHECK-DAG: %[[INPUT_DIM_2:.+]] = dim %[[X]], 2 : tensor<?x?x?x?xf32>
// CHECK-DAG: %[[INPUT_DIM_3:.+]] = dim %[[X]], 3 : tensor<?x?x?x?xf32>
// CHECK-DAG: %[[INPUT_DIM_0:.+]] = dim %[[X]], %[[C0]] : tensor<?x?x?x?xf32>
// CHECK-DAG: %[[INPUT_DIM_1:.+]] = dim %[[X]], %[[C1]] : tensor<?x?x?x?xf32>
// CHECK-DAG: %[[INPUT_DIM_2:.+]] = dim %[[X]], %[[C2]] : tensor<?x?x?x?xf32>
// CHECK-DAG: %[[INPUT_DIM_3:.+]] = dim %[[X]], %[[C3]] : tensor<?x?x?x?xf32>
// CHECK-DAG: %[[TO_INPUT_DIM_TENSOR:.+]] = tensor_from_elements(%[[INPUT_DIM_0]], %[[INPUT_DIM_1]], %[[INPUT_DIM_2]], %[[INPUT_DIM_3]]) : tensor<4xindex>
// CHECK-DAG: %[[STDDEV_BCAST:.+]] = "xla_hlo.dynamic_broadcast_in_dim"(%[[STDDEV]], %[[TO_INPUT_DIM_TENSOR]]) {broadcast_dimensions = dense<1> : tensor<1xi64>} : (tensor<?xf32>, tensor<4xindex>) -> tensor<?x?x?x?xf32>
// CHECK-DAG: %[[SCALE_BCAST:.+]] = "xla_hlo.dynamic_broadcast_in_dim"(%[[SCALE]], %[[TO_INPUT_DIM_TENSOR]]) {broadcast_dimensions = dense<1> : tensor<1xi64>} : (tensor<?xf32>, tensor<4xindex>) -> tensor<?x?x?x?xf32>

View File

@ -16,8 +16,7 @@ func @sqr_transform_result(%a: tensor<*xf32>) -> tensor<*xf32> {
%flat_b = "xla_hlo.sqrt"(%flat_a) : (tensor<?xf32>) -> tensor<?xf32>
// Restore original shape.
%shape_as_extent_tensor = "shape.to_extent_tensor"(%shape)
: (!shape.shape) -> tensor<?xindex>
%shape_as_extent_tensor = shape.to_extent_tensor %shape : tensor<?xindex>
%b = "xla_hlo.dynamic_reshape"(%flat_b, %shape_as_extent_tensor)
: (tensor<?xf32>, tensor<?xindex>) -> tensor<*xf32>
@ -35,7 +34,7 @@ func @sqrt(%a: tensor<*xf32>) -> tensor<*xf32> {
// CHECK-NEXT: %[[FLAT_SHAPE:.*]] = tensor_from_elements(%[[NUM_ELEMENTS_AS_INDEX]]) : tensor<1xindex>
// CHECK-NEXT: %[[FLAT_A:.*]] = "xla_hlo.dynamic_reshape"(%[[A]], %[[FLAT_SHAPE]]) : (tensor<*xf32>, tensor<1xindex>) -> tensor<?xf32>
// CHECK-NEXT: %[[FLAT_B:.*]] = "xla_hlo.sqrt"(%[[FLAT_A]]) : (tensor<?xf32>) -> tensor<?xf32>
// CHECK-NEXT: %[[SHAPE_AS_EXTENT_TENSOR:.*]] = "shape.to_extent_tensor"(%[[SHAPE]]) : (!shape.shape) -> tensor<?xindex>
// CHECK-NEXT: %[[SHAPE_AS_EXTENT_TENSOR:.*]] = shape.to_extent_tensor %[[SHAPE]] : tensor<?xindex>
// CHECK-NEXT: %[[B:.*]] = "xla_hlo.dynamic_reshape"(%[[FLAT_B]], %[[SHAPE_AS_EXTENT_TENSOR]]) : (tensor<?xf32>, tensor<?xindex>) -> tensor<*xf32>
// CHECK-NEXT: return %[[B]] : tensor<*xf32>
%b = "xla_hlo.sqrt"(%a) : (tensor<*xf32>) -> tensor<*xf32>

View File

@ -18,6 +18,7 @@ limitations under the License.
#include <cstddef>
#include <cstdint>
#include <iterator>
#include <limits>
#include <numeric>
#include "llvm/ADT/ArrayRef.h"
@ -25,6 +26,7 @@ limitations under the License.
#include "llvm/ADT/STLExtras.h"
#include "llvm/ADT/Sequence.h"
#include "llvm/ADT/SmallVector.h"
#include "llvm/Support/ErrorHandling.h"
#include "llvm/Support/FormatVariadic.h"
#include "mlir/Dialect/Shape/IR/Shape.h" // from @llvm-project
#include "mlir/Dialect/StandardOps/IR/Ops.h" // from @llvm-project
@ -54,6 +56,7 @@ limitations under the License.
#include "tensorflow/compiler/xla/xla_data.pb.h"
#include "tensorflow/core/framework/kernel_shape_util.h"
#include "tensorflow/core/kernels/conv_grad_shape_utils.h"
#include "tensorflow/core/lib/bfloat16/bfloat16.h"
#include "tensorflow/core/util/padding.h"
#include "tensorflow/core/util/tensor_format.h"
@ -881,6 +884,31 @@ static Type GetAccumulationType(Type ty) {
return (ty.isF16() || ty.isBF16()) ? FloatType::getF32(ty.getContext()) : ty;
}
//===----------------------------------------------------------------------===//
// Softplus op utilities.
//===----------------------------------------------------------------------===//
static DenseElementsAttr GetEpsilonValue(Type ty) {
auto element_ty = ty.cast<TensorType>().getElementType();
auto scalar_ty = RankedTensorType::get({}, element_ty);
if (element_ty.isF16()) {
uint16_t raw_epsilon = Eigen::NumTraits<Eigen::half>::epsilon().x;
auto value = APFloat(APFloat::IEEEhalf(), APInt(16, raw_epsilon));
return DenseElementsAttr::get(scalar_ty, value);
} else if (element_ty.isBF16()) {
uint16_t raw_epsilon = tensorflow::bfloat16::epsilon().value;
auto value = APFloat(APFloat::BFloat(), APInt(16, raw_epsilon));
return DenseElementsAttr::get(scalar_ty, value);
} else if (element_ty.isF32()) {
auto value = APFloat(std::numeric_limits<float>::epsilon());
return DenseElementsAttr::get(scalar_ty, value);
} else if (element_ty.isF64()) {
auto value = APFloat(std::numeric_limits<double>::epsilon());
return DenseElementsAttr::get(scalar_ty, value);
}
llvm_unreachable("unsupported element type for tf.SoftPlus");
}
//===----------------------------------------------------------------------===//
// ArgMax/ArgMin op utilities.
//===----------------------------------------------------------------------===//
@ -4387,45 +4415,6 @@ class ConvertRandomShuffleOp : public OpRewritePattern<TF::RandomShuffleOp> {
}
};
// Converts tf.VariableShape op to a XLA HLO constant representing the variable
// shape.
class ConvertVariableShapeOp : public OpRewritePattern<TF::VariableShapeOp> {
public:
using OpRewritePattern::OpRewritePattern;
LogicalResult matchAndRewrite(TF::VariableShapeOp op,
PatternRewriter &rewriter) const override {
// The input type should be a tensor<!tf.resource<resource-type>>. We need
// to get the inner resource type.
auto input_type = op.input().getType().cast<TensorType>();
auto subtypes =
input_type.getElementType().cast<TF::ResourceType>().getSubtypes();
// It can be missing; then we cannot convert.
if (subtypes.empty()) return failure();
auto resource_type = subtypes[0].cast<TensorType>();
if (!resource_type.hasStaticShape()) return failure();
auto resource_shape = resource_type.getShape();
Attribute const_attr;
// We need to match the original op result's element type.
auto element_type = op.getType().cast<TensorType>().getElementType();
unsigned bitwidth = element_type.cast<IntegerType>().getWidth();
if (bitwidth == 32) {
SmallVector<int32_t, 4> shape(resource_shape.begin(),
resource_shape.end());
const_attr = GetI32ElementsAttr(shape, &rewriter);
} else {
assert(bitwidth == 64);
const_attr = GetI64ElementsAttr(resource_shape, &rewriter);
}
rewriter.replaceOpWithNewOp<xla_hlo::ConstOp>(op, const_attr);
return success();
}
};
// Converts an XlaSharding op to a XLA HLO shard op with sharding attributes.
class ConvertXlaShardingOp : public OpRewritePattern<TF::XlaShardingOp> {
public:
@ -4621,45 +4610,19 @@ class ConvertShapeOp : public OpRewritePattern<TF::ShapeOp> {
LogicalResult matchAndRewrite(TF::ShapeOp op,
PatternRewriter &rewriter) const override {
Value input = op.input();
auto input_ty = input.getType().dyn_cast<RankedTensorType>();
// If the shape is static it can be canonicalized.
if (!input_ty || input_ty.hasStaticShape()) {
auto shape_op = rewriter.create<shape::ShapeOfOp>(op.getLoc(), input);
auto result_ty = op.getResult().getType().dyn_cast<RankedTensorType>();
if (!result_ty) {
return failure();
}
auto result_ty = op.getResult().getType().cast<RankedTensorType>();
auto element_ty = result_ty.getElementType();
auto index_tensor =
RankedTensorType::get(result_ty.getShape(), rewriter.getIndexType());
auto extent_tensor = rewriter.create<shape::ToExtentTensorOp>(
op.getLoc(), index_tensor, shape_op);
int64_t rank = input_ty.getRank();
auto shape_op = rewriter.create<shape::ShapeOfOp>(op.getLoc(), input);
auto index_ty = RankedTensorType::get({1}, element_ty);
llvm::SmallVector<Value, 4> dim_values;
for (int64_t i = 0; i < rank; ++i) {
if (!input_ty.isDynamicDim(i)) {
auto dim_attr = DenseElementsAttr::get(
index_ty,
rewriter.getIntegerAttr(element_ty, input_ty.getDimSize(i)));
auto index = rewriter.create<xla_hlo::ConstOp>(op.getLoc(), dim_attr);
dim_values.push_back(index);
continue;
}
auto extent_op = rewriter.create<shape::GetExtentOp>(
op.getLoc(), shape_op, rewriter.getI64IntegerAttr(i));
auto index_op = rewriter.create<shape::SizeToIndexOp>(
op.getLoc(), rewriter.getIndexType(), extent_op);
auto int_op =
rewriter.create<IndexCastOp>(op.getLoc(), element_ty, index_op);
auto from_tensor = rewriter.create<TensorFromElementsOp>(
op.getLoc(), int_op.getResult());
auto reshape_op =
rewriter.create<ReshapeOp>(op.getLoc(), index_ty, from_tensor);
dim_values.push_back(reshape_op);
}
rewriter.replaceOpWithNewOp<ConcatenateOp>(op, result_ty, dim_values,
rewriter.getI64IntegerAttr(0));
rewriter.replaceOpWithNewOp<IndexCastOp>(op, result_ty, extent_tensor);
return success();
}
};
@ -5250,7 +5213,7 @@ LogicalResult legalizeTF(Operation *op, bool allow_partial_conversion,
ConvertTensorScatterUpdateOp, ConvertTileOp, ConvertTopKV2Op,
ConvertUnpackOp, ConvertUnsortedSegmentMaxOp, ConvertUnsortedSegmentMinOp,
ConvertUnsortedSegmentProdOp, ConvertUnsortedSegmentSumOp,
ConvertRandomShuffleOp, ConvertVariableShapeOp, ConvertXlaShardingOp,
ConvertRandomShuffleOp, ConvertXlaShardingOp,
ConvertXlaDynamicUpdateSliceOp>(op->getContext());
// Populate with CHLO->HLO lowerings to account for TF ops legalized to

View File

@ -618,3 +618,39 @@ def : Pat<(TF_SigmoidGradOp AnyRankedTensor:$l, AnyRankedTensor:$r),
(HLO_MulOp
(HLO_MulOp $r, $l),
(HLO_SubOp (HLO_ConstOp (ConstantSplat<"1"> $l)), $l))>;
//===----------------------------------------------------------------------===//
// Softplus op.
//===----------------------------------------------------------------------===//
def EpsilonValue : NativeCodeCall<"GetEpsilonValue($0.getType())">;
def : Pattern<(TF_SoftplusOp AnyTensor:$features),
[
(HLO_ExpOp:$features_exp $features),
(HLOClient_BroadcastAddOp:$threshold
(HLO_LogOp (HLO_ConstOp (EpsilonValue $features))),
(HLO_ConstOp (GetScalarOfType<2> $features)),
(NullDenseIntElementsAttr)
),
(HLO_SelectOp:$output
(HLOClient_BroadcastCompareOp
$features,
(HLO_NegOp $threshold),
(NullDenseIntElementsAttr),
HLO_COMPARISON_DIRECTION_GT
),
$features,
(HLO_SelectOp
(HLOClient_BroadcastCompareOp
$features,
$threshold,
(NullDenseIntElementsAttr),
HLO_COMPARISON_DIRECTION_LT
),
$features_exp,
(HLO_Log1pOp $features_exp)
)
),
(replaceWithValue $output)
]>;

View File

@ -89,6 +89,8 @@ static bool IsOpWhitelisted(Operation* op) {
TypeID::get<TF::AddV2Op>(),
TypeID::get<TF::AngleOp>(),
TypeID::get<TF::ApproximateEqualOp>(),
TypeID::get<TF::ArgMaxOp>(),
TypeID::get<TF::ArgMinOp>(),
TypeID::get<TF::AsinhOp>(),
TypeID::get<TF::AsinOp>(),
TypeID::get<TF::Atan2Op>(),
@ -100,6 +102,7 @@ static bool IsOpWhitelisted(Operation* op) {
TypeID::get<TF::BitwiseAndOp>(),
TypeID::get<TF::BitwiseOrOp>(),
TypeID::get<TF::BitwiseXorOp>(),
TypeID::get<TF::BucketizeOp>(),
TypeID::get<TF::CastOp>(),
TypeID::get<TF::ClipByValueOp>(),
TypeID::get<TF::ComplexAbsOp>(),
@ -116,13 +119,24 @@ static bool IsOpWhitelisted(Operation* op) {
TypeID::get<TF::ErfcOp>(),
TypeID::get<TF::ErfOp>(),
TypeID::get<TF::Expm1Op>(),
TypeID::get<TF::FFT2DOp>(),
TypeID::get<TF::FFT3DOp>(),
TypeID::get<TF::FFTOp>(),
TypeID::get<TF::FloorDivOp>(),
TypeID::get<TF::FloorModOp>(),
TypeID::get<TF::GatherNdOp>(),
TypeID::get<TF::GreaterEqualOp>(),
TypeID::get<TF::GreaterOp>(),
TypeID::get<TF::IFFT2DOp>(),
TypeID::get<TF::IFFT3DOp>(),
TypeID::get<TF::IFFTOp>(),
TypeID::get<TF::IRFFT2DOp>(),
TypeID::get<TF::IRFFT3DOp>(),
TypeID::get<TF::IRFFTOp>(),
TypeID::get<TF::InvertOp>(),
TypeID::get<TF::InvOp>(),
TypeID::get<TF::LRNOp>(),
TypeID::get<TF::LRNGradOp>(),
TypeID::get<TF::LeakyReluGradOp>(),
TypeID::get<TF::LeakyReluOp>(),
TypeID::get<TF::LeftShiftOp>(),
@ -134,16 +148,20 @@ static bool IsOpWhitelisted(Operation* op) {
TypeID::get<TF::LogicalOrOp>(),
TypeID::get<TF::LogOp>(),
TypeID::get<TF::MatMulOp>(),
TypeID::get<TF::MirrorPadOp>(),
TypeID::get<TF::MulOp>(),
TypeID::get<TF::NegOp>(),
TypeID::get<TF::NotEqualOp>(),
TypeID::get<TF::PadOp>(),
TypeID::get<TF::PlaceholderWithDefaultOp>(),
TypeID::get<TF::PowOp>(),
TypeID::get<TF::RFFT2DOp>(),
TypeID::get<TF::RFFT3DOp>(),
TypeID::get<TF::RealDivOp>(),
TypeID::get<TF::ReciprocalOp>(),
TypeID::get<TF::ReciprocalGradOp>(),
TypeID::get<TF::Relu6GradOp>(),
TypeID::get<TF::ReverseSequenceOp>(),
TypeID::get<TF::RightShiftOp>(),
TypeID::get<TF::RintOp>(),
TypeID::get<TF::RoundOp>(),
@ -156,6 +174,7 @@ static bool IsOpWhitelisted(Operation* op) {
TypeID::get<TF::SoftplusGradOp>(),
TypeID::get<TF::SoftsignGradOp>(),
TypeID::get<TF::SoftsignOp>(),
TypeID::get<TF::SparseToDenseOp>(),
TypeID::get<TF::SqrtGradOp>(),
TypeID::get<TF::SquareOp>(),
TypeID::get<TF::SubOp>(),

View File

@ -16,7 +16,6 @@ limitations under the License.
// This file implements logic for lowering HLO/LHLO dialect to Linalg dialect.
#include "absl/memory/memory.h"
#include "llvm/ADT/APInt.h"
#include "mlir/Dialect/Linalg/IR/LinalgOps.h" // from @llvm-project
#include "mlir/Dialect/Linalg/IR/LinalgTypes.h" // from @llvm-project
#include "mlir/Dialect/StandardOps/IR/Ops.h" // from @llvm-project
@ -39,13 +38,9 @@ limitations under the License.
namespace mlir {
namespace {
ArrayAttr GetNParallelLoopsAttrs(unsigned nParallelLoops, Builder* b) {
auto parallelLoopTypeAttr = b->getStringAttr("parallel");
SmallVector<Attribute, 3> iteratorTypes;
for (int i = 0; i < nParallelLoops; ++i) {
iteratorTypes.push_back(parallelLoopTypeAttr);
}
return b->getArrayAttr(iteratorTypes);
SmallVector<StringRef, 3> GetNParallelLoopsAttrs(unsigned nParallelLoops) {
static constexpr StringRef kParallelIterType = "parallel";
return SmallVector<StringRef, 3>(nParallelLoops, kParallelIterType);
}
template <bool isLHLO = true>
@ -90,7 +85,7 @@ class PointwiseToLinalgConverter : public OpConversionPattern<OpTy> {
}
// Construct the indexing maps needed for linalg.generic ops.
SmallVector<Attribute, 2> indexingMaps;
SmallVector<AffineMap, 2> indexing_maps;
SmallVector<Type, 4> bodyArgTypes, bodyResultTypes, opResultTypes;
// This doesnt account for implicit broadcast, but the working assumption
@ -107,9 +102,9 @@ class PointwiseToLinalgConverter : public OpConversionPattern<OpTy> {
!shapedType.isa<RankedTensorType>()) ||
shapedType.getRank() != nloops)
return nullptr;
indexingMaps.emplace_back(AffineMapAttr::get(
indexing_maps.emplace_back(
nloops ? rewriter.getMultiDimIdentityMap(nloops)
: AffineMap::get(nloops, 0, rewriter.getContext())));
: AffineMap::get(nloops, 0, rewriter.getContext()));
return shapedType;
};
for (const auto& arg : llvm::enumerate(args)) {
@ -132,11 +127,9 @@ class PointwiseToLinalgConverter : public OpConversionPattern<OpTy> {
auto linalgOp = rewriter.create<linalg::GenericOp>(
loc, opResultTypes, args,
rewriter.getI64IntegerAttr(bodyArgTypes.size()), // args_in
rewriter.getI64IntegerAttr(bodyResultTypes.size()), // args_out
rewriter.getArrayAttr(indexingMaps),
GetNParallelLoopsAttrs(nloops, &rewriter),
/*doc=*/nullptr, /*library_call=*/nullptr);
/*inputCount=*/bodyArgTypes.size(),
/*outputCount=*/bodyResultTypes.size(), indexing_maps,
GetNParallelLoopsAttrs(nloops));
// Add a block to the region.
auto* region = &linalgOp.region();
@ -297,8 +290,8 @@ struct ConvToLinalgConverter : public OpConversionPattern<xla_lhlo::ConvOp> {
/// Base class for lowering xla operations that have one operand and one result,
/// and are semantically equivalent to a copy of the input to the output (like
/// transpose, some reshape, etc.). The derived classes need to provide a method
/// `getIndexingMapsAttr` that returns an ArrayAttr containing AffineMapAttr for
/// the index maps of the input and the output.
/// `getIndexingMaps` that returns AffineMaps for the index maps of the input
/// and the output.
template <typename Derived, typename OpTy, bool isLHLO = true>
class DataMovementOpConverter : public OpConversionPattern<OpTy> {
public:
@ -310,17 +303,17 @@ class DataMovementOpConverter : public OpConversionPattern<OpTy> {
if (!verifyXLAOpBufferOrTensorSemantics<isLHLO>(op)) return failure();
auto operandType = op.operand().getType().template cast<ShapedType>();
auto resultType = getXLAOpResultType<isLHLO>(op);
ArrayAttr indexingMapsAttr = Derived::getIndexingMapsAttr(op, &rewriter);
if (!indexingMapsAttr) return failure();
SmallVector<AffineMap, 2> indexing_maps =
Derived::getIndexingMaps(op, &rewriter);
if (indexing_maps.empty()) return failure();
OpBuilder::InsertionGuard linalgOpGuard(rewriter);
auto nloops = resultType.getRank();
auto loc = op.getLoc();
auto linalgOp = rewriter.create<linalg::GenericOp>(
loc, isLHLO ? ArrayRef<Type>{} : resultType, args,
rewriter.getI64IntegerAttr(1), rewriter.getI64IntegerAttr(1),
indexingMapsAttr, GetNParallelLoopsAttrs(nloops, &rewriter),
/*doc=*/nullptr, /*library_call=*/nullptr);
loc, isLHLO ? ArrayRef<Type>{} : resultType, args, /*inputCount=*/1,
/*outputCount=*/1, indexing_maps, GetNParallelLoopsAttrs(nloops));
auto* region = &linalgOp.region();
auto* block = rewriter.createBlock(region, region->end());
@ -344,7 +337,8 @@ class BroadcastConverter
using DataMovementOpConverter<BroadcastConverter, OpTy,
isLHLO>::DataMovementOpConverter;
static ArrayAttr getIndexingMapsAttr(OpTy broadcastOp, Builder* b) {
static SmallVector<AffineMap, 2> getIndexingMaps(OpTy broadcastOp,
Builder* b) {
ShapedType inputType =
broadcastOp.operand().getType().template cast<ShapedType>();
unsigned inputRank = inputType.getRank();
@ -368,8 +362,7 @@ class BroadcastConverter
inputMap =
AffineMap::get(nloops, /*symbolCount=*/0, inputDimExprs, context);
}
return b->getAffineMapArrayAttr(
{inputMap, b->getMultiDimIdentityMap(nloops)});
return {inputMap, b->getMultiDimIdentityMap(nloops)};
}
};
@ -381,8 +374,8 @@ class HloBroadcastInDimConverter
xla_hlo::BroadcastInDimOp,
false>::DataMovementOpConverter;
static ArrayAttr getIndexingMapsAttr(xla_hlo::BroadcastInDimOp broadcastOp,
Builder* b) {
static SmallVector<AffineMap, 2> getIndexingMaps(
xla_hlo::BroadcastInDimOp broadcastOp, Builder* b) {
auto resultType = getXLAOpResultType<false>(broadcastOp);
auto operandType =
broadcastOp.operand().getType().template cast<ShapedType>();
@ -390,9 +383,8 @@ class HloBroadcastInDimConverter
// The input is a scalar, i.e. this is a scalar broadcast op.
if (operandType.getRank() == 0) {
return b->getAffineMapArrayAttr(
{AffineMap::get(nloops, /*symbolCount=*/0, b->getContext()),
b->getMultiDimIdentityMap(nloops)});
return {AffineMap::get(nloops, /*symbolCount=*/0, b->getContext()),
b->getMultiDimIdentityMap(nloops)};
}
auto operandShape = operandType.getShape();
@ -409,9 +401,9 @@ class HloBroadcastInDimConverter
: b->getAffineDimExpr(size));
}
}
return b->getAffineMapArrayAttr(
{AffineMap::get(nloops, /*symbolCount=*/0, dimExprs, b->getContext()),
b->getMultiDimIdentityMap(nloops)});
return {
AffineMap::get(nloops, /*symbolCount=*/0, dimExprs, b->getContext()),
b->getMultiDimIdentityMap(nloops)};
}
};
@ -447,11 +439,9 @@ class LhloBroadcastInDimConverter
rewriter.create<LoadOp>(loc, operand, llvm::makeArrayRef({zero}));
auto linalgOp = rewriter.create<linalg::GenericOp>(
loc, llvm::None, llvm::makeArrayRef(operand_adaptor.output()),
rewriter.getI64IntegerAttr(0), rewriter.getI64IntegerAttr(1),
rewriter.getAffineMapArrayAttr(
{rewriter.getMultiDimIdentityMap(nloops)}),
GetNParallelLoopsAttrs(nloops, &rewriter),
/*doc=*/nullptr, /*library_call=*/nullptr);
/*inputCount=*/0, /*outputCount=*/1,
llvm::makeArrayRef(rewriter.getMultiDimIdentityMap(nloops)),
GetNParallelLoopsAttrs(nloops));
auto* region = &linalgOp.region();
auto* block = rewriter.createBlock(region, region->end());
@ -460,16 +450,15 @@ class LhloBroadcastInDimConverter
rewriter.setInsertionPointToEnd(block);
rewriter.create<linalg::YieldOp>(loc, val);
} else {
ArrayAttr indexingMapsAttr = getIndexingMapsAttr(
op, broadcast_dims, result_shape, operand_type, &rewriter);
auto indexing_maps = getIndexingMaps(op, broadcast_dims, result_shape,
operand_type, &rewriter);
OpBuilder::InsertionGuard linalgOpGuard(rewriter);
auto linalgOp = rewriter.create<linalg::GenericOp>(
loc, llvm::None,
llvm::makeArrayRef({operand, operand_adaptor.output()}),
rewriter.getI64IntegerAttr(1), rewriter.getI64IntegerAttr(1),
indexingMapsAttr, GetNParallelLoopsAttrs(nloops, &rewriter),
/*doc=*/nullptr, /*library_call=*/nullptr);
/*inputCount=*/1, /*outputCount=*/1, indexing_maps,
GetNParallelLoopsAttrs(nloops));
auto* region = &linalgOp.region();
auto* block = rewriter.createBlock(region, region->end());
@ -504,14 +493,14 @@ class LhloBroadcastInDimConverter
}
SmallVector<int64_t, 2> new_shape, new_strides, broadcast_dims;
SmallVector<SmallVector<AffineExpr, 2>, 4> collapsed_dims_list;
SmallVector<AffineExpr, 2> collapsed_dims;
SmallVector<linalg::ReassociationIndices, 4> collapsed_dims_list;
linalg::ReassociationIndices collapsed_dims;
for (const auto& item :
enumerate(op.broadcast_dimensions().getIntValues())) {
size_t index = item.index();
int dim = item.value().getSExtValue();
collapsed_dims.push_back(rewriter.getAffineDimExpr(index));
collapsed_dims.push_back(index);
bool expansion_needed =
operand_shape[index] == 1 && result_shape[dim] != 1;
@ -542,31 +531,28 @@ class LhloBroadcastInDimConverter
// `linalg.reshape` is inserted only if necessary, i.e. when the rank can be
// reduced.
if (new_shape.size() < operand_shape.size()) {
SmallVector<ArrayRef<AffineExpr>, 4> reassociation_maps;
for (const auto& dims : collapsed_dims_list)
reassociation_maps.push_back(dims);
auto new_memref_type = MemRefType::get(
new_shape, operand_type.getElementType(),
makeStridedLinearLayoutMap(new_strides, operand_offset,
rewriter.getContext()));
operand = rewriter.create<linalg::ReshapeOp>(op.getLoc(), new_memref_type,
operand_adaptor.operand(),
reassociation_maps);
collapsed_dims_list);
}
return std::make_pair(operand, broadcast_dims);
}
ArrayAttr getIndexingMapsAttr(xla_lhlo::BroadcastInDimOp op,
SmallVector<AffineMap, 2> getIndexingMaps(xla_lhlo::BroadcastInDimOp op,
ArrayRef<int64_t> broadcastDims,
ArrayRef<int64_t> resultShape,
MemRefType operandType, Builder* b) const {
MemRefType operandType,
Builder* b) const {
unsigned nloops = resultShape.size();
// The input is a scalar, i.e. this is a scalar broadcast op.
if (operandType.getRank() == 0) {
return b->getAffineMapArrayAttr(
{AffineMap::get(nloops, /*symbolCount=*/0, b->getContext()),
b->getMultiDimIdentityMap(nloops)});
return {AffineMap::get(nloops, /*symbolCount=*/0, b->getContext()),
b->getMultiDimIdentityMap(nloops)};
}
auto operandShape = operandType.getShape();
@ -584,99 +570,9 @@ class LhloBroadcastInDimConverter
}
dimExprs.push_back(b->getAffineDimExpr(size));
}
return b->getAffineMapArrayAttr(
{AffineMap::get(nloops, /*symbolCount=*/0, dimExprs, b->getContext()),
b->getMultiDimIdentityMap(nloops)});
}
};
/// Pattern for the special case where reshape is adding or removing a dimension
/// of size 1. These can be lowered to a linalg.generic op.
///
/// For example a
/// "xla_hlo.reshape"(..) : (tensor<12x1x42xi32) -> tensor<12x42xi32>
/// can have indexing maps
/// [affine_map<(d0, d1) -> (d0, 0, d1)>, affine_map<(d0, d1) -> (d0, d1)>]
///
/// Similarly a
/// "xla_hlo.reshape"(..) : (tensor<12x42xi32>) -> tensor<12x1x42xi32>
/// can have indexing maps
/// [affine_map<(d0, d1, d2) -> (d0, d2)>, affine_map<(d0, d1, d2) -> (d0, d1,
/// d2)>]
// TODO(ravishankarm): This pattern needs to be removed. The general reshape
// lowering hits a corner case where the following sequence of operations
// cannot be fused cause the resulting indexing map is not invertible.
//
// %r = linalg.reshape %s [affine_map<(d0, d1, d2) -> (d0, d1)>,
// affine_map<(d0, d1, d2) -> (d2)>]
// : tensor<5x5xf32> into tensor<5x1x5xf32>
// %f = linalg.generic
// {...
// indexing_maps = [affine_map<(d0, d1, d2) -> (d0, d1, d2)>,
// affine_map<(d0, d1, d2) -> (d0, d2)>],
// iterator_types = ["parallel", "parallel", "parallel"]} %r {..}
// : tensor<5x1x5xf32> -> tensor<5x5xf32>
//
// The resolution of this requires a canonicalization on linalg ops where the
// dims of size 1 are removed. This pattern can be removed after that.
template <typename OpTy, bool isLHLO = true>
class ReshapeAddRemoveDimConverter
: public DataMovementOpConverter<ReshapeAddRemoveDimConverter<OpTy, isLHLO>,
OpTy, isLHLO> {
public:
ReshapeAddRemoveDimConverter(MLIRContext* context)
: DataMovementOpConverter<ReshapeAddRemoveDimConverter<OpTy, isLHLO>,
OpTy, isLHLO>(context, 100) {}
static ArrayAttr getIndexingMapsAttr(OpTy op, Builder* b) {
auto resultType =
getXLAOpResultType<isLHLO>(op).template cast<ShapedType>();
auto operandType =
op.getOperation()->getOperand(0).getType().template cast<ShapedType>();
if (!resultType.hasStaticShape() || !operandType.hasStaticShape())
return nullptr;
auto nloops = resultType.getRank();
SmallVector<AffineExpr, 2> inputExprs;
unsigned resultIndex = 0, operandIndex = 0;
auto resultShape = resultType.getShape();
auto operandShape = operandType.getShape();
while (resultIndex < resultShape.size() &&
operandIndex < operandShape.size()) {
if (resultShape[resultIndex] == operandShape[operandIndex]) {
// Copy over the affine expr when the size of the result and operand
// match at a dim
inputExprs.push_back(b->getAffineDimExpr(resultIndex));
resultIndex++;
operandIndex++;
} else if (resultShape[resultIndex] == 1) {
// If size at result is 1, then ignore this dimension for the input, it
// is an extra dim added.
resultIndex++;
} else if (operandShape[operandIndex] == 1) {
// If the operandShape is 1, then add a (0) for the operand map since
// this dimension is dropped.
inputExprs.push_back(b->getAffineConstantExpr(0));
operandIndex++;
} else {
return nullptr;
}
}
// Make sure all remaining dimensions of the operand and result are ones.
auto checkRemainingDims = [](int64_t dim) { return dim != 1; };
if ((resultIndex < resultShape.size() &&
llvm::any_of(resultShape.drop_front(resultIndex),
checkRemainingDims)) ||
(operandIndex < operandShape.size() &&
llvm::any_of(operandShape.drop_front(operandIndex),
checkRemainingDims)))
return nullptr;
inputExprs.resize(operandShape.size(), b->getAffineConstantExpr(0));
return b->getAffineMapArrayAttr(
{AffineMap::get(nloops, /*symbolCount=*/0, inputExprs, b->getContext()),
b->getMultiDimIdentityMap(nloops)});
return {
AffineMap::get(nloops, /*symbolCount=*/0, dimExprs, b->getContext()),
b->getMultiDimIdentityMap(nloops)};
}
};
@ -687,7 +583,7 @@ class TransposeConverter
public:
using DataMovementOpConverter<TransposeConverter<OpTy, isLHLO>, OpTy,
isLHLO>::DataMovementOpConverter;
static ArrayAttr getIndexingMapsAttr(OpTy op, Builder* b) {
static SmallVector<AffineMap, 2> getIndexingMaps(OpTy op, Builder* b) {
auto resultType =
getXLAOpResultType<isLHLO>(op).template cast<ShapedType>();
auto nloops = resultType.getRank();
@ -697,9 +593,9 @@ class TransposeConverter
inputExprs[permutation.value().getZExtValue()] =
b->getAffineDimExpr(permutation.index());
}
return b->getAffineMapArrayAttr(
{AffineMap::get(nloops, /*symbolCount=*/0, inputExprs, b->getContext()),
b->getMultiDimIdentityMap(nloops)});
return {
AffineMap::get(nloops, /*symbolCount=*/0, inputExprs, b->getContext()),
b->getMultiDimIdentityMap(nloops)};
}
};
@ -722,13 +618,6 @@ class ReshapeOpConverter : public OpConversionPattern<OpTy> {
if (!operandType.hasStaticShape() || !resultType.hasStaticShape())
return failure();
// TODO(ravishankarm): To make this pattern not match the pattern that
// ReshapeAddRemoveDimConverter is for, check that condition here. Remove
// this when ReshapeAddRemoveDimConverter pattern is removed.
if (ReshapeAddRemoveDimConverter<OpTy, isLHLO>::getIndexingMapsAttr(
reshapeOp, &rewriter))
return failure();
// Compute the reassociation maps for the linalg operation.
ArrayRef<int64_t> srcShape =
(operandType.getRank() > resultType.getRank() ? operandType.getShape()
@ -737,22 +626,25 @@ class ReshapeOpConverter : public OpConversionPattern<OpTy> {
(operandType.getRank() > resultType.getRank() ? resultType.getShape()
: operandType.getShape());
unsigned currSrcDim = 0, currDstDim = 0;
SmallVector<SmallVector<AffineExpr, 4>, 4> exprs(dstShape.size());
SmallVector<linalg::ReassociationExprs, 4> reassociationMap(
dstShape.size());
while (currSrcDim < srcShape.size() && currDstDim < dstShape.size()) {
int64_t dstSize = dstShape[currDstDim];
int64_t srcSize = srcShape[currSrcDim];
while (srcSize < dstSize && currSrcDim < srcShape.size()) {
exprs[currDstDim].push_back(rewriter.getAffineDimExpr(currSrcDim++));
reassociationMap[currDstDim].push_back(
rewriter.getAffineDimExpr(currSrcDim++));
srcSize *= srcShape[currSrcDim];
}
if (srcSize == dstSize) {
exprs[currDstDim].push_back(rewriter.getAffineDimExpr(currSrcDim++));
reassociationMap[currDstDim].push_back(
rewriter.getAffineDimExpr(currSrcDim++));
// If the next dim in dstShape is not 1, treat subsequent dims in
// srcShape which are 1 to be collapsed.
if (currDstDim == dstShape.size() - 1 ||
dstShape[currDstDim + 1] != 1) {
while (currSrcDim < srcShape.size() && srcShape[currSrcDim] == 1) {
exprs[currDstDim].push_back(
reassociationMap[currDstDim].push_back(
rewriter.getAffineDimExpr(currSrcDim++));
}
}
@ -763,18 +655,15 @@ class ReshapeOpConverter : public OpConversionPattern<OpTy> {
}
if (currSrcDim != srcShape.size()) return failure();
SmallVector<ArrayRef<AffineExpr>, 4> reassociationMaps;
for (auto& expr : exprs) reassociationMaps.push_back(expr);
if (isLHLO) {
Value reshapeBuffer = rewriter.create<linalg::ReshapeOp>(
reshapeOp.getLoc(), resultType, args[0], reassociationMaps);
reshapeOp.getLoc(), resultType, args[0], reassociationMap);
rewriter.replaceOpWithNewOp<linalg::CopyOp>(
reshapeOp, reshapeBuffer, args[1], /*inputPermutation =*/nullptr,
/*outputPermutation =*/nullptr);
} else {
rewriter.replaceOpWithNewOp<linalg::TensorReshapeOp>(
reshapeOp, resultType, args[0], reassociationMaps);
reshapeOp, resultType, args[0], reassociationMap);
}
return success();
}
@ -796,18 +685,14 @@ class IotaConverter : public OpConversionPattern<xla_lhlo::IotaOp> {
// Construct the indexing maps needed for linalg.generic ops.
unsigned nloops = resultMemrefType.getRank();
SmallVector<Attribute, 2> indexingMaps;
indexingMaps.emplace_back(
AffineMapAttr::get(rewriter.getMultiDimIdentityMap(nloops)));
auto loc = iotaOp.getLoc();
auto linalgOp = rewriter.create<linalg::IndexedGenericOp>(
loc, ArrayRef<Type>{}, args,
rewriter.getI64IntegerAttr(0), // args_in
rewriter.getI64IntegerAttr(1), // args_out
rewriter.getArrayAttr(indexingMaps),
GetNParallelLoopsAttrs(nloops, &rewriter),
/*doc=*/nullptr, /*library_call=*/nullptr);
0, // args_in
1, // args_out
llvm::makeArrayRef(rewriter.getMultiDimIdentityMap(nloops)),
GetNParallelLoopsAttrs(nloops));
// Add a block to the region.
auto* region = &linalgOp.region();
@ -857,7 +742,7 @@ class ReverseConverter
public:
using DataMovementOpConverter<ReverseConverter<OpTy, isLHLO>, OpTy,
isLHLO>::DataMovementOpConverter;
static ArrayAttr getIndexingMapsAttr(OpTy op, Builder* b) {
static SmallVector<AffineMap, 2> getIndexingMaps(OpTy op, Builder* b) {
auto resultType =
getXLAOpResultType<isLHLO>(op).template cast<ShapedType>();
auto nloops = resultType.getRank();
@ -871,9 +756,9 @@ class ReverseConverter
int n = resultType.getShape()[i];
inputExprs[i] = b->getAffineConstantExpr(n - 1) - inputExprs[i];
}
return b->getAffineMapArrayAttr(
{AffineMap::get(nloops, /*symbolCount=*/0, inputExprs, b->getContext()),
b->getMultiDimIdentityMap(nloops)});
return {
AffineMap::get(nloops, /*symbolCount=*/0, inputExprs, b->getContext()),
b->getMultiDimIdentityMap(nloops)};
}
};
@ -946,7 +831,7 @@ void populateLHLOToLinalgConversionPattern(MLIRContext* context,
PointwiseToLinalgConverter<xla_lhlo::SqrtOp>,
PointwiseToLinalgConverter<xla_lhlo::SubOp>,
PointwiseToLinalgConverter<xla_lhlo::TanhOp>,
ReshapeAddRemoveDimConverter<xla_lhlo::ReshapeOp>,
ReshapeOpConverter<xla_lhlo::ReshapeOp>,
ReverseConverter<xla_lhlo::ReverseOp>,
ScalarPointwiseToStandardConverter<xla_lhlo::AddOp>,
SliceConverter
@ -1045,7 +930,6 @@ void populateHLOToLinalgConversionPattern(MLIRContext* context,
PointwiseToLinalgConverter<xla_hlo::SqrtOp, false>,
PointwiseToLinalgConverter<xla_hlo::SubOp, false>,
PointwiseToLinalgConverter<xla_hlo::TanhOp, false>,
ReshapeAddRemoveDimConverter<xla_hlo::ReshapeOp, false>,
ReshapeOpConverter<xla_hlo::ReshapeOp, false>,
ReverseConverter<xla_hlo::ReverseOp, false>,
TransposeConverter<xla_hlo::TransposeOp, false>>(context);

View File

@ -185,6 +185,7 @@ tf_xla_py_test(
name = "argminmax_test",
size = "small",
srcs = ["argminmax_test.py"],
enable_mlir_bridge = True,
python_version = "PY3",
tags = [
"no_pip", # TODO(b/149738646): fix pip install so these tests run on kokoro pip
@ -253,6 +254,7 @@ tf_xla_py_test(
name = "bucketize_op_test",
size = "small",
srcs = ["bucketize_op_test.py"],
enable_mlir_bridge = True,
python_version = "PY3",
tags = [
"no_pip", # TODO(b/149738646): fix pip install so these tests run on kokoro pip
@ -692,6 +694,7 @@ tf_xla_py_test(
name = "fft_test",
size = "medium",
srcs = ["fft_test.py"],
enable_mlir_bridge = True,
python_version = "PY3",
shard_count = 6,
tags = [
@ -805,6 +808,7 @@ tf_xla_py_test(
name = "lrn_ops_test",
size = "medium",
srcs = ["lrn_ops_test.py"],
enable_mlir_bridge = True,
python_version = "PY3",
tags = [
"no_pip", # TODO(b/149738646): fix pip install so these tests run on kokoro pip
@ -1129,6 +1133,7 @@ tf_xla_py_test(
name = "reverse_sequence_op_test",
size = "medium",
srcs = ["reverse_sequence_op_test.py"],
enable_mlir_bridge = True,
python_version = "PY3",
tags = [
"no_pip", # TODO(b/149738646): fix pip install so these tests run on kokoro pip
@ -1218,6 +1223,7 @@ tf_xla_py_test(
name = "sparse_to_dense_op_test",
size = "small",
srcs = ["sparse_to_dense_op_test.py"],
enable_mlir_bridge = True,
python_version = "PY3",
tags = [
"no_pip", # TODO(b/149738646): fix pip install so these tests run on kokoro pip

View File

@ -1225,8 +1225,6 @@ class BinaryOpsTest(xla_test.XLATestCase):
[7, 7, 7, 7, 7, 7]],
dtype=dtype))
@test_util.disable_mlir_bridge(
"Requires concatenate op support in MlirHloBuilder")
def testSymmetricMirrorPad(self):
mirror_pad = lambda t, paddings: array_ops.pad(t, paddings, "SYMMETRIC")
for dtype in self.numeric_types:
@ -1258,8 +1256,6 @@ class BinaryOpsTest(xla_test.XLATestCase):
np.array([[0, 0], [0, 0]], dtype=np.int32),
expected=np.array([[1, 2, 3], [4, 5, 6]], dtype=dtype))
@test_util.disable_mlir_bridge(
"Requires concatenate op support in MlirHloBuilder")
def testReflectMirrorPad(self):
mirror_pad = lambda t, paddings: array_ops.pad(t, paddings, "REFLECT")
for dtype in self.numeric_types:

View File

@ -21,6 +21,7 @@ from __future__ import print_function
from tensorflow.compiler.tests import xla_test
from tensorflow.python.framework import dtypes
from tensorflow.python.framework import errors_impl
from tensorflow.python.framework import test_util
from tensorflow.python.ops import array_ops
from tensorflow.python.ops import math_ops
from tensorflow.python.platform import test
@ -57,6 +58,7 @@ class BucketizationOpTest(xla_test.XLATestCase):
expected_out, sess.run(op,
{p: [[-5, 0, 2, 3, 5], [8, 10, 11, 12, 0]]}))
@test_util.disable_mlir_bridge("Error handling")
def testInvalidBoundariesOrder(self):
with self.session() as sess:
p = array_ops.placeholder(dtypes.int32)

View File

@ -78,10 +78,17 @@ class ConvolutionNodeNameTest(xla_test.XLATestCase):
xla_names = _GetNodeNames(use_xla=True)
no_xla_names = _GetNodeNames(use_xla=False)
self.assertListEqual(
xla_names,
no_xla_names,
)
# CPU path creates some additional nodes to handle dilations.
# TODO(b/138804006): Remove this when CPU & GPU support dilations.
filtered_no_xla_names = []
for name in no_xla_names:
if ("dilation_rate" in name or "filter_shape" in name or "stack" in name):
continue
else:
filtered_no_xla_names.append(name)
self.assertListEqual(xla_names, filtered_no_xla_names)
def testConv1DNodeNameMatch(self):
input_sizes = [8, 16, 3]

View File

@ -22,6 +22,7 @@ import numpy as np
from tensorflow.compiler.tests import xla_test
from tensorflow.python.framework import dtypes
from tensorflow.python.framework import test_util
from tensorflow.python.ops import array_ops
from tensorflow.python.ops import sparse_ops
from tensorflow.python.platform import test
@ -101,6 +102,7 @@ class SparseToDenseTest(xla_test.XLATestCase):
with self.assertRaisesWithPredicateMatch(ValueError, "must be rank 1"):
_SparseToDense([1, 3], [[5], [3]], 1, -1)
@test_util.disable_mlir_bridge("Error handling")
def testBadValue(self):
with self.session(), self.test_scope():
with self.assertRaisesOpError(
@ -108,12 +110,14 @@ class SparseToDenseTest(xla_test.XLATestCase):
r"should be \[\] or \[2\]"):
_SparseToDense([1, 3], [5], [[5], [3]], -1)
@test_util.disable_mlir_bridge("Error handling")
def testBadNumValues(self):
with self.session(), self.test_scope():
with self.assertRaisesOpError(
r"sparse_values has incorrect shape \[3\], should be \[\] or \[2\]"):
_SparseToDense([1, 3], [5], [1, 2, 3], -1)
@test_util.disable_mlir_bridge("Error handling")
def testBadDefault(self):
with self.session(), self.test_scope():
with self.assertRaisesOpError("default_value should be a scalar"):

View File

@ -918,16 +918,12 @@ class UnaryOpsTest(xla_test.XLATestCase):
np.array([1, 0x100000003f800000], np.int64),
expected=np.array([1, 0x100000003f800000], np.uint64))
@test_util.disable_mlir_bridge(
"TODO(b/153812660): Handle tf.InvertPermutation compilation")
def testInvertPermutation(self):
self._assertOpOutputMatchesExpected(
array_ops.invert_permutation,
np.array([1, 2, 0], np.int32),
expected=np.array([2, 0, 1], dtype=np.int32))
@test_util.disable_mlir_bridge(
"TODO(b/153812660): Handle tf.InvertPermutation compilation")
def testInvertPermutationTwiceIsNoop(self):
self._assertOpOutputMatchesExpected(
lambda x: array_ops.invert_permutation(array_ops.invert_permutation(x)),
@ -1144,8 +1140,6 @@ class UnaryOpsTest(xla_test.XLATestCase):
rtol=rtol,
atol=atol)
@test_util.disable_mlir_bridge(
"bf16 type not supported in CreateDenseElementsAttrFromLiteral")
def testSoftplus(self):
for dtype in self.float_types & {dtypes.float32, dtypes.float64}:
self._assertSoftplusMatchesExpected([[-2, 0, 8]], dtype)

View File

@ -132,7 +132,7 @@ cc_library(
"//tensorflow/stream_executor:device_memory_allocator",
"@com_google_absl//absl/memory",
"@com_google_absl//absl/types:span",
"@llvm-project//llvm:support",
"@llvm-project//llvm:Support",
],
)
@ -150,7 +150,7 @@ cc_library(
"//tensorflow/compiler/xla/service:compiler",
"//tensorflow/core:stream_executor_no_cuda",
"@com_google_absl//absl/memory",
"@llvm-project//llvm:support",
"@llvm-project//llvm:Support",
],
)

Some files were not shown because too many files have changed in this diff Show More