Merge branch 'master' of github.com:ashahba/tensorflow into ashahba/ubuntu-onednn-partials

This commit is contained in:
Abolfazl Shahbazi 2020-07-07 21:03:21 -07:00
commit f7dabcae30
388 changed files with 10149 additions and 6583 deletions

View File

@ -243,8 +243,10 @@ TFE_TensorHandle* CopyTensorFromParallelDevice(TFE_Context* context,
const char* target_device_name, const char* target_device_name,
TF_Status* status, TF_Status* status,
void* device_info) { void* device_info) {
TF_SetStatus(status, TF_INTERNAL, TF_SetStatus(status, TF_UNIMPLEMENTED,
"Trying to copy a tensor out of a parallel device."); "Trying to copy a tensor out of a parallel device. Since there "
"are multiple components to parallel tensors, they must be "
"unpacked explicitly.");
return nullptr; return nullptr;
} }

View File

@ -157,7 +157,7 @@ TEST(PARALLEL_DEVICE, TestExplicitCopies) {
// Copies off of parallel devices must be explicit. // Copies off of parallel devices must be explicit.
TensorHandlePtr copy_back(TFE_TensorHandleCopyToDevice( TensorHandlePtr copy_back(TFE_TensorHandleCopyToDevice(
device_value.get(), context.get(), first_device_name, status.get())); device_value.get(), context.get(), first_device_name, status.get()));
ASSERT_EQ(TF_GetCode(status.get()), TF_INTERNAL); ASSERT_EQ(TF_GetCode(status.get()), TF_UNIMPLEMENTED);
} }
TEST(PARALLEL_DEVICE, TestDifferentShapes) { TEST(PARALLEL_DEVICE, TestDifferentShapes) {

View File

@ -73,6 +73,14 @@ void ParseGCSPath(const std::string& fname, bool object_empty_ok,
} }
} }
/// Appends a trailing slash if the name doesn't already have one.
static void MaybeAppendSlash(std::string* name) {
if (name->empty())
*name = "/";
else if (name->back() != '/')
name->push_back('/');
}
// SECTION 1. Implementation for `TF_RandomAccessFile` // SECTION 1. Implementation for `TF_RandomAccessFile`
// ---------------------------------------------------------------------------- // ----------------------------------------------------------------------------
namespace tf_random_access_file { namespace tf_random_access_file {
@ -410,6 +418,70 @@ void NewReadOnlyMemoryRegionFromFile(const TF_Filesystem* filesystem,
} }
} }
void CreateDir(const TF_Filesystem* filesystem, const char* path,
TF_Status* status) {
std::string bucket, object;
ParseGCSPath(path, true, &bucket, &object, status);
if (TF_GetCode(status) != TF_OK) return;
auto gcs_file = static_cast<GCSFile*>(filesystem->plugin_filesystem);
if (object.empty()) {
auto bucket_metadata = gcs_file->gcs_client.GetBucketMetadata(bucket);
TF_SetStatusFromGCSStatus(bucket_metadata.status(), status);
return;
}
MaybeAppendSlash(&object);
auto object_metadata = gcs_file->gcs_client.GetObjectMetadata(bucket, object);
TF_SetStatusFromGCSStatus(object_metadata.status(), status);
if (TF_GetCode(status) == TF_NOT_FOUND) {
auto insert_metadata =
gcs_file->gcs_client.InsertObject(bucket, object, "");
TF_SetStatusFromGCSStatus(insert_metadata.status(), status);
} else if (TF_GetCode(status) == TF_OK) {
TF_SetStatus(status, TF_ALREADY_EXISTS, path);
}
}
void DeleteFile(const TF_Filesystem* filesystem, const char* path,
TF_Status* status) {
std::string bucket, object;
ParseGCSPath(path, false, &bucket, &object, status);
if (TF_GetCode(status) != TF_OK) return;
auto gcs_file = static_cast<GCSFile*>(filesystem->plugin_filesystem);
auto gcs_status = gcs_file->gcs_client.DeleteObject(bucket, object);
TF_SetStatusFromGCSStatus(gcs_status, status);
}
void DeleteDir(const TF_Filesystem* filesystem, const char* path,
TF_Status* status) {
std::string bucket, object;
ParseGCSPath(path, false, &bucket, &object, status);
if (TF_GetCode(status) != TF_OK) return;
MaybeAppendSlash(&object);
auto gcs_file = static_cast<GCSFile*>(filesystem->plugin_filesystem);
int object_count = 0;
for (auto&& metadata :
gcs_file->gcs_client.ListObjects(bucket, gcs::Prefix(object))) {
if (!metadata) {
TF_SetStatusFromGCSStatus(metadata.status(), status);
return;
}
++object_count;
// We consider a path is a non-empty directory in two cases:
// - There are more than two objects whose keys start with the name of this
// directory.
// - There is one object whose key contains the name of this directory ( but
// not equal ).
if (object_count > 1 || metadata->name() != object) {
TF_SetStatus(status, TF_FAILED_PRECONDITION,
"Cannot delete a non-empty directory.");
return;
}
}
auto gcs_status = gcs_file->gcs_client.DeleteObject(bucket, object);
TF_SetStatusFromGCSStatus(gcs_status, status);
}
} // namespace tf_gcs_filesystem } // namespace tf_gcs_filesystem
static void ProvideFilesystemSupportFor(TF_FilesystemPluginOps* ops, static void ProvideFilesystemSupportFor(TF_FilesystemPluginOps* ops,

View File

@ -19,9 +19,6 @@ package(
cc_library( cc_library(
name = "concrete_function", name = "concrete_function",
srcs = [
"concrete_function.cc",
],
hdrs = [ hdrs = [
"concrete_function.h", "concrete_function.h",
], ],
@ -29,7 +26,6 @@ cc_library(
":function_metadata", ":function_metadata",
"//tensorflow/c/eager:immediate_execution_operation", "//tensorflow/c/eager:immediate_execution_operation",
"//tensorflow/c/eager:immediate_execution_tensor_handle", "//tensorflow/c/eager:immediate_execution_tensor_handle",
"//tensorflow/core:protos_all_cc",
], ],
) )
@ -60,10 +56,13 @@ cc_library(
"saved_model_utils.h", "saved_model_utils.h",
], ],
deps = [ deps = [
":function_metadata",
"//tensorflow/c:tf_tensor_internal", "//tensorflow/c:tf_tensor_internal",
"//tensorflow/c/eager:immediate_execution_context", "//tensorflow/c/eager:immediate_execution_context",
"//tensorflow/c/experimental/saved_model/core/revived_types:constant", "//tensorflow/c/experimental/saved_model/core/revived_types:constant",
"//tensorflow/c/experimental/saved_model/core/revived_types:tf_concrete_function",
"//tensorflow/c/experimental/saved_model/core/revived_types:variable", "//tensorflow/c/experimental/saved_model/core/revived_types:variable",
"//tensorflow/core:lib",
"//tensorflow/core:protos_all_cc", "//tensorflow/core:protos_all_cc",
], ],
) )
@ -91,6 +90,18 @@ cc_library(
], ],
) )
cc_library(
name = "tf_concrete_function_test_protos",
testonly = True,
srcs = ["tf_concrete_function_test_protos.cc"],
hdrs = ["tf_concrete_function_test_protos.h"],
deps = [
"//tensorflow/core:lib",
"//tensorflow/core:protos_all_cc",
"@com_google_absl//absl/strings",
],
)
cc_library( cc_library(
name = "tf_saved_model_impl", name = "tf_saved_model_impl",
srcs = [ srcs = [
@ -114,12 +125,16 @@ cc_library(
"saved_model_api.h", "saved_model_api.h",
], ],
visibility = ["//tensorflow/python:__pkg__"], visibility = ["//tensorflow/python:__pkg__"],
deps = [
"//tensorflow/c/eager:immediate_execution_operation",
"//tensorflow/c/eager:immediate_execution_tensor_handle",
"//tensorflow/core:lib",
],
) )
filegroup( filegroup(
name = "mobile_srcs_only_runtime", name = "mobile_srcs_only_runtime",
srcs = [ srcs = [
"concrete_function.cc",
"concrete_function.h", "concrete_function.h",
"function_metadata.h", "function_metadata.h",
"saved_model_api.h", "saved_model_api.h",
@ -172,3 +187,28 @@ tf_cc_test(
"//tensorflow/core/common_runtime/eager:core", "//tensorflow/core/common_runtime/eager:core",
], ],
) )
tf_cc_test(
name = "tf_concrete_function_loading_test",
srcs = [
"tf_concrete_function_loading_test.cc",
],
deps = [
":saved_model_utils",
":test_utils",
":tf_concrete_function_test_protos",
"//tensorflow/c:tensor_interface",
"//tensorflow/c/eager:immediate_execution_tensor_handle",
"//tensorflow/c/experimental/saved_model/core/revived_types:constant",
"//tensorflow/c/experimental/saved_model/core/revived_types:tensorhandle_convertible",
"//tensorflow/c/experimental/saved_model/core/revived_types:tf_concrete_function",
"//tensorflow/core:framework",
"//tensorflow/core:lib",
"//tensorflow/core:protos_all_cc",
"//tensorflow/core:test",
"//tensorflow/core:test_main",
"//tensorflow/core/common_runtime:core_cpu_lib",
"//tensorflow/core/common_runtime/eager:context",
"//tensorflow/core/common_runtime/eager:core",
],
)

View File

@ -16,12 +16,12 @@ limitations under the License.
#ifndef TENSORFLOW_C_EXPERIMENTAL_SAVED_MODEL_CORE_CONCRETE_FUNCTION_H_ #ifndef TENSORFLOW_C_EXPERIMENTAL_SAVED_MODEL_CORE_CONCRETE_FUNCTION_H_
#define TENSORFLOW_C_EXPERIMENTAL_SAVED_MODEL_CORE_CONCRETE_FUNCTION_H_ #define TENSORFLOW_C_EXPERIMENTAL_SAVED_MODEL_CORE_CONCRETE_FUNCTION_H_
#include <memory>
#include <vector> #include <vector>
#include "tensorflow/c/eager/immediate_execution_operation.h" #include "tensorflow/c/eager/immediate_execution_operation.h"
#include "tensorflow/c/eager/immediate_execution_tensor_handle.h" #include "tensorflow/c/eager/immediate_execution_tensor_handle.h"
#include "tensorflow/c/experimental/saved_model/core/function_metadata.h" #include "tensorflow/c/experimental/saved_model/core/function_metadata.h"
#include "tensorflow/core/framework/function.pb.h"
namespace tensorflow { namespace tensorflow {
@ -35,19 +35,14 @@ namespace tensorflow {
// and have only a single implementation. // and have only a single implementation.
class ConcreteFunction { class ConcreteFunction {
public: public:
virtual ~ConcreteFunction() = 0; virtual ~ConcreteFunction() = default;
// This method returns the "Call" Op used to execute the function. // This method returns the "Call" Op used to execute the function.
virtual ImmediateExecutionOperation* GetCallOp() = 0; virtual Status GetCallOp(ImmediateOpPtr* out) = 0;
const std::vector<tensorflow::ImmediateExecutionTensorHandle*>& GetCaptures() virtual const std::vector<ImmediateExecutionTensorHandle*>& GetCaptures()
const; const = 0;
const FunctionMetadata& GetFunctionMetadata() const; virtual const FunctionMetadata& GetFunctionMetadata() const = 0;
private:
FunctionMetadata metadata_;
std::vector<tensorflow::ImmediateExecutionTensorHandle*> captures_;
FunctionDef* function_;
}; };
} // namespace tensorflow } // namespace tensorflow

View File

@ -14,6 +14,27 @@ package(
licenses = ["notice"], # Apache 2.0 licenses = ["notice"], # Apache 2.0
) )
cc_library(
name = "restore_ops",
srcs = [
"restore_ops.cc",
],
hdrs = [
"restore_ops.h",
],
deps = [
"//tensorflow/c:tensor_interface",
"//tensorflow/c/eager:abstract_tensor_handle",
"//tensorflow/c/eager:immediate_execution_context",
"//tensorflow/c/eager:immediate_execution_operation",
"//tensorflow/c/eager:immediate_execution_tensor_handle",
"//tensorflow/core:framework",
"//tensorflow/core:lib",
"//tensorflow/core:protos_all_cc",
"//tensorflow/core/lib/llvm_rtti",
],
)
cc_library( cc_library(
name = "variable_ops", name = "variable_ops",
srcs = [ srcs = [
@ -37,16 +58,45 @@ cc_library(
) )
tf_cc_test( tf_cc_test(
name = "variable_ops_test", name = "restore_ops_test",
srcs = [ srcs = [
"variable_ops_test.cc", "restore_ops_test.cc",
],
data = [
"//tensorflow/cc/saved_model:saved_model_half_plus_two",
], ],
deps = [ deps = [
":variable_ops", ":restore_ops",
"//tensorflow/c:tensor_interface", "//tensorflow/c:tensor_interface",
"//tensorflow/c/eager:abstract_tensor_handle", "//tensorflow/c/eager:abstract_tensor_handle",
"//tensorflow/c/eager:immediate_execution_context", "//tensorflow/c/eager:immediate_execution_context",
"//tensorflow/c/eager:immediate_execution_tensor_handle", "//tensorflow/c/eager:immediate_execution_tensor_handle",
"//tensorflow/c/experimental/saved_model/core:test_utils",
"//tensorflow/cc/saved_model:constants",
"//tensorflow/core:all_kernels",
"//tensorflow/core:framework",
"//tensorflow/core:lib",
"//tensorflow/core:protos_all_cc",
"//tensorflow/core:test",
"//tensorflow/core:test_main",
"//tensorflow/core/common_runtime:core_cpu_lib",
"//tensorflow/core/common_runtime/eager:context",
"//tensorflow/core/common_runtime/eager:core",
],
)
tf_cc_test(
name = "variable_ops_test",
srcs = [
"variable_ops_test.cc",
],
deps = [
":variable_ops",
"//tensorflow/c:tensor_interface",
"//tensorflow/c/eager:abstract_tensor_handle",
"//tensorflow/c/eager:immediate_execution_context",
"//tensorflow/c/eager:immediate_execution_tensor_handle",
"//tensorflow/c/experimental/saved_model/core:test_utils",
"//tensorflow/core:all_kernels", "//tensorflow/core:all_kernels",
"//tensorflow/core:framework", "//tensorflow/core:framework",
"//tensorflow/core:lib", "//tensorflow/core:lib",

View File

@ -0,0 +1,111 @@
/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#include "tensorflow/c/experimental/saved_model/core/ops/restore_ops.h"
#include "tensorflow/c/eager/abstract_tensor_handle.h"
#include "tensorflow/c/eager/immediate_execution_context.h"
#include "tensorflow/c/eager/immediate_execution_operation.h"
#include "tensorflow/c/eager/immediate_execution_tensor_handle.h"
#include "tensorflow/c/tensor_interface.h"
#include "tensorflow/core/framework/tensor.h"
#include "tensorflow/core/framework/types.pb.h"
#include "tensorflow/core/lib/llvm_rtti/llvm_rtti.h"
#include "tensorflow/core/platform/errors.h"
#include "tensorflow/core/platform/types.h"
namespace tensorflow {
namespace internal {
namespace {
// Creates a scalar string tensorhandle containing a single string `s`
Status CreateStringScalarTensorHandle(ImmediateExecutionContext* ctx,
const std::string& s,
ImmediateTensorHandlePtr* out) {
AbstractTensorPtr tensor(ctx->CreateStringScalar(s));
if (tensor.get() == nullptr) {
return errors::Internal(
"Failed to create scalar string tensor for checkpoint restore");
}
out->reset(ctx->CreateLocalHandle(tensor.get()));
return Status();
}
// Creates a Rank 1 string tensorhandle containing a single string `s`
Status CreateStringVectorTensorHandle(ImmediateExecutionContext* ctx,
const std::string& s,
ImmediateTensorHandlePtr* out) {
int64 flat_shape[] = {1};
AbstractTensorPtr tensor(ctx->CreateTensor(DT_STRING, flat_shape));
if (tensor.get() == nullptr) {
return errors::Internal(
"Failed to create vector string tensor for checkpoint restore");
}
// Use placement new to construct the string, since we don't have
// access to Tensor::flat. This is conceptually equivalent to:
// tensor.flat<tstring>()(0) = s
new (tensor->Data()) tstring(s);
out->reset(ctx->CreateLocalHandle(tensor.get()));
return Status();
}
} // namespace
Status SingleRestore(ImmediateExecutionContext* ctx, const std::string& prefix,
const std::string& checkpoint_key, DataType dtype,
ImmediateTensorHandlePtr* out) {
// Create the EagerOp
ImmediateOpPtr restore_op(ctx->CreateOperation());
TF_RETURN_IF_ERROR(restore_op->Reset("RestoreV2", "/cpu:0"));
TF_RETURN_IF_ERROR(restore_op->SetAttrTypeList("dtypes", &dtype, 1));
ImmediateTensorHandlePtr prefix_handle;
TF_RETURN_IF_ERROR(
CreateStringScalarTensorHandle(ctx, prefix, &prefix_handle));
ImmediateTensorHandlePtr names_handle;
TF_RETURN_IF_ERROR(
CreateStringVectorTensorHandle(ctx, checkpoint_key, &names_handle));
// Note that empty string is the slice spec used for a non-partitioned
// ResourceVariable:
// https://github.com/tensorflow/tensorflow/blob/06ff30f7ea35098cb68a231a9eb7ff3ff4be4e1e/tensorflow/python/training/saving/saveable_object_util.py#L194
ImmediateTensorHandlePtr shapes_and_slices_handle;
TF_RETURN_IF_ERROR(
CreateStringVectorTensorHandle(ctx, "", &shapes_and_slices_handle));
TF_RETURN_IF_ERROR(restore_op->AddInput(prefix_handle.get()));
TF_RETURN_IF_ERROR(restore_op->AddInput(names_handle.get()));
TF_RETURN_IF_ERROR(restore_op->AddInput(shapes_and_slices_handle.get()));
AbstractTensorHandle* restored_handle = nullptr;
int num_retvals = 1;
TF_RETURN_IF_ERROR(restore_op->Execute(
absl::MakeSpan(&restored_handle, num_retvals), &num_retvals));
AbstractTensorHandlePtr owned_restored_handle(restored_handle);
if (!tensorflow::isa<ImmediateExecutionTensorHandle>(
owned_restored_handle.get())) {
return errors::Internal("Unexpected tensor handle kind.");
}
out->reset(reinterpret_cast<ImmediateExecutionTensorHandle*>(
owned_restored_handle.release()));
return Status();
}
} // namespace internal
} // namespace tensorflow

View File

@ -0,0 +1,40 @@
/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#ifndef TENSORFLOW_C_EXPERIMENTAL_SAVED_MODEL_CORE_OPS_RESTORE_OP_H_
#define TENSORFLOW_C_EXPERIMENTAL_SAVED_MODEL_CORE_OPS_RESTORE_OP_H_
#include <string>
#include "tensorflow/c/eager/immediate_execution_context.h"
#include "tensorflow/c/eager/immediate_execution_tensor_handle.h"
#include "tensorflow/core/framework/types.pb.h"
#include "tensorflow/core/platform/status.h"
namespace tensorflow {
namespace internal {
// TODO(bmzhao): Add a function to restore multiple tensors in one call.
// Restores a single non-partioned tensorhandle of dtype `dtype`, using
// checkpoint at `prefix`, with a value stored in `checkpoint_key`.
Status SingleRestore(ImmediateExecutionContext* ctx, const std::string& prefix,
const std::string& checkpoint_key, DataType dtype,
ImmediateTensorHandlePtr* out);
} // namespace internal
} // namespace tensorflow
#endif // TENSORFLOW_C_EXPERIMENTAL_SAVED_MODEL_CORE_OPS_RESTORE_OP_H_

View File

@ -0,0 +1,111 @@
/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#include "tensorflow/c/experimental/saved_model/core/ops/restore_ops.h"
#include "tensorflow/c/eager/abstract_tensor_handle.h"
#include "tensorflow/c/eager/immediate_execution_tensor_handle.h"
#include "tensorflow/c/experimental/saved_model/core/test_utils.h"
#include "tensorflow/c/tensor_interface.h"
#include "tensorflow/cc/saved_model/constants.h"
#include "tensorflow/core/framework/types.pb.h"
#include "tensorflow/core/lib/core/status_test_util.h"
#include "tensorflow/core/platform/path.h"
#include "tensorflow/core/platform/stringpiece.h"
#include "tensorflow/core/platform/test.h"
namespace tensorflow {
namespace {
std::string CheckpointPrefix(StringPiece saved_model_dir) {
return io::JoinPath(testing::TensorFlowSrcRoot(), "cc/saved_model/testdata",
saved_model_dir, kSavedModelVariablesDirectory,
kSavedModelVariablesFilename);
}
class RestoreOpsTest : public ::testing::Test {
public:
RestoreOpsTest()
: device_mgr_(testing::CreateTestingDeviceMgr()),
ctx_(testing::CreateTestingEagerContext(device_mgr_.get())) {}
EagerContext* context() { return ctx_.get(); }
private:
std::unique_ptr<StaticDeviceMgr> device_mgr_;
EagerContextPtr ctx_;
};
// One way of obtaining the checkpointa checkpoint's tensor names is:
// bazel run //tensorflow/python/tools:inspect_checkpoint -- --all_tensors
// --file_name="$CKPT_PREFIX".
// Here are the values for VarsAndArithmeticObjectGraph:
// tensor: child/z/.ATTRIBUTES/VARIABLE_VALUE (float32) []
// 3.0
// tensor: x/.ATTRIBUTES/VARIABLE_VALUE (float32) []
// 1.0
// tensor: y/.ATTRIBUTES/VARIABLE_VALUE (float32) []
// 2.0
TEST_F(RestoreOpsTest, RestoreSuccessful) {
ImmediateTensorHandlePtr x_handle;
TF_EXPECT_OK(internal::SingleRestore(
context(), CheckpointPrefix("VarsAndArithmeticObjectGraph"),
"x/.ATTRIBUTES/VARIABLE_VALUE", DT_FLOAT, &x_handle));
AbstractTensorPtr x = testing::TensorHandleToTensor(x_handle.get());
EXPECT_EQ(x->Type(), DT_FLOAT);
EXPECT_EQ(x->NumElements(), 1);
EXPECT_EQ(x->NumDims(), 0);
EXPECT_FLOAT_EQ(*reinterpret_cast<float*>(x->Data()), 1.0f);
ImmediateTensorHandlePtr y_handle;
TF_EXPECT_OK(internal::SingleRestore(
context(), CheckpointPrefix("VarsAndArithmeticObjectGraph"),
"y/.ATTRIBUTES/VARIABLE_VALUE", DT_FLOAT, &y_handle));
AbstractTensorPtr y = testing::TensorHandleToTensor(y_handle.get());
EXPECT_EQ(y->Type(), DT_FLOAT);
EXPECT_EQ(y->NumElements(), 1);
EXPECT_EQ(y->NumDims(), 0);
EXPECT_FLOAT_EQ(*reinterpret_cast<float*>(y->Data()), 2.0f);
ImmediateTensorHandlePtr z_handle;
TF_EXPECT_OK(internal::SingleRestore(
context(), CheckpointPrefix("VarsAndArithmeticObjectGraph"),
"child/z/.ATTRIBUTES/VARIABLE_VALUE", DT_FLOAT, &z_handle));
AbstractTensorPtr z = testing::TensorHandleToTensor(z_handle.get());
EXPECT_EQ(z->Type(), DT_FLOAT);
EXPECT_EQ(z->NumElements(), 1);
EXPECT_EQ(z->NumDims(), 0);
EXPECT_FLOAT_EQ(*reinterpret_cast<float*>(z->Data()), 3.0f);
}
TEST_F(RestoreOpsTest, BadCheckpointPrefixShouldFail) {
ImmediateTensorHandlePtr x_handle;
Status status = internal::SingleRestore(
context(), CheckpointPrefix("unknown_bad_checkpoint_prefix"),
"x/.ATTRIBUTES/VARIABLE_VALUE", DT_FLOAT, &x_handle);
EXPECT_FALSE(status.ok()) << status.error_message();
}
TEST_F(RestoreOpsTest, BadCheckpointKeyShouldFail) {
ImmediateTensorHandlePtr x_handle;
Status status = internal::SingleRestore(
context(), CheckpointPrefix("VarsAndArithmeticObjectGraph"),
"bad_checkpoint_key", DT_FLOAT, &x_handle);
EXPECT_FALSE(status.ok()) << status.error_message();
}
} // namespace
} // namespace tensorflow

View File

@ -18,6 +18,7 @@ limitations under the License.
#include <memory> #include <memory>
#include "tensorflow/c/eager/immediate_execution_tensor_handle.h" #include "tensorflow/c/eager/immediate_execution_tensor_handle.h"
#include "tensorflow/c/experimental/saved_model/core/test_utils.h"
#include "tensorflow/c/tensor_interface.h" #include "tensorflow/c/tensor_interface.h"
#include "tensorflow/core/common_runtime/device_mgr.h" #include "tensorflow/core/common_runtime/device_mgr.h"
#include "tensorflow/core/common_runtime/eager/context.h" #include "tensorflow/core/common_runtime/eager/context.h"
@ -39,17 +40,8 @@ ImmediateTensorHandlePtr CreateScalarTensorHandle(EagerContext* context,
class VariableOpsTest : public ::testing::Test { class VariableOpsTest : public ::testing::Test {
public: public:
VariableOpsTest() VariableOpsTest()
: device_mgr_(std::make_unique<StaticDeviceMgr>(DeviceFactory::NewDevice( : device_mgr_(testing::CreateTestingDeviceMgr()),
"CPU", {}, "/job:localhost/replica:0/task:0"))), ctx_(testing::CreateTestingEagerContext(device_mgr_.get())) {}
ctx_(new EagerContext(
SessionOptions(),
tensorflow::ContextDevicePlacementPolicy::DEVICE_PLACEMENT_SILENT,
tensorflow::ContextMirroringPolicy::MIRRORING_NONE,
/* async= */ false,
/* lazy_copy_function_remote_inputs= */ false, device_mgr_.get(),
/* device_mgr_owned= */ false, /* rendezvous= */ nullptr,
/* custom_kernel_creator= */ nullptr,
/* cluster_flr= */ nullptr)) {}
EagerContext* context() { return ctx_.get(); } EagerContext* context() { return ctx_.get(); }

View File

@ -58,3 +58,24 @@ cc_library(
"//tensorflow/c/eager:immediate_execution_tensor_handle", "//tensorflow/c/eager:immediate_execution_tensor_handle",
], ],
) )
cc_library(
name = "tf_concrete_function",
srcs = [
"tf_concrete_function.cc",
],
hdrs = [
"tf_concrete_function.h",
],
deps = [
":tensorhandle_convertible",
"//tensorflow/c/eager:immediate_execution_context",
"//tensorflow/c/eager:immediate_execution_operation",
"//tensorflow/c/eager:immediate_execution_tensor_handle",
"//tensorflow/c/experimental/saved_model/core:concrete_function",
"//tensorflow/c/experimental/saved_model/core:function_metadata",
"//tensorflow/core:lib",
"//tensorflow/core:protos_all_cc",
"//tensorflow/core/common_runtime/eager:context",
],
)

View File

@ -0,0 +1,87 @@
/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#include "tensorflow/c/experimental/saved_model/core/revived_types/tf_concrete_function.h"
#include <memory>
#include <string>
#include "tensorflow/c/eager/immediate_execution_operation.h"
#include "tensorflow/c/eager/immediate_execution_tensor_handle.h"
#include "tensorflow/c/experimental/saved_model/core/revived_types/tensorhandle_convertible.h"
#include "tensorflow/core/common_runtime/eager/context.h"
#include "tensorflow/core/framework/function.pb.h"
#include "tensorflow/core/platform/errors.h"
#include "tensorflow/core/platform/logging.h"
#include "tensorflow/core/platform/status.h"
#include "tensorflow/core/protobuf/saved_object_graph.pb.h"
#include "tensorflow/core/protobuf/struct.pb.h"
namespace tensorflow {
TFConcreteFunction::TFConcreteFunction(
const std::string& name,
std::vector<ImmediateExecutionTensorHandle*> captures,
FunctionMetadata metadata, ImmediateExecutionContext* ctx)
: name_(name),
captures_(std::move(captures)),
metadata_(std::move(metadata)),
ctx_(ctx) {}
TFConcreteFunction::~TFConcreteFunction() {
Status status = ctx_->RemoveFunction(name_);
if (!status.ok()) {
LOG(ERROR) << "Failed to remove functiondef " << name_ << ". "
<< status.error_message();
}
}
Status TFConcreteFunction::Create(
const FunctionDef* function_def,
std::vector<ImmediateExecutionTensorHandle*> captures,
FunctionMetadata metadata, ImmediateExecutionContext* ctx,
std::unique_ptr<TFConcreteFunction>* out) {
TF_RETURN_IF_ERROR(ctx->AddFunctionDef(*function_def));
out->reset(new TFConcreteFunction(function_def->signature().name(),
std::move(captures), std::move(metadata),
ctx));
return Status();
}
const std::vector<ImmediateExecutionTensorHandle*>&
TFConcreteFunction::GetCaptures() const {
return captures_;
}
const FunctionMetadata& TFConcreteFunction::GetFunctionMetadata() const {
return metadata_;
}
Status TFConcreteFunction::GetCallOp(ImmediateOpPtr* out) {
out->reset(ctx_->CreateOperation());
// In eager mode, TF2 python executes functions by constructing an op with
// the name of the functiondef:
// https://github.com/tensorflow/tensorflow/blob/66668ec0ca432e2f38a575b814f45b6d299d01ed/tensorflow/python/eager/function.py#L545
// In graph mode, we create a PartitionedCallOp instead:
// https://github.com/tensorflow/tensorflow/blob/66668ec0ca432e2f38a575b814f45b6d299d01ed/tensorflow/python/eager/function.py#L573
// TODO(bmzhao): After discussing with Allen, we should execute this via a
// PartitionedCallOp for compatibility with "tooling that assumes functions in
// graphs are PartitionedCallOps".
TF_RETURN_IF_ERROR((*out)->Reset(name_.c_str(), nullptr));
return Status();
}
} // namespace tensorflow

View File

@ -0,0 +1,87 @@
/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#ifndef TENSORFLOW_C_EXPERIMENTAL_SAVED_MODEL_CORE_TF_CONCRETE_FUNCTION_H_
#define TENSORFLOW_C_EXPERIMENTAL_SAVED_MODEL_CORE_TF_CONCRETE_FUNCTION_H_
#include <functional>
#include <memory>
#include <string>
#include <unordered_map>
#include <vector>
#include "tensorflow/c/eager/immediate_execution_context.h"
#include "tensorflow/c/eager/immediate_execution_operation.h"
#include "tensorflow/c/eager/immediate_execution_tensor_handle.h"
#include "tensorflow/c/experimental/saved_model/core/concrete_function.h"
#include "tensorflow/c/experimental/saved_model/core/function_metadata.h"
#include "tensorflow/c/experimental/saved_model/core/revived_types/tensorhandle_convertible.h"
#include "tensorflow/core/framework/function.pb.h"
#include "tensorflow/core/protobuf/saved_object_graph.pb.h"
namespace tensorflow {
// TF Eager Runtime-based implementation of a "ConcreteFunction" loaded from a
// saved model.
class TFConcreteFunction : public ConcreteFunction {
public:
// Factory function for creating a TFConcreteFunction.
//
// Params:
// function_def - The function_def associated with the created
// TFConcreteFunction. TFConcreteFunction will register this
// function_def with `ctx` on creation, and de-register it on
// destruction. function_def must be non-null, but
// otherwise has no lifetime requirements.
// captures - The captured TensorHandles associated with this
// TFConcreteFunction.
// metadata - The FunctionMetadata associated with this TFConcreteFunction.
// ctx - A handle to the Tensorflow runtime. This MUST be non-null and
// outlive TFConcreteFunction.
// out - The output TFConcreteFunction.
static Status Create(const FunctionDef* function_def,
std::vector<ImmediateExecutionTensorHandle*> captures,
FunctionMetadata metadata,
ImmediateExecutionContext* ctx,
std::unique_ptr<TFConcreteFunction>* out);
// This method returns the "Call" Op used to execute the function.
Status GetCallOp(ImmediateOpPtr* out) override;
const std::vector<ImmediateExecutionTensorHandle*>& GetCaptures()
const override;
const FunctionMetadata& GetFunctionMetadata() const override;
~TFConcreteFunction() override;
private:
TFConcreteFunction(const std::string& name,
std::vector<ImmediateExecutionTensorHandle*> captures,
FunctionMetadata metadata, ImmediateExecutionContext* ctx);
TFConcreteFunction(const TFConcreteFunction&) = delete;
TFConcreteFunction& operator=(const TFConcreteFunction&) = delete;
// Name of the FunctionDef corresponding to this TFConcreteFunction
std::string name_;
std::vector<ImmediateExecutionTensorHandle*> captures_;
FunctionMetadata metadata_;
ImmediateExecutionContext* ctx_;
};
} // namespace tensorflow
#endif // TENSORFLOW_C_EXPERIMENTAL_SAVED_MODEL_CORE_TF_CONCRETE_FUNCTION_H_

View File

@ -17,14 +17,125 @@ limitations under the License.
#include <memory> #include <memory>
#include "tensorflow/c/experimental/saved_model/core/function_metadata.h"
#include "tensorflow/c/experimental/saved_model/core/revived_types/constant.h" #include "tensorflow/c/experimental/saved_model/core/revived_types/constant.h"
#include "tensorflow/c/experimental/saved_model/core/revived_types/variable.h" #include "tensorflow/c/experimental/saved_model/core/revived_types/variable.h"
#include "tensorflow/c/tf_tensor_internal.h" #include "tensorflow/c/tf_tensor_internal.h"
#include "tensorflow/core/framework/tensor.pb.h" #include "tensorflow/core/framework/tensor.pb.h"
#include "tensorflow/core/platform/errors.h"
#include "tensorflow/core/protobuf/saved_object_graph.pb.h" #include "tensorflow/core/protobuf/saved_object_graph.pb.h"
#include "tensorflow/core/protobuf/struct.pb.h"
namespace tensorflow { namespace tensorflow {
namespace internal { namespace internal {
namespace {
// This returns the size of `tf.nest.flatten(value)`, on values that are
// used in tf.function's input_signatures.
int FlattenedSize(const tensorflow::StructuredValue& value, Status* status) {
// This follows the logic from
// https://github.com/tensorflow/tensorflow/blob/1c064ab76064c58e54261b805027474885a1534d/tensorflow/compiler/mlir/tensorflow/translate/import_model.cc#L2775
switch (value.kind_case()) {
case StructuredValue::kDictValue: {
const DictValue& dict = value.dict_value();
int size = 0;
for (const auto& field : dict.fields()) {
size += FlattenedSize(field.second, status);
}
return size;
}
case StructuredValue::kTupleValue: {
const TupleValue& tuple = value.tuple_value();
int size = 0;
for (const StructuredValue& value : tuple.values()) {
size += FlattenedSize(value, status);
}
return size;
}
case StructuredValue::kListValue: {
const ListValue& list = value.list_value();
int size = 0;
for (const StructuredValue& value : list.values()) {
size += FlattenedSize(value, status);
}
return size;
}
case StructuredValue::kTensorSpecValue: {
return 1;
}
case StructuredValue::kNoneValue: {
// Base case: do nothing.
// This arises, for example, as the top-level object of an output
// signature when there are no return values.
return 0;
}
default: {
status->Update(errors::Internal("Unhandled structured value kind ",
value.kind_case()));
return 0;
}
}
}
// Perform some basic sanity checks on SavedConcreteFunction's input and
// output signatures with respect to the corresponding FunctionDef's input
// and output args.
Status ValidateSavedFunctionCompatibleWithFunctionDef(
const SavedConcreteFunction& saved_concrete_function,
const FunctionDef* function_def) {
// tf.functions go through many transformations before becoming FunctionDefs
// 1. flatten user-provided inputs:
// https://github.com/tensorflow/tensorflow/blob/1c064ab76064c58e54261b805027474885a1534d/tensorflow/python/eager/function.py#L2671-L2675
// 2. convert user-provided inputs to tensors:
// https://github.com/tensorflow/tensorflow/blob/1c064ab76064c58e54261b805027474885a1534d/tensorflow/python/eager/function.py#L2687-L2688
// 3. filter any non-tensor, non-variable inputs:
// https://github.com/tensorflow/tensorflow/blob/1c064ab76064c58e54261b805027474885a1534d/tensorflow/python/eager/function.py#L1840-L1841
// 4. concatenate any captured inputs:
// https://github.com/tensorflow/tensorflow/blob/1c064ab76064c58e54261b805027474885a1534d/tensorflow/python/eager/function.py#L1912
// Since our API is limited to tf.functions annotated with input signatures,
// conditions 2 and 3 are trivially satisfied.
// We need to ensure that:
// flatten(input_signature).size() + captures.size() = fdef.signature().size()
// A concrete function's serialized "canonicalized_input_signature" comes
// from encoding its "structured_input_signature" field:
// https://github.com/tensorflow/tensorflow/blob/1c064ab76064c58e54261b805027474885a1534d/tensorflow/python/saved_model/function_serialization.py#L70-L71
// The "structured_input_signature" is guaranteed to be a tuple of the python
// args, kwargs that correspond to the tf.function:
// https://github.com/tensorflow/tensorflow/blob/1c064ab76064c58e54261b805027474885a1534d/tensorflow/python/eager/function.py#L1974-L1979
const std::string& name = function_def->signature().name();
const StructuredValue& input_signature =
saved_concrete_function.canonicalized_input_signature();
Status status;
int input_signature_size = FlattenedSize(input_signature, &status);
TF_RETURN_IF_ERROR(status);
if (input_signature_size + saved_concrete_function.bound_inputs_size() !=
function_def->signature().input_arg_size()) {
return errors::FailedPrecondition(
"FunctionDef ", name, " has ",
function_def->signature().input_arg_size(),
" inputs, but the SavedConcreteFunction has ", input_signature_size,
" flattened user inputs and ",
saved_concrete_function.bound_inputs_size(), " captured inputs.");
}
const StructuredValue& output_signature =
saved_concrete_function.output_signature();
int output_signature_size = FlattenedSize(output_signature, &status);
TF_RETURN_IF_ERROR(status);
if (output_signature_size != function_def->signature().output_arg_size()) {
return errors::FailedPrecondition(
"FunctionDef ", name, " has ",
function_def->signature().output_arg_size(),
" outputs, but the SavedConcreteFunction has ", output_signature_size,
" flattened outputs.");
}
return status;
}
} // namespace
Status TensorProtoToConstant(ImmediateExecutionContext* ctx, Status TensorProtoToConstant(ImmediateExecutionContext* ctx,
const TensorProto& proto, const TensorProto& proto,
@ -54,5 +165,31 @@ Status LoadSavedVariable(ImmediateExecutionContext* ctx,
return Status(); return Status();
} }
Status LoadTFConcreteFunction(
const SavedConcreteFunction& saved_concrete_function,
const FunctionDef* function_def,
const std::unordered_map<int, std::unique_ptr<TensorHandleConvertible>>&
captured_objects,
ImmediateExecutionContext* ctx, std::unique_ptr<TFConcreteFunction>* out) {
TF_RETURN_IF_ERROR(ValidateSavedFunctionCompatibleWithFunctionDef(
saved_concrete_function, function_def));
// Copy over captures
std::vector<ImmediateExecutionTensorHandle*> captures;
captures.reserve(saved_concrete_function.bound_inputs_size());
for (int bound_input : saved_concrete_function.bound_inputs()) {
auto iter = captured_objects.find(bound_input);
if (iter == captured_objects.end()) {
return errors::FailedPrecondition("Failed to find bound_input ",
bound_input,
" for SavedConcreteFunction");
}
captures.push_back(iter->second->handle());
}
return TFConcreteFunction::Create(function_def, std::move(captures), {}, ctx,
out);
}
} // namespace internal } // namespace internal
} // namespace tensorflow } // namespace tensorflow

View File

@ -21,6 +21,7 @@ limitations under the License.
#include "tensorflow/c/eager/immediate_execution_context.h" #include "tensorflow/c/eager/immediate_execution_context.h"
#include "tensorflow/c/experimental/saved_model/core/revived_types/constant.h" #include "tensorflow/c/experimental/saved_model/core/revived_types/constant.h"
#include "tensorflow/c/experimental/saved_model/core/revived_types/tf_concrete_function.h"
#include "tensorflow/c/experimental/saved_model/core/revived_types/variable.h" #include "tensorflow/c/experimental/saved_model/core/revived_types/variable.h"
#include "tensorflow/core/framework/tensor.pb.h" #include "tensorflow/core/framework/tensor.pb.h"
#include "tensorflow/core/protobuf/saved_object_graph.pb.h" #include "tensorflow/core/protobuf/saved_object_graph.pb.h"
@ -43,6 +44,14 @@ Status LoadSavedVariable(ImmediateExecutionContext* ctx,
const SavedVariable& variable, const SavedVariable& variable,
std::unique_ptr<Variable>* output); std::unique_ptr<Variable>* output);
// Creates a TFConcreteFunction from a SavedConcreteFunction.
Status LoadTFConcreteFunction(
const SavedConcreteFunction& saved_concrete_function,
const FunctionDef* function_def,
const std::unordered_map<int, std::unique_ptr<TensorHandleConvertible>>&
captured_objects,
ImmediateExecutionContext* ctx, std::unique_ptr<TFConcreteFunction>* out);
} // namespace internal } // namespace internal
} // namespace tensorflow } // namespace tensorflow

View File

@ -139,5 +139,13 @@ void CheckBufferDataIsEqual(DataType dtype, int64 num_elements, void* a,
} }
} }
AbstractTensorPtr TensorHandleToTensor(ImmediateExecutionTensorHandle* handle) {
Status status;
AbstractTensorPtr tensor(handle->Resolve(&status));
CHECK(status.ok()) << status.error_message();
CHECK_NE(tensor.get(), nullptr);
return tensor;
}
} // namespace testing } // namespace testing
} // namespace tensorflow } // namespace tensorflow

View File

@ -69,6 +69,10 @@ void FillNumericTensorBuffer(DataType dtype, size_t num_elements, void* buffer,
void CheckBufferDataIsEqual(DataType dtype, int64 num_elements, void* a, void CheckBufferDataIsEqual(DataType dtype, int64 num_elements, void* a,
void* b); void* b);
// Converts a TensorHandle to a Tensor, and dies if unsuccessful. This should
// only be used for testing purposes.
AbstractTensorPtr TensorHandleToTensor(ImmediateExecutionTensorHandle* handle);
} // namespace testing } // namespace testing
} // namespace tensorflow } // namespace tensorflow

View File

@ -0,0 +1,271 @@
/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#include <memory>
#include <unordered_map>
#include "tensorflow/c/eager/immediate_execution_tensor_handle.h"
#include "tensorflow/c/experimental/saved_model/core/revived_types/tensorhandle_convertible.h"
#include "tensorflow/c/experimental/saved_model/core/revived_types/tf_concrete_function.h"
#include "tensorflow/c/experimental/saved_model/core/saved_model_utils.h"
#include "tensorflow/c/experimental/saved_model/core/test_utils.h"
#include "tensorflow/c/experimental/saved_model/core/tf_concrete_function_test_protos.h"
#include "tensorflow/core/framework/function.pb.h"
#include "tensorflow/core/framework/op_def.pb.h"
#include "tensorflow/core/framework/types.pb.h"
#include "tensorflow/core/lib/core/status_test_util.h"
#include "tensorflow/core/platform/status.h"
#include "tensorflow/core/platform/test.h"
#include "tensorflow/core/protobuf/error_codes.pb.h"
#include "tensorflow/core/protobuf/saved_object_graph.pb.h"
namespace tensorflow {
namespace {
class SavedConcreteFunctionLoadingTest : public ::testing::Test {
public:
SavedConcreteFunctionLoadingTest()
: device_mgr_(testing::CreateTestingDeviceMgr()),
ctx_(testing::CreateTestingEagerContext(device_mgr_.get())) {}
EagerContext* context() { return ctx_.get(); }
private:
std::unique_ptr<StaticDeviceMgr> device_mgr_;
EagerContextPtr ctx_;
};
class DummyCapture : public TensorHandleConvertible {
public:
DummyCapture(ImmediateExecutionContext* ctx, int8 value)
: TensorHandleConvertible(
testing::CreateTensorHandle(ctx, DT_FLOAT, {2, 4}, value)) {}
};
FunctionDef FuncDefWithNumInputsOutputs(int num_inputs, int num_outputs) {
FunctionDef func;
OpDef* signature = func.mutable_signature();
for (int i = 0; i < num_inputs; ++i) {
signature->add_input_arg();
}
for (int i = 0; i < num_outputs; ++i) {
signature->add_output_arg();
}
return func;
}
// A SavedConcreteFunction whose canonicalized input signature
// has less inputs than its corresponding FunctionDef should cause an error.
TEST_F(SavedConcreteFunctionLoadingTest, TooFewInputsInSavedConcreteFunction) {
// `saved` has 1 input
SavedConcreteFunction saved;
*saved.mutable_canonicalized_input_signature() =
testing::SingleArgInputSignature();
*saved.mutable_output_signature() = testing::ZeroReturnOutputSignature();
// `func` has 2 inputs
FunctionDef func = FuncDefWithNumInputsOutputs(2, 0);
std::unique_ptr<TFConcreteFunction> result;
Status status =
internal::LoadTFConcreteFunction(saved, &func, {}, context(), &result);
EXPECT_EQ(status.code(), error::FAILED_PRECONDITION)
<< status.error_message();
}
// A SavedConcreteFunction whose canonicalized input signature length +
// captures is less than its corresponding FunctionDef should cause an error.
TEST_F(SavedConcreteFunctionLoadingTest,
TooFewInputsWithCapturesInSavedConcreteFunction) {
// `saved` has 1 input, and 1 capture, for a total of 2 inputs
SavedConcreteFunction saved;
*saved.mutable_canonicalized_input_signature() =
testing::SingleArgInputSignature();
*saved.mutable_output_signature() = testing::ZeroReturnOutputSignature();
saved.add_bound_inputs(5);
// `func` has 3 inputs
FunctionDef func = FuncDefWithNumInputsOutputs(3, 0);
std::unordered_map<int, std::unique_ptr<TensorHandleConvertible>> captures;
captures[5] = std::make_unique<DummyCapture>(context(), 10);
std::unique_ptr<TFConcreteFunction> result;
Status status = internal::LoadTFConcreteFunction(saved, &func, captures,
context(), &result);
EXPECT_EQ(status.code(), error::FAILED_PRECONDITION)
<< status.error_message();
}
// A SavedConcreteFunction whose canonicalized input signature
// has more inputs than its corresponding FunctionDef should cause an error.
TEST_F(SavedConcreteFunctionLoadingTest, TooManyInputsInSavedConcreteFunction) {
// `saved` has 3 inputs
SavedConcreteFunction saved;
*saved.mutable_canonicalized_input_signature() =
testing::ThreeArgInputSignature();
*saved.mutable_output_signature() = testing::ZeroReturnOutputSignature();
// `func` has 2 inputs
FunctionDef func = FuncDefWithNumInputsOutputs(2, 0);
std::unique_ptr<TFConcreteFunction> result;
Status status =
internal::LoadTFConcreteFunction(saved, &func, {}, context(), &result);
EXPECT_EQ(status.code(), error::FAILED_PRECONDITION)
<< status.error_message();
}
// A SavedConcreteFunction whose canonicalized input signature
// has the same number of inputs than its corresponding FunctionDef, but has
// additional captures should cause an error.
TEST_F(SavedConcreteFunctionLoadingTest,
TooManyInputsWithCaptureInSavedConcreteFunction) {
// `saved` has 3 inputs, and 1 capture, for a total of 4 inputs.
SavedConcreteFunction saved;
*saved.mutable_canonicalized_input_signature() =
testing::ThreeArgInputSignature();
*saved.mutable_output_signature() = testing::ZeroReturnOutputSignature();
saved.add_bound_inputs(5);
// `func` has 3 inputs.
FunctionDef func = FuncDefWithNumInputsOutputs(3, 0);
std::unordered_map<int, std::unique_ptr<TensorHandleConvertible>> captures;
captures[5] = std::make_unique<DummyCapture>(context(), 10);
std::unique_ptr<TFConcreteFunction> result;
Status status = internal::LoadTFConcreteFunction(saved, &func, captures,
context(), &result);
EXPECT_EQ(status.code(), error::FAILED_PRECONDITION)
<< status.error_message();
}
// A SavedConcreteFunction whose capture refers to an index not in the capture
// map should cause an error.
TEST_F(SavedConcreteFunctionLoadingTest, ImproperCaptureIndex) {
// `saved` has 3 inputs, 1 capture, for a total of 4 inputs
SavedConcreteFunction saved;
*saved.mutable_canonicalized_input_signature() =
testing::ThreeArgInputSignature();
*saved.mutable_output_signature() = testing::ZeroReturnOutputSignature();
// Capture is at index "10"
saved.add_bound_inputs(10);
// `func` has 4 inputs
FunctionDef func = FuncDefWithNumInputsOutputs(4, 0);
// `captures` only has a capture for index 5
std::unordered_map<int, std::unique_ptr<TensorHandleConvertible>> captures;
captures[5] = std::make_unique<DummyCapture>(context(), 10);
std::unique_ptr<TFConcreteFunction> result;
Status status = internal::LoadTFConcreteFunction(saved, &func, captures,
context(), &result);
EXPECT_EQ(status.code(), error::FAILED_PRECONDITION)
<< status.error_message();
}
// A SavedConcreteFunction whose outputs are fewer than its corresponding
// functiondef should cause an error.
TEST_F(SavedConcreteFunctionLoadingTest, TooFewOutputsInSavedConcreteFunction) {
// `saved` has 0 inputs, 1 output
SavedConcreteFunction saved;
*saved.mutable_canonicalized_input_signature() =
testing::ZeroArgInputSignature();
*saved.mutable_output_signature() = testing::SingleReturnOutputSignature();
// `func` has 0 inputs, 2 outputs
FunctionDef func = FuncDefWithNumInputsOutputs(0, 2);
std::unique_ptr<TFConcreteFunction> result;
Status status =
internal::LoadTFConcreteFunction(saved, &func, {}, context(), &result);
EXPECT_EQ(status.code(), error::FAILED_PRECONDITION)
<< status.error_message();
}
// A SavedConcreteFunction whose outputs exceed its corresponding functiondef
// should cause an error.
TEST_F(SavedConcreteFunctionLoadingTest,
TooManyOutputsInSavedConcreteFunction) {
// `saved` has 1 input, 3 outputs
SavedConcreteFunction saved;
*saved.mutable_canonicalized_input_signature() =
testing::SingleArgInputSignature();
*saved.mutable_output_signature() = testing::ThreeReturnOutputSignature();
// `func` has 1 input, 2 outputs
FunctionDef func = FuncDefWithNumInputsOutputs(1, 2);
std::unique_ptr<TFConcreteFunction> result;
Status status =
internal::LoadTFConcreteFunction(saved, &func, {}, context(), &result);
EXPECT_EQ(status.code(), error::FAILED_PRECONDITION)
<< status.error_message();
}
// A SavedConcreteFunction whose (inputs + captures) = functiondef inputs,
// and whose outputs = functiondef outputs should successfully load.
TEST_F(SavedConcreteFunctionLoadingTest, SuccessfulLoad) {
// `saved` has 1 input, 2 captures, 3 outputs
SavedConcreteFunction saved;
*saved.mutable_canonicalized_input_signature() =
testing::SingleArgInputSignature();
*saved.mutable_output_signature() = testing::ThreeReturnOutputSignature();
saved.add_bound_inputs(2);
saved.add_bound_inputs(5);
// `func` has 3 inputs, 3 outputs
FunctionDef func = FuncDefWithNumInputsOutputs(3, 3);
std::unordered_map<int, std::unique_ptr<TensorHandleConvertible>> captures;
captures[2] = std::make_unique<DummyCapture>(context(), 1);
captures[5] = std::make_unique<DummyCapture>(context(), 10);
std::unique_ptr<TFConcreteFunction> result;
Status status = internal::LoadTFConcreteFunction(saved, &func, captures,
context(), &result);
TF_EXPECT_OK(status) << status.error_message();
}
// A TFConcreteFunction should register functiondefs on creation, and
// remove them upon deletion.
TEST_F(SavedConcreteFunctionLoadingTest, RegistersAndRemovesFunctionDefs) {
std::string func_name = "FooBarBazWombatFunction";
SavedConcreteFunction saved;
*saved.mutable_canonicalized_input_signature() =
testing::ZeroArgInputSignature();
*saved.mutable_output_signature() = testing::ZeroReturnOutputSignature();
FunctionDef func = FuncDefWithNumInputsOutputs(0, 0);
*func.mutable_signature()->mutable_name() = func_name;
{
std::unique_ptr<TFConcreteFunction> result;
Status status =
internal::LoadTFConcreteFunction(saved, &func, {}, context(), &result);
TF_EXPECT_OK(status) << status.error_message();
// The function should be registered with context.
EXPECT_TRUE(context()->FindFunctionByName(func_name));
}
// After `result's` destructor runs, the function should no longer be
// registered with context.
EXPECT_FALSE(context()->FindFunctionByName(func_name));
}
} // namespace
} // namespace tensorflow

View File

@ -0,0 +1,212 @@
/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#include "tensorflow/c/experimental/saved_model/core/tf_concrete_function_test_protos.h"
#include <string>
#include "absl/strings/string_view.h"
#include "tensorflow/core/platform/protobuf.h"
#include "tensorflow/core/protobuf/struct.pb.h"
namespace tensorflow {
namespace testing {
namespace {
constexpr absl::string_view kZeroArgInputSignatureTextProto = R"(
tuple_value: {
values: {
tuple_value: {
}
}
values: {
dict_value: {
}
}
}
)";
constexpr absl::string_view kSingleArgInputSignatureTextProto = R"(
tuple_value: {
values: {
tuple_value: {
values: {
tensor_spec_value: {
name : "x"
shape: {
dim: {
size: 1
}
dim: {
size: 10
}
}
dtype: DT_FLOAT
}
}
}
}
values: {
dict_value: {
}
}
}
)";
constexpr absl::string_view kThreeArgInputSignatureTextProto = R"(
tuple_value: {
values: {
tuple_value: {
values: {
tensor_spec_value: {
name : "x"
shape: {
dim: {
size: 1
}
}
dtype: DT_FLOAT
}
}
values: {
tensor_spec_value: {
name : "y"
shape: {
dim: {
size: 1
}
}
dtype: DT_FLOAT
}
}
values: {
tensor_spec_value: {
name : "z"
shape: {
dim: {
size: 1
}
}
dtype: DT_FLOAT
}
}
}
}
values: {
dict_value: {
}
}
}
)";
constexpr absl::string_view kZeroReturnOutputSignatureTextProto = R"(
none_value: {}
)";
constexpr absl::string_view kSingleReturnOutputSignatureTextProto = R"(
tensor_spec_value: {
shape: {
dim: {
size: 1
}
}
dtype: DT_FLOAT
}
)";
constexpr absl::string_view kThreeReturnOutputSignatureTextProto = R"(
tuple_value: {
values: {
dict_value: {
fields: {
key : "a"
value: {
tensor_spec_value: {
name : "0/a"
shape: {
dim: {
size: 1
}
}
dtype: DT_FLOAT
}
}
}
fields: {
key : "b"
value: {
tensor_spec_value: {
name : "0/b"
shape: {
dim: {
size: 1
}
}
dtype: DT_FLOAT
}
}
}
}
}
values: {
tensor_spec_value: {
name : "1"
shape: {
dim: {
size: 1
}
}
dtype: DT_FLOAT
}
}
}
)";
StructuredValue ParseStructuredValue(absl::string_view text_proto) {
StructuredValue value;
CHECK(tensorflow::protobuf::TextFormat::ParseFromString(string(text_proto),
&value));
return value;
}
} // namespace
StructuredValue ZeroArgInputSignature() {
return ParseStructuredValue(kZeroArgInputSignatureTextProto);
}
StructuredValue SingleArgInputSignature() {
return ParseStructuredValue(kSingleArgInputSignatureTextProto);
}
StructuredValue ThreeArgInputSignature() {
return ParseStructuredValue(kThreeArgInputSignatureTextProto);
}
StructuredValue ZeroReturnOutputSignature() {
return ParseStructuredValue(kZeroReturnOutputSignatureTextProto);
}
StructuredValue SingleReturnOutputSignature() {
return ParseStructuredValue(kSingleReturnOutputSignatureTextProto);
}
StructuredValue ThreeReturnOutputSignature() {
return ParseStructuredValue(kThreeReturnOutputSignatureTextProto);
}
} // namespace testing
} // namespace tensorflow

View File

@ -0,0 +1,50 @@
/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#ifndef TENSORFLOW_C_EXPERIMENTAL_SAVED_MODEL_TF_CONCRETE_FUNCTION_TEST_PROTOS_H_
#define TENSORFLOW_C_EXPERIMENTAL_SAVED_MODEL_TF_CONCRETE_FUNCTION_TEST_PROTOS_H_
#include "tensorflow/core/protobuf/struct.pb.h"
namespace tensorflow {
namespace testing {
// Returns a StructuredValue corresponding to the serialized InputSignature of a
// tf.function with 0 inputs
StructuredValue ZeroArgInputSignature();
// Returns a StructuredValue corresponding to the serialized InputSignature of a
// tf.function with 1 input
StructuredValue SingleArgInputSignature();
// Returns a StructuredValue corresponding to the serialized InputSignature of a
// tf.function with 3 inputs
StructuredValue ThreeArgInputSignature();
// Returns a StructuredValue corresponding to the serialized OutputSignature of
// a tf.function with no return values
StructuredValue ZeroReturnOutputSignature();
// Returns a StructuredValue corresponding to the serialized OutputSignature of
// a tf.function with a single tensor output
StructuredValue SingleReturnOutputSignature();
// Returns a StructuredValue corresponding to the serialized OutputSignature of
// a tf.function with three tensor outputs
StructuredValue ThreeReturnOutputSignature();
} // namespace testing
} // namespace tensorflow
#endif // TENSORFLOW_C_EXPERIMENTAL_SAVED_MODEL_TF_CONCRETE_FUNCTION_TEST_PROTOS_H_

View File

@ -41,11 +41,13 @@ cc_library(
":tensorhandle_list", ":tensorhandle_list",
":tensorhandle_list_type", ":tensorhandle_list_type",
"//tensorflow/c:c_api_macros", "//tensorflow/c:c_api_macros",
"//tensorflow/c:tf_status_internal",
"//tensorflow/c/eager:c_api", "//tensorflow/c/eager:c_api",
"//tensorflow/c/eager:c_api_internal", "//tensorflow/c/eager:immediate_execution_operation",
"//tensorflow/c/eager:tfe_op_internal", "//tensorflow/c/eager:tfe_op_internal",
"//tensorflow/c/experimental/saved_model/core:concrete_function", "//tensorflow/c/experimental/saved_model/core:concrete_function",
"//tensorflow/c/experimental/saved_model/core:function_metadata", "//tensorflow/c/experimental/saved_model/core:function_metadata",
"//tensorflow/core:lib",
], ],
) )
@ -205,9 +207,13 @@ tf_cc_test(
], ],
deps = [ deps = [
"//tensorflow/c:tf_status", "//tensorflow/c:tf_status",
"//tensorflow/c:tf_tensor",
"//tensorflow/c/eager:c_api", "//tensorflow/c/eager:c_api",
"//tensorflow/c/eager:c_api_experimental", "//tensorflow/c/eager:c_api_experimental",
"//tensorflow/c/eager:c_api_test_util",
"//tensorflow/c/experimental/saved_model/public:concrete_function",
"//tensorflow/c/experimental/saved_model/public:saved_model_api", "//tensorflow/c/experimental/saved_model/public:saved_model_api",
"//tensorflow/c/experimental/saved_model/public:tensorhandle_list",
"//tensorflow/core:lib", "//tensorflow/core:lib",
"//tensorflow/core:test", "//tensorflow/core:test",
"//tensorflow/core:test_main", "//tensorflow/core:test_main",

View File

@ -15,12 +15,15 @@ limitations under the License.
#include "tensorflow/c/experimental/saved_model/public/concrete_function.h" #include "tensorflow/c/experimental/saved_model/public/concrete_function.h"
#include "tensorflow/c/eager/immediate_execution_operation.h"
#include "tensorflow/c/eager/tfe_op_internal.h" #include "tensorflow/c/eager/tfe_op_internal.h"
#include "tensorflow/c/experimental/saved_model/core/concrete_function.h" #include "tensorflow/c/experimental/saved_model/core/concrete_function.h"
#include "tensorflow/c/experimental/saved_model/core/function_metadata.h" #include "tensorflow/c/experimental/saved_model/core/function_metadata.h"
#include "tensorflow/c/experimental/saved_model/internal/concrete_function_type.h" #include "tensorflow/c/experimental/saved_model/internal/concrete_function_type.h"
#include "tensorflow/c/experimental/saved_model/internal/function_metadata_type.h" #include "tensorflow/c/experimental/saved_model/internal/function_metadata_type.h"
#include "tensorflow/c/experimental/saved_model/internal/tensorhandle_list_type.h" #include "tensorflow/c/experimental/saved_model/internal/tensorhandle_list_type.h"
#include "tensorflow/c/tf_status_internal.h"
#include "tensorflow/core/platform/status.h"
extern "C" { extern "C" {
@ -34,8 +37,11 @@ const TF_TensorHandleList* TF_ConcreteFunctionGetCaptures(
return tensorflow::wrap(&tensorflow::unwrap(func)->GetCaptures()); return tensorflow::wrap(&tensorflow::unwrap(func)->GetCaptures());
} }
TFE_Op* TF_ConcreteFunctionGetCallOp(TF_ConcreteFunction* func) { TFE_Op* TF_ConcreteFunctionGetCallOp(TF_ConcreteFunction* func,
return tensorflow::wrap(tensorflow::unwrap(func)->GetCallOp()); TF_Status* status) {
tensorflow::ImmediateOpPtr call_op(nullptr);
status->status = tensorflow::unwrap(func)->GetCallOp(&call_op);
return tensorflow::wrap(call_op.release());
} }
} // end extern "C" } // end extern "C"

View File

@ -41,7 +41,7 @@ TF_CAPI_EXPORT extern const TF_TensorHandleList* TF_ConcreteFunctionGetCaptures(
// Returns a TFE_Op suitable for executing this function. // Returns a TFE_Op suitable for executing this function.
TF_CAPI_EXPORT extern TFE_Op* TF_ConcreteFunctionGetCallOp( TF_CAPI_EXPORT extern TFE_Op* TF_ConcreteFunctionGetCallOp(
TF_ConcreteFunction* func); TF_ConcreteFunction* func, TF_Status* status);
#ifdef __cplusplus #ifdef __cplusplus
} // end extern "C" } // end extern "C"

View File

@ -244,9 +244,7 @@ static bool MustAliasOutput(
if (input_output_alias.shape().tuple_shapes_size() == 0) { if (input_output_alias.shape().tuple_shapes_size() == 0) {
return false; return false;
} }
return input_output_alias.OutputHasAlias(output_index) && return input_output_alias.OutputHasAlias(output_index);
input_output_alias.GetAliasedParameter(output_index).value().kind ==
xla::HloInputOutputAliasConfig::kUserAlias;
} }
// Returns an aliased tensor if it exists, nullptr otherwise. // Returns an aliased tensor if it exists, nullptr otherwise.

View File

@ -482,8 +482,8 @@ tf_cc_test(
) )
cc_library( cc_library(
name = "xla_hlo_fusion", name = "mhlo_fusion",
srcs = ["lib/Dialect/mhlo/transforms/xla_hlo_fusion.cc"], srcs = ["lib/Dialect/mhlo/transforms/mhlo_fusion.cc"],
deps = [ deps = [
":cycle_detector", ":cycle_detector",
":hlo", ":hlo",
@ -680,3 +680,40 @@ cc_library(
], ],
alwayslink = 1, alwayslink = 1,
) )
cc_library(
name = "all_xla_passes_for_testing",
visibility = [
"//tensorflow/compiler/mlir:__subpackages__",
],
deps = [
":chlo_legalize_to_hlo",
":hlo_dialect_registration",
":hlo_legalize_to_lhlo",
":lhlo",
":lhlo_copy_removal",
":lhlo_fuse_linalg",
":lhlo_legalize_to_affine",
":lhlo_legalize_to_gpu",
":lhlo_legalize_to_parallel_loops",
":mhlo_fusion",
":xla_legalize_control_flow",
":xla_legalize_tanh_to_approximation",
":xla_legalize_to_linalg",
":xla_legalize_to_standard",
":xla_lower",
":xla_sink_constants_to_control_flow",
":xla_test_passes",
":xla_transform_unranked_hlo",
],
)
cc_binary(
name = "mlir-hlo-opt",
deps = [
":all_xla_passes_for_testing",
"@llvm-project//mlir:AllPassesAndDialects",
"@llvm-project//mlir:MlirOptLib",
"@llvm-project//mlir:MlirOptMain",
],
)

View File

@ -17,12 +17,12 @@ limitations under the License.
// These ops are not necessarily orthogonal or optimized for transformation but // These ops are not necessarily orthogonal or optimized for transformation but
// for ease of expression in certain cases deemed important for client // for ease of expression in certain cases deemed important for client
// libraries (i.e. implicit broadcasting, helper ops, etc). // libraries (i.e. implicit broadcasting, helper ops, etc).
// This dialect is considered to exist in addition to augment the xla_hlo // This dialect is considered to exist in addition to augment the mhlo
// dialect for ergonomic needs, not duplicate/replace it. // dialect for ergonomic needs, not duplicate/replace it.
// //
// The typical use of this dialect is for client libraries to be able to emit // The typical use of this dialect is for client libraries to be able to emit
// less constrained ops and rely on the conversion framework to lower any // less constrained ops and rely on the conversion framework to lower any
// xla_chlo ops to canonical xla_hlo ops. // xla_chlo ops to canonical mhlo ops.
// //
// See: https://www.tensorflow.org/xla/operation_semantics // See: https://www.tensorflow.org/xla/operation_semantics
@ -44,7 +44,7 @@ def HLOClient_Dialect : Dialect {
let description = [{ let description = [{
This dialect contains ops that align closely with the API surface area This dialect contains ops that align closely with the API surface area
of the XlaBuilder C++ API, where such ops have semantics that go beyond of the XlaBuilder C++ API, where such ops have semantics that go beyond
what exists in the lower level dialects (such as `xla_hlo`). Essentially, what exists in the lower level dialects (such as `mhlo`). Essentially,
whenever the client library uses syntactic sugar or composition whenever the client library uses syntactic sugar or composition
of multiple ops for an API call, this dialect tries to model the API call of multiple ops for an API call, this dialect tries to model the API call
and provide conversion patterns to fully materialize into lower level and provide conversion patterns to fully materialize into lower level
@ -65,7 +65,7 @@ class HLOClient_Op<string mnemonic, list<OpTrait> traits> :
// broadcasting (via the broadcast_dimensions attribute) and implicit degenerate // broadcasting (via the broadcast_dimensions attribute) and implicit degenerate
// shape broadcasting. // shape broadcasting.
// //
// These correspond to operations in the xla_hlo dialect without the // These correspond to operations in the mhlo dialect without the
// "broadcast_" prefix, except that those ops require same-shaped operands and // "broadcast_" prefix, except that those ops require same-shaped operands and
// results. // results.
// //

View File

@ -37,12 +37,12 @@ class OpBuilder;
#include "tensorflow/compiler/mlir/hlo/include/mlir-hlo/Dialect/mhlo/IR/hlo_structs.h.inc" #include "tensorflow/compiler/mlir/hlo/include/mlir-hlo/Dialect/mhlo/IR/hlo_structs.h.inc"
namespace xla_hlo { namespace mhlo {
class XlaHloDialect : public Dialect { class XlaHloDialect : public Dialect {
public: public:
explicit XlaHloDialect(MLIRContext *context); explicit XlaHloDialect(MLIRContext *context);
static StringRef getDialectNamespace() { return "xla_hlo"; } static StringRef getDialectNamespace() { return "mhlo"; }
// Registered hook to materialize a constant operation from a given attribute // Registered hook to materialize a constant operation from a given attribute
// value with the desired resultant type. // value with the desired resultant type.
@ -82,7 +82,7 @@ class TokenType : public Type::TypeBase<TokenType, Type, TypeStorage> {
// %1 = index_cast %0 : index to i64 // %1 = index_cast %0 : index to i64
// %2 = dim %arg0, 1 : memref<?x?xf32> // %2 = dim %arg0, 1 : memref<?x?xf32>
// %3 = index_cast %2 : index to i64 // %3 = index_cast %2 : index to i64
// %4 = "xla_hlo.scalars_to_dimension_tensor"(%1, %3) // %4 = "mhlo.scalars_to_dimension_tensor"(%1, %3)
// : (i64, i64) -> tensor<2xi64> // : (i64, i64) -> tensor<2xi64>
// //
// and returns %4 as the shape value. // and returns %4 as the shape value.
@ -93,7 +93,7 @@ LogicalResult deriveShapeFromFirstOperand(
#define GET_OP_CLASSES #define GET_OP_CLASSES
#include "tensorflow/compiler/mlir/hlo/include/mlir-hlo/Dialect/mhlo/IR/hlo_ops.h.inc" #include "tensorflow/compiler/mlir/hlo/include/mlir-hlo/Dialect/mhlo/IR/hlo_ops.h.inc"
} // end namespace xla_hlo } // end namespace mhlo
} // end namespace mlir } // end namespace mlir
#endif // TENSORFLOW_COMPILER_MLIR_HLO_INCLUDE_MLIR_HLO_DIALECT_MHLO_IR_HLO_OPS_H_ #endif // TENSORFLOW_COMPILER_MLIR_HLO_INCLUDE_MLIR_HLO_DIALECT_MHLO_IR_HLO_OPS_H_

View File

@ -29,8 +29,8 @@ include "mlir-hlo/Dialect/mhlo/IR/hlo_utils.td"
include "mlir-hlo/Dialect/mhlo/IR/infer_fusibility_op_interface.td" include "mlir-hlo/Dialect/mhlo/IR/infer_fusibility_op_interface.td"
def HLO_Dialect : Dialect { def HLO_Dialect : Dialect {
let name = "xla_hlo"; let name = "mhlo";
let cppNamespace = "xla_hlo"; let cppNamespace = "mhlo";
} }
class HLO_Op<string mnemonic, list<OpTrait> traits> : class HLO_Op<string mnemonic, list<OpTrait> traits> :
@ -78,6 +78,7 @@ def HLO_IotaOp : HLO_Op<"iota", [NoSideEffect]>, BASE_HLO_IotaOp {
// TODO(b/130357376): Iota has special conversion logic to HLO. // TODO(b/130357376): Iota has special conversion logic to HLO.
let hasCustomHLOConverter = 1; let hasCustomHLOConverter = 1;
let hasCanonicalizer = 1;
} }
def HLO_DynamicIotaOp: HLO_Op<"dynamic_iota", [NoSideEffect]> { def HLO_DynamicIotaOp: HLO_Op<"dynamic_iota", [NoSideEffect]> {

View File

@ -22,7 +22,7 @@ limitations under the License.
#include "tensorflow/compiler/mlir/hlo/include/mlir-hlo/Dialect/mhlo/IR/lhlo_ops.h" #include "tensorflow/compiler/mlir/hlo/include/mlir-hlo/Dialect/mhlo/IR/lhlo_ops.h"
namespace mlir { namespace mlir {
namespace xla_hlo { namespace mhlo {
template <typename HloOpTy> template <typename HloOpTy>
struct HloToLhloOpImpl { struct HloToLhloOpImpl {
@ -33,7 +33,7 @@ using HloToLhloOp = typename HloToLhloOpImpl<HloOpTy>::Type;
#define MAP_HLO_TO_LHLO(OpName) \ #define MAP_HLO_TO_LHLO(OpName) \
template <> \ template <> \
struct HloToLhloOpImpl<xla_hlo::OpName> { \ struct HloToLhloOpImpl<mhlo::OpName> { \
using Type = xla_lhlo::OpName; \ using Type = xla_lhlo::OpName; \
} }
@ -74,7 +74,7 @@ MAP_HLO_TO_LHLO(TanhOp);
#undef MAP_HLO_TO_LHLO #undef MAP_HLO_TO_LHLO
} // namespace xla_hlo } // namespace mhlo
} // namespace mlir } // namespace mlir
#endif // TENSORFLOW_COMPILER_MLIR_HLO_INCLUDE_MLIR_HLO_DIALECT_MHLO_TRANSFORMS_MAP_HLO_TO_LHLO_OP_H_ #endif // TENSORFLOW_COMPILER_MLIR_HLO_INCLUDE_MLIR_HLO_DIALECT_MHLO_TRANSFORMS_MAP_HLO_TO_LHLO_OP_H_

View File

@ -464,7 +464,7 @@ struct XlaOpToStdScalarOp {
template <typename XlaOpTy, typename LhloOpTy = XlaOpTy, template <typename XlaOpTy, typename LhloOpTy = XlaOpTy,
typename = std::enable_if_t< typename = std::enable_if_t<
!std::is_same<LhloOpTy, xla_lhlo::CompareOp>::value && !std::is_same<LhloOpTy, xla_lhlo::CompareOp>::value &&
std::is_same<typename xla_hlo::HloToLhloOp<LhloOpTy>, std::is_same<typename mhlo::HloToLhloOp<LhloOpTy>,
std::false_type>::value>> std::false_type>::value>>
static Value map(XlaOpTy op, ArrayRef<Type> result_types, static Value map(XlaOpTy op, ArrayRef<Type> result_types,
ArrayRef<Value> args, OpBuilder* b, unsigned i = 0) { ArrayRef<Value> args, OpBuilder* b, unsigned i = 0) {
@ -472,8 +472,8 @@ struct XlaOpToStdScalarOp {
args, b); args, b);
} }
// Implementation for HLO ops except xla_hlo::CompareOp. // Implementation for HLO ops except mhlo::CompareOp.
template <typename XlaOpTy, typename LhloOpTy = xla_hlo::HloToLhloOp<XlaOpTy>, template <typename XlaOpTy, typename LhloOpTy = mhlo::HloToLhloOp<XlaOpTy>,
typename = std::enable_if_t< typename = std::enable_if_t<
!std::is_same<LhloOpTy, xla_lhlo::CompareOp>::value && !std::is_same<LhloOpTy, xla_lhlo::CompareOp>::value &&
!std::is_same<LhloOpTy, std::false_type>::value>> !std::is_same<LhloOpTy, std::false_type>::value>>
@ -493,10 +493,11 @@ struct XlaOpToStdScalarOp {
op.getLoc(), comparison_direction, result_types, args, b); op.getLoc(), comparison_direction, result_types, args, b);
} }
// Implementation for xla_hlo::CompareOp. // Implementation for mhlo::CompareOp.
template <typename HloOpTy, typename = std::enable_if_t<std::is_same< template <typename HloOpTy,
HloOpTy, xla_hlo::CompareOp>::value>> typename =
static Value map(xla_hlo::CompareOp op, ArrayRef<Type> result_types, std::enable_if_t<std::is_same<HloOpTy, mhlo::CompareOp>::value>>
static Value map(mhlo::CompareOp op, ArrayRef<Type> result_types,
ArrayRef<Value> args, OpBuilder* b) { ArrayRef<Value> args, OpBuilder* b) {
auto comparison_direction = op.comparison_direction(); auto comparison_direction = op.comparison_direction();
return impl::MapXlaCompareOpToStdScalarOp<xla_lhlo::CompareOp>( return impl::MapXlaCompareOpToStdScalarOp<xla_lhlo::CompareOp>(

View File

@ -29,7 +29,7 @@ template <typename T>
class OperationPass; class OperationPass;
class Pass; class Pass;
namespace xla_hlo { namespace mhlo {
/// Lowers HLO control flow ops to the Standard dialect. /// Lowers HLO control flow ops to the Standard dialect.
std::unique_ptr<OperationPass<FuncOp>> createLegalizeControlFlowPass(); std::unique_ptr<OperationPass<FuncOp>> createLegalizeControlFlowPass();
@ -55,10 +55,10 @@ std::unique_ptr<OperationPass<FuncOp>> createTransformUnrankedHloPass();
// necessary to export to XLA. // necessary to export to XLA.
std::unique_ptr<OperationPass<FuncOp>> createSinkConstantsToControlFlowPass(); std::unique_ptr<OperationPass<FuncOp>> createSinkConstantsToControlFlowPass();
// fuse xla_hlo ops to kLoop/kInput fusion patterns // fuse mhlo ops to kLoop/kInput fusion patterns
std::unique_ptr<OperationPass<FuncOp>> createXlaHloFusionPass(); std::unique_ptr<OperationPass<FuncOp>> createXlaHloFusionPass();
} // namespace xla_hlo } // namespace mhlo
namespace xla_lhlo { namespace xla_lhlo {

View File

@ -27,7 +27,7 @@ class LLVMTypeConverter;
class LowerToLLVMOptions; class LowerToLLVMOptions;
class OwningRewritePatternList; class OwningRewritePatternList;
class BufferAssignmentPlacer; class BufferAssignmentPlacer;
namespace xla_hlo { namespace mhlo {
// Collection of rewrite patterns for lowering a general dot product. // Collection of rewrite patterns for lowering a general dot product.
void PopulateGeneralDotOpLoweringPatterns(OwningRewritePatternList *patterns, void PopulateGeneralDotOpLoweringPatterns(OwningRewritePatternList *patterns,
@ -73,7 +73,7 @@ void PopulateTransformUnrankedHloPatterns(MLIRContext *context,
void PopulateUnfuseBatchNormPatterns(MLIRContext *context, void PopulateUnfuseBatchNormPatterns(MLIRContext *context,
OwningRewritePatternList *patterns); OwningRewritePatternList *patterns);
} // namespace xla_hlo } // namespace mhlo
namespace xla_lhlo { namespace xla_lhlo {

View File

@ -18,7 +18,7 @@ limitations under the License.
#include "tensorflow/compiler/mlir/hlo/include/mlir-hlo/Dialect/mhlo/IR/lhlo_ops.h" #include "tensorflow/compiler/mlir/hlo/include/mlir-hlo/Dialect/mhlo/IR/lhlo_ops.h"
// Static initialization for XLA dialect registration. // Static initialization for XLA dialect registration.
static mlir::DialectRegistration<mlir::xla_hlo::XlaHloDialect> xla_hlo_ops; static mlir::DialectRegistration<mlir::mhlo::XlaHloDialect> mhlo_ops;
static mlir::DialectRegistration<mlir::xla_chlo::XlaHloClientDialect> static mlir::DialectRegistration<mlir::xla_chlo::XlaHloClientDialect>
xla_chlo_ops; xla_chlo_ops;
static mlir::DialectRegistration<mlir::xla_lhlo::XlaLhloDialect> xla_lhlo_ops; static mlir::DialectRegistration<mlir::xla_lhlo::XlaLhloDialect> xla_lhlo_ops;

View File

@ -60,7 +60,7 @@ limitations under the License.
namespace mlir { namespace mlir {
#include "tensorflow/compiler/mlir/hlo/include/mlir-hlo/Dialect/mhlo/IR/hlo_structs.cc.inc" #include "tensorflow/compiler/mlir/hlo/include/mlir-hlo/Dialect/mhlo/IR/hlo_structs.cc.inc"
namespace xla_hlo { namespace mhlo {
Operation* XlaHloDialect::materializeConstant(OpBuilder& builder, Operation* XlaHloDialect::materializeConstant(OpBuilder& builder,
Attribute value, Type type, Attribute value, Type type,
@ -68,8 +68,7 @@ Operation* XlaHloDialect::materializeConstant(OpBuilder& builder,
// HLO dialect constants only support ElementsAttr unlike standard dialect // HLO dialect constants only support ElementsAttr unlike standard dialect
// constant which supports all attributes. // constant which supports all attributes.
if (value.isa<ElementsAttr>()) if (value.isa<ElementsAttr>())
return builder.create<xla_hlo::ConstOp>(loc, type, return builder.create<mhlo::ConstOp>(loc, type, value.cast<ElementsAttr>());
value.cast<ElementsAttr>());
return nullptr; return nullptr;
} }
@ -167,7 +166,7 @@ void ConstOp::build(OpBuilder& builder, OperationState& result,
} }
// TODO: support other XLA specific types. // TODO: support other XLA specific types.
assert(type && "unsupported attribute type for building xla_hlo.constant"); assert(type && "unsupported attribute type for building mhlo.constant");
result.types.push_back(type); result.types.push_back(type);
result.addAttribute("value", value); result.addAttribute("value", value);
} }
@ -215,6 +214,41 @@ static LogicalResult Verify(IotaOp op) {
return success(); return success();
} }
// Iota operations across multiple dimensions can be reduced to an iota and a
// ranked broadcast.
struct IotaBroadcast : public OpRewritePattern<IotaOp> {
using OpRewritePattern<IotaOp>::OpRewritePattern;
LogicalResult matchAndRewrite(IotaOp iota,
PatternRewriter& rewriter) const override {
auto result_ty = iota.getType().cast<ShapedType>();
if (!result_ty.hasRank() || result_ty.getRank() < 2) {
return failure();
}
auto iota_dimension = iota.iota_dimension();
auto iota_type = RankedTensorType::get(
{result_ty.getDimSize(iota_dimension.getLimitedValue())},
result_ty.getElementType());
auto new_iota = rewriter.create<IotaOp>(iota.getLoc(), iota_type,
rewriter.getI64IntegerAttr(0));
auto broadcast_attr = DenseIntElementsAttr::get(
RankedTensorType::get({1}, rewriter.getIntegerType(64)),
{iota_dimension});
rewriter.replaceOpWithNewOp<BroadcastInDimOp>(iota, result_ty, new_iota,
broadcast_attr);
return success();
}
};
void IotaOp::getCanonicalizationPatterns(OwningRewritePatternList& results,
MLIRContext* context) {
results.insert<IotaBroadcast>(context);
}
//===----------------------------------------------------------------------===// //===----------------------------------------------------------------------===//
// DynamicIotaOp // DynamicIotaOp
//===----------------------------------------------------------------------===// //===----------------------------------------------------------------------===//
@ -236,11 +270,63 @@ struct DynamicIotaIsStatic : public OpRewritePattern<DynamicIotaOp> {
} }
}; };
// Dynamic Iota operations across multiple dimensions can be reduced to an iota
// and a ranked broadcast.
struct DynamicIotaBroadcast : public OpRewritePattern<DynamicIotaOp> {
using OpRewritePattern<DynamicIotaOp>::OpRewritePattern;
LogicalResult matchAndRewrite(DynamicIotaOp iota,
PatternRewriter& rewriter) const override {
auto result_ty = iota.getType().cast<ShapedType>();
if (!result_ty.hasRank() || result_ty.getRank() < 2) {
return failure();
}
auto iota_dimension = iota.iota_dimension();
auto iota_dimension_int = iota_dimension.getLimitedValue();
auto converted_shape = rewriter.create<IndexCastOp>(
iota.getLoc(),
RankedTensorType::get(
iota.output_shape().getType().cast<ShapedType>().getShape(),
rewriter.getI64Type()),
iota.output_shape());
auto sliced_shape = rewriter.create<SliceOp>(
iota.getLoc(), converted_shape,
GetI64ElementsAttr(iota_dimension_int, &rewriter),
GetI64ElementsAttr(iota_dimension_int + 1, &rewriter),
GetI64ElementsAttr(1, &rewriter));
auto converted_sliced_shape = rewriter.create<IndexCastOp>(
iota.getLoc(),
RankedTensorType::get(
{1},
iota.output_shape().getType().cast<ShapedType>().getElementType()),
sliced_shape);
auto iota_type = RankedTensorType::get(
{result_ty.getDimSize(iota_dimension_int)}, result_ty.getElementType());
auto new_iota = rewriter.create<DynamicIotaOp>(
iota.getLoc(), iota_type, converted_sliced_shape,
rewriter.getI64IntegerAttr(0));
auto broadcast_attr = DenseIntElementsAttr::get(
RankedTensorType::get({1}, rewriter.getIntegerType(64)),
{iota_dimension});
rewriter.replaceOpWithNewOp<DynamicBroadcastInDimOp>(
iota, result_ty, new_iota, iota.output_shape(), broadcast_attr);
return success();
}
};
} // namespace } // namespace
void DynamicIotaOp::getCanonicalizationPatterns( void DynamicIotaOp::getCanonicalizationPatterns(
OwningRewritePatternList& results, MLIRContext* context) { OwningRewritePatternList& results, MLIRContext* context) {
results.insert<DynamicIotaIsStatic>(context); results.insert<DynamicIotaIsStatic>(context);
results.insert<DynamicIotaBroadcast>(context);
} }
//===----------------------------------------------------------------------===// //===----------------------------------------------------------------------===//
@ -387,7 +473,7 @@ static LogicalResult Verify(GetTupleElementOp op) {
OpFoldResult GetTupleElementOp::fold(ArrayRef<Attribute> operands) { OpFoldResult GetTupleElementOp::fold(ArrayRef<Attribute> operands) {
if (auto tupleOp = if (auto tupleOp =
dyn_cast_or_null<xla_hlo::TupleOp>(getOperand().getDefiningOp())) { dyn_cast_or_null<mhlo::TupleOp>(getOperand().getDefiningOp())) {
return tupleOp.getOperand(index().getLimitedValue()); return tupleOp.getOperand(index().getLimitedValue());
} }
@ -693,10 +779,8 @@ void ComplexOp::build(OpBuilder& builder, OperationState& state, Value lhs,
} }
OpFoldResult ComplexOp::fold(ArrayRef<Attribute> operands) { OpFoldResult ComplexOp::fold(ArrayRef<Attribute> operands) {
auto real_op = auto real_op = dyn_cast_or_null<mhlo::RealOp>(getOperand(0).getDefiningOp());
dyn_cast_or_null<xla_hlo::RealOp>(getOperand(0).getDefiningOp()); auto imag_op = dyn_cast_or_null<mhlo::ImagOp>(getOperand(1).getDefiningOp());
auto imag_op =
dyn_cast_or_null<xla_hlo::ImagOp>(getOperand(1).getDefiningOp());
if (real_op && imag_op && real_op.getOperand() == imag_op.getOperand()) { if (real_op && imag_op && real_op.getOperand() == imag_op.getOperand()) {
return real_op.getOperand(); return real_op.getOperand();
} }
@ -727,7 +811,7 @@ void ImagOp::build(OpBuilder& builder, OperationState& state, Value val) {
OpFoldResult ImagOp::fold(ArrayRef<Attribute> operands) { OpFoldResult ImagOp::fold(ArrayRef<Attribute> operands) {
if (auto complex_op = if (auto complex_op =
dyn_cast_or_null<xla_hlo::ComplexOp>(getOperand().getDefiningOp())) { dyn_cast_or_null<mhlo::ComplexOp>(getOperand().getDefiningOp())) {
return complex_op.getOperand(1); return complex_op.getOperand(1);
} }
@ -740,7 +824,7 @@ void RealOp::build(OpBuilder& builder, OperationState& state, Value val) {
OpFoldResult RealOp::fold(ArrayRef<Attribute> operands) { OpFoldResult RealOp::fold(ArrayRef<Attribute> operands) {
if (auto complex_op = if (auto complex_op =
dyn_cast_or_null<xla_hlo::ComplexOp>(getOperand().getDefiningOp())) { dyn_cast_or_null<mhlo::ComplexOp>(getOperand().getDefiningOp())) {
return complex_op.getOperand(0); return complex_op.getOperand(0);
} }
@ -1148,7 +1232,7 @@ static LogicalResult Verify(MapOp op) {
// RecvOp // RecvOp
//===----------------------------------------------------------------------===// //===----------------------------------------------------------------------===//
// Checks that the result type is of the form `tuple<any_type, xla_hlo::token>` // Checks that the result type is of the form `tuple<any_type, mhlo::token>`
static LogicalResult Verify(RecvOp op) { static LogicalResult Verify(RecvOp op) {
auto result_ty = op.getResult().getType().cast<TupleType>(); auto result_ty = op.getResult().getType().cast<TupleType>();
auto subtypes = result_ty.getTypes(); auto subtypes = result_ty.getTypes();
@ -2020,7 +2104,7 @@ void CompareOp::build(OpBuilder& builder, OperationState& result, Value lhs,
#include "tensorflow/compiler/mlir/hlo/include/mlir-hlo/Dialect/mhlo/IR/hlo_ops.cc.inc" #include "tensorflow/compiler/mlir/hlo/include/mlir-hlo/Dialect/mhlo/IR/hlo_ops.cc.inc"
//===----------------------------------------------------------------------===// //===----------------------------------------------------------------------===//
// xla_hlo Dialect Interfaces // mhlo Dialect Interfaces
//===----------------------------------------------------------------------===// //===----------------------------------------------------------------------===//
namespace { namespace {
@ -2032,7 +2116,7 @@ struct HLOInlinerInterface : public DialectInlinerInterface {
BlockAndValueMapping& valueMapping) const final { BlockAndValueMapping& valueMapping) const final {
return true; return true;
} }
// Operations in xla_hlo dialect are always legal to inline since they are // Operations in mhlo dialect are always legal to inline since they are
// pure. // pure.
bool isLegalToInline(Operation*, Region*, BlockAndValueMapping&) const final { bool isLegalToInline(Operation*, Region*, BlockAndValueMapping&) const final {
return true; return true;
@ -2041,7 +2125,7 @@ struct HLOInlinerInterface : public DialectInlinerInterface {
} // end anonymous namespace } // end anonymous namespace
//===----------------------------------------------------------------------===// //===----------------------------------------------------------------------===//
// xla_hlo Dialect Constructor // mhlo Dialect Constructor
//===----------------------------------------------------------------------===// //===----------------------------------------------------------------------===//
XlaHloDialect::XlaHloDialect(MLIRContext* context) XlaHloDialect::XlaHloDialect(MLIRContext* context)
@ -2061,8 +2145,7 @@ Type XlaHloDialect::parseType(DialectAsmParser& parser) const {
if (parser.parseKeyword(&data_type)) return Type(); if (parser.parseKeyword(&data_type)) return Type();
if (data_type == "token") return TokenType::get(getContext()); if (data_type == "token") return TokenType::get(getContext());
parser.emitError(parser.getNameLoc()) parser.emitError(parser.getNameLoc()) << "unknown mhlo type: " << data_type;
<< "unknown xla_hlo type: " << data_type;
return nullptr; return nullptr;
} }
@ -2071,7 +2154,7 @@ void XlaHloDialect::printType(Type type, DialectAsmPrinter& os) const {
os << "token"; os << "token";
return; return;
} }
os << "<unknown xla_hlo type>"; os << "<unknown mhlo type>";
} }
//===----------------------------------------------------------------------===// //===----------------------------------------------------------------------===//
@ -2106,5 +2189,5 @@ LogicalResult deriveShapeFromFirstOperand(
return success(); return success();
} }
} // namespace xla_hlo } // namespace mhlo
} // namespace mlir } // namespace mlir

View File

@ -30,7 +30,7 @@ namespace xla_chlo {
namespace { namespace {
// Converts binary ops that statically are determined to not broadcast directly // Converts binary ops that statically are determined to not broadcast directly
// to the corresponding xla_hlo non-broadcasting op. // to the corresponding mhlo non-broadcasting op.
template <typename ChloOpTy, typename HloOpTy, typename Adaptor> template <typename ChloOpTy, typename HloOpTy, typename Adaptor>
struct ConvertTrivialNonBroadcastBinaryOp : public OpRewritePattern<ChloOpTy> { struct ConvertTrivialNonBroadcastBinaryOp : public OpRewritePattern<ChloOpTy> {
using OpRewritePattern<ChloOpTy>::OpRewritePattern; using OpRewritePattern<ChloOpTy>::OpRewritePattern;
@ -63,7 +63,7 @@ struct ConvertTrivialNonBroadcastBinaryOp : public OpRewritePattern<ChloOpTy> {
}; };
// Converts a binary op with ranked broadcasting operands to explicitly // Converts a binary op with ranked broadcasting operands to explicitly
// broadcast and invoke the corresponding xla_hlo non-broadcasting op. // broadcast and invoke the corresponding mhlo non-broadcasting op.
// Note that dynamic broadcasting supported by this pattern is only valid for // Note that dynamic broadcasting supported by this pattern is only valid for
// "numpy" broadcasting semantics as defined here: // "numpy" broadcasting semantics as defined here:
// https://docs.scipy.org/doc/numpy/reference/ufuncs.html // https://docs.scipy.org/doc/numpy/reference/ufuncs.html
@ -136,7 +136,7 @@ struct ConvertRankedDynamicBroadcastBinaryOp
// properly. // properly.
auto lhs_broadcast_dimensions = llvm::to_vector<4>( auto lhs_broadcast_dimensions = llvm::to_vector<4>(
llvm::seq<int64_t>(result_rank - lhs_type.getRank(), result_rank)); llvm::seq<int64_t>(result_rank - lhs_type.getRank(), result_rank));
Value broadcasted_lhs = rewriter.create<xla_hlo::DynamicBroadcastInDimOp>( Value broadcasted_lhs = rewriter.create<mhlo::DynamicBroadcastInDimOp>(
loc, loc,
RankedTensorType::get(result_type.getShape(), RankedTensorType::get(result_type.getShape(),
lhs_type.getElementType()), lhs_type.getElementType()),
@ -144,7 +144,7 @@ struct ConvertRankedDynamicBroadcastBinaryOp
rewriter.getI64TensorAttr(lhs_broadcast_dimensions)); rewriter.getI64TensorAttr(lhs_broadcast_dimensions));
auto rhs_broadcast_dimensions = llvm::to_vector<4>( auto rhs_broadcast_dimensions = llvm::to_vector<4>(
llvm::seq<int64_t>(result_rank - rhs_type.getRank(), result_rank)); llvm::seq<int64_t>(result_rank - rhs_type.getRank(), result_rank));
Value broadcasted_rhs = rewriter.create<xla_hlo::DynamicBroadcastInDimOp>( Value broadcasted_rhs = rewriter.create<mhlo::DynamicBroadcastInDimOp>(
loc, loc,
RankedTensorType::get(result_type.getShape(), RankedTensorType::get(result_type.getShape(),
rhs_type.getElementType()), rhs_type.getElementType()),
@ -182,21 +182,19 @@ struct HloBinaryElementwiseAdaptor {
}; };
struct HloComplexAdaptor { struct HloComplexAdaptor {
static xla_hlo::ComplexOp CreateOp(BroadcastComplexOp from_op, static mhlo::ComplexOp CreateOp(BroadcastComplexOp from_op, Type result_type,
Type result_type, Value broadcasted_lhs, Value broadcasted_lhs, Value broadcasted_rhs,
Value broadcasted_rhs,
OpBuilder &builder) { OpBuilder &builder) {
return builder.create<xla_hlo::ComplexOp>(from_op.getLoc(), result_type, return builder.create<mhlo::ComplexOp>(from_op.getLoc(), result_type,
broadcasted_lhs, broadcasted_rhs); broadcasted_lhs, broadcasted_rhs);
} }
}; };
struct HloCompareAdaptor { struct HloCompareAdaptor {
static xla_hlo::CompareOp CreateOp(BroadcastCompareOp from_op, static mhlo::CompareOp CreateOp(BroadcastCompareOp from_op, Type result_type,
Type result_type, Value broadcasted_lhs, Value broadcasted_lhs, Value broadcasted_rhs,
Value broadcasted_rhs,
OpBuilder &builder) { OpBuilder &builder) {
return builder.create<xla_hlo::CompareOp>(from_op.getLoc(), result_type, return builder.create<mhlo::CompareOp>(from_op.getLoc(), result_type,
broadcasted_lhs, broadcasted_rhs, broadcasted_lhs, broadcasted_rhs,
from_op.comparison_direction()); from_op.comparison_direction());
} }
@ -214,28 +212,27 @@ void PopulateLegalizeChloToHloPatterns(MLIRContext *context,
HloBinaryElementwiseAdaptor<ChloOp, HloOp>>(context, \ HloBinaryElementwiseAdaptor<ChloOp, HloOp>>(context, \
patterns); patterns);
POPULATE_BCAST(BroadcastAddOp, xla_hlo::AddOp); POPULATE_BCAST(BroadcastAddOp, mhlo::AddOp);
POPULATE_BCAST(BroadcastAndOp, xla_hlo::AndOp); POPULATE_BCAST(BroadcastAndOp, mhlo::AndOp);
POPULATE_BCAST(BroadcastAtan2Op, xla_hlo::Atan2Op); POPULATE_BCAST(BroadcastAtan2Op, mhlo::Atan2Op);
POPULATE_BCAST(BroadcastDivOp, xla_hlo::DivOp); POPULATE_BCAST(BroadcastDivOp, mhlo::DivOp);
POPULATE_BCAST(BroadcastMaxOp, xla_hlo::MaxOp); POPULATE_BCAST(BroadcastMaxOp, mhlo::MaxOp);
POPULATE_BCAST(BroadcastMinOp, xla_hlo::MinOp); POPULATE_BCAST(BroadcastMinOp, mhlo::MinOp);
POPULATE_BCAST(BroadcastMulOp, xla_hlo::MulOp); POPULATE_BCAST(BroadcastMulOp, mhlo::MulOp);
POPULATE_BCAST(BroadcastOrOp, xla_hlo::OrOp); POPULATE_BCAST(BroadcastOrOp, mhlo::OrOp);
POPULATE_BCAST(BroadcastPowOp, xla_hlo::PowOp); POPULATE_BCAST(BroadcastPowOp, mhlo::PowOp);
POPULATE_BCAST(BroadcastRemOp, xla_hlo::RemOp); POPULATE_BCAST(BroadcastRemOp, mhlo::RemOp);
POPULATE_BCAST(BroadcastShiftLeftOp, xla_hlo::ShiftLeftOp); POPULATE_BCAST(BroadcastShiftLeftOp, mhlo::ShiftLeftOp);
POPULATE_BCAST(BroadcastShiftRightArithmeticOp, POPULATE_BCAST(BroadcastShiftRightArithmeticOp, mhlo::ShiftRightArithmeticOp);
xla_hlo::ShiftRightArithmeticOp); POPULATE_BCAST(BroadcastShiftRightLogicalOp, mhlo::ShiftRightLogicalOp);
POPULATE_BCAST(BroadcastShiftRightLogicalOp, xla_hlo::ShiftRightLogicalOp); POPULATE_BCAST(BroadcastSubOp, mhlo::SubOp);
POPULATE_BCAST(BroadcastSubOp, xla_hlo::SubOp); POPULATE_BCAST(BroadcastXorOp, mhlo::XorOp);
POPULATE_BCAST(BroadcastXorOp, xla_hlo::XorOp);
// Broadcasting ops requiring special construction. // Broadcasting ops requiring special construction.
PopulateForBinaryOp<BroadcastComplexOp, xla_hlo::ComplexOp, PopulateForBinaryOp<BroadcastComplexOp, mhlo::ComplexOp, HloComplexAdaptor>(
HloComplexAdaptor>(context, patterns); context, patterns);
PopulateForBinaryOp<BroadcastCompareOp, xla_hlo::CompareOp, PopulateForBinaryOp<BroadcastCompareOp, mhlo::CompareOp, HloCompareAdaptor>(
HloCompareAdaptor>(context, patterns); context, patterns);
} }
} // namespace xla_chlo } // namespace xla_chlo

View File

@ -32,8 +32,8 @@ struct TestChloLegalizeToHloPass
OwningRewritePatternList conversionPatterns; OwningRewritePatternList conversionPatterns;
conversionTarget.addIllegalDialect<XlaHloClientDialect>(); conversionTarget.addIllegalDialect<XlaHloClientDialect>();
// Consider the xla_hlo dialect legal for tests. // Consider the mhlo dialect legal for tests.
conversionTarget.addLegalDialect<xla_hlo::XlaHloDialect>(); conversionTarget.addLegalDialect<mhlo::XlaHloDialect>();
// The conversion uses helpers from the Standard dialect. // The conversion uses helpers from the Standard dialect.
conversionTarget.addLegalDialect<mlir::StandardOpsDialect>(); conversionTarget.addLegalDialect<mlir::StandardOpsDialect>();
conversionTarget.addLegalDialect<mlir::shape::ShapeDialect>(); conversionTarget.addLegalDialect<mlir::shape::ShapeDialect>();

View File

@ -37,7 +37,7 @@ limitations under the License.
#include "tensorflow/compiler/mlir/hlo/include/mlir-hlo/Dialect/mhlo/transforms/rewriters.h" #include "tensorflow/compiler/mlir/hlo/include/mlir-hlo/Dialect/mhlo/transforms/rewriters.h"
namespace mlir { namespace mlir {
namespace xla_hlo { namespace mhlo {
namespace { namespace {
template <typename T> template <typename T>
@ -128,7 +128,7 @@ class HloToLhloOpConverter : public BaseOpConversion<HloOpTy> {
op->getLoc(), result.value(), results_shape.front(), &rewriter)); op->getLoc(), result.value(), results_shape.front(), &rewriter));
} }
} }
rewriter.create<xla_hlo::HloToLhloOp<HloOpTy>>(op->getLoc(), llvm::None, rewriter.create<mhlo::HloToLhloOp<HloOpTy>>(op->getLoc(), llvm::None,
buffer_args, op->getAttrs()); buffer_args, op->getAttrs());
rewriter.replaceOp(op, ArrayRef<Value>(buffer_args).slice(operands.size())); rewriter.replaceOp(op, ArrayRef<Value>(buffer_args).slice(operands.size()));
return success(); return success();
@ -136,12 +136,12 @@ class HloToLhloOpConverter : public BaseOpConversion<HloOpTy> {
}; };
struct HloToLhloDynamicBroadcastInDimOpConverter struct HloToLhloDynamicBroadcastInDimOpConverter
: public BaseOpConversion<xla_hlo::DynamicBroadcastInDimOp> { : public BaseOpConversion<mhlo::DynamicBroadcastInDimOp> {
public: public:
using BaseOpConversion<xla_hlo::DynamicBroadcastInDimOp>::BaseOpConversion; using BaseOpConversion<mhlo::DynamicBroadcastInDimOp>::BaseOpConversion;
LogicalResult matchAndRewrite( LogicalResult matchAndRewrite(
xla_hlo::DynamicBroadcastInDimOp op, ArrayRef<Value> operands, mhlo::DynamicBroadcastInDimOp op, ArrayRef<Value> operands,
ConversionPatternRewriter& rewriter) const final { ConversionPatternRewriter& rewriter) const final {
auto loc = op.getLoc(); auto loc = op.getLoc();
Value resultBuffer = InsertDynamicAllocAndDealloc( Value resultBuffer = InsertDynamicAllocAndDealloc(
@ -162,7 +162,7 @@ struct HloToLhloDynamicBroadcastInDimOpConverter
// and size of the target dimension if size-1 dimension expansion is // and size of the target dimension if size-1 dimension expansion is
// necessary. // necessary.
xla_lhlo::DynamicMemRefCastOp InsertDynamicMemrefCastOp( xla_lhlo::DynamicMemRefCastOp InsertDynamicMemrefCastOp(
xla_hlo::DynamicBroadcastInDimOp op, Value operand, OpBuilder* b) const { mhlo::DynamicBroadcastInDimOp op, Value operand, OpBuilder* b) const {
auto loc = op.getLoc(); auto loc = op.getLoc();
auto operand_type = operand.getType().cast<MemRefType>(); auto operand_type = operand.getType().cast<MemRefType>();
auto operand_shape = operand_type.getShape(); auto operand_shape = operand_type.getShape();
@ -220,12 +220,12 @@ struct HloToLhloDynamicBroadcastInDimOpConverter
} }
}; };
struct HloToLhloReduceOpConverter : public BaseOpConversion<xla_hlo::ReduceOp> { struct HloToLhloReduceOpConverter : public BaseOpConversion<mhlo::ReduceOp> {
public: public:
using BaseOpConversion<xla_hlo::ReduceOp>::BaseOpConversion; using BaseOpConversion<mhlo::ReduceOp>::BaseOpConversion;
LogicalResult matchAndRewrite( LogicalResult matchAndRewrite(
xla_hlo::ReduceOp op, ArrayRef<Value> operands, mhlo::ReduceOp op, ArrayRef<Value> operands,
ConversionPatternRewriter& rewriter) const final { ConversionPatternRewriter& rewriter) const final {
auto loc = op.getLoc(); auto loc = op.getLoc();
// TODO(b/137624192) Implement variadic reduce. // TODO(b/137624192) Implement variadic reduce.
@ -314,10 +314,10 @@ class HloToLhloTensorStoreOpConverter
// "xla_lhlo.fusion"() ({ // "xla_lhlo.fusion"() ({
// %0 = tensor_load %arg1 : memref<2x2xf32> // %0 = tensor_load %arg1 : memref<2x2xf32>
// %1 = tensor_load %arg2 : memref<2x2xf32> // %1 = tensor_load %arg2 : memref<2x2xf32>
// %2 = "xla_hlo.add"(%0, %1) : // %2 = "mhlo.add"(%0, %1) :
// (tensor<2x2xf32>, tensor<2x2xf32>) -> tensor<2x2xf32> // (tensor<2x2xf32>, tensor<2x2xf32>) -> tensor<2x2xf32>
// %3 = tensor_load %arg0 : memref<2x2xf32> // %3 = tensor_load %arg0 : memref<2x2xf32>
// %4 = "xla_hlo.multiply"(%2, %3) : // %4 = "mhlo.multiply"(%2, %3) :
// (tensor<2x2xf32>, tensor<2x2xf32>) -> tensor<2x2xf32> // (tensor<2x2xf32>, tensor<2x2xf32>) -> tensor<2x2xf32>
// tensor_store %4, %arg3 : memref<2x2xf32> // tensor_store %4, %arg3 : memref<2x2xf32>
// "xla_lhlo.terminator"() : () -> () // "xla_lhlo.terminator"() : () -> ()
@ -344,8 +344,8 @@ class HloToLhloTensorStoreOpConverter
// FuncOp signature conversion example: // FuncOp signature conversion example:
// //
// func @func_op(%arg0: tensor<4xf32>, %arg1: tensor<4xf32>) -> tensor<4xf32> { // func @func_op(%arg0: tensor<4xf32>, %arg1: tensor<4xf32>) -> tensor<4xf32> {
// %0 = "xla_hlo.maximum"(%arg0, %arg1) : (tensor<4xf32>, tensor<4xf32>) -> // %0 = "mhlo.maximum"(%arg0, %arg1) : (tensor<4xf32>, tensor<4xf32>) ->
// tensor<4xf32> %1 = "xla_hlo.add"(%arg0, %0) : (tensor<4xf32>, // tensor<4xf32> %1 = "mhlo.add"(%arg0, %0) : (tensor<4xf32>,
// tensor<4xf32>) -> tensor<4xf32> return %1 : tensor<4xf32> // tensor<4xf32>) -> tensor<4xf32> return %1 : tensor<4xf32>
// } // }
// //
@ -388,7 +388,7 @@ struct HloLegalizeToLhlo
target.addIllegalOp<mlir::TensorStoreOp>(); target.addIllegalOp<mlir::TensorStoreOp>();
target.addLegalOp<ModuleTerminatorOp>(); target.addLegalOp<ModuleTerminatorOp>();
target.addLegalOp<TensorFromElementsOp>(); target.addLegalOp<TensorFromElementsOp>();
target.addIllegalDialect<xla_hlo::XlaHloDialect>(); target.addIllegalDialect<mhlo::XlaHloDialect>();
BufferAssignmentTypeConverter converter; BufferAssignmentTypeConverter converter;
target.addDynamicallyLegalOp<FuncOp>([&](FuncOp op) { target.addDynamicallyLegalOp<FuncOp>([&](FuncOp op) {
@ -442,38 +442,38 @@ void populateHLOToLHLOConversionPattern(
// clang-format off // clang-format off
patterns->insert< patterns->insert<
HloToLhloDynamicBroadcastInDimOpConverter, HloToLhloDynamicBroadcastInDimOpConverter,
HloToLhloOpConverter<xla_hlo::AbsOp>, HloToLhloOpConverter<mhlo::AbsOp>,
HloToLhloOpConverter<xla_hlo::AddOp>, HloToLhloOpConverter<mhlo::AddOp>,
HloToLhloOpConverter<xla_hlo::AndOp>, HloToLhloOpConverter<mhlo::AndOp>,
HloToLhloOpConverter<xla_hlo::BroadcastInDimOp>, HloToLhloOpConverter<mhlo::BroadcastInDimOp>,
HloToLhloOpConverter<xla_hlo::CeilOp>, HloToLhloOpConverter<mhlo::CeilOp>,
HloToLhloOpConverter<xla_hlo::CompareOp>, HloToLhloOpConverter<mhlo::CompareOp>,
HloToLhloOpConverter<xla_hlo::ComplexOp>, HloToLhloOpConverter<mhlo::ComplexOp>,
HloToLhloOpConverter<xla_hlo::ConstOp>, HloToLhloOpConverter<mhlo::ConstOp>,
HloToLhloOpConverter<xla_hlo::ConvOp>, HloToLhloOpConverter<mhlo::ConvOp>,
HloToLhloOpConverter<xla_hlo::ConvertOp>, HloToLhloOpConverter<mhlo::ConvertOp>,
HloToLhloOpConverter<xla_hlo::CopyOp>, HloToLhloOpConverter<mhlo::CopyOp>,
HloToLhloOpConverter<xla_hlo::CosOp>, HloToLhloOpConverter<mhlo::CosOp>,
HloToLhloOpConverter<xla_hlo::DivOp>, HloToLhloOpConverter<mhlo::DivOp>,
HloToLhloOpConverter<xla_hlo::DotOp>, HloToLhloOpConverter<mhlo::DotOp>,
HloToLhloOpConverter<xla_hlo::ExpOp>, HloToLhloOpConverter<mhlo::ExpOp>,
HloToLhloOpConverter<xla_hlo::GatherOp>, HloToLhloOpConverter<mhlo::GatherOp>,
HloToLhloOpConverter<xla_hlo::ImagOp>, HloToLhloOpConverter<mhlo::ImagOp>,
HloToLhloOpConverter<xla_hlo::IotaOp>, HloToLhloOpConverter<mhlo::IotaOp>,
HloToLhloOpConverter<xla_hlo::LogOp>, HloToLhloOpConverter<mhlo::LogOp>,
HloToLhloOpConverter<xla_hlo::MaxOp>, HloToLhloOpConverter<mhlo::MaxOp>,
HloToLhloOpConverter<xla_hlo::MinOp>, HloToLhloOpConverter<mhlo::MinOp>,
HloToLhloOpConverter<xla_hlo::MulOp>, HloToLhloOpConverter<mhlo::MulOp>,
HloToLhloOpConverter<xla_hlo::NegOp>, HloToLhloOpConverter<mhlo::NegOp>,
HloToLhloOpConverter<xla_hlo::RealOp>, HloToLhloOpConverter<mhlo::RealOp>,
HloToLhloOpConverter<xla_hlo::RemOp>, HloToLhloOpConverter<mhlo::RemOp>,
HloToLhloOpConverter<xla_hlo::RsqrtOp>, HloToLhloOpConverter<mhlo::RsqrtOp>,
HloToLhloOpConverter<xla_hlo::ReshapeOp>, HloToLhloOpConverter<mhlo::ReshapeOp>,
HloToLhloOpConverter<xla_hlo::SelectOp>, HloToLhloOpConverter<mhlo::SelectOp>,
HloToLhloOpConverter<xla_hlo::SignOp>, HloToLhloOpConverter<mhlo::SignOp>,
HloToLhloOpConverter<xla_hlo::SqrtOp>, HloToLhloOpConverter<mhlo::SqrtOp>,
HloToLhloOpConverter<xla_hlo::SubOp>, HloToLhloOpConverter<mhlo::SubOp>,
HloToLhloOpConverter<xla_hlo::TanhOp>, HloToLhloOpConverter<mhlo::TanhOp>,
HloToLhloReduceOpConverter, HloToLhloReduceOpConverter,
HloToLhloTensorLoadOpConverter, HloToLhloTensorLoadOpConverter,
HloToLhloTensorStoreOpConverter HloToLhloTensorStoreOpConverter
@ -489,5 +489,5 @@ std::unique_ptr<OperationPass<ModuleOp>> createLegalizeToLhloPass(
static PassRegistration<HloLegalizeToLhlo> legalize_pass( static PassRegistration<HloLegalizeToLhlo> legalize_pass(
"hlo-legalize-to-lhlo", "Legalize from HLO dialect to LHLO dialect"); "hlo-legalize-to-lhlo", "Legalize from HLO dialect to LHLO dialect");
} // namespace xla_hlo } // namespace mhlo
} // namespace mlir } // namespace mlir

View File

@ -35,7 +35,7 @@ limitations under the License.
using mlir::PassRegistration; using mlir::PassRegistration;
namespace mlir { namespace mlir {
namespace xla_hlo { namespace mhlo {
namespace { namespace {
struct LegalizeControlFlow struct LegalizeControlFlow
: public mlir::PassWrapper<LegalizeControlFlow, FunctionPass> { : public mlir::PassWrapper<LegalizeControlFlow, FunctionPass> {
@ -51,7 +51,7 @@ LogicalResult ReplaceTerminators(Region* region, Block* target_block,
OpBuilder* builder) { OpBuilder* builder) {
for (auto& old_block : region->getBlocks()) { for (auto& old_block : region->getBlocks()) {
Block* block = mapper.lookup(&old_block); Block* block = mapper.lookup(&old_block);
auto return_op = dyn_cast<xla_hlo::ReturnOp>(block->getTerminator()); auto return_op = dyn_cast<mhlo::ReturnOp>(block->getTerminator());
if (!return_op) continue; if (!return_op) continue;
builder->setInsertionPointToEnd(block); builder->setInsertionPointToEnd(block);
builder->create<mlir::BranchOp>(loc, target_block, return_op.getOperands()); builder->create<mlir::BranchOp>(loc, target_block, return_op.getOperands());
@ -61,7 +61,7 @@ LogicalResult ReplaceTerminators(Region* region, Block* target_block,
return success(); return success();
} }
LogicalResult LowerIfOp(mlir::xla_hlo::IfOp if_op) { LogicalResult LowerIfOp(mlir::mhlo::IfOp if_op) {
Operation* op_inst = if_op.getOperation(); Operation* op_inst = if_op.getOperation();
mlir::OpBuilder builder(if_op); mlir::OpBuilder builder(if_op);
auto orig_block = op_inst->getBlock(); auto orig_block = op_inst->getBlock();
@ -106,13 +106,13 @@ LogicalResult LowerIfOp(mlir::xla_hlo::IfOp if_op) {
return success(); return success();
} }
LogicalResult LowerWhileOp(mlir::xla_hlo::WhileOp while_op) { LogicalResult LowerWhileOp(mlir::mhlo::WhileOp while_op) {
// Converts an XLA while loop into control flow. This generates a set of MLIR // Converts an XLA while loop into control flow. This generates a set of MLIR
// blocks and branches, along with inlining the regions provided by the XLA // blocks and branches, along with inlining the regions provided by the XLA
// while loop. The structure should be similar to below: // while loop. The structure should be similar to below:
// //
// <prior operations> // <prior operations>
// %0 = "xla_hlo.while"(%arg0) {^cond(...){...}, ^body(...){...}} // %0 = "mhlo.while"(%arg0) {^cond(...){...}, ^body(...){...}}
// <post operations> // <post operations>
auto* op_inst = while_op.getOperation(); auto* op_inst = while_op.getOperation();
mlir::OpBuilder builder(while_op); mlir::OpBuilder builder(while_op);
@ -147,7 +147,7 @@ LogicalResult LowerWhileOp(mlir::xla_hlo::WhileOp while_op) {
// extract_element and conditional branch. This changes the block below: // extract_element and conditional branch. This changes the block below:
// ^cond(%0): // ^cond(%0):
// <inlined conditional region> // <inlined conditional region>
// "xla_hlo".return(%1) // "mhlo".return(%1)
// //
// Into: // Into:
// ^cond(%0): // ^cond(%0):
@ -156,14 +156,14 @@ LogicalResult LowerWhileOp(mlir::xla_hlo::WhileOp while_op) {
// cond_br %2, ^body(%0), ^tail(%0) // Branch. // cond_br %2, ^body(%0), ^tail(%0) // Branch.
builder.setInsertionPointToStart(cond_block); builder.setInsertionPointToStart(cond_block);
// Replace the xla_hlo::ReturnOp with a branch back to the condition block. // Replace the mhlo::ReturnOp with a branch back to the condition block.
// This is required as the xla_hlo::ReturnOp is used to mark the end of a // This is required as the mhlo::ReturnOp is used to mark the end of a
// block for regions nested inside of a operations (MLIR ReturnOp cannot be // block for regions nested inside of a operations (MLIR ReturnOp cannot be
// nested within an non-function region). // nested within an non-function region).
for (auto& block : while_op.cond()) { for (auto& block : while_op.cond()) {
auto new_block = mapper.lookup(&block); auto new_block = mapper.lookup(&block);
auto return_op = dyn_cast<xla_hlo::ReturnOp>(new_block->getTerminator()); auto return_op = dyn_cast<mhlo::ReturnOp>(new_block->getTerminator());
if (!return_op) continue; if (!return_op) continue;
builder.setInsertionPointToEnd(new_block); builder.setInsertionPointToEnd(new_block);
@ -183,7 +183,7 @@ LogicalResult LowerWhileOp(mlir::xla_hlo::WhileOp while_op) {
// conditional block. This changes the block below: // conditional block. This changes the block below:
// ^body(%0): // ^body(%0):
// <inlined body block> // <inlined body block>
// "xla_hlo".return(%1) // "mhlo".return(%1)
// //
// Into: // Into:
// ^body(%0): // ^body(%0):
@ -191,8 +191,7 @@ LogicalResult LowerWhileOp(mlir::xla_hlo::WhileOp while_op) {
// br ^cond(%0) // Branch. // br ^cond(%0) // Branch.
for (auto& block : while_op.body()) { for (auto& block : while_op.body()) {
auto new_block = mapper.lookup(&block); auto new_block = mapper.lookup(&block);
auto return_op = auto return_op = dyn_cast<mlir::mhlo::ReturnOp>(new_block->getTerminator());
dyn_cast<mlir::xla_hlo::ReturnOp>(new_block->getTerminator());
if (!return_op) continue; if (!return_op) continue;
builder.setInsertionPointToEnd(new_block); builder.setInsertionPointToEnd(new_block);
builder.create<mlir::BranchOp>(loc, cond_block, return_op.getOperands()); builder.create<mlir::BranchOp>(loc, cond_block, return_op.getOperands());
@ -224,14 +223,14 @@ void LegalizeControlFlow::runOnFunction() {
} }
} }
} // namespace } // namespace
} // namespace xla_hlo } // namespace mhlo
} // namespace mlir } // namespace mlir
std::unique_ptr<mlir::OperationPass<mlir::FuncOp>> std::unique_ptr<mlir::OperationPass<mlir::FuncOp>>
mlir::xla_hlo::createLegalizeControlFlowPass() { mlir::mhlo::createLegalizeControlFlowPass() {
return std::make_unique<LegalizeControlFlow>(); return std::make_unique<LegalizeControlFlow>();
} }
static PassRegistration<mlir::xla_hlo::LegalizeControlFlow> legalize_cf_pass( static PassRegistration<mlir::mhlo::LegalizeControlFlow> legalize_cf_pass(
"xla-legalize-control-flow", "xla-legalize-control-flow",
"Legalize from XLA control flow to MLIR control flow"); "Legalize from XLA control flow to MLIR control flow");

View File

@ -28,14 +28,14 @@ namespace mlir {
namespace { namespace {
#include "tensorflow/compiler/mlir/hlo/lib/Dialect/mhlo/transforms/generated_legalize_to_standard.inc" #include "tensorflow/compiler/mlir/hlo/lib/Dialect/mhlo/transforms/generated_legalize_to_standard.inc"
} // end anonymous namespace } // end anonymous namespace
namespace xla_hlo { namespace mhlo {
namespace { namespace {
class CompareIConvert : public OpRewritePattern<xla_hlo::CompareOp> { class CompareIConvert : public OpRewritePattern<mhlo::CompareOp> {
public: public:
using OpRewritePattern::OpRewritePattern; using OpRewritePattern::OpRewritePattern;
LogicalResult matchAndRewrite(xla_hlo::CompareOp op, LogicalResult matchAndRewrite(mhlo::CompareOp op,
PatternRewriter &rewriter) const override { PatternRewriter &rewriter) const override {
auto lhs = op.lhs(); auto lhs = op.lhs();
auto rhs = op.rhs(); auto rhs = op.rhs();
@ -68,11 +68,11 @@ class CompareIConvert : public OpRewritePattern<xla_hlo::CompareOp> {
} }
}; };
class CompareFConvert : public OpRewritePattern<xla_hlo::CompareOp> { class CompareFConvert : public OpRewritePattern<mhlo::CompareOp> {
public: public:
using OpRewritePattern::OpRewritePattern; using OpRewritePattern::OpRewritePattern;
LogicalResult matchAndRewrite(xla_hlo::CompareOp op, LogicalResult matchAndRewrite(mhlo::CompareOp op,
PatternRewriter &rewriter) const override { PatternRewriter &rewriter) const override {
auto lhs = op.lhs(); auto lhs = op.lhs();
auto rhs = op.rhs(); auto rhs = op.rhs();
@ -109,11 +109,11 @@ class CompareFConvert : public OpRewritePattern<xla_hlo::CompareOp> {
// convert the integer constant to iota result type. For complex types, the real // convert the integer constant to iota result type. For complex types, the real
// part is replaced with the generated constant and the imaginary part is // part is replaced with the generated constant and the imaginary part is
// replaced with zero tensor. // replaced with zero tensor.
class ConvertIotaOp : public OpRewritePattern<xla_hlo::IotaOp> { class ConvertIotaOp : public OpRewritePattern<mhlo::IotaOp> {
public: public:
using OpRewritePattern::OpRewritePattern; using OpRewritePattern::OpRewritePattern;
LogicalResult matchAndRewrite(xla_hlo::IotaOp op, LogicalResult matchAndRewrite(mhlo::IotaOp op,
PatternRewriter &rewriter) const override { PatternRewriter &rewriter) const override {
auto output_type = op.getType().cast<ShapedType>(); auto output_type = op.getType().cast<ShapedType>();
auto output_size = output_type.getNumElements(); auto output_size = output_type.getNumElements();
@ -168,8 +168,7 @@ class ConvertIotaOp : public OpRewritePattern<xla_hlo::IotaOp> {
loc, DenseIntElementsAttr::get(int_shape_type, APInt(bitwidth, 0))); loc, DenseIntElementsAttr::get(int_shape_type, APInt(bitwidth, 0)));
auto imag_zeroes = auto imag_zeroes =
rewriter.create<ConvertOp>(loc, int_or_float_shape_ty, zeroes); rewriter.create<ConvertOp>(loc, int_or_float_shape_ty, zeroes);
rewriter.replaceOpWithNewOp<xla_hlo::ComplexOp>(op, iota_const, rewriter.replaceOpWithNewOp<mhlo::ComplexOp>(op, iota_const, imag_zeroes);
imag_zeroes);
return success(); return success();
} }
}; };
@ -197,12 +196,12 @@ void PopulateXlaToStdPatterns(OwningRewritePatternList *patterns,
/// Perform the lowering to standard dialect. /// Perform the lowering to standard dialect.
void LegalizeToStandard::runOnFunction() { void LegalizeToStandard::runOnFunction() {
OwningRewritePatternList patterns; OwningRewritePatternList patterns;
mlir::xla_hlo::PopulateXlaToStdPatterns(&patterns, &getContext()); mlir::mhlo::PopulateXlaToStdPatterns(&patterns, &getContext());
applyPatternsAndFoldGreedily(getFunction(), patterns); applyPatternsAndFoldGreedily(getFunction(), patterns);
} }
static PassRegistration<LegalizeToStandard> legalize_pass( static PassRegistration<LegalizeToStandard> legalize_pass(
"xla-legalize-to-std", "Legalize from XLA dialect to standard dialect"); "xla-legalize-to-std", "Legalize from XLA dialect to standard dialect");
} // end namespace xla_hlo } // end namespace mhlo
} // end namespace mlir } // end namespace mlir

View File

@ -84,13 +84,13 @@ Value TransposeReshape(Value arg, mlir::Location loc,
transposed_shape.push_back(arg_shape[val]); transposed_shape.push_back(arg_shape[val]);
} }
auto transpose_type = RankedTensorType::get(transposed_shape, element_type); auto transpose_type = RankedTensorType::get(transposed_shape, element_type);
auto transpose_result = rewriter->create<mlir::xla_hlo::TransposeOp>( auto transpose_result = rewriter->create<mlir::mhlo::TransposeOp>(
loc, transpose_type, arg, transpose_permutation_attr); loc, transpose_type, arg, transpose_permutation_attr);
// Return the final result. // Return the final result.
auto reshaped_type = auto reshaped_type =
RankedTensorType::get({left_size, right_size}, element_type); RankedTensorType::get({left_size, right_size}, element_type);
return rewriter->create<mlir::xla_hlo::ReshapeOp>(loc, reshaped_type, return rewriter->create<mlir::mhlo::ReshapeOp>(loc, reshaped_type,
transpose_result); transpose_result);
} }
@ -125,8 +125,7 @@ Value ProcessDotArg(Value arg, mlir::Location loc,
return TransposeReshape(arg, loc, contract_dims, outer_dims, shape, rewriter); return TransposeReshape(arg, loc, contract_dims, outer_dims, shape, rewriter);
} }
struct GeneralDotConvert struct GeneralDotConvert : public OpRewritePattern<mlir::mhlo::DotGeneralOp> {
: public OpRewritePattern<mlir::xla_hlo::DotGeneralOp> {
// Attempts to lower a General Dot operator to a standard Dot operator. // Attempts to lower a General Dot operator to a standard Dot operator.
// General dots include batching dimensions and can have collapsing // General dots include batching dimensions and can have collapsing
// dimensions along any axis. Inserting correctly arrange transpose and // dimensions along any axis. Inserting correctly arrange transpose and
@ -138,7 +137,7 @@ struct GeneralDotConvert
explicit GeneralDotConvert(MLIRContext *context) explicit GeneralDotConvert(MLIRContext *context)
: OpRewritePattern(context) {} : OpRewritePattern(context) {}
LogicalResult matchAndRewrite(mlir::xla_hlo::DotGeneralOp op, LogicalResult matchAndRewrite(mlir::mhlo::DotGeneralOp op,
PatternRewriter &rewriter) const override { PatternRewriter &rewriter) const override {
auto dot_element_type = mlir::getElementTypeOrSelf(op); auto dot_element_type = mlir::getElementTypeOrSelf(op);
@ -162,10 +161,10 @@ struct GeneralDotConvert
auto new_dot_type = auto new_dot_type =
RankedTensorType::get({lhs_shape[0], rhs_shape[1]}, dot_element_type); RankedTensorType::get({lhs_shape[0], rhs_shape[1]}, dot_element_type);
auto new_dot_op = rewriter.create<mlir::xla_hlo::DotOp>( auto new_dot_op = rewriter.create<mlir::mhlo::DotOp>(
op.getLoc(), new_dot_type, lhs, rhs, *(op.precision_config())); op.getLoc(), new_dot_type, lhs, rhs, *(op.precision_config()));
rewriter.replaceOpWithNewOp<mlir::xla_hlo::ReshapeOp>(op, op.getType(), rewriter.replaceOpWithNewOp<mlir::mhlo::ReshapeOp>(op, op.getType(),
new_dot_op); new_dot_op);
return success(); return success();
} }
@ -176,15 +175,14 @@ struct LegalizeGeneralDot
/// Lower all general dots that can be represented as a non-batched matmul. /// Lower all general dots that can be represented as a non-batched matmul.
void runOnFunction() override { void runOnFunction() override {
OwningRewritePatternList patterns; OwningRewritePatternList patterns;
mlir::xla_hlo::PopulateGeneralDotOpLoweringPatterns(&patterns, mlir::mhlo::PopulateGeneralDotOpLoweringPatterns(&patterns, &getContext());
&getContext());
applyPatternsAndFoldGreedily(getFunction(), patterns); applyPatternsAndFoldGreedily(getFunction(), patterns);
} }
}; };
} // namespace } // namespace
void mlir::xla_hlo::PopulateGeneralDotOpLoweringPatterns( void mlir::mhlo::PopulateGeneralDotOpLoweringPatterns(
OwningRewritePatternList *patterns, MLIRContext *ctx) { OwningRewritePatternList *patterns, MLIRContext *ctx) {
patterns->insert<GeneralDotConvert>(ctx); patterns->insert<GeneralDotConvert>(ctx);
} }

View File

@ -23,7 +23,7 @@ limitations under the License.
#include "tensorflow/compiler/mlir/hlo/include/mlir-hlo/Dialect/mhlo/IR/hlo_ops.h" #include "tensorflow/compiler/mlir/hlo/include/mlir-hlo/Dialect/mhlo/IR/hlo_ops.h"
namespace mlir { namespace mlir {
namespace xla_hlo { namespace mhlo {
namespace { namespace {
@ -86,5 +86,5 @@ void PopulateMaterializeBroadcastsPatterns(MLIRContext *context,
patterns->insert<ClampWithBroadcastConvert>(context); patterns->insert<ClampWithBroadcastConvert>(context);
} }
} // namespace xla_hlo } // namespace mhlo
} // namespace mlir } // namespace mlir

View File

@ -23,7 +23,7 @@ limitations under the License.
#include "tensorflow/compiler/mlir/hlo/include/mlir-hlo/Dialect/mhlo/transforms/rewriters.h" #include "tensorflow/compiler/mlir/hlo/include/mlir-hlo/Dialect/mhlo/transforms/rewriters.h"
namespace mlir { namespace mlir {
namespace xla_hlo { namespace mhlo {
namespace { namespace {
@ -33,7 +33,7 @@ struct TestMaterializeBroadcastsPass
ConversionTarget conversionTarget(getContext()); ConversionTarget conversionTarget(getContext());
OwningRewritePatternList conversionPatterns; OwningRewritePatternList conversionPatterns;
// Consider the xla_hlo dialect legal for tests. // Consider the mhlo dialect legal for tests.
conversionTarget.addLegalDialect<XlaHloDialect>(); conversionTarget.addLegalDialect<XlaHloDialect>();
// The conversion uses helpers from the Standard dialect. // The conversion uses helpers from the Standard dialect.
conversionTarget.addLegalDialect<mlir::StandardOpsDialect>(); conversionTarget.addLegalDialect<mlir::StandardOpsDialect>();
@ -50,9 +50,9 @@ struct TestMaterializeBroadcastsPass
} // namespace } // namespace
} // namespace xla_hlo } // namespace mhlo
} // namespace mlir } // namespace mlir
static mlir::PassRegistration<mlir::xla_hlo::TestMaterializeBroadcastsPass> static mlir::PassRegistration<mlir::mhlo::TestMaterializeBroadcastsPass> pass(
pass("test-xla-materialize-broadcasts", "test-xla-materialize-broadcasts",
"Test pass for materializing 'broadcast_dimensions' attributes"); "Test pass for materializing 'broadcast_dimensions' attributes");

View File

@ -60,7 +60,7 @@ limitations under the License.
// shape dialect once it is ready. // shape dialect once it is ready.
namespace mlir { namespace mlir {
namespace xla_hlo { namespace mhlo {
namespace { namespace {
using llvm::EquivalenceClasses; using llvm::EquivalenceClasses;
@ -544,7 +544,7 @@ struct XlaHloFusion : public mlir::PassWrapper<XlaHloFusion, FunctionPass> {
} }
FusionOp fusion = FusionOp fusion =
b.create<xla_hlo::FusionOp>(fused_loc, output_types, inputs); b.create<mhlo::FusionOp>(fused_loc, output_types, inputs);
Region& region = fusion.fused_computation(); Region& region = fusion.fused_computation();
region.push_back(new Block); region.push_back(new Block);
Block& block = region.front(); Block& block = region.front();
@ -552,7 +552,7 @@ struct XlaHloFusion : public mlir::PassWrapper<XlaHloFusion, FunctionPass> {
op->moveBefore(&block, block.end()); op->moveBefore(&block, block.end());
} }
b.setInsertionPoint(&block, block.end()); b.setInsertionPoint(&block, block.end());
b.create<xla_hlo::ReturnOp>(fused_loc, outputs); b.create<mhlo::ReturnOp>(fused_loc, outputs);
for (auto output_and_result : llvm::zip(outputs, fusion.getResults())) { for (auto output_and_result : llvm::zip(outputs, fusion.getResults())) {
Value output = std::get<0>(output_and_result); Value output = std::get<0>(output_and_result);
@ -572,8 +572,8 @@ std::unique_ptr<OperationPass<FuncOp>> createXlaHloFusion() {
return std::make_unique<XlaHloFusion>(); return std::make_unique<XlaHloFusion>();
} }
static PassRegistration<XlaHloFusion> xla_hlo_fusion_pass( static PassRegistration<XlaHloFusion> mhlo_fusion_pass(
"xla-hlo-fusion", "fuse xla_hlo ops to kLoop/kInput fusion patterns."); "xla-hlo-fusion", "fuse mhlo ops to kLoop/kInput fusion patterns.");
} // namespace xla_hlo } // namespace mhlo
} // namespace mlir } // namespace mlir

View File

@ -23,7 +23,7 @@ limitations under the License.
#include "tensorflow/compiler/mlir/hlo/include/mlir-hlo/Dialect/mhlo/IR/hlo_ops.h" #include "tensorflow/compiler/mlir/hlo/include/mlir-hlo/Dialect/mhlo/IR/hlo_ops.h"
namespace mlir { namespace mlir {
namespace xla_hlo { namespace mhlo {
namespace { namespace {
@ -81,5 +81,5 @@ std::unique_ptr<OperationPass<FuncOp>> createSinkConstantsToControlFlowPass() {
return std::make_unique<SinkConstantsToControlFlow>(); return std::make_unique<SinkConstantsToControlFlow>();
} }
} // namespace xla_hlo } // namespace mhlo
} // namespace mlir } // namespace mlir

View File

@ -25,7 +25,7 @@ limitations under the License.
#include "tensorflow/compiler/mlir/hlo/include/mlir-hlo/Dialect/mhlo/IR/hlo_ops.h" #include "tensorflow/compiler/mlir/hlo/include/mlir-hlo/Dialect/mhlo/IR/hlo_ops.h"
namespace mlir { namespace mlir {
namespace xla_hlo { namespace mhlo {
namespace { namespace {
@ -40,11 +40,11 @@ Value BroadcastToFeatureDim(Location loc, RankedTensorType result_type,
auto dims_type = RankedTensorType::get({1}, b.getIntegerType(64)); auto dims_type = RankedTensorType::get({1}, b.getIntegerType(64));
auto dims = DenseIntElementsAttr::get(dims_type, {feature_dim}); auto dims = DenseIntElementsAttr::get(dims_type, {feature_dim});
if (shape_value) { if (shape_value) {
return rewriter.createOrFold<xla_hlo::DynamicBroadcastInDimOp>( return rewriter.createOrFold<mhlo::DynamicBroadcastInDimOp>(
loc, result_type, value_1d, shape_value, dims); loc, result_type, value_1d, shape_value, dims);
} }
assert(result_type.hasStaticShape()); assert(result_type.hasStaticShape());
return rewriter.create<xla_hlo::BroadcastInDimOp>(loc, result_type, value_1d, return rewriter.create<mhlo::BroadcastInDimOp>(loc, result_type, value_1d,
dims); dims);
} }
@ -89,25 +89,25 @@ Value MaterializeEpsilon(Operation* op, FloatAttr epsilon_attr,
auto epsilon_tensor_attr = auto epsilon_tensor_attr =
DenseElementsAttr::get(scalar_type, {epsilon_attr.cast<Attribute>()}); DenseElementsAttr::get(scalar_type, {epsilon_attr.cast<Attribute>()});
Value epsilon = Value epsilon =
rewriter.create<xla_hlo::ConstOp>(op->getLoc(), epsilon_tensor_attr); rewriter.create<mhlo::ConstOp>(op->getLoc(), epsilon_tensor_attr);
auto dims_type = RankedTensorType::get({0}, b.getIntegerType(64)); auto dims_type = RankedTensorType::get({0}, b.getIntegerType(64));
auto dims = DenseIntElementsAttr::get(dims_type, SmallVector<int64_t, 1>{}); auto dims = DenseIntElementsAttr::get(dims_type, SmallVector<int64_t, 1>{});
if (broadcast_to_type.hasStaticShape()) { if (broadcast_to_type.hasStaticShape()) {
return rewriter.create<xla_hlo::BroadcastInDimOp>( return rewriter.create<mhlo::BroadcastInDimOp>(
op->getLoc(), broadcast_to_type, epsilon, /*broadcast_dims=*/dims); op->getLoc(), broadcast_to_type, epsilon, /*broadcast_dims=*/dims);
} }
Value shape_value = CalculateShapeValue(op->getLoc(), variance, rewriter); Value shape_value = CalculateShapeValue(op->getLoc(), variance, rewriter);
return rewriter.createOrFold<xla_hlo::DynamicBroadcastInDimOp>( return rewriter.createOrFold<mhlo::DynamicBroadcastInDimOp>(
op->getLoc(), broadcast_to_type, epsilon, shape_value, op->getLoc(), broadcast_to_type, epsilon, shape_value,
/*broadcast_dims=*/dims); /*broadcast_dims=*/dims);
} }
class UnfuseBatchNormInferencePattern class UnfuseBatchNormInferencePattern
: public OpRewritePattern<xla_hlo::BatchNormInferenceOp> { : public OpRewritePattern<mhlo::BatchNormInferenceOp> {
public: public:
using OpRewritePattern<xla_hlo::BatchNormInferenceOp>::OpRewritePattern; using OpRewritePattern<mhlo::BatchNormInferenceOp>::OpRewritePattern;
LogicalResult matchAndRewrite(xla_hlo::BatchNormInferenceOp bn_op, LogicalResult matchAndRewrite(mhlo::BatchNormInferenceOp bn_op,
PatternRewriter& rewriter) const override { PatternRewriter& rewriter) const override {
// Enforce type invariants. // Enforce type invariants.
// Note that we deduce the actual element type from the variance, // Note that we deduce the actual element type from the variance,
@ -132,9 +132,9 @@ class UnfuseBatchNormInferencePattern
if (!epsilon) { if (!epsilon) {
return failure(); return failure();
} }
Value stddev = rewriter.create<xla_hlo::AddOp>(bn_op.getLoc(), Value stddev =
bn_op.variance(), epsilon); rewriter.create<mhlo::AddOp>(bn_op.getLoc(), bn_op.variance(), epsilon);
stddev = rewriter.create<xla_hlo::SqrtOp>(bn_op.getLoc(), stddev); stddev = rewriter.create<mhlo::SqrtOp>(bn_op.getLoc(), stddev);
// Broadcast all terms. // Broadcast all terms.
Value shape_value; Value shape_value;
@ -156,14 +156,13 @@ class UnfuseBatchNormInferencePattern
// Compute: // Compute:
// scale * (input - mean) / stddev + offset // scale * (input - mean) / stddev + offset
Value result = rewriter.create<xla_hlo::SubOp>( Value result = rewriter.create<mhlo::SubOp>(bn_op.getLoc(), bn_op.operand(),
bn_op.getLoc(), bn_op.operand(), broadcast_mean); broadcast_mean);
result = rewriter.create<xla_hlo::MulOp>(bn_op.getLoc(), result, result =
broadcast_scale); rewriter.create<mhlo::MulOp>(bn_op.getLoc(), result, broadcast_scale);
result = rewriter.create<xla_hlo::DivOp>(bn_op.getLoc(), result, result =
broadcast_stddev); rewriter.create<mhlo::DivOp>(bn_op.getLoc(), result, broadcast_stddev);
rewriter.replaceOpWithNewOp<xla_hlo::AddOp>(bn_op, result, rewriter.replaceOpWithNewOp<mhlo::AddOp>(bn_op, result, broadcast_offset);
broadcast_offset);
return success(); return success();
} }
@ -180,5 +179,5 @@ void PopulateUnfuseBatchNormPatterns(MLIRContext* context,
patterns->insert<UnfuseBatchNormInferencePattern>(context); patterns->insert<UnfuseBatchNormInferencePattern>(context);
} }
} // namespace xla_hlo } // namespace mhlo
} // namespace mlir } // namespace mlir

View File

@ -23,7 +23,7 @@ limitations under the License.
#include "tensorflow/compiler/mlir/hlo/include/mlir-hlo/Dialect/mhlo/transforms/rewriters.h" #include "tensorflow/compiler/mlir/hlo/include/mlir-hlo/Dialect/mhlo/transforms/rewriters.h"
namespace mlir { namespace mlir {
namespace xla_hlo { namespace mhlo {
namespace { namespace {
@ -38,9 +38,9 @@ struct TestUnfuseBatchNormPass
} // namespace } // namespace
} // namespace xla_hlo } // namespace mhlo
} // namespace mlir } // namespace mlir
static mlir::PassRegistration<mlir::xla_hlo::TestUnfuseBatchNormPass> pass( static mlir::PassRegistration<mlir::mhlo::TestUnfuseBatchNormPass> pass(
"test-xla-unfuse-batch-norm", "test-xla-unfuse-batch-norm",
"Test pass for materializing 'broadcast_dimensions' attributes"); "Test pass for materializing 'broadcast_dimensions' attributes");

View File

@ -182,7 +182,7 @@ struct ConvToLinalgConverter : public OpConversionPattern<xla_lhlo::ConvOp> {
using OpConversionPattern<xla_lhlo::ConvOp>::OpConversionPattern; using OpConversionPattern<xla_lhlo::ConvOp>::OpConversionPattern;
// This code has been adapted from IREE's // This code has been adapted from IREE's
// (https://github.com/google/iree/) xla_hlo -> linalg conversion. // (https://github.com/google/iree/) mhlo -> linalg conversion.
LogicalResult matchAndRewrite( LogicalResult matchAndRewrite(
xla_lhlo::ConvOp op, ArrayRef<Value> args, xla_lhlo::ConvOp op, ArrayRef<Value> args,
ConversionPatternRewriter& rewriter) const final { ConversionPatternRewriter& rewriter) const final {
@ -348,14 +348,14 @@ class BroadcastConverter
class HloBroadcastInDimConverter class HloBroadcastInDimConverter
: public DataMovementOpConverter<HloBroadcastInDimConverter, : public DataMovementOpConverter<HloBroadcastInDimConverter,
xla_hlo::BroadcastInDimOp, false> { mhlo::BroadcastInDimOp, false> {
public: public:
using DataMovementOpConverter<HloBroadcastInDimConverter, using DataMovementOpConverter<HloBroadcastInDimConverter,
xla_hlo::BroadcastInDimOp, mhlo::BroadcastInDimOp,
false>::DataMovementOpConverter; false>::DataMovementOpConverter;
static SmallVector<AffineMap, 2> getIndexingMaps( static SmallVector<AffineMap, 2> getIndexingMaps(
xla_hlo::BroadcastInDimOp broadcastOp, Builder* b) { mhlo::BroadcastInDimOp broadcastOp, Builder* b) {
auto resultType = getXLAOpResultType<false>(broadcastOp); auto resultType = getXLAOpResultType<false>(broadcastOp);
auto operandType = auto operandType =
broadcastOp.operand().getType().template cast<ShapedType>(); broadcastOp.operand().getType().template cast<ShapedType>();
@ -845,7 +845,7 @@ struct HloLegalizeToLinalg
target.addLegalDialect<linalg::LinalgDialect, StandardOpsDialect>(); target.addLegalDialect<linalg::LinalgDialect, StandardOpsDialect>();
auto func = getFunction(); auto func = getFunction();
xla_hlo::populateHLOToLinalgConversionPattern(func.getContext(), &patterns); mhlo::populateHLOToLinalgConversionPattern(func.getContext(), &patterns);
if (failed(applyPartialConversion(func, target, patterns, nullptr))) { if (failed(applyPartialConversion(func, target, patterns, nullptr))) {
signalPassFailure(); signalPassFailure();
} }
@ -863,40 +863,40 @@ static PassRegistration<LhloLegalizeToLinalg> legalize_lhlo_pass(
"lhlo-legalize-to-linalg", "Legalize from LHLO dialect to Linalg dialect"); "lhlo-legalize-to-linalg", "Legalize from LHLO dialect to Linalg dialect");
} // namespace xla_lhlo } // namespace xla_lhlo
namespace xla_hlo { namespace mhlo {
void populateHLOToLinalgConversionPattern(MLIRContext* context, void populateHLOToLinalgConversionPattern(MLIRContext* context,
OwningRewritePatternList* patterns) { OwningRewritePatternList* patterns) {
patterns->insert<BroadcastConverter<xla_hlo::BroadcastOp, false>, patterns->insert<BroadcastConverter<mhlo::BroadcastOp, false>,
HloBroadcastInDimConverter, HloBroadcastInDimConverter,
PointwiseToLinalgConverter<xla_hlo::AbsOp, false>, PointwiseToLinalgConverter<mhlo::AbsOp, false>,
PointwiseToLinalgConverter<xla_hlo::AddOp, false>, PointwiseToLinalgConverter<mhlo::AddOp, false>,
PointwiseToLinalgConverter<xla_hlo::AndOp, false>, PointwiseToLinalgConverter<mhlo::AndOp, false>,
PointwiseToLinalgConverter<xla_hlo::CeilOp, false>, PointwiseToLinalgConverter<mhlo::CeilOp, false>,
PointwiseToLinalgConverter<xla_hlo::CompareOp, false>, PointwiseToLinalgConverter<mhlo::CompareOp, false>,
PointwiseToLinalgConverter<xla_hlo::ComplexOp, false>, PointwiseToLinalgConverter<mhlo::ComplexOp, false>,
PointwiseToLinalgConverter<xla_hlo::ConvertOp, false>, PointwiseToLinalgConverter<mhlo::ConvertOp, false>,
PointwiseToLinalgConverter<xla_hlo::CopyOp, false>, PointwiseToLinalgConverter<mhlo::CopyOp, false>,
PointwiseToLinalgConverter<xla_hlo::CosOp, false>, PointwiseToLinalgConverter<mhlo::CosOp, false>,
PointwiseToLinalgConverter<xla_hlo::DivOp, false>, PointwiseToLinalgConverter<mhlo::DivOp, false>,
PointwiseToLinalgConverter<xla_hlo::ExpOp, false>, PointwiseToLinalgConverter<mhlo::ExpOp, false>,
PointwiseToLinalgConverter<xla_hlo::ImagOp, false>, PointwiseToLinalgConverter<mhlo::ImagOp, false>,
PointwiseToLinalgConverter<xla_hlo::LogOp, false>, PointwiseToLinalgConverter<mhlo::LogOp, false>,
PointwiseToLinalgConverter<xla_hlo::MaxOp, false>, PointwiseToLinalgConverter<mhlo::MaxOp, false>,
PointwiseToLinalgConverter<xla_hlo::MinOp, false>, PointwiseToLinalgConverter<mhlo::MinOp, false>,
PointwiseToLinalgConverter<xla_hlo::MulOp, false>, PointwiseToLinalgConverter<mhlo::MulOp, false>,
PointwiseToLinalgConverter<xla_hlo::NegOp, false>, PointwiseToLinalgConverter<mhlo::NegOp, false>,
PointwiseToLinalgConverter<xla_hlo::RealOp, false>, PointwiseToLinalgConverter<mhlo::RealOp, false>,
PointwiseToLinalgConverter<xla_hlo::RemOp, false>, PointwiseToLinalgConverter<mhlo::RemOp, false>,
PointwiseToLinalgConverter<xla_hlo::RsqrtOp, false>, PointwiseToLinalgConverter<mhlo::RsqrtOp, false>,
PointwiseToLinalgConverter<xla_hlo::SelectOp, false>, PointwiseToLinalgConverter<mhlo::SelectOp, false>,
PointwiseToLinalgConverter<xla_hlo::SinOp, false>, PointwiseToLinalgConverter<mhlo::SinOp, false>,
PointwiseToLinalgConverter<xla_hlo::SqrtOp, false>, PointwiseToLinalgConverter<mhlo::SqrtOp, false>,
PointwiseToLinalgConverter<xla_hlo::SubOp, false>, PointwiseToLinalgConverter<mhlo::SubOp, false>,
PointwiseToLinalgConverter<xla_hlo::TanhOp, false>, PointwiseToLinalgConverter<mhlo::TanhOp, false>,
ReshapeOpConverter<xla_hlo::ReshapeOp, false>, ReshapeOpConverter<mhlo::ReshapeOp, false>,
ReverseConverter<xla_hlo::ReverseOp, false>, ReverseConverter<mhlo::ReverseOp, false>,
TransposeConverter<xla_hlo::TransposeOp, false>>(context); TransposeConverter<mhlo::TransposeOp, false>>(context);
} }
std::unique_ptr<OperationPass<FuncOp>> createLegalizeHloToLinalgPass() { std::unique_ptr<OperationPass<FuncOp>> createLegalizeHloToLinalgPass() {
@ -905,5 +905,5 @@ std::unique_ptr<OperationPass<FuncOp>> createLegalizeHloToLinalgPass() {
static PassRegistration<HloLegalizeToLinalg> legalize_hlo_pass( static PassRegistration<HloLegalizeToLinalg> legalize_hlo_pass(
"hlo-legalize-to-linalg", "Legalize from HLO dialect to Linalg dialect"); "hlo-legalize-to-linalg", "Legalize from HLO dialect to Linalg dialect");
} // namespace xla_hlo } // namespace mhlo
} // namespace mlir } // namespace mlir

View File

@ -28,7 +28,7 @@ limitations under the License.
#include "tensorflow/compiler/mlir/hlo/include/mlir-hlo/Dialect/mhlo/transforms/rewriters.h" #include "tensorflow/compiler/mlir/hlo/include/mlir-hlo/Dialect/mhlo/transforms/rewriters.h"
namespace mlir { namespace mlir {
namespace xla_hlo { namespace mhlo {
namespace { namespace {
// TODO(frgossen): Make it variadic. // TODO(frgossen): Make it variadic.
@ -69,7 +69,7 @@ struct UnaryElementwiseOpConversion : public OpRewritePattern<OpTy> {
rewriter.create<TensorFromElementsOp>(loc, numElementsAsIndex); rewriter.create<TensorFromElementsOp>(loc, numElementsAsIndex);
auto flatTensorTy = RankedTensorType::get({ShapedType::kDynamicSize}, auto flatTensorTy = RankedTensorType::get({ShapedType::kDynamicSize},
operandTy.getElementType()); operandTy.getElementType());
Value flatOperand = rewriter.create<xla_hlo::DynamicReshapeOp>( Value flatOperand = rewriter.create<mhlo::DynamicReshapeOp>(
loc, flatTensorTy, operand, flatShapeAsDimTensor); loc, flatTensorTy, operand, flatShapeAsDimTensor);
// Generate IR for the actual operation. // Generate IR for the actual operation.
@ -80,7 +80,7 @@ struct UnaryElementwiseOpConversion : public OpRewritePattern<OpTy> {
rewriter.getIndexType()); rewriter.getIndexType());
Value shapeAsExtentTensor = Value shapeAsExtentTensor =
rewriter.create<shape::ToExtentTensorOp>(loc, extentTensorTy, shape); rewriter.create<shape::ToExtentTensorOp>(loc, extentTensorTy, shape);
Value result = rewriter.create<xla_hlo::DynamicReshapeOp>( Value result = rewriter.create<mhlo::DynamicReshapeOp>(
loc, operandTy, flatResult, shapeAsExtentTensor); loc, operandTy, flatResult, shapeAsExtentTensor);
rewriter.replaceOp(op, result); rewriter.replaceOp(op, result);
@ -184,5 +184,5 @@ static PassRegistration<TransformUnrankedHloPass> transform_unranked_hlo_pass(
"transform-unranked-hlo", "transform-unranked-hlo",
"Realize element-wise operations on ranked tensors where possible"); "Realize element-wise operations on ranked tensors where possible");
} // namespace xla_hlo } // namespace mhlo
} // namespace mlir } // namespace mlir

View File

@ -0,0 +1,19 @@
load("//tensorflow/compiler/mlir:glob_lit_test.bzl", "glob_lit_tests")
package(licenses = ["notice"])
glob_lit_tests(
data = [":test_utilities"],
driver = "@llvm-project//mlir:run_lit.sh",
test_file_exts = ["mlir"],
)
# Bundle together all of the test utilities that are used by tests.
filegroup(
name = "test_utilities",
testonly = True,
data = [
"//tensorflow/compiler/mlir/hlo:mlir-hlo-opt",
"@llvm-project//llvm:FileCheck",
],
)

View File

@ -0,0 +1,499 @@
// RUN: mlir-hlo-opt %s -pass-pipeline='func(canonicalize)' | FileCheck %s
// CHECK-LABEL: add_fold
func @add_fold() -> tensor<4xi64> {
%0 = mhlo.constant dense<[1, 2, 3, 4]> : tensor<4xi64>
%1 = mhlo.constant dense<[5, 6, 7, 8]> : tensor<4xi64>
// CHECK: mhlo.constant dense<[6, 8, 10, 12]>
%2 = "mhlo.add"(%0, %1) : (tensor<4xi64>, tensor<4xi64>) -> (tensor<4xi64>)
return %2 : tensor<4xi64>
}
// CHECK-LABEL: add_scalar_fold
func @add_scalar_fold() -> tensor<4xi64> {
%0 = mhlo.constant dense<1> : tensor<4xi64>
%1 = mhlo.constant dense<5> : tensor<4xi64>
// CHECK: mhlo.constant dense<6>
%2 = "mhlo.add"(%0, %1) : (tensor<4xi64>, tensor<4xi64>) -> (tensor<4xi64>)
return %2 : tensor<4xi64>
}
// CHECK-LABEL: add_fold_float
func @add_fold_float() -> tensor<4xf64> {
%0 = mhlo.constant dense<[1.0, 2.0, 3.0, 4.0]> : tensor<4xf64>
%1 = mhlo.constant dense<[5.0, 6.0, 7.0, 8.0]> : tensor<4xf64>
// CHECK: mhlo.constant dense<[6.000000e+00, 8.000000e+00, 1.000000e+01, 1.200000e+01]>
%2 = "mhlo.add"(%0, %1) : (tensor<4xf64>, tensor<4xf64>) -> (tensor<4xf64>)
return %2 : tensor<4xf64>
}
// CHECK-LABEL: sub_scalar_fold
func @sub_scalar_fold() -> tensor<4xi64> {
%0 = mhlo.constant dense<5> : tensor<4xi64>
%1 = mhlo.constant dense<1> : tensor<4xi64>
// CHECK: mhlo.constant dense<4>
%2 = "mhlo.subtract"(%0, %1) : (tensor<4xi64>, tensor<4xi64>) -> (tensor<4xi64>)
return %2 : tensor<4xi64>
}
// CHECK-LABEL: multiply_scalar_fold
func @multiply_scalar_fold() -> tensor<4xi64> {
%0 = mhlo.constant dense<5> : tensor<4xi64>
%1 = mhlo.constant dense<3> : tensor<4xi64>
// CHECK: mhlo.constant dense<15>
%2 = "mhlo.multiply"(%0, %1) : (tensor<4xi64>, tensor<4xi64>) -> (tensor<4xi64>)
return %2 : tensor<4xi64>
}
// CHECK-LABEL: divide_scalar_fold
func @divide_scalar_fold() -> tensor<4xi64> {
%0 = mhlo.constant dense<7> : tensor<4xi64>
%1 = mhlo.constant dense<5> : tensor<4xi64>
// CHECK: mhlo.constant dense<1>
%2 = "mhlo.divide"(%0, %1) : (tensor<4xi64>, tensor<4xi64>) -> (tensor<4xi64>)
return %2 : tensor<4xi64>
}
// CHECK-LABEL: divide_fold_float
func @divide_fold_float() -> tensor<4xf64> {
%0 = mhlo.constant dense<[5.0, 66.0, 5.0, 1.0]> : tensor<4xf64>
%1 = mhlo.constant dense<[5.0, 3.0, 2.0, 4.0]> : tensor<4xf64>
// CHECK: mhlo.constant dense<[1.000000e+00, 2.200000e+01, 2.500000e+00, 2.500000e-01]>
%2 = "mhlo.divide"(%0, %1) : (tensor<4xf64>, tensor<4xf64>) -> (tensor<4xf64>)
return %2 : tensor<4xf64>
}
// CHECK-LABEL: max_scalar_fold
func @max_scalar_fold() -> tensor<4xi64> {
%0 = mhlo.constant dense<7> : tensor<4xi64>
%1 = mhlo.constant dense<5> : tensor<4xi64>
// CHECK: mhlo.constant dense<7>
%2 = "mhlo.maximum"(%0, %1) : (tensor<4xi64>, tensor<4xi64>) -> (tensor<4xi64>)
return %2 : tensor<4xi64>
}
// CHECK-LABEL: max_fold_float
func @max_fold_float() -> tensor<4xf64> {
%0 = mhlo.constant dense<[5.0, 66.0, 5.0, 1.0]> : tensor<4xf64>
%1 = mhlo.constant dense<[5.0, 3.0, 2.0, 4.0]> : tensor<4xf64>
// CHECK: mhlo.constant dense<[5.000000e+00, 6.600000e+01, 5.000000e+00, 4.000000e+00]>
%2 = "mhlo.maximum"(%0, %1) : (tensor<4xf64>, tensor<4xf64>) -> (tensor<4xf64>)
return %2 : tensor<4xf64>
}
// CHECK-LABEL: min_scalar_fold
func @min_scalar_fold() -> tensor<4xi64> {
%0 = mhlo.constant dense<7> : tensor<4xi64>
%1 = mhlo.constant dense<-5> : tensor<4xi64>
// CHECK: mhlo.constant dense<-5>
%2 = "mhlo.minimum"(%0, %1) : (tensor<4xi64>, tensor<4xi64>) -> (tensor<4xi64>)
return %2 : tensor<4xi64>
}
// CHECK-LABEL: min_fold_float
func @min_fold_float() -> tensor<4xf64> {
%0 = mhlo.constant dense<[5.0, 66.0, 5.0, 1.0]> : tensor<4xf64>
%1 = mhlo.constant dense<[5.0, 3.0, 2.0, 4.0]> : tensor<4xf64>
// CHECK: mhlo.constant dense<[5.000000e+00, 3.000000e+00, 2.000000e+00, 1.000000e+00]>
%2 = "mhlo.minimum"(%0, %1) : (tensor<4xf64>, tensor<4xf64>) -> (tensor<4xf64>)
return %2 : tensor<4xf64>
}
// CHECK-LABEL: concatenate_noop
func @concatenate_noop(%arg0: tensor<4xi32>) -> tensor<4xi32> {
// CHECK-SAME: [[ARG:%.+]]: tensor<4xi32>
%0 = "mhlo.concatenate"(%arg0) { dimension = 0 : i64 } : (tensor<4xi32>) -> tensor<4xi32>
// CHECK: return [[ARG]]
return %0 : tensor<4xi32>
}
// CHECK-LABEL: concatenate_remove_operand
func @concatenate_remove_operand(%arg0: tensor<4xi32>, %arg1: tensor<0xi32>) -> tensor<4xi32> {
// CHECK-SAME: [[ARG0:%.+]]: tensor<4xi32>
// CHECK-SAME: [[ARG1:%.+]]: tensor<0xi32>
%0 = "mhlo.concatenate"(%arg0, %arg1) { dimension = 0 : i64 } : (tensor<4xi32>, tensor<0xi32>) -> tensor<4xi32>
// CHECK: return [[ARG0]]
return %0 : tensor<4xi32>
}
// CHECK-LABEL: concatenate_empty_bool
func @concatenate_empty_bool(%arg0: tensor<0xi1>, %arg1: tensor<0xi1>) -> tensor<0xi1> {
// CHECK: mhlo.constant
%0 = "mhlo.concatenate"(%arg0, %arg1) { dimension = 0 : i64 } : (tensor<0xi1>, tensor<0xi1>) -> tensor<0xi1>
return %0 : tensor<0xi1>
}
// CHECK-LABEL: concatenate_empty_int
func @concatenate_empty_int(%arg0: tensor<0xi32>, %arg1: tensor<0xi32>) -> tensor<0xi32> {
// CHECK: mhlo.constant
%0 = "mhlo.concatenate"(%arg0, %arg1) { dimension = 0 : i64 } : (tensor<0xi32>, tensor<0xi32>) -> tensor<0xi32>
return %0 : tensor<0xi32>
}
// CHECK-LABEL: concatenate_empty_float
func @concatenate_empty_float(%arg0: tensor<0xf32>, %arg1: tensor<0xf32>) -> tensor<0xf32> {
// CHECK: mhlo.constant
%0 = "mhlo.concatenate"(%arg0, %arg1) { dimension = 0 : i64 } : (tensor<0xf32>, tensor<0xf32>) -> tensor<0xf32>
return %0 : tensor<0xf32>
}
// CHECK-LABEL: concatenate_const_1D
func @concatenate_const_1D() -> tensor<4xi32> {
// CHECK: [[VAL:%.+]]= mhlo.constant dense<[0, 1, 2, 3]>
%0 = mhlo.constant dense<[0, 1]> : tensor<2xi32>
%1 = mhlo.constant dense<[2, 3]> : tensor<2xi32>
%2 = "mhlo.concatenate"(%0, %1) { dimension = 0 : i64 } : (tensor<2xi32>, tensor<2xi32>) -> tensor<4xi32>
// CHECK: return [[VAL]]
return %2 : tensor<4xi32>
}
// CHECK-LABEL: concatenate_const_1D_float
func @concatenate_const_1D_float() -> tensor<4xf32> {
// CHECK: [[VAL:%.+]] = mhlo.constant dense<[0.000000e+00, 1.000000e+00, 2.000000e+00, 3.000000e+00]>
%0 = mhlo.constant dense<[0.0, 1.0]> : tensor<2xf32>
%1 = mhlo.constant dense<[2.0, 3.0]> : tensor<2xf32>
%2 = "mhlo.concatenate"(%0, %1) { dimension = 0 : i64 } : (tensor<2xf32>, tensor<2xf32>) -> tensor<4xf32>
// CHECK: return [[VAL]]
return %2 : tensor<4xf32>
}
// CHECK-LABEL: concatenate_const_2D_vertical
func @concatenate_const_2D_vertical() -> tensor<2x2xi32> {
// CHECK: [[VAL:%.+]]= mhlo.constant dense<[
// CHECK-SAME: [0, 1], [2, 3]
// CHECK-SAME: ]>
%0 = mhlo.constant dense<[[0, 1]]> : tensor<1x2xi32>
%1 = mhlo.constant dense<[[2, 3]]> : tensor<1x2xi32>
%2 = "mhlo.concatenate"(%0, %1) { dimension = 0 : i64 } : (tensor<1x2xi32>, tensor<1x2xi32>) -> tensor<2x2xi32>
// CHECK: return [[VAL]]
return %2 : tensor<2x2xi32>
}
// CHECK-LABEL: concatenate_const_2D_horizontal
func @concatenate_const_2D_horizontal() -> tensor<2x2xi32> {
// CHECK: [[VAL:%.+]]= mhlo.constant dense<[
// CHECK-SAME: [0, 2], [1, 3]
// CHECK-SAME: ]>
%0 = mhlo.constant dense<[[0], [1]]> : tensor<2x1xi32>
%1 = mhlo.constant dense<[[2], [3]]> : tensor<2x1xi32>
%2 = "mhlo.concatenate"(%0, %1) { dimension = 1 : i64 } : (tensor<2x1xi32>, tensor<2x1xi32>) -> tensor<2x2xi32>
// CHECK: return [[VAL]]
return %2 : tensor<2x2xi32>
}
// CHECK-LABEL: dynamic_slice_variable_start
func @dynamic_slice_variable_start(%arg0: tensor<3x4xi32>, %arg1: tensor<i64>, %arg2: tensor<i64>) -> tensor<1x4xi32> {
// CHECK: "mhlo.dynamic-slice"
%1 = "mhlo.dynamic-slice"(%arg0, %arg1, %arg2) {slice_sizes = dense<[1, 4]> : tensor<2xi64>} : (tensor<3x4xi32>, tensor<i64>, tensor<i64>) -> tensor<1x4xi32>
return %1 : tensor<1x4xi32>
}
// CHECK-LABEL: dynamic_slice_constant_start
func @dynamic_slice_constant_start(%arg0: tensor<4xi32>) -> tensor<2xi32> {
// CHECK: %[[RESULT:.*]] = "mhlo.slice"(%arg0)
// CHECK-DAG-SAME: limit_indices = dense<3> : tensor<1xi64>
// CHECK-DAG-SAME: start_indices = dense<1> : tensor<1xi64>
// CHECK-DAG-SAME: strides = dense<1> : tensor<1xi64>}
// CHECK: return %[[RESULT]] : tensor<2xi32>
%0 = mhlo.constant dense<1> : tensor<i64>
%1 = "mhlo.dynamic-slice"(%arg0, %0) {slice_sizes = dense<2> : tensor<1xi64>} : (tensor<4xi32>, tensor<i64>) -> tensor<2xi32>
return %1 : tensor<2xi32>
}
// CHECK-LABEL: dynamic_slice_constant_start_dynamic_shape
func @dynamic_slice_constant_start_dynamic_shape(%arg0: tensor<?x4xi32>, %arg1: tensor<2xi64>) -> tensor<?x4xi32> {
// CHECK: %[[RESULT:.*]] = "mhlo.slice"(%arg0)
// CHECK-DAG-SAME: limit_indices = dense<[2, 4]> : tensor<2xi64>
// CHECK-DAG-SAME: start_indices = dense<[1, 0]> : tensor<2xi64>
// CHECK-DAG-SAME: strides = dense<1> : tensor<2xi64>
// CHECK: return %[[RESULT]] : tensor<?x4xi32>
%0 = mhlo.constant dense<1> : tensor<i64>
%1 = mhlo.constant dense<0> : tensor<i64>
%2 = "mhlo.dynamic-slice"(%arg0, %0, %1) {slice_sizes = dense<[1, 4]> : tensor<2xi64>} : (tensor<?x4xi32>, tensor<i64>, tensor<i64>) -> tensor<?x4xi32>
return %2 : tensor<?x4xi32>
}
// CHECK-LABEL: slice_2D_noop
// CHECK-SAME: [[ARG:%.+]]: tensor<2x2xi64>
func @slice_2D_noop(%arg0: tensor<2x2xi64>) -> tensor<2x2xi64> {
%0 = "mhlo.slice"(%arg0) { limit_indices = dense<[2, 2]> : tensor<2xi64>, start_indices = dense<[0, 0]> : tensor<2xi64>, strides = dense<1> : tensor<2xi64>} : (tensor<2x2xi64>) -> (tensor<2x2xi64>)
// CHECK-NEXT: return [[ARG]]
return %0 : tensor<2x2xi64>
}
// CHECK-LABEL: slice_1D_fold
func @slice_1D_fold() -> tensor<2xi64> {
%0 = mhlo.constant dense<[5, 7, 9, 10]> : tensor<4xi64>
// CHECK: mhlo.constant dense<[7, 9]>
%1 = "mhlo.slice"(%0) { limit_indices = dense<[3]> : tensor<1xi64>, start_indices = dense<[1]> : tensor<1xi64>, strides = dense<1> : tensor<1xi64>} : (tensor<4xi64>) -> (tensor<2xi64>)
return %1 : tensor<2xi64>
}
// CHECK-LABEL: slice_1D_fp
func @slice_1D_fp() -> tensor<2xf32> {
%0 = mhlo.constant dense<[5.0, 7.0, 9.0, 10.0]> : tensor<4xf32>
// CHECK: mhlo.constant dense<[7.000000e+00, 9.000000e+00]>
%1 = "mhlo.slice"(%0) { limit_indices = dense<[3]> : tensor<1xi64>, start_indices = dense<[1]> : tensor<1xi64>, strides = dense<1> : tensor<1xi64>} : (tensor<4xf32>) -> (tensor<2xf32>)
return %1 : tensor<2xf32>
}
// CHECK-LABEL: slice_1D_strided_fold
func @slice_1D_strided_fold() -> tensor<2xi64> {
%0 = mhlo.constant dense<[5, 7, 9, 10]> : tensor<4xi64>
// CHECK: mhlo.constant dense<[7, 10]>
%1 = "mhlo.slice"(%0) { limit_indices = dense<[4]> : tensor<1xi64>, start_indices = dense<[1]> : tensor<1xi64>, strides = dense<2> : tensor<1xi64>} : (tensor<4xi64>) -> (tensor<2xi64>)
return %1 : tensor<2xi64>
}
// CHECK-LABEL: slice_2D_fold
func @slice_2D_fold() -> tensor<2x2xi64> {
%0 = mhlo.constant dense<[[0, 1, 2, 3], [4, 5, 6, 7], [8, 9, 10, 11], [12, 13, 14, 15]]> : tensor<4x4xi64>
// CHECK-NEXT: mhlo.constant dense<[
// CHECK-SAME: [6, 7],
// CHECK-SAME: [10, 11]
// CHECK-SAME: ]>
%1 = "mhlo.slice"(%0) { limit_indices = dense<[3, 4]> : tensor<2xi64>, start_indices = dense<[1, 2]> : tensor<2xi64>, strides = dense<1> : tensor<2xi64>} : (tensor<4x4xi64>) -> (tensor<2x2xi64>)
return %1 : tensor<2x2xi64>
}
// CHECK-LABEL: slice_2D_fold_horizontal
func @slice_2D_fold_horizontal() -> tensor<1x4xi64> {
%0 = mhlo.constant dense<[[0, 1, 2, 3], [4, 5, 6, 7], [8, 9, 10, 11], [12, 13, 14, 15]]> : tensor<4x4xi64>
// CHECK-NEXT: mhlo.constant dense<[
// CHECK-SAME: [0, 1, 2, 3]
// CHECK-SAME: ]>
%1 = "mhlo.slice"(%0) { limit_indices = dense<[1, 4]> : tensor<2xi64>, start_indices = dense<[0, 0]> : tensor<2xi64>, strides = dense<1> : tensor<2xi64>} : (tensor<4x4xi64>) -> (tensor<1x4xi64>)
return %1 : tensor<1x4xi64>
}
// CHECK-LABEL: slice_2D_fold_vertical
func @slice_2D_fold_vertical() -> tensor<4x1xi64> {
%0 = mhlo.constant dense<[[0, 1, 2, 3], [4, 5, 6, 7], [8, 9, 10, 11], [12, 13, 14, 15]]> : tensor<4x4xi64>
// CHECK-NEXT: mhlo.constant dense<[
// CHECK-SAME: [2], [6], [10], [14]
// CHECK-SAME: ]>
%1 = "mhlo.slice"(%0) { limit_indices = dense<[4, 3]> : tensor<2xi64>, start_indices = dense<[0, 2]> : tensor<2xi64>, strides = dense<1> : tensor<2xi64>} : (tensor<4x4xi64>) -> (tensor<4x1xi64>)
return %1 : tensor<4x1xi64>
}
// CHECK-LABEL: slice_concat_fold_first
func @slice_concat_fold_first(%arg0: tensor<1x5xf32>, %arg1: tensor<1x5xf32>) -> tensor<1x5xf32> {
%0 = "mhlo.concatenate"(%arg0, %arg1) { dimension = 0 : i64 } : (tensor<1x5xf32>, tensor<1x5xf32>) -> tensor<2x5xf32>
%1 = "mhlo.slice"(%0) { limit_indices = dense<[1, 5]> : tensor<2xi64>, start_indices = dense<[0, 0]> : tensor<2xi64>, strides = dense<1> : tensor<2xi64>} : (tensor<2x5xf32>) -> (tensor<1x5xf32>)
// CHECK: return %arg0
return %1 : tensor<1x5xf32>
}
// CHECK-LABEL: slice_concat_fold_second
func @slice_concat_fold_second(%arg0: tensor<1x5xf32>, %arg1: tensor<1x5xf32>) -> tensor<1x5xf32> {
%0 = "mhlo.concatenate"(%arg0, %arg1) { dimension = 0 : i64 } : (tensor<1x5xf32>, tensor<1x5xf32>) -> tensor<2x5xf32>
%1 = "mhlo.slice"(%0) { limit_indices = dense<[2, 5]> : tensor<2xi64>, start_indices = dense<[1, 0]> : tensor<2xi64>, strides = dense<1> : tensor<2xi64>} : (tensor<2x5xf32>) -> (tensor<1x5xf32>)
// CHECK: return %arg1
return %1 : tensor<1x5xf32>
}
// CHECK-LABEL: slice_concat_fold_second_with_slice
func @slice_concat_fold_second_with_slice(%arg0: tensor<1x5xf32>, %arg1: tensor<1x5xf32>) -> tensor<1x4xf32> {
%0 = "mhlo.concatenate"(%arg0, %arg1) { dimension = 0 : i64 } : (tensor<1x5xf32>, tensor<1x5xf32>) -> tensor<2x5xf32>
// CHECK: [[SLICE:%.+]] = "mhlo.slice"(%arg1) {limit_indices = dense<[1, 5]> : tensor<2xi64>, start_indices = dense<[0, 1]> : tensor<2xi64>, strides = dense<1> : tensor<2xi64>} : (tensor<1x5xf32>) -> tensor<1x4xf32>
%1 = "mhlo.slice"(%0) { limit_indices = dense<[2, 5]> : tensor<2xi64>, start_indices = dense<[1, 1]> : tensor<2xi64>, strides = dense<1> : tensor<2xi64>} : (tensor<2x5xf32>) -> (tensor<1x4xf32>)
// CHECK: return [[SLICE]]
return %1 : tensor<1x4xf32>
}
// CHECK-LABEL: slice_concat_fold_middle
func @slice_concat_fold_middle(%arg0: tensor<1x5xf32>, %arg1: tensor<2x5xf32>, %arg2: tensor<1x5xf32>) -> tensor<1x5xf32> {
%0 = "mhlo.concatenate"(%arg0, %arg1, %arg2) { dimension = 0 : i64 } : (tensor<1x5xf32>, tensor<2x5xf32>, tensor<1x5xf32>) -> tensor<4x5xf32>
// CHECK: [[SLICE:%.+]] = "mhlo.slice"(%arg1) {limit_indices = dense<[2, 5]> : tensor<2xi64>, start_indices = dense<[1, 0]> : tensor<2xi64>, strides = dense<1> : tensor<2xi64>}
%1 = "mhlo.slice"(%0) { limit_indices = dense<[3, 5]> : tensor<2xi64>, start_indices = dense<[2, 0]> : tensor<2xi64>, strides = dense<1> : tensor<2xi64>} : (tensor<4x5xf32>) -> (tensor<1x5xf32>)
// CHECK: return [[SLICE]]
return %1 : tensor<1x5xf32>
}
// CHECK-LABEL: slice_concat_fold_two
func @slice_concat_fold_two(%arg0: tensor<1x5xf32>, %arg1: tensor<2x5xf32>, %arg2: tensor<1x5xf32>) -> tensor<2x5xf32> {
// CHECK: [[CONCAT:%.+]] = "mhlo.concatenate"(%arg1, %arg2) {dimension = 0 : i64}
%0 = "mhlo.concatenate"(%arg0, %arg1, %arg2) { dimension = 0 : i64 } : (tensor<1x5xf32>, tensor<2x5xf32>, tensor<1x5xf32>) -> tensor<4x5xf32>
// CHECK: [[SLICE:%.+]] = "mhlo.slice"([[CONCAT]]) {limit_indices = dense<[3, 5]> : tensor<2xi64>, start_indices = dense<[1, 0]> : tensor<2xi64>, strides = dense<1> : tensor<2xi64>}
%1 = "mhlo.slice"(%0) { limit_indices = dense<[4, 5]> : tensor<2xi64>, start_indices = dense<[2, 0]> : tensor<2xi64>, strides = dense<1> : tensor<2xi64>} : (tensor<4x5xf32>) -> (tensor<2x5xf32>)
// CHECK: return [[SLICE]]
return %1 : tensor<2x5xf32>
}
// CHECK-LABEL: func @broadcast_in_dim_identity
func @broadcast_in_dim_identity(%arg0: tensor<2x3x4xf32>) -> tensor<2x3x4xf32> {
// CHECK: return %arg0
%0 = "mhlo.broadcast_in_dim"(%arg0) {broadcast_dimensions = dense<[0, 1, 2]> : tensor<3xi64>} : (tensor<2x3x4xf32>) -> tensor<2x3x4xf32>
return %0 : tensor<2x3x4xf32>
}
// CHECK-LABEL: func @broadcast_in_dim_not_identity_because_it_actually_broadcasts
func @broadcast_in_dim_not_identity_because_it_actually_broadcasts(%arg0: tensor<1x2xf32>) -> tensor<2x2xf32> {
// CHECK: mhlo.broadcast_in_dim
%0 = "mhlo.broadcast_in_dim"(%arg0) {broadcast_dimensions = dense<[0, 1]> : tensor<2xi64>} : (tensor<1x2xf32>) -> tensor<2x2xf32>
return %0 : tensor<2x2xf32>
}
// CHECK-LABEL: func @broadcast_in_dim_not_identity_permutation
func @broadcast_in_dim_not_identity_permutation(%arg0: tensor<2x2xf32>) -> tensor<2x2xf32> {
// CHECK: mhlo.broadcast_in_dim
%0 = "mhlo.broadcast_in_dim"(%arg0) {broadcast_dimensions = dense<[1, 0]> : tensor<2xi64>} : (tensor<2x2xf32>) -> tensor<2x2xf32>
return %0 : tensor<2x2xf32>
}
// CHECK-LABEL: func @dynamic_broadcast_in_dim_op_not_actually_dynamic
func @dynamic_broadcast_in_dim_op_not_actually_dynamic(%arg0: tensor<4xf32>, %arg1: tensor<2xi64>) -> tensor<5x4xf32> {
// CHECK: %[[RESULT:.+]] = "mhlo.broadcast_in_dim"(%arg0) {broadcast_dimensions = dense<1> : tensor<1xi64>} : (tensor<4xf32>) -> tensor<5x4xf32>
%0 = "mhlo.dynamic_broadcast_in_dim"(%arg0, %arg1) { broadcast_dimensions = dense<1> : tensor<1xi64> } : (tensor<4xf32>, tensor<2xi64>) -> tensor<5x4xf32>
// CHECK: return %[[RESULT]] : tensor<5x4xf32>
return %0 : tensor<5x4xf32>
}
// CHECK-LABEL: @complex_expand_fold
func @complex_expand_fold(%arg0: tensor<4xf32>, %arg1: tensor<4xf32>) -> (tensor<4xf32>, tensor<4xf32>) {
%0 = "mhlo.complex"(%arg0, %arg1) : (tensor<4xf32>, tensor<4xf32>) -> (tensor<4xcomplex<f32>>)
%1 = "mhlo.real"(%0) : (tensor<4xcomplex<f32>>) -> (tensor<4xf32>)
%2 = "mhlo.imag"(%0) : (tensor<4xcomplex<f32>>) -> (tensor<4xf32>)
// CHECK: return %arg0, %arg1
return %1, %2 : tensor<4xf32>, tensor<4xf32>
}
// CHECK-LABEL: @complex_collapse_fold
func @complex_collapse_fold(%arg0: tensor<4xcomplex<f32>>) -> tensor<4xcomplex<f32>> {
%0 = "mhlo.real"(%arg0) : (tensor<4xcomplex<f32>>) -> (tensor<4xf32>)
%1 = "mhlo.imag"(%arg0) : (tensor<4xcomplex<f32>>) -> (tensor<4xf32>)
%2 = "mhlo.complex"(%0, %1) : (tensor<4xf32>, tensor<4xf32>) -> tensor<4xcomplex<f32>>
// CHECK: return %arg0
return %2 : tensor<4xcomplex<f32>>
}
// CHECK-LABEL: @dynamic_iota_is_static
func @dynamic_iota_is_static(%arg0 : tensor<1xindex>) -> tensor<4xi32> {
// CHECK: [[RESULT:%.*]] = "mhlo.iota"
// CHECK: return [[RESULT]]
%0 = "mhlo.dynamic_iota"(%arg0) {iota_dimension = 0 : i64} : (tensor<1xindex>) -> tensor<4xi32>
return %0 : tensor<4xi32>
}
// CHECK-LABEL: @dynamic_iota_broadcast
func @dynamic_iota_broadcast(%arg0 : tensor<2xindex>) -> tensor<5x?xi32> {
// CHECK: [[IOTA:%.+]] = "mhlo.iota"() {iota_dimension = 0 : i64} : () -> tensor<5xi32>
// CHECK: [[BROADCAST:%.+]] = "mhlo.dynamic_broadcast_in_dim"([[IOTA]], %arg0) {broadcast_dimensions = dense<0> : tensor<1xi64>} : (tensor<5xi32>, tensor<2xindex>) -> tensor<5x?xi32>
%0 = "mhlo.dynamic_iota"(%arg0) {iota_dimension = 0 : i64} : (tensor<2xindex>) -> tensor<5x?xi32>
// CHECK: return [[BROADCAST]]
return %0 : tensor<5x?xi32>
}
// CHECK-LABEL: @dynamic_iota_broadcast_second
func @dynamic_iota_broadcast_second(%arg0 : tensor<2xindex>) -> tensor<5x?xi32> {
// CHECK-NEXT: [[CAST1:%.+]] = index_cast %arg0 : tensor<2xindex> to tensor<2xi64>
// CHECK-NEXT: [[SLICE:%.+]] = "mhlo.slice"([[CAST1]]) {limit_indices = dense<2> : tensor<1xi64>, start_indices = dense<1> : tensor<1xi64>, strides = dense<1> : tensor<1xi64>} : (tensor<2xi64>) -> tensor<1xi64>
// CHECK-NEXT: [[CAST2:%.+]] = index_cast [[SLICE]] : tensor<1xi64> to tensor<1xindex>
// CHECK-NEXT: [[IOTA:%.+]] = "mhlo.dynamic_iota"([[CAST2]]) {iota_dimension = 0 : i64} : (tensor<1xindex>) -> tensor<?xi32>
// CHECK-NEXT: [[BROADCAST:%.+]] = "mhlo.dynamic_broadcast_in_dim"([[IOTA]], %arg0) {broadcast_dimensions = dense<1> : tensor<1xi64>} : (tensor<?xi32>, tensor<2xindex>) -> tensor<5x?xi32>
%0 = "mhlo.dynamic_iota"(%arg0) {iota_dimension = 1 : i64} : (tensor<2xindex>) -> tensor<5x?xi32>
// CHECK: return [[BROADCAST]]
return %0 : tensor<5x?xi32>
}
// CHECK-LABEL: @iota_not_lowered_to_constant
func @iota_not_lowered_to_constant() -> tensor<4xi32> {
// CHECK: [[RESULT:%.*]] = "mhlo.iota"
// CHECK: return [[RESULT]]
%0 = "mhlo.iota"() {iota_dimension = 0 : i64} : () -> tensor<4xi32>
return %0 : tensor<4xi32>
}
// CHECK-LABEL: @iota_broadcast
func @iota_broadcast() -> tensor<5x4xi32> {
// CHECK: [[IOTA:%.+]] = "mhlo.iota"() {iota_dimension = 0 : i64} : () -> tensor<5xi32>
// CHECK: [[RESULT:%.+]] = "mhlo.broadcast_in_dim"([[IOTA]]) {broadcast_dimensions = dense<0> : tensor<1xi64>} : (tensor<5xi32>) -> tensor<5x4xi32>
%0 = "mhlo.iota"() {iota_dimension = 0 : i64} : () -> tensor<5x4xi32>
return %0 : tensor<5x4xi32>
}
// CHECK-LABEL: @iota_broadcast
func @iota_broadcast_second() -> tensor<5x4xi32> {
// CHECK: [[IOTA:%.+]] = "mhlo.iota"() {iota_dimension = 0 : i64} : () -> tensor<4xi32>
// CHECK: [[RESULT:%.+]] = "mhlo.broadcast_in_dim"([[IOTA]]) {broadcast_dimensions = dense<1> : tensor<1xi64>} : (tensor<4xi32>) -> tensor<5x4xi32>
%0 = "mhlo.iota"() {iota_dimension = 1 : i64} : () -> tensor<5x4xi32>
return %0 : tensor<5x4xi32>
}
// CHECK-LABEL: @unary_einsum
func @unary_einsum(%arg0: tensor<2x3xf32>) -> tensor<2x2xf32> {
// CHECK: %[[ONE:.*]] = mhlo.constant dense<1.000000e+00> : tensor<f32>
// CHECK: "mhlo.einsum"(%[[ONE]], %arg0) {einsum_config = ",ab->aa"}
%0 = "mhlo.unary_einsum"(%arg0) {einsum_config = "ab->aa"} : (tensor<2x3xf32>) -> tensor<2x2xf32>
return %0 : tensor<2x2xf32>
}
// CHECK-LABEL: func @fold_copy
// CHECK-SAME: [[ARG:%[a-zA-Z0-9]+]]
func @fold_copy(%arg : tensor<1x4xf32>) -> tensor<1x4xf32> {
// CHECK: return [[ARG]]
%0 = "mhlo.copy"(%arg) : (tensor<1x4xf32>) -> tensor<1x4xf32>
return %0 : tensor<1x4xf32>
}
// CHECK-LABEL: func @dynamic_reshape_not_actually_dynamic
func @dynamic_reshape_not_actually_dynamic(%arg0: tensor<4xf32>, %shape: tensor<2xindex>) -> tensor<4x1xf32> {
// CHECK: mhlo.reshape
%0 = "mhlo.dynamic_reshape"(%arg0, %shape) : (tensor<4xf32>, tensor<2xindex>) -> tensor<4x1xf32>
return %0 : tensor<4x1xf32>
}
// CHECK-LABEL: do_not_dce_while_with_outfeed
func @do_not_dce_while_with_outfeed(%arg0: tensor<i64>) -> tensor<i64> {
// CHECK: mhlo.while
%0 = "mhlo.while"(%arg0) ( {
^bb0(%arg1: tensor<i64>):
%1 = "mhlo.compare"(%arg1, %arg1) {comparison_direction = "LT"} : (tensor<i64>, tensor<i64>) -> tensor<i1>
"mhlo.return"(%1) : (tensor<i1>) -> ()
}, {
^bb0(%arg1: tensor<i64>):
%1 = "mhlo.create_token"() : () -> !mhlo.token
// Side-effecting op outfeed present inside while.
%2 = "mhlo.outfeed"(%arg1, %1) {outfeed_config = ""} : (tensor<i64>, !mhlo.token) -> !mhlo.token
"mhlo.return"(%arg1) : (tensor<i64>) -> ()
}) : (tensor<i64>) -> tensor<i64>
return %arg0 : tensor<i64>
}
// CHECK-LABEL: dce_while_without_side_effect
func @dce_while_without_side_effect(%arg0: tensor<i64>) -> tensor<i64> {
// CHECK-NOT: mhlo.while
%0 = "mhlo.while"(%arg0) ( {
^bb0(%arg1: tensor<i64>):
%1 = "mhlo.compare"(%arg1, %arg1) {comparison_direction = "LT"} : (tensor<i64>, tensor<i64>) -> tensor<i1>
"mhlo.return"(%1) : (tensor<i1>) -> ()
}, {
^bb0(%arg1: tensor<i64>):
%1 = "mhlo.create_token"() : () -> !mhlo.token
"mhlo.return"(%arg1) : (tensor<i64>) -> ()
}) : (tensor<i64>) -> tensor<i64>
return %arg0 : tensor<i64>
}

View File

@ -1,4 +1,4 @@
// RUN: xla-opt -test-xla-infer-shaped-type-methods -allow-unregistered-dialect -split-input-file -verify-diagnostics %s -o - | FileCheck %s // RUN: mlir-hlo-opt -test-xla-infer-shaped-type-methods -allow-unregistered-dialect -split-input-file -verify-diagnostics %s -o - | FileCheck %s
// CHECK-LABEL: @broadcast_add // CHECK-LABEL: @broadcast_add
// Note that all broadcast_ops are expanded from the same template, so // Note that all broadcast_ops are expanded from the same template, so

View File

@ -1,10 +1,10 @@
// RUN: xla-opt -test-xla-chlo-legalize-to-hlo -cse -split-input-file -verify-diagnostics %s -o - | FileCheck %s // RUN: mlir-hlo-opt -test-xla-chlo-legalize-to-hlo -cse -split-input-file -verify-diagnostics %s -o - | FileCheck %s
// Check the non-broadcast case for each registered op, then just check a // Check the non-broadcast case for each registered op, then just check a
// representative op for detailed broadcast semantics. // representative op for detailed broadcast semantics.
// CHECK-LABEL: @addWithoutBroadcast // CHECK-LABEL: @addWithoutBroadcast
func @addWithoutBroadcast(%arg0: tensor<4xf32>, %arg1: tensor<4xf32>) -> tensor<4xf32> { func @addWithoutBroadcast(%arg0: tensor<4xf32>, %arg1: tensor<4xf32>) -> tensor<4xf32> {
// CHECK: xla_hlo.add %arg0, %arg1 // CHECK: mhlo.add %arg0, %arg1
%0 = xla_chlo.broadcast_add %arg0, %arg1 : (tensor<4xf32>, tensor<4xf32>) -> tensor<4xf32> %0 = xla_chlo.broadcast_add %arg0, %arg1 : (tensor<4xf32>, tensor<4xf32>) -> tensor<4xf32>
return %0 : tensor<4xf32> return %0 : tensor<4xf32>
} }
@ -20,9 +20,9 @@ func @dynamicBroadcast(%arg0: tensor<?xf32>, %arg1: tensor<?x?xf32>) -> tensor<?
// CHECK-NEXT: %[[FINAL_RESULT:.+]] = shape.assuming %[[WITNESS]] // CHECK-NEXT: %[[FINAL_RESULT:.+]] = shape.assuming %[[WITNESS]]
// CHECK-DAG: %[[RESULT_S:.+]] = "shape.broadcast"(%[[ARG0_S]], %[[ARG1_S]]) // CHECK-DAG: %[[RESULT_S:.+]] = "shape.broadcast"(%[[ARG0_S]], %[[ARG1_S]])
// CHECK: %[[RESULT_EXTENTS:.+]] = shape.to_extent_tensor %[[RESULT_S]] // CHECK: %[[RESULT_EXTENTS:.+]] = shape.to_extent_tensor %[[RESULT_S]]
// CHECK-DAG: %[[ARG0_B:.+]] = "xla_hlo.dynamic_broadcast_in_dim"(%[[ARG0]], %[[RESULT_EXTENTS]]) {broadcast_dimensions = dense<1> : tensor<1xi64>} // CHECK-DAG: %[[ARG0_B:.+]] = "mhlo.dynamic_broadcast_in_dim"(%[[ARG0]], %[[RESULT_EXTENTS]]) {broadcast_dimensions = dense<1> : tensor<1xi64>}
// CHECK-DAG: %[[ARG1_B:.+]] = "xla_hlo.dynamic_broadcast_in_dim"(%[[ARG1]], %[[RESULT_EXTENTS]]) {broadcast_dimensions = dense<[0, 1]> : tensor<2xi64>} // CHECK-DAG: %[[ARG1_B:.+]] = "mhlo.dynamic_broadcast_in_dim"(%[[ARG1]], %[[RESULT_EXTENTS]]) {broadcast_dimensions = dense<[0, 1]> : tensor<2xi64>}
// CHECK-NEXT: %[[RESULT:.+]] = xla_hlo.add %[[ARG0_B]], %[[ARG1_B]] // CHECK-NEXT: %[[RESULT:.+]] = mhlo.add %[[ARG0_B]], %[[ARG1_B]]
// CHECK-NEXT: shape.assuming_yield %[[RESULT]] // CHECK-NEXT: shape.assuming_yield %[[RESULT]]
// CHECK-NEXT: } // CHECK-NEXT: }
// CHECK-NEXT: return %[[FINAL_RESULT]] : tensor<?x?xf32> // CHECK-NEXT: return %[[FINAL_RESULT]] : tensor<?x?xf32>
@ -41,9 +41,9 @@ func @dynamicBroadcastComplex(%arg0: tensor<?xf32>, %arg1: tensor<?x?xf32>) -> t
// CHECK-NEXT: %[[FINAL_RESULT:.+]] = shape.assuming %[[WITNESS]] // CHECK-NEXT: %[[FINAL_RESULT:.+]] = shape.assuming %[[WITNESS]]
// CHECK-NEXT: %[[RESULT_S:.+]] = "shape.broadcast"(%[[ARG0_S]], %[[ARG1_S]]) // CHECK-NEXT: %[[RESULT_S:.+]] = "shape.broadcast"(%[[ARG0_S]], %[[ARG1_S]])
// CHECK-NEXT: %[[RESULT_EXTENTS:.+]] = shape.to_extent_tensor %[[RESULT_S]] // CHECK-NEXT: %[[RESULT_EXTENTS:.+]] = shape.to_extent_tensor %[[RESULT_S]]
// CHECK-DAG: %[[ARG0_B:.+]] = "xla_hlo.dynamic_broadcast_in_dim"(%[[ARG0]], %[[RESULT_EXTENTS]]) {broadcast_dimensions = dense<1> : tensor<1xi64>} : (tensor<?xf32>, tensor<2xindex>) -> tensor<?x?xf32> // CHECK-DAG: %[[ARG0_B:.+]] = "mhlo.dynamic_broadcast_in_dim"(%[[ARG0]], %[[RESULT_EXTENTS]]) {broadcast_dimensions = dense<1> : tensor<1xi64>} : (tensor<?xf32>, tensor<2xindex>) -> tensor<?x?xf32>
// CHECK-DAG: %[[ARG1_B:.+]] = "xla_hlo.dynamic_broadcast_in_dim"(%[[ARG1]], %[[RESULT_EXTENTS]]) {broadcast_dimensions = dense<[0, 1]> : tensor<2xi64>} : (tensor<?x?xf32>, tensor<2xindex>) -> tensor<?x?xf32> // CHECK-DAG: %[[ARG1_B:.+]] = "mhlo.dynamic_broadcast_in_dim"(%[[ARG1]], %[[RESULT_EXTENTS]]) {broadcast_dimensions = dense<[0, 1]> : tensor<2xi64>} : (tensor<?x?xf32>, tensor<2xindex>) -> tensor<?x?xf32>
// CHECK-NEXT: %[[RESULT:.+]] = "xla_hlo.complex"(%[[ARG0_B]], %[[ARG1_B]]) : (tensor<?x?xf32>, tensor<?x?xf32>) -> tensor<?x?xcomplex<f32>> // CHECK-NEXT: %[[RESULT:.+]] = "mhlo.complex"(%[[ARG0_B]], %[[ARG1_B]]) : (tensor<?x?xf32>, tensor<?x?xf32>) -> tensor<?x?xcomplex<f32>>
// CHECK-NEXT: shape.assuming_yield %[[RESULT]] // CHECK-NEXT: shape.assuming_yield %[[RESULT]]
// CHECK-NEXT: } // CHECK-NEXT: }
// CHECK-NEXT: return %[[FINAL_RESULT]] : tensor<?x?xcomplex<f32>> // CHECK-NEXT: return %[[FINAL_RESULT]] : tensor<?x?xcomplex<f32>>
@ -62,9 +62,9 @@ func @dynamicBroadcastCompare(%arg0: tensor<?xf32>, %arg1: tensor<?x?xf32>) -> t
// CHECK: %[[FINAL_RESULT:.+]] = shape.assuming %[[WITNESS]] // CHECK: %[[FINAL_RESULT:.+]] = shape.assuming %[[WITNESS]]
// CHECK: %[[RESULT_S:.+]] = "shape.broadcast"(%[[ARG0_S]], %[[ARG1_S]]) // CHECK: %[[RESULT_S:.+]] = "shape.broadcast"(%[[ARG0_S]], %[[ARG1_S]])
// CHECK: %[[RESULT_EXTENTS:.+]] = shape.to_extent_tensor %[[RESULT_S]] // CHECK: %[[RESULT_EXTENTS:.+]] = shape.to_extent_tensor %[[RESULT_S]]
// CHECK-DAG: %[[ARG0_B:.+]] = "xla_hlo.dynamic_broadcast_in_dim"(%[[ARG0]], %[[RESULT_EXTENTS]]) {broadcast_dimensions = dense<1> : tensor<1xi64>} : (tensor<?xf32>, tensor<2xindex>) -> tensor<?x?xf32> // CHECK-DAG: %[[ARG0_B:.+]] = "mhlo.dynamic_broadcast_in_dim"(%[[ARG0]], %[[RESULT_EXTENTS]]) {broadcast_dimensions = dense<1> : tensor<1xi64>} : (tensor<?xf32>, tensor<2xindex>) -> tensor<?x?xf32>
// CHECK-DAG: %[[ARG1_B:.+]] = "xla_hlo.dynamic_broadcast_in_dim"(%[[ARG1]], %[[RESULT_EXTENTS]]) {broadcast_dimensions = dense<[0, 1]> : tensor<2xi64>} : (tensor<?x?xf32>, tensor<2xindex>) -> tensor<?x?xf32> // CHECK-DAG: %[[ARG1_B:.+]] = "mhlo.dynamic_broadcast_in_dim"(%[[ARG1]], %[[RESULT_EXTENTS]]) {broadcast_dimensions = dense<[0, 1]> : tensor<2xi64>} : (tensor<?x?xf32>, tensor<2xindex>) -> tensor<?x?xf32>
// CHECK: %[[RESULT:.+]] = "xla_hlo.compare"(%[[ARG0_B]], %[[ARG1_B]]) {comparison_direction = "EQ"} : (tensor<?x?xf32>, tensor<?x?xf32>) -> tensor<?x?xi1> // CHECK: %[[RESULT:.+]] = "mhlo.compare"(%[[ARG0_B]], %[[ARG1_B]]) {comparison_direction = "EQ"} : (tensor<?x?xf32>, tensor<?x?xf32>) -> tensor<?x?xi1>
// CHECK: shape.assuming_yield %[[RESULT]] // CHECK: shape.assuming_yield %[[RESULT]]
// CHECK-NEXT: } // CHECK-NEXT: }
// CHECK: return %[[FINAL_RESULT]] : tensor<?x?xi1> // CHECK: return %[[FINAL_RESULT]] : tensor<?x?xi1>
@ -76,7 +76,7 @@ func @dynamicBroadcastCompare(%arg0: tensor<?xf32>, %arg1: tensor<?x?xf32>) -> t
// Verifies that broadcast_dimensions validity checks are valid. // Verifies that broadcast_dimensions validity checks are valid.
// CHECK-LABEL: @dynamicNonScalarBroadcastDimensions // CHECK-LABEL: @dynamicNonScalarBroadcastDimensions
func @dynamicNonScalarBroadcastDimensions(%arg0: tensor<1x4xf32>, %arg1: tensor<4xf32>) -> tensor<1x4xf32> { func @dynamicNonScalarBroadcastDimensions(%arg0: tensor<1x4xf32>, %arg1: tensor<4xf32>) -> tensor<1x4xf32> {
// CHECK: xla_hlo.add // CHECK: mhlo.add
%0 = xla_chlo.broadcast_add %arg0, %arg1 {broadcast_dimensions = dense<1> : tensor<1xi64>} : (tensor<1x4xf32>, tensor<4xf32>) -> tensor<1x4xf32> %0 = xla_chlo.broadcast_add %arg0, %arg1 {broadcast_dimensions = dense<1> : tensor<1xi64>} : (tensor<1x4xf32>, tensor<4xf32>) -> tensor<1x4xf32>
return %0 : tensor<1x4xf32> return %0 : tensor<1x4xf32>
} }
@ -85,7 +85,7 @@ func @dynamicNonScalarBroadcastDimensions(%arg0: tensor<1x4xf32>, %arg1: tensor<
// Verifies that broadcast_dimensions validity checks are valid. // Verifies that broadcast_dimensions validity checks are valid.
// CHECK-LABEL: @dynamicNonScalarByScalarBroadcastDimensions // CHECK-LABEL: @dynamicNonScalarByScalarBroadcastDimensions
func @dynamicNonScalarByScalarBroadcastDimensions(%arg0: tensor<1x4xf32>, %arg1: tensor<f32>) -> tensor<1x4xf32> { func @dynamicNonScalarByScalarBroadcastDimensions(%arg0: tensor<1x4xf32>, %arg1: tensor<f32>) -> tensor<1x4xf32> {
// CHECK: xla_hlo.add // CHECK: mhlo.add
%0 = xla_chlo.broadcast_add %arg0, %arg1 {broadcast_dimensions = dense<[]> : tensor<0xi64>} : (tensor<1x4xf32>, tensor<f32>) -> tensor<1x4xf32> %0 = xla_chlo.broadcast_add %arg0, %arg1 {broadcast_dimensions = dense<[]> : tensor<0xi64>} : (tensor<1x4xf32>, tensor<f32>) -> tensor<1x4xf32>
return %0 : tensor<1x4xf32> return %0 : tensor<1x4xf32>
} }
@ -113,7 +113,7 @@ func @dynamicNonScalarBroadcastDimensionsMismatch(%arg0: tensor<1x4xf32>, %arg1:
// expansions. Tests below merely verify that the op has an expansion. // expansions. Tests below merely verify that the op has an expansion.
// CHECK-LABEL: @andWithoutBroadcast // CHECK-LABEL: @andWithoutBroadcast
func @andWithoutBroadcast(%arg0: tensor<4xi1>, %arg1: tensor<4xi1>) -> tensor<4xi1> { func @andWithoutBroadcast(%arg0: tensor<4xi1>, %arg1: tensor<4xi1>) -> tensor<4xi1> {
// CHECK: xla_hlo.and %arg0, %arg1 // CHECK: mhlo.and %arg0, %arg1
%0 = xla_chlo.broadcast_and %arg0, %arg1 : (tensor<4xi1>, tensor<4xi1>) -> tensor<4xi1> %0 = xla_chlo.broadcast_and %arg0, %arg1 : (tensor<4xi1>, tensor<4xi1>) -> tensor<4xi1>
return %0 : tensor<4xi1> return %0 : tensor<4xi1>
} }
@ -121,7 +121,7 @@ func @andWithoutBroadcast(%arg0: tensor<4xi1>, %arg1: tensor<4xi1>) -> tensor<4x
// ----- // -----
// CHECK-LABEL: @atan2WithoutBroadcast // CHECK-LABEL: @atan2WithoutBroadcast
func @atan2WithoutBroadcast(%arg0: tensor<4xf32>, %arg1: tensor<4xf32>) -> tensor<4xf32> { func @atan2WithoutBroadcast(%arg0: tensor<4xf32>, %arg1: tensor<4xf32>) -> tensor<4xf32> {
// CHECK: xla_hlo.atan2 %arg0, %arg1 // CHECK: mhlo.atan2 %arg0, %arg1
%0 = xla_chlo.broadcast_atan2 %arg0, %arg1 : (tensor<4xf32>, tensor<4xf32>) -> tensor<4xf32> %0 = xla_chlo.broadcast_atan2 %arg0, %arg1 : (tensor<4xf32>, tensor<4xf32>) -> tensor<4xf32>
return %0 : tensor<4xf32> return %0 : tensor<4xf32>
} }
@ -129,7 +129,7 @@ func @atan2WithoutBroadcast(%arg0: tensor<4xf32>, %arg1: tensor<4xf32>) -> tenso
// ----- // -----
// CHECK-LABEL: @compareWithoutBroadcast // CHECK-LABEL: @compareWithoutBroadcast
func @compareWithoutBroadcast(%arg0: tensor<4xf32>, %arg1: tensor<4xf32>) -> tensor<4xi1> { func @compareWithoutBroadcast(%arg0: tensor<4xf32>, %arg1: tensor<4xf32>) -> tensor<4xi1> {
// CHECK: "xla_hlo.compare"(%arg0, %arg1) {comparison_direction = "EQ"} : (tensor<4xf32>, tensor<4xf32>) -> tensor<4xi1> // CHECK: "mhlo.compare"(%arg0, %arg1) {comparison_direction = "EQ"} : (tensor<4xf32>, tensor<4xf32>) -> tensor<4xi1>
%0 = xla_chlo.broadcast_compare %arg0, %arg1 {comparison_direction = "EQ"} : (tensor<4xf32>, tensor<4xf32>) -> tensor<4xi1> %0 = xla_chlo.broadcast_compare %arg0, %arg1 {comparison_direction = "EQ"} : (tensor<4xf32>, tensor<4xf32>) -> tensor<4xi1>
return %0 : tensor<4xi1> return %0 : tensor<4xi1>
} }
@ -137,7 +137,7 @@ func @compareWithoutBroadcast(%arg0: tensor<4xf32>, %arg1: tensor<4xf32>) -> ten
// ----- // -----
// CHECK-LABEL: @complexWithoutBroadcast // CHECK-LABEL: @complexWithoutBroadcast
func @complexWithoutBroadcast(%arg0: tensor<4xf32>, %arg1: tensor<4xf32>) -> tensor<4xcomplex<f32>> { func @complexWithoutBroadcast(%arg0: tensor<4xf32>, %arg1: tensor<4xf32>) -> tensor<4xcomplex<f32>> {
// CHECK: "xla_hlo.complex"(%arg0, %arg1) : (tensor<4xf32>, tensor<4xf32>) -> tensor<4xcomplex<f32>> // CHECK: "mhlo.complex"(%arg0, %arg1) : (tensor<4xf32>, tensor<4xf32>) -> tensor<4xcomplex<f32>>
%0 = xla_chlo.broadcast_complex %arg0, %arg1 : (tensor<4xf32>, tensor<4xf32>) -> tensor<4xcomplex<f32>> %0 = xla_chlo.broadcast_complex %arg0, %arg1 : (tensor<4xf32>, tensor<4xf32>) -> tensor<4xcomplex<f32>>
return %0 : tensor<4xcomplex<f32>> return %0 : tensor<4xcomplex<f32>>
} }
@ -145,7 +145,7 @@ func @complexWithoutBroadcast(%arg0: tensor<4xf32>, %arg1: tensor<4xf32>) -> ten
// ----- // -----
// CHECK-LABEL: @divideWithoutBroadcast // CHECK-LABEL: @divideWithoutBroadcast
func @divideWithoutBroadcast(%arg0: tensor<4xf32>, %arg1: tensor<4xf32>) -> tensor<4xf32> { func @divideWithoutBroadcast(%arg0: tensor<4xf32>, %arg1: tensor<4xf32>) -> tensor<4xf32> {
// CHECK: xla_hlo.divide %arg0, %arg1 // CHECK: mhlo.divide %arg0, %arg1
%0 = xla_chlo.broadcast_divide %arg0, %arg1 : (tensor<4xf32>, tensor<4xf32>) -> tensor<4xf32> %0 = xla_chlo.broadcast_divide %arg0, %arg1 : (tensor<4xf32>, tensor<4xf32>) -> tensor<4xf32>
return %0 : tensor<4xf32> return %0 : tensor<4xf32>
} }
@ -153,7 +153,7 @@ func @divideWithoutBroadcast(%arg0: tensor<4xf32>, %arg1: tensor<4xf32>) -> tens
// ----- // -----
// CHECK-LABEL: @maximumWithoutBroadcast // CHECK-LABEL: @maximumWithoutBroadcast
func @maximumWithoutBroadcast(%arg0: tensor<4xf32>, %arg1: tensor<4xf32>) -> tensor<4xf32> { func @maximumWithoutBroadcast(%arg0: tensor<4xf32>, %arg1: tensor<4xf32>) -> tensor<4xf32> {
// CHECK: xla_hlo.maximum %arg0, %arg1 // CHECK: mhlo.maximum %arg0, %arg1
%0 = xla_chlo.broadcast_maximum %arg0, %arg1 : (tensor<4xf32>, tensor<4xf32>) -> tensor<4xf32> %0 = xla_chlo.broadcast_maximum %arg0, %arg1 : (tensor<4xf32>, tensor<4xf32>) -> tensor<4xf32>
return %0 : tensor<4xf32> return %0 : tensor<4xf32>
} }
@ -161,7 +161,7 @@ func @maximumWithoutBroadcast(%arg0: tensor<4xf32>, %arg1: tensor<4xf32>) -> ten
// ----- // -----
// CHECK-LABEL: @minimumWithoutBroadcast // CHECK-LABEL: @minimumWithoutBroadcast
func @minimumWithoutBroadcast(%arg0: tensor<4xf32>, %arg1: tensor<4xf32>) -> tensor<4xf32> { func @minimumWithoutBroadcast(%arg0: tensor<4xf32>, %arg1: tensor<4xf32>) -> tensor<4xf32> {
// CHECK: xla_hlo.minimum %arg0, %arg1 // CHECK: mhlo.minimum %arg0, %arg1
%0 = xla_chlo.broadcast_minimum %arg0, %arg1 : (tensor<4xf32>, tensor<4xf32>) -> tensor<4xf32> %0 = xla_chlo.broadcast_minimum %arg0, %arg1 : (tensor<4xf32>, tensor<4xf32>) -> tensor<4xf32>
return %0 : tensor<4xf32> return %0 : tensor<4xf32>
} }
@ -169,7 +169,7 @@ func @minimumWithoutBroadcast(%arg0: tensor<4xf32>, %arg1: tensor<4xf32>) -> ten
// ----- // -----
// CHECK-LABEL: @multiplyWithoutBroadcast // CHECK-LABEL: @multiplyWithoutBroadcast
func @multiplyWithoutBroadcast(%arg0: tensor<4xf32>, %arg1: tensor<4xf32>) -> tensor<4xf32> { func @multiplyWithoutBroadcast(%arg0: tensor<4xf32>, %arg1: tensor<4xf32>) -> tensor<4xf32> {
// CHECK: xla_hlo.multiply %arg0, %arg1 // CHECK: mhlo.multiply %arg0, %arg1
%0 = xla_chlo.broadcast_multiply %arg0, %arg1 : (tensor<4xf32>, tensor<4xf32>) -> tensor<4xf32> %0 = xla_chlo.broadcast_multiply %arg0, %arg1 : (tensor<4xf32>, tensor<4xf32>) -> tensor<4xf32>
return %0 : tensor<4xf32> return %0 : tensor<4xf32>
} }
@ -177,7 +177,7 @@ func @multiplyWithoutBroadcast(%arg0: tensor<4xf32>, %arg1: tensor<4xf32>) -> te
// ----- // -----
// CHECK-LABEL: @orWithoutBroadcast // CHECK-LABEL: @orWithoutBroadcast
func @orWithoutBroadcast(%arg0: tensor<4xi1>, %arg1: tensor<4xi1>) -> tensor<4xi1> { func @orWithoutBroadcast(%arg0: tensor<4xi1>, %arg1: tensor<4xi1>) -> tensor<4xi1> {
// CHECK: xla_hlo.or %arg0, %arg1 // CHECK: mhlo.or %arg0, %arg1
%0 = xla_chlo.broadcast_or %arg0, %arg1 : (tensor<4xi1>, tensor<4xi1>) -> tensor<4xi1> %0 = xla_chlo.broadcast_or %arg0, %arg1 : (tensor<4xi1>, tensor<4xi1>) -> tensor<4xi1>
return %0 : tensor<4xi1> return %0 : tensor<4xi1>
} }
@ -185,7 +185,7 @@ func @orWithoutBroadcast(%arg0: tensor<4xi1>, %arg1: tensor<4xi1>) -> tensor<4xi
// ----- // -----
// CHECK-LABEL: @powerWithoutBroadcast // CHECK-LABEL: @powerWithoutBroadcast
func @powerWithoutBroadcast(%arg0: tensor<4xf32>, %arg1: tensor<4xf32>) -> tensor<4xf32> { func @powerWithoutBroadcast(%arg0: tensor<4xf32>, %arg1: tensor<4xf32>) -> tensor<4xf32> {
// CHECK: xla_hlo.power %arg0, %arg1 // CHECK: mhlo.power %arg0, %arg1
%0 = xla_chlo.broadcast_power %arg0, %arg1 : (tensor<4xf32>, tensor<4xf32>) -> tensor<4xf32> %0 = xla_chlo.broadcast_power %arg0, %arg1 : (tensor<4xf32>, tensor<4xf32>) -> tensor<4xf32>
return %0 : tensor<4xf32> return %0 : tensor<4xf32>
} }
@ -193,7 +193,7 @@ func @powerWithoutBroadcast(%arg0: tensor<4xf32>, %arg1: tensor<4xf32>) -> tenso
// ----- // -----
// CHECK-LABEL: @remainderWithoutBroadcast // CHECK-LABEL: @remainderWithoutBroadcast
func @remainderWithoutBroadcast(%arg0: tensor<4xf32>, %arg1: tensor<4xf32>) -> tensor<4xf32> { func @remainderWithoutBroadcast(%arg0: tensor<4xf32>, %arg1: tensor<4xf32>) -> tensor<4xf32> {
// CHECK: xla_hlo.remainder %arg0, %arg1 // CHECK: mhlo.remainder %arg0, %arg1
%0 = xla_chlo.broadcast_remainder %arg0, %arg1 : (tensor<4xf32>, tensor<4xf32>) -> tensor<4xf32> %0 = xla_chlo.broadcast_remainder %arg0, %arg1 : (tensor<4xf32>, tensor<4xf32>) -> tensor<4xf32>
return %0 : tensor<4xf32> return %0 : tensor<4xf32>
} }
@ -201,7 +201,7 @@ func @remainderWithoutBroadcast(%arg0: tensor<4xf32>, %arg1: tensor<4xf32>) -> t
// ----- // -----
// CHECK-LABEL: @shift_leftWithoutBroadcast // CHECK-LABEL: @shift_leftWithoutBroadcast
func @shift_leftWithoutBroadcast(%arg0: tensor<4xf32>, %arg1: tensor<4xf32>) -> tensor<4xf32> { func @shift_leftWithoutBroadcast(%arg0: tensor<4xf32>, %arg1: tensor<4xf32>) -> tensor<4xf32> {
// CHECK: xla_hlo.shift_left %arg0, %arg1 // CHECK: mhlo.shift_left %arg0, %arg1
%0 = xla_chlo.broadcast_shift_left %arg0, %arg1 : (tensor<4xf32>, tensor<4xf32>) -> tensor<4xf32> %0 = xla_chlo.broadcast_shift_left %arg0, %arg1 : (tensor<4xf32>, tensor<4xf32>) -> tensor<4xf32>
return %0 : tensor<4xf32> return %0 : tensor<4xf32>
} }
@ -209,7 +209,7 @@ func @shift_leftWithoutBroadcast(%arg0: tensor<4xf32>, %arg1: tensor<4xf32>) ->
// ----- // -----
// CHECK-LABEL: @shift_right_arithmeticWithoutBroadcast // CHECK-LABEL: @shift_right_arithmeticWithoutBroadcast
func @shift_right_arithmeticWithoutBroadcast(%arg0: tensor<4xf32>, %arg1: tensor<4xf32>) -> tensor<4xf32> { func @shift_right_arithmeticWithoutBroadcast(%arg0: tensor<4xf32>, %arg1: tensor<4xf32>) -> tensor<4xf32> {
// CHECK: xla_hlo.shift_right_arithmetic %arg0, %arg1 // CHECK: mhlo.shift_right_arithmetic %arg0, %arg1
%0 = xla_chlo.broadcast_shift_right_arithmetic %arg0, %arg1 : (tensor<4xf32>, tensor<4xf32>) -> tensor<4xf32> %0 = xla_chlo.broadcast_shift_right_arithmetic %arg0, %arg1 : (tensor<4xf32>, tensor<4xf32>) -> tensor<4xf32>
return %0 : tensor<4xf32> return %0 : tensor<4xf32>
} }
@ -217,7 +217,7 @@ func @shift_right_arithmeticWithoutBroadcast(%arg0: tensor<4xf32>, %arg1: tensor
// ----- // -----
// CHECK-LABEL: @shift_right_logicalWithoutBroadcast // CHECK-LABEL: @shift_right_logicalWithoutBroadcast
func @shift_right_logicalWithoutBroadcast(%arg0: tensor<4xf32>, %arg1: tensor<4xf32>) -> tensor<4xf32> { func @shift_right_logicalWithoutBroadcast(%arg0: tensor<4xf32>, %arg1: tensor<4xf32>) -> tensor<4xf32> {
// CHECK: xla_hlo.shift_right_logical %arg0, %arg1 // CHECK: mhlo.shift_right_logical %arg0, %arg1
%0 = xla_chlo.broadcast_shift_right_logical %arg0, %arg1 : (tensor<4xf32>, tensor<4xf32>) -> tensor<4xf32> %0 = xla_chlo.broadcast_shift_right_logical %arg0, %arg1 : (tensor<4xf32>, tensor<4xf32>) -> tensor<4xf32>
return %0 : tensor<4xf32> return %0 : tensor<4xf32>
} }
@ -225,7 +225,7 @@ func @shift_right_logicalWithoutBroadcast(%arg0: tensor<4xf32>, %arg1: tensor<4x
// ----- // -----
// CHECK-LABEL: @subWithoutBroadcast // CHECK-LABEL: @subWithoutBroadcast
func @subWithoutBroadcast(%arg0: tensor<4xf32>, %arg1: tensor<4xf32>) -> tensor<4xf32> { func @subWithoutBroadcast(%arg0: tensor<4xf32>, %arg1: tensor<4xf32>) -> tensor<4xf32> {
// CHECK: xla_hlo.subtract %arg0, %arg1 // CHECK: mhlo.subtract %arg0, %arg1
%0 = xla_chlo.broadcast_subtract %arg0, %arg1 : (tensor<4xf32>, tensor<4xf32>) -> tensor<4xf32> %0 = xla_chlo.broadcast_subtract %arg0, %arg1 : (tensor<4xf32>, tensor<4xf32>) -> tensor<4xf32>
return %0 : tensor<4xf32> return %0 : tensor<4xf32>
} }
@ -233,7 +233,7 @@ func @subWithoutBroadcast(%arg0: tensor<4xf32>, %arg1: tensor<4xf32>) -> tensor<
// ----- // -----
// CHECK-LABEL: @xorWithoutBroadcast // CHECK-LABEL: @xorWithoutBroadcast
func @xorWithoutBroadcast(%arg0: tensor<4xi1>, %arg1: tensor<4xi1>) -> tensor<4xi1> { func @xorWithoutBroadcast(%arg0: tensor<4xi1>, %arg1: tensor<4xi1>) -> tensor<4xi1> {
// CHECK: xla_hlo.xor %arg0, %arg1 // CHECK: mhlo.xor %arg0, %arg1
%0 = xla_chlo.broadcast_xor %arg0, %arg1 : (tensor<4xi1>, tensor<4xi1>) -> tensor<4xi1> %0 = xla_chlo.broadcast_xor %arg0, %arg1 : (tensor<4xi1>, tensor<4xi1>) -> tensor<4xi1>
return %0 : tensor<4xi1> return %0 : tensor<4xi1>
} }

View File

@ -1,9 +1,9 @@
// RUN: xla-opt %s -pass-pipeline='func(canonicalize)' | FileCheck %s // RUN: mlir-hlo-opt %s -pass-pipeline='func(canonicalize)' | FileCheck %s
// CHECK-LABEL: func @single_operand // CHECK-LABEL: func @single_operand
// CHECK-SAME: [[ARG:%[a-zA-Z0-9]+]] // CHECK-SAME: [[ARG:%[a-zA-Z0-9]+]]
func @single_operand(%arg: tensor<1x2xf32>) -> tensor<1x2xf32> { func @single_operand(%arg: tensor<1x2xf32>) -> tensor<1x2xf32> {
%0 = "xla_hlo.concatenate"(%arg) {dimension = 0 : i64} : (tensor<1x2xf32>) -> tensor<1x2xf32> %0 = "mhlo.concatenate"(%arg) {dimension = 0 : i64} : (tensor<1x2xf32>) -> tensor<1x2xf32>
// CHECK-NEXT: return [[ARG]] // CHECK-NEXT: return [[ARG]]
return %0 : tensor<1x2xf32> return %0 : tensor<1x2xf32>
} }

View File

@ -0,0 +1,225 @@
// RUN: mlir-hlo-opt %s -split-input-file -pass-pipeline='func(canonicalize)' | FileCheck %s
// -----
// CHECK-LABEL: func @same_type
// CHECK-SAME: [[ARG:%[a-zA-Z0-9]+]]
func @same_type(%arg: tensor<f32>) -> tensor<f32> {
%0 = "mhlo.convert"(%arg) : (tensor<f32>) -> tensor<f32>
// CHECK-NEXT: return [[ARG]]
return %0 : tensor<f32>
}
// -----
// CHECK-LABEL: func @int_widening
// CHECK-SAME: [[ARG:%[a-zA-Z0-9]+]]
func @int_widening(%arg: tensor<i32>) -> tensor<i64> {
// CHECK-NEXT: [[RES:%.+]] = "mhlo.convert"([[ARG]]) : (tensor<i32>) -> tensor<i64>
%0 = "mhlo.convert"(%arg) : (tensor<i32>) -> tensor<i64>
// CHECK-NEXT: return [[RES]]
return %0 : tensor<i64>
}
// -----
// CHECK-LABEL: func @int_narrowing
// CHECK-SAME: [[ARG:%[a-zA-Z0-9]+]]
func @int_narrowing(%arg: tensor<i32>) -> tensor<i16> {
// CHECK-NEXT: [[RES:%.+]] = "mhlo.convert"([[ARG]]) : (tensor<i32>) -> tensor<i16>
%0 = "mhlo.convert"(%arg) : (tensor<i32>) -> tensor<i16>
// CHECK-NEXT: return [[RES]]
return %0 : tensor<i16>
}
// -----
// CHECK-LABEL: func @float_int
// CHECK-SAME: [[ARG:%[a-zA-Z0-9]+]]
func @float_int(%arg: tensor<f32>) -> tensor<i32> {
// CHECK-NEXT: [[RES:%.+]] = "mhlo.convert"([[ARG]]) : (tensor<f32>) -> tensor<i32>
%0 = "mhlo.convert"(%arg) : (tensor<f32>) -> tensor<i32>
// CHECK-NEXT: return [[RES]]
return %0 : tensor<i32>
}
// -----
// CHECK-LABEL: func @int_float
// CHECK-SAME: [[ARG:%[a-zA-Z0-9]+]]
func @int_float(%arg: tensor<i32>) -> tensor<f32> {
// CHECK-NEXT: [[RES:%.+]] = "mhlo.convert"([[ARG]]) : (tensor<i32>) -> tensor<f32>
%0 = "mhlo.convert"(%arg) : (tensor<i32>) -> tensor<f32>
// CHECK-NEXT: return [[RES]]
return %0 : tensor<f32>
}
// -----
// CHECK-LABEL: func @high_rank_tensor
// CHECK-SAME: [[ARG:%[a-zA-Z0-9]+]]
func @high_rank_tensor(%arg: tensor<2x3xi32>) -> tensor<2x3xf32> {
// CHECK-NEXT: [[RES:%.+]] = "mhlo.convert"([[ARG]]) : (tensor<2x3xi32>) -> tensor<2x3xf32>
%0 = "mhlo.convert"(%arg) : (tensor<2x3xi32>) -> tensor<2x3xf32>
// CHECK-NEXT: return [[RES]]
return %0 : tensor<2x3xf32>
}
// -----
// CHECK-LABEL: func @const_same_type
func @const_same_type() -> tensor<i32> {
// CHECK-NEXT: [[CST:%.+]] = mhlo.constant dense<42> : tensor<i32>
%cst = mhlo.constant dense<42> : tensor<i32>
%0 = "mhlo.convert"(%cst) : (tensor<i32>) -> tensor<i32>
// CHECK-NEXT: return [[CST]]
return %0 : tensor<i32>
}
// -----
// CHECK-LABEL: func @const_float_int
func @const_float_int() -> tensor<i32> {
// CHECK-NEXT: [[CST:%.+]] = mhlo.constant dense<42> : tensor<i32>
%cst = mhlo.constant dense<42.0> : tensor<f32>
%0 = "mhlo.convert"(%cst) : (tensor<f32>) -> tensor<i32>
// CHECK-NEXT: return [[CST]]
return %0 : tensor<i32>
}
// -----
// CHECK-LABEL: func @const_int_float
func @const_int_float() -> tensor<f32> {
// CHECK-NEXT: [[CST:%.+]] = mhlo.constant dense<4.{{0*}}e+00> : tensor<f32>
%cst = mhlo.constant dense<4> : tensor<i32>
%0 = "mhlo.convert"(%cst) : (tensor<i32>) -> tensor<f32>
// CHECK-NEXT: return [[CST]]
return %0 : tensor<f32>
}
// -----
// CHECK-LABEL: func @const_negative_int_float
func @const_negative_int_float() -> tensor<f32> {
// CHECK-NEXT: [[CST:%.+]] = mhlo.constant dense<-4.{{0*}}e+00> : tensor<f32>
%cst = mhlo.constant dense<-4> : tensor<i32>
%0 = "mhlo.convert"(%cst) : (tensor<i32>) -> tensor<f32>
// CHECK-NEXT: return [[CST]]
return %0 : tensor<f32>
}
// -----
// CHECK-LABEL: func @const_int_bf16
func @const_int_bf16() -> tensor<bf16> {
// CHECK-NEXT: [[CST:%.+]] = mhlo.constant dense<4.{{0*}}e+00> : tensor<bf16>
%cst = mhlo.constant dense<4> : tensor<i32>
%0 = "mhlo.convert"(%cst) : (tensor<i32>) -> tensor<bf16>
// CHECK-NEXT: return [[CST]]
return %0 : tensor<bf16>
}
// -----
// CHECK-LABEL: func @const_bf16_int
func @const_bf16_int() -> tensor<i16> {
// CHECK-NEXT: [[CST:%.+]] = mhlo.constant dense<42> : tensor<i16>
%cst = mhlo.constant dense<42.0> : tensor<bf16>
%0 = "mhlo.convert"(%cst) : (tensor<bf16>) -> tensor<i16>
// CHECK-NEXT: return [[CST]]
return %0 : tensor<i16>
}
// -----
// CHECK-LABEL: func @const_int_narrowing
func @const_int_narrowing() -> tensor<i32> {
// CHECK-NEXT: [[CST:%.+]] = mhlo.constant dense<42> : tensor<i32>
%cst = mhlo.constant dense<42> : tensor<i64>
%0 = "mhlo.convert"(%cst) : (tensor<i64>) -> tensor<i32>
// CHECK-NEXT: return [[CST]]
return %0 : tensor<i32>
}
// -----
// CHECK-LABEL: func @const_int_widening
func @const_int_widening() -> tensor<i64> {
// CHECK-NEXT: [[CST:%.+]] = mhlo.constant dense<42> : tensor<i64>
%cst = mhlo.constant dense<42> : tensor<i32>
%0 = "mhlo.convert"(%cst) : (tensor<i32>) -> tensor<i64>
// CHECK-NEXT: return [[CST]]
return %0 : tensor<i64>
}
// -----
// CHECK-LABEL: func @const_negative_int_widening
func @const_negative_int_widening() -> tensor<i64> {
// CHECK-NEXT: [[CST:%.+]] = mhlo.constant dense<-42> : tensor<i64>
%cst = mhlo.constant dense<-42> : tensor<i32>
%0 = "mhlo.convert"(%cst) : (tensor<i32>) -> tensor<i64>
// CHECK-NEXT: return [[CST]]
return %0 : tensor<i64>
}
// -----
// CHECK-LABEL: func @const_float_narrowing
func @const_float_narrowing() -> tensor<f32> {
// CHECK-NEXT: [[CST:%.+]] = mhlo.constant dense<4.2{{0*}}e+00> : tensor<f32>
%cst = mhlo.constant dense<4.2> : tensor<f64>
%0 = "mhlo.convert"(%cst) : (tensor<f64>) -> tensor<f32>
// CHECK-NEXT: return [[CST]]
return %0 : tensor<f32>
}
// -----
// CHECK-LABEL: func @const_f32_bf16
func @const_f32_bf16() -> tensor<bf16> {
// CHECK-NEXT: [[CST:%.+]] = mhlo.constant dense<4.2{{0*}}e+01> : tensor<bf16>
%cst = mhlo.constant dense<42.0> : tensor<f32>
%0 = "mhlo.convert"(%cst) : (tensor<f32>) -> tensor<bf16>
// CHECK-NEXT: return [[CST]]
return %0 : tensor<bf16>
}
// -----
// CHECK-LABEL: func @const_bf16_f64
func @const_bf16_f64() -> tensor<f64> {
// CHECK-NEXT: [[CST:%.+]] = mhlo.constant dense<4.187500e+00> : tensor<f64>
%cst = mhlo.constant dense<4.2> : tensor<bf16>
%0 = "mhlo.convert"(%cst) : (tensor<bf16>) -> tensor<f64>
// CHECK-NEXT: return [[CST]]
return %0 : tensor<f64>
}
// -----
// CHECK-LABEL: func @const_bf16_int
func @const_bf16_int() -> tensor<i64> {
// CHECK-NEXT: [[CST:%.+]] = mhlo.constant dense<42> : tensor<i64>
%cst = mhlo.constant dense<42.0> : tensor<bf16>
%0 = "mhlo.convert"(%cst) : (tensor<bf16>) -> tensor<i64>
// CHECK-NEXT: return [[CST]]
return %0 : tensor<i64>
}
// -----
// CHECK-LABEL: func @const_high_rank_tensor
func @const_high_rank_tensor() -> tensor<2x3xi32> {
// CHECK-NEXT: [[CST:%.+]] = mhlo.constant dense<[
// CHECK-SAME: [1, 2, 3], [4, 5, 6]
// CHECK-SAME: ]> : tensor<2x3xi32>
%cst = mhlo.constant dense<[[1.0, 2.0, 3.0], [4.0, 5.0, 6.0]]> : tensor<2x3xf32>
%0 = "mhlo.convert"(%cst) : (tensor<2x3xf32>) -> tensor<2x3xi32>
// CHECK-NEXT: return [[CST]]
return %0 : tensor<2x3xi32>
}

View File

@ -1,10 +1,10 @@
// RUN: xla-opt -hlo-legalize-to-lhlo -buffer-placement -split-input-file %s -o - | FileCheck --check-prefixes=PRE,BOTH %s // RUN: mlir-hlo-opt -hlo-legalize-to-lhlo -buffer-placement -split-input-file %s -o - | FileCheck --check-prefixes=PRE,BOTH %s
// RUN: xla-opt -hlo-legalize-to-lhlo=results-escape-function=true -buffer-placement -split-input-file %s -o - | FileCheck --check-prefixes=ESC,BOTH %s // RUN: mlir-hlo-opt -hlo-legalize-to-lhlo=results-escape-function=true -buffer-placement -split-input-file %s -o - | FileCheck --check-prefixes=ESC,BOTH %s
// BOTH-LABEL: func @attrs // BOTH-LABEL: func @attrs
func @attrs_copy(%operand: memref<2x2xf32>, %result: memref<2x2xf32>) { func @attrs_copy(%operand: memref<2x2xf32>, %result: memref<2x2xf32>) {
%tensor_operand = tensor_load %operand : memref<2x2xf32> %tensor_operand = tensor_load %operand : memref<2x2xf32>
%tensor_result = "xla_hlo.exponential"(%tensor_operand) %tensor_result = "mhlo.exponential"(%tensor_operand)
{some_attr_1 = "exp.1", some_attr_2 = dense<1> : tensor<1xi64>} {some_attr_1 = "exp.1", some_attr_2 = dense<1> : tensor<1xi64>}
: (tensor<2x2xf32>) -> tensor<2x2xf32> : (tensor<2x2xf32>) -> tensor<2x2xf32>
// BOTH: "xla_lhlo.exponential"(%{{.*}}, %{{.*}}) {some_attr_1 = "exp.1", some_attr_2 = dense<1> : tensor<1xi64>} // BOTH: "xla_lhlo.exponential"(%{{.*}}, %{{.*}}) {some_attr_1 = "exp.1", some_attr_2 = dense<1> : tensor<1xi64>}
@ -28,11 +28,11 @@ func @return_func(%arg0: tensor<4xf32>) -> tensor<4xf32> {
// BOTH-LABEL: func @func_op_long // BOTH-LABEL: func @func_op_long
func @func_op_long(%arg0: tensor<4xf32>, %arg1: tensor<4xf32>) -> tensor<4xf32> { func @func_op_long(%arg0: tensor<4xf32>, %arg1: tensor<4xf32>) -> tensor<4xf32> {
%1 = xla_hlo.maximum %arg0, %arg1 : tensor<4xf32> %1 = mhlo.maximum %arg0, %arg1 : tensor<4xf32>
%2 = xla_hlo.add %arg0, %1 : tensor<4xf32> %2 = mhlo.add %arg0, %1 : tensor<4xf32>
%3 = xla_hlo.minimum %arg0, %arg1 : tensor<4xf32> %3 = mhlo.minimum %arg0, %arg1 : tensor<4xf32>
%4 = xla_hlo.subtract %arg1, %3 : tensor<4xf32> %4 = mhlo.subtract %arg1, %3 : tensor<4xf32>
%5 = xla_hlo.multiply %2, %4 : tensor<4xf32> %5 = mhlo.multiply %2, %4 : tensor<4xf32>
return %5 : tensor<4xf32> return %5 : tensor<4xf32>
} }
// PRE: (%[[NEW_ARG0:.*]]: memref<4xf32>, %[[NEW_ARG1:.*]]: memref<4xf32>, %[[RESULT:.*]]: memref<4xf32>) // PRE: (%[[NEW_ARG0:.*]]: memref<4xf32>, %[[NEW_ARG1:.*]]: memref<4xf32>, %[[RESULT:.*]]: memref<4xf32>)
@ -65,12 +65,12 @@ func @fusion(%multiplier: memref<2x2xf32>, %summand_1: memref<2x2xf32>,
// BOTH-NEXT: %[[ADD_RESULT:.*]] = alloc() : memref<2x2xf32> // BOTH-NEXT: %[[ADD_RESULT:.*]] = alloc() : memref<2x2xf32>
%tensor_summand_1 = tensor_load %summand_1 : memref<2x2xf32> %tensor_summand_1 = tensor_load %summand_1 : memref<2x2xf32>
%tensor_summand_2 = tensor_load %summand_2 : memref<2x2xf32> %tensor_summand_2 = tensor_load %summand_2 : memref<2x2xf32>
%sum = "xla_hlo.add"(%tensor_summand_1, %tensor_summand_2) %sum = "mhlo.add"(%tensor_summand_1, %tensor_summand_2)
: (tensor<2x2xf32>, tensor<2x2xf32>) -> tensor<2x2xf32> : (tensor<2x2xf32>, tensor<2x2xf32>) -> tensor<2x2xf32>
// BOTH-NEXT: "xla_lhlo.add"(%{{.*}}, %{{.*}}, %[[ADD_RESULT]]) // BOTH-NEXT: "xla_lhlo.add"(%{{.*}}, %{{.*}}, %[[ADD_RESULT]])
// BOTH-NEXT: %[[MUL_RESULT:.*]] = alloc() : memref<2x2xf32> // BOTH-NEXT: %[[MUL_RESULT:.*]] = alloc() : memref<2x2xf32>
%tensor_multiplier = tensor_load %multiplier : memref<2x2xf32> %tensor_multiplier = tensor_load %multiplier : memref<2x2xf32>
%tensor_result = "xla_hlo.multiply"(%sum, %tensor_multiplier) %tensor_result = "mhlo.multiply"(%sum, %tensor_multiplier)
: (tensor<2x2xf32>, tensor<2x2xf32>) -> tensor<2x2xf32> : (tensor<2x2xf32>, tensor<2x2xf32>) -> tensor<2x2xf32>
// BOTH-NEXT: "xla_lhlo.multiply"(%[[ADD_RESULT]], %{{.*}}, %[[MUL_RESULT]]) // BOTH-NEXT: "xla_lhlo.multiply"(%[[ADD_RESULT]], %{{.*}}, %[[MUL_RESULT]])
// BOTH-NEXT: dealloc %[[ADD_RESULT]] : memref<2x2xf32> // BOTH-NEXT: dealloc %[[ADD_RESULT]] : memref<2x2xf32>
@ -86,7 +86,7 @@ func @fusion(%multiplier: memref<2x2xf32>, %summand_1: memref<2x2xf32>,
// BOTH-LABEL: func @copy // BOTH-LABEL: func @copy
func @copy(%operand: memref<2x2xf32>, %result: memref<2x2xf32>) { func @copy(%operand: memref<2x2xf32>, %result: memref<2x2xf32>) {
%tensor_operand = tensor_load %operand : memref<2x2xf32> %tensor_operand = tensor_load %operand : memref<2x2xf32>
%tensor_result = "xla_hlo.copy"(%tensor_operand) %tensor_result = "mhlo.copy"(%tensor_operand)
: (tensor<2x2xf32>) -> tensor<2x2xf32> : (tensor<2x2xf32>) -> tensor<2x2xf32>
// BOTH: "xla_lhlo.copy"(%{{.*}}, %{{.*}}) // BOTH: "xla_lhlo.copy"(%{{.*}}, %{{.*}})
tensor_store %tensor_result, %result : memref<2x2xf32> tensor_store %tensor_result, %result : memref<2x2xf32>
@ -98,7 +98,7 @@ func @copy(%operand: memref<2x2xf32>, %result: memref<2x2xf32>) {
// BOTH-LABEL: func @exp // BOTH-LABEL: func @exp
func @exp(%operand: memref<2x2xf32>, %result: memref<2x2xf32>) { func @exp(%operand: memref<2x2xf32>, %result: memref<2x2xf32>) {
%tensor_operand = tensor_load %operand : memref<2x2xf32> %tensor_operand = tensor_load %operand : memref<2x2xf32>
%tensor_result = "xla_hlo.exponential"(%tensor_operand) %tensor_result = "mhlo.exponential"(%tensor_operand)
: (tensor<2x2xf32>) -> tensor<2x2xf32> : (tensor<2x2xf32>) -> tensor<2x2xf32>
// BOTH: "xla_lhlo.exponential"(%{{.*}}, %{{.*}}) // BOTH: "xla_lhlo.exponential"(%{{.*}}, %{{.*}})
tensor_store %tensor_result, %result : memref<2x2xf32> tensor_store %tensor_result, %result : memref<2x2xf32>
@ -110,7 +110,7 @@ func @exp(%operand: memref<2x2xf32>, %result: memref<2x2xf32>) {
// BOTH-LABEL: func @log // BOTH-LABEL: func @log
func @log(%operand: memref<2x2xf32>, %result: memref<2x2xf32>) { func @log(%operand: memref<2x2xf32>, %result: memref<2x2xf32>) {
%tensor_operand = tensor_load %operand : memref<2x2xf32> %tensor_operand = tensor_load %operand : memref<2x2xf32>
%tensor_result = "xla_hlo.log"(%tensor_operand) %tensor_result = "mhlo.log"(%tensor_operand)
: (tensor<2x2xf32>) -> tensor<2x2xf32> : (tensor<2x2xf32>) -> tensor<2x2xf32>
// BOTH: "xla_lhlo.log"(%{{.*}}, %{{.*}}) // BOTH: "xla_lhlo.log"(%{{.*}}, %{{.*}})
tensor_store %tensor_result, %result : memref<2x2xf32> tensor_store %tensor_result, %result : memref<2x2xf32>
@ -125,7 +125,7 @@ func @select(%pred: memref<2x2xi1>, %lhs: memref<2x2xf32>,
%tensor_pred = tensor_load %pred : memref<2x2xi1> %tensor_pred = tensor_load %pred : memref<2x2xi1>
%tensor_lhs = tensor_load %lhs : memref<2x2xf32> %tensor_lhs = tensor_load %lhs : memref<2x2xf32>
%tensor_rhs = tensor_load %rhs : memref<2x2xf32> %tensor_rhs = tensor_load %rhs : memref<2x2xf32>
%tensor_result = "xla_hlo.select"(%tensor_pred, %tensor_lhs, %tensor_rhs) %tensor_result = "mhlo.select"(%tensor_pred, %tensor_lhs, %tensor_rhs)
: (tensor<2x2xi1>, tensor<2x2xf32>, tensor<2x2xf32>) -> tensor<2x2xf32> : (tensor<2x2xi1>, tensor<2x2xf32>, tensor<2x2xf32>) -> tensor<2x2xf32>
// BOTH: "xla_lhlo.select"(%{{.*}}, %{{.*}}, %{{.*}}, %{{.*}}) // BOTH: "xla_lhlo.select"(%{{.*}}, %{{.*}}, %{{.*}}, %{{.*}})
tensor_store %tensor_result, %result : memref<2x2xf32> tensor_store %tensor_result, %result : memref<2x2xf32>
@ -138,7 +138,7 @@ func @select(%pred: memref<2x2xi1>, %lhs: memref<2x2xf32>,
func @compare(%lhs: memref<2x2xf32>, %rhs: memref<2x2xf32>, %result: memref<2x2xi1>) { func @compare(%lhs: memref<2x2xf32>, %rhs: memref<2x2xf32>, %result: memref<2x2xi1>) {
%tensor_lhs = tensor_load %lhs : memref<2x2xf32> %tensor_lhs = tensor_load %lhs : memref<2x2xf32>
%tensor_rhs = tensor_load %rhs : memref<2x2xf32> %tensor_rhs = tensor_load %rhs : memref<2x2xf32>
%tensor_result = "xla_hlo.compare"(%tensor_lhs, %tensor_rhs) %tensor_result = "mhlo.compare"(%tensor_lhs, %tensor_rhs)
{comparison_direction = "EQ"} {comparison_direction = "EQ"}
: (tensor<2x2xf32>, tensor<2x2xf32>) -> tensor<2x2xi1> : (tensor<2x2xf32>, tensor<2x2xf32>) -> tensor<2x2xi1>
// BOTH: "xla_lhlo.compare"(%{{.*}}, %{{.*}}, %{{.*}}) {comparison_direction = "EQ"} // BOTH: "xla_lhlo.compare"(%{{.*}}, %{{.*}}, %{{.*}}) {comparison_direction = "EQ"}
@ -151,7 +151,7 @@ func @compare(%lhs: memref<2x2xf32>, %rhs: memref<2x2xf32>, %result: memref<2x2x
// BOTH-LABEL: func @broadcast // BOTH-LABEL: func @broadcast
func @broadcast(%operand: memref<5xf32>, %result: memref<10x5xf32>) { func @broadcast(%operand: memref<5xf32>, %result: memref<10x5xf32>) {
%tensor_operand = tensor_load %operand : memref<5xf32> %tensor_operand = tensor_load %operand : memref<5xf32>
%tensor_result = "xla_hlo.broadcast_in_dim"(%tensor_operand) %tensor_result = "mhlo.broadcast_in_dim"(%tensor_operand)
{broadcast_dimensions = dense<1> : tensor<1xi64>} {broadcast_dimensions = dense<1> : tensor<1xi64>}
: (tensor<5xf32>) -> tensor<10x5xf32> : (tensor<5xf32>) -> tensor<10x5xf32>
// BOTH: "xla_lhlo.broadcast_in_dim"(%{{.*}}, %{{.*}}) {broadcast_dimensions = dense<1> : tensor<1xi64>} // BOTH: "xla_lhlo.broadcast_in_dim"(%{{.*}}, %{{.*}}) {broadcast_dimensions = dense<1> : tensor<1xi64>}
@ -170,7 +170,7 @@ func @dyn_broadcast(%operand: memref<?x?xf32>) {
// BOTH-SAME: (%[[OPERAND:.*]]: memref<?x?xf32>) // BOTH-SAME: (%[[OPERAND:.*]]: memref<?x?xf32>)
%tensor_operand = tensor_load %operand : memref<?x?xf32> %tensor_operand = tensor_load %operand : memref<?x?xf32>
%shape = call @external_func() : () -> tensor<3xi64> %shape = call @external_func() : () -> tensor<3xi64>
%tensor_result = "xla_hlo.dynamic_broadcast_in_dim"(%tensor_operand, %shape) { %tensor_result = "mhlo.dynamic_broadcast_in_dim"(%tensor_operand, %shape) {
broadcast_dimensions = dense<[1, 2]> : tensor<2xi64> broadcast_dimensions = dense<[1, 2]> : tensor<2xi64>
} : (tensor<?x?xf32>, tensor<3xi64>) -> tensor<?x?x?xf32> } : (tensor<?x?xf32>, tensor<3xi64>) -> tensor<?x?x?xf32>
// BOTH: %[[SHAPE:.*]] = call @external_func() // BOTH: %[[SHAPE:.*]] = call @external_func()
@ -226,7 +226,7 @@ func @complex(%real: memref<2x2xf32>,
%result: memref<2x2xcomplex<f32>>) { %result: memref<2x2xcomplex<f32>>) {
%tensor_real = tensor_load %real : memref<2x2xf32> %tensor_real = tensor_load %real : memref<2x2xf32>
%tensor_imag = tensor_load %imag : memref<2x2xf32> %tensor_imag = tensor_load %imag : memref<2x2xf32>
%tensor_result = "xla_hlo.complex"(%tensor_real, %tensor_imag) %tensor_result = "mhlo.complex"(%tensor_real, %tensor_imag)
: (tensor<2x2xf32>, tensor<2x2xf32>) -> tensor<2x2xcomplex<f32>> : (tensor<2x2xf32>, tensor<2x2xf32>) -> tensor<2x2xcomplex<f32>>
// BOTH: "xla_lhlo.complex"(%{{.*}}, %{{.*}}) // BOTH: "xla_lhlo.complex"(%{{.*}}, %{{.*}})
tensor_store %tensor_result, %result : memref<2x2xcomplex<f32>> tensor_store %tensor_result, %result : memref<2x2xcomplex<f32>>
@ -238,7 +238,7 @@ func @complex(%real: memref<2x2xf32>,
// BOTH-LABEL: func @real // BOTH-LABEL: func @real
func @real(%operand: memref<2x2xcomplex<f32>>, %result: memref<2x2xf32>) { func @real(%operand: memref<2x2xcomplex<f32>>, %result: memref<2x2xf32>) {
%tensor_operand = tensor_load %operand : memref<2x2xcomplex<f32>> %tensor_operand = tensor_load %operand : memref<2x2xcomplex<f32>>
%tensor_result = "xla_hlo.real"(%tensor_operand) %tensor_result = "mhlo.real"(%tensor_operand)
: (tensor<2x2xcomplex<f32>>) -> tensor<2x2xf32> : (tensor<2x2xcomplex<f32>>) -> tensor<2x2xf32>
// BOTH: "xla_lhlo.real"(%{{.*}}, %{{.*}}) // BOTH: "xla_lhlo.real"(%{{.*}}, %{{.*}})
tensor_store %tensor_result, %result : memref<2x2xf32> tensor_store %tensor_result, %result : memref<2x2xf32>
@ -250,7 +250,7 @@ func @real(%operand: memref<2x2xcomplex<f32>>, %result: memref<2x2xf32>) {
// BOTH-LABEL: func @imag // BOTH-LABEL: func @imag
func @imag(%operand: memref<2x2xcomplex<f32>>, %result: memref<2x2xf32>) { func @imag(%operand: memref<2x2xcomplex<f32>>, %result: memref<2x2xf32>) {
%tensor_operand = tensor_load %operand : memref<2x2xcomplex<f32>> %tensor_operand = tensor_load %operand : memref<2x2xcomplex<f32>>
%tensor_result = "xla_hlo.imag"(%tensor_operand) %tensor_result = "mhlo.imag"(%tensor_operand)
: (tensor<2x2xcomplex<f32>>) -> tensor<2x2xf32> : (tensor<2x2xcomplex<f32>>) -> tensor<2x2xf32>
// BOTH: "xla_lhlo.imag"(%{{.*}}, %{{.*}}) // BOTH: "xla_lhlo.imag"(%{{.*}}, %{{.*}})
tensor_store %tensor_result, %result : memref<2x2xf32> tensor_store %tensor_result, %result : memref<2x2xf32>
@ -261,7 +261,7 @@ func @imag(%operand: memref<2x2xcomplex<f32>>, %result: memref<2x2xf32>) {
// BOTH-LABEL: func @iota // BOTH-LABEL: func @iota
func @iota(%result: memref<10xi32>) { func @iota(%result: memref<10xi32>) {
%tensor_result = "xla_hlo.iota"() %tensor_result = "mhlo.iota"()
{iota_dimension = 0 : i64} : () -> tensor<10xi32> {iota_dimension = 0 : i64} : () -> tensor<10xi32>
// BOTH: "xla_lhlo.iota"(%{{.*}}) {iota_dimension = 0 : i64} // BOTH: "xla_lhlo.iota"(%{{.*}}) {iota_dimension = 0 : i64}
tensor_store %tensor_result, %result : memref<10xi32> tensor_store %tensor_result, %result : memref<10xi32>
@ -273,7 +273,7 @@ func @iota(%result: memref<10xi32>) {
// BOTH-LABEL: func @abs // BOTH-LABEL: func @abs
func @abs(%operand: memref<2x2xf32>, %result: memref<2x2xf32>) { func @abs(%operand: memref<2x2xf32>, %result: memref<2x2xf32>) {
%tensor_operand = tensor_load %operand : memref<2x2xf32> %tensor_operand = tensor_load %operand : memref<2x2xf32>
%tensor_result = "xla_hlo.abs"(%tensor_operand) %tensor_result = "mhlo.abs"(%tensor_operand)
: (tensor<2x2xf32>) -> tensor<2x2xf32> : (tensor<2x2xf32>) -> tensor<2x2xf32>
// BOTH: "xla_lhlo.abs"(%{{.*}}, %{{.*}}) // BOTH: "xla_lhlo.abs"(%{{.*}}, %{{.*}})
tensor_store %tensor_result, %result : memref<2x2xf32> tensor_store %tensor_result, %result : memref<2x2xf32>
@ -285,7 +285,7 @@ func @abs(%operand: memref<2x2xf32>, %result: memref<2x2xf32>) {
// BOTH-LABEL: func @ceil // BOTH-LABEL: func @ceil
func @ceil(%operand: memref<2x2xf32>, %result: memref<2x2xf32>) { func @ceil(%operand: memref<2x2xf32>, %result: memref<2x2xf32>) {
%tensor_operand = tensor_load %operand : memref<2x2xf32> %tensor_operand = tensor_load %operand : memref<2x2xf32>
%tensor_result = "xla_hlo.ceil"(%tensor_operand) %tensor_result = "mhlo.ceil"(%tensor_operand)
: (tensor<2x2xf32>) -> tensor<2x2xf32> : (tensor<2x2xf32>) -> tensor<2x2xf32>
// BOTH: "xla_lhlo.ceil"(%{{.*}}, %{{.*}}) // BOTH: "xla_lhlo.ceil"(%{{.*}}, %{{.*}})
tensor_store %tensor_result, %result : memref<2x2xf32> tensor_store %tensor_result, %result : memref<2x2xf32>
@ -297,7 +297,7 @@ func @ceil(%operand: memref<2x2xf32>, %result: memref<2x2xf32>) {
// BOTH-LABEL: func @convert // BOTH-LABEL: func @convert
func @convert(%operand: memref<2x2xf32>, %result: memref<2x2xf32>) { func @convert(%operand: memref<2x2xf32>, %result: memref<2x2xf32>) {
%tensor_operand = tensor_load %operand : memref<2x2xf32> %tensor_operand = tensor_load %operand : memref<2x2xf32>
%tensor_result = "xla_hlo.convert"(%tensor_operand) %tensor_result = "mhlo.convert"(%tensor_operand)
: (tensor<2x2xf32>) -> tensor<2x2xf32> : (tensor<2x2xf32>) -> tensor<2x2xf32>
// BOTH: "xla_lhlo.copy"(%{{.*}}, %{{.*}}) // BOTH: "xla_lhlo.copy"(%{{.*}}, %{{.*}})
// BOTH-NOT: tensor_store // BOTH-NOT: tensor_store
@ -310,7 +310,7 @@ func @convert(%operand: memref<2x2xf32>, %result: memref<2x2xf32>) {
// BOTH-LABEL: func @cos // BOTH-LABEL: func @cos
func @cos(%operand: memref<2x2xf32>, %result: memref<2x2xf32>) { func @cos(%operand: memref<2x2xf32>, %result: memref<2x2xf32>) {
%tensor_operand = tensor_load %operand : memref<2x2xf32> %tensor_operand = tensor_load %operand : memref<2x2xf32>
%tensor_result = "xla_hlo.cosine"(%tensor_operand) %tensor_result = "mhlo.cosine"(%tensor_operand)
: (tensor<2x2xf32>) -> tensor<2x2xf32> : (tensor<2x2xf32>) -> tensor<2x2xf32>
// BOTH: "xla_lhlo.cosine"(%{{.*}}, %{{.*}}) // BOTH: "xla_lhlo.cosine"(%{{.*}}, %{{.*}})
tensor_store %tensor_result, %result : memref<2x2xf32> tensor_store %tensor_result, %result : memref<2x2xf32>
@ -322,7 +322,7 @@ func @cos(%operand: memref<2x2xf32>, %result: memref<2x2xf32>) {
// BOTH-LABEL: func @neg // BOTH-LABEL: func @neg
func @neg(%operand: memref<2x2xf32>, %result: memref<2x2xf32>) { func @neg(%operand: memref<2x2xf32>, %result: memref<2x2xf32>) {
%tensor_operand = tensor_load %operand : memref<2x2xf32> %tensor_operand = tensor_load %operand : memref<2x2xf32>
%tensor_result = "xla_hlo.negate"(%tensor_operand) %tensor_result = "mhlo.negate"(%tensor_operand)
: (tensor<2x2xf32>) -> tensor<2x2xf32> : (tensor<2x2xf32>) -> tensor<2x2xf32>
// BOTH: "xla_lhlo.negate"(%{{.*}}, %{{.*}}) // BOTH: "xla_lhlo.negate"(%{{.*}}, %{{.*}})
tensor_store %tensor_result, %result : memref<2x2xf32> tensor_store %tensor_result, %result : memref<2x2xf32>
@ -334,7 +334,7 @@ func @neg(%operand: memref<2x2xf32>, %result: memref<2x2xf32>) {
// BOTH-LABEL: func @rsqrt // BOTH-LABEL: func @rsqrt
func @rsqrt(%operand: memref<2x2xf32>, %result: memref<2x2xf32>) { func @rsqrt(%operand: memref<2x2xf32>, %result: memref<2x2xf32>) {
%tensor_operand = tensor_load %operand : memref<2x2xf32> %tensor_operand = tensor_load %operand : memref<2x2xf32>
%tensor_result = "xla_hlo.rsqrt"(%tensor_operand) %tensor_result = "mhlo.rsqrt"(%tensor_operand)
: (tensor<2x2xf32>) -> tensor<2x2xf32> : (tensor<2x2xf32>) -> tensor<2x2xf32>
// BOTH: "xla_lhlo.rsqrt"(%{{.*}}, %{{.*}}) // BOTH: "xla_lhlo.rsqrt"(%{{.*}}, %{{.*}})
tensor_store %tensor_result, %result : memref<2x2xf32> tensor_store %tensor_result, %result : memref<2x2xf32>
@ -346,7 +346,7 @@ func @rsqrt(%operand: memref<2x2xf32>, %result: memref<2x2xf32>) {
// BOTH-LABEL: func @sign // BOTH-LABEL: func @sign
func @sign(%operand: memref<2x2xf32>, %result: memref<2x2xf32>) { func @sign(%operand: memref<2x2xf32>, %result: memref<2x2xf32>) {
%tensor_operand = tensor_load %operand : memref<2x2xf32> %tensor_operand = tensor_load %operand : memref<2x2xf32>
%tensor_result = "xla_hlo.sign"(%tensor_operand) %tensor_result = "mhlo.sign"(%tensor_operand)
: (tensor<2x2xf32>) -> tensor<2x2xf32> : (tensor<2x2xf32>) -> tensor<2x2xf32>
// BOTH: "xla_lhlo.sign"(%{{.*}}, %{{.*}}) // BOTH: "xla_lhlo.sign"(%{{.*}}, %{{.*}})
tensor_store %tensor_result, %result : memref<2x2xf32> tensor_store %tensor_result, %result : memref<2x2xf32>
@ -358,7 +358,7 @@ func @sign(%operand: memref<2x2xf32>, %result: memref<2x2xf32>) {
// BOTH-LABEL: func @sqrt // BOTH-LABEL: func @sqrt
func @sqrt(%operand: memref<2x2xf32>, %result: memref<2x2xf32>) { func @sqrt(%operand: memref<2x2xf32>, %result: memref<2x2xf32>) {
%tensor_operand = tensor_load %operand : memref<2x2xf32> %tensor_operand = tensor_load %operand : memref<2x2xf32>
%tensor_result = "xla_hlo.sqrt"(%tensor_operand) %tensor_result = "mhlo.sqrt"(%tensor_operand)
: (tensor<2x2xf32>) -> tensor<2x2xf32> : (tensor<2x2xf32>) -> tensor<2x2xf32>
// BOTH: "xla_lhlo.sqrt"(%{{.*}}, %{{.*}}) // BOTH: "xla_lhlo.sqrt"(%{{.*}}, %{{.*}})
tensor_store %tensor_result, %result : memref<2x2xf32> tensor_store %tensor_result, %result : memref<2x2xf32>
@ -370,7 +370,7 @@ func @sqrt(%operand: memref<2x2xf32>, %result: memref<2x2xf32>) {
// BOTH-LABEL: func @tanh // BOTH-LABEL: func @tanh
func @tanh(%operand: memref<2x2xf32>, %result: memref<2x2xf32>) { func @tanh(%operand: memref<2x2xf32>, %result: memref<2x2xf32>) {
%tensor_operand = tensor_load %operand : memref<2x2xf32> %tensor_operand = tensor_load %operand : memref<2x2xf32>
%tensor_result = "xla_hlo.tanh"(%tensor_operand) %tensor_result = "mhlo.tanh"(%tensor_operand)
: (tensor<2x2xf32>) -> tensor<2x2xf32> : (tensor<2x2xf32>) -> tensor<2x2xf32>
// BOTH: "xla_lhlo.tanh"(%{{.*}}, %{{.*}}) // BOTH: "xla_lhlo.tanh"(%{{.*}}, %{{.*}})
tensor_store %tensor_result, %result : memref<2x2xf32> tensor_store %tensor_result, %result : memref<2x2xf32>
@ -383,7 +383,7 @@ func @tanh(%operand: memref<2x2xf32>, %result: memref<2x2xf32>) {
func @remainder(%lhs: memref<2x2xf32>, %rhs: memref<2x2xf32>, %result: memref<2x2xf32>) { func @remainder(%lhs: memref<2x2xf32>, %rhs: memref<2x2xf32>, %result: memref<2x2xf32>) {
%tensor_lhs = tensor_load %lhs : memref<2x2xf32> %tensor_lhs = tensor_load %lhs : memref<2x2xf32>
%tensor_rhs = tensor_load %rhs : memref<2x2xf32> %tensor_rhs = tensor_load %rhs : memref<2x2xf32>
%tensor_result = "xla_hlo.remainder"(%tensor_lhs, %tensor_rhs) %tensor_result = "mhlo.remainder"(%tensor_lhs, %tensor_rhs)
: (tensor<2x2xf32>, tensor<2x2xf32>) -> tensor<2x2xf32> : (tensor<2x2xf32>, tensor<2x2xf32>) -> tensor<2x2xf32>
// BOTH: "xla_lhlo.remainder"(%{{.*}}, %{{.*}}, %{{.*}}) // BOTH: "xla_lhlo.remainder"(%{{.*}}, %{{.*}}, %{{.*}})
tensor_store %tensor_result, %result : memref<2x2xf32> tensor_store %tensor_result, %result : memref<2x2xf32>
@ -395,7 +395,7 @@ func @remainder(%lhs: memref<2x2xf32>, %rhs: memref<2x2xf32>, %result: memref<2x
// Dynamic shape binary element-wise operation. // Dynamic shape binary element-wise operation.
// BOTH-LABEL: func @add_dyn // BOTH-LABEL: func @add_dyn
func @add_dyn(%lhs: tensor<?x?xf32>, %rhs: tensor<?x?xf32>) { func @add_dyn(%lhs: tensor<?x?xf32>, %rhs: tensor<?x?xf32>) {
%result = "xla_hlo.add"(%lhs, %rhs) %result = "mhlo.add"(%lhs, %rhs)
: (tensor<?x?xf32>, tensor<?x?xf32>) -> tensor<?x?xf32> : (tensor<?x?xf32>, tensor<?x?xf32>) -> tensor<?x?xf32>
// BOTH: %[[C0:.*]] = constant 0 : index // BOTH: %[[C0:.*]] = constant 0 : index
// BOTH: %[[DIM0:.*]] = dim %arg0, %[[C0]] : memref<?x?xf32> // BOTH: %[[DIM0:.*]] = dim %arg0, %[[C0]] : memref<?x?xf32>
@ -420,7 +420,7 @@ func @add_dyn(%lhs: tensor<?x?xf32>, %rhs: tensor<?x?xf32>) {
// Dynamic shape unary element-wise operation. // Dynamic shape unary element-wise operation.
// BOTH-LABEL: func @tanh_dyn // BOTH-LABEL: func @tanh_dyn
func @tanh_dyn(%arg0: tensor<?x?xf32>) { func @tanh_dyn(%arg0: tensor<?x?xf32>) {
%result = "xla_hlo.tanh"(%arg0) %result = "mhlo.tanh"(%arg0)
: (tensor<?x?xf32>) -> tensor<?x?xf32> : (tensor<?x?xf32>) -> tensor<?x?xf32>
// BOTH: %[[C0:.*]] = constant 0 : index // BOTH: %[[C0:.*]] = constant 0 : index
// BOTH: %[[DIM0:.*]] = dim %arg0, %[[C0]] : memref<?x?xf32> // BOTH: %[[DIM0:.*]] = dim %arg0, %[[C0]] : memref<?x?xf32>
@ -448,7 +448,7 @@ func @dot(%arg0: tensor<1024x1024xf32>) -> tensor<1024x1024xf32> {
// ESC-SAME: (%[[ARG0:.*]]: [[TYPE:.*]]) -> [[TYPE]] // ESC-SAME: (%[[ARG0:.*]]: [[TYPE:.*]]) -> [[TYPE]]
// BOTH-NEXT: %[[ALLOC:.*]] = alloc // BOTH-NEXT: %[[ALLOC:.*]] = alloc
// BOTH: "xla_lhlo.dot"(%[[ARG0]], %[[ARG0]], %[[ALLOC]]) : ([[TYPE]], [[TYPE]], [[TYPE]]) -> () // BOTH: "xla_lhlo.dot"(%[[ARG0]], %[[ARG0]], %[[ALLOC]]) : ([[TYPE]], [[TYPE]], [[TYPE]]) -> ()
%dot = "xla_hlo.dot"(%arg0, %arg0) %dot = "mhlo.dot"(%arg0, %arg0)
: (tensor<1024x1024xf32>, tensor<1024x1024xf32>) -> tensor<1024x1024xf32> : (tensor<1024x1024xf32>, tensor<1024x1024xf32>) -> tensor<1024x1024xf32>
// PRE: "xla_lhlo.copy"(%[[ALLOC]], %[[RESULT]]) // PRE: "xla_lhlo.copy"(%[[ALLOC]], %[[RESULT]])
// ESC: return %[[ALLOC]] // ESC: return %[[ALLOC]]
@ -466,7 +466,7 @@ func @conv(%input: tensor<3x5x5x3xf32>, %filter : tensor<2x2x3x4xf32>) -> tensor
// BOTH-SAME: [0, 1], [0, 1]]> : tensor<2x2xi64> // BOTH-SAME: [0, 1], [0, 1]]> : tensor<2x2xi64>
// BOTH-SAME: rhs_dilation = dense<[1, 2]> // BOTH-SAME: rhs_dilation = dense<[1, 2]>
// BOTH-SAME: window_strides = dense<[2, 1]> // BOTH-SAME: window_strides = dense<[2, 1]>
%out = "xla_hlo.convolution"(%filter, %input) { %out = "mhlo.convolution"(%filter, %input) {
batch_group_count = 1 : i64, batch_group_count = 1 : i64,
dimension_numbers = { dimension_numbers = {
input_batch_dimension = 0 : i64, input_batch_dimension = 0 : i64,

View File

@ -1,4 +1,4 @@
// RUN: xla-opt %s -hlo-legalize-to-linalg -split-input-file | FileCheck %s // RUN: mlir-hlo-opt %s -hlo-legalize-to-linalg -split-input-file | FileCheck %s
// CHECK: #map0 = affine_map<(d0, d1) -> (d0, d1)> // CHECK: #map0 = affine_map<(d0, d1) -> (d0, d1)>
// CHECK-LABEL: func @float_add // CHECK-LABEL: func @float_add
@ -10,7 +10,7 @@ func @float_add(%lhs: tensor<2x2xf32>,
// CHECK-SAME: %[[ARG1:[a-zA-Z0-9_]*]]: f32 // CHECK-SAME: %[[ARG1:[a-zA-Z0-9_]*]]: f32
// CHECK: %[[RESULT:[a-zA-Z0-9_]*]] = addf %[[ARG0]], %[[ARG1]] // CHECK: %[[RESULT:[a-zA-Z0-9_]*]] = addf %[[ARG0]], %[[ARG1]]
// CHECK: linalg.yield %[[RESULT]] // CHECK: linalg.yield %[[RESULT]]
%0 = "xla_hlo.add"(%lhs, %rhs) : (tensor<2x2xf32>, %0 = "mhlo.add"(%lhs, %rhs) : (tensor<2x2xf32>,
tensor<2x2xf32>) -> tensor<2x2xf32> tensor<2x2xf32>) -> tensor<2x2xf32>
return %0 : tensor<2x2xf32> return %0 : tensor<2x2xf32>
} }
@ -22,7 +22,7 @@ func @integer_add(%lhs: tensor<2x2xi32>,
%rhs: tensor<2x2xi32>) -> tensor<2x2xi32> { %rhs: tensor<2x2xi32>) -> tensor<2x2xi32> {
// CHECK: linalg.generic // CHECK: linalg.generic
// CHECK: addi // CHECK: addi
%0 = "xla_hlo.add"(%lhs, %rhs) : (tensor<2x2xi32>, %0 = "mhlo.add"(%lhs, %rhs) : (tensor<2x2xi32>,
tensor<2x2xi32>) -> tensor<2x2xi32> tensor<2x2xi32>) -> tensor<2x2xi32>
return %0 : tensor<2x2xi32> return %0 : tensor<2x2xi32>
} }
@ -34,7 +34,7 @@ func @float_mul(%lhs: tensor<2x2xf32>,
%rhs: tensor<2x2xf32>) -> tensor<2x2xf32> { %rhs: tensor<2x2xf32>) -> tensor<2x2xf32> {
// CHECK: linalg.generic // CHECK: linalg.generic
// CHECK: mulf // CHECK: mulf
%0 = "xla_hlo.multiply"(%lhs, %rhs) : (tensor<2x2xf32>, %0 = "mhlo.multiply"(%lhs, %rhs) : (tensor<2x2xf32>,
tensor<2x2xf32>) -> tensor<2x2xf32> tensor<2x2xf32>) -> tensor<2x2xf32>
return %0 : tensor<2x2xf32> return %0 : tensor<2x2xf32>
} }
@ -46,7 +46,7 @@ func @integer_mul(%lhs: tensor<2x2xi32>,
%rhs: tensor<2x2xi32>) -> tensor<2x2xi32> { %rhs: tensor<2x2xi32>) -> tensor<2x2xi32> {
// CHECK: linalg.generic // CHECK: linalg.generic
// CHECK: muli // CHECK: muli
%0 = "xla_hlo.multiply"(%lhs, %rhs) : (tensor<2x2xi32>, %0 = "mhlo.multiply"(%lhs, %rhs) : (tensor<2x2xi32>,
tensor<2x2xi32>) -> tensor<2x2xi32> tensor<2x2xi32>) -> tensor<2x2xi32>
return %0 : tensor<2x2xi32> return %0 : tensor<2x2xi32>
} }
@ -58,7 +58,7 @@ func @float_remainder(%lhs: tensor<2x2xf32>,
%rhs: tensor<2x2xf32>) -> tensor<2x2xf32> { %rhs: tensor<2x2xf32>) -> tensor<2x2xf32> {
// CHECK: linalg.generic // CHECK: linalg.generic
// CHECK: remf // CHECK: remf
%0 = "xla_hlo.remainder"(%lhs, %rhs) : (tensor<2x2xf32>, %0 = "mhlo.remainder"(%lhs, %rhs) : (tensor<2x2xf32>,
tensor<2x2xf32>) -> tensor<2x2xf32> tensor<2x2xf32>) -> tensor<2x2xf32>
return %0 : tensor<2x2xf32> return %0 : tensor<2x2xf32>
} }
@ -70,7 +70,7 @@ func @integer_remainder(%lhs: tensor<2x2xi32>,
%rhs: tensor<2x2xi32>) -> tensor<2x2xi32> { %rhs: tensor<2x2xi32>) -> tensor<2x2xi32> {
// CHECK: linalg.generic // CHECK: linalg.generic
// CHECK: remi_signed // CHECK: remi_signed
%0 = "xla_hlo.remainder"(%lhs, %rhs) : (tensor<2x2xi32>, %0 = "mhlo.remainder"(%lhs, %rhs) : (tensor<2x2xi32>,
tensor<2x2xi32>) -> tensor<2x2xi32> tensor<2x2xi32>) -> tensor<2x2xi32>
return %0 : tensor<2x2xi32> return %0 : tensor<2x2xi32>
} }
@ -79,7 +79,7 @@ func @integer_remainder(%lhs: tensor<2x2xi32>,
// CHECK-LABEL: func @float_rsqrt // CHECK-LABEL: func @float_rsqrt
func @float_rsqrt(%operand: tensor<2x2xf32>) -> tensor<2x2xf32> { func @float_rsqrt(%operand: tensor<2x2xf32>) -> tensor<2x2xf32> {
%tensor_result = "xla_hlo.rsqrt"(%operand) %tensor_result = "mhlo.rsqrt"(%operand)
: (tensor<2x2xf32>) -> tensor<2x2xf32> : (tensor<2x2xf32>) -> tensor<2x2xf32>
// CHECK: linalg.generic // CHECK: linalg.generic
// CHECK: rsqrt // CHECK: rsqrt
@ -93,7 +93,7 @@ func @float_sub(%lhs: tensor<2x2xf32>,
%rhs: tensor<2x2xf32>) -> tensor<2x2xf32> { %rhs: tensor<2x2xf32>) -> tensor<2x2xf32> {
// CHECK: linalg.generic // CHECK: linalg.generic
// CHECK: subf // CHECK: subf
%0 = "xla_hlo.subtract"(%lhs, %rhs) : (tensor<2x2xf32>, %0 = "mhlo.subtract"(%lhs, %rhs) : (tensor<2x2xf32>,
tensor<2x2xf32>) -> tensor<2x2xf32> tensor<2x2xf32>) -> tensor<2x2xf32>
return %0 : tensor<2x2xf32> return %0 : tensor<2x2xf32>
} }
@ -105,7 +105,7 @@ func @integer_sub(%lhs: tensor<2x2xi32>,
%rhs: tensor<2x2xi32>) -> tensor<2x2xi32> { %rhs: tensor<2x2xi32>) -> tensor<2x2xi32> {
// CHECK: linalg.generic // CHECK: linalg.generic
// CHECK: subi // CHECK: subi
%0 = "xla_hlo.subtract"(%lhs, %rhs) : (tensor<2x2xi32>, %0 = "mhlo.subtract"(%lhs, %rhs) : (tensor<2x2xi32>,
tensor<2x2xi32>) -> tensor<2x2xi32> tensor<2x2xi32>) -> tensor<2x2xi32>
return %0 : tensor<2x2xi32> return %0 : tensor<2x2xi32>
} }
@ -116,7 +116,7 @@ func @integer_sub(%lhs: tensor<2x2xi32>,
func @float_abs(%arg0: tensor<2x2xf32>) -> tensor<2x2xf32> { func @float_abs(%arg0: tensor<2x2xf32>) -> tensor<2x2xf32> {
// CHECK: linalg.generic // CHECK: linalg.generic
// CHECK: absf // CHECK: absf
%0 = "xla_hlo.abs"(%arg0) : (tensor<2x2xf32>) -> tensor<2x2xf32> %0 = "mhlo.abs"(%arg0) : (tensor<2x2xf32>) -> tensor<2x2xf32>
return %0 : tensor<2x2xf32> return %0 : tensor<2x2xf32>
} }
@ -126,7 +126,7 @@ func @float_abs(%arg0: tensor<2x2xf32>) -> tensor<2x2xf32> {
func @float_exp(%arg0: tensor<2x2xf32>) -> tensor<2x2xf32> { func @float_exp(%arg0: tensor<2x2xf32>) -> tensor<2x2xf32> {
// CHECK: linalg.generic // CHECK: linalg.generic
// CHECK: exp // CHECK: exp
%0 = "xla_hlo.exponential"(%arg0) : (tensor<2x2xf32>) -> tensor<2x2xf32> %0 = "mhlo.exponential"(%arg0) : (tensor<2x2xf32>) -> tensor<2x2xf32>
return %0 : tensor<2x2xf32> return %0 : tensor<2x2xf32>
} }
@ -136,7 +136,7 @@ func @float_exp(%arg0: tensor<2x2xf32>) -> tensor<2x2xf32> {
func @float_log(%arg0: tensor<2x2xf32>) -> tensor<2x2xf32> { func @float_log(%arg0: tensor<2x2xf32>) -> tensor<2x2xf32> {
// CHECK: linalg.generic // CHECK: linalg.generic
// CHECK: log // CHECK: log
%0 = "xla_hlo.log"(%arg0) : (tensor<2x2xf32>) -> tensor<2x2xf32> %0 = "mhlo.log"(%arg0) : (tensor<2x2xf32>) -> tensor<2x2xf32>
return %0 : tensor<2x2xf32> return %0 : tensor<2x2xf32>
} }
@ -146,7 +146,7 @@ func @float_log(%arg0: tensor<2x2xf32>) -> tensor<2x2xf32> {
func @float_ceil(%arg0: tensor<2x2xf32>) -> tensor<2x2xf32> { func @float_ceil(%arg0: tensor<2x2xf32>) -> tensor<2x2xf32> {
// CHECK: linalg.generic // CHECK: linalg.generic
// CHECK: ceilf // CHECK: ceilf
%0 = "xla_hlo.ceil"(%arg0) : (tensor<2x2xf32>) -> tensor<2x2xf32> %0 = "mhlo.ceil"(%arg0) : (tensor<2x2xf32>) -> tensor<2x2xf32>
return %0 : tensor<2x2xf32> return %0 : tensor<2x2xf32>
} }
@ -156,7 +156,7 @@ func @float_ceil(%arg0: tensor<2x2xf32>) -> tensor<2x2xf32> {
func @float_neg(%arg0: tensor<2x2xf32>) -> tensor<2x2xf32> { func @float_neg(%arg0: tensor<2x2xf32>) -> tensor<2x2xf32> {
// CHECK: linalg.generic // CHECK: linalg.generic
// CHECK: negf // CHECK: negf
%0 = "xla_hlo.negate"(%arg0) : (tensor<2x2xf32>) -> tensor<2x2xf32> %0 = "mhlo.negate"(%arg0) : (tensor<2x2xf32>) -> tensor<2x2xf32>
return %0 : tensor<2x2xf32> return %0 : tensor<2x2xf32>
} }
@ -166,7 +166,7 @@ func @float_neg(%arg0: tensor<2x2xf32>) -> tensor<2x2xf32> {
func @float_tanh(%arg0: tensor<2x2xf32>) -> tensor<2x2xf32> { func @float_tanh(%arg0: tensor<2x2xf32>) -> tensor<2x2xf32> {
// CHECK: linalg.generic // CHECK: linalg.generic
// CHECK: tanh // CHECK: tanh
%0 = "xla_hlo.tanh"(%arg0) : (tensor<2x2xf32>) -> tensor<2x2xf32> %0 = "mhlo.tanh"(%arg0) : (tensor<2x2xf32>) -> tensor<2x2xf32>
return %0 : tensor<2x2xf32> return %0 : tensor<2x2xf32>
} }
@ -177,7 +177,7 @@ func @integer_and(%lhs: tensor<2x2xi32>,
%rhs: tensor<2x2xi32>) -> tensor<2x2xi32> { %rhs: tensor<2x2xi32>) -> tensor<2x2xi32> {
// CHECK: linalg.generic // CHECK: linalg.generic
// CHECK: and // CHECK: and
%0 = "xla_hlo.and"(%lhs, %rhs) : (tensor<2x2xi32>, %0 = "mhlo.and"(%lhs, %rhs) : (tensor<2x2xi32>,
tensor<2x2xi32>) -> tensor<2x2xi32> tensor<2x2xi32>) -> tensor<2x2xi32>
return %0 : tensor<2x2xi32> return %0 : tensor<2x2xi32>
} }
@ -187,7 +187,7 @@ func @integer_and(%lhs: tensor<2x2xi32>,
// CHECK-LABEL: func @float_cmp // CHECK-LABEL: func @float_cmp
func @float_cmp(%lhs: tensor<2x2xf32>, func @float_cmp(%lhs: tensor<2x2xf32>,
%rhs: tensor<2x2xf32>) -> (tensor<2x2xi1>) { %rhs: tensor<2x2xf32>) -> (tensor<2x2xi1>) {
%0 = "xla_hlo.compare"(%lhs, %rhs) {comparison_direction = "EQ"} %0 = "mhlo.compare"(%lhs, %rhs) {comparison_direction = "EQ"}
: (tensor<2x2xf32>, tensor<2x2xf32>) -> tensor<2x2xi1> : (tensor<2x2xf32>, tensor<2x2xf32>) -> tensor<2x2xi1>
return %0 : tensor<2x2xi1> return %0 : tensor<2x2xi1>
} }
@ -201,7 +201,7 @@ func @float_cmp(%lhs: tensor<2x2xf32>,
// CHECK-LABEL: func @int_cmp // CHECK-LABEL: func @int_cmp
func @int_cmp(%lhs: tensor<2x2xi32>, func @int_cmp(%lhs: tensor<2x2xi32>,
%rhs: tensor<2x2xi32>) -> tensor<2x2xi1> { %rhs: tensor<2x2xi32>) -> tensor<2x2xi1> {
%0 = "xla_hlo.compare"(%lhs, %rhs) {comparison_direction = "LT"} %0 = "mhlo.compare"(%lhs, %rhs) {comparison_direction = "LT"}
: (tensor<2x2xi32>, tensor<2x2xi32>) -> (tensor<2x2xi1>) : (tensor<2x2xi32>, tensor<2x2xi32>) -> (tensor<2x2xi1>)
return %0 : tensor<2x2xi1> return %0 : tensor<2x2xi1>
} }
@ -216,7 +216,7 @@ func @int_cmp(%lhs: tensor<2x2xi32>,
func @float_cos(%arg0: tensor<2x2xf32>) -> tensor<2x2xf32> { func @float_cos(%arg0: tensor<2x2xf32>) -> tensor<2x2xf32> {
// CHECK: linalg.generic // CHECK: linalg.generic
// CHECK: cos // CHECK: cos
%0 = "xla_hlo.cosine"(%arg0) : (tensor<2x2xf32>) -> tensor<2x2xf32> %0 = "mhlo.cosine"(%arg0) : (tensor<2x2xf32>) -> tensor<2x2xf32>
return %0 : tensor<2x2xf32> return %0 : tensor<2x2xf32>
} }
@ -226,7 +226,7 @@ func @float_cos(%arg0: tensor<2x2xf32>) -> tensor<2x2xf32> {
func @float_sin(%arg0: tensor<2x2xf32>) -> tensor<2x2xf32> { func @float_sin(%arg0: tensor<2x2xf32>) -> tensor<2x2xf32> {
// CHECK: linalg.generic // CHECK: linalg.generic
// CHECK: sin // CHECK: sin
%0 = "xla_hlo.sine"(%arg0) : (tensor<2x2xf32>) -> tensor<2x2xf32> %0 = "mhlo.sine"(%arg0) : (tensor<2x2xf32>) -> tensor<2x2xf32>
return %0 : tensor<2x2xf32> return %0 : tensor<2x2xf32>
} }
@ -235,7 +235,7 @@ func @float_sin(%arg0: tensor<2x2xf32>) -> tensor<2x2xf32> {
// CHECK-LABEL: func @copy // CHECK-LABEL: func @copy
// CHECK-SAME: [[ARG:%[a-zA-Z0-9]+]] // CHECK-SAME: [[ARG:%[a-zA-Z0-9]+]]
func @copy(%input: tensor<2x4x8xf32>) -> tensor<2x4x8xf32> { func @copy(%input: tensor<2x4x8xf32>) -> tensor<2x4x8xf32> {
%0 = "xla_hlo.copy"(%input) : (tensor<2x4x8xf32>) -> (tensor<2x4x8xf32>) %0 = "mhlo.copy"(%input) : (tensor<2x4x8xf32>) -> (tensor<2x4x8xf32>)
return %0 : tensor<2x4x8xf32> return %0 : tensor<2x4x8xf32>
} }
// CHECK: return [[ARG]] : tensor<2x4x8xf32> // CHECK: return [[ARG]] : tensor<2x4x8xf32>
@ -245,7 +245,7 @@ func @copy(%input: tensor<2x4x8xf32>) -> tensor<2x4x8xf32> {
// CHECK-LABEL: func @select // CHECK-LABEL: func @select
func @select(%pred: tensor<2x2xi1>, %lhs: tensor<2x2xf32>, func @select(%pred: tensor<2x2xi1>, %lhs: tensor<2x2xf32>,
%rhs: tensor<2x2xf32>) -> tensor<2x2xf32> { %rhs: tensor<2x2xf32>) -> tensor<2x2xf32> {
%0 = "xla_hlo.select"(%pred, %lhs, %rhs) %0 = "mhlo.select"(%pred, %lhs, %rhs)
: (tensor<2x2xi1>, tensor<2x2xf32>, tensor<2x2xf32>) -> (tensor<2x2xf32>) : (tensor<2x2xi1>, tensor<2x2xf32>, tensor<2x2xf32>) -> (tensor<2x2xf32>)
return %0 : tensor<2x2xf32> return %0 : tensor<2x2xf32>
} }
@ -260,7 +260,7 @@ func @select(%pred: tensor<2x2xi1>, %lhs: tensor<2x2xf32>,
// CHECK-DAG: #[[RESULT_MAP:.+]] = affine_map<(d0, d1, d2) -> (d0, d1, d2)> // CHECK-DAG: #[[RESULT_MAP:.+]] = affine_map<(d0, d1, d2) -> (d0, d1, d2)>
// CHECK-LABEL: func @broadcast_scalar // CHECK-LABEL: func @broadcast_scalar
func @broadcast_scalar(%arg: tensor<f32>) -> tensor<4x2x1xf32> { func @broadcast_scalar(%arg: tensor<f32>) -> tensor<4x2x1xf32> {
%0 = "xla_hlo.broadcast"(%arg) {broadcast_sizes = dense<[4, 2, 1]> : tensor<3xi64>} : (tensor<f32>) -> tensor<4x2x1xf32> %0 = "mhlo.broadcast"(%arg) {broadcast_sizes = dense<[4, 2, 1]> : tensor<3xi64>} : (tensor<f32>) -> tensor<4x2x1xf32>
return %0: tensor<4x2x1xf32> return %0: tensor<4x2x1xf32>
} }
// CHECK: linalg.generic {{{.*}}indexing_maps = [#[[OPERAND_MAP]], #[[RESULT_MAP]]] // CHECK: linalg.generic {{{.*}}indexing_maps = [#[[OPERAND_MAP]], #[[RESULT_MAP]]]
@ -273,7 +273,7 @@ func @broadcast_scalar(%arg: tensor<f32>) -> tensor<4x2x1xf32> {
// CHECK-DAG: #[[RESULT_MAP:.+]] = affine_map<(d0, d1, d2, d3, d4, d5) -> (d0, d1, d2, d3, d4, d5)> // CHECK-DAG: #[[RESULT_MAP:.+]] = affine_map<(d0, d1, d2, d3, d4, d5) -> (d0, d1, d2, d3, d4, d5)>
// CHECK-LABEL: func @broadcast // CHECK-LABEL: func @broadcast
func @broadcast(%arg: tensor<4x?x16xf32>) -> tensor<4x2x1x4x?x16xf32> { func @broadcast(%arg: tensor<4x?x16xf32>) -> tensor<4x2x1x4x?x16xf32> {
%0 = "xla_hlo.broadcast"(%arg) {broadcast_sizes = dense<[4, 2, 1]> : tensor<3xi64>} : (tensor<4x?x16xf32>) -> tensor<4x2x1x4x?x16xf32> %0 = "mhlo.broadcast"(%arg) {broadcast_sizes = dense<[4, 2, 1]> : tensor<3xi64>} : (tensor<4x?x16xf32>) -> tensor<4x2x1x4x?x16xf32>
return %0: tensor<4x2x1x4x?x16xf32> return %0: tensor<4x2x1x4x?x16xf32>
} }
// CHECK: linalg.generic {{{.*}}indexing_maps = [#[[OPERAND_MAP]], #[[RESULT_MAP]]] // CHECK: linalg.generic {{{.*}}indexing_maps = [#[[OPERAND_MAP]], #[[RESULT_MAP]]]
@ -286,7 +286,7 @@ func @broadcast(%arg: tensor<4x?x16xf32>) -> tensor<4x2x1x4x?x16xf32> {
// CHECK-DAG: #[[RESULT_MAP:.*]] = affine_map<(d0, d1, d2, d3, d4) -> (d0, d1, d2, d3, d4)> // CHECK-DAG: #[[RESULT_MAP:.*]] = affine_map<(d0, d1, d2, d3, d4) -> (d0, d1, d2, d3, d4)>
// CHECK-LABEL: func @broadcast_in_dim // CHECK-LABEL: func @broadcast_in_dim
func @broadcast_in_dim(%operand: tensor<5x7x1xf32>) -> tensor<7x10x6x4x5xf32> { func @broadcast_in_dim(%operand: tensor<5x7x1xf32>) -> tensor<7x10x6x4x5xf32> {
%0 = "xla_hlo.broadcast_in_dim"(%operand) %0 = "mhlo.broadcast_in_dim"(%operand)
{broadcast_dimensions = dense<[4,0,2]> : tensor<3xi64>} {broadcast_dimensions = dense<[4,0,2]> : tensor<3xi64>}
: (tensor<5x7x1xf32>) -> tensor<7x10x6x4x5xf32> : (tensor<5x7x1xf32>) -> tensor<7x10x6x4x5xf32>
return %0 : tensor<7x10x6x4x5xf32> return %0 : tensor<7x10x6x4x5xf32>
@ -302,7 +302,7 @@ func @broadcast_in_dim(%operand: tensor<5x7x1xf32>) -> tensor<7x10x6x4x5xf32> {
// CHECK-LABEL: func @broadcast_in_dim_with_one_to_one // CHECK-LABEL: func @broadcast_in_dim_with_one_to_one
func @broadcast_in_dim_with_one_to_one( func @broadcast_in_dim_with_one_to_one(
%operand: tensor<1xf32>) -> tensor<1x5xf32> { %operand: tensor<1xf32>) -> tensor<1x5xf32> {
%0 = "xla_hlo.broadcast_in_dim"(%operand) %0 = "mhlo.broadcast_in_dim"(%operand)
{broadcast_dimensions = dense<[0]> : tensor<1xi64>} {broadcast_dimensions = dense<[0]> : tensor<1xi64>}
: (tensor<1xf32>) -> tensor<1x5xf32> : (tensor<1xf32>) -> tensor<1x5xf32>
return %0 : tensor<1x5xf32> return %0 : tensor<1x5xf32>
@ -317,7 +317,7 @@ func @broadcast_in_dim_with_one_to_one(
// CHECK-DAG: #[[RESULT_MAP:.*]] = affine_map<(d0, d1, d2) -> (d0, d1, d2)> // CHECK-DAG: #[[RESULT_MAP:.*]] = affine_map<(d0, d1, d2) -> (d0, d1, d2)>
// CHECK-LABEL: func @broadcast_scalar // CHECK-LABEL: func @broadcast_scalar
func @broadcast_scalar(%operand: tensor<f32>) -> tensor<7x10x6xf32> { func @broadcast_scalar(%operand: tensor<f32>) -> tensor<7x10x6xf32> {
%0 = "xla_hlo.broadcast_in_dim"(%operand) %0 = "mhlo.broadcast_in_dim"(%operand)
{broadcast_dimensions = dense<[]> : tensor<0xi64>} {broadcast_dimensions = dense<[]> : tensor<0xi64>}
: (tensor<f32>) -> tensor<7x10x6xf32> : (tensor<f32>) -> tensor<7x10x6xf32>
return %0 : tensor<7x10x6xf32> return %0 : tensor<7x10x6xf32>
@ -332,7 +332,7 @@ func @broadcast_scalar(%operand: tensor<f32>) -> tensor<7x10x6xf32> {
// CHECK-DAG: #[[RESULT_MAP:.*]] = affine_map<(d0, d1, d2, d3) -> (d0, d1, d2, d3)> // CHECK-DAG: #[[RESULT_MAP:.*]] = affine_map<(d0, d1, d2, d3) -> (d0, d1, d2, d3)>
// CHECK-LABEL: func @transpose // CHECK-LABEL: func @transpose
func @transpose(%arg0: tensor<2x3x9x5xi32>) -> tensor<3x2x5x9xi32> { func @transpose(%arg0: tensor<2x3x9x5xi32>) -> tensor<3x2x5x9xi32> {
%0 = "xla_hlo.transpose"(%arg0) {permutation = dense<[1, 0, 3, 2]> : tensor<4xi64>} %0 = "mhlo.transpose"(%arg0) {permutation = dense<[1, 0, 3, 2]> : tensor<4xi64>}
: (tensor<2x3x9x5xi32>) -> tensor<3x2x5x9xi32> : (tensor<2x3x9x5xi32>) -> tensor<3x2x5x9xi32>
return %0 : tensor<3x2x5x9xi32> return %0 : tensor<3x2x5x9xi32>
} }
@ -344,7 +344,7 @@ func @transpose(%arg0: tensor<2x3x9x5xi32>) -> tensor<3x2x5x9xi32> {
// CHECK-DAG: #[[RESHAPE_MAP2:.*]] = affine_map<(d0, d1, d2) -> (d2)> // CHECK-DAG: #[[RESHAPE_MAP2:.*]] = affine_map<(d0, d1, d2) -> (d2)>
// CHECK-LABEL: func @reshape_3D_2D // CHECK-LABEL: func @reshape_3D_2D
func @reshape_3D_2D(%arg0: tensor<12x1x42xi32>) -> tensor<12x42xi32> { func @reshape_3D_2D(%arg0: tensor<12x1x42xi32>) -> tensor<12x42xi32> {
%0 = "xla_hlo.reshape"(%arg0) : (tensor<12x1x42xi32>) -> tensor<12x42xi32> %0 = "mhlo.reshape"(%arg0) : (tensor<12x1x42xi32>) -> tensor<12x42xi32>
return %0 : tensor<12x42xi32> return %0 : tensor<12x42xi32>
} }
// CHECK: linalg.tensor_reshape %{{.*}} [#[[RESHAPE_MAP1]], #[[RESHAPE_MAP2]]] // CHECK: linalg.tensor_reshape %{{.*}} [#[[RESHAPE_MAP1]], #[[RESHAPE_MAP2]]]
@ -355,7 +355,7 @@ func @reshape_3D_2D(%arg0: tensor<12x1x42xi32>) -> tensor<12x42xi32> {
// CHECK-DAG: #[[RESHAPE_MAP2:.*]] = affine_map<(d0, d1, d2, d3) -> (d1, d2, d3)> // CHECK-DAG: #[[RESHAPE_MAP2:.*]] = affine_map<(d0, d1, d2, d3) -> (d1, d2, d3)>
// CHECK-LABEL: func @reshape_4D_2D // CHECK-LABEL: func @reshape_4D_2D
func @reshape_4D_2D(%arg0: tensor<12x42x1x1xi32>) -> tensor<12x42xi32> { func @reshape_4D_2D(%arg0: tensor<12x42x1x1xi32>) -> tensor<12x42xi32> {
%0 = "xla_hlo.reshape"(%arg0) : (tensor<12x42x1x1xi32>) -> tensor<12x42xi32> %0 = "mhlo.reshape"(%arg0) : (tensor<12x42x1x1xi32>) -> tensor<12x42xi32>
return %0 : tensor<12x42xi32> return %0 : tensor<12x42xi32>
} }
// CHECK: linalg.tensor_reshape %{{.*}} [#[[RESHAPE_MAP1]], #[[RESHAPE_MAP2]]] // CHECK: linalg.tensor_reshape %{{.*}} [#[[RESHAPE_MAP1]], #[[RESHAPE_MAP2]]]
@ -366,7 +366,7 @@ func @reshape_4D_2D(%arg0: tensor<12x42x1x1xi32>) -> tensor<12x42xi32> {
// CHECK-DAG: #[[RESHAPE_MAP2:.*]] = affine_map<(d0, d1, d2, d3) -> (d2, d3)> // CHECK-DAG: #[[RESHAPE_MAP2:.*]] = affine_map<(d0, d1, d2, d3) -> (d2, d3)>
// CHECK-LABEL: func @reshape_2D_4D // CHECK-LABEL: func @reshape_2D_4D
func @reshape_2D_4D(%arg0: tensor<12x42xi32>) -> tensor<12x1x42x1xi32> { func @reshape_2D_4D(%arg0: tensor<12x42xi32>) -> tensor<12x1x42x1xi32> {
%0 = "xla_hlo.reshape"(%arg0) : (tensor<12x42xi32>) -> tensor<12x1x42x1xi32> %0 = "mhlo.reshape"(%arg0) : (tensor<12x42xi32>) -> tensor<12x1x42x1xi32>
return %0 : tensor<12x1x42x1xi32> return %0 : tensor<12x1x42x1xi32>
} }
// CHECK: linalg.tensor_reshape %{{.*}} [#[[RESHAPE_MAP1]], #[[RESHAPE_MAP2]]] // CHECK: linalg.tensor_reshape %{{.*}} [#[[RESHAPE_MAP1]], #[[RESHAPE_MAP2]]]
@ -375,7 +375,7 @@ func @reshape_2D_4D(%arg0: tensor<12x42xi32>) -> tensor<12x1x42x1xi32> {
// CHECK-LABEL: func @minf // CHECK-LABEL: func @minf
func @minf(%lhs: tensor<2x2xf32>, %rhs: tensor<2x2xf32>) -> tensor<2x2xf32> { func @minf(%lhs: tensor<2x2xf32>, %rhs: tensor<2x2xf32>) -> tensor<2x2xf32> {
%0 = "xla_hlo.minimum"(%lhs, %rhs) %0 = "mhlo.minimum"(%lhs, %rhs)
: (tensor<2x2xf32>, tensor<2x2xf32>) -> tensor<2x2xf32> : (tensor<2x2xf32>, tensor<2x2xf32>) -> tensor<2x2xf32>
return %0 : tensor<2x2xf32> return %0 : tensor<2x2xf32>
} }
@ -389,7 +389,7 @@ func @minf(%lhs: tensor<2x2xf32>, %rhs: tensor<2x2xf32>) -> tensor<2x2xf32> {
// CHECK-LABEL: func @maxi // CHECK-LABEL: func @maxi
func @maxi(%lhs: tensor<2x2xi32>, %rhs: tensor<2x2xi32>) -> tensor<2x2xi32> { func @maxi(%lhs: tensor<2x2xi32>, %rhs: tensor<2x2xi32>) -> tensor<2x2xi32> {
%0 = "xla_hlo.maximum"(%lhs, %rhs) %0 = "mhlo.maximum"(%lhs, %rhs)
: (tensor<2x2xi32>, tensor<2x2xi32>) -> tensor<2x2xi32> : (tensor<2x2xi32>, tensor<2x2xi32>) -> tensor<2x2xi32>
return %0 : tensor<2x2xi32> return %0 : tensor<2x2xi32>
} }
@ -404,7 +404,7 @@ func @maxi(%lhs: tensor<2x2xi32>, %rhs: tensor<2x2xi32>) -> tensor<2x2xi32> {
// CHECK-DAG: #[[MAP:.*]] = affine_map<() -> ()> // CHECK-DAG: #[[MAP:.*]] = affine_map<() -> ()>
// CHECK-LABEL: func @add_scalar // CHECK-LABEL: func @add_scalar
func @add_scalar(%lhs: tensor<f32>, %rhs: tensor<f32>) -> tensor<f32> { func @add_scalar(%lhs: tensor<f32>, %rhs: tensor<f32>) -> tensor<f32> {
%0 = "xla_hlo.add"(%lhs, %rhs) : (tensor<f32>, tensor<f32>) -> tensor<f32> %0 = "mhlo.add"(%lhs, %rhs) : (tensor<f32>, tensor<f32>) -> tensor<f32>
return %0 : tensor<f32> return %0 : tensor<f32>
} }
// CHECK: linalg.generic // CHECK: linalg.generic
@ -417,7 +417,7 @@ func @add_scalar(%lhs: tensor<f32>, %rhs: tensor<f32>) -> tensor<f32> {
func @reshape_collapse_single_dim func @reshape_collapse_single_dim
(%arg0: tensor<1x28x28x1xf32>) -> tensor<1x784xf32> { (%arg0: tensor<1x28x28x1xf32>) -> tensor<1x784xf32> {
%0 = "xla_hlo.reshape"(%arg0) : (tensor<1x28x28x1xf32>) -> tensor<1x784xf32> %0 = "mhlo.reshape"(%arg0) : (tensor<1x28x28x1xf32>) -> tensor<1x784xf32>
return %0 : tensor<1x784xf32> return %0 : tensor<1x784xf32>
} }
// CHECK-DAG: #[[MAP0:.*]] = affine_map<(d0, d1, d2, d3) -> (d0)> // CHECK-DAG: #[[MAP0:.*]] = affine_map<(d0, d1, d2, d3) -> (d0)>
@ -428,7 +428,7 @@ func @reshape_collapse_single_dim
// ----- // -----
func @reshape_collapse(%arg0: tensor<2x2x2x3xf32>) -> tensor<2x4x3xf32> { func @reshape_collapse(%arg0: tensor<2x2x2x3xf32>) -> tensor<2x4x3xf32> {
%0 = "xla_hlo.reshape"(%arg0) : (tensor<2x2x2x3xf32>) -> tensor<2x4x3xf32> %0 = "mhlo.reshape"(%arg0) : (tensor<2x2x2x3xf32>) -> tensor<2x4x3xf32>
return %0 : tensor<2x4x3xf32> return %0 : tensor<2x4x3xf32>
} }
// CHECK-DAG: #[[MAP0:.*]] = affine_map<(d0, d1, d2, d3) -> (d0)> // CHECK-DAG: #[[MAP0:.*]] = affine_map<(d0, d1, d2, d3) -> (d0)>
@ -440,7 +440,7 @@ func @reshape_collapse(%arg0: tensor<2x2x2x3xf32>) -> tensor<2x4x3xf32> {
// ----- // -----
func @reshape_expand(%arg0: tensor<2x8xf32>) -> tensor<2x4x2xf32> { func @reshape_expand(%arg0: tensor<2x8xf32>) -> tensor<2x4x2xf32> {
%0 = "xla_hlo.reshape"(%arg0) : (tensor<2x8xf32>) -> tensor<2x4x2xf32> %0 = "mhlo.reshape"(%arg0) : (tensor<2x8xf32>) -> tensor<2x4x2xf32>
return %0 : tensor<2x4x2xf32> return %0 : tensor<2x4x2xf32>
} }
// CHECK-DAG: #[[MAP0:.*]] = affine_map<(d0, d1, d2) -> (d0)> // CHECK-DAG: #[[MAP0:.*]] = affine_map<(d0, d1, d2) -> (d0)>
@ -451,7 +451,7 @@ func @reshape_expand(%arg0: tensor<2x8xf32>) -> tensor<2x4x2xf32> {
// ----- // -----
func @reshape_single_expand(%arg0 : tensor<8xf32>) -> tensor<1x4x2xf32> { func @reshape_single_expand(%arg0 : tensor<8xf32>) -> tensor<1x4x2xf32> {
%0 = "xla_hlo.reshape"(%arg0) : (tensor<8xf32>) -> tensor<1x4x2xf32> %0 = "mhlo.reshape"(%arg0) : (tensor<8xf32>) -> tensor<1x4x2xf32>
return %0 : tensor<1x4x2xf32> return %0 : tensor<1x4x2xf32>
} }
// CHECK: #[[MAP0:.*]] = affine_map<(d0, d1, d2) -> (d0, d1, d2)> // CHECK: #[[MAP0:.*]] = affine_map<(d0, d1, d2) -> (d0, d1, d2)>
@ -462,7 +462,7 @@ func @reshape_single_expand(%arg0 : tensor<8xf32>) -> tensor<1x4x2xf32> {
func @reshape_multiple_collapse func @reshape_multiple_collapse
(%arg0 : tensor<1x2x2x5x3x2xf32>) -> tensor<1x4x5x6xf32> { (%arg0 : tensor<1x2x2x5x3x2xf32>) -> tensor<1x4x5x6xf32> {
%0 = "xla_hlo.reshape"(%arg0) : (tensor<1x2x2x5x3x2xf32>) -> tensor<1x4x5x6xf32> %0 = "mhlo.reshape"(%arg0) : (tensor<1x2x2x5x3x2xf32>) -> tensor<1x4x5x6xf32>
return %0 : tensor<1x4x5x6xf32> return %0 : tensor<1x4x5x6xf32>
} }
// CHECK-DAG: #[[MAP0:.*]] = affine_map<(d0, d1, d2, d3, d4, d5) -> (d0)> // CHECK-DAG: #[[MAP0:.*]] = affine_map<(d0, d1, d2, d3, d4, d5) -> (d0)>
@ -476,7 +476,7 @@ func @reshape_multiple_collapse
// CHECK-LABEL: func @convert_i32_to_f32 // CHECK-LABEL: func @convert_i32_to_f32
func @convert_i32_to_f32(%input: tensor<2x2xi32>) -> tensor<2x2xf32> { func @convert_i32_to_f32(%input: tensor<2x2xi32>) -> tensor<2x2xf32> {
%result = "xla_hlo.convert"(%input) : (tensor<2x2xi32>) -> tensor<2x2xf32> %result = "mhlo.convert"(%input) : (tensor<2x2xi32>) -> tensor<2x2xf32>
return %result : tensor<2x2xf32> return %result : tensor<2x2xf32>
} }
// CHECK: linalg.generic // CHECK: linalg.generic
@ -488,7 +488,7 @@ func @convert_i32_to_f32(%input: tensor<2x2xi32>) -> tensor<2x2xf32> {
// CHECK-LABEL: func @convert_i16_to_i32 // CHECK-LABEL: func @convert_i16_to_i32
func @convert_i16_to_i32(%input: tensor<2x2xi16>) -> tensor<2x2xi32> { func @convert_i16_to_i32(%input: tensor<2x2xi16>) -> tensor<2x2xi32> {
%result = "xla_hlo.convert"(%input) : (tensor<2x2xi16>) -> tensor<2x2xi32> %result = "mhlo.convert"(%input) : (tensor<2x2xi16>) -> tensor<2x2xi32>
return %result : tensor<2x2xi32> return %result : tensor<2x2xi32>
} }
// CHECK: linalg.generic // CHECK: linalg.generic
@ -500,7 +500,7 @@ func @convert_i16_to_i32(%input: tensor<2x2xi16>) -> tensor<2x2xi32> {
// CHECK-LABEL: func @convert_i32_to_i16 // CHECK-LABEL: func @convert_i32_to_i16
func @convert_i32_to_i16(%input: tensor<2x2xi32>) -> tensor<2x2xi16> { func @convert_i32_to_i16(%input: tensor<2x2xi32>) -> tensor<2x2xi16> {
%result = "xla_hlo.convert"(%input) : (tensor<2x2xi32>) -> tensor<2x2xi16> %result = "mhlo.convert"(%input) : (tensor<2x2xi32>) -> tensor<2x2xi16>
return %result : tensor<2x2xi16> return %result : tensor<2x2xi16>
} }
// CHECK: linalg.generic // CHECK: linalg.generic
@ -512,7 +512,7 @@ func @convert_i32_to_i16(%input: tensor<2x2xi32>) -> tensor<2x2xi16> {
// CHECK-LABEL: func @convert_f32_to_f64 // CHECK-LABEL: func @convert_f32_to_f64
func @convert_f32_to_f64(%input: tensor<2x2xf32>) -> tensor<2x2xf64> { func @convert_f32_to_f64(%input: tensor<2x2xf32>) -> tensor<2x2xf64> {
%result = "xla_hlo.convert"(%input) : (tensor<2x2xf32>) -> tensor<2x2xf64> %result = "mhlo.convert"(%input) : (tensor<2x2xf32>) -> tensor<2x2xf64>
return %result : tensor<2x2xf64> return %result : tensor<2x2xf64>
} }
// CHECK: linalg.generic // CHECK: linalg.generic
@ -524,7 +524,7 @@ func @convert_f32_to_f64(%input: tensor<2x2xf32>) -> tensor<2x2xf64> {
// CHECK-LABEL: func @convert_f64_to_f32 // CHECK-LABEL: func @convert_f64_to_f32
func @convert_f64_to_f32(%input: tensor<2x2xf64>) -> tensor<2x2xf32> { func @convert_f64_to_f32(%input: tensor<2x2xf64>) -> tensor<2x2xf32> {
%result = "xla_hlo.convert"(%input) : (tensor<2x2xf64>) -> tensor<2x2xf32> %result = "mhlo.convert"(%input) : (tensor<2x2xf64>) -> tensor<2x2xf32>
return %result : tensor<2x2xf32> return %result : tensor<2x2xf32>
} }
// CHECK: linalg.generic // CHECK: linalg.generic
@ -536,7 +536,7 @@ func @convert_f64_to_f32(%input: tensor<2x2xf64>) -> tensor<2x2xf32> {
// CHECK-LABEL: func @convert_f32_to_i32 // CHECK-LABEL: func @convert_f32_to_i32
func @convert_f32_to_i32(%input: tensor<2x2xf32>) -> tensor<2x2xi32> { func @convert_f32_to_i32(%input: tensor<2x2xf32>) -> tensor<2x2xi32> {
%result = "xla_hlo.convert"(%input) : (tensor<2x2xf32>) -> tensor<2x2xi32> %result = "mhlo.convert"(%input) : (tensor<2x2xf32>) -> tensor<2x2xi32>
return %result : tensor<2x2xi32> return %result : tensor<2x2xi32>
} }
// CHECK: linalg.generic // CHECK: linalg.generic
@ -550,7 +550,7 @@ func @convert_f32_to_i32(%input: tensor<2x2xf32>) -> tensor<2x2xi32> {
// CHECK-DAG: #[[RESULT_MAP:.*]] = affine_map<(d0, d1) -> (d0, d1)> // CHECK-DAG: #[[RESULT_MAP:.*]] = affine_map<(d0, d1) -> (d0, d1)>
// CHECK-LABEL: func @reverse // CHECK-LABEL: func @reverse
func @reverse(%input: tensor<2x3xf32>) -> tensor<2x3xf32> { func @reverse(%input: tensor<2x3xf32>) -> tensor<2x3xf32> {
%result = "xla_hlo.reverse"(%input) { %result = "mhlo.reverse"(%input) {
dimensions = dense<1> : tensor<1xi64> dimensions = dense<1> : tensor<1xi64>
} : (tensor<2x3xf32>) -> tensor<2x3xf32> } : (tensor<2x3xf32>) -> tensor<2x3xf32>
return %result : tensor<2x3xf32> return %result : tensor<2x3xf32>

View File

@ -1,28 +1,28 @@
// RUN: xla-opt %s -inline | FileCheck %s // RUN: mlir-hlo-opt %s -inline | FileCheck %s
// Test case: Basic test of inlining into xla_hlo.while. // Test case: Basic test of inlining into mhlo.while.
// CHECK-LABEL: func @caller // CHECK-LABEL: func @caller
// CHECK: "xla_hlo.while"{{.*}}( { // CHECK: "mhlo.while"{{.*}}( {
// CHECK: }, { // CHECK: }, {
// CHECK: "xla_hlo.exponential" // CHECK: "mhlo.exponential"
// CHECK: }) // CHECK: })
// CHECK-LABEL: func @callee // CHECK-LABEL: func @callee
func @caller(%arg0: tensor<f32>, %pred: tensor<i1>) -> tensor<f32> { func @caller(%arg0: tensor<f32>, %pred: tensor<i1>) -> tensor<f32> {
%0 = "xla_hlo.while"(%arg0) ( { %0 = "mhlo.while"(%arg0) ( {
^entry(%unused: tensor<f32>): ^entry(%unused: tensor<f32>):
"xla_hlo.return"(%pred) : (tensor<i1>) -> () "mhlo.return"(%pred) : (tensor<i1>) -> ()
}, { }, {
^entry(%0: tensor<f32>): ^entry(%0: tensor<f32>):
%1 = call @callee(%0) : (tensor<f32>) -> (tensor<f32>) %1 = call @callee(%0) : (tensor<f32>) -> (tensor<f32>)
"xla_hlo.return"(%1) : (tensor<f32>) -> () "mhlo.return"(%1) : (tensor<f32>) -> ()
} ) : (tensor<f32>) -> (tensor<f32>) } ) : (tensor<f32>) -> (tensor<f32>)
return %0 : tensor<f32> return %0 : tensor<f32>
} }
func @callee(%arg0: tensor<f32>) -> tensor<f32> { func @callee(%arg0: tensor<f32>) -> tensor<f32> {
%0 = "xla_hlo.exponential"(%arg0) : (tensor<f32>) -> tensor<f32> %0 = "mhlo.exponential"(%arg0) : (tensor<f32>) -> tensor<f32>
return %0 : tensor<f32> return %0 : tensor<f32>
} }

View File

@ -1,24 +1,24 @@
// RUN: xla-opt -xla-legalize-control-flow %s -o - | FileCheck %s // RUN: mlir-hlo-opt -xla-legalize-control-flow %s -o - | FileCheck %s
// CHECK-LABEL: func @while(%arg0: tensor<i64>) -> tensor<i64> { // CHECK-LABEL: func @while(%arg0: tensor<i64>) -> tensor<i64> {
func @while(%arg0: tensor<i64>) -> tensor<i64> { func @while(%arg0: tensor<i64>) -> tensor<i64> {
//CHECK: br ^bb1(%arg0 : tensor<i64>) //CHECK: br ^bb1(%arg0 : tensor<i64>)
//CHECK: ^bb1([[VAL0:%.+]]: tensor<i64>): //CHECK: ^bb1([[VAL0:%.+]]: tensor<i64>):
//CHECK: [[VAL1:%.+]] = "xla_hlo.compare"([[VAL0]], [[VAL0]]) //CHECK: [[VAL1:%.+]] = "mhlo.compare"([[VAL0]], [[VAL0]])
//CHECK: [[VAL2:%.+]] = extract_element [[VAL1]][] : tensor<i1> //CHECK: [[VAL2:%.+]] = extract_element [[VAL1]][] : tensor<i1>
//CHECK: cond_br [[VAL2]], ^bb2([[VAL0]] : tensor<i64>), ^bb3([[VAL0]] : tensor<i64>) //CHECK: cond_br [[VAL2]], ^bb2([[VAL0]] : tensor<i64>), ^bb3([[VAL0]] : tensor<i64>)
//CHECK: ^bb2([[VAL3:%.+]]: tensor<i64>): //CHECK: ^bb2([[VAL3:%.+]]: tensor<i64>):
//CHECK: [[VAL4:%.+]] = xla_hlo.add [[VAL3]], [[VAL3]] //CHECK: [[VAL4:%.+]] = mhlo.add [[VAL3]], [[VAL3]]
//CHECK: br ^bb1([[VAL4]] : tensor<i64>) //CHECK: br ^bb1([[VAL4]] : tensor<i64>)
//CHECK: ^bb3([[VAL5:%.+]]: tensor<i64>): //CHECK: ^bb3([[VAL5:%.+]]: tensor<i64>):
%0 = "xla_hlo.while"(%arg0) ( { %0 = "mhlo.while"(%arg0) ( {
^bb0(%arg1: tensor<i64>): ^bb0(%arg1: tensor<i64>):
%1 = "xla_hlo.compare"(%arg1, %arg1) {comparison_direction = "LT", name = "compare.2"} : (tensor<i64>, tensor<i64>) -> tensor<i1> %1 = "mhlo.compare"(%arg1, %arg1) {comparison_direction = "LT", name = "compare.2"} : (tensor<i64>, tensor<i64>) -> tensor<i1>
"xla_hlo.return"(%1) : (tensor<i1>) -> () "mhlo.return"(%1) : (tensor<i1>) -> ()
}, { }, {
^bb0(%arg1: tensor<i64>): ^bb0(%arg1: tensor<i64>):
%1 = xla_hlo.add %arg1, %arg1 {name = "compare.0"} : tensor<i64> %1 = mhlo.add %arg1, %arg1 {name = "compare.0"} : tensor<i64>
"xla_hlo.return"(%1) : (tensor<i64>) -> () "mhlo.return"(%1) : (tensor<i64>) -> ()
}) : (tensor<i64>) -> tensor<i64> }) : (tensor<i64>) -> tensor<i64>
// CHECK-NEXT: return [[VAL5]] // CHECK-NEXT: return [[VAL5]]
@ -30,27 +30,27 @@ func @conditional(%arg0: tensor<f32>) -> tensor<f32> {
// CHECK: [[C0:%.+]] = constant dense<1.000000e+01> : tensor<f32> // CHECK: [[C0:%.+]] = constant dense<1.000000e+01> : tensor<f32>
%cst = constant dense<1.000000e+01> : tensor<f32> %cst = constant dense<1.000000e+01> : tensor<f32>
// CHECK: [[VAL0:%.+]] = "xla_hlo.compare"(%arg0, [[C0]]) {comparison_direction = "LT"} : (tensor<f32>, tensor<f32>) -> tensor<i1> // CHECK: [[VAL0:%.+]] = "mhlo.compare"(%arg0, [[C0]]) {comparison_direction = "LT"} : (tensor<f32>, tensor<f32>) -> tensor<i1>
%0 = "xla_hlo.compare"(%arg0, %cst) {comparison_direction = "LT"} : (tensor<f32>, tensor<f32>) -> tensor<i1> %0 = "mhlo.compare"(%arg0, %cst) {comparison_direction = "LT"} : (tensor<f32>, tensor<f32>) -> tensor<i1>
// CHECK: [[VAL1:%.+]] = extract_element [[VAL0]][] : tensor<i1> // CHECK: [[VAL1:%.+]] = extract_element [[VAL0]][] : tensor<i1>
// CHECK: cond_br [[VAL1]], ^bb1(%arg0 : tensor<f32>), ^bb2(%arg0 : tensor<f32>) // CHECK: cond_br [[VAL1]], ^bb1(%arg0 : tensor<f32>), ^bb2(%arg0 : tensor<f32>)
%1 = "xla_hlo.if"(%0, %arg0, %arg0) ( { %1 = "mhlo.if"(%0, %arg0, %arg0) ( {
^bb0(%arg1: tensor<f32>): ^bb0(%arg1: tensor<f32>):
// CHECK: ^bb1([[VAL2:%.+]]: tensor<f32>): // CHECK: ^bb1([[VAL2:%.+]]: tensor<f32>):
// CHECK: [[VAL3:%.+]] = "xla_hlo.log"([[VAL2]]) : (tensor<f32>) -> tensor<f32> // CHECK: [[VAL3:%.+]] = "mhlo.log"([[VAL2]]) : (tensor<f32>) -> tensor<f32>
// CHECK: br ^bb3([[VAL3]] : tensor<f32>) // CHECK: br ^bb3([[VAL3]] : tensor<f32>)
%2 = "xla_hlo.log"(%arg1) : (tensor<f32>) -> tensor<f32> %2 = "mhlo.log"(%arg1) : (tensor<f32>) -> tensor<f32>
"xla_hlo.return"(%2) : (tensor<f32>) -> () "mhlo.return"(%2) : (tensor<f32>) -> ()
}, { }, {
^bb0(%arg1: tensor<f32>): ^bb0(%arg1: tensor<f32>):
// CHECK: ^bb2([[VAL4:%.+]]: tensor<f32>): // CHECK: ^bb2([[VAL4:%.+]]: tensor<f32>):
// CHECK: [[VAL5:%.+]] = "xla_hlo.exponential"([[VAL4]]) : (tensor<f32>) -> tensor<f32> // CHECK: [[VAL5:%.+]] = "mhlo.exponential"([[VAL4]]) : (tensor<f32>) -> tensor<f32>
// CHECK: br ^bb3([[VAL5]] : tensor<f32>) // CHECK: br ^bb3([[VAL5]] : tensor<f32>)
%2 = "xla_hlo.exponential"(%arg1) : (tensor<f32>) -> tensor<f32> %2 = "mhlo.exponential"(%arg1) : (tensor<f32>) -> tensor<f32>
"xla_hlo.return"(%2) : (tensor<f32>) -> () "mhlo.return"(%2) : (tensor<f32>) -> ()
}) : (tensor<i1>, tensor<f32>, tensor<f32>) -> tensor<f32> }) : (tensor<i1>, tensor<f32>, tensor<f32>) -> tensor<f32>
// CHECK: ^bb3([[VAL6:%.+]]: tensor<f32>): // CHECK: ^bb3([[VAL6:%.+]]: tensor<f32>):
@ -62,27 +62,27 @@ func @conditional(%arg0: tensor<f32>) -> tensor<f32> {
func @while_with_multiple_blocks_in_body(%arg0: tensor<i64>) -> tensor<i64> { func @while_with_multiple_blocks_in_body(%arg0: tensor<i64>) -> tensor<i64> {
// CHECK: br ^[[COND_ENTRY:.+]](%arg0 : tensor<i64>) // CHECK: br ^[[COND_ENTRY:.+]](%arg0 : tensor<i64>)
// CHECK: ^[[COND_ENTRY]](%0: tensor<i64>): // CHECK: ^[[COND_ENTRY]](%0: tensor<i64>):
// CHECK: %1 = "xla_hlo.compare"(%0, %0) {comparison_direction = "LT"} : (tensor<i64>, tensor<i64>) -> tensor<i1> // CHECK: %1 = "mhlo.compare"(%0, %0) {comparison_direction = "LT"} : (tensor<i64>, tensor<i64>) -> tensor<i1>
// CHECK: %2 = extract_element %1[] : tensor<i1> // CHECK: %2 = extract_element %1[] : tensor<i1>
// CHECK: cond_br %2, ^[[BODY_ENTRY:.+]](%0 : tensor<i64>), ^[[EXIT:.+]](%0 : tensor<i64>) // CHECK: cond_br %2, ^[[BODY_ENTRY:.+]](%0 : tensor<i64>), ^[[EXIT:.+]](%0 : tensor<i64>)
// CHECK: ^[[BODY_ENTRY]](%3: tensor<i64>): // CHECK: ^[[BODY_ENTRY]](%3: tensor<i64>):
// CHECK: br ^[[BODY_SUCC:.+]](%3 : tensor<i64>) // CHECK: br ^[[BODY_SUCC:.+]](%3 : tensor<i64>)
// CHECK: ^[[BODY_SUCC]](%4: tensor<i64>): // CHECK: ^[[BODY_SUCC]](%4: tensor<i64>):
// CHECK: %5 = xla_hlo.add %4, %4 : tensor<i64> // CHECK: %5 = mhlo.add %4, %4 : tensor<i64>
// CHECK: br ^[[COND_ENTRY]](%5 : tensor<i64>) // CHECK: br ^[[COND_ENTRY]](%5 : tensor<i64>)
// CHECK: ^[[EXIT]](%6: tensor<i64>): // CHECK: ^[[EXIT]](%6: tensor<i64>):
// CHECK: return %6 : tensor<i64> // CHECK: return %6 : tensor<i64>
// CHECK: } // CHECK: }
%0 = "xla_hlo.while"(%arg0) ( { %0 = "mhlo.while"(%arg0) ( {
^cond_entry(%arg1: tensor<i64>): ^cond_entry(%arg1: tensor<i64>):
%1 = "xla_hlo.compare"(%arg1, %arg1) {comparison_direction = "LT"} : (tensor<i64>, tensor<i64>) -> tensor<i1> %1 = "mhlo.compare"(%arg1, %arg1) {comparison_direction = "LT"} : (tensor<i64>, tensor<i64>) -> tensor<i1>
"xla_hlo.return"(%1) : (tensor<i1>) -> () "mhlo.return"(%1) : (tensor<i1>) -> ()
}, { }, {
^body_entry(%arg1: tensor<i64>): ^body_entry(%arg1: tensor<i64>):
br ^body_succ(%arg1: tensor<i64>) br ^body_succ(%arg1: tensor<i64>)
^body_succ(%0: tensor<i64>): ^body_succ(%0: tensor<i64>):
%1 = xla_hlo.add %0, %0 : tensor<i64> %1 = mhlo.add %0, %0 : tensor<i64>
"xla_hlo.return"(%1) : (tensor<i64>) -> () "mhlo.return"(%1) : (tensor<i64>) -> ()
}) : (tensor<i64>) -> tensor<i64> }) : (tensor<i64>) -> tensor<i64>
return %0 : tensor<i64> return %0 : tensor<i64>
@ -94,7 +94,7 @@ func @while_with_multiple_blocks_in_cond(%arg0: tensor<i64>) -> tensor<i64> {
// CHECK: ^[[COND_ENTRY]](%0: tensor<i64>): // CHECK: ^[[COND_ENTRY]](%0: tensor<i64>):
// CHECK: br ^[[COND_SUCC:.+]](%0 : tensor<i64>) // CHECK: br ^[[COND_SUCC:.+]](%0 : tensor<i64>)
// CHECK: ^[[COND_SUCC]](%1: tensor<i64>): // CHECK: ^[[COND_SUCC]](%1: tensor<i64>):
// CHECK: %2 = "xla_hlo.compare"(%1, %1) {comparison_direction = "LT"} : (tensor<i64>, tensor<i64>) -> tensor<i1> // CHECK: %2 = "mhlo.compare"(%1, %1) {comparison_direction = "LT"} : (tensor<i64>, tensor<i64>) -> tensor<i1>
// CHECK: %3 = extract_element %2[] : tensor<i1> // CHECK: %3 = extract_element %2[] : tensor<i1>
// CHECK: cond_br %3, ^[[BODY_ENTRY:.+]](%0 : tensor<i64>), ^[[EXIT:.+]](%0 : tensor<i64>) // CHECK: cond_br %3, ^[[BODY_ENTRY:.+]](%0 : tensor<i64>), ^[[EXIT:.+]](%0 : tensor<i64>)
// CHECK: ^[[BODY_ENTRY]](%4: tensor<i64>): // CHECK: ^[[BODY_ENTRY]](%4: tensor<i64>):
@ -102,15 +102,15 @@ func @while_with_multiple_blocks_in_cond(%arg0: tensor<i64>) -> tensor<i64> {
// CHECK: ^[[EXIT]](%5: tensor<i64>): // CHECK: ^[[EXIT]](%5: tensor<i64>):
// CHECK: return %5 : tensor<i64> // CHECK: return %5 : tensor<i64>
// CHECK: } // CHECK: }
%0 = "xla_hlo.while"(%arg0) ( { %0 = "mhlo.while"(%arg0) ( {
^cond_entry(%arg1: tensor<i64>): ^cond_entry(%arg1: tensor<i64>):
br ^cond_succ(%arg1: tensor<i64>) br ^cond_succ(%arg1: tensor<i64>)
^cond_succ(%0: tensor<i64>): ^cond_succ(%0: tensor<i64>):
%1 = "xla_hlo.compare"(%0, %0) {comparison_direction = "LT"} : (tensor<i64>, tensor<i64>) -> tensor<i1> %1 = "mhlo.compare"(%0, %0) {comparison_direction = "LT"} : (tensor<i64>, tensor<i64>) -> tensor<i1>
"xla_hlo.return"(%1) : (tensor<i1>) -> () "mhlo.return"(%1) : (tensor<i1>) -> ()
}, { }, {
^body_entry(%arg1: tensor<i64>): ^body_entry(%arg1: tensor<i64>):
"xla_hlo.return"(%arg1) : (tensor<i64>) -> () "mhlo.return"(%arg1) : (tensor<i64>) -> ()
}) : (tensor<i64>) -> tensor<i64> }) : (tensor<i64>) -> tensor<i64>
return %0 : tensor<i64> return %0 : tensor<i64>
@ -123,24 +123,24 @@ func @conditional_with_multiple_blocks(%arg0: tensor<f32>, %arg1: tensor<f32>, %
// CHECK: ^[[THEN_ENTRY]](%1: tensor<f32>): // CHECK: ^[[THEN_ENTRY]](%1: tensor<f32>):
// CHECK: br ^[[THEN_SUCC:.+]](%1 : tensor<f32>) // CHECK: br ^[[THEN_SUCC:.+]](%1 : tensor<f32>)
// CHECK: ^[[THEN_SUCC]](%2: tensor<f32>): // CHECK: ^[[THEN_SUCC]](%2: tensor<f32>):
// CHECK: %3 = "xla_hlo.log"(%2) : (tensor<f32>) -> tensor<f32> // CHECK: %3 = "mhlo.log"(%2) : (tensor<f32>) -> tensor<f32>
// CHECK: br ^[[EXIT:.+]](%3 : tensor<f32>) // CHECK: br ^[[EXIT:.+]](%3 : tensor<f32>)
// CHECK: ^[[ELSE_ENTRY]](%4: tensor<f32>): // CHECK: ^[[ELSE_ENTRY]](%4: tensor<f32>):
// CHECK: %5 = "xla_hlo.exponential"(%4) : (tensor<f32>) -> tensor<f32> // CHECK: %5 = "mhlo.exponential"(%4) : (tensor<f32>) -> tensor<f32>
// CHECK: br ^[[EXIT]](%5 : tensor<f32>) // CHECK: br ^[[EXIT]](%5 : tensor<f32>)
// CHECK: ^[[EXIT]](%6: tensor<f32>): // CHECK: ^[[EXIT]](%6: tensor<f32>):
// CHECK: return %6 : tensor<f32> // CHECK: return %6 : tensor<f32>
// CHECK: } // CHECK: }
%1 = "xla_hlo.if"(%pred, %arg0, %arg1) ( { %1 = "mhlo.if"(%pred, %arg0, %arg1) ( {
^then_entry(%arg2: tensor<f32>): ^then_entry(%arg2: tensor<f32>):
br ^then_succ(%arg2: tensor<f32>) br ^then_succ(%arg2: tensor<f32>)
^then_succ(%0: tensor<f32>): ^then_succ(%0: tensor<f32>):
%2 = "xla_hlo.log"(%0) : (tensor<f32>) -> tensor<f32> %2 = "mhlo.log"(%0) : (tensor<f32>) -> tensor<f32>
"xla_hlo.return"(%2) : (tensor<f32>) -> () "mhlo.return"(%2) : (tensor<f32>) -> ()
}, { }, {
^else_entry(%arg2: tensor<f32>): ^else_entry(%arg2: tensor<f32>):
%2 = "xla_hlo.exponential"(%arg2) : (tensor<f32>) -> tensor<f32> %2 = "mhlo.exponential"(%arg2) : (tensor<f32>) -> tensor<f32>
"xla_hlo.return"(%2) : (tensor<f32>) -> () "mhlo.return"(%2) : (tensor<f32>) -> ()
}) : (tensor<i1>, tensor<f32>, tensor<f32>) -> tensor<f32> }) : (tensor<i1>, tensor<f32>, tensor<f32>) -> tensor<f32>
return %1 : tensor<f32> return %1 : tensor<f32>
} }

View File

@ -1,21 +1,21 @@
// RUN: xla-opt -xla-legalize-to-std %s -o - | FileCheck %s // RUN: mlir-hlo-opt -xla-legalize-to-std %s -o - | FileCheck %s
// CHECK-LABEL: func @binary_ops_float(%arg0: tensor<4xf32>, %arg1: tensor<4xf32>) -> tensor<4xf32> { // CHECK-LABEL: func @binary_ops_float(%arg0: tensor<4xf32>, %arg1: tensor<4xf32>) -> tensor<4xf32> {
func @binary_ops_float(%arg0: tensor<4xf32>, %arg1: tensor<4xf32>) -> tensor<4xf32> { func @binary_ops_float(%arg0: tensor<4xf32>, %arg1: tensor<4xf32>) -> tensor<4xf32> {
// CHECK-NEXT: %0 = addf %arg0, %arg1 : tensor<4xf32> // CHECK-NEXT: %0 = addf %arg0, %arg1 : tensor<4xf32>
%0 = "xla_hlo.add"(%arg0, %arg1) {name = "add.3"} : (tensor<4xf32>, tensor<4xf32>) -> tensor<4xf32> %0 = "mhlo.add"(%arg0, %arg1) {name = "add.3"} : (tensor<4xf32>, tensor<4xf32>) -> tensor<4xf32>
// CHECK-NEXT: %1 = mulf %0, %arg1 : tensor<4xf32> // CHECK-NEXT: %1 = mulf %0, %arg1 : tensor<4xf32>
%1 = "xla_hlo.multiply"(%0, %arg1) {name = "mul.4"} : (tensor<4xf32>, tensor<4xf32>) -> tensor<4xf32> %1 = "mhlo.multiply"(%0, %arg1) {name = "mul.4"} : (tensor<4xf32>, tensor<4xf32>) -> tensor<4xf32>
// CHECK-NEXT: %2 = subf %1, %arg1 : tensor<4xf32> // CHECK-NEXT: %2 = subf %1, %arg1 : tensor<4xf32>
%2 = "xla_hlo.subtract"(%1, %arg1) {name = "sub.5"} : (tensor<4xf32>, tensor<4xf32>) -> tensor<4xf32> %2 = "mhlo.subtract"(%1, %arg1) {name = "sub.5"} : (tensor<4xf32>, tensor<4xf32>) -> tensor<4xf32>
// CHECK-NEXT: %3 = divf %2, %arg1 : tensor<4xf32> // CHECK-NEXT: %3 = divf %2, %arg1 : tensor<4xf32>
%3 = "xla_hlo.divide"(%2, %arg1) {name = "div.6"} : (tensor<4xf32>, tensor<4xf32>) -> tensor<4xf32> %3 = "mhlo.divide"(%2, %arg1) {name = "div.6"} : (tensor<4xf32>, tensor<4xf32>) -> tensor<4xf32>
// CHECK-NEXT: %4 = remf %3, %arg1 : tensor<4xf32> // CHECK-NEXT: %4 = remf %3, %arg1 : tensor<4xf32>
%4 = "xla_hlo.remainder"(%3, %arg1) : (tensor<4xf32>, tensor<4xf32>) -> tensor<4xf32> %4 = "mhlo.remainder"(%3, %arg1) : (tensor<4xf32>, tensor<4xf32>) -> tensor<4xf32>
// CHECK-NEXT: return %4 : tensor<4xf32> // CHECK-NEXT: return %4 : tensor<4xf32>
return %4 : tensor<4xf32> return %4 : tensor<4xf32>
@ -24,19 +24,19 @@ func @binary_ops_float(%arg0: tensor<4xf32>, %arg1: tensor<4xf32>) -> tensor<4xf
// CHECK-LABEL: func @binary_ops_int(%arg0: tensor<4xi32>, %arg1: tensor<4xi32>) -> tensor<4xi32> { // CHECK-LABEL: func @binary_ops_int(%arg0: tensor<4xi32>, %arg1: tensor<4xi32>) -> tensor<4xi32> {
func @binary_ops_int(%arg0: tensor<4xi32>, %arg1: tensor<4xi32>) -> tensor<4xi32> { func @binary_ops_int(%arg0: tensor<4xi32>, %arg1: tensor<4xi32>) -> tensor<4xi32> {
// CHECK-NEXT: %0 = addi %arg0, %arg1 : tensor<4xi32> // CHECK-NEXT: %0 = addi %arg0, %arg1 : tensor<4xi32>
%0 = "xla_hlo.add"(%arg0, %arg1) {name = "add.3"} : (tensor<4xi32>, tensor<4xi32>) -> tensor<4xi32> %0 = "mhlo.add"(%arg0, %arg1) {name = "add.3"} : (tensor<4xi32>, tensor<4xi32>) -> tensor<4xi32>
// CHECK-NEXT: %1 = muli %0, %arg1 : tensor<4xi32> // CHECK-NEXT: %1 = muli %0, %arg1 : tensor<4xi32>
%1 = "xla_hlo.multiply"(%0, %arg1) {name = "mul.4"} : (tensor<4xi32>, tensor<4xi32>) -> tensor<4xi32> %1 = "mhlo.multiply"(%0, %arg1) {name = "mul.4"} : (tensor<4xi32>, tensor<4xi32>) -> tensor<4xi32>
// CHECK-NEXT: %2 = subi %1, %arg1 : tensor<4xi32> // CHECK-NEXT: %2 = subi %1, %arg1 : tensor<4xi32>
%2 = "xla_hlo.subtract"(%1, %arg1) {name = "sub.5"} : (tensor<4xi32>, tensor<4xi32>) -> tensor<4xi32> %2 = "mhlo.subtract"(%1, %arg1) {name = "sub.5"} : (tensor<4xi32>, tensor<4xi32>) -> tensor<4xi32>
// CHECK-NEXT: %3 = divi_signed %2, %arg1 : tensor<4xi32> // CHECK-NEXT: %3 = divi_signed %2, %arg1 : tensor<4xi32>
%3 = "xla_hlo.divide"(%2, %arg1) {name = "div.6"} : (tensor<4xi32>, tensor<4xi32>) -> tensor<4xi32> %3 = "mhlo.divide"(%2, %arg1) {name = "div.6"} : (tensor<4xi32>, tensor<4xi32>) -> tensor<4xi32>
// CHECK-NEXT: %4 = remi_signed %3, %arg1 : tensor<4xi32> // CHECK-NEXT: %4 = remi_signed %3, %arg1 : tensor<4xi32>
%4 = "xla_hlo.remainder"(%3, %arg1) : (tensor<4xi32>, tensor<4xi32>) -> tensor<4xi32> %4 = "mhlo.remainder"(%3, %arg1) : (tensor<4xi32>, tensor<4xi32>) -> tensor<4xi32>
// CHECK-NEXT: return %4 : tensor<4xi32> // CHECK-NEXT: return %4 : tensor<4xi32>
return %4 : tensor<4xi32> return %4 : tensor<4xi32>
@ -45,17 +45,17 @@ func @binary_ops_int(%arg0: tensor<4xi32>, %arg1: tensor<4xi32>) -> tensor<4xi32
// CHECK-LABEL: func @compare_int(%arg0: tensor<4xi32>) -> (tensor<4xi1>, tensor<4xi1>, tensor<4xi1>, tensor<4xi1>, tensor<4xi1>, tensor<4xi1>) { // CHECK-LABEL: func @compare_int(%arg0: tensor<4xi32>) -> (tensor<4xi1>, tensor<4xi1>, tensor<4xi1>, tensor<4xi1>, tensor<4xi1>, tensor<4xi1>) {
func @compare_int(%arg0: tensor<4xi32>) -> (tensor<4xi1>,tensor<4xi1>,tensor<4xi1>,tensor<4xi1>,tensor<4xi1>,tensor<4xi1>) { func @compare_int(%arg0: tensor<4xi32>) -> (tensor<4xi1>,tensor<4xi1>,tensor<4xi1>,tensor<4xi1>,tensor<4xi1>,tensor<4xi1>) {
// CHECK-NEXT: %0 = cmpi "eq", %arg0, %arg0 : tensor<4xi32> // CHECK-NEXT: %0 = cmpi "eq", %arg0, %arg0 : tensor<4xi32>
%0 = "xla_hlo.compare"(%arg0, %arg0) {comparison_direction = "EQ"} : (tensor<4xi32>, tensor<4xi32>) -> tensor<4xi1> %0 = "mhlo.compare"(%arg0, %arg0) {comparison_direction = "EQ"} : (tensor<4xi32>, tensor<4xi32>) -> tensor<4xi1>
// CHECK-NEXT: %1 = cmpi "ne", %arg0, %arg0 : tensor<4xi32> // CHECK-NEXT: %1 = cmpi "ne", %arg0, %arg0 : tensor<4xi32>
%1 = "xla_hlo.compare"(%arg0, %arg0) {comparison_direction = "NE"} : (tensor<4xi32>, tensor<4xi32>) -> tensor<4xi1> %1 = "mhlo.compare"(%arg0, %arg0) {comparison_direction = "NE"} : (tensor<4xi32>, tensor<4xi32>) -> tensor<4xi1>
// CHECK-NEXT: %2 = cmpi "slt", %arg0, %arg0 : tensor<4xi32> // CHECK-NEXT: %2 = cmpi "slt", %arg0, %arg0 : tensor<4xi32>
%2 = "xla_hlo.compare"(%arg0, %arg0) {comparison_direction = "LT"} : (tensor<4xi32>, tensor<4xi32>) -> tensor<4xi1> %2 = "mhlo.compare"(%arg0, %arg0) {comparison_direction = "LT"} : (tensor<4xi32>, tensor<4xi32>) -> tensor<4xi1>
// CHECK-NEXT: %3 = cmpi "sle", %arg0, %arg0 : tensor<4xi32> // CHECK-NEXT: %3 = cmpi "sle", %arg0, %arg0 : tensor<4xi32>
%3 = "xla_hlo.compare"(%arg0, %arg0) {comparison_direction = "LE"} : (tensor<4xi32>, tensor<4xi32>) -> tensor<4xi1> %3 = "mhlo.compare"(%arg0, %arg0) {comparison_direction = "LE"} : (tensor<4xi32>, tensor<4xi32>) -> tensor<4xi1>
// CHECK-NEXT: %4 = cmpi "sgt", %arg0, %arg0 : tensor<4xi32> // CHECK-NEXT: %4 = cmpi "sgt", %arg0, %arg0 : tensor<4xi32>
%4 = "xla_hlo.compare"(%arg0, %arg0) {comparison_direction = "GT"} : (tensor<4xi32>, tensor<4xi32>) -> tensor<4xi1> %4 = "mhlo.compare"(%arg0, %arg0) {comparison_direction = "GT"} : (tensor<4xi32>, tensor<4xi32>) -> tensor<4xi1>
// CHECK-NEXT: %5 = cmpi "sge", %arg0, %arg0 : tensor<4xi32> // CHECK-NEXT: %5 = cmpi "sge", %arg0, %arg0 : tensor<4xi32>
%5 = "xla_hlo.compare"(%arg0, %arg0) {comparison_direction = "GE"} : (tensor<4xi32>, tensor<4xi32>) -> tensor<4xi1> %5 = "mhlo.compare"(%arg0, %arg0) {comparison_direction = "GE"} : (tensor<4xi32>, tensor<4xi32>) -> tensor<4xi1>
// CHECK-NEXT: return %0, %1, %2, %3, %4, %5 : tensor<4xi1>, tensor<4xi1>, tensor<4xi1>, tensor<4xi1>, tensor<4xi1>, tensor<4xi1> // CHECK-NEXT: return %0, %1, %2, %3, %4, %5 : tensor<4xi1>, tensor<4xi1>, tensor<4xi1>, tensor<4xi1>, tensor<4xi1>, tensor<4xi1>
return %0, %1, %2, %3, %4, %5 : tensor<4xi1>, tensor<4xi1>, tensor<4xi1>, tensor<4xi1>, tensor<4xi1>, tensor<4xi1> return %0, %1, %2, %3, %4, %5 : tensor<4xi1>, tensor<4xi1>, tensor<4xi1>, tensor<4xi1>, tensor<4xi1>, tensor<4xi1>
} }
@ -63,28 +63,28 @@ func @compare_int(%arg0: tensor<4xi32>) -> (tensor<4xi1>,tensor<4xi1>,tensor<4xi
// CHECK-LABEL: func @compare_float // CHECK-LABEL: func @compare_float
func @compare_float(%arg0: tensor<4xf32>) -> (tensor<4xi1>,tensor<4xi1>,tensor<4xi1>,tensor<4xi1>,tensor<4xi1>,tensor<4xi1>) { func @compare_float(%arg0: tensor<4xf32>) -> (tensor<4xi1>,tensor<4xi1>,tensor<4xi1>,tensor<4xi1>,tensor<4xi1>,tensor<4xi1>) {
// CHECK-NEXT: %0 = cmpf "oeq", %arg0, %arg0 : tensor<4xf32> // CHECK-NEXT: %0 = cmpf "oeq", %arg0, %arg0 : tensor<4xf32>
%0 = "xla_hlo.compare"(%arg0, %arg0) {comparison_direction = "EQ"} : (tensor<4xf32>, tensor<4xf32>) -> tensor<4xi1> %0 = "mhlo.compare"(%arg0, %arg0) {comparison_direction = "EQ"} : (tensor<4xf32>, tensor<4xf32>) -> tensor<4xi1>
// CHECK-NEXT: %1 = cmpf "une", %arg0, %arg0 : tensor<4xf32> // CHECK-NEXT: %1 = cmpf "une", %arg0, %arg0 : tensor<4xf32>
%1 = "xla_hlo.compare"(%arg0, %arg0) {comparison_direction = "NE"} : (tensor<4xf32>, tensor<4xf32>) -> tensor<4xi1> %1 = "mhlo.compare"(%arg0, %arg0) {comparison_direction = "NE"} : (tensor<4xf32>, tensor<4xf32>) -> tensor<4xi1>
// CHECK-NEXT: %2 = cmpf "olt", %arg0, %arg0 : tensor<4xf32> // CHECK-NEXT: %2 = cmpf "olt", %arg0, %arg0 : tensor<4xf32>
%2 = "xla_hlo.compare"(%arg0, %arg0) {comparison_direction = "LT"} : (tensor<4xf32>, tensor<4xf32>) -> tensor<4xi1> %2 = "mhlo.compare"(%arg0, %arg0) {comparison_direction = "LT"} : (tensor<4xf32>, tensor<4xf32>) -> tensor<4xi1>
// CHECK-NEXT: %3 = cmpf "ole", %arg0, %arg0 : tensor<4xf32> // CHECK-NEXT: %3 = cmpf "ole", %arg0, %arg0 : tensor<4xf32>
%3 = "xla_hlo.compare"(%arg0, %arg0) {comparison_direction = "LE"} : (tensor<4xf32>, tensor<4xf32>) -> tensor<4xi1> %3 = "mhlo.compare"(%arg0, %arg0) {comparison_direction = "LE"} : (tensor<4xf32>, tensor<4xf32>) -> tensor<4xi1>
// CHECK-NEXT: %4 = cmpf "ogt", %arg0, %arg0 : tensor<4xf32> // CHECK-NEXT: %4 = cmpf "ogt", %arg0, %arg0 : tensor<4xf32>
%4 = "xla_hlo.compare"(%arg0, %arg0) {comparison_direction = "GT"} : (tensor<4xf32>, tensor<4xf32>) -> tensor<4xi1> %4 = "mhlo.compare"(%arg0, %arg0) {comparison_direction = "GT"} : (tensor<4xf32>, tensor<4xf32>) -> tensor<4xi1>
// CHECK-NEXT: %5 = cmpf "oge", %arg0, %arg0 : tensor<4xf32> // CHECK-NEXT: %5 = cmpf "oge", %arg0, %arg0 : tensor<4xf32>
%5 = "xla_hlo.compare"(%arg0, %arg0) {comparison_direction = "GE"} : (tensor<4xf32>, tensor<4xf32>) -> tensor<4xi1> %5 = "mhlo.compare"(%arg0, %arg0) {comparison_direction = "GE"} : (tensor<4xf32>, tensor<4xf32>) -> tensor<4xi1>
return %0, %1, %2, %3, %4, %5: tensor<4xi1>, tensor<4xi1>, tensor<4xi1>, tensor<4xi1>, tensor<4xi1>, tensor<4xi1> return %0, %1, %2, %3, %4, %5: tensor<4xi1>, tensor<4xi1>, tensor<4xi1>, tensor<4xi1>, tensor<4xi1>, tensor<4xi1>
} }
// CHECK-LABEL: func @int_constant // CHECK-LABEL: func @int_constant
func @int_constant() -> (tensor<i32>, tensor<2x3xi32>, tensor<2x3xi32>) { func @int_constant() -> (tensor<i32>, tensor<2x3xi32>, tensor<2x3xi32>) {
// CHECK-NEXT: [[CST0:%.+]] = constant {{.+}} : tensor<i32> // CHECK-NEXT: [[CST0:%.+]] = constant {{.+}} : tensor<i32>
%0 = "xla_hlo.constant"() {value = dense<0> : tensor<i32>} : () -> (tensor<i32>) %0 = "mhlo.constant"() {value = dense<0> : tensor<i32>} : () -> (tensor<i32>)
// CHECK-NEXT: [[CST1:%.+]] = constant {{.+}} : tensor<2x3xi32> // CHECK-NEXT: [[CST1:%.+]] = constant {{.+}} : tensor<2x3xi32>
%1 = "xla_hlo.constant"() {value = dense<1> : tensor<2x3xi32>} : () -> (tensor<2x3xi32>) %1 = "mhlo.constant"() {value = dense<1> : tensor<2x3xi32>} : () -> (tensor<2x3xi32>)
// CHECK-NEXT: [[CST2:%.+]] = constant {{.+}} : tensor<2x3xi32> // CHECK-NEXT: [[CST2:%.+]] = constant {{.+}} : tensor<2x3xi32>
%2 = "xla_hlo.constant"() {value = dense<[[1, 2, 3], [4, 5, 6]]> : tensor<2x3xi32>} : () -> (tensor<2x3xi32>) %2 = "mhlo.constant"() {value = dense<[[1, 2, 3], [4, 5, 6]]> : tensor<2x3xi32>} : () -> (tensor<2x3xi32>)
// CHECK-NEXT: return [[CST0]], [[CST1]], [[CST2]] : tensor<i32>, tensor<2x3xi32>, tensor<2x3xi32> // CHECK-NEXT: return [[CST0]], [[CST1]], [[CST2]] : tensor<i32>, tensor<2x3xi32>, tensor<2x3xi32>
return %0, %1, %2: tensor<i32>, tensor<2x3xi32>, tensor<2x3xi32> return %0, %1, %2: tensor<i32>, tensor<2x3xi32>, tensor<2x3xi32>
} }
@ -92,11 +92,11 @@ func @int_constant() -> (tensor<i32>, tensor<2x3xi32>, tensor<2x3xi32>) {
// CHECK-LABEL: func @float_constant // CHECK-LABEL: func @float_constant
func @float_constant() -> (tensor<f32>, tensor<2x3xf32>, tensor<2x3xf32>) { func @float_constant() -> (tensor<f32>, tensor<2x3xf32>, tensor<2x3xf32>) {
// CHECK-NEXT: [[CST0:%.+]] = constant {{.+}} : tensor<f32> // CHECK-NEXT: [[CST0:%.+]] = constant {{.+}} : tensor<f32>
%0 = "xla_hlo.constant"() {value = dense<0.0> : tensor<f32>} : () -> (tensor<f32>) %0 = "mhlo.constant"() {value = dense<0.0> : tensor<f32>} : () -> (tensor<f32>)
// CHECK-NEXT: [[CST1:%.+]] = constant {{.+}} : tensor<2x3xf32> // CHECK-NEXT: [[CST1:%.+]] = constant {{.+}} : tensor<2x3xf32>
%1 = "xla_hlo.constant"() {value = dense<1.0> : tensor<2x3xf32>} : () -> (tensor<2x3xf32>) %1 = "mhlo.constant"() {value = dense<1.0> : tensor<2x3xf32>} : () -> (tensor<2x3xf32>)
// CHECK-NEXT: [[CST2:%.+]] = constant {{.+}} : tensor<2x3xf32> // CHECK-NEXT: [[CST2:%.+]] = constant {{.+}} : tensor<2x3xf32>
%2 = "xla_hlo.constant"() {value = dense<[[1.0, 2.0, 3.0], [4.0, 5.0, 6.0]]> : tensor<2x3xf32>} : () -> (tensor<2x3xf32>) %2 = "mhlo.constant"() {value = dense<[[1.0, 2.0, 3.0], [4.0, 5.0, 6.0]]> : tensor<2x3xf32>} : () -> (tensor<2x3xf32>)
// CHECK-NEXT: return [[CST0]], [[CST1]], [[CST2]] : tensor<f32>, tensor<2x3xf32>, tensor<2x3xf32> // CHECK-NEXT: return [[CST0]], [[CST1]], [[CST2]] : tensor<f32>, tensor<2x3xf32>, tensor<2x3xf32>
return %0, %1, %2: tensor<f32>, tensor<2x3xf32>, tensor<2x3xf32> return %0, %1, %2: tensor<f32>, tensor<2x3xf32>, tensor<2x3xf32>
} }
@ -105,7 +105,7 @@ func @float_constant() -> (tensor<f32>, tensor<2x3xf32>, tensor<2x3xf32>) {
// CHECK-LABEL: func @iota.const.1() -> tensor<4xi32> { // CHECK-LABEL: func @iota.const.1() -> tensor<4xi32> {
func @iota.const.1() -> tensor<4xi32> { func @iota.const.1() -> tensor<4xi32> {
// CHECK-NEXT: %[[CST:.*]] = constant dense<[0, 1, 2, 3]> : tensor<4xi32> // CHECK-NEXT: %[[CST:.*]] = constant dense<[0, 1, 2, 3]> : tensor<4xi32>
%0 = "xla_hlo.iota"() {iota_dimension = 0 : i64} : () -> tensor<4xi32> %0 = "mhlo.iota"() {iota_dimension = 0 : i64} : () -> tensor<4xi32>
// CHECK-NEXT: return %[[CST]] : tensor<4xi32> // CHECK-NEXT: return %[[CST]] : tensor<4xi32>
return %0 : tensor<4xi32> return %0 : tensor<4xi32>
} }
@ -113,7 +113,7 @@ func @iota.const.1() -> tensor<4xi32> {
// CHECK-LABEL: func @iota.const.2() -> tensor<2x4xi32> { // CHECK-LABEL: func @iota.const.2() -> tensor<2x4xi32> {
func @iota.const.2() -> tensor<2x4xi32> { func @iota.const.2() -> tensor<2x4xi32> {
// CHECK-NEXT: %[[CST:.*]] = constant dense<{{\[\[}}0, 0, 0, 0], [1, 1, 1, 1]]> : tensor<2x4xi32> // CHECK-NEXT: %[[CST:.*]] = constant dense<{{\[\[}}0, 0, 0, 0], [1, 1, 1, 1]]> : tensor<2x4xi32>
%0 = "xla_hlo.iota"() {iota_dimension = 0 : i64} : () -> tensor<2x4xi32> %0 = "mhlo.iota"() {iota_dimension = 0 : i64} : () -> tensor<2x4xi32>
// CHECK-NEXT: return %[[CST]] : tensor<2x4xi32> // CHECK-NEXT: return %[[CST]] : tensor<2x4xi32>
return %0 : tensor<2x4xi32> return %0 : tensor<2x4xi32>
} }
@ -121,7 +121,7 @@ func @iota.const.2() -> tensor<2x4xi32> {
// CHECK-LABEL: func @iota.const.3() -> tensor<2x4xi32> { // CHECK-LABEL: func @iota.const.3() -> tensor<2x4xi32> {
func @iota.const.3() -> tensor<2x4xi32> { func @iota.const.3() -> tensor<2x4xi32> {
// CHECK-NEXT: %[[CST:.*]] = constant dense<{{\[\[}}0, 1, 2, 3], [0, 1, 2, 3]]> : tensor<2x4xi32> // CHECK-NEXT: %[[CST:.*]] = constant dense<{{\[\[}}0, 1, 2, 3], [0, 1, 2, 3]]> : tensor<2x4xi32>
%0 = "xla_hlo.iota"() {iota_dimension = 1 : i64} : () -> tensor<2x4xi32> %0 = "mhlo.iota"() {iota_dimension = 1 : i64} : () -> tensor<2x4xi32>
// CHECK-NEXT: return %[[CST]] : tensor<2x4xi32> // CHECK-NEXT: return %[[CST]] : tensor<2x4xi32>
return %0 : tensor<2x4xi32> return %0 : tensor<2x4xi32>
} }
@ -129,7 +129,7 @@ func @iota.const.3() -> tensor<2x4xi32> {
// CHECK-LABEL: func @iota.const.4() -> tensor<2x3x4xi32> { // CHECK-LABEL: func @iota.const.4() -> tensor<2x3x4xi32> {
func @iota.const.4() -> tensor<2x3x4xi32> { func @iota.const.4() -> tensor<2x3x4xi32> {
// CHECK-NEXT: %[[CST:.*]] = constant dense<{{\[\[\[}}0, 0, 0, 0], [0, 0, 0, 0], [0, 0, 0, 0{{\]\]}}, {{\[\[}}1, 1, 1, 1], [1, 1, 1, 1], [1, 1, 1, 1]]]> : tensor<2x3x4xi32> // CHECK-NEXT: %[[CST:.*]] = constant dense<{{\[\[\[}}0, 0, 0, 0], [0, 0, 0, 0], [0, 0, 0, 0{{\]\]}}, {{\[\[}}1, 1, 1, 1], [1, 1, 1, 1], [1, 1, 1, 1]]]> : tensor<2x3x4xi32>
%0 = "xla_hlo.iota"() {iota_dimension = 0 : i64} : () -> tensor<2x3x4xi32> %0 = "mhlo.iota"() {iota_dimension = 0 : i64} : () -> tensor<2x3x4xi32>
// CHECK-NEXT: return %[[CST]] : tensor<2x3x4xi32> // CHECK-NEXT: return %[[CST]] : tensor<2x3x4xi32>
return %0 : tensor<2x3x4xi32> return %0 : tensor<2x3x4xi32>
} }
@ -137,7 +137,7 @@ func @iota.const.4() -> tensor<2x3x4xi32> {
// CHECK-LABEL: func @iota.const.5() -> tensor<2x3x4xi32> { // CHECK-LABEL: func @iota.const.5() -> tensor<2x3x4xi32> {
func @iota.const.5() -> tensor<2x3x4xi32> { func @iota.const.5() -> tensor<2x3x4xi32> {
// CHECK-NEXT: %[[CST:.*]] = constant dense<{{\[\[\[}}0, 0, 0, 0], [1, 1, 1, 1], [2, 2, 2, 2{{\]\]}}, {{\[\[}}0, 0, 0, 0], [1, 1, 1, 1], [2, 2, 2, 2]]]> : tensor<2x3x4xi32> // CHECK-NEXT: %[[CST:.*]] = constant dense<{{\[\[\[}}0, 0, 0, 0], [1, 1, 1, 1], [2, 2, 2, 2{{\]\]}}, {{\[\[}}0, 0, 0, 0], [1, 1, 1, 1], [2, 2, 2, 2]]]> : tensor<2x3x4xi32>
%0 = "xla_hlo.iota"() {iota_dimension = 1 : i64} : () -> tensor<2x3x4xi32> %0 = "mhlo.iota"() {iota_dimension = 1 : i64} : () -> tensor<2x3x4xi32>
// CHECK-NEXT: return %[[CST]] : tensor<2x3x4xi32> // CHECK-NEXT: return %[[CST]] : tensor<2x3x4xi32>
return %0 : tensor<2x3x4xi32> return %0 : tensor<2x3x4xi32>
} }
@ -145,7 +145,7 @@ func @iota.const.5() -> tensor<2x3x4xi32> {
// CHECK-LABEL: func @iota.const.6() -> tensor<2x3x4xi32> { // CHECK-LABEL: func @iota.const.6() -> tensor<2x3x4xi32> {
func @iota.const.6() -> tensor<2x3x4xi32> { func @iota.const.6() -> tensor<2x3x4xi32> {
// CHECK-NEXT: %[[CST:.*]] = constant dense<{{\[\[\[}}0, 1, 2, 3], [0, 1, 2, 3], [0, 1, 2, 3{{\]\]}}, {{\[\[}}0, 1, 2, 3], [0, 1, 2, 3], [0, 1, 2, 3]]]> : tensor<2x3x4xi32> // CHECK-NEXT: %[[CST:.*]] = constant dense<{{\[\[\[}}0, 1, 2, 3], [0, 1, 2, 3], [0, 1, 2, 3{{\]\]}}, {{\[\[}}0, 1, 2, 3], [0, 1, 2, 3], [0, 1, 2, 3]]]> : tensor<2x3x4xi32>
%0 = "xla_hlo.iota"() {iota_dimension = 2 : i64} : () -> tensor<2x3x4xi32> %0 = "mhlo.iota"() {iota_dimension = 2 : i64} : () -> tensor<2x3x4xi32>
// CHECK-NEXT: return %[[CST]] : tensor<2x3x4xi32> // CHECK-NEXT: return %[[CST]] : tensor<2x3x4xi32>
return %0 : tensor<2x3x4xi32> return %0 : tensor<2x3x4xi32>
} }
@ -153,7 +153,7 @@ func @iota.const.6() -> tensor<2x3x4xi32> {
// CHECK-LABEL: func @iota.const.f32 // CHECK-LABEL: func @iota.const.f32
func @iota.const.f32() -> tensor<4xf32> { func @iota.const.f32() -> tensor<4xf32> {
// CHECK-NEXT: %[[CST:.*]] = constant dense<[0.000000e+00, 1.000000e+00, 2.000000e+00, 3.000000e+00]> : tensor<4xf32> // CHECK-NEXT: %[[CST:.*]] = constant dense<[0.000000e+00, 1.000000e+00, 2.000000e+00, 3.000000e+00]> : tensor<4xf32>
%0 = "xla_hlo.iota"() {iota_dimension = 0 : i64} : () -> tensor<4xf32> %0 = "mhlo.iota"() {iota_dimension = 0 : i64} : () -> tensor<4xf32>
// CHECK-NEXT: return %[[CST]] : tensor<4xf32> // CHECK-NEXT: return %[[CST]] : tensor<4xf32>
return %0 : tensor<4xf32> return %0 : tensor<4xf32>
} }
@ -161,7 +161,7 @@ func @iota.const.f32() -> tensor<4xf32> {
// CHECK-LABEL: func @iota.const.f64 // CHECK-LABEL: func @iota.const.f64
func @iota.const.f64() -> tensor<4xf64> { func @iota.const.f64() -> tensor<4xf64> {
// CHECK-NEXT: %[[CST:.*]] = constant dense<[0.000000e+00, 1.000000e+00, 2.000000e+00, 3.000000e+00]> : tensor<4xf64> // CHECK-NEXT: %[[CST:.*]] = constant dense<[0.000000e+00, 1.000000e+00, 2.000000e+00, 3.000000e+00]> : tensor<4xf64>
%0 = "xla_hlo.iota"() {iota_dimension = 0 : i64} : () -> tensor<4xf64> %0 = "mhlo.iota"() {iota_dimension = 0 : i64} : () -> tensor<4xf64>
// CHECK-NEXT: return %[[CST]] : tensor<4xf64> // CHECK-NEXT: return %[[CST]] : tensor<4xf64>
return %0 : tensor<4xf64> return %0 : tensor<4xf64>
} }
@ -169,7 +169,7 @@ func @iota.const.f64() -> tensor<4xf64> {
// CHECK-LABEL: func @iota.const.bf16 // CHECK-LABEL: func @iota.const.bf16
func @iota.const.bf16() -> tensor<4xbf16> { func @iota.const.bf16() -> tensor<4xbf16> {
// CHECK-NEXT: %[[CST:.*]] = constant dense<[0.000000e+00, 1.000000e+00, 2.000000e+00, 3.000000e+00]> : tensor<4xbf16> // CHECK-NEXT: %[[CST:.*]] = constant dense<[0.000000e+00, 1.000000e+00, 2.000000e+00, 3.000000e+00]> : tensor<4xbf16>
%0 = "xla_hlo.iota"() {iota_dimension = 0 : i64} : () -> tensor<4xbf16> %0 = "mhlo.iota"() {iota_dimension = 0 : i64} : () -> tensor<4xbf16>
// CHECK-NEXT: return %[[CST]] : tensor<4xbf16> // CHECK-NEXT: return %[[CST]] : tensor<4xbf16>
return %0 : tensor<4xbf16> return %0 : tensor<4xbf16>
} }
@ -178,8 +178,8 @@ func @iota.const.bf16() -> tensor<4xbf16> {
func @iota.const.complex.f32() -> tensor<4xcomplex<f32>> { func @iota.const.complex.f32() -> tensor<4xcomplex<f32>> {
// CHECK-NEXT: [[REAL:%.*]] = constant dense<[0.000000e+00, 1.000000e+00, 2.000000e+00, 3.000000e+00]> : tensor<4xf32> // CHECK-NEXT: [[REAL:%.*]] = constant dense<[0.000000e+00, 1.000000e+00, 2.000000e+00, 3.000000e+00]> : tensor<4xf32>
// CHECK-NEXT: [[IMAG:%.*]] = constant dense<0.000000e+00> : tensor<4xf32> // CHECK-NEXT: [[IMAG:%.*]] = constant dense<0.000000e+00> : tensor<4xf32>
// CHECK-NEXT: [[COMPLEX:%.*]] = "xla_hlo.complex"([[REAL]], [[IMAG]]) // CHECK-NEXT: [[COMPLEX:%.*]] = "mhlo.complex"([[REAL]], [[IMAG]])
%0 = "xla_hlo.iota"() {iota_dimension = 0 : i64} : () -> tensor<4xcomplex<f32>> %0 = "mhlo.iota"() {iota_dimension = 0 : i64} : () -> tensor<4xcomplex<f32>>
// CHECK-NEXT: return [[COMPLEX]] : tensor<4xcomplex<f32>> // CHECK-NEXT: return [[COMPLEX]] : tensor<4xcomplex<f32>>
return %0 : tensor<4xcomplex<f32>> return %0 : tensor<4xcomplex<f32>>
} }
@ -188,8 +188,8 @@ func @iota.const.complex.f32() -> tensor<4xcomplex<f32>> {
func @iota.const.complex.f64() -> tensor<4xcomplex<f64>> { func @iota.const.complex.f64() -> tensor<4xcomplex<f64>> {
// CHECK-NEXT: [[REAL:%.*]] = constant dense<[0.000000e+00, 1.000000e+00, 2.000000e+00, 3.000000e+00]> : tensor<4xf64> // CHECK-NEXT: [[REAL:%.*]] = constant dense<[0.000000e+00, 1.000000e+00, 2.000000e+00, 3.000000e+00]> : tensor<4xf64>
// CHECK-NEXT: [[IMAG:%.*]] = constant dense<0.000000e+00> : tensor<4xf64> // CHECK-NEXT: [[IMAG:%.*]] = constant dense<0.000000e+00> : tensor<4xf64>
// CHECK-NEXT: [[COMPLEX:%.*]] = "xla_hlo.complex"([[REAL]], [[IMAG]]) // CHECK-NEXT: [[COMPLEX:%.*]] = "mhlo.complex"([[REAL]], [[IMAG]])
%0 = "xla_hlo.iota"() {iota_dimension = 0 : i64} : () -> tensor<4xcomplex<f64>> %0 = "mhlo.iota"() {iota_dimension = 0 : i64} : () -> tensor<4xcomplex<f64>>
// CHECK-NEXT: return [[COMPLEX]] : tensor<4xcomplex<f64>> // CHECK-NEXT: return [[COMPLEX]] : tensor<4xcomplex<f64>>
return %0 : tensor<4xcomplex<f64>> return %0 : tensor<4xcomplex<f64>>
} }

View File

@ -1,4 +1,4 @@
// RUN: xla-opt -xla-legalize-tanh-to-approximation -split-input-file %s | FileCheck %s // RUN: mlir-hlo-opt -xla-legalize-tanh-to-approximation -split-input-file %s | FileCheck %s
func @tanh_f64(%arg0 : f64) -> f64 { func @tanh_f64(%arg0 : f64) -> f64 {
%res = tanh %arg0 : f64 %res = tanh %arg0 : f64

View File

@ -1,4 +1,4 @@
// RUN: xla-opt -lhlo-copy-removal %s -o - | FileCheck %s // RUN: mlir-hlo-opt -lhlo-copy-removal %s -o - | FileCheck %s
// CHECK-LABEL: func @remove_simple // CHECK-LABEL: func @remove_simple
func @remove_simple(%arg0: memref<2x2xf32>) { func @remove_simple(%arg0: memref<2x2xf32>) {

View File

@ -1,6 +1,6 @@
// RUN: xla-opt -lhlo-fuse-linalg %s -split-input-file | FileCheck %s --dump-input=always // RUN: mlir-hlo-opt -lhlo-fuse-linalg %s -split-input-file | FileCheck %s --dump-input=always
// RUN: xla-opt -lhlo-fuse-linalg=tile-sizes=2,3 %s -split-input-file | FileCheck %s -check-prefix=TILED // RUN: mlir-hlo-opt -lhlo-fuse-linalg=tile-sizes=2,3 %s -split-input-file | FileCheck %s -check-prefix=TILED
// RUN: xla-opt -lhlo-fuse-linalg=use-parallel-loops %s -split-input-file | FileCheck %s -check-prefix=PLOOP // RUN: mlir-hlo-opt -lhlo-fuse-linalg=use-parallel-loops %s -split-input-file | FileCheck %s -check-prefix=PLOOP
#map0 = affine_map<(d0, d1) -> (d0, d1)> #map0 = affine_map<(d0, d1) -> (d0, d1)>
#pointwise_2d_trait = {args_in = 2, args_out = 1, indexing_maps = [#map0, #map0, #map0], iterator_types = ["parallel", "parallel"]} #pointwise_2d_trait = {args_in = 2, args_out = 1, indexing_maps = [#map0, #map0, #map0], iterator_types = ["parallel", "parallel"]}

View File

@ -4,7 +4,7 @@
// Lowering to STD dialect and store forwarding pass would be required to get // Lowering to STD dialect and store forwarding pass would be required to get
// rid of them. This is exactly what is done in the real MLIR GPU pipeline, but // rid of them. This is exactly what is done in the real MLIR GPU pipeline, but
// here we disable verification with `verify-each=0` to check the output IR. // here we disable verification with `verify-each=0` to check the output IR.
// RUN: xla-opt %s -lhlo-legalize-to-parallel-loops -canonicalize --verify-each=0 | FileCheck %s // RUN: mlir-hlo-opt %s -lhlo-legalize-to-parallel-loops -canonicalize --verify-each=0 | FileCheck %s
func @select_and_scatter(%arg: memref<112x112xf32>, func @select_and_scatter(%arg: memref<112x112xf32>,
%src: memref<56x56xf32>, %src: memref<56x56xf32>,

View File

@ -1,4 +1,4 @@
// RUN: xla-opt -lhlo-legalize-to-affine %s -o - | FileCheck %s // RUN: mlir-hlo-opt -lhlo-legalize-to-affine %s -o - | FileCheck %s
// Smoke test. // Smoke test.
// CHECK-LABEL: func @min_op // CHECK-LABEL: func @min_op

View File

@ -1,4 +1,4 @@
// RUN: xla-opt %s -lhlo-legalize-to-gpu -split-input-file | FileCheck %s // RUN: mlir-hlo-opt %s -lhlo-legalize-to-gpu -split-input-file | FileCheck %s
func @reduce(%arg: memref<100x10xf32>, func @reduce(%arg: memref<100x10xf32>,
%init: memref<f32>, %init: memref<f32>,

View File

@ -1,4 +1,4 @@
// RUN: xla-opt %s -lhlo-legalize-to-linalg -split-input-file | FileCheck %s // RUN: mlir-hlo-opt %s -lhlo-legalize-to-linalg -split-input-file | FileCheck %s
// CHECK: #map0 = affine_map<(d0, d1) -> (d0, d1)> // CHECK: #map0 = affine_map<(d0, d1) -> (d0, d1)>
// CHECK-LABEL: func @element_wise // CHECK-LABEL: func @element_wise

View File

@ -1,4 +1,4 @@
// RUN: xla-opt %s --test-lhlo-legalize-to-llvm -split-input-file | FileCheck %s // RUN: mlir-hlo-opt %s --test-lhlo-legalize-to-llvm -split-input-file | FileCheck %s
// CHECK-LABEL: func @static_memref_cast // CHECK-LABEL: func @static_memref_cast
func @static_memref_cast(%buf : memref<10x1x5xf32>) { func @static_memref_cast(%buf : memref<10x1x5xf32>) {

View File

@ -1,4 +1,4 @@
// RUN: xla-opt %s -lhlo-legalize-to-parallel-loops -canonicalize -split-input-file | FileCheck %s // RUN: mlir-hlo-opt %s -lhlo-legalize-to-parallel-loops -canonicalize -split-input-file | FileCheck %s
func @reduce(%arg: memref<100x10x5xf32>, func @reduce(%arg: memref<100x10x5xf32>,
%init: memref<f32>, %init: memref<f32>,

View File

@ -1,4 +1,4 @@
// RUN: xla-opt %s -verify-diagnostics -split-input-file | xla-opt | FileCheck %s // RUN: mlir-hlo-opt %s -verify-diagnostics -split-input-file | mlir-hlo-opt | FileCheck %s
// ----- // -----
@ -396,9 +396,9 @@ func @fusion_memref(%input1: memref<10xf32>, %input2: memref<10xf32>, %input3: m
"xla_lhlo.fusion"() ( { "xla_lhlo.fusion"() ( {
%0 = tensor_load %input1 : memref<10xf32> %0 = tensor_load %input1 : memref<10xf32>
%1 = tensor_load %input2 : memref<10xf32> %1 = tensor_load %input2 : memref<10xf32>
%2 = "xla_hlo.add"(%0, %1) {name = "add"} : (tensor<10xf32>, tensor<10xf32>) -> tensor<10xf32> %2 = "mhlo.add"(%0, %1) {name = "add"} : (tensor<10xf32>, tensor<10xf32>) -> tensor<10xf32>
%3 = tensor_load %input3 : memref<10xf32> %3 = tensor_load %input3 : memref<10xf32>
%4 = "xla_hlo.multiply"(%2, %3) {name = "multiply"} : (tensor<10xf32>, tensor<10xf32>) -> tensor<10xf32> %4 = "mhlo.multiply"(%2, %3) {name = "multiply"} : (tensor<10xf32>, tensor<10xf32>) -> tensor<10xf32>
tensor_store %4, %out : memref<10xf32> tensor_store %4, %out : memref<10xf32>
"xla_lhlo.terminator"() : () -> () "xla_lhlo.terminator"() : () -> ()
} ) : () -> () } ) : () -> ()
@ -803,15 +803,15 @@ func @shift_right_logical_memrefs(%arg0: memref<1xf32>, %arg1: memref<1xf32>, %a
func @all_reduce_memrefs(%arg0: memref<10xf32>, %arg_out: memref<10xf32>) -> () { func @all_reduce_memrefs(%arg0: memref<10xf32>, %arg_out: memref<10xf32>) -> () {
"xla_lhlo.all_reduce"(%arg0, %arg_out) ({ "xla_lhlo.all_reduce"(%arg0, %arg_out) ({
^bb0(%lhs: tensor<f32>, %rhs: tensor<f32>): ^bb0(%lhs: tensor<f32>, %rhs: tensor<f32>):
%max = xla_hlo.maximum %lhs, %rhs : tensor<f32> %max = mhlo.maximum %lhs, %rhs : tensor<f32>
"xla_hlo.return"(%max) : (tensor<f32>) -> () "mhlo.return"(%max) : (tensor<f32>) -> ()
}) })
{ replica_groups = dense<[[0, 2, 4, 6], [1, 3, 5, 7]]> : tensor<2x4xi64> }: (memref<10xf32>, memref<10xf32>) -> () { replica_groups = dense<[[0, 2, 4, 6], [1, 3, 5, 7]]> : tensor<2x4xi64> }: (memref<10xf32>, memref<10xf32>) -> ()
"xla_lhlo.all_reduce"(%arg0, %arg_out) ({ "xla_lhlo.all_reduce"(%arg0, %arg_out) ({
^bb0(%lhs: tensor<f32>, %rhs: tensor<f32>): ^bb0(%lhs: tensor<f32>, %rhs: tensor<f32>):
%max = xla_hlo.maximum %lhs, %rhs : tensor<f32> %max = mhlo.maximum %lhs, %rhs : tensor<f32>
"xla_hlo.return"(%max) : (tensor<f32>) -> () "mhlo.return"(%max) : (tensor<f32>) -> ()
}) })
{ {
replica_groups = dense<[[0, 2, 4, 6], [1, 3, 5, 7]]> : tensor<2x4xi64>, replica_groups = dense<[[0, 2, 4, 6], [1, 3, 5, 7]]> : tensor<2x4xi64>,
@ -958,8 +958,8 @@ func @scatter_memrefs(%input: memref<200x100x300xf32>, %indices: memref<10x2xi32
%updates: memref<10x300xf32>, %arg_out: memref<200x100x300xf32>) -> () { %updates: memref<10x300xf32>, %arg_out: memref<200x100x300xf32>) -> () {
"xla_lhlo.scatter" (%input, %indices, %updates, %arg_out) ({ "xla_lhlo.scatter" (%input, %indices, %updates, %arg_out) ({
^bb0(%lhs: tensor<f32>, %rhs: tensor<f32>): // no predecessors ^bb0(%lhs: tensor<f32>, %rhs: tensor<f32>): // no predecessors
%add = xla_hlo.add %lhs, %rhs : tensor<f32> %add = mhlo.add %lhs, %rhs : tensor<f32>
"xla_hlo.return"(%add) : (tensor<f32>) -> () "mhlo.return"(%add) : (tensor<f32>) -> ()
}) { }) {
scatter_dimension_numbers = { scatter_dimension_numbers = {
update_window_dims = dense<[1]> : tensor<1xi64>, update_window_dims = dense<[1]> : tensor<1xi64>,
@ -979,8 +979,8 @@ func @scatter_memrefs(%input: memref<200x100x300xf32>, %indices: memref<10x2xi32
func @map_memrefs(%arg0: memref<20xf32>, %arg1: memref<20xf32>, %arg_out: memref<20xf32>) -> () { func @map_memrefs(%arg0: memref<20xf32>, %arg1: memref<20xf32>, %arg_out: memref<20xf32>) -> () {
"xla_lhlo.map"(%arg0, %arg1, %arg_out) ({ "xla_lhlo.map"(%arg0, %arg1, %arg_out) ({
^bb0(%a: tensor<f32>, %b: tensor<f32>): ^bb0(%a: tensor<f32>, %b: tensor<f32>):
%c = xla_hlo.add %a, %b : tensor<f32> %c = mhlo.add %a, %b : tensor<f32>
"xla_hlo.return"(%c) : (tensor<f32>) -> () "mhlo.return"(%c) : (tensor<f32>) -> ()
}) {dimensions = dense<0> : tensor<1xi64>} : (memref<20xf32>, memref<20xf32>, memref<20xf32>) -> () }) {dimensions = dense<0> : tensor<1xi64>} : (memref<20xf32>, memref<20xf32>, memref<20xf32>) -> ()
return return
} }
@ -991,8 +991,8 @@ func @map_memrefs(%arg0: memref<20xf32>, %arg1: memref<20xf32>, %arg_out: memref
// expected-error@+1{{requires the same shape for all operands}} // expected-error@+1{{requires the same shape for all operands}}
"xla_lhlo.map"(%arg0, %arg1, %arg_out) ({ "xla_lhlo.map"(%arg0, %arg1, %arg_out) ({
^bb0(%a: tensor<f32>, %b: tensor<f32>): ^bb0(%a: tensor<f32>, %b: tensor<f32>):
%c = xla_hlo.add %a, %b : tensor<f32> %c = mhlo.add %a, %b : tensor<f32>
"xla_hlo.return"(%c) : (tensor<f32>) -> () "mhlo.return"(%c) : (tensor<f32>) -> ()
}) {dimensions = dense<0> : tensor<1xi64>} : (memref<20xf32>, memref<20xf32>, memref<10xf32>) -> () }) {dimensions = dense<0> : tensor<1xi64>} : (memref<20xf32>, memref<20xf32>, memref<10xf32>) -> ()
return return
} }
@ -1012,8 +1012,8 @@ func @sort_memrefs(%arg0: memref<16x16xf32>, %arg1: memref<16x16xf16>,
%out0: memref<16x16xf32>, %out1: memref<16x16xf16>) -> () { %out0: memref<16x16xf32>, %out1: memref<16x16xf16>) -> () {
"xla_lhlo.sort"(%arg0, %arg1, %out0, %out1) ( { "xla_lhlo.sort"(%arg0, %arg1, %out0, %out1) ( {
^bb0(%a: tensor<f32>, %b: tensor<f32>, %c: tensor<f16>, %d: tensor<f16>): ^bb0(%a: tensor<f32>, %b: tensor<f32>, %c: tensor<f16>, %d: tensor<f16>):
%7 = "xla_hlo.compare"(%a, %b) {comparison_direction = "GT"} : (tensor<f32>, tensor<f32>) -> tensor<i1> %7 = "mhlo.compare"(%a, %b) {comparison_direction = "GT"} : (tensor<f32>, tensor<f32>) -> tensor<i1>
"xla_hlo.return"(%7) : (tensor<i1>) -> () "mhlo.return"(%7) : (tensor<i1>) -> ()
}) {dimension = 1 : i64, is_stable = true} : (memref<16x16xf32>, memref<16x16xf16>, memref<16x16xf32>, memref<16x16xf16>) -> () }) {dimension = 1 : i64, is_stable = true} : (memref<16x16xf32>, memref<16x16xf16>, memref<16x16xf32>, memref<16x16xf16>) -> ()
return return
} }
@ -1025,8 +1025,8 @@ func @sort_memrefs(%arg0: memref<16x16xf32>, %arg1: memref<16x16xf16>,
%out0: memref<16x16xf32>, %out1: memref<16x16xf16>) -> () { %out0: memref<16x16xf32>, %out1: memref<16x16xf16>) -> () {
"xla_lhlo.sort"(%arg0, %arg1, %out0, %out1) ( { "xla_lhlo.sort"(%arg0, %arg1, %out0, %out1) ( {
^bb0(%a: tensor<f32>, %b: tensor<f32>, %c: tensor<f16>, %d: tensor<f16>): ^bb0(%a: tensor<f32>, %b: tensor<f32>, %c: tensor<f16>, %d: tensor<f16>):
%7 = "xla_hlo.compare"(%a, %b) {comparison_direction = "GT"} : (tensor<f32>, tensor<f32>) -> tensor<i1> %7 = "mhlo.compare"(%a, %b) {comparison_direction = "GT"} : (tensor<f32>, tensor<f32>) -> tensor<i1>
"xla_hlo.return"(%7) : (tensor<i1>) -> () "mhlo.return"(%7) : (tensor<i1>) -> ()
}) {dimension = 1 : i64} : (memref<16x16xf32>, memref<16x16xf16>, memref<16x16xf32>, memref<16x16xf16>) -> () }) {dimension = 1 : i64} : (memref<16x16xf32>, memref<16x16xf16>, memref<16x16xf32>, memref<16x16xf16>) -> ()
return return
} }
@ -1038,8 +1038,8 @@ func @sort_memrefs(%arg0: memref<16x16xf32>, %arg1: memref<16x16xf16>,
%out0: memref<16x16xf32>, %out1: memref<16x16xf16>) -> () { %out0: memref<16x16xf32>, %out1: memref<16x16xf16>) -> () {
"xla_lhlo.sort"(%arg0, %arg1, %out0, %out1) ( { "xla_lhlo.sort"(%arg0, %arg1, %out0, %out1) ( {
^bb0(%a: tensor<f32>, %b: tensor<f32>, %c: tensor<f16>, %d: tensor<f16>): ^bb0(%a: tensor<f32>, %b: tensor<f32>, %c: tensor<f16>, %d: tensor<f16>):
%7 = "xla_hlo.compare"(%a, %b) {comparison_direction = "GT"} : (tensor<f32>, tensor<f32>) -> tensor<i1> %7 = "mhlo.compare"(%a, %b) {comparison_direction = "GT"} : (tensor<f32>, tensor<f32>) -> tensor<i1>
"xla_hlo.return"(%7) : (tensor<i1>) -> () "mhlo.return"(%7) : (tensor<i1>) -> ()
}) : (memref<16x16xf32>, memref<16x16xf16>, memref<16x16xf32>, memref<16x16xf16>) -> () }) : (memref<16x16xf32>, memref<16x16xf16>, memref<16x16xf32>, memref<16x16xf16>) -> ()
return return
} }

View File

@ -0,0 +1,224 @@
// RUN: mlir-hlo-opt %s -test-xla-chlo-legalize-to-hlo -test-xla-lower-complex | FileCheck %s
// CHECK-LABEL: @add
func @add(%arg0 : tensor<2xf32>, %arg1 : tensor<2xf32>, %arg2 : tensor<2xf32>, %arg3 : tensor<2xf32>) -> (tensor<2xf32>, tensor<2xf32>) {
%2 = "mhlo.complex"(%arg0, %arg1) : (tensor<2xf32>, tensor<2xf32>) -> (tensor<2xcomplex<f32>>)
%3 = "mhlo.complex"(%arg2, %arg3) : (tensor<2xf32>, tensor<2xf32>) -> (tensor<2xcomplex<f32>>)
// CHECK-DAG: [[VAL0:%.+]] = mhlo.add %arg0, %arg2
// CHECK-DAG: [[VAL1:%.+]] = mhlo.add %arg1, %arg3
%4 = "mhlo.add"(%2, %3) : (tensor<2xcomplex<f32>>, tensor<2xcomplex<f32>>) -> (tensor<2xcomplex<f32>>)
%5 = "mhlo.real"(%4) : (tensor<2xcomplex<f32>>) -> (tensor<2xf32>)
%6 = "mhlo.imag"(%4) : (tensor<2xcomplex<f32>>) -> (tensor<2xf32>)
// CHECK: return [[VAL0]], [[VAL1]]
return %5, %6 : tensor<2xf32>, tensor<2xf32>
}
// CHECK-LABEL: @add_unranked
func @add_unranked(%arg0 : tensor<*xf32>, %arg1 : tensor<*xf32>, %arg2 : tensor<*xf32>, %arg3 : tensor<*xf32>) -> (tensor<*xf32>, tensor<*xf32>) {
%2 = "mhlo.complex"(%arg0, %arg1) : (tensor<*xf32>, tensor<*xf32>) -> (tensor<*xcomplex<f32>>)
%3 = "mhlo.complex"(%arg2, %arg3) : (tensor<*xf32>, tensor<*xf32>) -> (tensor<*xcomplex<f32>>)
// CHECK-DAG: [[VAL0:%.+]] = mhlo.add %arg0, %arg2
// CHECK-DAG: [[VAL1:%.+]] = mhlo.add %arg1, %arg3
%4 = "mhlo.add"(%2, %3) : (tensor<*xcomplex<f32>>, tensor<*xcomplex<f32>>) -> (tensor<*xcomplex<f32>>)
%5 = "mhlo.real"(%4) : (tensor<*xcomplex<f32>>) -> (tensor<*xf32>)
%6 = "mhlo.imag"(%4) : (tensor<*xcomplex<f32>>) -> (tensor<*xf32>)
// CHECK: return [[VAL0]], [[VAL1]]
return %5, %6 : tensor<*xf32>, tensor<*xf32>
}
// CHECK-LABEL: @sub
func @sub(%arg0 : tensor<2xf32>, %arg1 : tensor<2xf32>, %arg2 : tensor<2xf32>, %arg3 : tensor<2xf32>) -> (tensor<2xf32>, tensor<2xf32>) {
%2 = "mhlo.complex"(%arg0, %arg1) : (tensor<2xf32>, tensor<2xf32>) -> (tensor<2xcomplex<f32>>)
%3 = "mhlo.complex"(%arg2, %arg3) : (tensor<2xf32>, tensor<2xf32>) -> (tensor<2xcomplex<f32>>)
// CHECK-DAG: [[VAL0:%.+]] = mhlo.subtract %arg0, %arg2
// CHECK-DAG: [[VAL1:%.+]] = mhlo.subtract %arg1, %arg3
%4 = "mhlo.subtract"(%2, %3) : (tensor<2xcomplex<f32>>, tensor<2xcomplex<f32>>) -> (tensor<2xcomplex<f32>>)
%5 = "mhlo.real"(%4) : (tensor<2xcomplex<f32>>) -> (tensor<2xf32>)
%6 = "mhlo.imag"(%4) : (tensor<2xcomplex<f32>>) -> (tensor<2xf32>)
// CHECK: return [[VAL0]], [[VAL1]]
return %5, %6 : tensor<2xf32>, tensor<2xf32>
}
// CHECK-LABEL: @sub_unranked
func @sub_unranked(%arg0 : tensor<*xf32>, %arg1 : tensor<*xf32>, %arg2 : tensor<*xf32>, %arg3 : tensor<*xf32>) -> (tensor<*xf32>, tensor<*xf32>) {
%2 = "mhlo.complex"(%arg0, %arg1) : (tensor<*xf32>, tensor<*xf32>) -> (tensor<*xcomplex<f32>>)
%3 = "mhlo.complex"(%arg2, %arg3) : (tensor<*xf32>, tensor<*xf32>) -> (tensor<*xcomplex<f32>>)
// CHECK-DAG: [[VAL0:%.+]] = mhlo.subtract %arg0, %arg2
// CHECK-DAG: [[VAL1:%.+]] = mhlo.subtract %arg1, %arg3
%4 = "mhlo.subtract"(%2, %3) : (tensor<*xcomplex<f32>>, tensor<*xcomplex<f32>>) -> (tensor<*xcomplex<f32>>)
%5 = "mhlo.real"(%4) : (tensor<*xcomplex<f32>>) -> (tensor<*xf32>)
%6 = "mhlo.imag"(%4) : (tensor<*xcomplex<f32>>) -> (tensor<*xf32>)
// CHECK: return [[VAL0]], [[VAL1]]
return %5, %6 : tensor<*xf32>, tensor<*xf32>
}
// CHECK-LABEL: @mul
func @mul(%arg0 : tensor<2xf32>, %arg1 : tensor<2xf32>, %arg2 : tensor<2xf32>, %arg3 : tensor<2xf32>) -> (tensor<2xf32>, tensor<2xf32>) {
%2 = "mhlo.complex"(%arg0, %arg1) : (tensor<2xf32>, tensor<2xf32>) -> (tensor<2xcomplex<f32>>)
%3 = "mhlo.complex"(%arg2, %arg3) : (tensor<2xf32>, tensor<2xf32>) -> (tensor<2xcomplex<f32>>)
// CHECK-DAG: [[VAL0:%.+]] = mhlo.multiply %arg0, %arg2
// CHECK-DAG: [[VAL1:%.+]] = mhlo.multiply %arg1, %arg3
// CHECK-DAG: [[VAL2:%.+]] = mhlo.subtract [[VAL0]], [[VAL1]]
// CHECK-DAG: [[VAL3:%.+]] = mhlo.multiply %arg0, %arg3
// CHECK-DAG: [[VAL4:%.+]] = mhlo.multiply %arg1, %arg2
// CHECK-DAG: [[VAL5:%.+]] = mhlo.add [[VAL3]], [[VAL4]]
%4 = "mhlo.multiply"(%2, %3) : (tensor<2xcomplex<f32>>, tensor<2xcomplex<f32>>) -> (tensor<2xcomplex<f32>>)
%5 = "mhlo.real"(%4) : (tensor<2xcomplex<f32>>) -> (tensor<2xf32>)
%6 = "mhlo.imag"(%4) : (tensor<2xcomplex<f32>>) -> (tensor<2xf32>)
// CHECK: return %2, %5 : tensor<2xf32>, tensor<2xf32>
return %5, %6 : tensor<2xf32>, tensor<2xf32>
}
// CHECK-LABEL: @mul_unranked
func @mul_unranked(%arg0 : tensor<*xf32>, %arg1 : tensor<*xf32>, %arg2 : tensor<*xf32>, %arg3 : tensor<*xf32>) -> (tensor<*xf32>, tensor<*xf32>) {
%2 = "mhlo.complex"(%arg0, %arg1) : (tensor<*xf32>, tensor<*xf32>) -> (tensor<*xcomplex<f32>>)
%3 = "mhlo.complex"(%arg2, %arg3) : (tensor<*xf32>, tensor<*xf32>) -> (tensor<*xcomplex<f32>>)
// CHECK-DAG: [[VAL0:%.+]] = mhlo.multiply %arg0, %arg2
// CHECK-DAG: [[VAL1:%.+]] = mhlo.multiply %arg1, %arg3
// CHECK-DAG: [[VAL2:%.+]] = mhlo.subtract [[VAL0]], [[VAL1]]
// CHECK-DAG: [[VAL3:%.+]] = mhlo.multiply %arg0, %arg3
// CHECK-DAG: [[VAL4:%.+]] = mhlo.multiply %arg1, %arg2
// CHECK-DAG: [[VAL5:%.+]] = mhlo.add [[VAL3]], [[VAL4]]
%4 = "mhlo.multiply"(%2, %3) : (tensor<*xcomplex<f32>>, tensor<*xcomplex<f32>>) -> (tensor<*xcomplex<f32>>)
%5 = "mhlo.real"(%4) : (tensor<*xcomplex<f32>>) -> (tensor<*xf32>)
%6 = "mhlo.imag"(%4) : (tensor<*xcomplex<f32>>) -> (tensor<*xf32>)
// CHECK: return %2, %5 : tensor<*xf32>, tensor<*xf32>
return %5, %6 : tensor<*xf32>, tensor<*xf32>
}
// CHECK-LABEL: @div
func @div(%arg0 : tensor<2xf32>, %arg1 : tensor<2xf32>, %arg2 : tensor<2xf32>, %arg3 : tensor<2xf32>) -> (tensor<2xf32>, tensor<2xf32>) {
%2 = "mhlo.complex"(%arg0, %arg1) : (tensor<2xf32>, tensor<2xf32>) -> (tensor<2xcomplex<f32>>)
%3 = "mhlo.complex"(%arg2, %arg3) : (tensor<2xf32>, tensor<2xf32>) -> (tensor<2xcomplex<f32>>)
// CHECK-DAG: [[VAL0:%.+]] = "mhlo.negate"(%arg3)
// Compute the numerator's real component:
// numerator.real = lhs.real * rhs.real lhs.imag * rhs.imag
// CHECK-DAG: [[VAL1:%.+]] = mhlo.multiply %arg0, %arg2
// CHECK-DAG: [[VAL2:%.+]] = mhlo.multiply %arg1, [[VAL0]]
// CHECK-DAG: [[VAL3:%.+]] = mhlo.subtract [[VAL1]], [[VAL2]]
// Compute the real valued denominator as rhs * con(rhs):
// denominator = rhs.real * rhs.real + rhs.imag * rhs.imag
// CHECK-DAG: [[VAL4:%.+]] = mhlo.multiply %arg2, %arg2
// CHECK-DAG: [[VAL5:%.+]] = mhlo.multiply %arg3, [[VAL0]]
// CHECK-DAG: [[VAL6:%.+]] = mhlo.subtract [[VAL4]], [[VAL5]]
// Compute the numerator's imaginary component:
// numerator.imag = lhs.imag * rhs.real - lhs.real * rhs.imag
// CHECK-DAG: [[VAL7:%.+]] = mhlo.multiply %arg1, %arg2
// CHECK-DAG: [[VAL8:%.+]] = mhlo.multiply %arg0, [[VAL0]]
// CHECK-DAG: [[VAL9:%.+]] = mhlo.add [[VAL8]], [[VAL7]]
// Divide the numerator by the real valued denominator.
// CHECK-DAG: [[VAL10:%.+]] = mhlo.divide [[VAL3]], [[VAL6]]
// CHECK-DAG: [[VAL11:%.+]] = mhlo.divide [[VAL9]], [[VAL6]]
%4 = "mhlo.divide"(%2, %3) : (tensor<2xcomplex<f32>>, tensor<2xcomplex<f32>>) -> (tensor<2xcomplex<f32>>)
%5 = "mhlo.real"(%4) : (tensor<2xcomplex<f32>>) -> (tensor<2xf32>)
%6 = "mhlo.imag"(%4) : (tensor<2xcomplex<f32>>) -> (tensor<2xf32>)
// CHECK: return [[VAL10]], [[VAL11]]
return %5, %6 : tensor<2xf32>, tensor<2xf32>
}
// -----
// CHECK-LABEL: @div_unranked
func @div_unranked(%arg0 : tensor<*xf32>, %arg1 : tensor<*xf32>, %arg2 : tensor<*xf32>, %arg3 : tensor<*xf32>) -> (tensor<*xf32>, tensor<*xf32>) {
%2 = "mhlo.complex"(%arg0, %arg1) : (tensor<*xf32>, tensor<*xf32>) -> (tensor<*xcomplex<f32>>)
%3 = "mhlo.complex"(%arg2, %arg3) : (tensor<*xf32>, tensor<*xf32>) -> (tensor<*xcomplex<f32>>)
// CHECK-DAG: [[VAL0:%.+]] = "mhlo.negate"(%arg3)
// Compute the numerator's real component:
// numerator.real = lhs.real * rhs.real lhs.imag * rhs.imag
// CHECK-DAG: [[VAL1:%.+]] = mhlo.multiply %arg0, %arg2
// CHECK-DAG: [[VAL2:%.+]] = mhlo.multiply %arg1, [[VAL0]]
// CHECK-DAG: [[VAL3:%.+]] = mhlo.subtract [[VAL1]], [[VAL2]]
// Compute the real valued denominator as rhs * con(rhs):
// denominator = rhs.real * rhs.real + rhs.imag * rhs.imag
// CHECK-DAG: [[VAL4:%.+]] = mhlo.multiply %arg2, %arg2
// CHECK-DAG: [[VAL5:%.+]] = mhlo.multiply %arg3, [[VAL0]]
// CHECK-DAG: [[VAL6:%.+]] = mhlo.subtract [[VAL4]], [[VAL5]]
// Compute the numerator's imaginary component:
// numerator.imag = lhs.imag * rhs.real - lhs.real * rhs.imag
// CHECK-DAG: [[VAL7:%.+]] = mhlo.multiply %arg1, %arg2
// CHECK-DAG: [[VAL8:%.+]] = mhlo.multiply %arg0, [[VAL0]]
// CHECK-DAG: [[VAL9:%.+]] = mhlo.add [[VAL8]], [[VAL7]]
// Divide the numerator by the real valued denominator.
// CHECK-DAG: [[VAL10:%.+]] = mhlo.divide [[VAL3]], [[VAL6]]
// CHECK-DAG: [[VAL11:%.+]] = mhlo.divide [[VAL9]], [[VAL6]]
%4 = "mhlo.divide"(%2, %3) : (tensor<*xcomplex<f32>>, tensor<*xcomplex<f32>>) -> (tensor<*xcomplex<f32>>)
%5 = "mhlo.real"(%4) : (tensor<*xcomplex<f32>>) -> (tensor<*xf32>)
%6 = "mhlo.imag"(%4) : (tensor<*xcomplex<f32>>) -> (tensor<*xf32>)
// CHECK: return [[VAL10]], [[VAL11]]
return %5, %6 : tensor<*xf32>, tensor<*xf32>
}
// CHECK-LABEL: @abs
func @abs(%arg0 : tensor<2xf32>, %arg1 : tensor<2xf32>) -> (tensor<2xf32>) {
%0 = "mhlo.complex"(%arg0, %arg1) : (tensor<2xf32>, tensor<2xf32>) -> (tensor<2xcomplex<f32>>)
// CHECK-DAG: [[VAL0:%.+]] = mhlo.multiply %arg0, %arg0
// CHECK-DAG: [[VAL1:%.+]] = mhlo.multiply %arg1, %arg1
// CHECK-DAG: [[VAL2:%.+]] = mhlo.add [[VAL0]], [[VAL1]]
// CHECK-DAG: [[VAL3:%.+]] = "mhlo.sqrt"([[VAL2]])
%1 = "mhlo.abs"(%0) : (tensor<2xcomplex<f32>>) -> (tensor<2xcomplex<f32>>)
%2 = "mhlo.real"(%1) : (tensor<2xcomplex<f32>>) -> (tensor<2xf32>)
// CHECK: return [[VAL3]]
return %2 : tensor<2xf32>
}
// CHECK-LABEL: @exp
func @exp(%arg0 : tensor<2xf32>, %arg1 : tensor<2xf32>) -> (tensor<2xf32>, tensor<2xf32>) {
%0 = "mhlo.complex"(%arg0, %arg1) : (tensor<2xf32>, tensor<2xf32>) -> (tensor<2xcomplex<f32>>)
// CHECK-DAG: [[VAL0:%.+]] = "mhlo.exponential"(%arg0)
// CHECK-DAG: [[VAL1:%.+]] = "mhlo.cosine"(%arg1)
// CHECK-DAG: [[VAL2:%.+]] = "mhlo.sine"(%arg1)
// CHECK-DAG: [[VAL3:%.+]] = mhlo.multiply [[VAL0]], [[VAL1]]
// CHECK-DAG: [[VAL4:%.+]] = mhlo.multiply [[VAL0]], [[VAL2]]
%1 = "mhlo.exponential"(%0) : (tensor<2xcomplex<f32>>) -> (tensor<2xcomplex<f32>>)
%2 = "mhlo.real"(%1) : (tensor<2xcomplex<f32>>) -> (tensor<2xf32>)
%3 = "mhlo.imag"(%1) : (tensor<2xcomplex<f32>>) -> (tensor<2xf32>)
// CHECK: return [[VAL3]], [[VAL4]]
return %2, %3 : tensor<2xf32>, tensor<2xf32>
}
// CHECK-LABEL: @exp_unranked
func @exp_unranked(%arg0 : tensor<*xf32>, %arg1 : tensor<*xf32>) -> (tensor<*xf32>, tensor<*xf32>) {
%0 = "mhlo.complex"(%arg0, %arg1) : (tensor<*xf32>, tensor<*xf32>) -> (tensor<*xcomplex<f32>>)
// CHECK-DAG: [[VAL0:%.+]] = "mhlo.exponential"(%arg0)
// CHECK-DAG: [[VAL1:%.+]] = "mhlo.cosine"(%arg1)
// CHECK-DAG: [[VAL2:%.+]] = "mhlo.sine"(%arg1)
// CHECK-DAG: [[VAL3:%.+]] = mhlo.multiply [[VAL0]], [[VAL1]]
// CHECK-DAG: [[VAL4:%.+]] = mhlo.multiply [[VAL0]], [[VAL2]]
%1 = "mhlo.exponential"(%0) : (tensor<*xcomplex<f32>>) -> (tensor<*xcomplex<f32>>)
%2 = "mhlo.real"(%1) : (tensor<*xcomplex<f32>>) -> (tensor<*xf32>)
%3 = "mhlo.imag"(%1) : (tensor<*xcomplex<f32>>) -> (tensor<*xf32>)
// CHECK: return [[VAL3]], [[VAL4]]
return %2, %3 : tensor<*xf32>, tensor<*xf32>
}

View File

@ -0,0 +1,35 @@
// RUN: mlir-hlo-opt -test-xla-lower-general-dot -split-input-file %s -o - | FileCheck %s
// CHECK-LABEL: @testDebatch1
func @testDebatch1(%arg0: tensor<1x1x2xf32>, %arg1: tensor<2x3xf32>) -> tensor<1x1x3xf32> {
// CHECK-DAG: [[R0:%.+]] = "mhlo.reshape"(%arg0) : (tensor<1x1x2xf32>) -> tensor<1x2xf32>
// CHECK-DAG: [[R1:%.+]] = "mhlo.dot"([[R0]], %arg1) {precision_config = ["DEFAULT", "DEFAULT"]} : (tensor<1x2xf32>, tensor<2x3xf32>) -> tensor<1x3xf32>
// CHECK: [[R2:%.+]] = "mhlo.reshape"([[R1]]) : (tensor<1x3xf32>) -> tensor<1x1x3xf32>
%0 = "mhlo.dot_general"(%arg0, %arg1) {dot_dimension_numbers = {lhs_batching_dimensions = dense<[]> : tensor<0xi64>, lhs_contracting_dimensions = dense<2> : tensor<1xi64>, rhs_batching_dimensions = dense<[]> : tensor<0xi64>, rhs_contracting_dimensions = dense<0> : tensor<1xi64>}, precision_config = ["DEFAULT", "DEFAULT"]} : (tensor<1x1x2xf32>, tensor<2x3xf32>) -> tensor<1x1x3xf32>
return %0 : tensor<1x1x3xf32>
}
// -----
// CHECK-LABEL: @testDebatch2
func @testDebatch2(%arg0: tensor<2x3xf32>, %arg1: tensor<1x1x2xf32>) -> tensor<3x1x1xf32> {
// CHECK-DAG: [[R0:%.+]] = "mhlo.transpose"(%arg0) {permutation = dense<[1, 0]> : tensor<2xi64>} : (tensor<2x3xf32>) -> tensor<3x2xf32>
// CHECK-DAG: [[R1:%.+]] = "mhlo.transpose"(%arg1) {permutation = dense<[2, 0, 1]> : tensor<3xi64>} : (tensor<1x1x2xf32>) -> tensor<2x1x1xf32>
// CHECK-DAG: [[R2:%.+]] = "mhlo.reshape"([[R1]]) : (tensor<2x1x1xf32>) -> tensor<2x1xf32>
// CHECK-DAG: [[R3:%.+]] = "mhlo.dot"([[R0]], [[R2]]) {precision_config = ["DEFAULT", "DEFAULT"]} : (tensor<3x2xf32>, tensor<2x1xf32>) -> tensor<3x1xf32>
// CHECK: [[R4:%.+]] = "mhlo.reshape"([[R3]]) : (tensor<3x1xf32>) -> tensor<3x1x1xf32>
%0 = "mhlo.dot_general"(%arg0, %arg1) {dot_dimension_numbers = {lhs_batching_dimensions = dense<[]> : tensor<0xi64>, lhs_contracting_dimensions = dense<0> : tensor<1xi64>, rhs_batching_dimensions = dense<[]> : tensor<0xi64>, rhs_contracting_dimensions = dense<2> : tensor<1xi64>}, precision_config = ["DEFAULT", "DEFAULT"]} : (tensor<2x3xf32>, tensor<1x1x2xf32>) -> tensor<3x1x1xf32>
return %0 : tensor<3x1x1xf32>
}
// -----
// CHECK-LABEL: @testBatchPassthrough
func @testBatchPassthrough(%arg0: tensor<2x2x3xf32>, %arg1: tensor<2x1x2xf32>) -> tensor<3x2x1xf32> {
// CHECK-NEXT: "mhlo.dot_general"(%arg0, %arg1)
%0 = "mhlo.dot_general"(%arg0, %arg1) {dot_dimension_numbers = {lhs_batching_dimensions = dense<[0]> : tensor<1xi64>, lhs_contracting_dimensions = dense<1> : tensor<1xi64>, rhs_batching_dimensions = dense<[0]> : tensor<1xi64>, rhs_contracting_dimensions = dense<2> : tensor<1xi64>}, precision_config = ["DEFAULT", "DEFAULT"]} : (tensor<2x2x3xf32>, tensor<2x1x2xf32>) -> tensor<3x2x1xf32>
return %0 : tensor<3x2x1xf32>
}

View File

@ -0,0 +1,11 @@
// RUN: mlir-hlo-opt -test-xla-materialize-broadcasts -split-input-file %s -o - | FileCheck %s
// CHECK-LABEL: @clampBroadcast
// CHECK-SAME: (%[[MIN:.+]]: tensor<f32>, %[[VAL:.+]]: tensor<4xf32>, %[[MAX:.+]]: tensor<f32>)
func @clampBroadcast(%min: tensor<f32>, %value: tensor<4xf32>, %max: tensor<f32>) -> tensor<4xf32> {
// CHECK-DAG: %[[MIN_BC:.+]] = "mhlo.broadcast"(%[[MIN]]) {broadcast_sizes = dense<4> : tensor<1xi64>} : (tensor<f32>) -> tensor<4xf32>
// CHECK-DAG: %[[MAX_BC:.+]] = "mhlo.broadcast"(%[[MAX]]) {broadcast_sizes = dense<4> : tensor<1xi64>} : (tensor<f32>) -> tensor<4xf32>
// CHECK: "mhlo.clamp"(%[[MIN_BC]], %[[VAL]], %[[MAX_BC]]) : (tensor<4xf32>, tensor<4xf32>, tensor<4xf32>) -> tensor<4xf32>
%0 = "mhlo.clamp"(%min, %value, %max) : (tensor<f32>, tensor<4xf32>, tensor<f32>) -> tensor<4xf32>
return %0 : tensor<4xf32>
}

View File

@ -1,14 +1,14 @@
// RUN: xla-opt %s -pass-pipeline='func(canonicalize)' | FileCheck %s // RUN: mlir-hlo-opt %s -pass-pipeline='func(canonicalize)' | FileCheck %s
// CHECK-LABEL: func @noop // CHECK-LABEL: func @noop
// CHECK-SAME: (%[[ARG0:.*]]: tensor<4x8xf32>) // CHECK-SAME: (%[[ARG0:.*]]: tensor<4x8xf32>)
// CHECK: return %[[ARG0]] // CHECK: return %[[ARG0]]
func @noop(%arg0: tensor<4x8xf32>) -> tensor<4x8xf32> { func @noop(%arg0: tensor<4x8xf32>) -> tensor<4x8xf32> {
%0 = xla_hlo.constant dense<0.000000e+00> : tensor<f32> %0 = mhlo.constant dense<0.000000e+00> : tensor<f32>
%2 = "xla_hlo.reduce"(%arg0, %0) ( { %2 = "mhlo.reduce"(%arg0, %0) ( {
^bb0(%arg1: tensor<f32>, %arg2: tensor<f32>): ^bb0(%arg1: tensor<f32>, %arg2: tensor<f32>):
%4 = xla_hlo.add %arg1, %arg2 : tensor<f32> %4 = mhlo.add %arg1, %arg2 : tensor<f32>
"xla_hlo.return"(%4) : (tensor<f32>) -> () "mhlo.return"(%4) : (tensor<f32>) -> ()
}) {dimensions = dense<[]> : tensor<0xi64>} : (tensor<4x8xf32>, tensor<f32>) -> tensor<4x8xf32> }) {dimensions = dense<[]> : tensor<0xi64>} : (tensor<4x8xf32>, tensor<f32>) -> tensor<4x8xf32>
return %2 : tensor<4x8xf32> return %2 : tensor<4x8xf32>
} }

View File

@ -0,0 +1,149 @@
// RUN: mlir-hlo-opt %s -split-input-file -pass-pipeline='func(canonicalize)' | FileCheck %s
// CHECK-LABEL: func @const_fold_collapse_to_scalar
func @const_fold_collapse_to_scalar() -> tensor<i32> {
// CHECK-NEXT: [[CST:%.+]] = mhlo.constant dense<42> : tensor<i32>
%cst = mhlo.constant dense<42> : tensor<1x1xi32>
%0 = "mhlo.reshape"(%cst) : (tensor<1x1xi32>) -> tensor<i32>
// CHECK-NEXT: return [[CST]]
return %0 : tensor<i32>
}
// -----
// CHECK-LABEL: func @const_fold_collapse_to_tensor
func @const_fold_collapse_to_tensor() -> tensor<2xi32> {
// CHECK-NEXT: [[CST:%.+]] = mhlo.constant dense<42> : tensor<2xi32>
%cst = mhlo.constant dense<42> : tensor<1x2xi32>
%0 = "mhlo.reshape"(%cst) : (tensor<1x2xi32>) -> tensor<2xi32>
// CHECK-NEXT: return [[CST]]
return %0 : tensor<2xi32>
}
// -----
// CHECK-LABEL: func @const_fold_expand
func @const_fold_expand() -> tensor<1xi32> {
// CHECK-NEXT: [[CST:%.+]] = mhlo.constant dense<42> : tensor<1xi32>
%cst = mhlo.constant dense<42> : tensor<i32>
%0 = "mhlo.reshape"(%cst) : (tensor<i32>) -> tensor<1xi32>
// CHECK-NEXT: return [[CST]]
return %0 : tensor<1xi32>
}
// -----
// CHECK-LABEL: func @const_fold_nontrivial
func @const_fold_nontrivial() -> tensor<16xi64> {
// CHECK-NEXT: [[CST:%.+]] = mhlo.constant dense<42> : tensor<16xi64>
%cst = mhlo.constant dense<42> : tensor<4x4xi64>
%0 = "mhlo.reshape"(%cst) : (tensor<4x4xi64>) -> tensor<16xi64>
// CHECK-NEXT: return [[CST]]
return %0 : tensor<16xi64>
}
// -----
// CHECK-LABEL: func @const_fold_flatten
func @const_fold_flatten() -> tensor<16xi64> {
// CHECK-NEXT: [[CST:%.+]] = mhlo.constant dense<42> : tensor<16xi64>
%cst = mhlo.constant dense<42> : tensor<4x4xi64>
%0 = "mhlo.reshape"(%cst) : (tensor<4x4xi64>) -> tensor<16xi64>
// CHECK-NEXT: return [[CST]]
return %0 : tensor<16xi64>
}
// -----
// CHECK-LABEL: func @const_fold_6
func @const_fold_6() -> tensor<6xi32> {
// CHECK-NEXT: [[CST:%.+]] = mhlo.constant dense<[1, 2, 3, 4, 5, 6]> : tensor<6xi32>
%cst = mhlo.constant dense<[[1, 2], [3, 4], [5, 6]]> : tensor<3x2xi32>
%0 = "mhlo.reshape"(%cst) : (tensor<3x2xi32>) -> tensor<6xi32>
// CHECK-NEXT: return [[CST]]
return %0 : tensor<6xi32>
}
// -----
// CHECK-LABEL: func @const_fold_same_shape
func @const_fold_same_shape() -> tensor<2x3xi32> {
// CHECK-NEXT: [[CST:%.+]] = mhlo.constant dense<[
// CHECK-SAME: [1, 2, 3], [4, 5, 6]
// CHECK-SAME: ]> : tensor<2x3xi32>
%cst = mhlo.constant dense<[1, 2, 3, 4, 5, 6]> : tensor<6xi32>
%0 = "mhlo.reshape"(%cst) : (tensor<6xi32>) -> tensor<2x3xi32>
// CHECK-NEXT: return [[CST]]
return %0 : tensor<2x3xi32>
}
// -----
// CHECK-LABEL: func @const_fold_float
func @const_fold_float() -> tensor<16xf64> {
// CHECK-NEXT: [[CST:%.+]] = mhlo.constant dense<4.2{{0*}}e+00> : tensor<16xf64>
%cst = mhlo.constant dense<4.2> : tensor<4x4xf64>
%0 = "mhlo.reshape"(%cst) : (tensor<4x4xf64>) -> tensor<16xf64>
// CHECK-NEXT: return [[CST]]
return %0 : tensor<16xf64>
}
// -----
// CHECK-LABEL: func @non_const_same_shape
// CHECK-SAME: [[ARG:%[a-zA-Z0-9]+]]
func @non_const_same_shape(%arg : tensor<2x3xi32>) -> tensor<2x3xi32> {
// CHECK-NEXT: return [[ARG]]
%0 = "mhlo.reshape"(%arg) : (tensor<2x3xi32>) -> tensor<2x3xi32>
return %0 : tensor<2x3xi32>
}
// -----
// CHECK-LABEL: func @non_const_chained_reshape
// CHECK-SAME: [[ARG:%[a-zA-Z0-9]+]]
func @non_const_chained_reshape(%arg : tensor<2x3xi32>) -> (tensor<3x2xi32>, tensor<6xi32>) {
// CHECK-NEXT: "mhlo.reshape"([[ARG]]) : (tensor<2x3xi32>) -> tensor<3x2xi32>
// CHECK-NEXT: "mhlo.reshape"([[ARG]]) : (tensor<2x3xi32>) -> tensor<6xi32>
%0 = "mhlo.reshape"(%arg) : (tensor<2x3xi32>) -> tensor<3x2xi32>
%1 = "mhlo.reshape"(%0) : (tensor<3x2xi32>) -> tensor<6xi32>
return %0, %1 : tensor<3x2xi32>, tensor<6xi32> // return both so nothing is removed
}
// -----
// CHECK-LABEL: func @non_const_chained_reshape_unused_parent
// CHECK-SAME: [[ARG:%[a-zA-Z0-9]+]]
func @non_const_chained_reshape_unused_parent(%arg : tensor<2x3xi32>) -> tensor<6xi32> {
// CHECK-NEXT: [[RES:%.+]] = "mhlo.reshape"([[ARG]]) : (tensor<2x3xi32>) -> tensor<6xi32>
%0 = "mhlo.reshape"(%arg) : (tensor<2x3xi32>) -> tensor<3x2xi32>
%1 = "mhlo.reshape"(%0) : (tensor<3x2xi32>) -> tensor<6xi32>
// CHECK-NEXT: return [[RES]]
return %1 : tensor<6xi32>
}
// -----
// CHECK-LABEL: func @non_const_chained_reshape_becomes_noop
// CHECK-SAME: [[ARG:%[a-zA-Z0-9]+]]
func @non_const_chained_reshape_becomes_noop(%arg : tensor<2x3xi32>) -> tensor<2x3xi32> {
%0 = "mhlo.reshape"(%arg) : (tensor<2x3xi32>) -> tensor<3x2xi32>
%1 = "mhlo.reshape"(%0) : (tensor<3x2xi32>) -> tensor<2x3xi32>
// CHECK-NEXT: return [[ARG]]
return %1 : tensor<2x3xi32>
}
// -----
// CHECK-LABEL: func @non_const_many_chained_reshapes
// CHECK-SAME: [[ARG:%[a-zA-Z0-9]+]]
func @non_const_many_chained_reshapes(%arg : tensor<2x3x4xi32>) -> tensor<1x2x4x3xi32> {
// CHECK-NEXT: [[RES:%.+]] = "mhlo.reshape"([[ARG]]) : (tensor<2x3x4xi32>) -> tensor<1x2x4x3xi32>
%0 = "mhlo.reshape"(%arg) : (tensor<2x3x4xi32>) -> tensor<4x3x2xi32>
%1 = "mhlo.reshape"(%0) : (tensor<4x3x2xi32>) -> tensor<12x2xi32>
%2 = "mhlo.reshape"(%1) : (tensor<12x2xi32>) -> tensor<2x12xi32>
%3 = "mhlo.reshape"(%2) : (tensor<2x12xi32>) -> tensor<24xi32>
%4 = "mhlo.reshape"(%3) : (tensor<24xi32>) -> tensor<1x2x4x3xi32>
// CHECK-NEXT: return [[RES]]
return %4 : tensor<1x2x4x3xi32>
}

View File

@ -1,9 +1,9 @@
// RUN: xla-opt %s -pass-pipeline='func(canonicalize)' | FileCheck %s // RUN: mlir-hlo-opt %s -pass-pipeline='func(canonicalize)' | FileCheck %s
// CHECK-LABEL: func @noop // CHECK-LABEL: func @noop
// CHECK-SAME: (%[[ARG0:.*]]: tensor<1x2xf32>) // CHECK-SAME: (%[[ARG0:.*]]: tensor<1x2xf32>)
func @noop(%arg0: tensor<1x2xf32>) -> tensor<1x2xf32> { func @noop(%arg0: tensor<1x2xf32>) -> tensor<1x2xf32> {
%0 = "xla_hlo.reverse"(%arg0) {dimensions = dense<[]> : tensor<0xi64>} : (tensor<1x2xf32>) -> tensor<1x2xf32> %0 = "mhlo.reverse"(%arg0) {dimensions = dense<[]> : tensor<0xi64>} : (tensor<1x2xf32>) -> tensor<1x2xf32>
// CHECK: return %[[ARG0]] // CHECK: return %[[ARG0]]
return %0 : tensor<1x2xf32> return %0 : tensor<1x2xf32>
} }

View File

@ -0,0 +1,60 @@
// RUN: mlir-hlo-opt %s -xla-hlo-sink-constants-to-control-flow | FileCheck %s
// Tests sinking constants to a while loop.
// CHECK-LABEL: func @sink_const_to_while
func @sink_const_to_while(%arg0: tensor<i64>) -> tensor<i64> {
// CHECK-NEXT: mhlo.while
%c0 = mhlo.constant dense<1> : tensor<i64>
%c1 = mhlo.constant dense<2> : tensor<i64>
%0 = "mhlo.while"(%arg0) ( {
^bb0(%arg1: tensor<i64>):
// CHECK: %[[ARG1A:.+]]: tensor<i64>
// CHECK: %[[C0:.+]] = mhlo.constant dense<1> : tensor<i64>
// CHECK: "mhlo.compare"(%[[C0]], %[[ARG1A]])
%1 = "mhlo.compare"(%c0, %arg1) {comparison_direction = "LT"} : (tensor<i64>, tensor<i64>) -> tensor<i1>
"mhlo.return"(%1) : (tensor<i1>) -> ()
}, {
^bb0(%arg1: tensor<i64>):
// CHECK: %[[ARG1B:.+]]: tensor<i64>
// CHECK-DAG: %[[C1:.+]] = mhlo.constant dense<2> : tensor<i64>
// CHECK-DAG: %[[ADD0:.+]] = mhlo.add %[[ARG1B]], %[[ARG1B]]
%2 = mhlo.add %arg1, %arg1 : tensor<i64>
// CHECK: %[[ADD1:.+]] = mhlo.add %[[C1]], %[[ADD0]]
%3 = mhlo.add %c1, %2 : tensor<i64>
// CHECK: %[[ADD2:.+]] = mhlo.add %[[C1]], %[[ADD1]]
%4 = mhlo.add %c1, %3 : tensor<i64>
"mhlo.return"(%4) : (tensor<i64>) -> ()
}) : (tensor<i64>) -> tensor<i64>
return %0 : tensor<i64>
}
// Tests sinking constants to a conditional op.
// CHECK-LABEL: func @sink_const_to_conditional
func @sink_const_to_conditional(%arg0: tensor<i64>) -> tensor<i64> {
%c0 = mhlo.constant dense<1> : tensor<i64>
%c1 = mhlo.constant dense<2> : tensor<i64>
%0 = "mhlo.compare"(%arg0, %c0) {comparison_direction = "LT"} : (tensor<i64>, tensor<i64>) -> tensor<i1>
%1 = "mhlo.tuple"(%arg0) : (tensor<i64>) -> tuple<tensor<i64>>
// CHECK: mhlo.if
%2 = "mhlo.if"(%0, %1, %1) ( {
^bb0(%arg1: tuple<tensor<i64>>):
// CHECK: %[[C0:.+]] = mhlo.constant dense<1> : tensor<i64>
%3 = "mhlo.get_tuple_element"(%arg1) {index = 0 : i32} : (tuple<tensor<i64>>) -> tensor<i64>
// CHECK: %[[ADD0:.+]] = mhlo.add %[[C0]],
%4 = mhlo.add %c0, %3 : tensor<i64>
%5 = "mhlo.tuple"(%4) : (tensor<i64>) -> tuple<tensor<i64>>
"mhlo.return"(%5) : (tuple<tensor<i64>>) -> ()
}, {
^bb0(%arg1: tuple<tensor<i64>>):
// CHECK: %[[C1:.+]] = mhlo.constant dense<2> : tensor<i64>
%6 = "mhlo.get_tuple_element"(%arg1) {index = 0 : i32} : (tuple<tensor<i64>>) -> tensor<i64>
// CHECK: %[[ADD1:.+]] = mhlo.add %[[C1]],
%7 = mhlo.add %c1, %6 : tensor<i64>
%8 = "mhlo.tuple"(%7) : (tensor<i64>) -> tuple<tensor<i64>>
"mhlo.return"(%8) : (tuple<tensor<i64>>) -> ()
}) : (tensor<i1>, tuple<tensor<i64>>, tuple<tensor<i64>>) -> tuple<tensor<i64>>
%9 = "mhlo.get_tuple_element"(%2) {index = 0 : i32} : (tuple<tensor<i64>>) -> tensor<i64>
return %9 : tensor<i64>
}

View File

@ -1,9 +1,9 @@
// RUN: xla-opt %s -split-input-file -pass-pipeline='func(canonicalize)' | FileCheck %s // RUN: mlir-hlo-opt %s -split-input-file -pass-pipeline='func(canonicalize)' | FileCheck %s
// CHECK-LABEL: func @remove_noop // CHECK-LABEL: func @remove_noop
// CHECK-SAME: [[ARG:%[a-zA-Z0-9]+]] // CHECK-SAME: [[ARG:%[a-zA-Z0-9]+]]
func @remove_noop(%arg : tensor<2x3x9x5xi32>) -> tensor<2x3x9x5xi32> { func @remove_noop(%arg : tensor<2x3x9x5xi32>) -> tensor<2x3x9x5xi32> {
%0 = "xla_hlo.transpose"(%arg) {permutation = dense<[0, 1, 2, 3]> : tensor<4xi64>}: (tensor<2x3x9x5xi32>) -> tensor<2x3x9x5xi32> %0 = "mhlo.transpose"(%arg) {permutation = dense<[0, 1, 2, 3]> : tensor<4xi64>}: (tensor<2x3x9x5xi32>) -> tensor<2x3x9x5xi32>
// CHECK-NEXT: return [[ARG]] // CHECK-NEXT: return [[ARG]]
return %0 : tensor<2x3x9x5xi32> return %0 : tensor<2x3x9x5xi32>
} }
@ -13,8 +13,8 @@ func @remove_noop(%arg : tensor<2x3x9x5xi32>) -> tensor<2x3x9x5xi32> {
// CHECK-LABEL: func @keep_real_transpose // CHECK-LABEL: func @keep_real_transpose
// CHECK-SAME: [[ARG:%[a-zA-Z0-9]+]] // CHECK-SAME: [[ARG:%[a-zA-Z0-9]+]]
func @keep_real_transpose(%arg : tensor<2x3x9x5xi32>) -> tensor<3x2x5x9xi32> { func @keep_real_transpose(%arg : tensor<2x3x9x5xi32>) -> tensor<3x2x5x9xi32> {
// CHECK-NEXT: "xla_hlo.transpose"([[ARG]]) // CHECK-NEXT: "mhlo.transpose"([[ARG]])
%0 = "xla_hlo.transpose"(%arg) {permutation = dense<[1, 0, 3, 2]> : tensor<4xi64>}: (tensor<2x3x9x5xi32>) -> tensor<3x2x5x9xi32> %0 = "mhlo.transpose"(%arg) {permutation = dense<[1, 0, 3, 2]> : tensor<4xi64>}: (tensor<2x3x9x5xi32>) -> tensor<3x2x5x9xi32>
return %0 : tensor<3x2x5x9xi32> return %0 : tensor<3x2x5x9xi32>
} }
@ -23,7 +23,7 @@ func @keep_real_transpose(%arg : tensor<2x3x9x5xi32>) -> tensor<3x2x5x9xi32> {
// CHECK-LABEL: func @keep_same_shape_real_transpose // CHECK-LABEL: func @keep_same_shape_real_transpose
// CHECK-SAME: [[ARG:%[a-zA-Z0-9]+]] // CHECK-SAME: [[ARG:%[a-zA-Z0-9]+]]
func @keep_same_shape_real_transpose(%arg : tensor<4x4xi32>) -> tensor<4x4xi32> { func @keep_same_shape_real_transpose(%arg : tensor<4x4xi32>) -> tensor<4x4xi32> {
// CHECK-NEXT: "xla_hlo.transpose"([[ARG]]) // CHECK-NEXT: "mhlo.transpose"([[ARG]])
%0 = "xla_hlo.transpose"(%arg) {permutation = dense<[1, 0]> : tensor<2xi64>}: (tensor<4x4xi32>) -> tensor<4x4xi32> %0 = "mhlo.transpose"(%arg) {permutation = dense<[1, 0]> : tensor<2xi64>}: (tensor<4x4xi32>) -> tensor<4x4xi32>
return %0 : tensor<4x4xi32> return %0 : tensor<4x4xi32>
} }

View File

@ -0,0 +1,10 @@
// RUN: mlir-hlo-opt %s -split-input-file -pass-pipeline='func(canonicalize)' | FileCheck %s
// CHECK-LABEL: func @fold_access
// CHECK-SAME: [[ARG:%[a-zA-Z0-9]+]]
func @fold_access(%arg : tensor<i32>) -> tensor<i32> {
// CHECK-NEXT: return [[ARG]]
%tuple = "mhlo.tuple"(%arg) : (tensor<i32>) -> tuple<tensor<i32>>
%element = "mhlo.get_tuple_element"(%tuple) {index = 0 : i32} : (tuple<tensor<i32>>) -> tensor<i32>
return %element : tensor<i32>
}

View File

@ -1,4 +1,4 @@
// RUN: xla-opt -split-input-file -test-xla-unfuse-batch-norm -verify-diagnostics %s | FileCheck --enable-var-scope %s // RUN: mlir-hlo-opt -split-input-file -test-xla-unfuse-batch-norm -verify-diagnostics %s | FileCheck --enable-var-scope %s
// CHECK-LABEL: @batchNormInference_2D_inner_features // CHECK-LABEL: @batchNormInference_2D_inner_features
// CHECK-SAME: %[[X:[^:[:space:]]+]] // CHECK-SAME: %[[X:[^:[:space:]]+]]
@ -10,19 +10,19 @@ func @batchNormInference_2D_inner_features(
%x: tensor<4x256xf32>, %scale: tensor<256xf32>, %offset: tensor<256xf32>, %x: tensor<4x256xf32>, %scale: tensor<256xf32>, %offset: tensor<256xf32>,
%mean: tensor<256xf32>, %variance: tensor<256xf32>) %mean: tensor<256xf32>, %variance: tensor<256xf32>)
-> (tensor<4x256xf32>) { -> (tensor<4x256xf32>) {
// CHECK-DAG: %[[EPS:.+]] = xla_hlo.constant dense<1.001000e-05> : tensor<f32> // CHECK-DAG: %[[EPS:.+]] = mhlo.constant dense<1.001000e-05> : tensor<f32>
// CHECK-DAG: %[[EPS_BCAST:.+]] = "xla_hlo.broadcast_in_dim"(%[[EPS]]) {broadcast_dimensions = dense<[]> : tensor<0xi64>} : (tensor<f32>) -> tensor<256xf32> // CHECK-DAG: %[[EPS_BCAST:.+]] = "mhlo.broadcast_in_dim"(%[[EPS]]) {broadcast_dimensions = dense<[]> : tensor<0xi64>} : (tensor<f32>) -> tensor<256xf32>
// CHECK-DAG: %[[VARIANCE_EPS:.+]] = xla_hlo.add %[[VARIANCE]], %[[EPS_BCAST]] : tensor<256xf32> // CHECK-DAG: %[[VARIANCE_EPS:.+]] = mhlo.add %[[VARIANCE]], %[[EPS_BCAST]] : tensor<256xf32>
// CHECK-DAG: %[[STDDEV:.+]] = "xla_hlo.sqrt"(%[[VARIANCE_EPS]]) : (tensor<256xf32>) -> tensor<256xf32> // CHECK-DAG: %[[STDDEV:.+]] = "mhlo.sqrt"(%[[VARIANCE_EPS]]) : (tensor<256xf32>) -> tensor<256xf32>
// CHECK-DAG: %[[STDDEV_BCAST:.+]] = "xla_hlo.broadcast_in_dim"(%[[STDDEV]]) {broadcast_dimensions = dense<1> : tensor<1xi64>} : (tensor<256xf32>) -> tensor<4x256xf32> // CHECK-DAG: %[[STDDEV_BCAST:.+]] = "mhlo.broadcast_in_dim"(%[[STDDEV]]) {broadcast_dimensions = dense<1> : tensor<1xi64>} : (tensor<256xf32>) -> tensor<4x256xf32>
// CHECK-DAG: %[[SCALE_BCAST:.+]] = "xla_hlo.broadcast_in_dim"(%[[SCALE]]) {broadcast_dimensions = dense<1> : tensor<1xi64>} : (tensor<256xf32>) -> tensor<4x256xf32> // CHECK-DAG: %[[SCALE_BCAST:.+]] = "mhlo.broadcast_in_dim"(%[[SCALE]]) {broadcast_dimensions = dense<1> : tensor<1xi64>} : (tensor<256xf32>) -> tensor<4x256xf32>
// CHECK-DAG: %[[OFFSET_BCAST:.+]] = "xla_hlo.broadcast_in_dim"(%[[OFFSET]]) {broadcast_dimensions = dense<1> : tensor<1xi64>} : (tensor<256xf32>) -> tensor<4x256xf32> // CHECK-DAG: %[[OFFSET_BCAST:.+]] = "mhlo.broadcast_in_dim"(%[[OFFSET]]) {broadcast_dimensions = dense<1> : tensor<1xi64>} : (tensor<256xf32>) -> tensor<4x256xf32>
// CHECK-DAG: %[[MEAN_BCAST:.+]] = "xla_hlo.broadcast_in_dim"(%[[MEAN]]) {broadcast_dimensions = dense<1> : tensor<1xi64>} : (tensor<256xf32>) -> tensor<4x256xf32> // CHECK-DAG: %[[MEAN_BCAST:.+]] = "mhlo.broadcast_in_dim"(%[[MEAN]]) {broadcast_dimensions = dense<1> : tensor<1xi64>} : (tensor<256xf32>) -> tensor<4x256xf32>
// CHECK-DAG: %[[X_CENTER:.+]] = xla_hlo.subtract %[[X]], %[[MEAN_BCAST]] : tensor<4x256xf32> // CHECK-DAG: %[[X_CENTER:.+]] = mhlo.subtract %[[X]], %[[MEAN_BCAST]] : tensor<4x256xf32>
// CHECK-DAG: %[[X_SCALED:.+]] = xla_hlo.multiply %[[X_CENTER]], %[[SCALE_BCAST]] : tensor<4x256xf32> // CHECK-DAG: %[[X_SCALED:.+]] = mhlo.multiply %[[X_CENTER]], %[[SCALE_BCAST]] : tensor<4x256xf32>
// CHECK-DAG: %[[X_NORMED:.+]] = xla_hlo.divide %[[X_SCALED]], %[[STDDEV_BCAST]] : tensor<4x256xf32> // CHECK-DAG: %[[X_NORMED:.+]] = mhlo.divide %[[X_SCALED]], %[[STDDEV_BCAST]] : tensor<4x256xf32>
// CHECK-DAG: %[[RESULT:.+]] = xla_hlo.add %[[X_NORMED]], %[[OFFSET_BCAST]] : tensor<4x256xf32> // CHECK-DAG: %[[RESULT:.+]] = mhlo.add %[[X_NORMED]], %[[OFFSET_BCAST]] : tensor<4x256xf32>
%0 = "xla_hlo.batch_norm_inference"(%x, %scale, %offset, %mean, %variance) %0 = "mhlo.batch_norm_inference"(%x, %scale, %offset, %mean, %variance)
{epsilon = 1.001000e-05 : f32, feature_index = 1 : i64} : {epsilon = 1.001000e-05 : f32, feature_index = 1 : i64} :
(tensor<4x256xf32>, tensor<256xf32>, tensor<256xf32>, tensor<256xf32>, (tensor<4x256xf32>, tensor<256xf32>, tensor<256xf32>, tensor<256xf32>,
tensor<256xf32>) -> tensor<4x256xf32> tensor<256xf32>) -> tensor<4x256xf32>
@ -36,12 +36,12 @@ func @batchNormInference_2D_inner_features(
// the verifier to enforce the rest. // the verifier to enforce the rest.
// CHECK-SAME: %[[X:[^:]+]] // CHECK-SAME: %[[X:[^:]+]]
// CHECK-SAME: %[[SCALE:[^:]+]] // CHECK-SAME: %[[SCALE:[^:]+]]
// CHECK-DAG: %[[SCALE_BCAST:.+]] = "xla_hlo.broadcast_in_dim"(%[[SCALE]]) {broadcast_dimensions = dense<2> : tensor<1xi64>} : (tensor<256xf32>) -> tensor<3x4x256x6xf32> // CHECK-DAG: %[[SCALE_BCAST:.+]] = "mhlo.broadcast_in_dim"(%[[SCALE]]) {broadcast_dimensions = dense<2> : tensor<1xi64>} : (tensor<256xf32>) -> tensor<3x4x256x6xf32>
func @batchNormInference_4D_middle_features( func @batchNormInference_4D_middle_features(
%x: tensor<3x4x256x6xf32>, %scale: tensor<256xf32>, %offset: tensor<256xf32>, %x: tensor<3x4x256x6xf32>, %scale: tensor<256xf32>, %offset: tensor<256xf32>,
%mean: tensor<256xf32>, %variance: tensor<256xf32>) %mean: tensor<256xf32>, %variance: tensor<256xf32>)
-> (tensor<3x4x256x6xf32>) { -> (tensor<3x4x256x6xf32>) {
%0 = "xla_hlo.batch_norm_inference"(%x, %scale, %offset, %mean, %variance) %0 = "mhlo.batch_norm_inference"(%x, %scale, %offset, %mean, %variance)
{epsilon = 1.001000e-05 : f32, feature_index = 2 : i64} : {epsilon = 1.001000e-05 : f32, feature_index = 2 : i64} :
(tensor<3x4x256x6xf32>, tensor<256xf32>, tensor<256xf32>, tensor<256xf32>, (tensor<3x4x256x6xf32>, tensor<256xf32>, tensor<256xf32>, tensor<256xf32>,
tensor<256xf32>) -> tensor<3x4x256x6xf32> tensor<256xf32>) -> tensor<3x4x256x6xf32>
@ -51,12 +51,12 @@ func @batchNormInference_4D_middle_features(
// ----- // -----
// CHECK-LABEL: @batchNormInference_f64 // CHECK-LABEL: @batchNormInference_f64
// Validate that epsilon is properly promoted to f64 // Validate that epsilon is properly promoted to f64
// CHECK-DAG: %[[EPS:.+]] = xla_hlo.constant dense<1.000000e+00> : tensor<f64> // CHECK-DAG: %[[EPS:.+]] = mhlo.constant dense<1.000000e+00> : tensor<f64>
func @batchNormInference_f64( func @batchNormInference_f64(
%x: tensor<4x256xf64>, %scale: tensor<256xf64>, %offset: tensor<256xf64>, %x: tensor<4x256xf64>, %scale: tensor<256xf64>, %offset: tensor<256xf64>,
%mean: tensor<256xf64>, %variance: tensor<256xf64>) %mean: tensor<256xf64>, %variance: tensor<256xf64>)
-> (tensor<4x256xf64>) { -> (tensor<4x256xf64>) {
%0 = "xla_hlo.batch_norm_inference"(%x, %scale, %offset, %mean, %variance) %0 = "mhlo.batch_norm_inference"(%x, %scale, %offset, %mean, %variance)
{epsilon = 1.0 : f32, feature_index = 1 : i64} : {epsilon = 1.0 : f32, feature_index = 1 : i64} :
(tensor<4x256xf64>, tensor<256xf64>, tensor<256xf64>, tensor<256xf64>, (tensor<4x256xf64>, tensor<256xf64>, tensor<256xf64>, tensor<256xf64>,
tensor<256xf64>) -> tensor<4x256xf64> tensor<256xf64>) -> tensor<4x256xf64>
@ -66,12 +66,12 @@ func @batchNormInference_f64(
// ----- // -----
// CHECK-LABEL: @batchNormInference_f16 // CHECK-LABEL: @batchNormInference_f16
// Validate that epsilon is properly promoted to f64 // Validate that epsilon is properly promoted to f64
// CHECK-DAG: %[[EPS:.+]] = xla_hlo.constant dense<1.000000e+00> : tensor<f16> // CHECK-DAG: %[[EPS:.+]] = mhlo.constant dense<1.000000e+00> : tensor<f16>
func @batchNormInference_f16( func @batchNormInference_f16(
%x: tensor<4x256xf16>, %scale: tensor<256xf16>, %offset: tensor<256xf16>, %x: tensor<4x256xf16>, %scale: tensor<256xf16>, %offset: tensor<256xf16>,
%mean: tensor<256xf16>, %variance: tensor<256xf16>) %mean: tensor<256xf16>, %variance: tensor<256xf16>)
-> (tensor<4x256xf16>) { -> (tensor<4x256xf16>) {
%0 = "xla_hlo.batch_norm_inference"(%x, %scale, %offset, %mean, %variance) %0 = "mhlo.batch_norm_inference"(%x, %scale, %offset, %mean, %variance)
{epsilon = 1.0 : f32, feature_index = 1 : i64} : {epsilon = 1.0 : f32, feature_index = 1 : i64} :
(tensor<4x256xf16>, tensor<256xf16>, tensor<256xf16>, tensor<256xf16>, (tensor<4x256xf16>, tensor<256xf16>, tensor<256xf16>, tensor<256xf16>,
tensor<256xf16>) -> tensor<4x256xf16> tensor<256xf16>) -> tensor<4x256xf16>
@ -85,7 +85,7 @@ func @batchNormInference_f16_overflow(
%mean: tensor<256xf16>, %variance: tensor<256xf16>) %mean: tensor<256xf16>, %variance: tensor<256xf16>)
-> (tensor<4x256xf16>) { -> (tensor<4x256xf16>) {
// expected-warning @+1 {{Could not convert batch_norm epsilon to target fp type: opStatus = 24}} // expected-warning @+1 {{Could not convert batch_norm epsilon to target fp type: opStatus = 24}}
%0 = "xla_hlo.batch_norm_inference"(%x, %scale, %offset, %mean, %variance) %0 = "mhlo.batch_norm_inference"(%x, %scale, %offset, %mean, %variance)
{epsilon = 0.00000001 : f32, feature_index = 1 : i64} : {epsilon = 0.00000001 : f32, feature_index = 1 : i64} :
(tensor<4x256xf16>, tensor<256xf16>, tensor<256xf16>, tensor<256xf16>, (tensor<4x256xf16>, tensor<256xf16>, tensor<256xf16>, tensor<256xf16>,
tensor<256xf16>) -> tensor<4x256xf16> tensor<256xf16>) -> tensor<4x256xf16>
@ -108,26 +108,26 @@ func @batchNormInference_dynamic_shape(
// CHECK-DAG: %[[C1:.*]] = constant 1 : index // CHECK-DAG: %[[C1:.*]] = constant 1 : index
// CHECK-DAG: %[[C2:.*]] = constant 2 : index // CHECK-DAG: %[[C2:.*]] = constant 2 : index
// CHECK-DAG: %[[C3:.*]] = constant 3 : index // CHECK-DAG: %[[C3:.*]] = constant 3 : index
// CHECK-DAG: %[[EPS:.+]] = xla_hlo.constant dense<1.000000e-03> : tensor<f32> // CHECK-DAG: %[[EPS:.+]] = mhlo.constant dense<1.000000e-03> : tensor<f32>
// CHECK-DAG: %[[DIM:.+]] = dim %[[VARIANCE]], %[[C0]] : tensor<?xf32> // CHECK-DAG: %[[DIM:.+]] = dim %[[VARIANCE]], %[[C0]] : tensor<?xf32>
// CHECK-DAG: %[[TO_DIM_TENSOR:.+]] = tensor_from_elements(%[[DIM]]) : tensor<1xindex> // CHECK-DAG: %[[TO_DIM_TENSOR:.+]] = tensor_from_elements(%[[DIM]]) : tensor<1xindex>
// CHECK-DAG: %[[EPS_BCAST:.+]] = "xla_hlo.dynamic_broadcast_in_dim"(%[[EPS]], %[[TO_DIM_TENSOR]]) {broadcast_dimensions = dense<[]> : tensor<0xi64>} : (tensor<f32>, tensor<1xindex>) -> tensor<?xf32> // CHECK-DAG: %[[EPS_BCAST:.+]] = "mhlo.dynamic_broadcast_in_dim"(%[[EPS]], %[[TO_DIM_TENSOR]]) {broadcast_dimensions = dense<[]> : tensor<0xi64>} : (tensor<f32>, tensor<1xindex>) -> tensor<?xf32>
// CHECK-DAG: %[[VARIANCE_EPS:.+]] = xla_hlo.add %[[VARIANCE]], %[[EPS_BCAST]] : tensor<?xf32> // CHECK-DAG: %[[VARIANCE_EPS:.+]] = mhlo.add %[[VARIANCE]], %[[EPS_BCAST]] : tensor<?xf32>
// CHECK-DAG: %[[STDDEV:.+]] = "xla_hlo.sqrt"(%[[VARIANCE_EPS]]) : (tensor<?xf32>) -> tensor<?xf32> // CHECK-DAG: %[[STDDEV:.+]] = "mhlo.sqrt"(%[[VARIANCE_EPS]]) : (tensor<?xf32>) -> tensor<?xf32>
// CHECK-DAG: %[[INPUT_DIM_0:.+]] = dim %[[X]], %[[C0]] : tensor<?x?x?x?xf32> // CHECK-DAG: %[[INPUT_DIM_0:.+]] = dim %[[X]], %[[C0]] : tensor<?x?x?x?xf32>
// CHECK-DAG: %[[INPUT_DIM_1:.+]] = dim %[[X]], %[[C1]] : tensor<?x?x?x?xf32> // CHECK-DAG: %[[INPUT_DIM_1:.+]] = dim %[[X]], %[[C1]] : tensor<?x?x?x?xf32>
// CHECK-DAG: %[[INPUT_DIM_2:.+]] = dim %[[X]], %[[C2]] : tensor<?x?x?x?xf32> // CHECK-DAG: %[[INPUT_DIM_2:.+]] = dim %[[X]], %[[C2]] : tensor<?x?x?x?xf32>
// CHECK-DAG: %[[INPUT_DIM_3:.+]] = dim %[[X]], %[[C3]] : tensor<?x?x?x?xf32> // CHECK-DAG: %[[INPUT_DIM_3:.+]] = dim %[[X]], %[[C3]] : tensor<?x?x?x?xf32>
// CHECK-DAG: %[[TO_INPUT_DIM_TENSOR:.+]] = tensor_from_elements(%[[INPUT_DIM_0]], %[[INPUT_DIM_1]], %[[INPUT_DIM_2]], %[[INPUT_DIM_3]]) : tensor<4xindex> // CHECK-DAG: %[[TO_INPUT_DIM_TENSOR:.+]] = tensor_from_elements(%[[INPUT_DIM_0]], %[[INPUT_DIM_1]], %[[INPUT_DIM_2]], %[[INPUT_DIM_3]]) : tensor<4xindex>
// CHECK-DAG: %[[STDDEV_BCAST:.+]] = "xla_hlo.dynamic_broadcast_in_dim"(%[[STDDEV]], %[[TO_INPUT_DIM_TENSOR]]) {broadcast_dimensions = dense<1> : tensor<1xi64>} : (tensor<?xf32>, tensor<4xindex>) -> tensor<?x?x?x?xf32> // CHECK-DAG: %[[STDDEV_BCAST:.+]] = "mhlo.dynamic_broadcast_in_dim"(%[[STDDEV]], %[[TO_INPUT_DIM_TENSOR]]) {broadcast_dimensions = dense<1> : tensor<1xi64>} : (tensor<?xf32>, tensor<4xindex>) -> tensor<?x?x?x?xf32>
// CHECK-DAG: %[[SCALE_BCAST:.+]] = "xla_hlo.dynamic_broadcast_in_dim"(%[[SCALE]], %[[TO_INPUT_DIM_TENSOR]]) {broadcast_dimensions = dense<1> : tensor<1xi64>} : (tensor<?xf32>, tensor<4xindex>) -> tensor<?x?x?x?xf32> // CHECK-DAG: %[[SCALE_BCAST:.+]] = "mhlo.dynamic_broadcast_in_dim"(%[[SCALE]], %[[TO_INPUT_DIM_TENSOR]]) {broadcast_dimensions = dense<1> : tensor<1xi64>} : (tensor<?xf32>, tensor<4xindex>) -> tensor<?x?x?x?xf32>
// CHECK-DAG: %[[OFFSET_BCAST:.+]] = "xla_hlo.dynamic_broadcast_in_dim"(%[[OFFSET]], %[[TO_INPUT_DIM_TENSOR]]) {broadcast_dimensions = dense<1> : tensor<1xi64>} : (tensor<?xf32>, tensor<4xindex>) -> tensor<?x?x?x?xf32> // CHECK-DAG: %[[OFFSET_BCAST:.+]] = "mhlo.dynamic_broadcast_in_dim"(%[[OFFSET]], %[[TO_INPUT_DIM_TENSOR]]) {broadcast_dimensions = dense<1> : tensor<1xi64>} : (tensor<?xf32>, tensor<4xindex>) -> tensor<?x?x?x?xf32>
// CHECK-DAG: %[[MEAN_BCAST:.+]] = "xla_hlo.dynamic_broadcast_in_dim"(%[[MEAN]], %[[TO_INPUT_DIM_TENSOR]]) {broadcast_dimensions = dense<1> : tensor<1xi64>} : (tensor<?xf32>, tensor<4xindex>) -> tensor<?x?x?x?xf32> // CHECK-DAG: %[[MEAN_BCAST:.+]] = "mhlo.dynamic_broadcast_in_dim"(%[[MEAN]], %[[TO_INPUT_DIM_TENSOR]]) {broadcast_dimensions = dense<1> : tensor<1xi64>} : (tensor<?xf32>, tensor<4xindex>) -> tensor<?x?x?x?xf32>
// CHECK-DAG: %[[X_CENTER:.+]] = xla_hlo.subtract %[[X]], %[[MEAN_BCAST]] : tensor<?x?x?x?xf32> // CHECK-DAG: %[[X_CENTER:.+]] = mhlo.subtract %[[X]], %[[MEAN_BCAST]] : tensor<?x?x?x?xf32>
// CHECK-DAG: %[[X_SCALED:.+]] = xla_hlo.multiply %[[X_CENTER]], %[[SCALE_BCAST]] : tensor<?x?x?x?xf32> // CHECK-DAG: %[[X_SCALED:.+]] = mhlo.multiply %[[X_CENTER]], %[[SCALE_BCAST]] : tensor<?x?x?x?xf32>
// CHECK-DAG: %[[X_NORMED:.+]] = xla_hlo.divide %[[X_SCALED]], %[[STDDEV_BCAST]] : tensor<?x?x?x?xf32> // CHECK-DAG: %[[X_NORMED:.+]] = mhlo.divide %[[X_SCALED]], %[[STDDEV_BCAST]] : tensor<?x?x?x?xf32>
// CHECK-DAG: %[[RESULT:.+]] = xla_hlo.add %[[X_NORMED]], %[[OFFSET_BCAST]] : tensor<?x?x?x?xf32> // CHECK-DAG: %[[RESULT:.+]] = mhlo.add %[[X_NORMED]], %[[OFFSET_BCAST]] : tensor<?x?x?x?xf32>
%0 = "xla_hlo.batch_norm_inference"(%x, %scale, %offset, %mean, %variance) %0 = "mhlo.batch_norm_inference"(%x, %scale, %offset, %mean, %variance)
{epsilon = 0.001 : f32, feature_index = 1 : i64} : {epsilon = 0.001 : f32, feature_index = 1 : i64} :
(tensor<?x?x?x?xf32>, tensor<?xf32>, tensor<?xf32>, tensor<?xf32>, (tensor<?x?x?x?xf32>, tensor<?xf32>, tensor<?xf32>, tensor<?xf32>,
tensor<?xf32>) -> tensor<?x?x?x?xf32> tensor<?xf32>) -> tensor<?x?x?x?xf32>

View File

@ -0,0 +1,97 @@
// RUN: mlir-hlo-opt %s -xla-hlo-fusion -split-input-file | FileCheck %s
// CHECK-LABEL: func @multi_outputs_same
func @multi_outputs_same(%arg0: tensor<?x?xf32>, %arg1: tensor<?x?xf32>) -> (tensor<?x?xf32>, tensor<?x?xf32>) {
%0 = "mhlo.add"(%arg0, %arg1) : (tensor<?x?xf32>, tensor<?x?xf32>) -> tensor<?x?xf32>
%1 = "mhlo.subtract"(%arg0, %0) : (tensor<?x?xf32>, tensor<?x?xf32>) -> tensor<?x?xf32>
%2 = "mhlo.add"(%1, %1) : (tensor<?x?xf32>, tensor<?x?xf32>) -> tensor<?x?xf32>
// CHECK: %[[RET:.*]]:2 = "mhlo.fusion"
// CHECK-NEXT: mhlo.add
// CHECK-NEXT: mhlo.subtract
// CHECK-NEXT: mhlo.add
// CHECK-NEXT: mhlo.return
return %1, %2 : tensor<?x?xf32>, tensor<?x?xf32>
}
// -----
// CHECK-LABEL: func @multi_outputs_same_2
func @multi_outputs_same_2(%arg0: tensor<?x?xf32>, %arg1: tensor<?x?xf32>) -> (tensor<?x?xf32>, tensor<?x?xf32>, tensor<?x?xf32>) {
%0 = "mhlo.abs"(%arg0) : (tensor<?x?xf32>) -> tensor<?x?xf32>
%1 = "mhlo.abs"(%arg1) : (tensor<?x?xf32>) -> tensor<?x?xf32>
%2 = "mhlo.add"(%0, %1) : (tensor<?x?xf32>, tensor<?x?xf32>) -> tensor<?x?xf32>
%3 = "mhlo.abs"(%0) : (tensor<?x?xf32>) -> tensor<?x?xf32>
%4 = "mhlo.abs"(%1) : (tensor<?x?xf32>) -> tensor<?x?xf32>
// CHECK: %[[RET:.*]]:3 = "mhlo.fusion"
// CHECK-NEXT: mhlo.abs
// CHECK-NEXT: mhlo.abs
// CHECK-NEXT: mhlo.add
// CHECK-NEXT: mhlo.abs
// CHECK-NEXT: mhlo.abs
// CHECK-NEXT: mhlo.return
return %2, %3, %4 : tensor<?x?xf32>, tensor<?x?xf32>, tensor<?x?xf32>
}
// -----
// CHECK-LABEL: func @multi_outputs_not_sure_same
func @multi_outputs_not_sure_same(%arg0: tensor<?x?xf32>, %arg1: tensor<?x?xf32>) -> (tensor<?x?xf32>, tensor<?x?xf32>) {
%0 = "mhlo.add"(%arg0, %arg0) : (tensor<?x?xf32>, tensor<?x?xf32>) -> tensor<?x?xf32>
// CHECK-NOT: mhlo.fusion
%1 = "mhlo.subtract"(%arg1, %arg1) : (tensor<?x?xf32>, tensor<?x?xf32>) -> tensor<?x?xf32>
return %0, %1 : tensor<?x?xf32>, tensor<?x?xf32>
}
// -----
// CHECK-LABEL: func @reduce
func @reduce(%arg0: tensor<?x?xf32>, %arg1: tensor<?x?xf32>) -> (tensor<?x?xf32>, tensor<?xf32>) {
%0 = "mhlo.add"(%arg0, %arg1) : (tensor<?x?xf32>, tensor<?x?xf32>) -> tensor<?x?xf32>
%1 = "mhlo.subtract"(%arg0, %0) : (tensor<?x?xf32>, tensor<?x?xf32>) -> tensor<?x?xf32>
// CHECK: %[[RET0:.*]] = "mhlo.fusion"
// CHECK-NEXT: mhlo.add
// CHECK-NEXT: mhlo.subtract
// CHECK-NEXT: mhlo.return
// Currently we do not support fuse arguments and ops without direct producer-consumer
// relationship. Thus Reduce Op should not be fused with above two ops.
%2 = mhlo.constant dense<0.000000e+00> : tensor<f32>
%3 = "mhlo.reduce"(%arg0, %2) ( {
^bb0(%arg2: tensor<f32>, %arg3: tensor<f32>):
%4 = "mhlo.add"(%arg2, %arg3) : (tensor<f32>, tensor<f32>) -> tensor<f32>
"mhlo.return"(%4) : (tensor<f32>) -> ()
}) {dimensions = dense<[1]> : tensor<1xi64>} : (tensor<?x?xf32>, tensor<f32>) -> tensor<?xf32>
%4 = "mhlo.add"(%3, %3) : (tensor<?xf32>, tensor<?xf32>) -> tensor<?xf32>
// Above two ops should not be fused since reduce op can not be
// fused with its consumer.
// CHECK-NOT: mhlo.fusion
return %1, %4 : tensor<?x?xf32>, tensor<?xf32>
}
// -----
// CHECK-LABEL: func @reduce_2
func @reduce_2(%arg0: tensor<?x?xf32>, %arg1: tensor<?x?xf32>) -> (tensor<?x?xf32>, tensor<?xf32>) {
%0 = "mhlo.add"(%arg0, %arg1) : (tensor<?x?xf32>, tensor<?x?xf32>) -> tensor<?x?xf32>
%1 = "mhlo.subtract"(%arg0, %0) : (tensor<?x?xf32>, tensor<?x?xf32>) -> tensor<?x?xf32>
%2 = mhlo.constant dense<0.000000e+00> : tensor<f32>
%3 = "mhlo.reduce"(%1, %2) ( {
^bb0(%arg2: tensor<f32>, %arg3: tensor<f32>):
%4 = "mhlo.add"(%arg2, %arg3) : (tensor<f32>, tensor<f32>) -> tensor<f32>
"mhlo.return"(%4) : (tensor<f32>) -> ()
}) {dimensions = dense<[1]> : tensor<1xi64>} : (tensor<?x?xf32>, tensor<f32>) -> tensor<?xf32>
// CHECK: %[[RET0:.*]]:2 = "mhlo.fusion"
// CHECK-NEXT: mhlo.add
// CHECK-NEXT: mhlo.subtract
// CHECK-NEXT: mhlo.constant
// CHECK-NEXT: mhlo.reduce
// CHECK: mhlo.return
// Following op should not be fused with the above ops since reduce op can not be
// fused with its consumer.
// CHECK-NOT: mhlo.fusion
%4 = "mhlo.add"(%3, %3) : (tensor<?xf32>, tensor<?xf32>) -> tensor<?xf32>
return %1, %4 : tensor<?x?xf32>, tensor<?xf32>
}

View File

@ -1,4 +1,4 @@
// RUN: xla-opt -transform-unranked-hlo -split-input-file %s | FileCheck %s // RUN: mlir-hlo-opt -transform-unranked-hlo -split-input-file %s | FileCheck %s
// Check the validity of expected IR. // Check the validity of expected IR.
// CHECK-LABEL: @sqr_transform_result // CHECK-LABEL: @sqr_transform_result
@ -9,15 +9,15 @@ func @sqr_transform_result(%a: tensor<*xf32>) -> tensor<*xf32> {
%num_elements = shape.num_elements %shape %num_elements = shape.num_elements %shape
%num_elements_as_index = shape.size_to_index %num_elements %num_elements_as_index = shape.size_to_index %num_elements
%flat_shape = tensor_from_elements(%num_elements_as_index) : tensor<1xindex> %flat_shape = tensor_from_elements(%num_elements_as_index) : tensor<1xindex>
%flat_a = "xla_hlo.dynamic_reshape"(%a, %flat_shape) %flat_a = "mhlo.dynamic_reshape"(%a, %flat_shape)
: (tensor<*xf32>, tensor<1xindex>) -> tensor<?xf32> : (tensor<*xf32>, tensor<1xindex>) -> tensor<?xf32>
// Apply operation. // Apply operation.
%flat_b = "xla_hlo.sqrt"(%flat_a) : (tensor<?xf32>) -> tensor<?xf32> %flat_b = "mhlo.sqrt"(%flat_a) : (tensor<?xf32>) -> tensor<?xf32>
// Restore original shape. // Restore original shape.
%shape_as_extent_tensor = shape.to_extent_tensor %shape : tensor<?xindex> %shape_as_extent_tensor = shape.to_extent_tensor %shape : tensor<?xindex>
%b = "xla_hlo.dynamic_reshape"(%flat_b, %shape_as_extent_tensor) %b = "mhlo.dynamic_reshape"(%flat_b, %shape_as_extent_tensor)
: (tensor<?xf32>, tensor<?xindex>) -> tensor<*xf32> : (tensor<?xf32>, tensor<?xindex>) -> tensor<*xf32>
return %b : tensor<*xf32> return %b : tensor<*xf32>
@ -33,12 +33,12 @@ func @sqrt(%a: tensor<*xf32>) -> tensor<*xf32> {
// CHECK-NEXT: %[[NUM_ELEMENTS:.*]] = shape.num_elements %[[SHAPE]] // CHECK-NEXT: %[[NUM_ELEMENTS:.*]] = shape.num_elements %[[SHAPE]]
// CHECK-NEXT: %[[NUM_ELEMENTS_AS_INDEX:.*]] = shape.size_to_index %[[NUM_ELEMENTS]] // CHECK-NEXT: %[[NUM_ELEMENTS_AS_INDEX:.*]] = shape.size_to_index %[[NUM_ELEMENTS]]
// CHECK-NEXT: %[[FLAT_SHAPE:.*]] = tensor_from_elements(%[[NUM_ELEMENTS_AS_INDEX]]) : tensor<1xindex> // CHECK-NEXT: %[[FLAT_SHAPE:.*]] = tensor_from_elements(%[[NUM_ELEMENTS_AS_INDEX]]) : tensor<1xindex>
// CHECK-NEXT: %[[FLAT_A:.*]] = "xla_hlo.dynamic_reshape"(%[[A]], %[[FLAT_SHAPE]]) : (tensor<*xf32>, tensor<1xindex>) -> tensor<?xf32> // CHECK-NEXT: %[[FLAT_A:.*]] = "mhlo.dynamic_reshape"(%[[A]], %[[FLAT_SHAPE]]) : (tensor<*xf32>, tensor<1xindex>) -> tensor<?xf32>
// CHECK-NEXT: %[[FLAT_B:.*]] = "xla_hlo.sqrt"(%[[FLAT_A]]) : (tensor<?xf32>) -> tensor<?xf32> // CHECK-NEXT: %[[FLAT_B:.*]] = "mhlo.sqrt"(%[[FLAT_A]]) : (tensor<?xf32>) -> tensor<?xf32>
// CHECK-NEXT: %[[SHAPE_AS_EXTENT_TENSOR:.*]] = shape.to_extent_tensor %[[SHAPE]] : tensor<?xindex> // CHECK-NEXT: %[[SHAPE_AS_EXTENT_TENSOR:.*]] = shape.to_extent_tensor %[[SHAPE]] : tensor<?xindex>
// CHECK-NEXT: %[[B:.*]] = "xla_hlo.dynamic_reshape"(%[[FLAT_B]], %[[SHAPE_AS_EXTENT_TENSOR]]) : (tensor<?xf32>, tensor<?xindex>) -> tensor<*xf32> // CHECK-NEXT: %[[B:.*]] = "mhlo.dynamic_reshape"(%[[FLAT_B]], %[[SHAPE_AS_EXTENT_TENSOR]]) : (tensor<?xf32>, tensor<?xindex>) -> tensor<*xf32>
// CHECK-NEXT: return %[[B]] : tensor<*xf32> // CHECK-NEXT: return %[[B]] : tensor<*xf32>
%b = "xla_hlo.sqrt"(%a) : (tensor<*xf32>) -> tensor<*xf32> %b = "mhlo.sqrt"(%a) : (tensor<*xf32>) -> tensor<*xf32>
return %b : tensor<*xf32> return %b : tensor<*xf32>
} }
@ -48,9 +48,9 @@ func @sqrt(%a: tensor<*xf32>) -> tensor<*xf32> {
// CHECK-LABEL: @sqrt_ranked // CHECK-LABEL: @sqrt_ranked
// CHECK-SAME: (%[[A:.*]]: tensor<3x?xf32>) // CHECK-SAME: (%[[A:.*]]: tensor<3x?xf32>)
func @sqrt_ranked(%a: tensor<3x?xf32>) -> tensor<3x?xf32> { func @sqrt_ranked(%a: tensor<3x?xf32>) -> tensor<3x?xf32> {
// CHECK-NEXT: %[[B:.*]] = "xla_hlo.sqrt"(%[[A]]) : (tensor<3x?xf32>) -> tensor<3x?xf32> // CHECK-NEXT: %[[B:.*]] = "mhlo.sqrt"(%[[A]]) : (tensor<3x?xf32>) -> tensor<3x?xf32>
// CHECK-NEXT: return %[[B]] : tensor<3x?xf32> // CHECK-NEXT: return %[[B]] : tensor<3x?xf32>
%b = "xla_hlo.sqrt"(%a) : (tensor<3x?xf32>) -> tensor<3x?xf32> %b = "mhlo.sqrt"(%a) : (tensor<3x?xf32>) -> tensor<3x?xf32>
return %b : tensor<3x?xf32> return %b : tensor<3x?xf32>
} }
@ -60,9 +60,9 @@ func @sqrt_ranked(%a: tensor<3x?xf32>) -> tensor<3x?xf32> {
// CHECK-LABEL: @sqrt_static // CHECK-LABEL: @sqrt_static
// CHECK-SAME: (%[[A:.*]]: tensor<2x3xf32>) // CHECK-SAME: (%[[A:.*]]: tensor<2x3xf32>)
func @sqrt_static(%a: tensor<2x3xf32>) -> tensor<2x3xf32> { func @sqrt_static(%a: tensor<2x3xf32>) -> tensor<2x3xf32> {
// CHECK-NEXT: %[[B:.*]] = "xla_hlo.sqrt"(%[[A]]) : (tensor<2x3xf32>) -> tensor<2x3xf32> // CHECK-NEXT: %[[B:.*]] = "mhlo.sqrt"(%[[A]]) : (tensor<2x3xf32>) -> tensor<2x3xf32>
// CHECK-NEXT: return %[[B]] : tensor<2x3xf32> // CHECK-NEXT: return %[[B]] : tensor<2x3xf32>
%b = "xla_hlo.sqrt"(%a) : (tensor<2x3xf32>) -> tensor<2x3xf32> %b = "mhlo.sqrt"(%a) : (tensor<2x3xf32>) -> tensor<2x3xf32>
return %b : tensor<2x3xf32> return %b : tensor<2x3xf32>
} }
@ -77,12 +77,12 @@ func @add_unranked(%a : tensor<*xf32>, %b : tensor<*xf32>) -> tensor<*xf32> {
// CHECK: %[[NUM_ELEMENTS:.*]] = shape.num_elements %[[SHAPE]] // CHECK: %[[NUM_ELEMENTS:.*]] = shape.num_elements %[[SHAPE]]
// CHECK: %[[NUM_ELEMENTS_AS_INDEX:.*]] = shape.size_to_index %[[NUM_ELEMENTS]] // CHECK: %[[NUM_ELEMENTS_AS_INDEX:.*]] = shape.size_to_index %[[NUM_ELEMENTS]]
// CHECK: %[[FLAT_SHAPE:.*]] = tensor_from_elements(%[[NUM_ELEMENTS_AS_INDEX]]) : tensor<1xindex> // CHECK: %[[FLAT_SHAPE:.*]] = tensor_from_elements(%[[NUM_ELEMENTS_AS_INDEX]]) : tensor<1xindex>
// CHECK: %[[FLAT_A:.*]] = "xla_hlo.dynamic_reshape"(%[[A]], %[[FLAT_SHAPE]]) : (tensor<*xf32>, tensor<1xindex>) -> tensor<?xf32> // CHECK: %[[FLAT_A:.*]] = "mhlo.dynamic_reshape"(%[[A]], %[[FLAT_SHAPE]]) : (tensor<*xf32>, tensor<1xindex>) -> tensor<?xf32>
// CHECK: %[[FLAT_B:.*]] = "xla_hlo.dynamic_reshape"(%[[B]], %[[FLAT_SHAPE]]) : (tensor<*xf32>, tensor<1xindex>) -> tensor<?xf32> // CHECK: %[[FLAT_B:.*]] = "mhlo.dynamic_reshape"(%[[B]], %[[FLAT_SHAPE]]) : (tensor<*xf32>, tensor<1xindex>) -> tensor<?xf32>
// CHECK: %[[FLAT_RESULT:.*]] = xla_hlo.add %[[FLAT_A]], %[[FLAT_B]] : tensor<?xf32> // CHECK: %[[FLAT_RESULT:.*]] = mhlo.add %[[FLAT_A]], %[[FLAT_B]] : tensor<?xf32>
// CHECK: %[[SHAPE_AS_EXTENT_TENSOR:.*]] = shape.to_extent_tensor %[[SHAPE]] : tensor<?xindex> // CHECK: %[[SHAPE_AS_EXTENT_TENSOR:.*]] = shape.to_extent_tensor %[[SHAPE]] : tensor<?xindex>
// CHECK: %[[RESULT:.*]] = "xla_hlo.dynamic_reshape"(%[[FLAT_RESULT]], %[[SHAPE_AS_EXTENT_TENSOR]]) : (tensor<?xf32>, tensor<?xindex>) -> tensor<*xf32> // CHECK: %[[RESULT:.*]] = "mhlo.dynamic_reshape"(%[[FLAT_RESULT]], %[[SHAPE_AS_EXTENT_TENSOR]]) : (tensor<?xf32>, tensor<?xindex>) -> tensor<*xf32>
// CHECK: return %[[RESULT]] : tensor<*xf32> // CHECK: return %[[RESULT]] : tensor<*xf32>
%result = xla_hlo.add %a, %b : tensor<*xf32> %result = mhlo.add %a, %b : tensor<*xf32>
return %result : tensor<*xf32> return %result : tensor<*xf32>
} }

View File

@ -115,12 +115,12 @@ Status MlirFunctionOptimizationPass::Run(
}); });
if (!is_enabled) { if (!is_enabled) {
VLOG(1) << "None of the MLIR optimization passes are enabled " VLOG(0) << "None of the MLIR optimization passes are enabled "
<< "(registered " << registry_->passes().size() << ")"; << "(registered " << registry_->passes().size() << ")";
return Status::OK(); return Status::OK();
} }
VLOG(1) << "Running MLIR Graph Optimization Passes " VLOG(0) << "Running MLIR Graph Optimization Passes "
<< "(registered " << registry_->passes().size() << " passes)"; << "(registered " << registry_->passes().size() << " passes)";
GraphDebugInfo debug_info; GraphDebugInfo debug_info;
@ -187,12 +187,12 @@ Status MlirV1CompatGraphOptimizationPass::Run(
}); });
if (!is_enabled) { if (!is_enabled) {
VLOG(1) << "None of the MLIR optimization passes are enabled " VLOG(0) << "None of the MLIR optimization passes are enabled "
<< "(registered" << registry_->passes().size() << " passes)"; << "(registered" << registry_->passes().size() << " passes)";
return Status::OK(); return Status::OK();
} }
VLOG(1) << "Running MLIR Graph Optimization V1 Compat Passes " VLOG(0) << "Running MLIR Graph Optimization V1 Compat Passes "
<< "(registered" << registry_->passes().size() << " passes)"; << "(registered" << registry_->passes().size() << " passes)";
GraphDebugInfo debug_info; GraphDebugInfo debug_info;

View File

@ -70,7 +70,7 @@ tool_dirs = config.mlir_tf_tools_dirs + [
config.mlir_tools_dir, config.llvm_tools_dir config.mlir_tools_dir, config.llvm_tools_dir
] ]
tool_names = [ tool_names = [
'mlir-opt', 'mlir-translate', 'tf-opt', 'tf_tfl_translate', 'mlir-opt', 'mlir-hlo-opt', 'mlir-translate', 'tf-opt', 'tf_tfl_translate',
'tf_tfjs_translate', 'flatbuffer_to_string', 'flatbuffer_translate', 'tf_tfjs_translate', 'flatbuffer_to_string', 'flatbuffer_translate',
'tf-mlir-translate', 'mlir-tflite-runner', 'tfcompile', 'tf-mlir-translate', 'mlir-tflite-runner', 'tfcompile',
'json_to_flatbuffer', 'xla-gpu-opt', 'xla-opt', 'hlo_to_llvm_ir' 'json_to_flatbuffer', 'xla-gpu-opt', 'xla-opt', 'hlo_to_llvm_ir'

View File

@ -42,6 +42,7 @@ config.suffixes = ['.td', '.mlir', '.pbtxt']
mlir_tf_tools_dirs = [ mlir_tf_tools_dirs = [
'tensorflow/compiler/mlir', 'tensorflow/compiler/mlir',
'tensorflow/compiler/mlir/hlo',
'tensorflow/compiler/mlir/lite', 'tensorflow/compiler/mlir/lite',
'tensorflow/compiler/mlir/tensorflow', 'tensorflow/compiler/mlir/tensorflow',
'tensorflow/compiler/mlir/tfjs', 'tensorflow/compiler/mlir/tfjs',

View File

@ -144,6 +144,7 @@ gentbl(
td_srcs = [ td_srcs = [
"@llvm-project//mlir:include/mlir/IR/OpBase.td", "@llvm-project//mlir:include/mlir/IR/OpBase.td",
"@llvm-project//mlir:include/mlir/Dialect/StandardOps/IR/Ops.td", "@llvm-project//mlir:include/mlir/Dialect/StandardOps/IR/Ops.td",
"@llvm-project//mlir:include/mlir/IR/SymbolInterfaces.td",
], ],
test = True, test = True,
) )
@ -786,7 +787,6 @@ cc_library(
"//tensorflow/core:lib", "//tensorflow/core:lib",
"//tensorflow/core:protos_all_cc", "//tensorflow/core:protos_all_cc",
"//tensorflow/core/grappler/utils:transitive_fanin", "//tensorflow/core/grappler/utils:transitive_fanin",
"//tensorflow/core/platform:protobuf_internal",
"//tensorflow/core/platform:types", "//tensorflow/core/platform:types",
"//tensorflow/stream_executor/lib", "//tensorflow/stream_executor/lib",
"@com_google_absl//absl/algorithm:container", "@com_google_absl//absl/algorithm:container",

Binary file not shown.

After

Width:  |  Height:  |  Size: 180 KiB

View File

@ -0,0 +1,196 @@
# Automatic Space to Depth Transform in MLIR Bridge
Author: wangtao@, yuanzx@, hinsu@, lyandy@, chiachenc@, aminim@, jpienaar@,
dehao@
## TL;DR
_This document describes an automatic space to depth transform for the first
convolution in the new MLIR bridge to improve MXU efficiency of low batch size
convolutions._
## Background
For image models, the first layer is usually not MXU friendly as it has a
feature size of 3. This results in poor performance especially with small batch.
One way to address this issue is to use the `space-to-depth` transform. This
optimization tiles the 2x2 space dimensions to the feature dimension so that the
feature dimension becomes 3\*4=12, which is more MXU friendly. In order to make
this optimization efficient, the shape of the weight needs to be padded and
transposed to the shape that the convolution emitter expects. The input also
needs to be transposed on the host and padded on the device to make the
convolution efficient. Although a 2x2 space-to-depth transform works only when
the first convolution has a stride of 2, many image models, ResNet-like in
particular, have a stride-2 convolution in the first layer.
Space to depth helped models such as MaskRCNN, SSD and I3D gain more than 2X
speedup and reduce memory usage in the first convolution.
The first convolution in many image models, including ResNet or ResNet-like, is
a (kernel=7, stride=2) 2D convolution. The input of the convolution is images,
which usually has RGB channels. The input of this first convolution is of shape
[batch\_size, height, width, 3] and the kernel size is [kernel\_size,
kernel\_size, 3, out\_channel]. Space to depth is to transform this first
convolution's input to [batch\_size, height // stride, width // stride, 3 \*
stride \* stride] and the kernel to [kernel\_size // stride, kernel\_size //
stride, 3 \* stride \* stride, out\_channel] to improve TPU MXU utilization.
![drawings](images/space_to_depth_transform.png)
This optimization can be automatically done by the graph optimizer where weight
transformation is done at variable loading time and the input transformation is
done for every inference invocation. A further optimization can fuse this (at
host) with the double transpose to minimize memory operation on host.
## Proposed Method
**block\_size** is defined as the number of space sizes transformed to the depth
dimension. _stride % block\_size == 0_ and _stride >= block\_size_ is required
to do the transform. There are three parts of automatically space to depth
transformation:
1. Transform input on the host.
Space-to-depth performs the following permutation, which is equivalent to
`tf.nn.space_to_depth`.
```python
images = tf.reshape(images, [batch, h // block_size, block_size,
w // block_size, block_size, c])
images = tf.transpose(images, [0, 1, 3, 2, 4, 5])
images = tf.reshape(images, [batch, h // block_size, w // block_size,
c * (block_size ** 2)])
```
`SpaceToDepthOp` can be called on the host to perform the transform.
1. Weight Transformation
Weight Transformation is similar to Input Transform. Weight transform is
needed to apply space to depth optimization for a model that needs to load a
pre-train checkpoint. This transform can be done on the host or TPU device
based on the cost. As the size of the kernel is relatively small, this won't
add additional cost to TPU device time. Below is the logic to transform the
kernel of shape [7, 7, 3, 64] to [4, 4, 12, 84].
```python
conv0 = tf.compat.v1.layers.Conv2D(
filters=filters,
kernel_size=kernel_size,
strides=2,
padding=('SAME' if strides == 1 else 'VALID'),
use_bias=False,
kernel_initializer=tf.variance_scaling_initializer(),
data_format=data_format)
# Use the image size without space-to-depth transform as the input of conv0.
batch_size, h, w, channel = inputs.get_shape().as_list()
conv0.build([
batch_size, h * space_to_depth_block_size, w * space_to_depth_block_size,
channel // (space_to_depth_block_size**2)
])
kernel = conv0.weights[0]
# [7, 7, 3, 64] --> [8, 8, 3, 64]
kernel = tf.pad(
kernel,
paddings=tf.constant([[1, 0], [1, 0], [0, 0], [0, 0]]),
mode='CONSTANT',
constant_values=0.)
# Transform kernel follows the space-to-depth logic: https://www.tensorflow.org/api_docs/python/tf/nn/space_to_depth)
kernel = tf.reshape(
kernel,
[4, space_to_depth_block_size, 4, space_to_depth_block_size, 3, filters])
kernel = tf.transpose(kernel, [0, 2, 1, 3, 4, 5])
kernel = tf.reshape(kernel, [4, 4, int(channel), filters])
kernel = tf.cast(kernel, inputs.dtype)
```
If kernel\_size % block\_size != 0, padding is needed for the weight before
transform, input of Convolution needs to be padded as well.
1. Rewrite the first convolution
Need to rewrite the first convolution's shape of input from [batch\_size,
height, width, 3] to [batch\_size, height // block\_size, width //
block\_size, 3 \* block\_size \* block\_size] and kernel shape from
[kernel\_size, kernel\_size, 3, out\_channel] to [kernel\_size //
block\_size, kernel\_size // block\_size, 3 \* block\_size \* block\_size,
This is the proposed workflow for automatic space to depth transformation.
All the transformations will be triggered in a MLIR SpaceToDepthRewritePass,
this Rewrite pass will be triggered before TPURewrite so that no metadata
rewrite is needed.
* First, the rewrite pass will walk through all the convolutions in func of
tf\_device::LaunchOp and get the first Convolution and its shape;
* Second, the rewrite pass will apply transformations to the first
convolution, the padding before the first convolution, first convolution's
filters and its Conv2DBackPropFilter;
* At last, the rewrite pass will insert SpaceToDepthOp after IteratorGetNext
where the iterator's result has the same shape as the first convolution's
input.
#### Pseudo MLIR code before and after RewritePass
```mlir
// Example: original program:
//
module {
func @while_body {
%input = "tf.IteratorGetNext"(...) {device = "/CPU:0"}:
-> tensor<2x224x224x3xf32>
%device_launch = "tf_device.launch_func"(%input,...) {func = @_func,...)
return ...
}
func @_func(%input: tensor<2x224x224x3xf32>,
%filter: tensor<7x7x3x64xf32>) {
%6 = "tf.Conv2D"(%input, %filter) {strides = [1, 2, 2, 1]}:
(tensor<2x230x230x3xf32>, tensor<7x7x3x64xf32>) ->
tensor<2x112x112x64xf32>
}
}
// With this pass, the program will be transformed into:
module {
func @while_body {
%input = "tf.IteratorGetNext"(...) {device = "/CPU:0"}
-> tensor<2x224x224x3xf32>
%space_to_depth = "tf.SpaceToDepth"(%input) {block_size = 2, ...}:
(tensor<2x224x224x3xf32>) -> tensor<2x112x112x12xf32>
%device_launch = "tf_device.launch_func"(%space_to_depth,...) {func = @_func,...)
return ...
}
func @_func(%input: tensor<2x112x112x12xf32>,
%filter: tensor<7x7x3x64xf32>) {
%filter_transform = "tf.Pad/tf.Transpose/tf.Reshape"(%filter):
tensor<7x7x3x64xf32>) -> tensor<4x4x12x64xf32>
%conv = "tf.Conv2D"(%input, %filter_transfrom) {strides = [1, 1, 1, 1]}:
(tensor<2x112x112x12xf32>, tensor<4x4x12x64xf32>) ->
tensor<2x112x112x64xf32>
}
}
```
### SpaceToDepth Trigger Condition
Space to depth will only be triggered when batch size is small and the first
convolution channel size is small. Stride of the convolution should be bigger
than 1 as well. A cost model will be built that takes input shape and host cost
into consideration to trigger the transformation. There will be a flag to
disable this feature as well.
### Fuse SpaceToDepth with Automatic Double Transpose
The transpose and reshape op in SpaceToDepthOp on TPU hosts may cause image
model to be infeed bound. To reduce host time, space to depth transform can be
fused with `automatic double transpose` to reduce extra overhead on the host.
### Extend from Conv2D to Conv3D
SpaceToDepth not only helps with 2D image models but also 3D image models such
as I3D. The plan is to apply automatic space to depth for Conv2D as the first
step. After Conv2D is well tested, will generalize this technique to Conv3D.

View File

@ -229,7 +229,8 @@ namespace {
ParseResult ParseReplicateOpOperands( ParseResult ParseReplicateOpOperands(
OpAsmParser* parser, OperationState* state, OpAsmParser* parser, OperationState* state,
llvm::SmallVectorImpl<llvm::SmallVector<OpAsmParser::OperandType, 8>>* llvm::SmallVectorImpl<llvm::SmallVector<OpAsmParser::OperandType, 8>>*
operands, replicated_inputs,
llvm::SmallVectorImpl<OpAsmParser::OperandType>* packed_inputs,
llvm::SmallVectorImpl<OpAsmParser::OperandType>* region_args, llvm::SmallVectorImpl<OpAsmParser::OperandType>* region_args,
llvm::SmallVectorImpl<Type>* region_arg_types) { llvm::SmallVectorImpl<Type>* region_arg_types) {
// No operands or empty operand list. // No operands or empty operand list.
@ -238,26 +239,61 @@ ParseResult ParseReplicateOpOperands(
return success(); return success();
// Parse comma separated operands of the following format: // Parse comma separated operands of the following format:
// [%a, ...] as %block_arg: type // replicated_input
// [%a, ...] as %block_arg0: type
// packed_input
// %b as %block_arg1: type
//
// Replicated inputs are placed before packed inputs when forming the op.
llvm::SmallVector<OpAsmParser::OperandType, 8> replicated_region_args;
llvm::SmallVector<OpAsmParser::OperandType, 8> packed_region_args;
llvm::SmallVector<Type, 8> replicated_region_arg_types;
llvm::SmallVector<Type, 8> packed_region_arg_types;
do { do {
if (parser->parseOperandList(operands->emplace_back(), OpAsmParser::OperandType operand_type;
OpAsmParser::Delimiter::Square) || if (parser->parseOptionalOperand(operand_type).hasValue()) {
parser->parseKeyword("as", packed_inputs->emplace_back(operand_type);
" between replicated inputs and block argument") || if (parser->parseKeyword("as",
parser->parseRegionArgument(region_args->emplace_back()) || " between packed input and block argument") ||
parser->parseColonType(region_arg_types->emplace_back())) parser->parseRegionArgument(packed_region_args.emplace_back()) ||
parser->parseColonType(packed_region_arg_types.emplace_back()))
return failure(); return failure();
} else if (parser->parseOperandList(replicated_inputs->emplace_back(),
OpAsmParser::Delimiter::Square) ||
parser->parseKeyword(
"as", " between replicated inputs and block argument") ||
parser->parseRegionArgument(
replicated_region_args.emplace_back()) ||
parser->parseColonType(
replicated_region_arg_types.emplace_back())) {
return failure();
}
} while (succeeded(parser->parseOptionalComma())); } while (succeeded(parser->parseOptionalComma()));
region_args->reserve(replicated_region_args.size() +
packed_region_args.size());
region_args->append(replicated_region_args.begin(),
replicated_region_args.end());
region_args->append(packed_region_args.begin(), packed_region_args.end());
region_arg_types->reserve(replicated_region_arg_types.size() +
packed_region_arg_types.size());
region_arg_types->append(replicated_region_arg_types.begin(),
replicated_region_arg_types.end());
region_arg_types->append(packed_region_arg_types.begin(),
packed_region_arg_types.end());
// Parse remaining `)` surrounding operands. // Parse remaining `)` surrounding operands.
return parser->parseRParen(); return parser->parseRParen();
} }
ParseResult SetOperands( ParseResult SetReplicateOpOperands(
llvm::SMLoc loc, OpAsmParser* parser, OperationState* state, llvm::SMLoc loc, OpAsmParser* parser, OperationState* state,
llvm::ArrayRef<llvm::SmallVector<OpAsmParser::OperandType, 8>> operands, llvm::ArrayRef<llvm::SmallVector<OpAsmParser::OperandType, 8>>
llvm::ArrayRef<Type> region_arg_types, int* n) { replicated_inputs,
if (operands.empty()) return success(); llvm::ArrayRef<OpAsmParser::OperandType> packed_inputs,
llvm::ArrayRef<Type> region_arg_types, int32_t* n) {
if (replicated_inputs.empty() && packed_inputs.empty()) return success();
for (const auto& attr : state->attributes) for (const auto& attr : state->attributes)
if (attr.first.strref() == "n") if (attr.first.strref() == "n")
@ -267,38 +303,68 @@ ParseResult SetOperands(
if (*n < 2) if (*n < 2)
return parser->emitError(loc) << "expects 'n' to be at least 2, got " << *n; return parser->emitError(loc) << "expects 'n' to be at least 2, got " << *n;
for (int i = 0, e = operands.size(); i < e; ++i) { for (auto replicated_input_and_idx : llvm::enumerate(replicated_inputs)) {
const auto& operand = operands[i]; const int32_t idx = replicated_input_and_idx.index();
const auto& replicated_input = replicated_input_and_idx.value();
// Check if replicated input matches `n`. // Check if replicated input matches `n`.
if (operand.size() != *n) if (replicated_input.size() != *n)
return parser->emitError(loc) return parser->emitError(loc)
<< "expects number of operands for replicated input " << i << "expects number of operands for replicated input " << idx
<< " to be 'n' (" << *n << "), got " << operand.size(); << " to be 'n' (" << *n << "), got " << replicated_input.size();
// Resolve replicated input and block argument type. // Resolve replicated input and block argument type.
if (parser->resolveOperands(operand, region_arg_types[i], state->operands)) if (parser->resolveOperands(replicated_input, region_arg_types[idx],
state->operands))
return failure();
}
const int32_t num_replicated_block_args = replicated_inputs.size();
for (auto packed_input_and_idx : llvm::enumerate(packed_inputs)) {
const int32_t idx = packed_input_and_idx.index();
const auto& packed_input = packed_input_and_idx.value();
// Resolve packed input and block argument type.
if (parser->resolveOperand(
packed_input, region_arg_types[idx + num_replicated_block_args],
state->operands))
return failure(); return failure();
} }
return success(); return success();
} }
constexpr char kOperandSegmentSizesAttr[] = "operand_segment_sizes";
ParseResult ParseReplicateOp(OpAsmParser* parser, OperationState* state) { ParseResult ParseReplicateOp(OpAsmParser* parser, OperationState* state) {
llvm::SMLoc loc = parser->getCurrentLocation(); llvm::SMLoc loc = parser->getCurrentLocation();
// Parse operands, attributes, and region of op. // Parse operands, attributes, and region of op.
llvm::SmallVector<llvm::SmallVector<OpAsmParser::OperandType, 8>, 8> operands; llvm::SmallVector<llvm::SmallVector<OpAsmParser::OperandType, 8>, 8>
replicated_inputs;
llvm::SmallVector<OpAsmParser::OperandType, 8> packed_inputs;
llvm::SmallVector<OpAsmParser::OperandType, 8> region_args; llvm::SmallVector<OpAsmParser::OperandType, 8> region_args;
llvm::SmallVector<Type, 8> region_arg_types; llvm::SmallVector<Type, 8> region_arg_types;
int n = 0; int32_t n = 0;
Region& body = *state->addRegion(); Region& body = *state->addRegion();
if (ParseReplicateOpOperands(parser, state, &operands, &region_args, if (ParseReplicateOpOperands(parser, state, &replicated_inputs,
&packed_inputs, &region_args,
&region_arg_types) || &region_arg_types) ||
parser->parseOptionalAttrDict(state->attributes) || parser->parseOptionalAttrDict(state->attributes) ||
SetOperands(loc, parser, state, operands, region_arg_types, &n) || SetReplicateOpOperands(loc, parser, state, replicated_inputs,
packed_inputs, region_arg_types, &n) ||
parser->parseRegion(body, region_args, region_arg_types)) parser->parseRegion(body, region_args, region_arg_types))
return failure(); return failure();
// Add derived `operand_segment_sizes` attribute based on parsed operands.
if (!state->attributes.get(kOperandSegmentSizesAttr)) {
int32_t num_replicated_inputs = replicated_inputs.size() * n;
int32_t num_packed_inputs = packed_inputs.size();
auto attr = DenseIntElementsAttr::get(
VectorType::get({2}, parser->getBuilder().getI32Type()),
{num_replicated_inputs, num_packed_inputs});
state->addAttribute(kOperandSegmentSizesAttr, attr);
}
// Ensure that the region is well formed: it contains at least a block with // Ensure that the region is well formed: it contains at least a block with
// a ReturnOp terminator. // a ReturnOp terminator.
ReplicateOp::ensureTerminator(body, parser->getBuilder(), state->location); ReplicateOp::ensureTerminator(body, parser->getBuilder(), state->location);
@ -323,22 +389,40 @@ void Print(ReplicateOp op, OpAsmPrinter* p) {
*p << op.getOperationName(); *p << op.getOperationName();
// Print comma separated operands of the following format: // Print comma separated operands of the following format:
// [%a, ...] as %block_arg: type // replicated_input
int n = op.getAttrOfType<IntegerAttr>("n").getInt(); // [%a, ...] as %block_arg0: type
// packed_input
// %b as %block_arg1: type
const int32_t n = op.n().getSExtValue();
const int32_t num_replicated_inputs =
(*op.operand_segment_sizes().int_value_begin()).getSExtValue();
const int32_t num_replicated_block_args = num_replicated_inputs / n;
if (op.getNumOperands()) { if (op.getNumOperands()) {
*p << '('; *p << '(';
Block& block = op.body().front(); Block& block = op.body().front();
interleaveComma(block.getArguments(), *p, [&](BlockArgument arg) { interleaveComma(block.getArguments(), *p, [&](BlockArgument arg) {
const int block_arg_num = arg.getArgNumber(); const int block_arg_num = arg.getArgNumber();
if (block_arg_num < num_replicated_block_args) {
*p << '['; *p << '[';
p->printOperands(std::next(op.operand_begin(), block_arg_num * n), p->printOperands(
std::next(op.operand_begin(), (block_arg_num + 1) * n)); std::next(op.replicated_inputs().begin(), block_arg_num * n),
*p << "] as " << arg << ": " << arg.getType(); std::next(op.replicated_inputs().begin(), (block_arg_num + 1) * n));
*p << "]";
} else {
p->printOperand(*std::next(op.packed_inputs().begin(),
block_arg_num - num_replicated_block_args));
}
*p << " as " << arg << ": " << arg.getType();
}); });
*p << ')'; *p << ')';
} }
p->printOptionalAttrDict(op.getAttrs()); // Skip derived `operand_segment_sizes` attribute as custom print format of
// operands holds enough information to calculate these variadic operand list
// lengths.
p->printOptionalAttrDict(op.getAttrs(), /*elidedAttrs=*/ArrayRef<StringRef>{
kOperandSegmentSizesAttr});
p->printRegion(op.body(), /*printEntryBlockArgs=*/false); p->printRegion(op.body(), /*printEntryBlockArgs=*/false);
} }
@ -353,9 +437,7 @@ LogicalResult VerifyCompatibleTypes(Type a, Type b) {
} }
LogicalResult Verify(ReplicateOp op) { LogicalResult Verify(ReplicateOp op) {
uint64_t n = op.n().getLimitedValue(); int32_t n = op.n().getSExtValue();
if (n < 2)
return op.emitOpError() << "expects 'n' to be at least 2, got " << n;
// Check number of devices, if set, matches `n`. // Check number of devices, if set, matches `n`.
if (op.devices().hasValue()) { if (op.devices().hasValue()) {
@ -381,22 +463,46 @@ LogicalResult Verify(ReplicateOp op) {
Block& block = op.body().front(); Block& block = op.body().front();
// Check number of operands matches `n` * number of block arguments. auto operand_segment_sizes = op.operand_segment_sizes();
if (op.getNumOperands() != n * block.getNumArguments()) const int32_t num_replicated_inputs =
return op.emitOpError() operand_segment_sizes.getValue<IntegerAttr>({0}).getInt();
<< "expects number of operands (" << op.getNumOperands() const int32_t num_packed_inputs =
<< ") to be equal to 'n' * number of block arguments (" << n << " * " operand_segment_sizes.getValue<IntegerAttr>({1}).getInt();
<< block.getNumArguments() << ")";
// Check replicated input types match block argument types. if (num_replicated_inputs % n != 0)
for (auto block_arg : block.getArguments()) {
Type block_arg_type = block_arg.getType();
for (int i = n * block_arg.getArgNumber(), e = i + n; i < e; ++i)
if (failed(VerifyCompatibleTypes(block_arg_type,
op.getOperand(i).getType())))
return op.emitOpError() return op.emitOpError()
<< "incompatible types for operand " << i << "expects number of replicated inputs (" << num_replicated_inputs
<< " and block argument " << block_arg.getArgNumber(); << ") to be evenly divisible by 'n' (" << n << ")";
const int32_t num_replicated_block_args = num_replicated_inputs / n;
if (num_replicated_block_args + num_packed_inputs != block.getNumArguments())
return op.emitOpError()
<< "expects number of block arguments (" << block.getNumArguments()
<< ") to be equal to number of replicated inputs ("
<< num_replicated_inputs << ") / 'n' (" << n
<< ") + number of packed inputs (" << num_packed_inputs << ")";
// Check input types match block argument types.
auto verify_operand_types = [&](BlockArgument block_arg,
int32_t op_operand_idx) -> LogicalResult {
Type op_operand_type = op.getOperand(op_operand_idx).getType();
if (failed(VerifyCompatibleTypes(block_arg.getType(), op_operand_type)))
return op.emitOpError()
<< "expects operand " << op_operand_idx << " (" << op_operand_type
<< ") and block argument " << block_arg.getArgNumber() << " ("
<< block_arg.getType() << ") to have compatible types";
return success();
};
for (auto block_arg : block.getArguments()) {
if (block_arg.getArgNumber() < num_replicated_block_args) {
for (int32_t i = n * block_arg.getArgNumber(), e = i + n; i < e; ++i)
if (failed(verify_operand_types(block_arg, i))) return failure();
} else {
const int32_t idx = block_arg.getArgNumber() - num_replicated_block_args +
num_replicated_inputs;
if (failed(verify_operand_types(block_arg, idx))) return failure();
}
} }
Operation& terminator = block.back(); Operation& terminator = block.back();
@ -412,8 +518,8 @@ LogicalResult Verify(ReplicateOp op) {
for (auto operand_type_and_idx : for (auto operand_type_and_idx :
llvm::enumerate(terminator.getOperandTypes())) { llvm::enumerate(terminator.getOperandTypes())) {
Type operand_type = operand_type_and_idx.value(); Type operand_type = operand_type_and_idx.value();
int operand_idx = operand_type_and_idx.index(); int32_t operand_idx = operand_type_and_idx.index();
for (int i = n * operand_idx, e = i + n; i < e; ++i) for (int32_t i = n * operand_idx, e = i + n; i < e; ++i)
if (failed(VerifyCompatibleTypes(operand_type, op.getType(i)))) if (failed(VerifyCompatibleTypes(operand_type, op.getType(i))))
return op.emitOpError() << "incompatible types for result " << i return op.emitOpError() << "incompatible types for result " << i
<< " and terminator operand " << operand_idx; << " and terminator operand " << operand_idx;
@ -428,7 +534,7 @@ void BuildReplicateOp(
const llvm::SmallDenseMap<StringRef, llvm::SmallVector<StringRef, 4>>& const llvm::SmallDenseMap<StringRef, llvm::SmallVector<StringRef, 4>>&
devices, devices,
llvm::ArrayRef<std::pair<OperandsTy, Type>> replicated_inputs, llvm::ArrayRef<std::pair<OperandsTy, Type>> replicated_inputs,
ResultsTy replica_output_types) { llvm::ArrayRef<Value> packed_inputs, ResultsTy replica_output_types) {
DCHECK_GE(n, 2); DCHECK_GE(n, 2);
state->addAttribute("n", builder->getI32IntegerAttr(n)); state->addAttribute("n", builder->getI32IntegerAttr(n));
@ -456,6 +562,17 @@ void BuildReplicateOp(
block.addArgument(replicated_input.second); block.addArgument(replicated_input.second);
} }
for (auto& packed_input : packed_inputs) {
state->addOperands(packed_input);
block.addArgument(packed_input.getType());
}
// Add derived `operand_segment_sizes` attribute.
int32_t num_replicated_inputs = replicated_inputs.size() * n;
auto operand_segment_sizes = DenseIntElementsAttr::get(
VectorType::get({2}, builder->getI32Type()), {num_replicated_inputs, 0});
state->addAttribute(kOperandSegmentSizesAttr, operand_segment_sizes);
for (const auto& output_type : replica_output_types) for (const auto& output_type : replica_output_types)
state->addTypes(llvm::SmallVector<Type, 8>(n, output_type)); state->addTypes(llvm::SmallVector<Type, 8>(n, output_type));
} }
@ -466,9 +583,10 @@ void ReplicateOp::build(
const llvm::SmallDenseMap<StringRef, llvm::SmallVector<StringRef, 4>>& const llvm::SmallDenseMap<StringRef, llvm::SmallVector<StringRef, 4>>&
devices, devices,
llvm::ArrayRef<std::pair<llvm::ArrayRef<Value>, Type>> replicated_inputs, llvm::ArrayRef<std::pair<llvm::ArrayRef<Value>, Type>> replicated_inputs,
llvm::ArrayRef<Value> packed_inputs,
llvm::ArrayRef<Type> replica_output_types) { llvm::ArrayRef<Type> replica_output_types) {
BuildReplicateOp(&builder, &state, n, devices, replicated_inputs, BuildReplicateOp(&builder, &state, n, devices, replicated_inputs,
replica_output_types); packed_inputs, replica_output_types);
} }
void ReplicateOp::build( void ReplicateOp::build(
@ -476,9 +594,10 @@ void ReplicateOp::build(
const llvm::SmallDenseMap<StringRef, llvm::SmallVector<StringRef, 4>>& const llvm::SmallDenseMap<StringRef, llvm::SmallVector<StringRef, 4>>&
devices, devices,
llvm::ArrayRef<std::pair<Operation::operand_range, Type>> replicated_inputs, llvm::ArrayRef<std::pair<Operation::operand_range, Type>> replicated_inputs,
llvm::ArrayRef<Value> packed_inputs,
Operation::result_type_range replica_output_types) { Operation::result_type_range replica_output_types) {
BuildReplicateOp(&builder, &state, n, devices, replicated_inputs, BuildReplicateOp(&builder, &state, n, devices, replicated_inputs,
replica_output_types); packed_inputs, replica_output_types);
} }
//===----------------------------------------------------------------------===// //===----------------------------------------------------------------------===//

View File

@ -177,8 +177,8 @@ def TfDevice_ParallelExecuteOp : TfDevice_Op<"parallel_execute",
let verifier = [{ return Verify(*this); }]; let verifier = [{ return Verify(*this); }];
} }
def TfDevice_ReplicateOp : def TfDevice_ReplicateOp : TfDevice_Op<"replicate",
TfDevice_Op<"replicate", [SingleBlockImplicitTerminator<"ReturnOp">]> { [SingleBlockImplicitTerminator<"ReturnOp">, AttrSizedOperandSegments]> {
let summary = "Wraps an N-way replicated computation."; let summary = "Wraps an N-way replicated computation.";
let description = [{ let description = [{
@ -187,22 +187,30 @@ across multiple devices. The number of replications is based on the `n`
attribute. Explicit devices can be populated in the `devices` attribute, and it attribute. Explicit devices can be populated in the `devices` attribute, and it
must be a mapping of device alias to list of explicit or aliased device names must be a mapping of device alias to list of explicit or aliased device names
from the outer scope. The device name map specifies devices on which replicated from the outer scope. The device name map specifies devices on which replicated
ops inside tf_device.replicate will be executed. A tf_device.parallel_execute ops inside tf_device.replicate will be executed.
inside the tf_device.replicate op region may be used to represent computations
across a larger set of devices. In that case, the device alias can be used to A tf_device.parallel_execute inside the tf_device.replicate op region may be
specify device assignment and replication of each concurrent execution used to represent computations across a larger set of devices. In that case, the
(i.e. region) defined by tf_device.parallel_execute op. The size of each value device alias can be used to specify device assignment and replication of each
list in the device name map must match `n`. Within a replica, the execution concurrent execution (i.e. region) defined by tf_device.parallel_execute op.
semantics follow standard sequential behavior. Ops in the tf_device.replicate The size of each value list in the device name map must match `n`. Within a
wrapped with a tf_device.launch will have its device set to the associated replica, the execution semantics follow standard sequential behavior. Ops in the
replicated device from `devices` if the tf_device.launch refers to an aliased tf_device.replicate wrapped with a tf_device.launch will have its device set to
device name. Otherwise the device already set in tf_device.launch is used the associated replicated device from `devices` if the tf_device.launch refers
instead. Operands are replicated inputs: each group of `n` inputs corresponds to to an aliased device name. Otherwise the device already set in tf_device.launch
an input for a single individual replica and is mapped to a single region is used instead.
argument. Inside one group the operands are matching in order the `devices`
attribute. Each replicated input must have compatible shapes and types. Operands Operands are replicated inputs and packed inputs.
not replicated can be implicitly captured by ops in the region. Results are
replicated each from the regions terminator. replicated_inputs: each group of `n` inputs corresponds to an input for a single
individual replica and is mapped to a single region argument. Inside one group
the operands are matching in order the `devices` attribute. Each replicated
input must have compatible shapes and types.
packed_inputs: each input corresponds to an input broadcasted across all
replicas and is mapped to a single region argument.
Operands not replicated can be implicitly captured by ops in the region. Results
are replicated each from the regions terminator.
For example: For example:
``` ```
@ -214,46 +222,55 @@ For example:
%5 = "tf.opF"() : () -> tensor<!tf.resource> %5 = "tf.opF"() : () -> tensor<!tf.resource>
%6 = "tf.opG"() : () -> tensor<!tf.string> %6 = "tf.opG"() : () -> tensor<!tf.string>
%7 = "tf.opH"() : () -> tensor<!tf.string> %7 = "tf.opH"() : () -> tensor<!tf.string>
%8 = "tf.opI"() : () -> tensor<i1> %8 = "tf.opI"() : () -> tensor<!tf.variant>
%output:8 = tf_device.replicate([%0, %1] as %input_0:tensor<i32>, %9 = "tf.opJ"() : () -> tensor<i1>
[%2, %3] as %input_1:tensor<f32>, %output:8 = tf_device.replicate([%0, %1] as %input_0: tensor<i32>,
[%4, %5] as %input_2:tensor<!tf.resource> [%2, %3] as %input_1: tensor<f32>,
[%6, %7] as %input_3:tensor<!tf.string>) [%4, %5] as %input_2: tensor<!tf.resource>,
[%6, %7] as %input_3: tensor<!tf.string>,
%8 as %input_4: tensor<!tf.variant>)
{n = 2 : i32, {n = 2 : i32,
devices = {DEVICE_ALIAS_0 = ["/DEVICE:0", "/DEVICE:1"], devices = {DEVICE_ALIAS_0 = ["/DEVICE:0", "/DEVICE:1"],
DEVICE_ALIAS_1 = ["/DEVICE:2", "/DEVICE:3"]}} { DEVICE_ALIAS_1 = ["/DEVICE:2", "/DEVICE:3"]}} {
// Inside the region, %0, %2, %4, and %6 corresponds to // Inside the region, %0, %2, %4, and %6 corresponds to
// "/DEVICE:0"/"/DEVICE:2" and %1, %3, %5, and %7 corresponds to // "/DEVICE:0"/"/DEVICE:2" and %1, %3, %5, and %7 corresponds to
// "/DEVICE:1"/"/DEVICE:3", depending on which device alias is used. // "/DEVICE:1"/"/DEVICE:3", depending on which device alias is used.
%j = "tf_device.launch"() ( { %k = "tf_device.launch"() ( {
%9 = "tf.opJ"(%input_0, %6) : (tensor<i32>, tensor<i1>) -> tensor<i32> %9 = "tf.opK"(%input_0, %input_4, %9) :
(tensor<i32>, tensor<!tf.variant>, tensor<i1>) -> tensor<i32>
tf_device.return %9 : tensor<i32> tf_device.return %9 : tensor<i32>
}) {device = "DEVICE_ALIAS_0"} : () -> tensor<i32> }) {device = "DEVICE_ALIAS_0"} : () -> tensor<i32>
%k = "tf_device.launch"() ( { %l = "tf_device.launch"() ( {
%10 = "tf.opK"(%input_1, %6) : (tensor<f32>, tensor<i1>) -> tensor<f32> %10 = "tf.opL"(%input_1, %input_4, %9) :
(tensor<f32>, tensor<!tf.variant>, tensor<i1>) -> tensor<f32>
tf_device.return %10 : tensor<f32> tf_device.return %10 : tensor<f32>
}) {device = "DEVICE_ALIAS_1"} : () -> tensor<f32> }) {device = "DEVICE_ALIAS_1"} : () -> tensor<f32>
%l = "tf_device.launch"() ( { %m = "tf_device.launch"() ( {
%11 = "tf.opL"(%input_2, %6) : (tensor<!tf.resource>, tensor<i1>) %11 = "tf.opM"(%input_2, %input_4, %9) :
(tensor<!tf.resource>, tensor<!tf.variant>, tensor<i1>)
-> tensor<!tf.resource> -> tensor<!tf.resource>
tf_device.return %11 : tensor<!tf.resource> tf_device.return %11 : tensor<!tf.resource>
}) {device = "/DEVICE:4"} : () -> tensor<f32> }) {device = "/DEVICE:4"} : () -> tensor<f32>
%m = "tf.opM"(%input_3, %6) : (tensor<!tf.string>, tensor<i1>) %n = "tf.opN"(%input_3, %input_4, %9) :
(tensor<!tf.string>, tensor<!tf.variant>, tensor<i1>)
-> tensor<!tf.string> -> tensor<!tf.string>
tf_device.return %j, %k, %l, %m : tf_device.return %k, %l, %m, %n :
tensor<i32>, tensor<f32>, tensor<!tf.resource>, tensor<!tf.string> tensor<i32>, tensor<f32>, tensor<!tf.resource>, tensor<!tf.string>
} }
// %output#0 corresponds to %j returned from "/DEVICE:0" // %output#0 corresponds to %k returned from "/DEVICE:0"
// %output#1 corresponds to %j returned from "/DEVICE:1" // %output#1 corresponds to %k returned from "/DEVICE:1"
// %output#2 corresponds to %k returned from "/DEVICE:2" // %output#2 corresponds to %l returned from "/DEVICE:2"
// %output#3 corresponds to %k returned from "/DEVICE:3" // %output#3 corresponds to %l returned from "/DEVICE:3"
// %output#4, %output#5 corresponds to %l and will be returned from "/DEVICE:4" // %output#4, %output#5 corresponds to %m and will be returned from "/DEVICE:4"
// %output#6, %output#7 corresponds to %m and will have no device set // %output#6, %output#7 corresponds to %n and will have no device set
``` ```
}]; }];
let arguments = (ins let arguments = (ins
Variadic<AnyType>:$replicated_inputs, Variadic<AnyType>:$replicated_inputs,
Variadic<AnyType>:$packed_inputs,
I32ElementsAttr:$operand_segment_sizes,
Confined<I32Attr, [IntMinValue<2>]>:$n, Confined<I32Attr, [IntMinValue<2>]>:$n,
OptionalAttr<DictionaryAttr>:$devices OptionalAttr<DictionaryAttr>:$devices
); );
@ -272,10 +289,12 @@ For example:
OpBuilder<"OpBuilder& builder, OperationState& state, int n, " OpBuilder<"OpBuilder& builder, OperationState& state, int n, "
"const llvm::SmallDenseMap<StringRef, llvm::SmallVector<StringRef, 4>>& devices, " "const llvm::SmallDenseMap<StringRef, llvm::SmallVector<StringRef, 4>>& devices, "
"llvm::ArrayRef<std::pair<llvm::ArrayRef<Value>, Type>> replicated_inputs, " "llvm::ArrayRef<std::pair<llvm::ArrayRef<Value>, Type>> replicated_inputs, "
"llvm::ArrayRef<Value> packed_inputs, "
"llvm::ArrayRef<Type> replica_output_types">, "llvm::ArrayRef<Type> replica_output_types">,
OpBuilder<"OpBuilder& builder, OperationState& state, int n, " OpBuilder<"OpBuilder& builder, OperationState& state, int n, "
"const llvm::SmallDenseMap<StringRef, llvm::SmallVector<StringRef, 4>>& devices, " "const llvm::SmallDenseMap<StringRef, llvm::SmallVector<StringRef, 4>>& devices, "
"llvm::ArrayRef<std::pair<Operation::operand_range, Type>> replicated_inputs, " "llvm::ArrayRef<std::pair<Operation::operand_range, Type>> replicated_inputs, "
"llvm::ArrayRef<Value> packed_inputs, "
"Operation::result_type_range replica_output_types"> "Operation::result_type_range replica_output_types">
]; ];

View File

@ -4274,6 +4274,117 @@ LogicalResult WhileRegionOp::moveOutOfLoop(
return success(); return success();
} }
//===----------------------------------------------------------------------===//
// WhileRegionOp canonicalization
//===----------------------------------------------------------------------===//
namespace {
// Eliminate values that pass through the WhileRegionOp body.
struct WhileRegionEliminatePassThrough
: public OpRewritePattern<WhileRegionOp> {
using OpRewritePattern<WhileRegionOp>::OpRewritePattern;
LogicalResult matchAndRewrite(WhileRegionOp while_op,
PatternRewriter &rewriter) const override {
// Replace values that simply passthrough the body with extern values. The
// block arguments of body and while match and so the corresponding cond
// argument can be easily found.
int old_num_operands = while_op.getNumOperands();
int new_num_operands = old_num_operands;
auto &body_block = while_op.body().front();
auto &cond_block = while_op.cond().front();
auto &yield = *body_block.getTerminator();
// Bit mask indicating which operands will be removed.
SmallVector<bool, 16> removed_operand(old_num_operands, false);
for (int op_idx : llvm::seq<int>(0, old_num_operands)) {
auto body_arg = body_block.getArgument(op_idx);
if (body_arg == yield.getOperand(op_idx)) {
// Replace the use of the passthrough value with the while operand
// in the body and condition regions, as well as the while output (if
// type match)
// TODO(jurahul): Use PatternRewriter API for IR modification.
auto value = while_op.getOperand(op_idx);
if (body_arg.getType() == value.getType())
body_arg.replaceAllUsesWith(value);
auto cond_arg = cond_block.getArgument(op_idx);
if (cond_arg.getType() == value.getType())
cond_arg.replaceAllUsesWith(value);
auto result = while_op.getResult(op_idx);
if (result.getType() == value.getType())
result.replaceAllUsesWith(value);
}
// Now check if the operand is unused in both regions as well as the
// result. If so, mark it for removal.
if (body_block.getArgument(op_idx).use_empty() &&
cond_block.getArgument(op_idx).use_empty() &&
while_op.getResult(op_idx).use_empty()) {
removed_operand[op_idx] = true;
new_num_operands--;
}
}
if (new_num_operands == old_num_operands) return failure();
// Compress the operands, region arguments, and outputs.
SmallVector<Value, 4> new_while_operands;
SmallVector<Type, 4> new_result_types;
new_while_operands.reserve(new_num_operands);
new_result_types.reserve(new_num_operands);
// Build new operands and result type.
int next_idx = 0;
for (int op_idx : llvm::seq<int>(0, old_num_operands)) {
if (removed_operand[op_idx]) continue;
new_while_operands.push_back(while_op.getOperand(op_idx));
new_result_types.push_back(while_op.getResult(op_idx).getType());
next_idx++;
}
// Create the new while operation.
auto new_while_op =
rewriter.create<WhileRegionOp>(while_op.getLoc(), new_result_types,
new_while_operands, while_op.getAttrs());
// Move region bodies to the new while.
rewriter.inlineRegionBefore(while_op.cond(), new_while_op.cond(),
new_while_op.cond().end());
rewriter.inlineRegionBefore(while_op.body(), new_while_op.body(),
new_while_op.body().end());
auto &new_cond_block = new_while_op.cond().front();
auto &new_body_block = new_while_op.body().front();
auto &new_yield = *new_body_block.getTerminator();
// Build a vector of new results. Also patch up the region bodies and yield.
SmallVector<Value, 4> new_results;
next_idx = 0;
for (int op_idx : llvm::seq<int>(0, old_num_operands)) {
if (removed_operand[op_idx]) {
new_cond_block.eraseArgument(next_idx);
new_body_block.eraseArgument(next_idx);
new_yield.eraseOperand(next_idx);
new_results.push_back(nullptr);
} else {
new_results.push_back(new_while_op.getResult(next_idx++));
}
}
rewriter.replaceOp(while_op, new_results);
return success();
}
};
} // anonymous namespace
void WhileRegionOp::getCanonicalizationPatterns(
OwningRewritePatternList &results, MLIRContext *context) {
results.insert<WhileRegionEliminatePassThrough>(context);
}
//===----------------------------------------------------------------------===// //===----------------------------------------------------------------------===//
// XdivyOp // XdivyOp
//===----------------------------------------------------------------------===// //===----------------------------------------------------------------------===//

View File

@ -658,7 +658,9 @@ def TL_WhileRegionOp : TF_Op<"WhileRegion",
This implies that the operand and result types for tf.WhileRegion should be This implies that the operand and result types for tf.WhileRegion should be
the same. Note that the condition and body regions can implicitly capture the same. Note that the condition and body regions can implicitly capture
loop invariant values directly. loop invariant values directly. In canonical form, iteration variables that
pass through the loop body unmodified are converted to implicitly captured
references to their values outside the loop.
}]; }];
let arguments = (ins let arguments = (ins
@ -676,6 +678,8 @@ def TL_WhileRegionOp : TF_Op<"WhileRegion",
let regions = (region SizedRegion<1>:$cond, SizedRegion<1>:$body); let regions = (region SizedRegion<1>:$cond, SizedRegion<1>:$body);
let verifier = [{ return Verify(*this); }]; let verifier = [{ return Verify(*this); }];
let hasCanonicalizer = 1;
} }
def TF_TensorListReserveOp : TF_TensorListInitOp<"TensorListReserve"> { def TF_TensorListReserveOp : TF_TensorListInitOp<"TensorListReserve"> {
@ -717,7 +721,8 @@ This operation holds the metadata common to operations of a `tpu.replicate()` co
DefaultValuedAttr<StrArrayAttr, "{}">:$host_compute_core, DefaultValuedAttr<StrArrayAttr, "{}">:$host_compute_core,
DefaultValuedAttr<StrArrayAttr, "{}">:$padding_map, DefaultValuedAttr<StrArrayAttr, "{}">:$padding_map,
DefaultValuedAttr<StrAttr, "STEP_MARK_AT_ENTRY">:$step_marker_location, DefaultValuedAttr<StrAttr, "STEP_MARK_AT_ENTRY">:$step_marker_location,
DefaultValuedAttr<BoolAttr, "false">:$allow_soft_placement DefaultValuedAttr<BoolAttr, "false">:$allow_soft_placement,
DefaultValuedAttr<BoolAttr, "false">:$use_spmd_for_xla_partitioning
); );
let results = (outs); let results = (outs);

View File

@ -141,16 +141,27 @@ static LogicalResult VerifyIndexPath(Operation *op, NamedAttribute named_attr) {
return mlir::success(); return mlir::success();
} }
Type GetBoundInputArgTypeFor(GlobalTensorOp global_tensor) { Type GetBoundInputArgTypeFor(mlir::Operation *op) {
if (auto global_tensor = llvm::dyn_cast<GlobalTensorOp>(op)) {
auto type = global_tensor.type().cast<TensorType>(); auto type = global_tensor.type().cast<TensorType>();
return RankedTensorType::get( return RankedTensorType::get(
{}, TF::ResourceType::get({type}, type.getContext())); {}, TF::ResourceType::get({type}, type.getContext()));
}
if (auto asset = llvm::dyn_cast<AssetOp>(op)) {
return RankedTensorType::get({}, TF::StringType::get(asset.getContext()));
}
op->emitError() << "unknown symbol operation";
return {};
} }
static LogicalResult VerifyBoundInputArgType(Operation *op_for_diagnostics, static LogicalResult VerifyBoundInputArgType(Operation *op_for_diagnostics,
Type arg_type, Type arg_type,
GlobalTensorOp global_tensor) { mlir::Operation *symbol_op) {
auto expected_type = GetBoundInputArgTypeFor(global_tensor); auto expected_type = GetBoundInputArgTypeFor(symbol_op);
if (!expected_type) return failure();
if (arg_type != expected_type) { if (arg_type != expected_type) {
return op_for_diagnostics->emitError() return op_for_diagnostics->emitError()
<< "bound input with type " << arg_type << " expected to have type " << "bound input with type " << arg_type << " expected to have type "
@ -169,14 +180,14 @@ LogicalResult TensorFlowSavedModelDialect::verifyRegionArgAttribute(
} }
auto symbol_name = named_attr.second.cast<FlatSymbolRefAttr>().getValue(); auto symbol_name = named_attr.second.cast<FlatSymbolRefAttr>().getValue();
auto module = op->getParentOfType<ModuleOp>(); auto module = op->getParentOfType<ModuleOp>();
auto global_tensor = module.lookupSymbol<GlobalTensorOp>(symbol_name); mlir::Operation *symbol_op = module.lookupSymbol(symbol_name);
if (!global_tensor) { if (!symbol_op) {
return op->emitError() << "'tf_saved_model.bound_input' attribute must " return op->emitError() << "'tf_saved_model.bound_input' attribute must "
"reference a valid symbol, got invalid symbol '" "reference a valid symbol, got invalid symbol '"
<< symbol_name << "'"; << symbol_name << "'";
} }
auto arg_type = cast<FuncOp>(op).getArgument(arg_index).getType(); auto arg_type = cast<FuncOp>(op).getArgument(arg_index).getType();
return VerifyBoundInputArgType(op, arg_type, global_tensor); return VerifyBoundInputArgType(op, arg_type, symbol_op);
} }
if (named_attr.first == "tf_saved_model.index_path") { if (named_attr.first == "tf_saved_model.index_path") {
return VerifyIndexPath(op, named_attr); return VerifyIndexPath(op, named_attr);
@ -404,12 +415,12 @@ bool HasTfSavedModelSemantics(ModuleOp module) {
return module.getAttr("tf_saved_model.semantics") != nullptr; return module.getAttr("tf_saved_model.semantics") != nullptr;
} }
GlobalTensorOp LookupBoundInput(FuncOp func, int arg_index, Operation *LookupBoundInput(FuncOp func, int arg_index,
const SymbolTable &symbol_table) { const SymbolTable &symbol_table) {
auto attr = func.getArgAttrOfType<FlatSymbolRefAttr>( auto attr = func.getArgAttrOfType<FlatSymbolRefAttr>(
arg_index, "tf_saved_model.bound_input"); arg_index, "tf_saved_model.bound_input");
if (!attr) return nullptr; if (!attr) return nullptr;
return symbol_table.lookup<GlobalTensorOp>(attr.getValue()); return symbol_table.lookup(attr.getValue());
} }
SessionInitializerOp GetSessionInitializerOp(mlir::ModuleOp op) { SessionInitializerOp GetSessionInitializerOp(mlir::ModuleOp op) {

View File

@ -36,6 +36,8 @@ class TensorFlowSavedModelDialect : public Dialect {
NamedAttribute named_attr) override; NamedAttribute named_attr) override;
LogicalResult verifyOperationAttribute(Operation *op, LogicalResult verifyOperationAttribute(Operation *op,
NamedAttribute named_attr) override; NamedAttribute named_attr) override;
static StringRef getDialectNamespace() { return "tf_saved_model"; }
}; };
// Declares the operations for this dialect using the generated header. // Declares the operations for this dialect using the generated header.
@ -54,12 +56,19 @@ bool HasTfSavedModelSemantics(ModuleOp module);
// Returns the tf_saved_model.global_tensor op that func's arg_index'th argument // Returns the tf_saved_model.global_tensor op that func's arg_index'th argument
// refers to as a bound input, or null. // refers to as a bound input, or null.
GlobalTensorOp LookupBoundInput(FuncOp func, int arg_index, Operation *LookupBoundInput(FuncOp func, int arg_index,
const SymbolTable &symbol_table); const SymbolTable &symbol_table);
// Gets the type that an exported function arg that is bound to `global_tensor` template <typename T>
// should have. T LookupBoundInputOfType(FuncOp func, int arg_index,
Type GetBoundInputArgTypeFor(GlobalTensorOp global_tensor); const SymbolTable &symbol_table) {
return llvm::dyn_cast_or_null<T>(
LookupBoundInput(func, arg_index, symbol_table));
}
// Gets the type that an exported function arg that is bound to symbol ops such
// as `global_tensor` and `asset` should have.
Type GetBoundInputArgTypeFor(mlir::Operation *op);
// Returns the session initializer of this module if it exists. Returns null // Returns the session initializer of this module if it exists. Returns null
// otherwise. // otherwise.

View File

@ -19,6 +19,7 @@ limitations under the License.
#define SAVED_MODEL_DIALECT #define SAVED_MODEL_DIALECT
include "mlir/IR/OpBase.td" include "mlir/IR/OpBase.td"
include "mlir/IR/SymbolInterfaces.td"
//===----------------------------------------------------------------------===// //===----------------------------------------------------------------------===//
// Dialect definition // Dialect definition
@ -154,4 +155,24 @@ def TfSavedModel_SessionInitializerOp: TfSavedModel_Op<"session_initializer"> {
let hasCanonicalizer = 1; let hasCanonicalizer = 1;
} }
def TfSavedModel_AssetOp: TfSavedModel_Op<"asset", [Symbol]> {
let summary = "Represents an asset in saved model.";
let description = [{
Represents an asset in the saved model that points to an external file. It
is a scalar string tensor and it is passed as an argument to the session
initializer function.
The `sym_name` represents the symbol table name used for internal IR
references.
The `filename` attribute contains the file path to the asset file and it is
relative to saved model directory.
}];
let arguments = (ins
StrAttr:$sym_name,
StrAttr:$filename
);
}
#endif // SAVED_MODEL_DIALECT #endif // SAVED_MODEL_DIALECT

View File

@ -19,7 +19,7 @@ module attributes {tf.versions = {producer = 888 : i32}} {
// CHECK-LABEL: func @_func // CHECK-LABEL: func @_func
// CHECK-SAME: %[[ARG0:.*]]: tensor<?xi32>, // CHECK-SAME: %[[ARG0:.*]]: tensor<?xi32>,
// CHECK-SAME: %[[ARG1:.*]]: tensor<?xi32> {xla_hlo.is_same_data_across_replicas} // CHECK-SAME: %[[ARG1:.*]]: tensor<?xi32> {mhlo.is_same_data_across_replicas}
// CHECK-SAME: %[[ARG2:.*]]: tensor<?xi32>) // CHECK-SAME: %[[ARG2:.*]]: tensor<?xi32>)
func @_func(%arg0: tensor<?xi32>, %arg1: tensor<?xi32>, %arg2: tensor<?xi32>) -> tensor<?xi32> { func @_func(%arg0: tensor<?xi32>, %arg1: tensor<?xi32>, %arg2: tensor<?xi32>) -> tensor<?xi32> {
%0 = "tf._D"(%arg0, %arg1) : (tensor<?xi32>, tensor<?xi32>) -> tensor<?xi32> %0 = "tf._D"(%arg0, %arg1) : (tensor<?xi32>, tensor<?xi32>) -> tensor<?xi32>
@ -54,9 +54,9 @@ module attributes {tf.versions = {producer = 888 : i32}} {
} }
// CHECK-LABEL: func @_func // CHECK-LABEL: func @_func
// CHECK-SAME: %[[ARG0:.*]]: tensor<?xi32> {xla_hlo.is_same_data_across_replicas}, // CHECK-SAME: %[[ARG0:.*]]: tensor<?xi32> {mhlo.is_same_data_across_replicas},
// CHECK-SAME: %[[ARG1:.*]]: tensor<?xi32>, // CHECK-SAME: %[[ARG1:.*]]: tensor<?xi32>,
// CHECK-SAME: %[[ARG2:.*]]: tensor<!tf.resource<tensor<?xi32>>> {xla_hlo.is_same_data_across_replicas} // CHECK-SAME: %[[ARG2:.*]]: tensor<!tf.resource<tensor<?xi32>>> {mhlo.is_same_data_across_replicas}
func @_func(%arg0: tensor<?xi32>, %arg1: tensor<?xi32>, %arg2: tensor<!tf.resource<tensor<?xi32>>>) -> tensor<?xi32> { func @_func(%arg0: tensor<?xi32>, %arg1: tensor<?xi32>, %arg2: tensor<!tf.resource<tensor<?xi32>>>) -> tensor<?xi32> {
%0 = "tf._D"(%arg0, %arg1) : (tensor<?xi32>, tensor<?xi32>) -> tensor<?xi32> %0 = "tf._D"(%arg0, %arg1) : (tensor<?xi32>, tensor<?xi32>) -> tensor<?xi32>
return %0 : tensor<?xi32> return %0 : tensor<?xi32>
@ -78,7 +78,7 @@ module attributes {tf.versions = {producer = 888 : i32}} {
} }
// CHECK-LABEL: func @_func // CHECK-LABEL: func @_func
// CHECK-NOT: xla_hlo.is_same_data_across_replicas // CHECK-NOT: mhlo.is_same_data_across_replicas
func @_func(%arg0: tensor<?xi32>, %arg1: tensor<?xi32>) -> tensor<?xi32> { func @_func(%arg0: tensor<?xi32>, %arg1: tensor<?xi32>) -> tensor<?xi32> {
%0 = "tf._D"(%arg0, %arg1) : (tensor<?xi32>, tensor<?xi32>) -> tensor<?xi32> %0 = "tf._D"(%arg0, %arg1) : (tensor<?xi32>, tensor<?xi32>) -> tensor<?xi32>
return %0 : tensor<?xi32> return %0 : tensor<?xi32>

View File

@ -648,3 +648,152 @@ func @erase_tf_var_is_initialized(%arg0 : tensor<!tf.resource<tensor<f32>>>) ->
// Unused VarIsInitializedOp is erased. // Unused VarIsInitializedOp is erased.
// CHECK: tf.VarHandleOp // CHECK: tf.VarHandleOp
// CHECK-NEXT: tf.UnknownOp // CHECK-NEXT: tf.UnknownOp
// Simple pass through value
// CHECK-LABEL: testWhileRegionSimplePassThrough
func @testWhileRegionSimplePassThrough(%arg0 : tensor<*xf32>, %arg1 : tensor<i32>) -> tensor<*xf32> {
// CHECK: "tf.WhileRegion"(%arg1)
%0:2 = "tf.WhileRegion"(%arg0, %arg1) (
{
// condition, check if count has reached 0
^bb0(%carg0: tensor<*xf32>, %carg1: tensor<i32>):
%zero = constant dense<0> : tensor<i32>
%ne = "tf.NotEqual"(%carg1, %zero) : (tensor<i32>, tensor<i32>) -> tensor<i1>
"tf.Yield"(%ne) : (tensor<i1>) -> ()
},
{
// loop body
^bb0(%barg0: tensor<*xf32>, %barg1: tensor<i32>):
%one = constant dense<1> : tensor<i32>
%sub = "tf.Sub"(%barg1, %one) : (tensor<i32>, tensor<i32>) -> tensor<i32>
"tf.Yield"(%barg0, %sub) : (tensor<*xf32>, tensor<i32>) -> ()
}
) { is_stateless = false } : (tensor<*xf32>, tensor<i32>) -> (tensor<*xf32>, tensor<i32>)
// CHECK: return %arg0 : tensor<*xf32>
return %0#0 : tensor<*xf32>
}
// Multiple pass through values
// CHECK-LABEL: testWhileRegionMultiplePassThrough
func @testWhileRegionMultiplePassThrough(%arg0 : tensor<*xf32>, %arg1 : tensor<*xf32>, %arg2 : tensor<*xf32>, %arg3 : tensor<i32>) -> tensor<*xf32> {
// Verify that first 3 operands are elimiinated.
// CHECK: "tf.WhileRegion"(%arg3)
%0:4 = "tf.WhileRegion"(%arg0, %arg1, %arg2, %arg3) (
{
// condition, check if count has reached 0
^bb0(%carg0 : tensor<*xf32>, %carg1 : tensor<*xf32>, %carg2 : tensor<*xf32>, %carg3 : tensor<i32>):
%zero = constant dense<0> : tensor<i32>
%ne = "tf.NotEqual"(%carg3, %zero) : (tensor<i32>, tensor<i32>) -> tensor<i1>
"tf.Yield"(%ne) : (tensor<i1>) -> ()
},
{
// loop body
^bb0(%barg0 : tensor<*xf32>, %barg1 : tensor<*xf32>, %barg2 : tensor<*xf32>, %barg3 : tensor<i32>):
%one = constant dense<1> : tensor<i32>
%sub = "tf.Sub"(%barg3, %one) : (tensor<i32>, tensor<i32>) -> tensor<i32>
"tf.Yield"(%barg0, %barg1, %barg2, %sub) : (tensor<*xf32>, tensor<*xf32>, tensor<*xf32>, tensor<i32>) -> ()
}
) { is_stateless = false } : (tensor<*xf32>, tensor<*xf32>, tensor<*xf32>, tensor<i32>) -> (tensor<*xf32>, tensor<*xf32>, tensor<*xf32>, tensor<i32>)
// CHECK: %[[SUB0:.*]] = "tf.Sub"(%arg0, %arg1)
// CHECK: %[[SUB1:.*]] = "tf.Sub"(%arg2, %[[SUB0]])
// CHECK: return %[[SUB1]]
%sub0 = "tf.Sub" (%0#0, %0#1) : (tensor<*xf32>, tensor<*xf32>) -> tensor<*xf32>
%sub1 = "tf.Sub" (%0#2, %sub0) : (tensor<*xf32>, tensor<*xf32>) -> tensor<*xf32>
return %sub1 : tensor<*xf32>
}
// Multiple non contiguous pass through values
// CHECK-LABEL: testWhileRegionMultiplePassThroughNonContiguous
func @testWhileRegionMultiplePassThroughNonContiguous(%arg0 : tensor<*xf32>, %arg1 : tensor<*xf32>, %arg2 : tensor<*xf32>, %arg3 : tensor<i32>) -> tensor<*xf32> {
// Verify arg0 and arg2 are eliminated
// CHECK: %[[WHILE_OUT:.*]]:2 = "tf.WhileRegion"(%arg1, %arg3)
%0:4 = "tf.WhileRegion"(%arg0, %arg1, %arg2, %arg3) (
{
// condition, check if count has reached 0
^bb0(%carg0 : tensor<*xf32>, %carg1 : tensor<*xf32>, %carg2 : tensor<*xf32>, %carg3 : tensor<i32>):
%zero = constant dense<0> : tensor<i32>
%ne = "tf.NotEqual"(%carg3, %zero) : (tensor<i32>, tensor<i32>) -> tensor<i1>
"tf.Yield"(%ne) : (tensor<i1>) -> ()
},
{
// loop body
^bb0(%barg0 : tensor<*xf32>, %barg1 : tensor<*xf32>, %barg2 : tensor<*xf32>, %barg3 : tensor<i32>):
%arg1neg = "tf.Neg"(%barg1) : (tensor<*xf32>) -> tensor<*xf32>
%one = constant dense<1> : tensor<i32>
%sub = "tf.Sub"(%barg3, %one) : (tensor<i32>, tensor<i32>) -> tensor<i32>
"tf.Yield"(%barg0, %arg1neg, %barg2, %sub) : (tensor<*xf32>, tensor<*xf32>, tensor<*xf32>, tensor<i32>) -> ()
}
) { is_stateless = false } : (tensor<*xf32>, tensor<*xf32>, tensor<*xf32>, tensor<i32>) -> (tensor<*xf32>, tensor<*xf32>, tensor<*xf32>, tensor<i32>)
// Verify that use of while loop results corresponding to result #0 and 2 of
// the while are replaces with corresponding WhileRegion operands
// CHECK: %[[SUB0:.*]] = "tf.Sub"(%arg0, %[[WHILE_OUT]]#0)
// CHECK: %[[SUB1:.*]] = "tf.Sub"(%arg2, %[[SUB0]])
// CHECK: return %[[SUB1]]
%sub0 = "tf.Sub" (%0#0, %0#1) : (tensor<*xf32>, tensor<*xf32>) -> tensor<*xf32>
%sub1 = "tf.Sub" (%0#2, %sub0) : (tensor<*xf32>, tensor<*xf32>) -> tensor<*xf32>
return %sub1 : tensor<*xf32>
}
// Pass through but with type mismatch (tensor<*xf32> is compatible with
// tensor<?x?xf32> in the body). WhileRegion canonicalization does not handle
// this.
// CHECK-LABEL: testWhileRegionPassThroughTypeMismatch
func @testWhileRegionPassThroughTypeMismatch(%arg0 : tensor<*xf32>, %arg1 : tensor<i32>) -> tensor<*xf32> {
// Verify that the While stay's unchanged
// CHECK: "tf.WhileRegion"(%arg0, %arg1)
%0:2 = "tf.WhileRegion"(%arg0, %arg1) (
{
// condition, check if count has reached 0
^bb0(%carg0: tensor<*xf32>, %carg1: tensor<i32>):
%zero = constant dense<0> : tensor<i32>
%ne = "tf.NotEqual"(%carg1, %zero) : (tensor<i32>, tensor<i32>) -> tensor<i1>
"tf.Yield"(%ne) : (tensor<i1>) -> ()
},
{
// loop body
^bb0(%barg0: tensor<?x?xf32>, %barg1: tensor<i32>):
%one = constant dense<1> : tensor<i32>
%sub = "tf.Sub"(%barg1, %one) : (tensor<i32>, tensor<i32>) -> tensor<i32>
"tf.Yield"(%barg0, %sub) : (tensor<?x?xf32>, tensor<i32>) -> ()
}
) { is_stateless = false } : (tensor<*xf32>, tensor<i32>) -> (tensor<*xf32>, tensor<i32>)
// Verify that the result stays uchanged
// CHECK: return %arg0 : tensor<*xf32>
return %0#0 : tensor<*xf32>
}
// Unused value flowing through the while (operand 2 and 3, is unused in the
// while and the corresponding result is unused as well). Canonicalization will
// eliminate them.
// CHECK-LABEL: testWhileRegionUnusedValue
func @testWhileRegionUnusedValue(%arg0 : tensor<*xf32>, %arg1 : tensor<i32>, %arg2: tensor<i32>) -> tensor<*xf32> {
%cst = constant dense <33.0> : tensor<f32>
// Verify that last 2 operands of while (unused) are removed
// CHECK: %[[WHILE_OUT:.*]]:2 = "tf.WhileRegion"(%arg0, %arg1)
%0:4 = "tf.WhileRegion"(%arg0, %arg1, %arg2, %cst) (
{
// condition, check if count has reached 0
^bb0(%carg0: tensor<*xf32>, %carg1: tensor<i32>, %carg2:tensor<i32>, %carg3:tensor<f32>):
%zero = constant dense<0> : tensor<i32>
%ne = "tf.NotEqual"(%carg1, %zero) : (tensor<i32>, tensor<i32>) -> tensor<i1>
"tf.Yield"(%ne) : (tensor<i1>) -> ()
},
{
// loop body
^bb0(%barg0: tensor<*xf32>, %barg1: tensor<i32>, %barg2:tensor<i32>, %barg3:tensor<f32>):
%add = "tf.Add"(%barg0, %barg0) : (tensor<*xf32>, tensor<*xf32>) -> tensor<*xf32>
%one = constant dense<1> : tensor<i32>
%sub = "tf.Sub"(%barg1, %one) : (tensor<i32>, tensor<i32>) -> tensor<i32>
%dummy0 = constant dense<7> : tensor<i32>
%dummy1 = constant dense<3.0> : tensor<f32>
"tf.Yield"(%add, %sub, %dummy0, %dummy1) : (tensor<*xf32>, tensor<i32>, tensor<i32>, tensor<f32>) -> ()
}
) { is_stateless = false } : (tensor<*xf32>, tensor<i32>, tensor<i32>, tensor<f32>) -> (tensor<*xf32>, tensor<i32>, tensor<i32>, tensor<f32>)
// Verify that return still uses while result # 0
// CHECK: return %[[WHILE_OUT]]#0 : tensor<*xf32>
return %0#0 : tensor<*xf32>
}

View File

@ -17,8 +17,8 @@ func @biasAdd_dynamic(%arg0: tensor<?x?x?x?xi32>, %arg1: tensor<?xi32>) -> tenso
} }
func @add(%arg0: tensor<2xi32>) -> tensor<2xi32> { func @add(%arg0: tensor<2xi32>) -> tensor<2xi32> {
%0 = xla_hlo.add %arg0, %arg0 : tensor<2xi32> %0 = mhlo.add %arg0, %arg0 : tensor<2xi32>
%1 = xla_hlo.add %0, %arg0 : tensor<2xi32> %1 = mhlo.add %0, %arg0 : tensor<2xi32>
return %1 : tensor<2xi32> return %1 : tensor<2xi32>
} }
@ -33,7 +33,7 @@ func @broadcast_multi_dim_add(%arg0: tensor<4x1x1xi32>, %arg1: tensor<4x4x4x4xi3
} }
func @div(%arg0: tensor<2xi32>) -> tensor<2xi32> { func @div(%arg0: tensor<2xi32>) -> tensor<2xi32> {
%0 = xla_hlo.divide %arg0, %arg0 : tensor<2xi32> %0 = mhlo.divide %arg0, %arg0 : tensor<2xi32>
return %0 : tensor<2xi32> return %0 : tensor<2xi32>
} }
@ -43,7 +43,7 @@ func @broadcast_div(%arg0: tensor<1xi32>, %arg1: tensor<1x2xi32>) -> tensor<1x2x
} }
func @shift_left(%arg0: tensor<4xi32>, %arg1: tensor<4xi32>) -> tensor<4xi32> { func @shift_left(%arg0: tensor<4xi32>, %arg1: tensor<4xi32>) -> tensor<4xi32> {
%0 = xla_hlo.shift_left %arg0, %arg1 : tensor<4xi32> %0 = mhlo.shift_left %arg0, %arg1 : tensor<4xi32>
return %0 : tensor<4xi32> return %0 : tensor<4xi32>
} }
@ -53,17 +53,17 @@ func @div_dynamic(%arg0: tensor<?xi32>, %arg1: tensor<?x?xi32>) -> tensor<?x?xi3
} }
func @maximum(%arg0: tensor<4xf32>, %arg1: tensor<4xf32>) -> tensor<4xf32> { func @maximum(%arg0: tensor<4xf32>, %arg1: tensor<4xf32>) -> tensor<4xf32> {
%0 = xla_hlo.maximum %arg0, %arg1 : tensor<4xf32> %0 = mhlo.maximum %arg0, %arg1 : tensor<4xf32>
return %0 : tensor<4xf32> return %0 : tensor<4xf32>
} }
func @minimum(%arg0: tensor<4xf32>, %arg1: tensor<4xf32>) -> tensor<4xf32> { func @minimum(%arg0: tensor<4xf32>, %arg1: tensor<4xf32>) -> tensor<4xf32> {
%0 = xla_hlo.minimum %arg0, %arg1 : tensor<4xf32> %0 = mhlo.minimum %arg0, %arg1 : tensor<4xf32>
return %0 : tensor<4xf32> return %0 : tensor<4xf32>
} }
func @mul(%arg0: tensor<2xi32>) -> tensor<2xi32> { func @mul(%arg0: tensor<2xi32>) -> tensor<2xi32> {
%0 = xla_hlo.multiply %arg0, %arg0 : tensor<2xi32> %0 = mhlo.multiply %arg0, %arg0 : tensor<2xi32>
return %0 : tensor<2xi32> return %0 : tensor<2xi32>
} }
@ -73,7 +73,7 @@ func @broadcast_mul(%arg0: tensor<1xi32>, %arg1: tensor<1x2xi32>) -> tensor<1x2x
} }
func @real_div(%arg0: tensor<2xi32>) -> tensor<2xi32> { func @real_div(%arg0: tensor<2xi32>) -> tensor<2xi32> {
%0 = xla_hlo.divide %arg0, %arg0 : tensor<2xi32> %0 = mhlo.divide %arg0, %arg0 : tensor<2xi32>
return %0 : tensor<2xi32> return %0 : tensor<2xi32>
} }
@ -83,7 +83,7 @@ func @broadcast_real_div(%arg0: tensor<1xi32>, %arg1: tensor<1x2xi32>) -> tensor
} }
func @sub(%arg0: tensor<2xi32>) -> tensor<2xi32> { func @sub(%arg0: tensor<2xi32>) -> tensor<2xi32> {
%0 = xla_hlo.subtract %arg0, %arg0 : tensor<2xi32> %0 = mhlo.subtract %arg0, %arg0 : tensor<2xi32>
return %0 : tensor<2xi32> return %0 : tensor<2xi32>
} }
@ -93,7 +93,7 @@ func @broadcast_sub(%arg0: tensor<1xi32>, %arg1: tensor<1x2xi32>) -> tensor<1x2x
} }
func @shift_right(%arg0: tensor<4xi32>, %arg1: tensor<4xi32>) -> tensor<4xi32> { func @shift_right(%arg0: tensor<4xi32>, %arg1: tensor<4xi32>) -> tensor<4xi32> {
%0 = xla_hlo.shift_right_arithmetic %arg0, %arg1 : tensor<4xi32> %0 = mhlo.shift_right_arithmetic %arg0, %arg1 : tensor<4xi32>
return %0 : tensor<4xi32> return %0 : tensor<4xi32>
} }
@ -103,7 +103,7 @@ func @broadcast_shift_right(%arg0: tensor<4xi32>, %arg1: tensor<2x4xi32>) -> ten
} }
func @and(%arg0: tensor<2xi1>) -> tensor<2xi1> { func @and(%arg0: tensor<2xi1>) -> tensor<2xi1> {
%0 = xla_hlo.and %arg0, %arg0 : tensor<2xi1> %0 = mhlo.and %arg0, %arg0 : tensor<2xi1>
return %0 : tensor<2xi1> return %0 : tensor<2xi1>
} }
@ -118,7 +118,7 @@ func @and_dynamic(%arg0: tensor<?xi1>, %arg1: tensor<1xi1>) -> tensor<?xi1> {
} }
func @or(%arg0: tensor<2xi1>) -> tensor<2xi1> { func @or(%arg0: tensor<2xi1>) -> tensor<2xi1> {
%0 = xla_hlo.or %arg0, %arg0 : tensor<2xi1> %0 = mhlo.or %arg0, %arg0 : tensor<2xi1>
return %0 : tensor<2xi1> return %0 : tensor<2xi1>
} }
@ -133,7 +133,7 @@ func @or_dynamic(%arg0: tensor<?xi1>, %arg1: tensor<1xi1>) -> tensor<?xi1> {
} }
func @bitwise_or(%arg0: tensor<4xi32>, %arg1: tensor<4xi32>) -> tensor<4xi32> { func @bitwise_or(%arg0: tensor<4xi32>, %arg1: tensor<4xi32>) -> tensor<4xi32> {
%0 = xla_hlo.or %arg0, %arg1 : tensor<4xi32> %0 = mhlo.or %arg0, %arg1 : tensor<4xi32>
return %0 : tensor<4xi32> return %0 : tensor<4xi32>
} }
@ -148,7 +148,7 @@ func @bitwise_or_dynamic(%arg0: tensor<?xi32>, %arg1: tensor<1xi32>) -> tensor<?
} }
func @bitwise_and(%arg0: tensor<4xi32>, %arg1: tensor<4xi32>) -> tensor<4xi32> { func @bitwise_and(%arg0: tensor<4xi32>, %arg1: tensor<4xi32>) -> tensor<4xi32> {
%0 = xla_hlo.and %arg0, %arg1 : tensor<4xi32> %0 = mhlo.and %arg0, %arg1 : tensor<4xi32>
return %0 : tensor<4xi32> return %0 : tensor<4xi32>
} }
@ -163,69 +163,69 @@ func @bitwise_and_dynamic(%arg0: tensor<?xi32>, %arg1: tensor<1xi32>) -> tensor<
} }
func @pow(%arg0: tensor<2xf32>) -> tensor<2xf32> { func @pow(%arg0: tensor<2xf32>) -> tensor<2xf32> {
%0 = xla_hlo.power %arg0, %arg0 : tensor<2xf32> %0 = mhlo.power %arg0, %arg0 : tensor<2xf32>
return %0 : tensor<2xf32> return %0 : tensor<2xf32>
} }
func @pow_dynamic(%arg0: tensor<?xf32>) -> tensor<?xf32> { func @pow_dynamic(%arg0: tensor<?xf32>) -> tensor<?xf32> {
%0 = xla_hlo.power %arg0, %arg0 : tensor<?xf32> %0 = mhlo.power %arg0, %arg0 : tensor<?xf32>
return %0 : tensor<?xf32> return %0 : tensor<?xf32>
} }
func @floordiv_broadcast_i32(%arg0: tensor<2x3xi32>, %arg1: tensor<3xi32>) -> tensor<2x3xi32> { func @floordiv_broadcast_i32(%arg0: tensor<2x3xi32>, %arg1: tensor<3xi32>) -> tensor<2x3xi32> {
%0 = xla_hlo.constant dense<0> : tensor<2x3xi32> %0 = mhlo.constant dense<0> : tensor<2x3xi32>
%1 = "xla_chlo.broadcast_compare"(%arg0, %0) {comparison_direction = "LT"} : (tensor<2x3xi32>, tensor<2x3xi32>) -> tensor<2x3xi1> %1 = "xla_chlo.broadcast_compare"(%arg0, %0) {comparison_direction = "LT"} : (tensor<2x3xi32>, tensor<2x3xi32>) -> tensor<2x3xi1>
%2 = xla_hlo.constant dense<0> : tensor<3xi32> %2 = mhlo.constant dense<0> : tensor<3xi32>
%3 = "xla_chlo.broadcast_compare"(%arg1, %2) {comparison_direction = "LT"} : (tensor<3xi32>, tensor<3xi32>) -> tensor<3xi1> %3 = "xla_chlo.broadcast_compare"(%arg1, %2) {comparison_direction = "LT"} : (tensor<3xi32>, tensor<3xi32>) -> tensor<3xi1>
%4 = "xla_chlo.broadcast_compare"(%1, %3) {broadcast_dimensions = dense<1> : tensor<1xi64>, comparison_direction = "EQ"} : (tensor<2x3xi1>, tensor<3xi1>) -> tensor<2x3xi1> %4 = "xla_chlo.broadcast_compare"(%1, %3) {broadcast_dimensions = dense<1> : tensor<1xi64>, comparison_direction = "EQ"} : (tensor<2x3xi1>, tensor<3xi1>) -> tensor<2x3xi1>
%5 = "xla_chlo.broadcast_divide"(%arg0, %arg1) {broadcast_dimensions = dense<1> : tensor<1xi64>} : (tensor<2x3xi32>, tensor<3xi32>) -> tensor<2x3xi32> %5 = "xla_chlo.broadcast_divide"(%arg0, %arg1) {broadcast_dimensions = dense<1> : tensor<1xi64>} : (tensor<2x3xi32>, tensor<3xi32>) -> tensor<2x3xi32>
%6 = "xla_hlo.abs"(%arg0) : (tensor<2x3xi32>) -> tensor<2x3xi32> %6 = "mhlo.abs"(%arg0) : (tensor<2x3xi32>) -> tensor<2x3xi32>
%7 = "xla_hlo.abs"(%arg1) : (tensor<3xi32>) -> tensor<3xi32> %7 = "mhlo.abs"(%arg1) : (tensor<3xi32>) -> tensor<3xi32>
%8 = xla_hlo.constant dense<1> : tensor<3xi32> %8 = mhlo.constant dense<1> : tensor<3xi32>
%9 = xla_hlo.subtract %7, %8 : tensor<3xi32> %9 = mhlo.subtract %7, %8 : tensor<3xi32>
%10 = "xla_chlo.broadcast_add"(%6, %9) {broadcast_dimensions = dense<1> : tensor<1xi64>} : (tensor<2x3xi32>, tensor<3xi32>) -> tensor<2x3xi32> %10 = "xla_chlo.broadcast_add"(%6, %9) {broadcast_dimensions = dense<1> : tensor<1xi64>} : (tensor<2x3xi32>, tensor<3xi32>) -> tensor<2x3xi32>
%11 = "xla_hlo.negate"(%10) : (tensor<2x3xi32>) -> tensor<2x3xi32> %11 = "mhlo.negate"(%10) : (tensor<2x3xi32>) -> tensor<2x3xi32>
%12 = "xla_hlo.abs"(%arg1) : (tensor<3xi32>) -> tensor<3xi32> %12 = "mhlo.abs"(%arg1) : (tensor<3xi32>) -> tensor<3xi32>
%13 = "xla_chlo.broadcast_divide"(%11, %12) {broadcast_dimensions = dense<1> : tensor<1xi64>} : (tensor<2x3xi32>, tensor<3xi32>) -> tensor<2x3xi32> %13 = "xla_chlo.broadcast_divide"(%11, %12) {broadcast_dimensions = dense<1> : tensor<1xi64>} : (tensor<2x3xi32>, tensor<3xi32>) -> tensor<2x3xi32>
%14 = "xla_hlo.select"(%4, %5, %13) : (tensor<2x3xi1>, tensor<2x3xi32>, tensor<2x3xi32>) -> tensor<2x3xi32> %14 = "mhlo.select"(%4, %5, %13) : (tensor<2x3xi1>, tensor<2x3xi32>, tensor<2x3xi32>) -> tensor<2x3xi32>
return %14 : tensor<2x3xi32> return %14 : tensor<2x3xi32>
} }
func @floordiv_reverse_broadcast_i32(%arg0: tensor<3xi32>, %arg1: tensor<2x3xi32>) -> tensor<2x3xi32> { func @floordiv_reverse_broadcast_i32(%arg0: tensor<3xi32>, %arg1: tensor<2x3xi32>) -> tensor<2x3xi32> {
%0 = xla_hlo.constant dense<0> : tensor<3xi32> %0 = mhlo.constant dense<0> : tensor<3xi32>
%1 = "xla_hlo.compare"(%arg0, %0) {comparison_direction = "LT"} : (tensor<3xi32>, tensor<3xi32>) -> tensor<3xi1> %1 = "mhlo.compare"(%arg0, %0) {comparison_direction = "LT"} : (tensor<3xi32>, tensor<3xi32>) -> tensor<3xi1>
%2 = xla_hlo.constant dense<0> : tensor<2x3xi32> %2 = mhlo.constant dense<0> : tensor<2x3xi32>
%3 = "xla_chlo.broadcast_compare"(%arg1, %2) {comparison_direction = "LT"} : (tensor<2x3xi32>, tensor<2x3xi32>) -> tensor<2x3xi1> %3 = "xla_chlo.broadcast_compare"(%arg1, %2) {comparison_direction = "LT"} : (tensor<2x3xi32>, tensor<2x3xi32>) -> tensor<2x3xi1>
%4 = "xla_chlo.broadcast_compare"(%1, %3) {broadcast_dimensions = dense<1> : tensor<1xi64>, comparison_direction = "EQ"} : (tensor<3xi1>, tensor<2x3xi1>) -> tensor<2x3xi1> %4 = "xla_chlo.broadcast_compare"(%1, %3) {broadcast_dimensions = dense<1> : tensor<1xi64>, comparison_direction = "EQ"} : (tensor<3xi1>, tensor<2x3xi1>) -> tensor<2x3xi1>
%5 = "xla_chlo.broadcast_divide"(%arg0, %arg1) {broadcast_dimensions = dense<1> : tensor<1xi64>} : (tensor<3xi32>, tensor<2x3xi32>) -> tensor<2x3xi32> %5 = "xla_chlo.broadcast_divide"(%arg0, %arg1) {broadcast_dimensions = dense<1> : tensor<1xi64>} : (tensor<3xi32>, tensor<2x3xi32>) -> tensor<2x3xi32>
%6 = "xla_hlo.abs"(%arg0) : (tensor<3xi32>) -> tensor<3xi32> %6 = "mhlo.abs"(%arg0) : (tensor<3xi32>) -> tensor<3xi32>
%7 = "xla_hlo.abs"(%arg1) : (tensor<2x3xi32>) -> tensor<2x3xi32> %7 = "mhlo.abs"(%arg1) : (tensor<2x3xi32>) -> tensor<2x3xi32>
%8 = xla_hlo.constant dense<1> : tensor<2x3xi32> %8 = mhlo.constant dense<1> : tensor<2x3xi32>
%9 = xla_hlo.subtract %7, %8 : tensor<2x3xi32> %9 = mhlo.subtract %7, %8 : tensor<2x3xi32>
%10 = "xla_chlo.broadcast_add"(%6, %9) {broadcast_dimensions = dense<1> : tensor<1xi64>} : (tensor<3xi32>, tensor<2x3xi32>) -> tensor<2x3xi32> %10 = "xla_chlo.broadcast_add"(%6, %9) {broadcast_dimensions = dense<1> : tensor<1xi64>} : (tensor<3xi32>, tensor<2x3xi32>) -> tensor<2x3xi32>
%11 = "xla_hlo.negate"(%10) : (tensor<2x3xi32>) -> tensor<2x3xi32> %11 = "mhlo.negate"(%10) : (tensor<2x3xi32>) -> tensor<2x3xi32>
%12 = "xla_hlo.abs"(%arg1) : (tensor<2x3xi32>) -> tensor<2x3xi32> %12 = "mhlo.abs"(%arg1) : (tensor<2x3xi32>) -> tensor<2x3xi32>
%13 = xla_hlo.divide %11, %12 : tensor<2x3xi32> %13 = mhlo.divide %11, %12 : tensor<2x3xi32>
%14 = "xla_hlo.select"(%4, %5, %13) : (tensor<2x3xi1>, tensor<2x3xi32>, tensor<2x3xi32>) -> tensor<2x3xi32> %14 = "mhlo.select"(%4, %5, %13) : (tensor<2x3xi1>, tensor<2x3xi32>, tensor<2x3xi32>) -> tensor<2x3xi32>
return %14 : tensor<2x3xi32> return %14 : tensor<2x3xi32>
} }
func @floordiv_f32(%arg0: tensor<2xf32>) -> tensor<2xf32> { func @floordiv_f32(%arg0: tensor<2xf32>) -> tensor<2xf32> {
%0 = xla_hlo.divide %arg0, %arg0 : tensor<2xf32> %0 = mhlo.divide %arg0, %arg0 : tensor<2xf32>
%1 = xla_hlo.divide %arg0, %arg0 : tensor<2xf32> %1 = mhlo.divide %arg0, %arg0 : tensor<2xf32>
%2 = "xla_hlo.floor"(%1) : (tensor<2xf32>) -> tensor<2xf32> %2 = "mhlo.floor"(%1) : (tensor<2xf32>) -> tensor<2xf32>
return %2 : tensor<2xf32> return %2 : tensor<2xf32>
} }
func @floordiv_f16_broadcast(%arg0: tensor<2x3xf16>, %arg1: tensor<3xf16>) -> tensor<2x3xf16> { func @floordiv_f16_broadcast(%arg0: tensor<2x3xf16>, %arg1: tensor<3xf16>) -> tensor<2x3xf16> {
%0 = "xla_chlo.broadcast_divide"(%arg0, %arg1) {broadcast_dimensions = dense<1> : tensor<1xi64>} : (tensor<2x3xf16>, tensor<3xf16>) -> tensor<2x3xf16> %0 = "xla_chlo.broadcast_divide"(%arg0, %arg1) {broadcast_dimensions = dense<1> : tensor<1xi64>} : (tensor<2x3xf16>, tensor<3xf16>) -> tensor<2x3xf16>
%1 = "xla_chlo.broadcast_divide"(%arg0, %arg1) {broadcast_dimensions = dense<1> : tensor<1xi64>} : (tensor<2x3xf16>, tensor<3xf16>) -> tensor<2x3xf16> %1 = "xla_chlo.broadcast_divide"(%arg0, %arg1) {broadcast_dimensions = dense<1> : tensor<1xi64>} : (tensor<2x3xf16>, tensor<3xf16>) -> tensor<2x3xf16>
%2 = "xla_hlo.floor"(%1) : (tensor<2x3xf16>) -> tensor<2x3xf16> %2 = "mhlo.floor"(%1) : (tensor<2x3xf16>) -> tensor<2x3xf16>
return %2 : tensor<2x3xf16> return %2 : tensor<2x3xf16>
} }
func @equal(%arg0: tensor<2xi32>) -> tensor<2xi1> { func @equal(%arg0: tensor<2xi32>) -> tensor<2xi1> {
%0 = "xla_hlo.compare"(%arg0, %arg0) {comparison_direction = "EQ"} : (tensor<2xi32>, tensor<2xi32>) -> tensor<2xi1> %0 = "mhlo.compare"(%arg0, %arg0) {comparison_direction = "EQ"} : (tensor<2xi32>, tensor<2xi32>) -> tensor<2xi1>
return %0 : tensor<2xi1> return %0 : tensor<2xi1>
} }
@ -250,7 +250,7 @@ func @equal_incompatible_shape_broadcastable(%arg0: tensor<?xi32>, %arg1: tensor
} }
func @notequal(%arg0: tensor<2xi32>) -> tensor<2xi1> { func @notequal(%arg0: tensor<2xi32>) -> tensor<2xi1> {
%0 = "xla_hlo.compare"(%arg0, %arg0) {comparison_direction = "NE"} : (tensor<2xi32>, tensor<2xi32>) -> tensor<2xi1> %0 = "mhlo.compare"(%arg0, %arg0) {comparison_direction = "NE"} : (tensor<2xi32>, tensor<2xi32>) -> tensor<2xi1>
return %0 : tensor<2xi1> return %0 : tensor<2xi1>
} }
@ -270,7 +270,7 @@ func @notequal_incompatible_shape_broadcastable(%arg0: tensor<?xi32>, %arg1: ten
} }
func @greater(%arg0: tensor<2xi32>) -> tensor<2xi1> { func @greater(%arg0: tensor<2xi32>) -> tensor<2xi1> {
%0 = "xla_hlo.compare"(%arg0, %arg0) {comparison_direction = "GT"} : (tensor<2xi32>, tensor<2xi32>) -> tensor<2xi1> %0 = "mhlo.compare"(%arg0, %arg0) {comparison_direction = "GT"} : (tensor<2xi32>, tensor<2xi32>) -> tensor<2xi1>
return %0 : tensor<2xi1> return %0 : tensor<2xi1>
} }
@ -280,7 +280,7 @@ func @broadcast_greater(%arg0: tensor<1xi32>, %arg1: tensor<1x2xi32>) -> tensor<
} }
func @greater_equal(%arg0: tensor<2xi32>) -> tensor<2xi1> { func @greater_equal(%arg0: tensor<2xi32>) -> tensor<2xi1> {
%0 = "xla_hlo.compare"(%arg0, %arg0) {comparison_direction = "GE"} : (tensor<2xi32>, tensor<2xi32>) -> tensor<2xi1> %0 = "mhlo.compare"(%arg0, %arg0) {comparison_direction = "GE"} : (tensor<2xi32>, tensor<2xi32>) -> tensor<2xi1>
return %0 : tensor<2xi1> return %0 : tensor<2xi1>
} }
@ -290,7 +290,7 @@ func @broadcast_greater_equal(%arg0: tensor<1xi32>, %arg1: tensor<1x2xi32>) -> t
} }
func @less(%arg0: tensor<2xi32>) -> tensor<2xi1> { func @less(%arg0: tensor<2xi32>) -> tensor<2xi1> {
%0 = "xla_hlo.compare"(%arg0, %arg0) {comparison_direction = "LT"} : (tensor<2xi32>, tensor<2xi32>) -> tensor<2xi1> %0 = "mhlo.compare"(%arg0, %arg0) {comparison_direction = "LT"} : (tensor<2xi32>, tensor<2xi32>) -> tensor<2xi1>
return %0 : tensor<2xi1> return %0 : tensor<2xi1>
} }
@ -300,7 +300,7 @@ func @broadcast_less(%arg0: tensor<1xi32>, %arg1: tensor<1x2xi32>) -> tensor<1x2
} }
func @less_equal(%arg0: tensor<2xi32>) -> tensor<2xi1> { func @less_equal(%arg0: tensor<2xi32>) -> tensor<2xi1> {
%0 = "xla_hlo.compare"(%arg0, %arg0) {comparison_direction = "LE"} : (tensor<2xi32>, tensor<2xi32>) -> tensor<2xi1> %0 = "mhlo.compare"(%arg0, %arg0) {comparison_direction = "LE"} : (tensor<2xi32>, tensor<2xi32>) -> tensor<2xi1>
return %0 : tensor<2xi1> return %0 : tensor<2xi1>
} }
@ -310,426 +310,426 @@ func @broadcast_less_equal(%arg0: tensor<1xi32>, %arg1: tensor<1x2xi32>) -> tens
} }
func @concat_v2(%arg0: tensor<3x3xf32>, %arg1: tensor<3x3xf32>) -> tensor<6x3xf32> { func @concat_v2(%arg0: tensor<3x3xf32>, %arg1: tensor<3x3xf32>) -> tensor<6x3xf32> {
%2 = "xla_hlo.concatenate"(%arg0, %arg1) {dimension = 0 : i64} : (tensor<3x3xf32>, tensor<3x3xf32>) -> tensor<6x3xf32> %2 = "mhlo.concatenate"(%arg0, %arg1) {dimension = 0 : i64} : (tensor<3x3xf32>, tensor<3x3xf32>) -> tensor<6x3xf32>
return %2 : tensor<6x3xf32> return %2 : tensor<6x3xf32>
} }
func @concat_v2_1d_axis(%arg0: tensor<3x3xf32>, %arg1: tensor<3x3xf32>) -> tensor<3x6xf32> { func @concat_v2_1d_axis(%arg0: tensor<3x3xf32>, %arg1: tensor<3x3xf32>) -> tensor<3x6xf32> {
%2 = "xla_hlo.concatenate"(%arg0, %arg1) {dimension = 1 : i64} : (tensor<3x3xf32>, tensor<3x3xf32>) -> tensor<3x6xf32> %2 = "mhlo.concatenate"(%arg0, %arg1) {dimension = 1 : i64} : (tensor<3x3xf32>, tensor<3x3xf32>) -> tensor<3x6xf32>
return %2 : tensor<3x6xf32> return %2 : tensor<3x6xf32>
} }
func @const() -> tensor<2xi32> { func @const() -> tensor<2xi32> {
%0 = xla_hlo.constant dense<0> : tensor<2xi32> %0 = mhlo.constant dense<0> : tensor<2xi32>
return %0 : tensor<2xi32> return %0 : tensor<2xi32>
} }
func @relu(%arg0: tensor<1xi32>) -> tensor<1xi32> { func @relu(%arg0: tensor<1xi32>) -> tensor<1xi32> {
%0 = xla_hlo.constant dense<0> : tensor<i32> %0 = mhlo.constant dense<0> : tensor<i32>
%1 = "xla_chlo.broadcast_maximum"(%0, %arg0) {broadcast_dimensions = dense<[]> : tensor<0xi64>} : (tensor<i32>, tensor<1xi32>) -> tensor<1xi32> %1 = "xla_chlo.broadcast_maximum"(%0, %arg0) {broadcast_dimensions = dense<[]> : tensor<0xi64>} : (tensor<i32>, tensor<1xi32>) -> tensor<1xi32>
return %1 : tensor<1xi32> return %1 : tensor<1xi32>
} }
func @relu_unranked(%arg0: tensor<?xi32>) -> tensor<?xi32> { func @relu_unranked(%arg0: tensor<?xi32>) -> tensor<?xi32> {
%0 = xla_hlo.constant dense<0> : tensor<i32> %0 = mhlo.constant dense<0> : tensor<i32>
%1 = "xla_chlo.broadcast_maximum"(%0, %arg0) {broadcast_dimensions = dense<[]> : tensor<0xi64>} : (tensor<i32>, tensor<?xi32>) -> tensor<?xi32> %1 = "xla_chlo.broadcast_maximum"(%0, %arg0) {broadcast_dimensions = dense<[]> : tensor<0xi64>} : (tensor<i32>, tensor<?xi32>) -> tensor<?xi32>
return %1 : tensor<?xi32> return %1 : tensor<?xi32>
} }
func @relu6(%arg0: tensor<1xi32>) -> tensor<1xi32> { func @relu6(%arg0: tensor<1xi32>) -> tensor<1xi32> {
%0 = xla_hlo.constant dense<0> : tensor<i32> %0 = mhlo.constant dense<0> : tensor<i32>
%1 = xla_hlo.constant dense<6> : tensor<i32> %1 = mhlo.constant dense<6> : tensor<i32>
%2 = "xla_chlo.broadcast_minimum"(%arg0, %1) {broadcast_dimensions = dense<[]> : tensor<0xi64>} : (tensor<1xi32>, tensor<i32>) -> tensor<1xi32> %2 = "xla_chlo.broadcast_minimum"(%arg0, %1) {broadcast_dimensions = dense<[]> : tensor<0xi64>} : (tensor<1xi32>, tensor<i32>) -> tensor<1xi32>
%3 = "xla_chlo.broadcast_maximum"(%2, %0) {broadcast_dimensions = dense<[]> : tensor<0xi64>} : (tensor<1xi32>, tensor<i32>) -> tensor<1xi32> %3 = "xla_chlo.broadcast_maximum"(%2, %0) {broadcast_dimensions = dense<[]> : tensor<0xi64>} : (tensor<1xi32>, tensor<i32>) -> tensor<1xi32>
return %3 : tensor<1xi32> return %3 : tensor<1xi32>
} }
func @relu6_unranked(%arg0: tensor<?xi32>) -> tensor<?xi32> { func @relu6_unranked(%arg0: tensor<?xi32>) -> tensor<?xi32> {
%0 = xla_hlo.constant dense<0> : tensor<i32> %0 = mhlo.constant dense<0> : tensor<i32>
%1 = xla_hlo.constant dense<6> : tensor<i32> %1 = mhlo.constant dense<6> : tensor<i32>
%2 = "xla_chlo.broadcast_minimum"(%arg0, %1) {broadcast_dimensions = dense<[]> : tensor<0xi64>} : (tensor<?xi32>, tensor<i32>) -> tensor<?xi32> %2 = "xla_chlo.broadcast_minimum"(%arg0, %1) {broadcast_dimensions = dense<[]> : tensor<0xi64>} : (tensor<?xi32>, tensor<i32>) -> tensor<?xi32>
%3 = "xla_chlo.broadcast_maximum"(%2, %0) {broadcast_dimensions = dense<[]> : tensor<0xi64>} : (tensor<?xi32>, tensor<i32>) -> tensor<?xi32> %3 = "xla_chlo.broadcast_maximum"(%2, %0) {broadcast_dimensions = dense<[]> : tensor<0xi64>} : (tensor<?xi32>, tensor<i32>) -> tensor<?xi32>
return %3 : tensor<?xi32> return %3 : tensor<?xi32>
} }
func @relu_grad(%arg0: tensor<4x8xf32>, %arg1: tensor<?x?xf32>) -> tensor<4x8xf32> { func @relu_grad(%arg0: tensor<4x8xf32>, %arg1: tensor<?x?xf32>) -> tensor<4x8xf32> {
%0 = xla_hlo.constant dense<0.000000e+00> : tensor<f32> %0 = mhlo.constant dense<0.000000e+00> : tensor<f32>
%1 = "xla_chlo.broadcast_compare"(%arg1, %0) {broadcast_dimensions = dense<[]> : tensor<0xi64>, comparison_direction = "GT"} : (tensor<?x?xf32>, tensor<f32>) -> tensor<?x?xi1> %1 = "xla_chlo.broadcast_compare"(%arg1, %0) {broadcast_dimensions = dense<[]> : tensor<0xi64>, comparison_direction = "GT"} : (tensor<?x?xf32>, tensor<f32>) -> tensor<?x?xi1>
%2 = xla_hlo.constant dense<0.000000e+00> : tensor<4x8xf32> %2 = mhlo.constant dense<0.000000e+00> : tensor<4x8xf32>
%3 = "xla_hlo.select"(%1, %arg0, %2) : (tensor<?x?xi1>, tensor<4x8xf32>, tensor<4x8xf32>) -> tensor<4x8xf32> %3 = "mhlo.select"(%1, %arg0, %2) : (tensor<?x?xi1>, tensor<4x8xf32>, tensor<4x8xf32>) -> tensor<4x8xf32>
return %3 : tensor<4x8xf32> return %3 : tensor<4x8xf32>
} }
func @select(%arg0: tensor<2xi1>, %arg1: tensor<2xi32>, %arg2: tensor<2xi32>) -> tensor<2xi32> { func @select(%arg0: tensor<2xi1>, %arg1: tensor<2xi32>, %arg2: tensor<2xi32>) -> tensor<2xi32> {
%0 = "xla_hlo.select"(%arg0, %arg1, %arg2) : (tensor<2xi1>, tensor<2xi32>, tensor<2xi32>) -> tensor<2xi32> %0 = "mhlo.select"(%arg0, %arg1, %arg2) : (tensor<2xi1>, tensor<2xi32>, tensor<2xi32>) -> tensor<2xi32>
return %0 : tensor<2xi32> return %0 : tensor<2xi32>
} }
func @select_float(%arg0: tensor<2xi1>, %arg1: tensor<2xf32>, %arg2: tensor<2xf32>) -> tensor<2xf32> { func @select_float(%arg0: tensor<2xi1>, %arg1: tensor<2xf32>, %arg2: tensor<2xf32>) -> tensor<2xf32> {
%0 = "xla_hlo.select"(%arg0, %arg1, %arg2) : (tensor<2xi1>, tensor<2xf32>, tensor<2xf32>) -> tensor<2xf32> %0 = "mhlo.select"(%arg0, %arg1, %arg2) : (tensor<2xi1>, tensor<2xf32>, tensor<2xf32>) -> tensor<2xf32>
return %0 : tensor<2xf32> return %0 : tensor<2xf32>
} }
func @select_multidimensional(%arg0: tensor<3x2xi1>, %arg1: tensor<3x2xi32>, %arg2: tensor<3x2xi32>) -> tensor<3x2xi32> { func @select_multidimensional(%arg0: tensor<3x2xi1>, %arg1: tensor<3x2xi32>, %arg2: tensor<3x2xi32>) -> tensor<3x2xi32> {
%0 = "xla_hlo.select"(%arg0, %arg1, %arg2) : (tensor<3x2xi1>, tensor<3x2xi32>, tensor<3x2xi32>) -> tensor<3x2xi32> %0 = "mhlo.select"(%arg0, %arg1, %arg2) : (tensor<3x2xi1>, tensor<3x2xi32>, tensor<3x2xi32>) -> tensor<3x2xi32>
return %0 : tensor<3x2xi32> return %0 : tensor<3x2xi32>
} }
func @selectv2(%arg0: tensor<2xi1>, %arg1: tensor<2xi32>, %arg2: tensor<2xi32>) -> tensor<2xi32> { func @selectv2(%arg0: tensor<2xi1>, %arg1: tensor<2xi32>, %arg2: tensor<2xi32>) -> tensor<2xi32> {
%0 = "xla_hlo.select"(%arg0, %arg1, %arg2) : (tensor<2xi1>, tensor<2xi32>, tensor<2xi32>) -> tensor<2xi32> %0 = "mhlo.select"(%arg0, %arg1, %arg2) : (tensor<2xi1>, tensor<2xi32>, tensor<2xi32>) -> tensor<2xi32>
return %0 : tensor<2xi32> return %0 : tensor<2xi32>
} }
func @selectv2_pred_scalar(%arg0: tensor<i1>, %arg1: tensor<2xi32>, %arg2: tensor<2xi32>) -> tensor<2xi32> { func @selectv2_pred_scalar(%arg0: tensor<i1>, %arg1: tensor<2xi32>, %arg2: tensor<2xi32>) -> tensor<2xi32> {
%0 = "xla_hlo.select"(%arg0, %arg1, %arg2) : (tensor<i1>, tensor<2xi32>, tensor<2xi32>) -> tensor<2xi32> %0 = "mhlo.select"(%arg0, %arg1, %arg2) : (tensor<i1>, tensor<2xi32>, tensor<2xi32>) -> tensor<2xi32>
return %0 : tensor<2xi32> return %0 : tensor<2xi32>
} }
func @transpose_2d(%arg0: tensor<2x3xf32>) -> tensor<3x2xf32> { func @transpose_2d(%arg0: tensor<2x3xf32>) -> tensor<3x2xf32> {
%0 = xla_hlo.constant dense<[1, 0]> : tensor<2xi64> %0 = mhlo.constant dense<[1, 0]> : tensor<2xi64>
%1 = xla_hlo.constant dense<[1, 0]> : tensor<2xi64> %1 = mhlo.constant dense<[1, 0]> : tensor<2xi64>
%2 = "xla_hlo.transpose"(%arg0) {permutation = dense<[1, 0]> : tensor<2xi64>} : (tensor<2x3xf32>) -> tensor<3x2xf32> %2 = "mhlo.transpose"(%arg0) {permutation = dense<[1, 0]> : tensor<2xi64>} : (tensor<2x3xf32>) -> tensor<3x2xf32>
return %2 : tensor<3x2xf32> return %2 : tensor<3x2xf32>
} }
func @transpose_3d_int32(%arg0: tensor<1x2x3xf32>) -> tensor<3x2x1xf32> { func @transpose_3d_int32(%arg0: tensor<1x2x3xf32>) -> tensor<3x2x1xf32> {
%0 = xla_hlo.constant dense<[2, 1, 0]> : tensor<3xi32> %0 = mhlo.constant dense<[2, 1, 0]> : tensor<3xi32>
%1 = xla_hlo.constant dense<[2, 1, 0]> : tensor<3xi64> %1 = mhlo.constant dense<[2, 1, 0]> : tensor<3xi64>
%2 = "xla_hlo.transpose"(%arg0) {permutation = dense<[2, 1, 0]> : tensor<3xi64>} : (tensor<1x2x3xf32>) -> tensor<3x2x1xf32> %2 = "mhlo.transpose"(%arg0) {permutation = dense<[2, 1, 0]> : tensor<3xi64>} : (tensor<1x2x3xf32>) -> tensor<3x2x1xf32>
return %2 : tensor<3x2x1xf32> return %2 : tensor<3x2x1xf32>
} }
func @transpose_3d(%arg0: tensor<1x2x3xf32>) -> tensor<3x2x1xf32> { func @transpose_3d(%arg0: tensor<1x2x3xf32>) -> tensor<3x2x1xf32> {
%0 = xla_hlo.constant dense<[2, 1, 0]> : tensor<3xi64> %0 = mhlo.constant dense<[2, 1, 0]> : tensor<3xi64>
%1 = xla_hlo.constant dense<[2, 1, 0]> : tensor<3xi64> %1 = mhlo.constant dense<[2, 1, 0]> : tensor<3xi64>
%2 = "xla_hlo.transpose"(%arg0) {permutation = dense<[2, 1, 0]> : tensor<3xi64>} : (tensor<1x2x3xf32>) -> tensor<3x2x1xf32> %2 = "mhlo.transpose"(%arg0) {permutation = dense<[2, 1, 0]> : tensor<3xi64>} : (tensor<1x2x3xf32>) -> tensor<3x2x1xf32>
return %2 : tensor<3x2x1xf32> return %2 : tensor<3x2x1xf32>
} }
func @transpose_dynamic_2d(%arg0: tensor<?x4xf32>) -> tensor<4x?xf32> { func @transpose_dynamic_2d(%arg0: tensor<?x4xf32>) -> tensor<4x?xf32> {
%0 = xla_hlo.constant dense<[1, 0]> : tensor<2xi64> %0 = mhlo.constant dense<[1, 0]> : tensor<2xi64>
%1 = xla_hlo.constant dense<[1, 0]> : tensor<2xi64> %1 = mhlo.constant dense<[1, 0]> : tensor<2xi64>
%2 = "xla_hlo.transpose"(%arg0) {permutation = dense<[1, 0]> : tensor<2xi64>} : (tensor<?x4xf32>) -> tensor<4x?xf32> %2 = "mhlo.transpose"(%arg0) {permutation = dense<[1, 0]> : tensor<2xi64>} : (tensor<?x4xf32>) -> tensor<4x?xf32>
return %2 : tensor<4x?xf32> return %2 : tensor<4x?xf32>
} }
func @transpose_unranked_2d(%arg0: tensor<*xf32>) -> tensor<*xf32> { func @transpose_unranked_2d(%arg0: tensor<*xf32>) -> tensor<*xf32> {
%0 = xla_hlo.constant dense<[1, 0]> : tensor<2xi64> %0 = mhlo.constant dense<[1, 0]> : tensor<2xi64>
%1 = xla_hlo.constant dense<[1, 0]> : tensor<2xi64> %1 = mhlo.constant dense<[1, 0]> : tensor<2xi64>
%2 = "xla_hlo.transpose"(%arg0) {permutation = dense<[1, 0]> : tensor<2xi64>} : (tensor<*xf32>) -> tensor<*xf32> %2 = "mhlo.transpose"(%arg0) {permutation = dense<[1, 0]> : tensor<2xi64>} : (tensor<*xf32>) -> tensor<*xf32>
return %2 : tensor<*xf32> return %2 : tensor<*xf32>
} }
func @abs(%arg0: tensor<2xf32>) -> tensor<2xf32> { func @abs(%arg0: tensor<2xf32>) -> tensor<2xf32> {
%0 = "xla_hlo.abs"(%arg0) : (tensor<2xf32>) -> tensor<2xf32> %0 = "mhlo.abs"(%arg0) : (tensor<2xf32>) -> tensor<2xf32>
return %0 : tensor<2xf32> return %0 : tensor<2xf32>
} }
func @abs_dynamic(%arg0: tensor<?xf32>) -> tensor<?xf32> { func @abs_dynamic(%arg0: tensor<?xf32>) -> tensor<?xf32> {
%0 = "xla_hlo.abs"(%arg0) : (tensor<?xf32>) -> tensor<?xf32> %0 = "mhlo.abs"(%arg0) : (tensor<?xf32>) -> tensor<?xf32>
return %0 : tensor<?xf32> return %0 : tensor<?xf32>
} }
func @abs_unranked(%arg0: tensor<*xf32>) -> tensor<*xf32> { func @abs_unranked(%arg0: tensor<*xf32>) -> tensor<*xf32> {
%0 = "xla_hlo.abs"(%arg0) : (tensor<*xf32>) -> tensor<*xf32> %0 = "mhlo.abs"(%arg0) : (tensor<*xf32>) -> tensor<*xf32>
return %0 : tensor<*xf32> return %0 : tensor<*xf32>
} }
func @ceil(%arg0: tensor<2xf32>) -> tensor<2xf32> { func @ceil(%arg0: tensor<2xf32>) -> tensor<2xf32> {
%0 = "xla_hlo.ceil"(%arg0) : (tensor<2xf32>) -> tensor<2xf32> %0 = "mhlo.ceil"(%arg0) : (tensor<2xf32>) -> tensor<2xf32>
return %0 : tensor<2xf32> return %0 : tensor<2xf32>
} }
func @ceil_dynamic(%arg0: tensor<?xf32>) -> tensor<?xf32> { func @ceil_dynamic(%arg0: tensor<?xf32>) -> tensor<?xf32> {
%0 = "xla_hlo.ceil"(%arg0) : (tensor<?xf32>) -> tensor<?xf32> %0 = "mhlo.ceil"(%arg0) : (tensor<?xf32>) -> tensor<?xf32>
return %0 : tensor<?xf32> return %0 : tensor<?xf32>
} }
func @ceil_unranked(%arg0: tensor<*xf32>) -> tensor<*xf32> { func @ceil_unranked(%arg0: tensor<*xf32>) -> tensor<*xf32> {
%0 = "xla_hlo.ceil"(%arg0) : (tensor<*xf32>) -> tensor<*xf32> %0 = "mhlo.ceil"(%arg0) : (tensor<*xf32>) -> tensor<*xf32>
return %0 : tensor<*xf32> return %0 : tensor<*xf32>
} }
func @complex_abs(%arg0: tensor<2xcomplex<f32>>) -> tensor<2xf32> { func @complex_abs(%arg0: tensor<2xcomplex<f32>>) -> tensor<2xf32> {
%0 = "xla_hlo.abs"(%arg0) : (tensor<2xcomplex<f32>>) -> tensor<2xf32> %0 = "mhlo.abs"(%arg0) : (tensor<2xcomplex<f32>>) -> tensor<2xf32>
return %0 : tensor<2xf32> return %0 : tensor<2xf32>
} }
func @cos(%arg0: tensor<2xf32>) -> tensor<2xf32> { func @cos(%arg0: tensor<2xf32>) -> tensor<2xf32> {
%0 = "xla_hlo.cosine"(%arg0) : (tensor<2xf32>) -> tensor<2xf32> %0 = "mhlo.cosine"(%arg0) : (tensor<2xf32>) -> tensor<2xf32>
return %0 : tensor<2xf32> return %0 : tensor<2xf32>
} }
func @cos_dynamic(%arg0: tensor<?xf32>) -> tensor<?xf32> { func @cos_dynamic(%arg0: tensor<?xf32>) -> tensor<?xf32> {
%0 = "xla_hlo.cosine"(%arg0) : (tensor<?xf32>) -> tensor<?xf32> %0 = "mhlo.cosine"(%arg0) : (tensor<?xf32>) -> tensor<?xf32>
return %0 : tensor<?xf32> return %0 : tensor<?xf32>
} }
func @cos_unranked(%arg0: tensor<*xf32>) -> tensor<*xf32> { func @cos_unranked(%arg0: tensor<*xf32>) -> tensor<*xf32> {
%0 = "xla_hlo.cosine"(%arg0) : (tensor<*xf32>) -> tensor<*xf32> %0 = "mhlo.cosine"(%arg0) : (tensor<*xf32>) -> tensor<*xf32>
return %0 : tensor<*xf32> return %0 : tensor<*xf32>
} }
func @exp(%arg0: tensor<2xf32>) -> tensor<2xf32> { func @exp(%arg0: tensor<2xf32>) -> tensor<2xf32> {
%0 = "xla_hlo.exponential"(%arg0) : (tensor<2xf32>) -> tensor<2xf32> %0 = "mhlo.exponential"(%arg0) : (tensor<2xf32>) -> tensor<2xf32>
return %0 : tensor<2xf32> return %0 : tensor<2xf32>
} }
func @exp_dynamic(%arg0: tensor<?xf32>) -> tensor<?xf32> { func @exp_dynamic(%arg0: tensor<?xf32>) -> tensor<?xf32> {
%0 = "xla_hlo.exponential"(%arg0) : (tensor<?xf32>) -> tensor<?xf32> %0 = "mhlo.exponential"(%arg0) : (tensor<?xf32>) -> tensor<?xf32>
return %0 : tensor<?xf32> return %0 : tensor<?xf32>
} }
func @exp_unranked(%arg0: tensor<*xf32>) -> tensor<*xf32> { func @exp_unranked(%arg0: tensor<*xf32>) -> tensor<*xf32> {
%0 = "xla_hlo.exponential"(%arg0) : (tensor<*xf32>) -> tensor<*xf32> %0 = "mhlo.exponential"(%arg0) : (tensor<*xf32>) -> tensor<*xf32>
return %0 : tensor<*xf32> return %0 : tensor<*xf32>
} }
func @floor(%arg0: tensor<2xf32>) -> tensor<2xf32> { func @floor(%arg0: tensor<2xf32>) -> tensor<2xf32> {
%0 = "xla_hlo.floor"(%arg0) : (tensor<2xf32>) -> tensor<2xf32> %0 = "mhlo.floor"(%arg0) : (tensor<2xf32>) -> tensor<2xf32>
return %0 : tensor<2xf32> return %0 : tensor<2xf32>
} }
func @floor_dynamic(%arg0: tensor<?xf32>) -> tensor<?xf32> { func @floor_dynamic(%arg0: tensor<?xf32>) -> tensor<?xf32> {
%0 = "xla_hlo.floor"(%arg0) : (tensor<?xf32>) -> tensor<?xf32> %0 = "mhlo.floor"(%arg0) : (tensor<?xf32>) -> tensor<?xf32>
return %0 : tensor<?xf32> return %0 : tensor<?xf32>
} }
func @floor_unranked(%arg0: tensor<*xf32>) -> tensor<*xf32> { func @floor_unranked(%arg0: tensor<*xf32>) -> tensor<*xf32> {
%0 = "xla_hlo.floor"(%arg0) : (tensor<*xf32>) -> tensor<*xf32> %0 = "mhlo.floor"(%arg0) : (tensor<*xf32>) -> tensor<*xf32>
return %0 : tensor<*xf32> return %0 : tensor<*xf32>
} }
func @is_finite(%arg0: tensor<2xf32>) -> tensor<2xi1> { func @is_finite(%arg0: tensor<2xf32>) -> tensor<2xi1> {
%0 = "xla_hlo.is_finite"(%arg0) : (tensor<2xf32>) -> tensor<2xi1> %0 = "mhlo.is_finite"(%arg0) : (tensor<2xf32>) -> tensor<2xi1>
return %0 : tensor<2xi1> return %0 : tensor<2xi1>
} }
func @is_finite_dynamic(%arg0: tensor<?xf32>) -> tensor<?xi1> { func @is_finite_dynamic(%arg0: tensor<?xf32>) -> tensor<?xi1> {
%0 = "xla_hlo.is_finite"(%arg0) : (tensor<?xf32>) -> tensor<?xi1> %0 = "mhlo.is_finite"(%arg0) : (tensor<?xf32>) -> tensor<?xi1>
return %0 : tensor<?xi1> return %0 : tensor<?xi1>
} }
func @is_finite_unranked(%arg0: tensor<*xf32>) -> tensor<*xi1> { func @is_finite_unranked(%arg0: tensor<*xf32>) -> tensor<*xi1> {
%0 = "xla_hlo.is_finite"(%arg0) : (tensor<*xf32>) -> tensor<*xi1> %0 = "mhlo.is_finite"(%arg0) : (tensor<*xf32>) -> tensor<*xi1>
return %0 : tensor<*xi1> return %0 : tensor<*xi1>
} }
func @log(%arg0: tensor<2xf32>) -> tensor<2xf32> { func @log(%arg0: tensor<2xf32>) -> tensor<2xf32> {
%0 = "xla_hlo.log"(%arg0) : (tensor<2xf32>) -> tensor<2xf32> %0 = "mhlo.log"(%arg0) : (tensor<2xf32>) -> tensor<2xf32>
return %0 : tensor<2xf32> return %0 : tensor<2xf32>
} }
func @log_dynamic(%arg0: tensor<?xf32>) -> tensor<?xf32> { func @log_dynamic(%arg0: tensor<?xf32>) -> tensor<?xf32> {
%0 = "xla_hlo.log"(%arg0) : (tensor<?xf32>) -> tensor<?xf32> %0 = "mhlo.log"(%arg0) : (tensor<?xf32>) -> tensor<?xf32>
return %0 : tensor<?xf32> return %0 : tensor<?xf32>
} }
func @log_unranked(%arg0: tensor<*xf32>) -> tensor<*xf32> { func @log_unranked(%arg0: tensor<*xf32>) -> tensor<*xf32> {
%0 = "xla_hlo.log"(%arg0) : (tensor<*xf32>) -> tensor<*xf32> %0 = "mhlo.log"(%arg0) : (tensor<*xf32>) -> tensor<*xf32>
return %0 : tensor<*xf32> return %0 : tensor<*xf32>
} }
func @log1p(%arg0: tensor<2xf32>) -> tensor<2xf32> { func @log1p(%arg0: tensor<2xf32>) -> tensor<2xf32> {
%0 = "xla_hlo.log_plus_one"(%arg0) : (tensor<2xf32>) -> tensor<2xf32> %0 = "mhlo.log_plus_one"(%arg0) : (tensor<2xf32>) -> tensor<2xf32>
return %0 : tensor<2xf32> return %0 : tensor<2xf32>
} }
func @log1p_dynamic(%arg0: tensor<?xf32>) -> tensor<?xf32> { func @log1p_dynamic(%arg0: tensor<?xf32>) -> tensor<?xf32> {
%0 = "xla_hlo.log_plus_one"(%arg0) : (tensor<?xf32>) -> tensor<?xf32> %0 = "mhlo.log_plus_one"(%arg0) : (tensor<?xf32>) -> tensor<?xf32>
return %0 : tensor<?xf32> return %0 : tensor<?xf32>
} }
func @log1p_unranked(%arg0: tensor<*xf32>) -> tensor<*xf32> { func @log1p_unranked(%arg0: tensor<*xf32>) -> tensor<*xf32> {
%0 = "xla_hlo.log_plus_one"(%arg0) : (tensor<*xf32>) -> tensor<*xf32> %0 = "mhlo.log_plus_one"(%arg0) : (tensor<*xf32>) -> tensor<*xf32>
return %0 : tensor<*xf32> return %0 : tensor<*xf32>
} }
func @neg(%arg0: tensor<2xf32>) -> tensor<2xf32> { func @neg(%arg0: tensor<2xf32>) -> tensor<2xf32> {
%0 = "xla_hlo.negate"(%arg0) : (tensor<2xf32>) -> tensor<2xf32> %0 = "mhlo.negate"(%arg0) : (tensor<2xf32>) -> tensor<2xf32>
return %0 : tensor<2xf32> return %0 : tensor<2xf32>
} }
func @neg_dynamic(%arg0: tensor<?xf32>) -> tensor<?xf32> { func @neg_dynamic(%arg0: tensor<?xf32>) -> tensor<?xf32> {
%0 = "xla_hlo.negate"(%arg0) : (tensor<?xf32>) -> tensor<?xf32> %0 = "mhlo.negate"(%arg0) : (tensor<?xf32>) -> tensor<?xf32>
return %0 : tensor<?xf32> return %0 : tensor<?xf32>
} }
func @neg_unranked(%arg0: tensor<*xf32>) -> tensor<*xf32> { func @neg_unranked(%arg0: tensor<*xf32>) -> tensor<*xf32> {
%0 = "xla_hlo.negate"(%arg0) : (tensor<*xf32>) -> tensor<*xf32> %0 = "mhlo.negate"(%arg0) : (tensor<*xf32>) -> tensor<*xf32>
return %0 : tensor<*xf32> return %0 : tensor<*xf32>
} }
func @sigmoid(%arg0: tensor<2xf32>) -> tensor<2xf32> { func @sigmoid(%arg0: tensor<2xf32>) -> tensor<2xf32> {
%0 = xla_hlo.constant dense<5.000000e-01> : tensor<f32> %0 = mhlo.constant dense<5.000000e-01> : tensor<f32>
%1 = xla_hlo.constant dense<2> : tensor<1xi64> %1 = mhlo.constant dense<2> : tensor<1xi64>
%2 = xla_hlo.constant dense<5.000000e-01> : tensor<2xf32> %2 = mhlo.constant dense<5.000000e-01> : tensor<2xf32>
%3 = xla_hlo.multiply %arg0, %2 : tensor<2xf32> %3 = mhlo.multiply %arg0, %2 : tensor<2xf32>
%4 = "xla_hlo.tanh"(%3) : (tensor<2xf32>) -> tensor<2xf32> %4 = "mhlo.tanh"(%3) : (tensor<2xf32>) -> tensor<2xf32>
%5 = xla_hlo.multiply %4, %2 : tensor<2xf32> %5 = mhlo.multiply %4, %2 : tensor<2xf32>
%6 = xla_hlo.add %5, %2 : tensor<2xf32> %6 = mhlo.add %5, %2 : tensor<2xf32>
return %6 : tensor<2xf32> return %6 : tensor<2xf32>
} }
func @sin(%arg0: tensor<2xf32>) -> tensor<2xf32> { func @sin(%arg0: tensor<2xf32>) -> tensor<2xf32> {
%0 = "xla_hlo.sine"(%arg0) : (tensor<2xf32>) -> tensor<2xf32> %0 = "mhlo.sine"(%arg0) : (tensor<2xf32>) -> tensor<2xf32>
return %0 : tensor<2xf32> return %0 : tensor<2xf32>
} }
func @sin_dynamic(%arg0: tensor<?xf32>) -> tensor<?xf32> { func @sin_dynamic(%arg0: tensor<?xf32>) -> tensor<?xf32> {
%0 = "xla_hlo.sine"(%arg0) : (tensor<?xf32>) -> tensor<?xf32> %0 = "mhlo.sine"(%arg0) : (tensor<?xf32>) -> tensor<?xf32>
return %0 : tensor<?xf32> return %0 : tensor<?xf32>
} }
func @sin_unranked(%arg0: tensor<*xf32>) -> tensor<*xf32> { func @sin_unranked(%arg0: tensor<*xf32>) -> tensor<*xf32> {
%0 = "xla_hlo.sine"(%arg0) : (tensor<*xf32>) -> tensor<*xf32> %0 = "mhlo.sine"(%arg0) : (tensor<*xf32>) -> tensor<*xf32>
return %0 : tensor<*xf32> return %0 : tensor<*xf32>
} }
func @rsqrt(%arg0: tensor<2xf32>) -> tensor<2xf32> { func @rsqrt(%arg0: tensor<2xf32>) -> tensor<2xf32> {
%0 = "xla_hlo.rsqrt"(%arg0) : (tensor<2xf32>) -> tensor<2xf32> %0 = "mhlo.rsqrt"(%arg0) : (tensor<2xf32>) -> tensor<2xf32>
return %0 : tensor<2xf32> return %0 : tensor<2xf32>
} }
func @rsqrt_dynamic(%arg0: tensor<?xf32>) -> tensor<?xf32> { func @rsqrt_dynamic(%arg0: tensor<?xf32>) -> tensor<?xf32> {
%0 = "xla_hlo.rsqrt"(%arg0) : (tensor<?xf32>) -> tensor<?xf32> %0 = "mhlo.rsqrt"(%arg0) : (tensor<?xf32>) -> tensor<?xf32>
return %0 : tensor<?xf32> return %0 : tensor<?xf32>
} }
func @rsqrt_unranked(%arg0: tensor<*xf32>) -> tensor<*xf32> { func @rsqrt_unranked(%arg0: tensor<*xf32>) -> tensor<*xf32> {
%0 = "xla_hlo.rsqrt"(%arg0) : (tensor<*xf32>) -> tensor<*xf32> %0 = "mhlo.rsqrt"(%arg0) : (tensor<*xf32>) -> tensor<*xf32>
return %0 : tensor<*xf32> return %0 : tensor<*xf32>
} }
func @sqrt(%arg0: tensor<2xf32>) -> tensor<2xf32> { func @sqrt(%arg0: tensor<2xf32>) -> tensor<2xf32> {
%0 = "xla_hlo.sqrt"(%arg0) : (tensor<2xf32>) -> tensor<2xf32> %0 = "mhlo.sqrt"(%arg0) : (tensor<2xf32>) -> tensor<2xf32>
return %0 : tensor<2xf32> return %0 : tensor<2xf32>
} }
func @sqrt_dynamic(%arg0: tensor<?xf32>) -> tensor<?xf32> { func @sqrt_dynamic(%arg0: tensor<?xf32>) -> tensor<?xf32> {
%0 = "xla_hlo.sqrt"(%arg0) : (tensor<?xf32>) -> tensor<?xf32> %0 = "mhlo.sqrt"(%arg0) : (tensor<?xf32>) -> tensor<?xf32>
return %0 : tensor<?xf32> return %0 : tensor<?xf32>
} }
func @sqrt_unranked(%arg0: tensor<*xf32>) -> tensor<*xf32> { func @sqrt_unranked(%arg0: tensor<*xf32>) -> tensor<*xf32> {
%0 = "xla_hlo.sqrt"(%arg0) : (tensor<*xf32>) -> tensor<*xf32> %0 = "mhlo.sqrt"(%arg0) : (tensor<*xf32>) -> tensor<*xf32>
return %0 : tensor<*xf32> return %0 : tensor<*xf32>
} }
func @tanh(%arg0: tensor<2xf32>) -> tensor<2xf32> { func @tanh(%arg0: tensor<2xf32>) -> tensor<2xf32> {
%0 = "xla_hlo.tanh"(%arg0) : (tensor<2xf32>) -> tensor<2xf32> %0 = "mhlo.tanh"(%arg0) : (tensor<2xf32>) -> tensor<2xf32>
return %0 : tensor<2xf32> return %0 : tensor<2xf32>
} }
func @tanh_dynamic(%arg0: tensor<?xf32>) -> tensor<?xf32> { func @tanh_dynamic(%arg0: tensor<?xf32>) -> tensor<?xf32> {
%0 = "xla_hlo.tanh"(%arg0) : (tensor<?xf32>) -> tensor<?xf32> %0 = "mhlo.tanh"(%arg0) : (tensor<?xf32>) -> tensor<?xf32>
return %0 : tensor<?xf32> return %0 : tensor<?xf32>
} }
func @tanh_unranked(%arg0: tensor<*xf32>) -> tensor<*xf32> { func @tanh_unranked(%arg0: tensor<*xf32>) -> tensor<*xf32> {
%0 = "xla_hlo.tanh"(%arg0) : (tensor<*xf32>) -> tensor<*xf32> %0 = "mhlo.tanh"(%arg0) : (tensor<*xf32>) -> tensor<*xf32>
return %0 : tensor<*xf32> return %0 : tensor<*xf32>
} }
func @bitcast(%arg0: tensor<2xf32>) -> tensor<2xf32> { func @bitcast(%arg0: tensor<2xf32>) -> tensor<2xf32> {
%0 = "xla_hlo.bitcast_convert"(%arg0) : (tensor<2xf32>) -> tensor<2xf32> %0 = "mhlo.bitcast_convert"(%arg0) : (tensor<2xf32>) -> tensor<2xf32>
return %0 : tensor<2xf32> return %0 : tensor<2xf32>
} }
func @bitcast_dynamic(%arg0: tensor<?xf32>) -> tensor<?xf32> { func @bitcast_dynamic(%arg0: tensor<?xf32>) -> tensor<?xf32> {
%0 = "xla_hlo.bitcast_convert"(%arg0) : (tensor<?xf32>) -> tensor<?xf32> %0 = "mhlo.bitcast_convert"(%arg0) : (tensor<?xf32>) -> tensor<?xf32>
return %0 : tensor<?xf32> return %0 : tensor<?xf32>
} }
func @bitcast_unranked(%arg0: tensor<*xf32>) -> tensor<*xf32> { func @bitcast_unranked(%arg0: tensor<*xf32>) -> tensor<*xf32> {
%0 = "xla_hlo.bitcast_convert"(%arg0) : (tensor<*xf32>) -> tensor<*xf32> %0 = "mhlo.bitcast_convert"(%arg0) : (tensor<*xf32>) -> tensor<*xf32>
return %0 : tensor<*xf32> return %0 : tensor<*xf32>
} }
func @bitcast_same_widths(%arg0: tensor<2xf32>) -> tensor<2xi32> { func @bitcast_same_widths(%arg0: tensor<2xf32>) -> tensor<2xi32> {
%0 = "xla_hlo.bitcast_convert"(%arg0) : (tensor<2xf32>) -> tensor<2xi32> %0 = "mhlo.bitcast_convert"(%arg0) : (tensor<2xf32>) -> tensor<2xi32>
return %0 : tensor<2xi32> return %0 : tensor<2xi32>
} }
func @sign(%arg0: tensor<1x2x3x4xf32>) -> tensor<1x2x3x4xf32> { func @sign(%arg0: tensor<1x2x3x4xf32>) -> tensor<1x2x3x4xf32> {
%0 = "xla_hlo.compare"(%arg0, %arg0) {comparison_direction = "NE"} : (tensor<1x2x3x4xf32>, tensor<1x2x3x4xf32>) -> tensor<1x2x3x4xi1> %0 = "mhlo.compare"(%arg0, %arg0) {comparison_direction = "NE"} : (tensor<1x2x3x4xf32>, tensor<1x2x3x4xf32>) -> tensor<1x2x3x4xi1>
%1 = xla_hlo.constant dense<0.000000e+00> : tensor<1x2x3x4xf32> %1 = mhlo.constant dense<0.000000e+00> : tensor<1x2x3x4xf32>
%2 = "xla_hlo.compare"(%arg0, %arg0) {comparison_direction = "NE"} : (tensor<1x2x3x4xf32>, tensor<1x2x3x4xf32>) -> tensor<1x2x3x4xi1> %2 = "mhlo.compare"(%arg0, %arg0) {comparison_direction = "NE"} : (tensor<1x2x3x4xf32>, tensor<1x2x3x4xf32>) -> tensor<1x2x3x4xi1>
%3 = xla_hlo.constant dense<0.000000e+00> : tensor<1x2x3x4xf32> %3 = mhlo.constant dense<0.000000e+00> : tensor<1x2x3x4xf32>
%4 = "xla_hlo.sign"(%arg0) : (tensor<1x2x3x4xf32>) -> tensor<1x2x3x4xf32> %4 = "mhlo.sign"(%arg0) : (tensor<1x2x3x4xf32>) -> tensor<1x2x3x4xf32>
%5 = "xla_hlo.select"(%2, %3, %4) : (tensor<1x2x3x4xi1>, tensor<1x2x3x4xf32>, tensor<1x2x3x4xf32>) -> tensor<1x2x3x4xf32> %5 = "mhlo.select"(%2, %3, %4) : (tensor<1x2x3x4xi1>, tensor<1x2x3x4xf32>, tensor<1x2x3x4xf32>) -> tensor<1x2x3x4xf32>
%6 = "xla_hlo.select"(%0, %1, %5) : (tensor<1x2x3x4xi1>, tensor<1x2x3x4xf32>, tensor<1x2x3x4xf32>) -> tensor<1x2x3x4xf32> %6 = "mhlo.select"(%0, %1, %5) : (tensor<1x2x3x4xi1>, tensor<1x2x3x4xf32>, tensor<1x2x3x4xf32>) -> tensor<1x2x3x4xf32>
return %6 : tensor<1x2x3x4xf32> return %6 : tensor<1x2x3x4xf32>
} }
func @size_rank_one_i32(%arg0: tensor<f32>) -> tensor<i32> { func @size_rank_one_i32(%arg0: tensor<f32>) -> tensor<i32> {
%0 = xla_hlo.constant dense<1> : tensor<i32> %0 = mhlo.constant dense<1> : tensor<i32>
return %0 : tensor<i32> return %0 : tensor<i32>
} }
func @size_rank_one_i64(%arg0: tensor<f32>) -> tensor<i64> { func @size_rank_one_i64(%arg0: tensor<f32>) -> tensor<i64> {
%0 = xla_hlo.constant dense<1> : tensor<i64> %0 = mhlo.constant dense<1> : tensor<i64>
return %0 : tensor<i64> return %0 : tensor<i64>
} }
func @complex(%arg0: tensor<3xf32>, %arg1: tensor<3xf32>) -> tensor<3xcomplex<f32>> { func @complex(%arg0: tensor<3xf32>, %arg1: tensor<3xf32>) -> tensor<3xcomplex<f32>> {
%0 = "xla_hlo.complex"(%arg0, %arg1) : (tensor<3xf32>, tensor<3xf32>) -> tensor<3xcomplex<f32>> %0 = "mhlo.complex"(%arg0, %arg1) : (tensor<3xf32>, tensor<3xf32>) -> tensor<3xcomplex<f32>>
return %0 : tensor<3xcomplex<f32>> return %0 : tensor<3xcomplex<f32>>
} }
func @convert_i32_f32(%arg0: tensor<2xi32>) -> tensor<2xf32> { func @convert_i32_f32(%arg0: tensor<2xi32>) -> tensor<2xf32> {
%0 = "xla_hlo.convert"(%arg0) : (tensor<2xi32>) -> tensor<2xf32> %0 = "mhlo.convert"(%arg0) : (tensor<2xi32>) -> tensor<2xf32>
return %0 : tensor<2xf32> return %0 : tensor<2xf32>
} }
func @convert_slice(%arg0: tensor<1x4672xf32>) -> tensor<1x519xf32> { func @convert_slice(%arg0: tensor<1x4672xf32>) -> tensor<1x519xf32> {
%0 = "xla_hlo.slice"(%arg0) {limit_indices = dense<[1, 4672]> : tensor<2xi64>, start_indices = dense<[0, 4153]> : tensor<2xi64>, strides = dense<1> : tensor<2xi64>} : (tensor<1x4672xf32>) -> tensor<1x519xf32> %0 = "mhlo.slice"(%arg0) {limit_indices = dense<[1, 4672]> : tensor<2xi64>, start_indices = dense<[0, 4153]> : tensor<2xi64>, strides = dense<1> : tensor<2xi64>} : (tensor<1x4672xf32>) -> tensor<1x519xf32>
return %0 : tensor<1x519xf32> return %0 : tensor<1x519xf32>
} }
func @reshape(%arg0: tensor<4x6xf32>) -> tensor<2x2x6xf32> { func @reshape(%arg0: tensor<4x6xf32>) -> tensor<2x2x6xf32> {
%0 = "xla_hlo.reshape"(%arg0) : (tensor<4x6xf32>) -> tensor<2x2x6xf32> %0 = "mhlo.reshape"(%arg0) : (tensor<4x6xf32>) -> tensor<2x2x6xf32>
return %0 : tensor<2x2x6xf32> return %0 : tensor<2x2x6xf32>
} }
func @convert_dot_1d_2d(%arg0: tensor<256xf32>, %arg1: tensor<256x1xf32>) -> tensor<1xf32> { func @convert_dot_1d_2d(%arg0: tensor<256xf32>, %arg1: tensor<256x1xf32>) -> tensor<1xf32> {
%0 = "xla_hlo.dot"(%arg0, %arg1) {precision_config = ["DEFAULT", "DEFAULT"]} : (tensor<256xf32>, tensor<256x1xf32>) -> tensor<1xf32> %0 = "mhlo.dot"(%arg0, %arg1) {precision_config = ["DEFAULT", "DEFAULT"]} : (tensor<256xf32>, tensor<256x1xf32>) -> tensor<1xf32>
return %0 : tensor<1xf32> return %0 : tensor<1xf32>
} }
func @convert_dot_2d_1d(%arg0: tensor<1x256xf32>, %arg1: tensor<256xf32>) -> tensor<1xf32> { func @convert_dot_2d_1d(%arg0: tensor<1x256xf32>, %arg1: tensor<256xf32>) -> tensor<1xf32> {
%0 = "xla_hlo.dot"(%arg0, %arg1) {precision_config = ["DEFAULT", "DEFAULT"]} : (tensor<1x256xf32>, tensor<256xf32>) -> tensor<1xf32> %0 = "mhlo.dot"(%arg0, %arg1) {precision_config = ["DEFAULT", "DEFAULT"]} : (tensor<1x256xf32>, tensor<256xf32>) -> tensor<1xf32>
return %0 : tensor<1xf32> return %0 : tensor<1xf32>
} }
func @convert_dot_1d_1d(%arg0: tensor<256xf32>, %arg1: tensor<256xf32>) -> tensor<f32> { func @convert_dot_1d_1d(%arg0: tensor<256xf32>, %arg1: tensor<256xf32>) -> tensor<f32> {
%0 = "xla_hlo.dot"(%arg0, %arg1) {precision_config = ["DEFAULT", "DEFAULT"]} : (tensor<256xf32>, tensor<256xf32>) -> tensor<f32> %0 = "mhlo.dot"(%arg0, %arg1) {precision_config = ["DEFAULT", "DEFAULT"]} : (tensor<256xf32>, tensor<256xf32>) -> tensor<f32>
return %0 : tensor<f32> return %0 : tensor<f32>
} }
func @convert_dot_2d_2d(%arg0: tensor<1x256xf32>, %arg1: tensor<256x1xf32>) -> tensor<1x1xf32> { func @convert_dot_2d_2d(%arg0: tensor<1x256xf32>, %arg1: tensor<256x1xf32>) -> tensor<1x1xf32> {
%0 = "xla_hlo.dot"(%arg0, %arg1) {precision_config = ["DEFAULT", "DEFAULT"]} : (tensor<1x256xf32>, tensor<256x1xf32>) -> tensor<1x1xf32> %0 = "mhlo.dot"(%arg0, %arg1) {precision_config = ["DEFAULT", "DEFAULT"]} : (tensor<1x256xf32>, tensor<256x1xf32>) -> tensor<1x1xf32>
return %0 : tensor<1x1xf32> return %0 : tensor<1x1xf32>
} }
func @broadcast_in_dim_tf_style(%arg0: tensor<8x1x16xf32>) -> tensor<3x8x8x16xf32> { func @broadcast_in_dim_tf_style(%arg0: tensor<8x1x16xf32>) -> tensor<3x8x8x16xf32> {
%0 = "xla_hlo.broadcast_in_dim"(%arg0) {broadcast_dimensions = dense<[1, 2, 3]> : tensor<3xi64>, name = "broadcast.0"} : (tensor<8x1x16xf32>) -> tensor<3x8x8x16xf32> %0 = "mhlo.broadcast_in_dim"(%arg0) {broadcast_dimensions = dense<[1, 2, 3]> : tensor<3xi64>, name = "broadcast.0"} : (tensor<8x1x16xf32>) -> tensor<3x8x8x16xf32>
return %0 : tensor<3x8x8x16xf32> return %0 : tensor<3x8x8x16xf32>
} }
func @broadcast_in_dim_general_case(%arg0: tensor<3x1x16xf32>) -> tensor<3x8x8x16xf32> { func @broadcast_in_dim_general_case(%arg0: tensor<3x1x16xf32>) -> tensor<3x8x8x16xf32> {
%0 = "xla_hlo.broadcast_in_dim"(%arg0) {broadcast_dimensions = dense<[0, 2, 3]> : tensor<3xi64>, name = "broadcast.0"} : (tensor<3x1x16xf32>) -> tensor<3x8x8x16xf32> %0 = "mhlo.broadcast_in_dim"(%arg0) {broadcast_dimensions = dense<[0, 2, 3]> : tensor<3xi64>, name = "broadcast.0"} : (tensor<3x1x16xf32>) -> tensor<3x8x8x16xf32>
return %0 : tensor<3x8x8x16xf32> return %0 : tensor<3x8x8x16xf32>
} }
func @convert_dot_general(%arg0: tensor<3x2x6x5x1xf32>, %arg1: tensor<3x2x4x6xf32>) -> tensor<3x5x1x4xf32> { func @convert_dot_general(%arg0: tensor<3x2x6x5x1xf32>, %arg1: tensor<3x2x4x6xf32>) -> tensor<3x5x1x4xf32> {
%0 = "xla_hlo.dot_general"(%arg0, %arg1) {dot_dimension_numbers = {lhs_batching_dimensions = dense<0> : tensor<1xi64>, lhs_contracting_dimensions = dense<[1, 2]> : tensor<2xi64>, rhs_batching_dimensions = dense<0> : tensor<1xi64>, rhs_contracting_dimensions = dense<[1, 3]> : tensor<2xi64>}, precision_config = ["DEFAULT", "DEFAULT"]} : (tensor<3x2x6x5x1xf32>, tensor<3x2x4x6xf32>) -> tensor<3x5x1x4xf32> %0 = "mhlo.dot_general"(%arg0, %arg1) {dot_dimension_numbers = {lhs_batching_dimensions = dense<0> : tensor<1xi64>, lhs_contracting_dimensions = dense<[1, 2]> : tensor<2xi64>, rhs_batching_dimensions = dense<0> : tensor<1xi64>, rhs_contracting_dimensions = dense<[1, 3]> : tensor<2xi64>}, precision_config = ["DEFAULT", "DEFAULT"]} : (tensor<3x2x6x5x1xf32>, tensor<3x2x4x6xf32>) -> tensor<3x5x1x4xf32>
return %0 : tensor<3x5x1x4xf32> return %0 : tensor<3x5x1x4xf32>
} }
func @convert_conv2d(%arg0: tensor<1x8x8x207xf32>, %arg1: tensor<3x3x207x16xf32>) -> tensor<1x8x8x16xf32> { func @convert_conv2d(%arg0: tensor<1x8x8x207xf32>, %arg1: tensor<3x3x207x16xf32>) -> tensor<1x8x8x16xf32> {
%0 = "xla_hlo.convolution"(%arg0, %arg1) {batch_group_count = 1 : i64, dimension_numbers = %0 = "mhlo.convolution"(%arg0, %arg1) {batch_group_count = 1 : i64, dimension_numbers =
{input_batch_dimension = 0 : i64, input_feature_dimension = 3 : i64, input_spatial_dimensions = dense<[1, 2]> : tensor<2xi64>, kernel_input_feature_dimension = 2 : i64, kernel_output_feature_dimension = 3 : i64, kernel_spatial_dimensions = dense<[0, 1]> : tensor<2xi64>, output_batch_dimension = 0 : i64, output_feature_dimension = 3 : i64, output_spatial_dimensions = dense<[1, 2]> : tensor<2xi64>}, {input_batch_dimension = 0 : i64, input_feature_dimension = 3 : i64, input_spatial_dimensions = dense<[1, 2]> : tensor<2xi64>, kernel_input_feature_dimension = 2 : i64, kernel_output_feature_dimension = 3 : i64, kernel_spatial_dimensions = dense<[0, 1]> : tensor<2xi64>, output_batch_dimension = 0 : i64, output_feature_dimension = 3 : i64, output_spatial_dimensions = dense<[1, 2]> : tensor<2xi64>},
feature_group_count = 1 : i64, lhs_dilation = dense<1> : tensor<2xi64>, padding = dense<1> : tensor<2x2xi64>, precision_config = ["DEFAULT", "DEFAULT"], rhs_dilation = dense<1> : tensor<2xi64>, window_strides = dense<1> : tensor<2xi64>} : feature_group_count = 1 : i64, lhs_dilation = dense<1> : tensor<2xi64>, padding = dense<1> : tensor<2x2xi64>, precision_config = ["DEFAULT", "DEFAULT"], rhs_dilation = dense<1> : tensor<2xi64>, window_strides = dense<1> : tensor<2xi64>} :
(tensor<1x8x8x207xf32>, tensor<3x3x207x16xf32>) -> tensor<1x8x8x16xf32> (tensor<1x8x8x207xf32>, tensor<3x3x207x16xf32>) -> tensor<1x8x8x16xf32>
@ -737,7 +737,7 @@ func @convert_conv2d(%arg0: tensor<1x8x8x207xf32>, %arg1: tensor<3x3x207x16xf32>
} }
func @convert_depthwise_conv2d(%arg0: tensor<1x8x8x207xf32>, %arg1: tensor<3x3x207x16xf32>) -> tensor<1x8x8x16xf32> { func @convert_depthwise_conv2d(%arg0: tensor<1x8x8x207xf32>, %arg1: tensor<3x3x207x16xf32>) -> tensor<1x8x8x16xf32> {
%0 = "xla_hlo.convolution"(%arg0, %arg1) {batch_group_count = 1 : i64, dimension_numbers = %0 = "mhlo.convolution"(%arg0, %arg1) {batch_group_count = 1 : i64, dimension_numbers =
{input_batch_dimension = 0 : i64, input_feature_dimension = 3 : i64, input_spatial_dimensions = dense<[1, 2]> : tensor<2xi64>, kernel_input_feature_dimension = 2 : i64, kernel_output_feature_dimension = 3 : i64, kernel_spatial_dimensions = dense<[0, 1]> : tensor<2xi64>, output_batch_dimension = 0 : i64, output_feature_dimension = 3 : i64, output_spatial_dimensions = dense<[1, 2]> : tensor<2xi64>}, {input_batch_dimension = 0 : i64, input_feature_dimension = 3 : i64, input_spatial_dimensions = dense<[1, 2]> : tensor<2xi64>, kernel_input_feature_dimension = 2 : i64, kernel_output_feature_dimension = 3 : i64, kernel_spatial_dimensions = dense<[0, 1]> : tensor<2xi64>, output_batch_dimension = 0 : i64, output_feature_dimension = 3 : i64, output_spatial_dimensions = dense<[1, 2]> : tensor<2xi64>},
feature_group_count = 207 : i64, lhs_dilation = dense<1> : tensor<2xi64>, padding = dense<1> : tensor<2x2xi64>, precision_config = ["DEFAULT", "DEFAULT"], rhs_dilation = dense<1> : tensor<2xi64>, window_strides = dense<1> : tensor<2xi64>} : feature_group_count = 207 : i64, lhs_dilation = dense<1> : tensor<2xi64>, padding = dense<1> : tensor<2x2xi64>, precision_config = ["DEFAULT", "DEFAULT"], rhs_dilation = dense<1> : tensor<2xi64>, window_strides = dense<1> : tensor<2xi64>} :
(tensor<1x8x8x207xf32>, tensor<3x3x207x16xf32>) -> tensor<1x8x8x16xf32> (tensor<1x8x8x207xf32>, tensor<3x3x207x16xf32>) -> tensor<1x8x8x16xf32>
@ -745,7 +745,7 @@ func @convert_depthwise_conv2d(%arg0: tensor<1x8x8x207xf32>, %arg1: tensor<3x3x2
} }
func @convert_conv2d_valid_padding(%arg0: tensor<1x8x8x207xf32>, %arg1: tensor<3x3x207x16xf32>) -> tensor<1x8x8x16xf32> { func @convert_conv2d_valid_padding(%arg0: tensor<1x8x8x207xf32>, %arg1: tensor<3x3x207x16xf32>) -> tensor<1x8x8x16xf32> {
%0 = "xla_hlo.convolution"(%arg0, %arg1) {batch_group_count = 1 : i64, dimension_numbers = %0 = "mhlo.convolution"(%arg0, %arg1) {batch_group_count = 1 : i64, dimension_numbers =
{input_batch_dimension = 0 : i64, input_feature_dimension = 3 : i64, input_spatial_dimensions = dense<[1, 2]> : tensor<2xi64>, kernel_input_feature_dimension = 2 : i64, kernel_output_feature_dimension = 3 : i64, kernel_spatial_dimensions = dense<[0, 1]> : tensor<2xi64>, output_batch_dimension = 0 : i64, output_feature_dimension = 3 : i64, output_spatial_dimensions = dense<[1, 2]> : tensor<2xi64>}, {input_batch_dimension = 0 : i64, input_feature_dimension = 3 : i64, input_spatial_dimensions = dense<[1, 2]> : tensor<2xi64>, kernel_input_feature_dimension = 2 : i64, kernel_output_feature_dimension = 3 : i64, kernel_spatial_dimensions = dense<[0, 1]> : tensor<2xi64>, output_batch_dimension = 0 : i64, output_feature_dimension = 3 : i64, output_spatial_dimensions = dense<[1, 2]> : tensor<2xi64>},
feature_group_count = 1 : i64, lhs_dilation = dense<1> : tensor<2xi64>, padding = dense<0> : tensor<2x2xi64>, precision_config = ["DEFAULT", "DEFAULT"], rhs_dilation = dense<1> : tensor<2xi64>, window_strides = dense<1> : tensor<2xi64>} : feature_group_count = 1 : i64, lhs_dilation = dense<1> : tensor<2xi64>, padding = dense<0> : tensor<2x2xi64>, precision_config = ["DEFAULT", "DEFAULT"], rhs_dilation = dense<1> : tensor<2xi64>, window_strides = dense<1> : tensor<2xi64>} :
(tensor<1x8x8x207xf32>, tensor<3x3x207x16xf32>) -> tensor<1x8x8x16xf32> (tensor<1x8x8x207xf32>, tensor<3x3x207x16xf32>) -> tensor<1x8x8x16xf32>
@ -753,22 +753,22 @@ func @convert_conv2d_valid_padding(%arg0: tensor<1x8x8x207xf32>, %arg1: tensor<3
} }
func @convert_reduce_to_sum(%arg0: tensor<1x256xf32>) -> tensor<1xf32> { func @convert_reduce_to_sum(%arg0: tensor<1x256xf32>) -> tensor<1xf32> {
%0 = xla_hlo.constant dense<0.000000e+00> : tensor<f32> %0 = mhlo.constant dense<0.000000e+00> : tensor<f32>
%1 = "xla_hlo.reduce"(%arg0, %0) ( { %1 = "mhlo.reduce"(%arg0, %0) ( {
^bb0(%arg1: tensor<f32>, %arg2: tensor<f32>): ^bb0(%arg1: tensor<f32>, %arg2: tensor<f32>):
%2 = xla_hlo.add %arg1, %arg2 : tensor<f32> %2 = mhlo.add %arg1, %arg2 : tensor<f32>
"xla_hlo.return"(%2) : (tensor<f32>) -> () "mhlo.return"(%2) : (tensor<f32>) -> ()
}) {dimensions = dense<1> : tensor<1xi64>} : (tensor<1x256xf32>, tensor<f32>) -> tensor<1xf32> }) {dimensions = dense<1> : tensor<1xi64>} : (tensor<1x256xf32>, tensor<f32>) -> tensor<1xf32>
return %1 : tensor<1xf32> return %1 : tensor<1xf32>
} }
func @convert_reduce_to_max(%arg0: tensor<1x256xf32>) -> tensor<1xf32> { func @convert_reduce_to_max(%arg0: tensor<1x256xf32>) -> tensor<1xf32> {
// "0xFF800000" represents -INF for f32. // "0xFF800000" represents -INF for f32.
%0 = xla_hlo.constant dense<0xFF800000> : tensor<f32> %0 = mhlo.constant dense<0xFF800000> : tensor<f32>
%1 = "xla_hlo.reduce"(%arg0, %0) ( { %1 = "mhlo.reduce"(%arg0, %0) ( {
^bb0(%arg1: tensor<f32>, %arg2: tensor<f32>): ^bb0(%arg1: tensor<f32>, %arg2: tensor<f32>):
%2 = xla_hlo.maximum %arg1, %arg2 : tensor<f32> %2 = mhlo.maximum %arg1, %arg2 : tensor<f32>
"xla_hlo.return"(%2) : (tensor<f32>) -> () "mhlo.return"(%2) : (tensor<f32>) -> ()
}) {dimensions = dense<1> : tensor<1xi64>} : (tensor<1x256xf32>, tensor<f32>) -> tensor<1xf32> }) {dimensions = dense<1> : tensor<1xi64>} : (tensor<1x256xf32>, tensor<f32>) -> tensor<1xf32>
return %1 : tensor<1xf32> return %1 : tensor<1xf32>
} }
@ -776,11 +776,11 @@ func @convert_reduce_to_max(%arg0: tensor<1x256xf32>) -> tensor<1xf32> {
func @convert_reduce_to_min(%arg0: tensor<1x256xf32>) -> tensor<1xf32> { func @convert_reduce_to_min(%arg0: tensor<1x256xf32>) -> tensor<1xf32> {
// "0x7F800000" represents INF for f32. // "0x7F800000" represents INF for f32.
%0 = xla_hlo.constant dense<0x7F800000> : tensor<f32> %0 = mhlo.constant dense<0x7F800000> : tensor<f32>
%1 = "xla_hlo.reduce"(%arg0, %0) ( { %1 = "mhlo.reduce"(%arg0, %0) ( {
^bb0(%arg1: tensor<f32>, %arg2: tensor<f32>): ^bb0(%arg1: tensor<f32>, %arg2: tensor<f32>):
%2 = xla_hlo.minimum %arg1, %arg2 : tensor<f32> %2 = mhlo.minimum %arg1, %arg2 : tensor<f32>
"xla_hlo.return"(%2) : (tensor<f32>) -> () "mhlo.return"(%2) : (tensor<f32>) -> ()
}) {dimensions = dense<1> : tensor<1xi64>} : (tensor<1x256xf32>, tensor<f32>) -> tensor<1xf32> }) {dimensions = dense<1> : tensor<1xi64>} : (tensor<1x256xf32>, tensor<f32>) -> tensor<1xf32>
return %1 : tensor<1xf32> return %1 : tensor<1xf32>
} }

View File

@ -42,16 +42,19 @@ func @empty_replicate() {
// CHECK-LABEL: func @replicate_with_multiple_operands // CHECK-LABEL: func @replicate_with_multiple_operands
func @replicate_with_multiple_operands() { func @replicate_with_multiple_operands() {
%0 = "tf.opA"() : () -> (tensor<*xi1>) %0 = "tf.opA"() : () -> tensor<*xi1>
%1 = "tf.opB"() : () -> (tensor<*xi1>) %1 = "tf.opB"() : () -> tensor<*xi1>
%2 = "tf.opC"() : () -> (tensor<*xi1>) %2 = "tf.opC"() : () -> tensor<*xi1>
%3 = "tf.opD"() : () -> (tensor<*xi32>) %3 = "tf.opD"() : () -> tensor<*xi32>
%4 = "tf.opE"() : () -> (tensor<*xi32>) %4 = "tf.opE"() : () -> tensor<*xi32>
%5 = "tf.opF"() : () -> (tensor<*xi32>) %5 = "tf.opF"() : () -> tensor<*xi32>
%6 = "tf.opG"() : () -> (tensor<*xf32>) %6 = "tf.opG"() : () -> tensor<*xf32>
%7 = "tf.opH"() : () -> (tensor<*xf32>) %7 = "tf.opH"() : () -> tensor<*xf32>
%8 = "tf.opI"() : () -> (tensor<*xf32>) %8 = "tf.opI"() : () -> tensor<*xf32>
tf_device.replicate([%0, %1, %2] as %input0: tensor<*xi1>, [%3, %4, %5] as %input1: tensor<*xi32>, [%6, %7, %8] as %input2: tensor<*xf32>) {n = 3 : i32} { %9 = "tf.opJ"() : () -> tensor<*xi8>
%10 = "tf.opK"() : () -> tensor<*xi16>
%11 = "tf.opL"() : () -> tensor<*xi64>
tf_device.replicate([%0, %1, %2] as %input0: tensor<*xi1>, %9 as %input1: tensor<*xi8>, %10 as %input2: tensor<*xi16>, [%3, %4, %5] as %input3: tensor<*xi32>, [%6, %7, %8] as %input4: tensor<*xf32>, %11 as %input5: tensor<*xi64>) {n = 3 : i32} {
tf_device.return tf_device.return
} }
return return
@ -65,12 +68,32 @@ func @replicate_with_multiple_operands() {
// CHECK: %[[OP_G:[a-z0-9]*]] = "tf.opG" // CHECK: %[[OP_G:[a-z0-9]*]] = "tf.opG"
// CHECK: %[[OP_H:[a-z0-9]*]] = "tf.opH" // CHECK: %[[OP_H:[a-z0-9]*]] = "tf.opH"
// CHECK: %[[OP_I:[a-z0-9]*]] = "tf.opI" // CHECK: %[[OP_I:[a-z0-9]*]] = "tf.opI"
// CHECK: %[[OP_J:[a-z0-9]*]] = "tf.opJ"
// CHECK: %[[OP_K:[a-z0-9]*]] = "tf.opK"
// CHECK: %[[OP_L:[a-z0-9]*]] = "tf.opL"
// CHECK: tf_device.replicate // CHECK: tf_device.replicate
// CHECK-SAME: ([%[[OP_A]], %[[OP_B]], %[[OP_C]]] as %{{[a-z0-9]*}}: tensor<*xi1>, [%[[OP_D]], %[[OP_E]], %[[OP_F]]] as %{{[a-z0-9]*}}: tensor<*xi32>, [%[[OP_G]], %[[OP_H]], %[[OP_I]]] as %{{[a-z0-9]*}}: tensor<*xf32>) // CHECK-SAME: [%[[OP_A]], %[[OP_B]], %[[OP_C]]] as %{{[a-z0-9]*}}: tensor<*xi1>
// CHECK-SAME: [%[[OP_D]], %[[OP_E]], %[[OP_F]]] as %{{[a-z0-9]*}}: tensor<*xi32>
// CHECK-SAME: [%[[OP_G]], %[[OP_H]], %[[OP_I]]] as %{{[a-z0-9]*}}: tensor<*xf32>
// CHECK-SAME: %[[OP_J]] as %{{[a-z0-9]*}}: tensor<*xi8>
// CHECK-SAME: %[[OP_K]] as %{{[a-z0-9]*}}: tensor<*xi16>
// CHECK-SAME: %[[OP_L]] as %{{[a-z0-9]*}}: tensor<*xi64>
// CHECK-SAME: n = 3 // CHECK-SAME: n = 3
// CHECK-NEXT: tf_device.return // CHECK-NEXT: tf_device.return
} }
// CHECK-LABEL: func @replicate_derived_operand_segment_sizes
func @replicate_derived_operand_segment_sizes() {
tf_device.replicate {n = 2 : i32, operand_segment_sizes = dense<[0, 0]> : vector<2xi32>} {
}
return
// CHECK: tf_device.replicate
// CHECK-SAME: n = 2
// CHECK-NOT: operand_segment_sizes
// CHECK-NEXT: tf_device.return
}
// CHECK-LABEL: func @replicate_with_return // CHECK-LABEL: func @replicate_with_return
// CHECK-SAME: (%[[ARG_0:[a-z0-9]*]]: tensor<*xf32>, %[[ARG_1:[a-z0-9]*]]: tensor<*xf32>, %[[ARG_2:[a-z0-9]*]]: tensor<*xi32>) // CHECK-SAME: (%[[ARG_0:[a-z0-9]*]]: tensor<*xf32>, %[[ARG_1:[a-z0-9]*]]: tensor<*xf32>, %[[ARG_2:[a-z0-9]*]]: tensor<*xi32>)
func @replicate_with_return(%arg0: tensor<*xf32>, %arg1: tensor<*xf32>, %arg2: tensor<*xi32>) { func @replicate_with_return(%arg0: tensor<*xf32>, %arg1: tensor<*xf32>, %arg2: tensor<*xi32>) {

View File

@ -61,7 +61,7 @@ func @parser_replicate_terminator() {
func @verifier_replicate_no_block() { func @verifier_replicate_no_block() {
"tf_device.replicate" () ({ "tf_device.replicate" () ({
// expected-error@-1 {{'tf_device.replicate' op region #0 ('body') failed to verify constraint: region with 1 blocks}} // expected-error@-1 {{'tf_device.replicate' op region #0 ('body') failed to verify constraint: region with 1 blocks}}
}) {n = 2 : i32} : () -> () }) {n = 2 : i32, operand_segment_sizes = dense<[0, 0]> : vector<2xi32>} : () -> ()
return return
} }
@ -72,7 +72,7 @@ func @verifier_replicate_empty_block() {
"tf_device.replicate" () ({ "tf_device.replicate" () ({
// expected-error@-1 {{'tf_device.replicate' op expects a non-empty block}} // expected-error@-1 {{'tf_device.replicate' op expects a non-empty block}}
^entry: ^entry:
}) {n = 2 : i32} : () -> () }) {n = 2 : i32, operand_segment_sizes = dense<[0, 0]> : vector<2xi32>} : () -> ()
return return
} }
@ -85,7 +85,7 @@ func @verifier_replicate_terminator() {
// expected-error@-1 {{'tf_device.replicate' op expects regions to end with 'tf_device.return', found 'std.return'}} // expected-error@-1 {{'tf_device.replicate' op expects regions to end with 'tf_device.return', found 'std.return'}}
^entry: ^entry:
return return
}) {n = 2 : i32} : () -> () }) {n = 2 : i32, operand_segment_sizes = dense<[0, 0]> : vector<2xi32>} : () -> ()
return return
} }
@ -97,7 +97,7 @@ func @verifier_replicate_n() {
// expected-error@-1 {{'tf_device.replicate' op attribute 'n' failed to satisfy constraint: 32-bit signless integer attribute whose minimum value is 2}} // expected-error@-1 {{'tf_device.replicate' op attribute 'n' failed to satisfy constraint: 32-bit signless integer attribute whose minimum value is 2}}
^entry: ^entry:
tf_device.return tf_device.return
}) {n = 1 : i32} : () -> () }) {n = 1 : i32, operand_segment_sizes = dense<[0, 0]> : vector<2xi32>} : () -> ()
} }
// ----- // -----
@ -109,43 +109,66 @@ func @verifier_replicate_n_device() {
// expected-error@-1 {{'tf_device.replicate' op expects number of devices (2) to be equal to 'n' (3)}} // expected-error@-1 {{'tf_device.replicate' op expects number of devices (2) to be equal to 'n' (3)}}
^entry: ^entry:
tf_device.return tf_device.return
}) {n = 3 : i32, devices = {TPU_REPLICATED_CORE_0 = ["/DEVICE:0", "/DEVICE:1"]}} : () -> () }) {devices = {TPU_REPLICATED_CORE_0 = ["/DEVICE:0", "/DEVICE:1"]}, n = 3 : i32, operand_segment_sizes = dense<[0, 0]> : vector<2xi32>} : () -> ()
} }
// ----- // -----
// Check that replicate op's `devices` attribute must consist of dictionary // Check that replicate op's 'devices' attribute must consist of dictionary
// with values as list with size equal to 'n' attribute. // with values as list with size equal to 'n' attribute.
func @verifier_replicate_n_device_multiple_alias() { func @verifier_replicate_n_device_multiple_alias() {
"tf_device.replicate" () ({ "tf_device.replicate" () ({
// expected-error@-1 {{'tf_device.replicate' op expects number of devices (2) to be equal to 'n' (3)}} // expected-error@-1 {{'tf_device.replicate' op expects number of devices (2) to be equal to 'n' (3)}}
^entry: ^entry:
tf_device.return tf_device.return
}) {n = 3 : i32, devices = {TPU_REPLICATED_CORE_0 = ["/DEVICE:0", "/DEVICE:1"], TPU_REPLICATED_CORE_1 = ["/DEVICE:2"]}} : () -> () }) {devices = {TPU_REPLICATED_CORE_0 = ["/DEVICE:0", "/DEVICE:1"], TPU_REPLICATED_CORE_1 = ["/DEVICE:2"]}, n = 3 : i32, operand_segment_sizes = dense<[0, 0]> : vector<2xi32>} : () -> ()
} }
// ----- // -----
// Check that a replicate with mismatched operand and block arg counts is // Check number of replicated inputs is evenly divisible by 'n'.
// invalid. func @verifier_replicate_bad_operand_segment_sizes(%arg0: tensor<*xi32>) {
func @verifier_replicate_operand_block_arg_count(%arg0: tensor<*xi32>) { "tf_device.replicate" (%arg0, %arg0, %arg0, %arg0) ({
"tf_device.replicate" (%arg0, %arg0, %arg0) ({ // expected-error@-1 {{'tf_device.replicate' op expects number of replicated inputs (4) to be evenly divisible by 'n' (3)}}
// expected-error@-1 {{'tf_device.replicate' op expects number of operands (3) to be equal to 'n' * number of block arguments (2 * 1)}} ^entry(%input0: tensor<*xi32>, %input1: tensor<*xi32>):
^entry(%input0: tensor<*xi32>):
tf_device.return tf_device.return
}) {n = 2 : i32} : (tensor<*xi32>, tensor<*xi32>, tensor<*xi32>) -> () }) {n = 3 : i32, operand_segment_sizes = dense<[4, 0]> : vector<2xi32>} : (tensor<*xi32>, tensor<*xi32>, tensor<*xi32>, tensor<*xi32>) -> ()
} }
// ----- // -----
// Check that a replicate with incompatible operand and block argument type is // Check number of replicated inputs / 'n' + number of packed inputs matches the
// invalid. // number of block arguments.
func @verifier_replicate_operand_block_arg_type(%arg0: tensor<*xi32>, %arg1: tensor<*xi1>) { func @verifier_replicate_num_block_args(%arg0: tensor<*xi32>) {
"tf_device.replicate" (%arg0, %arg0, %arg0, %arg0, %arg0) ({
// expected-error@-1 {{'tf_device.replicate' op expects number of block arguments (2) to be equal to number of replicated inputs (3) / 'n' (3) + number of packed inputs (2)}}
^entry(%input0: tensor<*xi32>, %input1: tensor<*xi32>):
tf_device.return
}) {n = 3 : i32, operand_segment_sizes = dense<[3, 2]> : vector<2xi32>} : (tensor<*xi32>, tensor<*xi32>, tensor<*xi32>, tensor<*xi32>, tensor<*xi32>) -> ()
}
// -----
// Check that a replicate with incompatible replicated operand and block
// argument type is invalid.
func @verifier_replicate_replicated_operand_block_arg_type(%arg0: tensor<*xi32>, %arg1: tensor<*xi1>) {
"tf_device.replicate" (%arg0, %arg1) ({ "tf_device.replicate" (%arg0, %arg1) ({
// expected-error@-1 {{'tf_device.replicate' op incompatible types for operand 1 and block argument 0}} // expected-error@-1 {{'tf_device.replicate' op expects operand 1 ('tensor<*xi1>') and block argument 0 ('tensor<*xi32>') to have compatible types}}
^entry(%input0: tensor<*xi32>): ^entry(%input0: tensor<*xi32>):
tf_device.return tf_device.return
}) {n = 2 : i32} : (tensor<*xi32>, tensor<*xi1>) -> () }) {n = 2 : i32, operand_segment_sizes = dense<[2, 0]> : vector<2xi32>} : (tensor<*xi32>, tensor<*xi1>) -> ()
}
// -----
// Check that a replicate with incompatible packed operand and block argument
// type is invalid.
func @verifier_replicate_packed_operand_block_arg_type(%arg0: tensor<*xi1>) {
"tf_device.replicate" (%arg0) ({
// expected-error@-1 {{'tf_device.replicate' op expects operand 0 ('tensor<*xi1>') and block argument 0 ('tensor<*xi32>') to have compatible types}}
^entry(%input0: tensor<*xi32>):
tf_device.return
}) {n = 2 : i32, operand_segment_sizes = dense<[0, 1]> : vector<2xi32>} : (tensor<*xi1>) -> ()
} }
// ----- // -----
@ -157,7 +180,7 @@ func @verifier_replicate_result_return_operand_count(%arg0: tensor<*xi32>) {
// expected-error@-1 {{'tf_device.replicate' op expects number of results (3) to be equal to 'n' * number of terminator operands (2 * 1)}} // expected-error@-1 {{'tf_device.replicate' op expects number of results (3) to be equal to 'n' * number of terminator operands (2 * 1)}}
^entry: ^entry:
tf_device.return %arg0 : tensor<*xi32> tf_device.return %arg0 : tensor<*xi32>
}) {n = 2 : i32} : () -> (tensor<*xi32>, tensor<*xi32>, tensor<*xi32>) }) {n = 2 : i32, operand_segment_sizes = dense<[0, 0]> : vector<2xi32>} : () -> (tensor<*xi32>, tensor<*xi32>, tensor<*xi32>)
} }
// ----- // -----
@ -169,7 +192,7 @@ func @verifier_replicate_result_return_operand_type(%arg0: tensor<*xi32>) {
// expected-error@-1 {{'tf_device.replicate' op incompatible types for result 1 and terminator operand 0}} // expected-error@-1 {{'tf_device.replicate' op incompatible types for result 1 and terminator operand 0}}
^entry: ^entry:
tf_device.return %arg0 : tensor<*xi32> tf_device.return %arg0 : tensor<*xi32>
}) {n = 2 : i32} : () -> (tensor<*xi32>, tensor<*xi1>) }) {n = 2 : i32, operand_segment_sizes = dense<[0, 0]> : vector<2xi32>} : () -> (tensor<*xi32>, tensor<*xi1>)
} }
// ----- // -----

Some files were not shown because too many files have changed in this diff Show More