Merge branch 'master' of github.com:ashahba/tensorflow into ashahba/ubuntu-onednn-partials
This commit is contained in:
commit
9fa46cf554
@ -50,6 +50,9 @@
|
||||
* Tracing and Debugging:
|
||||
* <ADD RELEASE NOTES HERE>
|
||||
* Other:
|
||||
* We have replaced uses of "whitelist" with "allowlist" where possible.
|
||||
Please see https://developers.google.com/style/word-list#blacklist for more
|
||||
context.
|
||||
* <ADD RELEASE NOTES HERE>
|
||||
|
||||
## Thanks to our Contributors
|
||||
|
@ -44,7 +44,7 @@ Even if the untrusted party only supplies the serialized computation
|
||||
graph (in form of a `GraphDef`, `SavedModel`, or equivalent on-disk format), the
|
||||
set of computation primitives available to TensorFlow is powerful enough that
|
||||
you should assume that the TensorFlow process effectively executes arbitrary
|
||||
code. One common solution is to whitelist only a few safe Ops. While this is
|
||||
code. One common solution is to allow only a few safe Ops. While this is
|
||||
possible in theory, we still recommend you sandbox the execution.
|
||||
|
||||
It depends on the computation graph whether a user provided checkpoint is safe.
|
||||
|
@ -260,6 +260,36 @@ config_setting(
|
||||
visibility = ["//visibility:public"],
|
||||
)
|
||||
|
||||
config_setting(
|
||||
name = "armeabi",
|
||||
values = {"cpu": "armeabi"},
|
||||
visibility = ["//visibility:public"],
|
||||
)
|
||||
|
||||
config_setting(
|
||||
name = "armeabi-v7a",
|
||||
values = {"cpu": "armeabi-v7a"},
|
||||
visibility = ["//visibility:public"],
|
||||
)
|
||||
|
||||
config_setting(
|
||||
name = "arm64-v8a",
|
||||
values = {"cpu": "arm64-v8a"},
|
||||
visibility = ["//visibility:public"],
|
||||
)
|
||||
|
||||
selects.config_setting_group(
|
||||
name = "arm_any",
|
||||
match_any = [
|
||||
":arm",
|
||||
":armeabi",
|
||||
":armeabi-v7a",
|
||||
":arm64-v8a",
|
||||
":linux_aarch64",
|
||||
":linux_armhf",
|
||||
],
|
||||
)
|
||||
|
||||
config_setting(
|
||||
name = "freebsd",
|
||||
values = {"cpu": "freebsd"},
|
||||
|
@ -337,10 +337,13 @@ tensorflow::Status CreateRemoteContexts(
|
||||
});
|
||||
}
|
||||
counter.Wait();
|
||||
tensorflow::StatusGroup sg;
|
||||
for (int i = 0; i < num_remote_workers; i++) {
|
||||
TF_RETURN_IF_ERROR(statuses[i]);
|
||||
if (TF_PREDICT_FALSE(!statuses[i].ok())) {
|
||||
sg.Update(statuses[i]);
|
||||
}
|
||||
}
|
||||
return tensorflow::Status::OK();
|
||||
return sg.as_summary_status();
|
||||
}
|
||||
|
||||
tensorflow::Status UpdateRemoteContexts(
|
||||
@ -611,10 +614,21 @@ tensorflow::Status UpdateTFE_ContextWithServerDef(
|
||||
|
||||
// Initialize remote eager workers.
|
||||
if (reset_context) {
|
||||
LOG_AND_RETURN_IF_ERROR(CreateRemoteContexts(
|
||||
const tensorflow::Status s = CreateRemoteContexts(
|
||||
ctx, remote_workers, context_id, context_view_id, keep_alive_secs,
|
||||
server_def, remote_eager_workers.get(), context->Executor().Async(),
|
||||
context->LazyCopyFunctionRemoteInputs(), base_request));
|
||||
context->LazyCopyFunctionRemoteInputs(), base_request);
|
||||
// NOTE: the remote tasks could fail after `GetAllRemoteDevices` and cause
|
||||
// the CreateRemoteContexts to fail. We currently only log instead of
|
||||
// directly returning the error, since returning here will cause the server
|
||||
// object to be destroyed (which currently CHECK-fails). The client will
|
||||
// see additional errors if ops are subsequently sent to the failed workers.
|
||||
if (TF_PREDICT_FALSE(!s.ok())) {
|
||||
LOG(ERROR) << "Error when creating contexts on remote targets: "
|
||||
<< s.error_message()
|
||||
<< "\nExecuting remote ops or functions on these remote "
|
||||
"targets will fail.";
|
||||
}
|
||||
} else {
|
||||
// The master's context_view_id will be incremented by one
|
||||
// the UpdateRemoteMaster call later. We want all new workers and
|
||||
@ -644,15 +658,16 @@ tensorflow::Status UpdateTFE_ContextWithServerDef(
|
||||
grpc_server->worker_env()->rendezvous_mgr->Find(context_id);
|
||||
auto* device_mgr = grpc_server->worker_env()->device_mgr;
|
||||
std::shared_ptr<tensorflow::WorkerSession> worker_session;
|
||||
TF_RETURN_IF_ERROR(grpc_server->worker_env()->session_mgr->CreateSession(
|
||||
session_name, server_def, base_request.cluster_device_attributes(),
|
||||
true));
|
||||
TF_RETURN_IF_ERROR(
|
||||
LOG_AND_RETURN_IF_ERROR(
|
||||
grpc_server->worker_env()->session_mgr->CreateSession(
|
||||
session_name, server_def, base_request.cluster_device_attributes(),
|
||||
true));
|
||||
LOG_AND_RETURN_IF_ERROR(
|
||||
grpc_server->worker_env()->session_mgr->WorkerSessionForSession(
|
||||
session_name, &worker_session));
|
||||
|
||||
// Initialize remote tensor communication based on worker session.
|
||||
TF_RETURN_IF_ERROR(r->Initialize(worker_session.get()));
|
||||
LOG_AND_RETURN_IF_ERROR(r->Initialize(worker_session.get()));
|
||||
|
||||
tensorflow::DistributedFunctionLibraryRuntime* cluster_flr =
|
||||
tensorflow::eager::CreateClusterFLR(context_id, context,
|
||||
|
@ -52,6 +52,25 @@ cc_library(
|
||||
],
|
||||
)
|
||||
|
||||
cc_library(
|
||||
name = "cleanup",
|
||||
hdrs = ["cleanup.h"],
|
||||
)
|
||||
|
||||
cc_library(
|
||||
name = "ram_file_block_cache",
|
||||
srcs = ["ram_file_block_cache.cc"],
|
||||
hdrs = ["ram_file_block_cache.h"],
|
||||
deps = [
|
||||
":cleanup",
|
||||
":file_block_cache",
|
||||
"//tensorflow/c:env",
|
||||
"//tensorflow/c:tf_status",
|
||||
"@com_google_absl//absl/base",
|
||||
"@com_google_absl//absl/synchronization",
|
||||
],
|
||||
)
|
||||
|
||||
tf_cc_test(
|
||||
name = "gcs_filesystem_test",
|
||||
srcs = [
|
||||
|
111
tensorflow/c/experimental/filesystem/plugins/gcs/cleanup.h
Normal file
111
tensorflow/c/experimental/filesystem/plugins/gcs/cleanup.h
Normal file
@ -0,0 +1,111 @@
|
||||
/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
|
||||
// MakeCleanup(f) returns an RAII cleanup object that calls 'f' in its
|
||||
// destructor. The easiest way to use MakeCleanup is with a lambda argument,
|
||||
// capturing the return value in an 'auto' local variable. Most users will not
|
||||
// need more sophisticated syntax than that.
|
||||
//
|
||||
// Example:
|
||||
// void func() {
|
||||
// FILE* fp = fopen("data.txt", "r");
|
||||
// if (fp == nullptr) return;
|
||||
// auto fp_cleaner = gtl::MakeCleanup([fp] { fclose(fp); });
|
||||
// // No matter what, fclose(fp) will happen.
|
||||
// DataObject d;
|
||||
// while (ReadDataObject(fp, &d)) {
|
||||
// if (d.IsBad()) {
|
||||
// LOG(ERROR) << "Bad Data";
|
||||
// return;
|
||||
// }
|
||||
// PushGoodData(d);
|
||||
// }
|
||||
// }
|
||||
//
|
||||
// You can use Cleanup<F> directly, instead of using MakeCleanup and auto,
|
||||
// but there's rarely a reason to do that.
|
||||
//
|
||||
// You can call 'release()' on a Cleanup object to cancel the cleanup.
|
||||
|
||||
#ifndef TENSORFLOW_C_EXPERIMENTAL_FILESYSTEM_PLUGINS_GCS_CLEANUP_H_
|
||||
#define TENSORFLOW_C_EXPERIMENTAL_FILESYSTEM_PLUGINS_GCS_CLEANUP_H_
|
||||
|
||||
#include <type_traits>
|
||||
#include <utility>
|
||||
|
||||
#include "tensorflow/core/platform/macros.h"
|
||||
|
||||
namespace tf_gcs_filesystem {
|
||||
|
||||
// A move-only RAII object that calls a stored cleanup functor when
|
||||
// destroyed. Cleanup<F> is the return type of gtl::MakeCleanup(F).
|
||||
template <typename F>
|
||||
class Cleanup {
|
||||
public:
|
||||
Cleanup() : released_(true), f_() {}
|
||||
|
||||
template <typename G>
|
||||
explicit Cleanup(G&& f) // NOLINT
|
||||
: f_(std::forward<G>(f)) {} // NOLINT(build/c++11)
|
||||
|
||||
Cleanup(Cleanup&& src) // NOLINT
|
||||
: released_(src.is_released()), f_(src.release()) {}
|
||||
|
||||
// Implicitly move-constructible from any compatible Cleanup<G>.
|
||||
// The source will be released as if src.release() were called.
|
||||
// A moved-from Cleanup can be safely destroyed or reassigned.
|
||||
template <typename G>
|
||||
Cleanup(Cleanup<G>&& src) // NOLINT
|
||||
: released_(src.is_released()), f_(src.release()) {}
|
||||
|
||||
// Assignment to a Cleanup object behaves like destroying it
|
||||
// and making a new one in its place, analogous to unique_ptr
|
||||
// semantics.
|
||||
Cleanup& operator=(Cleanup&& src) { // NOLINT
|
||||
if (!released_) f_();
|
||||
released_ = src.released_;
|
||||
f_ = src.release();
|
||||
return *this;
|
||||
}
|
||||
|
||||
~Cleanup() {
|
||||
if (!released_) f_();
|
||||
}
|
||||
|
||||
// Releases the cleanup function instead of running it.
|
||||
// Hint: use c.release()() to run early.
|
||||
F release() {
|
||||
released_ = true;
|
||||
return std::move(f_);
|
||||
}
|
||||
|
||||
bool is_released() const { return released_; }
|
||||
|
||||
private:
|
||||
static_assert(!std::is_reference<F>::value, "F must not be a reference");
|
||||
|
||||
bool released_ = false;
|
||||
F f_;
|
||||
};
|
||||
|
||||
template <int&... ExplicitParameterBarrier, typename F,
|
||||
typename DecayF = typename std::decay<F>::type>
|
||||
Cleanup<DecayF> MakeCleanup(F&& f) {
|
||||
return Cleanup<DecayF>(std::forward<F>(f));
|
||||
}
|
||||
|
||||
} // namespace tf_gcs_filesystem
|
||||
|
||||
#endif // TENSORFLOW_C_EXPERIMENTAL_FILESYSTEM_PLUGINS_GCS_CLEANUP_H_
|
@ -1,8 +1,11 @@
|
||||
/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
|
@ -0,0 +1,317 @@
|
||||
/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
#include "tensorflow/c/experimental/filesystem/plugins/gcs/ram_file_block_cache.h"
|
||||
|
||||
#include <cstring>
|
||||
#include <memory>
|
||||
#include <sstream>
|
||||
#include <utility>
|
||||
|
||||
#include "absl/synchronization/mutex.h"
|
||||
#include "tensorflow/c/experimental/filesystem/plugins/gcs/cleanup.h"
|
||||
|
||||
namespace tf_gcs_filesystem {
|
||||
|
||||
bool RamFileBlockCache::BlockNotStale(const std::shared_ptr<Block>& block) {
|
||||
absl::MutexLock l(&block->mu);
|
||||
if (block->state != FetchState::FINISHED) {
|
||||
return true; // No need to check for staleness.
|
||||
}
|
||||
if (max_staleness_ == 0) return true; // Not enforcing staleness.
|
||||
return timer_seconds_() - block->timestamp <= max_staleness_;
|
||||
}
|
||||
|
||||
std::shared_ptr<RamFileBlockCache::Block> RamFileBlockCache::Lookup(
|
||||
const Key& key) {
|
||||
absl::MutexLock lock(&mu_);
|
||||
auto entry = block_map_.find(key);
|
||||
if (entry != block_map_.end()) {
|
||||
if (BlockNotStale(entry->second)) {
|
||||
if (cache_stats_ != nullptr) {
|
||||
cache_stats_->RecordCacheHitBlockSize(entry->second->data.size());
|
||||
}
|
||||
return entry->second;
|
||||
} else {
|
||||
// Remove the stale block and continue.
|
||||
RemoveFile_Locked(key.first);
|
||||
}
|
||||
}
|
||||
|
||||
// Insert a new empty block, setting the bookkeeping to sentinel values
|
||||
// in order to update them as appropriate.
|
||||
auto new_entry = std::make_shared<Block>();
|
||||
lru_list_.push_front(key);
|
||||
lra_list_.push_front(key);
|
||||
new_entry->lru_iterator = lru_list_.begin();
|
||||
new_entry->lra_iterator = lra_list_.begin();
|
||||
new_entry->timestamp = timer_seconds_();
|
||||
block_map_.emplace(std::make_pair(key, new_entry));
|
||||
return new_entry;
|
||||
}
|
||||
|
||||
// Remove blocks from the cache until we do not exceed our maximum size.
|
||||
void RamFileBlockCache::Trim() {
|
||||
while (!lru_list_.empty() && cache_size_ > max_bytes_) {
|
||||
RemoveBlock(block_map_.find(lru_list_.back()));
|
||||
}
|
||||
}
|
||||
|
||||
/// Move the block to the front of the LRU list if it isn't already there.
|
||||
void RamFileBlockCache::UpdateLRU(const Key& key,
|
||||
const std::shared_ptr<Block>& block,
|
||||
TF_Status* status) {
|
||||
absl::MutexLock lock(&mu_);
|
||||
if (block->timestamp == 0) {
|
||||
// The block was evicted from another thread. Allow it to remain evicted.
|
||||
return TF_SetStatus(status, TF_OK, "");
|
||||
}
|
||||
if (block->lru_iterator != lru_list_.begin()) {
|
||||
lru_list_.erase(block->lru_iterator);
|
||||
lru_list_.push_front(key);
|
||||
block->lru_iterator = lru_list_.begin();
|
||||
}
|
||||
|
||||
// Check for inconsistent state. If there is a block later in the same file
|
||||
// in the cache, and our current block is not block size, this likely means
|
||||
// we have inconsistent state within the cache. Note: it's possible some
|
||||
// incomplete reads may still go undetected.
|
||||
if (block->data.size() < block_size_) {
|
||||
Key fmax = std::make_pair(key.first, std::numeric_limits<size_t>::max());
|
||||
auto fcmp = block_map_.upper_bound(fmax);
|
||||
if (fcmp != block_map_.begin() && key < (--fcmp)->first) {
|
||||
return TF_SetStatus(status, TF_INTERNAL,
|
||||
"Block cache contents are inconsistent.");
|
||||
}
|
||||
}
|
||||
|
||||
Trim();
|
||||
|
||||
return TF_SetStatus(status, TF_OK, "");
|
||||
}
|
||||
|
||||
void RamFileBlockCache::MaybeFetch(const Key& key,
|
||||
const std::shared_ptr<Block>& block,
|
||||
TF_Status* status) {
|
||||
bool downloaded_block = false;
|
||||
auto reconcile_state = MakeCleanup([this, &downloaded_block, &key, &block] {
|
||||
// Perform this action in a cleanup callback to avoid locking mu_ after
|
||||
// locking block->mu.
|
||||
if (downloaded_block) {
|
||||
absl::MutexLock l(&mu_);
|
||||
// Do not update state if the block is already to be evicted.
|
||||
if (block->timestamp != 0) {
|
||||
// Use capacity() instead of size() to account for all memory
|
||||
// used by the cache.
|
||||
cache_size_ += block->data.capacity();
|
||||
// Put to beginning of LRA list.
|
||||
lra_list_.erase(block->lra_iterator);
|
||||
lra_list_.push_front(key);
|
||||
block->lra_iterator = lra_list_.begin();
|
||||
block->timestamp = timer_seconds_();
|
||||
}
|
||||
}
|
||||
});
|
||||
// Loop until either block content is successfully fetched, or our request
|
||||
// encounters an error.
|
||||
absl::MutexLock l(&block->mu);
|
||||
TF_SetStatus(status, TF_OK, "");
|
||||
while (true) {
|
||||
switch (block->state) {
|
||||
case FetchState::ERROR:
|
||||
// TF_FALLTHROUGH_INTENDED
|
||||
case FetchState::CREATED:
|
||||
block->state = FetchState::FETCHING;
|
||||
block->mu.Unlock(); // Release the lock while making the API call.
|
||||
block->data.clear();
|
||||
block->data.resize(block_size_, 0);
|
||||
size_t bytes_transferred;
|
||||
block_fetcher_(key.first, key.second, block_size_, block->data.data(),
|
||||
&bytes_transferred, status);
|
||||
if (cache_stats_ != nullptr) {
|
||||
cache_stats_->RecordCacheMissBlockSize(bytes_transferred);
|
||||
}
|
||||
block->mu.Lock(); // Reacquire the lock immediately afterwards
|
||||
if (TF_GetCode(status) == TF_OK) {
|
||||
block->data.resize(bytes_transferred, 0);
|
||||
// Shrink the data capacity to the actual size used.
|
||||
// NOLINTNEXTLINE: shrink_to_fit() may not shrink the capacity.
|
||||
std::vector<char>(block->data).swap(block->data);
|
||||
downloaded_block = true;
|
||||
block->state = FetchState::FINISHED;
|
||||
} else {
|
||||
block->state = FetchState::ERROR;
|
||||
}
|
||||
block->cond_var.SignalAll();
|
||||
return;
|
||||
case FetchState::FETCHING:
|
||||
block->cond_var.WaitWithTimeout(&block->mu, absl::Minutes(1));
|
||||
if (block->state == FetchState::FINISHED) {
|
||||
return TF_SetStatus(status, TF_OK, "");
|
||||
}
|
||||
// Re-loop in case of errors.
|
||||
break;
|
||||
case FetchState::FINISHED:
|
||||
return TF_SetStatus(status, TF_OK, "");
|
||||
}
|
||||
}
|
||||
return TF_SetStatus(
|
||||
status, TF_INTERNAL,
|
||||
"Control flow should never reach the end of RamFileBlockCache::Fetch.");
|
||||
}
|
||||
|
||||
void RamFileBlockCache::Read(const std::string& filename, size_t offset,
|
||||
size_t n, char* buffer, size_t* bytes_transferred,
|
||||
TF_Status* status) {
|
||||
*bytes_transferred = 0;
|
||||
if (n == 0) {
|
||||
return TF_SetStatus(status, TF_OK, "");
|
||||
}
|
||||
if (!IsCacheEnabled() || (n > max_bytes_)) {
|
||||
// The cache is effectively disabled, so we pass the read through to the
|
||||
// fetcher without breaking it up into blocks.
|
||||
return block_fetcher_(filename, offset, n, buffer, bytes_transferred,
|
||||
status);
|
||||
}
|
||||
// Calculate the block-aligned start and end of the read.
|
||||
size_t start = block_size_ * (offset / block_size_);
|
||||
size_t finish = block_size_ * ((offset + n) / block_size_);
|
||||
if (finish < offset + n) {
|
||||
finish += block_size_;
|
||||
}
|
||||
size_t total_bytes_transferred = 0;
|
||||
// Now iterate through the blocks, reading them one at a time.
|
||||
for (size_t pos = start; pos < finish; pos += block_size_) {
|
||||
Key key = std::make_pair(filename, pos);
|
||||
// Look up the block, fetching and inserting it if necessary, and update the
|
||||
// LRU iterator for the key and block.
|
||||
std::shared_ptr<Block> block = Lookup(key);
|
||||
if (!block) {
|
||||
std::cerr << "No block for key " << key.first << "@" << key.second;
|
||||
abort();
|
||||
}
|
||||
MaybeFetch(key, block, status);
|
||||
if (TF_GetCode(status) != TF_OK) return;
|
||||
UpdateLRU(key, block, status);
|
||||
if (TF_GetCode(status) != TF_OK) return;
|
||||
// Copy the relevant portion of the block into the result buffer.
|
||||
const auto& data = block->data;
|
||||
if (offset >= pos + data.size()) {
|
||||
// The requested offset is at or beyond the end of the file. This can
|
||||
// happen if `offset` is not block-aligned, and the read returns the last
|
||||
// block in the file, which does not extend all the way out to `offset`.
|
||||
*bytes_transferred = total_bytes_transferred;
|
||||
std::stringstream os;
|
||||
os << "EOF at offset " << offset << " in file " << filename
|
||||
<< " at position " << pos << " with data size " << data.size();
|
||||
return TF_SetStatus(status, TF_OUT_OF_RANGE, std::move(os).str().c_str());
|
||||
}
|
||||
auto begin = data.begin();
|
||||
if (offset > pos) {
|
||||
// The block begins before the slice we're reading.
|
||||
begin += offset - pos;
|
||||
}
|
||||
auto end = data.end();
|
||||
if (pos + data.size() > offset + n) {
|
||||
// The block extends past the end of the slice we're reading.
|
||||
end -= (pos + data.size()) - (offset + n);
|
||||
}
|
||||
if (begin < end) {
|
||||
size_t bytes_to_copy = end - begin;
|
||||
memcpy(&buffer[total_bytes_transferred], &*begin, bytes_to_copy);
|
||||
total_bytes_transferred += bytes_to_copy;
|
||||
}
|
||||
if (data.size() < block_size_) {
|
||||
// The block was a partial block and thus signals EOF at its upper bound.
|
||||
break;
|
||||
}
|
||||
}
|
||||
*bytes_transferred = total_bytes_transferred;
|
||||
return TF_SetStatus(status, TF_OK, "");
|
||||
}
|
||||
|
||||
bool RamFileBlockCache::ValidateAndUpdateFileSignature(
|
||||
const std::string& filename, int64_t file_signature) {
|
||||
absl::MutexLock lock(&mu_);
|
||||
auto it = file_signature_map_.find(filename);
|
||||
if (it != file_signature_map_.end()) {
|
||||
if (it->second == file_signature) {
|
||||
return true;
|
||||
}
|
||||
// Remove the file from cache if the signatures don't match.
|
||||
RemoveFile_Locked(filename);
|
||||
it->second = file_signature;
|
||||
return false;
|
||||
}
|
||||
file_signature_map_[filename] = file_signature;
|
||||
return true;
|
||||
}
|
||||
|
||||
size_t RamFileBlockCache::CacheSize() const {
|
||||
absl::MutexLock lock(&mu_);
|
||||
return cache_size_;
|
||||
}
|
||||
|
||||
void RamFileBlockCache::Prune() {
|
||||
while (!stop_pruning_thread_.WaitForNotificationWithTimeout(
|
||||
absl::Microseconds(1000000))) {
|
||||
absl::MutexLock lock(&mu_);
|
||||
uint64_t now = timer_seconds_();
|
||||
while (!lra_list_.empty()) {
|
||||
auto it = block_map_.find(lra_list_.back());
|
||||
if (now - it->second->timestamp <= max_staleness_) {
|
||||
// The oldest block is not yet expired. Come back later.
|
||||
break;
|
||||
}
|
||||
// We need to make a copy of the filename here, since it could otherwise
|
||||
// be used within RemoveFile_Locked after `it` is deleted.
|
||||
RemoveFile_Locked(std::string(it->first.first));
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
void RamFileBlockCache::Flush() {
|
||||
absl::MutexLock lock(&mu_);
|
||||
block_map_.clear();
|
||||
lru_list_.clear();
|
||||
lra_list_.clear();
|
||||
cache_size_ = 0;
|
||||
}
|
||||
|
||||
void RamFileBlockCache::RemoveFile(const std::string& filename) {
|
||||
absl::MutexLock lock(&mu_);
|
||||
RemoveFile_Locked(filename);
|
||||
}
|
||||
|
||||
void RamFileBlockCache::RemoveFile_Locked(const std::string& filename) {
|
||||
Key begin = std::make_pair(filename, 0);
|
||||
auto it = block_map_.lower_bound(begin);
|
||||
while (it != block_map_.end() && it->first.first == filename) {
|
||||
auto next = std::next(it);
|
||||
RemoveBlock(it);
|
||||
it = next;
|
||||
}
|
||||
}
|
||||
|
||||
void RamFileBlockCache::RemoveBlock(BlockMap::iterator entry) {
|
||||
// This signals that the block is removed, and should not be inadvertently
|
||||
// reinserted into the cache in UpdateLRU.
|
||||
entry->second->timestamp = 0;
|
||||
lru_list_.erase(entry->second->lru_iterator);
|
||||
lra_list_.erase(entry->second->lra_iterator);
|
||||
cache_size_ -= entry->second->data.capacity();
|
||||
block_map_.erase(entry);
|
||||
}
|
||||
|
||||
} // namespace tf_gcs_filesystem
|
@ -0,0 +1,267 @@
|
||||
/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
|
||||
#ifndef TENSORFLOW_C_EXPERIMENTAL_FILESYSTEM_PLUGINS_GCS_RAM_FILE_BLOCK_CACHE_H_
|
||||
#define TENSORFLOW_C_EXPERIMENTAL_FILESYSTEM_PLUGINS_GCS_RAM_FILE_BLOCK_CACHE_H_
|
||||
|
||||
#include <functional>
|
||||
#include <iostream>
|
||||
#include <list>
|
||||
#include <map>
|
||||
#include <memory>
|
||||
#include <string>
|
||||
#include <vector>
|
||||
|
||||
#include "absl/base/thread_annotations.h"
|
||||
#include "absl/synchronization/mutex.h"
|
||||
#include "absl/synchronization/notification.h"
|
||||
#include "tensorflow/c/env.h"
|
||||
#include "tensorflow/c/experimental/filesystem/plugins/gcs/file_block_cache.h"
|
||||
#include "tensorflow/c/tf_status.h"
|
||||
|
||||
namespace tf_gcs_filesystem {
|
||||
|
||||
/// \brief An LRU block cache of file contents, keyed by {filename, offset}.
|
||||
///
|
||||
/// This class should be shared by read-only random access files on a remote
|
||||
/// filesystem (e.g. GCS).
|
||||
class RamFileBlockCache : public FileBlockCache {
|
||||
public:
|
||||
/// The callback executed when a block is not found in the cache, and needs to
|
||||
/// be fetched from the backing filesystem. This callback is provided when the
|
||||
/// cache is constructed. The `status` should be `TF_OK` as long as the
|
||||
/// read from the remote filesystem succeeded (similar to the semantics of the
|
||||
/// read(2) system call).
|
||||
typedef std::function<void(const std::string& filename, size_t offset,
|
||||
size_t buffer_size, char* buffer,
|
||||
size_t* bytes_transferred, TF_Status* status)>
|
||||
BlockFetcher;
|
||||
|
||||
RamFileBlockCache(size_t block_size, size_t max_bytes, uint64_t max_staleness,
|
||||
BlockFetcher block_fetcher,
|
||||
std::function<uint64_t()> timer_seconds)
|
||||
: block_size_(block_size),
|
||||
max_bytes_(max_bytes),
|
||||
max_staleness_(max_staleness),
|
||||
block_fetcher_(block_fetcher),
|
||||
timer_seconds_(timer_seconds),
|
||||
pruning_thread_(nullptr,
|
||||
[](TF_Thread* thread) { TF_JoinThread(thread); }) {
|
||||
if (max_staleness_ > 0) {
|
||||
TF_ThreadOptions thread_options;
|
||||
TF_DefaultThreadOptions(&thread_options);
|
||||
pruning_thread_.reset(
|
||||
TF_StartThread(&thread_options, "TF_prune_FBC", PruneThread, this));
|
||||
}
|
||||
std::cout << "GCS file block cache is "
|
||||
<< (IsCacheEnabled() ? "enabled" : "disabled");
|
||||
}
|
||||
|
||||
~RamFileBlockCache() override {
|
||||
if (pruning_thread_) {
|
||||
stop_pruning_thread_.Notify();
|
||||
// Destroying pruning_thread_ will block until Prune() receives the above
|
||||
// notification and returns.
|
||||
pruning_thread_.reset();
|
||||
}
|
||||
}
|
||||
|
||||
/// Read `n` bytes from `filename` starting at `offset` into `buffer`. This
|
||||
/// method will set `status` to:
|
||||
///
|
||||
/// 1) The error from the remote filesystem, if the read from the remote
|
||||
/// filesystem failed.
|
||||
/// 2) `TF_FAILED_PRECONDITION` if the read from the remote filesystem
|
||||
/// succeeded,
|
||||
/// but the read returned a partial block, and the LRU cache contained a
|
||||
/// block at a higher offset (indicating that the partial block should have
|
||||
/// been a full block).
|
||||
/// 3) `TF_OUT_OF_RANGE` if the read from the remote filesystem succeeded, but
|
||||
/// the file contents do not extend past `offset` and thus nothing was
|
||||
/// placed in `out`.
|
||||
/// 4) `TF_OK` otherwise (i.e. the read succeeded, and at least one byte was
|
||||
/// placed
|
||||
/// in `buffer`).
|
||||
///
|
||||
/// Caller is responsible for allocating memory for `buffer`.
|
||||
/// `buffer` will be left unchanged in case of errors.
|
||||
void Read(const std::string& filename, size_t offset, size_t n, char* buffer,
|
||||
size_t* bytes_transferred, TF_Status* status) override;
|
||||
|
||||
// Validate the given file signature with the existing file signature in the
|
||||
// cache. Returns true if the signature doesn't change or the file doesn't
|
||||
// exist before. If the signature changes, update the existing signature with
|
||||
// the new one and remove the file from cache.
|
||||
bool ValidateAndUpdateFileSignature(const std::string& filename,
|
||||
int64_t file_signature) override
|
||||
ABSL_LOCKS_EXCLUDED(mu_);
|
||||
|
||||
/// Remove all cached blocks for `filename`.
|
||||
void RemoveFile(const std::string& filename) override
|
||||
ABSL_LOCKS_EXCLUDED(mu_);
|
||||
|
||||
/// Remove all cached data.
|
||||
void Flush() override ABSL_LOCKS_EXCLUDED(mu_);
|
||||
|
||||
/// Accessors for cache parameters.
|
||||
size_t block_size() const override { return block_size_; }
|
||||
size_t max_bytes() const override { return max_bytes_; }
|
||||
uint64_t max_staleness() const override { return max_staleness_; }
|
||||
|
||||
/// The current size (in bytes) of the cache.
|
||||
size_t CacheSize() const override ABSL_LOCKS_EXCLUDED(mu_);
|
||||
|
||||
// Returns true if the cache is enabled. If false, the BlockFetcher callback
|
||||
// is always executed during Read.
|
||||
bool IsCacheEnabled() const override {
|
||||
return block_size_ > 0 && max_bytes_ > 0;
|
||||
}
|
||||
|
||||
// We can not pass a lambda with capture as a function pointer to
|
||||
// `TF_StartThread`, so we have to wrap `Prune` inside a static function.
|
||||
static void PruneThread(void* param) {
|
||||
auto ram_file_block_cache = static_cast<RamFileBlockCache*>(param);
|
||||
ram_file_block_cache->Prune();
|
||||
}
|
||||
|
||||
private:
|
||||
/// The size of the blocks stored in the LRU cache, as well as the size of the
|
||||
/// reads from the underlying filesystem.
|
||||
const size_t block_size_;
|
||||
/// The maximum number of bytes (sum of block sizes) allowed in the LRU cache.
|
||||
const size_t max_bytes_;
|
||||
/// The maximum staleness of any block in the LRU cache, in seconds.
|
||||
const uint64_t max_staleness_;
|
||||
/// The callback to read a block from the underlying filesystem.
|
||||
const BlockFetcher block_fetcher_;
|
||||
/// The callback to read timestamps.
|
||||
const std::function<uint64_t()> timer_seconds_;
|
||||
|
||||
/// \brief The key type for the file block cache.
|
||||
///
|
||||
/// The file block cache key is a {filename, offset} pair.
|
||||
typedef std::pair<std::string, size_t> Key;
|
||||
|
||||
/// \brief The state of a block.
|
||||
///
|
||||
/// A block begins in the CREATED stage. The first thread will attempt to read
|
||||
/// the block from the filesystem, transitioning the state of the block to
|
||||
/// FETCHING. After completing, if the read was successful the state should
|
||||
/// be FINISHED. Otherwise the state should be ERROR. A subsequent read can
|
||||
/// re-fetch the block if the state is ERROR.
|
||||
enum class FetchState {
|
||||
CREATED,
|
||||
FETCHING,
|
||||
FINISHED,
|
||||
ERROR,
|
||||
};
|
||||
|
||||
/// \brief A block of a file.
|
||||
///
|
||||
/// A file block consists of the block data, the block's current position in
|
||||
/// the LRU cache, the timestamp (seconds since epoch) at which the block
|
||||
/// was cached, a coordination lock, and state & condition variables.
|
||||
///
|
||||
/// Thread safety:
|
||||
/// The iterator and timestamp fields should only be accessed while holding
|
||||
/// the block-cache-wide mu_ instance variable. The state variable should only
|
||||
/// be accessed while holding the Block's mu lock. The data vector should only
|
||||
/// be accessed after state == FINISHED, and it should never be modified.
|
||||
///
|
||||
/// In order to prevent deadlocks, never grab the block-cache-wide mu_ lock
|
||||
/// AFTER grabbing any block's mu lock. It is safe to grab mu without locking
|
||||
/// mu_.
|
||||
struct Block {
|
||||
/// The block data.
|
||||
std::vector<char> data;
|
||||
/// A list iterator pointing to the block's position in the LRU list.
|
||||
std::list<Key>::iterator lru_iterator;
|
||||
/// A list iterator pointing to the block's position in the LRA list.
|
||||
std::list<Key>::iterator lra_iterator;
|
||||
/// The timestamp (seconds since epoch) at which the block was cached.
|
||||
uint64_t timestamp;
|
||||
/// Mutex to guard state variable
|
||||
absl::Mutex mu;
|
||||
/// The state of the block.
|
||||
FetchState state ABSL_GUARDED_BY(mu) = FetchState::CREATED;
|
||||
/// Wait on cond_var if state is FETCHING.
|
||||
absl::CondVar cond_var;
|
||||
};
|
||||
|
||||
/// \brief The block map type for the file block cache.
|
||||
///
|
||||
/// The block map is an ordered map from Key to Block.
|
||||
typedef std::map<Key, std::shared_ptr<Block>> BlockMap;
|
||||
|
||||
/// Prune the cache by removing files with expired blocks.
|
||||
void Prune() ABSL_LOCKS_EXCLUDED(mu_);
|
||||
|
||||
bool BlockNotStale(const std::shared_ptr<Block>& block)
|
||||
ABSL_EXCLUSIVE_LOCKS_REQUIRED(mu_);
|
||||
|
||||
/// Look up a Key in the block cache.
|
||||
std::shared_ptr<Block> Lookup(const Key& key) ABSL_LOCKS_EXCLUDED(mu_);
|
||||
|
||||
void MaybeFetch(const Key& key, const std::shared_ptr<Block>& block,
|
||||
TF_Status* status) ABSL_LOCKS_EXCLUDED(mu_);
|
||||
|
||||
/// Trim the block cache to make room for another entry.
|
||||
void Trim() ABSL_EXCLUSIVE_LOCKS_REQUIRED(mu_);
|
||||
|
||||
/// Update the LRU iterator for the block at `key`.
|
||||
void UpdateLRU(const Key& key, const std::shared_ptr<Block>& block,
|
||||
TF_Status* status) ABSL_LOCKS_EXCLUDED(mu_);
|
||||
|
||||
/// Remove all blocks of a file, with mu_ already held.
|
||||
void RemoveFile_Locked(const std::string& filename)
|
||||
ABSL_EXCLUSIVE_LOCKS_REQUIRED(mu_);
|
||||
|
||||
/// Remove the block `entry` from the block map and LRU list, and update the
|
||||
/// cache size accordingly.
|
||||
void RemoveBlock(BlockMap::iterator entry) ABSL_EXCLUSIVE_LOCKS_REQUIRED(mu_);
|
||||
|
||||
/// The cache pruning thread that removes files with expired blocks.
|
||||
std::unique_ptr<TF_Thread, std::function<void(TF_Thread*)>> pruning_thread_;
|
||||
|
||||
/// Notification for stopping the cache pruning thread.
|
||||
absl::Notification stop_pruning_thread_;
|
||||
|
||||
/// Guards access to the block map, LRU list, and cached byte count.
|
||||
mutable absl::Mutex mu_;
|
||||
|
||||
/// The block map (map from Key to Block).
|
||||
BlockMap block_map_ ABSL_GUARDED_BY(mu_);
|
||||
|
||||
/// The LRU list of block keys. The front of the list identifies the most
|
||||
/// recently accessed block.
|
||||
std::list<Key> lru_list_ ABSL_GUARDED_BY(mu_);
|
||||
|
||||
/// The LRA (least recently added) list of block keys. The front of the list
|
||||
/// identifies the most recently added block.
|
||||
///
|
||||
/// Note: blocks are added to lra_list_ only after they have successfully been
|
||||
/// fetched from the underlying block store.
|
||||
std::list<Key> lra_list_ ABSL_GUARDED_BY(mu_);
|
||||
|
||||
/// The combined number of bytes in all of the cached blocks.
|
||||
size_t cache_size_ ABSL_GUARDED_BY(mu_) = 0;
|
||||
|
||||
// A filename->file_signature map.
|
||||
std::map<std::string, int64_t> file_signature_map_ ABSL_GUARDED_BY(mu_);
|
||||
};
|
||||
|
||||
} // namespace tf_gcs_filesystem
|
||||
|
||||
#endif // TENSORFLOW_C_EXPERIMENTAL_FILESYSTEM_PLUGINS_GCS_RAM_FILE_BLOCK_CACHE_H_
|
@ -18,6 +18,7 @@ limitations under the License.
|
||||
#include <string>
|
||||
|
||||
#include "absl/strings/string_view.h"
|
||||
#include "tensorflow/core/platform/logging.h"
|
||||
#include "tensorflow/core/platform/protobuf.h"
|
||||
#include "tensorflow/core/protobuf/struct.pb.h"
|
||||
|
||||
|
@ -248,15 +248,22 @@ TF_Tensor* TF_AllocateOutput(TF_OpKernelContext* context, int index,
|
||||
size_t len, TF_Status* status) {
|
||||
TF_SetStatus(status, TF_OK, "");
|
||||
auto* cc_ctx = reinterpret_cast<::tensorflow::OpKernelContext*>(context);
|
||||
tensorflow::AllocatorAttributes attr = cc_ctx->output_alloc_attr(index);
|
||||
auto* allocator = cc_ctx->get_allocator(attr);
|
||||
void* data = tensorflow::allocate_tensor("TF_AllocateOutput", len, allocator);
|
||||
TF_Tensor* result = TF_NewTensor(dtype, dims, num_dims, data, len,
|
||||
tensorflow::deallocate_buffer, allocator);
|
||||
TF_SetOutput(context, index, result, status);
|
||||
if (TF_GetCode(status) != TF_OK) {
|
||||
TF_DeleteTensor(result);
|
||||
|
||||
static_assert(sizeof(int64_t) == sizeof(tensorflow::int64),
|
||||
"64-bit int types should match in size");
|
||||
tensorflow::gtl::ArraySlice<tensorflow::int64> dimarray(
|
||||
reinterpret_cast<tensorflow::int64*>(dims), num_dims);
|
||||
tensorflow::Tensor* tensor;
|
||||
tensorflow::Status s = cc_ctx->allocate_output(
|
||||
index, tensorflow::TensorShape(dimarray), &tensor);
|
||||
if (!s.ok()) {
|
||||
::tensorflow::Set_TF_Status_from_Status(status, s);
|
||||
return nullptr;
|
||||
}
|
||||
return result;
|
||||
TF_Tensor* tf_tensor = TF_TensorFromTensor(*tensor, &s);
|
||||
if (!s.ok()) {
|
||||
::tensorflow::Set_TF_Status_from_Status(status, s);
|
||||
return nullptr;
|
||||
}
|
||||
return tf_tensor;
|
||||
}
|
||||
|
@ -1096,33 +1096,33 @@ StatusOr<bool> IsIdentityDrivingConstsInLoop(Node* node) {
|
||||
return true;
|
||||
}
|
||||
|
||||
absl::flat_hash_set<string> GetOrCreateWhitelist() {
|
||||
absl::flat_hash_map<string, std::vector<string>>* whitelist_table =
|
||||
tensorflow::GetWhitelistTable();
|
||||
absl::flat_hash_set<string> GetOrCreateAllowlist() {
|
||||
absl::flat_hash_map<string, std::vector<string>>* allowlist_table =
|
||||
tensorflow::GetAllowlistTable();
|
||||
MarkForCompilationPassFlags* flags = GetMarkForCompilationPassFlags();
|
||||
absl::flat_hash_set<string> whitelist;
|
||||
absl::flat_hash_set<string> allowlist;
|
||||
|
||||
for (auto s : absl::StrSplit(flags->tf_xla_ops_to_cluster, ',')) {
|
||||
if (s == "FUSIBLE") {
|
||||
for (auto pair : *whitelist_table) {
|
||||
whitelist.insert(pair.second.begin(), pair.second.end());
|
||||
for (auto pair : *allowlist_table) {
|
||||
allowlist.insert(pair.second.begin(), pair.second.end());
|
||||
}
|
||||
} else if (whitelist_table->contains(s)) {
|
||||
auto v = whitelist_table->at(s);
|
||||
whitelist.insert(v.begin(), v.end());
|
||||
} else if (allowlist_table->contains(s)) {
|
||||
auto v = allowlist_table->at(s);
|
||||
allowlist.insert(v.begin(), v.end());
|
||||
} else if (!s.empty()) {
|
||||
// Should be a user provided TF operation.
|
||||
whitelist.insert(string(s));
|
||||
allowlist.insert(string(s));
|
||||
}
|
||||
}
|
||||
|
||||
if (VLOG_IS_ON(2) && !whitelist.empty()) {
|
||||
std::vector<string> vwhitelist(whitelist.begin(), whitelist.end());
|
||||
absl::c_sort(vwhitelist);
|
||||
if (VLOG_IS_ON(2) && !allowlist.empty()) {
|
||||
std::vector<string> vallowlist(allowlist.begin(), allowlist.end());
|
||||
absl::c_sort(vallowlist);
|
||||
VLOG(2) << "XLA clustering will only consider the following TF operations: "
|
||||
<< absl::StrJoin(vwhitelist, " ");
|
||||
<< absl::StrJoin(vallowlist, " ");
|
||||
}
|
||||
return whitelist;
|
||||
return allowlist;
|
||||
}
|
||||
|
||||
Status MarkForCompilationPassImpl::FindCompilationCandidates() {
|
||||
@ -1156,12 +1156,12 @@ Status MarkForCompilationPassImpl::FindCompilationCandidates() {
|
||||
|
||||
VLOG(2) << "sorted_nodes.size() = " << sorted_nodes.size();
|
||||
|
||||
auto whitelist = GetOrCreateWhitelist();
|
||||
auto allowlist = GetOrCreateAllowlist();
|
||||
|
||||
std::vector<string> vall_ops = XlaOpRegistry::GetAllRegisteredOps();
|
||||
absl::flat_hash_set<string> all_ops(vall_ops.begin(), vall_ops.end());
|
||||
// Check that user's provided TF operation really exists.
|
||||
for (const auto& s : whitelist) {
|
||||
for (const auto& s : allowlist) {
|
||||
if (!all_ops.contains(string(s))) {
|
||||
return errors::InvalidArgument(
|
||||
"The operation '", s,
|
||||
@ -1206,7 +1206,7 @@ Status MarkForCompilationPassImpl::FindCompilationCandidates() {
|
||||
continue;
|
||||
}
|
||||
|
||||
if (!whitelist.empty() && !whitelist.contains(node->def().op())) {
|
||||
if (!allowlist.empty() && !allowlist.contains(node->def().op())) {
|
||||
VLOG(1) << "Rejecting TF operation " << node->def().op()
|
||||
<< " as it is not listed in --tf_xla_ops_to_cluster.";
|
||||
continue;
|
||||
@ -1781,7 +1781,7 @@ Status MarkForCompilationPass::RunForTest(
|
||||
return MarkForCompilation(options, debug_options);
|
||||
}
|
||||
|
||||
absl::flat_hash_map<string, std::vector<string>>* GetWhitelistTable() {
|
||||
absl::flat_hash_map<string, std::vector<string>>* GetAllowlistTable() {
|
||||
// Table format: category name: {list of TF operations in that category}
|
||||
static absl::flat_hash_map<string, std::vector<string>>* result =
|
||||
new absl::flat_hash_map<string, std::vector<string>>{
|
||||
@ -1845,7 +1845,7 @@ absl::flat_hash_map<string, std::vector<string>>* GetWhitelistTable() {
|
||||
namespace testing {
|
||||
void ResetClusterSequenceNumber() { cluster_sequence_num = 0; }
|
||||
|
||||
absl::flat_hash_set<string> GetKnownXLAWhitelistOp() {
|
||||
absl::flat_hash_set<string> GetKnownXLAAllowlistOp() {
|
||||
absl::flat_hash_set<string> result{"AdjustContrastv2",
|
||||
"AdjustHue",
|
||||
"AdjustSaturation",
|
||||
|
@ -58,7 +58,7 @@ bool IsCompilable(FunctionLibraryRuntime* flr, const NodeDef& ndef,
|
||||
RecursiveCompilabilityChecker::UncompilableNodesMap*
|
||||
uncompilable_node_info = nullptr);
|
||||
|
||||
absl::flat_hash_map<string, std::vector<string>>* GetWhitelistTable();
|
||||
absl::flat_hash_map<string, std::vector<string>>* GetAllowlistTable();
|
||||
|
||||
namespace testing {
|
||||
// DO NOT USE IN PRODUCTION.
|
||||
@ -66,8 +66,8 @@ namespace testing {
|
||||
// Resets some internal state to let us write reliable unit tests.
|
||||
void ResetClusterSequenceNumber();
|
||||
|
||||
// Return a list of operation that we choose not to put into the whitelist.
|
||||
absl::flat_hash_set<string> GetKnownXLAWhitelistOp();
|
||||
// Return a list of operation that we choose not to put into the allowlist.
|
||||
absl::flat_hash_set<string> GetKnownXLAAllowlistOp();
|
||||
} // namespace testing
|
||||
} // namespace tensorflow
|
||||
|
||||
|
@ -1802,34 +1802,34 @@ TEST(XlaCompilationTest, StagePipelinePreservedByClusterScopingPass) {
|
||||
EXPECT_NE(clusters["relu0"], clusters["relu1"]);
|
||||
}
|
||||
}
|
||||
TEST(XlaCompilationTest, XLALiteWhitelist) {
|
||||
auto* whitelist_table = tensorflow::GetWhitelistTable();
|
||||
absl::flat_hash_set<string> hwhitelist;
|
||||
TEST(XlaCompilationTest, XLALiteAllowlist) {
|
||||
auto* allowlist_table = tensorflow::GetAllowlistTable();
|
||||
absl::flat_hash_set<string> hallowlist;
|
||||
std::vector<string> vall_ops = XlaOpRegistry::GetAllRegisteredOps();
|
||||
absl::flat_hash_set<string> all_ops(vall_ops.begin(), vall_ops.end());
|
||||
|
||||
// Check that all the operations in the table are existing TF operations
|
||||
for (auto pair : *whitelist_table) {
|
||||
hwhitelist.insert(pair.second.begin(), pair.second.end());
|
||||
for (auto pair : *allowlist_table) {
|
||||
hallowlist.insert(pair.second.begin(), pair.second.end());
|
||||
for (auto op : pair.second) {
|
||||
ASSERT_TRUE(all_ops.contains(op));
|
||||
}
|
||||
}
|
||||
|
||||
// Check that all registered XLA operation are in the whitelist
|
||||
// Check that all registered XLA operation are in the allowlist
|
||||
// table or are known to not be in it.
|
||||
|
||||
absl::flat_hash_set<string> known_not_in_list =
|
||||
tensorflow::testing::GetKnownXLAWhitelistOp();
|
||||
tensorflow::testing::GetKnownXLAAllowlistOp();
|
||||
std::vector<string> unknow_op;
|
||||
for (string op : vall_ops) {
|
||||
if (!hwhitelist.contains(op) && !known_not_in_list.contains(op)) {
|
||||
if (!hallowlist.contains(op) && !known_not_in_list.contains(op)) {
|
||||
unknow_op.push_back(op);
|
||||
}
|
||||
}
|
||||
EXPECT_TRUE(unknow_op.empty())
|
||||
<< "Someone added support for a new TF opeations inside XLA. They must "
|
||||
"be included in the XLALite whitelist or blacklist:\n"
|
||||
"be included in the XLALite allowlist or blacklist:\n"
|
||||
<< absl::StrJoin(unknow_op, "\n");
|
||||
}
|
||||
} // namespace
|
||||
|
@ -669,6 +669,8 @@ cc_library(
|
||||
":lhlo_legalize_to_llvm", # build-cleaner: keep
|
||||
":xla_materialize_broadcasts", # build-cleaner: keep
|
||||
":xla_unfuse_batch_norm", # build-cleaner: keep
|
||||
"@llvm-project//mlir:AffineToStandardTransforms",
|
||||
"@llvm-project//mlir:CFGTransforms",
|
||||
"@llvm-project//mlir:IR",
|
||||
"@llvm-project//mlir:InferTypeOpInterface",
|
||||
"@llvm-project//mlir:LLVMDialect",
|
||||
|
@ -28,18 +28,18 @@ limitations under the License.
|
||||
#include "mlir/Interfaces/SideEffectInterfaces.h" // from @llvm-project
|
||||
|
||||
namespace mlir {
|
||||
namespace xla_chlo {
|
||||
namespace chlo {
|
||||
|
||||
class XlaHloClientDialect : public Dialect {
|
||||
class HloClientDialect : public Dialect {
|
||||
public:
|
||||
explicit XlaHloClientDialect(MLIRContext *context);
|
||||
static StringRef getDialectNamespace() { return "xla_chlo"; }
|
||||
explicit HloClientDialect(MLIRContext *context);
|
||||
static StringRef getDialectNamespace() { return "chlo"; }
|
||||
};
|
||||
|
||||
#define GET_OP_CLASSES
|
||||
#include "tensorflow/compiler/mlir/hlo/include/mlir-hlo/Dialect/mhlo/IR/chlo_ops.h.inc"
|
||||
|
||||
} // namespace xla_chlo
|
||||
} // namespace chlo
|
||||
} // namespace mlir
|
||||
|
||||
#endif // TENSORFLOW_COMPILER_MLIR_HLO_INCLUDE_MLIR_HLO_DIALECT_MHLO_IR_CHLO_OPS_H_
|
||||
|
@ -22,7 +22,7 @@ limitations under the License.
|
||||
//
|
||||
// The typical use of this dialect is for client libraries to be able to emit
|
||||
// less constrained ops and rely on the conversion framework to lower any
|
||||
// xla_chlo ops to canonical mhlo ops.
|
||||
// chlo ops to canonical mhlo ops.
|
||||
//
|
||||
// See: https://www.tensorflow.org/xla/operation_semantics
|
||||
|
||||
@ -35,8 +35,8 @@ include "mlir/Interfaces/SideEffectInterfaces.td"
|
||||
include "mlir-hlo/Dialect/mhlo/IR/hlo_ops_base.td"
|
||||
|
||||
def HLOClient_Dialect : Dialect {
|
||||
let name = "xla_chlo";
|
||||
let cppNamespace = "xla_chlo";
|
||||
let name = "chlo";
|
||||
let cppNamespace = "chlo";
|
||||
let summary = [{
|
||||
XLA Client HLO Ops
|
||||
}];
|
||||
|
@ -39,9 +39,9 @@ class OpBuilder;
|
||||
|
||||
namespace mhlo {
|
||||
|
||||
class XlaHloDialect : public Dialect {
|
||||
class MhloDialect : public Dialect {
|
||||
public:
|
||||
explicit XlaHloDialect(MLIRContext *context);
|
||||
explicit MhloDialect(MLIRContext *context);
|
||||
static StringRef getDialectNamespace() { return "mhlo"; }
|
||||
|
||||
// Registered hook to materialize a constant operation from a given attribute
|
||||
|
@ -35,18 +35,18 @@ class OpBuilder;
|
||||
|
||||
#include "tensorflow/compiler/mlir/hlo/include/mlir-hlo/Dialect/mhlo/IR/lhlo_structs.h.inc"
|
||||
|
||||
namespace xla_lhlo {
|
||||
namespace lmhlo {
|
||||
|
||||
class XlaLhloDialect : public Dialect {
|
||||
class LmhloDialect : public Dialect {
|
||||
public:
|
||||
explicit XlaLhloDialect(MLIRContext *context);
|
||||
static StringRef getDialectNamespace() { return "xla_lhlo"; }
|
||||
explicit LmhloDialect(MLIRContext *context);
|
||||
static StringRef getDialectNamespace() { return "lmhlo"; }
|
||||
};
|
||||
|
||||
#define GET_OP_CLASSES
|
||||
#include "tensorflow/compiler/mlir/hlo/include/mlir-hlo/Dialect/mhlo/IR/lhlo_ops.h.inc"
|
||||
|
||||
} // namespace xla_lhlo
|
||||
} // namespace lmhlo
|
||||
} // end namespace mlir
|
||||
|
||||
#endif // TENSORFLOW_COMPILER_MLIR_HLO_INCLUDE_MLIR_HLO_DIALECT_MHLO_IR_LHLO_OPS_H_
|
||||
|
@ -38,8 +38,8 @@ include "mlir/Interfaces/ViewLikeInterface.td"
|
||||
include "mlir-hlo/Dialect/mhlo/IR/hlo_ops_base.td"
|
||||
|
||||
def LHLO_Dialect : Dialect {
|
||||
let name = "xla_lhlo";
|
||||
let cppNamespace = "xla_lhlo";
|
||||
let name = "lmhlo";
|
||||
let cppNamespace = "lmhlo";
|
||||
}
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
@ -253,7 +253,7 @@ def LHLO_ReduceWindowOp: LHLO_Op<"reduce_window", [
|
||||
|
||||
// TODO(timshen): Add a custom parser to hide operand_segment_sizes. For example,
|
||||
// A tuple-like pattern match syntax could work:
|
||||
// xla_lhlo.case %index, (%input0, %input1, %input2), (%output0, %output1) {
|
||||
// lmhlo.case %index, (%input0, %input1, %input2), (%output0, %output1) {
|
||||
// ...
|
||||
// }, {
|
||||
// ...
|
||||
@ -337,7 +337,7 @@ def HLO_StaticMemRefCastOp: Op<LHLO_Dialect, "static_memref_cast",
|
||||
Example:
|
||||
```mlir
|
||||
%buf_transformed =
|
||||
xla_lhlo.static_memref_cast %buf
|
||||
lmhlo.static_memref_cast %buf
|
||||
: memref<1x5xf32> -> memref<5xf32, offset: 2, strides: [1]>
|
||||
|
||||
// The result of the op is a rank-1 memref with `[5]` shape, stride 1 and
|
||||
@ -379,7 +379,7 @@ def HLO_DynamicMemRefCastOp: Op<LHLO_Dialect, "dynamic_memref_cast",
|
||||
Example:
|
||||
```mlir
|
||||
%buf_transformed =
|
||||
xla_lhlo.dynamic_memref_cast %buf(%size_X, %size_Y)[%step_X, %step_Y]
|
||||
lmhlo.dynamic_memref_cast %buf(%size_X, %size_Y)[%step_X, %step_Y]
|
||||
: memref<?x?xf32> -> memref<?x?xf32, offset: 0, strides: [?, ?]>
|
||||
// The result of the op is a type-erased memref with `[%size_X, %size_Y]`
|
||||
// shape and `[%step_X, %step_Y]` strides. The offset will be inherited
|
||||
@ -470,14 +470,6 @@ def ReshapeMemRefCastOp: Op<LHLO_Dialect, "reshape_memref_cast", [
|
||||
);
|
||||
let results = (outs AnyRankedOrUnrankedMemRef:$result);
|
||||
|
||||
let builders = [OpBuilder<
|
||||
"OpBuilder &builder, OperationState &result, MemRefType resultType, " #
|
||||
"Value operand, Value shape", [{
|
||||
result.addOperands(operand);
|
||||
result.addOperands(shape);
|
||||
result.types.push_back(resultType);
|
||||
}]>];
|
||||
|
||||
let extraClassDeclaration = [{
|
||||
MemRefType getType() { return getResult().getType().cast<MemRefType>(); }
|
||||
}];
|
||||
|
@ -34,7 +34,7 @@ using HloToLhloOp = typename HloToLhloOpImpl<HloOpTy>::Type;
|
||||
#define MAP_HLO_TO_LHLO(OpName) \
|
||||
template <> \
|
||||
struct HloToLhloOpImpl<mhlo::OpName> { \
|
||||
using Type = xla_lhlo::OpName; \
|
||||
using Type = lmhlo::OpName; \
|
||||
}
|
||||
|
||||
MAP_HLO_TO_LHLO(AbsOp);
|
||||
|
@ -24,7 +24,7 @@ limitations under the License.
|
||||
#include "tensorflow/compiler/mlir/hlo/include/mlir-hlo/Dialect/mhlo/transforms/map_hlo_to_lhlo_op.h"
|
||||
|
||||
namespace mlir {
|
||||
namespace xla_lhlo {
|
||||
namespace lmhlo {
|
||||
namespace impl {
|
||||
|
||||
// A struct to map LhloBinaryOpTy type to the corresponding floating-point and
|
||||
@ -33,32 +33,32 @@ template <typename LhloBinaryOpTy>
|
||||
struct LhloToScalarOp;
|
||||
|
||||
template <>
|
||||
struct LhloToScalarOp<xla_lhlo::AddOp> {
|
||||
struct LhloToScalarOp<lmhlo::AddOp> {
|
||||
using FOp = ::mlir::AddFOp;
|
||||
using IOp = ::mlir::AddIOp;
|
||||
};
|
||||
template <>
|
||||
struct LhloToScalarOp<xla_lhlo::CompareOp> {
|
||||
struct LhloToScalarOp<lmhlo::CompareOp> {
|
||||
using FOp = ::mlir::CmpFOp;
|
||||
using IOp = ::mlir::CmpIOp;
|
||||
};
|
||||
template <>
|
||||
struct LhloToScalarOp<xla_lhlo::DivOp> {
|
||||
struct LhloToScalarOp<lmhlo::DivOp> {
|
||||
using FOp = ::mlir::DivFOp;
|
||||
using IOp = ::mlir::SignedDivIOp;
|
||||
};
|
||||
template <>
|
||||
struct LhloToScalarOp<xla_lhlo::MulOp> {
|
||||
struct LhloToScalarOp<lmhlo::MulOp> {
|
||||
using FOp = ::mlir::MulFOp;
|
||||
using IOp = ::mlir::MulIOp;
|
||||
};
|
||||
template <>
|
||||
struct LhloToScalarOp<xla_lhlo::RemOp> {
|
||||
struct LhloToScalarOp<lmhlo::RemOp> {
|
||||
using FOp = ::mlir::RemFOp;
|
||||
using IOp = ::mlir::SignedRemIOp;
|
||||
};
|
||||
template <>
|
||||
struct LhloToScalarOp<xla_lhlo::SubOp> {
|
||||
struct LhloToScalarOp<lmhlo::SubOp> {
|
||||
using FOp = ::mlir::SubFOp;
|
||||
using IOp = ::mlir::SubIOp;
|
||||
};
|
||||
@ -116,16 +116,17 @@ inline Value MapLhloOpToStdScalarOp(Location loc, ArrayRef<Type> result_types,
|
||||
}
|
||||
|
||||
template <>
|
||||
inline Value MapLhloOpToStdScalarOp<xla_lhlo::AbsOp>(
|
||||
Location loc, ArrayRef<Type> result_types, ArrayRef<Value> args,
|
||||
OpBuilder* b) {
|
||||
inline Value MapLhloOpToStdScalarOp<lmhlo::AbsOp>(Location loc,
|
||||
ArrayRef<Type> result_types,
|
||||
ArrayRef<Value> args,
|
||||
OpBuilder* b) {
|
||||
Type element_type = args.front().getType();
|
||||
if (element_type.isa<FloatType>()) {
|
||||
return MapLhloOpToStdScalarOpImpl<FloatType, ::mlir::AbsFOp>{}(
|
||||
loc, result_types, args, b);
|
||||
}
|
||||
if (element_type.isa<IntegerType>()) {
|
||||
// xla_lhlo.abs(x, result) -> result = select((x > 0), x, sub(0, x))
|
||||
// lmhlo.abs(x, result) -> result = select((x > 0), x, sub(0, x))
|
||||
Value lhs = args[0];
|
||||
auto integer_type = element_type.dyn_cast<IntegerType>();
|
||||
|
||||
@ -133,16 +134,17 @@ inline Value MapLhloOpToStdScalarOp<xla_lhlo::AbsOp>(
|
||||
b->create<::mlir::ConstantIntOp>(loc, 0, integer_type.getWidth());
|
||||
auto lhs_gt_zero = b->create<ScalarIOp<CompareOp>>(loc, CmpIPredicate::sge,
|
||||
lhs, zero_intval);
|
||||
auto neg_val = b->create<ScalarIOp<xla_lhlo::SubOp>>(loc, zero_intval, lhs);
|
||||
auto neg_val = b->create<ScalarIOp<lmhlo::SubOp>>(loc, zero_intval, lhs);
|
||||
return b->create<::mlir::SelectOp>(loc, lhs_gt_zero, lhs, neg_val);
|
||||
}
|
||||
return nullptr;
|
||||
}
|
||||
|
||||
template <>
|
||||
inline Value MapLhloOpToStdScalarOp<xla_lhlo::AndOp>(
|
||||
Location loc, ArrayRef<Type> result_types, ArrayRef<Value> args,
|
||||
OpBuilder* b) {
|
||||
inline Value MapLhloOpToStdScalarOp<lmhlo::AndOp>(Location loc,
|
||||
ArrayRef<Type> result_types,
|
||||
ArrayRef<Value> args,
|
||||
OpBuilder* b) {
|
||||
return MapLhloOpToStdScalarOpImpl<IntegerType, ::mlir::AndOp>{}(
|
||||
loc, result_types, args, b);
|
||||
}
|
||||
@ -205,30 +207,33 @@ inline Value MapXlaCompareOpToStdScalarOp(Location loc,
|
||||
}
|
||||
|
||||
template <>
|
||||
inline Value MapLhloOpToStdScalarOp<xla_lhlo::CopyOp>(
|
||||
Location loc, ArrayRef<Type> result_types, ArrayRef<Value> args,
|
||||
OpBuilder* b) {
|
||||
inline Value MapLhloOpToStdScalarOp<lmhlo::CopyOp>(Location loc,
|
||||
ArrayRef<Type> result_types,
|
||||
ArrayRef<Value> args,
|
||||
OpBuilder* b) {
|
||||
return args.front();
|
||||
}
|
||||
|
||||
template <>
|
||||
inline Value MapLhloOpToStdScalarOp<xla_lhlo::ExpOp>(
|
||||
Location loc, ArrayRef<Type> result_types, ArrayRef<Value> args,
|
||||
OpBuilder* b) {
|
||||
inline Value MapLhloOpToStdScalarOp<lmhlo::ExpOp>(Location loc,
|
||||
ArrayRef<Type> result_types,
|
||||
ArrayRef<Value> args,
|
||||
OpBuilder* b) {
|
||||
return MapLhloOpToStdScalarOpImpl<FloatType, ::mlir::ExpOp>{}(
|
||||
loc, result_types, args, b);
|
||||
}
|
||||
|
||||
template <>
|
||||
inline Value MapLhloOpToStdScalarOp<xla_lhlo::CeilOp>(
|
||||
Location loc, ArrayRef<Type> result_types, ArrayRef<Value> args,
|
||||
OpBuilder* b) {
|
||||
inline Value MapLhloOpToStdScalarOp<lmhlo::CeilOp>(Location loc,
|
||||
ArrayRef<Type> result_types,
|
||||
ArrayRef<Value> args,
|
||||
OpBuilder* b) {
|
||||
return MapLhloOpToStdScalarOpImpl<FloatType, ::mlir::CeilFOp>{}(
|
||||
loc, result_types, args, b);
|
||||
}
|
||||
|
||||
template <>
|
||||
inline Value MapLhloOpToStdScalarOp<xla_lhlo::ComplexOp>(
|
||||
inline Value MapLhloOpToStdScalarOp<lmhlo::ComplexOp>(
|
||||
Location loc, ArrayRef<Type> result_types, ArrayRef<Value> args,
|
||||
OpBuilder* b) {
|
||||
return MapLhloOpToStdScalarOpImpl<CreateComplexOp>{}(loc, result_types, args,
|
||||
@ -236,21 +241,23 @@ inline Value MapLhloOpToStdScalarOp<xla_lhlo::ComplexOp>(
|
||||
}
|
||||
|
||||
template <>
|
||||
inline Value MapLhloOpToStdScalarOp<xla_lhlo::RealOp>(
|
||||
Location loc, ArrayRef<Type> result_types, ArrayRef<Value> args,
|
||||
OpBuilder* b) {
|
||||
inline Value MapLhloOpToStdScalarOp<lmhlo::RealOp>(Location loc,
|
||||
ArrayRef<Type> result_types,
|
||||
ArrayRef<Value> args,
|
||||
OpBuilder* b) {
|
||||
return MapLhloOpToStdScalarOpImpl<ReOp>{}(loc, result_types, args, b);
|
||||
}
|
||||
|
||||
template <>
|
||||
inline Value MapLhloOpToStdScalarOp<xla_lhlo::ImagOp>(
|
||||
Location loc, ArrayRef<Type> result_types, ArrayRef<Value> args,
|
||||
OpBuilder* b) {
|
||||
inline Value MapLhloOpToStdScalarOp<lmhlo::ImagOp>(Location loc,
|
||||
ArrayRef<Type> result_types,
|
||||
ArrayRef<Value> args,
|
||||
OpBuilder* b) {
|
||||
return MapLhloOpToStdScalarOpImpl<ImOp>{}(loc, result_types, args, b);
|
||||
}
|
||||
|
||||
template <>
|
||||
inline Value MapLhloOpToStdScalarOp<xla_lhlo::ConvertOp>(
|
||||
inline Value MapLhloOpToStdScalarOp<lmhlo::ConvertOp>(
|
||||
Location loc, ArrayRef<Type> result_types, ArrayRef<Value> args,
|
||||
OpBuilder* b) {
|
||||
Type sourceType = args.front().getType();
|
||||
@ -288,9 +295,10 @@ inline Value MapLhloOpToStdScalarOp<xla_lhlo::ConvertOp>(
|
||||
}
|
||||
|
||||
template <>
|
||||
inline Value MapLhloOpToStdScalarOp<xla_lhlo::DotOp>(
|
||||
Location loc, ArrayRef<Type> result_types, ArrayRef<Value> args,
|
||||
OpBuilder* b) {
|
||||
inline Value MapLhloOpToStdScalarOp<lmhlo::DotOp>(Location loc,
|
||||
ArrayRef<Type> result_types,
|
||||
ArrayRef<Value> args,
|
||||
OpBuilder* b) {
|
||||
// Dot Op converter from lhlo to affine only accepts float and integer types.
|
||||
const auto& lhs = args[0];
|
||||
const auto& rhs = args[1];
|
||||
@ -312,17 +320,19 @@ inline Value MapLhloOpToStdScalarOp<xla_lhlo::DotOp>(
|
||||
}
|
||||
|
||||
template <>
|
||||
inline Value MapLhloOpToStdScalarOp<xla_lhlo::CosOp>(
|
||||
Location loc, ArrayRef<Type> result_types, ArrayRef<Value> args,
|
||||
OpBuilder* b) {
|
||||
inline Value MapLhloOpToStdScalarOp<lmhlo::CosOp>(Location loc,
|
||||
ArrayRef<Type> result_types,
|
||||
ArrayRef<Value> args,
|
||||
OpBuilder* b) {
|
||||
return MapLhloOpToStdScalarOpImpl<FloatType, ::mlir::CosOp>{}(
|
||||
loc, result_types, args, b);
|
||||
}
|
||||
|
||||
template <>
|
||||
inline Value MapLhloOpToStdScalarOp<xla_lhlo::SinOp>(
|
||||
Location loc, ArrayRef<Type> result_types, ArrayRef<Value> args,
|
||||
OpBuilder* b) {
|
||||
inline Value MapLhloOpToStdScalarOp<lmhlo::SinOp>(Location loc,
|
||||
ArrayRef<Type> result_types,
|
||||
ArrayRef<Value> args,
|
||||
OpBuilder* b) {
|
||||
return MapLhloOpToStdScalarOpImpl<FloatType, ::mlir::SinOp>{}(
|
||||
loc, result_types, args, b);
|
||||
}
|
||||
@ -361,66 +371,69 @@ struct XlaCompareSelectOpToStdScalarOp<SupportedType, StdCompareOp, Predicate,
|
||||
};
|
||||
|
||||
template <>
|
||||
inline Value MapLhloOpToStdScalarOp<xla_lhlo::LogOp>(
|
||||
Location loc, ArrayRef<Type> result_types, ArrayRef<Value> args,
|
||||
OpBuilder* b) {
|
||||
inline Value MapLhloOpToStdScalarOp<lmhlo::LogOp>(Location loc,
|
||||
ArrayRef<Type> result_types,
|
||||
ArrayRef<Value> args,
|
||||
OpBuilder* b) {
|
||||
return MapLhloOpToStdScalarOpImpl<FloatType, ::mlir::LogOp>{}(
|
||||
loc, result_types, args, b);
|
||||
}
|
||||
|
||||
template <>
|
||||
inline Value MapLhloOpToStdScalarOp<xla_lhlo::MaxOp>(
|
||||
Location loc, ArrayRef<Type> result_types, ArrayRef<Value> args,
|
||||
OpBuilder* b) {
|
||||
inline Value MapLhloOpToStdScalarOp<lmhlo::MaxOp>(Location loc,
|
||||
ArrayRef<Type> result_types,
|
||||
ArrayRef<Value> args,
|
||||
OpBuilder* b) {
|
||||
return XlaCompareSelectOpToStdScalarOp<
|
||||
IntegerType, ScalarIOp<xla_lhlo::CompareOp>, CmpIPredicate, FloatType,
|
||||
ScalarFOp<xla_lhlo::CompareOp>, CmpFPredicate>::map(loc, "GT",
|
||||
result_types, args,
|
||||
b);
|
||||
IntegerType, ScalarIOp<lmhlo::CompareOp>, CmpIPredicate, FloatType,
|
||||
ScalarFOp<lmhlo::CompareOp>, CmpFPredicate>::map(loc, "GT", result_types,
|
||||
args, b);
|
||||
}
|
||||
|
||||
template <>
|
||||
inline Value MapLhloOpToStdScalarOp<xla_lhlo::MinOp>(
|
||||
Location loc, ArrayRef<Type> result_types, ArrayRef<Value> args,
|
||||
OpBuilder* b) {
|
||||
inline Value MapLhloOpToStdScalarOp<lmhlo::MinOp>(Location loc,
|
||||
ArrayRef<Type> result_types,
|
||||
ArrayRef<Value> args,
|
||||
OpBuilder* b) {
|
||||
return XlaCompareSelectOpToStdScalarOp<
|
||||
IntegerType, ScalarIOp<xla_lhlo::CompareOp>, CmpIPredicate, FloatType,
|
||||
ScalarFOp<xla_lhlo::CompareOp>, CmpFPredicate>::map(loc, "LT",
|
||||
result_types, args,
|
||||
b);
|
||||
IntegerType, ScalarIOp<lmhlo::CompareOp>, CmpIPredicate, FloatType,
|
||||
ScalarFOp<lmhlo::CompareOp>, CmpFPredicate>::map(loc, "LT", result_types,
|
||||
args, b);
|
||||
}
|
||||
|
||||
template <>
|
||||
inline Value MapLhloOpToStdScalarOp<xla_lhlo::NegOp>(
|
||||
Location loc, ArrayRef<Type> result_types, ArrayRef<Value> args,
|
||||
OpBuilder* b) {
|
||||
inline Value MapLhloOpToStdScalarOp<lmhlo::NegOp>(Location loc,
|
||||
ArrayRef<Type> result_types,
|
||||
ArrayRef<Value> args,
|
||||
OpBuilder* b) {
|
||||
Type element_type = args.front().getType();
|
||||
if (element_type.isa<FloatType>()) {
|
||||
return MapLhloOpToStdScalarOpImpl<FloatType, ::mlir::NegFOp>{}(
|
||||
loc, result_types, args, b);
|
||||
}
|
||||
if (element_type.isa<IntegerType>()) {
|
||||
// xla_lhlo.neg(x, result) -> result = sub(0, x)
|
||||
// lmhlo.neg(x, result) -> result = sub(0, x)
|
||||
Value lhs = args[0];
|
||||
auto integer_type = element_type.dyn_cast<IntegerType>();
|
||||
|
||||
auto zero_intval =
|
||||
b->create<::mlir::ConstantIntOp>(loc, 0, integer_type.getWidth());
|
||||
return b->create<ScalarIOp<xla_lhlo::SubOp>>(loc, zero_intval, lhs);
|
||||
return b->create<ScalarIOp<lmhlo::SubOp>>(loc, zero_intval, lhs);
|
||||
}
|
||||
return nullptr;
|
||||
}
|
||||
|
||||
template <>
|
||||
inline Value MapLhloOpToStdScalarOp<xla_lhlo::RsqrtOp>(
|
||||
Location loc, ArrayRef<Type> result_types, ArrayRef<Value> args,
|
||||
OpBuilder* b) {
|
||||
inline Value MapLhloOpToStdScalarOp<lmhlo::RsqrtOp>(Location loc,
|
||||
ArrayRef<Type> result_types,
|
||||
ArrayRef<Value> args,
|
||||
OpBuilder* b) {
|
||||
return MapLhloOpToStdScalarOpImpl<FloatType, ::mlir::RsqrtOp>{}(
|
||||
loc, result_types, args, b);
|
||||
}
|
||||
|
||||
template <>
|
||||
inline Value MapLhloOpToStdScalarOp<xla_lhlo::SelectOp>(
|
||||
inline Value MapLhloOpToStdScalarOp<lmhlo::SelectOp>(
|
||||
Location loc, ArrayRef<Type> result_types, ArrayRef<Value> args,
|
||||
OpBuilder* b) {
|
||||
return MapLhloOpToStdScalarOpImpl<::mlir::SelectOp>{}(loc, result_types, args,
|
||||
@ -428,9 +441,10 @@ inline Value MapLhloOpToStdScalarOp<xla_lhlo::SelectOp>(
|
||||
}
|
||||
|
||||
template <>
|
||||
inline Value MapLhloOpToStdScalarOp<xla_lhlo::SignOp>(
|
||||
Location loc, ArrayRef<Type> result_types, ArrayRef<Value> args,
|
||||
OpBuilder* b) {
|
||||
inline Value MapLhloOpToStdScalarOp<lmhlo::SignOp>(Location loc,
|
||||
ArrayRef<Type> result_types,
|
||||
ArrayRef<Value> args,
|
||||
OpBuilder* b) {
|
||||
Type element_type = args.front().getType();
|
||||
if (element_type.isa<FloatType>()) {
|
||||
FloatType float_type = element_type.cast<FloatType>();
|
||||
@ -442,17 +456,19 @@ inline Value MapLhloOpToStdScalarOp<xla_lhlo::SignOp>(
|
||||
}
|
||||
|
||||
template <>
|
||||
inline Value MapLhloOpToStdScalarOp<xla_lhlo::SqrtOp>(
|
||||
Location loc, ArrayRef<Type> result_types, ArrayRef<Value> args,
|
||||
OpBuilder* b) {
|
||||
inline Value MapLhloOpToStdScalarOp<lmhlo::SqrtOp>(Location loc,
|
||||
ArrayRef<Type> result_types,
|
||||
ArrayRef<Value> args,
|
||||
OpBuilder* b) {
|
||||
return MapLhloOpToStdScalarOpImpl<FloatType, ::mlir::SqrtOp>{}(
|
||||
loc, result_types, args, b);
|
||||
}
|
||||
|
||||
template <>
|
||||
inline Value MapLhloOpToStdScalarOp<xla_lhlo::TanhOp>(
|
||||
Location loc, ArrayRef<Type> result_types, ArrayRef<Value> args,
|
||||
OpBuilder* b) {
|
||||
inline Value MapLhloOpToStdScalarOp<lmhlo::TanhOp>(Location loc,
|
||||
ArrayRef<Type> result_types,
|
||||
ArrayRef<Value> args,
|
||||
OpBuilder* b) {
|
||||
return MapLhloOpToStdScalarOpImpl<FloatType, ::mlir::TanhOp>{}(
|
||||
loc, result_types, args, b);
|
||||
}
|
||||
@ -460,10 +476,10 @@ inline Value MapLhloOpToStdScalarOp<xla_lhlo::TanhOp>(
|
||||
} // namespace impl
|
||||
|
||||
struct XlaOpToStdScalarOp {
|
||||
// Implementation for LHLO ops except xla_lhlo::CompareOp.
|
||||
// Implementation for LHLO ops except lmhlo::CompareOp.
|
||||
template <typename XlaOpTy, typename LhloOpTy = XlaOpTy,
|
||||
typename = std::enable_if_t<
|
||||
!std::is_same<LhloOpTy, xla_lhlo::CompareOp>::value &&
|
||||
!std::is_same<LhloOpTy, lmhlo::CompareOp>::value &&
|
||||
std::is_same<typename mhlo::HloToLhloOp<LhloOpTy>,
|
||||
std::false_type>::value>>
|
||||
static Value map(XlaOpTy op, ArrayRef<Type> result_types,
|
||||
@ -475,7 +491,7 @@ struct XlaOpToStdScalarOp {
|
||||
// Implementation for HLO ops except mhlo::CompareOp.
|
||||
template <typename XlaOpTy, typename LhloOpTy = mhlo::HloToLhloOp<XlaOpTy>,
|
||||
typename = std::enable_if_t<
|
||||
!std::is_same<LhloOpTy, xla_lhlo::CompareOp>::value &&
|
||||
!std::is_same<LhloOpTy, lmhlo::CompareOp>::value &&
|
||||
!std::is_same<LhloOpTy, std::false_type>::value>>
|
||||
static Value map(XlaOpTy op, ArrayRef<Type> result_types,
|
||||
ArrayRef<Value> args, OpBuilder* b, int i = 0) {
|
||||
@ -483,13 +499,13 @@ struct XlaOpToStdScalarOp {
|
||||
args, b);
|
||||
}
|
||||
|
||||
// Implementation for xla_lhlo::CompareOp.
|
||||
// Implementation for lmhlo::CompareOp.
|
||||
template <typename LhloOpTy, typename = std::enable_if_t<std::is_same<
|
||||
LhloOpTy, xla_lhlo::CompareOp>::value>>
|
||||
static Value map(xla_lhlo::CompareOp op, ArrayRef<Type> result_types,
|
||||
LhloOpTy, lmhlo::CompareOp>::value>>
|
||||
static Value map(lmhlo::CompareOp op, ArrayRef<Type> result_types,
|
||||
ArrayRef<Value> args, OpBuilder* b) {
|
||||
auto comparison_direction = op.comparison_direction();
|
||||
return impl::MapXlaCompareOpToStdScalarOp<xla_lhlo::CompareOp>(
|
||||
return impl::MapXlaCompareOpToStdScalarOp<lmhlo::CompareOp>(
|
||||
op.getLoc(), comparison_direction, result_types, args, b);
|
||||
}
|
||||
|
||||
@ -500,12 +516,12 @@ struct XlaOpToStdScalarOp {
|
||||
static Value map(mhlo::CompareOp op, ArrayRef<Type> result_types,
|
||||
ArrayRef<Value> args, OpBuilder* b) {
|
||||
auto comparison_direction = op.comparison_direction();
|
||||
return impl::MapXlaCompareOpToStdScalarOp<xla_lhlo::CompareOp>(
|
||||
return impl::MapXlaCompareOpToStdScalarOp<lmhlo::CompareOp>(
|
||||
op.getLoc(), comparison_direction, result_types, args, b);
|
||||
}
|
||||
};
|
||||
|
||||
} // namespace xla_lhlo
|
||||
} // namespace lmhlo
|
||||
} // namespace mlir
|
||||
|
||||
#endif // TENSORFLOW_COMPILER_MLIR_HLO_INCLUDE_MLIR_HLO_DIALECT_MHLO_TRANSFORMS_MAP_XLA_TO_SCALAR_OP_H_
|
||||
|
@ -60,7 +60,7 @@ std::unique_ptr<OperationPass<FuncOp>> createXlaHloFusionPass();
|
||||
|
||||
} // namespace mhlo
|
||||
|
||||
namespace xla_lhlo {
|
||||
namespace lmhlo {
|
||||
|
||||
// Lowers from LHLO dialect to Affine dialect.
|
||||
std::unique_ptr<OperationPass<FuncOp>> createLegalizeToAffinePass();
|
||||
@ -92,7 +92,7 @@ std::unique_ptr<Pass> createLhloCopyRemovalPass();
|
||||
// Lowers from LHLO dialect to parallel loops.
|
||||
std::unique_ptr<OperationPass<FuncOp>> createLegalizeLhloToParallelLoopsPass();
|
||||
|
||||
} // namespace xla_lhlo
|
||||
} // namespace lmhlo
|
||||
|
||||
namespace xla {
|
||||
|
||||
|
@ -75,23 +75,23 @@ void PopulateUnfuseBatchNormPatterns(MLIRContext *context,
|
||||
|
||||
} // namespace mhlo
|
||||
|
||||
namespace xla_lhlo {
|
||||
namespace lmhlo {
|
||||
|
||||
/// Collect a set of patterns to convert from the LHLO dialect to LLVM.
|
||||
void PopulateLhloToLLVMConversionPatterns(const LowerToLLVMOptions &options,
|
||||
LLVMTypeConverter *converter,
|
||||
OwningRewritePatternList *patterns);
|
||||
|
||||
} // namespace xla_lhlo
|
||||
} // namespace lmhlo
|
||||
|
||||
namespace xla_chlo {
|
||||
namespace chlo {
|
||||
|
||||
// Populates a collection of conversion patterns for legalizing client-HLO to
|
||||
// HLO.
|
||||
void PopulateLegalizeChloToHloPatterns(MLIRContext *context,
|
||||
OwningRewritePatternList *patterns);
|
||||
|
||||
} // namespace xla_chlo
|
||||
} // namespace chlo
|
||||
|
||||
namespace xla {
|
||||
|
||||
|
@ -23,7 +23,7 @@ limitations under the License.
|
||||
#include "tensorflow/compiler/mlir/hlo/include/mlir-hlo/utils/broadcast_utils.h"
|
||||
|
||||
namespace mlir {
|
||||
namespace xla_chlo {
|
||||
namespace chlo {
|
||||
|
||||
template <typename T>
|
||||
static LogicalResult Verify(T op) {
|
||||
@ -263,10 +263,10 @@ BROADCAST_BINARY_OP_DEFS(BroadcastXorOp);
|
||||
#include "tensorflow/compiler/mlir/hlo/include/mlir-hlo/Dialect/mhlo/IR/chlo_ops.cc.inc"
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
// xla_chlo Dialect Constructor
|
||||
// chlo Dialect Constructor
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
XlaHloClientDialect::XlaHloClientDialect(MLIRContext* context)
|
||||
HloClientDialect::HloClientDialect(MLIRContext* context)
|
||||
: Dialect(getDialectNamespace(), context) {
|
||||
addOperations<
|
||||
#define GET_OP_LIST
|
||||
@ -274,5 +274,5 @@ XlaHloClientDialect::XlaHloClientDialect(MLIRContext* context)
|
||||
>();
|
||||
}
|
||||
|
||||
} // namespace xla_chlo
|
||||
} // namespace chlo
|
||||
} // namespace mlir
|
||||
|
@ -18,7 +18,6 @@ limitations under the License.
|
||||
#include "tensorflow/compiler/mlir/hlo/include/mlir-hlo/Dialect/mhlo/IR/lhlo_ops.h"
|
||||
|
||||
// Static initialization for XLA dialect registration.
|
||||
static mlir::DialectRegistration<mlir::mhlo::XlaHloDialect> mhlo_ops;
|
||||
static mlir::DialectRegistration<mlir::xla_chlo::XlaHloClientDialect>
|
||||
xla_chlo_ops;
|
||||
static mlir::DialectRegistration<mlir::xla_lhlo::XlaLhloDialect> xla_lhlo_ops;
|
||||
static mlir::DialectRegistration<mlir::mhlo::MhloDialect> mhlo_ops;
|
||||
static mlir::DialectRegistration<mlir::chlo::HloClientDialect> chlo_ops;
|
||||
static mlir::DialectRegistration<mlir::lmhlo::LmhloDialect> lmhlo_ops;
|
||||
|
@ -62,9 +62,8 @@ namespace mlir {
|
||||
#include "tensorflow/compiler/mlir/hlo/include/mlir-hlo/Dialect/mhlo/IR/hlo_structs.cc.inc"
|
||||
namespace mhlo {
|
||||
|
||||
Operation* XlaHloDialect::materializeConstant(OpBuilder& builder,
|
||||
Attribute value, Type type,
|
||||
Location loc) {
|
||||
Operation* MhloDialect::materializeConstant(OpBuilder& builder, Attribute value,
|
||||
Type type, Location loc) {
|
||||
// HLO dialect constants only support ElementsAttr unlike standard dialect
|
||||
// constant which supports all attributes.
|
||||
if (value.isa<ElementsAttr>())
|
||||
@ -2128,7 +2127,7 @@ struct HLOInlinerInterface : public DialectInlinerInterface {
|
||||
// mhlo Dialect Constructor
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
XlaHloDialect::XlaHloDialect(MLIRContext* context)
|
||||
MhloDialect::MhloDialect(MLIRContext* context)
|
||||
: Dialect(getDialectNamespace(), context) {
|
||||
addOperations<
|
||||
#define GET_OP_LIST
|
||||
@ -2140,7 +2139,7 @@ XlaHloDialect::XlaHloDialect(MLIRContext* context)
|
||||
// allowUnknownOperations();
|
||||
}
|
||||
|
||||
Type XlaHloDialect::parseType(DialectAsmParser& parser) const {
|
||||
Type MhloDialect::parseType(DialectAsmParser& parser) const {
|
||||
StringRef data_type;
|
||||
if (parser.parseKeyword(&data_type)) return Type();
|
||||
|
||||
@ -2149,7 +2148,7 @@ Type XlaHloDialect::parseType(DialectAsmParser& parser) const {
|
||||
return nullptr;
|
||||
}
|
||||
|
||||
void XlaHloDialect::printType(Type type, DialectAsmPrinter& os) const {
|
||||
void MhloDialect::printType(Type type, DialectAsmPrinter& os) const {
|
||||
if (type.isa<TokenType>()) {
|
||||
os << "token";
|
||||
return;
|
||||
|
@ -46,9 +46,9 @@ limitations under the License.
|
||||
|
||||
namespace mlir {
|
||||
#include "tensorflow/compiler/mlir/hlo/include/mlir-hlo/Dialect/mhlo/IR/lhlo_structs.cc.inc"
|
||||
namespace xla_lhlo {
|
||||
namespace lmhlo {
|
||||
|
||||
XlaLhloDialect::XlaLhloDialect(MLIRContext *context)
|
||||
LmhloDialect::LmhloDialect(MLIRContext *context)
|
||||
: Dialect(getDialectNamespace(), context) {
|
||||
addOperations<
|
||||
#define GET_OP_LIST
|
||||
@ -138,5 +138,5 @@ void FusionOp::build(OpBuilder &builder, OperationState &result,
|
||||
FusionOp::ensureTerminator(*bodyRegion, builder, result.location);
|
||||
}
|
||||
|
||||
} // namespace xla_lhlo
|
||||
} // namespace lmhlo
|
||||
} // namespace mlir
|
||||
|
@ -25,7 +25,7 @@ limitations under the License.
|
||||
#include "tensorflow/compiler/mlir/hlo/include/mlir-hlo/utils/broadcast_utils.h"
|
||||
|
||||
namespace mlir {
|
||||
namespace xla_chlo {
|
||||
namespace chlo {
|
||||
|
||||
namespace {
|
||||
|
||||
@ -235,5 +235,5 @@ void PopulateLegalizeChloToHloPatterns(MLIRContext *context,
|
||||
context, patterns);
|
||||
}
|
||||
|
||||
} // namespace xla_chlo
|
||||
} // namespace chlo
|
||||
} // namespace mlir
|
||||
|
@ -21,7 +21,7 @@ limitations under the License.
|
||||
#include "tensorflow/compiler/mlir/hlo/include/mlir-hlo/Dialect/mhlo/transforms/rewriters.h"
|
||||
|
||||
namespace mlir {
|
||||
namespace xla_chlo {
|
||||
namespace chlo {
|
||||
|
||||
namespace {
|
||||
|
||||
@ -31,9 +31,9 @@ struct TestChloLegalizeToHloPass
|
||||
ConversionTarget conversionTarget(getContext());
|
||||
OwningRewritePatternList conversionPatterns;
|
||||
|
||||
conversionTarget.addIllegalDialect<XlaHloClientDialect>();
|
||||
conversionTarget.addIllegalDialect<HloClientDialect>();
|
||||
// Consider the mhlo dialect legal for tests.
|
||||
conversionTarget.addLegalDialect<mhlo::XlaHloDialect>();
|
||||
conversionTarget.addLegalDialect<mhlo::MhloDialect>();
|
||||
// The conversion uses helpers from the Standard dialect.
|
||||
conversionTarget.addLegalDialect<mlir::StandardOpsDialect>();
|
||||
conversionTarget.addLegalDialect<mlir::shape::ShapeDialect>();
|
||||
@ -49,9 +49,9 @@ struct TestChloLegalizeToHloPass
|
||||
|
||||
} // namespace
|
||||
|
||||
} // namespace xla_chlo
|
||||
} // namespace chlo
|
||||
} // namespace mlir
|
||||
|
||||
static mlir::PassRegistration<mlir::xla_chlo::TestChloLegalizeToHloPass> pass(
|
||||
static mlir::PassRegistration<mlir::chlo::TestChloLegalizeToHloPass> pass(
|
||||
"test-xla-chlo-legalize-to-hlo",
|
||||
"Test pass for applying chlo -> hlo legalization patterns");
|
||||
|
@ -44,7 +44,7 @@ template <typename T>
|
||||
using BaseOpConversion = BufferAssignmentOpConversionPattern<T>;
|
||||
using StdReturnOpConverter =
|
||||
detail::BufferAssignmentReturnOpConverter<mlir::ReturnOp, mlir::ReturnOp,
|
||||
xla_lhlo::CopyOp, true>;
|
||||
lmhlo::CopyOp, true>;
|
||||
|
||||
Value InsertDynamicAllocAndDealloc(Location loc, Value result,
|
||||
Value shape_operand,
|
||||
@ -149,7 +149,7 @@ struct HloToLhloDynamicBroadcastInDimOpConverter
|
||||
|
||||
Value transformed_operand =
|
||||
InsertDynamicMemrefCastOp(op, operands.front(), &rewriter);
|
||||
rewriter.create<xla_lhlo::BroadcastInDimOp>(
|
||||
rewriter.create<lmhlo::BroadcastInDimOp>(
|
||||
loc, transformed_operand, resultBuffer, op.broadcast_dimensions());
|
||||
|
||||
rewriter.replaceOp(op, {resultBuffer});
|
||||
@ -161,7 +161,7 @@ struct HloToLhloDynamicBroadcastInDimOpConverter
|
||||
// Inserts dynamic memref to change the layout of the memref to put 0-stride
|
||||
// and size of the target dimension if size-1 dimension expansion is
|
||||
// necessary.
|
||||
xla_lhlo::DynamicMemRefCastOp InsertDynamicMemrefCastOp(
|
||||
lmhlo::DynamicMemRefCastOp InsertDynamicMemrefCastOp(
|
||||
mhlo::DynamicBroadcastInDimOp op, Value operand, OpBuilder* b) const {
|
||||
auto loc = op.getLoc();
|
||||
auto operand_type = operand.getType().cast<MemRefType>();
|
||||
@ -214,12 +214,37 @@ struct HloToLhloDynamicBroadcastInDimOpConverter
|
||||
makeStridedLinearLayoutMap(dynamic_layout,
|
||||
/*offset=*/0, b->getContext()));
|
||||
|
||||
auto transformed_operand = b->create<xla_lhlo::DynamicMemRefCastOp>(
|
||||
auto transformed_operand = b->create<lmhlo::DynamicMemRefCastOp>(
|
||||
loc, type_erased_memref_type, operand, sizes, strides);
|
||||
return transformed_operand;
|
||||
}
|
||||
};
|
||||
|
||||
struct HloToLhloDynamicReshapeConverter
|
||||
: public BaseOpConversion<mhlo::DynamicReshapeOp> {
|
||||
public:
|
||||
using BaseOpConversion<mhlo::DynamicReshapeOp>::BaseOpConversion;
|
||||
|
||||
LogicalResult matchAndRewrite(
|
||||
mhlo::DynamicReshapeOp op, ArrayRef<Value> operands,
|
||||
ConversionPatternRewriter& rewriter) const final {
|
||||
Type result_type;
|
||||
if (auto ranked_type = op.getType().dyn_cast<RankedTensorType>()) {
|
||||
result_type =
|
||||
MemRefType::get(ranked_type.getShape(), ranked_type.getElementType());
|
||||
} else if (auto unranked_type =
|
||||
op.getType().dyn_cast<UnrankedTensorType>()) {
|
||||
result_type = UnrankedMemRefType::get(unranked_type.getElementType(), 0);
|
||||
} else {
|
||||
return failure();
|
||||
}
|
||||
mhlo::DynamicReshapeOp::Adaptor adaptor(operands);
|
||||
rewriter.replaceOpWithNewOp<lmhlo::ReshapeMemRefCastOp>(
|
||||
op, result_type, adaptor.operand(), adaptor.output_shape());
|
||||
return success();
|
||||
}
|
||||
};
|
||||
|
||||
struct HloToLhloReduceOpConverter : public BaseOpConversion<mhlo::ReduceOp> {
|
||||
public:
|
||||
using BaseOpConversion<mhlo::ReduceOp>::BaseOpConversion;
|
||||
@ -241,8 +266,8 @@ struct HloToLhloReduceOpConverter : public BaseOpConversion<mhlo::ReduceOp> {
|
||||
buffer_args.push_back(
|
||||
InsertAlloc(loc, result, this->bufferAssignment, &rewriter));
|
||||
}
|
||||
auto new_op = rewriter.create<xla_lhlo::ReduceOp>(
|
||||
loc, llvm::None, buffer_args, op.getAttrs());
|
||||
auto new_op = rewriter.create<lmhlo::ReduceOp>(loc, llvm::None, buffer_args,
|
||||
op.getAttrs());
|
||||
|
||||
// Copy over the operations inside the region.
|
||||
rewriter.inlineRegionBefore(op.body(), new_op.body(), new_op.body().end());
|
||||
@ -267,7 +292,7 @@ struct HloToLhloReduceOpConverter : public BaseOpConversion<mhlo::ReduceOp> {
|
||||
}
|
||||
// Insert terminator at the end.
|
||||
rewriter.setInsertionPointToEnd(&entry_block);
|
||||
rewriter.create<xla_lhlo::TerminatorOp>(loc);
|
||||
rewriter.create<lmhlo::TerminatorOp>(loc);
|
||||
|
||||
rewriter.replaceOp(op, ArrayRef<Value>(buffer_args).slice(operands.size()));
|
||||
|
||||
@ -296,8 +321,8 @@ class HloToLhloTensorStoreOpConverter
|
||||
LogicalResult matchAndRewrite(
|
||||
mlir::TensorStoreOp op, ArrayRef<Value> operands,
|
||||
ConversionPatternRewriter& rewriter) const final {
|
||||
rewriter.replaceOpWithNewOp<xla_lhlo::CopyOp>(
|
||||
op, llvm::None, operands.front(), operands.back());
|
||||
rewriter.replaceOpWithNewOp<lmhlo::CopyOp>(op, llvm::None, operands.front(),
|
||||
operands.back());
|
||||
return success();
|
||||
}
|
||||
};
|
||||
@ -311,7 +336,7 @@ class HloToLhloTensorStoreOpConverter
|
||||
// %arg1: memref<2x2xf32>,
|
||||
// %arg2: memref<2x2xf32>,
|
||||
// %arg3: memref<2x2xf32>) {
|
||||
// "xla_lhlo.fusion"() ({
|
||||
// "lmhlo.fusion"() ({
|
||||
// %0 = tensor_load %arg1 : memref<2x2xf32>
|
||||
// %1 = tensor_load %arg2 : memref<2x2xf32>
|
||||
// %2 = "mhlo.add"(%0, %1) :
|
||||
@ -320,7 +345,7 @@ class HloToLhloTensorStoreOpConverter
|
||||
// %4 = "mhlo.multiply"(%2, %3) :
|
||||
// (tensor<2x2xf32>, tensor<2x2xf32>) -> tensor<2x2xf32>
|
||||
// tensor_store %4, %arg3 : memref<2x2xf32>
|
||||
// "xla_lhlo.terminator"() : () -> ()
|
||||
// "lmhlo.terminator"() : () -> ()
|
||||
// }) : () -> ()
|
||||
// return
|
||||
// }
|
||||
@ -330,13 +355,13 @@ class HloToLhloTensorStoreOpConverter
|
||||
// %arg1: memref<2x2xf32>,
|
||||
// %arg2: memref<2x2xf32>,
|
||||
// %arg3: memref<2x2xf32>) {
|
||||
// "xla_lhlo.fusion"() ( {
|
||||
// "lmhlo.fusion"() ( {
|
||||
// %0 = alloc() : memref<2x2xf32>
|
||||
// "xla_lhlo.add"(%arg1, %arg2, %0) :
|
||||
// "lmhlo.add"(%arg1, %arg2, %0) :
|
||||
// (memref<2x2xf32>, memref<2x2xf32>, memref<2x2xf32>) -> ()
|
||||
// "xla_lhlo.multiply"(%0, %arg0, %arg3) :
|
||||
// "lmhlo.multiply"(%0, %arg0, %arg3) :
|
||||
// (memref<2x2xf32>, memref<2x2xf32>, memref<2x2xf32>) -> ()
|
||||
// "xla_lhlo.terminator"() : () -> ()
|
||||
// "lmhlo.terminator"() : () -> ()
|
||||
// }) : () -> ()
|
||||
// return
|
||||
// }
|
||||
@ -357,13 +382,13 @@ class HloToLhloTensorStoreOpConverter
|
||||
// %arg2: memref<4xf32>) {
|
||||
// %0 = alloc() : memref<4xf32>
|
||||
|
||||
// "xla_lhlo.maximum"(%arg0, %arg1, %0) :
|
||||
// "lmhlo.maximum"(%arg0, %arg1, %0) :
|
||||
// (memref<4xf32>, memref<4xf32>, memref<4xf32>) -> ()
|
||||
// %1 = alloc() : memref<4xf32>
|
||||
// "xla_lhlo.add"(%arg0, %0, %1) :
|
||||
// "lmhlo.add"(%arg0, %0, %1) :
|
||||
// (memref<4xf32>, memref<4xf32>, memref<4xf32>) -> ()
|
||||
// "xla_lhlo.copy"(%1, %arg2) : (memref<4xf32>, memref<4xf32>) -> ()
|
||||
// "xla_lhlo.terminator"() : () -> ()
|
||||
// "lmhlo.copy"(%1, %arg2) : (memref<4xf32>, memref<4xf32>) -> ()
|
||||
// "lmhlo.terminator"() : () -> ()
|
||||
// }
|
||||
|
||||
struct HloLegalizeToLhlo
|
||||
@ -381,26 +406,31 @@ struct HloLegalizeToLhlo
|
||||
OwningRewritePatternList patterns;
|
||||
auto& context = getContext();
|
||||
ConversionTarget target(context);
|
||||
target.addLegalDialect<xla_lhlo::XlaLhloDialect>();
|
||||
target.addLegalDialect<lmhlo::LmhloDialect>();
|
||||
target.addLegalDialect<StandardOpsDialect>();
|
||||
target.addLegalOp<ModuleOp>();
|
||||
target.addIllegalOp<mlir::TensorLoadOp>();
|
||||
target.addIllegalOp<mlir::TensorStoreOp>();
|
||||
target.addLegalOp<ModuleTerminatorOp>();
|
||||
target.addLegalOp<TensorFromElementsOp>();
|
||||
target.addIllegalDialect<mhlo::XlaHloDialect>();
|
||||
target.addIllegalDialect<mhlo::MhloDialect>();
|
||||
|
||||
BufferAssignmentTypeConverter converter;
|
||||
auto isMemRefType = [](Type type) { return type.isa<BaseMemRefType>(); };
|
||||
target.addDynamicallyLegalOp<FuncOp>([&](FuncOp op) {
|
||||
auto inputs = op.getType().getInputs();
|
||||
return llvm::all_of(inputs,
|
||||
[](Type input) { return input.isa<MemRefType>(); }) &&
|
||||
return llvm::all_of(inputs, isMemRefType) &&
|
||||
converter.isLegal(&op.getBody());
|
||||
});
|
||||
target.addDynamicallyLegalOp<mlir::ReturnOp>([&](mlir::ReturnOp returnOp) {
|
||||
return std::all_of(returnOp.operand_type_begin(),
|
||||
returnOp.operand_type_end(),
|
||||
[](Type type) { return type.isa<MemRefType>(); });
|
||||
target.addDynamicallyLegalOp<CallOp>([&](CallOp op) {
|
||||
return std::all_of(op.operand_type_begin(), op.operand_type_end(),
|
||||
isMemRefType) &&
|
||||
std::all_of(op.result_type_begin(), op.result_type_end(),
|
||||
isMemRefType);
|
||||
});
|
||||
target.addDynamicallyLegalOp<mlir::ReturnOp>([&](mlir::ReturnOp op) {
|
||||
return std::all_of(op.operand_type_begin(), op.operand_type_end(),
|
||||
isMemRefType);
|
||||
});
|
||||
|
||||
auto module = getOperation();
|
||||
@ -411,12 +441,12 @@ struct HloLegalizeToLhlo
|
||||
&converter, &patterns);
|
||||
if (results_escape_function) {
|
||||
populateWithBufferAssignmentOpConversionPatterns<
|
||||
mlir::ReturnOp, mlir::ReturnOp, xla_lhlo::CopyOp,
|
||||
mlir::ReturnOp, mlir::ReturnOp, lmhlo::CopyOp,
|
||||
/*allowMemrefFunctionResults=*/true>(&context, &bufferAssignment,
|
||||
&converter, &patterns);
|
||||
} else {
|
||||
populateWithBufferAssignmentOpConversionPatterns<
|
||||
mlir::ReturnOp, mlir::ReturnOp, xla_lhlo::CopyOp,
|
||||
mlir::ReturnOp, mlir::ReturnOp, lmhlo::CopyOp,
|
||||
/*allowMemrefFunctionResults=*/false>(&context, &bufferAssignment,
|
||||
&converter, &patterns);
|
||||
}
|
||||
@ -442,6 +472,7 @@ void populateHLOToLHLOConversionPattern(
|
||||
// clang-format off
|
||||
patterns->insert<
|
||||
HloToLhloDynamicBroadcastInDimOpConverter,
|
||||
HloToLhloDynamicReshapeConverter,
|
||||
HloToLhloOpConverter<mhlo::AbsOp>,
|
||||
HloToLhloOpConverter<mhlo::AddOp>,
|
||||
HloToLhloOpConverter<mhlo::AndOp>,
|
||||
|
@ -23,7 +23,7 @@ limitations under the License.
|
||||
#include "tensorflow/compiler/mlir/hlo/include/mlir-hlo/Dialect/mhlo/transforms/passes.h"
|
||||
|
||||
namespace mlir {
|
||||
namespace xla_lhlo {
|
||||
namespace lmhlo {
|
||||
namespace {
|
||||
|
||||
// Removes LHLO copy operations that copy from allocated buffers to block
|
||||
@ -34,7 +34,7 @@ struct LhloCopyRemoval : mlir::PassWrapper<LhloCopyRemoval, OperationPass<>> {
|
||||
void runOnOperation() override {
|
||||
llvm::SmallVector<mlir::Operation*, 2> eraseList;
|
||||
auto operation = getOperation();
|
||||
operation->walk([&](mlir::xla_lhlo::CopyOp copyOp) {
|
||||
operation->walk([&](mlir::lmhlo::CopyOp copyOp) {
|
||||
// If this region contains more than one block, then ignore this copy
|
||||
// operation.
|
||||
if (copyOp.getParentRegion()->getBlocks().size() > 1) {
|
||||
@ -101,5 +101,5 @@ std::unique_ptr<Pass> createLhloCopyRemovalPass() {
|
||||
static PassRegistration<LhloCopyRemoval> copy_removal_pass(
|
||||
"lhlo-copy-removal", "Removes redundant LHLO copy operations");
|
||||
|
||||
} // namespace xla_lhlo
|
||||
} // namespace lmhlo
|
||||
} // namespace mlir
|
||||
|
@ -27,7 +27,7 @@ limitations under the License.
|
||||
#include "tensorflow/compiler/mlir/hlo/include/mlir-hlo/Dialect/mhlo/transforms/passes.h"
|
||||
|
||||
namespace mlir {
|
||||
namespace xla_lhlo {
|
||||
namespace lmhlo {
|
||||
namespace {
|
||||
|
||||
using linalg::LinalgOp;
|
||||
@ -147,5 +147,5 @@ static PassRegistration<LhloFuseLinalg> legalize_pass(
|
||||
"lhlo-fuse-linalg",
|
||||
"Greedily fuse linalg ops obtained after LHLO lowering.");
|
||||
|
||||
} // namespace xla_lhlo
|
||||
} // namespace lmhlo
|
||||
} // namespace mlir
|
||||
|
@ -28,7 +28,7 @@ limitations under the License.
|
||||
#include "tensorflow/compiler/mlir/hlo/include/mlir-hlo/Dialect/mhlo/transforms/map_xla_to_scalar_op.h"
|
||||
|
||||
namespace mlir {
|
||||
namespace xla_lhlo {
|
||||
namespace lmhlo {
|
||||
namespace {
|
||||
|
||||
// Builds an affine loop nest iterating from zeros to "upper_bounds" with unit
|
||||
@ -69,7 +69,7 @@ struct DotOpConverter : public OpRewritePattern<DotOp> {
|
||||
auto r = builder.create<AffineLoadOp>(loc, rhs, rhs_indices);
|
||||
auto result =
|
||||
rewriter.create<AffineLoadOp>(loc, op.output(), result_indices);
|
||||
Value op_result = xla_lhlo::XlaOpToStdScalarOp::map<DotOp>(
|
||||
Value op_result = lmhlo::XlaOpToStdScalarOp::map<DotOp>(
|
||||
op, element_type, {l, r, result}, &builder);
|
||||
map_status = success(op_result != nullptr);
|
||||
if (failed(map_status)) return;
|
||||
@ -108,7 +108,7 @@ struct BinaryOpConverter : public OpRewritePattern<LhloOpTy> {
|
||||
ValueRange induction_vars) {
|
||||
auto l = builder.create<AffineLoadOp>(loc, lhs, induction_vars);
|
||||
auto r = builder.create<AffineLoadOp>(loc, rhs, induction_vars);
|
||||
Value op_result = xla_lhlo::XlaOpToStdScalarOp::map<LhloOpTy>(
|
||||
Value op_result = lmhlo::XlaOpToStdScalarOp::map<LhloOpTy>(
|
||||
op, element_type, {l, r}, &builder);
|
||||
map_status = success(op_result != nullptr);
|
||||
if (failed(map_status)) return;
|
||||
@ -127,13 +127,13 @@ void populateLHLOToAffineConversionPattern(MLIRContext* context,
|
||||
OwningRewritePatternList* patterns) {
|
||||
// clang-format off
|
||||
patterns->insert<
|
||||
BinaryOpConverter<xla_lhlo::AddOp>,
|
||||
BinaryOpConverter<xla_lhlo::AndOp>,
|
||||
BinaryOpConverter<xla_lhlo::DivOp>,
|
||||
BinaryOpConverter<xla_lhlo::MaxOp>,
|
||||
BinaryOpConverter<xla_lhlo::MinOp>,
|
||||
BinaryOpConverter<xla_lhlo::MulOp>,
|
||||
BinaryOpConverter<xla_lhlo::SubOp>,
|
||||
BinaryOpConverter<lmhlo::AddOp>,
|
||||
BinaryOpConverter<lmhlo::AndOp>,
|
||||
BinaryOpConverter<lmhlo::DivOp>,
|
||||
BinaryOpConverter<lmhlo::MaxOp>,
|
||||
BinaryOpConverter<lmhlo::MinOp>,
|
||||
BinaryOpConverter<lmhlo::MulOp>,
|
||||
BinaryOpConverter<lmhlo::SubOp>,
|
||||
DotOpConverter>(context);
|
||||
// clang-format on
|
||||
}
|
||||
@ -157,5 +157,5 @@ std::unique_ptr<OperationPass<FuncOp>> createLegalizeToAffinePass() {
|
||||
static PassRegistration<LhloLegalizeToAffine> legalize_pass(
|
||||
"lhlo-legalize-to-affine", "Legalize from LHLO dialect to affine dialect");
|
||||
|
||||
} // namespace xla_lhlo
|
||||
} // namespace lmhlo
|
||||
} // namespace mlir
|
||||
|
@ -38,7 +38,7 @@ limitations under the License.
|
||||
#include "tensorflow/compiler/mlir/hlo/include/mlir-hlo/Dialect/mhlo/transforms/map_xla_to_scalar_op.h"
|
||||
|
||||
namespace mlir {
|
||||
namespace xla_lhlo {
|
||||
namespace lmhlo {
|
||||
namespace {
|
||||
|
||||
// A simple translation of LHLO reduce operations to a corresponding gpu
|
||||
@ -173,7 +173,7 @@ struct LhloLegalizeToGpu : public PassWrapper<LhloLegalizeToGpu, FunctionPass> {
|
||||
OwningRewritePatternList patterns;
|
||||
ConversionTarget target(getContext());
|
||||
target.addLegalDialect<linalg::LinalgDialect, StandardOpsDialect,
|
||||
gpu::GPUDialect, scf::SCFDialect, XlaLhloDialect>();
|
||||
gpu::GPUDialect, scf::SCFDialect, LmhloDialect>();
|
||||
target.addIllegalOp<ReduceOp>();
|
||||
auto func = getFunction();
|
||||
patterns.insert<LhloReduceToGPULaunchConverter>(func.getContext());
|
||||
@ -192,5 +192,5 @@ std::unique_ptr<OperationPass<FuncOp>> createLegalizeToGpuPass() {
|
||||
static PassRegistration<LhloLegalizeToGpu> legalize_pass(
|
||||
"lhlo-legalize-to-gpu", "Legalize from LHLO dialect to GPU dialect");
|
||||
|
||||
} // namespace xla_lhlo
|
||||
} // namespace lmhlo
|
||||
} // namespace mlir
|
||||
|
@ -21,7 +21,7 @@ limitations under the License.
|
||||
#include "tensorflow/compiler/mlir/hlo/include/mlir-hlo/Dialect/mhlo/IR/lhlo_ops.h"
|
||||
|
||||
namespace mlir {
|
||||
namespace xla_lhlo {
|
||||
namespace lmhlo {
|
||||
namespace {
|
||||
|
||||
struct StaticMemRefCastOpConverter
|
||||
@ -132,5 +132,5 @@ void PopulateLhloToLLVMConversionPatterns(const LowerToLLVMOptions &options,
|
||||
*converter, options);
|
||||
}
|
||||
|
||||
} // namespace xla_lhlo
|
||||
} // namespace lmhlo
|
||||
} // namespace mlir
|
||||
|
@ -13,6 +13,8 @@ See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
|
||||
#include "mlir/Conversion/AffineToStandard/AffineToStandard.h" // from @llvm-project
|
||||
#include "mlir/Conversion/SCFToStandard/SCFToStandard.h" // from @llvm-project
|
||||
#include "mlir/Conversion/StandardToLLVM/ConvertStandardToLLVM.h" // from @llvm-project
|
||||
#include "mlir/Conversion/StandardToLLVM/ConvertStandardToLLVMPass.h" // from @llvm-project
|
||||
#include "mlir/Dialect/LLVMIR/LLVMDialect.h" // from @llvm-project
|
||||
@ -23,7 +25,7 @@ limitations under the License.
|
||||
#include "tensorflow/compiler/mlir/hlo/include/mlir-hlo/Dialect/mhlo/transforms/rewriters.h"
|
||||
|
||||
namespace mlir {
|
||||
namespace xla_lhlo {
|
||||
namespace lmhlo {
|
||||
namespace {
|
||||
|
||||
class TestLhloToLLVMPass
|
||||
@ -38,11 +40,14 @@ class TestLhloToLLVMPass
|
||||
populateStdToLLVMConversionPatterns(converter, patterns);
|
||||
PopulateLhloToLLVMConversionPatterns(
|
||||
LowerToLLVMOptions::getDefaultOptions(), &converter, &patterns);
|
||||
mlir::populateLoopToStdConversionPatterns(patterns, &getContext());
|
||||
|
||||
mlir::populateAffineToStdConversionPatterns(patterns, m.getContext());
|
||||
|
||||
ConversionTarget target(getContext());
|
||||
target.addLegalDialect<LLVM::LLVMDialect>();
|
||||
target.addLegalOp<ModuleOp, ModuleTerminatorOp>();
|
||||
target.addIllegalDialect<XlaLhloDialect>();
|
||||
target.addIllegalDialect<LmhloDialect>();
|
||||
|
||||
if (failed(applyFullConversion(m, target, patterns))) {
|
||||
signalPassFailure();
|
||||
@ -55,5 +60,5 @@ class TestLhloToLLVMPass
|
||||
static PassRegistration<TestLhloToLLVMPass> legalize_lhlo_pass(
|
||||
"test-lhlo-legalize-to-llvm", "Legalize from LHLO dialect to LLVM.");
|
||||
|
||||
} // namespace xla_lhlo
|
||||
} // namespace lmhlo
|
||||
} // namespace mlir
|
||||
|
@ -26,7 +26,7 @@ limitations under the License.
|
||||
#include "tensorflow/compiler/mlir/hlo/include/mlir-hlo/Dialect/mhlo/IR/lhlo_ops.h"
|
||||
|
||||
namespace mlir {
|
||||
namespace xla_lhlo {
|
||||
namespace lmhlo {
|
||||
namespace {
|
||||
|
||||
// Clones and adapts the code in `lhlo_block` that works on buffers and has a
|
||||
@ -154,14 +154,14 @@ scf::ParallelOp MakeLoopOverShape(Location loc, Value shaped_value,
|
||||
return b->create<scf::ParallelOp>(loc, lower, upper, step);
|
||||
}
|
||||
|
||||
// Converts `xla_lhlo.ReduceOp` into two scf::ParallelOp and a scf::ReduceOp.
|
||||
// Converts `lmhlo.ReduceOp` into two scf::ParallelOp and a scf::ReduceOp.
|
||||
// The outper `ParallelOp` refers to the parallel loops if there are
|
||||
// any. The inner `ParalleOp` refers to the reduction loops and `ReduceOp`
|
||||
// contains the reduction operator.
|
||||
//
|
||||
// Example:
|
||||
//
|
||||
// "xla_lhlo.reduce"(%buffer, %init_buf, %result) ( {
|
||||
// "lmhlo.reduce"(%buffer, %init_buf, %result) ( {
|
||||
// ^bb0(%lhs: memref<f32>, %rhs: memref<f32>, %res: memref<f32>):
|
||||
// <LHLO ops>
|
||||
// } ) {dimensions = dense<[1]> : tensor<1xi64>}
|
||||
@ -187,12 +187,12 @@ scf::ParallelOp MakeLoopOverShape(Location loc, Value shaped_value,
|
||||
// } : f32
|
||||
// scf.yield
|
||||
// }
|
||||
class ReduceOpConverter : public OpConversionPattern<xla_lhlo::ReduceOp> {
|
||||
class ReduceOpConverter : public OpConversionPattern<lmhlo::ReduceOp> {
|
||||
public:
|
||||
using OpConversionPattern<xla_lhlo::ReduceOp>::OpConversionPattern;
|
||||
using OpConversionPattern<lmhlo::ReduceOp>::OpConversionPattern;
|
||||
|
||||
LogicalResult matchAndRewrite(
|
||||
xla_lhlo::ReduceOp xla_reduce_op, ArrayRef<Value> /*args*/,
|
||||
lmhlo::ReduceOp xla_reduce_op, ArrayRef<Value> /*args*/,
|
||||
ConversionPatternRewriter& rewriter) const final {
|
||||
// TODO(b/137624192) Implement variadic reduce.
|
||||
if (xla_reduce_op.out().size() != 1) return failure();
|
||||
@ -226,7 +226,7 @@ class ReduceOpConverter : public OpConversionPattern<xla_lhlo::ReduceOp> {
|
||||
// scf.yield
|
||||
// }
|
||||
scf::ReduceOp CreateReduceOpInNestedParallelLoops(
|
||||
xla_lhlo::ReduceOp xla_reduce_op,
|
||||
lmhlo::ReduceOp xla_reduce_op,
|
||||
ConversionPatternRewriter* rewriter) const {
|
||||
auto loc = xla_reduce_op.getLoc();
|
||||
DenseSet<int> reducing_dims;
|
||||
@ -314,7 +314,7 @@ class ReduceOpConverter : public OpConversionPattern<xla_lhlo::ReduceOp> {
|
||||
// accumulator = reduction_operator(output[O], value)
|
||||
// output[O] = accumulator
|
||||
//
|
||||
// Converts `xla_lhlo.ReduceWindowOp` into two scf::ParallelOp and a
|
||||
// Converts `lmhlo.ReduceWindowOp` into two scf::ParallelOp and a
|
||||
// scf::ReduceOp.
|
||||
// The outper `ParallelOp` refers to the parallel loops that traverese output
|
||||
// buffer. The inner `ParalleOp` refers to the reduction loops that traverse
|
||||
@ -325,11 +325,11 @@ class ReduceOpConverter : public OpConversionPattern<xla_lhlo::ReduceOp> {
|
||||
// func @reduce_window(%arg: memref<112x112xf32>,
|
||||
// %init: memref<f32>,
|
||||
// %result: memref<56x56xf32>) {
|
||||
// "xla_lhlo.reduce_window"(%arg, %init, %result) ( {
|
||||
// "lmhlo.reduce_window"(%arg, %init, %result) ( {
|
||||
// ^bb0(%lhs: memref<f32>, %rhs: memref<f32>, %res: memref<f32>):
|
||||
// "xla_lhlo.maximum"(%lhs, %rhs, %res)
|
||||
// "lmhlo.maximum"(%lhs, %rhs, %res)
|
||||
// : (memref<f32>, memref<f32>, memref<f32>) -> ()
|
||||
// "xla_lhlo.terminator"() : () -> ()
|
||||
// "lmhlo.terminator"() : () -> ()
|
||||
// }) {
|
||||
// padding = dense<[[0, 1], [0, 1]]> : tensor<2x2xi64>,
|
||||
// window_dimensions = dense<[3, 3]> : tensor<2xi64>,
|
||||
@ -359,12 +359,12 @@ class ReduceOpConverter : public OpConversionPattern<xla_lhlo::ReduceOp> {
|
||||
// return
|
||||
// }
|
||||
class ReduceWindowOpConverter
|
||||
: public OpConversionPattern<xla_lhlo::ReduceWindowOp> {
|
||||
: public OpConversionPattern<lmhlo::ReduceWindowOp> {
|
||||
public:
|
||||
using OpConversionPattern<xla_lhlo::ReduceWindowOp>::OpConversionPattern;
|
||||
using OpConversionPattern<lmhlo::ReduceWindowOp>::OpConversionPattern;
|
||||
|
||||
LogicalResult matchAndRewrite(
|
||||
xla_lhlo::ReduceWindowOp xla_reduce_window_op, ArrayRef<Value> /*args*/,
|
||||
lmhlo::ReduceWindowOp xla_reduce_window_op, ArrayRef<Value> /*args*/,
|
||||
ConversionPatternRewriter& rewriter) const final {
|
||||
scf::ParallelOp output_loop, window_loop;
|
||||
std::tie(output_loop, window_loop) =
|
||||
@ -383,7 +383,7 @@ class ReduceWindowOpConverter
|
||||
private:
|
||||
std::pair<scf::ParallelOp, scf::ParallelOp>
|
||||
CreateParallelLoopsToTraverseOutputAndWindow(
|
||||
xla_lhlo::ReduceWindowOp xla_reduce_window_op,
|
||||
lmhlo::ReduceWindowOp xla_reduce_window_op,
|
||||
ConversionPatternRewriter* rewriter) const {
|
||||
auto loc = xla_reduce_window_op.getLoc();
|
||||
Value init_value =
|
||||
@ -415,9 +415,8 @@ class ReduceWindowOpConverter
|
||||
}
|
||||
|
||||
scf::ReduceOp CreateReduceOpInNestedParallelLoops(
|
||||
xla_lhlo::ReduceWindowOp xla_reduce_window_op,
|
||||
scf::ParallelOp output_loop, scf::ParallelOp window_loop,
|
||||
ConversionPatternRewriter* rewriter) const {
|
||||
lmhlo::ReduceWindowOp xla_reduce_window_op, scf::ParallelOp output_loop,
|
||||
scf::ParallelOp window_loop, ConversionPatternRewriter* rewriter) const {
|
||||
rewriter->setInsertionPointToStart(window_loop.getBody());
|
||||
auto loc = xla_reduce_window_op.getLoc();
|
||||
|
||||
@ -481,12 +480,12 @@ class ReduceWindowOpConverter
|
||||
// initialized_flag = true
|
||||
// output(selected_index) = scatter(output(selected_index), source(S))
|
||||
class SelectAndScatterOpConverter
|
||||
: public OpConversionPattern<xla_lhlo::SelectAndScatterOp> {
|
||||
: public OpConversionPattern<lmhlo::SelectAndScatterOp> {
|
||||
public:
|
||||
using OpConversionPattern<xla_lhlo::SelectAndScatterOp>::OpConversionPattern;
|
||||
using OpConversionPattern<lmhlo::SelectAndScatterOp>::OpConversionPattern;
|
||||
|
||||
LogicalResult matchAndRewrite(
|
||||
xla_lhlo::SelectAndScatterOp s_and_s_op, ArrayRef<Value> /*args*/,
|
||||
lmhlo::SelectAndScatterOp s_and_s_op, ArrayRef<Value> /*args*/,
|
||||
ConversionPatternRewriter& rewriter) const final {
|
||||
auto loc = s_and_s_op.getLoc();
|
||||
InitializeOutput(s_and_s_op, &rewriter);
|
||||
@ -515,7 +514,7 @@ class SelectAndScatterOpConverter
|
||||
}
|
||||
|
||||
private:
|
||||
void InitializeOutput(xla_lhlo::SelectAndScatterOp s_and_s_op,
|
||||
void InitializeOutput(lmhlo::SelectAndScatterOp s_and_s_op,
|
||||
OpBuilder* b) const {
|
||||
auto loc = s_and_s_op.getLoc();
|
||||
Value init_value = b->create<LoadOp>(loc, s_and_s_op.init_value());
|
||||
@ -533,7 +532,7 @@ class SelectAndScatterOpConverter
|
||||
SmallVector<Value, 2> window_ivs;
|
||||
scf::ForOp inner_loop;
|
||||
};
|
||||
WindowLoops InsertWindowLoops(xla_lhlo::SelectAndScatterOp s_and_s_op,
|
||||
WindowLoops InsertWindowLoops(lmhlo::SelectAndScatterOp s_and_s_op,
|
||||
scf::ParallelOp loop_over_src,
|
||||
OpBuilder* b) const {
|
||||
auto loc = s_and_s_op.getLoc();
|
||||
@ -598,7 +597,7 @@ class SelectAndScatterOpConverter
|
||||
SmallVector<Value, 4> ivs_val_flag_;
|
||||
};
|
||||
|
||||
SmallVector<Value, 2> SelectIvs(xla_lhlo::SelectAndScatterOp s_and_s_op,
|
||||
SmallVector<Value, 2> SelectIvs(lmhlo::SelectAndScatterOp s_and_s_op,
|
||||
scf::ParallelOp loop_over_src,
|
||||
OpBuilder* b) const {
|
||||
auto loc = s_and_s_op.getLoc();
|
||||
@ -636,9 +635,10 @@ class SelectAndScatterOpConverter
|
||||
return window_loops.selected_ivs;
|
||||
}
|
||||
|
||||
SmallVector<Value, 4> SelectOrInitialize(
|
||||
xla_lhlo::SelectAndScatterOp s_and_s_op, ArrayRef<Value> operand_ivs,
|
||||
IterArgs* ivs_val_flag, OpBuilder* b) const {
|
||||
SmallVector<Value, 4> SelectOrInitialize(lmhlo::SelectAndScatterOp s_and_s_op,
|
||||
ArrayRef<Value> operand_ivs,
|
||||
IterArgs* ivs_val_flag,
|
||||
OpBuilder* b) const {
|
||||
auto loc = s_and_s_op.getLoc();
|
||||
Value true_i1 = b->create<mlir::ConstantOp>(
|
||||
loc, b->getI1Type(), b->getIntegerAttr(b->getI1Type(), 1));
|
||||
@ -707,9 +707,9 @@ struct LhloLegalizeToParallelLoops
|
||||
|
||||
ConversionTarget target(getContext());
|
||||
target.addLegalDialect<linalg::LinalgDialect, StandardOpsDialect,
|
||||
scf::SCFDialect, XlaLhloDialect>();
|
||||
target.addIllegalOp<xla_lhlo::ReduceOp, xla_lhlo::ReduceWindowOp,
|
||||
xla_lhlo::SelectAndScatterOp>();
|
||||
scf::SCFDialect, LmhloDialect>();
|
||||
target.addIllegalOp<lmhlo::ReduceOp, lmhlo::ReduceWindowOp,
|
||||
lmhlo::SelectAndScatterOp>();
|
||||
|
||||
if (failed(applyPartialConversion(func, target, patterns))) {
|
||||
signalPassFailure();
|
||||
@ -727,5 +727,5 @@ static PassRegistration<LhloLegalizeToParallelLoops> legalize_lhlo_pass(
|
||||
"lhlo-legalize-to-parallel-loops",
|
||||
"Legalize from LHLO dialect to parallel loops.");
|
||||
|
||||
} // namespace xla_lhlo
|
||||
} // namespace lmhlo
|
||||
} // namespace mlir
|
||||
|
@ -34,7 +34,7 @@ struct TestMaterializeBroadcastsPass
|
||||
OwningRewritePatternList conversionPatterns;
|
||||
|
||||
// Consider the mhlo dialect legal for tests.
|
||||
conversionTarget.addLegalDialect<XlaHloDialect>();
|
||||
conversionTarget.addLegalDialect<MhloDialect>();
|
||||
// The conversion uses helpers from the Standard dialect.
|
||||
conversionTarget.addLegalDialect<mlir::StandardOpsDialect>();
|
||||
|
||||
|
@ -131,9 +131,9 @@ class PointwiseToLinalgConverter : public OpConversionPattern<OpTy> {
|
||||
loc, opResultTypes, args, args_count, results_count, indexing_maps,
|
||||
GetNParallelLoopsAttrs(nloops),
|
||||
[&](OpBuilder& nestedBuilder, Location nestedLoc, ValueRange args) {
|
||||
// TODO(ravishankarm) : For now use the method in xla_lhlo namespace.
|
||||
// TODO(ravishankarm) : For now use the method in lmhlo namespace.
|
||||
// That method needs to be moved out of there.
|
||||
Value opResult = xla_lhlo::XlaOpToStdScalarOp::map<OpTy>(
|
||||
Value opResult = lmhlo::XlaOpToStdScalarOp::map<OpTy>(
|
||||
op, bodyResultTypes,
|
||||
llvm::to_vector<2>(args.take_front(args_count)), &rewriter);
|
||||
nestedBuilder.create<linalg::YieldOp>(loc, opResult);
|
||||
@ -162,8 +162,8 @@ class ScalarPointwiseToStandardConverter : public OpConversionPattern<LhloOp> {
|
||||
// Create two loads from the input.
|
||||
auto lhs = rewriter.create<LoadOp>(loc, lhlo_op.lhs());
|
||||
auto rhs = rewriter.create<LoadOp>(loc, lhlo_op.rhs());
|
||||
// TODO(ravishankarm) : Move this method out of xla_lhlo namespace.
|
||||
Value opResult = xla_lhlo::XlaOpToStdScalarOp::map<LhloOp>(
|
||||
// TODO(ravishankarm) : Move this method out of lmhlo namespace.
|
||||
Value opResult = lmhlo::XlaOpToStdScalarOp::map<LhloOp>(
|
||||
lhlo_op, argType.getElementType(), llvm::ArrayRef<Value>{lhs, rhs},
|
||||
&rewriter);
|
||||
rewriter.create<StoreOp>(loc, opResult, lhlo_op.out());
|
||||
@ -173,21 +173,21 @@ class ScalarPointwiseToStandardConverter : public OpConversionPattern<LhloOp> {
|
||||
};
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
// xla_lhlo.convolution conversion pattern.
|
||||
// lmhlo.convolution conversion pattern.
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
/// Converts xla_lhlo.convolution operation to a linalg.conv op.
|
||||
struct ConvToLinalgConverter : public OpConversionPattern<xla_lhlo::ConvOp> {
|
||||
/// Converts lmhlo.convolution operation to a linalg.conv op.
|
||||
struct ConvToLinalgConverter : public OpConversionPattern<lmhlo::ConvOp> {
|
||||
public:
|
||||
using OpConversionPattern<xla_lhlo::ConvOp>::OpConversionPattern;
|
||||
using OpConversionPattern<lmhlo::ConvOp>::OpConversionPattern;
|
||||
|
||||
// This code has been adapted from IREE's
|
||||
// (https://github.com/google/iree/) mhlo -> linalg conversion.
|
||||
LogicalResult matchAndRewrite(
|
||||
xla_lhlo::ConvOp op, ArrayRef<Value> args,
|
||||
lmhlo::ConvOp op, ArrayRef<Value> args,
|
||||
ConversionPatternRewriter& rewriter) const final {
|
||||
// Check validity of dimension information.
|
||||
if (const xla_lhlo::ConvDimensionNumbers& dimensionNumbers =
|
||||
if (const lmhlo::ConvDimensionNumbers& dimensionNumbers =
|
||||
op.dimension_numbers()) {
|
||||
const int inputSpatialRank =
|
||||
llvm::size(dimensionNumbers.input_spatial_dimensions());
|
||||
@ -388,14 +388,14 @@ class HloBroadcastInDimConverter
|
||||
};
|
||||
|
||||
class LhloBroadcastInDimConverter
|
||||
: public OpConversionPattern<xla_lhlo::BroadcastInDimOp> {
|
||||
: public OpConversionPattern<lmhlo::BroadcastInDimOp> {
|
||||
public:
|
||||
using OpConversionPattern<xla_lhlo::BroadcastInDimOp>::OpConversionPattern;
|
||||
using OpConversionPattern<lmhlo::BroadcastInDimOp>::OpConversionPattern;
|
||||
|
||||
LogicalResult matchAndRewrite(
|
||||
xla_lhlo::BroadcastInDimOp op, ArrayRef<Value> args,
|
||||
lmhlo::BroadcastInDimOp op, ArrayRef<Value> args,
|
||||
ConversionPatternRewriter& rewriter) const final {
|
||||
xla_lhlo::BroadcastInDimOp::Adaptor operand_adaptor(args);
|
||||
lmhlo::BroadcastInDimOp::Adaptor operand_adaptor(args);
|
||||
auto result_type = operand_adaptor.output().getType().cast<MemRefType>();
|
||||
auto result_shape = result_type.getShape();
|
||||
|
||||
@ -444,9 +444,9 @@ class LhloBroadcastInDimConverter
|
||||
|
||||
// Inserts 'linalg.reshape' if there is a size-1 dim expansion.
|
||||
std::pair<Value, SmallVector<int64_t, 2>> InsertReshapeIfNecessary(
|
||||
xla_lhlo::BroadcastInDimOp op, ArrayRef<Value> args,
|
||||
lmhlo::BroadcastInDimOp op, ArrayRef<Value> args,
|
||||
ConversionPatternRewriter& rewriter) const {
|
||||
xla_lhlo::BroadcastInDimOp::Adaptor operand_adaptor(args);
|
||||
lmhlo::BroadcastInDimOp::Adaptor operand_adaptor(args);
|
||||
Value operand = operand_adaptor.operand();
|
||||
auto operand_type = operand_adaptor.operand().getType().cast<MemRefType>();
|
||||
auto operand_shape = operand_type.getShape();
|
||||
@ -512,7 +512,7 @@ class LhloBroadcastInDimConverter
|
||||
return std::make_pair(operand, broadcast_dims);
|
||||
}
|
||||
|
||||
SmallVector<AffineMap, 2> getIndexingMaps(xla_lhlo::BroadcastInDimOp op,
|
||||
SmallVector<AffineMap, 2> getIndexingMaps(lmhlo::BroadcastInDimOp op,
|
||||
ArrayRef<int64_t> broadcastDims,
|
||||
ArrayRef<int64_t> resultShape,
|
||||
MemRefType operandType,
|
||||
@ -639,12 +639,12 @@ class ReshapeOpConverter : public OpConversionPattern<OpTy> {
|
||||
}
|
||||
};
|
||||
|
||||
class IotaConverter : public OpConversionPattern<xla_lhlo::IotaOp> {
|
||||
class IotaConverter : public OpConversionPattern<lmhlo::IotaOp> {
|
||||
public:
|
||||
using OpConversionPattern<xla_lhlo::IotaOp>::OpConversionPattern;
|
||||
using OpConversionPattern<lmhlo::IotaOp>::OpConversionPattern;
|
||||
|
||||
LogicalResult matchAndRewrite(
|
||||
xla_lhlo::IotaOp iotaOp, ArrayRef<Value> args,
|
||||
lmhlo::IotaOp iotaOp, ArrayRef<Value> args,
|
||||
ConversionPatternRewriter& rewriter) const final {
|
||||
auto resultMemrefType =
|
||||
iotaOp.getOperand().getType().dyn_cast<MemRefType>();
|
||||
@ -680,12 +680,12 @@ class IotaConverter : public OpConversionPattern<xla_lhlo::IotaOp> {
|
||||
}
|
||||
};
|
||||
|
||||
class ConstConverter : public OpConversionPattern<xla_lhlo::ConstOp> {
|
||||
class ConstConverter : public OpConversionPattern<lmhlo::ConstOp> {
|
||||
public:
|
||||
using OpConversionPattern<xla_lhlo::ConstOp>::OpConversionPattern;
|
||||
using OpConversionPattern<lmhlo::ConstOp>::OpConversionPattern;
|
||||
|
||||
LogicalResult matchAndRewrite(
|
||||
xla_lhlo::ConstOp constOp, ArrayRef<Value> args,
|
||||
lmhlo::ConstOp constOp, ArrayRef<Value> args,
|
||||
ConversionPatternRewriter& rewriter) const final {
|
||||
auto loc = constOp.getLoc();
|
||||
auto valueAttr = constOp.value().cast<DenseElementsAttr>();
|
||||
@ -726,12 +726,12 @@ class ReverseConverter
|
||||
}
|
||||
};
|
||||
|
||||
class SliceConverter : public OpConversionPattern<xla_lhlo::SliceOp> {
|
||||
class SliceConverter : public OpConversionPattern<lmhlo::SliceOp> {
|
||||
public:
|
||||
using OpConversionPattern<xla_lhlo::SliceOp>::OpConversionPattern;
|
||||
using OpConversionPattern<lmhlo::SliceOp>::OpConversionPattern;
|
||||
|
||||
LogicalResult matchAndRewrite(
|
||||
xla_lhlo::SliceOp sliceOp, ArrayRef<Value> args,
|
||||
lmhlo::SliceOp sliceOp, ArrayRef<Value> args,
|
||||
ConversionPatternRewriter& rewriter) const final {
|
||||
auto loc = sliceOp.getLoc();
|
||||
auto argType =
|
||||
@ -763,50 +763,50 @@ class SliceConverter : public OpConversionPattern<xla_lhlo::SliceOp> {
|
||||
void populateLHLOToLinalgConversionPattern(MLIRContext* context,
|
||||
OwningRewritePatternList* patterns) {
|
||||
// clang-format off
|
||||
patterns->insert<BroadcastConverter<xla_lhlo::BroadcastOp>,
|
||||
patterns->insert<BroadcastConverter<lmhlo::BroadcastOp>,
|
||||
ConstConverter,
|
||||
ConvToLinalgConverter,
|
||||
IotaConverter,
|
||||
LhloBroadcastInDimConverter,
|
||||
PointwiseToLinalgConverter<xla_lhlo::AbsOp>,
|
||||
PointwiseToLinalgConverter<xla_lhlo::AddOp>,
|
||||
PointwiseToLinalgConverter<xla_lhlo::AndOp>,
|
||||
PointwiseToLinalgConverter<xla_lhlo::CeilOp>,
|
||||
PointwiseToLinalgConverter<xla_lhlo::CompareOp>,
|
||||
PointwiseToLinalgConverter<xla_lhlo::ComplexOp>,
|
||||
PointwiseToLinalgConverter<xla_lhlo::ConvertOp>,
|
||||
PointwiseToLinalgConverter<lmhlo::AbsOp>,
|
||||
PointwiseToLinalgConverter<lmhlo::AddOp>,
|
||||
PointwiseToLinalgConverter<lmhlo::AndOp>,
|
||||
PointwiseToLinalgConverter<lmhlo::CeilOp>,
|
||||
PointwiseToLinalgConverter<lmhlo::CompareOp>,
|
||||
PointwiseToLinalgConverter<lmhlo::ComplexOp>,
|
||||
PointwiseToLinalgConverter<lmhlo::ConvertOp>,
|
||||
// TODO(ataei): Remove this pattern, CopyOp is folded away.
|
||||
PointwiseToLinalgConverter<xla_lhlo::CopyOp>,
|
||||
PointwiseToLinalgConverter<xla_lhlo::CosOp>,
|
||||
PointwiseToLinalgConverter<xla_lhlo::DivOp>,
|
||||
PointwiseToLinalgConverter<xla_lhlo::ExpOp>,
|
||||
PointwiseToLinalgConverter<xla_lhlo::ImagOp>,
|
||||
PointwiseToLinalgConverter<xla_lhlo::LogOp>,
|
||||
PointwiseToLinalgConverter<xla_lhlo::MaxOp>,
|
||||
PointwiseToLinalgConverter<xla_lhlo::MinOp>,
|
||||
PointwiseToLinalgConverter<xla_lhlo::MulOp>,
|
||||
PointwiseToLinalgConverter<xla_lhlo::NegOp>,
|
||||
PointwiseToLinalgConverter<xla_lhlo::RealOp>,
|
||||
PointwiseToLinalgConverter<xla_lhlo::RemOp>,
|
||||
PointwiseToLinalgConverter<xla_lhlo::RsqrtOp>,
|
||||
PointwiseToLinalgConverter<xla_lhlo::SelectOp>,
|
||||
PointwiseToLinalgConverter<xla_lhlo::SignOp>,
|
||||
PointwiseToLinalgConverter<xla_lhlo::SinOp>,
|
||||
PointwiseToLinalgConverter<xla_lhlo::SqrtOp>,
|
||||
PointwiseToLinalgConverter<xla_lhlo::SubOp>,
|
||||
PointwiseToLinalgConverter<xla_lhlo::TanhOp>,
|
||||
ReshapeOpConverter<xla_lhlo::ReshapeOp>,
|
||||
ReverseConverter<xla_lhlo::ReverseOp>,
|
||||
ScalarPointwiseToStandardConverter<xla_lhlo::AddOp>,
|
||||
PointwiseToLinalgConverter<lmhlo::CopyOp>,
|
||||
PointwiseToLinalgConverter<lmhlo::CosOp>,
|
||||
PointwiseToLinalgConverter<lmhlo::DivOp>,
|
||||
PointwiseToLinalgConverter<lmhlo::ExpOp>,
|
||||
PointwiseToLinalgConverter<lmhlo::ImagOp>,
|
||||
PointwiseToLinalgConverter<lmhlo::LogOp>,
|
||||
PointwiseToLinalgConverter<lmhlo::MaxOp>,
|
||||
PointwiseToLinalgConverter<lmhlo::MinOp>,
|
||||
PointwiseToLinalgConverter<lmhlo::MulOp>,
|
||||
PointwiseToLinalgConverter<lmhlo::NegOp>,
|
||||
PointwiseToLinalgConverter<lmhlo::RealOp>,
|
||||
PointwiseToLinalgConverter<lmhlo::RemOp>,
|
||||
PointwiseToLinalgConverter<lmhlo::RsqrtOp>,
|
||||
PointwiseToLinalgConverter<lmhlo::SelectOp>,
|
||||
PointwiseToLinalgConverter<lmhlo::SignOp>,
|
||||
PointwiseToLinalgConverter<lmhlo::SinOp>,
|
||||
PointwiseToLinalgConverter<lmhlo::SqrtOp>,
|
||||
PointwiseToLinalgConverter<lmhlo::SubOp>,
|
||||
PointwiseToLinalgConverter<lmhlo::TanhOp>,
|
||||
ReshapeOpConverter<lmhlo::ReshapeOp>,
|
||||
ReverseConverter<lmhlo::ReverseOp>,
|
||||
ScalarPointwiseToStandardConverter<lmhlo::AddOp>,
|
||||
SliceConverter
|
||||
>(context);
|
||||
// clang-format on
|
||||
}
|
||||
|
||||
// Converts LHLO ops to Linalg generic.
|
||||
// Sample result for xla_lhlo::AddOp.
|
||||
// Sample result for lmhlo::AddOp.
|
||||
//
|
||||
// "xla_lhlo.add"(%arg1, %arg2, %out) :
|
||||
// "lmhlo.add"(%arg1, %arg2, %out) :
|
||||
// (memref<2x2xf32>, memref<2x2xf32>, memref<2x2xf32>) -> ()
|
||||
//
|
||||
// will be converted to
|
||||
@ -854,14 +854,14 @@ struct HloLegalizeToLinalg
|
||||
|
||||
} // namespace
|
||||
|
||||
namespace xla_lhlo {
|
||||
namespace lmhlo {
|
||||
std::unique_ptr<OperationPass<FuncOp>> createLegalizeLhloToLinalgPass() {
|
||||
return absl::make_unique<LhloLegalizeToLinalg>();
|
||||
}
|
||||
|
||||
static PassRegistration<LhloLegalizeToLinalg> legalize_lhlo_pass(
|
||||
"lhlo-legalize-to-linalg", "Legalize from LHLO dialect to Linalg dialect");
|
||||
} // namespace xla_lhlo
|
||||
} // namespace lmhlo
|
||||
|
||||
namespace mhlo {
|
||||
|
||||
|
@ -152,7 +152,7 @@ struct TransformUnrankedHloPass
|
||||
// Setup conversion target.
|
||||
MLIRContext &ctx = getContext();
|
||||
ConversionTarget target(ctx);
|
||||
target.addLegalDialect<XlaHloDialect, StandardOpsDialect,
|
||||
target.addLegalDialect<MhloDialect, StandardOpsDialect,
|
||||
shape::ShapeDialect>();
|
||||
target.addLegalOp<FuncOp>();
|
||||
AddLegalOpOnRankedTensor<SqrtOp>(&target);
|
||||
|
@ -11,7 +11,7 @@ func @broadcast_add(%arg0: tensor<?xf32>, %arg1: tensor<?xf32>) -> tensor<1xinde
|
||||
// CHECK-DAG: %[[BCAST_S:.+]] = "shape.broadcast"(%[[ARG0_S]], %[[ARG1_S]])
|
||||
// CHECK: %[[EXTENTS:.+]] = shape.to_extent_tensor %[[BCAST_S]]
|
||||
// CHECK: return %[[EXTENTS]]
|
||||
%0 = xla_chlo.broadcast_add %arg0, %arg1 : (tensor<?xf32>, tensor<?xf32>) -> tensor<?xf32>
|
||||
%0 = chlo.broadcast_add %arg0, %arg1 : (tensor<?xf32>, tensor<?xf32>) -> tensor<?xf32>
|
||||
%1 = "xla_test.reify_return_type_shapes"(%0) : (tensor<?xf32>) -> tensor<1xindex>
|
||||
return %1 : tensor<1xindex>
|
||||
}
|
||||
@ -19,7 +19,7 @@ func @broadcast_add(%arg0: tensor<?xf32>, %arg1: tensor<?xf32>) -> tensor<1xinde
|
||||
// -----
|
||||
// CHECK-LABEL: @complex_ranked_components
|
||||
func @complex_ranked_components(%arg0: tensor<?xf32>, %arg1: tensor<?x?xf32>) -> tensor<?x?xcomplex<f32>> {
|
||||
%0 = xla_chlo.broadcast_complex %arg0, %arg1 : (tensor<?xf32>, tensor<?x?xf32>) -> tensor<?x?xcomplex<f32>>
|
||||
%0 = chlo.broadcast_complex %arg0, %arg1 : (tensor<?xf32>, tensor<?x?xf32>) -> tensor<?x?xcomplex<f32>>
|
||||
// CHECK: "xla_test.return_type_components"(%0) {dims0 = [-1, -1], element_type0 = complex<f32>}
|
||||
%1 = "xla_test.get_return_type_components"(%0) : (tensor<?x?xcomplex<f32>>) -> tensor<?x?xcomplex<f32>>
|
||||
return %1 : tensor<?x?xcomplex<f32>>
|
||||
@ -28,7 +28,7 @@ func @complex_ranked_components(%arg0: tensor<?xf32>, %arg1: tensor<?x?xf32>) ->
|
||||
// -----
|
||||
// CHECK-LABEL: @compare_ranked_components
|
||||
func @compare_ranked_components(%arg0: tensor<?xf32>, %arg1: tensor<?x?xf32>) -> tensor<?x?xi1> {
|
||||
%0 = xla_chlo.broadcast_compare %arg0, %arg1 {comparison_direction = "EQ"} : (tensor<?xf32>, tensor<?x?xf32>) -> tensor<?x?xi1>
|
||||
%0 = chlo.broadcast_compare %arg0, %arg1 {comparison_direction = "EQ"} : (tensor<?xf32>, tensor<?x?xf32>) -> tensor<?x?xi1>
|
||||
// CHECK: "xla_test.return_type_components"(%0) {dims0 = [-1, -1], element_type0 = i1}
|
||||
%1 = "xla_test.get_return_type_components"(%0) : (tensor<?x?xi1>) -> tensor<?x?xi1>
|
||||
return %0 : tensor<?x?xi1>
|
||||
@ -37,7 +37,7 @@ func @compare_ranked_components(%arg0: tensor<?xf32>, %arg1: tensor<?x?xf32>) ->
|
||||
// -----
|
||||
// CHECK-LABEL: @broadcast_add_ranked_components_r1
|
||||
func @broadcast_add_ranked_components_r1(%arg0: tensor<?xf32>, %arg1: tensor<?xf32>) -> tensor<?xf32> {
|
||||
%0 = xla_chlo.broadcast_add %arg0, %arg1 : (tensor<?xf32>, tensor<?xf32>) -> tensor<?xf32>
|
||||
%0 = chlo.broadcast_add %arg0, %arg1 : (tensor<?xf32>, tensor<?xf32>) -> tensor<?xf32>
|
||||
// CHECK: "xla_test.return_type_components"(%0) {dims0 = [-1], element_type0 = f32}
|
||||
%1 = "xla_test.get_return_type_components"(%0) : (tensor<?xf32>) -> tensor<?xf32>
|
||||
return %1 : tensor<?xf32>
|
||||
@ -46,7 +46,7 @@ func @broadcast_add_ranked_components_r1(%arg0: tensor<?xf32>, %arg1: tensor<?xf
|
||||
// -----
|
||||
// CHECK-LABEL: @broadcast_add_ranked_components_r1x2
|
||||
func @broadcast_add_ranked_components_r1x2(%arg0: tensor<?xf32>, %arg1: tensor<?x3xf32>) -> tensor<?x3xf32> {
|
||||
%0 = xla_chlo.broadcast_add %arg0, %arg1 : (tensor<?xf32>, tensor<?x3xf32>) -> tensor<?x3xf32>
|
||||
%0 = chlo.broadcast_add %arg0, %arg1 : (tensor<?xf32>, tensor<?x3xf32>) -> tensor<?x3xf32>
|
||||
// TODO: Overly broad shapes are being returned. Tighten the calculation
|
||||
// and update/extend these tests.
|
||||
// CHECK: "xla_test.return_type_components"(%0) {dims0 = [-1, -1], element_type0 = f32}
|
||||
|
@ -5,7 +5,7 @@
|
||||
// CHECK-LABEL: @addWithoutBroadcast
|
||||
func @addWithoutBroadcast(%arg0: tensor<4xf32>, %arg1: tensor<4xf32>) -> tensor<4xf32> {
|
||||
// CHECK: mhlo.add %arg0, %arg1
|
||||
%0 = xla_chlo.broadcast_add %arg0, %arg1 : (tensor<4xf32>, tensor<4xf32>) -> tensor<4xf32>
|
||||
%0 = chlo.broadcast_add %arg0, %arg1 : (tensor<4xf32>, tensor<4xf32>) -> tensor<4xf32>
|
||||
return %0 : tensor<4xf32>
|
||||
}
|
||||
|
||||
@ -26,7 +26,7 @@ func @dynamicBroadcast(%arg0: tensor<?xf32>, %arg1: tensor<?x?xf32>) -> tensor<?
|
||||
// CHECK-NEXT: shape.assuming_yield %[[RESULT]]
|
||||
// CHECK-NEXT: }
|
||||
// CHECK-NEXT: return %[[FINAL_RESULT]] : tensor<?x?xf32>
|
||||
%0 = xla_chlo.broadcast_add %arg0, %arg1 : (tensor<?xf32>, tensor<?x?xf32>) -> tensor<?x?xf32>
|
||||
%0 = chlo.broadcast_add %arg0, %arg1 : (tensor<?xf32>, tensor<?x?xf32>) -> tensor<?x?xf32>
|
||||
return %0 : tensor<?x?xf32>
|
||||
}
|
||||
|
||||
@ -47,7 +47,7 @@ func @dynamicBroadcastComplex(%arg0: tensor<?xf32>, %arg1: tensor<?x?xf32>) -> t
|
||||
// CHECK-NEXT: shape.assuming_yield %[[RESULT]]
|
||||
// CHECK-NEXT: }
|
||||
// CHECK-NEXT: return %[[FINAL_RESULT]] : tensor<?x?xcomplex<f32>>
|
||||
%0 = xla_chlo.broadcast_complex %arg0, %arg1 : (tensor<?xf32>, tensor<?x?xf32>) -> tensor<?x?xcomplex<f32>>
|
||||
%0 = chlo.broadcast_complex %arg0, %arg1 : (tensor<?xf32>, tensor<?x?xf32>) -> tensor<?x?xcomplex<f32>>
|
||||
return %0 : tensor<?x?xcomplex<f32>>
|
||||
}
|
||||
|
||||
@ -68,7 +68,7 @@ func @dynamicBroadcastCompare(%arg0: tensor<?xf32>, %arg1: tensor<?x?xf32>) -> t
|
||||
// CHECK: shape.assuming_yield %[[RESULT]]
|
||||
// CHECK-NEXT: }
|
||||
// CHECK: return %[[FINAL_RESULT]] : tensor<?x?xi1>
|
||||
%0 = xla_chlo.broadcast_compare %arg0, %arg1 {comparison_direction = "EQ"} : (tensor<?xf32>, tensor<?x?xf32>) -> tensor<?x?xi1>
|
||||
%0 = chlo.broadcast_compare %arg0, %arg1 {comparison_direction = "EQ"} : (tensor<?xf32>, tensor<?x?xf32>) -> tensor<?x?xi1>
|
||||
return %0 : tensor<?x?xi1>
|
||||
}
|
||||
|
||||
@ -77,7 +77,7 @@ func @dynamicBroadcastCompare(%arg0: tensor<?xf32>, %arg1: tensor<?x?xf32>) -> t
|
||||
// CHECK-LABEL: @dynamicNonScalarBroadcastDimensions
|
||||
func @dynamicNonScalarBroadcastDimensions(%arg0: tensor<1x4xf32>, %arg1: tensor<4xf32>) -> tensor<1x4xf32> {
|
||||
// CHECK: mhlo.add
|
||||
%0 = xla_chlo.broadcast_add %arg0, %arg1 {broadcast_dimensions = dense<1> : tensor<1xi64>} : (tensor<1x4xf32>, tensor<4xf32>) -> tensor<1x4xf32>
|
||||
%0 = chlo.broadcast_add %arg0, %arg1 {broadcast_dimensions = dense<1> : tensor<1xi64>} : (tensor<1x4xf32>, tensor<4xf32>) -> tensor<1x4xf32>
|
||||
return %0 : tensor<1x4xf32>
|
||||
}
|
||||
|
||||
@ -86,7 +86,7 @@ func @dynamicNonScalarBroadcastDimensions(%arg0: tensor<1x4xf32>, %arg1: tensor<
|
||||
// CHECK-LABEL: @dynamicNonScalarByScalarBroadcastDimensions
|
||||
func @dynamicNonScalarByScalarBroadcastDimensions(%arg0: tensor<1x4xf32>, %arg1: tensor<f32>) -> tensor<1x4xf32> {
|
||||
// CHECK: mhlo.add
|
||||
%0 = xla_chlo.broadcast_add %arg0, %arg1 {broadcast_dimensions = dense<[]> : tensor<0xi64>} : (tensor<1x4xf32>, tensor<f32>) -> tensor<1x4xf32>
|
||||
%0 = chlo.broadcast_add %arg0, %arg1 {broadcast_dimensions = dense<[]> : tensor<0xi64>} : (tensor<1x4xf32>, tensor<f32>) -> tensor<1x4xf32>
|
||||
return %0 : tensor<1x4xf32>
|
||||
}
|
||||
|
||||
@ -95,7 +95,7 @@ func @dynamicNonScalarByScalarBroadcastDimensions(%arg0: tensor<1x4xf32>, %arg1:
|
||||
func @dynamicNonScalarBroadcastDimensionsSizeMismatch(%arg0: tensor<1x4xf32>, %arg1: tensor<4xf32>) -> tensor<1x4xf32> {
|
||||
// expected-warning @+2 {{unsupported non prefix-padded dynamic rank broadcast_dimensions}}
|
||||
// expected-error @+1 {{failed to legalize operation}}
|
||||
%0 = xla_chlo.broadcast_add %arg0, %arg1 {broadcast_dimensions = dense<[1, 2]> : tensor<2xi64>} : (tensor<1x4xf32>, tensor<4xf32>) -> tensor<1x4xf32>
|
||||
%0 = chlo.broadcast_add %arg0, %arg1 {broadcast_dimensions = dense<[1, 2]> : tensor<2xi64>} : (tensor<1x4xf32>, tensor<4xf32>) -> tensor<1x4xf32>
|
||||
return %0 : tensor<1x4xf32>
|
||||
}
|
||||
|
||||
@ -104,7 +104,7 @@ func @dynamicNonScalarBroadcastDimensionsSizeMismatch(%arg0: tensor<1x4xf32>, %a
|
||||
func @dynamicNonScalarBroadcastDimensionsMismatch(%arg0: tensor<1x4xf32>, %arg1: tensor<4xf32>) -> tensor<1x4xf32> {
|
||||
// expected-warning @+2 {{unsupported non prefix-padded dynamic rank broadcast_dimensions}}
|
||||
// expected-error @+1 {{failed to legalize operation}}
|
||||
%0 = xla_chlo.broadcast_add %arg0, %arg1 {broadcast_dimensions = dense<2> : tensor<1xi64>} : (tensor<1x4xf32>, tensor<4xf32>) -> tensor<1x4xf32>
|
||||
%0 = chlo.broadcast_add %arg0, %arg1 {broadcast_dimensions = dense<2> : tensor<1xi64>} : (tensor<1x4xf32>, tensor<4xf32>) -> tensor<1x4xf32>
|
||||
return %0 : tensor<1x4xf32>
|
||||
}
|
||||
|
||||
@ -114,7 +114,7 @@ func @dynamicNonScalarBroadcastDimensionsMismatch(%arg0: tensor<1x4xf32>, %arg1:
|
||||
// CHECK-LABEL: @andWithoutBroadcast
|
||||
func @andWithoutBroadcast(%arg0: tensor<4xi1>, %arg1: tensor<4xi1>) -> tensor<4xi1> {
|
||||
// CHECK: mhlo.and %arg0, %arg1
|
||||
%0 = xla_chlo.broadcast_and %arg0, %arg1 : (tensor<4xi1>, tensor<4xi1>) -> tensor<4xi1>
|
||||
%0 = chlo.broadcast_and %arg0, %arg1 : (tensor<4xi1>, tensor<4xi1>) -> tensor<4xi1>
|
||||
return %0 : tensor<4xi1>
|
||||
}
|
||||
|
||||
@ -122,7 +122,7 @@ func @andWithoutBroadcast(%arg0: tensor<4xi1>, %arg1: tensor<4xi1>) -> tensor<4x
|
||||
// CHECK-LABEL: @atan2WithoutBroadcast
|
||||
func @atan2WithoutBroadcast(%arg0: tensor<4xf32>, %arg1: tensor<4xf32>) -> tensor<4xf32> {
|
||||
// CHECK: mhlo.atan2 %arg0, %arg1
|
||||
%0 = xla_chlo.broadcast_atan2 %arg0, %arg1 : (tensor<4xf32>, tensor<4xf32>) -> tensor<4xf32>
|
||||
%0 = chlo.broadcast_atan2 %arg0, %arg1 : (tensor<4xf32>, tensor<4xf32>) -> tensor<4xf32>
|
||||
return %0 : tensor<4xf32>
|
||||
}
|
||||
|
||||
@ -130,7 +130,7 @@ func @atan2WithoutBroadcast(%arg0: tensor<4xf32>, %arg1: tensor<4xf32>) -> tenso
|
||||
// CHECK-LABEL: @compareWithoutBroadcast
|
||||
func @compareWithoutBroadcast(%arg0: tensor<4xf32>, %arg1: tensor<4xf32>) -> tensor<4xi1> {
|
||||
// CHECK: "mhlo.compare"(%arg0, %arg1) {comparison_direction = "EQ"} : (tensor<4xf32>, tensor<4xf32>) -> tensor<4xi1>
|
||||
%0 = xla_chlo.broadcast_compare %arg0, %arg1 {comparison_direction = "EQ"} : (tensor<4xf32>, tensor<4xf32>) -> tensor<4xi1>
|
||||
%0 = chlo.broadcast_compare %arg0, %arg1 {comparison_direction = "EQ"} : (tensor<4xf32>, tensor<4xf32>) -> tensor<4xi1>
|
||||
return %0 : tensor<4xi1>
|
||||
}
|
||||
|
||||
@ -138,7 +138,7 @@ func @compareWithoutBroadcast(%arg0: tensor<4xf32>, %arg1: tensor<4xf32>) -> ten
|
||||
// CHECK-LABEL: @complexWithoutBroadcast
|
||||
func @complexWithoutBroadcast(%arg0: tensor<4xf32>, %arg1: tensor<4xf32>) -> tensor<4xcomplex<f32>> {
|
||||
// CHECK: "mhlo.complex"(%arg0, %arg1) : (tensor<4xf32>, tensor<4xf32>) -> tensor<4xcomplex<f32>>
|
||||
%0 = xla_chlo.broadcast_complex %arg0, %arg1 : (tensor<4xf32>, tensor<4xf32>) -> tensor<4xcomplex<f32>>
|
||||
%0 = chlo.broadcast_complex %arg0, %arg1 : (tensor<4xf32>, tensor<4xf32>) -> tensor<4xcomplex<f32>>
|
||||
return %0 : tensor<4xcomplex<f32>>
|
||||
}
|
||||
|
||||
@ -146,7 +146,7 @@ func @complexWithoutBroadcast(%arg0: tensor<4xf32>, %arg1: tensor<4xf32>) -> ten
|
||||
// CHECK-LABEL: @divideWithoutBroadcast
|
||||
func @divideWithoutBroadcast(%arg0: tensor<4xf32>, %arg1: tensor<4xf32>) -> tensor<4xf32> {
|
||||
// CHECK: mhlo.divide %arg0, %arg1
|
||||
%0 = xla_chlo.broadcast_divide %arg0, %arg1 : (tensor<4xf32>, tensor<4xf32>) -> tensor<4xf32>
|
||||
%0 = chlo.broadcast_divide %arg0, %arg1 : (tensor<4xf32>, tensor<4xf32>) -> tensor<4xf32>
|
||||
return %0 : tensor<4xf32>
|
||||
}
|
||||
|
||||
@ -154,7 +154,7 @@ func @divideWithoutBroadcast(%arg0: tensor<4xf32>, %arg1: tensor<4xf32>) -> tens
|
||||
// CHECK-LABEL: @maximumWithoutBroadcast
|
||||
func @maximumWithoutBroadcast(%arg0: tensor<4xf32>, %arg1: tensor<4xf32>) -> tensor<4xf32> {
|
||||
// CHECK: mhlo.maximum %arg0, %arg1
|
||||
%0 = xla_chlo.broadcast_maximum %arg0, %arg1 : (tensor<4xf32>, tensor<4xf32>) -> tensor<4xf32>
|
||||
%0 = chlo.broadcast_maximum %arg0, %arg1 : (tensor<4xf32>, tensor<4xf32>) -> tensor<4xf32>
|
||||
return %0 : tensor<4xf32>
|
||||
}
|
||||
|
||||
@ -162,7 +162,7 @@ func @maximumWithoutBroadcast(%arg0: tensor<4xf32>, %arg1: tensor<4xf32>) -> ten
|
||||
// CHECK-LABEL: @minimumWithoutBroadcast
|
||||
func @minimumWithoutBroadcast(%arg0: tensor<4xf32>, %arg1: tensor<4xf32>) -> tensor<4xf32> {
|
||||
// CHECK: mhlo.minimum %arg0, %arg1
|
||||
%0 = xla_chlo.broadcast_minimum %arg0, %arg1 : (tensor<4xf32>, tensor<4xf32>) -> tensor<4xf32>
|
||||
%0 = chlo.broadcast_minimum %arg0, %arg1 : (tensor<4xf32>, tensor<4xf32>) -> tensor<4xf32>
|
||||
return %0 : tensor<4xf32>
|
||||
}
|
||||
|
||||
@ -170,7 +170,7 @@ func @minimumWithoutBroadcast(%arg0: tensor<4xf32>, %arg1: tensor<4xf32>) -> ten
|
||||
// CHECK-LABEL: @multiplyWithoutBroadcast
|
||||
func @multiplyWithoutBroadcast(%arg0: tensor<4xf32>, %arg1: tensor<4xf32>) -> tensor<4xf32> {
|
||||
// CHECK: mhlo.multiply %arg0, %arg1
|
||||
%0 = xla_chlo.broadcast_multiply %arg0, %arg1 : (tensor<4xf32>, tensor<4xf32>) -> tensor<4xf32>
|
||||
%0 = chlo.broadcast_multiply %arg0, %arg1 : (tensor<4xf32>, tensor<4xf32>) -> tensor<4xf32>
|
||||
return %0 : tensor<4xf32>
|
||||
}
|
||||
|
||||
@ -178,7 +178,7 @@ func @multiplyWithoutBroadcast(%arg0: tensor<4xf32>, %arg1: tensor<4xf32>) -> te
|
||||
// CHECK-LABEL: @orWithoutBroadcast
|
||||
func @orWithoutBroadcast(%arg0: tensor<4xi1>, %arg1: tensor<4xi1>) -> tensor<4xi1> {
|
||||
// CHECK: mhlo.or %arg0, %arg1
|
||||
%0 = xla_chlo.broadcast_or %arg0, %arg1 : (tensor<4xi1>, tensor<4xi1>) -> tensor<4xi1>
|
||||
%0 = chlo.broadcast_or %arg0, %arg1 : (tensor<4xi1>, tensor<4xi1>) -> tensor<4xi1>
|
||||
return %0 : tensor<4xi1>
|
||||
}
|
||||
|
||||
@ -186,7 +186,7 @@ func @orWithoutBroadcast(%arg0: tensor<4xi1>, %arg1: tensor<4xi1>) -> tensor<4xi
|
||||
// CHECK-LABEL: @powerWithoutBroadcast
|
||||
func @powerWithoutBroadcast(%arg0: tensor<4xf32>, %arg1: tensor<4xf32>) -> tensor<4xf32> {
|
||||
// CHECK: mhlo.power %arg0, %arg1
|
||||
%0 = xla_chlo.broadcast_power %arg0, %arg1 : (tensor<4xf32>, tensor<4xf32>) -> tensor<4xf32>
|
||||
%0 = chlo.broadcast_power %arg0, %arg1 : (tensor<4xf32>, tensor<4xf32>) -> tensor<4xf32>
|
||||
return %0 : tensor<4xf32>
|
||||
}
|
||||
|
||||
@ -194,7 +194,7 @@ func @powerWithoutBroadcast(%arg0: tensor<4xf32>, %arg1: tensor<4xf32>) -> tenso
|
||||
// CHECK-LABEL: @remainderWithoutBroadcast
|
||||
func @remainderWithoutBroadcast(%arg0: tensor<4xf32>, %arg1: tensor<4xf32>) -> tensor<4xf32> {
|
||||
// CHECK: mhlo.remainder %arg0, %arg1
|
||||
%0 = xla_chlo.broadcast_remainder %arg0, %arg1 : (tensor<4xf32>, tensor<4xf32>) -> tensor<4xf32>
|
||||
%0 = chlo.broadcast_remainder %arg0, %arg1 : (tensor<4xf32>, tensor<4xf32>) -> tensor<4xf32>
|
||||
return %0 : tensor<4xf32>
|
||||
}
|
||||
|
||||
@ -202,7 +202,7 @@ func @remainderWithoutBroadcast(%arg0: tensor<4xf32>, %arg1: tensor<4xf32>) -> t
|
||||
// CHECK-LABEL: @shift_leftWithoutBroadcast
|
||||
func @shift_leftWithoutBroadcast(%arg0: tensor<4xf32>, %arg1: tensor<4xf32>) -> tensor<4xf32> {
|
||||
// CHECK: mhlo.shift_left %arg0, %arg1
|
||||
%0 = xla_chlo.broadcast_shift_left %arg0, %arg1 : (tensor<4xf32>, tensor<4xf32>) -> tensor<4xf32>
|
||||
%0 = chlo.broadcast_shift_left %arg0, %arg1 : (tensor<4xf32>, tensor<4xf32>) -> tensor<4xf32>
|
||||
return %0 : tensor<4xf32>
|
||||
}
|
||||
|
||||
@ -210,7 +210,7 @@ func @shift_leftWithoutBroadcast(%arg0: tensor<4xf32>, %arg1: tensor<4xf32>) ->
|
||||
// CHECK-LABEL: @shift_right_arithmeticWithoutBroadcast
|
||||
func @shift_right_arithmeticWithoutBroadcast(%arg0: tensor<4xf32>, %arg1: tensor<4xf32>) -> tensor<4xf32> {
|
||||
// CHECK: mhlo.shift_right_arithmetic %arg0, %arg1
|
||||
%0 = xla_chlo.broadcast_shift_right_arithmetic %arg0, %arg1 : (tensor<4xf32>, tensor<4xf32>) -> tensor<4xf32>
|
||||
%0 = chlo.broadcast_shift_right_arithmetic %arg0, %arg1 : (tensor<4xf32>, tensor<4xf32>) -> tensor<4xf32>
|
||||
return %0 : tensor<4xf32>
|
||||
}
|
||||
|
||||
@ -218,7 +218,7 @@ func @shift_right_arithmeticWithoutBroadcast(%arg0: tensor<4xf32>, %arg1: tensor
|
||||
// CHECK-LABEL: @shift_right_logicalWithoutBroadcast
|
||||
func @shift_right_logicalWithoutBroadcast(%arg0: tensor<4xf32>, %arg1: tensor<4xf32>) -> tensor<4xf32> {
|
||||
// CHECK: mhlo.shift_right_logical %arg0, %arg1
|
||||
%0 = xla_chlo.broadcast_shift_right_logical %arg0, %arg1 : (tensor<4xf32>, tensor<4xf32>) -> tensor<4xf32>
|
||||
%0 = chlo.broadcast_shift_right_logical %arg0, %arg1 : (tensor<4xf32>, tensor<4xf32>) -> tensor<4xf32>
|
||||
return %0 : tensor<4xf32>
|
||||
}
|
||||
|
||||
@ -226,7 +226,7 @@ func @shift_right_logicalWithoutBroadcast(%arg0: tensor<4xf32>, %arg1: tensor<4x
|
||||
// CHECK-LABEL: @subWithoutBroadcast
|
||||
func @subWithoutBroadcast(%arg0: tensor<4xf32>, %arg1: tensor<4xf32>) -> tensor<4xf32> {
|
||||
// CHECK: mhlo.subtract %arg0, %arg1
|
||||
%0 = xla_chlo.broadcast_subtract %arg0, %arg1 : (tensor<4xf32>, tensor<4xf32>) -> tensor<4xf32>
|
||||
%0 = chlo.broadcast_subtract %arg0, %arg1 : (tensor<4xf32>, tensor<4xf32>) -> tensor<4xf32>
|
||||
return %0 : tensor<4xf32>
|
||||
}
|
||||
|
||||
@ -234,6 +234,6 @@ func @subWithoutBroadcast(%arg0: tensor<4xf32>, %arg1: tensor<4xf32>) -> tensor<
|
||||
// CHECK-LABEL: @xorWithoutBroadcast
|
||||
func @xorWithoutBroadcast(%arg0: tensor<4xi1>, %arg1: tensor<4xi1>) -> tensor<4xi1> {
|
||||
// CHECK: mhlo.xor %arg0, %arg1
|
||||
%0 = xla_chlo.broadcast_xor %arg0, %arg1 : (tensor<4xi1>, tensor<4xi1>) -> tensor<4xi1>
|
||||
%0 = chlo.broadcast_xor %arg0, %arg1 : (tensor<4xi1>, tensor<4xi1>) -> tensor<4xi1>
|
||||
return %0 : tensor<4xi1>
|
||||
}
|
||||
|
@ -0,0 +1,34 @@
|
||||
// RUN: mlir-hlo-opt -hlo-legalize-to-lhlo=results-escape-function=true -buffer-placement %s -o - | FileCheck %s
|
||||
|
||||
// CHECK-LABEL: func @func_op_unranked_arg_result
|
||||
func @func_op_unranked_arg_result(%arg0: tensor<*xf32>) -> tensor<*xf32> {
|
||||
return %arg0 : tensor<*xf32>
|
||||
}
|
||||
// CHECK-SAME: ([[ARG:%.*]]: memref<*xf32>) -> memref<*xf32>
|
||||
// CHECK-NEXT: return [[ARG]] : memref<*xf32>
|
||||
|
||||
// -----
|
||||
|
||||
// CHECK-LABEL: func @dynamic_reshape_from_unranked
|
||||
func @dynamic_reshape_from_unranked(
|
||||
%operand: tensor<*xf32>, %shape: tensor<1xi32>) -> tensor<?xf32> {
|
||||
%reshaped = "mhlo.dynamic_reshape"(%operand, %shape)
|
||||
: (tensor<*xf32>, tensor<1xi32>) -> tensor<?xf32>
|
||||
return %reshaped : tensor<?xf32>
|
||||
}
|
||||
// CHECK-SAME: ([[ARG:%.*]]: memref<*xf32>, [[SHAPE:%.*]]: memref<1xi32>)
|
||||
// CHECK-NEXT: reshape_memref_cast [[ARG]]([[SHAPE]])
|
||||
// CHECK-SAME: : (memref<*xf32>, memref<1xi32>) -> memref<?xf32>
|
||||
|
||||
// -----
|
||||
|
||||
// CHECK-LABEL: func @dynamic_reshape_to_unranked
|
||||
func @dynamic_reshape_to_unranked(
|
||||
%operand: tensor<?xf32>, %shape: tensor<?xi32>) -> tensor<*xf32> {
|
||||
%reshaped = "mhlo.dynamic_reshape"(%operand, %shape)
|
||||
: (tensor<?xf32>, tensor<?xi32>) -> tensor<*xf32>
|
||||
return %reshaped : tensor<*xf32>
|
||||
}
|
||||
// CHECK-SAME: ([[ARG:%.*]]: memref<?xf32>, [[SHAPE:%.*]]: memref<?xi32>)
|
||||
// CHECK-NEXT: reshape_memref_cast [[ARG]]([[SHAPE]])
|
||||
// CHECK-SAME: : (memref<?xf32>, memref<?xi32>) -> memref<*xf32>
|
@ -7,7 +7,7 @@ func @attrs_copy(%operand: memref<2x2xf32>, %result: memref<2x2xf32>) {
|
||||
%tensor_result = "mhlo.exponential"(%tensor_operand)
|
||||
{some_attr_1 = "exp.1", some_attr_2 = dense<1> : tensor<1xi64>}
|
||||
: (tensor<2x2xf32>) -> tensor<2x2xf32>
|
||||
// BOTH: "xla_lhlo.exponential"(%{{.*}}, %{{.*}}) {some_attr_1 = "exp.1", some_attr_2 = dense<1> : tensor<1xi64>}
|
||||
// BOTH: "lmhlo.exponential"(%{{.*}}, %{{.*}}) {some_attr_1 = "exp.1", some_attr_2 = dense<1> : tensor<1xi64>}
|
||||
tensor_store %tensor_result, %result : memref<2x2xf32>
|
||||
return
|
||||
}
|
||||
@ -18,10 +18,10 @@ func @return_func(%arg0: tensor<4xf32>) -> tensor<4xf32> {
|
||||
return %arg0 : tensor<4xf32>
|
||||
}
|
||||
// PRE: (%[[ARG0:.*]]: [[TYPE:.*]], %[[RESULT:.*]]: [[TYPE]])
|
||||
// PRE-NEXT: "xla_lhlo.copy"(%[[ARG0]], %[[RESULT]]) : ([[TYPE]], [[TYPE]]) -> ()
|
||||
// PRE-NEXT: "lmhlo.copy"(%[[ARG0]], %[[RESULT]]) : ([[TYPE]], [[TYPE]]) -> ()
|
||||
// PRE-NEXT: return
|
||||
// ESC: (%[[ARG0:.*]]: [[TYPE:.*]]) -> [[TYPE]]
|
||||
// ESC-NOT: "xla_lhlo.copy"
|
||||
// ESC-NOT: "lmhlo.copy"
|
||||
// ESC-NEXT: return %[[ARG0]]
|
||||
|
||||
// -----
|
||||
@ -38,20 +38,20 @@ func @func_op_long(%arg0: tensor<4xf32>, %arg1: tensor<4xf32>) -> tensor<4xf32>
|
||||
// PRE: (%[[NEW_ARG0:.*]]: memref<4xf32>, %[[NEW_ARG1:.*]]: memref<4xf32>, %[[RESULT:.*]]: memref<4xf32>)
|
||||
// ESC: (%[[NEW_ARG0:.*]]: memref<4xf32>, %[[NEW_ARG1:.*]]: memref<4xf32>) -> memref<4xf32>
|
||||
// BOTH-NEXT: %[[MAX_RESULT:.*]] = alloc() : memref<4xf32>
|
||||
// BOTH-NEXT: "xla_lhlo.maximum"(%[[NEW_ARG0]], %[[NEW_ARG1]], %[[MAX_RESULT]])
|
||||
// BOTH-NEXT: "lmhlo.maximum"(%[[NEW_ARG0]], %[[NEW_ARG1]], %[[MAX_RESULT]])
|
||||
// BOTH-NEXT: %[[ADD_RESULT:.*]] = alloc() : memref<4xf32>
|
||||
// BOTH-NEXT: "xla_lhlo.add"(%[[NEW_ARG0]], %[[MAX_RESULT]], %[[ADD_RESULT]])
|
||||
// BOTH-NEXT: "lmhlo.add"(%[[NEW_ARG0]], %[[MAX_RESULT]], %[[ADD_RESULT]])
|
||||
// BOTH-NEXT: dealloc %[[MAX_RESULT]] : memref<4xf32>
|
||||
// BOTH-NEXT: %[[MIN_RESULT:.*]] = alloc() : memref<4xf32>
|
||||
// BOTH-NEXT: "xla_lhlo.minimum"(%[[NEW_ARG0]], %[[NEW_ARG1]], %[[MIN_RESULT]])
|
||||
// BOTH-NEXT: "lmhlo.minimum"(%[[NEW_ARG0]], %[[NEW_ARG1]], %[[MIN_RESULT]])
|
||||
// BOTH-NEXT: %[[SUB_RESULT:.*]] = alloc() : memref<4xf32>
|
||||
// BOTH-NEXT: "xla_lhlo.subtract"(%[[NEW_ARG1]], %[[MIN_RESULT]], %[[SUB_RESULT]])
|
||||
// BOTH-NEXT: "lmhlo.subtract"(%[[NEW_ARG1]], %[[MIN_RESULT]], %[[SUB_RESULT]])
|
||||
// BOTH-NEXT: dealloc %[[MIN_RESULT]] : memref<4xf32>
|
||||
// BOTH-NEXT: %[[MUL_RESULT:.*]] = alloc() : memref<4xf32>
|
||||
// BOTH-NEXT: "xla_lhlo.multiply"(%[[ADD_RESULT]], %[[SUB_RESULT]], %[[MUL_RESULT]])
|
||||
// BOTH-NEXT: "lmhlo.multiply"(%[[ADD_RESULT]], %[[SUB_RESULT]], %[[MUL_RESULT]])
|
||||
// BOTH-NEXT: dealloc %[[SUB_RESULT]] : memref<4xf32>
|
||||
// BOTH-NEXT: dealloc %[[ADD_RESULT]] : memref<4xf32>
|
||||
// PRE-NEXT: "xla_lhlo.copy"(%[[MUL_RESULT]], %[[RESULT]]) : (memref<4xf32>, memref<4xf32>) -> ()
|
||||
// PRE-NEXT: "lmhlo.copy"(%[[MUL_RESULT]], %[[RESULT]]) : (memref<4xf32>, memref<4xf32>) -> ()
|
||||
// PRE-NEXT: dealloc %[[MUL_RESULT]] : memref<4xf32>
|
||||
// PRE-NEXT: return
|
||||
// ESC-NEXT: return %[[MUL_RESULT]] : memref<4xf32>
|
||||
@ -67,14 +67,14 @@ func @fusion(%multiplier: memref<2x2xf32>, %summand_1: memref<2x2xf32>,
|
||||
%tensor_summand_2 = tensor_load %summand_2 : memref<2x2xf32>
|
||||
%sum = "mhlo.add"(%tensor_summand_1, %tensor_summand_2)
|
||||
: (tensor<2x2xf32>, tensor<2x2xf32>) -> tensor<2x2xf32>
|
||||
// BOTH-NEXT: "xla_lhlo.add"(%{{.*}}, %{{.*}}, %[[ADD_RESULT]])
|
||||
// BOTH-NEXT: "lmhlo.add"(%{{.*}}, %{{.*}}, %[[ADD_RESULT]])
|
||||
// BOTH-NEXT: %[[MUL_RESULT:.*]] = alloc() : memref<2x2xf32>
|
||||
%tensor_multiplier = tensor_load %multiplier : memref<2x2xf32>
|
||||
%tensor_result = "mhlo.multiply"(%sum, %tensor_multiplier)
|
||||
: (tensor<2x2xf32>, tensor<2x2xf32>) -> tensor<2x2xf32>
|
||||
// BOTH-NEXT: "xla_lhlo.multiply"(%[[ADD_RESULT]], %{{.*}}, %[[MUL_RESULT]])
|
||||
// BOTH-NEXT: "lmhlo.multiply"(%[[ADD_RESULT]], %{{.*}}, %[[MUL_RESULT]])
|
||||
// BOTH-NEXT: dealloc %[[ADD_RESULT]] : memref<2x2xf32>
|
||||
// BOTH-NEXT: "xla_lhlo.copy"(%[[MUL_RESULT]], %[[RESULT]])
|
||||
// BOTH-NEXT: "lmhlo.copy"(%[[MUL_RESULT]], %[[RESULT]])
|
||||
tensor_store %tensor_result, %result : memref<2x2xf32>
|
||||
// BOTH-NEXT: dealloc %[[MUL_RESULT]] : memref<2x2xf32>
|
||||
// BOTH-NEXT: return
|
||||
@ -88,7 +88,7 @@ func @copy(%operand: memref<2x2xf32>, %result: memref<2x2xf32>) {
|
||||
%tensor_operand = tensor_load %operand : memref<2x2xf32>
|
||||
%tensor_result = "mhlo.copy"(%tensor_operand)
|
||||
: (tensor<2x2xf32>) -> tensor<2x2xf32>
|
||||
// BOTH: "xla_lhlo.copy"(%{{.*}}, %{{.*}})
|
||||
// BOTH: "lmhlo.copy"(%{{.*}}, %{{.*}})
|
||||
tensor_store %tensor_result, %result : memref<2x2xf32>
|
||||
return
|
||||
}
|
||||
@ -100,7 +100,7 @@ func @exp(%operand: memref<2x2xf32>, %result: memref<2x2xf32>) {
|
||||
%tensor_operand = tensor_load %operand : memref<2x2xf32>
|
||||
%tensor_result = "mhlo.exponential"(%tensor_operand)
|
||||
: (tensor<2x2xf32>) -> tensor<2x2xf32>
|
||||
// BOTH: "xla_lhlo.exponential"(%{{.*}}, %{{.*}})
|
||||
// BOTH: "lmhlo.exponential"(%{{.*}}, %{{.*}})
|
||||
tensor_store %tensor_result, %result : memref<2x2xf32>
|
||||
return
|
||||
}
|
||||
@ -112,7 +112,7 @@ func @log(%operand: memref<2x2xf32>, %result: memref<2x2xf32>) {
|
||||
%tensor_operand = tensor_load %operand : memref<2x2xf32>
|
||||
%tensor_result = "mhlo.log"(%tensor_operand)
|
||||
: (tensor<2x2xf32>) -> tensor<2x2xf32>
|
||||
// BOTH: "xla_lhlo.log"(%{{.*}}, %{{.*}})
|
||||
// BOTH: "lmhlo.log"(%{{.*}}, %{{.*}})
|
||||
tensor_store %tensor_result, %result : memref<2x2xf32>
|
||||
return
|
||||
}
|
||||
@ -127,7 +127,7 @@ func @select(%pred: memref<2x2xi1>, %lhs: memref<2x2xf32>,
|
||||
%tensor_rhs = tensor_load %rhs : memref<2x2xf32>
|
||||
%tensor_result = "mhlo.select"(%tensor_pred, %tensor_lhs, %tensor_rhs)
|
||||
: (tensor<2x2xi1>, tensor<2x2xf32>, tensor<2x2xf32>) -> tensor<2x2xf32>
|
||||
// BOTH: "xla_lhlo.select"(%{{.*}}, %{{.*}}, %{{.*}}, %{{.*}})
|
||||
// BOTH: "lmhlo.select"(%{{.*}}, %{{.*}}, %{{.*}}, %{{.*}})
|
||||
tensor_store %tensor_result, %result : memref<2x2xf32>
|
||||
return
|
||||
}
|
||||
@ -141,7 +141,7 @@ func @compare(%lhs: memref<2x2xf32>, %rhs: memref<2x2xf32>, %result: memref<2x2x
|
||||
%tensor_result = "mhlo.compare"(%tensor_lhs, %tensor_rhs)
|
||||
{comparison_direction = "EQ"}
|
||||
: (tensor<2x2xf32>, tensor<2x2xf32>) -> tensor<2x2xi1>
|
||||
// BOTH: "xla_lhlo.compare"(%{{.*}}, %{{.*}}, %{{.*}}) {comparison_direction = "EQ"}
|
||||
// BOTH: "lmhlo.compare"(%{{.*}}, %{{.*}}, %{{.*}}) {comparison_direction = "EQ"}
|
||||
tensor_store %tensor_result, %result : memref<2x2xi1>
|
||||
return
|
||||
}
|
||||
@ -154,7 +154,7 @@ func @broadcast(%operand: memref<5xf32>, %result: memref<10x5xf32>) {
|
||||
%tensor_result = "mhlo.broadcast_in_dim"(%tensor_operand)
|
||||
{broadcast_dimensions = dense<1> : tensor<1xi64>}
|
||||
: (tensor<5xf32>) -> tensor<10x5xf32>
|
||||
// BOTH: "xla_lhlo.broadcast_in_dim"(%{{.*}}, %{{.*}}) {broadcast_dimensions = dense<1> : tensor<1xi64>}
|
||||
// BOTH: "lmhlo.broadcast_in_dim"(%{{.*}}, %{{.*}}) {broadcast_dimensions = dense<1> : tensor<1xi64>}
|
||||
tensor_store %tensor_result, %result : memref<10x5xf32>
|
||||
return
|
||||
}
|
||||
@ -169,11 +169,12 @@ func @external_func() -> tensor<3xi64>
|
||||
func @dyn_broadcast(%operand: memref<?x?xf32>) {
|
||||
// BOTH-SAME: (%[[OPERAND:.*]]: memref<?x?xf32>)
|
||||
%tensor_operand = tensor_load %operand : memref<?x?xf32>
|
||||
%shape = call @external_func() : () -> tensor<3xi64>
|
||||
%c1 = constant 1 : i64
|
||||
%shape = tensor_from_elements(%c1, %c1, %c1) : tensor<3xi64>
|
||||
%tensor_result = "mhlo.dynamic_broadcast_in_dim"(%tensor_operand, %shape) {
|
||||
broadcast_dimensions = dense<[1, 2]> : tensor<2xi64>
|
||||
} : (tensor<?x?xf32>, tensor<3xi64>) -> tensor<?x?x?xf32>
|
||||
// BOTH: %[[SHAPE:.*]] = call @external_func()
|
||||
// BOTH: %[[SHAPE:.*]] = tensor_from_elements
|
||||
// BOTH: %[[C0:.*]] = constant 0 : index
|
||||
// BOTH: %[[EL0:.*]] = extract_element %[[SHAPE]][%[[C0]]] : tensor<3xi64>
|
||||
// BOTH: %[[IC0:.*]] = index_cast %[[EL0]] : i64 to index
|
||||
@ -204,12 +205,12 @@ func @dyn_broadcast(%operand: memref<?x?xf32>) {
|
||||
// BOTH: %[[EXPAND_1:.*]] = cmpi "slt", %[[OPERAND_DIM_1]], %[[RESULT_DIM_2]]
|
||||
// BOTH: %[[STRIDE_1:.*]] = select %[[EXPAND_1]], %[[C0_]], %[[C1_]] : index
|
||||
|
||||
// BOTH: %[[TRANSFORMED_MEMREF:.*]] = xla_lhlo.dynamic_memref_cast
|
||||
// BOTH: %[[TRANSFORMED_MEMREF:.*]] = lmhlo.dynamic_memref_cast
|
||||
// BOTH-SAME: %[[OPERAND]](%[[RESULT_DIM_1]], %[[RESULT_DIM_2]])
|
||||
// BOTH-SAME: {{\[}}%[[STRIDE_0]], %[[STRIDE_1]]]
|
||||
// BOTH-SAME: : memref<?x?xf32> -> memref<?x?xf32, #map0>
|
||||
|
||||
// BOTH: "xla_lhlo.broadcast_in_dim"(%[[TRANSFORMED_MEMREF]], %[[RESULT]]) {
|
||||
// BOTH: "lmhlo.broadcast_in_dim"(%[[TRANSFORMED_MEMREF]], %[[RESULT]]) {
|
||||
// BOTH-SAME: broadcast_dimensions = dense<[1, 2]> : tensor<2xi64>
|
||||
// BOTH-SAME: } : (memref<?x?xf32, #[[MAP]]>, memref<?x?x?xf32>) -> ()
|
||||
|
||||
@ -228,7 +229,7 @@ func @complex(%real: memref<2x2xf32>,
|
||||
%tensor_imag = tensor_load %imag : memref<2x2xf32>
|
||||
%tensor_result = "mhlo.complex"(%tensor_real, %tensor_imag)
|
||||
: (tensor<2x2xf32>, tensor<2x2xf32>) -> tensor<2x2xcomplex<f32>>
|
||||
// BOTH: "xla_lhlo.complex"(%{{.*}}, %{{.*}})
|
||||
// BOTH: "lmhlo.complex"(%{{.*}}, %{{.*}})
|
||||
tensor_store %tensor_result, %result : memref<2x2xcomplex<f32>>
|
||||
return
|
||||
}
|
||||
@ -240,7 +241,7 @@ func @real(%operand: memref<2x2xcomplex<f32>>, %result: memref<2x2xf32>) {
|
||||
%tensor_operand = tensor_load %operand : memref<2x2xcomplex<f32>>
|
||||
%tensor_result = "mhlo.real"(%tensor_operand)
|
||||
: (tensor<2x2xcomplex<f32>>) -> tensor<2x2xf32>
|
||||
// BOTH: "xla_lhlo.real"(%{{.*}}, %{{.*}})
|
||||
// BOTH: "lmhlo.real"(%{{.*}}, %{{.*}})
|
||||
tensor_store %tensor_result, %result : memref<2x2xf32>
|
||||
return
|
||||
}
|
||||
@ -252,7 +253,7 @@ func @imag(%operand: memref<2x2xcomplex<f32>>, %result: memref<2x2xf32>) {
|
||||
%tensor_operand = tensor_load %operand : memref<2x2xcomplex<f32>>
|
||||
%tensor_result = "mhlo.imag"(%tensor_operand)
|
||||
: (tensor<2x2xcomplex<f32>>) -> tensor<2x2xf32>
|
||||
// BOTH: "xla_lhlo.imag"(%{{.*}}, %{{.*}})
|
||||
// BOTH: "lmhlo.imag"(%{{.*}}, %{{.*}})
|
||||
tensor_store %tensor_result, %result : memref<2x2xf32>
|
||||
return
|
||||
}
|
||||
@ -263,7 +264,7 @@ func @imag(%operand: memref<2x2xcomplex<f32>>, %result: memref<2x2xf32>) {
|
||||
func @iota(%result: memref<10xi32>) {
|
||||
%tensor_result = "mhlo.iota"()
|
||||
{iota_dimension = 0 : i64} : () -> tensor<10xi32>
|
||||
// BOTH: "xla_lhlo.iota"(%{{.*}}) {iota_dimension = 0 : i64}
|
||||
// BOTH: "lmhlo.iota"(%{{.*}}) {iota_dimension = 0 : i64}
|
||||
tensor_store %tensor_result, %result : memref<10xi32>
|
||||
return
|
||||
}
|
||||
@ -275,7 +276,7 @@ func @abs(%operand: memref<2x2xf32>, %result: memref<2x2xf32>) {
|
||||
%tensor_operand = tensor_load %operand : memref<2x2xf32>
|
||||
%tensor_result = "mhlo.abs"(%tensor_operand)
|
||||
: (tensor<2x2xf32>) -> tensor<2x2xf32>
|
||||
// BOTH: "xla_lhlo.abs"(%{{.*}}, %{{.*}})
|
||||
// BOTH: "lmhlo.abs"(%{{.*}}, %{{.*}})
|
||||
tensor_store %tensor_result, %result : memref<2x2xf32>
|
||||
return
|
||||
}
|
||||
@ -287,7 +288,7 @@ func @ceil(%operand: memref<2x2xf32>, %result: memref<2x2xf32>) {
|
||||
%tensor_operand = tensor_load %operand : memref<2x2xf32>
|
||||
%tensor_result = "mhlo.ceil"(%tensor_operand)
|
||||
: (tensor<2x2xf32>) -> tensor<2x2xf32>
|
||||
// BOTH: "xla_lhlo.ceil"(%{{.*}}, %{{.*}})
|
||||
// BOTH: "lmhlo.ceil"(%{{.*}}, %{{.*}})
|
||||
tensor_store %tensor_result, %result : memref<2x2xf32>
|
||||
return
|
||||
}
|
||||
@ -299,7 +300,7 @@ func @convert(%operand: memref<2x2xf32>, %result: memref<2x2xf32>) {
|
||||
%tensor_operand = tensor_load %operand : memref<2x2xf32>
|
||||
%tensor_result = "mhlo.convert"(%tensor_operand)
|
||||
: (tensor<2x2xf32>) -> tensor<2x2xf32>
|
||||
// BOTH: "xla_lhlo.copy"(%{{.*}}, %{{.*}})
|
||||
// BOTH: "lmhlo.copy"(%{{.*}}, %{{.*}})
|
||||
// BOTH-NOT: tensor_store
|
||||
tensor_store %tensor_result, %result : memref<2x2xf32>
|
||||
return
|
||||
@ -312,7 +313,7 @@ func @cos(%operand: memref<2x2xf32>, %result: memref<2x2xf32>) {
|
||||
%tensor_operand = tensor_load %operand : memref<2x2xf32>
|
||||
%tensor_result = "mhlo.cosine"(%tensor_operand)
|
||||
: (tensor<2x2xf32>) -> tensor<2x2xf32>
|
||||
// BOTH: "xla_lhlo.cosine"(%{{.*}}, %{{.*}})
|
||||
// BOTH: "lmhlo.cosine"(%{{.*}}, %{{.*}})
|
||||
tensor_store %tensor_result, %result : memref<2x2xf32>
|
||||
return
|
||||
}
|
||||
@ -324,7 +325,7 @@ func @neg(%operand: memref<2x2xf32>, %result: memref<2x2xf32>) {
|
||||
%tensor_operand = tensor_load %operand : memref<2x2xf32>
|
||||
%tensor_result = "mhlo.negate"(%tensor_operand)
|
||||
: (tensor<2x2xf32>) -> tensor<2x2xf32>
|
||||
// BOTH: "xla_lhlo.negate"(%{{.*}}, %{{.*}})
|
||||
// BOTH: "lmhlo.negate"(%{{.*}}, %{{.*}})
|
||||
tensor_store %tensor_result, %result : memref<2x2xf32>
|
||||
return
|
||||
}
|
||||
@ -336,7 +337,7 @@ func @rsqrt(%operand: memref<2x2xf32>, %result: memref<2x2xf32>) {
|
||||
%tensor_operand = tensor_load %operand : memref<2x2xf32>
|
||||
%tensor_result = "mhlo.rsqrt"(%tensor_operand)
|
||||
: (tensor<2x2xf32>) -> tensor<2x2xf32>
|
||||
// BOTH: "xla_lhlo.rsqrt"(%{{.*}}, %{{.*}})
|
||||
// BOTH: "lmhlo.rsqrt"(%{{.*}}, %{{.*}})
|
||||
tensor_store %tensor_result, %result : memref<2x2xf32>
|
||||
return
|
||||
}
|
||||
@ -348,7 +349,7 @@ func @sign(%operand: memref<2x2xf32>, %result: memref<2x2xf32>) {
|
||||
%tensor_operand = tensor_load %operand : memref<2x2xf32>
|
||||
%tensor_result = "mhlo.sign"(%tensor_operand)
|
||||
: (tensor<2x2xf32>) -> tensor<2x2xf32>
|
||||
// BOTH: "xla_lhlo.sign"(%{{.*}}, %{{.*}})
|
||||
// BOTH: "lmhlo.sign"(%{{.*}}, %{{.*}})
|
||||
tensor_store %tensor_result, %result : memref<2x2xf32>
|
||||
return
|
||||
}
|
||||
@ -360,7 +361,7 @@ func @sqrt(%operand: memref<2x2xf32>, %result: memref<2x2xf32>) {
|
||||
%tensor_operand = tensor_load %operand : memref<2x2xf32>
|
||||
%tensor_result = "mhlo.sqrt"(%tensor_operand)
|
||||
: (tensor<2x2xf32>) -> tensor<2x2xf32>
|
||||
// BOTH: "xla_lhlo.sqrt"(%{{.*}}, %{{.*}})
|
||||
// BOTH: "lmhlo.sqrt"(%{{.*}}, %{{.*}})
|
||||
tensor_store %tensor_result, %result : memref<2x2xf32>
|
||||
return
|
||||
}
|
||||
@ -372,7 +373,7 @@ func @tanh(%operand: memref<2x2xf32>, %result: memref<2x2xf32>) {
|
||||
%tensor_operand = tensor_load %operand : memref<2x2xf32>
|
||||
%tensor_result = "mhlo.tanh"(%tensor_operand)
|
||||
: (tensor<2x2xf32>) -> tensor<2x2xf32>
|
||||
// BOTH: "xla_lhlo.tanh"(%{{.*}}, %{{.*}})
|
||||
// BOTH: "lmhlo.tanh"(%{{.*}}, %{{.*}})
|
||||
tensor_store %tensor_result, %result : memref<2x2xf32>
|
||||
return
|
||||
}
|
||||
@ -385,7 +386,7 @@ func @remainder(%lhs: memref<2x2xf32>, %rhs: memref<2x2xf32>, %result: memref<2x
|
||||
%tensor_rhs = tensor_load %rhs : memref<2x2xf32>
|
||||
%tensor_result = "mhlo.remainder"(%tensor_lhs, %tensor_rhs)
|
||||
: (tensor<2x2xf32>, tensor<2x2xf32>) -> tensor<2x2xf32>
|
||||
// BOTH: "xla_lhlo.remainder"(%{{.*}}, %{{.*}}, %{{.*}})
|
||||
// BOTH: "lmhlo.remainder"(%{{.*}}, %{{.*}}, %{{.*}})
|
||||
tensor_store %tensor_result, %result : memref<2x2xf32>
|
||||
return
|
||||
}
|
||||
@ -411,7 +412,7 @@ func @add_dyn(%lhs: tensor<?x?xf32>, %rhs: tensor<?x?xf32>) {
|
||||
// BOTH: %[[EE1:.*]] = extract_element %[[SHAPE]][%[[C1_]]] : tensor<2xi64>
|
||||
// BOTH: %[[ICS1:.*]] = index_cast %[[EE1]] : i64 to index
|
||||
// BOTH: %[[RESULT:.*]] = alloc(%[[ICS0]], %[[ICS1]])
|
||||
// BOTH: "xla_lhlo.add"(%arg0, %arg1, %[[RESULT]]) : (memref<?x?xf32>, memref<?x?xf32>, memref<?x?xf32>) -> ()
|
||||
// BOTH: "lmhlo.add"(%arg0, %arg1, %[[RESULT]]) : (memref<?x?xf32>, memref<?x?xf32>, memref<?x?xf32>) -> ()
|
||||
return
|
||||
}
|
||||
|
||||
@ -436,7 +437,7 @@ func @tanh_dyn(%arg0: tensor<?x?xf32>) {
|
||||
// BOTH: %[[EE1:.*]] = extract_element %[[SHAPE]][%[[C1_]]] : tensor<2xi64>
|
||||
// BOTH: %[[ICS1:.*]] = index_cast %[[EE1]] : i64 to index
|
||||
// BOTH: %[[RESULT:.*]] = alloc(%[[ICS0]], %[[ICS1]])
|
||||
// BOTH: "xla_lhlo.tanh"(%arg0, %[[RESULT]]) : (memref<?x?xf32>, memref<?x?xf32>) -> ()
|
||||
// BOTH: "lmhlo.tanh"(%arg0, %[[RESULT]]) : (memref<?x?xf32>, memref<?x?xf32>) -> ()
|
||||
return
|
||||
}
|
||||
|
||||
@ -447,10 +448,10 @@ func @dot(%arg0: tensor<1024x1024xf32>) -> tensor<1024x1024xf32> {
|
||||
// PRE-SAME: (%[[ARG0:.*]]: [[TYPE:.*]], %[[RESULT:.*]]: [[TYPE]])
|
||||
// ESC-SAME: (%[[ARG0:.*]]: [[TYPE:.*]]) -> [[TYPE]]
|
||||
// BOTH-NEXT: %[[ALLOC:.*]] = alloc
|
||||
// BOTH: "xla_lhlo.dot"(%[[ARG0]], %[[ARG0]], %[[ALLOC]]) : ([[TYPE]], [[TYPE]], [[TYPE]]) -> ()
|
||||
// BOTH: "lmhlo.dot"(%[[ARG0]], %[[ARG0]], %[[ALLOC]]) : ([[TYPE]], [[TYPE]], [[TYPE]]) -> ()
|
||||
%dot = "mhlo.dot"(%arg0, %arg0)
|
||||
: (tensor<1024x1024xf32>, tensor<1024x1024xf32>) -> tensor<1024x1024xf32>
|
||||
// PRE: "xla_lhlo.copy"(%[[ALLOC]], %[[RESULT]])
|
||||
// PRE: "lmhlo.copy"(%[[ALLOC]], %[[RESULT]])
|
||||
// ESC: return %[[ALLOC]]
|
||||
return %dot : tensor<1024x1024xf32>
|
||||
}
|
||||
@ -461,7 +462,7 @@ func @dot(%arg0: tensor<1024x1024xf32>) -> tensor<1024x1024xf32> {
|
||||
func @conv(%input: tensor<3x5x5x3xf32>, %filter : tensor<2x2x3x4xf32>) -> tensor<3x5x5x4xf32> {
|
||||
%c0 = constant 0 : index
|
||||
// BOTH: %[[OUT:.*]] = alloc() : memref<3x5x5x4xf32>
|
||||
// BOTH: "xla_lhlo.convolution"(%{{.+}}, %{{.+}}, %[[OUT]])
|
||||
// BOTH: "lmhlo.convolution"(%{{.+}}, %{{.+}}, %[[OUT]])
|
||||
// BOTH-SAME: padding = dense<[
|
||||
// BOTH-SAME: [0, 1], [0, 1]]> : tensor<2x2xi64>
|
||||
// BOTH-SAME: rhs_dilation = dense<[1, 2]>
|
||||
|
@ -3,10 +3,10 @@
|
||||
// CHECK-LABEL: func @remove_simple
|
||||
func @remove_simple(%arg0: memref<2x2xf32>) {
|
||||
%0 = alloc() {temp = true} : memref<2x2xf32>
|
||||
"xla_lhlo.copy"(%0, %arg0) : (memref<2x2xf32>, memref<2x2xf32>) -> ()
|
||||
"lmhlo.copy"(%0, %arg0) : (memref<2x2xf32>, memref<2x2xf32>) -> ()
|
||||
dealloc %0 : memref<2x2xf32>
|
||||
// CHECK-NEXT: "xla_lhlo.terminator"() : () -> ()
|
||||
"xla_lhlo.terminator"() : () -> ()
|
||||
// CHECK-NEXT: "lmhlo.terminator"() : () -> ()
|
||||
"lmhlo.terminator"() : () -> ()
|
||||
}
|
||||
|
||||
// -----
|
||||
@ -14,9 +14,9 @@ func @remove_simple(%arg0: memref<2x2xf32>) {
|
||||
// CHECK-LABEL: func @remove_without_dealloc
|
||||
func @remove_without_dealloc(%arg0: memref<2x2xf32>) {
|
||||
%0 = alloc() {temp = true} : memref<2x2xf32>
|
||||
"xla_lhlo.copy"(%0, %arg0) : (memref<2x2xf32>, memref<2x2xf32>) -> ()
|
||||
// CHECK-NEXT: "xla_lhlo.terminator"() : () -> ()
|
||||
"xla_lhlo.terminator"() : () -> ()
|
||||
"lmhlo.copy"(%0, %arg0) : (memref<2x2xf32>, memref<2x2xf32>) -> ()
|
||||
// CHECK-NEXT: "lmhlo.terminator"() : () -> ()
|
||||
"lmhlo.terminator"() : () -> ()
|
||||
}
|
||||
|
||||
// -----
|
||||
@ -24,22 +24,22 @@ func @remove_without_dealloc(%arg0: memref<2x2xf32>) {
|
||||
// CHECK-LABEL: func @replace_dependency
|
||||
func @replace_dependency(%arg0: memref<2x2xf32>, %arg1: memref<2x2xf32>) {
|
||||
%0 = alloc() {temp = true} : memref<2x2xf32>
|
||||
"xla_lhlo.exponential"(%arg0, %0) : (memref<2x2xf32>, memref<2x2xf32>) -> ()
|
||||
// CHECK-NEXT: "xla_lhlo.exponential"(%arg0, %arg1) : (memref<2x2xf32>, memref<2x2xf32>) -> ()
|
||||
"xla_lhlo.copy"(%0, %arg1) : (memref<2x2xf32>, memref<2x2xf32>) -> ()
|
||||
"lmhlo.exponential"(%arg0, %0) : (memref<2x2xf32>, memref<2x2xf32>) -> ()
|
||||
// CHECK-NEXT: "lmhlo.exponential"(%arg0, %arg1) : (memref<2x2xf32>, memref<2x2xf32>) -> ()
|
||||
"lmhlo.copy"(%0, %arg1) : (memref<2x2xf32>, memref<2x2xf32>) -> ()
|
||||
dealloc %0 : memref<2x2xf32>
|
||||
// CHECK-NEXT: "xla_lhlo.terminator"() : () -> ()
|
||||
"xla_lhlo.terminator"() : () -> ()
|
||||
// CHECK-NEXT: "lmhlo.terminator"() : () -> ()
|
||||
"lmhlo.terminator"() : () -> ()
|
||||
}
|
||||
|
||||
// -----
|
||||
|
||||
// CHECK-LABEL: func @keep_copies
|
||||
func @keep_copies(%arg0: memref<2x2xf32>, %arg1: memref<2x2xf32>) {
|
||||
// CHECK-NEXT: "xla_lhlo.copy"(%arg0, %arg1) : (memref<2x2xf32>, memref<2x2xf32>) -> ()
|
||||
"xla_lhlo.copy"(%arg0, %arg1) : (memref<2x2xf32>, memref<2x2xf32>) -> ()
|
||||
// CHECK-NEXT: "xla_lhlo.terminator"() : () -> ()
|
||||
"xla_lhlo.terminator"() : () -> ()
|
||||
// CHECK-NEXT: "lmhlo.copy"(%arg0, %arg1) : (memref<2x2xf32>, memref<2x2xf32>) -> ()
|
||||
"lmhlo.copy"(%arg0, %arg1) : (memref<2x2xf32>, memref<2x2xf32>) -> ()
|
||||
// CHECK-NEXT: "lmhlo.terminator"() : () -> ()
|
||||
"lmhlo.terminator"() : () -> ()
|
||||
}
|
||||
|
||||
// -----
|
||||
@ -50,14 +50,14 @@ func @must_not_be_removed(%arg0: memref<2x2xf32>,
|
||||
%arg2: memref<2x2xf32>) {
|
||||
// CHECK-NEXT: %[[ALLOC:.*]] = alloc() {temp = true} : memref<2x2xf32>
|
||||
%0 = alloc() {temp = true} : memref<2x2xf32>
|
||||
// CHECK-NEXT: "xla_lhlo.exponential"(%arg0, %[[ALLOC]]) : (memref<2x2xf32>, memref<2x2xf32>) -> ()
|
||||
"xla_lhlo.exponential"(%arg0, %0) : (memref<2x2xf32>, memref<2x2xf32>) -> ()
|
||||
// CHECK-NEXT: "xla_lhlo.exponential"(%arg1, %arg2) : (memref<2x2xf32>, memref<2x2xf32>) -> ()
|
||||
"xla_lhlo.exponential"(%arg1, %arg2) : (memref<2x2xf32>, memref<2x2xf32>) -> ()
|
||||
// CHECK-NEXT: "xla_lhlo.copy"(%[[ALLOC]], %arg2) : (memref<2x2xf32>, memref<2x2xf32>) -> ()
|
||||
"xla_lhlo.copy"(%0, %arg2) : (memref<2x2xf32>, memref<2x2xf32>) -> ()
|
||||
// CHECK-NEXT: "lmhlo.exponential"(%arg0, %[[ALLOC]]) : (memref<2x2xf32>, memref<2x2xf32>) -> ()
|
||||
"lmhlo.exponential"(%arg0, %0) : (memref<2x2xf32>, memref<2x2xf32>) -> ()
|
||||
// CHECK-NEXT: "lmhlo.exponential"(%arg1, %arg2) : (memref<2x2xf32>, memref<2x2xf32>) -> ()
|
||||
"lmhlo.exponential"(%arg1, %arg2) : (memref<2x2xf32>, memref<2x2xf32>) -> ()
|
||||
// CHECK-NEXT: "lmhlo.copy"(%[[ALLOC]], %arg2) : (memref<2x2xf32>, memref<2x2xf32>) -> ()
|
||||
"lmhlo.copy"(%0, %arg2) : (memref<2x2xf32>, memref<2x2xf32>) -> ()
|
||||
dealloc %0 : memref<2x2xf32>
|
||||
"xla_lhlo.terminator"() : () -> ()
|
||||
"lmhlo.terminator"() : () -> ()
|
||||
}
|
||||
|
||||
// -----
|
||||
@ -67,13 +67,13 @@ func @must_be_removed_first(%arg0: memref<2x2xf32>,
|
||||
%arg1: memref<2x2xf32>,
|
||||
%arg2: memref<2x2xf32>) {
|
||||
%0 = alloc() {temp = true} : memref<2x2xf32>
|
||||
// CHECK-NEXT: "xla_lhlo.exponential"(%arg1, %arg2) : (memref<2x2xf32>, memref<2x2xf32>) -> ()
|
||||
"xla_lhlo.exponential"(%arg1, %arg2) : (memref<2x2xf32>, memref<2x2xf32>) -> ()
|
||||
// CHECK-NEXT: "xla_lhlo.exponential"(%arg0, %arg2) : (memref<2x2xf32>, memref<2x2xf32>) -> ()
|
||||
"xla_lhlo.exponential"(%arg0, %0) : (memref<2x2xf32>, memref<2x2xf32>) -> ()
|
||||
"xla_lhlo.copy"(%0, %arg2) : (memref<2x2xf32>, memref<2x2xf32>) -> ()
|
||||
// CHECK-NEXT: "lmhlo.exponential"(%arg1, %arg2) : (memref<2x2xf32>, memref<2x2xf32>) -> ()
|
||||
"lmhlo.exponential"(%arg1, %arg2) : (memref<2x2xf32>, memref<2x2xf32>) -> ()
|
||||
// CHECK-NEXT: "lmhlo.exponential"(%arg0, %arg2) : (memref<2x2xf32>, memref<2x2xf32>) -> ()
|
||||
"lmhlo.exponential"(%arg0, %0) : (memref<2x2xf32>, memref<2x2xf32>) -> ()
|
||||
"lmhlo.copy"(%0, %arg2) : (memref<2x2xf32>, memref<2x2xf32>) -> ()
|
||||
dealloc %0 : memref<2x2xf32>
|
||||
"xla_lhlo.terminator"() : () -> ()
|
||||
"lmhlo.terminator"() : () -> ()
|
||||
}
|
||||
|
||||
// -----
|
||||
@ -83,11 +83,11 @@ func @must_be_removed_second(%arg0: memref<2x2xf32>,
|
||||
%arg1: memref<2x2xf32>,
|
||||
%arg2: memref<2x2xf32>) {
|
||||
%0 = alloc() {temp = true} : memref<2x2xf32>
|
||||
// CHECK-NEXT: "xla_lhlo.exponential"(%arg0, %arg2) : (memref<2x2xf32>, memref<2x2xf32>) -> ()
|
||||
"xla_lhlo.exponential"(%arg0, %0) : (memref<2x2xf32>, memref<2x2xf32>) -> ()
|
||||
"xla_lhlo.copy"(%0, %arg2) : (memref<2x2xf32>, memref<2x2xf32>) -> ()
|
||||
// CHECK-NEXT: "xla_lhlo.exponential"(%arg1, %arg2) : (memref<2x2xf32>, memref<2x2xf32>) -> ()
|
||||
"xla_lhlo.exponential"(%arg1, %arg2) : (memref<2x2xf32>, memref<2x2xf32>) -> ()
|
||||
// CHECK-NEXT: "lmhlo.exponential"(%arg0, %arg2) : (memref<2x2xf32>, memref<2x2xf32>) -> ()
|
||||
"lmhlo.exponential"(%arg0, %0) : (memref<2x2xf32>, memref<2x2xf32>) -> ()
|
||||
"lmhlo.copy"(%0, %arg2) : (memref<2x2xf32>, memref<2x2xf32>) -> ()
|
||||
// CHECK-NEXT: "lmhlo.exponential"(%arg1, %arg2) : (memref<2x2xf32>, memref<2x2xf32>) -> ()
|
||||
"lmhlo.exponential"(%arg1, %arg2) : (memref<2x2xf32>, memref<2x2xf32>) -> ()
|
||||
dealloc %0 : memref<2x2xf32>
|
||||
"xla_lhlo.terminator"() : () -> ()
|
||||
"lmhlo.terminator"() : () -> ()
|
||||
}
|
||||
|
@ -10,18 +10,18 @@ func @select_and_scatter(%arg: memref<112x112xf32>,
|
||||
%src: memref<56x56xf32>,
|
||||
%init: memref<f32>,
|
||||
%result: memref<112x112xf32>) {
|
||||
"xla_lhlo.select_and_scatter"(%arg, %src, %init, %result) ( {
|
||||
"lmhlo.select_and_scatter"(%arg, %src, %init, %result) ( {
|
||||
// select
|
||||
^bb0(%lhs: memref<f32>, %rhs: memref<f32>, %pred: memref<i1>):
|
||||
"xla_lhlo.compare"(%lhs, %rhs, %pred) {comparison_direction = "GE"} :
|
||||
"lmhlo.compare"(%lhs, %rhs, %pred) {comparison_direction = "GE"} :
|
||||
(memref<f32>, memref<f32>, memref<i1>) -> ()
|
||||
"xla_lhlo.terminator"() : () -> ()
|
||||
"lmhlo.terminator"() : () -> ()
|
||||
}, {
|
||||
// scatter
|
||||
^bb0(%lhs: memref<f32>, %rhs: memref<f32>, %out: memref<f32>):
|
||||
"xla_lhlo.add"(%lhs, %rhs, %out) :
|
||||
"lmhlo.add"(%lhs, %rhs, %out) :
|
||||
(memref<f32>, memref<f32>, memref<f32>) -> ()
|
||||
"xla_lhlo.terminator"() : () -> ()
|
||||
"lmhlo.terminator"() : () -> ()
|
||||
}) {
|
||||
padding = dense<[[0, 1], [0, 1]]> : tensor<2x2xi64>,
|
||||
window_dimensions = dense<[3, 3]> : tensor<2xi64>,
|
||||
@ -29,7 +29,7 @@ func @select_and_scatter(%arg: memref<112x112xf32>,
|
||||
} : (memref<112x112xf32>,
|
||||
memref<56x56xf32>,
|
||||
memref<f32>, memref<112x112xf32>) -> ()
|
||||
"xla_lhlo.terminator"() : () -> ()
|
||||
"lmhlo.terminator"() : () -> ()
|
||||
}
|
||||
// CHECK-LABEL: func @select_and_scatter(
|
||||
// CHECK-SAME: [[ARG_BUF:%.*]]: memref<112x112xf32>,
|
||||
@ -121,7 +121,7 @@ func @select_and_scatter(%arg: memref<112x112xf32>,
|
||||
// CHECK: store [[SEL_VAL]], [[SEL_VAL_BUF]][] : memref<f32>
|
||||
|
||||
// Compute PRED.
|
||||
// CHECK: "xla_lhlo.compare"(
|
||||
// CHECK: "lmhlo.compare"(
|
||||
// CHECK-SAME: [[ARG_ELEM_BUF]], [[SEL_VAL_BUF]], [[PRED_BUF]])
|
||||
// CHECK: [[PRED:%.*]] = load [[PRED_BUF]][] : memref<i1>
|
||||
|
||||
@ -182,7 +182,7 @@ func @select_and_scatter(%arg: memref<112x112xf32>,
|
||||
// CHECK: store [[CUR_RES]], [[CUR_RES_BUF]][] : memref<f32>
|
||||
|
||||
// Compute scatter value.
|
||||
// CHECK: "xla_lhlo.add"([[SRC_ELEM_BUF]], [[CUR_RES_BUF]], [[RES_BUF]]) :
|
||||
// CHECK: "lmhlo.add"([[SRC_ELEM_BUF]], [[CUR_RES_BUF]], [[RES_BUF]]) :
|
||||
// CHECK-SAME: (memref<f32>, memref<f32>, memref<f32>) -> ()
|
||||
// CHECK: [[RES:%.*]] = load [[RES_BUF]][] : memref<f32>
|
||||
|
||||
|
@ -14,7 +14,7 @@ func @min_op(%lhs: memref<4x3x2x1xf32>, %rhs: memref<4x3x2x1xf32>,
|
||||
// CHECK-NEXT: %[[MIN:.*]] = select %[[MIN_PREDICATE]], %[[LHS]], %[[RHS]] : f32
|
||||
// CHECK-NEXT: affine.store %[[MIN]], %{{.*}}[%[[I]], %[[J]], %[[K]], %[[L]]] : memref<4x3x2x1xf32>
|
||||
// CHECK: return
|
||||
"xla_lhlo.minimum"(%lhs, %rhs, %result) {name = "min.1"} :
|
||||
"lmhlo.minimum"(%lhs, %rhs, %result) {name = "min.1"} :
|
||||
(memref<4x3x2x1xf32>, memref<4x3x2x1xf32>, memref<4x3x2x1xf32>) -> ()
|
||||
return
|
||||
}
|
||||
@ -24,7 +24,7 @@ func @min_op(%lhs: memref<4x3x2x1xf32>, %rhs: memref<4x3x2x1xf32>,
|
||||
func @float_add_op(%lhs: memref<7xf32>, %rhs: memref<7xf32>,
|
||||
%result: memref<7xf32>) -> () {
|
||||
// CHECK: addf %{{.*}}, %{{.*}} : f32
|
||||
"xla_lhlo.add"(%lhs, %rhs, %result) {name = "add.1"}
|
||||
"lmhlo.add"(%lhs, %rhs, %result) {name = "add.1"}
|
||||
: (memref<7xf32>, memref<7xf32>, memref<7xf32>) -> ()
|
||||
return
|
||||
}
|
||||
@ -32,7 +32,7 @@ func @float_add_op(%lhs: memref<7xf32>, %rhs: memref<7xf32>,
|
||||
func @int_add_op(%lhs: memref<7xi32>, %rhs: memref<7xi32>,
|
||||
%result: memref<7xi32>) -> () {
|
||||
// CHECK: addi %{{.*}}, %{{.*}} : i32
|
||||
"xla_lhlo.add"(%lhs, %rhs, %result) {name = "add.1"}
|
||||
"lmhlo.add"(%lhs, %rhs, %result) {name = "add.1"}
|
||||
: (memref<7xi32>, memref<7xi32>, memref<7xi32>) -> ()
|
||||
return
|
||||
}
|
||||
@ -42,7 +42,7 @@ func @int_add_op(%lhs: memref<7xi32>, %rhs: memref<7xi32>,
|
||||
func @int_and_op(%lhs: memref<7xi32>, %rhs: memref<7xi32>,
|
||||
%result: memref<7xi32>) -> () {
|
||||
// CHECK: and %{{.*}}, %{{.*}} : i32
|
||||
"xla_lhlo.and"(%lhs, %rhs, %result) {name = "and.1"}
|
||||
"lmhlo.and"(%lhs, %rhs, %result) {name = "and.1"}
|
||||
: (memref<7xi32>, memref<7xi32>, memref<7xi32>) -> ()
|
||||
return
|
||||
}
|
||||
@ -52,7 +52,7 @@ func @int_and_op(%lhs: memref<7xi32>, %rhs: memref<7xi32>,
|
||||
func @float_div_op(%lhs: memref<7xf32>, %rhs: memref<7xf32>,
|
||||
%result: memref<7xf32>) -> () {
|
||||
// CHECK: divf %{{.*}}, %{{.*}} : f32
|
||||
"xla_lhlo.divide"(%lhs, %rhs, %result) {name = "div.1"}
|
||||
"lmhlo.divide"(%lhs, %rhs, %result) {name = "div.1"}
|
||||
: (memref<7xf32>, memref<7xf32>, memref<7xf32>) -> ()
|
||||
return
|
||||
}
|
||||
@ -60,7 +60,7 @@ func @float_div_op(%lhs: memref<7xf32>, %rhs: memref<7xf32>,
|
||||
func @int_div_op(%lhs: memref<7xi32>, %rhs: memref<7xi32>,
|
||||
%result: memref<7xi32>) -> () {
|
||||
// CHECK: divi_signed %{{.*}}, %{{.*}} : i32
|
||||
"xla_lhlo.divide"(%lhs, %rhs, %result) {name = "div.1"}
|
||||
"lmhlo.divide"(%lhs, %rhs, %result) {name = "div.1"}
|
||||
: (memref<7xi32>, memref<7xi32>, memref<7xi32>) -> ()
|
||||
return
|
||||
}
|
||||
@ -71,7 +71,7 @@ func @float_max_op(%lhs: memref<7xf32>, %rhs: memref<7xf32>,
|
||||
%result: memref<7xf32>) -> () {
|
||||
// CHECK: %[[CHECK:.*]] = cmpf "ogt", %[[ONE:.*]], %[[TWO:.*]] : f32
|
||||
// CHECK: select %[[CHECK]], %[[ONE]], %[[TWO]] : f32
|
||||
"xla_lhlo.maximum"(%lhs, %rhs, %result) {name = "max.1"}
|
||||
"lmhlo.maximum"(%lhs, %rhs, %result) {name = "max.1"}
|
||||
: (memref<7xf32>, memref<7xf32>, memref<7xf32>) -> ()
|
||||
return
|
||||
}
|
||||
@ -81,7 +81,7 @@ func @int_max_op(%lhs: memref<7xi32>, %rhs: memref<7xi32>,
|
||||
%result: memref<7xi32>) -> () {
|
||||
// CHECK: %[[CHECK:.*]] = cmpi "sgt", %[[ONE:.*]], %[[TWO:.*]] : i32
|
||||
// CHECK: select %[[CHECK]], %[[ONE]], %[[TWO]] : i32
|
||||
"xla_lhlo.maximum"(%lhs, %rhs, %result) {name = "max.1"}
|
||||
"lmhlo.maximum"(%lhs, %rhs, %result) {name = "max.1"}
|
||||
: (memref<7xi32>, memref<7xi32>, memref<7xi32>) -> ()
|
||||
return
|
||||
}
|
||||
@ -92,7 +92,7 @@ func @float_min_op(%lhs: memref<7xf32>, %rhs: memref<7xf32>,
|
||||
%result: memref<7xf32>) -> () {
|
||||
// CHECK: %[[CHECK:.*]] = cmpf "olt", %[[ONE:.*]], %[[TWO:.*]] : f32
|
||||
// CHECK: select %[[CHECK]], %[[ONE]], %[[TWO]] : f32
|
||||
"xla_lhlo.minimum"(%lhs, %rhs, %result) {name = "min.1"}
|
||||
"lmhlo.minimum"(%lhs, %rhs, %result) {name = "min.1"}
|
||||
: (memref<7xf32>, memref<7xf32>, memref<7xf32>) -> ()
|
||||
return
|
||||
}
|
||||
@ -102,7 +102,7 @@ func @int_min_op(%lhs: memref<7xi32>, %rhs: memref<7xi32>,
|
||||
%result: memref<7xi32>) -> () {
|
||||
// CHECK: %[[CHECK:.*]] = cmpi "slt", %[[ONE:.*]], %[[TWO:.*]] : i32
|
||||
// CHECK: select %[[CHECK]], %[[ONE]], %[[TWO]] : i32
|
||||
"xla_lhlo.minimum"(%lhs, %rhs, %result) {name = "min.1"}
|
||||
"lmhlo.minimum"(%lhs, %rhs, %result) {name = "min.1"}
|
||||
: (memref<7xi32>, memref<7xi32>, memref<7xi32>) -> ()
|
||||
return
|
||||
}
|
||||
@ -112,7 +112,7 @@ func @int_min_op(%lhs: memref<7xi32>, %rhs: memref<7xi32>,
|
||||
func @float_mul_op(%lhs: memref<7xf32>, %rhs: memref<7xf32>,
|
||||
%result: memref<7xf32>) -> () {
|
||||
// CHECK: mulf %{{.*}}, %{{.*}} : f32
|
||||
"xla_lhlo.multiply"(%lhs, %rhs, %result) {name = "mul.1"}
|
||||
"lmhlo.multiply"(%lhs, %rhs, %result) {name = "mul.1"}
|
||||
: (memref<7xf32>, memref<7xf32>, memref<7xf32>) -> ()
|
||||
return
|
||||
}
|
||||
@ -121,7 +121,7 @@ func @float_mul_op(%lhs: memref<7xf32>, %rhs: memref<7xf32>,
|
||||
func @int_mul_op(%lhs: memref<7xi32>, %rhs: memref<7xi32>,
|
||||
%result: memref<7xi32>) -> () {
|
||||
// CHECK: muli %{{.*}}, %{{.*}} : i32
|
||||
"xla_lhlo.multiply"(%lhs, %rhs, %result) {name = "mul.1"}
|
||||
"lmhlo.multiply"(%lhs, %rhs, %result) {name = "mul.1"}
|
||||
: (memref<7xi32>, memref<7xi32>, memref<7xi32>) -> ()
|
||||
return
|
||||
}
|
||||
@ -131,7 +131,7 @@ func @int_mul_op(%lhs: memref<7xi32>, %rhs: memref<7xi32>,
|
||||
func @float_sub_op(%lhs: memref<7xf32>, %rhs: memref<7xf32>,
|
||||
%result: memref<7xf32>) -> () {
|
||||
// CHECK: subf %{{.*}}, %{{.*}} : f32
|
||||
"xla_lhlo.subtract"(%lhs, %rhs, %result) {name = "sub.1"}
|
||||
"lmhlo.subtract"(%lhs, %rhs, %result) {name = "sub.1"}
|
||||
: (memref<7xf32>, memref<7xf32>, memref<7xf32>) -> ()
|
||||
return
|
||||
}
|
||||
@ -139,7 +139,7 @@ func @float_sub_op(%lhs: memref<7xf32>, %rhs: memref<7xf32>,
|
||||
func @int_sub_op(%lhs: memref<7xi32>, %rhs: memref<7xi32>,
|
||||
%result: memref<7xi32>) -> () {
|
||||
// CHECK: subi %{{.*}}, %{{.*}} : i32
|
||||
"xla_lhlo.subtract"(%lhs, %rhs, %result) {name = "sub.1"}
|
||||
"lmhlo.subtract"(%lhs, %rhs, %result) {name = "sub.1"}
|
||||
: (memref<7xi32>, memref<7xi32>, memref<7xi32>) -> ()
|
||||
return
|
||||
}
|
||||
@ -158,7 +158,7 @@ func @float_dot_op(%lhs: memref<7x3xf32>, %rhs:
|
||||
// CHECK-NEXT: %[[ADD:.*]] = addf %[[MULT]], %[[RESULT]] : f32
|
||||
// CHECK-NEXT: affine.store %[[ADD]], %{{.*}}[%[[I]], %[[J]]] : memref<7x4xf32>
|
||||
// CHECK: return
|
||||
"xla_lhlo.dot"(%lhs, %rhs, %result) :
|
||||
"lmhlo.dot"(%lhs, %rhs, %result) :
|
||||
(memref<7x3xf32>, memref<3x4xf32>, memref<7x4xf32>) -> ()
|
||||
return
|
||||
}
|
||||
@ -175,7 +175,7 @@ func @int_dot_op(%lhs: memref<7x3xi32>, %rhs:
|
||||
// CHECK-NEXT: %[[ADD:.*]] = addi %[[MULT]], %[[RESULT]] : i32
|
||||
// CHECK-NEXT: affine.store %[[ADD]], %{{.*}}[%[[I]], %[[J]]] : memref<7x4xi32>
|
||||
// CHECK: return
|
||||
"xla_lhlo.dot"(%lhs, %rhs, %result) :
|
||||
"lmhlo.dot"(%lhs, %rhs, %result) :
|
||||
(memref<7x3xi32>, memref<3x4xi32>, memref<7x4xi32>) -> ()
|
||||
return
|
||||
}
|
||||
|
@ -3,11 +3,11 @@
|
||||
func @reduce(%arg: memref<100x10xf32>,
|
||||
%init: memref<f32>,
|
||||
%result: memref<100xf32>) {
|
||||
"xla_lhlo.reduce"(%arg, %init, %result) ( {
|
||||
"lmhlo.reduce"(%arg, %init, %result) ( {
|
||||
^bb0(%lhs: memref<f32>, %rhs: memref<f32>, %res: memref<f32>):
|
||||
"xla_lhlo.add"(%lhs, %rhs, %res)
|
||||
"lmhlo.add"(%lhs, %rhs, %res)
|
||||
: (memref<f32>, memref<f32>, memref<f32>) -> ()
|
||||
"xla_lhlo.terminator"() : () -> ()
|
||||
"lmhlo.terminator"() : () -> ()
|
||||
} ) {dimensions = dense<[1]> : tensor<1xi64>}
|
||||
: (memref<100x10xf32>, memref<f32>, memref<100xf32>) -> ()
|
||||
return
|
||||
@ -25,7 +25,7 @@ func @reduce(%arg: memref<100x10xf32>,
|
||||
// CHECK: scf.for %[[IDX1:.*]] = %[[LB]] to %[[UB]] step %[[STEP]] {
|
||||
// CHECK: %[[LHS:.*]] = linalg.slice %[[ARG2]][%[[IDX]]] : memref<100xf32>, index, memref<f32, #map0>
|
||||
// CHECK: %[[RHS:.*]] = linalg.slice %[[ARG0]][%[[IDX]], %[[IDX1]]] : memref<100x10xf32>, index, index, memref<f32, #map0>
|
||||
// CHECK: "xla_lhlo.add"(%[[LHS]], %[[RHS]], %[[LHS]]) : (memref<f32, {{.*}}>, memref<f32, {{.*}}>, memref<f32, {{.*}}>) -> ()
|
||||
// CHECK: "lmhlo.add"(%[[LHS]], %[[RHS]], %[[LHS]]) : (memref<f32, {{.*}}>, memref<f32, {{.*}}>, memref<f32, {{.*}}>) -> ()
|
||||
// CHECK: }
|
||||
// CHECK: gpu.terminator
|
||||
// CHECK: }
|
||||
|
@ -4,7 +4,7 @@
|
||||
// CHECK-LABEL: func @element_wise
|
||||
func @element_wise(%lhs: memref<2x2xf32>, %rhs: memref<2x2xf32>,
|
||||
%result: memref<2x2xf32>) {
|
||||
"xla_lhlo.add"(%lhs, %rhs, %result)
|
||||
"lmhlo.add"(%lhs, %rhs, %result)
|
||||
: (memref<2x2xf32>, memref<2x2xf32>, memref<2x2xf32>) -> ()
|
||||
return
|
||||
}
|
||||
@ -19,7 +19,7 @@ func @element_wise(%lhs: memref<2x2xf32>, %rhs: memref<2x2xf32>,
|
||||
func @element_wise_with_dynamic_shape(%lhs: memref<?x?xf32>,
|
||||
%rhs: memref<?x?xf32>,
|
||||
%result: memref<?x?xf32>) {
|
||||
"xla_lhlo.add"(%lhs, %rhs, %result)
|
||||
"lmhlo.add"(%lhs, %rhs, %result)
|
||||
: (memref<?x?xf32>, memref<?x?xf32>, memref<?x?xf32>) -> ()
|
||||
return
|
||||
}
|
||||
@ -33,7 +33,7 @@ func @element_wise_with_dynamic_shape(%lhs: memref<?x?xf32>,
|
||||
// CHECK-LABEL: func @element_wise_scalar
|
||||
func @element_wise_scalar(%lhs: memref<f32>, %rhs: memref<f32>,
|
||||
%result: memref<f32>) {
|
||||
"xla_lhlo.add"(%lhs, %rhs, %result)
|
||||
"lmhlo.add"(%lhs, %rhs, %result)
|
||||
: (memref<f32>, memref<f32>, memref<f32>) -> ()
|
||||
return
|
||||
}
|
||||
@ -48,7 +48,7 @@ func @element_wise_scalar(%lhs: memref<f32>, %rhs: memref<f32>,
|
||||
// CHECK-LABEL: func @minf
|
||||
func @minf(%lhs: memref<2x2xf32>, %rhs: memref<2x2xf32>,
|
||||
%result: memref<2x2xf32>) {
|
||||
"xla_lhlo.minimum"(%lhs, %rhs, %result)
|
||||
"lmhlo.minimum"(%lhs, %rhs, %result)
|
||||
: (memref<2x2xf32>, memref<2x2xf32>, memref<2x2xf32>) -> ()
|
||||
return
|
||||
}
|
||||
@ -63,7 +63,7 @@ func @minf(%lhs: memref<2x2xf32>, %rhs: memref<2x2xf32>,
|
||||
// CHECK-LABEL: func @maxi
|
||||
func @maxi(%lhs: memref<2x2xi32>, %rhs: memref<2x2xi32>,
|
||||
%result: memref<2x2xi32>) {
|
||||
"xla_lhlo.maximum"(%lhs, %rhs, %result)
|
||||
"lmhlo.maximum"(%lhs, %rhs, %result)
|
||||
: (memref<2x2xi32>, memref<2x2xi32>, memref<2x2xi32>) -> ()
|
||||
return
|
||||
}
|
||||
@ -78,7 +78,7 @@ func @maxi(%lhs: memref<2x2xi32>, %rhs: memref<2x2xi32>,
|
||||
// CHECK-LABEL: func @and
|
||||
func @and(%lhs: memref<2x2xi32>, %rhs: memref<2x2xi32>,
|
||||
%result: memref<2x2xi32>) {
|
||||
"xla_lhlo.and"(%lhs, %rhs, %result)
|
||||
"lmhlo.and"(%lhs, %rhs, %result)
|
||||
: (memref<2x2xi32>, memref<2x2xi32>, memref<2x2xi32>) -> ()
|
||||
return
|
||||
}
|
||||
@ -91,7 +91,7 @@ func @and(%lhs: memref<2x2xi32>, %rhs: memref<2x2xi32>,
|
||||
|
||||
// CHECK-LABEL: func @exp
|
||||
func @exp(%input: memref<2x2xf32>, %result: memref<2x2xf32>) {
|
||||
"xla_lhlo.exponential"(%input, %result)
|
||||
"lmhlo.exponential"(%input, %result)
|
||||
: (memref<2x2xf32>, memref<2x2xf32>) -> ()
|
||||
return
|
||||
}
|
||||
@ -104,7 +104,7 @@ func @exp(%input: memref<2x2xf32>, %result: memref<2x2xf32>) {
|
||||
|
||||
// CHECK-LABEL: func @log
|
||||
func @log(%input: memref<2x2xf32>, %result: memref<2x2xf32>) {
|
||||
"xla_lhlo.log"(%input, %result) : (memref<2x2xf32>, memref<2x2xf32>) -> ()
|
||||
"lmhlo.log"(%input, %result) : (memref<2x2xf32>, memref<2x2xf32>) -> ()
|
||||
return
|
||||
}
|
||||
// CHECK: linalg.generic
|
||||
@ -116,7 +116,7 @@ func @log(%input: memref<2x2xf32>, %result: memref<2x2xf32>) {
|
||||
|
||||
// CHECK-LABEL: func @copy
|
||||
func @copy(%in: memref<2x4x8xf32>, %out: memref<2x4x8xf32>) {
|
||||
"xla_lhlo.copy"(%in, %out) : (memref<2x4x8xf32>, memref<2x4x8xf32>) -> ()
|
||||
"lmhlo.copy"(%in, %out) : (memref<2x4x8xf32>, memref<2x4x8xf32>) -> ()
|
||||
return
|
||||
}
|
||||
// CHECK: linalg.generic
|
||||
@ -128,7 +128,7 @@ func @copy(%in: memref<2x4x8xf32>, %out: memref<2x4x8xf32>) {
|
||||
// CHECK-LABEL: func @float_cmp
|
||||
func @float_cmp(%lhs: memref<2x2xf32>, %rhs: memref<2x2xf32>,
|
||||
%result: memref<2x2xi1>) {
|
||||
"xla_lhlo.compare"(%lhs, %rhs, %result) {comparison_direction = "EQ"}
|
||||
"lmhlo.compare"(%lhs, %rhs, %result) {comparison_direction = "EQ"}
|
||||
: (memref<2x2xf32>, memref<2x2xf32>, memref<2x2xi1>) -> ()
|
||||
return
|
||||
}
|
||||
@ -142,7 +142,7 @@ func @float_cmp(%lhs: memref<2x2xf32>, %rhs: memref<2x2xf32>,
|
||||
// CHECK-LABEL: func @int_cmp
|
||||
func @int_cmp(%lhs: memref<2x2xi32>, %rhs: memref<2x2xi32>,
|
||||
%result: memref<2x2xi1>) {
|
||||
"xla_lhlo.compare"(%lhs, %rhs, %result) {comparison_direction = "LT"}
|
||||
"lmhlo.compare"(%lhs, %rhs, %result) {comparison_direction = "LT"}
|
||||
: (memref<2x2xi32>, memref<2x2xi32>, memref<2x2xi1>) -> ()
|
||||
return
|
||||
}
|
||||
@ -156,7 +156,7 @@ func @int_cmp(%lhs: memref<2x2xi32>, %rhs: memref<2x2xi32>,
|
||||
// CHECK-LABEL: func @select
|
||||
func @select(%pred: memref<2x2xi1>, %lhs: memref<2x2xf32>,
|
||||
%rhs: memref<2x2xf32>, %result: memref<2x2xf32>) {
|
||||
"xla_lhlo.select"(%pred, %lhs, %rhs, %result)
|
||||
"lmhlo.select"(%pred, %lhs, %rhs, %result)
|
||||
: (memref<2x2xi1>, memref<2x2xf32>, memref<2x2xf32>, memref<2x2xf32>) -> ()
|
||||
return
|
||||
}
|
||||
@ -170,7 +170,7 @@ func @select(%pred: memref<2x2xi1>, %lhs: memref<2x2xf32>,
|
||||
// CHECK: #[[RESULT_MAP:.*]] = affine_map<(d0, d1) -> (d0, d1)>
|
||||
// CHECK-LABEL: func @iota
|
||||
func @iota(%out: memref<7x10xf32>) {
|
||||
"xla_lhlo.iota"(%out) {iota_dimension = 1 : i64} : (memref<7x10xf32>) -> ()
|
||||
"lmhlo.iota"(%out) {iota_dimension = 1 : i64} : (memref<7x10xf32>) -> ()
|
||||
return
|
||||
}
|
||||
// CHECK: linalg.indexed_generic
|
||||
@ -186,7 +186,7 @@ func @iota(%out: memref<7x10xf32>) {
|
||||
// CHECK-DAG: #[[RESULT_MAP:.+]] = affine_map<(d0, d1, d2) -> (d0, d1, d2)>
|
||||
// CHECK-LABEL: func @broadcast_scalar
|
||||
func @broadcast_scalar(%operand: memref<f32>, %result: memref<4x2x1xf32>) {
|
||||
"xla_lhlo.broadcast"(%operand, %result) {
|
||||
"lmhlo.broadcast"(%operand, %result) {
|
||||
broadcast_sizes = dense<[4, 2, 1]> : tensor<3xi64>
|
||||
} : (memref<f32>, memref<4x2x1xf32>) -> ()
|
||||
return
|
||||
@ -203,7 +203,7 @@ func @broadcast_scalar(%operand: memref<f32>, %result: memref<4x2x1xf32>) {
|
||||
// CHECK-LABEL: func @broadcast
|
||||
func @broadcast(%operand: memref<4x?x16xf32>,
|
||||
%result: memref<4x2x1x4x?x16xf32>) {
|
||||
"xla_lhlo.broadcast"(%operand, %result) {
|
||||
"lmhlo.broadcast"(%operand, %result) {
|
||||
broadcast_sizes = dense<[4, 2, 1]> : tensor<3xi64>
|
||||
} : (memref<4x?x16xf32>, memref<4x2x1x4x?x16xf32>) -> ()
|
||||
return
|
||||
@ -220,7 +220,7 @@ func @broadcast(%operand: memref<4x?x16xf32>,
|
||||
// CHECK-LABEL: func @dynamic_broadcast_in_dim
|
||||
func @dynamic_broadcast_in_dim(%operand: memref<?x?x?xf32>,
|
||||
%result: memref<?x?x?x?x?xf32>) {
|
||||
"xla_lhlo.broadcast_in_dim"(%operand, %result) {
|
||||
"lmhlo.broadcast_in_dim"(%operand, %result) {
|
||||
broadcast_dimensions = dense<[4,0,2]> : tensor<3xi64>
|
||||
} : (memref<?x?x?xf32>, memref<?x?x?x?x?xf32>) -> ()
|
||||
return
|
||||
@ -237,7 +237,7 @@ func @dynamic_broadcast_in_dim(%operand: memref<?x?x?xf32>,
|
||||
// CHECK-LABEL: func @static_broadcast_in_dim_no_expansion
|
||||
func @static_broadcast_in_dim_no_expansion(%operand: memref<5xf32>,
|
||||
%result: memref<5x10xf32>) {
|
||||
"xla_lhlo.broadcast_in_dim"(%operand, %result) {
|
||||
"lmhlo.broadcast_in_dim"(%operand, %result) {
|
||||
broadcast_dimensions = dense<[0]> : tensor<1xi64>
|
||||
} : (memref<5xf32>, memref<5x10xf32>) -> ()
|
||||
return
|
||||
@ -255,7 +255,7 @@ func @static_broadcast_in_dim_no_expansion(%operand: memref<5xf32>,
|
||||
// CHECK-LABEL: func @static_broadcast_in_dim_expansion
|
||||
func @static_broadcast_in_dim_expansion(%operand: memref<1x5xf32>,
|
||||
%result: memref<5x10x100xf32>) {
|
||||
"xla_lhlo.broadcast_in_dim"(%operand, %result) {
|
||||
"lmhlo.broadcast_in_dim"(%operand, %result) {
|
||||
broadcast_dimensions = dense<[2, 0]> : tensor<2xi64>
|
||||
} : (memref<1x5xf32>, memref<5x10x100xf32>) -> ()
|
||||
return
|
||||
@ -274,7 +274,7 @@ func @static_broadcast_in_dim_expansion(%operand: memref<1x5xf32>,
|
||||
// CHECK-LABEL: func @static_broadcast_in_dim_scalar
|
||||
func @static_broadcast_in_dim_scalar(%operand: memref<f32>,
|
||||
%result: memref<5x10xf32>) {
|
||||
"xla_lhlo.broadcast_in_dim"(%operand, %result) {
|
||||
"lmhlo.broadcast_in_dim"(%operand, %result) {
|
||||
broadcast_dimensions = dense<[]> : tensor<0xi64>
|
||||
} : (memref<f32>, memref<5x10xf32>) -> ()
|
||||
return
|
||||
@ -291,7 +291,7 @@ func @static_broadcast_in_dim_scalar(%operand: memref<f32>,
|
||||
// CHECK-LABEL: func @static_broadcast_in_dim_with_one_to_one
|
||||
func @static_broadcast_in_dim_with_one_to_one(%operand: memref<1xf32>,
|
||||
%result: memref<1x5xf32>) {
|
||||
"xla_lhlo.broadcast_in_dim"(%operand, %result) {
|
||||
"lmhlo.broadcast_in_dim"(%operand, %result) {
|
||||
broadcast_dimensions = dense<[0]> : tensor<1xi64>
|
||||
} : (memref<1xf32>, memref<1x5xf32>) -> ()
|
||||
return
|
||||
@ -307,7 +307,7 @@ func @static_broadcast_in_dim_with_one_to_one(%operand: memref<1xf32>,
|
||||
// CHECK-LABEL: func @static_broadcast_in_dim_with_one_to_many
|
||||
func @static_broadcast_in_dim_with_one_to_many(%operand: memref<1xf32>,
|
||||
%result: memref<5x5xf32>) {
|
||||
"xla_lhlo.broadcast_in_dim"(%operand, %result) {
|
||||
"lmhlo.broadcast_in_dim"(%operand, %result) {
|
||||
broadcast_dimensions = dense<[1]> : tensor<1xi64>
|
||||
} : (memref<1xf32>, memref<5x5xf32>) -> ()
|
||||
return
|
||||
@ -323,7 +323,7 @@ func @static_broadcast_in_dim_with_one_to_many(%operand: memref<1xf32>,
|
||||
|
||||
// CHECK-LABEL: func @constant
|
||||
func @constant(%value: memref<i32>) {
|
||||
"xla_lhlo.constant"(%value) {
|
||||
"lmhlo.constant"(%value) {
|
||||
value = dense<10> : tensor<i32>
|
||||
} : (memref<i32>) -> ()
|
||||
return
|
||||
@ -335,7 +335,7 @@ func @constant(%value: memref<i32>) {
|
||||
|
||||
// CHECK-LABEL: func @absf
|
||||
func @absf(%input: memref<2x2xf32>, %result: memref<2x2xf32>) {
|
||||
"xla_lhlo.abs"(%input, %result) : (memref<2x2xf32>, memref<2x2xf32>) -> ()
|
||||
"lmhlo.abs"(%input, %result) : (memref<2x2xf32>, memref<2x2xf32>) -> ()
|
||||
return
|
||||
}
|
||||
// CHECK: linalg.generic
|
||||
@ -348,7 +348,7 @@ func @absf(%input: memref<2x2xf32>, %result: memref<2x2xf32>) {
|
||||
// CHECK-LABEL: func @absi
|
||||
func @absi(%input: memref<2x2xi32>,
|
||||
%result: memref<2x2xi32>) {
|
||||
"xla_lhlo.abs"(%input, %result) : (memref<2x2xi32>, memref<2x2xi32>) -> ()
|
||||
"lmhlo.abs"(%input, %result) : (memref<2x2xi32>, memref<2x2xi32>) -> ()
|
||||
return
|
||||
}
|
||||
|
||||
@ -364,7 +364,7 @@ func @absi(%input: memref<2x2xi32>,
|
||||
|
||||
// CHECK-LABEL: func @ceil
|
||||
func @ceil(%input: memref<2x2xf32>, %result: memref<2x2xf32>) {
|
||||
"xla_lhlo.ceil"(%input, %result) : (memref<2x2xf32>, memref<2x2xf32>) -> ()
|
||||
"lmhlo.ceil"(%input, %result) : (memref<2x2xf32>, memref<2x2xf32>) -> ()
|
||||
return
|
||||
}
|
||||
// CHECK: linalg.generic
|
||||
@ -376,7 +376,7 @@ func @ceil(%input: memref<2x2xf32>, %result: memref<2x2xf32>) {
|
||||
|
||||
// CHECK-LABEL: func @convert_i32_to_f32
|
||||
func @convert_i32_to_f32(%input: memref<2x2xi32>, %result: memref<2x2xf32>) {
|
||||
"xla_lhlo.convert"(%input, %result) : (memref<2x2xi32>, memref<2x2xf32>) -> ()
|
||||
"lmhlo.convert"(%input, %result) : (memref<2x2xi32>, memref<2x2xf32>) -> ()
|
||||
return
|
||||
}
|
||||
// CHECK: linalg.generic
|
||||
@ -389,7 +389,7 @@ func @convert_i32_to_f32(%input: memref<2x2xi32>, %result: memref<2x2xf32>) {
|
||||
// CHECK-LABEL: func @convert_i16_to_i32
|
||||
func @convert_i16_to_i32(%input: memref<2x2xi16>,
|
||||
%result: memref<2x2xi32>) {
|
||||
"xla_lhlo.convert"(%input, %result) : (memref<2x2xi16>, memref<2x2xi32>) -> ()
|
||||
"lmhlo.convert"(%input, %result) : (memref<2x2xi16>, memref<2x2xi32>) -> ()
|
||||
return
|
||||
}
|
||||
// CHECK: linalg.generic
|
||||
@ -401,7 +401,7 @@ func @convert_i16_to_i32(%input: memref<2x2xi16>,
|
||||
|
||||
// CHECK-LABEL: func @convert_i32_to_i16
|
||||
func @convert_i32_to_i16(%input: memref<2x2xi32>, %result: memref<2x2xi16>) {
|
||||
"xla_lhlo.convert"(%input, %result) : (memref<2x2xi32>, memref<2x2xi16>) -> ()
|
||||
"lmhlo.convert"(%input, %result) : (memref<2x2xi32>, memref<2x2xi16>) -> ()
|
||||
return
|
||||
}
|
||||
// CHECK: linalg.generic
|
||||
@ -413,7 +413,7 @@ func @convert_i32_to_i16(%input: memref<2x2xi32>, %result: memref<2x2xi16>) {
|
||||
|
||||
// CHECK-LABEL: func @convert_f32_to_f64
|
||||
func @convert_f32_to_f64(%input: memref<2x2xf32>, %result: memref<2x2xf64>) {
|
||||
"xla_lhlo.convert"(%input, %result) : (memref<2x2xf32>, memref<2x2xf64>) -> ()
|
||||
"lmhlo.convert"(%input, %result) : (memref<2x2xf32>, memref<2x2xf64>) -> ()
|
||||
return
|
||||
}
|
||||
// CHECK: linalg.generic
|
||||
@ -425,7 +425,7 @@ func @convert_f32_to_f64(%input: memref<2x2xf32>, %result: memref<2x2xf64>) {
|
||||
|
||||
// CHECK-LABEL: func @convert_f64_to_f32
|
||||
func @convert_f64_to_f32(%input: memref<2x2xf64>, %result: memref<2x2xf32>) {
|
||||
"xla_lhlo.convert"(%input, %result) : (memref<2x2xf64>, memref<2x2xf32>) -> ()
|
||||
"lmhlo.convert"(%input, %result) : (memref<2x2xf64>, memref<2x2xf32>) -> ()
|
||||
return
|
||||
}
|
||||
// CHECK: linalg.generic
|
||||
@ -437,7 +437,7 @@ func @convert_f64_to_f32(%input: memref<2x2xf64>, %result: memref<2x2xf32>) {
|
||||
|
||||
// CHECK-LABEL: func @convert_i32_to_i32
|
||||
func @convert_i32_to_i32(%input: memref<2x2xi32>, %result: memref<2x2xi32>) {
|
||||
"xla_lhlo.convert"(%input, %result) : (memref<2x2xi32>, memref<2x2xi32>) -> ()
|
||||
"lmhlo.convert"(%input, %result) : (memref<2x2xi32>, memref<2x2xi32>) -> ()
|
||||
return
|
||||
}
|
||||
// CHECK: linalg.generic
|
||||
@ -448,7 +448,7 @@ func @convert_i32_to_i32(%input: memref<2x2xi32>, %result: memref<2x2xi32>) {
|
||||
|
||||
// CHECK-LABEL: func @convert_f32_to_f32
|
||||
func @convert_f32_to_f32(%input: memref<2x2xf32>, %result: memref<2x2xf32>) {
|
||||
"xla_lhlo.convert"(%input, %result) : (memref<2x2xf32>, memref<2x2xf32>) -> ()
|
||||
"lmhlo.convert"(%input, %result) : (memref<2x2xf32>, memref<2x2xf32>) -> ()
|
||||
return
|
||||
}
|
||||
// CHECK: linalg.generic
|
||||
@ -459,7 +459,7 @@ func @convert_f32_to_f32(%input: memref<2x2xf32>, %result: memref<2x2xf32>) {
|
||||
|
||||
// CHECK-LABEL: func @convert_f32_to_i32
|
||||
func @convert_f32_to_i32(%input: memref<2x2xf32>, %result: memref<2x2xi32>) {
|
||||
"xla_lhlo.convert"(%input, %result)
|
||||
"lmhlo.convert"(%input, %result)
|
||||
: (memref<2x2xf32>, memref<2x2xi32>) -> ()
|
||||
return
|
||||
}
|
||||
@ -472,7 +472,7 @@ func @convert_f32_to_i32(%input: memref<2x2xf32>, %result: memref<2x2xi32>) {
|
||||
|
||||
// CHECK-LABEL: func @cos
|
||||
func @cos(%input: memref<2x2xf32>, %result: memref<2x2xf32>) {
|
||||
"xla_lhlo.cosine"(%input, %result) : (memref<2x2xf32>, memref<2x2xf32>) -> ()
|
||||
"lmhlo.cosine"(%input, %result) : (memref<2x2xf32>, memref<2x2xf32>) -> ()
|
||||
return
|
||||
}
|
||||
// CHECK: linalg.generic
|
||||
@ -485,7 +485,7 @@ func @cos(%input: memref<2x2xf32>, %result: memref<2x2xf32>) {
|
||||
// CHECK-LABEL: func @sin
|
||||
func @sin(%input: memref<2x2xf32>,
|
||||
%result: memref<2x2xf32>) {
|
||||
"xla_lhlo.sine"(%input, %result)
|
||||
"lmhlo.sine"(%input, %result)
|
||||
: (memref<2x2xf32>, memref<2x2xf32>) -> ()
|
||||
return
|
||||
}
|
||||
@ -498,7 +498,7 @@ func @sin(%input: memref<2x2xf32>,
|
||||
|
||||
// CHECK-LABEL: func @negf
|
||||
func @negf(%input: memref<2x2xf32>, %result: memref<2x2xf32>) {
|
||||
"xla_lhlo.negate"(%input, %result) : (memref<2x2xf32>, memref<2x2xf32>) -> ()
|
||||
"lmhlo.negate"(%input, %result) : (memref<2x2xf32>, memref<2x2xf32>) -> ()
|
||||
return
|
||||
}
|
||||
// CHECK: linalg.generic
|
||||
@ -510,7 +510,7 @@ func @negf(%input: memref<2x2xf32>, %result: memref<2x2xf32>) {
|
||||
|
||||
// CHECK-LABEL: func @negi
|
||||
func @negi(%input: memref<2x2xi32>, %result: memref<2x2xi32>) {
|
||||
"xla_lhlo.negate"(%input, %result) : (memref<2x2xi32>, memref<2x2xi32>) -> ()
|
||||
"lmhlo.negate"(%input, %result) : (memref<2x2xi32>, memref<2x2xi32>) -> ()
|
||||
return
|
||||
}
|
||||
// CHECK: linalg.generic
|
||||
@ -524,7 +524,7 @@ func @negi(%input: memref<2x2xi32>, %result: memref<2x2xi32>) {
|
||||
// CHECK-LABEL: func @rem
|
||||
func @remainder(%lhs: memref<2x2xf32>, %rhs: memref<2x2xf32>,
|
||||
%result: memref<2x2xf32>) {
|
||||
"xla_lhlo.remainder"(%lhs, %rhs, %result)
|
||||
"lmhlo.remainder"(%lhs, %rhs, %result)
|
||||
: (memref<2x2xf32>, memref<2x2xf32>, memref<2x2xf32>) -> ()
|
||||
return
|
||||
}
|
||||
@ -537,7 +537,7 @@ func @remainder(%lhs: memref<2x2xf32>, %rhs: memref<2x2xf32>,
|
||||
|
||||
// CHECK-LABEL: func @rsqrt
|
||||
func @rsqrt(%input: memref<2x2xf32>, %result: memref<2x2xf32>) {
|
||||
"xla_lhlo.rsqrt"(%input, %result) : (memref<2x2xf32>, memref<2x2xf32>) -> ()
|
||||
"lmhlo.rsqrt"(%input, %result) : (memref<2x2xf32>, memref<2x2xf32>) -> ()
|
||||
return
|
||||
}
|
||||
// CHECK: linalg.generic
|
||||
@ -549,7 +549,7 @@ func @rsqrt(%input: memref<2x2xf32>, %result: memref<2x2xf32>) {
|
||||
|
||||
// CHECK-LABEL: func @sign
|
||||
func @sign(%input: memref<2x2xf32>, %result: memref<2x2xf32>) {
|
||||
"xla_lhlo.sign"(%input, %result) : (memref<2x2xf32>, memref<2x2xf32>) -> ()
|
||||
"lmhlo.sign"(%input, %result) : (memref<2x2xf32>, memref<2x2xf32>) -> ()
|
||||
return
|
||||
}
|
||||
// CHECK: linalg.generic
|
||||
@ -562,7 +562,7 @@ func @sign(%input: memref<2x2xf32>, %result: memref<2x2xf32>) {
|
||||
|
||||
// CHECK-LABEL: func @sqrt
|
||||
func @sqrt(%input: memref<2x2xf32>, %result: memref<2x2xf32>) {
|
||||
"xla_lhlo.sqrt"(%input, %result) : (memref<2x2xf32>, memref<2x2xf32>) -> ()
|
||||
"lmhlo.sqrt"(%input, %result) : (memref<2x2xf32>, memref<2x2xf32>) -> ()
|
||||
return
|
||||
}
|
||||
// CHECK: linalg.generic
|
||||
@ -574,7 +574,7 @@ func @sqrt(%input: memref<2x2xf32>, %result: memref<2x2xf32>) {
|
||||
|
||||
// CHECK-LABEL: func @tanh
|
||||
func @tanh(%input: memref<2x2xf32>, %result: memref<2x2xf32>) {
|
||||
"xla_lhlo.tanh"(%input, %result) : (memref<2x2xf32>, memref<2x2xf32>) -> ()
|
||||
"lmhlo.tanh"(%input, %result) : (memref<2x2xf32>, memref<2x2xf32>) -> ()
|
||||
return
|
||||
}
|
||||
// CHECK: linalg.generic
|
||||
@ -588,7 +588,7 @@ func @tanh(%input: memref<2x2xf32>, %result: memref<2x2xf32>) {
|
||||
func @complex(%real: memref<2x2xf32>,
|
||||
%imag: memref<2x2xf32>,
|
||||
%cplx: memref<2x2xcomplex<f32>>) {
|
||||
"xla_lhlo.complex"(%real, %imag, %cplx)
|
||||
"lmhlo.complex"(%real, %imag, %cplx)
|
||||
: (memref<2x2xf32>, memref<2x2xf32>, memref<2x2xcomplex<f32>>) -> ()
|
||||
return
|
||||
}
|
||||
@ -602,7 +602,7 @@ func @complex(%real: memref<2x2xf32>,
|
||||
// CHECK-LABEL: func @real
|
||||
func @real(%cplx: memref<2x2xcomplex<f32>>,
|
||||
%real: memref<2x2xf32>) {
|
||||
"xla_lhlo.real"(%cplx, %real)
|
||||
"lmhlo.real"(%cplx, %real)
|
||||
: (memref<2x2xcomplex<f32>>, memref<2x2xf32>) -> ()
|
||||
return
|
||||
}
|
||||
@ -616,7 +616,7 @@ func @real(%cplx: memref<2x2xcomplex<f32>>,
|
||||
// CHECK-LABEL: func @imag
|
||||
func @imag(%cplx: memref<2x2xcomplex<f32>>,
|
||||
%imag: memref<2x2xf32>) {
|
||||
"xla_lhlo.imag"(%cplx, %imag)
|
||||
"lmhlo.imag"(%cplx, %imag)
|
||||
: (memref<2x2xcomplex<f32>>, memref<2x2xf32>) -> ()
|
||||
return
|
||||
}
|
||||
@ -629,7 +629,7 @@ func @imag(%cplx: memref<2x2xcomplex<f32>>,
|
||||
|
||||
// CHECK: func @slice(%[[IN:.*]]: memref<?x?xf32>, %[[OUT:.*]]: memref<?x?xf32>)
|
||||
func @slice(%operand: memref<?x?xf32>, %result: memref<?x?xf32>) {
|
||||
"xla_lhlo.slice"(%operand, %result) {
|
||||
"lmhlo.slice"(%operand, %result) {
|
||||
start_indices = dense<[0,1]> : tensor<2xi64>,
|
||||
limit_indices = dense<[2,3]> : tensor<2xi64>,
|
||||
strides = dense<[1,1]> : tensor<2xi64>
|
||||
@ -653,7 +653,7 @@ func @slice(%operand: memref<?x?xf32>, %result: memref<?x?xf32>) {
|
||||
// CHECK-DAG: #[[MAP2:.*]] = affine_map<(d0, d1, d2) -> (d2)>
|
||||
// CHECK-LABEL: func @reshape_3D_2D
|
||||
func @reshape_3D_2D(%arg0: memref<12x1x42xi32>, %arg1 : memref<12x42xi32>) {
|
||||
"xla_lhlo.reshape"(%arg0, %arg1)
|
||||
"lmhlo.reshape"(%arg0, %arg1)
|
||||
: (memref<12x1x42xi32>, memref<12x42xi32>) -> ()
|
||||
return
|
||||
}
|
||||
@ -666,7 +666,7 @@ func @reshape_3D_2D(%arg0: memref<12x1x42xi32>, %arg1 : memref<12x42xi32>) {
|
||||
// CHECK-DAG: #[[MAP2:.*]] = affine_map<(d0, d1, d2, d3) -> (d1, d2, d3)>
|
||||
// CHECK-LABEL: func @reshape_4D_2D
|
||||
func @reshape_4D_2D(%arg0: memref<12x42x1x1xi32>, %arg1 : memref<12x42xi32>) {
|
||||
"xla_lhlo.reshape"(%arg0, %arg1)
|
||||
"lmhlo.reshape"(%arg0, %arg1)
|
||||
: (memref<12x42x1x1xi32>, memref<12x42xi32>) -> ()
|
||||
return
|
||||
}
|
||||
@ -679,7 +679,7 @@ func @reshape_4D_2D(%arg0: memref<12x42x1x1xi32>, %arg1 : memref<12x42xi32>) {
|
||||
// CHECK-DAG: #[[MAP2:.*]] = affine_map<(d0, d1, d2, d3) -> (d2, d3)>
|
||||
// CHECK-LABEL: func @reshape_2D_4D
|
||||
func @reshape_2D_4D(%arg0: memref<12x42xi32>, %arg1 : memref<12x1x42x1xi32>) {
|
||||
"xla_lhlo.reshape"(%arg0, %arg1)
|
||||
"lmhlo.reshape"(%arg0, %arg1)
|
||||
: (memref<12x42xi32>, memref<12x1x42x1xi32>) -> ()
|
||||
return
|
||||
}
|
||||
@ -692,7 +692,7 @@ func @reshape_2D_4D(%arg0: memref<12x42xi32>, %arg1 : memref<12x1x42x1xi32>) {
|
||||
// CHECK-DAG: #[[RESULT_MAP:.*]] = affine_map<(d0, d1) -> (d0, d1)>
|
||||
// CHECK-LABEL: func @reverse
|
||||
func @reverse(%arg0: memref<2x3xf32>, %arg1: memref<2x3xf32>) {
|
||||
"xla_lhlo.reverse"(%arg0, %arg1) {
|
||||
"lmhlo.reverse"(%arg0, %arg1) {
|
||||
dimensions = dense<1> : tensor<1xi64>
|
||||
} : (memref<2x3xf32>, memref<2x3xf32>) -> ()
|
||||
return
|
||||
@ -710,15 +710,15 @@ func @conv(%input: memref<3x5x5x3xf32>, %filter: memref<2x2x3x4xf32>, %output: m
|
||||
// CHECK-SAME: padding = dense<{{\[\[}}0, 1], [0, 1]]> : tensor<2x2xi64>
|
||||
// CHECK-SAME: strides = [2, 1]}
|
||||
// With all atributes explicitly specified.
|
||||
"xla_lhlo.convolution"(%filter, %input, %0) {batch_group_count = 1 : i64, dimension_numbers = {input_batch_dimension = 0 : i64, input_feature_dimension = 3 : i64, input_spatial_dimensions = dense<[1, 2]> : tensor<2xi64>, kernel_input_feature_dimension = 2 : i64, kernel_output_feature_dimension = 3 : i64, kernel_spatial_dimensions = dense<[0, 1]> : tensor<2xi64>, output_batch_dimension = 0 : i64, output_feature_dimension = 3 : i64, output_spatial_dimensions = dense<[1, 2]> : tensor<2xi64>}, feature_group_count = 1 : i64, padding = dense<[[0, 1], [0, 1]]> : tensor<2x2xi64>, rhs_dilation = dense<[1, 2]> : tensor<2xi64>, window_strides = dense<[2, 1]> : tensor<2xi64>} : (memref<2x2x3x4xf32>, memref<3x5x5x3xf32>, memref<3x5x5x4xf32>) -> ()
|
||||
"lmhlo.convolution"(%filter, %input, %0) {batch_group_count = 1 : i64, dimension_numbers = {input_batch_dimension = 0 : i64, input_feature_dimension = 3 : i64, input_spatial_dimensions = dense<[1, 2]> : tensor<2xi64>, kernel_input_feature_dimension = 2 : i64, kernel_output_feature_dimension = 3 : i64, kernel_spatial_dimensions = dense<[0, 1]> : tensor<2xi64>, output_batch_dimension = 0 : i64, output_feature_dimension = 3 : i64, output_spatial_dimensions = dense<[1, 2]> : tensor<2xi64>}, feature_group_count = 1 : i64, padding = dense<[[0, 1], [0, 1]]> : tensor<2x2xi64>, rhs_dilation = dense<[1, 2]> : tensor<2xi64>, window_strides = dense<[2, 1]> : tensor<2xi64>} : (memref<2x2x3x4xf32>, memref<3x5x5x3xf32>, memref<3x5x5x4xf32>) -> ()
|
||||
|
||||
// Dilation left unspecified, sets default dilation since linalg expects it.
|
||||
// CHECK: linalg.conv(%{{.+}}, %{{.+}}, %{{.+}})
|
||||
// CHECK-SAME: dilations = [1, 1]
|
||||
// Padding is not set if it's zero.
|
||||
// CHECK-NOT: padding
|
||||
"xla_lhlo.convolution"(%filter, %input, %0) {batch_group_count = 1 : i64, dimension_numbers = {input_batch_dimension = 0 : i64, input_feature_dimension = 3 : i64, input_spatial_dimensions = dense<[1, 2]> : tensor<2xi64>, kernel_input_feature_dimension = 2 : i64, kernel_output_feature_dimension = 3 : i64, kernel_spatial_dimensions = dense<[0, 1]> : tensor<2xi64>, output_batch_dimension = 0 : i64, output_feature_dimension = 3 : i64, output_spatial_dimensions = dense<[1, 2]> : tensor<2xi64>}, feature_group_count = 1 : i64, window_strides = dense<[2, 1]> : tensor<2xi64>} : (memref<2x2x3x4xf32>, memref<3x5x5x3xf32>, memref<3x5x5x4xf32>) -> ()
|
||||
"lmhlo.convolution"(%filter, %input, %0) {batch_group_count = 1 : i64, dimension_numbers = {input_batch_dimension = 0 : i64, input_feature_dimension = 3 : i64, input_spatial_dimensions = dense<[1, 2]> : tensor<2xi64>, kernel_input_feature_dimension = 2 : i64, kernel_output_feature_dimension = 3 : i64, kernel_spatial_dimensions = dense<[0, 1]> : tensor<2xi64>, output_batch_dimension = 0 : i64, output_feature_dimension = 3 : i64, output_spatial_dimensions = dense<[1, 2]> : tensor<2xi64>}, feature_group_count = 1 : i64, window_strides = dense<[2, 1]> : tensor<2xi64>} : (memref<2x2x3x4xf32>, memref<3x5x5x3xf32>, memref<3x5x5x4xf32>) -> ()
|
||||
|
||||
"xla_lhlo.copy"(%0, %output) : (memref<3x5x5x4xf32>, memref<3x5x5x4xf32>) -> ()
|
||||
"xla_lhlo.terminator"() : () -> ()
|
||||
"lmhlo.copy"(%0, %output) : (memref<3x5x5x4xf32>, memref<3x5x5x4xf32>) -> ()
|
||||
"lmhlo.terminator"() : () -> ()
|
||||
}
|
||||
|
@ -2,7 +2,7 @@
|
||||
|
||||
// CHECK-LABEL: func @static_memref_cast
|
||||
func @static_memref_cast(%buf : memref<10x1x5xf32>) {
|
||||
%0 = xla_lhlo.static_memref_cast %buf
|
||||
%0 = lmhlo.static_memref_cast %buf
|
||||
: memref<10x1x5xf32> -> memref<10x5xf32, offset: 2, strides: [5, 1]>
|
||||
return
|
||||
}
|
||||
@ -38,7 +38,7 @@ func @dynamic_memref_cast(%buf : memref<?x?xf32>) {
|
||||
%size_Y = constant 50 : index
|
||||
%stride_X = constant 1 : index
|
||||
%stride_Y = constant 0 : index
|
||||
%0 = xla_lhlo.dynamic_memref_cast %buf(%size_X, %size_Y)[%stride_X, %stride_Y]
|
||||
%0 = lmhlo.dynamic_memref_cast %buf(%size_X, %size_Y)[%stride_X, %stride_Y]
|
||||
: memref<?x?xf32> -> memref<?x?xf32, offset: 0, strides: [?, ?]>
|
||||
return
|
||||
}
|
||||
|
@ -3,11 +3,11 @@
|
||||
func @reduce(%arg: memref<100x10x5xf32>,
|
||||
%init: memref<f32>,
|
||||
%result: memref<100x5xf32>) {
|
||||
"xla_lhlo.reduce"(%arg, %init, %result) ( {
|
||||
"lmhlo.reduce"(%arg, %init, %result) ( {
|
||||
^bb0(%lhs: memref<f32>, %rhs: memref<f32>, %res: memref<f32>):
|
||||
"xla_lhlo.add"(%lhs, %rhs, %res)
|
||||
"lmhlo.add"(%lhs, %rhs, %res)
|
||||
: (memref<f32>, memref<f32>, memref<f32>) -> ()
|
||||
"xla_lhlo.terminator"() : () -> ()
|
||||
"lmhlo.terminator"() : () -> ()
|
||||
} ) {dimensions = dense<[1]> : tensor<1xi64>}
|
||||
: (memref<100x10x5xf32>, memref<f32>, memref<100x5xf32>) -> ()
|
||||
return
|
||||
@ -35,7 +35,7 @@ func @reduce(%arg: memref<100x10x5xf32>,
|
||||
// CHECK: [[ACC_OUT_BUF:%.*]] = alloc() : memref<f32>
|
||||
// CHECK: store [[ELEM]], [[ELEM_BUF]][] : memref<f32>
|
||||
// CHECK: store [[ACC]], [[ACC_BUF]][] : memref<f32>
|
||||
// CHECK: "xla_lhlo.add"([[ELEM_BUF]], [[ACC_BUF]], [[ACC_OUT_BUF]])
|
||||
// CHECK: "lmhlo.add"([[ELEM_BUF]], [[ACC_BUF]], [[ACC_OUT_BUF]])
|
||||
// CHECK: [[ACC_RESULT:%.*]] = load [[ACC_OUT_BUF]][] : memref<f32>
|
||||
// CHECK: scf.reduce.return [[ACC_RESULT]] : f32
|
||||
// CHECK: }
|
||||
@ -49,11 +49,11 @@ func @reduce(%arg: memref<100x10x5xf32>,
|
||||
func @reduce_no_outer_loop(%arg: memref<100xf32>,
|
||||
%init: memref<f32>,
|
||||
%result: memref<1xf32>) {
|
||||
"xla_lhlo.reduce"(%arg, %init, %result) ( {
|
||||
"lmhlo.reduce"(%arg, %init, %result) ( {
|
||||
^bb0(%lhs: memref<f32>, %rhs: memref<f32>, %res: memref<f32>):
|
||||
"xla_lhlo.add"(%lhs, %rhs, %res)
|
||||
"lmhlo.add"(%lhs, %rhs, %res)
|
||||
: (memref<f32>, memref<f32>, memref<f32>) -> ()
|
||||
"xla_lhlo.terminator"() : () -> ()
|
||||
"lmhlo.terminator"() : () -> ()
|
||||
} ) {dimensions = dense<[0]> : tensor<1xi64>}
|
||||
: (memref<100xf32>, memref<f32>, memref<1xf32>) -> ()
|
||||
return
|
||||
@ -76,7 +76,7 @@ func @reduce_no_outer_loop(%arg: memref<100xf32>,
|
||||
// CHECK: [[ACC_OUT_BUF:%.*]] = alloc() : memref<f32>
|
||||
// CHECK: store [[ELEM]], [[ELEM_BUF]][] : memref<f32>
|
||||
// CHECK: store [[ACC]], [[ACC_BUF]][] : memref<f32>
|
||||
// CHECK: "xla_lhlo.add"([[ELEM_BUF]], [[ACC_BUF]], [[ACC_OUT_BUF]])
|
||||
// CHECK: "lmhlo.add"([[ELEM_BUF]], [[ACC_BUF]], [[ACC_OUT_BUF]])
|
||||
// CHECK: [[ACC_RESULT:%.*]] = load [[ACC_OUT_BUF]][] : memref<f32>
|
||||
// CHECK: scf.reduce.return [[ACC_RESULT]]
|
||||
// CHECK: }
|
||||
@ -88,11 +88,11 @@ func @reduce_no_outer_loop(%arg: memref<100xf32>,
|
||||
func @dynamic_reduce(%arg: memref<?x?x?xf32>,
|
||||
%init: memref<f32>,
|
||||
%result: memref<?x?xf32>) {
|
||||
"xla_lhlo.reduce"(%arg, %init, %result) ( {
|
||||
"lmhlo.reduce"(%arg, %init, %result) ( {
|
||||
^bb0(%lhs: memref<f32>, %rhs: memref<f32>, %res: memref<f32>):
|
||||
"xla_lhlo.add"(%lhs, %rhs, %res)
|
||||
"lmhlo.add"(%lhs, %rhs, %res)
|
||||
: (memref<f32>, memref<f32>, memref<f32>) -> ()
|
||||
"xla_lhlo.terminator"() : () -> ()
|
||||
"lmhlo.terminator"() : () -> ()
|
||||
} ) {dimensions = dense<[1]> : tensor<1xi64>}
|
||||
: (memref<?x?x?xf32>, memref<f32>, memref<?x?xf32>) -> ()
|
||||
return
|
||||
@ -121,7 +121,7 @@ func @dynamic_reduce(%arg: memref<?x?x?xf32>,
|
||||
// CHECK: [[ACC_OUT_BUF:%.*]] = alloc() : memref<f32>
|
||||
// CHECK: store [[ELEM]], [[ELEM_BUF]][] : memref<f32>
|
||||
// CHECK: store [[ACC]], [[ACC_BUF]][] : memref<f32>
|
||||
// CHECK: "xla_lhlo.add"([[ELEM_BUF]], [[ACC_BUF]], [[ACC_OUT_BUF]])
|
||||
// CHECK: "lmhlo.add"([[ELEM_BUF]], [[ACC_BUF]], [[ACC_OUT_BUF]])
|
||||
// CHECK: [[ACC_RESULT:%.*]] = load [[ACC_OUT_BUF]][] : memref<f32>
|
||||
// CHECK: scf.reduce.return [[ACC_RESULT]] : f32
|
||||
// CHECK: }
|
||||
@ -135,11 +135,11 @@ func @dynamic_reduce(%arg: memref<?x?x?xf32>,
|
||||
func @reduce_window(%arg: memref<112x112xf32>,
|
||||
%init: memref<f32>,
|
||||
%result: memref<56x56xf32>) {
|
||||
"xla_lhlo.reduce_window"(%arg, %init, %result) ( {
|
||||
"lmhlo.reduce_window"(%arg, %init, %result) ( {
|
||||
^bb0(%lhs: memref<f32>, %rhs: memref<f32>, %res: memref<f32>):
|
||||
"xla_lhlo.maximum"(%lhs, %rhs, %res)
|
||||
"lmhlo.maximum"(%lhs, %rhs, %res)
|
||||
: (memref<f32>, memref<f32>, memref<f32>) -> ()
|
||||
"xla_lhlo.terminator"() : () -> ()
|
||||
"lmhlo.terminator"() : () -> ()
|
||||
}) {
|
||||
padding = dense<[[0, 1], [0, 1]]> : tensor<2x2xi64>,
|
||||
window_dimensions = dense<[3, 3]> : tensor<2xi64>,
|
||||
@ -189,7 +189,7 @@ func @reduce_window(%arg: memref<112x112xf32>,
|
||||
// CHECK: [[ACC_OUT_BUF:%.*]] = alloc() : memref<f32>
|
||||
// CHECK: store [[ELEM]], [[ELEM_BUF]][] : memref<f32>
|
||||
// CHECK: store [[ACC]], [[ACC_BUF]][] : memref<f32>
|
||||
// CHECK: "xla_lhlo.maximum"([[ELEM_BUF]], [[ACC_BUF]], [[ACC_OUT_BUF]])
|
||||
// CHECK: "lmhlo.maximum"([[ELEM_BUF]], [[ACC_BUF]], [[ACC_OUT_BUF]])
|
||||
// CHECK: [[ACC_RESULT:%.*]] = load [[ACC_OUT_BUF]][] : memref<f32>
|
||||
// CHECK: scf.reduce.return [[ACC_RESULT]] : f32
|
||||
// CHECK: }
|
||||
|
@ -4,7 +4,7 @@
|
||||
|
||||
// CHECK-LABEL: func @ceil
|
||||
func @ceil(%input: memref<2x2xf32>, %result: memref<2x2xf32>) {
|
||||
"xla_lhlo.ceil"(%input, %result) : (memref<2x2xf32>, memref<2x2xf32>) -> ()
|
||||
"lmhlo.ceil"(%input, %result) : (memref<2x2xf32>, memref<2x2xf32>) -> ()
|
||||
return
|
||||
}
|
||||
|
||||
@ -12,7 +12,7 @@ func @ceil(%input: memref<2x2xf32>, %result: memref<2x2xf32>) {
|
||||
|
||||
func @ceil(%input: memref<2x2xi32>, %result: memref<2x2xi32>) {
|
||||
// expected-error@+1{{must be memref of floating-point values}}
|
||||
"xla_lhlo.ceil"(%input, %result) : (memref<2x2xi32>, memref<2x2xi32>) -> ()
|
||||
"lmhlo.ceil"(%input, %result) : (memref<2x2xi32>, memref<2x2xi32>) -> ()
|
||||
return
|
||||
}
|
||||
|
||||
@ -20,7 +20,7 @@ func @ceil(%input: memref<2x2xi32>, %result: memref<2x2xi32>) {
|
||||
|
||||
// CHECK-LABEL: func @cos
|
||||
func @cos(%input: memref<2x2xf32>, %result: memref<2x2xf32>) {
|
||||
"xla_lhlo.cosine"(%input, %result) : (memref<2x2xf32>, memref<2x2xf32>) -> ()
|
||||
"lmhlo.cosine"(%input, %result) : (memref<2x2xf32>, memref<2x2xf32>) -> ()
|
||||
return
|
||||
}
|
||||
|
||||
@ -28,7 +28,7 @@ func @cos(%input: memref<2x2xf32>, %result: memref<2x2xf32>) {
|
||||
|
||||
// CHECK-LABEL: func @cos
|
||||
func @cos(%input: memref<2x2xcomplex<f32>>, %result: memref<2x2xcomplex<f32>>) {
|
||||
"xla_lhlo.cosine"(%input, %result) : (memref<2x2xcomplex<f32>>, memref<2x2xcomplex<f32>>) -> ()
|
||||
"lmhlo.cosine"(%input, %result) : (memref<2x2xcomplex<f32>>, memref<2x2xcomplex<f32>>) -> ()
|
||||
return
|
||||
}
|
||||
|
||||
@ -36,7 +36,7 @@ func @cos(%input: memref<2x2xcomplex<f32>>, %result: memref<2x2xcomplex<f32>>) {
|
||||
|
||||
func @cos(%input: memref<2x2xi32>, %result: memref<2x2xi32>) {
|
||||
// expected-error@+1{{must be memref of floating-point or complex-type values}}
|
||||
"xla_lhlo.cosine"(%input, %result) : (memref<2x2xi32>, memref<2x2xi32>) -> ()
|
||||
"lmhlo.cosine"(%input, %result) : (memref<2x2xi32>, memref<2x2xi32>) -> ()
|
||||
return
|
||||
}
|
||||
|
||||
@ -44,7 +44,7 @@ func @cos(%input: memref<2x2xi32>, %result: memref<2x2xi32>) {
|
||||
|
||||
// CHECK-LABEL: func @sin
|
||||
func @sin(%input: memref<2x2xf32>, %result: memref<2x2xf32>) {
|
||||
"xla_lhlo.sine"(%input, %result) : (memref<2x2xf32>, memref<2x2xf32>) -> ()
|
||||
"lmhlo.sine"(%input, %result) : (memref<2x2xf32>, memref<2x2xf32>) -> ()
|
||||
return
|
||||
}
|
||||
|
||||
@ -52,7 +52,7 @@ func @sin(%input: memref<2x2xf32>, %result: memref<2x2xf32>) {
|
||||
|
||||
// CHECK-LABEL: func @sin
|
||||
func @sin(%input: memref<2x2xcomplex<f32>>, %result: memref<2x2xcomplex<f32>>) {
|
||||
"xla_lhlo.sine"(%input, %result) : (memref<2x2xcomplex<f32>>, memref<2x2xcomplex<f32>>) -> ()
|
||||
"lmhlo.sine"(%input, %result) : (memref<2x2xcomplex<f32>>, memref<2x2xcomplex<f32>>) -> ()
|
||||
return
|
||||
}
|
||||
|
||||
@ -60,7 +60,7 @@ func @sin(%input: memref<2x2xcomplex<f32>>, %result: memref<2x2xcomplex<f32>>) {
|
||||
|
||||
func @sin(%input: memref<2x2xi32>, %result: memref<2x2xi32>) {
|
||||
// expected-error@+1{{must be memref of floating-point or complex-type values}}
|
||||
"xla_lhlo.sine"(%input, %result) : (memref<2x2xi32>, memref<2x2xi32>) -> ()
|
||||
"lmhlo.sine"(%input, %result) : (memref<2x2xi32>, memref<2x2xi32>) -> ()
|
||||
return
|
||||
}
|
||||
|
||||
@ -68,7 +68,7 @@ func @sin(%input: memref<2x2xi32>, %result: memref<2x2xi32>) {
|
||||
|
||||
// CHECK-LABEL: func @add_memrefs
|
||||
func @add_memrefs(%arg0: memref<1xi32>, %arg1: memref<1xi32>, %arg_out: memref<1xi32>) -> () {
|
||||
"xla_lhlo.add"(%arg0, %arg1, %arg_out) : (memref<1xi32>, memref<1xi32>, memref<1xi32>) -> ()
|
||||
"lmhlo.add"(%arg0, %arg1, %arg_out) : (memref<1xi32>, memref<1xi32>, memref<1xi32>) -> ()
|
||||
return
|
||||
}
|
||||
|
||||
@ -76,7 +76,7 @@ func @add_memrefs(%arg0: memref<1xi32>, %arg1: memref<1xi32>, %arg_out: memref<1
|
||||
|
||||
// CHECK-LABEL: func @abs_memref
|
||||
func @abs_memref(%in: memref<10xf32>, %out: memref<10xf32>) -> () {
|
||||
"xla_lhlo.abs"(%in, %out) : (memref<10xf32>, memref<10xf32>) -> ()
|
||||
"lmhlo.abs"(%in, %out) : (memref<10xf32>, memref<10xf32>) -> ()
|
||||
return
|
||||
}
|
||||
|
||||
@ -84,7 +84,7 @@ func @abs_memref(%in: memref<10xf32>, %out: memref<10xf32>) -> () {
|
||||
|
||||
// CHECK-LABEL: func @convert_memref
|
||||
func @convert_memref(%in: memref<10xf32>, %out: memref<10xi32>) -> () {
|
||||
"xla_lhlo.convert"(%in, %out) : (memref<10xf32>, memref<10xi32>) -> ()
|
||||
"lmhlo.convert"(%in, %out) : (memref<10xf32>, memref<10xi32>) -> ()
|
||||
return
|
||||
}
|
||||
|
||||
@ -92,7 +92,7 @@ func @convert_memref(%in: memref<10xf32>, %out: memref<10xi32>) -> () {
|
||||
|
||||
func @convert_memref(%in: memref<10xf32>, %out: memref<9xi32>) -> () {
|
||||
// expected-error@+1{{requires the same shape for all operands}}
|
||||
"xla_lhlo.convert"(%in, %out) : (memref<10xf32>, memref<9xi32>) -> ()
|
||||
"lmhlo.convert"(%in, %out) : (memref<10xf32>, memref<9xi32>) -> ()
|
||||
return
|
||||
}
|
||||
|
||||
@ -100,7 +100,7 @@ func @convert_memref(%in: memref<10xf32>, %out: memref<9xi32>) -> () {
|
||||
|
||||
// CHECK-LABEL: func @exp
|
||||
func @exp(%input: memref<2x2xf32>, %result: memref<2x2xf32>) {
|
||||
"xla_lhlo.exponential"(%input, %result) : (memref<2x2xf32>, memref<2x2xf32>) -> ()
|
||||
"lmhlo.exponential"(%input, %result) : (memref<2x2xf32>, memref<2x2xf32>) -> ()
|
||||
return
|
||||
}
|
||||
|
||||
@ -108,7 +108,7 @@ func @exp(%input: memref<2x2xf32>, %result: memref<2x2xf32>) {
|
||||
|
||||
// CHECK-LABEL: func @exp
|
||||
func @exp(%input: memref<2x2xcomplex<f32>>, %result: memref<2x2xcomplex<f32>>) {
|
||||
"xla_lhlo.exponential"(%input, %result) : (memref<2x2xcomplex<f32>>, memref<2x2xcomplex<f32>>) -> ()
|
||||
"lmhlo.exponential"(%input, %result) : (memref<2x2xcomplex<f32>>, memref<2x2xcomplex<f32>>) -> ()
|
||||
return
|
||||
}
|
||||
|
||||
@ -116,7 +116,7 @@ func @exp(%input: memref<2x2xcomplex<f32>>, %result: memref<2x2xcomplex<f32>>) {
|
||||
|
||||
func @exp(%input: memref<2x2xi32>, %result: memref<2x2xi32>) {
|
||||
// expected-error@+1{{must be memref of floating-point or complex-type values}}
|
||||
"xla_lhlo.exponential"(%input, %result) : (memref<2x2xi32>, memref<2x2xi32>) -> ()
|
||||
"lmhlo.exponential"(%input, %result) : (memref<2x2xi32>, memref<2x2xi32>) -> ()
|
||||
return
|
||||
}
|
||||
|
||||
@ -124,7 +124,7 @@ func @exp(%input: memref<2x2xi32>, %result: memref<2x2xi32>) {
|
||||
|
||||
// CHECK-LABEL: func @log_memref
|
||||
func @log_memref(%in: memref<10xf32>, %out: memref<10xf32>) -> () {
|
||||
"xla_lhlo.log"(%in, %out) : (memref<10xf32>, memref<10xf32>) -> ()
|
||||
"lmhlo.log"(%in, %out) : (memref<10xf32>, memref<10xf32>) -> ()
|
||||
return
|
||||
}
|
||||
|
||||
@ -132,7 +132,7 @@ func @log_memref(%in: memref<10xf32>, %out: memref<10xf32>) -> () {
|
||||
|
||||
// CHECK-LABEL: func @log_memref
|
||||
func @log_memref(%in: memref<10xcomplex<f32>>, %out: memref<10xcomplex<f32>>) -> () {
|
||||
"xla_lhlo.log"(%in, %out) : (memref<10xcomplex<f32>>, memref<10xcomplex<f32>>) -> ()
|
||||
"lmhlo.log"(%in, %out) : (memref<10xcomplex<f32>>, memref<10xcomplex<f32>>) -> ()
|
||||
return
|
||||
}
|
||||
|
||||
@ -140,7 +140,7 @@ func @log_memref(%in: memref<10xcomplex<f32>>, %out: memref<10xcomplex<f32>>) ->
|
||||
|
||||
func @log_memref(%in: memref<10xi32>, %out: memref<10xi32>) -> () {
|
||||
// expected-error@+1{{must be memref of floating-point or complex-type values}}
|
||||
"xla_lhlo.log"(%in, %out) : (memref<10xi32>, memref<10xi32>) -> ()
|
||||
"lmhlo.log"(%in, %out) : (memref<10xi32>, memref<10xi32>) -> ()
|
||||
return
|
||||
}
|
||||
|
||||
@ -148,7 +148,7 @@ func @log_memref(%in: memref<10xi32>, %out: memref<10xi32>) -> () {
|
||||
|
||||
// CHECK-LABEL: func @neg_memref
|
||||
func @neg_memref(%in: memref<10xf32>, %out: memref<10xf32>) -> () {
|
||||
"xla_lhlo.negate"(%in, %out) : (memref<10xf32>, memref<10xf32>) -> ()
|
||||
"lmhlo.negate"(%in, %out) : (memref<10xf32>, memref<10xf32>) -> ()
|
||||
return
|
||||
}
|
||||
|
||||
@ -156,7 +156,7 @@ func @neg_memref(%in: memref<10xf32>, %out: memref<10xf32>) -> () {
|
||||
|
||||
// CHECK-LABEL: func @rsqrt_memref
|
||||
func @rsqrt_memref(%in: memref<10xf32>, %out: memref<10xf32>) -> () {
|
||||
"xla_lhlo.rsqrt"(%in, %out) : (memref<10xf32>, memref<10xf32>) -> ()
|
||||
"lmhlo.rsqrt"(%in, %out) : (memref<10xf32>, memref<10xf32>) -> ()
|
||||
return
|
||||
}
|
||||
|
||||
@ -164,7 +164,7 @@ func @rsqrt_memref(%in: memref<10xf32>, %out: memref<10xf32>) -> () {
|
||||
|
||||
// CHECK-LABEL: func @rsqrt_memref
|
||||
func @rsqrt_memref(%in: memref<10xcomplex<f32>>, %out: memref<10xcomplex<f32>>) -> () {
|
||||
"xla_lhlo.rsqrt"(%in, %out) : (memref<10xcomplex<f32>>, memref<10xcomplex<f32>>) -> ()
|
||||
"lmhlo.rsqrt"(%in, %out) : (memref<10xcomplex<f32>>, memref<10xcomplex<f32>>) -> ()
|
||||
return
|
||||
}
|
||||
|
||||
@ -172,7 +172,7 @@ func @rsqrt_memref(%in: memref<10xcomplex<f32>>, %out: memref<10xcomplex<f32>>)
|
||||
|
||||
func @rsqrt_memref(%in: memref<10xi32>, %out: memref<10xi32>) -> () {
|
||||
// expected-error@+1{{must be memref of floating-point or complex-type values}}
|
||||
"xla_lhlo.rsqrt"(%in, %out) : (memref<10xi32>, memref<10xi32>) -> ()
|
||||
"lmhlo.rsqrt"(%in, %out) : (memref<10xi32>, memref<10xi32>) -> ()
|
||||
return
|
||||
}
|
||||
|
||||
@ -180,7 +180,7 @@ func @rsqrt_memref(%in: memref<10xi32>, %out: memref<10xi32>) -> () {
|
||||
|
||||
// CHECK-LABEL: func @sqrt_memref
|
||||
func @sqrt_memref(%in: memref<10xf32>, %out: memref<10xf32>) -> () {
|
||||
"xla_lhlo.sqrt"(%in, %out) : (memref<10xf32>, memref<10xf32>) -> ()
|
||||
"lmhlo.sqrt"(%in, %out) : (memref<10xf32>, memref<10xf32>) -> ()
|
||||
return
|
||||
}
|
||||
|
||||
@ -188,7 +188,7 @@ func @sqrt_memref(%in: memref<10xf32>, %out: memref<10xf32>) -> () {
|
||||
|
||||
// CHECK-LABEL: func @sqrt_memref
|
||||
func @sqrt_memref(%in: memref<10xcomplex<f32>>, %out: memref<10xcomplex<f32>>) -> () {
|
||||
"xla_lhlo.sqrt"(%in, %out) : (memref<10xcomplex<f32>>, memref<10xcomplex<f32>>) -> ()
|
||||
"lmhlo.sqrt"(%in, %out) : (memref<10xcomplex<f32>>, memref<10xcomplex<f32>>) -> ()
|
||||
return
|
||||
}
|
||||
|
||||
@ -196,7 +196,7 @@ func @sqrt_memref(%in: memref<10xcomplex<f32>>, %out: memref<10xcomplex<f32>>) -
|
||||
|
||||
func @sqrt_memref(%in: memref<10xi32>, %out: memref<10xi32>) -> () {
|
||||
// expected-error@+1{{must be memref of floating-point or complex-type values}}
|
||||
"xla_lhlo.sqrt"(%in, %out) : (memref<10xi32>, memref<10xi32>) -> ()
|
||||
"lmhlo.sqrt"(%in, %out) : (memref<10xi32>, memref<10xi32>) -> ()
|
||||
return
|
||||
}
|
||||
|
||||
@ -204,7 +204,7 @@ func @sqrt_memref(%in: memref<10xi32>, %out: memref<10xi32>) -> () {
|
||||
|
||||
// CHECK-LABEL: func @sign_memref
|
||||
func @sign_memref(%in: memref<10xf32>, %out: memref<10xf32>) -> () {
|
||||
"xla_lhlo.sign"(%in, %out) : (memref<10xf32>, memref<10xf32>) -> ()
|
||||
"lmhlo.sign"(%in, %out) : (memref<10xf32>, memref<10xf32>) -> ()
|
||||
return
|
||||
}
|
||||
|
||||
@ -212,7 +212,7 @@ func @sign_memref(%in: memref<10xf32>, %out: memref<10xf32>) -> () {
|
||||
|
||||
// CHECK-LABEL: func @tanh_memref
|
||||
func @tanh_memref(%in: memref<10xf32>, %out: memref<10xf32>) -> () {
|
||||
"xla_lhlo.tanh"(%in, %out) : (memref<10xf32>, memref<10xf32>) -> ()
|
||||
"lmhlo.tanh"(%in, %out) : (memref<10xf32>, memref<10xf32>) -> ()
|
||||
return
|
||||
}
|
||||
|
||||
@ -220,7 +220,7 @@ func @tanh_memref(%in: memref<10xf32>, %out: memref<10xf32>) -> () {
|
||||
|
||||
// CHECK-LABEL: func @tanh_memref
|
||||
func @tanh_memref(%in: memref<10xcomplex<f32>>, %out: memref<10xcomplex<f32>>) -> () {
|
||||
"xla_lhlo.tanh"(%in, %out) : (memref<10xcomplex<f32>>, memref<10xcomplex<f32>>) -> ()
|
||||
"lmhlo.tanh"(%in, %out) : (memref<10xcomplex<f32>>, memref<10xcomplex<f32>>) -> ()
|
||||
return
|
||||
}
|
||||
|
||||
@ -228,15 +228,15 @@ func @tanh_memref(%in: memref<10xcomplex<f32>>, %out: memref<10xcomplex<f32>>) -
|
||||
|
||||
func @tanh_memref(%in: memref<10xi32>, %out: memref<10xi32>) -> () {
|
||||
// expected-error@+1{{must be memref of floating-point or complex-type values}}
|
||||
"xla_lhlo.tanh"(%in, %out) : (memref<10xi32>, memref<10xi32>) -> ()
|
||||
"lmhlo.tanh"(%in, %out) : (memref<10xi32>, memref<10xi32>) -> ()
|
||||
return
|
||||
}
|
||||
|
||||
// -----
|
||||
|
||||
func @tanh_memref(%arg0: memref<1xf32>, %arg1: memref<2xf32>) -> () {
|
||||
// expected-error@+1{{'xla_lhlo.tanh' op requires all operands to have the same type}}
|
||||
"xla_lhlo.tanh"(%arg0, %arg1) : (memref<1xf32>, memref<2xf32>) -> ()
|
||||
// expected-error@+1{{'lmhlo.tanh' op requires all operands to have the same type}}
|
||||
"lmhlo.tanh"(%arg0, %arg1) : (memref<1xf32>, memref<2xf32>) -> ()
|
||||
return
|
||||
}
|
||||
|
||||
@ -244,7 +244,7 @@ func @tanh_memref(%arg0: memref<1xf32>, %arg1: memref<2xf32>) -> () {
|
||||
|
||||
// CHECK-LABEL: func @add_memref
|
||||
func @add_memref(%lhs: memref<10xf32>, %rhs: memref<10xf32>, %out: memref<10xf32>) -> () {
|
||||
"xla_lhlo.add"(%lhs, %rhs, %out) : (memref<10xf32>, memref<10xf32>, memref<10xf32>) -> ()
|
||||
"lmhlo.add"(%lhs, %rhs, %out) : (memref<10xf32>, memref<10xf32>, memref<10xf32>) -> ()
|
||||
return
|
||||
}
|
||||
|
||||
@ -252,7 +252,7 @@ func @add_memref(%lhs: memref<10xf32>, %rhs: memref<10xf32>, %out: memref<10xf32
|
||||
|
||||
// CHECK-LABEL: func @div_memref
|
||||
func @div_memref(%lhs: memref<10xf32>, %rhs: memref<10xf32>, %out: memref<10xf32>) -> () {
|
||||
"xla_lhlo.divide"(%lhs, %rhs, %out) : (memref<10xf32>, memref<10xf32>, memref<10xf32>) -> ()
|
||||
"lmhlo.divide"(%lhs, %rhs, %out) : (memref<10xf32>, memref<10xf32>, memref<10xf32>) -> ()
|
||||
return
|
||||
}
|
||||
|
||||
@ -260,7 +260,7 @@ func @div_memref(%lhs: memref<10xf32>, %rhs: memref<10xf32>, %out: memref<10xf32
|
||||
|
||||
// CHECK-LABEL: func @max_memref
|
||||
func @max_memref(%lhs: memref<10xf32>, %rhs: memref<10xf32>, %out: memref<10xf32>) -> () {
|
||||
"xla_lhlo.maximum"(%lhs, %rhs, %out) : (memref<10xf32>, memref<10xf32>, memref<10xf32>) -> ()
|
||||
"lmhlo.maximum"(%lhs, %rhs, %out) : (memref<10xf32>, memref<10xf32>, memref<10xf32>) -> ()
|
||||
return
|
||||
}
|
||||
|
||||
@ -268,7 +268,7 @@ func @max_memref(%lhs: memref<10xf32>, %rhs: memref<10xf32>, %out: memref<10xf32
|
||||
|
||||
// CHECK-LABEL: func @min_memref
|
||||
func @min_memref(%lhs: memref<10xf32>, %rhs: memref<10xf32>, %out: memref<10xf32>) -> () {
|
||||
"xla_lhlo.minimum"(%lhs, %rhs, %out) : (memref<10xf32>, memref<10xf32>, memref<10xf32>) -> ()
|
||||
"lmhlo.minimum"(%lhs, %rhs, %out) : (memref<10xf32>, memref<10xf32>, memref<10xf32>) -> ()
|
||||
return
|
||||
}
|
||||
|
||||
@ -276,7 +276,7 @@ func @min_memref(%lhs: memref<10xf32>, %rhs: memref<10xf32>, %out: memref<10xf32
|
||||
|
||||
// CHECK-LABEL: func @mul_memref
|
||||
func @mul_memref(%lhs: memref<10xf32>, %rhs: memref<10xf32>, %out: memref<10xf32>) -> () {
|
||||
"xla_lhlo.multiply"(%lhs, %rhs, %out) : (memref<10xf32>, memref<10xf32>, memref<10xf32>) -> ()
|
||||
"lmhlo.multiply"(%lhs, %rhs, %out) : (memref<10xf32>, memref<10xf32>, memref<10xf32>) -> ()
|
||||
return
|
||||
}
|
||||
|
||||
@ -284,7 +284,7 @@ func @mul_memref(%lhs: memref<10xf32>, %rhs: memref<10xf32>, %out: memref<10xf32
|
||||
|
||||
// CHECK-LABEL: func @sub_memref
|
||||
func @sub_memref(%lhs: memref<10xf32>, %rhs: memref<10xf32>, %out: memref<10xf32>) -> () {
|
||||
"xla_lhlo.subtract"(%lhs, %rhs, %out) : (memref<10xf32>, memref<10xf32>, memref<10xf32>) -> ()
|
||||
"lmhlo.subtract"(%lhs, %rhs, %out) : (memref<10xf32>, memref<10xf32>, memref<10xf32>) -> ()
|
||||
return
|
||||
}
|
||||
|
||||
@ -292,7 +292,7 @@ func @sub_memref(%lhs: memref<10xf32>, %rhs: memref<10xf32>, %out: memref<10xf32
|
||||
|
||||
// CHECK-LABEL: func @and_memref
|
||||
func @and_memref(%lhs: memref<10xi32>, %rhs: memref<10xi32>, %out: memref<10xi32>) -> () {
|
||||
"xla_lhlo.and"(%lhs, %rhs, %out) : (memref<10xi32>, memref<10xi32>, memref<10xi32>) -> ()
|
||||
"lmhlo.and"(%lhs, %rhs, %out) : (memref<10xi32>, memref<10xi32>, memref<10xi32>) -> ()
|
||||
return
|
||||
}
|
||||
|
||||
@ -300,7 +300,7 @@ func @and_memref(%lhs: memref<10xi32>, %rhs: memref<10xi32>, %out: memref<10xi32
|
||||
|
||||
// CHECK-LABEL: func @and_memref
|
||||
func @and_memref(%lhs: memref<10xi1>, %rhs: memref<10xi1>, %out: memref<10xi1>) -> () {
|
||||
"xla_lhlo.and"(%lhs, %rhs, %out) : (memref<10xi1>, memref<10xi1>, memref<10xi1>) -> ()
|
||||
"lmhlo.and"(%lhs, %rhs, %out) : (memref<10xi1>, memref<10xi1>, memref<10xi1>) -> ()
|
||||
return
|
||||
}
|
||||
|
||||
@ -308,7 +308,7 @@ func @and_memref(%lhs: memref<10xi1>, %rhs: memref<10xi1>, %out: memref<10xi1>)
|
||||
|
||||
func @and_memref(%lhs: memref<10xf32>, %rhs: memref<10xf32>, %out: memref<10xf32>) -> () {
|
||||
// expected-error @+1 {{must be memref of 8/16/32/64-bit signless integer or 8/16/32/64-bit unsigned integer or pred (AKA boolean or 1-bit integer) values}}
|
||||
"xla_lhlo.and"(%lhs, %rhs, %out) : (memref<10xf32>, memref<10xf32>, memref<10xf32>) -> ()
|
||||
"lmhlo.and"(%lhs, %rhs, %out) : (memref<10xf32>, memref<10xf32>, memref<10xf32>) -> ()
|
||||
return
|
||||
}
|
||||
|
||||
@ -316,7 +316,7 @@ func @and_memref(%lhs: memref<10xf32>, %rhs: memref<10xf32>, %out: memref<10xf32
|
||||
|
||||
// CHECK-LABEL: func @or_memref
|
||||
func @or_memref(%lhs: memref<10xi32>, %rhs: memref<10xi32>, %out: memref<10xi32>) -> () {
|
||||
"xla_lhlo.or"(%lhs, %rhs, %out) : (memref<10xi32>, memref<10xi32>, memref<10xi32>) -> ()
|
||||
"lmhlo.or"(%lhs, %rhs, %out) : (memref<10xi32>, memref<10xi32>, memref<10xi32>) -> ()
|
||||
return
|
||||
}
|
||||
|
||||
@ -324,7 +324,7 @@ func @or_memref(%lhs: memref<10xi32>, %rhs: memref<10xi32>, %out: memref<10xi32>
|
||||
|
||||
// CHECK-LABEL: func @or_memref
|
||||
func @or_memref(%lhs: memref<10xi1>, %rhs: memref<10xi1>, %out: memref<10xi1>) -> () {
|
||||
"xla_lhlo.or"(%lhs, %rhs, %out) : (memref<10xi1>, memref<10xi1>, memref<10xi1>) -> ()
|
||||
"lmhlo.or"(%lhs, %rhs, %out) : (memref<10xi1>, memref<10xi1>, memref<10xi1>) -> ()
|
||||
return
|
||||
}
|
||||
|
||||
@ -332,7 +332,7 @@ func @or_memref(%lhs: memref<10xi1>, %rhs: memref<10xi1>, %out: memref<10xi1>) -
|
||||
|
||||
func @or_memref(%lhs: memref<10xf32>, %rhs: memref<10xf32>, %out: memref<10xf32>) -> () {
|
||||
// expected-error @+1 {{must be memref of 8/16/32/64-bit signless integer or 8/16/32/64-bit unsigned integer or pred (AKA boolean or 1-bit integer) values}}
|
||||
"xla_lhlo.or"(%lhs, %rhs, %out) : (memref<10xf32>, memref<10xf32>, memref<10xf32>) -> ()
|
||||
"lmhlo.or"(%lhs, %rhs, %out) : (memref<10xf32>, memref<10xf32>, memref<10xf32>) -> ()
|
||||
return
|
||||
}
|
||||
|
||||
@ -340,7 +340,7 @@ func @or_memref(%lhs: memref<10xf32>, %rhs: memref<10xf32>, %out: memref<10xf32>
|
||||
|
||||
// CHECK-LABEL: func @xor_memref
|
||||
func @xor_memref(%lhs: memref<10xi32>, %rhs: memref<10xi32>, %out: memref<10xi32>) -> () {
|
||||
"xla_lhlo.xor"(%lhs, %rhs, %out) : (memref<10xi32>, memref<10xi32>, memref<10xi32>) -> ()
|
||||
"lmhlo.xor"(%lhs, %rhs, %out) : (memref<10xi32>, memref<10xi32>, memref<10xi32>) -> ()
|
||||
return
|
||||
}
|
||||
|
||||
@ -348,7 +348,7 @@ func @xor_memref(%lhs: memref<10xi32>, %rhs: memref<10xi32>, %out: memref<10xi32
|
||||
|
||||
// CHECK-LABEL: func @xor_memref
|
||||
func @xor_memref(%lhs: memref<10xi1>, %rhs: memref<10xi1>, %out: memref<10xi1>) -> () {
|
||||
"xla_lhlo.xor"(%lhs, %rhs, %out) : (memref<10xi1>, memref<10xi1>, memref<10xi1>) -> ()
|
||||
"lmhlo.xor"(%lhs, %rhs, %out) : (memref<10xi1>, memref<10xi1>, memref<10xi1>) -> ()
|
||||
return
|
||||
}
|
||||
|
||||
@ -356,7 +356,7 @@ func @xor_memref(%lhs: memref<10xi1>, %rhs: memref<10xi1>, %out: memref<10xi1>)
|
||||
|
||||
func @xor_memref(%lhs: memref<10xf32>, %rhs: memref<10xf32>, %out: memref<10xf32>) -> () {
|
||||
// expected-error @+1 {{must be memref of 8/16/32/64-bit signless integer or 8/16/32/64-bit unsigned integer or pred (AKA boolean or 1-bit integer) values}}
|
||||
"xla_lhlo.xor"(%lhs, %rhs, %out) : (memref<10xf32>, memref<10xf32>, memref<10xf32>) -> ()
|
||||
"lmhlo.xor"(%lhs, %rhs, %out) : (memref<10xf32>, memref<10xf32>, memref<10xf32>) -> ()
|
||||
return
|
||||
}
|
||||
|
||||
@ -364,7 +364,7 @@ func @xor_memref(%lhs: memref<10xf32>, %rhs: memref<10xf32>, %out: memref<10xf32
|
||||
|
||||
// CHECK-LABEL: func @broadcast_in_dim_memref
|
||||
func @broadcast_in_dim_memref(%arg0: memref<1x2xi32>, %out: memref<1x2x2xi32>) -> () {
|
||||
"xla_lhlo.broadcast_in_dim"(%arg0, %out) {broadcast_dimensions = dense<[1, 2]> : tensor<2xi64>} : (memref<1x2xi32>, memref<1x2x2xi32>) -> ()
|
||||
"lmhlo.broadcast_in_dim"(%arg0, %out) {broadcast_dimensions = dense<[1, 2]> : tensor<2xi64>} : (memref<1x2xi32>, memref<1x2x2xi32>) -> ()
|
||||
return
|
||||
}
|
||||
|
||||
@ -372,7 +372,7 @@ func @broadcast_in_dim_memref(%arg0: memref<1x2xi32>, %out: memref<1x2x2xi32>) -
|
||||
|
||||
// CHECK-LABEL: func @broadcast_in_dim_zero_rank_memref
|
||||
func @broadcast_in_dim_zero_rank_memref(%arg0: memref<i32>, %out: memref<1x2x3xi32>) -> () {
|
||||
"xla_lhlo.broadcast_in_dim"(%arg0, %out) {broadcast_dimensions = dense<[]> : tensor<0xi64>} : (memref<i32>, memref<1x2x3xi32>) -> ()
|
||||
"lmhlo.broadcast_in_dim"(%arg0, %out) {broadcast_dimensions = dense<[]> : tensor<0xi64>} : (memref<i32>, memref<1x2x3xi32>) -> ()
|
||||
return
|
||||
}
|
||||
|
||||
@ -381,10 +381,10 @@ func @broadcast_in_dim_zero_rank_memref(%arg0: memref<i32>, %out: memref<1x2x3xi
|
||||
|
||||
// CHECK-LABEL: func @reduce_memref
|
||||
func @reduce_memref(%input: memref<10xf32>, %init: memref<f32>, %out: memref<1xf32>) -> () {
|
||||
"xla_lhlo.reduce"(%input, %init, %out) ( {
|
||||
"lmhlo.reduce"(%input, %init, %out) ( {
|
||||
^bb0(%arg1: memref<f32>, %arg2: memref<f32>, %result: memref<f32>):
|
||||
"xla_lhlo.add"(%arg1, %arg2, %result) : (memref<f32>, memref<f32>, memref<f32>) -> ()
|
||||
"xla_lhlo.terminator"() : () -> ()
|
||||
"lmhlo.add"(%arg1, %arg2, %result) : (memref<f32>, memref<f32>, memref<f32>) -> ()
|
||||
"lmhlo.terminator"() : () -> ()
|
||||
} ) {dimensions = dense<[0]> : tensor<1xi64>} : (memref<10xf32>, memref<f32>, memref<1xf32>) -> ()
|
||||
return
|
||||
}
|
||||
@ -393,14 +393,14 @@ func @reduce_memref(%input: memref<10xf32>, %init: memref<f32>, %out: memref<1xf
|
||||
|
||||
// CHECK-LABEL: func @fusion_memref
|
||||
func @fusion_memref(%input1: memref<10xf32>, %input2: memref<10xf32>, %input3: memref<10xf32>, %out: memref<10xf32>) -> () {
|
||||
"xla_lhlo.fusion"() ( {
|
||||
"lmhlo.fusion"() ( {
|
||||
%0 = tensor_load %input1 : memref<10xf32>
|
||||
%1 = tensor_load %input2 : memref<10xf32>
|
||||
%2 = "mhlo.add"(%0, %1) {name = "add"} : (tensor<10xf32>, tensor<10xf32>) -> tensor<10xf32>
|
||||
%3 = tensor_load %input3 : memref<10xf32>
|
||||
%4 = "mhlo.multiply"(%2, %3) {name = "multiply"} : (tensor<10xf32>, tensor<10xf32>) -> tensor<10xf32>
|
||||
tensor_store %4, %out : memref<10xf32>
|
||||
"xla_lhlo.terminator"() : () -> ()
|
||||
"lmhlo.terminator"() : () -> ()
|
||||
} ) : () -> ()
|
||||
return
|
||||
}
|
||||
@ -409,18 +409,18 @@ func @fusion_memref(%input1: memref<10xf32>, %input2: memref<10xf32>, %input3: m
|
||||
|
||||
// CHECK-LABEL: func @case_memref
|
||||
func @case_memref(%index: memref<i32>, %operand_1: memref<f32>, %operand_2: memref<f32>, %operand_3: memref<f32>, %out: memref<f32>) -> () {
|
||||
"xla_lhlo.case"(%index, %operand_1, %operand_2, %operand_3, %out) ( {
|
||||
"lmhlo.case"(%index, %operand_1, %operand_2, %operand_3, %out) ( {
|
||||
^bb0(%arg0: memref<f32>):
|
||||
"xla_lhlo.negate"(%arg0, %out) : (memref<f32>, memref<f32>) -> ()
|
||||
"xla_lhlo.terminator"() : () -> ()
|
||||
"lmhlo.negate"(%arg0, %out) : (memref<f32>, memref<f32>) -> ()
|
||||
"lmhlo.terminator"() : () -> ()
|
||||
}, {
|
||||
^bb0(%arg0: memref<f32>):
|
||||
"xla_lhlo.copy"(%arg0, %out) : (memref<f32>, memref<f32>) -> ()
|
||||
"xla_lhlo.terminator"() : () -> ()
|
||||
"lmhlo.copy"(%arg0, %out) : (memref<f32>, memref<f32>) -> ()
|
||||
"lmhlo.terminator"() : () -> ()
|
||||
}, {
|
||||
^bb0(%arg0: memref<f32>):
|
||||
"xla_lhlo.add"(%arg0, %arg0, %out) : (memref<f32>, memref<f32>, memref<f32>) -> ()
|
||||
"xla_lhlo.terminator"() : () -> ()
|
||||
"lmhlo.add"(%arg0, %arg0, %out) : (memref<f32>, memref<f32>, memref<f32>) -> ()
|
||||
"lmhlo.terminator"() : () -> ()
|
||||
}
|
||||
) {operand_segment_sizes = dense<[1, 3, 1]> : vector<3xi32>}
|
||||
: (memref<i32>, memref<f32>, memref<f32>, memref<f32>, memref<f32>) -> ()
|
||||
@ -430,7 +430,7 @@ func @case_memref(%index: memref<i32>, %operand_1: memref<f32>, %operand_2: memr
|
||||
// -----
|
||||
|
||||
func @static_memref_cast(%in: memref<10x1xf32>) {
|
||||
%out = xla_lhlo.static_memref_cast %in
|
||||
%out = lmhlo.static_memref_cast %in
|
||||
: memref<10x1xf32> -> memref<10xf32, offset: 0, strides: [1]>
|
||||
return
|
||||
}
|
||||
@ -440,7 +440,7 @@ func @static_memref_cast(%in: memref<10x1xf32>) {
|
||||
|
||||
func @static_memref_cast_dynamic_operand(%in: memref<10x?xf32>) {
|
||||
// expected-error @+1 {{operand must have static shape}}
|
||||
%out = xla_lhlo.static_memref_cast %in
|
||||
%out = lmhlo.static_memref_cast %in
|
||||
: memref<10x?xf32> -> memref<10x1xf32, offset: 0, strides: [10, 1]>
|
||||
return
|
||||
}
|
||||
@ -449,7 +449,7 @@ func @static_memref_cast_dynamic_operand(%in: memref<10x?xf32>) {
|
||||
|
||||
func @static_memref_cast_dynamic_result(%in: memref<10x1xf32>) {
|
||||
// expected-error @+1 {{result must have static shape}}
|
||||
%out = xla_lhlo.static_memref_cast %in
|
||||
%out = lmhlo.static_memref_cast %in
|
||||
: memref<10x1xf32> -> memref<10x?xf32, offset: 0, strides: [?, ?]>
|
||||
return
|
||||
}
|
||||
@ -459,7 +459,7 @@ func @static_memref_cast_dynamic_result(%in: memref<10x1xf32>) {
|
||||
func @dynamic_memref_cast(%in: memref<?xf32>) {
|
||||
%size = constant 10 : index
|
||||
%step = constant 1 : index
|
||||
%out = xla_lhlo.dynamic_memref_cast %in(%size)[%step]
|
||||
%out = lmhlo.dynamic_memref_cast %in(%size)[%step]
|
||||
: memref<?xf32> -> memref<?xf32, offset: 0, strides: [?]>
|
||||
return
|
||||
}
|
||||
@ -471,7 +471,7 @@ func @dynamic_memref_cast_incompatible_result_type(%in: memref<?xf32>) {
|
||||
// expected-error @+3 {{`sizes` args count must be equal to the rank of the output memref}}
|
||||
%size = constant 10 : index
|
||||
%step = constant 1 : index
|
||||
%out = xla_lhlo.dynamic_memref_cast %in(%size)[%step]
|
||||
%out = lmhlo.dynamic_memref_cast %in(%size)[%step]
|
||||
: memref<?xf32> -> memref<?x?xf32, offset: 0, strides: [?, ?]>
|
||||
return
|
||||
}
|
||||
@ -483,19 +483,19 @@ func @reshape_memref_cast(%unranked: memref<*xf32>, %shape1: memref<1xi32>,
|
||||
// CHECK-SAME: [[UNRANKED:%.*]]: memref<*xf32>, [[SHAPE_1:%.*]]: memref<1xi32>,
|
||||
// CHECK-SAME: [[SHAPE_2:%.*]]: memref<2xi32>, [[SHAPE_3:%.*]]: memref<?xi32>
|
||||
|
||||
// CHECK-NEXT: [[DYN_VEC:%.*]] = xla_lhlo.reshape_memref_cast [[UNRANKED]]
|
||||
// CHECK-NEXT: [[DYN_VEC:%.*]] = lmhlo.reshape_memref_cast [[UNRANKED]]
|
||||
// CHECK-SAME: : (memref<*xf32>, memref<1xi32>) -> memref<?xf32>
|
||||
%dyn_vec = xla_lhlo.reshape_memref_cast %unranked(%shape1)
|
||||
%dyn_vec = lmhlo.reshape_memref_cast %unranked(%shape1)
|
||||
: (memref<*xf32>, memref<1xi32>) -> memref<?xf32>
|
||||
|
||||
// CHECK-NEXT: [[DYN_MAT:%.*]] = xla_lhlo.reshape_memref_cast [[DYN_VEC]]
|
||||
// CHECK-NEXT: [[DYN_MAT:%.*]] = lmhlo.reshape_memref_cast [[DYN_VEC]]
|
||||
// CHECK-SAME: : (memref<?xf32>, memref<2xi32>) -> memref<?x?xf32>
|
||||
%dyn_mat = xla_lhlo.reshape_memref_cast %dyn_vec(%shape2)
|
||||
%dyn_mat = lmhlo.reshape_memref_cast %dyn_vec(%shape2)
|
||||
: (memref<?xf32>, memref<2xi32>) -> memref<?x?xf32>
|
||||
|
||||
// CHECK-NEXT: {{%.*}} = xla_lhlo.reshape_memref_cast [[DYN_MAT]]
|
||||
// CHECK-NEXT: {{%.*}} = lmhlo.reshape_memref_cast [[DYN_MAT]]
|
||||
// CHECK-SAME: : (memref<?x?xf32>, memref<?xi32>) -> memref<*xf32>
|
||||
%new_unranked = xla_lhlo.reshape_memref_cast %dyn_mat(%shape3)
|
||||
%new_unranked = lmhlo.reshape_memref_cast %dyn_mat(%shape3)
|
||||
: (memref<?x?xf32>, memref<?xi32>) -> memref<*xf32>
|
||||
return
|
||||
}
|
||||
@ -505,7 +505,7 @@ func @reshape_memref_cast(%unranked: memref<*xf32>, %shape1: memref<1xi32>,
|
||||
func @reshape_memref_cast_element_type_mismatch(
|
||||
%buf: memref<*xf32>, %shape: memref<1xi32>) {
|
||||
// expected-error @+1 {{element types of source and destination memref types should be the same}}
|
||||
xla_lhlo.reshape_memref_cast %buf(%shape)
|
||||
lmhlo.reshape_memref_cast %buf(%shape)
|
||||
: (memref<*xf32>, memref<1xi32>) -> memref<?xi32>
|
||||
}
|
||||
|
||||
@ -514,7 +514,7 @@ func @reshape_memref_cast_element_type_mismatch(
|
||||
func @reshape_memref_cast_dst_ranked_shape_unranked(
|
||||
%buf: memref<*xf32>, %shape: memref<?xi32>) {
|
||||
// expected-error @+1 {{cannot use shape operand with dynamic length to cast statically-ranked memref type}}
|
||||
xla_lhlo.reshape_memref_cast %buf(%shape)
|
||||
lmhlo.reshape_memref_cast %buf(%shape)
|
||||
: (memref<*xf32>, memref<?xi32>) -> memref<?xf32>
|
||||
return
|
||||
}
|
||||
@ -524,7 +524,7 @@ func @reshape_memref_cast_dst_ranked_shape_unranked(
|
||||
func @reshape_memref_cast_dst_shape_rank_mismatch(
|
||||
%buf: memref<*xf32>, %shape: memref<1xi32>) {
|
||||
// expected-error @+1 {{length of shape operand differs from the result's memref rank}}
|
||||
xla_lhlo.reshape_memref_cast %buf(%shape)
|
||||
lmhlo.reshape_memref_cast %buf(%shape)
|
||||
: (memref<*xf32>, memref<1xi32>) -> memref<?x?xf32>
|
||||
return
|
||||
}
|
||||
@ -535,7 +535,7 @@ func @reshape_memref_cast_affine_map_is_not_identity(
|
||||
%buf: memref<4x4xf32, offset: 0, strides: [3, 2]>,
|
||||
%shape: memref<1xi32>) {
|
||||
// expected-error @+1 {{operand memref type should have identity affine map}}
|
||||
xla_lhlo.reshape_memref_cast %buf(%shape)
|
||||
lmhlo.reshape_memref_cast %buf(%shape)
|
||||
: (memref<4x4xf32, offset: 0, strides: [3, 2]>, memref<1xi32>)
|
||||
-> memref<8xf32>
|
||||
return
|
||||
@ -545,7 +545,7 @@ func @reshape_memref_cast_affine_map_is_not_identity(
|
||||
|
||||
// CHECK-LABEL: func @atan2_memrefs
|
||||
func @atan2_memrefs(%arg0: memref<1xf32>, %arg1: memref<1xf32>, %arg_out: memref<1xf32>) -> () {
|
||||
"xla_lhlo.atan2"(%arg0, %arg1, %arg_out) : (memref<1xf32>, memref<1xf32>, memref<1xf32>) -> ()
|
||||
"lmhlo.atan2"(%arg0, %arg1, %arg_out) : (memref<1xf32>, memref<1xf32>, memref<1xf32>) -> ()
|
||||
return
|
||||
}
|
||||
|
||||
@ -553,7 +553,7 @@ func @atan2_memrefs(%arg0: memref<1xf32>, %arg1: memref<1xf32>, %arg_out: memref
|
||||
|
||||
// CHECK-LABEL: func @atan2_memrefs
|
||||
func @atan2_memrefs(%arg0: memref<1xcomplex<f32>>, %arg1: memref<1xcomplex<f32>>, %arg_out: memref<1xcomplex<f32>>) -> () {
|
||||
"xla_lhlo.atan2"(%arg0, %arg1, %arg_out) : (memref<1xcomplex<f32>>, memref<1xcomplex<f32>>, memref<1xcomplex<f32>>) -> ()
|
||||
"lmhlo.atan2"(%arg0, %arg1, %arg_out) : (memref<1xcomplex<f32>>, memref<1xcomplex<f32>>, memref<1xcomplex<f32>>) -> ()
|
||||
return
|
||||
}
|
||||
|
||||
@ -561,7 +561,7 @@ func @atan2_memrefs(%arg0: memref<1xcomplex<f32>>, %arg1: memref<1xcomplex<f32>>
|
||||
|
||||
func @atan2_memrefs(%arg0: memref<1xi32>, %arg1: memref<1xi32>, %arg_out: memref<1xi32>) -> () {
|
||||
// expected-error@+1{{must be memref of floating-point or complex-type values}}
|
||||
"xla_lhlo.atan2"(%arg0, %arg1, %arg_out) : (memref<1xi32>, memref<1xi32>, memref<1xi32>) -> ()
|
||||
"lmhlo.atan2"(%arg0, %arg1, %arg_out) : (memref<1xi32>, memref<1xi32>, memref<1xi32>) -> ()
|
||||
return
|
||||
}
|
||||
|
||||
@ -569,7 +569,7 @@ func @atan2_memrefs(%arg0: memref<1xi32>, %arg1: memref<1xi32>, %arg_out: memref
|
||||
|
||||
// CHECK-LABEL: func @bitcast_convert_memrefs
|
||||
func @bitcast_convert_memrefs(%arg0: memref<1xf32>, %arg_out: memref<1xi32>) -> () {
|
||||
"xla_lhlo.bitcast_convert"(%arg0, %arg_out) : (memref<1xf32>, memref<1xi32>) -> ()
|
||||
"lmhlo.bitcast_convert"(%arg0, %arg_out) : (memref<1xf32>, memref<1xi32>) -> ()
|
||||
return
|
||||
}
|
||||
|
||||
@ -577,7 +577,7 @@ func @bitcast_convert_memrefs(%arg0: memref<1xf32>, %arg_out: memref<1xi32>) ->
|
||||
|
||||
func @bitcast_convert_memrefs(%arg0: memref<1xf32>, %arg_out: memref<2xi32>) -> () {
|
||||
// expected-error@+1{{requires the same shape for all operands}}
|
||||
"xla_lhlo.bitcast_convert"(%arg0, %arg_out) : (memref<1xf32>, memref<2xi32>) -> ()
|
||||
"lmhlo.bitcast_convert"(%arg0, %arg_out) : (memref<1xf32>, memref<2xi32>) -> ()
|
||||
return
|
||||
}
|
||||
|
||||
@ -585,7 +585,7 @@ func @bitcast_convert_memrefs(%arg0: memref<1xf32>, %arg_out: memref<2xi32>) ->
|
||||
|
||||
// CHECK-LABEL: func @clz_memrefs
|
||||
func @clz_memrefs(%arg0: memref<1xi32>, %arg_out: memref<1xi32>) -> () {
|
||||
"xla_lhlo.count_leading_zeros"(%arg0, %arg_out) : (memref<1xi32>, memref<1xi32>) -> ()
|
||||
"lmhlo.count_leading_zeros"(%arg0, %arg_out) : (memref<1xi32>, memref<1xi32>) -> ()
|
||||
return
|
||||
}
|
||||
|
||||
@ -593,7 +593,7 @@ func @clz_memrefs(%arg0: memref<1xi32>, %arg_out: memref<1xi32>) -> () {
|
||||
|
||||
// CHECK-LABEL: func @expm1_memrefs
|
||||
func @expm1_memrefs(%arg0: memref<1xf32>, %arg_out: memref<1xf32>) -> () {
|
||||
"xla_lhlo.exponential_minus_one"(%arg0, %arg_out) : (memref<1xf32>, memref<1xf32>) -> ()
|
||||
"lmhlo.exponential_minus_one"(%arg0, %arg_out) : (memref<1xf32>, memref<1xf32>) -> ()
|
||||
return
|
||||
}
|
||||
|
||||
@ -601,7 +601,7 @@ func @expm1_memrefs(%arg0: memref<1xf32>, %arg_out: memref<1xf32>) -> () {
|
||||
|
||||
// CHECK-LABEL: func @expm1_memrefs
|
||||
func @expm1_memrefs(%arg0: memref<1xcomplex<f32>>, %arg_out: memref<1xcomplex<f32>>) -> () {
|
||||
"xla_lhlo.exponential_minus_one"(%arg0, %arg_out) : (memref<1xcomplex<f32>>, memref<1xcomplex<f32>>) -> ()
|
||||
"lmhlo.exponential_minus_one"(%arg0, %arg_out) : (memref<1xcomplex<f32>>, memref<1xcomplex<f32>>) -> ()
|
||||
return
|
||||
}
|
||||
|
||||
@ -609,7 +609,7 @@ func @expm1_memrefs(%arg0: memref<1xcomplex<f32>>, %arg_out: memref<1xcomplex<f3
|
||||
|
||||
// CHECK-LABEL: func @floor_memrefs
|
||||
func @floor_memrefs(%arg0: memref<1xf32>, %arg_out: memref<1xf32>) -> () {
|
||||
"xla_lhlo.floor"(%arg0, %arg_out) : (memref<1xf32>, memref<1xf32>) -> ()
|
||||
"lmhlo.floor"(%arg0, %arg_out) : (memref<1xf32>, memref<1xf32>) -> ()
|
||||
return
|
||||
}
|
||||
|
||||
@ -617,7 +617,7 @@ func @floor_memrefs(%arg0: memref<1xf32>, %arg_out: memref<1xf32>) -> () {
|
||||
|
||||
func @floor_memrefs(%arg0: memref<1xi32>, %arg_out: memref<1xi32>) -> () {
|
||||
// expected-error@+1{{must be memref of floating-point values}}
|
||||
"xla_lhlo.floor"(%arg0, %arg_out) : (memref<1xi32>, memref<1xi32>) -> ()
|
||||
"lmhlo.floor"(%arg0, %arg_out) : (memref<1xi32>, memref<1xi32>) -> ()
|
||||
return
|
||||
}
|
||||
|
||||
@ -625,7 +625,7 @@ func @floor_memrefs(%arg0: memref<1xi32>, %arg_out: memref<1xi32>) -> () {
|
||||
|
||||
// CHECK-LABEL: func @imag_memrefs
|
||||
func @imag_memrefs(%arg0: memref<1xcomplex<f32>>, %arg_out: memref<1xf32>) -> () {
|
||||
"xla_lhlo.imag"(%arg0, %arg_out) : (memref<1xcomplex<f32>>, memref<1xf32>) -> ()
|
||||
"lmhlo.imag"(%arg0, %arg_out) : (memref<1xcomplex<f32>>, memref<1xf32>) -> ()
|
||||
return
|
||||
}
|
||||
|
||||
@ -633,7 +633,7 @@ func @imag_memrefs(%arg0: memref<1xcomplex<f32>>, %arg_out: memref<1xf32>) -> ()
|
||||
|
||||
func @imag_memrefs(%arg0: memref<1xf32>, %arg_out: memref<1xf32>) -> () {
|
||||
// expected-error@+1{{must be memref of complex-type values}}
|
||||
"xla_lhlo.imag"(%arg0, %arg_out) : (memref<1xf32>, memref<1xf32>) -> ()
|
||||
"lmhlo.imag"(%arg0, %arg_out) : (memref<1xf32>, memref<1xf32>) -> ()
|
||||
return
|
||||
}
|
||||
|
||||
@ -641,7 +641,7 @@ func @imag_memrefs(%arg0: memref<1xf32>, %arg_out: memref<1xf32>) -> () {
|
||||
|
||||
// CHECK-LABEL: func @real_memrefs
|
||||
func @real_memrefs(%arg0: memref<1xcomplex<f32>>, %arg_out: memref<1xf32>) -> () {
|
||||
"xla_lhlo.real"(%arg0, %arg_out) : (memref<1xcomplex<f32>>, memref<1xf32>) -> ()
|
||||
"lmhlo.real"(%arg0, %arg_out) : (memref<1xcomplex<f32>>, memref<1xf32>) -> ()
|
||||
return
|
||||
}
|
||||
|
||||
@ -649,7 +649,7 @@ func @real_memrefs(%arg0: memref<1xcomplex<f32>>, %arg_out: memref<1xf32>) -> ()
|
||||
|
||||
func @real_memrefs(%arg0: memref<1xf32>, %arg_out: memref<1xf32>) -> () {
|
||||
// expected-error@+1{{must be memref of complex-type values}}
|
||||
"xla_lhlo.real"(%arg0, %arg_out) : (memref<1xf32>, memref<1xf32>) -> ()
|
||||
"lmhlo.real"(%arg0, %arg_out) : (memref<1xf32>, memref<1xf32>) -> ()
|
||||
return
|
||||
}
|
||||
|
||||
@ -657,7 +657,7 @@ func @real_memrefs(%arg0: memref<1xf32>, %arg_out: memref<1xf32>) -> () {
|
||||
|
||||
// CHECK-LABEL: func @is_finite_memrefs
|
||||
func @is_finite_memrefs(%arg0: memref<1xf32>, %arg_out: memref<1xi1>) -> () {
|
||||
"xla_lhlo.is_finite"(%arg0, %arg_out) : (memref<1xf32>, memref<1xi1>) -> ()
|
||||
"lmhlo.is_finite"(%arg0, %arg_out) : (memref<1xf32>, memref<1xi1>) -> ()
|
||||
return
|
||||
}
|
||||
|
||||
@ -665,7 +665,7 @@ func @is_finite_memrefs(%arg0: memref<1xf32>, %arg_out: memref<1xi1>) -> () {
|
||||
|
||||
// CHECK-LABEL: func @log1p_memrefs
|
||||
func @log1p_memrefs(%arg0: memref<1xf32>, %arg_out: memref<1xf32>) -> () {
|
||||
"xla_lhlo.log_plus_one"(%arg0, %arg_out) : (memref<1xf32>, memref<1xf32>) -> ()
|
||||
"lmhlo.log_plus_one"(%arg0, %arg_out) : (memref<1xf32>, memref<1xf32>) -> ()
|
||||
return
|
||||
}
|
||||
|
||||
@ -673,7 +673,7 @@ func @log1p_memrefs(%arg0: memref<1xf32>, %arg_out: memref<1xf32>) -> () {
|
||||
|
||||
// CHECK-LABEL: func @log1p_memrefs
|
||||
func @log1p_memrefs(%arg0: memref<1xcomplex<f32>>, %arg_out: memref<1xcomplex<f32>>) -> () {
|
||||
"xla_lhlo.log_plus_one"(%arg0, %arg_out) : (memref<1xcomplex<f32>>, memref<1xcomplex<f32>>) -> ()
|
||||
"lmhlo.log_plus_one"(%arg0, %arg_out) : (memref<1xcomplex<f32>>, memref<1xcomplex<f32>>) -> ()
|
||||
return
|
||||
}
|
||||
|
||||
@ -681,7 +681,7 @@ func @log1p_memrefs(%arg0: memref<1xcomplex<f32>>, %arg_out: memref<1xcomplex<f3
|
||||
|
||||
func @log1p_memref(%in: memref<10xi32>, %out: memref<10xi32>) -> () {
|
||||
// expected-error@+1{{must be memref of floating-point or complex-type values}}
|
||||
"xla_lhlo.log_plus_one"(%in, %out) : (memref<10xi32>, memref<10xi32>) -> ()
|
||||
"lmhlo.log_plus_one"(%in, %out) : (memref<10xi32>, memref<10xi32>) -> ()
|
||||
return
|
||||
}
|
||||
|
||||
@ -689,7 +689,7 @@ func @log1p_memref(%in: memref<10xi32>, %out: memref<10xi32>) -> () {
|
||||
|
||||
// CHECK-LABEL: func @not_memrefs
|
||||
func @not_memrefs(%arg0: memref<1xi32>, %arg_out: memref<1xi32>) -> () {
|
||||
"xla_lhlo.not"(%arg0, %arg_out) : (memref<1xi32>, memref<1xi32>) -> ()
|
||||
"lmhlo.not"(%arg0, %arg_out) : (memref<1xi32>, memref<1xi32>) -> ()
|
||||
return
|
||||
}
|
||||
|
||||
@ -697,7 +697,7 @@ func @not_memrefs(%arg0: memref<1xi32>, %arg_out: memref<1xi32>) -> () {
|
||||
|
||||
// CHECK-LABEL: func @not_memrefs
|
||||
func @not_memrefs(%arg0: memref<1xi1>, %arg_out: memref<1xi1>) -> () {
|
||||
"xla_lhlo.not"(%arg0, %arg_out) : (memref<1xi1>, memref<1xi1>) -> ()
|
||||
"lmhlo.not"(%arg0, %arg_out) : (memref<1xi1>, memref<1xi1>) -> ()
|
||||
return
|
||||
}
|
||||
|
||||
@ -705,7 +705,7 @@ func @not_memrefs(%arg0: memref<1xi1>, %arg_out: memref<1xi1>) -> () {
|
||||
|
||||
func @not_memrefs(%arg0: memref<1xf32>, %arg_out: memref<1xf32>) -> () {
|
||||
// expected-error @+1 {{must be memref of 8/16/32/64-bit signless integer or 8/16/32/64-bit unsigned integer or pred (AKA boolean or 1-bit integer) values}}
|
||||
"xla_lhlo.not"(%arg0, %arg_out) : (memref<1xf32>, memref<1xf32>) -> ()
|
||||
"lmhlo.not"(%arg0, %arg_out) : (memref<1xf32>, memref<1xf32>) -> ()
|
||||
return
|
||||
}
|
||||
|
||||
@ -713,7 +713,7 @@ func @not_memrefs(%arg0: memref<1xf32>, %arg_out: memref<1xf32>) -> () {
|
||||
|
||||
// CHECK-LABEL: func @popcnt_memrefs
|
||||
func @popcnt_memrefs(%arg0: memref<1xi32>, %arg_out: memref<1xi32>) -> () {
|
||||
"xla_lhlo.popcnt"(%arg0, %arg_out) : (memref<1xi32>, memref<1xi32>) -> ()
|
||||
"lmhlo.popcnt"(%arg0, %arg_out) : (memref<1xi32>, memref<1xi32>) -> ()
|
||||
return
|
||||
}
|
||||
|
||||
@ -721,7 +721,7 @@ func @popcnt_memrefs(%arg0: memref<1xi32>, %arg_out: memref<1xi32>) -> () {
|
||||
|
||||
func @popcnt_memrefs(%arg0: memref<1xf32>, %arg_out: memref<1xf32>) -> () {
|
||||
// expected-error @+1 {{must be memref of 8/16/32/64-bit signless integer or 8/16/32/64-bit unsigned integer values}}
|
||||
"xla_lhlo.popcnt"(%arg0, %arg_out) : (memref<1xf32>, memref<1xf32>) -> ()
|
||||
"lmhlo.popcnt"(%arg0, %arg_out) : (memref<1xf32>, memref<1xf32>) -> ()
|
||||
return
|
||||
}
|
||||
|
||||
@ -729,7 +729,7 @@ func @popcnt_memrefs(%arg0: memref<1xf32>, %arg_out: memref<1xf32>) -> () {
|
||||
|
||||
// CHECK-LABEL: func @reduce_precision_memrefs
|
||||
func @reduce_precision_memrefs(%arg0: memref<1xf32>, %arg_out: memref<1xf32>) -> () {
|
||||
"xla_lhlo.reduce_precision"(%arg0, %arg_out) { exponent_bits = 4 : i32, mantissa_bits = 4 : i32 } : (memref<1xf32>, memref<1xf32>) -> ()
|
||||
"lmhlo.reduce_precision"(%arg0, %arg_out) { exponent_bits = 4 : i32, mantissa_bits = 4 : i32 } : (memref<1xf32>, memref<1xf32>) -> ()
|
||||
return
|
||||
}
|
||||
|
||||
@ -737,7 +737,7 @@ func @reduce_precision_memrefs(%arg0: memref<1xf32>, %arg_out: memref<1xf32>) ->
|
||||
|
||||
// CHECK-LABEL: func @round_memrefs
|
||||
func @round_memrefs(%arg0: memref<1xf32>, %arg_out: memref<1xf32>) -> () {
|
||||
"xla_lhlo.round_nearest_afz"(%arg0, %arg_out) : (memref<1xf32>, memref<1xf32>) -> ()
|
||||
"lmhlo.round_nearest_afz"(%arg0, %arg_out) : (memref<1xf32>, memref<1xf32>) -> ()
|
||||
return
|
||||
}
|
||||
|
||||
@ -745,7 +745,7 @@ func @round_memrefs(%arg0: memref<1xf32>, %arg_out: memref<1xf32>) -> () {
|
||||
|
||||
func @round_memrefs(%arg0: memref<1xi32>, %arg_out: memref<1xi32>) -> () {
|
||||
// expected-error@+1{{must be memref of floating-point values}}
|
||||
"xla_lhlo.round_nearest_afz"(%arg0, %arg_out) : (memref<1xi32>, memref<1xi32>) -> ()
|
||||
"lmhlo.round_nearest_afz"(%arg0, %arg_out) : (memref<1xi32>, memref<1xi32>) -> ()
|
||||
return
|
||||
}
|
||||
|
||||
@ -753,7 +753,7 @@ func @round_memrefs(%arg0: memref<1xi32>, %arg_out: memref<1xi32>) -> () {
|
||||
|
||||
// CHECK-LABEL: func @shift_left_memrefs
|
||||
func @shift_left_memrefs(%arg0: memref<1xi32>, %arg1: memref<1xi32>, %arg_out: memref<1xi32>) -> () {
|
||||
"xla_lhlo.shift_left"(%arg0, %arg1, %arg_out) : (memref<1xi32>, memref<1xi32>, memref<1xi32>) -> ()
|
||||
"lmhlo.shift_left"(%arg0, %arg1, %arg_out) : (memref<1xi32>, memref<1xi32>, memref<1xi32>) -> ()
|
||||
return
|
||||
}
|
||||
|
||||
@ -761,7 +761,7 @@ func @shift_left_memrefs(%arg0: memref<1xi32>, %arg1: memref<1xi32>, %arg_out: m
|
||||
|
||||
func @shift_left_memrefs(%arg0: memref<1xf32>, %arg1: memref<1xf32>, %arg_out: memref<1xf32>) -> () {
|
||||
// expected-error @+1 {{must be memref of 8/16/32/64-bit signless integer or 8/16/32/64-bit unsigned integer values}}
|
||||
"xla_lhlo.shift_left"(%arg0, %arg1, %arg_out) : (memref<1xf32>, memref<1xf32>, memref<1xf32>) -> ()
|
||||
"lmhlo.shift_left"(%arg0, %arg1, %arg_out) : (memref<1xf32>, memref<1xf32>, memref<1xf32>) -> ()
|
||||
return
|
||||
}
|
||||
|
||||
@ -769,7 +769,7 @@ func @shift_left_memrefs(%arg0: memref<1xf32>, %arg1: memref<1xf32>, %arg_out: m
|
||||
|
||||
// CHECK-LABEL: func @shift_right_arithmetic_memrefs
|
||||
func @shift_right_arithmetic_memrefs(%arg0: memref<1xi32>, %arg1: memref<1xi32>, %arg_out: memref<1xi32>) -> () {
|
||||
"xla_lhlo.shift_right_arithmetic"(%arg0, %arg1, %arg_out) : (memref<1xi32>, memref<1xi32>, memref<1xi32>) -> ()
|
||||
"lmhlo.shift_right_arithmetic"(%arg0, %arg1, %arg_out) : (memref<1xi32>, memref<1xi32>, memref<1xi32>) -> ()
|
||||
return
|
||||
}
|
||||
|
||||
@ -777,7 +777,7 @@ func @shift_right_arithmetic_memrefs(%arg0: memref<1xi32>, %arg1: memref<1xi32>,
|
||||
|
||||
func @shift_right_arithmetic_memrefs(%arg0: memref<1xf32>, %arg1: memref<1xf32>, %arg_out: memref<1xf32>) -> () {
|
||||
// expected-error @+1 {{must be memref of 8/16/32/64-bit signless integer or 8/16/32/64-bit unsigned integer values}}
|
||||
"xla_lhlo.shift_right_arithmetic"(%arg0, %arg1, %arg_out) : (memref<1xf32>, memref<1xf32>, memref<1xf32>) -> ()
|
||||
"lmhlo.shift_right_arithmetic"(%arg0, %arg1, %arg_out) : (memref<1xf32>, memref<1xf32>, memref<1xf32>) -> ()
|
||||
return
|
||||
}
|
||||
|
||||
@ -785,7 +785,7 @@ func @shift_right_arithmetic_memrefs(%arg0: memref<1xf32>, %arg1: memref<1xf32>,
|
||||
|
||||
// CHECK-LABEL: func @shift_right_logical_memrefs
|
||||
func @shift_right_logical_memrefs(%arg0: memref<1xi32>, %arg1: memref<1xi32>, %arg_out: memref<1xi32>) -> () {
|
||||
"xla_lhlo.shift_right_logical"(%arg0, %arg1, %arg_out) : (memref<1xi32>, memref<1xi32>, memref<1xi32>) -> ()
|
||||
"lmhlo.shift_right_logical"(%arg0, %arg1, %arg_out) : (memref<1xi32>, memref<1xi32>, memref<1xi32>) -> ()
|
||||
return
|
||||
}
|
||||
|
||||
@ -793,7 +793,7 @@ func @shift_right_logical_memrefs(%arg0: memref<1xi32>, %arg1: memref<1xi32>, %a
|
||||
|
||||
func @shift_right_logical_memrefs(%arg0: memref<1xf32>, %arg1: memref<1xf32>, %arg_out: memref<1xf32>) -> () {
|
||||
// expected-error @+1 {{must be memref of 8/16/32/64-bit signless integer or 8/16/32/64-bit unsigned integer values}}
|
||||
"xla_lhlo.shift_right_logical"(%arg0, %arg1, %arg_out) : (memref<1xf32>, memref<1xf32>, memref<1xf32>) -> ()
|
||||
"lmhlo.shift_right_logical"(%arg0, %arg1, %arg_out) : (memref<1xf32>, memref<1xf32>, memref<1xf32>) -> ()
|
||||
return
|
||||
}
|
||||
|
||||
@ -801,14 +801,14 @@ func @shift_right_logical_memrefs(%arg0: memref<1xf32>, %arg1: memref<1xf32>, %a
|
||||
|
||||
// CHECK-LABEL: func @all_reduce_memrefs
|
||||
func @all_reduce_memrefs(%arg0: memref<10xf32>, %arg_out: memref<10xf32>) -> () {
|
||||
"xla_lhlo.all_reduce"(%arg0, %arg_out) ({
|
||||
"lmhlo.all_reduce"(%arg0, %arg_out) ({
|
||||
^bb0(%lhs: tensor<f32>, %rhs: tensor<f32>):
|
||||
%max = mhlo.maximum %lhs, %rhs : tensor<f32>
|
||||
"mhlo.return"(%max) : (tensor<f32>) -> ()
|
||||
})
|
||||
{ replica_groups = dense<[[0, 2, 4, 6], [1, 3, 5, 7]]> : tensor<2x4xi64> }: (memref<10xf32>, memref<10xf32>) -> ()
|
||||
|
||||
"xla_lhlo.all_reduce"(%arg0, %arg_out) ({
|
||||
"lmhlo.all_reduce"(%arg0, %arg_out) ({
|
||||
^bb0(%lhs: tensor<f32>, %rhs: tensor<f32>):
|
||||
%max = mhlo.maximum %lhs, %rhs : tensor<f32>
|
||||
"mhlo.return"(%max) : (tensor<f32>) -> ()
|
||||
@ -826,11 +826,11 @@ func @all_reduce_memrefs(%arg0: memref<10xf32>, %arg_out: memref<10xf32>) -> ()
|
||||
|
||||
// CHECK-LABEL: func @collective_permute_memrefs
|
||||
func @collective_permute_memrefs(%arg0: memref<128x32xf32>, %arg_out: memref<128x32xf32>) -> () {
|
||||
"xla_lhlo.collective_permute"(%arg0, %arg_out) {
|
||||
"lmhlo.collective_permute"(%arg0, %arg_out) {
|
||||
source_target_pairs = dense<[[0, 1], [1, 2], [2, 3]]> : tensor<3x2xi64>
|
||||
} : (memref<128x32xf32>, memref<128x32xf32>) -> ()
|
||||
|
||||
"xla_lhlo.collective_permute"(%arg0, %arg_out) {
|
||||
"lmhlo.collective_permute"(%arg0, %arg_out) {
|
||||
source_target_pairs = dense<[[0, 1], [1, 2], [2, 3]]> : tensor<3x2xi64>,
|
||||
channel_id = { handle = 5 : i64, type = 2 : i64 }
|
||||
} : (memref<128x32xf32>, memref<128x32xf32>) -> ()
|
||||
@ -841,7 +841,7 @@ func @collective_permute_memrefs(%arg0: memref<128x32xf32>, %arg_out: memref<128
|
||||
|
||||
// CHECK-LABEL: func @fft_memrefs
|
||||
func @fft_memrefs(%arg0: memref<3x9xf32>, %arg_out: memref<3x5xcomplex<f32>>) -> () {
|
||||
"xla_lhlo.fft"(%arg0, %arg_out) {fft_length = dense<9> : tensor<1xi64>, fft_type = "RFFT"} : (memref<3x9xf32>, memref<3x5xcomplex<f32>>) -> ()
|
||||
"lmhlo.fft"(%arg0, %arg_out) {fft_length = dense<9> : tensor<1xi64>, fft_type = "RFFT"} : (memref<3x9xf32>, memref<3x5xcomplex<f32>>) -> ()
|
||||
return
|
||||
}
|
||||
|
||||
@ -852,7 +852,7 @@ func @batch_norm_grad_memrefs(%arg0: memref<8x8x8x8xf32>, %arg1: memref<8xf32>,
|
||||
%arg3: memref<8xf32>, %arg4: memref<8x8x8x8xf32>,
|
||||
%grad_operand: memref<8x8x8x8xf32>, %grad_scale: memref<8xf32>,
|
||||
%grad_offset: memref<8xf32>) -> () {
|
||||
"xla_lhlo.batch_norm_grad"(%arg0, %arg1, %arg2, %arg3, %arg4, %grad_operand, %grad_scale, %grad_offset) {epsilon = 1.000000e-03 : f32, feature_index = 3 : i64}
|
||||
"lmhlo.batch_norm_grad"(%arg0, %arg1, %arg2, %arg3, %arg4, %grad_operand, %grad_scale, %grad_offset) {epsilon = 1.000000e-03 : f32, feature_index = 3 : i64}
|
||||
: (memref<8x8x8x8xf32>, memref<8xf32>, memref<8xf32>, memref<8xf32>, memref<8x8x8x8xf32>,
|
||||
memref<8x8x8x8xf32>, memref<8xf32>, memref<8xf32>) -> ()
|
||||
return
|
||||
@ -863,7 +863,7 @@ func @batch_norm_grad_memrefs(%arg0: memref<8x8x8x8xf32>, %arg1: memref<8xf32>,
|
||||
// CHECK-LABEL: func @batch_norm_inference_memrefs
|
||||
func @batch_norm_inference_memrefs(%arg0: memref<8x8x8x8xf32>, %arg1: memref<8xf32>, %arg2: memref<8xf32>,
|
||||
%arg3: memref<8xf32>, %arg4: memref<8xf32>, %arg_out: memref<8x8x8x8xf32>) -> () {
|
||||
"xla_lhlo.batch_norm_inference"(%arg0, %arg1, %arg2, %arg3, %arg4, %arg_out) {epsilon = 1.000000e-03 : f32, feature_index = 3 : i64}
|
||||
"lmhlo.batch_norm_inference"(%arg0, %arg1, %arg2, %arg3, %arg4, %arg_out) {epsilon = 1.000000e-03 : f32, feature_index = 3 : i64}
|
||||
: (memref<8x8x8x8xf32>, memref<8xf32>, memref<8xf32>, memref<8xf32>, memref<8xf32>, memref<8x8x8x8xf32>) -> ()
|
||||
return
|
||||
}
|
||||
@ -874,7 +874,7 @@ func @batch_norm_inference_memrefs(%arg0: memref<8x8x8x8xf32>, %arg1: memref<8xf
|
||||
func @batch_norm_training_memrefs(%arg0: memref<8x8x8x8xf32>, %arg1: memref<8xf32>, %arg2: memref<8xf32>,
|
||||
%output: memref<8x8x8x8xf32>, %batch_mean: memref<8xf32>,
|
||||
%batch_var: memref<8xf32>) -> () {
|
||||
"xla_lhlo.batch_norm_training"(%arg0, %arg1, %arg2, %output, %batch_mean, %batch_var) {epsilon = 1.000000e-03 : f32, feature_index = 3 : i64}
|
||||
"lmhlo.batch_norm_training"(%arg0, %arg1, %arg2, %output, %batch_mean, %batch_var) {epsilon = 1.000000e-03 : f32, feature_index = 3 : i64}
|
||||
: (memref<8x8x8x8xf32>, memref<8xf32>, memref<8xf32>, memref<8x8x8x8xf32>, memref<8xf32>, memref<8xf32>) -> ()
|
||||
return
|
||||
}
|
||||
@ -883,8 +883,8 @@ func @batch_norm_training_memrefs(%arg0: memref<8x8x8x8xf32>, %arg1: memref<8xf3
|
||||
|
||||
// CHECK-LABEL: func @cholesky_memrefs
|
||||
func @cholesky_memrefs(%arg0: memref<1x291x291xf32>, %arg_out: memref<1x291x291xf32>) -> () {
|
||||
"xla_lhlo.cholesky"(%arg0, %arg_out) : (memref<1x291x291xf32>, memref<1x291x291xf32>) -> ()
|
||||
"xla_lhlo.cholesky"(%arg0, %arg_out) { lower = true } : (memref<1x291x291xf32>, memref<1x291x291xf32>) -> ()
|
||||
"lmhlo.cholesky"(%arg0, %arg_out) : (memref<1x291x291xf32>, memref<1x291x291xf32>) -> ()
|
||||
"lmhlo.cholesky"(%arg0, %arg_out) { lower = true } : (memref<1x291x291xf32>, memref<1x291x291xf32>) -> ()
|
||||
return
|
||||
}
|
||||
|
||||
@ -892,7 +892,7 @@ func @cholesky_memrefs(%arg0: memref<1x291x291xf32>, %arg_out: memref<1x291x291x
|
||||
|
||||
// CHECK-LABEL: func @infeed_memrefs
|
||||
func @infeed_memrefs(%arg_out: memref<3xf32>) -> () {
|
||||
"xla_lhlo.infeed"(%arg_out) { config = "x" } : (memref<3xf32>) -> ()
|
||||
"lmhlo.infeed"(%arg_out) { config = "x" } : (memref<3xf32>) -> ()
|
||||
return
|
||||
}
|
||||
|
||||
@ -900,7 +900,7 @@ func @infeed_memrefs(%arg_out: memref<3xf32>) -> () {
|
||||
|
||||
// CHECK-LABEL: func @outfeed_memrefs
|
||||
func @outfeed_memrefs(%arg0: memref<3xf32>) -> () {
|
||||
"xla_lhlo.outfeed"(%arg0) { config = "x" } : (memref<3xf32>) -> ()
|
||||
"lmhlo.outfeed"(%arg0) { config = "x" } : (memref<3xf32>) -> ()
|
||||
return
|
||||
}
|
||||
|
||||
@ -908,7 +908,7 @@ func @outfeed_memrefs(%arg0: memref<3xf32>) -> () {
|
||||
|
||||
// CHECK-LABEL: func @replica_id_memrefs
|
||||
func @replica_id_memrefs(%arg_out: memref<ui32>) -> () {
|
||||
"xla_lhlo.replica_id"(%arg_out) : (memref<ui32>) -> ()
|
||||
"lmhlo.replica_id"(%arg_out) : (memref<ui32>) -> ()
|
||||
return
|
||||
}
|
||||
|
||||
@ -916,7 +916,7 @@ func @replica_id_memrefs(%arg_out: memref<ui32>) -> () {
|
||||
|
||||
// CHECK-LABEL: func @triangular_solve_memrefs
|
||||
func @triangular_solve_memrefs(%arg0: memref<4x4xf32>, %arg1: memref<3x4xf32>, %arg_out: memref<3x4xf32>) -> () {
|
||||
"xla_lhlo.triangular_solve"(%arg0, %arg1, %arg_out) {left_side = true, lower = true, transpose_a = "NO_TRANSPOSE", unit_diagonal = true}
|
||||
"lmhlo.triangular_solve"(%arg0, %arg1, %arg_out) {left_side = true, lower = true, transpose_a = "NO_TRANSPOSE", unit_diagonal = true}
|
||||
: (memref<4x4xf32>, memref<3x4xf32>, memref<3x4xf32>) -> ()
|
||||
return
|
||||
}
|
||||
@ -925,9 +925,9 @@ func @triangular_solve_memrefs(%arg0: memref<4x4xf32>, %arg1: memref<3x4xf32>, %
|
||||
|
||||
// CHECK-LABEL: func @while_memrefs
|
||||
func @while_memrefs(%arg0: memref<i64>, %arg_out: memref<i64>) -> () {
|
||||
"xla_lhlo.while"(%arg0, %arg_out) (
|
||||
{ ^bb0(%arg: memref<i64>, %cond: memref<i1>): "xla_lhlo.terminator"() : () -> () },
|
||||
{ ^bb0(%arg: memref<i64>, %body_out: memref<i64>): "xla_lhlo.terminator"() : () -> () }
|
||||
"lmhlo.while"(%arg0, %arg_out) (
|
||||
{ ^bb0(%arg: memref<i64>, %cond: memref<i1>): "lmhlo.terminator"() : () -> () },
|
||||
{ ^bb0(%arg: memref<i64>, %body_out: memref<i64>): "lmhlo.terminator"() : () -> () }
|
||||
) : (memref<i64>, memref<i64>) -> ()
|
||||
return
|
||||
}
|
||||
@ -936,9 +936,9 @@ func @while_memrefs(%arg0: memref<i64>, %arg_out: memref<i64>) -> () {
|
||||
|
||||
// CHECK-LABEL: func @while_memrefs
|
||||
func @while_memrefs(%arg0: memref<i64>, %arg1: memref<5xf32>, %arg0_out: memref<i64>, %arg1_out: memref<5xf32>) -> () {
|
||||
"xla_lhlo.while"(%arg0, %arg1, %arg0_out, %arg1_out) (
|
||||
{ ^bb0(%cur0: memref<i64>, %cur1: memref<5xf32>, %cond: memref<i1>): "xla_lhlo.terminator"() : () -> () },
|
||||
{ ^bb0(%cur0: memref<i64>, %cur1: memref<5xf32>, %body_out0: memref<i64>, %body_out1: memref<5xf32>): "xla_lhlo.terminator"() : () -> () }
|
||||
"lmhlo.while"(%arg0, %arg1, %arg0_out, %arg1_out) (
|
||||
{ ^bb0(%cur0: memref<i64>, %cur1: memref<5xf32>, %cond: memref<i1>): "lmhlo.terminator"() : () -> () },
|
||||
{ ^bb0(%cur0: memref<i64>, %cur1: memref<5xf32>, %body_out0: memref<i64>, %body_out1: memref<5xf32>): "lmhlo.terminator"() : () -> () }
|
||||
) : (memref<i64>, memref<5xf32>, memref<i64>, memref<5xf32>) -> ()
|
||||
return
|
||||
}
|
||||
@ -947,7 +947,7 @@ func @while_memrefs(%arg0: memref<i64>, %arg1: memref<5xf32>, %arg0_out: memref<
|
||||
|
||||
// CHECK-LABEL: func @bitcast_memrefs
|
||||
func @bitcast_memrefs(%arg0: memref<1xf64>, %arg_out: memref<2xi32>) -> () {
|
||||
"xla_lhlo.bitcast"(%arg0, %arg_out) : (memref<1xf64>, memref<2xi32>) -> ()
|
||||
"lmhlo.bitcast"(%arg0, %arg_out) : (memref<1xf64>, memref<2xi32>) -> ()
|
||||
return
|
||||
}
|
||||
|
||||
@ -956,7 +956,7 @@ func @bitcast_memrefs(%arg0: memref<1xf64>, %arg_out: memref<2xi32>) -> () {
|
||||
// CHECK-LABEL: func @scatter_memrefs
|
||||
func @scatter_memrefs(%input: memref<200x100x300xf32>, %indices: memref<10x2xi32>,
|
||||
%updates: memref<10x300xf32>, %arg_out: memref<200x100x300xf32>) -> () {
|
||||
"xla_lhlo.scatter" (%input, %indices, %updates, %arg_out) ({
|
||||
"lmhlo.scatter" (%input, %indices, %updates, %arg_out) ({
|
||||
^bb0(%lhs: tensor<f32>, %rhs: tensor<f32>): // no predecessors
|
||||
%add = mhlo.add %lhs, %rhs : tensor<f32>
|
||||
"mhlo.return"(%add) : (tensor<f32>) -> ()
|
||||
@ -977,7 +977,7 @@ func @scatter_memrefs(%input: memref<200x100x300xf32>, %indices: memref<10x2xi32
|
||||
|
||||
// CHECK-LABEL: func @map_memrefs
|
||||
func @map_memrefs(%arg0: memref<20xf32>, %arg1: memref<20xf32>, %arg_out: memref<20xf32>) -> () {
|
||||
"xla_lhlo.map"(%arg0, %arg1, %arg_out) ({
|
||||
"lmhlo.map"(%arg0, %arg1, %arg_out) ({
|
||||
^bb0(%a: tensor<f32>, %b: tensor<f32>):
|
||||
%c = mhlo.add %a, %b : tensor<f32>
|
||||
"mhlo.return"(%c) : (tensor<f32>) -> ()
|
||||
@ -989,7 +989,7 @@ func @map_memrefs(%arg0: memref<20xf32>, %arg1: memref<20xf32>, %arg_out: memref
|
||||
|
||||
func @map_memrefs(%arg0: memref<20xf32>, %arg1: memref<20xf32>, %arg_out: memref<10xf32>) -> () {
|
||||
// expected-error@+1{{requires the same shape for all operands}}
|
||||
"xla_lhlo.map"(%arg0, %arg1, %arg_out) ({
|
||||
"lmhlo.map"(%arg0, %arg1, %arg_out) ({
|
||||
^bb0(%a: tensor<f32>, %b: tensor<f32>):
|
||||
%c = mhlo.add %a, %b : tensor<f32>
|
||||
"mhlo.return"(%c) : (tensor<f32>) -> ()
|
||||
@ -1001,7 +1001,7 @@ func @map_memrefs(%arg0: memref<20xf32>, %arg1: memref<20xf32>, %arg_out: memref
|
||||
|
||||
// CHECK-LABEL: func @rng_get_and_update_state_memrefs
|
||||
func @rng_get_and_update_state_memrefs(%state: memref<1xui64>) -> () {
|
||||
"xla_lhlo.rng_get_and_update_state"(%state) { delta = 1 : i64 } : (memref<1xui64>) -> ()
|
||||
"lmhlo.rng_get_and_update_state"(%state) { delta = 1 : i64 } : (memref<1xui64>) -> ()
|
||||
return
|
||||
}
|
||||
|
||||
@ -1010,7 +1010,7 @@ func @rng_get_and_update_state_memrefs(%state: memref<1xui64>) -> () {
|
||||
// CHECK-LABEL: func @sort_memrefs
|
||||
func @sort_memrefs(%arg0: memref<16x16xf32>, %arg1: memref<16x16xf16>,
|
||||
%out0: memref<16x16xf32>, %out1: memref<16x16xf16>) -> () {
|
||||
"xla_lhlo.sort"(%arg0, %arg1, %out0, %out1) ( {
|
||||
"lmhlo.sort"(%arg0, %arg1, %out0, %out1) ( {
|
||||
^bb0(%a: tensor<f32>, %b: tensor<f32>, %c: tensor<f16>, %d: tensor<f16>):
|
||||
%7 = "mhlo.compare"(%a, %b) {comparison_direction = "GT"} : (tensor<f32>, tensor<f32>) -> tensor<i1>
|
||||
"mhlo.return"(%7) : (tensor<i1>) -> ()
|
||||
@ -1023,7 +1023,7 @@ func @sort_memrefs(%arg0: memref<16x16xf32>, %arg1: memref<16x16xf16>,
|
||||
// CHECK-LABEL: func @sort_memrefs
|
||||
func @sort_memrefs(%arg0: memref<16x16xf32>, %arg1: memref<16x16xf16>,
|
||||
%out0: memref<16x16xf32>, %out1: memref<16x16xf16>) -> () {
|
||||
"xla_lhlo.sort"(%arg0, %arg1, %out0, %out1) ( {
|
||||
"lmhlo.sort"(%arg0, %arg1, %out0, %out1) ( {
|
||||
^bb0(%a: tensor<f32>, %b: tensor<f32>, %c: tensor<f16>, %d: tensor<f16>):
|
||||
%7 = "mhlo.compare"(%a, %b) {comparison_direction = "GT"} : (tensor<f32>, tensor<f32>) -> tensor<i1>
|
||||
"mhlo.return"(%7) : (tensor<i1>) -> ()
|
||||
@ -1036,7 +1036,7 @@ func @sort_memrefs(%arg0: memref<16x16xf32>, %arg1: memref<16x16xf16>,
|
||||
// CHECK-LABEL: func @sort_memrefs
|
||||
func @sort_memrefs(%arg0: memref<16x16xf32>, %arg1: memref<16x16xf16>,
|
||||
%out0: memref<16x16xf32>, %out1: memref<16x16xf16>) -> () {
|
||||
"xla_lhlo.sort"(%arg0, %arg1, %out0, %out1) ( {
|
||||
"lmhlo.sort"(%arg0, %arg1, %out0, %out1) ( {
|
||||
^bb0(%a: tensor<f32>, %b: tensor<f32>, %c: tensor<f16>, %d: tensor<f16>):
|
||||
%7 = "mhlo.compare"(%a, %b) {comparison_direction = "GT"} : (tensor<f32>, tensor<f32>) -> tensor<i1>
|
||||
"mhlo.return"(%7) : (tensor<i1>) -> ()
|
||||
|
@ -30,7 +30,7 @@ struct PassConfig {
|
||||
explicit PassConfig(QuantizationSpecs specs)
|
||||
: emit_builtin_tflite_ops(true),
|
||||
lower_tensor_list_ops(false),
|
||||
trim_functions_whitelist({}),
|
||||
trim_functions_allowlist({}),
|
||||
quant_specs(std::move(specs)),
|
||||
form_clusters(false),
|
||||
unfold_batch_matmul(true),
|
||||
@ -44,8 +44,8 @@ struct PassConfig {
|
||||
// If `lower_tensor_list_ops` is true, tensorlist ops will be lowered to basic
|
||||
// TF ops before legalization to TF Lite dialect.
|
||||
bool lower_tensor_list_ops;
|
||||
// The whitelist of functions that would be preserved after trimming.
|
||||
llvm::ArrayRef<std::string> trim_functions_whitelist;
|
||||
// The allowlist of functions that would be preserved after trimming.
|
||||
llvm::ArrayRef<std::string> trim_functions_allowlist;
|
||||
// All information about quantization.
|
||||
QuantizationSpecs quant_specs;
|
||||
// If `form_clusters` is true , clusters are formed by grouping consecutive
|
||||
|
@ -71,7 +71,7 @@ limitations under the License.
|
||||
#include "tensorflow/core/platform/errors.h"
|
||||
#include "tensorflow/core/platform/logging.h"
|
||||
#include "tensorflow/core/platform/status.h"
|
||||
#include "tensorflow/lite/delegates/flex/whitelisted_flex_ops.h"
|
||||
#include "tensorflow/lite/delegates/flex/allowlisted_flex_ops.h"
|
||||
#include "tensorflow/lite/kernels/internal/kernel_utils.h"
|
||||
#include "tensorflow/lite/schema/schema_generated.h"
|
||||
#include "tensorflow/lite/string_util.h"
|
||||
@ -101,7 +101,7 @@ using mlir::Value;
|
||||
using tensorflow::OpOrArgLocNameMapper;
|
||||
using tensorflow::OpOrArgNameMapper;
|
||||
using tensorflow::Status;
|
||||
using tflite::flex::IsWhitelistedFlexOp;
|
||||
using tflite::flex::IsAllowlistedFlexOp;
|
||||
using xla::StatusOr;
|
||||
|
||||
template <typename T>
|
||||
@ -972,7 +972,7 @@ Optional<BufferOffset<tflite::Operator>> Translator::BuildOperator(
|
||||
// model is of an open op system.
|
||||
//
|
||||
// The following algorithm is followed:
|
||||
// if flex is enabled and the op is whitelisted as flex
|
||||
// if flex is enabled and the op is allowlisted as flex
|
||||
// we emit op as flex.
|
||||
// if custom is enabled
|
||||
// we emit the op as custom.
|
||||
@ -982,11 +982,11 @@ Optional<BufferOffset<tflite::Operator>> Translator::BuildOperator(
|
||||
}
|
||||
|
||||
// Flex op case
|
||||
// Eventually, the whitelist will go away and we will rely on some TF op
|
||||
// Eventually, the allowlist will go away and we will rely on some TF op
|
||||
// trait (e.g. No side effect) to determine if it is a supported "Flex"
|
||||
// op or not.
|
||||
if (enabled_op_types_.contains(OpType::kSelectTf) &&
|
||||
IsWhitelistedFlexOp(node_def->op())) {
|
||||
IsAllowlistedFlexOp(node_def->op())) {
|
||||
// Construct ops as flex op encoding TensorFlow node definition
|
||||
// as custom options.
|
||||
// Flex ops are named with the kFlexOpNamePrefix prefix to the actual
|
||||
@ -1037,7 +1037,7 @@ Optional<BufferOffset<tflite::Operator>> Translator::BuildOperator(
|
||||
}
|
||||
|
||||
// Insert failed op to `flex_ops` or `custom_ops`.
|
||||
if (IsWhitelistedFlexOp(node_def->op())) {
|
||||
if (IsAllowlistedFlexOp(node_def->op())) {
|
||||
failed_flex_ops_.insert(os.str());
|
||||
} else {
|
||||
failed_custom_ops_.insert(os.str());
|
||||
|
@ -443,8 +443,7 @@ StatusOr<Operation*> BuildConstOp(const tflite::TensorT& tensor,
|
||||
if (auto float_type = elem_type.dyn_cast<mlir::FloatType>()) {
|
||||
TF_ASSIGN_OR_RETURN(value,
|
||||
ConvertFloatBuffer(shaped_type, float_type, buffer));
|
||||
} else if (elem_type.isa<mlir::IntegerType>() ||
|
||||
elem_type.isa<QuantizedType>()) {
|
||||
} else if (elem_type.isa<mlir::IntegerType, QuantizedType>()) {
|
||||
TF_ASSIGN_OR_RETURN(value,
|
||||
ConvertIntBuffer(shaped_type, elem_type, buffer));
|
||||
} else if (elem_type.isa<mlir::TF::StringType>()) {
|
||||
@ -456,8 +455,7 @@ StatusOr<Operation*> BuildConstOp(const tflite::TensorT& tensor,
|
||||
refs.push_back({ref.data(), ref.size()});
|
||||
|
||||
value = mlir::DenseStringElementsAttr::get(shaped_type, refs);
|
||||
} else if (elem_type.isa<mlir::ComplexType>() ||
|
||||
elem_type.isa<mlir::TF::TensorFlowType>()) {
|
||||
} else if (elem_type.isa<mlir::ComplexType, mlir::TF::TensorFlowType>()) {
|
||||
auto dialect = elem_type.getContext()->getRegisteredDialect("tf");
|
||||
tensorflow::TensorProto repr = ConvertTfliteConstTensor(tensor, buffer);
|
||||
std::string mangled = tensorflow::mangling_util::MangleTensor(repr);
|
||||
|
@ -694,8 +694,7 @@ void QuantizationDriver::SetupAllStates() {
|
||||
fn_.walk([&](Operation *op) {
|
||||
if (op->isKnownTerminator() ||
|
||||
op->hasTrait<OpTrait::quant::NoQuantizableResult>() ||
|
||||
llvm::isa<quant::DequantizeCastOp>(op) ||
|
||||
llvm::isa<quant::QuantizeCastOp>(op))
|
||||
llvm::isa<quant::DequantizeCastOp, quant::QuantizeCastOp>(op))
|
||||
return;
|
||||
work_list_.push_back(op);
|
||||
|
||||
|
@ -386,8 +386,7 @@ struct FoldTrivalRequantizeOp : public OpRewritePattern<RQ> {
|
||||
|
||||
Operation* def = pre_quantized.getDefiningOp();
|
||||
if (!def) return failure();
|
||||
if (llvm::isa<FixedOutputRangeInterface>(def) ||
|
||||
llvm::isa<SameScalesOpInterface>(def) ||
|
||||
if (llvm::isa<FixedOutputRangeInterface, SameScalesOpInterface>(def) ||
|
||||
def->hasTrait<OpTrait::quant::NoQuantizableResult>()) {
|
||||
return failure();
|
||||
}
|
||||
|
@ -1,4 +1,4 @@
|
||||
// RUN: tf-opt %s -tfl-prepare-quantize -tfl-test-quantize-whitelist="quantize_float_placeholder_only,not_reset_input" | FileCheck %s
|
||||
// RUN: tf-opt %s -tfl-prepare-quantize -tfl-test-quantize-allowlist="quantize_float_placeholder_only,not_reset_input" | FileCheck %s
|
||||
|
||||
// CHECK-LABEL: quantize_float_placeholder_only
|
||||
func @quantize_float_placeholder_only(%arg0: tensor<f32>, %arg1: tensor<2x3xi32>, %arg2: tensor<2x3xf32>) -> (tensor<f32>, tensor<2x3xi32>, tensor<2x3xf32>) {
|
||||
|
@ -1,4 +1,4 @@
|
||||
// RUN: tf-opt -tfl-trim-funcs-tf -tfl-trim-funcs-whitelist="bar,foobar" %s | FileCheck %s
|
||||
// RUN: tf-opt -tfl-trim-funcs-tf -tfl-trim-funcs-allowlist="bar,foobar" %s | FileCheck %s
|
||||
|
||||
func @foo(%arg0: tensor<1x4xf32>, %arg1: tensor<1x4xf32>) -> tensor<1x4xf32> {
|
||||
return %arg0 : tensor<1x4xf32>
|
||||
|
@ -560,7 +560,7 @@ struct FuseBinaryOpToFollowingAffineOp : public OpRewritePattern<AffineOpType> {
|
||||
return failure();
|
||||
ShapedType filter_type = filter_cst.getType();
|
||||
|
||||
if (llvm::isa<AddOp>(binary_op) || llvm::isa<SubOp>(binary_op)) {
|
||||
if (llvm::isa<AddOp, SubOp>(binary_op)) {
|
||||
auto padding = fc_op.template getAttrOfType<StringAttr>("padding");
|
||||
if (padding && padding.getValue() != "VALID") return failure();
|
||||
|
||||
@ -606,7 +606,7 @@ struct FuseBinaryOpToFollowingAffineOp : public OpRewritePattern<AffineOpType> {
|
||||
rewriter.create<ConstOp>(fc_op.getLoc(), new_bias_type, new_bias);
|
||||
fc_op.setOperand(0, binary_op->getOperand(0));
|
||||
fc_op.setOperand(2, new_bias_op);
|
||||
} else if (llvm::isa<MulOp>(binary_op) || llvm::isa<DivOp>(binary_op)) {
|
||||
} else if (llvm::isa<MulOp, DivOp>(binary_op)) {
|
||||
// The fusion of mul/div is actually applying the following
|
||||
// transformation:
|
||||
// w * (x ' c) + b => (w ' c) x + b
|
||||
|
@ -61,7 +61,7 @@ std::unique_ptr<OperationPass<FuncOp>> CreatePostQuantizePass(
|
||||
// Creates an instance of the TensorFlow Lite dialect TrimFunctions
|
||||
// pass.
|
||||
std::unique_ptr<OperationPass<ModuleOp>> CreateTrimFunctionsPass(
|
||||
llvm::ArrayRef<std::string> trim_funcs_whitelist);
|
||||
llvm::ArrayRef<std::string> trim_funcs_allowlist);
|
||||
|
||||
// Creates an instance of the TensorFlow Lite dialect PrepareCompositeFunctions
|
||||
// pass.
|
||||
|
@ -35,9 +35,9 @@ limitations under the License.
|
||||
#include "tensorflow/core/framework/types.pb.h"
|
||||
|
||||
// NOLINTNEXTLINE
|
||||
static llvm::cl::list<std::string> quantize_whitelist(
|
||||
"tfl-test-quantize-whitelist", llvm::cl::value_desc("list"),
|
||||
llvm::cl::desc("comma separated list of whitelisted functions to be "
|
||||
static llvm::cl::list<std::string> quantize_allowlist(
|
||||
"tfl-test-quantize-allowlist", llvm::cl::value_desc("list"),
|
||||
llvm::cl::desc("comma separated list of allowlisted functions to be "
|
||||
"quantized. Only used in tests"),
|
||||
llvm::cl::CommaSeparated);
|
||||
|
||||
@ -108,7 +108,7 @@ class PrepareQuantizePass
|
||||
|
||||
// Get the min and max values from the quantization specification for the
|
||||
// current function function and argument index. Uses default values if
|
||||
// the function is specified in the `quantize_whitelist`.
|
||||
// the function is specified in the `quantize_allowlist`.
|
||||
std::pair<llvm::Optional<double>, llvm::Optional<double>>
|
||||
GetMinMaxValuesForArgument(llvm::StringRef func_name, int index) {
|
||||
if (func_name == quant_specs_.target_func) {
|
||||
@ -132,7 +132,7 @@ bool PrepareQuantizePass::SetInputNodesQuantizationParams(FuncOp func) {
|
||||
// Skip this function because it isn't the target function from the spec or
|
||||
// in the function while list.
|
||||
if (target_func != func_name &&
|
||||
!llvm::is_contained(quantize_whitelist, func_name)) {
|
||||
!llvm::is_contained(quantize_allowlist, func_name)) {
|
||||
return false;
|
||||
}
|
||||
|
||||
|
@ -29,12 +29,12 @@ limitations under the License.
|
||||
#include "mlir/Pass/Pass.h" // from @llvm-project
|
||||
#include "tensorflow/compiler/mlir/lite/transforms/passes.h"
|
||||
|
||||
// The cmd line flag to specify the whitelist of functions. Rest are trimmed
|
||||
// The cmd line flag to specify the allowlist of functions. Rest are trimmed
|
||||
// after this pass is run.
|
||||
// NOLINTNEXTLINE
|
||||
static llvm::cl::list<std::string> trim_funcs_whitelist(
|
||||
"tfl-trim-funcs-whitelist", llvm::cl::value_desc("list"),
|
||||
llvm::cl::desc("comma separated list of whitelisted functions. The first "
|
||||
static llvm::cl::list<std::string> trim_funcs_allowlist(
|
||||
"tfl-trim-funcs-allowlist", llvm::cl::value_desc("list"),
|
||||
llvm::cl::desc("comma separated list of allowlisted functions. The first "
|
||||
"function specified will be used as main."),
|
||||
llvm::cl::CommaSeparated);
|
||||
|
||||
@ -43,25 +43,25 @@ namespace TFL {
|
||||
namespace {
|
||||
|
||||
// The pass to trim functions before we legalize to TFL
|
||||
// dialect using the specified whitelist.
|
||||
// dialect using the specified allowlist.
|
||||
class TrimFunctionsPass
|
||||
: public mlir::PassWrapper<TrimFunctionsPass, OperationPass<ModuleOp>> {
|
||||
public:
|
||||
explicit TrimFunctionsPass() : trim_funcs_whitelist_(trim_funcs_whitelist) {}
|
||||
explicit TrimFunctionsPass(llvm::ArrayRef<std::string> trim_funcs_whitelist)
|
||||
: trim_funcs_whitelist_(trim_funcs_whitelist) {}
|
||||
explicit TrimFunctionsPass() : trim_funcs_allowlist_(trim_funcs_allowlist) {}
|
||||
explicit TrimFunctionsPass(llvm::ArrayRef<std::string> trim_funcs_allowlist)
|
||||
: trim_funcs_allowlist_(trim_funcs_allowlist) {}
|
||||
|
||||
private:
|
||||
void runOnOperation() override;
|
||||
bool TrimModule();
|
||||
void Verify();
|
||||
|
||||
llvm::ArrayRef<std::string> trim_funcs_whitelist_;
|
||||
llvm::ArrayRef<std::string> trim_funcs_allowlist_;
|
||||
};
|
||||
|
||||
void TrimFunctionsPass::runOnOperation() {
|
||||
// trim the functions in the module using the trim_funcs_whitelist_
|
||||
// by removing functions not in the whitelist.
|
||||
// trim the functions in the module using the trim_funcs_allowlist_
|
||||
// by removing functions not in the allowlist.
|
||||
if (TrimModule()) {
|
||||
// verify the updated module is still valid, if not signal the
|
||||
// pass as failed.
|
||||
@ -70,20 +70,20 @@ void TrimFunctionsPass::runOnOperation() {
|
||||
}
|
||||
|
||||
bool TrimFunctionsPass::TrimModule() {
|
||||
// if no trim_funcs_whitelist_ is specified, this pass is a no-op.
|
||||
if (trim_funcs_whitelist_.empty()) return false;
|
||||
// if no trim_funcs_allowlist_ is specified, this pass is a no-op.
|
||||
if (trim_funcs_allowlist_.empty()) return false;
|
||||
|
||||
llvm::SmallVector<FuncOp, 4> funcs_to_trim;
|
||||
for (auto func : getOperation().getOps<FuncOp>()) {
|
||||
if (llvm::is_contained(trim_funcs_whitelist_, func.getName())) {
|
||||
// If no main is specified in the whitelist, use the 1st func
|
||||
// in trim_funcs_whitelist as the main.
|
||||
if (llvm::is_contained(trim_funcs_allowlist_, func.getName())) {
|
||||
// If no main is specified in the allowlist, use the 1st func
|
||||
// in trim_funcs_allowlist as the main.
|
||||
// TODO(ashwinm): Currently tflite flatbuffer export assumes there is
|
||||
// always a main. This is strictly not required for TFlite. We need to
|
||||
// remove that restriction once we have support to attribute the main
|
||||
// tensorflow function in MLIR TF import using an entry_point attr.
|
||||
if (!llvm::is_contained(trim_funcs_whitelist_, "main") &&
|
||||
func.getName() == trim_funcs_whitelist_[0]) {
|
||||
if (!llvm::is_contained(trim_funcs_allowlist_, "main") &&
|
||||
func.getName() == trim_funcs_allowlist_[0]) {
|
||||
func.setName("main");
|
||||
}
|
||||
} else {
|
||||
@ -99,7 +99,7 @@ bool TrimFunctionsPass::TrimModule() {
|
||||
}
|
||||
|
||||
// validate that all reachable functions from the remaining functions are
|
||||
// also in the whitelist.
|
||||
// also in the allowlist.
|
||||
void TrimFunctionsPass::Verify() {
|
||||
// TODO(ashwinm): Instead, we should make sure that references to all
|
||||
// SymbolRefAttrs of all ops are present.
|
||||
@ -109,7 +109,7 @@ void TrimFunctionsPass::Verify() {
|
||||
auto walk_result = func.walk([&](CallOp op) -> WalkResult {
|
||||
if (!symbol_table.lookup<FuncOp>(op.getCallee()))
|
||||
return getOperation().emitError()
|
||||
<< func.getName() << " is not in the funcs whitelist";
|
||||
<< func.getName() << " is not in the funcs allowlist";
|
||||
return WalkResult::advance();
|
||||
});
|
||||
if (walk_result.wasInterrupted()) return signalPassFailure();
|
||||
@ -121,13 +121,13 @@ void TrimFunctionsPass::Verify() {
|
||||
// Creates an instance of the TensorFlow Lite dialect TrimFunctions
|
||||
/// pass.
|
||||
std::unique_ptr<OperationPass<ModuleOp>> CreateTrimFunctionsPass(
|
||||
llvm::ArrayRef<std::string> trim_funcs_whitelist) {
|
||||
return std::make_unique<TrimFunctionsPass>(trim_funcs_whitelist);
|
||||
llvm::ArrayRef<std::string> trim_funcs_allowlist) {
|
||||
return std::make_unique<TrimFunctionsPass>(trim_funcs_allowlist);
|
||||
}
|
||||
|
||||
static PassRegistration<TrimFunctionsPass> pass(
|
||||
"tfl-trim-funcs-tf",
|
||||
"Trim functions to restrict them to a specified whitelist prior to "
|
||||
"Trim functions to restrict them to a specified allowlist prior to "
|
||||
"legalization to TensorFlow lite dialect");
|
||||
|
||||
} // namespace TFL
|
||||
|
@ -624,6 +624,7 @@ cc_library(
|
||||
"transforms/tpu_rewrite_pass.cc",
|
||||
"transforms/tpu_sharding_identification_pass.cc",
|
||||
"transforms/tpu_space_to_depth_pass.cc",
|
||||
"transforms/tpu_update_embedding_enqueue_op_inputs.cc",
|
||||
"transforms/tpu_variable_runtime_reformatting.cc",
|
||||
"translate/breakup-islands.cc",
|
||||
"translate/tf_executor_to_functional.cc",
|
||||
|
@ -168,8 +168,7 @@ void ResourceAliasAnalysis::AnalyzeFunction(FuncOp func_op) {
|
||||
var_handle.resource(),
|
||||
GetOrCreateIdForVarHandle(var_handle, &next_unique_id,
|
||||
&var_handle_name_id_map));
|
||||
} else if (llvm::isa<TF::IdentityNOp>(op) ||
|
||||
llvm::isa<TF::IdentityOp>(op)) {
|
||||
} else if (llvm::isa<TF::IdentityNOp, TF::IdentityOp>(op)) {
|
||||
for (auto operand_and_result :
|
||||
llvm::zip(op->getOperands(), op->getResults())) {
|
||||
forward_input_to_output(std::get<0>(operand_and_result),
|
||||
@ -333,7 +332,7 @@ bool OpIsDeclaration(Operation* op,
|
||||
const ResourceAliasAnalysis& alias_analysis) {
|
||||
// TODO(yuanzx): Add other types of resources.
|
||||
return llvm::isa<TF::VarHandleOp>(op) ||
|
||||
((llvm::isa<TF::IdentityNOp>(op) || llvm::isa<TF::IdentityOp>(op)) &&
|
||||
(llvm::isa<TF::IdentityNOp, TF::IdentityOp>(op) &&
|
||||
!FindAccessedResources(op, alias_analysis).empty());
|
||||
}
|
||||
|
||||
|
@ -569,8 +569,10 @@ void BuildReplicateOp(
|
||||
|
||||
// Add derived `operand_segment_sizes` attribute.
|
||||
int32_t num_replicated_inputs = replicated_inputs.size() * n;
|
||||
auto operand_segment_sizes = DenseIntElementsAttr::get(
|
||||
VectorType::get({2}, builder->getI32Type()), {num_replicated_inputs, 0});
|
||||
int32_t num_packed_inputs = packed_inputs.size();
|
||||
auto operand_segment_sizes =
|
||||
DenseIntElementsAttr::get(VectorType::get({2}, builder->getI32Type()),
|
||||
{num_replicated_inputs, num_packed_inputs});
|
||||
state->addAttribute(kOperandSegmentSizesAttr, operand_segment_sizes);
|
||||
|
||||
for (const auto& output_type : replica_output_types)
|
||||
@ -600,6 +602,65 @@ void ReplicateOp::build(
|
||||
packed_inputs, replica_output_types);
|
||||
}
|
||||
|
||||
// Returns the number of packed block arguments.
|
||||
unsigned ReplicateOp::GetNumPackedBlockArguments() {
|
||||
return packed_inputs().size();
|
||||
}
|
||||
|
||||
// Returns the number of replicated block arguments.
|
||||
unsigned ReplicateOp::GetNumReplicatedBlockArguments() {
|
||||
return GetBody().getNumArguments() - GetNumPackedBlockArguments();
|
||||
}
|
||||
|
||||
// Returns the replicated block arguments. A copy should be made if the
|
||||
// replicate op is being modified.
|
||||
llvm::ArrayRef<BlockArgument> ReplicateOp::GetReplicatedBlockArguments() {
|
||||
return GetBody().getArguments().drop_back(GetNumPackedBlockArguments());
|
||||
}
|
||||
|
||||
// Returns the packed block arguments. A copy should be made if the replicate op
|
||||
// is being modified.
|
||||
llvm::ArrayRef<BlockArgument> ReplicateOp::GetPackedBlockArguments() {
|
||||
return GetBody().getArguments().take_back(GetNumPackedBlockArguments());
|
||||
}
|
||||
|
||||
// Checks if a block argument is replicated (forwarding replicated inputs).
|
||||
bool ReplicateOp::IsReplicatedBlockArgument(BlockArgument block_arg) {
|
||||
assert(block_arg.getOwner() == &GetBody());
|
||||
return block_arg.getArgNumber() < GetNumReplicatedBlockArguments();
|
||||
}
|
||||
|
||||
// Checks if a block argument is packed (forwarding a packed input).
|
||||
bool ReplicateOp::IsPackedBlockArgument(BlockArgument block_arg) {
|
||||
return !IsReplicatedBlockArgument(block_arg);
|
||||
}
|
||||
|
||||
// Returns the operand index of the operand being forwarded as a
|
||||
// replicated/packed block argument for a given replica. This assumes a valid
|
||||
// block argument (of the replicate op) and a valid replica is provided.
|
||||
unsigned ReplicateOp::GetReplicaOperandIndexForBlockArgument(
|
||||
BlockArgument block_arg, unsigned replica) {
|
||||
const int32_t num_replicas = nAttr().getInt();
|
||||
assert(replica < num_replicas && block_arg.getOwner() == &GetBody());
|
||||
|
||||
const unsigned num_replicated_args = GetNumReplicatedBlockArguments();
|
||||
if (block_arg.getArgNumber() < num_replicated_args)
|
||||
return block_arg.getArgNumber() * num_replicas + replica;
|
||||
|
||||
return block_arg.getArgNumber() - num_replicated_args +
|
||||
replicated_inputs().size();
|
||||
}
|
||||
|
||||
// Returns the operand being forwarded as a replicated/packed block argument for
|
||||
// a given replica. This assumes a valid block argument (of the replicate op)
|
||||
// and a valid replica is provided.
|
||||
Value ReplicateOp::GetReplicaOperandForBlockArgument(BlockArgument block_arg,
|
||||
unsigned replica) {
|
||||
const unsigned operand_index =
|
||||
GetReplicaOperandIndexForBlockArgument(block_arg, replica);
|
||||
return getOperand(operand_index);
|
||||
}
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
// Canonicalization patterns
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
@ -283,6 +283,14 @@ For example:
|
||||
|
||||
let extraClassDeclaration = [{
|
||||
Block &GetBody() { return getOperation()->getRegion(0).front(); }
|
||||
unsigned GetNumReplicatedBlockArguments();
|
||||
unsigned GetNumPackedBlockArguments();
|
||||
llvm::ArrayRef<BlockArgument> GetPackedBlockArguments();
|
||||
llvm::ArrayRef<BlockArgument> GetReplicatedBlockArguments();
|
||||
bool IsReplicatedBlockArgument(BlockArgument block_arg);
|
||||
bool IsPackedBlockArgument(BlockArgument block_arg);
|
||||
unsigned GetReplicaOperandIndexForBlockArgument(BlockArgument block_arg, unsigned replica);
|
||||
Value GetReplicaOperandForBlockArgument(BlockArgument block_arg, unsigned replica);
|
||||
}];
|
||||
|
||||
let builders = [
|
||||
|
@ -71,7 +71,7 @@ struct TensorFlowExecutorInlinerInterface : public DialectInlinerInterface {
|
||||
// Allow inlining into tf.island regions if the incoming region has a single
|
||||
// block.
|
||||
return llvm::isa<tf_executor::IslandOp>(dest->getParentOp()) &&
|
||||
std::next(src->begin()) == src->end();
|
||||
llvm::hasSingleElement(*src);
|
||||
}
|
||||
};
|
||||
|
||||
|
@ -7270,6 +7270,8 @@ reshape(t, []) ==> 7
|
||||
let verifier = [{
|
||||
return Verify(*this);
|
||||
}];
|
||||
|
||||
let hasCanonicalizer = 1;
|
||||
}
|
||||
|
||||
def TF_ResizeBilinearOp : TF_Op<"ResizeBilinear", [NoSideEffect]> {
|
||||
|
@ -506,28 +506,52 @@ LogicalResult FoldOperandsPermutation(
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
namespace {
|
||||
// Folder that returns LHS of an Arithmetic Op if the RHS is a constant
|
||||
// known to be Identity (e.g X+0)
|
||||
// Fold Arithmetic Op if one of the operands is a constant known to be an
|
||||
// Identity (e.g. X+0, X*1, etc...). For commutative operations fold if
|
||||
// known identity value is either lhs or rhs.
|
||||
template <
|
||||
typename OpT,
|
||||
typename std::enable_if<llvm::is_one_of<
|
||||
OpT, AddV2Op, SubOp, MulOp, DivOp, RealDivOp>::value>::type * = nullptr>
|
||||
OpFoldResult IdentityArithmeticOpFolder(OpT arithmetic_op,
|
||||
ArrayRef<Attribute> operands) {
|
||||
auto result_op_type = arithmetic_op.getResult().getType();
|
||||
auto lhs_type = arithmetic_op.x().getType().template cast<ShapedType>();
|
||||
if (!result_op_type.template cast<ShapedType>().hasStaticShape()) return {};
|
||||
auto rhs_type = arithmetic_op.y().getType().template cast<ShapedType>();
|
||||
auto result_type =
|
||||
arithmetic_op.getResult().getType().template cast<ShapedType>();
|
||||
|
||||
// We only handle non-broadcastable case.
|
||||
if (result_op_type != lhs_type) {
|
||||
return {};
|
||||
}
|
||||
// We can fold arithmetic operation only of we can prove that we will not
|
||||
// accidentally hide a broadcasting error.
|
||||
auto is_valid_broadcasting = [](ShapedType operand_ty, ShapedType identity_ty,
|
||||
ShapedType result_ty) -> bool {
|
||||
// Scalar identity is broadcastable to any operand shape, we only need to
|
||||
// check that operand has the same shape as a result.
|
||||
bool scalar_identity = identity_ty.hasRank() && identity_ty.getRank() == 0;
|
||||
if (scalar_identity) return operand_ty == result_ty;
|
||||
|
||||
// If identity is not a scalar, we must verify that all shapes are equal
|
||||
// and statically known.
|
||||
//
|
||||
// TODO(ezhulenev): Fold if identity shape is statically know to be
|
||||
// broadcastable to the operand shape.
|
||||
return operand_ty == result_ty && identity_ty == result_ty &&
|
||||
result_ty.hasStaticShape();
|
||||
};
|
||||
|
||||
// Check that we have a constant operand on one side (candidate for identity).
|
||||
const bool is_commutative =
|
||||
(std::is_same<OpT, AddV2Op>::value || std::is_same<OpT, MulOp>::value);
|
||||
auto lhs_attr = operands[0].dyn_cast_or_null<DenseElementsAttr>();
|
||||
auto rhs_attr = operands[1].dyn_cast_or_null<DenseElementsAttr>();
|
||||
if (!rhs_attr && !(is_commutative && lhs_attr)) return {};
|
||||
|
||||
// Mul and Div ops have identity value one while AddV2 and SubOp have identity
|
||||
// value zero.
|
||||
int identity =
|
||||
const int identity =
|
||||
(std::is_same<OpT, MulOp>::value || std::is_same<OpT, DivOp>::value ||
|
||||
std::is_same<OpT, RealDivOp>::value);
|
||||
std::is_same<OpT, RealDivOp>::value)
|
||||
? 1
|
||||
: 0;
|
||||
|
||||
Type element_ty = lhs_type.getElementType();
|
||||
Attribute identity_attr;
|
||||
@ -539,23 +563,19 @@ OpFoldResult IdentityArithmeticOpFolder(OpT arithmetic_op,
|
||||
return {};
|
||||
}
|
||||
|
||||
if (auto attr = operands[1].dyn_cast_or_null<DenseElementsAttr>()) {
|
||||
if (attr.isSplat() && attr.getSplatValue() == identity_attr)
|
||||
// Fold: Op(Operand, Identity) -> Operand.
|
||||
if (rhs_attr && is_valid_broadcasting(lhs_type, rhs_type, result_type)) {
|
||||
if (rhs_attr.isSplat() && rhs_attr.getSplatValue() == identity_attr)
|
||||
return arithmetic_op.x();
|
||||
}
|
||||
|
||||
auto rhs_type = arithmetic_op.y().getType().template cast<ShapedType>();
|
||||
// TODO(chhe): we could fold and add an identity to force the broadcast.
|
||||
if (result_op_type != rhs_type) {
|
||||
return {};
|
||||
}
|
||||
|
||||
bool is_symmetric =
|
||||
(std::is_same<OpT, AddV2Op>::value || std::is_same<OpT, MulOp>::value);
|
||||
if (auto attr = operands[0].dyn_cast_or_null<DenseElementsAttr>()) {
|
||||
if (is_symmetric && attr.isSplat() && attr.getSplatValue() == identity_attr)
|
||||
// Fold: Op(Identity, Operand) -> Operand for commutative operations.
|
||||
if (lhs_attr && is_commutative &&
|
||||
is_valid_broadcasting(rhs_type, lhs_type, result_type)) {
|
||||
if (lhs_attr.isSplat() && lhs_attr.getSplatValue() == identity_attr)
|
||||
return arithmetic_op.y();
|
||||
}
|
||||
|
||||
return {};
|
||||
}
|
||||
} // namespace
|
||||
@ -1168,8 +1188,7 @@ void ConstOp::build(OpBuilder &builder, OperationState &result,
|
||||
ShapedType type;
|
||||
if (auto elem_attr = value.dyn_cast<ElementsAttr>()) {
|
||||
return ConstOp::build(builder, result, elem_attr);
|
||||
} else if (value.isa<BoolAttr>() || value.isa<FloatAttr>() ||
|
||||
value.isa<IntegerAttr>()) {
|
||||
} else if (value.isa<BoolAttr, FloatAttr, IntegerAttr>()) {
|
||||
// All TensorFlow types must be tensor types. In the build() method,
|
||||
// we want to provide more flexibility by allowing attributes of scalar
|
||||
// types. But we need to wrap it up with ElementsAttr to construct
|
||||
@ -2870,6 +2889,11 @@ void ReshapeOp::build(OpBuilder &builder, OperationState &result, Value tensor,
|
||||
return unranked();
|
||||
}
|
||||
|
||||
void ReshapeOp::getCanonicalizationPatterns(OwningRewritePatternList &results,
|
||||
MLIRContext *context) {
|
||||
results.insert<RedundantReshape>(context);
|
||||
}
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
// SelectOp
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
@ -1165,4 +1165,35 @@ array([0, 2, 2])
|
||||
);
|
||||
}
|
||||
|
||||
def TF_TPUPartitionedCallOp : TF_Op<"TPUPartitionedCall", [CallOpInterface]> {
|
||||
let summary = "Calls a function placed on a specified TPU device.";
|
||||
|
||||
let arguments = (ins
|
||||
Variadic<TF_Tensor>:$args,
|
||||
I32Tensor:$device_ordinal,
|
||||
|
||||
SymbolRefAttr:$f,
|
||||
DefaultValuedAttr<I64Attr, "0">:$autotuner_thresh
|
||||
);
|
||||
|
||||
let results = (outs
|
||||
Variadic<TF_Tensor>:$output
|
||||
);
|
||||
|
||||
TF_DerivedOperandTypeListAttr Tin = TF_DerivedOperandTypeListAttr<0>;
|
||||
TF_DerivedResultTypeListAttr Tout = TF_DerivedResultTypeListAttr<0>;
|
||||
|
||||
let extraClassDeclaration = [{
|
||||
// Gets the argument operands to the called function.
|
||||
operand_range getArgOperands() { return args(); }
|
||||
|
||||
// Returns the callee of this operation.
|
||||
CallInterfaceCallable getCallableForCallee() {
|
||||
return getAttrOfType<SymbolRefAttr>("f");
|
||||
}
|
||||
}];
|
||||
|
||||
let verifier = [{ return VerifyPartitionedCall(*this); }];
|
||||
}
|
||||
|
||||
#endif // TF_OPS
|
||||
|
@ -356,7 +356,7 @@ LogicalResult VerifyExportedFunc(FuncOp func) {
|
||||
LogicalResult TensorFlowSavedModelDialect::verifyOperationAttribute(
|
||||
Operation *op, NamedAttribute named_attr) {
|
||||
if (named_attr.first == "tf_saved_model.exported_names") {
|
||||
if (!isa<FuncOp>(op) && !isa<GlobalTensorOp>(op)) {
|
||||
if (!isa<FuncOp, GlobalTensorOp>(op)) {
|
||||
return op->emitError() << "'tf_saved_model.exported_names' must be on a "
|
||||
"'func' or 'tf_saved_model.global_tensor' op";
|
||||
}
|
||||
|
@ -90,8 +90,7 @@ class TensorFlowType : public Type {
|
||||
|
||||
// Returns true if the specified type is a valid TensorFlow element type.
|
||||
static inline bool IsValidTFElementType(Type type) {
|
||||
return type.isa<ComplexType>() || type.isa<FloatType>() ||
|
||||
type.isa<IntegerType>() || type.isa<TensorFlowType>();
|
||||
return type.isa<ComplexType, FloatType, IntegerType, TensorFlowType>();
|
||||
}
|
||||
|
||||
// Returns true if this is a valid TensorFlow tensor type.
|
||||
|
@ -190,6 +190,27 @@ func @testSubOfNeg(%arg0: tensor<8x16xf32>, %arg1: tensor<8x16xf32>) -> tensor<8
|
||||
// CHECK: return %0
|
||||
}
|
||||
|
||||
// CHECK-LABEL: testSubOfZero
|
||||
func @testSubOfZero(%arg0: tensor<?x1xf32>, %arg1: tensor<4x1xf32>) -> (tensor<?x1xf32>, tensor<4x1xf32>) {
|
||||
%0 = "tf.Const"() {value = dense<0.0> : tensor<f32>} : () -> tensor<f32>
|
||||
%1 = "tf.Sub"(%arg0, %0) : (tensor<?x1xf32>, tensor<f32>) -> tensor<?x1xf32>
|
||||
%2 = "tf.Sub"(%arg1, %0) : (tensor<4x1xf32>, tensor<f32>) -> tensor<4x1xf32>
|
||||
return %1, %2: tensor<?x1xf32>, tensor<4x1xf32>
|
||||
|
||||
// CHECK: return %arg0, %arg1
|
||||
}
|
||||
|
||||
// CHECK-LABEL: testSubOfZeroWithBroadcasting
|
||||
func @testSubOfZeroWithBroadcasting(%arg0: tensor<4x1xf32>) -> tensor<4x4xf32> {
|
||||
// This is an identity arithmetic operation, however we do not currently fold
|
||||
// it because it has a broadcasting.
|
||||
%0 = "tf.Const"() {value = dense<[[0.0, 0.0, 0.0, 0.0]]> : tensor<1x4xf32>} : () -> tensor<1x4xf32>
|
||||
%1 = "tf.Sub"(%arg0, %0) : (tensor<4x1xf32>, tensor<1x4xf32>) -> tensor<4x4xf32>
|
||||
return %1 : tensor<4x4xf32>
|
||||
|
||||
// CHECK: return %1
|
||||
}
|
||||
|
||||
// CHECK-LABEL: testSquareOfSub
|
||||
func @testSquareOfSub(%arg0: tensor<8x16xf32>, %arg1: tensor<8x16xf32>) -> tensor<8x16xf32> {
|
||||
%0 = "tf.Sub"(%arg0, %arg1) : (tensor<8x16xf32>, tensor<8x16xf32>) -> tensor<8x16xf32>
|
||||
@ -257,6 +278,46 @@ func @testAddV2OfNegRight(%arg0: tensor<8x16xf32>, %arg1: tensor<8x16xf32>) -> t
|
||||
// CHECK: return %0
|
||||
}
|
||||
|
||||
// CHECK-LABEL: testAddV2IdentityScalar
|
||||
func @testAddV2IdentityScalar(%arg0: tensor<f32>, %arg1: tensor<?xf32>, %arg2: tensor<4xf32>) -> (tensor<f32>, tensor<?xf32>, tensor<4xf32>) {
|
||||
%0 = "tf.Const"() {value = dense<0.0> : tensor<f32>} : () -> tensor<f32>
|
||||
|
||||
// Identity scalar (0.0) is foldable with operand of any shape because
|
||||
// scalar is safely broadcastable to any shape.
|
||||
|
||||
%1 = "tf.AddV2"(%arg0, %0) : (tensor<f32>, tensor<f32>) -> tensor<f32>
|
||||
%2 = "tf.AddV2"(%arg1, %0) : (tensor<?xf32>, tensor<f32>) -> tensor<?xf32>
|
||||
%3 = "tf.AddV2"(%arg2, %0) : (tensor<4xf32>, tensor<f32>) -> tensor<4xf32>
|
||||
|
||||
%4 = "tf.AddV2"(%0, %1) : (tensor<f32>, tensor<f32>) -> tensor<f32>
|
||||
%5 = "tf.AddV2"(%0, %2) : (tensor<f32>, tensor<?xf32>) -> tensor<?xf32>
|
||||
%6 = "tf.AddV2"(%0, %3) : (tensor<f32>, tensor<4xf32>) -> tensor<4xf32>
|
||||
|
||||
// CHECK: return %arg0, %arg1, %arg2
|
||||
return %4, %5, %6: tensor<f32>, tensor<?xf32>, tensor<4xf32>
|
||||
}
|
||||
|
||||
// CHECK-LABEL: testAddV2IdentityTensor
|
||||
func @testAddV2IdentityTensor(%arg0: tensor<f32>, %arg1: tensor<4xf32>) -> (tensor<4xf32>, tensor<4xf32>, tensor<4xf32>, tensor<4xf32>) {
|
||||
%0 = "tf.Const"() {value = dense<[0.0, 0.0, 0.0, 0.0]> : tensor<4xf32>} : () -> tensor<4xf32>
|
||||
|
||||
// If operand is a scalar, then the identity value (0.0 for addition) can
|
||||
// be of any shape, because operand is safely broadcastable to any shape.
|
||||
//
|
||||
// However we can't fold this arithmetic operation because the operand
|
||||
// shape does not match the result shape.
|
||||
|
||||
%1 = "tf.AddV2"(%arg0, %0) : (tensor<f32>, tensor<4xf32>) -> tensor<4xf32>
|
||||
%2 = "tf.AddV2"(%0, %arg0) : (tensor<4xf32>, tensor<f32>) -> tensor<4xf32>
|
||||
|
||||
// If operand has the same shape as a result, we can fold it.
|
||||
%3 = "tf.AddV2"(%arg1, %0) : (tensor<4xf32>, tensor<4xf32>) -> tensor<4xf32>
|
||||
%4 = "tf.AddV2"(%0, %arg1) : (tensor<4xf32>, tensor<4xf32>) -> tensor<4xf32>
|
||||
|
||||
// CHECK: return %1, %2, %arg1, %arg1
|
||||
return %1, %2, %3, %4: tensor<4xf32>, tensor<4xf32>, tensor<4xf32>, tensor<4xf32>
|
||||
}
|
||||
|
||||
// CHECK-LABEL: testDoubleConj
|
||||
func @testDoubleConj(%arg0: tensor<8x16x32x64xcomplex<f32>>) -> tensor<8x16x32x64xcomplex<f32>> {
|
||||
%0 = "tf.Conj"(%arg0) : (tensor<8x16x32x64xcomplex<f32>>) -> tensor<8x16x32x64xcomplex<f32>>
|
||||
@ -302,6 +363,20 @@ func @testDoubleReciprocal(%arg0: tensor<8x16x32x64xi32>) -> tensor<8x16x32x64xi
|
||||
// CHECK: return %arg0
|
||||
}
|
||||
|
||||
// CHECK-LABEL: testRedundantReshape
|
||||
func @testRedundantReshape(%arg0: tensor<4x4xi32>) -> tensor<2x8xi32> {
|
||||
%0 = "tf.Const"() {value = dense<[8, 2]> : tensor<2xi32>} : () -> tensor<2xi32>
|
||||
%1 = "tf.Const"() {value = dense<[2, 8]> : tensor<2xi32>} : () -> tensor<2xi32>
|
||||
%2 = "tf.Reshape"(%arg0, %0) : (tensor<4x4xi32>, tensor<2xi32>) -> tensor<8x2xi32>
|
||||
%3 = "tf.Reshape"(%2, %1) : (tensor<8x2xi32>, tensor<2xi32>) -> tensor<2x8xi32>
|
||||
return %3: tensor<2x8xi32>
|
||||
|
||||
// CHECK: %0 = "tf.Const"
|
||||
// CHECK-SAME: value = dense<[2, 8]> : tensor<2xi32>
|
||||
// CHECK: %1 = "tf.Reshape"(%arg0, %0)
|
||||
// CHECK: return %1 : tensor<2x8xi32>
|
||||
}
|
||||
|
||||
// CHECK-LABEL: testSelectScalarPred
|
||||
func @testSelectScalarPred(%arg0: tensor<i1>, %arg1: tensor<4x2xf16>, %arg2: tensor<4x2xf16>) -> tensor<4x2xf16> {
|
||||
// CHECK-NEXT: "tf.SelectV2"(%arg0, %arg1, %arg2) : (tensor<i1>, tensor<4x2xf16>, tensor<4x2xf16>) -> tensor<4x2xf16>
|
||||
|
@ -2,17 +2,17 @@
|
||||
|
||||
|
||||
func @biasAdd_NHWC(%arg0: tensor<1x32x10x32xi32>, %arg1: tensor<32xi32>) -> tensor<1x32x10x32xi32> {
|
||||
%0 = "xla_chlo.broadcast_add"(%arg0, %arg1) {broadcast_dimensions = dense<3> : tensor<1xi64>} : (tensor<1x32x10x32xi32>, tensor<32xi32>) -> tensor<1x32x10x32xi32>
|
||||
%0 = "chlo.broadcast_add"(%arg0, %arg1) {broadcast_dimensions = dense<3> : tensor<1xi64>} : (tensor<1x32x10x32xi32>, tensor<32xi32>) -> tensor<1x32x10x32xi32>
|
||||
return %0 : tensor<1x32x10x32xi32>
|
||||
}
|
||||
|
||||
func @biasAdd_NCHW(%arg0: tensor<1x32x10x32xi32>, %arg1: tensor<32xi32>) -> tensor<1x32x10x32xi32> {
|
||||
%0 = "xla_chlo.broadcast_add"(%arg0, %arg1) {broadcast_dimensions = dense<3> : tensor<1xi64>} : (tensor<1x32x10x32xi32>, tensor<32xi32>) -> tensor<1x32x10x32xi32>
|
||||
%0 = "chlo.broadcast_add"(%arg0, %arg1) {broadcast_dimensions = dense<3> : tensor<1xi64>} : (tensor<1x32x10x32xi32>, tensor<32xi32>) -> tensor<1x32x10x32xi32>
|
||||
return %0 : tensor<1x32x10x32xi32>
|
||||
}
|
||||
|
||||
func @biasAdd_dynamic(%arg0: tensor<?x?x?x?xi32>, %arg1: tensor<?xi32>) -> tensor<?x?x?x?xi32> {
|
||||
%0 = "xla_chlo.broadcast_add"(%arg0, %arg1) {broadcast_dimensions = dense<3> : tensor<1xi64>} : (tensor<?x?x?x?xi32>, tensor<?xi32>) -> tensor<?x?x?x?xi32>
|
||||
%0 = "chlo.broadcast_add"(%arg0, %arg1) {broadcast_dimensions = dense<3> : tensor<1xi64>} : (tensor<?x?x?x?xi32>, tensor<?xi32>) -> tensor<?x?x?x?xi32>
|
||||
return %0 : tensor<?x?x?x?xi32>
|
||||
}
|
||||
|
||||
@ -23,12 +23,12 @@ func @add(%arg0: tensor<2xi32>) -> tensor<2xi32> {
|
||||
}
|
||||
|
||||
func @broadcast_add(%arg0: tensor<1xi32>, %arg1: tensor<1x2xi32>) -> tensor<1x2xi32> {
|
||||
%0 = "xla_chlo.broadcast_add"(%arg0, %arg1) {broadcast_dimensions = dense<1> : tensor<1xi64>} : (tensor<1xi32>, tensor<1x2xi32>) -> tensor<1x2xi32>
|
||||
%0 = "chlo.broadcast_add"(%arg0, %arg1) {broadcast_dimensions = dense<1> : tensor<1xi64>} : (tensor<1xi32>, tensor<1x2xi32>) -> tensor<1x2xi32>
|
||||
return %0 : tensor<1x2xi32>
|
||||
}
|
||||
|
||||
func @broadcast_multi_dim_add(%arg0: tensor<4x1x1xi32>, %arg1: tensor<4x4x4x4xi32>) -> tensor<4x4x4x4xi32> {
|
||||
%0 = "xla_chlo.broadcast_add"(%arg0, %arg1) {broadcast_dimensions = dense<[1, 2, 3]> : tensor<3xi64>} : (tensor<4x1x1xi32>, tensor<4x4x4x4xi32>) -> tensor<4x4x4x4xi32>
|
||||
%0 = "chlo.broadcast_add"(%arg0, %arg1) {broadcast_dimensions = dense<[1, 2, 3]> : tensor<3xi64>} : (tensor<4x1x1xi32>, tensor<4x4x4x4xi32>) -> tensor<4x4x4x4xi32>
|
||||
return %0 : tensor<4x4x4x4xi32>
|
||||
}
|
||||
|
||||
@ -38,7 +38,7 @@ func @div(%arg0: tensor<2xi32>) -> tensor<2xi32> {
|
||||
}
|
||||
|
||||
func @broadcast_div(%arg0: tensor<1xi32>, %arg1: tensor<1x2xi32>) -> tensor<1x2xi32> {
|
||||
%0 = "xla_chlo.broadcast_divide"(%arg0, %arg1) {broadcast_dimensions = dense<1> : tensor<1xi64>} : (tensor<1xi32>, tensor<1x2xi32>) -> tensor<1x2xi32>
|
||||
%0 = "chlo.broadcast_divide"(%arg0, %arg1) {broadcast_dimensions = dense<1> : tensor<1xi64>} : (tensor<1xi32>, tensor<1x2xi32>) -> tensor<1x2xi32>
|
||||
return %0 : tensor<1x2xi32>
|
||||
}
|
||||
|
||||
@ -48,7 +48,7 @@ func @shift_left(%arg0: tensor<4xi32>, %arg1: tensor<4xi32>) -> tensor<4xi32> {
|
||||
}
|
||||
|
||||
func @div_dynamic(%arg0: tensor<?xi32>, %arg1: tensor<?x?xi32>) -> tensor<?x?xi32> {
|
||||
%0 = "xla_chlo.broadcast_divide"(%arg0, %arg1) {broadcast_dimensions = dense<1> : tensor<1xi64>} : (tensor<?xi32>, tensor<?x?xi32>) -> tensor<?x?xi32>
|
||||
%0 = "chlo.broadcast_divide"(%arg0, %arg1) {broadcast_dimensions = dense<1> : tensor<1xi64>} : (tensor<?xi32>, tensor<?x?xi32>) -> tensor<?x?xi32>
|
||||
return %0 : tensor<?x?xi32>
|
||||
}
|
||||
|
||||
@ -68,7 +68,7 @@ func @mul(%arg0: tensor<2xi32>) -> tensor<2xi32> {
|
||||
}
|
||||
|
||||
func @broadcast_mul(%arg0: tensor<1xi32>, %arg1: tensor<1x2xi32>) -> tensor<1x2xi32> {
|
||||
%0 = "xla_chlo.broadcast_multiply"(%arg0, %arg1) {broadcast_dimensions = dense<1> : tensor<1xi64>} : (tensor<1xi32>, tensor<1x2xi32>) -> tensor<1x2xi32>
|
||||
%0 = "chlo.broadcast_multiply"(%arg0, %arg1) {broadcast_dimensions = dense<1> : tensor<1xi64>} : (tensor<1xi32>, tensor<1x2xi32>) -> tensor<1x2xi32>
|
||||
return %0 : tensor<1x2xi32>
|
||||
}
|
||||
|
||||
@ -78,7 +78,7 @@ func @real_div(%arg0: tensor<2xi32>) -> tensor<2xi32> {
|
||||
}
|
||||
|
||||
func @broadcast_real_div(%arg0: tensor<1xi32>, %arg1: tensor<1x2xi32>) -> tensor<1x2xi32> {
|
||||
%0 = "xla_chlo.broadcast_divide"(%arg0, %arg1) {broadcast_dimensions = dense<1> : tensor<1xi64>} : (tensor<1xi32>, tensor<1x2xi32>) -> tensor<1x2xi32>
|
||||
%0 = "chlo.broadcast_divide"(%arg0, %arg1) {broadcast_dimensions = dense<1> : tensor<1xi64>} : (tensor<1xi32>, tensor<1x2xi32>) -> tensor<1x2xi32>
|
||||
return %0 : tensor<1x2xi32>
|
||||
}
|
||||
|
||||
@ -88,7 +88,7 @@ func @sub(%arg0: tensor<2xi32>) -> tensor<2xi32> {
|
||||
}
|
||||
|
||||
func @broadcast_sub(%arg0: tensor<1xi32>, %arg1: tensor<1x2xi32>) -> tensor<1x2xi32> {
|
||||
%0 = "xla_chlo.broadcast_subtract"(%arg0, %arg1) {broadcast_dimensions = dense<1> : tensor<1xi64>} : (tensor<1xi32>, tensor<1x2xi32>) -> tensor<1x2xi32>
|
||||
%0 = "chlo.broadcast_subtract"(%arg0, %arg1) {broadcast_dimensions = dense<1> : tensor<1xi64>} : (tensor<1xi32>, tensor<1x2xi32>) -> tensor<1x2xi32>
|
||||
return %0 : tensor<1x2xi32>
|
||||
}
|
||||
|
||||
@ -98,7 +98,7 @@ func @shift_right(%arg0: tensor<4xi32>, %arg1: tensor<4xi32>) -> tensor<4xi32> {
|
||||
}
|
||||
|
||||
func @broadcast_shift_right(%arg0: tensor<4xi32>, %arg1: tensor<2x4xi32>) -> tensor<2x4xi32> {
|
||||
%0 = "xla_chlo.broadcast_shift_right_arithmetic"(%arg0, %arg1) {broadcast_dimensions = dense<1> : tensor<1xi64>} : (tensor<4xi32>, tensor<2x4xi32>) -> tensor<2x4xi32>
|
||||
%0 = "chlo.broadcast_shift_right_arithmetic"(%arg0, %arg1) {broadcast_dimensions = dense<1> : tensor<1xi64>} : (tensor<4xi32>, tensor<2x4xi32>) -> tensor<2x4xi32>
|
||||
return %0 : tensor<2x4xi32>
|
||||
}
|
||||
|
||||
@ -108,12 +108,12 @@ func @and(%arg0: tensor<2xi1>) -> tensor<2xi1> {
|
||||
}
|
||||
|
||||
func @and_broadcast(%arg0: tensor<1xi1>, %arg1: tensor<1x2xi1>) -> tensor<1x2xi1> {
|
||||
%0 = "xla_chlo.broadcast_and"(%arg0, %arg1) {broadcast_dimensions = dense<1> : tensor<1xi64>} : (tensor<1xi1>, tensor<1x2xi1>) -> tensor<1x2xi1>
|
||||
%0 = "chlo.broadcast_and"(%arg0, %arg1) {broadcast_dimensions = dense<1> : tensor<1xi64>} : (tensor<1xi1>, tensor<1x2xi1>) -> tensor<1x2xi1>
|
||||
return %0 : tensor<1x2xi1>
|
||||
}
|
||||
|
||||
func @and_dynamic(%arg0: tensor<?xi1>, %arg1: tensor<1xi1>) -> tensor<?xi1> {
|
||||
%0 = "xla_chlo.broadcast_and"(%arg0, %arg1) : (tensor<?xi1>, tensor<1xi1>) -> tensor<?xi1>
|
||||
%0 = "chlo.broadcast_and"(%arg0, %arg1) : (tensor<?xi1>, tensor<1xi1>) -> tensor<?xi1>
|
||||
return %0 : tensor<?xi1>
|
||||
}
|
||||
|
||||
@ -123,12 +123,12 @@ func @or(%arg0: tensor<2xi1>) -> tensor<2xi1> {
|
||||
}
|
||||
|
||||
func @or_broadcast(%arg0: tensor<1xi1>, %arg1: tensor<1x2xi1>) -> tensor<1x2xi1> {
|
||||
%0 = "xla_chlo.broadcast_or"(%arg0, %arg1) {broadcast_dimensions = dense<1> : tensor<1xi64>} : (tensor<1xi1>, tensor<1x2xi1>) -> tensor<1x2xi1>
|
||||
%0 = "chlo.broadcast_or"(%arg0, %arg1) {broadcast_dimensions = dense<1> : tensor<1xi64>} : (tensor<1xi1>, tensor<1x2xi1>) -> tensor<1x2xi1>
|
||||
return %0 : tensor<1x2xi1>
|
||||
}
|
||||
|
||||
func @or_dynamic(%arg0: tensor<?xi1>, %arg1: tensor<1xi1>) -> tensor<?xi1> {
|
||||
%0 = "xla_chlo.broadcast_or"(%arg0, %arg1) : (tensor<?xi1>, tensor<1xi1>) -> tensor<?xi1>
|
||||
%0 = "chlo.broadcast_or"(%arg0, %arg1) : (tensor<?xi1>, tensor<1xi1>) -> tensor<?xi1>
|
||||
return %0 : tensor<?xi1>
|
||||
}
|
||||
|
||||
@ -138,12 +138,12 @@ func @bitwise_or(%arg0: tensor<4xi32>, %arg1: tensor<4xi32>) -> tensor<4xi32> {
|
||||
}
|
||||
|
||||
func @bitwise_or_broadcast(%arg0: tensor<1xi8>, %arg1: tensor<1x4xi8>) -> tensor<1x4xi8> {
|
||||
%0 = "xla_chlo.broadcast_or"(%arg0, %arg1) {broadcast_dimensions = dense<1> : tensor<1xi64>} : (tensor<1xi8>, tensor<1x4xi8>) -> tensor<1x4xi8>
|
||||
%0 = "chlo.broadcast_or"(%arg0, %arg1) {broadcast_dimensions = dense<1> : tensor<1xi64>} : (tensor<1xi8>, tensor<1x4xi8>) -> tensor<1x4xi8>
|
||||
return %0 : tensor<1x4xi8>
|
||||
}
|
||||
|
||||
func @bitwise_or_dynamic(%arg0: tensor<?xi32>, %arg1: tensor<1xi32>) -> tensor<?xi32> {
|
||||
%0 = "xla_chlo.broadcast_or"(%arg0, %arg1) : (tensor<?xi32>, tensor<1xi32>) -> tensor<?xi32>
|
||||
%0 = "chlo.broadcast_or"(%arg0, %arg1) : (tensor<?xi32>, tensor<1xi32>) -> tensor<?xi32>
|
||||
return %0 : tensor<?xi32>
|
||||
}
|
||||
|
||||
@ -153,12 +153,12 @@ func @bitwise_and(%arg0: tensor<4xi32>, %arg1: tensor<4xi32>) -> tensor<4xi32> {
|
||||
}
|
||||
|
||||
func @bitwise_and_broadcast(%arg0: tensor<1xi8>, %arg1: tensor<1x4xi8>) -> tensor<1x4xi8> {
|
||||
%0 = "xla_chlo.broadcast_and"(%arg0, %arg1) {broadcast_dimensions = dense<1> : tensor<1xi64>} : (tensor<1xi8>, tensor<1x4xi8>) -> tensor<1x4xi8>
|
||||
%0 = "chlo.broadcast_and"(%arg0, %arg1) {broadcast_dimensions = dense<1> : tensor<1xi64>} : (tensor<1xi8>, tensor<1x4xi8>) -> tensor<1x4xi8>
|
||||
return %0 : tensor<1x4xi8>
|
||||
}
|
||||
|
||||
func @bitwise_and_dynamic(%arg0: tensor<?xi32>, %arg1: tensor<1xi32>) -> tensor<?xi32> {
|
||||
%0 = "xla_chlo.broadcast_and"(%arg0, %arg1) : (tensor<?xi32>, tensor<1xi32>) -> tensor<?xi32>
|
||||
%0 = "chlo.broadcast_and"(%arg0, %arg1) : (tensor<?xi32>, tensor<1xi32>) -> tensor<?xi32>
|
||||
return %0 : tensor<?xi32>
|
||||
}
|
||||
|
||||
@ -174,19 +174,19 @@ func @pow_dynamic(%arg0: tensor<?xf32>) -> tensor<?xf32> {
|
||||
|
||||
func @floordiv_broadcast_i32(%arg0: tensor<2x3xi32>, %arg1: tensor<3xi32>) -> tensor<2x3xi32> {
|
||||
%0 = mhlo.constant dense<0> : tensor<2x3xi32>
|
||||
%1 = "xla_chlo.broadcast_compare"(%arg0, %0) {comparison_direction = "LT"} : (tensor<2x3xi32>, tensor<2x3xi32>) -> tensor<2x3xi1>
|
||||
%1 = "chlo.broadcast_compare"(%arg0, %0) {comparison_direction = "LT"} : (tensor<2x3xi32>, tensor<2x3xi32>) -> tensor<2x3xi1>
|
||||
%2 = mhlo.constant dense<0> : tensor<3xi32>
|
||||
%3 = "xla_chlo.broadcast_compare"(%arg1, %2) {comparison_direction = "LT"} : (tensor<3xi32>, tensor<3xi32>) -> tensor<3xi1>
|
||||
%4 = "xla_chlo.broadcast_compare"(%1, %3) {broadcast_dimensions = dense<1> : tensor<1xi64>, comparison_direction = "EQ"} : (tensor<2x3xi1>, tensor<3xi1>) -> tensor<2x3xi1>
|
||||
%5 = "xla_chlo.broadcast_divide"(%arg0, %arg1) {broadcast_dimensions = dense<1> : tensor<1xi64>} : (tensor<2x3xi32>, tensor<3xi32>) -> tensor<2x3xi32>
|
||||
%3 = "chlo.broadcast_compare"(%arg1, %2) {comparison_direction = "LT"} : (tensor<3xi32>, tensor<3xi32>) -> tensor<3xi1>
|
||||
%4 = "chlo.broadcast_compare"(%1, %3) {broadcast_dimensions = dense<1> : tensor<1xi64>, comparison_direction = "EQ"} : (tensor<2x3xi1>, tensor<3xi1>) -> tensor<2x3xi1>
|
||||
%5 = "chlo.broadcast_divide"(%arg0, %arg1) {broadcast_dimensions = dense<1> : tensor<1xi64>} : (tensor<2x3xi32>, tensor<3xi32>) -> tensor<2x3xi32>
|
||||
%6 = "mhlo.abs"(%arg0) : (tensor<2x3xi32>) -> tensor<2x3xi32>
|
||||
%7 = "mhlo.abs"(%arg1) : (tensor<3xi32>) -> tensor<3xi32>
|
||||
%8 = mhlo.constant dense<1> : tensor<3xi32>
|
||||
%9 = mhlo.subtract %7, %8 : tensor<3xi32>
|
||||
%10 = "xla_chlo.broadcast_add"(%6, %9) {broadcast_dimensions = dense<1> : tensor<1xi64>} : (tensor<2x3xi32>, tensor<3xi32>) -> tensor<2x3xi32>
|
||||
%10 = "chlo.broadcast_add"(%6, %9) {broadcast_dimensions = dense<1> : tensor<1xi64>} : (tensor<2x3xi32>, tensor<3xi32>) -> tensor<2x3xi32>
|
||||
%11 = "mhlo.negate"(%10) : (tensor<2x3xi32>) -> tensor<2x3xi32>
|
||||
%12 = "mhlo.abs"(%arg1) : (tensor<3xi32>) -> tensor<3xi32>
|
||||
%13 = "xla_chlo.broadcast_divide"(%11, %12) {broadcast_dimensions = dense<1> : tensor<1xi64>} : (tensor<2x3xi32>, tensor<3xi32>) -> tensor<2x3xi32>
|
||||
%13 = "chlo.broadcast_divide"(%11, %12) {broadcast_dimensions = dense<1> : tensor<1xi64>} : (tensor<2x3xi32>, tensor<3xi32>) -> tensor<2x3xi32>
|
||||
%14 = "mhlo.select"(%4, %5, %13) : (tensor<2x3xi1>, tensor<2x3xi32>, tensor<2x3xi32>) -> tensor<2x3xi32>
|
||||
return %14 : tensor<2x3xi32>
|
||||
}
|
||||
@ -195,14 +195,14 @@ func @floordiv_reverse_broadcast_i32(%arg0: tensor<3xi32>, %arg1: tensor<2x3xi32
|
||||
%0 = mhlo.constant dense<0> : tensor<3xi32>
|
||||
%1 = "mhlo.compare"(%arg0, %0) {comparison_direction = "LT"} : (tensor<3xi32>, tensor<3xi32>) -> tensor<3xi1>
|
||||
%2 = mhlo.constant dense<0> : tensor<2x3xi32>
|
||||
%3 = "xla_chlo.broadcast_compare"(%arg1, %2) {comparison_direction = "LT"} : (tensor<2x3xi32>, tensor<2x3xi32>) -> tensor<2x3xi1>
|
||||
%4 = "xla_chlo.broadcast_compare"(%1, %3) {broadcast_dimensions = dense<1> : tensor<1xi64>, comparison_direction = "EQ"} : (tensor<3xi1>, tensor<2x3xi1>) -> tensor<2x3xi1>
|
||||
%5 = "xla_chlo.broadcast_divide"(%arg0, %arg1) {broadcast_dimensions = dense<1> : tensor<1xi64>} : (tensor<3xi32>, tensor<2x3xi32>) -> tensor<2x3xi32>
|
||||
%3 = "chlo.broadcast_compare"(%arg1, %2) {comparison_direction = "LT"} : (tensor<2x3xi32>, tensor<2x3xi32>) -> tensor<2x3xi1>
|
||||
%4 = "chlo.broadcast_compare"(%1, %3) {broadcast_dimensions = dense<1> : tensor<1xi64>, comparison_direction = "EQ"} : (tensor<3xi1>, tensor<2x3xi1>) -> tensor<2x3xi1>
|
||||
%5 = "chlo.broadcast_divide"(%arg0, %arg1) {broadcast_dimensions = dense<1> : tensor<1xi64>} : (tensor<3xi32>, tensor<2x3xi32>) -> tensor<2x3xi32>
|
||||
%6 = "mhlo.abs"(%arg0) : (tensor<3xi32>) -> tensor<3xi32>
|
||||
%7 = "mhlo.abs"(%arg1) : (tensor<2x3xi32>) -> tensor<2x3xi32>
|
||||
%8 = mhlo.constant dense<1> : tensor<2x3xi32>
|
||||
%9 = mhlo.subtract %7, %8 : tensor<2x3xi32>
|
||||
%10 = "xla_chlo.broadcast_add"(%6, %9) {broadcast_dimensions = dense<1> : tensor<1xi64>} : (tensor<3xi32>, tensor<2x3xi32>) -> tensor<2x3xi32>
|
||||
%10 = "chlo.broadcast_add"(%6, %9) {broadcast_dimensions = dense<1> : tensor<1xi64>} : (tensor<3xi32>, tensor<2x3xi32>) -> tensor<2x3xi32>
|
||||
%11 = "mhlo.negate"(%10) : (tensor<2x3xi32>) -> tensor<2x3xi32>
|
||||
%12 = "mhlo.abs"(%arg1) : (tensor<2x3xi32>) -> tensor<2x3xi32>
|
||||
%13 = mhlo.divide %11, %12 : tensor<2x3xi32>
|
||||
@ -218,8 +218,8 @@ func @floordiv_f32(%arg0: tensor<2xf32>) -> tensor<2xf32> {
|
||||
}
|
||||
|
||||
func @floordiv_f16_broadcast(%arg0: tensor<2x3xf16>, %arg1: tensor<3xf16>) -> tensor<2x3xf16> {
|
||||
%0 = "xla_chlo.broadcast_divide"(%arg0, %arg1) {broadcast_dimensions = dense<1> : tensor<1xi64>} : (tensor<2x3xf16>, tensor<3xf16>) -> tensor<2x3xf16>
|
||||
%1 = "xla_chlo.broadcast_divide"(%arg0, %arg1) {broadcast_dimensions = dense<1> : tensor<1xi64>} : (tensor<2x3xf16>, tensor<3xf16>) -> tensor<2x3xf16>
|
||||
%0 = "chlo.broadcast_divide"(%arg0, %arg1) {broadcast_dimensions = dense<1> : tensor<1xi64>} : (tensor<2x3xf16>, tensor<3xf16>) -> tensor<2x3xf16>
|
||||
%1 = "chlo.broadcast_divide"(%arg0, %arg1) {broadcast_dimensions = dense<1> : tensor<1xi64>} : (tensor<2x3xf16>, tensor<3xf16>) -> tensor<2x3xf16>
|
||||
%2 = "mhlo.floor"(%1) : (tensor<2x3xf16>) -> tensor<2x3xf16>
|
||||
return %2 : tensor<2x3xf16>
|
||||
}
|
||||
@ -230,22 +230,22 @@ func @equal(%arg0: tensor<2xi32>) -> tensor<2xi1> {
|
||||
}
|
||||
|
||||
func @equal_dynamic(%arg0: tensor<?xi32>, %arg1: tensor<1xi32>) -> tensor<?xi1> {
|
||||
%0 = "xla_chlo.broadcast_compare"(%arg0, %arg1) {comparison_direction = "EQ"} : (tensor<?xi32>, tensor<1xi32>) -> tensor<?xi1>
|
||||
%0 = "chlo.broadcast_compare"(%arg0, %arg1) {comparison_direction = "EQ"} : (tensor<?xi32>, tensor<1xi32>) -> tensor<?xi1>
|
||||
return %0 : tensor<?xi1>
|
||||
}
|
||||
|
||||
func @equal_broadcast(%arg0: tensor<1xi32>, %arg1: tensor<1x2xi32>) -> tensor<1x2xi1> {
|
||||
%0 = "xla_chlo.broadcast_compare"(%arg0, %arg1) {broadcast_dimensions = dense<1> : tensor<1xi64>, comparison_direction = "EQ"} : (tensor<1xi32>, tensor<1x2xi32>) -> tensor<1x2xi1>
|
||||
%0 = "chlo.broadcast_compare"(%arg0, %arg1) {broadcast_dimensions = dense<1> : tensor<1xi64>, comparison_direction = "EQ"} : (tensor<1xi32>, tensor<1x2xi32>) -> tensor<1x2xi1>
|
||||
return %0 : tensor<1x2xi1>
|
||||
}
|
||||
|
||||
func @equal_broadcast_no_incompatible_shapes_error(%arg0: tensor<2xi32>, %arg1: tensor<1x2xi32>) -> tensor<1x2xi1> {
|
||||
%0 = "xla_chlo.broadcast_compare"(%arg0, %arg1) {broadcast_dimensions = dense<1> : tensor<1xi64>, comparison_direction = "EQ"} : (tensor<2xi32>, tensor<1x2xi32>) -> tensor<1x2xi1>
|
||||
%0 = "chlo.broadcast_compare"(%arg0, %arg1) {broadcast_dimensions = dense<1> : tensor<1xi64>, comparison_direction = "EQ"} : (tensor<2xi32>, tensor<1x2xi32>) -> tensor<1x2xi1>
|
||||
return %0 : tensor<1x2xi1>
|
||||
}
|
||||
|
||||
func @equal_incompatible_shape_broadcastable(%arg0: tensor<?xi32>, %arg1: tensor<1xi32>) -> tensor<?xi1> {
|
||||
%0 = "xla_chlo.broadcast_compare"(%arg0, %arg1) {comparison_direction = "EQ"} : (tensor<?xi32>, tensor<1xi32>) -> tensor<?xi1>
|
||||
%0 = "chlo.broadcast_compare"(%arg0, %arg1) {comparison_direction = "EQ"} : (tensor<?xi32>, tensor<1xi32>) -> tensor<?xi1>
|
||||
return %0 : tensor<?xi1>
|
||||
}
|
||||
|
||||
@ -255,17 +255,17 @@ func @notequal(%arg0: tensor<2xi32>) -> tensor<2xi1> {
|
||||
}
|
||||
|
||||
func @notequal_broadcast(%arg0: tensor<1xi32>, %arg1: tensor<1x2xi32>) -> tensor<1x2xi1> {
|
||||
%0 = "xla_chlo.broadcast_compare"(%arg0, %arg1) {broadcast_dimensions = dense<1> : tensor<1xi64>, comparison_direction = "NE"} : (tensor<1xi32>, tensor<1x2xi32>) -> tensor<1x2xi1>
|
||||
%0 = "chlo.broadcast_compare"(%arg0, %arg1) {broadcast_dimensions = dense<1> : tensor<1xi64>, comparison_direction = "NE"} : (tensor<1xi32>, tensor<1x2xi32>) -> tensor<1x2xi1>
|
||||
return %0 : tensor<1x2xi1>
|
||||
}
|
||||
|
||||
func @notequal_broadcast_no_incompatible_shapes_error(%arg0: tensor<2xi32>, %arg1: tensor<1x2xi32>) -> tensor<1x2xi1> {
|
||||
%0 = "xla_chlo.broadcast_compare"(%arg0, %arg1) {broadcast_dimensions = dense<1> : tensor<1xi64>, comparison_direction = "NE"} : (tensor<2xi32>, tensor<1x2xi32>) -> tensor<1x2xi1>
|
||||
%0 = "chlo.broadcast_compare"(%arg0, %arg1) {broadcast_dimensions = dense<1> : tensor<1xi64>, comparison_direction = "NE"} : (tensor<2xi32>, tensor<1x2xi32>) -> tensor<1x2xi1>
|
||||
return %0 : tensor<1x2xi1>
|
||||
}
|
||||
|
||||
func @notequal_incompatible_shape_broadcastable(%arg0: tensor<?xi32>, %arg1: tensor<1xi32>) -> tensor<?xi1> {
|
||||
%0 = "xla_chlo.broadcast_compare"(%arg0, %arg1) {comparison_direction = "NE"} : (tensor<?xi32>, tensor<1xi32>) -> tensor<?xi1>
|
||||
%0 = "chlo.broadcast_compare"(%arg0, %arg1) {comparison_direction = "NE"} : (tensor<?xi32>, tensor<1xi32>) -> tensor<?xi1>
|
||||
return %0 : tensor<?xi1>
|
||||
}
|
||||
|
||||
@ -275,7 +275,7 @@ func @greater(%arg0: tensor<2xi32>) -> tensor<2xi1> {
|
||||
}
|
||||
|
||||
func @broadcast_greater(%arg0: tensor<1xi32>, %arg1: tensor<1x2xi32>) -> tensor<1x2xi1> {
|
||||
%0 = "xla_chlo.broadcast_compare"(%arg0, %arg1) {broadcast_dimensions = dense<1> : tensor<1xi64>, comparison_direction = "GT"} : (tensor<1xi32>, tensor<1x2xi32>) -> tensor<1x2xi1>
|
||||
%0 = "chlo.broadcast_compare"(%arg0, %arg1) {broadcast_dimensions = dense<1> : tensor<1xi64>, comparison_direction = "GT"} : (tensor<1xi32>, tensor<1x2xi32>) -> tensor<1x2xi1>
|
||||
return %0 : tensor<1x2xi1>
|
||||
}
|
||||
|
||||
@ -285,7 +285,7 @@ func @greater_equal(%arg0: tensor<2xi32>) -> tensor<2xi1> {
|
||||
}
|
||||
|
||||
func @broadcast_greater_equal(%arg0: tensor<1xi32>, %arg1: tensor<1x2xi32>) -> tensor<1x2xi1> {
|
||||
%0 = "xla_chlo.broadcast_compare"(%arg0, %arg1) {broadcast_dimensions = dense<1> : tensor<1xi64>, comparison_direction = "GE"} : (tensor<1xi32>, tensor<1x2xi32>) -> tensor<1x2xi1>
|
||||
%0 = "chlo.broadcast_compare"(%arg0, %arg1) {broadcast_dimensions = dense<1> : tensor<1xi64>, comparison_direction = "GE"} : (tensor<1xi32>, tensor<1x2xi32>) -> tensor<1x2xi1>
|
||||
return %0 : tensor<1x2xi1>
|
||||
}
|
||||
|
||||
@ -295,7 +295,7 @@ func @less(%arg0: tensor<2xi32>) -> tensor<2xi1> {
|
||||
}
|
||||
|
||||
func @broadcast_less(%arg0: tensor<1xi32>, %arg1: tensor<1x2xi32>) -> tensor<1x2xi1> {
|
||||
%0 = "xla_chlo.broadcast_compare"(%arg0, %arg1) {broadcast_dimensions = dense<1> : tensor<1xi64>, comparison_direction = "LT"} : (tensor<1xi32>, tensor<1x2xi32>) -> tensor<1x2xi1>
|
||||
%0 = "chlo.broadcast_compare"(%arg0, %arg1) {broadcast_dimensions = dense<1> : tensor<1xi64>, comparison_direction = "LT"} : (tensor<1xi32>, tensor<1x2xi32>) -> tensor<1x2xi1>
|
||||
return %0 : tensor<1x2xi1>
|
||||
}
|
||||
|
||||
@ -305,7 +305,7 @@ func @less_equal(%arg0: tensor<2xi32>) -> tensor<2xi1> {
|
||||
}
|
||||
|
||||
func @broadcast_less_equal(%arg0: tensor<1xi32>, %arg1: tensor<1x2xi32>) -> tensor<1x2xi1> {
|
||||
%0 = "xla_chlo.broadcast_compare"(%arg0, %arg1) {broadcast_dimensions = dense<1> : tensor<1xi64>, comparison_direction = "LE"} : (tensor<1xi32>, tensor<1x2xi32>) -> tensor<1x2xi1>
|
||||
%0 = "chlo.broadcast_compare"(%arg0, %arg1) {broadcast_dimensions = dense<1> : tensor<1xi64>, comparison_direction = "LE"} : (tensor<1xi32>, tensor<1x2xi32>) -> tensor<1x2xi1>
|
||||
return %0 : tensor<1x2xi1>
|
||||
}
|
||||
|
||||
@ -326,35 +326,35 @@ func @const() -> tensor<2xi32> {
|
||||
|
||||
func @relu(%arg0: tensor<1xi32>) -> tensor<1xi32> {
|
||||
%0 = mhlo.constant dense<0> : tensor<i32>
|
||||
%1 = "xla_chlo.broadcast_maximum"(%0, %arg0) {broadcast_dimensions = dense<[]> : tensor<0xi64>} : (tensor<i32>, tensor<1xi32>) -> tensor<1xi32>
|
||||
%1 = "chlo.broadcast_maximum"(%0, %arg0) {broadcast_dimensions = dense<[]> : tensor<0xi64>} : (tensor<i32>, tensor<1xi32>) -> tensor<1xi32>
|
||||
return %1 : tensor<1xi32>
|
||||
}
|
||||
|
||||
func @relu_unranked(%arg0: tensor<?xi32>) -> tensor<?xi32> {
|
||||
%0 = mhlo.constant dense<0> : tensor<i32>
|
||||
%1 = "xla_chlo.broadcast_maximum"(%0, %arg0) {broadcast_dimensions = dense<[]> : tensor<0xi64>} : (tensor<i32>, tensor<?xi32>) -> tensor<?xi32>
|
||||
%1 = "chlo.broadcast_maximum"(%0, %arg0) {broadcast_dimensions = dense<[]> : tensor<0xi64>} : (tensor<i32>, tensor<?xi32>) -> tensor<?xi32>
|
||||
return %1 : tensor<?xi32>
|
||||
}
|
||||
|
||||
func @relu6(%arg0: tensor<1xi32>) -> tensor<1xi32> {
|
||||
%0 = mhlo.constant dense<0> : tensor<i32>
|
||||
%1 = mhlo.constant dense<6> : tensor<i32>
|
||||
%2 = "xla_chlo.broadcast_minimum"(%arg0, %1) {broadcast_dimensions = dense<[]> : tensor<0xi64>} : (tensor<1xi32>, tensor<i32>) -> tensor<1xi32>
|
||||
%3 = "xla_chlo.broadcast_maximum"(%2, %0) {broadcast_dimensions = dense<[]> : tensor<0xi64>} : (tensor<1xi32>, tensor<i32>) -> tensor<1xi32>
|
||||
%2 = "chlo.broadcast_minimum"(%arg0, %1) {broadcast_dimensions = dense<[]> : tensor<0xi64>} : (tensor<1xi32>, tensor<i32>) -> tensor<1xi32>
|
||||
%3 = "chlo.broadcast_maximum"(%2, %0) {broadcast_dimensions = dense<[]> : tensor<0xi64>} : (tensor<1xi32>, tensor<i32>) -> tensor<1xi32>
|
||||
return %3 : tensor<1xi32>
|
||||
}
|
||||
|
||||
func @relu6_unranked(%arg0: tensor<?xi32>) -> tensor<?xi32> {
|
||||
%0 = mhlo.constant dense<0> : tensor<i32>
|
||||
%1 = mhlo.constant dense<6> : tensor<i32>
|
||||
%2 = "xla_chlo.broadcast_minimum"(%arg0, %1) {broadcast_dimensions = dense<[]> : tensor<0xi64>} : (tensor<?xi32>, tensor<i32>) -> tensor<?xi32>
|
||||
%3 = "xla_chlo.broadcast_maximum"(%2, %0) {broadcast_dimensions = dense<[]> : tensor<0xi64>} : (tensor<?xi32>, tensor<i32>) -> tensor<?xi32>
|
||||
%2 = "chlo.broadcast_minimum"(%arg0, %1) {broadcast_dimensions = dense<[]> : tensor<0xi64>} : (tensor<?xi32>, tensor<i32>) -> tensor<?xi32>
|
||||
%3 = "chlo.broadcast_maximum"(%2, %0) {broadcast_dimensions = dense<[]> : tensor<0xi64>} : (tensor<?xi32>, tensor<i32>) -> tensor<?xi32>
|
||||
return %3 : tensor<?xi32>
|
||||
}
|
||||
|
||||
func @relu_grad(%arg0: tensor<4x8xf32>, %arg1: tensor<?x?xf32>) -> tensor<4x8xf32> {
|
||||
%0 = mhlo.constant dense<0.000000e+00> : tensor<f32>
|
||||
%1 = "xla_chlo.broadcast_compare"(%arg1, %0) {broadcast_dimensions = dense<[]> : tensor<0xi64>, comparison_direction = "GT"} : (tensor<?x?xf32>, tensor<f32>) -> tensor<?x?xi1>
|
||||
%1 = "chlo.broadcast_compare"(%arg1, %0) {broadcast_dimensions = dense<[]> : tensor<0xi64>, comparison_direction = "GT"} : (tensor<?x?xf32>, tensor<f32>) -> tensor<?x?xi1>
|
||||
%2 = mhlo.constant dense<0.000000e+00> : tensor<4x8xf32>
|
||||
%3 = "mhlo.select"(%1, %arg0, %2) : (tensor<?x?xi1>, tensor<4x8xf32>, tensor<4x8xf32>) -> tensor<4x8xf32>
|
||||
return %3 : tensor<4x8xf32>
|
||||
|
@ -86,6 +86,15 @@ func @mul_no_nan(%arg0: tensor<2x3xf32>, %arg1: tensor<3xf32>) -> tensor<2x3xf32
|
||||
return %0 : tensor<2x3xf32>
|
||||
}
|
||||
|
||||
// CHECK-LABEL: @is_nan
|
||||
func @is_nan(%arg0: tensor<3x4xf32>) -> tensor<3x4xi1> {
|
||||
// CHECK: %[[NAN:.*]] = "tf.Const"() {value = dense<0x7FC00000> : tensor<f32>} : () -> tensor<f32>
|
||||
// CHECK: %[[RESULT:.*]] = "tf.Equal"(%arg0, %[[NAN]]) {incompatible_shape_error = true} : (tensor<3x4xf32>, tensor<f32>) -> tensor<3x4xi1>
|
||||
%0 = "tf.IsNan"(%arg0) : (tensor<3x4xf32>) -> tensor<3x4xi1>
|
||||
// CHECK: return %[[RESULT]]
|
||||
return %0 : tensor<3x4xi1>
|
||||
}
|
||||
|
||||
// CHECK-LABEL: func @fill
|
||||
// CHECK-SAME: (%[[ARG0:.*]]: tensor<*xi64>, %[[ARG1:.*]]: tensor<*xf32>)
|
||||
func @fill(%arg0: tensor<*xi64>, %arg1: tensor<*xf32>) -> tensor<*xf32> {
|
||||
|
@ -0,0 +1,79 @@
|
||||
// RUN: tf-opt %s -split-input-file -verify-diagnostics -tf-tpu-update-embedding-enqueue-op-inputs | FileCheck %s
|
||||
|
||||
// CHECK-LABEL: func @check_enqueue_ops_update_for_eval
|
||||
// CHECK-SAME: %[[ARG_0:[a-z0-9]*]]: tensor<?x2xi32>
|
||||
// CHECK-SAME: %[[ARG_1:[a-z0-9]*]]: tensor<?x2xi32>
|
||||
// CHECK-SAME: %[[ARG_2:[a-z0-9]*]]: tensor<?x2xi32>
|
||||
// CHECK-SAME: %[[ARG_3:[a-z0-9]*]]: tensor<?xi32>
|
||||
// CHECK-SAME: %[[ARG_4:[a-z0-9]*]]: tensor<?xi32>
|
||||
// CHECK-SAME: %[[ARG_5:[a-z0-9]*]]: tensor<?xi32>
|
||||
// CHECK-SAME: %[[ARG_6:[a-z0-9]*]]: tensor<!tf.string>
|
||||
// CHECK-SAME: %[[ARG_7:[a-z0-9]*]]: tensor<!tf.string>
|
||||
// CHECK-SAME: %[[ARG_8:[a-z0-9]*]]: tensor<i1>
|
||||
func @check_enqueue_ops_update_for_eval(%arg0: tensor<?x2xi32>, %arg1: tensor<?x2xi32>,
|
||||
%arg2 :tensor<?x2xi32>, %arg3: tensor<?xi32>, %arg4: tensor<?xi32>, %arg5: tensor<?xi32>,
|
||||
%arg6: tensor<!tf.string>, %arg7: tensor<!tf.string>, %arg8: tensor<i1>) -> () {
|
||||
// CHECK: %[[CONST_0:[a-z0-9]*]] = "tf.Const"()
|
||||
%0 = "tf.Const"() {value = dense<[]> : tensor<0xf32>} : () -> tensor<0xf32>
|
||||
%1 = "tf.SelectV2"(%arg8, %arg6, %arg7) : (tensor<i1>, tensor<!tf.string>, tensor<!tf.string>) -> tensor<!tf.string>
|
||||
|
||||
// CHECK: "tf.EnqueueTPUEmbeddingSparseTensorBatch"(%[[ARG_0]], %[[ARG_1]], %[[ARG_2]], %[[ARG_3]], %[[ARG_4]], %[[ARG_5]], %[[CONST_0]], %[[CONST_0]], %[[CONST_0]], %[[ARG_7]])
|
||||
"tf.EnqueueTPUEmbeddingSparseTensorBatch"(%arg0, %arg1, %arg2, %arg3, %arg4, %arg5, %0, %0, %0, %1) {_tpu_embedding_layer = "call1", _xla_outside_compilation = "0", combiners = ["mean", "sum"], device_ordinal = -1 : i64, max_sequence_lengths = [0, 0, 0], table_ids = [1, 1, 0]} : (tensor<?x2xi32>, tensor<?x2xi32>, tensor<?x2xi32>, tensor<?xi32>, tensor<?xi32>, tensor<?xi32>, tensor<0xf32>, tensor<0xf32>, tensor<0xf32>, tensor<!tf.string>) -> ()
|
||||
%2:2 = "tf.RecvTPUEmbeddingActivations"() {_tpu_embedding_layer = "call1", config = "\0A\0B\0C\0D"} : () -> (tensor<2x2xf32>, tensor<4x4xf32>)
|
||||
return
|
||||
}
|
||||
|
||||
// -----
|
||||
|
||||
// CHECK-LABEL: func @check_enqueue_ops_update_for_training
|
||||
// CHECK-SAME: %[[ARG_0:[a-z0-9]*]]: tensor<?x2xi32>
|
||||
// CHECK-SAME: %[[ARG_1:[a-z0-9]*]]: tensor<?x2xi32>
|
||||
// CHECK-SAME: %[[ARG_2:[a-z0-9]*]]: tensor<?x2xi32>
|
||||
// CHECK-SAME: %[[ARG_3:[a-z0-9]*]]: tensor<?xi32>
|
||||
// CHECK-SAME: %[[ARG_4:[a-z0-9]*]]: tensor<?xi32>
|
||||
// CHECK-SAME: %[[ARG_5:[a-z0-9]*]]: tensor<?xi32>
|
||||
// CHECK-SAME: %[[ARG_6:[a-z0-9]*]]: tensor<!tf.string>
|
||||
// CHECK-SAME: %[[ARG_7:[a-z0-9]*]]: tensor<!tf.string>
|
||||
// CHECK-SAME: %[[ARG_8:[a-z0-9]*]]: tensor<i1>
|
||||
func @check_enqueue_ops_update_for_training(%arg0: tensor<?x2xi32>, %arg1: tensor<?x2xi32>,
|
||||
%arg2 :tensor<?x2xi32>, %arg3: tensor<?xi32>, %arg4: tensor<?xi32>, %arg5: tensor<?xi32>,
|
||||
%arg6: tensor<!tf.string>, %arg7: tensor<!tf.string>, %arg8: tensor<i1>) -> () {
|
||||
// CHECK: %[[CONST_0:[a-z0-9]*]] = "tf.Const"()
|
||||
%0 = "tf.Const"() {value = dense<[]> : tensor<0xf32>} : () -> tensor<0xf32>
|
||||
%1 = "tf.SelectV2"(%arg8, %arg6, %arg7) : (tensor<i1>, tensor<!tf.string>, tensor<!tf.string>) -> tensor<!tf.string>
|
||||
|
||||
%2 = "tf.Const"() {value = dense<0.0> : tensor<2x2xf32>} : () -> tensor<2x2xf32>
|
||||
%3 = "tf.Const"() {value = dense<0.0> : tensor<4x4xf32>} : () -> tensor<4x4xf32>
|
||||
"tf.SendTPUEmbeddingGradients"(%2, %3) {_tpu_embedding_layer = "call1", config = "\0A\0B\0C\0D", operand_segment_sizes = dense<[2, 0]> : vector<2xi32>} : (tensor<2x2xf32>, tensor<4x4xf32>) -> ()
|
||||
|
||||
// CHECK: "tf.EnqueueTPUEmbeddingSparseTensorBatch"(%[[ARG_0]], %[[ARG_1]], %[[ARG_2]], %[[ARG_3]], %[[ARG_4]], %[[ARG_5]], %[[CONST_0]], %[[CONST_0]], %[[CONST_0]], %[[ARG_6]])
|
||||
"tf.EnqueueTPUEmbeddingSparseTensorBatch"(%arg0, %arg1, %arg2, %arg3, %arg4, %arg5, %0, %0, %0, %1) {_tpu_embedding_layer = "call1", _xla_outside_compilation = "0", combiners = ["mean", "sum"], device_ordinal = -1 : i64, max_sequence_lengths = [0, 0, 0], table_ids = [1, 1, 0]} : (tensor<?x2xi32>, tensor<?x2xi32>, tensor<?x2xi32>, tensor<?xi32>, tensor<?xi32>, tensor<?xi32>, tensor<0xf32>, tensor<0xf32>, tensor<0xf32>, tensor<!tf.string>) -> ()
|
||||
%4:2 = "tf.RecvTPUEmbeddingActivations"() {_tpu_embedding_layer = "call1", config = "\0A\0B\0C\0D"} : () -> (tensor<2x2xf32>, tensor<4x4xf32>)
|
||||
return
|
||||
}
|
||||
|
||||
// -----
|
||||
|
||||
func @check_enqueue_ops_with_different_attr_disallowed(%arg0: tensor<?x2xi32>, %arg1: tensor<?x2xi32>,
|
||||
%arg2 :tensor<?x2xi32>, %arg3: tensor<?xi32>, %arg4: tensor<?xi32>, %arg5: tensor<?xi32>,
|
||||
%arg6: tensor<!tf.string>, %arg7: tensor<!tf.string>, %arg8: tensor<i1>) -> () {
|
||||
%0 = "tf.Const"() {value = dense<[]> : tensor<0xf32>} : () -> tensor<0xf32>
|
||||
%1 = "tf.SelectV2"(%arg8, %arg6, %arg7) : (tensor<i1>, tensor<!tf.string>, tensor<!tf.string>) -> tensor<!tf.string>
|
||||
// expected-error @+1 {{'tf.EnqueueTPUEmbeddingSparseTensorBatch' op must have a corresponding 'tf.RecvTPUEmbeddingActivations' op}}
|
||||
"tf.EnqueueTPUEmbeddingSparseTensorBatch"(%arg0, %arg1, %arg2, %arg3, %arg4, %arg5, %0, %0, %0, %1) {_tpu_embedding_layer = "call_123", _xla_outside_compilation = "0", combiners = ["mean", "sum"], device_ordinal = -1 : i64, max_sequence_lengths = [0, 0, 0], table_ids = [1, 1, 0]} : (tensor<?x2xi32>, tensor<?x2xi32>, tensor<?x2xi32>, tensor<?xi32>, tensor<?xi32>, tensor<?xi32>, tensor<0xf32>, tensor<0xf32>, tensor<0xf32>, tensor<!tf.string>) -> ()
|
||||
%2:2 = "tf.RecvTPUEmbeddingActivations"() {_tpu_embedding_layer = "call1", config = "\0A\0B\0C\0D"} : () -> (tensor<2x2xf32>, tensor<4x4xf32>)
|
||||
return
|
||||
}
|
||||
|
||||
// -----
|
||||
|
||||
func @check_embedding_ops_with_missing_attribute_disallowed(%arg0: tensor<?x2xi32>, %arg1: tensor<?x2xi32>,
|
||||
%arg2 :tensor<?x2xi32>, %arg3: tensor<?xi32>, %arg4: tensor<?xi32>, %arg5: tensor<?xi32>,
|
||||
%arg6: tensor<!tf.string>, %arg7: tensor<!tf.string>, %arg8: tensor<i1>) -> () {
|
||||
%0 = "tf.Const"() {value = dense<[]> : tensor<0xf32>} : () -> tensor<0xf32>
|
||||
%1 = "tf.SelectV2"(%arg8, %arg6, %arg7) : (tensor<i1>, tensor<!tf.string>, tensor<!tf.string>) -> tensor<!tf.string>
|
||||
"tf.EnqueueTPUEmbeddingSparseTensorBatch"(%arg0, %arg1, %arg2, %arg3, %arg4, %arg5, %0, %0, %0, %1) {_tpu_embedding_layer = "call_123", _xla_outside_compilation = "0", combiners = ["mean", "sum"], device_ordinal = -1 : i64, max_sequence_lengths = [0, 0, 0], table_ids = [1, 1, 0]} : (tensor<?x2xi32>, tensor<?x2xi32>, tensor<?x2xi32>, tensor<?xi32>, tensor<?xi32>, tensor<?xi32>, tensor<0xf32>, tensor<0xf32>, tensor<0xf32>, tensor<!tf.string>) -> ()
|
||||
// expected-error @+1 {{'tf.RecvTPUEmbeddingActivations' op requires attribute '_tpu_embedding_layer'}}
|
||||
%2:2 = "tf.RecvTPUEmbeddingActivations"() {config = "\0A\0B\0C\0D"} : () -> (tensor<2x2xf32>, tensor<4x4xf32>)
|
||||
return
|
||||
}
|
@ -107,6 +107,7 @@ void CreateTPUBridgePipeline(OpPassManager &pm) {
|
||||
}
|
||||
|
||||
void CreateTPUBridgePipelineV1(OpPassManager &pm) {
|
||||
pm.addPass(TF::CreateTFShapeInferencePass());
|
||||
// For V1 compatibility, we process a module where the graph does not have
|
||||
// feeds and fetched. We extract first the TPU computation in a submodule,
|
||||
// where it'll be in a function with args and returned values, much more like
|
||||
|
@ -194,6 +194,13 @@ def RealDivWithSqrtDivisor : Pat<(TF_RealDivOp $arg0, (TF_SqrtOp $arg1)),
|
||||
def ReciprocalNested : Pat<(TF_ReciprocalOp (TF_ReciprocalOp $arg)),
|
||||
(replaceWithValue $arg)>;
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
// Reshape op patterns.
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
def RedundantReshape : Pat<(TF_ReshapeOp (TF_ReshapeOp $arg, $unused), $shape),
|
||||
(TF_ReshapeOp $arg, $shape)>;
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
// Select op patterns.
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
@ -154,6 +154,14 @@ foreach fromToBinPair = [[TF_DivNoNanOp, TF_DivOp],
|
||||
def LowerFillOp : Pat<(TF_FillOp $dims, $value),
|
||||
(TF_BroadcastToOp $value, $dims)>;
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
// NaN op patterns.
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
def LowerIsNanOp : Pat<(TF_IsNanOp $x),
|
||||
(TF_EqualOp $x, (TF_ConstOp:$nan (GetScalarNanOfType $x)),
|
||||
/*incompatible_shape_error*/ConstBoolAttrTrue)>;
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
// L2Loss op patterns.
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
@ -287,6 +287,11 @@ CreateTPUExtractHeadTailOutsideCompilationPass();
|
||||
// that are only used for host computation.
|
||||
std::unique_ptr<OperationPass<FuncOp>> CreateTPUHostComputationExpansionPass();
|
||||
|
||||
// Creates a pass that updates inputs to TPU embedding layer enqueue ops so that
|
||||
// correct ops are invoked during training and evaluation.
|
||||
std::unique_ptr<OperationPass<FuncOp>>
|
||||
CreateTPUUpdateEmbeddingEnqueueOpInputsPass();
|
||||
|
||||
// Creates a pass that extract outside compilation (CPU ops inside TPU cluster)
|
||||
// ops to a separate parallel_execute region to run on CPU.
|
||||
std::unique_ptr<OperationPass<ModuleOp>>
|
||||
|
@ -375,7 +375,7 @@ LogicalResult FindResourceArgUseInfo(
|
||||
info.data_type = assign.value().getType();
|
||||
continue;
|
||||
}
|
||||
if (isa<TF::StackPushV2Op>(user) || isa<TF::StackPopV2Op>(user)) {
|
||||
if (isa<TF::StackPushV2Op, TF::StackPopV2Op>(user)) {
|
||||
// Stacks will be handled by a separate pass.
|
||||
do_not_touch = true;
|
||||
break;
|
||||
|
@ -205,9 +205,9 @@ GetSubtypes(Type type) {
|
||||
// Returns whether type can be further refined.
|
||||
bool CanBeRefined(Type type) {
|
||||
auto shape_type = type.dyn_cast<ShapedType>();
|
||||
return shape_type && (!shape_type.hasStaticShape() ||
|
||||
shape_type.getElementType().isa<TF::ResourceType>() ||
|
||||
shape_type.getElementType().isa<TF::VariantType>());
|
||||
return shape_type &&
|
||||
(!shape_type.hasStaticShape() ||
|
||||
shape_type.getElementType().isa<TF::ResourceType, TF::VariantType>());
|
||||
}
|
||||
|
||||
// Infers the shape from a (Stateful)PartionedCall operation by looking up the
|
||||
@ -712,8 +712,7 @@ bool ShapeInference::InferShapeForSingleOperation(Operation* op) {
|
||||
// The shape function of these ops sometimes does not propagate subtypes
|
||||
// (handle shapes) for resource and variant types. We use a simple passthrough
|
||||
// to make sure they are preserved in the output.
|
||||
if (isa<TF::IdentityOp>(op) || isa<TF::IdentityNOp>(op) ||
|
||||
isa<TF::ZerosLikeOp>(op) || isa<TF::WhileOp>(op)) {
|
||||
if (isa<TF::IdentityOp, TF::IdentityNOp, TF::ZerosLikeOp, TF::WhileOp>(op)) {
|
||||
return RefineTypeForPassThroughOperands(op, op->getOperands(),
|
||||
op->getResults());
|
||||
}
|
||||
@ -729,7 +728,7 @@ bool ShapeInference::InferShapeForSingleOperation(Operation* op) {
|
||||
|
||||
// Handle call operations by looking up callee and infering return shape as
|
||||
// needed.
|
||||
if (isa<PartitionedCallOp>(op) || isa<StatefulPartitionedCallOp>(op))
|
||||
if (isa<PartitionedCallOp, StatefulPartitionedCallOp>(op))
|
||||
return InferShapeForCall(op);
|
||||
|
||||
// tf.Cast are only inferred if they have at least one user in the TF dialect
|
||||
@ -889,8 +888,7 @@ bool ShapeInference::InferShapeForSingleOperation(Operation* op) {
|
||||
};
|
||||
auto new_element_type = shaped_type.getElementType();
|
||||
// Populate the handle shapes for a resource/variant.
|
||||
if (new_element_type.isa<TF::ResourceType>() ||
|
||||
new_element_type.isa<TF::VariantType>()) {
|
||||
if (new_element_type.isa<TF::ResourceType, TF::VariantType>()) {
|
||||
auto handle_shapes_types = c.output_handle_shapes_and_types(output);
|
||||
if (handle_shapes_types) {
|
||||
SmallVector<TensorType, 1> subtypes;
|
||||
|
@ -488,7 +488,7 @@ LogicalResult DecomposeStackOpsInternal(
|
||||
llvm::StringMap<PartitionedCallStackOpsInfo>*
|
||||
decomposed_partitioned_call_callees) {
|
||||
for (auto& op : llvm::make_early_inc_range(block->getOperations())) {
|
||||
if (llvm::isa<TF::IdentityOp>(&op) || llvm::isa<TF::IdentityNOp>(&op)) {
|
||||
if (llvm::isa<TF::IdentityOp, TF::IdentityNOp>(&op)) {
|
||||
// Removes identity nodes in the block. The device computation does not
|
||||
// need such nodes to carry information.
|
||||
op.replaceAllUsesWith(op.getOperands());
|
||||
|
@ -809,7 +809,7 @@ LogicalResult DecomposeTensorArrayOps(
|
||||
llvm::StringMap<PartitionedCallTensorArrayOpsInfo>*
|
||||
decomposed_partitioned_call_callees) {
|
||||
for (auto& op : llvm::make_early_inc_range(block->getOperations())) {
|
||||
if (llvm::isa<TF::IdentityOp>(&op) || llvm::isa<TF::IdentityNOp>(&op)) {
|
||||
if (llvm::isa<TF::IdentityOp, TF::IdentityNOp>(&op)) {
|
||||
op.replaceAllUsesWith(op.getOperands());
|
||||
op.erase();
|
||||
} else if (auto ta = llvm::dyn_cast<TF::TensorArrayV3Op>(&op)) {
|
||||
|
@ -495,8 +495,7 @@ void TPUClusterFormation::runOnFunction() {
|
||||
|
||||
// Remove TPUReplicatedInput and TPUReplicatedOutput nodes.
|
||||
auto remove_result = getFunction().walk([&](Operation* op) {
|
||||
if (!llvm::isa<TF::TPUReplicatedInputOp>(op) &&
|
||||
!llvm::isa<TF::TPUReplicatedOutputOp>(op))
|
||||
if (!llvm::isa<TF::TPUReplicatedInputOp, TF::TPUReplicatedOutputOp>(op))
|
||||
return WalkResult::advance();
|
||||
|
||||
// Forward operand to result. When `num_replicas` attribute is 1, no
|
||||
|
@ -0,0 +1,162 @@
|
||||
/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
|
||||
#include "llvm/ADT/StringMap.h"
|
||||
#include "llvm/Support/Casting.h"
|
||||
#include "mlir/IR/Attributes.h" // from @llvm-project
|
||||
#include "mlir/IR/Block.h" // from @llvm-project
|
||||
#include "mlir/IR/Builders.h" // from @llvm-project
|
||||
#include "mlir/IR/Operation.h" // from @llvm-project
|
||||
#include "mlir/IR/Value.h" // from @llvm-project
|
||||
#include "mlir/IR/Visitors.h" // from @llvm-project
|
||||
#include "mlir/Pass/Pass.h" // from @llvm-project
|
||||
#include "mlir/Pass/PassRegistry.h" // from @llvm-project
|
||||
#include "mlir/Support/LogicalResult.h" // from @llvm-project
|
||||
#include "tensorflow/compiler/mlir/tensorflow/ir/tf_ops.h"
|
||||
#include "tensorflow/compiler/mlir/tensorflow/transforms/passes.h"
|
||||
|
||||
namespace mlir {
|
||||
namespace TFTPU {
|
||||
namespace {
|
||||
|
||||
constexpr char kTPUEmbeddingAttr[] = "_tpu_embedding_layer";
|
||||
|
||||
struct TPUUpdateEmbeddingEnqueueOpInputs
|
||||
: public PassWrapper<TPUUpdateEmbeddingEnqueueOpInputs, FunctionPass> {
|
||||
void runOnFunction() override;
|
||||
};
|
||||
|
||||
// Extracts `_tpu_embedding_layer` attribute from TPU embedding ops and
|
||||
// clear the attribute from the operation. This ensures that future optimization
|
||||
// passes does not trigger additional logic due to presence of this attribute.
|
||||
LogicalResult ExtractEmbeddingAttribute(
|
||||
Operation* op, llvm::StringMap<Operation*>* embedding_op_map) {
|
||||
auto embedding_attr = op->getAttrOfType<StringAttr>(kTPUEmbeddingAttr);
|
||||
if (!embedding_attr)
|
||||
return op->emitOpError("requires attribute '_tpu_embedding_layer'");
|
||||
|
||||
if (!embedding_op_map->insert({embedding_attr.getValue(), op}).second)
|
||||
return op->emitOpError(
|
||||
"found duplicate TPU embedding ops potentially from multiple "
|
||||
"TPUEmbedding layers");
|
||||
|
||||
op->removeAttr(kTPUEmbeddingAttr);
|
||||
return success();
|
||||
}
|
||||
|
||||
LogicalResult FindTPUEmbeddingOps(
|
||||
FuncOp func_op, llvm::StringMap<Operation*>* enqueue_op_map,
|
||||
llvm::StringMap<Operation*>* recv_activation_op_map,
|
||||
llvm::StringMap<Operation*>* send_gradient_op_map) {
|
||||
auto walk_result = func_op.walk([&](Operation* op) {
|
||||
if (llvm::isa<TF::RecvTPUEmbeddingActivationsOp>(op))
|
||||
if (failed(ExtractEmbeddingAttribute(op, recv_activation_op_map)))
|
||||
return WalkResult::interrupt();
|
||||
|
||||
if (llvm::isa<TF::SendTPUEmbeddingGradientsOp>(op))
|
||||
if (failed(ExtractEmbeddingAttribute(op, send_gradient_op_map)))
|
||||
return WalkResult::interrupt();
|
||||
|
||||
if (llvm::isa<TF::EnqueueTPUEmbeddingSparseTensorBatchOp,
|
||||
TF::EnqueueTPUEmbeddingRaggedTensorBatchOp>(op))
|
||||
if (failed(ExtractEmbeddingAttribute(op, enqueue_op_map)))
|
||||
return WalkResult::interrupt();
|
||||
|
||||
return WalkResult::advance();
|
||||
});
|
||||
return failure(walk_result.wasInterrupted());
|
||||
}
|
||||
|
||||
// Updates the operand of TPU embedding enqueue ops depending on whether
|
||||
// the graph is in training mode or in non-training mode.
|
||||
// If SendTPUEmbeddingGradients op is present, this means that graph is in
|
||||
// training mode. As so, correctly feed in `then` branch value of SelectV2
|
||||
// operand as inputs to the TPU embedding enqueue ops.
|
||||
LogicalResult UpdateEmbeddingEnqueueOpInput(
|
||||
const llvm::StringMap<Operation*>& enqueue_op_map,
|
||||
const llvm::StringMap<Operation*>& recv_activation_op_map,
|
||||
const llvm::StringMap<Operation*>& send_gradient_op_map) {
|
||||
for (const auto& it : enqueue_op_map) {
|
||||
const auto& embedding_attr = it.getKey();
|
||||
Operation* embedding_op = it.second;
|
||||
if (!recv_activation_op_map.count(embedding_attr))
|
||||
return embedding_op->emitOpError()
|
||||
<< "must have a corresponding '"
|
||||
<< TF::RecvTPUEmbeddingActivationsOp::getOperationName() << "' op";
|
||||
|
||||
// TPU Embedding enqueue ops take different inputs depending on whether
|
||||
// graph is in training mode or in eval/prediction mode. The inputs to the
|
||||
// enqueue ops are present/listed as operands to SelectV2 op. Then branch
|
||||
// operand of the SelectV2 op represents input to take during training
|
||||
// and else branch operand represents input to take during
|
||||
// prediction/evaluation. If SendTPUEmbeddingGradients op exists in the
|
||||
// graph, then graph is in training mode, so correctly forward the input
|
||||
// of SelectV2 op as operand to the TPU embedding enqueue op.
|
||||
bool is_training = send_gradient_op_map.count(embedding_attr);
|
||||
for (auto enqueue_operand : embedding_op->getOperands()) {
|
||||
if (auto select = llvm::dyn_cast_or_null<TF::SelectV2Op>(
|
||||
enqueue_operand.getDefiningOp())) {
|
||||
enqueue_operand.replaceAllUsesWith(is_training ? select.t()
|
||||
: select.e());
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
return success();
|
||||
}
|
||||
|
||||
void TPUUpdateEmbeddingEnqueueOpInputs::runOnFunction() {
|
||||
OpBuilder builder(&getContext());
|
||||
auto func_op = getFunction();
|
||||
|
||||
// All TPU embedding layer related ops are annotated with
|
||||
// `_tpu_embedding_layer` attribute along with corresponding string attribute.
|
||||
// Store all tpu embedding layer related ops with value of
|
||||
// `_tpu_embedding_layer` attribute as map key.
|
||||
llvm::StringMap<Operation*> enqueue_op_map;
|
||||
llvm::StringMap<Operation*> recv_activation_op_map;
|
||||
llvm::StringMap<Operation*> send_gradient_op_map;
|
||||
if (failed(FindTPUEmbeddingOps(func_op, &enqueue_op_map,
|
||||
&recv_activation_op_map,
|
||||
&send_gradient_op_map)))
|
||||
return signalPassFailure();
|
||||
|
||||
if (enqueue_op_map.size() != recv_activation_op_map.size()) {
|
||||
func_op.emitError() << "expects the number of embedding enqueue ops to "
|
||||
"match the number of '"
|
||||
<< TF::RecvTPUEmbeddingActivationsOp::getOperationName()
|
||||
<< "' ops";
|
||||
return signalPassFailure();
|
||||
}
|
||||
|
||||
if (failed(UpdateEmbeddingEnqueueOpInput(
|
||||
enqueue_op_map, recv_activation_op_map, send_gradient_op_map)))
|
||||
return signalPassFailure();
|
||||
}
|
||||
|
||||
} // anonymous namespace
|
||||
|
||||
std::unique_ptr<OperationPass<FuncOp>>
|
||||
CreateTPUUpdateEmbeddingEnqueueOpInputsPass() {
|
||||
return std::make_unique<TPUUpdateEmbeddingEnqueueOpInputs>();
|
||||
}
|
||||
|
||||
static PassRegistration<TPUUpdateEmbeddingEnqueueOpInputs> pass(
|
||||
"tf-tpu-update-embedding-enqueue-op-inputs",
|
||||
"Updates inputs to TPU embedding enqueue ops depending on whether graph "
|
||||
"is in training mode or in evaluation mode.");
|
||||
|
||||
} // namespace TFTPU
|
||||
} // namespace mlir
|
@ -576,9 +576,8 @@ StatusOr<std::unique_ptr<Graph>> Exporter::Convert(
|
||||
// Adds nodes for operations.
|
||||
for (Operation& inst : graph_op.GetBody()) {
|
||||
for (auto type : inst.getResultTypes())
|
||||
if (!type.isa<mlir::TensorType>() &&
|
||||
!type.isa<mlir::tf_executor::ControlType>() &&
|
||||
!type.isa<mlir::tf_executor::TokenType>())
|
||||
if (!type.isa<mlir::TensorType, mlir::tf_executor::ControlType,
|
||||
mlir::tf_executor::TokenType>())
|
||||
return errors::InvalidArgument(
|
||||
"Values must be of tensor type, TensorFlow control type, or "
|
||||
"TensorFlow token type. Found ",
|
||||
|
@ -253,7 +253,7 @@ static void RegisterDialects() {
|
||||
mlir::registerDialect<mlir::TF::TensorFlowDialect>();
|
||||
mlir::registerDialect<mlir::shape::ShapeDialect>();
|
||||
mlir::registerDialect<mlir::tf_executor::TensorFlowExecutorDialect>();
|
||||
mlir::registerDialect<mlir::mhlo::XlaHloDialect>();
|
||||
mlir::registerDialect<mlir::mhlo::MhloDialect>();
|
||||
return true;
|
||||
}();
|
||||
(void)init_once;
|
||||
|
@ -88,7 +88,7 @@ struct MaterializeBroadcastsPass
|
||||
mlir::OwningRewritePatternList conversionPatterns;
|
||||
|
||||
// Consider the mhlo dialect legal for tests.
|
||||
conversionTarget.addLegalDialect<mlir::mhlo::XlaHloDialect>();
|
||||
conversionTarget.addLegalDialect<mlir::mhlo::MhloDialect>();
|
||||
// The conversion uses helpers from the Standard dialect.
|
||||
conversionTarget.addLegalDialect<mlir::StandardOpsDialect>();
|
||||
|
||||
@ -128,7 +128,7 @@ Status LowerTfOpToLhloWithDynamicShapes(mlir::ModuleOp module) {
|
||||
pm.addNestedPass<mlir::FuncOp>(absl::make_unique<UnfuseBatchNormPass>());
|
||||
pm.addPass(mlir::mhlo::createLegalizeToLhloPass(
|
||||
/*results_escape_functions=*/true));
|
||||
pm.addNestedPass<mlir::FuncOp>(mlir::xla_lhlo::createLhloCopyRemovalPass());
|
||||
pm.addNestedPass<mlir::FuncOp>(mlir::lmhlo::createLhloCopyRemovalPass());
|
||||
|
||||
if (failed(pm.run(module))) {
|
||||
return InternalError("Lowering TF to LHLO failed.");
|
||||
|
@ -33,7 +33,7 @@ namespace xla {
|
||||
static std::string GetMlirOpName(HloOpcode opcode) {
|
||||
std::string op_name = HloOpcodeString(opcode);
|
||||
absl::c_replace(op_name, '-', '_');
|
||||
return mlir::mhlo::XlaHloDialect::getDialectNamespace().str() + "." + op_name;
|
||||
return mlir::mhlo::MhloDialect::getDialectNamespace().str() + "." + op_name;
|
||||
}
|
||||
|
||||
static std::string ToString(mlir::Type ty) {
|
||||
|
@ -1,11 +1,11 @@
|
||||
// RUN: xla-opt -split-input-file -xla-hlo-to-lhlo-with-xla %s | FileCheck --enable-var-scope %s
|
||||
|
||||
// CHECK-LABEL: func @main
|
||||
// CHECK-SAME: %[[ARG0:.*]]: memref<2x2xf32> {xla_lhlo.params = 0
|
||||
// CHECK-SAME: %[[ARG0:.*]]: memref<2x2xf32> {lmhlo.params = 0
|
||||
// CHECK-SAME: %[[ARG1:.*]]: memref<16xi8>
|
||||
func @main(%value: tensor<2x2xf32>) -> tensor<2x2xf32> {
|
||||
// CHECK: %[[VIEW:.*]] = {{.*}} memref<16xi8> to memref<2x2xf32>
|
||||
// CHECK: lhlo.abs
|
||||
// CHECK: lmhlo.abs
|
||||
// CHECK-SAME: %[[ARG0]], %[[VIEW]]
|
||||
%abs = "mhlo.abs"(%value) : (tensor<2x2xf32>) -> tensor<2x2xf32>
|
||||
return %abs : tensor<2x2xf32>
|
||||
@ -14,12 +14,12 @@ func @main(%value: tensor<2x2xf32>) -> tensor<2x2xf32> {
|
||||
// -----
|
||||
|
||||
// CHECK-LABEL: func @main
|
||||
// CHECK-SAME: %[[ARG0:.*]]: memref<2x2xf32> {xla_lhlo.params = 0
|
||||
// CHECK-SAME: %[[ARG1:.*]]: memref<2x2xf32> {xla_lhlo.params = 1
|
||||
// CHECK-SAME: %[[ARG0:.*]]: memref<2x2xf32> {lmhlo.params = 0
|
||||
// CHECK-SAME: %[[ARG1:.*]]: memref<2x2xf32> {lmhlo.params = 1
|
||||
// CHECK-SAME: %[[ARG2:.*]]: memref<16xi8>
|
||||
func @main(%value0: tensor<2x2xf32>, %value1: tensor<2x2xf32>) -> tensor<2x2xf32> {
|
||||
// CHECK: %[[VIEW:.*]] = {{.*}} memref<16xi8> to memref<2x2xf32>
|
||||
// CHECK: lhlo.add
|
||||
// CHECK: lmhlo.add
|
||||
// CHECK-SAME: %[[ARG0]], %[[ARG1]], %[[VIEW]]
|
||||
// CHECK-NEXT: return
|
||||
%res = "mhlo.add"(%value0, %value1) : (tensor<2x2xf32>, tensor<2x2xf32>) -> tensor<2x2xf32>
|
||||
@ -29,12 +29,12 @@ func @main(%value0: tensor<2x2xf32>, %value1: tensor<2x2xf32>) -> tensor<2x2xf32
|
||||
// -----
|
||||
|
||||
// CHECK-LABEL: func @main
|
||||
// CHECK-SAME: %[[ARG0:.*]]: memref<2x2xi32> {xla_lhlo.params = 0
|
||||
// CHECK-SAME: %[[ARG1:.*]]: memref<2x2xi32> {xla_lhlo.params = 1
|
||||
// CHECK-SAME: %[[ARG0:.*]]: memref<2x2xi32> {lmhlo.params = 0
|
||||
// CHECK-SAME: %[[ARG1:.*]]: memref<2x2xi32> {lmhlo.params = 1
|
||||
// CHECK-SAME: %[[ARG2:.*]]: memref<16xi8>
|
||||
func @main(%value0: tensor<2x2xi32>, %value1: tensor<2x2xi32>) -> tensor<2x2xi32> {
|
||||
// CHECK: %[[VIEW:.*]] = {{.*}} memref<16xi8> to memref<2x2xi32>
|
||||
// CHECK: lhlo.and
|
||||
// CHECK: lmhlo.and
|
||||
// CHECK-SAME: %[[ARG0]], %[[ARG1]], %[[VIEW]]
|
||||
// CHECK-NEXT: return
|
||||
%res = "mhlo.and"(%value0, %value1) : (tensor<2x2xi32>, tensor<2x2xi32>) -> tensor<2x2xi32>
|
||||
@ -44,11 +44,11 @@ func @main(%value0: tensor<2x2xi32>, %value1: tensor<2x2xi32>) -> tensor<2x2xi32
|
||||
// -----
|
||||
|
||||
// CHECK-LABEL: func @main
|
||||
// CHECK-SAME: %[[ARG0:.*]]: memref<2x2xf32> {xla_lhlo.params = 0
|
||||
// CHECK-SAME: %[[ARG0:.*]]: memref<2x2xf32> {lmhlo.params = 0
|
||||
// CHECK-SAME: %[[ARG1:.*]]: memref<16xi8>
|
||||
func @main(%value0: tensor<2x2xf32>) -> tensor<2x2xf32> {
|
||||
// CHECK: %[[VIEW:.*]] = {{.*}} memref<16xi8> to memref<2x2xf32>
|
||||
// CHECK: lhlo.ceil
|
||||
// CHECK: lmhlo.ceil
|
||||
// CHECK-SAME: %[[ARG0]], %[[VIEW]]
|
||||
%res = "mhlo.ceil"(%value0) : (tensor<2x2xf32>) -> tensor<2x2xf32>
|
||||
return %res : tensor<2x2xf32>
|
||||
@ -57,12 +57,12 @@ func @main(%value0: tensor<2x2xf32>) -> tensor<2x2xf32> {
|
||||
// -----
|
||||
|
||||
// CHECK-LABEL: func @main
|
||||
// CHECK-SAME: %[[ARG0:.*]]: memref<1x2xf32> {xla_lhlo.params = 0
|
||||
// CHECK-SAME: %[[ARG1:.*]]: memref<1x2xf32> {xla_lhlo.params = 1
|
||||
// CHECK-SAME: %[[ARG0:.*]]: memref<1x2xf32> {lmhlo.params = 0
|
||||
// CHECK-SAME: %[[ARG1:.*]]: memref<1x2xf32> {lmhlo.params = 1
|
||||
// CHECK-SAME: %[[ARG2:.*]]: memref<16xi8>
|
||||
func @main(%value0: tensor<1x2xf32>, %value1: tensor<1x2xf32>) -> tensor<1x2xcomplex<f32>> {
|
||||
// CHECK: %[[VIEW:.*]] = {{.*}} memref<16xi8> to memref<1x2xcomplex<f32>>
|
||||
// CHECK: lhlo.complex
|
||||
// CHECK: lmhlo.complex
|
||||
// CHECK-SAME: %[[ARG0]], %[[ARG1]], %[[VIEW]]
|
||||
// CHECK-NEXT: return
|
||||
%res = "mhlo.complex"(%value0, %value1) : (tensor<1x2xf32>, tensor<1x2xf32>) -> (tensor<1x2xcomplex<f32>>)
|
||||
@ -72,11 +72,11 @@ func @main(%value0: tensor<1x2xf32>, %value1: tensor<1x2xf32>) -> tensor<1x2xcom
|
||||
// -----
|
||||
|
||||
// CHECK-LABEL: func @main
|
||||
// CHECK-SAME: %[[ARG0:.*]]: memref<1x2xcomplex<f32>> {xla_lhlo.params = 0
|
||||
// CHECK-SAME: %[[ARG0:.*]]: memref<1x2xcomplex<f32>> {lmhlo.params = 0
|
||||
// CHECK-SAME: %[[ARG1:.*]]: memref<16xi8>
|
||||
func @main(%value0: tensor<1x2xcomplex<f32>>) -> tensor<1x2xcomplex<f32>> {
|
||||
// CHECK: %[[VIEW:.*]] = {{.*}} memref<16xi8> to memref<1x2xcomplex<f32>>
|
||||
// CHECK: lhlo.cosine
|
||||
// CHECK: lmhlo.cosine
|
||||
// CHECK-SAME: %[[ARG0]], %[[VIEW]]
|
||||
// CHECK-NEXT: return
|
||||
%res = "mhlo.cosine"(%value0) : (tensor<1x2xcomplex<f32>>) -> tensor<1x2xcomplex<f32>>
|
||||
@ -86,12 +86,12 @@ func @main(%value0: tensor<1x2xcomplex<f32>>) -> tensor<1x2xcomplex<f32>> {
|
||||
// -----
|
||||
|
||||
// CHECK-LABEL: func @main
|
||||
// CHECK-SAME: %[[ARG0:.*]]: memref<2x2xf32> {xla_lhlo.params = 0
|
||||
// CHECK-SAME: %[[ARG1:.*]]: memref<2x2xf32> {xla_lhlo.params = 1
|
||||
// CHECK-SAME: %[[ARG0:.*]]: memref<2x2xf32> {lmhlo.params = 0
|
||||
// CHECK-SAME: %[[ARG1:.*]]: memref<2x2xf32> {lmhlo.params = 1
|
||||
// CHECK-SAME: %[[ARG2:.*]]: memref<16xi8>
|
||||
func @main(%value0: tensor<2x2xf32>, %value1: tensor<2x2xf32>) -> tensor<2x2xf32> {
|
||||
// CHECK: %[[VIEW:.*]] = {{.*}} memref<16xi8> to memref<2x2xf32>
|
||||
// CHECK: lhlo.divide
|
||||
// CHECK: lmhlo.divide
|
||||
// CHECK-SAME: %[[ARG0]], %[[ARG1]], %[[VIEW]]
|
||||
// CHECK-NEXT: return
|
||||
%res = "mhlo.divide"(%value0, %value1) : (tensor<2x2xf32>, tensor<2x2xf32>) -> tensor<2x2xf32>
|
||||
@ -101,11 +101,11 @@ func @main(%value0: tensor<2x2xf32>, %value1: tensor<2x2xf32>) -> tensor<2x2xf32
|
||||
// -----
|
||||
|
||||
// CHECK-LABEL: func @main
|
||||
// CHECK-SAME: %[[ARG0:.*]]: memref<2x2xf32> {xla_lhlo.params = 0
|
||||
// CHECK-SAME: %[[ARG0:.*]]: memref<2x2xf32> {lmhlo.params = 0
|
||||
// CHECK-SAME: %[[ARG1:.*]]: memref<16xi8>
|
||||
func @main(%value0: tensor<2x2xf32>) -> tensor<2x2xf32> {
|
||||
// CHECK: %[[VIEW:.*]] = {{.*}} memref<16xi8> to memref<2x2xf32>
|
||||
// CHECK: lhlo.exponential
|
||||
// CHECK: lmhlo.exponential
|
||||
// CHECK-SAME: %[[ARG0]], %[[VIEW]]
|
||||
%res = "mhlo.exponential"(%value0) : (tensor<2x2xf32>) -> tensor<2x2xf32>
|
||||
return %res : tensor<2x2xf32>
|
||||
@ -114,11 +114,11 @@ func @main(%value0: tensor<2x2xf32>) -> tensor<2x2xf32> {
|
||||
// -----
|
||||
|
||||
// CHECK-LABEL: func @main
|
||||
// CHECK-SAME: %[[ARG0:.*]]: memref<2x2xf32> {xla_lhlo.params = 0
|
||||
// CHECK-SAME: %[[ARG0:.*]]: memref<2x2xf32> {lmhlo.params = 0
|
||||
// CHECK-SAME: %[[ARG1:.*]]: memref<16xi8>
|
||||
func @main(%value0: tensor<2x2xf32>) -> tensor<2x2xf32> {
|
||||
// CHECK: %[[VIEW:.*]] = {{.*}} memref<16xi8> to memref<2x2xf32>
|
||||
// CHECK: lhlo.log
|
||||
// CHECK: lmhlo.log
|
||||
// CHECK-SAME: %[[ARG0]], %[[VIEW]]
|
||||
%res = "mhlo.log"(%value0) : (tensor<2x2xf32>) -> tensor<2x2xf32>
|
||||
return %res : tensor<2x2xf32>
|
||||
@ -127,12 +127,12 @@ func @main(%value0: tensor<2x2xf32>) -> tensor<2x2xf32> {
|
||||
// -----
|
||||
|
||||
// CHECK-LABEL: func @main
|
||||
// CHECK-SAME: %[[ARG0:.*]]: memref<2x2xf32> {xla_lhlo.params = 0
|
||||
// CHECK-SAME: %[[ARG1:.*]]: memref<2x2xf32> {xla_lhlo.params = 1
|
||||
// CHECK-SAME: %[[ARG0:.*]]: memref<2x2xf32> {lmhlo.params = 0
|
||||
// CHECK-SAME: %[[ARG1:.*]]: memref<2x2xf32> {lmhlo.params = 1
|
||||
// CHECK-SAME: %[[ARG2:.*]]: memref<16xi8>
|
||||
func @main(%value0: tensor<2x2xf32>, %value1: tensor<2x2xf32>) -> tensor<2x2xf32> {
|
||||
// CHECK: %[[VIEW:.*]] = {{.*}} memref<16xi8> to memref<2x2xf32>
|
||||
// CHECK: lhlo.maximum
|
||||
// CHECK: lmhlo.maximum
|
||||
// CHECK-SAME: %[[ARG0]], %[[ARG1]], %[[VIEW]]
|
||||
// CHECK-NEXT: return
|
||||
%res = "mhlo.maximum"(%value0, %value1) : (tensor<2x2xf32>, tensor<2x2xf32>) -> tensor<2x2xf32>
|
||||
@ -142,12 +142,12 @@ func @main(%value0: tensor<2x2xf32>, %value1: tensor<2x2xf32>) -> tensor<2x2xf32
|
||||
// -----
|
||||
|
||||
// CHECK-LABEL: func @main
|
||||
// CHECK-SAME: %[[ARG0:.*]]: memref<2x2xf32> {xla_lhlo.params = 0
|
||||
// CHECK-SAME: %[[ARG1:.*]]: memref<2x2xf32> {xla_lhlo.params = 1
|
||||
// CHECK-SAME: %[[ARG0:.*]]: memref<2x2xf32> {lmhlo.params = 0
|
||||
// CHECK-SAME: %[[ARG1:.*]]: memref<2x2xf32> {lmhlo.params = 1
|
||||
// CHECK-SAME: %[[ARG2:.*]]: memref<16xi8>
|
||||
func @main(%value0: tensor<2x2xf32>, %value1: tensor<2x2xf32>) -> tensor<2x2xf32> {
|
||||
// CHECK: %[[VIEW:.*]] = {{.*}} memref<16xi8> to memref<2x2xf32>
|
||||
// CHECK: lhlo.minimum
|
||||
// CHECK: lmhlo.minimum
|
||||
// CHECK-SAME: %[[ARG0]], %[[ARG1]], %[[VIEW]]
|
||||
// CHECK-NEXT: return
|
||||
%res = "mhlo.minimum"(%value0, %value1) : (tensor<2x2xf32>, tensor<2x2xf32>) -> tensor<2x2xf32>
|
||||
@ -157,12 +157,12 @@ func @main(%value0: tensor<2x2xf32>, %value1: tensor<2x2xf32>) -> tensor<2x2xf32
|
||||
// -----
|
||||
|
||||
// CHECK-LABEL: func @main
|
||||
// CHECK-SAME: %[[ARG0:.*]]: memref<2x2xf32> {xla_lhlo.params = 0
|
||||
// CHECK-SAME: %[[ARG1:.*]]: memref<2x2xf32> {xla_lhlo.params = 1
|
||||
// CHECK-SAME: %[[ARG0:.*]]: memref<2x2xf32> {lmhlo.params = 0
|
||||
// CHECK-SAME: %[[ARG1:.*]]: memref<2x2xf32> {lmhlo.params = 1
|
||||
// CHECK-SAME: %[[ARG2:.*]]: memref<16xi8>
|
||||
func @main(%value0: tensor<2x2xf32>, %value1: tensor<2x2xf32>) -> tensor<2x2xf32> {
|
||||
// CHECK: %[[VIEW:.*]] = {{.*}} memref<16xi8> to memref<2x2xf32>
|
||||
// CHECK: lhlo.multiply
|
||||
// CHECK: lmhlo.multiply
|
||||
// CHECK-SAME: %[[ARG0]], %[[ARG1]], %[[VIEW]]
|
||||
// CHECK-NEXT: return
|
||||
%res = "mhlo.multiply"(%value0, %value1) : (tensor<2x2xf32>, tensor<2x2xf32>) -> tensor<2x2xf32>
|
||||
@ -172,11 +172,11 @@ func @main(%value0: tensor<2x2xf32>, %value1: tensor<2x2xf32>) -> tensor<2x2xf32
|
||||
// -----
|
||||
|
||||
// CHECK-LABEL: func @main
|
||||
// CHECK-SAME: %[[ARG0:.*]]: memref<2x2xf32> {xla_lhlo.params = 0
|
||||
// CHECK-SAME: %[[ARG0:.*]]: memref<2x2xf32> {lmhlo.params = 0
|
||||
// CHECK-SAME: %[[ARG1:.*]]: memref<16xi8>
|
||||
func @main(%value0: tensor<2x2xf32>) -> tensor<2x2xf32> {
|
||||
// CHECK: %[[VIEW:.*]] = {{.*}} memref<16xi8> to memref<2x2xf32>
|
||||
// CHECK: lhlo.negate
|
||||
// CHECK: lmhlo.negate
|
||||
// CHECK-SAME: %[[ARG0]], %[[VIEW]]
|
||||
%res = "mhlo.negate"(%value0) : (tensor<2x2xf32>) -> tensor<2x2xf32>
|
||||
return %res : tensor<2x2xf32>
|
||||
@ -185,11 +185,11 @@ func @main(%value0: tensor<2x2xf32>) -> tensor<2x2xf32> {
|
||||
// -----
|
||||
|
||||
// CHECK-LABEL: func @main
|
||||
// CHECK-SAME: %[[ARG0:.*]]: memref<1x2xcomplex<f32>> {xla_lhlo.params = 0
|
||||
// CHECK-SAME: %[[ARG0:.*]]: memref<1x2xcomplex<f32>> {lmhlo.params = 0
|
||||
// CHECK-SAME: %[[ARG1:.*]]: memref<8xi8>
|
||||
func @main(%value0: tensor<1x2xcomplex<f32>>) -> tensor<1x2xf32> {
|
||||
// CHECK: %[[VIEW:.*]] = {{.*}} memref<8xi8> to memref<1x2xf32>
|
||||
// CHECK: lhlo.real
|
||||
// CHECK: lmhlo.real
|
||||
// CHECK-SAME: %[[ARG0]], %[[VIEW]]
|
||||
%res = "mhlo.real"(%value0) : (tensor<1x2xcomplex<f32>>) -> (tensor<1x2xf32>)
|
||||
return %res : tensor<1x2xf32>
|
||||
@ -198,11 +198,11 @@ func @main(%value0: tensor<1x2xcomplex<f32>>) -> tensor<1x2xf32> {
|
||||
// -----
|
||||
|
||||
// CHECK-LABEL: func @main
|
||||
// CHECK-SAME: %[[ARG0:.*]]: memref<1x2xcomplex<f32>> {xla_lhlo.params = 0
|
||||
// CHECK-SAME: %[[ARG0:.*]]: memref<1x2xcomplex<f32>> {lmhlo.params = 0
|
||||
// CHECK-SAME: %[[ARG1:.*]]: memref<8xi8>
|
||||
func @main(%value0: tensor<1x2xcomplex<f32>>) -> tensor<1x2xf32> {
|
||||
// CHECK: %[[VIEW:.*]] = {{.*}} memref<8xi8> to memref<1x2xf32>
|
||||
// CHECK: lhlo.imag
|
||||
// CHECK: lmhlo.imag
|
||||
// CHECK-SAME: %[[ARG0]], %[[VIEW]]
|
||||
%res = "mhlo.imag"(%value0) : (tensor<1x2xcomplex<f32>>) -> (tensor<1x2xf32>)
|
||||
return %res : tensor<1x2xf32>
|
||||
@ -211,12 +211,12 @@ func @main(%value0: tensor<1x2xcomplex<f32>>) -> tensor<1x2xf32> {
|
||||
// -----
|
||||
|
||||
// CHECK-LABEL: func @main
|
||||
// CHECK-SAME: %[[ARG0:.*]]: memref<2x2xi32> {xla_lhlo.params = 0
|
||||
// CHECK-SAME: %[[ARG1:.*]]: memref<2x2xi32> {xla_lhlo.params = 1
|
||||
// CHECK-SAME: %[[ARG0:.*]]: memref<2x2xi32> {lmhlo.params = 0
|
||||
// CHECK-SAME: %[[ARG1:.*]]: memref<2x2xi32> {lmhlo.params = 1
|
||||
// CHECK-SAME: %[[ARG2:.*]]: memref<16xi8>
|
||||
func @main(%value0: tensor<2x2xi32>, %value1: tensor<2x2xi32>) -> tensor<2x2xi32> {
|
||||
// CHECK: %[[VIEW:.*]] = {{.*}} memref<16xi8> to memref<2x2xi32>
|
||||
// CHECK: lhlo.remainder
|
||||
// CHECK: lmhlo.remainder
|
||||
// CHECK-SAME: %[[ARG0]], %[[ARG1]], %[[VIEW]]
|
||||
// CHECK-NEXT: return
|
||||
%res = "mhlo.remainder"(%value0, %value1) : (tensor<2x2xi32>, tensor<2x2xi32>) -> tensor<2x2xi32>
|
||||
@ -226,11 +226,11 @@ func @main(%value0: tensor<2x2xi32>, %value1: tensor<2x2xi32>) -> tensor<2x2xi32
|
||||
// -----
|
||||
|
||||
// CHECK-LABEL: func @main
|
||||
// CHECK-SAME: %[[ARG0:.*]]: memref<2x2xf32> {xla_lhlo.params = 0
|
||||
// CHECK-SAME: %[[ARG0:.*]]: memref<2x2xf32> {lmhlo.params = 0
|
||||
// CHECK-SAME: %[[ARG1:.*]]: memref<16xi8>
|
||||
func @main(%value0: tensor<2x2xf32>) -> tensor<2x2xf32> {
|
||||
// CHECK: %[[VIEW:.*]] = {{.*}} memref<16xi8> to memref<2x2xf32>
|
||||
// CHECK: lhlo.rsqrt
|
||||
// CHECK: lmhlo.rsqrt
|
||||
// CHECK-SAME: %[[ARG0]], %[[VIEW]]
|
||||
%res = "mhlo.rsqrt"(%value0) : (tensor<2x2xf32>) -> tensor<2x2xf32>
|
||||
return %res : tensor<2x2xf32>
|
||||
@ -239,13 +239,13 @@ func @main(%value0: tensor<2x2xf32>) -> tensor<2x2xf32> {
|
||||
// -----
|
||||
|
||||
// CHECK-LABEL: func @main
|
||||
// CHECK-SAME: %[[ARG0:.*]]: memref<2x2xi1> {xla_lhlo.params = 0
|
||||
// CHECK-SAME: %[[ARG1:.*]]: memref<2x2xf32> {xla_lhlo.params = 1
|
||||
// CHECK-SAME: %[[ARG2:.*]]: memref<2x2xf32> {xla_lhlo.params = 2
|
||||
// CHECK-SAME: %[[ARG0:.*]]: memref<2x2xi1> {lmhlo.params = 0
|
||||
// CHECK-SAME: %[[ARG1:.*]]: memref<2x2xf32> {lmhlo.params = 1
|
||||
// CHECK-SAME: %[[ARG2:.*]]: memref<2x2xf32> {lmhlo.params = 2
|
||||
// CHECK-SAME: %[[ARG3:.*]]: memref<16xi8>
|
||||
func @main(%pred: tensor<2x2xi1>, %lhs: tensor<2x2xf32>, %rhs: tensor<2x2xf32>) -> tensor<2x2xf32> {
|
||||
// CHECK: %[[VIEW:.*]] = {{.*}} memref<16xi8> to memref<2x2xf32>
|
||||
// CHECK: lhlo.select
|
||||
// CHECK: lmhlo.select
|
||||
// CHECK-SAME: %[[ARG0]], %[[ARG1]], %[[ARG2]], %[[VIEW]]
|
||||
// CHECK-NEXT: return
|
||||
%0 = "mhlo.select"(%pred, %lhs, %rhs) : (tensor<2x2xi1>, tensor<2x2xf32>, tensor<2x2xf32>) -> (tensor<2x2xf32>)
|
||||
@ -255,11 +255,11 @@ func @main(%pred: tensor<2x2xi1>, %lhs: tensor<2x2xf32>, %rhs: tensor<2x2xf32>)
|
||||
// -----
|
||||
|
||||
// CHECK-LABEL: func @main
|
||||
// CHECK-SAME: %[[ARG0:.*]]: memref<2x2xf32> {xla_lhlo.params = 0
|
||||
// CHECK-SAME: %[[ARG0:.*]]: memref<2x2xf32> {lmhlo.params = 0
|
||||
// CHECK-SAME: %[[ARG1:.*]]: memref<16xi8>
|
||||
func @main(%value0: tensor<2x2xf32>) -> tensor<2x2xf32> {
|
||||
// CHECK: %[[VIEW:.*]] = {{.*}} memref<16xi8> to memref<2x2xf32>
|
||||
// CHECK: lhlo.sign
|
||||
// CHECK: lmhlo.sign
|
||||
// CHECK-SAME: %[[ARG0]], %[[VIEW]]
|
||||
%res = "mhlo.sign"(%value0) : (tensor<2x2xf32>) -> tensor<2x2xf32>
|
||||
return %res : tensor<2x2xf32>
|
||||
@ -268,11 +268,11 @@ func @main(%value0: tensor<2x2xf32>) -> tensor<2x2xf32> {
|
||||
// -----
|
||||
|
||||
// CHECK-LABEL: func @main
|
||||
// CHECK-SAME: %[[ARG0:.*]]: memref<2x2xf32> {xla_lhlo.params = 0
|
||||
// CHECK-SAME: %[[ARG0:.*]]: memref<2x2xf32> {lmhlo.params = 0
|
||||
// CHECK-SAME: %[[ARG1:.*]]: memref<16xi8>
|
||||
func @main(%value0: tensor<2x2xf32>) -> tensor<2x2xf32> {
|
||||
// CHECK: %[[VIEW:.*]] = {{.*}} memref<16xi8> to memref<2x2xf32>
|
||||
// CHECK: lhlo.sqrt
|
||||
// CHECK: lmhlo.sqrt
|
||||
// CHECK-SAME: %[[ARG0]], %[[VIEW]]
|
||||
%res = "mhlo.sqrt"(%value0) : (tensor<2x2xf32>) -> tensor<2x2xf32>
|
||||
return %res : tensor<2x2xf32>
|
||||
@ -281,12 +281,12 @@ func @main(%value0: tensor<2x2xf32>) -> tensor<2x2xf32> {
|
||||
// -----
|
||||
|
||||
// CHECK-LABEL: func @main
|
||||
// CHECK-SAME: %[[ARG0:.*]]: memref<2x2xi32> {xla_lhlo.params = 0
|
||||
// CHECK-SAME: %[[ARG1:.*]]: memref<2x2xi32> {xla_lhlo.params = 1
|
||||
// CHECK-SAME: %[[ARG0:.*]]: memref<2x2xi32> {lmhlo.params = 0
|
||||
// CHECK-SAME: %[[ARG1:.*]]: memref<2x2xi32> {lmhlo.params = 1
|
||||
// CHECK-SAME: %[[ARG2:.*]]: memref<16xi8>
|
||||
func @main(%value0: tensor<2x2xi32>, %value1: tensor<2x2xi32>) -> tensor<2x2xi32> {
|
||||
// CHECK: %[[VIEW:.*]] = {{.*}} memref<16xi8> to memref<2x2xi32>
|
||||
// CHECK: lhlo.subtract
|
||||
// CHECK: lmhlo.subtract
|
||||
// CHECK-SAME: %[[ARG0]], %[[ARG1]], %[[VIEW]]
|
||||
// CHECK-NEXT: return
|
||||
%res = "mhlo.subtract"(%value0, %value1) : (tensor<2x2xi32>, tensor<2x2xi32>) -> tensor<2x2xi32>
|
||||
@ -296,11 +296,11 @@ func @main(%value0: tensor<2x2xi32>, %value1: tensor<2x2xi32>) -> tensor<2x2xi32
|
||||
// -----
|
||||
|
||||
// CHECK-LABEL: func @main
|
||||
// CHECK-SAME: %[[ARG0:.*]]: memref<2x2xf32> {xla_lhlo.params = 0
|
||||
// CHECK-SAME: %[[ARG0:.*]]: memref<2x2xf32> {lmhlo.params = 0
|
||||
// CHECK-SAME: %[[ARG1:.*]]: memref<16xi8>
|
||||
func @main(%value0: tensor<2x2xf32>) -> tensor<2x2xf32> {
|
||||
// CHECK: %[[VIEW:.*]] = {{.*}} memref<16xi8> to memref<2x2xf32>
|
||||
// CHECK: lhlo.tanh
|
||||
// CHECK: lmhlo.tanh
|
||||
// CHECK-SAME: %[[ARG0]], %[[VIEW]]
|
||||
%res = "mhlo.tanh"(%value0) : (tensor<2x2xf32>) -> tensor<2x2xf32>
|
||||
return %res : tensor<2x2xf32>
|
||||
@ -311,11 +311,11 @@ func @main(%value0: tensor<2x2xf32>) -> tensor<2x2xf32> {
|
||||
// CHECK-LABEL: func @main
|
||||
// CHECK-SAME: %[[ARG0:.*]]: memref<5x5xi32>
|
||||
// CHECK-SAME: %[[ARG1:.*]]: memref<5x5xf32>
|
||||
// CHECK-SAME: %[[ARG2:.*]]: memref<100xi8> {xla_lhlo.alloc = 0
|
||||
// CHECK-SAME: %[[ARG3:.*]]: memref<100xi8> {xla_lhlo.alloc = 1
|
||||
// CHECK-SAME: %[[ARG2:.*]]: memref<100xi8> {lmhlo.alloc = 0
|
||||
// CHECK-SAME: %[[ARG3:.*]]: memref<100xi8> {lmhlo.alloc = 1
|
||||
// CHECK: %[[VIEW0:.*]] = std.view %[[ARG2]]{{.*}} : memref<100xi8> to memref<5x5xi32>
|
||||
// CHECK: %[[VIEW1:.*]] = std.view %[[ARG3]]{{.*}} : memref<100xi8> to memref<5x5xf32>
|
||||
// CHECK: "xla_lhlo.sort"(%[[ARG0]], %[[ARG1]], %[[VIEW0]], %[[VIEW1]])
|
||||
// CHECK: "lmhlo.sort"(%[[ARG0]], %[[ARG1]], %[[VIEW0]], %[[VIEW1]])
|
||||
func @main(%key: tensor<5x5xi32>, %value: tensor<5x5xf32>) -> tuple<tensor<5x5xi32>, tensor<5x5xf32>> {
|
||||
%res = "mhlo.sort"(%key, %value) ({
|
||||
^bb0(%a: tensor<i32>, %b: tensor<i32>, %c: tensor<f32>, %d: tensor<f32>):
|
||||
|
@ -3,14 +3,14 @@
|
||||
// Current allocation will lead to one buffer argument for the "value" and
|
||||
// another one for the output, an no returned values.
|
||||
// CHECK-LABEL: func @main
|
||||
// CHECK-SAME: %[[ARG0:.*]]: memref<2x2xf32> {xla_lhlo.params = 0 : index},
|
||||
// CHECK-SAME: %[[ARG1:.*]]: memref<16xi8> {xla_lhlo.alloc = 0 : index, xla_lhlo.liveout = true}
|
||||
// CHECK-SAME: %[[ARG0:.*]]: memref<2x2xf32> {lmhlo.params = 0 : index},
|
||||
// CHECK-SAME: %[[ARG1:.*]]: memref<16xi8> {lmhlo.alloc = 0 : index, lmhlo.liveout = true}
|
||||
// CHECK-SAME: ) {
|
||||
func @main(%value: tensor<2x2xf32>) -> tensor<2x2xf32> {
|
||||
// The only expected instruction is a copy from the input into the output.
|
||||
// CHECK: %[[C0:.*]] = constant 0 : index
|
||||
// CHECK: %[[OUTPUT:.*]] = std.view %[[ARG1]][%[[C0]]][] : memref<16xi8> to memref<2x2xf32>
|
||||
// CHECK: xla_lhlo.copy
|
||||
// CHECK: lmhlo.copy
|
||||
// CHECK-SAME: %[[ARG0]], %[[OUTPUT]]
|
||||
return %value : tensor<2x2xf32>
|
||||
}
|
||||
|
@ -23,8 +23,8 @@ func @unknown_op(%arg0: tensor<2xf32>) -> tensor<2xf32> {
|
||||
return %0 : tensor<2xf32>
|
||||
}
|
||||
|
||||
// CHECK-LABEL: not_whitelisted_op
|
||||
func @not_whitelisted_op(%arg0: tensor<3xi32>, %arg1: tensor<i32>, %arg2: tensor<i32>) -> tensor<?x?x?xf32> {
|
||||
// CHECK-LABEL: not_allowlisted_op
|
||||
func @not_allowlisted_op(%arg0: tensor<3xi32>, %arg1: tensor<i32>, %arg2: tensor<i32>) -> tensor<?x?x?xf32> {
|
||||
// CHECK: tf.TensorListReserve
|
||||
%0 = "tf.TensorListReserve"(%arg0, %arg1) : (tensor<3xi32>, tensor<i32>) -> tensor<!tf.variant<tensor<?x?x?xf32>>>
|
||||
// CHECK: tf.TensorListGetItem
|
||||
|
@ -54,7 +54,7 @@ func @fusedBatchNormV3_training(%arg0: tensor<8x8x8x8xf32>, %arg1: tensor<8xf32>
|
||||
// CHECK: "mhlo.get_tuple_element"(%[[RESULT0]]) {index = 1 : i32} : (tuple<tensor<8x8x8x8xf32>, tensor<8xf32>, tensor<8xf32>>) -> tensor<8xf32>
|
||||
// CHECK: %[[VAR:.*]] = "mhlo.get_tuple_element"(%[[RESULT0]]) {index = 2 : i32} : (tuple<tensor<8x8x8x8xf32>, tensor<8xf32>, tensor<8xf32>>) -> tensor<8xf32>
|
||||
// CHECK: mhlo.constant
|
||||
// CHECK: xla_chlo.broadcast_multiply %[[VAR]], {{.*}} : (tensor<8xf32>, tensor<f32>) -> tensor<8xf32>
|
||||
// CHECK: chlo.broadcast_multiply %[[VAR]], {{.*}} : (tensor<8xf32>, tensor<f32>) -> tensor<8xf32>
|
||||
return %0#0 : tensor<8x8x8x8xf32>
|
||||
}
|
||||
|
||||
@ -75,18 +75,18 @@ func @fusedBatchNormV3_training_exponentialAvgFactor(%arg0: tensor<8x8x8x8xf32>,
|
||||
// CHECK-DAG: %[[BATCH_VAR:.*]] = "mhlo.get_tuple_element"(%[[RESULT0]]) {index = 2 : i32}
|
||||
|
||||
// CHECK: %[[FACTOR:.*]] = mhlo.constant dense<1.00195694>
|
||||
// CHECK: %[[CORRECTED_VAR:.*]] = xla_chlo.broadcast_multiply %[[BATCH_VAR]], %[[FACTOR]]
|
||||
// CHECK: %[[CORRECTED_VAR:.*]] = chlo.broadcast_multiply %[[BATCH_VAR]], %[[FACTOR]]
|
||||
|
||||
// CHECK-DAG: %[[ALPHA:.*]] = mhlo.constant dense<0.199999988>
|
||||
// CHECK-DAG: %[[BETA:.*]] = mhlo.constant dense<8.000000e-01>
|
||||
|
||||
// CHECK: %[[ALPHA_MUL_OLD_MEAN:.*]] = xla_chlo.broadcast_multiply %[[ALPHA]], %arg3
|
||||
// CHECK: %[[BETA_MUL_BATCH_MEAN:.*]] = xla_chlo.broadcast_multiply %[[BETA]], %[[BATCH_MEAN]]
|
||||
// CHECK: %[[NEW_BATCH_MEAN:.*]] = xla_chlo.broadcast_add %[[ALPHA_MUL_OLD_MEAN]], %[[BETA_MUL_BATCH_MEAN]]
|
||||
// CHECK: %[[ALPHA_MUL_OLD_MEAN:.*]] = chlo.broadcast_multiply %[[ALPHA]], %arg3
|
||||
// CHECK: %[[BETA_MUL_BATCH_MEAN:.*]] = chlo.broadcast_multiply %[[BETA]], %[[BATCH_MEAN]]
|
||||
// CHECK: %[[NEW_BATCH_MEAN:.*]] = chlo.broadcast_add %[[ALPHA_MUL_OLD_MEAN]], %[[BETA_MUL_BATCH_MEAN]]
|
||||
|
||||
// CHECK: %[[ALPHA_MUL_OLD_VAR:.*]] = xla_chlo.broadcast_multiply %[[ALPHA]], %arg4
|
||||
// CHECK: %[[BETA_MUL_CORRECTED_VAR:.*]] = xla_chlo.broadcast_multiply %[[BETA]], %[[CORRECTED_VAR]]
|
||||
// CHECK: %[[NEW_BATCH_VAR:.*]] = xla_chlo.broadcast_add %[[ALPHA_MUL_OLD_VAR]], %[[BETA_MUL_CORRECTED_VAR]]
|
||||
// CHECK: %[[ALPHA_MUL_OLD_VAR:.*]] = chlo.broadcast_multiply %[[ALPHA]], %arg4
|
||||
// CHECK: %[[BETA_MUL_CORRECTED_VAR:.*]] = chlo.broadcast_multiply %[[BETA]], %[[CORRECTED_VAR]]
|
||||
// CHECK: %[[NEW_BATCH_VAR:.*]] = chlo.broadcast_add %[[ALPHA_MUL_OLD_VAR]], %[[BETA_MUL_CORRECTED_VAR]]
|
||||
|
||||
// CHECK: return %[[NEW_BATCH_MEAN]], %[[NEW_BATCH_VAR]], %[[BATCH_MEAN]], %[[BATCH_VAR]]
|
||||
return %0#1, %0#2, %0#3, %0#4 : tensor<8xf32>, tensor<8xf32>, tensor<8xf32>, tensor<8xf32>
|
||||
@ -134,7 +134,7 @@ func @fusedBatchNormGrad_noTraining(%arg0: tensor<8x8x8x8xf32>, %arg1: tensor<8x
|
||||
// CHECK-NEXT: %[[act:.*]] = "mhlo.convert"(%arg1) : (tensor<8x8x8x8xf32>) -> tensor<8x8x8x8xf32>
|
||||
// CHECK-NEXT: %[[eps:.*]] = mhlo.constant dense<1.000000e-03> : tensor<f32>
|
||||
|
||||
// CHECK-NEXT: %[[add:.*]] = xla_chlo.broadcast_add %arg4, %[[eps]] {broadcast_dimensions = dense<[]> : tensor<0xi64>} : (tensor<8xf32>, tensor<f32>) -> tensor<8xf32>
|
||||
// CHECK-NEXT: %[[add:.*]] = chlo.broadcast_add %arg4, %[[eps]] {broadcast_dimensions = dense<[]> : tensor<0xi64>} : (tensor<8xf32>, tensor<f32>) -> tensor<8xf32>
|
||||
// CHECK-NEXT: %[[scr1:.*]] = "mhlo.rsqrt"(%[[add]]) : (tensor<8xf32>) -> tensor<8xf32>
|
||||
|
||||
// CHECK: %[[bcast_arg3:.+]] = "mhlo.dynamic_broadcast_in_dim"(%arg3, {{.*}}) {broadcast_dimensions = dense<3> : tensor<1xi64>} : (tensor<8xf32>, tensor<4xindex>) -> tensor<8x8x8x8xf32>
|
||||
@ -193,7 +193,7 @@ func @fusedBatchNormGradV2_noTraining(%arg0: tensor<8x8x8x8xf32>, %arg1: tensor<
|
||||
// CHECK-NEXT: %[[act:.*]] = "mhlo.convert"(%arg1) : (tensor<8x8x8x8xf32>) -> tensor<8x8x8x8xf32>
|
||||
// CHECK-NEXT: %[[eps:.*]] = mhlo.constant dense<1.000000e-03> : tensor<f32>
|
||||
|
||||
// CHECK-NEXT: %[[add:.*]] = xla_chlo.broadcast_add %arg4, %[[eps]] {broadcast_dimensions = dense<[]> : tensor<0xi64>} : (tensor<8xf32>, tensor<f32>) -> tensor<8xf32>
|
||||
// CHECK-NEXT: %[[add:.*]] = chlo.broadcast_add %arg4, %[[eps]] {broadcast_dimensions = dense<[]> : tensor<0xi64>} : (tensor<8xf32>, tensor<f32>) -> tensor<8xf32>
|
||||
// CHECK-NEXT: %[[scr1:.*]] = "mhlo.rsqrt"(%[[add]]) : (tensor<8xf32>) -> tensor<8xf32>
|
||||
|
||||
// CHECK: %[[bcast_arg3:.+]] = "mhlo.dynamic_broadcast_in_dim"(%arg3, {{.*}}) {broadcast_dimensions = dense<3> : tensor<1xi64>} : (tensor<8xf32>, tensor<4xindex>) -> tensor<8x8x8x8xf32>
|
||||
@ -280,7 +280,7 @@ func @fusedBatchNormGradV3_noTraining(%arg0: tensor<8x8x8x8xf32>, %arg1: tensor<
|
||||
// CHECK-NEXT: %[[act:.*]] = "mhlo.convert"(%arg1) : (tensor<8x8x8x8xf32>) -> tensor<8x8x8x8xf32>
|
||||
// CHECK-NEXT: %[[eps:.*]] = mhlo.constant dense<1.000000e-03> : tensor<f32>
|
||||
|
||||
// CHECK-NEXT: %[[add:.*]] = xla_chlo.broadcast_add %arg4, %[[eps]] {broadcast_dimensions = dense<[]> : tensor<0xi64>} : (tensor<8xf32>, tensor<f32>) -> tensor<8xf32>
|
||||
// CHECK-NEXT: %[[add:.*]] = chlo.broadcast_add %arg4, %[[eps]] {broadcast_dimensions = dense<[]> : tensor<0xi64>} : (tensor<8xf32>, tensor<f32>) -> tensor<8xf32>
|
||||
// CHECK-NEXT: %[[scr1:.*]] = "mhlo.rsqrt"(%[[add]]) : (tensor<8xf32>) -> tensor<8xf32>
|
||||
|
||||
// CHECK: %[[bcast_arg3:.+]] = "mhlo.dynamic_broadcast_in_dim"(%arg3, {{.*}}) {broadcast_dimensions = dense<3> : tensor<1xi64>} : (tensor<8xf32>, tensor<4xindex>) -> tensor<8x8x8x8xf32>
|
||||
@ -367,7 +367,7 @@ func @fusedBatchNormGradV3_noTraining_NCHW(%arg0: tensor<8x8x8x8xf32>, %arg1: te
|
||||
// CHECK-NEXT: %[[act:.*]] = "mhlo.convert"(%arg1) : (tensor<8x8x8x8xf32>) -> tensor<8x8x8x8xf32>
|
||||
// CHECK-NEXT: %[[eps:.*]] = mhlo.constant dense<1.000000e-03> : tensor<f32>
|
||||
|
||||
// CHECK-NEXT: %[[add:.*]] = xla_chlo.broadcast_add %arg4, %[[eps]] {broadcast_dimensions = dense<[]> : tensor<0xi64>} : (tensor<8xf32>, tensor<f32>) -> tensor<8xf32>
|
||||
// CHECK-NEXT: %[[add:.*]] = chlo.broadcast_add %arg4, %[[eps]] {broadcast_dimensions = dense<[]> : tensor<0xi64>} : (tensor<8xf32>, tensor<f32>) -> tensor<8xf32>
|
||||
// CHECK-NEXT: %[[scr1:.*]] = "mhlo.rsqrt"(%[[add]]) : (tensor<8xf32>) -> tensor<8xf32>
|
||||
|
||||
// CHECK: %[[bcast_arg3:.+]] = "mhlo.dynamic_broadcast_in_dim"(%arg3, {{.*}}) {broadcast_dimensions = dense<1> : tensor<1xi64>} : (tensor<8xf32>, tensor<4xindex>) -> tensor<8x8x8x8xf32>
|
||||
@ -498,19 +498,19 @@ func @unary_einsum(%arg0: tensor<2x3xf32>) -> tensor<2x2xf32> {
|
||||
// CHECK-LABEL: func @floordiv_broadcast_i32
|
||||
func @floordiv_broadcast_i32(%arg0: tensor<2x3xi32>, %arg1: tensor<3xi32>) -> tensor<2x3xi32> {
|
||||
// CHECK-DAG: [[ZEROS1:%.+]] = mhlo.constant dense<0>
|
||||
// CHECK-DAG: [[CMP1:%.+]] = xla_chlo.broadcast_compare %arg0, [[ZEROS1]] {comparison_direction = "LT"}
|
||||
// CHECK-DAG: [[CMP1:%.+]] = chlo.broadcast_compare %arg0, [[ZEROS1]] {comparison_direction = "LT"}
|
||||
// CHECK-DAG: [[ZEROS2:%.+]] = mhlo.constant dense<0>
|
||||
// CHECK-DAG: [[CMP2:%.+]] = xla_chlo.broadcast_compare %arg1, [[ZEROS2]] {comparison_direction = "LT"}
|
||||
// CHECK-DAG: [[CMP3:%.+]] = xla_chlo.broadcast_compare [[CMP1]], [[CMP2]] {broadcast_dimensions = dense<1> : tensor<1xi64>, comparison_direction = "EQ"}
|
||||
// CHECK-DAG: [[DIV1:%.+]] = xla_chlo.broadcast_divide %arg0, %arg1 {broadcast_dimensions = dense<1> : tensor<1xi64>}
|
||||
// CHECK-DAG: [[CMP2:%.+]] = chlo.broadcast_compare %arg1, [[ZEROS2]] {comparison_direction = "LT"}
|
||||
// CHECK-DAG: [[CMP3:%.+]] = chlo.broadcast_compare [[CMP1]], [[CMP2]] {broadcast_dimensions = dense<1> : tensor<1xi64>, comparison_direction = "EQ"}
|
||||
// CHECK-DAG: [[DIV1:%.+]] = chlo.broadcast_divide %arg0, %arg1 {broadcast_dimensions = dense<1> : tensor<1xi64>}
|
||||
// CHECK-DAG: [[ABS1:%.+]] = "mhlo.abs"(%arg0)
|
||||
// CHECK-DAG: [[ABS2:%.+]] = "mhlo.abs"(%arg1)
|
||||
// CHECK-DAG: [[ONES:%.+]] = mhlo.constant dense<1>
|
||||
// CHECK-DAG: [[SUB:%.+]] = xla_chlo.broadcast_subtract [[ABS2]], [[ONES]]
|
||||
// CHECK-DAG: [[ADD:%.+]] = xla_chlo.broadcast_add [[ABS1]], [[SUB]] {broadcast_dimensions = dense<1> : tensor<1xi64>}
|
||||
// CHECK-DAG: [[SUB:%.+]] = chlo.broadcast_subtract [[ABS2]], [[ONES]]
|
||||
// CHECK-DAG: [[ADD:%.+]] = chlo.broadcast_add [[ABS1]], [[SUB]] {broadcast_dimensions = dense<1> : tensor<1xi64>}
|
||||
// CHECK-DAG: [[NEG:%.+]] = "mhlo.negate"([[ADD]])
|
||||
// CHECK-DAG: [[ABS3:%.+]] = "mhlo.abs"(%arg1)
|
||||
// CHECK-DAG: [[DIV2:%.+]] = xla_chlo.broadcast_divide [[NEG]], [[ABS3]] {broadcast_dimensions = dense<1> : tensor<1xi64>}
|
||||
// CHECK-DAG: [[DIV2:%.+]] = chlo.broadcast_divide [[NEG]], [[ABS3]] {broadcast_dimensions = dense<1> : tensor<1xi64>}
|
||||
// CHECK-DAG: [[SELECT:%.+]] = "mhlo.select"([[CMP3]], [[DIV1]], [[DIV2]])
|
||||
// CHECK: return [[SELECT]]
|
||||
%0 = "tf.FloorDiv"(%arg0, %arg1) : (tensor<2x3xi32>, tensor<3xi32>) -> tensor<2x3xi32>
|
||||
@ -520,19 +520,19 @@ func @floordiv_broadcast_i32(%arg0: tensor<2x3xi32>, %arg1: tensor<3xi32>) -> te
|
||||
// CHECK-LABEL: func @floordiv_reverse_broadcast_i32
|
||||
func @floordiv_reverse_broadcast_i32(%arg0: tensor<3xi32>, %arg1: tensor<2x3xi32>) -> tensor<2x3xi32> {
|
||||
// CHECK-DAG: [[ZEROS1:%.+]] = mhlo.constant dense<0>
|
||||
// CHECK-DAG: [[CMP1:%.+]] = xla_chlo.broadcast_compare %arg0, [[ZEROS1]] {comparison_direction = "LT"}
|
||||
// CHECK-DAG: [[CMP1:%.+]] = chlo.broadcast_compare %arg0, [[ZEROS1]] {comparison_direction = "LT"}
|
||||
// CHECK-DAG: [[ZEROS2:%.+]] = mhlo.constant dense<0>
|
||||
// CHECK-DAG: [[CMP2:%.+]] = xla_chlo.broadcast_compare %arg1, [[ZEROS2]] {comparison_direction = "LT"}
|
||||
// CHECK-DAG: [[CMP3:%.+]] = xla_chlo.broadcast_compare [[CMP1]], [[CMP2]] {broadcast_dimensions = dense<1> : tensor<1xi64>, comparison_direction = "EQ"}
|
||||
// CHECK-DAG: [[DIV1:%.+]] = xla_chlo.broadcast_divide %arg0, %arg1 {broadcast_dimensions = dense<1> : tensor<1xi64>}
|
||||
// CHECK-DAG: [[CMP2:%.+]] = chlo.broadcast_compare %arg1, [[ZEROS2]] {comparison_direction = "LT"}
|
||||
// CHECK-DAG: [[CMP3:%.+]] = chlo.broadcast_compare [[CMP1]], [[CMP2]] {broadcast_dimensions = dense<1> : tensor<1xi64>, comparison_direction = "EQ"}
|
||||
// CHECK-DAG: [[DIV1:%.+]] = chlo.broadcast_divide %arg0, %arg1 {broadcast_dimensions = dense<1> : tensor<1xi64>}
|
||||
// CHECK-DAG: [[ABS1:%.+]] = "mhlo.abs"(%arg0)
|
||||
// CHECK-DAG: [[ABS2:%.+]] = "mhlo.abs"(%arg1)
|
||||
// CHECK-DAG: [[ONES:%.+]] = mhlo.constant dense<1>
|
||||
// CHECK-DAG: [[SUB:%.+]] = xla_chlo.broadcast_subtract [[ABS2]], [[ONES]]
|
||||
// CHECK-DAG: [[ADD:%.+]] = xla_chlo.broadcast_add [[ABS1]], [[SUB]] {broadcast_dimensions = dense<1> : tensor<1xi64>}
|
||||
// CHECK-DAG: [[SUB:%.+]] = chlo.broadcast_subtract [[ABS2]], [[ONES]]
|
||||
// CHECK-DAG: [[ADD:%.+]] = chlo.broadcast_add [[ABS1]], [[SUB]] {broadcast_dimensions = dense<1> : tensor<1xi64>}
|
||||
// CHECK-DAG: [[NEG:%.+]] = "mhlo.negate"([[ADD]])
|
||||
// CHECK-DAG: [[ABS3:%.+]] = "mhlo.abs"(%arg1)
|
||||
// CHECK-DAG: [[DIV2:%.+]] = xla_chlo.broadcast_divide [[NEG]], [[ABS3]]
|
||||
// CHECK-DAG: [[DIV2:%.+]] = chlo.broadcast_divide [[NEG]], [[ABS3]]
|
||||
// CHECK-DAG: [[SELECT:%.+]] = "mhlo.select"([[CMP3]], [[DIV1]], [[DIV2]])
|
||||
// CHECK: return [[SELECT]]
|
||||
%0 = "tf.FloorDiv"(%arg0, %arg1) : (tensor<3xi32>, tensor<2x3xi32>) -> tensor<2x3xi32>
|
||||
@ -541,7 +541,7 @@ func @floordiv_reverse_broadcast_i32(%arg0: tensor<3xi32>, %arg1: tensor<2x3xi32
|
||||
|
||||
// CHECK-LABEL: func @floordiv_f32
|
||||
func @floordiv_f32(%arg0: tensor<2xf32>) -> tensor<2xf32> {
|
||||
// CHECK-NEXT: %[[DIV:.*]] = xla_chlo.broadcast_divide %arg0, %arg0
|
||||
// CHECK-NEXT: %[[DIV:.*]] = chlo.broadcast_divide %arg0, %arg0
|
||||
// CHECK-NEXT: %[[FLOOR:.*]] = "mhlo.floor"(%[[DIV]])
|
||||
// CHECK-NEXT: return %[[FLOOR]] : tensor<2xf32>
|
||||
%0 = "tf.FloorDiv"(%arg0, %arg0) : (tensor<2xf32>, tensor<2xf32>) -> tensor<2xf32>
|
||||
@ -552,7 +552,7 @@ func @floordiv_f32(%arg0: tensor<2xf32>) -> tensor<2xf32> {
|
||||
func @floordiv_bf16(%arg0: tensor<2xbf16>) -> tensor<2xbf16> {
|
||||
// CHECK-NEXT: mhlo.convert
|
||||
// CHECK-NEXT: mhlo.convert
|
||||
// CHECK-NEXT: xla_chlo.broadcast_divide
|
||||
// CHECK-NEXT: chlo.broadcast_divide
|
||||
// CHECK-NEXT: mhlo.floor
|
||||
// CHECK-NEXT: mhlo.convert
|
||||
// CHECK-NEXT: return
|
||||
@ -562,7 +562,7 @@ func @floordiv_bf16(%arg0: tensor<2xbf16>) -> tensor<2xbf16> {
|
||||
|
||||
// CHECK-LABEL: func @floordiv_f16_broadcast
|
||||
func @floordiv_f16_broadcast(%arg0: tensor<2x3xf16>, %arg1: tensor<3xf16>) -> tensor<2x3xf16> {
|
||||
// CHECK-NEXT: xla_chlo.broadcast_divide
|
||||
// CHECK-NEXT: chlo.broadcast_divide
|
||||
// CHECK-NEXT: mhlo.floor
|
||||
// CHECK-NEXT: return
|
||||
%0 = "tf.FloorDiv"(%arg0, %arg1) : (tensor<2x3xf16>, tensor<3xf16>) -> tensor<2x3xf16>
|
||||
@ -572,19 +572,19 @@ func @floordiv_f16_broadcast(%arg0: tensor<2x3xf16>, %arg1: tensor<3xf16>) -> te
|
||||
// CHECK-LABEL: func @floordiv_dynamic
|
||||
func @floordiv_dynamic(%arg0: tensor<?x?xi32>, %arg1: tensor<?xi32>) -> tensor<?x?xi32> {
|
||||
// CHECK-DAG: [[ZEROS1:%.+]] = mhlo.constant dense<0>
|
||||
// CHECK-DAG: [[CMP1:%.+]] = xla_chlo.broadcast_compare %arg0, [[ZEROS1]] {comparison_direction = "LT"}
|
||||
// CHECK-DAG: [[CMP1:%.+]] = chlo.broadcast_compare %arg0, [[ZEROS1]] {comparison_direction = "LT"}
|
||||
// CHECK-DAG: [[ZEROS2:%.+]] = mhlo.constant dense<0>
|
||||
// CHECK-DAG: [[CMP2:%.+]] = xla_chlo.broadcast_compare %arg1, [[ZEROS2]] {comparison_direction = "LT"}
|
||||
// CHECK-DAG: [[CMP3:%.+]] = xla_chlo.broadcast_compare [[CMP1]], [[CMP2]] {broadcast_dimensions = dense<1> : tensor<1xi64>, comparison_direction = "EQ"}
|
||||
// CHECK-DAG: [[DIV1:%.+]] = xla_chlo.broadcast_divide %arg0, %arg1 {broadcast_dimensions = dense<1> : tensor<1xi64>}
|
||||
// CHECK-DAG: [[CMP2:%.+]] = chlo.broadcast_compare %arg1, [[ZEROS2]] {comparison_direction = "LT"}
|
||||
// CHECK-DAG: [[CMP3:%.+]] = chlo.broadcast_compare [[CMP1]], [[CMP2]] {broadcast_dimensions = dense<1> : tensor<1xi64>, comparison_direction = "EQ"}
|
||||
// CHECK-DAG: [[DIV1:%.+]] = chlo.broadcast_divide %arg0, %arg1 {broadcast_dimensions = dense<1> : tensor<1xi64>}
|
||||
// CHECK-DAG: [[ABS1:%.+]] = "mhlo.abs"(%arg0)
|
||||
// CHECK-DAG: [[ABS2:%.+]] = "mhlo.abs"(%arg1)
|
||||
// CHECK-DAG: [[ONES:%.+]] = mhlo.constant dense<1>
|
||||
// CHECK-DAG: [[SUB:%.+]] = xla_chlo.broadcast_subtract [[ABS2]], [[ONES]]
|
||||
// CHECK-DAG: [[ADD:%.+]] = xla_chlo.broadcast_add [[ABS1]], [[SUB]] {broadcast_dimensions = dense<1> : tensor<1xi64>}
|
||||
// CHECK-DAG: [[SUB:%.+]] = chlo.broadcast_subtract [[ABS2]], [[ONES]]
|
||||
// CHECK-DAG: [[ADD:%.+]] = chlo.broadcast_add [[ABS1]], [[SUB]] {broadcast_dimensions = dense<1> : tensor<1xi64>}
|
||||
// CHECK-DAG: [[NEG:%.+]] = "mhlo.negate"([[ADD]])
|
||||
// CHECK-DAG: [[ABS3:%.+]] = "mhlo.abs"(%arg1)
|
||||
// CHECK-DAG: [[DIV2:%.+]] = xla_chlo.broadcast_divide [[NEG]], [[ABS3]] {broadcast_dimensions = dense<1> : tensor<1xi64>}
|
||||
// CHECK-DAG: [[DIV2:%.+]] = chlo.broadcast_divide [[NEG]], [[ABS3]] {broadcast_dimensions = dense<1> : tensor<1xi64>}
|
||||
// CHECK-DAG: [[SELECT:%.+]] = "mhlo.select"([[CMP3]], [[DIV1]], [[DIV2]])
|
||||
// CHECK: return [[SELECT]]
|
||||
%0 = "tf.FloorDiv"(%arg0, %arg1) : (tensor<?x?xi32>, tensor<?xi32>) -> tensor<?x?xi32>
|
||||
@ -600,15 +600,15 @@ func @floordiv_unranked(%arg0: tensor<*xi32>, %arg1: tensor<*xi32>) -> tensor<*x
|
||||
|
||||
// CHECK-LABEL: func @floormod_broadcast_numerator
|
||||
func @floormod_broadcast_numerator(%arg0: tensor<3xi32>, %arg1: tensor<2x3xi32>) -> tensor<2x3xi32> {
|
||||
// CHECK-DAG: [[REM:%.+]] = xla_chlo.broadcast_remainder %arg0, %arg1 {broadcast_dimensions = dense<1> : tensor<1xi64>}
|
||||
// CHECK-DAG: [[REM:%.+]] = chlo.broadcast_remainder %arg0, %arg1 {broadcast_dimensions = dense<1> : tensor<1xi64>}
|
||||
// CHECK-DAG: [[ZL:%.+]] = mhlo.constant dense<0>
|
||||
// CHECK-DAG: [[CMP1:%.+]] = xla_chlo.broadcast_compare [[REM]], [[ZL]] {broadcast_dimensions = dense<1> : tensor<1xi64>, comparison_direction = "NE"}
|
||||
// CHECK-DAG: [[CMP1:%.+]] = chlo.broadcast_compare [[REM]], [[ZL]] {broadcast_dimensions = dense<1> : tensor<1xi64>, comparison_direction = "NE"}
|
||||
// CHECK-DAG: [[ZR:%.+]] = mhlo.constant dense<0>
|
||||
// CHECK-DAG: [[CMP2:%.+]] = xla_chlo.broadcast_compare %arg1, [[ZR]] {comparison_direction = "LT"}
|
||||
// CHECK-DAG: [[CMP3:%.+]] = xla_chlo.broadcast_compare [[REM]], [[ZR]] {broadcast_dimensions = dense<[]> : tensor<0xi64>, comparison_direction = "LT"}
|
||||
// CHECK-DAG: [[CMP4:%.+]] = xla_chlo.broadcast_compare [[CMP2]], [[CMP3]] {comparison_direction = "NE"}
|
||||
// CHECK-DAG: [[AND:%.+]] = xla_chlo.broadcast_and [[CMP1]], [[CMP4]]
|
||||
// CHECK-DAG: [[ADD:%.+]] = xla_chlo.broadcast_add %arg1, [[REM]]
|
||||
// CHECK-DAG: [[CMP2:%.+]] = chlo.broadcast_compare %arg1, [[ZR]] {comparison_direction = "LT"}
|
||||
// CHECK-DAG: [[CMP3:%.+]] = chlo.broadcast_compare [[REM]], [[ZR]] {broadcast_dimensions = dense<[]> : tensor<0xi64>, comparison_direction = "LT"}
|
||||
// CHECK-DAG: [[CMP4:%.+]] = chlo.broadcast_compare [[CMP2]], [[CMP3]] {comparison_direction = "NE"}
|
||||
// CHECK-DAG: [[AND:%.+]] = chlo.broadcast_and [[CMP1]], [[CMP4]]
|
||||
// CHECK-DAG: [[ADD:%.+]] = chlo.broadcast_add %arg1, [[REM]]
|
||||
// CHECK-DAG: [[SELECT:%.+]] = "mhlo.select"([[AND]], [[ADD]], [[REM]])
|
||||
// CHECK-NEXT: return [[SELECT]]
|
||||
%0 = "tf.FloorMod"(%arg0, %arg1) : (tensor<3xi32>, tensor<2x3xi32>) -> tensor<2x3xi32>
|
||||
@ -617,15 +617,15 @@ func @floormod_broadcast_numerator(%arg0: tensor<3xi32>, %arg1: tensor<2x3xi32>)
|
||||
|
||||
// CHECK-LABEL: func @floormod_broadcast_denominator
|
||||
func @floormod_broadcast_denominator(%arg0: tensor<2x3xi32>, %arg1: tensor<3xi32>) -> tensor<2x3xi32> {
|
||||
// CHECK-DAG: [[REM:%.+]] = xla_chlo.broadcast_remainder %arg0, %arg1 {broadcast_dimensions = dense<1> : tensor<1xi64>}
|
||||
// CHECK-DAG: [[REM:%.+]] = chlo.broadcast_remainder %arg0, %arg1 {broadcast_dimensions = dense<1> : tensor<1xi64>}
|
||||
// CHECK-DAG: [[ZL:%.+]] = mhlo.constant dense<0>
|
||||
// CHECK-DAG: [[CMP1:%.+]] = xla_chlo.broadcast_compare [[REM]], [[ZL]] {comparison_direction = "NE"}
|
||||
// CHECK-DAG: [[CMP1:%.+]] = chlo.broadcast_compare [[REM]], [[ZL]] {comparison_direction = "NE"}
|
||||
// CHECK-DAG: [[ZR:%.+]] = mhlo.constant dense<0>
|
||||
// CHECK-DAG: [[CMP2:%.+]] = xla_chlo.broadcast_compare %arg1, [[ZR]] {comparison_direction = "LT"}
|
||||
// CHECK-DAG: [[CMP3:%.+]] = xla_chlo.broadcast_compare [[REM]], [[ZR]] {broadcast_dimensions = dense<[]> : tensor<0xi64>, comparison_direction = "LT"}
|
||||
// CHECK-DAG: [[CMP4:%.+]] = xla_chlo.broadcast_compare [[CMP2]], [[CMP3]] {broadcast_dimensions = dense<1> : tensor<1xi64>, comparison_direction = "NE"}
|
||||
// CHECK-DAG: [[AND:%.+]] = xla_chlo.broadcast_and [[CMP1]], [[CMP4]]
|
||||
// CHECK-DAG: [[ADD:%.+]] = xla_chlo.broadcast_add %arg1, [[REM]] {broadcast_dimensions = dense<1> : tensor<1xi64>}
|
||||
// CHECK-DAG: [[CMP2:%.+]] = chlo.broadcast_compare %arg1, [[ZR]] {comparison_direction = "LT"}
|
||||
// CHECK-DAG: [[CMP3:%.+]] = chlo.broadcast_compare [[REM]], [[ZR]] {broadcast_dimensions = dense<[]> : tensor<0xi64>, comparison_direction = "LT"}
|
||||
// CHECK-DAG: [[CMP4:%.+]] = chlo.broadcast_compare [[CMP2]], [[CMP3]] {broadcast_dimensions = dense<1> : tensor<1xi64>, comparison_direction = "NE"}
|
||||
// CHECK-DAG: [[AND:%.+]] = chlo.broadcast_and [[CMP1]], [[CMP4]]
|
||||
// CHECK-DAG: [[ADD:%.+]] = chlo.broadcast_add %arg1, [[REM]] {broadcast_dimensions = dense<1> : tensor<1xi64>}
|
||||
// CHECK-DAG: [[SELECT:%.+]] = "mhlo.select"([[AND]], [[ADD]], [[REM]])
|
||||
// CHECK-NEXT: return [[SELECT]]
|
||||
%0 = "tf.FloorMod"(%arg0, %arg1) : (tensor<2x3xi32>, tensor<3xi32>) -> tensor<2x3xi32>
|
||||
@ -634,15 +634,15 @@ func @floormod_broadcast_denominator(%arg0: tensor<2x3xi32>, %arg1: tensor<3xi32
|
||||
|
||||
// CHECK-LABEL: func @floormod_dynamic
|
||||
func @floormod_dynamic(%arg0: tensor<?x?xi32>, %arg1: tensor<?xi32>) -> tensor<?x?xi32> {
|
||||
// CHECK-DAG: [[REM:%.+]] = xla_chlo.broadcast_remainder %arg0, %arg1 {broadcast_dimensions = dense<1> : tensor<1xi64>}
|
||||
// CHECK-DAG: [[REM:%.+]] = chlo.broadcast_remainder %arg0, %arg1 {broadcast_dimensions = dense<1> : tensor<1xi64>}
|
||||
// CHECK-DAG: [[ZL:%.+]] = mhlo.constant dense<0>
|
||||
// CHECK-DAG: [[CMP1:%.+]] = xla_chlo.broadcast_compare [[REM]], [[ZL]] {comparison_direction = "NE"}
|
||||
// CHECK-DAG: [[CMP1:%.+]] = chlo.broadcast_compare [[REM]], [[ZL]] {comparison_direction = "NE"}
|
||||
// CHECK-DAG: [[ZR:%.+]] = mhlo.constant dense<0>
|
||||
// CHECK-DAG: [[CMP2:%.+]] = xla_chlo.broadcast_compare %arg1, [[ZR]] {comparison_direction = "LT"}
|
||||
// CHECK-DAG: [[CMP3:%.+]] = xla_chlo.broadcast_compare [[REM]], [[ZR]] {broadcast_dimensions = dense<[]> : tensor<0xi64>, comparison_direction = "LT"}
|
||||
// CHECK-DAG: [[CMP4:%.+]] = xla_chlo.broadcast_compare [[CMP2]], [[CMP3]] {broadcast_dimensions = dense<1> : tensor<1xi64>, comparison_direction = "NE"}
|
||||
// CHECK-DAG: [[AND:%.+]] = xla_chlo.broadcast_and [[CMP1]], [[CMP4]]
|
||||
// CHECK-DAG: [[ADD:%.+]] = xla_chlo.broadcast_add %arg1, [[REM]] {broadcast_dimensions = dense<1> : tensor<1xi64>}
|
||||
// CHECK-DAG: [[CMP2:%.+]] = chlo.broadcast_compare %arg1, [[ZR]] {comparison_direction = "LT"}
|
||||
// CHECK-DAG: [[CMP3:%.+]] = chlo.broadcast_compare [[REM]], [[ZR]] {broadcast_dimensions = dense<[]> : tensor<0xi64>, comparison_direction = "LT"}
|
||||
// CHECK-DAG: [[CMP4:%.+]] = chlo.broadcast_compare [[CMP2]], [[CMP3]] {broadcast_dimensions = dense<1> : tensor<1xi64>, comparison_direction = "NE"}
|
||||
// CHECK-DAG: [[AND:%.+]] = chlo.broadcast_and [[CMP1]], [[CMP4]]
|
||||
// CHECK-DAG: [[ADD:%.+]] = chlo.broadcast_add %arg1, [[REM]] {broadcast_dimensions = dense<1> : tensor<1xi64>}
|
||||
// CHECK-DAG: [[SELECT:%.+]] = "mhlo.select"([[AND]], [[ADD]], [[REM]])
|
||||
// CHECK-NEXT: return [[SELECT]]
|
||||
%0 = "tf.FloorMod"(%arg0, %arg1) : (tensor<?x?xi32>, tensor<?xi32>) -> tensor<?x?xi32>
|
||||
@ -979,10 +979,10 @@ func @matrix_band_part(%arg0: tensor<64x64xbf16>, %arg1: tensor<i64>, %arg2: ten
|
||||
// CHECK: %[[X:.*]] = "mhlo.iota"() {iota_dimension = 1 : i64} : () -> tensor<64x64xbf16>
|
||||
// CHECK: %[[Y:.*]] = "mhlo.iota"() {iota_dimension = 0 : i64} : () -> tensor<64x64xbf16>
|
||||
// CHECK: %[[OFFSET:.*]] = mhlo.subtract %[[X]], %[[Y]] : tensor<64x64xbf16>
|
||||
// CHECK: %[[G:.*]] = xla_chlo.broadcast_compare %[[F]], %[[OFFSET]] {comparison_direction = "LE"} : (tensor<bf16>, tensor<64x64xbf16>) -> tensor<64x64xi1>
|
||||
// CHECK: %[[G:.*]] = chlo.broadcast_compare %[[F]], %[[OFFSET]] {comparison_direction = "LE"} : (tensor<bf16>, tensor<64x64xbf16>) -> tensor<64x64xi1>
|
||||
|
||||
// CHECK: %[[H:.*]] = "mhlo.convert"(%[[D]]) : (tensor<i64>) -> tensor<bf16>
|
||||
// CHECK: %[[I:.*]] = xla_chlo.broadcast_compare %[[OFFSET]], %[[H]] {comparison_direction = "LE"} : (tensor<64x64xbf16>, tensor<bf16>) -> tensor<64x64xi1>
|
||||
// CHECK: %[[I:.*]] = chlo.broadcast_compare %[[OFFSET]], %[[H]] {comparison_direction = "LE"} : (tensor<64x64xbf16>, tensor<bf16>) -> tensor<64x64xi1>
|
||||
|
||||
// CHECK: %[[J:.*]] = mhlo.and %[[G]], %[[I]] : tensor<64x64xi1>
|
||||
|
||||
@ -1000,10 +1000,10 @@ func @matrix_band_part_2(%arg0: tensor<12x24x48xbf16>, %arg1: tensor<i64>, %arg2
|
||||
// CHECK: %[[Y:.*]] = "mhlo.iota"() {iota_dimension = 0 : i64} : () -> tensor<24x48xbf16>
|
||||
// CHECK: %[[OFFSET:.*]] = mhlo.subtract %[[X]], %[[Y]] : tensor<24x48xbf16>
|
||||
|
||||
// CHECK: %[[G:.*]] = xla_chlo.broadcast_compare %[[F]], %[[OFFSET]] {comparison_direction = "LE"} : (tensor<bf16>, tensor<24x48xbf16>) -> tensor<24x48xi1>
|
||||
// CHECK: %[[G:.*]] = chlo.broadcast_compare %[[F]], %[[OFFSET]] {comparison_direction = "LE"} : (tensor<bf16>, tensor<24x48xbf16>) -> tensor<24x48xi1>
|
||||
|
||||
// CHECK: %[[H:.*]] = "mhlo.convert"(%[[D]]) : (tensor<i64>) -> tensor<bf16>
|
||||
// CHECK: %[[I:.*]] = xla_chlo.broadcast_compare %[[OFFSET]], %[[H]] {comparison_direction = "LE"} : (tensor<24x48xbf16>, tensor<bf16>) -> tensor<24x48xi1>
|
||||
// CHECK: %[[I:.*]] = chlo.broadcast_compare %[[OFFSET]], %[[H]] {comparison_direction = "LE"} : (tensor<24x48xbf16>, tensor<bf16>) -> tensor<24x48xi1>
|
||||
// CHECK: %[[J:.*]] = mhlo.and %[[G]], %[[I]] : tensor<24x48xi1>
|
||||
|
||||
// CHECK: %[[ZERO2:.*]] = mhlo.constant dense<0.000000e+00> : tensor<12x24x48xbf16>
|
||||
@ -1315,7 +1315,7 @@ func @stateful_pcall_multi_in_out(%arg0: tensor<i32>, %arg1: tensor<i32>) -> (te
|
||||
// CHECK-LABEL: func @relu
|
||||
func @relu(%arg0: tensor<1xi32>) -> tensor<1xi32> {
|
||||
// CHECK: %[[ZERO:.*]] = mhlo.constant dense<0> : tensor<i32>
|
||||
// CHECK: xla_chlo.broadcast_maximum %[[ZERO]], %arg0 {broadcast_dimensions = dense<[]> : tensor<0xi64>} : (tensor<i32>, tensor<1xi32>) -> tensor<1xi32>
|
||||
// CHECK: chlo.broadcast_maximum %[[ZERO]], %arg0 {broadcast_dimensions = dense<[]> : tensor<0xi64>} : (tensor<i32>, tensor<1xi32>) -> tensor<1xi32>
|
||||
%0 = "tf.Relu"(%arg0) : (tensor<1xi32>) -> tensor<1xi32>
|
||||
return %0: tensor<1xi32>
|
||||
}
|
||||
@ -1323,7 +1323,7 @@ func @relu(%arg0: tensor<1xi32>) -> tensor<1xi32> {
|
||||
// CHECK-LABEL: func @relu_unranked
|
||||
func @relu_unranked(%arg0: tensor<?xi32>) -> tensor<?xi32> {
|
||||
// CHECK: %[[ZERO:.*]] = mhlo.constant dense<0> : tensor<i32>
|
||||
// CHECK: xla_chlo.broadcast_maximum %[[ZERO]], %arg0 {broadcast_dimensions = dense<[]> : tensor<0xi64>} : (tensor<i32>, tensor<?xi32>) -> tensor<?xi32>
|
||||
// CHECK: chlo.broadcast_maximum %[[ZERO]], %arg0 {broadcast_dimensions = dense<[]> : tensor<0xi64>} : (tensor<i32>, tensor<?xi32>) -> tensor<?xi32>
|
||||
%0 = "tf.Relu"(%arg0) : (tensor<?xi32>) -> tensor<?xi32>
|
||||
return %0: tensor<?xi32>
|
||||
}
|
||||
@ -1351,7 +1351,7 @@ func @relu6_unranked(%arg0: tensor<?xi32>) -> tensor<?xi32> {
|
||||
func @relu_grad(%gradients: tensor<4x8xf32>, %features: tensor<?x?xf32>) -> tensor<4x8xf32> {
|
||||
// CHECK-DAG: %[[ZERO_SCALAR:.*]] = mhlo.constant dense<0.000000e+00> : tensor<f32>
|
||||
// CHECK-DAG: %[[ZERO:.*]] = mhlo.constant dense<0.000000e+00> : tensor<4x8xf32>
|
||||
// CHECK-DAG: %[[PRED:.*]] = xla_chlo.broadcast_compare %[[FEATURES]], %[[ZERO_SCALAR]] {comparison_direction = "GT"} : (tensor<?x?xf32>, tensor<f32>) -> tensor<?x?xi1>
|
||||
// CHECK-DAG: %[[PRED:.*]] = chlo.broadcast_compare %[[FEATURES]], %[[ZERO_SCALAR]] {comparison_direction = "GT"} : (tensor<?x?xf32>, tensor<f32>) -> tensor<?x?xi1>
|
||||
// CHECK-DAG: %[[RESULT:.*]] = "mhlo.select"(%[[PRED]], %[[GRADIENTS]], %[[ZERO]]) : (tensor<?x?xi1>, tensor<4x8xf32>, tensor<4x8xf32>) -> tensor<4x8xf32>
|
||||
// CHECK-DAG: return %[[RESULT]] : tensor<4x8xf32>
|
||||
%2 = "tf.ReluGrad"(%gradients, %features) : (tensor<4x8xf32>, tensor<?x?xf32>) -> tensor<4x8xf32>
|
||||
@ -2473,10 +2473,10 @@ func @strided_slice_nonconstant_begin_end(%arg0: tensor<i32>, %arg1: tensor<32x1
|
||||
// CHECK-DAG-SAME: start_indices = dense<0> : tensor<1xi64>,
|
||||
// CHECK-DAG-SAME: strides = dense<1> : tensor<1xi64>} : (tensor<1xi32>) -> tensor<1xi32>
|
||||
// CHECK-NEXT: %[[INDEX2:.*]] = "mhlo.reshape"(%[[INDEX]]) : (tensor<1xi32>) -> tensor<i32>
|
||||
// CHECK-NEXT: %[[CMP:.*]] = xla_chlo.broadcast_compare %[[INDEX2]], %[[ZERO]]
|
||||
// CHECK-NEXT: %[[CMP:.*]] = chlo.broadcast_compare %[[INDEX2]], %[[ZERO]]
|
||||
// CHECK-DAG-SAME: {comparison_direction = "LT"} : (tensor<i32>, tensor<i32>) -> tensor<i1>
|
||||
// CHECK-NEXT: %[[DIM:.*]] = mhlo.constant dense<32> : tensor<i32>
|
||||
// CHECK-NEXT: %[[WRAP:.*]] = xla_chlo.broadcast_add %[[DIM]], %[[INDEX2]] : (tensor<i32>, tensor<i32>) -> tensor<i32>
|
||||
// CHECK-NEXT: %[[WRAP:.*]] = chlo.broadcast_add %[[DIM]], %[[INDEX2]] : (tensor<i32>, tensor<i32>) -> tensor<i32>
|
||||
// CHECK-NEXT: %[[INDEX3:.*]] = "mhlo.select"(%[[CMP]], %[[WRAP]], %[[INDEX2]]) :
|
||||
// CHECK-DAG-SAME: (tensor<i1>, tensor<i32>, tensor<i32>) -> tensor<i32>
|
||||
// CHECK-NEXT: %[[SLICED:.*]] = "mhlo.dynamic-slice"
|
||||
@ -2605,7 +2605,7 @@ func @mean(%arg0: tensor<4x8xf16>) -> tensor<4x1xf16> {
|
||||
// CHECK: "mhlo.return"(%[[REDUCE_BODY_RESULT]]) : (tensor<f32>) -> ()
|
||||
// CHECK: }) {dimensions = dense<1> : tensor<1xi64>} : (tensor<4x8xf32>, tensor<f32>) -> tensor<4xf32>
|
||||
// CHECK: %[[DIVISOR:.*]] = mhlo.constant dense<8.000000e+00> : tensor<f32>
|
||||
// CHECK: %[[MEAN:.*]] = xla_chlo.broadcast_divide %[[REDUCED]], %[[DIVISOR]] {broadcast_dimensions = dense<[]> : tensor<0xi64>} : (tensor<4xf32>, tensor<f32>) -> tensor<4xf32>
|
||||
// CHECK: %[[MEAN:.*]] = chlo.broadcast_divide %[[REDUCED]], %[[DIVISOR]] {broadcast_dimensions = dense<[]> : tensor<0xi64>} : (tensor<4xf32>, tensor<f32>) -> tensor<4xf32>
|
||||
// CHECK: %[[CAST_BACK:.*]] = "mhlo.convert"(%[[MEAN]]) : (tensor<4xf32>) -> tensor<4xf16>
|
||||
// CHECK: %[[RESULT:.*]] = "mhlo.reshape"(%[[CAST_BACK]]) : (tensor<4xf16>) -> tensor<4x1xf16>
|
||||
// CHECK: return %[[RESULT]] : tensor<4x1xf16>
|
||||
@ -2909,8 +2909,8 @@ func @rng_std_normal(%arg0: tensor<3xi32>) -> tensor<12x?x64xf32> {
|
||||
func @range(%arg0: tensor<f32>, %arg1: tensor<f32>) -> tensor<5xf32> {
|
||||
%1 = "tf.Const"() {device = "", dtype = "tfdtype$DT_FLOAT", name = "range/limit", value = dense<5.000000e+00> : tensor<f32>} : () -> tensor<f32>
|
||||
// CHECK-DAG: [[IOTA:%.*]] = "mhlo.iota"
|
||||
// CHECK-DAG: [[MUL:%.*]] = xla_chlo.broadcast_multiply [[IOTA]], [[DELTA]] {broadcast_dimensions = dense<[]> : tensor<0xi64>}
|
||||
// CHECK: xla_chlo.broadcast_add [[MUL]], [[START]] {broadcast_dimensions = dense<[]> : tensor<0xi64>}
|
||||
// CHECK-DAG: [[MUL:%.*]] = chlo.broadcast_multiply [[IOTA]], [[DELTA]] {broadcast_dimensions = dense<[]> : tensor<0xi64>}
|
||||
// CHECK: chlo.broadcast_add [[MUL]], [[START]] {broadcast_dimensions = dense<[]> : tensor<0xi64>}
|
||||
%3 = "tf.Range"(%arg0, %1, %arg1) {Tidx = "tfdtype$DT_FLOAT", device = "", name = "range"} : (tensor<f32>, tensor<f32>, tensor<f32>) -> tensor<5xf32>
|
||||
return %3 : tensor<5xf32>
|
||||
}
|
||||
@ -2929,8 +2929,8 @@ func @range_dynamic(%arg0: tensor<f32>, %arg1: tensor<f32>, %arg2: tensor<f32>)
|
||||
// CHECK-DAG: [[IOTA:%.+]] = "mhlo.dynamic_iota"([[RESHAPE]]) {iota_dimension = 0 : i64}
|
||||
// CHECK-DAG: [[CONVERT3:%.+]] = "mhlo.convert"(%arg0)
|
||||
// CHECK-DAG: [[CONVERT4:%.+]] = "mhlo.convert"(%arg2)
|
||||
// CHECK-DAG: [[MUL:%.+]] = xla_chlo.broadcast_multiply [[IOTA]], [[CONVERT4]] {broadcast_dimensions = dense<[]> : tensor<0xi64>}
|
||||
// CHECK-DAG: [[ADD:%.+]] = xla_chlo.broadcast_add [[MUL]], [[CONVERT3]] {broadcast_dimensions = dense<[]> : tensor<0xi64>}
|
||||
// CHECK-DAG: [[MUL:%.+]] = chlo.broadcast_multiply [[IOTA]], [[CONVERT4]] {broadcast_dimensions = dense<[]> : tensor<0xi64>}
|
||||
// CHECK-DAG: [[ADD:%.+]] = chlo.broadcast_add [[MUL]], [[CONVERT3]] {broadcast_dimensions = dense<[]> : tensor<0xi64>}
|
||||
%2 = "tf.Range"(%arg0, %arg1, %arg2) {Tidx = "tfdtype$DT_FLOAT", device = "", name = "range"} : (tensor<f32>, tensor<f32>, tensor<f32>) -> tensor<?xf32>
|
||||
|
||||
// CHECK: return [[ADD]]
|
||||
@ -2951,8 +2951,8 @@ func @range_int_dynamic(%arg0: tensor<i32>, %arg1: tensor<i32>, %arg2: tensor<i3
|
||||
// CHECK-DAG: [[IOTA:%.+]] = "mhlo.dynamic_iota"([[RESHAPE]]) {iota_dimension = 0 : i64}
|
||||
// CHECK-DAG: [[CONVERT3:%.+]] = "mhlo.convert"(%arg0)
|
||||
// CHECK-DAG: [[CONVERT4:%.+]] = "mhlo.convert"(%arg2)
|
||||
// CHECK-DAG: [[MUL:%.+]] = xla_chlo.broadcast_multiply [[IOTA]], [[CONVERT4]] {broadcast_dimensions = dense<[]> : tensor<0xi64>}
|
||||
// CHECK-DAG: [[ADD:%.+]] = xla_chlo.broadcast_add [[MUL]], [[CONVERT3]] {broadcast_dimensions = dense<[]> : tensor<0xi64>}
|
||||
// CHECK-DAG: [[MUL:%.+]] = chlo.broadcast_multiply [[IOTA]], [[CONVERT4]] {broadcast_dimensions = dense<[]> : tensor<0xi64>}
|
||||
// CHECK-DAG: [[ADD:%.+]] = chlo.broadcast_add [[MUL]], [[CONVERT3]] {broadcast_dimensions = dense<[]> : tensor<0xi64>}
|
||||
%2 = "tf.Range"(%arg0, %arg1, %arg2) {Tidx = "tfdtype$DT_FLOAT", device = "", name = "range"} : (tensor<i32>, tensor<i32>, tensor<i32>) -> tensor<?xi32>
|
||||
|
||||
// CHECK: return [[ADD]]
|
||||
@ -2966,12 +2966,12 @@ func @linspace_static(%arg0: tensor<f32>, %arg1: tensor<f32>) -> tensor<4xf32> {
|
||||
// CHECK-DAG: [[NUM_CAST:%.*]] = tensor_cast [[NUM]]
|
||||
// CHECK-DAG: [[NUM_F32:%.*]] = "mhlo.convert"([[NUM_CAST]])
|
||||
// CHECK-DAG: [[ONE:%.*]] = mhlo.constant dense<1.000000e+00>
|
||||
// CHECK-DAG: [[STEP_DENOMINATOR:%.*]] = xla_chlo.broadcast_subtract [[NUM_F32]], [[ONE]]
|
||||
// CHECK-DAG: [[STEP_NUMERATOR:%.*]] = xla_chlo.broadcast_subtract [[STOP]], [[START]]
|
||||
// CHECK-DAG: [[STEP:%.*]] = xla_chlo.broadcast_divide [[STEP_NUMERATOR]], [[STEP_DENOMINATOR]]
|
||||
// CHECK-DAG: [[STEP_DENOMINATOR:%.*]] = chlo.broadcast_subtract [[NUM_F32]], [[ONE]]
|
||||
// CHECK-DAG: [[STEP_NUMERATOR:%.*]] = chlo.broadcast_subtract [[STOP]], [[START]]
|
||||
// CHECK-DAG: [[STEP:%.*]] = chlo.broadcast_divide [[STEP_NUMERATOR]], [[STEP_DENOMINATOR]]
|
||||
// CHECK-DAG: [[IOTA:%.*]] = "mhlo.iota"() {iota_dimension = 0 : i64}
|
||||
// CHECK-DAG: [[MUL:%.*]] = xla_chlo.broadcast_multiply [[IOTA]], [[STEP]] {broadcast_dimensions = dense<[]> : tensor<0xi64>}
|
||||
// CHECK-DAG: [[LINSPACE:%.*]] = xla_chlo.broadcast_add [[MUL]], [[START]] {broadcast_dimensions = dense<[]> : tensor<0xi64>}
|
||||
// CHECK-DAG: [[MUL:%.*]] = chlo.broadcast_multiply [[IOTA]], [[STEP]] {broadcast_dimensions = dense<[]> : tensor<0xi64>}
|
||||
// CHECK-DAG: [[LINSPACE:%.*]] = chlo.broadcast_add [[MUL]], [[START]] {broadcast_dimensions = dense<[]> : tensor<0xi64>}
|
||||
// CHECK: return [[LINSPACE]]
|
||||
%0 = "tf.Const"() {_output_shapes = ["tfshape$"], device = "", dtype = i32, value = dense<4> : tensor<i32>} : () -> tensor<i32>
|
||||
%1 = "tf.LinSpace"(%arg0, %arg1, %0) : (tensor<f32>, tensor<f32>, tensor<i32>) -> tensor<4xf32>
|
||||
@ -3266,13 +3266,13 @@ func @size_ranked(%input: tensor<2x?x8xf32>) -> (tensor<i32>) {
|
||||
// CHECK: %[[CONST:.*]] = mhlo.constant dense<1>
|
||||
// CHECK: %[[DIM_0:.*]] = "mhlo.get_dimension_size"(%[[INPUT]])
|
||||
// CHECK-SAME: dimension = 0
|
||||
// CHECK: %[[MUL_0:.*]] = xla_chlo.broadcast_multiply %[[CONST]], %[[DIM_0]]
|
||||
// CHECK: %[[MUL_0:.*]] = chlo.broadcast_multiply %[[CONST]], %[[DIM_0]]
|
||||
// CHECK: %[[DIM_1:.*]] = "mhlo.get_dimension_size"(%[[INPUT]])
|
||||
// CHECK-SAME: dimension = 1
|
||||
// CHECK: %[[MUL_1:.*]] = xla_chlo.broadcast_multiply %[[MUL_0]], %[[DIM_1]]
|
||||
// CHECK: %[[MUL_1:.*]] = chlo.broadcast_multiply %[[MUL_0]], %[[DIM_1]]
|
||||
// CHECK: %[[DIM_2:.*]] = "mhlo.get_dimension_size"(%[[INPUT]])
|
||||
// CHECK-SAME: dimension = 2
|
||||
// CHECK: %[[MUL_2:.*]] = xla_chlo.broadcast_multiply %[[MUL_1]], %[[DIM_2]]
|
||||
// CHECK: %[[MUL_2:.*]] = chlo.broadcast_multiply %[[MUL_1]], %[[DIM_2]]
|
||||
%size = "tf.Size"(%input) {T = "tfdtype$DT_FLOAT", out_type = "tfdtype$DT_INT32"} : (tensor<2x?x8xf32>) -> tensor<i32>
|
||||
// CHECK: return %[[MUL_2]]
|
||||
return %size : tensor<i32>
|
||||
@ -3789,7 +3789,7 @@ func @random_shuffle_3D(%input: tensor<4x?x16xf32>) -> tensor<4x?x16xf32> {
|
||||
// CHECK: [[INDICES1:%.*]] = "mhlo.dynamic-update-slice"([[INDICES]], [[TGT_IDX]], [[IV]]) : (tensor<4xi32>, tensor<1xi32>, tensor<i32>) -> tensor<4xi32>
|
||||
// CHECK: [[INDICES2:%.*]] = "mhlo.dynamic-update-slice"([[INDICES1]], [[SRC_IDX]], [[SWP]]) : (tensor<4xi32>, tensor<1xi32>, tensor<i32>) -> tensor<4xi32>
|
||||
// CHECK: [[ONE:%.*]] = mhlo.constant dense<1> : tensor<i32>
|
||||
// CHECK: [[NEW_IV:%.*]] = xla_chlo.broadcast_add [[IV]], [[ONE]]
|
||||
// CHECK: [[NEW_IV:%.*]] = chlo.broadcast_add [[IV]], [[ONE]]
|
||||
// CHECK: [[NEW_TUPLE:%.*]] = "mhlo.tuple"([[NEW_IV]], [[SWAPS]], [[INDICES2]])
|
||||
// CHECK: "mhlo.return"([[NEW_TUPLE]])
|
||||
// CHECK: }) : (tuple<tensor<i32>, tensor<4xi32>, tensor<4xi32>>) -> tuple<tensor<i32>, tensor<4xi32>, tensor<4xi32>>
|
||||
@ -3822,7 +3822,7 @@ func @avgpool_valid_padding(%arg0: tensor<2x12x20x7xf16>) -> tensor<2x3x5x7xf16>
|
||||
// CHECK: "mhlo.return"([[ADD]])
|
||||
// CHECK: }) {window_dimensions = dense<[1, 2, 2, 1]> : tensor<4xi64>, window_strides = dense<[1, 4, 4, 1]> : tensor<4xi64>} : (tensor<2x12x20x7xf32>, tensor<f32>) -> tensor<2x3x5x7xf32>
|
||||
// CHECK: [[COUNT:%.+]] = mhlo.constant dense<4.000000e+00> : tensor<f32>
|
||||
// CHECK: [[DIV:%.+]] = xla_chlo.broadcast_divide [[REDUCE]], [[COUNT]] {broadcast_dimensions = dense<[]> : tensor<0xi64>} : (tensor<2x3x5x7xf32>, tensor<f32>) -> tensor<2x3x5x7xf32>
|
||||
// CHECK: [[DIV:%.+]] = chlo.broadcast_divide [[REDUCE]], [[COUNT]] {broadcast_dimensions = dense<[]> : tensor<0xi64>} : (tensor<2x3x5x7xf32>, tensor<f32>) -> tensor<2x3x5x7xf32>
|
||||
// CHECK: [[CONV16:%.+]] = "mhlo.convert"([[DIV]]) : (tensor<2x3x5x7xf32>) -> tensor<2x3x5x7xf16>
|
||||
// CHECK: return [[CONV16]]
|
||||
%0 = "tf.AvgPool"(%arg0) {data_format = "NHWC", ksize = [1, 2, 2, 1], padding = "VALID", strides = [1, 4, 4, 1]} : (tensor<2x12x20x7xf16>) -> tensor<2x3x5x7xf16>
|
||||
@ -3844,7 +3844,7 @@ func @avgpool_same_padding(%arg0: tensor<2x13x25x7xf32>) -> tensor<2x4x7x7xf32>
|
||||
// CHECK-SAME: %[[OUT_GRAD:.*]]: tensor<10x12x16x64xf32>) -> tensor<10x24x32x64xf32> {
|
||||
// CHECK: %[[ZERO:.*]] = mhlo.constant dense<0.000000e+00> : tensor<f32>
|
||||
// CHECK: %[[DIVISOR:.*]] = mhlo.constant dense<4.000000e+00> : tensor<f32>
|
||||
// CHECK: %[[OUT_GRAD_DIVIDED:.*]] = xla_chlo.broadcast_divide %[[OUT_GRAD]], %[[DIVISOR]]
|
||||
// CHECK: %[[OUT_GRAD_DIVIDED:.*]] = chlo.broadcast_divide %[[OUT_GRAD]], %[[DIVISOR]]
|
||||
// CHECK_SAME: broadcast_dimensions = dense<[]>
|
||||
// CHECK_SAME: -> tensor<10x12x16x64xf32>
|
||||
// CHECK: %[[REDUCE_WINDOW_INPUT:.*]] = "mhlo.pad"(%[[OUT_GRAD_DIVIDED]], %[[ZERO]])
|
||||
@ -3876,7 +3876,7 @@ func @avgpool_grad_valid_padding(%grad: tensor<10x12x16x64xf32>) -> tensor<10x24
|
||||
// CHECK-SAME: %[[OUT_GRAD:.*]]: tensor<10x8x12x16x64xf32>) -> tensor<10x8x24x32x64xf32> {
|
||||
// CHECK: %[[ZERO:.*]] = mhlo.constant dense<0.000000e+00> : tensor<f32>
|
||||
// CHECK: %[[DIVISOR:.*]] = mhlo.constant dense<4.000000e+00> : tensor<f32>
|
||||
// CHECK: %[[OUT_GRAD_DIVIDED:.*]] = xla_chlo.broadcast_divide %[[OUT_GRAD]], %[[DIVISOR]] {broadcast_dimensions = dense<[]> : tensor<0xi64>} : (tensor<10x8x12x16x64xf32>, tensor<f32>) -> tensor<10x8x12x16x64xf32>
|
||||
// CHECK: %[[OUT_GRAD_DIVIDED:.*]] = chlo.broadcast_divide %[[OUT_GRAD]], %[[DIVISOR]] {broadcast_dimensions = dense<[]> : tensor<0xi64>} : (tensor<10x8x12x16x64xf32>, tensor<f32>) -> tensor<10x8x12x16x64xf32>
|
||||
// CHECK: %[[REDUCE_WINDOW_INPUT:.*]] = "mhlo.pad"(%[[OUT_GRAD_DIVIDED]], %[[ZERO]])
|
||||
// CHECK-SAME: edge_padding_high = dense<[0, 0, 1, 1, 0]>
|
||||
// CHECK-SAME: edge_padding_low = dense<[0, 0, 1, 1, 0]>
|
||||
@ -4059,7 +4059,7 @@ func @avgpool_3d_grad_ncdwh_format(%grad: tensor<2x9x8x4x7xf32>) -> tensor<2x9x8
|
||||
// CHECK-SAME: %[[OUT_GRAD:.*]]: tensor<10x12x16x64xbf16>) -> tensor<10x24x32x64xbf16> {
|
||||
// CHECK: %[[ZERO:.*]] = mhlo.constant dense<0.000000e+00> : tensor<bf16>
|
||||
// CHECK: %[[DIVISOR:.*]] = mhlo.constant dense<4.000000e+00> : tensor<bf16>
|
||||
// CHECK: %[[OUT_GRAD_DIVIDED:.*]] = xla_chlo.broadcast_divide %[[OUT_GRAD]], %[[DIVISOR]]
|
||||
// CHECK: %[[OUT_GRAD_DIVIDED:.*]] = chlo.broadcast_divide %[[OUT_GRAD]], %[[DIVISOR]]
|
||||
// CHECK-SAME: broadcast_dimensions = dense<[]>
|
||||
// CHECK-SAME: -> tensor<10x12x16x64xbf16>
|
||||
// CHECK: %[[REDUCE_WINDOW_INPUT:.*]] = "mhlo.pad"(%[[OUT_GRAD_DIVIDED]], %[[ZERO]])
|
||||
@ -4236,10 +4236,10 @@ func @softplus_f16(%arg0: tensor<8x16xf16>) -> tensor<8x16xf16> {
|
||||
// CHECK-DAG: [[EPSILON:%.*]] = mhlo.constant dense<1.220700e-04> : tensor<f16>
|
||||
// CHECK-DAG: [[EPSILON_LOG:%.*]] = "mhlo.log"([[EPSILON]])
|
||||
// CHECK-DAG: [[TWO:%.*]] = mhlo.constant dense<2.000000e+00> : tensor<f16>
|
||||
// CHECK: [[THRESHOLD:%.*]] = xla_chlo.broadcast_add [[EPSILON_LOG]], [[TWO]]
|
||||
// CHECK: [[THRESHOLD:%.*]] = chlo.broadcast_add [[EPSILON_LOG]], [[TWO]]
|
||||
// CHECK: [[NEG_THRESHOLD:%.*]] = "mhlo.negate"([[THRESHOLD]])
|
||||
// CHECK-DAG: [[COMPARE_GT:%.*]] = xla_chlo.broadcast_compare [[FEATURES]], [[NEG_THRESHOLD]] {comparison_direction = "GT"}
|
||||
// CHECK-DAG: [[COMPARE_LT:%.*]] = xla_chlo.broadcast_compare [[FEATURES]], [[THRESHOLD]] {comparison_direction = "LT"}
|
||||
// CHECK-DAG: [[COMPARE_GT:%.*]] = chlo.broadcast_compare [[FEATURES]], [[NEG_THRESHOLD]] {comparison_direction = "GT"}
|
||||
// CHECK-DAG: [[COMPARE_LT:%.*]] = chlo.broadcast_compare [[FEATURES]], [[THRESHOLD]] {comparison_direction = "LT"}
|
||||
// CHECK-DAG: [[FEATURES_EXP_LOG:%.*]] = "mhlo.log_plus_one"([[FEATURES_EXP]])
|
||||
// CHECK: [[ELSE_SELECT:%.*]] = "mhlo.select"([[COMPARE_LT]], [[FEATURES_EXP]], [[FEATURES_EXP_LOG]])
|
||||
// CHECK: [[ENTRY_SELECT:%.*]] = "mhlo.select"([[COMPARE_GT]], [[FEATURES]], [[ELSE_SELECT]])
|
||||
@ -4256,10 +4256,10 @@ func @softplus_bf16(%arg0: tensor<8x16xbf16>) -> tensor<8x16xbf16> {
|
||||
// CHECK-DAG: [[EPSILON:%.*]] = mhlo.constant dense<7.812500e-03> : tensor<bf16>
|
||||
// CHECK-DAG: [[EPSILON_LOG:%.*]] = "mhlo.log"([[EPSILON]])
|
||||
// CHECK-DAG: [[TWO:%.*]] = mhlo.constant dense<2.000000e+00> : tensor<bf16>
|
||||
// CHECK: [[THRESHOLD:%.*]] = xla_chlo.broadcast_add [[EPSILON_LOG]], [[TWO]]
|
||||
// CHECK: [[THRESHOLD:%.*]] = chlo.broadcast_add [[EPSILON_LOG]], [[TWO]]
|
||||
// CHECK: [[NEG_THRESHOLD:%.*]] = "mhlo.negate"([[THRESHOLD]])
|
||||
// CHECK-DAG: [[COMPARE_GT:%.*]] = xla_chlo.broadcast_compare [[FEATURES]], [[NEG_THRESHOLD]] {comparison_direction = "GT"}
|
||||
// CHECK-DAG: [[COMPARE_LT:%.*]] = xla_chlo.broadcast_compare [[FEATURES]], [[THRESHOLD]] {comparison_direction = "LT"}
|
||||
// CHECK-DAG: [[COMPARE_GT:%.*]] = chlo.broadcast_compare [[FEATURES]], [[NEG_THRESHOLD]] {comparison_direction = "GT"}
|
||||
// CHECK-DAG: [[COMPARE_LT:%.*]] = chlo.broadcast_compare [[FEATURES]], [[THRESHOLD]] {comparison_direction = "LT"}
|
||||
// CHECK-DAG: [[FEATURES_EXP_LOG:%.*]] = "mhlo.log_plus_one"([[FEATURES_EXP]])
|
||||
// CHECK: [[ELSE_SELECT:%.*]] = "mhlo.select"([[COMPARE_LT]], [[FEATURES_EXP]], [[FEATURES_EXP_LOG]])
|
||||
// CHECK: [[ENTRY_SELECT:%.*]] = "mhlo.select"([[COMPARE_GT]], [[FEATURES]], [[ELSE_SELECT]])
|
||||
@ -4276,10 +4276,10 @@ func @softplus_f32(%arg0: tensor<8x16xf32>) -> tensor<8x16xf32> {
|
||||
// CHECK-DAG: [[EPSILON:%.*]] = mhlo.constant dense<1.1920929E-7> : tensor<f32>
|
||||
// CHECK-DAG: [[EPSILON_LOG:%.*]] = "mhlo.log"([[EPSILON]])
|
||||
// CHECK-DAG: [[TWO:%.*]] = mhlo.constant dense<2.000000e+00> : tensor<f32>
|
||||
// CHECK: [[THRESHOLD:%.*]] = xla_chlo.broadcast_add [[EPSILON_LOG]], [[TWO]]
|
||||
// CHECK: [[THRESHOLD:%.*]] = chlo.broadcast_add [[EPSILON_LOG]], [[TWO]]
|
||||
// CHECK: [[NEG_THRESHOLD:%.*]] = "mhlo.negate"([[THRESHOLD]])
|
||||
// CHECK-DAG: [[COMPARE_GT:%.*]] = xla_chlo.broadcast_compare [[FEATURES]], [[NEG_THRESHOLD]] {comparison_direction = "GT"}
|
||||
// CHECK-DAG: [[COMPARE_LT:%.*]] = xla_chlo.broadcast_compare [[FEATURES]], [[THRESHOLD]] {comparison_direction = "LT"}
|
||||
// CHECK-DAG: [[COMPARE_GT:%.*]] = chlo.broadcast_compare [[FEATURES]], [[NEG_THRESHOLD]] {comparison_direction = "GT"}
|
||||
// CHECK-DAG: [[COMPARE_LT:%.*]] = chlo.broadcast_compare [[FEATURES]], [[THRESHOLD]] {comparison_direction = "LT"}
|
||||
// CHECK-DAG: [[FEATURES_EXP_LOG:%.*]] = "mhlo.log_plus_one"([[FEATURES_EXP]])
|
||||
// CHECK: [[ELSE_SELECT:%.*]] = "mhlo.select"([[COMPARE_LT]], [[FEATURES_EXP]], [[FEATURES_EXP_LOG]])
|
||||
// CHECK: [[ENTRY_SELECT:%.*]] = "mhlo.select"([[COMPARE_GT]], [[FEATURES]], [[ELSE_SELECT]])
|
||||
@ -4296,10 +4296,10 @@ func @softplus_f64(%arg0: tensor<8x16xf64>) -> tensor<8x16xf64> {
|
||||
// CHECK-DAG: [[EPSILON:%.*]] = mhlo.constant dense<2.2204460492503131E-16> : tensor<f64>
|
||||
// CHECK-DAG: [[EPSILON_LOG:%.*]] = "mhlo.log"([[EPSILON]])
|
||||
// CHECK-DAG: [[TWO:%.*]] = mhlo.constant dense<2.000000e+00> : tensor<f64>
|
||||
// CHECK: [[THRESHOLD:%.*]] = xla_chlo.broadcast_add [[EPSILON_LOG]], [[TWO]]
|
||||
// CHECK: [[THRESHOLD:%.*]] = chlo.broadcast_add [[EPSILON_LOG]], [[TWO]]
|
||||
// CHECK: [[NEG_THRESHOLD:%.*]] = "mhlo.negate"([[THRESHOLD]])
|
||||
// CHECK-DAG: [[COMPARE_GT:%.*]] = xla_chlo.broadcast_compare [[FEATURES]], [[NEG_THRESHOLD]] {comparison_direction = "GT"}
|
||||
// CHECK-DAG: [[COMPARE_LT:%.*]] = xla_chlo.broadcast_compare [[FEATURES]], [[THRESHOLD]] {comparison_direction = "LT"}
|
||||
// CHECK-DAG: [[COMPARE_GT:%.*]] = chlo.broadcast_compare [[FEATURES]], [[NEG_THRESHOLD]] {comparison_direction = "GT"}
|
||||
// CHECK-DAG: [[COMPARE_LT:%.*]] = chlo.broadcast_compare [[FEATURES]], [[THRESHOLD]] {comparison_direction = "LT"}
|
||||
// CHECK-DAG: [[FEATURES_EXP_LOG:%.*]] = "mhlo.log_plus_one"([[FEATURES_EXP]])
|
||||
// CHECK: [[ELSE_SELECT:%.*]] = "mhlo.select"([[COMPARE_LT]], [[FEATURES_EXP]], [[FEATURES_EXP_LOG]])
|
||||
// CHECK: [[ENTRY_SELECT:%.*]] = "mhlo.select"([[COMPARE_GT]], [[FEATURES]], [[ELSE_SELECT]])
|
||||
|
@ -48,7 +48,7 @@ class XlaBuilderTest : public ::testing::Test {
|
||||
xla_builder_(name_, builder_, module_->getLoc()) {}
|
||||
|
||||
string SetupTest() {
|
||||
mlir::registerDialect<mlir::mhlo::XlaHloDialect>();
|
||||
mlir::registerDialect<mlir::mhlo::MhloDialect>();
|
||||
return ::testing::UnitTest::GetInstance()->current_test_info()->name();
|
||||
}
|
||||
|
||||
|
@ -715,7 +715,7 @@ static void CreateWhile32(Location loc, int num_iterations,
|
||||
auto one =
|
||||
builder->create<mhlo::ConstOp>(loc, builder->getI32IntegerAttr(1));
|
||||
auto scalar_broadcast_dims = GetI64ElementsAttr({}, builder);
|
||||
auto plus_one = builder->create<xla_chlo::BroadcastAddOp>(
|
||||
auto plus_one = builder->create<chlo::BroadcastAddOp>(
|
||||
loc, old_values[0], one, scalar_broadcast_dims);
|
||||
// Prepend with the updated loop induction variable.
|
||||
new_values.insert(new_values.begin(), plus_one);
|
||||
@ -1483,7 +1483,7 @@ class ConvertFusedBatchNormGradBase
|
||||
RankedTensorType scalar_float = RankedTensorType::get({}, kernel_type);
|
||||
auto epsilon = rewriter.create<ConstOp>(
|
||||
loc, DenseFPElementsAttr::get(scalar_float, {op.epsilon()}));
|
||||
auto add_op = rewriter.create<xla_chlo::BroadcastAddOp>(
|
||||
auto add_op = rewriter.create<chlo::BroadcastAddOp>(
|
||||
loc, var, epsilon.getResult(), scalar_broadcast_dims);
|
||||
|
||||
Value scratch1 = rewriter.create<RsqrtOp>(loc, add_op);
|
||||
@ -1601,7 +1601,7 @@ class ConvertFusedBatchNormV3Op
|
||||
auto factor_const_op = rewriter.create<mhlo::ConstOp>(
|
||||
op.getLoc(), rewriter.getFloatAttr(scale_element_type, factor));
|
||||
|
||||
Value corrected_variance = rewriter.create<xla_chlo::BroadcastMulOp>(
|
||||
Value corrected_variance = rewriter.create<chlo::BroadcastMulOp>(
|
||||
op.getLoc(), batch_variance.getType(), batch_variance,
|
||||
factor_const_op, /*broadcast_dimensions=*/DenseIntElementsAttr());
|
||||
|
||||
@ -1621,26 +1621,24 @@ class ConvertFusedBatchNormV3Op
|
||||
rewriter.getFloatAttr(mean_element_type, exponential_avg_factor));
|
||||
|
||||
// new_running_mean = alpha * old_mean + beta * batch_mean.
|
||||
auto alpha_mul_old_mean = rewriter.create<xla_chlo::BroadcastMulOp>(
|
||||
auto alpha_mul_old_mean = rewriter.create<chlo::BroadcastMulOp>(
|
||||
op.getLoc(), op.mean().getType(), alpha, op.mean(),
|
||||
/*broadcast_dimensions=*/DenseIntElementsAttr());
|
||||
auto beta_mul_batch_mean = rewriter.create<xla_chlo::BroadcastMulOp>(
|
||||
auto beta_mul_batch_mean = rewriter.create<chlo::BroadcastMulOp>(
|
||||
op.getLoc(), batch_mean.getType(), beta, batch_mean,
|
||||
/*broadcast_dimensions=*/DenseIntElementsAttr());
|
||||
batch_mean = rewriter.create<xla_chlo::BroadcastAddOp>(
|
||||
batch_mean = rewriter.create<chlo::BroadcastAddOp>(
|
||||
op.getLoc(), alpha_mul_old_mean, beta_mul_batch_mean,
|
||||
/*broadcast_dimensions=*/DenseIntElementsAttr());
|
||||
|
||||
// new_running_variance = alpha * old_variance + beta * batch_variance.
|
||||
auto alpha_mul_old_variance = rewriter.create<xla_chlo::BroadcastMulOp>(
|
||||
auto alpha_mul_old_variance = rewriter.create<chlo::BroadcastMulOp>(
|
||||
op.getLoc(), op.variance().getType(), alpha, op.variance(),
|
||||
/*broadcast_dimensions=*/DenseIntElementsAttr());
|
||||
auto beta_mul_batch_variance =
|
||||
rewriter.create<xla_chlo::BroadcastMulOp>(
|
||||
op.getLoc(), corrected_variance.getType(), beta,
|
||||
corrected_variance,
|
||||
/*broadcast_dimensions=*/DenseIntElementsAttr());
|
||||
corrected_variance = rewriter.create<xla_chlo::BroadcastAddOp>(
|
||||
auto beta_mul_batch_variance = rewriter.create<chlo::BroadcastMulOp>(
|
||||
op.getLoc(), corrected_variance.getType(), beta, corrected_variance,
|
||||
/*broadcast_dimensions=*/DenseIntElementsAttr());
|
||||
corrected_variance = rewriter.create<chlo::BroadcastAddOp>(
|
||||
op.getLoc(), alpha_mul_old_variance, beta_mul_batch_variance,
|
||||
/*broadcast_dimensions=*/DenseIntElementsAttr());
|
||||
}
|
||||
@ -1810,7 +1808,7 @@ class ConvertAvgPoolOp : public OpRewritePattern<TF::AvgPoolOp> {
|
||||
Value divisor =
|
||||
GetScalarConstOfType(sum_element_type, op.getLoc(), count, &rewriter);
|
||||
auto scalar_broadcast_dims = GetI64ElementsAttr({}, &rewriter);
|
||||
Value result = rewriter.create<xla_chlo::BroadcastDivOp>(
|
||||
Value result = rewriter.create<chlo::BroadcastDivOp>(
|
||||
op.getLoc(), result_type, reduce, divisor, scalar_broadcast_dims);
|
||||
|
||||
// Convert back if we enlarged the element type's bitwidth.
|
||||
@ -1914,7 +1912,7 @@ class ConvertAvgPoolGradOp : public OpRewritePattern<OpTy> {
|
||||
Value divisor =
|
||||
GetScalarConstOfType(element_type, loc, window_count, &rewriter);
|
||||
auto scalar_broadcast_dims = GetI64ElementsAttr({}, &rewriter);
|
||||
out_grad_divided = rewriter.create<xla_chlo::BroadcastDivOp>(
|
||||
out_grad_divided = rewriter.create<chlo::BroadcastDivOp>(
|
||||
loc, out_grad_type, out_grad, divisor, scalar_broadcast_dims);
|
||||
} else {
|
||||
assert(op.padding() == "SAME");
|
||||
@ -2335,7 +2333,7 @@ class ConvertSizeOp : public OpRewritePattern<TF::SizeOp> {
|
||||
auto dim = rewriter.create<GetDimensionSizeOp>(
|
||||
op.getLoc(), result_type, input,
|
||||
rewriter.getIntegerAttr(rewriter.getIntegerType(32), i));
|
||||
size = rewriter.create<xla_chlo::BroadcastMulOp>(
|
||||
size = rewriter.create<chlo::BroadcastMulOp>(
|
||||
op.getLoc(), size->getResult(0), dim.getResult(),
|
||||
/*DenseIntElementsAttr=*/DenseIntElementsAttr());
|
||||
}
|
||||
@ -3021,10 +3019,10 @@ class ConvertRangeOp : public OpRewritePattern<TF::RangeOp> {
|
||||
|
||||
auto iota = rewriter.create<IotaOp>(op.getLoc(), result_type,
|
||||
rewriter.getI64IntegerAttr(0));
|
||||
auto scaled = rewriter.create<xla_chlo::BroadcastMulOp>(
|
||||
auto scaled = rewriter.create<chlo::BroadcastMulOp>(
|
||||
op.getLoc(), result_type, iota, op.delta(),
|
||||
xla::getBroadcastDimensionsAttr(&rewriter, iota, op.delta()));
|
||||
rewriter.replaceOpWithNewOp<xla_chlo::BroadcastAddOp>(
|
||||
rewriter.replaceOpWithNewOp<chlo::BroadcastAddOp>(
|
||||
op, result_type, scaled, op.start(),
|
||||
xla::getBroadcastDimensionsAttr(&rewriter, scaled, op.start()));
|
||||
return success();
|
||||
@ -3101,10 +3099,10 @@ class ConvertDynamicRangeOp : public OpRewritePattern<TF::RangeOp> {
|
||||
|
||||
auto iota = rewriter.create<DynamicIotaOp>(
|
||||
op.getLoc(), result_type, reshape, rewriter.getI64IntegerAttr(0));
|
||||
auto scaled = rewriter.create<xla_chlo::BroadcastMulOp>(
|
||||
auto scaled = rewriter.create<chlo::BroadcastMulOp>(
|
||||
op.getLoc(), result_type, iota, delta_out_cast,
|
||||
xla::getBroadcastDimensionsAttr(&rewriter, iota, delta_cast));
|
||||
rewriter.replaceOpWithNewOp<xla_chlo::BroadcastAddOp>(
|
||||
rewriter.replaceOpWithNewOp<chlo::BroadcastAddOp>(
|
||||
op, result_type, scaled, start_out_cast,
|
||||
xla::getBroadcastDimensionsAttr(&rewriter, scaled, start_out_cast));
|
||||
return success();
|
||||
@ -3152,7 +3150,7 @@ class ConvertLinSpaceOp : public OpRewritePattern<TF::LinSpaceOp> {
|
||||
int64_t num = (*num_attr.begin()).getSExtValue();
|
||||
|
||||
// Calculate the scaling that needs to be applied to the iota.
|
||||
auto step_numerator = rewriter.create<xla_chlo::BroadcastSubOp>(
|
||||
auto step_numerator = rewriter.create<chlo::BroadcastSubOp>(
|
||||
op.getLoc(), op.start().getType(), op.stop(), op.start(),
|
||||
xla::getBroadcastDimensionsAttr(&rewriter, op.stop(), op.start()));
|
||||
Value step_denominator = rewriter.create<ConvertOp>(
|
||||
@ -3160,11 +3158,11 @@ class ConvertLinSpaceOp : public OpRewritePattern<TF::LinSpaceOp> {
|
||||
if (num > 1) {
|
||||
Value one = GetScalarConstOfType(result_type.getElementType(),
|
||||
op.getLoc(), 1, &rewriter);
|
||||
step_denominator = rewriter.create<xla_chlo::BroadcastSubOp>(
|
||||
step_denominator = rewriter.create<chlo::BroadcastSubOp>(
|
||||
op.getLoc(), step_denominator.getType(), step_denominator, one,
|
||||
xla::getBroadcastDimensionsAttr(&rewriter, step_denominator, one));
|
||||
}
|
||||
auto step = rewriter.create<xla_chlo::BroadcastDivOp>(
|
||||
auto step = rewriter.create<chlo::BroadcastDivOp>(
|
||||
op.getLoc(), step_numerator.getType(), step_numerator, step_denominator,
|
||||
xla::getBroadcastDimensionsAttr(&rewriter, step_numerator,
|
||||
step_denominator));
|
||||
@ -3172,10 +3170,10 @@ class ConvertLinSpaceOp : public OpRewritePattern<TF::LinSpaceOp> {
|
||||
// Scale the iota and add the offset.
|
||||
auto iota = rewriter.create<IotaOp>(op.getLoc(), result_type,
|
||||
rewriter.getI64IntegerAttr(0));
|
||||
auto scaled = rewriter.create<xla_chlo::BroadcastMulOp>(
|
||||
auto scaled = rewriter.create<chlo::BroadcastMulOp>(
|
||||
op.getLoc(), result_type, iota, step,
|
||||
xla::getBroadcastDimensionsAttr(&rewriter, iota, step));
|
||||
rewriter.replaceOpWithNewOp<xla_chlo::BroadcastAddOp>(
|
||||
rewriter.replaceOpWithNewOp<chlo::BroadcastAddOp>(
|
||||
op, result_type, scaled, op.start(),
|
||||
xla::getBroadcastDimensionsAttr(&rewriter, scaled, op.start()));
|
||||
return success();
|
||||
@ -3251,7 +3249,7 @@ class GenericConvertReductionOp : public OpRewritePattern<OpTy> {
|
||||
auto divisor = GetScalarConstOfType(reduce_element_type, loc,
|
||||
divisor_count, &rewriter);
|
||||
auto broadcast_dims = GetI64ElementsAttr({}, &rewriter);
|
||||
result = rewriter.create<xla_chlo::BroadcastDivOp>(
|
||||
result = rewriter.create<chlo::BroadcastDivOp>(
|
||||
loc, result, divisor.getResult(), broadcast_dims);
|
||||
}
|
||||
|
||||
@ -5008,11 +5006,11 @@ class ConvertQrOp : public OpRewritePattern<TF::QrOp> {
|
||||
Value iota = builder->create<IotaOp>(
|
||||
loc, RankedTensorType::get({m}, builder->getIntegerType(32)),
|
||||
builder->getI64IntegerAttr(0));
|
||||
Value gtk = builder->create<xla_chlo::BroadcastCompareOp>(
|
||||
Value gtk = builder->create<chlo::BroadcastCompareOp>(
|
||||
loc, iota, k, GetI64ElementsAttr({}, builder),
|
||||
StringAttr::get("GT", builder->getContext()));
|
||||
gtk = builder->create<ConvertOp>(loc, gtk, x_type.getElementType());
|
||||
Value x_after_k = builder->create<xla_chlo::BroadcastMulOp>(
|
||||
Value x_after_k = builder->create<chlo::BroadcastMulOp>(
|
||||
loc, x, gtk, GetI64ElementsAttr({minor_dim}, builder));
|
||||
Value x_after_k_sq = builder->create<MulOp>(loc, x_after_k, x_after_k);
|
||||
// sigma = np.dot(x[k+1:], x[k+1:])
|
||||
@ -5024,15 +5022,15 @@ class ConvertQrOp : public OpRewritePattern<TF::QrOp> {
|
||||
Value mu = builder->create<SqrtOp>(
|
||||
loc, builder->create<AddOp>(loc, alpha_sq, sigma.getResult(0)));
|
||||
|
||||
Value sigma_is_zero = builder->create<xla_chlo::BroadcastCompareOp>(
|
||||
Value sigma_is_zero = builder->create<chlo::BroadcastCompareOp>(
|
||||
loc, sigma.getResult(0), zero, GetI64ElementsAttr({}, builder),
|
||||
StringAttr::get("EQ", builder->getContext()));
|
||||
Value alpha_is_negative = builder->create<xla_chlo::BroadcastCompareOp>(
|
||||
Value alpha_is_negative = builder->create<chlo::BroadcastCompareOp>(
|
||||
loc, alpha, zero, GetI64ElementsAttr({}, builder),
|
||||
StringAttr::get("LT", builder->getContext()));
|
||||
auto batch_size_one = builder->create<BroadcastOp>(
|
||||
loc, alpha.getType(), one, GetI64ElementsAttr(batch_dims, builder));
|
||||
Value signed_mu = builder->create<xla_chlo::BroadcastMulOp>(
|
||||
Value signed_mu = builder->create<chlo::BroadcastMulOp>(
|
||||
loc,
|
||||
builder->create<SelectOp>(loc, mu.getType(), alpha_is_negative,
|
||||
batch_size_one,
|
||||
@ -5050,7 +5048,7 @@ class ConvertQrOp : public OpRewritePattern<TF::QrOp> {
|
||||
divisor = builder->create<SelectOp>(loc, divisor.getType(), sigma_is_zero,
|
||||
batch_size_one, divisor);
|
||||
|
||||
Value eqk = builder->create<xla_chlo::BroadcastCompareOp>(
|
||||
Value eqk = builder->create<chlo::BroadcastCompareOp>(
|
||||
loc, iota, k, GetI64ElementsAttr({}, builder),
|
||||
StringAttr::get("EQ", builder->getContext()));
|
||||
eqk = builder->create<ConvertOp>(loc, eqk, x_type.getElementType());
|
||||
@ -5064,7 +5062,7 @@ class ConvertQrOp : public OpRewritePattern<TF::QrOp> {
|
||||
// Form v as [0, 0, ..., 1] ++ x[k+1:] / divisor
|
||||
// If sigma is zero, x[k+1:] is zero, so use any non-zero divisor.
|
||||
// Note that the add performs a degenerate broadcast.
|
||||
*v = builder->create<xla_chlo::BroadcastAddOp>(
|
||||
*v = builder->create<chlo::BroadcastAddOp>(
|
||||
loc, e_k,
|
||||
StaticBinaryBroadcast<DivOp>(loc, x_after_k, divisor,
|
||||
GetI64ElementsAttr(batch_dim_ids, builder),
|
||||
@ -5154,12 +5152,12 @@ class ConvertQrOp : public OpRewritePattern<TF::QrOp> {
|
||||
auto iota = builder->create<IotaOp>(
|
||||
loc, RankedTensorType::get({m, 1}, builder->getIntegerType(32)),
|
||||
builder->getI64IntegerAttr(0));
|
||||
Value predecessor_mask = builder->create<xla_chlo::BroadcastCompareOp>(
|
||||
Value predecessor_mask = builder->create<chlo::BroadcastCompareOp>(
|
||||
loc, iota, j, GetI64ElementsAttr({}, builder),
|
||||
StringAttr::get("LT", builder->getContext()));
|
||||
predecessor_mask = builder->create<ConvertOp>(loc, predecessor_mask,
|
||||
a_type.getElementType());
|
||||
Value mask = builder->create<xla_chlo::BroadcastCompareOp>(
|
||||
Value mask = builder->create<chlo::BroadcastCompareOp>(
|
||||
loc, iota, j, GetI64ElementsAttr({}, builder),
|
||||
StringAttr::get("EQ", builder->getContext()));
|
||||
mask = builder->create<ConvertOp>(loc, mask, a_type.getElementType());
|
||||
@ -5189,7 +5187,7 @@ class ConvertQrOp : public OpRewritePattern<TF::QrOp> {
|
||||
loc,
|
||||
RankedTensorType::get(a_type.getShape(), builder->getIntegerType(32)),
|
||||
builder->getI64IntegerAttr(minor_dim + 1));
|
||||
Value xa_mask = builder->create<xla_chlo::BroadcastCompareOp>(
|
||||
Value xa_mask = builder->create<chlo::BroadcastCompareOp>(
|
||||
loc, iota_mn, j, GetI64ElementsAttr({}, builder),
|
||||
StringAttr::get("EQ", builder->getContext()));
|
||||
a = builder->create<SelectOp>(loc, a_type, xa_mask, new_x, a);
|
||||
@ -5226,7 +5224,7 @@ class ConvertQrOp : public OpRewritePattern<TF::QrOp> {
|
||||
loc, taus.getType(), taus_zeros,
|
||||
GetI64ElementsAttr(taus.getType().cast<RankedTensorType>().getShape(),
|
||||
builder));
|
||||
Value taus_mask = builder->create<xla_chlo::BroadcastCompareOp>(
|
||||
Value taus_mask = builder->create<chlo::BroadcastCompareOp>(
|
||||
loc, iota_n, j, GetI64ElementsAttr({}, builder),
|
||||
StringAttr::get("EQ", builder->getContext()));
|
||||
auto taus_update = builder->create<SelectOp>(
|
||||
@ -5311,7 +5309,7 @@ class ConvertQrOp : public OpRewritePattern<TF::QrOp> {
|
||||
loc, vs.getType(), zero,
|
||||
GetI64ElementsAttr(vs.getType().cast<RankedTensorType>().getShape(),
|
||||
builder));
|
||||
auto compare = builder->create<xla_chlo::BroadcastCompareOp>(
|
||||
auto compare = builder->create<chlo::BroadcastCompareOp>(
|
||||
loc, iota_mn, j, GetI64ElementsAttr({}, builder),
|
||||
StringAttr::get("GE", builder->getContext()));
|
||||
auto y = builder->create<SelectOp>(loc, vs.getType(), compare, zero, vs);
|
||||
@ -5459,16 +5457,16 @@ LogicalResult legalizeTF(Operation *op, bool allow_partial_conversion,
|
||||
// Populate with CHLO->HLO lowerings to account for TF ops legalized to
|
||||
// CHLO first.
|
||||
if (legalize_chlo) {
|
||||
xla_chlo::PopulateLegalizeChloToHloPatterns(context, &patterns);
|
||||
chlo::PopulateLegalizeChloToHloPatterns(context, &patterns);
|
||||
}
|
||||
|
||||
ConversionTarget target(*context);
|
||||
if (legalize_chlo) {
|
||||
target.addIllegalDialect<xla_chlo::XlaHloClientDialect>();
|
||||
target.addIllegalDialect<chlo::HloClientDialect>();
|
||||
} else {
|
||||
target.addLegalDialect<xla_chlo::XlaHloClientDialect>();
|
||||
target.addLegalDialect<chlo::HloClientDialect>();
|
||||
}
|
||||
target.addLegalDialect<XlaHloDialect>();
|
||||
target.addLegalDialect<MhloDialect>();
|
||||
target.addLegalDialect<StandardOpsDialect>();
|
||||
target.addLegalDialect<shape::ShapeDialect>();
|
||||
target.addLegalOp<CallOp>();
|
||||
|
@ -75,10 +75,10 @@ namespace {
|
||||
template <typename T, size_t N>
|
||||
using InlinedVector = tensorflow::gtl::InlinedVector<T, N>; // non-absl ok
|
||||
|
||||
static bool IsOpWhitelisted(Operation* op) {
|
||||
// White-listed TensorFlow ops are known to have well behaved tf2xla kernels
|
||||
static bool IsOpAllowlisted(Operation* op) {
|
||||
// Allowlisted TensorFlow ops are known to have well behaved tf2xla kernels
|
||||
// building valid MLIR using MlirHloBuilder.
|
||||
// TODO(hinsu): Drop explicit whitelist when MLIR based bridge is enabled for
|
||||
// TODO(hinsu): Drop explicit allowlist when MLIR based bridge is enabled for
|
||||
// all tf2xla kernels.
|
||||
// clang-format off
|
||||
static llvm::SmallDenseSet<mlir::TypeID, 512> ops = {
|
||||
@ -342,7 +342,7 @@ LogicalResult FuncLegalizer::Legalize() {
|
||||
}
|
||||
|
||||
LogicalResult FuncLegalizer::LegalizeOp(Operation* op) {
|
||||
if (!IsOpWhitelisted(op)) return success();
|
||||
if (!IsOpAllowlisted(op)) return success();
|
||||
|
||||
// Only static shaped operands are supported in XLA builders for now.
|
||||
for (Type ty : op->getOperandTypes()) {
|
||||
|
@ -190,51 +190,51 @@ Status LhloDialectEmitter::DefaultAction(HloInstruction* instr) {
|
||||
using ::xla::HloOpcode;
|
||||
switch (instr->opcode()) {
|
||||
case HloOpcode::kAbs:
|
||||
return CreateOpWithoutAttrs<xla_lhlo::AbsOp>(instr).status();
|
||||
return CreateOpWithoutAttrs<lmhlo::AbsOp>(instr).status();
|
||||
case HloOpcode::kAdd:
|
||||
return CreateOpWithoutAttrs<xla_lhlo::AddOp>(instr).status();
|
||||
return CreateOpWithoutAttrs<lmhlo::AddOp>(instr).status();
|
||||
case HloOpcode::kAnd:
|
||||
return CreateOpWithoutAttrs<xla_lhlo::AndOp>(instr).status();
|
||||
return CreateOpWithoutAttrs<lmhlo::AndOp>(instr).status();
|
||||
case HloOpcode::kCeil:
|
||||
return CreateOpWithoutAttrs<xla_lhlo::CeilOp>(instr).status();
|
||||
return CreateOpWithoutAttrs<lmhlo::CeilOp>(instr).status();
|
||||
case HloOpcode::kComplex:
|
||||
return CreateOpWithoutAttrs<xla_lhlo::ComplexOp>(instr).status();
|
||||
return CreateOpWithoutAttrs<lmhlo::ComplexOp>(instr).status();
|
||||
case HloOpcode::kCopy:
|
||||
return CreateOpWithoutAttrs<xla_lhlo::CopyOp>(instr).status();
|
||||
return CreateOpWithoutAttrs<lmhlo::CopyOp>(instr).status();
|
||||
case HloOpcode::kCos:
|
||||
return CreateOpWithoutAttrs<xla_lhlo::CosOp>(instr).status();
|
||||
return CreateOpWithoutAttrs<lmhlo::CosOp>(instr).status();
|
||||
case HloOpcode::kDivide:
|
||||
return CreateOpWithoutAttrs<xla_lhlo::DivOp>(instr).status();
|
||||
return CreateOpWithoutAttrs<lmhlo::DivOp>(instr).status();
|
||||
case HloOpcode::kExp:
|
||||
return CreateOpWithoutAttrs<xla_lhlo::ExpOp>(instr).status();
|
||||
return CreateOpWithoutAttrs<lmhlo::ExpOp>(instr).status();
|
||||
case HloOpcode::kImag:
|
||||
return CreateOpWithoutAttrs<xla_lhlo::ImagOp>(instr).status();
|
||||
return CreateOpWithoutAttrs<lmhlo::ImagOp>(instr).status();
|
||||
case HloOpcode::kLog:
|
||||
return CreateOpWithoutAttrs<xla_lhlo::LogOp>(instr).status();
|
||||
return CreateOpWithoutAttrs<lmhlo::LogOp>(instr).status();
|
||||
case HloOpcode::kMaximum:
|
||||
return CreateOpWithoutAttrs<xla_lhlo::MaxOp>(instr).status();
|
||||
return CreateOpWithoutAttrs<lmhlo::MaxOp>(instr).status();
|
||||
case HloOpcode::kMinimum:
|
||||
return CreateOpWithoutAttrs<xla_lhlo::MinOp>(instr).status();
|
||||
return CreateOpWithoutAttrs<lmhlo::MinOp>(instr).status();
|
||||
case HloOpcode::kMultiply:
|
||||
return CreateOpWithoutAttrs<xla_lhlo::MulOp>(instr).status();
|
||||
return CreateOpWithoutAttrs<lmhlo::MulOp>(instr).status();
|
||||
case HloOpcode::kNegate:
|
||||
return CreateOpWithoutAttrs<xla_lhlo::NegOp>(instr).status();
|
||||
return CreateOpWithoutAttrs<lmhlo::NegOp>(instr).status();
|
||||
case HloOpcode::kReal:
|
||||
return CreateOpWithoutAttrs<xla_lhlo::RealOp>(instr).status();
|
||||
return CreateOpWithoutAttrs<lmhlo::RealOp>(instr).status();
|
||||
case HloOpcode::kRemainder:
|
||||
return CreateOpWithoutAttrs<xla_lhlo::RemOp>(instr).status();
|
||||
return CreateOpWithoutAttrs<lmhlo::RemOp>(instr).status();
|
||||
case HloOpcode::kRsqrt:
|
||||
return CreateOpWithoutAttrs<xla_lhlo::RsqrtOp>(instr).status();
|
||||
return CreateOpWithoutAttrs<lmhlo::RsqrtOp>(instr).status();
|
||||
case HloOpcode::kSelect:
|
||||
return CreateOpWithoutAttrs<xla_lhlo::SelectOp>(instr).status();
|
||||
return CreateOpWithoutAttrs<lmhlo::SelectOp>(instr).status();
|
||||
case HloOpcode::kSign:
|
||||
return CreateOpWithoutAttrs<xla_lhlo::SignOp>(instr).status();
|
||||
return CreateOpWithoutAttrs<lmhlo::SignOp>(instr).status();
|
||||
case HloOpcode::kSqrt:
|
||||
return CreateOpWithoutAttrs<xla_lhlo::SqrtOp>(instr).status();
|
||||
return CreateOpWithoutAttrs<lmhlo::SqrtOp>(instr).status();
|
||||
case HloOpcode::kSubtract:
|
||||
return CreateOpWithoutAttrs<xla_lhlo::SubOp>(instr).status();
|
||||
return CreateOpWithoutAttrs<lmhlo::SubOp>(instr).status();
|
||||
case HloOpcode::kTanh:
|
||||
return CreateOpWithoutAttrs<xla_lhlo::TanhOp>(instr).status();
|
||||
return CreateOpWithoutAttrs<lmhlo::TanhOp>(instr).status();
|
||||
default:
|
||||
llvm::errs() << instr->ToString();
|
||||
return tensorflow::errors::Internal(
|
||||
@ -246,7 +246,7 @@ Status LhloDialectEmitter::DefaultAction(HloInstruction* instr) {
|
||||
|
||||
StatusOr<mlir::Operation*> LhloDialectEmitter::EmitSortOp(
|
||||
HloInstruction* instr) {
|
||||
TF_ASSIGN_OR_RETURN(auto sort, CreateOpWithoutAttrs<xla_lhlo::SortOp>(instr));
|
||||
TF_ASSIGN_OR_RETURN(auto sort, CreateOpWithoutAttrs<lmhlo::SortOp>(instr));
|
||||
auto* sort_instr = ::xla::Cast<::xla::HloSortInstruction>(instr);
|
||||
sort.dimensionAttr(builder_.getI64IntegerAttr(sort_instr->sort_dimension()));
|
||||
sort.is_stableAttr(builder_.getBoolAttr(sort_instr->is_stable()));
|
||||
@ -379,16 +379,16 @@ Status LhloDialectEmitter::Initialize() {
|
||||
block->addArgument(arg_type);
|
||||
allocations_[alloc] = block->getArguments().back();
|
||||
args_attrs.emplace_back();
|
||||
args_attrs.back().set(builder_.getIdentifier("xla_lhlo.params"),
|
||||
args_attrs.back().set(builder_.getIdentifier("lmhlo.params"),
|
||||
builder_.getIndexAttr(alloc->parameter_number()));
|
||||
} else {
|
||||
block->addArgument(MemRefType::get({alloc->size()}, i8_type_));
|
||||
allocations_[alloc] = block->getArguments().back();
|
||||
args_attrs.emplace_back();
|
||||
args_attrs.back().set(builder_.getIdentifier("xla_lhlo.alloc"),
|
||||
args_attrs.back().set(builder_.getIdentifier("lmhlo.alloc"),
|
||||
builder_.getIndexAttr(alloc->index()));
|
||||
if (alloc->maybe_live_out())
|
||||
args_attrs.back().set(builder_.getIdentifier("xla_lhlo.liveout"),
|
||||
args_attrs.back().set(builder_.getIdentifier("lmhlo.liveout"),
|
||||
builder_.getBoolAttr(true));
|
||||
}
|
||||
}
|
||||
|
Some files were not shown because too many files have changed in this diff Show More
Loading…
x
Reference in New Issue
Block a user