Merge remote-tracking branch 'upstream/master' into offline_memory_planner
This commit is contained in:
commit
708ecda43e
@ -202,7 +202,6 @@ cc_library(
|
||||
":operation_interface",
|
||||
":tensor_handle_interface",
|
||||
"//tensorflow/c:tensor_interface",
|
||||
"//tensorflow/c/experimental/saved_model/core:saved_model_api",
|
||||
"//tensorflow/core:framework",
|
||||
"//tensorflow/core:lib",
|
||||
"//tensorflow/core:protos_all_cc",
|
||||
|
@ -1473,14 +1473,10 @@ const TFE_OpAttrs* TFE_OpGetAttrs(TFE_Op* op) {
|
||||
}
|
||||
|
||||
void TFE_OpAddAttrs(TFE_Op* op, const TFE_OpAttrs* attrs) {
|
||||
tensorflow::AttrValueMap m;
|
||||
tensorflow::unwrap(attrs)->FillAttrValueMap(&m);
|
||||
tensorflow::EagerOperation* operation =
|
||||
OperationFromInterface(tensorflow::unwrap(op));
|
||||
tensorflow::AttrBuilder* destination = operation->MutableAttrs();
|
||||
for (const auto& attribute : m) {
|
||||
destination->Set(attribute.first, attribute.second);
|
||||
}
|
||||
destination->CopyAttributes(*tensorflow::unwrap(attrs));
|
||||
}
|
||||
|
||||
void TFE_OpAttrsSerialize(const TFE_OpAttrs* attrs, TF_Buffer* buf,
|
||||
|
@ -21,8 +21,8 @@ limitations under the License.
|
||||
#include "absl/types/span.h"
|
||||
#include "tensorflow/c/eager/operation_interface.h"
|
||||
#include "tensorflow/c/eager/tensor_handle_interface.h"
|
||||
#include "tensorflow/c/experimental/saved_model/core/saved_model_api.h"
|
||||
#include "tensorflow/c/tensor_interface.h"
|
||||
#include "tensorflow/core/framework/function.pb.h"
|
||||
#include "tensorflow/core/framework/numeric_types.h"
|
||||
#include "tensorflow/core/framework/types.pb.h"
|
||||
#include "tensorflow/core/platform/status.h"
|
||||
@ -84,11 +84,10 @@ class AbstractContextInterface {
|
||||
// Create an operation to perform op execution
|
||||
virtual AbstractOperationInterface* CreateOperation() = 0;
|
||||
|
||||
// Load a SavedModelAPI object from the given directory and tags
|
||||
virtual std::unique_ptr<SavedModelAPI> LoadSavedModelAPI(
|
||||
const std::string& directory,
|
||||
const absl::optional<std::unordered_set<std::string>>& tags,
|
||||
tensorflow::Status* status) = 0;
|
||||
// Returns whether the runtime is backed by TFRT or the legacy TF Eager
|
||||
// Runtime. This is necessary to decouple runtime-dependent
|
||||
// code that is layered on top of the runtime.
|
||||
virtual bool UsesTFRT() = 0;
|
||||
|
||||
// List attributes of available devices
|
||||
virtual void ListDevices(std::vector<DeviceAttributes>* devices) = 0;
|
||||
|
@ -16,6 +16,8 @@ limitations under the License.
|
||||
#include "tensorflow/c/eager/parallel_device/parallel_device_lib.h"
|
||||
|
||||
#include "tensorflow/core/lib/gtl/cleanup.h"
|
||||
#include "tensorflow/core/platform/env.h"
|
||||
#include "tensorflow/core/platform/mutex.h"
|
||||
|
||||
namespace tensorflow {
|
||||
namespace parallel_device {
|
||||
@ -28,21 +30,198 @@ class OpDeleter {
|
||||
|
||||
using OpPtr = std::unique_ptr<TFE_Op, OpDeleter>;
|
||||
|
||||
// Creates a vector of `count` new executors (threads).
|
||||
std::vector<ExecutorPtr> MakeExecutors(size_t count) {
|
||||
std::vector<ExecutorPtr> executors;
|
||||
executors.reserve(count);
|
||||
for (int i = 0; i < count; ++i) {
|
||||
executors.emplace_back(TFE_NewExecutor(true /* is_async */));
|
||||
}
|
||||
return executors;
|
||||
}
|
||||
class StatusDeleter {
|
||||
public:
|
||||
void operator()(TF_Status* to_delete) const { TF_DeleteStatus(to_delete); }
|
||||
};
|
||||
|
||||
using StatusPtr = std::unique_ptr<TF_Status, StatusDeleter>;
|
||||
|
||||
} // namespace
|
||||
|
||||
// Allows a single op at a time to be launched without blocking.
|
||||
//
|
||||
// DeviceThread itself is thread-safe, in that StartExecute will block if there
|
||||
// is a pending execution. Since StartExecute is equivalent to grabbing a lock,
|
||||
// multiple DeviceThreads should always be accessed in the same order to avoid
|
||||
// deadlocks.
|
||||
class DeviceThread {
|
||||
public:
|
||||
// Starts a background thread waiting for `StartExecute`.
|
||||
explicit DeviceThread(const std::string& device)
|
||||
: status_(TF_NewStatus()),
|
||||
device_(device),
|
||||
op_(nullptr),
|
||||
thread_(tensorflow::Env::Default()->StartThread(
|
||||
tensorflow::ThreadOptions(), "parallel_device_execute",
|
||||
std::bind(&DeviceThread::Run, this))) {}
|
||||
~DeviceThread();
|
||||
|
||||
// Requests that the worker thread execute the specified operation. Blocks
|
||||
// until the previously pending operation (a StartExecute without a Join) has
|
||||
// finished, if any.
|
||||
void StartExecute(TFE_Context* context, const char* operation_name,
|
||||
std::vector<TFE_TensorHandle*> inputs,
|
||||
const TFE_OpAttrs* attributes, int expected_max_outputs);
|
||||
// Block until the previous `StartExecute` operation has executed. Forwards
|
||||
// the status from `TFE_Execute` and returns outputs if the status is OK.
|
||||
std::vector<TensorHandlePtr> Join(TF_Status* status);
|
||||
|
||||
private:
|
||||
void Run();
|
||||
|
||||
void Execute(TFE_Context* context, const char* operation_name,
|
||||
std::vector<TFE_TensorHandle*> inputs,
|
||||
const TFE_OpAttrs* attributes, int expected_max_outputs,
|
||||
std::vector<TensorHandlePtr>* outputs, TF_Status* status) const
|
||||
TF_EXCLUSIVE_LOCKS_REQUIRED(execution_mutex_);
|
||||
|
||||
enum class ExecutionState {
|
||||
kReadyToExecute,
|
||||
kHasResult,
|
||||
kIdle,
|
||||
kShuttingDown,
|
||||
};
|
||||
|
||||
tensorflow::mutex execution_mutex_;
|
||||
ExecutionState execution_state_ TF_GUARDED_BY(execution_mutex_) =
|
||||
ExecutionState::kIdle;
|
||||
// Tells the worker thread that there is new work.
|
||||
tensorflow::condition_variable start_execute_;
|
||||
// The worker thread notifies that work has finished.
|
||||
tensorflow::condition_variable finished_execute_;
|
||||
// Notifies a StartExecute that the previous Join has finished.
|
||||
tensorflow::condition_variable finished_join_;
|
||||
|
||||
// Temporary state between `StartExecute` and `Join`.
|
||||
// Inputs
|
||||
TFE_Context* context_ TF_GUARDED_BY(execution_mutex_);
|
||||
const char* operation_name_ TF_GUARDED_BY(execution_mutex_);
|
||||
std::vector<TFE_TensorHandle*> op_inputs_ TF_GUARDED_BY(execution_mutex_);
|
||||
const TFE_OpAttrs* attributes_ TF_GUARDED_BY(execution_mutex_);
|
||||
int expected_max_outputs_ TF_GUARDED_BY(execution_mutex_);
|
||||
// Outputs
|
||||
std::vector<TensorHandlePtr> op_outputs_ TF_GUARDED_BY(execution_mutex_);
|
||||
StatusPtr status_ TF_GUARDED_BY(execution_mutex_);
|
||||
|
||||
const std::string device_;
|
||||
mutable OpPtr op_ TF_GUARDED_BY(execution_mutex_);
|
||||
std::unique_ptr<Thread> thread_;
|
||||
};
|
||||
|
||||
DeviceThread::~DeviceThread() {
|
||||
{
|
||||
tensorflow::mutex_lock l(execution_mutex_);
|
||||
execution_state_ = ExecutionState::kShuttingDown;
|
||||
}
|
||||
start_execute_.notify_one();
|
||||
}
|
||||
|
||||
void DeviceThread::Run() {
|
||||
while (true) {
|
||||
{
|
||||
tensorflow::mutex_lock l(execution_mutex_);
|
||||
while (execution_state_ == ExecutionState::kIdle ||
|
||||
execution_state_ == ExecutionState::kHasResult) {
|
||||
start_execute_.wait(l);
|
||||
}
|
||||
if (execution_state_ == ExecutionState::kShuttingDown) {
|
||||
return;
|
||||
} else if (execution_state_ == ExecutionState::kReadyToExecute) {
|
||||
// op_outputs_ may have been std::moved
|
||||
op_outputs_ = std::vector<TensorHandlePtr>();
|
||||
Execute(context_, operation_name_, std::move(op_inputs_), attributes_,
|
||||
expected_max_outputs_, &op_outputs_, status_.get());
|
||||
execution_state_ = ExecutionState::kHasResult;
|
||||
}
|
||||
}
|
||||
finished_execute_.notify_one();
|
||||
}
|
||||
}
|
||||
|
||||
void DeviceThread::StartExecute(TFE_Context* context,
|
||||
const char* operation_name,
|
||||
std::vector<TFE_TensorHandle*> inputs,
|
||||
const TFE_OpAttrs* attributes,
|
||||
int expected_max_outputs) {
|
||||
{
|
||||
tensorflow::mutex_lock l(execution_mutex_);
|
||||
while (execution_state_ != ExecutionState::kIdle) {
|
||||
// If there's already a pending execution, wait until Join finishes before
|
||||
// starting on the next operation.
|
||||
finished_join_.wait(l);
|
||||
}
|
||||
context_ = context;
|
||||
operation_name_ = operation_name;
|
||||
op_inputs_ = inputs;
|
||||
attributes_ = attributes;
|
||||
expected_max_outputs_ = expected_max_outputs;
|
||||
execution_state_ = ExecutionState::kReadyToExecute;
|
||||
}
|
||||
start_execute_.notify_one();
|
||||
}
|
||||
|
||||
std::vector<TensorHandlePtr> DeviceThread::Join(TF_Status* status) {
|
||||
std::vector<TensorHandlePtr> result;
|
||||
{
|
||||
tensorflow::mutex_lock l(execution_mutex_);
|
||||
while (execution_state_ != ExecutionState::kHasResult) {
|
||||
finished_execute_.wait(l);
|
||||
}
|
||||
if (TF_GetCode(status_.get()) != TF_OK) {
|
||||
TF_SetStatus(status, TF_GetCode(status_.get()),
|
||||
TF_Message(status_.get()));
|
||||
}
|
||||
execution_state_ = ExecutionState::kIdle;
|
||||
result = std::move(op_outputs_);
|
||||
}
|
||||
finished_join_.notify_one();
|
||||
return result;
|
||||
}
|
||||
|
||||
void DeviceThread::Execute(TFE_Context* context, const char* operation_name,
|
||||
std::vector<TFE_TensorHandle*> inputs,
|
||||
const TFE_OpAttrs* attributes,
|
||||
int expected_max_outputs,
|
||||
std::vector<TensorHandlePtr>* outputs,
|
||||
TF_Status* status) const {
|
||||
if (op_ == nullptr) {
|
||||
op_.reset(TFE_NewOp(context, operation_name, status));
|
||||
if (TF_GetCode(status) != TF_OK) return;
|
||||
TFE_OpSetDevice(op_.get(), device_.c_str(), status);
|
||||
if (TF_GetCode(status) != TF_OK) return;
|
||||
} else {
|
||||
TFE_OpReset(op_.get(), operation_name, device_.c_str(), status);
|
||||
if (TF_GetCode(status) != TF_OK) return;
|
||||
}
|
||||
TFE_OpAddAttrs(op_.get(), attributes);
|
||||
for (int input_index = 0; input_index < inputs.size(); ++input_index) {
|
||||
TFE_OpAddInput(op_.get(), inputs[input_index], status);
|
||||
if (TF_GetCode(status) != TF_OK) return;
|
||||
}
|
||||
std::vector<TFE_TensorHandle*> unwrapped_results(expected_max_outputs);
|
||||
int real_num_outputs = expected_max_outputs;
|
||||
if (TF_GetCode(status) != TF_OK) return;
|
||||
TFE_Execute(op_.get(), unwrapped_results.data(), &real_num_outputs, status);
|
||||
if (TF_GetCode(status) != TF_OK) return;
|
||||
unwrapped_results.resize(real_num_outputs);
|
||||
outputs->reserve(real_num_outputs);
|
||||
for (TFE_TensorHandle* unwrapped_result : unwrapped_results) {
|
||||
outputs->emplace_back(unwrapped_result);
|
||||
}
|
||||
}
|
||||
|
||||
ParallelDevice::ParallelDevice(const std::vector<std::string>& devices)
|
||||
: underlying_devices_(devices),
|
||||
executors_(MakeExecutors(underlying_devices_.size())) {}
|
||||
: underlying_devices_(devices) {
|
||||
device_threads_.reserve(devices.size());
|
||||
for (int device_index = 0; device_index < devices.size(); ++device_index) {
|
||||
device_threads_.emplace_back(
|
||||
new DeviceThread(devices[device_index].c_str()));
|
||||
}
|
||||
}
|
||||
|
||||
// Necessary for a unique_ptr to a forward-declared type.
|
||||
ParallelDevice::~ParallelDevice() = default;
|
||||
|
||||
std::unique_ptr<ParallelTensor> ParallelDevice::CopyToParallelDevice(
|
||||
TFE_Context* context, TFE_TensorHandle* tensor, TF_Status* status) const {
|
||||
@ -108,72 +287,34 @@ ParallelDevice::Execute(TFE_Context* context,
|
||||
// Compute per-device per-output tensors
|
||||
std::vector<std::vector<TensorHandlePtr>> per_device_output_tensors;
|
||||
per_device_output_tensors.reserve(underlying_devices_.size());
|
||||
// TODO(allenl): Add a TFE_ExecuteWithExecutor API so we don't have to keep
|
||||
// setting the thread-local executor like this.
|
||||
TFE_Executor* previous_executor(TFE_ContextGetExecutorForThread(context));
|
||||
auto reset_executor =
|
||||
tensorflow::gtl::MakeCleanup([context, previous_executor]() {
|
||||
TFE_ContextSetExecutorForThread(context, previous_executor);
|
||||
TFE_DeleteExecutor(previous_executor);
|
||||
});
|
||||
int first_op_output_count;
|
||||
int first_op_output_count = 0;
|
||||
for (int device_index = 0; device_index < underlying_devices_.size();
|
||||
++device_index) {
|
||||
TFE_Executor* executor = executors_[device_index].get();
|
||||
// Note that the `reset_executor` cleanup sets the thread's executor back to
|
||||
// the value before this function ran.
|
||||
TFE_ContextSetExecutorForThread(context, executor);
|
||||
OpPtr op(TFE_NewOp(context, operation_name, status));
|
||||
if (TF_GetCode(status) != TF_OK) return result;
|
||||
TFE_OpSetDevice(op.get(), underlying_devices_[device_index].c_str(),
|
||||
status);
|
||||
TFE_OpAddAttrs(op.get(), attributes);
|
||||
DeviceThread* device_thread = device_threads_[device_index].get();
|
||||
std::vector<TFE_TensorHandle*> device_inputs;
|
||||
device_inputs.reserve(device_inputs.size());
|
||||
for (int input_index = 0; input_index < inputs.size(); ++input_index) {
|
||||
// Parallel tensors are divided between operations by device.
|
||||
TFE_OpAddInput(op.get(), inputs[input_index]->tensor(device_index),
|
||||
status);
|
||||
if (TF_GetCode(status) != TF_OK) return result;
|
||||
device_inputs.push_back(inputs[input_index]->tensor(device_index));
|
||||
}
|
||||
std::vector<TFE_TensorHandle*> op_outputs(expected_max_outputs);
|
||||
int real_num_outputs = expected_max_outputs;
|
||||
// For nested devices, the inner device sees the async executor we've
|
||||
// set. Inner parallel devices will just overwrite this with their own and
|
||||
// then set it back to ours before returning. This means parallel devices
|
||||
// which consist of several aliased parallel devices would hypothetically
|
||||
// deadlock if the outer parallel device ran one collective with a group
|
||||
// size equal to the total number of aliased physical devices. Currently
|
||||
// physical devices cannot participate in a single collective reduction
|
||||
// multiple times, so this would fail earlier.
|
||||
//
|
||||
// TODO(allenl): Keep a map from outer executor to list of inner executors
|
||||
// rather than a single list of executors so aliased nested parallel devices
|
||||
// don't re-use an executor.
|
||||
TFE_Execute(op.get(), op_outputs.data(), &real_num_outputs, status);
|
||||
device_thread->StartExecute(context, operation_name,
|
||||
std::move(device_inputs), attributes,
|
||||
expected_max_outputs);
|
||||
}
|
||||
for (int device_index = 0; device_index < underlying_devices_.size();
|
||||
++device_index) {
|
||||
DeviceThread* device_thread = device_threads_[device_index].get();
|
||||
per_device_output_tensors.push_back(device_thread->Join(status));
|
||||
if (TF_GetCode(status) != TF_OK) return result;
|
||||
if (device_index == 0) {
|
||||
first_op_output_count = real_num_outputs;
|
||||
first_op_output_count = per_device_output_tensors.rbegin()->size();
|
||||
} else {
|
||||
if (real_num_outputs != first_op_output_count) {
|
||||
if (per_device_output_tensors.rbegin()->size() != first_op_output_count) {
|
||||
TF_SetStatus(status, TF_INTERNAL,
|
||||
"Parallel ops produced different numbers of tensors.");
|
||||
return result;
|
||||
}
|
||||
}
|
||||
if (TF_GetCode(status) != TF_OK) return result;
|
||||
std::vector<TensorHandlePtr> this_outputs;
|
||||
this_outputs.reserve(real_num_outputs);
|
||||
for (int output_num = 0; output_num < real_num_outputs; ++output_num) {
|
||||
this_outputs.emplace_back(op_outputs[output_num]);
|
||||
}
|
||||
per_device_output_tensors.push_back(std::move(this_outputs));
|
||||
}
|
||||
for (int device_index = 0; device_index < underlying_devices_.size();
|
||||
++device_index) {
|
||||
TFE_Executor* executor = executors_[device_index].get();
|
||||
// TODO(b/157523095): Syncing the executor here shouldn't be
|
||||
// necessary. Currently async+remote is missing cross-executor
|
||||
// coordination.
|
||||
TFE_ExecutorWaitForAllPendingNodes(executor, status);
|
||||
if (TF_GetCode(status) != TF_OK) return result;
|
||||
}
|
||||
// For each output of the original operation, pack the per-device
|
||||
// TensorHandles we've computed into a single parallel TensorHandle.
|
||||
|
@ -41,16 +41,8 @@ class TensorHandleDeleter {
|
||||
|
||||
using TensorHandlePtr = std::unique_ptr<TFE_TensorHandle, TensorHandleDeleter>;
|
||||
|
||||
class ExecutorDeleter {
|
||||
public:
|
||||
void operator()(TFE_Executor* to_delete) const {
|
||||
TFE_DeleteExecutor(to_delete);
|
||||
}
|
||||
};
|
||||
|
||||
using ExecutorPtr = std::unique_ptr<TFE_Executor, ExecutorDeleter>;
|
||||
|
||||
class ParallelTensor;
|
||||
class DeviceThread;
|
||||
|
||||
// Forwards operations to `devices`, maintaining ParallelTensor with components
|
||||
// placed on each underlying device.
|
||||
@ -58,6 +50,8 @@ class ParallelDevice {
|
||||
public:
|
||||
explicit ParallelDevice(const std::vector<std::string>& devices);
|
||||
|
||||
~ParallelDevice();
|
||||
|
||||
// Helper to copy a tensor handle from another device once for each component
|
||||
// of the ParallelDevice.
|
||||
//
|
||||
@ -94,9 +88,19 @@ class ParallelDevice {
|
||||
// A sequence of device names, indicating which devices replicated operations
|
||||
// are forwarded to.
|
||||
const std::vector<std::string> underlying_devices_;
|
||||
// A sequence of TFE_Executors, one per device, for executing operations in
|
||||
// A sequence of thread wrappers, one per device, for executing operations in
|
||||
// parallel.
|
||||
const std::vector<ExecutorPtr> executors_;
|
||||
//
|
||||
// Conceptually this is a thread pool with one thread per device. It requires
|
||||
// less synchronization than a thread pool would for this task, since Execute
|
||||
// acquires each thread in order (and so only one Execute will schedule
|
||||
// blocking collective operations at a time), and avoids some dynamic
|
||||
// allocation/scheduling.
|
||||
//
|
||||
// TODO(allenl): Keep a map from outer thread to list of inner threads rather
|
||||
// than a single list of threads so aliased nested parallel devices don't
|
||||
// re-use a thread.
|
||||
std::vector<std::unique_ptr<DeviceThread>> device_threads_;
|
||||
};
|
||||
|
||||
// Contains a tuple of tensors, one on each of the `underlying_devices_` of the
|
||||
|
@ -407,7 +407,7 @@ TensorHandlePtr CollectiveSum(TFE_Context* context, TFE_TensorHandle* input,
|
||||
return TensorHandlePtr(result_handle);
|
||||
}
|
||||
|
||||
TEST(PARALLEL_DEVICE, TestCollective) {
|
||||
void TestCollective(bool async) {
|
||||
std::unique_ptr<TF_Status, decltype(&TF_DeleteStatus)> status(
|
||||
TF_NewStatus(), TF_DeleteStatus);
|
||||
std::unique_ptr<TFE_ContextOptions, decltype(&TFE_DeleteContextOptions)> opts(
|
||||
@ -423,6 +423,9 @@ TEST(PARALLEL_DEVICE, TestCollective) {
|
||||
std::unique_ptr<TFE_Context, decltype(&TFE_DeleteContext)> context(
|
||||
TFE_NewContext(opts.get(), status.get()), TFE_DeleteContext);
|
||||
ASSERT_TRUE(TF_GetCode(status.get()) == TF_OK) << TF_Message(status.get());
|
||||
std::unique_ptr<TFE_Executor, decltype(&TFE_DeleteExecutor)> executor(
|
||||
TFE_NewExecutor(async), TFE_DeleteExecutor);
|
||||
TFE_ContextSetExecutorForThread(context.get(), executor.get());
|
||||
|
||||
const char* device_name = "/job:localhost/replica:0/task:0/device:CUSTOM:0";
|
||||
std::array<const char*, 2> underlying_devices{
|
||||
@ -452,8 +455,16 @@ TEST(PARALLEL_DEVICE, TestCollective) {
|
||||
ASSERT_TRUE(TF_GetCode(status.get()) == TF_OK) << TF_Message(status.get());
|
||||
ExpectScalarEq<float>(result_components[0].get(), 3.);
|
||||
ExpectScalarEq<float>(result_components[1].get(), 3.);
|
||||
// Destroying the context's default executor first isn't safe.
|
||||
context.reset();
|
||||
}
|
||||
|
||||
TEST(PARALLEL_DEVICE, TestCollectiveSync) { TestCollective(/*async=*/false); }
|
||||
|
||||
// Note that ops on the parallel device currently don't execute
|
||||
// asynchronously. The test is just that we don't get deadlocks.
|
||||
TEST(PARALLEL_DEVICE, TestCollectiveAsync) { TestCollective(/*async=*/true); }
|
||||
|
||||
void RegisterCollectiveMulFunction(TFE_Context* context,
|
||||
const char* function_name, int group_size,
|
||||
TF_Status* status) {
|
||||
|
@ -26,5 +26,6 @@ cc_library(
|
||||
deps = [
|
||||
"//tensorflow/c:tf_status",
|
||||
"//tensorflow/c/experimental/filesystem:filesystem_interface",
|
||||
"@com_github_googlecloudplatform_google_cloud_cpp//:storage_client",
|
||||
],
|
||||
)
|
||||
|
@ -15,11 +15,22 @@ limitations under the License.
|
||||
#include <stdlib.h>
|
||||
#include <string.h>
|
||||
|
||||
#include "google/cloud/storage/client.h"
|
||||
#include "tensorflow/c/experimental/filesystem/filesystem_interface.h"
|
||||
#include "tensorflow/c/tf_status.h"
|
||||
|
||||
// Implementation of a filesystem for GCS environments.
|
||||
// This filesystem will support `gs://` URI schemes.
|
||||
namespace gcs = google::cloud::storage;
|
||||
|
||||
// We can cast `google::cloud::StatusCode` to `TF_Code` because they have the
|
||||
// same integer values. See
|
||||
// https://github.com/googleapis/google-cloud-cpp/blob/6c09cbfa0160bc046e5509b4dd2ab4b872648b4a/google/cloud/status.h#L32-L52
|
||||
static inline void TF_SetStatusFromGCSStatus(
|
||||
const google::cloud::Status& gcs_status, TF_Status* status) {
|
||||
TF_SetStatus(status, static_cast<TF_Code>(gcs_status.code()),
|
||||
gcs_status.message().c_str());
|
||||
}
|
||||
|
||||
static void* plugin_memory_allocate(size_t size) { return calloc(1, size); }
|
||||
static void plugin_memory_free(void* ptr) { free(ptr); }
|
||||
@ -52,6 +63,20 @@ namespace tf_read_only_memory_region {
|
||||
// ----------------------------------------------------------------------------
|
||||
namespace tf_gcs_filesystem {
|
||||
|
||||
// TODO(vnvo2409): Add lazy-loading and customizing parameters.
|
||||
static void Init(TF_Filesystem* filesystem, TF_Status* status) {
|
||||
google::cloud::StatusOr<gcs::Client> client =
|
||||
gcs::Client::CreateDefaultClient();
|
||||
if (!client) {
|
||||
TF_SetStatusFromGCSStatus(client.status(), status);
|
||||
return;
|
||||
}
|
||||
filesystem->plugin_filesystem = plugin_memory_allocate(sizeof(gcs::Client));
|
||||
auto gcs_client = static_cast<gcs::Client*>(filesystem->plugin_filesystem);
|
||||
(*gcs_client) = client.value();
|
||||
TF_SetStatus(status, TF_OK, "");
|
||||
}
|
||||
|
||||
// TODO(vnvo2409): Implement later
|
||||
|
||||
} // namespace tf_gcs_filesystem
|
||||
@ -60,6 +85,10 @@ static void ProvideFilesystemSupportFor(TF_FilesystemPluginOps* ops,
|
||||
const char* uri) {
|
||||
TF_SetFilesystemVersionMetadata(ops);
|
||||
ops->scheme = strdup(uri);
|
||||
|
||||
ops->filesystem_ops = static_cast<TF_FilesystemOps*>(
|
||||
plugin_memory_allocate(TF_FILESYSTEM_OPS_SIZE));
|
||||
ops->filesystem_ops->init = tf_gcs_filesystem::Init;
|
||||
}
|
||||
|
||||
void TF_InitPlugin(TF_FilesystemPluginInfo* info) {
|
||||
|
@ -57,6 +57,7 @@ cc_library(
|
||||
":concrete_function",
|
||||
":saved_model_api",
|
||||
"//tensorflow/core:lib",
|
||||
"//tensorflow/core/common_runtime/eager:context",
|
||||
"@com_google_absl//absl/types:optional",
|
||||
],
|
||||
)
|
||||
|
@ -21,6 +21,7 @@ limitations under the License.
|
||||
|
||||
#include "absl/types/optional.h"
|
||||
#include "tensorflow/c/experimental/saved_model/core/concrete_function.h"
|
||||
#include "tensorflow/core/common_runtime/eager/context.h"
|
||||
#include "tensorflow/core/platform/errors.h"
|
||||
|
||||
namespace tensorflow {
|
||||
@ -51,7 +52,7 @@ std::vector<ConcreteFunction*> TFSavedModelAPIImpl::ListFunctions() {
|
||||
Status TFSavedModelAPIImpl::Load(
|
||||
const std::string& directory,
|
||||
const absl::optional<std::unordered_set<std::string>>& tags,
|
||||
TFSavedModelAPIImpl* out) {
|
||||
EagerContext* context, std::unique_ptr<TFSavedModelAPIImpl>* out) {
|
||||
// TODO(bmzhao): Add support for loading a TFSavedModelImpl.
|
||||
return errors::Unimplemented(
|
||||
"TFSavedModelAPIImpl loading is unimplemented currently");
|
||||
|
@ -23,14 +23,13 @@ limitations under the License.
|
||||
#include "absl/types/optional.h"
|
||||
#include "tensorflow/c/experimental/saved_model/core/concrete_function.h"
|
||||
#include "tensorflow/c/experimental/saved_model/core/saved_model_api.h"
|
||||
#include "tensorflow/core/common_runtime/eager/context.h"
|
||||
#include "tensorflow/core/platform/status.h"
|
||||
|
||||
namespace tensorflow {
|
||||
|
||||
class TFSavedModelAPIImpl : public SavedModelAPI {
|
||||
public:
|
||||
TFSavedModelAPIImpl() = default;
|
||||
|
||||
Status GetFunction(const std::string& function_path,
|
||||
ConcreteFunction** function) override;
|
||||
|
||||
@ -40,13 +39,14 @@ class TFSavedModelAPIImpl : public SavedModelAPI {
|
||||
static Status Load(
|
||||
const std::string& directory,
|
||||
const absl::optional<std::unordered_set<std::string>>& tags,
|
||||
TFSavedModelAPIImpl* out);
|
||||
EagerContext* context, std::unique_ptr<TFSavedModelAPIImpl>* out);
|
||||
|
||||
std::vector<ConcreteFunction*> ListFunctions() override;
|
||||
|
||||
~TFSavedModelAPIImpl() override = default;
|
||||
|
||||
private:
|
||||
TFSavedModelAPIImpl() = default;
|
||||
std::vector<ConcreteFunction> functions_;
|
||||
};
|
||||
|
||||
|
@ -144,7 +144,9 @@ cc_library(
|
||||
"//tensorflow/c:tf_status_internal",
|
||||
"//tensorflow/c/eager:tfe_context_internal",
|
||||
"//tensorflow/c/experimental/saved_model/core:saved_model_api",
|
||||
"//tensorflow/c/experimental/saved_model/core:tf_saved_model_impl",
|
||||
"//tensorflow/core:lib",
|
||||
"//tensorflow/core/common_runtime/eager:context",
|
||||
"@com_google_absl//absl/types:optional",
|
||||
],
|
||||
)
|
||||
|
@ -22,11 +22,15 @@ limitations under the License.
|
||||
#include "absl/types/optional.h"
|
||||
#include "tensorflow/c/eager/tfe_context_internal.h"
|
||||
#include "tensorflow/c/experimental/saved_model/core/saved_model_api.h"
|
||||
#include "tensorflow/c/experimental/saved_model/core/tf_saved_model_impl.h"
|
||||
#include "tensorflow/c/experimental/saved_model/internal/concrete_function_list_type.h"
|
||||
#include "tensorflow/c/experimental/saved_model/internal/concrete_function_type.h"
|
||||
#include "tensorflow/c/experimental/saved_model/internal/saved_model_api_type.h"
|
||||
#include "tensorflow/c/tf_status.h"
|
||||
#include "tensorflow/c/tf_status_internal.h"
|
||||
#include "tensorflow/core/common_runtime/eager/context.h"
|
||||
#include "tensorflow/core/platform/casts.h"
|
||||
#include "tensorflow/core/platform/errors.h"
|
||||
#include "tensorflow/core/platform/status.h"
|
||||
|
||||
extern "C" {
|
||||
@ -34,10 +38,21 @@ extern "C" {
|
||||
TF_SavedModel* TF_LoadSavedModel(const char* dirname, TFE_Context* ctx,
|
||||
TF_Status* status) {
|
||||
std::string saved_model_dir(dirname);
|
||||
std::unique_ptr<tensorflow::SavedModelAPI> result;
|
||||
|
||||
if (tensorflow::unwrap(ctx)->UsesTFRT()) {
|
||||
status->status = tensorflow::errors::Unimplemented(
|
||||
"TFRT SavedModel implementation will be added in the future");
|
||||
} else {
|
||||
std::unique_ptr<tensorflow::TFSavedModelAPIImpl> saved_model;
|
||||
status->status = tensorflow::TFSavedModelAPIImpl::Load(
|
||||
dirname, absl::nullopt,
|
||||
tensorflow::down_cast<tensorflow::EagerContext*>(
|
||||
tensorflow::unwrap(ctx)),
|
||||
&saved_model);
|
||||
result = std::move(saved_model);
|
||||
}
|
||||
|
||||
std::unique_ptr<tensorflow::SavedModelAPI> result =
|
||||
tensorflow::unwrap(ctx)->LoadSavedModelAPI(dirname, absl::nullopt,
|
||||
&status->status);
|
||||
if (!status->status.ok()) {
|
||||
return nullptr;
|
||||
}
|
||||
@ -54,9 +69,20 @@ TF_SavedModel* TF_LoadSavedModelWithTags(const char* dirname, TFE_Context* ctx,
|
||||
tagset.insert(std::string(tags[i]));
|
||||
}
|
||||
|
||||
std::unique_ptr<tensorflow::SavedModelAPI> result =
|
||||
tensorflow::unwrap(ctx)->LoadSavedModelAPI(dirname, std::move(tagset),
|
||||
&status->status);
|
||||
std::unique_ptr<tensorflow::SavedModelAPI> result;
|
||||
if (tensorflow::unwrap(ctx)->UsesTFRT()) {
|
||||
status->status = tensorflow::errors::Unimplemented(
|
||||
"TFRT SavedModel implementation will be added in the future");
|
||||
} else {
|
||||
std::unique_ptr<tensorflow::TFSavedModelAPIImpl> saved_model;
|
||||
status->status = tensorflow::TFSavedModelAPIImpl::Load(
|
||||
dirname, tagset,
|
||||
tensorflow::down_cast<tensorflow::EagerContext*>(
|
||||
tensorflow::unwrap(ctx)),
|
||||
&saved_model);
|
||||
result = std::move(saved_model);
|
||||
}
|
||||
|
||||
if (!status->status.ok()) {
|
||||
return nullptr;
|
||||
}
|
||||
|
@ -106,6 +106,7 @@ cc_library(
|
||||
hdrs = ["loader.h"],
|
||||
deps = [
|
||||
":constants",
|
||||
":loader_util",
|
||||
":reader",
|
||||
] + if_not_mobile([
|
||||
"//tensorflow/core:core_cpu",
|
||||
@ -132,6 +133,17 @@ cc_library(
|
||||
],
|
||||
)
|
||||
|
||||
cc_library(
|
||||
name = "loader_util",
|
||||
srcs = ["loader_util.cc"],
|
||||
hdrs = ["loader_util.h"],
|
||||
deps = [":constants"] + if_not_mobile([
|
||||
"//tensorflow/core:lib",
|
||||
"//tensorflow/core:lib_internal",
|
||||
"//tensorflow/core:protos_all_cc",
|
||||
]),
|
||||
)
|
||||
|
||||
tf_cc_test(
|
||||
name = "bundle_v2_test",
|
||||
srcs = ["bundle_v2_test.cc"],
|
||||
|
@ -18,6 +18,7 @@ limitations under the License.
|
||||
#include <unordered_set>
|
||||
|
||||
#include "tensorflow/cc/saved_model/constants.h"
|
||||
#include "tensorflow/cc/saved_model/loader_util.h"
|
||||
#include "tensorflow/cc/saved_model/reader.h"
|
||||
#include "tensorflow/core/framework/attr_value.pb.h"
|
||||
#include "tensorflow/core/framework/node_def.pb.h"
|
||||
@ -29,7 +30,6 @@ limitations under the License.
|
||||
#include "tensorflow/core/lib/strings/strcat.h"
|
||||
#include "tensorflow/core/platform/env.h"
|
||||
#include "tensorflow/core/platform/errors.h"
|
||||
#include "tensorflow/core/platform/protobuf_internal.h"
|
||||
#include "tensorflow/core/protobuf/graph_debug_info.pb.h"
|
||||
#include "tensorflow/core/protobuf/saver.pb.h"
|
||||
#include "tensorflow/core/public/session.h"
|
||||
@ -191,41 +191,6 @@ Status RunInitOp(const RunOptions& run_options, const string& export_dir,
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
// A SavedModel may store the name of the initialization op to run in the
|
||||
// in the SignatureDef (v2) or a collection (v1). If an init_op collection
|
||||
// exists, then the collection must contain exactly one op.
|
||||
Status GetInitOp(const string& export_dir, const MetaGraphDef& meta_graph_def,
|
||||
string* init_op_name) {
|
||||
const auto& sig_def_map = meta_graph_def.signature_def();
|
||||
const auto& init_op_sig_it =
|
||||
meta_graph_def.signature_def().find(kSavedModelInitOpSignatureKey);
|
||||
if (init_op_sig_it != sig_def_map.end()) {
|
||||
*init_op_name = init_op_sig_it->second.outputs()
|
||||
.find(kSavedModelInitOpSignatureKey)
|
||||
->second.name();
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
const auto& collection_def_map = meta_graph_def.collection_def();
|
||||
string init_op_collection_key;
|
||||
if (collection_def_map.find(kSavedModelMainOpKey) !=
|
||||
collection_def_map.end()) {
|
||||
init_op_collection_key = kSavedModelMainOpKey;
|
||||
} else {
|
||||
init_op_collection_key = kSavedModelLegacyInitOpKey;
|
||||
}
|
||||
|
||||
const auto init_op_it = collection_def_map.find(init_op_collection_key);
|
||||
if (init_op_it != collection_def_map.end()) {
|
||||
if (init_op_it->second.node_list().value_size() != 1) {
|
||||
return errors::FailedPrecondition(
|
||||
strings::StrCat("Expected exactly one main op in : ", export_dir));
|
||||
}
|
||||
*init_op_name = init_op_it->second.node_list().value(0);
|
||||
}
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
Status RunRestore(const RunOptions& run_options, const string& export_dir,
|
||||
const StringPiece restore_op_name,
|
||||
const StringPiece variable_filename_const_op_name,
|
||||
@ -263,32 +228,6 @@ Status RunRestore(const RunOptions& run_options, const string& export_dir,
|
||||
nullptr /* outputs */, &run_metadata, session);
|
||||
}
|
||||
|
||||
Status GetAssetFileDefs(const MetaGraphDef& meta_graph_def,
|
||||
std::vector<AssetFileDef>* asset_file_defs) {
|
||||
// With SavedModel v2, we write asset file def into metagraph instead of
|
||||
// collection, so read from metagraph first.
|
||||
if (meta_graph_def.asset_file_def_size() > 0) {
|
||||
for (const auto& asset : meta_graph_def.asset_file_def()) {
|
||||
asset_file_defs->push_back(asset);
|
||||
}
|
||||
return Status::OK();
|
||||
}
|
||||
// Fall back to read from collection to be backward compatible with v1.
|
||||
const auto& collection_def_map = meta_graph_def.collection_def();
|
||||
const auto assets_it = collection_def_map.find(kSavedModelAssetsKey);
|
||||
if (assets_it == collection_def_map.end()) {
|
||||
return Status::OK();
|
||||
}
|
||||
const auto& any_assets = assets_it->second.any_list().value();
|
||||
for (const auto& any_asset : any_assets) {
|
||||
AssetFileDef asset_file_def;
|
||||
TF_RETURN_IF_ERROR(
|
||||
ParseAny(any_asset, &asset_file_def, "tensorflow.AssetFileDef"));
|
||||
asset_file_defs->push_back(asset_file_def);
|
||||
}
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
Status ReadSavedModelDebugInfoIfPresent(
|
||||
const string& export_dir,
|
||||
std::unique_ptr<GraphDebugInfo>* debug_info_proto) {
|
||||
@ -322,7 +261,7 @@ Status LoadSavedModelInternal(const SessionOptions& session_options,
|
||||
|
||||
std::vector<AssetFileDef> asset_file_defs;
|
||||
TF_RETURN_IF_ERROR(
|
||||
GetAssetFileDefs(bundle->meta_graph_def, &asset_file_defs));
|
||||
internal::GetAssetFileDefs(bundle->meta_graph_def, &asset_file_defs));
|
||||
TF_RETURN_IF_ERROR(
|
||||
RunRestore(run_options, export_dir,
|
||||
bundle->meta_graph_def.saver_def().restore_op_name(),
|
||||
@ -336,7 +275,7 @@ Status LoadSavedModelInternal(const SessionOptions& session_options,
|
||||
const uint64 graph_init_start_microseconds = Env::Default()->NowMicros();
|
||||
string init_op_name;
|
||||
TF_RETURN_IF_ERROR(
|
||||
GetInitOp(export_dir, bundle->meta_graph_def, &init_op_name));
|
||||
internal::GetInitOp(export_dir, bundle->meta_graph_def, &init_op_name));
|
||||
TF_RETURN_IF_ERROR(RunInitOp(run_options, export_dir, bundle->meta_graph_def,
|
||||
asset_file_defs, bundle->session.get(),
|
||||
init_op_name));
|
||||
|
90
tensorflow/cc/saved_model/loader_util.cc
Normal file
90
tensorflow/cc/saved_model/loader_util.cc
Normal file
@ -0,0 +1,90 @@
|
||||
/* Copyright 2016 The TensorFlow Authors. All Rights Reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
|
||||
#include "tensorflow/cc/saved_model/loader_util.h"
|
||||
|
||||
#include <vector>
|
||||
|
||||
#include "tensorflow/cc/saved_model/constants.h"
|
||||
#include "tensorflow/core/lib/strings/strcat.h"
|
||||
#include "tensorflow/core/platform/errors.h"
|
||||
#include "tensorflow/core/platform/protobuf_internal.h"
|
||||
|
||||
namespace tensorflow {
|
||||
namespace internal {
|
||||
|
||||
// A SavedModel may store the name of the initialization op to run in the
|
||||
// in the SignatureDef (v2) or a collection (v1). If an init_op collection
|
||||
// exists, then the collection must contain exactly one op.
|
||||
Status GetInitOp(const string& export_dir, const MetaGraphDef& meta_graph_def,
|
||||
string* init_op_name) {
|
||||
const auto& sig_def_map = meta_graph_def.signature_def();
|
||||
const auto& init_op_sig_it =
|
||||
meta_graph_def.signature_def().find(kSavedModelInitOpSignatureKey);
|
||||
if (init_op_sig_it != sig_def_map.end()) {
|
||||
*init_op_name = init_op_sig_it->second.outputs()
|
||||
.find(kSavedModelInitOpSignatureKey)
|
||||
->second.name();
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
const auto& collection_def_map = meta_graph_def.collection_def();
|
||||
string init_op_collection_key;
|
||||
if (collection_def_map.find(kSavedModelMainOpKey) !=
|
||||
collection_def_map.end()) {
|
||||
init_op_collection_key = kSavedModelMainOpKey;
|
||||
} else {
|
||||
init_op_collection_key = kSavedModelLegacyInitOpKey;
|
||||
}
|
||||
|
||||
const auto init_op_it = collection_def_map.find(init_op_collection_key);
|
||||
if (init_op_it != collection_def_map.end()) {
|
||||
if (init_op_it->second.node_list().value_size() != 1) {
|
||||
return errors::FailedPrecondition(
|
||||
strings::StrCat("Expected exactly one main op in : ", export_dir));
|
||||
}
|
||||
*init_op_name = init_op_it->second.node_list().value(0);
|
||||
}
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
Status GetAssetFileDefs(const MetaGraphDef& meta_graph_def,
|
||||
std::vector<AssetFileDef>* asset_file_defs) {
|
||||
// With SavedModel v2, we write asset file def into metagraph instead of
|
||||
// collection, so read from metagraph first.
|
||||
if (meta_graph_def.asset_file_def_size() > 0) {
|
||||
for (const auto& asset : meta_graph_def.asset_file_def()) {
|
||||
asset_file_defs->push_back(asset);
|
||||
}
|
||||
return Status::OK();
|
||||
}
|
||||
// Fall back to read from collection to be backward compatible with v1.
|
||||
const auto& collection_def_map = meta_graph_def.collection_def();
|
||||
const auto assets_it = collection_def_map.find(kSavedModelAssetsKey);
|
||||
if (assets_it == collection_def_map.end()) {
|
||||
return Status::OK();
|
||||
}
|
||||
const auto& any_assets = assets_it->second.any_list().value();
|
||||
for (const auto& any_asset : any_assets) {
|
||||
AssetFileDef asset_file_def;
|
||||
TF_RETURN_IF_ERROR(
|
||||
ParseAny(any_asset, &asset_file_def, "tensorflow.AssetFileDef"));
|
||||
asset_file_defs->push_back(asset_file_def);
|
||||
}
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
} // namespace internal
|
||||
} // namespace tensorflow
|
39
tensorflow/cc/saved_model/loader_util.h
Normal file
39
tensorflow/cc/saved_model/loader_util.h
Normal file
@ -0,0 +1,39 @@
|
||||
/* Copyright 2016 The TensorFlow Authors. All Rights Reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
|
||||
#ifndef TENSORFLOW_CC_SAVED_MODEL_LOADER_UTIL_H_
|
||||
#define TENSORFLOW_CC_SAVED_MODEL_LOADER_UTIL_H_
|
||||
|
||||
#include <string>
|
||||
|
||||
#include "tensorflow/core/lib/core/status.h"
|
||||
#include "tensorflow/core/protobuf/meta_graph.pb.h"
|
||||
|
||||
namespace tensorflow {
|
||||
namespace internal {
|
||||
|
||||
// A SavedModel may store the name of the initialization op to run in the
|
||||
// in the SignatureDef (v2) or a collection (v1). If an init_op collection
|
||||
// exists, then the collection must contain exactly one op.
|
||||
Status GetInitOp(const string& export_dir, const MetaGraphDef& meta_graph_def,
|
||||
string* init_op_name);
|
||||
|
||||
Status GetAssetFileDefs(const MetaGraphDef& meta_graph_def,
|
||||
std::vector<AssetFileDef>* asset_file_defs);
|
||||
|
||||
} // namespace internal
|
||||
} // namespace tensorflow
|
||||
|
||||
#endif // TENSORFLOW_CC_SAVED_MODEL_LOADER_UTIL_H_
|
@ -33,6 +33,7 @@ MarkForCompilationPassFlags* mark_for_compilation_flags;
|
||||
XlaDeviceFlags* device_flags;
|
||||
XlaOpsCommonFlags* ops_flags;
|
||||
IntroduceFloatingPointJitterPassFlags* jitter_flags;
|
||||
MlirCommonFlags* mlir_flags;
|
||||
|
||||
std::vector<Flag>* flag_list;
|
||||
absl::once_flag flags_init;
|
||||
@ -166,6 +167,9 @@ void AllocateAndParseFlags() {
|
||||
jitter_flags = new IntroduceFloatingPointJitterPassFlags;
|
||||
jitter_flags->jitter_amount = 1e-5;
|
||||
|
||||
mlir_flags = new MlirCommonFlags;
|
||||
mlir_flags->tf_mlir_enable_mlir_bridge = false;
|
||||
|
||||
auto setter_for_jitter_tensor_names = [](string sequence) {
|
||||
jitter_flags->tensor_names = absl::StrSplit(sequence, ',');
|
||||
return true;
|
||||
@ -211,7 +215,11 @@ void AllocateAndParseFlags() {
|
||||
Flag("tf_introduce_floating_point_jitter_amount",
|
||||
&jitter_flags->jitter_amount,
|
||||
"The amount of jitter to introduce. This amount is added to each "
|
||||
"element in the tensors named in `tensor_names.")});
|
||||
"element in the tensors named in `tensor_names."),
|
||||
|
||||
Flag("tf_mlir_enable_mlir_bridge",
|
||||
&mlir_flags->tf_mlir_enable_mlir_bridge,
|
||||
"Enables experimental MLIR-Based TensorFlow Compiler Bridge.")});
|
||||
|
||||
AppendMarkForCompilationPassFlagsInternal(flag_list);
|
||||
xla::ParseFlagsFromEnvAndDieIfUnknown("TF_XLA_FLAGS", *flag_list);
|
||||
@ -250,6 +258,11 @@ GetIntroduceFloatingPointJitterPassFlags() {
|
||||
return *jitter_flags;
|
||||
}
|
||||
|
||||
MlirCommonFlags* GetMlirCommonFlags() {
|
||||
absl::call_once(flags_init, &AllocateAndParseFlags);
|
||||
return mlir_flags;
|
||||
}
|
||||
|
||||
void AppendMarkForCompilationPassFlags(std::vector<Flag>* flag_list) {
|
||||
absl::call_once(flags_init, &AllocateAndParseFlags);
|
||||
AppendMarkForCompilationPassFlagsInternal(flag_list);
|
||||
|
@ -133,6 +133,11 @@ struct IntroduceFloatingPointJitterPassFlags {
|
||||
std::vector<string> tensor_names;
|
||||
};
|
||||
|
||||
// Flags for common MLIR configurations.
|
||||
struct MlirCommonFlags {
|
||||
bool tf_mlir_enable_mlir_bridge;
|
||||
};
|
||||
|
||||
// Return a pointer to the DumpGraphFlags struct;
|
||||
// repeated calls return the same pointer.
|
||||
// This should be called only after Flags::Parse() has returned.
|
||||
@ -148,6 +153,8 @@ const XlaOpsCommonFlags& GetXlaOpsCommonFlags();
|
||||
const IntroduceFloatingPointJitterPassFlags&
|
||||
GetIntroduceFloatingPointJitterPassFlags();
|
||||
|
||||
MlirCommonFlags* GetMlirCommonFlags();
|
||||
|
||||
// Appends the flag definitions associated with
|
||||
// MarkForCompilationPassFlags/DumpGraphFlags to `flag_list`.
|
||||
//
|
||||
|
@ -30,7 +30,7 @@ cc_library(
|
||||
hdrs = ["op_or_arg_name_mapper.h"],
|
||||
deps = [
|
||||
"@com_google_absl//absl/strings",
|
||||
"@llvm-project//llvm:support",
|
||||
"@llvm-project//llvm:Support",
|
||||
"@llvm-project//mlir:IR",
|
||||
],
|
||||
)
|
||||
@ -42,7 +42,7 @@ cc_library(
|
||||
":init_mlir",
|
||||
"//tensorflow/core:lib",
|
||||
"//tensorflow/core/platform:logging",
|
||||
"@llvm-project//llvm:support",
|
||||
"@llvm-project//llvm:Support",
|
||||
"@llvm-project//mlir:AllPassesAndDialects",
|
||||
"@llvm-project//mlir:IR",
|
||||
"@llvm-project//mlir:MlirOptLib",
|
||||
@ -86,7 +86,7 @@ cc_library(
|
||||
hdrs = ["init_mlir.h"],
|
||||
deps = [
|
||||
"//tensorflow/core:lib",
|
||||
"@llvm-project//llvm:support",
|
||||
"@llvm-project//llvm:Support",
|
||||
],
|
||||
)
|
||||
|
||||
@ -102,7 +102,7 @@ cc_library(
|
||||
"//tensorflow/compiler/mlir/tensorflow:mlir_roundtrip_flags",
|
||||
"//tensorflow/core:core_cpu",
|
||||
"@com_google_absl//absl/container:flat_hash_set",
|
||||
"@llvm-project//llvm:support",
|
||||
"@llvm-project//llvm:Support",
|
||||
"@llvm-project//mlir:IR",
|
||||
"@llvm-project//mlir:Shape",
|
||||
"@llvm-project//mlir:StandardOps",
|
||||
@ -155,7 +155,7 @@ tf_cc_binary(
|
||||
"//tensorflow/core:tensorflow",
|
||||
"//tensorflow/stream_executor/lib",
|
||||
"@com_google_absl//absl/strings",
|
||||
"@llvm-project//llvm:support",
|
||||
"@llvm-project//llvm:Support",
|
||||
"@llvm-project//mlir:AllPassesAndDialects",
|
||||
"@llvm-project//mlir:IR",
|
||||
"@llvm-project//mlir:Support",
|
||||
|
@ -225,7 +225,7 @@ cc_library(
|
||||
"//tensorflow/compiler/mlir/lite/quantization:quantization_lib",
|
||||
"//tensorflow/compiler/mlir/tensorflow:tensorflow_types",
|
||||
"//tensorflow/lite/schema:schema_fbs",
|
||||
"@llvm-project//llvm:support",
|
||||
"@llvm-project//llvm:Support",
|
||||
"@llvm-project//mlir:Analysis",
|
||||
"@llvm-project//mlir:DerivedAttributeOpInterface",
|
||||
"@llvm-project//mlir:Dialect",
|
||||
@ -253,7 +253,7 @@ cc_library(
|
||||
deps = [
|
||||
":tensorflow_lite",
|
||||
"//tensorflow/compiler/mlir/tensorflow",
|
||||
"@llvm-project//llvm:support",
|
||||
"@llvm-project//llvm:Support",
|
||||
"@llvm-project//mlir:IR",
|
||||
"@llvm-project//mlir:StandardOps",
|
||||
"@llvm-project//mlir:Support",
|
||||
@ -272,7 +272,7 @@ cc_library(
|
||||
deps = [
|
||||
":tensorflow_lite",
|
||||
"//tensorflow/compiler/mlir/tensorflow",
|
||||
"@llvm-project//llvm:support",
|
||||
"@llvm-project//llvm:Support",
|
||||
"@llvm-project//mlir:IR",
|
||||
"@llvm-project//mlir:StandardOps",
|
||||
"@llvm-project//mlir:Support",
|
||||
@ -289,7 +289,7 @@ cc_library(
|
||||
],
|
||||
deps = [
|
||||
":tensorflow_lite",
|
||||
"@llvm-project//llvm:support",
|
||||
"@llvm-project//llvm:Support",
|
||||
"@llvm-project//mlir:IR",
|
||||
"@llvm-project//mlir:StandardOps",
|
||||
],
|
||||
@ -304,7 +304,7 @@ tf_cc_test(
|
||||
"//tensorflow/compiler/mlir/tensorflow",
|
||||
"//tensorflow/core:test",
|
||||
"//tensorflow/core:test_main",
|
||||
"@llvm-project//llvm:support",
|
||||
"@llvm-project//llvm:Support",
|
||||
"@llvm-project//mlir:IR",
|
||||
"@llvm-project//mlir:StandardOps",
|
||||
"@llvm-project//mlir:Support",
|
||||
@ -357,7 +357,7 @@ cc_library(
|
||||
"//tensorflow/core/platform:logging",
|
||||
"@com_google_absl//absl/container:inlined_vector",
|
||||
"@com_google_absl//absl/memory",
|
||||
"@llvm-project//llvm:support",
|
||||
"@llvm-project//llvm:Support",
|
||||
"@llvm-project//mlir:Analysis",
|
||||
"@llvm-project//mlir:IR",
|
||||
"@llvm-project//mlir:Pass",
|
||||
@ -383,7 +383,7 @@ cc_library(
|
||||
":validators",
|
||||
"//tensorflow/compiler/mlir/lite/quantization:quantization_lib",
|
||||
"//tensorflow/compiler/mlir/tensorflow",
|
||||
"@llvm-project//llvm:support",
|
||||
"@llvm-project//llvm:Support",
|
||||
"@llvm-project//mlir:Analysis",
|
||||
"@llvm-project//mlir:IR",
|
||||
"@llvm-project//mlir:Pass",
|
||||
@ -416,7 +416,7 @@ cc_library(
|
||||
"//tensorflow/compiler/mlir/lite/quantization/lite:tfl_to_std",
|
||||
"//tensorflow/core:protos_all_cc",
|
||||
"@com_google_absl//absl/memory",
|
||||
"@llvm-project//llvm:support",
|
||||
"@llvm-project//llvm:Support",
|
||||
"@llvm-project//mlir:Analysis",
|
||||
"@llvm-project//mlir:IR",
|
||||
"@llvm-project//mlir:Pass",
|
||||
@ -441,7 +441,7 @@ cc_library(
|
||||
"@com_google_absl//absl/base",
|
||||
"@com_google_absl//absl/memory",
|
||||
"@com_google_absl//absl/strings",
|
||||
"@llvm-project//llvm:support",
|
||||
"@llvm-project//llvm:Support",
|
||||
"@llvm-project//mlir:IR",
|
||||
"@llvm-project//mlir:Pass",
|
||||
"@llvm-project//mlir:StandardOps",
|
||||
@ -494,8 +494,8 @@ tf_native_cc_binary(
|
||||
"converter_gen.cc",
|
||||
],
|
||||
deps = [
|
||||
"@llvm-project//llvm:support",
|
||||
"@llvm-project//llvm:tablegen",
|
||||
"@llvm-project//llvm:Support",
|
||||
"@llvm-project//llvm:TableGen",
|
||||
"@llvm-project//mlir:TableGen",
|
||||
],
|
||||
)
|
||||
@ -541,8 +541,8 @@ cc_library(
|
||||
"@com_google_absl//absl/container:flat_hash_map",
|
||||
"@com_google_absl//absl/strings",
|
||||
"@flatbuffers",
|
||||
"@llvm-project//llvm:analysis",
|
||||
"@llvm-project//llvm:support",
|
||||
"@llvm-project//llvm:Analysis",
|
||||
"@llvm-project//llvm:Support",
|
||||
"@llvm-project//mlir:IR",
|
||||
"@llvm-project//mlir:TransformUtils",
|
||||
],
|
||||
@ -619,7 +619,7 @@ cc_library(
|
||||
"@com_google_absl//absl/container:flat_hash_set",
|
||||
"@com_google_absl//absl/strings",
|
||||
"@flatbuffers",
|
||||
"@llvm-project//llvm:support",
|
||||
"@llvm-project//llvm:Support",
|
||||
"@llvm-project//mlir:IR",
|
||||
"@llvm-project//mlir:QuantOps",
|
||||
"@llvm-project//mlir:StandardOps",
|
||||
@ -653,7 +653,7 @@ cc_library(
|
||||
"@com_google_absl//absl/container:flat_hash_map",
|
||||
"@com_google_absl//absl/container:flat_hash_set",
|
||||
"@com_google_absl//absl/strings",
|
||||
"@llvm-project//llvm:support",
|
||||
"@llvm-project//llvm:Support",
|
||||
"@llvm-project//mlir:IR",
|
||||
"@llvm-project//mlir:QuantOps",
|
||||
"@llvm-project//mlir:StandardOps",
|
||||
@ -713,7 +713,7 @@ cc_library(
|
||||
"@com_google_absl//absl/container:flat_hash_map",
|
||||
"@com_google_absl//absl/container:flat_hash_set",
|
||||
"@com_google_absl//absl/strings",
|
||||
"@llvm-project//llvm:support",
|
||||
"@llvm-project//llvm:Support",
|
||||
"@llvm-project//mlir:IR",
|
||||
"@llvm-project//mlir:MlirTranslateMain",
|
||||
"@llvm-project//mlir:QuantOps",
|
||||
@ -743,7 +743,7 @@ cc_library(
|
||||
"tf_tfl_translate_cl.h",
|
||||
],
|
||||
deps = [
|
||||
"@llvm-project//llvm:support",
|
||||
"@llvm-project//llvm:Support",
|
||||
],
|
||||
alwayslink = 1,
|
||||
)
|
||||
@ -755,7 +755,7 @@ cc_library(
|
||||
],
|
||||
deps = [
|
||||
"//tensorflow/compiler/mlir/lite/quantization:quantization_config",
|
||||
"@llvm-project//llvm:support",
|
||||
"@llvm-project//llvm:Support",
|
||||
],
|
||||
)
|
||||
|
||||
@ -780,7 +780,7 @@ tf_cc_binary(
|
||||
":tf_tfl_translate_cl_options",
|
||||
":tf_to_tfl_flatbuffer",
|
||||
"@com_google_absl//absl/strings",
|
||||
"@llvm-project//llvm:support",
|
||||
"@llvm-project//llvm:Support",
|
||||
# TODO(b/155809683): Link only necessary dialects.
|
||||
"@llvm-project//mlir:AllPassesAndDialects",
|
||||
"@llvm-project//mlir:IR",
|
||||
@ -805,7 +805,7 @@ tf_cc_binary(
|
||||
":flatbuffer_translate_lib",
|
||||
":flatbuffer_translate_registeration",
|
||||
"@com_google_absl//absl/strings",
|
||||
"@llvm-project//llvm:support",
|
||||
"@llvm-project//llvm:Support",
|
||||
# TODO(b/155809683): Link only necessary dialects.
|
||||
"@llvm-project//mlir:AllPassesAndDialects",
|
||||
"@llvm-project//mlir:IR",
|
||||
@ -874,7 +874,7 @@ cc_library(
|
||||
"//tensorflow/lite/tools/optimize:quantize_weights",
|
||||
"//tensorflow/stream_executor/lib",
|
||||
"@com_google_absl//absl/types:span",
|
||||
"@llvm-project//llvm:support",
|
||||
"@llvm-project//llvm:Support",
|
||||
"@llvm-project//mlir:AllPassesAndDialects",
|
||||
"@llvm-project//mlir:IR",
|
||||
"@llvm-project//mlir:Parser",
|
||||
@ -894,6 +894,6 @@ cc_library(
|
||||
"//tensorflow/lite/experimental/mlir:__subpackages__",
|
||||
],
|
||||
deps = [
|
||||
"@llvm-project//llvm:support",
|
||||
"@llvm-project//llvm:Support",
|
||||
],
|
||||
)
|
||||
|
@ -868,6 +868,8 @@ StatusOr<FuncOp> ConvertSubgraph(
|
||||
subgraph, &builder, "outputs", func_outputs));
|
||||
}
|
||||
func.setAttr("tf.entry_function", builder.getDictionaryAttr(attributes));
|
||||
} else {
|
||||
func.setVisibility(FuncOp::Visibility::Private);
|
||||
}
|
||||
|
||||
absl::flat_hash_set<const tflite::OperatorT*> pruned_subgraph_ops;
|
||||
|
@ -27,7 +27,7 @@ cc_library(
|
||||
"//tensorflow/lite/toco:toco_flags_proto_cc",
|
||||
"//tensorflow/lite/toco:types_proto_cc",
|
||||
"//tensorflow/stream_executor/lib",
|
||||
"@llvm-project//llvm:support",
|
||||
"@llvm-project//llvm:Support",
|
||||
"@llvm-project//mlir:IR",
|
||||
"@llvm-project//mlir:Pass",
|
||||
"@llvm-project//mlir:Support",
|
||||
@ -56,7 +56,7 @@ cc_library(
|
||||
"//tensorflow/lite/toco:toco_flags_proto_cc",
|
||||
"//tensorflow/lite/toco:types_proto_cc",
|
||||
"//tensorflow/stream_executor/lib",
|
||||
"@llvm-project//llvm:support",
|
||||
"@llvm-project//llvm:Support",
|
||||
"@llvm-project//mlir:IR",
|
||||
"@llvm-project//mlir:Pass",
|
||||
"@llvm-project//mlir:Support",
|
||||
@ -85,7 +85,7 @@ cc_library(
|
||||
"//tensorflow/lite/toco:toco_flags_proto_cc",
|
||||
"//tensorflow/lite/toco:types_proto_cc",
|
||||
"//tensorflow/stream_executor/lib",
|
||||
"@llvm-project//llvm:support",
|
||||
"@llvm-project//llvm:Support",
|
||||
"@llvm-project//mlir:IR",
|
||||
"@llvm-project//mlir:Pass",
|
||||
"@llvm-project//mlir:Support",
|
||||
|
@ -80,7 +80,7 @@ cc_library(
|
||||
"//tensorflow/core:lib_proto_parsing",
|
||||
"@com_google_absl//absl/memory",
|
||||
"@com_google_absl//absl/strings",
|
||||
"@llvm-project//llvm:support",
|
||||
"@llvm-project//llvm:Support",
|
||||
"@llvm-project//mlir:Analysis",
|
||||
"@llvm-project//mlir:IR",
|
||||
"@llvm-project//mlir:Pass",
|
||||
@ -106,7 +106,7 @@ cc_library(
|
||||
deps = [
|
||||
"//tensorflow/core:lib_proto_parsing",
|
||||
"@com_google_absl//absl/memory",
|
||||
"@llvm-project//llvm:support",
|
||||
"@llvm-project//llvm:Support",
|
||||
"@llvm-project//mlir:IR",
|
||||
"@llvm-project//mlir:QuantOps",
|
||||
"@llvm-project//mlir:StandardOps",
|
||||
@ -125,7 +125,7 @@ cc_library(
|
||||
deps = [
|
||||
"//tensorflow/core:protos_all_cc",
|
||||
"@com_google_absl//absl/strings",
|
||||
"@llvm-project//llvm:support",
|
||||
"@llvm-project//llvm:Support",
|
||||
],
|
||||
)
|
||||
|
||||
@ -135,8 +135,8 @@ tf_native_cc_binary(
|
||||
"tools/op_quant_spec_getters_gen.cc",
|
||||
],
|
||||
deps = [
|
||||
"@llvm-project//llvm:support",
|
||||
"@llvm-project//llvm:tablegen",
|
||||
"@llvm-project//llvm:Support",
|
||||
"@llvm-project//llvm:TableGen",
|
||||
"@llvm-project//mlir:TableGen",
|
||||
],
|
||||
)
|
||||
@ -157,7 +157,7 @@ cc_library(
|
||||
deps = [
|
||||
":numerical_utils",
|
||||
"@com_google_absl//absl/types:optional",
|
||||
"@llvm-project//llvm:support",
|
||||
"@llvm-project//llvm:Support",
|
||||
"@llvm-project//mlir:IR",
|
||||
"@llvm-project//mlir:QuantOps",
|
||||
"@llvm-project//mlir:Support",
|
||||
@ -172,7 +172,7 @@ cc_library(
|
||||
":device_target",
|
||||
":quantization_lib",
|
||||
"@com_google_absl//absl/memory",
|
||||
"@llvm-project//llvm:support",
|
||||
"@llvm-project//llvm:Support",
|
||||
"@llvm-project//mlir:IR",
|
||||
"@llvm-project//mlir:QuantOps",
|
||||
"@llvm-project//mlir:StandardOps",
|
||||
|
@ -36,7 +36,7 @@ cc_library(
|
||||
"//tensorflow/lite/core/api",
|
||||
"//tensorflow/lite/schema:schema_fbs",
|
||||
"@com_google_absl//absl/strings",
|
||||
"@llvm-project//llvm:support",
|
||||
"@llvm-project//llvm:Support",
|
||||
"@llvm-project//mlir:IR",
|
||||
"@llvm-project//mlir:Pass",
|
||||
],
|
||||
@ -54,7 +54,7 @@ cc_library(
|
||||
deps = [
|
||||
"//tensorflow/compiler/mlir/lite:tensorflow_lite",
|
||||
"//tensorflow/compiler/mlir/lite/quantization:quantization_lib",
|
||||
"@llvm-project//llvm:support",
|
||||
"@llvm-project//llvm:Support",
|
||||
"@llvm-project//mlir:IR",
|
||||
"@llvm-project//mlir:QuantOps",
|
||||
"@llvm-project//mlir:StandardOps",
|
||||
@ -73,7 +73,7 @@ tf_cc_binary(
|
||||
"//tensorflow/lite:framework",
|
||||
"//tensorflow/lite/schema:schema_fbs",
|
||||
"@com_google_absl//absl/strings",
|
||||
"@llvm-project//llvm:support",
|
||||
"@llvm-project//llvm:Support",
|
||||
"@llvm-project//mlir:AllPassesAndDialects",
|
||||
],
|
||||
)
|
||||
|
@ -27,7 +27,7 @@ cc_library(
|
||||
"//tensorflow/compiler/mlir/lite/quantization:quantization_lib",
|
||||
"//tensorflow/compiler/mlir/tensorflow",
|
||||
"@com_google_absl//absl/strings",
|
||||
"@llvm-project//llvm:support",
|
||||
"@llvm-project//llvm:Support",
|
||||
"@llvm-project//mlir:IR",
|
||||
"@llvm-project//mlir:Pass",
|
||||
"@llvm-project//mlir:QuantOps",
|
||||
|
@ -32,7 +32,7 @@ cc_library(
|
||||
"//tensorflow/lite/core/api",
|
||||
"//tensorflow/lite/schema:schema_fbs",
|
||||
"@com_google_absl//absl/strings",
|
||||
"@llvm-project//llvm:support",
|
||||
"@llvm-project//llvm:Support",
|
||||
"@llvm-project//mlir:IR",
|
||||
"@llvm-project//mlir:Pass",
|
||||
],
|
||||
|
@ -54,7 +54,7 @@ tf_native_cc_binary(
|
||||
"//tensorflow/lite:framework",
|
||||
"//tensorflow/lite/schema:schema_fbs",
|
||||
"@com_google_absl//absl/strings",
|
||||
"@llvm-project//llvm:support",
|
||||
"@llvm-project//llvm:Support",
|
||||
],
|
||||
)
|
||||
|
||||
@ -70,6 +70,6 @@ tf_native_cc_binary(
|
||||
"//tensorflow/lite:framework",
|
||||
"//tensorflow/lite/schema:schema_fbs",
|
||||
"@com_google_absl//absl/strings",
|
||||
"@llvm-project//llvm:support",
|
||||
"@llvm-project//llvm:Support",
|
||||
],
|
||||
)
|
||||
|
@ -0,0 +1,12 @@
|
||||
// RUN: tf-opt %s -tfl-legalize-tf='run-tfl-runtime-verification=false' | FileCheck %s
|
||||
|
||||
func @broadcast_to_bf16(%arg0: tensor<3xbf16>, %arg1: tensor<2xi64>) -> tensor<3x3xbf16> {
|
||||
%0 = "tf.BroadcastTo"(%arg0, %arg1) : (tensor<3xbf16>, tensor<2xi64>) -> tensor<3x3xbf16>
|
||||
return %0: tensor<3x3xbf16>
|
||||
|
||||
// CHECK-LABEL: broadcast_to_bf16
|
||||
// CHECK: [[CST:%.*]] = constant dense<1.000000e+00> : tensor<bf16>
|
||||
// CHECK: [[FILL:%.*]] = "tfl.fill"(%arg1, [[CST]]) : (tensor<2xi64>, tensor<bf16>) -> tensor<3x3xbf16>
|
||||
// CHECK: [[MUL:%.*]] = "tfl.mul"(%arg0, [[FILL]]) {fused_activation_function = "NONE"} : (tensor<3xbf16>, tensor<3x3xbf16>) -> tensor<3x3xbf16>
|
||||
// CHECK: return [[MUL]] : tensor<3x3xbf16>
|
||||
}
|
@ -1021,24 +1021,6 @@ func @matmul_transposed(%arg0: tensor<40x37xf32>, %arg1: tensor<40x37xf32>) -> t
|
||||
// CHECK: "tfl.fully_connected"(%arg0, %arg1, %cst) {fused_activation_function = "NONE", keep_num_dims = false, weights_format = "DEFAULT"} : (tensor<40x37xf32>, tensor<40x37xf32>, none) -> tensor<40x40xf32>
|
||||
}
|
||||
|
||||
func @concat2Tensors(%arg0: tensor<2x1xi32>, %arg1: tensor<2x1xi32>) -> tensor<2x2xi32> {
|
||||
%0 = "tf.Const"() { value = dense<1> : tensor<i32> } : () -> tensor<i32>
|
||||
%1 = "tf.Concat"(%0, %arg0, %arg1) : (tensor<i32>, tensor<2x1xi32>, tensor<2x1xi32>) -> tensor<2x2xi32>
|
||||
return %1 : tensor<2x2xi32>
|
||||
|
||||
// CHECK-LABEL: concat2Tensors
|
||||
// CHECK: "tfl.concatenation"(%arg0, %arg1) {axis = 1 : i32, fused_activation_function = "NONE"} : (tensor<2x1xi32>, tensor<2x1xi32>) -> tensor<2x2xi32>
|
||||
}
|
||||
|
||||
func @concat3Tensors(%arg0: tensor<2x1xi32>, %arg1: tensor<2x1xi32>, %arg2: tensor<2x1xi32>) -> tensor<2x3xi32> {
|
||||
%0 = "tf.Const"() { value = dense<-1> : tensor<i32> } : () -> tensor<i32>
|
||||
%1 = "tf.Concat"(%0, %arg0, %arg1, %arg2) : (tensor<i32>, tensor<2x1xi32>, tensor<2x1xi32>, tensor<2x1xi32>) -> tensor<2x3xi32>
|
||||
return %1 : tensor<2x3xi32>
|
||||
|
||||
// CHECK-LABEL: concat3Tensors
|
||||
// CHECK: "tfl.concatenation"(%arg0, %arg1, %arg2) {axis = -1 : i32, fused_activation_function = "NONE"} : (tensor<2x1xi32>, tensor<2x1xi32>, tensor<2x1xi32>) -> tensor<2x3xi32>
|
||||
}
|
||||
|
||||
func @concatv2With3Tensors(%arg0: tensor<2x1xi32>, %arg1: tensor<2x1xi32>, %arg2: tensor<2x1xi32>) -> tensor<2x3xi32> {
|
||||
%0 = "tf.Const"() { value = dense<-1> : tensor<i32> } : () -> tensor<i32>
|
||||
%1 = "tf.ConcatV2"(%arg0, %arg1, %arg2, %0) : (tensor<2x1xi32>, tensor<2x1xi32>, tensor<2x1xi32>, tensor<i32>) -> tensor<2x3xi32>
|
||||
|
@ -13,14 +13,12 @@ func @main(%arg0: tensor<f32>, %arg1: tensor<f32>) -> (tensor<f32>) {
|
||||
return %3 : tensor<f32>
|
||||
}
|
||||
|
||||
// CHECK-NOT: add
|
||||
// CHECK-NOT: sub
|
||||
func @add(%arg0: tensor<*xf32>, %arg1: tensor<*xf32>) -> tensor<*xf32> {
|
||||
func @add(%arg0: tensor<*xf32>, %arg1: tensor<*xf32>) -> tensor<*xf32> attributes {sym_visibility = "private"} {
|
||||
%0 = "tf.Add"(%arg0, %arg1): (tensor<*xf32>, tensor<*xf32>) -> tensor<*xf32>
|
||||
return %0 : tensor<*xf32>
|
||||
}
|
||||
|
||||
func @sub(%arg0: tensor<*xf32>, %arg1: tensor<*xf32>) -> tensor<*xf32> {
|
||||
func @sub(%arg0: tensor<*xf32>, %arg1: tensor<*xf32>) -> tensor<*xf32> attributes {sym_visibility = "private"} {
|
||||
%0 = "tf.Sub"(%arg0, %arg1) : (tensor<*xf32>, tensor<*xf32>) -> tensor<*xf32>
|
||||
return %0 : tensor<*xf32>
|
||||
}
|
||||
@ -42,65 +40,31 @@ func @main(%arg0: tensor<f32>, %arg1: tensor<f32>) -> (tensor<f32>) {
|
||||
return %3 : tensor<f32>
|
||||
}
|
||||
|
||||
// CHECK-NOT: addormul
|
||||
// CHECK-NOT: sub
|
||||
// CHECK-NOT: mul
|
||||
// CHECK-NOT: add
|
||||
|
||||
func @addormul(%arg0: tensor<*xf32>, %arg1: tensor<*xf32>) -> tensor<*xf32> {
|
||||
func @addormul(%arg0: tensor<*xf32>, %arg1: tensor<*xf32>) -> tensor<*xf32> attributes {sym_visibility = "private"} {
|
||||
%0 = constant dense<false> : tensor<i1>
|
||||
%1 = "tf.If"(%0, %arg1, %arg0) {else_branch = @mul, then_branch = @add, is_stateless = true} : (tensor<i1>, tensor<*xf32>, tensor<*xf32>) -> tensor<*xf32>
|
||||
return %1 : tensor<*xf32>
|
||||
}
|
||||
|
||||
func @sub(%arg0: tensor<*xf32>, %arg1: tensor<*xf32>) -> tensor<*xf32> {
|
||||
func @sub(%arg0: tensor<*xf32>, %arg1: tensor<*xf32>) -> tensor<*xf32> attributes {sym_visibility = "private"} {
|
||||
%0 = "tf.Sub"(%arg0, %arg1) : (tensor<*xf32>, tensor<*xf32>) -> tensor<*xf32>
|
||||
return %0 : tensor<*xf32>
|
||||
}
|
||||
|
||||
func @add(%arg0: tensor<*xf32>, %arg1: tensor<*xf32>) -> tensor<*xf32> {
|
||||
func @add(%arg0: tensor<*xf32>, %arg1: tensor<*xf32>) -> tensor<*xf32> attributes {sym_visibility = "private"} {
|
||||
%0 = "tf.Add"(%arg0, %arg1): (tensor<*xf32>, tensor<*xf32>) -> tensor<*xf32>
|
||||
return %0 : tensor<*xf32>
|
||||
}
|
||||
|
||||
func @mul(%arg0: tensor<*xf32>, %arg1: tensor<*xf32>) -> tensor<*xf32> {
|
||||
func @mul(%arg0: tensor<*xf32>, %arg1: tensor<*xf32>) -> tensor<*xf32> attributes {sym_visibility = "private"} {
|
||||
%0 = "tf.Multiply"(%arg0, %arg1): (tensor<*xf32>, tensor<*xf32>) -> tensor<*xf32>
|
||||
return %0 : tensor<*xf32>
|
||||
}
|
||||
|
||||
// -----
|
||||
|
||||
// Verify that branch functions with multiple references are not erased.
|
||||
|
||||
func @main(%arg0: tensor<f32>, %arg1: tensor<f32>, %arg2: tensor<i1>) -> (tensor<f32>, tensor<f32>) {
|
||||
%0 = "tf.Placeholder.input"(%arg0) : (tensor<f32>) -> tensor<f32>
|
||||
%1 = "tf.Placeholder.input"(%arg1) : (tensor<f32>) -> tensor<f32>
|
||||
%2 = constant dense<true> : tensor<i1>
|
||||
|
||||
// CHECK: tf.Add
|
||||
%3 = "tf.If"(%2, %0, %1) {else_branch = @sub, then_branch = @add, is_stateless = true} : (tensor<i1>, tensor<f32>, tensor<f32>) -> tensor<f32>
|
||||
|
||||
// CHECK: tf.If
|
||||
%4 = "tf.If"(%arg2, %0, %1) {else_branch = @sub, then_branch = @add, is_stateless = true} : (tensor<i1>, tensor<f32>, tensor<f32>) -> tensor<f32>
|
||||
return %3, %4 : tensor<f32>, tensor<f32>
|
||||
}
|
||||
|
||||
// CHECK: add
|
||||
// CHECK: sub
|
||||
func @add(%arg0: tensor<*xf32>, %arg1: tensor<*xf32>) -> tensor<*xf32> {
|
||||
%0 = "tf.Add"(%arg0, %arg1): (tensor<*xf32>, tensor<*xf32>) -> tensor<*xf32>
|
||||
return %0 : tensor<*xf32>
|
||||
}
|
||||
|
||||
func @sub(%arg0: tensor<*xf32>, %arg1: tensor<*xf32>) -> tensor<*xf32> {
|
||||
%0 = "tf.Sub"(%arg0, %arg1) : (tensor<*xf32>, tensor<*xf32>) -> tensor<*xf32>
|
||||
return %0 : tensor<*xf32>
|
||||
}
|
||||
|
||||
// -----
|
||||
|
||||
// Verify unused if with functions without side-effects are removed.
|
||||
|
||||
// Verify unused if with functions without side-effects is removed.
|
||||
// CHECK-LABEL: main
|
||||
func @main(%arg0: tensor<3x15x14x3xf32>) -> tensor<3x15x14x8xf32>
|
||||
attributes {tf.entry_function = {inputs = "input", outputs = "Conv2D"}} {
|
||||
%cst = constant dense<[0, 1, 2, 3]> : tensor<4xi32>
|
||||
@ -118,26 +82,22 @@ func @main(%arg0: tensor<3x15x14x3xf32>) -> tensor<3x15x14x8xf32>
|
||||
return %4 : tensor<3x15x14x8xf32>
|
||||
}
|
||||
|
||||
func @_functionalize_if_else_branch_00(%arg0: tensor<*xi1>, %arg1: tensor<*xf32>, %arg2: tensor<*xf32>) -> tensor<i1> {
|
||||
func @_functionalize_if_else_branch_00(%arg0: tensor<*xi1>, %arg1: tensor<*xf32>, %arg2: tensor<*xf32>) -> tensor<i1> attributes {sym_visibility = "private"} {
|
||||
%cst = constant dense<false> : tensor<i1>
|
||||
return %cst : tensor<i1>
|
||||
}
|
||||
|
||||
func @_functionalize_if_then_branch_00(%arg0: tensor<*xi1>, %arg1: tensor<*xf32>, %arg2: tensor<*xf32>) -> tensor<i1> {
|
||||
func @_functionalize_if_then_branch_00(%arg0: tensor<*xi1>, %arg1: tensor<*xf32>, %arg2: tensor<*xf32>) -> tensor<i1> attributes {sym_visibility = "private"} {
|
||||
%cst = constant dense<true> : tensor<i1>
|
||||
return %cst : tensor<i1>
|
||||
}
|
||||
|
||||
// CHECK: func @main
|
||||
// CHECK-NOT: tf.If
|
||||
// CHECK: return
|
||||
// CHECK-NOT: func @_functionalize_if_else_branch_00
|
||||
// CHECK-NOT: func @_functionalize_if_then_branch_00
|
||||
|
||||
// -----
|
||||
|
||||
// Verify unused if with function with side-effects is not removed.
|
||||
|
||||
// CHECK-LABEL: main
|
||||
func @main(%arg0: tensor<3x15x14x3xf32>) -> tensor<3x15x14x8xf32>
|
||||
attributes {tf.entry_function = {inputs = "input", outputs = "Conv2D"}} {
|
||||
%cst = constant dense<[0, 1, 2, 3]> : tensor<4xi32>
|
||||
@ -155,27 +115,25 @@ func @main(%arg0: tensor<3x15x14x3xf32>) -> tensor<3x15x14x8xf32>
|
||||
return %4 : tensor<3x15x14x8xf32>
|
||||
}
|
||||
|
||||
func @_functionalize_if_else_branch_01(%arg0: tensor<*xi1>, %arg1: tensor<*xf32>, %arg2: tensor<*xf32>) -> tensor<i1> {
|
||||
func @_functionalize_if_else_branch_01(%arg0: tensor<*xi1>, %arg1: tensor<*xf32>, %arg2: tensor<*xf32>) -> tensor<i1> attributes {sym_visibility = "private"} {
|
||||
%cst = constant dense<false> : tensor<i1>
|
||||
return %cst : tensor<i1>
|
||||
}
|
||||
|
||||
func @_functionalize_if_then_branch_01(%arg0: tensor<*xi1>, %arg1: tensor<*xf32>, %arg2: tensor<*xf32>) -> tensor<i1> {
|
||||
func @_functionalize_if_then_branch_01(%arg0: tensor<*xi1>, %arg1: tensor<*xf32>, %arg2: tensor<*xf32>) -> tensor<i1> attributes {sym_visibility = "private"} {
|
||||
%0 = "tf.blah"() : () -> tensor<i1>
|
||||
return %0 : tensor<i1>
|
||||
}
|
||||
|
||||
// CHECK: func @main
|
||||
// CHECK: tf.If
|
||||
// CHECK: return
|
||||
// CHECK: func @_functionalize_if_else_branch_01
|
||||
// CHECK: func @_functionalize_if_then_branch_01
|
||||
|
||||
// -----
|
||||
|
||||
// Verify unused if with function with side-effects is removed if op says
|
||||
// stateless.
|
||||
|
||||
// CHECK-LABEL: main
|
||||
func @main(%arg0: tensor<3x15x14x3xf32>) -> tensor<3x15x14x8xf32>
|
||||
attributes {tf.entry_function = {inputs = "input", outputs = "Conv2D"}} {
|
||||
%cst = constant dense<[0, 1, 2, 3]> : tensor<4xi32>
|
||||
@ -193,18 +151,15 @@ func @main(%arg0: tensor<3x15x14x3xf32>) -> tensor<3x15x14x8xf32>
|
||||
return %4 : tensor<3x15x14x8xf32>
|
||||
}
|
||||
|
||||
func @_functionalize_if_else_branch_02(%arg0: tensor<*xi1>, %arg1: tensor<*xf32>, %arg2: tensor<*xf32>) -> tensor<i1> {
|
||||
func @_functionalize_if_else_branch_02(%arg0: tensor<*xi1>, %arg1: tensor<*xf32>, %arg2: tensor<*xf32>) -> tensor<i1> attributes {sym_visibility = "private"} {
|
||||
%cst = constant dense<false> : tensor<i1>
|
||||
return %cst : tensor<i1>
|
||||
}
|
||||
|
||||
func @_functionalize_if_then_branch_02(%arg0: tensor<*xi1>, %arg1: tensor<*xf32>, %arg2: tensor<*xf32>) -> tensor<i1> {
|
||||
func @_functionalize_if_then_branch_02(%arg0: tensor<*xi1>, %arg1: tensor<*xf32>, %arg2: tensor<*xf32>) -> tensor<i1> attributes {sym_visibility = "private"} {
|
||||
%0 = "tf.blah"() : () -> tensor<i1>
|
||||
return %0 : tensor<i1>
|
||||
}
|
||||
|
||||
// CHECK: func @main
|
||||
// CHECK-NOT: tf.If
|
||||
// CHECK: return
|
||||
// CHECK-NOT: func @_functionalize_if_else_branch_02
|
||||
// CHECK-NOT: func @_functionalize_if_then_branch_02
|
||||
|
@ -94,12 +94,6 @@ void AddTFToTFLConversionPasses(const mlir::TFL::PassConfig& pass_config,
|
||||
pass_manager->addPass(mlir::TFL::CreatePrepareCompositeFunctionsPass());
|
||||
}
|
||||
|
||||
// This pass marks non-exported functions as symbol visibility 'private'
|
||||
// those deemed read-only as immutable.
|
||||
pass_manager->addPass(
|
||||
mlir::tf_saved_model::
|
||||
CreateMarkFunctionVisibilityUsingSavedModelLinkagePass());
|
||||
|
||||
pass_manager->addPass(mlir::createInlinerPass());
|
||||
pass_manager->addPass(mlir::createSymbolDCEPass());
|
||||
|
||||
@ -162,6 +156,7 @@ void AddTFToTFLConversionPasses(const mlir::TFL::PassConfig& pass_config,
|
||||
// so that it can target constants introduced once TensorFlow Identity ops
|
||||
// are removed during legalization.
|
||||
pass_manager->addPass(mlir::TFL::CreateOptimizeFunctionalOpsPass());
|
||||
pass_manager->addPass(mlir::createSymbolDCEPass());
|
||||
pass_manager->addNestedPass<mlir::FuncOp>(mlir::createCanonicalizerPass());
|
||||
pass_manager->addNestedPass<mlir::FuncOp>(mlir::createCSEPass());
|
||||
// This pass should be always at the end of the floating point model
|
||||
@ -237,6 +232,7 @@ void CreateTFLStandardPipeline(OpPassManager& pm,
|
||||
mlir::TFL::CreateLegalizeTFPass(/*run_tfl_runtime_verification=*/true));
|
||||
pm.addPass(mlir::TFL::CreateOptimizePass());
|
||||
pm.addPass(mlir::TFL::CreateOptimizeFunctionalOpsPass());
|
||||
pm.addPass(mlir::createSymbolDCEPass());
|
||||
|
||||
// Canonicalize, CSE etc.
|
||||
pm.addNestedPass<mlir::FuncOp>(mlir::createCanonicalizerPass());
|
||||
|
@ -123,7 +123,6 @@ bool HasSameStaticShapes(Operation* op) {
|
||||
// operands are properly supported in declarative rewrite rule specification.
|
||||
|
||||
DECL_CONVERT_OP(Assert);
|
||||
DECL_CONVERT_OP(Concat);
|
||||
DECL_CONVERT_OP(ConcatV2);
|
||||
DECL_CONVERT_OP(MatMul);
|
||||
DECL_CONVERT_OP(MatrixDiagV2);
|
||||
@ -184,25 +183,6 @@ LogicalResult ConvertTFRandomUniformOp::matchAndRewrite(
|
||||
return failure();
|
||||
}
|
||||
|
||||
LogicalResult ConvertTFConcatOp::matchAndRewrite(
|
||||
Operation* op, PatternRewriter& rewriter) const {
|
||||
auto tf_concat_op = cast<TF::ConcatOp>(op);
|
||||
|
||||
auto values = tf_concat_op.values();
|
||||
auto output_type = tf_concat_op.output().getType();
|
||||
// Extract axis attribute from constant concat_dims tensor
|
||||
ElementsAttr axis;
|
||||
if (!matchPattern(tf_concat_op.concat_dim(), m_Constant(&axis)))
|
||||
return failure();
|
||||
|
||||
StringAttr fused_activation_function =
|
||||
StringAttr::get("NONE", rewriter.getContext());
|
||||
rewriter.replaceOpWithNewOp<TFL::ConcatenationOp>(
|
||||
op, output_type, values, mlir::TFL::ExtractSingleElementAsInteger(axis),
|
||||
fused_activation_function);
|
||||
return success();
|
||||
}
|
||||
|
||||
// Converts any IntegerAttr to an IntegerAttr of an i32 type.
|
||||
// The value won't change in the new attribute, but if the value is out of
|
||||
// the bound of i32, the function returns a failure.
|
||||
@ -517,6 +497,14 @@ StatusOr<ConstantOp> CreateConstOpWithSingleValue(PatternRewriter* rewriter,
|
||||
attr = DenseElementsAttr::get(scalar_type, floatValues);
|
||||
break;
|
||||
}
|
||||
case mlir::StandardTypes::BF16: {
|
||||
auto floatType = mlir::FloatType::getBF16(element_type.getContext());
|
||||
auto floatAttr =
|
||||
mlir::FloatAttr::get(floatType, static_cast<float>(value));
|
||||
std::vector<Attribute> floatValues({floatAttr});
|
||||
attr = DenseElementsAttr::get(scalar_type, floatValues);
|
||||
break;
|
||||
}
|
||||
case mlir::StandardTypes::F32: {
|
||||
attr =
|
||||
DenseElementsAttr::get<float>(scalar_type, static_cast<float>(value));
|
||||
@ -756,11 +744,11 @@ void LegalizeTF::runOnFunction() {
|
||||
|
||||
// Add the generated patterns to the list.
|
||||
populateWithGenerated(context, &patterns);
|
||||
patterns.insert<ConvertTFConcatOp, ConvertTFConcatV2Op, ConvertTFMatMulOp,
|
||||
ConvertTFMatrixDiagV2Op, ConvertTFMatrixDiagV3Op,
|
||||
ConvertTFPackOp, ConvertTFReshapeOp, ConvertTFSplitOp,
|
||||
ConvertTFSplitVOp, ConvertTFStridedSliceOp, ConvertTFUnpackOp,
|
||||
ConvertTFAssertOp, ConvertTFReciprocalOp,
|
||||
patterns
|
||||
.insert<ConvertTFConcatV2Op, ConvertTFMatMulOp, ConvertTFMatrixDiagV2Op,
|
||||
ConvertTFMatrixDiagV3Op, ConvertTFPackOp, ConvertTFReshapeOp,
|
||||
ConvertTFSplitOp, ConvertTFSplitVOp, ConvertTFStridedSliceOp,
|
||||
ConvertTFUnpackOp, ConvertTFAssertOp, ConvertTFReciprocalOp,
|
||||
ConvertTFRandomUniformOp, ConvertTFBroadcastToOp>(context);
|
||||
|
||||
// Ophint python converter converted tf node pattern.
|
||||
|
@ -32,8 +32,6 @@ namespace mlir {
|
||||
namespace TFL {
|
||||
namespace {
|
||||
|
||||
using FuncSet = llvm::SmallSet<FuncOp, 4>;
|
||||
|
||||
// Module pass to optimize TensorFlow functional ops.
|
||||
struct OptimizeFunctionalOpsPass
|
||||
: public PassWrapper<OptimizeFunctionalOpsPass, OperationPass<ModuleOp>> {
|
||||
@ -44,8 +42,8 @@ struct OptimizeFunctionalOpsPass
|
||||
// op operands' types.
|
||||
//
|
||||
// Requires the function has exactly one block.
|
||||
static void UpdateFuncType(FuncOp func) {
|
||||
Operation* terminator = &func.getBlocks().front().back();
|
||||
void UpdateFuncType(FuncOp func) {
|
||||
Operation* terminator = func.front().getTerminator();
|
||||
auto return_types = llvm::to_vector<4>(terminator->getOperandTypes());
|
||||
|
||||
FunctionType func_type = func.getType();
|
||||
@ -57,7 +55,7 @@ static void UpdateFuncType(FuncOp func) {
|
||||
}
|
||||
|
||||
// TODO(jpienaar): Remove when recursive side-effect modeling is added.
|
||||
static bool IsSideEffectFree(FuncOp func) {
|
||||
bool IsSideEffectFree(FuncOp func) {
|
||||
return !func.getBody()
|
||||
.walk([&](Operation* op) {
|
||||
if (!MemoryEffectOpInterface::hasNoEffect(op) &&
|
||||
@ -72,8 +70,8 @@ static bool IsSideEffectFree(FuncOp func) {
|
||||
// function body based on the conditional value.
|
||||
class FoldIfOp : public OpRewritePattern<TF::IfOp> {
|
||||
public:
|
||||
explicit FoldIfOp(MLIRContext* context, FuncSet* inlined_funcs)
|
||||
: OpRewritePattern<TF::IfOp>(context), inlined_funcs_(inlined_funcs) {}
|
||||
explicit FoldIfOp(MLIRContext* context)
|
||||
: OpRewritePattern<TF::IfOp>(context) {}
|
||||
|
||||
LogicalResult matchAndRewrite(TF::IfOp op,
|
||||
PatternRewriter& rewriter) const override {
|
||||
@ -82,7 +80,7 @@ class FoldIfOp : public OpRewritePattern<TF::IfOp> {
|
||||
// updated if operands' shapes change after inlining. Without this
|
||||
// restriction, it would require tensor cast ops.
|
||||
FuncOp parent_op = op.getParentOfType<FuncOp>();
|
||||
if (parent_op.getBlocks().size() != 1) return failure();
|
||||
if (!llvm::hasSingleElement(parent_op)) return failure();
|
||||
|
||||
// Find the then and else branch functions.
|
||||
SymbolTable table(op.getParentOfType<ModuleOp>());
|
||||
@ -95,8 +93,6 @@ class FoldIfOp : public OpRewritePattern<TF::IfOp> {
|
||||
if (op.use_empty() &&
|
||||
(op.is_stateless() ||
|
||||
(IsSideEffectFree(then_branch) && IsSideEffectFree(else_branch)))) {
|
||||
inlined_funcs_->insert(then_branch);
|
||||
inlined_funcs_->insert(else_branch);
|
||||
rewriter.eraseOp(op.getOperation());
|
||||
return success();
|
||||
}
|
||||
@ -118,14 +114,14 @@ class FoldIfOp : public OpRewritePattern<TF::IfOp> {
|
||||
// Make sure that the function has exactly one block to simplify inlining.
|
||||
// TFLite doesn't use control flow with blocks so functions with more than
|
||||
// one blocks are not encountered in practice.
|
||||
if (func.getBody().getBlocks().size() != 1) return failure();
|
||||
if (!llvm::hasSingleElement(func)) return failure();
|
||||
|
||||
BlockAndValueMapping mapper;
|
||||
for (int i = 0, e = func.getNumArguments(); i != e; ++i)
|
||||
mapper.map(func.getArgument(i), op.getOperand(i + 1));
|
||||
|
||||
llvm::SmallVector<Value, 4> updated_results;
|
||||
for (auto& op_to_inline : func.getBody().front()) {
|
||||
for (auto& op_to_inline : func.front()) {
|
||||
// If this is a terminator, identify the values to use to replace the
|
||||
// original If op.
|
||||
if (op_to_inline.isKnownTerminator()) {
|
||||
@ -145,64 +141,26 @@ class FoldIfOp : public OpRewritePattern<TF::IfOp> {
|
||||
// return type should be updated.
|
||||
UpdateFuncType(parent_op);
|
||||
|
||||
// Track functions that could be erased if this op was the last reference
|
||||
// of the function.
|
||||
inlined_funcs_->insert(then_branch);
|
||||
inlined_funcs_->insert(else_branch);
|
||||
return success();
|
||||
}
|
||||
|
||||
private:
|
||||
FuncSet* inlined_funcs_;
|
||||
};
|
||||
|
||||
// Erases functions from the given candidates that are not referenced by any of
|
||||
// the ops in the module.
|
||||
static void EraseDeadFuncs(const FuncSet& candidate_funcs, ModuleOp module) {
|
||||
if (candidate_funcs.empty()) return;
|
||||
|
||||
SymbolTable manager(module);
|
||||
|
||||
// Identify the functions that are used as symbols in the module and shouldn't
|
||||
// be erased.
|
||||
FuncSet in_use_funcs;
|
||||
manager.getOp()->walk([&](Operation* op) {
|
||||
for (auto attr : op->getAttrs()) {
|
||||
if (auto symbol = attr.second.dyn_cast<FlatSymbolRefAttr>()) {
|
||||
auto func = manager.lookup<FuncOp>(symbol.getValue());
|
||||
in_use_funcs.insert(func);
|
||||
}
|
||||
}
|
||||
});
|
||||
|
||||
for (FuncOp func : candidate_funcs) {
|
||||
if (!in_use_funcs.count(func)) manager.erase(func);
|
||||
}
|
||||
}
|
||||
|
||||
void OptimizeFunctionalOpsPass::runOnOperation() {
|
||||
OwningRewritePatternList patterns;
|
||||
|
||||
FuncSet inlined_funcs;
|
||||
patterns.insert<FoldIfOp>(&getContext(), &inlined_funcs);
|
||||
patterns.insert<FoldIfOp>(&getContext());
|
||||
|
||||
ModuleOp module = getOperation();
|
||||
applyPatternsAndFoldGreedily(module, patterns);
|
||||
|
||||
// Erase inlined functions that don't have any references.
|
||||
//
|
||||
// TODO(hinsu): Update this to not erase entry points once TFLite support to
|
||||
// have multiple entry points is implemented. Until then, it is safe to
|
||||
// erase these functions.
|
||||
EraseDeadFuncs(inlined_funcs, module);
|
||||
}
|
||||
|
||||
PassRegistration<OptimizeFunctionalOpsPass> pass(
|
||||
"tfl-optimize-functional-ops", "Optimize TensorFlow functional ops");
|
||||
} // namespace
|
||||
|
||||
std::unique_ptr<OperationPass<ModuleOp>> CreateOptimizeFunctionalOpsPass() {
|
||||
return std::make_unique<OptimizeFunctionalOpsPass>();
|
||||
}
|
||||
|
||||
static PassRegistration<OptimizeFunctionalOpsPass> pass(
|
||||
"tfl-optimize-functional-ops", "Optimize TensorFlow functional ops");
|
||||
} // namespace TFL
|
||||
} // namespace mlir
|
||||
|
@ -29,7 +29,7 @@ cc_library(
|
||||
# place for core related components.
|
||||
"//tensorflow/compiler/mlir/tensorflow:graph_optimization_pass_registration",
|
||||
"//tensorflow/compiler/mlir/tensorflow:import_utils",
|
||||
"@llvm-project//llvm:support",
|
||||
"@llvm-project//llvm:Support",
|
||||
"@llvm-project//mlir:IR",
|
||||
"@llvm-project//mlir:Parser",
|
||||
"@llvm-project//mlir:Pass",
|
||||
|
@ -20,7 +20,7 @@ tf_python_pybind_extension(
|
||||
"//tensorflow/compiler/mlir/tensorflow:tensorflow_types",
|
||||
"//tensorflow/python:pybind11_lib",
|
||||
"//tensorflow/python:pybind11_status",
|
||||
"@llvm-project//llvm:support",
|
||||
"@llvm-project//llvm:Support",
|
||||
"@llvm-project//mlir:IR",
|
||||
"@llvm-project//mlir:StandardOps",
|
||||
"@pybind11",
|
||||
@ -35,7 +35,7 @@ tf_python_pybind_extension(
|
||||
deps = [
|
||||
"//tensorflow/python:pybind11_lib",
|
||||
"//tensorflow/python:pybind11_status",
|
||||
"@llvm-project//llvm:support",
|
||||
"@llvm-project//llvm:Support",
|
||||
"@pybind11",
|
||||
],
|
||||
)
|
||||
|
@ -187,7 +187,7 @@ gentbl(
|
||||
td_file = "transforms/legalize_hlo_patterns.td",
|
||||
td_srcs = [
|
||||
"//tensorflow/compiler/mlir/xla:hlo_ops_td_files",
|
||||
"@llvm-project//llvm:support",
|
||||
"@llvm-project//llvm:Support",
|
||||
"@llvm-project//mlir:StdOpsTdFiles",
|
||||
"//tensorflow/compiler/mlir/tensorflow:tensorflow_ops_td_files",
|
||||
],
|
||||
@ -204,7 +204,7 @@ cc_library(
|
||||
":tensorflow",
|
||||
"//tensorflow/compiler/mlir/xla:hlo",
|
||||
"//tensorflow/core:framework",
|
||||
"@llvm-project//llvm:support",
|
||||
"@llvm-project//llvm:Support",
|
||||
"@llvm-project//mlir:Analysis",
|
||||
"@llvm-project//mlir:Dialect",
|
||||
"@llvm-project//mlir:IR",
|
||||
@ -225,7 +225,7 @@ cc_library(
|
||||
"ir/tf_attributes.h",
|
||||
],
|
||||
deps = [
|
||||
"@llvm-project//llvm:support",
|
||||
"@llvm-project//llvm:Support",
|
||||
"@llvm-project//mlir:IR",
|
||||
],
|
||||
)
|
||||
@ -240,7 +240,7 @@ cc_library(
|
||||
"ir/tf_types.h",
|
||||
],
|
||||
deps = [
|
||||
"@llvm-project//llvm:support",
|
||||
"@llvm-project//llvm:Support",
|
||||
"@llvm-project//mlir:Dialect",
|
||||
"@llvm-project//mlir:IR",
|
||||
],
|
||||
@ -293,7 +293,7 @@ cc_library(
|
||||
"//tensorflow/core:framework",
|
||||
"//tensorflow/core:lib",
|
||||
"//tensorflow/core/platform:logging",
|
||||
"@llvm-project//llvm:support",
|
||||
"@llvm-project//llvm:Support",
|
||||
"@llvm-project//mlir:Analysis",
|
||||
"@llvm-project//mlir:CallOpInterfacesIncGen",
|
||||
"@llvm-project//mlir:DerivedAttributeOpInterface",
|
||||
@ -388,7 +388,7 @@ cc_library(
|
||||
":tensorflow",
|
||||
"//tensorflow/core:framework",
|
||||
"@com_google_absl//absl/memory",
|
||||
"@llvm-project//llvm:support",
|
||||
"@llvm-project//llvm:Support",
|
||||
"@llvm-project//mlir:Analysis",
|
||||
"@llvm-project//mlir:IR",
|
||||
"@llvm-project//mlir:Pass",
|
||||
@ -416,6 +416,7 @@ cc_library(
|
||||
"transforms/fold_switch.cc",
|
||||
"transforms/freeze_global_tensors.cc",
|
||||
"transforms/functional_control_flow_to_cfg.cc",
|
||||
"transforms/fused_kernel_matcher.cc",
|
||||
"transforms/generated_canonicalize.inc",
|
||||
"transforms/generated_optimize.inc",
|
||||
"transforms/gpu_fusion.cc",
|
||||
@ -424,7 +425,6 @@ cc_library(
|
||||
"transforms/layout_optimization.cc",
|
||||
"transforms/mark_function_visibility.cc",
|
||||
"transforms/materialize_mlir_passthrough_op.cc",
|
||||
"transforms/op_fusion.cc",
|
||||
"transforms/optimize.cc",
|
||||
"transforms/optimize_global_tensors.cc",
|
||||
"transforms/parallel_execute_to_islands.cc",
|
||||
@ -500,7 +500,7 @@ cc_library(
|
||||
"//tensorflow/core/protobuf/tpu:dynamic_padding_proto_cc",
|
||||
"@com_google_absl//absl/memory",
|
||||
"@com_google_absl//absl/strings",
|
||||
"@llvm-project//llvm:support",
|
||||
"@llvm-project//llvm:Support",
|
||||
"@llvm-project//mlir:Analysis",
|
||||
"@llvm-project//mlir:IR",
|
||||
"@llvm-project//mlir:Parser",
|
||||
@ -609,7 +609,7 @@ cc_library(
|
||||
"@com_google_absl//absl/container:inlined_vector",
|
||||
"@com_google_absl//absl/strings",
|
||||
"@com_google_absl//absl/types:optional",
|
||||
"@llvm-project//llvm:support",
|
||||
"@llvm-project//llvm:Support",
|
||||
"@llvm-project//mlir:Analysis",
|
||||
"@llvm-project//mlir:IR",
|
||||
"@llvm-project//mlir:Pass",
|
||||
@ -639,7 +639,7 @@ cc_library(
|
||||
":parse_text_proto",
|
||||
"//tensorflow/core:lib",
|
||||
"@com_google_absl//absl/strings",
|
||||
"@llvm-project//llvm:support",
|
||||
"@llvm-project//llvm:Support",
|
||||
],
|
||||
)
|
||||
|
||||
@ -669,7 +669,7 @@ cc_library(
|
||||
"@com_google_absl//absl/container:flat_hash_set",
|
||||
"@com_google_absl//absl/memory",
|
||||
"@com_google_absl//absl/strings",
|
||||
"@llvm-project//llvm:support",
|
||||
"@llvm-project//llvm:Support",
|
||||
"@llvm-project//mlir:IR",
|
||||
"@llvm-project//mlir:StandardOps",
|
||||
"@llvm-project//mlir:Support",
|
||||
@ -712,7 +712,7 @@ cc_library(
|
||||
"//tensorflow/stream_executor/lib",
|
||||
"@com_google_absl//absl/container:flat_hash_set",
|
||||
"@com_google_absl//absl/strings",
|
||||
"@llvm-project//llvm:support",
|
||||
"@llvm-project//llvm:Support",
|
||||
"@llvm-project//mlir:IR",
|
||||
],
|
||||
)
|
||||
@ -722,7 +722,7 @@ cc_library(
|
||||
srcs = ["translate/translate_tf_dialect_op.cc"],
|
||||
deps = [
|
||||
":export_tf_dialect_op",
|
||||
"@llvm-project//llvm:support",
|
||||
"@llvm-project//llvm:Support",
|
||||
"@llvm-project//mlir:IR",
|
||||
"@llvm-project//mlir:Support",
|
||||
"@llvm-project//mlir:Translation",
|
||||
@ -781,7 +781,7 @@ cc_library(
|
||||
"@com_google_absl//absl/container:flat_hash_set",
|
||||
"@com_google_absl//absl/container:inlined_vector",
|
||||
"@com_google_absl//absl/strings",
|
||||
"@llvm-project//llvm:support",
|
||||
"@llvm-project//llvm:Support",
|
||||
],
|
||||
)
|
||||
|
||||
@ -799,7 +799,7 @@ cc_library(
|
||||
"//tensorflow/core:protos_all_cc",
|
||||
"//tensorflow/stream_executor/lib",
|
||||
"@com_google_absl//absl/strings",
|
||||
"@llvm-project//llvm:support",
|
||||
"@llvm-project//llvm:Support",
|
||||
"@llvm-project//mlir:IR",
|
||||
"@llvm-project//mlir:Support",
|
||||
],
|
||||
@ -816,7 +816,7 @@ tf_cc_test(
|
||||
"//tensorflow/core:test",
|
||||
"//tensorflow/core:test_main",
|
||||
"//tensorflow/stream_executor/lib",
|
||||
"@llvm-project//llvm:support",
|
||||
"@llvm-project//llvm:Support",
|
||||
"@llvm-project//mlir:IR",
|
||||
],
|
||||
)
|
||||
@ -838,7 +838,7 @@ cc_library(
|
||||
"@com_google_absl//absl/base",
|
||||
"@com_google_absl//absl/container:inlined_vector",
|
||||
"@com_google_absl//absl/strings",
|
||||
"@llvm-project//llvm:support",
|
||||
"@llvm-project//llvm:Support",
|
||||
"@llvm-project//mlir:IR",
|
||||
],
|
||||
)
|
||||
@ -882,7 +882,7 @@ cc_library(
|
||||
deps = [
|
||||
"//tensorflow/core/platform:errors",
|
||||
"//tensorflow/core/platform:status",
|
||||
"@llvm-project//llvm:support",
|
||||
"@llvm-project//llvm:Support",
|
||||
"@llvm-project//mlir:IR",
|
||||
],
|
||||
)
|
||||
@ -925,7 +925,7 @@ cc_library(
|
||||
"//tensorflow/core:lib",
|
||||
"//tensorflow/stream_executor",
|
||||
"//tensorflow/stream_executor/lib",
|
||||
"@llvm-project//llvm:support",
|
||||
"@llvm-project//llvm:Support",
|
||||
"@llvm-project//mlir:IR",
|
||||
"@llvm-project//mlir:SideEffects",
|
||||
"@llvm-project//mlir:Support",
|
||||
@ -957,7 +957,7 @@ cc_library(
|
||||
"//tensorflow/core:ops",
|
||||
"//tensorflow/core:protos_all_cc",
|
||||
"//tensorflow/stream_executor/lib",
|
||||
"@llvm-project//llvm:support",
|
||||
"@llvm-project//llvm:Support",
|
||||
"@llvm-project//mlir:IR",
|
||||
"@llvm-project//mlir:Pass",
|
||||
],
|
||||
@ -986,7 +986,7 @@ cc_library(
|
||||
"@com_google_absl//absl/container:inlined_vector",
|
||||
"@com_google_absl//absl/memory",
|
||||
"@com_google_absl//absl/strings",
|
||||
"@llvm-project//llvm:support",
|
||||
"@llvm-project//llvm:Support",
|
||||
"@llvm-project//mlir:IR",
|
||||
"@llvm-project//mlir:Support",
|
||||
],
|
||||
@ -1011,7 +1011,7 @@ cc_library(
|
||||
"@com_google_absl//absl/memory",
|
||||
"@com_google_absl//absl/strings",
|
||||
"@com_google_absl//absl/types:span",
|
||||
"@llvm-project//llvm:support",
|
||||
"@llvm-project//llvm:Support",
|
||||
"@llvm-project//mlir:IR",
|
||||
"@llvm-project//mlir:Parser",
|
||||
"@llvm-project//mlir:Pass",
|
||||
@ -1027,7 +1027,7 @@ cc_library(
|
||||
"translate/tf_mlir_translate_cl.h",
|
||||
],
|
||||
deps = [
|
||||
"@llvm-project//llvm:support",
|
||||
"@llvm-project//llvm:Support",
|
||||
],
|
||||
alwayslink = 1,
|
||||
)
|
||||
@ -1044,7 +1044,7 @@ cc_library(
|
||||
":translate_lib",
|
||||
"//tensorflow/core:protos_all_cc",
|
||||
"//tensorflow/stream_executor/lib",
|
||||
"@llvm-project//llvm:support",
|
||||
"@llvm-project//llvm:Support",
|
||||
"@llvm-project//mlir:IR",
|
||||
"@llvm-project//mlir:Translation",
|
||||
],
|
||||
@ -1060,7 +1060,7 @@ tf_cc_test(
|
||||
"//tensorflow/core:lib",
|
||||
"//tensorflow/core:test",
|
||||
"//tensorflow/core:test_main",
|
||||
"@llvm-project//llvm:support",
|
||||
"@llvm-project//llvm:Support",
|
||||
"@llvm-project//mlir:IR",
|
||||
],
|
||||
)
|
||||
@ -1071,8 +1071,8 @@ tf_native_cc_binary(
|
||||
"translate/derived_attr_populator_gen.cc",
|
||||
],
|
||||
deps = [
|
||||
"@llvm-project//llvm:support",
|
||||
"@llvm-project//llvm:tablegen",
|
||||
"@llvm-project//llvm:Support",
|
||||
"@llvm-project//llvm:TableGen",
|
||||
"@llvm-project//mlir:TableGen",
|
||||
],
|
||||
)
|
||||
@ -1134,7 +1134,7 @@ COMPILE_MLIR_UTIL_DEPS = [
|
||||
":tensorflow_passes",
|
||||
":translate_utils",
|
||||
"@com_google_absl//absl/types:optional",
|
||||
"@llvm-project//llvm:support",
|
||||
"@llvm-project//llvm:Support",
|
||||
"@llvm-project//mlir:IR",
|
||||
"@llvm-project//mlir:Parser",
|
||||
"@llvm-project//mlir:Pass",
|
||||
@ -1266,7 +1266,7 @@ cc_library(
|
||||
":tensorflow",
|
||||
":tensorflow_types",
|
||||
"//tensorflow/core:framework",
|
||||
"@llvm-project//llvm:support",
|
||||
"@llvm-project//llvm:Support",
|
||||
"@llvm-project//mlir:IR",
|
||||
],
|
||||
alwayslink = 1,
|
||||
@ -1285,7 +1285,7 @@ cc_library(
|
||||
"//tensorflow/core/protobuf/tpu:topology_proto_cc",
|
||||
"//tensorflow/stream_executor/lib",
|
||||
"@com_google_absl//absl/strings",
|
||||
"@llvm-project//llvm:support",
|
||||
"@llvm-project//llvm:Support",
|
||||
"@llvm-project//mlir:IR",
|
||||
],
|
||||
)
|
||||
@ -1300,7 +1300,7 @@ tf_cc_test(
|
||||
"//tensorflow/core:test",
|
||||
"//tensorflow/core:test_main",
|
||||
"//tensorflow/core/protobuf/tpu:topology_proto_cc",
|
||||
"@llvm-project//llvm:support",
|
||||
"@llvm-project//llvm:Support",
|
||||
"@llvm-project//mlir:IR",
|
||||
],
|
||||
)
|
||||
@ -1313,7 +1313,7 @@ cc_library(
|
||||
":tensorflow",
|
||||
"//tensorflow/core:core_cpu_lib",
|
||||
"//tensorflow/core:framework",
|
||||
"@llvm-project//llvm:support",
|
||||
"@llvm-project//llvm:Support",
|
||||
"@llvm-project//mlir:IR",
|
||||
"@llvm-project//mlir:Support",
|
||||
],
|
||||
@ -1331,7 +1331,7 @@ tf_cc_test(
|
||||
"//tensorflow/core:protos_all_cc",
|
||||
"//tensorflow/core:test",
|
||||
"//tensorflow/core:test_main",
|
||||
"@llvm-project//llvm:support",
|
||||
"@llvm-project//llvm:Support",
|
||||
"@llvm-project//mlir:IR",
|
||||
"@llvm-project//mlir:Support",
|
||||
],
|
||||
@ -1344,7 +1344,7 @@ cc_library(
|
||||
deps = [
|
||||
"//tensorflow/core:lib",
|
||||
"//tensorflow/core/platform:logging",
|
||||
"@llvm-project//llvm:support",
|
||||
"@llvm-project//llvm:Support",
|
||||
"@llvm-project//mlir:IR",
|
||||
],
|
||||
)
|
||||
@ -1359,7 +1359,7 @@ tf_cc_test(
|
||||
"//tensorflow/core:test",
|
||||
"//tensorflow/core:test_main",
|
||||
"//tensorflow/core/platform:test",
|
||||
"@llvm-project//llvm:support",
|
||||
"@llvm-project//llvm:Support",
|
||||
"@llvm-project//mlir:IR",
|
||||
],
|
||||
)
|
||||
@ -1378,7 +1378,7 @@ cc_library(
|
||||
"//tensorflow/core:graph",
|
||||
"//tensorflow/core:lib",
|
||||
"//tensorflow/core/platform:logging",
|
||||
"@llvm-project//llvm:support",
|
||||
"@llvm-project//llvm:Support",
|
||||
"@llvm-project//mlir:AllPassesAndDialects",
|
||||
"@llvm-project//mlir:Analysis",
|
||||
"@llvm-project//mlir:IR",
|
||||
@ -1398,7 +1398,7 @@ tf_cc_test(
|
||||
"//tensorflow/core:test",
|
||||
"//tensorflow/core:test_main",
|
||||
"//tensorflow/core/platform:test",
|
||||
"@llvm-project//llvm:support",
|
||||
"@llvm-project//llvm:Support",
|
||||
"@llvm-project//mlir:IR",
|
||||
],
|
||||
)
|
||||
@ -1409,7 +1409,7 @@ cc_library(
|
||||
hdrs = ["utils/bridge_logger.h"],
|
||||
deps = [
|
||||
":dump_mlir_util",
|
||||
"@llvm-project//llvm:support",
|
||||
"@llvm-project//llvm:Support",
|
||||
"@llvm-project//mlir:IR",
|
||||
"@llvm-project//mlir:Pass",
|
||||
],
|
||||
@ -1425,7 +1425,7 @@ cc_library(
|
||||
"//tensorflow/compiler/tf2xla:resource_operation_table",
|
||||
"//tensorflow/core:framework",
|
||||
"@com_google_absl//absl/strings",
|
||||
"@llvm-project//llvm:support",
|
||||
"@llvm-project//llvm:Support",
|
||||
"@llvm-project//mlir:IR",
|
||||
"@llvm-project//mlir:Support",
|
||||
],
|
||||
@ -1443,7 +1443,7 @@ cc_library(
|
||||
":tensorflow",
|
||||
"//tensorflow/compiler/xla:xla_data_proto_cc",
|
||||
"//tensorflow/core/protobuf/tpu:compile_metadata_proto_cc",
|
||||
"@llvm-project//llvm:support",
|
||||
"@llvm-project//llvm:Support",
|
||||
"@llvm-project//mlir:IR",
|
||||
"@llvm-project//mlir:Support",
|
||||
],
|
||||
|
@ -36,7 +36,7 @@ tf_cuda_library(
|
||||
"//tensorflow/core:framework",
|
||||
"//tensorflow/core:protos_all_cc",
|
||||
"//tensorflow/core/platform:errors",
|
||||
"@llvm-project//llvm:support",
|
||||
"@llvm-project//llvm:Support",
|
||||
"@llvm-project//mlir:IR",
|
||||
"@llvm-project//mlir:Pass",
|
||||
"@llvm-project//mlir:StandardOps",
|
||||
|
@ -794,6 +794,34 @@ This op is deprecated. Prefer `tf.nn.batch_normalization`.
|
||||
TF_DerivedOperandTypeAttr T = TF_DerivedOperandTypeAttr<0>;
|
||||
}
|
||||
|
||||
def TF_BatchToSpaceOp : TF_Op<"BatchToSpace", [NoSideEffect]> {
|
||||
let summary = "BatchToSpace for 4-D tensors of type T.";
|
||||
|
||||
let description = [{
|
||||
This is a legacy version of the more general BatchToSpaceND.
|
||||
|
||||
Rearranges (permutes) data from batch into blocks of spatial data, followed by
|
||||
cropping. This is the reverse transformation of SpaceToBatch. More specifically,
|
||||
this op outputs a copy of the input tensor where values from the `batch`
|
||||
dimension are moved in spatial blocks to the `height` and `width` dimensions,
|
||||
followed by cropping along the `height` and `width` dimensions.
|
||||
}];
|
||||
|
||||
let arguments = (ins
|
||||
TF_Tensor:$input,
|
||||
TF_I32OrI64Tensor:$crops,
|
||||
|
||||
Confined<I64Attr, [IntMinValue<2>]>:$block_size
|
||||
);
|
||||
|
||||
let results = (outs
|
||||
TF_Tensor:$output
|
||||
);
|
||||
|
||||
TF_DerivedOperandTypeAttr T = TF_DerivedOperandTypeAttr<0>;
|
||||
TF_DerivedOperandTypeAttr Tidx = TF_DerivedOperandTypeAttr<1>;
|
||||
}
|
||||
|
||||
def TF_BatchToSpaceNDOp : TF_Op<"BatchToSpaceND", [NoSideEffect]> {
|
||||
let summary = "BatchToSpace for N-D tensors of type T.";
|
||||
|
||||
@ -1219,6 +1247,35 @@ subsequent operation and then be optimized away, however.)
|
||||
}];
|
||||
}
|
||||
|
||||
def TF_BucketizeOp : TF_Op<"Bucketize", [NoSideEffect, SameOperandsAndResultShape]> {
|
||||
let summary = "Bucketizes 'input' based on 'boundaries'.";
|
||||
|
||||
let description = [{
|
||||
For example, if the inputs are
|
||||
boundaries = [0, 10, 100]
|
||||
input = [[-5, 10000]
|
||||
[150, 10]
|
||||
[5, 100]]
|
||||
|
||||
then the output will be
|
||||
output = [[0, 3]
|
||||
[3, 2]
|
||||
[1, 3]]
|
||||
}];
|
||||
|
||||
let arguments = (ins
|
||||
TensorOf<[F32, F64, I32, I64]>:$input,
|
||||
|
||||
F32ArrayAttr:$boundaries
|
||||
);
|
||||
|
||||
let results = (outs
|
||||
I32Tensor:$output
|
||||
);
|
||||
|
||||
TF_DerivedOperandTypeAttr T = TF_DerivedOperandTypeAttr<0>;
|
||||
}
|
||||
|
||||
def TF_CaseOp : TF_Op<"Case", []> {
|
||||
let summary = [{
|
||||
An n-way switch statement which calls a single branch function.
|
||||
@ -1257,6 +1314,8 @@ An n-way switch statement, implementing the following:
|
||||
|
||||
TF_DerivedOperandTypeListAttr Tin = TF_DerivedOperandTypeListAttr<1>;
|
||||
TF_DerivedResultTypeListAttr Tout = TF_DerivedResultTypeListAttr<0>;
|
||||
|
||||
let hasCanonicalizer = 1;
|
||||
}
|
||||
|
||||
def TF_CastOp : TF_Op<"Cast", [NoSideEffect, SameOperandsAndResultShape]> {
|
||||
@ -1519,6 +1578,8 @@ def TF_ConcatOp : TF_Op<"Concat", [NoSideEffect]> {
|
||||
let verifier = [{
|
||||
return Verify(*this);
|
||||
}];
|
||||
|
||||
let hasCanonicalizer = 1;
|
||||
}
|
||||
|
||||
def TF_ConcatOffsetOp : TF_Op<"ConcatOffset", [NoSideEffect]> {
|
||||
@ -2320,6 +2381,21 @@ horizontal and vertices strides, `strides = [1, stride, stride, 1]`.
|
||||
TF_DerivedOperandTypeAttr T = TF_DerivedOperandTypeAttr<0>;
|
||||
}
|
||||
|
||||
def TF_DeviceIndexOp : TF_Op<"DeviceIndex", [NoSideEffect]> {
|
||||
let summary = "Return the index of device the op runs.";
|
||||
|
||||
let description = [{
|
||||
}];
|
||||
|
||||
let arguments = (ins
|
||||
StrArrayAttr:$device_names
|
||||
);
|
||||
|
||||
let results = (outs
|
||||
I32Tensor:$index
|
||||
);
|
||||
}
|
||||
|
||||
def TF_DiagPartOp : TF_Op<"DiagPart", [NoSideEffect]> {
|
||||
let summary = "Returns the diagonal part of the tensor.";
|
||||
|
||||
@ -2968,6 +3044,63 @@ i.e. `exp(x) - 1` or `e^(x) - 1`, where `x` is the input tensor.
|
||||
TF_DerivedOperandTypeAttr T = TF_DerivedOperandTypeAttr<0>;
|
||||
}
|
||||
|
||||
def TF_FFTOp : TF_Op<"FFT", [NoSideEffect]> {
|
||||
let summary = "Fast Fourier transform.";
|
||||
|
||||
let description = [{
|
||||
Computes the 1-dimensional discrete Fourier transform over the inner-most
|
||||
dimension of `input`.
|
||||
}];
|
||||
|
||||
let arguments = (ins
|
||||
TensorOf<[TF_Complex128, TF_Complex64]>:$input
|
||||
);
|
||||
|
||||
let results = (outs
|
||||
TensorOf<[TF_Complex128, TF_Complex64]>:$output
|
||||
);
|
||||
|
||||
TF_DerivedOperandTypeAttr Tcomplex = TF_DerivedOperandTypeAttr<0>;
|
||||
}
|
||||
|
||||
def TF_FFT2DOp : TF_Op<"FFT2D", [NoSideEffect]> {
|
||||
let summary = "2D fast Fourier transform.";
|
||||
|
||||
let description = [{
|
||||
Computes the 2-dimensional discrete Fourier transform over the inner-most
|
||||
2 dimensions of `input`.
|
||||
}];
|
||||
|
||||
let arguments = (ins
|
||||
TensorOf<[TF_Complex128, TF_Complex64]>:$input
|
||||
);
|
||||
|
||||
let results = (outs
|
||||
TensorOf<[TF_Complex128, TF_Complex64]>:$output
|
||||
);
|
||||
|
||||
TF_DerivedOperandTypeAttr Tcomplex = TF_DerivedOperandTypeAttr<0>;
|
||||
}
|
||||
|
||||
def TF_FFT3DOp : TF_Op<"FFT3D", [NoSideEffect]> {
|
||||
let summary = "3D fast Fourier transform.";
|
||||
|
||||
let description = [{
|
||||
Computes the 3-dimensional discrete Fourier transform over the inner-most 3
|
||||
dimensions of `input`.
|
||||
}];
|
||||
|
||||
let arguments = (ins
|
||||
TensorOf<[TF_Complex128, TF_Complex64]>:$input
|
||||
);
|
||||
|
||||
let results = (outs
|
||||
TensorOf<[TF_Complex128, TF_Complex64]>:$output
|
||||
);
|
||||
|
||||
TF_DerivedOperandTypeAttr Tcomplex = TF_DerivedOperandTypeAttr<0>;
|
||||
}
|
||||
|
||||
def TF_FakeQuantWithMinMaxArgsOp : TF_Op<"FakeQuantWithMinMaxArgs", [NoSideEffect, SameOperandsAndResultType]> {
|
||||
let summary = [{
|
||||
Fake-quantize the 'inputs' tensor, type float to 'outputs' tensor of same type.
|
||||
@ -3727,6 +3860,161 @@ table will be immutable.
|
||||
);
|
||||
}
|
||||
|
||||
def TF_IFFTOp : TF_Op<"IFFT", [NoSideEffect]> {
|
||||
let summary = "Inverse fast Fourier transform.";
|
||||
|
||||
let description = [{
|
||||
Computes the inverse 1-dimensional discrete Fourier transform over the
|
||||
inner-most dimension of `input`.
|
||||
}];
|
||||
|
||||
let arguments = (ins
|
||||
TensorOf<[TF_Complex128, TF_Complex64]>:$input
|
||||
);
|
||||
|
||||
let results = (outs
|
||||
TensorOf<[TF_Complex128, TF_Complex64]>:$output
|
||||
);
|
||||
|
||||
TF_DerivedOperandTypeAttr Tcomplex = TF_DerivedOperandTypeAttr<0>;
|
||||
}
|
||||
|
||||
def TF_IFFT2DOp : TF_Op<"IFFT2D", [NoSideEffect]> {
|
||||
let summary = "Inverse 2D fast Fourier transform.";
|
||||
|
||||
let description = [{
|
||||
Computes the inverse 2-dimensional discrete Fourier transform over the
|
||||
inner-most 2 dimensions of `input`.
|
||||
}];
|
||||
|
||||
let arguments = (ins
|
||||
TensorOf<[TF_Complex128, TF_Complex64]>:$input
|
||||
);
|
||||
|
||||
let results = (outs
|
||||
TensorOf<[TF_Complex128, TF_Complex64]>:$output
|
||||
);
|
||||
|
||||
TF_DerivedOperandTypeAttr Tcomplex = TF_DerivedOperandTypeAttr<0>;
|
||||
}
|
||||
|
||||
def TF_IFFT3DOp : TF_Op<"IFFT3D", [NoSideEffect]> {
|
||||
let summary = "Inverse 3D fast Fourier transform.";
|
||||
|
||||
let description = [{
|
||||
Computes the inverse 3-dimensional discrete Fourier transform over the
|
||||
inner-most 3 dimensions of `input`.
|
||||
}];
|
||||
|
||||
let arguments = (ins
|
||||
TensorOf<[TF_Complex128, TF_Complex64]>:$input
|
||||
);
|
||||
|
||||
let results = (outs
|
||||
TensorOf<[TF_Complex128, TF_Complex64]>:$output
|
||||
);
|
||||
|
||||
TF_DerivedOperandTypeAttr Tcomplex = TF_DerivedOperandTypeAttr<0>;
|
||||
}
|
||||
|
||||
def TF_IRFFTOp : TF_Op<"IRFFT", [NoSideEffect]> {
|
||||
let summary = "Inverse real-valued fast Fourier transform.";
|
||||
|
||||
let description = [{
|
||||
Computes the inverse 1-dimensional discrete Fourier transform of a real-valued
|
||||
signal over the inner-most dimension of `input`.
|
||||
|
||||
The inner-most dimension of `input` is assumed to be the result of `RFFT`: the
|
||||
`fft_length / 2 + 1` unique components of the DFT of a real-valued signal. If
|
||||
`fft_length` is not provided, it is computed from the size of the inner-most
|
||||
dimension of `input` (`fft_length = 2 * (inner - 1)`). If the FFT length used to
|
||||
compute `input` is odd, it should be provided since it cannot be inferred
|
||||
properly.
|
||||
|
||||
Along the axis `IRFFT` is computed on, if `fft_length / 2 + 1` is smaller
|
||||
than the corresponding dimension of `input`, the dimension is cropped. If it is
|
||||
larger, the dimension is padded with zeros.
|
||||
}];
|
||||
|
||||
let arguments = (ins
|
||||
TensorOf<[TF_Complex128, TF_Complex64]>:$input,
|
||||
I32Tensor:$fft_length
|
||||
);
|
||||
|
||||
let results = (outs
|
||||
TF_F32OrF64Tensor:$output
|
||||
);
|
||||
|
||||
TF_DerivedResultTypeAttr Treal = TF_DerivedResultTypeAttr<0>;
|
||||
TF_DerivedOperandTypeAttr Tcomplex = TF_DerivedOperandTypeAttr<0>;
|
||||
}
|
||||
|
||||
def TF_IRFFT2DOp : TF_Op<"IRFFT2D", [NoSideEffect]> {
|
||||
let summary = "Inverse 2D real-valued fast Fourier transform.";
|
||||
|
||||
let description = [{
|
||||
Computes the inverse 2-dimensional discrete Fourier transform of a real-valued
|
||||
signal over the inner-most 2 dimensions of `input`.
|
||||
|
||||
The inner-most 2 dimensions of `input` are assumed to be the result of `RFFT2D`:
|
||||
The inner-most dimension contains the `fft_length / 2 + 1` unique components of
|
||||
the DFT of a real-valued signal. If `fft_length` is not provided, it is computed
|
||||
from the size of the inner-most 2 dimensions of `input`. If the FFT length used
|
||||
to compute `input` is odd, it should be provided since it cannot be inferred
|
||||
properly.
|
||||
|
||||
Along each axis `IRFFT2D` is computed on, if `fft_length` (or
|
||||
`fft_length / 2 + 1` for the inner-most dimension) is smaller than the
|
||||
corresponding dimension of `input`, the dimension is cropped. If it is larger,
|
||||
the dimension is padded with zeros.
|
||||
}];
|
||||
|
||||
let arguments = (ins
|
||||
TensorOf<[TF_Complex128, TF_Complex64]>:$input,
|
||||
I32Tensor:$fft_length
|
||||
);
|
||||
|
||||
let results = (outs
|
||||
TF_F32OrF64Tensor:$output
|
||||
);
|
||||
|
||||
TF_DerivedResultTypeAttr Treal = TF_DerivedResultTypeAttr<0>;
|
||||
TF_DerivedOperandTypeAttr Tcomplex = TF_DerivedOperandTypeAttr<0>;
|
||||
}
|
||||
|
||||
def TF_IRFFT3DOp : TF_Op<"IRFFT3D", [NoSideEffect]> {
|
||||
let summary = "Inverse 3D real-valued fast Fourier transform.";
|
||||
|
||||
let description = [{
|
||||
Computes the inverse 3-dimensional discrete Fourier transform of a real-valued
|
||||
signal over the inner-most 3 dimensions of `input`.
|
||||
|
||||
The inner-most 3 dimensions of `input` are assumed to be the result of `RFFT3D`:
|
||||
The inner-most dimension contains the `fft_length / 2 + 1` unique components of
|
||||
the DFT of a real-valued signal. If `fft_length` is not provided, it is computed
|
||||
from the size of the inner-most 3 dimensions of `input`. If the FFT length used
|
||||
to compute `input` is odd, it should be provided since it cannot be inferred
|
||||
properly.
|
||||
|
||||
Along each axis `IRFFT3D` is computed on, if `fft_length` (or
|
||||
`fft_length / 2 + 1` for the inner-most dimension) is smaller than the
|
||||
corresponding dimension of `input`, the dimension is cropped. If it is larger,
|
||||
the dimension is padded with zeros.
|
||||
}];
|
||||
|
||||
let arguments = (ins
|
||||
TensorOf<[TF_Complex128, TF_Complex64]>:$input,
|
||||
I32Tensor:$fft_length
|
||||
);
|
||||
|
||||
let results = (outs
|
||||
TF_F32OrF64Tensor:$output
|
||||
);
|
||||
|
||||
TF_DerivedResultTypeAttr Treal = TF_DerivedResultTypeAttr<0>;
|
||||
TF_DerivedOperandTypeAttr Tcomplex = TF_DerivedOperandTypeAttr<0>;
|
||||
}
|
||||
|
||||
def TF_IdentityNOp : TF_Op<"IdentityN", [NoSideEffect]> {
|
||||
let summary = [{
|
||||
Returns a list of tensors with the same shapes and contents as the input
|
||||
@ -4142,6 +4430,30 @@ convolutional neural networks (NIPS 2012)](http://papers.nips.cc/paper/4824-imag
|
||||
TF_DerivedOperandTypeAttr T = TF_DerivedOperandTypeAttr<0>;
|
||||
}
|
||||
|
||||
def TF_LRNGradOp : TF_Op<"LRNGrad", [NoSideEffect]> {
|
||||
let summary = "Gradients for Local Response Normalization.";
|
||||
|
||||
let description = [{
|
||||
}];
|
||||
|
||||
let arguments = (ins
|
||||
TensorOf<[BF16, F16, F32]>:$input_grads,
|
||||
TensorOf<[BF16, F16, F32]>:$input_image,
|
||||
TensorOf<[BF16, F16, F32]>:$output_image,
|
||||
|
||||
DefaultValuedAttr<I64Attr, "5">:$depth_radius,
|
||||
DefaultValuedAttr<F32Attr, "1.0f">:$bias,
|
||||
DefaultValuedAttr<F32Attr, "1.0f">:$alpha,
|
||||
DefaultValuedAttr<F32Attr, "0.5f">:$beta
|
||||
);
|
||||
|
||||
let results = (outs
|
||||
TensorOf<[BF16, F16, F32]>:$output
|
||||
);
|
||||
|
||||
TF_DerivedOperandTypeAttr T = TF_DerivedOperandTypeAttr<0>;
|
||||
}
|
||||
|
||||
def TF_LeakyReluOp : TF_Op<"LeakyRelu", [NoSideEffect, SameOperandsAndResultType]> {
|
||||
let summary = "Computes rectified linear: `max(features, features * alpha)`.";
|
||||
|
||||
@ -6333,6 +6645,66 @@ the dimension is padded with zeros.
|
||||
TF_DerivedResultTypeAttr Tcomplex = TF_DerivedResultTypeAttr<0>;
|
||||
}
|
||||
|
||||
def TF_RFFT2DOp : TF_Op<"RFFT2D", [NoSideEffect]> {
|
||||
let summary = "2D real-valued fast Fourier transform.";
|
||||
|
||||
let description = [{
|
||||
Computes the 2-dimensional discrete Fourier transform of a real-valued signal
|
||||
over the inner-most 2 dimensions of `input`.
|
||||
|
||||
Since the DFT of a real signal is Hermitian-symmetric, `RFFT2D` only returns the
|
||||
`fft_length / 2 + 1` unique components of the FFT for the inner-most dimension
|
||||
of `output`: the zero-frequency term, followed by the `fft_length / 2`
|
||||
positive-frequency terms.
|
||||
|
||||
Along each axis `RFFT2D` is computed on, if `fft_length` is smaller than the
|
||||
corresponding dimension of `input`, the dimension is cropped. If it is larger,
|
||||
the dimension is padded with zeros.
|
||||
}];
|
||||
|
||||
let arguments = (ins
|
||||
TF_F32OrF64Tensor:$input,
|
||||
I32Tensor:$fft_length
|
||||
);
|
||||
|
||||
let results = (outs
|
||||
TensorOf<[TF_Complex128, TF_Complex64]>:$output
|
||||
);
|
||||
|
||||
TF_DerivedOperandTypeAttr Treal = TF_DerivedOperandTypeAttr<0>;
|
||||
TF_DerivedResultTypeAttr Tcomplex = TF_DerivedResultTypeAttr<0>;
|
||||
}
|
||||
|
||||
def TF_RFFT3DOp : TF_Op<"RFFT3D", [NoSideEffect]> {
|
||||
let summary = "3D real-valued fast Fourier transform.";
|
||||
|
||||
let description = [{
|
||||
Computes the 3-dimensional discrete Fourier transform of a real-valued signal
|
||||
over the inner-most 3 dimensions of `input`.
|
||||
|
||||
Since the DFT of a real signal is Hermitian-symmetric, `RFFT3D` only returns the
|
||||
`fft_length / 2 + 1` unique components of the FFT for the inner-most dimension
|
||||
of `output`: the zero-frequency term, followed by the `fft_length / 2`
|
||||
positive-frequency terms.
|
||||
|
||||
Along each axis `RFFT3D` is computed on, if `fft_length` is smaller than the
|
||||
corresponding dimension of `input`, the dimension is cropped. If it is larger,
|
||||
the dimension is padded with zeros.
|
||||
}];
|
||||
|
||||
let arguments = (ins
|
||||
TF_F32OrF64Tensor:$input,
|
||||
I32Tensor:$fft_length
|
||||
);
|
||||
|
||||
let results = (outs
|
||||
TensorOf<[TF_Complex128, TF_Complex64]>:$output
|
||||
);
|
||||
|
||||
TF_DerivedOperandTypeAttr Treal = TF_DerivedOperandTypeAttr<0>;
|
||||
TF_DerivedResultTypeAttr Tcomplex = TF_DerivedResultTypeAttr<0>;
|
||||
}
|
||||
|
||||
def TF_RandomGammaGradOp : TF_Op<"RandomGammaGrad", [NoSideEffect, ResultsBroadcastableShape]>,
|
||||
WithBroadcastableBinOpBuilder {
|
||||
let summary = [{
|
||||
@ -8134,6 +8506,34 @@ def TF_SoftsignGradOp : TF_Op<"SoftsignGrad", [NoSideEffect, SameOperandsAndResu
|
||||
TF_DerivedOperandTypeAttr T = TF_DerivedOperandTypeAttr<0>;
|
||||
}
|
||||
|
||||
def TF_SpaceToBatchOp : TF_Op<"SpaceToBatch", [NoSideEffect]> {
|
||||
let summary = "SpaceToBatch for 4-D tensors of type T.";
|
||||
|
||||
let description = [{
|
||||
This is a legacy version of the more general SpaceToBatchND.
|
||||
|
||||
Zero-pads and then rearranges (permutes) blocks of spatial data into batch.
|
||||
More specifically, this op outputs a copy of the input tensor where values from
|
||||
the `height` and `width` dimensions are moved to the `batch` dimension. After
|
||||
the zero-padding, both `height` and `width` of the input must be divisible by the
|
||||
block size.
|
||||
}];
|
||||
|
||||
let arguments = (ins
|
||||
TF_Tensor:$input,
|
||||
TF_I32OrI64Tensor:$paddings,
|
||||
|
||||
Confined<I64Attr, [IntMinValue<2>]>:$block_size
|
||||
);
|
||||
|
||||
let results = (outs
|
||||
TF_Tensor:$output
|
||||
);
|
||||
|
||||
TF_DerivedOperandTypeAttr T = TF_DerivedOperandTypeAttr<0>;
|
||||
TF_DerivedOperandTypeAttr Tpaddings = TF_DerivedOperandTypeAttr<1>;
|
||||
}
|
||||
|
||||
def TF_SpaceToBatchNDOp : TF_Op<"SpaceToBatchND", [NoSideEffect]> {
|
||||
let summary = "SpaceToBatch for N-D tensors of type T.";
|
||||
|
||||
@ -10892,6 +11292,50 @@ create these operators.
|
||||
TF_DerivedOperandSizeAttr num_args = TF_DerivedOperandSizeAttr<2>;
|
||||
}
|
||||
|
||||
def TF__FusedMatMulOp : TF_Op<"_FusedMatMul", [NoSideEffect]> {
|
||||
let summary = [{
|
||||
Performs a MatMul followed by a specified series of operations.
|
||||
}];
|
||||
|
||||
let description = [{
|
||||
The inputs to the MatMul are specified by `a` and `b`. The series of operations
|
||||
that follows is specified by the `fused_ops` attribute, which is a list of TF op
|
||||
names specified as strings (e.g. "Relu"). They are performed in order, where the
|
||||
(first) input to each op is the output of the preceding op. The first input and
|
||||
the output of each fused_op must be of type T.
|
||||
|
||||
Currently supported fused_op combinations are: ["BiasAdd"] and ["BiasAdd",A],
|
||||
where A is one of {"Elu","Relu","Relu6"}.
|
||||
|
||||
* The first input to BiasAdd is the Conv2D result, and the additional BiasAdd
|
||||
input is specified by `args`.
|
||||
* If there is an op A specified, the output of the BiasAdd is the input to op A,
|
||||
and op A produces the _FusedConv2D output. Otherwise, the BiasAdd produces the
|
||||
_FusedConv2D output.
|
||||
|
||||
*NOTE*: Do not invoke this operator directly in Python. Grappler is
|
||||
expected to create these operators.
|
||||
}];
|
||||
|
||||
let arguments = (ins
|
||||
F32Tensor:$a,
|
||||
F32Tensor:$b,
|
||||
Variadic<F32Tensor>:$args,
|
||||
|
||||
DefaultValuedAttr<BoolAttr, "false">:$transpose_a,
|
||||
DefaultValuedAttr<BoolAttr, "false">:$transpose_b,
|
||||
DefaultValuedAttr<StrArrayAttr, "{}">:$fused_ops,
|
||||
DefaultValuedAttr<F32Attr, "0.0001f">:$epsilon
|
||||
);
|
||||
|
||||
let results = (outs
|
||||
F32Tensor:$product
|
||||
);
|
||||
|
||||
TF_DerivedOperandTypeAttr T = TF_DerivedOperandTypeAttr<0>;
|
||||
TF_DerivedOperandSizeAttr num_args = TF_DerivedOperandSizeAttr<2>;
|
||||
}
|
||||
|
||||
def TF__HostComputeMlirOp : TF_Op<"_HostComputeMlir", []> {
|
||||
let summary = "A host-side computation called from a TPU device.";
|
||||
|
||||
|
@ -44,6 +44,7 @@ limitations under the License.
|
||||
#include "mlir/IR/Diagnostics.h" // from @llvm-project
|
||||
#include "mlir/IR/DialectImplementation.h" // from @llvm-project
|
||||
#include "mlir/IR/Function.h" // from @llvm-project
|
||||
#include "mlir/IR/Identifier.h" // from @llvm-project
|
||||
#include "mlir/IR/Location.h" // from @llvm-project
|
||||
#include "mlir/IR/MLIRContext.h" // from @llvm-project
|
||||
#include "mlir/IR/Matchers.h" // from @llvm-project
|
||||
@ -68,6 +69,17 @@ limitations under the License.
|
||||
namespace mlir {
|
||||
namespace TF {
|
||||
|
||||
// Propagates underscore and device attributes from src to dst.
|
||||
// TODO(b/158769932): This should be a general feature instead post some policy
|
||||
// discussion.
|
||||
static void PropagateAttributes(Operation *src, Operation *dst) {
|
||||
auto device = mlir::Identifier::get("device", src->getContext());
|
||||
for (auto named_attr : src->getAttrs()) {
|
||||
if (*named_attr.first.begin() == '_' || named_attr.first == device)
|
||||
dst->setAttr(named_attr.first, named_attr.second);
|
||||
}
|
||||
}
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
// TF op helper functions
|
||||
//===----------------------------------------------------------------------===//
|
||||
@ -786,11 +798,49 @@ static LogicalResult Verify(BroadcastToOp op) {
|
||||
}
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
// CastOp
|
||||
// CaseOp
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
class FoldConstantCaseOp : public OpRewritePattern<TF::CaseOp> {
|
||||
public:
|
||||
explicit FoldConstantCaseOp(MLIRContext *context)
|
||||
: OpRewritePattern<TF::CaseOp>(context) {}
|
||||
LogicalResult matchAndRewrite(TF::CaseOp op,
|
||||
PatternRewriter &rewriter) const override;
|
||||
};
|
||||
|
||||
LogicalResult FoldConstantCaseOp::matchAndRewrite(
|
||||
TF::CaseOp op, PatternRewriter &rewriter) const {
|
||||
// Extract the constant cond value.
|
||||
DenseIntElementsAttr branch;
|
||||
if (!matchPattern(op.branch_index(), m_Constant(&branch))) return failure();
|
||||
|
||||
// Only attempt to fold scalar valued case statements.
|
||||
// TODO(jpienaar): This can be removed if CaseOp's verifier covers it.
|
||||
if (!branch.getType().cast<RankedTensorType>().getShape().empty())
|
||||
return failure();
|
||||
|
||||
int index = *branch.getValues<int>().begin();
|
||||
// TODO(jpienaar): This can be removed if CaseOp's verifier covers it.
|
||||
if (index >= op.branches().size()) return failure();
|
||||
|
||||
auto func = op.branches()[index].cast<SymbolRefAttr>();
|
||||
auto empty = rewriter.getStringAttr("");
|
||||
auto call_op = rewriter.create<PartitionedCallOp>(
|
||||
op.getLoc(), op.getResultTypes(), op.getOperands().drop_front(), func,
|
||||
/*config=*/empty, /*config_proto=*/empty, /*executor_type=*/empty);
|
||||
PropagateAttributes(op.getOperation(), call_op);
|
||||
rewriter.replaceOp(op, call_op.getResults());
|
||||
return success();
|
||||
}
|
||||
|
||||
void CaseOp::getCanonicalizationPatterns(OwningRewritePatternList &results,
|
||||
MLIRContext *context) {
|
||||
results.insert<FoldConstantCaseOp>(context);
|
||||
}
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
// LeakyReluOp
|
||||
// CastOp
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
OpFoldResult CastOp::fold(ArrayRef<Attribute> operands) {
|
||||
@ -823,6 +873,11 @@ static LogicalResult Verify(OpT op) {
|
||||
/*mask_one_dim=*/true, op.getOperation());
|
||||
}
|
||||
|
||||
void ConcatOp::getCanonicalizationPatterns(OwningRewritePatternList &results,
|
||||
MLIRContext *context) {
|
||||
results.insert<ConvertToConcatV2>(context);
|
||||
}
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
// ConcatOffsetOp
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
@ -334,7 +334,7 @@ SmallVector<StringRef, 2> GetExportedNames(Operation *op) {
|
||||
bool IsExported(Operation *op) {
|
||||
auto exported_names =
|
||||
op->getAttrOfType<ArrayAttr>("tf_saved_model.exported_names");
|
||||
return exported_names && exported_names.size() != 0;
|
||||
return exported_names && !exported_names.empty();
|
||||
}
|
||||
|
||||
bool HasTfSavedModelSemantics(ModuleOp module) {
|
||||
|
@ -133,6 +133,16 @@ func @testSameCastTypeAcrossBasicBlocks(tensor<8x16x32x64xf32>) -> tensor<8x16x3
|
||||
// CHECK: return %arg0
|
||||
}
|
||||
|
||||
// CHECK-LABEL: testConcatCanonicalization
|
||||
func @testConcatCanonicalization(%arg0: tensor<2x1xi32>, %arg1: tensor<2x1xi32>) -> tensor<2x2xi32> {
|
||||
// CHECK: %[[AXIS:.*]] = "tf.Const"
|
||||
%0 = "tf.Const"() { value = dense<1> : tensor<i32> } : () -> tensor<i32>
|
||||
|
||||
// CHECK: "tf.ConcatV2"(%arg0, %arg1, %[[AXIS]])
|
||||
%1 = "tf.Concat"(%0, %arg0, %arg1) : (tensor<i32>, tensor<2x1xi32>, tensor<2x1xi32>) -> tensor<2x2xi32>
|
||||
return %1 : tensor<2x2xi32>
|
||||
}
|
||||
|
||||
// CHECK-LABEL: testLogOfSoftmax
|
||||
func @testLogOfSoftmax(%arg0: tensor<8x16xf32>) -> tensor<8x16xf32> {
|
||||
%0 = "tf.Softmax"(%arg0) : (tensor<8x16xf32>) -> tensor<8x16xf32>
|
||||
@ -550,3 +560,29 @@ func @foldFill() -> (tensor<3x2x1xf32>, tensor<*xf32>, tensor<*xcomplex<f32>>) {
|
||||
|
||||
return %2, %3, %4 : tensor<3x2x1xf32>, tensor<*xf32>, tensor<*xcomplex<f32>>
|
||||
}
|
||||
|
||||
// CHECK-LABEL: foldCase
|
||||
func @foldCase(%arg0: tensor<f32>, %arg1: tensor<f32>) -> (tensor<f32>) {
|
||||
%2 = constant dense<1> : tensor<i32>
|
||||
%3 = constant dense<0> : tensor<i32>
|
||||
|
||||
// CHECK: PartitionedCall
|
||||
// CHECK-SAME: device = "noodle"
|
||||
// CHECK-SAME: f = @add
|
||||
%4 = "tf.Case"(%2, %arg0, %arg1) {branches = [@sub, @add], output_shapes = [#tf.shape<>], device = "noodle"} : (tensor<i32>, tensor<f32>, tensor<f32>) -> tensor<f32>
|
||||
// CHECK: PartitionedCall
|
||||
// CHECK-SAME: _cluster_launch = "not_ready"
|
||||
// CHECK-SAME: f = @sub
|
||||
%5 = "tf.Case"(%3, %4, %arg1) {branches = [@sub, @add], output_shapes = [#tf.shape<>], _cluster_launch = "not_ready"} : (tensor<i32>, tensor<f32>, tensor<f32>) -> tensor<f32>
|
||||
return %5 : tensor<f32>
|
||||
}
|
||||
|
||||
func @add(%arg0: tensor<*xf32>, %arg1: tensor<*xf32>) -> tensor<*xf32> {
|
||||
%0 = "tf.Add"(%arg0, %arg1): (tensor<*xf32>, tensor<*xf32>) -> tensor<*xf32>
|
||||
return %0 : tensor<*xf32>
|
||||
}
|
||||
|
||||
func @sub(%arg0: tensor<*xf32>, %arg1: tensor<*xf32>) -> tensor<*xf32> {
|
||||
%0 = "tf.Sub"(%arg0, %arg1) : (tensor<*xf32>, tensor<*xf32>) -> tensor<*xf32>
|
||||
return %0 : tensor<*xf32>
|
||||
}
|
||||
|
@ -370,7 +370,6 @@ func @decompose_resource_gather_op(%indices : tensor<5xi32>) -> tensor<2x5x16xi3
|
||||
|
||||
// Tests that composite tf.ResourceScatterUpdate operation is decomposed.
|
||||
|
||||
|
||||
// CHECK-LABEL: @decompose_resource_scatter_update_op
|
||||
// CHECK-SAME: ([[INDEX:%.+]]: tensor<2x?xi32>, [[UPDATE:%.+]]: tensor<?x?x?xi32>)
|
||||
func @decompose_resource_scatter_update_op(%indices : tensor<2x?xi32>, %updates: tensor<?x?x?xi32>) {
|
||||
@ -384,3 +383,34 @@ func @decompose_resource_scatter_update_op(%indices : tensor<2x?xi32>, %updates:
|
||||
|
||||
return
|
||||
}
|
||||
|
||||
// -----
|
||||
|
||||
// Tests that tf.VariableShape operation is decomposed.
|
||||
|
||||
// CHECK-LABEL: @decompose_variable_shape_i32
|
||||
func @decompose_variable_shape_i32(%input: tensor<!tf.resource<tensor<?x?x?xf32>>>) -> tensor<3xi32> {
|
||||
%0 = "tf.VariableShape"(%input) : (tensor<!tf.resource<tensor<?x?x?xf32>>>) -> tensor<3xi32>
|
||||
// CHECK: %[[READ:.*]] = "tf.ReadVariableOp"(%arg0)
|
||||
// CHECK: %[[SHAPE:.*]] = "tf.Shape"(%[[READ]])
|
||||
// CHECK: return %[[SHAPE]]
|
||||
return %0 : tensor<3xi32>
|
||||
}
|
||||
|
||||
// CHECK-LABEL: @decompose_variable_shape_i64
|
||||
func @decompose_variable_shape_i64(%input: tensor<!tf.resource<tensor<?x?x?xf32>>>) -> tensor<3xi64> {
|
||||
%0 = "tf.VariableShape"(%input) : (tensor<!tf.resource<tensor<?x?x?xf32>>>) -> tensor<3xi64>
|
||||
// CHECK: %[[READ:.*]] = "tf.ReadVariableOp"(%arg0)
|
||||
// CHECK: %[[SHAPE:.*]] = "tf.Shape"(%[[READ]])
|
||||
// CHECK: return %[[SHAPE]]
|
||||
return %0 : tensor<3xi64>
|
||||
}
|
||||
|
||||
// CHECK-LABEL: @decompose_variable_shape_no_subtype
|
||||
func @decompose_variable_shape_no_subtype(%input: tensor<!tf.resource>) -> tensor<3xi32> {
|
||||
%0 = "tf.VariableShape"(%input) : (tensor<!tf.resource>) -> tensor<3xi32>
|
||||
// CHECK: "tf.VariableShape"
|
||||
// CHECK-NOT: "tf.ReadVariableOp"
|
||||
// CHECK-NOT: "tf.Shape"
|
||||
return %0 : tensor<3xi32>
|
||||
}
|
||||
|
@ -1,4 +1,4 @@
|
||||
// RUN: tf-opt %s -tf-op-fusion | FileCheck %s
|
||||
// RUN: tf-opt %s -tf-fused-kernel-matcher | FileCheck %s
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
// Conv2D + BiasAdd + <Activation> fusions.
|
||||
@ -107,3 +107,54 @@ func @conv2D_dataFormatMismatch(%arg0: tensor<128xf32>, %arg1: tensor<1x1x3x128x
|
||||
%3 = "tf.Identity"(%2) : (tensor<*xf32>) -> tensor<*xf32>
|
||||
return %3 : tensor<*xf32>
|
||||
}
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
// MatMul + BiasAdd + <Activation> fusions.
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
// CHECK-LABEL: matmulBiasAdd
|
||||
func @matmulBiasAdd(%arg0: tensor<64xf32>, %arg1: tensor<8x32xf32>, %arg2: tensor<32x64xf32>) -> (tensor<*xf32>) {
|
||||
// CHECK: %[[VAL_3:.*]] = "tf._FusedMatMul"(%arg1, %arg2, %arg0) {epsilon = 0.000000e+00 : f32, fused_ops = ["BiasAdd"], transpose_a = false, transpose_b = false} : (tensor<8x32xf32>, tensor<32x64xf32>, tensor<64xf32>) -> tensor<*xf32>
|
||||
// CHECK: %[[VAL_4:.*]] = "tf.Identity"(%[[VAL_3]]) : (tensor<*xf32>) -> tensor<*xf32>
|
||||
// CHECK: return %[[VAL_4]]
|
||||
%3 = "tf.MatMul"(%arg1, %arg2) {transpose_a = false, transpose_b = false} : (tensor<8x32xf32>, tensor<32x64xf32>) -> tensor<*xf32>
|
||||
%4 = "tf.BiasAdd"(%3, %arg0) {data_format = "NHWC"} : (tensor<*xf32>, tensor<64xf32>) -> tensor<*xf32>
|
||||
%5 = "tf.Identity"(%4) : (tensor<*xf32>) -> tensor<*xf32>
|
||||
return %5 : tensor<*xf32>
|
||||
}
|
||||
|
||||
// CHECK-LABEL: matmulBiasAdd_relu
|
||||
func @matmulBiasAdd_relu(%arg0: tensor<64xf32>, %arg1: tensor<8x32xf32>, %arg2: tensor<32x64xf32>) -> (tensor<*xf32>) {
|
||||
// CHECK: %[[VAL_3:.*]] = "tf._FusedMatMul"(%arg1, %arg2, %arg0) {epsilon = 0.000000e+00 : f32, fused_ops = ["BiasAdd", "Relu"], transpose_a = false, transpose_b = false} : (tensor<8x32xf32>, tensor<32x64xf32>, tensor<64xf32>) -> tensor<*xf32>
|
||||
// CHECK: %[[VAL_4:.*]] = "tf.Identity"(%[[VAL_3]]) : (tensor<*xf32>) -> tensor<*xf32>
|
||||
// CHECK: return %[[VAL_4]]
|
||||
%3 = "tf.MatMul"(%arg1, %arg2) {transpose_a = false, transpose_b = false} : (tensor<8x32xf32>, tensor<32x64xf32>) -> tensor<*xf32>
|
||||
%4 = "tf.BiasAdd"(%3, %arg0) {data_format = "NHWC"} : (tensor<*xf32>, tensor<64xf32>) -> tensor<*xf32>
|
||||
%5 = "tf.Relu"(%4) : (tensor<*xf32>) -> tensor<*xf32>
|
||||
%6 = "tf.Identity"(%5) : (tensor<*xf32>) -> tensor<*xf32>
|
||||
return %6 : tensor<*xf32>
|
||||
}
|
||||
|
||||
// CHECK-LABEL: matmulBiasAdd_relu6
|
||||
func @matmulBiasAdd_relu6(%arg0: tensor<64xf32>, %arg1: tensor<8x32xf32>, %arg2: tensor<32x64xf32>) -> (tensor<*xf32>) {
|
||||
// CHECK: %[[VAL_3:.*]] = "tf._FusedMatMul"(%arg1, %arg2, %arg0) {epsilon = 0.000000e+00 : f32, fused_ops = ["BiasAdd", "Relu6"], transpose_a = false, transpose_b = false} : (tensor<8x32xf32>, tensor<32x64xf32>, tensor<64xf32>) -> tensor<*xf32>
|
||||
// CHECK: %[[VAL_4:.*]] = "tf.Identity"(%[[VAL_3]]) : (tensor<*xf32>) -> tensor<*xf32>
|
||||
// CHECK: return %[[VAL_4]]
|
||||
%3 = "tf.MatMul"(%arg1, %arg2) {transpose_a = false, transpose_b = false} : (tensor<8x32xf32>, tensor<32x64xf32>) -> tensor<*xf32>
|
||||
%4 = "tf.BiasAdd"(%3, %arg0) {data_format = "NHWC"} : (tensor<*xf32>, tensor<64xf32>) -> tensor<*xf32>
|
||||
%5 = "tf.Relu6"(%4) : (tensor<*xf32>) -> tensor<*xf32>
|
||||
%6 = "tf.Identity"(%5) : (tensor<*xf32>) -> tensor<*xf32>
|
||||
return %6 : tensor<*xf32>
|
||||
}
|
||||
|
||||
// CHECK-LABEL: matmulBiasAdd_elu
|
||||
func @matmulBiasAdd_elu(%arg0: tensor<64xf32>, %arg1: tensor<8x32xf32>, %arg2: tensor<32x64xf32>) -> (tensor<*xf32>) {
|
||||
// CHECK: %[[VAL_3:.*]] = "tf._FusedMatMul"(%arg1, %arg2, %arg0) {epsilon = 0.000000e+00 : f32, fused_ops = ["BiasAdd", "Elu"], transpose_a = false, transpose_b = false} : (tensor<8x32xf32>, tensor<32x64xf32>, tensor<64xf32>) -> tensor<*xf32>
|
||||
// CHECK: %[[VAL_4:.*]] = "tf.Identity"(%[[VAL_3]]) : (tensor<*xf32>) -> tensor<*xf32>
|
||||
// CHECK: return %[[VAL_4]]
|
||||
%3 = "tf.MatMul"(%arg1, %arg2) {transpose_a = false, transpose_b = false} : (tensor<8x32xf32>, tensor<32x64xf32>) -> tensor<*xf32>
|
||||
%4 = "tf.BiasAdd"(%3, %arg0) {data_format = "NHWC"} : (tensor<*xf32>, tensor<64xf32>) -> tensor<*xf32>
|
||||
%5 = "tf.Elu"(%4) : (tensor<*xf32>) -> tensor<*xf32>
|
||||
%6 = "tf.Identity"(%5) : (tensor<*xf32>) -> tensor<*xf32>
|
||||
return %6 : tensor<*xf32>
|
||||
}
|
@ -49,5 +49,5 @@ library {
|
||||
}
|
||||
}
|
||||
|
||||
# CHECK-DAG: func @custom_relu{{[0-9]*}}() attributes {tf._implements = #tf.func<@tensorflow.relu, {}>}
|
||||
# CHECK-DAG: func @custom_embedding_matmul{{[0-9]*}}() attributes {tf._implements = #tf.func<@tensorflow.embedding_matmul, {key1 = 2 : i64, key2 = false}>}
|
||||
# CHECK-DAG: func @custom_relu{{[0-9]*}}(){{.+}}tf._implements = #tf.func<@tensorflow.relu, {}>}
|
||||
# CHECK-DAG: func @custom_embedding_matmul{{[0-9]*}}(){{.+}}tf._implements = #tf.func<@tensorflow.embedding_matmul, {key1 = 2 : i64, key2 = false}>}
|
||||
|
@ -124,5 +124,5 @@ versions {
|
||||
# CHECK: "tf.LegacyCall"() {_disable_call_shape_inference = false, f = @foo110}
|
||||
# CHECK: "tf.LegacyCall"() {_disable_call_shape_inference = false, f = @foo111}
|
||||
|
||||
# CHECK-LABEL: func @foo110() {
|
||||
# CHECK-LABEL: func @foo111() {
|
||||
# CHECK-LABEL: func @foo110() attributes {sym_visibility = "private"}
|
||||
# CHECK-LABEL: func @foo111() attributes {sym_visibility = "private"}
|
||||
|
@ -57,7 +57,7 @@ versions {
|
||||
# CHECK: "tf.LegacyCall"() {_disable_call_shape_inference = true, f = @foo0}
|
||||
# CHECK: "tf.LegacyCall"() {_disable_call_shape_inference = false, f = @bar0}
|
||||
|
||||
# CHECK-LABEL: func @foo0() {
|
||||
# CHECK-LABEL: func @foo0() attributes {sym_visibility = "private"}
|
||||
# CHECK: "tf.LegacyCall"() {_disable_call_shape_inference = false, f = @bar0}
|
||||
|
||||
# CHECK-LABEL: func @bar0() {
|
||||
# CHECK-LABEL: func @bar0() attributes {sym_visibility = "private"}
|
||||
|
@ -7,6 +7,7 @@ module attributes {tf.versions = {producer = 888 : i32}, tf.devices = ["/job:wor
|
||||
func @head_single_outside_compiled_op(%arg0: tensor<i32>) {
|
||||
// CHECK: "tf_device.launch"
|
||||
// CHECK-NEXT: "tf.A"
|
||||
// CHECK-NOT: _xla_outside_compilation
|
||||
// CHECK-NEXT: tf_device.return
|
||||
// CHECK-NEXT: device = "/job:worker/replica:0/task:0/device:CPU:0"
|
||||
//
|
||||
@ -27,6 +28,7 @@ module attributes {tf.versions = {producer = 888 : i32}, tf.devices = ["/job:wor
|
||||
func @head_single_outside_compiled_op_no_operands() {
|
||||
// CHECK: %[[LAUNCH_OUT:.*]] = "tf_device.launch"
|
||||
// CHECK-NEXT: %[[A_OUT:.*]] = "tf.A"
|
||||
// CHECK-NOT: _xla_outside_compilation
|
||||
// CHECK-NEXT: tf_device.return %[[A_OUT]]
|
||||
// CHECK-NEXT: device = "/job:worker/replica:0/task:0/device:CPU:0"
|
||||
//
|
||||
@ -49,6 +51,7 @@ module attributes {tf.versions = {producer = 888 : i32}, tf.devices = ["/job:wor
|
||||
%a = "tf.A"() : () -> tensor<i32>
|
||||
// CHECK-NEXT: %[[LAUNCH_OUT:.*]] = "tf_device.launch"
|
||||
// CHECK-NEXT: %[[B_OUT:.*]] = "tf.B"
|
||||
// CHECK-NOT: _xla_outside_compilation
|
||||
// CHECK-NEXT: tf_device.return %[[B_OUT]]
|
||||
// CHECK-NEXT: device = "/job:worker/replica:0/task:0/device:CPU:0"
|
||||
//
|
||||
@ -69,6 +72,7 @@ module attributes {tf.versions = {producer = 888 : i32}, tf.devices = ["/job:wor
|
||||
func @head_aliased_output() -> (tensor<i32>, tensor<i32>, tensor<i32>) {
|
||||
// CHECK: %[[LAUNCH_OUT:.*]] = "tf_device.launch"
|
||||
// CHECK-NEXT: %[[A_OUT:.*]] = "tf.A"
|
||||
// CHECK-NOT: _xla_outside_compilation
|
||||
// CHECK-NEXT: tf_device.return %[[A_OUT]]
|
||||
// CHECK-NEXT: device = "/job:worker/replica:0/task:0/device:CPU:0"
|
||||
//
|
||||
@ -96,8 +100,11 @@ module attributes {tf.versions = {producer = 888 : i32}, tf.devices = ["/job:wor
|
||||
func @head_all_cluster_op(%arg0: tensor<i32>) -> tensor<i32> {
|
||||
// CHECK: %[[LAUNCH_OUT:.*]] = "tf_device.launch"
|
||||
// CHECK-NEXT: %[[A_OUT:.*]] = "tf.A"
|
||||
// CHECK-NOT: _xla_outside_compilation
|
||||
// CHECK-NEXT: %[[B_OUT:.*]] = "tf.B"(%[[A_OUT]])
|
||||
// CHECK-NOT: _xla_outside_compilation
|
||||
// CHECK-NEXT: %[[C_OUT:.*]] = "tf.C"(%[[B_OUT]], %arg0)
|
||||
// CHECK-NOT: _xla_outside_compilation
|
||||
// CHECK-NEXT: tf_device.return %[[C_OUT]]
|
||||
// CHECK-NEXT: device = "/job:worker/replica:0/task:0/device:CPU:0"
|
||||
//
|
||||
@ -117,8 +124,11 @@ module attributes {tf.versions = {producer = 888 : i32}, tf.devices = ["/job:wor
|
||||
func @head_multiple_outside_compiled_ops(%arg0: tensor<i32>) {
|
||||
// CHECK: %[[LAUNCH_OUT:.*]] = "tf_device.launch"
|
||||
// CHECK-NEXT: %[[A_OUT:.*]] = "tf.A"
|
||||
// CHECK-NOT: _xla_outside_compilation
|
||||
// CHECK-NEXT: %[[B_OUT:.*]] = "tf.B"(%[[A_OUT]])
|
||||
// CHECK-NOT: _xla_outside_compilation
|
||||
// CHECK-NEXT: "tf.C"
|
||||
// CHECK-NOT: _xla_outside_compilation
|
||||
// CHECK-NEXT: tf_device.return %[[B_OUT]]
|
||||
// CHECK-NEXT: device = "/job:worker/replica:0/task:0/device:CPU:0"
|
||||
//
|
||||
@ -141,6 +151,7 @@ module attributes {tf.versions = {producer = 888 : i32}, tf.devices = ["/job:wor
|
||||
//
|
||||
// CHECK-NEXT: %[[LAUNCH_OUT:.*]] = "tf_device.launch"()
|
||||
// CHECK-NEXT: %[[A_OUT:.*]] = "tf.A"(%[[RI]])
|
||||
// CHECK-NOT: _xla_outside_compilation
|
||||
// CHECK-NEXT: tf_device.return %[[A_OUT]]
|
||||
// CHECK-NEXT: device = "TPU_REPLICATED_HOST"
|
||||
//
|
||||
@ -173,6 +184,7 @@ module attributes {tf.versions = {producer = 888 : i32}, tf.devices = ["/job:wor
|
||||
//
|
||||
// CHECK: "tf_device.launch"
|
||||
// CHECK-NEXT: "tf.B"(%[[CLUSTER_OUT]])
|
||||
// CHECK-NOT: _xla_outside_compilation
|
||||
// CHECK-NEXT: tf_device.return
|
||||
// CHECK-NEXT: device = "/job:worker/replica:0/task:0/device:CPU:0"
|
||||
"tf_device.cluster"() ( {
|
||||
@ -199,6 +211,7 @@ module attributes {tf.versions = {producer = 888 : i32}, tf.devices = ["/job:wor
|
||||
//
|
||||
// CHECK: %[[LAUNCH_OUT:.*]] = "tf_device.launch"
|
||||
// CHECK-NEXT: %[[B_OUT:.*]] = "tf.B"(%[[CLUSTER_OUT]])
|
||||
// CHECK-NOT: _xla_outside_compilation
|
||||
// CHECK-NEXT: tf_device.return %[[B_OUT]]
|
||||
// CHECK-NEXT: device = "/job:worker/replica:0/task:0/device:CPU:0"
|
||||
%cluster = "tf_device.cluster"() ( {
|
||||
@ -226,7 +239,9 @@ module attributes {tf.versions = {producer = 888 : i32}, tf.devices = ["/job:wor
|
||||
//
|
||||
// CHECK: "tf_device.launch"
|
||||
// CHECK-NEXT: %[[C_OUT:.*]] = "tf.C"(%arg0, %[[CLUSTER_OUT]]#1)
|
||||
// CHECK-NOT: _xla_outside_compilation
|
||||
// CHECK-NEXT: "tf.D"(%[[C_OUT]], %arg0, %[[CLUSTER_OUT]]#0)
|
||||
// CHECK-NOT: _xla_outside_compilation
|
||||
// CHECK-NEXT: tf_device.return
|
||||
// CHECK-NEXT: device = "/job:worker/replica:0/task:0/device:CPU:0"
|
||||
"tf_device.cluster"() ( {
|
||||
@ -258,6 +273,7 @@ module attributes {tf.versions = {producer = 888 : i32}, tf.devices = ["/job:wor
|
||||
//
|
||||
// CHECK: %[[LAUNCH_OUT:.*]] = "tf_device.launch"
|
||||
// CHECK-NEXT: %[[D_OUT:.*]] = "tf.D"(%[[CLUSTER_OUT]]#0, %[[A_OUT]])
|
||||
// CHECK-NOT: _xla_outside_compilation
|
||||
// CHECK-NEXT: tf_device.return
|
||||
// CHECK-NEXT: device = "/job:worker/replica:0/task:0/device:CPU:0"
|
||||
%cluster:5 = "tf_device.cluster"() ( {
|
||||
@ -286,6 +302,7 @@ module attributes {tf.versions = {producer = 888 : i32}, tf.devices = ["/job:wor
|
||||
//
|
||||
// CHECK-NEXT: "tf_device.launch"()
|
||||
// CHECK-NEXT: %[[B_OUT:.*]] = "tf.B"(%[[CLUSTER_OUT]], %[[RI]])
|
||||
// CHECK-NOT: _xla_outside_compilation
|
||||
// CHECK-NEXT: tf_device.return
|
||||
// CHECK-NEXT: device = "TPU_REPLICATED_HOST"
|
||||
tf_device.replicate([%arg0, %arg1] as %ri : tensor<i32>) {n = 2 : i32} {
|
||||
@ -320,6 +337,7 @@ module attributes {tf.versions = {producer = 888 : i32}, tf.devices = ["/job:wor
|
||||
func @head_tail_simple_extraction(%arg0: tensor<i32>) -> tensor<i32> {
|
||||
// CHECK: %[[HEAD_LAUNCH_OUT:.*]] = "tf_device.launch"
|
||||
// CHECK-NEXT: %[[A_OUT:.*]] = "tf.A"(%arg0)
|
||||
// CHECK-NOT: _xla_outside_compilation
|
||||
// CHECK-NEXT: tf_device.return %[[A_OUT]]
|
||||
// CHECK-NEXT: device = "/job:worker/replica:0/task:0/device:CPU:0"
|
||||
//
|
||||
@ -335,6 +353,7 @@ module attributes {tf.versions = {producer = 888 : i32}, tf.devices = ["/job:wor
|
||||
//
|
||||
// CHECK: %[[TAIL_LAUNCH_OUT:.*]] = "tf_device.launch"
|
||||
// CHECK-NEXT: %[[C_OUT:.*]] = "tf.C"(%[[CLUSTER_OUT]])
|
||||
// CHECK-NOT: _xla_outside_compilation
|
||||
// CHECK-NEXT: tf_device.return %[[C_OUT]]
|
||||
// CHECK-NEXT: device = "/job:worker/replica:0/task:0/device:CPU:0"
|
||||
%cluster = "tf_device.cluster"() ( {
|
||||
@ -353,6 +372,7 @@ module attributes {tf.versions = {producer = 888 : i32}, tf.devices = ["/job:wor
|
||||
//
|
||||
// CHECK-NEXT: %[[HEAD_LAUNCH_OUT:.*]] = "tf_device.launch"()
|
||||
// CHECK-NEXT: %[[A_OUT:.*]] = "tf.A"(%[[RI]])
|
||||
// CHECK-NOT: _xla_outside_compilation
|
||||
// CHECK-NEXT: tf_device.return %[[A_OUT]]
|
||||
// CHECK-NEXT: device = "TPU_REPLICATED_HOST"
|
||||
//
|
||||
@ -370,6 +390,7 @@ module attributes {tf.versions = {producer = 888 : i32}, tf.devices = ["/job:wor
|
||||
//
|
||||
// CHECK-NEXT: "tf_device.launch"()
|
||||
// CHECK-NEXT: "tf.D"(%[[HEAD_LAUNCH_OUT]], %[[CLUSTER_OUT]], %[[RI]])
|
||||
// CHECK-NOT: _xla_outside_compilation
|
||||
// CHECK-NEXT: tf_device.return
|
||||
// CHECK-NEXT: device = "TPU_REPLICATED_HOST"
|
||||
tf_device.replicate([%arg0, %arg1] as %ri : tensor<i32>) {n = 2 : i32} {
|
||||
|
@ -1,4 +1,4 @@
|
||||
// RUN: tf-opt %s -split-input-file -tf-tpu-host-computation-expansion | FileCheck %s --dump-input-on-failure
|
||||
// RUN: tf-opt %s -split-input-file -tf-tpu-host-computation-expansion | FileCheck %s
|
||||
|
||||
// Tests expansion of a outside compiled ops at head/tail of TPU computation.
|
||||
|
||||
@ -26,7 +26,7 @@ func @cast_at_head_expanded(%arg0: tensor<?xi32>) {
|
||||
"tf.B"(%1) {_xla_outside_compilation = "cluster1"} : (tensor<?xi32>) -> ()
|
||||
"tf.C"() : () -> ()
|
||||
tf_device.return
|
||||
}) {cluster_attr = "cluster_attr"} : () -> ()
|
||||
}) {} : () -> ()
|
||||
return
|
||||
}
|
||||
|
||||
@ -44,7 +44,7 @@ func @check_consecutive_unary_ops_outside_compiled(%arg0: tensor<?xi32>) {
|
||||
"tf.B"(%2) {_xla_outside_compilation = "cluster1"} : (tensor<?xi32>) -> ()
|
||||
"tf.C"() : () -> ()
|
||||
tf_device.return
|
||||
}) {cluster_attr = "cluster_attr"} : () -> ()
|
||||
}) {} : () -> ()
|
||||
return
|
||||
}
|
||||
|
||||
@ -59,7 +59,7 @@ func @check_only_necesarily_ops_outside_compiled(%arg0: tensor<?xi32>) {
|
||||
"tf.B"(%1) : (tensor<?xi32>) -> ()
|
||||
"tf.C"() : () -> ()
|
||||
tf_device.return
|
||||
}) {cluster_attr = "cluster_attr"} : () -> ()
|
||||
}) {} : () -> ()
|
||||
return
|
||||
}
|
||||
|
||||
@ -67,9 +67,9 @@ func @check_only_necesarily_ops_outside_compiled(%arg0: tensor<?xi32>) {
|
||||
func @check_only_necesarily_ops_outside_compiled_with_chained_ops(%arg0: tensor<?xi32>) {
|
||||
// CHECK: "tf_device.cluster"
|
||||
// CHECK-NEXT: "tf.Cast"
|
||||
// CHECK-NOT: _xla_outside_compilation = ""
|
||||
// CHECK-NOT: _xla_outside_compilation
|
||||
// CHECK-NEXT: "tf.Identity"
|
||||
// CHECK-NOT: _xla_outside_compilation = ""
|
||||
// CHECK-NOT: _xla_outside_compilation
|
||||
// CHECK-NEXT: "tf.B"
|
||||
"tf_device.cluster"() ( {
|
||||
%1 = "tf.Cast"(%arg0) : (tensor<?xi32>) -> (tensor<?xi32>)
|
||||
@ -77,6 +77,19 @@ func @check_only_necesarily_ops_outside_compiled_with_chained_ops(%arg0: tensor<
|
||||
"tf.B"(%2) : (tensor<?xi32>) -> ()
|
||||
"tf.C"() : () -> ()
|
||||
tf_device.return
|
||||
}) {cluster_attr = "cluster_attr"} : () -> ()
|
||||
}) : () -> ()
|
||||
return
|
||||
}
|
||||
|
||||
// CHECK-LABEL: func @check_op_without_usage_not_outside_compiled
|
||||
func @check_op_without_usage_not_outside_compiled(%arg0: tensor<?xi32>) {
|
||||
// CHECK: "tf_device.cluster"
|
||||
// CHECK-NEXT: "tf.Identity"
|
||||
// CHECK-NOT: _xla_outside_compilation
|
||||
"tf_device.cluster"() ( {
|
||||
"tf.Identity"(%arg0) : (tensor<?xi32>) -> (tensor<?xi32>)
|
||||
"tf.C"() : () -> ()
|
||||
tf_device.return
|
||||
}) : () -> ()
|
||||
return
|
||||
}
|
||||
|
@ -87,7 +87,8 @@ void CreateTPUBridgePipeline(OpPassManager &pm) {
|
||||
// because DecomposeResourceOpsPass uses pattern rewriter which hoists
|
||||
// changed constants out of tf_device.Launch.
|
||||
func_pm.addPass(TFDevice::CreateDecomposeResourceOpsPass());
|
||||
|
||||
pm.addNestedPass<FuncOp>(CreateTPUHostComputationExpansionPass());
|
||||
pm.addPass(CreateTPUExtractHeadTailOutsideCompilationPass());
|
||||
// Run another shape inference pass because resource decomposition might have
|
||||
// created new partial types.
|
||||
pm.addPass(TF::CreateTFShapeInferencePass());
|
||||
|
@ -83,6 +83,13 @@ def BitcastSameType : Pat<(TF_BitcastOp:$res $arg), (replaceWithValue $arg),
|
||||
def BitcastNested : Pat<(TF_BitcastOp (TF_BitcastOp $arg)),
|
||||
(TF_BitcastOp $arg)>;
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
// Concat op patterns.
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
def ConvertToConcatV2 : Pat<(TF_ConcatOp $axis, $inputs),
|
||||
(TF_ConcatV2Op $inputs, $axis)>;
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
// Conj op patterns.
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
@ -50,6 +50,24 @@ static Type GetResourceSubtypeOrDefault(Value resource, Type element_type) {
|
||||
return UnrankedTensorType::get(element_type);
|
||||
}
|
||||
|
||||
static bool HasResourceSubtype(Value resource) {
|
||||
return resource.getType()
|
||||
.cast<TensorType>()
|
||||
.getElementType()
|
||||
.cast<ResourceType>()
|
||||
.getSubtypes()
|
||||
.size() == 1;
|
||||
}
|
||||
|
||||
static Type GetResourceSubtype(Value resource) {
|
||||
return resource.getType()
|
||||
.cast<TensorType>()
|
||||
.getElementType()
|
||||
.cast<ResourceType>()
|
||||
.getSubtypes()
|
||||
.front();
|
||||
}
|
||||
|
||||
#include "tensorflow/compiler/mlir/tensorflow/transforms/generated_decompose_resource_ops.inc"
|
||||
} // namespace
|
||||
|
||||
|
@ -23,7 +23,7 @@ class GetScalarOfType<int value> : NativeCodeCall<
|
||||
|
||||
// Creates a tf.ReadVariable op that reads a resource `$2` that has the same
|
||||
// element type as `$1`. The op created will use location of `$0`.
|
||||
def CreateTFReadVariableOp: NativeCodeCall<
|
||||
def CreateTFReadVariableOp : NativeCodeCall<
|
||||
"$_builder.create<TF::ReadVariableOp>("
|
||||
" $0.getLoc(),"
|
||||
" GetResourceSubtypeOrDefault("
|
||||
@ -31,6 +31,12 @@ def CreateTFReadVariableOp: NativeCodeCall<
|
||||
" $2)"
|
||||
>;
|
||||
|
||||
def CheckHasResourceSubtype : Constraint<CPred<"HasResourceSubtype($0)">>;
|
||||
|
||||
def CreateTFReadVariableOpFromResourceHandle : NativeCodeCall<
|
||||
"$_builder.create<TF::ReadVariableOp>("
|
||||
"$0.getLoc(), GetResourceSubtype($1), $1)">;
|
||||
|
||||
def DecomposeAssignAddVariableOp :
|
||||
Pat<
|
||||
(TF_AssignAddVariableOp:$src_op $resource, $value),
|
||||
@ -315,3 +321,9 @@ def DecomposeResourceScatterUpdate : Pat<
|
||||
$updates
|
||||
)
|
||||
)>;
|
||||
|
||||
// Pattern to decompose tf.VariableShape into tf.ReadVariable and tf.Shape.
|
||||
def DecomposeVariableShape : Pat<
|
||||
(TF_VariableShapeOp:$src_op $resource),
|
||||
(TF_ShapeOp (CreateTFReadVariableOpFromResourceHandle $src_op, $resource)),
|
||||
[(CheckHasResourceSubtype $resource)]>;
|
||||
|
@ -0,0 +1,225 @@
|
||||
/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
|
||||
#include <cstdio>
|
||||
#include <iostream>
|
||||
|
||||
#include "llvm/ADT/StringRef.h"
|
||||
#include "mlir/IR/Attributes.h" // from @llvm-project
|
||||
#include "mlir/IR/Diagnostics.h" // from @llvm-project
|
||||
#include "mlir/IR/Function.h" // from @llvm-project
|
||||
#include "mlir/IR/PatternMatch.h" // from @llvm-project
|
||||
#include "mlir/Pass/Pass.h" // from @llvm-project
|
||||
#include "mlir/Support/LLVM.h" // from @llvm-project
|
||||
#include "mlir/Support/LogicalResult.h" // from @llvm-project
|
||||
#include "tensorflow/compiler/mlir/tensorflow/ir/tf_ops.h"
|
||||
|
||||
namespace mlir {
|
||||
|
||||
namespace TF {
|
||||
|
||||
namespace {
|
||||
|
||||
// Note: This implements fusions performed in the old Remapper Grappler pass.
|
||||
// That pass has specific cases for GPU and based on different target
|
||||
// configurations on both CPU and GPU (Intel MKL, ROCm, etc.). This MLIR pass
|
||||
// covers the general CPU case and at the moment does not account for any
|
||||
// target-specific configurations.
|
||||
// TODO(b/158265178): Support GPU-specific fusions.
|
||||
// TODO(b/158266710): Support CPU MKL configurations.
|
||||
|
||||
// Optimizes TF computations by fusing subgraphs/nodes onto more efficient
|
||||
// implementations to decrease the number of operations needed to perform a
|
||||
// computation.
|
||||
struct FusedKernelMatcherPass
|
||||
: public PassWrapper<FusedKernelMatcherPass, FunctionPass> {
|
||||
void runOnFunction() override;
|
||||
};
|
||||
|
||||
// Returns an op's name with the dialect prefix stripped off.
|
||||
StringRef GetOpNameWithoutDialect(Operation *op) {
|
||||
return op->getName().getStringRef().split(".").second;
|
||||
}
|
||||
|
||||
bool IsActivationFunction(Operation *op) {
|
||||
return isa<EluOp>(op) || isa<ReluOp>(op) || isa<Relu6Op>(op);
|
||||
}
|
||||
|
||||
// Finds and returns an activation op that uses the result of `op`. If there are
|
||||
// multiple such activations, one is returned (with no guarantee as to which
|
||||
// one). If there are no activation functions that use the output, returns
|
||||
// nullptr.
|
||||
Operation *GetActivation(Value op) {
|
||||
for (auto &use : op.getUses()) {
|
||||
if (IsActivationFunction(use.getOwner())) return use.getOwner();
|
||||
}
|
||||
return nullptr;
|
||||
}
|
||||
|
||||
// Finds and returns a BiasAdd that uses the result of `op` as the `value`
|
||||
// input. If there are multiple such BiasAdds, one is returned (with no
|
||||
// guarantee as to which one). If there are no BiasAdds that use the output,
|
||||
// returns a null BiasAddOp.
|
||||
BiasAddOp GetBiasAdd(Value op) {
|
||||
for (auto &use : op.getUses()) {
|
||||
auto bias_add = dyn_cast_or_null<BiasAddOp>(use.getOwner());
|
||||
// If it's a BiasAdd, check that the conv op is the first input.
|
||||
if (bias_add && bias_add.value() == op) return bias_add;
|
||||
}
|
||||
// No BiasAddOps found among uses.
|
||||
return BiasAddOp();
|
||||
}
|
||||
|
||||
// Performs a fusion of the following pattern(s), if possible:
|
||||
// <Contraction> + BiasAdd + <Activation> -> <FusedContraction>
|
||||
//
|
||||
// Note that fusion with activation is preferred, but a contraction and BiasAdd
|
||||
// can also be replaced by a _FusedConv2D if there is no other activation
|
||||
// function.
|
||||
// i.e., this class also supports the following fusion:
|
||||
// <Contraction> + BiasAdd -> <FusedContraction>
|
||||
//
|
||||
// TODO(b/158266331): Support fusing activation chains of arbitrary length.
|
||||
template <typename SrcOpT, typename FusedOpT>
|
||||
class FuseContractionWithBiasAdd : public OpRewritePattern<SrcOpT> {
|
||||
public:
|
||||
using OpRewritePattern<SrcOpT>::OpRewritePattern;
|
||||
// Class users should override this method if there are any op-specific
|
||||
// compatibility requirements between the contraction op and the BiasAdd op.
|
||||
virtual bool AreFuseCompatible(SrcOpT contraction_op, BiasAddOp bias_add,
|
||||
PatternRewriter &rewriter) const {
|
||||
return true;
|
||||
}
|
||||
|
||||
LogicalResult matchAndRewrite(SrcOpT contraction,
|
||||
PatternRewriter &rewriter) const override {
|
||||
auto context = rewriter.getContext();
|
||||
// If the contraction is used in multiple places, fusing it will only create
|
||||
// more contraction nodes, which is slower.
|
||||
if (!contraction.getResult().hasOneUse())
|
||||
return rewriter.notifyMatchFailure(contraction,
|
||||
"result is used by multiple ops");
|
||||
|
||||
BiasAddOp bias_add = GetBiasAdd(contraction.getResult());
|
||||
if (!bias_add) {
|
||||
return rewriter.notifyMatchFailure(
|
||||
contraction, "does not feed into a tf.BiasAdd/tf.BiasAddV1 op");
|
||||
}
|
||||
|
||||
if (!AreFuseCompatible(contraction, bias_add, rewriter)) {
|
||||
return rewriter.notifyMatchFailure(
|
||||
contraction, "cannot fuse with the subsequent BiasAdd op");
|
||||
}
|
||||
|
||||
SmallVector<Location, 3> locations{contraction.getLoc(), bias_add.getLoc()};
|
||||
SmallVector<Attribute, 2> fused_ops{
|
||||
StringAttr::get(GetOpNameWithoutDialect(bias_add), context)};
|
||||
|
||||
// BiasAdd may or may not feed into an activation function.
|
||||
auto activation = GetActivation(bias_add);
|
||||
|
||||
// If there is an activation, only fuse it if this is the only op to use the
|
||||
// result of the BiasAdd.
|
||||
bool fuse_activation = activation && bias_add.output().hasOneUse();
|
||||
Type result_type;
|
||||
|
||||
// Include info about the activation function if applicable.
|
||||
if (fuse_activation) {
|
||||
locations.push_back(activation->getLoc());
|
||||
fused_ops.push_back(
|
||||
StringAttr::get(GetOpNameWithoutDialect(activation), context));
|
||||
result_type = activation->getResultTypes().front();
|
||||
} else {
|
||||
result_type = bias_add.getResult().getType();
|
||||
}
|
||||
|
||||
auto fused_loc = rewriter.getFusedLoc(locations);
|
||||
|
||||
// The fused contraction has the same operands as the original contraction
|
||||
// with `bias` from the BiasAddOp appended.
|
||||
SmallVector<Value, 4> operands(contraction.operand_begin(),
|
||||
contraction.operand_end());
|
||||
operands.push_back(bias_add.bias());
|
||||
|
||||
// The fused contraction has the same attributes as the original
|
||||
// contraction, with two additions: the list of ops which have been fused
|
||||
// together; epsilon (only with FusedBatchNorm).
|
||||
std::vector<NamedAttribute> attrs = contraction.getAttrs();
|
||||
ArrayAttr fused_ops_attr = ArrayAttr::get(fused_ops, context);
|
||||
attrs.push_back(
|
||||
NamedAttribute(Identifier::get("fused_ops", context), fused_ops_attr));
|
||||
// Epsilon is used only in fusions with the FusedBatchNorm op, so we zero it
|
||||
// here.
|
||||
Attribute epsilon = rewriter.getF32FloatAttr(0);
|
||||
attrs.push_back(
|
||||
NamedAttribute(Identifier::get("epsilon", context), epsilon));
|
||||
|
||||
Value fused_op = rewriter.create<FusedOpT>(fused_loc, result_type,
|
||||
ValueRange(operands), attrs);
|
||||
auto op_to_replace = fuse_activation ? activation : bias_add;
|
||||
rewriter.replaceOp(op_to_replace, ValueRange({fused_op}));
|
||||
return success();
|
||||
}
|
||||
};
|
||||
|
||||
// Performs a fusion of the following pattern(s), if possible:
|
||||
// Conv2D + BiasAdd + <Activation> -> _FusedConv2D
|
||||
class FuseConv2DBiasAdd
|
||||
: public FuseContractionWithBiasAdd<Conv2DOp, _FusedConv2DOp> {
|
||||
public:
|
||||
using FuseContractionWithBiasAdd<Conv2DOp,
|
||||
_FusedConv2DOp>::FuseContractionWithBiasAdd;
|
||||
// Verify that the Conv2D and BiasAdd data formats match. This is necessary
|
||||
// for the ops to fuse correctly, the fused Conv2D op has one data format
|
||||
// attribute which is shared.
|
||||
bool AreFuseCompatible(Conv2DOp conv, BiasAddOp bias_add,
|
||||
PatternRewriter &rewriter) const override {
|
||||
// Verify that the data formats match and are valid for fusion.
|
||||
if (conv.data_format() != bias_add.data_format()) {
|
||||
rewriter.notifyMatchFailure(conv, [&](Diagnostic &diag) {
|
||||
diag << "data format does not match Conv2D data format ("
|
||||
<< bias_add.data_format() << " vs " << conv.data_format() << ")";
|
||||
});
|
||||
return false;
|
||||
}
|
||||
return true;
|
||||
}
|
||||
};
|
||||
|
||||
// Performs a fusion of the following pattern(s), if possible:
|
||||
// MatMulOp + BiasAdd + <Activation> -> _FusedMatMulOp
|
||||
using FuseMatMulBiasAdd = FuseContractionWithBiasAdd<MatMulOp, _FusedMatMulOp>;
|
||||
|
||||
void FusedKernelMatcherPass::runOnFunction() {
|
||||
OwningRewritePatternList patterns;
|
||||
auto func = getFunction();
|
||||
patterns.insert<FuseConv2DBiasAdd, FuseMatMulBiasAdd>(&getContext());
|
||||
|
||||
applyPatternsAndFoldGreedily(func, patterns);
|
||||
}
|
||||
|
||||
} // namespace
|
||||
|
||||
std::unique_ptr<OperationPass<FuncOp>> CreateFusedKernelMatcherPass() {
|
||||
return std::make_unique<FusedKernelMatcherPass>();
|
||||
}
|
||||
|
||||
static PassRegistration<FusedKernelMatcherPass> pass(
|
||||
"tf-fused-kernel-matcher",
|
||||
"Matches computations corresponding to optimized fused kernels");
|
||||
|
||||
} // namespace TF
|
||||
|
||||
} // namespace mlir
|
@ -1,174 +0,0 @@
|
||||
/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
|
||||
#include <cstdio>
|
||||
#include <iostream>
|
||||
|
||||
#include "llvm/ADT/StringRef.h"
|
||||
#include "mlir/IR/Attributes.h" // from @llvm-project
|
||||
#include "mlir/IR/Diagnostics.h" // from @llvm-project
|
||||
#include "mlir/IR/Function.h" // from @llvm-project
|
||||
#include "mlir/IR/PatternMatch.h" // from @llvm-project
|
||||
#include "mlir/Pass/Pass.h" // from @llvm-project
|
||||
#include "mlir/Support/LLVM.h" // from @llvm-project
|
||||
#include "mlir/Support/LogicalResult.h" // from @llvm-project
|
||||
#include "tensorflow/compiler/mlir/tensorflow/ir/tf_ops.h"
|
||||
|
||||
namespace mlir {
|
||||
|
||||
namespace TF {
|
||||
|
||||
namespace {
|
||||
|
||||
// Note: This implements the fusions performed in the old Remapper Grappler
|
||||
// pass. That pass has specific cases for GPU and based on different
|
||||
// target configurations on both CPU and GPU (Intel MKL, ROCm, etc.). This MLIR
|
||||
// pass covers the general CPU case and at the moment does not account for any
|
||||
// specific target configurations.
|
||||
// TODO(b/158265178): Support GPU-specific fusions.
|
||||
// TODO(b/158266710): Support CPU MKL configurations.
|
||||
|
||||
// Optimizes TF computations by fusing subgraphs/nodes onto more efficient
|
||||
// implementations to decrease the number of operations needed to perform a
|
||||
// computation.
|
||||
struct OpFusionPass : public PassWrapper<OpFusionPass, FunctionPass> {
|
||||
void runOnFunction() override;
|
||||
};
|
||||
|
||||
// Returns an op's name with the dialect prefix stripped off.
|
||||
StringRef GetOpNameWithoutDialect(Operation *op) {
|
||||
return op->getName().getStringRef().split(".").second;
|
||||
}
|
||||
|
||||
bool IsActivationFunction(Operation *op) {
|
||||
return isa<EluOp>(op) || isa<ReluOp>(op) || isa<Relu6Op>(op);
|
||||
}
|
||||
|
||||
// Finds and returns an activation op that uses the result of `op`. If there are
|
||||
// multiple such activations, one is returned (with no guarantee as to which
|
||||
// one). If there are no activation functions that use the output, returns
|
||||
// nullptr.
|
||||
Operation *GetActivation(Value op) {
|
||||
for (auto &use : op.getUses()) {
|
||||
if (IsActivationFunction(use.getOwner())) return use.getOwner();
|
||||
}
|
||||
return nullptr;
|
||||
}
|
||||
|
||||
// Finds and returns a BiasAdd that uses the result of `op` as the `value`
|
||||
// input. If there are multiple such BiasAdds, one is returned (with no
|
||||
// guarantee as to which one). If there are no BiasAdds that use the output,
|
||||
// returns a null BiasAddOp.
|
||||
BiasAddOp GetBiasAdd(Value op) {
|
||||
for (auto &use : op.getUses()) {
|
||||
auto bias_add = dyn_cast_or_null<BiasAddOp>(use.getOwner());
|
||||
// If it's a BiasAdd, check that the conv op is the first input.
|
||||
if (bias_add && bias_add.value() == op) return bias_add;
|
||||
}
|
||||
// No BiasAddOps found among uses.
|
||||
return BiasAddOp();
|
||||
}
|
||||
|
||||
// Performs a fusion of the following pattern(s), if possible:
|
||||
// Conv2D + BiasAdd + <Activation> -> _FusedConv2D
|
||||
//
|
||||
// Note that fusion with activation is preferred, but a Conv2D and BiasAdd can
|
||||
// also be replaced by a _FusedConv2D if there is no other activation function.
|
||||
// i.e., this class also supports the following fusion:
|
||||
// Conv2D + BiasAdd -> _FusedConv2D
|
||||
//
|
||||
// TODO(b/158266331): Support fusing Conv2D + BiasAdd + a chain of activations.
|
||||
class FuseConv2DBiasAdd : public OpRewritePattern<Conv2DOp> {
|
||||
public:
|
||||
using OpRewritePattern::OpRewritePattern;
|
||||
|
||||
LogicalResult matchAndRewrite(Conv2DOp op,
|
||||
PatternRewriter &rewriter) const override {
|
||||
// If the convolution is used in multiple places, fusing it will only create
|
||||
// more convolutions, which is slower.
|
||||
if (!op.getResult().hasOneUse())
|
||||
return rewriter.notifyMatchFailure(op, "result is used by multiple ops");
|
||||
|
||||
BiasAddOp bias_add = GetBiasAdd(op);
|
||||
if (!bias_add) {
|
||||
return rewriter.notifyMatchFailure(
|
||||
op, "does not feed into a tf.BiasAdd/tf.BiasAddV1 op");
|
||||
}
|
||||
|
||||
// Check that Conv and BiasAdd formats match.
|
||||
if (op.data_format() != bias_add.data_format()) {
|
||||
return rewriter.notifyMatchFailure(op, [&](Diagnostic &diag) {
|
||||
diag << "data format does not match Conv2D data format ("
|
||||
<< bias_add.data_format() << " vs " << op.data_format() << ")";
|
||||
});
|
||||
}
|
||||
|
||||
SmallVector<Location, 3> locations{op.getLoc(), bias_add.getLoc()};
|
||||
SmallVector<Attribute, 2> fused_ops{StringAttr::get(
|
||||
GetOpNameWithoutDialect(bias_add), rewriter.getContext())};
|
||||
Type result_type;
|
||||
|
||||
// BiasAdd may or may not feed into an activation function.
|
||||
auto activation = GetActivation(bias_add);
|
||||
|
||||
// If there is an activation, only fuse it if this is the only op to use the
|
||||
// result of the BiasAdd.
|
||||
bool fuse_activation = activation && bias_add.output().hasOneUse();
|
||||
|
||||
// Include info about the activation function if applicable.
|
||||
if (fuse_activation) {
|
||||
locations.push_back(activation->getLoc());
|
||||
fused_ops.push_back(StringAttr::get(GetOpNameWithoutDialect(activation),
|
||||
rewriter.getContext()));
|
||||
result_type = activation->getResultTypes().front();
|
||||
} else {
|
||||
result_type = bias_add.getResult().getType();
|
||||
}
|
||||
|
||||
auto loc = rewriter.getFusedLoc(locations);
|
||||
ArrayAttr fused_ops_attr = ArrayAttr::get(fused_ops, rewriter.getContext());
|
||||
// Epsilon is used only in fusions with the BatchNorm op.
|
||||
APFloat epsilon = APFloat(0.0f);
|
||||
auto fused_op = rewriter.create<_FusedConv2DOp>(
|
||||
loc, result_type, op.input(), op.filter(), bias_add.bias(),
|
||||
op.strides(), op.padding(), op.explicit_paddings(), op.data_format(),
|
||||
op.dilations(), op.use_cudnn_on_gpu(), fused_ops_attr, epsilon);
|
||||
auto op_to_replace = fuse_activation ? activation : bias_add;
|
||||
rewriter.replaceOp(op_to_replace, {fused_op});
|
||||
return success();
|
||||
}
|
||||
};
|
||||
|
||||
void OpFusionPass::runOnFunction() {
|
||||
OwningRewritePatternList patterns;
|
||||
auto func = getFunction();
|
||||
patterns.insert<FuseConv2DBiasAdd>(&getContext());
|
||||
|
||||
applyPatternsAndFoldGreedily(func, patterns);
|
||||
}
|
||||
|
||||
} // namespace
|
||||
|
||||
std::unique_ptr<OperationPass<FuncOp>> CreateOpFusionPass() {
|
||||
return std::make_unique<OpFusionPass>();
|
||||
}
|
||||
|
||||
static PassRegistration<OpFusionPass> pass(
|
||||
"tf-op-fusion",
|
||||
"Replaces commonly occurring subgraphs with optimized fused kernels");
|
||||
|
||||
} // namespace TF
|
||||
|
||||
} // namespace mlir
|
@ -54,7 +54,7 @@ std::unique_ptr<OperationPass<FuncOp>> CreateTFOptimizePass();
|
||||
|
||||
// Creates pass to rewrite RecvTPUEmbeddingActivationsOp and
|
||||
// SendTPUEmbeddingGradients ops to internal variants.
|
||||
std::unique_ptr<OperationPass<FuncOp>> CreateRewriteTPUEmbeddingOps();
|
||||
std::unique_ptr<OperationPass<FuncOp>> CreateRewriteTPUEmbeddingOpsPass();
|
||||
|
||||
// Performs specific fusion for GPU targets.
|
||||
std::unique_ptr<OperationPass<FuncOp>> CreateGpuOpFusionPass();
|
||||
@ -148,8 +148,10 @@ CreateTensorArrayOpsDecompositionPass();
|
||||
// Create a pass that legalize HLO to TF dialect.
|
||||
std::unique_ptr<OperationPass<FuncOp>> CreateLegalizeHloToTfPass();
|
||||
|
||||
// Creates a pass that performs fusion of common sequences of ops.
|
||||
std::unique_ptr<OperationPass<FuncOp>> CreateOpFusionPass();
|
||||
// Matches sequence of ops to TensorFlow fused kernels. This pass should not be
|
||||
// generally used beyond exporting to runtimes that supports these ops. In the
|
||||
// future these fusions may be codegen'd automatically.
|
||||
std::unique_ptr<OperationPass<FuncOp>> CreateFusedKernelMatcherPass();
|
||||
} // namespace TF
|
||||
|
||||
namespace tf_executor {
|
||||
|
@ -101,7 +101,7 @@ void RewriteTPUEmbeddingOps::runOnFunction() {
|
||||
|
||||
} // anonymous namespace
|
||||
|
||||
std::unique_ptr<OperationPass<FuncOp>> CreateRewriteTPUEmbeddingOps() {
|
||||
std::unique_ptr<OperationPass<FuncOp>> CreateRewriteTPUEmbeddingOpsPass() {
|
||||
return std::make_unique<RewriteTPUEmbeddingOps>();
|
||||
}
|
||||
|
||||
|
@ -209,8 +209,10 @@ void CreateHeadComputation(OpBuilder* builder, tf_device::ClusterOp cluster,
|
||||
llvm::ArrayRef<Operation*> head_outside_compiled_ops,
|
||||
llvm::StringRef host_device) {
|
||||
Block* launch_block = new Block;
|
||||
for (Operation* head_outside_compiled_op : head_outside_compiled_ops)
|
||||
for (Operation* head_outside_compiled_op : head_outside_compiled_ops) {
|
||||
head_outside_compiled_op->removeAttr(kXlaOutsideCompilationAttr);
|
||||
head_outside_compiled_op->moveBefore(launch_block, launch_block->end());
|
||||
}
|
||||
|
||||
tf_device::LaunchOp launch = CreateLaunchForBlock(
|
||||
builder, cluster, /*before=*/true, launch_block, host_device);
|
||||
@ -294,8 +296,10 @@ void CreateTailComputation(OpBuilder* builder, tf_device::ClusterOp cluster,
|
||||
llvm::ArrayRef<Operation*> tail_outside_compiled_ops,
|
||||
llvm::StringRef host_device) {
|
||||
Block* launch_block = new Block;
|
||||
for (Operation* tail_outside_compiled_op : tail_outside_compiled_ops)
|
||||
for (Operation* tail_outside_compiled_op : tail_outside_compiled_ops) {
|
||||
tail_outside_compiled_op->removeAttr(kXlaOutsideCompilationAttr);
|
||||
tail_outside_compiled_op->moveBefore(launch_block, launch_block->begin());
|
||||
}
|
||||
|
||||
tf_device::LaunchOp launch = CreateLaunchForBlock(
|
||||
builder, cluster, /*before=*/false, launch_block, host_device);
|
||||
|
@ -92,10 +92,13 @@ void ExpandHeadOutsideCompiledOps(tf_device::ClusterOp cluster,
|
||||
|
||||
for (auto head_outside_compiled_op :
|
||||
llvm::reverse(head_outside_compiled_ops)) {
|
||||
if (HasOutsideCompilationAttribute(head_outside_compiled_op)) continue;
|
||||
auto users = head_outside_compiled_op->getUsers();
|
||||
if (users.empty() ||
|
||||
HasOutsideCompilationAttribute(head_outside_compiled_op))
|
||||
continue;
|
||||
|
||||
bool should_expand_op_to_host_computation = true;
|
||||
for (auto consumer_op : head_outside_compiled_op->getUsers()) {
|
||||
for (auto consumer_op : users) {
|
||||
if (should_expand_op_to_host_computation &&
|
||||
!HasOutsideCompilationAttribute(consumer_op)) {
|
||||
should_expand_op_to_host_computation = false;
|
||||
|
@ -219,8 +219,7 @@ class ImporterBase {
|
||||
const absl::InlinedVector<OutputTensor, 4>& arg_nodes,
|
||||
const absl::InlinedVector<OutputTensor, 4>& ret_nodes,
|
||||
const absl::InlinedVector<Node*, 4>& control_ret_nodes,
|
||||
llvm::ArrayRef<mlir::NamedAttribute> attrs,
|
||||
bool function_graph);
|
||||
llvm::ArrayRef<mlir::NamedAttribute> attrs);
|
||||
|
||||
// Finds out the function definition for the given function name from the
|
||||
// graph and converts it to a function of the module. This method is called
|
||||
@ -1302,8 +1301,7 @@ Status ImporterBase::ConvertLibFunction(llvm::StringRef func_name) {
|
||||
|
||||
TF_RETURN_IF_ERROR(child_importer.Convert(
|
||||
mlir_func_name, func_type, arg_nodes, ret_nodes, control_ret_nodes,
|
||||
llvm::makeArrayRef(attributes.begin(), attributes.end()),
|
||||
/*function_graph=*/true));
|
||||
llvm::makeArrayRef(attributes.begin(), attributes.end())));
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
@ -1405,7 +1403,7 @@ Status ImporterBase::Convert(
|
||||
const absl::InlinedVector<OutputTensor, 4>& arg_nodes,
|
||||
const absl::InlinedVector<OutputTensor, 4>& ret_nodes,
|
||||
const absl::InlinedVector<Node*, 4>& control_ret_nodes,
|
||||
llvm::ArrayRef<mlir::NamedAttribute> attrs, bool function_graph) {
|
||||
llvm::ArrayRef<mlir::NamedAttribute> attrs) {
|
||||
// TODO(b/122040776): Uses debug info for FunctionDef.
|
||||
auto function = mlir::FuncOp::create(mlir::UnknownLoc::get(context_),
|
||||
func_name, func_type, attrs);
|
||||
@ -2222,8 +2220,15 @@ StatusOr<mlir::OwningModuleRef> GraphDefImporter::Convert(
|
||||
PopulateTfVersions(module.get(), graph.versions());
|
||||
|
||||
TF_RETURN_IF_ERROR(importer.ImporterBase::Convert(
|
||||
func_name, func_type, arg_nodes, ret_nodes, control_ret_nodes, attrs,
|
||||
specs.graph_as_function));
|
||||
func_name, func_type, arg_nodes, ret_nodes, control_ret_nodes, attrs));
|
||||
|
||||
// Mark main function public, others private.
|
||||
for (auto function : module.get().getOps<mlir::FuncOp>()) {
|
||||
auto visibility = function.getName() == func_name
|
||||
? mlir::FuncOp::Visibility::Public
|
||||
: mlir::FuncOp::Visibility::Private;
|
||||
function.setVisibility(visibility);
|
||||
}
|
||||
return module;
|
||||
}
|
||||
|
||||
@ -2888,6 +2893,16 @@ void AdjustBoundInputArgTypes(mlir::ModuleOp module) {
|
||||
}
|
||||
}
|
||||
|
||||
// Marks the visibility of functions in the saved model module.
|
||||
void MarkSavedModelFunctionVisibility(mlir::ModuleOp module) {
|
||||
for (auto func : module.getOps<mlir::FuncOp>()) {
|
||||
auto visibility = mlir::tf_saved_model::IsExported(func)
|
||||
? mlir::FuncOp::Visibility::Public
|
||||
: mlir::FuncOp::Visibility::Private;
|
||||
func.setVisibility(visibility);
|
||||
}
|
||||
}
|
||||
|
||||
// Reorder the ops in the module to make testing easier and less dependent
|
||||
// on implementation details such as the order of functions in the
|
||||
// FunctionDefLibrary.
|
||||
@ -3130,6 +3145,7 @@ Status CreateSavedModelIR(
|
||||
AdjustBoundInputArgTypes(module);
|
||||
module.setAttr("tf_saved_model.semantics", builder.getUnitAttr());
|
||||
SortSavedModelModule(module);
|
||||
MarkSavedModelFunctionVisibility(module);
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
@ -3299,6 +3315,7 @@ SavedModelSignatureDefImporter::ConvertSignatures() {
|
||||
mlir::OpBuilder builder(module_->getBodyRegion());
|
||||
module_->setAttr("tf_saved_model.semantics", builder.getUnitAttr());
|
||||
SortSavedModelModule(*module_);
|
||||
MarkSavedModelFunctionVisibility(*module_);
|
||||
|
||||
return std::move(module_);
|
||||
}
|
||||
|
@ -57,7 +57,7 @@ cc_library(
|
||||
],
|
||||
deps = [
|
||||
":tfjs_inc_gen",
|
||||
"@llvm-project//llvm:support",
|
||||
"@llvm-project//llvm:Support",
|
||||
"@llvm-project//mlir:Dialect",
|
||||
"@llvm-project//mlir:IR",
|
||||
"@llvm-project//mlir:SideEffects",
|
||||
@ -109,7 +109,7 @@ cc_library(
|
||||
":tensorflow_js",
|
||||
":tensorflow_js_dialect_registration",
|
||||
"//tensorflow/compiler/mlir/tensorflow",
|
||||
"@llvm-project//llvm:support",
|
||||
"@llvm-project//llvm:Support",
|
||||
"@llvm-project//mlir:Analysis",
|
||||
"@llvm-project//mlir:IR",
|
||||
"@llvm-project//mlir:Pass",
|
||||
@ -190,7 +190,7 @@ cc_library(
|
||||
"//tensorflow/stream_executor/lib",
|
||||
"@com_google_absl//absl/strings",
|
||||
"@com_google_absl//absl/types:span",
|
||||
"@llvm-project//llvm:support",
|
||||
"@llvm-project//llvm:Support",
|
||||
"@llvm-project//mlir:AllPassesAndDialects",
|
||||
"@llvm-project//mlir:IR",
|
||||
"@llvm-project//mlir:Parser",
|
||||
@ -229,7 +229,7 @@ tf_cc_binary(
|
||||
"//tensorflow/core/platform:errors",
|
||||
"//tensorflow/stream_executor/lib",
|
||||
"@com_google_absl//absl/strings",
|
||||
"@llvm-project//llvm:support",
|
||||
"@llvm-project//llvm:Support",
|
||||
"@llvm-project//mlir:IR",
|
||||
"@llvm-project//mlir:Pass",
|
||||
"@llvm-project//mlir:Support",
|
||||
|
@ -11,7 +11,7 @@ cc_library(
|
||||
deps = [
|
||||
"@com_google_absl//absl/memory",
|
||||
"@com_google_absl//absl/strings",
|
||||
"@llvm-project//llvm:support",
|
||||
"@llvm-project//llvm:Support",
|
||||
"@llvm-project//mlir:AllPassesAndDialects",
|
||||
"@llvm-project//mlir:GPUDialect",
|
||||
"@llvm-project//mlir:IR",
|
||||
|
@ -107,7 +107,7 @@ gentbl(
|
||||
td_file = "transforms/legalize_tf_patterns.td",
|
||||
td_srcs = [
|
||||
":hlo_ops_td_files",
|
||||
"@llvm-project//llvm:support",
|
||||
"@llvm-project//llvm:Support",
|
||||
"@llvm-project//mlir:StdOpsTdFiles",
|
||||
"//tensorflow/compiler/mlir/tensorflow:tensorflow_ops_td_files",
|
||||
],
|
||||
@ -177,7 +177,8 @@ cc_library(
|
||||
"//tensorflow/compiler/xla/client:sharding_builder",
|
||||
"//tensorflow/core:framework",
|
||||
"//tensorflow/core/kernels:conv_grad_shape_utils",
|
||||
"@llvm-project//llvm:support",
|
||||
"//tensorflow/core/lib/bfloat16",
|
||||
"@llvm-project//llvm:Support",
|
||||
"@llvm-project//mlir:Analysis",
|
||||
"@llvm-project//mlir:Dialect",
|
||||
"@llvm-project//mlir:IR",
|
||||
@ -217,7 +218,7 @@ cc_library(
|
||||
"@com_google_absl//absl/container:inlined_vector",
|
||||
"@com_google_absl//absl/memory",
|
||||
"@com_google_absl//absl/strings",
|
||||
"@llvm-project//llvm:support",
|
||||
"@llvm-project//llvm:Support",
|
||||
"@llvm-project//mlir:IR",
|
||||
"@llvm-project//mlir:Pass",
|
||||
"@llvm-project//mlir:StandardOps",
|
||||
@ -233,7 +234,7 @@ cc_library(
|
||||
":hlo",
|
||||
"//tensorflow/compiler/mlir/tensorflow",
|
||||
"//tensorflow/compiler/mlir/tensorflow:lower_tf_lib",
|
||||
"@llvm-project//llvm:support",
|
||||
"@llvm-project//llvm:Support",
|
||||
"@llvm-project//mlir:IR",
|
||||
"@llvm-project//mlir:Pass",
|
||||
"@llvm-project//mlir:Support",
|
||||
@ -249,7 +250,7 @@ cc_library(
|
||||
":hlo",
|
||||
":lhlo",
|
||||
":map_hlo_to_lhlo_op",
|
||||
"@llvm-project//llvm:support",
|
||||
"@llvm-project//llvm:Support",
|
||||
"@llvm-project//mlir:StandardOps",
|
||||
],
|
||||
)
|
||||
@ -272,7 +273,7 @@ cc_library(
|
||||
":map_xla_to_scalar_op",
|
||||
"//tensorflow/compiler/xla:status",
|
||||
"@com_google_absl//absl/memory",
|
||||
"@llvm-project//llvm:support",
|
||||
"@llvm-project//llvm:Support",
|
||||
"@llvm-project//mlir:Affine",
|
||||
"@llvm-project//mlir:IR",
|
||||
"@llvm-project//mlir:Pass",
|
||||
@ -287,7 +288,7 @@ cc_library(
|
||||
deps = [
|
||||
":lhlo",
|
||||
"@com_google_absl//absl/memory",
|
||||
"@llvm-project//llvm:support",
|
||||
"@llvm-project//llvm:Support",
|
||||
"@llvm-project//mlir:IR",
|
||||
"@llvm-project//mlir:LinalgOps",
|
||||
"@llvm-project//mlir:Pass",
|
||||
@ -321,7 +322,7 @@ cc_library(
|
||||
":lhlo",
|
||||
":map_xla_to_scalar_op",
|
||||
"@com_google_absl//absl/memory",
|
||||
"@llvm-project//llvm:support",
|
||||
"@llvm-project//llvm:Support",
|
||||
"@llvm-project//mlir:IR",
|
||||
"@llvm-project//mlir:LinalgOps",
|
||||
"@llvm-project//mlir:Pass",
|
||||
@ -354,7 +355,7 @@ cc_library(
|
||||
":lhlo",
|
||||
":map_xla_to_scalar_op",
|
||||
"@com_google_absl//absl/memory",
|
||||
"@llvm-project//llvm:support",
|
||||
"@llvm-project//llvm:Support",
|
||||
"@llvm-project//mlir:GPUDialect",
|
||||
"@llvm-project//mlir:IR",
|
||||
"@llvm-project//mlir:LinalgOps",
|
||||
@ -372,7 +373,7 @@ cc_library(
|
||||
deps = [
|
||||
":lhlo",
|
||||
"@com_google_absl//absl/memory",
|
||||
"@llvm-project//llvm:support",
|
||||
"@llvm-project//llvm:Support",
|
||||
"@llvm-project//mlir:LinalgOps",
|
||||
"@llvm-project//mlir:LinalgTransforms",
|
||||
"@llvm-project//mlir:Pass",
|
||||
@ -416,7 +417,7 @@ cc_library(
|
||||
srcs = ["transforms/cycle_detector.cc"],
|
||||
hdrs = ["transforms/cycle_detector.h"],
|
||||
deps = [
|
||||
"@llvm-project//llvm:support",
|
||||
"@llvm-project//llvm:Support",
|
||||
],
|
||||
alwayslink = 1,
|
||||
)
|
||||
@ -437,8 +438,8 @@ cc_library(
|
||||
deps = [
|
||||
":cycle_detector",
|
||||
":hlo",
|
||||
"@llvm-project//llvm:ir",
|
||||
"@llvm-project//llvm:support",
|
||||
"@llvm-project//llvm:Core",
|
||||
"@llvm-project//llvm:Support",
|
||||
"@llvm-project//mlir:IR",
|
||||
"@llvm-project//mlir:Pass",
|
||||
"@llvm-project//mlir:StandardOps",
|
||||
@ -466,7 +467,7 @@ cc_library(
|
||||
srcs = ["transforms/legalize_control_flow.cc"],
|
||||
deps = [
|
||||
":hlo",
|
||||
"@llvm-project//llvm:support",
|
||||
"@llvm-project//llvm:Support",
|
||||
"@llvm-project//mlir:Analysis",
|
||||
"@llvm-project//mlir:IR",
|
||||
"@llvm-project//mlir:Pass",
|
||||
@ -490,7 +491,7 @@ cc_library(
|
||||
"//tensorflow/compiler/xla:util",
|
||||
"//tensorflow/compiler/xla/service:buffer_assignment",
|
||||
"//tensorflow/compiler/xla/service:hlo",
|
||||
"@llvm-project//llvm:support",
|
||||
"@llvm-project//llvm:Support",
|
||||
"@llvm-project//mlir:IR",
|
||||
"@llvm-project//mlir:Pass",
|
||||
"@llvm-project//mlir:StandardOps",
|
||||
@ -504,7 +505,7 @@ cc_library(
|
||||
deps = [
|
||||
":hlo",
|
||||
":xla_legalize_to_standard_inc_gen",
|
||||
"@llvm-project//llvm:support",
|
||||
"@llvm-project//llvm:Support",
|
||||
"@llvm-project//mlir:Analysis",
|
||||
"@llvm-project//mlir:IR",
|
||||
"@llvm-project//mlir:Pass",
|
||||
@ -522,7 +523,7 @@ gentbl(
|
||||
td_file = "transforms/lower_complex_patterns.td",
|
||||
td_srcs = [
|
||||
":hlo_ops_td_files",
|
||||
"@llvm-project//llvm:support",
|
||||
"@llvm-project//llvm:Support",
|
||||
"@llvm-project//mlir:StdOpsTdFiles",
|
||||
],
|
||||
)
|
||||
@ -537,7 +538,7 @@ cc_library(
|
||||
deps = [
|
||||
":hlo",
|
||||
":xla_dialect_registration",
|
||||
"@llvm-project//llvm:support",
|
||||
"@llvm-project//llvm:Support",
|
||||
"@llvm-project//mlir:Analysis",
|
||||
"@llvm-project//mlir:IR",
|
||||
"@llvm-project//mlir:Pass",
|
||||
@ -564,7 +565,7 @@ cc_library(
|
||||
srcs = ["transforms/unfuse_batch_norm.cc"],
|
||||
deps = [
|
||||
":hlo",
|
||||
"@llvm-project//llvm:support",
|
||||
"@llvm-project//llvm:Support",
|
||||
"@llvm-project//mlir:IR",
|
||||
"@llvm-project//mlir:StandardOps",
|
||||
"@llvm-project//mlir:Transforms",
|
||||
@ -637,7 +638,7 @@ cc_library(
|
||||
":infer_fusibility_op_interface",
|
||||
":xla_canonicalize_inc_gen",
|
||||
"@com_google_absl//absl/container:flat_hash_set",
|
||||
"@llvm-project//llvm:support",
|
||||
"@llvm-project//llvm:Support",
|
||||
"@llvm-project//mlir:Analysis",
|
||||
"@llvm-project//mlir:IR",
|
||||
"@llvm-project//mlir:InferTypeOpInterface",
|
||||
@ -659,6 +660,7 @@ cc_library(
|
||||
deps = [
|
||||
":attribute_importer",
|
||||
":hlo",
|
||||
":hlo_module_importer",
|
||||
":hlo_utils",
|
||||
":type_to_shape",
|
||||
"//tensorflow/compiler/xla:comparison_util",
|
||||
@ -671,7 +673,7 @@ cc_library(
|
||||
"//tensorflow/core/platform:types",
|
||||
"//tensorflow/stream_executor/lib",
|
||||
"@com_google_absl//absl/container:flat_hash_map",
|
||||
"@llvm-project//llvm:support",
|
||||
"@llvm-project//llvm:Support",
|
||||
"@llvm-project//mlir:IR",
|
||||
],
|
||||
)
|
||||
@ -692,7 +694,7 @@ cc_library(
|
||||
deps = [
|
||||
":hlo_ops_base_inc_gen",
|
||||
":lhlo_ops_inc_gen",
|
||||
"@llvm-project//llvm:support",
|
||||
"@llvm-project//llvm:Support",
|
||||
"@llvm-project//mlir:Analysis",
|
||||
"@llvm-project//mlir:IR",
|
||||
"@llvm-project//mlir:Pass",
|
||||
@ -748,7 +750,7 @@ cc_library(
|
||||
"//tensorflow/core:framework",
|
||||
"//tensorflow/core/platform:logging",
|
||||
"//tensorflow/core/platform:types",
|
||||
"@llvm-project//llvm:support",
|
||||
"@llvm-project//llvm:Support",
|
||||
"@llvm-project//mlir:IR",
|
||||
"@llvm-project//mlir:Support",
|
||||
],
|
||||
@ -798,7 +800,7 @@ cc_library(
|
||||
"//tensorflow/core:lib",
|
||||
"//tensorflow/core:protos_all_cc",
|
||||
"//tensorflow/stream_executor/lib",
|
||||
"@llvm-project//llvm:support",
|
||||
"@llvm-project//llvm:Support",
|
||||
"@llvm-project//mlir:Analysis",
|
||||
"@llvm-project//mlir:IR",
|
||||
"@llvm-project//mlir:Pass",
|
||||
@ -846,7 +848,7 @@ cc_library(
|
||||
"//tensorflow/compiler/xla/service:hlo_casting_utils",
|
||||
"//tensorflow/core:lib",
|
||||
"@com_google_absl//absl/types:optional",
|
||||
"@llvm-project//llvm:support",
|
||||
"@llvm-project//llvm:Support",
|
||||
"@llvm-project//mlir:IR",
|
||||
"@llvm-project//mlir:StandardOps",
|
||||
],
|
||||
@ -877,7 +879,7 @@ cc_library(
|
||||
"//tensorflow/compiler/xla/service:hlo_parser",
|
||||
"//tensorflow/compiler/xla/service:hlo_proto_cc",
|
||||
"//tensorflow/core:lib",
|
||||
"@llvm-project//llvm:support",
|
||||
"@llvm-project//llvm:Support",
|
||||
"@llvm-project//mlir:IR",
|
||||
"@llvm-project//mlir:Translation",
|
||||
],
|
||||
@ -888,8 +890,8 @@ tf_native_cc_binary(
|
||||
name = "operator_writer_gen",
|
||||
srcs = ["operator_writer_gen.cc"],
|
||||
deps = [
|
||||
"@llvm-project//llvm:support",
|
||||
"@llvm-project//llvm:tablegen",
|
||||
"@llvm-project//llvm:Support",
|
||||
"@llvm-project//llvm:TableGen",
|
||||
"@llvm-project//mlir:Support",
|
||||
"@llvm-project//mlir:TableGen",
|
||||
],
|
||||
|
@ -15,6 +15,8 @@ limitations under the License.
|
||||
|
||||
#include "tensorflow/compiler/mlir/xla/hlo_function_importer.h"
|
||||
|
||||
#include <unordered_map>
|
||||
|
||||
#include "absl/types/optional.h"
|
||||
#include "llvm/ADT/ArrayRef.h"
|
||||
#include "llvm/ADT/STLExtras.h"
|
||||
@ -79,30 +81,35 @@ bool DotIsDefault(const HloInstruction* instruction) {
|
||||
}
|
||||
} // namespace
|
||||
|
||||
StatusOr<mlir::FuncOp> HloFunctionImporter::ImportFunction(
|
||||
mlir::ModuleOp module, mlir::Builder* builder,
|
||||
std::unordered_map<HloComputation*, FuncOp>* function_map,
|
||||
HloComputation* computation) {
|
||||
HloFunctionImporter importer(module, builder, function_map);
|
||||
return importer.ImportFunction(computation);
|
||||
Status HloFunctionImporter::ImportAsFunc(
|
||||
const HloComputation& computation, mlir::ModuleOp module,
|
||||
std::unordered_map<const HloComputation*, FuncOp>* function_map,
|
||||
mlir::Builder* builder) {
|
||||
HloFunctionImporter importer(module, function_map, builder);
|
||||
return importer.ImportAsFunc(computation).status();
|
||||
}
|
||||
|
||||
StatusOr<mlir::FuncOp> HloFunctionImporter::ImportFunction(
|
||||
HloComputation* computation) {
|
||||
auto& imported = (*function_map_)[computation];
|
||||
Status HloFunctionImporter::ImportAsRegion(
|
||||
const xla::HloComputation& computation, mlir::Region* region,
|
||||
mlir::Builder* builder) {
|
||||
HloFunctionImporter importer(region->getParentOfType<mlir::ModuleOp>(), {},
|
||||
builder);
|
||||
return importer.ImportAsRegion(computation, region);
|
||||
}
|
||||
|
||||
StatusOr<mlir::FuncOp> HloFunctionImporter::ImportAsFunc(
|
||||
const HloComputation& computation) {
|
||||
auto& imported = (*function_map_)[&computation];
|
||||
if (imported) return imported;
|
||||
|
||||
llvm::SmallVector<Type, 4> args, rets;
|
||||
TF_RETURN_IF_ERROR(
|
||||
GetMlirTypes(computation->parameter_instructions(), &args));
|
||||
TF_RETURN_IF_ERROR(GetMlirTypes({computation->root_instruction()}, &rets));
|
||||
|
||||
TF_RETURN_IF_ERROR(GetMlirTypes(computation.parameter_instructions(), &args));
|
||||
TF_RETURN_IF_ERROR(GetMlirTypes({computation.root_instruction()}, &rets));
|
||||
auto func_type = mlir::FunctionType::get(args, rets, context_);
|
||||
|
||||
string computation_name =
|
||||
computation->parent()->entry_computation() == computation
|
||||
computation.parent()->entry_computation() == &computation
|
||||
? "main"
|
||||
: SanitizeFunctionName(computation->name());
|
||||
: SanitizeFunctionName(computation.name());
|
||||
|
||||
// Construct the MLIR function and map arguments.
|
||||
llvm::ArrayRef<mlir::NamedAttribute> attrs;
|
||||
@ -119,31 +126,30 @@ StatusOr<mlir::FuncOp> HloFunctionImporter::ImportFunction(
|
||||
return function;
|
||||
}
|
||||
|
||||
tensorflow::Status HloFunctionImporter::ImportComputation(
|
||||
HloComputation* computation, mlir::Region* region) {
|
||||
tensorflow::Status HloFunctionImporter::ImportAsRegion(
|
||||
const HloComputation& computation, mlir::Region* region) {
|
||||
// TODO(hinsu): Store computation name as an attribute for round-trip.
|
||||
auto* block = new mlir::Block;
|
||||
region->push_back(block);
|
||||
|
||||
llvm::SmallVector<Type, 4> args;
|
||||
TF_RETURN_IF_ERROR(
|
||||
GetMlirTypes(computation->parameter_instructions(), &args));
|
||||
TF_RETURN_IF_ERROR(GetMlirTypes(computation.parameter_instructions(), &args));
|
||||
block->addArguments(args);
|
||||
|
||||
return ImportInstructions(computation, block);
|
||||
}
|
||||
|
||||
tensorflow::Status HloFunctionImporter::ImportInstructions(
|
||||
HloComputation* computation, mlir::Block* block) {
|
||||
const HloComputation& computation, mlir::Block* block) {
|
||||
// Setup the input parameters.
|
||||
const int num_parameters = computation->num_parameters();
|
||||
const int num_parameters = computation.num_parameters();
|
||||
for (int i = 0; i < num_parameters; i++) {
|
||||
auto hlo_parameter = computation->parameter_instruction(i);
|
||||
auto hlo_parameter = computation.parameter_instruction(i);
|
||||
instruction_value_map_[hlo_parameter] = block->getArgument(i);
|
||||
}
|
||||
|
||||
mlir::OpBuilder builder = mlir::OpBuilder::atBlockEnd(block);
|
||||
for (auto instruction : computation->MakeInstructionPostOrder()) {
|
||||
for (auto instruction : computation.MakeInstructionPostOrder()) {
|
||||
TF_ASSIGN_OR_RETURN(auto new_operation,
|
||||
ImportInstruction(instruction, &builder));
|
||||
if (new_operation) {
|
||||
@ -156,7 +162,7 @@ tensorflow::Status HloFunctionImporter::ImportInstructions(
|
||||
|
||||
// Setup the return type (HLO only supports a single return value).
|
||||
TF_ASSIGN_OR_RETURN(auto result,
|
||||
GetMlirValue(computation->root_instruction()));
|
||||
GetMlirValue(computation.root_instruction()));
|
||||
|
||||
// Create terminator op depending on the parent op of this region.
|
||||
if (llvm::isa<FuncOp>(block->getParentOp())) {
|
||||
@ -249,7 +255,7 @@ StatusOr<mlir::Operation*> HloFunctionImporter::ImportInstruction(
|
||||
}
|
||||
case HloOpcode::kCall: {
|
||||
TF_ASSIGN_OR_RETURN(FuncOp function,
|
||||
ImportFunction(instruction->to_apply()));
|
||||
ImportAsFunc(*instruction->to_apply()));
|
||||
mlir::Operation* new_operation =
|
||||
func_builder->create<mlir::CallOp>(loc, function, operands);
|
||||
return new_operation;
|
||||
@ -365,7 +371,7 @@ StatusOr<mlir::Operation*> HloFunctionImporter::ImportInstruction(
|
||||
|
||||
auto scatter_op = func_builder->create<mlir::xla_hlo::ScatterOp>(
|
||||
loc, result_type, operands, attributes);
|
||||
TF_RETURN_IF_ERROR(ImportComputation(scatter->to_apply(),
|
||||
TF_RETURN_IF_ERROR(ImportAsRegion(*scatter->to_apply(),
|
||||
&scatter_op.update_computation()));
|
||||
return scatter_op.getOperation();
|
||||
}
|
||||
@ -387,9 +393,9 @@ StatusOr<mlir::Operation*> HloFunctionImporter::ImportInstruction(
|
||||
auto select_scatter_op =
|
||||
func_builder->create<mlir::xla_hlo::SelectAndScatterOp>(
|
||||
loc, result_type, operands, attributes);
|
||||
TF_RETURN_IF_ERROR(ImportComputation(select_scatter->select(),
|
||||
TF_RETURN_IF_ERROR(ImportAsRegion(*select_scatter->select(),
|
||||
&select_scatter_op.select()));
|
||||
TF_RETURN_IF_ERROR(ImportComputation(select_scatter->scatter(),
|
||||
TF_RETURN_IF_ERROR(ImportAsRegion(*select_scatter->scatter(),
|
||||
&select_scatter_op.scatter()));
|
||||
return select_scatter_op.getOperation();
|
||||
}
|
||||
@ -414,8 +420,8 @@ StatusOr<mlir::Operation*> HloFunctionImporter::ImportInstruction(
|
||||
loc, result_type, operands,
|
||||
builder_->getI64IntegerAttr(sort_instruction->sort_dimension()),
|
||||
builder_->getBoolAttr(sort_instruction->is_stable()));
|
||||
TF_RETURN_IF_ERROR(ImportComputation(sort_instruction->to_apply(),
|
||||
&sort_op.comparator()));
|
||||
TF_RETURN_IF_ERROR(
|
||||
ImportAsRegion(*sort_instruction->to_apply(), &sort_op.comparator()));
|
||||
return sort_op.getOperation();
|
||||
}
|
||||
case HloOpcode::kConditional: {
|
||||
@ -430,9 +436,9 @@ StatusOr<mlir::Operation*> HloFunctionImporter::ImportInstruction(
|
||||
|
||||
auto op = func_builder->create<mlir::xla_hlo::IfOp>(loc, rets, operands,
|
||||
attributes);
|
||||
TF_RETURN_IF_ERROR(ImportComputation(instruction->true_computation(),
|
||||
TF_RETURN_IF_ERROR(ImportAsRegion(*instruction->true_computation(),
|
||||
&op.true_branch()));
|
||||
TF_RETURN_IF_ERROR(ImportComputation(instruction->false_computation(),
|
||||
TF_RETURN_IF_ERROR(ImportAsRegion(*instruction->false_computation(),
|
||||
&op.false_branch()));
|
||||
return op.getOperation();
|
||||
}
|
||||
@ -448,8 +454,7 @@ StatusOr<mlir::Operation*> HloFunctionImporter::ImportInstruction(
|
||||
llvm::enumerate(instruction->branch_computations())) {
|
||||
auto index = index_and_computation.index();
|
||||
HloComputation* computation = index_and_computation.value();
|
||||
TF_RETURN_IF_ERROR(
|
||||
ImportComputation(computation, &op.branches()[index]));
|
||||
TF_RETURN_IF_ERROR(ImportAsRegion(*computation, &op.branches()[index]));
|
||||
}
|
||||
return op.getOperation();
|
||||
}
|
||||
@ -468,7 +473,7 @@ StatusOr<mlir::Operation*> HloFunctionImporter::ImportInstruction(
|
||||
attributes.push_back(ConvertChannelHandle(all_reduce->channel_id()));
|
||||
auto all_reduce_op = func_builder->create<mlir::xla_hlo::AllReduceOp>(
|
||||
loc, result_type, operands, attributes);
|
||||
TF_RETURN_IF_ERROR(ImportComputation(all_reduce->to_apply(),
|
||||
TF_RETURN_IF_ERROR(ImportAsRegion(*all_reduce->to_apply(),
|
||||
&all_reduce_op.computation()));
|
||||
return all_reduce_op.getOperation();
|
||||
}
|
||||
@ -481,7 +486,7 @@ StatusOr<mlir::Operation*> HloFunctionImporter::ImportInstruction(
|
||||
llvm::makeArrayRef(operands).drop_front(num_inputs),
|
||||
ConvertDimensions(instruction->dimensions()));
|
||||
TF_RETURN_IF_ERROR(
|
||||
ImportComputation(instruction->to_apply(), &reduce.body()));
|
||||
ImportAsRegion(*instruction->to_apply(), &reduce.body()));
|
||||
return reduce.getOperation();
|
||||
}
|
||||
case HloOpcode::kReverse: {
|
||||
@ -517,9 +522,9 @@ StatusOr<mlir::Operation*> HloFunctionImporter::ImportInstruction(
|
||||
auto op = func_builder->create<mlir::xla_hlo::WhileOp>(
|
||||
loc, operands[0].getType(), operands[0]);
|
||||
TF_RETURN_IF_ERROR(
|
||||
ImportComputation(instruction->while_condition(), &op.cond()));
|
||||
ImportAsRegion(*instruction->while_condition(), &op.cond()));
|
||||
TF_RETURN_IF_ERROR(
|
||||
ImportComputation(instruction->while_body(), &op.body()));
|
||||
ImportAsRegion(*instruction->while_body(), &op.body()));
|
||||
return op.getOperation();
|
||||
}
|
||||
case HloOpcode::kGetTupleElement: {
|
||||
@ -580,7 +585,7 @@ StatusOr<mlir::Operation*> HloFunctionImporter::ImportInstruction(
|
||||
auto reduce = func_builder->create<mlir::xla_hlo::ReduceWindowOp>(
|
||||
loc, result_type, operands, attributes);
|
||||
TF_RETURN_IF_ERROR(
|
||||
ImportComputation(instruction->to_apply(), &reduce.body()));
|
||||
ImportAsRegion(*instruction->to_apply(), &reduce.body()));
|
||||
return reduce.getOperation();
|
||||
}
|
||||
case HloOpcode::kMap: {
|
||||
@ -588,7 +593,7 @@ StatusOr<mlir::Operation*> HloFunctionImporter::ImportInstruction(
|
||||
loc, result_type, operands,
|
||||
ConvertDimensions(instruction->dimensions()));
|
||||
TF_RETURN_IF_ERROR(
|
||||
ImportComputation(instruction->to_apply(), &op.computation()));
|
||||
ImportAsRegion(*instruction->to_apply(), &op.computation()));
|
||||
return op.getOperation();
|
||||
}
|
||||
case HloOpcode::kConvolution: {
|
||||
|
@ -42,29 +42,39 @@ class Shape;
|
||||
// Helper class for importing HloComputations.
|
||||
class HloFunctionImporter {
|
||||
public:
|
||||
static StatusOr<mlir::FuncOp> ImportFunction(
|
||||
mlir::ModuleOp module, mlir::Builder* builder,
|
||||
std::unordered_map<xla::HloComputation*, mlir::FuncOp>* function_map,
|
||||
xla::HloComputation* computation);
|
||||
// Imports the given computation as a function in the given module. This also
|
||||
// imports any computations referred by instructions in this computation.
|
||||
static Status ImportAsFunc(const xla::HloComputation& computation,
|
||||
mlir::ModuleOp module,
|
||||
std::unordered_map<const xla::HloComputation*,
|
||||
mlir::FuncOp>* function_map,
|
||||
mlir::Builder* builder);
|
||||
|
||||
// Imports the given hlo computation to the specified region.
|
||||
static Status ImportAsRegion(const xla::HloComputation& computation,
|
||||
mlir::Region* region, mlir::Builder* builder);
|
||||
|
||||
private:
|
||||
HloFunctionImporter(
|
||||
mlir::ModuleOp module, mlir::Builder* builder,
|
||||
std::unordered_map<xla::HloComputation*, mlir::FuncOp>* function_map)
|
||||
HloFunctionImporter(mlir::ModuleOp module,
|
||||
std::unordered_map<const xla::HloComputation*,
|
||||
mlir::FuncOp>* function_map,
|
||||
mlir::Builder* builder)
|
||||
: context_(module.getContext()),
|
||||
module_(module),
|
||||
builder_(builder),
|
||||
function_map_(function_map) {}
|
||||
|
||||
StatusOr<mlir::FuncOp> ImportFunction(xla::HloComputation* computation);
|
||||
// Imports the given computation as a new function, if it hasn't been already
|
||||
// imported.
|
||||
StatusOr<mlir::FuncOp> ImportAsFunc(const xla::HloComputation& computation);
|
||||
|
||||
// Imports the given computation in the specified region.
|
||||
tensorflow::Status ImportComputation(HloComputation* computation,
|
||||
tensorflow::Status ImportAsRegion(const HloComputation& computation,
|
||||
mlir::Region* region);
|
||||
|
||||
// Imports instructions from the given computation in the specified block.
|
||||
// Assumes that the block already has correct arguments populated.
|
||||
tensorflow::Status ImportInstructions(HloComputation* computation,
|
||||
tensorflow::Status ImportInstructions(const HloComputation& computation,
|
||||
mlir::Block* block);
|
||||
|
||||
// Imports an instruction.
|
||||
@ -125,7 +135,7 @@ class HloFunctionImporter {
|
||||
mlir::Builder* builder_;
|
||||
|
||||
// Mapping from HloComputation to the created MLIR function.
|
||||
std::unordered_map<xla::HloComputation*, mlir::FuncOp>* function_map_;
|
||||
std::unordered_map<const xla::HloComputation*, mlir::FuncOp>* function_map_;
|
||||
|
||||
// Mapping from HloInstructions to the associative MLIR values.
|
||||
std::unordered_map<xla::HloInstruction*, mlir::Value> instruction_value_map_;
|
||||
|
@ -33,11 +33,11 @@ namespace xla {
|
||||
Status HloModuleImporter::Import(const xla::HloModule& module) {
|
||||
// TODO(hinsu): Only import the entry computation here once all HLO ops with
|
||||
// reference to other computation are updated to have a region instead of a
|
||||
// function attribute.
|
||||
for (const auto& computation : module.computations()) {
|
||||
auto result = HloFunctionImporter::ImportFunction(
|
||||
module_, &builder_, &function_map_, computation);
|
||||
TF_RETURN_IF_ERROR(result.status());
|
||||
// function attribute. Currently the importer test doesn't refer to all the
|
||||
// computations from the entry computation so tests may need some update.
|
||||
for (const auto* computation : module.computations()) {
|
||||
TF_RETURN_IF_ERROR(HloFunctionImporter::ImportAsFunc(
|
||||
*computation, module_, &function_map_, &builder_));
|
||||
}
|
||||
|
||||
return Status::OK();
|
||||
|
@ -54,7 +54,7 @@ class HloModuleImporter {
|
||||
// Map for tracking which MLIR function map to which HLO Computation. This
|
||||
// tracks functions as they are imported and provides a quick lookup for
|
||||
// functions invoked by control flow related operations (e.g. while, call).
|
||||
std::unordered_map<xla::HloComputation*, mlir::FuncOp> function_map_;
|
||||
std::unordered_map<const xla::HloComputation*, mlir::FuncOp> function_map_;
|
||||
};
|
||||
|
||||
} // namespace xla
|
||||
|
@ -836,21 +836,33 @@ LogicalResult ConcatenateOp::inferReturnTypes(
|
||||
auto dimension = dimension_attr.getInt();
|
||||
|
||||
auto first_type = (*operands.begin()).getType().cast<ShapedType>();
|
||||
|
||||
auto out_element = first_type.getElementType();
|
||||
|
||||
for (auto operand : operands.getTypes()) {
|
||||
auto element_type = getElementTypeOrSelf(operand);
|
||||
if (element_type != out_element) {
|
||||
return failure();
|
||||
}
|
||||
}
|
||||
|
||||
// If an input is unranked the output shape is unranked.
|
||||
if (!first_type.hasRank()) {
|
||||
inferredReturnTypes.push_back(UnrankedTensorType::get(out_element));
|
||||
return success();
|
||||
}
|
||||
|
||||
auto out_shape = llvm::to_vector<6>(first_type.getShape());
|
||||
out_shape[dimension] = 0;
|
||||
|
||||
for (auto operand : operands.getTypes()) {
|
||||
auto type = operand.cast<ShapedType>();
|
||||
auto dim = type.getShape()[dimension];
|
||||
|
||||
// Validate the element types match.
|
||||
if (type.getElementType() != out_element) {
|
||||
return failure();
|
||||
if (!type.hasRank()) {
|
||||
inferredReturnTypes.push_back(UnrankedTensorType::get(out_element));
|
||||
return success();
|
||||
}
|
||||
|
||||
// If the dimension is dynamic we know the output dimension is dynamic.
|
||||
auto dim = type.getShape()[dimension];
|
||||
if (dim == -1) {
|
||||
out_shape[dimension] = -1;
|
||||
break;
|
||||
@ -937,26 +949,39 @@ OpFoldResult ConcatenateOp::fold(ArrayRef<Attribute> operands) {
|
||||
}
|
||||
|
||||
static LogicalResult Verify(ConcatenateOp op) {
|
||||
auto firstType = op.getOperand(0).getType().cast<RankedTensorType>();
|
||||
Type element_type = getElementTypeOrSelf(op.getOperand(0).getType());
|
||||
RankedTensorType first_ranked_type;
|
||||
int num_operands = op.getNumOperands();
|
||||
for (int i = 0; i < num_operands; i++) {
|
||||
auto second_type = op.getOperand(i).getType().dyn_cast<ShapedType>();
|
||||
if (second_type.getElementType() != element_type) {
|
||||
return op.emitOpError(
|
||||
llvm::formatv("operands (0) and ({0}) do not match element type", i));
|
||||
}
|
||||
|
||||
auto firstShape = firstType.getShape();
|
||||
int numOperands = op.getNumOperands();
|
||||
for (int i = 1; i < numOperands; i++) {
|
||||
auto secondType = op.getOperand(i).getType().cast<RankedTensorType>();
|
||||
if (!second_type.hasRank()) {
|
||||
continue;
|
||||
}
|
||||
|
||||
if (firstType.getRank() != secondType.getRank()) {
|
||||
if (!first_ranked_type) {
|
||||
first_ranked_type = second_type.cast<RankedTensorType>();
|
||||
continue;
|
||||
}
|
||||
|
||||
if (first_ranked_type.getRank() != second_type.getRank()) {
|
||||
return op.emitOpError(
|
||||
llvm::formatv("operands (0) and ({0}) do not match rank", i));
|
||||
}
|
||||
|
||||
auto secondShape = secondType.getShape();
|
||||
for (int d = 0; d < firstType.getRank(); ++d) {
|
||||
if (firstShape[d] != secondShape[d] && d != op.dimension()) {
|
||||
auto first_shape = second_type.getShape();
|
||||
auto second_shape = second_type.getShape();
|
||||
for (int d = 0; d < first_ranked_type.getRank(); ++d) {
|
||||
if (first_shape[d] != second_shape[d] && d != op.dimension()) {
|
||||
return op.emitOpError(llvm::formatv(
|
||||
"operands (0) and ({0}) non-concat dimensions do not match "
|
||||
"({1}) != ({2})",
|
||||
i, llvm::make_range(firstShape.begin(), firstShape.end()),
|
||||
llvm::make_range(secondShape.begin(), secondShape.end())));
|
||||
i, llvm::make_range(first_shape.begin(), first_shape.end()),
|
||||
llvm::make_range(second_shape.begin(), second_shape.end())));
|
||||
}
|
||||
}
|
||||
}
|
||||
|
@ -358,7 +358,7 @@ def HLO_SubOp : HLO_BinaryElementwiseOp<"subtract",
|
||||
}
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
// XLA binary elementwise op definitions.
|
||||
// XLA binary logical elementwise op definitions.
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
// See https://www.tensorflow.org/xla/operation_semantics#element-wise_binary_arithmetic_operations
|
||||
@ -379,15 +379,6 @@ def HLO_XorOp : HLO_BinaryLogicalElementwiseOp<"xor">, BASE_HLO_XorOp;
|
||||
// XLA communication op definitions.
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
// Represents a unique identifier for each Send/Recv instruction pair or
|
||||
// optionally for collective instructions (AllReduce, CollectivePermute,
|
||||
// AllToAll). Non-positive channel_id handle is equivalent to no channel id.
|
||||
def ChannelHandle : StructAttr<"ChannelHandle", HLO_Dialect, [
|
||||
StructFieldAttr<"handle", I64Attr>,
|
||||
StructFieldAttr<"type", I64Attr>]> {
|
||||
let description = "two 64-bit integers 'handle' and 'type'";
|
||||
}
|
||||
|
||||
// InfeedOp corresponds to 'InfeedWithToken' xla client API and not 'Infeed'.
|
||||
// InfeedWithToken allows ordering of infeed HLO instructions using tokens.
|
||||
def HLO_InfeedOp : HLO_Op<"infeed", []> {
|
||||
@ -451,7 +442,7 @@ def HLO_SendOp : HLO_Op<"send", []> {
|
||||
let arguments = (ins
|
||||
HLO_TensorOrTuple:$operand,
|
||||
HLO_Token:$token,
|
||||
ChannelHandle:$channel_id,
|
||||
ChannelHandle<HLO_Dialect>:$channel_id,
|
||||
DefaultValuedAttr<BoolAttr, "false">:$is_host_transfer
|
||||
);
|
||||
|
||||
@ -476,7 +467,7 @@ def HLO_RecvOp : HLO_Op<"recv", []> {
|
||||
|
||||
let arguments = (ins
|
||||
HLO_Token:$token,
|
||||
ChannelHandle:$channel_id,
|
||||
ChannelHandle<HLO_Dialect>:$channel_id,
|
||||
DefaultValuedAttr<BoolAttr, "false">:$is_host_transfer
|
||||
);
|
||||
|
||||
@ -564,16 +555,8 @@ def HLO_CaseOp: HLO_Op<"case", [RecursiveSideEffects]>,
|
||||
|
||||
|
||||
def HLO_WhileOp: HLO_Op<"while", [RecursiveSideEffects,
|
||||
SameOperandsAndResultType]> {
|
||||
string summary = "While operator";
|
||||
|
||||
string description = [{
|
||||
Returns the result of executing a body function until the cond body returns
|
||||
true.
|
||||
|
||||
See https://www.tensorflow.org/xla/operation_semantics#while.
|
||||
}];
|
||||
|
||||
SameOperandsAndResultType]>,
|
||||
BASE_HLO_WhileOp {
|
||||
let arguments = (ins HLO_TensorOrTuple:$val);
|
||||
|
||||
let regions = (region AnyRegion:$cond, AnyRegion:$body);
|
||||
@ -590,7 +573,7 @@ def HLO_AllReduceOp : HLO_Op<"all_reduce",
|
||||
let arguments = (ins
|
||||
HLO_Tensor:$operand,
|
||||
I64ElementsAttr:$replica_groups,
|
||||
OptionalAttr<ChannelHandle>:$channel_id
|
||||
OptionalAttr<ChannelHandle<HLO_Dialect>>:$channel_id
|
||||
);
|
||||
let regions = (region SizedRegion<1>:$computation);
|
||||
let results = (outs HLO_Tensor);
|
||||
|
@ -584,6 +584,15 @@ class BASE_HLO_CaseOp {
|
||||
// XLA parallelism related op definitions.
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
// Represents a unique identifier for each Send/Recv instruction pair or
|
||||
// optionally for collective instructions (AllReduce, CollectivePermute,
|
||||
// AllToAll). Non-positive channel_id handle is equivalent to no channel id.
|
||||
class ChannelHandle<Dialect dialect> : StructAttr<"ChannelHandle", dialect, [
|
||||
StructFieldAttr<"handle", I64Attr>,
|
||||
StructFieldAttr<"type", I64Attr>]> {
|
||||
let description = "two 64-bit integers 'handle' and 'type'";
|
||||
}
|
||||
|
||||
class BASE_HLO_ReplicaIdOp {
|
||||
string summary = "ReplicaId operator";
|
||||
|
||||
@ -1258,4 +1267,45 @@ class BASE_HLO_RngNormalOp {
|
||||
}];
|
||||
}
|
||||
|
||||
class BASE_HLO_ReducePrecisionOp {
|
||||
string summary = "Reduce precision operator";
|
||||
|
||||
string description = [{
|
||||
Models the effect of converting floating - point values to a lower -
|
||||
precision format(such as IEEE - FP16) and back to the original
|
||||
format. The number of exponent and mantissa bits in the lower -
|
||||
precision format can be specified arbitrarily,
|
||||
although all bit sizes may not be supported on all hardware
|
||||
implementations.
|
||||
|
||||
See https://www.tensorflow.org/xla/operation_semantics#reduceprecision.
|
||||
}];
|
||||
}
|
||||
|
||||
class BASE_HLO_InfeedOp {
|
||||
string summary = "Infeed operator";
|
||||
|
||||
string description = [{
|
||||
Reads a single data item from the implicit Infeed streaming interface of
|
||||
the device, interpreting the data as the given shape and its layout, and
|
||||
returns an LHLO op of the data. Multiple Infeed operations are allowed in a
|
||||
computation, but there must be a total order among the Infeed operations.
|
||||
For example, two Infeeds in the code below have a total order since there
|
||||
is a dependency between the while loops.
|
||||
|
||||
See https://www.tensorflow.org/xla/operation_semantics#infeed
|
||||
}];
|
||||
}
|
||||
|
||||
class BASE_HLO_WhileOp {
|
||||
string summary = "While operator";
|
||||
|
||||
string description = [{
|
||||
Returns the result of executing a body function until the cond body returns
|
||||
true.
|
||||
|
||||
See https://www.tensorflow.org/xla/operation_semantics#while.
|
||||
}];
|
||||
}
|
||||
|
||||
#endif // HLO_OPS_BASE
|
||||
|
@ -14,6 +14,20 @@ limitations under the License.
|
||||
==============================================================================*/
|
||||
|
||||
// This is the operation definition file for LXLA.
|
||||
//
|
||||
// This file largely overlaps with hlo_ops.td at a logic level. It's tempting to
|
||||
// merge these two files together, but we need to consider the following
|
||||
// obstacles:
|
||||
// * We need to have a common representation for arguments. That is to say,
|
||||
// HLO_Array<X> translates to HLO_Tensor<X> in HLO dialect, and
|
||||
// Arg<LHLO_Buffer<X>, "", [Mem(Read|Write)]> in LHLO. Array types within tuples
|
||||
// also need to be transformed.
|
||||
// * As of now, TableGen's dag functions are not sufficient to accomplish the
|
||||
// one above.
|
||||
// * Traits aren't identical, but need to be coped. For example,
|
||||
// SameOperandAndResultType in HLO corresponds to SameTypeOperands in LHLO.
|
||||
// * Also, currently HLO describes the API in XLA's client side, not service
|
||||
// side. LHLO aims for the service side.
|
||||
|
||||
#ifndef LHLO_OPS
|
||||
#define LHLO_OPS
|
||||
@ -38,11 +52,17 @@ def LHLO_IntBuffer : MemRefOf<[HLO_Int]>;
|
||||
// Any floating-point tensor types
|
||||
def LHLO_FpBuffer : MemRefOf<[AnyFloat]>;
|
||||
|
||||
def LHLO_ComplexBuffer : MemRefOf<[AnyComplex]>;
|
||||
|
||||
def LHLO_FpOrComplexBuffer : MemRefOf<[AnyFloat, AnyComplex]>;
|
||||
|
||||
def LHLO_PredBuffer : MemRefOf<[HLO_Pred]>;
|
||||
|
||||
// Any integer or floating-point tensor types
|
||||
def LHLO_IntOrFpBuffer : MemRefOf<[HLO_Int, AnyFloat]>;
|
||||
|
||||
def LHLO_PredOrIntBuffer : MemRefOf<[HLO_Int, HLO_Pred]>;
|
||||
|
||||
def LHLO_Buffer : MemRefOf<[AnyFloat, AnySignlessInteger, AnyComplex]>;
|
||||
|
||||
def LHLO_TupleBuffer : NestedTupleOf<[LHLO_Buffer]>;
|
||||
@ -74,88 +94,126 @@ def LHLO_IotaOp : LHLO_Op<"iota", []>, BASE_HLO_IotaOp {
|
||||
//===----------------------------------------------------------------------===//
|
||||
// See https://www.tensorflow.org/xla/operation_semantics#element-wise_unary_functions
|
||||
|
||||
class LHLO_UnaryElementwiseOp<string mnemonic> :
|
||||
LHLO_Op<mnemonic, [SameTypeOperands]> {
|
||||
let arguments = (ins Arg<LHLO_Buffer, "", [MemRead]>:$input,
|
||||
Arg<LHLO_Buffer, "", [MemWrite]>:$output);
|
||||
class LHLO_UnaryElementwiseOp<string mnemonic,
|
||||
Type BufferType = LHLO_Buffer,
|
||||
list<OpTrait> traits = [SameTypeOperands]>
|
||||
: LHLO_Op<mnemonic, traits> {
|
||||
let arguments = (ins Arg<BufferType, "", [MemRead]>:$input,
|
||||
Arg<BufferType, "", [MemWrite]>:$output);
|
||||
}
|
||||
|
||||
def LHLO_AbsOp: LHLO_UnaryElementwiseOp<"abs">, BASE_HLO_AbsOp;
|
||||
|
||||
def LHLO_CeilOp: LHLO_UnaryElementwiseOp<"ceil">, BASE_HLO_CeilOp;
|
||||
// TODO(timshen): add a custom verifier.
|
||||
def LHLO_BitcastConvertOp:
|
||||
LHLO_UnaryElementwiseOp<"bitcast_convert", LHLO_Buffer, [SameOperandsShape]>, BASE_HLO_BitcastConvertOp;
|
||||
|
||||
def LHLO_ConvertOp : LHLO_Op<"convert", [SameOperandsShape]>, BASE_HLO_ConvertOp {
|
||||
let arguments = (ins Arg<LHLO_Buffer, "", [MemRead]>:$input,
|
||||
Arg<LHLO_Buffer, "", [MemWrite]>:$output);
|
||||
}
|
||||
def LHLO_CeilOp: LHLO_UnaryElementwiseOp<"ceil", LHLO_FpBuffer>, BASE_HLO_CeilOp;
|
||||
|
||||
def LHLO_CosOp: LHLO_UnaryElementwiseOp<"cosine">, BASE_HLO_CosOp;
|
||||
def LHLO_ClzOp: LHLO_UnaryElementwiseOp<"count_leading_zeros", LHLO_IntBuffer>, BASE_HLO_ClzOp;
|
||||
|
||||
def LHLO_ExpOp: LHLO_UnaryElementwiseOp<"exponential">, BASE_HLO_ExpOp;
|
||||
// TODO(timshen): add a custom verifier.
|
||||
def LHLO_ConvertOp : LHLO_UnaryElementwiseOp<"convert", LHLO_Buffer, [SameOperandsShape]>, BASE_HLO_ConvertOp;
|
||||
|
||||
def LHLO_CosOp: LHLO_UnaryElementwiseOp<"cosine", LHLO_FpOrComplexBuffer>, BASE_HLO_CosOp;
|
||||
|
||||
def LHLO_ExpOp: LHLO_UnaryElementwiseOp<"exponential", LHLO_FpOrComplexBuffer>, BASE_HLO_ExpOp;
|
||||
|
||||
def LHLO_Expm1Op: LHLO_UnaryElementwiseOp<"exponential_minus_one", LHLO_FpOrComplexBuffer>, BASE_HLO_Expm1Op;
|
||||
|
||||
def LHLO_FloorOp: LHLO_UnaryElementwiseOp<"floor", LHLO_FpBuffer>, BASE_HLO_FloorOp;
|
||||
|
||||
def LHLO_ImagOp: LHLO_Op<"imag", [SameOperandsShape]>, BASE_HLO_ImagOp {
|
||||
let arguments = (ins Arg<LHLO_Buffer, "", [MemRead]>:$input,
|
||||
Arg<LHLO_Buffer, "", [MemWrite]>:$output);
|
||||
let arguments = (ins Arg<LHLO_ComplexBuffer, "", [MemRead]>:$input,
|
||||
Arg<LHLO_FpBuffer, "", [MemWrite]>:$output);
|
||||
}
|
||||
|
||||
def LHLO_LogOp: LHLO_UnaryElementwiseOp<"log">, BASE_HLO_LogOp;
|
||||
def LHLO_IsFiniteOp: LHLO_Op<"is_finite", [SameOperandsShape]>, BASE_HLO_IsFiniteOp {
|
||||
let arguments = (ins Arg<LHLO_FpBuffer, "", [MemRead]>:$input,
|
||||
Arg<LHLO_PredBuffer, "", [MemWrite]>:$output);
|
||||
}
|
||||
|
||||
def LHLO_LogOp: LHLO_UnaryElementwiseOp<"log", LHLO_FpOrComplexBuffer>, BASE_HLO_LogOp;
|
||||
|
||||
def LHLO_Log1pOp: LHLO_UnaryElementwiseOp<"log_plus_one", LHLO_FpOrComplexBuffer>, BASE_HLO_Log1pOp;
|
||||
|
||||
def LHLO_NegOp: LHLO_UnaryElementwiseOp<"negate">, BASE_HLO_NegOp;
|
||||
|
||||
def LHLO_NotOp: LHLO_UnaryElementwiseOp<"not", LHLO_PredOrIntBuffer>, BASE_HLO_NotOp;
|
||||
|
||||
def LHLO_PopulationCountOp: LHLO_UnaryElementwiseOp<"popcnt", LHLO_IntBuffer>, BASE_HLO_PopulationCountOp;
|
||||
|
||||
def LHLO_RealOp: LHLO_Op<"real", [SameOperandsShape]>, BASE_HLO_RealOp {
|
||||
let arguments = (ins Arg<LHLO_Buffer, "", [MemRead]>:$input,
|
||||
Arg<LHLO_Buffer, "", [MemWrite]>:$output);
|
||||
let arguments = (ins Arg<LHLO_ComplexBuffer, "", [MemRead]>:$input,
|
||||
Arg<LHLO_FpBuffer, "", [MemWrite]>:$output);
|
||||
}
|
||||
|
||||
def LHLO_RsqrtOp: LHLO_UnaryElementwiseOp<"rsqrt">, BASE_HLO_RsqrtOp;
|
||||
def LHLO_RoundOp: LHLO_UnaryElementwiseOp<"round_nearest_afz", LHLO_FpBuffer>, BASE_HLO_RoundOp;
|
||||
|
||||
def LHLO_SqrtOp: LHLO_UnaryElementwiseOp<"sqrt">, BASE_HLO_SqrtOp;
|
||||
def LHLO_RsqrtOp: LHLO_UnaryElementwiseOp<"rsqrt", LHLO_FpOrComplexBuffer>, BASE_HLO_RsqrtOp;
|
||||
|
||||
def LHLO_SqrtOp: LHLO_UnaryElementwiseOp<"sqrt", LHLO_FpOrComplexBuffer>, BASE_HLO_SqrtOp;
|
||||
|
||||
def LHLO_SignOp: LHLO_UnaryElementwiseOp<"sign">, BASE_HLO_SignOp;
|
||||
|
||||
def LHLO_SinOp: LHLO_UnaryElementwiseOp<"sine">, BASE_HLO_SinOp;
|
||||
def LHLO_SinOp: LHLO_UnaryElementwiseOp<"sine", LHLO_FpOrComplexBuffer>, BASE_HLO_SinOp;
|
||||
|
||||
def LHLO_TanhOp: LHLO_UnaryElementwiseOp<"tanh">, BASE_HLO_TanhOp;
|
||||
def LHLO_TanhOp: LHLO_UnaryElementwiseOp<"tanh", LHLO_FpOrComplexBuffer>, BASE_HLO_TanhOp;
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
// XLA binary elementwise op definitions.
|
||||
//===----------------------------------------------------------------------===//
|
||||
// See https://www.tensorflow.org/xla/operation_semantics#element-wise_binary_arithmetic_operations
|
||||
|
||||
class LHLO_BinaryElementwiseOp<string mnemonic, list<OpTrait> traits> :
|
||||
class LHLO_BinaryElementwiseOp<string mnemonic, Type BufferType = LHLO_Buffer,
|
||||
list<OpTrait> traits = [SameTypeOperands]> :
|
||||
LHLO_Op<mnemonic, traits> {
|
||||
let arguments = (ins
|
||||
Arg<LHLO_Buffer, "", [MemRead]>:$lhs,
|
||||
Arg<LHLO_Buffer, "", [MemRead]>:$rhs,
|
||||
Arg<LHLO_Buffer, "", [MemWrite]>:$out,
|
||||
Arg<BufferType, "", [MemRead]>:$lhs,
|
||||
Arg<BufferType, "", [MemRead]>:$rhs,
|
||||
Arg<BufferType, "", [MemWrite]>:$out,
|
||||
OptionalAttr<BroadcastDimAttr>:$broadcast_dimensions
|
||||
);
|
||||
}
|
||||
|
||||
def LHLO_AddOp : LHLO_BinaryElementwiseOp<"add", []>, BASE_HLO_AddOp;
|
||||
def LHLO_AddOp : LHLO_BinaryElementwiseOp<"add">, BASE_HLO_AddOp;
|
||||
|
||||
def LHLO_AndOp: LHLO_BinaryElementwiseOp<"and", LHLO_PredOrIntBuffer>, BASE_HLO_AndOp;
|
||||
|
||||
def LHLO_Atan2Op : LHLO_BinaryElementwiseOp<"atan2", LHLO_FpOrComplexBuffer>, BASE_HLO_Atan2Op;
|
||||
|
||||
def LHLO_ComplexOp: LHLO_Op<"complex", [SameOperandsShape]>, BASE_HLO_ComplexOp {
|
||||
let arguments = (ins Arg<LHLO_Buffer, "", [MemRead]>:$lhs,
|
||||
Arg<LHLO_Buffer, "", [MemRead]>:$rhs,
|
||||
Arg<LHLO_Buffer, "", [MemWrite]>:$output);
|
||||
let arguments = (ins
|
||||
Arg<LHLO_FpBuffer, "", [MemRead]>:$lhs,
|
||||
Arg<LHLO_FpBuffer, "", [MemRead]>:$rhs,
|
||||
Arg<LHLO_ComplexBuffer, "", [MemWrite]>:$output,
|
||||
OptionalAttr<BroadcastDimAttr>:$broadcast_dimensions
|
||||
);
|
||||
}
|
||||
|
||||
def LHLO_DivOp : LHLO_BinaryElementwiseOp<"divide", []>, BASE_HLO_DivOp;
|
||||
def LHLO_DivOp : LHLO_BinaryElementwiseOp<"divide">, BASE_HLO_DivOp;
|
||||
|
||||
def LHLO_MaxOp : LHLO_BinaryElementwiseOp<"maximum", []>, BASE_HLO_MaxOp;
|
||||
def LHLO_MaxOp : LHLO_BinaryElementwiseOp<"maximum">, BASE_HLO_MaxOp;
|
||||
|
||||
def LHLO_MinOp : LHLO_BinaryElementwiseOp<"minimum", []>, BASE_HLO_MinOp;
|
||||
def LHLO_MinOp : LHLO_BinaryElementwiseOp<"minimum">, BASE_HLO_MinOp;
|
||||
|
||||
def LHLO_MulOp : LHLO_BinaryElementwiseOp<"multiply", []>, BASE_HLO_MulOp;
|
||||
def LHLO_MulOp : LHLO_BinaryElementwiseOp<"multiply">, BASE_HLO_MulOp;
|
||||
|
||||
def LHLO_RemOp :
|
||||
LHLO_BinaryElementwiseOp<"remainder", []>, BASE_HLO_RemOp;
|
||||
def LHLO_OrOp : LHLO_BinaryElementwiseOp<"or", LHLO_PredOrIntBuffer>, BASE_HLO_OrOp;
|
||||
|
||||
def LHLO_SubOp : LHLO_BinaryElementwiseOp<"subtract", []>, BASE_HLO_SubOp;
|
||||
def LHLO_PowOp : LHLO_BinaryElementwiseOp<"power">, BASE_HLO_PowOp;
|
||||
|
||||
def LHLO_AndOp: LHLO_BinaryElementwiseOp<"and", []>, BASE_HLO_AndOp;
|
||||
def LHLO_RemOp : LHLO_BinaryElementwiseOp<"remainder", LHLO_IntOrFpBuffer>, BASE_HLO_RemOp;
|
||||
|
||||
def LHLO_OrOp: LHLO_BinaryElementwiseOp<"or", []>, BASE_HLO_OrOp;
|
||||
def LHLO_ShiftLeftOp : LHLO_BinaryElementwiseOp<"shift_left", LHLO_IntBuffer>, BASE_HLO_ShiftLeftOp;
|
||||
|
||||
def LHLO_ShiftRightArithmeticOp : LHLO_BinaryElementwiseOp<"shift_right_arithmetic", LHLO_IntBuffer>, BASE_HLO_ShiftRightArithmeticOp;
|
||||
|
||||
def LHLO_ShiftRightLogicalOp : LHLO_BinaryElementwiseOp<"shift_right_logical", LHLO_IntBuffer>, BASE_HLO_ShiftRightLogicalOp;
|
||||
|
||||
def LHLO_SubOp : LHLO_BinaryElementwiseOp<"subtract">, BASE_HLO_SubOp;
|
||||
|
||||
def LHLO_XorOp : LHLO_BinaryElementwiseOp<"xor", LHLO_PredOrIntBuffer>, BASE_HLO_XorOp;
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
// XLA control flow op definitions.
|
||||
@ -210,6 +268,16 @@ def LHLO_CaseOp: LHLO_Op<"case", [
|
||||
let regions = (region VariadicRegion<SizedRegion<1>>:$branches);
|
||||
}
|
||||
|
||||
// TODO(timshen): Add a custom syntax for this.
|
||||
def LHLO_WhileOp: LHLO_Op<"while", [SameTypeOperands]>, BASE_HLO_WhileOp {
|
||||
let arguments = (ins
|
||||
Arg<LHLO_BufferOrTuple, "", [MemRead]>:$val,
|
||||
Arg<LHLO_BufferOrTuple, "", [MemWrite]>:$output
|
||||
);
|
||||
|
||||
let regions = (region SizedRegion<1>:$cond, SizedRegion<1>:$body);
|
||||
}
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
// XLA tuple op definitions.
|
||||
//===----------------------------------------------------------------------===//
|
||||
@ -269,7 +337,9 @@ def HLO_DynamicUpdateSliceOp: LHLO_Op<"dynamic-update-slice", []> {
|
||||
|
||||
def HLO_StaticMemRefCastOp: Op<LHLO_Dialect, "static_memref_cast",
|
||||
[NoSideEffect, DeclareOpInterfaceMethods<ViewLikeOpInterface>]> {
|
||||
let summary = "static memref cast operation";
|
||||
let summary = [{
|
||||
"modifies the offset, sizes and strides of a statically shaped memref.
|
||||
}];
|
||||
let description = [{
|
||||
Allows to modify the offset, sizes and strides of a statically shaped memref.
|
||||
|
||||
@ -357,7 +427,23 @@ def HLO_DynamicMemRefCastOp: Op<LHLO_Dialect, "dynamic_memref_cast",
|
||||
// XLA Other op definitions.
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
def HLO_BatchNormInferenceOp : LHLO_Op<"batch_norm_inference", []>,
|
||||
def LHLO_BatchNormGradOp : LHLO_Op<"batch_norm_grad", []>,
|
||||
BASE_HLO_BatchNormGradOp {
|
||||
|
||||
let arguments = (ins
|
||||
Arg<LHLO_Buffer, "", [MemRead]>:$operand,
|
||||
Arg<LHLO_Buffer, "", [MemRead]>:$scale,
|
||||
Arg<LHLO_Buffer, "", [MemRead]>:$mean,
|
||||
Arg<LHLO_Buffer, "", [MemRead]>:$variance,
|
||||
Arg<LHLO_Buffer, "", [MemRead]>:$grad_output,
|
||||
Arg<LHLO_TupleBuffer, "", [MemWrite]>:$output,
|
||||
F32Attr:$epsilon,
|
||||
I64Attr:$feature_index
|
||||
);
|
||||
|
||||
}
|
||||
|
||||
def LHLO_BatchNormInferenceOp : LHLO_Op<"batch_norm_inference", []>,
|
||||
BASE_HLO_BatchNormInferenceOp {
|
||||
|
||||
let arguments = (ins
|
||||
@ -372,6 +458,19 @@ def HLO_BatchNormInferenceOp : LHLO_Op<"batch_norm_inference", []>,
|
||||
);
|
||||
}
|
||||
|
||||
def LHLO_BatchNormTrainingOp : LHLO_Op<"batch_norm_training", []>,
|
||||
BASE_HLO_BatchNormTrainingOp {
|
||||
|
||||
let arguments = (ins
|
||||
Arg<LHLO_Buffer, "", [MemRead]>:$operand,
|
||||
Arg<LHLO_Buffer, "", [MemRead]>:$scale,
|
||||
Arg<LHLO_Buffer, "", [MemRead]>:$offset,
|
||||
Arg<LHLO_TupleBuffer, "", [MemWrite]>:$output,
|
||||
F32Attr:$epsilon,
|
||||
I64Attr:$feature_index
|
||||
);
|
||||
}
|
||||
|
||||
def LHLO_BroadcastOp : LHLO_Op<"broadcast",
|
||||
[]>, BASE_HLO_BroadcastOp {
|
||||
let arguments = (ins
|
||||
@ -531,6 +630,88 @@ def LHLO_TransposeOp: LHLO_Op<"transpose", []>, BASE_HLO_TransposeOp {
|
||||
);
|
||||
}
|
||||
|
||||
def LHLO_ReducePrecisionOp: LHLO_Op<"reduce_precision", [SameTypeOperands]>,
|
||||
BASE_HLO_ReducePrecisionOp {
|
||||
let arguments = (ins
|
||||
Arg<LHLO_FpBuffer, "", [MemRead]>:$operand,
|
||||
Arg<LHLO_FpBuffer, "", [MemWrite]>:$output,
|
||||
I32Attr:$exponent_bits,
|
||||
I32Attr:$mantissa_bits
|
||||
);
|
||||
}
|
||||
|
||||
def LHLO_AllReduceOp : LHLO_Op<"all_reduce", [SameTypeOperands]>,
|
||||
BASE_HLO_AllReduceOp {
|
||||
let arguments = (ins
|
||||
Arg<LHLO_Buffer, "", [MemRead]>:$operand,
|
||||
Arg<LHLO_Buffer, "", [MemWrite]>:$output,
|
||||
I64ElementsAttr:$replica_groups,
|
||||
DefaultValuedAttr<BoolAttr, "false">:$constrain_layout,
|
||||
OptionalAttr<ChannelHandle<LHLO_Dialect>>:$channel_id,
|
||||
DefaultValuedAttr<BoolAttr, "false">:$use_global_device_ids
|
||||
);
|
||||
let regions = (region SizedRegion<1>:$computation);
|
||||
}
|
||||
|
||||
def LHLO_CollectivePermuteOp: LHLO_Op<"collective_permute", [SameTypeOperands]>,
|
||||
BASE_HLO_CollectivePermuteOp {
|
||||
|
||||
let arguments = (ins
|
||||
Arg<LHLO_Buffer, "", [MemRead]>:$operand,
|
||||
Arg<LHLO_Buffer, "", [MemWrite]>:$output,
|
||||
I64ElementsAttr:$source_target_pairs,
|
||||
OptionalAttr<ChannelHandle<LHLO_Dialect>>:$channel_id
|
||||
);
|
||||
}
|
||||
|
||||
def LHLO_FftOp: LHLO_Op<"fft", []>, BASE_HLO_FftOp {
|
||||
let arguments = (ins
|
||||
Arg<LHLO_Buffer, "", [MemRead]>:$operand,
|
||||
Arg<LHLO_Buffer, "", [MemWrite]>:$output,
|
||||
HLO_FftTypeAttr:$fft_type,
|
||||
I64ElementsAttr:$fft_length
|
||||
);
|
||||
}
|
||||
|
||||
def LHLO_CholeskyOp: LHLO_Op<"cholesky", [SameOperandsElementType]>, BASE_HLO_CholeskyOp {
|
||||
let arguments = (ins
|
||||
Arg<LHLO_FpOrComplexBuffer, "", [MemRead]>:$a,
|
||||
Arg<LHLO_FpOrComplexBuffer, "", [MemWrite]>:$output,
|
||||
DefaultValuedAttr<BoolAttr, "false">:$lower
|
||||
);
|
||||
}
|
||||
|
||||
def LHLO_Infeed: LHLO_Op<"infeed", []>, BASE_HLO_InfeedOp {
|
||||
let arguments = (ins
|
||||
Arg<LHLO_Buffer, "", [MemWrite]>:$output,
|
||||
DefaultValuedAttr<StrAttr, "">:$config
|
||||
);
|
||||
}
|
||||
|
||||
def LHLO_Outfeed: LHLO_Op<"outfeed", []> {
|
||||
let arguments = (ins
|
||||
Arg<LHLO_Buffer, "", [MemRead]>:$operand,
|
||||
DefaultValuedAttr<StrAttr, "">:$config
|
||||
);
|
||||
}
|
||||
|
||||
def LHLO_ReplicaIdOp : LHLO_Op<"replica_id", []>, BASE_HLO_ReplicaIdOp {
|
||||
let arguments = (ins Arg<MemRefOf<[UI32]>, "", [MemWrite]>);
|
||||
}
|
||||
|
||||
def LHLO_TriangularSolveOp: LHLO_Op<"triangular_solve", [SameOperandsElementType]>,
|
||||
BASE_HLO_TriangularSolveOp {
|
||||
let arguments = (ins
|
||||
Arg<LHLO_FpOrComplexBuffer, "", [MemRead]>:$a,
|
||||
Arg<LHLO_FpOrComplexBuffer, "", [MemRead]>:$b,
|
||||
Arg<LHLO_FpOrComplexBuffer, "", [MemWrite]>:$output,
|
||||
BoolAttr:$left_side,
|
||||
BoolAttr:$lower,
|
||||
BoolAttr:$unit_diagonal,
|
||||
HLO_TransposeAttr:$transpose_a
|
||||
);
|
||||
}
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
// Late operations
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
@ -19,10 +19,12 @@ limitations under the License.
|
||||
#include "mlir/IR/Builders.h" // from @llvm-project
|
||||
#include "mlir/IR/StandardTypes.h" // from @llvm-project
|
||||
#include "tensorflow/compiler/mlir/xla/attribute_importer.h"
|
||||
#include "tensorflow/compiler/mlir/xla/hlo_function_importer.h"
|
||||
#include "tensorflow/compiler/mlir/xla/hlo_utils.h"
|
||||
#include "tensorflow/compiler/mlir/xla/ir/hlo_ops.h"
|
||||
#include "tensorflow/compiler/mlir/xla/type_to_shape.h"
|
||||
#include "tensorflow/compiler/xla/comparison_util.h"
|
||||
#include "tensorflow/compiler/xla/service/hlo_module.h"
|
||||
#include "tensorflow/compiler/xla/service/shape_inference.h"
|
||||
#include "tensorflow/compiler/xla/util.h"
|
||||
|
||||
@ -118,6 +120,76 @@ StatusOr<XlaOp> MlirHloBuilder::ConvGeneralDilatedInternal(
|
||||
return MakeXlaOp(op);
|
||||
}
|
||||
|
||||
StatusOr<XlaOp> MlirHloBuilder::FftInternal(
|
||||
const Shape& shape, XlaOp operand, FftType fft_type,
|
||||
absl::Span<const int64> fft_length) {
|
||||
TF_ASSIGN_OR_RETURN(mlir::Type ty, ConvertShapeToType<mlir::RankedTensorType>(
|
||||
shape, builder_));
|
||||
auto op = builder_.create<mlir::xla_hlo::FftOp>(
|
||||
loc_, ty, GetValue(operand),
|
||||
builder_.getStringAttr(FftType_Name(fft_type)),
|
||||
GetI64ElementsAttr(fft_length, &builder_));
|
||||
return MakeXlaOp(op);
|
||||
}
|
||||
|
||||
StatusOr<XlaOp> MlirHloBuilder::ReduceInternal(
|
||||
const Shape& shape, absl::Span<const XlaOp> all_operands,
|
||||
const XlaComputation& computation,
|
||||
absl::Span<const int64> dimensions_to_reduce) {
|
||||
// Reduce takes two set of variadic operands inputs and init_values.
|
||||
// all_operands contains both of these so split operands into two parts.
|
||||
int64_t num_args = all_operands.size() / 2;
|
||||
auto op = builder_.create<mlir::xla_hlo::ReduceOp>(
|
||||
loc_, GetValues(all_operands.first(num_args)),
|
||||
GetValues(all_operands.subspan(num_args)),
|
||||
GetI64ElementsAttr(dimensions_to_reduce, &builder_));
|
||||
TF_RETURN_IF_ERROR(ImportComputation(computation.proto(), &op.body()));
|
||||
if (op.getNumResults() == 1) return MakeXlaOp(op.getResult(0));
|
||||
auto tuple = builder_.create<mlir::xla_hlo::TupleOp>(loc_, op.getResults());
|
||||
return MakeXlaOp(tuple);
|
||||
}
|
||||
|
||||
StatusOr<XlaOp> MlirHloBuilder::ReduceWindowInternal(
|
||||
const Shape& shape, XlaOp operand, XlaOp init_value,
|
||||
const XlaComputation& computation, Window window) {
|
||||
TF_ASSIGN_OR_RETURN(mlir::Type ty, ConvertShapeToType<mlir::RankedTensorType>(
|
||||
shape, builder_));
|
||||
llvm::SmallVector<int64, 4> sizes, strides, base_dilations, win_dilations;
|
||||
llvm::SmallVector<int64, 8> padding;
|
||||
for (const auto& dim : window.dimensions()) {
|
||||
sizes.push_back(dim.size());
|
||||
strides.push_back(dim.stride());
|
||||
base_dilations.push_back(dim.base_dilation());
|
||||
win_dilations.push_back(dim.window_dilation());
|
||||
padding.push_back(dim.padding_low());
|
||||
padding.push_back(dim.padding_high());
|
||||
}
|
||||
auto padding_ty =
|
||||
mlir::RankedTensorType::get({static_cast<int64_t>(padding.size()) / 2, 2},
|
||||
builder_.getIntegerType(64));
|
||||
auto op = builder_.create<mlir::xla_hlo::ReduceWindowOp>(
|
||||
loc_, ty, GetValue(operand), GetValue(init_value),
|
||||
GetI64ElementsAttr(sizes, &builder_),
|
||||
GetI64ElementsAttr(strides, &builder_),
|
||||
GetI64ElementsAttr(base_dilations, &builder_),
|
||||
GetI64ElementsAttr(win_dilations, &builder_),
|
||||
mlir::DenseIntElementsAttr::get(padding_ty, padding));
|
||||
TF_RETURN_IF_ERROR(ImportComputation(computation.proto(), &op.body()));
|
||||
return MakeXlaOp(op);
|
||||
}
|
||||
|
||||
XlaOp MlirHloBuilder::Iota(const Shape& shape, int64 iota_dimension) {
|
||||
return ReportErrorOrReturn([&]() -> StatusOr<XlaOp> {
|
||||
TF_ASSIGN_OR_RETURN(
|
||||
mlir::Type ty,
|
||||
ConvertShapeToType<mlir::RankedTensorType>(shape, builder_));
|
||||
auto op = builder_.create<mlir::xla_hlo::IotaOp>(
|
||||
loc_, ty,
|
||||
builder_.getIntegerAttr(builder_.getI64Type(), iota_dimension));
|
||||
return MakeXlaOp(op);
|
||||
});
|
||||
}
|
||||
|
||||
StatusOr<XlaOp> MlirHloBuilder::TransposeInternal(
|
||||
const Shape& shape, XlaOp operand, absl::Span<const int64> permutation) {
|
||||
TF_ASSIGN_OR_RETURN(mlir::Type ty, ConvertShapeToType<mlir::RankedTensorType>(
|
||||
@ -127,6 +199,15 @@ StatusOr<XlaOp> MlirHloBuilder::TransposeInternal(
|
||||
return MakeXlaOp(op);
|
||||
}
|
||||
|
||||
StatusOr<XlaOp> MlirHloBuilder::RevInternal(
|
||||
const Shape& shape, XlaOp operand, absl::Span<const int64> dimensions) {
|
||||
TF_ASSIGN_OR_RETURN(mlir::Type ty, ConvertShapeToType<mlir::RankedTensorType>(
|
||||
shape, builder_));
|
||||
auto op = builder_.create<mlir::xla_hlo::ReverseOp>(
|
||||
loc_, ty, GetValue(operand), GetI64ElementsAttr(dimensions, &builder_));
|
||||
return MakeXlaOp(op);
|
||||
}
|
||||
|
||||
StatusOr<XlaOp> MlirHloBuilder::GatherInternal(
|
||||
const Shape& shape, XlaOp input, XlaOp start_indices,
|
||||
const GatherDimensionNumbers& dimension_numbers,
|
||||
@ -140,6 +221,24 @@ StatusOr<XlaOp> MlirHloBuilder::GatherInternal(
|
||||
return MakeXlaOp(op);
|
||||
}
|
||||
|
||||
StatusOr<XlaOp> MlirHloBuilder::ScatterInternal(
|
||||
const Shape& shape, XlaOp input, XlaOp scatter_indices, XlaOp updates,
|
||||
const XlaComputation& update_computation,
|
||||
const ScatterDimensionNumbers& dimension_numbers, bool indices_are_sorted,
|
||||
bool unique_indices) {
|
||||
TF_ASSIGN_OR_RETURN(mlir::Type ty, ConvertShapeToType<mlir::RankedTensorType>(
|
||||
shape, builder_));
|
||||
auto op = builder_.create<mlir::xla_hlo::ScatterOp>(
|
||||
loc_, ty, GetValue(input), GetValue(scatter_indices), GetValue(updates),
|
||||
ConvertScatterDimensionNumbers(dimension_numbers, &builder_),
|
||||
builder_.getBoolAttr(indices_are_sorted),
|
||||
builder_.getBoolAttr(unique_indices));
|
||||
|
||||
TF_RETURN_IF_ERROR(
|
||||
ImportComputation(update_computation.proto(), &op.update_computation()));
|
||||
return MakeXlaOp(op);
|
||||
}
|
||||
|
||||
StatusOr<XlaOp> MlirHloBuilder::RngOpInternal(
|
||||
RandomDistribution distribution, absl::Span<const XlaOp> parameters,
|
||||
const Shape& shape) {
|
||||
@ -348,6 +447,18 @@ StatusOr<XlaOp> MlirHloBuilder::CreateOp(
|
||||
return MakeXlaOp(op->getResult(0));
|
||||
}
|
||||
|
||||
Status MlirHloBuilder::ImportComputation(const HloModuleProto& computation,
|
||||
mlir::Region* region) {
|
||||
TF_ASSIGN_OR_RETURN(auto module_config,
|
||||
xla::HloModule::CreateModuleConfigFromProto(
|
||||
computation, xla::DebugOptions()));
|
||||
TF_ASSIGN_OR_RETURN(auto hlo_module, xla::HloModule::CreateFromProto(
|
||||
computation, module_config));
|
||||
|
||||
return HloFunctionImporter::ImportAsRegion(*hlo_module->entry_computation(),
|
||||
region, &builder_);
|
||||
}
|
||||
|
||||
StatusOr<const Shape*> MlirHloBuilder::GetShapePtr(XlaOp op) const {
|
||||
TF_RETURN_IF_ERROR(first_error());
|
||||
TF_RETURN_IF_ERROR(CheckOpBuilder(op));
|
||||
|
@ -120,15 +120,40 @@ class MlirHloBuilder : public XlaBuilder {
|
||||
int64 feature_group_count, int64 batch_group_count,
|
||||
const PrecisionConfig* precision_config) override;
|
||||
|
||||
StatusOr<XlaOp> FftInternal(const Shape& shape, XlaOp operand,
|
||||
FftType fft_type,
|
||||
absl::Span<const int64> fft_length) override;
|
||||
|
||||
StatusOr<XlaOp> ReduceInternal(
|
||||
const Shape& shape, absl::Span<const XlaOp> all_operands,
|
||||
const XlaComputation& computation,
|
||||
absl::Span<const int64> dimensions_to_reduce) override;
|
||||
|
||||
StatusOr<XlaOp> ReduceWindowInternal(const Shape& shape, XlaOp operand,
|
||||
XlaOp init_value,
|
||||
const XlaComputation& computation,
|
||||
Window window) override;
|
||||
|
||||
XlaOp Iota(const Shape& shape, int64 iota_dimension) override;
|
||||
|
||||
StatusOr<XlaOp> TransposeInternal(
|
||||
const Shape& shape, XlaOp operand,
|
||||
absl::Span<const int64> permutation) override;
|
||||
|
||||
StatusOr<XlaOp> RevInternal(const Shape& shape, XlaOp operand,
|
||||
absl::Span<const int64> dimensions) override;
|
||||
|
||||
StatusOr<XlaOp> GatherInternal(
|
||||
const Shape& shape, XlaOp input, XlaOp start_indices,
|
||||
const GatherDimensionNumbers& dimension_numbers,
|
||||
absl::Span<const int64> slice_sizes, bool indices_are_sorted) override;
|
||||
|
||||
StatusOr<XlaOp> ScatterInternal(
|
||||
const Shape& shape, XlaOp input, XlaOp scatter_indices, XlaOp updates,
|
||||
const XlaComputation& update_computation,
|
||||
const ScatterDimensionNumbers& dimension_numbers, bool indices_are_sorted,
|
||||
bool unique_indices) override;
|
||||
|
||||
StatusOr<XlaOp> RngOpInternal(RandomDistribution distribution,
|
||||
absl::Span<const XlaOp> parameters,
|
||||
const Shape& shape) override;
|
||||
@ -196,6 +221,9 @@ class MlirHloBuilder : public XlaBuilder {
|
||||
llvm::ArrayRef<XlaOp> operands,
|
||||
llvm::ArrayRef<mlir::NamedAttribute> attributes = {});
|
||||
|
||||
Status ImportComputation(const HloModuleProto& computation,
|
||||
mlir::Region* region);
|
||||
|
||||
mlir::OpBuilder builder_;
|
||||
mlir::Location loc_;
|
||||
|
||||
|
@ -68,10 +68,11 @@ static StringRef GetClientBuilder(const Operator& op) {
|
||||
return kOpToXLABuilderMap->lookup(op_name);
|
||||
}
|
||||
|
||||
static void BuildOperator(const Operator& op, raw_ostream* output) {
|
||||
auto& os = *output;
|
||||
os << " auto& value_map = *lowering_context.values;\n"
|
||||
<< " auto result = xla_op.getResult();\n";
|
||||
static void BuildOperator(const Operator& op, raw_ostream& os) {
|
||||
os << "mlir::LogicalResult ExportXlaOp(mlir::xla_hlo::"
|
||||
<< op.getCppClassName() << " op, OpLoweringContext ctx) {\n"
|
||||
<< " auto& value_map = *ctx.values;\n"
|
||||
<< " auto result = op.getResult();\n";
|
||||
|
||||
// Build a conversion for each of the arguments.
|
||||
int operand_number = 0;
|
||||
@ -82,15 +83,14 @@ static void BuildOperator(const Operator& op, raw_ostream* output) {
|
||||
if (auto* operand_cst = arg.dyn_cast<NamedTypeConstraint*>()) {
|
||||
// Handle a non-variadic operand.
|
||||
if (!operand_cst->isVariableLength()) {
|
||||
os << " auto xla_arg_" << index
|
||||
<< " = value_map[*xla_op.getODSOperands(" << operand_number++
|
||||
<< ").begin()];\n";
|
||||
os << " auto xla_arg_" << index << " = value_map[*op.getODSOperands("
|
||||
<< operand_number++ << ").begin()];\n";
|
||||
continue;
|
||||
}
|
||||
|
||||
// Otherwise, this is a varidiac operand list.
|
||||
os << " std::vector<xla::XlaOp> xla_arg_" << index << ";\n"
|
||||
<< " for (auto operand : xla_op.getODSOperands(" << operand_number++
|
||||
<< " for (auto operand : op.getODSOperands(" << operand_number++
|
||||
<< "))\n xla_arg_" << index
|
||||
<< ".push_back(value_map[operand]);\n";
|
||||
continue;
|
||||
@ -99,8 +99,8 @@ static void BuildOperator(const Operator& op, raw_ostream* output) {
|
||||
// Otherwise, this is an attribute.
|
||||
auto named_attr = arg.get<NamedAttribute*>();
|
||||
os << " auto xla_arg_" << index << " = "
|
||||
<< GetDefaultAttrExport(*named_attr) << "(xla_op."
|
||||
<< op.getArgName(index) << "());\n";
|
||||
<< GetDefaultAttrExport(*named_attr) << "(op." << op.getArgName(index)
|
||||
<< "());\n";
|
||||
}
|
||||
|
||||
// Emit call to client API
|
||||
@ -109,7 +109,7 @@ static void BuildOperator(const Operator& op, raw_ostream* output) {
|
||||
// If all operands are variadic, then pass the builder explicitly to xla
|
||||
// client API call
|
||||
if (op.getNumOperands() == op.getNumVariableLengthOperands()) {
|
||||
os << "lowering_context.builder";
|
||||
os << "ctx.builder";
|
||||
if (op.getNumArgs() != 0) os << ", ";
|
||||
}
|
||||
|
||||
@ -120,6 +120,7 @@ static void BuildOperator(const Operator& op, raw_ostream* output) {
|
||||
|
||||
os << " value_map[result] = xla_result;\n";
|
||||
os << " return mlir::success();\n";
|
||||
os << "}\n";
|
||||
}
|
||||
|
||||
// The function below has a non-constant reference as that is required by LLVM's
|
||||
@ -128,6 +129,14 @@ static void BuildOperator(const Operator& op, raw_ostream* output) {
|
||||
static bool OperatorWritersMain(raw_ostream& os, RecordKeeper& records) {
|
||||
emitSourceFileHeader("MLIR XLA Builders", os);
|
||||
|
||||
// Emit all the helper functions.
|
||||
for (const auto* def : records.getAllDerivedDefinitions("HLO_Op")) {
|
||||
Operator op(def);
|
||||
|
||||
// Skip operations that have a custom exporter.
|
||||
if (!def->getValueAsBit("hasCustomHLOConverter")) BuildOperator(op, os);
|
||||
}
|
||||
|
||||
// Emit a function to generate an XLA operation for the operations with
|
||||
// auto-generated builders.
|
||||
os << "mlir::LogicalResult ExportXlaOperator(\n"
|
||||
@ -153,12 +162,11 @@ static bool OperatorWritersMain(raw_ostream& os, RecordKeeper& records) {
|
||||
// Cast to the current operation and build the exporter.
|
||||
os << " if (auto xla_op = llvm::dyn_cast<mlir::xla_hlo::"
|
||||
<< op.getCppClassName() << ">(op)) {\n";
|
||||
if (def->getValueAsBit("hasCustomHLOConverter")) {
|
||||
os << " return mlir::xla_hlo::ExportXlaOp(xla_op, "
|
||||
"lowering_context);\n";
|
||||
} else {
|
||||
BuildOperator(op, &os);
|
||||
}
|
||||
os << " return ";
|
||||
// The autogenerated converters aren't in the same namespace.
|
||||
// TODO(jpienaar): Reconsider this.
|
||||
if (def->getValueAsBit("hasCustomHLOConverter")) os << "mlir::xla_hlo::";
|
||||
os << "ExportXlaOp(xla_op, lowering_context);\n";
|
||||
os << " }\n";
|
||||
}
|
||||
|
||||
|
@ -30,7 +30,7 @@ tf_cc_test(
|
||||
"//tensorflow/core:test",
|
||||
"//tensorflow/core:test_main",
|
||||
"//tensorflow/core:testlib",
|
||||
"@llvm-project//llvm:support",
|
||||
"@llvm-project//llvm:Support",
|
||||
"@llvm-project//mlir:IR",
|
||||
],
|
||||
)
|
||||
|
@ -9,7 +9,7 @@ func @broadcast_add(%arg0: tensor<?xf32>, %arg1: tensor<?xf32>) -> tensor<1xinde
|
||||
// CHECK-DAG: %[[ARG0_S:.+]] = shape.shape_of %[[ARG0]]
|
||||
// CHECK-DAG: %[[ARG1_S:.+]] = shape.shape_of %[[ARG1]]
|
||||
// CHECK-DAG: %[[BCAST_S:.+]] = "shape.broadcast"(%[[ARG0_S]], %[[ARG1_S]])
|
||||
// CHECK: %[[EXTENTS:.+]] = "shape.to_extent_tensor"(%[[BCAST_S]])
|
||||
// CHECK: %[[EXTENTS:.+]] = shape.to_extent_tensor %[[BCAST_S]]
|
||||
// CHECK: return %[[EXTENTS]]
|
||||
%0 = xla_chlo.broadcast_add %arg0, %arg1 : (tensor<?xf32>, tensor<?xf32>) -> tensor<?xf32>
|
||||
%1 = "xla_test.reify_return_type_shapes"(%0) : (tensor<?xf32>) -> tensor<1xindex>
|
||||
|
@ -17,7 +17,7 @@ func @dynamicBroadcast(%arg0: tensor<?xf32>, %arg1: tensor<?x?xf32>) -> tensor<?
|
||||
// CHECK-DAG: %[[ARG0_S:.+]] = shape.shape_of %[[ARG0]]
|
||||
// CHECK-DAG: %[[ARG1_S:.+]] = shape.shape_of %[[ARG1]]
|
||||
// CHECK-DAG: %[[RESULT_S:.+]] = "shape.broadcast"(%[[ARG0_S]], %[[ARG1_S]])
|
||||
// CHECK: %[[RESULT_EXTENTS:.+]] = "shape.to_extent_tensor"(%[[RESULT_S]])
|
||||
// CHECK: %[[RESULT_EXTENTS:.+]] = shape.to_extent_tensor %[[RESULT_S]]
|
||||
// CHECK-DAG: %[[ARG0_B:.+]] = "xla_hlo.dynamic_broadcast_in_dim"(%[[ARG0]], %[[RESULT_EXTENTS]]) {broadcast_dimensions = dense<1> : tensor<1xi64>}
|
||||
// CHECK-DAG: %[[ARG1_B:.+]] = "xla_hlo.dynamic_broadcast_in_dim"(%[[ARG1]], %[[RESULT_EXTENTS]]) {broadcast_dimensions = dense<[0, 1]> : tensor<2xi64>}
|
||||
// CHECK-DAG: %[[RESULT:.+]] = xla_hlo.add %[[ARG0_B]], %[[ARG1_B]]
|
||||
@ -34,7 +34,7 @@ func @dynamicBroadcastComplex(%arg0: tensor<?xf32>, %arg1: tensor<?x?xf32>) -> t
|
||||
// CHECK-DAG: %[[ARG0_S:.+]] = shape.shape_of %[[ARG0]]
|
||||
// CHECK-DAG: %[[ARG1_S:.+]] = shape.shape_of %[[ARG1]]
|
||||
// CHECK-DAG: %[[RESULT_S:.+]] = "shape.broadcast"(%[[ARG0_S]], %[[ARG1_S]])
|
||||
// CHECK: %[[RESULT_EXTENTS:.+]] = "shape.to_extent_tensor"(%[[RESULT_S]])
|
||||
// CHECK: %[[RESULT_EXTENTS:.+]] = shape.to_extent_tensor %[[RESULT_S]]
|
||||
// CHECK-DAG: %[[ARG0_B:.+]] = "xla_hlo.dynamic_broadcast_in_dim"(%[[ARG0]], %[[RESULT_EXTENTS]]) {broadcast_dimensions = dense<1> : tensor<1xi64>} : (tensor<?xf32>, tensor<2xindex>) -> tensor<?x?xf32>
|
||||
// CHECK-DAG: %[[ARG1_B:.+]] = "xla_hlo.dynamic_broadcast_in_dim"(%[[ARG1]], %[[RESULT_EXTENTS]]) {broadcast_dimensions = dense<[0, 1]> : tensor<2xi64>} : (tensor<?x?xf32>, tensor<2xindex>) -> tensor<?x?xf32>
|
||||
// CHECK-DAG: %[[RESULT:.+]] = "xla_hlo.complex"(%[[ARG0_B]], %[[ARG1_B]]) : (tensor<?x?xf32>, tensor<?x?xf32>) -> tensor<?x?xcomplex<f32>>
|
||||
@ -51,7 +51,7 @@ func @dynamicBroadcastCompare(%arg0: tensor<?xf32>, %arg1: tensor<?x?xf32>) -> t
|
||||
// CHECK-DAG: %[[ARG0_S:.+]] = shape.shape_of %[[ARG0]]
|
||||
// CHECK-DAG: %[[ARG1_S:.+]] = shape.shape_of %[[ARG1]]
|
||||
// CHECK-DAG: %[[RESULT_S:.+]] = "shape.broadcast"(%[[ARG0_S]], %[[ARG1_S]])
|
||||
// CHECK: %[[RESULT_EXTENTS:.+]] = "shape.to_extent_tensor"(%[[RESULT_S]])
|
||||
// CHECK: %[[RESULT_EXTENTS:.+]] = shape.to_extent_tensor %[[RESULT_S]]
|
||||
// CHECK-DAG: %[[ARG0_B:.+]] = "xla_hlo.dynamic_broadcast_in_dim"(%[[ARG0]], %[[RESULT_EXTENTS]]) {broadcast_dimensions = dense<1> : tensor<1xi64>} : (tensor<?xf32>, tensor<2xindex>) -> tensor<?x?xf32>
|
||||
// CHECK-DAG: %[[ARG1_B:.+]] = "xla_hlo.dynamic_broadcast_in_dim"(%[[ARG1]], %[[RESULT_EXTENTS]]) {broadcast_dimensions = dense<[0, 1]> : tensor<2xi64>} : (tensor<?x?xf32>, tensor<2xindex>) -> tensor<?x?xf32>
|
||||
// CHECK-DAG: %[[RESULT:.+]] = "xla_hlo.compare"(%[[ARG0_B]], %[[ARG1_B]]) {comparison_direction = "EQ"} : (tensor<?x?xf32>, tensor<?x?xf32>) -> tensor<?x?xi1>
|
||||
|
@ -184,14 +184,16 @@ func @dyn_broadcast(%operand: memref<?x?xf32>) {
|
||||
|
||||
// CHECK: %[[C1__:.*]] = constant 1 : index
|
||||
// CHECK: %[[EL1_:.*]] = extract_element %[[SHAPE]]{{\[}}%[[C1__]]] : tensor<3xi64>
|
||||
// CHECK: %[[OPERAND_DIM_0:.*]] = dim %[[OPERAND]], 0 : memref<?x?xf32>
|
||||
// CHECK: %[[C0___:.*]] = constant 0 : index
|
||||
// CHECK: %[[OPERAND_DIM_0:.*]] = dim %[[OPERAND]], %[[C0___]] : memref<?x?xf32>
|
||||
// CHECK: %[[RESULT_DIM_1:.*]] = index_cast %[[EL1_]] : i64 to index
|
||||
// CHECK: %[[EXPAND_0:.*]] = cmpi "slt", %[[OPERAND_DIM_0]], %[[RESULT_DIM_1]]
|
||||
// CHECK: %[[STRIDE_0:.*]] = select %[[EXPAND_0]], %[[C0_]], %[[C1_]] : index
|
||||
|
||||
// CHECK: %[[C2_:.*]] = constant 2 : index
|
||||
// CHECK: %[[EL2_:.*]] = extract_element %[[SHAPE]]{{\[}}%[[C2_]]] : tensor<3xi64>
|
||||
// CHECK: %[[OPERAND_DIM_1:.*]] = dim %[[OPERAND]], 1 : memref<?x?xf32>
|
||||
// CHECK: %[[C1___:.*]] = constant 1 : index
|
||||
// CHECK: %[[OPERAND_DIM_1:.*]] = dim %[[OPERAND]], %[[C1___]] : memref<?x?xf32>
|
||||
// CHECK: %[[RESULT_DIM_2:.*]] = index_cast %[[EL2_]] : i64 to index
|
||||
// CHECK: %[[EXPAND_1:.*]] = cmpi "slt", %[[OPERAND_DIM_1]], %[[RESULT_DIM_2]]
|
||||
// CHECK: %[[STRIDE_1:.*]] = select %[[EXPAND_1]], %[[C0_]], %[[C1_]] : index
|
||||
@ -389,15 +391,18 @@ func @remainder(%lhs: memref<2x2xf32>, %rhs: memref<2x2xf32>, %result: memref<2x
|
||||
func @add_dyn(%lhs: tensor<?x?xf32>, %rhs: tensor<?x?xf32>) {
|
||||
%result = "xla_hlo.add"(%lhs, %rhs)
|
||||
: (tensor<?x?xf32>, tensor<?x?xf32>) -> tensor<?x?xf32>
|
||||
// CHECK: %[[DIM0:.*]] = dim %arg0, 0 : memref<?x?xf32>
|
||||
// CHECK: %[[C0:.*]] = constant 0 : index
|
||||
// CHECK: %[[DIM0:.*]] = dim %arg0, %[[C0]] : memref<?x?xf32>
|
||||
// CHECK: %[[IC0:.*]] = index_cast %[[DIM0]] : index to i64
|
||||
// CHECK: %[[DIM1:.*]] = dim %arg0, 1 : memref<?x?xf32>
|
||||
// CHECK: %[[C1:.*]] = constant 1 : index
|
||||
// CHECK: %[[DIM1:.*]] = dim %arg0, %[[C1]] : memref<?x?xf32>
|
||||
// CHECK: %[[IC1:.*]] = index_cast %[[DIM1]] : index to i64
|
||||
// CHECK: %[[SHAPE:.*]] = tensor_from_elements(%[[IC0]], %[[IC1]]) : tensor<2xi64>
|
||||
// CHECK: %[[C0:.*]] = constant 0 : index
|
||||
// CHECK: %[[EE0:.*]] = extract_element %[[SHAPE]][%[[C0]]] : tensor<2xi64>
|
||||
// CHECK: %[[C0_:.*]] = constant 0 : index
|
||||
// CHECK: %[[EE0:.*]] = extract_element %[[SHAPE]][%[[C0_]]] : tensor<2xi64>
|
||||
// CHECK: %[[ICS0:.*]] = index_cast %[[EE0]] : i64 to index
|
||||
// CHECK: %[[EE1:.*]] = extract_element %[[SHAPE]][%[[C1]]] : tensor<2xi64>
|
||||
// CHECK: %[[C1_:.*]] = constant 1 : index
|
||||
// CHECK: %[[EE1:.*]] = extract_element %[[SHAPE]][%[[C1_]]] : tensor<2xi64>
|
||||
// CHECK: %[[ICS1:.*]] = index_cast %[[EE1]] : i64 to index
|
||||
// CHECK: %[[RESULT:.*]] = alloc(%[[ICS0]], %[[ICS1]])
|
||||
// CHECK: "xla_lhlo.add"(%arg0, %arg1, %[[RESULT]]) : (memref<?x?xf32>, memref<?x?xf32>, memref<?x?xf32>) -> ()
|
||||
@ -411,15 +416,18 @@ func @add_dyn(%lhs: tensor<?x?xf32>, %rhs: tensor<?x?xf32>) {
|
||||
func @tanh_dyn(%arg0: tensor<?x?xf32>) {
|
||||
%result = "xla_hlo.tanh"(%arg0)
|
||||
: (tensor<?x?xf32>) -> tensor<?x?xf32>
|
||||
// CHECK: %[[DIM0:.*]] = dim %arg0, 0 : memref<?x?xf32>
|
||||
// CHECK: %[[C0:.*]] = constant 0 : index
|
||||
// CHECK: %[[DIM0:.*]] = dim %arg0, %[[C0]] : memref<?x?xf32>
|
||||
// CHECK: %[[IC0:.*]] = index_cast %[[DIM0]] : index to i64
|
||||
// CHECK: %[[DIM1:.*]] = dim %arg0, 1 : memref<?x?xf32>
|
||||
// CHECK: %[[C1:.*]] = constant 1 : index
|
||||
// CHECK: %[[DIM1:.*]] = dim %arg0, %[[C1]] : memref<?x?xf32>
|
||||
// CHECK: %[[IC1:.*]] = index_cast %[[DIM1]] : index to i64
|
||||
// CHECK: %[[SHAPE:.*]] = tensor_from_elements(%[[IC0]], %[[IC1]]) : tensor<2xi64>
|
||||
// CHECK: %[[C0:.*]] = constant 0 : index
|
||||
// CHECK: %[[EE0:.*]] = extract_element %[[SHAPE]][%[[C0]]] : tensor<2xi64>
|
||||
// CHECK: %[[C0_:.*]] = constant 0 : index
|
||||
// CHECK: %[[EE0:.*]] = extract_element %[[SHAPE]][%[[C0_]]] : tensor<2xi64>
|
||||
// CHECK: %[[ICS0:.*]] = index_cast %[[EE0]] : i64 to index
|
||||
// CHECK: %[[EE1:.*]] = extract_element %[[SHAPE]][%[[C1]]] : tensor<2xi64>
|
||||
// CHECK: %[[C1_:.*]] = constant 1 : index
|
||||
// CHECK: %[[EE1:.*]] = extract_element %[[SHAPE]][%[[C1_]]] : tensor<2xi64>
|
||||
// CHECK: %[[ICS1:.*]] = index_cast %[[EE1]] : i64 to index
|
||||
// CHECK: %[[RESULT:.*]] = alloc(%[[ICS0]], %[[ICS1]])
|
||||
// CHECK: "xla_lhlo.tanh"(%arg0, %[[RESULT]]) : (memref<?x?xf32>, memref<?x?xf32>) -> ()
|
||||
|
@ -340,36 +340,36 @@ func @transpose(%arg0: tensor<2x3x9x5xi32>) -> tensor<3x2x5x9xi32> {
|
||||
|
||||
// -----
|
||||
|
||||
// CHECK-DAG: #[[OPERAND_MAP:.*]] = affine_map<(d0, d1) -> (d0, 0, d1)>
|
||||
// CHECK-DAG: #[[RESULT_MAP:.*]] = affine_map<(d0, d1) -> (d0, d1)>
|
||||
// CHECK-DAG: #[[RESHAPE_MAP1:.*]] = affine_map<(d0, d1, d2) -> (d0, d1)>
|
||||
// CHECK-DAG: #[[RESHAPE_MAP2:.*]] = affine_map<(d0, d1, d2) -> (d2)>
|
||||
// CHECK-LABEL: func @reshape_3D_2D
|
||||
func @reshape_3D_2D(%arg0: tensor<12x1x42xi32>) -> tensor<12x42xi32> {
|
||||
%0 = "xla_hlo.reshape"(%arg0) : (tensor<12x1x42xi32>) -> tensor<12x42xi32>
|
||||
return %0 : tensor<12x42xi32>
|
||||
}
|
||||
// CHECK: linalg.generic {{{.*}}indexing_maps = [#[[OPERAND_MAP]], #[[RESULT_MAP]]]
|
||||
// CHECK: linalg.tensor_reshape %{{.*}} [#[[RESHAPE_MAP1]], #[[RESHAPE_MAP2]]]
|
||||
|
||||
// -----
|
||||
|
||||
// CHECK-DAG: #[[OPERAND_MAP:.*]] = affine_map<(d0, d1) -> (d0, d1, 0, 0)>
|
||||
// CHECK-DAG: #[[RESULT_MAP:.*]] = affine_map<(d0, d1) -> (d0, d1)>
|
||||
// CHECK-DAG: #[[RESHAPE_MAP1:.*]] = affine_map<(d0, d1, d2, d3) -> (d0)>
|
||||
// CHECK-DAG: #[[RESHAPE_MAP2:.*]] = affine_map<(d0, d1, d2, d3) -> (d1, d2, d3)>
|
||||
// CHECK-LABEL: func @reshape_4D_2D
|
||||
func @reshape_4D_2D(%arg0: tensor<12x42x1x1xi32>) -> tensor<12x42xi32> {
|
||||
%0 = "xla_hlo.reshape"(%arg0) : (tensor<12x42x1x1xi32>) -> tensor<12x42xi32>
|
||||
return %0 : tensor<12x42xi32>
|
||||
}
|
||||
// CHECK: linalg.generic {{{.*}}indexing_maps = [#[[OPERAND_MAP]], #[[RESULT_MAP]]]
|
||||
// CHECK: linalg.tensor_reshape %{{.*}} [#[[RESHAPE_MAP1]], #[[RESHAPE_MAP2]]]
|
||||
|
||||
// -----
|
||||
|
||||
// CHECK-DAG: #[[OPERAND_MAP:.*]] = affine_map<(d0, d1, d2, d3) -> (d0, d2)>
|
||||
// CHECK-DAG: #[[RESULT_MAP:.*]] = affine_map<(d0, d1, d2, d3) -> (d0, d1, d2, d3)>
|
||||
// CHECK-DAG: #[[RESHAPE_MAP1:.*]] = affine_map<(d0, d1, d2, d3) -> (d0, d1)>
|
||||
// CHECK-DAG: #[[RESHAPE_MAP2:.*]] = affine_map<(d0, d1, d2, d3) -> (d2, d3)>
|
||||
// CHECK-LABEL: func @reshape_2D_4D
|
||||
func @reshape_2D_4D(%arg0: tensor<12x42xi32>) -> tensor<12x1x42x1xi32> {
|
||||
%0 = "xla_hlo.reshape"(%arg0) : (tensor<12x42xi32>) -> tensor<12x1x42x1xi32>
|
||||
return %0 : tensor<12x1x42x1xi32>
|
||||
}
|
||||
// CHECK: linalg.generic {{{.*}}indexing_maps = [#[[OPERAND_MAP]], #[[RESULT_MAP]]]
|
||||
// CHECK: linalg.tensor_reshape %{{.*}} [#[[RESHAPE_MAP1]], #[[RESHAPE_MAP2]]]
|
||||
|
||||
// -----
|
||||
|
||||
@ -407,7 +407,8 @@ func @add_scalar(%lhs: tensor<f32>, %rhs: tensor<f32>) -> tensor<f32> {
|
||||
%0 = "xla_hlo.add"(%lhs, %rhs) : (tensor<f32>, tensor<f32>) -> tensor<f32>
|
||||
return %0 : tensor<f32>
|
||||
}
|
||||
// CHECK: linalg.generic {{{.*}}indexing_maps = [#[[MAP]], #[[MAP]], #[[MAP]]]
|
||||
// CHECK: linalg.generic
|
||||
// CHECK-SAME: indexing_maps = [#[[MAP]], #[[MAP]], #[[MAP]]]
|
||||
// CHECK-NEXT: ^bb0(%[[LHS:.*]]: f32, %[[RHS:.*]]: f32):
|
||||
// CHECK: %[[RESULT:.*]] = addf %[[LHS]], %[[RHS]]
|
||||
// CHECK-NEXT: linalg.yield %[[RESULT]] : f32
|
||||
@ -554,4 +555,5 @@ func @reverse(%input: tensor<2x3xf32>) -> tensor<2x3xf32> {
|
||||
} : (tensor<2x3xf32>) -> tensor<2x3xf32>
|
||||
return %result : tensor<2x3xf32>
|
||||
}
|
||||
// CHECK: linalg.generic {{{.*}}indexing_maps = [#[[OPERAND_MAP]], #[[RESULT_MAP]]]
|
||||
// CHECK: linalg.generic
|
||||
// CHECK-SAME: indexing_maps = [#[[OPERAND_MAP]], #[[RESULT_MAP]]]
|
||||
|
@ -14,10 +14,10 @@ func @batchmatmulv2_basic(%arg0: tensor<1x4x2xf32>, %arg1: tensor<3x2x4xf32>) ->
|
||||
// CHECK: [[RHSHEAD:%.*]], [[RHSTAIL:%.*]] = "shape.split_at"([[RHSSHAPE]], [[CM2]]) : (!shape.shape, i32) -> (!shape.shape, !shape.shape)
|
||||
// CHECK: [[BCASTHEAD:%.*]] = "shape.broadcast"([[LHSHEAD]], [[RHSHEAD]]) : (!shape.shape, !shape.shape) -> !shape.shape
|
||||
// CHECK: [[LHSBCASTSHAPE:%.*]] = "shape.concat"([[BCASTHEAD]], [[LHSTAIL]]) : (!shape.shape, !shape.shape) -> !shape.shape
|
||||
// CHECK: [[LHSSHAPEEXTENTS:%.*]] = "shape.to_extent_tensor"([[LHSBCASTSHAPE]]) : (!shape.shape) -> tensor<3xindex>
|
||||
// CHECK: [[LHSSHAPEEXTENTS:%.*]] = shape.to_extent_tensor [[LHSBCASTSHAPE]] : tensor<3xindex>
|
||||
// CHECK: [[LHSBCAST:%.*]] = "xla_hlo.dynamic_broadcast_in_dim"([[LHS]], [[LHSSHAPEEXTENTS]]) {broadcast_dimensions = dense<[0, 1, 2]> : tensor<3xi64>} : (tensor<1x4x2xf32>, tensor<3xindex>) -> tensor<3x4x2xf32>
|
||||
// CHECK: [[RHSBCASTSHAPE:%.*]] = "shape.concat"([[BCASTHEAD]], [[RHSTAIL]]) : (!shape.shape, !shape.shape) -> !shape.shape
|
||||
// CHECK: [[RHSSHAPEEXTENTS:%.*]] = "shape.to_extent_tensor"([[RHSBCASTSHAPE]]) : (!shape.shape) -> tensor<3xindex>
|
||||
// CHECK: [[RHSSHAPEEXTENTS:%.*]] = shape.to_extent_tensor [[RHSBCASTSHAPE]] : tensor<3xindex>
|
||||
// CHECK: [[RHSBCAST:%.*]] = "xla_hlo.dynamic_broadcast_in_dim"([[RHS]], [[RHSSHAPEEXTENTS]]) {broadcast_dimensions = dense<[0, 1, 2]> : tensor<3xi64>} : (tensor<3x2x4xf32>, tensor<3xindex>) -> tensor<3x2x4xf32>
|
||||
// CHECK: [[RESULT:%.*]] = "xla_hlo.dot_general"([[LHSBCAST]], [[RHSBCAST]]) {dot_dimension_numbers = {lhs_batching_dimensions = dense<0> : tensor<1xi64>, lhs_contracting_dimensions = dense<2> : tensor<1xi64>, rhs_batching_dimensions = dense<0> : tensor<1xi64>, rhs_contracting_dimensions = dense<1> : tensor<1xi64>}} : (tensor<3x4x2xf32>, tensor<3x2x4xf32>) -> tensor<3x4x4xf32>
|
||||
// CHECK: return [[RESULT]] : tensor<3x4x4xf32>
|
||||
|
@ -27,7 +27,7 @@ func @broadcast_add(%arg0: tensor<1xi32>, %arg1: tensor<1x2xi32>) -> tensor<1x2x
|
||||
// CHECK: %[[UNUSED_LHS_SHAPE:.+]] = shape.const_shape [1]
|
||||
// CHECK: %[[UNUSED_RHS_SHAPE:.+]] = shape.const_shape [1, 2]
|
||||
// CHECK: %[[RESULT_SHAPE:.+]] = shape.const_shape [1, 2]
|
||||
// CHECK-DAG: %[[RESULT_EXTENTS:.+]] = "shape.to_extent_tensor"(%[[RESULT_SHAPE]])
|
||||
// CHECK-DAG: %[[RESULT_EXTENTS:.+]] = shape.to_extent_tensor %[[RESULT_SHAPE]]
|
||||
// CHECK-DAG: %[[LHS_BCAST:.+]] = "xla_hlo.dynamic_broadcast_in_dim"(%arg0, %[[RESULT_EXTENTS]]) {broadcast_dimensions = dense<1> : tensor<1xi64>}
|
||||
// CHECK-DAG: %[[RHS_BCAST:.+]] = "xla_hlo.dynamic_broadcast_in_dim"(%arg1, %[[RESULT_EXTENTS]]) {broadcast_dimensions = dense<[0, 1]> : tensor<2xi64>}
|
||||
// CHECK: xla_hlo.add %[[LHS_BCAST]], %[[RHS_BCAST]]
|
||||
@ -42,7 +42,7 @@ func @broadcast_multi_dim_add(%arg0: tensor<4x1x1xi32>, %arg1: tensor<4x4x4x4xi3
|
||||
// CHECK: %[[UNUSED_LHS_SHAPE:.+]] = shape.const_shape [4, 1, 1]
|
||||
// CHECK: %[[UNUSED_RHS_SHAPE:.+]] = shape.const_shape [4, 4, 4, 4]
|
||||
// CHECK: %[[RESULT_SHAPE:.+]] = shape.const_shape [4, 4, 4, 4]
|
||||
// CHECK-DAG: %[[RESULT_EXTENTS:.+]] = "shape.to_extent_tensor"(%[[RESULT_SHAPE]])
|
||||
// CHECK-DAG: %[[RESULT_EXTENTS:.+]] = shape.to_extent_tensor %[[RESULT_SHAPE]]
|
||||
// CHECK-DAG: %[[LHS_BCAST:.+]] = "xla_hlo.dynamic_broadcast_in_dim"(%arg0, %[[RESULT_EXTENTS]]) {broadcast_dimensions = dense<[1, 2, 3]> : tensor<3xi64>}
|
||||
// CHECK-DAG: %[[RHS_BCAST:.+]] = "xla_hlo.dynamic_broadcast_in_dim"(%arg1, %[[RESULT_EXTENTS]]) {broadcast_dimensions = dense<[0, 1, 2, 3]> : tensor<4xi64>}
|
||||
// CHECK: xla_hlo.add %[[LHS_BCAST]], %[[RHS_BCAST]]
|
||||
@ -55,7 +55,7 @@ func @add_dynamic(%arg0: tensor<?xi32>, %arg1: tensor<?x?xi32>) -> tensor<?x?xi3
|
||||
// CHECK-DAG: %[[LHS_SHAPE:.+]] = shape.shape_of %arg0
|
||||
// CHECK-DAG: %[[RHS_SHAPE:.+]] = shape.shape_of %arg1
|
||||
// CHECK-DAG: %[[RESULT_SHAPE:.+]] = "shape.broadcast"(%[[LHS_SHAPE]], %[[RHS_SHAPE]])
|
||||
// CHECK-DAG: %[[RESULT_EXTENTS:.+]] = "shape.to_extent_tensor"(%[[RESULT_SHAPE]])
|
||||
// CHECK-DAG: %[[RESULT_EXTENTS:.+]] = shape.to_extent_tensor %[[RESULT_SHAPE]]
|
||||
// CHECK-DAG: %[[LHS_BCAST:.+]] = "xla_hlo.dynamic_broadcast_in_dim"(%arg0, %[[RESULT_EXTENTS]]) {broadcast_dimensions = dense<1> : tensor<1xi64>}
|
||||
// CHECK-DAG: %[[RHS_BCAST:.+]] = "xla_hlo.dynamic_broadcast_in_dim"(%arg1, %[[RESULT_EXTENTS]]) {broadcast_dimensions = dense<[0, 1]> : tensor<2xi64>}
|
||||
// CHECK: xla_hlo.add %4, %5 : tensor<?x?xi32>
|
||||
@ -203,7 +203,7 @@ func @equal_dynamic(%arg0: tensor<?xi32>, %arg1: tensor<1xi32>) -> tensor<?xi1>
|
||||
// CHECK-DAG: %[[LHS_SHAPE:.+]] = shape.shape_of %arg0
|
||||
// CHECK-DAG: %[[RHS_SHAPE:.+]] = shape.const_shape [1]
|
||||
// CHECK-DAG: %[[RESULT_SHAPE:.+]] = "shape.broadcast"(%[[LHS_SHAPE]], %[[RHS_SHAPE]])
|
||||
// CHECK-DAG: %[[RESULT_EXTENTS:.+]] = "shape.to_extent_tensor"(%[[RESULT_SHAPE]])
|
||||
// CHECK-DAG: %[[RESULT_EXTENTS:.+]] = shape.to_extent_tensor %[[RESULT_SHAPE]]
|
||||
// CHECK-DAG: %[[LHS_BCAST:.+]] = "xla_hlo.dynamic_broadcast_in_dim"(%arg0, %[[RESULT_EXTENTS]]) {broadcast_dimensions = dense<0> : tensor<1xi64>}
|
||||
// CHECK-DAG: %[[RHS_BCAST:.+]] = "xla_hlo.dynamic_broadcast_in_dim"(%arg1, %[[RESULT_EXTENTS]]) {broadcast_dimensions = dense<0> : tensor<1xi64>}
|
||||
// CHECK: "xla_hlo.compare"(%[[LHS_BCAST]], %[[RHS_BCAST]]) {comparison_direction = "EQ"}
|
||||
@ -216,7 +216,7 @@ func @equal_broadcast(%arg0: tensor<1xi32>, %arg1: tensor<1x2xi32>) -> tensor<1x
|
||||
// CHECK-DAG: %[[LHS_SHAPE:.+]] = shape.const_shape [1]
|
||||
// CHECK-DAG: %[[RHS_SHAPE:.+]] = shape.const_shape [1, 2]
|
||||
// CHECK-DAG: %[[RESULT_SHAPE:.+]] = shape.const_shape [1, 2]
|
||||
// CHECK-DAG: %[[RESULT_EXTENTS:.+]] = "shape.to_extent_tensor"(%[[RESULT_SHAPE]])
|
||||
// CHECK-DAG: %[[RESULT_EXTENTS:.+]] = shape.to_extent_tensor %[[RESULT_SHAPE]]
|
||||
// CHECK-DAG: %[[LHS_BCAST:.+]] = "xla_hlo.dynamic_broadcast_in_dim"(%arg0, %[[RESULT_EXTENTS]]) {broadcast_dimensions = dense<1> : tensor<1xi64>}
|
||||
// CHECK-DAG: %[[RHS_BCAST:.+]] = "xla_hlo.dynamic_broadcast_in_dim"(%arg1, %[[RESULT_EXTENTS]]) {broadcast_dimensions = dense<[0, 1]> : tensor<2xi64>}
|
||||
// CHECK: "xla_hlo.compare"(%[[LHS_BCAST]], %[[RHS_BCAST]]) {comparison_direction = "EQ"}
|
||||
@ -284,7 +284,7 @@ func @broadcast_greater(%arg0: tensor<1xi32>, %arg1: tensor<1x2xi32>) -> tensor<
|
||||
// CHECK-DAG: %[[LHS_SHAPE:.+]] = shape.const_shape [1]
|
||||
// CHECK-DAG: %[[RHS_SHAPE:.+]] = shape.const_shape [1, 2]
|
||||
// CHECK-DAG: %[[RESULT_SHAPE:.+]] = shape.const_shape [1, 2]
|
||||
// CHECK-DAG: %[[RESULT_EXTENTS:.+]] = "shape.to_extent_tensor"(%[[RESULT_SHAPE]])
|
||||
// CHECK-DAG: %[[RESULT_EXTENTS:.+]] = shape.to_extent_tensor %[[RESULT_SHAPE]]
|
||||
// CHECK-DAG: %[[LHS_BCAST:.+]] = "xla_hlo.dynamic_broadcast_in_dim"(%arg0, %[[RESULT_EXTENTS]]) {broadcast_dimensions = dense<1> : tensor<1xi64>}
|
||||
// CHECK-DAG: %[[RHS_BCAST:.+]] = "xla_hlo.dynamic_broadcast_in_dim"(%arg1, %[[RESULT_EXTENTS]]) {broadcast_dimensions = dense<[0, 1]> : tensor<2xi64>}
|
||||
// CHECK: "xla_hlo.compare"(%[[LHS_BCAST]], %[[RHS_BCAST]]) {comparison_direction = "GT"}
|
||||
@ -297,7 +297,7 @@ func @greater_dynamic(%arg0: tensor<?xi32>, %arg1: tensor<?xi32>) -> tensor<?xi1
|
||||
// CHECK-DAG: %[[LHS_SHAPE:.+]] = shape.shape_of %arg0
|
||||
// CHECK-DAG: %[[RHS_SHAPE:.+]] = shape.shape_of %arg1
|
||||
// CHECK-DAG: %[[RESULT_SHAPE:.+]] = "shape.broadcast"(%[[LHS_SHAPE]], %[[RHS_SHAPE]])
|
||||
// CHECK-DAG: %[[RESULT_EXTENTS:.+]] = "shape.to_extent_tensor"(%[[RESULT_SHAPE]])
|
||||
// CHECK-DAG: %[[RESULT_EXTENTS:.+]] = shape.to_extent_tensor %[[RESULT_SHAPE]]
|
||||
// CHECK-DAG: %[[LHS_BCAST:.+]] = "xla_hlo.dynamic_broadcast_in_dim"(%arg0, %[[RESULT_EXTENTS]]) {broadcast_dimensions = dense<0> : tensor<1xi64>}
|
||||
// CHECK-DAG: %[[RHS_BCAST:.+]] = "xla_hlo.dynamic_broadcast_in_dim"(%arg1, %[[RESULT_EXTENTS]]) {broadcast_dimensions = dense<0> : tensor<1xi64>}
|
||||
// CHECK: "xla_hlo.compare"(%[[LHS_BCAST]], %[[RHS_BCAST]]) {comparison_direction = "GT"}
|
||||
|
@ -187,6 +187,70 @@ func @dynamic_update_slice(%arg0: tensor<3x4xi32>, %arg1: tensor<2x2xi32>, %arg2
|
||||
return %0: tensor<3x4xi32>
|
||||
}
|
||||
|
||||
// CHECK-LABEL: @sparse_to_dense
|
||||
// CHECK-SAME: (%[[ARG0:.*]]: tensor<3x2xi32>, %[[ARG1:.*]]: tensor<3xf32>, %[[ARG2:.*]]: tensor<f32>)
|
||||
func @sparse_to_dense(%arg0: tensor<3x2xi32>, %arg1: tensor<3xf32>, %arg2: tensor<f32>) -> tensor<3x3xf32> {
|
||||
|
||||
// CHECK: %[[CST:.*]] = xla_hlo.constant dense<3> : tensor<2xi32>
|
||||
// CHECK: %[[DEFAULT:.*]] = "xla_hlo.broadcast_in_dim"(%[[ARG2]]) {broadcast_dimensions = dense<[]> : tensor<0xi64>} : (tensor<f32>) -> tensor<3x3xf32>
|
||||
|
||||
// CHECK: %[[RESULT:.*]] = "xla_hlo.scatter"(%[[DEFAULT]], %[[ARG0]], %[[ARG1]]) ( {
|
||||
// CHECK: ^bb0(%[[ARG3:.*]]: tensor<f32>, %[[ARG4:.*]]: tensor<f32>): // no predecessors
|
||||
// CHECK: "xla_hlo.return"(%[[ARG4]]) : (tensor<f32>) -> ()
|
||||
// CHECK: })
|
||||
// CHECK-SAME: indices_are_sorted = false
|
||||
// CHECK-SAME: scatter_dimension_numbers
|
||||
// CHECK-SAME: index_vector_dim = 1 : i64
|
||||
// CHECK-SAME: inserted_window_dims = dense<[0, 1]> : tensor<2xi64>
|
||||
// CHECK-SAME: scatter_dims_to_operand_dims = dense<[0, 1]> : tensor<2xi64>
|
||||
// CHECK-SAME: update_window_dims = dense<[]> : tensor<0xi64>
|
||||
// CHECK-SAME: unique_indices = false
|
||||
// CHECK-SAME: (tensor<3x3xf32>, tensor<3x2xi32>, tensor<3xf32>) -> tensor<3x3xf32>
|
||||
|
||||
// return %[[RESULT]] : tensor<3x3xf32>
|
||||
|
||||
%cst = xla_hlo.constant dense<3> : tensor<2xi32>
|
||||
%0 = "tf.SparseToDense"(%arg0, %cst, %arg1, %arg2) {validate_indices = true}: (tensor<3x2xi32>, tensor<2xi32>, tensor<3xf32>, tensor<f32>) -> tensor<3x3xf32>
|
||||
return %0 : tensor<3x3xf32>
|
||||
}
|
||||
|
||||
// CHECK-LABEL: fft
|
||||
func @fft(%arg0: tensor<3x5x8xcomplex<f32>>) -> tensor<3x5x8xcomplex<f32>> {
|
||||
// CHECK: "xla_hlo.fft"(%arg0)
|
||||
%0 = "tf.FFT"(%arg0) : (tensor<3x5x8xcomplex<f32>>) -> tensor<3x5x8xcomplex<f32>>
|
||||
return %0 : tensor<3x5x8xcomplex<f32>>
|
||||
}
|
||||
|
||||
// CHECK-LABEL: reverse_sequence
|
||||
func @reverse_sequence(%arg0: tensor<4x2x3x1x1xi32>, %arg1: tensor<3xi32>) -> tensor<4x2x3x1x1xi32> {
|
||||
// CHECK-NOT: tf.ReverseSequence
|
||||
%0 = "tf.ReverseSequence"(%arg0, %arg1) {batch_dim = 2 : i64, seq_dim = 0 : i64}: (tensor<4x2x3x1x1xi32>, tensor<3xi32>) -> tensor<4x2x3x1x1xi32>
|
||||
return %0 : tensor<4x2x3x1x1xi32>
|
||||
}
|
||||
|
||||
// CHECK-LABEL: mirror_pad
|
||||
func @mirror_pad(%arg0: tensor<2x3xcomplex<f64>>) -> tensor<4x7xcomplex<f64>> {
|
||||
%0 = xla_hlo.constant dense<[[1, 1], [2, 2]]> : tensor<2x2xi32>
|
||||
// CHECK-NOT: tf.MirrorPad
|
||||
%1 = "tf.MirrorPad"(%arg0, %0) {mode = "SYMMETRIC"} : (tensor<2x3xcomplex<f64>>, tensor<2x2xi32>) -> tensor<4x7xcomplex<f64>>
|
||||
return %1 : tensor<4x7xcomplex<f64>>
|
||||
}
|
||||
|
||||
// CHECK-LABEL: bucketize
|
||||
func @bucketize(%arg0: tensor<2x5xf32>) -> tensor<2x5xi32> {
|
||||
// CHECK-NOT: tf.Bucketize
|
||||
%0 = "tf.Bucketize"(%arg0) {boundaries = [0.000000e+00 : f32, 3.000000e+00 : f32, 8.000000e+00 : f32, 1.100000e+01 : f32]} : (tensor<2x5xf32>) -> tensor<2x5xi32>
|
||||
return %0 : tensor<2x5xi32>
|
||||
}
|
||||
|
||||
// CHECK-LABEL: arg_min
|
||||
func @arg_min(%arg0: tensor<6xf64>) -> tensor<i32> {
|
||||
// CHECK-NOT: ArgMin
|
||||
%0 = xla_hlo.constant dense<0> : tensor<i32>
|
||||
%1 = "tf.ArgMin"(%arg0, %0) : (tensor<6xf64>, tensor<i32>) -> tensor<i32>
|
||||
return %1 : tensor<i32>
|
||||
}
|
||||
|
||||
// TODO(hinsu): Add a test with a valid TF op for which tf2xla kernel is
|
||||
// available but doesn't support this instance.
|
||||
}
|
||||
|
@ -420,7 +420,7 @@ func @fusedBatchNormGradV3_Training_NCHW(%arg0: tensor<8x8x8x8xf32>, %arg1: tens
|
||||
// CHECK-LABEL: func @biasAdd_NHWC
|
||||
func @biasAdd_NHWC(%arg0: tensor<1x32x10x32xi32>, %arg1: tensor<32xi32>) -> tensor<1x32x10x32xi32> {
|
||||
// CHECK: %[[ARG0_SHAPE:.+]] = shape.shape_of %arg0
|
||||
// CHECK: %[[ARG0_EXTENTS:.+]] = "shape.to_extent_tensor"(%[[ARG0_SHAPE]])
|
||||
// CHECK: %[[ARG0_EXTENTS:.+]] = shape.to_extent_tensor %[[ARG0_SHAPE]]
|
||||
// CHECK: %[[ARG1_BCAST:.+]] = "xla_hlo.dynamic_broadcast_in_dim"(%arg1, %[[ARG0_EXTENTS]])
|
||||
// CHECK-SAME: {broadcast_dimensions = dense<3> : tensor<1xi64>}
|
||||
// CHECK: %[[RESULT:.+]] = xla_hlo.add %arg0, %[[ARG1_BCAST]]
|
||||
@ -431,7 +431,7 @@ func @biasAdd_NHWC(%arg0: tensor<1x32x10x32xi32>, %arg1: tensor<32xi32>) -> tens
|
||||
// CHECK-LABEL: func @biasAdd_NCHW
|
||||
func @biasAdd_NCHW(%arg0: tensor<1x32x10x32xi32>, %arg1: tensor<32xi32>) -> tensor<1x32x10x32xi32> {
|
||||
// CHECK: %[[ARG0_SHAPE:.+]] = shape.shape_of %arg0
|
||||
// CHECK: %[[ARG0_EXTENTS:.+]] = "shape.to_extent_tensor"(%[[ARG0_SHAPE]])
|
||||
// CHECK: %[[ARG0_EXTENTS:.+]] = shape.to_extent_tensor %[[ARG0_SHAPE]]
|
||||
// CHECK: %[[ARG1_BCAST:.+]] = "xla_hlo.dynamic_broadcast_in_dim"(%arg1, %[[ARG0_EXTENTS]])
|
||||
// CHECK-SAME: {broadcast_dimensions = dense<1> : tensor<1xi64>}
|
||||
// CHECK: %[[RESULT:.+]] = xla_hlo.add %arg0, %[[ARG1_BCAST]]
|
||||
@ -442,7 +442,7 @@ func @biasAdd_NCHW(%arg0: tensor<1x32x10x32xi32>, %arg1: tensor<32xi32>) -> tens
|
||||
// CHECK-LABEL: func @biasAdd_dynamic
|
||||
func @biasAdd_dynamic(%arg0: tensor<?x?x?x?xi32>, %arg1: tensor<?xi32>) -> tensor<?x?x?x?xi32> {
|
||||
// CHECK: %[[ARG0_SHAPE:.+]] = shape.shape_of %arg0
|
||||
// CHECK: %[[ARG0_EXTENTS:.+]] = "shape.to_extent_tensor"(%[[ARG0_SHAPE]])
|
||||
// CHECK: %[[ARG0_EXTENTS:.+]] = shape.to_extent_tensor %[[ARG0_SHAPE]]
|
||||
// CHECK: %[[ARG1_BCAST:.+]] = "xla_hlo.dynamic_broadcast_in_dim"(%arg1, %[[ARG0_EXTENTS]])
|
||||
// CHECK-SAME: {broadcast_dimensions = dense<1> : tensor<1xi64>}
|
||||
// CHECK: %[[RESULT:.+]] = xla_hlo.add %arg0, %[[ARG1_BCAST]]
|
||||
@ -1445,7 +1445,7 @@ func @simple_softmax(%arg0: tensor<2x3xf32>) -> tensor<2x3xf32> {
|
||||
// CHECK: %[[CASTED_MAX:.*]] = "xla_hlo.convert"(%[[MAX]]) : (tensor<2xf32>) -> tensor<2xf32>
|
||||
|
||||
// CHECK: %[[RESULT_SHAPE:.+]] = shape.shape_of %[[ARG0]]
|
||||
// CHECK: %[[RESULT_EXTENTS:.+]] = "shape.to_extent_tensor"(%[[RESULT_SHAPE]]) : (!shape.shape) -> tensor<2xindex>
|
||||
// CHECK: %[[RESULT_EXTENTS:.+]] = shape.to_extent_tensor %[[RESULT_SHAPE]] : tensor<2xindex>
|
||||
// CHECK: %[[BCAST_MAX:.+]] = "xla_hlo.dynamic_broadcast_in_dim"(%[[CASTED_MAX]], %[[RESULT_EXTENTS]]) {broadcast_dimensions = dense<0> : tensor<1xi64>}
|
||||
// CHECK: %[[SHIFTED_INP:.*]] = xla_hlo.subtract %[[ARG0]], %[[BCAST_MAX]]
|
||||
// CHECK: %[[EXP:.*]] = "xla_hlo.exponential"(%[[SHIFTED_INP]])
|
||||
@ -1460,7 +1460,7 @@ func @simple_softmax(%arg0: tensor<2x3xf32>) -> tensor<2x3xf32> {
|
||||
// CHECK: %[[CASTED_SUM:.*]] = "xla_hlo.convert"(%[[SUM]]) : (tensor<2xf32>) -> tensor<2xf32>
|
||||
|
||||
// CHECK: %[[RESULT_SHAPE:.+]] = shape.shape_of %[[ARG0]]
|
||||
// CHECK: %[[RESULT_EXTENTS:.+]] = "shape.to_extent_tensor"(%[[RESULT_SHAPE]]) : (!shape.shape) -> tensor<2xindex>
|
||||
// CHECK: %[[RESULT_EXTENTS:.+]] = shape.to_extent_tensor %[[RESULT_SHAPE]] : tensor<2xindex>
|
||||
// CHECK: %[[BCAST_SUM:.+]] = "xla_hlo.dynamic_broadcast_in_dim"(%[[CASTED_SUM]], %[[RESULT_EXTENTS]]) {broadcast_dimensions = dense<0> : tensor<1xi64>}
|
||||
// CHECK: %[[RESULT:.*]] = xla_hlo.divide %[[EXP]], %[[BCAST_SUM]]
|
||||
// CHECK: return %[[RESULT]]
|
||||
@ -1517,7 +1517,7 @@ func @simple_logsoftmax(%arg0: tensor<2x3xf32>) -> tensor<2x3xf32> {
|
||||
// CHECK: %[[CASTED_SUM:.*]] = "xla_hlo.convert"(%[[SUM]]) : (tensor<2xf32>) -> tensor<2xf32>
|
||||
// CHECK: %[[LOG:.*]] = "xla_hlo.log"(%[[CASTED_SUM]]) : (tensor<2xf32>) -> tensor<2xf32>
|
||||
// CHECK: %[[RESULT_SHAPE:.+]] = shape.shape_of %[[ARG0]]
|
||||
// CHECK: %[[RESULT_EXTENTS:.+]] = "shape.to_extent_tensor"(%[[RESULT_SHAPE]]) : (!shape.shape) -> tensor<2xindex>
|
||||
// CHECK: %[[RESULT_EXTENTS:.+]] = shape.to_extent_tensor %[[RESULT_SHAPE]] : tensor<2xindex>
|
||||
// CHECK: %[[BCAST_SUM:.+]] = "xla_hlo.dynamic_broadcast_in_dim"(%[[LOG]], %[[RESULT_EXTENTS]]) {broadcast_dimensions = dense<0> : tensor<1xi64>}
|
||||
// CHECK: %[[RESULT:.*]] = xla_hlo.subtract {{.*}}, %[[BCAST_SUM]]
|
||||
// CHECK: return %[[RESULT]]
|
||||
@ -1544,58 +1544,34 @@ func @rfft_1D(%arg0: tensor<8xf32>) -> tensor<8xcomplex<f32>> {
|
||||
|
||||
// CHECK-LABEL: func @shape_1D
|
||||
func @shape_1D(%arg0: tensor<?xf32>) -> tensor<1xi32> {
|
||||
// CHECK-DAG: [[SHAPE:%.+]] = shape.shape_of %arg0
|
||||
// CHECK-DAG: [[EXTENT:%.+]] = shape.get_extent [[SHAPE]], 0
|
||||
// CHECK-DAG: [[TO_INDEX:%.+]] = shape.size_to_index [[EXTENT]]
|
||||
// CHECK-DAG: [[CAST:%.+]] = index_cast [[TO_INDEX]]
|
||||
// CHECK-DAG: [[TENSOR:%.+]] = tensor_from_elements([[CAST]])
|
||||
// CHECK-DAG: [[RESHAPE:%.+]] = "xla_hlo.reshape"([[TENSOR]])
|
||||
// CHECK-DAG: [[CONCAT:%.+]] = "xla_hlo.concatenate"([[RESHAPE]]) {dimension = 0 : i64}
|
||||
// CHECK: [[SHAPE:%.+]] = shape.shape_of %arg0
|
||||
// CHECK: [[TENSOR:%.+]] = shape.to_extent_tensor [[SHAPE]]
|
||||
// CHECK: [[CAST:%.+]] = index_cast [[TENSOR]]
|
||||
%0 = "tf.Shape"(%arg0) : (tensor<?xf32>) -> tensor<1xi32>
|
||||
|
||||
// CHECK: return [[CONCAT]]
|
||||
// CHECK: return [[CAST]]
|
||||
return %0 : tensor<1xi32>
|
||||
}
|
||||
|
||||
// CHECK-LABEL: func @shape_2D
|
||||
func @shape_2D(%arg0: tensor<?x?xf32>) -> tensor<2xi32> {
|
||||
// CHECK-DAG: [[SHAPE:%.+]] = shape.shape_of %arg0
|
||||
// CHECK-DAG: [[EXTENT0:%.+]] = shape.get_extent [[SHAPE]], 0
|
||||
// CHECK-DAG: [[EXTENT1:%.+]] = shape.get_extent [[SHAPE]], 1
|
||||
// CHECK-DAG: [[TO_INDEX0:%.+]] = shape.size_to_index [[EXTENT0]]
|
||||
// CHECK-DAG: [[TO_INDEX1:%.+]] = shape.size_to_index [[EXTENT1]]
|
||||
// CHECK-DAG: [[CAST0:%.+]] = index_cast [[TO_INDEX0]]
|
||||
// CHECK-DAG: [[CAST1:%.+]] = index_cast [[TO_INDEX1]]
|
||||
// CHECK-DAG: [[TENSOR0:%.+]] = tensor_from_elements([[CAST0]])
|
||||
// CHECK-DAG: [[TENSOR1:%.+]] = tensor_from_elements([[CAST1]])
|
||||
// CHECK-DAG: [[RESHAPE0:%.+]] = "xla_hlo.reshape"([[TENSOR0]])
|
||||
// CHECK-DAG: [[RESHAPE1:%.+]] = "xla_hlo.reshape"([[TENSOR1]])
|
||||
// CHECK-DAG: [[CONCAT:%.+]] = "xla_hlo.concatenate"([[RESHAPE0]], [[RESHAPE1]]) {dimension = 0 : i64}
|
||||
// CHECK: [[SHAPE:%.+]] = shape.shape_of %arg0
|
||||
// CHECK: [[TENSOR:%.+]] = shape.to_extent_tensor [[SHAPE]]
|
||||
// CHECK: [[CAST:%.+]] = index_cast [[TENSOR]]
|
||||
%0 = "tf.Shape"(%arg0) : (tensor<?x?xf32>) -> tensor<2xi32>
|
||||
|
||||
// CHECK: return [[CONCAT]]
|
||||
return %0 : tensor<2xi32>
|
||||
}
|
||||
|
||||
// CHECK-LABEL: func @shape_with_const
|
||||
func @shape_with_const(%arg0: tensor<?x3xf32>) -> tensor<2xi32> {
|
||||
// CHECK-DAG: [[SHAPE:%.+]] = shape.shape_of %arg0
|
||||
// CHECK-DAG: [[EXTENT:%.+]] = shape.get_extent [[SHAPE]], 0
|
||||
// CHECK-DAG: [[TO_INDEX:%.+]] = shape.size_to_index [[EXTENT]]
|
||||
// CHECK-DAG: [[CAST:%.+]] = index_cast [[TO_INDEX]]
|
||||
// CHECK-DAG: [[TENSOR:%.+]] = tensor_from_elements([[CAST]])
|
||||
// CHECK-DAG: [[RESHAPE:%.+]] = "xla_hlo.reshape"([[TENSOR]])
|
||||
// CHECK-DAG: [[CONST:%.+]] = xla_hlo.constant dense<3>
|
||||
// CHECK-DAG: [[CONCAT:%.+]] = "xla_hlo.concatenate"([[RESHAPE]], [[CONST]]) {dimension = 0 : i64}
|
||||
%0 = "tf.Shape"(%arg0) : (tensor<?x3xf32>) -> tensor<2xi32>
|
||||
|
||||
// CHECK: return [[CONCAT]]
|
||||
// CHECK: return [[CAST]]
|
||||
return %0 : tensor<2xi32>
|
||||
}
|
||||
|
||||
// CHECK-LABEL: func @shape_rankless
|
||||
func @shape_rankless(%arg0: tensor<*xf32>) -> tensor<?xi32> {
|
||||
// CHECK: [[SHAPE:%.+]] = shape.shape_of %arg0
|
||||
// CHECK: [[TENSOR:%.+]] = shape.to_extent_tensor [[SHAPE]]
|
||||
// CHECK: [[CAST:%.+]] = index_cast [[TENSOR]]
|
||||
%0 = "tf.Shape"(%arg0) : (tensor<*xf32>) -> tensor<?xi32>
|
||||
|
||||
// CHECK: return [[CAST]]
|
||||
return %0 : tensor<?xi32>
|
||||
}
|
||||
|
||||
@ -1884,7 +1860,7 @@ func @neg_unranked(%arg0: tensor<*xf32>) -> tensor<*xf32> {
|
||||
func @sigmoid(%arg0: tensor<2xf32>) -> tensor<2xf32> {
|
||||
// CHECK-DAG: [[SCALAR:%.+]] = xla_hlo.constant dense<5.000000e-01> : tensor<f32>
|
||||
// CHECK-DAG: [[SHAPE:%.+]] = shape.shape_of %arg0 : tensor<2xf32>
|
||||
// CHECK-DAG: [[SHAPE_VAL:%.+]] = "shape.to_extent_tensor"([[SHAPE]]) : (!shape.shape) -> tensor<1xindex>
|
||||
// CHECK-DAG: [[SHAPE_VAL:%.+]] = shape.to_extent_tensor [[SHAPE]] : tensor<1xindex>
|
||||
// CHECK-DAG: [[HALF:%.+]] = "xla_hlo.dynamic_broadcast_in_dim"([[SCALAR]], [[SHAPE_VAL]]) {broadcast_dimensions = dense<[]> : tensor<0xi64>} : (tensor<f32>, tensor<1xindex>) -> tensor<2xf32>
|
||||
// CHECK-DAG: [[R1:%.+]] = xla_hlo.multiply %arg0, [[HALF]] : tensor<2xf32>
|
||||
// CHECK-DAG: [[R2:%.+]] = "xla_hlo.tanh"([[R1]]) : (tensor<2xf32>) -> tensor<2xf32>
|
||||
@ -1906,7 +1882,7 @@ func @sigmoid_complex(%arg0: tensor<2xcomplex<f32>>) -> tensor<2xcomplex<f32>> {
|
||||
func @sigmoid_unranked(%arg0: tensor<*xf32>) -> tensor<*xf32> {
|
||||
// CHECK-DAG: [[SCALAR:%.+]] = xla_hlo.constant dense<5.000000e-01> : tensor<f32>
|
||||
// CHECK-DAG: [[SHAPE:%.+]] = shape.shape_of %arg0 : tensor<*xf32>
|
||||
// CHECK-DAG: [[SHAPE_VAL:%.+]] = "shape.to_extent_tensor"([[SHAPE]]) : (!shape.shape) -> tensor<?xindex>
|
||||
// CHECK-DAG: [[SHAPE_VAL:%.+]] = shape.to_extent_tensor [[SHAPE]] : tensor<?xindex>
|
||||
// CHECK-DAG: [[HALF:%.+]] = "xla_hlo.dynamic_broadcast_in_dim"([[SCALAR]], [[SHAPE_VAL]]) {broadcast_dimensions = dense<[]> : tensor<0xi64>} : (tensor<f32>, tensor<?xindex>) -> tensor<*xf32>
|
||||
// CHECK-DAG: [[R1:%.+]] = xla_hlo.multiply %arg0, [[HALF]] : tensor<*xf32>
|
||||
// CHECK-DAG: [[R2:%.+]] = "xla_hlo.tanh"([[R1]]) : (tensor<*xf32>) -> tensor<*xf32>
|
||||
@ -3826,42 +3802,6 @@ func @random_shuffle_3D(%input: tensor<4x?x16xf32>) -> tensor<4x?x16xf32> {
|
||||
return %0: tensor<4x?x16xf32>
|
||||
}
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
// tf.VariableShape legalization
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
// CHECK-LABLE: @variable_shape32
|
||||
func @variable_shape32(%input: tensor<!tf.resource<tensor<2x4x8xf32>>>) -> tensor<3xi32> {
|
||||
// CHECK: [[CST:%.*]] = xla_hlo.constant dense<[2, 4, 8]> : tensor<3xi32>
|
||||
// CHECK: [[CST_CAST:%.*]] = tensor_cast [[CST]]
|
||||
%0 = "tf.VariableShape"(%input) : (tensor<!tf.resource<tensor<2x4x8xf32>>>) -> (tensor<3xi32>)
|
||||
// CHECK: return [[CST_CAST]]
|
||||
return %0: tensor<3xi32>
|
||||
}
|
||||
|
||||
// CHECK-LABLE: @variable_shape64
|
||||
func @variable_shape64(%input: tensor<!tf.resource<tensor<2x4x8xf32>>>) -> tensor<3xi64> {
|
||||
// CHECK: [[CST:%.*]] = xla_hlo.constant dense<[2, 4, 8]> : tensor<3xi64>
|
||||
// CHECK: [[CST_CAST:%.*]] = tensor_cast [[CST]]
|
||||
%0 = "tf.VariableShape"(%input) : (tensor<!tf.resource<tensor<2x4x8xf32>>>) -> (tensor<3xi64>)
|
||||
// CHECK: return [[CST_CAST]]
|
||||
return %0: tensor<3xi64>
|
||||
}
|
||||
|
||||
// CHECK-LABEL: @variable_shape_unknown_resource
|
||||
func @variable_shape_unknown_resource(%input: tensor<!tf.resource>) -> tensor<?xi32> {
|
||||
// CHECK: tf.VariableShape
|
||||
%0 = "tf.VariableShape"(%input) : (tensor<!tf.resource>) -> (tensor<?xi32>)
|
||||
return %0: tensor<?xi32>
|
||||
}
|
||||
|
||||
// CHECK-LABEL: @variable_shape_unknown_resource_shape
|
||||
func @variable_shape_unknown_resource_shape(%input: tensor<!tf.resource<tensor<?x?xf32>>>) -> tensor<2xi32> {
|
||||
// CHECK: tf.VariableShape
|
||||
%0 = "tf.VariableShape"(%input) : (tensor<!tf.resource<tensor<?x?xf32>>>) -> (tensor<2xi32>)
|
||||
return %0: tensor<2xi32>
|
||||
}
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
// tf.AvgPool legalization
|
||||
//===----------------------------------------------------------------------===//
|
||||
@ -4025,3 +3965,87 @@ func @qr(%arg0: tensor<500x100x75xf32>) -> (tensor<500x100x75xf32>, tensor<500x7
|
||||
%0:2 = "tf.Qr"(%arg0) {full_matrices = false} : (tensor<500x100x75xf32>) -> (tensor<500x100x75xf32>, tensor<500x75x75xf32>)
|
||||
return %0#0, %0#1 : tensor<500x100x75xf32>, tensor<500x75x75xf32>
|
||||
}
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
// tf.Softplus legalization
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
// CHECK-LABEL: func @softplus_f16
|
||||
// CHECK-SAME: ([[FEATURES:%.*]]: tensor<8x16xf16>)
|
||||
func @softplus_f16(%arg0: tensor<8x16xf16>) -> tensor<8x16xf16> {
|
||||
// CHECK-DAG: [[FEATURES_EXP:%.*]] = "xla_hlo.exponential"([[FEATURES]])
|
||||
// CHECK-DAG: [[EPSILON:%.*]] = xla_hlo.constant dense<1.220700e-04> : tensor<f16>
|
||||
// CHECK-DAG: [[EPSILON_LOG:%.*]] = "xla_hlo.log"([[EPSILON]])
|
||||
// CHECK-DAG: [[TWO:%.*]] = xla_hlo.constant dense<2.000000e+00> : tensor<f16>
|
||||
// CHECK: [[THRESHOLD:%.*]] = xla_chlo.broadcast_add [[EPSILON_LOG]], [[TWO]]
|
||||
// CHECK: [[NEG_THRESHOLD:%.*]] = "xla_hlo.negate"([[THRESHOLD]])
|
||||
// CHECK-DAG: [[COMPARE_GT:%.*]] = xla_chlo.broadcast_compare [[FEATURES]], [[NEG_THRESHOLD]] {comparison_direction = "GT"}
|
||||
// CHECK-DAG: [[COMPARE_LT:%.*]] = xla_chlo.broadcast_compare [[FEATURES]], [[THRESHOLD]] {comparison_direction = "LT"}
|
||||
// CHECK-DAG: [[FEATURES_EXP_LOG:%.*]] = "xla_hlo.log_plus_one"([[FEATURES_EXP]])
|
||||
// CHECK: [[ELSE_SELECT:%.*]] = "xla_hlo.select"([[COMPARE_LT]], [[FEATURES_EXP]], [[FEATURES_EXP_LOG]])
|
||||
// CHECK: [[ENTRY_SELECT:%.*]] = "xla_hlo.select"([[COMPARE_GT]], [[FEATURES]], [[ELSE_SELECT]])
|
||||
%0 = "tf.Softplus"(%arg0) : (tensor<8x16xf16>) -> tensor<8x16xf16>
|
||||
|
||||
// CHECK: return [[ENTRY_SELECT]] : tensor<8x16xf16>
|
||||
return %0 : tensor<8x16xf16>
|
||||
}
|
||||
|
||||
// CHECK-LABEL: func @softplus_bf16
|
||||
// CHECK-SAME: ([[FEATURES:%.*]]: tensor<8x16xbf16>)
|
||||
func @softplus_bf16(%arg0: tensor<8x16xbf16>) -> tensor<8x16xbf16> {
|
||||
// CHECK-DAG: [[FEATURES_EXP:%.*]] = "xla_hlo.exponential"([[FEATURES]])
|
||||
// CHECK-DAG: [[EPSILON:%.*]] = xla_hlo.constant dense<7.812500e-03> : tensor<bf16>
|
||||
// CHECK-DAG: [[EPSILON_LOG:%.*]] = "xla_hlo.log"([[EPSILON]])
|
||||
// CHECK-DAG: [[TWO:%.*]] = xla_hlo.constant dense<2.000000e+00> : tensor<bf16>
|
||||
// CHECK: [[THRESHOLD:%.*]] = xla_chlo.broadcast_add [[EPSILON_LOG]], [[TWO]]
|
||||
// CHECK: [[NEG_THRESHOLD:%.*]] = "xla_hlo.negate"([[THRESHOLD]])
|
||||
// CHECK-DAG: [[COMPARE_GT:%.*]] = xla_chlo.broadcast_compare [[FEATURES]], [[NEG_THRESHOLD]] {comparison_direction = "GT"}
|
||||
// CHECK-DAG: [[COMPARE_LT:%.*]] = xla_chlo.broadcast_compare [[FEATURES]], [[THRESHOLD]] {comparison_direction = "LT"}
|
||||
// CHECK-DAG: [[FEATURES_EXP_LOG:%.*]] = "xla_hlo.log_plus_one"([[FEATURES_EXP]])
|
||||
// CHECK: [[ELSE_SELECT:%.*]] = "xla_hlo.select"([[COMPARE_LT]], [[FEATURES_EXP]], [[FEATURES_EXP_LOG]])
|
||||
// CHECK: [[ENTRY_SELECT:%.*]] = "xla_hlo.select"([[COMPARE_GT]], [[FEATURES]], [[ELSE_SELECT]])
|
||||
%0 = "tf.Softplus"(%arg0) : (tensor<8x16xbf16>) -> tensor<8x16xbf16>
|
||||
|
||||
// CHECK: return [[ENTRY_SELECT]] : tensor<8x16xbf16>
|
||||
return %0 : tensor<8x16xbf16>
|
||||
}
|
||||
|
||||
// CHECK-LABEL: func @softplus_f32
|
||||
// CHECK-SAME: ([[FEATURES:%.*]]: tensor<8x16xf32>)
|
||||
func @softplus_f32(%arg0: tensor<8x16xf32>) -> tensor<8x16xf32> {
|
||||
// CHECK-DAG: [[FEATURES_EXP:%.*]] = "xla_hlo.exponential"([[FEATURES]])
|
||||
// CHECK-DAG: [[EPSILON:%.*]] = xla_hlo.constant dense<1.1920929E-7> : tensor<f32>
|
||||
// CHECK-DAG: [[EPSILON_LOG:%.*]] = "xla_hlo.log"([[EPSILON]])
|
||||
// CHECK-DAG: [[TWO:%.*]] = xla_hlo.constant dense<2.000000e+00> : tensor<f32>
|
||||
// CHECK: [[THRESHOLD:%.*]] = xla_chlo.broadcast_add [[EPSILON_LOG]], [[TWO]]
|
||||
// CHECK: [[NEG_THRESHOLD:%.*]] = "xla_hlo.negate"([[THRESHOLD]])
|
||||
// CHECK-DAG: [[COMPARE_GT:%.*]] = xla_chlo.broadcast_compare [[FEATURES]], [[NEG_THRESHOLD]] {comparison_direction = "GT"}
|
||||
// CHECK-DAG: [[COMPARE_LT:%.*]] = xla_chlo.broadcast_compare [[FEATURES]], [[THRESHOLD]] {comparison_direction = "LT"}
|
||||
// CHECK-DAG: [[FEATURES_EXP_LOG:%.*]] = "xla_hlo.log_plus_one"([[FEATURES_EXP]])
|
||||
// CHECK: [[ELSE_SELECT:%.*]] = "xla_hlo.select"([[COMPARE_LT]], [[FEATURES_EXP]], [[FEATURES_EXP_LOG]])
|
||||
// CHECK: [[ENTRY_SELECT:%.*]] = "xla_hlo.select"([[COMPARE_GT]], [[FEATURES]], [[ELSE_SELECT]])
|
||||
%0 = "tf.Softplus"(%arg0) : (tensor<8x16xf32>) -> tensor<8x16xf32>
|
||||
|
||||
// CHECK: return [[ENTRY_SELECT]] : tensor<8x16xf32>
|
||||
return %0 : tensor<8x16xf32>
|
||||
}
|
||||
|
||||
// CHECK-LABEL: func @softplus_f64
|
||||
// CHECK-SAME: ([[FEATURES:%.*]]: tensor<8x16xf64>)
|
||||
func @softplus_f64(%arg0: tensor<8x16xf64>) -> tensor<8x16xf64> {
|
||||
// CHECK-DAG: [[FEATURES_EXP:%.*]] = "xla_hlo.exponential"([[FEATURES]])
|
||||
// CHECK-DAG: [[EPSILON:%.*]] = xla_hlo.constant dense<2.2204460492503131E-16> : tensor<f64>
|
||||
// CHECK-DAG: [[EPSILON_LOG:%.*]] = "xla_hlo.log"([[EPSILON]])
|
||||
// CHECK-DAG: [[TWO:%.*]] = xla_hlo.constant dense<2.000000e+00> : tensor<f64>
|
||||
// CHECK: [[THRESHOLD:%.*]] = xla_chlo.broadcast_add [[EPSILON_LOG]], [[TWO]]
|
||||
// CHECK: [[NEG_THRESHOLD:%.*]] = "xla_hlo.negate"([[THRESHOLD]])
|
||||
// CHECK-DAG: [[COMPARE_GT:%.*]] = xla_chlo.broadcast_compare [[FEATURES]], [[NEG_THRESHOLD]] {comparison_direction = "GT"}
|
||||
// CHECK-DAG: [[COMPARE_LT:%.*]] = xla_chlo.broadcast_compare [[FEATURES]], [[THRESHOLD]] {comparison_direction = "LT"}
|
||||
// CHECK-DAG: [[FEATURES_EXP_LOG:%.*]] = "xla_hlo.log_plus_one"([[FEATURES_EXP]])
|
||||
// CHECK: [[ELSE_SELECT:%.*]] = "xla_hlo.select"([[COMPARE_LT]], [[FEATURES_EXP]], [[FEATURES_EXP_LOG]])
|
||||
// CHECK: [[ENTRY_SELECT:%.*]] = "xla_hlo.select"([[COMPARE_GT]], [[FEATURES]], [[ELSE_SELECT]])
|
||||
%0 = "tf.Softplus"(%arg0) : (tensor<8x16xf64>) -> tensor<8x16xf64>
|
||||
|
||||
// CHECK: return [[ENTRY_SELECT]] : tensor<8x16xf64>
|
||||
return %0 : tensor<8x16xf64>
|
||||
}
|
||||
|
@ -173,7 +173,8 @@ func @iota(%out: memref<7x10xf32>) {
|
||||
"xla_lhlo.iota"(%out) {iota_dimension = 1 : i64} : (memref<7x10xf32>) -> ()
|
||||
return
|
||||
}
|
||||
// CHECK: linalg.indexed_generic {{{.*}}indexing_maps = [#[[RESULT_MAP]]]
|
||||
// CHECK: linalg.indexed_generic
|
||||
// CHECK-SAME: indexing_maps = [#[[RESULT_MAP]]]
|
||||
// CHECK-NEXT: ^bb0(%[[D0:.*]]: index, %[[D1:.*]]: index, %[[RESULT:.*]]: f32):
|
||||
// CHECK-NEXT: %[[INT_CAST:.*]] = index_cast %[[D1]] : index to i32
|
||||
// CHECK-NEXT: %[[FLOAT_CAST:.*]] = sitofp %[[INT_CAST]] : i32 to f32
|
||||
@ -190,7 +191,8 @@ func @broadcast_scalar(%operand: memref<f32>, %result: memref<4x2x1xf32>) {
|
||||
} : (memref<f32>, memref<4x2x1xf32>) -> ()
|
||||
return
|
||||
}
|
||||
// CHECK: linalg.generic {{{.*}}indexing_maps = [#[[OPERAND_MAP]], #[[RESULT_MAP]]]
|
||||
// CHECK: linalg.generic
|
||||
// CHECK-SAME: indexing_maps = [#[[OPERAND_MAP]], #[[RESULT_MAP]]]
|
||||
// CHECK-NEXT: ^bb0(%[[OPERAND:.+]]: f32, %{{.+}}: f32):
|
||||
// CHECK-NEXT: linalg.yield %[[OPERAND]] : f32
|
||||
|
||||
@ -206,7 +208,8 @@ func @broadcast(%operand: memref<4x?x16xf32>,
|
||||
} : (memref<4x?x16xf32>, memref<4x2x1x4x?x16xf32>) -> ()
|
||||
return
|
||||
}
|
||||
// CHECK: linalg.generic {{{.*}}indexing_maps = [#[[OPERAND_MAP]], #[[RESULT_MAP]]]
|
||||
// CHECK: linalg.generic
|
||||
// CHECK-SAME: indexing_maps = [#[[OPERAND_MAP]], #[[RESULT_MAP]]]
|
||||
// CHECK-NEXT: ^bb0(%[[OPERAND:.+]]: f32, %{{.+}}: f32):
|
||||
// CHECK-NEXT: linalg.yield %[[OPERAND]] : f32
|
||||
|
||||
@ -222,7 +225,8 @@ func @dynamic_broadcast_in_dim(%operand: memref<?x?x?xf32>,
|
||||
} : (memref<?x?x?xf32>, memref<?x?x?x?x?xf32>) -> ()
|
||||
return
|
||||
}
|
||||
// CHECK: linalg.generic {{{.*}}indexing_maps = [#[[OPERAND_MAP]], #[[RESULT_MAP]]]
|
||||
// CHECK: linalg.generic
|
||||
// CHECK-SAME: indexing_maps = [#[[OPERAND_MAP]], #[[RESULT_MAP]]]
|
||||
// CHECK-NEXT: ^bb0(%[[OPERAND:.*]]: f32, %[[RESULT:.*]]: f32):
|
||||
// CHECK-NEXT: linalg.yield %[[OPERAND]] : f32
|
||||
|
||||
@ -645,39 +649,42 @@ func @slice(%operand: memref<?x?xf32>, %result: memref<?x?xf32>) {
|
||||
|
||||
// -----
|
||||
|
||||
// CHECK-DAG: #[[OPERAND_MAP:.*]] = affine_map<(d0, d1) -> (d0, 0, d1)>
|
||||
// CHECK-DAG: #[[RESULT_MAP:.*]] = affine_map<(d0, d1) -> (d0, d1)>
|
||||
// CHECK-DAG: #[[MAP1:.*]] = affine_map<(d0, d1, d2) -> (d0, d1)>
|
||||
// CHECK-DAG: #[[MAP2:.*]] = affine_map<(d0, d1, d2) -> (d2)>
|
||||
// CHECK-LABEL: func @reshape_3D_2D
|
||||
func @reshape_3D_2D(%arg0: memref<12x1x42xi32>, %arg1 : memref<12x42xi32>) {
|
||||
"xla_lhlo.reshape"(%arg0, %arg1)
|
||||
: (memref<12x1x42xi32>, memref<12x42xi32>) -> ()
|
||||
return
|
||||
}
|
||||
// CHECK: linalg.generic {{{.*}}indexing_maps = [#[[OPERAND_MAP]], #[[RESULT_MAP]]]
|
||||
// CHECK: linalg.reshape %{{.*}} [#[[MAP1]], #[[MAP2]]]
|
||||
// CHECK-NEXT: linalg.copy
|
||||
|
||||
// -----
|
||||
|
||||
// CHECK-DAG: #[[OPERAND_MAP:.*]] = affine_map<(d0, d1) -> (d0, d1, 0, 0)>
|
||||
// CHECK-DAG: #[[RESULT_MAP:.*]] = affine_map<(d0, d1) -> (d0, d1)>
|
||||
// CHECK-DAG: #[[MAP1:.*]] = affine_map<(d0, d1, d2, d3) -> (d0)>
|
||||
// CHECK-DAG: #[[MAP2:.*]] = affine_map<(d0, d1, d2, d3) -> (d1, d2, d3)>
|
||||
// CHECK-LABEL: func @reshape_4D_2D
|
||||
func @reshape_4D_2D(%arg0: memref<12x42x1x1xi32>, %arg1 : memref<12x42xi32>) {
|
||||
"xla_lhlo.reshape"(%arg0, %arg1)
|
||||
: (memref<12x42x1x1xi32>, memref<12x42xi32>) -> ()
|
||||
return
|
||||
}
|
||||
// CHECK: linalg.generic {{{.*}}indexing_maps = [#[[OPERAND_MAP]], #[[RESULT_MAP]]]
|
||||
// CHECK: linalg.reshape %{{.*}} [#[[MAP1]], #[[MAP2]]]
|
||||
// CHECK-NEXT: linalg.copy
|
||||
|
||||
// -----
|
||||
|
||||
// CHECK-DAG: #[[OPERAND_MAP:.*]] = affine_map<(d0, d1, d2, d3) -> (d0, d2)>
|
||||
// CHECK-DAG: #[[RESULT_MAP:.*]] = affine_map<(d0, d1, d2, d3) -> (d0, d1, d2, d3)>
|
||||
// CHECK-DAG: #[[MAP1:.*]] = affine_map<(d0, d1, d2, d3) -> (d0, d1)>
|
||||
// CHECK-DAG: #[[MAP2:.*]] = affine_map<(d0, d1, d2, d3) -> (d2, d3)>
|
||||
// CHECK-LABEL: func @reshape_2D_4D
|
||||
func @reshape_2D_4D(%arg0: memref<12x42xi32>, %arg1 : memref<12x1x42x1xi32>) {
|
||||
"xla_lhlo.reshape"(%arg0, %arg1)
|
||||
: (memref<12x42xi32>, memref<12x1x42x1xi32>) -> ()
|
||||
return
|
||||
}
|
||||
// CHECK: linalg.generic {{{.*}}indexing_maps = [#[[OPERAND_MAP]], #[[RESULT_MAP]]]
|
||||
// CHECK: linalg.reshape %{{.*}} [#[[MAP1]], #[[MAP2]]]
|
||||
// CHECK-NEXT: linalg.copy
|
||||
|
||||
// -----
|
||||
|
||||
|
@ -103,9 +103,10 @@ func @dynamic_reduce(%arg: memref<?x?x?xf32>,
|
||||
// CHECK-SAME: [[RESULT_BUF:%.*]]: memref<?x?xf32>) {
|
||||
// CHECK-DAG: [[C0:%.*]] = constant 0 : index
|
||||
// CHECK-DAG: [[C1:%.*]] = constant 1 : index
|
||||
// CHECK: [[DIM0:%.*]] = dim [[ARG_BUF]], 0 : memref<?x?x?xf32>
|
||||
// CHECK: [[DIM1:%.*]] = dim [[ARG_BUF]], 1 : memref<?x?x?xf32>
|
||||
// CHECK: [[DIM2:%.*]] = dim [[ARG_BUF]], 2 : memref<?x?x?xf32>
|
||||
// CHECK-DAG: [[C2:%.*]] = constant 2 : index
|
||||
// CHECK: [[DIM0:%.*]] = dim [[ARG_BUF]], [[C0]] : memref<?x?x?xf32>
|
||||
// CHECK: [[DIM1:%.*]] = dim [[ARG_BUF]], [[C1]] : memref<?x?x?xf32>
|
||||
// CHECK: [[DIM2:%.*]] = dim [[ARG_BUF]], [[C2]] : memref<?x?x?xf32>
|
||||
// CHECK: [[INIT:%.*]] = load [[INIT_BUF]]
|
||||
// CHECK: scf.parallel ([[I:%.*]], [[K:%.*]]) = ([[C0]], [[C0]])
|
||||
// CHECK-SAME: to ([[DIM0]], [[DIM2]]) step ([[C1]], [[C1]]) {
|
||||
|
@ -1,8 +1,66 @@
|
||||
// RUN: xla-opt %s -verify-diagnostics -split-input-file | xla-opt | FileCheck %s
|
||||
|
||||
func @enforce_same_shape(%arg0: memref<1xf32>, %arg1: memref<2xf32>) -> () {
|
||||
// expected-error@+1{{'xla_lhlo.tanh' op requires all operands to have the same type}}
|
||||
"xla_lhlo.tanh"(%arg0, %arg1) : (memref<1xf32>, memref<2xf32>) -> ()
|
||||
// -----
|
||||
|
||||
// CHECK-LABEL: func @ceil
|
||||
func @ceil(%input: memref<2x2xf32>, %result: memref<2x2xf32>) {
|
||||
"xla_lhlo.ceil"(%input, %result) : (memref<2x2xf32>, memref<2x2xf32>) -> ()
|
||||
return
|
||||
}
|
||||
|
||||
// -----
|
||||
|
||||
func @ceil(%input: memref<2x2xi32>, %result: memref<2x2xi32>) {
|
||||
// expected-error@+1{{must be memref of floating-point values}}
|
||||
"xla_lhlo.ceil"(%input, %result) : (memref<2x2xi32>, memref<2x2xi32>) -> ()
|
||||
return
|
||||
}
|
||||
|
||||
// -----
|
||||
|
||||
// CHECK-LABEL: func @cos
|
||||
func @cos(%input: memref<2x2xf32>, %result: memref<2x2xf32>) {
|
||||
"xla_lhlo.cosine"(%input, %result) : (memref<2x2xf32>, memref<2x2xf32>) -> ()
|
||||
return
|
||||
}
|
||||
|
||||
// -----
|
||||
|
||||
// CHECK-LABEL: func @cos
|
||||
func @cos(%input: memref<2x2xcomplex<f32>>, %result: memref<2x2xcomplex<f32>>) {
|
||||
"xla_lhlo.cosine"(%input, %result) : (memref<2x2xcomplex<f32>>, memref<2x2xcomplex<f32>>) -> ()
|
||||
return
|
||||
}
|
||||
|
||||
// -----
|
||||
|
||||
func @cos(%input: memref<2x2xi32>, %result: memref<2x2xi32>) {
|
||||
// expected-error@+1{{must be memref of floating-point or complex-type values}}
|
||||
"xla_lhlo.cosine"(%input, %result) : (memref<2x2xi32>, memref<2x2xi32>) -> ()
|
||||
return
|
||||
}
|
||||
|
||||
// -----
|
||||
|
||||
// CHECK-LABEL: func @sin
|
||||
func @sin(%input: memref<2x2xf32>, %result: memref<2x2xf32>) {
|
||||
"xla_lhlo.sine"(%input, %result) : (memref<2x2xf32>, memref<2x2xf32>) -> ()
|
||||
return
|
||||
}
|
||||
|
||||
// -----
|
||||
|
||||
// CHECK-LABEL: func @sin
|
||||
func @sin(%input: memref<2x2xcomplex<f32>>, %result: memref<2x2xcomplex<f32>>) {
|
||||
"xla_lhlo.sine"(%input, %result) : (memref<2x2xcomplex<f32>>, memref<2x2xcomplex<f32>>) -> ()
|
||||
return
|
||||
}
|
||||
|
||||
// -----
|
||||
|
||||
func @sin(%input: memref<2x2xi32>, %result: memref<2x2xi32>) {
|
||||
// expected-error@+1{{must be memref of floating-point or complex-type values}}
|
||||
"xla_lhlo.sine"(%input, %result) : (memref<2x2xi32>, memref<2x2xi32>) -> ()
|
||||
return
|
||||
}
|
||||
|
||||
@ -25,16 +83,40 @@ func @abs_memref(%in: memref<10xf32>, %out: memref<10xf32>) -> () {
|
||||
// -----
|
||||
|
||||
// CHECK-LABEL: func @convert_memref
|
||||
func @convert_memref(%in: memref<10xf32>, %out: memref<10xf32>) -> () {
|
||||
"xla_lhlo.convert"(%in, %out) : (memref<10xf32>, memref<10xf32>) -> ()
|
||||
func @convert_memref(%in: memref<10xf32>, %out: memref<10xi32>) -> () {
|
||||
"xla_lhlo.convert"(%in, %out) : (memref<10xf32>, memref<10xi32>) -> ()
|
||||
return
|
||||
}
|
||||
|
||||
// -----
|
||||
|
||||
// CHECK-LABEL: func @exp_memref
|
||||
func @exp_memref(%in: memref<10xf32>, %out: memref<10xf32>) -> () {
|
||||
"xla_lhlo.exponential"(%in, %out) : (memref<10xf32>, memref<10xf32>) -> ()
|
||||
func @convert_memref(%in: memref<10xf32>, %out: memref<9xi32>) -> () {
|
||||
// expected-error@+1{{requires the same shape for all operands}}
|
||||
"xla_lhlo.convert"(%in, %out) : (memref<10xf32>, memref<9xi32>) -> ()
|
||||
return
|
||||
}
|
||||
|
||||
// -----
|
||||
|
||||
// CHECK-LABEL: func @exp
|
||||
func @exp(%input: memref<2x2xf32>, %result: memref<2x2xf32>) {
|
||||
"xla_lhlo.exponential"(%input, %result) : (memref<2x2xf32>, memref<2x2xf32>) -> ()
|
||||
return
|
||||
}
|
||||
|
||||
// -----
|
||||
|
||||
// CHECK-LABEL: func @exp
|
||||
func @exp(%input: memref<2x2xcomplex<f32>>, %result: memref<2x2xcomplex<f32>>) {
|
||||
"xla_lhlo.exponential"(%input, %result) : (memref<2x2xcomplex<f32>>, memref<2x2xcomplex<f32>>) -> ()
|
||||
return
|
||||
}
|
||||
|
||||
// -----
|
||||
|
||||
func @exp(%input: memref<2x2xi32>, %result: memref<2x2xi32>) {
|
||||
// expected-error@+1{{must be memref of floating-point or complex-type values}}
|
||||
"xla_lhlo.exponential"(%input, %result) : (memref<2x2xi32>, memref<2x2xi32>) -> ()
|
||||
return
|
||||
}
|
||||
|
||||
@ -48,6 +130,22 @@ func @log_memref(%in: memref<10xf32>, %out: memref<10xf32>) -> () {
|
||||
|
||||
// -----
|
||||
|
||||
// CHECK-LABEL: func @log_memref
|
||||
func @log_memref(%in: memref<10xcomplex<f32>>, %out: memref<10xcomplex<f32>>) -> () {
|
||||
"xla_lhlo.log"(%in, %out) : (memref<10xcomplex<f32>>, memref<10xcomplex<f32>>) -> ()
|
||||
return
|
||||
}
|
||||
|
||||
// -----
|
||||
|
||||
func @log_memref(%in: memref<10xi32>, %out: memref<10xi32>) -> () {
|
||||
// expected-error@+1{{must be memref of floating-point or complex-type values}}
|
||||
"xla_lhlo.log"(%in, %out) : (memref<10xi32>, memref<10xi32>) -> ()
|
||||
return
|
||||
}
|
||||
|
||||
// -----
|
||||
|
||||
// CHECK-LABEL: func @neg_memref
|
||||
func @neg_memref(%in: memref<10xf32>, %out: memref<10xf32>) -> () {
|
||||
"xla_lhlo.negate"(%in, %out) : (memref<10xf32>, memref<10xf32>) -> ()
|
||||
@ -64,6 +162,46 @@ func @rsqrt_memref(%in: memref<10xf32>, %out: memref<10xf32>) -> () {
|
||||
|
||||
// -----
|
||||
|
||||
// CHECK-LABEL: func @rsqrt_memref
|
||||
func @rsqrt_memref(%in: memref<10xcomplex<f32>>, %out: memref<10xcomplex<f32>>) -> () {
|
||||
"xla_lhlo.rsqrt"(%in, %out) : (memref<10xcomplex<f32>>, memref<10xcomplex<f32>>) -> ()
|
||||
return
|
||||
}
|
||||
|
||||
// -----
|
||||
|
||||
func @rsqrt_memref(%in: memref<10xi32>, %out: memref<10xi32>) -> () {
|
||||
// expected-error@+1{{must be memref of floating-point or complex-type values}}
|
||||
"xla_lhlo.rsqrt"(%in, %out) : (memref<10xi32>, memref<10xi32>) -> ()
|
||||
return
|
||||
}
|
||||
|
||||
// -----
|
||||
|
||||
// CHECK-LABEL: func @sqrt_memref
|
||||
func @sqrt_memref(%in: memref<10xf32>, %out: memref<10xf32>) -> () {
|
||||
"xla_lhlo.sqrt"(%in, %out) : (memref<10xf32>, memref<10xf32>) -> ()
|
||||
return
|
||||
}
|
||||
|
||||
// -----
|
||||
|
||||
// CHECK-LABEL: func @sqrt_memref
|
||||
func @sqrt_memref(%in: memref<10xcomplex<f32>>, %out: memref<10xcomplex<f32>>) -> () {
|
||||
"xla_lhlo.sqrt"(%in, %out) : (memref<10xcomplex<f32>>, memref<10xcomplex<f32>>) -> ()
|
||||
return
|
||||
}
|
||||
|
||||
// -----
|
||||
|
||||
func @sqrt_memref(%in: memref<10xi32>, %out: memref<10xi32>) -> () {
|
||||
// expected-error@+1{{must be memref of floating-point or complex-type values}}
|
||||
"xla_lhlo.sqrt"(%in, %out) : (memref<10xi32>, memref<10xi32>) -> ()
|
||||
return
|
||||
}
|
||||
|
||||
// -----
|
||||
|
||||
// CHECK-LABEL: func @sign_memref
|
||||
func @sign_memref(%in: memref<10xf32>, %out: memref<10xf32>) -> () {
|
||||
"xla_lhlo.sign"(%in, %out) : (memref<10xf32>, memref<10xf32>) -> ()
|
||||
@ -80,6 +218,30 @@ func @tanh_memref(%in: memref<10xf32>, %out: memref<10xf32>) -> () {
|
||||
|
||||
// -----
|
||||
|
||||
// CHECK-LABEL: func @tanh_memref
|
||||
func @tanh_memref(%in: memref<10xcomplex<f32>>, %out: memref<10xcomplex<f32>>) -> () {
|
||||
"xla_lhlo.tanh"(%in, %out) : (memref<10xcomplex<f32>>, memref<10xcomplex<f32>>) -> ()
|
||||
return
|
||||
}
|
||||
|
||||
// -----
|
||||
|
||||
func @tanh_memref(%in: memref<10xi32>, %out: memref<10xi32>) -> () {
|
||||
// expected-error@+1{{must be memref of floating-point or complex-type values}}
|
||||
"xla_lhlo.tanh"(%in, %out) : (memref<10xi32>, memref<10xi32>) -> ()
|
||||
return
|
||||
}
|
||||
|
||||
// -----
|
||||
|
||||
func @tanh_memref(%arg0: memref<1xf32>, %arg1: memref<2xf32>) -> () {
|
||||
// expected-error@+1{{'xla_lhlo.tanh' op requires all operands to have the same type}}
|
||||
"xla_lhlo.tanh"(%arg0, %arg1) : (memref<1xf32>, memref<2xf32>) -> ()
|
||||
return
|
||||
}
|
||||
|
||||
// -----
|
||||
|
||||
// CHECK-LABEL: func @add_memref
|
||||
func @add_memref(%lhs: memref<10xf32>, %rhs: memref<10xf32>, %out: memref<10xf32>) -> () {
|
||||
"xla_lhlo.add"(%lhs, %rhs, %out) : (memref<10xf32>, memref<10xf32>, memref<10xf32>) -> ()
|
||||
@ -129,13 +291,77 @@ func @sub_memref(%lhs: memref<10xf32>, %rhs: memref<10xf32>, %out: memref<10xf32
|
||||
// -----
|
||||
|
||||
// CHECK-LABEL: func @and_memref
|
||||
func @and_memref(%lhs: memref<10xi32>, %rhs: memref<10xi32>, %out: memref<10xi32>) -> () {
|
||||
"xla_lhlo.and"(%lhs, %rhs, %out) : (memref<10xi32>, memref<10xi32>, memref<10xi32>) -> ()
|
||||
return
|
||||
}
|
||||
|
||||
// -----
|
||||
|
||||
// CHECK-LABEL: func @and_memref
|
||||
func @and_memref(%lhs: memref<10xi1>, %rhs: memref<10xi1>, %out: memref<10xi1>) -> () {
|
||||
"xla_lhlo.and"(%lhs, %rhs, %out) : (memref<10xi1>, memref<10xi1>, memref<10xi1>) -> ()
|
||||
return
|
||||
}
|
||||
|
||||
// -----
|
||||
|
||||
func @and_memref(%lhs: memref<10xf32>, %rhs: memref<10xf32>, %out: memref<10xf32>) -> () {
|
||||
// expected-error @+1 {{must be memref of 8/16/32/64-bit signless integer or 8/16/32/64-bit unsigned integer or pred (AKA boolean or 1-bit integer) values}}
|
||||
"xla_lhlo.and"(%lhs, %rhs, %out) : (memref<10xf32>, memref<10xf32>, memref<10xf32>) -> ()
|
||||
return
|
||||
}
|
||||
|
||||
// -----
|
||||
|
||||
// CHECK-LABEL: func @or_memref
|
||||
func @or_memref(%lhs: memref<10xi32>, %rhs: memref<10xi32>, %out: memref<10xi32>) -> () {
|
||||
"xla_lhlo.or"(%lhs, %rhs, %out) : (memref<10xi32>, memref<10xi32>, memref<10xi32>) -> ()
|
||||
return
|
||||
}
|
||||
|
||||
// -----
|
||||
|
||||
// CHECK-LABEL: func @or_memref
|
||||
func @or_memref(%lhs: memref<10xi1>, %rhs: memref<10xi1>, %out: memref<10xi1>) -> () {
|
||||
"xla_lhlo.or"(%lhs, %rhs, %out) : (memref<10xi1>, memref<10xi1>, memref<10xi1>) -> ()
|
||||
return
|
||||
}
|
||||
|
||||
// -----
|
||||
|
||||
func @or_memref(%lhs: memref<10xf32>, %rhs: memref<10xf32>, %out: memref<10xf32>) -> () {
|
||||
// expected-error @+1 {{must be memref of 8/16/32/64-bit signless integer or 8/16/32/64-bit unsigned integer or pred (AKA boolean or 1-bit integer) values}}
|
||||
"xla_lhlo.or"(%lhs, %rhs, %out) : (memref<10xf32>, memref<10xf32>, memref<10xf32>) -> ()
|
||||
return
|
||||
}
|
||||
|
||||
// -----
|
||||
|
||||
// CHECK-LABEL: func @xor_memref
|
||||
func @xor_memref(%lhs: memref<10xi32>, %rhs: memref<10xi32>, %out: memref<10xi32>) -> () {
|
||||
"xla_lhlo.xor"(%lhs, %rhs, %out) : (memref<10xi32>, memref<10xi32>, memref<10xi32>) -> ()
|
||||
return
|
||||
}
|
||||
|
||||
// -----
|
||||
|
||||
// CHECK-LABEL: func @xor_memref
|
||||
func @xor_memref(%lhs: memref<10xi1>, %rhs: memref<10xi1>, %out: memref<10xi1>) -> () {
|
||||
"xla_lhlo.xor"(%lhs, %rhs, %out) : (memref<10xi1>, memref<10xi1>, memref<10xi1>) -> ()
|
||||
return
|
||||
}
|
||||
|
||||
// -----
|
||||
|
||||
func @xor_memref(%lhs: memref<10xf32>, %rhs: memref<10xf32>, %out: memref<10xf32>) -> () {
|
||||
// expected-error @+1 {{must be memref of 8/16/32/64-bit signless integer or 8/16/32/64-bit unsigned integer or pred (AKA boolean or 1-bit integer) values}}
|
||||
"xla_lhlo.xor"(%lhs, %rhs, %out) : (memref<10xf32>, memref<10xf32>, memref<10xf32>) -> ()
|
||||
return
|
||||
}
|
||||
|
||||
// -----
|
||||
|
||||
// CHECK-LABEL: func @broadcast_in_dim_memref
|
||||
func @broadcast_in_dim_memref(%arg0: memref<1x2xi32>, %out: memref<1x2x2xi32>) -> () {
|
||||
"xla_lhlo.broadcast_in_dim"(%arg0, %out) {broadcast_dimensions = dense<[1, 2]> : tensor<2xi64>} : (memref<1x2xi32>, memref<1x2x2xi32>) -> ()
|
||||
@ -248,3 +474,392 @@ func @dynamic_memref_cast_incompatible_result_type(%in: memref<?xf32>) {
|
||||
: memref<?xf32> -> memref<?x?xf32, offset: 0, strides: [?, ?]>
|
||||
return
|
||||
}
|
||||
|
||||
// -----
|
||||
|
||||
// CHECK-LABEL: func @atan2_memrefs
|
||||
func @atan2_memrefs(%arg0: memref<1xf32>, %arg1: memref<1xf32>, %arg_out: memref<1xf32>) -> () {
|
||||
"xla_lhlo.atan2"(%arg0, %arg1, %arg_out) : (memref<1xf32>, memref<1xf32>, memref<1xf32>) -> ()
|
||||
return
|
||||
}
|
||||
|
||||
// -----
|
||||
|
||||
// CHECK-LABEL: func @atan2_memrefs
|
||||
func @atan2_memrefs(%arg0: memref<1xcomplex<f32>>, %arg1: memref<1xcomplex<f32>>, %arg_out: memref<1xcomplex<f32>>) -> () {
|
||||
"xla_lhlo.atan2"(%arg0, %arg1, %arg_out) : (memref<1xcomplex<f32>>, memref<1xcomplex<f32>>, memref<1xcomplex<f32>>) -> ()
|
||||
return
|
||||
}
|
||||
|
||||
// -----
|
||||
|
||||
func @atan2_memrefs(%arg0: memref<1xi32>, %arg1: memref<1xi32>, %arg_out: memref<1xi32>) -> () {
|
||||
// expected-error@+1{{must be memref of floating-point or complex-type values}}
|
||||
"xla_lhlo.atan2"(%arg0, %arg1, %arg_out) : (memref<1xi32>, memref<1xi32>, memref<1xi32>) -> ()
|
||||
return
|
||||
}
|
||||
|
||||
// -----
|
||||
|
||||
// CHECK-LABEL: func @bitcast_convert_memrefs
|
||||
func @bitcast_convert_memrefs(%arg0: memref<1xf32>, %arg_out: memref<1xi32>) -> () {
|
||||
"xla_lhlo.bitcast_convert"(%arg0, %arg_out) : (memref<1xf32>, memref<1xi32>) -> ()
|
||||
return
|
||||
}
|
||||
|
||||
// -----
|
||||
|
||||
func @bitcast_convert_memrefs(%arg0: memref<1xf32>, %arg_out: memref<2xi32>) -> () {
|
||||
// expected-error@+1{{requires the same shape for all operands}}
|
||||
"xla_lhlo.bitcast_convert"(%arg0, %arg_out) : (memref<1xf32>, memref<2xi32>) -> ()
|
||||
return
|
||||
}
|
||||
|
||||
// -----
|
||||
|
||||
// CHECK-LABEL: func @clz_memrefs
|
||||
func @clz_memrefs(%arg0: memref<1xi32>, %arg_out: memref<1xi32>) -> () {
|
||||
"xla_lhlo.count_leading_zeros"(%arg0, %arg_out) : (memref<1xi32>, memref<1xi32>) -> ()
|
||||
return
|
||||
}
|
||||
|
||||
// -----
|
||||
|
||||
// CHECK-LABEL: func @expm1_memrefs
|
||||
func @expm1_memrefs(%arg0: memref<1xf32>, %arg_out: memref<1xf32>) -> () {
|
||||
"xla_lhlo.exponential_minus_one"(%arg0, %arg_out) : (memref<1xf32>, memref<1xf32>) -> ()
|
||||
return
|
||||
}
|
||||
|
||||
// -----
|
||||
|
||||
// CHECK-LABEL: func @expm1_memrefs
|
||||
func @expm1_memrefs(%arg0: memref<1xcomplex<f32>>, %arg_out: memref<1xcomplex<f32>>) -> () {
|
||||
"xla_lhlo.exponential_minus_one"(%arg0, %arg_out) : (memref<1xcomplex<f32>>, memref<1xcomplex<f32>>) -> ()
|
||||
return
|
||||
}
|
||||
|
||||
// -----
|
||||
|
||||
// CHECK-LABEL: func @floor_memrefs
|
||||
func @floor_memrefs(%arg0: memref<1xf32>, %arg_out: memref<1xf32>) -> () {
|
||||
"xla_lhlo.floor"(%arg0, %arg_out) : (memref<1xf32>, memref<1xf32>) -> ()
|
||||
return
|
||||
}
|
||||
|
||||
// -----
|
||||
|
||||
func @floor_memrefs(%arg0: memref<1xi32>, %arg_out: memref<1xi32>) -> () {
|
||||
// expected-error@+1{{must be memref of floating-point values}}
|
||||
"xla_lhlo.floor"(%arg0, %arg_out) : (memref<1xi32>, memref<1xi32>) -> ()
|
||||
return
|
||||
}
|
||||
|
||||
// -----
|
||||
|
||||
// CHECK-LABEL: func @imag_memrefs
|
||||
func @imag_memrefs(%arg0: memref<1xcomplex<f32>>, %arg_out: memref<1xf32>) -> () {
|
||||
"xla_lhlo.imag"(%arg0, %arg_out) : (memref<1xcomplex<f32>>, memref<1xf32>) -> ()
|
||||
return
|
||||
}
|
||||
|
||||
// -----
|
||||
|
||||
func @imag_memrefs(%arg0: memref<1xf32>, %arg_out: memref<1xf32>) -> () {
|
||||
// expected-error@+1{{must be memref of complex-type values}}
|
||||
"xla_lhlo.imag"(%arg0, %arg_out) : (memref<1xf32>, memref<1xf32>) -> ()
|
||||
return
|
||||
}
|
||||
|
||||
// -----
|
||||
|
||||
// CHECK-LABEL: func @real_memrefs
|
||||
func @real_memrefs(%arg0: memref<1xcomplex<f32>>, %arg_out: memref<1xf32>) -> () {
|
||||
"xla_lhlo.real"(%arg0, %arg_out) : (memref<1xcomplex<f32>>, memref<1xf32>) -> ()
|
||||
return
|
||||
}
|
||||
|
||||
// -----
|
||||
|
||||
func @real_memrefs(%arg0: memref<1xf32>, %arg_out: memref<1xf32>) -> () {
|
||||
// expected-error@+1{{must be memref of complex-type values}}
|
||||
"xla_lhlo.real"(%arg0, %arg_out) : (memref<1xf32>, memref<1xf32>) -> ()
|
||||
return
|
||||
}
|
||||
|
||||
// -----
|
||||
|
||||
// CHECK-LABEL: func @is_finite_memrefs
|
||||
func @is_finite_memrefs(%arg0: memref<1xf32>, %arg_out: memref<1xi1>) -> () {
|
||||
"xla_lhlo.is_finite"(%arg0, %arg_out) : (memref<1xf32>, memref<1xi1>) -> ()
|
||||
return
|
||||
}
|
||||
|
||||
// -----
|
||||
|
||||
// CHECK-LABEL: func @log1p_memrefs
|
||||
func @log1p_memrefs(%arg0: memref<1xf32>, %arg_out: memref<1xf32>) -> () {
|
||||
"xla_lhlo.log_plus_one"(%arg0, %arg_out) : (memref<1xf32>, memref<1xf32>) -> ()
|
||||
return
|
||||
}
|
||||
|
||||
// -----
|
||||
|
||||
// CHECK-LABEL: func @log1p_memrefs
|
||||
func @log1p_memrefs(%arg0: memref<1xcomplex<f32>>, %arg_out: memref<1xcomplex<f32>>) -> () {
|
||||
"xla_lhlo.log_plus_one"(%arg0, %arg_out) : (memref<1xcomplex<f32>>, memref<1xcomplex<f32>>) -> ()
|
||||
return
|
||||
}
|
||||
|
||||
// -----
|
||||
|
||||
func @log1p_memref(%in: memref<10xi32>, %out: memref<10xi32>) -> () {
|
||||
// expected-error@+1{{must be memref of floating-point or complex-type values}}
|
||||
"xla_lhlo.log_plus_one"(%in, %out) : (memref<10xi32>, memref<10xi32>) -> ()
|
||||
return
|
||||
}
|
||||
|
||||
// -----
|
||||
|
||||
// CHECK-LABEL: func @not_memrefs
|
||||
func @not_memrefs(%arg0: memref<1xi32>, %arg_out: memref<1xi32>) -> () {
|
||||
"xla_lhlo.not"(%arg0, %arg_out) : (memref<1xi32>, memref<1xi32>) -> ()
|
||||
return
|
||||
}
|
||||
|
||||
// -----
|
||||
|
||||
// CHECK-LABEL: func @not_memrefs
|
||||
func @not_memrefs(%arg0: memref<1xi1>, %arg_out: memref<1xi1>) -> () {
|
||||
"xla_lhlo.not"(%arg0, %arg_out) : (memref<1xi1>, memref<1xi1>) -> ()
|
||||
return
|
||||
}
|
||||
|
||||
// -----
|
||||
|
||||
func @not_memrefs(%arg0: memref<1xf32>, %arg_out: memref<1xf32>) -> () {
|
||||
// expected-error @+1 {{must be memref of 8/16/32/64-bit signless integer or 8/16/32/64-bit unsigned integer or pred (AKA boolean or 1-bit integer) values}}
|
||||
"xla_lhlo.not"(%arg0, %arg_out) : (memref<1xf32>, memref<1xf32>) -> ()
|
||||
return
|
||||
}
|
||||
|
||||
// -----
|
||||
|
||||
// CHECK-LABEL: func @popcnt_memrefs
|
||||
func @popcnt_memrefs(%arg0: memref<1xi32>, %arg_out: memref<1xi32>) -> () {
|
||||
"xla_lhlo.popcnt"(%arg0, %arg_out) : (memref<1xi32>, memref<1xi32>) -> ()
|
||||
return
|
||||
}
|
||||
|
||||
// -----
|
||||
|
||||
func @popcnt_memrefs(%arg0: memref<1xf32>, %arg_out: memref<1xf32>) -> () {
|
||||
// expected-error @+1 {{must be memref of 8/16/32/64-bit signless integer or 8/16/32/64-bit unsigned integer values}}
|
||||
"xla_lhlo.popcnt"(%arg0, %arg_out) : (memref<1xf32>, memref<1xf32>) -> ()
|
||||
return
|
||||
}
|
||||
|
||||
// -----
|
||||
|
||||
// CHECK-LABEL: func @reduce_precision_memrefs
|
||||
func @reduce_precision_memrefs(%arg0: memref<1xf32>, %arg_out: memref<1xf32>) -> () {
|
||||
"xla_lhlo.reduce_precision"(%arg0, %arg_out) { exponent_bits = 4 : i32, mantissa_bits = 4 : i32 } : (memref<1xf32>, memref<1xf32>) -> ()
|
||||
return
|
||||
}
|
||||
|
||||
// -----
|
||||
|
||||
// CHECK-LABEL: func @round_memrefs
|
||||
func @round_memrefs(%arg0: memref<1xf32>, %arg_out: memref<1xf32>) -> () {
|
||||
"xla_lhlo.round_nearest_afz"(%arg0, %arg_out) : (memref<1xf32>, memref<1xf32>) -> ()
|
||||
return
|
||||
}
|
||||
|
||||
// -----
|
||||
|
||||
func @round_memrefs(%arg0: memref<1xi32>, %arg_out: memref<1xi32>) -> () {
|
||||
// expected-error@+1{{must be memref of floating-point values}}
|
||||
"xla_lhlo.round_nearest_afz"(%arg0, %arg_out) : (memref<1xi32>, memref<1xi32>) -> ()
|
||||
return
|
||||
}
|
||||
|
||||
// -----
|
||||
|
||||
// CHECK-LABEL: func @shift_left_memrefs
|
||||
func @shift_left_memrefs(%arg0: memref<1xi32>, %arg1: memref<1xi32>, %arg_out: memref<1xi32>) -> () {
|
||||
"xla_lhlo.shift_left"(%arg0, %arg1, %arg_out) : (memref<1xi32>, memref<1xi32>, memref<1xi32>) -> ()
|
||||
return
|
||||
}
|
||||
|
||||
// -----
|
||||
|
||||
func @shift_left_memrefs(%arg0: memref<1xf32>, %arg1: memref<1xf32>, %arg_out: memref<1xf32>) -> () {
|
||||
// expected-error @+1 {{must be memref of 8/16/32/64-bit signless integer or 8/16/32/64-bit unsigned integer values}}
|
||||
"xla_lhlo.shift_left"(%arg0, %arg1, %arg_out) : (memref<1xf32>, memref<1xf32>, memref<1xf32>) -> ()
|
||||
return
|
||||
}
|
||||
|
||||
// -----
|
||||
|
||||
// CHECK-LABEL: func @shift_right_arithmetic_memrefs
|
||||
func @shift_right_arithmetic_memrefs(%arg0: memref<1xi32>, %arg1: memref<1xi32>, %arg_out: memref<1xi32>) -> () {
|
||||
"xla_lhlo.shift_right_arithmetic"(%arg0, %arg1, %arg_out) : (memref<1xi32>, memref<1xi32>, memref<1xi32>) -> ()
|
||||
return
|
||||
}
|
||||
|
||||
// -----
|
||||
|
||||
func @shift_right_arithmetic_memrefs(%arg0: memref<1xf32>, %arg1: memref<1xf32>, %arg_out: memref<1xf32>) -> () {
|
||||
// expected-error @+1 {{must be memref of 8/16/32/64-bit signless integer or 8/16/32/64-bit unsigned integer values}}
|
||||
"xla_lhlo.shift_right_arithmetic"(%arg0, %arg1, %arg_out) : (memref<1xf32>, memref<1xf32>, memref<1xf32>) -> ()
|
||||
return
|
||||
}
|
||||
|
||||
// -----
|
||||
|
||||
// CHECK-LABEL: func @shift_right_logical_memrefs
|
||||
func @shift_right_logical_memrefs(%arg0: memref<1xi32>, %arg1: memref<1xi32>, %arg_out: memref<1xi32>) -> () {
|
||||
"xla_lhlo.shift_right_logical"(%arg0, %arg1, %arg_out) : (memref<1xi32>, memref<1xi32>, memref<1xi32>) -> ()
|
||||
return
|
||||
}
|
||||
|
||||
// -----
|
||||
|
||||
func @shift_right_logical_memrefs(%arg0: memref<1xf32>, %arg1: memref<1xf32>, %arg_out: memref<1xf32>) -> () {
|
||||
// expected-error @+1 {{must be memref of 8/16/32/64-bit signless integer or 8/16/32/64-bit unsigned integer values}}
|
||||
"xla_lhlo.shift_right_logical"(%arg0, %arg1, %arg_out) : (memref<1xf32>, memref<1xf32>, memref<1xf32>) -> ()
|
||||
return
|
||||
}
|
||||
|
||||
// -----
|
||||
|
||||
// CHECK-LABEL: func @all_reduce_memrefs
|
||||
func @all_reduce_memrefs(%arg0: memref<10xf32>, %arg_out: memref<10xf32>) -> () {
|
||||
"xla_lhlo.all_reduce"(%arg0, %arg_out) ({
|
||||
^bb0(%lhs: tensor<f32>, %rhs: tensor<f32>):
|
||||
%max = xla_hlo.maximum %lhs, %rhs : tensor<f32>
|
||||
"xla_hlo.return"(%max) : (tensor<f32>) -> ()
|
||||
})
|
||||
{ replica_groups = dense<[[0, 2, 4, 6], [1, 3, 5, 7]]> : tensor<2x4xi64> }: (memref<10xf32>, memref<10xf32>) -> ()
|
||||
|
||||
"xla_lhlo.all_reduce"(%arg0, %arg_out) ({
|
||||
^bb0(%lhs: tensor<f32>, %rhs: tensor<f32>):
|
||||
%max = xla_hlo.maximum %lhs, %rhs : tensor<f32>
|
||||
"xla_hlo.return"(%max) : (tensor<f32>) -> ()
|
||||
})
|
||||
{
|
||||
replica_groups = dense<[[0, 2, 4, 6], [1, 3, 5, 7]]> : tensor<2x4xi64>,
|
||||
channel_id = { handle = 5 : i64, type = 2 : i64 },
|
||||
constrain_layout = true,
|
||||
use_global_device_ids = true
|
||||
}: (memref<10xf32>, memref<10xf32>) -> ()
|
||||
return
|
||||
}
|
||||
|
||||
// -----
|
||||
|
||||
// CHECK-LABEL: func @collective_permute_memrefs
|
||||
func @collective_permute_memrefs(%arg0: memref<128x32xf32>, %arg_out: memref<128x32xf32>) -> () {
|
||||
"xla_lhlo.collective_permute"(%arg0, %arg_out) {
|
||||
source_target_pairs = dense<[[0, 1], [1, 2], [2, 3]]> : tensor<3x2xi64>
|
||||
} : (memref<128x32xf32>, memref<128x32xf32>) -> ()
|
||||
|
||||
"xla_lhlo.collective_permute"(%arg0, %arg_out) {
|
||||
source_target_pairs = dense<[[0, 1], [1, 2], [2, 3]]> : tensor<3x2xi64>,
|
||||
channel_id = { handle = 5 : i64, type = 2 : i64 }
|
||||
} : (memref<128x32xf32>, memref<128x32xf32>) -> ()
|
||||
return
|
||||
}
|
||||
|
||||
// -----
|
||||
|
||||
// CHECK-LABEL: func @fft_memrefs
|
||||
func @fft_memrefs(%arg0: memref<3x9xf32>, %arg_out: memref<3x5xcomplex<f32>>) -> () {
|
||||
"xla_lhlo.fft"(%arg0, %arg_out) {fft_length = dense<9> : tensor<1xi64>, fft_type = "RFFT"} : (memref<3x9xf32>, memref<3x5xcomplex<f32>>) -> ()
|
||||
return
|
||||
}
|
||||
|
||||
// -----
|
||||
|
||||
// CHECK-LABEL: func @batch_norm_grad_memrefs
|
||||
func @batch_norm_grad_memrefs(%arg0: memref<8x8x8x8xf32>, %arg1: memref<8xf32>, %arg2: memref<8xf32>,
|
||||
%arg3: memref<8xf32>, %arg4: memref<8x8x8x8xf32>,
|
||||
%arg_out: tuple<memref<8x8x8x8xf32>, memref<8xf32>, memref<8xf32>>) -> () {
|
||||
"xla_lhlo.batch_norm_grad"(%arg0, %arg1, %arg2, %arg3, %arg4, %arg_out) {epsilon = 1.000000e-03 : f32, feature_index = 3 : i64}
|
||||
: (memref<8x8x8x8xf32>, memref<8xf32>, memref<8xf32>, memref<8xf32>, memref<8x8x8x8xf32>,
|
||||
tuple<memref<8x8x8x8xf32>, memref<8xf32>, memref<8xf32>>) -> ()
|
||||
return
|
||||
}
|
||||
|
||||
// -----
|
||||
|
||||
// CHECK-LABEL: func @batch_norm_inference_memrefs
|
||||
func @batch_norm_inference_memrefs(%arg0: memref<8x8x8x8xf32>, %arg1: memref<8xf32>, %arg2: memref<8xf32>,
|
||||
%arg3: memref<8xf32>, %arg4: memref<8xf32>, %arg_out: memref<8x8x8x8xf32>) -> () {
|
||||
"xla_lhlo.batch_norm_inference"(%arg0, %arg1, %arg2, %arg3, %arg4, %arg_out) {epsilon = 1.000000e-03 : f32, feature_index = 3 : i64}
|
||||
: (memref<8x8x8x8xf32>, memref<8xf32>, memref<8xf32>, memref<8xf32>, memref<8xf32>, memref<8x8x8x8xf32>) -> ()
|
||||
return
|
||||
}
|
||||
|
||||
// -----
|
||||
|
||||
// CHECK-LABEL: func @batch_norm_training_memrefs
|
||||
func @batch_norm_training_memrefs(%arg0: memref<8x8x8x8xf32>, %arg1: memref<8xf32>, %arg2: memref<8xf32>,
|
||||
%arg_out: tuple<memref<8x8x8x8xf32>, memref<8xf32>, memref<8xf32>>) -> () {
|
||||
"xla_lhlo.batch_norm_training"(%arg0, %arg1, %arg2, %arg_out) {epsilon = 1.000000e-03 : f32, feature_index = 3 : i64}
|
||||
: (memref<8x8x8x8xf32>, memref<8xf32>, memref<8xf32>, tuple<memref<8x8x8x8xf32>, memref<8xf32>, memref<8xf32>>) -> ()
|
||||
return
|
||||
}
|
||||
|
||||
// -----
|
||||
|
||||
// CHECK-LABEL: func @cholesky_memrefs
|
||||
func @cholesky_memrefs(%arg0: memref<1x291x291xf32>, %arg_out: memref<1x291x291xf32>) -> () {
|
||||
"xla_lhlo.cholesky"(%arg0, %arg_out) : (memref<1x291x291xf32>, memref<1x291x291xf32>) -> ()
|
||||
"xla_lhlo.cholesky"(%arg0, %arg_out) { lower = true } : (memref<1x291x291xf32>, memref<1x291x291xf32>) -> ()
|
||||
return
|
||||
}
|
||||
|
||||
// -----
|
||||
|
||||
// CHECK-LABEL: func @infeed_memrefs
|
||||
func @infeed_memrefs(%arg_out: memref<3xf32>) -> () {
|
||||
"xla_lhlo.infeed"(%arg_out) { config = "x" } : (memref<3xf32>) -> ()
|
||||
return
|
||||
}
|
||||
|
||||
// -----
|
||||
|
||||
// CHECK-LABEL: func @outfeed_memrefs
|
||||
func @outfeed_memrefs(%arg0: memref<3xf32>) -> () {
|
||||
"xla_lhlo.outfeed"(%arg0) { config = "x" } : (memref<3xf32>) -> ()
|
||||
return
|
||||
}
|
||||
|
||||
// -----
|
||||
|
||||
// CHECK-LABEL: func @replica_id_memrefs
|
||||
func @replica_id_memrefs(%arg_out: memref<ui32>) -> () {
|
||||
"xla_lhlo.replica_id"(%arg_out) : (memref<ui32>) -> ()
|
||||
return
|
||||
}
|
||||
|
||||
// -----
|
||||
|
||||
// CHECK-LABEL: func @triangular_solve_memrefs
|
||||
func @triangular_solve_memrefs(%arg0: memref<4x4xf32>, %arg1: memref<3x4xf32>, %arg_out: memref<3x4xf32>) -> () {
|
||||
"xla_lhlo.triangular_solve"(%arg0, %arg1, %arg_out) {left_side = true, lower = true, transpose_a = "NO_TRANSPOSE", unit_diagonal = true}
|
||||
: (memref<4x4xf32>, memref<3x4xf32>, memref<3x4xf32>) -> ()
|
||||
return
|
||||
}
|
||||
|
||||
// -----
|
||||
|
||||
// CHECK-LABEL: func @while_memrefs
|
||||
func @while_memrefs(%arg0: memref<i64>, %arg_out: memref<i64>) -> () {
|
||||
"xla_lhlo.while"(%arg0, %arg_out) (
|
||||
{ ^bb0(%arg: memref<i64>, %cond: memref<i1>): "xla_lhlo.terminator"() : () -> () },
|
||||
{ ^bb0(%arg: memref<i64>, %body_out: memref<i64>): "xla_lhlo.terminator"() : () -> () }
|
||||
) : (memref<i64>, memref<i64>) -> ()
|
||||
return
|
||||
}
|
||||
|
@ -304,6 +304,46 @@ func @collective_permute_duplicate_sources(%arg0: tensor<128x32xf32>) -> tensor<
|
||||
|
||||
// -----
|
||||
|
||||
// CHECK-LABEL: @concat_1D
|
||||
func @concat_1D(%arg0: tensor<1xi32>, %arg1: tensor<2xi32>) -> tensor<3xi32> {
|
||||
%0 = "xla_hlo.concatenate"(%arg0, %arg1) { dimension = 0 : i64 } : (tensor<1xi32>, tensor<2xi32>) -> tensor<3xi32>
|
||||
return %0 : tensor<3xi32>
|
||||
}
|
||||
|
||||
// -----
|
||||
|
||||
func @concat_1D_type_error(%arg0: tensor<1xi32>, %arg1: tensor<2xf32>) -> tensor<3xi32> {
|
||||
// expected-error@+1 {{'xla_hlo.concatenate' op requires the same element type for all operands and results}}
|
||||
%0 = "xla_hlo.concatenate"(%arg0, %arg1) { dimension = 0 : i64 } : (tensor<1xi32>, tensor<2xf32>) -> tensor<3xi32>
|
||||
return %0 : tensor<3xi32>
|
||||
}
|
||||
|
||||
// -----
|
||||
|
||||
// CHECK-LABEL: @concat_1D_unranked
|
||||
func @concat_1D_unranked(%arg0: tensor<1xi32>, %arg1: tensor<*xi32>) -> tensor<*xi32> {
|
||||
%0 = "xla_hlo.concatenate"(%arg0, %arg1) { dimension = 0 : i64 } : (tensor<1xi32>, tensor<*xi32>) -> tensor<*xi32>
|
||||
return %0 : tensor<*xi32>
|
||||
}
|
||||
|
||||
// -----
|
||||
|
||||
func @concat_1D_unranked_error(%arg0: tensor<1xi32>, %arg1: tensor<*xi32>) -> tensor<3xi32> {
|
||||
// expected-error@+1 {{'xla_hlo.concatenate' op inferred type incompatible with return type of operation}}
|
||||
%0 = "xla_hlo.concatenate"(%arg0, %arg1) { dimension = 0 : i64 } : (tensor<1xi32>, tensor<*xi32>) -> tensor<3xi32>
|
||||
return %0 : tensor<3xi32>
|
||||
}
|
||||
|
||||
// -----
|
||||
|
||||
func @concat_1D_error(%arg0: tensor<1xi32>, %arg1: tensor<2xi32>) -> tensor<4xi32> {
|
||||
// expected-error@+1 {{'xla_hlo.concatenate' op inferred type incompatible with return type of operation}}
|
||||
%0 = "xla_hlo.concatenate"(%arg0, %arg1) { dimension = 0 : i64 } : (tensor<1xi32>, tensor<2xi32>) -> tensor<4xi32>
|
||||
return %0 : tensor<4xi32>
|
||||
}
|
||||
|
||||
// -----
|
||||
|
||||
// CHECK-LABEL: func @clamp
|
||||
func @clamp(%arg0: tensor<1xi32>) -> tensor<1xi32> {
|
||||
%0 = "xla_hlo.clamp"(%arg0, %arg0, %arg0) : (tensor<1xi32>, tensor<1xi32>, tensor<1xi32>) -> tensor<1xi32>
|
||||
|
@ -104,16 +104,20 @@ func @batchNormInference_dynamic_shape(
|
||||
%x: tensor<?x?x?x?xf32>, %scale: tensor<?xf32>, %offset: tensor<?xf32>,
|
||||
%mean: tensor<?xf32>, %variance: tensor<?xf32>)
|
||||
-> tensor<?x?x?x?xf32> {
|
||||
// CHECK-DAG: %[[C0:.*]] = constant 0 : index
|
||||
// CHECK-DAG: %[[C1:.*]] = constant 1 : index
|
||||
// CHECK-DAG: %[[C2:.*]] = constant 2 : index
|
||||
// CHECK-DAG: %[[C3:.*]] = constant 3 : index
|
||||
// CHECK-DAG: %[[EPS:.+]] = xla_hlo.constant dense<1.000000e-03> : tensor<f32>
|
||||
// CHECK-DAG: %[[DIM:.+]] = dim %[[VARIANCE]], 0 : tensor<?xf32>
|
||||
// CHECK-DAG: %[[DIM:.+]] = dim %[[VARIANCE]], %[[C0]] : tensor<?xf32>
|
||||
// CHECK-DAG: %[[TO_DIM_TENSOR:.+]] = tensor_from_elements(%[[DIM]]) : tensor<1xindex>
|
||||
// CHECK-DAG: %[[EPS_BCAST:.+]] = "xla_hlo.dynamic_broadcast_in_dim"(%[[EPS]], %[[TO_DIM_TENSOR]]) {broadcast_dimensions = dense<[]> : tensor<0xi64>} : (tensor<f32>, tensor<1xindex>) -> tensor<?xf32>
|
||||
// CHECK-DAG: %[[VARIANCE_EPS:.+]] = xla_hlo.add %[[VARIANCE]], %[[EPS_BCAST]] : tensor<?xf32>
|
||||
// CHECK-DAG: %[[STDDEV:.+]] = "xla_hlo.sqrt"(%[[VARIANCE_EPS]]) : (tensor<?xf32>) -> tensor<?xf32>
|
||||
// CHECK-DAG: %[[INPUT_DIM_0:.+]] = dim %[[X]], 0 : tensor<?x?x?x?xf32>
|
||||
// CHECK-DAG: %[[INPUT_DIM_1:.+]] = dim %[[X]], 1 : tensor<?x?x?x?xf32>
|
||||
// CHECK-DAG: %[[INPUT_DIM_2:.+]] = dim %[[X]], 2 : tensor<?x?x?x?xf32>
|
||||
// CHECK-DAG: %[[INPUT_DIM_3:.+]] = dim %[[X]], 3 : tensor<?x?x?x?xf32>
|
||||
// CHECK-DAG: %[[INPUT_DIM_0:.+]] = dim %[[X]], %[[C0]] : tensor<?x?x?x?xf32>
|
||||
// CHECK-DAG: %[[INPUT_DIM_1:.+]] = dim %[[X]], %[[C1]] : tensor<?x?x?x?xf32>
|
||||
// CHECK-DAG: %[[INPUT_DIM_2:.+]] = dim %[[X]], %[[C2]] : tensor<?x?x?x?xf32>
|
||||
// CHECK-DAG: %[[INPUT_DIM_3:.+]] = dim %[[X]], %[[C3]] : tensor<?x?x?x?xf32>
|
||||
// CHECK-DAG: %[[TO_INPUT_DIM_TENSOR:.+]] = tensor_from_elements(%[[INPUT_DIM_0]], %[[INPUT_DIM_1]], %[[INPUT_DIM_2]], %[[INPUT_DIM_3]]) : tensor<4xindex>
|
||||
// CHECK-DAG: %[[STDDEV_BCAST:.+]] = "xla_hlo.dynamic_broadcast_in_dim"(%[[STDDEV]], %[[TO_INPUT_DIM_TENSOR]]) {broadcast_dimensions = dense<1> : tensor<1xi64>} : (tensor<?xf32>, tensor<4xindex>) -> tensor<?x?x?x?xf32>
|
||||
// CHECK-DAG: %[[SCALE_BCAST:.+]] = "xla_hlo.dynamic_broadcast_in_dim"(%[[SCALE]], %[[TO_INPUT_DIM_TENSOR]]) {broadcast_dimensions = dense<1> : tensor<1xi64>} : (tensor<?xf32>, tensor<4xindex>) -> tensor<?x?x?x?xf32>
|
||||
|
@ -16,8 +16,7 @@ func @sqr_transform_result(%a: tensor<*xf32>) -> tensor<*xf32> {
|
||||
%flat_b = "xla_hlo.sqrt"(%flat_a) : (tensor<?xf32>) -> tensor<?xf32>
|
||||
|
||||
// Restore original shape.
|
||||
%shape_as_extent_tensor = "shape.to_extent_tensor"(%shape)
|
||||
: (!shape.shape) -> tensor<?xindex>
|
||||
%shape_as_extent_tensor = shape.to_extent_tensor %shape : tensor<?xindex>
|
||||
%b = "xla_hlo.dynamic_reshape"(%flat_b, %shape_as_extent_tensor)
|
||||
: (tensor<?xf32>, tensor<?xindex>) -> tensor<*xf32>
|
||||
|
||||
@ -35,7 +34,7 @@ func @sqrt(%a: tensor<*xf32>) -> tensor<*xf32> {
|
||||
// CHECK-NEXT: %[[FLAT_SHAPE:.*]] = tensor_from_elements(%[[NUM_ELEMENTS_AS_INDEX]]) : tensor<1xindex>
|
||||
// CHECK-NEXT: %[[FLAT_A:.*]] = "xla_hlo.dynamic_reshape"(%[[A]], %[[FLAT_SHAPE]]) : (tensor<*xf32>, tensor<1xindex>) -> tensor<?xf32>
|
||||
// CHECK-NEXT: %[[FLAT_B:.*]] = "xla_hlo.sqrt"(%[[FLAT_A]]) : (tensor<?xf32>) -> tensor<?xf32>
|
||||
// CHECK-NEXT: %[[SHAPE_AS_EXTENT_TENSOR:.*]] = "shape.to_extent_tensor"(%[[SHAPE]]) : (!shape.shape) -> tensor<?xindex>
|
||||
// CHECK-NEXT: %[[SHAPE_AS_EXTENT_TENSOR:.*]] = shape.to_extent_tensor %[[SHAPE]] : tensor<?xindex>
|
||||
// CHECK-NEXT: %[[B:.*]] = "xla_hlo.dynamic_reshape"(%[[FLAT_B]], %[[SHAPE_AS_EXTENT_TENSOR]]) : (tensor<?xf32>, tensor<?xindex>) -> tensor<*xf32>
|
||||
// CHECK-NEXT: return %[[B]] : tensor<*xf32>
|
||||
%b = "xla_hlo.sqrt"(%a) : (tensor<*xf32>) -> tensor<*xf32>
|
||||
|
@ -18,6 +18,7 @@ limitations under the License.
|
||||
#include <cstddef>
|
||||
#include <cstdint>
|
||||
#include <iterator>
|
||||
#include <limits>
|
||||
#include <numeric>
|
||||
|
||||
#include "llvm/ADT/ArrayRef.h"
|
||||
@ -25,6 +26,7 @@ limitations under the License.
|
||||
#include "llvm/ADT/STLExtras.h"
|
||||
#include "llvm/ADT/Sequence.h"
|
||||
#include "llvm/ADT/SmallVector.h"
|
||||
#include "llvm/Support/ErrorHandling.h"
|
||||
#include "llvm/Support/FormatVariadic.h"
|
||||
#include "mlir/Dialect/Shape/IR/Shape.h" // from @llvm-project
|
||||
#include "mlir/Dialect/StandardOps/IR/Ops.h" // from @llvm-project
|
||||
@ -54,6 +56,7 @@ limitations under the License.
|
||||
#include "tensorflow/compiler/xla/xla_data.pb.h"
|
||||
#include "tensorflow/core/framework/kernel_shape_util.h"
|
||||
#include "tensorflow/core/kernels/conv_grad_shape_utils.h"
|
||||
#include "tensorflow/core/lib/bfloat16/bfloat16.h"
|
||||
#include "tensorflow/core/util/padding.h"
|
||||
#include "tensorflow/core/util/tensor_format.h"
|
||||
|
||||
@ -881,6 +884,31 @@ static Type GetAccumulationType(Type ty) {
|
||||
return (ty.isF16() || ty.isBF16()) ? FloatType::getF32(ty.getContext()) : ty;
|
||||
}
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
// Softplus op utilities.
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
static DenseElementsAttr GetEpsilonValue(Type ty) {
|
||||
auto element_ty = ty.cast<TensorType>().getElementType();
|
||||
auto scalar_ty = RankedTensorType::get({}, element_ty);
|
||||
if (element_ty.isF16()) {
|
||||
uint16_t raw_epsilon = Eigen::NumTraits<Eigen::half>::epsilon().x;
|
||||
auto value = APFloat(APFloat::IEEEhalf(), APInt(16, raw_epsilon));
|
||||
return DenseElementsAttr::get(scalar_ty, value);
|
||||
} else if (element_ty.isBF16()) {
|
||||
uint16_t raw_epsilon = tensorflow::bfloat16::epsilon().value;
|
||||
auto value = APFloat(APFloat::BFloat(), APInt(16, raw_epsilon));
|
||||
return DenseElementsAttr::get(scalar_ty, value);
|
||||
} else if (element_ty.isF32()) {
|
||||
auto value = APFloat(std::numeric_limits<float>::epsilon());
|
||||
return DenseElementsAttr::get(scalar_ty, value);
|
||||
} else if (element_ty.isF64()) {
|
||||
auto value = APFloat(std::numeric_limits<double>::epsilon());
|
||||
return DenseElementsAttr::get(scalar_ty, value);
|
||||
}
|
||||
llvm_unreachable("unsupported element type for tf.SoftPlus");
|
||||
}
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
// ArgMax/ArgMin op utilities.
|
||||
//===----------------------------------------------------------------------===//
|
||||
@ -4387,45 +4415,6 @@ class ConvertRandomShuffleOp : public OpRewritePattern<TF::RandomShuffleOp> {
|
||||
}
|
||||
};
|
||||
|
||||
// Converts tf.VariableShape op to a XLA HLO constant representing the variable
|
||||
// shape.
|
||||
class ConvertVariableShapeOp : public OpRewritePattern<TF::VariableShapeOp> {
|
||||
public:
|
||||
using OpRewritePattern::OpRewritePattern;
|
||||
|
||||
LogicalResult matchAndRewrite(TF::VariableShapeOp op,
|
||||
PatternRewriter &rewriter) const override {
|
||||
// The input type should be a tensor<!tf.resource<resource-type>>. We need
|
||||
// to get the inner resource type.
|
||||
auto input_type = op.input().getType().cast<TensorType>();
|
||||
auto subtypes =
|
||||
input_type.getElementType().cast<TF::ResourceType>().getSubtypes();
|
||||
// It can be missing; then we cannot convert.
|
||||
if (subtypes.empty()) return failure();
|
||||
|
||||
auto resource_type = subtypes[0].cast<TensorType>();
|
||||
if (!resource_type.hasStaticShape()) return failure();
|
||||
|
||||
auto resource_shape = resource_type.getShape();
|
||||
Attribute const_attr;
|
||||
|
||||
// We need to match the original op result's element type.
|
||||
auto element_type = op.getType().cast<TensorType>().getElementType();
|
||||
unsigned bitwidth = element_type.cast<IntegerType>().getWidth();
|
||||
if (bitwidth == 32) {
|
||||
SmallVector<int32_t, 4> shape(resource_shape.begin(),
|
||||
resource_shape.end());
|
||||
const_attr = GetI32ElementsAttr(shape, &rewriter);
|
||||
} else {
|
||||
assert(bitwidth == 64);
|
||||
const_attr = GetI64ElementsAttr(resource_shape, &rewriter);
|
||||
}
|
||||
|
||||
rewriter.replaceOpWithNewOp<xla_hlo::ConstOp>(op, const_attr);
|
||||
return success();
|
||||
}
|
||||
};
|
||||
|
||||
// Converts an XlaSharding op to a XLA HLO shard op with sharding attributes.
|
||||
class ConvertXlaShardingOp : public OpRewritePattern<TF::XlaShardingOp> {
|
||||
public:
|
||||
@ -4621,45 +4610,19 @@ class ConvertShapeOp : public OpRewritePattern<TF::ShapeOp> {
|
||||
LogicalResult matchAndRewrite(TF::ShapeOp op,
|
||||
PatternRewriter &rewriter) const override {
|
||||
Value input = op.input();
|
||||
auto input_ty = input.getType().dyn_cast<RankedTensorType>();
|
||||
// If the shape is static it can be canonicalized.
|
||||
if (!input_ty || input_ty.hasStaticShape()) {
|
||||
|
||||
auto shape_op = rewriter.create<shape::ShapeOfOp>(op.getLoc(), input);
|
||||
auto result_ty = op.getResult().getType().dyn_cast<RankedTensorType>();
|
||||
if (!result_ty) {
|
||||
return failure();
|
||||
}
|
||||
|
||||
auto result_ty = op.getResult().getType().cast<RankedTensorType>();
|
||||
auto element_ty = result_ty.getElementType();
|
||||
auto index_tensor =
|
||||
RankedTensorType::get(result_ty.getShape(), rewriter.getIndexType());
|
||||
auto extent_tensor = rewriter.create<shape::ToExtentTensorOp>(
|
||||
op.getLoc(), index_tensor, shape_op);
|
||||
|
||||
int64_t rank = input_ty.getRank();
|
||||
auto shape_op = rewriter.create<shape::ShapeOfOp>(op.getLoc(), input);
|
||||
|
||||
auto index_ty = RankedTensorType::get({1}, element_ty);
|
||||
llvm::SmallVector<Value, 4> dim_values;
|
||||
for (int64_t i = 0; i < rank; ++i) {
|
||||
if (!input_ty.isDynamicDim(i)) {
|
||||
auto dim_attr = DenseElementsAttr::get(
|
||||
index_ty,
|
||||
rewriter.getIntegerAttr(element_ty, input_ty.getDimSize(i)));
|
||||
auto index = rewriter.create<xla_hlo::ConstOp>(op.getLoc(), dim_attr);
|
||||
dim_values.push_back(index);
|
||||
continue;
|
||||
}
|
||||
|
||||
auto extent_op = rewriter.create<shape::GetExtentOp>(
|
||||
op.getLoc(), shape_op, rewriter.getI64IntegerAttr(i));
|
||||
auto index_op = rewriter.create<shape::SizeToIndexOp>(
|
||||
op.getLoc(), rewriter.getIndexType(), extent_op);
|
||||
auto int_op =
|
||||
rewriter.create<IndexCastOp>(op.getLoc(), element_ty, index_op);
|
||||
auto from_tensor = rewriter.create<TensorFromElementsOp>(
|
||||
op.getLoc(), int_op.getResult());
|
||||
auto reshape_op =
|
||||
rewriter.create<ReshapeOp>(op.getLoc(), index_ty, from_tensor);
|
||||
dim_values.push_back(reshape_op);
|
||||
}
|
||||
|
||||
rewriter.replaceOpWithNewOp<ConcatenateOp>(op, result_ty, dim_values,
|
||||
rewriter.getI64IntegerAttr(0));
|
||||
rewriter.replaceOpWithNewOp<IndexCastOp>(op, result_ty, extent_tensor);
|
||||
return success();
|
||||
}
|
||||
};
|
||||
@ -5250,7 +5213,7 @@ LogicalResult legalizeTF(Operation *op, bool allow_partial_conversion,
|
||||
ConvertTensorScatterUpdateOp, ConvertTileOp, ConvertTopKV2Op,
|
||||
ConvertUnpackOp, ConvertUnsortedSegmentMaxOp, ConvertUnsortedSegmentMinOp,
|
||||
ConvertUnsortedSegmentProdOp, ConvertUnsortedSegmentSumOp,
|
||||
ConvertRandomShuffleOp, ConvertVariableShapeOp, ConvertXlaShardingOp,
|
||||
ConvertRandomShuffleOp, ConvertXlaShardingOp,
|
||||
ConvertXlaDynamicUpdateSliceOp>(op->getContext());
|
||||
|
||||
// Populate with CHLO->HLO lowerings to account for TF ops legalized to
|
||||
|
@ -618,3 +618,39 @@ def : Pat<(TF_SigmoidGradOp AnyRankedTensor:$l, AnyRankedTensor:$r),
|
||||
(HLO_MulOp
|
||||
(HLO_MulOp $r, $l),
|
||||
(HLO_SubOp (HLO_ConstOp (ConstantSplat<"1"> $l)), $l))>;
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
// Softplus op.
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
def EpsilonValue : NativeCodeCall<"GetEpsilonValue($0.getType())">;
|
||||
|
||||
def : Pattern<(TF_SoftplusOp AnyTensor:$features),
|
||||
[
|
||||
(HLO_ExpOp:$features_exp $features),
|
||||
(HLOClient_BroadcastAddOp:$threshold
|
||||
(HLO_LogOp (HLO_ConstOp (EpsilonValue $features))),
|
||||
(HLO_ConstOp (GetScalarOfType<2> $features)),
|
||||
(NullDenseIntElementsAttr)
|
||||
),
|
||||
(HLO_SelectOp:$output
|
||||
(HLOClient_BroadcastCompareOp
|
||||
$features,
|
||||
(HLO_NegOp $threshold),
|
||||
(NullDenseIntElementsAttr),
|
||||
HLO_COMPARISON_DIRECTION_GT
|
||||
),
|
||||
$features,
|
||||
(HLO_SelectOp
|
||||
(HLOClient_BroadcastCompareOp
|
||||
$features,
|
||||
$threshold,
|
||||
(NullDenseIntElementsAttr),
|
||||
HLO_COMPARISON_DIRECTION_LT
|
||||
),
|
||||
$features_exp,
|
||||
(HLO_Log1pOp $features_exp)
|
||||
)
|
||||
),
|
||||
(replaceWithValue $output)
|
||||
]>;
|
||||
|
@ -89,6 +89,8 @@ static bool IsOpWhitelisted(Operation* op) {
|
||||
TypeID::get<TF::AddV2Op>(),
|
||||
TypeID::get<TF::AngleOp>(),
|
||||
TypeID::get<TF::ApproximateEqualOp>(),
|
||||
TypeID::get<TF::ArgMaxOp>(),
|
||||
TypeID::get<TF::ArgMinOp>(),
|
||||
TypeID::get<TF::AsinhOp>(),
|
||||
TypeID::get<TF::AsinOp>(),
|
||||
TypeID::get<TF::Atan2Op>(),
|
||||
@ -100,6 +102,7 @@ static bool IsOpWhitelisted(Operation* op) {
|
||||
TypeID::get<TF::BitwiseAndOp>(),
|
||||
TypeID::get<TF::BitwiseOrOp>(),
|
||||
TypeID::get<TF::BitwiseXorOp>(),
|
||||
TypeID::get<TF::BucketizeOp>(),
|
||||
TypeID::get<TF::CastOp>(),
|
||||
TypeID::get<TF::ClipByValueOp>(),
|
||||
TypeID::get<TF::ComplexAbsOp>(),
|
||||
@ -116,13 +119,24 @@ static bool IsOpWhitelisted(Operation* op) {
|
||||
TypeID::get<TF::ErfcOp>(),
|
||||
TypeID::get<TF::ErfOp>(),
|
||||
TypeID::get<TF::Expm1Op>(),
|
||||
TypeID::get<TF::FFT2DOp>(),
|
||||
TypeID::get<TF::FFT3DOp>(),
|
||||
TypeID::get<TF::FFTOp>(),
|
||||
TypeID::get<TF::FloorDivOp>(),
|
||||
TypeID::get<TF::FloorModOp>(),
|
||||
TypeID::get<TF::GatherNdOp>(),
|
||||
TypeID::get<TF::GreaterEqualOp>(),
|
||||
TypeID::get<TF::GreaterOp>(),
|
||||
TypeID::get<TF::IFFT2DOp>(),
|
||||
TypeID::get<TF::IFFT3DOp>(),
|
||||
TypeID::get<TF::IFFTOp>(),
|
||||
TypeID::get<TF::IRFFT2DOp>(),
|
||||
TypeID::get<TF::IRFFT3DOp>(),
|
||||
TypeID::get<TF::IRFFTOp>(),
|
||||
TypeID::get<TF::InvertOp>(),
|
||||
TypeID::get<TF::InvOp>(),
|
||||
TypeID::get<TF::LRNOp>(),
|
||||
TypeID::get<TF::LRNGradOp>(),
|
||||
TypeID::get<TF::LeakyReluGradOp>(),
|
||||
TypeID::get<TF::LeakyReluOp>(),
|
||||
TypeID::get<TF::LeftShiftOp>(),
|
||||
@ -134,16 +148,20 @@ static bool IsOpWhitelisted(Operation* op) {
|
||||
TypeID::get<TF::LogicalOrOp>(),
|
||||
TypeID::get<TF::LogOp>(),
|
||||
TypeID::get<TF::MatMulOp>(),
|
||||
TypeID::get<TF::MirrorPadOp>(),
|
||||
TypeID::get<TF::MulOp>(),
|
||||
TypeID::get<TF::NegOp>(),
|
||||
TypeID::get<TF::NotEqualOp>(),
|
||||
TypeID::get<TF::PadOp>(),
|
||||
TypeID::get<TF::PlaceholderWithDefaultOp>(),
|
||||
TypeID::get<TF::PowOp>(),
|
||||
TypeID::get<TF::RFFT2DOp>(),
|
||||
TypeID::get<TF::RFFT3DOp>(),
|
||||
TypeID::get<TF::RealDivOp>(),
|
||||
TypeID::get<TF::ReciprocalOp>(),
|
||||
TypeID::get<TF::ReciprocalGradOp>(),
|
||||
TypeID::get<TF::Relu6GradOp>(),
|
||||
TypeID::get<TF::ReverseSequenceOp>(),
|
||||
TypeID::get<TF::RightShiftOp>(),
|
||||
TypeID::get<TF::RintOp>(),
|
||||
TypeID::get<TF::RoundOp>(),
|
||||
@ -156,6 +174,7 @@ static bool IsOpWhitelisted(Operation* op) {
|
||||
TypeID::get<TF::SoftplusGradOp>(),
|
||||
TypeID::get<TF::SoftsignGradOp>(),
|
||||
TypeID::get<TF::SoftsignOp>(),
|
||||
TypeID::get<TF::SparseToDenseOp>(),
|
||||
TypeID::get<TF::SqrtGradOp>(),
|
||||
TypeID::get<TF::SquareOp>(),
|
||||
TypeID::get<TF::SubOp>(),
|
||||
|
@ -16,7 +16,6 @@ limitations under the License.
|
||||
// This file implements logic for lowering HLO/LHLO dialect to Linalg dialect.
|
||||
|
||||
#include "absl/memory/memory.h"
|
||||
#include "llvm/ADT/APInt.h"
|
||||
#include "mlir/Dialect/Linalg/IR/LinalgOps.h" // from @llvm-project
|
||||
#include "mlir/Dialect/Linalg/IR/LinalgTypes.h" // from @llvm-project
|
||||
#include "mlir/Dialect/StandardOps/IR/Ops.h" // from @llvm-project
|
||||
@ -39,13 +38,9 @@ limitations under the License.
|
||||
namespace mlir {
|
||||
namespace {
|
||||
|
||||
ArrayAttr GetNParallelLoopsAttrs(unsigned nParallelLoops, Builder* b) {
|
||||
auto parallelLoopTypeAttr = b->getStringAttr("parallel");
|
||||
SmallVector<Attribute, 3> iteratorTypes;
|
||||
for (int i = 0; i < nParallelLoops; ++i) {
|
||||
iteratorTypes.push_back(parallelLoopTypeAttr);
|
||||
}
|
||||
return b->getArrayAttr(iteratorTypes);
|
||||
SmallVector<StringRef, 3> GetNParallelLoopsAttrs(unsigned nParallelLoops) {
|
||||
static constexpr StringRef kParallelIterType = "parallel";
|
||||
return SmallVector<StringRef, 3>(nParallelLoops, kParallelIterType);
|
||||
}
|
||||
|
||||
template <bool isLHLO = true>
|
||||
@ -90,7 +85,7 @@ class PointwiseToLinalgConverter : public OpConversionPattern<OpTy> {
|
||||
}
|
||||
|
||||
// Construct the indexing maps needed for linalg.generic ops.
|
||||
SmallVector<Attribute, 2> indexingMaps;
|
||||
SmallVector<AffineMap, 2> indexing_maps;
|
||||
SmallVector<Type, 4> bodyArgTypes, bodyResultTypes, opResultTypes;
|
||||
|
||||
// This doesnt account for implicit broadcast, but the working assumption
|
||||
@ -107,9 +102,9 @@ class PointwiseToLinalgConverter : public OpConversionPattern<OpTy> {
|
||||
!shapedType.isa<RankedTensorType>()) ||
|
||||
shapedType.getRank() != nloops)
|
||||
return nullptr;
|
||||
indexingMaps.emplace_back(AffineMapAttr::get(
|
||||
indexing_maps.emplace_back(
|
||||
nloops ? rewriter.getMultiDimIdentityMap(nloops)
|
||||
: AffineMap::get(nloops, 0, rewriter.getContext())));
|
||||
: AffineMap::get(nloops, 0, rewriter.getContext()));
|
||||
return shapedType;
|
||||
};
|
||||
for (const auto& arg : llvm::enumerate(args)) {
|
||||
@ -132,11 +127,9 @@ class PointwiseToLinalgConverter : public OpConversionPattern<OpTy> {
|
||||
|
||||
auto linalgOp = rewriter.create<linalg::GenericOp>(
|
||||
loc, opResultTypes, args,
|
||||
rewriter.getI64IntegerAttr(bodyArgTypes.size()), // args_in
|
||||
rewriter.getI64IntegerAttr(bodyResultTypes.size()), // args_out
|
||||
rewriter.getArrayAttr(indexingMaps),
|
||||
GetNParallelLoopsAttrs(nloops, &rewriter),
|
||||
/*doc=*/nullptr, /*library_call=*/nullptr);
|
||||
/*inputCount=*/bodyArgTypes.size(),
|
||||
/*outputCount=*/bodyResultTypes.size(), indexing_maps,
|
||||
GetNParallelLoopsAttrs(nloops));
|
||||
|
||||
// Add a block to the region.
|
||||
auto* region = &linalgOp.region();
|
||||
@ -297,8 +290,8 @@ struct ConvToLinalgConverter : public OpConversionPattern<xla_lhlo::ConvOp> {
|
||||
/// Base class for lowering xla operations that have one operand and one result,
|
||||
/// and are semantically equivalent to a copy of the input to the output (like
|
||||
/// transpose, some reshape, etc.). The derived classes need to provide a method
|
||||
/// `getIndexingMapsAttr` that returns an ArrayAttr containing AffineMapAttr for
|
||||
/// the index maps of the input and the output.
|
||||
/// `getIndexingMaps` that returns AffineMaps for the index maps of the input
|
||||
/// and the output.
|
||||
template <typename Derived, typename OpTy, bool isLHLO = true>
|
||||
class DataMovementOpConverter : public OpConversionPattern<OpTy> {
|
||||
public:
|
||||
@ -310,17 +303,17 @@ class DataMovementOpConverter : public OpConversionPattern<OpTy> {
|
||||
if (!verifyXLAOpBufferOrTensorSemantics<isLHLO>(op)) return failure();
|
||||
auto operandType = op.operand().getType().template cast<ShapedType>();
|
||||
auto resultType = getXLAOpResultType<isLHLO>(op);
|
||||
ArrayAttr indexingMapsAttr = Derived::getIndexingMapsAttr(op, &rewriter);
|
||||
if (!indexingMapsAttr) return failure();
|
||||
|
||||
SmallVector<AffineMap, 2> indexing_maps =
|
||||
Derived::getIndexingMaps(op, &rewriter);
|
||||
if (indexing_maps.empty()) return failure();
|
||||
|
||||
OpBuilder::InsertionGuard linalgOpGuard(rewriter);
|
||||
auto nloops = resultType.getRank();
|
||||
auto loc = op.getLoc();
|
||||
auto linalgOp = rewriter.create<linalg::GenericOp>(
|
||||
loc, isLHLO ? ArrayRef<Type>{} : resultType, args,
|
||||
rewriter.getI64IntegerAttr(1), rewriter.getI64IntegerAttr(1),
|
||||
indexingMapsAttr, GetNParallelLoopsAttrs(nloops, &rewriter),
|
||||
/*doc=*/nullptr, /*library_call=*/nullptr);
|
||||
loc, isLHLO ? ArrayRef<Type>{} : resultType, args, /*inputCount=*/1,
|
||||
/*outputCount=*/1, indexing_maps, GetNParallelLoopsAttrs(nloops));
|
||||
|
||||
auto* region = &linalgOp.region();
|
||||
auto* block = rewriter.createBlock(region, region->end());
|
||||
@ -344,7 +337,8 @@ class BroadcastConverter
|
||||
using DataMovementOpConverter<BroadcastConverter, OpTy,
|
||||
isLHLO>::DataMovementOpConverter;
|
||||
|
||||
static ArrayAttr getIndexingMapsAttr(OpTy broadcastOp, Builder* b) {
|
||||
static SmallVector<AffineMap, 2> getIndexingMaps(OpTy broadcastOp,
|
||||
Builder* b) {
|
||||
ShapedType inputType =
|
||||
broadcastOp.operand().getType().template cast<ShapedType>();
|
||||
unsigned inputRank = inputType.getRank();
|
||||
@ -368,8 +362,7 @@ class BroadcastConverter
|
||||
inputMap =
|
||||
AffineMap::get(nloops, /*symbolCount=*/0, inputDimExprs, context);
|
||||
}
|
||||
return b->getAffineMapArrayAttr(
|
||||
{inputMap, b->getMultiDimIdentityMap(nloops)});
|
||||
return {inputMap, b->getMultiDimIdentityMap(nloops)};
|
||||
}
|
||||
};
|
||||
|
||||
@ -381,8 +374,8 @@ class HloBroadcastInDimConverter
|
||||
xla_hlo::BroadcastInDimOp,
|
||||
false>::DataMovementOpConverter;
|
||||
|
||||
static ArrayAttr getIndexingMapsAttr(xla_hlo::BroadcastInDimOp broadcastOp,
|
||||
Builder* b) {
|
||||
static SmallVector<AffineMap, 2> getIndexingMaps(
|
||||
xla_hlo::BroadcastInDimOp broadcastOp, Builder* b) {
|
||||
auto resultType = getXLAOpResultType<false>(broadcastOp);
|
||||
auto operandType =
|
||||
broadcastOp.operand().getType().template cast<ShapedType>();
|
||||
@ -390,9 +383,8 @@ class HloBroadcastInDimConverter
|
||||
|
||||
// The input is a scalar, i.e. this is a scalar broadcast op.
|
||||
if (operandType.getRank() == 0) {
|
||||
return b->getAffineMapArrayAttr(
|
||||
{AffineMap::get(nloops, /*symbolCount=*/0, b->getContext()),
|
||||
b->getMultiDimIdentityMap(nloops)});
|
||||
return {AffineMap::get(nloops, /*symbolCount=*/0, b->getContext()),
|
||||
b->getMultiDimIdentityMap(nloops)};
|
||||
}
|
||||
|
||||
auto operandShape = operandType.getShape();
|
||||
@ -409,9 +401,9 @@ class HloBroadcastInDimConverter
|
||||
: b->getAffineDimExpr(size));
|
||||
}
|
||||
}
|
||||
return b->getAffineMapArrayAttr(
|
||||
{AffineMap::get(nloops, /*symbolCount=*/0, dimExprs, b->getContext()),
|
||||
b->getMultiDimIdentityMap(nloops)});
|
||||
return {
|
||||
AffineMap::get(nloops, /*symbolCount=*/0, dimExprs, b->getContext()),
|
||||
b->getMultiDimIdentityMap(nloops)};
|
||||
}
|
||||
};
|
||||
|
||||
@ -447,11 +439,9 @@ class LhloBroadcastInDimConverter
|
||||
rewriter.create<LoadOp>(loc, operand, llvm::makeArrayRef({zero}));
|
||||
auto linalgOp = rewriter.create<linalg::GenericOp>(
|
||||
loc, llvm::None, llvm::makeArrayRef(operand_adaptor.output()),
|
||||
rewriter.getI64IntegerAttr(0), rewriter.getI64IntegerAttr(1),
|
||||
rewriter.getAffineMapArrayAttr(
|
||||
{rewriter.getMultiDimIdentityMap(nloops)}),
|
||||
GetNParallelLoopsAttrs(nloops, &rewriter),
|
||||
/*doc=*/nullptr, /*library_call=*/nullptr);
|
||||
/*inputCount=*/0, /*outputCount=*/1,
|
||||
llvm::makeArrayRef(rewriter.getMultiDimIdentityMap(nloops)),
|
||||
GetNParallelLoopsAttrs(nloops));
|
||||
|
||||
auto* region = &linalgOp.region();
|
||||
auto* block = rewriter.createBlock(region, region->end());
|
||||
@ -460,16 +450,15 @@ class LhloBroadcastInDimConverter
|
||||
rewriter.setInsertionPointToEnd(block);
|
||||
rewriter.create<linalg::YieldOp>(loc, val);
|
||||
} else {
|
||||
ArrayAttr indexingMapsAttr = getIndexingMapsAttr(
|
||||
op, broadcast_dims, result_shape, operand_type, &rewriter);
|
||||
auto indexing_maps = getIndexingMaps(op, broadcast_dims, result_shape,
|
||||
operand_type, &rewriter);
|
||||
|
||||
OpBuilder::InsertionGuard linalgOpGuard(rewriter);
|
||||
auto linalgOp = rewriter.create<linalg::GenericOp>(
|
||||
loc, llvm::None,
|
||||
llvm::makeArrayRef({operand, operand_adaptor.output()}),
|
||||
rewriter.getI64IntegerAttr(1), rewriter.getI64IntegerAttr(1),
|
||||
indexingMapsAttr, GetNParallelLoopsAttrs(nloops, &rewriter),
|
||||
/*doc=*/nullptr, /*library_call=*/nullptr);
|
||||
/*inputCount=*/1, /*outputCount=*/1, indexing_maps,
|
||||
GetNParallelLoopsAttrs(nloops));
|
||||
|
||||
auto* region = &linalgOp.region();
|
||||
auto* block = rewriter.createBlock(region, region->end());
|
||||
@ -504,14 +493,14 @@ class LhloBroadcastInDimConverter
|
||||
}
|
||||
|
||||
SmallVector<int64_t, 2> new_shape, new_strides, broadcast_dims;
|
||||
SmallVector<SmallVector<AffineExpr, 2>, 4> collapsed_dims_list;
|
||||
SmallVector<AffineExpr, 2> collapsed_dims;
|
||||
SmallVector<linalg::ReassociationIndices, 4> collapsed_dims_list;
|
||||
linalg::ReassociationIndices collapsed_dims;
|
||||
for (const auto& item :
|
||||
enumerate(op.broadcast_dimensions().getIntValues())) {
|
||||
size_t index = item.index();
|
||||
int dim = item.value().getSExtValue();
|
||||
|
||||
collapsed_dims.push_back(rewriter.getAffineDimExpr(index));
|
||||
collapsed_dims.push_back(index);
|
||||
|
||||
bool expansion_needed =
|
||||
operand_shape[index] == 1 && result_shape[dim] != 1;
|
||||
@ -542,31 +531,28 @@ class LhloBroadcastInDimConverter
|
||||
// `linalg.reshape` is inserted only if necessary, i.e. when the rank can be
|
||||
// reduced.
|
||||
if (new_shape.size() < operand_shape.size()) {
|
||||
SmallVector<ArrayRef<AffineExpr>, 4> reassociation_maps;
|
||||
for (const auto& dims : collapsed_dims_list)
|
||||
reassociation_maps.push_back(dims);
|
||||
auto new_memref_type = MemRefType::get(
|
||||
new_shape, operand_type.getElementType(),
|
||||
makeStridedLinearLayoutMap(new_strides, operand_offset,
|
||||
rewriter.getContext()));
|
||||
operand = rewriter.create<linalg::ReshapeOp>(op.getLoc(), new_memref_type,
|
||||
operand_adaptor.operand(),
|
||||
reassociation_maps);
|
||||
collapsed_dims_list);
|
||||
}
|
||||
return std::make_pair(operand, broadcast_dims);
|
||||
}
|
||||
|
||||
ArrayAttr getIndexingMapsAttr(xla_lhlo::BroadcastInDimOp op,
|
||||
SmallVector<AffineMap, 2> getIndexingMaps(xla_lhlo::BroadcastInDimOp op,
|
||||
ArrayRef<int64_t> broadcastDims,
|
||||
ArrayRef<int64_t> resultShape,
|
||||
MemRefType operandType, Builder* b) const {
|
||||
MemRefType operandType,
|
||||
Builder* b) const {
|
||||
unsigned nloops = resultShape.size();
|
||||
|
||||
// The input is a scalar, i.e. this is a scalar broadcast op.
|
||||
if (operandType.getRank() == 0) {
|
||||
return b->getAffineMapArrayAttr(
|
||||
{AffineMap::get(nloops, /*symbolCount=*/0, b->getContext()),
|
||||
b->getMultiDimIdentityMap(nloops)});
|
||||
return {AffineMap::get(nloops, /*symbolCount=*/0, b->getContext()),
|
||||
b->getMultiDimIdentityMap(nloops)};
|
||||
}
|
||||
|
||||
auto operandShape = operandType.getShape();
|
||||
@ -584,99 +570,9 @@ class LhloBroadcastInDimConverter
|
||||
}
|
||||
dimExprs.push_back(b->getAffineDimExpr(size));
|
||||
}
|
||||
return b->getAffineMapArrayAttr(
|
||||
{AffineMap::get(nloops, /*symbolCount=*/0, dimExprs, b->getContext()),
|
||||
b->getMultiDimIdentityMap(nloops)});
|
||||
}
|
||||
};
|
||||
|
||||
/// Pattern for the special case where reshape is adding or removing a dimension
|
||||
/// of size 1. These can be lowered to a linalg.generic op.
|
||||
///
|
||||
/// For example a
|
||||
/// "xla_hlo.reshape"(..) : (tensor<12x1x42xi32) -> tensor<12x42xi32>
|
||||
/// can have indexing maps
|
||||
/// [affine_map<(d0, d1) -> (d0, 0, d1)>, affine_map<(d0, d1) -> (d0, d1)>]
|
||||
///
|
||||
/// Similarly a
|
||||
/// "xla_hlo.reshape"(..) : (tensor<12x42xi32>) -> tensor<12x1x42xi32>
|
||||
/// can have indexing maps
|
||||
/// [affine_map<(d0, d1, d2) -> (d0, d2)>, affine_map<(d0, d1, d2) -> (d0, d1,
|
||||
/// d2)>]
|
||||
|
||||
// TODO(ravishankarm): This pattern needs to be removed. The general reshape
|
||||
// lowering hits a corner case where the following sequence of operations
|
||||
// cannot be fused cause the resulting indexing map is not invertible.
|
||||
//
|
||||
// %r = linalg.reshape %s [affine_map<(d0, d1, d2) -> (d0, d1)>,
|
||||
// affine_map<(d0, d1, d2) -> (d2)>]
|
||||
// : tensor<5x5xf32> into tensor<5x1x5xf32>
|
||||
// %f = linalg.generic
|
||||
// {...
|
||||
// indexing_maps = [affine_map<(d0, d1, d2) -> (d0, d1, d2)>,
|
||||
// affine_map<(d0, d1, d2) -> (d0, d2)>],
|
||||
// iterator_types = ["parallel", "parallel", "parallel"]} %r {..}
|
||||
// : tensor<5x1x5xf32> -> tensor<5x5xf32>
|
||||
//
|
||||
// The resolution of this requires a canonicalization on linalg ops where the
|
||||
// dims of size 1 are removed. This pattern can be removed after that.
|
||||
template <typename OpTy, bool isLHLO = true>
|
||||
class ReshapeAddRemoveDimConverter
|
||||
: public DataMovementOpConverter<ReshapeAddRemoveDimConverter<OpTy, isLHLO>,
|
||||
OpTy, isLHLO> {
|
||||
public:
|
||||
ReshapeAddRemoveDimConverter(MLIRContext* context)
|
||||
: DataMovementOpConverter<ReshapeAddRemoveDimConverter<OpTy, isLHLO>,
|
||||
OpTy, isLHLO>(context, 100) {}
|
||||
|
||||
static ArrayAttr getIndexingMapsAttr(OpTy op, Builder* b) {
|
||||
auto resultType =
|
||||
getXLAOpResultType<isLHLO>(op).template cast<ShapedType>();
|
||||
auto operandType =
|
||||
op.getOperation()->getOperand(0).getType().template cast<ShapedType>();
|
||||
if (!resultType.hasStaticShape() || !operandType.hasStaticShape())
|
||||
return nullptr;
|
||||
|
||||
auto nloops = resultType.getRank();
|
||||
SmallVector<AffineExpr, 2> inputExprs;
|
||||
unsigned resultIndex = 0, operandIndex = 0;
|
||||
auto resultShape = resultType.getShape();
|
||||
auto operandShape = operandType.getShape();
|
||||
|
||||
while (resultIndex < resultShape.size() &&
|
||||
operandIndex < operandShape.size()) {
|
||||
if (resultShape[resultIndex] == operandShape[operandIndex]) {
|
||||
// Copy over the affine expr when the size of the result and operand
|
||||
// match at a dim
|
||||
inputExprs.push_back(b->getAffineDimExpr(resultIndex));
|
||||
resultIndex++;
|
||||
operandIndex++;
|
||||
} else if (resultShape[resultIndex] == 1) {
|
||||
// If size at result is 1, then ignore this dimension for the input, it
|
||||
// is an extra dim added.
|
||||
resultIndex++;
|
||||
} else if (operandShape[operandIndex] == 1) {
|
||||
// If the operandShape is 1, then add a (0) for the operand map since
|
||||
// this dimension is dropped.
|
||||
inputExprs.push_back(b->getAffineConstantExpr(0));
|
||||
operandIndex++;
|
||||
} else {
|
||||
return nullptr;
|
||||
}
|
||||
}
|
||||
// Make sure all remaining dimensions of the operand and result are ones.
|
||||
auto checkRemainingDims = [](int64_t dim) { return dim != 1; };
|
||||
if ((resultIndex < resultShape.size() &&
|
||||
llvm::any_of(resultShape.drop_front(resultIndex),
|
||||
checkRemainingDims)) ||
|
||||
(operandIndex < operandShape.size() &&
|
||||
llvm::any_of(operandShape.drop_front(operandIndex),
|
||||
checkRemainingDims)))
|
||||
return nullptr;
|
||||
inputExprs.resize(operandShape.size(), b->getAffineConstantExpr(0));
|
||||
return b->getAffineMapArrayAttr(
|
||||
{AffineMap::get(nloops, /*symbolCount=*/0, inputExprs, b->getContext()),
|
||||
b->getMultiDimIdentityMap(nloops)});
|
||||
return {
|
||||
AffineMap::get(nloops, /*symbolCount=*/0, dimExprs, b->getContext()),
|
||||
b->getMultiDimIdentityMap(nloops)};
|
||||
}
|
||||
};
|
||||
|
||||
@ -687,7 +583,7 @@ class TransposeConverter
|
||||
public:
|
||||
using DataMovementOpConverter<TransposeConverter<OpTy, isLHLO>, OpTy,
|
||||
isLHLO>::DataMovementOpConverter;
|
||||
static ArrayAttr getIndexingMapsAttr(OpTy op, Builder* b) {
|
||||
static SmallVector<AffineMap, 2> getIndexingMaps(OpTy op, Builder* b) {
|
||||
auto resultType =
|
||||
getXLAOpResultType<isLHLO>(op).template cast<ShapedType>();
|
||||
auto nloops = resultType.getRank();
|
||||
@ -697,9 +593,9 @@ class TransposeConverter
|
||||
inputExprs[permutation.value().getZExtValue()] =
|
||||
b->getAffineDimExpr(permutation.index());
|
||||
}
|
||||
return b->getAffineMapArrayAttr(
|
||||
{AffineMap::get(nloops, /*symbolCount=*/0, inputExprs, b->getContext()),
|
||||
b->getMultiDimIdentityMap(nloops)});
|
||||
return {
|
||||
AffineMap::get(nloops, /*symbolCount=*/0, inputExprs, b->getContext()),
|
||||
b->getMultiDimIdentityMap(nloops)};
|
||||
}
|
||||
};
|
||||
|
||||
@ -722,13 +618,6 @@ class ReshapeOpConverter : public OpConversionPattern<OpTy> {
|
||||
if (!operandType.hasStaticShape() || !resultType.hasStaticShape())
|
||||
return failure();
|
||||
|
||||
// TODO(ravishankarm): To make this pattern not match the pattern that
|
||||
// ReshapeAddRemoveDimConverter is for, check that condition here. Remove
|
||||
// this when ReshapeAddRemoveDimConverter pattern is removed.
|
||||
if (ReshapeAddRemoveDimConverter<OpTy, isLHLO>::getIndexingMapsAttr(
|
||||
reshapeOp, &rewriter))
|
||||
return failure();
|
||||
|
||||
// Compute the reassociation maps for the linalg operation.
|
||||
ArrayRef<int64_t> srcShape =
|
||||
(operandType.getRank() > resultType.getRank() ? operandType.getShape()
|
||||
@ -737,22 +626,25 @@ class ReshapeOpConverter : public OpConversionPattern<OpTy> {
|
||||
(operandType.getRank() > resultType.getRank() ? resultType.getShape()
|
||||
: operandType.getShape());
|
||||
unsigned currSrcDim = 0, currDstDim = 0;
|
||||
SmallVector<SmallVector<AffineExpr, 4>, 4> exprs(dstShape.size());
|
||||
SmallVector<linalg::ReassociationExprs, 4> reassociationMap(
|
||||
dstShape.size());
|
||||
while (currSrcDim < srcShape.size() && currDstDim < dstShape.size()) {
|
||||
int64_t dstSize = dstShape[currDstDim];
|
||||
int64_t srcSize = srcShape[currSrcDim];
|
||||
while (srcSize < dstSize && currSrcDim < srcShape.size()) {
|
||||
exprs[currDstDim].push_back(rewriter.getAffineDimExpr(currSrcDim++));
|
||||
reassociationMap[currDstDim].push_back(
|
||||
rewriter.getAffineDimExpr(currSrcDim++));
|
||||
srcSize *= srcShape[currSrcDim];
|
||||
}
|
||||
if (srcSize == dstSize) {
|
||||
exprs[currDstDim].push_back(rewriter.getAffineDimExpr(currSrcDim++));
|
||||
reassociationMap[currDstDim].push_back(
|
||||
rewriter.getAffineDimExpr(currSrcDim++));
|
||||
// If the next dim in dstShape is not 1, treat subsequent dims in
|
||||
// srcShape which are 1 to be collapsed.
|
||||
if (currDstDim == dstShape.size() - 1 ||
|
||||
dstShape[currDstDim + 1] != 1) {
|
||||
while (currSrcDim < srcShape.size() && srcShape[currSrcDim] == 1) {
|
||||
exprs[currDstDim].push_back(
|
||||
reassociationMap[currDstDim].push_back(
|
||||
rewriter.getAffineDimExpr(currSrcDim++));
|
||||
}
|
||||
}
|
||||
@ -763,18 +655,15 @@ class ReshapeOpConverter : public OpConversionPattern<OpTy> {
|
||||
}
|
||||
if (currSrcDim != srcShape.size()) return failure();
|
||||
|
||||
SmallVector<ArrayRef<AffineExpr>, 4> reassociationMaps;
|
||||
for (auto& expr : exprs) reassociationMaps.push_back(expr);
|
||||
|
||||
if (isLHLO) {
|
||||
Value reshapeBuffer = rewriter.create<linalg::ReshapeOp>(
|
||||
reshapeOp.getLoc(), resultType, args[0], reassociationMaps);
|
||||
reshapeOp.getLoc(), resultType, args[0], reassociationMap);
|
||||
rewriter.replaceOpWithNewOp<linalg::CopyOp>(
|
||||
reshapeOp, reshapeBuffer, args[1], /*inputPermutation =*/nullptr,
|
||||
/*outputPermutation =*/nullptr);
|
||||
} else {
|
||||
rewriter.replaceOpWithNewOp<linalg::TensorReshapeOp>(
|
||||
reshapeOp, resultType, args[0], reassociationMaps);
|
||||
reshapeOp, resultType, args[0], reassociationMap);
|
||||
}
|
||||
return success();
|
||||
}
|
||||
@ -796,18 +685,14 @@ class IotaConverter : public OpConversionPattern<xla_lhlo::IotaOp> {
|
||||
|
||||
// Construct the indexing maps needed for linalg.generic ops.
|
||||
unsigned nloops = resultMemrefType.getRank();
|
||||
SmallVector<Attribute, 2> indexingMaps;
|
||||
indexingMaps.emplace_back(
|
||||
AffineMapAttr::get(rewriter.getMultiDimIdentityMap(nloops)));
|
||||
|
||||
auto loc = iotaOp.getLoc();
|
||||
auto linalgOp = rewriter.create<linalg::IndexedGenericOp>(
|
||||
loc, ArrayRef<Type>{}, args,
|
||||
rewriter.getI64IntegerAttr(0), // args_in
|
||||
rewriter.getI64IntegerAttr(1), // args_out
|
||||
rewriter.getArrayAttr(indexingMaps),
|
||||
GetNParallelLoopsAttrs(nloops, &rewriter),
|
||||
/*doc=*/nullptr, /*library_call=*/nullptr);
|
||||
0, // args_in
|
||||
1, // args_out
|
||||
llvm::makeArrayRef(rewriter.getMultiDimIdentityMap(nloops)),
|
||||
GetNParallelLoopsAttrs(nloops));
|
||||
|
||||
// Add a block to the region.
|
||||
auto* region = &linalgOp.region();
|
||||
@ -857,7 +742,7 @@ class ReverseConverter
|
||||
public:
|
||||
using DataMovementOpConverter<ReverseConverter<OpTy, isLHLO>, OpTy,
|
||||
isLHLO>::DataMovementOpConverter;
|
||||
static ArrayAttr getIndexingMapsAttr(OpTy op, Builder* b) {
|
||||
static SmallVector<AffineMap, 2> getIndexingMaps(OpTy op, Builder* b) {
|
||||
auto resultType =
|
||||
getXLAOpResultType<isLHLO>(op).template cast<ShapedType>();
|
||||
auto nloops = resultType.getRank();
|
||||
@ -871,9 +756,9 @@ class ReverseConverter
|
||||
int n = resultType.getShape()[i];
|
||||
inputExprs[i] = b->getAffineConstantExpr(n - 1) - inputExprs[i];
|
||||
}
|
||||
return b->getAffineMapArrayAttr(
|
||||
{AffineMap::get(nloops, /*symbolCount=*/0, inputExprs, b->getContext()),
|
||||
b->getMultiDimIdentityMap(nloops)});
|
||||
return {
|
||||
AffineMap::get(nloops, /*symbolCount=*/0, inputExprs, b->getContext()),
|
||||
b->getMultiDimIdentityMap(nloops)};
|
||||
}
|
||||
};
|
||||
|
||||
@ -946,7 +831,7 @@ void populateLHLOToLinalgConversionPattern(MLIRContext* context,
|
||||
PointwiseToLinalgConverter<xla_lhlo::SqrtOp>,
|
||||
PointwiseToLinalgConverter<xla_lhlo::SubOp>,
|
||||
PointwiseToLinalgConverter<xla_lhlo::TanhOp>,
|
||||
ReshapeAddRemoveDimConverter<xla_lhlo::ReshapeOp>,
|
||||
ReshapeOpConverter<xla_lhlo::ReshapeOp>,
|
||||
ReverseConverter<xla_lhlo::ReverseOp>,
|
||||
ScalarPointwiseToStandardConverter<xla_lhlo::AddOp>,
|
||||
SliceConverter
|
||||
@ -1045,7 +930,6 @@ void populateHLOToLinalgConversionPattern(MLIRContext* context,
|
||||
PointwiseToLinalgConverter<xla_hlo::SqrtOp, false>,
|
||||
PointwiseToLinalgConverter<xla_hlo::SubOp, false>,
|
||||
PointwiseToLinalgConverter<xla_hlo::TanhOp, false>,
|
||||
ReshapeAddRemoveDimConverter<xla_hlo::ReshapeOp, false>,
|
||||
ReshapeOpConverter<xla_hlo::ReshapeOp, false>,
|
||||
ReverseConverter<xla_hlo::ReverseOp, false>,
|
||||
TransposeConverter<xla_hlo::TransposeOp, false>>(context);
|
||||
|
@ -185,6 +185,7 @@ tf_xla_py_test(
|
||||
name = "argminmax_test",
|
||||
size = "small",
|
||||
srcs = ["argminmax_test.py"],
|
||||
enable_mlir_bridge = True,
|
||||
python_version = "PY3",
|
||||
tags = [
|
||||
"no_pip", # TODO(b/149738646): fix pip install so these tests run on kokoro pip
|
||||
@ -253,6 +254,7 @@ tf_xla_py_test(
|
||||
name = "bucketize_op_test",
|
||||
size = "small",
|
||||
srcs = ["bucketize_op_test.py"],
|
||||
enable_mlir_bridge = True,
|
||||
python_version = "PY3",
|
||||
tags = [
|
||||
"no_pip", # TODO(b/149738646): fix pip install so these tests run on kokoro pip
|
||||
@ -692,6 +694,7 @@ tf_xla_py_test(
|
||||
name = "fft_test",
|
||||
size = "medium",
|
||||
srcs = ["fft_test.py"],
|
||||
enable_mlir_bridge = True,
|
||||
python_version = "PY3",
|
||||
shard_count = 6,
|
||||
tags = [
|
||||
@ -805,6 +808,7 @@ tf_xla_py_test(
|
||||
name = "lrn_ops_test",
|
||||
size = "medium",
|
||||
srcs = ["lrn_ops_test.py"],
|
||||
enable_mlir_bridge = True,
|
||||
python_version = "PY3",
|
||||
tags = [
|
||||
"no_pip", # TODO(b/149738646): fix pip install so these tests run on kokoro pip
|
||||
@ -1129,6 +1133,7 @@ tf_xla_py_test(
|
||||
name = "reverse_sequence_op_test",
|
||||
size = "medium",
|
||||
srcs = ["reverse_sequence_op_test.py"],
|
||||
enable_mlir_bridge = True,
|
||||
python_version = "PY3",
|
||||
tags = [
|
||||
"no_pip", # TODO(b/149738646): fix pip install so these tests run on kokoro pip
|
||||
@ -1218,6 +1223,7 @@ tf_xla_py_test(
|
||||
name = "sparse_to_dense_op_test",
|
||||
size = "small",
|
||||
srcs = ["sparse_to_dense_op_test.py"],
|
||||
enable_mlir_bridge = True,
|
||||
python_version = "PY3",
|
||||
tags = [
|
||||
"no_pip", # TODO(b/149738646): fix pip install so these tests run on kokoro pip
|
||||
|
@ -1225,8 +1225,6 @@ class BinaryOpsTest(xla_test.XLATestCase):
|
||||
[7, 7, 7, 7, 7, 7]],
|
||||
dtype=dtype))
|
||||
|
||||
@test_util.disable_mlir_bridge(
|
||||
"Requires concatenate op support in MlirHloBuilder")
|
||||
def testSymmetricMirrorPad(self):
|
||||
mirror_pad = lambda t, paddings: array_ops.pad(t, paddings, "SYMMETRIC")
|
||||
for dtype in self.numeric_types:
|
||||
@ -1258,8 +1256,6 @@ class BinaryOpsTest(xla_test.XLATestCase):
|
||||
np.array([[0, 0], [0, 0]], dtype=np.int32),
|
||||
expected=np.array([[1, 2, 3], [4, 5, 6]], dtype=dtype))
|
||||
|
||||
@test_util.disable_mlir_bridge(
|
||||
"Requires concatenate op support in MlirHloBuilder")
|
||||
def testReflectMirrorPad(self):
|
||||
mirror_pad = lambda t, paddings: array_ops.pad(t, paddings, "REFLECT")
|
||||
for dtype in self.numeric_types:
|
||||
|
@ -21,6 +21,7 @@ from __future__ import print_function
|
||||
from tensorflow.compiler.tests import xla_test
|
||||
from tensorflow.python.framework import dtypes
|
||||
from tensorflow.python.framework import errors_impl
|
||||
from tensorflow.python.framework import test_util
|
||||
from tensorflow.python.ops import array_ops
|
||||
from tensorflow.python.ops import math_ops
|
||||
from tensorflow.python.platform import test
|
||||
@ -57,6 +58,7 @@ class BucketizationOpTest(xla_test.XLATestCase):
|
||||
expected_out, sess.run(op,
|
||||
{p: [[-5, 0, 2, 3, 5], [8, 10, 11, 12, 0]]}))
|
||||
|
||||
@test_util.disable_mlir_bridge("Error handling")
|
||||
def testInvalidBoundariesOrder(self):
|
||||
with self.session() as sess:
|
||||
p = array_ops.placeholder(dtypes.int32)
|
||||
|
@ -78,10 +78,17 @@ class ConvolutionNodeNameTest(xla_test.XLATestCase):
|
||||
|
||||
xla_names = _GetNodeNames(use_xla=True)
|
||||
no_xla_names = _GetNodeNames(use_xla=False)
|
||||
self.assertListEqual(
|
||||
xla_names,
|
||||
no_xla_names,
|
||||
)
|
||||
|
||||
# CPU path creates some additional nodes to handle dilations.
|
||||
# TODO(b/138804006): Remove this when CPU & GPU support dilations.
|
||||
filtered_no_xla_names = []
|
||||
for name in no_xla_names:
|
||||
if ("dilation_rate" in name or "filter_shape" in name or "stack" in name):
|
||||
continue
|
||||
else:
|
||||
filtered_no_xla_names.append(name)
|
||||
|
||||
self.assertListEqual(xla_names, filtered_no_xla_names)
|
||||
|
||||
def testConv1DNodeNameMatch(self):
|
||||
input_sizes = [8, 16, 3]
|
||||
|
@ -22,6 +22,7 @@ import numpy as np
|
||||
|
||||
from tensorflow.compiler.tests import xla_test
|
||||
from tensorflow.python.framework import dtypes
|
||||
from tensorflow.python.framework import test_util
|
||||
from tensorflow.python.ops import array_ops
|
||||
from tensorflow.python.ops import sparse_ops
|
||||
from tensorflow.python.platform import test
|
||||
@ -101,6 +102,7 @@ class SparseToDenseTest(xla_test.XLATestCase):
|
||||
with self.assertRaisesWithPredicateMatch(ValueError, "must be rank 1"):
|
||||
_SparseToDense([1, 3], [[5], [3]], 1, -1)
|
||||
|
||||
@test_util.disable_mlir_bridge("Error handling")
|
||||
def testBadValue(self):
|
||||
with self.session(), self.test_scope():
|
||||
with self.assertRaisesOpError(
|
||||
@ -108,12 +110,14 @@ class SparseToDenseTest(xla_test.XLATestCase):
|
||||
r"should be \[\] or \[2\]"):
|
||||
_SparseToDense([1, 3], [5], [[5], [3]], -1)
|
||||
|
||||
@test_util.disable_mlir_bridge("Error handling")
|
||||
def testBadNumValues(self):
|
||||
with self.session(), self.test_scope():
|
||||
with self.assertRaisesOpError(
|
||||
r"sparse_values has incorrect shape \[3\], should be \[\] or \[2\]"):
|
||||
_SparseToDense([1, 3], [5], [1, 2, 3], -1)
|
||||
|
||||
@test_util.disable_mlir_bridge("Error handling")
|
||||
def testBadDefault(self):
|
||||
with self.session(), self.test_scope():
|
||||
with self.assertRaisesOpError("default_value should be a scalar"):
|
||||
|
@ -918,16 +918,12 @@ class UnaryOpsTest(xla_test.XLATestCase):
|
||||
np.array([1, 0x100000003f800000], np.int64),
|
||||
expected=np.array([1, 0x100000003f800000], np.uint64))
|
||||
|
||||
@test_util.disable_mlir_bridge(
|
||||
"TODO(b/153812660): Handle tf.InvertPermutation compilation")
|
||||
def testInvertPermutation(self):
|
||||
self._assertOpOutputMatchesExpected(
|
||||
array_ops.invert_permutation,
|
||||
np.array([1, 2, 0], np.int32),
|
||||
expected=np.array([2, 0, 1], dtype=np.int32))
|
||||
|
||||
@test_util.disable_mlir_bridge(
|
||||
"TODO(b/153812660): Handle tf.InvertPermutation compilation")
|
||||
def testInvertPermutationTwiceIsNoop(self):
|
||||
self._assertOpOutputMatchesExpected(
|
||||
lambda x: array_ops.invert_permutation(array_ops.invert_permutation(x)),
|
||||
@ -1144,8 +1140,6 @@ class UnaryOpsTest(xla_test.XLATestCase):
|
||||
rtol=rtol,
|
||||
atol=atol)
|
||||
|
||||
@test_util.disable_mlir_bridge(
|
||||
"bf16 type not supported in CreateDenseElementsAttrFromLiteral")
|
||||
def testSoftplus(self):
|
||||
for dtype in self.float_types & {dtypes.float32, dtypes.float64}:
|
||||
self._assertSoftplusMatchesExpected([[-2, 0, 8]], dtype)
|
||||
|
@ -132,7 +132,7 @@ cc_library(
|
||||
"//tensorflow/stream_executor:device_memory_allocator",
|
||||
"@com_google_absl//absl/memory",
|
||||
"@com_google_absl//absl/types:span",
|
||||
"@llvm-project//llvm:support",
|
||||
"@llvm-project//llvm:Support",
|
||||
],
|
||||
)
|
||||
|
||||
@ -150,7 +150,7 @@ cc_library(
|
||||
"//tensorflow/compiler/xla/service:compiler",
|
||||
"//tensorflow/core:stream_executor_no_cuda",
|
||||
"@com_google_absl//absl/memory",
|
||||
"@llvm-project//llvm:support",
|
||||
"@llvm-project//llvm:Support",
|
||||
],
|
||||
)
|
||||
|
||||
|
Some files were not shown because too many files have changed in this diff Show More
Loading…
Reference in New Issue
Block a user