Merge branch 'master' of github.com:ashahba/tensorflow into ashahba/ubuntu-onednn-partials

This commit is contained in:
Abolfazl Shahbazi 2020-07-08 15:50:53 -07:00
commit 9fa46cf554
302 changed files with 5214 additions and 3441 deletions

View File

@ -50,6 +50,9 @@
* Tracing and Debugging:
* <ADD RELEASE NOTES HERE>
* Other:
* We have replaced uses of "whitelist" with "allowlist" where possible.
Please see https://developers.google.com/style/word-list#blacklist for more
context.
* <ADD RELEASE NOTES HERE>
## Thanks to our Contributors

View File

@ -44,7 +44,7 @@ Even if the untrusted party only supplies the serialized computation
graph (in form of a `GraphDef`, `SavedModel`, or equivalent on-disk format), the
set of computation primitives available to TensorFlow is powerful enough that
you should assume that the TensorFlow process effectively executes arbitrary
code. One common solution is to whitelist only a few safe Ops. While this is
code. One common solution is to allow only a few safe Ops. While this is
possible in theory, we still recommend you sandbox the execution.
It depends on the computation graph whether a user provided checkpoint is safe.

View File

@ -260,6 +260,36 @@ config_setting(
visibility = ["//visibility:public"],
)
config_setting(
name = "armeabi",
values = {"cpu": "armeabi"},
visibility = ["//visibility:public"],
)
config_setting(
name = "armeabi-v7a",
values = {"cpu": "armeabi-v7a"},
visibility = ["//visibility:public"],
)
config_setting(
name = "arm64-v8a",
values = {"cpu": "arm64-v8a"},
visibility = ["//visibility:public"],
)
selects.config_setting_group(
name = "arm_any",
match_any = [
":arm",
":armeabi",
":armeabi-v7a",
":arm64-v8a",
":linux_aarch64",
":linux_armhf",
],
)
config_setting(
name = "freebsd",
values = {"cpu": "freebsd"},

View File

@ -337,10 +337,13 @@ tensorflow::Status CreateRemoteContexts(
});
}
counter.Wait();
tensorflow::StatusGroup sg;
for (int i = 0; i < num_remote_workers; i++) {
TF_RETURN_IF_ERROR(statuses[i]);
if (TF_PREDICT_FALSE(!statuses[i].ok())) {
sg.Update(statuses[i]);
}
}
return tensorflow::Status::OK();
return sg.as_summary_status();
}
tensorflow::Status UpdateRemoteContexts(
@ -611,10 +614,21 @@ tensorflow::Status UpdateTFE_ContextWithServerDef(
// Initialize remote eager workers.
if (reset_context) {
LOG_AND_RETURN_IF_ERROR(CreateRemoteContexts(
const tensorflow::Status s = CreateRemoteContexts(
ctx, remote_workers, context_id, context_view_id, keep_alive_secs,
server_def, remote_eager_workers.get(), context->Executor().Async(),
context->LazyCopyFunctionRemoteInputs(), base_request));
context->LazyCopyFunctionRemoteInputs(), base_request);
// NOTE: the remote tasks could fail after `GetAllRemoteDevices` and cause
// the CreateRemoteContexts to fail. We currently only log instead of
// directly returning the error, since returning here will cause the server
// object to be destroyed (which currently CHECK-fails). The client will
// see additional errors if ops are subsequently sent to the failed workers.
if (TF_PREDICT_FALSE(!s.ok())) {
LOG(ERROR) << "Error when creating contexts on remote targets: "
<< s.error_message()
<< "\nExecuting remote ops or functions on these remote "
"targets will fail.";
}
} else {
// The master's context_view_id will be incremented by one
// the UpdateRemoteMaster call later. We want all new workers and
@ -644,15 +658,16 @@ tensorflow::Status UpdateTFE_ContextWithServerDef(
grpc_server->worker_env()->rendezvous_mgr->Find(context_id);
auto* device_mgr = grpc_server->worker_env()->device_mgr;
std::shared_ptr<tensorflow::WorkerSession> worker_session;
TF_RETURN_IF_ERROR(grpc_server->worker_env()->session_mgr->CreateSession(
session_name, server_def, base_request.cluster_device_attributes(),
true));
TF_RETURN_IF_ERROR(
LOG_AND_RETURN_IF_ERROR(
grpc_server->worker_env()->session_mgr->CreateSession(
session_name, server_def, base_request.cluster_device_attributes(),
true));
LOG_AND_RETURN_IF_ERROR(
grpc_server->worker_env()->session_mgr->WorkerSessionForSession(
session_name, &worker_session));
// Initialize remote tensor communication based on worker session.
TF_RETURN_IF_ERROR(r->Initialize(worker_session.get()));
LOG_AND_RETURN_IF_ERROR(r->Initialize(worker_session.get()));
tensorflow::DistributedFunctionLibraryRuntime* cluster_flr =
tensorflow::eager::CreateClusterFLR(context_id, context,

View File

@ -52,6 +52,25 @@ cc_library(
],
)
cc_library(
name = "cleanup",
hdrs = ["cleanup.h"],
)
cc_library(
name = "ram_file_block_cache",
srcs = ["ram_file_block_cache.cc"],
hdrs = ["ram_file_block_cache.h"],
deps = [
":cleanup",
":file_block_cache",
"//tensorflow/c:env",
"//tensorflow/c:tf_status",
"@com_google_absl//absl/base",
"@com_google_absl//absl/synchronization",
],
)
tf_cc_test(
name = "gcs_filesystem_test",
srcs = [

View File

@ -0,0 +1,111 @@
/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
// MakeCleanup(f) returns an RAII cleanup object that calls 'f' in its
// destructor. The easiest way to use MakeCleanup is with a lambda argument,
// capturing the return value in an 'auto' local variable. Most users will not
// need more sophisticated syntax than that.
//
// Example:
// void func() {
// FILE* fp = fopen("data.txt", "r");
// if (fp == nullptr) return;
// auto fp_cleaner = gtl::MakeCleanup([fp] { fclose(fp); });
// // No matter what, fclose(fp) will happen.
// DataObject d;
// while (ReadDataObject(fp, &d)) {
// if (d.IsBad()) {
// LOG(ERROR) << "Bad Data";
// return;
// }
// PushGoodData(d);
// }
// }
//
// You can use Cleanup<F> directly, instead of using MakeCleanup and auto,
// but there's rarely a reason to do that.
//
// You can call 'release()' on a Cleanup object to cancel the cleanup.
#ifndef TENSORFLOW_C_EXPERIMENTAL_FILESYSTEM_PLUGINS_GCS_CLEANUP_H_
#define TENSORFLOW_C_EXPERIMENTAL_FILESYSTEM_PLUGINS_GCS_CLEANUP_H_
#include <type_traits>
#include <utility>
#include "tensorflow/core/platform/macros.h"
namespace tf_gcs_filesystem {
// A move-only RAII object that calls a stored cleanup functor when
// destroyed. Cleanup<F> is the return type of gtl::MakeCleanup(F).
template <typename F>
class Cleanup {
public:
Cleanup() : released_(true), f_() {}
template <typename G>
explicit Cleanup(G&& f) // NOLINT
: f_(std::forward<G>(f)) {} // NOLINT(build/c++11)
Cleanup(Cleanup&& src) // NOLINT
: released_(src.is_released()), f_(src.release()) {}
// Implicitly move-constructible from any compatible Cleanup<G>.
// The source will be released as if src.release() were called.
// A moved-from Cleanup can be safely destroyed or reassigned.
template <typename G>
Cleanup(Cleanup<G>&& src) // NOLINT
: released_(src.is_released()), f_(src.release()) {}
// Assignment to a Cleanup object behaves like destroying it
// and making a new one in its place, analogous to unique_ptr
// semantics.
Cleanup& operator=(Cleanup&& src) { // NOLINT
if (!released_) f_();
released_ = src.released_;
f_ = src.release();
return *this;
}
~Cleanup() {
if (!released_) f_();
}
// Releases the cleanup function instead of running it.
// Hint: use c.release()() to run early.
F release() {
released_ = true;
return std::move(f_);
}
bool is_released() const { return released_; }
private:
static_assert(!std::is_reference<F>::value, "F must not be a reference");
bool released_ = false;
F f_;
};
template <int&... ExplicitParameterBarrier, typename F,
typename DecayF = typename std::decay<F>::type>
Cleanup<DecayF> MakeCleanup(F&& f) {
return Cleanup<DecayF>(std::forward<F>(f));
}
} // namespace tf_gcs_filesystem
#endif // TENSORFLOW_C_EXPERIMENTAL_FILESYSTEM_PLUGINS_GCS_CLEANUP_H_

View File

@ -1,8 +1,11 @@
/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.

View File

@ -0,0 +1,317 @@
/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#include "tensorflow/c/experimental/filesystem/plugins/gcs/ram_file_block_cache.h"
#include <cstring>
#include <memory>
#include <sstream>
#include <utility>
#include "absl/synchronization/mutex.h"
#include "tensorflow/c/experimental/filesystem/plugins/gcs/cleanup.h"
namespace tf_gcs_filesystem {
bool RamFileBlockCache::BlockNotStale(const std::shared_ptr<Block>& block) {
absl::MutexLock l(&block->mu);
if (block->state != FetchState::FINISHED) {
return true; // No need to check for staleness.
}
if (max_staleness_ == 0) return true; // Not enforcing staleness.
return timer_seconds_() - block->timestamp <= max_staleness_;
}
std::shared_ptr<RamFileBlockCache::Block> RamFileBlockCache::Lookup(
const Key& key) {
absl::MutexLock lock(&mu_);
auto entry = block_map_.find(key);
if (entry != block_map_.end()) {
if (BlockNotStale(entry->second)) {
if (cache_stats_ != nullptr) {
cache_stats_->RecordCacheHitBlockSize(entry->second->data.size());
}
return entry->second;
} else {
// Remove the stale block and continue.
RemoveFile_Locked(key.first);
}
}
// Insert a new empty block, setting the bookkeeping to sentinel values
// in order to update them as appropriate.
auto new_entry = std::make_shared<Block>();
lru_list_.push_front(key);
lra_list_.push_front(key);
new_entry->lru_iterator = lru_list_.begin();
new_entry->lra_iterator = lra_list_.begin();
new_entry->timestamp = timer_seconds_();
block_map_.emplace(std::make_pair(key, new_entry));
return new_entry;
}
// Remove blocks from the cache until we do not exceed our maximum size.
void RamFileBlockCache::Trim() {
while (!lru_list_.empty() && cache_size_ > max_bytes_) {
RemoveBlock(block_map_.find(lru_list_.back()));
}
}
/// Move the block to the front of the LRU list if it isn't already there.
void RamFileBlockCache::UpdateLRU(const Key& key,
const std::shared_ptr<Block>& block,
TF_Status* status) {
absl::MutexLock lock(&mu_);
if (block->timestamp == 0) {
// The block was evicted from another thread. Allow it to remain evicted.
return TF_SetStatus(status, TF_OK, "");
}
if (block->lru_iterator != lru_list_.begin()) {
lru_list_.erase(block->lru_iterator);
lru_list_.push_front(key);
block->lru_iterator = lru_list_.begin();
}
// Check for inconsistent state. If there is a block later in the same file
// in the cache, and our current block is not block size, this likely means
// we have inconsistent state within the cache. Note: it's possible some
// incomplete reads may still go undetected.
if (block->data.size() < block_size_) {
Key fmax = std::make_pair(key.first, std::numeric_limits<size_t>::max());
auto fcmp = block_map_.upper_bound(fmax);
if (fcmp != block_map_.begin() && key < (--fcmp)->first) {
return TF_SetStatus(status, TF_INTERNAL,
"Block cache contents are inconsistent.");
}
}
Trim();
return TF_SetStatus(status, TF_OK, "");
}
void RamFileBlockCache::MaybeFetch(const Key& key,
const std::shared_ptr<Block>& block,
TF_Status* status) {
bool downloaded_block = false;
auto reconcile_state = MakeCleanup([this, &downloaded_block, &key, &block] {
// Perform this action in a cleanup callback to avoid locking mu_ after
// locking block->mu.
if (downloaded_block) {
absl::MutexLock l(&mu_);
// Do not update state if the block is already to be evicted.
if (block->timestamp != 0) {
// Use capacity() instead of size() to account for all memory
// used by the cache.
cache_size_ += block->data.capacity();
// Put to beginning of LRA list.
lra_list_.erase(block->lra_iterator);
lra_list_.push_front(key);
block->lra_iterator = lra_list_.begin();
block->timestamp = timer_seconds_();
}
}
});
// Loop until either block content is successfully fetched, or our request
// encounters an error.
absl::MutexLock l(&block->mu);
TF_SetStatus(status, TF_OK, "");
while (true) {
switch (block->state) {
case FetchState::ERROR:
// TF_FALLTHROUGH_INTENDED
case FetchState::CREATED:
block->state = FetchState::FETCHING;
block->mu.Unlock(); // Release the lock while making the API call.
block->data.clear();
block->data.resize(block_size_, 0);
size_t bytes_transferred;
block_fetcher_(key.first, key.second, block_size_, block->data.data(),
&bytes_transferred, status);
if (cache_stats_ != nullptr) {
cache_stats_->RecordCacheMissBlockSize(bytes_transferred);
}
block->mu.Lock(); // Reacquire the lock immediately afterwards
if (TF_GetCode(status) == TF_OK) {
block->data.resize(bytes_transferred, 0);
// Shrink the data capacity to the actual size used.
// NOLINTNEXTLINE: shrink_to_fit() may not shrink the capacity.
std::vector<char>(block->data).swap(block->data);
downloaded_block = true;
block->state = FetchState::FINISHED;
} else {
block->state = FetchState::ERROR;
}
block->cond_var.SignalAll();
return;
case FetchState::FETCHING:
block->cond_var.WaitWithTimeout(&block->mu, absl::Minutes(1));
if (block->state == FetchState::FINISHED) {
return TF_SetStatus(status, TF_OK, "");
}
// Re-loop in case of errors.
break;
case FetchState::FINISHED:
return TF_SetStatus(status, TF_OK, "");
}
}
return TF_SetStatus(
status, TF_INTERNAL,
"Control flow should never reach the end of RamFileBlockCache::Fetch.");
}
void RamFileBlockCache::Read(const std::string& filename, size_t offset,
size_t n, char* buffer, size_t* bytes_transferred,
TF_Status* status) {
*bytes_transferred = 0;
if (n == 0) {
return TF_SetStatus(status, TF_OK, "");
}
if (!IsCacheEnabled() || (n > max_bytes_)) {
// The cache is effectively disabled, so we pass the read through to the
// fetcher without breaking it up into blocks.
return block_fetcher_(filename, offset, n, buffer, bytes_transferred,
status);
}
// Calculate the block-aligned start and end of the read.
size_t start = block_size_ * (offset / block_size_);
size_t finish = block_size_ * ((offset + n) / block_size_);
if (finish < offset + n) {
finish += block_size_;
}
size_t total_bytes_transferred = 0;
// Now iterate through the blocks, reading them one at a time.
for (size_t pos = start; pos < finish; pos += block_size_) {
Key key = std::make_pair(filename, pos);
// Look up the block, fetching and inserting it if necessary, and update the
// LRU iterator for the key and block.
std::shared_ptr<Block> block = Lookup(key);
if (!block) {
std::cerr << "No block for key " << key.first << "@" << key.second;
abort();
}
MaybeFetch(key, block, status);
if (TF_GetCode(status) != TF_OK) return;
UpdateLRU(key, block, status);
if (TF_GetCode(status) != TF_OK) return;
// Copy the relevant portion of the block into the result buffer.
const auto& data = block->data;
if (offset >= pos + data.size()) {
// The requested offset is at or beyond the end of the file. This can
// happen if `offset` is not block-aligned, and the read returns the last
// block in the file, which does not extend all the way out to `offset`.
*bytes_transferred = total_bytes_transferred;
std::stringstream os;
os << "EOF at offset " << offset << " in file " << filename
<< " at position " << pos << " with data size " << data.size();
return TF_SetStatus(status, TF_OUT_OF_RANGE, std::move(os).str().c_str());
}
auto begin = data.begin();
if (offset > pos) {
// The block begins before the slice we're reading.
begin += offset - pos;
}
auto end = data.end();
if (pos + data.size() > offset + n) {
// The block extends past the end of the slice we're reading.
end -= (pos + data.size()) - (offset + n);
}
if (begin < end) {
size_t bytes_to_copy = end - begin;
memcpy(&buffer[total_bytes_transferred], &*begin, bytes_to_copy);
total_bytes_transferred += bytes_to_copy;
}
if (data.size() < block_size_) {
// The block was a partial block and thus signals EOF at its upper bound.
break;
}
}
*bytes_transferred = total_bytes_transferred;
return TF_SetStatus(status, TF_OK, "");
}
bool RamFileBlockCache::ValidateAndUpdateFileSignature(
const std::string& filename, int64_t file_signature) {
absl::MutexLock lock(&mu_);
auto it = file_signature_map_.find(filename);
if (it != file_signature_map_.end()) {
if (it->second == file_signature) {
return true;
}
// Remove the file from cache if the signatures don't match.
RemoveFile_Locked(filename);
it->second = file_signature;
return false;
}
file_signature_map_[filename] = file_signature;
return true;
}
size_t RamFileBlockCache::CacheSize() const {
absl::MutexLock lock(&mu_);
return cache_size_;
}
void RamFileBlockCache::Prune() {
while (!stop_pruning_thread_.WaitForNotificationWithTimeout(
absl::Microseconds(1000000))) {
absl::MutexLock lock(&mu_);
uint64_t now = timer_seconds_();
while (!lra_list_.empty()) {
auto it = block_map_.find(lra_list_.back());
if (now - it->second->timestamp <= max_staleness_) {
// The oldest block is not yet expired. Come back later.
break;
}
// We need to make a copy of the filename here, since it could otherwise
// be used within RemoveFile_Locked after `it` is deleted.
RemoveFile_Locked(std::string(it->first.first));
}
}
}
void RamFileBlockCache::Flush() {
absl::MutexLock lock(&mu_);
block_map_.clear();
lru_list_.clear();
lra_list_.clear();
cache_size_ = 0;
}
void RamFileBlockCache::RemoveFile(const std::string& filename) {
absl::MutexLock lock(&mu_);
RemoveFile_Locked(filename);
}
void RamFileBlockCache::RemoveFile_Locked(const std::string& filename) {
Key begin = std::make_pair(filename, 0);
auto it = block_map_.lower_bound(begin);
while (it != block_map_.end() && it->first.first == filename) {
auto next = std::next(it);
RemoveBlock(it);
it = next;
}
}
void RamFileBlockCache::RemoveBlock(BlockMap::iterator entry) {
// This signals that the block is removed, and should not be inadvertently
// reinserted into the cache in UpdateLRU.
entry->second->timestamp = 0;
lru_list_.erase(entry->second->lru_iterator);
lra_list_.erase(entry->second->lra_iterator);
cache_size_ -= entry->second->data.capacity();
block_map_.erase(entry);
}
} // namespace tf_gcs_filesystem

View File

@ -0,0 +1,267 @@
/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#ifndef TENSORFLOW_C_EXPERIMENTAL_FILESYSTEM_PLUGINS_GCS_RAM_FILE_BLOCK_CACHE_H_
#define TENSORFLOW_C_EXPERIMENTAL_FILESYSTEM_PLUGINS_GCS_RAM_FILE_BLOCK_CACHE_H_
#include <functional>
#include <iostream>
#include <list>
#include <map>
#include <memory>
#include <string>
#include <vector>
#include "absl/base/thread_annotations.h"
#include "absl/synchronization/mutex.h"
#include "absl/synchronization/notification.h"
#include "tensorflow/c/env.h"
#include "tensorflow/c/experimental/filesystem/plugins/gcs/file_block_cache.h"
#include "tensorflow/c/tf_status.h"
namespace tf_gcs_filesystem {
/// \brief An LRU block cache of file contents, keyed by {filename, offset}.
///
/// This class should be shared by read-only random access files on a remote
/// filesystem (e.g. GCS).
class RamFileBlockCache : public FileBlockCache {
public:
/// The callback executed when a block is not found in the cache, and needs to
/// be fetched from the backing filesystem. This callback is provided when the
/// cache is constructed. The `status` should be `TF_OK` as long as the
/// read from the remote filesystem succeeded (similar to the semantics of the
/// read(2) system call).
typedef std::function<void(const std::string& filename, size_t offset,
size_t buffer_size, char* buffer,
size_t* bytes_transferred, TF_Status* status)>
BlockFetcher;
RamFileBlockCache(size_t block_size, size_t max_bytes, uint64_t max_staleness,
BlockFetcher block_fetcher,
std::function<uint64_t()> timer_seconds)
: block_size_(block_size),
max_bytes_(max_bytes),
max_staleness_(max_staleness),
block_fetcher_(block_fetcher),
timer_seconds_(timer_seconds),
pruning_thread_(nullptr,
[](TF_Thread* thread) { TF_JoinThread(thread); }) {
if (max_staleness_ > 0) {
TF_ThreadOptions thread_options;
TF_DefaultThreadOptions(&thread_options);
pruning_thread_.reset(
TF_StartThread(&thread_options, "TF_prune_FBC", PruneThread, this));
}
std::cout << "GCS file block cache is "
<< (IsCacheEnabled() ? "enabled" : "disabled");
}
~RamFileBlockCache() override {
if (pruning_thread_) {
stop_pruning_thread_.Notify();
// Destroying pruning_thread_ will block until Prune() receives the above
// notification and returns.
pruning_thread_.reset();
}
}
/// Read `n` bytes from `filename` starting at `offset` into `buffer`. This
/// method will set `status` to:
///
/// 1) The error from the remote filesystem, if the read from the remote
/// filesystem failed.
/// 2) `TF_FAILED_PRECONDITION` if the read from the remote filesystem
/// succeeded,
/// but the read returned a partial block, and the LRU cache contained a
/// block at a higher offset (indicating that the partial block should have
/// been a full block).
/// 3) `TF_OUT_OF_RANGE` if the read from the remote filesystem succeeded, but
/// the file contents do not extend past `offset` and thus nothing was
/// placed in `out`.
/// 4) `TF_OK` otherwise (i.e. the read succeeded, and at least one byte was
/// placed
/// in `buffer`).
///
/// Caller is responsible for allocating memory for `buffer`.
/// `buffer` will be left unchanged in case of errors.
void Read(const std::string& filename, size_t offset, size_t n, char* buffer,
size_t* bytes_transferred, TF_Status* status) override;
// Validate the given file signature with the existing file signature in the
// cache. Returns true if the signature doesn't change or the file doesn't
// exist before. If the signature changes, update the existing signature with
// the new one and remove the file from cache.
bool ValidateAndUpdateFileSignature(const std::string& filename,
int64_t file_signature) override
ABSL_LOCKS_EXCLUDED(mu_);
/// Remove all cached blocks for `filename`.
void RemoveFile(const std::string& filename) override
ABSL_LOCKS_EXCLUDED(mu_);
/// Remove all cached data.
void Flush() override ABSL_LOCKS_EXCLUDED(mu_);
/// Accessors for cache parameters.
size_t block_size() const override { return block_size_; }
size_t max_bytes() const override { return max_bytes_; }
uint64_t max_staleness() const override { return max_staleness_; }
/// The current size (in bytes) of the cache.
size_t CacheSize() const override ABSL_LOCKS_EXCLUDED(mu_);
// Returns true if the cache is enabled. If false, the BlockFetcher callback
// is always executed during Read.
bool IsCacheEnabled() const override {
return block_size_ > 0 && max_bytes_ > 0;
}
// We can not pass a lambda with capture as a function pointer to
// `TF_StartThread`, so we have to wrap `Prune` inside a static function.
static void PruneThread(void* param) {
auto ram_file_block_cache = static_cast<RamFileBlockCache*>(param);
ram_file_block_cache->Prune();
}
private:
/// The size of the blocks stored in the LRU cache, as well as the size of the
/// reads from the underlying filesystem.
const size_t block_size_;
/// The maximum number of bytes (sum of block sizes) allowed in the LRU cache.
const size_t max_bytes_;
/// The maximum staleness of any block in the LRU cache, in seconds.
const uint64_t max_staleness_;
/// The callback to read a block from the underlying filesystem.
const BlockFetcher block_fetcher_;
/// The callback to read timestamps.
const std::function<uint64_t()> timer_seconds_;
/// \brief The key type for the file block cache.
///
/// The file block cache key is a {filename, offset} pair.
typedef std::pair<std::string, size_t> Key;
/// \brief The state of a block.
///
/// A block begins in the CREATED stage. The first thread will attempt to read
/// the block from the filesystem, transitioning the state of the block to
/// FETCHING. After completing, if the read was successful the state should
/// be FINISHED. Otherwise the state should be ERROR. A subsequent read can
/// re-fetch the block if the state is ERROR.
enum class FetchState {
CREATED,
FETCHING,
FINISHED,
ERROR,
};
/// \brief A block of a file.
///
/// A file block consists of the block data, the block's current position in
/// the LRU cache, the timestamp (seconds since epoch) at which the block
/// was cached, a coordination lock, and state & condition variables.
///
/// Thread safety:
/// The iterator and timestamp fields should only be accessed while holding
/// the block-cache-wide mu_ instance variable. The state variable should only
/// be accessed while holding the Block's mu lock. The data vector should only
/// be accessed after state == FINISHED, and it should never be modified.
///
/// In order to prevent deadlocks, never grab the block-cache-wide mu_ lock
/// AFTER grabbing any block's mu lock. It is safe to grab mu without locking
/// mu_.
struct Block {
/// The block data.
std::vector<char> data;
/// A list iterator pointing to the block's position in the LRU list.
std::list<Key>::iterator lru_iterator;
/// A list iterator pointing to the block's position in the LRA list.
std::list<Key>::iterator lra_iterator;
/// The timestamp (seconds since epoch) at which the block was cached.
uint64_t timestamp;
/// Mutex to guard state variable
absl::Mutex mu;
/// The state of the block.
FetchState state ABSL_GUARDED_BY(mu) = FetchState::CREATED;
/// Wait on cond_var if state is FETCHING.
absl::CondVar cond_var;
};
/// \brief The block map type for the file block cache.
///
/// The block map is an ordered map from Key to Block.
typedef std::map<Key, std::shared_ptr<Block>> BlockMap;
/// Prune the cache by removing files with expired blocks.
void Prune() ABSL_LOCKS_EXCLUDED(mu_);
bool BlockNotStale(const std::shared_ptr<Block>& block)
ABSL_EXCLUSIVE_LOCKS_REQUIRED(mu_);
/// Look up a Key in the block cache.
std::shared_ptr<Block> Lookup(const Key& key) ABSL_LOCKS_EXCLUDED(mu_);
void MaybeFetch(const Key& key, const std::shared_ptr<Block>& block,
TF_Status* status) ABSL_LOCKS_EXCLUDED(mu_);
/// Trim the block cache to make room for another entry.
void Trim() ABSL_EXCLUSIVE_LOCKS_REQUIRED(mu_);
/// Update the LRU iterator for the block at `key`.
void UpdateLRU(const Key& key, const std::shared_ptr<Block>& block,
TF_Status* status) ABSL_LOCKS_EXCLUDED(mu_);
/// Remove all blocks of a file, with mu_ already held.
void RemoveFile_Locked(const std::string& filename)
ABSL_EXCLUSIVE_LOCKS_REQUIRED(mu_);
/// Remove the block `entry` from the block map and LRU list, and update the
/// cache size accordingly.
void RemoveBlock(BlockMap::iterator entry) ABSL_EXCLUSIVE_LOCKS_REQUIRED(mu_);
/// The cache pruning thread that removes files with expired blocks.
std::unique_ptr<TF_Thread, std::function<void(TF_Thread*)>> pruning_thread_;
/// Notification for stopping the cache pruning thread.
absl::Notification stop_pruning_thread_;
/// Guards access to the block map, LRU list, and cached byte count.
mutable absl::Mutex mu_;
/// The block map (map from Key to Block).
BlockMap block_map_ ABSL_GUARDED_BY(mu_);
/// The LRU list of block keys. The front of the list identifies the most
/// recently accessed block.
std::list<Key> lru_list_ ABSL_GUARDED_BY(mu_);
/// The LRA (least recently added) list of block keys. The front of the list
/// identifies the most recently added block.
///
/// Note: blocks are added to lra_list_ only after they have successfully been
/// fetched from the underlying block store.
std::list<Key> lra_list_ ABSL_GUARDED_BY(mu_);
/// The combined number of bytes in all of the cached blocks.
size_t cache_size_ ABSL_GUARDED_BY(mu_) = 0;
// A filename->file_signature map.
std::map<std::string, int64_t> file_signature_map_ ABSL_GUARDED_BY(mu_);
};
} // namespace tf_gcs_filesystem
#endif // TENSORFLOW_C_EXPERIMENTAL_FILESYSTEM_PLUGINS_GCS_RAM_FILE_BLOCK_CACHE_H_

View File

@ -18,6 +18,7 @@ limitations under the License.
#include <string>
#include "absl/strings/string_view.h"
#include "tensorflow/core/platform/logging.h"
#include "tensorflow/core/platform/protobuf.h"
#include "tensorflow/core/protobuf/struct.pb.h"

View File

@ -248,15 +248,22 @@ TF_Tensor* TF_AllocateOutput(TF_OpKernelContext* context, int index,
size_t len, TF_Status* status) {
TF_SetStatus(status, TF_OK, "");
auto* cc_ctx = reinterpret_cast<::tensorflow::OpKernelContext*>(context);
tensorflow::AllocatorAttributes attr = cc_ctx->output_alloc_attr(index);
auto* allocator = cc_ctx->get_allocator(attr);
void* data = tensorflow::allocate_tensor("TF_AllocateOutput", len, allocator);
TF_Tensor* result = TF_NewTensor(dtype, dims, num_dims, data, len,
tensorflow::deallocate_buffer, allocator);
TF_SetOutput(context, index, result, status);
if (TF_GetCode(status) != TF_OK) {
TF_DeleteTensor(result);
static_assert(sizeof(int64_t) == sizeof(tensorflow::int64),
"64-bit int types should match in size");
tensorflow::gtl::ArraySlice<tensorflow::int64> dimarray(
reinterpret_cast<tensorflow::int64*>(dims), num_dims);
tensorflow::Tensor* tensor;
tensorflow::Status s = cc_ctx->allocate_output(
index, tensorflow::TensorShape(dimarray), &tensor);
if (!s.ok()) {
::tensorflow::Set_TF_Status_from_Status(status, s);
return nullptr;
}
return result;
TF_Tensor* tf_tensor = TF_TensorFromTensor(*tensor, &s);
if (!s.ok()) {
::tensorflow::Set_TF_Status_from_Status(status, s);
return nullptr;
}
return tf_tensor;
}

View File

@ -1096,33 +1096,33 @@ StatusOr<bool> IsIdentityDrivingConstsInLoop(Node* node) {
return true;
}
absl::flat_hash_set<string> GetOrCreateWhitelist() {
absl::flat_hash_map<string, std::vector<string>>* whitelist_table =
tensorflow::GetWhitelistTable();
absl::flat_hash_set<string> GetOrCreateAllowlist() {
absl::flat_hash_map<string, std::vector<string>>* allowlist_table =
tensorflow::GetAllowlistTable();
MarkForCompilationPassFlags* flags = GetMarkForCompilationPassFlags();
absl::flat_hash_set<string> whitelist;
absl::flat_hash_set<string> allowlist;
for (auto s : absl::StrSplit(flags->tf_xla_ops_to_cluster, ',')) {
if (s == "FUSIBLE") {
for (auto pair : *whitelist_table) {
whitelist.insert(pair.second.begin(), pair.second.end());
for (auto pair : *allowlist_table) {
allowlist.insert(pair.second.begin(), pair.second.end());
}
} else if (whitelist_table->contains(s)) {
auto v = whitelist_table->at(s);
whitelist.insert(v.begin(), v.end());
} else if (allowlist_table->contains(s)) {
auto v = allowlist_table->at(s);
allowlist.insert(v.begin(), v.end());
} else if (!s.empty()) {
// Should be a user provided TF operation.
whitelist.insert(string(s));
allowlist.insert(string(s));
}
}
if (VLOG_IS_ON(2) && !whitelist.empty()) {
std::vector<string> vwhitelist(whitelist.begin(), whitelist.end());
absl::c_sort(vwhitelist);
if (VLOG_IS_ON(2) && !allowlist.empty()) {
std::vector<string> vallowlist(allowlist.begin(), allowlist.end());
absl::c_sort(vallowlist);
VLOG(2) << "XLA clustering will only consider the following TF operations: "
<< absl::StrJoin(vwhitelist, " ");
<< absl::StrJoin(vallowlist, " ");
}
return whitelist;
return allowlist;
}
Status MarkForCompilationPassImpl::FindCompilationCandidates() {
@ -1156,12 +1156,12 @@ Status MarkForCompilationPassImpl::FindCompilationCandidates() {
VLOG(2) << "sorted_nodes.size() = " << sorted_nodes.size();
auto whitelist = GetOrCreateWhitelist();
auto allowlist = GetOrCreateAllowlist();
std::vector<string> vall_ops = XlaOpRegistry::GetAllRegisteredOps();
absl::flat_hash_set<string> all_ops(vall_ops.begin(), vall_ops.end());
// Check that user's provided TF operation really exists.
for (const auto& s : whitelist) {
for (const auto& s : allowlist) {
if (!all_ops.contains(string(s))) {
return errors::InvalidArgument(
"The operation '", s,
@ -1206,7 +1206,7 @@ Status MarkForCompilationPassImpl::FindCompilationCandidates() {
continue;
}
if (!whitelist.empty() && !whitelist.contains(node->def().op())) {
if (!allowlist.empty() && !allowlist.contains(node->def().op())) {
VLOG(1) << "Rejecting TF operation " << node->def().op()
<< " as it is not listed in --tf_xla_ops_to_cluster.";
continue;
@ -1781,7 +1781,7 @@ Status MarkForCompilationPass::RunForTest(
return MarkForCompilation(options, debug_options);
}
absl::flat_hash_map<string, std::vector<string>>* GetWhitelistTable() {
absl::flat_hash_map<string, std::vector<string>>* GetAllowlistTable() {
// Table format: category name: {list of TF operations in that category}
static absl::flat_hash_map<string, std::vector<string>>* result =
new absl::flat_hash_map<string, std::vector<string>>{
@ -1845,7 +1845,7 @@ absl::flat_hash_map<string, std::vector<string>>* GetWhitelistTable() {
namespace testing {
void ResetClusterSequenceNumber() { cluster_sequence_num = 0; }
absl::flat_hash_set<string> GetKnownXLAWhitelistOp() {
absl::flat_hash_set<string> GetKnownXLAAllowlistOp() {
absl::flat_hash_set<string> result{"AdjustContrastv2",
"AdjustHue",
"AdjustSaturation",

View File

@ -58,7 +58,7 @@ bool IsCompilable(FunctionLibraryRuntime* flr, const NodeDef& ndef,
RecursiveCompilabilityChecker::UncompilableNodesMap*
uncompilable_node_info = nullptr);
absl::flat_hash_map<string, std::vector<string>>* GetWhitelistTable();
absl::flat_hash_map<string, std::vector<string>>* GetAllowlistTable();
namespace testing {
// DO NOT USE IN PRODUCTION.
@ -66,8 +66,8 @@ namespace testing {
// Resets some internal state to let us write reliable unit tests.
void ResetClusterSequenceNumber();
// Return a list of operation that we choose not to put into the whitelist.
absl::flat_hash_set<string> GetKnownXLAWhitelistOp();
// Return a list of operation that we choose not to put into the allowlist.
absl::flat_hash_set<string> GetKnownXLAAllowlistOp();
} // namespace testing
} // namespace tensorflow

View File

@ -1802,34 +1802,34 @@ TEST(XlaCompilationTest, StagePipelinePreservedByClusterScopingPass) {
EXPECT_NE(clusters["relu0"], clusters["relu1"]);
}
}
TEST(XlaCompilationTest, XLALiteWhitelist) {
auto* whitelist_table = tensorflow::GetWhitelistTable();
absl::flat_hash_set<string> hwhitelist;
TEST(XlaCompilationTest, XLALiteAllowlist) {
auto* allowlist_table = tensorflow::GetAllowlistTable();
absl::flat_hash_set<string> hallowlist;
std::vector<string> vall_ops = XlaOpRegistry::GetAllRegisteredOps();
absl::flat_hash_set<string> all_ops(vall_ops.begin(), vall_ops.end());
// Check that all the operations in the table are existing TF operations
for (auto pair : *whitelist_table) {
hwhitelist.insert(pair.second.begin(), pair.second.end());
for (auto pair : *allowlist_table) {
hallowlist.insert(pair.second.begin(), pair.second.end());
for (auto op : pair.second) {
ASSERT_TRUE(all_ops.contains(op));
}
}
// Check that all registered XLA operation are in the whitelist
// Check that all registered XLA operation are in the allowlist
// table or are known to not be in it.
absl::flat_hash_set<string> known_not_in_list =
tensorflow::testing::GetKnownXLAWhitelistOp();
tensorflow::testing::GetKnownXLAAllowlistOp();
std::vector<string> unknow_op;
for (string op : vall_ops) {
if (!hwhitelist.contains(op) && !known_not_in_list.contains(op)) {
if (!hallowlist.contains(op) && !known_not_in_list.contains(op)) {
unknow_op.push_back(op);
}
}
EXPECT_TRUE(unknow_op.empty())
<< "Someone added support for a new TF opeations inside XLA. They must "
"be included in the XLALite whitelist or blacklist:\n"
"be included in the XLALite allowlist or blacklist:\n"
<< absl::StrJoin(unknow_op, "\n");
}
} // namespace

View File

@ -669,6 +669,8 @@ cc_library(
":lhlo_legalize_to_llvm", # build-cleaner: keep
":xla_materialize_broadcasts", # build-cleaner: keep
":xla_unfuse_batch_norm", # build-cleaner: keep
"@llvm-project//mlir:AffineToStandardTransforms",
"@llvm-project//mlir:CFGTransforms",
"@llvm-project//mlir:IR",
"@llvm-project//mlir:InferTypeOpInterface",
"@llvm-project//mlir:LLVMDialect",

View File

@ -28,18 +28,18 @@ limitations under the License.
#include "mlir/Interfaces/SideEffectInterfaces.h" // from @llvm-project
namespace mlir {
namespace xla_chlo {
namespace chlo {
class XlaHloClientDialect : public Dialect {
class HloClientDialect : public Dialect {
public:
explicit XlaHloClientDialect(MLIRContext *context);
static StringRef getDialectNamespace() { return "xla_chlo"; }
explicit HloClientDialect(MLIRContext *context);
static StringRef getDialectNamespace() { return "chlo"; }
};
#define GET_OP_CLASSES
#include "tensorflow/compiler/mlir/hlo/include/mlir-hlo/Dialect/mhlo/IR/chlo_ops.h.inc"
} // namespace xla_chlo
} // namespace chlo
} // namespace mlir
#endif // TENSORFLOW_COMPILER_MLIR_HLO_INCLUDE_MLIR_HLO_DIALECT_MHLO_IR_CHLO_OPS_H_

View File

@ -22,7 +22,7 @@ limitations under the License.
//
// The typical use of this dialect is for client libraries to be able to emit
// less constrained ops and rely on the conversion framework to lower any
// xla_chlo ops to canonical mhlo ops.
// chlo ops to canonical mhlo ops.
//
// See: https://www.tensorflow.org/xla/operation_semantics
@ -35,8 +35,8 @@ include "mlir/Interfaces/SideEffectInterfaces.td"
include "mlir-hlo/Dialect/mhlo/IR/hlo_ops_base.td"
def HLOClient_Dialect : Dialect {
let name = "xla_chlo";
let cppNamespace = "xla_chlo";
let name = "chlo";
let cppNamespace = "chlo";
let summary = [{
XLA Client HLO Ops
}];

View File

@ -39,9 +39,9 @@ class OpBuilder;
namespace mhlo {
class XlaHloDialect : public Dialect {
class MhloDialect : public Dialect {
public:
explicit XlaHloDialect(MLIRContext *context);
explicit MhloDialect(MLIRContext *context);
static StringRef getDialectNamespace() { return "mhlo"; }
// Registered hook to materialize a constant operation from a given attribute

View File

@ -35,18 +35,18 @@ class OpBuilder;
#include "tensorflow/compiler/mlir/hlo/include/mlir-hlo/Dialect/mhlo/IR/lhlo_structs.h.inc"
namespace xla_lhlo {
namespace lmhlo {
class XlaLhloDialect : public Dialect {
class LmhloDialect : public Dialect {
public:
explicit XlaLhloDialect(MLIRContext *context);
static StringRef getDialectNamespace() { return "xla_lhlo"; }
explicit LmhloDialect(MLIRContext *context);
static StringRef getDialectNamespace() { return "lmhlo"; }
};
#define GET_OP_CLASSES
#include "tensorflow/compiler/mlir/hlo/include/mlir-hlo/Dialect/mhlo/IR/lhlo_ops.h.inc"
} // namespace xla_lhlo
} // namespace lmhlo
} // end namespace mlir
#endif // TENSORFLOW_COMPILER_MLIR_HLO_INCLUDE_MLIR_HLO_DIALECT_MHLO_IR_LHLO_OPS_H_

View File

@ -38,8 +38,8 @@ include "mlir/Interfaces/ViewLikeInterface.td"
include "mlir-hlo/Dialect/mhlo/IR/hlo_ops_base.td"
def LHLO_Dialect : Dialect {
let name = "xla_lhlo";
let cppNamespace = "xla_lhlo";
let name = "lmhlo";
let cppNamespace = "lmhlo";
}
//===----------------------------------------------------------------------===//
@ -253,7 +253,7 @@ def LHLO_ReduceWindowOp: LHLO_Op<"reduce_window", [
// TODO(timshen): Add a custom parser to hide operand_segment_sizes. For example,
// A tuple-like pattern match syntax could work:
// xla_lhlo.case %index, (%input0, %input1, %input2), (%output0, %output1) {
// lmhlo.case %index, (%input0, %input1, %input2), (%output0, %output1) {
// ...
// }, {
// ...
@ -337,7 +337,7 @@ def HLO_StaticMemRefCastOp: Op<LHLO_Dialect, "static_memref_cast",
Example:
```mlir
%buf_transformed =
xla_lhlo.static_memref_cast %buf
lmhlo.static_memref_cast %buf
: memref<1x5xf32> -> memref<5xf32, offset: 2, strides: [1]>
// The result of the op is a rank-1 memref with `[5]` shape, stride 1 and
@ -379,7 +379,7 @@ def HLO_DynamicMemRefCastOp: Op<LHLO_Dialect, "dynamic_memref_cast",
Example:
```mlir
%buf_transformed =
xla_lhlo.dynamic_memref_cast %buf(%size_X, %size_Y)[%step_X, %step_Y]
lmhlo.dynamic_memref_cast %buf(%size_X, %size_Y)[%step_X, %step_Y]
: memref<?x?xf32> -> memref<?x?xf32, offset: 0, strides: [?, ?]>
// The result of the op is a type-erased memref with `[%size_X, %size_Y]`
// shape and `[%step_X, %step_Y]` strides. The offset will be inherited
@ -470,14 +470,6 @@ def ReshapeMemRefCastOp: Op<LHLO_Dialect, "reshape_memref_cast", [
);
let results = (outs AnyRankedOrUnrankedMemRef:$result);
let builders = [OpBuilder<
"OpBuilder &builder, OperationState &result, MemRefType resultType, " #
"Value operand, Value shape", [{
result.addOperands(operand);
result.addOperands(shape);
result.types.push_back(resultType);
}]>];
let extraClassDeclaration = [{
MemRefType getType() { return getResult().getType().cast<MemRefType>(); }
}];

View File

@ -34,7 +34,7 @@ using HloToLhloOp = typename HloToLhloOpImpl<HloOpTy>::Type;
#define MAP_HLO_TO_LHLO(OpName) \
template <> \
struct HloToLhloOpImpl<mhlo::OpName> { \
using Type = xla_lhlo::OpName; \
using Type = lmhlo::OpName; \
}
MAP_HLO_TO_LHLO(AbsOp);

View File

@ -24,7 +24,7 @@ limitations under the License.
#include "tensorflow/compiler/mlir/hlo/include/mlir-hlo/Dialect/mhlo/transforms/map_hlo_to_lhlo_op.h"
namespace mlir {
namespace xla_lhlo {
namespace lmhlo {
namespace impl {
// A struct to map LhloBinaryOpTy type to the corresponding floating-point and
@ -33,32 +33,32 @@ template <typename LhloBinaryOpTy>
struct LhloToScalarOp;
template <>
struct LhloToScalarOp<xla_lhlo::AddOp> {
struct LhloToScalarOp<lmhlo::AddOp> {
using FOp = ::mlir::AddFOp;
using IOp = ::mlir::AddIOp;
};
template <>
struct LhloToScalarOp<xla_lhlo::CompareOp> {
struct LhloToScalarOp<lmhlo::CompareOp> {
using FOp = ::mlir::CmpFOp;
using IOp = ::mlir::CmpIOp;
};
template <>
struct LhloToScalarOp<xla_lhlo::DivOp> {
struct LhloToScalarOp<lmhlo::DivOp> {
using FOp = ::mlir::DivFOp;
using IOp = ::mlir::SignedDivIOp;
};
template <>
struct LhloToScalarOp<xla_lhlo::MulOp> {
struct LhloToScalarOp<lmhlo::MulOp> {
using FOp = ::mlir::MulFOp;
using IOp = ::mlir::MulIOp;
};
template <>
struct LhloToScalarOp<xla_lhlo::RemOp> {
struct LhloToScalarOp<lmhlo::RemOp> {
using FOp = ::mlir::RemFOp;
using IOp = ::mlir::SignedRemIOp;
};
template <>
struct LhloToScalarOp<xla_lhlo::SubOp> {
struct LhloToScalarOp<lmhlo::SubOp> {
using FOp = ::mlir::SubFOp;
using IOp = ::mlir::SubIOp;
};
@ -116,16 +116,17 @@ inline Value MapLhloOpToStdScalarOp(Location loc, ArrayRef<Type> result_types,
}
template <>
inline Value MapLhloOpToStdScalarOp<xla_lhlo::AbsOp>(
Location loc, ArrayRef<Type> result_types, ArrayRef<Value> args,
OpBuilder* b) {
inline Value MapLhloOpToStdScalarOp<lmhlo::AbsOp>(Location loc,
ArrayRef<Type> result_types,
ArrayRef<Value> args,
OpBuilder* b) {
Type element_type = args.front().getType();
if (element_type.isa<FloatType>()) {
return MapLhloOpToStdScalarOpImpl<FloatType, ::mlir::AbsFOp>{}(
loc, result_types, args, b);
}
if (element_type.isa<IntegerType>()) {
// xla_lhlo.abs(x, result) -> result = select((x > 0), x, sub(0, x))
// lmhlo.abs(x, result) -> result = select((x > 0), x, sub(0, x))
Value lhs = args[0];
auto integer_type = element_type.dyn_cast<IntegerType>();
@ -133,16 +134,17 @@ inline Value MapLhloOpToStdScalarOp<xla_lhlo::AbsOp>(
b->create<::mlir::ConstantIntOp>(loc, 0, integer_type.getWidth());
auto lhs_gt_zero = b->create<ScalarIOp<CompareOp>>(loc, CmpIPredicate::sge,
lhs, zero_intval);
auto neg_val = b->create<ScalarIOp<xla_lhlo::SubOp>>(loc, zero_intval, lhs);
auto neg_val = b->create<ScalarIOp<lmhlo::SubOp>>(loc, zero_intval, lhs);
return b->create<::mlir::SelectOp>(loc, lhs_gt_zero, lhs, neg_val);
}
return nullptr;
}
template <>
inline Value MapLhloOpToStdScalarOp<xla_lhlo::AndOp>(
Location loc, ArrayRef<Type> result_types, ArrayRef<Value> args,
OpBuilder* b) {
inline Value MapLhloOpToStdScalarOp<lmhlo::AndOp>(Location loc,
ArrayRef<Type> result_types,
ArrayRef<Value> args,
OpBuilder* b) {
return MapLhloOpToStdScalarOpImpl<IntegerType, ::mlir::AndOp>{}(
loc, result_types, args, b);
}
@ -205,30 +207,33 @@ inline Value MapXlaCompareOpToStdScalarOp(Location loc,
}
template <>
inline Value MapLhloOpToStdScalarOp<xla_lhlo::CopyOp>(
Location loc, ArrayRef<Type> result_types, ArrayRef<Value> args,
OpBuilder* b) {
inline Value MapLhloOpToStdScalarOp<lmhlo::CopyOp>(Location loc,
ArrayRef<Type> result_types,
ArrayRef<Value> args,
OpBuilder* b) {
return args.front();
}
template <>
inline Value MapLhloOpToStdScalarOp<xla_lhlo::ExpOp>(
Location loc, ArrayRef<Type> result_types, ArrayRef<Value> args,
OpBuilder* b) {
inline Value MapLhloOpToStdScalarOp<lmhlo::ExpOp>(Location loc,
ArrayRef<Type> result_types,
ArrayRef<Value> args,
OpBuilder* b) {
return MapLhloOpToStdScalarOpImpl<FloatType, ::mlir::ExpOp>{}(
loc, result_types, args, b);
}
template <>
inline Value MapLhloOpToStdScalarOp<xla_lhlo::CeilOp>(
Location loc, ArrayRef<Type> result_types, ArrayRef<Value> args,
OpBuilder* b) {
inline Value MapLhloOpToStdScalarOp<lmhlo::CeilOp>(Location loc,
ArrayRef<Type> result_types,
ArrayRef<Value> args,
OpBuilder* b) {
return MapLhloOpToStdScalarOpImpl<FloatType, ::mlir::CeilFOp>{}(
loc, result_types, args, b);
}
template <>
inline Value MapLhloOpToStdScalarOp<xla_lhlo::ComplexOp>(
inline Value MapLhloOpToStdScalarOp<lmhlo::ComplexOp>(
Location loc, ArrayRef<Type> result_types, ArrayRef<Value> args,
OpBuilder* b) {
return MapLhloOpToStdScalarOpImpl<CreateComplexOp>{}(loc, result_types, args,
@ -236,21 +241,23 @@ inline Value MapLhloOpToStdScalarOp<xla_lhlo::ComplexOp>(
}
template <>
inline Value MapLhloOpToStdScalarOp<xla_lhlo::RealOp>(
Location loc, ArrayRef<Type> result_types, ArrayRef<Value> args,
OpBuilder* b) {
inline Value MapLhloOpToStdScalarOp<lmhlo::RealOp>(Location loc,
ArrayRef<Type> result_types,
ArrayRef<Value> args,
OpBuilder* b) {
return MapLhloOpToStdScalarOpImpl<ReOp>{}(loc, result_types, args, b);
}
template <>
inline Value MapLhloOpToStdScalarOp<xla_lhlo::ImagOp>(
Location loc, ArrayRef<Type> result_types, ArrayRef<Value> args,
OpBuilder* b) {
inline Value MapLhloOpToStdScalarOp<lmhlo::ImagOp>(Location loc,
ArrayRef<Type> result_types,
ArrayRef<Value> args,
OpBuilder* b) {
return MapLhloOpToStdScalarOpImpl<ImOp>{}(loc, result_types, args, b);
}
template <>
inline Value MapLhloOpToStdScalarOp<xla_lhlo::ConvertOp>(
inline Value MapLhloOpToStdScalarOp<lmhlo::ConvertOp>(
Location loc, ArrayRef<Type> result_types, ArrayRef<Value> args,
OpBuilder* b) {
Type sourceType = args.front().getType();
@ -288,9 +295,10 @@ inline Value MapLhloOpToStdScalarOp<xla_lhlo::ConvertOp>(
}
template <>
inline Value MapLhloOpToStdScalarOp<xla_lhlo::DotOp>(
Location loc, ArrayRef<Type> result_types, ArrayRef<Value> args,
OpBuilder* b) {
inline Value MapLhloOpToStdScalarOp<lmhlo::DotOp>(Location loc,
ArrayRef<Type> result_types,
ArrayRef<Value> args,
OpBuilder* b) {
// Dot Op converter from lhlo to affine only accepts float and integer types.
const auto& lhs = args[0];
const auto& rhs = args[1];
@ -312,17 +320,19 @@ inline Value MapLhloOpToStdScalarOp<xla_lhlo::DotOp>(
}
template <>
inline Value MapLhloOpToStdScalarOp<xla_lhlo::CosOp>(
Location loc, ArrayRef<Type> result_types, ArrayRef<Value> args,
OpBuilder* b) {
inline Value MapLhloOpToStdScalarOp<lmhlo::CosOp>(Location loc,
ArrayRef<Type> result_types,
ArrayRef<Value> args,
OpBuilder* b) {
return MapLhloOpToStdScalarOpImpl<FloatType, ::mlir::CosOp>{}(
loc, result_types, args, b);
}
template <>
inline Value MapLhloOpToStdScalarOp<xla_lhlo::SinOp>(
Location loc, ArrayRef<Type> result_types, ArrayRef<Value> args,
OpBuilder* b) {
inline Value MapLhloOpToStdScalarOp<lmhlo::SinOp>(Location loc,
ArrayRef<Type> result_types,
ArrayRef<Value> args,
OpBuilder* b) {
return MapLhloOpToStdScalarOpImpl<FloatType, ::mlir::SinOp>{}(
loc, result_types, args, b);
}
@ -361,66 +371,69 @@ struct XlaCompareSelectOpToStdScalarOp<SupportedType, StdCompareOp, Predicate,
};
template <>
inline Value MapLhloOpToStdScalarOp<xla_lhlo::LogOp>(
Location loc, ArrayRef<Type> result_types, ArrayRef<Value> args,
OpBuilder* b) {
inline Value MapLhloOpToStdScalarOp<lmhlo::LogOp>(Location loc,
ArrayRef<Type> result_types,
ArrayRef<Value> args,
OpBuilder* b) {
return MapLhloOpToStdScalarOpImpl<FloatType, ::mlir::LogOp>{}(
loc, result_types, args, b);
}
template <>
inline Value MapLhloOpToStdScalarOp<xla_lhlo::MaxOp>(
Location loc, ArrayRef<Type> result_types, ArrayRef<Value> args,
OpBuilder* b) {
inline Value MapLhloOpToStdScalarOp<lmhlo::MaxOp>(Location loc,
ArrayRef<Type> result_types,
ArrayRef<Value> args,
OpBuilder* b) {
return XlaCompareSelectOpToStdScalarOp<
IntegerType, ScalarIOp<xla_lhlo::CompareOp>, CmpIPredicate, FloatType,
ScalarFOp<xla_lhlo::CompareOp>, CmpFPredicate>::map(loc, "GT",
result_types, args,
b);
IntegerType, ScalarIOp<lmhlo::CompareOp>, CmpIPredicate, FloatType,
ScalarFOp<lmhlo::CompareOp>, CmpFPredicate>::map(loc, "GT", result_types,
args, b);
}
template <>
inline Value MapLhloOpToStdScalarOp<xla_lhlo::MinOp>(
Location loc, ArrayRef<Type> result_types, ArrayRef<Value> args,
OpBuilder* b) {
inline Value MapLhloOpToStdScalarOp<lmhlo::MinOp>(Location loc,
ArrayRef<Type> result_types,
ArrayRef<Value> args,
OpBuilder* b) {
return XlaCompareSelectOpToStdScalarOp<
IntegerType, ScalarIOp<xla_lhlo::CompareOp>, CmpIPredicate, FloatType,
ScalarFOp<xla_lhlo::CompareOp>, CmpFPredicate>::map(loc, "LT",
result_types, args,
b);
IntegerType, ScalarIOp<lmhlo::CompareOp>, CmpIPredicate, FloatType,
ScalarFOp<lmhlo::CompareOp>, CmpFPredicate>::map(loc, "LT", result_types,
args, b);
}
template <>
inline Value MapLhloOpToStdScalarOp<xla_lhlo::NegOp>(
Location loc, ArrayRef<Type> result_types, ArrayRef<Value> args,
OpBuilder* b) {
inline Value MapLhloOpToStdScalarOp<lmhlo::NegOp>(Location loc,
ArrayRef<Type> result_types,
ArrayRef<Value> args,
OpBuilder* b) {
Type element_type = args.front().getType();
if (element_type.isa<FloatType>()) {
return MapLhloOpToStdScalarOpImpl<FloatType, ::mlir::NegFOp>{}(
loc, result_types, args, b);
}
if (element_type.isa<IntegerType>()) {
// xla_lhlo.neg(x, result) -> result = sub(0, x)
// lmhlo.neg(x, result) -> result = sub(0, x)
Value lhs = args[0];
auto integer_type = element_type.dyn_cast<IntegerType>();
auto zero_intval =
b->create<::mlir::ConstantIntOp>(loc, 0, integer_type.getWidth());
return b->create<ScalarIOp<xla_lhlo::SubOp>>(loc, zero_intval, lhs);
return b->create<ScalarIOp<lmhlo::SubOp>>(loc, zero_intval, lhs);
}
return nullptr;
}
template <>
inline Value MapLhloOpToStdScalarOp<xla_lhlo::RsqrtOp>(
Location loc, ArrayRef<Type> result_types, ArrayRef<Value> args,
OpBuilder* b) {
inline Value MapLhloOpToStdScalarOp<lmhlo::RsqrtOp>(Location loc,
ArrayRef<Type> result_types,
ArrayRef<Value> args,
OpBuilder* b) {
return MapLhloOpToStdScalarOpImpl<FloatType, ::mlir::RsqrtOp>{}(
loc, result_types, args, b);
}
template <>
inline Value MapLhloOpToStdScalarOp<xla_lhlo::SelectOp>(
inline Value MapLhloOpToStdScalarOp<lmhlo::SelectOp>(
Location loc, ArrayRef<Type> result_types, ArrayRef<Value> args,
OpBuilder* b) {
return MapLhloOpToStdScalarOpImpl<::mlir::SelectOp>{}(loc, result_types, args,
@ -428,9 +441,10 @@ inline Value MapLhloOpToStdScalarOp<xla_lhlo::SelectOp>(
}
template <>
inline Value MapLhloOpToStdScalarOp<xla_lhlo::SignOp>(
Location loc, ArrayRef<Type> result_types, ArrayRef<Value> args,
OpBuilder* b) {
inline Value MapLhloOpToStdScalarOp<lmhlo::SignOp>(Location loc,
ArrayRef<Type> result_types,
ArrayRef<Value> args,
OpBuilder* b) {
Type element_type = args.front().getType();
if (element_type.isa<FloatType>()) {
FloatType float_type = element_type.cast<FloatType>();
@ -442,17 +456,19 @@ inline Value MapLhloOpToStdScalarOp<xla_lhlo::SignOp>(
}
template <>
inline Value MapLhloOpToStdScalarOp<xla_lhlo::SqrtOp>(
Location loc, ArrayRef<Type> result_types, ArrayRef<Value> args,
OpBuilder* b) {
inline Value MapLhloOpToStdScalarOp<lmhlo::SqrtOp>(Location loc,
ArrayRef<Type> result_types,
ArrayRef<Value> args,
OpBuilder* b) {
return MapLhloOpToStdScalarOpImpl<FloatType, ::mlir::SqrtOp>{}(
loc, result_types, args, b);
}
template <>
inline Value MapLhloOpToStdScalarOp<xla_lhlo::TanhOp>(
Location loc, ArrayRef<Type> result_types, ArrayRef<Value> args,
OpBuilder* b) {
inline Value MapLhloOpToStdScalarOp<lmhlo::TanhOp>(Location loc,
ArrayRef<Type> result_types,
ArrayRef<Value> args,
OpBuilder* b) {
return MapLhloOpToStdScalarOpImpl<FloatType, ::mlir::TanhOp>{}(
loc, result_types, args, b);
}
@ -460,10 +476,10 @@ inline Value MapLhloOpToStdScalarOp<xla_lhlo::TanhOp>(
} // namespace impl
struct XlaOpToStdScalarOp {
// Implementation for LHLO ops except xla_lhlo::CompareOp.
// Implementation for LHLO ops except lmhlo::CompareOp.
template <typename XlaOpTy, typename LhloOpTy = XlaOpTy,
typename = std::enable_if_t<
!std::is_same<LhloOpTy, xla_lhlo::CompareOp>::value &&
!std::is_same<LhloOpTy, lmhlo::CompareOp>::value &&
std::is_same<typename mhlo::HloToLhloOp<LhloOpTy>,
std::false_type>::value>>
static Value map(XlaOpTy op, ArrayRef<Type> result_types,
@ -475,7 +491,7 @@ struct XlaOpToStdScalarOp {
// Implementation for HLO ops except mhlo::CompareOp.
template <typename XlaOpTy, typename LhloOpTy = mhlo::HloToLhloOp<XlaOpTy>,
typename = std::enable_if_t<
!std::is_same<LhloOpTy, xla_lhlo::CompareOp>::value &&
!std::is_same<LhloOpTy, lmhlo::CompareOp>::value &&
!std::is_same<LhloOpTy, std::false_type>::value>>
static Value map(XlaOpTy op, ArrayRef<Type> result_types,
ArrayRef<Value> args, OpBuilder* b, int i = 0) {
@ -483,13 +499,13 @@ struct XlaOpToStdScalarOp {
args, b);
}
// Implementation for xla_lhlo::CompareOp.
// Implementation for lmhlo::CompareOp.
template <typename LhloOpTy, typename = std::enable_if_t<std::is_same<
LhloOpTy, xla_lhlo::CompareOp>::value>>
static Value map(xla_lhlo::CompareOp op, ArrayRef<Type> result_types,
LhloOpTy, lmhlo::CompareOp>::value>>
static Value map(lmhlo::CompareOp op, ArrayRef<Type> result_types,
ArrayRef<Value> args, OpBuilder* b) {
auto comparison_direction = op.comparison_direction();
return impl::MapXlaCompareOpToStdScalarOp<xla_lhlo::CompareOp>(
return impl::MapXlaCompareOpToStdScalarOp<lmhlo::CompareOp>(
op.getLoc(), comparison_direction, result_types, args, b);
}
@ -500,12 +516,12 @@ struct XlaOpToStdScalarOp {
static Value map(mhlo::CompareOp op, ArrayRef<Type> result_types,
ArrayRef<Value> args, OpBuilder* b) {
auto comparison_direction = op.comparison_direction();
return impl::MapXlaCompareOpToStdScalarOp<xla_lhlo::CompareOp>(
return impl::MapXlaCompareOpToStdScalarOp<lmhlo::CompareOp>(
op.getLoc(), comparison_direction, result_types, args, b);
}
};
} // namespace xla_lhlo
} // namespace lmhlo
} // namespace mlir
#endif // TENSORFLOW_COMPILER_MLIR_HLO_INCLUDE_MLIR_HLO_DIALECT_MHLO_TRANSFORMS_MAP_XLA_TO_SCALAR_OP_H_

View File

@ -60,7 +60,7 @@ std::unique_ptr<OperationPass<FuncOp>> createXlaHloFusionPass();
} // namespace mhlo
namespace xla_lhlo {
namespace lmhlo {
// Lowers from LHLO dialect to Affine dialect.
std::unique_ptr<OperationPass<FuncOp>> createLegalizeToAffinePass();
@ -92,7 +92,7 @@ std::unique_ptr<Pass> createLhloCopyRemovalPass();
// Lowers from LHLO dialect to parallel loops.
std::unique_ptr<OperationPass<FuncOp>> createLegalizeLhloToParallelLoopsPass();
} // namespace xla_lhlo
} // namespace lmhlo
namespace xla {

View File

@ -75,23 +75,23 @@ void PopulateUnfuseBatchNormPatterns(MLIRContext *context,
} // namespace mhlo
namespace xla_lhlo {
namespace lmhlo {
/// Collect a set of patterns to convert from the LHLO dialect to LLVM.
void PopulateLhloToLLVMConversionPatterns(const LowerToLLVMOptions &options,
LLVMTypeConverter *converter,
OwningRewritePatternList *patterns);
} // namespace xla_lhlo
} // namespace lmhlo
namespace xla_chlo {
namespace chlo {
// Populates a collection of conversion patterns for legalizing client-HLO to
// HLO.
void PopulateLegalizeChloToHloPatterns(MLIRContext *context,
OwningRewritePatternList *patterns);
} // namespace xla_chlo
} // namespace chlo
namespace xla {

View File

@ -23,7 +23,7 @@ limitations under the License.
#include "tensorflow/compiler/mlir/hlo/include/mlir-hlo/utils/broadcast_utils.h"
namespace mlir {
namespace xla_chlo {
namespace chlo {
template <typename T>
static LogicalResult Verify(T op) {
@ -263,10 +263,10 @@ BROADCAST_BINARY_OP_DEFS(BroadcastXorOp);
#include "tensorflow/compiler/mlir/hlo/include/mlir-hlo/Dialect/mhlo/IR/chlo_ops.cc.inc"
//===----------------------------------------------------------------------===//
// xla_chlo Dialect Constructor
// chlo Dialect Constructor
//===----------------------------------------------------------------------===//
XlaHloClientDialect::XlaHloClientDialect(MLIRContext* context)
HloClientDialect::HloClientDialect(MLIRContext* context)
: Dialect(getDialectNamespace(), context) {
addOperations<
#define GET_OP_LIST
@ -274,5 +274,5 @@ XlaHloClientDialect::XlaHloClientDialect(MLIRContext* context)
>();
}
} // namespace xla_chlo
} // namespace chlo
} // namespace mlir

View File

@ -18,7 +18,6 @@ limitations under the License.
#include "tensorflow/compiler/mlir/hlo/include/mlir-hlo/Dialect/mhlo/IR/lhlo_ops.h"
// Static initialization for XLA dialect registration.
static mlir::DialectRegistration<mlir::mhlo::XlaHloDialect> mhlo_ops;
static mlir::DialectRegistration<mlir::xla_chlo::XlaHloClientDialect>
xla_chlo_ops;
static mlir::DialectRegistration<mlir::xla_lhlo::XlaLhloDialect> xla_lhlo_ops;
static mlir::DialectRegistration<mlir::mhlo::MhloDialect> mhlo_ops;
static mlir::DialectRegistration<mlir::chlo::HloClientDialect> chlo_ops;
static mlir::DialectRegistration<mlir::lmhlo::LmhloDialect> lmhlo_ops;

View File

@ -62,9 +62,8 @@ namespace mlir {
#include "tensorflow/compiler/mlir/hlo/include/mlir-hlo/Dialect/mhlo/IR/hlo_structs.cc.inc"
namespace mhlo {
Operation* XlaHloDialect::materializeConstant(OpBuilder& builder,
Attribute value, Type type,
Location loc) {
Operation* MhloDialect::materializeConstant(OpBuilder& builder, Attribute value,
Type type, Location loc) {
// HLO dialect constants only support ElementsAttr unlike standard dialect
// constant which supports all attributes.
if (value.isa<ElementsAttr>())
@ -2128,7 +2127,7 @@ struct HLOInlinerInterface : public DialectInlinerInterface {
// mhlo Dialect Constructor
//===----------------------------------------------------------------------===//
XlaHloDialect::XlaHloDialect(MLIRContext* context)
MhloDialect::MhloDialect(MLIRContext* context)
: Dialect(getDialectNamespace(), context) {
addOperations<
#define GET_OP_LIST
@ -2140,7 +2139,7 @@ XlaHloDialect::XlaHloDialect(MLIRContext* context)
// allowUnknownOperations();
}
Type XlaHloDialect::parseType(DialectAsmParser& parser) const {
Type MhloDialect::parseType(DialectAsmParser& parser) const {
StringRef data_type;
if (parser.parseKeyword(&data_type)) return Type();
@ -2149,7 +2148,7 @@ Type XlaHloDialect::parseType(DialectAsmParser& parser) const {
return nullptr;
}
void XlaHloDialect::printType(Type type, DialectAsmPrinter& os) const {
void MhloDialect::printType(Type type, DialectAsmPrinter& os) const {
if (type.isa<TokenType>()) {
os << "token";
return;

View File

@ -46,9 +46,9 @@ limitations under the License.
namespace mlir {
#include "tensorflow/compiler/mlir/hlo/include/mlir-hlo/Dialect/mhlo/IR/lhlo_structs.cc.inc"
namespace xla_lhlo {
namespace lmhlo {
XlaLhloDialect::XlaLhloDialect(MLIRContext *context)
LmhloDialect::LmhloDialect(MLIRContext *context)
: Dialect(getDialectNamespace(), context) {
addOperations<
#define GET_OP_LIST
@ -138,5 +138,5 @@ void FusionOp::build(OpBuilder &builder, OperationState &result,
FusionOp::ensureTerminator(*bodyRegion, builder, result.location);
}
} // namespace xla_lhlo
} // namespace lmhlo
} // namespace mlir

View File

@ -25,7 +25,7 @@ limitations under the License.
#include "tensorflow/compiler/mlir/hlo/include/mlir-hlo/utils/broadcast_utils.h"
namespace mlir {
namespace xla_chlo {
namespace chlo {
namespace {
@ -235,5 +235,5 @@ void PopulateLegalizeChloToHloPatterns(MLIRContext *context,
context, patterns);
}
} // namespace xla_chlo
} // namespace chlo
} // namespace mlir

View File

@ -21,7 +21,7 @@ limitations under the License.
#include "tensorflow/compiler/mlir/hlo/include/mlir-hlo/Dialect/mhlo/transforms/rewriters.h"
namespace mlir {
namespace xla_chlo {
namespace chlo {
namespace {
@ -31,9 +31,9 @@ struct TestChloLegalizeToHloPass
ConversionTarget conversionTarget(getContext());
OwningRewritePatternList conversionPatterns;
conversionTarget.addIllegalDialect<XlaHloClientDialect>();
conversionTarget.addIllegalDialect<HloClientDialect>();
// Consider the mhlo dialect legal for tests.
conversionTarget.addLegalDialect<mhlo::XlaHloDialect>();
conversionTarget.addLegalDialect<mhlo::MhloDialect>();
// The conversion uses helpers from the Standard dialect.
conversionTarget.addLegalDialect<mlir::StandardOpsDialect>();
conversionTarget.addLegalDialect<mlir::shape::ShapeDialect>();
@ -49,9 +49,9 @@ struct TestChloLegalizeToHloPass
} // namespace
} // namespace xla_chlo
} // namespace chlo
} // namespace mlir
static mlir::PassRegistration<mlir::xla_chlo::TestChloLegalizeToHloPass> pass(
static mlir::PassRegistration<mlir::chlo::TestChloLegalizeToHloPass> pass(
"test-xla-chlo-legalize-to-hlo",
"Test pass for applying chlo -> hlo legalization patterns");

View File

@ -44,7 +44,7 @@ template <typename T>
using BaseOpConversion = BufferAssignmentOpConversionPattern<T>;
using StdReturnOpConverter =
detail::BufferAssignmentReturnOpConverter<mlir::ReturnOp, mlir::ReturnOp,
xla_lhlo::CopyOp, true>;
lmhlo::CopyOp, true>;
Value InsertDynamicAllocAndDealloc(Location loc, Value result,
Value shape_operand,
@ -149,7 +149,7 @@ struct HloToLhloDynamicBroadcastInDimOpConverter
Value transformed_operand =
InsertDynamicMemrefCastOp(op, operands.front(), &rewriter);
rewriter.create<xla_lhlo::BroadcastInDimOp>(
rewriter.create<lmhlo::BroadcastInDimOp>(
loc, transformed_operand, resultBuffer, op.broadcast_dimensions());
rewriter.replaceOp(op, {resultBuffer});
@ -161,7 +161,7 @@ struct HloToLhloDynamicBroadcastInDimOpConverter
// Inserts dynamic memref to change the layout of the memref to put 0-stride
// and size of the target dimension if size-1 dimension expansion is
// necessary.
xla_lhlo::DynamicMemRefCastOp InsertDynamicMemrefCastOp(
lmhlo::DynamicMemRefCastOp InsertDynamicMemrefCastOp(
mhlo::DynamicBroadcastInDimOp op, Value operand, OpBuilder* b) const {
auto loc = op.getLoc();
auto operand_type = operand.getType().cast<MemRefType>();
@ -214,12 +214,37 @@ struct HloToLhloDynamicBroadcastInDimOpConverter
makeStridedLinearLayoutMap(dynamic_layout,
/*offset=*/0, b->getContext()));
auto transformed_operand = b->create<xla_lhlo::DynamicMemRefCastOp>(
auto transformed_operand = b->create<lmhlo::DynamicMemRefCastOp>(
loc, type_erased_memref_type, operand, sizes, strides);
return transformed_operand;
}
};
struct HloToLhloDynamicReshapeConverter
: public BaseOpConversion<mhlo::DynamicReshapeOp> {
public:
using BaseOpConversion<mhlo::DynamicReshapeOp>::BaseOpConversion;
LogicalResult matchAndRewrite(
mhlo::DynamicReshapeOp op, ArrayRef<Value> operands,
ConversionPatternRewriter& rewriter) const final {
Type result_type;
if (auto ranked_type = op.getType().dyn_cast<RankedTensorType>()) {
result_type =
MemRefType::get(ranked_type.getShape(), ranked_type.getElementType());
} else if (auto unranked_type =
op.getType().dyn_cast<UnrankedTensorType>()) {
result_type = UnrankedMemRefType::get(unranked_type.getElementType(), 0);
} else {
return failure();
}
mhlo::DynamicReshapeOp::Adaptor adaptor(operands);
rewriter.replaceOpWithNewOp<lmhlo::ReshapeMemRefCastOp>(
op, result_type, adaptor.operand(), adaptor.output_shape());
return success();
}
};
struct HloToLhloReduceOpConverter : public BaseOpConversion<mhlo::ReduceOp> {
public:
using BaseOpConversion<mhlo::ReduceOp>::BaseOpConversion;
@ -241,8 +266,8 @@ struct HloToLhloReduceOpConverter : public BaseOpConversion<mhlo::ReduceOp> {
buffer_args.push_back(
InsertAlloc(loc, result, this->bufferAssignment, &rewriter));
}
auto new_op = rewriter.create<xla_lhlo::ReduceOp>(
loc, llvm::None, buffer_args, op.getAttrs());
auto new_op = rewriter.create<lmhlo::ReduceOp>(loc, llvm::None, buffer_args,
op.getAttrs());
// Copy over the operations inside the region.
rewriter.inlineRegionBefore(op.body(), new_op.body(), new_op.body().end());
@ -267,7 +292,7 @@ struct HloToLhloReduceOpConverter : public BaseOpConversion<mhlo::ReduceOp> {
}
// Insert terminator at the end.
rewriter.setInsertionPointToEnd(&entry_block);
rewriter.create<xla_lhlo::TerminatorOp>(loc);
rewriter.create<lmhlo::TerminatorOp>(loc);
rewriter.replaceOp(op, ArrayRef<Value>(buffer_args).slice(operands.size()));
@ -296,8 +321,8 @@ class HloToLhloTensorStoreOpConverter
LogicalResult matchAndRewrite(
mlir::TensorStoreOp op, ArrayRef<Value> operands,
ConversionPatternRewriter& rewriter) const final {
rewriter.replaceOpWithNewOp<xla_lhlo::CopyOp>(
op, llvm::None, operands.front(), operands.back());
rewriter.replaceOpWithNewOp<lmhlo::CopyOp>(op, llvm::None, operands.front(),
operands.back());
return success();
}
};
@ -311,7 +336,7 @@ class HloToLhloTensorStoreOpConverter
// %arg1: memref<2x2xf32>,
// %arg2: memref<2x2xf32>,
// %arg3: memref<2x2xf32>) {
// "xla_lhlo.fusion"() ({
// "lmhlo.fusion"() ({
// %0 = tensor_load %arg1 : memref<2x2xf32>
// %1 = tensor_load %arg2 : memref<2x2xf32>
// %2 = "mhlo.add"(%0, %1) :
@ -320,7 +345,7 @@ class HloToLhloTensorStoreOpConverter
// %4 = "mhlo.multiply"(%2, %3) :
// (tensor<2x2xf32>, tensor<2x2xf32>) -> tensor<2x2xf32>
// tensor_store %4, %arg3 : memref<2x2xf32>
// "xla_lhlo.terminator"() : () -> ()
// "lmhlo.terminator"() : () -> ()
// }) : () -> ()
// return
// }
@ -330,13 +355,13 @@ class HloToLhloTensorStoreOpConverter
// %arg1: memref<2x2xf32>,
// %arg2: memref<2x2xf32>,
// %arg3: memref<2x2xf32>) {
// "xla_lhlo.fusion"() ( {
// "lmhlo.fusion"() ( {
// %0 = alloc() : memref<2x2xf32>
// "xla_lhlo.add"(%arg1, %arg2, %0) :
// "lmhlo.add"(%arg1, %arg2, %0) :
// (memref<2x2xf32>, memref<2x2xf32>, memref<2x2xf32>) -> ()
// "xla_lhlo.multiply"(%0, %arg0, %arg3) :
// "lmhlo.multiply"(%0, %arg0, %arg3) :
// (memref<2x2xf32>, memref<2x2xf32>, memref<2x2xf32>) -> ()
// "xla_lhlo.terminator"() : () -> ()
// "lmhlo.terminator"() : () -> ()
// }) : () -> ()
// return
// }
@ -357,13 +382,13 @@ class HloToLhloTensorStoreOpConverter
// %arg2: memref<4xf32>) {
// %0 = alloc() : memref<4xf32>
// "xla_lhlo.maximum"(%arg0, %arg1, %0) :
// "lmhlo.maximum"(%arg0, %arg1, %0) :
// (memref<4xf32>, memref<4xf32>, memref<4xf32>) -> ()
// %1 = alloc() : memref<4xf32>
// "xla_lhlo.add"(%arg0, %0, %1) :
// "lmhlo.add"(%arg0, %0, %1) :
// (memref<4xf32>, memref<4xf32>, memref<4xf32>) -> ()
// "xla_lhlo.copy"(%1, %arg2) : (memref<4xf32>, memref<4xf32>) -> ()
// "xla_lhlo.terminator"() : () -> ()
// "lmhlo.copy"(%1, %arg2) : (memref<4xf32>, memref<4xf32>) -> ()
// "lmhlo.terminator"() : () -> ()
// }
struct HloLegalizeToLhlo
@ -381,26 +406,31 @@ struct HloLegalizeToLhlo
OwningRewritePatternList patterns;
auto& context = getContext();
ConversionTarget target(context);
target.addLegalDialect<xla_lhlo::XlaLhloDialect>();
target.addLegalDialect<lmhlo::LmhloDialect>();
target.addLegalDialect<StandardOpsDialect>();
target.addLegalOp<ModuleOp>();
target.addIllegalOp<mlir::TensorLoadOp>();
target.addIllegalOp<mlir::TensorStoreOp>();
target.addLegalOp<ModuleTerminatorOp>();
target.addLegalOp<TensorFromElementsOp>();
target.addIllegalDialect<mhlo::XlaHloDialect>();
target.addIllegalDialect<mhlo::MhloDialect>();
BufferAssignmentTypeConverter converter;
auto isMemRefType = [](Type type) { return type.isa<BaseMemRefType>(); };
target.addDynamicallyLegalOp<FuncOp>([&](FuncOp op) {
auto inputs = op.getType().getInputs();
return llvm::all_of(inputs,
[](Type input) { return input.isa<MemRefType>(); }) &&
return llvm::all_of(inputs, isMemRefType) &&
converter.isLegal(&op.getBody());
});
target.addDynamicallyLegalOp<mlir::ReturnOp>([&](mlir::ReturnOp returnOp) {
return std::all_of(returnOp.operand_type_begin(),
returnOp.operand_type_end(),
[](Type type) { return type.isa<MemRefType>(); });
target.addDynamicallyLegalOp<CallOp>([&](CallOp op) {
return std::all_of(op.operand_type_begin(), op.operand_type_end(),
isMemRefType) &&
std::all_of(op.result_type_begin(), op.result_type_end(),
isMemRefType);
});
target.addDynamicallyLegalOp<mlir::ReturnOp>([&](mlir::ReturnOp op) {
return std::all_of(op.operand_type_begin(), op.operand_type_end(),
isMemRefType);
});
auto module = getOperation();
@ -411,12 +441,12 @@ struct HloLegalizeToLhlo
&converter, &patterns);
if (results_escape_function) {
populateWithBufferAssignmentOpConversionPatterns<
mlir::ReturnOp, mlir::ReturnOp, xla_lhlo::CopyOp,
mlir::ReturnOp, mlir::ReturnOp, lmhlo::CopyOp,
/*allowMemrefFunctionResults=*/true>(&context, &bufferAssignment,
&converter, &patterns);
} else {
populateWithBufferAssignmentOpConversionPatterns<
mlir::ReturnOp, mlir::ReturnOp, xla_lhlo::CopyOp,
mlir::ReturnOp, mlir::ReturnOp, lmhlo::CopyOp,
/*allowMemrefFunctionResults=*/false>(&context, &bufferAssignment,
&converter, &patterns);
}
@ -442,6 +472,7 @@ void populateHLOToLHLOConversionPattern(
// clang-format off
patterns->insert<
HloToLhloDynamicBroadcastInDimOpConverter,
HloToLhloDynamicReshapeConverter,
HloToLhloOpConverter<mhlo::AbsOp>,
HloToLhloOpConverter<mhlo::AddOp>,
HloToLhloOpConverter<mhlo::AndOp>,

View File

@ -23,7 +23,7 @@ limitations under the License.
#include "tensorflow/compiler/mlir/hlo/include/mlir-hlo/Dialect/mhlo/transforms/passes.h"
namespace mlir {
namespace xla_lhlo {
namespace lmhlo {
namespace {
// Removes LHLO copy operations that copy from allocated buffers to block
@ -34,7 +34,7 @@ struct LhloCopyRemoval : mlir::PassWrapper<LhloCopyRemoval, OperationPass<>> {
void runOnOperation() override {
llvm::SmallVector<mlir::Operation*, 2> eraseList;
auto operation = getOperation();
operation->walk([&](mlir::xla_lhlo::CopyOp copyOp) {
operation->walk([&](mlir::lmhlo::CopyOp copyOp) {
// If this region contains more than one block, then ignore this copy
// operation.
if (copyOp.getParentRegion()->getBlocks().size() > 1) {
@ -101,5 +101,5 @@ std::unique_ptr<Pass> createLhloCopyRemovalPass() {
static PassRegistration<LhloCopyRemoval> copy_removal_pass(
"lhlo-copy-removal", "Removes redundant LHLO copy operations");
} // namespace xla_lhlo
} // namespace lmhlo
} // namespace mlir

View File

@ -27,7 +27,7 @@ limitations under the License.
#include "tensorflow/compiler/mlir/hlo/include/mlir-hlo/Dialect/mhlo/transforms/passes.h"
namespace mlir {
namespace xla_lhlo {
namespace lmhlo {
namespace {
using linalg::LinalgOp;
@ -147,5 +147,5 @@ static PassRegistration<LhloFuseLinalg> legalize_pass(
"lhlo-fuse-linalg",
"Greedily fuse linalg ops obtained after LHLO lowering.");
} // namespace xla_lhlo
} // namespace lmhlo
} // namespace mlir

View File

@ -28,7 +28,7 @@ limitations under the License.
#include "tensorflow/compiler/mlir/hlo/include/mlir-hlo/Dialect/mhlo/transforms/map_xla_to_scalar_op.h"
namespace mlir {
namespace xla_lhlo {
namespace lmhlo {
namespace {
// Builds an affine loop nest iterating from zeros to "upper_bounds" with unit
@ -69,7 +69,7 @@ struct DotOpConverter : public OpRewritePattern<DotOp> {
auto r = builder.create<AffineLoadOp>(loc, rhs, rhs_indices);
auto result =
rewriter.create<AffineLoadOp>(loc, op.output(), result_indices);
Value op_result = xla_lhlo::XlaOpToStdScalarOp::map<DotOp>(
Value op_result = lmhlo::XlaOpToStdScalarOp::map<DotOp>(
op, element_type, {l, r, result}, &builder);
map_status = success(op_result != nullptr);
if (failed(map_status)) return;
@ -108,7 +108,7 @@ struct BinaryOpConverter : public OpRewritePattern<LhloOpTy> {
ValueRange induction_vars) {
auto l = builder.create<AffineLoadOp>(loc, lhs, induction_vars);
auto r = builder.create<AffineLoadOp>(loc, rhs, induction_vars);
Value op_result = xla_lhlo::XlaOpToStdScalarOp::map<LhloOpTy>(
Value op_result = lmhlo::XlaOpToStdScalarOp::map<LhloOpTy>(
op, element_type, {l, r}, &builder);
map_status = success(op_result != nullptr);
if (failed(map_status)) return;
@ -127,13 +127,13 @@ void populateLHLOToAffineConversionPattern(MLIRContext* context,
OwningRewritePatternList* patterns) {
// clang-format off
patterns->insert<
BinaryOpConverter<xla_lhlo::AddOp>,
BinaryOpConverter<xla_lhlo::AndOp>,
BinaryOpConverter<xla_lhlo::DivOp>,
BinaryOpConverter<xla_lhlo::MaxOp>,
BinaryOpConverter<xla_lhlo::MinOp>,
BinaryOpConverter<xla_lhlo::MulOp>,
BinaryOpConverter<xla_lhlo::SubOp>,
BinaryOpConverter<lmhlo::AddOp>,
BinaryOpConverter<lmhlo::AndOp>,
BinaryOpConverter<lmhlo::DivOp>,
BinaryOpConverter<lmhlo::MaxOp>,
BinaryOpConverter<lmhlo::MinOp>,
BinaryOpConverter<lmhlo::MulOp>,
BinaryOpConverter<lmhlo::SubOp>,
DotOpConverter>(context);
// clang-format on
}
@ -157,5 +157,5 @@ std::unique_ptr<OperationPass<FuncOp>> createLegalizeToAffinePass() {
static PassRegistration<LhloLegalizeToAffine> legalize_pass(
"lhlo-legalize-to-affine", "Legalize from LHLO dialect to affine dialect");
} // namespace xla_lhlo
} // namespace lmhlo
} // namespace mlir

View File

@ -38,7 +38,7 @@ limitations under the License.
#include "tensorflow/compiler/mlir/hlo/include/mlir-hlo/Dialect/mhlo/transforms/map_xla_to_scalar_op.h"
namespace mlir {
namespace xla_lhlo {
namespace lmhlo {
namespace {
// A simple translation of LHLO reduce operations to a corresponding gpu
@ -173,7 +173,7 @@ struct LhloLegalizeToGpu : public PassWrapper<LhloLegalizeToGpu, FunctionPass> {
OwningRewritePatternList patterns;
ConversionTarget target(getContext());
target.addLegalDialect<linalg::LinalgDialect, StandardOpsDialect,
gpu::GPUDialect, scf::SCFDialect, XlaLhloDialect>();
gpu::GPUDialect, scf::SCFDialect, LmhloDialect>();
target.addIllegalOp<ReduceOp>();
auto func = getFunction();
patterns.insert<LhloReduceToGPULaunchConverter>(func.getContext());
@ -192,5 +192,5 @@ std::unique_ptr<OperationPass<FuncOp>> createLegalizeToGpuPass() {
static PassRegistration<LhloLegalizeToGpu> legalize_pass(
"lhlo-legalize-to-gpu", "Legalize from LHLO dialect to GPU dialect");
} // namespace xla_lhlo
} // namespace lmhlo
} // namespace mlir

View File

@ -21,7 +21,7 @@ limitations under the License.
#include "tensorflow/compiler/mlir/hlo/include/mlir-hlo/Dialect/mhlo/IR/lhlo_ops.h"
namespace mlir {
namespace xla_lhlo {
namespace lmhlo {
namespace {
struct StaticMemRefCastOpConverter
@ -132,5 +132,5 @@ void PopulateLhloToLLVMConversionPatterns(const LowerToLLVMOptions &options,
*converter, options);
}
} // namespace xla_lhlo
} // namespace lmhlo
} // namespace mlir

View File

@ -13,6 +13,8 @@ See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#include "mlir/Conversion/AffineToStandard/AffineToStandard.h" // from @llvm-project
#include "mlir/Conversion/SCFToStandard/SCFToStandard.h" // from @llvm-project
#include "mlir/Conversion/StandardToLLVM/ConvertStandardToLLVM.h" // from @llvm-project
#include "mlir/Conversion/StandardToLLVM/ConvertStandardToLLVMPass.h" // from @llvm-project
#include "mlir/Dialect/LLVMIR/LLVMDialect.h" // from @llvm-project
@ -23,7 +25,7 @@ limitations under the License.
#include "tensorflow/compiler/mlir/hlo/include/mlir-hlo/Dialect/mhlo/transforms/rewriters.h"
namespace mlir {
namespace xla_lhlo {
namespace lmhlo {
namespace {
class TestLhloToLLVMPass
@ -38,11 +40,14 @@ class TestLhloToLLVMPass
populateStdToLLVMConversionPatterns(converter, patterns);
PopulateLhloToLLVMConversionPatterns(
LowerToLLVMOptions::getDefaultOptions(), &converter, &patterns);
mlir::populateLoopToStdConversionPatterns(patterns, &getContext());
mlir::populateAffineToStdConversionPatterns(patterns, m.getContext());
ConversionTarget target(getContext());
target.addLegalDialect<LLVM::LLVMDialect>();
target.addLegalOp<ModuleOp, ModuleTerminatorOp>();
target.addIllegalDialect<XlaLhloDialect>();
target.addIllegalDialect<LmhloDialect>();
if (failed(applyFullConversion(m, target, patterns))) {
signalPassFailure();
@ -55,5 +60,5 @@ class TestLhloToLLVMPass
static PassRegistration<TestLhloToLLVMPass> legalize_lhlo_pass(
"test-lhlo-legalize-to-llvm", "Legalize from LHLO dialect to LLVM.");
} // namespace xla_lhlo
} // namespace lmhlo
} // namespace mlir

View File

@ -26,7 +26,7 @@ limitations under the License.
#include "tensorflow/compiler/mlir/hlo/include/mlir-hlo/Dialect/mhlo/IR/lhlo_ops.h"
namespace mlir {
namespace xla_lhlo {
namespace lmhlo {
namespace {
// Clones and adapts the code in `lhlo_block` that works on buffers and has a
@ -154,14 +154,14 @@ scf::ParallelOp MakeLoopOverShape(Location loc, Value shaped_value,
return b->create<scf::ParallelOp>(loc, lower, upper, step);
}
// Converts `xla_lhlo.ReduceOp` into two scf::ParallelOp and a scf::ReduceOp.
// Converts `lmhlo.ReduceOp` into two scf::ParallelOp and a scf::ReduceOp.
// The outper `ParallelOp` refers to the parallel loops if there are
// any. The inner `ParalleOp` refers to the reduction loops and `ReduceOp`
// contains the reduction operator.
//
// Example:
//
// "xla_lhlo.reduce"(%buffer, %init_buf, %result) ( {
// "lmhlo.reduce"(%buffer, %init_buf, %result) ( {
// ^bb0(%lhs: memref<f32>, %rhs: memref<f32>, %res: memref<f32>):
// <LHLO ops>
// } ) {dimensions = dense<[1]> : tensor<1xi64>}
@ -187,12 +187,12 @@ scf::ParallelOp MakeLoopOverShape(Location loc, Value shaped_value,
// } : f32
// scf.yield
// }
class ReduceOpConverter : public OpConversionPattern<xla_lhlo::ReduceOp> {
class ReduceOpConverter : public OpConversionPattern<lmhlo::ReduceOp> {
public:
using OpConversionPattern<xla_lhlo::ReduceOp>::OpConversionPattern;
using OpConversionPattern<lmhlo::ReduceOp>::OpConversionPattern;
LogicalResult matchAndRewrite(
xla_lhlo::ReduceOp xla_reduce_op, ArrayRef<Value> /*args*/,
lmhlo::ReduceOp xla_reduce_op, ArrayRef<Value> /*args*/,
ConversionPatternRewriter& rewriter) const final {
// TODO(b/137624192) Implement variadic reduce.
if (xla_reduce_op.out().size() != 1) return failure();
@ -226,7 +226,7 @@ class ReduceOpConverter : public OpConversionPattern<xla_lhlo::ReduceOp> {
// scf.yield
// }
scf::ReduceOp CreateReduceOpInNestedParallelLoops(
xla_lhlo::ReduceOp xla_reduce_op,
lmhlo::ReduceOp xla_reduce_op,
ConversionPatternRewriter* rewriter) const {
auto loc = xla_reduce_op.getLoc();
DenseSet<int> reducing_dims;
@ -314,7 +314,7 @@ class ReduceOpConverter : public OpConversionPattern<xla_lhlo::ReduceOp> {
// accumulator = reduction_operator(output[O], value)
// output[O] = accumulator
//
// Converts `xla_lhlo.ReduceWindowOp` into two scf::ParallelOp and a
// Converts `lmhlo.ReduceWindowOp` into two scf::ParallelOp and a
// scf::ReduceOp.
// The outper `ParallelOp` refers to the parallel loops that traverese output
// buffer. The inner `ParalleOp` refers to the reduction loops that traverse
@ -325,11 +325,11 @@ class ReduceOpConverter : public OpConversionPattern<xla_lhlo::ReduceOp> {
// func @reduce_window(%arg: memref<112x112xf32>,
// %init: memref<f32>,
// %result: memref<56x56xf32>) {
// "xla_lhlo.reduce_window"(%arg, %init, %result) ( {
// "lmhlo.reduce_window"(%arg, %init, %result) ( {
// ^bb0(%lhs: memref<f32>, %rhs: memref<f32>, %res: memref<f32>):
// "xla_lhlo.maximum"(%lhs, %rhs, %res)
// "lmhlo.maximum"(%lhs, %rhs, %res)
// : (memref<f32>, memref<f32>, memref<f32>) -> ()
// "xla_lhlo.terminator"() : () -> ()
// "lmhlo.terminator"() : () -> ()
// }) {
// padding = dense<[[0, 1], [0, 1]]> : tensor<2x2xi64>,
// window_dimensions = dense<[3, 3]> : tensor<2xi64>,
@ -359,12 +359,12 @@ class ReduceOpConverter : public OpConversionPattern<xla_lhlo::ReduceOp> {
// return
// }
class ReduceWindowOpConverter
: public OpConversionPattern<xla_lhlo::ReduceWindowOp> {
: public OpConversionPattern<lmhlo::ReduceWindowOp> {
public:
using OpConversionPattern<xla_lhlo::ReduceWindowOp>::OpConversionPattern;
using OpConversionPattern<lmhlo::ReduceWindowOp>::OpConversionPattern;
LogicalResult matchAndRewrite(
xla_lhlo::ReduceWindowOp xla_reduce_window_op, ArrayRef<Value> /*args*/,
lmhlo::ReduceWindowOp xla_reduce_window_op, ArrayRef<Value> /*args*/,
ConversionPatternRewriter& rewriter) const final {
scf::ParallelOp output_loop, window_loop;
std::tie(output_loop, window_loop) =
@ -383,7 +383,7 @@ class ReduceWindowOpConverter
private:
std::pair<scf::ParallelOp, scf::ParallelOp>
CreateParallelLoopsToTraverseOutputAndWindow(
xla_lhlo::ReduceWindowOp xla_reduce_window_op,
lmhlo::ReduceWindowOp xla_reduce_window_op,
ConversionPatternRewriter* rewriter) const {
auto loc = xla_reduce_window_op.getLoc();
Value init_value =
@ -415,9 +415,8 @@ class ReduceWindowOpConverter
}
scf::ReduceOp CreateReduceOpInNestedParallelLoops(
xla_lhlo::ReduceWindowOp xla_reduce_window_op,
scf::ParallelOp output_loop, scf::ParallelOp window_loop,
ConversionPatternRewriter* rewriter) const {
lmhlo::ReduceWindowOp xla_reduce_window_op, scf::ParallelOp output_loop,
scf::ParallelOp window_loop, ConversionPatternRewriter* rewriter) const {
rewriter->setInsertionPointToStart(window_loop.getBody());
auto loc = xla_reduce_window_op.getLoc();
@ -481,12 +480,12 @@ class ReduceWindowOpConverter
// initialized_flag = true
// output(selected_index) = scatter(output(selected_index), source(S))
class SelectAndScatterOpConverter
: public OpConversionPattern<xla_lhlo::SelectAndScatterOp> {
: public OpConversionPattern<lmhlo::SelectAndScatterOp> {
public:
using OpConversionPattern<xla_lhlo::SelectAndScatterOp>::OpConversionPattern;
using OpConversionPattern<lmhlo::SelectAndScatterOp>::OpConversionPattern;
LogicalResult matchAndRewrite(
xla_lhlo::SelectAndScatterOp s_and_s_op, ArrayRef<Value> /*args*/,
lmhlo::SelectAndScatterOp s_and_s_op, ArrayRef<Value> /*args*/,
ConversionPatternRewriter& rewriter) const final {
auto loc = s_and_s_op.getLoc();
InitializeOutput(s_and_s_op, &rewriter);
@ -515,7 +514,7 @@ class SelectAndScatterOpConverter
}
private:
void InitializeOutput(xla_lhlo::SelectAndScatterOp s_and_s_op,
void InitializeOutput(lmhlo::SelectAndScatterOp s_and_s_op,
OpBuilder* b) const {
auto loc = s_and_s_op.getLoc();
Value init_value = b->create<LoadOp>(loc, s_and_s_op.init_value());
@ -533,7 +532,7 @@ class SelectAndScatterOpConverter
SmallVector<Value, 2> window_ivs;
scf::ForOp inner_loop;
};
WindowLoops InsertWindowLoops(xla_lhlo::SelectAndScatterOp s_and_s_op,
WindowLoops InsertWindowLoops(lmhlo::SelectAndScatterOp s_and_s_op,
scf::ParallelOp loop_over_src,
OpBuilder* b) const {
auto loc = s_and_s_op.getLoc();
@ -598,7 +597,7 @@ class SelectAndScatterOpConverter
SmallVector<Value, 4> ivs_val_flag_;
};
SmallVector<Value, 2> SelectIvs(xla_lhlo::SelectAndScatterOp s_and_s_op,
SmallVector<Value, 2> SelectIvs(lmhlo::SelectAndScatterOp s_and_s_op,
scf::ParallelOp loop_over_src,
OpBuilder* b) const {
auto loc = s_and_s_op.getLoc();
@ -636,9 +635,10 @@ class SelectAndScatterOpConverter
return window_loops.selected_ivs;
}
SmallVector<Value, 4> SelectOrInitialize(
xla_lhlo::SelectAndScatterOp s_and_s_op, ArrayRef<Value> operand_ivs,
IterArgs* ivs_val_flag, OpBuilder* b) const {
SmallVector<Value, 4> SelectOrInitialize(lmhlo::SelectAndScatterOp s_and_s_op,
ArrayRef<Value> operand_ivs,
IterArgs* ivs_val_flag,
OpBuilder* b) const {
auto loc = s_and_s_op.getLoc();
Value true_i1 = b->create<mlir::ConstantOp>(
loc, b->getI1Type(), b->getIntegerAttr(b->getI1Type(), 1));
@ -707,9 +707,9 @@ struct LhloLegalizeToParallelLoops
ConversionTarget target(getContext());
target.addLegalDialect<linalg::LinalgDialect, StandardOpsDialect,
scf::SCFDialect, XlaLhloDialect>();
target.addIllegalOp<xla_lhlo::ReduceOp, xla_lhlo::ReduceWindowOp,
xla_lhlo::SelectAndScatterOp>();
scf::SCFDialect, LmhloDialect>();
target.addIllegalOp<lmhlo::ReduceOp, lmhlo::ReduceWindowOp,
lmhlo::SelectAndScatterOp>();
if (failed(applyPartialConversion(func, target, patterns))) {
signalPassFailure();
@ -727,5 +727,5 @@ static PassRegistration<LhloLegalizeToParallelLoops> legalize_lhlo_pass(
"lhlo-legalize-to-parallel-loops",
"Legalize from LHLO dialect to parallel loops.");
} // namespace xla_lhlo
} // namespace lmhlo
} // namespace mlir

View File

@ -34,7 +34,7 @@ struct TestMaterializeBroadcastsPass
OwningRewritePatternList conversionPatterns;
// Consider the mhlo dialect legal for tests.
conversionTarget.addLegalDialect<XlaHloDialect>();
conversionTarget.addLegalDialect<MhloDialect>();
// The conversion uses helpers from the Standard dialect.
conversionTarget.addLegalDialect<mlir::StandardOpsDialect>();

View File

@ -131,9 +131,9 @@ class PointwiseToLinalgConverter : public OpConversionPattern<OpTy> {
loc, opResultTypes, args, args_count, results_count, indexing_maps,
GetNParallelLoopsAttrs(nloops),
[&](OpBuilder& nestedBuilder, Location nestedLoc, ValueRange args) {
// TODO(ravishankarm) : For now use the method in xla_lhlo namespace.
// TODO(ravishankarm) : For now use the method in lmhlo namespace.
// That method needs to be moved out of there.
Value opResult = xla_lhlo::XlaOpToStdScalarOp::map<OpTy>(
Value opResult = lmhlo::XlaOpToStdScalarOp::map<OpTy>(
op, bodyResultTypes,
llvm::to_vector<2>(args.take_front(args_count)), &rewriter);
nestedBuilder.create<linalg::YieldOp>(loc, opResult);
@ -162,8 +162,8 @@ class ScalarPointwiseToStandardConverter : public OpConversionPattern<LhloOp> {
// Create two loads from the input.
auto lhs = rewriter.create<LoadOp>(loc, lhlo_op.lhs());
auto rhs = rewriter.create<LoadOp>(loc, lhlo_op.rhs());
// TODO(ravishankarm) : Move this method out of xla_lhlo namespace.
Value opResult = xla_lhlo::XlaOpToStdScalarOp::map<LhloOp>(
// TODO(ravishankarm) : Move this method out of lmhlo namespace.
Value opResult = lmhlo::XlaOpToStdScalarOp::map<LhloOp>(
lhlo_op, argType.getElementType(), llvm::ArrayRef<Value>{lhs, rhs},
&rewriter);
rewriter.create<StoreOp>(loc, opResult, lhlo_op.out());
@ -173,21 +173,21 @@ class ScalarPointwiseToStandardConverter : public OpConversionPattern<LhloOp> {
};
//===----------------------------------------------------------------------===//
// xla_lhlo.convolution conversion pattern.
// lmhlo.convolution conversion pattern.
//===----------------------------------------------------------------------===//
/// Converts xla_lhlo.convolution operation to a linalg.conv op.
struct ConvToLinalgConverter : public OpConversionPattern<xla_lhlo::ConvOp> {
/// Converts lmhlo.convolution operation to a linalg.conv op.
struct ConvToLinalgConverter : public OpConversionPattern<lmhlo::ConvOp> {
public:
using OpConversionPattern<xla_lhlo::ConvOp>::OpConversionPattern;
using OpConversionPattern<lmhlo::ConvOp>::OpConversionPattern;
// This code has been adapted from IREE's
// (https://github.com/google/iree/) mhlo -> linalg conversion.
LogicalResult matchAndRewrite(
xla_lhlo::ConvOp op, ArrayRef<Value> args,
lmhlo::ConvOp op, ArrayRef<Value> args,
ConversionPatternRewriter& rewriter) const final {
// Check validity of dimension information.
if (const xla_lhlo::ConvDimensionNumbers& dimensionNumbers =
if (const lmhlo::ConvDimensionNumbers& dimensionNumbers =
op.dimension_numbers()) {
const int inputSpatialRank =
llvm::size(dimensionNumbers.input_spatial_dimensions());
@ -388,14 +388,14 @@ class HloBroadcastInDimConverter
};
class LhloBroadcastInDimConverter
: public OpConversionPattern<xla_lhlo::BroadcastInDimOp> {
: public OpConversionPattern<lmhlo::BroadcastInDimOp> {
public:
using OpConversionPattern<xla_lhlo::BroadcastInDimOp>::OpConversionPattern;
using OpConversionPattern<lmhlo::BroadcastInDimOp>::OpConversionPattern;
LogicalResult matchAndRewrite(
xla_lhlo::BroadcastInDimOp op, ArrayRef<Value> args,
lmhlo::BroadcastInDimOp op, ArrayRef<Value> args,
ConversionPatternRewriter& rewriter) const final {
xla_lhlo::BroadcastInDimOp::Adaptor operand_adaptor(args);
lmhlo::BroadcastInDimOp::Adaptor operand_adaptor(args);
auto result_type = operand_adaptor.output().getType().cast<MemRefType>();
auto result_shape = result_type.getShape();
@ -444,9 +444,9 @@ class LhloBroadcastInDimConverter
// Inserts 'linalg.reshape' if there is a size-1 dim expansion.
std::pair<Value, SmallVector<int64_t, 2>> InsertReshapeIfNecessary(
xla_lhlo::BroadcastInDimOp op, ArrayRef<Value> args,
lmhlo::BroadcastInDimOp op, ArrayRef<Value> args,
ConversionPatternRewriter& rewriter) const {
xla_lhlo::BroadcastInDimOp::Adaptor operand_adaptor(args);
lmhlo::BroadcastInDimOp::Adaptor operand_adaptor(args);
Value operand = operand_adaptor.operand();
auto operand_type = operand_adaptor.operand().getType().cast<MemRefType>();
auto operand_shape = operand_type.getShape();
@ -512,7 +512,7 @@ class LhloBroadcastInDimConverter
return std::make_pair(operand, broadcast_dims);
}
SmallVector<AffineMap, 2> getIndexingMaps(xla_lhlo::BroadcastInDimOp op,
SmallVector<AffineMap, 2> getIndexingMaps(lmhlo::BroadcastInDimOp op,
ArrayRef<int64_t> broadcastDims,
ArrayRef<int64_t> resultShape,
MemRefType operandType,
@ -639,12 +639,12 @@ class ReshapeOpConverter : public OpConversionPattern<OpTy> {
}
};
class IotaConverter : public OpConversionPattern<xla_lhlo::IotaOp> {
class IotaConverter : public OpConversionPattern<lmhlo::IotaOp> {
public:
using OpConversionPattern<xla_lhlo::IotaOp>::OpConversionPattern;
using OpConversionPattern<lmhlo::IotaOp>::OpConversionPattern;
LogicalResult matchAndRewrite(
xla_lhlo::IotaOp iotaOp, ArrayRef<Value> args,
lmhlo::IotaOp iotaOp, ArrayRef<Value> args,
ConversionPatternRewriter& rewriter) const final {
auto resultMemrefType =
iotaOp.getOperand().getType().dyn_cast<MemRefType>();
@ -680,12 +680,12 @@ class IotaConverter : public OpConversionPattern<xla_lhlo::IotaOp> {
}
};
class ConstConverter : public OpConversionPattern<xla_lhlo::ConstOp> {
class ConstConverter : public OpConversionPattern<lmhlo::ConstOp> {
public:
using OpConversionPattern<xla_lhlo::ConstOp>::OpConversionPattern;
using OpConversionPattern<lmhlo::ConstOp>::OpConversionPattern;
LogicalResult matchAndRewrite(
xla_lhlo::ConstOp constOp, ArrayRef<Value> args,
lmhlo::ConstOp constOp, ArrayRef<Value> args,
ConversionPatternRewriter& rewriter) const final {
auto loc = constOp.getLoc();
auto valueAttr = constOp.value().cast<DenseElementsAttr>();
@ -726,12 +726,12 @@ class ReverseConverter
}
};
class SliceConverter : public OpConversionPattern<xla_lhlo::SliceOp> {
class SliceConverter : public OpConversionPattern<lmhlo::SliceOp> {
public:
using OpConversionPattern<xla_lhlo::SliceOp>::OpConversionPattern;
using OpConversionPattern<lmhlo::SliceOp>::OpConversionPattern;
LogicalResult matchAndRewrite(
xla_lhlo::SliceOp sliceOp, ArrayRef<Value> args,
lmhlo::SliceOp sliceOp, ArrayRef<Value> args,
ConversionPatternRewriter& rewriter) const final {
auto loc = sliceOp.getLoc();
auto argType =
@ -763,50 +763,50 @@ class SliceConverter : public OpConversionPattern<xla_lhlo::SliceOp> {
void populateLHLOToLinalgConversionPattern(MLIRContext* context,
OwningRewritePatternList* patterns) {
// clang-format off
patterns->insert<BroadcastConverter<xla_lhlo::BroadcastOp>,
patterns->insert<BroadcastConverter<lmhlo::BroadcastOp>,
ConstConverter,
ConvToLinalgConverter,
IotaConverter,
LhloBroadcastInDimConverter,
PointwiseToLinalgConverter<xla_lhlo::AbsOp>,
PointwiseToLinalgConverter<xla_lhlo::AddOp>,
PointwiseToLinalgConverter<xla_lhlo::AndOp>,
PointwiseToLinalgConverter<xla_lhlo::CeilOp>,
PointwiseToLinalgConverter<xla_lhlo::CompareOp>,
PointwiseToLinalgConverter<xla_lhlo::ComplexOp>,
PointwiseToLinalgConverter<xla_lhlo::ConvertOp>,
PointwiseToLinalgConverter<lmhlo::AbsOp>,
PointwiseToLinalgConverter<lmhlo::AddOp>,
PointwiseToLinalgConverter<lmhlo::AndOp>,
PointwiseToLinalgConverter<lmhlo::CeilOp>,
PointwiseToLinalgConverter<lmhlo::CompareOp>,
PointwiseToLinalgConverter<lmhlo::ComplexOp>,
PointwiseToLinalgConverter<lmhlo::ConvertOp>,
// TODO(ataei): Remove this pattern, CopyOp is folded away.
PointwiseToLinalgConverter<xla_lhlo::CopyOp>,
PointwiseToLinalgConverter<xla_lhlo::CosOp>,
PointwiseToLinalgConverter<xla_lhlo::DivOp>,
PointwiseToLinalgConverter<xla_lhlo::ExpOp>,
PointwiseToLinalgConverter<xla_lhlo::ImagOp>,
PointwiseToLinalgConverter<xla_lhlo::LogOp>,
PointwiseToLinalgConverter<xla_lhlo::MaxOp>,
PointwiseToLinalgConverter<xla_lhlo::MinOp>,
PointwiseToLinalgConverter<xla_lhlo::MulOp>,
PointwiseToLinalgConverter<xla_lhlo::NegOp>,
PointwiseToLinalgConverter<xla_lhlo::RealOp>,
PointwiseToLinalgConverter<xla_lhlo::RemOp>,
PointwiseToLinalgConverter<xla_lhlo::RsqrtOp>,
PointwiseToLinalgConverter<xla_lhlo::SelectOp>,
PointwiseToLinalgConverter<xla_lhlo::SignOp>,
PointwiseToLinalgConverter<xla_lhlo::SinOp>,
PointwiseToLinalgConverter<xla_lhlo::SqrtOp>,
PointwiseToLinalgConverter<xla_lhlo::SubOp>,
PointwiseToLinalgConverter<xla_lhlo::TanhOp>,
ReshapeOpConverter<xla_lhlo::ReshapeOp>,
ReverseConverter<xla_lhlo::ReverseOp>,
ScalarPointwiseToStandardConverter<xla_lhlo::AddOp>,
PointwiseToLinalgConverter<lmhlo::CopyOp>,
PointwiseToLinalgConverter<lmhlo::CosOp>,
PointwiseToLinalgConverter<lmhlo::DivOp>,
PointwiseToLinalgConverter<lmhlo::ExpOp>,
PointwiseToLinalgConverter<lmhlo::ImagOp>,
PointwiseToLinalgConverter<lmhlo::LogOp>,
PointwiseToLinalgConverter<lmhlo::MaxOp>,
PointwiseToLinalgConverter<lmhlo::MinOp>,
PointwiseToLinalgConverter<lmhlo::MulOp>,
PointwiseToLinalgConverter<lmhlo::NegOp>,
PointwiseToLinalgConverter<lmhlo::RealOp>,
PointwiseToLinalgConverter<lmhlo::RemOp>,
PointwiseToLinalgConverter<lmhlo::RsqrtOp>,
PointwiseToLinalgConverter<lmhlo::SelectOp>,
PointwiseToLinalgConverter<lmhlo::SignOp>,
PointwiseToLinalgConverter<lmhlo::SinOp>,
PointwiseToLinalgConverter<lmhlo::SqrtOp>,
PointwiseToLinalgConverter<lmhlo::SubOp>,
PointwiseToLinalgConverter<lmhlo::TanhOp>,
ReshapeOpConverter<lmhlo::ReshapeOp>,
ReverseConverter<lmhlo::ReverseOp>,
ScalarPointwiseToStandardConverter<lmhlo::AddOp>,
SliceConverter
>(context);
// clang-format on
}
// Converts LHLO ops to Linalg generic.
// Sample result for xla_lhlo::AddOp.
// Sample result for lmhlo::AddOp.
//
// "xla_lhlo.add"(%arg1, %arg2, %out) :
// "lmhlo.add"(%arg1, %arg2, %out) :
// (memref<2x2xf32>, memref<2x2xf32>, memref<2x2xf32>) -> ()
//
// will be converted to
@ -854,14 +854,14 @@ struct HloLegalizeToLinalg
} // namespace
namespace xla_lhlo {
namespace lmhlo {
std::unique_ptr<OperationPass<FuncOp>> createLegalizeLhloToLinalgPass() {
return absl::make_unique<LhloLegalizeToLinalg>();
}
static PassRegistration<LhloLegalizeToLinalg> legalize_lhlo_pass(
"lhlo-legalize-to-linalg", "Legalize from LHLO dialect to Linalg dialect");
} // namespace xla_lhlo
} // namespace lmhlo
namespace mhlo {

View File

@ -152,7 +152,7 @@ struct TransformUnrankedHloPass
// Setup conversion target.
MLIRContext &ctx = getContext();
ConversionTarget target(ctx);
target.addLegalDialect<XlaHloDialect, StandardOpsDialect,
target.addLegalDialect<MhloDialect, StandardOpsDialect,
shape::ShapeDialect>();
target.addLegalOp<FuncOp>();
AddLegalOpOnRankedTensor<SqrtOp>(&target);

View File

@ -11,7 +11,7 @@ func @broadcast_add(%arg0: tensor<?xf32>, %arg1: tensor<?xf32>) -> tensor<1xinde
// CHECK-DAG: %[[BCAST_S:.+]] = "shape.broadcast"(%[[ARG0_S]], %[[ARG1_S]])
// CHECK: %[[EXTENTS:.+]] = shape.to_extent_tensor %[[BCAST_S]]
// CHECK: return %[[EXTENTS]]
%0 = xla_chlo.broadcast_add %arg0, %arg1 : (tensor<?xf32>, tensor<?xf32>) -> tensor<?xf32>
%0 = chlo.broadcast_add %arg0, %arg1 : (tensor<?xf32>, tensor<?xf32>) -> tensor<?xf32>
%1 = "xla_test.reify_return_type_shapes"(%0) : (tensor<?xf32>) -> tensor<1xindex>
return %1 : tensor<1xindex>
}
@ -19,7 +19,7 @@ func @broadcast_add(%arg0: tensor<?xf32>, %arg1: tensor<?xf32>) -> tensor<1xinde
// -----
// CHECK-LABEL: @complex_ranked_components
func @complex_ranked_components(%arg0: tensor<?xf32>, %arg1: tensor<?x?xf32>) -> tensor<?x?xcomplex<f32>> {
%0 = xla_chlo.broadcast_complex %arg0, %arg1 : (tensor<?xf32>, tensor<?x?xf32>) -> tensor<?x?xcomplex<f32>>
%0 = chlo.broadcast_complex %arg0, %arg1 : (tensor<?xf32>, tensor<?x?xf32>) -> tensor<?x?xcomplex<f32>>
// CHECK: "xla_test.return_type_components"(%0) {dims0 = [-1, -1], element_type0 = complex<f32>}
%1 = "xla_test.get_return_type_components"(%0) : (tensor<?x?xcomplex<f32>>) -> tensor<?x?xcomplex<f32>>
return %1 : tensor<?x?xcomplex<f32>>
@ -28,7 +28,7 @@ func @complex_ranked_components(%arg0: tensor<?xf32>, %arg1: tensor<?x?xf32>) ->
// -----
// CHECK-LABEL: @compare_ranked_components
func @compare_ranked_components(%arg0: tensor<?xf32>, %arg1: tensor<?x?xf32>) -> tensor<?x?xi1> {
%0 = xla_chlo.broadcast_compare %arg0, %arg1 {comparison_direction = "EQ"} : (tensor<?xf32>, tensor<?x?xf32>) -> tensor<?x?xi1>
%0 = chlo.broadcast_compare %arg0, %arg1 {comparison_direction = "EQ"} : (tensor<?xf32>, tensor<?x?xf32>) -> tensor<?x?xi1>
// CHECK: "xla_test.return_type_components"(%0) {dims0 = [-1, -1], element_type0 = i1}
%1 = "xla_test.get_return_type_components"(%0) : (tensor<?x?xi1>) -> tensor<?x?xi1>
return %0 : tensor<?x?xi1>
@ -37,7 +37,7 @@ func @compare_ranked_components(%arg0: tensor<?xf32>, %arg1: tensor<?x?xf32>) ->
// -----
// CHECK-LABEL: @broadcast_add_ranked_components_r1
func @broadcast_add_ranked_components_r1(%arg0: tensor<?xf32>, %arg1: tensor<?xf32>) -> tensor<?xf32> {
%0 = xla_chlo.broadcast_add %arg0, %arg1 : (tensor<?xf32>, tensor<?xf32>) -> tensor<?xf32>
%0 = chlo.broadcast_add %arg0, %arg1 : (tensor<?xf32>, tensor<?xf32>) -> tensor<?xf32>
// CHECK: "xla_test.return_type_components"(%0) {dims0 = [-1], element_type0 = f32}
%1 = "xla_test.get_return_type_components"(%0) : (tensor<?xf32>) -> tensor<?xf32>
return %1 : tensor<?xf32>
@ -46,7 +46,7 @@ func @broadcast_add_ranked_components_r1(%arg0: tensor<?xf32>, %arg1: tensor<?xf
// -----
// CHECK-LABEL: @broadcast_add_ranked_components_r1x2
func @broadcast_add_ranked_components_r1x2(%arg0: tensor<?xf32>, %arg1: tensor<?x3xf32>) -> tensor<?x3xf32> {
%0 = xla_chlo.broadcast_add %arg0, %arg1 : (tensor<?xf32>, tensor<?x3xf32>) -> tensor<?x3xf32>
%0 = chlo.broadcast_add %arg0, %arg1 : (tensor<?xf32>, tensor<?x3xf32>) -> tensor<?x3xf32>
// TODO: Overly broad shapes are being returned. Tighten the calculation
// and update/extend these tests.
// CHECK: "xla_test.return_type_components"(%0) {dims0 = [-1, -1], element_type0 = f32}

View File

@ -5,7 +5,7 @@
// CHECK-LABEL: @addWithoutBroadcast
func @addWithoutBroadcast(%arg0: tensor<4xf32>, %arg1: tensor<4xf32>) -> tensor<4xf32> {
// CHECK: mhlo.add %arg0, %arg1
%0 = xla_chlo.broadcast_add %arg0, %arg1 : (tensor<4xf32>, tensor<4xf32>) -> tensor<4xf32>
%0 = chlo.broadcast_add %arg0, %arg1 : (tensor<4xf32>, tensor<4xf32>) -> tensor<4xf32>
return %0 : tensor<4xf32>
}
@ -26,7 +26,7 @@ func @dynamicBroadcast(%arg0: tensor<?xf32>, %arg1: tensor<?x?xf32>) -> tensor<?
// CHECK-NEXT: shape.assuming_yield %[[RESULT]]
// CHECK-NEXT: }
// CHECK-NEXT: return %[[FINAL_RESULT]] : tensor<?x?xf32>
%0 = xla_chlo.broadcast_add %arg0, %arg1 : (tensor<?xf32>, tensor<?x?xf32>) -> tensor<?x?xf32>
%0 = chlo.broadcast_add %arg0, %arg1 : (tensor<?xf32>, tensor<?x?xf32>) -> tensor<?x?xf32>
return %0 : tensor<?x?xf32>
}
@ -47,7 +47,7 @@ func @dynamicBroadcastComplex(%arg0: tensor<?xf32>, %arg1: tensor<?x?xf32>) -> t
// CHECK-NEXT: shape.assuming_yield %[[RESULT]]
// CHECK-NEXT: }
// CHECK-NEXT: return %[[FINAL_RESULT]] : tensor<?x?xcomplex<f32>>
%0 = xla_chlo.broadcast_complex %arg0, %arg1 : (tensor<?xf32>, tensor<?x?xf32>) -> tensor<?x?xcomplex<f32>>
%0 = chlo.broadcast_complex %arg0, %arg1 : (tensor<?xf32>, tensor<?x?xf32>) -> tensor<?x?xcomplex<f32>>
return %0 : tensor<?x?xcomplex<f32>>
}
@ -68,7 +68,7 @@ func @dynamicBroadcastCompare(%arg0: tensor<?xf32>, %arg1: tensor<?x?xf32>) -> t
// CHECK: shape.assuming_yield %[[RESULT]]
// CHECK-NEXT: }
// CHECK: return %[[FINAL_RESULT]] : tensor<?x?xi1>
%0 = xla_chlo.broadcast_compare %arg0, %arg1 {comparison_direction = "EQ"} : (tensor<?xf32>, tensor<?x?xf32>) -> tensor<?x?xi1>
%0 = chlo.broadcast_compare %arg0, %arg1 {comparison_direction = "EQ"} : (tensor<?xf32>, tensor<?x?xf32>) -> tensor<?x?xi1>
return %0 : tensor<?x?xi1>
}
@ -77,7 +77,7 @@ func @dynamicBroadcastCompare(%arg0: tensor<?xf32>, %arg1: tensor<?x?xf32>) -> t
// CHECK-LABEL: @dynamicNonScalarBroadcastDimensions
func @dynamicNonScalarBroadcastDimensions(%arg0: tensor<1x4xf32>, %arg1: tensor<4xf32>) -> tensor<1x4xf32> {
// CHECK: mhlo.add
%0 = xla_chlo.broadcast_add %arg0, %arg1 {broadcast_dimensions = dense<1> : tensor<1xi64>} : (tensor<1x4xf32>, tensor<4xf32>) -> tensor<1x4xf32>
%0 = chlo.broadcast_add %arg0, %arg1 {broadcast_dimensions = dense<1> : tensor<1xi64>} : (tensor<1x4xf32>, tensor<4xf32>) -> tensor<1x4xf32>
return %0 : tensor<1x4xf32>
}
@ -86,7 +86,7 @@ func @dynamicNonScalarBroadcastDimensions(%arg0: tensor<1x4xf32>, %arg1: tensor<
// CHECK-LABEL: @dynamicNonScalarByScalarBroadcastDimensions
func @dynamicNonScalarByScalarBroadcastDimensions(%arg0: tensor<1x4xf32>, %arg1: tensor<f32>) -> tensor<1x4xf32> {
// CHECK: mhlo.add
%0 = xla_chlo.broadcast_add %arg0, %arg1 {broadcast_dimensions = dense<[]> : tensor<0xi64>} : (tensor<1x4xf32>, tensor<f32>) -> tensor<1x4xf32>
%0 = chlo.broadcast_add %arg0, %arg1 {broadcast_dimensions = dense<[]> : tensor<0xi64>} : (tensor<1x4xf32>, tensor<f32>) -> tensor<1x4xf32>
return %0 : tensor<1x4xf32>
}
@ -95,7 +95,7 @@ func @dynamicNonScalarByScalarBroadcastDimensions(%arg0: tensor<1x4xf32>, %arg1:
func @dynamicNonScalarBroadcastDimensionsSizeMismatch(%arg0: tensor<1x4xf32>, %arg1: tensor<4xf32>) -> tensor<1x4xf32> {
// expected-warning @+2 {{unsupported non prefix-padded dynamic rank broadcast_dimensions}}
// expected-error @+1 {{failed to legalize operation}}
%0 = xla_chlo.broadcast_add %arg0, %arg1 {broadcast_dimensions = dense<[1, 2]> : tensor<2xi64>} : (tensor<1x4xf32>, tensor<4xf32>) -> tensor<1x4xf32>
%0 = chlo.broadcast_add %arg0, %arg1 {broadcast_dimensions = dense<[1, 2]> : tensor<2xi64>} : (tensor<1x4xf32>, tensor<4xf32>) -> tensor<1x4xf32>
return %0 : tensor<1x4xf32>
}
@ -104,7 +104,7 @@ func @dynamicNonScalarBroadcastDimensionsSizeMismatch(%arg0: tensor<1x4xf32>, %a
func @dynamicNonScalarBroadcastDimensionsMismatch(%arg0: tensor<1x4xf32>, %arg1: tensor<4xf32>) -> tensor<1x4xf32> {
// expected-warning @+2 {{unsupported non prefix-padded dynamic rank broadcast_dimensions}}
// expected-error @+1 {{failed to legalize operation}}
%0 = xla_chlo.broadcast_add %arg0, %arg1 {broadcast_dimensions = dense<2> : tensor<1xi64>} : (tensor<1x4xf32>, tensor<4xf32>) -> tensor<1x4xf32>
%0 = chlo.broadcast_add %arg0, %arg1 {broadcast_dimensions = dense<2> : tensor<1xi64>} : (tensor<1x4xf32>, tensor<4xf32>) -> tensor<1x4xf32>
return %0 : tensor<1x4xf32>
}
@ -114,7 +114,7 @@ func @dynamicNonScalarBroadcastDimensionsMismatch(%arg0: tensor<1x4xf32>, %arg1:
// CHECK-LABEL: @andWithoutBroadcast
func @andWithoutBroadcast(%arg0: tensor<4xi1>, %arg1: tensor<4xi1>) -> tensor<4xi1> {
// CHECK: mhlo.and %arg0, %arg1
%0 = xla_chlo.broadcast_and %arg0, %arg1 : (tensor<4xi1>, tensor<4xi1>) -> tensor<4xi1>
%0 = chlo.broadcast_and %arg0, %arg1 : (tensor<4xi1>, tensor<4xi1>) -> tensor<4xi1>
return %0 : tensor<4xi1>
}
@ -122,7 +122,7 @@ func @andWithoutBroadcast(%arg0: tensor<4xi1>, %arg1: tensor<4xi1>) -> tensor<4x
// CHECK-LABEL: @atan2WithoutBroadcast
func @atan2WithoutBroadcast(%arg0: tensor<4xf32>, %arg1: tensor<4xf32>) -> tensor<4xf32> {
// CHECK: mhlo.atan2 %arg0, %arg1
%0 = xla_chlo.broadcast_atan2 %arg0, %arg1 : (tensor<4xf32>, tensor<4xf32>) -> tensor<4xf32>
%0 = chlo.broadcast_atan2 %arg0, %arg1 : (tensor<4xf32>, tensor<4xf32>) -> tensor<4xf32>
return %0 : tensor<4xf32>
}
@ -130,7 +130,7 @@ func @atan2WithoutBroadcast(%arg0: tensor<4xf32>, %arg1: tensor<4xf32>) -> tenso
// CHECK-LABEL: @compareWithoutBroadcast
func @compareWithoutBroadcast(%arg0: tensor<4xf32>, %arg1: tensor<4xf32>) -> tensor<4xi1> {
// CHECK: "mhlo.compare"(%arg0, %arg1) {comparison_direction = "EQ"} : (tensor<4xf32>, tensor<4xf32>) -> tensor<4xi1>
%0 = xla_chlo.broadcast_compare %arg0, %arg1 {comparison_direction = "EQ"} : (tensor<4xf32>, tensor<4xf32>) -> tensor<4xi1>
%0 = chlo.broadcast_compare %arg0, %arg1 {comparison_direction = "EQ"} : (tensor<4xf32>, tensor<4xf32>) -> tensor<4xi1>
return %0 : tensor<4xi1>
}
@ -138,7 +138,7 @@ func @compareWithoutBroadcast(%arg0: tensor<4xf32>, %arg1: tensor<4xf32>) -> ten
// CHECK-LABEL: @complexWithoutBroadcast
func @complexWithoutBroadcast(%arg0: tensor<4xf32>, %arg1: tensor<4xf32>) -> tensor<4xcomplex<f32>> {
// CHECK: "mhlo.complex"(%arg0, %arg1) : (tensor<4xf32>, tensor<4xf32>) -> tensor<4xcomplex<f32>>
%0 = xla_chlo.broadcast_complex %arg0, %arg1 : (tensor<4xf32>, tensor<4xf32>) -> tensor<4xcomplex<f32>>
%0 = chlo.broadcast_complex %arg0, %arg1 : (tensor<4xf32>, tensor<4xf32>) -> tensor<4xcomplex<f32>>
return %0 : tensor<4xcomplex<f32>>
}
@ -146,7 +146,7 @@ func @complexWithoutBroadcast(%arg0: tensor<4xf32>, %arg1: tensor<4xf32>) -> ten
// CHECK-LABEL: @divideWithoutBroadcast
func @divideWithoutBroadcast(%arg0: tensor<4xf32>, %arg1: tensor<4xf32>) -> tensor<4xf32> {
// CHECK: mhlo.divide %arg0, %arg1
%0 = xla_chlo.broadcast_divide %arg0, %arg1 : (tensor<4xf32>, tensor<4xf32>) -> tensor<4xf32>
%0 = chlo.broadcast_divide %arg0, %arg1 : (tensor<4xf32>, tensor<4xf32>) -> tensor<4xf32>
return %0 : tensor<4xf32>
}
@ -154,7 +154,7 @@ func @divideWithoutBroadcast(%arg0: tensor<4xf32>, %arg1: tensor<4xf32>) -> tens
// CHECK-LABEL: @maximumWithoutBroadcast
func @maximumWithoutBroadcast(%arg0: tensor<4xf32>, %arg1: tensor<4xf32>) -> tensor<4xf32> {
// CHECK: mhlo.maximum %arg0, %arg1
%0 = xla_chlo.broadcast_maximum %arg0, %arg1 : (tensor<4xf32>, tensor<4xf32>) -> tensor<4xf32>
%0 = chlo.broadcast_maximum %arg0, %arg1 : (tensor<4xf32>, tensor<4xf32>) -> tensor<4xf32>
return %0 : tensor<4xf32>
}
@ -162,7 +162,7 @@ func @maximumWithoutBroadcast(%arg0: tensor<4xf32>, %arg1: tensor<4xf32>) -> ten
// CHECK-LABEL: @minimumWithoutBroadcast
func @minimumWithoutBroadcast(%arg0: tensor<4xf32>, %arg1: tensor<4xf32>) -> tensor<4xf32> {
// CHECK: mhlo.minimum %arg0, %arg1
%0 = xla_chlo.broadcast_minimum %arg0, %arg1 : (tensor<4xf32>, tensor<4xf32>) -> tensor<4xf32>
%0 = chlo.broadcast_minimum %arg0, %arg1 : (tensor<4xf32>, tensor<4xf32>) -> tensor<4xf32>
return %0 : tensor<4xf32>
}
@ -170,7 +170,7 @@ func @minimumWithoutBroadcast(%arg0: tensor<4xf32>, %arg1: tensor<4xf32>) -> ten
// CHECK-LABEL: @multiplyWithoutBroadcast
func @multiplyWithoutBroadcast(%arg0: tensor<4xf32>, %arg1: tensor<4xf32>) -> tensor<4xf32> {
// CHECK: mhlo.multiply %arg0, %arg1
%0 = xla_chlo.broadcast_multiply %arg0, %arg1 : (tensor<4xf32>, tensor<4xf32>) -> tensor<4xf32>
%0 = chlo.broadcast_multiply %arg0, %arg1 : (tensor<4xf32>, tensor<4xf32>) -> tensor<4xf32>
return %0 : tensor<4xf32>
}
@ -178,7 +178,7 @@ func @multiplyWithoutBroadcast(%arg0: tensor<4xf32>, %arg1: tensor<4xf32>) -> te
// CHECK-LABEL: @orWithoutBroadcast
func @orWithoutBroadcast(%arg0: tensor<4xi1>, %arg1: tensor<4xi1>) -> tensor<4xi1> {
// CHECK: mhlo.or %arg0, %arg1
%0 = xla_chlo.broadcast_or %arg0, %arg1 : (tensor<4xi1>, tensor<4xi1>) -> tensor<4xi1>
%0 = chlo.broadcast_or %arg0, %arg1 : (tensor<4xi1>, tensor<4xi1>) -> tensor<4xi1>
return %0 : tensor<4xi1>
}
@ -186,7 +186,7 @@ func @orWithoutBroadcast(%arg0: tensor<4xi1>, %arg1: tensor<4xi1>) -> tensor<4xi
// CHECK-LABEL: @powerWithoutBroadcast
func @powerWithoutBroadcast(%arg0: tensor<4xf32>, %arg1: tensor<4xf32>) -> tensor<4xf32> {
// CHECK: mhlo.power %arg0, %arg1
%0 = xla_chlo.broadcast_power %arg0, %arg1 : (tensor<4xf32>, tensor<4xf32>) -> tensor<4xf32>
%0 = chlo.broadcast_power %arg0, %arg1 : (tensor<4xf32>, tensor<4xf32>) -> tensor<4xf32>
return %0 : tensor<4xf32>
}
@ -194,7 +194,7 @@ func @powerWithoutBroadcast(%arg0: tensor<4xf32>, %arg1: tensor<4xf32>) -> tenso
// CHECK-LABEL: @remainderWithoutBroadcast
func @remainderWithoutBroadcast(%arg0: tensor<4xf32>, %arg1: tensor<4xf32>) -> tensor<4xf32> {
// CHECK: mhlo.remainder %arg0, %arg1
%0 = xla_chlo.broadcast_remainder %arg0, %arg1 : (tensor<4xf32>, tensor<4xf32>) -> tensor<4xf32>
%0 = chlo.broadcast_remainder %arg0, %arg1 : (tensor<4xf32>, tensor<4xf32>) -> tensor<4xf32>
return %0 : tensor<4xf32>
}
@ -202,7 +202,7 @@ func @remainderWithoutBroadcast(%arg0: tensor<4xf32>, %arg1: tensor<4xf32>) -> t
// CHECK-LABEL: @shift_leftWithoutBroadcast
func @shift_leftWithoutBroadcast(%arg0: tensor<4xf32>, %arg1: tensor<4xf32>) -> tensor<4xf32> {
// CHECK: mhlo.shift_left %arg0, %arg1
%0 = xla_chlo.broadcast_shift_left %arg0, %arg1 : (tensor<4xf32>, tensor<4xf32>) -> tensor<4xf32>
%0 = chlo.broadcast_shift_left %arg0, %arg1 : (tensor<4xf32>, tensor<4xf32>) -> tensor<4xf32>
return %0 : tensor<4xf32>
}
@ -210,7 +210,7 @@ func @shift_leftWithoutBroadcast(%arg0: tensor<4xf32>, %arg1: tensor<4xf32>) ->
// CHECK-LABEL: @shift_right_arithmeticWithoutBroadcast
func @shift_right_arithmeticWithoutBroadcast(%arg0: tensor<4xf32>, %arg1: tensor<4xf32>) -> tensor<4xf32> {
// CHECK: mhlo.shift_right_arithmetic %arg0, %arg1
%0 = xla_chlo.broadcast_shift_right_arithmetic %arg0, %arg1 : (tensor<4xf32>, tensor<4xf32>) -> tensor<4xf32>
%0 = chlo.broadcast_shift_right_arithmetic %arg0, %arg1 : (tensor<4xf32>, tensor<4xf32>) -> tensor<4xf32>
return %0 : tensor<4xf32>
}
@ -218,7 +218,7 @@ func @shift_right_arithmeticWithoutBroadcast(%arg0: tensor<4xf32>, %arg1: tensor
// CHECK-LABEL: @shift_right_logicalWithoutBroadcast
func @shift_right_logicalWithoutBroadcast(%arg0: tensor<4xf32>, %arg1: tensor<4xf32>) -> tensor<4xf32> {
// CHECK: mhlo.shift_right_logical %arg0, %arg1
%0 = xla_chlo.broadcast_shift_right_logical %arg0, %arg1 : (tensor<4xf32>, tensor<4xf32>) -> tensor<4xf32>
%0 = chlo.broadcast_shift_right_logical %arg0, %arg1 : (tensor<4xf32>, tensor<4xf32>) -> tensor<4xf32>
return %0 : tensor<4xf32>
}
@ -226,7 +226,7 @@ func @shift_right_logicalWithoutBroadcast(%arg0: tensor<4xf32>, %arg1: tensor<4x
// CHECK-LABEL: @subWithoutBroadcast
func @subWithoutBroadcast(%arg0: tensor<4xf32>, %arg1: tensor<4xf32>) -> tensor<4xf32> {
// CHECK: mhlo.subtract %arg0, %arg1
%0 = xla_chlo.broadcast_subtract %arg0, %arg1 : (tensor<4xf32>, tensor<4xf32>) -> tensor<4xf32>
%0 = chlo.broadcast_subtract %arg0, %arg1 : (tensor<4xf32>, tensor<4xf32>) -> tensor<4xf32>
return %0 : tensor<4xf32>
}
@ -234,6 +234,6 @@ func @subWithoutBroadcast(%arg0: tensor<4xf32>, %arg1: tensor<4xf32>) -> tensor<
// CHECK-LABEL: @xorWithoutBroadcast
func @xorWithoutBroadcast(%arg0: tensor<4xi1>, %arg1: tensor<4xi1>) -> tensor<4xi1> {
// CHECK: mhlo.xor %arg0, %arg1
%0 = xla_chlo.broadcast_xor %arg0, %arg1 : (tensor<4xi1>, tensor<4xi1>) -> tensor<4xi1>
%0 = chlo.broadcast_xor %arg0, %arg1 : (tensor<4xi1>, tensor<4xi1>) -> tensor<4xi1>
return %0 : tensor<4xi1>
}

View File

@ -0,0 +1,34 @@
// RUN: mlir-hlo-opt -hlo-legalize-to-lhlo=results-escape-function=true -buffer-placement %s -o - | FileCheck %s
// CHECK-LABEL: func @func_op_unranked_arg_result
func @func_op_unranked_arg_result(%arg0: tensor<*xf32>) -> tensor<*xf32> {
return %arg0 : tensor<*xf32>
}
// CHECK-SAME: ([[ARG:%.*]]: memref<*xf32>) -> memref<*xf32>
// CHECK-NEXT: return [[ARG]] : memref<*xf32>
// -----
// CHECK-LABEL: func @dynamic_reshape_from_unranked
func @dynamic_reshape_from_unranked(
%operand: tensor<*xf32>, %shape: tensor<1xi32>) -> tensor<?xf32> {
%reshaped = "mhlo.dynamic_reshape"(%operand, %shape)
: (tensor<*xf32>, tensor<1xi32>) -> tensor<?xf32>
return %reshaped : tensor<?xf32>
}
// CHECK-SAME: ([[ARG:%.*]]: memref<*xf32>, [[SHAPE:%.*]]: memref<1xi32>)
// CHECK-NEXT: reshape_memref_cast [[ARG]]([[SHAPE]])
// CHECK-SAME: : (memref<*xf32>, memref<1xi32>) -> memref<?xf32>
// -----
// CHECK-LABEL: func @dynamic_reshape_to_unranked
func @dynamic_reshape_to_unranked(
%operand: tensor<?xf32>, %shape: tensor<?xi32>) -> tensor<*xf32> {
%reshaped = "mhlo.dynamic_reshape"(%operand, %shape)
: (tensor<?xf32>, tensor<?xi32>) -> tensor<*xf32>
return %reshaped : tensor<*xf32>
}
// CHECK-SAME: ([[ARG:%.*]]: memref<?xf32>, [[SHAPE:%.*]]: memref<?xi32>)
// CHECK-NEXT: reshape_memref_cast [[ARG]]([[SHAPE]])
// CHECK-SAME: : (memref<?xf32>, memref<?xi32>) -> memref<*xf32>

View File

@ -7,7 +7,7 @@ func @attrs_copy(%operand: memref<2x2xf32>, %result: memref<2x2xf32>) {
%tensor_result = "mhlo.exponential"(%tensor_operand)
{some_attr_1 = "exp.1", some_attr_2 = dense<1> : tensor<1xi64>}
: (tensor<2x2xf32>) -> tensor<2x2xf32>
// BOTH: "xla_lhlo.exponential"(%{{.*}}, %{{.*}}) {some_attr_1 = "exp.1", some_attr_2 = dense<1> : tensor<1xi64>}
// BOTH: "lmhlo.exponential"(%{{.*}}, %{{.*}}) {some_attr_1 = "exp.1", some_attr_2 = dense<1> : tensor<1xi64>}
tensor_store %tensor_result, %result : memref<2x2xf32>
return
}
@ -18,10 +18,10 @@ func @return_func(%arg0: tensor<4xf32>) -> tensor<4xf32> {
return %arg0 : tensor<4xf32>
}
// PRE: (%[[ARG0:.*]]: [[TYPE:.*]], %[[RESULT:.*]]: [[TYPE]])
// PRE-NEXT: "xla_lhlo.copy"(%[[ARG0]], %[[RESULT]]) : ([[TYPE]], [[TYPE]]) -> ()
// PRE-NEXT: "lmhlo.copy"(%[[ARG0]], %[[RESULT]]) : ([[TYPE]], [[TYPE]]) -> ()
// PRE-NEXT: return
// ESC: (%[[ARG0:.*]]: [[TYPE:.*]]) -> [[TYPE]]
// ESC-NOT: "xla_lhlo.copy"
// ESC-NOT: "lmhlo.copy"
// ESC-NEXT: return %[[ARG0]]
// -----
@ -38,20 +38,20 @@ func @func_op_long(%arg0: tensor<4xf32>, %arg1: tensor<4xf32>) -> tensor<4xf32>
// PRE: (%[[NEW_ARG0:.*]]: memref<4xf32>, %[[NEW_ARG1:.*]]: memref<4xf32>, %[[RESULT:.*]]: memref<4xf32>)
// ESC: (%[[NEW_ARG0:.*]]: memref<4xf32>, %[[NEW_ARG1:.*]]: memref<4xf32>) -> memref<4xf32>
// BOTH-NEXT: %[[MAX_RESULT:.*]] = alloc() : memref<4xf32>
// BOTH-NEXT: "xla_lhlo.maximum"(%[[NEW_ARG0]], %[[NEW_ARG1]], %[[MAX_RESULT]])
// BOTH-NEXT: "lmhlo.maximum"(%[[NEW_ARG0]], %[[NEW_ARG1]], %[[MAX_RESULT]])
// BOTH-NEXT: %[[ADD_RESULT:.*]] = alloc() : memref<4xf32>
// BOTH-NEXT: "xla_lhlo.add"(%[[NEW_ARG0]], %[[MAX_RESULT]], %[[ADD_RESULT]])
// BOTH-NEXT: "lmhlo.add"(%[[NEW_ARG0]], %[[MAX_RESULT]], %[[ADD_RESULT]])
// BOTH-NEXT: dealloc %[[MAX_RESULT]] : memref<4xf32>
// BOTH-NEXT: %[[MIN_RESULT:.*]] = alloc() : memref<4xf32>
// BOTH-NEXT: "xla_lhlo.minimum"(%[[NEW_ARG0]], %[[NEW_ARG1]], %[[MIN_RESULT]])
// BOTH-NEXT: "lmhlo.minimum"(%[[NEW_ARG0]], %[[NEW_ARG1]], %[[MIN_RESULT]])
// BOTH-NEXT: %[[SUB_RESULT:.*]] = alloc() : memref<4xf32>
//  BOTH-NEXT: "xla_lhlo.subtract"(%[[NEW_ARG1]], %[[MIN_RESULT]], %[[SUB_RESULT]])
//  BOTH-NEXT: "lmhlo.subtract"(%[[NEW_ARG1]], %[[MIN_RESULT]], %[[SUB_RESULT]])
// BOTH-NEXT: dealloc %[[MIN_RESULT]] : memref<4xf32>
// BOTH-NEXT: %[[MUL_RESULT:.*]] = alloc() : memref<4xf32>
// BOTH-NEXT: "xla_lhlo.multiply"(%[[ADD_RESULT]], %[[SUB_RESULT]], %[[MUL_RESULT]])
// BOTH-NEXT: "lmhlo.multiply"(%[[ADD_RESULT]], %[[SUB_RESULT]], %[[MUL_RESULT]])
// BOTH-NEXT: dealloc %[[SUB_RESULT]] : memref<4xf32>
// BOTH-NEXT: dealloc %[[ADD_RESULT]] : memref<4xf32>
// PRE-NEXT: "xla_lhlo.copy"(%[[MUL_RESULT]], %[[RESULT]]) : (memref<4xf32>, memref<4xf32>) -> ()
// PRE-NEXT: "lmhlo.copy"(%[[MUL_RESULT]], %[[RESULT]]) : (memref<4xf32>, memref<4xf32>) -> ()
// PRE-NEXT: dealloc %[[MUL_RESULT]] : memref<4xf32>
// PRE-NEXT: return
// ESC-NEXT: return %[[MUL_RESULT]] : memref<4xf32>
@ -67,14 +67,14 @@ func @fusion(%multiplier: memref<2x2xf32>, %summand_1: memref<2x2xf32>,
%tensor_summand_2 = tensor_load %summand_2 : memref<2x2xf32>
%sum = "mhlo.add"(%tensor_summand_1, %tensor_summand_2)
: (tensor<2x2xf32>, tensor<2x2xf32>) -> tensor<2x2xf32>
// BOTH-NEXT: "xla_lhlo.add"(%{{.*}}, %{{.*}}, %[[ADD_RESULT]])
// BOTH-NEXT: "lmhlo.add"(%{{.*}}, %{{.*}}, %[[ADD_RESULT]])
// BOTH-NEXT: %[[MUL_RESULT:.*]] = alloc() : memref<2x2xf32>
%tensor_multiplier = tensor_load %multiplier : memref<2x2xf32>
%tensor_result = "mhlo.multiply"(%sum, %tensor_multiplier)
: (tensor<2x2xf32>, tensor<2x2xf32>) -> tensor<2x2xf32>
// BOTH-NEXT: "xla_lhlo.multiply"(%[[ADD_RESULT]], %{{.*}}, %[[MUL_RESULT]])
// BOTH-NEXT: "lmhlo.multiply"(%[[ADD_RESULT]], %{{.*}}, %[[MUL_RESULT]])
// BOTH-NEXT: dealloc %[[ADD_RESULT]] : memref<2x2xf32>
// BOTH-NEXT: "xla_lhlo.copy"(%[[MUL_RESULT]], %[[RESULT]])
// BOTH-NEXT: "lmhlo.copy"(%[[MUL_RESULT]], %[[RESULT]])
tensor_store %tensor_result, %result : memref<2x2xf32>
// BOTH-NEXT: dealloc %[[MUL_RESULT]] : memref<2x2xf32>
// BOTH-NEXT: return
@ -88,7 +88,7 @@ func @copy(%operand: memref<2x2xf32>, %result: memref<2x2xf32>) {
%tensor_operand = tensor_load %operand : memref<2x2xf32>
%tensor_result = "mhlo.copy"(%tensor_operand)
: (tensor<2x2xf32>) -> tensor<2x2xf32>
// BOTH: "xla_lhlo.copy"(%{{.*}}, %{{.*}})
// BOTH: "lmhlo.copy"(%{{.*}}, %{{.*}})
tensor_store %tensor_result, %result : memref<2x2xf32>
return
}
@ -100,7 +100,7 @@ func @exp(%operand: memref<2x2xf32>, %result: memref<2x2xf32>) {
%tensor_operand = tensor_load %operand : memref<2x2xf32>
%tensor_result = "mhlo.exponential"(%tensor_operand)
: (tensor<2x2xf32>) -> tensor<2x2xf32>
// BOTH: "xla_lhlo.exponential"(%{{.*}}, %{{.*}})
// BOTH: "lmhlo.exponential"(%{{.*}}, %{{.*}})
tensor_store %tensor_result, %result : memref<2x2xf32>
return
}
@ -112,7 +112,7 @@ func @log(%operand: memref<2x2xf32>, %result: memref<2x2xf32>) {
%tensor_operand = tensor_load %operand : memref<2x2xf32>
%tensor_result = "mhlo.log"(%tensor_operand)
: (tensor<2x2xf32>) -> tensor<2x2xf32>
// BOTH: "xla_lhlo.log"(%{{.*}}, %{{.*}})
// BOTH: "lmhlo.log"(%{{.*}}, %{{.*}})
tensor_store %tensor_result, %result : memref<2x2xf32>
return
}
@ -127,7 +127,7 @@ func @select(%pred: memref<2x2xi1>, %lhs: memref<2x2xf32>,
%tensor_rhs = tensor_load %rhs : memref<2x2xf32>
%tensor_result = "mhlo.select"(%tensor_pred, %tensor_lhs, %tensor_rhs)
: (tensor<2x2xi1>, tensor<2x2xf32>, tensor<2x2xf32>) -> tensor<2x2xf32>
// BOTH: "xla_lhlo.select"(%{{.*}}, %{{.*}}, %{{.*}}, %{{.*}})
// BOTH: "lmhlo.select"(%{{.*}}, %{{.*}}, %{{.*}}, %{{.*}})
tensor_store %tensor_result, %result : memref<2x2xf32>
return
}
@ -141,7 +141,7 @@ func @compare(%lhs: memref<2x2xf32>, %rhs: memref<2x2xf32>, %result: memref<2x2x
%tensor_result = "mhlo.compare"(%tensor_lhs, %tensor_rhs)
{comparison_direction = "EQ"}
: (tensor<2x2xf32>, tensor<2x2xf32>) -> tensor<2x2xi1>
// BOTH: "xla_lhlo.compare"(%{{.*}}, %{{.*}}, %{{.*}}) {comparison_direction = "EQ"}
// BOTH: "lmhlo.compare"(%{{.*}}, %{{.*}}, %{{.*}}) {comparison_direction = "EQ"}
tensor_store %tensor_result, %result : memref<2x2xi1>
return
}
@ -154,7 +154,7 @@ func @broadcast(%operand: memref<5xf32>, %result: memref<10x5xf32>) {
%tensor_result = "mhlo.broadcast_in_dim"(%tensor_operand)
{broadcast_dimensions = dense<1> : tensor<1xi64>}
: (tensor<5xf32>) -> tensor<10x5xf32>
// BOTH: "xla_lhlo.broadcast_in_dim"(%{{.*}}, %{{.*}}) {broadcast_dimensions = dense<1> : tensor<1xi64>}
// BOTH: "lmhlo.broadcast_in_dim"(%{{.*}}, %{{.*}}) {broadcast_dimensions = dense<1> : tensor<1xi64>}
tensor_store %tensor_result, %result : memref<10x5xf32>
return
}
@ -169,11 +169,12 @@ func @external_func() -> tensor<3xi64>
func @dyn_broadcast(%operand: memref<?x?xf32>) {
// BOTH-SAME: (%[[OPERAND:.*]]: memref<?x?xf32>)
%tensor_operand = tensor_load %operand : memref<?x?xf32>
%shape = call @external_func() : () -> tensor<3xi64>
%c1 = constant 1 : i64
%shape = tensor_from_elements(%c1, %c1, %c1) : tensor<3xi64>
%tensor_result = "mhlo.dynamic_broadcast_in_dim"(%tensor_operand, %shape) {
broadcast_dimensions = dense<[1, 2]> : tensor<2xi64>
} : (tensor<?x?xf32>, tensor<3xi64>) -> tensor<?x?x?xf32>
// BOTH: %[[SHAPE:.*]] = call @external_func()
// BOTH: %[[SHAPE:.*]] = tensor_from_elements
// BOTH: %[[C0:.*]] = constant 0 : index
// BOTH: %[[EL0:.*]] = extract_element %[[SHAPE]][%[[C0]]] : tensor<3xi64>
// BOTH: %[[IC0:.*]] = index_cast %[[EL0]] : i64 to index
@ -204,12 +205,12 @@ func @dyn_broadcast(%operand: memref<?x?xf32>) {
// BOTH: %[[EXPAND_1:.*]] = cmpi "slt", %[[OPERAND_DIM_1]], %[[RESULT_DIM_2]]
// BOTH: %[[STRIDE_1:.*]] = select %[[EXPAND_1]], %[[C0_]], %[[C1_]] : index
// BOTH: %[[TRANSFORMED_MEMREF:.*]] = xla_lhlo.dynamic_memref_cast
// BOTH: %[[TRANSFORMED_MEMREF:.*]] = lmhlo.dynamic_memref_cast
// BOTH-SAME: %[[OPERAND]](%[[RESULT_DIM_1]], %[[RESULT_DIM_2]])
// BOTH-SAME: {{\[}}%[[STRIDE_0]], %[[STRIDE_1]]]
// BOTH-SAME: : memref<?x?xf32> -> memref<?x?xf32, #map0>
// BOTH: "xla_lhlo.broadcast_in_dim"(%[[TRANSFORMED_MEMREF]], %[[RESULT]]) {
// BOTH: "lmhlo.broadcast_in_dim"(%[[TRANSFORMED_MEMREF]], %[[RESULT]]) {
// BOTH-SAME: broadcast_dimensions = dense<[1, 2]> : tensor<2xi64>
// BOTH-SAME: } : (memref<?x?xf32, #[[MAP]]>, memref<?x?x?xf32>) -> ()
@ -228,7 +229,7 @@ func @complex(%real: memref<2x2xf32>,
%tensor_imag = tensor_load %imag : memref<2x2xf32>
%tensor_result = "mhlo.complex"(%tensor_real, %tensor_imag)
: (tensor<2x2xf32>, tensor<2x2xf32>) -> tensor<2x2xcomplex<f32>>
// BOTH: "xla_lhlo.complex"(%{{.*}}, %{{.*}})
// BOTH: "lmhlo.complex"(%{{.*}}, %{{.*}})
tensor_store %tensor_result, %result : memref<2x2xcomplex<f32>>
return
}
@ -240,7 +241,7 @@ func @real(%operand: memref<2x2xcomplex<f32>>, %result: memref<2x2xf32>) {
%tensor_operand = tensor_load %operand : memref<2x2xcomplex<f32>>
%tensor_result = "mhlo.real"(%tensor_operand)
: (tensor<2x2xcomplex<f32>>) -> tensor<2x2xf32>
// BOTH: "xla_lhlo.real"(%{{.*}}, %{{.*}})
// BOTH: "lmhlo.real"(%{{.*}}, %{{.*}})
tensor_store %tensor_result, %result : memref<2x2xf32>
return
}
@ -252,7 +253,7 @@ func @imag(%operand: memref<2x2xcomplex<f32>>, %result: memref<2x2xf32>) {
%tensor_operand = tensor_load %operand : memref<2x2xcomplex<f32>>
%tensor_result = "mhlo.imag"(%tensor_operand)
: (tensor<2x2xcomplex<f32>>) -> tensor<2x2xf32>
// BOTH: "xla_lhlo.imag"(%{{.*}}, %{{.*}})
// BOTH: "lmhlo.imag"(%{{.*}}, %{{.*}})
tensor_store %tensor_result, %result : memref<2x2xf32>
return
}
@ -263,7 +264,7 @@ func @imag(%operand: memref<2x2xcomplex<f32>>, %result: memref<2x2xf32>) {
func @iota(%result: memref<10xi32>) {
%tensor_result = "mhlo.iota"()
{iota_dimension = 0 : i64} : () -> tensor<10xi32>
// BOTH: "xla_lhlo.iota"(%{{.*}}) {iota_dimension = 0 : i64}
// BOTH: "lmhlo.iota"(%{{.*}}) {iota_dimension = 0 : i64}
tensor_store %tensor_result, %result : memref<10xi32>
return
}
@ -275,7 +276,7 @@ func @abs(%operand: memref<2x2xf32>, %result: memref<2x2xf32>) {
%tensor_operand = tensor_load %operand : memref<2x2xf32>
%tensor_result = "mhlo.abs"(%tensor_operand)
: (tensor<2x2xf32>) -> tensor<2x2xf32>
// BOTH: "xla_lhlo.abs"(%{{.*}}, %{{.*}})
// BOTH: "lmhlo.abs"(%{{.*}}, %{{.*}})
tensor_store %tensor_result, %result : memref<2x2xf32>
return
}
@ -287,7 +288,7 @@ func @ceil(%operand: memref<2x2xf32>, %result: memref<2x2xf32>) {
%tensor_operand = tensor_load %operand : memref<2x2xf32>
%tensor_result = "mhlo.ceil"(%tensor_operand)
: (tensor<2x2xf32>) -> tensor<2x2xf32>
// BOTH: "xla_lhlo.ceil"(%{{.*}}, %{{.*}})
// BOTH: "lmhlo.ceil"(%{{.*}}, %{{.*}})
tensor_store %tensor_result, %result : memref<2x2xf32>
return
}
@ -299,7 +300,7 @@ func @convert(%operand: memref<2x2xf32>, %result: memref<2x2xf32>) {
%tensor_operand = tensor_load %operand : memref<2x2xf32>
%tensor_result = "mhlo.convert"(%tensor_operand)
: (tensor<2x2xf32>) -> tensor<2x2xf32>
// BOTH: "xla_lhlo.copy"(%{{.*}}, %{{.*}})
// BOTH: "lmhlo.copy"(%{{.*}}, %{{.*}})
// BOTH-NOT: tensor_store
tensor_store %tensor_result, %result : memref<2x2xf32>
return
@ -312,7 +313,7 @@ func @cos(%operand: memref<2x2xf32>, %result: memref<2x2xf32>) {
%tensor_operand = tensor_load %operand : memref<2x2xf32>
%tensor_result = "mhlo.cosine"(%tensor_operand)
: (tensor<2x2xf32>) -> tensor<2x2xf32>
// BOTH: "xla_lhlo.cosine"(%{{.*}}, %{{.*}})
// BOTH: "lmhlo.cosine"(%{{.*}}, %{{.*}})
tensor_store %tensor_result, %result : memref<2x2xf32>
return
}
@ -324,7 +325,7 @@ func @neg(%operand: memref<2x2xf32>, %result: memref<2x2xf32>) {
%tensor_operand = tensor_load %operand : memref<2x2xf32>
%tensor_result = "mhlo.negate"(%tensor_operand)
: (tensor<2x2xf32>) -> tensor<2x2xf32>
// BOTH: "xla_lhlo.negate"(%{{.*}}, %{{.*}})
// BOTH: "lmhlo.negate"(%{{.*}}, %{{.*}})
tensor_store %tensor_result, %result : memref<2x2xf32>
return
}
@ -336,7 +337,7 @@ func @rsqrt(%operand: memref<2x2xf32>, %result: memref<2x2xf32>) {
%tensor_operand = tensor_load %operand : memref<2x2xf32>
%tensor_result = "mhlo.rsqrt"(%tensor_operand)
: (tensor<2x2xf32>) -> tensor<2x2xf32>
// BOTH: "xla_lhlo.rsqrt"(%{{.*}}, %{{.*}})
// BOTH: "lmhlo.rsqrt"(%{{.*}}, %{{.*}})
tensor_store %tensor_result, %result : memref<2x2xf32>
return
}
@ -348,7 +349,7 @@ func @sign(%operand: memref<2x2xf32>, %result: memref<2x2xf32>) {
%tensor_operand = tensor_load %operand : memref<2x2xf32>
%tensor_result = "mhlo.sign"(%tensor_operand)
: (tensor<2x2xf32>) -> tensor<2x2xf32>
// BOTH: "xla_lhlo.sign"(%{{.*}}, %{{.*}})
// BOTH: "lmhlo.sign"(%{{.*}}, %{{.*}})
tensor_store %tensor_result, %result : memref<2x2xf32>
return
}
@ -360,7 +361,7 @@ func @sqrt(%operand: memref<2x2xf32>, %result: memref<2x2xf32>) {
%tensor_operand = tensor_load %operand : memref<2x2xf32>
%tensor_result = "mhlo.sqrt"(%tensor_operand)
: (tensor<2x2xf32>) -> tensor<2x2xf32>
// BOTH: "xla_lhlo.sqrt"(%{{.*}}, %{{.*}})
// BOTH: "lmhlo.sqrt"(%{{.*}}, %{{.*}})
tensor_store %tensor_result, %result : memref<2x2xf32>
return
}
@ -372,7 +373,7 @@ func @tanh(%operand: memref<2x2xf32>, %result: memref<2x2xf32>) {
%tensor_operand = tensor_load %operand : memref<2x2xf32>
%tensor_result = "mhlo.tanh"(%tensor_operand)
: (tensor<2x2xf32>) -> tensor<2x2xf32>
// BOTH: "xla_lhlo.tanh"(%{{.*}}, %{{.*}})
// BOTH: "lmhlo.tanh"(%{{.*}}, %{{.*}})
tensor_store %tensor_result, %result : memref<2x2xf32>
return
}
@ -385,7 +386,7 @@ func @remainder(%lhs: memref<2x2xf32>, %rhs: memref<2x2xf32>, %result: memref<2x
%tensor_rhs = tensor_load %rhs : memref<2x2xf32>
%tensor_result = "mhlo.remainder"(%tensor_lhs, %tensor_rhs)
: (tensor<2x2xf32>, tensor<2x2xf32>) -> tensor<2x2xf32>
// BOTH: "xla_lhlo.remainder"(%{{.*}}, %{{.*}}, %{{.*}})
// BOTH: "lmhlo.remainder"(%{{.*}}, %{{.*}}, %{{.*}})
tensor_store %tensor_result, %result : memref<2x2xf32>
return
}
@ -411,7 +412,7 @@ func @add_dyn(%lhs: tensor<?x?xf32>, %rhs: tensor<?x?xf32>) {
// BOTH: %[[EE1:.*]] = extract_element %[[SHAPE]][%[[C1_]]] : tensor<2xi64>
// BOTH: %[[ICS1:.*]] = index_cast %[[EE1]] : i64 to index
// BOTH: %[[RESULT:.*]] = alloc(%[[ICS0]], %[[ICS1]])
// BOTH: "xla_lhlo.add"(%arg0, %arg1, %[[RESULT]]) : (memref<?x?xf32>, memref<?x?xf32>, memref<?x?xf32>) -> ()
// BOTH: "lmhlo.add"(%arg0, %arg1, %[[RESULT]]) : (memref<?x?xf32>, memref<?x?xf32>, memref<?x?xf32>) -> ()
return
}
@ -436,7 +437,7 @@ func @tanh_dyn(%arg0: tensor<?x?xf32>) {
// BOTH: %[[EE1:.*]] = extract_element %[[SHAPE]][%[[C1_]]] : tensor<2xi64>
// BOTH: %[[ICS1:.*]] = index_cast %[[EE1]] : i64 to index
// BOTH: %[[RESULT:.*]] = alloc(%[[ICS0]], %[[ICS1]])
// BOTH: "xla_lhlo.tanh"(%arg0, %[[RESULT]]) : (memref<?x?xf32>, memref<?x?xf32>) -> ()
// BOTH: "lmhlo.tanh"(%arg0, %[[RESULT]]) : (memref<?x?xf32>, memref<?x?xf32>) -> ()
return
}
@ -447,10 +448,10 @@ func @dot(%arg0: tensor<1024x1024xf32>) -> tensor<1024x1024xf32> {
// PRE-SAME: (%[[ARG0:.*]]: [[TYPE:.*]], %[[RESULT:.*]]: [[TYPE]])
// ESC-SAME: (%[[ARG0:.*]]: [[TYPE:.*]]) -> [[TYPE]]
// BOTH-NEXT: %[[ALLOC:.*]] = alloc
// BOTH: "xla_lhlo.dot"(%[[ARG0]], %[[ARG0]], %[[ALLOC]]) : ([[TYPE]], [[TYPE]], [[TYPE]]) -> ()
// BOTH: "lmhlo.dot"(%[[ARG0]], %[[ARG0]], %[[ALLOC]]) : ([[TYPE]], [[TYPE]], [[TYPE]]) -> ()
%dot = "mhlo.dot"(%arg0, %arg0)
: (tensor<1024x1024xf32>, tensor<1024x1024xf32>) -> tensor<1024x1024xf32>
// PRE: "xla_lhlo.copy"(%[[ALLOC]], %[[RESULT]])
// PRE: "lmhlo.copy"(%[[ALLOC]], %[[RESULT]])
// ESC: return %[[ALLOC]]
return %dot : tensor<1024x1024xf32>
}
@ -461,7 +462,7 @@ func @dot(%arg0: tensor<1024x1024xf32>) -> tensor<1024x1024xf32> {
func @conv(%input: tensor<3x5x5x3xf32>, %filter : tensor<2x2x3x4xf32>) -> tensor<3x5x5x4xf32> {
%c0 = constant 0 : index
// BOTH: %[[OUT:.*]] = alloc() : memref<3x5x5x4xf32>
// BOTH: "xla_lhlo.convolution"(%{{.+}}, %{{.+}}, %[[OUT]])
// BOTH: "lmhlo.convolution"(%{{.+}}, %{{.+}}, %[[OUT]])
// BOTH-SAME: padding = dense<[
// BOTH-SAME: [0, 1], [0, 1]]> : tensor<2x2xi64>
// BOTH-SAME: rhs_dilation = dense<[1, 2]>

View File

@ -3,10 +3,10 @@
// CHECK-LABEL: func @remove_simple
func @remove_simple(%arg0: memref<2x2xf32>) {
%0 = alloc() {temp = true} : memref<2x2xf32>
"xla_lhlo.copy"(%0, %arg0) : (memref<2x2xf32>, memref<2x2xf32>) -> ()
"lmhlo.copy"(%0, %arg0) : (memref<2x2xf32>, memref<2x2xf32>) -> ()
dealloc %0 : memref<2x2xf32>
// CHECK-NEXT: "xla_lhlo.terminator"() : () -> ()
"xla_lhlo.terminator"() : () -> ()
// CHECK-NEXT: "lmhlo.terminator"() : () -> ()
"lmhlo.terminator"() : () -> ()
}
// -----
@ -14,9 +14,9 @@ func @remove_simple(%arg0: memref<2x2xf32>) {
// CHECK-LABEL: func @remove_without_dealloc
func @remove_without_dealloc(%arg0: memref<2x2xf32>) {
%0 = alloc() {temp = true} : memref<2x2xf32>
"xla_lhlo.copy"(%0, %arg0) : (memref<2x2xf32>, memref<2x2xf32>) -> ()
// CHECK-NEXT: "xla_lhlo.terminator"() : () -> ()
"xla_lhlo.terminator"() : () -> ()
"lmhlo.copy"(%0, %arg0) : (memref<2x2xf32>, memref<2x2xf32>) -> ()
// CHECK-NEXT: "lmhlo.terminator"() : () -> ()
"lmhlo.terminator"() : () -> ()
}
// -----
@ -24,22 +24,22 @@ func @remove_without_dealloc(%arg0: memref<2x2xf32>) {
// CHECK-LABEL: func @replace_dependency
func @replace_dependency(%arg0: memref<2x2xf32>, %arg1: memref<2x2xf32>) {
%0 = alloc() {temp = true} : memref<2x2xf32>
"xla_lhlo.exponential"(%arg0, %0) : (memref<2x2xf32>, memref<2x2xf32>) -> ()
// CHECK-NEXT: "xla_lhlo.exponential"(%arg0, %arg1) : (memref<2x2xf32>, memref<2x2xf32>) -> ()
"xla_lhlo.copy"(%0, %arg1) : (memref<2x2xf32>, memref<2x2xf32>) -> ()
"lmhlo.exponential"(%arg0, %0) : (memref<2x2xf32>, memref<2x2xf32>) -> ()
// CHECK-NEXT: "lmhlo.exponential"(%arg0, %arg1) : (memref<2x2xf32>, memref<2x2xf32>) -> ()
"lmhlo.copy"(%0, %arg1) : (memref<2x2xf32>, memref<2x2xf32>) -> ()
dealloc %0 : memref<2x2xf32>
// CHECK-NEXT: "xla_lhlo.terminator"() : () -> ()
"xla_lhlo.terminator"() : () -> ()
// CHECK-NEXT: "lmhlo.terminator"() : () -> ()
"lmhlo.terminator"() : () -> ()
}
// -----
// CHECK-LABEL: func @keep_copies
func @keep_copies(%arg0: memref<2x2xf32>, %arg1: memref<2x2xf32>) {
// CHECK-NEXT: "xla_lhlo.copy"(%arg0, %arg1) : (memref<2x2xf32>, memref<2x2xf32>) -> ()
"xla_lhlo.copy"(%arg0, %arg1) : (memref<2x2xf32>, memref<2x2xf32>) -> ()
// CHECK-NEXT: "xla_lhlo.terminator"() : () -> ()
"xla_lhlo.terminator"() : () -> ()
// CHECK-NEXT: "lmhlo.copy"(%arg0, %arg1) : (memref<2x2xf32>, memref<2x2xf32>) -> ()
"lmhlo.copy"(%arg0, %arg1) : (memref<2x2xf32>, memref<2x2xf32>) -> ()
// CHECK-NEXT: "lmhlo.terminator"() : () -> ()
"lmhlo.terminator"() : () -> ()
}
// -----
@ -50,14 +50,14 @@ func @must_not_be_removed(%arg0: memref<2x2xf32>,
%arg2: memref<2x2xf32>) {
// CHECK-NEXT: %[[ALLOC:.*]] = alloc() {temp = true} : memref<2x2xf32>
%0 = alloc() {temp = true} : memref<2x2xf32>
// CHECK-NEXT: "xla_lhlo.exponential"(%arg0, %[[ALLOC]]) : (memref<2x2xf32>, memref<2x2xf32>) -> ()
"xla_lhlo.exponential"(%arg0, %0) : (memref<2x2xf32>, memref<2x2xf32>) -> ()
// CHECK-NEXT: "xla_lhlo.exponential"(%arg1, %arg2) : (memref<2x2xf32>, memref<2x2xf32>) -> ()
"xla_lhlo.exponential"(%arg1, %arg2) : (memref<2x2xf32>, memref<2x2xf32>) -> ()
// CHECK-NEXT: "xla_lhlo.copy"(%[[ALLOC]], %arg2) : (memref<2x2xf32>, memref<2x2xf32>) -> ()
"xla_lhlo.copy"(%0, %arg2) : (memref<2x2xf32>, memref<2x2xf32>) -> ()
// CHECK-NEXT: "lmhlo.exponential"(%arg0, %[[ALLOC]]) : (memref<2x2xf32>, memref<2x2xf32>) -> ()
"lmhlo.exponential"(%arg0, %0) : (memref<2x2xf32>, memref<2x2xf32>) -> ()
// CHECK-NEXT: "lmhlo.exponential"(%arg1, %arg2) : (memref<2x2xf32>, memref<2x2xf32>) -> ()
"lmhlo.exponential"(%arg1, %arg2) : (memref<2x2xf32>, memref<2x2xf32>) -> ()
// CHECK-NEXT: "lmhlo.copy"(%[[ALLOC]], %arg2) : (memref<2x2xf32>, memref<2x2xf32>) -> ()
"lmhlo.copy"(%0, %arg2) : (memref<2x2xf32>, memref<2x2xf32>) -> ()
dealloc %0 : memref<2x2xf32>
"xla_lhlo.terminator"() : () -> ()
"lmhlo.terminator"() : () -> ()
}
// -----
@ -67,13 +67,13 @@ func @must_be_removed_first(%arg0: memref<2x2xf32>,
%arg1: memref<2x2xf32>,
%arg2: memref<2x2xf32>) {
%0 = alloc() {temp = true} : memref<2x2xf32>
// CHECK-NEXT: "xla_lhlo.exponential"(%arg1, %arg2) : (memref<2x2xf32>, memref<2x2xf32>) -> ()
"xla_lhlo.exponential"(%arg1, %arg2) : (memref<2x2xf32>, memref<2x2xf32>) -> ()
// CHECK-NEXT: "xla_lhlo.exponential"(%arg0, %arg2) : (memref<2x2xf32>, memref<2x2xf32>) -> ()
"xla_lhlo.exponential"(%arg0, %0) : (memref<2x2xf32>, memref<2x2xf32>) -> ()
"xla_lhlo.copy"(%0, %arg2) : (memref<2x2xf32>, memref<2x2xf32>) -> ()
// CHECK-NEXT: "lmhlo.exponential"(%arg1, %arg2) : (memref<2x2xf32>, memref<2x2xf32>) -> ()
"lmhlo.exponential"(%arg1, %arg2) : (memref<2x2xf32>, memref<2x2xf32>) -> ()
// CHECK-NEXT: "lmhlo.exponential"(%arg0, %arg2) : (memref<2x2xf32>, memref<2x2xf32>) -> ()
"lmhlo.exponential"(%arg0, %0) : (memref<2x2xf32>, memref<2x2xf32>) -> ()
"lmhlo.copy"(%0, %arg2) : (memref<2x2xf32>, memref<2x2xf32>) -> ()
dealloc %0 : memref<2x2xf32>
"xla_lhlo.terminator"() : () -> ()
"lmhlo.terminator"() : () -> ()
}
// -----
@ -83,11 +83,11 @@ func @must_be_removed_second(%arg0: memref<2x2xf32>,
%arg1: memref<2x2xf32>,
%arg2: memref<2x2xf32>) {
%0 = alloc() {temp = true} : memref<2x2xf32>
// CHECK-NEXT: "xla_lhlo.exponential"(%arg0, %arg2) : (memref<2x2xf32>, memref<2x2xf32>) -> ()
"xla_lhlo.exponential"(%arg0, %0) : (memref<2x2xf32>, memref<2x2xf32>) -> ()
"xla_lhlo.copy"(%0, %arg2) : (memref<2x2xf32>, memref<2x2xf32>) -> ()
// CHECK-NEXT: "xla_lhlo.exponential"(%arg1, %arg2) : (memref<2x2xf32>, memref<2x2xf32>) -> ()
"xla_lhlo.exponential"(%arg1, %arg2) : (memref<2x2xf32>, memref<2x2xf32>) -> ()
// CHECK-NEXT: "lmhlo.exponential"(%arg0, %arg2) : (memref<2x2xf32>, memref<2x2xf32>) -> ()
"lmhlo.exponential"(%arg0, %0) : (memref<2x2xf32>, memref<2x2xf32>) -> ()
"lmhlo.copy"(%0, %arg2) : (memref<2x2xf32>, memref<2x2xf32>) -> ()
// CHECK-NEXT: "lmhlo.exponential"(%arg1, %arg2) : (memref<2x2xf32>, memref<2x2xf32>) -> ()
"lmhlo.exponential"(%arg1, %arg2) : (memref<2x2xf32>, memref<2x2xf32>) -> ()
dealloc %0 : memref<2x2xf32>
"xla_lhlo.terminator"() : () -> ()
"lmhlo.terminator"() : () -> ()
}

View File

@ -10,18 +10,18 @@ func @select_and_scatter(%arg: memref<112x112xf32>,
%src: memref<56x56xf32>,
%init: memref<f32>,
%result: memref<112x112xf32>) {
"xla_lhlo.select_and_scatter"(%arg, %src, %init, %result) ( {
"lmhlo.select_and_scatter"(%arg, %src, %init, %result) ( {
// select
^bb0(%lhs: memref<f32>, %rhs: memref<f32>, %pred: memref<i1>):
"xla_lhlo.compare"(%lhs, %rhs, %pred) {comparison_direction = "GE"} :
"lmhlo.compare"(%lhs, %rhs, %pred) {comparison_direction = "GE"} :
(memref<f32>, memref<f32>, memref<i1>) -> ()
"xla_lhlo.terminator"() : () -> ()
"lmhlo.terminator"() : () -> ()
}, {
// scatter
^bb0(%lhs: memref<f32>, %rhs: memref<f32>, %out: memref<f32>):
"xla_lhlo.add"(%lhs, %rhs, %out) :
"lmhlo.add"(%lhs, %rhs, %out) :
(memref<f32>, memref<f32>, memref<f32>) -> ()
"xla_lhlo.terminator"() : () -> ()
"lmhlo.terminator"() : () -> ()
}) {
padding = dense<[[0, 1], [0, 1]]> : tensor<2x2xi64>,
window_dimensions = dense<[3, 3]> : tensor<2xi64>,
@ -29,7 +29,7 @@ func @select_and_scatter(%arg: memref<112x112xf32>,
} : (memref<112x112xf32>,
memref<56x56xf32>,
memref<f32>, memref<112x112xf32>) -> ()
"xla_lhlo.terminator"() : () -> ()
"lmhlo.terminator"() : () -> ()
}
// CHECK-LABEL: func @select_and_scatter(
// CHECK-SAME: [[ARG_BUF:%.*]]: memref<112x112xf32>,
@ -121,7 +121,7 @@ func @select_and_scatter(%arg: memref<112x112xf32>,
// CHECK: store [[SEL_VAL]], [[SEL_VAL_BUF]][] : memref<f32>
// Compute PRED.
// CHECK: "xla_lhlo.compare"(
// CHECK: "lmhlo.compare"(
// CHECK-SAME: [[ARG_ELEM_BUF]], [[SEL_VAL_BUF]], [[PRED_BUF]])
// CHECK: [[PRED:%.*]] = load [[PRED_BUF]][] : memref<i1>
@ -182,7 +182,7 @@ func @select_and_scatter(%arg: memref<112x112xf32>,
// CHECK: store [[CUR_RES]], [[CUR_RES_BUF]][] : memref<f32>
// Compute scatter value.
// CHECK: "xla_lhlo.add"([[SRC_ELEM_BUF]], [[CUR_RES_BUF]], [[RES_BUF]]) :
// CHECK: "lmhlo.add"([[SRC_ELEM_BUF]], [[CUR_RES_BUF]], [[RES_BUF]]) :
// CHECK-SAME: (memref<f32>, memref<f32>, memref<f32>) -> ()
// CHECK: [[RES:%.*]] = load [[RES_BUF]][] : memref<f32>

View File

@ -14,7 +14,7 @@ func @min_op(%lhs: memref<4x3x2x1xf32>, %rhs: memref<4x3x2x1xf32>,
// CHECK-NEXT: %[[MIN:.*]] = select %[[MIN_PREDICATE]], %[[LHS]], %[[RHS]] : f32
// CHECK-NEXT: affine.store %[[MIN]], %{{.*}}[%[[I]], %[[J]], %[[K]], %[[L]]] : memref<4x3x2x1xf32>
// CHECK: return
"xla_lhlo.minimum"(%lhs, %rhs, %result) {name = "min.1"} :
"lmhlo.minimum"(%lhs, %rhs, %result) {name = "min.1"} :
(memref<4x3x2x1xf32>, memref<4x3x2x1xf32>, memref<4x3x2x1xf32>) -> ()
return
}
@ -24,7 +24,7 @@ func @min_op(%lhs: memref<4x3x2x1xf32>, %rhs: memref<4x3x2x1xf32>,
func @float_add_op(%lhs: memref<7xf32>, %rhs: memref<7xf32>,
%result: memref<7xf32>) -> () {
// CHECK: addf %{{.*}}, %{{.*}} : f32
"xla_lhlo.add"(%lhs, %rhs, %result) {name = "add.1"}
"lmhlo.add"(%lhs, %rhs, %result) {name = "add.1"}
: (memref<7xf32>, memref<7xf32>, memref<7xf32>) -> ()
return
}
@ -32,7 +32,7 @@ func @float_add_op(%lhs: memref<7xf32>, %rhs: memref<7xf32>,
func @int_add_op(%lhs: memref<7xi32>, %rhs: memref<7xi32>,
%result: memref<7xi32>) -> () {
// CHECK: addi %{{.*}}, %{{.*}} : i32
"xla_lhlo.add"(%lhs, %rhs, %result) {name = "add.1"}
"lmhlo.add"(%lhs, %rhs, %result) {name = "add.1"}
: (memref<7xi32>, memref<7xi32>, memref<7xi32>) -> ()
return
}
@ -42,7 +42,7 @@ func @int_add_op(%lhs: memref<7xi32>, %rhs: memref<7xi32>,
func @int_and_op(%lhs: memref<7xi32>, %rhs: memref<7xi32>,
%result: memref<7xi32>) -> () {
// CHECK: and %{{.*}}, %{{.*}} : i32
"xla_lhlo.and"(%lhs, %rhs, %result) {name = "and.1"}
"lmhlo.and"(%lhs, %rhs, %result) {name = "and.1"}
: (memref<7xi32>, memref<7xi32>, memref<7xi32>) -> ()
return
}
@ -52,7 +52,7 @@ func @int_and_op(%lhs: memref<7xi32>, %rhs: memref<7xi32>,
func @float_div_op(%lhs: memref<7xf32>, %rhs: memref<7xf32>,
%result: memref<7xf32>) -> () {
// CHECK: divf %{{.*}}, %{{.*}} : f32
"xla_lhlo.divide"(%lhs, %rhs, %result) {name = "div.1"}
"lmhlo.divide"(%lhs, %rhs, %result) {name = "div.1"}
: (memref<7xf32>, memref<7xf32>, memref<7xf32>) -> ()
return
}
@ -60,7 +60,7 @@ func @float_div_op(%lhs: memref<7xf32>, %rhs: memref<7xf32>,
func @int_div_op(%lhs: memref<7xi32>, %rhs: memref<7xi32>,
%result: memref<7xi32>) -> () {
// CHECK: divi_signed %{{.*}}, %{{.*}} : i32
"xla_lhlo.divide"(%lhs, %rhs, %result) {name = "div.1"}
"lmhlo.divide"(%lhs, %rhs, %result) {name = "div.1"}
: (memref<7xi32>, memref<7xi32>, memref<7xi32>) -> ()
return
}
@ -71,7 +71,7 @@ func @float_max_op(%lhs: memref<7xf32>, %rhs: memref<7xf32>,
%result: memref<7xf32>) -> () {
// CHECK: %[[CHECK:.*]] = cmpf "ogt", %[[ONE:.*]], %[[TWO:.*]] : f32
// CHECK: select %[[CHECK]], %[[ONE]], %[[TWO]] : f32
"xla_lhlo.maximum"(%lhs, %rhs, %result) {name = "max.1"}
"lmhlo.maximum"(%lhs, %rhs, %result) {name = "max.1"}
: (memref<7xf32>, memref<7xf32>, memref<7xf32>) -> ()
return
}
@ -81,7 +81,7 @@ func @int_max_op(%lhs: memref<7xi32>, %rhs: memref<7xi32>,
%result: memref<7xi32>) -> () {
// CHECK: %[[CHECK:.*]] = cmpi "sgt", %[[ONE:.*]], %[[TWO:.*]] : i32
// CHECK: select %[[CHECK]], %[[ONE]], %[[TWO]] : i32
"xla_lhlo.maximum"(%lhs, %rhs, %result) {name = "max.1"}
"lmhlo.maximum"(%lhs, %rhs, %result) {name = "max.1"}
: (memref<7xi32>, memref<7xi32>, memref<7xi32>) -> ()
return
}
@ -92,7 +92,7 @@ func @float_min_op(%lhs: memref<7xf32>, %rhs: memref<7xf32>,
%result: memref<7xf32>) -> () {
// CHECK: %[[CHECK:.*]] = cmpf "olt", %[[ONE:.*]], %[[TWO:.*]] : f32
// CHECK: select %[[CHECK]], %[[ONE]], %[[TWO]] : f32
"xla_lhlo.minimum"(%lhs, %rhs, %result) {name = "min.1"}
"lmhlo.minimum"(%lhs, %rhs, %result) {name = "min.1"}
: (memref<7xf32>, memref<7xf32>, memref<7xf32>) -> ()
return
}
@ -102,7 +102,7 @@ func @int_min_op(%lhs: memref<7xi32>, %rhs: memref<7xi32>,
%result: memref<7xi32>) -> () {
// CHECK: %[[CHECK:.*]] = cmpi "slt", %[[ONE:.*]], %[[TWO:.*]] : i32
// CHECK: select %[[CHECK]], %[[ONE]], %[[TWO]] : i32
"xla_lhlo.minimum"(%lhs, %rhs, %result) {name = "min.1"}
"lmhlo.minimum"(%lhs, %rhs, %result) {name = "min.1"}
: (memref<7xi32>, memref<7xi32>, memref<7xi32>) -> ()
return
}
@ -112,7 +112,7 @@ func @int_min_op(%lhs: memref<7xi32>, %rhs: memref<7xi32>,
func @float_mul_op(%lhs: memref<7xf32>, %rhs: memref<7xf32>,
%result: memref<7xf32>) -> () {
// CHECK: mulf %{{.*}}, %{{.*}} : f32
"xla_lhlo.multiply"(%lhs, %rhs, %result) {name = "mul.1"}
"lmhlo.multiply"(%lhs, %rhs, %result) {name = "mul.1"}
: (memref<7xf32>, memref<7xf32>, memref<7xf32>) -> ()
return
}
@ -121,7 +121,7 @@ func @float_mul_op(%lhs: memref<7xf32>, %rhs: memref<7xf32>,
func @int_mul_op(%lhs: memref<7xi32>, %rhs: memref<7xi32>,
%result: memref<7xi32>) -> () {
// CHECK: muli %{{.*}}, %{{.*}} : i32
"xla_lhlo.multiply"(%lhs, %rhs, %result) {name = "mul.1"}
"lmhlo.multiply"(%lhs, %rhs, %result) {name = "mul.1"}
: (memref<7xi32>, memref<7xi32>, memref<7xi32>) -> ()
return
}
@ -131,7 +131,7 @@ func @int_mul_op(%lhs: memref<7xi32>, %rhs: memref<7xi32>,
func @float_sub_op(%lhs: memref<7xf32>, %rhs: memref<7xf32>,
%result: memref<7xf32>) -> () {
// CHECK: subf %{{.*}}, %{{.*}} : f32
"xla_lhlo.subtract"(%lhs, %rhs, %result) {name = "sub.1"}
"lmhlo.subtract"(%lhs, %rhs, %result) {name = "sub.1"}
: (memref<7xf32>, memref<7xf32>, memref<7xf32>) -> ()
return
}
@ -139,7 +139,7 @@ func @float_sub_op(%lhs: memref<7xf32>, %rhs: memref<7xf32>,
func @int_sub_op(%lhs: memref<7xi32>, %rhs: memref<7xi32>,
%result: memref<7xi32>) -> () {
// CHECK: subi %{{.*}}, %{{.*}} : i32
"xla_lhlo.subtract"(%lhs, %rhs, %result) {name = "sub.1"}
"lmhlo.subtract"(%lhs, %rhs, %result) {name = "sub.1"}
: (memref<7xi32>, memref<7xi32>, memref<7xi32>) -> ()
return
}
@ -158,7 +158,7 @@ func @float_dot_op(%lhs: memref<7x3xf32>, %rhs:
// CHECK-NEXT: %[[ADD:.*]] = addf %[[MULT]], %[[RESULT]] : f32
// CHECK-NEXT: affine.store %[[ADD]], %{{.*}}[%[[I]], %[[J]]] : memref<7x4xf32>
// CHECK: return
"xla_lhlo.dot"(%lhs, %rhs, %result) :
"lmhlo.dot"(%lhs, %rhs, %result) :
(memref<7x3xf32>, memref<3x4xf32>, memref<7x4xf32>) -> ()
return
}
@ -175,7 +175,7 @@ func @int_dot_op(%lhs: memref<7x3xi32>, %rhs:
// CHECK-NEXT: %[[ADD:.*]] = addi %[[MULT]], %[[RESULT]] : i32
// CHECK-NEXT: affine.store %[[ADD]], %{{.*}}[%[[I]], %[[J]]] : memref<7x4xi32>
// CHECK: return
"xla_lhlo.dot"(%lhs, %rhs, %result) :
"lmhlo.dot"(%lhs, %rhs, %result) :
(memref<7x3xi32>, memref<3x4xi32>, memref<7x4xi32>) -> ()
return
}

View File

@ -3,11 +3,11 @@
func @reduce(%arg: memref<100x10xf32>,
%init: memref<f32>,
%result: memref<100xf32>) {
"xla_lhlo.reduce"(%arg, %init, %result) ( {
"lmhlo.reduce"(%arg, %init, %result) ( {
^bb0(%lhs: memref<f32>, %rhs: memref<f32>, %res: memref<f32>):
"xla_lhlo.add"(%lhs, %rhs, %res)
"lmhlo.add"(%lhs, %rhs, %res)
: (memref<f32>, memref<f32>, memref<f32>) -> ()
"xla_lhlo.terminator"() : () -> ()
"lmhlo.terminator"() : () -> ()
} ) {dimensions = dense<[1]> : tensor<1xi64>}
: (memref<100x10xf32>, memref<f32>, memref<100xf32>) -> ()
return
@ -25,7 +25,7 @@ func @reduce(%arg: memref<100x10xf32>,
// CHECK: scf.for %[[IDX1:.*]] = %[[LB]] to %[[UB]] step %[[STEP]] {
// CHECK: %[[LHS:.*]] = linalg.slice %[[ARG2]][%[[IDX]]] : memref<100xf32>, index, memref<f32, #map0>
// CHECK: %[[RHS:.*]] = linalg.slice %[[ARG0]][%[[IDX]], %[[IDX1]]] : memref<100x10xf32>, index, index, memref<f32, #map0>
// CHECK: "xla_lhlo.add"(%[[LHS]], %[[RHS]], %[[LHS]]) : (memref<f32, {{.*}}>, memref<f32, {{.*}}>, memref<f32, {{.*}}>) -> ()
// CHECK: "lmhlo.add"(%[[LHS]], %[[RHS]], %[[LHS]]) : (memref<f32, {{.*}}>, memref<f32, {{.*}}>, memref<f32, {{.*}}>) -> ()
// CHECK: }
// CHECK: gpu.terminator
// CHECK: }

View File

@ -4,7 +4,7 @@
// CHECK-LABEL: func @element_wise
func @element_wise(%lhs: memref<2x2xf32>, %rhs: memref<2x2xf32>,
%result: memref<2x2xf32>) {
"xla_lhlo.add"(%lhs, %rhs, %result)
"lmhlo.add"(%lhs, %rhs, %result)
: (memref<2x2xf32>, memref<2x2xf32>, memref<2x2xf32>) -> ()
return
}
@ -19,7 +19,7 @@ func @element_wise(%lhs: memref<2x2xf32>, %rhs: memref<2x2xf32>,
func @element_wise_with_dynamic_shape(%lhs: memref<?x?xf32>,
%rhs: memref<?x?xf32>,
%result: memref<?x?xf32>) {
"xla_lhlo.add"(%lhs, %rhs, %result)
"lmhlo.add"(%lhs, %rhs, %result)
: (memref<?x?xf32>, memref<?x?xf32>, memref<?x?xf32>) -> ()
return
}
@ -33,7 +33,7 @@ func @element_wise_with_dynamic_shape(%lhs: memref<?x?xf32>,
// CHECK-LABEL: func @element_wise_scalar
func @element_wise_scalar(%lhs: memref<f32>, %rhs: memref<f32>,
%result: memref<f32>) {
"xla_lhlo.add"(%lhs, %rhs, %result)
"lmhlo.add"(%lhs, %rhs, %result)
: (memref<f32>, memref<f32>, memref<f32>) -> ()
return
}
@ -48,7 +48,7 @@ func @element_wise_scalar(%lhs: memref<f32>, %rhs: memref<f32>,
// CHECK-LABEL: func @minf
func @minf(%lhs: memref<2x2xf32>, %rhs: memref<2x2xf32>,
%result: memref<2x2xf32>) {
"xla_lhlo.minimum"(%lhs, %rhs, %result)
"lmhlo.minimum"(%lhs, %rhs, %result)
: (memref<2x2xf32>, memref<2x2xf32>, memref<2x2xf32>) -> ()
return
}
@ -63,7 +63,7 @@ func @minf(%lhs: memref<2x2xf32>, %rhs: memref<2x2xf32>,
// CHECK-LABEL: func @maxi
func @maxi(%lhs: memref<2x2xi32>, %rhs: memref<2x2xi32>,
%result: memref<2x2xi32>) {
"xla_lhlo.maximum"(%lhs, %rhs, %result)
"lmhlo.maximum"(%lhs, %rhs, %result)
: (memref<2x2xi32>, memref<2x2xi32>, memref<2x2xi32>) -> ()
return
}
@ -78,7 +78,7 @@ func @maxi(%lhs: memref<2x2xi32>, %rhs: memref<2x2xi32>,
// CHECK-LABEL: func @and
func @and(%lhs: memref<2x2xi32>, %rhs: memref<2x2xi32>,
%result: memref<2x2xi32>) {
"xla_lhlo.and"(%lhs, %rhs, %result)
"lmhlo.and"(%lhs, %rhs, %result)
: (memref<2x2xi32>, memref<2x2xi32>, memref<2x2xi32>) -> ()
return
}
@ -91,7 +91,7 @@ func @and(%lhs: memref<2x2xi32>, %rhs: memref<2x2xi32>,
// CHECK-LABEL: func @exp
func @exp(%input: memref<2x2xf32>, %result: memref<2x2xf32>) {
"xla_lhlo.exponential"(%input, %result)
"lmhlo.exponential"(%input, %result)
: (memref<2x2xf32>, memref<2x2xf32>) -> ()
return
}
@ -104,7 +104,7 @@ func @exp(%input: memref<2x2xf32>, %result: memref<2x2xf32>) {
// CHECK-LABEL: func @log
func @log(%input: memref<2x2xf32>, %result: memref<2x2xf32>) {
"xla_lhlo.log"(%input, %result) : (memref<2x2xf32>, memref<2x2xf32>) -> ()
"lmhlo.log"(%input, %result) : (memref<2x2xf32>, memref<2x2xf32>) -> ()
return
}
// CHECK: linalg.generic
@ -116,7 +116,7 @@ func @log(%input: memref<2x2xf32>, %result: memref<2x2xf32>) {
// CHECK-LABEL: func @copy
func @copy(%in: memref<2x4x8xf32>, %out: memref<2x4x8xf32>) {
"xla_lhlo.copy"(%in, %out) : (memref<2x4x8xf32>, memref<2x4x8xf32>) -> ()
"lmhlo.copy"(%in, %out) : (memref<2x4x8xf32>, memref<2x4x8xf32>) -> ()
return
}
// CHECK: linalg.generic
@ -128,7 +128,7 @@ func @copy(%in: memref<2x4x8xf32>, %out: memref<2x4x8xf32>) {
// CHECK-LABEL: func @float_cmp
func @float_cmp(%lhs: memref<2x2xf32>, %rhs: memref<2x2xf32>,
%result: memref<2x2xi1>) {
"xla_lhlo.compare"(%lhs, %rhs, %result) {comparison_direction = "EQ"}
"lmhlo.compare"(%lhs, %rhs, %result) {comparison_direction = "EQ"}
: (memref<2x2xf32>, memref<2x2xf32>, memref<2x2xi1>) -> ()
return
}
@ -142,7 +142,7 @@ func @float_cmp(%lhs: memref<2x2xf32>, %rhs: memref<2x2xf32>,
// CHECK-LABEL: func @int_cmp
func @int_cmp(%lhs: memref<2x2xi32>, %rhs: memref<2x2xi32>,
%result: memref<2x2xi1>) {
"xla_lhlo.compare"(%lhs, %rhs, %result) {comparison_direction = "LT"}
"lmhlo.compare"(%lhs, %rhs, %result) {comparison_direction = "LT"}
: (memref<2x2xi32>, memref<2x2xi32>, memref<2x2xi1>) -> ()
return
}
@ -156,7 +156,7 @@ func @int_cmp(%lhs: memref<2x2xi32>, %rhs: memref<2x2xi32>,
// CHECK-LABEL: func @select
func @select(%pred: memref<2x2xi1>, %lhs: memref<2x2xf32>,
%rhs: memref<2x2xf32>, %result: memref<2x2xf32>) {
"xla_lhlo.select"(%pred, %lhs, %rhs, %result)
"lmhlo.select"(%pred, %lhs, %rhs, %result)
: (memref<2x2xi1>, memref<2x2xf32>, memref<2x2xf32>, memref<2x2xf32>) -> ()
return
}
@ -170,7 +170,7 @@ func @select(%pred: memref<2x2xi1>, %lhs: memref<2x2xf32>,
// CHECK: #[[RESULT_MAP:.*]] = affine_map<(d0, d1) -> (d0, d1)>
// CHECK-LABEL: func @iota
func @iota(%out: memref<7x10xf32>) {
"xla_lhlo.iota"(%out) {iota_dimension = 1 : i64} : (memref<7x10xf32>) -> ()
"lmhlo.iota"(%out) {iota_dimension = 1 : i64} : (memref<7x10xf32>) -> ()
return
}
// CHECK: linalg.indexed_generic
@ -186,7 +186,7 @@ func @iota(%out: memref<7x10xf32>) {
// CHECK-DAG: #[[RESULT_MAP:.+]] = affine_map<(d0, d1, d2) -> (d0, d1, d2)>
// CHECK-LABEL: func @broadcast_scalar
func @broadcast_scalar(%operand: memref<f32>, %result: memref<4x2x1xf32>) {
"xla_lhlo.broadcast"(%operand, %result) {
"lmhlo.broadcast"(%operand, %result) {
broadcast_sizes = dense<[4, 2, 1]> : tensor<3xi64>
} : (memref<f32>, memref<4x2x1xf32>) -> ()
return
@ -203,7 +203,7 @@ func @broadcast_scalar(%operand: memref<f32>, %result: memref<4x2x1xf32>) {
// CHECK-LABEL: func @broadcast
func @broadcast(%operand: memref<4x?x16xf32>,
%result: memref<4x2x1x4x?x16xf32>) {
"xla_lhlo.broadcast"(%operand, %result) {
"lmhlo.broadcast"(%operand, %result) {
broadcast_sizes = dense<[4, 2, 1]> : tensor<3xi64>
} : (memref<4x?x16xf32>, memref<4x2x1x4x?x16xf32>) -> ()
return
@ -220,7 +220,7 @@ func @broadcast(%operand: memref<4x?x16xf32>,
// CHECK-LABEL: func @dynamic_broadcast_in_dim
func @dynamic_broadcast_in_dim(%operand: memref<?x?x?xf32>,
%result: memref<?x?x?x?x?xf32>) {
"xla_lhlo.broadcast_in_dim"(%operand, %result) {
"lmhlo.broadcast_in_dim"(%operand, %result) {
broadcast_dimensions = dense<[4,0,2]> : tensor<3xi64>
} : (memref<?x?x?xf32>, memref<?x?x?x?x?xf32>) -> ()
return
@ -237,7 +237,7 @@ func @dynamic_broadcast_in_dim(%operand: memref<?x?x?xf32>,
// CHECK-LABEL: func @static_broadcast_in_dim_no_expansion
func @static_broadcast_in_dim_no_expansion(%operand: memref<5xf32>,
%result: memref<5x10xf32>) {
"xla_lhlo.broadcast_in_dim"(%operand, %result) {
"lmhlo.broadcast_in_dim"(%operand, %result) {
broadcast_dimensions = dense<[0]> : tensor<1xi64>
} : (memref<5xf32>, memref<5x10xf32>) -> ()
return
@ -255,7 +255,7 @@ func @static_broadcast_in_dim_no_expansion(%operand: memref<5xf32>,
// CHECK-LABEL: func @static_broadcast_in_dim_expansion
func @static_broadcast_in_dim_expansion(%operand: memref<1x5xf32>,
%result: memref<5x10x100xf32>) {
"xla_lhlo.broadcast_in_dim"(%operand, %result) {
"lmhlo.broadcast_in_dim"(%operand, %result) {
broadcast_dimensions = dense<[2, 0]> : tensor<2xi64>
} : (memref<1x5xf32>, memref<5x10x100xf32>) -> ()
return
@ -274,7 +274,7 @@ func @static_broadcast_in_dim_expansion(%operand: memref<1x5xf32>,
// CHECK-LABEL: func @static_broadcast_in_dim_scalar
func @static_broadcast_in_dim_scalar(%operand: memref<f32>,
%result: memref<5x10xf32>) {
"xla_lhlo.broadcast_in_dim"(%operand, %result) {
"lmhlo.broadcast_in_dim"(%operand, %result) {
broadcast_dimensions = dense<[]> : tensor<0xi64>
} : (memref<f32>, memref<5x10xf32>) -> ()
return
@ -291,7 +291,7 @@ func @static_broadcast_in_dim_scalar(%operand: memref<f32>,
// CHECK-LABEL: func @static_broadcast_in_dim_with_one_to_one
func @static_broadcast_in_dim_with_one_to_one(%operand: memref<1xf32>,
%result: memref<1x5xf32>) {
"xla_lhlo.broadcast_in_dim"(%operand, %result) {
"lmhlo.broadcast_in_dim"(%operand, %result) {
broadcast_dimensions = dense<[0]> : tensor<1xi64>
} : (memref<1xf32>, memref<1x5xf32>) -> ()
return
@ -307,7 +307,7 @@ func @static_broadcast_in_dim_with_one_to_one(%operand: memref<1xf32>,
// CHECK-LABEL: func @static_broadcast_in_dim_with_one_to_many
func @static_broadcast_in_dim_with_one_to_many(%operand: memref<1xf32>,
%result: memref<5x5xf32>) {
"xla_lhlo.broadcast_in_dim"(%operand, %result) {
"lmhlo.broadcast_in_dim"(%operand, %result) {
broadcast_dimensions = dense<[1]> : tensor<1xi64>
} : (memref<1xf32>, memref<5x5xf32>) -> ()
return
@ -323,7 +323,7 @@ func @static_broadcast_in_dim_with_one_to_many(%operand: memref<1xf32>,
// CHECK-LABEL: func @constant
func @constant(%value: memref<i32>) {
"xla_lhlo.constant"(%value) {
"lmhlo.constant"(%value) {
value = dense<10> : tensor<i32>
} : (memref<i32>) -> ()
return
@ -335,7 +335,7 @@ func @constant(%value: memref<i32>) {
// CHECK-LABEL: func @absf
func @absf(%input: memref<2x2xf32>, %result: memref<2x2xf32>) {
"xla_lhlo.abs"(%input, %result) : (memref<2x2xf32>, memref<2x2xf32>) -> ()
"lmhlo.abs"(%input, %result) : (memref<2x2xf32>, memref<2x2xf32>) -> ()
return
}
// CHECK: linalg.generic
@ -348,7 +348,7 @@ func @absf(%input: memref<2x2xf32>, %result: memref<2x2xf32>) {
// CHECK-LABEL: func @absi
func @absi(%input: memref<2x2xi32>,
%result: memref<2x2xi32>) {
"xla_lhlo.abs"(%input, %result) : (memref<2x2xi32>, memref<2x2xi32>) -> ()
"lmhlo.abs"(%input, %result) : (memref<2x2xi32>, memref<2x2xi32>) -> ()
return
}
@ -364,7 +364,7 @@ func @absi(%input: memref<2x2xi32>,
// CHECK-LABEL: func @ceil
func @ceil(%input: memref<2x2xf32>, %result: memref<2x2xf32>) {
"xla_lhlo.ceil"(%input, %result) : (memref<2x2xf32>, memref<2x2xf32>) -> ()
"lmhlo.ceil"(%input, %result) : (memref<2x2xf32>, memref<2x2xf32>) -> ()
return
}
// CHECK: linalg.generic
@ -376,7 +376,7 @@ func @ceil(%input: memref<2x2xf32>, %result: memref<2x2xf32>) {
// CHECK-LABEL: func @convert_i32_to_f32
func @convert_i32_to_f32(%input: memref<2x2xi32>, %result: memref<2x2xf32>) {
"xla_lhlo.convert"(%input, %result) : (memref<2x2xi32>, memref<2x2xf32>) -> ()
"lmhlo.convert"(%input, %result) : (memref<2x2xi32>, memref<2x2xf32>) -> ()
return
}
// CHECK: linalg.generic
@ -389,7 +389,7 @@ func @convert_i32_to_f32(%input: memref<2x2xi32>, %result: memref<2x2xf32>) {
// CHECK-LABEL: func @convert_i16_to_i32
func @convert_i16_to_i32(%input: memref<2x2xi16>,
%result: memref<2x2xi32>) {
"xla_lhlo.convert"(%input, %result) : (memref<2x2xi16>, memref<2x2xi32>) -> ()
"lmhlo.convert"(%input, %result) : (memref<2x2xi16>, memref<2x2xi32>) -> ()
return
}
// CHECK: linalg.generic
@ -401,7 +401,7 @@ func @convert_i16_to_i32(%input: memref<2x2xi16>,
// CHECK-LABEL: func @convert_i32_to_i16
func @convert_i32_to_i16(%input: memref<2x2xi32>, %result: memref<2x2xi16>) {
"xla_lhlo.convert"(%input, %result) : (memref<2x2xi32>, memref<2x2xi16>) -> ()
"lmhlo.convert"(%input, %result) : (memref<2x2xi32>, memref<2x2xi16>) -> ()
return
}
// CHECK: linalg.generic
@ -413,7 +413,7 @@ func @convert_i32_to_i16(%input: memref<2x2xi32>, %result: memref<2x2xi16>) {
// CHECK-LABEL: func @convert_f32_to_f64
func @convert_f32_to_f64(%input: memref<2x2xf32>, %result: memref<2x2xf64>) {
"xla_lhlo.convert"(%input, %result) : (memref<2x2xf32>, memref<2x2xf64>) -> ()
"lmhlo.convert"(%input, %result) : (memref<2x2xf32>, memref<2x2xf64>) -> ()
return
}
// CHECK: linalg.generic
@ -425,7 +425,7 @@ func @convert_f32_to_f64(%input: memref<2x2xf32>, %result: memref<2x2xf64>) {
// CHECK-LABEL: func @convert_f64_to_f32
func @convert_f64_to_f32(%input: memref<2x2xf64>, %result: memref<2x2xf32>) {
"xla_lhlo.convert"(%input, %result) : (memref<2x2xf64>, memref<2x2xf32>) -> ()
"lmhlo.convert"(%input, %result) : (memref<2x2xf64>, memref<2x2xf32>) -> ()
return
}
// CHECK: linalg.generic
@ -437,7 +437,7 @@ func @convert_f64_to_f32(%input: memref<2x2xf64>, %result: memref<2x2xf32>) {
// CHECK-LABEL: func @convert_i32_to_i32
func @convert_i32_to_i32(%input: memref<2x2xi32>, %result: memref<2x2xi32>) {
"xla_lhlo.convert"(%input, %result) : (memref<2x2xi32>, memref<2x2xi32>) -> ()
"lmhlo.convert"(%input, %result) : (memref<2x2xi32>, memref<2x2xi32>) -> ()
return
}
// CHECK: linalg.generic
@ -448,7 +448,7 @@ func @convert_i32_to_i32(%input: memref<2x2xi32>, %result: memref<2x2xi32>) {
// CHECK-LABEL: func @convert_f32_to_f32
func @convert_f32_to_f32(%input: memref<2x2xf32>, %result: memref<2x2xf32>) {
"xla_lhlo.convert"(%input, %result) : (memref<2x2xf32>, memref<2x2xf32>) -> ()
"lmhlo.convert"(%input, %result) : (memref<2x2xf32>, memref<2x2xf32>) -> ()
return
}
// CHECK: linalg.generic
@ -459,7 +459,7 @@ func @convert_f32_to_f32(%input: memref<2x2xf32>, %result: memref<2x2xf32>) {
// CHECK-LABEL: func @convert_f32_to_i32
func @convert_f32_to_i32(%input: memref<2x2xf32>, %result: memref<2x2xi32>) {
"xla_lhlo.convert"(%input, %result)
"lmhlo.convert"(%input, %result)
: (memref<2x2xf32>, memref<2x2xi32>) -> ()
return
}
@ -472,7 +472,7 @@ func @convert_f32_to_i32(%input: memref<2x2xf32>, %result: memref<2x2xi32>) {
// CHECK-LABEL: func @cos
func @cos(%input: memref<2x2xf32>, %result: memref<2x2xf32>) {
"xla_lhlo.cosine"(%input, %result) : (memref<2x2xf32>, memref<2x2xf32>) -> ()
"lmhlo.cosine"(%input, %result) : (memref<2x2xf32>, memref<2x2xf32>) -> ()
return
}
// CHECK: linalg.generic
@ -485,7 +485,7 @@ func @cos(%input: memref<2x2xf32>, %result: memref<2x2xf32>) {
// CHECK-LABEL: func @sin
func @sin(%input: memref<2x2xf32>,
%result: memref<2x2xf32>) {
"xla_lhlo.sine"(%input, %result)
"lmhlo.sine"(%input, %result)
: (memref<2x2xf32>, memref<2x2xf32>) -> ()
return
}
@ -498,7 +498,7 @@ func @sin(%input: memref<2x2xf32>,
// CHECK-LABEL: func @negf
func @negf(%input: memref<2x2xf32>, %result: memref<2x2xf32>) {
"xla_lhlo.negate"(%input, %result) : (memref<2x2xf32>, memref<2x2xf32>) -> ()
"lmhlo.negate"(%input, %result) : (memref<2x2xf32>, memref<2x2xf32>) -> ()
return
}
// CHECK: linalg.generic
@ -510,7 +510,7 @@ func @negf(%input: memref<2x2xf32>, %result: memref<2x2xf32>) {
// CHECK-LABEL: func @negi
func @negi(%input: memref<2x2xi32>, %result: memref<2x2xi32>) {
"xla_lhlo.negate"(%input, %result) : (memref<2x2xi32>, memref<2x2xi32>) -> ()
"lmhlo.negate"(%input, %result) : (memref<2x2xi32>, memref<2x2xi32>) -> ()
return
}
// CHECK: linalg.generic
@ -524,7 +524,7 @@ func @negi(%input: memref<2x2xi32>, %result: memref<2x2xi32>) {
// CHECK-LABEL: func @rem
func @remainder(%lhs: memref<2x2xf32>, %rhs: memref<2x2xf32>,
%result: memref<2x2xf32>) {
"xla_lhlo.remainder"(%lhs, %rhs, %result)
"lmhlo.remainder"(%lhs, %rhs, %result)
: (memref<2x2xf32>, memref<2x2xf32>, memref<2x2xf32>) -> ()
return
}
@ -537,7 +537,7 @@ func @remainder(%lhs: memref<2x2xf32>, %rhs: memref<2x2xf32>,
// CHECK-LABEL: func @rsqrt
func @rsqrt(%input: memref<2x2xf32>, %result: memref<2x2xf32>) {
"xla_lhlo.rsqrt"(%input, %result) : (memref<2x2xf32>, memref<2x2xf32>) -> ()
"lmhlo.rsqrt"(%input, %result) : (memref<2x2xf32>, memref<2x2xf32>) -> ()
return
}
// CHECK: linalg.generic
@ -549,7 +549,7 @@ func @rsqrt(%input: memref<2x2xf32>, %result: memref<2x2xf32>) {
// CHECK-LABEL: func @sign
func @sign(%input: memref<2x2xf32>, %result: memref<2x2xf32>) {
"xla_lhlo.sign"(%input, %result) : (memref<2x2xf32>, memref<2x2xf32>) -> ()
"lmhlo.sign"(%input, %result) : (memref<2x2xf32>, memref<2x2xf32>) -> ()
return
}
// CHECK: linalg.generic
@ -562,7 +562,7 @@ func @sign(%input: memref<2x2xf32>, %result: memref<2x2xf32>) {
// CHECK-LABEL: func @sqrt
func @sqrt(%input: memref<2x2xf32>, %result: memref<2x2xf32>) {
"xla_lhlo.sqrt"(%input, %result) : (memref<2x2xf32>, memref<2x2xf32>) -> ()
"lmhlo.sqrt"(%input, %result) : (memref<2x2xf32>, memref<2x2xf32>) -> ()
return
}
// CHECK: linalg.generic
@ -574,7 +574,7 @@ func @sqrt(%input: memref<2x2xf32>, %result: memref<2x2xf32>) {
// CHECK-LABEL: func @tanh
func @tanh(%input: memref<2x2xf32>, %result: memref<2x2xf32>) {
"xla_lhlo.tanh"(%input, %result) : (memref<2x2xf32>, memref<2x2xf32>) -> ()
"lmhlo.tanh"(%input, %result) : (memref<2x2xf32>, memref<2x2xf32>) -> ()
return
}
// CHECK: linalg.generic
@ -588,7 +588,7 @@ func @tanh(%input: memref<2x2xf32>, %result: memref<2x2xf32>) {
func @complex(%real: memref<2x2xf32>,
%imag: memref<2x2xf32>,
%cplx: memref<2x2xcomplex<f32>>) {
"xla_lhlo.complex"(%real, %imag, %cplx)
"lmhlo.complex"(%real, %imag, %cplx)
: (memref<2x2xf32>, memref<2x2xf32>, memref<2x2xcomplex<f32>>) -> ()
return
}
@ -602,7 +602,7 @@ func @complex(%real: memref<2x2xf32>,
// CHECK-LABEL: func @real
func @real(%cplx: memref<2x2xcomplex<f32>>,
%real: memref<2x2xf32>) {
"xla_lhlo.real"(%cplx, %real)
"lmhlo.real"(%cplx, %real)
: (memref<2x2xcomplex<f32>>, memref<2x2xf32>) -> ()
return
}
@ -616,7 +616,7 @@ func @real(%cplx: memref<2x2xcomplex<f32>>,
// CHECK-LABEL: func @imag
func @imag(%cplx: memref<2x2xcomplex<f32>>,
%imag: memref<2x2xf32>) {
"xla_lhlo.imag"(%cplx, %imag)
"lmhlo.imag"(%cplx, %imag)
: (memref<2x2xcomplex<f32>>, memref<2x2xf32>) -> ()
return
}
@ -629,7 +629,7 @@ func @imag(%cplx: memref<2x2xcomplex<f32>>,
// CHECK: func @slice(%[[IN:.*]]: memref<?x?xf32>, %[[OUT:.*]]: memref<?x?xf32>)
func @slice(%operand: memref<?x?xf32>, %result: memref<?x?xf32>) {
"xla_lhlo.slice"(%operand, %result) {
"lmhlo.slice"(%operand, %result) {
start_indices = dense<[0,1]> : tensor<2xi64>,
limit_indices = dense<[2,3]> : tensor<2xi64>,
strides = dense<[1,1]> : tensor<2xi64>
@ -653,7 +653,7 @@ func @slice(%operand: memref<?x?xf32>, %result: memref<?x?xf32>) {
// CHECK-DAG: #[[MAP2:.*]] = affine_map<(d0, d1, d2) -> (d2)>
// CHECK-LABEL: func @reshape_3D_2D
func @reshape_3D_2D(%arg0: memref<12x1x42xi32>, %arg1 : memref<12x42xi32>) {
"xla_lhlo.reshape"(%arg0, %arg1)
"lmhlo.reshape"(%arg0, %arg1)
: (memref<12x1x42xi32>, memref<12x42xi32>) -> ()
return
}
@ -666,7 +666,7 @@ func @reshape_3D_2D(%arg0: memref<12x1x42xi32>, %arg1 : memref<12x42xi32>) {
// CHECK-DAG: #[[MAP2:.*]] = affine_map<(d0, d1, d2, d3) -> (d1, d2, d3)>
// CHECK-LABEL: func @reshape_4D_2D
func @reshape_4D_2D(%arg0: memref<12x42x1x1xi32>, %arg1 : memref<12x42xi32>) {
"xla_lhlo.reshape"(%arg0, %arg1)
"lmhlo.reshape"(%arg0, %arg1)
: (memref<12x42x1x1xi32>, memref<12x42xi32>) -> ()
return
}
@ -679,7 +679,7 @@ func @reshape_4D_2D(%arg0: memref<12x42x1x1xi32>, %arg1 : memref<12x42xi32>) {
// CHECK-DAG: #[[MAP2:.*]] = affine_map<(d0, d1, d2, d3) -> (d2, d3)>
// CHECK-LABEL: func @reshape_2D_4D
func @reshape_2D_4D(%arg0: memref<12x42xi32>, %arg1 : memref<12x1x42x1xi32>) {
"xla_lhlo.reshape"(%arg0, %arg1)
"lmhlo.reshape"(%arg0, %arg1)
: (memref<12x42xi32>, memref<12x1x42x1xi32>) -> ()
return
}
@ -692,7 +692,7 @@ func @reshape_2D_4D(%arg0: memref<12x42xi32>, %arg1 : memref<12x1x42x1xi32>) {
// CHECK-DAG: #[[RESULT_MAP:.*]] = affine_map<(d0, d1) -> (d0, d1)>
// CHECK-LABEL: func @reverse
func @reverse(%arg0: memref<2x3xf32>, %arg1: memref<2x3xf32>) {
"xla_lhlo.reverse"(%arg0, %arg1) {
"lmhlo.reverse"(%arg0, %arg1) {
dimensions = dense<1> : tensor<1xi64>
} : (memref<2x3xf32>, memref<2x3xf32>) -> ()
return
@ -710,15 +710,15 @@ func @conv(%input: memref<3x5x5x3xf32>, %filter: memref<2x2x3x4xf32>, %output: m
// CHECK-SAME: padding = dense<{{\[\[}}0, 1], [0, 1]]> : tensor<2x2xi64>
// CHECK-SAME: strides = [2, 1]}
// With all atributes explicitly specified.
"xla_lhlo.convolution"(%filter, %input, %0) {batch_group_count = 1 : i64, dimension_numbers = {input_batch_dimension = 0 : i64, input_feature_dimension = 3 : i64, input_spatial_dimensions = dense<[1, 2]> : tensor<2xi64>, kernel_input_feature_dimension = 2 : i64, kernel_output_feature_dimension = 3 : i64, kernel_spatial_dimensions = dense<[0, 1]> : tensor<2xi64>, output_batch_dimension = 0 : i64, output_feature_dimension = 3 : i64, output_spatial_dimensions = dense<[1, 2]> : tensor<2xi64>}, feature_group_count = 1 : i64, padding = dense<[[0, 1], [0, 1]]> : tensor<2x2xi64>, rhs_dilation = dense<[1, 2]> : tensor<2xi64>, window_strides = dense<[2, 1]> : tensor<2xi64>} : (memref<2x2x3x4xf32>, memref<3x5x5x3xf32>, memref<3x5x5x4xf32>) -> ()
"lmhlo.convolution"(%filter, %input, %0) {batch_group_count = 1 : i64, dimension_numbers = {input_batch_dimension = 0 : i64, input_feature_dimension = 3 : i64, input_spatial_dimensions = dense<[1, 2]> : tensor<2xi64>, kernel_input_feature_dimension = 2 : i64, kernel_output_feature_dimension = 3 : i64, kernel_spatial_dimensions = dense<[0, 1]> : tensor<2xi64>, output_batch_dimension = 0 : i64, output_feature_dimension = 3 : i64, output_spatial_dimensions = dense<[1, 2]> : tensor<2xi64>}, feature_group_count = 1 : i64, padding = dense<[[0, 1], [0, 1]]> : tensor<2x2xi64>, rhs_dilation = dense<[1, 2]> : tensor<2xi64>, window_strides = dense<[2, 1]> : tensor<2xi64>} : (memref<2x2x3x4xf32>, memref<3x5x5x3xf32>, memref<3x5x5x4xf32>) -> ()
// Dilation left unspecified, sets default dilation since linalg expects it.
// CHECK: linalg.conv(%{{.+}}, %{{.+}}, %{{.+}})
// CHECK-SAME: dilations = [1, 1]
// Padding is not set if it's zero.
// CHECK-NOT: padding
"xla_lhlo.convolution"(%filter, %input, %0) {batch_group_count = 1 : i64, dimension_numbers = {input_batch_dimension = 0 : i64, input_feature_dimension = 3 : i64, input_spatial_dimensions = dense<[1, 2]> : tensor<2xi64>, kernel_input_feature_dimension = 2 : i64, kernel_output_feature_dimension = 3 : i64, kernel_spatial_dimensions = dense<[0, 1]> : tensor<2xi64>, output_batch_dimension = 0 : i64, output_feature_dimension = 3 : i64, output_spatial_dimensions = dense<[1, 2]> : tensor<2xi64>}, feature_group_count = 1 : i64, window_strides = dense<[2, 1]> : tensor<2xi64>} : (memref<2x2x3x4xf32>, memref<3x5x5x3xf32>, memref<3x5x5x4xf32>) -> ()
"lmhlo.convolution"(%filter, %input, %0) {batch_group_count = 1 : i64, dimension_numbers = {input_batch_dimension = 0 : i64, input_feature_dimension = 3 : i64, input_spatial_dimensions = dense<[1, 2]> : tensor<2xi64>, kernel_input_feature_dimension = 2 : i64, kernel_output_feature_dimension = 3 : i64, kernel_spatial_dimensions = dense<[0, 1]> : tensor<2xi64>, output_batch_dimension = 0 : i64, output_feature_dimension = 3 : i64, output_spatial_dimensions = dense<[1, 2]> : tensor<2xi64>}, feature_group_count = 1 : i64, window_strides = dense<[2, 1]> : tensor<2xi64>} : (memref<2x2x3x4xf32>, memref<3x5x5x3xf32>, memref<3x5x5x4xf32>) -> ()
"xla_lhlo.copy"(%0, %output) : (memref<3x5x5x4xf32>, memref<3x5x5x4xf32>) -> ()
"xla_lhlo.terminator"() : () -> ()
"lmhlo.copy"(%0, %output) : (memref<3x5x5x4xf32>, memref<3x5x5x4xf32>) -> ()
"lmhlo.terminator"() : () -> ()
}

View File

@ -2,7 +2,7 @@
// CHECK-LABEL: func @static_memref_cast
func @static_memref_cast(%buf : memref<10x1x5xf32>) {
%0 = xla_lhlo.static_memref_cast %buf
%0 = lmhlo.static_memref_cast %buf
: memref<10x1x5xf32> -> memref<10x5xf32, offset: 2, strides: [5, 1]>
return
}
@ -38,7 +38,7 @@ func @dynamic_memref_cast(%buf : memref<?x?xf32>) {
%size_Y = constant 50 : index
%stride_X = constant 1 : index
%stride_Y = constant 0 : index
%0 = xla_lhlo.dynamic_memref_cast %buf(%size_X, %size_Y)[%stride_X, %stride_Y]
%0 = lmhlo.dynamic_memref_cast %buf(%size_X, %size_Y)[%stride_X, %stride_Y]
: memref<?x?xf32> -> memref<?x?xf32, offset: 0, strides: [?, ?]>
return
}

View File

@ -3,11 +3,11 @@
func @reduce(%arg: memref<100x10x5xf32>,
%init: memref<f32>,
%result: memref<100x5xf32>) {
"xla_lhlo.reduce"(%arg, %init, %result) ( {
"lmhlo.reduce"(%arg, %init, %result) ( {
^bb0(%lhs: memref<f32>, %rhs: memref<f32>, %res: memref<f32>):
"xla_lhlo.add"(%lhs, %rhs, %res)
"lmhlo.add"(%lhs, %rhs, %res)
: (memref<f32>, memref<f32>, memref<f32>) -> ()
"xla_lhlo.terminator"() : () -> ()
"lmhlo.terminator"() : () -> ()
} ) {dimensions = dense<[1]> : tensor<1xi64>}
: (memref<100x10x5xf32>, memref<f32>, memref<100x5xf32>) -> ()
return
@ -35,7 +35,7 @@ func @reduce(%arg: memref<100x10x5xf32>,
// CHECK: [[ACC_OUT_BUF:%.*]] = alloc() : memref<f32>
// CHECK: store [[ELEM]], [[ELEM_BUF]][] : memref<f32>
// CHECK: store [[ACC]], [[ACC_BUF]][] : memref<f32>
// CHECK: "xla_lhlo.add"([[ELEM_BUF]], [[ACC_BUF]], [[ACC_OUT_BUF]])
// CHECK: "lmhlo.add"([[ELEM_BUF]], [[ACC_BUF]], [[ACC_OUT_BUF]])
// CHECK: [[ACC_RESULT:%.*]] = load [[ACC_OUT_BUF]][] : memref<f32>
// CHECK: scf.reduce.return [[ACC_RESULT]] : f32
// CHECK: }
@ -49,11 +49,11 @@ func @reduce(%arg: memref<100x10x5xf32>,
func @reduce_no_outer_loop(%arg: memref<100xf32>,
%init: memref<f32>,
%result: memref<1xf32>) {
"xla_lhlo.reduce"(%arg, %init, %result) ( {
"lmhlo.reduce"(%arg, %init, %result) ( {
^bb0(%lhs: memref<f32>, %rhs: memref<f32>, %res: memref<f32>):
"xla_lhlo.add"(%lhs, %rhs, %res)
"lmhlo.add"(%lhs, %rhs, %res)
: (memref<f32>, memref<f32>, memref<f32>) -> ()
"xla_lhlo.terminator"() : () -> ()
"lmhlo.terminator"() : () -> ()
} ) {dimensions = dense<[0]> : tensor<1xi64>}
: (memref<100xf32>, memref<f32>, memref<1xf32>) -> ()
return
@ -76,7 +76,7 @@ func @reduce_no_outer_loop(%arg: memref<100xf32>,
// CHECK: [[ACC_OUT_BUF:%.*]] = alloc() : memref<f32>
// CHECK: store [[ELEM]], [[ELEM_BUF]][] : memref<f32>
// CHECK: store [[ACC]], [[ACC_BUF]][] : memref<f32>
// CHECK: "xla_lhlo.add"([[ELEM_BUF]], [[ACC_BUF]], [[ACC_OUT_BUF]])
// CHECK: "lmhlo.add"([[ELEM_BUF]], [[ACC_BUF]], [[ACC_OUT_BUF]])
// CHECK: [[ACC_RESULT:%.*]] = load [[ACC_OUT_BUF]][] : memref<f32>
// CHECK: scf.reduce.return [[ACC_RESULT]]
// CHECK: }
@ -88,11 +88,11 @@ func @reduce_no_outer_loop(%arg: memref<100xf32>,
func @dynamic_reduce(%arg: memref<?x?x?xf32>,
%init: memref<f32>,
%result: memref<?x?xf32>) {
"xla_lhlo.reduce"(%arg, %init, %result) ( {
"lmhlo.reduce"(%arg, %init, %result) ( {
^bb0(%lhs: memref<f32>, %rhs: memref<f32>, %res: memref<f32>):
"xla_lhlo.add"(%lhs, %rhs, %res)
"lmhlo.add"(%lhs, %rhs, %res)
: (memref<f32>, memref<f32>, memref<f32>) -> ()
"xla_lhlo.terminator"() : () -> ()
"lmhlo.terminator"() : () -> ()
} ) {dimensions = dense<[1]> : tensor<1xi64>}
: (memref<?x?x?xf32>, memref<f32>, memref<?x?xf32>) -> ()
return
@ -121,7 +121,7 @@ func @dynamic_reduce(%arg: memref<?x?x?xf32>,
// CHECK: [[ACC_OUT_BUF:%.*]] = alloc() : memref<f32>
// CHECK: store [[ELEM]], [[ELEM_BUF]][] : memref<f32>
// CHECK: store [[ACC]], [[ACC_BUF]][] : memref<f32>
// CHECK: "xla_lhlo.add"([[ELEM_BUF]], [[ACC_BUF]], [[ACC_OUT_BUF]])
// CHECK: "lmhlo.add"([[ELEM_BUF]], [[ACC_BUF]], [[ACC_OUT_BUF]])
// CHECK: [[ACC_RESULT:%.*]] = load [[ACC_OUT_BUF]][] : memref<f32>
// CHECK: scf.reduce.return [[ACC_RESULT]] : f32
// CHECK: }
@ -135,11 +135,11 @@ func @dynamic_reduce(%arg: memref<?x?x?xf32>,
func @reduce_window(%arg: memref<112x112xf32>,
%init: memref<f32>,
%result: memref<56x56xf32>) {
"xla_lhlo.reduce_window"(%arg, %init, %result) ( {
"lmhlo.reduce_window"(%arg, %init, %result) ( {
^bb0(%lhs: memref<f32>, %rhs: memref<f32>, %res: memref<f32>):
"xla_lhlo.maximum"(%lhs, %rhs, %res)
"lmhlo.maximum"(%lhs, %rhs, %res)
: (memref<f32>, memref<f32>, memref<f32>) -> ()
"xla_lhlo.terminator"() : () -> ()
"lmhlo.terminator"() : () -> ()
}) {
padding = dense<[[0, 1], [0, 1]]> : tensor<2x2xi64>,
window_dimensions = dense<[3, 3]> : tensor<2xi64>,
@ -189,7 +189,7 @@ func @reduce_window(%arg: memref<112x112xf32>,
// CHECK: [[ACC_OUT_BUF:%.*]] = alloc() : memref<f32>
// CHECK: store [[ELEM]], [[ELEM_BUF]][] : memref<f32>
// CHECK: store [[ACC]], [[ACC_BUF]][] : memref<f32>
// CHECK: "xla_lhlo.maximum"([[ELEM_BUF]], [[ACC_BUF]], [[ACC_OUT_BUF]])
// CHECK: "lmhlo.maximum"([[ELEM_BUF]], [[ACC_BUF]], [[ACC_OUT_BUF]])
// CHECK: [[ACC_RESULT:%.*]] = load [[ACC_OUT_BUF]][] : memref<f32>
// CHECK: scf.reduce.return [[ACC_RESULT]] : f32
// CHECK: }

View File

@ -4,7 +4,7 @@
// CHECK-LABEL: func @ceil
func @ceil(%input: memref<2x2xf32>, %result: memref<2x2xf32>) {
"xla_lhlo.ceil"(%input, %result) : (memref<2x2xf32>, memref<2x2xf32>) -> ()
"lmhlo.ceil"(%input, %result) : (memref<2x2xf32>, memref<2x2xf32>) -> ()
return
}
@ -12,7 +12,7 @@ func @ceil(%input: memref<2x2xf32>, %result: memref<2x2xf32>) {
func @ceil(%input: memref<2x2xi32>, %result: memref<2x2xi32>) {
// expected-error@+1{{must be memref of floating-point values}}
"xla_lhlo.ceil"(%input, %result) : (memref<2x2xi32>, memref<2x2xi32>) -> ()
"lmhlo.ceil"(%input, %result) : (memref<2x2xi32>, memref<2x2xi32>) -> ()
return
}
@ -20,7 +20,7 @@ func @ceil(%input: memref<2x2xi32>, %result: memref<2x2xi32>) {
// CHECK-LABEL: func @cos
func @cos(%input: memref<2x2xf32>, %result: memref<2x2xf32>) {
"xla_lhlo.cosine"(%input, %result) : (memref<2x2xf32>, memref<2x2xf32>) -> ()
"lmhlo.cosine"(%input, %result) : (memref<2x2xf32>, memref<2x2xf32>) -> ()
return
}
@ -28,7 +28,7 @@ func @cos(%input: memref<2x2xf32>, %result: memref<2x2xf32>) {
// CHECK-LABEL: func @cos
func @cos(%input: memref<2x2xcomplex<f32>>, %result: memref<2x2xcomplex<f32>>) {
"xla_lhlo.cosine"(%input, %result) : (memref<2x2xcomplex<f32>>, memref<2x2xcomplex<f32>>) -> ()
"lmhlo.cosine"(%input, %result) : (memref<2x2xcomplex<f32>>, memref<2x2xcomplex<f32>>) -> ()
return
}
@ -36,7 +36,7 @@ func @cos(%input: memref<2x2xcomplex<f32>>, %result: memref<2x2xcomplex<f32>>) {
func @cos(%input: memref<2x2xi32>, %result: memref<2x2xi32>) {
// expected-error@+1{{must be memref of floating-point or complex-type values}}
"xla_lhlo.cosine"(%input, %result) : (memref<2x2xi32>, memref<2x2xi32>) -> ()
"lmhlo.cosine"(%input, %result) : (memref<2x2xi32>, memref<2x2xi32>) -> ()
return
}
@ -44,7 +44,7 @@ func @cos(%input: memref<2x2xi32>, %result: memref<2x2xi32>) {
// CHECK-LABEL: func @sin
func @sin(%input: memref<2x2xf32>, %result: memref<2x2xf32>) {
"xla_lhlo.sine"(%input, %result) : (memref<2x2xf32>, memref<2x2xf32>) -> ()
"lmhlo.sine"(%input, %result) : (memref<2x2xf32>, memref<2x2xf32>) -> ()
return
}
@ -52,7 +52,7 @@ func @sin(%input: memref<2x2xf32>, %result: memref<2x2xf32>) {
// CHECK-LABEL: func @sin
func @sin(%input: memref<2x2xcomplex<f32>>, %result: memref<2x2xcomplex<f32>>) {
"xla_lhlo.sine"(%input, %result) : (memref<2x2xcomplex<f32>>, memref<2x2xcomplex<f32>>) -> ()
"lmhlo.sine"(%input, %result) : (memref<2x2xcomplex<f32>>, memref<2x2xcomplex<f32>>) -> ()
return
}
@ -60,7 +60,7 @@ func @sin(%input: memref<2x2xcomplex<f32>>, %result: memref<2x2xcomplex<f32>>) {
func @sin(%input: memref<2x2xi32>, %result: memref<2x2xi32>) {
// expected-error@+1{{must be memref of floating-point or complex-type values}}
"xla_lhlo.sine"(%input, %result) : (memref<2x2xi32>, memref<2x2xi32>) -> ()
"lmhlo.sine"(%input, %result) : (memref<2x2xi32>, memref<2x2xi32>) -> ()
return
}
@ -68,7 +68,7 @@ func @sin(%input: memref<2x2xi32>, %result: memref<2x2xi32>) {
// CHECK-LABEL: func @add_memrefs
func @add_memrefs(%arg0: memref<1xi32>, %arg1: memref<1xi32>, %arg_out: memref<1xi32>) -> () {
"xla_lhlo.add"(%arg0, %arg1, %arg_out) : (memref<1xi32>, memref<1xi32>, memref<1xi32>) -> ()
"lmhlo.add"(%arg0, %arg1, %arg_out) : (memref<1xi32>, memref<1xi32>, memref<1xi32>) -> ()
return
}
@ -76,7 +76,7 @@ func @add_memrefs(%arg0: memref<1xi32>, %arg1: memref<1xi32>, %arg_out: memref<1
// CHECK-LABEL: func @abs_memref
func @abs_memref(%in: memref<10xf32>, %out: memref<10xf32>) -> () {
"xla_lhlo.abs"(%in, %out) : (memref<10xf32>, memref<10xf32>) -> ()
"lmhlo.abs"(%in, %out) : (memref<10xf32>, memref<10xf32>) -> ()
return
}
@ -84,7 +84,7 @@ func @abs_memref(%in: memref<10xf32>, %out: memref<10xf32>) -> () {
// CHECK-LABEL: func @convert_memref
func @convert_memref(%in: memref<10xf32>, %out: memref<10xi32>) -> () {
"xla_lhlo.convert"(%in, %out) : (memref<10xf32>, memref<10xi32>) -> ()
"lmhlo.convert"(%in, %out) : (memref<10xf32>, memref<10xi32>) -> ()
return
}
@ -92,7 +92,7 @@ func @convert_memref(%in: memref<10xf32>, %out: memref<10xi32>) -> () {
func @convert_memref(%in: memref<10xf32>, %out: memref<9xi32>) -> () {
// expected-error@+1{{requires the same shape for all operands}}
"xla_lhlo.convert"(%in, %out) : (memref<10xf32>, memref<9xi32>) -> ()
"lmhlo.convert"(%in, %out) : (memref<10xf32>, memref<9xi32>) -> ()
return
}
@ -100,7 +100,7 @@ func @convert_memref(%in: memref<10xf32>, %out: memref<9xi32>) -> () {
// CHECK-LABEL: func @exp
func @exp(%input: memref<2x2xf32>, %result: memref<2x2xf32>) {
"xla_lhlo.exponential"(%input, %result) : (memref<2x2xf32>, memref<2x2xf32>) -> ()
"lmhlo.exponential"(%input, %result) : (memref<2x2xf32>, memref<2x2xf32>) -> ()
return
}
@ -108,7 +108,7 @@ func @exp(%input: memref<2x2xf32>, %result: memref<2x2xf32>) {
// CHECK-LABEL: func @exp
func @exp(%input: memref<2x2xcomplex<f32>>, %result: memref<2x2xcomplex<f32>>) {
"xla_lhlo.exponential"(%input, %result) : (memref<2x2xcomplex<f32>>, memref<2x2xcomplex<f32>>) -> ()
"lmhlo.exponential"(%input, %result) : (memref<2x2xcomplex<f32>>, memref<2x2xcomplex<f32>>) -> ()
return
}
@ -116,7 +116,7 @@ func @exp(%input: memref<2x2xcomplex<f32>>, %result: memref<2x2xcomplex<f32>>) {
func @exp(%input: memref<2x2xi32>, %result: memref<2x2xi32>) {
// expected-error@+1{{must be memref of floating-point or complex-type values}}
"xla_lhlo.exponential"(%input, %result) : (memref<2x2xi32>, memref<2x2xi32>) -> ()
"lmhlo.exponential"(%input, %result) : (memref<2x2xi32>, memref<2x2xi32>) -> ()
return
}
@ -124,7 +124,7 @@ func @exp(%input: memref<2x2xi32>, %result: memref<2x2xi32>) {
// CHECK-LABEL: func @log_memref
func @log_memref(%in: memref<10xf32>, %out: memref<10xf32>) -> () {
"xla_lhlo.log"(%in, %out) : (memref<10xf32>, memref<10xf32>) -> ()
"lmhlo.log"(%in, %out) : (memref<10xf32>, memref<10xf32>) -> ()
return
}
@ -132,7 +132,7 @@ func @log_memref(%in: memref<10xf32>, %out: memref<10xf32>) -> () {
// CHECK-LABEL: func @log_memref
func @log_memref(%in: memref<10xcomplex<f32>>, %out: memref<10xcomplex<f32>>) -> () {
"xla_lhlo.log"(%in, %out) : (memref<10xcomplex<f32>>, memref<10xcomplex<f32>>) -> ()
"lmhlo.log"(%in, %out) : (memref<10xcomplex<f32>>, memref<10xcomplex<f32>>) -> ()
return
}
@ -140,7 +140,7 @@ func @log_memref(%in: memref<10xcomplex<f32>>, %out: memref<10xcomplex<f32>>) ->
func @log_memref(%in: memref<10xi32>, %out: memref<10xi32>) -> () {
// expected-error@+1{{must be memref of floating-point or complex-type values}}
"xla_lhlo.log"(%in, %out) : (memref<10xi32>, memref<10xi32>) -> ()
"lmhlo.log"(%in, %out) : (memref<10xi32>, memref<10xi32>) -> ()
return
}
@ -148,7 +148,7 @@ func @log_memref(%in: memref<10xi32>, %out: memref<10xi32>) -> () {
// CHECK-LABEL: func @neg_memref
func @neg_memref(%in: memref<10xf32>, %out: memref<10xf32>) -> () {
"xla_lhlo.negate"(%in, %out) : (memref<10xf32>, memref<10xf32>) -> ()
"lmhlo.negate"(%in, %out) : (memref<10xf32>, memref<10xf32>) -> ()
return
}
@ -156,7 +156,7 @@ func @neg_memref(%in: memref<10xf32>, %out: memref<10xf32>) -> () {
// CHECK-LABEL: func @rsqrt_memref
func @rsqrt_memref(%in: memref<10xf32>, %out: memref<10xf32>) -> () {
"xla_lhlo.rsqrt"(%in, %out) : (memref<10xf32>, memref<10xf32>) -> ()
"lmhlo.rsqrt"(%in, %out) : (memref<10xf32>, memref<10xf32>) -> ()
return
}
@ -164,7 +164,7 @@ func @rsqrt_memref(%in: memref<10xf32>, %out: memref<10xf32>) -> () {
// CHECK-LABEL: func @rsqrt_memref
func @rsqrt_memref(%in: memref<10xcomplex<f32>>, %out: memref<10xcomplex<f32>>) -> () {
"xla_lhlo.rsqrt"(%in, %out) : (memref<10xcomplex<f32>>, memref<10xcomplex<f32>>) -> ()
"lmhlo.rsqrt"(%in, %out) : (memref<10xcomplex<f32>>, memref<10xcomplex<f32>>) -> ()
return
}
@ -172,7 +172,7 @@ func @rsqrt_memref(%in: memref<10xcomplex<f32>>, %out: memref<10xcomplex<f32>>)
func @rsqrt_memref(%in: memref<10xi32>, %out: memref<10xi32>) -> () {
// expected-error@+1{{must be memref of floating-point or complex-type values}}
"xla_lhlo.rsqrt"(%in, %out) : (memref<10xi32>, memref<10xi32>) -> ()
"lmhlo.rsqrt"(%in, %out) : (memref<10xi32>, memref<10xi32>) -> ()
return
}
@ -180,7 +180,7 @@ func @rsqrt_memref(%in: memref<10xi32>, %out: memref<10xi32>) -> () {
// CHECK-LABEL: func @sqrt_memref
func @sqrt_memref(%in: memref<10xf32>, %out: memref<10xf32>) -> () {
"xla_lhlo.sqrt"(%in, %out) : (memref<10xf32>, memref<10xf32>) -> ()
"lmhlo.sqrt"(%in, %out) : (memref<10xf32>, memref<10xf32>) -> ()
return
}
@ -188,7 +188,7 @@ func @sqrt_memref(%in: memref<10xf32>, %out: memref<10xf32>) -> () {
// CHECK-LABEL: func @sqrt_memref
func @sqrt_memref(%in: memref<10xcomplex<f32>>, %out: memref<10xcomplex<f32>>) -> () {
"xla_lhlo.sqrt"(%in, %out) : (memref<10xcomplex<f32>>, memref<10xcomplex<f32>>) -> ()
"lmhlo.sqrt"(%in, %out) : (memref<10xcomplex<f32>>, memref<10xcomplex<f32>>) -> ()
return
}
@ -196,7 +196,7 @@ func @sqrt_memref(%in: memref<10xcomplex<f32>>, %out: memref<10xcomplex<f32>>) -
func @sqrt_memref(%in: memref<10xi32>, %out: memref<10xi32>) -> () {
// expected-error@+1{{must be memref of floating-point or complex-type values}}
"xla_lhlo.sqrt"(%in, %out) : (memref<10xi32>, memref<10xi32>) -> ()
"lmhlo.sqrt"(%in, %out) : (memref<10xi32>, memref<10xi32>) -> ()
return
}
@ -204,7 +204,7 @@ func @sqrt_memref(%in: memref<10xi32>, %out: memref<10xi32>) -> () {
// CHECK-LABEL: func @sign_memref
func @sign_memref(%in: memref<10xf32>, %out: memref<10xf32>) -> () {
"xla_lhlo.sign"(%in, %out) : (memref<10xf32>, memref<10xf32>) -> ()
"lmhlo.sign"(%in, %out) : (memref<10xf32>, memref<10xf32>) -> ()
return
}
@ -212,7 +212,7 @@ func @sign_memref(%in: memref<10xf32>, %out: memref<10xf32>) -> () {
// CHECK-LABEL: func @tanh_memref
func @tanh_memref(%in: memref<10xf32>, %out: memref<10xf32>) -> () {
"xla_lhlo.tanh"(%in, %out) : (memref<10xf32>, memref<10xf32>) -> ()
"lmhlo.tanh"(%in, %out) : (memref<10xf32>, memref<10xf32>) -> ()
return
}
@ -220,7 +220,7 @@ func @tanh_memref(%in: memref<10xf32>, %out: memref<10xf32>) -> () {
// CHECK-LABEL: func @tanh_memref
func @tanh_memref(%in: memref<10xcomplex<f32>>, %out: memref<10xcomplex<f32>>) -> () {
"xla_lhlo.tanh"(%in, %out) : (memref<10xcomplex<f32>>, memref<10xcomplex<f32>>) -> ()
"lmhlo.tanh"(%in, %out) : (memref<10xcomplex<f32>>, memref<10xcomplex<f32>>) -> ()
return
}
@ -228,15 +228,15 @@ func @tanh_memref(%in: memref<10xcomplex<f32>>, %out: memref<10xcomplex<f32>>) -
func @tanh_memref(%in: memref<10xi32>, %out: memref<10xi32>) -> () {
// expected-error@+1{{must be memref of floating-point or complex-type values}}
"xla_lhlo.tanh"(%in, %out) : (memref<10xi32>, memref<10xi32>) -> ()
"lmhlo.tanh"(%in, %out) : (memref<10xi32>, memref<10xi32>) -> ()
return
}
// -----
func @tanh_memref(%arg0: memref<1xf32>, %arg1: memref<2xf32>) -> () {
// expected-error@+1{{'xla_lhlo.tanh' op requires all operands to have the same type}}
"xla_lhlo.tanh"(%arg0, %arg1) : (memref<1xf32>, memref<2xf32>) -> ()
// expected-error@+1{{'lmhlo.tanh' op requires all operands to have the same type}}
"lmhlo.tanh"(%arg0, %arg1) : (memref<1xf32>, memref<2xf32>) -> ()
return
}
@ -244,7 +244,7 @@ func @tanh_memref(%arg0: memref<1xf32>, %arg1: memref<2xf32>) -> () {
// CHECK-LABEL: func @add_memref
func @add_memref(%lhs: memref<10xf32>, %rhs: memref<10xf32>, %out: memref<10xf32>) -> () {
"xla_lhlo.add"(%lhs, %rhs, %out) : (memref<10xf32>, memref<10xf32>, memref<10xf32>) -> ()
"lmhlo.add"(%lhs, %rhs, %out) : (memref<10xf32>, memref<10xf32>, memref<10xf32>) -> ()
return
}
@ -252,7 +252,7 @@ func @add_memref(%lhs: memref<10xf32>, %rhs: memref<10xf32>, %out: memref<10xf32
// CHECK-LABEL: func @div_memref
func @div_memref(%lhs: memref<10xf32>, %rhs: memref<10xf32>, %out: memref<10xf32>) -> () {
"xla_lhlo.divide"(%lhs, %rhs, %out) : (memref<10xf32>, memref<10xf32>, memref<10xf32>) -> ()
"lmhlo.divide"(%lhs, %rhs, %out) : (memref<10xf32>, memref<10xf32>, memref<10xf32>) -> ()
return
}
@ -260,7 +260,7 @@ func @div_memref(%lhs: memref<10xf32>, %rhs: memref<10xf32>, %out: memref<10xf32
// CHECK-LABEL: func @max_memref
func @max_memref(%lhs: memref<10xf32>, %rhs: memref<10xf32>, %out: memref<10xf32>) -> () {
"xla_lhlo.maximum"(%lhs, %rhs, %out) : (memref<10xf32>, memref<10xf32>, memref<10xf32>) -> ()
"lmhlo.maximum"(%lhs, %rhs, %out) : (memref<10xf32>, memref<10xf32>, memref<10xf32>) -> ()
return
}
@ -268,7 +268,7 @@ func @max_memref(%lhs: memref<10xf32>, %rhs: memref<10xf32>, %out: memref<10xf32
// CHECK-LABEL: func @min_memref
func @min_memref(%lhs: memref<10xf32>, %rhs: memref<10xf32>, %out: memref<10xf32>) -> () {
"xla_lhlo.minimum"(%lhs, %rhs, %out) : (memref<10xf32>, memref<10xf32>, memref<10xf32>) -> ()
"lmhlo.minimum"(%lhs, %rhs, %out) : (memref<10xf32>, memref<10xf32>, memref<10xf32>) -> ()
return
}
@ -276,7 +276,7 @@ func @min_memref(%lhs: memref<10xf32>, %rhs: memref<10xf32>, %out: memref<10xf32
// CHECK-LABEL: func @mul_memref
func @mul_memref(%lhs: memref<10xf32>, %rhs: memref<10xf32>, %out: memref<10xf32>) -> () {
"xla_lhlo.multiply"(%lhs, %rhs, %out) : (memref<10xf32>, memref<10xf32>, memref<10xf32>) -> ()
"lmhlo.multiply"(%lhs, %rhs, %out) : (memref<10xf32>, memref<10xf32>, memref<10xf32>) -> ()
return
}
@ -284,7 +284,7 @@ func @mul_memref(%lhs: memref<10xf32>, %rhs: memref<10xf32>, %out: memref<10xf32
// CHECK-LABEL: func @sub_memref
func @sub_memref(%lhs: memref<10xf32>, %rhs: memref<10xf32>, %out: memref<10xf32>) -> () {
"xla_lhlo.subtract"(%lhs, %rhs, %out) : (memref<10xf32>, memref<10xf32>, memref<10xf32>) -> ()
"lmhlo.subtract"(%lhs, %rhs, %out) : (memref<10xf32>, memref<10xf32>, memref<10xf32>) -> ()
return
}
@ -292,7 +292,7 @@ func @sub_memref(%lhs: memref<10xf32>, %rhs: memref<10xf32>, %out: memref<10xf32
// CHECK-LABEL: func @and_memref
func @and_memref(%lhs: memref<10xi32>, %rhs: memref<10xi32>, %out: memref<10xi32>) -> () {
"xla_lhlo.and"(%lhs, %rhs, %out) : (memref<10xi32>, memref<10xi32>, memref<10xi32>) -> ()
"lmhlo.and"(%lhs, %rhs, %out) : (memref<10xi32>, memref<10xi32>, memref<10xi32>) -> ()
return
}
@ -300,7 +300,7 @@ func @and_memref(%lhs: memref<10xi32>, %rhs: memref<10xi32>, %out: memref<10xi32
// CHECK-LABEL: func @and_memref
func @and_memref(%lhs: memref<10xi1>, %rhs: memref<10xi1>, %out: memref<10xi1>) -> () {
"xla_lhlo.and"(%lhs, %rhs, %out) : (memref<10xi1>, memref<10xi1>, memref<10xi1>) -> ()
"lmhlo.and"(%lhs, %rhs, %out) : (memref<10xi1>, memref<10xi1>, memref<10xi1>) -> ()
return
}
@ -308,7 +308,7 @@ func @and_memref(%lhs: memref<10xi1>, %rhs: memref<10xi1>, %out: memref<10xi1>)
func @and_memref(%lhs: memref<10xf32>, %rhs: memref<10xf32>, %out: memref<10xf32>) -> () {
// expected-error @+1 {{must be memref of 8/16/32/64-bit signless integer or 8/16/32/64-bit unsigned integer or pred (AKA boolean or 1-bit integer) values}}
"xla_lhlo.and"(%lhs, %rhs, %out) : (memref<10xf32>, memref<10xf32>, memref<10xf32>) -> ()
"lmhlo.and"(%lhs, %rhs, %out) : (memref<10xf32>, memref<10xf32>, memref<10xf32>) -> ()
return
}
@ -316,7 +316,7 @@ func @and_memref(%lhs: memref<10xf32>, %rhs: memref<10xf32>, %out: memref<10xf32
// CHECK-LABEL: func @or_memref
func @or_memref(%lhs: memref<10xi32>, %rhs: memref<10xi32>, %out: memref<10xi32>) -> () {
"xla_lhlo.or"(%lhs, %rhs, %out) : (memref<10xi32>, memref<10xi32>, memref<10xi32>) -> ()
"lmhlo.or"(%lhs, %rhs, %out) : (memref<10xi32>, memref<10xi32>, memref<10xi32>) -> ()
return
}
@ -324,7 +324,7 @@ func @or_memref(%lhs: memref<10xi32>, %rhs: memref<10xi32>, %out: memref<10xi32>
// CHECK-LABEL: func @or_memref
func @or_memref(%lhs: memref<10xi1>, %rhs: memref<10xi1>, %out: memref<10xi1>) -> () {
"xla_lhlo.or"(%lhs, %rhs, %out) : (memref<10xi1>, memref<10xi1>, memref<10xi1>) -> ()
"lmhlo.or"(%lhs, %rhs, %out) : (memref<10xi1>, memref<10xi1>, memref<10xi1>) -> ()
return
}
@ -332,7 +332,7 @@ func @or_memref(%lhs: memref<10xi1>, %rhs: memref<10xi1>, %out: memref<10xi1>) -
func @or_memref(%lhs: memref<10xf32>, %rhs: memref<10xf32>, %out: memref<10xf32>) -> () {
// expected-error @+1 {{must be memref of 8/16/32/64-bit signless integer or 8/16/32/64-bit unsigned integer or pred (AKA boolean or 1-bit integer) values}}
"xla_lhlo.or"(%lhs, %rhs, %out) : (memref<10xf32>, memref<10xf32>, memref<10xf32>) -> ()
"lmhlo.or"(%lhs, %rhs, %out) : (memref<10xf32>, memref<10xf32>, memref<10xf32>) -> ()
return
}
@ -340,7 +340,7 @@ func @or_memref(%lhs: memref<10xf32>, %rhs: memref<10xf32>, %out: memref<10xf32>
// CHECK-LABEL: func @xor_memref
func @xor_memref(%lhs: memref<10xi32>, %rhs: memref<10xi32>, %out: memref<10xi32>) -> () {
"xla_lhlo.xor"(%lhs, %rhs, %out) : (memref<10xi32>, memref<10xi32>, memref<10xi32>) -> ()
"lmhlo.xor"(%lhs, %rhs, %out) : (memref<10xi32>, memref<10xi32>, memref<10xi32>) -> ()
return
}
@ -348,7 +348,7 @@ func @xor_memref(%lhs: memref<10xi32>, %rhs: memref<10xi32>, %out: memref<10xi32
// CHECK-LABEL: func @xor_memref
func @xor_memref(%lhs: memref<10xi1>, %rhs: memref<10xi1>, %out: memref<10xi1>) -> () {
"xla_lhlo.xor"(%lhs, %rhs, %out) : (memref<10xi1>, memref<10xi1>, memref<10xi1>) -> ()
"lmhlo.xor"(%lhs, %rhs, %out) : (memref<10xi1>, memref<10xi1>, memref<10xi1>) -> ()
return
}
@ -356,7 +356,7 @@ func @xor_memref(%lhs: memref<10xi1>, %rhs: memref<10xi1>, %out: memref<10xi1>)
func @xor_memref(%lhs: memref<10xf32>, %rhs: memref<10xf32>, %out: memref<10xf32>) -> () {
// expected-error @+1 {{must be memref of 8/16/32/64-bit signless integer or 8/16/32/64-bit unsigned integer or pred (AKA boolean or 1-bit integer) values}}
"xla_lhlo.xor"(%lhs, %rhs, %out) : (memref<10xf32>, memref<10xf32>, memref<10xf32>) -> ()
"lmhlo.xor"(%lhs, %rhs, %out) : (memref<10xf32>, memref<10xf32>, memref<10xf32>) -> ()
return
}
@ -364,7 +364,7 @@ func @xor_memref(%lhs: memref<10xf32>, %rhs: memref<10xf32>, %out: memref<10xf32
// CHECK-LABEL: func @broadcast_in_dim_memref
func @broadcast_in_dim_memref(%arg0: memref<1x2xi32>, %out: memref<1x2x2xi32>) -> () {
"xla_lhlo.broadcast_in_dim"(%arg0, %out) {broadcast_dimensions = dense<[1, 2]> : tensor<2xi64>} : (memref<1x2xi32>, memref<1x2x2xi32>) -> ()
"lmhlo.broadcast_in_dim"(%arg0, %out) {broadcast_dimensions = dense<[1, 2]> : tensor<2xi64>} : (memref<1x2xi32>, memref<1x2x2xi32>) -> ()
return
}
@ -372,7 +372,7 @@ func @broadcast_in_dim_memref(%arg0: memref<1x2xi32>, %out: memref<1x2x2xi32>) -
// CHECK-LABEL: func @broadcast_in_dim_zero_rank_memref
func @broadcast_in_dim_zero_rank_memref(%arg0: memref<i32>, %out: memref<1x2x3xi32>) -> () {
"xla_lhlo.broadcast_in_dim"(%arg0, %out) {broadcast_dimensions = dense<[]> : tensor<0xi64>} : (memref<i32>, memref<1x2x3xi32>) -> ()
"lmhlo.broadcast_in_dim"(%arg0, %out) {broadcast_dimensions = dense<[]> : tensor<0xi64>} : (memref<i32>, memref<1x2x3xi32>) -> ()
return
}
@ -381,10 +381,10 @@ func @broadcast_in_dim_zero_rank_memref(%arg0: memref<i32>, %out: memref<1x2x3xi
// CHECK-LABEL: func @reduce_memref
func @reduce_memref(%input: memref<10xf32>, %init: memref<f32>, %out: memref<1xf32>) -> () {
"xla_lhlo.reduce"(%input, %init, %out) ( {
"lmhlo.reduce"(%input, %init, %out) ( {
^bb0(%arg1: memref<f32>, %arg2: memref<f32>, %result: memref<f32>):
"xla_lhlo.add"(%arg1, %arg2, %result) : (memref<f32>, memref<f32>, memref<f32>) -> ()
"xla_lhlo.terminator"() : () -> ()
"lmhlo.add"(%arg1, %arg2, %result) : (memref<f32>, memref<f32>, memref<f32>) -> ()
"lmhlo.terminator"() : () -> ()
} ) {dimensions = dense<[0]> : tensor<1xi64>} : (memref<10xf32>, memref<f32>, memref<1xf32>) -> ()
return
}
@ -393,14 +393,14 @@ func @reduce_memref(%input: memref<10xf32>, %init: memref<f32>, %out: memref<1xf
// CHECK-LABEL: func @fusion_memref
func @fusion_memref(%input1: memref<10xf32>, %input2: memref<10xf32>, %input3: memref<10xf32>, %out: memref<10xf32>) -> () {
"xla_lhlo.fusion"() ( {
"lmhlo.fusion"() ( {
%0 = tensor_load %input1 : memref<10xf32>
%1 = tensor_load %input2 : memref<10xf32>
%2 = "mhlo.add"(%0, %1) {name = "add"} : (tensor<10xf32>, tensor<10xf32>) -> tensor<10xf32>
%3 = tensor_load %input3 : memref<10xf32>
%4 = "mhlo.multiply"(%2, %3) {name = "multiply"} : (tensor<10xf32>, tensor<10xf32>) -> tensor<10xf32>
tensor_store %4, %out : memref<10xf32>
"xla_lhlo.terminator"() : () -> ()
"lmhlo.terminator"() : () -> ()
} ) : () -> ()
return
}
@ -409,18 +409,18 @@ func @fusion_memref(%input1: memref<10xf32>, %input2: memref<10xf32>, %input3: m
// CHECK-LABEL: func @case_memref
func @case_memref(%index: memref<i32>, %operand_1: memref<f32>, %operand_2: memref<f32>, %operand_3: memref<f32>, %out: memref<f32>) -> () {
"xla_lhlo.case"(%index, %operand_1, %operand_2, %operand_3, %out) ( {
"lmhlo.case"(%index, %operand_1, %operand_2, %operand_3, %out) ( {
^bb0(%arg0: memref<f32>):
"xla_lhlo.negate"(%arg0, %out) : (memref<f32>, memref<f32>) -> ()
"xla_lhlo.terminator"() : () -> ()
"lmhlo.negate"(%arg0, %out) : (memref<f32>, memref<f32>) -> ()
"lmhlo.terminator"() : () -> ()
}, {
^bb0(%arg0: memref<f32>):
"xla_lhlo.copy"(%arg0, %out) : (memref<f32>, memref<f32>) -> ()
"xla_lhlo.terminator"() : () -> ()
"lmhlo.copy"(%arg0, %out) : (memref<f32>, memref<f32>) -> ()
"lmhlo.terminator"() : () -> ()
}, {
^bb0(%arg0: memref<f32>):
"xla_lhlo.add"(%arg0, %arg0, %out) : (memref<f32>, memref<f32>, memref<f32>) -> ()
"xla_lhlo.terminator"() : () -> ()
"lmhlo.add"(%arg0, %arg0, %out) : (memref<f32>, memref<f32>, memref<f32>) -> ()
"lmhlo.terminator"() : () -> ()
}
) {operand_segment_sizes = dense<[1, 3, 1]> : vector<3xi32>}
: (memref<i32>, memref<f32>, memref<f32>, memref<f32>, memref<f32>) -> ()
@ -430,7 +430,7 @@ func @case_memref(%index: memref<i32>, %operand_1: memref<f32>, %operand_2: memr
// -----
func @static_memref_cast(%in: memref<10x1xf32>) {
%out = xla_lhlo.static_memref_cast %in
%out = lmhlo.static_memref_cast %in
: memref<10x1xf32> -> memref<10xf32, offset: 0, strides: [1]>
return
}
@ -440,7 +440,7 @@ func @static_memref_cast(%in: memref<10x1xf32>) {
func @static_memref_cast_dynamic_operand(%in: memref<10x?xf32>) {
// expected-error @+1 {{operand must have static shape}}
%out = xla_lhlo.static_memref_cast %in
%out = lmhlo.static_memref_cast %in
: memref<10x?xf32> -> memref<10x1xf32, offset: 0, strides: [10, 1]>
return
}
@ -449,7 +449,7 @@ func @static_memref_cast_dynamic_operand(%in: memref<10x?xf32>) {
func @static_memref_cast_dynamic_result(%in: memref<10x1xf32>) {
// expected-error @+1 {{result must have static shape}}
%out = xla_lhlo.static_memref_cast %in
%out = lmhlo.static_memref_cast %in
: memref<10x1xf32> -> memref<10x?xf32, offset: 0, strides: [?, ?]>
return
}
@ -459,7 +459,7 @@ func @static_memref_cast_dynamic_result(%in: memref<10x1xf32>) {
func @dynamic_memref_cast(%in: memref<?xf32>) {
%size = constant 10 : index
%step = constant 1 : index
%out = xla_lhlo.dynamic_memref_cast %in(%size)[%step]
%out = lmhlo.dynamic_memref_cast %in(%size)[%step]
: memref<?xf32> -> memref<?xf32, offset: 0, strides: [?]>
return
}
@ -471,7 +471,7 @@ func @dynamic_memref_cast_incompatible_result_type(%in: memref<?xf32>) {
// expected-error @+3 {{`sizes` args count must be equal to the rank of the output memref}}
%size = constant 10 : index
%step = constant 1 : index
%out = xla_lhlo.dynamic_memref_cast %in(%size)[%step]
%out = lmhlo.dynamic_memref_cast %in(%size)[%step]
: memref<?xf32> -> memref<?x?xf32, offset: 0, strides: [?, ?]>
return
}
@ -483,19 +483,19 @@ func @reshape_memref_cast(%unranked: memref<*xf32>, %shape1: memref<1xi32>,
// CHECK-SAME: [[UNRANKED:%.*]]: memref<*xf32>, [[SHAPE_1:%.*]]: memref<1xi32>,
// CHECK-SAME: [[SHAPE_2:%.*]]: memref<2xi32>, [[SHAPE_3:%.*]]: memref<?xi32>
// CHECK-NEXT: [[DYN_VEC:%.*]] = xla_lhlo.reshape_memref_cast [[UNRANKED]]
// CHECK-NEXT: [[DYN_VEC:%.*]] = lmhlo.reshape_memref_cast [[UNRANKED]]
// CHECK-SAME: : (memref<*xf32>, memref<1xi32>) -> memref<?xf32>
%dyn_vec = xla_lhlo.reshape_memref_cast %unranked(%shape1)
%dyn_vec = lmhlo.reshape_memref_cast %unranked(%shape1)
: (memref<*xf32>, memref<1xi32>) -> memref<?xf32>
// CHECK-NEXT: [[DYN_MAT:%.*]] = xla_lhlo.reshape_memref_cast [[DYN_VEC]]
// CHECK-NEXT: [[DYN_MAT:%.*]] = lmhlo.reshape_memref_cast [[DYN_VEC]]
// CHECK-SAME: : (memref<?xf32>, memref<2xi32>) -> memref<?x?xf32>
%dyn_mat = xla_lhlo.reshape_memref_cast %dyn_vec(%shape2)
%dyn_mat = lmhlo.reshape_memref_cast %dyn_vec(%shape2)
: (memref<?xf32>, memref<2xi32>) -> memref<?x?xf32>
// CHECK-NEXT: {{%.*}} = xla_lhlo.reshape_memref_cast [[DYN_MAT]]
// CHECK-NEXT: {{%.*}} = lmhlo.reshape_memref_cast [[DYN_MAT]]
// CHECK-SAME: : (memref<?x?xf32>, memref<?xi32>) -> memref<*xf32>
%new_unranked = xla_lhlo.reshape_memref_cast %dyn_mat(%shape3)
%new_unranked = lmhlo.reshape_memref_cast %dyn_mat(%shape3)
: (memref<?x?xf32>, memref<?xi32>) -> memref<*xf32>
return
}
@ -505,7 +505,7 @@ func @reshape_memref_cast(%unranked: memref<*xf32>, %shape1: memref<1xi32>,
func @reshape_memref_cast_element_type_mismatch(
%buf: memref<*xf32>, %shape: memref<1xi32>) {
// expected-error @+1 {{element types of source and destination memref types should be the same}}
xla_lhlo.reshape_memref_cast %buf(%shape)
lmhlo.reshape_memref_cast %buf(%shape)
: (memref<*xf32>, memref<1xi32>) -> memref<?xi32>
}
@ -514,7 +514,7 @@ func @reshape_memref_cast_element_type_mismatch(
func @reshape_memref_cast_dst_ranked_shape_unranked(
%buf: memref<*xf32>, %shape: memref<?xi32>) {
// expected-error @+1 {{cannot use shape operand with dynamic length to cast statically-ranked memref type}}
xla_lhlo.reshape_memref_cast %buf(%shape)
lmhlo.reshape_memref_cast %buf(%shape)
: (memref<*xf32>, memref<?xi32>) -> memref<?xf32>
return
}
@ -524,7 +524,7 @@ func @reshape_memref_cast_dst_ranked_shape_unranked(
func @reshape_memref_cast_dst_shape_rank_mismatch(
%buf: memref<*xf32>, %shape: memref<1xi32>) {
// expected-error @+1 {{length of shape operand differs from the result's memref rank}}
xla_lhlo.reshape_memref_cast %buf(%shape)
lmhlo.reshape_memref_cast %buf(%shape)
: (memref<*xf32>, memref<1xi32>) -> memref<?x?xf32>
return
}
@ -535,7 +535,7 @@ func @reshape_memref_cast_affine_map_is_not_identity(
%buf: memref<4x4xf32, offset: 0, strides: [3, 2]>,
%shape: memref<1xi32>) {
// expected-error @+1 {{operand memref type should have identity affine map}}
xla_lhlo.reshape_memref_cast %buf(%shape)
lmhlo.reshape_memref_cast %buf(%shape)
: (memref<4x4xf32, offset: 0, strides: [3, 2]>, memref<1xi32>)
-> memref<8xf32>
return
@ -545,7 +545,7 @@ func @reshape_memref_cast_affine_map_is_not_identity(
// CHECK-LABEL: func @atan2_memrefs
func @atan2_memrefs(%arg0: memref<1xf32>, %arg1: memref<1xf32>, %arg_out: memref<1xf32>) -> () {
"xla_lhlo.atan2"(%arg0, %arg1, %arg_out) : (memref<1xf32>, memref<1xf32>, memref<1xf32>) -> ()
"lmhlo.atan2"(%arg0, %arg1, %arg_out) : (memref<1xf32>, memref<1xf32>, memref<1xf32>) -> ()
return
}
@ -553,7 +553,7 @@ func @atan2_memrefs(%arg0: memref<1xf32>, %arg1: memref<1xf32>, %arg_out: memref
// CHECK-LABEL: func @atan2_memrefs
func @atan2_memrefs(%arg0: memref<1xcomplex<f32>>, %arg1: memref<1xcomplex<f32>>, %arg_out: memref<1xcomplex<f32>>) -> () {
"xla_lhlo.atan2"(%arg0, %arg1, %arg_out) : (memref<1xcomplex<f32>>, memref<1xcomplex<f32>>, memref<1xcomplex<f32>>) -> ()
"lmhlo.atan2"(%arg0, %arg1, %arg_out) : (memref<1xcomplex<f32>>, memref<1xcomplex<f32>>, memref<1xcomplex<f32>>) -> ()
return
}
@ -561,7 +561,7 @@ func @atan2_memrefs(%arg0: memref<1xcomplex<f32>>, %arg1: memref<1xcomplex<f32>>
func @atan2_memrefs(%arg0: memref<1xi32>, %arg1: memref<1xi32>, %arg_out: memref<1xi32>) -> () {
// expected-error@+1{{must be memref of floating-point or complex-type values}}
"xla_lhlo.atan2"(%arg0, %arg1, %arg_out) : (memref<1xi32>, memref<1xi32>, memref<1xi32>) -> ()
"lmhlo.atan2"(%arg0, %arg1, %arg_out) : (memref<1xi32>, memref<1xi32>, memref<1xi32>) -> ()
return
}
@ -569,7 +569,7 @@ func @atan2_memrefs(%arg0: memref<1xi32>, %arg1: memref<1xi32>, %arg_out: memref
// CHECK-LABEL: func @bitcast_convert_memrefs
func @bitcast_convert_memrefs(%arg0: memref<1xf32>, %arg_out: memref<1xi32>) -> () {
"xla_lhlo.bitcast_convert"(%arg0, %arg_out) : (memref<1xf32>, memref<1xi32>) -> ()
"lmhlo.bitcast_convert"(%arg0, %arg_out) : (memref<1xf32>, memref<1xi32>) -> ()
return
}
@ -577,7 +577,7 @@ func @bitcast_convert_memrefs(%arg0: memref<1xf32>, %arg_out: memref<1xi32>) ->
func @bitcast_convert_memrefs(%arg0: memref<1xf32>, %arg_out: memref<2xi32>) -> () {
// expected-error@+1{{requires the same shape for all operands}}
"xla_lhlo.bitcast_convert"(%arg0, %arg_out) : (memref<1xf32>, memref<2xi32>) -> ()
"lmhlo.bitcast_convert"(%arg0, %arg_out) : (memref<1xf32>, memref<2xi32>) -> ()
return
}
@ -585,7 +585,7 @@ func @bitcast_convert_memrefs(%arg0: memref<1xf32>, %arg_out: memref<2xi32>) ->
// CHECK-LABEL: func @clz_memrefs
func @clz_memrefs(%arg0: memref<1xi32>, %arg_out: memref<1xi32>) -> () {
"xla_lhlo.count_leading_zeros"(%arg0, %arg_out) : (memref<1xi32>, memref<1xi32>) -> ()
"lmhlo.count_leading_zeros"(%arg0, %arg_out) : (memref<1xi32>, memref<1xi32>) -> ()
return
}
@ -593,7 +593,7 @@ func @clz_memrefs(%arg0: memref<1xi32>, %arg_out: memref<1xi32>) -> () {
// CHECK-LABEL: func @expm1_memrefs
func @expm1_memrefs(%arg0: memref<1xf32>, %arg_out: memref<1xf32>) -> () {
"xla_lhlo.exponential_minus_one"(%arg0, %arg_out) : (memref<1xf32>, memref<1xf32>) -> ()
"lmhlo.exponential_minus_one"(%arg0, %arg_out) : (memref<1xf32>, memref<1xf32>) -> ()
return
}
@ -601,7 +601,7 @@ func @expm1_memrefs(%arg0: memref<1xf32>, %arg_out: memref<1xf32>) -> () {
// CHECK-LABEL: func @expm1_memrefs
func @expm1_memrefs(%arg0: memref<1xcomplex<f32>>, %arg_out: memref<1xcomplex<f32>>) -> () {
"xla_lhlo.exponential_minus_one"(%arg0, %arg_out) : (memref<1xcomplex<f32>>, memref<1xcomplex<f32>>) -> ()
"lmhlo.exponential_minus_one"(%arg0, %arg_out) : (memref<1xcomplex<f32>>, memref<1xcomplex<f32>>) -> ()
return
}
@ -609,7 +609,7 @@ func @expm1_memrefs(%arg0: memref<1xcomplex<f32>>, %arg_out: memref<1xcomplex<f3
// CHECK-LABEL: func @floor_memrefs
func @floor_memrefs(%arg0: memref<1xf32>, %arg_out: memref<1xf32>) -> () {
"xla_lhlo.floor"(%arg0, %arg_out) : (memref<1xf32>, memref<1xf32>) -> ()
"lmhlo.floor"(%arg0, %arg_out) : (memref<1xf32>, memref<1xf32>) -> ()
return
}
@ -617,7 +617,7 @@ func @floor_memrefs(%arg0: memref<1xf32>, %arg_out: memref<1xf32>) -> () {
func @floor_memrefs(%arg0: memref<1xi32>, %arg_out: memref<1xi32>) -> () {
// expected-error@+1{{must be memref of floating-point values}}
"xla_lhlo.floor"(%arg0, %arg_out) : (memref<1xi32>, memref<1xi32>) -> ()
"lmhlo.floor"(%arg0, %arg_out) : (memref<1xi32>, memref<1xi32>) -> ()
return
}
@ -625,7 +625,7 @@ func @floor_memrefs(%arg0: memref<1xi32>, %arg_out: memref<1xi32>) -> () {
// CHECK-LABEL: func @imag_memrefs
func @imag_memrefs(%arg0: memref<1xcomplex<f32>>, %arg_out: memref<1xf32>) -> () {
"xla_lhlo.imag"(%arg0, %arg_out) : (memref<1xcomplex<f32>>, memref<1xf32>) -> ()
"lmhlo.imag"(%arg0, %arg_out) : (memref<1xcomplex<f32>>, memref<1xf32>) -> ()
return
}
@ -633,7 +633,7 @@ func @imag_memrefs(%arg0: memref<1xcomplex<f32>>, %arg_out: memref<1xf32>) -> ()
func @imag_memrefs(%arg0: memref<1xf32>, %arg_out: memref<1xf32>) -> () {
// expected-error@+1{{must be memref of complex-type values}}
"xla_lhlo.imag"(%arg0, %arg_out) : (memref<1xf32>, memref<1xf32>) -> ()
"lmhlo.imag"(%arg0, %arg_out) : (memref<1xf32>, memref<1xf32>) -> ()
return
}
@ -641,7 +641,7 @@ func @imag_memrefs(%arg0: memref<1xf32>, %arg_out: memref<1xf32>) -> () {
// CHECK-LABEL: func @real_memrefs
func @real_memrefs(%arg0: memref<1xcomplex<f32>>, %arg_out: memref<1xf32>) -> () {
"xla_lhlo.real"(%arg0, %arg_out) : (memref<1xcomplex<f32>>, memref<1xf32>) -> ()
"lmhlo.real"(%arg0, %arg_out) : (memref<1xcomplex<f32>>, memref<1xf32>) -> ()
return
}
@ -649,7 +649,7 @@ func @real_memrefs(%arg0: memref<1xcomplex<f32>>, %arg_out: memref<1xf32>) -> ()
func @real_memrefs(%arg0: memref<1xf32>, %arg_out: memref<1xf32>) -> () {
// expected-error@+1{{must be memref of complex-type values}}
"xla_lhlo.real"(%arg0, %arg_out) : (memref<1xf32>, memref<1xf32>) -> ()
"lmhlo.real"(%arg0, %arg_out) : (memref<1xf32>, memref<1xf32>) -> ()
return
}
@ -657,7 +657,7 @@ func @real_memrefs(%arg0: memref<1xf32>, %arg_out: memref<1xf32>) -> () {
// CHECK-LABEL: func @is_finite_memrefs
func @is_finite_memrefs(%arg0: memref<1xf32>, %arg_out: memref<1xi1>) -> () {
"xla_lhlo.is_finite"(%arg0, %arg_out) : (memref<1xf32>, memref<1xi1>) -> ()
"lmhlo.is_finite"(%arg0, %arg_out) : (memref<1xf32>, memref<1xi1>) -> ()
return
}
@ -665,7 +665,7 @@ func @is_finite_memrefs(%arg0: memref<1xf32>, %arg_out: memref<1xi1>) -> () {
// CHECK-LABEL: func @log1p_memrefs
func @log1p_memrefs(%arg0: memref<1xf32>, %arg_out: memref<1xf32>) -> () {
"xla_lhlo.log_plus_one"(%arg0, %arg_out) : (memref<1xf32>, memref<1xf32>) -> ()
"lmhlo.log_plus_one"(%arg0, %arg_out) : (memref<1xf32>, memref<1xf32>) -> ()
return
}
@ -673,7 +673,7 @@ func @log1p_memrefs(%arg0: memref<1xf32>, %arg_out: memref<1xf32>) -> () {
// CHECK-LABEL: func @log1p_memrefs
func @log1p_memrefs(%arg0: memref<1xcomplex<f32>>, %arg_out: memref<1xcomplex<f32>>) -> () {
"xla_lhlo.log_plus_one"(%arg0, %arg_out) : (memref<1xcomplex<f32>>, memref<1xcomplex<f32>>) -> ()
"lmhlo.log_plus_one"(%arg0, %arg_out) : (memref<1xcomplex<f32>>, memref<1xcomplex<f32>>) -> ()
return
}
@ -681,7 +681,7 @@ func @log1p_memrefs(%arg0: memref<1xcomplex<f32>>, %arg_out: memref<1xcomplex<f3
func @log1p_memref(%in: memref<10xi32>, %out: memref<10xi32>) -> () {
// expected-error@+1{{must be memref of floating-point or complex-type values}}
"xla_lhlo.log_plus_one"(%in, %out) : (memref<10xi32>, memref<10xi32>) -> ()
"lmhlo.log_plus_one"(%in, %out) : (memref<10xi32>, memref<10xi32>) -> ()
return
}
@ -689,7 +689,7 @@ func @log1p_memref(%in: memref<10xi32>, %out: memref<10xi32>) -> () {
// CHECK-LABEL: func @not_memrefs
func @not_memrefs(%arg0: memref<1xi32>, %arg_out: memref<1xi32>) -> () {
"xla_lhlo.not"(%arg0, %arg_out) : (memref<1xi32>, memref<1xi32>) -> ()
"lmhlo.not"(%arg0, %arg_out) : (memref<1xi32>, memref<1xi32>) -> ()
return
}
@ -697,7 +697,7 @@ func @not_memrefs(%arg0: memref<1xi32>, %arg_out: memref<1xi32>) -> () {
// CHECK-LABEL: func @not_memrefs
func @not_memrefs(%arg0: memref<1xi1>, %arg_out: memref<1xi1>) -> () {
"xla_lhlo.not"(%arg0, %arg_out) : (memref<1xi1>, memref<1xi1>) -> ()
"lmhlo.not"(%arg0, %arg_out) : (memref<1xi1>, memref<1xi1>) -> ()
return
}
@ -705,7 +705,7 @@ func @not_memrefs(%arg0: memref<1xi1>, %arg_out: memref<1xi1>) -> () {
func @not_memrefs(%arg0: memref<1xf32>, %arg_out: memref<1xf32>) -> () {
// expected-error @+1 {{must be memref of 8/16/32/64-bit signless integer or 8/16/32/64-bit unsigned integer or pred (AKA boolean or 1-bit integer) values}}
"xla_lhlo.not"(%arg0, %arg_out) : (memref<1xf32>, memref<1xf32>) -> ()
"lmhlo.not"(%arg0, %arg_out) : (memref<1xf32>, memref<1xf32>) -> ()
return
}
@ -713,7 +713,7 @@ func @not_memrefs(%arg0: memref<1xf32>, %arg_out: memref<1xf32>) -> () {
// CHECK-LABEL: func @popcnt_memrefs
func @popcnt_memrefs(%arg0: memref<1xi32>, %arg_out: memref<1xi32>) -> () {
"xla_lhlo.popcnt"(%arg0, %arg_out) : (memref<1xi32>, memref<1xi32>) -> ()
"lmhlo.popcnt"(%arg0, %arg_out) : (memref<1xi32>, memref<1xi32>) -> ()
return
}
@ -721,7 +721,7 @@ func @popcnt_memrefs(%arg0: memref<1xi32>, %arg_out: memref<1xi32>) -> () {
func @popcnt_memrefs(%arg0: memref<1xf32>, %arg_out: memref<1xf32>) -> () {
// expected-error @+1 {{must be memref of 8/16/32/64-bit signless integer or 8/16/32/64-bit unsigned integer values}}
"xla_lhlo.popcnt"(%arg0, %arg_out) : (memref<1xf32>, memref<1xf32>) -> ()
"lmhlo.popcnt"(%arg0, %arg_out) : (memref<1xf32>, memref<1xf32>) -> ()
return
}
@ -729,7 +729,7 @@ func @popcnt_memrefs(%arg0: memref<1xf32>, %arg_out: memref<1xf32>) -> () {
// CHECK-LABEL: func @reduce_precision_memrefs
func @reduce_precision_memrefs(%arg0: memref<1xf32>, %arg_out: memref<1xf32>) -> () {
"xla_lhlo.reduce_precision"(%arg0, %arg_out) { exponent_bits = 4 : i32, mantissa_bits = 4 : i32 } : (memref<1xf32>, memref<1xf32>) -> ()
"lmhlo.reduce_precision"(%arg0, %arg_out) { exponent_bits = 4 : i32, mantissa_bits = 4 : i32 } : (memref<1xf32>, memref<1xf32>) -> ()
return
}
@ -737,7 +737,7 @@ func @reduce_precision_memrefs(%arg0: memref<1xf32>, %arg_out: memref<1xf32>) ->
// CHECK-LABEL: func @round_memrefs
func @round_memrefs(%arg0: memref<1xf32>, %arg_out: memref<1xf32>) -> () {
"xla_lhlo.round_nearest_afz"(%arg0, %arg_out) : (memref<1xf32>, memref<1xf32>) -> ()
"lmhlo.round_nearest_afz"(%arg0, %arg_out) : (memref<1xf32>, memref<1xf32>) -> ()
return
}
@ -745,7 +745,7 @@ func @round_memrefs(%arg0: memref<1xf32>, %arg_out: memref<1xf32>) -> () {
func @round_memrefs(%arg0: memref<1xi32>, %arg_out: memref<1xi32>) -> () {
// expected-error@+1{{must be memref of floating-point values}}
"xla_lhlo.round_nearest_afz"(%arg0, %arg_out) : (memref<1xi32>, memref<1xi32>) -> ()
"lmhlo.round_nearest_afz"(%arg0, %arg_out) : (memref<1xi32>, memref<1xi32>) -> ()
return
}
@ -753,7 +753,7 @@ func @round_memrefs(%arg0: memref<1xi32>, %arg_out: memref<1xi32>) -> () {
// CHECK-LABEL: func @shift_left_memrefs
func @shift_left_memrefs(%arg0: memref<1xi32>, %arg1: memref<1xi32>, %arg_out: memref<1xi32>) -> () {
"xla_lhlo.shift_left"(%arg0, %arg1, %arg_out) : (memref<1xi32>, memref<1xi32>, memref<1xi32>) -> ()
"lmhlo.shift_left"(%arg0, %arg1, %arg_out) : (memref<1xi32>, memref<1xi32>, memref<1xi32>) -> ()
return
}
@ -761,7 +761,7 @@ func @shift_left_memrefs(%arg0: memref<1xi32>, %arg1: memref<1xi32>, %arg_out: m
func @shift_left_memrefs(%arg0: memref<1xf32>, %arg1: memref<1xf32>, %arg_out: memref<1xf32>) -> () {
// expected-error @+1 {{must be memref of 8/16/32/64-bit signless integer or 8/16/32/64-bit unsigned integer values}}
"xla_lhlo.shift_left"(%arg0, %arg1, %arg_out) : (memref<1xf32>, memref<1xf32>, memref<1xf32>) -> ()
"lmhlo.shift_left"(%arg0, %arg1, %arg_out) : (memref<1xf32>, memref<1xf32>, memref<1xf32>) -> ()
return
}
@ -769,7 +769,7 @@ func @shift_left_memrefs(%arg0: memref<1xf32>, %arg1: memref<1xf32>, %arg_out: m
// CHECK-LABEL: func @shift_right_arithmetic_memrefs
func @shift_right_arithmetic_memrefs(%arg0: memref<1xi32>, %arg1: memref<1xi32>, %arg_out: memref<1xi32>) -> () {
"xla_lhlo.shift_right_arithmetic"(%arg0, %arg1, %arg_out) : (memref<1xi32>, memref<1xi32>, memref<1xi32>) -> ()
"lmhlo.shift_right_arithmetic"(%arg0, %arg1, %arg_out) : (memref<1xi32>, memref<1xi32>, memref<1xi32>) -> ()
return
}
@ -777,7 +777,7 @@ func @shift_right_arithmetic_memrefs(%arg0: memref<1xi32>, %arg1: memref<1xi32>,
func @shift_right_arithmetic_memrefs(%arg0: memref<1xf32>, %arg1: memref<1xf32>, %arg_out: memref<1xf32>) -> () {
// expected-error @+1 {{must be memref of 8/16/32/64-bit signless integer or 8/16/32/64-bit unsigned integer values}}
"xla_lhlo.shift_right_arithmetic"(%arg0, %arg1, %arg_out) : (memref<1xf32>, memref<1xf32>, memref<1xf32>) -> ()
"lmhlo.shift_right_arithmetic"(%arg0, %arg1, %arg_out) : (memref<1xf32>, memref<1xf32>, memref<1xf32>) -> ()
return
}
@ -785,7 +785,7 @@ func @shift_right_arithmetic_memrefs(%arg0: memref<1xf32>, %arg1: memref<1xf32>,
// CHECK-LABEL: func @shift_right_logical_memrefs
func @shift_right_logical_memrefs(%arg0: memref<1xi32>, %arg1: memref<1xi32>, %arg_out: memref<1xi32>) -> () {
"xla_lhlo.shift_right_logical"(%arg0, %arg1, %arg_out) : (memref<1xi32>, memref<1xi32>, memref<1xi32>) -> ()
"lmhlo.shift_right_logical"(%arg0, %arg1, %arg_out) : (memref<1xi32>, memref<1xi32>, memref<1xi32>) -> ()
return
}
@ -793,7 +793,7 @@ func @shift_right_logical_memrefs(%arg0: memref<1xi32>, %arg1: memref<1xi32>, %a
func @shift_right_logical_memrefs(%arg0: memref<1xf32>, %arg1: memref<1xf32>, %arg_out: memref<1xf32>) -> () {
// expected-error @+1 {{must be memref of 8/16/32/64-bit signless integer or 8/16/32/64-bit unsigned integer values}}
"xla_lhlo.shift_right_logical"(%arg0, %arg1, %arg_out) : (memref<1xf32>, memref<1xf32>, memref<1xf32>) -> ()
"lmhlo.shift_right_logical"(%arg0, %arg1, %arg_out) : (memref<1xf32>, memref<1xf32>, memref<1xf32>) -> ()
return
}
@ -801,14 +801,14 @@ func @shift_right_logical_memrefs(%arg0: memref<1xf32>, %arg1: memref<1xf32>, %a
// CHECK-LABEL: func @all_reduce_memrefs
func @all_reduce_memrefs(%arg0: memref<10xf32>, %arg_out: memref<10xf32>) -> () {
"xla_lhlo.all_reduce"(%arg0, %arg_out) ({
"lmhlo.all_reduce"(%arg0, %arg_out) ({
^bb0(%lhs: tensor<f32>, %rhs: tensor<f32>):
%max = mhlo.maximum %lhs, %rhs : tensor<f32>
"mhlo.return"(%max) : (tensor<f32>) -> ()
})
{ replica_groups = dense<[[0, 2, 4, 6], [1, 3, 5, 7]]> : tensor<2x4xi64> }: (memref<10xf32>, memref<10xf32>) -> ()
"xla_lhlo.all_reduce"(%arg0, %arg_out) ({
"lmhlo.all_reduce"(%arg0, %arg_out) ({
^bb0(%lhs: tensor<f32>, %rhs: tensor<f32>):
%max = mhlo.maximum %lhs, %rhs : tensor<f32>
"mhlo.return"(%max) : (tensor<f32>) -> ()
@ -826,11 +826,11 @@ func @all_reduce_memrefs(%arg0: memref<10xf32>, %arg_out: memref<10xf32>) -> ()
// CHECK-LABEL: func @collective_permute_memrefs
func @collective_permute_memrefs(%arg0: memref<128x32xf32>, %arg_out: memref<128x32xf32>) -> () {
"xla_lhlo.collective_permute"(%arg0, %arg_out) {
"lmhlo.collective_permute"(%arg0, %arg_out) {
source_target_pairs = dense<[[0, 1], [1, 2], [2, 3]]> : tensor<3x2xi64>
} : (memref<128x32xf32>, memref<128x32xf32>) -> ()
"xla_lhlo.collective_permute"(%arg0, %arg_out) {
"lmhlo.collective_permute"(%arg0, %arg_out) {
source_target_pairs = dense<[[0, 1], [1, 2], [2, 3]]> : tensor<3x2xi64>,
channel_id = { handle = 5 : i64, type = 2 : i64 }
} : (memref<128x32xf32>, memref<128x32xf32>) -> ()
@ -841,7 +841,7 @@ func @collective_permute_memrefs(%arg0: memref<128x32xf32>, %arg_out: memref<128
// CHECK-LABEL: func @fft_memrefs
func @fft_memrefs(%arg0: memref<3x9xf32>, %arg_out: memref<3x5xcomplex<f32>>) -> () {
"xla_lhlo.fft"(%arg0, %arg_out) {fft_length = dense<9> : tensor<1xi64>, fft_type = "RFFT"} : (memref<3x9xf32>, memref<3x5xcomplex<f32>>) -> ()
"lmhlo.fft"(%arg0, %arg_out) {fft_length = dense<9> : tensor<1xi64>, fft_type = "RFFT"} : (memref<3x9xf32>, memref<3x5xcomplex<f32>>) -> ()
return
}
@ -852,7 +852,7 @@ func @batch_norm_grad_memrefs(%arg0: memref<8x8x8x8xf32>, %arg1: memref<8xf32>,
%arg3: memref<8xf32>, %arg4: memref<8x8x8x8xf32>,
%grad_operand: memref<8x8x8x8xf32>, %grad_scale: memref<8xf32>,
%grad_offset: memref<8xf32>) -> () {
"xla_lhlo.batch_norm_grad"(%arg0, %arg1, %arg2, %arg3, %arg4, %grad_operand, %grad_scale, %grad_offset) {epsilon = 1.000000e-03 : f32, feature_index = 3 : i64}
"lmhlo.batch_norm_grad"(%arg0, %arg1, %arg2, %arg3, %arg4, %grad_operand, %grad_scale, %grad_offset) {epsilon = 1.000000e-03 : f32, feature_index = 3 : i64}
: (memref<8x8x8x8xf32>, memref<8xf32>, memref<8xf32>, memref<8xf32>, memref<8x8x8x8xf32>,
memref<8x8x8x8xf32>, memref<8xf32>, memref<8xf32>) -> ()
return
@ -863,7 +863,7 @@ func @batch_norm_grad_memrefs(%arg0: memref<8x8x8x8xf32>, %arg1: memref<8xf32>,
// CHECK-LABEL: func @batch_norm_inference_memrefs
func @batch_norm_inference_memrefs(%arg0: memref<8x8x8x8xf32>, %arg1: memref<8xf32>, %arg2: memref<8xf32>,
%arg3: memref<8xf32>, %arg4: memref<8xf32>, %arg_out: memref<8x8x8x8xf32>) -> () {
"xla_lhlo.batch_norm_inference"(%arg0, %arg1, %arg2, %arg3, %arg4, %arg_out) {epsilon = 1.000000e-03 : f32, feature_index = 3 : i64}
"lmhlo.batch_norm_inference"(%arg0, %arg1, %arg2, %arg3, %arg4, %arg_out) {epsilon = 1.000000e-03 : f32, feature_index = 3 : i64}
: (memref<8x8x8x8xf32>, memref<8xf32>, memref<8xf32>, memref<8xf32>, memref<8xf32>, memref<8x8x8x8xf32>) -> ()
return
}
@ -874,7 +874,7 @@ func @batch_norm_inference_memrefs(%arg0: memref<8x8x8x8xf32>, %arg1: memref<8xf
func @batch_norm_training_memrefs(%arg0: memref<8x8x8x8xf32>, %arg1: memref<8xf32>, %arg2: memref<8xf32>,
%output: memref<8x8x8x8xf32>, %batch_mean: memref<8xf32>,
%batch_var: memref<8xf32>) -> () {
"xla_lhlo.batch_norm_training"(%arg0, %arg1, %arg2, %output, %batch_mean, %batch_var) {epsilon = 1.000000e-03 : f32, feature_index = 3 : i64}
"lmhlo.batch_norm_training"(%arg0, %arg1, %arg2, %output, %batch_mean, %batch_var) {epsilon = 1.000000e-03 : f32, feature_index = 3 : i64}
: (memref<8x8x8x8xf32>, memref<8xf32>, memref<8xf32>, memref<8x8x8x8xf32>, memref<8xf32>, memref<8xf32>) -> ()
return
}
@ -883,8 +883,8 @@ func @batch_norm_training_memrefs(%arg0: memref<8x8x8x8xf32>, %arg1: memref<8xf3
// CHECK-LABEL: func @cholesky_memrefs
func @cholesky_memrefs(%arg0: memref<1x291x291xf32>, %arg_out: memref<1x291x291xf32>) -> () {
"xla_lhlo.cholesky"(%arg0, %arg_out) : (memref<1x291x291xf32>, memref<1x291x291xf32>) -> ()
"xla_lhlo.cholesky"(%arg0, %arg_out) { lower = true } : (memref<1x291x291xf32>, memref<1x291x291xf32>) -> ()
"lmhlo.cholesky"(%arg0, %arg_out) : (memref<1x291x291xf32>, memref<1x291x291xf32>) -> ()
"lmhlo.cholesky"(%arg0, %arg_out) { lower = true } : (memref<1x291x291xf32>, memref<1x291x291xf32>) -> ()
return
}
@ -892,7 +892,7 @@ func @cholesky_memrefs(%arg0: memref<1x291x291xf32>, %arg_out: memref<1x291x291x
// CHECK-LABEL: func @infeed_memrefs
func @infeed_memrefs(%arg_out: memref<3xf32>) -> () {
"xla_lhlo.infeed"(%arg_out) { config = "x" } : (memref<3xf32>) -> ()
"lmhlo.infeed"(%arg_out) { config = "x" } : (memref<3xf32>) -> ()
return
}
@ -900,7 +900,7 @@ func @infeed_memrefs(%arg_out: memref<3xf32>) -> () {
// CHECK-LABEL: func @outfeed_memrefs
func @outfeed_memrefs(%arg0: memref<3xf32>) -> () {
"xla_lhlo.outfeed"(%arg0) { config = "x" } : (memref<3xf32>) -> ()
"lmhlo.outfeed"(%arg0) { config = "x" } : (memref<3xf32>) -> ()
return
}
@ -908,7 +908,7 @@ func @outfeed_memrefs(%arg0: memref<3xf32>) -> () {
// CHECK-LABEL: func @replica_id_memrefs
func @replica_id_memrefs(%arg_out: memref<ui32>) -> () {
"xla_lhlo.replica_id"(%arg_out) : (memref<ui32>) -> ()
"lmhlo.replica_id"(%arg_out) : (memref<ui32>) -> ()
return
}
@ -916,7 +916,7 @@ func @replica_id_memrefs(%arg_out: memref<ui32>) -> () {
// CHECK-LABEL: func @triangular_solve_memrefs
func @triangular_solve_memrefs(%arg0: memref<4x4xf32>, %arg1: memref<3x4xf32>, %arg_out: memref<3x4xf32>) -> () {
"xla_lhlo.triangular_solve"(%arg0, %arg1, %arg_out) {left_side = true, lower = true, transpose_a = "NO_TRANSPOSE", unit_diagonal = true}
"lmhlo.triangular_solve"(%arg0, %arg1, %arg_out) {left_side = true, lower = true, transpose_a = "NO_TRANSPOSE", unit_diagonal = true}
: (memref<4x4xf32>, memref<3x4xf32>, memref<3x4xf32>) -> ()
return
}
@ -925,9 +925,9 @@ func @triangular_solve_memrefs(%arg0: memref<4x4xf32>, %arg1: memref<3x4xf32>, %
// CHECK-LABEL: func @while_memrefs
func @while_memrefs(%arg0: memref<i64>, %arg_out: memref<i64>) -> () {
"xla_lhlo.while"(%arg0, %arg_out) (
{ ^bb0(%arg: memref<i64>, %cond: memref<i1>): "xla_lhlo.terminator"() : () -> () },
{ ^bb0(%arg: memref<i64>, %body_out: memref<i64>): "xla_lhlo.terminator"() : () -> () }
"lmhlo.while"(%arg0, %arg_out) (
{ ^bb0(%arg: memref<i64>, %cond: memref<i1>): "lmhlo.terminator"() : () -> () },
{ ^bb0(%arg: memref<i64>, %body_out: memref<i64>): "lmhlo.terminator"() : () -> () }
) : (memref<i64>, memref<i64>) -> ()
return
}
@ -936,9 +936,9 @@ func @while_memrefs(%arg0: memref<i64>, %arg_out: memref<i64>) -> () {
// CHECK-LABEL: func @while_memrefs
func @while_memrefs(%arg0: memref<i64>, %arg1: memref<5xf32>, %arg0_out: memref<i64>, %arg1_out: memref<5xf32>) -> () {
"xla_lhlo.while"(%arg0, %arg1, %arg0_out, %arg1_out) (
{ ^bb0(%cur0: memref<i64>, %cur1: memref<5xf32>, %cond: memref<i1>): "xla_lhlo.terminator"() : () -> () },
{ ^bb0(%cur0: memref<i64>, %cur1: memref<5xf32>, %body_out0: memref<i64>, %body_out1: memref<5xf32>): "xla_lhlo.terminator"() : () -> () }
"lmhlo.while"(%arg0, %arg1, %arg0_out, %arg1_out) (
{ ^bb0(%cur0: memref<i64>, %cur1: memref<5xf32>, %cond: memref<i1>): "lmhlo.terminator"() : () -> () },
{ ^bb0(%cur0: memref<i64>, %cur1: memref<5xf32>, %body_out0: memref<i64>, %body_out1: memref<5xf32>): "lmhlo.terminator"() : () -> () }
) : (memref<i64>, memref<5xf32>, memref<i64>, memref<5xf32>) -> ()
return
}
@ -947,7 +947,7 @@ func @while_memrefs(%arg0: memref<i64>, %arg1: memref<5xf32>, %arg0_out: memref<
// CHECK-LABEL: func @bitcast_memrefs
func @bitcast_memrefs(%arg0: memref<1xf64>, %arg_out: memref<2xi32>) -> () {
"xla_lhlo.bitcast"(%arg0, %arg_out) : (memref<1xf64>, memref<2xi32>) -> ()
"lmhlo.bitcast"(%arg0, %arg_out) : (memref<1xf64>, memref<2xi32>) -> ()
return
}
@ -956,7 +956,7 @@ func @bitcast_memrefs(%arg0: memref<1xf64>, %arg_out: memref<2xi32>) -> () {
// CHECK-LABEL: func @scatter_memrefs
func @scatter_memrefs(%input: memref<200x100x300xf32>, %indices: memref<10x2xi32>,
%updates: memref<10x300xf32>, %arg_out: memref<200x100x300xf32>) -> () {
"xla_lhlo.scatter" (%input, %indices, %updates, %arg_out) ({
"lmhlo.scatter" (%input, %indices, %updates, %arg_out) ({
^bb0(%lhs: tensor<f32>, %rhs: tensor<f32>): // no predecessors
%add = mhlo.add %lhs, %rhs : tensor<f32>
"mhlo.return"(%add) : (tensor<f32>) -> ()
@ -977,7 +977,7 @@ func @scatter_memrefs(%input: memref<200x100x300xf32>, %indices: memref<10x2xi32
// CHECK-LABEL: func @map_memrefs
func @map_memrefs(%arg0: memref<20xf32>, %arg1: memref<20xf32>, %arg_out: memref<20xf32>) -> () {
"xla_lhlo.map"(%arg0, %arg1, %arg_out) ({
"lmhlo.map"(%arg0, %arg1, %arg_out) ({
^bb0(%a: tensor<f32>, %b: tensor<f32>):
%c = mhlo.add %a, %b : tensor<f32>
"mhlo.return"(%c) : (tensor<f32>) -> ()
@ -989,7 +989,7 @@ func @map_memrefs(%arg0: memref<20xf32>, %arg1: memref<20xf32>, %arg_out: memref
func @map_memrefs(%arg0: memref<20xf32>, %arg1: memref<20xf32>, %arg_out: memref<10xf32>) -> () {
// expected-error@+1{{requires the same shape for all operands}}
"xla_lhlo.map"(%arg0, %arg1, %arg_out) ({
"lmhlo.map"(%arg0, %arg1, %arg_out) ({
^bb0(%a: tensor<f32>, %b: tensor<f32>):
%c = mhlo.add %a, %b : tensor<f32>
"mhlo.return"(%c) : (tensor<f32>) -> ()
@ -1001,7 +1001,7 @@ func @map_memrefs(%arg0: memref<20xf32>, %arg1: memref<20xf32>, %arg_out: memref
// CHECK-LABEL: func @rng_get_and_update_state_memrefs
func @rng_get_and_update_state_memrefs(%state: memref<1xui64>) -> () {
"xla_lhlo.rng_get_and_update_state"(%state) { delta = 1 : i64 } : (memref<1xui64>) -> ()
"lmhlo.rng_get_and_update_state"(%state) { delta = 1 : i64 } : (memref<1xui64>) -> ()
return
}
@ -1010,7 +1010,7 @@ func @rng_get_and_update_state_memrefs(%state: memref<1xui64>) -> () {
// CHECK-LABEL: func @sort_memrefs
func @sort_memrefs(%arg0: memref<16x16xf32>, %arg1: memref<16x16xf16>,
%out0: memref<16x16xf32>, %out1: memref<16x16xf16>) -> () {
"xla_lhlo.sort"(%arg0, %arg1, %out0, %out1) ( {
"lmhlo.sort"(%arg0, %arg1, %out0, %out1) ( {
^bb0(%a: tensor<f32>, %b: tensor<f32>, %c: tensor<f16>, %d: tensor<f16>):
%7 = "mhlo.compare"(%a, %b) {comparison_direction = "GT"} : (tensor<f32>, tensor<f32>) -> tensor<i1>
"mhlo.return"(%7) : (tensor<i1>) -> ()
@ -1023,7 +1023,7 @@ func @sort_memrefs(%arg0: memref<16x16xf32>, %arg1: memref<16x16xf16>,
// CHECK-LABEL: func @sort_memrefs
func @sort_memrefs(%arg0: memref<16x16xf32>, %arg1: memref<16x16xf16>,
%out0: memref<16x16xf32>, %out1: memref<16x16xf16>) -> () {
"xla_lhlo.sort"(%arg0, %arg1, %out0, %out1) ( {
"lmhlo.sort"(%arg0, %arg1, %out0, %out1) ( {
^bb0(%a: tensor<f32>, %b: tensor<f32>, %c: tensor<f16>, %d: tensor<f16>):
%7 = "mhlo.compare"(%a, %b) {comparison_direction = "GT"} : (tensor<f32>, tensor<f32>) -> tensor<i1>
"mhlo.return"(%7) : (tensor<i1>) -> ()
@ -1036,7 +1036,7 @@ func @sort_memrefs(%arg0: memref<16x16xf32>, %arg1: memref<16x16xf16>,
// CHECK-LABEL: func @sort_memrefs
func @sort_memrefs(%arg0: memref<16x16xf32>, %arg1: memref<16x16xf16>,
%out0: memref<16x16xf32>, %out1: memref<16x16xf16>) -> () {
"xla_lhlo.sort"(%arg0, %arg1, %out0, %out1) ( {
"lmhlo.sort"(%arg0, %arg1, %out0, %out1) ( {
^bb0(%a: tensor<f32>, %b: tensor<f32>, %c: tensor<f16>, %d: tensor<f16>):
%7 = "mhlo.compare"(%a, %b) {comparison_direction = "GT"} : (tensor<f32>, tensor<f32>) -> tensor<i1>
"mhlo.return"(%7) : (tensor<i1>) -> ()

View File

@ -30,7 +30,7 @@ struct PassConfig {
explicit PassConfig(QuantizationSpecs specs)
: emit_builtin_tflite_ops(true),
lower_tensor_list_ops(false),
trim_functions_whitelist({}),
trim_functions_allowlist({}),
quant_specs(std::move(specs)),
form_clusters(false),
unfold_batch_matmul(true),
@ -44,8 +44,8 @@ struct PassConfig {
// If `lower_tensor_list_ops` is true, tensorlist ops will be lowered to basic
// TF ops before legalization to TF Lite dialect.
bool lower_tensor_list_ops;
// The whitelist of functions that would be preserved after trimming.
llvm::ArrayRef<std::string> trim_functions_whitelist;
// The allowlist of functions that would be preserved after trimming.
llvm::ArrayRef<std::string> trim_functions_allowlist;
// All information about quantization.
QuantizationSpecs quant_specs;
// If `form_clusters` is true , clusters are formed by grouping consecutive

View File

@ -71,7 +71,7 @@ limitations under the License.
#include "tensorflow/core/platform/errors.h"
#include "tensorflow/core/platform/logging.h"
#include "tensorflow/core/platform/status.h"
#include "tensorflow/lite/delegates/flex/whitelisted_flex_ops.h"
#include "tensorflow/lite/delegates/flex/allowlisted_flex_ops.h"
#include "tensorflow/lite/kernels/internal/kernel_utils.h"
#include "tensorflow/lite/schema/schema_generated.h"
#include "tensorflow/lite/string_util.h"
@ -101,7 +101,7 @@ using mlir::Value;
using tensorflow::OpOrArgLocNameMapper;
using tensorflow::OpOrArgNameMapper;
using tensorflow::Status;
using tflite::flex::IsWhitelistedFlexOp;
using tflite::flex::IsAllowlistedFlexOp;
using xla::StatusOr;
template <typename T>
@ -972,7 +972,7 @@ Optional<BufferOffset<tflite::Operator>> Translator::BuildOperator(
// model is of an open op system.
//
// The following algorithm is followed:
// if flex is enabled and the op is whitelisted as flex
// if flex is enabled and the op is allowlisted as flex
// we emit op as flex.
// if custom is enabled
// we emit the op as custom.
@ -982,11 +982,11 @@ Optional<BufferOffset<tflite::Operator>> Translator::BuildOperator(
}
// Flex op case
// Eventually, the whitelist will go away and we will rely on some TF op
// Eventually, the allowlist will go away and we will rely on some TF op
// trait (e.g. No side effect) to determine if it is a supported "Flex"
// op or not.
if (enabled_op_types_.contains(OpType::kSelectTf) &&
IsWhitelistedFlexOp(node_def->op())) {
IsAllowlistedFlexOp(node_def->op())) {
// Construct ops as flex op encoding TensorFlow node definition
// as custom options.
// Flex ops are named with the kFlexOpNamePrefix prefix to the actual
@ -1037,7 +1037,7 @@ Optional<BufferOffset<tflite::Operator>> Translator::BuildOperator(
}
// Insert failed op to `flex_ops` or `custom_ops`.
if (IsWhitelistedFlexOp(node_def->op())) {
if (IsAllowlistedFlexOp(node_def->op())) {
failed_flex_ops_.insert(os.str());
} else {
failed_custom_ops_.insert(os.str());

View File

@ -443,8 +443,7 @@ StatusOr<Operation*> BuildConstOp(const tflite::TensorT& tensor,
if (auto float_type = elem_type.dyn_cast<mlir::FloatType>()) {
TF_ASSIGN_OR_RETURN(value,
ConvertFloatBuffer(shaped_type, float_type, buffer));
} else if (elem_type.isa<mlir::IntegerType>() ||
elem_type.isa<QuantizedType>()) {
} else if (elem_type.isa<mlir::IntegerType, QuantizedType>()) {
TF_ASSIGN_OR_RETURN(value,
ConvertIntBuffer(shaped_type, elem_type, buffer));
} else if (elem_type.isa<mlir::TF::StringType>()) {
@ -456,8 +455,7 @@ StatusOr<Operation*> BuildConstOp(const tflite::TensorT& tensor,
refs.push_back({ref.data(), ref.size()});
value = mlir::DenseStringElementsAttr::get(shaped_type, refs);
} else if (elem_type.isa<mlir::ComplexType>() ||
elem_type.isa<mlir::TF::TensorFlowType>()) {
} else if (elem_type.isa<mlir::ComplexType, mlir::TF::TensorFlowType>()) {
auto dialect = elem_type.getContext()->getRegisteredDialect("tf");
tensorflow::TensorProto repr = ConvertTfliteConstTensor(tensor, buffer);
std::string mangled = tensorflow::mangling_util::MangleTensor(repr);

View File

@ -694,8 +694,7 @@ void QuantizationDriver::SetupAllStates() {
fn_.walk([&](Operation *op) {
if (op->isKnownTerminator() ||
op->hasTrait<OpTrait::quant::NoQuantizableResult>() ||
llvm::isa<quant::DequantizeCastOp>(op) ||
llvm::isa<quant::QuantizeCastOp>(op))
llvm::isa<quant::DequantizeCastOp, quant::QuantizeCastOp>(op))
return;
work_list_.push_back(op);

View File

@ -386,8 +386,7 @@ struct FoldTrivalRequantizeOp : public OpRewritePattern<RQ> {
Operation* def = pre_quantized.getDefiningOp();
if (!def) return failure();
if (llvm::isa<FixedOutputRangeInterface>(def) ||
llvm::isa<SameScalesOpInterface>(def) ||
if (llvm::isa<FixedOutputRangeInterface, SameScalesOpInterface>(def) ||
def->hasTrait<OpTrait::quant::NoQuantizableResult>()) {
return failure();
}

View File

@ -1,4 +1,4 @@
// RUN: tf-opt %s -tfl-prepare-quantize -tfl-test-quantize-whitelist="quantize_float_placeholder_only,not_reset_input" | FileCheck %s
// RUN: tf-opt %s -tfl-prepare-quantize -tfl-test-quantize-allowlist="quantize_float_placeholder_only,not_reset_input" | FileCheck %s
// CHECK-LABEL: quantize_float_placeholder_only
func @quantize_float_placeholder_only(%arg0: tensor<f32>, %arg1: tensor<2x3xi32>, %arg2: tensor<2x3xf32>) -> (tensor<f32>, tensor<2x3xi32>, tensor<2x3xf32>) {

View File

@ -1,4 +1,4 @@
// RUN: tf-opt -tfl-trim-funcs-tf -tfl-trim-funcs-whitelist="bar,foobar" %s | FileCheck %s
// RUN: tf-opt -tfl-trim-funcs-tf -tfl-trim-funcs-allowlist="bar,foobar" %s | FileCheck %s
func @foo(%arg0: tensor<1x4xf32>, %arg1: tensor<1x4xf32>) -> tensor<1x4xf32> {
return %arg0 : tensor<1x4xf32>

View File

@ -560,7 +560,7 @@ struct FuseBinaryOpToFollowingAffineOp : public OpRewritePattern<AffineOpType> {
return failure();
ShapedType filter_type = filter_cst.getType();
if (llvm::isa<AddOp>(binary_op) || llvm::isa<SubOp>(binary_op)) {
if (llvm::isa<AddOp, SubOp>(binary_op)) {
auto padding = fc_op.template getAttrOfType<StringAttr>("padding");
if (padding && padding.getValue() != "VALID") return failure();
@ -606,7 +606,7 @@ struct FuseBinaryOpToFollowingAffineOp : public OpRewritePattern<AffineOpType> {
rewriter.create<ConstOp>(fc_op.getLoc(), new_bias_type, new_bias);
fc_op.setOperand(0, binary_op->getOperand(0));
fc_op.setOperand(2, new_bias_op);
} else if (llvm::isa<MulOp>(binary_op) || llvm::isa<DivOp>(binary_op)) {
} else if (llvm::isa<MulOp, DivOp>(binary_op)) {
// The fusion of mul/div is actually applying the following
// transformation:
// w * (x ' c) + b => (w ' c) x + b

View File

@ -61,7 +61,7 @@ std::unique_ptr<OperationPass<FuncOp>> CreatePostQuantizePass(
// Creates an instance of the TensorFlow Lite dialect TrimFunctions
// pass.
std::unique_ptr<OperationPass<ModuleOp>> CreateTrimFunctionsPass(
llvm::ArrayRef<std::string> trim_funcs_whitelist);
llvm::ArrayRef<std::string> trim_funcs_allowlist);
// Creates an instance of the TensorFlow Lite dialect PrepareCompositeFunctions
// pass.

View File

@ -35,9 +35,9 @@ limitations under the License.
#include "tensorflow/core/framework/types.pb.h"
// NOLINTNEXTLINE
static llvm::cl::list<std::string> quantize_whitelist(
"tfl-test-quantize-whitelist", llvm::cl::value_desc("list"),
llvm::cl::desc("comma separated list of whitelisted functions to be "
static llvm::cl::list<std::string> quantize_allowlist(
"tfl-test-quantize-allowlist", llvm::cl::value_desc("list"),
llvm::cl::desc("comma separated list of allowlisted functions to be "
"quantized. Only used in tests"),
llvm::cl::CommaSeparated);
@ -108,7 +108,7 @@ class PrepareQuantizePass
// Get the min and max values from the quantization specification for the
// current function function and argument index. Uses default values if
// the function is specified in the `quantize_whitelist`.
// the function is specified in the `quantize_allowlist`.
std::pair<llvm::Optional<double>, llvm::Optional<double>>
GetMinMaxValuesForArgument(llvm::StringRef func_name, int index) {
if (func_name == quant_specs_.target_func) {
@ -132,7 +132,7 @@ bool PrepareQuantizePass::SetInputNodesQuantizationParams(FuncOp func) {
// Skip this function because it isn't the target function from the spec or
// in the function while list.
if (target_func != func_name &&
!llvm::is_contained(quantize_whitelist, func_name)) {
!llvm::is_contained(quantize_allowlist, func_name)) {
return false;
}

View File

@ -29,12 +29,12 @@ limitations under the License.
#include "mlir/Pass/Pass.h" // from @llvm-project
#include "tensorflow/compiler/mlir/lite/transforms/passes.h"
// The cmd line flag to specify the whitelist of functions. Rest are trimmed
// The cmd line flag to specify the allowlist of functions. Rest are trimmed
// after this pass is run.
// NOLINTNEXTLINE
static llvm::cl::list<std::string> trim_funcs_whitelist(
"tfl-trim-funcs-whitelist", llvm::cl::value_desc("list"),
llvm::cl::desc("comma separated list of whitelisted functions. The first "
static llvm::cl::list<std::string> trim_funcs_allowlist(
"tfl-trim-funcs-allowlist", llvm::cl::value_desc("list"),
llvm::cl::desc("comma separated list of allowlisted functions. The first "
"function specified will be used as main."),
llvm::cl::CommaSeparated);
@ -43,25 +43,25 @@ namespace TFL {
namespace {
// The pass to trim functions before we legalize to TFL
// dialect using the specified whitelist.
// dialect using the specified allowlist.
class TrimFunctionsPass
: public mlir::PassWrapper<TrimFunctionsPass, OperationPass<ModuleOp>> {
public:
explicit TrimFunctionsPass() : trim_funcs_whitelist_(trim_funcs_whitelist) {}
explicit TrimFunctionsPass(llvm::ArrayRef<std::string> trim_funcs_whitelist)
: trim_funcs_whitelist_(trim_funcs_whitelist) {}
explicit TrimFunctionsPass() : trim_funcs_allowlist_(trim_funcs_allowlist) {}
explicit TrimFunctionsPass(llvm::ArrayRef<std::string> trim_funcs_allowlist)
: trim_funcs_allowlist_(trim_funcs_allowlist) {}
private:
void runOnOperation() override;
bool TrimModule();
void Verify();
llvm::ArrayRef<std::string> trim_funcs_whitelist_;
llvm::ArrayRef<std::string> trim_funcs_allowlist_;
};
void TrimFunctionsPass::runOnOperation() {
// trim the functions in the module using the trim_funcs_whitelist_
// by removing functions not in the whitelist.
// trim the functions in the module using the trim_funcs_allowlist_
// by removing functions not in the allowlist.
if (TrimModule()) {
// verify the updated module is still valid, if not signal the
// pass as failed.
@ -70,20 +70,20 @@ void TrimFunctionsPass::runOnOperation() {
}
bool TrimFunctionsPass::TrimModule() {
// if no trim_funcs_whitelist_ is specified, this pass is a no-op.
if (trim_funcs_whitelist_.empty()) return false;
// if no trim_funcs_allowlist_ is specified, this pass is a no-op.
if (trim_funcs_allowlist_.empty()) return false;
llvm::SmallVector<FuncOp, 4> funcs_to_trim;
for (auto func : getOperation().getOps<FuncOp>()) {
if (llvm::is_contained(trim_funcs_whitelist_, func.getName())) {
// If no main is specified in the whitelist, use the 1st func
// in trim_funcs_whitelist as the main.
if (llvm::is_contained(trim_funcs_allowlist_, func.getName())) {
// If no main is specified in the allowlist, use the 1st func
// in trim_funcs_allowlist as the main.
// TODO(ashwinm): Currently tflite flatbuffer export assumes there is
// always a main. This is strictly not required for TFlite. We need to
// remove that restriction once we have support to attribute the main
// tensorflow function in MLIR TF import using an entry_point attr.
if (!llvm::is_contained(trim_funcs_whitelist_, "main") &&
func.getName() == trim_funcs_whitelist_[0]) {
if (!llvm::is_contained(trim_funcs_allowlist_, "main") &&
func.getName() == trim_funcs_allowlist_[0]) {
func.setName("main");
}
} else {
@ -99,7 +99,7 @@ bool TrimFunctionsPass::TrimModule() {
}
// validate that all reachable functions from the remaining functions are
// also in the whitelist.
// also in the allowlist.
void TrimFunctionsPass::Verify() {
// TODO(ashwinm): Instead, we should make sure that references to all
// SymbolRefAttrs of all ops are present.
@ -109,7 +109,7 @@ void TrimFunctionsPass::Verify() {
auto walk_result = func.walk([&](CallOp op) -> WalkResult {
if (!symbol_table.lookup<FuncOp>(op.getCallee()))
return getOperation().emitError()
<< func.getName() << " is not in the funcs whitelist";
<< func.getName() << " is not in the funcs allowlist";
return WalkResult::advance();
});
if (walk_result.wasInterrupted()) return signalPassFailure();
@ -121,13 +121,13 @@ void TrimFunctionsPass::Verify() {
// Creates an instance of the TensorFlow Lite dialect TrimFunctions
/// pass.
std::unique_ptr<OperationPass<ModuleOp>> CreateTrimFunctionsPass(
llvm::ArrayRef<std::string> trim_funcs_whitelist) {
return std::make_unique<TrimFunctionsPass>(trim_funcs_whitelist);
llvm::ArrayRef<std::string> trim_funcs_allowlist) {
return std::make_unique<TrimFunctionsPass>(trim_funcs_allowlist);
}
static PassRegistration<TrimFunctionsPass> pass(
"tfl-trim-funcs-tf",
"Trim functions to restrict them to a specified whitelist prior to "
"Trim functions to restrict them to a specified allowlist prior to "
"legalization to TensorFlow lite dialect");
} // namespace TFL

View File

@ -624,6 +624,7 @@ cc_library(
"transforms/tpu_rewrite_pass.cc",
"transforms/tpu_sharding_identification_pass.cc",
"transforms/tpu_space_to_depth_pass.cc",
"transforms/tpu_update_embedding_enqueue_op_inputs.cc",
"transforms/tpu_variable_runtime_reformatting.cc",
"translate/breakup-islands.cc",
"translate/tf_executor_to_functional.cc",

View File

@ -168,8 +168,7 @@ void ResourceAliasAnalysis::AnalyzeFunction(FuncOp func_op) {
var_handle.resource(),
GetOrCreateIdForVarHandle(var_handle, &next_unique_id,
&var_handle_name_id_map));
} else if (llvm::isa<TF::IdentityNOp>(op) ||
llvm::isa<TF::IdentityOp>(op)) {
} else if (llvm::isa<TF::IdentityNOp, TF::IdentityOp>(op)) {
for (auto operand_and_result :
llvm::zip(op->getOperands(), op->getResults())) {
forward_input_to_output(std::get<0>(operand_and_result),
@ -333,7 +332,7 @@ bool OpIsDeclaration(Operation* op,
const ResourceAliasAnalysis& alias_analysis) {
// TODO(yuanzx): Add other types of resources.
return llvm::isa<TF::VarHandleOp>(op) ||
((llvm::isa<TF::IdentityNOp>(op) || llvm::isa<TF::IdentityOp>(op)) &&
(llvm::isa<TF::IdentityNOp, TF::IdentityOp>(op) &&
!FindAccessedResources(op, alias_analysis).empty());
}

View File

@ -569,8 +569,10 @@ void BuildReplicateOp(
// Add derived `operand_segment_sizes` attribute.
int32_t num_replicated_inputs = replicated_inputs.size() * n;
auto operand_segment_sizes = DenseIntElementsAttr::get(
VectorType::get({2}, builder->getI32Type()), {num_replicated_inputs, 0});
int32_t num_packed_inputs = packed_inputs.size();
auto operand_segment_sizes =
DenseIntElementsAttr::get(VectorType::get({2}, builder->getI32Type()),
{num_replicated_inputs, num_packed_inputs});
state->addAttribute(kOperandSegmentSizesAttr, operand_segment_sizes);
for (const auto& output_type : replica_output_types)
@ -600,6 +602,65 @@ void ReplicateOp::build(
packed_inputs, replica_output_types);
}
// Returns the number of packed block arguments.
unsigned ReplicateOp::GetNumPackedBlockArguments() {
return packed_inputs().size();
}
// Returns the number of replicated block arguments.
unsigned ReplicateOp::GetNumReplicatedBlockArguments() {
return GetBody().getNumArguments() - GetNumPackedBlockArguments();
}
// Returns the replicated block arguments. A copy should be made if the
// replicate op is being modified.
llvm::ArrayRef<BlockArgument> ReplicateOp::GetReplicatedBlockArguments() {
return GetBody().getArguments().drop_back(GetNumPackedBlockArguments());
}
// Returns the packed block arguments. A copy should be made if the replicate op
// is being modified.
llvm::ArrayRef<BlockArgument> ReplicateOp::GetPackedBlockArguments() {
return GetBody().getArguments().take_back(GetNumPackedBlockArguments());
}
// Checks if a block argument is replicated (forwarding replicated inputs).
bool ReplicateOp::IsReplicatedBlockArgument(BlockArgument block_arg) {
assert(block_arg.getOwner() == &GetBody());
return block_arg.getArgNumber() < GetNumReplicatedBlockArguments();
}
// Checks if a block argument is packed (forwarding a packed input).
bool ReplicateOp::IsPackedBlockArgument(BlockArgument block_arg) {
return !IsReplicatedBlockArgument(block_arg);
}
// Returns the operand index of the operand being forwarded as a
// replicated/packed block argument for a given replica. This assumes a valid
// block argument (of the replicate op) and a valid replica is provided.
unsigned ReplicateOp::GetReplicaOperandIndexForBlockArgument(
BlockArgument block_arg, unsigned replica) {
const int32_t num_replicas = nAttr().getInt();
assert(replica < num_replicas && block_arg.getOwner() == &GetBody());
const unsigned num_replicated_args = GetNumReplicatedBlockArguments();
if (block_arg.getArgNumber() < num_replicated_args)
return block_arg.getArgNumber() * num_replicas + replica;
return block_arg.getArgNumber() - num_replicated_args +
replicated_inputs().size();
}
// Returns the operand being forwarded as a replicated/packed block argument for
// a given replica. This assumes a valid block argument (of the replicate op)
// and a valid replica is provided.
Value ReplicateOp::GetReplicaOperandForBlockArgument(BlockArgument block_arg,
unsigned replica) {
const unsigned operand_index =
GetReplicaOperandIndexForBlockArgument(block_arg, replica);
return getOperand(operand_index);
}
//===----------------------------------------------------------------------===//
// Canonicalization patterns
//===----------------------------------------------------------------------===//

View File

@ -283,6 +283,14 @@ For example:
let extraClassDeclaration = [{
Block &GetBody() { return getOperation()->getRegion(0).front(); }
unsigned GetNumReplicatedBlockArguments();
unsigned GetNumPackedBlockArguments();
llvm::ArrayRef<BlockArgument> GetPackedBlockArguments();
llvm::ArrayRef<BlockArgument> GetReplicatedBlockArguments();
bool IsReplicatedBlockArgument(BlockArgument block_arg);
bool IsPackedBlockArgument(BlockArgument block_arg);
unsigned GetReplicaOperandIndexForBlockArgument(BlockArgument block_arg, unsigned replica);
Value GetReplicaOperandForBlockArgument(BlockArgument block_arg, unsigned replica);
}];
let builders = [

View File

@ -71,7 +71,7 @@ struct TensorFlowExecutorInlinerInterface : public DialectInlinerInterface {
// Allow inlining into tf.island regions if the incoming region has a single
// block.
return llvm::isa<tf_executor::IslandOp>(dest->getParentOp()) &&
std::next(src->begin()) == src->end();
llvm::hasSingleElement(*src);
}
};

View File

@ -7270,6 +7270,8 @@ reshape(t, []) ==> 7
let verifier = [{
return Verify(*this);
}];
let hasCanonicalizer = 1;
}
def TF_ResizeBilinearOp : TF_Op<"ResizeBilinear", [NoSideEffect]> {

View File

@ -506,28 +506,52 @@ LogicalResult FoldOperandsPermutation(
//===----------------------------------------------------------------------===//
namespace {
// Folder that returns LHS of an Arithmetic Op if the RHS is a constant
// known to be Identity (e.g X+0)
// Fold Arithmetic Op if one of the operands is a constant known to be an
// Identity (e.g. X+0, X*1, etc...). For commutative operations fold if
// known identity value is either lhs or rhs.
template <
typename OpT,
typename std::enable_if<llvm::is_one_of<
OpT, AddV2Op, SubOp, MulOp, DivOp, RealDivOp>::value>::type * = nullptr>
OpFoldResult IdentityArithmeticOpFolder(OpT arithmetic_op,
ArrayRef<Attribute> operands) {
auto result_op_type = arithmetic_op.getResult().getType();
auto lhs_type = arithmetic_op.x().getType().template cast<ShapedType>();
if (!result_op_type.template cast<ShapedType>().hasStaticShape()) return {};
auto rhs_type = arithmetic_op.y().getType().template cast<ShapedType>();
auto result_type =
arithmetic_op.getResult().getType().template cast<ShapedType>();
// We only handle non-broadcastable case.
if (result_op_type != lhs_type) {
return {};
}
// We can fold arithmetic operation only of we can prove that we will not
// accidentally hide a broadcasting error.
auto is_valid_broadcasting = [](ShapedType operand_ty, ShapedType identity_ty,
ShapedType result_ty) -> bool {
// Scalar identity is broadcastable to any operand shape, we only need to
// check that operand has the same shape as a result.
bool scalar_identity = identity_ty.hasRank() && identity_ty.getRank() == 0;
if (scalar_identity) return operand_ty == result_ty;
// If identity is not a scalar, we must verify that all shapes are equal
// and statically known.
//
// TODO(ezhulenev): Fold if identity shape is statically know to be
// broadcastable to the operand shape.
return operand_ty == result_ty && identity_ty == result_ty &&
result_ty.hasStaticShape();
};
// Check that we have a constant operand on one side (candidate for identity).
const bool is_commutative =
(std::is_same<OpT, AddV2Op>::value || std::is_same<OpT, MulOp>::value);
auto lhs_attr = operands[0].dyn_cast_or_null<DenseElementsAttr>();
auto rhs_attr = operands[1].dyn_cast_or_null<DenseElementsAttr>();
if (!rhs_attr && !(is_commutative && lhs_attr)) return {};
// Mul and Div ops have identity value one while AddV2 and SubOp have identity
// value zero.
int identity =
const int identity =
(std::is_same<OpT, MulOp>::value || std::is_same<OpT, DivOp>::value ||
std::is_same<OpT, RealDivOp>::value);
std::is_same<OpT, RealDivOp>::value)
? 1
: 0;
Type element_ty = lhs_type.getElementType();
Attribute identity_attr;
@ -539,23 +563,19 @@ OpFoldResult IdentityArithmeticOpFolder(OpT arithmetic_op,
return {};
}
if (auto attr = operands[1].dyn_cast_or_null<DenseElementsAttr>()) {
if (attr.isSplat() && attr.getSplatValue() == identity_attr)
// Fold: Op(Operand, Identity) -> Operand.
if (rhs_attr && is_valid_broadcasting(lhs_type, rhs_type, result_type)) {
if (rhs_attr.isSplat() && rhs_attr.getSplatValue() == identity_attr)
return arithmetic_op.x();
}
auto rhs_type = arithmetic_op.y().getType().template cast<ShapedType>();
// TODO(chhe): we could fold and add an identity to force the broadcast.
if (result_op_type != rhs_type) {
return {};
}
bool is_symmetric =
(std::is_same<OpT, AddV2Op>::value || std::is_same<OpT, MulOp>::value);
if (auto attr = operands[0].dyn_cast_or_null<DenseElementsAttr>()) {
if (is_symmetric && attr.isSplat() && attr.getSplatValue() == identity_attr)
// Fold: Op(Identity, Operand) -> Operand for commutative operations.
if (lhs_attr && is_commutative &&
is_valid_broadcasting(rhs_type, lhs_type, result_type)) {
if (lhs_attr.isSplat() && lhs_attr.getSplatValue() == identity_attr)
return arithmetic_op.y();
}
return {};
}
} // namespace
@ -1168,8 +1188,7 @@ void ConstOp::build(OpBuilder &builder, OperationState &result,
ShapedType type;
if (auto elem_attr = value.dyn_cast<ElementsAttr>()) {
return ConstOp::build(builder, result, elem_attr);
} else if (value.isa<BoolAttr>() || value.isa<FloatAttr>() ||
value.isa<IntegerAttr>()) {
} else if (value.isa<BoolAttr, FloatAttr, IntegerAttr>()) {
// All TensorFlow types must be tensor types. In the build() method,
// we want to provide more flexibility by allowing attributes of scalar
// types. But we need to wrap it up with ElementsAttr to construct
@ -2870,6 +2889,11 @@ void ReshapeOp::build(OpBuilder &builder, OperationState &result, Value tensor,
return unranked();
}
void ReshapeOp::getCanonicalizationPatterns(OwningRewritePatternList &results,
MLIRContext *context) {
results.insert<RedundantReshape>(context);
}
//===----------------------------------------------------------------------===//
// SelectOp
//===----------------------------------------------------------------------===//

View File

@ -1165,4 +1165,35 @@ array([0, 2, 2])
);
}
def TF_TPUPartitionedCallOp : TF_Op<"TPUPartitionedCall", [CallOpInterface]> {
let summary = "Calls a function placed on a specified TPU device.";
let arguments = (ins
Variadic<TF_Tensor>:$args,
I32Tensor:$device_ordinal,
SymbolRefAttr:$f,
DefaultValuedAttr<I64Attr, "0">:$autotuner_thresh
);
let results = (outs
Variadic<TF_Tensor>:$output
);
TF_DerivedOperandTypeListAttr Tin = TF_DerivedOperandTypeListAttr<0>;
TF_DerivedResultTypeListAttr Tout = TF_DerivedResultTypeListAttr<0>;
let extraClassDeclaration = [{
// Gets the argument operands to the called function.
operand_range getArgOperands() { return args(); }
// Returns the callee of this operation.
CallInterfaceCallable getCallableForCallee() {
return getAttrOfType<SymbolRefAttr>("f");
}
}];
let verifier = [{ return VerifyPartitionedCall(*this); }];
}
#endif // TF_OPS

View File

@ -356,7 +356,7 @@ LogicalResult VerifyExportedFunc(FuncOp func) {
LogicalResult TensorFlowSavedModelDialect::verifyOperationAttribute(
Operation *op, NamedAttribute named_attr) {
if (named_attr.first == "tf_saved_model.exported_names") {
if (!isa<FuncOp>(op) && !isa<GlobalTensorOp>(op)) {
if (!isa<FuncOp, GlobalTensorOp>(op)) {
return op->emitError() << "'tf_saved_model.exported_names' must be on a "
"'func' or 'tf_saved_model.global_tensor' op";
}

View File

@ -90,8 +90,7 @@ class TensorFlowType : public Type {
// Returns true if the specified type is a valid TensorFlow element type.
static inline bool IsValidTFElementType(Type type) {
return type.isa<ComplexType>() || type.isa<FloatType>() ||
type.isa<IntegerType>() || type.isa<TensorFlowType>();
return type.isa<ComplexType, FloatType, IntegerType, TensorFlowType>();
}
// Returns true if this is a valid TensorFlow tensor type.

View File

@ -190,6 +190,27 @@ func @testSubOfNeg(%arg0: tensor<8x16xf32>, %arg1: tensor<8x16xf32>) -> tensor<8
// CHECK: return %0
}
// CHECK-LABEL: testSubOfZero
func @testSubOfZero(%arg0: tensor<?x1xf32>, %arg1: tensor<4x1xf32>) -> (tensor<?x1xf32>, tensor<4x1xf32>) {
%0 = "tf.Const"() {value = dense<0.0> : tensor<f32>} : () -> tensor<f32>
%1 = "tf.Sub"(%arg0, %0) : (tensor<?x1xf32>, tensor<f32>) -> tensor<?x1xf32>
%2 = "tf.Sub"(%arg1, %0) : (tensor<4x1xf32>, tensor<f32>) -> tensor<4x1xf32>
return %1, %2: tensor<?x1xf32>, tensor<4x1xf32>
// CHECK: return %arg0, %arg1
}
// CHECK-LABEL: testSubOfZeroWithBroadcasting
func @testSubOfZeroWithBroadcasting(%arg0: tensor<4x1xf32>) -> tensor<4x4xf32> {
// This is an identity arithmetic operation, however we do not currently fold
// it because it has a broadcasting.
%0 = "tf.Const"() {value = dense<[[0.0, 0.0, 0.0, 0.0]]> : tensor<1x4xf32>} : () -> tensor<1x4xf32>
%1 = "tf.Sub"(%arg0, %0) : (tensor<4x1xf32>, tensor<1x4xf32>) -> tensor<4x4xf32>
return %1 : tensor<4x4xf32>
// CHECK: return %1
}
// CHECK-LABEL: testSquareOfSub
func @testSquareOfSub(%arg0: tensor<8x16xf32>, %arg1: tensor<8x16xf32>) -> tensor<8x16xf32> {
%0 = "tf.Sub"(%arg0, %arg1) : (tensor<8x16xf32>, tensor<8x16xf32>) -> tensor<8x16xf32>
@ -257,6 +278,46 @@ func @testAddV2OfNegRight(%arg0: tensor<8x16xf32>, %arg1: tensor<8x16xf32>) -> t
// CHECK: return %0
}
// CHECK-LABEL: testAddV2IdentityScalar
func @testAddV2IdentityScalar(%arg0: tensor<f32>, %arg1: tensor<?xf32>, %arg2: tensor<4xf32>) -> (tensor<f32>, tensor<?xf32>, tensor<4xf32>) {
%0 = "tf.Const"() {value = dense<0.0> : tensor<f32>} : () -> tensor<f32>
// Identity scalar (0.0) is foldable with operand of any shape because
// scalar is safely broadcastable to any shape.
%1 = "tf.AddV2"(%arg0, %0) : (tensor<f32>, tensor<f32>) -> tensor<f32>
%2 = "tf.AddV2"(%arg1, %0) : (tensor<?xf32>, tensor<f32>) -> tensor<?xf32>
%3 = "tf.AddV2"(%arg2, %0) : (tensor<4xf32>, tensor<f32>) -> tensor<4xf32>
%4 = "tf.AddV2"(%0, %1) : (tensor<f32>, tensor<f32>) -> tensor<f32>
%5 = "tf.AddV2"(%0, %2) : (tensor<f32>, tensor<?xf32>) -> tensor<?xf32>
%6 = "tf.AddV2"(%0, %3) : (tensor<f32>, tensor<4xf32>) -> tensor<4xf32>
// CHECK: return %arg0, %arg1, %arg2
return %4, %5, %6: tensor<f32>, tensor<?xf32>, tensor<4xf32>
}
// CHECK-LABEL: testAddV2IdentityTensor
func @testAddV2IdentityTensor(%arg0: tensor<f32>, %arg1: tensor<4xf32>) -> (tensor<4xf32>, tensor<4xf32>, tensor<4xf32>, tensor<4xf32>) {
%0 = "tf.Const"() {value = dense<[0.0, 0.0, 0.0, 0.0]> : tensor<4xf32>} : () -> tensor<4xf32>
// If operand is a scalar, then the identity value (0.0 for addition) can
// be of any shape, because operand is safely broadcastable to any shape.
//
// However we can't fold this arithmetic operation because the operand
// shape does not match the result shape.
%1 = "tf.AddV2"(%arg0, %0) : (tensor<f32>, tensor<4xf32>) -> tensor<4xf32>
%2 = "tf.AddV2"(%0, %arg0) : (tensor<4xf32>, tensor<f32>) -> tensor<4xf32>
// If operand has the same shape as a result, we can fold it.
%3 = "tf.AddV2"(%arg1, %0) : (tensor<4xf32>, tensor<4xf32>) -> tensor<4xf32>
%4 = "tf.AddV2"(%0, %arg1) : (tensor<4xf32>, tensor<4xf32>) -> tensor<4xf32>
// CHECK: return %1, %2, %arg1, %arg1
return %1, %2, %3, %4: tensor<4xf32>, tensor<4xf32>, tensor<4xf32>, tensor<4xf32>
}
// CHECK-LABEL: testDoubleConj
func @testDoubleConj(%arg0: tensor<8x16x32x64xcomplex<f32>>) -> tensor<8x16x32x64xcomplex<f32>> {
%0 = "tf.Conj"(%arg0) : (tensor<8x16x32x64xcomplex<f32>>) -> tensor<8x16x32x64xcomplex<f32>>
@ -302,6 +363,20 @@ func @testDoubleReciprocal(%arg0: tensor<8x16x32x64xi32>) -> tensor<8x16x32x64xi
// CHECK: return %arg0
}
// CHECK-LABEL: testRedundantReshape
func @testRedundantReshape(%arg0: tensor<4x4xi32>) -> tensor<2x8xi32> {
%0 = "tf.Const"() {value = dense<[8, 2]> : tensor<2xi32>} : () -> tensor<2xi32>
%1 = "tf.Const"() {value = dense<[2, 8]> : tensor<2xi32>} : () -> tensor<2xi32>
%2 = "tf.Reshape"(%arg0, %0) : (tensor<4x4xi32>, tensor<2xi32>) -> tensor<8x2xi32>
%3 = "tf.Reshape"(%2, %1) : (tensor<8x2xi32>, tensor<2xi32>) -> tensor<2x8xi32>
return %3: tensor<2x8xi32>
// CHECK: %0 = "tf.Const"
// CHECK-SAME: value = dense<[2, 8]> : tensor<2xi32>
// CHECK: %1 = "tf.Reshape"(%arg0, %0)
// CHECK: return %1 : tensor<2x8xi32>
}
// CHECK-LABEL: testSelectScalarPred
func @testSelectScalarPred(%arg0: tensor<i1>, %arg1: tensor<4x2xf16>, %arg2: tensor<4x2xf16>) -> tensor<4x2xf16> {
// CHECK-NEXT: "tf.SelectV2"(%arg0, %arg1, %arg2) : (tensor<i1>, tensor<4x2xf16>, tensor<4x2xf16>) -> tensor<4x2xf16>

View File

@ -2,17 +2,17 @@
func @biasAdd_NHWC(%arg0: tensor<1x32x10x32xi32>, %arg1: tensor<32xi32>) -> tensor<1x32x10x32xi32> {
%0 = "xla_chlo.broadcast_add"(%arg0, %arg1) {broadcast_dimensions = dense<3> : tensor<1xi64>} : (tensor<1x32x10x32xi32>, tensor<32xi32>) -> tensor<1x32x10x32xi32>
%0 = "chlo.broadcast_add"(%arg0, %arg1) {broadcast_dimensions = dense<3> : tensor<1xi64>} : (tensor<1x32x10x32xi32>, tensor<32xi32>) -> tensor<1x32x10x32xi32>
return %0 : tensor<1x32x10x32xi32>
}
func @biasAdd_NCHW(%arg0: tensor<1x32x10x32xi32>, %arg1: tensor<32xi32>) -> tensor<1x32x10x32xi32> {
%0 = "xla_chlo.broadcast_add"(%arg0, %arg1) {broadcast_dimensions = dense<3> : tensor<1xi64>} : (tensor<1x32x10x32xi32>, tensor<32xi32>) -> tensor<1x32x10x32xi32>
%0 = "chlo.broadcast_add"(%arg0, %arg1) {broadcast_dimensions = dense<3> : tensor<1xi64>} : (tensor<1x32x10x32xi32>, tensor<32xi32>) -> tensor<1x32x10x32xi32>
return %0 : tensor<1x32x10x32xi32>
}
func @biasAdd_dynamic(%arg0: tensor<?x?x?x?xi32>, %arg1: tensor<?xi32>) -> tensor<?x?x?x?xi32> {
%0 = "xla_chlo.broadcast_add"(%arg0, %arg1) {broadcast_dimensions = dense<3> : tensor<1xi64>} : (tensor<?x?x?x?xi32>, tensor<?xi32>) -> tensor<?x?x?x?xi32>
%0 = "chlo.broadcast_add"(%arg0, %arg1) {broadcast_dimensions = dense<3> : tensor<1xi64>} : (tensor<?x?x?x?xi32>, tensor<?xi32>) -> tensor<?x?x?x?xi32>
return %0 : tensor<?x?x?x?xi32>
}
@ -23,12 +23,12 @@ func @add(%arg0: tensor<2xi32>) -> tensor<2xi32> {
}
func @broadcast_add(%arg0: tensor<1xi32>, %arg1: tensor<1x2xi32>) -> tensor<1x2xi32> {
%0 = "xla_chlo.broadcast_add"(%arg0, %arg1) {broadcast_dimensions = dense<1> : tensor<1xi64>} : (tensor<1xi32>, tensor<1x2xi32>) -> tensor<1x2xi32>
%0 = "chlo.broadcast_add"(%arg0, %arg1) {broadcast_dimensions = dense<1> : tensor<1xi64>} : (tensor<1xi32>, tensor<1x2xi32>) -> tensor<1x2xi32>
return %0 : tensor<1x2xi32>
}
func @broadcast_multi_dim_add(%arg0: tensor<4x1x1xi32>, %arg1: tensor<4x4x4x4xi32>) -> tensor<4x4x4x4xi32> {
%0 = "xla_chlo.broadcast_add"(%arg0, %arg1) {broadcast_dimensions = dense<[1, 2, 3]> : tensor<3xi64>} : (tensor<4x1x1xi32>, tensor<4x4x4x4xi32>) -> tensor<4x4x4x4xi32>
%0 = "chlo.broadcast_add"(%arg0, %arg1) {broadcast_dimensions = dense<[1, 2, 3]> : tensor<3xi64>} : (tensor<4x1x1xi32>, tensor<4x4x4x4xi32>) -> tensor<4x4x4x4xi32>
return %0 : tensor<4x4x4x4xi32>
}
@ -38,7 +38,7 @@ func @div(%arg0: tensor<2xi32>) -> tensor<2xi32> {
}
func @broadcast_div(%arg0: tensor<1xi32>, %arg1: tensor<1x2xi32>) -> tensor<1x2xi32> {
%0 = "xla_chlo.broadcast_divide"(%arg0, %arg1) {broadcast_dimensions = dense<1> : tensor<1xi64>} : (tensor<1xi32>, tensor<1x2xi32>) -> tensor<1x2xi32>
%0 = "chlo.broadcast_divide"(%arg0, %arg1) {broadcast_dimensions = dense<1> : tensor<1xi64>} : (tensor<1xi32>, tensor<1x2xi32>) -> tensor<1x2xi32>
return %0 : tensor<1x2xi32>
}
@ -48,7 +48,7 @@ func @shift_left(%arg0: tensor<4xi32>, %arg1: tensor<4xi32>) -> tensor<4xi32> {
}
func @div_dynamic(%arg0: tensor<?xi32>, %arg1: tensor<?x?xi32>) -> tensor<?x?xi32> {
%0 = "xla_chlo.broadcast_divide"(%arg0, %arg1) {broadcast_dimensions = dense<1> : tensor<1xi64>} : (tensor<?xi32>, tensor<?x?xi32>) -> tensor<?x?xi32>
%0 = "chlo.broadcast_divide"(%arg0, %arg1) {broadcast_dimensions = dense<1> : tensor<1xi64>} : (tensor<?xi32>, tensor<?x?xi32>) -> tensor<?x?xi32>
return %0 : tensor<?x?xi32>
}
@ -68,7 +68,7 @@ func @mul(%arg0: tensor<2xi32>) -> tensor<2xi32> {
}
func @broadcast_mul(%arg0: tensor<1xi32>, %arg1: tensor<1x2xi32>) -> tensor<1x2xi32> {
%0 = "xla_chlo.broadcast_multiply"(%arg0, %arg1) {broadcast_dimensions = dense<1> : tensor<1xi64>} : (tensor<1xi32>, tensor<1x2xi32>) -> tensor<1x2xi32>
%0 = "chlo.broadcast_multiply"(%arg0, %arg1) {broadcast_dimensions = dense<1> : tensor<1xi64>} : (tensor<1xi32>, tensor<1x2xi32>) -> tensor<1x2xi32>
return %0 : tensor<1x2xi32>
}
@ -78,7 +78,7 @@ func @real_div(%arg0: tensor<2xi32>) -> tensor<2xi32> {
}
func @broadcast_real_div(%arg0: tensor<1xi32>, %arg1: tensor<1x2xi32>) -> tensor<1x2xi32> {
%0 = "xla_chlo.broadcast_divide"(%arg0, %arg1) {broadcast_dimensions = dense<1> : tensor<1xi64>} : (tensor<1xi32>, tensor<1x2xi32>) -> tensor<1x2xi32>
%0 = "chlo.broadcast_divide"(%arg0, %arg1) {broadcast_dimensions = dense<1> : tensor<1xi64>} : (tensor<1xi32>, tensor<1x2xi32>) -> tensor<1x2xi32>
return %0 : tensor<1x2xi32>
}
@ -88,7 +88,7 @@ func @sub(%arg0: tensor<2xi32>) -> tensor<2xi32> {
}
func @broadcast_sub(%arg0: tensor<1xi32>, %arg1: tensor<1x2xi32>) -> tensor<1x2xi32> {
%0 = "xla_chlo.broadcast_subtract"(%arg0, %arg1) {broadcast_dimensions = dense<1> : tensor<1xi64>} : (tensor<1xi32>, tensor<1x2xi32>) -> tensor<1x2xi32>
%0 = "chlo.broadcast_subtract"(%arg0, %arg1) {broadcast_dimensions = dense<1> : tensor<1xi64>} : (tensor<1xi32>, tensor<1x2xi32>) -> tensor<1x2xi32>
return %0 : tensor<1x2xi32>
}
@ -98,7 +98,7 @@ func @shift_right(%arg0: tensor<4xi32>, %arg1: tensor<4xi32>) -> tensor<4xi32> {
}
func @broadcast_shift_right(%arg0: tensor<4xi32>, %arg1: tensor<2x4xi32>) -> tensor<2x4xi32> {
%0 = "xla_chlo.broadcast_shift_right_arithmetic"(%arg0, %arg1) {broadcast_dimensions = dense<1> : tensor<1xi64>} : (tensor<4xi32>, tensor<2x4xi32>) -> tensor<2x4xi32>
%0 = "chlo.broadcast_shift_right_arithmetic"(%arg0, %arg1) {broadcast_dimensions = dense<1> : tensor<1xi64>} : (tensor<4xi32>, tensor<2x4xi32>) -> tensor<2x4xi32>
return %0 : tensor<2x4xi32>
}
@ -108,12 +108,12 @@ func @and(%arg0: tensor<2xi1>) -> tensor<2xi1> {
}
func @and_broadcast(%arg0: tensor<1xi1>, %arg1: tensor<1x2xi1>) -> tensor<1x2xi1> {
%0 = "xla_chlo.broadcast_and"(%arg0, %arg1) {broadcast_dimensions = dense<1> : tensor<1xi64>} : (tensor<1xi1>, tensor<1x2xi1>) -> tensor<1x2xi1>
%0 = "chlo.broadcast_and"(%arg0, %arg1) {broadcast_dimensions = dense<1> : tensor<1xi64>} : (tensor<1xi1>, tensor<1x2xi1>) -> tensor<1x2xi1>
return %0 : tensor<1x2xi1>
}
func @and_dynamic(%arg0: tensor<?xi1>, %arg1: tensor<1xi1>) -> tensor<?xi1> {
%0 = "xla_chlo.broadcast_and"(%arg0, %arg1) : (tensor<?xi1>, tensor<1xi1>) -> tensor<?xi1>
%0 = "chlo.broadcast_and"(%arg0, %arg1) : (tensor<?xi1>, tensor<1xi1>) -> tensor<?xi1>
return %0 : tensor<?xi1>
}
@ -123,12 +123,12 @@ func @or(%arg0: tensor<2xi1>) -> tensor<2xi1> {
}
func @or_broadcast(%arg0: tensor<1xi1>, %arg1: tensor<1x2xi1>) -> tensor<1x2xi1> {
%0 = "xla_chlo.broadcast_or"(%arg0, %arg1) {broadcast_dimensions = dense<1> : tensor<1xi64>} : (tensor<1xi1>, tensor<1x2xi1>) -> tensor<1x2xi1>
%0 = "chlo.broadcast_or"(%arg0, %arg1) {broadcast_dimensions = dense<1> : tensor<1xi64>} : (tensor<1xi1>, tensor<1x2xi1>) -> tensor<1x2xi1>
return %0 : tensor<1x2xi1>
}
func @or_dynamic(%arg0: tensor<?xi1>, %arg1: tensor<1xi1>) -> tensor<?xi1> {
%0 = "xla_chlo.broadcast_or"(%arg0, %arg1) : (tensor<?xi1>, tensor<1xi1>) -> tensor<?xi1>
%0 = "chlo.broadcast_or"(%arg0, %arg1) : (tensor<?xi1>, tensor<1xi1>) -> tensor<?xi1>
return %0 : tensor<?xi1>
}
@ -138,12 +138,12 @@ func @bitwise_or(%arg0: tensor<4xi32>, %arg1: tensor<4xi32>) -> tensor<4xi32> {
}
func @bitwise_or_broadcast(%arg0: tensor<1xi8>, %arg1: tensor<1x4xi8>) -> tensor<1x4xi8> {
%0 = "xla_chlo.broadcast_or"(%arg0, %arg1) {broadcast_dimensions = dense<1> : tensor<1xi64>} : (tensor<1xi8>, tensor<1x4xi8>) -> tensor<1x4xi8>
%0 = "chlo.broadcast_or"(%arg0, %arg1) {broadcast_dimensions = dense<1> : tensor<1xi64>} : (tensor<1xi8>, tensor<1x4xi8>) -> tensor<1x4xi8>
return %0 : tensor<1x4xi8>
}
func @bitwise_or_dynamic(%arg0: tensor<?xi32>, %arg1: tensor<1xi32>) -> tensor<?xi32> {
%0 = "xla_chlo.broadcast_or"(%arg0, %arg1) : (tensor<?xi32>, tensor<1xi32>) -> tensor<?xi32>
%0 = "chlo.broadcast_or"(%arg0, %arg1) : (tensor<?xi32>, tensor<1xi32>) -> tensor<?xi32>
return %0 : tensor<?xi32>
}
@ -153,12 +153,12 @@ func @bitwise_and(%arg0: tensor<4xi32>, %arg1: tensor<4xi32>) -> tensor<4xi32> {
}
func @bitwise_and_broadcast(%arg0: tensor<1xi8>, %arg1: tensor<1x4xi8>) -> tensor<1x4xi8> {
%0 = "xla_chlo.broadcast_and"(%arg0, %arg1) {broadcast_dimensions = dense<1> : tensor<1xi64>} : (tensor<1xi8>, tensor<1x4xi8>) -> tensor<1x4xi8>
%0 = "chlo.broadcast_and"(%arg0, %arg1) {broadcast_dimensions = dense<1> : tensor<1xi64>} : (tensor<1xi8>, tensor<1x4xi8>) -> tensor<1x4xi8>
return %0 : tensor<1x4xi8>
}
func @bitwise_and_dynamic(%arg0: tensor<?xi32>, %arg1: tensor<1xi32>) -> tensor<?xi32> {
%0 = "xla_chlo.broadcast_and"(%arg0, %arg1) : (tensor<?xi32>, tensor<1xi32>) -> tensor<?xi32>
%0 = "chlo.broadcast_and"(%arg0, %arg1) : (tensor<?xi32>, tensor<1xi32>) -> tensor<?xi32>
return %0 : tensor<?xi32>
}
@ -174,19 +174,19 @@ func @pow_dynamic(%arg0: tensor<?xf32>) -> tensor<?xf32> {
func @floordiv_broadcast_i32(%arg0: tensor<2x3xi32>, %arg1: tensor<3xi32>) -> tensor<2x3xi32> {
%0 = mhlo.constant dense<0> : tensor<2x3xi32>
%1 = "xla_chlo.broadcast_compare"(%arg0, %0) {comparison_direction = "LT"} : (tensor<2x3xi32>, tensor<2x3xi32>) -> tensor<2x3xi1>
%1 = "chlo.broadcast_compare"(%arg0, %0) {comparison_direction = "LT"} : (tensor<2x3xi32>, tensor<2x3xi32>) -> tensor<2x3xi1>
%2 = mhlo.constant dense<0> : tensor<3xi32>
%3 = "xla_chlo.broadcast_compare"(%arg1, %2) {comparison_direction = "LT"} : (tensor<3xi32>, tensor<3xi32>) -> tensor<3xi1>
%4 = "xla_chlo.broadcast_compare"(%1, %3) {broadcast_dimensions = dense<1> : tensor<1xi64>, comparison_direction = "EQ"} : (tensor<2x3xi1>, tensor<3xi1>) -> tensor<2x3xi1>
%5 = "xla_chlo.broadcast_divide"(%arg0, %arg1) {broadcast_dimensions = dense<1> : tensor<1xi64>} : (tensor<2x3xi32>, tensor<3xi32>) -> tensor<2x3xi32>
%3 = "chlo.broadcast_compare"(%arg1, %2) {comparison_direction = "LT"} : (tensor<3xi32>, tensor<3xi32>) -> tensor<3xi1>
%4 = "chlo.broadcast_compare"(%1, %3) {broadcast_dimensions = dense<1> : tensor<1xi64>, comparison_direction = "EQ"} : (tensor<2x3xi1>, tensor<3xi1>) -> tensor<2x3xi1>
%5 = "chlo.broadcast_divide"(%arg0, %arg1) {broadcast_dimensions = dense<1> : tensor<1xi64>} : (tensor<2x3xi32>, tensor<3xi32>) -> tensor<2x3xi32>
%6 = "mhlo.abs"(%arg0) : (tensor<2x3xi32>) -> tensor<2x3xi32>
%7 = "mhlo.abs"(%arg1) : (tensor<3xi32>) -> tensor<3xi32>
%8 = mhlo.constant dense<1> : tensor<3xi32>
%9 = mhlo.subtract %7, %8 : tensor<3xi32>
%10 = "xla_chlo.broadcast_add"(%6, %9) {broadcast_dimensions = dense<1> : tensor<1xi64>} : (tensor<2x3xi32>, tensor<3xi32>) -> tensor<2x3xi32>
%10 = "chlo.broadcast_add"(%6, %9) {broadcast_dimensions = dense<1> : tensor<1xi64>} : (tensor<2x3xi32>, tensor<3xi32>) -> tensor<2x3xi32>
%11 = "mhlo.negate"(%10) : (tensor<2x3xi32>) -> tensor<2x3xi32>
%12 = "mhlo.abs"(%arg1) : (tensor<3xi32>) -> tensor<3xi32>
%13 = "xla_chlo.broadcast_divide"(%11, %12) {broadcast_dimensions = dense<1> : tensor<1xi64>} : (tensor<2x3xi32>, tensor<3xi32>) -> tensor<2x3xi32>
%13 = "chlo.broadcast_divide"(%11, %12) {broadcast_dimensions = dense<1> : tensor<1xi64>} : (tensor<2x3xi32>, tensor<3xi32>) -> tensor<2x3xi32>
%14 = "mhlo.select"(%4, %5, %13) : (tensor<2x3xi1>, tensor<2x3xi32>, tensor<2x3xi32>) -> tensor<2x3xi32>
return %14 : tensor<2x3xi32>
}
@ -195,14 +195,14 @@ func @floordiv_reverse_broadcast_i32(%arg0: tensor<3xi32>, %arg1: tensor<2x3xi32
%0 = mhlo.constant dense<0> : tensor<3xi32>
%1 = "mhlo.compare"(%arg0, %0) {comparison_direction = "LT"} : (tensor<3xi32>, tensor<3xi32>) -> tensor<3xi1>
%2 = mhlo.constant dense<0> : tensor<2x3xi32>
%3 = "xla_chlo.broadcast_compare"(%arg1, %2) {comparison_direction = "LT"} : (tensor<2x3xi32>, tensor<2x3xi32>) -> tensor<2x3xi1>
%4 = "xla_chlo.broadcast_compare"(%1, %3) {broadcast_dimensions = dense<1> : tensor<1xi64>, comparison_direction = "EQ"} : (tensor<3xi1>, tensor<2x3xi1>) -> tensor<2x3xi1>
%5 = "xla_chlo.broadcast_divide"(%arg0, %arg1) {broadcast_dimensions = dense<1> : tensor<1xi64>} : (tensor<3xi32>, tensor<2x3xi32>) -> tensor<2x3xi32>
%3 = "chlo.broadcast_compare"(%arg1, %2) {comparison_direction = "LT"} : (tensor<2x3xi32>, tensor<2x3xi32>) -> tensor<2x3xi1>
%4 = "chlo.broadcast_compare"(%1, %3) {broadcast_dimensions = dense<1> : tensor<1xi64>, comparison_direction = "EQ"} : (tensor<3xi1>, tensor<2x3xi1>) -> tensor<2x3xi1>
%5 = "chlo.broadcast_divide"(%arg0, %arg1) {broadcast_dimensions = dense<1> : tensor<1xi64>} : (tensor<3xi32>, tensor<2x3xi32>) -> tensor<2x3xi32>
%6 = "mhlo.abs"(%arg0) : (tensor<3xi32>) -> tensor<3xi32>
%7 = "mhlo.abs"(%arg1) : (tensor<2x3xi32>) -> tensor<2x3xi32>
%8 = mhlo.constant dense<1> : tensor<2x3xi32>
%9 = mhlo.subtract %7, %8 : tensor<2x3xi32>
%10 = "xla_chlo.broadcast_add"(%6, %9) {broadcast_dimensions = dense<1> : tensor<1xi64>} : (tensor<3xi32>, tensor<2x3xi32>) -> tensor<2x3xi32>
%10 = "chlo.broadcast_add"(%6, %9) {broadcast_dimensions = dense<1> : tensor<1xi64>} : (tensor<3xi32>, tensor<2x3xi32>) -> tensor<2x3xi32>
%11 = "mhlo.negate"(%10) : (tensor<2x3xi32>) -> tensor<2x3xi32>
%12 = "mhlo.abs"(%arg1) : (tensor<2x3xi32>) -> tensor<2x3xi32>
%13 = mhlo.divide %11, %12 : tensor<2x3xi32>
@ -218,8 +218,8 @@ func @floordiv_f32(%arg0: tensor<2xf32>) -> tensor<2xf32> {
}
func @floordiv_f16_broadcast(%arg0: tensor<2x3xf16>, %arg1: tensor<3xf16>) -> tensor<2x3xf16> {
%0 = "xla_chlo.broadcast_divide"(%arg0, %arg1) {broadcast_dimensions = dense<1> : tensor<1xi64>} : (tensor<2x3xf16>, tensor<3xf16>) -> tensor<2x3xf16>
%1 = "xla_chlo.broadcast_divide"(%arg0, %arg1) {broadcast_dimensions = dense<1> : tensor<1xi64>} : (tensor<2x3xf16>, tensor<3xf16>) -> tensor<2x3xf16>
%0 = "chlo.broadcast_divide"(%arg0, %arg1) {broadcast_dimensions = dense<1> : tensor<1xi64>} : (tensor<2x3xf16>, tensor<3xf16>) -> tensor<2x3xf16>
%1 = "chlo.broadcast_divide"(%arg0, %arg1) {broadcast_dimensions = dense<1> : tensor<1xi64>} : (tensor<2x3xf16>, tensor<3xf16>) -> tensor<2x3xf16>
%2 = "mhlo.floor"(%1) : (tensor<2x3xf16>) -> tensor<2x3xf16>
return %2 : tensor<2x3xf16>
}
@ -230,22 +230,22 @@ func @equal(%arg0: tensor<2xi32>) -> tensor<2xi1> {
}
func @equal_dynamic(%arg0: tensor<?xi32>, %arg1: tensor<1xi32>) -> tensor<?xi1> {
%0 = "xla_chlo.broadcast_compare"(%arg0, %arg1) {comparison_direction = "EQ"} : (tensor<?xi32>, tensor<1xi32>) -> tensor<?xi1>
%0 = "chlo.broadcast_compare"(%arg0, %arg1) {comparison_direction = "EQ"} : (tensor<?xi32>, tensor<1xi32>) -> tensor<?xi1>
return %0 : tensor<?xi1>
}
func @equal_broadcast(%arg0: tensor<1xi32>, %arg1: tensor<1x2xi32>) -> tensor<1x2xi1> {
%0 = "xla_chlo.broadcast_compare"(%arg0, %arg1) {broadcast_dimensions = dense<1> : tensor<1xi64>, comparison_direction = "EQ"} : (tensor<1xi32>, tensor<1x2xi32>) -> tensor<1x2xi1>
%0 = "chlo.broadcast_compare"(%arg0, %arg1) {broadcast_dimensions = dense<1> : tensor<1xi64>, comparison_direction = "EQ"} : (tensor<1xi32>, tensor<1x2xi32>) -> tensor<1x2xi1>
return %0 : tensor<1x2xi1>
}
func @equal_broadcast_no_incompatible_shapes_error(%arg0: tensor<2xi32>, %arg1: tensor<1x2xi32>) -> tensor<1x2xi1> {
%0 = "xla_chlo.broadcast_compare"(%arg0, %arg1) {broadcast_dimensions = dense<1> : tensor<1xi64>, comparison_direction = "EQ"} : (tensor<2xi32>, tensor<1x2xi32>) -> tensor<1x2xi1>
%0 = "chlo.broadcast_compare"(%arg0, %arg1) {broadcast_dimensions = dense<1> : tensor<1xi64>, comparison_direction = "EQ"} : (tensor<2xi32>, tensor<1x2xi32>) -> tensor<1x2xi1>
return %0 : tensor<1x2xi1>
}
func @equal_incompatible_shape_broadcastable(%arg0: tensor<?xi32>, %arg1: tensor<1xi32>) -> tensor<?xi1> {
%0 = "xla_chlo.broadcast_compare"(%arg0, %arg1) {comparison_direction = "EQ"} : (tensor<?xi32>, tensor<1xi32>) -> tensor<?xi1>
%0 = "chlo.broadcast_compare"(%arg0, %arg1) {comparison_direction = "EQ"} : (tensor<?xi32>, tensor<1xi32>) -> tensor<?xi1>
return %0 : tensor<?xi1>
}
@ -255,17 +255,17 @@ func @notequal(%arg0: tensor<2xi32>) -> tensor<2xi1> {
}
func @notequal_broadcast(%arg0: tensor<1xi32>, %arg1: tensor<1x2xi32>) -> tensor<1x2xi1> {
%0 = "xla_chlo.broadcast_compare"(%arg0, %arg1) {broadcast_dimensions = dense<1> : tensor<1xi64>, comparison_direction = "NE"} : (tensor<1xi32>, tensor<1x2xi32>) -> tensor<1x2xi1>
%0 = "chlo.broadcast_compare"(%arg0, %arg1) {broadcast_dimensions = dense<1> : tensor<1xi64>, comparison_direction = "NE"} : (tensor<1xi32>, tensor<1x2xi32>) -> tensor<1x2xi1>
return %0 : tensor<1x2xi1>
}
func @notequal_broadcast_no_incompatible_shapes_error(%arg0: tensor<2xi32>, %arg1: tensor<1x2xi32>) -> tensor<1x2xi1> {
%0 = "xla_chlo.broadcast_compare"(%arg0, %arg1) {broadcast_dimensions = dense<1> : tensor<1xi64>, comparison_direction = "NE"} : (tensor<2xi32>, tensor<1x2xi32>) -> tensor<1x2xi1>
%0 = "chlo.broadcast_compare"(%arg0, %arg1) {broadcast_dimensions = dense<1> : tensor<1xi64>, comparison_direction = "NE"} : (tensor<2xi32>, tensor<1x2xi32>) -> tensor<1x2xi1>
return %0 : tensor<1x2xi1>
}
func @notequal_incompatible_shape_broadcastable(%arg0: tensor<?xi32>, %arg1: tensor<1xi32>) -> tensor<?xi1> {
%0 = "xla_chlo.broadcast_compare"(%arg0, %arg1) {comparison_direction = "NE"} : (tensor<?xi32>, tensor<1xi32>) -> tensor<?xi1>
%0 = "chlo.broadcast_compare"(%arg0, %arg1) {comparison_direction = "NE"} : (tensor<?xi32>, tensor<1xi32>) -> tensor<?xi1>
return %0 : tensor<?xi1>
}
@ -275,7 +275,7 @@ func @greater(%arg0: tensor<2xi32>) -> tensor<2xi1> {
}
func @broadcast_greater(%arg0: tensor<1xi32>, %arg1: tensor<1x2xi32>) -> tensor<1x2xi1> {
%0 = "xla_chlo.broadcast_compare"(%arg0, %arg1) {broadcast_dimensions = dense<1> : tensor<1xi64>, comparison_direction = "GT"} : (tensor<1xi32>, tensor<1x2xi32>) -> tensor<1x2xi1>
%0 = "chlo.broadcast_compare"(%arg0, %arg1) {broadcast_dimensions = dense<1> : tensor<1xi64>, comparison_direction = "GT"} : (tensor<1xi32>, tensor<1x2xi32>) -> tensor<1x2xi1>
return %0 : tensor<1x2xi1>
}
@ -285,7 +285,7 @@ func @greater_equal(%arg0: tensor<2xi32>) -> tensor<2xi1> {
}
func @broadcast_greater_equal(%arg0: tensor<1xi32>, %arg1: tensor<1x2xi32>) -> tensor<1x2xi1> {
%0 = "xla_chlo.broadcast_compare"(%arg0, %arg1) {broadcast_dimensions = dense<1> : tensor<1xi64>, comparison_direction = "GE"} : (tensor<1xi32>, tensor<1x2xi32>) -> tensor<1x2xi1>
%0 = "chlo.broadcast_compare"(%arg0, %arg1) {broadcast_dimensions = dense<1> : tensor<1xi64>, comparison_direction = "GE"} : (tensor<1xi32>, tensor<1x2xi32>) -> tensor<1x2xi1>
return %0 : tensor<1x2xi1>
}
@ -295,7 +295,7 @@ func @less(%arg0: tensor<2xi32>) -> tensor<2xi1> {
}
func @broadcast_less(%arg0: tensor<1xi32>, %arg1: tensor<1x2xi32>) -> tensor<1x2xi1> {
%0 = "xla_chlo.broadcast_compare"(%arg0, %arg1) {broadcast_dimensions = dense<1> : tensor<1xi64>, comparison_direction = "LT"} : (tensor<1xi32>, tensor<1x2xi32>) -> tensor<1x2xi1>
%0 = "chlo.broadcast_compare"(%arg0, %arg1) {broadcast_dimensions = dense<1> : tensor<1xi64>, comparison_direction = "LT"} : (tensor<1xi32>, tensor<1x2xi32>) -> tensor<1x2xi1>
return %0 : tensor<1x2xi1>
}
@ -305,7 +305,7 @@ func @less_equal(%arg0: tensor<2xi32>) -> tensor<2xi1> {
}
func @broadcast_less_equal(%arg0: tensor<1xi32>, %arg1: tensor<1x2xi32>) -> tensor<1x2xi1> {
%0 = "xla_chlo.broadcast_compare"(%arg0, %arg1) {broadcast_dimensions = dense<1> : tensor<1xi64>, comparison_direction = "LE"} : (tensor<1xi32>, tensor<1x2xi32>) -> tensor<1x2xi1>
%0 = "chlo.broadcast_compare"(%arg0, %arg1) {broadcast_dimensions = dense<1> : tensor<1xi64>, comparison_direction = "LE"} : (tensor<1xi32>, tensor<1x2xi32>) -> tensor<1x2xi1>
return %0 : tensor<1x2xi1>
}
@ -326,35 +326,35 @@ func @const() -> tensor<2xi32> {
func @relu(%arg0: tensor<1xi32>) -> tensor<1xi32> {
%0 = mhlo.constant dense<0> : tensor<i32>
%1 = "xla_chlo.broadcast_maximum"(%0, %arg0) {broadcast_dimensions = dense<[]> : tensor<0xi64>} : (tensor<i32>, tensor<1xi32>) -> tensor<1xi32>
%1 = "chlo.broadcast_maximum"(%0, %arg0) {broadcast_dimensions = dense<[]> : tensor<0xi64>} : (tensor<i32>, tensor<1xi32>) -> tensor<1xi32>
return %1 : tensor<1xi32>
}
func @relu_unranked(%arg0: tensor<?xi32>) -> tensor<?xi32> {
%0 = mhlo.constant dense<0> : tensor<i32>
%1 = "xla_chlo.broadcast_maximum"(%0, %arg0) {broadcast_dimensions = dense<[]> : tensor<0xi64>} : (tensor<i32>, tensor<?xi32>) -> tensor<?xi32>
%1 = "chlo.broadcast_maximum"(%0, %arg0) {broadcast_dimensions = dense<[]> : tensor<0xi64>} : (tensor<i32>, tensor<?xi32>) -> tensor<?xi32>
return %1 : tensor<?xi32>
}
func @relu6(%arg0: tensor<1xi32>) -> tensor<1xi32> {
%0 = mhlo.constant dense<0> : tensor<i32>
%1 = mhlo.constant dense<6> : tensor<i32>
%2 = "xla_chlo.broadcast_minimum"(%arg0, %1) {broadcast_dimensions = dense<[]> : tensor<0xi64>} : (tensor<1xi32>, tensor<i32>) -> tensor<1xi32>
%3 = "xla_chlo.broadcast_maximum"(%2, %0) {broadcast_dimensions = dense<[]> : tensor<0xi64>} : (tensor<1xi32>, tensor<i32>) -> tensor<1xi32>
%2 = "chlo.broadcast_minimum"(%arg0, %1) {broadcast_dimensions = dense<[]> : tensor<0xi64>} : (tensor<1xi32>, tensor<i32>) -> tensor<1xi32>
%3 = "chlo.broadcast_maximum"(%2, %0) {broadcast_dimensions = dense<[]> : tensor<0xi64>} : (tensor<1xi32>, tensor<i32>) -> tensor<1xi32>
return %3 : tensor<1xi32>
}
func @relu6_unranked(%arg0: tensor<?xi32>) -> tensor<?xi32> {
%0 = mhlo.constant dense<0> : tensor<i32>
%1 = mhlo.constant dense<6> : tensor<i32>
%2 = "xla_chlo.broadcast_minimum"(%arg0, %1) {broadcast_dimensions = dense<[]> : tensor<0xi64>} : (tensor<?xi32>, tensor<i32>) -> tensor<?xi32>
%3 = "xla_chlo.broadcast_maximum"(%2, %0) {broadcast_dimensions = dense<[]> : tensor<0xi64>} : (tensor<?xi32>, tensor<i32>) -> tensor<?xi32>
%2 = "chlo.broadcast_minimum"(%arg0, %1) {broadcast_dimensions = dense<[]> : tensor<0xi64>} : (tensor<?xi32>, tensor<i32>) -> tensor<?xi32>
%3 = "chlo.broadcast_maximum"(%2, %0) {broadcast_dimensions = dense<[]> : tensor<0xi64>} : (tensor<?xi32>, tensor<i32>) -> tensor<?xi32>
return %3 : tensor<?xi32>
}
func @relu_grad(%arg0: tensor<4x8xf32>, %arg1: tensor<?x?xf32>) -> tensor<4x8xf32> {
%0 = mhlo.constant dense<0.000000e+00> : tensor<f32>
%1 = "xla_chlo.broadcast_compare"(%arg1, %0) {broadcast_dimensions = dense<[]> : tensor<0xi64>, comparison_direction = "GT"} : (tensor<?x?xf32>, tensor<f32>) -> tensor<?x?xi1>
%1 = "chlo.broadcast_compare"(%arg1, %0) {broadcast_dimensions = dense<[]> : tensor<0xi64>, comparison_direction = "GT"} : (tensor<?x?xf32>, tensor<f32>) -> tensor<?x?xi1>
%2 = mhlo.constant dense<0.000000e+00> : tensor<4x8xf32>
%3 = "mhlo.select"(%1, %arg0, %2) : (tensor<?x?xi1>, tensor<4x8xf32>, tensor<4x8xf32>) -> tensor<4x8xf32>
return %3 : tensor<4x8xf32>

View File

@ -86,6 +86,15 @@ func @mul_no_nan(%arg0: tensor<2x3xf32>, %arg1: tensor<3xf32>) -> tensor<2x3xf32
return %0 : tensor<2x3xf32>
}
// CHECK-LABEL: @is_nan
func @is_nan(%arg0: tensor<3x4xf32>) -> tensor<3x4xi1> {
// CHECK: %[[NAN:.*]] = "tf.Const"() {value = dense<0x7FC00000> : tensor<f32>} : () -> tensor<f32>
// CHECK: %[[RESULT:.*]] = "tf.Equal"(%arg0, %[[NAN]]) {incompatible_shape_error = true} : (tensor<3x4xf32>, tensor<f32>) -> tensor<3x4xi1>
%0 = "tf.IsNan"(%arg0) : (tensor<3x4xf32>) -> tensor<3x4xi1>
// CHECK: return %[[RESULT]]
return %0 : tensor<3x4xi1>
}
// CHECK-LABEL: func @fill
// CHECK-SAME: (%[[ARG0:.*]]: tensor<*xi64>, %[[ARG1:.*]]: tensor<*xf32>)
func @fill(%arg0: tensor<*xi64>, %arg1: tensor<*xf32>) -> tensor<*xf32> {

View File

@ -0,0 +1,79 @@
// RUN: tf-opt %s -split-input-file -verify-diagnostics -tf-tpu-update-embedding-enqueue-op-inputs | FileCheck %s
// CHECK-LABEL: func @check_enqueue_ops_update_for_eval
// CHECK-SAME: %[[ARG_0:[a-z0-9]*]]: tensor<?x2xi32>
// CHECK-SAME: %[[ARG_1:[a-z0-9]*]]: tensor<?x2xi32>
// CHECK-SAME: %[[ARG_2:[a-z0-9]*]]: tensor<?x2xi32>
// CHECK-SAME: %[[ARG_3:[a-z0-9]*]]: tensor<?xi32>
// CHECK-SAME: %[[ARG_4:[a-z0-9]*]]: tensor<?xi32>
// CHECK-SAME: %[[ARG_5:[a-z0-9]*]]: tensor<?xi32>
// CHECK-SAME: %[[ARG_6:[a-z0-9]*]]: tensor<!tf.string>
// CHECK-SAME: %[[ARG_7:[a-z0-9]*]]: tensor<!tf.string>
// CHECK-SAME: %[[ARG_8:[a-z0-9]*]]: tensor<i1>
func @check_enqueue_ops_update_for_eval(%arg0: tensor<?x2xi32>, %arg1: tensor<?x2xi32>,
%arg2 :tensor<?x2xi32>, %arg3: tensor<?xi32>, %arg4: tensor<?xi32>, %arg5: tensor<?xi32>,
%arg6: tensor<!tf.string>, %arg7: tensor<!tf.string>, %arg8: tensor<i1>) -> () {
// CHECK: %[[CONST_0:[a-z0-9]*]] = "tf.Const"()
%0 = "tf.Const"() {value = dense<[]> : tensor<0xf32>} : () -> tensor<0xf32>
%1 = "tf.SelectV2"(%arg8, %arg6, %arg7) : (tensor<i1>, tensor<!tf.string>, tensor<!tf.string>) -> tensor<!tf.string>
// CHECK: "tf.EnqueueTPUEmbeddingSparseTensorBatch"(%[[ARG_0]], %[[ARG_1]], %[[ARG_2]], %[[ARG_3]], %[[ARG_4]], %[[ARG_5]], %[[CONST_0]], %[[CONST_0]], %[[CONST_0]], %[[ARG_7]])
"tf.EnqueueTPUEmbeddingSparseTensorBatch"(%arg0, %arg1, %arg2, %arg3, %arg4, %arg5, %0, %0, %0, %1) {_tpu_embedding_layer = "call1", _xla_outside_compilation = "0", combiners = ["mean", "sum"], device_ordinal = -1 : i64, max_sequence_lengths = [0, 0, 0], table_ids = [1, 1, 0]} : (tensor<?x2xi32>, tensor<?x2xi32>, tensor<?x2xi32>, tensor<?xi32>, tensor<?xi32>, tensor<?xi32>, tensor<0xf32>, tensor<0xf32>, tensor<0xf32>, tensor<!tf.string>) -> ()
%2:2 = "tf.RecvTPUEmbeddingActivations"() {_tpu_embedding_layer = "call1", config = "\0A\0B\0C\0D"} : () -> (tensor<2x2xf32>, tensor<4x4xf32>)
return
}
// -----
// CHECK-LABEL: func @check_enqueue_ops_update_for_training
// CHECK-SAME: %[[ARG_0:[a-z0-9]*]]: tensor<?x2xi32>
// CHECK-SAME: %[[ARG_1:[a-z0-9]*]]: tensor<?x2xi32>
// CHECK-SAME: %[[ARG_2:[a-z0-9]*]]: tensor<?x2xi32>
// CHECK-SAME: %[[ARG_3:[a-z0-9]*]]: tensor<?xi32>
// CHECK-SAME: %[[ARG_4:[a-z0-9]*]]: tensor<?xi32>
// CHECK-SAME: %[[ARG_5:[a-z0-9]*]]: tensor<?xi32>
// CHECK-SAME: %[[ARG_6:[a-z0-9]*]]: tensor<!tf.string>
// CHECK-SAME: %[[ARG_7:[a-z0-9]*]]: tensor<!tf.string>
// CHECK-SAME: %[[ARG_8:[a-z0-9]*]]: tensor<i1>
func @check_enqueue_ops_update_for_training(%arg0: tensor<?x2xi32>, %arg1: tensor<?x2xi32>,
%arg2 :tensor<?x2xi32>, %arg3: tensor<?xi32>, %arg4: tensor<?xi32>, %arg5: tensor<?xi32>,
%arg6: tensor<!tf.string>, %arg7: tensor<!tf.string>, %arg8: tensor<i1>) -> () {
// CHECK: %[[CONST_0:[a-z0-9]*]] = "tf.Const"()
%0 = "tf.Const"() {value = dense<[]> : tensor<0xf32>} : () -> tensor<0xf32>
%1 = "tf.SelectV2"(%arg8, %arg6, %arg7) : (tensor<i1>, tensor<!tf.string>, tensor<!tf.string>) -> tensor<!tf.string>
%2 = "tf.Const"() {value = dense<0.0> : tensor<2x2xf32>} : () -> tensor<2x2xf32>
%3 = "tf.Const"() {value = dense<0.0> : tensor<4x4xf32>} : () -> tensor<4x4xf32>
"tf.SendTPUEmbeddingGradients"(%2, %3) {_tpu_embedding_layer = "call1", config = "\0A\0B\0C\0D", operand_segment_sizes = dense<[2, 0]> : vector<2xi32>} : (tensor<2x2xf32>, tensor<4x4xf32>) -> ()
// CHECK: "tf.EnqueueTPUEmbeddingSparseTensorBatch"(%[[ARG_0]], %[[ARG_1]], %[[ARG_2]], %[[ARG_3]], %[[ARG_4]], %[[ARG_5]], %[[CONST_0]], %[[CONST_0]], %[[CONST_0]], %[[ARG_6]])
"tf.EnqueueTPUEmbeddingSparseTensorBatch"(%arg0, %arg1, %arg2, %arg3, %arg4, %arg5, %0, %0, %0, %1) {_tpu_embedding_layer = "call1", _xla_outside_compilation = "0", combiners = ["mean", "sum"], device_ordinal = -1 : i64, max_sequence_lengths = [0, 0, 0], table_ids = [1, 1, 0]} : (tensor<?x2xi32>, tensor<?x2xi32>, tensor<?x2xi32>, tensor<?xi32>, tensor<?xi32>, tensor<?xi32>, tensor<0xf32>, tensor<0xf32>, tensor<0xf32>, tensor<!tf.string>) -> ()
%4:2 = "tf.RecvTPUEmbeddingActivations"() {_tpu_embedding_layer = "call1", config = "\0A\0B\0C\0D"} : () -> (tensor<2x2xf32>, tensor<4x4xf32>)
return
}
// -----
func @check_enqueue_ops_with_different_attr_disallowed(%arg0: tensor<?x2xi32>, %arg1: tensor<?x2xi32>,
%arg2 :tensor<?x2xi32>, %arg3: tensor<?xi32>, %arg4: tensor<?xi32>, %arg5: tensor<?xi32>,
%arg6: tensor<!tf.string>, %arg7: tensor<!tf.string>, %arg8: tensor<i1>) -> () {
%0 = "tf.Const"() {value = dense<[]> : tensor<0xf32>} : () -> tensor<0xf32>
%1 = "tf.SelectV2"(%arg8, %arg6, %arg7) : (tensor<i1>, tensor<!tf.string>, tensor<!tf.string>) -> tensor<!tf.string>
// expected-error @+1 {{'tf.EnqueueTPUEmbeddingSparseTensorBatch' op must have a corresponding 'tf.RecvTPUEmbeddingActivations' op}}
"tf.EnqueueTPUEmbeddingSparseTensorBatch"(%arg0, %arg1, %arg2, %arg3, %arg4, %arg5, %0, %0, %0, %1) {_tpu_embedding_layer = "call_123", _xla_outside_compilation = "0", combiners = ["mean", "sum"], device_ordinal = -1 : i64, max_sequence_lengths = [0, 0, 0], table_ids = [1, 1, 0]} : (tensor<?x2xi32>, tensor<?x2xi32>, tensor<?x2xi32>, tensor<?xi32>, tensor<?xi32>, tensor<?xi32>, tensor<0xf32>, tensor<0xf32>, tensor<0xf32>, tensor<!tf.string>) -> ()
%2:2 = "tf.RecvTPUEmbeddingActivations"() {_tpu_embedding_layer = "call1", config = "\0A\0B\0C\0D"} : () -> (tensor<2x2xf32>, tensor<4x4xf32>)
return
}
// -----
func @check_embedding_ops_with_missing_attribute_disallowed(%arg0: tensor<?x2xi32>, %arg1: tensor<?x2xi32>,
%arg2 :tensor<?x2xi32>, %arg3: tensor<?xi32>, %arg4: tensor<?xi32>, %arg5: tensor<?xi32>,
%arg6: tensor<!tf.string>, %arg7: tensor<!tf.string>, %arg8: tensor<i1>) -> () {
%0 = "tf.Const"() {value = dense<[]> : tensor<0xf32>} : () -> tensor<0xf32>
%1 = "tf.SelectV2"(%arg8, %arg6, %arg7) : (tensor<i1>, tensor<!tf.string>, tensor<!tf.string>) -> tensor<!tf.string>
"tf.EnqueueTPUEmbeddingSparseTensorBatch"(%arg0, %arg1, %arg2, %arg3, %arg4, %arg5, %0, %0, %0, %1) {_tpu_embedding_layer = "call_123", _xla_outside_compilation = "0", combiners = ["mean", "sum"], device_ordinal = -1 : i64, max_sequence_lengths = [0, 0, 0], table_ids = [1, 1, 0]} : (tensor<?x2xi32>, tensor<?x2xi32>, tensor<?x2xi32>, tensor<?xi32>, tensor<?xi32>, tensor<?xi32>, tensor<0xf32>, tensor<0xf32>, tensor<0xf32>, tensor<!tf.string>) -> ()
// expected-error @+1 {{'tf.RecvTPUEmbeddingActivations' op requires attribute '_tpu_embedding_layer'}}
%2:2 = "tf.RecvTPUEmbeddingActivations"() {config = "\0A\0B\0C\0D"} : () -> (tensor<2x2xf32>, tensor<4x4xf32>)
return
}

View File

@ -107,6 +107,7 @@ void CreateTPUBridgePipeline(OpPassManager &pm) {
}
void CreateTPUBridgePipelineV1(OpPassManager &pm) {
pm.addPass(TF::CreateTFShapeInferencePass());
// For V1 compatibility, we process a module where the graph does not have
// feeds and fetched. We extract first the TPU computation in a submodule,
// where it'll be in a function with args and returned values, much more like

View File

@ -194,6 +194,13 @@ def RealDivWithSqrtDivisor : Pat<(TF_RealDivOp $arg0, (TF_SqrtOp $arg1)),
def ReciprocalNested : Pat<(TF_ReciprocalOp (TF_ReciprocalOp $arg)),
(replaceWithValue $arg)>;
//===----------------------------------------------------------------------===//
// Reshape op patterns.
//===----------------------------------------------------------------------===//
def RedundantReshape : Pat<(TF_ReshapeOp (TF_ReshapeOp $arg, $unused), $shape),
(TF_ReshapeOp $arg, $shape)>;
//===----------------------------------------------------------------------===//
// Select op patterns.
//===----------------------------------------------------------------------===//

View File

@ -154,6 +154,14 @@ foreach fromToBinPair = [[TF_DivNoNanOp, TF_DivOp],
def LowerFillOp : Pat<(TF_FillOp $dims, $value),
(TF_BroadcastToOp $value, $dims)>;
//===----------------------------------------------------------------------===//
// NaN op patterns.
//===----------------------------------------------------------------------===//
def LowerIsNanOp : Pat<(TF_IsNanOp $x),
(TF_EqualOp $x, (TF_ConstOp:$nan (GetScalarNanOfType $x)),
/*incompatible_shape_error*/ConstBoolAttrTrue)>;
//===----------------------------------------------------------------------===//
// L2Loss op patterns.
//===----------------------------------------------------------------------===//

View File

@ -287,6 +287,11 @@ CreateTPUExtractHeadTailOutsideCompilationPass();
// that are only used for host computation.
std::unique_ptr<OperationPass<FuncOp>> CreateTPUHostComputationExpansionPass();
// Creates a pass that updates inputs to TPU embedding layer enqueue ops so that
// correct ops are invoked during training and evaluation.
std::unique_ptr<OperationPass<FuncOp>>
CreateTPUUpdateEmbeddingEnqueueOpInputsPass();
// Creates a pass that extract outside compilation (CPU ops inside TPU cluster)
// ops to a separate parallel_execute region to run on CPU.
std::unique_ptr<OperationPass<ModuleOp>>

View File

@ -375,7 +375,7 @@ LogicalResult FindResourceArgUseInfo(
info.data_type = assign.value().getType();
continue;
}
if (isa<TF::StackPushV2Op>(user) || isa<TF::StackPopV2Op>(user)) {
if (isa<TF::StackPushV2Op, TF::StackPopV2Op>(user)) {
// Stacks will be handled by a separate pass.
do_not_touch = true;
break;

View File

@ -205,9 +205,9 @@ GetSubtypes(Type type) {
// Returns whether type can be further refined.
bool CanBeRefined(Type type) {
auto shape_type = type.dyn_cast<ShapedType>();
return shape_type && (!shape_type.hasStaticShape() ||
shape_type.getElementType().isa<TF::ResourceType>() ||
shape_type.getElementType().isa<TF::VariantType>());
return shape_type &&
(!shape_type.hasStaticShape() ||
shape_type.getElementType().isa<TF::ResourceType, TF::VariantType>());
}
// Infers the shape from a (Stateful)PartionedCall operation by looking up the
@ -712,8 +712,7 @@ bool ShapeInference::InferShapeForSingleOperation(Operation* op) {
// The shape function of these ops sometimes does not propagate subtypes
// (handle shapes) for resource and variant types. We use a simple passthrough
// to make sure they are preserved in the output.
if (isa<TF::IdentityOp>(op) || isa<TF::IdentityNOp>(op) ||
isa<TF::ZerosLikeOp>(op) || isa<TF::WhileOp>(op)) {
if (isa<TF::IdentityOp, TF::IdentityNOp, TF::ZerosLikeOp, TF::WhileOp>(op)) {
return RefineTypeForPassThroughOperands(op, op->getOperands(),
op->getResults());
}
@ -729,7 +728,7 @@ bool ShapeInference::InferShapeForSingleOperation(Operation* op) {
// Handle call operations by looking up callee and infering return shape as
// needed.
if (isa<PartitionedCallOp>(op) || isa<StatefulPartitionedCallOp>(op))
if (isa<PartitionedCallOp, StatefulPartitionedCallOp>(op))
return InferShapeForCall(op);
// tf.Cast are only inferred if they have at least one user in the TF dialect
@ -889,8 +888,7 @@ bool ShapeInference::InferShapeForSingleOperation(Operation* op) {
};
auto new_element_type = shaped_type.getElementType();
// Populate the handle shapes for a resource/variant.
if (new_element_type.isa<TF::ResourceType>() ||
new_element_type.isa<TF::VariantType>()) {
if (new_element_type.isa<TF::ResourceType, TF::VariantType>()) {
auto handle_shapes_types = c.output_handle_shapes_and_types(output);
if (handle_shapes_types) {
SmallVector<TensorType, 1> subtypes;

View File

@ -488,7 +488,7 @@ LogicalResult DecomposeStackOpsInternal(
llvm::StringMap<PartitionedCallStackOpsInfo>*
decomposed_partitioned_call_callees) {
for (auto& op : llvm::make_early_inc_range(block->getOperations())) {
if (llvm::isa<TF::IdentityOp>(&op) || llvm::isa<TF::IdentityNOp>(&op)) {
if (llvm::isa<TF::IdentityOp, TF::IdentityNOp>(&op)) {
// Removes identity nodes in the block. The device computation does not
// need such nodes to carry information.
op.replaceAllUsesWith(op.getOperands());

View File

@ -809,7 +809,7 @@ LogicalResult DecomposeTensorArrayOps(
llvm::StringMap<PartitionedCallTensorArrayOpsInfo>*
decomposed_partitioned_call_callees) {
for (auto& op : llvm::make_early_inc_range(block->getOperations())) {
if (llvm::isa<TF::IdentityOp>(&op) || llvm::isa<TF::IdentityNOp>(&op)) {
if (llvm::isa<TF::IdentityOp, TF::IdentityNOp>(&op)) {
op.replaceAllUsesWith(op.getOperands());
op.erase();
} else if (auto ta = llvm::dyn_cast<TF::TensorArrayV3Op>(&op)) {

View File

@ -495,8 +495,7 @@ void TPUClusterFormation::runOnFunction() {
// Remove TPUReplicatedInput and TPUReplicatedOutput nodes.
auto remove_result = getFunction().walk([&](Operation* op) {
if (!llvm::isa<TF::TPUReplicatedInputOp>(op) &&
!llvm::isa<TF::TPUReplicatedOutputOp>(op))
if (!llvm::isa<TF::TPUReplicatedInputOp, TF::TPUReplicatedOutputOp>(op))
return WalkResult::advance();
// Forward operand to result. When `num_replicas` attribute is 1, no

View File

@ -0,0 +1,162 @@
/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#include "llvm/ADT/StringMap.h"
#include "llvm/Support/Casting.h"
#include "mlir/IR/Attributes.h" // from @llvm-project
#include "mlir/IR/Block.h" // from @llvm-project
#include "mlir/IR/Builders.h" // from @llvm-project
#include "mlir/IR/Operation.h" // from @llvm-project
#include "mlir/IR/Value.h" // from @llvm-project
#include "mlir/IR/Visitors.h" // from @llvm-project
#include "mlir/Pass/Pass.h" // from @llvm-project
#include "mlir/Pass/PassRegistry.h" // from @llvm-project
#include "mlir/Support/LogicalResult.h" // from @llvm-project
#include "tensorflow/compiler/mlir/tensorflow/ir/tf_ops.h"
#include "tensorflow/compiler/mlir/tensorflow/transforms/passes.h"
namespace mlir {
namespace TFTPU {
namespace {
constexpr char kTPUEmbeddingAttr[] = "_tpu_embedding_layer";
struct TPUUpdateEmbeddingEnqueueOpInputs
: public PassWrapper<TPUUpdateEmbeddingEnqueueOpInputs, FunctionPass> {
void runOnFunction() override;
};
// Extracts `_tpu_embedding_layer` attribute from TPU embedding ops and
// clear the attribute from the operation. This ensures that future optimization
// passes does not trigger additional logic due to presence of this attribute.
LogicalResult ExtractEmbeddingAttribute(
Operation* op, llvm::StringMap<Operation*>* embedding_op_map) {
auto embedding_attr = op->getAttrOfType<StringAttr>(kTPUEmbeddingAttr);
if (!embedding_attr)
return op->emitOpError("requires attribute '_tpu_embedding_layer'");
if (!embedding_op_map->insert({embedding_attr.getValue(), op}).second)
return op->emitOpError(
"found duplicate TPU embedding ops potentially from multiple "
"TPUEmbedding layers");
op->removeAttr(kTPUEmbeddingAttr);
return success();
}
LogicalResult FindTPUEmbeddingOps(
FuncOp func_op, llvm::StringMap<Operation*>* enqueue_op_map,
llvm::StringMap<Operation*>* recv_activation_op_map,
llvm::StringMap<Operation*>* send_gradient_op_map) {
auto walk_result = func_op.walk([&](Operation* op) {
if (llvm::isa<TF::RecvTPUEmbeddingActivationsOp>(op))
if (failed(ExtractEmbeddingAttribute(op, recv_activation_op_map)))
return WalkResult::interrupt();
if (llvm::isa<TF::SendTPUEmbeddingGradientsOp>(op))
if (failed(ExtractEmbeddingAttribute(op, send_gradient_op_map)))
return WalkResult::interrupt();
if (llvm::isa<TF::EnqueueTPUEmbeddingSparseTensorBatchOp,
TF::EnqueueTPUEmbeddingRaggedTensorBatchOp>(op))
if (failed(ExtractEmbeddingAttribute(op, enqueue_op_map)))
return WalkResult::interrupt();
return WalkResult::advance();
});
return failure(walk_result.wasInterrupted());
}
// Updates the operand of TPU embedding enqueue ops depending on whether
// the graph is in training mode or in non-training mode.
// If SendTPUEmbeddingGradients op is present, this means that graph is in
// training mode. As so, correctly feed in `then` branch value of SelectV2
// operand as inputs to the TPU embedding enqueue ops.
LogicalResult UpdateEmbeddingEnqueueOpInput(
const llvm::StringMap<Operation*>& enqueue_op_map,
const llvm::StringMap<Operation*>& recv_activation_op_map,
const llvm::StringMap<Operation*>& send_gradient_op_map) {
for (const auto& it : enqueue_op_map) {
const auto& embedding_attr = it.getKey();
Operation* embedding_op = it.second;
if (!recv_activation_op_map.count(embedding_attr))
return embedding_op->emitOpError()
<< "must have a corresponding '"
<< TF::RecvTPUEmbeddingActivationsOp::getOperationName() << "' op";
// TPU Embedding enqueue ops take different inputs depending on whether
// graph is in training mode or in eval/prediction mode. The inputs to the
// enqueue ops are present/listed as operands to SelectV2 op. Then branch
// operand of the SelectV2 op represents input to take during training
// and else branch operand represents input to take during
// prediction/evaluation. If SendTPUEmbeddingGradients op exists in the
// graph, then graph is in training mode, so correctly forward the input
// of SelectV2 op as operand to the TPU embedding enqueue op.
bool is_training = send_gradient_op_map.count(embedding_attr);
for (auto enqueue_operand : embedding_op->getOperands()) {
if (auto select = llvm::dyn_cast_or_null<TF::SelectV2Op>(
enqueue_operand.getDefiningOp())) {
enqueue_operand.replaceAllUsesWith(is_training ? select.t()
: select.e());
}
}
}
return success();
}
void TPUUpdateEmbeddingEnqueueOpInputs::runOnFunction() {
OpBuilder builder(&getContext());
auto func_op = getFunction();
// All TPU embedding layer related ops are annotated with
// `_tpu_embedding_layer` attribute along with corresponding string attribute.
// Store all tpu embedding layer related ops with value of
// `_tpu_embedding_layer` attribute as map key.
llvm::StringMap<Operation*> enqueue_op_map;
llvm::StringMap<Operation*> recv_activation_op_map;
llvm::StringMap<Operation*> send_gradient_op_map;
if (failed(FindTPUEmbeddingOps(func_op, &enqueue_op_map,
&recv_activation_op_map,
&send_gradient_op_map)))
return signalPassFailure();
if (enqueue_op_map.size() != recv_activation_op_map.size()) {
func_op.emitError() << "expects the number of embedding enqueue ops to "
"match the number of '"
<< TF::RecvTPUEmbeddingActivationsOp::getOperationName()
<< "' ops";
return signalPassFailure();
}
if (failed(UpdateEmbeddingEnqueueOpInput(
enqueue_op_map, recv_activation_op_map, send_gradient_op_map)))
return signalPassFailure();
}
} // anonymous namespace
std::unique_ptr<OperationPass<FuncOp>>
CreateTPUUpdateEmbeddingEnqueueOpInputsPass() {
return std::make_unique<TPUUpdateEmbeddingEnqueueOpInputs>();
}
static PassRegistration<TPUUpdateEmbeddingEnqueueOpInputs> pass(
"tf-tpu-update-embedding-enqueue-op-inputs",
"Updates inputs to TPU embedding enqueue ops depending on whether graph "
"is in training mode or in evaluation mode.");
} // namespace TFTPU
} // namespace mlir

View File

@ -576,9 +576,8 @@ StatusOr<std::unique_ptr<Graph>> Exporter::Convert(
// Adds nodes for operations.
for (Operation& inst : graph_op.GetBody()) {
for (auto type : inst.getResultTypes())
if (!type.isa<mlir::TensorType>() &&
!type.isa<mlir::tf_executor::ControlType>() &&
!type.isa<mlir::tf_executor::TokenType>())
if (!type.isa<mlir::TensorType, mlir::tf_executor::ControlType,
mlir::tf_executor::TokenType>())
return errors::InvalidArgument(
"Values must be of tensor type, TensorFlow control type, or "
"TensorFlow token type. Found ",

View File

@ -253,7 +253,7 @@ static void RegisterDialects() {
mlir::registerDialect<mlir::TF::TensorFlowDialect>();
mlir::registerDialect<mlir::shape::ShapeDialect>();
mlir::registerDialect<mlir::tf_executor::TensorFlowExecutorDialect>();
mlir::registerDialect<mlir::mhlo::XlaHloDialect>();
mlir::registerDialect<mlir::mhlo::MhloDialect>();
return true;
}();
(void)init_once;

View File

@ -88,7 +88,7 @@ struct MaterializeBroadcastsPass
mlir::OwningRewritePatternList conversionPatterns;
// Consider the mhlo dialect legal for tests.
conversionTarget.addLegalDialect<mlir::mhlo::XlaHloDialect>();
conversionTarget.addLegalDialect<mlir::mhlo::MhloDialect>();
// The conversion uses helpers from the Standard dialect.
conversionTarget.addLegalDialect<mlir::StandardOpsDialect>();
@ -128,7 +128,7 @@ Status LowerTfOpToLhloWithDynamicShapes(mlir::ModuleOp module) {
pm.addNestedPass<mlir::FuncOp>(absl::make_unique<UnfuseBatchNormPass>());
pm.addPass(mlir::mhlo::createLegalizeToLhloPass(
/*results_escape_functions=*/true));
pm.addNestedPass<mlir::FuncOp>(mlir::xla_lhlo::createLhloCopyRemovalPass());
pm.addNestedPass<mlir::FuncOp>(mlir::lmhlo::createLhloCopyRemovalPass());
if (failed(pm.run(module))) {
return InternalError("Lowering TF to LHLO failed.");

View File

@ -33,7 +33,7 @@ namespace xla {
static std::string GetMlirOpName(HloOpcode opcode) {
std::string op_name = HloOpcodeString(opcode);
absl::c_replace(op_name, '-', '_');
return mlir::mhlo::XlaHloDialect::getDialectNamespace().str() + "." + op_name;
return mlir::mhlo::MhloDialect::getDialectNamespace().str() + "." + op_name;
}
static std::string ToString(mlir::Type ty) {

View File

@ -1,11 +1,11 @@
// RUN: xla-opt -split-input-file -xla-hlo-to-lhlo-with-xla %s | FileCheck --enable-var-scope %s
// CHECK-LABEL: func @main
// CHECK-SAME: %[[ARG0:.*]]: memref<2x2xf32> {xla_lhlo.params = 0
// CHECK-SAME: %[[ARG0:.*]]: memref<2x2xf32> {lmhlo.params = 0
// CHECK-SAME: %[[ARG1:.*]]: memref<16xi8>
func @main(%value: tensor<2x2xf32>) -> tensor<2x2xf32> {
// CHECK: %[[VIEW:.*]] = {{.*}} memref<16xi8> to memref<2x2xf32>
// CHECK: lhlo.abs
// CHECK: lmhlo.abs
// CHECK-SAME: %[[ARG0]], %[[VIEW]]
%abs = "mhlo.abs"(%value) : (tensor<2x2xf32>) -> tensor<2x2xf32>
return %abs : tensor<2x2xf32>
@ -14,12 +14,12 @@ func @main(%value: tensor<2x2xf32>) -> tensor<2x2xf32> {
// -----
// CHECK-LABEL: func @main
// CHECK-SAME: %[[ARG0:.*]]: memref<2x2xf32> {xla_lhlo.params = 0
// CHECK-SAME: %[[ARG1:.*]]: memref<2x2xf32> {xla_lhlo.params = 1
// CHECK-SAME: %[[ARG0:.*]]: memref<2x2xf32> {lmhlo.params = 0
// CHECK-SAME: %[[ARG1:.*]]: memref<2x2xf32> {lmhlo.params = 1
// CHECK-SAME: %[[ARG2:.*]]: memref<16xi8>
func @main(%value0: tensor<2x2xf32>, %value1: tensor<2x2xf32>) -> tensor<2x2xf32> {
// CHECK: %[[VIEW:.*]] = {{.*}} memref<16xi8> to memref<2x2xf32>
// CHECK: lhlo.add
// CHECK: lmhlo.add
// CHECK-SAME: %[[ARG0]], %[[ARG1]], %[[VIEW]]
// CHECK-NEXT: return
%res = "mhlo.add"(%value0, %value1) : (tensor<2x2xf32>, tensor<2x2xf32>) -> tensor<2x2xf32>
@ -29,12 +29,12 @@ func @main(%value0: tensor<2x2xf32>, %value1: tensor<2x2xf32>) -> tensor<2x2xf32
// -----
// CHECK-LABEL: func @main
// CHECK-SAME: %[[ARG0:.*]]: memref<2x2xi32> {xla_lhlo.params = 0
// CHECK-SAME: %[[ARG1:.*]]: memref<2x2xi32> {xla_lhlo.params = 1
// CHECK-SAME: %[[ARG0:.*]]: memref<2x2xi32> {lmhlo.params = 0
// CHECK-SAME: %[[ARG1:.*]]: memref<2x2xi32> {lmhlo.params = 1
// CHECK-SAME: %[[ARG2:.*]]: memref<16xi8>
func @main(%value0: tensor<2x2xi32>, %value1: tensor<2x2xi32>) -> tensor<2x2xi32> {
// CHECK: %[[VIEW:.*]] = {{.*}} memref<16xi8> to memref<2x2xi32>
// CHECK: lhlo.and
// CHECK: lmhlo.and
// CHECK-SAME: %[[ARG0]], %[[ARG1]], %[[VIEW]]
// CHECK-NEXT: return
%res = "mhlo.and"(%value0, %value1) : (tensor<2x2xi32>, tensor<2x2xi32>) -> tensor<2x2xi32>
@ -44,11 +44,11 @@ func @main(%value0: tensor<2x2xi32>, %value1: tensor<2x2xi32>) -> tensor<2x2xi32
// -----
// CHECK-LABEL: func @main
// CHECK-SAME: %[[ARG0:.*]]: memref<2x2xf32> {xla_lhlo.params = 0
// CHECK-SAME: %[[ARG0:.*]]: memref<2x2xf32> {lmhlo.params = 0
// CHECK-SAME: %[[ARG1:.*]]: memref<16xi8>
func @main(%value0: tensor<2x2xf32>) -> tensor<2x2xf32> {
// CHECK: %[[VIEW:.*]] = {{.*}} memref<16xi8> to memref<2x2xf32>
// CHECK: lhlo.ceil
// CHECK: lmhlo.ceil
// CHECK-SAME: %[[ARG0]], %[[VIEW]]
%res = "mhlo.ceil"(%value0) : (tensor<2x2xf32>) -> tensor<2x2xf32>
return %res : tensor<2x2xf32>
@ -57,12 +57,12 @@ func @main(%value0: tensor<2x2xf32>) -> tensor<2x2xf32> {
// -----
// CHECK-LABEL: func @main
// CHECK-SAME: %[[ARG0:.*]]: memref<1x2xf32> {xla_lhlo.params = 0
// CHECK-SAME: %[[ARG1:.*]]: memref<1x2xf32> {xla_lhlo.params = 1
// CHECK-SAME: %[[ARG0:.*]]: memref<1x2xf32> {lmhlo.params = 0
// CHECK-SAME: %[[ARG1:.*]]: memref<1x2xf32> {lmhlo.params = 1
// CHECK-SAME: %[[ARG2:.*]]: memref<16xi8>
func @main(%value0: tensor<1x2xf32>, %value1: tensor<1x2xf32>) -> tensor<1x2xcomplex<f32>> {
// CHECK: %[[VIEW:.*]] = {{.*}} memref<16xi8> to memref<1x2xcomplex<f32>>
// CHECK: lhlo.complex
// CHECK: lmhlo.complex
// CHECK-SAME: %[[ARG0]], %[[ARG1]], %[[VIEW]]
// CHECK-NEXT: return
%res = "mhlo.complex"(%value0, %value1) : (tensor<1x2xf32>, tensor<1x2xf32>) -> (tensor<1x2xcomplex<f32>>)
@ -72,11 +72,11 @@ func @main(%value0: tensor<1x2xf32>, %value1: tensor<1x2xf32>) -> tensor<1x2xcom
// -----
// CHECK-LABEL: func @main
// CHECK-SAME: %[[ARG0:.*]]: memref<1x2xcomplex<f32>> {xla_lhlo.params = 0
// CHECK-SAME: %[[ARG0:.*]]: memref<1x2xcomplex<f32>> {lmhlo.params = 0
// CHECK-SAME: %[[ARG1:.*]]: memref<16xi8>
func @main(%value0: tensor<1x2xcomplex<f32>>) -> tensor<1x2xcomplex<f32>> {
// CHECK: %[[VIEW:.*]] = {{.*}} memref<16xi8> to memref<1x2xcomplex<f32>>
// CHECK: lhlo.cosine
// CHECK: lmhlo.cosine
// CHECK-SAME: %[[ARG0]], %[[VIEW]]
// CHECK-NEXT: return
%res = "mhlo.cosine"(%value0) : (tensor<1x2xcomplex<f32>>) -> tensor<1x2xcomplex<f32>>
@ -86,12 +86,12 @@ func @main(%value0: tensor<1x2xcomplex<f32>>) -> tensor<1x2xcomplex<f32>> {
// -----
// CHECK-LABEL: func @main
// CHECK-SAME: %[[ARG0:.*]]: memref<2x2xf32> {xla_lhlo.params = 0
// CHECK-SAME: %[[ARG1:.*]]: memref<2x2xf32> {xla_lhlo.params = 1
// CHECK-SAME: %[[ARG0:.*]]: memref<2x2xf32> {lmhlo.params = 0
// CHECK-SAME: %[[ARG1:.*]]: memref<2x2xf32> {lmhlo.params = 1
// CHECK-SAME: %[[ARG2:.*]]: memref<16xi8>
func @main(%value0: tensor<2x2xf32>, %value1: tensor<2x2xf32>) -> tensor<2x2xf32> {
// CHECK: %[[VIEW:.*]] = {{.*}} memref<16xi8> to memref<2x2xf32>
// CHECK: lhlo.divide
// CHECK: lmhlo.divide
// CHECK-SAME: %[[ARG0]], %[[ARG1]], %[[VIEW]]
// CHECK-NEXT: return
%res = "mhlo.divide"(%value0, %value1) : (tensor<2x2xf32>, tensor<2x2xf32>) -> tensor<2x2xf32>
@ -101,11 +101,11 @@ func @main(%value0: tensor<2x2xf32>, %value1: tensor<2x2xf32>) -> tensor<2x2xf32
// -----
// CHECK-LABEL: func @main
// CHECK-SAME: %[[ARG0:.*]]: memref<2x2xf32> {xla_lhlo.params = 0
// CHECK-SAME: %[[ARG0:.*]]: memref<2x2xf32> {lmhlo.params = 0
// CHECK-SAME: %[[ARG1:.*]]: memref<16xi8>
func @main(%value0: tensor<2x2xf32>) -> tensor<2x2xf32> {
// CHECK: %[[VIEW:.*]] = {{.*}} memref<16xi8> to memref<2x2xf32>
// CHECK: lhlo.exponential
// CHECK: lmhlo.exponential
// CHECK-SAME: %[[ARG0]], %[[VIEW]]
%res = "mhlo.exponential"(%value0) : (tensor<2x2xf32>) -> tensor<2x2xf32>
return %res : tensor<2x2xf32>
@ -114,11 +114,11 @@ func @main(%value0: tensor<2x2xf32>) -> tensor<2x2xf32> {
// -----
// CHECK-LABEL: func @main
// CHECK-SAME: %[[ARG0:.*]]: memref<2x2xf32> {xla_lhlo.params = 0
// CHECK-SAME: %[[ARG0:.*]]: memref<2x2xf32> {lmhlo.params = 0
// CHECK-SAME: %[[ARG1:.*]]: memref<16xi8>
func @main(%value0: tensor<2x2xf32>) -> tensor<2x2xf32> {
// CHECK: %[[VIEW:.*]] = {{.*}} memref<16xi8> to memref<2x2xf32>
// CHECK: lhlo.log
// CHECK: lmhlo.log
// CHECK-SAME: %[[ARG0]], %[[VIEW]]
%res = "mhlo.log"(%value0) : (tensor<2x2xf32>) -> tensor<2x2xf32>
return %res : tensor<2x2xf32>
@ -127,12 +127,12 @@ func @main(%value0: tensor<2x2xf32>) -> tensor<2x2xf32> {
// -----
// CHECK-LABEL: func @main
// CHECK-SAME: %[[ARG0:.*]]: memref<2x2xf32> {xla_lhlo.params = 0
// CHECK-SAME: %[[ARG1:.*]]: memref<2x2xf32> {xla_lhlo.params = 1
// CHECK-SAME: %[[ARG0:.*]]: memref<2x2xf32> {lmhlo.params = 0
// CHECK-SAME: %[[ARG1:.*]]: memref<2x2xf32> {lmhlo.params = 1
// CHECK-SAME: %[[ARG2:.*]]: memref<16xi8>
func @main(%value0: tensor<2x2xf32>, %value1: tensor<2x2xf32>) -> tensor<2x2xf32> {
// CHECK: %[[VIEW:.*]] = {{.*}} memref<16xi8> to memref<2x2xf32>
// CHECK: lhlo.maximum
// CHECK: lmhlo.maximum
// CHECK-SAME: %[[ARG0]], %[[ARG1]], %[[VIEW]]
// CHECK-NEXT: return
%res = "mhlo.maximum"(%value0, %value1) : (tensor<2x2xf32>, tensor<2x2xf32>) -> tensor<2x2xf32>
@ -142,12 +142,12 @@ func @main(%value0: tensor<2x2xf32>, %value1: tensor<2x2xf32>) -> tensor<2x2xf32
// -----
// CHECK-LABEL: func @main
// CHECK-SAME: %[[ARG0:.*]]: memref<2x2xf32> {xla_lhlo.params = 0
// CHECK-SAME: %[[ARG1:.*]]: memref<2x2xf32> {xla_lhlo.params = 1
// CHECK-SAME: %[[ARG0:.*]]: memref<2x2xf32> {lmhlo.params = 0
// CHECK-SAME: %[[ARG1:.*]]: memref<2x2xf32> {lmhlo.params = 1
// CHECK-SAME: %[[ARG2:.*]]: memref<16xi8>
func @main(%value0: tensor<2x2xf32>, %value1: tensor<2x2xf32>) -> tensor<2x2xf32> {
// CHECK: %[[VIEW:.*]] = {{.*}} memref<16xi8> to memref<2x2xf32>
// CHECK: lhlo.minimum
// CHECK: lmhlo.minimum
// CHECK-SAME: %[[ARG0]], %[[ARG1]], %[[VIEW]]
// CHECK-NEXT: return
%res = "mhlo.minimum"(%value0, %value1) : (tensor<2x2xf32>, tensor<2x2xf32>) -> tensor<2x2xf32>
@ -157,12 +157,12 @@ func @main(%value0: tensor<2x2xf32>, %value1: tensor<2x2xf32>) -> tensor<2x2xf32
// -----
// CHECK-LABEL: func @main
// CHECK-SAME: %[[ARG0:.*]]: memref<2x2xf32> {xla_lhlo.params = 0
// CHECK-SAME: %[[ARG1:.*]]: memref<2x2xf32> {xla_lhlo.params = 1
// CHECK-SAME: %[[ARG0:.*]]: memref<2x2xf32> {lmhlo.params = 0
// CHECK-SAME: %[[ARG1:.*]]: memref<2x2xf32> {lmhlo.params = 1
// CHECK-SAME: %[[ARG2:.*]]: memref<16xi8>
func @main(%value0: tensor<2x2xf32>, %value1: tensor<2x2xf32>) -> tensor<2x2xf32> {
// CHECK: %[[VIEW:.*]] = {{.*}} memref<16xi8> to memref<2x2xf32>
// CHECK: lhlo.multiply
// CHECK: lmhlo.multiply
// CHECK-SAME: %[[ARG0]], %[[ARG1]], %[[VIEW]]
// CHECK-NEXT: return
%res = "mhlo.multiply"(%value0, %value1) : (tensor<2x2xf32>, tensor<2x2xf32>) -> tensor<2x2xf32>
@ -172,11 +172,11 @@ func @main(%value0: tensor<2x2xf32>, %value1: tensor<2x2xf32>) -> tensor<2x2xf32
// -----
// CHECK-LABEL: func @main
// CHECK-SAME: %[[ARG0:.*]]: memref<2x2xf32> {xla_lhlo.params = 0
// CHECK-SAME: %[[ARG0:.*]]: memref<2x2xf32> {lmhlo.params = 0
// CHECK-SAME: %[[ARG1:.*]]: memref<16xi8>
func @main(%value0: tensor<2x2xf32>) -> tensor<2x2xf32> {
// CHECK: %[[VIEW:.*]] = {{.*}} memref<16xi8> to memref<2x2xf32>
// CHECK: lhlo.negate
// CHECK: lmhlo.negate
// CHECK-SAME: %[[ARG0]], %[[VIEW]]
%res = "mhlo.negate"(%value0) : (tensor<2x2xf32>) -> tensor<2x2xf32>
return %res : tensor<2x2xf32>
@ -185,11 +185,11 @@ func @main(%value0: tensor<2x2xf32>) -> tensor<2x2xf32> {
// -----
// CHECK-LABEL: func @main
// CHECK-SAME: %[[ARG0:.*]]: memref<1x2xcomplex<f32>> {xla_lhlo.params = 0
// CHECK-SAME: %[[ARG0:.*]]: memref<1x2xcomplex<f32>> {lmhlo.params = 0
// CHECK-SAME: %[[ARG1:.*]]: memref<8xi8>
func @main(%value0: tensor<1x2xcomplex<f32>>) -> tensor<1x2xf32> {
// CHECK: %[[VIEW:.*]] = {{.*}} memref<8xi8> to memref<1x2xf32>
// CHECK: lhlo.real
// CHECK: lmhlo.real
// CHECK-SAME: %[[ARG0]], %[[VIEW]]
%res = "mhlo.real"(%value0) : (tensor<1x2xcomplex<f32>>) -> (tensor<1x2xf32>)
return %res : tensor<1x2xf32>
@ -198,11 +198,11 @@ func @main(%value0: tensor<1x2xcomplex<f32>>) -> tensor<1x2xf32> {
// -----
// CHECK-LABEL: func @main
// CHECK-SAME: %[[ARG0:.*]]: memref<1x2xcomplex<f32>> {xla_lhlo.params = 0
// CHECK-SAME: %[[ARG0:.*]]: memref<1x2xcomplex<f32>> {lmhlo.params = 0
// CHECK-SAME: %[[ARG1:.*]]: memref<8xi8>
func @main(%value0: tensor<1x2xcomplex<f32>>) -> tensor<1x2xf32> {
// CHECK: %[[VIEW:.*]] = {{.*}} memref<8xi8> to memref<1x2xf32>
// CHECK: lhlo.imag
// CHECK: lmhlo.imag
// CHECK-SAME: %[[ARG0]], %[[VIEW]]
%res = "mhlo.imag"(%value0) : (tensor<1x2xcomplex<f32>>) -> (tensor<1x2xf32>)
return %res : tensor<1x2xf32>
@ -211,12 +211,12 @@ func @main(%value0: tensor<1x2xcomplex<f32>>) -> tensor<1x2xf32> {
// -----
// CHECK-LABEL: func @main
// CHECK-SAME: %[[ARG0:.*]]: memref<2x2xi32> {xla_lhlo.params = 0
// CHECK-SAME: %[[ARG1:.*]]: memref<2x2xi32> {xla_lhlo.params = 1
// CHECK-SAME: %[[ARG0:.*]]: memref<2x2xi32> {lmhlo.params = 0
// CHECK-SAME: %[[ARG1:.*]]: memref<2x2xi32> {lmhlo.params = 1
// CHECK-SAME: %[[ARG2:.*]]: memref<16xi8>
func @main(%value0: tensor<2x2xi32>, %value1: tensor<2x2xi32>) -> tensor<2x2xi32> {
// CHECK: %[[VIEW:.*]] = {{.*}} memref<16xi8> to memref<2x2xi32>
// CHECK: lhlo.remainder
// CHECK: lmhlo.remainder
// CHECK-SAME: %[[ARG0]], %[[ARG1]], %[[VIEW]]
// CHECK-NEXT: return
%res = "mhlo.remainder"(%value0, %value1) : (tensor<2x2xi32>, tensor<2x2xi32>) -> tensor<2x2xi32>
@ -226,11 +226,11 @@ func @main(%value0: tensor<2x2xi32>, %value1: tensor<2x2xi32>) -> tensor<2x2xi32
// -----
// CHECK-LABEL: func @main
// CHECK-SAME: %[[ARG0:.*]]: memref<2x2xf32> {xla_lhlo.params = 0
// CHECK-SAME: %[[ARG0:.*]]: memref<2x2xf32> {lmhlo.params = 0
// CHECK-SAME: %[[ARG1:.*]]: memref<16xi8>
func @main(%value0: tensor<2x2xf32>) -> tensor<2x2xf32> {
// CHECK: %[[VIEW:.*]] = {{.*}} memref<16xi8> to memref<2x2xf32>
// CHECK: lhlo.rsqrt
// CHECK: lmhlo.rsqrt
// CHECK-SAME: %[[ARG0]], %[[VIEW]]
%res = "mhlo.rsqrt"(%value0) : (tensor<2x2xf32>) -> tensor<2x2xf32>
return %res : tensor<2x2xf32>
@ -239,13 +239,13 @@ func @main(%value0: tensor<2x2xf32>) -> tensor<2x2xf32> {
// -----
// CHECK-LABEL: func @main
// CHECK-SAME: %[[ARG0:.*]]: memref<2x2xi1> {xla_lhlo.params = 0
// CHECK-SAME: %[[ARG1:.*]]: memref<2x2xf32> {xla_lhlo.params = 1
// CHECK-SAME: %[[ARG2:.*]]: memref<2x2xf32> {xla_lhlo.params = 2
// CHECK-SAME: %[[ARG0:.*]]: memref<2x2xi1> {lmhlo.params = 0
// CHECK-SAME: %[[ARG1:.*]]: memref<2x2xf32> {lmhlo.params = 1
// CHECK-SAME: %[[ARG2:.*]]: memref<2x2xf32> {lmhlo.params = 2
// CHECK-SAME: %[[ARG3:.*]]: memref<16xi8>
func @main(%pred: tensor<2x2xi1>, %lhs: tensor<2x2xf32>, %rhs: tensor<2x2xf32>) -> tensor<2x2xf32> {
// CHECK: %[[VIEW:.*]] = {{.*}} memref<16xi8> to memref<2x2xf32>
// CHECK: lhlo.select
// CHECK: lmhlo.select
// CHECK-SAME: %[[ARG0]], %[[ARG1]], %[[ARG2]], %[[VIEW]]
// CHECK-NEXT: return
%0 = "mhlo.select"(%pred, %lhs, %rhs) : (tensor<2x2xi1>, tensor<2x2xf32>, tensor<2x2xf32>) -> (tensor<2x2xf32>)
@ -255,11 +255,11 @@ func @main(%pred: tensor<2x2xi1>, %lhs: tensor<2x2xf32>, %rhs: tensor<2x2xf32>)
// -----
// CHECK-LABEL: func @main
// CHECK-SAME: %[[ARG0:.*]]: memref<2x2xf32> {xla_lhlo.params = 0
// CHECK-SAME: %[[ARG0:.*]]: memref<2x2xf32> {lmhlo.params = 0
// CHECK-SAME: %[[ARG1:.*]]: memref<16xi8>
func @main(%value0: tensor<2x2xf32>) -> tensor<2x2xf32> {
// CHECK: %[[VIEW:.*]] = {{.*}} memref<16xi8> to memref<2x2xf32>
// CHECK: lhlo.sign
// CHECK: lmhlo.sign
// CHECK-SAME: %[[ARG0]], %[[VIEW]]
%res = "mhlo.sign"(%value0) : (tensor<2x2xf32>) -> tensor<2x2xf32>
return %res : tensor<2x2xf32>
@ -268,11 +268,11 @@ func @main(%value0: tensor<2x2xf32>) -> tensor<2x2xf32> {
// -----
// CHECK-LABEL: func @main
// CHECK-SAME: %[[ARG0:.*]]: memref<2x2xf32> {xla_lhlo.params = 0
// CHECK-SAME: %[[ARG0:.*]]: memref<2x2xf32> {lmhlo.params = 0
// CHECK-SAME: %[[ARG1:.*]]: memref<16xi8>
func @main(%value0: tensor<2x2xf32>) -> tensor<2x2xf32> {
// CHECK: %[[VIEW:.*]] = {{.*}} memref<16xi8> to memref<2x2xf32>
// CHECK: lhlo.sqrt
// CHECK: lmhlo.sqrt
// CHECK-SAME: %[[ARG0]], %[[VIEW]]
%res = "mhlo.sqrt"(%value0) : (tensor<2x2xf32>) -> tensor<2x2xf32>
return %res : tensor<2x2xf32>
@ -281,12 +281,12 @@ func @main(%value0: tensor<2x2xf32>) -> tensor<2x2xf32> {
// -----
// CHECK-LABEL: func @main
// CHECK-SAME: %[[ARG0:.*]]: memref<2x2xi32> {xla_lhlo.params = 0
// CHECK-SAME: %[[ARG1:.*]]: memref<2x2xi32> {xla_lhlo.params = 1
// CHECK-SAME: %[[ARG0:.*]]: memref<2x2xi32> {lmhlo.params = 0
// CHECK-SAME: %[[ARG1:.*]]: memref<2x2xi32> {lmhlo.params = 1
// CHECK-SAME: %[[ARG2:.*]]: memref<16xi8>
func @main(%value0: tensor<2x2xi32>, %value1: tensor<2x2xi32>) -> tensor<2x2xi32> {
// CHECK: %[[VIEW:.*]] = {{.*}} memref<16xi8> to memref<2x2xi32>
// CHECK: lhlo.subtract
// CHECK: lmhlo.subtract
// CHECK-SAME: %[[ARG0]], %[[ARG1]], %[[VIEW]]
// CHECK-NEXT: return
%res = "mhlo.subtract"(%value0, %value1) : (tensor<2x2xi32>, tensor<2x2xi32>) -> tensor<2x2xi32>
@ -296,11 +296,11 @@ func @main(%value0: tensor<2x2xi32>, %value1: tensor<2x2xi32>) -> tensor<2x2xi32
// -----
// CHECK-LABEL: func @main
// CHECK-SAME: %[[ARG0:.*]]: memref<2x2xf32> {xla_lhlo.params = 0
// CHECK-SAME: %[[ARG0:.*]]: memref<2x2xf32> {lmhlo.params = 0
// CHECK-SAME: %[[ARG1:.*]]: memref<16xi8>
func @main(%value0: tensor<2x2xf32>) -> tensor<2x2xf32> {
// CHECK: %[[VIEW:.*]] = {{.*}} memref<16xi8> to memref<2x2xf32>
// CHECK: lhlo.tanh
// CHECK: lmhlo.tanh
// CHECK-SAME: %[[ARG0]], %[[VIEW]]
%res = "mhlo.tanh"(%value0) : (tensor<2x2xf32>) -> tensor<2x2xf32>
return %res : tensor<2x2xf32>
@ -311,11 +311,11 @@ func @main(%value0: tensor<2x2xf32>) -> tensor<2x2xf32> {
// CHECK-LABEL: func @main
// CHECK-SAME: %[[ARG0:.*]]: memref<5x5xi32>
// CHECK-SAME: %[[ARG1:.*]]: memref<5x5xf32>
// CHECK-SAME: %[[ARG2:.*]]: memref<100xi8> {xla_lhlo.alloc = 0
// CHECK-SAME: %[[ARG3:.*]]: memref<100xi8> {xla_lhlo.alloc = 1
// CHECK-SAME: %[[ARG2:.*]]: memref<100xi8> {lmhlo.alloc = 0
// CHECK-SAME: %[[ARG3:.*]]: memref<100xi8> {lmhlo.alloc = 1
// CHECK: %[[VIEW0:.*]] = std.view %[[ARG2]]{{.*}} : memref<100xi8> to memref<5x5xi32>
// CHECK: %[[VIEW1:.*]] = std.view %[[ARG3]]{{.*}} : memref<100xi8> to memref<5x5xf32>
// CHECK: "xla_lhlo.sort"(%[[ARG0]], %[[ARG1]], %[[VIEW0]], %[[VIEW1]])
// CHECK: "lmhlo.sort"(%[[ARG0]], %[[ARG1]], %[[VIEW0]], %[[VIEW1]])
func @main(%key: tensor<5x5xi32>, %value: tensor<5x5xf32>) -> tuple<tensor<5x5xi32>, tensor<5x5xf32>> {
%res = "mhlo.sort"(%key, %value) ({
^bb0(%a: tensor<i32>, %b: tensor<i32>, %c: tensor<f32>, %d: tensor<f32>):

View File

@ -3,14 +3,14 @@
// Current allocation will lead to one buffer argument for the "value" and
// another one for the output, an no returned values.
// CHECK-LABEL: func @main
// CHECK-SAME: %[[ARG0:.*]]: memref<2x2xf32> {xla_lhlo.params = 0 : index},
// CHECK-SAME: %[[ARG1:.*]]: memref<16xi8> {xla_lhlo.alloc = 0 : index, xla_lhlo.liveout = true}
// CHECK-SAME: %[[ARG0:.*]]: memref<2x2xf32> {lmhlo.params = 0 : index},
// CHECK-SAME: %[[ARG1:.*]]: memref<16xi8> {lmhlo.alloc = 0 : index, lmhlo.liveout = true}
// CHECK-SAME: ) {
func @main(%value: tensor<2x2xf32>) -> tensor<2x2xf32> {
// The only expected instruction is a copy from the input into the output.
// CHECK: %[[C0:.*]] = constant 0 : index
// CHECK: %[[OUTPUT:.*]] = std.view %[[ARG1]][%[[C0]]][] : memref<16xi8> to memref<2x2xf32>
// CHECK: xla_lhlo.copy
// CHECK: lmhlo.copy
// CHECK-SAME: %[[ARG0]], %[[OUTPUT]]
return %value : tensor<2x2xf32>
}

View File

@ -23,8 +23,8 @@ func @unknown_op(%arg0: tensor<2xf32>) -> tensor<2xf32> {
return %0 : tensor<2xf32>
}
// CHECK-LABEL: not_whitelisted_op
func @not_whitelisted_op(%arg0: tensor<3xi32>, %arg1: tensor<i32>, %arg2: tensor<i32>) -> tensor<?x?x?xf32> {
// CHECK-LABEL: not_allowlisted_op
func @not_allowlisted_op(%arg0: tensor<3xi32>, %arg1: tensor<i32>, %arg2: tensor<i32>) -> tensor<?x?x?xf32> {
// CHECK: tf.TensorListReserve
%0 = "tf.TensorListReserve"(%arg0, %arg1) : (tensor<3xi32>, tensor<i32>) -> tensor<!tf.variant<tensor<?x?x?xf32>>>
// CHECK: tf.TensorListGetItem

View File

@ -54,7 +54,7 @@ func @fusedBatchNormV3_training(%arg0: tensor<8x8x8x8xf32>, %arg1: tensor<8xf32>
// CHECK: "mhlo.get_tuple_element"(%[[RESULT0]]) {index = 1 : i32} : (tuple<tensor<8x8x8x8xf32>, tensor<8xf32>, tensor<8xf32>>) -> tensor<8xf32>
// CHECK: %[[VAR:.*]] = "mhlo.get_tuple_element"(%[[RESULT0]]) {index = 2 : i32} : (tuple<tensor<8x8x8x8xf32>, tensor<8xf32>, tensor<8xf32>>) -> tensor<8xf32>
// CHECK: mhlo.constant
// CHECK: xla_chlo.broadcast_multiply %[[VAR]], {{.*}} : (tensor<8xf32>, tensor<f32>) -> tensor<8xf32>
// CHECK: chlo.broadcast_multiply %[[VAR]], {{.*}} : (tensor<8xf32>, tensor<f32>) -> tensor<8xf32>
return %0#0 : tensor<8x8x8x8xf32>
}
@ -75,18 +75,18 @@ func @fusedBatchNormV3_training_exponentialAvgFactor(%arg0: tensor<8x8x8x8xf32>,
// CHECK-DAG: %[[BATCH_VAR:.*]] = "mhlo.get_tuple_element"(%[[RESULT0]]) {index = 2 : i32}
// CHECK: %[[FACTOR:.*]] = mhlo.constant dense<1.00195694>
// CHECK: %[[CORRECTED_VAR:.*]] = xla_chlo.broadcast_multiply %[[BATCH_VAR]], %[[FACTOR]]
// CHECK: %[[CORRECTED_VAR:.*]] = chlo.broadcast_multiply %[[BATCH_VAR]], %[[FACTOR]]
// CHECK-DAG: %[[ALPHA:.*]] = mhlo.constant dense<0.199999988>
// CHECK-DAG: %[[BETA:.*]] = mhlo.constant dense<8.000000e-01>
// CHECK: %[[ALPHA_MUL_OLD_MEAN:.*]] = xla_chlo.broadcast_multiply %[[ALPHA]], %arg3
// CHECK: %[[BETA_MUL_BATCH_MEAN:.*]] = xla_chlo.broadcast_multiply %[[BETA]], %[[BATCH_MEAN]]
// CHECK: %[[NEW_BATCH_MEAN:.*]] = xla_chlo.broadcast_add %[[ALPHA_MUL_OLD_MEAN]], %[[BETA_MUL_BATCH_MEAN]]
// CHECK: %[[ALPHA_MUL_OLD_MEAN:.*]] = chlo.broadcast_multiply %[[ALPHA]], %arg3
// CHECK: %[[BETA_MUL_BATCH_MEAN:.*]] = chlo.broadcast_multiply %[[BETA]], %[[BATCH_MEAN]]
// CHECK: %[[NEW_BATCH_MEAN:.*]] = chlo.broadcast_add %[[ALPHA_MUL_OLD_MEAN]], %[[BETA_MUL_BATCH_MEAN]]
// CHECK: %[[ALPHA_MUL_OLD_VAR:.*]] = xla_chlo.broadcast_multiply %[[ALPHA]], %arg4
// CHECK: %[[BETA_MUL_CORRECTED_VAR:.*]] = xla_chlo.broadcast_multiply %[[BETA]], %[[CORRECTED_VAR]]
// CHECK: %[[NEW_BATCH_VAR:.*]] = xla_chlo.broadcast_add %[[ALPHA_MUL_OLD_VAR]], %[[BETA_MUL_CORRECTED_VAR]]
// CHECK: %[[ALPHA_MUL_OLD_VAR:.*]] = chlo.broadcast_multiply %[[ALPHA]], %arg4
// CHECK: %[[BETA_MUL_CORRECTED_VAR:.*]] = chlo.broadcast_multiply %[[BETA]], %[[CORRECTED_VAR]]
// CHECK: %[[NEW_BATCH_VAR:.*]] = chlo.broadcast_add %[[ALPHA_MUL_OLD_VAR]], %[[BETA_MUL_CORRECTED_VAR]]
// CHECK: return %[[NEW_BATCH_MEAN]], %[[NEW_BATCH_VAR]], %[[BATCH_MEAN]], %[[BATCH_VAR]]
return %0#1, %0#2, %0#3, %0#4 : tensor<8xf32>, tensor<8xf32>, tensor<8xf32>, tensor<8xf32>
@ -134,7 +134,7 @@ func @fusedBatchNormGrad_noTraining(%arg0: tensor<8x8x8x8xf32>, %arg1: tensor<8x
// CHECK-NEXT: %[[act:.*]] = "mhlo.convert"(%arg1) : (tensor<8x8x8x8xf32>) -> tensor<8x8x8x8xf32>
// CHECK-NEXT: %[[eps:.*]] = mhlo.constant dense<1.000000e-03> : tensor<f32>
// CHECK-NEXT: %[[add:.*]] = xla_chlo.broadcast_add %arg4, %[[eps]] {broadcast_dimensions = dense<[]> : tensor<0xi64>} : (tensor<8xf32>, tensor<f32>) -> tensor<8xf32>
// CHECK-NEXT: %[[add:.*]] = chlo.broadcast_add %arg4, %[[eps]] {broadcast_dimensions = dense<[]> : tensor<0xi64>} : (tensor<8xf32>, tensor<f32>) -> tensor<8xf32>
// CHECK-NEXT: %[[scr1:.*]] = "mhlo.rsqrt"(%[[add]]) : (tensor<8xf32>) -> tensor<8xf32>
// CHECK: %[[bcast_arg3:.+]] = "mhlo.dynamic_broadcast_in_dim"(%arg3, {{.*}}) {broadcast_dimensions = dense<3> : tensor<1xi64>} : (tensor<8xf32>, tensor<4xindex>) -> tensor<8x8x8x8xf32>
@ -193,7 +193,7 @@ func @fusedBatchNormGradV2_noTraining(%arg0: tensor<8x8x8x8xf32>, %arg1: tensor<
// CHECK-NEXT: %[[act:.*]] = "mhlo.convert"(%arg1) : (tensor<8x8x8x8xf32>) -> tensor<8x8x8x8xf32>
// CHECK-NEXT: %[[eps:.*]] = mhlo.constant dense<1.000000e-03> : tensor<f32>
// CHECK-NEXT: %[[add:.*]] = xla_chlo.broadcast_add %arg4, %[[eps]] {broadcast_dimensions = dense<[]> : tensor<0xi64>} : (tensor<8xf32>, tensor<f32>) -> tensor<8xf32>
// CHECK-NEXT: %[[add:.*]] = chlo.broadcast_add %arg4, %[[eps]] {broadcast_dimensions = dense<[]> : tensor<0xi64>} : (tensor<8xf32>, tensor<f32>) -> tensor<8xf32>
// CHECK-NEXT: %[[scr1:.*]] = "mhlo.rsqrt"(%[[add]]) : (tensor<8xf32>) -> tensor<8xf32>
// CHECK: %[[bcast_arg3:.+]] = "mhlo.dynamic_broadcast_in_dim"(%arg3, {{.*}}) {broadcast_dimensions = dense<3> : tensor<1xi64>} : (tensor<8xf32>, tensor<4xindex>) -> tensor<8x8x8x8xf32>
@ -280,7 +280,7 @@ func @fusedBatchNormGradV3_noTraining(%arg0: tensor<8x8x8x8xf32>, %arg1: tensor<
// CHECK-NEXT: %[[act:.*]] = "mhlo.convert"(%arg1) : (tensor<8x8x8x8xf32>) -> tensor<8x8x8x8xf32>
// CHECK-NEXT: %[[eps:.*]] = mhlo.constant dense<1.000000e-03> : tensor<f32>
// CHECK-NEXT: %[[add:.*]] = xla_chlo.broadcast_add %arg4, %[[eps]] {broadcast_dimensions = dense<[]> : tensor<0xi64>} : (tensor<8xf32>, tensor<f32>) -> tensor<8xf32>
// CHECK-NEXT: %[[add:.*]] = chlo.broadcast_add %arg4, %[[eps]] {broadcast_dimensions = dense<[]> : tensor<0xi64>} : (tensor<8xf32>, tensor<f32>) -> tensor<8xf32>
// CHECK-NEXT: %[[scr1:.*]] = "mhlo.rsqrt"(%[[add]]) : (tensor<8xf32>) -> tensor<8xf32>
// CHECK: %[[bcast_arg3:.+]] = "mhlo.dynamic_broadcast_in_dim"(%arg3, {{.*}}) {broadcast_dimensions = dense<3> : tensor<1xi64>} : (tensor<8xf32>, tensor<4xindex>) -> tensor<8x8x8x8xf32>
@ -367,7 +367,7 @@ func @fusedBatchNormGradV3_noTraining_NCHW(%arg0: tensor<8x8x8x8xf32>, %arg1: te
// CHECK-NEXT: %[[act:.*]] = "mhlo.convert"(%arg1) : (tensor<8x8x8x8xf32>) -> tensor<8x8x8x8xf32>
// CHECK-NEXT: %[[eps:.*]] = mhlo.constant dense<1.000000e-03> : tensor<f32>
// CHECK-NEXT: %[[add:.*]] = xla_chlo.broadcast_add %arg4, %[[eps]] {broadcast_dimensions = dense<[]> : tensor<0xi64>} : (tensor<8xf32>, tensor<f32>) -> tensor<8xf32>
// CHECK-NEXT: %[[add:.*]] = chlo.broadcast_add %arg4, %[[eps]] {broadcast_dimensions = dense<[]> : tensor<0xi64>} : (tensor<8xf32>, tensor<f32>) -> tensor<8xf32>
// CHECK-NEXT: %[[scr1:.*]] = "mhlo.rsqrt"(%[[add]]) : (tensor<8xf32>) -> tensor<8xf32>
// CHECK: %[[bcast_arg3:.+]] = "mhlo.dynamic_broadcast_in_dim"(%arg3, {{.*}}) {broadcast_dimensions = dense<1> : tensor<1xi64>} : (tensor<8xf32>, tensor<4xindex>) -> tensor<8x8x8x8xf32>
@ -498,19 +498,19 @@ func @unary_einsum(%arg0: tensor<2x3xf32>) -> tensor<2x2xf32> {
// CHECK-LABEL: func @floordiv_broadcast_i32
func @floordiv_broadcast_i32(%arg0: tensor<2x3xi32>, %arg1: tensor<3xi32>) -> tensor<2x3xi32> {
// CHECK-DAG: [[ZEROS1:%.+]] = mhlo.constant dense<0>
// CHECK-DAG: [[CMP1:%.+]] = xla_chlo.broadcast_compare %arg0, [[ZEROS1]] {comparison_direction = "LT"}
// CHECK-DAG: [[CMP1:%.+]] = chlo.broadcast_compare %arg0, [[ZEROS1]] {comparison_direction = "LT"}
// CHECK-DAG: [[ZEROS2:%.+]] = mhlo.constant dense<0>
// CHECK-DAG: [[CMP2:%.+]] = xla_chlo.broadcast_compare %arg1, [[ZEROS2]] {comparison_direction = "LT"}
// CHECK-DAG: [[CMP3:%.+]] = xla_chlo.broadcast_compare [[CMP1]], [[CMP2]] {broadcast_dimensions = dense<1> : tensor<1xi64>, comparison_direction = "EQ"}
// CHECK-DAG: [[DIV1:%.+]] = xla_chlo.broadcast_divide %arg0, %arg1 {broadcast_dimensions = dense<1> : tensor<1xi64>}
// CHECK-DAG: [[CMP2:%.+]] = chlo.broadcast_compare %arg1, [[ZEROS2]] {comparison_direction = "LT"}
// CHECK-DAG: [[CMP3:%.+]] = chlo.broadcast_compare [[CMP1]], [[CMP2]] {broadcast_dimensions = dense<1> : tensor<1xi64>, comparison_direction = "EQ"}
// CHECK-DAG: [[DIV1:%.+]] = chlo.broadcast_divide %arg0, %arg1 {broadcast_dimensions = dense<1> : tensor<1xi64>}
// CHECK-DAG: [[ABS1:%.+]] = "mhlo.abs"(%arg0)
// CHECK-DAG: [[ABS2:%.+]] = "mhlo.abs"(%arg1)
// CHECK-DAG: [[ONES:%.+]] = mhlo.constant dense<1>
// CHECK-DAG: [[SUB:%.+]] = xla_chlo.broadcast_subtract [[ABS2]], [[ONES]]
// CHECK-DAG: [[ADD:%.+]] = xla_chlo.broadcast_add [[ABS1]], [[SUB]] {broadcast_dimensions = dense<1> : tensor<1xi64>}
// CHECK-DAG: [[SUB:%.+]] = chlo.broadcast_subtract [[ABS2]], [[ONES]]
// CHECK-DAG: [[ADD:%.+]] = chlo.broadcast_add [[ABS1]], [[SUB]] {broadcast_dimensions = dense<1> : tensor<1xi64>}
// CHECK-DAG: [[NEG:%.+]] = "mhlo.negate"([[ADD]])
// CHECK-DAG: [[ABS3:%.+]] = "mhlo.abs"(%arg1)
// CHECK-DAG: [[DIV2:%.+]] = xla_chlo.broadcast_divide [[NEG]], [[ABS3]] {broadcast_dimensions = dense<1> : tensor<1xi64>}
// CHECK-DAG: [[DIV2:%.+]] = chlo.broadcast_divide [[NEG]], [[ABS3]] {broadcast_dimensions = dense<1> : tensor<1xi64>}
// CHECK-DAG: [[SELECT:%.+]] = "mhlo.select"([[CMP3]], [[DIV1]], [[DIV2]])
// CHECK: return [[SELECT]]
%0 = "tf.FloorDiv"(%arg0, %arg1) : (tensor<2x3xi32>, tensor<3xi32>) -> tensor<2x3xi32>
@ -520,19 +520,19 @@ func @floordiv_broadcast_i32(%arg0: tensor<2x3xi32>, %arg1: tensor<3xi32>) -> te
// CHECK-LABEL: func @floordiv_reverse_broadcast_i32
func @floordiv_reverse_broadcast_i32(%arg0: tensor<3xi32>, %arg1: tensor<2x3xi32>) -> tensor<2x3xi32> {
// CHECK-DAG: [[ZEROS1:%.+]] = mhlo.constant dense<0>
// CHECK-DAG: [[CMP1:%.+]] = xla_chlo.broadcast_compare %arg0, [[ZEROS1]] {comparison_direction = "LT"}
// CHECK-DAG: [[CMP1:%.+]] = chlo.broadcast_compare %arg0, [[ZEROS1]] {comparison_direction = "LT"}
// CHECK-DAG: [[ZEROS2:%.+]] = mhlo.constant dense<0>
// CHECK-DAG: [[CMP2:%.+]] = xla_chlo.broadcast_compare %arg1, [[ZEROS2]] {comparison_direction = "LT"}
// CHECK-DAG: [[CMP3:%.+]] = xla_chlo.broadcast_compare [[CMP1]], [[CMP2]] {broadcast_dimensions = dense<1> : tensor<1xi64>, comparison_direction = "EQ"}
// CHECK-DAG: [[DIV1:%.+]] = xla_chlo.broadcast_divide %arg0, %arg1 {broadcast_dimensions = dense<1> : tensor<1xi64>}
// CHECK-DAG: [[CMP2:%.+]] = chlo.broadcast_compare %arg1, [[ZEROS2]] {comparison_direction = "LT"}
// CHECK-DAG: [[CMP3:%.+]] = chlo.broadcast_compare [[CMP1]], [[CMP2]] {broadcast_dimensions = dense<1> : tensor<1xi64>, comparison_direction = "EQ"}
// CHECK-DAG: [[DIV1:%.+]] = chlo.broadcast_divide %arg0, %arg1 {broadcast_dimensions = dense<1> : tensor<1xi64>}
// CHECK-DAG: [[ABS1:%.+]] = "mhlo.abs"(%arg0)
// CHECK-DAG: [[ABS2:%.+]] = "mhlo.abs"(%arg1)
// CHECK-DAG: [[ONES:%.+]] = mhlo.constant dense<1>
// CHECK-DAG: [[SUB:%.+]] = xla_chlo.broadcast_subtract [[ABS2]], [[ONES]]
// CHECK-DAG: [[ADD:%.+]] = xla_chlo.broadcast_add [[ABS1]], [[SUB]] {broadcast_dimensions = dense<1> : tensor<1xi64>}
// CHECK-DAG: [[SUB:%.+]] = chlo.broadcast_subtract [[ABS2]], [[ONES]]
// CHECK-DAG: [[ADD:%.+]] = chlo.broadcast_add [[ABS1]], [[SUB]] {broadcast_dimensions = dense<1> : tensor<1xi64>}
// CHECK-DAG: [[NEG:%.+]] = "mhlo.negate"([[ADD]])
// CHECK-DAG: [[ABS3:%.+]] = "mhlo.abs"(%arg1)
// CHECK-DAG: [[DIV2:%.+]] = xla_chlo.broadcast_divide [[NEG]], [[ABS3]]
// CHECK-DAG: [[DIV2:%.+]] = chlo.broadcast_divide [[NEG]], [[ABS3]]
// CHECK-DAG: [[SELECT:%.+]] = "mhlo.select"([[CMP3]], [[DIV1]], [[DIV2]])
// CHECK: return [[SELECT]]
%0 = "tf.FloorDiv"(%arg0, %arg1) : (tensor<3xi32>, tensor<2x3xi32>) -> tensor<2x3xi32>
@ -541,7 +541,7 @@ func @floordiv_reverse_broadcast_i32(%arg0: tensor<3xi32>, %arg1: tensor<2x3xi32
// CHECK-LABEL: func @floordiv_f32
func @floordiv_f32(%arg0: tensor<2xf32>) -> tensor<2xf32> {
// CHECK-NEXT: %[[DIV:.*]] = xla_chlo.broadcast_divide %arg0, %arg0
// CHECK-NEXT: %[[DIV:.*]] = chlo.broadcast_divide %arg0, %arg0
// CHECK-NEXT: %[[FLOOR:.*]] = "mhlo.floor"(%[[DIV]])
// CHECK-NEXT: return %[[FLOOR]] : tensor<2xf32>
%0 = "tf.FloorDiv"(%arg0, %arg0) : (tensor<2xf32>, tensor<2xf32>) -> tensor<2xf32>
@ -552,7 +552,7 @@ func @floordiv_f32(%arg0: tensor<2xf32>) -> tensor<2xf32> {
func @floordiv_bf16(%arg0: tensor<2xbf16>) -> tensor<2xbf16> {
// CHECK-NEXT: mhlo.convert
// CHECK-NEXT: mhlo.convert
// CHECK-NEXT: xla_chlo.broadcast_divide
// CHECK-NEXT: chlo.broadcast_divide
// CHECK-NEXT: mhlo.floor
// CHECK-NEXT: mhlo.convert
// CHECK-NEXT: return
@ -562,7 +562,7 @@ func @floordiv_bf16(%arg0: tensor<2xbf16>) -> tensor<2xbf16> {
// CHECK-LABEL: func @floordiv_f16_broadcast
func @floordiv_f16_broadcast(%arg0: tensor<2x3xf16>, %arg1: tensor<3xf16>) -> tensor<2x3xf16> {
// CHECK-NEXT: xla_chlo.broadcast_divide
// CHECK-NEXT: chlo.broadcast_divide
// CHECK-NEXT: mhlo.floor
// CHECK-NEXT: return
%0 = "tf.FloorDiv"(%arg0, %arg1) : (tensor<2x3xf16>, tensor<3xf16>) -> tensor<2x3xf16>
@ -572,19 +572,19 @@ func @floordiv_f16_broadcast(%arg0: tensor<2x3xf16>, %arg1: tensor<3xf16>) -> te
// CHECK-LABEL: func @floordiv_dynamic
func @floordiv_dynamic(%arg0: tensor<?x?xi32>, %arg1: tensor<?xi32>) -> tensor<?x?xi32> {
// CHECK-DAG: [[ZEROS1:%.+]] = mhlo.constant dense<0>
// CHECK-DAG: [[CMP1:%.+]] = xla_chlo.broadcast_compare %arg0, [[ZEROS1]] {comparison_direction = "LT"}
// CHECK-DAG: [[CMP1:%.+]] = chlo.broadcast_compare %arg0, [[ZEROS1]] {comparison_direction = "LT"}
// CHECK-DAG: [[ZEROS2:%.+]] = mhlo.constant dense<0>
// CHECK-DAG: [[CMP2:%.+]] = xla_chlo.broadcast_compare %arg1, [[ZEROS2]] {comparison_direction = "LT"}
// CHECK-DAG: [[CMP3:%.+]] = xla_chlo.broadcast_compare [[CMP1]], [[CMP2]] {broadcast_dimensions = dense<1> : tensor<1xi64>, comparison_direction = "EQ"}
// CHECK-DAG: [[DIV1:%.+]] = xla_chlo.broadcast_divide %arg0, %arg1 {broadcast_dimensions = dense<1> : tensor<1xi64>}
// CHECK-DAG: [[CMP2:%.+]] = chlo.broadcast_compare %arg1, [[ZEROS2]] {comparison_direction = "LT"}
// CHECK-DAG: [[CMP3:%.+]] = chlo.broadcast_compare [[CMP1]], [[CMP2]] {broadcast_dimensions = dense<1> : tensor<1xi64>, comparison_direction = "EQ"}
// CHECK-DAG: [[DIV1:%.+]] = chlo.broadcast_divide %arg0, %arg1 {broadcast_dimensions = dense<1> : tensor<1xi64>}
// CHECK-DAG: [[ABS1:%.+]] = "mhlo.abs"(%arg0)
// CHECK-DAG: [[ABS2:%.+]] = "mhlo.abs"(%arg1)
// CHECK-DAG: [[ONES:%.+]] = mhlo.constant dense<1>
// CHECK-DAG: [[SUB:%.+]] = xla_chlo.broadcast_subtract [[ABS2]], [[ONES]]
// CHECK-DAG: [[ADD:%.+]] = xla_chlo.broadcast_add [[ABS1]], [[SUB]] {broadcast_dimensions = dense<1> : tensor<1xi64>}
// CHECK-DAG: [[SUB:%.+]] = chlo.broadcast_subtract [[ABS2]], [[ONES]]
// CHECK-DAG: [[ADD:%.+]] = chlo.broadcast_add [[ABS1]], [[SUB]] {broadcast_dimensions = dense<1> : tensor<1xi64>}
// CHECK-DAG: [[NEG:%.+]] = "mhlo.negate"([[ADD]])
// CHECK-DAG: [[ABS3:%.+]] = "mhlo.abs"(%arg1)
// CHECK-DAG: [[DIV2:%.+]] = xla_chlo.broadcast_divide [[NEG]], [[ABS3]] {broadcast_dimensions = dense<1> : tensor<1xi64>}
// CHECK-DAG: [[DIV2:%.+]] = chlo.broadcast_divide [[NEG]], [[ABS3]] {broadcast_dimensions = dense<1> : tensor<1xi64>}
// CHECK-DAG: [[SELECT:%.+]] = "mhlo.select"([[CMP3]], [[DIV1]], [[DIV2]])
// CHECK: return [[SELECT]]
%0 = "tf.FloorDiv"(%arg0, %arg1) : (tensor<?x?xi32>, tensor<?xi32>) -> tensor<?x?xi32>
@ -600,15 +600,15 @@ func @floordiv_unranked(%arg0: tensor<*xi32>, %arg1: tensor<*xi32>) -> tensor<*x
// CHECK-LABEL: func @floormod_broadcast_numerator
func @floormod_broadcast_numerator(%arg0: tensor<3xi32>, %arg1: tensor<2x3xi32>) -> tensor<2x3xi32> {
// CHECK-DAG: [[REM:%.+]] = xla_chlo.broadcast_remainder %arg0, %arg1 {broadcast_dimensions = dense<1> : tensor<1xi64>}
// CHECK-DAG: [[REM:%.+]] = chlo.broadcast_remainder %arg0, %arg1 {broadcast_dimensions = dense<1> : tensor<1xi64>}
// CHECK-DAG: [[ZL:%.+]] = mhlo.constant dense<0>
// CHECK-DAG: [[CMP1:%.+]] = xla_chlo.broadcast_compare [[REM]], [[ZL]] {broadcast_dimensions = dense<1> : tensor<1xi64>, comparison_direction = "NE"}
// CHECK-DAG: [[CMP1:%.+]] = chlo.broadcast_compare [[REM]], [[ZL]] {broadcast_dimensions = dense<1> : tensor<1xi64>, comparison_direction = "NE"}
// CHECK-DAG: [[ZR:%.+]] = mhlo.constant dense<0>
// CHECK-DAG: [[CMP2:%.+]] = xla_chlo.broadcast_compare %arg1, [[ZR]] {comparison_direction = "LT"}
// CHECK-DAG: [[CMP3:%.+]] = xla_chlo.broadcast_compare [[REM]], [[ZR]] {broadcast_dimensions = dense<[]> : tensor<0xi64>, comparison_direction = "LT"}
// CHECK-DAG: [[CMP4:%.+]] = xla_chlo.broadcast_compare [[CMP2]], [[CMP3]] {comparison_direction = "NE"}
// CHECK-DAG: [[AND:%.+]] = xla_chlo.broadcast_and [[CMP1]], [[CMP4]]
// CHECK-DAG: [[ADD:%.+]] = xla_chlo.broadcast_add %arg1, [[REM]]
// CHECK-DAG: [[CMP2:%.+]] = chlo.broadcast_compare %arg1, [[ZR]] {comparison_direction = "LT"}
// CHECK-DAG: [[CMP3:%.+]] = chlo.broadcast_compare [[REM]], [[ZR]] {broadcast_dimensions = dense<[]> : tensor<0xi64>, comparison_direction = "LT"}
// CHECK-DAG: [[CMP4:%.+]] = chlo.broadcast_compare [[CMP2]], [[CMP3]] {comparison_direction = "NE"}
// CHECK-DAG: [[AND:%.+]] = chlo.broadcast_and [[CMP1]], [[CMP4]]
// CHECK-DAG: [[ADD:%.+]] = chlo.broadcast_add %arg1, [[REM]]
// CHECK-DAG: [[SELECT:%.+]] = "mhlo.select"([[AND]], [[ADD]], [[REM]])
// CHECK-NEXT: return [[SELECT]]
%0 = "tf.FloorMod"(%arg0, %arg1) : (tensor<3xi32>, tensor<2x3xi32>) -> tensor<2x3xi32>
@ -617,15 +617,15 @@ func @floormod_broadcast_numerator(%arg0: tensor<3xi32>, %arg1: tensor<2x3xi32>)
// CHECK-LABEL: func @floormod_broadcast_denominator
func @floormod_broadcast_denominator(%arg0: tensor<2x3xi32>, %arg1: tensor<3xi32>) -> tensor<2x3xi32> {
// CHECK-DAG: [[REM:%.+]] = xla_chlo.broadcast_remainder %arg0, %arg1 {broadcast_dimensions = dense<1> : tensor<1xi64>}
// CHECK-DAG: [[REM:%.+]] = chlo.broadcast_remainder %arg0, %arg1 {broadcast_dimensions = dense<1> : tensor<1xi64>}
// CHECK-DAG: [[ZL:%.+]] = mhlo.constant dense<0>
// CHECK-DAG: [[CMP1:%.+]] = xla_chlo.broadcast_compare [[REM]], [[ZL]] {comparison_direction = "NE"}
// CHECK-DAG: [[CMP1:%.+]] = chlo.broadcast_compare [[REM]], [[ZL]] {comparison_direction = "NE"}
// CHECK-DAG: [[ZR:%.+]] = mhlo.constant dense<0>
// CHECK-DAG: [[CMP2:%.+]] = xla_chlo.broadcast_compare %arg1, [[ZR]] {comparison_direction = "LT"}
// CHECK-DAG: [[CMP3:%.+]] = xla_chlo.broadcast_compare [[REM]], [[ZR]] {broadcast_dimensions = dense<[]> : tensor<0xi64>, comparison_direction = "LT"}
// CHECK-DAG: [[CMP4:%.+]] = xla_chlo.broadcast_compare [[CMP2]], [[CMP3]] {broadcast_dimensions = dense<1> : tensor<1xi64>, comparison_direction = "NE"}
// CHECK-DAG: [[AND:%.+]] = xla_chlo.broadcast_and [[CMP1]], [[CMP4]]
// CHECK-DAG: [[ADD:%.+]] = xla_chlo.broadcast_add %arg1, [[REM]] {broadcast_dimensions = dense<1> : tensor<1xi64>}
// CHECK-DAG: [[CMP2:%.+]] = chlo.broadcast_compare %arg1, [[ZR]] {comparison_direction = "LT"}
// CHECK-DAG: [[CMP3:%.+]] = chlo.broadcast_compare [[REM]], [[ZR]] {broadcast_dimensions = dense<[]> : tensor<0xi64>, comparison_direction = "LT"}
// CHECK-DAG: [[CMP4:%.+]] = chlo.broadcast_compare [[CMP2]], [[CMP3]] {broadcast_dimensions = dense<1> : tensor<1xi64>, comparison_direction = "NE"}
// CHECK-DAG: [[AND:%.+]] = chlo.broadcast_and [[CMP1]], [[CMP4]]
// CHECK-DAG: [[ADD:%.+]] = chlo.broadcast_add %arg1, [[REM]] {broadcast_dimensions = dense<1> : tensor<1xi64>}
// CHECK-DAG: [[SELECT:%.+]] = "mhlo.select"([[AND]], [[ADD]], [[REM]])
// CHECK-NEXT: return [[SELECT]]
%0 = "tf.FloorMod"(%arg0, %arg1) : (tensor<2x3xi32>, tensor<3xi32>) -> tensor<2x3xi32>
@ -634,15 +634,15 @@ func @floormod_broadcast_denominator(%arg0: tensor<2x3xi32>, %arg1: tensor<3xi32
// CHECK-LABEL: func @floormod_dynamic
func @floormod_dynamic(%arg0: tensor<?x?xi32>, %arg1: tensor<?xi32>) -> tensor<?x?xi32> {
// CHECK-DAG: [[REM:%.+]] = xla_chlo.broadcast_remainder %arg0, %arg1 {broadcast_dimensions = dense<1> : tensor<1xi64>}
// CHECK-DAG: [[REM:%.+]] = chlo.broadcast_remainder %arg0, %arg1 {broadcast_dimensions = dense<1> : tensor<1xi64>}
// CHECK-DAG: [[ZL:%.+]] = mhlo.constant dense<0>
// CHECK-DAG: [[CMP1:%.+]] = xla_chlo.broadcast_compare [[REM]], [[ZL]] {comparison_direction = "NE"}
// CHECK-DAG: [[CMP1:%.+]] = chlo.broadcast_compare [[REM]], [[ZL]] {comparison_direction = "NE"}
// CHECK-DAG: [[ZR:%.+]] = mhlo.constant dense<0>
// CHECK-DAG: [[CMP2:%.+]] = xla_chlo.broadcast_compare %arg1, [[ZR]] {comparison_direction = "LT"}
// CHECK-DAG: [[CMP3:%.+]] = xla_chlo.broadcast_compare [[REM]], [[ZR]] {broadcast_dimensions = dense<[]> : tensor<0xi64>, comparison_direction = "LT"}
// CHECK-DAG: [[CMP4:%.+]] = xla_chlo.broadcast_compare [[CMP2]], [[CMP3]] {broadcast_dimensions = dense<1> : tensor<1xi64>, comparison_direction = "NE"}
// CHECK-DAG: [[AND:%.+]] = xla_chlo.broadcast_and [[CMP1]], [[CMP4]]
// CHECK-DAG: [[ADD:%.+]] = xla_chlo.broadcast_add %arg1, [[REM]] {broadcast_dimensions = dense<1> : tensor<1xi64>}
// CHECK-DAG: [[CMP2:%.+]] = chlo.broadcast_compare %arg1, [[ZR]] {comparison_direction = "LT"}
// CHECK-DAG: [[CMP3:%.+]] = chlo.broadcast_compare [[REM]], [[ZR]] {broadcast_dimensions = dense<[]> : tensor<0xi64>, comparison_direction = "LT"}
// CHECK-DAG: [[CMP4:%.+]] = chlo.broadcast_compare [[CMP2]], [[CMP3]] {broadcast_dimensions = dense<1> : tensor<1xi64>, comparison_direction = "NE"}
// CHECK-DAG: [[AND:%.+]] = chlo.broadcast_and [[CMP1]], [[CMP4]]
// CHECK-DAG: [[ADD:%.+]] = chlo.broadcast_add %arg1, [[REM]] {broadcast_dimensions = dense<1> : tensor<1xi64>}
// CHECK-DAG: [[SELECT:%.+]] = "mhlo.select"([[AND]], [[ADD]], [[REM]])
// CHECK-NEXT: return [[SELECT]]
%0 = "tf.FloorMod"(%arg0, %arg1) : (tensor<?x?xi32>, tensor<?xi32>) -> tensor<?x?xi32>
@ -979,10 +979,10 @@ func @matrix_band_part(%arg0: tensor<64x64xbf16>, %arg1: tensor<i64>, %arg2: ten
// CHECK: %[[X:.*]] = "mhlo.iota"() {iota_dimension = 1 : i64} : () -> tensor<64x64xbf16>
// CHECK: %[[Y:.*]] = "mhlo.iota"() {iota_dimension = 0 : i64} : () -> tensor<64x64xbf16>
// CHECK: %[[OFFSET:.*]] = mhlo.subtract %[[X]], %[[Y]] : tensor<64x64xbf16>
// CHECK: %[[G:.*]] = xla_chlo.broadcast_compare %[[F]], %[[OFFSET]] {comparison_direction = "LE"} : (tensor<bf16>, tensor<64x64xbf16>) -> tensor<64x64xi1>
// CHECK: %[[G:.*]] = chlo.broadcast_compare %[[F]], %[[OFFSET]] {comparison_direction = "LE"} : (tensor<bf16>, tensor<64x64xbf16>) -> tensor<64x64xi1>
// CHECK: %[[H:.*]] = "mhlo.convert"(%[[D]]) : (tensor<i64>) -> tensor<bf16>
// CHECK: %[[I:.*]] = xla_chlo.broadcast_compare %[[OFFSET]], %[[H]] {comparison_direction = "LE"} : (tensor<64x64xbf16>, tensor<bf16>) -> tensor<64x64xi1>
// CHECK: %[[I:.*]] = chlo.broadcast_compare %[[OFFSET]], %[[H]] {comparison_direction = "LE"} : (tensor<64x64xbf16>, tensor<bf16>) -> tensor<64x64xi1>
// CHECK: %[[J:.*]] = mhlo.and %[[G]], %[[I]] : tensor<64x64xi1>
@ -1000,10 +1000,10 @@ func @matrix_band_part_2(%arg0: tensor<12x24x48xbf16>, %arg1: tensor<i64>, %arg2
// CHECK: %[[Y:.*]] = "mhlo.iota"() {iota_dimension = 0 : i64} : () -> tensor<24x48xbf16>
// CHECK: %[[OFFSET:.*]] = mhlo.subtract %[[X]], %[[Y]] : tensor<24x48xbf16>
// CHECK: %[[G:.*]] = xla_chlo.broadcast_compare %[[F]], %[[OFFSET]] {comparison_direction = "LE"} : (tensor<bf16>, tensor<24x48xbf16>) -> tensor<24x48xi1>
// CHECK: %[[G:.*]] = chlo.broadcast_compare %[[F]], %[[OFFSET]] {comparison_direction = "LE"} : (tensor<bf16>, tensor<24x48xbf16>) -> tensor<24x48xi1>
// CHECK: %[[H:.*]] = "mhlo.convert"(%[[D]]) : (tensor<i64>) -> tensor<bf16>
// CHECK: %[[I:.*]] = xla_chlo.broadcast_compare %[[OFFSET]], %[[H]] {comparison_direction = "LE"} : (tensor<24x48xbf16>, tensor<bf16>) -> tensor<24x48xi1>
// CHECK: %[[I:.*]] = chlo.broadcast_compare %[[OFFSET]], %[[H]] {comparison_direction = "LE"} : (tensor<24x48xbf16>, tensor<bf16>) -> tensor<24x48xi1>
// CHECK: %[[J:.*]] = mhlo.and %[[G]], %[[I]] : tensor<24x48xi1>
// CHECK: %[[ZERO2:.*]] = mhlo.constant dense<0.000000e+00> : tensor<12x24x48xbf16>
@ -1315,7 +1315,7 @@ func @stateful_pcall_multi_in_out(%arg0: tensor<i32>, %arg1: tensor<i32>) -> (te
// CHECK-LABEL: func @relu
func @relu(%arg0: tensor<1xi32>) -> tensor<1xi32> {
// CHECK: %[[ZERO:.*]] = mhlo.constant dense<0> : tensor<i32>
// CHECK: xla_chlo.broadcast_maximum %[[ZERO]], %arg0 {broadcast_dimensions = dense<[]> : tensor<0xi64>} : (tensor<i32>, tensor<1xi32>) -> tensor<1xi32>
// CHECK: chlo.broadcast_maximum %[[ZERO]], %arg0 {broadcast_dimensions = dense<[]> : tensor<0xi64>} : (tensor<i32>, tensor<1xi32>) -> tensor<1xi32>
%0 = "tf.Relu"(%arg0) : (tensor<1xi32>) -> tensor<1xi32>
return %0: tensor<1xi32>
}
@ -1323,7 +1323,7 @@ func @relu(%arg0: tensor<1xi32>) -> tensor<1xi32> {
// CHECK-LABEL: func @relu_unranked
func @relu_unranked(%arg0: tensor<?xi32>) -> tensor<?xi32> {
// CHECK: %[[ZERO:.*]] = mhlo.constant dense<0> : tensor<i32>
// CHECK: xla_chlo.broadcast_maximum %[[ZERO]], %arg0 {broadcast_dimensions = dense<[]> : tensor<0xi64>} : (tensor<i32>, tensor<?xi32>) -> tensor<?xi32>
// CHECK: chlo.broadcast_maximum %[[ZERO]], %arg0 {broadcast_dimensions = dense<[]> : tensor<0xi64>} : (tensor<i32>, tensor<?xi32>) -> tensor<?xi32>
%0 = "tf.Relu"(%arg0) : (tensor<?xi32>) -> tensor<?xi32>
return %0: tensor<?xi32>
}
@ -1351,7 +1351,7 @@ func @relu6_unranked(%arg0: tensor<?xi32>) -> tensor<?xi32> {
func @relu_grad(%gradients: tensor<4x8xf32>, %features: tensor<?x?xf32>) -> tensor<4x8xf32> {
// CHECK-DAG: %[[ZERO_SCALAR:.*]] = mhlo.constant dense<0.000000e+00> : tensor<f32>
// CHECK-DAG: %[[ZERO:.*]] = mhlo.constant dense<0.000000e+00> : tensor<4x8xf32>
// CHECK-DAG: %[[PRED:.*]] = xla_chlo.broadcast_compare %[[FEATURES]], %[[ZERO_SCALAR]] {comparison_direction = "GT"} : (tensor<?x?xf32>, tensor<f32>) -> tensor<?x?xi1>
// CHECK-DAG: %[[PRED:.*]] = chlo.broadcast_compare %[[FEATURES]], %[[ZERO_SCALAR]] {comparison_direction = "GT"} : (tensor<?x?xf32>, tensor<f32>) -> tensor<?x?xi1>
// CHECK-DAG: %[[RESULT:.*]] = "mhlo.select"(%[[PRED]], %[[GRADIENTS]], %[[ZERO]]) : (tensor<?x?xi1>, tensor<4x8xf32>, tensor<4x8xf32>) -> tensor<4x8xf32>
// CHECK-DAG: return %[[RESULT]] : tensor<4x8xf32>
%2 = "tf.ReluGrad"(%gradients, %features) : (tensor<4x8xf32>, tensor<?x?xf32>) -> tensor<4x8xf32>
@ -2473,10 +2473,10 @@ func @strided_slice_nonconstant_begin_end(%arg0: tensor<i32>, %arg1: tensor<32x1
// CHECK-DAG-SAME: start_indices = dense<0> : tensor<1xi64>,
// CHECK-DAG-SAME: strides = dense<1> : tensor<1xi64>} : (tensor<1xi32>) -> tensor<1xi32>
// CHECK-NEXT: %[[INDEX2:.*]] = "mhlo.reshape"(%[[INDEX]]) : (tensor<1xi32>) -> tensor<i32>
// CHECK-NEXT: %[[CMP:.*]] = xla_chlo.broadcast_compare %[[INDEX2]], %[[ZERO]]
// CHECK-NEXT: %[[CMP:.*]] = chlo.broadcast_compare %[[INDEX2]], %[[ZERO]]
// CHECK-DAG-SAME: {comparison_direction = "LT"} : (tensor<i32>, tensor<i32>) -> tensor<i1>
// CHECK-NEXT: %[[DIM:.*]] = mhlo.constant dense<32> : tensor<i32>
// CHECK-NEXT: %[[WRAP:.*]] = xla_chlo.broadcast_add %[[DIM]], %[[INDEX2]] : (tensor<i32>, tensor<i32>) -> tensor<i32>
// CHECK-NEXT: %[[WRAP:.*]] = chlo.broadcast_add %[[DIM]], %[[INDEX2]] : (tensor<i32>, tensor<i32>) -> tensor<i32>
// CHECK-NEXT: %[[INDEX3:.*]] = "mhlo.select"(%[[CMP]], %[[WRAP]], %[[INDEX2]]) :
// CHECK-DAG-SAME: (tensor<i1>, tensor<i32>, tensor<i32>) -> tensor<i32>
// CHECK-NEXT: %[[SLICED:.*]] = "mhlo.dynamic-slice"
@ -2605,7 +2605,7 @@ func @mean(%arg0: tensor<4x8xf16>) -> tensor<4x1xf16> {
// CHECK: "mhlo.return"(%[[REDUCE_BODY_RESULT]]) : (tensor<f32>) -> ()
// CHECK: }) {dimensions = dense<1> : tensor<1xi64>} : (tensor<4x8xf32>, tensor<f32>) -> tensor<4xf32>
// CHECK: %[[DIVISOR:.*]] = mhlo.constant dense<8.000000e+00> : tensor<f32>
// CHECK: %[[MEAN:.*]] = xla_chlo.broadcast_divide %[[REDUCED]], %[[DIVISOR]] {broadcast_dimensions = dense<[]> : tensor<0xi64>} : (tensor<4xf32>, tensor<f32>) -> tensor<4xf32>
// CHECK: %[[MEAN:.*]] = chlo.broadcast_divide %[[REDUCED]], %[[DIVISOR]] {broadcast_dimensions = dense<[]> : tensor<0xi64>} : (tensor<4xf32>, tensor<f32>) -> tensor<4xf32>
// CHECK: %[[CAST_BACK:.*]] = "mhlo.convert"(%[[MEAN]]) : (tensor<4xf32>) -> tensor<4xf16>
// CHECK: %[[RESULT:.*]] = "mhlo.reshape"(%[[CAST_BACK]]) : (tensor<4xf16>) -> tensor<4x1xf16>
// CHECK: return %[[RESULT]] : tensor<4x1xf16>
@ -2909,8 +2909,8 @@ func @rng_std_normal(%arg0: tensor<3xi32>) -> tensor<12x?x64xf32> {
func @range(%arg0: tensor<f32>, %arg1: tensor<f32>) -> tensor<5xf32> {
%1 = "tf.Const"() {device = "", dtype = "tfdtype$DT_FLOAT", name = "range/limit", value = dense<5.000000e+00> : tensor<f32>} : () -> tensor<f32>
// CHECK-DAG: [[IOTA:%.*]] = "mhlo.iota"
// CHECK-DAG: [[MUL:%.*]] = xla_chlo.broadcast_multiply [[IOTA]], [[DELTA]] {broadcast_dimensions = dense<[]> : tensor<0xi64>}
// CHECK: xla_chlo.broadcast_add [[MUL]], [[START]] {broadcast_dimensions = dense<[]> : tensor<0xi64>}
// CHECK-DAG: [[MUL:%.*]] = chlo.broadcast_multiply [[IOTA]], [[DELTA]] {broadcast_dimensions = dense<[]> : tensor<0xi64>}
// CHECK: chlo.broadcast_add [[MUL]], [[START]] {broadcast_dimensions = dense<[]> : tensor<0xi64>}
%3 = "tf.Range"(%arg0, %1, %arg1) {Tidx = "tfdtype$DT_FLOAT", device = "", name = "range"} : (tensor<f32>, tensor<f32>, tensor<f32>) -> tensor<5xf32>
return %3 : tensor<5xf32>
}
@ -2929,8 +2929,8 @@ func @range_dynamic(%arg0: tensor<f32>, %arg1: tensor<f32>, %arg2: tensor<f32>)
// CHECK-DAG: [[IOTA:%.+]] = "mhlo.dynamic_iota"([[RESHAPE]]) {iota_dimension = 0 : i64}
// CHECK-DAG: [[CONVERT3:%.+]] = "mhlo.convert"(%arg0)
// CHECK-DAG: [[CONVERT4:%.+]] = "mhlo.convert"(%arg2)
// CHECK-DAG: [[MUL:%.+]] = xla_chlo.broadcast_multiply [[IOTA]], [[CONVERT4]] {broadcast_dimensions = dense<[]> : tensor<0xi64>}
// CHECK-DAG: [[ADD:%.+]] = xla_chlo.broadcast_add [[MUL]], [[CONVERT3]] {broadcast_dimensions = dense<[]> : tensor<0xi64>}
// CHECK-DAG: [[MUL:%.+]] = chlo.broadcast_multiply [[IOTA]], [[CONVERT4]] {broadcast_dimensions = dense<[]> : tensor<0xi64>}
// CHECK-DAG: [[ADD:%.+]] = chlo.broadcast_add [[MUL]], [[CONVERT3]] {broadcast_dimensions = dense<[]> : tensor<0xi64>}
%2 = "tf.Range"(%arg0, %arg1, %arg2) {Tidx = "tfdtype$DT_FLOAT", device = "", name = "range"} : (tensor<f32>, tensor<f32>, tensor<f32>) -> tensor<?xf32>
// CHECK: return [[ADD]]
@ -2951,8 +2951,8 @@ func @range_int_dynamic(%arg0: tensor<i32>, %arg1: tensor<i32>, %arg2: tensor<i3
// CHECK-DAG: [[IOTA:%.+]] = "mhlo.dynamic_iota"([[RESHAPE]]) {iota_dimension = 0 : i64}
// CHECK-DAG: [[CONVERT3:%.+]] = "mhlo.convert"(%arg0)
// CHECK-DAG: [[CONVERT4:%.+]] = "mhlo.convert"(%arg2)
// CHECK-DAG: [[MUL:%.+]] = xla_chlo.broadcast_multiply [[IOTA]], [[CONVERT4]] {broadcast_dimensions = dense<[]> : tensor<0xi64>}
// CHECK-DAG: [[ADD:%.+]] = xla_chlo.broadcast_add [[MUL]], [[CONVERT3]] {broadcast_dimensions = dense<[]> : tensor<0xi64>}
// CHECK-DAG: [[MUL:%.+]] = chlo.broadcast_multiply [[IOTA]], [[CONVERT4]] {broadcast_dimensions = dense<[]> : tensor<0xi64>}
// CHECK-DAG: [[ADD:%.+]] = chlo.broadcast_add [[MUL]], [[CONVERT3]] {broadcast_dimensions = dense<[]> : tensor<0xi64>}
%2 = "tf.Range"(%arg0, %arg1, %arg2) {Tidx = "tfdtype$DT_FLOAT", device = "", name = "range"} : (tensor<i32>, tensor<i32>, tensor<i32>) -> tensor<?xi32>
// CHECK: return [[ADD]]
@ -2966,12 +2966,12 @@ func @linspace_static(%arg0: tensor<f32>, %arg1: tensor<f32>) -> tensor<4xf32> {
// CHECK-DAG: [[NUM_CAST:%.*]] = tensor_cast [[NUM]]
// CHECK-DAG: [[NUM_F32:%.*]] = "mhlo.convert"([[NUM_CAST]])
// CHECK-DAG: [[ONE:%.*]] = mhlo.constant dense<1.000000e+00>
// CHECK-DAG: [[STEP_DENOMINATOR:%.*]] = xla_chlo.broadcast_subtract [[NUM_F32]], [[ONE]]
// CHECK-DAG: [[STEP_NUMERATOR:%.*]] = xla_chlo.broadcast_subtract [[STOP]], [[START]]
// CHECK-DAG: [[STEP:%.*]] = xla_chlo.broadcast_divide [[STEP_NUMERATOR]], [[STEP_DENOMINATOR]]
// CHECK-DAG: [[STEP_DENOMINATOR:%.*]] = chlo.broadcast_subtract [[NUM_F32]], [[ONE]]
// CHECK-DAG: [[STEP_NUMERATOR:%.*]] = chlo.broadcast_subtract [[STOP]], [[START]]
// CHECK-DAG: [[STEP:%.*]] = chlo.broadcast_divide [[STEP_NUMERATOR]], [[STEP_DENOMINATOR]]
// CHECK-DAG: [[IOTA:%.*]] = "mhlo.iota"() {iota_dimension = 0 : i64}
// CHECK-DAG: [[MUL:%.*]] = xla_chlo.broadcast_multiply [[IOTA]], [[STEP]] {broadcast_dimensions = dense<[]> : tensor<0xi64>}
// CHECK-DAG: [[LINSPACE:%.*]] = xla_chlo.broadcast_add [[MUL]], [[START]] {broadcast_dimensions = dense<[]> : tensor<0xi64>}
// CHECK-DAG: [[MUL:%.*]] = chlo.broadcast_multiply [[IOTA]], [[STEP]] {broadcast_dimensions = dense<[]> : tensor<0xi64>}
// CHECK-DAG: [[LINSPACE:%.*]] = chlo.broadcast_add [[MUL]], [[START]] {broadcast_dimensions = dense<[]> : tensor<0xi64>}
// CHECK: return [[LINSPACE]]
%0 = "tf.Const"() {_output_shapes = ["tfshape$"], device = "", dtype = i32, value = dense<4> : tensor<i32>} : () -> tensor<i32>
%1 = "tf.LinSpace"(%arg0, %arg1, %0) : (tensor<f32>, tensor<f32>, tensor<i32>) -> tensor<4xf32>
@ -3266,13 +3266,13 @@ func @size_ranked(%input: tensor<2x?x8xf32>) -> (tensor<i32>) {
// CHECK: %[[CONST:.*]] = mhlo.constant dense<1>
// CHECK: %[[DIM_0:.*]] = "mhlo.get_dimension_size"(%[[INPUT]])
// CHECK-SAME: dimension = 0
// CHECK: %[[MUL_0:.*]] = xla_chlo.broadcast_multiply %[[CONST]], %[[DIM_0]]
// CHECK: %[[MUL_0:.*]] = chlo.broadcast_multiply %[[CONST]], %[[DIM_0]]
// CHECK: %[[DIM_1:.*]] = "mhlo.get_dimension_size"(%[[INPUT]])
// CHECK-SAME: dimension = 1
// CHECK: %[[MUL_1:.*]] = xla_chlo.broadcast_multiply %[[MUL_0]], %[[DIM_1]]
// CHECK: %[[MUL_1:.*]] = chlo.broadcast_multiply %[[MUL_0]], %[[DIM_1]]
// CHECK: %[[DIM_2:.*]] = "mhlo.get_dimension_size"(%[[INPUT]])
// CHECK-SAME: dimension = 2
// CHECK: %[[MUL_2:.*]] = xla_chlo.broadcast_multiply %[[MUL_1]], %[[DIM_2]]
// CHECK: %[[MUL_2:.*]] = chlo.broadcast_multiply %[[MUL_1]], %[[DIM_2]]
%size = "tf.Size"(%input) {T = "tfdtype$DT_FLOAT", out_type = "tfdtype$DT_INT32"} : (tensor<2x?x8xf32>) -> tensor<i32>
// CHECK: return %[[MUL_2]]
return %size : tensor<i32>
@ -3789,7 +3789,7 @@ func @random_shuffle_3D(%input: tensor<4x?x16xf32>) -> tensor<4x?x16xf32> {
// CHECK: [[INDICES1:%.*]] = "mhlo.dynamic-update-slice"([[INDICES]], [[TGT_IDX]], [[IV]]) : (tensor<4xi32>, tensor<1xi32>, tensor<i32>) -> tensor<4xi32>
// CHECK: [[INDICES2:%.*]] = "mhlo.dynamic-update-slice"([[INDICES1]], [[SRC_IDX]], [[SWP]]) : (tensor<4xi32>, tensor<1xi32>, tensor<i32>) -> tensor<4xi32>
// CHECK: [[ONE:%.*]] = mhlo.constant dense<1> : tensor<i32>
// CHECK: [[NEW_IV:%.*]] = xla_chlo.broadcast_add [[IV]], [[ONE]]
// CHECK: [[NEW_IV:%.*]] = chlo.broadcast_add [[IV]], [[ONE]]
// CHECK: [[NEW_TUPLE:%.*]] = "mhlo.tuple"([[NEW_IV]], [[SWAPS]], [[INDICES2]])
// CHECK: "mhlo.return"([[NEW_TUPLE]])
// CHECK: }) : (tuple<tensor<i32>, tensor<4xi32>, tensor<4xi32>>) -> tuple<tensor<i32>, tensor<4xi32>, tensor<4xi32>>
@ -3822,7 +3822,7 @@ func @avgpool_valid_padding(%arg0: tensor<2x12x20x7xf16>) -> tensor<2x3x5x7xf16>
// CHECK: "mhlo.return"([[ADD]])
// CHECK: }) {window_dimensions = dense<[1, 2, 2, 1]> : tensor<4xi64>, window_strides = dense<[1, 4, 4, 1]> : tensor<4xi64>} : (tensor<2x12x20x7xf32>, tensor<f32>) -> tensor<2x3x5x7xf32>
// CHECK: [[COUNT:%.+]] = mhlo.constant dense<4.000000e+00> : tensor<f32>
// CHECK: [[DIV:%.+]] = xla_chlo.broadcast_divide [[REDUCE]], [[COUNT]] {broadcast_dimensions = dense<[]> : tensor<0xi64>} : (tensor<2x3x5x7xf32>, tensor<f32>) -> tensor<2x3x5x7xf32>
// CHECK: [[DIV:%.+]] = chlo.broadcast_divide [[REDUCE]], [[COUNT]] {broadcast_dimensions = dense<[]> : tensor<0xi64>} : (tensor<2x3x5x7xf32>, tensor<f32>) -> tensor<2x3x5x7xf32>
// CHECK: [[CONV16:%.+]] = "mhlo.convert"([[DIV]]) : (tensor<2x3x5x7xf32>) -> tensor<2x3x5x7xf16>
// CHECK: return [[CONV16]]
%0 = "tf.AvgPool"(%arg0) {data_format = "NHWC", ksize = [1, 2, 2, 1], padding = "VALID", strides = [1, 4, 4, 1]} : (tensor<2x12x20x7xf16>) -> tensor<2x3x5x7xf16>
@ -3844,7 +3844,7 @@ func @avgpool_same_padding(%arg0: tensor<2x13x25x7xf32>) -> tensor<2x4x7x7xf32>
// CHECK-SAME: %[[OUT_GRAD:.*]]: tensor<10x12x16x64xf32>) -> tensor<10x24x32x64xf32> {
// CHECK: %[[ZERO:.*]] = mhlo.constant dense<0.000000e+00> : tensor<f32>
// CHECK: %[[DIVISOR:.*]] = mhlo.constant dense<4.000000e+00> : tensor<f32>
// CHECK: %[[OUT_GRAD_DIVIDED:.*]] = xla_chlo.broadcast_divide %[[OUT_GRAD]], %[[DIVISOR]]
// CHECK: %[[OUT_GRAD_DIVIDED:.*]] = chlo.broadcast_divide %[[OUT_GRAD]], %[[DIVISOR]]
// CHECK_SAME: broadcast_dimensions = dense<[]>
// CHECK_SAME: -> tensor<10x12x16x64xf32>
// CHECK: %[[REDUCE_WINDOW_INPUT:.*]] = "mhlo.pad"(%[[OUT_GRAD_DIVIDED]], %[[ZERO]])
@ -3876,7 +3876,7 @@ func @avgpool_grad_valid_padding(%grad: tensor<10x12x16x64xf32>) -> tensor<10x24
// CHECK-SAME: %[[OUT_GRAD:.*]]: tensor<10x8x12x16x64xf32>) -> tensor<10x8x24x32x64xf32> {
// CHECK: %[[ZERO:.*]] = mhlo.constant dense<0.000000e+00> : tensor<f32>
// CHECK: %[[DIVISOR:.*]] = mhlo.constant dense<4.000000e+00> : tensor<f32>
// CHECK: %[[OUT_GRAD_DIVIDED:.*]] = xla_chlo.broadcast_divide %[[OUT_GRAD]], %[[DIVISOR]] {broadcast_dimensions = dense<[]> : tensor<0xi64>} : (tensor<10x8x12x16x64xf32>, tensor<f32>) -> tensor<10x8x12x16x64xf32>
// CHECK: %[[OUT_GRAD_DIVIDED:.*]] = chlo.broadcast_divide %[[OUT_GRAD]], %[[DIVISOR]] {broadcast_dimensions = dense<[]> : tensor<0xi64>} : (tensor<10x8x12x16x64xf32>, tensor<f32>) -> tensor<10x8x12x16x64xf32>
// CHECK: %[[REDUCE_WINDOW_INPUT:.*]] = "mhlo.pad"(%[[OUT_GRAD_DIVIDED]], %[[ZERO]])
// CHECK-SAME: edge_padding_high = dense<[0, 0, 1, 1, 0]>
// CHECK-SAME: edge_padding_low = dense<[0, 0, 1, 1, 0]>
@ -4059,7 +4059,7 @@ func @avgpool_3d_grad_ncdwh_format(%grad: tensor<2x9x8x4x7xf32>) -> tensor<2x9x8
// CHECK-SAME: %[[OUT_GRAD:.*]]: tensor<10x12x16x64xbf16>) -> tensor<10x24x32x64xbf16> {
// CHECK: %[[ZERO:.*]] = mhlo.constant dense<0.000000e+00> : tensor<bf16>
// CHECK: %[[DIVISOR:.*]] = mhlo.constant dense<4.000000e+00> : tensor<bf16>
// CHECK: %[[OUT_GRAD_DIVIDED:.*]] = xla_chlo.broadcast_divide %[[OUT_GRAD]], %[[DIVISOR]]
// CHECK: %[[OUT_GRAD_DIVIDED:.*]] = chlo.broadcast_divide %[[OUT_GRAD]], %[[DIVISOR]]
// CHECK-SAME: broadcast_dimensions = dense<[]>
// CHECK-SAME: -> tensor<10x12x16x64xbf16>
// CHECK: %[[REDUCE_WINDOW_INPUT:.*]] = "mhlo.pad"(%[[OUT_GRAD_DIVIDED]], %[[ZERO]])
@ -4236,10 +4236,10 @@ func @softplus_f16(%arg0: tensor<8x16xf16>) -> tensor<8x16xf16> {
// CHECK-DAG: [[EPSILON:%.*]] = mhlo.constant dense<1.220700e-04> : tensor<f16>
// CHECK-DAG: [[EPSILON_LOG:%.*]] = "mhlo.log"([[EPSILON]])
// CHECK-DAG: [[TWO:%.*]] = mhlo.constant dense<2.000000e+00> : tensor<f16>
// CHECK: [[THRESHOLD:%.*]] = xla_chlo.broadcast_add [[EPSILON_LOG]], [[TWO]]
// CHECK: [[THRESHOLD:%.*]] = chlo.broadcast_add [[EPSILON_LOG]], [[TWO]]
// CHECK: [[NEG_THRESHOLD:%.*]] = "mhlo.negate"([[THRESHOLD]])
// CHECK-DAG: [[COMPARE_GT:%.*]] = xla_chlo.broadcast_compare [[FEATURES]], [[NEG_THRESHOLD]] {comparison_direction = "GT"}
// CHECK-DAG: [[COMPARE_LT:%.*]] = xla_chlo.broadcast_compare [[FEATURES]], [[THRESHOLD]] {comparison_direction = "LT"}
// CHECK-DAG: [[COMPARE_GT:%.*]] = chlo.broadcast_compare [[FEATURES]], [[NEG_THRESHOLD]] {comparison_direction = "GT"}
// CHECK-DAG: [[COMPARE_LT:%.*]] = chlo.broadcast_compare [[FEATURES]], [[THRESHOLD]] {comparison_direction = "LT"}
// CHECK-DAG: [[FEATURES_EXP_LOG:%.*]] = "mhlo.log_plus_one"([[FEATURES_EXP]])
// CHECK: [[ELSE_SELECT:%.*]] = "mhlo.select"([[COMPARE_LT]], [[FEATURES_EXP]], [[FEATURES_EXP_LOG]])
// CHECK: [[ENTRY_SELECT:%.*]] = "mhlo.select"([[COMPARE_GT]], [[FEATURES]], [[ELSE_SELECT]])
@ -4256,10 +4256,10 @@ func @softplus_bf16(%arg0: tensor<8x16xbf16>) -> tensor<8x16xbf16> {
// CHECK-DAG: [[EPSILON:%.*]] = mhlo.constant dense<7.812500e-03> : tensor<bf16>
// CHECK-DAG: [[EPSILON_LOG:%.*]] = "mhlo.log"([[EPSILON]])
// CHECK-DAG: [[TWO:%.*]] = mhlo.constant dense<2.000000e+00> : tensor<bf16>
// CHECK: [[THRESHOLD:%.*]] = xla_chlo.broadcast_add [[EPSILON_LOG]], [[TWO]]
// CHECK: [[THRESHOLD:%.*]] = chlo.broadcast_add [[EPSILON_LOG]], [[TWO]]
// CHECK: [[NEG_THRESHOLD:%.*]] = "mhlo.negate"([[THRESHOLD]])
// CHECK-DAG: [[COMPARE_GT:%.*]] = xla_chlo.broadcast_compare [[FEATURES]], [[NEG_THRESHOLD]] {comparison_direction = "GT"}
// CHECK-DAG: [[COMPARE_LT:%.*]] = xla_chlo.broadcast_compare [[FEATURES]], [[THRESHOLD]] {comparison_direction = "LT"}
// CHECK-DAG: [[COMPARE_GT:%.*]] = chlo.broadcast_compare [[FEATURES]], [[NEG_THRESHOLD]] {comparison_direction = "GT"}
// CHECK-DAG: [[COMPARE_LT:%.*]] = chlo.broadcast_compare [[FEATURES]], [[THRESHOLD]] {comparison_direction = "LT"}
// CHECK-DAG: [[FEATURES_EXP_LOG:%.*]] = "mhlo.log_plus_one"([[FEATURES_EXP]])
// CHECK: [[ELSE_SELECT:%.*]] = "mhlo.select"([[COMPARE_LT]], [[FEATURES_EXP]], [[FEATURES_EXP_LOG]])
// CHECK: [[ENTRY_SELECT:%.*]] = "mhlo.select"([[COMPARE_GT]], [[FEATURES]], [[ELSE_SELECT]])
@ -4276,10 +4276,10 @@ func @softplus_f32(%arg0: tensor<8x16xf32>) -> tensor<8x16xf32> {
// CHECK-DAG: [[EPSILON:%.*]] = mhlo.constant dense<1.1920929E-7> : tensor<f32>
// CHECK-DAG: [[EPSILON_LOG:%.*]] = "mhlo.log"([[EPSILON]])
// CHECK-DAG: [[TWO:%.*]] = mhlo.constant dense<2.000000e+00> : tensor<f32>
// CHECK: [[THRESHOLD:%.*]] = xla_chlo.broadcast_add [[EPSILON_LOG]], [[TWO]]
// CHECK: [[THRESHOLD:%.*]] = chlo.broadcast_add [[EPSILON_LOG]], [[TWO]]
// CHECK: [[NEG_THRESHOLD:%.*]] = "mhlo.negate"([[THRESHOLD]])
// CHECK-DAG: [[COMPARE_GT:%.*]] = xla_chlo.broadcast_compare [[FEATURES]], [[NEG_THRESHOLD]] {comparison_direction = "GT"}
// CHECK-DAG: [[COMPARE_LT:%.*]] = xla_chlo.broadcast_compare [[FEATURES]], [[THRESHOLD]] {comparison_direction = "LT"}
// CHECK-DAG: [[COMPARE_GT:%.*]] = chlo.broadcast_compare [[FEATURES]], [[NEG_THRESHOLD]] {comparison_direction = "GT"}
// CHECK-DAG: [[COMPARE_LT:%.*]] = chlo.broadcast_compare [[FEATURES]], [[THRESHOLD]] {comparison_direction = "LT"}
// CHECK-DAG: [[FEATURES_EXP_LOG:%.*]] = "mhlo.log_plus_one"([[FEATURES_EXP]])
// CHECK: [[ELSE_SELECT:%.*]] = "mhlo.select"([[COMPARE_LT]], [[FEATURES_EXP]], [[FEATURES_EXP_LOG]])
// CHECK: [[ENTRY_SELECT:%.*]] = "mhlo.select"([[COMPARE_GT]], [[FEATURES]], [[ELSE_SELECT]])
@ -4296,10 +4296,10 @@ func @softplus_f64(%arg0: tensor<8x16xf64>) -> tensor<8x16xf64> {
// CHECK-DAG: [[EPSILON:%.*]] = mhlo.constant dense<2.2204460492503131E-16> : tensor<f64>
// CHECK-DAG: [[EPSILON_LOG:%.*]] = "mhlo.log"([[EPSILON]])
// CHECK-DAG: [[TWO:%.*]] = mhlo.constant dense<2.000000e+00> : tensor<f64>
// CHECK: [[THRESHOLD:%.*]] = xla_chlo.broadcast_add [[EPSILON_LOG]], [[TWO]]
// CHECK: [[THRESHOLD:%.*]] = chlo.broadcast_add [[EPSILON_LOG]], [[TWO]]
// CHECK: [[NEG_THRESHOLD:%.*]] = "mhlo.negate"([[THRESHOLD]])
// CHECK-DAG: [[COMPARE_GT:%.*]] = xla_chlo.broadcast_compare [[FEATURES]], [[NEG_THRESHOLD]] {comparison_direction = "GT"}
// CHECK-DAG: [[COMPARE_LT:%.*]] = xla_chlo.broadcast_compare [[FEATURES]], [[THRESHOLD]] {comparison_direction = "LT"}
// CHECK-DAG: [[COMPARE_GT:%.*]] = chlo.broadcast_compare [[FEATURES]], [[NEG_THRESHOLD]] {comparison_direction = "GT"}
// CHECK-DAG: [[COMPARE_LT:%.*]] = chlo.broadcast_compare [[FEATURES]], [[THRESHOLD]] {comparison_direction = "LT"}
// CHECK-DAG: [[FEATURES_EXP_LOG:%.*]] = "mhlo.log_plus_one"([[FEATURES_EXP]])
// CHECK: [[ELSE_SELECT:%.*]] = "mhlo.select"([[COMPARE_LT]], [[FEATURES_EXP]], [[FEATURES_EXP_LOG]])
// CHECK: [[ENTRY_SELECT:%.*]] = "mhlo.select"([[COMPARE_GT]], [[FEATURES]], [[ELSE_SELECT]])

View File

@ -48,7 +48,7 @@ class XlaBuilderTest : public ::testing::Test {
xla_builder_(name_, builder_, module_->getLoc()) {}
string SetupTest() {
mlir::registerDialect<mlir::mhlo::XlaHloDialect>();
mlir::registerDialect<mlir::mhlo::MhloDialect>();
return ::testing::UnitTest::GetInstance()->current_test_info()->name();
}

View File

@ -715,7 +715,7 @@ static void CreateWhile32(Location loc, int num_iterations,
auto one =
builder->create<mhlo::ConstOp>(loc, builder->getI32IntegerAttr(1));
auto scalar_broadcast_dims = GetI64ElementsAttr({}, builder);
auto plus_one = builder->create<xla_chlo::BroadcastAddOp>(
auto plus_one = builder->create<chlo::BroadcastAddOp>(
loc, old_values[0], one, scalar_broadcast_dims);
// Prepend with the updated loop induction variable.
new_values.insert(new_values.begin(), plus_one);
@ -1483,7 +1483,7 @@ class ConvertFusedBatchNormGradBase
RankedTensorType scalar_float = RankedTensorType::get({}, kernel_type);
auto epsilon = rewriter.create<ConstOp>(
loc, DenseFPElementsAttr::get(scalar_float, {op.epsilon()}));
auto add_op = rewriter.create<xla_chlo::BroadcastAddOp>(
auto add_op = rewriter.create<chlo::BroadcastAddOp>(
loc, var, epsilon.getResult(), scalar_broadcast_dims);
Value scratch1 = rewriter.create<RsqrtOp>(loc, add_op);
@ -1601,7 +1601,7 @@ class ConvertFusedBatchNormV3Op
auto factor_const_op = rewriter.create<mhlo::ConstOp>(
op.getLoc(), rewriter.getFloatAttr(scale_element_type, factor));
Value corrected_variance = rewriter.create<xla_chlo::BroadcastMulOp>(
Value corrected_variance = rewriter.create<chlo::BroadcastMulOp>(
op.getLoc(), batch_variance.getType(), batch_variance,
factor_const_op, /*broadcast_dimensions=*/DenseIntElementsAttr());
@ -1621,26 +1621,24 @@ class ConvertFusedBatchNormV3Op
rewriter.getFloatAttr(mean_element_type, exponential_avg_factor));
// new_running_mean = alpha * old_mean + beta * batch_mean.
auto alpha_mul_old_mean = rewriter.create<xla_chlo::BroadcastMulOp>(
auto alpha_mul_old_mean = rewriter.create<chlo::BroadcastMulOp>(
op.getLoc(), op.mean().getType(), alpha, op.mean(),
/*broadcast_dimensions=*/DenseIntElementsAttr());
auto beta_mul_batch_mean = rewriter.create<xla_chlo::BroadcastMulOp>(
auto beta_mul_batch_mean = rewriter.create<chlo::BroadcastMulOp>(
op.getLoc(), batch_mean.getType(), beta, batch_mean,
/*broadcast_dimensions=*/DenseIntElementsAttr());
batch_mean = rewriter.create<xla_chlo::BroadcastAddOp>(
batch_mean = rewriter.create<chlo::BroadcastAddOp>(
op.getLoc(), alpha_mul_old_mean, beta_mul_batch_mean,
/*broadcast_dimensions=*/DenseIntElementsAttr());
// new_running_variance = alpha * old_variance + beta * batch_variance.
auto alpha_mul_old_variance = rewriter.create<xla_chlo::BroadcastMulOp>(
auto alpha_mul_old_variance = rewriter.create<chlo::BroadcastMulOp>(
op.getLoc(), op.variance().getType(), alpha, op.variance(),
/*broadcast_dimensions=*/DenseIntElementsAttr());
auto beta_mul_batch_variance =
rewriter.create<xla_chlo::BroadcastMulOp>(
op.getLoc(), corrected_variance.getType(), beta,
corrected_variance,
/*broadcast_dimensions=*/DenseIntElementsAttr());
corrected_variance = rewriter.create<xla_chlo::BroadcastAddOp>(
auto beta_mul_batch_variance = rewriter.create<chlo::BroadcastMulOp>(
op.getLoc(), corrected_variance.getType(), beta, corrected_variance,
/*broadcast_dimensions=*/DenseIntElementsAttr());
corrected_variance = rewriter.create<chlo::BroadcastAddOp>(
op.getLoc(), alpha_mul_old_variance, beta_mul_batch_variance,
/*broadcast_dimensions=*/DenseIntElementsAttr());
}
@ -1810,7 +1808,7 @@ class ConvertAvgPoolOp : public OpRewritePattern<TF::AvgPoolOp> {
Value divisor =
GetScalarConstOfType(sum_element_type, op.getLoc(), count, &rewriter);
auto scalar_broadcast_dims = GetI64ElementsAttr({}, &rewriter);
Value result = rewriter.create<xla_chlo::BroadcastDivOp>(
Value result = rewriter.create<chlo::BroadcastDivOp>(
op.getLoc(), result_type, reduce, divisor, scalar_broadcast_dims);
// Convert back if we enlarged the element type's bitwidth.
@ -1914,7 +1912,7 @@ class ConvertAvgPoolGradOp : public OpRewritePattern<OpTy> {
Value divisor =
GetScalarConstOfType(element_type, loc, window_count, &rewriter);
auto scalar_broadcast_dims = GetI64ElementsAttr({}, &rewriter);
out_grad_divided = rewriter.create<xla_chlo::BroadcastDivOp>(
out_grad_divided = rewriter.create<chlo::BroadcastDivOp>(
loc, out_grad_type, out_grad, divisor, scalar_broadcast_dims);
} else {
assert(op.padding() == "SAME");
@ -2335,7 +2333,7 @@ class ConvertSizeOp : public OpRewritePattern<TF::SizeOp> {
auto dim = rewriter.create<GetDimensionSizeOp>(
op.getLoc(), result_type, input,
rewriter.getIntegerAttr(rewriter.getIntegerType(32), i));
size = rewriter.create<xla_chlo::BroadcastMulOp>(
size = rewriter.create<chlo::BroadcastMulOp>(
op.getLoc(), size->getResult(0), dim.getResult(),
/*DenseIntElementsAttr=*/DenseIntElementsAttr());
}
@ -3021,10 +3019,10 @@ class ConvertRangeOp : public OpRewritePattern<TF::RangeOp> {
auto iota = rewriter.create<IotaOp>(op.getLoc(), result_type,
rewriter.getI64IntegerAttr(0));
auto scaled = rewriter.create<xla_chlo::BroadcastMulOp>(
auto scaled = rewriter.create<chlo::BroadcastMulOp>(
op.getLoc(), result_type, iota, op.delta(),
xla::getBroadcastDimensionsAttr(&rewriter, iota, op.delta()));
rewriter.replaceOpWithNewOp<xla_chlo::BroadcastAddOp>(
rewriter.replaceOpWithNewOp<chlo::BroadcastAddOp>(
op, result_type, scaled, op.start(),
xla::getBroadcastDimensionsAttr(&rewriter, scaled, op.start()));
return success();
@ -3101,10 +3099,10 @@ class ConvertDynamicRangeOp : public OpRewritePattern<TF::RangeOp> {
auto iota = rewriter.create<DynamicIotaOp>(
op.getLoc(), result_type, reshape, rewriter.getI64IntegerAttr(0));
auto scaled = rewriter.create<xla_chlo::BroadcastMulOp>(
auto scaled = rewriter.create<chlo::BroadcastMulOp>(
op.getLoc(), result_type, iota, delta_out_cast,
xla::getBroadcastDimensionsAttr(&rewriter, iota, delta_cast));
rewriter.replaceOpWithNewOp<xla_chlo::BroadcastAddOp>(
rewriter.replaceOpWithNewOp<chlo::BroadcastAddOp>(
op, result_type, scaled, start_out_cast,
xla::getBroadcastDimensionsAttr(&rewriter, scaled, start_out_cast));
return success();
@ -3152,7 +3150,7 @@ class ConvertLinSpaceOp : public OpRewritePattern<TF::LinSpaceOp> {
int64_t num = (*num_attr.begin()).getSExtValue();
// Calculate the scaling that needs to be applied to the iota.
auto step_numerator = rewriter.create<xla_chlo::BroadcastSubOp>(
auto step_numerator = rewriter.create<chlo::BroadcastSubOp>(
op.getLoc(), op.start().getType(), op.stop(), op.start(),
xla::getBroadcastDimensionsAttr(&rewriter, op.stop(), op.start()));
Value step_denominator = rewriter.create<ConvertOp>(
@ -3160,11 +3158,11 @@ class ConvertLinSpaceOp : public OpRewritePattern<TF::LinSpaceOp> {
if (num > 1) {
Value one = GetScalarConstOfType(result_type.getElementType(),
op.getLoc(), 1, &rewriter);
step_denominator = rewriter.create<xla_chlo::BroadcastSubOp>(
step_denominator = rewriter.create<chlo::BroadcastSubOp>(
op.getLoc(), step_denominator.getType(), step_denominator, one,
xla::getBroadcastDimensionsAttr(&rewriter, step_denominator, one));
}
auto step = rewriter.create<xla_chlo::BroadcastDivOp>(
auto step = rewriter.create<chlo::BroadcastDivOp>(
op.getLoc(), step_numerator.getType(), step_numerator, step_denominator,
xla::getBroadcastDimensionsAttr(&rewriter, step_numerator,
step_denominator));
@ -3172,10 +3170,10 @@ class ConvertLinSpaceOp : public OpRewritePattern<TF::LinSpaceOp> {
// Scale the iota and add the offset.
auto iota = rewriter.create<IotaOp>(op.getLoc(), result_type,
rewriter.getI64IntegerAttr(0));
auto scaled = rewriter.create<xla_chlo::BroadcastMulOp>(
auto scaled = rewriter.create<chlo::BroadcastMulOp>(
op.getLoc(), result_type, iota, step,
xla::getBroadcastDimensionsAttr(&rewriter, iota, step));
rewriter.replaceOpWithNewOp<xla_chlo::BroadcastAddOp>(
rewriter.replaceOpWithNewOp<chlo::BroadcastAddOp>(
op, result_type, scaled, op.start(),
xla::getBroadcastDimensionsAttr(&rewriter, scaled, op.start()));
return success();
@ -3251,7 +3249,7 @@ class GenericConvertReductionOp : public OpRewritePattern<OpTy> {
auto divisor = GetScalarConstOfType(reduce_element_type, loc,
divisor_count, &rewriter);
auto broadcast_dims = GetI64ElementsAttr({}, &rewriter);
result = rewriter.create<xla_chlo::BroadcastDivOp>(
result = rewriter.create<chlo::BroadcastDivOp>(
loc, result, divisor.getResult(), broadcast_dims);
}
@ -5008,11 +5006,11 @@ class ConvertQrOp : public OpRewritePattern<TF::QrOp> {
Value iota = builder->create<IotaOp>(
loc, RankedTensorType::get({m}, builder->getIntegerType(32)),
builder->getI64IntegerAttr(0));
Value gtk = builder->create<xla_chlo::BroadcastCompareOp>(
Value gtk = builder->create<chlo::BroadcastCompareOp>(
loc, iota, k, GetI64ElementsAttr({}, builder),
StringAttr::get("GT", builder->getContext()));
gtk = builder->create<ConvertOp>(loc, gtk, x_type.getElementType());
Value x_after_k = builder->create<xla_chlo::BroadcastMulOp>(
Value x_after_k = builder->create<chlo::BroadcastMulOp>(
loc, x, gtk, GetI64ElementsAttr({minor_dim}, builder));
Value x_after_k_sq = builder->create<MulOp>(loc, x_after_k, x_after_k);
// sigma = np.dot(x[k+1:], x[k+1:])
@ -5024,15 +5022,15 @@ class ConvertQrOp : public OpRewritePattern<TF::QrOp> {
Value mu = builder->create<SqrtOp>(
loc, builder->create<AddOp>(loc, alpha_sq, sigma.getResult(0)));
Value sigma_is_zero = builder->create<xla_chlo::BroadcastCompareOp>(
Value sigma_is_zero = builder->create<chlo::BroadcastCompareOp>(
loc, sigma.getResult(0), zero, GetI64ElementsAttr({}, builder),
StringAttr::get("EQ", builder->getContext()));
Value alpha_is_negative = builder->create<xla_chlo::BroadcastCompareOp>(
Value alpha_is_negative = builder->create<chlo::BroadcastCompareOp>(
loc, alpha, zero, GetI64ElementsAttr({}, builder),
StringAttr::get("LT", builder->getContext()));
auto batch_size_one = builder->create<BroadcastOp>(
loc, alpha.getType(), one, GetI64ElementsAttr(batch_dims, builder));
Value signed_mu = builder->create<xla_chlo::BroadcastMulOp>(
Value signed_mu = builder->create<chlo::BroadcastMulOp>(
loc,
builder->create<SelectOp>(loc, mu.getType(), alpha_is_negative,
batch_size_one,
@ -5050,7 +5048,7 @@ class ConvertQrOp : public OpRewritePattern<TF::QrOp> {
divisor = builder->create<SelectOp>(loc, divisor.getType(), sigma_is_zero,
batch_size_one, divisor);
Value eqk = builder->create<xla_chlo::BroadcastCompareOp>(
Value eqk = builder->create<chlo::BroadcastCompareOp>(
loc, iota, k, GetI64ElementsAttr({}, builder),
StringAttr::get("EQ", builder->getContext()));
eqk = builder->create<ConvertOp>(loc, eqk, x_type.getElementType());
@ -5064,7 +5062,7 @@ class ConvertQrOp : public OpRewritePattern<TF::QrOp> {
// Form v as [0, 0, ..., 1] ++ x[k+1:] / divisor
// If sigma is zero, x[k+1:] is zero, so use any non-zero divisor.
// Note that the add performs a degenerate broadcast.
*v = builder->create<xla_chlo::BroadcastAddOp>(
*v = builder->create<chlo::BroadcastAddOp>(
loc, e_k,
StaticBinaryBroadcast<DivOp>(loc, x_after_k, divisor,
GetI64ElementsAttr(batch_dim_ids, builder),
@ -5154,12 +5152,12 @@ class ConvertQrOp : public OpRewritePattern<TF::QrOp> {
auto iota = builder->create<IotaOp>(
loc, RankedTensorType::get({m, 1}, builder->getIntegerType(32)),
builder->getI64IntegerAttr(0));
Value predecessor_mask = builder->create<xla_chlo::BroadcastCompareOp>(
Value predecessor_mask = builder->create<chlo::BroadcastCompareOp>(
loc, iota, j, GetI64ElementsAttr({}, builder),
StringAttr::get("LT", builder->getContext()));
predecessor_mask = builder->create<ConvertOp>(loc, predecessor_mask,
a_type.getElementType());
Value mask = builder->create<xla_chlo::BroadcastCompareOp>(
Value mask = builder->create<chlo::BroadcastCompareOp>(
loc, iota, j, GetI64ElementsAttr({}, builder),
StringAttr::get("EQ", builder->getContext()));
mask = builder->create<ConvertOp>(loc, mask, a_type.getElementType());
@ -5189,7 +5187,7 @@ class ConvertQrOp : public OpRewritePattern<TF::QrOp> {
loc,
RankedTensorType::get(a_type.getShape(), builder->getIntegerType(32)),
builder->getI64IntegerAttr(minor_dim + 1));
Value xa_mask = builder->create<xla_chlo::BroadcastCompareOp>(
Value xa_mask = builder->create<chlo::BroadcastCompareOp>(
loc, iota_mn, j, GetI64ElementsAttr({}, builder),
StringAttr::get("EQ", builder->getContext()));
a = builder->create<SelectOp>(loc, a_type, xa_mask, new_x, a);
@ -5226,7 +5224,7 @@ class ConvertQrOp : public OpRewritePattern<TF::QrOp> {
loc, taus.getType(), taus_zeros,
GetI64ElementsAttr(taus.getType().cast<RankedTensorType>().getShape(),
builder));
Value taus_mask = builder->create<xla_chlo::BroadcastCompareOp>(
Value taus_mask = builder->create<chlo::BroadcastCompareOp>(
loc, iota_n, j, GetI64ElementsAttr({}, builder),
StringAttr::get("EQ", builder->getContext()));
auto taus_update = builder->create<SelectOp>(
@ -5311,7 +5309,7 @@ class ConvertQrOp : public OpRewritePattern<TF::QrOp> {
loc, vs.getType(), zero,
GetI64ElementsAttr(vs.getType().cast<RankedTensorType>().getShape(),
builder));
auto compare = builder->create<xla_chlo::BroadcastCompareOp>(
auto compare = builder->create<chlo::BroadcastCompareOp>(
loc, iota_mn, j, GetI64ElementsAttr({}, builder),
StringAttr::get("GE", builder->getContext()));
auto y = builder->create<SelectOp>(loc, vs.getType(), compare, zero, vs);
@ -5459,16 +5457,16 @@ LogicalResult legalizeTF(Operation *op, bool allow_partial_conversion,
// Populate with CHLO->HLO lowerings to account for TF ops legalized to
// CHLO first.
if (legalize_chlo) {
xla_chlo::PopulateLegalizeChloToHloPatterns(context, &patterns);
chlo::PopulateLegalizeChloToHloPatterns(context, &patterns);
}
ConversionTarget target(*context);
if (legalize_chlo) {
target.addIllegalDialect<xla_chlo::XlaHloClientDialect>();
target.addIllegalDialect<chlo::HloClientDialect>();
} else {
target.addLegalDialect<xla_chlo::XlaHloClientDialect>();
target.addLegalDialect<chlo::HloClientDialect>();
}
target.addLegalDialect<XlaHloDialect>();
target.addLegalDialect<MhloDialect>();
target.addLegalDialect<StandardOpsDialect>();
target.addLegalDialect<shape::ShapeDialect>();
target.addLegalOp<CallOp>();

View File

@ -75,10 +75,10 @@ namespace {
template <typename T, size_t N>
using InlinedVector = tensorflow::gtl::InlinedVector<T, N>; // non-absl ok
static bool IsOpWhitelisted(Operation* op) {
// White-listed TensorFlow ops are known to have well behaved tf2xla kernels
static bool IsOpAllowlisted(Operation* op) {
// Allowlisted TensorFlow ops are known to have well behaved tf2xla kernels
// building valid MLIR using MlirHloBuilder.
// TODO(hinsu): Drop explicit whitelist when MLIR based bridge is enabled for
// TODO(hinsu): Drop explicit allowlist when MLIR based bridge is enabled for
// all tf2xla kernels.
// clang-format off
static llvm::SmallDenseSet<mlir::TypeID, 512> ops = {
@ -342,7 +342,7 @@ LogicalResult FuncLegalizer::Legalize() {
}
LogicalResult FuncLegalizer::LegalizeOp(Operation* op) {
if (!IsOpWhitelisted(op)) return success();
if (!IsOpAllowlisted(op)) return success();
// Only static shaped operands are supported in XLA builders for now.
for (Type ty : op->getOperandTypes()) {

View File

@ -190,51 +190,51 @@ Status LhloDialectEmitter::DefaultAction(HloInstruction* instr) {
using ::xla::HloOpcode;
switch (instr->opcode()) {
case HloOpcode::kAbs:
return CreateOpWithoutAttrs<xla_lhlo::AbsOp>(instr).status();
return CreateOpWithoutAttrs<lmhlo::AbsOp>(instr).status();
case HloOpcode::kAdd:
return CreateOpWithoutAttrs<xla_lhlo::AddOp>(instr).status();
return CreateOpWithoutAttrs<lmhlo::AddOp>(instr).status();
case HloOpcode::kAnd:
return CreateOpWithoutAttrs<xla_lhlo::AndOp>(instr).status();
return CreateOpWithoutAttrs<lmhlo::AndOp>(instr).status();
case HloOpcode::kCeil:
return CreateOpWithoutAttrs<xla_lhlo::CeilOp>(instr).status();
return CreateOpWithoutAttrs<lmhlo::CeilOp>(instr).status();
case HloOpcode::kComplex:
return CreateOpWithoutAttrs<xla_lhlo::ComplexOp>(instr).status();
return CreateOpWithoutAttrs<lmhlo::ComplexOp>(instr).status();
case HloOpcode::kCopy:
return CreateOpWithoutAttrs<xla_lhlo::CopyOp>(instr).status();
return CreateOpWithoutAttrs<lmhlo::CopyOp>(instr).status();
case HloOpcode::kCos:
return CreateOpWithoutAttrs<xla_lhlo::CosOp>(instr).status();
return CreateOpWithoutAttrs<lmhlo::CosOp>(instr).status();
case HloOpcode::kDivide:
return CreateOpWithoutAttrs<xla_lhlo::DivOp>(instr).status();
return CreateOpWithoutAttrs<lmhlo::DivOp>(instr).status();
case HloOpcode::kExp:
return CreateOpWithoutAttrs<xla_lhlo::ExpOp>(instr).status();
return CreateOpWithoutAttrs<lmhlo::ExpOp>(instr).status();
case HloOpcode::kImag:
return CreateOpWithoutAttrs<xla_lhlo::ImagOp>(instr).status();
return CreateOpWithoutAttrs<lmhlo::ImagOp>(instr).status();
case HloOpcode::kLog:
return CreateOpWithoutAttrs<xla_lhlo::LogOp>(instr).status();
return CreateOpWithoutAttrs<lmhlo::LogOp>(instr).status();
case HloOpcode::kMaximum:
return CreateOpWithoutAttrs<xla_lhlo::MaxOp>(instr).status();
return CreateOpWithoutAttrs<lmhlo::MaxOp>(instr).status();
case HloOpcode::kMinimum:
return CreateOpWithoutAttrs<xla_lhlo::MinOp>(instr).status();
return CreateOpWithoutAttrs<lmhlo::MinOp>(instr).status();
case HloOpcode::kMultiply:
return CreateOpWithoutAttrs<xla_lhlo::MulOp>(instr).status();
return CreateOpWithoutAttrs<lmhlo::MulOp>(instr).status();
case HloOpcode::kNegate:
return CreateOpWithoutAttrs<xla_lhlo::NegOp>(instr).status();
return CreateOpWithoutAttrs<lmhlo::NegOp>(instr).status();
case HloOpcode::kReal:
return CreateOpWithoutAttrs<xla_lhlo::RealOp>(instr).status();
return CreateOpWithoutAttrs<lmhlo::RealOp>(instr).status();
case HloOpcode::kRemainder:
return CreateOpWithoutAttrs<xla_lhlo::RemOp>(instr).status();
return CreateOpWithoutAttrs<lmhlo::RemOp>(instr).status();
case HloOpcode::kRsqrt:
return CreateOpWithoutAttrs<xla_lhlo::RsqrtOp>(instr).status();
return CreateOpWithoutAttrs<lmhlo::RsqrtOp>(instr).status();
case HloOpcode::kSelect:
return CreateOpWithoutAttrs<xla_lhlo::SelectOp>(instr).status();
return CreateOpWithoutAttrs<lmhlo::SelectOp>(instr).status();
case HloOpcode::kSign:
return CreateOpWithoutAttrs<xla_lhlo::SignOp>(instr).status();
return CreateOpWithoutAttrs<lmhlo::SignOp>(instr).status();
case HloOpcode::kSqrt:
return CreateOpWithoutAttrs<xla_lhlo::SqrtOp>(instr).status();
return CreateOpWithoutAttrs<lmhlo::SqrtOp>(instr).status();
case HloOpcode::kSubtract:
return CreateOpWithoutAttrs<xla_lhlo::SubOp>(instr).status();
return CreateOpWithoutAttrs<lmhlo::SubOp>(instr).status();
case HloOpcode::kTanh:
return CreateOpWithoutAttrs<xla_lhlo::TanhOp>(instr).status();
return CreateOpWithoutAttrs<lmhlo::TanhOp>(instr).status();
default:
llvm::errs() << instr->ToString();
return tensorflow::errors::Internal(
@ -246,7 +246,7 @@ Status LhloDialectEmitter::DefaultAction(HloInstruction* instr) {
StatusOr<mlir::Operation*> LhloDialectEmitter::EmitSortOp(
HloInstruction* instr) {
TF_ASSIGN_OR_RETURN(auto sort, CreateOpWithoutAttrs<xla_lhlo::SortOp>(instr));
TF_ASSIGN_OR_RETURN(auto sort, CreateOpWithoutAttrs<lmhlo::SortOp>(instr));
auto* sort_instr = ::xla::Cast<::xla::HloSortInstruction>(instr);
sort.dimensionAttr(builder_.getI64IntegerAttr(sort_instr->sort_dimension()));
sort.is_stableAttr(builder_.getBoolAttr(sort_instr->is_stable()));
@ -379,16 +379,16 @@ Status LhloDialectEmitter::Initialize() {
block->addArgument(arg_type);
allocations_[alloc] = block->getArguments().back();
args_attrs.emplace_back();
args_attrs.back().set(builder_.getIdentifier("xla_lhlo.params"),
args_attrs.back().set(builder_.getIdentifier("lmhlo.params"),
builder_.getIndexAttr(alloc->parameter_number()));
} else {
block->addArgument(MemRefType::get({alloc->size()}, i8_type_));
allocations_[alloc] = block->getArguments().back();
args_attrs.emplace_back();
args_attrs.back().set(builder_.getIdentifier("xla_lhlo.alloc"),
args_attrs.back().set(builder_.getIdentifier("lmhlo.alloc"),
builder_.getIndexAttr(alloc->index()));
if (alloc->maybe_live_out())
args_attrs.back().set(builder_.getIdentifier("xla_lhlo.liveout"),
args_attrs.back().set(builder_.getIdentifier("lmhlo.liveout"),
builder_.getBoolAttr(true));
}
}

Some files were not shown because too many files have changed in this diff Show More