Merge remote-tracking branch 'upstream/master' into offline_memory_planner

This commit is contained in:
Jens Elofsson 2020-06-15 10:06:36 +02:00
commit 708ecda43e
709 changed files with 15423 additions and 6384 deletions

View File

@ -202,7 +202,6 @@ cc_library(
":operation_interface", ":operation_interface",
":tensor_handle_interface", ":tensor_handle_interface",
"//tensorflow/c:tensor_interface", "//tensorflow/c:tensor_interface",
"//tensorflow/c/experimental/saved_model/core:saved_model_api",
"//tensorflow/core:framework", "//tensorflow/core:framework",
"//tensorflow/core:lib", "//tensorflow/core:lib",
"//tensorflow/core:protos_all_cc", "//tensorflow/core:protos_all_cc",

View File

@ -1473,14 +1473,10 @@ const TFE_OpAttrs* TFE_OpGetAttrs(TFE_Op* op) {
} }
void TFE_OpAddAttrs(TFE_Op* op, const TFE_OpAttrs* attrs) { void TFE_OpAddAttrs(TFE_Op* op, const TFE_OpAttrs* attrs) {
tensorflow::AttrValueMap m;
tensorflow::unwrap(attrs)->FillAttrValueMap(&m);
tensorflow::EagerOperation* operation = tensorflow::EagerOperation* operation =
OperationFromInterface(tensorflow::unwrap(op)); OperationFromInterface(tensorflow::unwrap(op));
tensorflow::AttrBuilder* destination = operation->MutableAttrs(); tensorflow::AttrBuilder* destination = operation->MutableAttrs();
for (const auto& attribute : m) { destination->CopyAttributes(*tensorflow::unwrap(attrs));
destination->Set(attribute.first, attribute.second);
}
} }
void TFE_OpAttrsSerialize(const TFE_OpAttrs* attrs, TF_Buffer* buf, void TFE_OpAttrsSerialize(const TFE_OpAttrs* attrs, TF_Buffer* buf,

View File

@ -21,8 +21,8 @@ limitations under the License.
#include "absl/types/span.h" #include "absl/types/span.h"
#include "tensorflow/c/eager/operation_interface.h" #include "tensorflow/c/eager/operation_interface.h"
#include "tensorflow/c/eager/tensor_handle_interface.h" #include "tensorflow/c/eager/tensor_handle_interface.h"
#include "tensorflow/c/experimental/saved_model/core/saved_model_api.h"
#include "tensorflow/c/tensor_interface.h" #include "tensorflow/c/tensor_interface.h"
#include "tensorflow/core/framework/function.pb.h"
#include "tensorflow/core/framework/numeric_types.h" #include "tensorflow/core/framework/numeric_types.h"
#include "tensorflow/core/framework/types.pb.h" #include "tensorflow/core/framework/types.pb.h"
#include "tensorflow/core/platform/status.h" #include "tensorflow/core/platform/status.h"
@ -84,11 +84,10 @@ class AbstractContextInterface {
// Create an operation to perform op execution // Create an operation to perform op execution
virtual AbstractOperationInterface* CreateOperation() = 0; virtual AbstractOperationInterface* CreateOperation() = 0;
// Load a SavedModelAPI object from the given directory and tags // Returns whether the runtime is backed by TFRT or the legacy TF Eager
virtual std::unique_ptr<SavedModelAPI> LoadSavedModelAPI( // Runtime. This is necessary to decouple runtime-dependent
const std::string& directory, // code that is layered on top of the runtime.
const absl::optional<std::unordered_set<std::string>>& tags, virtual bool UsesTFRT() = 0;
tensorflow::Status* status) = 0;
// List attributes of available devices // List attributes of available devices
virtual void ListDevices(std::vector<DeviceAttributes>* devices) = 0; virtual void ListDevices(std::vector<DeviceAttributes>* devices) = 0;

View File

@ -16,6 +16,8 @@ limitations under the License.
#include "tensorflow/c/eager/parallel_device/parallel_device_lib.h" #include "tensorflow/c/eager/parallel_device/parallel_device_lib.h"
#include "tensorflow/core/lib/gtl/cleanup.h" #include "tensorflow/core/lib/gtl/cleanup.h"
#include "tensorflow/core/platform/env.h"
#include "tensorflow/core/platform/mutex.h"
namespace tensorflow { namespace tensorflow {
namespace parallel_device { namespace parallel_device {
@ -28,21 +30,198 @@ class OpDeleter {
using OpPtr = std::unique_ptr<TFE_Op, OpDeleter>; using OpPtr = std::unique_ptr<TFE_Op, OpDeleter>;
// Creates a vector of `count` new executors (threads). class StatusDeleter {
std::vector<ExecutorPtr> MakeExecutors(size_t count) { public:
std::vector<ExecutorPtr> executors; void operator()(TF_Status* to_delete) const { TF_DeleteStatus(to_delete); }
executors.reserve(count); };
for (int i = 0; i < count; ++i) {
executors.emplace_back(TFE_NewExecutor(true /* is_async */)); using StatusPtr = std::unique_ptr<TF_Status, StatusDeleter>;
}
return executors;
}
} // namespace } // namespace
// Allows a single op at a time to be launched without blocking.
//
// DeviceThread itself is thread-safe, in that StartExecute will block if there
// is a pending execution. Since StartExecute is equivalent to grabbing a lock,
// multiple DeviceThreads should always be accessed in the same order to avoid
// deadlocks.
class DeviceThread {
public:
// Starts a background thread waiting for `StartExecute`.
explicit DeviceThread(const std::string& device)
: status_(TF_NewStatus()),
device_(device),
op_(nullptr),
thread_(tensorflow::Env::Default()->StartThread(
tensorflow::ThreadOptions(), "parallel_device_execute",
std::bind(&DeviceThread::Run, this))) {}
~DeviceThread();
// Requests that the worker thread execute the specified operation. Blocks
// until the previously pending operation (a StartExecute without a Join) has
// finished, if any.
void StartExecute(TFE_Context* context, const char* operation_name,
std::vector<TFE_TensorHandle*> inputs,
const TFE_OpAttrs* attributes, int expected_max_outputs);
// Block until the previous `StartExecute` operation has executed. Forwards
// the status from `TFE_Execute` and returns outputs if the status is OK.
std::vector<TensorHandlePtr> Join(TF_Status* status);
private:
void Run();
void Execute(TFE_Context* context, const char* operation_name,
std::vector<TFE_TensorHandle*> inputs,
const TFE_OpAttrs* attributes, int expected_max_outputs,
std::vector<TensorHandlePtr>* outputs, TF_Status* status) const
TF_EXCLUSIVE_LOCKS_REQUIRED(execution_mutex_);
enum class ExecutionState {
kReadyToExecute,
kHasResult,
kIdle,
kShuttingDown,
};
tensorflow::mutex execution_mutex_;
ExecutionState execution_state_ TF_GUARDED_BY(execution_mutex_) =
ExecutionState::kIdle;
// Tells the worker thread that there is new work.
tensorflow::condition_variable start_execute_;
// The worker thread notifies that work has finished.
tensorflow::condition_variable finished_execute_;
// Notifies a StartExecute that the previous Join has finished.
tensorflow::condition_variable finished_join_;
// Temporary state between `StartExecute` and `Join`.
// Inputs
TFE_Context* context_ TF_GUARDED_BY(execution_mutex_);
const char* operation_name_ TF_GUARDED_BY(execution_mutex_);
std::vector<TFE_TensorHandle*> op_inputs_ TF_GUARDED_BY(execution_mutex_);
const TFE_OpAttrs* attributes_ TF_GUARDED_BY(execution_mutex_);
int expected_max_outputs_ TF_GUARDED_BY(execution_mutex_);
// Outputs
std::vector<TensorHandlePtr> op_outputs_ TF_GUARDED_BY(execution_mutex_);
StatusPtr status_ TF_GUARDED_BY(execution_mutex_);
const std::string device_;
mutable OpPtr op_ TF_GUARDED_BY(execution_mutex_);
std::unique_ptr<Thread> thread_;
};
DeviceThread::~DeviceThread() {
{
tensorflow::mutex_lock l(execution_mutex_);
execution_state_ = ExecutionState::kShuttingDown;
}
start_execute_.notify_one();
}
void DeviceThread::Run() {
while (true) {
{
tensorflow::mutex_lock l(execution_mutex_);
while (execution_state_ == ExecutionState::kIdle ||
execution_state_ == ExecutionState::kHasResult) {
start_execute_.wait(l);
}
if (execution_state_ == ExecutionState::kShuttingDown) {
return;
} else if (execution_state_ == ExecutionState::kReadyToExecute) {
// op_outputs_ may have been std::moved
op_outputs_ = std::vector<TensorHandlePtr>();
Execute(context_, operation_name_, std::move(op_inputs_), attributes_,
expected_max_outputs_, &op_outputs_, status_.get());
execution_state_ = ExecutionState::kHasResult;
}
}
finished_execute_.notify_one();
}
}
void DeviceThread::StartExecute(TFE_Context* context,
const char* operation_name,
std::vector<TFE_TensorHandle*> inputs,
const TFE_OpAttrs* attributes,
int expected_max_outputs) {
{
tensorflow::mutex_lock l(execution_mutex_);
while (execution_state_ != ExecutionState::kIdle) {
// If there's already a pending execution, wait until Join finishes before
// starting on the next operation.
finished_join_.wait(l);
}
context_ = context;
operation_name_ = operation_name;
op_inputs_ = inputs;
attributes_ = attributes;
expected_max_outputs_ = expected_max_outputs;
execution_state_ = ExecutionState::kReadyToExecute;
}
start_execute_.notify_one();
}
std::vector<TensorHandlePtr> DeviceThread::Join(TF_Status* status) {
std::vector<TensorHandlePtr> result;
{
tensorflow::mutex_lock l(execution_mutex_);
while (execution_state_ != ExecutionState::kHasResult) {
finished_execute_.wait(l);
}
if (TF_GetCode(status_.get()) != TF_OK) {
TF_SetStatus(status, TF_GetCode(status_.get()),
TF_Message(status_.get()));
}
execution_state_ = ExecutionState::kIdle;
result = std::move(op_outputs_);
}
finished_join_.notify_one();
return result;
}
void DeviceThread::Execute(TFE_Context* context, const char* operation_name,
std::vector<TFE_TensorHandle*> inputs,
const TFE_OpAttrs* attributes,
int expected_max_outputs,
std::vector<TensorHandlePtr>* outputs,
TF_Status* status) const {
if (op_ == nullptr) {
op_.reset(TFE_NewOp(context, operation_name, status));
if (TF_GetCode(status) != TF_OK) return;
TFE_OpSetDevice(op_.get(), device_.c_str(), status);
if (TF_GetCode(status) != TF_OK) return;
} else {
TFE_OpReset(op_.get(), operation_name, device_.c_str(), status);
if (TF_GetCode(status) != TF_OK) return;
}
TFE_OpAddAttrs(op_.get(), attributes);
for (int input_index = 0; input_index < inputs.size(); ++input_index) {
TFE_OpAddInput(op_.get(), inputs[input_index], status);
if (TF_GetCode(status) != TF_OK) return;
}
std::vector<TFE_TensorHandle*> unwrapped_results(expected_max_outputs);
int real_num_outputs = expected_max_outputs;
if (TF_GetCode(status) != TF_OK) return;
TFE_Execute(op_.get(), unwrapped_results.data(), &real_num_outputs, status);
if (TF_GetCode(status) != TF_OK) return;
unwrapped_results.resize(real_num_outputs);
outputs->reserve(real_num_outputs);
for (TFE_TensorHandle* unwrapped_result : unwrapped_results) {
outputs->emplace_back(unwrapped_result);
}
}
ParallelDevice::ParallelDevice(const std::vector<std::string>& devices) ParallelDevice::ParallelDevice(const std::vector<std::string>& devices)
: underlying_devices_(devices), : underlying_devices_(devices) {
executors_(MakeExecutors(underlying_devices_.size())) {} device_threads_.reserve(devices.size());
for (int device_index = 0; device_index < devices.size(); ++device_index) {
device_threads_.emplace_back(
new DeviceThread(devices[device_index].c_str()));
}
}
// Necessary for a unique_ptr to a forward-declared type.
ParallelDevice::~ParallelDevice() = default;
std::unique_ptr<ParallelTensor> ParallelDevice::CopyToParallelDevice( std::unique_ptr<ParallelTensor> ParallelDevice::CopyToParallelDevice(
TFE_Context* context, TFE_TensorHandle* tensor, TF_Status* status) const { TFE_Context* context, TFE_TensorHandle* tensor, TF_Status* status) const {
@ -108,72 +287,34 @@ ParallelDevice::Execute(TFE_Context* context,
// Compute per-device per-output tensors // Compute per-device per-output tensors
std::vector<std::vector<TensorHandlePtr>> per_device_output_tensors; std::vector<std::vector<TensorHandlePtr>> per_device_output_tensors;
per_device_output_tensors.reserve(underlying_devices_.size()); per_device_output_tensors.reserve(underlying_devices_.size());
// TODO(allenl): Add a TFE_ExecuteWithExecutor API so we don't have to keep int first_op_output_count = 0;
// setting the thread-local executor like this.
TFE_Executor* previous_executor(TFE_ContextGetExecutorForThread(context));
auto reset_executor =
tensorflow::gtl::MakeCleanup([context, previous_executor]() {
TFE_ContextSetExecutorForThread(context, previous_executor);
TFE_DeleteExecutor(previous_executor);
});
int first_op_output_count;
for (int device_index = 0; device_index < underlying_devices_.size(); for (int device_index = 0; device_index < underlying_devices_.size();
++device_index) { ++device_index) {
TFE_Executor* executor = executors_[device_index].get(); DeviceThread* device_thread = device_threads_[device_index].get();
// Note that the `reset_executor` cleanup sets the thread's executor back to std::vector<TFE_TensorHandle*> device_inputs;
// the value before this function ran. device_inputs.reserve(device_inputs.size());
TFE_ContextSetExecutorForThread(context, executor);
OpPtr op(TFE_NewOp(context, operation_name, status));
if (TF_GetCode(status) != TF_OK) return result;
TFE_OpSetDevice(op.get(), underlying_devices_[device_index].c_str(),
status);
TFE_OpAddAttrs(op.get(), attributes);
for (int input_index = 0; input_index < inputs.size(); ++input_index) { for (int input_index = 0; input_index < inputs.size(); ++input_index) {
// Parallel tensors are divided between operations by device. // Parallel tensors are divided between operations by device.
TFE_OpAddInput(op.get(), inputs[input_index]->tensor(device_index), device_inputs.push_back(inputs[input_index]->tensor(device_index));
status);
if (TF_GetCode(status) != TF_OK) return result;
} }
std::vector<TFE_TensorHandle*> op_outputs(expected_max_outputs); device_thread->StartExecute(context, operation_name,
int real_num_outputs = expected_max_outputs; std::move(device_inputs), attributes,
// For nested devices, the inner device sees the async executor we've expected_max_outputs);
// set. Inner parallel devices will just overwrite this with their own and }
// then set it back to ours before returning. This means parallel devices for (int device_index = 0; device_index < underlying_devices_.size();
// which consist of several aliased parallel devices would hypothetically ++device_index) {
// deadlock if the outer parallel device ran one collective with a group DeviceThread* device_thread = device_threads_[device_index].get();
// size equal to the total number of aliased physical devices. Currently per_device_output_tensors.push_back(device_thread->Join(status));
// physical devices cannot participate in a single collective reduction if (TF_GetCode(status) != TF_OK) return result;
// multiple times, so this would fail earlier.
//
// TODO(allenl): Keep a map from outer executor to list of inner executors
// rather than a single list of executors so aliased nested parallel devices
// don't re-use an executor.
TFE_Execute(op.get(), op_outputs.data(), &real_num_outputs, status);
if (device_index == 0) { if (device_index == 0) {
first_op_output_count = real_num_outputs; first_op_output_count = per_device_output_tensors.rbegin()->size();
} else { } else {
if (real_num_outputs != first_op_output_count) { if (per_device_output_tensors.rbegin()->size() != first_op_output_count) {
TF_SetStatus(status, TF_INTERNAL, TF_SetStatus(status, TF_INTERNAL,
"Parallel ops produced different numbers of tensors."); "Parallel ops produced different numbers of tensors.");
return result; return result;
} }
} }
if (TF_GetCode(status) != TF_OK) return result;
std::vector<TensorHandlePtr> this_outputs;
this_outputs.reserve(real_num_outputs);
for (int output_num = 0; output_num < real_num_outputs; ++output_num) {
this_outputs.emplace_back(op_outputs[output_num]);
}
per_device_output_tensors.push_back(std::move(this_outputs));
}
for (int device_index = 0; device_index < underlying_devices_.size();
++device_index) {
TFE_Executor* executor = executors_[device_index].get();
// TODO(b/157523095): Syncing the executor here shouldn't be
// necessary. Currently async+remote is missing cross-executor
// coordination.
TFE_ExecutorWaitForAllPendingNodes(executor, status);
if (TF_GetCode(status) != TF_OK) return result;
} }
// For each output of the original operation, pack the per-device // For each output of the original operation, pack the per-device
// TensorHandles we've computed into a single parallel TensorHandle. // TensorHandles we've computed into a single parallel TensorHandle.

View File

@ -41,16 +41,8 @@ class TensorHandleDeleter {
using TensorHandlePtr = std::unique_ptr<TFE_TensorHandle, TensorHandleDeleter>; using TensorHandlePtr = std::unique_ptr<TFE_TensorHandle, TensorHandleDeleter>;
class ExecutorDeleter {
public:
void operator()(TFE_Executor* to_delete) const {
TFE_DeleteExecutor(to_delete);
}
};
using ExecutorPtr = std::unique_ptr<TFE_Executor, ExecutorDeleter>;
class ParallelTensor; class ParallelTensor;
class DeviceThread;
// Forwards operations to `devices`, maintaining ParallelTensor with components // Forwards operations to `devices`, maintaining ParallelTensor with components
// placed on each underlying device. // placed on each underlying device.
@ -58,6 +50,8 @@ class ParallelDevice {
public: public:
explicit ParallelDevice(const std::vector<std::string>& devices); explicit ParallelDevice(const std::vector<std::string>& devices);
~ParallelDevice();
// Helper to copy a tensor handle from another device once for each component // Helper to copy a tensor handle from another device once for each component
// of the ParallelDevice. // of the ParallelDevice.
// //
@ -94,9 +88,19 @@ class ParallelDevice {
// A sequence of device names, indicating which devices replicated operations // A sequence of device names, indicating which devices replicated operations
// are forwarded to. // are forwarded to.
const std::vector<std::string> underlying_devices_; const std::vector<std::string> underlying_devices_;
// A sequence of TFE_Executors, one per device, for executing operations in // A sequence of thread wrappers, one per device, for executing operations in
// parallel. // parallel.
const std::vector<ExecutorPtr> executors_; //
// Conceptually this is a thread pool with one thread per device. It requires
// less synchronization than a thread pool would for this task, since Execute
// acquires each thread in order (and so only one Execute will schedule
// blocking collective operations at a time), and avoids some dynamic
// allocation/scheduling.
//
// TODO(allenl): Keep a map from outer thread to list of inner threads rather
// than a single list of threads so aliased nested parallel devices don't
// re-use a thread.
std::vector<std::unique_ptr<DeviceThread>> device_threads_;
}; };
// Contains a tuple of tensors, one on each of the `underlying_devices_` of the // Contains a tuple of tensors, one on each of the `underlying_devices_` of the

View File

@ -407,7 +407,7 @@ TensorHandlePtr CollectiveSum(TFE_Context* context, TFE_TensorHandle* input,
return TensorHandlePtr(result_handle); return TensorHandlePtr(result_handle);
} }
TEST(PARALLEL_DEVICE, TestCollective) { void TestCollective(bool async) {
std::unique_ptr<TF_Status, decltype(&TF_DeleteStatus)> status( std::unique_ptr<TF_Status, decltype(&TF_DeleteStatus)> status(
TF_NewStatus(), TF_DeleteStatus); TF_NewStatus(), TF_DeleteStatus);
std::unique_ptr<TFE_ContextOptions, decltype(&TFE_DeleteContextOptions)> opts( std::unique_ptr<TFE_ContextOptions, decltype(&TFE_DeleteContextOptions)> opts(
@ -423,6 +423,9 @@ TEST(PARALLEL_DEVICE, TestCollective) {
std::unique_ptr<TFE_Context, decltype(&TFE_DeleteContext)> context( std::unique_ptr<TFE_Context, decltype(&TFE_DeleteContext)> context(
TFE_NewContext(opts.get(), status.get()), TFE_DeleteContext); TFE_NewContext(opts.get(), status.get()), TFE_DeleteContext);
ASSERT_TRUE(TF_GetCode(status.get()) == TF_OK) << TF_Message(status.get()); ASSERT_TRUE(TF_GetCode(status.get()) == TF_OK) << TF_Message(status.get());
std::unique_ptr<TFE_Executor, decltype(&TFE_DeleteExecutor)> executor(
TFE_NewExecutor(async), TFE_DeleteExecutor);
TFE_ContextSetExecutorForThread(context.get(), executor.get());
const char* device_name = "/job:localhost/replica:0/task:0/device:CUSTOM:0"; const char* device_name = "/job:localhost/replica:0/task:0/device:CUSTOM:0";
std::array<const char*, 2> underlying_devices{ std::array<const char*, 2> underlying_devices{
@ -452,8 +455,16 @@ TEST(PARALLEL_DEVICE, TestCollective) {
ASSERT_TRUE(TF_GetCode(status.get()) == TF_OK) << TF_Message(status.get()); ASSERT_TRUE(TF_GetCode(status.get()) == TF_OK) << TF_Message(status.get());
ExpectScalarEq<float>(result_components[0].get(), 3.); ExpectScalarEq<float>(result_components[0].get(), 3.);
ExpectScalarEq<float>(result_components[1].get(), 3.); ExpectScalarEq<float>(result_components[1].get(), 3.);
// Destroying the context's default executor first isn't safe.
context.reset();
} }
TEST(PARALLEL_DEVICE, TestCollectiveSync) { TestCollective(/*async=*/false); }
// Note that ops on the parallel device currently don't execute
// asynchronously. The test is just that we don't get deadlocks.
TEST(PARALLEL_DEVICE, TestCollectiveAsync) { TestCollective(/*async=*/true); }
void RegisterCollectiveMulFunction(TFE_Context* context, void RegisterCollectiveMulFunction(TFE_Context* context,
const char* function_name, int group_size, const char* function_name, int group_size,
TF_Status* status) { TF_Status* status) {

View File

@ -26,5 +26,6 @@ cc_library(
deps = [ deps = [
"//tensorflow/c:tf_status", "//tensorflow/c:tf_status",
"//tensorflow/c/experimental/filesystem:filesystem_interface", "//tensorflow/c/experimental/filesystem:filesystem_interface",
"@com_github_googlecloudplatform_google_cloud_cpp//:storage_client",
], ],
) )

View File

@ -15,11 +15,22 @@ limitations under the License.
#include <stdlib.h> #include <stdlib.h>
#include <string.h> #include <string.h>
#include "google/cloud/storage/client.h"
#include "tensorflow/c/experimental/filesystem/filesystem_interface.h" #include "tensorflow/c/experimental/filesystem/filesystem_interface.h"
#include "tensorflow/c/tf_status.h" #include "tensorflow/c/tf_status.h"
// Implementation of a filesystem for GCS environments. // Implementation of a filesystem for GCS environments.
// This filesystem will support `gs://` URI schemes. // This filesystem will support `gs://` URI schemes.
namespace gcs = google::cloud::storage;
// We can cast `google::cloud::StatusCode` to `TF_Code` because they have the
// same integer values. See
// https://github.com/googleapis/google-cloud-cpp/blob/6c09cbfa0160bc046e5509b4dd2ab4b872648b4a/google/cloud/status.h#L32-L52
static inline void TF_SetStatusFromGCSStatus(
const google::cloud::Status& gcs_status, TF_Status* status) {
TF_SetStatus(status, static_cast<TF_Code>(gcs_status.code()),
gcs_status.message().c_str());
}
static void* plugin_memory_allocate(size_t size) { return calloc(1, size); } static void* plugin_memory_allocate(size_t size) { return calloc(1, size); }
static void plugin_memory_free(void* ptr) { free(ptr); } static void plugin_memory_free(void* ptr) { free(ptr); }
@ -52,6 +63,20 @@ namespace tf_read_only_memory_region {
// ---------------------------------------------------------------------------- // ----------------------------------------------------------------------------
namespace tf_gcs_filesystem { namespace tf_gcs_filesystem {
// TODO(vnvo2409): Add lazy-loading and customizing parameters.
static void Init(TF_Filesystem* filesystem, TF_Status* status) {
google::cloud::StatusOr<gcs::Client> client =
gcs::Client::CreateDefaultClient();
if (!client) {
TF_SetStatusFromGCSStatus(client.status(), status);
return;
}
filesystem->plugin_filesystem = plugin_memory_allocate(sizeof(gcs::Client));
auto gcs_client = static_cast<gcs::Client*>(filesystem->plugin_filesystem);
(*gcs_client) = client.value();
TF_SetStatus(status, TF_OK, "");
}
// TODO(vnvo2409): Implement later // TODO(vnvo2409): Implement later
} // namespace tf_gcs_filesystem } // namespace tf_gcs_filesystem
@ -60,6 +85,10 @@ static void ProvideFilesystemSupportFor(TF_FilesystemPluginOps* ops,
const char* uri) { const char* uri) {
TF_SetFilesystemVersionMetadata(ops); TF_SetFilesystemVersionMetadata(ops);
ops->scheme = strdup(uri); ops->scheme = strdup(uri);
ops->filesystem_ops = static_cast<TF_FilesystemOps*>(
plugin_memory_allocate(TF_FILESYSTEM_OPS_SIZE));
ops->filesystem_ops->init = tf_gcs_filesystem::Init;
} }
void TF_InitPlugin(TF_FilesystemPluginInfo* info) { void TF_InitPlugin(TF_FilesystemPluginInfo* info) {

View File

@ -57,6 +57,7 @@ cc_library(
":concrete_function", ":concrete_function",
":saved_model_api", ":saved_model_api",
"//tensorflow/core:lib", "//tensorflow/core:lib",
"//tensorflow/core/common_runtime/eager:context",
"@com_google_absl//absl/types:optional", "@com_google_absl//absl/types:optional",
], ],
) )

View File

@ -21,6 +21,7 @@ limitations under the License.
#include "absl/types/optional.h" #include "absl/types/optional.h"
#include "tensorflow/c/experimental/saved_model/core/concrete_function.h" #include "tensorflow/c/experimental/saved_model/core/concrete_function.h"
#include "tensorflow/core/common_runtime/eager/context.h"
#include "tensorflow/core/platform/errors.h" #include "tensorflow/core/platform/errors.h"
namespace tensorflow { namespace tensorflow {
@ -51,7 +52,7 @@ std::vector<ConcreteFunction*> TFSavedModelAPIImpl::ListFunctions() {
Status TFSavedModelAPIImpl::Load( Status TFSavedModelAPIImpl::Load(
const std::string& directory, const std::string& directory,
const absl::optional<std::unordered_set<std::string>>& tags, const absl::optional<std::unordered_set<std::string>>& tags,
TFSavedModelAPIImpl* out) { EagerContext* context, std::unique_ptr<TFSavedModelAPIImpl>* out) {
// TODO(bmzhao): Add support for loading a TFSavedModelImpl. // TODO(bmzhao): Add support for loading a TFSavedModelImpl.
return errors::Unimplemented( return errors::Unimplemented(
"TFSavedModelAPIImpl loading is unimplemented currently"); "TFSavedModelAPIImpl loading is unimplemented currently");

View File

@ -23,14 +23,13 @@ limitations under the License.
#include "absl/types/optional.h" #include "absl/types/optional.h"
#include "tensorflow/c/experimental/saved_model/core/concrete_function.h" #include "tensorflow/c/experimental/saved_model/core/concrete_function.h"
#include "tensorflow/c/experimental/saved_model/core/saved_model_api.h" #include "tensorflow/c/experimental/saved_model/core/saved_model_api.h"
#include "tensorflow/core/common_runtime/eager/context.h"
#include "tensorflow/core/platform/status.h" #include "tensorflow/core/platform/status.h"
namespace tensorflow { namespace tensorflow {
class TFSavedModelAPIImpl : public SavedModelAPI { class TFSavedModelAPIImpl : public SavedModelAPI {
public: public:
TFSavedModelAPIImpl() = default;
Status GetFunction(const std::string& function_path, Status GetFunction(const std::string& function_path,
ConcreteFunction** function) override; ConcreteFunction** function) override;
@ -40,13 +39,14 @@ class TFSavedModelAPIImpl : public SavedModelAPI {
static Status Load( static Status Load(
const std::string& directory, const std::string& directory,
const absl::optional<std::unordered_set<std::string>>& tags, const absl::optional<std::unordered_set<std::string>>& tags,
TFSavedModelAPIImpl* out); EagerContext* context, std::unique_ptr<TFSavedModelAPIImpl>* out);
std::vector<ConcreteFunction*> ListFunctions() override; std::vector<ConcreteFunction*> ListFunctions() override;
~TFSavedModelAPIImpl() override = default; ~TFSavedModelAPIImpl() override = default;
private: private:
TFSavedModelAPIImpl() = default;
std::vector<ConcreteFunction> functions_; std::vector<ConcreteFunction> functions_;
}; };

View File

@ -144,7 +144,9 @@ cc_library(
"//tensorflow/c:tf_status_internal", "//tensorflow/c:tf_status_internal",
"//tensorflow/c/eager:tfe_context_internal", "//tensorflow/c/eager:tfe_context_internal",
"//tensorflow/c/experimental/saved_model/core:saved_model_api", "//tensorflow/c/experimental/saved_model/core:saved_model_api",
"//tensorflow/c/experimental/saved_model/core:tf_saved_model_impl",
"//tensorflow/core:lib", "//tensorflow/core:lib",
"//tensorflow/core/common_runtime/eager:context",
"@com_google_absl//absl/types:optional", "@com_google_absl//absl/types:optional",
], ],
) )

View File

@ -22,11 +22,15 @@ limitations under the License.
#include "absl/types/optional.h" #include "absl/types/optional.h"
#include "tensorflow/c/eager/tfe_context_internal.h" #include "tensorflow/c/eager/tfe_context_internal.h"
#include "tensorflow/c/experimental/saved_model/core/saved_model_api.h" #include "tensorflow/c/experimental/saved_model/core/saved_model_api.h"
#include "tensorflow/c/experimental/saved_model/core/tf_saved_model_impl.h"
#include "tensorflow/c/experimental/saved_model/internal/concrete_function_list_type.h" #include "tensorflow/c/experimental/saved_model/internal/concrete_function_list_type.h"
#include "tensorflow/c/experimental/saved_model/internal/concrete_function_type.h" #include "tensorflow/c/experimental/saved_model/internal/concrete_function_type.h"
#include "tensorflow/c/experimental/saved_model/internal/saved_model_api_type.h" #include "tensorflow/c/experimental/saved_model/internal/saved_model_api_type.h"
#include "tensorflow/c/tf_status.h" #include "tensorflow/c/tf_status.h"
#include "tensorflow/c/tf_status_internal.h" #include "tensorflow/c/tf_status_internal.h"
#include "tensorflow/core/common_runtime/eager/context.h"
#include "tensorflow/core/platform/casts.h"
#include "tensorflow/core/platform/errors.h"
#include "tensorflow/core/platform/status.h" #include "tensorflow/core/platform/status.h"
extern "C" { extern "C" {
@ -34,10 +38,21 @@ extern "C" {
TF_SavedModel* TF_LoadSavedModel(const char* dirname, TFE_Context* ctx, TF_SavedModel* TF_LoadSavedModel(const char* dirname, TFE_Context* ctx,
TF_Status* status) { TF_Status* status) {
std::string saved_model_dir(dirname); std::string saved_model_dir(dirname);
std::unique_ptr<tensorflow::SavedModelAPI> result;
if (tensorflow::unwrap(ctx)->UsesTFRT()) {
status->status = tensorflow::errors::Unimplemented(
"TFRT SavedModel implementation will be added in the future");
} else {
std::unique_ptr<tensorflow::TFSavedModelAPIImpl> saved_model;
status->status = tensorflow::TFSavedModelAPIImpl::Load(
dirname, absl::nullopt,
tensorflow::down_cast<tensorflow::EagerContext*>(
tensorflow::unwrap(ctx)),
&saved_model);
result = std::move(saved_model);
}
std::unique_ptr<tensorflow::SavedModelAPI> result =
tensorflow::unwrap(ctx)->LoadSavedModelAPI(dirname, absl::nullopt,
&status->status);
if (!status->status.ok()) { if (!status->status.ok()) {
return nullptr; return nullptr;
} }
@ -54,9 +69,20 @@ TF_SavedModel* TF_LoadSavedModelWithTags(const char* dirname, TFE_Context* ctx,
tagset.insert(std::string(tags[i])); tagset.insert(std::string(tags[i]));
} }
std::unique_ptr<tensorflow::SavedModelAPI> result = std::unique_ptr<tensorflow::SavedModelAPI> result;
tensorflow::unwrap(ctx)->LoadSavedModelAPI(dirname, std::move(tagset), if (tensorflow::unwrap(ctx)->UsesTFRT()) {
&status->status); status->status = tensorflow::errors::Unimplemented(
"TFRT SavedModel implementation will be added in the future");
} else {
std::unique_ptr<tensorflow::TFSavedModelAPIImpl> saved_model;
status->status = tensorflow::TFSavedModelAPIImpl::Load(
dirname, tagset,
tensorflow::down_cast<tensorflow::EagerContext*>(
tensorflow::unwrap(ctx)),
&saved_model);
result = std::move(saved_model);
}
if (!status->status.ok()) { if (!status->status.ok()) {
return nullptr; return nullptr;
} }

View File

@ -106,6 +106,7 @@ cc_library(
hdrs = ["loader.h"], hdrs = ["loader.h"],
deps = [ deps = [
":constants", ":constants",
":loader_util",
":reader", ":reader",
] + if_not_mobile([ ] + if_not_mobile([
"//tensorflow/core:core_cpu", "//tensorflow/core:core_cpu",
@ -132,6 +133,17 @@ cc_library(
], ],
) )
cc_library(
name = "loader_util",
srcs = ["loader_util.cc"],
hdrs = ["loader_util.h"],
deps = [":constants"] + if_not_mobile([
"//tensorflow/core:lib",
"//tensorflow/core:lib_internal",
"//tensorflow/core:protos_all_cc",
]),
)
tf_cc_test( tf_cc_test(
name = "bundle_v2_test", name = "bundle_v2_test",
srcs = ["bundle_v2_test.cc"], srcs = ["bundle_v2_test.cc"],

View File

@ -18,6 +18,7 @@ limitations under the License.
#include <unordered_set> #include <unordered_set>
#include "tensorflow/cc/saved_model/constants.h" #include "tensorflow/cc/saved_model/constants.h"
#include "tensorflow/cc/saved_model/loader_util.h"
#include "tensorflow/cc/saved_model/reader.h" #include "tensorflow/cc/saved_model/reader.h"
#include "tensorflow/core/framework/attr_value.pb.h" #include "tensorflow/core/framework/attr_value.pb.h"
#include "tensorflow/core/framework/node_def.pb.h" #include "tensorflow/core/framework/node_def.pb.h"
@ -29,7 +30,6 @@ limitations under the License.
#include "tensorflow/core/lib/strings/strcat.h" #include "tensorflow/core/lib/strings/strcat.h"
#include "tensorflow/core/platform/env.h" #include "tensorflow/core/platform/env.h"
#include "tensorflow/core/platform/errors.h" #include "tensorflow/core/platform/errors.h"
#include "tensorflow/core/platform/protobuf_internal.h"
#include "tensorflow/core/protobuf/graph_debug_info.pb.h" #include "tensorflow/core/protobuf/graph_debug_info.pb.h"
#include "tensorflow/core/protobuf/saver.pb.h" #include "tensorflow/core/protobuf/saver.pb.h"
#include "tensorflow/core/public/session.h" #include "tensorflow/core/public/session.h"
@ -191,41 +191,6 @@ Status RunInitOp(const RunOptions& run_options, const string& export_dir,
return Status::OK(); return Status::OK();
} }
// A SavedModel may store the name of the initialization op to run in the
// in the SignatureDef (v2) or a collection (v1). If an init_op collection
// exists, then the collection must contain exactly one op.
Status GetInitOp(const string& export_dir, const MetaGraphDef& meta_graph_def,
string* init_op_name) {
const auto& sig_def_map = meta_graph_def.signature_def();
const auto& init_op_sig_it =
meta_graph_def.signature_def().find(kSavedModelInitOpSignatureKey);
if (init_op_sig_it != sig_def_map.end()) {
*init_op_name = init_op_sig_it->second.outputs()
.find(kSavedModelInitOpSignatureKey)
->second.name();
return Status::OK();
}
const auto& collection_def_map = meta_graph_def.collection_def();
string init_op_collection_key;
if (collection_def_map.find(kSavedModelMainOpKey) !=
collection_def_map.end()) {
init_op_collection_key = kSavedModelMainOpKey;
} else {
init_op_collection_key = kSavedModelLegacyInitOpKey;
}
const auto init_op_it = collection_def_map.find(init_op_collection_key);
if (init_op_it != collection_def_map.end()) {
if (init_op_it->second.node_list().value_size() != 1) {
return errors::FailedPrecondition(
strings::StrCat("Expected exactly one main op in : ", export_dir));
}
*init_op_name = init_op_it->second.node_list().value(0);
}
return Status::OK();
}
Status RunRestore(const RunOptions& run_options, const string& export_dir, Status RunRestore(const RunOptions& run_options, const string& export_dir,
const StringPiece restore_op_name, const StringPiece restore_op_name,
const StringPiece variable_filename_const_op_name, const StringPiece variable_filename_const_op_name,
@ -263,32 +228,6 @@ Status RunRestore(const RunOptions& run_options, const string& export_dir,
nullptr /* outputs */, &run_metadata, session); nullptr /* outputs */, &run_metadata, session);
} }
Status GetAssetFileDefs(const MetaGraphDef& meta_graph_def,
std::vector<AssetFileDef>* asset_file_defs) {
// With SavedModel v2, we write asset file def into metagraph instead of
// collection, so read from metagraph first.
if (meta_graph_def.asset_file_def_size() > 0) {
for (const auto& asset : meta_graph_def.asset_file_def()) {
asset_file_defs->push_back(asset);
}
return Status::OK();
}
// Fall back to read from collection to be backward compatible with v1.
const auto& collection_def_map = meta_graph_def.collection_def();
const auto assets_it = collection_def_map.find(kSavedModelAssetsKey);
if (assets_it == collection_def_map.end()) {
return Status::OK();
}
const auto& any_assets = assets_it->second.any_list().value();
for (const auto& any_asset : any_assets) {
AssetFileDef asset_file_def;
TF_RETURN_IF_ERROR(
ParseAny(any_asset, &asset_file_def, "tensorflow.AssetFileDef"));
asset_file_defs->push_back(asset_file_def);
}
return Status::OK();
}
Status ReadSavedModelDebugInfoIfPresent( Status ReadSavedModelDebugInfoIfPresent(
const string& export_dir, const string& export_dir,
std::unique_ptr<GraphDebugInfo>* debug_info_proto) { std::unique_ptr<GraphDebugInfo>* debug_info_proto) {
@ -322,7 +261,7 @@ Status LoadSavedModelInternal(const SessionOptions& session_options,
std::vector<AssetFileDef> asset_file_defs; std::vector<AssetFileDef> asset_file_defs;
TF_RETURN_IF_ERROR( TF_RETURN_IF_ERROR(
GetAssetFileDefs(bundle->meta_graph_def, &asset_file_defs)); internal::GetAssetFileDefs(bundle->meta_graph_def, &asset_file_defs));
TF_RETURN_IF_ERROR( TF_RETURN_IF_ERROR(
RunRestore(run_options, export_dir, RunRestore(run_options, export_dir,
bundle->meta_graph_def.saver_def().restore_op_name(), bundle->meta_graph_def.saver_def().restore_op_name(),
@ -336,7 +275,7 @@ Status LoadSavedModelInternal(const SessionOptions& session_options,
const uint64 graph_init_start_microseconds = Env::Default()->NowMicros(); const uint64 graph_init_start_microseconds = Env::Default()->NowMicros();
string init_op_name; string init_op_name;
TF_RETURN_IF_ERROR( TF_RETURN_IF_ERROR(
GetInitOp(export_dir, bundle->meta_graph_def, &init_op_name)); internal::GetInitOp(export_dir, bundle->meta_graph_def, &init_op_name));
TF_RETURN_IF_ERROR(RunInitOp(run_options, export_dir, bundle->meta_graph_def, TF_RETURN_IF_ERROR(RunInitOp(run_options, export_dir, bundle->meta_graph_def,
asset_file_defs, bundle->session.get(), asset_file_defs, bundle->session.get(),
init_op_name)); init_op_name));

View File

@ -0,0 +1,90 @@
/* Copyright 2016 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#include "tensorflow/cc/saved_model/loader_util.h"
#include <vector>
#include "tensorflow/cc/saved_model/constants.h"
#include "tensorflow/core/lib/strings/strcat.h"
#include "tensorflow/core/platform/errors.h"
#include "tensorflow/core/platform/protobuf_internal.h"
namespace tensorflow {
namespace internal {
// A SavedModel may store the name of the initialization op to run in the
// in the SignatureDef (v2) or a collection (v1). If an init_op collection
// exists, then the collection must contain exactly one op.
Status GetInitOp(const string& export_dir, const MetaGraphDef& meta_graph_def,
string* init_op_name) {
const auto& sig_def_map = meta_graph_def.signature_def();
const auto& init_op_sig_it =
meta_graph_def.signature_def().find(kSavedModelInitOpSignatureKey);
if (init_op_sig_it != sig_def_map.end()) {
*init_op_name = init_op_sig_it->second.outputs()
.find(kSavedModelInitOpSignatureKey)
->second.name();
return Status::OK();
}
const auto& collection_def_map = meta_graph_def.collection_def();
string init_op_collection_key;
if (collection_def_map.find(kSavedModelMainOpKey) !=
collection_def_map.end()) {
init_op_collection_key = kSavedModelMainOpKey;
} else {
init_op_collection_key = kSavedModelLegacyInitOpKey;
}
const auto init_op_it = collection_def_map.find(init_op_collection_key);
if (init_op_it != collection_def_map.end()) {
if (init_op_it->second.node_list().value_size() != 1) {
return errors::FailedPrecondition(
strings::StrCat("Expected exactly one main op in : ", export_dir));
}
*init_op_name = init_op_it->second.node_list().value(0);
}
return Status::OK();
}
Status GetAssetFileDefs(const MetaGraphDef& meta_graph_def,
std::vector<AssetFileDef>* asset_file_defs) {
// With SavedModel v2, we write asset file def into metagraph instead of
// collection, so read from metagraph first.
if (meta_graph_def.asset_file_def_size() > 0) {
for (const auto& asset : meta_graph_def.asset_file_def()) {
asset_file_defs->push_back(asset);
}
return Status::OK();
}
// Fall back to read from collection to be backward compatible with v1.
const auto& collection_def_map = meta_graph_def.collection_def();
const auto assets_it = collection_def_map.find(kSavedModelAssetsKey);
if (assets_it == collection_def_map.end()) {
return Status::OK();
}
const auto& any_assets = assets_it->second.any_list().value();
for (const auto& any_asset : any_assets) {
AssetFileDef asset_file_def;
TF_RETURN_IF_ERROR(
ParseAny(any_asset, &asset_file_def, "tensorflow.AssetFileDef"));
asset_file_defs->push_back(asset_file_def);
}
return Status::OK();
}
} // namespace internal
} // namespace tensorflow

View File

@ -0,0 +1,39 @@
/* Copyright 2016 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#ifndef TENSORFLOW_CC_SAVED_MODEL_LOADER_UTIL_H_
#define TENSORFLOW_CC_SAVED_MODEL_LOADER_UTIL_H_
#include <string>
#include "tensorflow/core/lib/core/status.h"
#include "tensorflow/core/protobuf/meta_graph.pb.h"
namespace tensorflow {
namespace internal {
// A SavedModel may store the name of the initialization op to run in the
// in the SignatureDef (v2) or a collection (v1). If an init_op collection
// exists, then the collection must contain exactly one op.
Status GetInitOp(const string& export_dir, const MetaGraphDef& meta_graph_def,
string* init_op_name);
Status GetAssetFileDefs(const MetaGraphDef& meta_graph_def,
std::vector<AssetFileDef>* asset_file_defs);
} // namespace internal
} // namespace tensorflow
#endif // TENSORFLOW_CC_SAVED_MODEL_LOADER_UTIL_H_

View File

@ -33,6 +33,7 @@ MarkForCompilationPassFlags* mark_for_compilation_flags;
XlaDeviceFlags* device_flags; XlaDeviceFlags* device_flags;
XlaOpsCommonFlags* ops_flags; XlaOpsCommonFlags* ops_flags;
IntroduceFloatingPointJitterPassFlags* jitter_flags; IntroduceFloatingPointJitterPassFlags* jitter_flags;
MlirCommonFlags* mlir_flags;
std::vector<Flag>* flag_list; std::vector<Flag>* flag_list;
absl::once_flag flags_init; absl::once_flag flags_init;
@ -166,6 +167,9 @@ void AllocateAndParseFlags() {
jitter_flags = new IntroduceFloatingPointJitterPassFlags; jitter_flags = new IntroduceFloatingPointJitterPassFlags;
jitter_flags->jitter_amount = 1e-5; jitter_flags->jitter_amount = 1e-5;
mlir_flags = new MlirCommonFlags;
mlir_flags->tf_mlir_enable_mlir_bridge = false;
auto setter_for_jitter_tensor_names = [](string sequence) { auto setter_for_jitter_tensor_names = [](string sequence) {
jitter_flags->tensor_names = absl::StrSplit(sequence, ','); jitter_flags->tensor_names = absl::StrSplit(sequence, ',');
return true; return true;
@ -211,7 +215,11 @@ void AllocateAndParseFlags() {
Flag("tf_introduce_floating_point_jitter_amount", Flag("tf_introduce_floating_point_jitter_amount",
&jitter_flags->jitter_amount, &jitter_flags->jitter_amount,
"The amount of jitter to introduce. This amount is added to each " "The amount of jitter to introduce. This amount is added to each "
"element in the tensors named in `tensor_names.")}); "element in the tensors named in `tensor_names."),
Flag("tf_mlir_enable_mlir_bridge",
&mlir_flags->tf_mlir_enable_mlir_bridge,
"Enables experimental MLIR-Based TensorFlow Compiler Bridge.")});
AppendMarkForCompilationPassFlagsInternal(flag_list); AppendMarkForCompilationPassFlagsInternal(flag_list);
xla::ParseFlagsFromEnvAndDieIfUnknown("TF_XLA_FLAGS", *flag_list); xla::ParseFlagsFromEnvAndDieIfUnknown("TF_XLA_FLAGS", *flag_list);
@ -250,6 +258,11 @@ GetIntroduceFloatingPointJitterPassFlags() {
return *jitter_flags; return *jitter_flags;
} }
MlirCommonFlags* GetMlirCommonFlags() {
absl::call_once(flags_init, &AllocateAndParseFlags);
return mlir_flags;
}
void AppendMarkForCompilationPassFlags(std::vector<Flag>* flag_list) { void AppendMarkForCompilationPassFlags(std::vector<Flag>* flag_list) {
absl::call_once(flags_init, &AllocateAndParseFlags); absl::call_once(flags_init, &AllocateAndParseFlags);
AppendMarkForCompilationPassFlagsInternal(flag_list); AppendMarkForCompilationPassFlagsInternal(flag_list);

View File

@ -133,6 +133,11 @@ struct IntroduceFloatingPointJitterPassFlags {
std::vector<string> tensor_names; std::vector<string> tensor_names;
}; };
// Flags for common MLIR configurations.
struct MlirCommonFlags {
bool tf_mlir_enable_mlir_bridge;
};
// Return a pointer to the DumpGraphFlags struct; // Return a pointer to the DumpGraphFlags struct;
// repeated calls return the same pointer. // repeated calls return the same pointer.
// This should be called only after Flags::Parse() has returned. // This should be called only after Flags::Parse() has returned.
@ -148,6 +153,8 @@ const XlaOpsCommonFlags& GetXlaOpsCommonFlags();
const IntroduceFloatingPointJitterPassFlags& const IntroduceFloatingPointJitterPassFlags&
GetIntroduceFloatingPointJitterPassFlags(); GetIntroduceFloatingPointJitterPassFlags();
MlirCommonFlags* GetMlirCommonFlags();
// Appends the flag definitions associated with // Appends the flag definitions associated with
// MarkForCompilationPassFlags/DumpGraphFlags to `flag_list`. // MarkForCompilationPassFlags/DumpGraphFlags to `flag_list`.
// //

View File

@ -30,7 +30,7 @@ cc_library(
hdrs = ["op_or_arg_name_mapper.h"], hdrs = ["op_or_arg_name_mapper.h"],
deps = [ deps = [
"@com_google_absl//absl/strings", "@com_google_absl//absl/strings",
"@llvm-project//llvm:support", "@llvm-project//llvm:Support",
"@llvm-project//mlir:IR", "@llvm-project//mlir:IR",
], ],
) )
@ -42,7 +42,7 @@ cc_library(
":init_mlir", ":init_mlir",
"//tensorflow/core:lib", "//tensorflow/core:lib",
"//tensorflow/core/platform:logging", "//tensorflow/core/platform:logging",
"@llvm-project//llvm:support", "@llvm-project//llvm:Support",
"@llvm-project//mlir:AllPassesAndDialects", "@llvm-project//mlir:AllPassesAndDialects",
"@llvm-project//mlir:IR", "@llvm-project//mlir:IR",
"@llvm-project//mlir:MlirOptLib", "@llvm-project//mlir:MlirOptLib",
@ -86,7 +86,7 @@ cc_library(
hdrs = ["init_mlir.h"], hdrs = ["init_mlir.h"],
deps = [ deps = [
"//tensorflow/core:lib", "//tensorflow/core:lib",
"@llvm-project//llvm:support", "@llvm-project//llvm:Support",
], ],
) )
@ -102,7 +102,7 @@ cc_library(
"//tensorflow/compiler/mlir/tensorflow:mlir_roundtrip_flags", "//tensorflow/compiler/mlir/tensorflow:mlir_roundtrip_flags",
"//tensorflow/core:core_cpu", "//tensorflow/core:core_cpu",
"@com_google_absl//absl/container:flat_hash_set", "@com_google_absl//absl/container:flat_hash_set",
"@llvm-project//llvm:support", "@llvm-project//llvm:Support",
"@llvm-project//mlir:IR", "@llvm-project//mlir:IR",
"@llvm-project//mlir:Shape", "@llvm-project//mlir:Shape",
"@llvm-project//mlir:StandardOps", "@llvm-project//mlir:StandardOps",
@ -155,7 +155,7 @@ tf_cc_binary(
"//tensorflow/core:tensorflow", "//tensorflow/core:tensorflow",
"//tensorflow/stream_executor/lib", "//tensorflow/stream_executor/lib",
"@com_google_absl//absl/strings", "@com_google_absl//absl/strings",
"@llvm-project//llvm:support", "@llvm-project//llvm:Support",
"@llvm-project//mlir:AllPassesAndDialects", "@llvm-project//mlir:AllPassesAndDialects",
"@llvm-project//mlir:IR", "@llvm-project//mlir:IR",
"@llvm-project//mlir:Support", "@llvm-project//mlir:Support",

View File

@ -225,7 +225,7 @@ cc_library(
"//tensorflow/compiler/mlir/lite/quantization:quantization_lib", "//tensorflow/compiler/mlir/lite/quantization:quantization_lib",
"//tensorflow/compiler/mlir/tensorflow:tensorflow_types", "//tensorflow/compiler/mlir/tensorflow:tensorflow_types",
"//tensorflow/lite/schema:schema_fbs", "//tensorflow/lite/schema:schema_fbs",
"@llvm-project//llvm:support", "@llvm-project//llvm:Support",
"@llvm-project//mlir:Analysis", "@llvm-project//mlir:Analysis",
"@llvm-project//mlir:DerivedAttributeOpInterface", "@llvm-project//mlir:DerivedAttributeOpInterface",
"@llvm-project//mlir:Dialect", "@llvm-project//mlir:Dialect",
@ -253,7 +253,7 @@ cc_library(
deps = [ deps = [
":tensorflow_lite", ":tensorflow_lite",
"//tensorflow/compiler/mlir/tensorflow", "//tensorflow/compiler/mlir/tensorflow",
"@llvm-project//llvm:support", "@llvm-project//llvm:Support",
"@llvm-project//mlir:IR", "@llvm-project//mlir:IR",
"@llvm-project//mlir:StandardOps", "@llvm-project//mlir:StandardOps",
"@llvm-project//mlir:Support", "@llvm-project//mlir:Support",
@ -272,7 +272,7 @@ cc_library(
deps = [ deps = [
":tensorflow_lite", ":tensorflow_lite",
"//tensorflow/compiler/mlir/tensorflow", "//tensorflow/compiler/mlir/tensorflow",
"@llvm-project//llvm:support", "@llvm-project//llvm:Support",
"@llvm-project//mlir:IR", "@llvm-project//mlir:IR",
"@llvm-project//mlir:StandardOps", "@llvm-project//mlir:StandardOps",
"@llvm-project//mlir:Support", "@llvm-project//mlir:Support",
@ -289,7 +289,7 @@ cc_library(
], ],
deps = [ deps = [
":tensorflow_lite", ":tensorflow_lite",
"@llvm-project//llvm:support", "@llvm-project//llvm:Support",
"@llvm-project//mlir:IR", "@llvm-project//mlir:IR",
"@llvm-project//mlir:StandardOps", "@llvm-project//mlir:StandardOps",
], ],
@ -304,7 +304,7 @@ tf_cc_test(
"//tensorflow/compiler/mlir/tensorflow", "//tensorflow/compiler/mlir/tensorflow",
"//tensorflow/core:test", "//tensorflow/core:test",
"//tensorflow/core:test_main", "//tensorflow/core:test_main",
"@llvm-project//llvm:support", "@llvm-project//llvm:Support",
"@llvm-project//mlir:IR", "@llvm-project//mlir:IR",
"@llvm-project//mlir:StandardOps", "@llvm-project//mlir:StandardOps",
"@llvm-project//mlir:Support", "@llvm-project//mlir:Support",
@ -357,7 +357,7 @@ cc_library(
"//tensorflow/core/platform:logging", "//tensorflow/core/platform:logging",
"@com_google_absl//absl/container:inlined_vector", "@com_google_absl//absl/container:inlined_vector",
"@com_google_absl//absl/memory", "@com_google_absl//absl/memory",
"@llvm-project//llvm:support", "@llvm-project//llvm:Support",
"@llvm-project//mlir:Analysis", "@llvm-project//mlir:Analysis",
"@llvm-project//mlir:IR", "@llvm-project//mlir:IR",
"@llvm-project//mlir:Pass", "@llvm-project//mlir:Pass",
@ -383,7 +383,7 @@ cc_library(
":validators", ":validators",
"//tensorflow/compiler/mlir/lite/quantization:quantization_lib", "//tensorflow/compiler/mlir/lite/quantization:quantization_lib",
"//tensorflow/compiler/mlir/tensorflow", "//tensorflow/compiler/mlir/tensorflow",
"@llvm-project//llvm:support", "@llvm-project//llvm:Support",
"@llvm-project//mlir:Analysis", "@llvm-project//mlir:Analysis",
"@llvm-project//mlir:IR", "@llvm-project//mlir:IR",
"@llvm-project//mlir:Pass", "@llvm-project//mlir:Pass",
@ -416,7 +416,7 @@ cc_library(
"//tensorflow/compiler/mlir/lite/quantization/lite:tfl_to_std", "//tensorflow/compiler/mlir/lite/quantization/lite:tfl_to_std",
"//tensorflow/core:protos_all_cc", "//tensorflow/core:protos_all_cc",
"@com_google_absl//absl/memory", "@com_google_absl//absl/memory",
"@llvm-project//llvm:support", "@llvm-project//llvm:Support",
"@llvm-project//mlir:Analysis", "@llvm-project//mlir:Analysis",
"@llvm-project//mlir:IR", "@llvm-project//mlir:IR",
"@llvm-project//mlir:Pass", "@llvm-project//mlir:Pass",
@ -441,7 +441,7 @@ cc_library(
"@com_google_absl//absl/base", "@com_google_absl//absl/base",
"@com_google_absl//absl/memory", "@com_google_absl//absl/memory",
"@com_google_absl//absl/strings", "@com_google_absl//absl/strings",
"@llvm-project//llvm:support", "@llvm-project//llvm:Support",
"@llvm-project//mlir:IR", "@llvm-project//mlir:IR",
"@llvm-project//mlir:Pass", "@llvm-project//mlir:Pass",
"@llvm-project//mlir:StandardOps", "@llvm-project//mlir:StandardOps",
@ -494,8 +494,8 @@ tf_native_cc_binary(
"converter_gen.cc", "converter_gen.cc",
], ],
deps = [ deps = [
"@llvm-project//llvm:support", "@llvm-project//llvm:Support",
"@llvm-project//llvm:tablegen", "@llvm-project//llvm:TableGen",
"@llvm-project//mlir:TableGen", "@llvm-project//mlir:TableGen",
], ],
) )
@ -541,8 +541,8 @@ cc_library(
"@com_google_absl//absl/container:flat_hash_map", "@com_google_absl//absl/container:flat_hash_map",
"@com_google_absl//absl/strings", "@com_google_absl//absl/strings",
"@flatbuffers", "@flatbuffers",
"@llvm-project//llvm:analysis", "@llvm-project//llvm:Analysis",
"@llvm-project//llvm:support", "@llvm-project//llvm:Support",
"@llvm-project//mlir:IR", "@llvm-project//mlir:IR",
"@llvm-project//mlir:TransformUtils", "@llvm-project//mlir:TransformUtils",
], ],
@ -619,7 +619,7 @@ cc_library(
"@com_google_absl//absl/container:flat_hash_set", "@com_google_absl//absl/container:flat_hash_set",
"@com_google_absl//absl/strings", "@com_google_absl//absl/strings",
"@flatbuffers", "@flatbuffers",
"@llvm-project//llvm:support", "@llvm-project//llvm:Support",
"@llvm-project//mlir:IR", "@llvm-project//mlir:IR",
"@llvm-project//mlir:QuantOps", "@llvm-project//mlir:QuantOps",
"@llvm-project//mlir:StandardOps", "@llvm-project//mlir:StandardOps",
@ -653,7 +653,7 @@ cc_library(
"@com_google_absl//absl/container:flat_hash_map", "@com_google_absl//absl/container:flat_hash_map",
"@com_google_absl//absl/container:flat_hash_set", "@com_google_absl//absl/container:flat_hash_set",
"@com_google_absl//absl/strings", "@com_google_absl//absl/strings",
"@llvm-project//llvm:support", "@llvm-project//llvm:Support",
"@llvm-project//mlir:IR", "@llvm-project//mlir:IR",
"@llvm-project//mlir:QuantOps", "@llvm-project//mlir:QuantOps",
"@llvm-project//mlir:StandardOps", "@llvm-project//mlir:StandardOps",
@ -713,7 +713,7 @@ cc_library(
"@com_google_absl//absl/container:flat_hash_map", "@com_google_absl//absl/container:flat_hash_map",
"@com_google_absl//absl/container:flat_hash_set", "@com_google_absl//absl/container:flat_hash_set",
"@com_google_absl//absl/strings", "@com_google_absl//absl/strings",
"@llvm-project//llvm:support", "@llvm-project//llvm:Support",
"@llvm-project//mlir:IR", "@llvm-project//mlir:IR",
"@llvm-project//mlir:MlirTranslateMain", "@llvm-project//mlir:MlirTranslateMain",
"@llvm-project//mlir:QuantOps", "@llvm-project//mlir:QuantOps",
@ -743,7 +743,7 @@ cc_library(
"tf_tfl_translate_cl.h", "tf_tfl_translate_cl.h",
], ],
deps = [ deps = [
"@llvm-project//llvm:support", "@llvm-project//llvm:Support",
], ],
alwayslink = 1, alwayslink = 1,
) )
@ -755,7 +755,7 @@ cc_library(
], ],
deps = [ deps = [
"//tensorflow/compiler/mlir/lite/quantization:quantization_config", "//tensorflow/compiler/mlir/lite/quantization:quantization_config",
"@llvm-project//llvm:support", "@llvm-project//llvm:Support",
], ],
) )
@ -780,7 +780,7 @@ tf_cc_binary(
":tf_tfl_translate_cl_options", ":tf_tfl_translate_cl_options",
":tf_to_tfl_flatbuffer", ":tf_to_tfl_flatbuffer",
"@com_google_absl//absl/strings", "@com_google_absl//absl/strings",
"@llvm-project//llvm:support", "@llvm-project//llvm:Support",
# TODO(b/155809683): Link only necessary dialects. # TODO(b/155809683): Link only necessary dialects.
"@llvm-project//mlir:AllPassesAndDialects", "@llvm-project//mlir:AllPassesAndDialects",
"@llvm-project//mlir:IR", "@llvm-project//mlir:IR",
@ -805,7 +805,7 @@ tf_cc_binary(
":flatbuffer_translate_lib", ":flatbuffer_translate_lib",
":flatbuffer_translate_registeration", ":flatbuffer_translate_registeration",
"@com_google_absl//absl/strings", "@com_google_absl//absl/strings",
"@llvm-project//llvm:support", "@llvm-project//llvm:Support",
# TODO(b/155809683): Link only necessary dialects. # TODO(b/155809683): Link only necessary dialects.
"@llvm-project//mlir:AllPassesAndDialects", "@llvm-project//mlir:AllPassesAndDialects",
"@llvm-project//mlir:IR", "@llvm-project//mlir:IR",
@ -874,7 +874,7 @@ cc_library(
"//tensorflow/lite/tools/optimize:quantize_weights", "//tensorflow/lite/tools/optimize:quantize_weights",
"//tensorflow/stream_executor/lib", "//tensorflow/stream_executor/lib",
"@com_google_absl//absl/types:span", "@com_google_absl//absl/types:span",
"@llvm-project//llvm:support", "@llvm-project//llvm:Support",
"@llvm-project//mlir:AllPassesAndDialects", "@llvm-project//mlir:AllPassesAndDialects",
"@llvm-project//mlir:IR", "@llvm-project//mlir:IR",
"@llvm-project//mlir:Parser", "@llvm-project//mlir:Parser",
@ -894,6 +894,6 @@ cc_library(
"//tensorflow/lite/experimental/mlir:__subpackages__", "//tensorflow/lite/experimental/mlir:__subpackages__",
], ],
deps = [ deps = [
"@llvm-project//llvm:support", "@llvm-project//llvm:Support",
], ],
) )

View File

@ -868,6 +868,8 @@ StatusOr<FuncOp> ConvertSubgraph(
subgraph, &builder, "outputs", func_outputs)); subgraph, &builder, "outputs", func_outputs));
} }
func.setAttr("tf.entry_function", builder.getDictionaryAttr(attributes)); func.setAttr("tf.entry_function", builder.getDictionaryAttr(attributes));
} else {
func.setVisibility(FuncOp::Visibility::Private);
} }
absl::flat_hash_set<const tflite::OperatorT*> pruned_subgraph_ops; absl::flat_hash_set<const tflite::OperatorT*> pruned_subgraph_ops;

View File

@ -27,7 +27,7 @@ cc_library(
"//tensorflow/lite/toco:toco_flags_proto_cc", "//tensorflow/lite/toco:toco_flags_proto_cc",
"//tensorflow/lite/toco:types_proto_cc", "//tensorflow/lite/toco:types_proto_cc",
"//tensorflow/stream_executor/lib", "//tensorflow/stream_executor/lib",
"@llvm-project//llvm:support", "@llvm-project//llvm:Support",
"@llvm-project//mlir:IR", "@llvm-project//mlir:IR",
"@llvm-project//mlir:Pass", "@llvm-project//mlir:Pass",
"@llvm-project//mlir:Support", "@llvm-project//mlir:Support",
@ -56,7 +56,7 @@ cc_library(
"//tensorflow/lite/toco:toco_flags_proto_cc", "//tensorflow/lite/toco:toco_flags_proto_cc",
"//tensorflow/lite/toco:types_proto_cc", "//tensorflow/lite/toco:types_proto_cc",
"//tensorflow/stream_executor/lib", "//tensorflow/stream_executor/lib",
"@llvm-project//llvm:support", "@llvm-project//llvm:Support",
"@llvm-project//mlir:IR", "@llvm-project//mlir:IR",
"@llvm-project//mlir:Pass", "@llvm-project//mlir:Pass",
"@llvm-project//mlir:Support", "@llvm-project//mlir:Support",
@ -85,7 +85,7 @@ cc_library(
"//tensorflow/lite/toco:toco_flags_proto_cc", "//tensorflow/lite/toco:toco_flags_proto_cc",
"//tensorflow/lite/toco:types_proto_cc", "//tensorflow/lite/toco:types_proto_cc",
"//tensorflow/stream_executor/lib", "//tensorflow/stream_executor/lib",
"@llvm-project//llvm:support", "@llvm-project//llvm:Support",
"@llvm-project//mlir:IR", "@llvm-project//mlir:IR",
"@llvm-project//mlir:Pass", "@llvm-project//mlir:Pass",
"@llvm-project//mlir:Support", "@llvm-project//mlir:Support",

View File

@ -80,7 +80,7 @@ cc_library(
"//tensorflow/core:lib_proto_parsing", "//tensorflow/core:lib_proto_parsing",
"@com_google_absl//absl/memory", "@com_google_absl//absl/memory",
"@com_google_absl//absl/strings", "@com_google_absl//absl/strings",
"@llvm-project//llvm:support", "@llvm-project//llvm:Support",
"@llvm-project//mlir:Analysis", "@llvm-project//mlir:Analysis",
"@llvm-project//mlir:IR", "@llvm-project//mlir:IR",
"@llvm-project//mlir:Pass", "@llvm-project//mlir:Pass",
@ -106,7 +106,7 @@ cc_library(
deps = [ deps = [
"//tensorflow/core:lib_proto_parsing", "//tensorflow/core:lib_proto_parsing",
"@com_google_absl//absl/memory", "@com_google_absl//absl/memory",
"@llvm-project//llvm:support", "@llvm-project//llvm:Support",
"@llvm-project//mlir:IR", "@llvm-project//mlir:IR",
"@llvm-project//mlir:QuantOps", "@llvm-project//mlir:QuantOps",
"@llvm-project//mlir:StandardOps", "@llvm-project//mlir:StandardOps",
@ -125,7 +125,7 @@ cc_library(
deps = [ deps = [
"//tensorflow/core:protos_all_cc", "//tensorflow/core:protos_all_cc",
"@com_google_absl//absl/strings", "@com_google_absl//absl/strings",
"@llvm-project//llvm:support", "@llvm-project//llvm:Support",
], ],
) )
@ -135,8 +135,8 @@ tf_native_cc_binary(
"tools/op_quant_spec_getters_gen.cc", "tools/op_quant_spec_getters_gen.cc",
], ],
deps = [ deps = [
"@llvm-project//llvm:support", "@llvm-project//llvm:Support",
"@llvm-project//llvm:tablegen", "@llvm-project//llvm:TableGen",
"@llvm-project//mlir:TableGen", "@llvm-project//mlir:TableGen",
], ],
) )
@ -157,7 +157,7 @@ cc_library(
deps = [ deps = [
":numerical_utils", ":numerical_utils",
"@com_google_absl//absl/types:optional", "@com_google_absl//absl/types:optional",
"@llvm-project//llvm:support", "@llvm-project//llvm:Support",
"@llvm-project//mlir:IR", "@llvm-project//mlir:IR",
"@llvm-project//mlir:QuantOps", "@llvm-project//mlir:QuantOps",
"@llvm-project//mlir:Support", "@llvm-project//mlir:Support",
@ -172,7 +172,7 @@ cc_library(
":device_target", ":device_target",
":quantization_lib", ":quantization_lib",
"@com_google_absl//absl/memory", "@com_google_absl//absl/memory",
"@llvm-project//llvm:support", "@llvm-project//llvm:Support",
"@llvm-project//mlir:IR", "@llvm-project//mlir:IR",
"@llvm-project//mlir:QuantOps", "@llvm-project//mlir:QuantOps",
"@llvm-project//mlir:StandardOps", "@llvm-project//mlir:StandardOps",

View File

@ -36,7 +36,7 @@ cc_library(
"//tensorflow/lite/core/api", "//tensorflow/lite/core/api",
"//tensorflow/lite/schema:schema_fbs", "//tensorflow/lite/schema:schema_fbs",
"@com_google_absl//absl/strings", "@com_google_absl//absl/strings",
"@llvm-project//llvm:support", "@llvm-project//llvm:Support",
"@llvm-project//mlir:IR", "@llvm-project//mlir:IR",
"@llvm-project//mlir:Pass", "@llvm-project//mlir:Pass",
], ],
@ -54,7 +54,7 @@ cc_library(
deps = [ deps = [
"//tensorflow/compiler/mlir/lite:tensorflow_lite", "//tensorflow/compiler/mlir/lite:tensorflow_lite",
"//tensorflow/compiler/mlir/lite/quantization:quantization_lib", "//tensorflow/compiler/mlir/lite/quantization:quantization_lib",
"@llvm-project//llvm:support", "@llvm-project//llvm:Support",
"@llvm-project//mlir:IR", "@llvm-project//mlir:IR",
"@llvm-project//mlir:QuantOps", "@llvm-project//mlir:QuantOps",
"@llvm-project//mlir:StandardOps", "@llvm-project//mlir:StandardOps",
@ -73,7 +73,7 @@ tf_cc_binary(
"//tensorflow/lite:framework", "//tensorflow/lite:framework",
"//tensorflow/lite/schema:schema_fbs", "//tensorflow/lite/schema:schema_fbs",
"@com_google_absl//absl/strings", "@com_google_absl//absl/strings",
"@llvm-project//llvm:support", "@llvm-project//llvm:Support",
"@llvm-project//mlir:AllPassesAndDialects", "@llvm-project//mlir:AllPassesAndDialects",
], ],
) )

View File

@ -27,7 +27,7 @@ cc_library(
"//tensorflow/compiler/mlir/lite/quantization:quantization_lib", "//tensorflow/compiler/mlir/lite/quantization:quantization_lib",
"//tensorflow/compiler/mlir/tensorflow", "//tensorflow/compiler/mlir/tensorflow",
"@com_google_absl//absl/strings", "@com_google_absl//absl/strings",
"@llvm-project//llvm:support", "@llvm-project//llvm:Support",
"@llvm-project//mlir:IR", "@llvm-project//mlir:IR",
"@llvm-project//mlir:Pass", "@llvm-project//mlir:Pass",
"@llvm-project//mlir:QuantOps", "@llvm-project//mlir:QuantOps",

View File

@ -32,7 +32,7 @@ cc_library(
"//tensorflow/lite/core/api", "//tensorflow/lite/core/api",
"//tensorflow/lite/schema:schema_fbs", "//tensorflow/lite/schema:schema_fbs",
"@com_google_absl//absl/strings", "@com_google_absl//absl/strings",
"@llvm-project//llvm:support", "@llvm-project//llvm:Support",
"@llvm-project//mlir:IR", "@llvm-project//mlir:IR",
"@llvm-project//mlir:Pass", "@llvm-project//mlir:Pass",
], ],

View File

@ -54,7 +54,7 @@ tf_native_cc_binary(
"//tensorflow/lite:framework", "//tensorflow/lite:framework",
"//tensorflow/lite/schema:schema_fbs", "//tensorflow/lite/schema:schema_fbs",
"@com_google_absl//absl/strings", "@com_google_absl//absl/strings",
"@llvm-project//llvm:support", "@llvm-project//llvm:Support",
], ],
) )
@ -70,6 +70,6 @@ tf_native_cc_binary(
"//tensorflow/lite:framework", "//tensorflow/lite:framework",
"//tensorflow/lite/schema:schema_fbs", "//tensorflow/lite/schema:schema_fbs",
"@com_google_absl//absl/strings", "@com_google_absl//absl/strings",
"@llvm-project//llvm:support", "@llvm-project//llvm:Support",
], ],
) )

View File

@ -0,0 +1,12 @@
// RUN: tf-opt %s -tfl-legalize-tf='run-tfl-runtime-verification=false' | FileCheck %s
func @broadcast_to_bf16(%arg0: tensor<3xbf16>, %arg1: tensor<2xi64>) -> tensor<3x3xbf16> {
%0 = "tf.BroadcastTo"(%arg0, %arg1) : (tensor<3xbf16>, tensor<2xi64>) -> tensor<3x3xbf16>
return %0: tensor<3x3xbf16>
// CHECK-LABEL: broadcast_to_bf16
// CHECK: [[CST:%.*]] = constant dense<1.000000e+00> : tensor<bf16>
// CHECK: [[FILL:%.*]] = "tfl.fill"(%arg1, [[CST]]) : (tensor<2xi64>, tensor<bf16>) -> tensor<3x3xbf16>
// CHECK: [[MUL:%.*]] = "tfl.mul"(%arg0, [[FILL]]) {fused_activation_function = "NONE"} : (tensor<3xbf16>, tensor<3x3xbf16>) -> tensor<3x3xbf16>
// CHECK: return [[MUL]] : tensor<3x3xbf16>
}

View File

@ -1021,24 +1021,6 @@ func @matmul_transposed(%arg0: tensor<40x37xf32>, %arg1: tensor<40x37xf32>) -> t
// CHECK: "tfl.fully_connected"(%arg0, %arg1, %cst) {fused_activation_function = "NONE", keep_num_dims = false, weights_format = "DEFAULT"} : (tensor<40x37xf32>, tensor<40x37xf32>, none) -> tensor<40x40xf32> // CHECK: "tfl.fully_connected"(%arg0, %arg1, %cst) {fused_activation_function = "NONE", keep_num_dims = false, weights_format = "DEFAULT"} : (tensor<40x37xf32>, tensor<40x37xf32>, none) -> tensor<40x40xf32>
} }
func @concat2Tensors(%arg0: tensor<2x1xi32>, %arg1: tensor<2x1xi32>) -> tensor<2x2xi32> {
%0 = "tf.Const"() { value = dense<1> : tensor<i32> } : () -> tensor<i32>
%1 = "tf.Concat"(%0, %arg0, %arg1) : (tensor<i32>, tensor<2x1xi32>, tensor<2x1xi32>) -> tensor<2x2xi32>
return %1 : tensor<2x2xi32>
// CHECK-LABEL: concat2Tensors
// CHECK: "tfl.concatenation"(%arg0, %arg1) {axis = 1 : i32, fused_activation_function = "NONE"} : (tensor<2x1xi32>, tensor<2x1xi32>) -> tensor<2x2xi32>
}
func @concat3Tensors(%arg0: tensor<2x1xi32>, %arg1: tensor<2x1xi32>, %arg2: tensor<2x1xi32>) -> tensor<2x3xi32> {
%0 = "tf.Const"() { value = dense<-1> : tensor<i32> } : () -> tensor<i32>
%1 = "tf.Concat"(%0, %arg0, %arg1, %arg2) : (tensor<i32>, tensor<2x1xi32>, tensor<2x1xi32>, tensor<2x1xi32>) -> tensor<2x3xi32>
return %1 : tensor<2x3xi32>
// CHECK-LABEL: concat3Tensors
// CHECK: "tfl.concatenation"(%arg0, %arg1, %arg2) {axis = -1 : i32, fused_activation_function = "NONE"} : (tensor<2x1xi32>, tensor<2x1xi32>, tensor<2x1xi32>) -> tensor<2x3xi32>
}
func @concatv2With3Tensors(%arg0: tensor<2x1xi32>, %arg1: tensor<2x1xi32>, %arg2: tensor<2x1xi32>) -> tensor<2x3xi32> { func @concatv2With3Tensors(%arg0: tensor<2x1xi32>, %arg1: tensor<2x1xi32>, %arg2: tensor<2x1xi32>) -> tensor<2x3xi32> {
%0 = "tf.Const"() { value = dense<-1> : tensor<i32> } : () -> tensor<i32> %0 = "tf.Const"() { value = dense<-1> : tensor<i32> } : () -> tensor<i32>
%1 = "tf.ConcatV2"(%arg0, %arg1, %arg2, %0) : (tensor<2x1xi32>, tensor<2x1xi32>, tensor<2x1xi32>, tensor<i32>) -> tensor<2x3xi32> %1 = "tf.ConcatV2"(%arg0, %arg1, %arg2, %0) : (tensor<2x1xi32>, tensor<2x1xi32>, tensor<2x1xi32>, tensor<i32>) -> tensor<2x3xi32>

View File

@ -13,14 +13,12 @@ func @main(%arg0: tensor<f32>, %arg1: tensor<f32>) -> (tensor<f32>) {
return %3 : tensor<f32> return %3 : tensor<f32>
} }
// CHECK-NOT: add func @add(%arg0: tensor<*xf32>, %arg1: tensor<*xf32>) -> tensor<*xf32> attributes {sym_visibility = "private"} {
// CHECK-NOT: sub
func @add(%arg0: tensor<*xf32>, %arg1: tensor<*xf32>) -> tensor<*xf32> {
%0 = "tf.Add"(%arg0, %arg1): (tensor<*xf32>, tensor<*xf32>) -> tensor<*xf32> %0 = "tf.Add"(%arg0, %arg1): (tensor<*xf32>, tensor<*xf32>) -> tensor<*xf32>
return %0 : tensor<*xf32> return %0 : tensor<*xf32>
} }
func @sub(%arg0: tensor<*xf32>, %arg1: tensor<*xf32>) -> tensor<*xf32> { func @sub(%arg0: tensor<*xf32>, %arg1: tensor<*xf32>) -> tensor<*xf32> attributes {sym_visibility = "private"} {
%0 = "tf.Sub"(%arg0, %arg1) : (tensor<*xf32>, tensor<*xf32>) -> tensor<*xf32> %0 = "tf.Sub"(%arg0, %arg1) : (tensor<*xf32>, tensor<*xf32>) -> tensor<*xf32>
return %0 : tensor<*xf32> return %0 : tensor<*xf32>
} }
@ -42,65 +40,31 @@ func @main(%arg0: tensor<f32>, %arg1: tensor<f32>) -> (tensor<f32>) {
return %3 : tensor<f32> return %3 : tensor<f32>
} }
// CHECK-NOT: addormul func @addormul(%arg0: tensor<*xf32>, %arg1: tensor<*xf32>) -> tensor<*xf32> attributes {sym_visibility = "private"} {
// CHECK-NOT: sub
// CHECK-NOT: mul
// CHECK-NOT: add
func @addormul(%arg0: tensor<*xf32>, %arg1: tensor<*xf32>) -> tensor<*xf32> {
%0 = constant dense<false> : tensor<i1> %0 = constant dense<false> : tensor<i1>
%1 = "tf.If"(%0, %arg1, %arg0) {else_branch = @mul, then_branch = @add, is_stateless = true} : (tensor<i1>, tensor<*xf32>, tensor<*xf32>) -> tensor<*xf32> %1 = "tf.If"(%0, %arg1, %arg0) {else_branch = @mul, then_branch = @add, is_stateless = true} : (tensor<i1>, tensor<*xf32>, tensor<*xf32>) -> tensor<*xf32>
return %1 : tensor<*xf32> return %1 : tensor<*xf32>
} }
func @sub(%arg0: tensor<*xf32>, %arg1: tensor<*xf32>) -> tensor<*xf32> { func @sub(%arg0: tensor<*xf32>, %arg1: tensor<*xf32>) -> tensor<*xf32> attributes {sym_visibility = "private"} {
%0 = "tf.Sub"(%arg0, %arg1) : (tensor<*xf32>, tensor<*xf32>) -> tensor<*xf32> %0 = "tf.Sub"(%arg0, %arg1) : (tensor<*xf32>, tensor<*xf32>) -> tensor<*xf32>
return %0 : tensor<*xf32> return %0 : tensor<*xf32>
} }
func @add(%arg0: tensor<*xf32>, %arg1: tensor<*xf32>) -> tensor<*xf32> { func @add(%arg0: tensor<*xf32>, %arg1: tensor<*xf32>) -> tensor<*xf32> attributes {sym_visibility = "private"} {
%0 = "tf.Add"(%arg0, %arg1): (tensor<*xf32>, tensor<*xf32>) -> tensor<*xf32> %0 = "tf.Add"(%arg0, %arg1): (tensor<*xf32>, tensor<*xf32>) -> tensor<*xf32>
return %0 : tensor<*xf32> return %0 : tensor<*xf32>
} }
func @mul(%arg0: tensor<*xf32>, %arg1: tensor<*xf32>) -> tensor<*xf32> { func @mul(%arg0: tensor<*xf32>, %arg1: tensor<*xf32>) -> tensor<*xf32> attributes {sym_visibility = "private"} {
%0 = "tf.Multiply"(%arg0, %arg1): (tensor<*xf32>, tensor<*xf32>) -> tensor<*xf32> %0 = "tf.Multiply"(%arg0, %arg1): (tensor<*xf32>, tensor<*xf32>) -> tensor<*xf32>
return %0 : tensor<*xf32> return %0 : tensor<*xf32>
} }
// ----- // -----
// Verify that branch functions with multiple references are not erased. // Verify unused if with functions without side-effects is removed.
// CHECK-LABEL: main
func @main(%arg0: tensor<f32>, %arg1: tensor<f32>, %arg2: tensor<i1>) -> (tensor<f32>, tensor<f32>) {
%0 = "tf.Placeholder.input"(%arg0) : (tensor<f32>) -> tensor<f32>
%1 = "tf.Placeholder.input"(%arg1) : (tensor<f32>) -> tensor<f32>
%2 = constant dense<true> : tensor<i1>
// CHECK: tf.Add
%3 = "tf.If"(%2, %0, %1) {else_branch = @sub, then_branch = @add, is_stateless = true} : (tensor<i1>, tensor<f32>, tensor<f32>) -> tensor<f32>
// CHECK: tf.If
%4 = "tf.If"(%arg2, %0, %1) {else_branch = @sub, then_branch = @add, is_stateless = true} : (tensor<i1>, tensor<f32>, tensor<f32>) -> tensor<f32>
return %3, %4 : tensor<f32>, tensor<f32>
}
// CHECK: add
// CHECK: sub
func @add(%arg0: tensor<*xf32>, %arg1: tensor<*xf32>) -> tensor<*xf32> {
%0 = "tf.Add"(%arg0, %arg1): (tensor<*xf32>, tensor<*xf32>) -> tensor<*xf32>
return %0 : tensor<*xf32>
}
func @sub(%arg0: tensor<*xf32>, %arg1: tensor<*xf32>) -> tensor<*xf32> {
%0 = "tf.Sub"(%arg0, %arg1) : (tensor<*xf32>, tensor<*xf32>) -> tensor<*xf32>
return %0 : tensor<*xf32>
}
// -----
// Verify unused if with functions without side-effects are removed.
func @main(%arg0: tensor<3x15x14x3xf32>) -> tensor<3x15x14x8xf32> func @main(%arg0: tensor<3x15x14x3xf32>) -> tensor<3x15x14x8xf32>
attributes {tf.entry_function = {inputs = "input", outputs = "Conv2D"}} { attributes {tf.entry_function = {inputs = "input", outputs = "Conv2D"}} {
%cst = constant dense<[0, 1, 2, 3]> : tensor<4xi32> %cst = constant dense<[0, 1, 2, 3]> : tensor<4xi32>
@ -118,26 +82,22 @@ func @main(%arg0: tensor<3x15x14x3xf32>) -> tensor<3x15x14x8xf32>
return %4 : tensor<3x15x14x8xf32> return %4 : tensor<3x15x14x8xf32>
} }
func @_functionalize_if_else_branch_00(%arg0: tensor<*xi1>, %arg1: tensor<*xf32>, %arg2: tensor<*xf32>) -> tensor<i1> { func @_functionalize_if_else_branch_00(%arg0: tensor<*xi1>, %arg1: tensor<*xf32>, %arg2: tensor<*xf32>) -> tensor<i1> attributes {sym_visibility = "private"} {
%cst = constant dense<false> : tensor<i1> %cst = constant dense<false> : tensor<i1>
return %cst : tensor<i1> return %cst : tensor<i1>
} }
func @_functionalize_if_then_branch_00(%arg0: tensor<*xi1>, %arg1: tensor<*xf32>, %arg2: tensor<*xf32>) -> tensor<i1> { func @_functionalize_if_then_branch_00(%arg0: tensor<*xi1>, %arg1: tensor<*xf32>, %arg2: tensor<*xf32>) -> tensor<i1> attributes {sym_visibility = "private"} {
%cst = constant dense<true> : tensor<i1> %cst = constant dense<true> : tensor<i1>
return %cst : tensor<i1> return %cst : tensor<i1>
} }
// CHECK: func @main
// CHECK-NOT: tf.If // CHECK-NOT: tf.If
// CHECK: return // CHECK: return
// CHECK-NOT: func @_functionalize_if_else_branch_00
// CHECK-NOT: func @_functionalize_if_then_branch_00
// ----- // -----
// Verify unused if with function with side-effects is not removed. // Verify unused if with function with side-effects is not removed.
// CHECK-LABEL: main
func @main(%arg0: tensor<3x15x14x3xf32>) -> tensor<3x15x14x8xf32> func @main(%arg0: tensor<3x15x14x3xf32>) -> tensor<3x15x14x8xf32>
attributes {tf.entry_function = {inputs = "input", outputs = "Conv2D"}} { attributes {tf.entry_function = {inputs = "input", outputs = "Conv2D"}} {
%cst = constant dense<[0, 1, 2, 3]> : tensor<4xi32> %cst = constant dense<[0, 1, 2, 3]> : tensor<4xi32>
@ -155,27 +115,25 @@ func @main(%arg0: tensor<3x15x14x3xf32>) -> tensor<3x15x14x8xf32>
return %4 : tensor<3x15x14x8xf32> return %4 : tensor<3x15x14x8xf32>
} }
func @_functionalize_if_else_branch_01(%arg0: tensor<*xi1>, %arg1: tensor<*xf32>, %arg2: tensor<*xf32>) -> tensor<i1> { func @_functionalize_if_else_branch_01(%arg0: tensor<*xi1>, %arg1: tensor<*xf32>, %arg2: tensor<*xf32>) -> tensor<i1> attributes {sym_visibility = "private"} {
%cst = constant dense<false> : tensor<i1> %cst = constant dense<false> : tensor<i1>
return %cst : tensor<i1> return %cst : tensor<i1>
} }
func @_functionalize_if_then_branch_01(%arg0: tensor<*xi1>, %arg1: tensor<*xf32>, %arg2: tensor<*xf32>) -> tensor<i1> { func @_functionalize_if_then_branch_01(%arg0: tensor<*xi1>, %arg1: tensor<*xf32>, %arg2: tensor<*xf32>) -> tensor<i1> attributes {sym_visibility = "private"} {
%0 = "tf.blah"() : () -> tensor<i1> %0 = "tf.blah"() : () -> tensor<i1>
return %0 : tensor<i1> return %0 : tensor<i1>
} }
// CHECK: func @main
// CHECK: tf.If // CHECK: tf.If
// CHECK: return // CHECK: return
// CHECK: func @_functionalize_if_else_branch_01
// CHECK: func @_functionalize_if_then_branch_01
// ----- // -----
// Verify unused if with function with side-effects is removed if op says // Verify unused if with function with side-effects is removed if op says
// stateless. // stateless.
// CHECK-LABEL: main
func @main(%arg0: tensor<3x15x14x3xf32>) -> tensor<3x15x14x8xf32> func @main(%arg0: tensor<3x15x14x3xf32>) -> tensor<3x15x14x8xf32>
attributes {tf.entry_function = {inputs = "input", outputs = "Conv2D"}} { attributes {tf.entry_function = {inputs = "input", outputs = "Conv2D"}} {
%cst = constant dense<[0, 1, 2, 3]> : tensor<4xi32> %cst = constant dense<[0, 1, 2, 3]> : tensor<4xi32>
@ -193,18 +151,15 @@ func @main(%arg0: tensor<3x15x14x3xf32>) -> tensor<3x15x14x8xf32>
return %4 : tensor<3x15x14x8xf32> return %4 : tensor<3x15x14x8xf32>
} }
func @_functionalize_if_else_branch_02(%arg0: tensor<*xi1>, %arg1: tensor<*xf32>, %arg2: tensor<*xf32>) -> tensor<i1> { func @_functionalize_if_else_branch_02(%arg0: tensor<*xi1>, %arg1: tensor<*xf32>, %arg2: tensor<*xf32>) -> tensor<i1> attributes {sym_visibility = "private"} {
%cst = constant dense<false> : tensor<i1> %cst = constant dense<false> : tensor<i1>
return %cst : tensor<i1> return %cst : tensor<i1>
} }
func @_functionalize_if_then_branch_02(%arg0: tensor<*xi1>, %arg1: tensor<*xf32>, %arg2: tensor<*xf32>) -> tensor<i1> { func @_functionalize_if_then_branch_02(%arg0: tensor<*xi1>, %arg1: tensor<*xf32>, %arg2: tensor<*xf32>) -> tensor<i1> attributes {sym_visibility = "private"} {
%0 = "tf.blah"() : () -> tensor<i1> %0 = "tf.blah"() : () -> tensor<i1>
return %0 : tensor<i1> return %0 : tensor<i1>
} }
// CHECK: func @main
// CHECK-NOT: tf.If // CHECK-NOT: tf.If
// CHECK: return // CHECK: return
// CHECK-NOT: func @_functionalize_if_else_branch_02
// CHECK-NOT: func @_functionalize_if_then_branch_02

View File

@ -94,12 +94,6 @@ void AddTFToTFLConversionPasses(const mlir::TFL::PassConfig& pass_config,
pass_manager->addPass(mlir::TFL::CreatePrepareCompositeFunctionsPass()); pass_manager->addPass(mlir::TFL::CreatePrepareCompositeFunctionsPass());
} }
// This pass marks non-exported functions as symbol visibility 'private'
// those deemed read-only as immutable.
pass_manager->addPass(
mlir::tf_saved_model::
CreateMarkFunctionVisibilityUsingSavedModelLinkagePass());
pass_manager->addPass(mlir::createInlinerPass()); pass_manager->addPass(mlir::createInlinerPass());
pass_manager->addPass(mlir::createSymbolDCEPass()); pass_manager->addPass(mlir::createSymbolDCEPass());
@ -162,6 +156,7 @@ void AddTFToTFLConversionPasses(const mlir::TFL::PassConfig& pass_config,
// so that it can target constants introduced once TensorFlow Identity ops // so that it can target constants introduced once TensorFlow Identity ops
// are removed during legalization. // are removed during legalization.
pass_manager->addPass(mlir::TFL::CreateOptimizeFunctionalOpsPass()); pass_manager->addPass(mlir::TFL::CreateOptimizeFunctionalOpsPass());
pass_manager->addPass(mlir::createSymbolDCEPass());
pass_manager->addNestedPass<mlir::FuncOp>(mlir::createCanonicalizerPass()); pass_manager->addNestedPass<mlir::FuncOp>(mlir::createCanonicalizerPass());
pass_manager->addNestedPass<mlir::FuncOp>(mlir::createCSEPass()); pass_manager->addNestedPass<mlir::FuncOp>(mlir::createCSEPass());
// This pass should be always at the end of the floating point model // This pass should be always at the end of the floating point model
@ -237,6 +232,7 @@ void CreateTFLStandardPipeline(OpPassManager& pm,
mlir::TFL::CreateLegalizeTFPass(/*run_tfl_runtime_verification=*/true)); mlir::TFL::CreateLegalizeTFPass(/*run_tfl_runtime_verification=*/true));
pm.addPass(mlir::TFL::CreateOptimizePass()); pm.addPass(mlir::TFL::CreateOptimizePass());
pm.addPass(mlir::TFL::CreateOptimizeFunctionalOpsPass()); pm.addPass(mlir::TFL::CreateOptimizeFunctionalOpsPass());
pm.addPass(mlir::createSymbolDCEPass());
// Canonicalize, CSE etc. // Canonicalize, CSE etc.
pm.addNestedPass<mlir::FuncOp>(mlir::createCanonicalizerPass()); pm.addNestedPass<mlir::FuncOp>(mlir::createCanonicalizerPass());

View File

@ -123,7 +123,6 @@ bool HasSameStaticShapes(Operation* op) {
// operands are properly supported in declarative rewrite rule specification. // operands are properly supported in declarative rewrite rule specification.
DECL_CONVERT_OP(Assert); DECL_CONVERT_OP(Assert);
DECL_CONVERT_OP(Concat);
DECL_CONVERT_OP(ConcatV2); DECL_CONVERT_OP(ConcatV2);
DECL_CONVERT_OP(MatMul); DECL_CONVERT_OP(MatMul);
DECL_CONVERT_OP(MatrixDiagV2); DECL_CONVERT_OP(MatrixDiagV2);
@ -184,25 +183,6 @@ LogicalResult ConvertTFRandomUniformOp::matchAndRewrite(
return failure(); return failure();
} }
LogicalResult ConvertTFConcatOp::matchAndRewrite(
Operation* op, PatternRewriter& rewriter) const {
auto tf_concat_op = cast<TF::ConcatOp>(op);
auto values = tf_concat_op.values();
auto output_type = tf_concat_op.output().getType();
// Extract axis attribute from constant concat_dims tensor
ElementsAttr axis;
if (!matchPattern(tf_concat_op.concat_dim(), m_Constant(&axis)))
return failure();
StringAttr fused_activation_function =
StringAttr::get("NONE", rewriter.getContext());
rewriter.replaceOpWithNewOp<TFL::ConcatenationOp>(
op, output_type, values, mlir::TFL::ExtractSingleElementAsInteger(axis),
fused_activation_function);
return success();
}
// Converts any IntegerAttr to an IntegerAttr of an i32 type. // Converts any IntegerAttr to an IntegerAttr of an i32 type.
// The value won't change in the new attribute, but if the value is out of // The value won't change in the new attribute, but if the value is out of
// the bound of i32, the function returns a failure. // the bound of i32, the function returns a failure.
@ -517,6 +497,14 @@ StatusOr<ConstantOp> CreateConstOpWithSingleValue(PatternRewriter* rewriter,
attr = DenseElementsAttr::get(scalar_type, floatValues); attr = DenseElementsAttr::get(scalar_type, floatValues);
break; break;
} }
case mlir::StandardTypes::BF16: {
auto floatType = mlir::FloatType::getBF16(element_type.getContext());
auto floatAttr =
mlir::FloatAttr::get(floatType, static_cast<float>(value));
std::vector<Attribute> floatValues({floatAttr});
attr = DenseElementsAttr::get(scalar_type, floatValues);
break;
}
case mlir::StandardTypes::F32: { case mlir::StandardTypes::F32: {
attr = attr =
DenseElementsAttr::get<float>(scalar_type, static_cast<float>(value)); DenseElementsAttr::get<float>(scalar_type, static_cast<float>(value));
@ -756,11 +744,11 @@ void LegalizeTF::runOnFunction() {
// Add the generated patterns to the list. // Add the generated patterns to the list.
populateWithGenerated(context, &patterns); populateWithGenerated(context, &patterns);
patterns.insert<ConvertTFConcatOp, ConvertTFConcatV2Op, ConvertTFMatMulOp, patterns
ConvertTFMatrixDiagV2Op, ConvertTFMatrixDiagV3Op, .insert<ConvertTFConcatV2Op, ConvertTFMatMulOp, ConvertTFMatrixDiagV2Op,
ConvertTFPackOp, ConvertTFReshapeOp, ConvertTFSplitOp, ConvertTFMatrixDiagV3Op, ConvertTFPackOp, ConvertTFReshapeOp,
ConvertTFSplitVOp, ConvertTFStridedSliceOp, ConvertTFUnpackOp, ConvertTFSplitOp, ConvertTFSplitVOp, ConvertTFStridedSliceOp,
ConvertTFAssertOp, ConvertTFReciprocalOp, ConvertTFUnpackOp, ConvertTFAssertOp, ConvertTFReciprocalOp,
ConvertTFRandomUniformOp, ConvertTFBroadcastToOp>(context); ConvertTFRandomUniformOp, ConvertTFBroadcastToOp>(context);
// Ophint python converter converted tf node pattern. // Ophint python converter converted tf node pattern.

View File

@ -32,8 +32,6 @@ namespace mlir {
namespace TFL { namespace TFL {
namespace { namespace {
using FuncSet = llvm::SmallSet<FuncOp, 4>;
// Module pass to optimize TensorFlow functional ops. // Module pass to optimize TensorFlow functional ops.
struct OptimizeFunctionalOpsPass struct OptimizeFunctionalOpsPass
: public PassWrapper<OptimizeFunctionalOpsPass, OperationPass<ModuleOp>> { : public PassWrapper<OptimizeFunctionalOpsPass, OperationPass<ModuleOp>> {
@ -44,8 +42,8 @@ struct OptimizeFunctionalOpsPass
// op operands' types. // op operands' types.
// //
// Requires the function has exactly one block. // Requires the function has exactly one block.
static void UpdateFuncType(FuncOp func) { void UpdateFuncType(FuncOp func) {
Operation* terminator = &func.getBlocks().front().back(); Operation* terminator = func.front().getTerminator();
auto return_types = llvm::to_vector<4>(terminator->getOperandTypes()); auto return_types = llvm::to_vector<4>(terminator->getOperandTypes());
FunctionType func_type = func.getType(); FunctionType func_type = func.getType();
@ -57,7 +55,7 @@ static void UpdateFuncType(FuncOp func) {
} }
// TODO(jpienaar): Remove when recursive side-effect modeling is added. // TODO(jpienaar): Remove when recursive side-effect modeling is added.
static bool IsSideEffectFree(FuncOp func) { bool IsSideEffectFree(FuncOp func) {
return !func.getBody() return !func.getBody()
.walk([&](Operation* op) { .walk([&](Operation* op) {
if (!MemoryEffectOpInterface::hasNoEffect(op) && if (!MemoryEffectOpInterface::hasNoEffect(op) &&
@ -72,8 +70,8 @@ static bool IsSideEffectFree(FuncOp func) {
// function body based on the conditional value. // function body based on the conditional value.
class FoldIfOp : public OpRewritePattern<TF::IfOp> { class FoldIfOp : public OpRewritePattern<TF::IfOp> {
public: public:
explicit FoldIfOp(MLIRContext* context, FuncSet* inlined_funcs) explicit FoldIfOp(MLIRContext* context)
: OpRewritePattern<TF::IfOp>(context), inlined_funcs_(inlined_funcs) {} : OpRewritePattern<TF::IfOp>(context) {}
LogicalResult matchAndRewrite(TF::IfOp op, LogicalResult matchAndRewrite(TF::IfOp op,
PatternRewriter& rewriter) const override { PatternRewriter& rewriter) const override {
@ -82,7 +80,7 @@ class FoldIfOp : public OpRewritePattern<TF::IfOp> {
// updated if operands' shapes change after inlining. Without this // updated if operands' shapes change after inlining. Without this
// restriction, it would require tensor cast ops. // restriction, it would require tensor cast ops.
FuncOp parent_op = op.getParentOfType<FuncOp>(); FuncOp parent_op = op.getParentOfType<FuncOp>();
if (parent_op.getBlocks().size() != 1) return failure(); if (!llvm::hasSingleElement(parent_op)) return failure();
// Find the then and else branch functions. // Find the then and else branch functions.
SymbolTable table(op.getParentOfType<ModuleOp>()); SymbolTable table(op.getParentOfType<ModuleOp>());
@ -95,8 +93,6 @@ class FoldIfOp : public OpRewritePattern<TF::IfOp> {
if (op.use_empty() && if (op.use_empty() &&
(op.is_stateless() || (op.is_stateless() ||
(IsSideEffectFree(then_branch) && IsSideEffectFree(else_branch)))) { (IsSideEffectFree(then_branch) && IsSideEffectFree(else_branch)))) {
inlined_funcs_->insert(then_branch);
inlined_funcs_->insert(else_branch);
rewriter.eraseOp(op.getOperation()); rewriter.eraseOp(op.getOperation());
return success(); return success();
} }
@ -118,14 +114,14 @@ class FoldIfOp : public OpRewritePattern<TF::IfOp> {
// Make sure that the function has exactly one block to simplify inlining. // Make sure that the function has exactly one block to simplify inlining.
// TFLite doesn't use control flow with blocks so functions with more than // TFLite doesn't use control flow with blocks so functions with more than
// one blocks are not encountered in practice. // one blocks are not encountered in practice.
if (func.getBody().getBlocks().size() != 1) return failure(); if (!llvm::hasSingleElement(func)) return failure();
BlockAndValueMapping mapper; BlockAndValueMapping mapper;
for (int i = 0, e = func.getNumArguments(); i != e; ++i) for (int i = 0, e = func.getNumArguments(); i != e; ++i)
mapper.map(func.getArgument(i), op.getOperand(i + 1)); mapper.map(func.getArgument(i), op.getOperand(i + 1));
llvm::SmallVector<Value, 4> updated_results; llvm::SmallVector<Value, 4> updated_results;
for (auto& op_to_inline : func.getBody().front()) { for (auto& op_to_inline : func.front()) {
// If this is a terminator, identify the values to use to replace the // If this is a terminator, identify the values to use to replace the
// original If op. // original If op.
if (op_to_inline.isKnownTerminator()) { if (op_to_inline.isKnownTerminator()) {
@ -145,64 +141,26 @@ class FoldIfOp : public OpRewritePattern<TF::IfOp> {
// return type should be updated. // return type should be updated.
UpdateFuncType(parent_op); UpdateFuncType(parent_op);
// Track functions that could be erased if this op was the last reference
// of the function.
inlined_funcs_->insert(then_branch);
inlined_funcs_->insert(else_branch);
return success(); return success();
} }
private:
FuncSet* inlined_funcs_;
}; };
// Erases functions from the given candidates that are not referenced by any of
// the ops in the module.
static void EraseDeadFuncs(const FuncSet& candidate_funcs, ModuleOp module) {
if (candidate_funcs.empty()) return;
SymbolTable manager(module);
// Identify the functions that are used as symbols in the module and shouldn't
// be erased.
FuncSet in_use_funcs;
manager.getOp()->walk([&](Operation* op) {
for (auto attr : op->getAttrs()) {
if (auto symbol = attr.second.dyn_cast<FlatSymbolRefAttr>()) {
auto func = manager.lookup<FuncOp>(symbol.getValue());
in_use_funcs.insert(func);
}
}
});
for (FuncOp func : candidate_funcs) {
if (!in_use_funcs.count(func)) manager.erase(func);
}
}
void OptimizeFunctionalOpsPass::runOnOperation() { void OptimizeFunctionalOpsPass::runOnOperation() {
OwningRewritePatternList patterns; OwningRewritePatternList patterns;
FuncSet inlined_funcs; patterns.insert<FoldIfOp>(&getContext());
patterns.insert<FoldIfOp>(&getContext(), &inlined_funcs);
ModuleOp module = getOperation(); ModuleOp module = getOperation();
applyPatternsAndFoldGreedily(module, patterns); applyPatternsAndFoldGreedily(module, patterns);
// Erase inlined functions that don't have any references.
//
// TODO(hinsu): Update this to not erase entry points once TFLite support to
// have multiple entry points is implemented. Until then, it is safe to
// erase these functions.
EraseDeadFuncs(inlined_funcs, module);
} }
PassRegistration<OptimizeFunctionalOpsPass> pass(
"tfl-optimize-functional-ops", "Optimize TensorFlow functional ops");
} // namespace } // namespace
std::unique_ptr<OperationPass<ModuleOp>> CreateOptimizeFunctionalOpsPass() { std::unique_ptr<OperationPass<ModuleOp>> CreateOptimizeFunctionalOpsPass() {
return std::make_unique<OptimizeFunctionalOpsPass>(); return std::make_unique<OptimizeFunctionalOpsPass>();
} }
static PassRegistration<OptimizeFunctionalOpsPass> pass(
"tfl-optimize-functional-ops", "Optimize TensorFlow functional ops");
} // namespace TFL } // namespace TFL
} // namespace mlir } // namespace mlir

View File

@ -29,7 +29,7 @@ cc_library(
# place for core related components. # place for core related components.
"//tensorflow/compiler/mlir/tensorflow:graph_optimization_pass_registration", "//tensorflow/compiler/mlir/tensorflow:graph_optimization_pass_registration",
"//tensorflow/compiler/mlir/tensorflow:import_utils", "//tensorflow/compiler/mlir/tensorflow:import_utils",
"@llvm-project//llvm:support", "@llvm-project//llvm:Support",
"@llvm-project//mlir:IR", "@llvm-project//mlir:IR",
"@llvm-project//mlir:Parser", "@llvm-project//mlir:Parser",
"@llvm-project//mlir:Pass", "@llvm-project//mlir:Pass",

View File

@ -20,7 +20,7 @@ tf_python_pybind_extension(
"//tensorflow/compiler/mlir/tensorflow:tensorflow_types", "//tensorflow/compiler/mlir/tensorflow:tensorflow_types",
"//tensorflow/python:pybind11_lib", "//tensorflow/python:pybind11_lib",
"//tensorflow/python:pybind11_status", "//tensorflow/python:pybind11_status",
"@llvm-project//llvm:support", "@llvm-project//llvm:Support",
"@llvm-project//mlir:IR", "@llvm-project//mlir:IR",
"@llvm-project//mlir:StandardOps", "@llvm-project//mlir:StandardOps",
"@pybind11", "@pybind11",
@ -35,7 +35,7 @@ tf_python_pybind_extension(
deps = [ deps = [
"//tensorflow/python:pybind11_lib", "//tensorflow/python:pybind11_lib",
"//tensorflow/python:pybind11_status", "//tensorflow/python:pybind11_status",
"@llvm-project//llvm:support", "@llvm-project//llvm:Support",
"@pybind11", "@pybind11",
], ],
) )

View File

@ -187,7 +187,7 @@ gentbl(
td_file = "transforms/legalize_hlo_patterns.td", td_file = "transforms/legalize_hlo_patterns.td",
td_srcs = [ td_srcs = [
"//tensorflow/compiler/mlir/xla:hlo_ops_td_files", "//tensorflow/compiler/mlir/xla:hlo_ops_td_files",
"@llvm-project//llvm:support", "@llvm-project//llvm:Support",
"@llvm-project//mlir:StdOpsTdFiles", "@llvm-project//mlir:StdOpsTdFiles",
"//tensorflow/compiler/mlir/tensorflow:tensorflow_ops_td_files", "//tensorflow/compiler/mlir/tensorflow:tensorflow_ops_td_files",
], ],
@ -204,7 +204,7 @@ cc_library(
":tensorflow", ":tensorflow",
"//tensorflow/compiler/mlir/xla:hlo", "//tensorflow/compiler/mlir/xla:hlo",
"//tensorflow/core:framework", "//tensorflow/core:framework",
"@llvm-project//llvm:support", "@llvm-project//llvm:Support",
"@llvm-project//mlir:Analysis", "@llvm-project//mlir:Analysis",
"@llvm-project//mlir:Dialect", "@llvm-project//mlir:Dialect",
"@llvm-project//mlir:IR", "@llvm-project//mlir:IR",
@ -225,7 +225,7 @@ cc_library(
"ir/tf_attributes.h", "ir/tf_attributes.h",
], ],
deps = [ deps = [
"@llvm-project//llvm:support", "@llvm-project//llvm:Support",
"@llvm-project//mlir:IR", "@llvm-project//mlir:IR",
], ],
) )
@ -240,7 +240,7 @@ cc_library(
"ir/tf_types.h", "ir/tf_types.h",
], ],
deps = [ deps = [
"@llvm-project//llvm:support", "@llvm-project//llvm:Support",
"@llvm-project//mlir:Dialect", "@llvm-project//mlir:Dialect",
"@llvm-project//mlir:IR", "@llvm-project//mlir:IR",
], ],
@ -293,7 +293,7 @@ cc_library(
"//tensorflow/core:framework", "//tensorflow/core:framework",
"//tensorflow/core:lib", "//tensorflow/core:lib",
"//tensorflow/core/platform:logging", "//tensorflow/core/platform:logging",
"@llvm-project//llvm:support", "@llvm-project//llvm:Support",
"@llvm-project//mlir:Analysis", "@llvm-project//mlir:Analysis",
"@llvm-project//mlir:CallOpInterfacesIncGen", "@llvm-project//mlir:CallOpInterfacesIncGen",
"@llvm-project//mlir:DerivedAttributeOpInterface", "@llvm-project//mlir:DerivedAttributeOpInterface",
@ -388,7 +388,7 @@ cc_library(
":tensorflow", ":tensorflow",
"//tensorflow/core:framework", "//tensorflow/core:framework",
"@com_google_absl//absl/memory", "@com_google_absl//absl/memory",
"@llvm-project//llvm:support", "@llvm-project//llvm:Support",
"@llvm-project//mlir:Analysis", "@llvm-project//mlir:Analysis",
"@llvm-project//mlir:IR", "@llvm-project//mlir:IR",
"@llvm-project//mlir:Pass", "@llvm-project//mlir:Pass",
@ -416,6 +416,7 @@ cc_library(
"transforms/fold_switch.cc", "transforms/fold_switch.cc",
"transforms/freeze_global_tensors.cc", "transforms/freeze_global_tensors.cc",
"transforms/functional_control_flow_to_cfg.cc", "transforms/functional_control_flow_to_cfg.cc",
"transforms/fused_kernel_matcher.cc",
"transforms/generated_canonicalize.inc", "transforms/generated_canonicalize.inc",
"transforms/generated_optimize.inc", "transforms/generated_optimize.inc",
"transforms/gpu_fusion.cc", "transforms/gpu_fusion.cc",
@ -424,7 +425,6 @@ cc_library(
"transforms/layout_optimization.cc", "transforms/layout_optimization.cc",
"transforms/mark_function_visibility.cc", "transforms/mark_function_visibility.cc",
"transforms/materialize_mlir_passthrough_op.cc", "transforms/materialize_mlir_passthrough_op.cc",
"transforms/op_fusion.cc",
"transforms/optimize.cc", "transforms/optimize.cc",
"transforms/optimize_global_tensors.cc", "transforms/optimize_global_tensors.cc",
"transforms/parallel_execute_to_islands.cc", "transforms/parallel_execute_to_islands.cc",
@ -500,7 +500,7 @@ cc_library(
"//tensorflow/core/protobuf/tpu:dynamic_padding_proto_cc", "//tensorflow/core/protobuf/tpu:dynamic_padding_proto_cc",
"@com_google_absl//absl/memory", "@com_google_absl//absl/memory",
"@com_google_absl//absl/strings", "@com_google_absl//absl/strings",
"@llvm-project//llvm:support", "@llvm-project//llvm:Support",
"@llvm-project//mlir:Analysis", "@llvm-project//mlir:Analysis",
"@llvm-project//mlir:IR", "@llvm-project//mlir:IR",
"@llvm-project//mlir:Parser", "@llvm-project//mlir:Parser",
@ -609,7 +609,7 @@ cc_library(
"@com_google_absl//absl/container:inlined_vector", "@com_google_absl//absl/container:inlined_vector",
"@com_google_absl//absl/strings", "@com_google_absl//absl/strings",
"@com_google_absl//absl/types:optional", "@com_google_absl//absl/types:optional",
"@llvm-project//llvm:support", "@llvm-project//llvm:Support",
"@llvm-project//mlir:Analysis", "@llvm-project//mlir:Analysis",
"@llvm-project//mlir:IR", "@llvm-project//mlir:IR",
"@llvm-project//mlir:Pass", "@llvm-project//mlir:Pass",
@ -639,7 +639,7 @@ cc_library(
":parse_text_proto", ":parse_text_proto",
"//tensorflow/core:lib", "//tensorflow/core:lib",
"@com_google_absl//absl/strings", "@com_google_absl//absl/strings",
"@llvm-project//llvm:support", "@llvm-project//llvm:Support",
], ],
) )
@ -669,7 +669,7 @@ cc_library(
"@com_google_absl//absl/container:flat_hash_set", "@com_google_absl//absl/container:flat_hash_set",
"@com_google_absl//absl/memory", "@com_google_absl//absl/memory",
"@com_google_absl//absl/strings", "@com_google_absl//absl/strings",
"@llvm-project//llvm:support", "@llvm-project//llvm:Support",
"@llvm-project//mlir:IR", "@llvm-project//mlir:IR",
"@llvm-project//mlir:StandardOps", "@llvm-project//mlir:StandardOps",
"@llvm-project//mlir:Support", "@llvm-project//mlir:Support",
@ -712,7 +712,7 @@ cc_library(
"//tensorflow/stream_executor/lib", "//tensorflow/stream_executor/lib",
"@com_google_absl//absl/container:flat_hash_set", "@com_google_absl//absl/container:flat_hash_set",
"@com_google_absl//absl/strings", "@com_google_absl//absl/strings",
"@llvm-project//llvm:support", "@llvm-project//llvm:Support",
"@llvm-project//mlir:IR", "@llvm-project//mlir:IR",
], ],
) )
@ -722,7 +722,7 @@ cc_library(
srcs = ["translate/translate_tf_dialect_op.cc"], srcs = ["translate/translate_tf_dialect_op.cc"],
deps = [ deps = [
":export_tf_dialect_op", ":export_tf_dialect_op",
"@llvm-project//llvm:support", "@llvm-project//llvm:Support",
"@llvm-project//mlir:IR", "@llvm-project//mlir:IR",
"@llvm-project//mlir:Support", "@llvm-project//mlir:Support",
"@llvm-project//mlir:Translation", "@llvm-project//mlir:Translation",
@ -781,7 +781,7 @@ cc_library(
"@com_google_absl//absl/container:flat_hash_set", "@com_google_absl//absl/container:flat_hash_set",
"@com_google_absl//absl/container:inlined_vector", "@com_google_absl//absl/container:inlined_vector",
"@com_google_absl//absl/strings", "@com_google_absl//absl/strings",
"@llvm-project//llvm:support", "@llvm-project//llvm:Support",
], ],
) )
@ -799,7 +799,7 @@ cc_library(
"//tensorflow/core:protos_all_cc", "//tensorflow/core:protos_all_cc",
"//tensorflow/stream_executor/lib", "//tensorflow/stream_executor/lib",
"@com_google_absl//absl/strings", "@com_google_absl//absl/strings",
"@llvm-project//llvm:support", "@llvm-project//llvm:Support",
"@llvm-project//mlir:IR", "@llvm-project//mlir:IR",
"@llvm-project//mlir:Support", "@llvm-project//mlir:Support",
], ],
@ -816,7 +816,7 @@ tf_cc_test(
"//tensorflow/core:test", "//tensorflow/core:test",
"//tensorflow/core:test_main", "//tensorflow/core:test_main",
"//tensorflow/stream_executor/lib", "//tensorflow/stream_executor/lib",
"@llvm-project//llvm:support", "@llvm-project//llvm:Support",
"@llvm-project//mlir:IR", "@llvm-project//mlir:IR",
], ],
) )
@ -838,7 +838,7 @@ cc_library(
"@com_google_absl//absl/base", "@com_google_absl//absl/base",
"@com_google_absl//absl/container:inlined_vector", "@com_google_absl//absl/container:inlined_vector",
"@com_google_absl//absl/strings", "@com_google_absl//absl/strings",
"@llvm-project//llvm:support", "@llvm-project//llvm:Support",
"@llvm-project//mlir:IR", "@llvm-project//mlir:IR",
], ],
) )
@ -882,7 +882,7 @@ cc_library(
deps = [ deps = [
"//tensorflow/core/platform:errors", "//tensorflow/core/platform:errors",
"//tensorflow/core/platform:status", "//tensorflow/core/platform:status",
"@llvm-project//llvm:support", "@llvm-project//llvm:Support",
"@llvm-project//mlir:IR", "@llvm-project//mlir:IR",
], ],
) )
@ -925,7 +925,7 @@ cc_library(
"//tensorflow/core:lib", "//tensorflow/core:lib",
"//tensorflow/stream_executor", "//tensorflow/stream_executor",
"//tensorflow/stream_executor/lib", "//tensorflow/stream_executor/lib",
"@llvm-project//llvm:support", "@llvm-project//llvm:Support",
"@llvm-project//mlir:IR", "@llvm-project//mlir:IR",
"@llvm-project//mlir:SideEffects", "@llvm-project//mlir:SideEffects",
"@llvm-project//mlir:Support", "@llvm-project//mlir:Support",
@ -957,7 +957,7 @@ cc_library(
"//tensorflow/core:ops", "//tensorflow/core:ops",
"//tensorflow/core:protos_all_cc", "//tensorflow/core:protos_all_cc",
"//tensorflow/stream_executor/lib", "//tensorflow/stream_executor/lib",
"@llvm-project//llvm:support", "@llvm-project//llvm:Support",
"@llvm-project//mlir:IR", "@llvm-project//mlir:IR",
"@llvm-project//mlir:Pass", "@llvm-project//mlir:Pass",
], ],
@ -986,7 +986,7 @@ cc_library(
"@com_google_absl//absl/container:inlined_vector", "@com_google_absl//absl/container:inlined_vector",
"@com_google_absl//absl/memory", "@com_google_absl//absl/memory",
"@com_google_absl//absl/strings", "@com_google_absl//absl/strings",
"@llvm-project//llvm:support", "@llvm-project//llvm:Support",
"@llvm-project//mlir:IR", "@llvm-project//mlir:IR",
"@llvm-project//mlir:Support", "@llvm-project//mlir:Support",
], ],
@ -1011,7 +1011,7 @@ cc_library(
"@com_google_absl//absl/memory", "@com_google_absl//absl/memory",
"@com_google_absl//absl/strings", "@com_google_absl//absl/strings",
"@com_google_absl//absl/types:span", "@com_google_absl//absl/types:span",
"@llvm-project//llvm:support", "@llvm-project//llvm:Support",
"@llvm-project//mlir:IR", "@llvm-project//mlir:IR",
"@llvm-project//mlir:Parser", "@llvm-project//mlir:Parser",
"@llvm-project//mlir:Pass", "@llvm-project//mlir:Pass",
@ -1027,7 +1027,7 @@ cc_library(
"translate/tf_mlir_translate_cl.h", "translate/tf_mlir_translate_cl.h",
], ],
deps = [ deps = [
"@llvm-project//llvm:support", "@llvm-project//llvm:Support",
], ],
alwayslink = 1, alwayslink = 1,
) )
@ -1044,7 +1044,7 @@ cc_library(
":translate_lib", ":translate_lib",
"//tensorflow/core:protos_all_cc", "//tensorflow/core:protos_all_cc",
"//tensorflow/stream_executor/lib", "//tensorflow/stream_executor/lib",
"@llvm-project//llvm:support", "@llvm-project//llvm:Support",
"@llvm-project//mlir:IR", "@llvm-project//mlir:IR",
"@llvm-project//mlir:Translation", "@llvm-project//mlir:Translation",
], ],
@ -1060,7 +1060,7 @@ tf_cc_test(
"//tensorflow/core:lib", "//tensorflow/core:lib",
"//tensorflow/core:test", "//tensorflow/core:test",
"//tensorflow/core:test_main", "//tensorflow/core:test_main",
"@llvm-project//llvm:support", "@llvm-project//llvm:Support",
"@llvm-project//mlir:IR", "@llvm-project//mlir:IR",
], ],
) )
@ -1071,8 +1071,8 @@ tf_native_cc_binary(
"translate/derived_attr_populator_gen.cc", "translate/derived_attr_populator_gen.cc",
], ],
deps = [ deps = [
"@llvm-project//llvm:support", "@llvm-project//llvm:Support",
"@llvm-project//llvm:tablegen", "@llvm-project//llvm:TableGen",
"@llvm-project//mlir:TableGen", "@llvm-project//mlir:TableGen",
], ],
) )
@ -1134,7 +1134,7 @@ COMPILE_MLIR_UTIL_DEPS = [
":tensorflow_passes", ":tensorflow_passes",
":translate_utils", ":translate_utils",
"@com_google_absl//absl/types:optional", "@com_google_absl//absl/types:optional",
"@llvm-project//llvm:support", "@llvm-project//llvm:Support",
"@llvm-project//mlir:IR", "@llvm-project//mlir:IR",
"@llvm-project//mlir:Parser", "@llvm-project//mlir:Parser",
"@llvm-project//mlir:Pass", "@llvm-project//mlir:Pass",
@ -1266,7 +1266,7 @@ cc_library(
":tensorflow", ":tensorflow",
":tensorflow_types", ":tensorflow_types",
"//tensorflow/core:framework", "//tensorflow/core:framework",
"@llvm-project//llvm:support", "@llvm-project//llvm:Support",
"@llvm-project//mlir:IR", "@llvm-project//mlir:IR",
], ],
alwayslink = 1, alwayslink = 1,
@ -1285,7 +1285,7 @@ cc_library(
"//tensorflow/core/protobuf/tpu:topology_proto_cc", "//tensorflow/core/protobuf/tpu:topology_proto_cc",
"//tensorflow/stream_executor/lib", "//tensorflow/stream_executor/lib",
"@com_google_absl//absl/strings", "@com_google_absl//absl/strings",
"@llvm-project//llvm:support", "@llvm-project//llvm:Support",
"@llvm-project//mlir:IR", "@llvm-project//mlir:IR",
], ],
) )
@ -1300,7 +1300,7 @@ tf_cc_test(
"//tensorflow/core:test", "//tensorflow/core:test",
"//tensorflow/core:test_main", "//tensorflow/core:test_main",
"//tensorflow/core/protobuf/tpu:topology_proto_cc", "//tensorflow/core/protobuf/tpu:topology_proto_cc",
"@llvm-project//llvm:support", "@llvm-project//llvm:Support",
"@llvm-project//mlir:IR", "@llvm-project//mlir:IR",
], ],
) )
@ -1313,7 +1313,7 @@ cc_library(
":tensorflow", ":tensorflow",
"//tensorflow/core:core_cpu_lib", "//tensorflow/core:core_cpu_lib",
"//tensorflow/core:framework", "//tensorflow/core:framework",
"@llvm-project//llvm:support", "@llvm-project//llvm:Support",
"@llvm-project//mlir:IR", "@llvm-project//mlir:IR",
"@llvm-project//mlir:Support", "@llvm-project//mlir:Support",
], ],
@ -1331,7 +1331,7 @@ tf_cc_test(
"//tensorflow/core:protos_all_cc", "//tensorflow/core:protos_all_cc",
"//tensorflow/core:test", "//tensorflow/core:test",
"//tensorflow/core:test_main", "//tensorflow/core:test_main",
"@llvm-project//llvm:support", "@llvm-project//llvm:Support",
"@llvm-project//mlir:IR", "@llvm-project//mlir:IR",
"@llvm-project//mlir:Support", "@llvm-project//mlir:Support",
], ],
@ -1344,7 +1344,7 @@ cc_library(
deps = [ deps = [
"//tensorflow/core:lib", "//tensorflow/core:lib",
"//tensorflow/core/platform:logging", "//tensorflow/core/platform:logging",
"@llvm-project//llvm:support", "@llvm-project//llvm:Support",
"@llvm-project//mlir:IR", "@llvm-project//mlir:IR",
], ],
) )
@ -1359,7 +1359,7 @@ tf_cc_test(
"//tensorflow/core:test", "//tensorflow/core:test",
"//tensorflow/core:test_main", "//tensorflow/core:test_main",
"//tensorflow/core/platform:test", "//tensorflow/core/platform:test",
"@llvm-project//llvm:support", "@llvm-project//llvm:Support",
"@llvm-project//mlir:IR", "@llvm-project//mlir:IR",
], ],
) )
@ -1378,7 +1378,7 @@ cc_library(
"//tensorflow/core:graph", "//tensorflow/core:graph",
"//tensorflow/core:lib", "//tensorflow/core:lib",
"//tensorflow/core/platform:logging", "//tensorflow/core/platform:logging",
"@llvm-project//llvm:support", "@llvm-project//llvm:Support",
"@llvm-project//mlir:AllPassesAndDialects", "@llvm-project//mlir:AllPassesAndDialects",
"@llvm-project//mlir:Analysis", "@llvm-project//mlir:Analysis",
"@llvm-project//mlir:IR", "@llvm-project//mlir:IR",
@ -1398,7 +1398,7 @@ tf_cc_test(
"//tensorflow/core:test", "//tensorflow/core:test",
"//tensorflow/core:test_main", "//tensorflow/core:test_main",
"//tensorflow/core/platform:test", "//tensorflow/core/platform:test",
"@llvm-project//llvm:support", "@llvm-project//llvm:Support",
"@llvm-project//mlir:IR", "@llvm-project//mlir:IR",
], ],
) )
@ -1409,7 +1409,7 @@ cc_library(
hdrs = ["utils/bridge_logger.h"], hdrs = ["utils/bridge_logger.h"],
deps = [ deps = [
":dump_mlir_util", ":dump_mlir_util",
"@llvm-project//llvm:support", "@llvm-project//llvm:Support",
"@llvm-project//mlir:IR", "@llvm-project//mlir:IR",
"@llvm-project//mlir:Pass", "@llvm-project//mlir:Pass",
], ],
@ -1425,7 +1425,7 @@ cc_library(
"//tensorflow/compiler/tf2xla:resource_operation_table", "//tensorflow/compiler/tf2xla:resource_operation_table",
"//tensorflow/core:framework", "//tensorflow/core:framework",
"@com_google_absl//absl/strings", "@com_google_absl//absl/strings",
"@llvm-project//llvm:support", "@llvm-project//llvm:Support",
"@llvm-project//mlir:IR", "@llvm-project//mlir:IR",
"@llvm-project//mlir:Support", "@llvm-project//mlir:Support",
], ],
@ -1443,7 +1443,7 @@ cc_library(
":tensorflow", ":tensorflow",
"//tensorflow/compiler/xla:xla_data_proto_cc", "//tensorflow/compiler/xla:xla_data_proto_cc",
"//tensorflow/core/protobuf/tpu:compile_metadata_proto_cc", "//tensorflow/core/protobuf/tpu:compile_metadata_proto_cc",
"@llvm-project//llvm:support", "@llvm-project//llvm:Support",
"@llvm-project//mlir:IR", "@llvm-project//mlir:IR",
"@llvm-project//mlir:Support", "@llvm-project//mlir:Support",
], ],

View File

@ -36,7 +36,7 @@ tf_cuda_library(
"//tensorflow/core:framework", "//tensorflow/core:framework",
"//tensorflow/core:protos_all_cc", "//tensorflow/core:protos_all_cc",
"//tensorflow/core/platform:errors", "//tensorflow/core/platform:errors",
"@llvm-project//llvm:support", "@llvm-project//llvm:Support",
"@llvm-project//mlir:IR", "@llvm-project//mlir:IR",
"@llvm-project//mlir:Pass", "@llvm-project//mlir:Pass",
"@llvm-project//mlir:StandardOps", "@llvm-project//mlir:StandardOps",

View File

@ -794,6 +794,34 @@ This op is deprecated. Prefer `tf.nn.batch_normalization`.
TF_DerivedOperandTypeAttr T = TF_DerivedOperandTypeAttr<0>; TF_DerivedOperandTypeAttr T = TF_DerivedOperandTypeAttr<0>;
} }
def TF_BatchToSpaceOp : TF_Op<"BatchToSpace", [NoSideEffect]> {
let summary = "BatchToSpace for 4-D tensors of type T.";
let description = [{
This is a legacy version of the more general BatchToSpaceND.
Rearranges (permutes) data from batch into blocks of spatial data, followed by
cropping. This is the reverse transformation of SpaceToBatch. More specifically,
this op outputs a copy of the input tensor where values from the `batch`
dimension are moved in spatial blocks to the `height` and `width` dimensions,
followed by cropping along the `height` and `width` dimensions.
}];
let arguments = (ins
TF_Tensor:$input,
TF_I32OrI64Tensor:$crops,
Confined<I64Attr, [IntMinValue<2>]>:$block_size
);
let results = (outs
TF_Tensor:$output
);
TF_DerivedOperandTypeAttr T = TF_DerivedOperandTypeAttr<0>;
TF_DerivedOperandTypeAttr Tidx = TF_DerivedOperandTypeAttr<1>;
}
def TF_BatchToSpaceNDOp : TF_Op<"BatchToSpaceND", [NoSideEffect]> { def TF_BatchToSpaceNDOp : TF_Op<"BatchToSpaceND", [NoSideEffect]> {
let summary = "BatchToSpace for N-D tensors of type T."; let summary = "BatchToSpace for N-D tensors of type T.";
@ -1219,6 +1247,35 @@ subsequent operation and then be optimized away, however.)
}]; }];
} }
def TF_BucketizeOp : TF_Op<"Bucketize", [NoSideEffect, SameOperandsAndResultShape]> {
let summary = "Bucketizes 'input' based on 'boundaries'.";
let description = [{
For example, if the inputs are
boundaries = [0, 10, 100]
input = [[-5, 10000]
[150, 10]
[5, 100]]
then the output will be
output = [[0, 3]
[3, 2]
[1, 3]]
}];
let arguments = (ins
TensorOf<[F32, F64, I32, I64]>:$input,
F32ArrayAttr:$boundaries
);
let results = (outs
I32Tensor:$output
);
TF_DerivedOperandTypeAttr T = TF_DerivedOperandTypeAttr<0>;
}
def TF_CaseOp : TF_Op<"Case", []> { def TF_CaseOp : TF_Op<"Case", []> {
let summary = [{ let summary = [{
An n-way switch statement which calls a single branch function. An n-way switch statement which calls a single branch function.
@ -1257,6 +1314,8 @@ An n-way switch statement, implementing the following:
TF_DerivedOperandTypeListAttr Tin = TF_DerivedOperandTypeListAttr<1>; TF_DerivedOperandTypeListAttr Tin = TF_DerivedOperandTypeListAttr<1>;
TF_DerivedResultTypeListAttr Tout = TF_DerivedResultTypeListAttr<0>; TF_DerivedResultTypeListAttr Tout = TF_DerivedResultTypeListAttr<0>;
let hasCanonicalizer = 1;
} }
def TF_CastOp : TF_Op<"Cast", [NoSideEffect, SameOperandsAndResultShape]> { def TF_CastOp : TF_Op<"Cast", [NoSideEffect, SameOperandsAndResultShape]> {
@ -1519,6 +1578,8 @@ def TF_ConcatOp : TF_Op<"Concat", [NoSideEffect]> {
let verifier = [{ let verifier = [{
return Verify(*this); return Verify(*this);
}]; }];
let hasCanonicalizer = 1;
} }
def TF_ConcatOffsetOp : TF_Op<"ConcatOffset", [NoSideEffect]> { def TF_ConcatOffsetOp : TF_Op<"ConcatOffset", [NoSideEffect]> {
@ -2320,6 +2381,21 @@ horizontal and vertices strides, `strides = [1, stride, stride, 1]`.
TF_DerivedOperandTypeAttr T = TF_DerivedOperandTypeAttr<0>; TF_DerivedOperandTypeAttr T = TF_DerivedOperandTypeAttr<0>;
} }
def TF_DeviceIndexOp : TF_Op<"DeviceIndex", [NoSideEffect]> {
let summary = "Return the index of device the op runs.";
let description = [{
}];
let arguments = (ins
StrArrayAttr:$device_names
);
let results = (outs
I32Tensor:$index
);
}
def TF_DiagPartOp : TF_Op<"DiagPart", [NoSideEffect]> { def TF_DiagPartOp : TF_Op<"DiagPart", [NoSideEffect]> {
let summary = "Returns the diagonal part of the tensor."; let summary = "Returns the diagonal part of the tensor.";
@ -2968,6 +3044,63 @@ i.e. `exp(x) - 1` or `e^(x) - 1`, where `x` is the input tensor.
TF_DerivedOperandTypeAttr T = TF_DerivedOperandTypeAttr<0>; TF_DerivedOperandTypeAttr T = TF_DerivedOperandTypeAttr<0>;
} }
def TF_FFTOp : TF_Op<"FFT", [NoSideEffect]> {
let summary = "Fast Fourier transform.";
let description = [{
Computes the 1-dimensional discrete Fourier transform over the inner-most
dimension of `input`.
}];
let arguments = (ins
TensorOf<[TF_Complex128, TF_Complex64]>:$input
);
let results = (outs
TensorOf<[TF_Complex128, TF_Complex64]>:$output
);
TF_DerivedOperandTypeAttr Tcomplex = TF_DerivedOperandTypeAttr<0>;
}
def TF_FFT2DOp : TF_Op<"FFT2D", [NoSideEffect]> {
let summary = "2D fast Fourier transform.";
let description = [{
Computes the 2-dimensional discrete Fourier transform over the inner-most
2 dimensions of `input`.
}];
let arguments = (ins
TensorOf<[TF_Complex128, TF_Complex64]>:$input
);
let results = (outs
TensorOf<[TF_Complex128, TF_Complex64]>:$output
);
TF_DerivedOperandTypeAttr Tcomplex = TF_DerivedOperandTypeAttr<0>;
}
def TF_FFT3DOp : TF_Op<"FFT3D", [NoSideEffect]> {
let summary = "3D fast Fourier transform.";
let description = [{
Computes the 3-dimensional discrete Fourier transform over the inner-most 3
dimensions of `input`.
}];
let arguments = (ins
TensorOf<[TF_Complex128, TF_Complex64]>:$input
);
let results = (outs
TensorOf<[TF_Complex128, TF_Complex64]>:$output
);
TF_DerivedOperandTypeAttr Tcomplex = TF_DerivedOperandTypeAttr<0>;
}
def TF_FakeQuantWithMinMaxArgsOp : TF_Op<"FakeQuantWithMinMaxArgs", [NoSideEffect, SameOperandsAndResultType]> { def TF_FakeQuantWithMinMaxArgsOp : TF_Op<"FakeQuantWithMinMaxArgs", [NoSideEffect, SameOperandsAndResultType]> {
let summary = [{ let summary = [{
Fake-quantize the 'inputs' tensor, type float to 'outputs' tensor of same type. Fake-quantize the 'inputs' tensor, type float to 'outputs' tensor of same type.
@ -3727,6 +3860,161 @@ table will be immutable.
); );
} }
def TF_IFFTOp : TF_Op<"IFFT", [NoSideEffect]> {
let summary = "Inverse fast Fourier transform.";
let description = [{
Computes the inverse 1-dimensional discrete Fourier transform over the
inner-most dimension of `input`.
}];
let arguments = (ins
TensorOf<[TF_Complex128, TF_Complex64]>:$input
);
let results = (outs
TensorOf<[TF_Complex128, TF_Complex64]>:$output
);
TF_DerivedOperandTypeAttr Tcomplex = TF_DerivedOperandTypeAttr<0>;
}
def TF_IFFT2DOp : TF_Op<"IFFT2D", [NoSideEffect]> {
let summary = "Inverse 2D fast Fourier transform.";
let description = [{
Computes the inverse 2-dimensional discrete Fourier transform over the
inner-most 2 dimensions of `input`.
}];
let arguments = (ins
TensorOf<[TF_Complex128, TF_Complex64]>:$input
);
let results = (outs
TensorOf<[TF_Complex128, TF_Complex64]>:$output
);
TF_DerivedOperandTypeAttr Tcomplex = TF_DerivedOperandTypeAttr<0>;
}
def TF_IFFT3DOp : TF_Op<"IFFT3D", [NoSideEffect]> {
let summary = "Inverse 3D fast Fourier transform.";
let description = [{
Computes the inverse 3-dimensional discrete Fourier transform over the
inner-most 3 dimensions of `input`.
}];
let arguments = (ins
TensorOf<[TF_Complex128, TF_Complex64]>:$input
);
let results = (outs
TensorOf<[TF_Complex128, TF_Complex64]>:$output
);
TF_DerivedOperandTypeAttr Tcomplex = TF_DerivedOperandTypeAttr<0>;
}
def TF_IRFFTOp : TF_Op<"IRFFT", [NoSideEffect]> {
let summary = "Inverse real-valued fast Fourier transform.";
let description = [{
Computes the inverse 1-dimensional discrete Fourier transform of a real-valued
signal over the inner-most dimension of `input`.
The inner-most dimension of `input` is assumed to be the result of `RFFT`: the
`fft_length / 2 + 1` unique components of the DFT of a real-valued signal. If
`fft_length` is not provided, it is computed from the size of the inner-most
dimension of `input` (`fft_length = 2 * (inner - 1)`). If the FFT length used to
compute `input` is odd, it should be provided since it cannot be inferred
properly.
Along the axis `IRFFT` is computed on, if `fft_length / 2 + 1` is smaller
than the corresponding dimension of `input`, the dimension is cropped. If it is
larger, the dimension is padded with zeros.
}];
let arguments = (ins
TensorOf<[TF_Complex128, TF_Complex64]>:$input,
I32Tensor:$fft_length
);
let results = (outs
TF_F32OrF64Tensor:$output
);
TF_DerivedResultTypeAttr Treal = TF_DerivedResultTypeAttr<0>;
TF_DerivedOperandTypeAttr Tcomplex = TF_DerivedOperandTypeAttr<0>;
}
def TF_IRFFT2DOp : TF_Op<"IRFFT2D", [NoSideEffect]> {
let summary = "Inverse 2D real-valued fast Fourier transform.";
let description = [{
Computes the inverse 2-dimensional discrete Fourier transform of a real-valued
signal over the inner-most 2 dimensions of `input`.
The inner-most 2 dimensions of `input` are assumed to be the result of `RFFT2D`:
The inner-most dimension contains the `fft_length / 2 + 1` unique components of
the DFT of a real-valued signal. If `fft_length` is not provided, it is computed
from the size of the inner-most 2 dimensions of `input`. If the FFT length used
to compute `input` is odd, it should be provided since it cannot be inferred
properly.
Along each axis `IRFFT2D` is computed on, if `fft_length` (or
`fft_length / 2 + 1` for the inner-most dimension) is smaller than the
corresponding dimension of `input`, the dimension is cropped. If it is larger,
the dimension is padded with zeros.
}];
let arguments = (ins
TensorOf<[TF_Complex128, TF_Complex64]>:$input,
I32Tensor:$fft_length
);
let results = (outs
TF_F32OrF64Tensor:$output
);
TF_DerivedResultTypeAttr Treal = TF_DerivedResultTypeAttr<0>;
TF_DerivedOperandTypeAttr Tcomplex = TF_DerivedOperandTypeAttr<0>;
}
def TF_IRFFT3DOp : TF_Op<"IRFFT3D", [NoSideEffect]> {
let summary = "Inverse 3D real-valued fast Fourier transform.";
let description = [{
Computes the inverse 3-dimensional discrete Fourier transform of a real-valued
signal over the inner-most 3 dimensions of `input`.
The inner-most 3 dimensions of `input` are assumed to be the result of `RFFT3D`:
The inner-most dimension contains the `fft_length / 2 + 1` unique components of
the DFT of a real-valued signal. If `fft_length` is not provided, it is computed
from the size of the inner-most 3 dimensions of `input`. If the FFT length used
to compute `input` is odd, it should be provided since it cannot be inferred
properly.
Along each axis `IRFFT3D` is computed on, if `fft_length` (or
`fft_length / 2 + 1` for the inner-most dimension) is smaller than the
corresponding dimension of `input`, the dimension is cropped. If it is larger,
the dimension is padded with zeros.
}];
let arguments = (ins
TensorOf<[TF_Complex128, TF_Complex64]>:$input,
I32Tensor:$fft_length
);
let results = (outs
TF_F32OrF64Tensor:$output
);
TF_DerivedResultTypeAttr Treal = TF_DerivedResultTypeAttr<0>;
TF_DerivedOperandTypeAttr Tcomplex = TF_DerivedOperandTypeAttr<0>;
}
def TF_IdentityNOp : TF_Op<"IdentityN", [NoSideEffect]> { def TF_IdentityNOp : TF_Op<"IdentityN", [NoSideEffect]> {
let summary = [{ let summary = [{
Returns a list of tensors with the same shapes and contents as the input Returns a list of tensors with the same shapes and contents as the input
@ -4142,6 +4430,30 @@ convolutional neural networks (NIPS 2012)](http://papers.nips.cc/paper/4824-imag
TF_DerivedOperandTypeAttr T = TF_DerivedOperandTypeAttr<0>; TF_DerivedOperandTypeAttr T = TF_DerivedOperandTypeAttr<0>;
} }
def TF_LRNGradOp : TF_Op<"LRNGrad", [NoSideEffect]> {
let summary = "Gradients for Local Response Normalization.";
let description = [{
}];
let arguments = (ins
TensorOf<[BF16, F16, F32]>:$input_grads,
TensorOf<[BF16, F16, F32]>:$input_image,
TensorOf<[BF16, F16, F32]>:$output_image,
DefaultValuedAttr<I64Attr, "5">:$depth_radius,
DefaultValuedAttr<F32Attr, "1.0f">:$bias,
DefaultValuedAttr<F32Attr, "1.0f">:$alpha,
DefaultValuedAttr<F32Attr, "0.5f">:$beta
);
let results = (outs
TensorOf<[BF16, F16, F32]>:$output
);
TF_DerivedOperandTypeAttr T = TF_DerivedOperandTypeAttr<0>;
}
def TF_LeakyReluOp : TF_Op<"LeakyRelu", [NoSideEffect, SameOperandsAndResultType]> { def TF_LeakyReluOp : TF_Op<"LeakyRelu", [NoSideEffect, SameOperandsAndResultType]> {
let summary = "Computes rectified linear: `max(features, features * alpha)`."; let summary = "Computes rectified linear: `max(features, features * alpha)`.";
@ -6333,6 +6645,66 @@ the dimension is padded with zeros.
TF_DerivedResultTypeAttr Tcomplex = TF_DerivedResultTypeAttr<0>; TF_DerivedResultTypeAttr Tcomplex = TF_DerivedResultTypeAttr<0>;
} }
def TF_RFFT2DOp : TF_Op<"RFFT2D", [NoSideEffect]> {
let summary = "2D real-valued fast Fourier transform.";
let description = [{
Computes the 2-dimensional discrete Fourier transform of a real-valued signal
over the inner-most 2 dimensions of `input`.
Since the DFT of a real signal is Hermitian-symmetric, `RFFT2D` only returns the
`fft_length / 2 + 1` unique components of the FFT for the inner-most dimension
of `output`: the zero-frequency term, followed by the `fft_length / 2`
positive-frequency terms.
Along each axis `RFFT2D` is computed on, if `fft_length` is smaller than the
corresponding dimension of `input`, the dimension is cropped. If it is larger,
the dimension is padded with zeros.
}];
let arguments = (ins
TF_F32OrF64Tensor:$input,
I32Tensor:$fft_length
);
let results = (outs
TensorOf<[TF_Complex128, TF_Complex64]>:$output
);
TF_DerivedOperandTypeAttr Treal = TF_DerivedOperandTypeAttr<0>;
TF_DerivedResultTypeAttr Tcomplex = TF_DerivedResultTypeAttr<0>;
}
def TF_RFFT3DOp : TF_Op<"RFFT3D", [NoSideEffect]> {
let summary = "3D real-valued fast Fourier transform.";
let description = [{
Computes the 3-dimensional discrete Fourier transform of a real-valued signal
over the inner-most 3 dimensions of `input`.
Since the DFT of a real signal is Hermitian-symmetric, `RFFT3D` only returns the
`fft_length / 2 + 1` unique components of the FFT for the inner-most dimension
of `output`: the zero-frequency term, followed by the `fft_length / 2`
positive-frequency terms.
Along each axis `RFFT3D` is computed on, if `fft_length` is smaller than the
corresponding dimension of `input`, the dimension is cropped. If it is larger,
the dimension is padded with zeros.
}];
let arguments = (ins
TF_F32OrF64Tensor:$input,
I32Tensor:$fft_length
);
let results = (outs
TensorOf<[TF_Complex128, TF_Complex64]>:$output
);
TF_DerivedOperandTypeAttr Treal = TF_DerivedOperandTypeAttr<0>;
TF_DerivedResultTypeAttr Tcomplex = TF_DerivedResultTypeAttr<0>;
}
def TF_RandomGammaGradOp : TF_Op<"RandomGammaGrad", [NoSideEffect, ResultsBroadcastableShape]>, def TF_RandomGammaGradOp : TF_Op<"RandomGammaGrad", [NoSideEffect, ResultsBroadcastableShape]>,
WithBroadcastableBinOpBuilder { WithBroadcastableBinOpBuilder {
let summary = [{ let summary = [{
@ -8134,6 +8506,34 @@ def TF_SoftsignGradOp : TF_Op<"SoftsignGrad", [NoSideEffect, SameOperandsAndResu
TF_DerivedOperandTypeAttr T = TF_DerivedOperandTypeAttr<0>; TF_DerivedOperandTypeAttr T = TF_DerivedOperandTypeAttr<0>;
} }
def TF_SpaceToBatchOp : TF_Op<"SpaceToBatch", [NoSideEffect]> {
let summary = "SpaceToBatch for 4-D tensors of type T.";
let description = [{
This is a legacy version of the more general SpaceToBatchND.
Zero-pads and then rearranges (permutes) blocks of spatial data into batch.
More specifically, this op outputs a copy of the input tensor where values from
the `height` and `width` dimensions are moved to the `batch` dimension. After
the zero-padding, both `height` and `width` of the input must be divisible by the
block size.
}];
let arguments = (ins
TF_Tensor:$input,
TF_I32OrI64Tensor:$paddings,
Confined<I64Attr, [IntMinValue<2>]>:$block_size
);
let results = (outs
TF_Tensor:$output
);
TF_DerivedOperandTypeAttr T = TF_DerivedOperandTypeAttr<0>;
TF_DerivedOperandTypeAttr Tpaddings = TF_DerivedOperandTypeAttr<1>;
}
def TF_SpaceToBatchNDOp : TF_Op<"SpaceToBatchND", [NoSideEffect]> { def TF_SpaceToBatchNDOp : TF_Op<"SpaceToBatchND", [NoSideEffect]> {
let summary = "SpaceToBatch for N-D tensors of type T."; let summary = "SpaceToBatch for N-D tensors of type T.";
@ -10892,6 +11292,50 @@ create these operators.
TF_DerivedOperandSizeAttr num_args = TF_DerivedOperandSizeAttr<2>; TF_DerivedOperandSizeAttr num_args = TF_DerivedOperandSizeAttr<2>;
} }
def TF__FusedMatMulOp : TF_Op<"_FusedMatMul", [NoSideEffect]> {
let summary = [{
Performs a MatMul followed by a specified series of operations.
}];
let description = [{
The inputs to the MatMul are specified by `a` and `b`. The series of operations
that follows is specified by the `fused_ops` attribute, which is a list of TF op
names specified as strings (e.g. "Relu"). They are performed in order, where the
(first) input to each op is the output of the preceding op. The first input and
the output of each fused_op must be of type T.
Currently supported fused_op combinations are: ["BiasAdd"] and ["BiasAdd",A],
where A is one of {"Elu","Relu","Relu6"}.
* The first input to BiasAdd is the Conv2D result, and the additional BiasAdd
input is specified by `args`.
* If there is an op A specified, the output of the BiasAdd is the input to op A,
and op A produces the _FusedConv2D output. Otherwise, the BiasAdd produces the
_FusedConv2D output.
*NOTE*: Do not invoke this operator directly in Python. Grappler is
expected to create these operators.
}];
let arguments = (ins
F32Tensor:$a,
F32Tensor:$b,
Variadic<F32Tensor>:$args,
DefaultValuedAttr<BoolAttr, "false">:$transpose_a,
DefaultValuedAttr<BoolAttr, "false">:$transpose_b,
DefaultValuedAttr<StrArrayAttr, "{}">:$fused_ops,
DefaultValuedAttr<F32Attr, "0.0001f">:$epsilon
);
let results = (outs
F32Tensor:$product
);
TF_DerivedOperandTypeAttr T = TF_DerivedOperandTypeAttr<0>;
TF_DerivedOperandSizeAttr num_args = TF_DerivedOperandSizeAttr<2>;
}
def TF__HostComputeMlirOp : TF_Op<"_HostComputeMlir", []> { def TF__HostComputeMlirOp : TF_Op<"_HostComputeMlir", []> {
let summary = "A host-side computation called from a TPU device."; let summary = "A host-side computation called from a TPU device.";

View File

@ -44,6 +44,7 @@ limitations under the License.
#include "mlir/IR/Diagnostics.h" // from @llvm-project #include "mlir/IR/Diagnostics.h" // from @llvm-project
#include "mlir/IR/DialectImplementation.h" // from @llvm-project #include "mlir/IR/DialectImplementation.h" // from @llvm-project
#include "mlir/IR/Function.h" // from @llvm-project #include "mlir/IR/Function.h" // from @llvm-project
#include "mlir/IR/Identifier.h" // from @llvm-project
#include "mlir/IR/Location.h" // from @llvm-project #include "mlir/IR/Location.h" // from @llvm-project
#include "mlir/IR/MLIRContext.h" // from @llvm-project #include "mlir/IR/MLIRContext.h" // from @llvm-project
#include "mlir/IR/Matchers.h" // from @llvm-project #include "mlir/IR/Matchers.h" // from @llvm-project
@ -68,6 +69,17 @@ limitations under the License.
namespace mlir { namespace mlir {
namespace TF { namespace TF {
// Propagates underscore and device attributes from src to dst.
// TODO(b/158769932): This should be a general feature instead post some policy
// discussion.
static void PropagateAttributes(Operation *src, Operation *dst) {
auto device = mlir::Identifier::get("device", src->getContext());
for (auto named_attr : src->getAttrs()) {
if (*named_attr.first.begin() == '_' || named_attr.first == device)
dst->setAttr(named_attr.first, named_attr.second);
}
}
//===----------------------------------------------------------------------===// //===----------------------------------------------------------------------===//
// TF op helper functions // TF op helper functions
//===----------------------------------------------------------------------===// //===----------------------------------------------------------------------===//
@ -786,11 +798,49 @@ static LogicalResult Verify(BroadcastToOp op) {
} }
//===----------------------------------------------------------------------===// //===----------------------------------------------------------------------===//
// CastOp // CaseOp
//===----------------------------------------------------------------------===// //===----------------------------------------------------------------------===//
class FoldConstantCaseOp : public OpRewritePattern<TF::CaseOp> {
public:
explicit FoldConstantCaseOp(MLIRContext *context)
: OpRewritePattern<TF::CaseOp>(context) {}
LogicalResult matchAndRewrite(TF::CaseOp op,
PatternRewriter &rewriter) const override;
};
LogicalResult FoldConstantCaseOp::matchAndRewrite(
TF::CaseOp op, PatternRewriter &rewriter) const {
// Extract the constant cond value.
DenseIntElementsAttr branch;
if (!matchPattern(op.branch_index(), m_Constant(&branch))) return failure();
// Only attempt to fold scalar valued case statements.
// TODO(jpienaar): This can be removed if CaseOp's verifier covers it.
if (!branch.getType().cast<RankedTensorType>().getShape().empty())
return failure();
int index = *branch.getValues<int>().begin();
// TODO(jpienaar): This can be removed if CaseOp's verifier covers it.
if (index >= op.branches().size()) return failure();
auto func = op.branches()[index].cast<SymbolRefAttr>();
auto empty = rewriter.getStringAttr("");
auto call_op = rewriter.create<PartitionedCallOp>(
op.getLoc(), op.getResultTypes(), op.getOperands().drop_front(), func,
/*config=*/empty, /*config_proto=*/empty, /*executor_type=*/empty);
PropagateAttributes(op.getOperation(), call_op);
rewriter.replaceOp(op, call_op.getResults());
return success();
}
void CaseOp::getCanonicalizationPatterns(OwningRewritePatternList &results,
MLIRContext *context) {
results.insert<FoldConstantCaseOp>(context);
}
//===----------------------------------------------------------------------===// //===----------------------------------------------------------------------===//
// LeakyReluOp // CastOp
//===----------------------------------------------------------------------===// //===----------------------------------------------------------------------===//
OpFoldResult CastOp::fold(ArrayRef<Attribute> operands) { OpFoldResult CastOp::fold(ArrayRef<Attribute> operands) {
@ -823,6 +873,11 @@ static LogicalResult Verify(OpT op) {
/*mask_one_dim=*/true, op.getOperation()); /*mask_one_dim=*/true, op.getOperation());
} }
void ConcatOp::getCanonicalizationPatterns(OwningRewritePatternList &results,
MLIRContext *context) {
results.insert<ConvertToConcatV2>(context);
}
//===----------------------------------------------------------------------===// //===----------------------------------------------------------------------===//
// ConcatOffsetOp // ConcatOffsetOp
//===----------------------------------------------------------------------===// //===----------------------------------------------------------------------===//

View File

@ -334,7 +334,7 @@ SmallVector<StringRef, 2> GetExportedNames(Operation *op) {
bool IsExported(Operation *op) { bool IsExported(Operation *op) {
auto exported_names = auto exported_names =
op->getAttrOfType<ArrayAttr>("tf_saved_model.exported_names"); op->getAttrOfType<ArrayAttr>("tf_saved_model.exported_names");
return exported_names && exported_names.size() != 0; return exported_names && !exported_names.empty();
} }
bool HasTfSavedModelSemantics(ModuleOp module) { bool HasTfSavedModelSemantics(ModuleOp module) {

View File

@ -133,6 +133,16 @@ func @testSameCastTypeAcrossBasicBlocks(tensor<8x16x32x64xf32>) -> tensor<8x16x3
// CHECK: return %arg0 // CHECK: return %arg0
} }
// CHECK-LABEL: testConcatCanonicalization
func @testConcatCanonicalization(%arg0: tensor<2x1xi32>, %arg1: tensor<2x1xi32>) -> tensor<2x2xi32> {
// CHECK: %[[AXIS:.*]] = "tf.Const"
%0 = "tf.Const"() { value = dense<1> : tensor<i32> } : () -> tensor<i32>
// CHECK: "tf.ConcatV2"(%arg0, %arg1, %[[AXIS]])
%1 = "tf.Concat"(%0, %arg0, %arg1) : (tensor<i32>, tensor<2x1xi32>, tensor<2x1xi32>) -> tensor<2x2xi32>
return %1 : tensor<2x2xi32>
}
// CHECK-LABEL: testLogOfSoftmax // CHECK-LABEL: testLogOfSoftmax
func @testLogOfSoftmax(%arg0: tensor<8x16xf32>) -> tensor<8x16xf32> { func @testLogOfSoftmax(%arg0: tensor<8x16xf32>) -> tensor<8x16xf32> {
%0 = "tf.Softmax"(%arg0) : (tensor<8x16xf32>) -> tensor<8x16xf32> %0 = "tf.Softmax"(%arg0) : (tensor<8x16xf32>) -> tensor<8x16xf32>
@ -550,3 +560,29 @@ func @foldFill() -> (tensor<3x2x1xf32>, tensor<*xf32>, tensor<*xcomplex<f32>>) {
return %2, %3, %4 : tensor<3x2x1xf32>, tensor<*xf32>, tensor<*xcomplex<f32>> return %2, %3, %4 : tensor<3x2x1xf32>, tensor<*xf32>, tensor<*xcomplex<f32>>
} }
// CHECK-LABEL: foldCase
func @foldCase(%arg0: tensor<f32>, %arg1: tensor<f32>) -> (tensor<f32>) {
%2 = constant dense<1> : tensor<i32>
%3 = constant dense<0> : tensor<i32>
// CHECK: PartitionedCall
// CHECK-SAME: device = "noodle"
// CHECK-SAME: f = @add
%4 = "tf.Case"(%2, %arg0, %arg1) {branches = [@sub, @add], output_shapes = [#tf.shape<>], device = "noodle"} : (tensor<i32>, tensor<f32>, tensor<f32>) -> tensor<f32>
// CHECK: PartitionedCall
// CHECK-SAME: _cluster_launch = "not_ready"
// CHECK-SAME: f = @sub
%5 = "tf.Case"(%3, %4, %arg1) {branches = [@sub, @add], output_shapes = [#tf.shape<>], _cluster_launch = "not_ready"} : (tensor<i32>, tensor<f32>, tensor<f32>) -> tensor<f32>
return %5 : tensor<f32>
}
func @add(%arg0: tensor<*xf32>, %arg1: tensor<*xf32>) -> tensor<*xf32> {
%0 = "tf.Add"(%arg0, %arg1): (tensor<*xf32>, tensor<*xf32>) -> tensor<*xf32>
return %0 : tensor<*xf32>
}
func @sub(%arg0: tensor<*xf32>, %arg1: tensor<*xf32>) -> tensor<*xf32> {
%0 = "tf.Sub"(%arg0, %arg1) : (tensor<*xf32>, tensor<*xf32>) -> tensor<*xf32>
return %0 : tensor<*xf32>
}

View File

@ -370,7 +370,6 @@ func @decompose_resource_gather_op(%indices : tensor<5xi32>) -> tensor<2x5x16xi3
// Tests that composite tf.ResourceScatterUpdate operation is decomposed. // Tests that composite tf.ResourceScatterUpdate operation is decomposed.
// CHECK-LABEL: @decompose_resource_scatter_update_op // CHECK-LABEL: @decompose_resource_scatter_update_op
// CHECK-SAME: ([[INDEX:%.+]]: tensor<2x?xi32>, [[UPDATE:%.+]]: tensor<?x?x?xi32>) // CHECK-SAME: ([[INDEX:%.+]]: tensor<2x?xi32>, [[UPDATE:%.+]]: tensor<?x?x?xi32>)
func @decompose_resource_scatter_update_op(%indices : tensor<2x?xi32>, %updates: tensor<?x?x?xi32>) { func @decompose_resource_scatter_update_op(%indices : tensor<2x?xi32>, %updates: tensor<?x?x?xi32>) {
@ -384,3 +383,34 @@ func @decompose_resource_scatter_update_op(%indices : tensor<2x?xi32>, %updates:
return return
} }
// -----
// Tests that tf.VariableShape operation is decomposed.
// CHECK-LABEL: @decompose_variable_shape_i32
func @decompose_variable_shape_i32(%input: tensor<!tf.resource<tensor<?x?x?xf32>>>) -> tensor<3xi32> {
%0 = "tf.VariableShape"(%input) : (tensor<!tf.resource<tensor<?x?x?xf32>>>) -> tensor<3xi32>
// CHECK: %[[READ:.*]] = "tf.ReadVariableOp"(%arg0)
// CHECK: %[[SHAPE:.*]] = "tf.Shape"(%[[READ]])
// CHECK: return %[[SHAPE]]
return %0 : tensor<3xi32>
}
// CHECK-LABEL: @decompose_variable_shape_i64
func @decompose_variable_shape_i64(%input: tensor<!tf.resource<tensor<?x?x?xf32>>>) -> tensor<3xi64> {
%0 = "tf.VariableShape"(%input) : (tensor<!tf.resource<tensor<?x?x?xf32>>>) -> tensor<3xi64>
// CHECK: %[[READ:.*]] = "tf.ReadVariableOp"(%arg0)
// CHECK: %[[SHAPE:.*]] = "tf.Shape"(%[[READ]])
// CHECK: return %[[SHAPE]]
return %0 : tensor<3xi64>
}
// CHECK-LABEL: @decompose_variable_shape_no_subtype
func @decompose_variable_shape_no_subtype(%input: tensor<!tf.resource>) -> tensor<3xi32> {
%0 = "tf.VariableShape"(%input) : (tensor<!tf.resource>) -> tensor<3xi32>
// CHECK: "tf.VariableShape"
// CHECK-NOT: "tf.ReadVariableOp"
// CHECK-NOT: "tf.Shape"
return %0 : tensor<3xi32>
}

View File

@ -1,4 +1,4 @@
// RUN: tf-opt %s -tf-op-fusion | FileCheck %s // RUN: tf-opt %s -tf-fused-kernel-matcher | FileCheck %s
//===----------------------------------------------------------------------===// //===----------------------------------------------------------------------===//
// Conv2D + BiasAdd + <Activation> fusions. // Conv2D + BiasAdd + <Activation> fusions.
@ -107,3 +107,54 @@ func @conv2D_dataFormatMismatch(%arg0: tensor<128xf32>, %arg1: tensor<1x1x3x128x
%3 = "tf.Identity"(%2) : (tensor<*xf32>) -> tensor<*xf32> %3 = "tf.Identity"(%2) : (tensor<*xf32>) -> tensor<*xf32>
return %3 : tensor<*xf32> return %3 : tensor<*xf32>
} }
//===----------------------------------------------------------------------===//
// MatMul + BiasAdd + <Activation> fusions.
//===----------------------------------------------------------------------===//
// CHECK-LABEL: matmulBiasAdd
func @matmulBiasAdd(%arg0: tensor<64xf32>, %arg1: tensor<8x32xf32>, %arg2: tensor<32x64xf32>) -> (tensor<*xf32>) {
// CHECK: %[[VAL_3:.*]] = "tf._FusedMatMul"(%arg1, %arg2, %arg0) {epsilon = 0.000000e+00 : f32, fused_ops = ["BiasAdd"], transpose_a = false, transpose_b = false} : (tensor<8x32xf32>, tensor<32x64xf32>, tensor<64xf32>) -> tensor<*xf32>
// CHECK: %[[VAL_4:.*]] = "tf.Identity"(%[[VAL_3]]) : (tensor<*xf32>) -> tensor<*xf32>
// CHECK: return %[[VAL_4]]
%3 = "tf.MatMul"(%arg1, %arg2) {transpose_a = false, transpose_b = false} : (tensor<8x32xf32>, tensor<32x64xf32>) -> tensor<*xf32>
%4 = "tf.BiasAdd"(%3, %arg0) {data_format = "NHWC"} : (tensor<*xf32>, tensor<64xf32>) -> tensor<*xf32>
%5 = "tf.Identity"(%4) : (tensor<*xf32>) -> tensor<*xf32>
return %5 : tensor<*xf32>
}
// CHECK-LABEL: matmulBiasAdd_relu
func @matmulBiasAdd_relu(%arg0: tensor<64xf32>, %arg1: tensor<8x32xf32>, %arg2: tensor<32x64xf32>) -> (tensor<*xf32>) {
// CHECK: %[[VAL_3:.*]] = "tf._FusedMatMul"(%arg1, %arg2, %arg0) {epsilon = 0.000000e+00 : f32, fused_ops = ["BiasAdd", "Relu"], transpose_a = false, transpose_b = false} : (tensor<8x32xf32>, tensor<32x64xf32>, tensor<64xf32>) -> tensor<*xf32>
// CHECK: %[[VAL_4:.*]] = "tf.Identity"(%[[VAL_3]]) : (tensor<*xf32>) -> tensor<*xf32>
// CHECK: return %[[VAL_4]]
%3 = "tf.MatMul"(%arg1, %arg2) {transpose_a = false, transpose_b = false} : (tensor<8x32xf32>, tensor<32x64xf32>) -> tensor<*xf32>
%4 = "tf.BiasAdd"(%3, %arg0) {data_format = "NHWC"} : (tensor<*xf32>, tensor<64xf32>) -> tensor<*xf32>
%5 = "tf.Relu"(%4) : (tensor<*xf32>) -> tensor<*xf32>
%6 = "tf.Identity"(%5) : (tensor<*xf32>) -> tensor<*xf32>
return %6 : tensor<*xf32>
}
// CHECK-LABEL: matmulBiasAdd_relu6
func @matmulBiasAdd_relu6(%arg0: tensor<64xf32>, %arg1: tensor<8x32xf32>, %arg2: tensor<32x64xf32>) -> (tensor<*xf32>) {
// CHECK: %[[VAL_3:.*]] = "tf._FusedMatMul"(%arg1, %arg2, %arg0) {epsilon = 0.000000e+00 : f32, fused_ops = ["BiasAdd", "Relu6"], transpose_a = false, transpose_b = false} : (tensor<8x32xf32>, tensor<32x64xf32>, tensor<64xf32>) -> tensor<*xf32>
// CHECK: %[[VAL_4:.*]] = "tf.Identity"(%[[VAL_3]]) : (tensor<*xf32>) -> tensor<*xf32>
// CHECK: return %[[VAL_4]]
%3 = "tf.MatMul"(%arg1, %arg2) {transpose_a = false, transpose_b = false} : (tensor<8x32xf32>, tensor<32x64xf32>) -> tensor<*xf32>
%4 = "tf.BiasAdd"(%3, %arg0) {data_format = "NHWC"} : (tensor<*xf32>, tensor<64xf32>) -> tensor<*xf32>
%5 = "tf.Relu6"(%4) : (tensor<*xf32>) -> tensor<*xf32>
%6 = "tf.Identity"(%5) : (tensor<*xf32>) -> tensor<*xf32>
return %6 : tensor<*xf32>
}
// CHECK-LABEL: matmulBiasAdd_elu
func @matmulBiasAdd_elu(%arg0: tensor<64xf32>, %arg1: tensor<8x32xf32>, %arg2: tensor<32x64xf32>) -> (tensor<*xf32>) {
// CHECK: %[[VAL_3:.*]] = "tf._FusedMatMul"(%arg1, %arg2, %arg0) {epsilon = 0.000000e+00 : f32, fused_ops = ["BiasAdd", "Elu"], transpose_a = false, transpose_b = false} : (tensor<8x32xf32>, tensor<32x64xf32>, tensor<64xf32>) -> tensor<*xf32>
// CHECK: %[[VAL_4:.*]] = "tf.Identity"(%[[VAL_3]]) : (tensor<*xf32>) -> tensor<*xf32>
// CHECK: return %[[VAL_4]]
%3 = "tf.MatMul"(%arg1, %arg2) {transpose_a = false, transpose_b = false} : (tensor<8x32xf32>, tensor<32x64xf32>) -> tensor<*xf32>
%4 = "tf.BiasAdd"(%3, %arg0) {data_format = "NHWC"} : (tensor<*xf32>, tensor<64xf32>) -> tensor<*xf32>
%5 = "tf.Elu"(%4) : (tensor<*xf32>) -> tensor<*xf32>
%6 = "tf.Identity"(%5) : (tensor<*xf32>) -> tensor<*xf32>
return %6 : tensor<*xf32>
}

View File

@ -49,5 +49,5 @@ library {
} }
} }
# CHECK-DAG: func @custom_relu{{[0-9]*}}() attributes {tf._implements = #tf.func<@tensorflow.relu, {}>} # CHECK-DAG: func @custom_relu{{[0-9]*}}(){{.+}}tf._implements = #tf.func<@tensorflow.relu, {}>}
# CHECK-DAG: func @custom_embedding_matmul{{[0-9]*}}() attributes {tf._implements = #tf.func<@tensorflow.embedding_matmul, {key1 = 2 : i64, key2 = false}>} # CHECK-DAG: func @custom_embedding_matmul{{[0-9]*}}(){{.+}}tf._implements = #tf.func<@tensorflow.embedding_matmul, {key1 = 2 : i64, key2 = false}>}

View File

@ -124,5 +124,5 @@ versions {
# CHECK: "tf.LegacyCall"() {_disable_call_shape_inference = false, f = @foo110} # CHECK: "tf.LegacyCall"() {_disable_call_shape_inference = false, f = @foo110}
# CHECK: "tf.LegacyCall"() {_disable_call_shape_inference = false, f = @foo111} # CHECK: "tf.LegacyCall"() {_disable_call_shape_inference = false, f = @foo111}
# CHECK-LABEL: func @foo110() { # CHECK-LABEL: func @foo110() attributes {sym_visibility = "private"}
# CHECK-LABEL: func @foo111() { # CHECK-LABEL: func @foo111() attributes {sym_visibility = "private"}

View File

@ -57,7 +57,7 @@ versions {
# CHECK: "tf.LegacyCall"() {_disable_call_shape_inference = true, f = @foo0} # CHECK: "tf.LegacyCall"() {_disable_call_shape_inference = true, f = @foo0}
# CHECK: "tf.LegacyCall"() {_disable_call_shape_inference = false, f = @bar0} # CHECK: "tf.LegacyCall"() {_disable_call_shape_inference = false, f = @bar0}
# CHECK-LABEL: func @foo0() { # CHECK-LABEL: func @foo0() attributes {sym_visibility = "private"}
# CHECK: "tf.LegacyCall"() {_disable_call_shape_inference = false, f = @bar0} # CHECK: "tf.LegacyCall"() {_disable_call_shape_inference = false, f = @bar0}
# CHECK-LABEL: func @bar0() { # CHECK-LABEL: func @bar0() attributes {sym_visibility = "private"}

View File

@ -7,6 +7,7 @@ module attributes {tf.versions = {producer = 888 : i32}, tf.devices = ["/job:wor
func @head_single_outside_compiled_op(%arg0: tensor<i32>) { func @head_single_outside_compiled_op(%arg0: tensor<i32>) {
// CHECK: "tf_device.launch" // CHECK: "tf_device.launch"
// CHECK-NEXT: "tf.A" // CHECK-NEXT: "tf.A"
// CHECK-NOT: _xla_outside_compilation
// CHECK-NEXT: tf_device.return // CHECK-NEXT: tf_device.return
// CHECK-NEXT: device = "/job:worker/replica:0/task:0/device:CPU:0" // CHECK-NEXT: device = "/job:worker/replica:0/task:0/device:CPU:0"
// //
@ -27,6 +28,7 @@ module attributes {tf.versions = {producer = 888 : i32}, tf.devices = ["/job:wor
func @head_single_outside_compiled_op_no_operands() { func @head_single_outside_compiled_op_no_operands() {
// CHECK: %[[LAUNCH_OUT:.*]] = "tf_device.launch" // CHECK: %[[LAUNCH_OUT:.*]] = "tf_device.launch"
// CHECK-NEXT: %[[A_OUT:.*]] = "tf.A" // CHECK-NEXT: %[[A_OUT:.*]] = "tf.A"
// CHECK-NOT: _xla_outside_compilation
// CHECK-NEXT: tf_device.return %[[A_OUT]] // CHECK-NEXT: tf_device.return %[[A_OUT]]
// CHECK-NEXT: device = "/job:worker/replica:0/task:0/device:CPU:0" // CHECK-NEXT: device = "/job:worker/replica:0/task:0/device:CPU:0"
// //
@ -49,6 +51,7 @@ module attributes {tf.versions = {producer = 888 : i32}, tf.devices = ["/job:wor
%a = "tf.A"() : () -> tensor<i32> %a = "tf.A"() : () -> tensor<i32>
// CHECK-NEXT: %[[LAUNCH_OUT:.*]] = "tf_device.launch" // CHECK-NEXT: %[[LAUNCH_OUT:.*]] = "tf_device.launch"
// CHECK-NEXT: %[[B_OUT:.*]] = "tf.B" // CHECK-NEXT: %[[B_OUT:.*]] = "tf.B"
// CHECK-NOT: _xla_outside_compilation
// CHECK-NEXT: tf_device.return %[[B_OUT]] // CHECK-NEXT: tf_device.return %[[B_OUT]]
// CHECK-NEXT: device = "/job:worker/replica:0/task:0/device:CPU:0" // CHECK-NEXT: device = "/job:worker/replica:0/task:0/device:CPU:0"
// //
@ -69,6 +72,7 @@ module attributes {tf.versions = {producer = 888 : i32}, tf.devices = ["/job:wor
func @head_aliased_output() -> (tensor<i32>, tensor<i32>, tensor<i32>) { func @head_aliased_output() -> (tensor<i32>, tensor<i32>, tensor<i32>) {
// CHECK: %[[LAUNCH_OUT:.*]] = "tf_device.launch" // CHECK: %[[LAUNCH_OUT:.*]] = "tf_device.launch"
// CHECK-NEXT: %[[A_OUT:.*]] = "tf.A" // CHECK-NEXT: %[[A_OUT:.*]] = "tf.A"
// CHECK-NOT: _xla_outside_compilation
// CHECK-NEXT: tf_device.return %[[A_OUT]] // CHECK-NEXT: tf_device.return %[[A_OUT]]
// CHECK-NEXT: device = "/job:worker/replica:0/task:0/device:CPU:0" // CHECK-NEXT: device = "/job:worker/replica:0/task:0/device:CPU:0"
// //
@ -96,8 +100,11 @@ module attributes {tf.versions = {producer = 888 : i32}, tf.devices = ["/job:wor
func @head_all_cluster_op(%arg0: tensor<i32>) -> tensor<i32> { func @head_all_cluster_op(%arg0: tensor<i32>) -> tensor<i32> {
// CHECK: %[[LAUNCH_OUT:.*]] = "tf_device.launch" // CHECK: %[[LAUNCH_OUT:.*]] = "tf_device.launch"
// CHECK-NEXT: %[[A_OUT:.*]] = "tf.A" // CHECK-NEXT: %[[A_OUT:.*]] = "tf.A"
// CHECK-NOT: _xla_outside_compilation
// CHECK-NEXT: %[[B_OUT:.*]] = "tf.B"(%[[A_OUT]]) // CHECK-NEXT: %[[B_OUT:.*]] = "tf.B"(%[[A_OUT]])
// CHECK-NOT: _xla_outside_compilation
// CHECK-NEXT: %[[C_OUT:.*]] = "tf.C"(%[[B_OUT]], %arg0) // CHECK-NEXT: %[[C_OUT:.*]] = "tf.C"(%[[B_OUT]], %arg0)
// CHECK-NOT: _xla_outside_compilation
// CHECK-NEXT: tf_device.return %[[C_OUT]] // CHECK-NEXT: tf_device.return %[[C_OUT]]
// CHECK-NEXT: device = "/job:worker/replica:0/task:0/device:CPU:0" // CHECK-NEXT: device = "/job:worker/replica:0/task:0/device:CPU:0"
// //
@ -117,8 +124,11 @@ module attributes {tf.versions = {producer = 888 : i32}, tf.devices = ["/job:wor
func @head_multiple_outside_compiled_ops(%arg0: tensor<i32>) { func @head_multiple_outside_compiled_ops(%arg0: tensor<i32>) {
// CHECK: %[[LAUNCH_OUT:.*]] = "tf_device.launch" // CHECK: %[[LAUNCH_OUT:.*]] = "tf_device.launch"
// CHECK-NEXT: %[[A_OUT:.*]] = "tf.A" // CHECK-NEXT: %[[A_OUT:.*]] = "tf.A"
// CHECK-NOT: _xla_outside_compilation
// CHECK-NEXT: %[[B_OUT:.*]] = "tf.B"(%[[A_OUT]]) // CHECK-NEXT: %[[B_OUT:.*]] = "tf.B"(%[[A_OUT]])
// CHECK-NOT: _xla_outside_compilation
// CHECK-NEXT: "tf.C" // CHECK-NEXT: "tf.C"
// CHECK-NOT: _xla_outside_compilation
// CHECK-NEXT: tf_device.return %[[B_OUT]] // CHECK-NEXT: tf_device.return %[[B_OUT]]
// CHECK-NEXT: device = "/job:worker/replica:0/task:0/device:CPU:0" // CHECK-NEXT: device = "/job:worker/replica:0/task:0/device:CPU:0"
// //
@ -141,6 +151,7 @@ module attributes {tf.versions = {producer = 888 : i32}, tf.devices = ["/job:wor
// //
// CHECK-NEXT: %[[LAUNCH_OUT:.*]] = "tf_device.launch"() // CHECK-NEXT: %[[LAUNCH_OUT:.*]] = "tf_device.launch"()
// CHECK-NEXT: %[[A_OUT:.*]] = "tf.A"(%[[RI]]) // CHECK-NEXT: %[[A_OUT:.*]] = "tf.A"(%[[RI]])
// CHECK-NOT: _xla_outside_compilation
// CHECK-NEXT: tf_device.return %[[A_OUT]] // CHECK-NEXT: tf_device.return %[[A_OUT]]
// CHECK-NEXT: device = "TPU_REPLICATED_HOST" // CHECK-NEXT: device = "TPU_REPLICATED_HOST"
// //
@ -173,6 +184,7 @@ module attributes {tf.versions = {producer = 888 : i32}, tf.devices = ["/job:wor
// //
// CHECK: "tf_device.launch" // CHECK: "tf_device.launch"
// CHECK-NEXT: "tf.B"(%[[CLUSTER_OUT]]) // CHECK-NEXT: "tf.B"(%[[CLUSTER_OUT]])
// CHECK-NOT: _xla_outside_compilation
// CHECK-NEXT: tf_device.return // CHECK-NEXT: tf_device.return
// CHECK-NEXT: device = "/job:worker/replica:0/task:0/device:CPU:0" // CHECK-NEXT: device = "/job:worker/replica:0/task:0/device:CPU:0"
"tf_device.cluster"() ( { "tf_device.cluster"() ( {
@ -199,6 +211,7 @@ module attributes {tf.versions = {producer = 888 : i32}, tf.devices = ["/job:wor
// //
// CHECK: %[[LAUNCH_OUT:.*]] = "tf_device.launch" // CHECK: %[[LAUNCH_OUT:.*]] = "tf_device.launch"
// CHECK-NEXT: %[[B_OUT:.*]] = "tf.B"(%[[CLUSTER_OUT]]) // CHECK-NEXT: %[[B_OUT:.*]] = "tf.B"(%[[CLUSTER_OUT]])
// CHECK-NOT: _xla_outside_compilation
// CHECK-NEXT: tf_device.return %[[B_OUT]] // CHECK-NEXT: tf_device.return %[[B_OUT]]
// CHECK-NEXT: device = "/job:worker/replica:0/task:0/device:CPU:0" // CHECK-NEXT: device = "/job:worker/replica:0/task:0/device:CPU:0"
%cluster = "tf_device.cluster"() ( { %cluster = "tf_device.cluster"() ( {
@ -226,7 +239,9 @@ module attributes {tf.versions = {producer = 888 : i32}, tf.devices = ["/job:wor
// //
// CHECK: "tf_device.launch" // CHECK: "tf_device.launch"
// CHECK-NEXT: %[[C_OUT:.*]] = "tf.C"(%arg0, %[[CLUSTER_OUT]]#1) // CHECK-NEXT: %[[C_OUT:.*]] = "tf.C"(%arg0, %[[CLUSTER_OUT]]#1)
// CHECK-NOT: _xla_outside_compilation
// CHECK-NEXT: "tf.D"(%[[C_OUT]], %arg0, %[[CLUSTER_OUT]]#0) // CHECK-NEXT: "tf.D"(%[[C_OUT]], %arg0, %[[CLUSTER_OUT]]#0)
// CHECK-NOT: _xla_outside_compilation
// CHECK-NEXT: tf_device.return // CHECK-NEXT: tf_device.return
// CHECK-NEXT: device = "/job:worker/replica:0/task:0/device:CPU:0" // CHECK-NEXT: device = "/job:worker/replica:0/task:0/device:CPU:0"
"tf_device.cluster"() ( { "tf_device.cluster"() ( {
@ -258,6 +273,7 @@ module attributes {tf.versions = {producer = 888 : i32}, tf.devices = ["/job:wor
// //
// CHECK: %[[LAUNCH_OUT:.*]] = "tf_device.launch" // CHECK: %[[LAUNCH_OUT:.*]] = "tf_device.launch"
// CHECK-NEXT: %[[D_OUT:.*]] = "tf.D"(%[[CLUSTER_OUT]]#0, %[[A_OUT]]) // CHECK-NEXT: %[[D_OUT:.*]] = "tf.D"(%[[CLUSTER_OUT]]#0, %[[A_OUT]])
// CHECK-NOT: _xla_outside_compilation
// CHECK-NEXT: tf_device.return // CHECK-NEXT: tf_device.return
// CHECK-NEXT: device = "/job:worker/replica:0/task:0/device:CPU:0" // CHECK-NEXT: device = "/job:worker/replica:0/task:0/device:CPU:0"
%cluster:5 = "tf_device.cluster"() ( { %cluster:5 = "tf_device.cluster"() ( {
@ -286,6 +302,7 @@ module attributes {tf.versions = {producer = 888 : i32}, tf.devices = ["/job:wor
// //
// CHECK-NEXT: "tf_device.launch"() // CHECK-NEXT: "tf_device.launch"()
// CHECK-NEXT: %[[B_OUT:.*]] = "tf.B"(%[[CLUSTER_OUT]], %[[RI]]) // CHECK-NEXT: %[[B_OUT:.*]] = "tf.B"(%[[CLUSTER_OUT]], %[[RI]])
// CHECK-NOT: _xla_outside_compilation
// CHECK-NEXT: tf_device.return // CHECK-NEXT: tf_device.return
// CHECK-NEXT: device = "TPU_REPLICATED_HOST" // CHECK-NEXT: device = "TPU_REPLICATED_HOST"
tf_device.replicate([%arg0, %arg1] as %ri : tensor<i32>) {n = 2 : i32} { tf_device.replicate([%arg0, %arg1] as %ri : tensor<i32>) {n = 2 : i32} {
@ -320,6 +337,7 @@ module attributes {tf.versions = {producer = 888 : i32}, tf.devices = ["/job:wor
func @head_tail_simple_extraction(%arg0: tensor<i32>) -> tensor<i32> { func @head_tail_simple_extraction(%arg0: tensor<i32>) -> tensor<i32> {
// CHECK: %[[HEAD_LAUNCH_OUT:.*]] = "tf_device.launch" // CHECK: %[[HEAD_LAUNCH_OUT:.*]] = "tf_device.launch"
// CHECK-NEXT: %[[A_OUT:.*]] = "tf.A"(%arg0) // CHECK-NEXT: %[[A_OUT:.*]] = "tf.A"(%arg0)
// CHECK-NOT: _xla_outside_compilation
// CHECK-NEXT: tf_device.return %[[A_OUT]] // CHECK-NEXT: tf_device.return %[[A_OUT]]
// CHECK-NEXT: device = "/job:worker/replica:0/task:0/device:CPU:0" // CHECK-NEXT: device = "/job:worker/replica:0/task:0/device:CPU:0"
// //
@ -335,6 +353,7 @@ module attributes {tf.versions = {producer = 888 : i32}, tf.devices = ["/job:wor
// //
// CHECK: %[[TAIL_LAUNCH_OUT:.*]] = "tf_device.launch" // CHECK: %[[TAIL_LAUNCH_OUT:.*]] = "tf_device.launch"
// CHECK-NEXT: %[[C_OUT:.*]] = "tf.C"(%[[CLUSTER_OUT]]) // CHECK-NEXT: %[[C_OUT:.*]] = "tf.C"(%[[CLUSTER_OUT]])
// CHECK-NOT: _xla_outside_compilation
// CHECK-NEXT: tf_device.return %[[C_OUT]] // CHECK-NEXT: tf_device.return %[[C_OUT]]
// CHECK-NEXT: device = "/job:worker/replica:0/task:0/device:CPU:0" // CHECK-NEXT: device = "/job:worker/replica:0/task:0/device:CPU:0"
%cluster = "tf_device.cluster"() ( { %cluster = "tf_device.cluster"() ( {
@ -353,6 +372,7 @@ module attributes {tf.versions = {producer = 888 : i32}, tf.devices = ["/job:wor
// //
// CHECK-NEXT: %[[HEAD_LAUNCH_OUT:.*]] = "tf_device.launch"() // CHECK-NEXT: %[[HEAD_LAUNCH_OUT:.*]] = "tf_device.launch"()
// CHECK-NEXT: %[[A_OUT:.*]] = "tf.A"(%[[RI]]) // CHECK-NEXT: %[[A_OUT:.*]] = "tf.A"(%[[RI]])
// CHECK-NOT: _xla_outside_compilation
// CHECK-NEXT: tf_device.return %[[A_OUT]] // CHECK-NEXT: tf_device.return %[[A_OUT]]
// CHECK-NEXT: device = "TPU_REPLICATED_HOST" // CHECK-NEXT: device = "TPU_REPLICATED_HOST"
// //
@ -370,6 +390,7 @@ module attributes {tf.versions = {producer = 888 : i32}, tf.devices = ["/job:wor
// //
// CHECK-NEXT: "tf_device.launch"() // CHECK-NEXT: "tf_device.launch"()
// CHECK-NEXT: "tf.D"(%[[HEAD_LAUNCH_OUT]], %[[CLUSTER_OUT]], %[[RI]]) // CHECK-NEXT: "tf.D"(%[[HEAD_LAUNCH_OUT]], %[[CLUSTER_OUT]], %[[RI]])
// CHECK-NOT: _xla_outside_compilation
// CHECK-NEXT: tf_device.return // CHECK-NEXT: tf_device.return
// CHECK-NEXT: device = "TPU_REPLICATED_HOST" // CHECK-NEXT: device = "TPU_REPLICATED_HOST"
tf_device.replicate([%arg0, %arg1] as %ri : tensor<i32>) {n = 2 : i32} { tf_device.replicate([%arg0, %arg1] as %ri : tensor<i32>) {n = 2 : i32} {

View File

@ -1,4 +1,4 @@
// RUN: tf-opt %s -split-input-file -tf-tpu-host-computation-expansion | FileCheck %s --dump-input-on-failure // RUN: tf-opt %s -split-input-file -tf-tpu-host-computation-expansion | FileCheck %s
// Tests expansion of a outside compiled ops at head/tail of TPU computation. // Tests expansion of a outside compiled ops at head/tail of TPU computation.
@ -26,7 +26,7 @@ func @cast_at_head_expanded(%arg0: tensor<?xi32>) {
"tf.B"(%1) {_xla_outside_compilation = "cluster1"} : (tensor<?xi32>) -> () "tf.B"(%1) {_xla_outside_compilation = "cluster1"} : (tensor<?xi32>) -> ()
"tf.C"() : () -> () "tf.C"() : () -> ()
tf_device.return tf_device.return
}) {cluster_attr = "cluster_attr"} : () -> () }) {} : () -> ()
return return
} }
@ -44,7 +44,7 @@ func @check_consecutive_unary_ops_outside_compiled(%arg0: tensor<?xi32>) {
"tf.B"(%2) {_xla_outside_compilation = "cluster1"} : (tensor<?xi32>) -> () "tf.B"(%2) {_xla_outside_compilation = "cluster1"} : (tensor<?xi32>) -> ()
"tf.C"() : () -> () "tf.C"() : () -> ()
tf_device.return tf_device.return
}) {cluster_attr = "cluster_attr"} : () -> () }) {} : () -> ()
return return
} }
@ -59,7 +59,7 @@ func @check_only_necesarily_ops_outside_compiled(%arg0: tensor<?xi32>) {
"tf.B"(%1) : (tensor<?xi32>) -> () "tf.B"(%1) : (tensor<?xi32>) -> ()
"tf.C"() : () -> () "tf.C"() : () -> ()
tf_device.return tf_device.return
}) {cluster_attr = "cluster_attr"} : () -> () }) {} : () -> ()
return return
} }
@ -67,9 +67,9 @@ func @check_only_necesarily_ops_outside_compiled(%arg0: tensor<?xi32>) {
func @check_only_necesarily_ops_outside_compiled_with_chained_ops(%arg0: tensor<?xi32>) { func @check_only_necesarily_ops_outside_compiled_with_chained_ops(%arg0: tensor<?xi32>) {
// CHECK: "tf_device.cluster" // CHECK: "tf_device.cluster"
// CHECK-NEXT: "tf.Cast" // CHECK-NEXT: "tf.Cast"
// CHECK-NOT: _xla_outside_compilation = "" // CHECK-NOT: _xla_outside_compilation
// CHECK-NEXT: "tf.Identity" // CHECK-NEXT: "tf.Identity"
// CHECK-NOT: _xla_outside_compilation = "" // CHECK-NOT: _xla_outside_compilation
// CHECK-NEXT: "tf.B" // CHECK-NEXT: "tf.B"
"tf_device.cluster"() ( { "tf_device.cluster"() ( {
%1 = "tf.Cast"(%arg0) : (tensor<?xi32>) -> (tensor<?xi32>) %1 = "tf.Cast"(%arg0) : (tensor<?xi32>) -> (tensor<?xi32>)
@ -77,6 +77,19 @@ func @check_only_necesarily_ops_outside_compiled_with_chained_ops(%arg0: tensor<
"tf.B"(%2) : (tensor<?xi32>) -> () "tf.B"(%2) : (tensor<?xi32>) -> ()
"tf.C"() : () -> () "tf.C"() : () -> ()
tf_device.return tf_device.return
}) {cluster_attr = "cluster_attr"} : () -> () }) : () -> ()
return
}
// CHECK-LABEL: func @check_op_without_usage_not_outside_compiled
func @check_op_without_usage_not_outside_compiled(%arg0: tensor<?xi32>) {
// CHECK: "tf_device.cluster"
// CHECK-NEXT: "tf.Identity"
// CHECK-NOT: _xla_outside_compilation
"tf_device.cluster"() ( {
"tf.Identity"(%arg0) : (tensor<?xi32>) -> (tensor<?xi32>)
"tf.C"() : () -> ()
tf_device.return
}) : () -> ()
return return
} }

View File

@ -87,7 +87,8 @@ void CreateTPUBridgePipeline(OpPassManager &pm) {
// because DecomposeResourceOpsPass uses pattern rewriter which hoists // because DecomposeResourceOpsPass uses pattern rewriter which hoists
// changed constants out of tf_device.Launch. // changed constants out of tf_device.Launch.
func_pm.addPass(TFDevice::CreateDecomposeResourceOpsPass()); func_pm.addPass(TFDevice::CreateDecomposeResourceOpsPass());
pm.addNestedPass<FuncOp>(CreateTPUHostComputationExpansionPass());
pm.addPass(CreateTPUExtractHeadTailOutsideCompilationPass());
// Run another shape inference pass because resource decomposition might have // Run another shape inference pass because resource decomposition might have
// created new partial types. // created new partial types.
pm.addPass(TF::CreateTFShapeInferencePass()); pm.addPass(TF::CreateTFShapeInferencePass());

View File

@ -83,6 +83,13 @@ def BitcastSameType : Pat<(TF_BitcastOp:$res $arg), (replaceWithValue $arg),
def BitcastNested : Pat<(TF_BitcastOp (TF_BitcastOp $arg)), def BitcastNested : Pat<(TF_BitcastOp (TF_BitcastOp $arg)),
(TF_BitcastOp $arg)>; (TF_BitcastOp $arg)>;
//===----------------------------------------------------------------------===//
// Concat op patterns.
//===----------------------------------------------------------------------===//
def ConvertToConcatV2 : Pat<(TF_ConcatOp $axis, $inputs),
(TF_ConcatV2Op $inputs, $axis)>;
//===----------------------------------------------------------------------===// //===----------------------------------------------------------------------===//
// Conj op patterns. // Conj op patterns.
//===----------------------------------------------------------------------===// //===----------------------------------------------------------------------===//

View File

@ -50,6 +50,24 @@ static Type GetResourceSubtypeOrDefault(Value resource, Type element_type) {
return UnrankedTensorType::get(element_type); return UnrankedTensorType::get(element_type);
} }
static bool HasResourceSubtype(Value resource) {
return resource.getType()
.cast<TensorType>()
.getElementType()
.cast<ResourceType>()
.getSubtypes()
.size() == 1;
}
static Type GetResourceSubtype(Value resource) {
return resource.getType()
.cast<TensorType>()
.getElementType()
.cast<ResourceType>()
.getSubtypes()
.front();
}
#include "tensorflow/compiler/mlir/tensorflow/transforms/generated_decompose_resource_ops.inc" #include "tensorflow/compiler/mlir/tensorflow/transforms/generated_decompose_resource_ops.inc"
} // namespace } // namespace

View File

@ -31,6 +31,12 @@ def CreateTFReadVariableOp: NativeCodeCall<
" $2)" " $2)"
>; >;
def CheckHasResourceSubtype : Constraint<CPred<"HasResourceSubtype($0)">>;
def CreateTFReadVariableOpFromResourceHandle : NativeCodeCall<
"$_builder.create<TF::ReadVariableOp>("
"$0.getLoc(), GetResourceSubtype($1), $1)">;
def DecomposeAssignAddVariableOp : def DecomposeAssignAddVariableOp :
Pat< Pat<
(TF_AssignAddVariableOp:$src_op $resource, $value), (TF_AssignAddVariableOp:$src_op $resource, $value),
@ -315,3 +321,9 @@ def DecomposeResourceScatterUpdate : Pat<
$updates $updates
) )
)>; )>;
// Pattern to decompose tf.VariableShape into tf.ReadVariable and tf.Shape.
def DecomposeVariableShape : Pat<
(TF_VariableShapeOp:$src_op $resource),
(TF_ShapeOp (CreateTFReadVariableOpFromResourceHandle $src_op, $resource)),
[(CheckHasResourceSubtype $resource)]>;

View File

@ -0,0 +1,225 @@
/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#include <cstdio>
#include <iostream>
#include "llvm/ADT/StringRef.h"
#include "mlir/IR/Attributes.h" // from @llvm-project
#include "mlir/IR/Diagnostics.h" // from @llvm-project
#include "mlir/IR/Function.h" // from @llvm-project
#include "mlir/IR/PatternMatch.h" // from @llvm-project
#include "mlir/Pass/Pass.h" // from @llvm-project
#include "mlir/Support/LLVM.h" // from @llvm-project
#include "mlir/Support/LogicalResult.h" // from @llvm-project
#include "tensorflow/compiler/mlir/tensorflow/ir/tf_ops.h"
namespace mlir {
namespace TF {
namespace {
// Note: This implements fusions performed in the old Remapper Grappler pass.
// That pass has specific cases for GPU and based on different target
// configurations on both CPU and GPU (Intel MKL, ROCm, etc.). This MLIR pass
// covers the general CPU case and at the moment does not account for any
// target-specific configurations.
// TODO(b/158265178): Support GPU-specific fusions.
// TODO(b/158266710): Support CPU MKL configurations.
// Optimizes TF computations by fusing subgraphs/nodes onto more efficient
// implementations to decrease the number of operations needed to perform a
// computation.
struct FusedKernelMatcherPass
: public PassWrapper<FusedKernelMatcherPass, FunctionPass> {
void runOnFunction() override;
};
// Returns an op's name with the dialect prefix stripped off.
StringRef GetOpNameWithoutDialect(Operation *op) {
return op->getName().getStringRef().split(".").second;
}
bool IsActivationFunction(Operation *op) {
return isa<EluOp>(op) || isa<ReluOp>(op) || isa<Relu6Op>(op);
}
// Finds and returns an activation op that uses the result of `op`. If there are
// multiple such activations, one is returned (with no guarantee as to which
// one). If there are no activation functions that use the output, returns
// nullptr.
Operation *GetActivation(Value op) {
for (auto &use : op.getUses()) {
if (IsActivationFunction(use.getOwner())) return use.getOwner();
}
return nullptr;
}
// Finds and returns a BiasAdd that uses the result of `op` as the `value`
// input. If there are multiple such BiasAdds, one is returned (with no
// guarantee as to which one). If there are no BiasAdds that use the output,
// returns a null BiasAddOp.
BiasAddOp GetBiasAdd(Value op) {
for (auto &use : op.getUses()) {
auto bias_add = dyn_cast_or_null<BiasAddOp>(use.getOwner());
// If it's a BiasAdd, check that the conv op is the first input.
if (bias_add && bias_add.value() == op) return bias_add;
}
// No BiasAddOps found among uses.
return BiasAddOp();
}
// Performs a fusion of the following pattern(s), if possible:
// <Contraction> + BiasAdd + <Activation> -> <FusedContraction>
//
// Note that fusion with activation is preferred, but a contraction and BiasAdd
// can also be replaced by a _FusedConv2D if there is no other activation
// function.
// i.e., this class also supports the following fusion:
// <Contraction> + BiasAdd -> <FusedContraction>
//
// TODO(b/158266331): Support fusing activation chains of arbitrary length.
template <typename SrcOpT, typename FusedOpT>
class FuseContractionWithBiasAdd : public OpRewritePattern<SrcOpT> {
public:
using OpRewritePattern<SrcOpT>::OpRewritePattern;
// Class users should override this method if there are any op-specific
// compatibility requirements between the contraction op and the BiasAdd op.
virtual bool AreFuseCompatible(SrcOpT contraction_op, BiasAddOp bias_add,
PatternRewriter &rewriter) const {
return true;
}
LogicalResult matchAndRewrite(SrcOpT contraction,
PatternRewriter &rewriter) const override {
auto context = rewriter.getContext();
// If the contraction is used in multiple places, fusing it will only create
// more contraction nodes, which is slower.
if (!contraction.getResult().hasOneUse())
return rewriter.notifyMatchFailure(contraction,
"result is used by multiple ops");
BiasAddOp bias_add = GetBiasAdd(contraction.getResult());
if (!bias_add) {
return rewriter.notifyMatchFailure(
contraction, "does not feed into a tf.BiasAdd/tf.BiasAddV1 op");
}
if (!AreFuseCompatible(contraction, bias_add, rewriter)) {
return rewriter.notifyMatchFailure(
contraction, "cannot fuse with the subsequent BiasAdd op");
}
SmallVector<Location, 3> locations{contraction.getLoc(), bias_add.getLoc()};
SmallVector<Attribute, 2> fused_ops{
StringAttr::get(GetOpNameWithoutDialect(bias_add), context)};
// BiasAdd may or may not feed into an activation function.
auto activation = GetActivation(bias_add);
// If there is an activation, only fuse it if this is the only op to use the
// result of the BiasAdd.
bool fuse_activation = activation && bias_add.output().hasOneUse();
Type result_type;
// Include info about the activation function if applicable.
if (fuse_activation) {
locations.push_back(activation->getLoc());
fused_ops.push_back(
StringAttr::get(GetOpNameWithoutDialect(activation), context));
result_type = activation->getResultTypes().front();
} else {
result_type = bias_add.getResult().getType();
}
auto fused_loc = rewriter.getFusedLoc(locations);
// The fused contraction has the same operands as the original contraction
// with `bias` from the BiasAddOp appended.
SmallVector<Value, 4> operands(contraction.operand_begin(),
contraction.operand_end());
operands.push_back(bias_add.bias());
// The fused contraction has the same attributes as the original
// contraction, with two additions: the list of ops which have been fused
// together; epsilon (only with FusedBatchNorm).
std::vector<NamedAttribute> attrs = contraction.getAttrs();
ArrayAttr fused_ops_attr = ArrayAttr::get(fused_ops, context);
attrs.push_back(
NamedAttribute(Identifier::get("fused_ops", context), fused_ops_attr));
// Epsilon is used only in fusions with the FusedBatchNorm op, so we zero it
// here.
Attribute epsilon = rewriter.getF32FloatAttr(0);
attrs.push_back(
NamedAttribute(Identifier::get("epsilon", context), epsilon));
Value fused_op = rewriter.create<FusedOpT>(fused_loc, result_type,
ValueRange(operands), attrs);
auto op_to_replace = fuse_activation ? activation : bias_add;
rewriter.replaceOp(op_to_replace, ValueRange({fused_op}));
return success();
}
};
// Performs a fusion of the following pattern(s), if possible:
// Conv2D + BiasAdd + <Activation> -> _FusedConv2D
class FuseConv2DBiasAdd
: public FuseContractionWithBiasAdd<Conv2DOp, _FusedConv2DOp> {
public:
using FuseContractionWithBiasAdd<Conv2DOp,
_FusedConv2DOp>::FuseContractionWithBiasAdd;
// Verify that the Conv2D and BiasAdd data formats match. This is necessary
// for the ops to fuse correctly, the fused Conv2D op has one data format
// attribute which is shared.
bool AreFuseCompatible(Conv2DOp conv, BiasAddOp bias_add,
PatternRewriter &rewriter) const override {
// Verify that the data formats match and are valid for fusion.
if (conv.data_format() != bias_add.data_format()) {
rewriter.notifyMatchFailure(conv, [&](Diagnostic &diag) {
diag << "data format does not match Conv2D data format ("
<< bias_add.data_format() << " vs " << conv.data_format() << ")";
});
return false;
}
return true;
}
};
// Performs a fusion of the following pattern(s), if possible:
// MatMulOp + BiasAdd + <Activation> -> _FusedMatMulOp
using FuseMatMulBiasAdd = FuseContractionWithBiasAdd<MatMulOp, _FusedMatMulOp>;
void FusedKernelMatcherPass::runOnFunction() {
OwningRewritePatternList patterns;
auto func = getFunction();
patterns.insert<FuseConv2DBiasAdd, FuseMatMulBiasAdd>(&getContext());
applyPatternsAndFoldGreedily(func, patterns);
}
} // namespace
std::unique_ptr<OperationPass<FuncOp>> CreateFusedKernelMatcherPass() {
return std::make_unique<FusedKernelMatcherPass>();
}
static PassRegistration<FusedKernelMatcherPass> pass(
"tf-fused-kernel-matcher",
"Matches computations corresponding to optimized fused kernels");
} // namespace TF
} // namespace mlir

View File

@ -1,174 +0,0 @@
/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#include <cstdio>
#include <iostream>
#include "llvm/ADT/StringRef.h"
#include "mlir/IR/Attributes.h" // from @llvm-project
#include "mlir/IR/Diagnostics.h" // from @llvm-project
#include "mlir/IR/Function.h" // from @llvm-project
#include "mlir/IR/PatternMatch.h" // from @llvm-project
#include "mlir/Pass/Pass.h" // from @llvm-project
#include "mlir/Support/LLVM.h" // from @llvm-project
#include "mlir/Support/LogicalResult.h" // from @llvm-project
#include "tensorflow/compiler/mlir/tensorflow/ir/tf_ops.h"
namespace mlir {
namespace TF {
namespace {
// Note: This implements the fusions performed in the old Remapper Grappler
// pass. That pass has specific cases for GPU and based on different
// target configurations on both CPU and GPU (Intel MKL, ROCm, etc.). This MLIR
// pass covers the general CPU case and at the moment does not account for any
// specific target configurations.
// TODO(b/158265178): Support GPU-specific fusions.
// TODO(b/158266710): Support CPU MKL configurations.
// Optimizes TF computations by fusing subgraphs/nodes onto more efficient
// implementations to decrease the number of operations needed to perform a
// computation.
struct OpFusionPass : public PassWrapper<OpFusionPass, FunctionPass> {
void runOnFunction() override;
};
// Returns an op's name with the dialect prefix stripped off.
StringRef GetOpNameWithoutDialect(Operation *op) {
return op->getName().getStringRef().split(".").second;
}
bool IsActivationFunction(Operation *op) {
return isa<EluOp>(op) || isa<ReluOp>(op) || isa<Relu6Op>(op);
}
// Finds and returns an activation op that uses the result of `op`. If there are
// multiple such activations, one is returned (with no guarantee as to which
// one). If there are no activation functions that use the output, returns
// nullptr.
Operation *GetActivation(Value op) {
for (auto &use : op.getUses()) {
if (IsActivationFunction(use.getOwner())) return use.getOwner();
}
return nullptr;
}
// Finds and returns a BiasAdd that uses the result of `op` as the `value`
// input. If there are multiple such BiasAdds, one is returned (with no
// guarantee as to which one). If there are no BiasAdds that use the output,
// returns a null BiasAddOp.
BiasAddOp GetBiasAdd(Value op) {
for (auto &use : op.getUses()) {
auto bias_add = dyn_cast_or_null<BiasAddOp>(use.getOwner());
// If it's a BiasAdd, check that the conv op is the first input.
if (bias_add && bias_add.value() == op) return bias_add;
}
// No BiasAddOps found among uses.
return BiasAddOp();
}
// Performs a fusion of the following pattern(s), if possible:
// Conv2D + BiasAdd + <Activation> -> _FusedConv2D
//
// Note that fusion with activation is preferred, but a Conv2D and BiasAdd can
// also be replaced by a _FusedConv2D if there is no other activation function.
// i.e., this class also supports the following fusion:
// Conv2D + BiasAdd -> _FusedConv2D
//
// TODO(b/158266331): Support fusing Conv2D + BiasAdd + a chain of activations.
class FuseConv2DBiasAdd : public OpRewritePattern<Conv2DOp> {
public:
using OpRewritePattern::OpRewritePattern;
LogicalResult matchAndRewrite(Conv2DOp op,
PatternRewriter &rewriter) const override {
// If the convolution is used in multiple places, fusing it will only create
// more convolutions, which is slower.
if (!op.getResult().hasOneUse())
return rewriter.notifyMatchFailure(op, "result is used by multiple ops");
BiasAddOp bias_add = GetBiasAdd(op);
if (!bias_add) {
return rewriter.notifyMatchFailure(
op, "does not feed into a tf.BiasAdd/tf.BiasAddV1 op");
}
// Check that Conv and BiasAdd formats match.
if (op.data_format() != bias_add.data_format()) {
return rewriter.notifyMatchFailure(op, [&](Diagnostic &diag) {
diag << "data format does not match Conv2D data format ("
<< bias_add.data_format() << " vs " << op.data_format() << ")";
});
}
SmallVector<Location, 3> locations{op.getLoc(), bias_add.getLoc()};
SmallVector<Attribute, 2> fused_ops{StringAttr::get(
GetOpNameWithoutDialect(bias_add), rewriter.getContext())};
Type result_type;
// BiasAdd may or may not feed into an activation function.
auto activation = GetActivation(bias_add);
// If there is an activation, only fuse it if this is the only op to use the
// result of the BiasAdd.
bool fuse_activation = activation && bias_add.output().hasOneUse();
// Include info about the activation function if applicable.
if (fuse_activation) {
locations.push_back(activation->getLoc());
fused_ops.push_back(StringAttr::get(GetOpNameWithoutDialect(activation),
rewriter.getContext()));
result_type = activation->getResultTypes().front();
} else {
result_type = bias_add.getResult().getType();
}
auto loc = rewriter.getFusedLoc(locations);
ArrayAttr fused_ops_attr = ArrayAttr::get(fused_ops, rewriter.getContext());
// Epsilon is used only in fusions with the BatchNorm op.
APFloat epsilon = APFloat(0.0f);
auto fused_op = rewriter.create<_FusedConv2DOp>(
loc, result_type, op.input(), op.filter(), bias_add.bias(),
op.strides(), op.padding(), op.explicit_paddings(), op.data_format(),
op.dilations(), op.use_cudnn_on_gpu(), fused_ops_attr, epsilon);
auto op_to_replace = fuse_activation ? activation : bias_add;
rewriter.replaceOp(op_to_replace, {fused_op});
return success();
}
};
void OpFusionPass::runOnFunction() {
OwningRewritePatternList patterns;
auto func = getFunction();
patterns.insert<FuseConv2DBiasAdd>(&getContext());
applyPatternsAndFoldGreedily(func, patterns);
}
} // namespace
std::unique_ptr<OperationPass<FuncOp>> CreateOpFusionPass() {
return std::make_unique<OpFusionPass>();
}
static PassRegistration<OpFusionPass> pass(
"tf-op-fusion",
"Replaces commonly occurring subgraphs with optimized fused kernels");
} // namespace TF
} // namespace mlir

View File

@ -54,7 +54,7 @@ std::unique_ptr<OperationPass<FuncOp>> CreateTFOptimizePass();
// Creates pass to rewrite RecvTPUEmbeddingActivationsOp and // Creates pass to rewrite RecvTPUEmbeddingActivationsOp and
// SendTPUEmbeddingGradients ops to internal variants. // SendTPUEmbeddingGradients ops to internal variants.
std::unique_ptr<OperationPass<FuncOp>> CreateRewriteTPUEmbeddingOps(); std::unique_ptr<OperationPass<FuncOp>> CreateRewriteTPUEmbeddingOpsPass();
// Performs specific fusion for GPU targets. // Performs specific fusion for GPU targets.
std::unique_ptr<OperationPass<FuncOp>> CreateGpuOpFusionPass(); std::unique_ptr<OperationPass<FuncOp>> CreateGpuOpFusionPass();
@ -148,8 +148,10 @@ CreateTensorArrayOpsDecompositionPass();
// Create a pass that legalize HLO to TF dialect. // Create a pass that legalize HLO to TF dialect.
std::unique_ptr<OperationPass<FuncOp>> CreateLegalizeHloToTfPass(); std::unique_ptr<OperationPass<FuncOp>> CreateLegalizeHloToTfPass();
// Creates a pass that performs fusion of common sequences of ops. // Matches sequence of ops to TensorFlow fused kernels. This pass should not be
std::unique_ptr<OperationPass<FuncOp>> CreateOpFusionPass(); // generally used beyond exporting to runtimes that supports these ops. In the
// future these fusions may be codegen'd automatically.
std::unique_ptr<OperationPass<FuncOp>> CreateFusedKernelMatcherPass();
} // namespace TF } // namespace TF
namespace tf_executor { namespace tf_executor {

View File

@ -101,7 +101,7 @@ void RewriteTPUEmbeddingOps::runOnFunction() {
} // anonymous namespace } // anonymous namespace
std::unique_ptr<OperationPass<FuncOp>> CreateRewriteTPUEmbeddingOps() { std::unique_ptr<OperationPass<FuncOp>> CreateRewriteTPUEmbeddingOpsPass() {
return std::make_unique<RewriteTPUEmbeddingOps>(); return std::make_unique<RewriteTPUEmbeddingOps>();
} }

View File

@ -209,8 +209,10 @@ void CreateHeadComputation(OpBuilder* builder, tf_device::ClusterOp cluster,
llvm::ArrayRef<Operation*> head_outside_compiled_ops, llvm::ArrayRef<Operation*> head_outside_compiled_ops,
llvm::StringRef host_device) { llvm::StringRef host_device) {
Block* launch_block = new Block; Block* launch_block = new Block;
for (Operation* head_outside_compiled_op : head_outside_compiled_ops) for (Operation* head_outside_compiled_op : head_outside_compiled_ops) {
head_outside_compiled_op->removeAttr(kXlaOutsideCompilationAttr);
head_outside_compiled_op->moveBefore(launch_block, launch_block->end()); head_outside_compiled_op->moveBefore(launch_block, launch_block->end());
}
tf_device::LaunchOp launch = CreateLaunchForBlock( tf_device::LaunchOp launch = CreateLaunchForBlock(
builder, cluster, /*before=*/true, launch_block, host_device); builder, cluster, /*before=*/true, launch_block, host_device);
@ -294,8 +296,10 @@ void CreateTailComputation(OpBuilder* builder, tf_device::ClusterOp cluster,
llvm::ArrayRef<Operation*> tail_outside_compiled_ops, llvm::ArrayRef<Operation*> tail_outside_compiled_ops,
llvm::StringRef host_device) { llvm::StringRef host_device) {
Block* launch_block = new Block; Block* launch_block = new Block;
for (Operation* tail_outside_compiled_op : tail_outside_compiled_ops) for (Operation* tail_outside_compiled_op : tail_outside_compiled_ops) {
tail_outside_compiled_op->removeAttr(kXlaOutsideCompilationAttr);
tail_outside_compiled_op->moveBefore(launch_block, launch_block->begin()); tail_outside_compiled_op->moveBefore(launch_block, launch_block->begin());
}
tf_device::LaunchOp launch = CreateLaunchForBlock( tf_device::LaunchOp launch = CreateLaunchForBlock(
builder, cluster, /*before=*/false, launch_block, host_device); builder, cluster, /*before=*/false, launch_block, host_device);

View File

@ -92,10 +92,13 @@ void ExpandHeadOutsideCompiledOps(tf_device::ClusterOp cluster,
for (auto head_outside_compiled_op : for (auto head_outside_compiled_op :
llvm::reverse(head_outside_compiled_ops)) { llvm::reverse(head_outside_compiled_ops)) {
if (HasOutsideCompilationAttribute(head_outside_compiled_op)) continue; auto users = head_outside_compiled_op->getUsers();
if (users.empty() ||
HasOutsideCompilationAttribute(head_outside_compiled_op))
continue;
bool should_expand_op_to_host_computation = true; bool should_expand_op_to_host_computation = true;
for (auto consumer_op : head_outside_compiled_op->getUsers()) { for (auto consumer_op : users) {
if (should_expand_op_to_host_computation && if (should_expand_op_to_host_computation &&
!HasOutsideCompilationAttribute(consumer_op)) { !HasOutsideCompilationAttribute(consumer_op)) {
should_expand_op_to_host_computation = false; should_expand_op_to_host_computation = false;

View File

@ -219,8 +219,7 @@ class ImporterBase {
const absl::InlinedVector<OutputTensor, 4>& arg_nodes, const absl::InlinedVector<OutputTensor, 4>& arg_nodes,
const absl::InlinedVector<OutputTensor, 4>& ret_nodes, const absl::InlinedVector<OutputTensor, 4>& ret_nodes,
const absl::InlinedVector<Node*, 4>& control_ret_nodes, const absl::InlinedVector<Node*, 4>& control_ret_nodes,
llvm::ArrayRef<mlir::NamedAttribute> attrs, llvm::ArrayRef<mlir::NamedAttribute> attrs);
bool function_graph);
// Finds out the function definition for the given function name from the // Finds out the function definition for the given function name from the
// graph and converts it to a function of the module. This method is called // graph and converts it to a function of the module. This method is called
@ -1302,8 +1301,7 @@ Status ImporterBase::ConvertLibFunction(llvm::StringRef func_name) {
TF_RETURN_IF_ERROR(child_importer.Convert( TF_RETURN_IF_ERROR(child_importer.Convert(
mlir_func_name, func_type, arg_nodes, ret_nodes, control_ret_nodes, mlir_func_name, func_type, arg_nodes, ret_nodes, control_ret_nodes,
llvm::makeArrayRef(attributes.begin(), attributes.end()), llvm::makeArrayRef(attributes.begin(), attributes.end())));
/*function_graph=*/true));
return Status::OK(); return Status::OK();
} }
@ -1405,7 +1403,7 @@ Status ImporterBase::Convert(
const absl::InlinedVector<OutputTensor, 4>& arg_nodes, const absl::InlinedVector<OutputTensor, 4>& arg_nodes,
const absl::InlinedVector<OutputTensor, 4>& ret_nodes, const absl::InlinedVector<OutputTensor, 4>& ret_nodes,
const absl::InlinedVector<Node*, 4>& control_ret_nodes, const absl::InlinedVector<Node*, 4>& control_ret_nodes,
llvm::ArrayRef<mlir::NamedAttribute> attrs, bool function_graph) { llvm::ArrayRef<mlir::NamedAttribute> attrs) {
// TODO(b/122040776): Uses debug info for FunctionDef. // TODO(b/122040776): Uses debug info for FunctionDef.
auto function = mlir::FuncOp::create(mlir::UnknownLoc::get(context_), auto function = mlir::FuncOp::create(mlir::UnknownLoc::get(context_),
func_name, func_type, attrs); func_name, func_type, attrs);
@ -2222,8 +2220,15 @@ StatusOr<mlir::OwningModuleRef> GraphDefImporter::Convert(
PopulateTfVersions(module.get(), graph.versions()); PopulateTfVersions(module.get(), graph.versions());
TF_RETURN_IF_ERROR(importer.ImporterBase::Convert( TF_RETURN_IF_ERROR(importer.ImporterBase::Convert(
func_name, func_type, arg_nodes, ret_nodes, control_ret_nodes, attrs, func_name, func_type, arg_nodes, ret_nodes, control_ret_nodes, attrs));
specs.graph_as_function));
// Mark main function public, others private.
for (auto function : module.get().getOps<mlir::FuncOp>()) {
auto visibility = function.getName() == func_name
? mlir::FuncOp::Visibility::Public
: mlir::FuncOp::Visibility::Private;
function.setVisibility(visibility);
}
return module; return module;
} }
@ -2888,6 +2893,16 @@ void AdjustBoundInputArgTypes(mlir::ModuleOp module) {
} }
} }
// Marks the visibility of functions in the saved model module.
void MarkSavedModelFunctionVisibility(mlir::ModuleOp module) {
for (auto func : module.getOps<mlir::FuncOp>()) {
auto visibility = mlir::tf_saved_model::IsExported(func)
? mlir::FuncOp::Visibility::Public
: mlir::FuncOp::Visibility::Private;
func.setVisibility(visibility);
}
}
// Reorder the ops in the module to make testing easier and less dependent // Reorder the ops in the module to make testing easier and less dependent
// on implementation details such as the order of functions in the // on implementation details such as the order of functions in the
// FunctionDefLibrary. // FunctionDefLibrary.
@ -3130,6 +3145,7 @@ Status CreateSavedModelIR(
AdjustBoundInputArgTypes(module); AdjustBoundInputArgTypes(module);
module.setAttr("tf_saved_model.semantics", builder.getUnitAttr()); module.setAttr("tf_saved_model.semantics", builder.getUnitAttr());
SortSavedModelModule(module); SortSavedModelModule(module);
MarkSavedModelFunctionVisibility(module);
return Status::OK(); return Status::OK();
} }
@ -3299,6 +3315,7 @@ SavedModelSignatureDefImporter::ConvertSignatures() {
mlir::OpBuilder builder(module_->getBodyRegion()); mlir::OpBuilder builder(module_->getBodyRegion());
module_->setAttr("tf_saved_model.semantics", builder.getUnitAttr()); module_->setAttr("tf_saved_model.semantics", builder.getUnitAttr());
SortSavedModelModule(*module_); SortSavedModelModule(*module_);
MarkSavedModelFunctionVisibility(*module_);
return std::move(module_); return std::move(module_);
} }

View File

@ -57,7 +57,7 @@ cc_library(
], ],
deps = [ deps = [
":tfjs_inc_gen", ":tfjs_inc_gen",
"@llvm-project//llvm:support", "@llvm-project//llvm:Support",
"@llvm-project//mlir:Dialect", "@llvm-project//mlir:Dialect",
"@llvm-project//mlir:IR", "@llvm-project//mlir:IR",
"@llvm-project//mlir:SideEffects", "@llvm-project//mlir:SideEffects",
@ -109,7 +109,7 @@ cc_library(
":tensorflow_js", ":tensorflow_js",
":tensorflow_js_dialect_registration", ":tensorflow_js_dialect_registration",
"//tensorflow/compiler/mlir/tensorflow", "//tensorflow/compiler/mlir/tensorflow",
"@llvm-project//llvm:support", "@llvm-project//llvm:Support",
"@llvm-project//mlir:Analysis", "@llvm-project//mlir:Analysis",
"@llvm-project//mlir:IR", "@llvm-project//mlir:IR",
"@llvm-project//mlir:Pass", "@llvm-project//mlir:Pass",
@ -190,7 +190,7 @@ cc_library(
"//tensorflow/stream_executor/lib", "//tensorflow/stream_executor/lib",
"@com_google_absl//absl/strings", "@com_google_absl//absl/strings",
"@com_google_absl//absl/types:span", "@com_google_absl//absl/types:span",
"@llvm-project//llvm:support", "@llvm-project//llvm:Support",
"@llvm-project//mlir:AllPassesAndDialects", "@llvm-project//mlir:AllPassesAndDialects",
"@llvm-project//mlir:IR", "@llvm-project//mlir:IR",
"@llvm-project//mlir:Parser", "@llvm-project//mlir:Parser",
@ -229,7 +229,7 @@ tf_cc_binary(
"//tensorflow/core/platform:errors", "//tensorflow/core/platform:errors",
"//tensorflow/stream_executor/lib", "//tensorflow/stream_executor/lib",
"@com_google_absl//absl/strings", "@com_google_absl//absl/strings",
"@llvm-project//llvm:support", "@llvm-project//llvm:Support",
"@llvm-project//mlir:IR", "@llvm-project//mlir:IR",
"@llvm-project//mlir:Pass", "@llvm-project//mlir:Pass",
"@llvm-project//mlir:Support", "@llvm-project//mlir:Support",

View File

@ -11,7 +11,7 @@ cc_library(
deps = [ deps = [
"@com_google_absl//absl/memory", "@com_google_absl//absl/memory",
"@com_google_absl//absl/strings", "@com_google_absl//absl/strings",
"@llvm-project//llvm:support", "@llvm-project//llvm:Support",
"@llvm-project//mlir:AllPassesAndDialects", "@llvm-project//mlir:AllPassesAndDialects",
"@llvm-project//mlir:GPUDialect", "@llvm-project//mlir:GPUDialect",
"@llvm-project//mlir:IR", "@llvm-project//mlir:IR",

View File

@ -107,7 +107,7 @@ gentbl(
td_file = "transforms/legalize_tf_patterns.td", td_file = "transforms/legalize_tf_patterns.td",
td_srcs = [ td_srcs = [
":hlo_ops_td_files", ":hlo_ops_td_files",
"@llvm-project//llvm:support", "@llvm-project//llvm:Support",
"@llvm-project//mlir:StdOpsTdFiles", "@llvm-project//mlir:StdOpsTdFiles",
"//tensorflow/compiler/mlir/tensorflow:tensorflow_ops_td_files", "//tensorflow/compiler/mlir/tensorflow:tensorflow_ops_td_files",
], ],
@ -177,7 +177,8 @@ cc_library(
"//tensorflow/compiler/xla/client:sharding_builder", "//tensorflow/compiler/xla/client:sharding_builder",
"//tensorflow/core:framework", "//tensorflow/core:framework",
"//tensorflow/core/kernels:conv_grad_shape_utils", "//tensorflow/core/kernels:conv_grad_shape_utils",
"@llvm-project//llvm:support", "//tensorflow/core/lib/bfloat16",
"@llvm-project//llvm:Support",
"@llvm-project//mlir:Analysis", "@llvm-project//mlir:Analysis",
"@llvm-project//mlir:Dialect", "@llvm-project//mlir:Dialect",
"@llvm-project//mlir:IR", "@llvm-project//mlir:IR",
@ -217,7 +218,7 @@ cc_library(
"@com_google_absl//absl/container:inlined_vector", "@com_google_absl//absl/container:inlined_vector",
"@com_google_absl//absl/memory", "@com_google_absl//absl/memory",
"@com_google_absl//absl/strings", "@com_google_absl//absl/strings",
"@llvm-project//llvm:support", "@llvm-project//llvm:Support",
"@llvm-project//mlir:IR", "@llvm-project//mlir:IR",
"@llvm-project//mlir:Pass", "@llvm-project//mlir:Pass",
"@llvm-project//mlir:StandardOps", "@llvm-project//mlir:StandardOps",
@ -233,7 +234,7 @@ cc_library(
":hlo", ":hlo",
"//tensorflow/compiler/mlir/tensorflow", "//tensorflow/compiler/mlir/tensorflow",
"//tensorflow/compiler/mlir/tensorflow:lower_tf_lib", "//tensorflow/compiler/mlir/tensorflow:lower_tf_lib",
"@llvm-project//llvm:support", "@llvm-project//llvm:Support",
"@llvm-project//mlir:IR", "@llvm-project//mlir:IR",
"@llvm-project//mlir:Pass", "@llvm-project//mlir:Pass",
"@llvm-project//mlir:Support", "@llvm-project//mlir:Support",
@ -249,7 +250,7 @@ cc_library(
":hlo", ":hlo",
":lhlo", ":lhlo",
":map_hlo_to_lhlo_op", ":map_hlo_to_lhlo_op",
"@llvm-project//llvm:support", "@llvm-project//llvm:Support",
"@llvm-project//mlir:StandardOps", "@llvm-project//mlir:StandardOps",
], ],
) )
@ -272,7 +273,7 @@ cc_library(
":map_xla_to_scalar_op", ":map_xla_to_scalar_op",
"//tensorflow/compiler/xla:status", "//tensorflow/compiler/xla:status",
"@com_google_absl//absl/memory", "@com_google_absl//absl/memory",
"@llvm-project//llvm:support", "@llvm-project//llvm:Support",
"@llvm-project//mlir:Affine", "@llvm-project//mlir:Affine",
"@llvm-project//mlir:IR", "@llvm-project//mlir:IR",
"@llvm-project//mlir:Pass", "@llvm-project//mlir:Pass",
@ -287,7 +288,7 @@ cc_library(
deps = [ deps = [
":lhlo", ":lhlo",
"@com_google_absl//absl/memory", "@com_google_absl//absl/memory",
"@llvm-project//llvm:support", "@llvm-project//llvm:Support",
"@llvm-project//mlir:IR", "@llvm-project//mlir:IR",
"@llvm-project//mlir:LinalgOps", "@llvm-project//mlir:LinalgOps",
"@llvm-project//mlir:Pass", "@llvm-project//mlir:Pass",
@ -321,7 +322,7 @@ cc_library(
":lhlo", ":lhlo",
":map_xla_to_scalar_op", ":map_xla_to_scalar_op",
"@com_google_absl//absl/memory", "@com_google_absl//absl/memory",
"@llvm-project//llvm:support", "@llvm-project//llvm:Support",
"@llvm-project//mlir:IR", "@llvm-project//mlir:IR",
"@llvm-project//mlir:LinalgOps", "@llvm-project//mlir:LinalgOps",
"@llvm-project//mlir:Pass", "@llvm-project//mlir:Pass",
@ -354,7 +355,7 @@ cc_library(
":lhlo", ":lhlo",
":map_xla_to_scalar_op", ":map_xla_to_scalar_op",
"@com_google_absl//absl/memory", "@com_google_absl//absl/memory",
"@llvm-project//llvm:support", "@llvm-project//llvm:Support",
"@llvm-project//mlir:GPUDialect", "@llvm-project//mlir:GPUDialect",
"@llvm-project//mlir:IR", "@llvm-project//mlir:IR",
"@llvm-project//mlir:LinalgOps", "@llvm-project//mlir:LinalgOps",
@ -372,7 +373,7 @@ cc_library(
deps = [ deps = [
":lhlo", ":lhlo",
"@com_google_absl//absl/memory", "@com_google_absl//absl/memory",
"@llvm-project//llvm:support", "@llvm-project//llvm:Support",
"@llvm-project//mlir:LinalgOps", "@llvm-project//mlir:LinalgOps",
"@llvm-project//mlir:LinalgTransforms", "@llvm-project//mlir:LinalgTransforms",
"@llvm-project//mlir:Pass", "@llvm-project//mlir:Pass",
@ -416,7 +417,7 @@ cc_library(
srcs = ["transforms/cycle_detector.cc"], srcs = ["transforms/cycle_detector.cc"],
hdrs = ["transforms/cycle_detector.h"], hdrs = ["transforms/cycle_detector.h"],
deps = [ deps = [
"@llvm-project//llvm:support", "@llvm-project//llvm:Support",
], ],
alwayslink = 1, alwayslink = 1,
) )
@ -437,8 +438,8 @@ cc_library(
deps = [ deps = [
":cycle_detector", ":cycle_detector",
":hlo", ":hlo",
"@llvm-project//llvm:ir", "@llvm-project//llvm:Core",
"@llvm-project//llvm:support", "@llvm-project//llvm:Support",
"@llvm-project//mlir:IR", "@llvm-project//mlir:IR",
"@llvm-project//mlir:Pass", "@llvm-project//mlir:Pass",
"@llvm-project//mlir:StandardOps", "@llvm-project//mlir:StandardOps",
@ -466,7 +467,7 @@ cc_library(
srcs = ["transforms/legalize_control_flow.cc"], srcs = ["transforms/legalize_control_flow.cc"],
deps = [ deps = [
":hlo", ":hlo",
"@llvm-project//llvm:support", "@llvm-project//llvm:Support",
"@llvm-project//mlir:Analysis", "@llvm-project//mlir:Analysis",
"@llvm-project//mlir:IR", "@llvm-project//mlir:IR",
"@llvm-project//mlir:Pass", "@llvm-project//mlir:Pass",
@ -490,7 +491,7 @@ cc_library(
"//tensorflow/compiler/xla:util", "//tensorflow/compiler/xla:util",
"//tensorflow/compiler/xla/service:buffer_assignment", "//tensorflow/compiler/xla/service:buffer_assignment",
"//tensorflow/compiler/xla/service:hlo", "//tensorflow/compiler/xla/service:hlo",
"@llvm-project//llvm:support", "@llvm-project//llvm:Support",
"@llvm-project//mlir:IR", "@llvm-project//mlir:IR",
"@llvm-project//mlir:Pass", "@llvm-project//mlir:Pass",
"@llvm-project//mlir:StandardOps", "@llvm-project//mlir:StandardOps",
@ -504,7 +505,7 @@ cc_library(
deps = [ deps = [
":hlo", ":hlo",
":xla_legalize_to_standard_inc_gen", ":xla_legalize_to_standard_inc_gen",
"@llvm-project//llvm:support", "@llvm-project//llvm:Support",
"@llvm-project//mlir:Analysis", "@llvm-project//mlir:Analysis",
"@llvm-project//mlir:IR", "@llvm-project//mlir:IR",
"@llvm-project//mlir:Pass", "@llvm-project//mlir:Pass",
@ -522,7 +523,7 @@ gentbl(
td_file = "transforms/lower_complex_patterns.td", td_file = "transforms/lower_complex_patterns.td",
td_srcs = [ td_srcs = [
":hlo_ops_td_files", ":hlo_ops_td_files",
"@llvm-project//llvm:support", "@llvm-project//llvm:Support",
"@llvm-project//mlir:StdOpsTdFiles", "@llvm-project//mlir:StdOpsTdFiles",
], ],
) )
@ -537,7 +538,7 @@ cc_library(
deps = [ deps = [
":hlo", ":hlo",
":xla_dialect_registration", ":xla_dialect_registration",
"@llvm-project//llvm:support", "@llvm-project//llvm:Support",
"@llvm-project//mlir:Analysis", "@llvm-project//mlir:Analysis",
"@llvm-project//mlir:IR", "@llvm-project//mlir:IR",
"@llvm-project//mlir:Pass", "@llvm-project//mlir:Pass",
@ -564,7 +565,7 @@ cc_library(
srcs = ["transforms/unfuse_batch_norm.cc"], srcs = ["transforms/unfuse_batch_norm.cc"],
deps = [ deps = [
":hlo", ":hlo",
"@llvm-project//llvm:support", "@llvm-project//llvm:Support",
"@llvm-project//mlir:IR", "@llvm-project//mlir:IR",
"@llvm-project//mlir:StandardOps", "@llvm-project//mlir:StandardOps",
"@llvm-project//mlir:Transforms", "@llvm-project//mlir:Transforms",
@ -637,7 +638,7 @@ cc_library(
":infer_fusibility_op_interface", ":infer_fusibility_op_interface",
":xla_canonicalize_inc_gen", ":xla_canonicalize_inc_gen",
"@com_google_absl//absl/container:flat_hash_set", "@com_google_absl//absl/container:flat_hash_set",
"@llvm-project//llvm:support", "@llvm-project//llvm:Support",
"@llvm-project//mlir:Analysis", "@llvm-project//mlir:Analysis",
"@llvm-project//mlir:IR", "@llvm-project//mlir:IR",
"@llvm-project//mlir:InferTypeOpInterface", "@llvm-project//mlir:InferTypeOpInterface",
@ -659,6 +660,7 @@ cc_library(
deps = [ deps = [
":attribute_importer", ":attribute_importer",
":hlo", ":hlo",
":hlo_module_importer",
":hlo_utils", ":hlo_utils",
":type_to_shape", ":type_to_shape",
"//tensorflow/compiler/xla:comparison_util", "//tensorflow/compiler/xla:comparison_util",
@ -671,7 +673,7 @@ cc_library(
"//tensorflow/core/platform:types", "//tensorflow/core/platform:types",
"//tensorflow/stream_executor/lib", "//tensorflow/stream_executor/lib",
"@com_google_absl//absl/container:flat_hash_map", "@com_google_absl//absl/container:flat_hash_map",
"@llvm-project//llvm:support", "@llvm-project//llvm:Support",
"@llvm-project//mlir:IR", "@llvm-project//mlir:IR",
], ],
) )
@ -692,7 +694,7 @@ cc_library(
deps = [ deps = [
":hlo_ops_base_inc_gen", ":hlo_ops_base_inc_gen",
":lhlo_ops_inc_gen", ":lhlo_ops_inc_gen",
"@llvm-project//llvm:support", "@llvm-project//llvm:Support",
"@llvm-project//mlir:Analysis", "@llvm-project//mlir:Analysis",
"@llvm-project//mlir:IR", "@llvm-project//mlir:IR",
"@llvm-project//mlir:Pass", "@llvm-project//mlir:Pass",
@ -748,7 +750,7 @@ cc_library(
"//tensorflow/core:framework", "//tensorflow/core:framework",
"//tensorflow/core/platform:logging", "//tensorflow/core/platform:logging",
"//tensorflow/core/platform:types", "//tensorflow/core/platform:types",
"@llvm-project//llvm:support", "@llvm-project//llvm:Support",
"@llvm-project//mlir:IR", "@llvm-project//mlir:IR",
"@llvm-project//mlir:Support", "@llvm-project//mlir:Support",
], ],
@ -798,7 +800,7 @@ cc_library(
"//tensorflow/core:lib", "//tensorflow/core:lib",
"//tensorflow/core:protos_all_cc", "//tensorflow/core:protos_all_cc",
"//tensorflow/stream_executor/lib", "//tensorflow/stream_executor/lib",
"@llvm-project//llvm:support", "@llvm-project//llvm:Support",
"@llvm-project//mlir:Analysis", "@llvm-project//mlir:Analysis",
"@llvm-project//mlir:IR", "@llvm-project//mlir:IR",
"@llvm-project//mlir:Pass", "@llvm-project//mlir:Pass",
@ -846,7 +848,7 @@ cc_library(
"//tensorflow/compiler/xla/service:hlo_casting_utils", "//tensorflow/compiler/xla/service:hlo_casting_utils",
"//tensorflow/core:lib", "//tensorflow/core:lib",
"@com_google_absl//absl/types:optional", "@com_google_absl//absl/types:optional",
"@llvm-project//llvm:support", "@llvm-project//llvm:Support",
"@llvm-project//mlir:IR", "@llvm-project//mlir:IR",
"@llvm-project//mlir:StandardOps", "@llvm-project//mlir:StandardOps",
], ],
@ -877,7 +879,7 @@ cc_library(
"//tensorflow/compiler/xla/service:hlo_parser", "//tensorflow/compiler/xla/service:hlo_parser",
"//tensorflow/compiler/xla/service:hlo_proto_cc", "//tensorflow/compiler/xla/service:hlo_proto_cc",
"//tensorflow/core:lib", "//tensorflow/core:lib",
"@llvm-project//llvm:support", "@llvm-project//llvm:Support",
"@llvm-project//mlir:IR", "@llvm-project//mlir:IR",
"@llvm-project//mlir:Translation", "@llvm-project//mlir:Translation",
], ],
@ -888,8 +890,8 @@ tf_native_cc_binary(
name = "operator_writer_gen", name = "operator_writer_gen",
srcs = ["operator_writer_gen.cc"], srcs = ["operator_writer_gen.cc"],
deps = [ deps = [
"@llvm-project//llvm:support", "@llvm-project//llvm:Support",
"@llvm-project//llvm:tablegen", "@llvm-project//llvm:TableGen",
"@llvm-project//mlir:Support", "@llvm-project//mlir:Support",
"@llvm-project//mlir:TableGen", "@llvm-project//mlir:TableGen",
], ],

View File

@ -15,6 +15,8 @@ limitations under the License.
#include "tensorflow/compiler/mlir/xla/hlo_function_importer.h" #include "tensorflow/compiler/mlir/xla/hlo_function_importer.h"
#include <unordered_map>
#include "absl/types/optional.h" #include "absl/types/optional.h"
#include "llvm/ADT/ArrayRef.h" #include "llvm/ADT/ArrayRef.h"
#include "llvm/ADT/STLExtras.h" #include "llvm/ADT/STLExtras.h"
@ -79,30 +81,35 @@ bool DotIsDefault(const HloInstruction* instruction) {
} }
} // namespace } // namespace
StatusOr<mlir::FuncOp> HloFunctionImporter::ImportFunction( Status HloFunctionImporter::ImportAsFunc(
mlir::ModuleOp module, mlir::Builder* builder, const HloComputation& computation, mlir::ModuleOp module,
std::unordered_map<HloComputation*, FuncOp>* function_map, std::unordered_map<const HloComputation*, FuncOp>* function_map,
HloComputation* computation) { mlir::Builder* builder) {
HloFunctionImporter importer(module, builder, function_map); HloFunctionImporter importer(module, function_map, builder);
return importer.ImportFunction(computation); return importer.ImportAsFunc(computation).status();
} }
StatusOr<mlir::FuncOp> HloFunctionImporter::ImportFunction( Status HloFunctionImporter::ImportAsRegion(
HloComputation* computation) { const xla::HloComputation& computation, mlir::Region* region,
auto& imported = (*function_map_)[computation]; mlir::Builder* builder) {
HloFunctionImporter importer(region->getParentOfType<mlir::ModuleOp>(), {},
builder);
return importer.ImportAsRegion(computation, region);
}
StatusOr<mlir::FuncOp> HloFunctionImporter::ImportAsFunc(
const HloComputation& computation) {
auto& imported = (*function_map_)[&computation];
if (imported) return imported; if (imported) return imported;
llvm::SmallVector<Type, 4> args, rets; llvm::SmallVector<Type, 4> args, rets;
TF_RETURN_IF_ERROR( TF_RETURN_IF_ERROR(GetMlirTypes(computation.parameter_instructions(), &args));
GetMlirTypes(computation->parameter_instructions(), &args)); TF_RETURN_IF_ERROR(GetMlirTypes({computation.root_instruction()}, &rets));
TF_RETURN_IF_ERROR(GetMlirTypes({computation->root_instruction()}, &rets));
auto func_type = mlir::FunctionType::get(args, rets, context_); auto func_type = mlir::FunctionType::get(args, rets, context_);
string computation_name = string computation_name =
computation->parent()->entry_computation() == computation computation.parent()->entry_computation() == &computation
? "main" ? "main"
: SanitizeFunctionName(computation->name()); : SanitizeFunctionName(computation.name());
// Construct the MLIR function and map arguments. // Construct the MLIR function and map arguments.
llvm::ArrayRef<mlir::NamedAttribute> attrs; llvm::ArrayRef<mlir::NamedAttribute> attrs;
@ -119,31 +126,30 @@ StatusOr<mlir::FuncOp> HloFunctionImporter::ImportFunction(
return function; return function;
} }
tensorflow::Status HloFunctionImporter::ImportComputation( tensorflow::Status HloFunctionImporter::ImportAsRegion(
HloComputation* computation, mlir::Region* region) { const HloComputation& computation, mlir::Region* region) {
// TODO(hinsu): Store computation name as an attribute for round-trip. // TODO(hinsu): Store computation name as an attribute for round-trip.
auto* block = new mlir::Block; auto* block = new mlir::Block;
region->push_back(block); region->push_back(block);
llvm::SmallVector<Type, 4> args; llvm::SmallVector<Type, 4> args;
TF_RETURN_IF_ERROR( TF_RETURN_IF_ERROR(GetMlirTypes(computation.parameter_instructions(), &args));
GetMlirTypes(computation->parameter_instructions(), &args));
block->addArguments(args); block->addArguments(args);
return ImportInstructions(computation, block); return ImportInstructions(computation, block);
} }
tensorflow::Status HloFunctionImporter::ImportInstructions( tensorflow::Status HloFunctionImporter::ImportInstructions(
HloComputation* computation, mlir::Block* block) { const HloComputation& computation, mlir::Block* block) {
// Setup the input parameters. // Setup the input parameters.
const int num_parameters = computation->num_parameters(); const int num_parameters = computation.num_parameters();
for (int i = 0; i < num_parameters; i++) { for (int i = 0; i < num_parameters; i++) {
auto hlo_parameter = computation->parameter_instruction(i); auto hlo_parameter = computation.parameter_instruction(i);
instruction_value_map_[hlo_parameter] = block->getArgument(i); instruction_value_map_[hlo_parameter] = block->getArgument(i);
} }
mlir::OpBuilder builder = mlir::OpBuilder::atBlockEnd(block); mlir::OpBuilder builder = mlir::OpBuilder::atBlockEnd(block);
for (auto instruction : computation->MakeInstructionPostOrder()) { for (auto instruction : computation.MakeInstructionPostOrder()) {
TF_ASSIGN_OR_RETURN(auto new_operation, TF_ASSIGN_OR_RETURN(auto new_operation,
ImportInstruction(instruction, &builder)); ImportInstruction(instruction, &builder));
if (new_operation) { if (new_operation) {
@ -156,7 +162,7 @@ tensorflow::Status HloFunctionImporter::ImportInstructions(
// Setup the return type (HLO only supports a single return value). // Setup the return type (HLO only supports a single return value).
TF_ASSIGN_OR_RETURN(auto result, TF_ASSIGN_OR_RETURN(auto result,
GetMlirValue(computation->root_instruction())); GetMlirValue(computation.root_instruction()));
// Create terminator op depending on the parent op of this region. // Create terminator op depending on the parent op of this region.
if (llvm::isa<FuncOp>(block->getParentOp())) { if (llvm::isa<FuncOp>(block->getParentOp())) {
@ -249,7 +255,7 @@ StatusOr<mlir::Operation*> HloFunctionImporter::ImportInstruction(
} }
case HloOpcode::kCall: { case HloOpcode::kCall: {
TF_ASSIGN_OR_RETURN(FuncOp function, TF_ASSIGN_OR_RETURN(FuncOp function,
ImportFunction(instruction->to_apply())); ImportAsFunc(*instruction->to_apply()));
mlir::Operation* new_operation = mlir::Operation* new_operation =
func_builder->create<mlir::CallOp>(loc, function, operands); func_builder->create<mlir::CallOp>(loc, function, operands);
return new_operation; return new_operation;
@ -365,7 +371,7 @@ StatusOr<mlir::Operation*> HloFunctionImporter::ImportInstruction(
auto scatter_op = func_builder->create<mlir::xla_hlo::ScatterOp>( auto scatter_op = func_builder->create<mlir::xla_hlo::ScatterOp>(
loc, result_type, operands, attributes); loc, result_type, operands, attributes);
TF_RETURN_IF_ERROR(ImportComputation(scatter->to_apply(), TF_RETURN_IF_ERROR(ImportAsRegion(*scatter->to_apply(),
&scatter_op.update_computation())); &scatter_op.update_computation()));
return scatter_op.getOperation(); return scatter_op.getOperation();
} }
@ -387,9 +393,9 @@ StatusOr<mlir::Operation*> HloFunctionImporter::ImportInstruction(
auto select_scatter_op = auto select_scatter_op =
func_builder->create<mlir::xla_hlo::SelectAndScatterOp>( func_builder->create<mlir::xla_hlo::SelectAndScatterOp>(
loc, result_type, operands, attributes); loc, result_type, operands, attributes);
TF_RETURN_IF_ERROR(ImportComputation(select_scatter->select(), TF_RETURN_IF_ERROR(ImportAsRegion(*select_scatter->select(),
&select_scatter_op.select())); &select_scatter_op.select()));
TF_RETURN_IF_ERROR(ImportComputation(select_scatter->scatter(), TF_RETURN_IF_ERROR(ImportAsRegion(*select_scatter->scatter(),
&select_scatter_op.scatter())); &select_scatter_op.scatter()));
return select_scatter_op.getOperation(); return select_scatter_op.getOperation();
} }
@ -414,8 +420,8 @@ StatusOr<mlir::Operation*> HloFunctionImporter::ImportInstruction(
loc, result_type, operands, loc, result_type, operands,
builder_->getI64IntegerAttr(sort_instruction->sort_dimension()), builder_->getI64IntegerAttr(sort_instruction->sort_dimension()),
builder_->getBoolAttr(sort_instruction->is_stable())); builder_->getBoolAttr(sort_instruction->is_stable()));
TF_RETURN_IF_ERROR(ImportComputation(sort_instruction->to_apply(), TF_RETURN_IF_ERROR(
&sort_op.comparator())); ImportAsRegion(*sort_instruction->to_apply(), &sort_op.comparator()));
return sort_op.getOperation(); return sort_op.getOperation();
} }
case HloOpcode::kConditional: { case HloOpcode::kConditional: {
@ -430,9 +436,9 @@ StatusOr<mlir::Operation*> HloFunctionImporter::ImportInstruction(
auto op = func_builder->create<mlir::xla_hlo::IfOp>(loc, rets, operands, auto op = func_builder->create<mlir::xla_hlo::IfOp>(loc, rets, operands,
attributes); attributes);
TF_RETURN_IF_ERROR(ImportComputation(instruction->true_computation(), TF_RETURN_IF_ERROR(ImportAsRegion(*instruction->true_computation(),
&op.true_branch())); &op.true_branch()));
TF_RETURN_IF_ERROR(ImportComputation(instruction->false_computation(), TF_RETURN_IF_ERROR(ImportAsRegion(*instruction->false_computation(),
&op.false_branch())); &op.false_branch()));
return op.getOperation(); return op.getOperation();
} }
@ -448,8 +454,7 @@ StatusOr<mlir::Operation*> HloFunctionImporter::ImportInstruction(
llvm::enumerate(instruction->branch_computations())) { llvm::enumerate(instruction->branch_computations())) {
auto index = index_and_computation.index(); auto index = index_and_computation.index();
HloComputation* computation = index_and_computation.value(); HloComputation* computation = index_and_computation.value();
TF_RETURN_IF_ERROR( TF_RETURN_IF_ERROR(ImportAsRegion(*computation, &op.branches()[index]));
ImportComputation(computation, &op.branches()[index]));
} }
return op.getOperation(); return op.getOperation();
} }
@ -468,7 +473,7 @@ StatusOr<mlir::Operation*> HloFunctionImporter::ImportInstruction(
attributes.push_back(ConvertChannelHandle(all_reduce->channel_id())); attributes.push_back(ConvertChannelHandle(all_reduce->channel_id()));
auto all_reduce_op = func_builder->create<mlir::xla_hlo::AllReduceOp>( auto all_reduce_op = func_builder->create<mlir::xla_hlo::AllReduceOp>(
loc, result_type, operands, attributes); loc, result_type, operands, attributes);
TF_RETURN_IF_ERROR(ImportComputation(all_reduce->to_apply(), TF_RETURN_IF_ERROR(ImportAsRegion(*all_reduce->to_apply(),
&all_reduce_op.computation())); &all_reduce_op.computation()));
return all_reduce_op.getOperation(); return all_reduce_op.getOperation();
} }
@ -481,7 +486,7 @@ StatusOr<mlir::Operation*> HloFunctionImporter::ImportInstruction(
llvm::makeArrayRef(operands).drop_front(num_inputs), llvm::makeArrayRef(operands).drop_front(num_inputs),
ConvertDimensions(instruction->dimensions())); ConvertDimensions(instruction->dimensions()));
TF_RETURN_IF_ERROR( TF_RETURN_IF_ERROR(
ImportComputation(instruction->to_apply(), &reduce.body())); ImportAsRegion(*instruction->to_apply(), &reduce.body()));
return reduce.getOperation(); return reduce.getOperation();
} }
case HloOpcode::kReverse: { case HloOpcode::kReverse: {
@ -517,9 +522,9 @@ StatusOr<mlir::Operation*> HloFunctionImporter::ImportInstruction(
auto op = func_builder->create<mlir::xla_hlo::WhileOp>( auto op = func_builder->create<mlir::xla_hlo::WhileOp>(
loc, operands[0].getType(), operands[0]); loc, operands[0].getType(), operands[0]);
TF_RETURN_IF_ERROR( TF_RETURN_IF_ERROR(
ImportComputation(instruction->while_condition(), &op.cond())); ImportAsRegion(*instruction->while_condition(), &op.cond()));
TF_RETURN_IF_ERROR( TF_RETURN_IF_ERROR(
ImportComputation(instruction->while_body(), &op.body())); ImportAsRegion(*instruction->while_body(), &op.body()));
return op.getOperation(); return op.getOperation();
} }
case HloOpcode::kGetTupleElement: { case HloOpcode::kGetTupleElement: {
@ -580,7 +585,7 @@ StatusOr<mlir::Operation*> HloFunctionImporter::ImportInstruction(
auto reduce = func_builder->create<mlir::xla_hlo::ReduceWindowOp>( auto reduce = func_builder->create<mlir::xla_hlo::ReduceWindowOp>(
loc, result_type, operands, attributes); loc, result_type, operands, attributes);
TF_RETURN_IF_ERROR( TF_RETURN_IF_ERROR(
ImportComputation(instruction->to_apply(), &reduce.body())); ImportAsRegion(*instruction->to_apply(), &reduce.body()));
return reduce.getOperation(); return reduce.getOperation();
} }
case HloOpcode::kMap: { case HloOpcode::kMap: {
@ -588,7 +593,7 @@ StatusOr<mlir::Operation*> HloFunctionImporter::ImportInstruction(
loc, result_type, operands, loc, result_type, operands,
ConvertDimensions(instruction->dimensions())); ConvertDimensions(instruction->dimensions()));
TF_RETURN_IF_ERROR( TF_RETURN_IF_ERROR(
ImportComputation(instruction->to_apply(), &op.computation())); ImportAsRegion(*instruction->to_apply(), &op.computation()));
return op.getOperation(); return op.getOperation();
} }
case HloOpcode::kConvolution: { case HloOpcode::kConvolution: {

View File

@ -42,29 +42,39 @@ class Shape;
// Helper class for importing HloComputations. // Helper class for importing HloComputations.
class HloFunctionImporter { class HloFunctionImporter {
public: public:
static StatusOr<mlir::FuncOp> ImportFunction( // Imports the given computation as a function in the given module. This also
mlir::ModuleOp module, mlir::Builder* builder, // imports any computations referred by instructions in this computation.
std::unordered_map<xla::HloComputation*, mlir::FuncOp>* function_map, static Status ImportAsFunc(const xla::HloComputation& computation,
xla::HloComputation* computation); mlir::ModuleOp module,
std::unordered_map<const xla::HloComputation*,
mlir::FuncOp>* function_map,
mlir::Builder* builder);
// Imports the given hlo computation to the specified region.
static Status ImportAsRegion(const xla::HloComputation& computation,
mlir::Region* region, mlir::Builder* builder);
private: private:
HloFunctionImporter( HloFunctionImporter(mlir::ModuleOp module,
mlir::ModuleOp module, mlir::Builder* builder, std::unordered_map<const xla::HloComputation*,
std::unordered_map<xla::HloComputation*, mlir::FuncOp>* function_map) mlir::FuncOp>* function_map,
mlir::Builder* builder)
: context_(module.getContext()), : context_(module.getContext()),
module_(module), module_(module),
builder_(builder), builder_(builder),
function_map_(function_map) {} function_map_(function_map) {}
StatusOr<mlir::FuncOp> ImportFunction(xla::HloComputation* computation); // Imports the given computation as a new function, if it hasn't been already
// imported.
StatusOr<mlir::FuncOp> ImportAsFunc(const xla::HloComputation& computation);
// Imports the given computation in the specified region. // Imports the given computation in the specified region.
tensorflow::Status ImportComputation(HloComputation* computation, tensorflow::Status ImportAsRegion(const HloComputation& computation,
mlir::Region* region); mlir::Region* region);
// Imports instructions from the given computation in the specified block. // Imports instructions from the given computation in the specified block.
// Assumes that the block already has correct arguments populated. // Assumes that the block already has correct arguments populated.
tensorflow::Status ImportInstructions(HloComputation* computation, tensorflow::Status ImportInstructions(const HloComputation& computation,
mlir::Block* block); mlir::Block* block);
// Imports an instruction. // Imports an instruction.
@ -125,7 +135,7 @@ class HloFunctionImporter {
mlir::Builder* builder_; mlir::Builder* builder_;
// Mapping from HloComputation to the created MLIR function. // Mapping from HloComputation to the created MLIR function.
std::unordered_map<xla::HloComputation*, mlir::FuncOp>* function_map_; std::unordered_map<const xla::HloComputation*, mlir::FuncOp>* function_map_;
// Mapping from HloInstructions to the associative MLIR values. // Mapping from HloInstructions to the associative MLIR values.
std::unordered_map<xla::HloInstruction*, mlir::Value> instruction_value_map_; std::unordered_map<xla::HloInstruction*, mlir::Value> instruction_value_map_;

View File

@ -33,11 +33,11 @@ namespace xla {
Status HloModuleImporter::Import(const xla::HloModule& module) { Status HloModuleImporter::Import(const xla::HloModule& module) {
// TODO(hinsu): Only import the entry computation here once all HLO ops with // TODO(hinsu): Only import the entry computation here once all HLO ops with
// reference to other computation are updated to have a region instead of a // reference to other computation are updated to have a region instead of a
// function attribute. // function attribute. Currently the importer test doesn't refer to all the
for (const auto& computation : module.computations()) { // computations from the entry computation so tests may need some update.
auto result = HloFunctionImporter::ImportFunction( for (const auto* computation : module.computations()) {
module_, &builder_, &function_map_, computation); TF_RETURN_IF_ERROR(HloFunctionImporter::ImportAsFunc(
TF_RETURN_IF_ERROR(result.status()); *computation, module_, &function_map_, &builder_));
} }
return Status::OK(); return Status::OK();

View File

@ -54,7 +54,7 @@ class HloModuleImporter {
// Map for tracking which MLIR function map to which HLO Computation. This // Map for tracking which MLIR function map to which HLO Computation. This
// tracks functions as they are imported and provides a quick lookup for // tracks functions as they are imported and provides a quick lookup for
// functions invoked by control flow related operations (e.g. while, call). // functions invoked by control flow related operations (e.g. while, call).
std::unordered_map<xla::HloComputation*, mlir::FuncOp> function_map_; std::unordered_map<const xla::HloComputation*, mlir::FuncOp> function_map_;
}; };
} // namespace xla } // namespace xla

View File

@ -836,21 +836,33 @@ LogicalResult ConcatenateOp::inferReturnTypes(
auto dimension = dimension_attr.getInt(); auto dimension = dimension_attr.getInt();
auto first_type = (*operands.begin()).getType().cast<ShapedType>(); auto first_type = (*operands.begin()).getType().cast<ShapedType>();
auto out_element = first_type.getElementType(); auto out_element = first_type.getElementType();
for (auto operand : operands.getTypes()) {
auto element_type = getElementTypeOrSelf(operand);
if (element_type != out_element) {
return failure();
}
}
// If an input is unranked the output shape is unranked.
if (!first_type.hasRank()) {
inferredReturnTypes.push_back(UnrankedTensorType::get(out_element));
return success();
}
auto out_shape = llvm::to_vector<6>(first_type.getShape()); auto out_shape = llvm::to_vector<6>(first_type.getShape());
out_shape[dimension] = 0; out_shape[dimension] = 0;
for (auto operand : operands.getTypes()) { for (auto operand : operands.getTypes()) {
auto type = operand.cast<ShapedType>(); auto type = operand.cast<ShapedType>();
auto dim = type.getShape()[dimension]; if (!type.hasRank()) {
inferredReturnTypes.push_back(UnrankedTensorType::get(out_element));
// Validate the element types match. return success();
if (type.getElementType() != out_element) {
return failure();
} }
// If the dimension is dynamic we know the output dimension is dynamic. // If the dimension is dynamic we know the output dimension is dynamic.
auto dim = type.getShape()[dimension];
if (dim == -1) { if (dim == -1) {
out_shape[dimension] = -1; out_shape[dimension] = -1;
break; break;
@ -937,26 +949,39 @@ OpFoldResult ConcatenateOp::fold(ArrayRef<Attribute> operands) {
} }
static LogicalResult Verify(ConcatenateOp op) { static LogicalResult Verify(ConcatenateOp op) {
auto firstType = op.getOperand(0).getType().cast<RankedTensorType>(); Type element_type = getElementTypeOrSelf(op.getOperand(0).getType());
RankedTensorType first_ranked_type;
int num_operands = op.getNumOperands();
for (int i = 0; i < num_operands; i++) {
auto second_type = op.getOperand(i).getType().dyn_cast<ShapedType>();
if (second_type.getElementType() != element_type) {
return op.emitOpError(
llvm::formatv("operands (0) and ({0}) do not match element type", i));
}
auto firstShape = firstType.getShape(); if (!second_type.hasRank()) {
int numOperands = op.getNumOperands(); continue;
for (int i = 1; i < numOperands; i++) { }
auto secondType = op.getOperand(i).getType().cast<RankedTensorType>();
if (firstType.getRank() != secondType.getRank()) { if (!first_ranked_type) {
first_ranked_type = second_type.cast<RankedTensorType>();
continue;
}
if (first_ranked_type.getRank() != second_type.getRank()) {
return op.emitOpError( return op.emitOpError(
llvm::formatv("operands (0) and ({0}) do not match rank", i)); llvm::formatv("operands (0) and ({0}) do not match rank", i));
} }
auto secondShape = secondType.getShape(); auto first_shape = second_type.getShape();
for (int d = 0; d < firstType.getRank(); ++d) { auto second_shape = second_type.getShape();
if (firstShape[d] != secondShape[d] && d != op.dimension()) { for (int d = 0; d < first_ranked_type.getRank(); ++d) {
if (first_shape[d] != second_shape[d] && d != op.dimension()) {
return op.emitOpError(llvm::formatv( return op.emitOpError(llvm::formatv(
"operands (0) and ({0}) non-concat dimensions do not match " "operands (0) and ({0}) non-concat dimensions do not match "
"({1}) != ({2})", "({1}) != ({2})",
i, llvm::make_range(firstShape.begin(), firstShape.end()), i, llvm::make_range(first_shape.begin(), first_shape.end()),
llvm::make_range(secondShape.begin(), secondShape.end()))); llvm::make_range(second_shape.begin(), second_shape.end())));
} }
} }
} }

View File

@ -358,7 +358,7 @@ def HLO_SubOp : HLO_BinaryElementwiseOp<"subtract",
} }
//===----------------------------------------------------------------------===// //===----------------------------------------------------------------------===//
// XLA binary elementwise op definitions. // XLA binary logical elementwise op definitions.
//===----------------------------------------------------------------------===// //===----------------------------------------------------------------------===//
// See https://www.tensorflow.org/xla/operation_semantics#element-wise_binary_arithmetic_operations // See https://www.tensorflow.org/xla/operation_semantics#element-wise_binary_arithmetic_operations
@ -379,15 +379,6 @@ def HLO_XorOp : HLO_BinaryLogicalElementwiseOp<"xor">, BASE_HLO_XorOp;
// XLA communication op definitions. // XLA communication op definitions.
//===----------------------------------------------------------------------===// //===----------------------------------------------------------------------===//
// Represents a unique identifier for each Send/Recv instruction pair or
// optionally for collective instructions (AllReduce, CollectivePermute,
// AllToAll). Non-positive channel_id handle is equivalent to no channel id.
def ChannelHandle : StructAttr<"ChannelHandle", HLO_Dialect, [
StructFieldAttr<"handle", I64Attr>,
StructFieldAttr<"type", I64Attr>]> {
let description = "two 64-bit integers 'handle' and 'type'";
}
// InfeedOp corresponds to 'InfeedWithToken' xla client API and not 'Infeed'. // InfeedOp corresponds to 'InfeedWithToken' xla client API and not 'Infeed'.
// InfeedWithToken allows ordering of infeed HLO instructions using tokens. // InfeedWithToken allows ordering of infeed HLO instructions using tokens.
def HLO_InfeedOp : HLO_Op<"infeed", []> { def HLO_InfeedOp : HLO_Op<"infeed", []> {
@ -451,7 +442,7 @@ def HLO_SendOp : HLO_Op<"send", []> {
let arguments = (ins let arguments = (ins
HLO_TensorOrTuple:$operand, HLO_TensorOrTuple:$operand,
HLO_Token:$token, HLO_Token:$token,
ChannelHandle:$channel_id, ChannelHandle<HLO_Dialect>:$channel_id,
DefaultValuedAttr<BoolAttr, "false">:$is_host_transfer DefaultValuedAttr<BoolAttr, "false">:$is_host_transfer
); );
@ -476,7 +467,7 @@ def HLO_RecvOp : HLO_Op<"recv", []> {
let arguments = (ins let arguments = (ins
HLO_Token:$token, HLO_Token:$token,
ChannelHandle:$channel_id, ChannelHandle<HLO_Dialect>:$channel_id,
DefaultValuedAttr<BoolAttr, "false">:$is_host_transfer DefaultValuedAttr<BoolAttr, "false">:$is_host_transfer
); );
@ -564,16 +555,8 @@ def HLO_CaseOp: HLO_Op<"case", [RecursiveSideEffects]>,
def HLO_WhileOp: HLO_Op<"while", [RecursiveSideEffects, def HLO_WhileOp: HLO_Op<"while", [RecursiveSideEffects,
SameOperandsAndResultType]> { SameOperandsAndResultType]>,
string summary = "While operator"; BASE_HLO_WhileOp {
string description = [{
Returns the result of executing a body function until the cond body returns
true.
See https://www.tensorflow.org/xla/operation_semantics#while.
}];
let arguments = (ins HLO_TensorOrTuple:$val); let arguments = (ins HLO_TensorOrTuple:$val);
let regions = (region AnyRegion:$cond, AnyRegion:$body); let regions = (region AnyRegion:$cond, AnyRegion:$body);
@ -590,7 +573,7 @@ def HLO_AllReduceOp : HLO_Op<"all_reduce",
let arguments = (ins let arguments = (ins
HLO_Tensor:$operand, HLO_Tensor:$operand,
I64ElementsAttr:$replica_groups, I64ElementsAttr:$replica_groups,
OptionalAttr<ChannelHandle>:$channel_id OptionalAttr<ChannelHandle<HLO_Dialect>>:$channel_id
); );
let regions = (region SizedRegion<1>:$computation); let regions = (region SizedRegion<1>:$computation);
let results = (outs HLO_Tensor); let results = (outs HLO_Tensor);

View File

@ -584,6 +584,15 @@ class BASE_HLO_CaseOp {
// XLA parallelism related op definitions. // XLA parallelism related op definitions.
//===----------------------------------------------------------------------===// //===----------------------------------------------------------------------===//
// Represents a unique identifier for each Send/Recv instruction pair or
// optionally for collective instructions (AllReduce, CollectivePermute,
// AllToAll). Non-positive channel_id handle is equivalent to no channel id.
class ChannelHandle<Dialect dialect> : StructAttr<"ChannelHandle", dialect, [
StructFieldAttr<"handle", I64Attr>,
StructFieldAttr<"type", I64Attr>]> {
let description = "two 64-bit integers 'handle' and 'type'";
}
class BASE_HLO_ReplicaIdOp { class BASE_HLO_ReplicaIdOp {
string summary = "ReplicaId operator"; string summary = "ReplicaId operator";
@ -1258,4 +1267,45 @@ class BASE_HLO_RngNormalOp {
}]; }];
} }
class BASE_HLO_ReducePrecisionOp {
string summary = "Reduce precision operator";
string description = [{
Models the effect of converting floating - point values to a lower -
precision format(such as IEEE - FP16) and back to the original
format. The number of exponent and mantissa bits in the lower -
precision format can be specified arbitrarily,
although all bit sizes may not be supported on all hardware
implementations.
See https://www.tensorflow.org/xla/operation_semantics#reduceprecision.
}];
}
class BASE_HLO_InfeedOp {
string summary = "Infeed operator";
string description = [{
Reads a single data item from the implicit Infeed streaming interface of
the device, interpreting the data as the given shape and its layout, and
returns an LHLO op of the data. Multiple Infeed operations are allowed in a
computation, but there must be a total order among the Infeed operations.
For example, two Infeeds in the code below have a total order since there
is a dependency between the while loops.
See https://www.tensorflow.org/xla/operation_semantics#infeed
}];
}
class BASE_HLO_WhileOp {
string summary = "While operator";
string description = [{
Returns the result of executing a body function until the cond body returns
true.
See https://www.tensorflow.org/xla/operation_semantics#while.
}];
}
#endif // HLO_OPS_BASE #endif // HLO_OPS_BASE

View File

@ -14,6 +14,20 @@ limitations under the License.
==============================================================================*/ ==============================================================================*/
// This is the operation definition file for LXLA. // This is the operation definition file for LXLA.
//
// This file largely overlaps with hlo_ops.td at a logic level. It's tempting to
// merge these two files together, but we need to consider the following
// obstacles:
// * We need to have a common representation for arguments. That is to say,
// HLO_Array<X> translates to HLO_Tensor<X> in HLO dialect, and
// Arg<LHLO_Buffer<X>, "", [Mem(Read|Write)]> in LHLO. Array types within tuples
// also need to be transformed.
// * As of now, TableGen's dag functions are not sufficient to accomplish the
// one above.
// * Traits aren't identical, but need to be coped. For example,
// SameOperandAndResultType in HLO corresponds to SameTypeOperands in LHLO.
// * Also, currently HLO describes the API in XLA's client side, not service
// side. LHLO aims for the service side.
#ifndef LHLO_OPS #ifndef LHLO_OPS
#define LHLO_OPS #define LHLO_OPS
@ -38,11 +52,17 @@ def LHLO_IntBuffer : MemRefOf<[HLO_Int]>;
// Any floating-point tensor types // Any floating-point tensor types
def LHLO_FpBuffer : MemRefOf<[AnyFloat]>; def LHLO_FpBuffer : MemRefOf<[AnyFloat]>;
def LHLO_ComplexBuffer : MemRefOf<[AnyComplex]>;
def LHLO_FpOrComplexBuffer : MemRefOf<[AnyFloat, AnyComplex]>;
def LHLO_PredBuffer : MemRefOf<[HLO_Pred]>; def LHLO_PredBuffer : MemRefOf<[HLO_Pred]>;
// Any integer or floating-point tensor types // Any integer or floating-point tensor types
def LHLO_IntOrFpBuffer : MemRefOf<[HLO_Int, AnyFloat]>; def LHLO_IntOrFpBuffer : MemRefOf<[HLO_Int, AnyFloat]>;
def LHLO_PredOrIntBuffer : MemRefOf<[HLO_Int, HLO_Pred]>;
def LHLO_Buffer : MemRefOf<[AnyFloat, AnySignlessInteger, AnyComplex]>; def LHLO_Buffer : MemRefOf<[AnyFloat, AnySignlessInteger, AnyComplex]>;
def LHLO_TupleBuffer : NestedTupleOf<[LHLO_Buffer]>; def LHLO_TupleBuffer : NestedTupleOf<[LHLO_Buffer]>;
@ -74,88 +94,126 @@ def LHLO_IotaOp : LHLO_Op<"iota", []>, BASE_HLO_IotaOp {
//===----------------------------------------------------------------------===// //===----------------------------------------------------------------------===//
// See https://www.tensorflow.org/xla/operation_semantics#element-wise_unary_functions // See https://www.tensorflow.org/xla/operation_semantics#element-wise_unary_functions
class LHLO_UnaryElementwiseOp<string mnemonic> : class LHLO_UnaryElementwiseOp<string mnemonic,
LHLO_Op<mnemonic, [SameTypeOperands]> { Type BufferType = LHLO_Buffer,
let arguments = (ins Arg<LHLO_Buffer, "", [MemRead]>:$input, list<OpTrait> traits = [SameTypeOperands]>
Arg<LHLO_Buffer, "", [MemWrite]>:$output); : LHLO_Op<mnemonic, traits> {
let arguments = (ins Arg<BufferType, "", [MemRead]>:$input,
Arg<BufferType, "", [MemWrite]>:$output);
} }
def LHLO_AbsOp: LHLO_UnaryElementwiseOp<"abs">, BASE_HLO_AbsOp; def LHLO_AbsOp: LHLO_UnaryElementwiseOp<"abs">, BASE_HLO_AbsOp;
def LHLO_CeilOp: LHLO_UnaryElementwiseOp<"ceil">, BASE_HLO_CeilOp; // TODO(timshen): add a custom verifier.
def LHLO_BitcastConvertOp:
LHLO_UnaryElementwiseOp<"bitcast_convert", LHLO_Buffer, [SameOperandsShape]>, BASE_HLO_BitcastConvertOp;
def LHLO_ConvertOp : LHLO_Op<"convert", [SameOperandsShape]>, BASE_HLO_ConvertOp { def LHLO_CeilOp: LHLO_UnaryElementwiseOp<"ceil", LHLO_FpBuffer>, BASE_HLO_CeilOp;
let arguments = (ins Arg<LHLO_Buffer, "", [MemRead]>:$input,
Arg<LHLO_Buffer, "", [MemWrite]>:$output);
}
def LHLO_CosOp: LHLO_UnaryElementwiseOp<"cosine">, BASE_HLO_CosOp; def LHLO_ClzOp: LHLO_UnaryElementwiseOp<"count_leading_zeros", LHLO_IntBuffer>, BASE_HLO_ClzOp;
def LHLO_ExpOp: LHLO_UnaryElementwiseOp<"exponential">, BASE_HLO_ExpOp; // TODO(timshen): add a custom verifier.
def LHLO_ConvertOp : LHLO_UnaryElementwiseOp<"convert", LHLO_Buffer, [SameOperandsShape]>, BASE_HLO_ConvertOp;
def LHLO_CosOp: LHLO_UnaryElementwiseOp<"cosine", LHLO_FpOrComplexBuffer>, BASE_HLO_CosOp;
def LHLO_ExpOp: LHLO_UnaryElementwiseOp<"exponential", LHLO_FpOrComplexBuffer>, BASE_HLO_ExpOp;
def LHLO_Expm1Op: LHLO_UnaryElementwiseOp<"exponential_minus_one", LHLO_FpOrComplexBuffer>, BASE_HLO_Expm1Op;
def LHLO_FloorOp: LHLO_UnaryElementwiseOp<"floor", LHLO_FpBuffer>, BASE_HLO_FloorOp;
def LHLO_ImagOp: LHLO_Op<"imag", [SameOperandsShape]>, BASE_HLO_ImagOp { def LHLO_ImagOp: LHLO_Op<"imag", [SameOperandsShape]>, BASE_HLO_ImagOp {
let arguments = (ins Arg<LHLO_Buffer, "", [MemRead]>:$input, let arguments = (ins Arg<LHLO_ComplexBuffer, "", [MemRead]>:$input,
Arg<LHLO_Buffer, "", [MemWrite]>:$output); Arg<LHLO_FpBuffer, "", [MemWrite]>:$output);
} }
def LHLO_LogOp: LHLO_UnaryElementwiseOp<"log">, BASE_HLO_LogOp; def LHLO_IsFiniteOp: LHLO_Op<"is_finite", [SameOperandsShape]>, BASE_HLO_IsFiniteOp {
let arguments = (ins Arg<LHLO_FpBuffer, "", [MemRead]>:$input,
Arg<LHLO_PredBuffer, "", [MemWrite]>:$output);
}
def LHLO_LogOp: LHLO_UnaryElementwiseOp<"log", LHLO_FpOrComplexBuffer>, BASE_HLO_LogOp;
def LHLO_Log1pOp: LHLO_UnaryElementwiseOp<"log_plus_one", LHLO_FpOrComplexBuffer>, BASE_HLO_Log1pOp;
def LHLO_NegOp: LHLO_UnaryElementwiseOp<"negate">, BASE_HLO_NegOp; def LHLO_NegOp: LHLO_UnaryElementwiseOp<"negate">, BASE_HLO_NegOp;
def LHLO_NotOp: LHLO_UnaryElementwiseOp<"not", LHLO_PredOrIntBuffer>, BASE_HLO_NotOp;
def LHLO_PopulationCountOp: LHLO_UnaryElementwiseOp<"popcnt", LHLO_IntBuffer>, BASE_HLO_PopulationCountOp;
def LHLO_RealOp: LHLO_Op<"real", [SameOperandsShape]>, BASE_HLO_RealOp { def LHLO_RealOp: LHLO_Op<"real", [SameOperandsShape]>, BASE_HLO_RealOp {
let arguments = (ins Arg<LHLO_Buffer, "", [MemRead]>:$input, let arguments = (ins Arg<LHLO_ComplexBuffer, "", [MemRead]>:$input,
Arg<LHLO_Buffer, "", [MemWrite]>:$output); Arg<LHLO_FpBuffer, "", [MemWrite]>:$output);
} }
def LHLO_RsqrtOp: LHLO_UnaryElementwiseOp<"rsqrt">, BASE_HLO_RsqrtOp; def LHLO_RoundOp: LHLO_UnaryElementwiseOp<"round_nearest_afz", LHLO_FpBuffer>, BASE_HLO_RoundOp;
def LHLO_SqrtOp: LHLO_UnaryElementwiseOp<"sqrt">, BASE_HLO_SqrtOp; def LHLO_RsqrtOp: LHLO_UnaryElementwiseOp<"rsqrt", LHLO_FpOrComplexBuffer>, BASE_HLO_RsqrtOp;
def LHLO_SqrtOp: LHLO_UnaryElementwiseOp<"sqrt", LHLO_FpOrComplexBuffer>, BASE_HLO_SqrtOp;
def LHLO_SignOp: LHLO_UnaryElementwiseOp<"sign">, BASE_HLO_SignOp; def LHLO_SignOp: LHLO_UnaryElementwiseOp<"sign">, BASE_HLO_SignOp;
def LHLO_SinOp: LHLO_UnaryElementwiseOp<"sine">, BASE_HLO_SinOp; def LHLO_SinOp: LHLO_UnaryElementwiseOp<"sine", LHLO_FpOrComplexBuffer>, BASE_HLO_SinOp;
def LHLO_TanhOp: LHLO_UnaryElementwiseOp<"tanh">, BASE_HLO_TanhOp; def LHLO_TanhOp: LHLO_UnaryElementwiseOp<"tanh", LHLO_FpOrComplexBuffer>, BASE_HLO_TanhOp;
//===----------------------------------------------------------------------===// //===----------------------------------------------------------------------===//
// XLA binary elementwise op definitions. // XLA binary elementwise op definitions.
//===----------------------------------------------------------------------===// //===----------------------------------------------------------------------===//
// See https://www.tensorflow.org/xla/operation_semantics#element-wise_binary_arithmetic_operations // See https://www.tensorflow.org/xla/operation_semantics#element-wise_binary_arithmetic_operations
class LHLO_BinaryElementwiseOp<string mnemonic, list<OpTrait> traits> : class LHLO_BinaryElementwiseOp<string mnemonic, Type BufferType = LHLO_Buffer,
list<OpTrait> traits = [SameTypeOperands]> :
LHLO_Op<mnemonic, traits> { LHLO_Op<mnemonic, traits> {
let arguments = (ins let arguments = (ins
Arg<LHLO_Buffer, "", [MemRead]>:$lhs, Arg<BufferType, "", [MemRead]>:$lhs,
Arg<LHLO_Buffer, "", [MemRead]>:$rhs, Arg<BufferType, "", [MemRead]>:$rhs,
Arg<LHLO_Buffer, "", [MemWrite]>:$out, Arg<BufferType, "", [MemWrite]>:$out,
OptionalAttr<BroadcastDimAttr>:$broadcast_dimensions OptionalAttr<BroadcastDimAttr>:$broadcast_dimensions
); );
} }
def LHLO_AddOp : LHLO_BinaryElementwiseOp<"add", []>, BASE_HLO_AddOp; def LHLO_AddOp : LHLO_BinaryElementwiseOp<"add">, BASE_HLO_AddOp;
def LHLO_AndOp: LHLO_BinaryElementwiseOp<"and", LHLO_PredOrIntBuffer>, BASE_HLO_AndOp;
def LHLO_Atan2Op : LHLO_BinaryElementwiseOp<"atan2", LHLO_FpOrComplexBuffer>, BASE_HLO_Atan2Op;
def LHLO_ComplexOp: LHLO_Op<"complex", [SameOperandsShape]>, BASE_HLO_ComplexOp { def LHLO_ComplexOp: LHLO_Op<"complex", [SameOperandsShape]>, BASE_HLO_ComplexOp {
let arguments = (ins Arg<LHLO_Buffer, "", [MemRead]>:$lhs, let arguments = (ins
Arg<LHLO_Buffer, "", [MemRead]>:$rhs, Arg<LHLO_FpBuffer, "", [MemRead]>:$lhs,
Arg<LHLO_Buffer, "", [MemWrite]>:$output); Arg<LHLO_FpBuffer, "", [MemRead]>:$rhs,
Arg<LHLO_ComplexBuffer, "", [MemWrite]>:$output,
OptionalAttr<BroadcastDimAttr>:$broadcast_dimensions
);
} }
def LHLO_DivOp : LHLO_BinaryElementwiseOp<"divide", []>, BASE_HLO_DivOp; def LHLO_DivOp : LHLO_BinaryElementwiseOp<"divide">, BASE_HLO_DivOp;
def LHLO_MaxOp : LHLO_BinaryElementwiseOp<"maximum", []>, BASE_HLO_MaxOp; def LHLO_MaxOp : LHLO_BinaryElementwiseOp<"maximum">, BASE_HLO_MaxOp;
def LHLO_MinOp : LHLO_BinaryElementwiseOp<"minimum", []>, BASE_HLO_MinOp; def LHLO_MinOp : LHLO_BinaryElementwiseOp<"minimum">, BASE_HLO_MinOp;
def LHLO_MulOp : LHLO_BinaryElementwiseOp<"multiply", []>, BASE_HLO_MulOp; def LHLO_MulOp : LHLO_BinaryElementwiseOp<"multiply">, BASE_HLO_MulOp;
def LHLO_RemOp : def LHLO_OrOp : LHLO_BinaryElementwiseOp<"or", LHLO_PredOrIntBuffer>, BASE_HLO_OrOp;
LHLO_BinaryElementwiseOp<"remainder", []>, BASE_HLO_RemOp;
def LHLO_SubOp : LHLO_BinaryElementwiseOp<"subtract", []>, BASE_HLO_SubOp; def LHLO_PowOp : LHLO_BinaryElementwiseOp<"power">, BASE_HLO_PowOp;
def LHLO_AndOp: LHLO_BinaryElementwiseOp<"and", []>, BASE_HLO_AndOp; def LHLO_RemOp : LHLO_BinaryElementwiseOp<"remainder", LHLO_IntOrFpBuffer>, BASE_HLO_RemOp;
def LHLO_OrOp: LHLO_BinaryElementwiseOp<"or", []>, BASE_HLO_OrOp; def LHLO_ShiftLeftOp : LHLO_BinaryElementwiseOp<"shift_left", LHLO_IntBuffer>, BASE_HLO_ShiftLeftOp;
def LHLO_ShiftRightArithmeticOp : LHLO_BinaryElementwiseOp<"shift_right_arithmetic", LHLO_IntBuffer>, BASE_HLO_ShiftRightArithmeticOp;
def LHLO_ShiftRightLogicalOp : LHLO_BinaryElementwiseOp<"shift_right_logical", LHLO_IntBuffer>, BASE_HLO_ShiftRightLogicalOp;
def LHLO_SubOp : LHLO_BinaryElementwiseOp<"subtract">, BASE_HLO_SubOp;
def LHLO_XorOp : LHLO_BinaryElementwiseOp<"xor", LHLO_PredOrIntBuffer>, BASE_HLO_XorOp;
//===----------------------------------------------------------------------===// //===----------------------------------------------------------------------===//
// XLA control flow op definitions. // XLA control flow op definitions.
@ -210,6 +268,16 @@ def LHLO_CaseOp: LHLO_Op<"case", [
let regions = (region VariadicRegion<SizedRegion<1>>:$branches); let regions = (region VariadicRegion<SizedRegion<1>>:$branches);
} }
// TODO(timshen): Add a custom syntax for this.
def LHLO_WhileOp: LHLO_Op<"while", [SameTypeOperands]>, BASE_HLO_WhileOp {
let arguments = (ins
Arg<LHLO_BufferOrTuple, "", [MemRead]>:$val,
Arg<LHLO_BufferOrTuple, "", [MemWrite]>:$output
);
let regions = (region SizedRegion<1>:$cond, SizedRegion<1>:$body);
}
//===----------------------------------------------------------------------===// //===----------------------------------------------------------------------===//
// XLA tuple op definitions. // XLA tuple op definitions.
//===----------------------------------------------------------------------===// //===----------------------------------------------------------------------===//
@ -269,7 +337,9 @@ def HLO_DynamicUpdateSliceOp: LHLO_Op<"dynamic-update-slice", []> {
def HLO_StaticMemRefCastOp: Op<LHLO_Dialect, "static_memref_cast", def HLO_StaticMemRefCastOp: Op<LHLO_Dialect, "static_memref_cast",
[NoSideEffect, DeclareOpInterfaceMethods<ViewLikeOpInterface>]> { [NoSideEffect, DeclareOpInterfaceMethods<ViewLikeOpInterface>]> {
let summary = "static memref cast operation"; let summary = [{
"modifies the offset, sizes and strides of a statically shaped memref.
}];
let description = [{ let description = [{
Allows to modify the offset, sizes and strides of a statically shaped memref. Allows to modify the offset, sizes and strides of a statically shaped memref.
@ -357,7 +427,23 @@ def HLO_DynamicMemRefCastOp: Op<LHLO_Dialect, "dynamic_memref_cast",
// XLA Other op definitions. // XLA Other op definitions.
//===----------------------------------------------------------------------===// //===----------------------------------------------------------------------===//
def HLO_BatchNormInferenceOp : LHLO_Op<"batch_norm_inference", []>, def LHLO_BatchNormGradOp : LHLO_Op<"batch_norm_grad", []>,
BASE_HLO_BatchNormGradOp {
let arguments = (ins
Arg<LHLO_Buffer, "", [MemRead]>:$operand,
Arg<LHLO_Buffer, "", [MemRead]>:$scale,
Arg<LHLO_Buffer, "", [MemRead]>:$mean,
Arg<LHLO_Buffer, "", [MemRead]>:$variance,
Arg<LHLO_Buffer, "", [MemRead]>:$grad_output,
Arg<LHLO_TupleBuffer, "", [MemWrite]>:$output,
F32Attr:$epsilon,
I64Attr:$feature_index
);
}
def LHLO_BatchNormInferenceOp : LHLO_Op<"batch_norm_inference", []>,
BASE_HLO_BatchNormInferenceOp { BASE_HLO_BatchNormInferenceOp {
let arguments = (ins let arguments = (ins
@ -372,6 +458,19 @@ def HLO_BatchNormInferenceOp : LHLO_Op<"batch_norm_inference", []>,
); );
} }
def LHLO_BatchNormTrainingOp : LHLO_Op<"batch_norm_training", []>,
BASE_HLO_BatchNormTrainingOp {
let arguments = (ins
Arg<LHLO_Buffer, "", [MemRead]>:$operand,
Arg<LHLO_Buffer, "", [MemRead]>:$scale,
Arg<LHLO_Buffer, "", [MemRead]>:$offset,
Arg<LHLO_TupleBuffer, "", [MemWrite]>:$output,
F32Attr:$epsilon,
I64Attr:$feature_index
);
}
def LHLO_BroadcastOp : LHLO_Op<"broadcast", def LHLO_BroadcastOp : LHLO_Op<"broadcast",
[]>, BASE_HLO_BroadcastOp { []>, BASE_HLO_BroadcastOp {
let arguments = (ins let arguments = (ins
@ -531,6 +630,88 @@ def LHLO_TransposeOp: LHLO_Op<"transpose", []>, BASE_HLO_TransposeOp {
); );
} }
def LHLO_ReducePrecisionOp: LHLO_Op<"reduce_precision", [SameTypeOperands]>,
BASE_HLO_ReducePrecisionOp {
let arguments = (ins
Arg<LHLO_FpBuffer, "", [MemRead]>:$operand,
Arg<LHLO_FpBuffer, "", [MemWrite]>:$output,
I32Attr:$exponent_bits,
I32Attr:$mantissa_bits
);
}
def LHLO_AllReduceOp : LHLO_Op<"all_reduce", [SameTypeOperands]>,
BASE_HLO_AllReduceOp {
let arguments = (ins
Arg<LHLO_Buffer, "", [MemRead]>:$operand,
Arg<LHLO_Buffer, "", [MemWrite]>:$output,
I64ElementsAttr:$replica_groups,
DefaultValuedAttr<BoolAttr, "false">:$constrain_layout,
OptionalAttr<ChannelHandle<LHLO_Dialect>>:$channel_id,
DefaultValuedAttr<BoolAttr, "false">:$use_global_device_ids
);
let regions = (region SizedRegion<1>:$computation);
}
def LHLO_CollectivePermuteOp: LHLO_Op<"collective_permute", [SameTypeOperands]>,
BASE_HLO_CollectivePermuteOp {
let arguments = (ins
Arg<LHLO_Buffer, "", [MemRead]>:$operand,
Arg<LHLO_Buffer, "", [MemWrite]>:$output,
I64ElementsAttr:$source_target_pairs,
OptionalAttr<ChannelHandle<LHLO_Dialect>>:$channel_id
);
}
def LHLO_FftOp: LHLO_Op<"fft", []>, BASE_HLO_FftOp {
let arguments = (ins
Arg<LHLO_Buffer, "", [MemRead]>:$operand,
Arg<LHLO_Buffer, "", [MemWrite]>:$output,
HLO_FftTypeAttr:$fft_type,
I64ElementsAttr:$fft_length
);
}
def LHLO_CholeskyOp: LHLO_Op<"cholesky", [SameOperandsElementType]>, BASE_HLO_CholeskyOp {
let arguments = (ins
Arg<LHLO_FpOrComplexBuffer, "", [MemRead]>:$a,
Arg<LHLO_FpOrComplexBuffer, "", [MemWrite]>:$output,
DefaultValuedAttr<BoolAttr, "false">:$lower
);
}
def LHLO_Infeed: LHLO_Op<"infeed", []>, BASE_HLO_InfeedOp {
let arguments = (ins
Arg<LHLO_Buffer, "", [MemWrite]>:$output,
DefaultValuedAttr<StrAttr, "">:$config
);
}
def LHLO_Outfeed: LHLO_Op<"outfeed", []> {
let arguments = (ins
Arg<LHLO_Buffer, "", [MemRead]>:$operand,
DefaultValuedAttr<StrAttr, "">:$config
);
}
def LHLO_ReplicaIdOp : LHLO_Op<"replica_id", []>, BASE_HLO_ReplicaIdOp {
let arguments = (ins Arg<MemRefOf<[UI32]>, "", [MemWrite]>);
}
def LHLO_TriangularSolveOp: LHLO_Op<"triangular_solve", [SameOperandsElementType]>,
BASE_HLO_TriangularSolveOp {
let arguments = (ins
Arg<LHLO_FpOrComplexBuffer, "", [MemRead]>:$a,
Arg<LHLO_FpOrComplexBuffer, "", [MemRead]>:$b,
Arg<LHLO_FpOrComplexBuffer, "", [MemWrite]>:$output,
BoolAttr:$left_side,
BoolAttr:$lower,
BoolAttr:$unit_diagonal,
HLO_TransposeAttr:$transpose_a
);
}
//===----------------------------------------------------------------------===// //===----------------------------------------------------------------------===//
// Late operations // Late operations
//===----------------------------------------------------------------------===// //===----------------------------------------------------------------------===//

View File

@ -19,10 +19,12 @@ limitations under the License.
#include "mlir/IR/Builders.h" // from @llvm-project #include "mlir/IR/Builders.h" // from @llvm-project
#include "mlir/IR/StandardTypes.h" // from @llvm-project #include "mlir/IR/StandardTypes.h" // from @llvm-project
#include "tensorflow/compiler/mlir/xla/attribute_importer.h" #include "tensorflow/compiler/mlir/xla/attribute_importer.h"
#include "tensorflow/compiler/mlir/xla/hlo_function_importer.h"
#include "tensorflow/compiler/mlir/xla/hlo_utils.h" #include "tensorflow/compiler/mlir/xla/hlo_utils.h"
#include "tensorflow/compiler/mlir/xla/ir/hlo_ops.h" #include "tensorflow/compiler/mlir/xla/ir/hlo_ops.h"
#include "tensorflow/compiler/mlir/xla/type_to_shape.h" #include "tensorflow/compiler/mlir/xla/type_to_shape.h"
#include "tensorflow/compiler/xla/comparison_util.h" #include "tensorflow/compiler/xla/comparison_util.h"
#include "tensorflow/compiler/xla/service/hlo_module.h"
#include "tensorflow/compiler/xla/service/shape_inference.h" #include "tensorflow/compiler/xla/service/shape_inference.h"
#include "tensorflow/compiler/xla/util.h" #include "tensorflow/compiler/xla/util.h"
@ -118,6 +120,76 @@ StatusOr<XlaOp> MlirHloBuilder::ConvGeneralDilatedInternal(
return MakeXlaOp(op); return MakeXlaOp(op);
} }
StatusOr<XlaOp> MlirHloBuilder::FftInternal(
const Shape& shape, XlaOp operand, FftType fft_type,
absl::Span<const int64> fft_length) {
TF_ASSIGN_OR_RETURN(mlir::Type ty, ConvertShapeToType<mlir::RankedTensorType>(
shape, builder_));
auto op = builder_.create<mlir::xla_hlo::FftOp>(
loc_, ty, GetValue(operand),
builder_.getStringAttr(FftType_Name(fft_type)),
GetI64ElementsAttr(fft_length, &builder_));
return MakeXlaOp(op);
}
StatusOr<XlaOp> MlirHloBuilder::ReduceInternal(
const Shape& shape, absl::Span<const XlaOp> all_operands,
const XlaComputation& computation,
absl::Span<const int64> dimensions_to_reduce) {
// Reduce takes two set of variadic operands inputs and init_values.
// all_operands contains both of these so split operands into two parts.
int64_t num_args = all_operands.size() / 2;
auto op = builder_.create<mlir::xla_hlo::ReduceOp>(
loc_, GetValues(all_operands.first(num_args)),
GetValues(all_operands.subspan(num_args)),
GetI64ElementsAttr(dimensions_to_reduce, &builder_));
TF_RETURN_IF_ERROR(ImportComputation(computation.proto(), &op.body()));
if (op.getNumResults() == 1) return MakeXlaOp(op.getResult(0));
auto tuple = builder_.create<mlir::xla_hlo::TupleOp>(loc_, op.getResults());
return MakeXlaOp(tuple);
}
StatusOr<XlaOp> MlirHloBuilder::ReduceWindowInternal(
const Shape& shape, XlaOp operand, XlaOp init_value,
const XlaComputation& computation, Window window) {
TF_ASSIGN_OR_RETURN(mlir::Type ty, ConvertShapeToType<mlir::RankedTensorType>(
shape, builder_));
llvm::SmallVector<int64, 4> sizes, strides, base_dilations, win_dilations;
llvm::SmallVector<int64, 8> padding;
for (const auto& dim : window.dimensions()) {
sizes.push_back(dim.size());
strides.push_back(dim.stride());
base_dilations.push_back(dim.base_dilation());
win_dilations.push_back(dim.window_dilation());
padding.push_back(dim.padding_low());
padding.push_back(dim.padding_high());
}
auto padding_ty =
mlir::RankedTensorType::get({static_cast<int64_t>(padding.size()) / 2, 2},
builder_.getIntegerType(64));
auto op = builder_.create<mlir::xla_hlo::ReduceWindowOp>(
loc_, ty, GetValue(operand), GetValue(init_value),
GetI64ElementsAttr(sizes, &builder_),
GetI64ElementsAttr(strides, &builder_),
GetI64ElementsAttr(base_dilations, &builder_),
GetI64ElementsAttr(win_dilations, &builder_),
mlir::DenseIntElementsAttr::get(padding_ty, padding));
TF_RETURN_IF_ERROR(ImportComputation(computation.proto(), &op.body()));
return MakeXlaOp(op);
}
XlaOp MlirHloBuilder::Iota(const Shape& shape, int64 iota_dimension) {
return ReportErrorOrReturn([&]() -> StatusOr<XlaOp> {
TF_ASSIGN_OR_RETURN(
mlir::Type ty,
ConvertShapeToType<mlir::RankedTensorType>(shape, builder_));
auto op = builder_.create<mlir::xla_hlo::IotaOp>(
loc_, ty,
builder_.getIntegerAttr(builder_.getI64Type(), iota_dimension));
return MakeXlaOp(op);
});
}
StatusOr<XlaOp> MlirHloBuilder::TransposeInternal( StatusOr<XlaOp> MlirHloBuilder::TransposeInternal(
const Shape& shape, XlaOp operand, absl::Span<const int64> permutation) { const Shape& shape, XlaOp operand, absl::Span<const int64> permutation) {
TF_ASSIGN_OR_RETURN(mlir::Type ty, ConvertShapeToType<mlir::RankedTensorType>( TF_ASSIGN_OR_RETURN(mlir::Type ty, ConvertShapeToType<mlir::RankedTensorType>(
@ -127,6 +199,15 @@ StatusOr<XlaOp> MlirHloBuilder::TransposeInternal(
return MakeXlaOp(op); return MakeXlaOp(op);
} }
StatusOr<XlaOp> MlirHloBuilder::RevInternal(
const Shape& shape, XlaOp operand, absl::Span<const int64> dimensions) {
TF_ASSIGN_OR_RETURN(mlir::Type ty, ConvertShapeToType<mlir::RankedTensorType>(
shape, builder_));
auto op = builder_.create<mlir::xla_hlo::ReverseOp>(
loc_, ty, GetValue(operand), GetI64ElementsAttr(dimensions, &builder_));
return MakeXlaOp(op);
}
StatusOr<XlaOp> MlirHloBuilder::GatherInternal( StatusOr<XlaOp> MlirHloBuilder::GatherInternal(
const Shape& shape, XlaOp input, XlaOp start_indices, const Shape& shape, XlaOp input, XlaOp start_indices,
const GatherDimensionNumbers& dimension_numbers, const GatherDimensionNumbers& dimension_numbers,
@ -140,6 +221,24 @@ StatusOr<XlaOp> MlirHloBuilder::GatherInternal(
return MakeXlaOp(op); return MakeXlaOp(op);
} }
StatusOr<XlaOp> MlirHloBuilder::ScatterInternal(
const Shape& shape, XlaOp input, XlaOp scatter_indices, XlaOp updates,
const XlaComputation& update_computation,
const ScatterDimensionNumbers& dimension_numbers, bool indices_are_sorted,
bool unique_indices) {
TF_ASSIGN_OR_RETURN(mlir::Type ty, ConvertShapeToType<mlir::RankedTensorType>(
shape, builder_));
auto op = builder_.create<mlir::xla_hlo::ScatterOp>(
loc_, ty, GetValue(input), GetValue(scatter_indices), GetValue(updates),
ConvertScatterDimensionNumbers(dimension_numbers, &builder_),
builder_.getBoolAttr(indices_are_sorted),
builder_.getBoolAttr(unique_indices));
TF_RETURN_IF_ERROR(
ImportComputation(update_computation.proto(), &op.update_computation()));
return MakeXlaOp(op);
}
StatusOr<XlaOp> MlirHloBuilder::RngOpInternal( StatusOr<XlaOp> MlirHloBuilder::RngOpInternal(
RandomDistribution distribution, absl::Span<const XlaOp> parameters, RandomDistribution distribution, absl::Span<const XlaOp> parameters,
const Shape& shape) { const Shape& shape) {
@ -348,6 +447,18 @@ StatusOr<XlaOp> MlirHloBuilder::CreateOp(
return MakeXlaOp(op->getResult(0)); return MakeXlaOp(op->getResult(0));
} }
Status MlirHloBuilder::ImportComputation(const HloModuleProto& computation,
mlir::Region* region) {
TF_ASSIGN_OR_RETURN(auto module_config,
xla::HloModule::CreateModuleConfigFromProto(
computation, xla::DebugOptions()));
TF_ASSIGN_OR_RETURN(auto hlo_module, xla::HloModule::CreateFromProto(
computation, module_config));
return HloFunctionImporter::ImportAsRegion(*hlo_module->entry_computation(),
region, &builder_);
}
StatusOr<const Shape*> MlirHloBuilder::GetShapePtr(XlaOp op) const { StatusOr<const Shape*> MlirHloBuilder::GetShapePtr(XlaOp op) const {
TF_RETURN_IF_ERROR(first_error()); TF_RETURN_IF_ERROR(first_error());
TF_RETURN_IF_ERROR(CheckOpBuilder(op)); TF_RETURN_IF_ERROR(CheckOpBuilder(op));

View File

@ -120,15 +120,40 @@ class MlirHloBuilder : public XlaBuilder {
int64 feature_group_count, int64 batch_group_count, int64 feature_group_count, int64 batch_group_count,
const PrecisionConfig* precision_config) override; const PrecisionConfig* precision_config) override;
StatusOr<XlaOp> FftInternal(const Shape& shape, XlaOp operand,
FftType fft_type,
absl::Span<const int64> fft_length) override;
StatusOr<XlaOp> ReduceInternal(
const Shape& shape, absl::Span<const XlaOp> all_operands,
const XlaComputation& computation,
absl::Span<const int64> dimensions_to_reduce) override;
StatusOr<XlaOp> ReduceWindowInternal(const Shape& shape, XlaOp operand,
XlaOp init_value,
const XlaComputation& computation,
Window window) override;
XlaOp Iota(const Shape& shape, int64 iota_dimension) override;
StatusOr<XlaOp> TransposeInternal( StatusOr<XlaOp> TransposeInternal(
const Shape& shape, XlaOp operand, const Shape& shape, XlaOp operand,
absl::Span<const int64> permutation) override; absl::Span<const int64> permutation) override;
StatusOr<XlaOp> RevInternal(const Shape& shape, XlaOp operand,
absl::Span<const int64> dimensions) override;
StatusOr<XlaOp> GatherInternal( StatusOr<XlaOp> GatherInternal(
const Shape& shape, XlaOp input, XlaOp start_indices, const Shape& shape, XlaOp input, XlaOp start_indices,
const GatherDimensionNumbers& dimension_numbers, const GatherDimensionNumbers& dimension_numbers,
absl::Span<const int64> slice_sizes, bool indices_are_sorted) override; absl::Span<const int64> slice_sizes, bool indices_are_sorted) override;
StatusOr<XlaOp> ScatterInternal(
const Shape& shape, XlaOp input, XlaOp scatter_indices, XlaOp updates,
const XlaComputation& update_computation,
const ScatterDimensionNumbers& dimension_numbers, bool indices_are_sorted,
bool unique_indices) override;
StatusOr<XlaOp> RngOpInternal(RandomDistribution distribution, StatusOr<XlaOp> RngOpInternal(RandomDistribution distribution,
absl::Span<const XlaOp> parameters, absl::Span<const XlaOp> parameters,
const Shape& shape) override; const Shape& shape) override;
@ -196,6 +221,9 @@ class MlirHloBuilder : public XlaBuilder {
llvm::ArrayRef<XlaOp> operands, llvm::ArrayRef<XlaOp> operands,
llvm::ArrayRef<mlir::NamedAttribute> attributes = {}); llvm::ArrayRef<mlir::NamedAttribute> attributes = {});
Status ImportComputation(const HloModuleProto& computation,
mlir::Region* region);
mlir::OpBuilder builder_; mlir::OpBuilder builder_;
mlir::Location loc_; mlir::Location loc_;

View File

@ -68,10 +68,11 @@ static StringRef GetClientBuilder(const Operator& op) {
return kOpToXLABuilderMap->lookup(op_name); return kOpToXLABuilderMap->lookup(op_name);
} }
static void BuildOperator(const Operator& op, raw_ostream* output) { static void BuildOperator(const Operator& op, raw_ostream& os) {
auto& os = *output; os << "mlir::LogicalResult ExportXlaOp(mlir::xla_hlo::"
os << " auto& value_map = *lowering_context.values;\n" << op.getCppClassName() << " op, OpLoweringContext ctx) {\n"
<< " auto result = xla_op.getResult();\n"; << " auto& value_map = *ctx.values;\n"
<< " auto result = op.getResult();\n";
// Build a conversion for each of the arguments. // Build a conversion for each of the arguments.
int operand_number = 0; int operand_number = 0;
@ -82,15 +83,14 @@ static void BuildOperator(const Operator& op, raw_ostream* output) {
if (auto* operand_cst = arg.dyn_cast<NamedTypeConstraint*>()) { if (auto* operand_cst = arg.dyn_cast<NamedTypeConstraint*>()) {
// Handle a non-variadic operand. // Handle a non-variadic operand.
if (!operand_cst->isVariableLength()) { if (!operand_cst->isVariableLength()) {
os << " auto xla_arg_" << index os << " auto xla_arg_" << index << " = value_map[*op.getODSOperands("
<< " = value_map[*xla_op.getODSOperands(" << operand_number++ << operand_number++ << ").begin()];\n";
<< ").begin()];\n";
continue; continue;
} }
// Otherwise, this is a varidiac operand list. // Otherwise, this is a varidiac operand list.
os << " std::vector<xla::XlaOp> xla_arg_" << index << ";\n" os << " std::vector<xla::XlaOp> xla_arg_" << index << ";\n"
<< " for (auto operand : xla_op.getODSOperands(" << operand_number++ << " for (auto operand : op.getODSOperands(" << operand_number++
<< "))\n xla_arg_" << index << "))\n xla_arg_" << index
<< ".push_back(value_map[operand]);\n"; << ".push_back(value_map[operand]);\n";
continue; continue;
@ -99,8 +99,8 @@ static void BuildOperator(const Operator& op, raw_ostream* output) {
// Otherwise, this is an attribute. // Otherwise, this is an attribute.
auto named_attr = arg.get<NamedAttribute*>(); auto named_attr = arg.get<NamedAttribute*>();
os << " auto xla_arg_" << index << " = " os << " auto xla_arg_" << index << " = "
<< GetDefaultAttrExport(*named_attr) << "(xla_op." << GetDefaultAttrExport(*named_attr) << "(op." << op.getArgName(index)
<< op.getArgName(index) << "());\n"; << "());\n";
} }
// Emit call to client API // Emit call to client API
@ -109,7 +109,7 @@ static void BuildOperator(const Operator& op, raw_ostream* output) {
// If all operands are variadic, then pass the builder explicitly to xla // If all operands are variadic, then pass the builder explicitly to xla
// client API call // client API call
if (op.getNumOperands() == op.getNumVariableLengthOperands()) { if (op.getNumOperands() == op.getNumVariableLengthOperands()) {
os << "lowering_context.builder"; os << "ctx.builder";
if (op.getNumArgs() != 0) os << ", "; if (op.getNumArgs() != 0) os << ", ";
} }
@ -120,6 +120,7 @@ static void BuildOperator(const Operator& op, raw_ostream* output) {
os << " value_map[result] = xla_result;\n"; os << " value_map[result] = xla_result;\n";
os << " return mlir::success();\n"; os << " return mlir::success();\n";
os << "}\n";
} }
// The function below has a non-constant reference as that is required by LLVM's // The function below has a non-constant reference as that is required by LLVM's
@ -128,6 +129,14 @@ static void BuildOperator(const Operator& op, raw_ostream* output) {
static bool OperatorWritersMain(raw_ostream& os, RecordKeeper& records) { static bool OperatorWritersMain(raw_ostream& os, RecordKeeper& records) {
emitSourceFileHeader("MLIR XLA Builders", os); emitSourceFileHeader("MLIR XLA Builders", os);
// Emit all the helper functions.
for (const auto* def : records.getAllDerivedDefinitions("HLO_Op")) {
Operator op(def);
// Skip operations that have a custom exporter.
if (!def->getValueAsBit("hasCustomHLOConverter")) BuildOperator(op, os);
}
// Emit a function to generate an XLA operation for the operations with // Emit a function to generate an XLA operation for the operations with
// auto-generated builders. // auto-generated builders.
os << "mlir::LogicalResult ExportXlaOperator(\n" os << "mlir::LogicalResult ExportXlaOperator(\n"
@ -153,12 +162,11 @@ static bool OperatorWritersMain(raw_ostream& os, RecordKeeper& records) {
// Cast to the current operation and build the exporter. // Cast to the current operation and build the exporter.
os << " if (auto xla_op = llvm::dyn_cast<mlir::xla_hlo::" os << " if (auto xla_op = llvm::dyn_cast<mlir::xla_hlo::"
<< op.getCppClassName() << ">(op)) {\n"; << op.getCppClassName() << ">(op)) {\n";
if (def->getValueAsBit("hasCustomHLOConverter")) { os << " return ";
os << " return mlir::xla_hlo::ExportXlaOp(xla_op, " // The autogenerated converters aren't in the same namespace.
"lowering_context);\n"; // TODO(jpienaar): Reconsider this.
} else { if (def->getValueAsBit("hasCustomHLOConverter")) os << "mlir::xla_hlo::";
BuildOperator(op, &os); os << "ExportXlaOp(xla_op, lowering_context);\n";
}
os << " }\n"; os << " }\n";
} }

View File

@ -30,7 +30,7 @@ tf_cc_test(
"//tensorflow/core:test", "//tensorflow/core:test",
"//tensorflow/core:test_main", "//tensorflow/core:test_main",
"//tensorflow/core:testlib", "//tensorflow/core:testlib",
"@llvm-project//llvm:support", "@llvm-project//llvm:Support",
"@llvm-project//mlir:IR", "@llvm-project//mlir:IR",
], ],
) )

View File

@ -9,7 +9,7 @@ func @broadcast_add(%arg0: tensor<?xf32>, %arg1: tensor<?xf32>) -> tensor<1xinde
// CHECK-DAG: %[[ARG0_S:.+]] = shape.shape_of %[[ARG0]] // CHECK-DAG: %[[ARG0_S:.+]] = shape.shape_of %[[ARG0]]
// CHECK-DAG: %[[ARG1_S:.+]] = shape.shape_of %[[ARG1]] // CHECK-DAG: %[[ARG1_S:.+]] = shape.shape_of %[[ARG1]]
// CHECK-DAG: %[[BCAST_S:.+]] = "shape.broadcast"(%[[ARG0_S]], %[[ARG1_S]]) // CHECK-DAG: %[[BCAST_S:.+]] = "shape.broadcast"(%[[ARG0_S]], %[[ARG1_S]])
// CHECK: %[[EXTENTS:.+]] = "shape.to_extent_tensor"(%[[BCAST_S]]) // CHECK: %[[EXTENTS:.+]] = shape.to_extent_tensor %[[BCAST_S]]
// CHECK: return %[[EXTENTS]] // CHECK: return %[[EXTENTS]]
%0 = xla_chlo.broadcast_add %arg0, %arg1 : (tensor<?xf32>, tensor<?xf32>) -> tensor<?xf32> %0 = xla_chlo.broadcast_add %arg0, %arg1 : (tensor<?xf32>, tensor<?xf32>) -> tensor<?xf32>
%1 = "xla_test.reify_return_type_shapes"(%0) : (tensor<?xf32>) -> tensor<1xindex> %1 = "xla_test.reify_return_type_shapes"(%0) : (tensor<?xf32>) -> tensor<1xindex>

View File

@ -17,7 +17,7 @@ func @dynamicBroadcast(%arg0: tensor<?xf32>, %arg1: tensor<?x?xf32>) -> tensor<?
// CHECK-DAG: %[[ARG0_S:.+]] = shape.shape_of %[[ARG0]] // CHECK-DAG: %[[ARG0_S:.+]] = shape.shape_of %[[ARG0]]
// CHECK-DAG: %[[ARG1_S:.+]] = shape.shape_of %[[ARG1]] // CHECK-DAG: %[[ARG1_S:.+]] = shape.shape_of %[[ARG1]]
// CHECK-DAG: %[[RESULT_S:.+]] = "shape.broadcast"(%[[ARG0_S]], %[[ARG1_S]]) // CHECK-DAG: %[[RESULT_S:.+]] = "shape.broadcast"(%[[ARG0_S]], %[[ARG1_S]])
// CHECK: %[[RESULT_EXTENTS:.+]] = "shape.to_extent_tensor"(%[[RESULT_S]]) // CHECK: %[[RESULT_EXTENTS:.+]] = shape.to_extent_tensor %[[RESULT_S]]
// CHECK-DAG: %[[ARG0_B:.+]] = "xla_hlo.dynamic_broadcast_in_dim"(%[[ARG0]], %[[RESULT_EXTENTS]]) {broadcast_dimensions = dense<1> : tensor<1xi64>} // CHECK-DAG: %[[ARG0_B:.+]] = "xla_hlo.dynamic_broadcast_in_dim"(%[[ARG0]], %[[RESULT_EXTENTS]]) {broadcast_dimensions = dense<1> : tensor<1xi64>}
// CHECK-DAG: %[[ARG1_B:.+]] = "xla_hlo.dynamic_broadcast_in_dim"(%[[ARG1]], %[[RESULT_EXTENTS]]) {broadcast_dimensions = dense<[0, 1]> : tensor<2xi64>} // CHECK-DAG: %[[ARG1_B:.+]] = "xla_hlo.dynamic_broadcast_in_dim"(%[[ARG1]], %[[RESULT_EXTENTS]]) {broadcast_dimensions = dense<[0, 1]> : tensor<2xi64>}
// CHECK-DAG: %[[RESULT:.+]] = xla_hlo.add %[[ARG0_B]], %[[ARG1_B]] // CHECK-DAG: %[[RESULT:.+]] = xla_hlo.add %[[ARG0_B]], %[[ARG1_B]]
@ -34,7 +34,7 @@ func @dynamicBroadcastComplex(%arg0: tensor<?xf32>, %arg1: tensor<?x?xf32>) -> t
// CHECK-DAG: %[[ARG0_S:.+]] = shape.shape_of %[[ARG0]] // CHECK-DAG: %[[ARG0_S:.+]] = shape.shape_of %[[ARG0]]
// CHECK-DAG: %[[ARG1_S:.+]] = shape.shape_of %[[ARG1]] // CHECK-DAG: %[[ARG1_S:.+]] = shape.shape_of %[[ARG1]]
// CHECK-DAG: %[[RESULT_S:.+]] = "shape.broadcast"(%[[ARG0_S]], %[[ARG1_S]]) // CHECK-DAG: %[[RESULT_S:.+]] = "shape.broadcast"(%[[ARG0_S]], %[[ARG1_S]])
// CHECK: %[[RESULT_EXTENTS:.+]] = "shape.to_extent_tensor"(%[[RESULT_S]]) // CHECK: %[[RESULT_EXTENTS:.+]] = shape.to_extent_tensor %[[RESULT_S]]
// CHECK-DAG: %[[ARG0_B:.+]] = "xla_hlo.dynamic_broadcast_in_dim"(%[[ARG0]], %[[RESULT_EXTENTS]]) {broadcast_dimensions = dense<1> : tensor<1xi64>} : (tensor<?xf32>, tensor<2xindex>) -> tensor<?x?xf32> // CHECK-DAG: %[[ARG0_B:.+]] = "xla_hlo.dynamic_broadcast_in_dim"(%[[ARG0]], %[[RESULT_EXTENTS]]) {broadcast_dimensions = dense<1> : tensor<1xi64>} : (tensor<?xf32>, tensor<2xindex>) -> tensor<?x?xf32>
// CHECK-DAG: %[[ARG1_B:.+]] = "xla_hlo.dynamic_broadcast_in_dim"(%[[ARG1]], %[[RESULT_EXTENTS]]) {broadcast_dimensions = dense<[0, 1]> : tensor<2xi64>} : (tensor<?x?xf32>, tensor<2xindex>) -> tensor<?x?xf32> // CHECK-DAG: %[[ARG1_B:.+]] = "xla_hlo.dynamic_broadcast_in_dim"(%[[ARG1]], %[[RESULT_EXTENTS]]) {broadcast_dimensions = dense<[0, 1]> : tensor<2xi64>} : (tensor<?x?xf32>, tensor<2xindex>) -> tensor<?x?xf32>
// CHECK-DAG: %[[RESULT:.+]] = "xla_hlo.complex"(%[[ARG0_B]], %[[ARG1_B]]) : (tensor<?x?xf32>, tensor<?x?xf32>) -> tensor<?x?xcomplex<f32>> // CHECK-DAG: %[[RESULT:.+]] = "xla_hlo.complex"(%[[ARG0_B]], %[[ARG1_B]]) : (tensor<?x?xf32>, tensor<?x?xf32>) -> tensor<?x?xcomplex<f32>>
@ -51,7 +51,7 @@ func @dynamicBroadcastCompare(%arg0: tensor<?xf32>, %arg1: tensor<?x?xf32>) -> t
// CHECK-DAG: %[[ARG0_S:.+]] = shape.shape_of %[[ARG0]] // CHECK-DAG: %[[ARG0_S:.+]] = shape.shape_of %[[ARG0]]
// CHECK-DAG: %[[ARG1_S:.+]] = shape.shape_of %[[ARG1]] // CHECK-DAG: %[[ARG1_S:.+]] = shape.shape_of %[[ARG1]]
// CHECK-DAG: %[[RESULT_S:.+]] = "shape.broadcast"(%[[ARG0_S]], %[[ARG1_S]]) // CHECK-DAG: %[[RESULT_S:.+]] = "shape.broadcast"(%[[ARG0_S]], %[[ARG1_S]])
// CHECK: %[[RESULT_EXTENTS:.+]] = "shape.to_extent_tensor"(%[[RESULT_S]]) // CHECK: %[[RESULT_EXTENTS:.+]] = shape.to_extent_tensor %[[RESULT_S]]
// CHECK-DAG: %[[ARG0_B:.+]] = "xla_hlo.dynamic_broadcast_in_dim"(%[[ARG0]], %[[RESULT_EXTENTS]]) {broadcast_dimensions = dense<1> : tensor<1xi64>} : (tensor<?xf32>, tensor<2xindex>) -> tensor<?x?xf32> // CHECK-DAG: %[[ARG0_B:.+]] = "xla_hlo.dynamic_broadcast_in_dim"(%[[ARG0]], %[[RESULT_EXTENTS]]) {broadcast_dimensions = dense<1> : tensor<1xi64>} : (tensor<?xf32>, tensor<2xindex>) -> tensor<?x?xf32>
// CHECK-DAG: %[[ARG1_B:.+]] = "xla_hlo.dynamic_broadcast_in_dim"(%[[ARG1]], %[[RESULT_EXTENTS]]) {broadcast_dimensions = dense<[0, 1]> : tensor<2xi64>} : (tensor<?x?xf32>, tensor<2xindex>) -> tensor<?x?xf32> // CHECK-DAG: %[[ARG1_B:.+]] = "xla_hlo.dynamic_broadcast_in_dim"(%[[ARG1]], %[[RESULT_EXTENTS]]) {broadcast_dimensions = dense<[0, 1]> : tensor<2xi64>} : (tensor<?x?xf32>, tensor<2xindex>) -> tensor<?x?xf32>
// CHECK-DAG: %[[RESULT:.+]] = "xla_hlo.compare"(%[[ARG0_B]], %[[ARG1_B]]) {comparison_direction = "EQ"} : (tensor<?x?xf32>, tensor<?x?xf32>) -> tensor<?x?xi1> // CHECK-DAG: %[[RESULT:.+]] = "xla_hlo.compare"(%[[ARG0_B]], %[[ARG1_B]]) {comparison_direction = "EQ"} : (tensor<?x?xf32>, tensor<?x?xf32>) -> tensor<?x?xi1>

View File

@ -184,14 +184,16 @@ func @dyn_broadcast(%operand: memref<?x?xf32>) {
// CHECK: %[[C1__:.*]] = constant 1 : index // CHECK: %[[C1__:.*]] = constant 1 : index
// CHECK: %[[EL1_:.*]] = extract_element %[[SHAPE]]{{\[}}%[[C1__]]] : tensor<3xi64> // CHECK: %[[EL1_:.*]] = extract_element %[[SHAPE]]{{\[}}%[[C1__]]] : tensor<3xi64>
// CHECK: %[[OPERAND_DIM_0:.*]] = dim %[[OPERAND]], 0 : memref<?x?xf32> // CHECK: %[[C0___:.*]] = constant 0 : index
// CHECK: %[[OPERAND_DIM_0:.*]] = dim %[[OPERAND]], %[[C0___]] : memref<?x?xf32>
// CHECK: %[[RESULT_DIM_1:.*]] = index_cast %[[EL1_]] : i64 to index // CHECK: %[[RESULT_DIM_1:.*]] = index_cast %[[EL1_]] : i64 to index
// CHECK: %[[EXPAND_0:.*]] = cmpi "slt", %[[OPERAND_DIM_0]], %[[RESULT_DIM_1]] // CHECK: %[[EXPAND_0:.*]] = cmpi "slt", %[[OPERAND_DIM_0]], %[[RESULT_DIM_1]]
// CHECK: %[[STRIDE_0:.*]] = select %[[EXPAND_0]], %[[C0_]], %[[C1_]] : index // CHECK: %[[STRIDE_0:.*]] = select %[[EXPAND_0]], %[[C0_]], %[[C1_]] : index
// CHECK: %[[C2_:.*]] = constant 2 : index // CHECK: %[[C2_:.*]] = constant 2 : index
// CHECK: %[[EL2_:.*]] = extract_element %[[SHAPE]]{{\[}}%[[C2_]]] : tensor<3xi64> // CHECK: %[[EL2_:.*]] = extract_element %[[SHAPE]]{{\[}}%[[C2_]]] : tensor<3xi64>
// CHECK: %[[OPERAND_DIM_1:.*]] = dim %[[OPERAND]], 1 : memref<?x?xf32> // CHECK: %[[C1___:.*]] = constant 1 : index
// CHECK: %[[OPERAND_DIM_1:.*]] = dim %[[OPERAND]], %[[C1___]] : memref<?x?xf32>
// CHECK: %[[RESULT_DIM_2:.*]] = index_cast %[[EL2_]] : i64 to index // CHECK: %[[RESULT_DIM_2:.*]] = index_cast %[[EL2_]] : i64 to index
// CHECK: %[[EXPAND_1:.*]] = cmpi "slt", %[[OPERAND_DIM_1]], %[[RESULT_DIM_2]] // CHECK: %[[EXPAND_1:.*]] = cmpi "slt", %[[OPERAND_DIM_1]], %[[RESULT_DIM_2]]
// CHECK: %[[STRIDE_1:.*]] = select %[[EXPAND_1]], %[[C0_]], %[[C1_]] : index // CHECK: %[[STRIDE_1:.*]] = select %[[EXPAND_1]], %[[C0_]], %[[C1_]] : index
@ -389,15 +391,18 @@ func @remainder(%lhs: memref<2x2xf32>, %rhs: memref<2x2xf32>, %result: memref<2x
func @add_dyn(%lhs: tensor<?x?xf32>, %rhs: tensor<?x?xf32>) { func @add_dyn(%lhs: tensor<?x?xf32>, %rhs: tensor<?x?xf32>) {
%result = "xla_hlo.add"(%lhs, %rhs) %result = "xla_hlo.add"(%lhs, %rhs)
: (tensor<?x?xf32>, tensor<?x?xf32>) -> tensor<?x?xf32> : (tensor<?x?xf32>, tensor<?x?xf32>) -> tensor<?x?xf32>
// CHECK: %[[DIM0:.*]] = dim %arg0, 0 : memref<?x?xf32> // CHECK: %[[C0:.*]] = constant 0 : index
// CHECK: %[[DIM0:.*]] = dim %arg0, %[[C0]] : memref<?x?xf32>
// CHECK: %[[IC0:.*]] = index_cast %[[DIM0]] : index to i64 // CHECK: %[[IC0:.*]] = index_cast %[[DIM0]] : index to i64
// CHECK: %[[DIM1:.*]] = dim %arg0, 1 : memref<?x?xf32> // CHECK: %[[C1:.*]] = constant 1 : index
// CHECK: %[[DIM1:.*]] = dim %arg0, %[[C1]] : memref<?x?xf32>
// CHECK: %[[IC1:.*]] = index_cast %[[DIM1]] : index to i64 // CHECK: %[[IC1:.*]] = index_cast %[[DIM1]] : index to i64
// CHECK: %[[SHAPE:.*]] = tensor_from_elements(%[[IC0]], %[[IC1]]) : tensor<2xi64> // CHECK: %[[SHAPE:.*]] = tensor_from_elements(%[[IC0]], %[[IC1]]) : tensor<2xi64>
// CHECK: %[[C0:.*]] = constant 0 : index // CHECK: %[[C0_:.*]] = constant 0 : index
// CHECK: %[[EE0:.*]] = extract_element %[[SHAPE]][%[[C0]]] : tensor<2xi64> // CHECK: %[[EE0:.*]] = extract_element %[[SHAPE]][%[[C0_]]] : tensor<2xi64>
// CHECK: %[[ICS0:.*]] = index_cast %[[EE0]] : i64 to index // CHECK: %[[ICS0:.*]] = index_cast %[[EE0]] : i64 to index
// CHECK: %[[EE1:.*]] = extract_element %[[SHAPE]][%[[C1]]] : tensor<2xi64> // CHECK: %[[C1_:.*]] = constant 1 : index
// CHECK: %[[EE1:.*]] = extract_element %[[SHAPE]][%[[C1_]]] : tensor<2xi64>
// CHECK: %[[ICS1:.*]] = index_cast %[[EE1]] : i64 to index // CHECK: %[[ICS1:.*]] = index_cast %[[EE1]] : i64 to index
// CHECK: %[[RESULT:.*]] = alloc(%[[ICS0]], %[[ICS1]]) // CHECK: %[[RESULT:.*]] = alloc(%[[ICS0]], %[[ICS1]])
// CHECK: "xla_lhlo.add"(%arg0, %arg1, %[[RESULT]]) : (memref<?x?xf32>, memref<?x?xf32>, memref<?x?xf32>) -> () // CHECK: "xla_lhlo.add"(%arg0, %arg1, %[[RESULT]]) : (memref<?x?xf32>, memref<?x?xf32>, memref<?x?xf32>) -> ()
@ -411,15 +416,18 @@ func @add_dyn(%lhs: tensor<?x?xf32>, %rhs: tensor<?x?xf32>) {
func @tanh_dyn(%arg0: tensor<?x?xf32>) { func @tanh_dyn(%arg0: tensor<?x?xf32>) {
%result = "xla_hlo.tanh"(%arg0) %result = "xla_hlo.tanh"(%arg0)
: (tensor<?x?xf32>) -> tensor<?x?xf32> : (tensor<?x?xf32>) -> tensor<?x?xf32>
// CHECK: %[[DIM0:.*]] = dim %arg0, 0 : memref<?x?xf32> // CHECK: %[[C0:.*]] = constant 0 : index
// CHECK: %[[DIM0:.*]] = dim %arg0, %[[C0]] : memref<?x?xf32>
// CHECK: %[[IC0:.*]] = index_cast %[[DIM0]] : index to i64 // CHECK: %[[IC0:.*]] = index_cast %[[DIM0]] : index to i64
// CHECK: %[[DIM1:.*]] = dim %arg0, 1 : memref<?x?xf32> // CHECK: %[[C1:.*]] = constant 1 : index
// CHECK: %[[DIM1:.*]] = dim %arg0, %[[C1]] : memref<?x?xf32>
// CHECK: %[[IC1:.*]] = index_cast %[[DIM1]] : index to i64 // CHECK: %[[IC1:.*]] = index_cast %[[DIM1]] : index to i64
// CHECK: %[[SHAPE:.*]] = tensor_from_elements(%[[IC0]], %[[IC1]]) : tensor<2xi64> // CHECK: %[[SHAPE:.*]] = tensor_from_elements(%[[IC0]], %[[IC1]]) : tensor<2xi64>
// CHECK: %[[C0:.*]] = constant 0 : index // CHECK: %[[C0_:.*]] = constant 0 : index
// CHECK: %[[EE0:.*]] = extract_element %[[SHAPE]][%[[C0]]] : tensor<2xi64> // CHECK: %[[EE0:.*]] = extract_element %[[SHAPE]][%[[C0_]]] : tensor<2xi64>
// CHECK: %[[ICS0:.*]] = index_cast %[[EE0]] : i64 to index // CHECK: %[[ICS0:.*]] = index_cast %[[EE0]] : i64 to index
// CHECK: %[[EE1:.*]] = extract_element %[[SHAPE]][%[[C1]]] : tensor<2xi64> // CHECK: %[[C1_:.*]] = constant 1 : index
// CHECK: %[[EE1:.*]] = extract_element %[[SHAPE]][%[[C1_]]] : tensor<2xi64>
// CHECK: %[[ICS1:.*]] = index_cast %[[EE1]] : i64 to index // CHECK: %[[ICS1:.*]] = index_cast %[[EE1]] : i64 to index
// CHECK: %[[RESULT:.*]] = alloc(%[[ICS0]], %[[ICS1]]) // CHECK: %[[RESULT:.*]] = alloc(%[[ICS0]], %[[ICS1]])
// CHECK: "xla_lhlo.tanh"(%arg0, %[[RESULT]]) : (memref<?x?xf32>, memref<?x?xf32>) -> () // CHECK: "xla_lhlo.tanh"(%arg0, %[[RESULT]]) : (memref<?x?xf32>, memref<?x?xf32>) -> ()

View File

@ -340,36 +340,36 @@ func @transpose(%arg0: tensor<2x3x9x5xi32>) -> tensor<3x2x5x9xi32> {
// ----- // -----
// CHECK-DAG: #[[OPERAND_MAP:.*]] = affine_map<(d0, d1) -> (d0, 0, d1)> // CHECK-DAG: #[[RESHAPE_MAP1:.*]] = affine_map<(d0, d1, d2) -> (d0, d1)>
// CHECK-DAG: #[[RESULT_MAP:.*]] = affine_map<(d0, d1) -> (d0, d1)> // CHECK-DAG: #[[RESHAPE_MAP2:.*]] = affine_map<(d0, d1, d2) -> (d2)>
// CHECK-LABEL: func @reshape_3D_2D // CHECK-LABEL: func @reshape_3D_2D
func @reshape_3D_2D(%arg0: tensor<12x1x42xi32>) -> tensor<12x42xi32> { func @reshape_3D_2D(%arg0: tensor<12x1x42xi32>) -> tensor<12x42xi32> {
%0 = "xla_hlo.reshape"(%arg0) : (tensor<12x1x42xi32>) -> tensor<12x42xi32> %0 = "xla_hlo.reshape"(%arg0) : (tensor<12x1x42xi32>) -> tensor<12x42xi32>
return %0 : tensor<12x42xi32> return %0 : tensor<12x42xi32>
} }
// CHECK: linalg.generic {{{.*}}indexing_maps = [#[[OPERAND_MAP]], #[[RESULT_MAP]]] // CHECK: linalg.tensor_reshape %{{.*}} [#[[RESHAPE_MAP1]], #[[RESHAPE_MAP2]]]
// ----- // -----
// CHECK-DAG: #[[OPERAND_MAP:.*]] = affine_map<(d0, d1) -> (d0, d1, 0, 0)> // CHECK-DAG: #[[RESHAPE_MAP1:.*]] = affine_map<(d0, d1, d2, d3) -> (d0)>
// CHECK-DAG: #[[RESULT_MAP:.*]] = affine_map<(d0, d1) -> (d0, d1)> // CHECK-DAG: #[[RESHAPE_MAP2:.*]] = affine_map<(d0, d1, d2, d3) -> (d1, d2, d3)>
// CHECK-LABEL: func @reshape_4D_2D // CHECK-LABEL: func @reshape_4D_2D
func @reshape_4D_2D(%arg0: tensor<12x42x1x1xi32>) -> tensor<12x42xi32> { func @reshape_4D_2D(%arg0: tensor<12x42x1x1xi32>) -> tensor<12x42xi32> {
%0 = "xla_hlo.reshape"(%arg0) : (tensor<12x42x1x1xi32>) -> tensor<12x42xi32> %0 = "xla_hlo.reshape"(%arg0) : (tensor<12x42x1x1xi32>) -> tensor<12x42xi32>
return %0 : tensor<12x42xi32> return %0 : tensor<12x42xi32>
} }
// CHECK: linalg.generic {{{.*}}indexing_maps = [#[[OPERAND_MAP]], #[[RESULT_MAP]]] // CHECK: linalg.tensor_reshape %{{.*}} [#[[RESHAPE_MAP1]], #[[RESHAPE_MAP2]]]
// ----- // -----
// CHECK-DAG: #[[OPERAND_MAP:.*]] = affine_map<(d0, d1, d2, d3) -> (d0, d2)> // CHECK-DAG: #[[RESHAPE_MAP1:.*]] = affine_map<(d0, d1, d2, d3) -> (d0, d1)>
// CHECK-DAG: #[[RESULT_MAP:.*]] = affine_map<(d0, d1, d2, d3) -> (d0, d1, d2, d3)> // CHECK-DAG: #[[RESHAPE_MAP2:.*]] = affine_map<(d0, d1, d2, d3) -> (d2, d3)>
// CHECK-LABEL: func @reshape_2D_4D // CHECK-LABEL: func @reshape_2D_4D
func @reshape_2D_4D(%arg0: tensor<12x42xi32>) -> tensor<12x1x42x1xi32> { func @reshape_2D_4D(%arg0: tensor<12x42xi32>) -> tensor<12x1x42x1xi32> {
%0 = "xla_hlo.reshape"(%arg0) : (tensor<12x42xi32>) -> tensor<12x1x42x1xi32> %0 = "xla_hlo.reshape"(%arg0) : (tensor<12x42xi32>) -> tensor<12x1x42x1xi32>
return %0 : tensor<12x1x42x1xi32> return %0 : tensor<12x1x42x1xi32>
} }
// CHECK: linalg.generic {{{.*}}indexing_maps = [#[[OPERAND_MAP]], #[[RESULT_MAP]]] // CHECK: linalg.tensor_reshape %{{.*}} [#[[RESHAPE_MAP1]], #[[RESHAPE_MAP2]]]
// ----- // -----
@ -407,7 +407,8 @@ func @add_scalar(%lhs: tensor<f32>, %rhs: tensor<f32>) -> tensor<f32> {
%0 = "xla_hlo.add"(%lhs, %rhs) : (tensor<f32>, tensor<f32>) -> tensor<f32> %0 = "xla_hlo.add"(%lhs, %rhs) : (tensor<f32>, tensor<f32>) -> tensor<f32>
return %0 : tensor<f32> return %0 : tensor<f32>
} }
// CHECK: linalg.generic {{{.*}}indexing_maps = [#[[MAP]], #[[MAP]], #[[MAP]]] // CHECK: linalg.generic
// CHECK-SAME: indexing_maps = [#[[MAP]], #[[MAP]], #[[MAP]]]
// CHECK-NEXT: ^bb0(%[[LHS:.*]]: f32, %[[RHS:.*]]: f32): // CHECK-NEXT: ^bb0(%[[LHS:.*]]: f32, %[[RHS:.*]]: f32):
// CHECK: %[[RESULT:.*]] = addf %[[LHS]], %[[RHS]] // CHECK: %[[RESULT:.*]] = addf %[[LHS]], %[[RHS]]
// CHECK-NEXT: linalg.yield %[[RESULT]] : f32 // CHECK-NEXT: linalg.yield %[[RESULT]] : f32
@ -554,4 +555,5 @@ func @reverse(%input: tensor<2x3xf32>) -> tensor<2x3xf32> {
} : (tensor<2x3xf32>) -> tensor<2x3xf32> } : (tensor<2x3xf32>) -> tensor<2x3xf32>
return %result : tensor<2x3xf32> return %result : tensor<2x3xf32>
} }
// CHECK: linalg.generic {{{.*}}indexing_maps = [#[[OPERAND_MAP]], #[[RESULT_MAP]]] // CHECK: linalg.generic
// CHECK-SAME: indexing_maps = [#[[OPERAND_MAP]], #[[RESULT_MAP]]]

View File

@ -14,10 +14,10 @@ func @batchmatmulv2_basic(%arg0: tensor<1x4x2xf32>, %arg1: tensor<3x2x4xf32>) ->
// CHECK: [[RHSHEAD:%.*]], [[RHSTAIL:%.*]] = "shape.split_at"([[RHSSHAPE]], [[CM2]]) : (!shape.shape, i32) -> (!shape.shape, !shape.shape) // CHECK: [[RHSHEAD:%.*]], [[RHSTAIL:%.*]] = "shape.split_at"([[RHSSHAPE]], [[CM2]]) : (!shape.shape, i32) -> (!shape.shape, !shape.shape)
// CHECK: [[BCASTHEAD:%.*]] = "shape.broadcast"([[LHSHEAD]], [[RHSHEAD]]) : (!shape.shape, !shape.shape) -> !shape.shape // CHECK: [[BCASTHEAD:%.*]] = "shape.broadcast"([[LHSHEAD]], [[RHSHEAD]]) : (!shape.shape, !shape.shape) -> !shape.shape
// CHECK: [[LHSBCASTSHAPE:%.*]] = "shape.concat"([[BCASTHEAD]], [[LHSTAIL]]) : (!shape.shape, !shape.shape) -> !shape.shape // CHECK: [[LHSBCASTSHAPE:%.*]] = "shape.concat"([[BCASTHEAD]], [[LHSTAIL]]) : (!shape.shape, !shape.shape) -> !shape.shape
// CHECK: [[LHSSHAPEEXTENTS:%.*]] = "shape.to_extent_tensor"([[LHSBCASTSHAPE]]) : (!shape.shape) -> tensor<3xindex> // CHECK: [[LHSSHAPEEXTENTS:%.*]] = shape.to_extent_tensor [[LHSBCASTSHAPE]] : tensor<3xindex>
// CHECK: [[LHSBCAST:%.*]] = "xla_hlo.dynamic_broadcast_in_dim"([[LHS]], [[LHSSHAPEEXTENTS]]) {broadcast_dimensions = dense<[0, 1, 2]> : tensor<3xi64>} : (tensor<1x4x2xf32>, tensor<3xindex>) -> tensor<3x4x2xf32> // CHECK: [[LHSBCAST:%.*]] = "xla_hlo.dynamic_broadcast_in_dim"([[LHS]], [[LHSSHAPEEXTENTS]]) {broadcast_dimensions = dense<[0, 1, 2]> : tensor<3xi64>} : (tensor<1x4x2xf32>, tensor<3xindex>) -> tensor<3x4x2xf32>
// CHECK: [[RHSBCASTSHAPE:%.*]] = "shape.concat"([[BCASTHEAD]], [[RHSTAIL]]) : (!shape.shape, !shape.shape) -> !shape.shape // CHECK: [[RHSBCASTSHAPE:%.*]] = "shape.concat"([[BCASTHEAD]], [[RHSTAIL]]) : (!shape.shape, !shape.shape) -> !shape.shape
// CHECK: [[RHSSHAPEEXTENTS:%.*]] = "shape.to_extent_tensor"([[RHSBCASTSHAPE]]) : (!shape.shape) -> tensor<3xindex> // CHECK: [[RHSSHAPEEXTENTS:%.*]] = shape.to_extent_tensor [[RHSBCASTSHAPE]] : tensor<3xindex>
// CHECK: [[RHSBCAST:%.*]] = "xla_hlo.dynamic_broadcast_in_dim"([[RHS]], [[RHSSHAPEEXTENTS]]) {broadcast_dimensions = dense<[0, 1, 2]> : tensor<3xi64>} : (tensor<3x2x4xf32>, tensor<3xindex>) -> tensor<3x2x4xf32> // CHECK: [[RHSBCAST:%.*]] = "xla_hlo.dynamic_broadcast_in_dim"([[RHS]], [[RHSSHAPEEXTENTS]]) {broadcast_dimensions = dense<[0, 1, 2]> : tensor<3xi64>} : (tensor<3x2x4xf32>, tensor<3xindex>) -> tensor<3x2x4xf32>
// CHECK: [[RESULT:%.*]] = "xla_hlo.dot_general"([[LHSBCAST]], [[RHSBCAST]]) {dot_dimension_numbers = {lhs_batching_dimensions = dense<0> : tensor<1xi64>, lhs_contracting_dimensions = dense<2> : tensor<1xi64>, rhs_batching_dimensions = dense<0> : tensor<1xi64>, rhs_contracting_dimensions = dense<1> : tensor<1xi64>}} : (tensor<3x4x2xf32>, tensor<3x2x4xf32>) -> tensor<3x4x4xf32> // CHECK: [[RESULT:%.*]] = "xla_hlo.dot_general"([[LHSBCAST]], [[RHSBCAST]]) {dot_dimension_numbers = {lhs_batching_dimensions = dense<0> : tensor<1xi64>, lhs_contracting_dimensions = dense<2> : tensor<1xi64>, rhs_batching_dimensions = dense<0> : tensor<1xi64>, rhs_contracting_dimensions = dense<1> : tensor<1xi64>}} : (tensor<3x4x2xf32>, tensor<3x2x4xf32>) -> tensor<3x4x4xf32>
// CHECK: return [[RESULT]] : tensor<3x4x4xf32> // CHECK: return [[RESULT]] : tensor<3x4x4xf32>

View File

@ -27,7 +27,7 @@ func @broadcast_add(%arg0: tensor<1xi32>, %arg1: tensor<1x2xi32>) -> tensor<1x2x
// CHECK: %[[UNUSED_LHS_SHAPE:.+]] = shape.const_shape [1] // CHECK: %[[UNUSED_LHS_SHAPE:.+]] = shape.const_shape [1]
// CHECK: %[[UNUSED_RHS_SHAPE:.+]] = shape.const_shape [1, 2] // CHECK: %[[UNUSED_RHS_SHAPE:.+]] = shape.const_shape [1, 2]
// CHECK: %[[RESULT_SHAPE:.+]] = shape.const_shape [1, 2] // CHECK: %[[RESULT_SHAPE:.+]] = shape.const_shape [1, 2]
// CHECK-DAG: %[[RESULT_EXTENTS:.+]] = "shape.to_extent_tensor"(%[[RESULT_SHAPE]]) // CHECK-DAG: %[[RESULT_EXTENTS:.+]] = shape.to_extent_tensor %[[RESULT_SHAPE]]
// CHECK-DAG: %[[LHS_BCAST:.+]] = "xla_hlo.dynamic_broadcast_in_dim"(%arg0, %[[RESULT_EXTENTS]]) {broadcast_dimensions = dense<1> : tensor<1xi64>} // CHECK-DAG: %[[LHS_BCAST:.+]] = "xla_hlo.dynamic_broadcast_in_dim"(%arg0, %[[RESULT_EXTENTS]]) {broadcast_dimensions = dense<1> : tensor<1xi64>}
// CHECK-DAG: %[[RHS_BCAST:.+]] = "xla_hlo.dynamic_broadcast_in_dim"(%arg1, %[[RESULT_EXTENTS]]) {broadcast_dimensions = dense<[0, 1]> : tensor<2xi64>} // CHECK-DAG: %[[RHS_BCAST:.+]] = "xla_hlo.dynamic_broadcast_in_dim"(%arg1, %[[RESULT_EXTENTS]]) {broadcast_dimensions = dense<[0, 1]> : tensor<2xi64>}
// CHECK: xla_hlo.add %[[LHS_BCAST]], %[[RHS_BCAST]] // CHECK: xla_hlo.add %[[LHS_BCAST]], %[[RHS_BCAST]]
@ -42,7 +42,7 @@ func @broadcast_multi_dim_add(%arg0: tensor<4x1x1xi32>, %arg1: tensor<4x4x4x4xi3
// CHECK: %[[UNUSED_LHS_SHAPE:.+]] = shape.const_shape [4, 1, 1] // CHECK: %[[UNUSED_LHS_SHAPE:.+]] = shape.const_shape [4, 1, 1]
// CHECK: %[[UNUSED_RHS_SHAPE:.+]] = shape.const_shape [4, 4, 4, 4] // CHECK: %[[UNUSED_RHS_SHAPE:.+]] = shape.const_shape [4, 4, 4, 4]
// CHECK: %[[RESULT_SHAPE:.+]] = shape.const_shape [4, 4, 4, 4] // CHECK: %[[RESULT_SHAPE:.+]] = shape.const_shape [4, 4, 4, 4]
// CHECK-DAG: %[[RESULT_EXTENTS:.+]] = "shape.to_extent_tensor"(%[[RESULT_SHAPE]]) // CHECK-DAG: %[[RESULT_EXTENTS:.+]] = shape.to_extent_tensor %[[RESULT_SHAPE]]
// CHECK-DAG: %[[LHS_BCAST:.+]] = "xla_hlo.dynamic_broadcast_in_dim"(%arg0, %[[RESULT_EXTENTS]]) {broadcast_dimensions = dense<[1, 2, 3]> : tensor<3xi64>} // CHECK-DAG: %[[LHS_BCAST:.+]] = "xla_hlo.dynamic_broadcast_in_dim"(%arg0, %[[RESULT_EXTENTS]]) {broadcast_dimensions = dense<[1, 2, 3]> : tensor<3xi64>}
// CHECK-DAG: %[[RHS_BCAST:.+]] = "xla_hlo.dynamic_broadcast_in_dim"(%arg1, %[[RESULT_EXTENTS]]) {broadcast_dimensions = dense<[0, 1, 2, 3]> : tensor<4xi64>} // CHECK-DAG: %[[RHS_BCAST:.+]] = "xla_hlo.dynamic_broadcast_in_dim"(%arg1, %[[RESULT_EXTENTS]]) {broadcast_dimensions = dense<[0, 1, 2, 3]> : tensor<4xi64>}
// CHECK: xla_hlo.add %[[LHS_BCAST]], %[[RHS_BCAST]] // CHECK: xla_hlo.add %[[LHS_BCAST]], %[[RHS_BCAST]]
@ -55,7 +55,7 @@ func @add_dynamic(%arg0: tensor<?xi32>, %arg1: tensor<?x?xi32>) -> tensor<?x?xi3
// CHECK-DAG: %[[LHS_SHAPE:.+]] = shape.shape_of %arg0 // CHECK-DAG: %[[LHS_SHAPE:.+]] = shape.shape_of %arg0
// CHECK-DAG: %[[RHS_SHAPE:.+]] = shape.shape_of %arg1 // CHECK-DAG: %[[RHS_SHAPE:.+]] = shape.shape_of %arg1
// CHECK-DAG: %[[RESULT_SHAPE:.+]] = "shape.broadcast"(%[[LHS_SHAPE]], %[[RHS_SHAPE]]) // CHECK-DAG: %[[RESULT_SHAPE:.+]] = "shape.broadcast"(%[[LHS_SHAPE]], %[[RHS_SHAPE]])
// CHECK-DAG: %[[RESULT_EXTENTS:.+]] = "shape.to_extent_tensor"(%[[RESULT_SHAPE]]) // CHECK-DAG: %[[RESULT_EXTENTS:.+]] = shape.to_extent_tensor %[[RESULT_SHAPE]]
// CHECK-DAG: %[[LHS_BCAST:.+]] = "xla_hlo.dynamic_broadcast_in_dim"(%arg0, %[[RESULT_EXTENTS]]) {broadcast_dimensions = dense<1> : tensor<1xi64>} // CHECK-DAG: %[[LHS_BCAST:.+]] = "xla_hlo.dynamic_broadcast_in_dim"(%arg0, %[[RESULT_EXTENTS]]) {broadcast_dimensions = dense<1> : tensor<1xi64>}
// CHECK-DAG: %[[RHS_BCAST:.+]] = "xla_hlo.dynamic_broadcast_in_dim"(%arg1, %[[RESULT_EXTENTS]]) {broadcast_dimensions = dense<[0, 1]> : tensor<2xi64>} // CHECK-DAG: %[[RHS_BCAST:.+]] = "xla_hlo.dynamic_broadcast_in_dim"(%arg1, %[[RESULT_EXTENTS]]) {broadcast_dimensions = dense<[0, 1]> : tensor<2xi64>}
// CHECK: xla_hlo.add %4, %5 : tensor<?x?xi32> // CHECK: xla_hlo.add %4, %5 : tensor<?x?xi32>
@ -203,7 +203,7 @@ func @equal_dynamic(%arg0: tensor<?xi32>, %arg1: tensor<1xi32>) -> tensor<?xi1>
// CHECK-DAG: %[[LHS_SHAPE:.+]] = shape.shape_of %arg0 // CHECK-DAG: %[[LHS_SHAPE:.+]] = shape.shape_of %arg0
// CHECK-DAG: %[[RHS_SHAPE:.+]] = shape.const_shape [1] // CHECK-DAG: %[[RHS_SHAPE:.+]] = shape.const_shape [1]
// CHECK-DAG: %[[RESULT_SHAPE:.+]] = "shape.broadcast"(%[[LHS_SHAPE]], %[[RHS_SHAPE]]) // CHECK-DAG: %[[RESULT_SHAPE:.+]] = "shape.broadcast"(%[[LHS_SHAPE]], %[[RHS_SHAPE]])
// CHECK-DAG: %[[RESULT_EXTENTS:.+]] = "shape.to_extent_tensor"(%[[RESULT_SHAPE]]) // CHECK-DAG: %[[RESULT_EXTENTS:.+]] = shape.to_extent_tensor %[[RESULT_SHAPE]]
// CHECK-DAG: %[[LHS_BCAST:.+]] = "xla_hlo.dynamic_broadcast_in_dim"(%arg0, %[[RESULT_EXTENTS]]) {broadcast_dimensions = dense<0> : tensor<1xi64>} // CHECK-DAG: %[[LHS_BCAST:.+]] = "xla_hlo.dynamic_broadcast_in_dim"(%arg0, %[[RESULT_EXTENTS]]) {broadcast_dimensions = dense<0> : tensor<1xi64>}
// CHECK-DAG: %[[RHS_BCAST:.+]] = "xla_hlo.dynamic_broadcast_in_dim"(%arg1, %[[RESULT_EXTENTS]]) {broadcast_dimensions = dense<0> : tensor<1xi64>} // CHECK-DAG: %[[RHS_BCAST:.+]] = "xla_hlo.dynamic_broadcast_in_dim"(%arg1, %[[RESULT_EXTENTS]]) {broadcast_dimensions = dense<0> : tensor<1xi64>}
// CHECK: "xla_hlo.compare"(%[[LHS_BCAST]], %[[RHS_BCAST]]) {comparison_direction = "EQ"} // CHECK: "xla_hlo.compare"(%[[LHS_BCAST]], %[[RHS_BCAST]]) {comparison_direction = "EQ"}
@ -216,7 +216,7 @@ func @equal_broadcast(%arg0: tensor<1xi32>, %arg1: tensor<1x2xi32>) -> tensor<1x
// CHECK-DAG: %[[LHS_SHAPE:.+]] = shape.const_shape [1] // CHECK-DAG: %[[LHS_SHAPE:.+]] = shape.const_shape [1]
// CHECK-DAG: %[[RHS_SHAPE:.+]] = shape.const_shape [1, 2] // CHECK-DAG: %[[RHS_SHAPE:.+]] = shape.const_shape [1, 2]
// CHECK-DAG: %[[RESULT_SHAPE:.+]] = shape.const_shape [1, 2] // CHECK-DAG: %[[RESULT_SHAPE:.+]] = shape.const_shape [1, 2]
// CHECK-DAG: %[[RESULT_EXTENTS:.+]] = "shape.to_extent_tensor"(%[[RESULT_SHAPE]]) // CHECK-DAG: %[[RESULT_EXTENTS:.+]] = shape.to_extent_tensor %[[RESULT_SHAPE]]
// CHECK-DAG: %[[LHS_BCAST:.+]] = "xla_hlo.dynamic_broadcast_in_dim"(%arg0, %[[RESULT_EXTENTS]]) {broadcast_dimensions = dense<1> : tensor<1xi64>} // CHECK-DAG: %[[LHS_BCAST:.+]] = "xla_hlo.dynamic_broadcast_in_dim"(%arg0, %[[RESULT_EXTENTS]]) {broadcast_dimensions = dense<1> : tensor<1xi64>}
// CHECK-DAG: %[[RHS_BCAST:.+]] = "xla_hlo.dynamic_broadcast_in_dim"(%arg1, %[[RESULT_EXTENTS]]) {broadcast_dimensions = dense<[0, 1]> : tensor<2xi64>} // CHECK-DAG: %[[RHS_BCAST:.+]] = "xla_hlo.dynamic_broadcast_in_dim"(%arg1, %[[RESULT_EXTENTS]]) {broadcast_dimensions = dense<[0, 1]> : tensor<2xi64>}
// CHECK: "xla_hlo.compare"(%[[LHS_BCAST]], %[[RHS_BCAST]]) {comparison_direction = "EQ"} // CHECK: "xla_hlo.compare"(%[[LHS_BCAST]], %[[RHS_BCAST]]) {comparison_direction = "EQ"}
@ -284,7 +284,7 @@ func @broadcast_greater(%arg0: tensor<1xi32>, %arg1: tensor<1x2xi32>) -> tensor<
// CHECK-DAG: %[[LHS_SHAPE:.+]] = shape.const_shape [1] // CHECK-DAG: %[[LHS_SHAPE:.+]] = shape.const_shape [1]
// CHECK-DAG: %[[RHS_SHAPE:.+]] = shape.const_shape [1, 2] // CHECK-DAG: %[[RHS_SHAPE:.+]] = shape.const_shape [1, 2]
// CHECK-DAG: %[[RESULT_SHAPE:.+]] = shape.const_shape [1, 2] // CHECK-DAG: %[[RESULT_SHAPE:.+]] = shape.const_shape [1, 2]
// CHECK-DAG: %[[RESULT_EXTENTS:.+]] = "shape.to_extent_tensor"(%[[RESULT_SHAPE]]) // CHECK-DAG: %[[RESULT_EXTENTS:.+]] = shape.to_extent_tensor %[[RESULT_SHAPE]]
// CHECK-DAG: %[[LHS_BCAST:.+]] = "xla_hlo.dynamic_broadcast_in_dim"(%arg0, %[[RESULT_EXTENTS]]) {broadcast_dimensions = dense<1> : tensor<1xi64>} // CHECK-DAG: %[[LHS_BCAST:.+]] = "xla_hlo.dynamic_broadcast_in_dim"(%arg0, %[[RESULT_EXTENTS]]) {broadcast_dimensions = dense<1> : tensor<1xi64>}
// CHECK-DAG: %[[RHS_BCAST:.+]] = "xla_hlo.dynamic_broadcast_in_dim"(%arg1, %[[RESULT_EXTENTS]]) {broadcast_dimensions = dense<[0, 1]> : tensor<2xi64>} // CHECK-DAG: %[[RHS_BCAST:.+]] = "xla_hlo.dynamic_broadcast_in_dim"(%arg1, %[[RESULT_EXTENTS]]) {broadcast_dimensions = dense<[0, 1]> : tensor<2xi64>}
// CHECK: "xla_hlo.compare"(%[[LHS_BCAST]], %[[RHS_BCAST]]) {comparison_direction = "GT"} // CHECK: "xla_hlo.compare"(%[[LHS_BCAST]], %[[RHS_BCAST]]) {comparison_direction = "GT"}
@ -297,7 +297,7 @@ func @greater_dynamic(%arg0: tensor<?xi32>, %arg1: tensor<?xi32>) -> tensor<?xi1
// CHECK-DAG: %[[LHS_SHAPE:.+]] = shape.shape_of %arg0 // CHECK-DAG: %[[LHS_SHAPE:.+]] = shape.shape_of %arg0
// CHECK-DAG: %[[RHS_SHAPE:.+]] = shape.shape_of %arg1 // CHECK-DAG: %[[RHS_SHAPE:.+]] = shape.shape_of %arg1
// CHECK-DAG: %[[RESULT_SHAPE:.+]] = "shape.broadcast"(%[[LHS_SHAPE]], %[[RHS_SHAPE]]) // CHECK-DAG: %[[RESULT_SHAPE:.+]] = "shape.broadcast"(%[[LHS_SHAPE]], %[[RHS_SHAPE]])
// CHECK-DAG: %[[RESULT_EXTENTS:.+]] = "shape.to_extent_tensor"(%[[RESULT_SHAPE]]) // CHECK-DAG: %[[RESULT_EXTENTS:.+]] = shape.to_extent_tensor %[[RESULT_SHAPE]]
// CHECK-DAG: %[[LHS_BCAST:.+]] = "xla_hlo.dynamic_broadcast_in_dim"(%arg0, %[[RESULT_EXTENTS]]) {broadcast_dimensions = dense<0> : tensor<1xi64>} // CHECK-DAG: %[[LHS_BCAST:.+]] = "xla_hlo.dynamic_broadcast_in_dim"(%arg0, %[[RESULT_EXTENTS]]) {broadcast_dimensions = dense<0> : tensor<1xi64>}
// CHECK-DAG: %[[RHS_BCAST:.+]] = "xla_hlo.dynamic_broadcast_in_dim"(%arg1, %[[RESULT_EXTENTS]]) {broadcast_dimensions = dense<0> : tensor<1xi64>} // CHECK-DAG: %[[RHS_BCAST:.+]] = "xla_hlo.dynamic_broadcast_in_dim"(%arg1, %[[RESULT_EXTENTS]]) {broadcast_dimensions = dense<0> : tensor<1xi64>}
// CHECK: "xla_hlo.compare"(%[[LHS_BCAST]], %[[RHS_BCAST]]) {comparison_direction = "GT"} // CHECK: "xla_hlo.compare"(%[[LHS_BCAST]], %[[RHS_BCAST]]) {comparison_direction = "GT"}

View File

@ -187,6 +187,70 @@ func @dynamic_update_slice(%arg0: tensor<3x4xi32>, %arg1: tensor<2x2xi32>, %arg2
return %0: tensor<3x4xi32> return %0: tensor<3x4xi32>
} }
// CHECK-LABEL: @sparse_to_dense
// CHECK-SAME: (%[[ARG0:.*]]: tensor<3x2xi32>, %[[ARG1:.*]]: tensor<3xf32>, %[[ARG2:.*]]: tensor<f32>)
func @sparse_to_dense(%arg0: tensor<3x2xi32>, %arg1: tensor<3xf32>, %arg2: tensor<f32>) -> tensor<3x3xf32> {
// CHECK: %[[CST:.*]] = xla_hlo.constant dense<3> : tensor<2xi32>
// CHECK: %[[DEFAULT:.*]] = "xla_hlo.broadcast_in_dim"(%[[ARG2]]) {broadcast_dimensions = dense<[]> : tensor<0xi64>} : (tensor<f32>) -> tensor<3x3xf32>
// CHECK: %[[RESULT:.*]] = "xla_hlo.scatter"(%[[DEFAULT]], %[[ARG0]], %[[ARG1]]) ( {
// CHECK: ^bb0(%[[ARG3:.*]]: tensor<f32>, %[[ARG4:.*]]: tensor<f32>): // no predecessors
// CHECK: "xla_hlo.return"(%[[ARG4]]) : (tensor<f32>) -> ()
// CHECK: })
// CHECK-SAME: indices_are_sorted = false
// CHECK-SAME: scatter_dimension_numbers
// CHECK-SAME: index_vector_dim = 1 : i64
// CHECK-SAME: inserted_window_dims = dense<[0, 1]> : tensor<2xi64>
// CHECK-SAME: scatter_dims_to_operand_dims = dense<[0, 1]> : tensor<2xi64>
// CHECK-SAME: update_window_dims = dense<[]> : tensor<0xi64>
// CHECK-SAME: unique_indices = false
// CHECK-SAME: (tensor<3x3xf32>, tensor<3x2xi32>, tensor<3xf32>) -> tensor<3x3xf32>
// return %[[RESULT]] : tensor<3x3xf32>
%cst = xla_hlo.constant dense<3> : tensor<2xi32>
%0 = "tf.SparseToDense"(%arg0, %cst, %arg1, %arg2) {validate_indices = true}: (tensor<3x2xi32>, tensor<2xi32>, tensor<3xf32>, tensor<f32>) -> tensor<3x3xf32>
return %0 : tensor<3x3xf32>
}
// CHECK-LABEL: fft
func @fft(%arg0: tensor<3x5x8xcomplex<f32>>) -> tensor<3x5x8xcomplex<f32>> {
// CHECK: "xla_hlo.fft"(%arg0)
%0 = "tf.FFT"(%arg0) : (tensor<3x5x8xcomplex<f32>>) -> tensor<3x5x8xcomplex<f32>>
return %0 : tensor<3x5x8xcomplex<f32>>
}
// CHECK-LABEL: reverse_sequence
func @reverse_sequence(%arg0: tensor<4x2x3x1x1xi32>, %arg1: tensor<3xi32>) -> tensor<4x2x3x1x1xi32> {
// CHECK-NOT: tf.ReverseSequence
%0 = "tf.ReverseSequence"(%arg0, %arg1) {batch_dim = 2 : i64, seq_dim = 0 : i64}: (tensor<4x2x3x1x1xi32>, tensor<3xi32>) -> tensor<4x2x3x1x1xi32>
return %0 : tensor<4x2x3x1x1xi32>
}
// CHECK-LABEL: mirror_pad
func @mirror_pad(%arg0: tensor<2x3xcomplex<f64>>) -> tensor<4x7xcomplex<f64>> {
%0 = xla_hlo.constant dense<[[1, 1], [2, 2]]> : tensor<2x2xi32>
// CHECK-NOT: tf.MirrorPad
%1 = "tf.MirrorPad"(%arg0, %0) {mode = "SYMMETRIC"} : (tensor<2x3xcomplex<f64>>, tensor<2x2xi32>) -> tensor<4x7xcomplex<f64>>
return %1 : tensor<4x7xcomplex<f64>>
}
// CHECK-LABEL: bucketize
func @bucketize(%arg0: tensor<2x5xf32>) -> tensor<2x5xi32> {
// CHECK-NOT: tf.Bucketize
%0 = "tf.Bucketize"(%arg0) {boundaries = [0.000000e+00 : f32, 3.000000e+00 : f32, 8.000000e+00 : f32, 1.100000e+01 : f32]} : (tensor<2x5xf32>) -> tensor<2x5xi32>
return %0 : tensor<2x5xi32>
}
// CHECK-LABEL: arg_min
func @arg_min(%arg0: tensor<6xf64>) -> tensor<i32> {
// CHECK-NOT: ArgMin
%0 = xla_hlo.constant dense<0> : tensor<i32>
%1 = "tf.ArgMin"(%arg0, %0) : (tensor<6xf64>, tensor<i32>) -> tensor<i32>
return %1 : tensor<i32>
}
// TODO(hinsu): Add a test with a valid TF op for which tf2xla kernel is // TODO(hinsu): Add a test with a valid TF op for which tf2xla kernel is
// available but doesn't support this instance. // available but doesn't support this instance.
} }

View File

@ -420,7 +420,7 @@ func @fusedBatchNormGradV3_Training_NCHW(%arg0: tensor<8x8x8x8xf32>, %arg1: tens
// CHECK-LABEL: func @biasAdd_NHWC // CHECK-LABEL: func @biasAdd_NHWC
func @biasAdd_NHWC(%arg0: tensor<1x32x10x32xi32>, %arg1: tensor<32xi32>) -> tensor<1x32x10x32xi32> { func @biasAdd_NHWC(%arg0: tensor<1x32x10x32xi32>, %arg1: tensor<32xi32>) -> tensor<1x32x10x32xi32> {
// CHECK: %[[ARG0_SHAPE:.+]] = shape.shape_of %arg0 // CHECK: %[[ARG0_SHAPE:.+]] = shape.shape_of %arg0
// CHECK: %[[ARG0_EXTENTS:.+]] = "shape.to_extent_tensor"(%[[ARG0_SHAPE]]) // CHECK: %[[ARG0_EXTENTS:.+]] = shape.to_extent_tensor %[[ARG0_SHAPE]]
// CHECK: %[[ARG1_BCAST:.+]] = "xla_hlo.dynamic_broadcast_in_dim"(%arg1, %[[ARG0_EXTENTS]]) // CHECK: %[[ARG1_BCAST:.+]] = "xla_hlo.dynamic_broadcast_in_dim"(%arg1, %[[ARG0_EXTENTS]])
// CHECK-SAME: {broadcast_dimensions = dense<3> : tensor<1xi64>} // CHECK-SAME: {broadcast_dimensions = dense<3> : tensor<1xi64>}
// CHECK: %[[RESULT:.+]] = xla_hlo.add %arg0, %[[ARG1_BCAST]] // CHECK: %[[RESULT:.+]] = xla_hlo.add %arg0, %[[ARG1_BCAST]]
@ -431,7 +431,7 @@ func @biasAdd_NHWC(%arg0: tensor<1x32x10x32xi32>, %arg1: tensor<32xi32>) -> tens
// CHECK-LABEL: func @biasAdd_NCHW // CHECK-LABEL: func @biasAdd_NCHW
func @biasAdd_NCHW(%arg0: tensor<1x32x10x32xi32>, %arg1: tensor<32xi32>) -> tensor<1x32x10x32xi32> { func @biasAdd_NCHW(%arg0: tensor<1x32x10x32xi32>, %arg1: tensor<32xi32>) -> tensor<1x32x10x32xi32> {
// CHECK: %[[ARG0_SHAPE:.+]] = shape.shape_of %arg0 // CHECK: %[[ARG0_SHAPE:.+]] = shape.shape_of %arg0
// CHECK: %[[ARG0_EXTENTS:.+]] = "shape.to_extent_tensor"(%[[ARG0_SHAPE]]) // CHECK: %[[ARG0_EXTENTS:.+]] = shape.to_extent_tensor %[[ARG0_SHAPE]]
// CHECK: %[[ARG1_BCAST:.+]] = "xla_hlo.dynamic_broadcast_in_dim"(%arg1, %[[ARG0_EXTENTS]]) // CHECK: %[[ARG1_BCAST:.+]] = "xla_hlo.dynamic_broadcast_in_dim"(%arg1, %[[ARG0_EXTENTS]])
// CHECK-SAME: {broadcast_dimensions = dense<1> : tensor<1xi64>} // CHECK-SAME: {broadcast_dimensions = dense<1> : tensor<1xi64>}
// CHECK: %[[RESULT:.+]] = xla_hlo.add %arg0, %[[ARG1_BCAST]] // CHECK: %[[RESULT:.+]] = xla_hlo.add %arg0, %[[ARG1_BCAST]]
@ -442,7 +442,7 @@ func @biasAdd_NCHW(%arg0: tensor<1x32x10x32xi32>, %arg1: tensor<32xi32>) -> tens
// CHECK-LABEL: func @biasAdd_dynamic // CHECK-LABEL: func @biasAdd_dynamic
func @biasAdd_dynamic(%arg0: tensor<?x?x?x?xi32>, %arg1: tensor<?xi32>) -> tensor<?x?x?x?xi32> { func @biasAdd_dynamic(%arg0: tensor<?x?x?x?xi32>, %arg1: tensor<?xi32>) -> tensor<?x?x?x?xi32> {
// CHECK: %[[ARG0_SHAPE:.+]] = shape.shape_of %arg0 // CHECK: %[[ARG0_SHAPE:.+]] = shape.shape_of %arg0
// CHECK: %[[ARG0_EXTENTS:.+]] = "shape.to_extent_tensor"(%[[ARG0_SHAPE]]) // CHECK: %[[ARG0_EXTENTS:.+]] = shape.to_extent_tensor %[[ARG0_SHAPE]]
// CHECK: %[[ARG1_BCAST:.+]] = "xla_hlo.dynamic_broadcast_in_dim"(%arg1, %[[ARG0_EXTENTS]]) // CHECK: %[[ARG1_BCAST:.+]] = "xla_hlo.dynamic_broadcast_in_dim"(%arg1, %[[ARG0_EXTENTS]])
// CHECK-SAME: {broadcast_dimensions = dense<1> : tensor<1xi64>} // CHECK-SAME: {broadcast_dimensions = dense<1> : tensor<1xi64>}
// CHECK: %[[RESULT:.+]] = xla_hlo.add %arg0, %[[ARG1_BCAST]] // CHECK: %[[RESULT:.+]] = xla_hlo.add %arg0, %[[ARG1_BCAST]]
@ -1445,7 +1445,7 @@ func @simple_softmax(%arg0: tensor<2x3xf32>) -> tensor<2x3xf32> {
// CHECK: %[[CASTED_MAX:.*]] = "xla_hlo.convert"(%[[MAX]]) : (tensor<2xf32>) -> tensor<2xf32> // CHECK: %[[CASTED_MAX:.*]] = "xla_hlo.convert"(%[[MAX]]) : (tensor<2xf32>) -> tensor<2xf32>
// CHECK: %[[RESULT_SHAPE:.+]] = shape.shape_of %[[ARG0]] // CHECK: %[[RESULT_SHAPE:.+]] = shape.shape_of %[[ARG0]]
// CHECK: %[[RESULT_EXTENTS:.+]] = "shape.to_extent_tensor"(%[[RESULT_SHAPE]]) : (!shape.shape) -> tensor<2xindex> // CHECK: %[[RESULT_EXTENTS:.+]] = shape.to_extent_tensor %[[RESULT_SHAPE]] : tensor<2xindex>
// CHECK: %[[BCAST_MAX:.+]] = "xla_hlo.dynamic_broadcast_in_dim"(%[[CASTED_MAX]], %[[RESULT_EXTENTS]]) {broadcast_dimensions = dense<0> : tensor<1xi64>} // CHECK: %[[BCAST_MAX:.+]] = "xla_hlo.dynamic_broadcast_in_dim"(%[[CASTED_MAX]], %[[RESULT_EXTENTS]]) {broadcast_dimensions = dense<0> : tensor<1xi64>}
// CHECK: %[[SHIFTED_INP:.*]] = xla_hlo.subtract %[[ARG0]], %[[BCAST_MAX]] // CHECK: %[[SHIFTED_INP:.*]] = xla_hlo.subtract %[[ARG0]], %[[BCAST_MAX]]
// CHECK: %[[EXP:.*]] = "xla_hlo.exponential"(%[[SHIFTED_INP]]) // CHECK: %[[EXP:.*]] = "xla_hlo.exponential"(%[[SHIFTED_INP]])
@ -1460,7 +1460,7 @@ func @simple_softmax(%arg0: tensor<2x3xf32>) -> tensor<2x3xf32> {
// CHECK: %[[CASTED_SUM:.*]] = "xla_hlo.convert"(%[[SUM]]) : (tensor<2xf32>) -> tensor<2xf32> // CHECK: %[[CASTED_SUM:.*]] = "xla_hlo.convert"(%[[SUM]]) : (tensor<2xf32>) -> tensor<2xf32>
// CHECK: %[[RESULT_SHAPE:.+]] = shape.shape_of %[[ARG0]] // CHECK: %[[RESULT_SHAPE:.+]] = shape.shape_of %[[ARG0]]
// CHECK: %[[RESULT_EXTENTS:.+]] = "shape.to_extent_tensor"(%[[RESULT_SHAPE]]) : (!shape.shape) -> tensor<2xindex> // CHECK: %[[RESULT_EXTENTS:.+]] = shape.to_extent_tensor %[[RESULT_SHAPE]] : tensor<2xindex>
// CHECK: %[[BCAST_SUM:.+]] = "xla_hlo.dynamic_broadcast_in_dim"(%[[CASTED_SUM]], %[[RESULT_EXTENTS]]) {broadcast_dimensions = dense<0> : tensor<1xi64>} // CHECK: %[[BCAST_SUM:.+]] = "xla_hlo.dynamic_broadcast_in_dim"(%[[CASTED_SUM]], %[[RESULT_EXTENTS]]) {broadcast_dimensions = dense<0> : tensor<1xi64>}
// CHECK: %[[RESULT:.*]] = xla_hlo.divide %[[EXP]], %[[BCAST_SUM]] // CHECK: %[[RESULT:.*]] = xla_hlo.divide %[[EXP]], %[[BCAST_SUM]]
// CHECK: return %[[RESULT]] // CHECK: return %[[RESULT]]
@ -1517,7 +1517,7 @@ func @simple_logsoftmax(%arg0: tensor<2x3xf32>) -> tensor<2x3xf32> {
// CHECK: %[[CASTED_SUM:.*]] = "xla_hlo.convert"(%[[SUM]]) : (tensor<2xf32>) -> tensor<2xf32> // CHECK: %[[CASTED_SUM:.*]] = "xla_hlo.convert"(%[[SUM]]) : (tensor<2xf32>) -> tensor<2xf32>
// CHECK: %[[LOG:.*]] = "xla_hlo.log"(%[[CASTED_SUM]]) : (tensor<2xf32>) -> tensor<2xf32> // CHECK: %[[LOG:.*]] = "xla_hlo.log"(%[[CASTED_SUM]]) : (tensor<2xf32>) -> tensor<2xf32>
// CHECK: %[[RESULT_SHAPE:.+]] = shape.shape_of %[[ARG0]] // CHECK: %[[RESULT_SHAPE:.+]] = shape.shape_of %[[ARG0]]
// CHECK: %[[RESULT_EXTENTS:.+]] = "shape.to_extent_tensor"(%[[RESULT_SHAPE]]) : (!shape.shape) -> tensor<2xindex> // CHECK: %[[RESULT_EXTENTS:.+]] = shape.to_extent_tensor %[[RESULT_SHAPE]] : tensor<2xindex>
// CHECK: %[[BCAST_SUM:.+]] = "xla_hlo.dynamic_broadcast_in_dim"(%[[LOG]], %[[RESULT_EXTENTS]]) {broadcast_dimensions = dense<0> : tensor<1xi64>} // CHECK: %[[BCAST_SUM:.+]] = "xla_hlo.dynamic_broadcast_in_dim"(%[[LOG]], %[[RESULT_EXTENTS]]) {broadcast_dimensions = dense<0> : tensor<1xi64>}
// CHECK: %[[RESULT:.*]] = xla_hlo.subtract {{.*}}, %[[BCAST_SUM]] // CHECK: %[[RESULT:.*]] = xla_hlo.subtract {{.*}}, %[[BCAST_SUM]]
// CHECK: return %[[RESULT]] // CHECK: return %[[RESULT]]
@ -1544,58 +1544,34 @@ func @rfft_1D(%arg0: tensor<8xf32>) -> tensor<8xcomplex<f32>> {
// CHECK-LABEL: func @shape_1D // CHECK-LABEL: func @shape_1D
func @shape_1D(%arg0: tensor<?xf32>) -> tensor<1xi32> { func @shape_1D(%arg0: tensor<?xf32>) -> tensor<1xi32> {
// CHECK-DAG: [[SHAPE:%.+]] = shape.shape_of %arg0 // CHECK: [[SHAPE:%.+]] = shape.shape_of %arg0
// CHECK-DAG: [[EXTENT:%.+]] = shape.get_extent [[SHAPE]], 0 // CHECK: [[TENSOR:%.+]] = shape.to_extent_tensor [[SHAPE]]
// CHECK-DAG: [[TO_INDEX:%.+]] = shape.size_to_index [[EXTENT]] // CHECK: [[CAST:%.+]] = index_cast [[TENSOR]]
// CHECK-DAG: [[CAST:%.+]] = index_cast [[TO_INDEX]]
// CHECK-DAG: [[TENSOR:%.+]] = tensor_from_elements([[CAST]])
// CHECK-DAG: [[RESHAPE:%.+]] = "xla_hlo.reshape"([[TENSOR]])
// CHECK-DAG: [[CONCAT:%.+]] = "xla_hlo.concatenate"([[RESHAPE]]) {dimension = 0 : i64}
%0 = "tf.Shape"(%arg0) : (tensor<?xf32>) -> tensor<1xi32> %0 = "tf.Shape"(%arg0) : (tensor<?xf32>) -> tensor<1xi32>
// CHECK: return [[CONCAT]] // CHECK: return [[CAST]]
return %0 : tensor<1xi32> return %0 : tensor<1xi32>
} }
// CHECK-LABEL: func @shape_2D // CHECK-LABEL: func @shape_2D
func @shape_2D(%arg0: tensor<?x?xf32>) -> tensor<2xi32> { func @shape_2D(%arg0: tensor<?x?xf32>) -> tensor<2xi32> {
// CHECK-DAG: [[SHAPE:%.+]] = shape.shape_of %arg0 // CHECK: [[SHAPE:%.+]] = shape.shape_of %arg0
// CHECK-DAG: [[EXTENT0:%.+]] = shape.get_extent [[SHAPE]], 0 // CHECK: [[TENSOR:%.+]] = shape.to_extent_tensor [[SHAPE]]
// CHECK-DAG: [[EXTENT1:%.+]] = shape.get_extent [[SHAPE]], 1 // CHECK: [[CAST:%.+]] = index_cast [[TENSOR]]
// CHECK-DAG: [[TO_INDEX0:%.+]] = shape.size_to_index [[EXTENT0]]
// CHECK-DAG: [[TO_INDEX1:%.+]] = shape.size_to_index [[EXTENT1]]
// CHECK-DAG: [[CAST0:%.+]] = index_cast [[TO_INDEX0]]
// CHECK-DAG: [[CAST1:%.+]] = index_cast [[TO_INDEX1]]
// CHECK-DAG: [[TENSOR0:%.+]] = tensor_from_elements([[CAST0]])
// CHECK-DAG: [[TENSOR1:%.+]] = tensor_from_elements([[CAST1]])
// CHECK-DAG: [[RESHAPE0:%.+]] = "xla_hlo.reshape"([[TENSOR0]])
// CHECK-DAG: [[RESHAPE1:%.+]] = "xla_hlo.reshape"([[TENSOR1]])
// CHECK-DAG: [[CONCAT:%.+]] = "xla_hlo.concatenate"([[RESHAPE0]], [[RESHAPE1]]) {dimension = 0 : i64}
%0 = "tf.Shape"(%arg0) : (tensor<?x?xf32>) -> tensor<2xi32> %0 = "tf.Shape"(%arg0) : (tensor<?x?xf32>) -> tensor<2xi32>
// CHECK: return [[CONCAT]] // CHECK: return [[CAST]]
return %0 : tensor<2xi32>
}
// CHECK-LABEL: func @shape_with_const
func @shape_with_const(%arg0: tensor<?x3xf32>) -> tensor<2xi32> {
// CHECK-DAG: [[SHAPE:%.+]] = shape.shape_of %arg0
// CHECK-DAG: [[EXTENT:%.+]] = shape.get_extent [[SHAPE]], 0
// CHECK-DAG: [[TO_INDEX:%.+]] = shape.size_to_index [[EXTENT]]
// CHECK-DAG: [[CAST:%.+]] = index_cast [[TO_INDEX]]
// CHECK-DAG: [[TENSOR:%.+]] = tensor_from_elements([[CAST]])
// CHECK-DAG: [[RESHAPE:%.+]] = "xla_hlo.reshape"([[TENSOR]])
// CHECK-DAG: [[CONST:%.+]] = xla_hlo.constant dense<3>
// CHECK-DAG: [[CONCAT:%.+]] = "xla_hlo.concatenate"([[RESHAPE]], [[CONST]]) {dimension = 0 : i64}
%0 = "tf.Shape"(%arg0) : (tensor<?x3xf32>) -> tensor<2xi32>
// CHECK: return [[CONCAT]]
return %0 : tensor<2xi32> return %0 : tensor<2xi32>
} }
// CHECK-LABEL: func @shape_rankless // CHECK-LABEL: func @shape_rankless
func @shape_rankless(%arg0: tensor<*xf32>) -> tensor<?xi32> { func @shape_rankless(%arg0: tensor<*xf32>) -> tensor<?xi32> {
// CHECK: [[SHAPE:%.+]] = shape.shape_of %arg0
// CHECK: [[TENSOR:%.+]] = shape.to_extent_tensor [[SHAPE]]
// CHECK: [[CAST:%.+]] = index_cast [[TENSOR]]
%0 = "tf.Shape"(%arg0) : (tensor<*xf32>) -> tensor<?xi32> %0 = "tf.Shape"(%arg0) : (tensor<*xf32>) -> tensor<?xi32>
// CHECK: return [[CAST]]
return %0 : tensor<?xi32> return %0 : tensor<?xi32>
} }
@ -1884,7 +1860,7 @@ func @neg_unranked(%arg0: tensor<*xf32>) -> tensor<*xf32> {
func @sigmoid(%arg0: tensor<2xf32>) -> tensor<2xf32> { func @sigmoid(%arg0: tensor<2xf32>) -> tensor<2xf32> {
// CHECK-DAG: [[SCALAR:%.+]] = xla_hlo.constant dense<5.000000e-01> : tensor<f32> // CHECK-DAG: [[SCALAR:%.+]] = xla_hlo.constant dense<5.000000e-01> : tensor<f32>
// CHECK-DAG: [[SHAPE:%.+]] = shape.shape_of %arg0 : tensor<2xf32> // CHECK-DAG: [[SHAPE:%.+]] = shape.shape_of %arg0 : tensor<2xf32>
// CHECK-DAG: [[SHAPE_VAL:%.+]] = "shape.to_extent_tensor"([[SHAPE]]) : (!shape.shape) -> tensor<1xindex> // CHECK-DAG: [[SHAPE_VAL:%.+]] = shape.to_extent_tensor [[SHAPE]] : tensor<1xindex>
// CHECK-DAG: [[HALF:%.+]] = "xla_hlo.dynamic_broadcast_in_dim"([[SCALAR]], [[SHAPE_VAL]]) {broadcast_dimensions = dense<[]> : tensor<0xi64>} : (tensor<f32>, tensor<1xindex>) -> tensor<2xf32> // CHECK-DAG: [[HALF:%.+]] = "xla_hlo.dynamic_broadcast_in_dim"([[SCALAR]], [[SHAPE_VAL]]) {broadcast_dimensions = dense<[]> : tensor<0xi64>} : (tensor<f32>, tensor<1xindex>) -> tensor<2xf32>
// CHECK-DAG: [[R1:%.+]] = xla_hlo.multiply %arg0, [[HALF]] : tensor<2xf32> // CHECK-DAG: [[R1:%.+]] = xla_hlo.multiply %arg0, [[HALF]] : tensor<2xf32>
// CHECK-DAG: [[R2:%.+]] = "xla_hlo.tanh"([[R1]]) : (tensor<2xf32>) -> tensor<2xf32> // CHECK-DAG: [[R2:%.+]] = "xla_hlo.tanh"([[R1]]) : (tensor<2xf32>) -> tensor<2xf32>
@ -1906,7 +1882,7 @@ func @sigmoid_complex(%arg0: tensor<2xcomplex<f32>>) -> tensor<2xcomplex<f32>> {
func @sigmoid_unranked(%arg0: tensor<*xf32>) -> tensor<*xf32> { func @sigmoid_unranked(%arg0: tensor<*xf32>) -> tensor<*xf32> {
// CHECK-DAG: [[SCALAR:%.+]] = xla_hlo.constant dense<5.000000e-01> : tensor<f32> // CHECK-DAG: [[SCALAR:%.+]] = xla_hlo.constant dense<5.000000e-01> : tensor<f32>
// CHECK-DAG: [[SHAPE:%.+]] = shape.shape_of %arg0 : tensor<*xf32> // CHECK-DAG: [[SHAPE:%.+]] = shape.shape_of %arg0 : tensor<*xf32>
// CHECK-DAG: [[SHAPE_VAL:%.+]] = "shape.to_extent_tensor"([[SHAPE]]) : (!shape.shape) -> tensor<?xindex> // CHECK-DAG: [[SHAPE_VAL:%.+]] = shape.to_extent_tensor [[SHAPE]] : tensor<?xindex>
// CHECK-DAG: [[HALF:%.+]] = "xla_hlo.dynamic_broadcast_in_dim"([[SCALAR]], [[SHAPE_VAL]]) {broadcast_dimensions = dense<[]> : tensor<0xi64>} : (tensor<f32>, tensor<?xindex>) -> tensor<*xf32> // CHECK-DAG: [[HALF:%.+]] = "xla_hlo.dynamic_broadcast_in_dim"([[SCALAR]], [[SHAPE_VAL]]) {broadcast_dimensions = dense<[]> : tensor<0xi64>} : (tensor<f32>, tensor<?xindex>) -> tensor<*xf32>
// CHECK-DAG: [[R1:%.+]] = xla_hlo.multiply %arg0, [[HALF]] : tensor<*xf32> // CHECK-DAG: [[R1:%.+]] = xla_hlo.multiply %arg0, [[HALF]] : tensor<*xf32>
// CHECK-DAG: [[R2:%.+]] = "xla_hlo.tanh"([[R1]]) : (tensor<*xf32>) -> tensor<*xf32> // CHECK-DAG: [[R2:%.+]] = "xla_hlo.tanh"([[R1]]) : (tensor<*xf32>) -> tensor<*xf32>
@ -3826,42 +3802,6 @@ func @random_shuffle_3D(%input: tensor<4x?x16xf32>) -> tensor<4x?x16xf32> {
return %0: tensor<4x?x16xf32> return %0: tensor<4x?x16xf32>
} }
//===----------------------------------------------------------------------===//
// tf.VariableShape legalization
//===----------------------------------------------------------------------===//
// CHECK-LABLE: @variable_shape32
func @variable_shape32(%input: tensor<!tf.resource<tensor<2x4x8xf32>>>) -> tensor<3xi32> {
// CHECK: [[CST:%.*]] = xla_hlo.constant dense<[2, 4, 8]> : tensor<3xi32>
// CHECK: [[CST_CAST:%.*]] = tensor_cast [[CST]]
%0 = "tf.VariableShape"(%input) : (tensor<!tf.resource<tensor<2x4x8xf32>>>) -> (tensor<3xi32>)
// CHECK: return [[CST_CAST]]
return %0: tensor<3xi32>
}
// CHECK-LABLE: @variable_shape64
func @variable_shape64(%input: tensor<!tf.resource<tensor<2x4x8xf32>>>) -> tensor<3xi64> {
// CHECK: [[CST:%.*]] = xla_hlo.constant dense<[2, 4, 8]> : tensor<3xi64>
// CHECK: [[CST_CAST:%.*]] = tensor_cast [[CST]]
%0 = "tf.VariableShape"(%input) : (tensor<!tf.resource<tensor<2x4x8xf32>>>) -> (tensor<3xi64>)
// CHECK: return [[CST_CAST]]
return %0: tensor<3xi64>
}
// CHECK-LABEL: @variable_shape_unknown_resource
func @variable_shape_unknown_resource(%input: tensor<!tf.resource>) -> tensor<?xi32> {
// CHECK: tf.VariableShape
%0 = "tf.VariableShape"(%input) : (tensor<!tf.resource>) -> (tensor<?xi32>)
return %0: tensor<?xi32>
}
// CHECK-LABEL: @variable_shape_unknown_resource_shape
func @variable_shape_unknown_resource_shape(%input: tensor<!tf.resource<tensor<?x?xf32>>>) -> tensor<2xi32> {
// CHECK: tf.VariableShape
%0 = "tf.VariableShape"(%input) : (tensor<!tf.resource<tensor<?x?xf32>>>) -> (tensor<2xi32>)
return %0: tensor<2xi32>
}
//===----------------------------------------------------------------------===// //===----------------------------------------------------------------------===//
// tf.AvgPool legalization // tf.AvgPool legalization
//===----------------------------------------------------------------------===// //===----------------------------------------------------------------------===//
@ -4025,3 +3965,87 @@ func @qr(%arg0: tensor<500x100x75xf32>) -> (tensor<500x100x75xf32>, tensor<500x7
%0:2 = "tf.Qr"(%arg0) {full_matrices = false} : (tensor<500x100x75xf32>) -> (tensor<500x100x75xf32>, tensor<500x75x75xf32>) %0:2 = "tf.Qr"(%arg0) {full_matrices = false} : (tensor<500x100x75xf32>) -> (tensor<500x100x75xf32>, tensor<500x75x75xf32>)
return %0#0, %0#1 : tensor<500x100x75xf32>, tensor<500x75x75xf32> return %0#0, %0#1 : tensor<500x100x75xf32>, tensor<500x75x75xf32>
} }
//===----------------------------------------------------------------------===//
// tf.Softplus legalization
//===----------------------------------------------------------------------===//
// CHECK-LABEL: func @softplus_f16
// CHECK-SAME: ([[FEATURES:%.*]]: tensor<8x16xf16>)
func @softplus_f16(%arg0: tensor<8x16xf16>) -> tensor<8x16xf16> {
// CHECK-DAG: [[FEATURES_EXP:%.*]] = "xla_hlo.exponential"([[FEATURES]])
// CHECK-DAG: [[EPSILON:%.*]] = xla_hlo.constant dense<1.220700e-04> : tensor<f16>
// CHECK-DAG: [[EPSILON_LOG:%.*]] = "xla_hlo.log"([[EPSILON]])
// CHECK-DAG: [[TWO:%.*]] = xla_hlo.constant dense<2.000000e+00> : tensor<f16>
// CHECK: [[THRESHOLD:%.*]] = xla_chlo.broadcast_add [[EPSILON_LOG]], [[TWO]]
// CHECK: [[NEG_THRESHOLD:%.*]] = "xla_hlo.negate"([[THRESHOLD]])
// CHECK-DAG: [[COMPARE_GT:%.*]] = xla_chlo.broadcast_compare [[FEATURES]], [[NEG_THRESHOLD]] {comparison_direction = "GT"}
// CHECK-DAG: [[COMPARE_LT:%.*]] = xla_chlo.broadcast_compare [[FEATURES]], [[THRESHOLD]] {comparison_direction = "LT"}
// CHECK-DAG: [[FEATURES_EXP_LOG:%.*]] = "xla_hlo.log_plus_one"([[FEATURES_EXP]])
// CHECK: [[ELSE_SELECT:%.*]] = "xla_hlo.select"([[COMPARE_LT]], [[FEATURES_EXP]], [[FEATURES_EXP_LOG]])
// CHECK: [[ENTRY_SELECT:%.*]] = "xla_hlo.select"([[COMPARE_GT]], [[FEATURES]], [[ELSE_SELECT]])
%0 = "tf.Softplus"(%arg0) : (tensor<8x16xf16>) -> tensor<8x16xf16>
// CHECK: return [[ENTRY_SELECT]] : tensor<8x16xf16>
return %0 : tensor<8x16xf16>
}
// CHECK-LABEL: func @softplus_bf16
// CHECK-SAME: ([[FEATURES:%.*]]: tensor<8x16xbf16>)
func @softplus_bf16(%arg0: tensor<8x16xbf16>) -> tensor<8x16xbf16> {
// CHECK-DAG: [[FEATURES_EXP:%.*]] = "xla_hlo.exponential"([[FEATURES]])
// CHECK-DAG: [[EPSILON:%.*]] = xla_hlo.constant dense<7.812500e-03> : tensor<bf16>
// CHECK-DAG: [[EPSILON_LOG:%.*]] = "xla_hlo.log"([[EPSILON]])
// CHECK-DAG: [[TWO:%.*]] = xla_hlo.constant dense<2.000000e+00> : tensor<bf16>
// CHECK: [[THRESHOLD:%.*]] = xla_chlo.broadcast_add [[EPSILON_LOG]], [[TWO]]
// CHECK: [[NEG_THRESHOLD:%.*]] = "xla_hlo.negate"([[THRESHOLD]])
// CHECK-DAG: [[COMPARE_GT:%.*]] = xla_chlo.broadcast_compare [[FEATURES]], [[NEG_THRESHOLD]] {comparison_direction = "GT"}
// CHECK-DAG: [[COMPARE_LT:%.*]] = xla_chlo.broadcast_compare [[FEATURES]], [[THRESHOLD]] {comparison_direction = "LT"}
// CHECK-DAG: [[FEATURES_EXP_LOG:%.*]] = "xla_hlo.log_plus_one"([[FEATURES_EXP]])
// CHECK: [[ELSE_SELECT:%.*]] = "xla_hlo.select"([[COMPARE_LT]], [[FEATURES_EXP]], [[FEATURES_EXP_LOG]])
// CHECK: [[ENTRY_SELECT:%.*]] = "xla_hlo.select"([[COMPARE_GT]], [[FEATURES]], [[ELSE_SELECT]])
%0 = "tf.Softplus"(%arg0) : (tensor<8x16xbf16>) -> tensor<8x16xbf16>
// CHECK: return [[ENTRY_SELECT]] : tensor<8x16xbf16>
return %0 : tensor<8x16xbf16>
}
// CHECK-LABEL: func @softplus_f32
// CHECK-SAME: ([[FEATURES:%.*]]: tensor<8x16xf32>)
func @softplus_f32(%arg0: tensor<8x16xf32>) -> tensor<8x16xf32> {
// CHECK-DAG: [[FEATURES_EXP:%.*]] = "xla_hlo.exponential"([[FEATURES]])
// CHECK-DAG: [[EPSILON:%.*]] = xla_hlo.constant dense<1.1920929E-7> : tensor<f32>
// CHECK-DAG: [[EPSILON_LOG:%.*]] = "xla_hlo.log"([[EPSILON]])
// CHECK-DAG: [[TWO:%.*]] = xla_hlo.constant dense<2.000000e+00> : tensor<f32>
// CHECK: [[THRESHOLD:%.*]] = xla_chlo.broadcast_add [[EPSILON_LOG]], [[TWO]]
// CHECK: [[NEG_THRESHOLD:%.*]] = "xla_hlo.negate"([[THRESHOLD]])
// CHECK-DAG: [[COMPARE_GT:%.*]] = xla_chlo.broadcast_compare [[FEATURES]], [[NEG_THRESHOLD]] {comparison_direction = "GT"}
// CHECK-DAG: [[COMPARE_LT:%.*]] = xla_chlo.broadcast_compare [[FEATURES]], [[THRESHOLD]] {comparison_direction = "LT"}
// CHECK-DAG: [[FEATURES_EXP_LOG:%.*]] = "xla_hlo.log_plus_one"([[FEATURES_EXP]])
// CHECK: [[ELSE_SELECT:%.*]] = "xla_hlo.select"([[COMPARE_LT]], [[FEATURES_EXP]], [[FEATURES_EXP_LOG]])
// CHECK: [[ENTRY_SELECT:%.*]] = "xla_hlo.select"([[COMPARE_GT]], [[FEATURES]], [[ELSE_SELECT]])
%0 = "tf.Softplus"(%arg0) : (tensor<8x16xf32>) -> tensor<8x16xf32>
// CHECK: return [[ENTRY_SELECT]] : tensor<8x16xf32>
return %0 : tensor<8x16xf32>
}
// CHECK-LABEL: func @softplus_f64
// CHECK-SAME: ([[FEATURES:%.*]]: tensor<8x16xf64>)
func @softplus_f64(%arg0: tensor<8x16xf64>) -> tensor<8x16xf64> {
// CHECK-DAG: [[FEATURES_EXP:%.*]] = "xla_hlo.exponential"([[FEATURES]])
// CHECK-DAG: [[EPSILON:%.*]] = xla_hlo.constant dense<2.2204460492503131E-16> : tensor<f64>
// CHECK-DAG: [[EPSILON_LOG:%.*]] = "xla_hlo.log"([[EPSILON]])
// CHECK-DAG: [[TWO:%.*]] = xla_hlo.constant dense<2.000000e+00> : tensor<f64>
// CHECK: [[THRESHOLD:%.*]] = xla_chlo.broadcast_add [[EPSILON_LOG]], [[TWO]]
// CHECK: [[NEG_THRESHOLD:%.*]] = "xla_hlo.negate"([[THRESHOLD]])
// CHECK-DAG: [[COMPARE_GT:%.*]] = xla_chlo.broadcast_compare [[FEATURES]], [[NEG_THRESHOLD]] {comparison_direction = "GT"}
// CHECK-DAG: [[COMPARE_LT:%.*]] = xla_chlo.broadcast_compare [[FEATURES]], [[THRESHOLD]] {comparison_direction = "LT"}
// CHECK-DAG: [[FEATURES_EXP_LOG:%.*]] = "xla_hlo.log_plus_one"([[FEATURES_EXP]])
// CHECK: [[ELSE_SELECT:%.*]] = "xla_hlo.select"([[COMPARE_LT]], [[FEATURES_EXP]], [[FEATURES_EXP_LOG]])
// CHECK: [[ENTRY_SELECT:%.*]] = "xla_hlo.select"([[COMPARE_GT]], [[FEATURES]], [[ELSE_SELECT]])
%0 = "tf.Softplus"(%arg0) : (tensor<8x16xf64>) -> tensor<8x16xf64>
// CHECK: return [[ENTRY_SELECT]] : tensor<8x16xf64>
return %0 : tensor<8x16xf64>
}

View File

@ -173,7 +173,8 @@ func @iota(%out: memref<7x10xf32>) {
"xla_lhlo.iota"(%out) {iota_dimension = 1 : i64} : (memref<7x10xf32>) -> () "xla_lhlo.iota"(%out) {iota_dimension = 1 : i64} : (memref<7x10xf32>) -> ()
return return
} }
// CHECK: linalg.indexed_generic {{{.*}}indexing_maps = [#[[RESULT_MAP]]] // CHECK: linalg.indexed_generic
// CHECK-SAME: indexing_maps = [#[[RESULT_MAP]]]
// CHECK-NEXT: ^bb0(%[[D0:.*]]: index, %[[D1:.*]]: index, %[[RESULT:.*]]: f32): // CHECK-NEXT: ^bb0(%[[D0:.*]]: index, %[[D1:.*]]: index, %[[RESULT:.*]]: f32):
// CHECK-NEXT: %[[INT_CAST:.*]] = index_cast %[[D1]] : index to i32 // CHECK-NEXT: %[[INT_CAST:.*]] = index_cast %[[D1]] : index to i32
// CHECK-NEXT: %[[FLOAT_CAST:.*]] = sitofp %[[INT_CAST]] : i32 to f32 // CHECK-NEXT: %[[FLOAT_CAST:.*]] = sitofp %[[INT_CAST]] : i32 to f32
@ -190,7 +191,8 @@ func @broadcast_scalar(%operand: memref<f32>, %result: memref<4x2x1xf32>) {
} : (memref<f32>, memref<4x2x1xf32>) -> () } : (memref<f32>, memref<4x2x1xf32>) -> ()
return return
} }
// CHECK: linalg.generic {{{.*}}indexing_maps = [#[[OPERAND_MAP]], #[[RESULT_MAP]]] // CHECK: linalg.generic
// CHECK-SAME: indexing_maps = [#[[OPERAND_MAP]], #[[RESULT_MAP]]]
// CHECK-NEXT: ^bb0(%[[OPERAND:.+]]: f32, %{{.+}}: f32): // CHECK-NEXT: ^bb0(%[[OPERAND:.+]]: f32, %{{.+}}: f32):
// CHECK-NEXT: linalg.yield %[[OPERAND]] : f32 // CHECK-NEXT: linalg.yield %[[OPERAND]] : f32
@ -206,7 +208,8 @@ func @broadcast(%operand: memref<4x?x16xf32>,
} : (memref<4x?x16xf32>, memref<4x2x1x4x?x16xf32>) -> () } : (memref<4x?x16xf32>, memref<4x2x1x4x?x16xf32>) -> ()
return return
} }
// CHECK: linalg.generic {{{.*}}indexing_maps = [#[[OPERAND_MAP]], #[[RESULT_MAP]]] // CHECK: linalg.generic
// CHECK-SAME: indexing_maps = [#[[OPERAND_MAP]], #[[RESULT_MAP]]]
// CHECK-NEXT: ^bb0(%[[OPERAND:.+]]: f32, %{{.+}}: f32): // CHECK-NEXT: ^bb0(%[[OPERAND:.+]]: f32, %{{.+}}: f32):
// CHECK-NEXT: linalg.yield %[[OPERAND]] : f32 // CHECK-NEXT: linalg.yield %[[OPERAND]] : f32
@ -222,7 +225,8 @@ func @dynamic_broadcast_in_dim(%operand: memref<?x?x?xf32>,
} : (memref<?x?x?xf32>, memref<?x?x?x?x?xf32>) -> () } : (memref<?x?x?xf32>, memref<?x?x?x?x?xf32>) -> ()
return return
} }
// CHECK: linalg.generic {{{.*}}indexing_maps = [#[[OPERAND_MAP]], #[[RESULT_MAP]]] // CHECK: linalg.generic
// CHECK-SAME: indexing_maps = [#[[OPERAND_MAP]], #[[RESULT_MAP]]]
// CHECK-NEXT: ^bb0(%[[OPERAND:.*]]: f32, %[[RESULT:.*]]: f32): // CHECK-NEXT: ^bb0(%[[OPERAND:.*]]: f32, %[[RESULT:.*]]: f32):
// CHECK-NEXT: linalg.yield %[[OPERAND]] : f32 // CHECK-NEXT: linalg.yield %[[OPERAND]] : f32
@ -645,39 +649,42 @@ func @slice(%operand: memref<?x?xf32>, %result: memref<?x?xf32>) {
// ----- // -----
// CHECK-DAG: #[[OPERAND_MAP:.*]] = affine_map<(d0, d1) -> (d0, 0, d1)> // CHECK-DAG: #[[MAP1:.*]] = affine_map<(d0, d1, d2) -> (d0, d1)>
// CHECK-DAG: #[[RESULT_MAP:.*]] = affine_map<(d0, d1) -> (d0, d1)> // CHECK-DAG: #[[MAP2:.*]] = affine_map<(d0, d1, d2) -> (d2)>
// CHECK-LABEL: func @reshape_3D_2D // CHECK-LABEL: func @reshape_3D_2D
func @reshape_3D_2D(%arg0: memref<12x1x42xi32>, %arg1 : memref<12x42xi32>) { func @reshape_3D_2D(%arg0: memref<12x1x42xi32>, %arg1 : memref<12x42xi32>) {
"xla_lhlo.reshape"(%arg0, %arg1) "xla_lhlo.reshape"(%arg0, %arg1)
: (memref<12x1x42xi32>, memref<12x42xi32>) -> () : (memref<12x1x42xi32>, memref<12x42xi32>) -> ()
return return
} }
// CHECK: linalg.generic {{{.*}}indexing_maps = [#[[OPERAND_MAP]], #[[RESULT_MAP]]] // CHECK: linalg.reshape %{{.*}} [#[[MAP1]], #[[MAP2]]]
// CHECK-NEXT: linalg.copy
// ----- // -----
// CHECK-DAG: #[[OPERAND_MAP:.*]] = affine_map<(d0, d1) -> (d0, d1, 0, 0)> // CHECK-DAG: #[[MAP1:.*]] = affine_map<(d0, d1, d2, d3) -> (d0)>
// CHECK-DAG: #[[RESULT_MAP:.*]] = affine_map<(d0, d1) -> (d0, d1)> // CHECK-DAG: #[[MAP2:.*]] = affine_map<(d0, d1, d2, d3) -> (d1, d2, d3)>
// CHECK-LABEL: func @reshape_4D_2D // CHECK-LABEL: func @reshape_4D_2D
func @reshape_4D_2D(%arg0: memref<12x42x1x1xi32>, %arg1 : memref<12x42xi32>) { func @reshape_4D_2D(%arg0: memref<12x42x1x1xi32>, %arg1 : memref<12x42xi32>) {
"xla_lhlo.reshape"(%arg0, %arg1) "xla_lhlo.reshape"(%arg0, %arg1)
: (memref<12x42x1x1xi32>, memref<12x42xi32>) -> () : (memref<12x42x1x1xi32>, memref<12x42xi32>) -> ()
return return
} }
// CHECK: linalg.generic {{{.*}}indexing_maps = [#[[OPERAND_MAP]], #[[RESULT_MAP]]] // CHECK: linalg.reshape %{{.*}} [#[[MAP1]], #[[MAP2]]]
// CHECK-NEXT: linalg.copy
// ----- // -----
// CHECK-DAG: #[[OPERAND_MAP:.*]] = affine_map<(d0, d1, d2, d3) -> (d0, d2)> // CHECK-DAG: #[[MAP1:.*]] = affine_map<(d0, d1, d2, d3) -> (d0, d1)>
// CHECK-DAG: #[[RESULT_MAP:.*]] = affine_map<(d0, d1, d2, d3) -> (d0, d1, d2, d3)> // CHECK-DAG: #[[MAP2:.*]] = affine_map<(d0, d1, d2, d3) -> (d2, d3)>
// CHECK-LABEL: func @reshape_2D_4D // CHECK-LABEL: func @reshape_2D_4D
func @reshape_2D_4D(%arg0: memref<12x42xi32>, %arg1 : memref<12x1x42x1xi32>) { func @reshape_2D_4D(%arg0: memref<12x42xi32>, %arg1 : memref<12x1x42x1xi32>) {
"xla_lhlo.reshape"(%arg0, %arg1) "xla_lhlo.reshape"(%arg0, %arg1)
: (memref<12x42xi32>, memref<12x1x42x1xi32>) -> () : (memref<12x42xi32>, memref<12x1x42x1xi32>) -> ()
return return
} }
// CHECK: linalg.generic {{{.*}}indexing_maps = [#[[OPERAND_MAP]], #[[RESULT_MAP]]] // CHECK: linalg.reshape %{{.*}} [#[[MAP1]], #[[MAP2]]]
// CHECK-NEXT: linalg.copy
// ----- // -----

View File

@ -103,9 +103,10 @@ func @dynamic_reduce(%arg: memref<?x?x?xf32>,
// CHECK-SAME: [[RESULT_BUF:%.*]]: memref<?x?xf32>) { // CHECK-SAME: [[RESULT_BUF:%.*]]: memref<?x?xf32>) {
// CHECK-DAG: [[C0:%.*]] = constant 0 : index // CHECK-DAG: [[C0:%.*]] = constant 0 : index
// CHECK-DAG: [[C1:%.*]] = constant 1 : index // CHECK-DAG: [[C1:%.*]] = constant 1 : index
// CHECK: [[DIM0:%.*]] = dim [[ARG_BUF]], 0 : memref<?x?x?xf32> // CHECK-DAG: [[C2:%.*]] = constant 2 : index
// CHECK: [[DIM1:%.*]] = dim [[ARG_BUF]], 1 : memref<?x?x?xf32> // CHECK: [[DIM0:%.*]] = dim [[ARG_BUF]], [[C0]] : memref<?x?x?xf32>
// CHECK: [[DIM2:%.*]] = dim [[ARG_BUF]], 2 : memref<?x?x?xf32> // CHECK: [[DIM1:%.*]] = dim [[ARG_BUF]], [[C1]] : memref<?x?x?xf32>
// CHECK: [[DIM2:%.*]] = dim [[ARG_BUF]], [[C2]] : memref<?x?x?xf32>
// CHECK: [[INIT:%.*]] = load [[INIT_BUF]] // CHECK: [[INIT:%.*]] = load [[INIT_BUF]]
// CHECK: scf.parallel ([[I:%.*]], [[K:%.*]]) = ([[C0]], [[C0]]) // CHECK: scf.parallel ([[I:%.*]], [[K:%.*]]) = ([[C0]], [[C0]])
// CHECK-SAME: to ([[DIM0]], [[DIM2]]) step ([[C1]], [[C1]]) { // CHECK-SAME: to ([[DIM0]], [[DIM2]]) step ([[C1]], [[C1]]) {

View File

@ -1,8 +1,66 @@
// RUN: xla-opt %s -verify-diagnostics -split-input-file | xla-opt | FileCheck %s // RUN: xla-opt %s -verify-diagnostics -split-input-file | xla-opt | FileCheck %s
func @enforce_same_shape(%arg0: memref<1xf32>, %arg1: memref<2xf32>) -> () { // -----
// expected-error@+1{{'xla_lhlo.tanh' op requires all operands to have the same type}}
"xla_lhlo.tanh"(%arg0, %arg1) : (memref<1xf32>, memref<2xf32>) -> () // CHECK-LABEL: func @ceil
func @ceil(%input: memref<2x2xf32>, %result: memref<2x2xf32>) {
"xla_lhlo.ceil"(%input, %result) : (memref<2x2xf32>, memref<2x2xf32>) -> ()
return
}
// -----
func @ceil(%input: memref<2x2xi32>, %result: memref<2x2xi32>) {
// expected-error@+1{{must be memref of floating-point values}}
"xla_lhlo.ceil"(%input, %result) : (memref<2x2xi32>, memref<2x2xi32>) -> ()
return
}
// -----
// CHECK-LABEL: func @cos
func @cos(%input: memref<2x2xf32>, %result: memref<2x2xf32>) {
"xla_lhlo.cosine"(%input, %result) : (memref<2x2xf32>, memref<2x2xf32>) -> ()
return
}
// -----
// CHECK-LABEL: func @cos
func @cos(%input: memref<2x2xcomplex<f32>>, %result: memref<2x2xcomplex<f32>>) {
"xla_lhlo.cosine"(%input, %result) : (memref<2x2xcomplex<f32>>, memref<2x2xcomplex<f32>>) -> ()
return
}
// -----
func @cos(%input: memref<2x2xi32>, %result: memref<2x2xi32>) {
// expected-error@+1{{must be memref of floating-point or complex-type values}}
"xla_lhlo.cosine"(%input, %result) : (memref<2x2xi32>, memref<2x2xi32>) -> ()
return
}
// -----
// CHECK-LABEL: func @sin
func @sin(%input: memref<2x2xf32>, %result: memref<2x2xf32>) {
"xla_lhlo.sine"(%input, %result) : (memref<2x2xf32>, memref<2x2xf32>) -> ()
return
}
// -----
// CHECK-LABEL: func @sin
func @sin(%input: memref<2x2xcomplex<f32>>, %result: memref<2x2xcomplex<f32>>) {
"xla_lhlo.sine"(%input, %result) : (memref<2x2xcomplex<f32>>, memref<2x2xcomplex<f32>>) -> ()
return
}
// -----
func @sin(%input: memref<2x2xi32>, %result: memref<2x2xi32>) {
// expected-error@+1{{must be memref of floating-point or complex-type values}}
"xla_lhlo.sine"(%input, %result) : (memref<2x2xi32>, memref<2x2xi32>) -> ()
return return
} }
@ -25,16 +83,40 @@ func @abs_memref(%in: memref<10xf32>, %out: memref<10xf32>) -> () {
// ----- // -----
// CHECK-LABEL: func @convert_memref // CHECK-LABEL: func @convert_memref
func @convert_memref(%in: memref<10xf32>, %out: memref<10xf32>) -> () { func @convert_memref(%in: memref<10xf32>, %out: memref<10xi32>) -> () {
"xla_lhlo.convert"(%in, %out) : (memref<10xf32>, memref<10xf32>) -> () "xla_lhlo.convert"(%in, %out) : (memref<10xf32>, memref<10xi32>) -> ()
return return
} }
// ----- // -----
// CHECK-LABEL: func @exp_memref func @convert_memref(%in: memref<10xf32>, %out: memref<9xi32>) -> () {
func @exp_memref(%in: memref<10xf32>, %out: memref<10xf32>) -> () { // expected-error@+1{{requires the same shape for all operands}}
"xla_lhlo.exponential"(%in, %out) : (memref<10xf32>, memref<10xf32>) -> () "xla_lhlo.convert"(%in, %out) : (memref<10xf32>, memref<9xi32>) -> ()
return
}
// -----
// CHECK-LABEL: func @exp
func @exp(%input: memref<2x2xf32>, %result: memref<2x2xf32>) {
"xla_lhlo.exponential"(%input, %result) : (memref<2x2xf32>, memref<2x2xf32>) -> ()
return
}
// -----
// CHECK-LABEL: func @exp
func @exp(%input: memref<2x2xcomplex<f32>>, %result: memref<2x2xcomplex<f32>>) {
"xla_lhlo.exponential"(%input, %result) : (memref<2x2xcomplex<f32>>, memref<2x2xcomplex<f32>>) -> ()
return
}
// -----
func @exp(%input: memref<2x2xi32>, %result: memref<2x2xi32>) {
// expected-error@+1{{must be memref of floating-point or complex-type values}}
"xla_lhlo.exponential"(%input, %result) : (memref<2x2xi32>, memref<2x2xi32>) -> ()
return return
} }
@ -48,6 +130,22 @@ func @log_memref(%in: memref<10xf32>, %out: memref<10xf32>) -> () {
// ----- // -----
// CHECK-LABEL: func @log_memref
func @log_memref(%in: memref<10xcomplex<f32>>, %out: memref<10xcomplex<f32>>) -> () {
"xla_lhlo.log"(%in, %out) : (memref<10xcomplex<f32>>, memref<10xcomplex<f32>>) -> ()
return
}
// -----
func @log_memref(%in: memref<10xi32>, %out: memref<10xi32>) -> () {
// expected-error@+1{{must be memref of floating-point or complex-type values}}
"xla_lhlo.log"(%in, %out) : (memref<10xi32>, memref<10xi32>) -> ()
return
}
// -----
// CHECK-LABEL: func @neg_memref // CHECK-LABEL: func @neg_memref
func @neg_memref(%in: memref<10xf32>, %out: memref<10xf32>) -> () { func @neg_memref(%in: memref<10xf32>, %out: memref<10xf32>) -> () {
"xla_lhlo.negate"(%in, %out) : (memref<10xf32>, memref<10xf32>) -> () "xla_lhlo.negate"(%in, %out) : (memref<10xf32>, memref<10xf32>) -> ()
@ -64,6 +162,46 @@ func @rsqrt_memref(%in: memref<10xf32>, %out: memref<10xf32>) -> () {
// ----- // -----
// CHECK-LABEL: func @rsqrt_memref
func @rsqrt_memref(%in: memref<10xcomplex<f32>>, %out: memref<10xcomplex<f32>>) -> () {
"xla_lhlo.rsqrt"(%in, %out) : (memref<10xcomplex<f32>>, memref<10xcomplex<f32>>) -> ()
return
}
// -----
func @rsqrt_memref(%in: memref<10xi32>, %out: memref<10xi32>) -> () {
// expected-error@+1{{must be memref of floating-point or complex-type values}}
"xla_lhlo.rsqrt"(%in, %out) : (memref<10xi32>, memref<10xi32>) -> ()
return
}
// -----
// CHECK-LABEL: func @sqrt_memref
func @sqrt_memref(%in: memref<10xf32>, %out: memref<10xf32>) -> () {
"xla_lhlo.sqrt"(%in, %out) : (memref<10xf32>, memref<10xf32>) -> ()
return
}
// -----
// CHECK-LABEL: func @sqrt_memref
func @sqrt_memref(%in: memref<10xcomplex<f32>>, %out: memref<10xcomplex<f32>>) -> () {
"xla_lhlo.sqrt"(%in, %out) : (memref<10xcomplex<f32>>, memref<10xcomplex<f32>>) -> ()
return
}
// -----
func @sqrt_memref(%in: memref<10xi32>, %out: memref<10xi32>) -> () {
// expected-error@+1{{must be memref of floating-point or complex-type values}}
"xla_lhlo.sqrt"(%in, %out) : (memref<10xi32>, memref<10xi32>) -> ()
return
}
// -----
// CHECK-LABEL: func @sign_memref // CHECK-LABEL: func @sign_memref
func @sign_memref(%in: memref<10xf32>, %out: memref<10xf32>) -> () { func @sign_memref(%in: memref<10xf32>, %out: memref<10xf32>) -> () {
"xla_lhlo.sign"(%in, %out) : (memref<10xf32>, memref<10xf32>) -> () "xla_lhlo.sign"(%in, %out) : (memref<10xf32>, memref<10xf32>) -> ()
@ -80,6 +218,30 @@ func @tanh_memref(%in: memref<10xf32>, %out: memref<10xf32>) -> () {
// ----- // -----
// CHECK-LABEL: func @tanh_memref
func @tanh_memref(%in: memref<10xcomplex<f32>>, %out: memref<10xcomplex<f32>>) -> () {
"xla_lhlo.tanh"(%in, %out) : (memref<10xcomplex<f32>>, memref<10xcomplex<f32>>) -> ()
return
}
// -----
func @tanh_memref(%in: memref<10xi32>, %out: memref<10xi32>) -> () {
// expected-error@+1{{must be memref of floating-point or complex-type values}}
"xla_lhlo.tanh"(%in, %out) : (memref<10xi32>, memref<10xi32>) -> ()
return
}
// -----
func @tanh_memref(%arg0: memref<1xf32>, %arg1: memref<2xf32>) -> () {
// expected-error@+1{{'xla_lhlo.tanh' op requires all operands to have the same type}}
"xla_lhlo.tanh"(%arg0, %arg1) : (memref<1xf32>, memref<2xf32>) -> ()
return
}
// -----
// CHECK-LABEL: func @add_memref // CHECK-LABEL: func @add_memref
func @add_memref(%lhs: memref<10xf32>, %rhs: memref<10xf32>, %out: memref<10xf32>) -> () { func @add_memref(%lhs: memref<10xf32>, %rhs: memref<10xf32>, %out: memref<10xf32>) -> () {
"xla_lhlo.add"(%lhs, %rhs, %out) : (memref<10xf32>, memref<10xf32>, memref<10xf32>) -> () "xla_lhlo.add"(%lhs, %rhs, %out) : (memref<10xf32>, memref<10xf32>, memref<10xf32>) -> ()
@ -129,13 +291,77 @@ func @sub_memref(%lhs: memref<10xf32>, %rhs: memref<10xf32>, %out: memref<10xf32
// ----- // -----
// CHECK-LABEL: func @and_memref // CHECK-LABEL: func @and_memref
func @and_memref(%lhs: memref<10xi32>, %rhs: memref<10xi32>, %out: memref<10xi32>) -> () {
"xla_lhlo.and"(%lhs, %rhs, %out) : (memref<10xi32>, memref<10xi32>, memref<10xi32>) -> ()
return
}
// -----
// CHECK-LABEL: func @and_memref
func @and_memref(%lhs: memref<10xi1>, %rhs: memref<10xi1>, %out: memref<10xi1>) -> () {
"xla_lhlo.and"(%lhs, %rhs, %out) : (memref<10xi1>, memref<10xi1>, memref<10xi1>) -> ()
return
}
// -----
func @and_memref(%lhs: memref<10xf32>, %rhs: memref<10xf32>, %out: memref<10xf32>) -> () { func @and_memref(%lhs: memref<10xf32>, %rhs: memref<10xf32>, %out: memref<10xf32>) -> () {
// expected-error @+1 {{must be memref of 8/16/32/64-bit signless integer or 8/16/32/64-bit unsigned integer or pred (AKA boolean or 1-bit integer) values}}
"xla_lhlo.and"(%lhs, %rhs, %out) : (memref<10xf32>, memref<10xf32>, memref<10xf32>) -> () "xla_lhlo.and"(%lhs, %rhs, %out) : (memref<10xf32>, memref<10xf32>, memref<10xf32>) -> ()
return return
} }
// ----- // -----
// CHECK-LABEL: func @or_memref
func @or_memref(%lhs: memref<10xi32>, %rhs: memref<10xi32>, %out: memref<10xi32>) -> () {
"xla_lhlo.or"(%lhs, %rhs, %out) : (memref<10xi32>, memref<10xi32>, memref<10xi32>) -> ()
return
}
// -----
// CHECK-LABEL: func @or_memref
func @or_memref(%lhs: memref<10xi1>, %rhs: memref<10xi1>, %out: memref<10xi1>) -> () {
"xla_lhlo.or"(%lhs, %rhs, %out) : (memref<10xi1>, memref<10xi1>, memref<10xi1>) -> ()
return
}
// -----
func @or_memref(%lhs: memref<10xf32>, %rhs: memref<10xf32>, %out: memref<10xf32>) -> () {
// expected-error @+1 {{must be memref of 8/16/32/64-bit signless integer or 8/16/32/64-bit unsigned integer or pred (AKA boolean or 1-bit integer) values}}
"xla_lhlo.or"(%lhs, %rhs, %out) : (memref<10xf32>, memref<10xf32>, memref<10xf32>) -> ()
return
}
// -----
// CHECK-LABEL: func @xor_memref
func @xor_memref(%lhs: memref<10xi32>, %rhs: memref<10xi32>, %out: memref<10xi32>) -> () {
"xla_lhlo.xor"(%lhs, %rhs, %out) : (memref<10xi32>, memref<10xi32>, memref<10xi32>) -> ()
return
}
// -----
// CHECK-LABEL: func @xor_memref
func @xor_memref(%lhs: memref<10xi1>, %rhs: memref<10xi1>, %out: memref<10xi1>) -> () {
"xla_lhlo.xor"(%lhs, %rhs, %out) : (memref<10xi1>, memref<10xi1>, memref<10xi1>) -> ()
return
}
// -----
func @xor_memref(%lhs: memref<10xf32>, %rhs: memref<10xf32>, %out: memref<10xf32>) -> () {
// expected-error @+1 {{must be memref of 8/16/32/64-bit signless integer or 8/16/32/64-bit unsigned integer or pred (AKA boolean or 1-bit integer) values}}
"xla_lhlo.xor"(%lhs, %rhs, %out) : (memref<10xf32>, memref<10xf32>, memref<10xf32>) -> ()
return
}
// -----
// CHECK-LABEL: func @broadcast_in_dim_memref // CHECK-LABEL: func @broadcast_in_dim_memref
func @broadcast_in_dim_memref(%arg0: memref<1x2xi32>, %out: memref<1x2x2xi32>) -> () { func @broadcast_in_dim_memref(%arg0: memref<1x2xi32>, %out: memref<1x2x2xi32>) -> () {
"xla_lhlo.broadcast_in_dim"(%arg0, %out) {broadcast_dimensions = dense<[1, 2]> : tensor<2xi64>} : (memref<1x2xi32>, memref<1x2x2xi32>) -> () "xla_lhlo.broadcast_in_dim"(%arg0, %out) {broadcast_dimensions = dense<[1, 2]> : tensor<2xi64>} : (memref<1x2xi32>, memref<1x2x2xi32>) -> ()
@ -248,3 +474,392 @@ func @dynamic_memref_cast_incompatible_result_type(%in: memref<?xf32>) {
: memref<?xf32> -> memref<?x?xf32, offset: 0, strides: [?, ?]> : memref<?xf32> -> memref<?x?xf32, offset: 0, strides: [?, ?]>
return return
} }
// -----
// CHECK-LABEL: func @atan2_memrefs
func @atan2_memrefs(%arg0: memref<1xf32>, %arg1: memref<1xf32>, %arg_out: memref<1xf32>) -> () {
"xla_lhlo.atan2"(%arg0, %arg1, %arg_out) : (memref<1xf32>, memref<1xf32>, memref<1xf32>) -> ()
return
}
// -----
// CHECK-LABEL: func @atan2_memrefs
func @atan2_memrefs(%arg0: memref<1xcomplex<f32>>, %arg1: memref<1xcomplex<f32>>, %arg_out: memref<1xcomplex<f32>>) -> () {
"xla_lhlo.atan2"(%arg0, %arg1, %arg_out) : (memref<1xcomplex<f32>>, memref<1xcomplex<f32>>, memref<1xcomplex<f32>>) -> ()
return
}
// -----
func @atan2_memrefs(%arg0: memref<1xi32>, %arg1: memref<1xi32>, %arg_out: memref<1xi32>) -> () {
// expected-error@+1{{must be memref of floating-point or complex-type values}}
"xla_lhlo.atan2"(%arg0, %arg1, %arg_out) : (memref<1xi32>, memref<1xi32>, memref<1xi32>) -> ()
return
}
// -----
// CHECK-LABEL: func @bitcast_convert_memrefs
func @bitcast_convert_memrefs(%arg0: memref<1xf32>, %arg_out: memref<1xi32>) -> () {
"xla_lhlo.bitcast_convert"(%arg0, %arg_out) : (memref<1xf32>, memref<1xi32>) -> ()
return
}
// -----
func @bitcast_convert_memrefs(%arg0: memref<1xf32>, %arg_out: memref<2xi32>) -> () {
// expected-error@+1{{requires the same shape for all operands}}
"xla_lhlo.bitcast_convert"(%arg0, %arg_out) : (memref<1xf32>, memref<2xi32>) -> ()
return
}
// -----
// CHECK-LABEL: func @clz_memrefs
func @clz_memrefs(%arg0: memref<1xi32>, %arg_out: memref<1xi32>) -> () {
"xla_lhlo.count_leading_zeros"(%arg0, %arg_out) : (memref<1xi32>, memref<1xi32>) -> ()
return
}
// -----
// CHECK-LABEL: func @expm1_memrefs
func @expm1_memrefs(%arg0: memref<1xf32>, %arg_out: memref<1xf32>) -> () {
"xla_lhlo.exponential_minus_one"(%arg0, %arg_out) : (memref<1xf32>, memref<1xf32>) -> ()
return
}
// -----
// CHECK-LABEL: func @expm1_memrefs
func @expm1_memrefs(%arg0: memref<1xcomplex<f32>>, %arg_out: memref<1xcomplex<f32>>) -> () {
"xla_lhlo.exponential_minus_one"(%arg0, %arg_out) : (memref<1xcomplex<f32>>, memref<1xcomplex<f32>>) -> ()
return
}
// -----
// CHECK-LABEL: func @floor_memrefs
func @floor_memrefs(%arg0: memref<1xf32>, %arg_out: memref<1xf32>) -> () {
"xla_lhlo.floor"(%arg0, %arg_out) : (memref<1xf32>, memref<1xf32>) -> ()
return
}
// -----
func @floor_memrefs(%arg0: memref<1xi32>, %arg_out: memref<1xi32>) -> () {
// expected-error@+1{{must be memref of floating-point values}}
"xla_lhlo.floor"(%arg0, %arg_out) : (memref<1xi32>, memref<1xi32>) -> ()
return
}
// -----
// CHECK-LABEL: func @imag_memrefs
func @imag_memrefs(%arg0: memref<1xcomplex<f32>>, %arg_out: memref<1xf32>) -> () {
"xla_lhlo.imag"(%arg0, %arg_out) : (memref<1xcomplex<f32>>, memref<1xf32>) -> ()
return
}
// -----
func @imag_memrefs(%arg0: memref<1xf32>, %arg_out: memref<1xf32>) -> () {
// expected-error@+1{{must be memref of complex-type values}}
"xla_lhlo.imag"(%arg0, %arg_out) : (memref<1xf32>, memref<1xf32>) -> ()
return
}
// -----
// CHECK-LABEL: func @real_memrefs
func @real_memrefs(%arg0: memref<1xcomplex<f32>>, %arg_out: memref<1xf32>) -> () {
"xla_lhlo.real"(%arg0, %arg_out) : (memref<1xcomplex<f32>>, memref<1xf32>) -> ()
return
}
// -----
func @real_memrefs(%arg0: memref<1xf32>, %arg_out: memref<1xf32>) -> () {
// expected-error@+1{{must be memref of complex-type values}}
"xla_lhlo.real"(%arg0, %arg_out) : (memref<1xf32>, memref<1xf32>) -> ()
return
}
// -----
// CHECK-LABEL: func @is_finite_memrefs
func @is_finite_memrefs(%arg0: memref<1xf32>, %arg_out: memref<1xi1>) -> () {
"xla_lhlo.is_finite"(%arg0, %arg_out) : (memref<1xf32>, memref<1xi1>) -> ()
return
}
// -----
// CHECK-LABEL: func @log1p_memrefs
func @log1p_memrefs(%arg0: memref<1xf32>, %arg_out: memref<1xf32>) -> () {
"xla_lhlo.log_plus_one"(%arg0, %arg_out) : (memref<1xf32>, memref<1xf32>) -> ()
return
}
// -----
// CHECK-LABEL: func @log1p_memrefs
func @log1p_memrefs(%arg0: memref<1xcomplex<f32>>, %arg_out: memref<1xcomplex<f32>>) -> () {
"xla_lhlo.log_plus_one"(%arg0, %arg_out) : (memref<1xcomplex<f32>>, memref<1xcomplex<f32>>) -> ()
return
}
// -----
func @log1p_memref(%in: memref<10xi32>, %out: memref<10xi32>) -> () {
// expected-error@+1{{must be memref of floating-point or complex-type values}}
"xla_lhlo.log_plus_one"(%in, %out) : (memref<10xi32>, memref<10xi32>) -> ()
return
}
// -----
// CHECK-LABEL: func @not_memrefs
func @not_memrefs(%arg0: memref<1xi32>, %arg_out: memref<1xi32>) -> () {
"xla_lhlo.not"(%arg0, %arg_out) : (memref<1xi32>, memref<1xi32>) -> ()
return
}
// -----
// CHECK-LABEL: func @not_memrefs
func @not_memrefs(%arg0: memref<1xi1>, %arg_out: memref<1xi1>) -> () {
"xla_lhlo.not"(%arg0, %arg_out) : (memref<1xi1>, memref<1xi1>) -> ()
return
}
// -----
func @not_memrefs(%arg0: memref<1xf32>, %arg_out: memref<1xf32>) -> () {
// expected-error @+1 {{must be memref of 8/16/32/64-bit signless integer or 8/16/32/64-bit unsigned integer or pred (AKA boolean or 1-bit integer) values}}
"xla_lhlo.not"(%arg0, %arg_out) : (memref<1xf32>, memref<1xf32>) -> ()
return
}
// -----
// CHECK-LABEL: func @popcnt_memrefs
func @popcnt_memrefs(%arg0: memref<1xi32>, %arg_out: memref<1xi32>) -> () {
"xla_lhlo.popcnt"(%arg0, %arg_out) : (memref<1xi32>, memref<1xi32>) -> ()
return
}
// -----
func @popcnt_memrefs(%arg0: memref<1xf32>, %arg_out: memref<1xf32>) -> () {
// expected-error @+1 {{must be memref of 8/16/32/64-bit signless integer or 8/16/32/64-bit unsigned integer values}}
"xla_lhlo.popcnt"(%arg0, %arg_out) : (memref<1xf32>, memref<1xf32>) -> ()
return
}
// -----
// CHECK-LABEL: func @reduce_precision_memrefs
func @reduce_precision_memrefs(%arg0: memref<1xf32>, %arg_out: memref<1xf32>) -> () {
"xla_lhlo.reduce_precision"(%arg0, %arg_out) { exponent_bits = 4 : i32, mantissa_bits = 4 : i32 } : (memref<1xf32>, memref<1xf32>) -> ()
return
}
// -----
// CHECK-LABEL: func @round_memrefs
func @round_memrefs(%arg0: memref<1xf32>, %arg_out: memref<1xf32>) -> () {
"xla_lhlo.round_nearest_afz"(%arg0, %arg_out) : (memref<1xf32>, memref<1xf32>) -> ()
return
}
// -----
func @round_memrefs(%arg0: memref<1xi32>, %arg_out: memref<1xi32>) -> () {
// expected-error@+1{{must be memref of floating-point values}}
"xla_lhlo.round_nearest_afz"(%arg0, %arg_out) : (memref<1xi32>, memref<1xi32>) -> ()
return
}
// -----
// CHECK-LABEL: func @shift_left_memrefs
func @shift_left_memrefs(%arg0: memref<1xi32>, %arg1: memref<1xi32>, %arg_out: memref<1xi32>) -> () {
"xla_lhlo.shift_left"(%arg0, %arg1, %arg_out) : (memref<1xi32>, memref<1xi32>, memref<1xi32>) -> ()
return
}
// -----
func @shift_left_memrefs(%arg0: memref<1xf32>, %arg1: memref<1xf32>, %arg_out: memref<1xf32>) -> () {
// expected-error @+1 {{must be memref of 8/16/32/64-bit signless integer or 8/16/32/64-bit unsigned integer values}}
"xla_lhlo.shift_left"(%arg0, %arg1, %arg_out) : (memref<1xf32>, memref<1xf32>, memref<1xf32>) -> ()
return
}
// -----
// CHECK-LABEL: func @shift_right_arithmetic_memrefs
func @shift_right_arithmetic_memrefs(%arg0: memref<1xi32>, %arg1: memref<1xi32>, %arg_out: memref<1xi32>) -> () {
"xla_lhlo.shift_right_arithmetic"(%arg0, %arg1, %arg_out) : (memref<1xi32>, memref<1xi32>, memref<1xi32>) -> ()
return
}
// -----
func @shift_right_arithmetic_memrefs(%arg0: memref<1xf32>, %arg1: memref<1xf32>, %arg_out: memref<1xf32>) -> () {
// expected-error @+1 {{must be memref of 8/16/32/64-bit signless integer or 8/16/32/64-bit unsigned integer values}}
"xla_lhlo.shift_right_arithmetic"(%arg0, %arg1, %arg_out) : (memref<1xf32>, memref<1xf32>, memref<1xf32>) -> ()
return
}
// -----
// CHECK-LABEL: func @shift_right_logical_memrefs
func @shift_right_logical_memrefs(%arg0: memref<1xi32>, %arg1: memref<1xi32>, %arg_out: memref<1xi32>) -> () {
"xla_lhlo.shift_right_logical"(%arg0, %arg1, %arg_out) : (memref<1xi32>, memref<1xi32>, memref<1xi32>) -> ()
return
}
// -----
func @shift_right_logical_memrefs(%arg0: memref<1xf32>, %arg1: memref<1xf32>, %arg_out: memref<1xf32>) -> () {
// expected-error @+1 {{must be memref of 8/16/32/64-bit signless integer or 8/16/32/64-bit unsigned integer values}}
"xla_lhlo.shift_right_logical"(%arg0, %arg1, %arg_out) : (memref<1xf32>, memref<1xf32>, memref<1xf32>) -> ()
return
}
// -----
// CHECK-LABEL: func @all_reduce_memrefs
func @all_reduce_memrefs(%arg0: memref<10xf32>, %arg_out: memref<10xf32>) -> () {
"xla_lhlo.all_reduce"(%arg0, %arg_out) ({
^bb0(%lhs: tensor<f32>, %rhs: tensor<f32>):
%max = xla_hlo.maximum %lhs, %rhs : tensor<f32>
"xla_hlo.return"(%max) : (tensor<f32>) -> ()
})
{ replica_groups = dense<[[0, 2, 4, 6], [1, 3, 5, 7]]> : tensor<2x4xi64> }: (memref<10xf32>, memref<10xf32>) -> ()
"xla_lhlo.all_reduce"(%arg0, %arg_out) ({
^bb0(%lhs: tensor<f32>, %rhs: tensor<f32>):
%max = xla_hlo.maximum %lhs, %rhs : tensor<f32>
"xla_hlo.return"(%max) : (tensor<f32>) -> ()
})
{
replica_groups = dense<[[0, 2, 4, 6], [1, 3, 5, 7]]> : tensor<2x4xi64>,
channel_id = { handle = 5 : i64, type = 2 : i64 },
constrain_layout = true,
use_global_device_ids = true
}: (memref<10xf32>, memref<10xf32>) -> ()
return
}
// -----
// CHECK-LABEL: func @collective_permute_memrefs
func @collective_permute_memrefs(%arg0: memref<128x32xf32>, %arg_out: memref<128x32xf32>) -> () {
"xla_lhlo.collective_permute"(%arg0, %arg_out) {
source_target_pairs = dense<[[0, 1], [1, 2], [2, 3]]> : tensor<3x2xi64>
} : (memref<128x32xf32>, memref<128x32xf32>) -> ()
"xla_lhlo.collective_permute"(%arg0, %arg_out) {
source_target_pairs = dense<[[0, 1], [1, 2], [2, 3]]> : tensor<3x2xi64>,
channel_id = { handle = 5 : i64, type = 2 : i64 }
} : (memref<128x32xf32>, memref<128x32xf32>) -> ()
return
}
// -----
// CHECK-LABEL: func @fft_memrefs
func @fft_memrefs(%arg0: memref<3x9xf32>, %arg_out: memref<3x5xcomplex<f32>>) -> () {
"xla_lhlo.fft"(%arg0, %arg_out) {fft_length = dense<9> : tensor<1xi64>, fft_type = "RFFT"} : (memref<3x9xf32>, memref<3x5xcomplex<f32>>) -> ()
return
}
// -----
// CHECK-LABEL: func @batch_norm_grad_memrefs
func @batch_norm_grad_memrefs(%arg0: memref<8x8x8x8xf32>, %arg1: memref<8xf32>, %arg2: memref<8xf32>,
%arg3: memref<8xf32>, %arg4: memref<8x8x8x8xf32>,
%arg_out: tuple<memref<8x8x8x8xf32>, memref<8xf32>, memref<8xf32>>) -> () {
"xla_lhlo.batch_norm_grad"(%arg0, %arg1, %arg2, %arg3, %arg4, %arg_out) {epsilon = 1.000000e-03 : f32, feature_index = 3 : i64}
: (memref<8x8x8x8xf32>, memref<8xf32>, memref<8xf32>, memref<8xf32>, memref<8x8x8x8xf32>,
tuple<memref<8x8x8x8xf32>, memref<8xf32>, memref<8xf32>>) -> ()
return
}
// -----
// CHECK-LABEL: func @batch_norm_inference_memrefs
func @batch_norm_inference_memrefs(%arg0: memref<8x8x8x8xf32>, %arg1: memref<8xf32>, %arg2: memref<8xf32>,
%arg3: memref<8xf32>, %arg4: memref<8xf32>, %arg_out: memref<8x8x8x8xf32>) -> () {
"xla_lhlo.batch_norm_inference"(%arg0, %arg1, %arg2, %arg3, %arg4, %arg_out) {epsilon = 1.000000e-03 : f32, feature_index = 3 : i64}
: (memref<8x8x8x8xf32>, memref<8xf32>, memref<8xf32>, memref<8xf32>, memref<8xf32>, memref<8x8x8x8xf32>) -> ()
return
}
// -----
// CHECK-LABEL: func @batch_norm_training_memrefs
func @batch_norm_training_memrefs(%arg0: memref<8x8x8x8xf32>, %arg1: memref<8xf32>, %arg2: memref<8xf32>,
%arg_out: tuple<memref<8x8x8x8xf32>, memref<8xf32>, memref<8xf32>>) -> () {
"xla_lhlo.batch_norm_training"(%arg0, %arg1, %arg2, %arg_out) {epsilon = 1.000000e-03 : f32, feature_index = 3 : i64}
: (memref<8x8x8x8xf32>, memref<8xf32>, memref<8xf32>, tuple<memref<8x8x8x8xf32>, memref<8xf32>, memref<8xf32>>) -> ()
return
}
// -----
// CHECK-LABEL: func @cholesky_memrefs
func @cholesky_memrefs(%arg0: memref<1x291x291xf32>, %arg_out: memref<1x291x291xf32>) -> () {
"xla_lhlo.cholesky"(%arg0, %arg_out) : (memref<1x291x291xf32>, memref<1x291x291xf32>) -> ()
"xla_lhlo.cholesky"(%arg0, %arg_out) { lower = true } : (memref<1x291x291xf32>, memref<1x291x291xf32>) -> ()
return
}
// -----
// CHECK-LABEL: func @infeed_memrefs
func @infeed_memrefs(%arg_out: memref<3xf32>) -> () {
"xla_lhlo.infeed"(%arg_out) { config = "x" } : (memref<3xf32>) -> ()
return
}
// -----
// CHECK-LABEL: func @outfeed_memrefs
func @outfeed_memrefs(%arg0: memref<3xf32>) -> () {
"xla_lhlo.outfeed"(%arg0) { config = "x" } : (memref<3xf32>) -> ()
return
}
// -----
// CHECK-LABEL: func @replica_id_memrefs
func @replica_id_memrefs(%arg_out: memref<ui32>) -> () {
"xla_lhlo.replica_id"(%arg_out) : (memref<ui32>) -> ()
return
}
// -----
// CHECK-LABEL: func @triangular_solve_memrefs
func @triangular_solve_memrefs(%arg0: memref<4x4xf32>, %arg1: memref<3x4xf32>, %arg_out: memref<3x4xf32>) -> () {
"xla_lhlo.triangular_solve"(%arg0, %arg1, %arg_out) {left_side = true, lower = true, transpose_a = "NO_TRANSPOSE", unit_diagonal = true}
: (memref<4x4xf32>, memref<3x4xf32>, memref<3x4xf32>) -> ()
return
}
// -----
// CHECK-LABEL: func @while_memrefs
func @while_memrefs(%arg0: memref<i64>, %arg_out: memref<i64>) -> () {
"xla_lhlo.while"(%arg0, %arg_out) (
{ ^bb0(%arg: memref<i64>, %cond: memref<i1>): "xla_lhlo.terminator"() : () -> () },
{ ^bb0(%arg: memref<i64>, %body_out: memref<i64>): "xla_lhlo.terminator"() : () -> () }
) : (memref<i64>, memref<i64>) -> ()
return
}

View File

@ -304,6 +304,46 @@ func @collective_permute_duplicate_sources(%arg0: tensor<128x32xf32>) -> tensor<
// ----- // -----
// CHECK-LABEL: @concat_1D
func @concat_1D(%arg0: tensor<1xi32>, %arg1: tensor<2xi32>) -> tensor<3xi32> {
%0 = "xla_hlo.concatenate"(%arg0, %arg1) { dimension = 0 : i64 } : (tensor<1xi32>, tensor<2xi32>) -> tensor<3xi32>
return %0 : tensor<3xi32>
}
// -----
func @concat_1D_type_error(%arg0: tensor<1xi32>, %arg1: tensor<2xf32>) -> tensor<3xi32> {
// expected-error@+1 {{'xla_hlo.concatenate' op requires the same element type for all operands and results}}
%0 = "xla_hlo.concatenate"(%arg0, %arg1) { dimension = 0 : i64 } : (tensor<1xi32>, tensor<2xf32>) -> tensor<3xi32>
return %0 : tensor<3xi32>
}
// -----
// CHECK-LABEL: @concat_1D_unranked
func @concat_1D_unranked(%arg0: tensor<1xi32>, %arg1: tensor<*xi32>) -> tensor<*xi32> {
%0 = "xla_hlo.concatenate"(%arg0, %arg1) { dimension = 0 : i64 } : (tensor<1xi32>, tensor<*xi32>) -> tensor<*xi32>
return %0 : tensor<*xi32>
}
// -----
func @concat_1D_unranked_error(%arg0: tensor<1xi32>, %arg1: tensor<*xi32>) -> tensor<3xi32> {
// expected-error@+1 {{'xla_hlo.concatenate' op inferred type incompatible with return type of operation}}
%0 = "xla_hlo.concatenate"(%arg0, %arg1) { dimension = 0 : i64 } : (tensor<1xi32>, tensor<*xi32>) -> tensor<3xi32>
return %0 : tensor<3xi32>
}
// -----
func @concat_1D_error(%arg0: tensor<1xi32>, %arg1: tensor<2xi32>) -> tensor<4xi32> {
// expected-error@+1 {{'xla_hlo.concatenate' op inferred type incompatible with return type of operation}}
%0 = "xla_hlo.concatenate"(%arg0, %arg1) { dimension = 0 : i64 } : (tensor<1xi32>, tensor<2xi32>) -> tensor<4xi32>
return %0 : tensor<4xi32>
}
// -----
// CHECK-LABEL: func @clamp // CHECK-LABEL: func @clamp
func @clamp(%arg0: tensor<1xi32>) -> tensor<1xi32> { func @clamp(%arg0: tensor<1xi32>) -> tensor<1xi32> {
%0 = "xla_hlo.clamp"(%arg0, %arg0, %arg0) : (tensor<1xi32>, tensor<1xi32>, tensor<1xi32>) -> tensor<1xi32> %0 = "xla_hlo.clamp"(%arg0, %arg0, %arg0) : (tensor<1xi32>, tensor<1xi32>, tensor<1xi32>) -> tensor<1xi32>

View File

@ -104,16 +104,20 @@ func @batchNormInference_dynamic_shape(
%x: tensor<?x?x?x?xf32>, %scale: tensor<?xf32>, %offset: tensor<?xf32>, %x: tensor<?x?x?x?xf32>, %scale: tensor<?xf32>, %offset: tensor<?xf32>,
%mean: tensor<?xf32>, %variance: tensor<?xf32>) %mean: tensor<?xf32>, %variance: tensor<?xf32>)
-> tensor<?x?x?x?xf32> { -> tensor<?x?x?x?xf32> {
// CHECK-DAG: %[[C0:.*]] = constant 0 : index
// CHECK-DAG: %[[C1:.*]] = constant 1 : index
// CHECK-DAG: %[[C2:.*]] = constant 2 : index
// CHECK-DAG: %[[C3:.*]] = constant 3 : index
// CHECK-DAG: %[[EPS:.+]] = xla_hlo.constant dense<1.000000e-03> : tensor<f32> // CHECK-DAG: %[[EPS:.+]] = xla_hlo.constant dense<1.000000e-03> : tensor<f32>
// CHECK-DAG: %[[DIM:.+]] = dim %[[VARIANCE]], 0 : tensor<?xf32> // CHECK-DAG: %[[DIM:.+]] = dim %[[VARIANCE]], %[[C0]] : tensor<?xf32>
// CHECK-DAG: %[[TO_DIM_TENSOR:.+]] = tensor_from_elements(%[[DIM]]) : tensor<1xindex> // CHECK-DAG: %[[TO_DIM_TENSOR:.+]] = tensor_from_elements(%[[DIM]]) : tensor<1xindex>
// CHECK-DAG: %[[EPS_BCAST:.+]] = "xla_hlo.dynamic_broadcast_in_dim"(%[[EPS]], %[[TO_DIM_TENSOR]]) {broadcast_dimensions = dense<[]> : tensor<0xi64>} : (tensor<f32>, tensor<1xindex>) -> tensor<?xf32> // CHECK-DAG: %[[EPS_BCAST:.+]] = "xla_hlo.dynamic_broadcast_in_dim"(%[[EPS]], %[[TO_DIM_TENSOR]]) {broadcast_dimensions = dense<[]> : tensor<0xi64>} : (tensor<f32>, tensor<1xindex>) -> tensor<?xf32>
// CHECK-DAG: %[[VARIANCE_EPS:.+]] = xla_hlo.add %[[VARIANCE]], %[[EPS_BCAST]] : tensor<?xf32> // CHECK-DAG: %[[VARIANCE_EPS:.+]] = xla_hlo.add %[[VARIANCE]], %[[EPS_BCAST]] : tensor<?xf32>
// CHECK-DAG: %[[STDDEV:.+]] = "xla_hlo.sqrt"(%[[VARIANCE_EPS]]) : (tensor<?xf32>) -> tensor<?xf32> // CHECK-DAG: %[[STDDEV:.+]] = "xla_hlo.sqrt"(%[[VARIANCE_EPS]]) : (tensor<?xf32>) -> tensor<?xf32>
// CHECK-DAG: %[[INPUT_DIM_0:.+]] = dim %[[X]], 0 : tensor<?x?x?x?xf32> // CHECK-DAG: %[[INPUT_DIM_0:.+]] = dim %[[X]], %[[C0]] : tensor<?x?x?x?xf32>
// CHECK-DAG: %[[INPUT_DIM_1:.+]] = dim %[[X]], 1 : tensor<?x?x?x?xf32> // CHECK-DAG: %[[INPUT_DIM_1:.+]] = dim %[[X]], %[[C1]] : tensor<?x?x?x?xf32>
// CHECK-DAG: %[[INPUT_DIM_2:.+]] = dim %[[X]], 2 : tensor<?x?x?x?xf32> // CHECK-DAG: %[[INPUT_DIM_2:.+]] = dim %[[X]], %[[C2]] : tensor<?x?x?x?xf32>
// CHECK-DAG: %[[INPUT_DIM_3:.+]] = dim %[[X]], 3 : tensor<?x?x?x?xf32> // CHECK-DAG: %[[INPUT_DIM_3:.+]] = dim %[[X]], %[[C3]] : tensor<?x?x?x?xf32>
// CHECK-DAG: %[[TO_INPUT_DIM_TENSOR:.+]] = tensor_from_elements(%[[INPUT_DIM_0]], %[[INPUT_DIM_1]], %[[INPUT_DIM_2]], %[[INPUT_DIM_3]]) : tensor<4xindex> // CHECK-DAG: %[[TO_INPUT_DIM_TENSOR:.+]] = tensor_from_elements(%[[INPUT_DIM_0]], %[[INPUT_DIM_1]], %[[INPUT_DIM_2]], %[[INPUT_DIM_3]]) : tensor<4xindex>
// CHECK-DAG: %[[STDDEV_BCAST:.+]] = "xla_hlo.dynamic_broadcast_in_dim"(%[[STDDEV]], %[[TO_INPUT_DIM_TENSOR]]) {broadcast_dimensions = dense<1> : tensor<1xi64>} : (tensor<?xf32>, tensor<4xindex>) -> tensor<?x?x?x?xf32> // CHECK-DAG: %[[STDDEV_BCAST:.+]] = "xla_hlo.dynamic_broadcast_in_dim"(%[[STDDEV]], %[[TO_INPUT_DIM_TENSOR]]) {broadcast_dimensions = dense<1> : tensor<1xi64>} : (tensor<?xf32>, tensor<4xindex>) -> tensor<?x?x?x?xf32>
// CHECK-DAG: %[[SCALE_BCAST:.+]] = "xla_hlo.dynamic_broadcast_in_dim"(%[[SCALE]], %[[TO_INPUT_DIM_TENSOR]]) {broadcast_dimensions = dense<1> : tensor<1xi64>} : (tensor<?xf32>, tensor<4xindex>) -> tensor<?x?x?x?xf32> // CHECK-DAG: %[[SCALE_BCAST:.+]] = "xla_hlo.dynamic_broadcast_in_dim"(%[[SCALE]], %[[TO_INPUT_DIM_TENSOR]]) {broadcast_dimensions = dense<1> : tensor<1xi64>} : (tensor<?xf32>, tensor<4xindex>) -> tensor<?x?x?x?xf32>

View File

@ -16,8 +16,7 @@ func @sqr_transform_result(%a: tensor<*xf32>) -> tensor<*xf32> {
%flat_b = "xla_hlo.sqrt"(%flat_a) : (tensor<?xf32>) -> tensor<?xf32> %flat_b = "xla_hlo.sqrt"(%flat_a) : (tensor<?xf32>) -> tensor<?xf32>
// Restore original shape. // Restore original shape.
%shape_as_extent_tensor = "shape.to_extent_tensor"(%shape) %shape_as_extent_tensor = shape.to_extent_tensor %shape : tensor<?xindex>
: (!shape.shape) -> tensor<?xindex>
%b = "xla_hlo.dynamic_reshape"(%flat_b, %shape_as_extent_tensor) %b = "xla_hlo.dynamic_reshape"(%flat_b, %shape_as_extent_tensor)
: (tensor<?xf32>, tensor<?xindex>) -> tensor<*xf32> : (tensor<?xf32>, tensor<?xindex>) -> tensor<*xf32>
@ -35,7 +34,7 @@ func @sqrt(%a: tensor<*xf32>) -> tensor<*xf32> {
// CHECK-NEXT: %[[FLAT_SHAPE:.*]] = tensor_from_elements(%[[NUM_ELEMENTS_AS_INDEX]]) : tensor<1xindex> // CHECK-NEXT: %[[FLAT_SHAPE:.*]] = tensor_from_elements(%[[NUM_ELEMENTS_AS_INDEX]]) : tensor<1xindex>
// CHECK-NEXT: %[[FLAT_A:.*]] = "xla_hlo.dynamic_reshape"(%[[A]], %[[FLAT_SHAPE]]) : (tensor<*xf32>, tensor<1xindex>) -> tensor<?xf32> // CHECK-NEXT: %[[FLAT_A:.*]] = "xla_hlo.dynamic_reshape"(%[[A]], %[[FLAT_SHAPE]]) : (tensor<*xf32>, tensor<1xindex>) -> tensor<?xf32>
// CHECK-NEXT: %[[FLAT_B:.*]] = "xla_hlo.sqrt"(%[[FLAT_A]]) : (tensor<?xf32>) -> tensor<?xf32> // CHECK-NEXT: %[[FLAT_B:.*]] = "xla_hlo.sqrt"(%[[FLAT_A]]) : (tensor<?xf32>) -> tensor<?xf32>
// CHECK-NEXT: %[[SHAPE_AS_EXTENT_TENSOR:.*]] = "shape.to_extent_tensor"(%[[SHAPE]]) : (!shape.shape) -> tensor<?xindex> // CHECK-NEXT: %[[SHAPE_AS_EXTENT_TENSOR:.*]] = shape.to_extent_tensor %[[SHAPE]] : tensor<?xindex>
// CHECK-NEXT: %[[B:.*]] = "xla_hlo.dynamic_reshape"(%[[FLAT_B]], %[[SHAPE_AS_EXTENT_TENSOR]]) : (tensor<?xf32>, tensor<?xindex>) -> tensor<*xf32> // CHECK-NEXT: %[[B:.*]] = "xla_hlo.dynamic_reshape"(%[[FLAT_B]], %[[SHAPE_AS_EXTENT_TENSOR]]) : (tensor<?xf32>, tensor<?xindex>) -> tensor<*xf32>
// CHECK-NEXT: return %[[B]] : tensor<*xf32> // CHECK-NEXT: return %[[B]] : tensor<*xf32>
%b = "xla_hlo.sqrt"(%a) : (tensor<*xf32>) -> tensor<*xf32> %b = "xla_hlo.sqrt"(%a) : (tensor<*xf32>) -> tensor<*xf32>

View File

@ -18,6 +18,7 @@ limitations under the License.
#include <cstddef> #include <cstddef>
#include <cstdint> #include <cstdint>
#include <iterator> #include <iterator>
#include <limits>
#include <numeric> #include <numeric>
#include "llvm/ADT/ArrayRef.h" #include "llvm/ADT/ArrayRef.h"
@ -25,6 +26,7 @@ limitations under the License.
#include "llvm/ADT/STLExtras.h" #include "llvm/ADT/STLExtras.h"
#include "llvm/ADT/Sequence.h" #include "llvm/ADT/Sequence.h"
#include "llvm/ADT/SmallVector.h" #include "llvm/ADT/SmallVector.h"
#include "llvm/Support/ErrorHandling.h"
#include "llvm/Support/FormatVariadic.h" #include "llvm/Support/FormatVariadic.h"
#include "mlir/Dialect/Shape/IR/Shape.h" // from @llvm-project #include "mlir/Dialect/Shape/IR/Shape.h" // from @llvm-project
#include "mlir/Dialect/StandardOps/IR/Ops.h" // from @llvm-project #include "mlir/Dialect/StandardOps/IR/Ops.h" // from @llvm-project
@ -54,6 +56,7 @@ limitations under the License.
#include "tensorflow/compiler/xla/xla_data.pb.h" #include "tensorflow/compiler/xla/xla_data.pb.h"
#include "tensorflow/core/framework/kernel_shape_util.h" #include "tensorflow/core/framework/kernel_shape_util.h"
#include "tensorflow/core/kernels/conv_grad_shape_utils.h" #include "tensorflow/core/kernels/conv_grad_shape_utils.h"
#include "tensorflow/core/lib/bfloat16/bfloat16.h"
#include "tensorflow/core/util/padding.h" #include "tensorflow/core/util/padding.h"
#include "tensorflow/core/util/tensor_format.h" #include "tensorflow/core/util/tensor_format.h"
@ -881,6 +884,31 @@ static Type GetAccumulationType(Type ty) {
return (ty.isF16() || ty.isBF16()) ? FloatType::getF32(ty.getContext()) : ty; return (ty.isF16() || ty.isBF16()) ? FloatType::getF32(ty.getContext()) : ty;
} }
//===----------------------------------------------------------------------===//
// Softplus op utilities.
//===----------------------------------------------------------------------===//
static DenseElementsAttr GetEpsilonValue(Type ty) {
auto element_ty = ty.cast<TensorType>().getElementType();
auto scalar_ty = RankedTensorType::get({}, element_ty);
if (element_ty.isF16()) {
uint16_t raw_epsilon = Eigen::NumTraits<Eigen::half>::epsilon().x;
auto value = APFloat(APFloat::IEEEhalf(), APInt(16, raw_epsilon));
return DenseElementsAttr::get(scalar_ty, value);
} else if (element_ty.isBF16()) {
uint16_t raw_epsilon = tensorflow::bfloat16::epsilon().value;
auto value = APFloat(APFloat::BFloat(), APInt(16, raw_epsilon));
return DenseElementsAttr::get(scalar_ty, value);
} else if (element_ty.isF32()) {
auto value = APFloat(std::numeric_limits<float>::epsilon());
return DenseElementsAttr::get(scalar_ty, value);
} else if (element_ty.isF64()) {
auto value = APFloat(std::numeric_limits<double>::epsilon());
return DenseElementsAttr::get(scalar_ty, value);
}
llvm_unreachable("unsupported element type for tf.SoftPlus");
}
//===----------------------------------------------------------------------===// //===----------------------------------------------------------------------===//
// ArgMax/ArgMin op utilities. // ArgMax/ArgMin op utilities.
//===----------------------------------------------------------------------===// //===----------------------------------------------------------------------===//
@ -4387,45 +4415,6 @@ class ConvertRandomShuffleOp : public OpRewritePattern<TF::RandomShuffleOp> {
} }
}; };
// Converts tf.VariableShape op to a XLA HLO constant representing the variable
// shape.
class ConvertVariableShapeOp : public OpRewritePattern<TF::VariableShapeOp> {
public:
using OpRewritePattern::OpRewritePattern;
LogicalResult matchAndRewrite(TF::VariableShapeOp op,
PatternRewriter &rewriter) const override {
// The input type should be a tensor<!tf.resource<resource-type>>. We need
// to get the inner resource type.
auto input_type = op.input().getType().cast<TensorType>();
auto subtypes =
input_type.getElementType().cast<TF::ResourceType>().getSubtypes();
// It can be missing; then we cannot convert.
if (subtypes.empty()) return failure();
auto resource_type = subtypes[0].cast<TensorType>();
if (!resource_type.hasStaticShape()) return failure();
auto resource_shape = resource_type.getShape();
Attribute const_attr;
// We need to match the original op result's element type.
auto element_type = op.getType().cast<TensorType>().getElementType();
unsigned bitwidth = element_type.cast<IntegerType>().getWidth();
if (bitwidth == 32) {
SmallVector<int32_t, 4> shape(resource_shape.begin(),
resource_shape.end());
const_attr = GetI32ElementsAttr(shape, &rewriter);
} else {
assert(bitwidth == 64);
const_attr = GetI64ElementsAttr(resource_shape, &rewriter);
}
rewriter.replaceOpWithNewOp<xla_hlo::ConstOp>(op, const_attr);
return success();
}
};
// Converts an XlaSharding op to a XLA HLO shard op with sharding attributes. // Converts an XlaSharding op to a XLA HLO shard op with sharding attributes.
class ConvertXlaShardingOp : public OpRewritePattern<TF::XlaShardingOp> { class ConvertXlaShardingOp : public OpRewritePattern<TF::XlaShardingOp> {
public: public:
@ -4621,45 +4610,19 @@ class ConvertShapeOp : public OpRewritePattern<TF::ShapeOp> {
LogicalResult matchAndRewrite(TF::ShapeOp op, LogicalResult matchAndRewrite(TF::ShapeOp op,
PatternRewriter &rewriter) const override { PatternRewriter &rewriter) const override {
Value input = op.input(); Value input = op.input();
auto input_ty = input.getType().dyn_cast<RankedTensorType>();
// If the shape is static it can be canonicalized. auto shape_op = rewriter.create<shape::ShapeOfOp>(op.getLoc(), input);
if (!input_ty || input_ty.hasStaticShape()) { auto result_ty = op.getResult().getType().dyn_cast<RankedTensorType>();
if (!result_ty) {
return failure(); return failure();
} }
auto result_ty = op.getResult().getType().cast<RankedTensorType>(); auto index_tensor =
auto element_ty = result_ty.getElementType(); RankedTensorType::get(result_ty.getShape(), rewriter.getIndexType());
auto extent_tensor = rewriter.create<shape::ToExtentTensorOp>(
op.getLoc(), index_tensor, shape_op);
int64_t rank = input_ty.getRank(); rewriter.replaceOpWithNewOp<IndexCastOp>(op, result_ty, extent_tensor);
auto shape_op = rewriter.create<shape::ShapeOfOp>(op.getLoc(), input);
auto index_ty = RankedTensorType::get({1}, element_ty);
llvm::SmallVector<Value, 4> dim_values;
for (int64_t i = 0; i < rank; ++i) {
if (!input_ty.isDynamicDim(i)) {
auto dim_attr = DenseElementsAttr::get(
index_ty,
rewriter.getIntegerAttr(element_ty, input_ty.getDimSize(i)));
auto index = rewriter.create<xla_hlo::ConstOp>(op.getLoc(), dim_attr);
dim_values.push_back(index);
continue;
}
auto extent_op = rewriter.create<shape::GetExtentOp>(
op.getLoc(), shape_op, rewriter.getI64IntegerAttr(i));
auto index_op = rewriter.create<shape::SizeToIndexOp>(
op.getLoc(), rewriter.getIndexType(), extent_op);
auto int_op =
rewriter.create<IndexCastOp>(op.getLoc(), element_ty, index_op);
auto from_tensor = rewriter.create<TensorFromElementsOp>(
op.getLoc(), int_op.getResult());
auto reshape_op =
rewriter.create<ReshapeOp>(op.getLoc(), index_ty, from_tensor);
dim_values.push_back(reshape_op);
}
rewriter.replaceOpWithNewOp<ConcatenateOp>(op, result_ty, dim_values,
rewriter.getI64IntegerAttr(0));
return success(); return success();
} }
}; };
@ -5250,7 +5213,7 @@ LogicalResult legalizeTF(Operation *op, bool allow_partial_conversion,
ConvertTensorScatterUpdateOp, ConvertTileOp, ConvertTopKV2Op, ConvertTensorScatterUpdateOp, ConvertTileOp, ConvertTopKV2Op,
ConvertUnpackOp, ConvertUnsortedSegmentMaxOp, ConvertUnsortedSegmentMinOp, ConvertUnpackOp, ConvertUnsortedSegmentMaxOp, ConvertUnsortedSegmentMinOp,
ConvertUnsortedSegmentProdOp, ConvertUnsortedSegmentSumOp, ConvertUnsortedSegmentProdOp, ConvertUnsortedSegmentSumOp,
ConvertRandomShuffleOp, ConvertVariableShapeOp, ConvertXlaShardingOp, ConvertRandomShuffleOp, ConvertXlaShardingOp,
ConvertXlaDynamicUpdateSliceOp>(op->getContext()); ConvertXlaDynamicUpdateSliceOp>(op->getContext());
// Populate with CHLO->HLO lowerings to account for TF ops legalized to // Populate with CHLO->HLO lowerings to account for TF ops legalized to

View File

@ -618,3 +618,39 @@ def : Pat<(TF_SigmoidGradOp AnyRankedTensor:$l, AnyRankedTensor:$r),
(HLO_MulOp (HLO_MulOp
(HLO_MulOp $r, $l), (HLO_MulOp $r, $l),
(HLO_SubOp (HLO_ConstOp (ConstantSplat<"1"> $l)), $l))>; (HLO_SubOp (HLO_ConstOp (ConstantSplat<"1"> $l)), $l))>;
//===----------------------------------------------------------------------===//
// Softplus op.
//===----------------------------------------------------------------------===//
def EpsilonValue : NativeCodeCall<"GetEpsilonValue($0.getType())">;
def : Pattern<(TF_SoftplusOp AnyTensor:$features),
[
(HLO_ExpOp:$features_exp $features),
(HLOClient_BroadcastAddOp:$threshold
(HLO_LogOp (HLO_ConstOp (EpsilonValue $features))),
(HLO_ConstOp (GetScalarOfType<2> $features)),
(NullDenseIntElementsAttr)
),
(HLO_SelectOp:$output
(HLOClient_BroadcastCompareOp
$features,
(HLO_NegOp $threshold),
(NullDenseIntElementsAttr),
HLO_COMPARISON_DIRECTION_GT
),
$features,
(HLO_SelectOp
(HLOClient_BroadcastCompareOp
$features,
$threshold,
(NullDenseIntElementsAttr),
HLO_COMPARISON_DIRECTION_LT
),
$features_exp,
(HLO_Log1pOp $features_exp)
)
),
(replaceWithValue $output)
]>;

View File

@ -89,6 +89,8 @@ static bool IsOpWhitelisted(Operation* op) {
TypeID::get<TF::AddV2Op>(), TypeID::get<TF::AddV2Op>(),
TypeID::get<TF::AngleOp>(), TypeID::get<TF::AngleOp>(),
TypeID::get<TF::ApproximateEqualOp>(), TypeID::get<TF::ApproximateEqualOp>(),
TypeID::get<TF::ArgMaxOp>(),
TypeID::get<TF::ArgMinOp>(),
TypeID::get<TF::AsinhOp>(), TypeID::get<TF::AsinhOp>(),
TypeID::get<TF::AsinOp>(), TypeID::get<TF::AsinOp>(),
TypeID::get<TF::Atan2Op>(), TypeID::get<TF::Atan2Op>(),
@ -100,6 +102,7 @@ static bool IsOpWhitelisted(Operation* op) {
TypeID::get<TF::BitwiseAndOp>(), TypeID::get<TF::BitwiseAndOp>(),
TypeID::get<TF::BitwiseOrOp>(), TypeID::get<TF::BitwiseOrOp>(),
TypeID::get<TF::BitwiseXorOp>(), TypeID::get<TF::BitwiseXorOp>(),
TypeID::get<TF::BucketizeOp>(),
TypeID::get<TF::CastOp>(), TypeID::get<TF::CastOp>(),
TypeID::get<TF::ClipByValueOp>(), TypeID::get<TF::ClipByValueOp>(),
TypeID::get<TF::ComplexAbsOp>(), TypeID::get<TF::ComplexAbsOp>(),
@ -116,13 +119,24 @@ static bool IsOpWhitelisted(Operation* op) {
TypeID::get<TF::ErfcOp>(), TypeID::get<TF::ErfcOp>(),
TypeID::get<TF::ErfOp>(), TypeID::get<TF::ErfOp>(),
TypeID::get<TF::Expm1Op>(), TypeID::get<TF::Expm1Op>(),
TypeID::get<TF::FFT2DOp>(),
TypeID::get<TF::FFT3DOp>(),
TypeID::get<TF::FFTOp>(),
TypeID::get<TF::FloorDivOp>(), TypeID::get<TF::FloorDivOp>(),
TypeID::get<TF::FloorModOp>(), TypeID::get<TF::FloorModOp>(),
TypeID::get<TF::GatherNdOp>(), TypeID::get<TF::GatherNdOp>(),
TypeID::get<TF::GreaterEqualOp>(), TypeID::get<TF::GreaterEqualOp>(),
TypeID::get<TF::GreaterOp>(), TypeID::get<TF::GreaterOp>(),
TypeID::get<TF::IFFT2DOp>(),
TypeID::get<TF::IFFT3DOp>(),
TypeID::get<TF::IFFTOp>(),
TypeID::get<TF::IRFFT2DOp>(),
TypeID::get<TF::IRFFT3DOp>(),
TypeID::get<TF::IRFFTOp>(),
TypeID::get<TF::InvertOp>(), TypeID::get<TF::InvertOp>(),
TypeID::get<TF::InvOp>(), TypeID::get<TF::InvOp>(),
TypeID::get<TF::LRNOp>(),
TypeID::get<TF::LRNGradOp>(),
TypeID::get<TF::LeakyReluGradOp>(), TypeID::get<TF::LeakyReluGradOp>(),
TypeID::get<TF::LeakyReluOp>(), TypeID::get<TF::LeakyReluOp>(),
TypeID::get<TF::LeftShiftOp>(), TypeID::get<TF::LeftShiftOp>(),
@ -134,16 +148,20 @@ static bool IsOpWhitelisted(Operation* op) {
TypeID::get<TF::LogicalOrOp>(), TypeID::get<TF::LogicalOrOp>(),
TypeID::get<TF::LogOp>(), TypeID::get<TF::LogOp>(),
TypeID::get<TF::MatMulOp>(), TypeID::get<TF::MatMulOp>(),
TypeID::get<TF::MirrorPadOp>(),
TypeID::get<TF::MulOp>(), TypeID::get<TF::MulOp>(),
TypeID::get<TF::NegOp>(), TypeID::get<TF::NegOp>(),
TypeID::get<TF::NotEqualOp>(), TypeID::get<TF::NotEqualOp>(),
TypeID::get<TF::PadOp>(), TypeID::get<TF::PadOp>(),
TypeID::get<TF::PlaceholderWithDefaultOp>(), TypeID::get<TF::PlaceholderWithDefaultOp>(),
TypeID::get<TF::PowOp>(), TypeID::get<TF::PowOp>(),
TypeID::get<TF::RFFT2DOp>(),
TypeID::get<TF::RFFT3DOp>(),
TypeID::get<TF::RealDivOp>(), TypeID::get<TF::RealDivOp>(),
TypeID::get<TF::ReciprocalOp>(), TypeID::get<TF::ReciprocalOp>(),
TypeID::get<TF::ReciprocalGradOp>(), TypeID::get<TF::ReciprocalGradOp>(),
TypeID::get<TF::Relu6GradOp>(), TypeID::get<TF::Relu6GradOp>(),
TypeID::get<TF::ReverseSequenceOp>(),
TypeID::get<TF::RightShiftOp>(), TypeID::get<TF::RightShiftOp>(),
TypeID::get<TF::RintOp>(), TypeID::get<TF::RintOp>(),
TypeID::get<TF::RoundOp>(), TypeID::get<TF::RoundOp>(),
@ -156,6 +174,7 @@ static bool IsOpWhitelisted(Operation* op) {
TypeID::get<TF::SoftplusGradOp>(), TypeID::get<TF::SoftplusGradOp>(),
TypeID::get<TF::SoftsignGradOp>(), TypeID::get<TF::SoftsignGradOp>(),
TypeID::get<TF::SoftsignOp>(), TypeID::get<TF::SoftsignOp>(),
TypeID::get<TF::SparseToDenseOp>(),
TypeID::get<TF::SqrtGradOp>(), TypeID::get<TF::SqrtGradOp>(),
TypeID::get<TF::SquareOp>(), TypeID::get<TF::SquareOp>(),
TypeID::get<TF::SubOp>(), TypeID::get<TF::SubOp>(),

View File

@ -16,7 +16,6 @@ limitations under the License.
// This file implements logic for lowering HLO/LHLO dialect to Linalg dialect. // This file implements logic for lowering HLO/LHLO dialect to Linalg dialect.
#include "absl/memory/memory.h" #include "absl/memory/memory.h"
#include "llvm/ADT/APInt.h"
#include "mlir/Dialect/Linalg/IR/LinalgOps.h" // from @llvm-project #include "mlir/Dialect/Linalg/IR/LinalgOps.h" // from @llvm-project
#include "mlir/Dialect/Linalg/IR/LinalgTypes.h" // from @llvm-project #include "mlir/Dialect/Linalg/IR/LinalgTypes.h" // from @llvm-project
#include "mlir/Dialect/StandardOps/IR/Ops.h" // from @llvm-project #include "mlir/Dialect/StandardOps/IR/Ops.h" // from @llvm-project
@ -39,13 +38,9 @@ limitations under the License.
namespace mlir { namespace mlir {
namespace { namespace {
ArrayAttr GetNParallelLoopsAttrs(unsigned nParallelLoops, Builder* b) { SmallVector<StringRef, 3> GetNParallelLoopsAttrs(unsigned nParallelLoops) {
auto parallelLoopTypeAttr = b->getStringAttr("parallel"); static constexpr StringRef kParallelIterType = "parallel";
SmallVector<Attribute, 3> iteratorTypes; return SmallVector<StringRef, 3>(nParallelLoops, kParallelIterType);
for (int i = 0; i < nParallelLoops; ++i) {
iteratorTypes.push_back(parallelLoopTypeAttr);
}
return b->getArrayAttr(iteratorTypes);
} }
template <bool isLHLO = true> template <bool isLHLO = true>
@ -90,7 +85,7 @@ class PointwiseToLinalgConverter : public OpConversionPattern<OpTy> {
} }
// Construct the indexing maps needed for linalg.generic ops. // Construct the indexing maps needed for linalg.generic ops.
SmallVector<Attribute, 2> indexingMaps; SmallVector<AffineMap, 2> indexing_maps;
SmallVector<Type, 4> bodyArgTypes, bodyResultTypes, opResultTypes; SmallVector<Type, 4> bodyArgTypes, bodyResultTypes, opResultTypes;
// This doesnt account for implicit broadcast, but the working assumption // This doesnt account for implicit broadcast, but the working assumption
@ -107,9 +102,9 @@ class PointwiseToLinalgConverter : public OpConversionPattern<OpTy> {
!shapedType.isa<RankedTensorType>()) || !shapedType.isa<RankedTensorType>()) ||
shapedType.getRank() != nloops) shapedType.getRank() != nloops)
return nullptr; return nullptr;
indexingMaps.emplace_back(AffineMapAttr::get( indexing_maps.emplace_back(
nloops ? rewriter.getMultiDimIdentityMap(nloops) nloops ? rewriter.getMultiDimIdentityMap(nloops)
: AffineMap::get(nloops, 0, rewriter.getContext()))); : AffineMap::get(nloops, 0, rewriter.getContext()));
return shapedType; return shapedType;
}; };
for (const auto& arg : llvm::enumerate(args)) { for (const auto& arg : llvm::enumerate(args)) {
@ -132,11 +127,9 @@ class PointwiseToLinalgConverter : public OpConversionPattern<OpTy> {
auto linalgOp = rewriter.create<linalg::GenericOp>( auto linalgOp = rewriter.create<linalg::GenericOp>(
loc, opResultTypes, args, loc, opResultTypes, args,
rewriter.getI64IntegerAttr(bodyArgTypes.size()), // args_in /*inputCount=*/bodyArgTypes.size(),
rewriter.getI64IntegerAttr(bodyResultTypes.size()), // args_out /*outputCount=*/bodyResultTypes.size(), indexing_maps,
rewriter.getArrayAttr(indexingMaps), GetNParallelLoopsAttrs(nloops));
GetNParallelLoopsAttrs(nloops, &rewriter),
/*doc=*/nullptr, /*library_call=*/nullptr);
// Add a block to the region. // Add a block to the region.
auto* region = &linalgOp.region(); auto* region = &linalgOp.region();
@ -297,8 +290,8 @@ struct ConvToLinalgConverter : public OpConversionPattern<xla_lhlo::ConvOp> {
/// Base class for lowering xla operations that have one operand and one result, /// Base class for lowering xla operations that have one operand and one result,
/// and are semantically equivalent to a copy of the input to the output (like /// and are semantically equivalent to a copy of the input to the output (like
/// transpose, some reshape, etc.). The derived classes need to provide a method /// transpose, some reshape, etc.). The derived classes need to provide a method
/// `getIndexingMapsAttr` that returns an ArrayAttr containing AffineMapAttr for /// `getIndexingMaps` that returns AffineMaps for the index maps of the input
/// the index maps of the input and the output. /// and the output.
template <typename Derived, typename OpTy, bool isLHLO = true> template <typename Derived, typename OpTy, bool isLHLO = true>
class DataMovementOpConverter : public OpConversionPattern<OpTy> { class DataMovementOpConverter : public OpConversionPattern<OpTy> {
public: public:
@ -310,17 +303,17 @@ class DataMovementOpConverter : public OpConversionPattern<OpTy> {
if (!verifyXLAOpBufferOrTensorSemantics<isLHLO>(op)) return failure(); if (!verifyXLAOpBufferOrTensorSemantics<isLHLO>(op)) return failure();
auto operandType = op.operand().getType().template cast<ShapedType>(); auto operandType = op.operand().getType().template cast<ShapedType>();
auto resultType = getXLAOpResultType<isLHLO>(op); auto resultType = getXLAOpResultType<isLHLO>(op);
ArrayAttr indexingMapsAttr = Derived::getIndexingMapsAttr(op, &rewriter);
if (!indexingMapsAttr) return failure(); SmallVector<AffineMap, 2> indexing_maps =
Derived::getIndexingMaps(op, &rewriter);
if (indexing_maps.empty()) return failure();
OpBuilder::InsertionGuard linalgOpGuard(rewriter); OpBuilder::InsertionGuard linalgOpGuard(rewriter);
auto nloops = resultType.getRank(); auto nloops = resultType.getRank();
auto loc = op.getLoc(); auto loc = op.getLoc();
auto linalgOp = rewriter.create<linalg::GenericOp>( auto linalgOp = rewriter.create<linalg::GenericOp>(
loc, isLHLO ? ArrayRef<Type>{} : resultType, args, loc, isLHLO ? ArrayRef<Type>{} : resultType, args, /*inputCount=*/1,
rewriter.getI64IntegerAttr(1), rewriter.getI64IntegerAttr(1), /*outputCount=*/1, indexing_maps, GetNParallelLoopsAttrs(nloops));
indexingMapsAttr, GetNParallelLoopsAttrs(nloops, &rewriter),
/*doc=*/nullptr, /*library_call=*/nullptr);
auto* region = &linalgOp.region(); auto* region = &linalgOp.region();
auto* block = rewriter.createBlock(region, region->end()); auto* block = rewriter.createBlock(region, region->end());
@ -344,7 +337,8 @@ class BroadcastConverter
using DataMovementOpConverter<BroadcastConverter, OpTy, using DataMovementOpConverter<BroadcastConverter, OpTy,
isLHLO>::DataMovementOpConverter; isLHLO>::DataMovementOpConverter;
static ArrayAttr getIndexingMapsAttr(OpTy broadcastOp, Builder* b) { static SmallVector<AffineMap, 2> getIndexingMaps(OpTy broadcastOp,
Builder* b) {
ShapedType inputType = ShapedType inputType =
broadcastOp.operand().getType().template cast<ShapedType>(); broadcastOp.operand().getType().template cast<ShapedType>();
unsigned inputRank = inputType.getRank(); unsigned inputRank = inputType.getRank();
@ -368,8 +362,7 @@ class BroadcastConverter
inputMap = inputMap =
AffineMap::get(nloops, /*symbolCount=*/0, inputDimExprs, context); AffineMap::get(nloops, /*symbolCount=*/0, inputDimExprs, context);
} }
return b->getAffineMapArrayAttr( return {inputMap, b->getMultiDimIdentityMap(nloops)};
{inputMap, b->getMultiDimIdentityMap(nloops)});
} }
}; };
@ -381,8 +374,8 @@ class HloBroadcastInDimConverter
xla_hlo::BroadcastInDimOp, xla_hlo::BroadcastInDimOp,
false>::DataMovementOpConverter; false>::DataMovementOpConverter;
static ArrayAttr getIndexingMapsAttr(xla_hlo::BroadcastInDimOp broadcastOp, static SmallVector<AffineMap, 2> getIndexingMaps(
Builder* b) { xla_hlo::BroadcastInDimOp broadcastOp, Builder* b) {
auto resultType = getXLAOpResultType<false>(broadcastOp); auto resultType = getXLAOpResultType<false>(broadcastOp);
auto operandType = auto operandType =
broadcastOp.operand().getType().template cast<ShapedType>(); broadcastOp.operand().getType().template cast<ShapedType>();
@ -390,9 +383,8 @@ class HloBroadcastInDimConverter
// The input is a scalar, i.e. this is a scalar broadcast op. // The input is a scalar, i.e. this is a scalar broadcast op.
if (operandType.getRank() == 0) { if (operandType.getRank() == 0) {
return b->getAffineMapArrayAttr( return {AffineMap::get(nloops, /*symbolCount=*/0, b->getContext()),
{AffineMap::get(nloops, /*symbolCount=*/0, b->getContext()), b->getMultiDimIdentityMap(nloops)};
b->getMultiDimIdentityMap(nloops)});
} }
auto operandShape = operandType.getShape(); auto operandShape = operandType.getShape();
@ -409,9 +401,9 @@ class HloBroadcastInDimConverter
: b->getAffineDimExpr(size)); : b->getAffineDimExpr(size));
} }
} }
return b->getAffineMapArrayAttr( return {
{AffineMap::get(nloops, /*symbolCount=*/0, dimExprs, b->getContext()), AffineMap::get(nloops, /*symbolCount=*/0, dimExprs, b->getContext()),
b->getMultiDimIdentityMap(nloops)}); b->getMultiDimIdentityMap(nloops)};
} }
}; };
@ -447,11 +439,9 @@ class LhloBroadcastInDimConverter
rewriter.create<LoadOp>(loc, operand, llvm::makeArrayRef({zero})); rewriter.create<LoadOp>(loc, operand, llvm::makeArrayRef({zero}));
auto linalgOp = rewriter.create<linalg::GenericOp>( auto linalgOp = rewriter.create<linalg::GenericOp>(
loc, llvm::None, llvm::makeArrayRef(operand_adaptor.output()), loc, llvm::None, llvm::makeArrayRef(operand_adaptor.output()),
rewriter.getI64IntegerAttr(0), rewriter.getI64IntegerAttr(1), /*inputCount=*/0, /*outputCount=*/1,
rewriter.getAffineMapArrayAttr( llvm::makeArrayRef(rewriter.getMultiDimIdentityMap(nloops)),
{rewriter.getMultiDimIdentityMap(nloops)}), GetNParallelLoopsAttrs(nloops));
GetNParallelLoopsAttrs(nloops, &rewriter),
/*doc=*/nullptr, /*library_call=*/nullptr);
auto* region = &linalgOp.region(); auto* region = &linalgOp.region();
auto* block = rewriter.createBlock(region, region->end()); auto* block = rewriter.createBlock(region, region->end());
@ -460,16 +450,15 @@ class LhloBroadcastInDimConverter
rewriter.setInsertionPointToEnd(block); rewriter.setInsertionPointToEnd(block);
rewriter.create<linalg::YieldOp>(loc, val); rewriter.create<linalg::YieldOp>(loc, val);
} else { } else {
ArrayAttr indexingMapsAttr = getIndexingMapsAttr( auto indexing_maps = getIndexingMaps(op, broadcast_dims, result_shape,
op, broadcast_dims, result_shape, operand_type, &rewriter); operand_type, &rewriter);
OpBuilder::InsertionGuard linalgOpGuard(rewriter); OpBuilder::InsertionGuard linalgOpGuard(rewriter);
auto linalgOp = rewriter.create<linalg::GenericOp>( auto linalgOp = rewriter.create<linalg::GenericOp>(
loc, llvm::None, loc, llvm::None,
llvm::makeArrayRef({operand, operand_adaptor.output()}), llvm::makeArrayRef({operand, operand_adaptor.output()}),
rewriter.getI64IntegerAttr(1), rewriter.getI64IntegerAttr(1), /*inputCount=*/1, /*outputCount=*/1, indexing_maps,
indexingMapsAttr, GetNParallelLoopsAttrs(nloops, &rewriter), GetNParallelLoopsAttrs(nloops));
/*doc=*/nullptr, /*library_call=*/nullptr);
auto* region = &linalgOp.region(); auto* region = &linalgOp.region();
auto* block = rewriter.createBlock(region, region->end()); auto* block = rewriter.createBlock(region, region->end());
@ -504,14 +493,14 @@ class LhloBroadcastInDimConverter
} }
SmallVector<int64_t, 2> new_shape, new_strides, broadcast_dims; SmallVector<int64_t, 2> new_shape, new_strides, broadcast_dims;
SmallVector<SmallVector<AffineExpr, 2>, 4> collapsed_dims_list; SmallVector<linalg::ReassociationIndices, 4> collapsed_dims_list;
SmallVector<AffineExpr, 2> collapsed_dims; linalg::ReassociationIndices collapsed_dims;
for (const auto& item : for (const auto& item :
enumerate(op.broadcast_dimensions().getIntValues())) { enumerate(op.broadcast_dimensions().getIntValues())) {
size_t index = item.index(); size_t index = item.index();
int dim = item.value().getSExtValue(); int dim = item.value().getSExtValue();
collapsed_dims.push_back(rewriter.getAffineDimExpr(index)); collapsed_dims.push_back(index);
bool expansion_needed = bool expansion_needed =
operand_shape[index] == 1 && result_shape[dim] != 1; operand_shape[index] == 1 && result_shape[dim] != 1;
@ -542,31 +531,28 @@ class LhloBroadcastInDimConverter
// `linalg.reshape` is inserted only if necessary, i.e. when the rank can be // `linalg.reshape` is inserted only if necessary, i.e. when the rank can be
// reduced. // reduced.
if (new_shape.size() < operand_shape.size()) { if (new_shape.size() < operand_shape.size()) {
SmallVector<ArrayRef<AffineExpr>, 4> reassociation_maps;
for (const auto& dims : collapsed_dims_list)
reassociation_maps.push_back(dims);
auto new_memref_type = MemRefType::get( auto new_memref_type = MemRefType::get(
new_shape, operand_type.getElementType(), new_shape, operand_type.getElementType(),
makeStridedLinearLayoutMap(new_strides, operand_offset, makeStridedLinearLayoutMap(new_strides, operand_offset,
rewriter.getContext())); rewriter.getContext()));
operand = rewriter.create<linalg::ReshapeOp>(op.getLoc(), new_memref_type, operand = rewriter.create<linalg::ReshapeOp>(op.getLoc(), new_memref_type,
operand_adaptor.operand(), operand_adaptor.operand(),
reassociation_maps); collapsed_dims_list);
} }
return std::make_pair(operand, broadcast_dims); return std::make_pair(operand, broadcast_dims);
} }
ArrayAttr getIndexingMapsAttr(xla_lhlo::BroadcastInDimOp op, SmallVector<AffineMap, 2> getIndexingMaps(xla_lhlo::BroadcastInDimOp op,
ArrayRef<int64_t> broadcastDims, ArrayRef<int64_t> broadcastDims,
ArrayRef<int64_t> resultShape, ArrayRef<int64_t> resultShape,
MemRefType operandType, Builder* b) const { MemRefType operandType,
Builder* b) const {
unsigned nloops = resultShape.size(); unsigned nloops = resultShape.size();
// The input is a scalar, i.e. this is a scalar broadcast op. // The input is a scalar, i.e. this is a scalar broadcast op.
if (operandType.getRank() == 0) { if (operandType.getRank() == 0) {
return b->getAffineMapArrayAttr( return {AffineMap::get(nloops, /*symbolCount=*/0, b->getContext()),
{AffineMap::get(nloops, /*symbolCount=*/0, b->getContext()), b->getMultiDimIdentityMap(nloops)};
b->getMultiDimIdentityMap(nloops)});
} }
auto operandShape = operandType.getShape(); auto operandShape = operandType.getShape();
@ -584,99 +570,9 @@ class LhloBroadcastInDimConverter
} }
dimExprs.push_back(b->getAffineDimExpr(size)); dimExprs.push_back(b->getAffineDimExpr(size));
} }
return b->getAffineMapArrayAttr( return {
{AffineMap::get(nloops, /*symbolCount=*/0, dimExprs, b->getContext()), AffineMap::get(nloops, /*symbolCount=*/0, dimExprs, b->getContext()),
b->getMultiDimIdentityMap(nloops)}); b->getMultiDimIdentityMap(nloops)};
}
};
/// Pattern for the special case where reshape is adding or removing a dimension
/// of size 1. These can be lowered to a linalg.generic op.
///
/// For example a
/// "xla_hlo.reshape"(..) : (tensor<12x1x42xi32) -> tensor<12x42xi32>
/// can have indexing maps
/// [affine_map<(d0, d1) -> (d0, 0, d1)>, affine_map<(d0, d1) -> (d0, d1)>]
///
/// Similarly a
/// "xla_hlo.reshape"(..) : (tensor<12x42xi32>) -> tensor<12x1x42xi32>
/// can have indexing maps
/// [affine_map<(d0, d1, d2) -> (d0, d2)>, affine_map<(d0, d1, d2) -> (d0, d1,
/// d2)>]
// TODO(ravishankarm): This pattern needs to be removed. The general reshape
// lowering hits a corner case where the following sequence of operations
// cannot be fused cause the resulting indexing map is not invertible.
//
// %r = linalg.reshape %s [affine_map<(d0, d1, d2) -> (d0, d1)>,
// affine_map<(d0, d1, d2) -> (d2)>]
// : tensor<5x5xf32> into tensor<5x1x5xf32>
// %f = linalg.generic
// {...
// indexing_maps = [affine_map<(d0, d1, d2) -> (d0, d1, d2)>,
// affine_map<(d0, d1, d2) -> (d0, d2)>],
// iterator_types = ["parallel", "parallel", "parallel"]} %r {..}
// : tensor<5x1x5xf32> -> tensor<5x5xf32>
//
// The resolution of this requires a canonicalization on linalg ops where the
// dims of size 1 are removed. This pattern can be removed after that.
template <typename OpTy, bool isLHLO = true>
class ReshapeAddRemoveDimConverter
: public DataMovementOpConverter<ReshapeAddRemoveDimConverter<OpTy, isLHLO>,
OpTy, isLHLO> {
public:
ReshapeAddRemoveDimConverter(MLIRContext* context)
: DataMovementOpConverter<ReshapeAddRemoveDimConverter<OpTy, isLHLO>,
OpTy, isLHLO>(context, 100) {}
static ArrayAttr getIndexingMapsAttr(OpTy op, Builder* b) {
auto resultType =
getXLAOpResultType<isLHLO>(op).template cast<ShapedType>();
auto operandType =
op.getOperation()->getOperand(0).getType().template cast<ShapedType>();
if (!resultType.hasStaticShape() || !operandType.hasStaticShape())
return nullptr;
auto nloops = resultType.getRank();
SmallVector<AffineExpr, 2> inputExprs;
unsigned resultIndex = 0, operandIndex = 0;
auto resultShape = resultType.getShape();
auto operandShape = operandType.getShape();
while (resultIndex < resultShape.size() &&
operandIndex < operandShape.size()) {
if (resultShape[resultIndex] == operandShape[operandIndex]) {
// Copy over the affine expr when the size of the result and operand
// match at a dim
inputExprs.push_back(b->getAffineDimExpr(resultIndex));
resultIndex++;
operandIndex++;
} else if (resultShape[resultIndex] == 1) {
// If size at result is 1, then ignore this dimension for the input, it
// is an extra dim added.
resultIndex++;
} else if (operandShape[operandIndex] == 1) {
// If the operandShape is 1, then add a (0) for the operand map since
// this dimension is dropped.
inputExprs.push_back(b->getAffineConstantExpr(0));
operandIndex++;
} else {
return nullptr;
}
}
// Make sure all remaining dimensions of the operand and result are ones.
auto checkRemainingDims = [](int64_t dim) { return dim != 1; };
if ((resultIndex < resultShape.size() &&
llvm::any_of(resultShape.drop_front(resultIndex),
checkRemainingDims)) ||
(operandIndex < operandShape.size() &&
llvm::any_of(operandShape.drop_front(operandIndex),
checkRemainingDims)))
return nullptr;
inputExprs.resize(operandShape.size(), b->getAffineConstantExpr(0));
return b->getAffineMapArrayAttr(
{AffineMap::get(nloops, /*symbolCount=*/0, inputExprs, b->getContext()),
b->getMultiDimIdentityMap(nloops)});
} }
}; };
@ -687,7 +583,7 @@ class TransposeConverter
public: public:
using DataMovementOpConverter<TransposeConverter<OpTy, isLHLO>, OpTy, using DataMovementOpConverter<TransposeConverter<OpTy, isLHLO>, OpTy,
isLHLO>::DataMovementOpConverter; isLHLO>::DataMovementOpConverter;
static ArrayAttr getIndexingMapsAttr(OpTy op, Builder* b) { static SmallVector<AffineMap, 2> getIndexingMaps(OpTy op, Builder* b) {
auto resultType = auto resultType =
getXLAOpResultType<isLHLO>(op).template cast<ShapedType>(); getXLAOpResultType<isLHLO>(op).template cast<ShapedType>();
auto nloops = resultType.getRank(); auto nloops = resultType.getRank();
@ -697,9 +593,9 @@ class TransposeConverter
inputExprs[permutation.value().getZExtValue()] = inputExprs[permutation.value().getZExtValue()] =
b->getAffineDimExpr(permutation.index()); b->getAffineDimExpr(permutation.index());
} }
return b->getAffineMapArrayAttr( return {
{AffineMap::get(nloops, /*symbolCount=*/0, inputExprs, b->getContext()), AffineMap::get(nloops, /*symbolCount=*/0, inputExprs, b->getContext()),
b->getMultiDimIdentityMap(nloops)}); b->getMultiDimIdentityMap(nloops)};
} }
}; };
@ -722,13 +618,6 @@ class ReshapeOpConverter : public OpConversionPattern<OpTy> {
if (!operandType.hasStaticShape() || !resultType.hasStaticShape()) if (!operandType.hasStaticShape() || !resultType.hasStaticShape())
return failure(); return failure();
// TODO(ravishankarm): To make this pattern not match the pattern that
// ReshapeAddRemoveDimConverter is for, check that condition here. Remove
// this when ReshapeAddRemoveDimConverter pattern is removed.
if (ReshapeAddRemoveDimConverter<OpTy, isLHLO>::getIndexingMapsAttr(
reshapeOp, &rewriter))
return failure();
// Compute the reassociation maps for the linalg operation. // Compute the reassociation maps for the linalg operation.
ArrayRef<int64_t> srcShape = ArrayRef<int64_t> srcShape =
(operandType.getRank() > resultType.getRank() ? operandType.getShape() (operandType.getRank() > resultType.getRank() ? operandType.getShape()
@ -737,22 +626,25 @@ class ReshapeOpConverter : public OpConversionPattern<OpTy> {
(operandType.getRank() > resultType.getRank() ? resultType.getShape() (operandType.getRank() > resultType.getRank() ? resultType.getShape()
: operandType.getShape()); : operandType.getShape());
unsigned currSrcDim = 0, currDstDim = 0; unsigned currSrcDim = 0, currDstDim = 0;
SmallVector<SmallVector<AffineExpr, 4>, 4> exprs(dstShape.size()); SmallVector<linalg::ReassociationExprs, 4> reassociationMap(
dstShape.size());
while (currSrcDim < srcShape.size() && currDstDim < dstShape.size()) { while (currSrcDim < srcShape.size() && currDstDim < dstShape.size()) {
int64_t dstSize = dstShape[currDstDim]; int64_t dstSize = dstShape[currDstDim];
int64_t srcSize = srcShape[currSrcDim]; int64_t srcSize = srcShape[currSrcDim];
while (srcSize < dstSize && currSrcDim < srcShape.size()) { while (srcSize < dstSize && currSrcDim < srcShape.size()) {
exprs[currDstDim].push_back(rewriter.getAffineDimExpr(currSrcDim++)); reassociationMap[currDstDim].push_back(
rewriter.getAffineDimExpr(currSrcDim++));
srcSize *= srcShape[currSrcDim]; srcSize *= srcShape[currSrcDim];
} }
if (srcSize == dstSize) { if (srcSize == dstSize) {
exprs[currDstDim].push_back(rewriter.getAffineDimExpr(currSrcDim++)); reassociationMap[currDstDim].push_back(
rewriter.getAffineDimExpr(currSrcDim++));
// If the next dim in dstShape is not 1, treat subsequent dims in // If the next dim in dstShape is not 1, treat subsequent dims in
// srcShape which are 1 to be collapsed. // srcShape which are 1 to be collapsed.
if (currDstDim == dstShape.size() - 1 || if (currDstDim == dstShape.size() - 1 ||
dstShape[currDstDim + 1] != 1) { dstShape[currDstDim + 1] != 1) {
while (currSrcDim < srcShape.size() && srcShape[currSrcDim] == 1) { while (currSrcDim < srcShape.size() && srcShape[currSrcDim] == 1) {
exprs[currDstDim].push_back( reassociationMap[currDstDim].push_back(
rewriter.getAffineDimExpr(currSrcDim++)); rewriter.getAffineDimExpr(currSrcDim++));
} }
} }
@ -763,18 +655,15 @@ class ReshapeOpConverter : public OpConversionPattern<OpTy> {
} }
if (currSrcDim != srcShape.size()) return failure(); if (currSrcDim != srcShape.size()) return failure();
SmallVector<ArrayRef<AffineExpr>, 4> reassociationMaps;
for (auto& expr : exprs) reassociationMaps.push_back(expr);
if (isLHLO) { if (isLHLO) {
Value reshapeBuffer = rewriter.create<linalg::ReshapeOp>( Value reshapeBuffer = rewriter.create<linalg::ReshapeOp>(
reshapeOp.getLoc(), resultType, args[0], reassociationMaps); reshapeOp.getLoc(), resultType, args[0], reassociationMap);
rewriter.replaceOpWithNewOp<linalg::CopyOp>( rewriter.replaceOpWithNewOp<linalg::CopyOp>(
reshapeOp, reshapeBuffer, args[1], /*inputPermutation =*/nullptr, reshapeOp, reshapeBuffer, args[1], /*inputPermutation =*/nullptr,
/*outputPermutation =*/nullptr); /*outputPermutation =*/nullptr);
} else { } else {
rewriter.replaceOpWithNewOp<linalg::TensorReshapeOp>( rewriter.replaceOpWithNewOp<linalg::TensorReshapeOp>(
reshapeOp, resultType, args[0], reassociationMaps); reshapeOp, resultType, args[0], reassociationMap);
} }
return success(); return success();
} }
@ -796,18 +685,14 @@ class IotaConverter : public OpConversionPattern<xla_lhlo::IotaOp> {
// Construct the indexing maps needed for linalg.generic ops. // Construct the indexing maps needed for linalg.generic ops.
unsigned nloops = resultMemrefType.getRank(); unsigned nloops = resultMemrefType.getRank();
SmallVector<Attribute, 2> indexingMaps;
indexingMaps.emplace_back(
AffineMapAttr::get(rewriter.getMultiDimIdentityMap(nloops)));
auto loc = iotaOp.getLoc(); auto loc = iotaOp.getLoc();
auto linalgOp = rewriter.create<linalg::IndexedGenericOp>( auto linalgOp = rewriter.create<linalg::IndexedGenericOp>(
loc, ArrayRef<Type>{}, args, loc, ArrayRef<Type>{}, args,
rewriter.getI64IntegerAttr(0), // args_in 0, // args_in
rewriter.getI64IntegerAttr(1), // args_out 1, // args_out
rewriter.getArrayAttr(indexingMaps), llvm::makeArrayRef(rewriter.getMultiDimIdentityMap(nloops)),
GetNParallelLoopsAttrs(nloops, &rewriter), GetNParallelLoopsAttrs(nloops));
/*doc=*/nullptr, /*library_call=*/nullptr);
// Add a block to the region. // Add a block to the region.
auto* region = &linalgOp.region(); auto* region = &linalgOp.region();
@ -857,7 +742,7 @@ class ReverseConverter
public: public:
using DataMovementOpConverter<ReverseConverter<OpTy, isLHLO>, OpTy, using DataMovementOpConverter<ReverseConverter<OpTy, isLHLO>, OpTy,
isLHLO>::DataMovementOpConverter; isLHLO>::DataMovementOpConverter;
static ArrayAttr getIndexingMapsAttr(OpTy op, Builder* b) { static SmallVector<AffineMap, 2> getIndexingMaps(OpTy op, Builder* b) {
auto resultType = auto resultType =
getXLAOpResultType<isLHLO>(op).template cast<ShapedType>(); getXLAOpResultType<isLHLO>(op).template cast<ShapedType>();
auto nloops = resultType.getRank(); auto nloops = resultType.getRank();
@ -871,9 +756,9 @@ class ReverseConverter
int n = resultType.getShape()[i]; int n = resultType.getShape()[i];
inputExprs[i] = b->getAffineConstantExpr(n - 1) - inputExprs[i]; inputExprs[i] = b->getAffineConstantExpr(n - 1) - inputExprs[i];
} }
return b->getAffineMapArrayAttr( return {
{AffineMap::get(nloops, /*symbolCount=*/0, inputExprs, b->getContext()), AffineMap::get(nloops, /*symbolCount=*/0, inputExprs, b->getContext()),
b->getMultiDimIdentityMap(nloops)}); b->getMultiDimIdentityMap(nloops)};
} }
}; };
@ -946,7 +831,7 @@ void populateLHLOToLinalgConversionPattern(MLIRContext* context,
PointwiseToLinalgConverter<xla_lhlo::SqrtOp>, PointwiseToLinalgConverter<xla_lhlo::SqrtOp>,
PointwiseToLinalgConverter<xla_lhlo::SubOp>, PointwiseToLinalgConverter<xla_lhlo::SubOp>,
PointwiseToLinalgConverter<xla_lhlo::TanhOp>, PointwiseToLinalgConverter<xla_lhlo::TanhOp>,
ReshapeAddRemoveDimConverter<xla_lhlo::ReshapeOp>, ReshapeOpConverter<xla_lhlo::ReshapeOp>,
ReverseConverter<xla_lhlo::ReverseOp>, ReverseConverter<xla_lhlo::ReverseOp>,
ScalarPointwiseToStandardConverter<xla_lhlo::AddOp>, ScalarPointwiseToStandardConverter<xla_lhlo::AddOp>,
SliceConverter SliceConverter
@ -1045,7 +930,6 @@ void populateHLOToLinalgConversionPattern(MLIRContext* context,
PointwiseToLinalgConverter<xla_hlo::SqrtOp, false>, PointwiseToLinalgConverter<xla_hlo::SqrtOp, false>,
PointwiseToLinalgConverter<xla_hlo::SubOp, false>, PointwiseToLinalgConverter<xla_hlo::SubOp, false>,
PointwiseToLinalgConverter<xla_hlo::TanhOp, false>, PointwiseToLinalgConverter<xla_hlo::TanhOp, false>,
ReshapeAddRemoveDimConverter<xla_hlo::ReshapeOp, false>,
ReshapeOpConverter<xla_hlo::ReshapeOp, false>, ReshapeOpConverter<xla_hlo::ReshapeOp, false>,
ReverseConverter<xla_hlo::ReverseOp, false>, ReverseConverter<xla_hlo::ReverseOp, false>,
TransposeConverter<xla_hlo::TransposeOp, false>>(context); TransposeConverter<xla_hlo::TransposeOp, false>>(context);

View File

@ -185,6 +185,7 @@ tf_xla_py_test(
name = "argminmax_test", name = "argminmax_test",
size = "small", size = "small",
srcs = ["argminmax_test.py"], srcs = ["argminmax_test.py"],
enable_mlir_bridge = True,
python_version = "PY3", python_version = "PY3",
tags = [ tags = [
"no_pip", # TODO(b/149738646): fix pip install so these tests run on kokoro pip "no_pip", # TODO(b/149738646): fix pip install so these tests run on kokoro pip
@ -253,6 +254,7 @@ tf_xla_py_test(
name = "bucketize_op_test", name = "bucketize_op_test",
size = "small", size = "small",
srcs = ["bucketize_op_test.py"], srcs = ["bucketize_op_test.py"],
enable_mlir_bridge = True,
python_version = "PY3", python_version = "PY3",
tags = [ tags = [
"no_pip", # TODO(b/149738646): fix pip install so these tests run on kokoro pip "no_pip", # TODO(b/149738646): fix pip install so these tests run on kokoro pip
@ -692,6 +694,7 @@ tf_xla_py_test(
name = "fft_test", name = "fft_test",
size = "medium", size = "medium",
srcs = ["fft_test.py"], srcs = ["fft_test.py"],
enable_mlir_bridge = True,
python_version = "PY3", python_version = "PY3",
shard_count = 6, shard_count = 6,
tags = [ tags = [
@ -805,6 +808,7 @@ tf_xla_py_test(
name = "lrn_ops_test", name = "lrn_ops_test",
size = "medium", size = "medium",
srcs = ["lrn_ops_test.py"], srcs = ["lrn_ops_test.py"],
enable_mlir_bridge = True,
python_version = "PY3", python_version = "PY3",
tags = [ tags = [
"no_pip", # TODO(b/149738646): fix pip install so these tests run on kokoro pip "no_pip", # TODO(b/149738646): fix pip install so these tests run on kokoro pip
@ -1129,6 +1133,7 @@ tf_xla_py_test(
name = "reverse_sequence_op_test", name = "reverse_sequence_op_test",
size = "medium", size = "medium",
srcs = ["reverse_sequence_op_test.py"], srcs = ["reverse_sequence_op_test.py"],
enable_mlir_bridge = True,
python_version = "PY3", python_version = "PY3",
tags = [ tags = [
"no_pip", # TODO(b/149738646): fix pip install so these tests run on kokoro pip "no_pip", # TODO(b/149738646): fix pip install so these tests run on kokoro pip
@ -1218,6 +1223,7 @@ tf_xla_py_test(
name = "sparse_to_dense_op_test", name = "sparse_to_dense_op_test",
size = "small", size = "small",
srcs = ["sparse_to_dense_op_test.py"], srcs = ["sparse_to_dense_op_test.py"],
enable_mlir_bridge = True,
python_version = "PY3", python_version = "PY3",
tags = [ tags = [
"no_pip", # TODO(b/149738646): fix pip install so these tests run on kokoro pip "no_pip", # TODO(b/149738646): fix pip install so these tests run on kokoro pip

View File

@ -1225,8 +1225,6 @@ class BinaryOpsTest(xla_test.XLATestCase):
[7, 7, 7, 7, 7, 7]], [7, 7, 7, 7, 7, 7]],
dtype=dtype)) dtype=dtype))
@test_util.disable_mlir_bridge(
"Requires concatenate op support in MlirHloBuilder")
def testSymmetricMirrorPad(self): def testSymmetricMirrorPad(self):
mirror_pad = lambda t, paddings: array_ops.pad(t, paddings, "SYMMETRIC") mirror_pad = lambda t, paddings: array_ops.pad(t, paddings, "SYMMETRIC")
for dtype in self.numeric_types: for dtype in self.numeric_types:
@ -1258,8 +1256,6 @@ class BinaryOpsTest(xla_test.XLATestCase):
np.array([[0, 0], [0, 0]], dtype=np.int32), np.array([[0, 0], [0, 0]], dtype=np.int32),
expected=np.array([[1, 2, 3], [4, 5, 6]], dtype=dtype)) expected=np.array([[1, 2, 3], [4, 5, 6]], dtype=dtype))
@test_util.disable_mlir_bridge(
"Requires concatenate op support in MlirHloBuilder")
def testReflectMirrorPad(self): def testReflectMirrorPad(self):
mirror_pad = lambda t, paddings: array_ops.pad(t, paddings, "REFLECT") mirror_pad = lambda t, paddings: array_ops.pad(t, paddings, "REFLECT")
for dtype in self.numeric_types: for dtype in self.numeric_types:

View File

@ -21,6 +21,7 @@ from __future__ import print_function
from tensorflow.compiler.tests import xla_test from tensorflow.compiler.tests import xla_test
from tensorflow.python.framework import dtypes from tensorflow.python.framework import dtypes
from tensorflow.python.framework import errors_impl from tensorflow.python.framework import errors_impl
from tensorflow.python.framework import test_util
from tensorflow.python.ops import array_ops from tensorflow.python.ops import array_ops
from tensorflow.python.ops import math_ops from tensorflow.python.ops import math_ops
from tensorflow.python.platform import test from tensorflow.python.platform import test
@ -57,6 +58,7 @@ class BucketizationOpTest(xla_test.XLATestCase):
expected_out, sess.run(op, expected_out, sess.run(op,
{p: [[-5, 0, 2, 3, 5], [8, 10, 11, 12, 0]]})) {p: [[-5, 0, 2, 3, 5], [8, 10, 11, 12, 0]]}))
@test_util.disable_mlir_bridge("Error handling")
def testInvalidBoundariesOrder(self): def testInvalidBoundariesOrder(self):
with self.session() as sess: with self.session() as sess:
p = array_ops.placeholder(dtypes.int32) p = array_ops.placeholder(dtypes.int32)

View File

@ -78,10 +78,17 @@ class ConvolutionNodeNameTest(xla_test.XLATestCase):
xla_names = _GetNodeNames(use_xla=True) xla_names = _GetNodeNames(use_xla=True)
no_xla_names = _GetNodeNames(use_xla=False) no_xla_names = _GetNodeNames(use_xla=False)
self.assertListEqual(
xla_names, # CPU path creates some additional nodes to handle dilations.
no_xla_names, # TODO(b/138804006): Remove this when CPU & GPU support dilations.
) filtered_no_xla_names = []
for name in no_xla_names:
if ("dilation_rate" in name or "filter_shape" in name or "stack" in name):
continue
else:
filtered_no_xla_names.append(name)
self.assertListEqual(xla_names, filtered_no_xla_names)
def testConv1DNodeNameMatch(self): def testConv1DNodeNameMatch(self):
input_sizes = [8, 16, 3] input_sizes = [8, 16, 3]

View File

@ -22,6 +22,7 @@ import numpy as np
from tensorflow.compiler.tests import xla_test from tensorflow.compiler.tests import xla_test
from tensorflow.python.framework import dtypes from tensorflow.python.framework import dtypes
from tensorflow.python.framework import test_util
from tensorflow.python.ops import array_ops from tensorflow.python.ops import array_ops
from tensorflow.python.ops import sparse_ops from tensorflow.python.ops import sparse_ops
from tensorflow.python.platform import test from tensorflow.python.platform import test
@ -101,6 +102,7 @@ class SparseToDenseTest(xla_test.XLATestCase):
with self.assertRaisesWithPredicateMatch(ValueError, "must be rank 1"): with self.assertRaisesWithPredicateMatch(ValueError, "must be rank 1"):
_SparseToDense([1, 3], [[5], [3]], 1, -1) _SparseToDense([1, 3], [[5], [3]], 1, -1)
@test_util.disable_mlir_bridge("Error handling")
def testBadValue(self): def testBadValue(self):
with self.session(), self.test_scope(): with self.session(), self.test_scope():
with self.assertRaisesOpError( with self.assertRaisesOpError(
@ -108,12 +110,14 @@ class SparseToDenseTest(xla_test.XLATestCase):
r"should be \[\] or \[2\]"): r"should be \[\] or \[2\]"):
_SparseToDense([1, 3], [5], [[5], [3]], -1) _SparseToDense([1, 3], [5], [[5], [3]], -1)
@test_util.disable_mlir_bridge("Error handling")
def testBadNumValues(self): def testBadNumValues(self):
with self.session(), self.test_scope(): with self.session(), self.test_scope():
with self.assertRaisesOpError( with self.assertRaisesOpError(
r"sparse_values has incorrect shape \[3\], should be \[\] or \[2\]"): r"sparse_values has incorrect shape \[3\], should be \[\] or \[2\]"):
_SparseToDense([1, 3], [5], [1, 2, 3], -1) _SparseToDense([1, 3], [5], [1, 2, 3], -1)
@test_util.disable_mlir_bridge("Error handling")
def testBadDefault(self): def testBadDefault(self):
with self.session(), self.test_scope(): with self.session(), self.test_scope():
with self.assertRaisesOpError("default_value should be a scalar"): with self.assertRaisesOpError("default_value should be a scalar"):

View File

@ -918,16 +918,12 @@ class UnaryOpsTest(xla_test.XLATestCase):
np.array([1, 0x100000003f800000], np.int64), np.array([1, 0x100000003f800000], np.int64),
expected=np.array([1, 0x100000003f800000], np.uint64)) expected=np.array([1, 0x100000003f800000], np.uint64))
@test_util.disable_mlir_bridge(
"TODO(b/153812660): Handle tf.InvertPermutation compilation")
def testInvertPermutation(self): def testInvertPermutation(self):
self._assertOpOutputMatchesExpected( self._assertOpOutputMatchesExpected(
array_ops.invert_permutation, array_ops.invert_permutation,
np.array([1, 2, 0], np.int32), np.array([1, 2, 0], np.int32),
expected=np.array([2, 0, 1], dtype=np.int32)) expected=np.array([2, 0, 1], dtype=np.int32))
@test_util.disable_mlir_bridge(
"TODO(b/153812660): Handle tf.InvertPermutation compilation")
def testInvertPermutationTwiceIsNoop(self): def testInvertPermutationTwiceIsNoop(self):
self._assertOpOutputMatchesExpected( self._assertOpOutputMatchesExpected(
lambda x: array_ops.invert_permutation(array_ops.invert_permutation(x)), lambda x: array_ops.invert_permutation(array_ops.invert_permutation(x)),
@ -1144,8 +1140,6 @@ class UnaryOpsTest(xla_test.XLATestCase):
rtol=rtol, rtol=rtol,
atol=atol) atol=atol)
@test_util.disable_mlir_bridge(
"bf16 type not supported in CreateDenseElementsAttrFromLiteral")
def testSoftplus(self): def testSoftplus(self):
for dtype in self.float_types & {dtypes.float32, dtypes.float64}: for dtype in self.float_types & {dtypes.float32, dtypes.float64}:
self._assertSoftplusMatchesExpected([[-2, 0, 8]], dtype) self._assertSoftplusMatchesExpected([[-2, 0, 8]], dtype)

View File

@ -132,7 +132,7 @@ cc_library(
"//tensorflow/stream_executor:device_memory_allocator", "//tensorflow/stream_executor:device_memory_allocator",
"@com_google_absl//absl/memory", "@com_google_absl//absl/memory",
"@com_google_absl//absl/types:span", "@com_google_absl//absl/types:span",
"@llvm-project//llvm:support", "@llvm-project//llvm:Support",
], ],
) )
@ -150,7 +150,7 @@ cc_library(
"//tensorflow/compiler/xla/service:compiler", "//tensorflow/compiler/xla/service:compiler",
"//tensorflow/core:stream_executor_no_cuda", "//tensorflow/core:stream_executor_no_cuda",
"@com_google_absl//absl/memory", "@com_google_absl//absl/memory",
"@llvm-project//llvm:support", "@llvm-project//llvm:Support",
], ],
) )

Some files were not shown because too many files have changed in this diff Show More