Merge branch 'master' of github.com:ashahba/tensorflow into ashahba/ubuntu-onednn-partials
This commit is contained in:
commit
f7dabcae30
@ -243,8 +243,10 @@ TFE_TensorHandle* CopyTensorFromParallelDevice(TFE_Context* context,
|
||||
const char* target_device_name,
|
||||
TF_Status* status,
|
||||
void* device_info) {
|
||||
TF_SetStatus(status, TF_INTERNAL,
|
||||
"Trying to copy a tensor out of a parallel device.");
|
||||
TF_SetStatus(status, TF_UNIMPLEMENTED,
|
||||
"Trying to copy a tensor out of a parallel device. Since there "
|
||||
"are multiple components to parallel tensors, they must be "
|
||||
"unpacked explicitly.");
|
||||
return nullptr;
|
||||
}
|
||||
|
||||
|
@ -157,7 +157,7 @@ TEST(PARALLEL_DEVICE, TestExplicitCopies) {
|
||||
// Copies off of parallel devices must be explicit.
|
||||
TensorHandlePtr copy_back(TFE_TensorHandleCopyToDevice(
|
||||
device_value.get(), context.get(), first_device_name, status.get()));
|
||||
ASSERT_EQ(TF_GetCode(status.get()), TF_INTERNAL);
|
||||
ASSERT_EQ(TF_GetCode(status.get()), TF_UNIMPLEMENTED);
|
||||
}
|
||||
|
||||
TEST(PARALLEL_DEVICE, TestDifferentShapes) {
|
||||
|
@ -73,6 +73,14 @@ void ParseGCSPath(const std::string& fname, bool object_empty_ok,
|
||||
}
|
||||
}
|
||||
|
||||
/// Appends a trailing slash if the name doesn't already have one.
|
||||
static void MaybeAppendSlash(std::string* name) {
|
||||
if (name->empty())
|
||||
*name = "/";
|
||||
else if (name->back() != '/')
|
||||
name->push_back('/');
|
||||
}
|
||||
|
||||
// SECTION 1. Implementation for `TF_RandomAccessFile`
|
||||
// ----------------------------------------------------------------------------
|
||||
namespace tf_random_access_file {
|
||||
@ -410,6 +418,70 @@ void NewReadOnlyMemoryRegionFromFile(const TF_Filesystem* filesystem,
|
||||
}
|
||||
}
|
||||
|
||||
void CreateDir(const TF_Filesystem* filesystem, const char* path,
|
||||
TF_Status* status) {
|
||||
std::string bucket, object;
|
||||
ParseGCSPath(path, true, &bucket, &object, status);
|
||||
if (TF_GetCode(status) != TF_OK) return;
|
||||
auto gcs_file = static_cast<GCSFile*>(filesystem->plugin_filesystem);
|
||||
if (object.empty()) {
|
||||
auto bucket_metadata = gcs_file->gcs_client.GetBucketMetadata(bucket);
|
||||
TF_SetStatusFromGCSStatus(bucket_metadata.status(), status);
|
||||
return;
|
||||
}
|
||||
|
||||
MaybeAppendSlash(&object);
|
||||
auto object_metadata = gcs_file->gcs_client.GetObjectMetadata(bucket, object);
|
||||
TF_SetStatusFromGCSStatus(object_metadata.status(), status);
|
||||
if (TF_GetCode(status) == TF_NOT_FOUND) {
|
||||
auto insert_metadata =
|
||||
gcs_file->gcs_client.InsertObject(bucket, object, "");
|
||||
TF_SetStatusFromGCSStatus(insert_metadata.status(), status);
|
||||
} else if (TF_GetCode(status) == TF_OK) {
|
||||
TF_SetStatus(status, TF_ALREADY_EXISTS, path);
|
||||
}
|
||||
}
|
||||
|
||||
void DeleteFile(const TF_Filesystem* filesystem, const char* path,
|
||||
TF_Status* status) {
|
||||
std::string bucket, object;
|
||||
ParseGCSPath(path, false, &bucket, &object, status);
|
||||
if (TF_GetCode(status) != TF_OK) return;
|
||||
auto gcs_file = static_cast<GCSFile*>(filesystem->plugin_filesystem);
|
||||
auto gcs_status = gcs_file->gcs_client.DeleteObject(bucket, object);
|
||||
TF_SetStatusFromGCSStatus(gcs_status, status);
|
||||
}
|
||||
|
||||
void DeleteDir(const TF_Filesystem* filesystem, const char* path,
|
||||
TF_Status* status) {
|
||||
std::string bucket, object;
|
||||
ParseGCSPath(path, false, &bucket, &object, status);
|
||||
if (TF_GetCode(status) != TF_OK) return;
|
||||
MaybeAppendSlash(&object);
|
||||
auto gcs_file = static_cast<GCSFile*>(filesystem->plugin_filesystem);
|
||||
int object_count = 0;
|
||||
for (auto&& metadata :
|
||||
gcs_file->gcs_client.ListObjects(bucket, gcs::Prefix(object))) {
|
||||
if (!metadata) {
|
||||
TF_SetStatusFromGCSStatus(metadata.status(), status);
|
||||
return;
|
||||
}
|
||||
++object_count;
|
||||
// We consider a path is a non-empty directory in two cases:
|
||||
// - There are more than two objects whose keys start with the name of this
|
||||
// directory.
|
||||
// - There is one object whose key contains the name of this directory ( but
|
||||
// not equal ).
|
||||
if (object_count > 1 || metadata->name() != object) {
|
||||
TF_SetStatus(status, TF_FAILED_PRECONDITION,
|
||||
"Cannot delete a non-empty directory.");
|
||||
return;
|
||||
}
|
||||
}
|
||||
auto gcs_status = gcs_file->gcs_client.DeleteObject(bucket, object);
|
||||
TF_SetStatusFromGCSStatus(gcs_status, status);
|
||||
}
|
||||
|
||||
} // namespace tf_gcs_filesystem
|
||||
|
||||
static void ProvideFilesystemSupportFor(TF_FilesystemPluginOps* ops,
|
||||
|
@ -19,9 +19,6 @@ package(
|
||||
|
||||
cc_library(
|
||||
name = "concrete_function",
|
||||
srcs = [
|
||||
"concrete_function.cc",
|
||||
],
|
||||
hdrs = [
|
||||
"concrete_function.h",
|
||||
],
|
||||
@ -29,7 +26,6 @@ cc_library(
|
||||
":function_metadata",
|
||||
"//tensorflow/c/eager:immediate_execution_operation",
|
||||
"//tensorflow/c/eager:immediate_execution_tensor_handle",
|
||||
"//tensorflow/core:protos_all_cc",
|
||||
],
|
||||
)
|
||||
|
||||
@ -60,10 +56,13 @@ cc_library(
|
||||
"saved_model_utils.h",
|
||||
],
|
||||
deps = [
|
||||
":function_metadata",
|
||||
"//tensorflow/c:tf_tensor_internal",
|
||||
"//tensorflow/c/eager:immediate_execution_context",
|
||||
"//tensorflow/c/experimental/saved_model/core/revived_types:constant",
|
||||
"//tensorflow/c/experimental/saved_model/core/revived_types:tf_concrete_function",
|
||||
"//tensorflow/c/experimental/saved_model/core/revived_types:variable",
|
||||
"//tensorflow/core:lib",
|
||||
"//tensorflow/core:protos_all_cc",
|
||||
],
|
||||
)
|
||||
@ -91,6 +90,18 @@ cc_library(
|
||||
],
|
||||
)
|
||||
|
||||
cc_library(
|
||||
name = "tf_concrete_function_test_protos",
|
||||
testonly = True,
|
||||
srcs = ["tf_concrete_function_test_protos.cc"],
|
||||
hdrs = ["tf_concrete_function_test_protos.h"],
|
||||
deps = [
|
||||
"//tensorflow/core:lib",
|
||||
"//tensorflow/core:protos_all_cc",
|
||||
"@com_google_absl//absl/strings",
|
||||
],
|
||||
)
|
||||
|
||||
cc_library(
|
||||
name = "tf_saved_model_impl",
|
||||
srcs = [
|
||||
@ -114,12 +125,16 @@ cc_library(
|
||||
"saved_model_api.h",
|
||||
],
|
||||
visibility = ["//tensorflow/python:__pkg__"],
|
||||
deps = [
|
||||
"//tensorflow/c/eager:immediate_execution_operation",
|
||||
"//tensorflow/c/eager:immediate_execution_tensor_handle",
|
||||
"//tensorflow/core:lib",
|
||||
],
|
||||
)
|
||||
|
||||
filegroup(
|
||||
name = "mobile_srcs_only_runtime",
|
||||
srcs = [
|
||||
"concrete_function.cc",
|
||||
"concrete_function.h",
|
||||
"function_metadata.h",
|
||||
"saved_model_api.h",
|
||||
@ -172,3 +187,28 @@ tf_cc_test(
|
||||
"//tensorflow/core/common_runtime/eager:core",
|
||||
],
|
||||
)
|
||||
|
||||
tf_cc_test(
|
||||
name = "tf_concrete_function_loading_test",
|
||||
srcs = [
|
||||
"tf_concrete_function_loading_test.cc",
|
||||
],
|
||||
deps = [
|
||||
":saved_model_utils",
|
||||
":test_utils",
|
||||
":tf_concrete_function_test_protos",
|
||||
"//tensorflow/c:tensor_interface",
|
||||
"//tensorflow/c/eager:immediate_execution_tensor_handle",
|
||||
"//tensorflow/c/experimental/saved_model/core/revived_types:constant",
|
||||
"//tensorflow/c/experimental/saved_model/core/revived_types:tensorhandle_convertible",
|
||||
"//tensorflow/c/experimental/saved_model/core/revived_types:tf_concrete_function",
|
||||
"//tensorflow/core:framework",
|
||||
"//tensorflow/core:lib",
|
||||
"//tensorflow/core:protos_all_cc",
|
||||
"//tensorflow/core:test",
|
||||
"//tensorflow/core:test_main",
|
||||
"//tensorflow/core/common_runtime:core_cpu_lib",
|
||||
"//tensorflow/core/common_runtime/eager:context",
|
||||
"//tensorflow/core/common_runtime/eager:core",
|
||||
],
|
||||
)
|
||||
|
@ -16,12 +16,12 @@ limitations under the License.
|
||||
#ifndef TENSORFLOW_C_EXPERIMENTAL_SAVED_MODEL_CORE_CONCRETE_FUNCTION_H_
|
||||
#define TENSORFLOW_C_EXPERIMENTAL_SAVED_MODEL_CORE_CONCRETE_FUNCTION_H_
|
||||
|
||||
#include <memory>
|
||||
#include <vector>
|
||||
|
||||
#include "tensorflow/c/eager/immediate_execution_operation.h"
|
||||
#include "tensorflow/c/eager/immediate_execution_tensor_handle.h"
|
||||
#include "tensorflow/c/experimental/saved_model/core/function_metadata.h"
|
||||
#include "tensorflow/core/framework/function.pb.h"
|
||||
|
||||
namespace tensorflow {
|
||||
|
||||
@ -35,19 +35,14 @@ namespace tensorflow {
|
||||
// and have only a single implementation.
|
||||
class ConcreteFunction {
|
||||
public:
|
||||
virtual ~ConcreteFunction() = 0;
|
||||
virtual ~ConcreteFunction() = default;
|
||||
|
||||
// This method returns the "Call" Op used to execute the function.
|
||||
virtual ImmediateExecutionOperation* GetCallOp() = 0;
|
||||
virtual Status GetCallOp(ImmediateOpPtr* out) = 0;
|
||||
|
||||
const std::vector<tensorflow::ImmediateExecutionTensorHandle*>& GetCaptures()
|
||||
const;
|
||||
const FunctionMetadata& GetFunctionMetadata() const;
|
||||
|
||||
private:
|
||||
FunctionMetadata metadata_;
|
||||
std::vector<tensorflow::ImmediateExecutionTensorHandle*> captures_;
|
||||
FunctionDef* function_;
|
||||
virtual const std::vector<ImmediateExecutionTensorHandle*>& GetCaptures()
|
||||
const = 0;
|
||||
virtual const FunctionMetadata& GetFunctionMetadata() const = 0;
|
||||
};
|
||||
|
||||
} // namespace tensorflow
|
||||
|
@ -14,6 +14,27 @@ package(
|
||||
licenses = ["notice"], # Apache 2.0
|
||||
)
|
||||
|
||||
cc_library(
|
||||
name = "restore_ops",
|
||||
srcs = [
|
||||
"restore_ops.cc",
|
||||
],
|
||||
hdrs = [
|
||||
"restore_ops.h",
|
||||
],
|
||||
deps = [
|
||||
"//tensorflow/c:tensor_interface",
|
||||
"//tensorflow/c/eager:abstract_tensor_handle",
|
||||
"//tensorflow/c/eager:immediate_execution_context",
|
||||
"//tensorflow/c/eager:immediate_execution_operation",
|
||||
"//tensorflow/c/eager:immediate_execution_tensor_handle",
|
||||
"//tensorflow/core:framework",
|
||||
"//tensorflow/core:lib",
|
||||
"//tensorflow/core:protos_all_cc",
|
||||
"//tensorflow/core/lib/llvm_rtti",
|
||||
],
|
||||
)
|
||||
|
||||
cc_library(
|
||||
name = "variable_ops",
|
||||
srcs = [
|
||||
@ -37,16 +58,45 @@ cc_library(
|
||||
)
|
||||
|
||||
tf_cc_test(
|
||||
name = "variable_ops_test",
|
||||
name = "restore_ops_test",
|
||||
srcs = [
|
||||
"variable_ops_test.cc",
|
||||
"restore_ops_test.cc",
|
||||
],
|
||||
data = [
|
||||
"//tensorflow/cc/saved_model:saved_model_half_plus_two",
|
||||
],
|
||||
deps = [
|
||||
":variable_ops",
|
||||
":restore_ops",
|
||||
"//tensorflow/c:tensor_interface",
|
||||
"//tensorflow/c/eager:abstract_tensor_handle",
|
||||
"//tensorflow/c/eager:immediate_execution_context",
|
||||
"//tensorflow/c/eager:immediate_execution_tensor_handle",
|
||||
"//tensorflow/c/experimental/saved_model/core:test_utils",
|
||||
"//tensorflow/cc/saved_model:constants",
|
||||
"//tensorflow/core:all_kernels",
|
||||
"//tensorflow/core:framework",
|
||||
"//tensorflow/core:lib",
|
||||
"//tensorflow/core:protos_all_cc",
|
||||
"//tensorflow/core:test",
|
||||
"//tensorflow/core:test_main",
|
||||
"//tensorflow/core/common_runtime:core_cpu_lib",
|
||||
"//tensorflow/core/common_runtime/eager:context",
|
||||
"//tensorflow/core/common_runtime/eager:core",
|
||||
],
|
||||
)
|
||||
|
||||
tf_cc_test(
|
||||
name = "variable_ops_test",
|
||||
srcs = [
|
||||
"variable_ops_test.cc",
|
||||
],
|
||||
deps = [
|
||||
":variable_ops",
|
||||
"//tensorflow/c:tensor_interface",
|
||||
"//tensorflow/c/eager:abstract_tensor_handle",
|
||||
"//tensorflow/c/eager:immediate_execution_context",
|
||||
"//tensorflow/c/eager:immediate_execution_tensor_handle",
|
||||
"//tensorflow/c/experimental/saved_model/core:test_utils",
|
||||
"//tensorflow/core:all_kernels",
|
||||
"//tensorflow/core:framework",
|
||||
"//tensorflow/core:lib",
|
||||
|
111
tensorflow/c/experimental/saved_model/core/ops/restore_ops.cc
Normal file
111
tensorflow/c/experimental/saved_model/core/ops/restore_ops.cc
Normal file
@ -0,0 +1,111 @@
|
||||
/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
|
||||
#include "tensorflow/c/experimental/saved_model/core/ops/restore_ops.h"
|
||||
|
||||
#include "tensorflow/c/eager/abstract_tensor_handle.h"
|
||||
#include "tensorflow/c/eager/immediate_execution_context.h"
|
||||
#include "tensorflow/c/eager/immediate_execution_operation.h"
|
||||
#include "tensorflow/c/eager/immediate_execution_tensor_handle.h"
|
||||
#include "tensorflow/c/tensor_interface.h"
|
||||
#include "tensorflow/core/framework/tensor.h"
|
||||
#include "tensorflow/core/framework/types.pb.h"
|
||||
#include "tensorflow/core/lib/llvm_rtti/llvm_rtti.h"
|
||||
#include "tensorflow/core/platform/errors.h"
|
||||
#include "tensorflow/core/platform/types.h"
|
||||
|
||||
namespace tensorflow {
|
||||
namespace internal {
|
||||
|
||||
namespace {
|
||||
|
||||
// Creates a scalar string tensorhandle containing a single string `s`
|
||||
Status CreateStringScalarTensorHandle(ImmediateExecutionContext* ctx,
|
||||
const std::string& s,
|
||||
ImmediateTensorHandlePtr* out) {
|
||||
AbstractTensorPtr tensor(ctx->CreateStringScalar(s));
|
||||
if (tensor.get() == nullptr) {
|
||||
return errors::Internal(
|
||||
"Failed to create scalar string tensor for checkpoint restore");
|
||||
}
|
||||
|
||||
out->reset(ctx->CreateLocalHandle(tensor.get()));
|
||||
return Status();
|
||||
}
|
||||
|
||||
// Creates a Rank 1 string tensorhandle containing a single string `s`
|
||||
Status CreateStringVectorTensorHandle(ImmediateExecutionContext* ctx,
|
||||
const std::string& s,
|
||||
ImmediateTensorHandlePtr* out) {
|
||||
int64 flat_shape[] = {1};
|
||||
AbstractTensorPtr tensor(ctx->CreateTensor(DT_STRING, flat_shape));
|
||||
if (tensor.get() == nullptr) {
|
||||
return errors::Internal(
|
||||
"Failed to create vector string tensor for checkpoint restore");
|
||||
}
|
||||
// Use placement new to construct the string, since we don't have
|
||||
// access to Tensor::flat. This is conceptually equivalent to:
|
||||
// tensor.flat<tstring>()(0) = s
|
||||
new (tensor->Data()) tstring(s);
|
||||
|
||||
out->reset(ctx->CreateLocalHandle(tensor.get()));
|
||||
return Status();
|
||||
}
|
||||
|
||||
} // namespace
|
||||
|
||||
Status SingleRestore(ImmediateExecutionContext* ctx, const std::string& prefix,
|
||||
const std::string& checkpoint_key, DataType dtype,
|
||||
ImmediateTensorHandlePtr* out) {
|
||||
// Create the EagerOp
|
||||
ImmediateOpPtr restore_op(ctx->CreateOperation());
|
||||
TF_RETURN_IF_ERROR(restore_op->Reset("RestoreV2", "/cpu:0"));
|
||||
TF_RETURN_IF_ERROR(restore_op->SetAttrTypeList("dtypes", &dtype, 1));
|
||||
|
||||
ImmediateTensorHandlePtr prefix_handle;
|
||||
TF_RETURN_IF_ERROR(
|
||||
CreateStringScalarTensorHandle(ctx, prefix, &prefix_handle));
|
||||
|
||||
ImmediateTensorHandlePtr names_handle;
|
||||
TF_RETURN_IF_ERROR(
|
||||
CreateStringVectorTensorHandle(ctx, checkpoint_key, &names_handle));
|
||||
|
||||
// Note that empty string is the slice spec used for a non-partitioned
|
||||
// ResourceVariable:
|
||||
// https://github.com/tensorflow/tensorflow/blob/06ff30f7ea35098cb68a231a9eb7ff3ff4be4e1e/tensorflow/python/training/saving/saveable_object_util.py#L194
|
||||
ImmediateTensorHandlePtr shapes_and_slices_handle;
|
||||
TF_RETURN_IF_ERROR(
|
||||
CreateStringVectorTensorHandle(ctx, "", &shapes_and_slices_handle));
|
||||
|
||||
TF_RETURN_IF_ERROR(restore_op->AddInput(prefix_handle.get()));
|
||||
TF_RETURN_IF_ERROR(restore_op->AddInput(names_handle.get()));
|
||||
TF_RETURN_IF_ERROR(restore_op->AddInput(shapes_and_slices_handle.get()));
|
||||
|
||||
AbstractTensorHandle* restored_handle = nullptr;
|
||||
int num_retvals = 1;
|
||||
TF_RETURN_IF_ERROR(restore_op->Execute(
|
||||
absl::MakeSpan(&restored_handle, num_retvals), &num_retvals));
|
||||
AbstractTensorHandlePtr owned_restored_handle(restored_handle);
|
||||
if (!tensorflow::isa<ImmediateExecutionTensorHandle>(
|
||||
owned_restored_handle.get())) {
|
||||
return errors::Internal("Unexpected tensor handle kind.");
|
||||
}
|
||||
out->reset(reinterpret_cast<ImmediateExecutionTensorHandle*>(
|
||||
owned_restored_handle.release()));
|
||||
return Status();
|
||||
}
|
||||
|
||||
} // namespace internal
|
||||
} // namespace tensorflow
|
40
tensorflow/c/experimental/saved_model/core/ops/restore_ops.h
Normal file
40
tensorflow/c/experimental/saved_model/core/ops/restore_ops.h
Normal file
@ -0,0 +1,40 @@
|
||||
/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
|
||||
#ifndef TENSORFLOW_C_EXPERIMENTAL_SAVED_MODEL_CORE_OPS_RESTORE_OP_H_
|
||||
#define TENSORFLOW_C_EXPERIMENTAL_SAVED_MODEL_CORE_OPS_RESTORE_OP_H_
|
||||
|
||||
#include <string>
|
||||
|
||||
#include "tensorflow/c/eager/immediate_execution_context.h"
|
||||
#include "tensorflow/c/eager/immediate_execution_tensor_handle.h"
|
||||
#include "tensorflow/core/framework/types.pb.h"
|
||||
#include "tensorflow/core/platform/status.h"
|
||||
|
||||
namespace tensorflow {
|
||||
namespace internal {
|
||||
|
||||
// TODO(bmzhao): Add a function to restore multiple tensors in one call.
|
||||
|
||||
// Restores a single non-partioned tensorhandle of dtype `dtype`, using
|
||||
// checkpoint at `prefix`, with a value stored in `checkpoint_key`.
|
||||
Status SingleRestore(ImmediateExecutionContext* ctx, const std::string& prefix,
|
||||
const std::string& checkpoint_key, DataType dtype,
|
||||
ImmediateTensorHandlePtr* out);
|
||||
|
||||
} // namespace internal
|
||||
} // namespace tensorflow
|
||||
|
||||
#endif // TENSORFLOW_C_EXPERIMENTAL_SAVED_MODEL_CORE_OPS_RESTORE_OP_H_
|
@ -0,0 +1,111 @@
|
||||
/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
|
||||
#include "tensorflow/c/experimental/saved_model/core/ops/restore_ops.h"
|
||||
|
||||
#include "tensorflow/c/eager/abstract_tensor_handle.h"
|
||||
#include "tensorflow/c/eager/immediate_execution_tensor_handle.h"
|
||||
#include "tensorflow/c/experimental/saved_model/core/test_utils.h"
|
||||
#include "tensorflow/c/tensor_interface.h"
|
||||
#include "tensorflow/cc/saved_model/constants.h"
|
||||
#include "tensorflow/core/framework/types.pb.h"
|
||||
#include "tensorflow/core/lib/core/status_test_util.h"
|
||||
#include "tensorflow/core/platform/path.h"
|
||||
#include "tensorflow/core/platform/stringpiece.h"
|
||||
#include "tensorflow/core/platform/test.h"
|
||||
|
||||
namespace tensorflow {
|
||||
namespace {
|
||||
|
||||
std::string CheckpointPrefix(StringPiece saved_model_dir) {
|
||||
return io::JoinPath(testing::TensorFlowSrcRoot(), "cc/saved_model/testdata",
|
||||
saved_model_dir, kSavedModelVariablesDirectory,
|
||||
kSavedModelVariablesFilename);
|
||||
}
|
||||
|
||||
class RestoreOpsTest : public ::testing::Test {
|
||||
public:
|
||||
RestoreOpsTest()
|
||||
: device_mgr_(testing::CreateTestingDeviceMgr()),
|
||||
ctx_(testing::CreateTestingEagerContext(device_mgr_.get())) {}
|
||||
|
||||
EagerContext* context() { return ctx_.get(); }
|
||||
|
||||
private:
|
||||
std::unique_ptr<StaticDeviceMgr> device_mgr_;
|
||||
EagerContextPtr ctx_;
|
||||
};
|
||||
|
||||
// One way of obtaining the checkpointa checkpoint's tensor names is:
|
||||
// bazel run //tensorflow/python/tools:inspect_checkpoint -- --all_tensors
|
||||
// --file_name="$CKPT_PREFIX".
|
||||
// Here are the values for VarsAndArithmeticObjectGraph:
|
||||
// tensor: child/z/.ATTRIBUTES/VARIABLE_VALUE (float32) []
|
||||
// 3.0
|
||||
// tensor: x/.ATTRIBUTES/VARIABLE_VALUE (float32) []
|
||||
// 1.0
|
||||
// tensor: y/.ATTRIBUTES/VARIABLE_VALUE (float32) []
|
||||
// 2.0
|
||||
|
||||
TEST_F(RestoreOpsTest, RestoreSuccessful) {
|
||||
ImmediateTensorHandlePtr x_handle;
|
||||
TF_EXPECT_OK(internal::SingleRestore(
|
||||
context(), CheckpointPrefix("VarsAndArithmeticObjectGraph"),
|
||||
"x/.ATTRIBUTES/VARIABLE_VALUE", DT_FLOAT, &x_handle));
|
||||
AbstractTensorPtr x = testing::TensorHandleToTensor(x_handle.get());
|
||||
EXPECT_EQ(x->Type(), DT_FLOAT);
|
||||
EXPECT_EQ(x->NumElements(), 1);
|
||||
EXPECT_EQ(x->NumDims(), 0);
|
||||
EXPECT_FLOAT_EQ(*reinterpret_cast<float*>(x->Data()), 1.0f);
|
||||
|
||||
ImmediateTensorHandlePtr y_handle;
|
||||
TF_EXPECT_OK(internal::SingleRestore(
|
||||
context(), CheckpointPrefix("VarsAndArithmeticObjectGraph"),
|
||||
"y/.ATTRIBUTES/VARIABLE_VALUE", DT_FLOAT, &y_handle));
|
||||
AbstractTensorPtr y = testing::TensorHandleToTensor(y_handle.get());
|
||||
EXPECT_EQ(y->Type(), DT_FLOAT);
|
||||
EXPECT_EQ(y->NumElements(), 1);
|
||||
EXPECT_EQ(y->NumDims(), 0);
|
||||
EXPECT_FLOAT_EQ(*reinterpret_cast<float*>(y->Data()), 2.0f);
|
||||
|
||||
ImmediateTensorHandlePtr z_handle;
|
||||
TF_EXPECT_OK(internal::SingleRestore(
|
||||
context(), CheckpointPrefix("VarsAndArithmeticObjectGraph"),
|
||||
"child/z/.ATTRIBUTES/VARIABLE_VALUE", DT_FLOAT, &z_handle));
|
||||
AbstractTensorPtr z = testing::TensorHandleToTensor(z_handle.get());
|
||||
EXPECT_EQ(z->Type(), DT_FLOAT);
|
||||
EXPECT_EQ(z->NumElements(), 1);
|
||||
EXPECT_EQ(z->NumDims(), 0);
|
||||
EXPECT_FLOAT_EQ(*reinterpret_cast<float*>(z->Data()), 3.0f);
|
||||
}
|
||||
|
||||
TEST_F(RestoreOpsTest, BadCheckpointPrefixShouldFail) {
|
||||
ImmediateTensorHandlePtr x_handle;
|
||||
Status status = internal::SingleRestore(
|
||||
context(), CheckpointPrefix("unknown_bad_checkpoint_prefix"),
|
||||
"x/.ATTRIBUTES/VARIABLE_VALUE", DT_FLOAT, &x_handle);
|
||||
EXPECT_FALSE(status.ok()) << status.error_message();
|
||||
}
|
||||
|
||||
TEST_F(RestoreOpsTest, BadCheckpointKeyShouldFail) {
|
||||
ImmediateTensorHandlePtr x_handle;
|
||||
Status status = internal::SingleRestore(
|
||||
context(), CheckpointPrefix("VarsAndArithmeticObjectGraph"),
|
||||
"bad_checkpoint_key", DT_FLOAT, &x_handle);
|
||||
EXPECT_FALSE(status.ok()) << status.error_message();
|
||||
}
|
||||
|
||||
} // namespace
|
||||
} // namespace tensorflow
|
@ -18,6 +18,7 @@ limitations under the License.
|
||||
#include <memory>
|
||||
|
||||
#include "tensorflow/c/eager/immediate_execution_tensor_handle.h"
|
||||
#include "tensorflow/c/experimental/saved_model/core/test_utils.h"
|
||||
#include "tensorflow/c/tensor_interface.h"
|
||||
#include "tensorflow/core/common_runtime/device_mgr.h"
|
||||
#include "tensorflow/core/common_runtime/eager/context.h"
|
||||
@ -39,17 +40,8 @@ ImmediateTensorHandlePtr CreateScalarTensorHandle(EagerContext* context,
|
||||
class VariableOpsTest : public ::testing::Test {
|
||||
public:
|
||||
VariableOpsTest()
|
||||
: device_mgr_(std::make_unique<StaticDeviceMgr>(DeviceFactory::NewDevice(
|
||||
"CPU", {}, "/job:localhost/replica:0/task:0"))),
|
||||
ctx_(new EagerContext(
|
||||
SessionOptions(),
|
||||
tensorflow::ContextDevicePlacementPolicy::DEVICE_PLACEMENT_SILENT,
|
||||
tensorflow::ContextMirroringPolicy::MIRRORING_NONE,
|
||||
/* async= */ false,
|
||||
/* lazy_copy_function_remote_inputs= */ false, device_mgr_.get(),
|
||||
/* device_mgr_owned= */ false, /* rendezvous= */ nullptr,
|
||||
/* custom_kernel_creator= */ nullptr,
|
||||
/* cluster_flr= */ nullptr)) {}
|
||||
: device_mgr_(testing::CreateTestingDeviceMgr()),
|
||||
ctx_(testing::CreateTestingEagerContext(device_mgr_.get())) {}
|
||||
|
||||
EagerContext* context() { return ctx_.get(); }
|
||||
|
||||
|
@ -58,3 +58,24 @@ cc_library(
|
||||
"//tensorflow/c/eager:immediate_execution_tensor_handle",
|
||||
],
|
||||
)
|
||||
|
||||
cc_library(
|
||||
name = "tf_concrete_function",
|
||||
srcs = [
|
||||
"tf_concrete_function.cc",
|
||||
],
|
||||
hdrs = [
|
||||
"tf_concrete_function.h",
|
||||
],
|
||||
deps = [
|
||||
":tensorhandle_convertible",
|
||||
"//tensorflow/c/eager:immediate_execution_context",
|
||||
"//tensorflow/c/eager:immediate_execution_operation",
|
||||
"//tensorflow/c/eager:immediate_execution_tensor_handle",
|
||||
"//tensorflow/c/experimental/saved_model/core:concrete_function",
|
||||
"//tensorflow/c/experimental/saved_model/core:function_metadata",
|
||||
"//tensorflow/core:lib",
|
||||
"//tensorflow/core:protos_all_cc",
|
||||
"//tensorflow/core/common_runtime/eager:context",
|
||||
],
|
||||
)
|
||||
|
@ -0,0 +1,87 @@
|
||||
/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
|
||||
#include "tensorflow/c/experimental/saved_model/core/revived_types/tf_concrete_function.h"
|
||||
|
||||
#include <memory>
|
||||
#include <string>
|
||||
|
||||
#include "tensorflow/c/eager/immediate_execution_operation.h"
|
||||
#include "tensorflow/c/eager/immediate_execution_tensor_handle.h"
|
||||
#include "tensorflow/c/experimental/saved_model/core/revived_types/tensorhandle_convertible.h"
|
||||
#include "tensorflow/core/common_runtime/eager/context.h"
|
||||
#include "tensorflow/core/framework/function.pb.h"
|
||||
#include "tensorflow/core/platform/errors.h"
|
||||
#include "tensorflow/core/platform/logging.h"
|
||||
#include "tensorflow/core/platform/status.h"
|
||||
#include "tensorflow/core/protobuf/saved_object_graph.pb.h"
|
||||
#include "tensorflow/core/protobuf/struct.pb.h"
|
||||
|
||||
namespace tensorflow {
|
||||
|
||||
TFConcreteFunction::TFConcreteFunction(
|
||||
const std::string& name,
|
||||
std::vector<ImmediateExecutionTensorHandle*> captures,
|
||||
FunctionMetadata metadata, ImmediateExecutionContext* ctx)
|
||||
: name_(name),
|
||||
captures_(std::move(captures)),
|
||||
metadata_(std::move(metadata)),
|
||||
ctx_(ctx) {}
|
||||
|
||||
TFConcreteFunction::~TFConcreteFunction() {
|
||||
Status status = ctx_->RemoveFunction(name_);
|
||||
if (!status.ok()) {
|
||||
LOG(ERROR) << "Failed to remove functiondef " << name_ << ". "
|
||||
<< status.error_message();
|
||||
}
|
||||
}
|
||||
|
||||
Status TFConcreteFunction::Create(
|
||||
const FunctionDef* function_def,
|
||||
std::vector<ImmediateExecutionTensorHandle*> captures,
|
||||
FunctionMetadata metadata, ImmediateExecutionContext* ctx,
|
||||
std::unique_ptr<TFConcreteFunction>* out) {
|
||||
TF_RETURN_IF_ERROR(ctx->AddFunctionDef(*function_def));
|
||||
out->reset(new TFConcreteFunction(function_def->signature().name(),
|
||||
std::move(captures), std::move(metadata),
|
||||
ctx));
|
||||
return Status();
|
||||
}
|
||||
|
||||
const std::vector<ImmediateExecutionTensorHandle*>&
|
||||
TFConcreteFunction::GetCaptures() const {
|
||||
return captures_;
|
||||
}
|
||||
|
||||
const FunctionMetadata& TFConcreteFunction::GetFunctionMetadata() const {
|
||||
return metadata_;
|
||||
}
|
||||
|
||||
Status TFConcreteFunction::GetCallOp(ImmediateOpPtr* out) {
|
||||
out->reset(ctx_->CreateOperation());
|
||||
// In eager mode, TF2 python executes functions by constructing an op with
|
||||
// the name of the functiondef:
|
||||
// https://github.com/tensorflow/tensorflow/blob/66668ec0ca432e2f38a575b814f45b6d299d01ed/tensorflow/python/eager/function.py#L545
|
||||
// In graph mode, we create a PartitionedCallOp instead:
|
||||
// https://github.com/tensorflow/tensorflow/blob/66668ec0ca432e2f38a575b814f45b6d299d01ed/tensorflow/python/eager/function.py#L573
|
||||
|
||||
// TODO(bmzhao): After discussing with Allen, we should execute this via a
|
||||
// PartitionedCallOp for compatibility with "tooling that assumes functions in
|
||||
// graphs are PartitionedCallOps".
|
||||
TF_RETURN_IF_ERROR((*out)->Reset(name_.c_str(), nullptr));
|
||||
return Status();
|
||||
}
|
||||
|
||||
} // namespace tensorflow
|
@ -0,0 +1,87 @@
|
||||
/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
|
||||
#ifndef TENSORFLOW_C_EXPERIMENTAL_SAVED_MODEL_CORE_TF_CONCRETE_FUNCTION_H_
|
||||
#define TENSORFLOW_C_EXPERIMENTAL_SAVED_MODEL_CORE_TF_CONCRETE_FUNCTION_H_
|
||||
|
||||
#include <functional>
|
||||
#include <memory>
|
||||
#include <string>
|
||||
#include <unordered_map>
|
||||
#include <vector>
|
||||
|
||||
#include "tensorflow/c/eager/immediate_execution_context.h"
|
||||
#include "tensorflow/c/eager/immediate_execution_operation.h"
|
||||
#include "tensorflow/c/eager/immediate_execution_tensor_handle.h"
|
||||
#include "tensorflow/c/experimental/saved_model/core/concrete_function.h"
|
||||
#include "tensorflow/c/experimental/saved_model/core/function_metadata.h"
|
||||
#include "tensorflow/c/experimental/saved_model/core/revived_types/tensorhandle_convertible.h"
|
||||
#include "tensorflow/core/framework/function.pb.h"
|
||||
#include "tensorflow/core/protobuf/saved_object_graph.pb.h"
|
||||
|
||||
namespace tensorflow {
|
||||
|
||||
// TF Eager Runtime-based implementation of a "ConcreteFunction" loaded from a
|
||||
// saved model.
|
||||
class TFConcreteFunction : public ConcreteFunction {
|
||||
public:
|
||||
// Factory function for creating a TFConcreteFunction.
|
||||
//
|
||||
// Params:
|
||||
// function_def - The function_def associated with the created
|
||||
// TFConcreteFunction. TFConcreteFunction will register this
|
||||
// function_def with `ctx` on creation, and de-register it on
|
||||
// destruction. function_def must be non-null, but
|
||||
// otherwise has no lifetime requirements.
|
||||
// captures - The captured TensorHandles associated with this
|
||||
// TFConcreteFunction.
|
||||
// metadata - The FunctionMetadata associated with this TFConcreteFunction.
|
||||
// ctx - A handle to the Tensorflow runtime. This MUST be non-null and
|
||||
// outlive TFConcreteFunction.
|
||||
// out - The output TFConcreteFunction.
|
||||
static Status Create(const FunctionDef* function_def,
|
||||
std::vector<ImmediateExecutionTensorHandle*> captures,
|
||||
FunctionMetadata metadata,
|
||||
ImmediateExecutionContext* ctx,
|
||||
std::unique_ptr<TFConcreteFunction>* out);
|
||||
|
||||
// This method returns the "Call" Op used to execute the function.
|
||||
Status GetCallOp(ImmediateOpPtr* out) override;
|
||||
|
||||
const std::vector<ImmediateExecutionTensorHandle*>& GetCaptures()
|
||||
const override;
|
||||
|
||||
const FunctionMetadata& GetFunctionMetadata() const override;
|
||||
|
||||
~TFConcreteFunction() override;
|
||||
|
||||
private:
|
||||
TFConcreteFunction(const std::string& name,
|
||||
std::vector<ImmediateExecutionTensorHandle*> captures,
|
||||
FunctionMetadata metadata, ImmediateExecutionContext* ctx);
|
||||
|
||||
TFConcreteFunction(const TFConcreteFunction&) = delete;
|
||||
TFConcreteFunction& operator=(const TFConcreteFunction&) = delete;
|
||||
|
||||
// Name of the FunctionDef corresponding to this TFConcreteFunction
|
||||
std::string name_;
|
||||
std::vector<ImmediateExecutionTensorHandle*> captures_;
|
||||
FunctionMetadata metadata_;
|
||||
ImmediateExecutionContext* ctx_;
|
||||
};
|
||||
|
||||
} // namespace tensorflow
|
||||
|
||||
#endif // TENSORFLOW_C_EXPERIMENTAL_SAVED_MODEL_CORE_TF_CONCRETE_FUNCTION_H_
|
@ -17,14 +17,125 @@ limitations under the License.
|
||||
|
||||
#include <memory>
|
||||
|
||||
#include "tensorflow/c/experimental/saved_model/core/function_metadata.h"
|
||||
#include "tensorflow/c/experimental/saved_model/core/revived_types/constant.h"
|
||||
#include "tensorflow/c/experimental/saved_model/core/revived_types/variable.h"
|
||||
#include "tensorflow/c/tf_tensor_internal.h"
|
||||
#include "tensorflow/core/framework/tensor.pb.h"
|
||||
#include "tensorflow/core/platform/errors.h"
|
||||
#include "tensorflow/core/protobuf/saved_object_graph.pb.h"
|
||||
#include "tensorflow/core/protobuf/struct.pb.h"
|
||||
|
||||
namespace tensorflow {
|
||||
namespace internal {
|
||||
namespace {
|
||||
|
||||
// This returns the size of `tf.nest.flatten(value)`, on values that are
|
||||
// used in tf.function's input_signatures.
|
||||
int FlattenedSize(const tensorflow::StructuredValue& value, Status* status) {
|
||||
// This follows the logic from
|
||||
// https://github.com/tensorflow/tensorflow/blob/1c064ab76064c58e54261b805027474885a1534d/tensorflow/compiler/mlir/tensorflow/translate/import_model.cc#L2775
|
||||
switch (value.kind_case()) {
|
||||
case StructuredValue::kDictValue: {
|
||||
const DictValue& dict = value.dict_value();
|
||||
int size = 0;
|
||||
for (const auto& field : dict.fields()) {
|
||||
size += FlattenedSize(field.second, status);
|
||||
}
|
||||
return size;
|
||||
}
|
||||
case StructuredValue::kTupleValue: {
|
||||
const TupleValue& tuple = value.tuple_value();
|
||||
int size = 0;
|
||||
for (const StructuredValue& value : tuple.values()) {
|
||||
size += FlattenedSize(value, status);
|
||||
}
|
||||
return size;
|
||||
}
|
||||
case StructuredValue::kListValue: {
|
||||
const ListValue& list = value.list_value();
|
||||
int size = 0;
|
||||
for (const StructuredValue& value : list.values()) {
|
||||
size += FlattenedSize(value, status);
|
||||
}
|
||||
return size;
|
||||
}
|
||||
case StructuredValue::kTensorSpecValue: {
|
||||
return 1;
|
||||
}
|
||||
case StructuredValue::kNoneValue: {
|
||||
// Base case: do nothing.
|
||||
// This arises, for example, as the top-level object of an output
|
||||
// signature when there are no return values.
|
||||
return 0;
|
||||
}
|
||||
default: {
|
||||
status->Update(errors::Internal("Unhandled structured value kind ",
|
||||
value.kind_case()));
|
||||
return 0;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Perform some basic sanity checks on SavedConcreteFunction's input and
|
||||
// output signatures with respect to the corresponding FunctionDef's input
|
||||
// and output args.
|
||||
Status ValidateSavedFunctionCompatibleWithFunctionDef(
|
||||
const SavedConcreteFunction& saved_concrete_function,
|
||||
const FunctionDef* function_def) {
|
||||
// tf.functions go through many transformations before becoming FunctionDefs
|
||||
// 1. flatten user-provided inputs:
|
||||
// https://github.com/tensorflow/tensorflow/blob/1c064ab76064c58e54261b805027474885a1534d/tensorflow/python/eager/function.py#L2671-L2675
|
||||
// 2. convert user-provided inputs to tensors:
|
||||
// https://github.com/tensorflow/tensorflow/blob/1c064ab76064c58e54261b805027474885a1534d/tensorflow/python/eager/function.py#L2687-L2688
|
||||
// 3. filter any non-tensor, non-variable inputs:
|
||||
// https://github.com/tensorflow/tensorflow/blob/1c064ab76064c58e54261b805027474885a1534d/tensorflow/python/eager/function.py#L1840-L1841
|
||||
// 4. concatenate any captured inputs:
|
||||
// https://github.com/tensorflow/tensorflow/blob/1c064ab76064c58e54261b805027474885a1534d/tensorflow/python/eager/function.py#L1912
|
||||
|
||||
// Since our API is limited to tf.functions annotated with input signatures,
|
||||
// conditions 2 and 3 are trivially satisfied.
|
||||
// We need to ensure that:
|
||||
// flatten(input_signature).size() + captures.size() = fdef.signature().size()
|
||||
// A concrete function's serialized "canonicalized_input_signature" comes
|
||||
// from encoding its "structured_input_signature" field:
|
||||
// https://github.com/tensorflow/tensorflow/blob/1c064ab76064c58e54261b805027474885a1534d/tensorflow/python/saved_model/function_serialization.py#L70-L71
|
||||
// The "structured_input_signature" is guaranteed to be a tuple of the python
|
||||
// args, kwargs that correspond to the tf.function:
|
||||
// https://github.com/tensorflow/tensorflow/blob/1c064ab76064c58e54261b805027474885a1534d/tensorflow/python/eager/function.py#L1974-L1979
|
||||
|
||||
const std::string& name = function_def->signature().name();
|
||||
const StructuredValue& input_signature =
|
||||
saved_concrete_function.canonicalized_input_signature();
|
||||
Status status;
|
||||
int input_signature_size = FlattenedSize(input_signature, &status);
|
||||
TF_RETURN_IF_ERROR(status);
|
||||
if (input_signature_size + saved_concrete_function.bound_inputs_size() !=
|
||||
function_def->signature().input_arg_size()) {
|
||||
return errors::FailedPrecondition(
|
||||
"FunctionDef ", name, " has ",
|
||||
function_def->signature().input_arg_size(),
|
||||
" inputs, but the SavedConcreteFunction has ", input_signature_size,
|
||||
" flattened user inputs and ",
|
||||
saved_concrete_function.bound_inputs_size(), " captured inputs.");
|
||||
}
|
||||
|
||||
const StructuredValue& output_signature =
|
||||
saved_concrete_function.output_signature();
|
||||
int output_signature_size = FlattenedSize(output_signature, &status);
|
||||
TF_RETURN_IF_ERROR(status);
|
||||
if (output_signature_size != function_def->signature().output_arg_size()) {
|
||||
return errors::FailedPrecondition(
|
||||
"FunctionDef ", name, " has ",
|
||||
function_def->signature().output_arg_size(),
|
||||
" outputs, but the SavedConcreteFunction has ", output_signature_size,
|
||||
" flattened outputs.");
|
||||
}
|
||||
|
||||
return status;
|
||||
}
|
||||
|
||||
} // namespace
|
||||
|
||||
Status TensorProtoToConstant(ImmediateExecutionContext* ctx,
|
||||
const TensorProto& proto,
|
||||
@ -54,5 +165,31 @@ Status LoadSavedVariable(ImmediateExecutionContext* ctx,
|
||||
return Status();
|
||||
}
|
||||
|
||||
Status LoadTFConcreteFunction(
|
||||
const SavedConcreteFunction& saved_concrete_function,
|
||||
const FunctionDef* function_def,
|
||||
const std::unordered_map<int, std::unique_ptr<TensorHandleConvertible>>&
|
||||
captured_objects,
|
||||
ImmediateExecutionContext* ctx, std::unique_ptr<TFConcreteFunction>* out) {
|
||||
TF_RETURN_IF_ERROR(ValidateSavedFunctionCompatibleWithFunctionDef(
|
||||
saved_concrete_function, function_def));
|
||||
|
||||
// Copy over captures
|
||||
std::vector<ImmediateExecutionTensorHandle*> captures;
|
||||
captures.reserve(saved_concrete_function.bound_inputs_size());
|
||||
for (int bound_input : saved_concrete_function.bound_inputs()) {
|
||||
auto iter = captured_objects.find(bound_input);
|
||||
if (iter == captured_objects.end()) {
|
||||
return errors::FailedPrecondition("Failed to find bound_input ",
|
||||
bound_input,
|
||||
" for SavedConcreteFunction");
|
||||
}
|
||||
captures.push_back(iter->second->handle());
|
||||
}
|
||||
|
||||
return TFConcreteFunction::Create(function_def, std::move(captures), {}, ctx,
|
||||
out);
|
||||
}
|
||||
|
||||
} // namespace internal
|
||||
} // namespace tensorflow
|
||||
|
@ -21,6 +21,7 @@ limitations under the License.
|
||||
|
||||
#include "tensorflow/c/eager/immediate_execution_context.h"
|
||||
#include "tensorflow/c/experimental/saved_model/core/revived_types/constant.h"
|
||||
#include "tensorflow/c/experimental/saved_model/core/revived_types/tf_concrete_function.h"
|
||||
#include "tensorflow/c/experimental/saved_model/core/revived_types/variable.h"
|
||||
#include "tensorflow/core/framework/tensor.pb.h"
|
||||
#include "tensorflow/core/protobuf/saved_object_graph.pb.h"
|
||||
@ -43,6 +44,14 @@ Status LoadSavedVariable(ImmediateExecutionContext* ctx,
|
||||
const SavedVariable& variable,
|
||||
std::unique_ptr<Variable>* output);
|
||||
|
||||
// Creates a TFConcreteFunction from a SavedConcreteFunction.
|
||||
Status LoadTFConcreteFunction(
|
||||
const SavedConcreteFunction& saved_concrete_function,
|
||||
const FunctionDef* function_def,
|
||||
const std::unordered_map<int, std::unique_ptr<TensorHandleConvertible>>&
|
||||
captured_objects,
|
||||
ImmediateExecutionContext* ctx, std::unique_ptr<TFConcreteFunction>* out);
|
||||
|
||||
} // namespace internal
|
||||
} // namespace tensorflow
|
||||
|
||||
|
@ -139,5 +139,13 @@ void CheckBufferDataIsEqual(DataType dtype, int64 num_elements, void* a,
|
||||
}
|
||||
}
|
||||
|
||||
AbstractTensorPtr TensorHandleToTensor(ImmediateExecutionTensorHandle* handle) {
|
||||
Status status;
|
||||
AbstractTensorPtr tensor(handle->Resolve(&status));
|
||||
CHECK(status.ok()) << status.error_message();
|
||||
CHECK_NE(tensor.get(), nullptr);
|
||||
return tensor;
|
||||
}
|
||||
|
||||
} // namespace testing
|
||||
} // namespace tensorflow
|
||||
|
@ -69,6 +69,10 @@ void FillNumericTensorBuffer(DataType dtype, size_t num_elements, void* buffer,
|
||||
void CheckBufferDataIsEqual(DataType dtype, int64 num_elements, void* a,
|
||||
void* b);
|
||||
|
||||
// Converts a TensorHandle to a Tensor, and dies if unsuccessful. This should
|
||||
// only be used for testing purposes.
|
||||
AbstractTensorPtr TensorHandleToTensor(ImmediateExecutionTensorHandle* handle);
|
||||
|
||||
} // namespace testing
|
||||
} // namespace tensorflow
|
||||
|
||||
|
@ -0,0 +1,271 @@
|
||||
/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
|
||||
#include <memory>
|
||||
#include <unordered_map>
|
||||
|
||||
#include "tensorflow/c/eager/immediate_execution_tensor_handle.h"
|
||||
#include "tensorflow/c/experimental/saved_model/core/revived_types/tensorhandle_convertible.h"
|
||||
#include "tensorflow/c/experimental/saved_model/core/revived_types/tf_concrete_function.h"
|
||||
#include "tensorflow/c/experimental/saved_model/core/saved_model_utils.h"
|
||||
#include "tensorflow/c/experimental/saved_model/core/test_utils.h"
|
||||
#include "tensorflow/c/experimental/saved_model/core/tf_concrete_function_test_protos.h"
|
||||
#include "tensorflow/core/framework/function.pb.h"
|
||||
#include "tensorflow/core/framework/op_def.pb.h"
|
||||
#include "tensorflow/core/framework/types.pb.h"
|
||||
#include "tensorflow/core/lib/core/status_test_util.h"
|
||||
#include "tensorflow/core/platform/status.h"
|
||||
#include "tensorflow/core/platform/test.h"
|
||||
#include "tensorflow/core/protobuf/error_codes.pb.h"
|
||||
#include "tensorflow/core/protobuf/saved_object_graph.pb.h"
|
||||
|
||||
namespace tensorflow {
|
||||
namespace {
|
||||
|
||||
class SavedConcreteFunctionLoadingTest : public ::testing::Test {
|
||||
public:
|
||||
SavedConcreteFunctionLoadingTest()
|
||||
: device_mgr_(testing::CreateTestingDeviceMgr()),
|
||||
ctx_(testing::CreateTestingEagerContext(device_mgr_.get())) {}
|
||||
|
||||
EagerContext* context() { return ctx_.get(); }
|
||||
|
||||
private:
|
||||
std::unique_ptr<StaticDeviceMgr> device_mgr_;
|
||||
EagerContextPtr ctx_;
|
||||
};
|
||||
|
||||
class DummyCapture : public TensorHandleConvertible {
|
||||
public:
|
||||
DummyCapture(ImmediateExecutionContext* ctx, int8 value)
|
||||
: TensorHandleConvertible(
|
||||
testing::CreateTensorHandle(ctx, DT_FLOAT, {2, 4}, value)) {}
|
||||
};
|
||||
|
||||
FunctionDef FuncDefWithNumInputsOutputs(int num_inputs, int num_outputs) {
|
||||
FunctionDef func;
|
||||
OpDef* signature = func.mutable_signature();
|
||||
for (int i = 0; i < num_inputs; ++i) {
|
||||
signature->add_input_arg();
|
||||
}
|
||||
for (int i = 0; i < num_outputs; ++i) {
|
||||
signature->add_output_arg();
|
||||
}
|
||||
return func;
|
||||
}
|
||||
|
||||
// A SavedConcreteFunction whose canonicalized input signature
|
||||
// has less inputs than its corresponding FunctionDef should cause an error.
|
||||
TEST_F(SavedConcreteFunctionLoadingTest, TooFewInputsInSavedConcreteFunction) {
|
||||
// `saved` has 1 input
|
||||
SavedConcreteFunction saved;
|
||||
*saved.mutable_canonicalized_input_signature() =
|
||||
testing::SingleArgInputSignature();
|
||||
*saved.mutable_output_signature() = testing::ZeroReturnOutputSignature();
|
||||
|
||||
// `func` has 2 inputs
|
||||
FunctionDef func = FuncDefWithNumInputsOutputs(2, 0);
|
||||
|
||||
std::unique_ptr<TFConcreteFunction> result;
|
||||
Status status =
|
||||
internal::LoadTFConcreteFunction(saved, &func, {}, context(), &result);
|
||||
EXPECT_EQ(status.code(), error::FAILED_PRECONDITION)
|
||||
<< status.error_message();
|
||||
}
|
||||
|
||||
// A SavedConcreteFunction whose canonicalized input signature length +
|
||||
// captures is less than its corresponding FunctionDef should cause an error.
|
||||
TEST_F(SavedConcreteFunctionLoadingTest,
|
||||
TooFewInputsWithCapturesInSavedConcreteFunction) {
|
||||
// `saved` has 1 input, and 1 capture, for a total of 2 inputs
|
||||
SavedConcreteFunction saved;
|
||||
*saved.mutable_canonicalized_input_signature() =
|
||||
testing::SingleArgInputSignature();
|
||||
*saved.mutable_output_signature() = testing::ZeroReturnOutputSignature();
|
||||
saved.add_bound_inputs(5);
|
||||
|
||||
// `func` has 3 inputs
|
||||
FunctionDef func = FuncDefWithNumInputsOutputs(3, 0);
|
||||
|
||||
std::unordered_map<int, std::unique_ptr<TensorHandleConvertible>> captures;
|
||||
captures[5] = std::make_unique<DummyCapture>(context(), 10);
|
||||
|
||||
std::unique_ptr<TFConcreteFunction> result;
|
||||
Status status = internal::LoadTFConcreteFunction(saved, &func, captures,
|
||||
context(), &result);
|
||||
EXPECT_EQ(status.code(), error::FAILED_PRECONDITION)
|
||||
<< status.error_message();
|
||||
}
|
||||
|
||||
// A SavedConcreteFunction whose canonicalized input signature
|
||||
// has more inputs than its corresponding FunctionDef should cause an error.
|
||||
TEST_F(SavedConcreteFunctionLoadingTest, TooManyInputsInSavedConcreteFunction) {
|
||||
// `saved` has 3 inputs
|
||||
SavedConcreteFunction saved;
|
||||
*saved.mutable_canonicalized_input_signature() =
|
||||
testing::ThreeArgInputSignature();
|
||||
*saved.mutable_output_signature() = testing::ZeroReturnOutputSignature();
|
||||
|
||||
// `func` has 2 inputs
|
||||
FunctionDef func = FuncDefWithNumInputsOutputs(2, 0);
|
||||
|
||||
std::unique_ptr<TFConcreteFunction> result;
|
||||
Status status =
|
||||
internal::LoadTFConcreteFunction(saved, &func, {}, context(), &result);
|
||||
EXPECT_EQ(status.code(), error::FAILED_PRECONDITION)
|
||||
<< status.error_message();
|
||||
}
|
||||
|
||||
// A SavedConcreteFunction whose canonicalized input signature
|
||||
// has the same number of inputs than its corresponding FunctionDef, but has
|
||||
// additional captures should cause an error.
|
||||
TEST_F(SavedConcreteFunctionLoadingTest,
|
||||
TooManyInputsWithCaptureInSavedConcreteFunction) {
|
||||
// `saved` has 3 inputs, and 1 capture, for a total of 4 inputs.
|
||||
SavedConcreteFunction saved;
|
||||
*saved.mutable_canonicalized_input_signature() =
|
||||
testing::ThreeArgInputSignature();
|
||||
*saved.mutable_output_signature() = testing::ZeroReturnOutputSignature();
|
||||
saved.add_bound_inputs(5);
|
||||
|
||||
// `func` has 3 inputs.
|
||||
FunctionDef func = FuncDefWithNumInputsOutputs(3, 0);
|
||||
|
||||
std::unordered_map<int, std::unique_ptr<TensorHandleConvertible>> captures;
|
||||
captures[5] = std::make_unique<DummyCapture>(context(), 10);
|
||||
|
||||
std::unique_ptr<TFConcreteFunction> result;
|
||||
Status status = internal::LoadTFConcreteFunction(saved, &func, captures,
|
||||
context(), &result);
|
||||
EXPECT_EQ(status.code(), error::FAILED_PRECONDITION)
|
||||
<< status.error_message();
|
||||
}
|
||||
|
||||
// A SavedConcreteFunction whose capture refers to an index not in the capture
|
||||
// map should cause an error.
|
||||
TEST_F(SavedConcreteFunctionLoadingTest, ImproperCaptureIndex) {
|
||||
// `saved` has 3 inputs, 1 capture, for a total of 4 inputs
|
||||
SavedConcreteFunction saved;
|
||||
*saved.mutable_canonicalized_input_signature() =
|
||||
testing::ThreeArgInputSignature();
|
||||
*saved.mutable_output_signature() = testing::ZeroReturnOutputSignature();
|
||||
// Capture is at index "10"
|
||||
saved.add_bound_inputs(10);
|
||||
|
||||
// `func` has 4 inputs
|
||||
FunctionDef func = FuncDefWithNumInputsOutputs(4, 0);
|
||||
|
||||
// `captures` only has a capture for index 5
|
||||
std::unordered_map<int, std::unique_ptr<TensorHandleConvertible>> captures;
|
||||
captures[5] = std::make_unique<DummyCapture>(context(), 10);
|
||||
|
||||
std::unique_ptr<TFConcreteFunction> result;
|
||||
Status status = internal::LoadTFConcreteFunction(saved, &func, captures,
|
||||
context(), &result);
|
||||
EXPECT_EQ(status.code(), error::FAILED_PRECONDITION)
|
||||
<< status.error_message();
|
||||
}
|
||||
|
||||
// A SavedConcreteFunction whose outputs are fewer than its corresponding
|
||||
// functiondef should cause an error.
|
||||
TEST_F(SavedConcreteFunctionLoadingTest, TooFewOutputsInSavedConcreteFunction) {
|
||||
// `saved` has 0 inputs, 1 output
|
||||
SavedConcreteFunction saved;
|
||||
*saved.mutable_canonicalized_input_signature() =
|
||||
testing::ZeroArgInputSignature();
|
||||
*saved.mutable_output_signature() = testing::SingleReturnOutputSignature();
|
||||
|
||||
// `func` has 0 inputs, 2 outputs
|
||||
FunctionDef func = FuncDefWithNumInputsOutputs(0, 2);
|
||||
|
||||
std::unique_ptr<TFConcreteFunction> result;
|
||||
Status status =
|
||||
internal::LoadTFConcreteFunction(saved, &func, {}, context(), &result);
|
||||
EXPECT_EQ(status.code(), error::FAILED_PRECONDITION)
|
||||
<< status.error_message();
|
||||
}
|
||||
|
||||
// A SavedConcreteFunction whose outputs exceed its corresponding functiondef
|
||||
// should cause an error.
|
||||
TEST_F(SavedConcreteFunctionLoadingTest,
|
||||
TooManyOutputsInSavedConcreteFunction) {
|
||||
// `saved` has 1 input, 3 outputs
|
||||
SavedConcreteFunction saved;
|
||||
*saved.mutable_canonicalized_input_signature() =
|
||||
testing::SingleArgInputSignature();
|
||||
*saved.mutable_output_signature() = testing::ThreeReturnOutputSignature();
|
||||
|
||||
// `func` has 1 input, 2 outputs
|
||||
FunctionDef func = FuncDefWithNumInputsOutputs(1, 2);
|
||||
|
||||
std::unique_ptr<TFConcreteFunction> result;
|
||||
Status status =
|
||||
internal::LoadTFConcreteFunction(saved, &func, {}, context(), &result);
|
||||
EXPECT_EQ(status.code(), error::FAILED_PRECONDITION)
|
||||
<< status.error_message();
|
||||
}
|
||||
|
||||
// A SavedConcreteFunction whose (inputs + captures) = functiondef inputs,
|
||||
// and whose outputs = functiondef outputs should successfully load.
|
||||
TEST_F(SavedConcreteFunctionLoadingTest, SuccessfulLoad) {
|
||||
// `saved` has 1 input, 2 captures, 3 outputs
|
||||
SavedConcreteFunction saved;
|
||||
*saved.mutable_canonicalized_input_signature() =
|
||||
testing::SingleArgInputSignature();
|
||||
*saved.mutable_output_signature() = testing::ThreeReturnOutputSignature();
|
||||
saved.add_bound_inputs(2);
|
||||
saved.add_bound_inputs(5);
|
||||
|
||||
// `func` has 3 inputs, 3 outputs
|
||||
FunctionDef func = FuncDefWithNumInputsOutputs(3, 3);
|
||||
|
||||
std::unordered_map<int, std::unique_ptr<TensorHandleConvertible>> captures;
|
||||
captures[2] = std::make_unique<DummyCapture>(context(), 1);
|
||||
captures[5] = std::make_unique<DummyCapture>(context(), 10);
|
||||
|
||||
std::unique_ptr<TFConcreteFunction> result;
|
||||
Status status = internal::LoadTFConcreteFunction(saved, &func, captures,
|
||||
context(), &result);
|
||||
TF_EXPECT_OK(status) << status.error_message();
|
||||
}
|
||||
|
||||
// A TFConcreteFunction should register functiondefs on creation, and
|
||||
// remove them upon deletion.
|
||||
TEST_F(SavedConcreteFunctionLoadingTest, RegistersAndRemovesFunctionDefs) {
|
||||
std::string func_name = "FooBarBazWombatFunction";
|
||||
|
||||
SavedConcreteFunction saved;
|
||||
*saved.mutable_canonicalized_input_signature() =
|
||||
testing::ZeroArgInputSignature();
|
||||
*saved.mutable_output_signature() = testing::ZeroReturnOutputSignature();
|
||||
FunctionDef func = FuncDefWithNumInputsOutputs(0, 0);
|
||||
*func.mutable_signature()->mutable_name() = func_name;
|
||||
|
||||
{
|
||||
std::unique_ptr<TFConcreteFunction> result;
|
||||
Status status =
|
||||
internal::LoadTFConcreteFunction(saved, &func, {}, context(), &result);
|
||||
TF_EXPECT_OK(status) << status.error_message();
|
||||
// The function should be registered with context.
|
||||
EXPECT_TRUE(context()->FindFunctionByName(func_name));
|
||||
}
|
||||
|
||||
// After `result's` destructor runs, the function should no longer be
|
||||
// registered with context.
|
||||
EXPECT_FALSE(context()->FindFunctionByName(func_name));
|
||||
}
|
||||
|
||||
} // namespace
|
||||
} // namespace tensorflow
|
@ -0,0 +1,212 @@
|
||||
/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
|
||||
#include "tensorflow/c/experimental/saved_model/core/tf_concrete_function_test_protos.h"
|
||||
|
||||
#include <string>
|
||||
|
||||
#include "absl/strings/string_view.h"
|
||||
#include "tensorflow/core/platform/protobuf.h"
|
||||
#include "tensorflow/core/protobuf/struct.pb.h"
|
||||
|
||||
namespace tensorflow {
|
||||
namespace testing {
|
||||
namespace {
|
||||
|
||||
constexpr absl::string_view kZeroArgInputSignatureTextProto = R"(
|
||||
tuple_value: {
|
||||
values: {
|
||||
tuple_value: {
|
||||
}
|
||||
}
|
||||
values: {
|
||||
dict_value: {
|
||||
}
|
||||
}
|
||||
}
|
||||
)";
|
||||
|
||||
constexpr absl::string_view kSingleArgInputSignatureTextProto = R"(
|
||||
tuple_value: {
|
||||
values: {
|
||||
tuple_value: {
|
||||
values: {
|
||||
tensor_spec_value: {
|
||||
name : "x"
|
||||
shape: {
|
||||
dim: {
|
||||
size: 1
|
||||
}
|
||||
dim: {
|
||||
size: 10
|
||||
}
|
||||
}
|
||||
dtype: DT_FLOAT
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
values: {
|
||||
dict_value: {
|
||||
}
|
||||
}
|
||||
}
|
||||
)";
|
||||
|
||||
constexpr absl::string_view kThreeArgInputSignatureTextProto = R"(
|
||||
tuple_value: {
|
||||
values: {
|
||||
tuple_value: {
|
||||
values: {
|
||||
tensor_spec_value: {
|
||||
name : "x"
|
||||
shape: {
|
||||
dim: {
|
||||
size: 1
|
||||
}
|
||||
}
|
||||
dtype: DT_FLOAT
|
||||
}
|
||||
}
|
||||
values: {
|
||||
tensor_spec_value: {
|
||||
name : "y"
|
||||
shape: {
|
||||
dim: {
|
||||
size: 1
|
||||
}
|
||||
}
|
||||
dtype: DT_FLOAT
|
||||
}
|
||||
}
|
||||
values: {
|
||||
tensor_spec_value: {
|
||||
name : "z"
|
||||
shape: {
|
||||
dim: {
|
||||
size: 1
|
||||
}
|
||||
}
|
||||
dtype: DT_FLOAT
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
values: {
|
||||
dict_value: {
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
)";
|
||||
|
||||
constexpr absl::string_view kZeroReturnOutputSignatureTextProto = R"(
|
||||
none_value: {}
|
||||
)";
|
||||
|
||||
constexpr absl::string_view kSingleReturnOutputSignatureTextProto = R"(
|
||||
tensor_spec_value: {
|
||||
shape: {
|
||||
dim: {
|
||||
size: 1
|
||||
}
|
||||
}
|
||||
dtype: DT_FLOAT
|
||||
}
|
||||
)";
|
||||
|
||||
constexpr absl::string_view kThreeReturnOutputSignatureTextProto = R"(
|
||||
tuple_value: {
|
||||
values: {
|
||||
dict_value: {
|
||||
fields: {
|
||||
key : "a"
|
||||
value: {
|
||||
tensor_spec_value: {
|
||||
name : "0/a"
|
||||
shape: {
|
||||
dim: {
|
||||
size: 1
|
||||
}
|
||||
}
|
||||
dtype: DT_FLOAT
|
||||
}
|
||||
}
|
||||
}
|
||||
fields: {
|
||||
key : "b"
|
||||
value: {
|
||||
tensor_spec_value: {
|
||||
name : "0/b"
|
||||
shape: {
|
||||
dim: {
|
||||
size: 1
|
||||
}
|
||||
}
|
||||
dtype: DT_FLOAT
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
values: {
|
||||
tensor_spec_value: {
|
||||
name : "1"
|
||||
shape: {
|
||||
dim: {
|
||||
size: 1
|
||||
}
|
||||
}
|
||||
dtype: DT_FLOAT
|
||||
}
|
||||
}
|
||||
}
|
||||
)";
|
||||
|
||||
StructuredValue ParseStructuredValue(absl::string_view text_proto) {
|
||||
StructuredValue value;
|
||||
CHECK(tensorflow::protobuf::TextFormat::ParseFromString(string(text_proto),
|
||||
&value));
|
||||
return value;
|
||||
}
|
||||
|
||||
} // namespace
|
||||
|
||||
StructuredValue ZeroArgInputSignature() {
|
||||
return ParseStructuredValue(kZeroArgInputSignatureTextProto);
|
||||
}
|
||||
|
||||
StructuredValue SingleArgInputSignature() {
|
||||
return ParseStructuredValue(kSingleArgInputSignatureTextProto);
|
||||
}
|
||||
|
||||
StructuredValue ThreeArgInputSignature() {
|
||||
return ParseStructuredValue(kThreeArgInputSignatureTextProto);
|
||||
}
|
||||
|
||||
StructuredValue ZeroReturnOutputSignature() {
|
||||
return ParseStructuredValue(kZeroReturnOutputSignatureTextProto);
|
||||
}
|
||||
|
||||
StructuredValue SingleReturnOutputSignature() {
|
||||
return ParseStructuredValue(kSingleReturnOutputSignatureTextProto);
|
||||
}
|
||||
|
||||
StructuredValue ThreeReturnOutputSignature() {
|
||||
return ParseStructuredValue(kThreeReturnOutputSignatureTextProto);
|
||||
}
|
||||
|
||||
} // namespace testing
|
||||
} // namespace tensorflow
|
@ -0,0 +1,50 @@
|
||||
/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
|
||||
#ifndef TENSORFLOW_C_EXPERIMENTAL_SAVED_MODEL_TF_CONCRETE_FUNCTION_TEST_PROTOS_H_
|
||||
#define TENSORFLOW_C_EXPERIMENTAL_SAVED_MODEL_TF_CONCRETE_FUNCTION_TEST_PROTOS_H_
|
||||
|
||||
#include "tensorflow/core/protobuf/struct.pb.h"
|
||||
|
||||
namespace tensorflow {
|
||||
namespace testing {
|
||||
|
||||
// Returns a StructuredValue corresponding to the serialized InputSignature of a
|
||||
// tf.function with 0 inputs
|
||||
StructuredValue ZeroArgInputSignature();
|
||||
|
||||
// Returns a StructuredValue corresponding to the serialized InputSignature of a
|
||||
// tf.function with 1 input
|
||||
StructuredValue SingleArgInputSignature();
|
||||
|
||||
// Returns a StructuredValue corresponding to the serialized InputSignature of a
|
||||
// tf.function with 3 inputs
|
||||
StructuredValue ThreeArgInputSignature();
|
||||
|
||||
// Returns a StructuredValue corresponding to the serialized OutputSignature of
|
||||
// a tf.function with no return values
|
||||
StructuredValue ZeroReturnOutputSignature();
|
||||
|
||||
// Returns a StructuredValue corresponding to the serialized OutputSignature of
|
||||
// a tf.function with a single tensor output
|
||||
StructuredValue SingleReturnOutputSignature();
|
||||
|
||||
// Returns a StructuredValue corresponding to the serialized OutputSignature of
|
||||
// a tf.function with three tensor outputs
|
||||
StructuredValue ThreeReturnOutputSignature();
|
||||
|
||||
} // namespace testing
|
||||
} // namespace tensorflow
|
||||
#endif // TENSORFLOW_C_EXPERIMENTAL_SAVED_MODEL_TF_CONCRETE_FUNCTION_TEST_PROTOS_H_
|
@ -41,11 +41,13 @@ cc_library(
|
||||
":tensorhandle_list",
|
||||
":tensorhandle_list_type",
|
||||
"//tensorflow/c:c_api_macros",
|
||||
"//tensorflow/c:tf_status_internal",
|
||||
"//tensorflow/c/eager:c_api",
|
||||
"//tensorflow/c/eager:c_api_internal",
|
||||
"//tensorflow/c/eager:immediate_execution_operation",
|
||||
"//tensorflow/c/eager:tfe_op_internal",
|
||||
"//tensorflow/c/experimental/saved_model/core:concrete_function",
|
||||
"//tensorflow/c/experimental/saved_model/core:function_metadata",
|
||||
"//tensorflow/core:lib",
|
||||
],
|
||||
)
|
||||
|
||||
@ -205,9 +207,13 @@ tf_cc_test(
|
||||
],
|
||||
deps = [
|
||||
"//tensorflow/c:tf_status",
|
||||
"//tensorflow/c:tf_tensor",
|
||||
"//tensorflow/c/eager:c_api",
|
||||
"//tensorflow/c/eager:c_api_experimental",
|
||||
"//tensorflow/c/eager:c_api_test_util",
|
||||
"//tensorflow/c/experimental/saved_model/public:concrete_function",
|
||||
"//tensorflow/c/experimental/saved_model/public:saved_model_api",
|
||||
"//tensorflow/c/experimental/saved_model/public:tensorhandle_list",
|
||||
"//tensorflow/core:lib",
|
||||
"//tensorflow/core:test",
|
||||
"//tensorflow/core:test_main",
|
||||
|
@ -15,12 +15,15 @@ limitations under the License.
|
||||
|
||||
#include "tensorflow/c/experimental/saved_model/public/concrete_function.h"
|
||||
|
||||
#include "tensorflow/c/eager/immediate_execution_operation.h"
|
||||
#include "tensorflow/c/eager/tfe_op_internal.h"
|
||||
#include "tensorflow/c/experimental/saved_model/core/concrete_function.h"
|
||||
#include "tensorflow/c/experimental/saved_model/core/function_metadata.h"
|
||||
#include "tensorflow/c/experimental/saved_model/internal/concrete_function_type.h"
|
||||
#include "tensorflow/c/experimental/saved_model/internal/function_metadata_type.h"
|
||||
#include "tensorflow/c/experimental/saved_model/internal/tensorhandle_list_type.h"
|
||||
#include "tensorflow/c/tf_status_internal.h"
|
||||
#include "tensorflow/core/platform/status.h"
|
||||
|
||||
extern "C" {
|
||||
|
||||
@ -34,8 +37,11 @@ const TF_TensorHandleList* TF_ConcreteFunctionGetCaptures(
|
||||
return tensorflow::wrap(&tensorflow::unwrap(func)->GetCaptures());
|
||||
}
|
||||
|
||||
TFE_Op* TF_ConcreteFunctionGetCallOp(TF_ConcreteFunction* func) {
|
||||
return tensorflow::wrap(tensorflow::unwrap(func)->GetCallOp());
|
||||
TFE_Op* TF_ConcreteFunctionGetCallOp(TF_ConcreteFunction* func,
|
||||
TF_Status* status) {
|
||||
tensorflow::ImmediateOpPtr call_op(nullptr);
|
||||
status->status = tensorflow::unwrap(func)->GetCallOp(&call_op);
|
||||
return tensorflow::wrap(call_op.release());
|
||||
}
|
||||
|
||||
} // end extern "C"
|
||||
|
@ -41,7 +41,7 @@ TF_CAPI_EXPORT extern const TF_TensorHandleList* TF_ConcreteFunctionGetCaptures(
|
||||
|
||||
// Returns a TFE_Op suitable for executing this function.
|
||||
TF_CAPI_EXPORT extern TFE_Op* TF_ConcreteFunctionGetCallOp(
|
||||
TF_ConcreteFunction* func);
|
||||
TF_ConcreteFunction* func, TF_Status* status);
|
||||
|
||||
#ifdef __cplusplus
|
||||
} // end extern "C"
|
||||
|
@ -244,9 +244,7 @@ static bool MustAliasOutput(
|
||||
if (input_output_alias.shape().tuple_shapes_size() == 0) {
|
||||
return false;
|
||||
}
|
||||
return input_output_alias.OutputHasAlias(output_index) &&
|
||||
input_output_alias.GetAliasedParameter(output_index).value().kind ==
|
||||
xla::HloInputOutputAliasConfig::kUserAlias;
|
||||
return input_output_alias.OutputHasAlias(output_index);
|
||||
}
|
||||
|
||||
// Returns an aliased tensor if it exists, nullptr otherwise.
|
||||
|
@ -482,8 +482,8 @@ tf_cc_test(
|
||||
)
|
||||
|
||||
cc_library(
|
||||
name = "xla_hlo_fusion",
|
||||
srcs = ["lib/Dialect/mhlo/transforms/xla_hlo_fusion.cc"],
|
||||
name = "mhlo_fusion",
|
||||
srcs = ["lib/Dialect/mhlo/transforms/mhlo_fusion.cc"],
|
||||
deps = [
|
||||
":cycle_detector",
|
||||
":hlo",
|
||||
@ -680,3 +680,40 @@ cc_library(
|
||||
],
|
||||
alwayslink = 1,
|
||||
)
|
||||
|
||||
cc_library(
|
||||
name = "all_xla_passes_for_testing",
|
||||
visibility = [
|
||||
"//tensorflow/compiler/mlir:__subpackages__",
|
||||
],
|
||||
deps = [
|
||||
":chlo_legalize_to_hlo",
|
||||
":hlo_dialect_registration",
|
||||
":hlo_legalize_to_lhlo",
|
||||
":lhlo",
|
||||
":lhlo_copy_removal",
|
||||
":lhlo_fuse_linalg",
|
||||
":lhlo_legalize_to_affine",
|
||||
":lhlo_legalize_to_gpu",
|
||||
":lhlo_legalize_to_parallel_loops",
|
||||
":mhlo_fusion",
|
||||
":xla_legalize_control_flow",
|
||||
":xla_legalize_tanh_to_approximation",
|
||||
":xla_legalize_to_linalg",
|
||||
":xla_legalize_to_standard",
|
||||
":xla_lower",
|
||||
":xla_sink_constants_to_control_flow",
|
||||
":xla_test_passes",
|
||||
":xla_transform_unranked_hlo",
|
||||
],
|
||||
)
|
||||
|
||||
cc_binary(
|
||||
name = "mlir-hlo-opt",
|
||||
deps = [
|
||||
":all_xla_passes_for_testing",
|
||||
"@llvm-project//mlir:AllPassesAndDialects",
|
||||
"@llvm-project//mlir:MlirOptLib",
|
||||
"@llvm-project//mlir:MlirOptMain",
|
||||
],
|
||||
)
|
||||
|
@ -17,12 +17,12 @@ limitations under the License.
|
||||
// These ops are not necessarily orthogonal or optimized for transformation but
|
||||
// for ease of expression in certain cases deemed important for client
|
||||
// libraries (i.e. implicit broadcasting, helper ops, etc).
|
||||
// This dialect is considered to exist in addition to augment the xla_hlo
|
||||
// This dialect is considered to exist in addition to augment the mhlo
|
||||
// dialect for ergonomic needs, not duplicate/replace it.
|
||||
//
|
||||
// The typical use of this dialect is for client libraries to be able to emit
|
||||
// less constrained ops and rely on the conversion framework to lower any
|
||||
// xla_chlo ops to canonical xla_hlo ops.
|
||||
// xla_chlo ops to canonical mhlo ops.
|
||||
//
|
||||
// See: https://www.tensorflow.org/xla/operation_semantics
|
||||
|
||||
@ -44,7 +44,7 @@ def HLOClient_Dialect : Dialect {
|
||||
let description = [{
|
||||
This dialect contains ops that align closely with the API surface area
|
||||
of the XlaBuilder C++ API, where such ops have semantics that go beyond
|
||||
what exists in the lower level dialects (such as `xla_hlo`). Essentially,
|
||||
what exists in the lower level dialects (such as `mhlo`). Essentially,
|
||||
whenever the client library uses syntactic sugar or composition
|
||||
of multiple ops for an API call, this dialect tries to model the API call
|
||||
and provide conversion patterns to fully materialize into lower level
|
||||
@ -65,7 +65,7 @@ class HLOClient_Op<string mnemonic, list<OpTrait> traits> :
|
||||
// broadcasting (via the broadcast_dimensions attribute) and implicit degenerate
|
||||
// shape broadcasting.
|
||||
//
|
||||
// These correspond to operations in the xla_hlo dialect without the
|
||||
// These correspond to operations in the mhlo dialect without the
|
||||
// "broadcast_" prefix, except that those ops require same-shaped operands and
|
||||
// results.
|
||||
//
|
||||
|
@ -37,12 +37,12 @@ class OpBuilder;
|
||||
|
||||
#include "tensorflow/compiler/mlir/hlo/include/mlir-hlo/Dialect/mhlo/IR/hlo_structs.h.inc"
|
||||
|
||||
namespace xla_hlo {
|
||||
namespace mhlo {
|
||||
|
||||
class XlaHloDialect : public Dialect {
|
||||
public:
|
||||
explicit XlaHloDialect(MLIRContext *context);
|
||||
static StringRef getDialectNamespace() { return "xla_hlo"; }
|
||||
static StringRef getDialectNamespace() { return "mhlo"; }
|
||||
|
||||
// Registered hook to materialize a constant operation from a given attribute
|
||||
// value with the desired resultant type.
|
||||
@ -82,7 +82,7 @@ class TokenType : public Type::TypeBase<TokenType, Type, TypeStorage> {
|
||||
// %1 = index_cast %0 : index to i64
|
||||
// %2 = dim %arg0, 1 : memref<?x?xf32>
|
||||
// %3 = index_cast %2 : index to i64
|
||||
// %4 = "xla_hlo.scalars_to_dimension_tensor"(%1, %3)
|
||||
// %4 = "mhlo.scalars_to_dimension_tensor"(%1, %3)
|
||||
// : (i64, i64) -> tensor<2xi64>
|
||||
//
|
||||
// and returns %4 as the shape value.
|
||||
@ -93,7 +93,7 @@ LogicalResult deriveShapeFromFirstOperand(
|
||||
#define GET_OP_CLASSES
|
||||
#include "tensorflow/compiler/mlir/hlo/include/mlir-hlo/Dialect/mhlo/IR/hlo_ops.h.inc"
|
||||
|
||||
} // end namespace xla_hlo
|
||||
} // end namespace mhlo
|
||||
} // end namespace mlir
|
||||
|
||||
#endif // TENSORFLOW_COMPILER_MLIR_HLO_INCLUDE_MLIR_HLO_DIALECT_MHLO_IR_HLO_OPS_H_
|
||||
|
@ -29,8 +29,8 @@ include "mlir-hlo/Dialect/mhlo/IR/hlo_utils.td"
|
||||
include "mlir-hlo/Dialect/mhlo/IR/infer_fusibility_op_interface.td"
|
||||
|
||||
def HLO_Dialect : Dialect {
|
||||
let name = "xla_hlo";
|
||||
let cppNamespace = "xla_hlo";
|
||||
let name = "mhlo";
|
||||
let cppNamespace = "mhlo";
|
||||
}
|
||||
|
||||
class HLO_Op<string mnemonic, list<OpTrait> traits> :
|
||||
@ -78,6 +78,7 @@ def HLO_IotaOp : HLO_Op<"iota", [NoSideEffect]>, BASE_HLO_IotaOp {
|
||||
|
||||
// TODO(b/130357376): Iota has special conversion logic to HLO.
|
||||
let hasCustomHLOConverter = 1;
|
||||
let hasCanonicalizer = 1;
|
||||
}
|
||||
|
||||
def HLO_DynamicIotaOp: HLO_Op<"dynamic_iota", [NoSideEffect]> {
|
||||
|
@ -22,7 +22,7 @@ limitations under the License.
|
||||
#include "tensorflow/compiler/mlir/hlo/include/mlir-hlo/Dialect/mhlo/IR/lhlo_ops.h"
|
||||
|
||||
namespace mlir {
|
||||
namespace xla_hlo {
|
||||
namespace mhlo {
|
||||
|
||||
template <typename HloOpTy>
|
||||
struct HloToLhloOpImpl {
|
||||
@ -31,10 +31,10 @@ struct HloToLhloOpImpl {
|
||||
template <typename HloOpTy>
|
||||
using HloToLhloOp = typename HloToLhloOpImpl<HloOpTy>::Type;
|
||||
|
||||
#define MAP_HLO_TO_LHLO(OpName) \
|
||||
template <> \
|
||||
struct HloToLhloOpImpl<xla_hlo::OpName> { \
|
||||
using Type = xla_lhlo::OpName; \
|
||||
#define MAP_HLO_TO_LHLO(OpName) \
|
||||
template <> \
|
||||
struct HloToLhloOpImpl<mhlo::OpName> { \
|
||||
using Type = xla_lhlo::OpName; \
|
||||
}
|
||||
|
||||
MAP_HLO_TO_LHLO(AbsOp);
|
||||
@ -74,7 +74,7 @@ MAP_HLO_TO_LHLO(TanhOp);
|
||||
|
||||
#undef MAP_HLO_TO_LHLO
|
||||
|
||||
} // namespace xla_hlo
|
||||
} // namespace mhlo
|
||||
} // namespace mlir
|
||||
|
||||
#endif // TENSORFLOW_COMPILER_MLIR_HLO_INCLUDE_MLIR_HLO_DIALECT_MHLO_TRANSFORMS_MAP_HLO_TO_LHLO_OP_H_
|
||||
|
@ -464,7 +464,7 @@ struct XlaOpToStdScalarOp {
|
||||
template <typename XlaOpTy, typename LhloOpTy = XlaOpTy,
|
||||
typename = std::enable_if_t<
|
||||
!std::is_same<LhloOpTy, xla_lhlo::CompareOp>::value &&
|
||||
std::is_same<typename xla_hlo::HloToLhloOp<LhloOpTy>,
|
||||
std::is_same<typename mhlo::HloToLhloOp<LhloOpTy>,
|
||||
std::false_type>::value>>
|
||||
static Value map(XlaOpTy op, ArrayRef<Type> result_types,
|
||||
ArrayRef<Value> args, OpBuilder* b, unsigned i = 0) {
|
||||
@ -472,8 +472,8 @@ struct XlaOpToStdScalarOp {
|
||||
args, b);
|
||||
}
|
||||
|
||||
// Implementation for HLO ops except xla_hlo::CompareOp.
|
||||
template <typename XlaOpTy, typename LhloOpTy = xla_hlo::HloToLhloOp<XlaOpTy>,
|
||||
// Implementation for HLO ops except mhlo::CompareOp.
|
||||
template <typename XlaOpTy, typename LhloOpTy = mhlo::HloToLhloOp<XlaOpTy>,
|
||||
typename = std::enable_if_t<
|
||||
!std::is_same<LhloOpTy, xla_lhlo::CompareOp>::value &&
|
||||
!std::is_same<LhloOpTy, std::false_type>::value>>
|
||||
@ -493,10 +493,11 @@ struct XlaOpToStdScalarOp {
|
||||
op.getLoc(), comparison_direction, result_types, args, b);
|
||||
}
|
||||
|
||||
// Implementation for xla_hlo::CompareOp.
|
||||
template <typename HloOpTy, typename = std::enable_if_t<std::is_same<
|
||||
HloOpTy, xla_hlo::CompareOp>::value>>
|
||||
static Value map(xla_hlo::CompareOp op, ArrayRef<Type> result_types,
|
||||
// Implementation for mhlo::CompareOp.
|
||||
template <typename HloOpTy,
|
||||
typename =
|
||||
std::enable_if_t<std::is_same<HloOpTy, mhlo::CompareOp>::value>>
|
||||
static Value map(mhlo::CompareOp op, ArrayRef<Type> result_types,
|
||||
ArrayRef<Value> args, OpBuilder* b) {
|
||||
auto comparison_direction = op.comparison_direction();
|
||||
return impl::MapXlaCompareOpToStdScalarOp<xla_lhlo::CompareOp>(
|
||||
|
@ -29,7 +29,7 @@ template <typename T>
|
||||
class OperationPass;
|
||||
class Pass;
|
||||
|
||||
namespace xla_hlo {
|
||||
namespace mhlo {
|
||||
|
||||
/// Lowers HLO control flow ops to the Standard dialect.
|
||||
std::unique_ptr<OperationPass<FuncOp>> createLegalizeControlFlowPass();
|
||||
@ -55,10 +55,10 @@ std::unique_ptr<OperationPass<FuncOp>> createTransformUnrankedHloPass();
|
||||
// necessary to export to XLA.
|
||||
std::unique_ptr<OperationPass<FuncOp>> createSinkConstantsToControlFlowPass();
|
||||
|
||||
// fuse xla_hlo ops to kLoop/kInput fusion patterns
|
||||
// fuse mhlo ops to kLoop/kInput fusion patterns
|
||||
std::unique_ptr<OperationPass<FuncOp>> createXlaHloFusionPass();
|
||||
|
||||
} // namespace xla_hlo
|
||||
} // namespace mhlo
|
||||
|
||||
namespace xla_lhlo {
|
||||
|
||||
|
@ -27,7 +27,7 @@ class LLVMTypeConverter;
|
||||
class LowerToLLVMOptions;
|
||||
class OwningRewritePatternList;
|
||||
class BufferAssignmentPlacer;
|
||||
namespace xla_hlo {
|
||||
namespace mhlo {
|
||||
|
||||
// Collection of rewrite patterns for lowering a general dot product.
|
||||
void PopulateGeneralDotOpLoweringPatterns(OwningRewritePatternList *patterns,
|
||||
@ -73,7 +73,7 @@ void PopulateTransformUnrankedHloPatterns(MLIRContext *context,
|
||||
void PopulateUnfuseBatchNormPatterns(MLIRContext *context,
|
||||
OwningRewritePatternList *patterns);
|
||||
|
||||
} // namespace xla_hlo
|
||||
} // namespace mhlo
|
||||
|
||||
namespace xla_lhlo {
|
||||
|
||||
|
@ -18,7 +18,7 @@ limitations under the License.
|
||||
#include "tensorflow/compiler/mlir/hlo/include/mlir-hlo/Dialect/mhlo/IR/lhlo_ops.h"
|
||||
|
||||
// Static initialization for XLA dialect registration.
|
||||
static mlir::DialectRegistration<mlir::xla_hlo::XlaHloDialect> xla_hlo_ops;
|
||||
static mlir::DialectRegistration<mlir::mhlo::XlaHloDialect> mhlo_ops;
|
||||
static mlir::DialectRegistration<mlir::xla_chlo::XlaHloClientDialect>
|
||||
xla_chlo_ops;
|
||||
static mlir::DialectRegistration<mlir::xla_lhlo::XlaLhloDialect> xla_lhlo_ops;
|
||||
|
@ -60,7 +60,7 @@ limitations under the License.
|
||||
|
||||
namespace mlir {
|
||||
#include "tensorflow/compiler/mlir/hlo/include/mlir-hlo/Dialect/mhlo/IR/hlo_structs.cc.inc"
|
||||
namespace xla_hlo {
|
||||
namespace mhlo {
|
||||
|
||||
Operation* XlaHloDialect::materializeConstant(OpBuilder& builder,
|
||||
Attribute value, Type type,
|
||||
@ -68,8 +68,7 @@ Operation* XlaHloDialect::materializeConstant(OpBuilder& builder,
|
||||
// HLO dialect constants only support ElementsAttr unlike standard dialect
|
||||
// constant which supports all attributes.
|
||||
if (value.isa<ElementsAttr>())
|
||||
return builder.create<xla_hlo::ConstOp>(loc, type,
|
||||
value.cast<ElementsAttr>());
|
||||
return builder.create<mhlo::ConstOp>(loc, type, value.cast<ElementsAttr>());
|
||||
return nullptr;
|
||||
}
|
||||
|
||||
@ -167,7 +166,7 @@ void ConstOp::build(OpBuilder& builder, OperationState& result,
|
||||
}
|
||||
|
||||
// TODO: support other XLA specific types.
|
||||
assert(type && "unsupported attribute type for building xla_hlo.constant");
|
||||
assert(type && "unsupported attribute type for building mhlo.constant");
|
||||
result.types.push_back(type);
|
||||
result.addAttribute("value", value);
|
||||
}
|
||||
@ -215,6 +214,41 @@ static LogicalResult Verify(IotaOp op) {
|
||||
return success();
|
||||
}
|
||||
|
||||
// Iota operations across multiple dimensions can be reduced to an iota and a
|
||||
// ranked broadcast.
|
||||
struct IotaBroadcast : public OpRewritePattern<IotaOp> {
|
||||
using OpRewritePattern<IotaOp>::OpRewritePattern;
|
||||
|
||||
LogicalResult matchAndRewrite(IotaOp iota,
|
||||
PatternRewriter& rewriter) const override {
|
||||
auto result_ty = iota.getType().cast<ShapedType>();
|
||||
if (!result_ty.hasRank() || result_ty.getRank() < 2) {
|
||||
return failure();
|
||||
}
|
||||
|
||||
auto iota_dimension = iota.iota_dimension();
|
||||
|
||||
auto iota_type = RankedTensorType::get(
|
||||
{result_ty.getDimSize(iota_dimension.getLimitedValue())},
|
||||
result_ty.getElementType());
|
||||
|
||||
auto new_iota = rewriter.create<IotaOp>(iota.getLoc(), iota_type,
|
||||
rewriter.getI64IntegerAttr(0));
|
||||
|
||||
auto broadcast_attr = DenseIntElementsAttr::get(
|
||||
RankedTensorType::get({1}, rewriter.getIntegerType(64)),
|
||||
{iota_dimension});
|
||||
rewriter.replaceOpWithNewOp<BroadcastInDimOp>(iota, result_ty, new_iota,
|
||||
broadcast_attr);
|
||||
return success();
|
||||
}
|
||||
};
|
||||
|
||||
void IotaOp::getCanonicalizationPatterns(OwningRewritePatternList& results,
|
||||
MLIRContext* context) {
|
||||
results.insert<IotaBroadcast>(context);
|
||||
}
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
// DynamicIotaOp
|
||||
//===----------------------------------------------------------------------===//
|
||||
@ -236,11 +270,63 @@ struct DynamicIotaIsStatic : public OpRewritePattern<DynamicIotaOp> {
|
||||
}
|
||||
};
|
||||
|
||||
// Dynamic Iota operations across multiple dimensions can be reduced to an iota
|
||||
// and a ranked broadcast.
|
||||
struct DynamicIotaBroadcast : public OpRewritePattern<DynamicIotaOp> {
|
||||
using OpRewritePattern<DynamicIotaOp>::OpRewritePattern;
|
||||
|
||||
LogicalResult matchAndRewrite(DynamicIotaOp iota,
|
||||
PatternRewriter& rewriter) const override {
|
||||
auto result_ty = iota.getType().cast<ShapedType>();
|
||||
if (!result_ty.hasRank() || result_ty.getRank() < 2) {
|
||||
return failure();
|
||||
}
|
||||
|
||||
auto iota_dimension = iota.iota_dimension();
|
||||
auto iota_dimension_int = iota_dimension.getLimitedValue();
|
||||
|
||||
auto converted_shape = rewriter.create<IndexCastOp>(
|
||||
iota.getLoc(),
|
||||
RankedTensorType::get(
|
||||
iota.output_shape().getType().cast<ShapedType>().getShape(),
|
||||
rewriter.getI64Type()),
|
||||
iota.output_shape());
|
||||
|
||||
auto sliced_shape = rewriter.create<SliceOp>(
|
||||
iota.getLoc(), converted_shape,
|
||||
GetI64ElementsAttr(iota_dimension_int, &rewriter),
|
||||
GetI64ElementsAttr(iota_dimension_int + 1, &rewriter),
|
||||
GetI64ElementsAttr(1, &rewriter));
|
||||
|
||||
auto converted_sliced_shape = rewriter.create<IndexCastOp>(
|
||||
iota.getLoc(),
|
||||
RankedTensorType::get(
|
||||
{1},
|
||||
iota.output_shape().getType().cast<ShapedType>().getElementType()),
|
||||
sliced_shape);
|
||||
|
||||
auto iota_type = RankedTensorType::get(
|
||||
{result_ty.getDimSize(iota_dimension_int)}, result_ty.getElementType());
|
||||
|
||||
auto new_iota = rewriter.create<DynamicIotaOp>(
|
||||
iota.getLoc(), iota_type, converted_sliced_shape,
|
||||
rewriter.getI64IntegerAttr(0));
|
||||
|
||||
auto broadcast_attr = DenseIntElementsAttr::get(
|
||||
RankedTensorType::get({1}, rewriter.getIntegerType(64)),
|
||||
{iota_dimension});
|
||||
rewriter.replaceOpWithNewOp<DynamicBroadcastInDimOp>(
|
||||
iota, result_ty, new_iota, iota.output_shape(), broadcast_attr);
|
||||
return success();
|
||||
}
|
||||
};
|
||||
|
||||
} // namespace
|
||||
|
||||
void DynamicIotaOp::getCanonicalizationPatterns(
|
||||
OwningRewritePatternList& results, MLIRContext* context) {
|
||||
results.insert<DynamicIotaIsStatic>(context);
|
||||
results.insert<DynamicIotaBroadcast>(context);
|
||||
}
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
@ -387,7 +473,7 @@ static LogicalResult Verify(GetTupleElementOp op) {
|
||||
|
||||
OpFoldResult GetTupleElementOp::fold(ArrayRef<Attribute> operands) {
|
||||
if (auto tupleOp =
|
||||
dyn_cast_or_null<xla_hlo::TupleOp>(getOperand().getDefiningOp())) {
|
||||
dyn_cast_or_null<mhlo::TupleOp>(getOperand().getDefiningOp())) {
|
||||
return tupleOp.getOperand(index().getLimitedValue());
|
||||
}
|
||||
|
||||
@ -693,10 +779,8 @@ void ComplexOp::build(OpBuilder& builder, OperationState& state, Value lhs,
|
||||
}
|
||||
|
||||
OpFoldResult ComplexOp::fold(ArrayRef<Attribute> operands) {
|
||||
auto real_op =
|
||||
dyn_cast_or_null<xla_hlo::RealOp>(getOperand(0).getDefiningOp());
|
||||
auto imag_op =
|
||||
dyn_cast_or_null<xla_hlo::ImagOp>(getOperand(1).getDefiningOp());
|
||||
auto real_op = dyn_cast_or_null<mhlo::RealOp>(getOperand(0).getDefiningOp());
|
||||
auto imag_op = dyn_cast_or_null<mhlo::ImagOp>(getOperand(1).getDefiningOp());
|
||||
if (real_op && imag_op && real_op.getOperand() == imag_op.getOperand()) {
|
||||
return real_op.getOperand();
|
||||
}
|
||||
@ -727,7 +811,7 @@ void ImagOp::build(OpBuilder& builder, OperationState& state, Value val) {
|
||||
|
||||
OpFoldResult ImagOp::fold(ArrayRef<Attribute> operands) {
|
||||
if (auto complex_op =
|
||||
dyn_cast_or_null<xla_hlo::ComplexOp>(getOperand().getDefiningOp())) {
|
||||
dyn_cast_or_null<mhlo::ComplexOp>(getOperand().getDefiningOp())) {
|
||||
return complex_op.getOperand(1);
|
||||
}
|
||||
|
||||
@ -740,7 +824,7 @@ void RealOp::build(OpBuilder& builder, OperationState& state, Value val) {
|
||||
|
||||
OpFoldResult RealOp::fold(ArrayRef<Attribute> operands) {
|
||||
if (auto complex_op =
|
||||
dyn_cast_or_null<xla_hlo::ComplexOp>(getOperand().getDefiningOp())) {
|
||||
dyn_cast_or_null<mhlo::ComplexOp>(getOperand().getDefiningOp())) {
|
||||
return complex_op.getOperand(0);
|
||||
}
|
||||
|
||||
@ -1148,7 +1232,7 @@ static LogicalResult Verify(MapOp op) {
|
||||
// RecvOp
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
// Checks that the result type is of the form `tuple<any_type, xla_hlo::token>`
|
||||
// Checks that the result type is of the form `tuple<any_type, mhlo::token>`
|
||||
static LogicalResult Verify(RecvOp op) {
|
||||
auto result_ty = op.getResult().getType().cast<TupleType>();
|
||||
auto subtypes = result_ty.getTypes();
|
||||
@ -2020,7 +2104,7 @@ void CompareOp::build(OpBuilder& builder, OperationState& result, Value lhs,
|
||||
#include "tensorflow/compiler/mlir/hlo/include/mlir-hlo/Dialect/mhlo/IR/hlo_ops.cc.inc"
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
// xla_hlo Dialect Interfaces
|
||||
// mhlo Dialect Interfaces
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
namespace {
|
||||
@ -2032,7 +2116,7 @@ struct HLOInlinerInterface : public DialectInlinerInterface {
|
||||
BlockAndValueMapping& valueMapping) const final {
|
||||
return true;
|
||||
}
|
||||
// Operations in xla_hlo dialect are always legal to inline since they are
|
||||
// Operations in mhlo dialect are always legal to inline since they are
|
||||
// pure.
|
||||
bool isLegalToInline(Operation*, Region*, BlockAndValueMapping&) const final {
|
||||
return true;
|
||||
@ -2041,7 +2125,7 @@ struct HLOInlinerInterface : public DialectInlinerInterface {
|
||||
} // end anonymous namespace
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
// xla_hlo Dialect Constructor
|
||||
// mhlo Dialect Constructor
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
XlaHloDialect::XlaHloDialect(MLIRContext* context)
|
||||
@ -2061,8 +2145,7 @@ Type XlaHloDialect::parseType(DialectAsmParser& parser) const {
|
||||
if (parser.parseKeyword(&data_type)) return Type();
|
||||
|
||||
if (data_type == "token") return TokenType::get(getContext());
|
||||
parser.emitError(parser.getNameLoc())
|
||||
<< "unknown xla_hlo type: " << data_type;
|
||||
parser.emitError(parser.getNameLoc()) << "unknown mhlo type: " << data_type;
|
||||
return nullptr;
|
||||
}
|
||||
|
||||
@ -2071,7 +2154,7 @@ void XlaHloDialect::printType(Type type, DialectAsmPrinter& os) const {
|
||||
os << "token";
|
||||
return;
|
||||
}
|
||||
os << "<unknown xla_hlo type>";
|
||||
os << "<unknown mhlo type>";
|
||||
}
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
@ -2106,5 +2189,5 @@ LogicalResult deriveShapeFromFirstOperand(
|
||||
return success();
|
||||
}
|
||||
|
||||
} // namespace xla_hlo
|
||||
} // namespace mhlo
|
||||
} // namespace mlir
|
||||
|
@ -30,7 +30,7 @@ namespace xla_chlo {
|
||||
namespace {
|
||||
|
||||
// Converts binary ops that statically are determined to not broadcast directly
|
||||
// to the corresponding xla_hlo non-broadcasting op.
|
||||
// to the corresponding mhlo non-broadcasting op.
|
||||
template <typename ChloOpTy, typename HloOpTy, typename Adaptor>
|
||||
struct ConvertTrivialNonBroadcastBinaryOp : public OpRewritePattern<ChloOpTy> {
|
||||
using OpRewritePattern<ChloOpTy>::OpRewritePattern;
|
||||
@ -63,7 +63,7 @@ struct ConvertTrivialNonBroadcastBinaryOp : public OpRewritePattern<ChloOpTy> {
|
||||
};
|
||||
|
||||
// Converts a binary op with ranked broadcasting operands to explicitly
|
||||
// broadcast and invoke the corresponding xla_hlo non-broadcasting op.
|
||||
// broadcast and invoke the corresponding mhlo non-broadcasting op.
|
||||
// Note that dynamic broadcasting supported by this pattern is only valid for
|
||||
// "numpy" broadcasting semantics as defined here:
|
||||
// https://docs.scipy.org/doc/numpy/reference/ufuncs.html
|
||||
@ -136,7 +136,7 @@ struct ConvertRankedDynamicBroadcastBinaryOp
|
||||
// properly.
|
||||
auto lhs_broadcast_dimensions = llvm::to_vector<4>(
|
||||
llvm::seq<int64_t>(result_rank - lhs_type.getRank(), result_rank));
|
||||
Value broadcasted_lhs = rewriter.create<xla_hlo::DynamicBroadcastInDimOp>(
|
||||
Value broadcasted_lhs = rewriter.create<mhlo::DynamicBroadcastInDimOp>(
|
||||
loc,
|
||||
RankedTensorType::get(result_type.getShape(),
|
||||
lhs_type.getElementType()),
|
||||
@ -144,7 +144,7 @@ struct ConvertRankedDynamicBroadcastBinaryOp
|
||||
rewriter.getI64TensorAttr(lhs_broadcast_dimensions));
|
||||
auto rhs_broadcast_dimensions = llvm::to_vector<4>(
|
||||
llvm::seq<int64_t>(result_rank - rhs_type.getRank(), result_rank));
|
||||
Value broadcasted_rhs = rewriter.create<xla_hlo::DynamicBroadcastInDimOp>(
|
||||
Value broadcasted_rhs = rewriter.create<mhlo::DynamicBroadcastInDimOp>(
|
||||
loc,
|
||||
RankedTensorType::get(result_type.getShape(),
|
||||
rhs_type.getElementType()),
|
||||
@ -182,23 +182,21 @@ struct HloBinaryElementwiseAdaptor {
|
||||
};
|
||||
|
||||
struct HloComplexAdaptor {
|
||||
static xla_hlo::ComplexOp CreateOp(BroadcastComplexOp from_op,
|
||||
Type result_type, Value broadcasted_lhs,
|
||||
Value broadcasted_rhs,
|
||||
OpBuilder &builder) {
|
||||
return builder.create<xla_hlo::ComplexOp>(from_op.getLoc(), result_type,
|
||||
broadcasted_lhs, broadcasted_rhs);
|
||||
static mhlo::ComplexOp CreateOp(BroadcastComplexOp from_op, Type result_type,
|
||||
Value broadcasted_lhs, Value broadcasted_rhs,
|
||||
OpBuilder &builder) {
|
||||
return builder.create<mhlo::ComplexOp>(from_op.getLoc(), result_type,
|
||||
broadcasted_lhs, broadcasted_rhs);
|
||||
}
|
||||
};
|
||||
|
||||
struct HloCompareAdaptor {
|
||||
static xla_hlo::CompareOp CreateOp(BroadcastCompareOp from_op,
|
||||
Type result_type, Value broadcasted_lhs,
|
||||
Value broadcasted_rhs,
|
||||
OpBuilder &builder) {
|
||||
return builder.create<xla_hlo::CompareOp>(from_op.getLoc(), result_type,
|
||||
broadcasted_lhs, broadcasted_rhs,
|
||||
from_op.comparison_direction());
|
||||
static mhlo::CompareOp CreateOp(BroadcastCompareOp from_op, Type result_type,
|
||||
Value broadcasted_lhs, Value broadcasted_rhs,
|
||||
OpBuilder &builder) {
|
||||
return builder.create<mhlo::CompareOp>(from_op.getLoc(), result_type,
|
||||
broadcasted_lhs, broadcasted_rhs,
|
||||
from_op.comparison_direction());
|
||||
}
|
||||
};
|
||||
|
||||
@ -214,28 +212,27 @@ void PopulateLegalizeChloToHloPatterns(MLIRContext *context,
|
||||
HloBinaryElementwiseAdaptor<ChloOp, HloOp>>(context, \
|
||||
patterns);
|
||||
|
||||
POPULATE_BCAST(BroadcastAddOp, xla_hlo::AddOp);
|
||||
POPULATE_BCAST(BroadcastAndOp, xla_hlo::AndOp);
|
||||
POPULATE_BCAST(BroadcastAtan2Op, xla_hlo::Atan2Op);
|
||||
POPULATE_BCAST(BroadcastDivOp, xla_hlo::DivOp);
|
||||
POPULATE_BCAST(BroadcastMaxOp, xla_hlo::MaxOp);
|
||||
POPULATE_BCAST(BroadcastMinOp, xla_hlo::MinOp);
|
||||
POPULATE_BCAST(BroadcastMulOp, xla_hlo::MulOp);
|
||||
POPULATE_BCAST(BroadcastOrOp, xla_hlo::OrOp);
|
||||
POPULATE_BCAST(BroadcastPowOp, xla_hlo::PowOp);
|
||||
POPULATE_BCAST(BroadcastRemOp, xla_hlo::RemOp);
|
||||
POPULATE_BCAST(BroadcastShiftLeftOp, xla_hlo::ShiftLeftOp);
|
||||
POPULATE_BCAST(BroadcastShiftRightArithmeticOp,
|
||||
xla_hlo::ShiftRightArithmeticOp);
|
||||
POPULATE_BCAST(BroadcastShiftRightLogicalOp, xla_hlo::ShiftRightLogicalOp);
|
||||
POPULATE_BCAST(BroadcastSubOp, xla_hlo::SubOp);
|
||||
POPULATE_BCAST(BroadcastXorOp, xla_hlo::XorOp);
|
||||
POPULATE_BCAST(BroadcastAddOp, mhlo::AddOp);
|
||||
POPULATE_BCAST(BroadcastAndOp, mhlo::AndOp);
|
||||
POPULATE_BCAST(BroadcastAtan2Op, mhlo::Atan2Op);
|
||||
POPULATE_BCAST(BroadcastDivOp, mhlo::DivOp);
|
||||
POPULATE_BCAST(BroadcastMaxOp, mhlo::MaxOp);
|
||||
POPULATE_BCAST(BroadcastMinOp, mhlo::MinOp);
|
||||
POPULATE_BCAST(BroadcastMulOp, mhlo::MulOp);
|
||||
POPULATE_BCAST(BroadcastOrOp, mhlo::OrOp);
|
||||
POPULATE_BCAST(BroadcastPowOp, mhlo::PowOp);
|
||||
POPULATE_BCAST(BroadcastRemOp, mhlo::RemOp);
|
||||
POPULATE_BCAST(BroadcastShiftLeftOp, mhlo::ShiftLeftOp);
|
||||
POPULATE_BCAST(BroadcastShiftRightArithmeticOp, mhlo::ShiftRightArithmeticOp);
|
||||
POPULATE_BCAST(BroadcastShiftRightLogicalOp, mhlo::ShiftRightLogicalOp);
|
||||
POPULATE_BCAST(BroadcastSubOp, mhlo::SubOp);
|
||||
POPULATE_BCAST(BroadcastXorOp, mhlo::XorOp);
|
||||
|
||||
// Broadcasting ops requiring special construction.
|
||||
PopulateForBinaryOp<BroadcastComplexOp, xla_hlo::ComplexOp,
|
||||
HloComplexAdaptor>(context, patterns);
|
||||
PopulateForBinaryOp<BroadcastCompareOp, xla_hlo::CompareOp,
|
||||
HloCompareAdaptor>(context, patterns);
|
||||
PopulateForBinaryOp<BroadcastComplexOp, mhlo::ComplexOp, HloComplexAdaptor>(
|
||||
context, patterns);
|
||||
PopulateForBinaryOp<BroadcastCompareOp, mhlo::CompareOp, HloCompareAdaptor>(
|
||||
context, patterns);
|
||||
}
|
||||
|
||||
} // namespace xla_chlo
|
||||
|
@ -32,8 +32,8 @@ struct TestChloLegalizeToHloPass
|
||||
OwningRewritePatternList conversionPatterns;
|
||||
|
||||
conversionTarget.addIllegalDialect<XlaHloClientDialect>();
|
||||
// Consider the xla_hlo dialect legal for tests.
|
||||
conversionTarget.addLegalDialect<xla_hlo::XlaHloDialect>();
|
||||
// Consider the mhlo dialect legal for tests.
|
||||
conversionTarget.addLegalDialect<mhlo::XlaHloDialect>();
|
||||
// The conversion uses helpers from the Standard dialect.
|
||||
conversionTarget.addLegalDialect<mlir::StandardOpsDialect>();
|
||||
conversionTarget.addLegalDialect<mlir::shape::ShapeDialect>();
|
||||
|
@ -37,7 +37,7 @@ limitations under the License.
|
||||
#include "tensorflow/compiler/mlir/hlo/include/mlir-hlo/Dialect/mhlo/transforms/rewriters.h"
|
||||
|
||||
namespace mlir {
|
||||
namespace xla_hlo {
|
||||
namespace mhlo {
|
||||
namespace {
|
||||
|
||||
template <typename T>
|
||||
@ -128,20 +128,20 @@ class HloToLhloOpConverter : public BaseOpConversion<HloOpTy> {
|
||||
op->getLoc(), result.value(), results_shape.front(), &rewriter));
|
||||
}
|
||||
}
|
||||
rewriter.create<xla_hlo::HloToLhloOp<HloOpTy>>(op->getLoc(), llvm::None,
|
||||
buffer_args, op->getAttrs());
|
||||
rewriter.create<mhlo::HloToLhloOp<HloOpTy>>(op->getLoc(), llvm::None,
|
||||
buffer_args, op->getAttrs());
|
||||
rewriter.replaceOp(op, ArrayRef<Value>(buffer_args).slice(operands.size()));
|
||||
return success();
|
||||
}
|
||||
};
|
||||
|
||||
struct HloToLhloDynamicBroadcastInDimOpConverter
|
||||
: public BaseOpConversion<xla_hlo::DynamicBroadcastInDimOp> {
|
||||
: public BaseOpConversion<mhlo::DynamicBroadcastInDimOp> {
|
||||
public:
|
||||
using BaseOpConversion<xla_hlo::DynamicBroadcastInDimOp>::BaseOpConversion;
|
||||
using BaseOpConversion<mhlo::DynamicBroadcastInDimOp>::BaseOpConversion;
|
||||
|
||||
LogicalResult matchAndRewrite(
|
||||
xla_hlo::DynamicBroadcastInDimOp op, ArrayRef<Value> operands,
|
||||
mhlo::DynamicBroadcastInDimOp op, ArrayRef<Value> operands,
|
||||
ConversionPatternRewriter& rewriter) const final {
|
||||
auto loc = op.getLoc();
|
||||
Value resultBuffer = InsertDynamicAllocAndDealloc(
|
||||
@ -162,7 +162,7 @@ struct HloToLhloDynamicBroadcastInDimOpConverter
|
||||
// and size of the target dimension if size-1 dimension expansion is
|
||||
// necessary.
|
||||
xla_lhlo::DynamicMemRefCastOp InsertDynamicMemrefCastOp(
|
||||
xla_hlo::DynamicBroadcastInDimOp op, Value operand, OpBuilder* b) const {
|
||||
mhlo::DynamicBroadcastInDimOp op, Value operand, OpBuilder* b) const {
|
||||
auto loc = op.getLoc();
|
||||
auto operand_type = operand.getType().cast<MemRefType>();
|
||||
auto operand_shape = operand_type.getShape();
|
||||
@ -220,12 +220,12 @@ struct HloToLhloDynamicBroadcastInDimOpConverter
|
||||
}
|
||||
};
|
||||
|
||||
struct HloToLhloReduceOpConverter : public BaseOpConversion<xla_hlo::ReduceOp> {
|
||||
struct HloToLhloReduceOpConverter : public BaseOpConversion<mhlo::ReduceOp> {
|
||||
public:
|
||||
using BaseOpConversion<xla_hlo::ReduceOp>::BaseOpConversion;
|
||||
using BaseOpConversion<mhlo::ReduceOp>::BaseOpConversion;
|
||||
|
||||
LogicalResult matchAndRewrite(
|
||||
xla_hlo::ReduceOp op, ArrayRef<Value> operands,
|
||||
mhlo::ReduceOp op, ArrayRef<Value> operands,
|
||||
ConversionPatternRewriter& rewriter) const final {
|
||||
auto loc = op.getLoc();
|
||||
// TODO(b/137624192) Implement variadic reduce.
|
||||
@ -314,10 +314,10 @@ class HloToLhloTensorStoreOpConverter
|
||||
// "xla_lhlo.fusion"() ({
|
||||
// %0 = tensor_load %arg1 : memref<2x2xf32>
|
||||
// %1 = tensor_load %arg2 : memref<2x2xf32>
|
||||
// %2 = "xla_hlo.add"(%0, %1) :
|
||||
// %2 = "mhlo.add"(%0, %1) :
|
||||
// (tensor<2x2xf32>, tensor<2x2xf32>) -> tensor<2x2xf32>
|
||||
// %3 = tensor_load %arg0 : memref<2x2xf32>
|
||||
// %4 = "xla_hlo.multiply"(%2, %3) :
|
||||
// %4 = "mhlo.multiply"(%2, %3) :
|
||||
// (tensor<2x2xf32>, tensor<2x2xf32>) -> tensor<2x2xf32>
|
||||
// tensor_store %4, %arg3 : memref<2x2xf32>
|
||||
// "xla_lhlo.terminator"() : () -> ()
|
||||
@ -344,8 +344,8 @@ class HloToLhloTensorStoreOpConverter
|
||||
// FuncOp signature conversion example:
|
||||
//
|
||||
// func @func_op(%arg0: tensor<4xf32>, %arg1: tensor<4xf32>) -> tensor<4xf32> {
|
||||
// %0 = "xla_hlo.maximum"(%arg0, %arg1) : (tensor<4xf32>, tensor<4xf32>) ->
|
||||
// tensor<4xf32> %1 = "xla_hlo.add"(%arg0, %0) : (tensor<4xf32>,
|
||||
// %0 = "mhlo.maximum"(%arg0, %arg1) : (tensor<4xf32>, tensor<4xf32>) ->
|
||||
// tensor<4xf32> %1 = "mhlo.add"(%arg0, %0) : (tensor<4xf32>,
|
||||
// tensor<4xf32>) -> tensor<4xf32> return %1 : tensor<4xf32>
|
||||
// }
|
||||
//
|
||||
@ -388,7 +388,7 @@ struct HloLegalizeToLhlo
|
||||
target.addIllegalOp<mlir::TensorStoreOp>();
|
||||
target.addLegalOp<ModuleTerminatorOp>();
|
||||
target.addLegalOp<TensorFromElementsOp>();
|
||||
target.addIllegalDialect<xla_hlo::XlaHloDialect>();
|
||||
target.addIllegalDialect<mhlo::XlaHloDialect>();
|
||||
|
||||
BufferAssignmentTypeConverter converter;
|
||||
target.addDynamicallyLegalOp<FuncOp>([&](FuncOp op) {
|
||||
@ -442,38 +442,38 @@ void populateHLOToLHLOConversionPattern(
|
||||
// clang-format off
|
||||
patterns->insert<
|
||||
HloToLhloDynamicBroadcastInDimOpConverter,
|
||||
HloToLhloOpConverter<xla_hlo::AbsOp>,
|
||||
HloToLhloOpConverter<xla_hlo::AddOp>,
|
||||
HloToLhloOpConverter<xla_hlo::AndOp>,
|
||||
HloToLhloOpConverter<xla_hlo::BroadcastInDimOp>,
|
||||
HloToLhloOpConverter<xla_hlo::CeilOp>,
|
||||
HloToLhloOpConverter<xla_hlo::CompareOp>,
|
||||
HloToLhloOpConverter<xla_hlo::ComplexOp>,
|
||||
HloToLhloOpConverter<xla_hlo::ConstOp>,
|
||||
HloToLhloOpConverter<xla_hlo::ConvOp>,
|
||||
HloToLhloOpConverter<xla_hlo::ConvertOp>,
|
||||
HloToLhloOpConverter<xla_hlo::CopyOp>,
|
||||
HloToLhloOpConverter<xla_hlo::CosOp>,
|
||||
HloToLhloOpConverter<xla_hlo::DivOp>,
|
||||
HloToLhloOpConverter<xla_hlo::DotOp>,
|
||||
HloToLhloOpConverter<xla_hlo::ExpOp>,
|
||||
HloToLhloOpConverter<xla_hlo::GatherOp>,
|
||||
HloToLhloOpConverter<xla_hlo::ImagOp>,
|
||||
HloToLhloOpConverter<xla_hlo::IotaOp>,
|
||||
HloToLhloOpConverter<xla_hlo::LogOp>,
|
||||
HloToLhloOpConverter<xla_hlo::MaxOp>,
|
||||
HloToLhloOpConverter<xla_hlo::MinOp>,
|
||||
HloToLhloOpConverter<xla_hlo::MulOp>,
|
||||
HloToLhloOpConverter<xla_hlo::NegOp>,
|
||||
HloToLhloOpConverter<xla_hlo::RealOp>,
|
||||
HloToLhloOpConverter<xla_hlo::RemOp>,
|
||||
HloToLhloOpConverter<xla_hlo::RsqrtOp>,
|
||||
HloToLhloOpConverter<xla_hlo::ReshapeOp>,
|
||||
HloToLhloOpConverter<xla_hlo::SelectOp>,
|
||||
HloToLhloOpConverter<xla_hlo::SignOp>,
|
||||
HloToLhloOpConverter<xla_hlo::SqrtOp>,
|
||||
HloToLhloOpConverter<xla_hlo::SubOp>,
|
||||
HloToLhloOpConverter<xla_hlo::TanhOp>,
|
||||
HloToLhloOpConverter<mhlo::AbsOp>,
|
||||
HloToLhloOpConverter<mhlo::AddOp>,
|
||||
HloToLhloOpConverter<mhlo::AndOp>,
|
||||
HloToLhloOpConverter<mhlo::BroadcastInDimOp>,
|
||||
HloToLhloOpConverter<mhlo::CeilOp>,
|
||||
HloToLhloOpConverter<mhlo::CompareOp>,
|
||||
HloToLhloOpConverter<mhlo::ComplexOp>,
|
||||
HloToLhloOpConverter<mhlo::ConstOp>,
|
||||
HloToLhloOpConverter<mhlo::ConvOp>,
|
||||
HloToLhloOpConverter<mhlo::ConvertOp>,
|
||||
HloToLhloOpConverter<mhlo::CopyOp>,
|
||||
HloToLhloOpConverter<mhlo::CosOp>,
|
||||
HloToLhloOpConverter<mhlo::DivOp>,
|
||||
HloToLhloOpConverter<mhlo::DotOp>,
|
||||
HloToLhloOpConverter<mhlo::ExpOp>,
|
||||
HloToLhloOpConverter<mhlo::GatherOp>,
|
||||
HloToLhloOpConverter<mhlo::ImagOp>,
|
||||
HloToLhloOpConverter<mhlo::IotaOp>,
|
||||
HloToLhloOpConverter<mhlo::LogOp>,
|
||||
HloToLhloOpConverter<mhlo::MaxOp>,
|
||||
HloToLhloOpConverter<mhlo::MinOp>,
|
||||
HloToLhloOpConverter<mhlo::MulOp>,
|
||||
HloToLhloOpConverter<mhlo::NegOp>,
|
||||
HloToLhloOpConverter<mhlo::RealOp>,
|
||||
HloToLhloOpConverter<mhlo::RemOp>,
|
||||
HloToLhloOpConverter<mhlo::RsqrtOp>,
|
||||
HloToLhloOpConverter<mhlo::ReshapeOp>,
|
||||
HloToLhloOpConverter<mhlo::SelectOp>,
|
||||
HloToLhloOpConverter<mhlo::SignOp>,
|
||||
HloToLhloOpConverter<mhlo::SqrtOp>,
|
||||
HloToLhloOpConverter<mhlo::SubOp>,
|
||||
HloToLhloOpConverter<mhlo::TanhOp>,
|
||||
HloToLhloReduceOpConverter,
|
||||
HloToLhloTensorLoadOpConverter,
|
||||
HloToLhloTensorStoreOpConverter
|
||||
@ -489,5 +489,5 @@ std::unique_ptr<OperationPass<ModuleOp>> createLegalizeToLhloPass(
|
||||
static PassRegistration<HloLegalizeToLhlo> legalize_pass(
|
||||
"hlo-legalize-to-lhlo", "Legalize from HLO dialect to LHLO dialect");
|
||||
|
||||
} // namespace xla_hlo
|
||||
} // namespace mhlo
|
||||
} // namespace mlir
|
||||
|
@ -35,7 +35,7 @@ limitations under the License.
|
||||
using mlir::PassRegistration;
|
||||
|
||||
namespace mlir {
|
||||
namespace xla_hlo {
|
||||
namespace mhlo {
|
||||
namespace {
|
||||
struct LegalizeControlFlow
|
||||
: public mlir::PassWrapper<LegalizeControlFlow, FunctionPass> {
|
||||
@ -51,7 +51,7 @@ LogicalResult ReplaceTerminators(Region* region, Block* target_block,
|
||||
OpBuilder* builder) {
|
||||
for (auto& old_block : region->getBlocks()) {
|
||||
Block* block = mapper.lookup(&old_block);
|
||||
auto return_op = dyn_cast<xla_hlo::ReturnOp>(block->getTerminator());
|
||||
auto return_op = dyn_cast<mhlo::ReturnOp>(block->getTerminator());
|
||||
if (!return_op) continue;
|
||||
builder->setInsertionPointToEnd(block);
|
||||
builder->create<mlir::BranchOp>(loc, target_block, return_op.getOperands());
|
||||
@ -61,7 +61,7 @@ LogicalResult ReplaceTerminators(Region* region, Block* target_block,
|
||||
return success();
|
||||
}
|
||||
|
||||
LogicalResult LowerIfOp(mlir::xla_hlo::IfOp if_op) {
|
||||
LogicalResult LowerIfOp(mlir::mhlo::IfOp if_op) {
|
||||
Operation* op_inst = if_op.getOperation();
|
||||
mlir::OpBuilder builder(if_op);
|
||||
auto orig_block = op_inst->getBlock();
|
||||
@ -106,13 +106,13 @@ LogicalResult LowerIfOp(mlir::xla_hlo::IfOp if_op) {
|
||||
return success();
|
||||
}
|
||||
|
||||
LogicalResult LowerWhileOp(mlir::xla_hlo::WhileOp while_op) {
|
||||
LogicalResult LowerWhileOp(mlir::mhlo::WhileOp while_op) {
|
||||
// Converts an XLA while loop into control flow. This generates a set of MLIR
|
||||
// blocks and branches, along with inlining the regions provided by the XLA
|
||||
// while loop. The structure should be similar to below:
|
||||
//
|
||||
// <prior operations>
|
||||
// %0 = "xla_hlo.while"(%arg0) {^cond(...){...}, ^body(...){...}}
|
||||
// %0 = "mhlo.while"(%arg0) {^cond(...){...}, ^body(...){...}}
|
||||
// <post operations>
|
||||
auto* op_inst = while_op.getOperation();
|
||||
mlir::OpBuilder builder(while_op);
|
||||
@ -147,7 +147,7 @@ LogicalResult LowerWhileOp(mlir::xla_hlo::WhileOp while_op) {
|
||||
// extract_element and conditional branch. This changes the block below:
|
||||
// ^cond(%0):
|
||||
// <inlined conditional region>
|
||||
// "xla_hlo".return(%1)
|
||||
// "mhlo".return(%1)
|
||||
//
|
||||
// Into:
|
||||
// ^cond(%0):
|
||||
@ -156,14 +156,14 @@ LogicalResult LowerWhileOp(mlir::xla_hlo::WhileOp while_op) {
|
||||
// cond_br %2, ^body(%0), ^tail(%0) // Branch.
|
||||
builder.setInsertionPointToStart(cond_block);
|
||||
|
||||
// Replace the xla_hlo::ReturnOp with a branch back to the condition block.
|
||||
// This is required as the xla_hlo::ReturnOp is used to mark the end of a
|
||||
// Replace the mhlo::ReturnOp with a branch back to the condition block.
|
||||
// This is required as the mhlo::ReturnOp is used to mark the end of a
|
||||
// block for regions nested inside of a operations (MLIR ReturnOp cannot be
|
||||
// nested within an non-function region).
|
||||
for (auto& block : while_op.cond()) {
|
||||
auto new_block = mapper.lookup(&block);
|
||||
|
||||
auto return_op = dyn_cast<xla_hlo::ReturnOp>(new_block->getTerminator());
|
||||
auto return_op = dyn_cast<mhlo::ReturnOp>(new_block->getTerminator());
|
||||
if (!return_op) continue;
|
||||
builder.setInsertionPointToEnd(new_block);
|
||||
|
||||
@ -183,7 +183,7 @@ LogicalResult LowerWhileOp(mlir::xla_hlo::WhileOp while_op) {
|
||||
// conditional block. This changes the block below:
|
||||
// ^body(%0):
|
||||
// <inlined body block>
|
||||
// "xla_hlo".return(%1)
|
||||
// "mhlo".return(%1)
|
||||
//
|
||||
// Into:
|
||||
// ^body(%0):
|
||||
@ -191,8 +191,7 @@ LogicalResult LowerWhileOp(mlir::xla_hlo::WhileOp while_op) {
|
||||
// br ^cond(%0) // Branch.
|
||||
for (auto& block : while_op.body()) {
|
||||
auto new_block = mapper.lookup(&block);
|
||||
auto return_op =
|
||||
dyn_cast<mlir::xla_hlo::ReturnOp>(new_block->getTerminator());
|
||||
auto return_op = dyn_cast<mlir::mhlo::ReturnOp>(new_block->getTerminator());
|
||||
if (!return_op) continue;
|
||||
builder.setInsertionPointToEnd(new_block);
|
||||
builder.create<mlir::BranchOp>(loc, cond_block, return_op.getOperands());
|
||||
@ -224,14 +223,14 @@ void LegalizeControlFlow::runOnFunction() {
|
||||
}
|
||||
}
|
||||
} // namespace
|
||||
} // namespace xla_hlo
|
||||
} // namespace mhlo
|
||||
} // namespace mlir
|
||||
|
||||
std::unique_ptr<mlir::OperationPass<mlir::FuncOp>>
|
||||
mlir::xla_hlo::createLegalizeControlFlowPass() {
|
||||
mlir::mhlo::createLegalizeControlFlowPass() {
|
||||
return std::make_unique<LegalizeControlFlow>();
|
||||
}
|
||||
|
||||
static PassRegistration<mlir::xla_hlo::LegalizeControlFlow> legalize_cf_pass(
|
||||
static PassRegistration<mlir::mhlo::LegalizeControlFlow> legalize_cf_pass(
|
||||
"xla-legalize-control-flow",
|
||||
"Legalize from XLA control flow to MLIR control flow");
|
||||
|
@ -28,14 +28,14 @@ namespace mlir {
|
||||
namespace {
|
||||
#include "tensorflow/compiler/mlir/hlo/lib/Dialect/mhlo/transforms/generated_legalize_to_standard.inc"
|
||||
} // end anonymous namespace
|
||||
namespace xla_hlo {
|
||||
namespace mhlo {
|
||||
namespace {
|
||||
|
||||
class CompareIConvert : public OpRewritePattern<xla_hlo::CompareOp> {
|
||||
class CompareIConvert : public OpRewritePattern<mhlo::CompareOp> {
|
||||
public:
|
||||
using OpRewritePattern::OpRewritePattern;
|
||||
|
||||
LogicalResult matchAndRewrite(xla_hlo::CompareOp op,
|
||||
LogicalResult matchAndRewrite(mhlo::CompareOp op,
|
||||
PatternRewriter &rewriter) const override {
|
||||
auto lhs = op.lhs();
|
||||
auto rhs = op.rhs();
|
||||
@ -68,11 +68,11 @@ class CompareIConvert : public OpRewritePattern<xla_hlo::CompareOp> {
|
||||
}
|
||||
};
|
||||
|
||||
class CompareFConvert : public OpRewritePattern<xla_hlo::CompareOp> {
|
||||
class CompareFConvert : public OpRewritePattern<mhlo::CompareOp> {
|
||||
public:
|
||||
using OpRewritePattern::OpRewritePattern;
|
||||
|
||||
LogicalResult matchAndRewrite(xla_hlo::CompareOp op,
|
||||
LogicalResult matchAndRewrite(mhlo::CompareOp op,
|
||||
PatternRewriter &rewriter) const override {
|
||||
auto lhs = op.lhs();
|
||||
auto rhs = op.rhs();
|
||||
@ -109,11 +109,11 @@ class CompareFConvert : public OpRewritePattern<xla_hlo::CompareOp> {
|
||||
// convert the integer constant to iota result type. For complex types, the real
|
||||
// part is replaced with the generated constant and the imaginary part is
|
||||
// replaced with zero tensor.
|
||||
class ConvertIotaOp : public OpRewritePattern<xla_hlo::IotaOp> {
|
||||
class ConvertIotaOp : public OpRewritePattern<mhlo::IotaOp> {
|
||||
public:
|
||||
using OpRewritePattern::OpRewritePattern;
|
||||
|
||||
LogicalResult matchAndRewrite(xla_hlo::IotaOp op,
|
||||
LogicalResult matchAndRewrite(mhlo::IotaOp op,
|
||||
PatternRewriter &rewriter) const override {
|
||||
auto output_type = op.getType().cast<ShapedType>();
|
||||
auto output_size = output_type.getNumElements();
|
||||
@ -168,8 +168,7 @@ class ConvertIotaOp : public OpRewritePattern<xla_hlo::IotaOp> {
|
||||
loc, DenseIntElementsAttr::get(int_shape_type, APInt(bitwidth, 0)));
|
||||
auto imag_zeroes =
|
||||
rewriter.create<ConvertOp>(loc, int_or_float_shape_ty, zeroes);
|
||||
rewriter.replaceOpWithNewOp<xla_hlo::ComplexOp>(op, iota_const,
|
||||
imag_zeroes);
|
||||
rewriter.replaceOpWithNewOp<mhlo::ComplexOp>(op, iota_const, imag_zeroes);
|
||||
return success();
|
||||
}
|
||||
};
|
||||
@ -197,12 +196,12 @@ void PopulateXlaToStdPatterns(OwningRewritePatternList *patterns,
|
||||
/// Perform the lowering to standard dialect.
|
||||
void LegalizeToStandard::runOnFunction() {
|
||||
OwningRewritePatternList patterns;
|
||||
mlir::xla_hlo::PopulateXlaToStdPatterns(&patterns, &getContext());
|
||||
mlir::mhlo::PopulateXlaToStdPatterns(&patterns, &getContext());
|
||||
applyPatternsAndFoldGreedily(getFunction(), patterns);
|
||||
}
|
||||
|
||||
static PassRegistration<LegalizeToStandard> legalize_pass(
|
||||
"xla-legalize-to-std", "Legalize from XLA dialect to standard dialect");
|
||||
|
||||
} // end namespace xla_hlo
|
||||
} // end namespace mhlo
|
||||
} // end namespace mlir
|
||||
|
@ -84,14 +84,14 @@ Value TransposeReshape(Value arg, mlir::Location loc,
|
||||
transposed_shape.push_back(arg_shape[val]);
|
||||
}
|
||||
auto transpose_type = RankedTensorType::get(transposed_shape, element_type);
|
||||
auto transpose_result = rewriter->create<mlir::xla_hlo::TransposeOp>(
|
||||
auto transpose_result = rewriter->create<mlir::mhlo::TransposeOp>(
|
||||
loc, transpose_type, arg, transpose_permutation_attr);
|
||||
|
||||
// Return the final result.
|
||||
auto reshaped_type =
|
||||
RankedTensorType::get({left_size, right_size}, element_type);
|
||||
return rewriter->create<mlir::xla_hlo::ReshapeOp>(loc, reshaped_type,
|
||||
transpose_result);
|
||||
return rewriter->create<mlir::mhlo::ReshapeOp>(loc, reshaped_type,
|
||||
transpose_result);
|
||||
}
|
||||
|
||||
Value ProcessDotArg(Value arg, mlir::Location loc,
|
||||
@ -125,8 +125,7 @@ Value ProcessDotArg(Value arg, mlir::Location loc,
|
||||
return TransposeReshape(arg, loc, contract_dims, outer_dims, shape, rewriter);
|
||||
}
|
||||
|
||||
struct GeneralDotConvert
|
||||
: public OpRewritePattern<mlir::xla_hlo::DotGeneralOp> {
|
||||
struct GeneralDotConvert : public OpRewritePattern<mlir::mhlo::DotGeneralOp> {
|
||||
// Attempts to lower a General Dot operator to a standard Dot operator.
|
||||
// General dots include batching dimensions and can have collapsing
|
||||
// dimensions along any axis. Inserting correctly arrange transpose and
|
||||
@ -138,7 +137,7 @@ struct GeneralDotConvert
|
||||
explicit GeneralDotConvert(MLIRContext *context)
|
||||
: OpRewritePattern(context) {}
|
||||
|
||||
LogicalResult matchAndRewrite(mlir::xla_hlo::DotGeneralOp op,
|
||||
LogicalResult matchAndRewrite(mlir::mhlo::DotGeneralOp op,
|
||||
PatternRewriter &rewriter) const override {
|
||||
auto dot_element_type = mlir::getElementTypeOrSelf(op);
|
||||
|
||||
@ -162,11 +161,11 @@ struct GeneralDotConvert
|
||||
auto new_dot_type =
|
||||
RankedTensorType::get({lhs_shape[0], rhs_shape[1]}, dot_element_type);
|
||||
|
||||
auto new_dot_op = rewriter.create<mlir::xla_hlo::DotOp>(
|
||||
auto new_dot_op = rewriter.create<mlir::mhlo::DotOp>(
|
||||
op.getLoc(), new_dot_type, lhs, rhs, *(op.precision_config()));
|
||||
|
||||
rewriter.replaceOpWithNewOp<mlir::xla_hlo::ReshapeOp>(op, op.getType(),
|
||||
new_dot_op);
|
||||
rewriter.replaceOpWithNewOp<mlir::mhlo::ReshapeOp>(op, op.getType(),
|
||||
new_dot_op);
|
||||
return success();
|
||||
}
|
||||
};
|
||||
@ -176,15 +175,14 @@ struct LegalizeGeneralDot
|
||||
/// Lower all general dots that can be represented as a non-batched matmul.
|
||||
void runOnFunction() override {
|
||||
OwningRewritePatternList patterns;
|
||||
mlir::xla_hlo::PopulateGeneralDotOpLoweringPatterns(&patterns,
|
||||
&getContext());
|
||||
mlir::mhlo::PopulateGeneralDotOpLoweringPatterns(&patterns, &getContext());
|
||||
applyPatternsAndFoldGreedily(getFunction(), patterns);
|
||||
}
|
||||
};
|
||||
|
||||
} // namespace
|
||||
|
||||
void mlir::xla_hlo::PopulateGeneralDotOpLoweringPatterns(
|
||||
void mlir::mhlo::PopulateGeneralDotOpLoweringPatterns(
|
||||
OwningRewritePatternList *patterns, MLIRContext *ctx) {
|
||||
patterns->insert<GeneralDotConvert>(ctx);
|
||||
}
|
||||
|
@ -23,7 +23,7 @@ limitations under the License.
|
||||
#include "tensorflow/compiler/mlir/hlo/include/mlir-hlo/Dialect/mhlo/IR/hlo_ops.h"
|
||||
|
||||
namespace mlir {
|
||||
namespace xla_hlo {
|
||||
namespace mhlo {
|
||||
|
||||
namespace {
|
||||
|
||||
@ -86,5 +86,5 @@ void PopulateMaterializeBroadcastsPatterns(MLIRContext *context,
|
||||
patterns->insert<ClampWithBroadcastConvert>(context);
|
||||
}
|
||||
|
||||
} // namespace xla_hlo
|
||||
} // namespace mhlo
|
||||
} // namespace mlir
|
||||
|
@ -23,7 +23,7 @@ limitations under the License.
|
||||
#include "tensorflow/compiler/mlir/hlo/include/mlir-hlo/Dialect/mhlo/transforms/rewriters.h"
|
||||
|
||||
namespace mlir {
|
||||
namespace xla_hlo {
|
||||
namespace mhlo {
|
||||
|
||||
namespace {
|
||||
|
||||
@ -33,7 +33,7 @@ struct TestMaterializeBroadcastsPass
|
||||
ConversionTarget conversionTarget(getContext());
|
||||
OwningRewritePatternList conversionPatterns;
|
||||
|
||||
// Consider the xla_hlo dialect legal for tests.
|
||||
// Consider the mhlo dialect legal for tests.
|
||||
conversionTarget.addLegalDialect<XlaHloDialect>();
|
||||
// The conversion uses helpers from the Standard dialect.
|
||||
conversionTarget.addLegalDialect<mlir::StandardOpsDialect>();
|
||||
@ -50,9 +50,9 @@ struct TestMaterializeBroadcastsPass
|
||||
|
||||
} // namespace
|
||||
|
||||
} // namespace xla_hlo
|
||||
} // namespace mhlo
|
||||
} // namespace mlir
|
||||
|
||||
static mlir::PassRegistration<mlir::xla_hlo::TestMaterializeBroadcastsPass>
|
||||
pass("test-xla-materialize-broadcasts",
|
||||
"Test pass for materializing 'broadcast_dimensions' attributes");
|
||||
static mlir::PassRegistration<mlir::mhlo::TestMaterializeBroadcastsPass> pass(
|
||||
"test-xla-materialize-broadcasts",
|
||||
"Test pass for materializing 'broadcast_dimensions' attributes");
|
||||
|
@ -60,7 +60,7 @@ limitations under the License.
|
||||
// shape dialect once it is ready.
|
||||
|
||||
namespace mlir {
|
||||
namespace xla_hlo {
|
||||
namespace mhlo {
|
||||
namespace {
|
||||
|
||||
using llvm::EquivalenceClasses;
|
||||
@ -544,7 +544,7 @@ struct XlaHloFusion : public mlir::PassWrapper<XlaHloFusion, FunctionPass> {
|
||||
}
|
||||
|
||||
FusionOp fusion =
|
||||
b.create<xla_hlo::FusionOp>(fused_loc, output_types, inputs);
|
||||
b.create<mhlo::FusionOp>(fused_loc, output_types, inputs);
|
||||
Region& region = fusion.fused_computation();
|
||||
region.push_back(new Block);
|
||||
Block& block = region.front();
|
||||
@ -552,7 +552,7 @@ struct XlaHloFusion : public mlir::PassWrapper<XlaHloFusion, FunctionPass> {
|
||||
op->moveBefore(&block, block.end());
|
||||
}
|
||||
b.setInsertionPoint(&block, block.end());
|
||||
b.create<xla_hlo::ReturnOp>(fused_loc, outputs);
|
||||
b.create<mhlo::ReturnOp>(fused_loc, outputs);
|
||||
|
||||
for (auto output_and_result : llvm::zip(outputs, fusion.getResults())) {
|
||||
Value output = std::get<0>(output_and_result);
|
||||
@ -572,8 +572,8 @@ std::unique_ptr<OperationPass<FuncOp>> createXlaHloFusion() {
|
||||
return std::make_unique<XlaHloFusion>();
|
||||
}
|
||||
|
||||
static PassRegistration<XlaHloFusion> xla_hlo_fusion_pass(
|
||||
"xla-hlo-fusion", "fuse xla_hlo ops to kLoop/kInput fusion patterns.");
|
||||
static PassRegistration<XlaHloFusion> mhlo_fusion_pass(
|
||||
"xla-hlo-fusion", "fuse mhlo ops to kLoop/kInput fusion patterns.");
|
||||
|
||||
} // namespace xla_hlo
|
||||
} // namespace mhlo
|
||||
} // namespace mlir
|
@ -23,7 +23,7 @@ limitations under the License.
|
||||
#include "tensorflow/compiler/mlir/hlo/include/mlir-hlo/Dialect/mhlo/IR/hlo_ops.h"
|
||||
|
||||
namespace mlir {
|
||||
namespace xla_hlo {
|
||||
namespace mhlo {
|
||||
|
||||
namespace {
|
||||
|
||||
@ -81,5 +81,5 @@ std::unique_ptr<OperationPass<FuncOp>> createSinkConstantsToControlFlowPass() {
|
||||
return std::make_unique<SinkConstantsToControlFlow>();
|
||||
}
|
||||
|
||||
} // namespace xla_hlo
|
||||
} // namespace mhlo
|
||||
} // namespace mlir
|
||||
|
@ -25,7 +25,7 @@ limitations under the License.
|
||||
#include "tensorflow/compiler/mlir/hlo/include/mlir-hlo/Dialect/mhlo/IR/hlo_ops.h"
|
||||
|
||||
namespace mlir {
|
||||
namespace xla_hlo {
|
||||
namespace mhlo {
|
||||
|
||||
namespace {
|
||||
|
||||
@ -40,12 +40,12 @@ Value BroadcastToFeatureDim(Location loc, RankedTensorType result_type,
|
||||
auto dims_type = RankedTensorType::get({1}, b.getIntegerType(64));
|
||||
auto dims = DenseIntElementsAttr::get(dims_type, {feature_dim});
|
||||
if (shape_value) {
|
||||
return rewriter.createOrFold<xla_hlo::DynamicBroadcastInDimOp>(
|
||||
return rewriter.createOrFold<mhlo::DynamicBroadcastInDimOp>(
|
||||
loc, result_type, value_1d, shape_value, dims);
|
||||
}
|
||||
assert(result_type.hasStaticShape());
|
||||
return rewriter.create<xla_hlo::BroadcastInDimOp>(loc, result_type, value_1d,
|
||||
dims);
|
||||
return rewriter.create<mhlo::BroadcastInDimOp>(loc, result_type, value_1d,
|
||||
dims);
|
||||
}
|
||||
|
||||
// Calculate the shape value of operand, assuming it is a dynamic shape with
|
||||
@ -89,25 +89,25 @@ Value MaterializeEpsilon(Operation* op, FloatAttr epsilon_attr,
|
||||
auto epsilon_tensor_attr =
|
||||
DenseElementsAttr::get(scalar_type, {epsilon_attr.cast<Attribute>()});
|
||||
Value epsilon =
|
||||
rewriter.create<xla_hlo::ConstOp>(op->getLoc(), epsilon_tensor_attr);
|
||||
rewriter.create<mhlo::ConstOp>(op->getLoc(), epsilon_tensor_attr);
|
||||
auto dims_type = RankedTensorType::get({0}, b.getIntegerType(64));
|
||||
auto dims = DenseIntElementsAttr::get(dims_type, SmallVector<int64_t, 1>{});
|
||||
if (broadcast_to_type.hasStaticShape()) {
|
||||
return rewriter.create<xla_hlo::BroadcastInDimOp>(
|
||||
return rewriter.create<mhlo::BroadcastInDimOp>(
|
||||
op->getLoc(), broadcast_to_type, epsilon, /*broadcast_dims=*/dims);
|
||||
}
|
||||
Value shape_value = CalculateShapeValue(op->getLoc(), variance, rewriter);
|
||||
return rewriter.createOrFold<xla_hlo::DynamicBroadcastInDimOp>(
|
||||
return rewriter.createOrFold<mhlo::DynamicBroadcastInDimOp>(
|
||||
op->getLoc(), broadcast_to_type, epsilon, shape_value,
|
||||
/*broadcast_dims=*/dims);
|
||||
}
|
||||
|
||||
class UnfuseBatchNormInferencePattern
|
||||
: public OpRewritePattern<xla_hlo::BatchNormInferenceOp> {
|
||||
: public OpRewritePattern<mhlo::BatchNormInferenceOp> {
|
||||
public:
|
||||
using OpRewritePattern<xla_hlo::BatchNormInferenceOp>::OpRewritePattern;
|
||||
using OpRewritePattern<mhlo::BatchNormInferenceOp>::OpRewritePattern;
|
||||
|
||||
LogicalResult matchAndRewrite(xla_hlo::BatchNormInferenceOp bn_op,
|
||||
LogicalResult matchAndRewrite(mhlo::BatchNormInferenceOp bn_op,
|
||||
PatternRewriter& rewriter) const override {
|
||||
// Enforce type invariants.
|
||||
// Note that we deduce the actual element type from the variance,
|
||||
@ -132,9 +132,9 @@ class UnfuseBatchNormInferencePattern
|
||||
if (!epsilon) {
|
||||
return failure();
|
||||
}
|
||||
Value stddev = rewriter.create<xla_hlo::AddOp>(bn_op.getLoc(),
|
||||
bn_op.variance(), epsilon);
|
||||
stddev = rewriter.create<xla_hlo::SqrtOp>(bn_op.getLoc(), stddev);
|
||||
Value stddev =
|
||||
rewriter.create<mhlo::AddOp>(bn_op.getLoc(), bn_op.variance(), epsilon);
|
||||
stddev = rewriter.create<mhlo::SqrtOp>(bn_op.getLoc(), stddev);
|
||||
|
||||
// Broadcast all terms.
|
||||
Value shape_value;
|
||||
@ -156,14 +156,13 @@ class UnfuseBatchNormInferencePattern
|
||||
|
||||
// Compute:
|
||||
// scale * (input - mean) / stddev + offset
|
||||
Value result = rewriter.create<xla_hlo::SubOp>(
|
||||
bn_op.getLoc(), bn_op.operand(), broadcast_mean);
|
||||
result = rewriter.create<xla_hlo::MulOp>(bn_op.getLoc(), result,
|
||||
broadcast_scale);
|
||||
result = rewriter.create<xla_hlo::DivOp>(bn_op.getLoc(), result,
|
||||
broadcast_stddev);
|
||||
rewriter.replaceOpWithNewOp<xla_hlo::AddOp>(bn_op, result,
|
||||
broadcast_offset);
|
||||
Value result = rewriter.create<mhlo::SubOp>(bn_op.getLoc(), bn_op.operand(),
|
||||
broadcast_mean);
|
||||
result =
|
||||
rewriter.create<mhlo::MulOp>(bn_op.getLoc(), result, broadcast_scale);
|
||||
result =
|
||||
rewriter.create<mhlo::DivOp>(bn_op.getLoc(), result, broadcast_stddev);
|
||||
rewriter.replaceOpWithNewOp<mhlo::AddOp>(bn_op, result, broadcast_offset);
|
||||
|
||||
return success();
|
||||
}
|
||||
@ -180,5 +179,5 @@ void PopulateUnfuseBatchNormPatterns(MLIRContext* context,
|
||||
patterns->insert<UnfuseBatchNormInferencePattern>(context);
|
||||
}
|
||||
|
||||
} // namespace xla_hlo
|
||||
} // namespace mhlo
|
||||
} // namespace mlir
|
||||
|
@ -23,7 +23,7 @@ limitations under the License.
|
||||
#include "tensorflow/compiler/mlir/hlo/include/mlir-hlo/Dialect/mhlo/transforms/rewriters.h"
|
||||
|
||||
namespace mlir {
|
||||
namespace xla_hlo {
|
||||
namespace mhlo {
|
||||
|
||||
namespace {
|
||||
|
||||
@ -38,9 +38,9 @@ struct TestUnfuseBatchNormPass
|
||||
|
||||
} // namespace
|
||||
|
||||
} // namespace xla_hlo
|
||||
} // namespace mhlo
|
||||
} // namespace mlir
|
||||
|
||||
static mlir::PassRegistration<mlir::xla_hlo::TestUnfuseBatchNormPass> pass(
|
||||
static mlir::PassRegistration<mlir::mhlo::TestUnfuseBatchNormPass> pass(
|
||||
"test-xla-unfuse-batch-norm",
|
||||
"Test pass for materializing 'broadcast_dimensions' attributes");
|
||||
|
@ -182,7 +182,7 @@ struct ConvToLinalgConverter : public OpConversionPattern<xla_lhlo::ConvOp> {
|
||||
using OpConversionPattern<xla_lhlo::ConvOp>::OpConversionPattern;
|
||||
|
||||
// This code has been adapted from IREE's
|
||||
// (https://github.com/google/iree/) xla_hlo -> linalg conversion.
|
||||
// (https://github.com/google/iree/) mhlo -> linalg conversion.
|
||||
LogicalResult matchAndRewrite(
|
||||
xla_lhlo::ConvOp op, ArrayRef<Value> args,
|
||||
ConversionPatternRewriter& rewriter) const final {
|
||||
@ -348,14 +348,14 @@ class BroadcastConverter
|
||||
|
||||
class HloBroadcastInDimConverter
|
||||
: public DataMovementOpConverter<HloBroadcastInDimConverter,
|
||||
xla_hlo::BroadcastInDimOp, false> {
|
||||
mhlo::BroadcastInDimOp, false> {
|
||||
public:
|
||||
using DataMovementOpConverter<HloBroadcastInDimConverter,
|
||||
xla_hlo::BroadcastInDimOp,
|
||||
mhlo::BroadcastInDimOp,
|
||||
false>::DataMovementOpConverter;
|
||||
|
||||
static SmallVector<AffineMap, 2> getIndexingMaps(
|
||||
xla_hlo::BroadcastInDimOp broadcastOp, Builder* b) {
|
||||
mhlo::BroadcastInDimOp broadcastOp, Builder* b) {
|
||||
auto resultType = getXLAOpResultType<false>(broadcastOp);
|
||||
auto operandType =
|
||||
broadcastOp.operand().getType().template cast<ShapedType>();
|
||||
@ -845,7 +845,7 @@ struct HloLegalizeToLinalg
|
||||
target.addLegalDialect<linalg::LinalgDialect, StandardOpsDialect>();
|
||||
|
||||
auto func = getFunction();
|
||||
xla_hlo::populateHLOToLinalgConversionPattern(func.getContext(), &patterns);
|
||||
mhlo::populateHLOToLinalgConversionPattern(func.getContext(), &patterns);
|
||||
if (failed(applyPartialConversion(func, target, patterns, nullptr))) {
|
||||
signalPassFailure();
|
||||
}
|
||||
@ -863,40 +863,40 @@ static PassRegistration<LhloLegalizeToLinalg> legalize_lhlo_pass(
|
||||
"lhlo-legalize-to-linalg", "Legalize from LHLO dialect to Linalg dialect");
|
||||
} // namespace xla_lhlo
|
||||
|
||||
namespace xla_hlo {
|
||||
namespace mhlo {
|
||||
|
||||
void populateHLOToLinalgConversionPattern(MLIRContext* context,
|
||||
OwningRewritePatternList* patterns) {
|
||||
patterns->insert<BroadcastConverter<xla_hlo::BroadcastOp, false>,
|
||||
patterns->insert<BroadcastConverter<mhlo::BroadcastOp, false>,
|
||||
HloBroadcastInDimConverter,
|
||||
PointwiseToLinalgConverter<xla_hlo::AbsOp, false>,
|
||||
PointwiseToLinalgConverter<xla_hlo::AddOp, false>,
|
||||
PointwiseToLinalgConverter<xla_hlo::AndOp, false>,
|
||||
PointwiseToLinalgConverter<xla_hlo::CeilOp, false>,
|
||||
PointwiseToLinalgConverter<xla_hlo::CompareOp, false>,
|
||||
PointwiseToLinalgConverter<xla_hlo::ComplexOp, false>,
|
||||
PointwiseToLinalgConverter<xla_hlo::ConvertOp, false>,
|
||||
PointwiseToLinalgConverter<xla_hlo::CopyOp, false>,
|
||||
PointwiseToLinalgConverter<xla_hlo::CosOp, false>,
|
||||
PointwiseToLinalgConverter<xla_hlo::DivOp, false>,
|
||||
PointwiseToLinalgConverter<xla_hlo::ExpOp, false>,
|
||||
PointwiseToLinalgConverter<xla_hlo::ImagOp, false>,
|
||||
PointwiseToLinalgConverter<xla_hlo::LogOp, false>,
|
||||
PointwiseToLinalgConverter<xla_hlo::MaxOp, false>,
|
||||
PointwiseToLinalgConverter<xla_hlo::MinOp, false>,
|
||||
PointwiseToLinalgConverter<xla_hlo::MulOp, false>,
|
||||
PointwiseToLinalgConverter<xla_hlo::NegOp, false>,
|
||||
PointwiseToLinalgConverter<xla_hlo::RealOp, false>,
|
||||
PointwiseToLinalgConverter<xla_hlo::RemOp, false>,
|
||||
PointwiseToLinalgConverter<xla_hlo::RsqrtOp, false>,
|
||||
PointwiseToLinalgConverter<xla_hlo::SelectOp, false>,
|
||||
PointwiseToLinalgConverter<xla_hlo::SinOp, false>,
|
||||
PointwiseToLinalgConverter<xla_hlo::SqrtOp, false>,
|
||||
PointwiseToLinalgConverter<xla_hlo::SubOp, false>,
|
||||
PointwiseToLinalgConverter<xla_hlo::TanhOp, false>,
|
||||
ReshapeOpConverter<xla_hlo::ReshapeOp, false>,
|
||||
ReverseConverter<xla_hlo::ReverseOp, false>,
|
||||
TransposeConverter<xla_hlo::TransposeOp, false>>(context);
|
||||
PointwiseToLinalgConverter<mhlo::AbsOp, false>,
|
||||
PointwiseToLinalgConverter<mhlo::AddOp, false>,
|
||||
PointwiseToLinalgConverter<mhlo::AndOp, false>,
|
||||
PointwiseToLinalgConverter<mhlo::CeilOp, false>,
|
||||
PointwiseToLinalgConverter<mhlo::CompareOp, false>,
|
||||
PointwiseToLinalgConverter<mhlo::ComplexOp, false>,
|
||||
PointwiseToLinalgConverter<mhlo::ConvertOp, false>,
|
||||
PointwiseToLinalgConverter<mhlo::CopyOp, false>,
|
||||
PointwiseToLinalgConverter<mhlo::CosOp, false>,
|
||||
PointwiseToLinalgConverter<mhlo::DivOp, false>,
|
||||
PointwiseToLinalgConverter<mhlo::ExpOp, false>,
|
||||
PointwiseToLinalgConverter<mhlo::ImagOp, false>,
|
||||
PointwiseToLinalgConverter<mhlo::LogOp, false>,
|
||||
PointwiseToLinalgConverter<mhlo::MaxOp, false>,
|
||||
PointwiseToLinalgConverter<mhlo::MinOp, false>,
|
||||
PointwiseToLinalgConverter<mhlo::MulOp, false>,
|
||||
PointwiseToLinalgConverter<mhlo::NegOp, false>,
|
||||
PointwiseToLinalgConverter<mhlo::RealOp, false>,
|
||||
PointwiseToLinalgConverter<mhlo::RemOp, false>,
|
||||
PointwiseToLinalgConverter<mhlo::RsqrtOp, false>,
|
||||
PointwiseToLinalgConverter<mhlo::SelectOp, false>,
|
||||
PointwiseToLinalgConverter<mhlo::SinOp, false>,
|
||||
PointwiseToLinalgConverter<mhlo::SqrtOp, false>,
|
||||
PointwiseToLinalgConverter<mhlo::SubOp, false>,
|
||||
PointwiseToLinalgConverter<mhlo::TanhOp, false>,
|
||||
ReshapeOpConverter<mhlo::ReshapeOp, false>,
|
||||
ReverseConverter<mhlo::ReverseOp, false>,
|
||||
TransposeConverter<mhlo::TransposeOp, false>>(context);
|
||||
}
|
||||
|
||||
std::unique_ptr<OperationPass<FuncOp>> createLegalizeHloToLinalgPass() {
|
||||
@ -905,5 +905,5 @@ std::unique_ptr<OperationPass<FuncOp>> createLegalizeHloToLinalgPass() {
|
||||
|
||||
static PassRegistration<HloLegalizeToLinalg> legalize_hlo_pass(
|
||||
"hlo-legalize-to-linalg", "Legalize from HLO dialect to Linalg dialect");
|
||||
} // namespace xla_hlo
|
||||
} // namespace mhlo
|
||||
} // namespace mlir
|
||||
|
@ -28,7 +28,7 @@ limitations under the License.
|
||||
#include "tensorflow/compiler/mlir/hlo/include/mlir-hlo/Dialect/mhlo/transforms/rewriters.h"
|
||||
|
||||
namespace mlir {
|
||||
namespace xla_hlo {
|
||||
namespace mhlo {
|
||||
namespace {
|
||||
|
||||
// TODO(frgossen): Make it variadic.
|
||||
@ -69,7 +69,7 @@ struct UnaryElementwiseOpConversion : public OpRewritePattern<OpTy> {
|
||||
rewriter.create<TensorFromElementsOp>(loc, numElementsAsIndex);
|
||||
auto flatTensorTy = RankedTensorType::get({ShapedType::kDynamicSize},
|
||||
operandTy.getElementType());
|
||||
Value flatOperand = rewriter.create<xla_hlo::DynamicReshapeOp>(
|
||||
Value flatOperand = rewriter.create<mhlo::DynamicReshapeOp>(
|
||||
loc, flatTensorTy, operand, flatShapeAsDimTensor);
|
||||
|
||||
// Generate IR for the actual operation.
|
||||
@ -80,7 +80,7 @@ struct UnaryElementwiseOpConversion : public OpRewritePattern<OpTy> {
|
||||
rewriter.getIndexType());
|
||||
Value shapeAsExtentTensor =
|
||||
rewriter.create<shape::ToExtentTensorOp>(loc, extentTensorTy, shape);
|
||||
Value result = rewriter.create<xla_hlo::DynamicReshapeOp>(
|
||||
Value result = rewriter.create<mhlo::DynamicReshapeOp>(
|
||||
loc, operandTy, flatResult, shapeAsExtentTensor);
|
||||
rewriter.replaceOp(op, result);
|
||||
|
||||
@ -184,5 +184,5 @@ static PassRegistration<TransformUnrankedHloPass> transform_unranked_hlo_pass(
|
||||
"transform-unranked-hlo",
|
||||
"Realize element-wise operations on ranked tensors where possible");
|
||||
|
||||
} // namespace xla_hlo
|
||||
} // namespace mhlo
|
||||
} // namespace mlir
|
||||
|
19
tensorflow/compiler/mlir/hlo/tests/BUILD
Normal file
19
tensorflow/compiler/mlir/hlo/tests/BUILD
Normal file
@ -0,0 +1,19 @@
|
||||
load("//tensorflow/compiler/mlir:glob_lit_test.bzl", "glob_lit_tests")
|
||||
|
||||
package(licenses = ["notice"])
|
||||
|
||||
glob_lit_tests(
|
||||
data = [":test_utilities"],
|
||||
driver = "@llvm-project//mlir:run_lit.sh",
|
||||
test_file_exts = ["mlir"],
|
||||
)
|
||||
|
||||
# Bundle together all of the test utilities that are used by tests.
|
||||
filegroup(
|
||||
name = "test_utilities",
|
||||
testonly = True,
|
||||
data = [
|
||||
"//tensorflow/compiler/mlir/hlo:mlir-hlo-opt",
|
||||
"@llvm-project//llvm:FileCheck",
|
||||
],
|
||||
)
|
499
tensorflow/compiler/mlir/hlo/tests/canonicalize.mlir
Normal file
499
tensorflow/compiler/mlir/hlo/tests/canonicalize.mlir
Normal file
@ -0,0 +1,499 @@
|
||||
// RUN: mlir-hlo-opt %s -pass-pipeline='func(canonicalize)' | FileCheck %s
|
||||
|
||||
// CHECK-LABEL: add_fold
|
||||
func @add_fold() -> tensor<4xi64> {
|
||||
%0 = mhlo.constant dense<[1, 2, 3, 4]> : tensor<4xi64>
|
||||
%1 = mhlo.constant dense<[5, 6, 7, 8]> : tensor<4xi64>
|
||||
// CHECK: mhlo.constant dense<[6, 8, 10, 12]>
|
||||
%2 = "mhlo.add"(%0, %1) : (tensor<4xi64>, tensor<4xi64>) -> (tensor<4xi64>)
|
||||
return %2 : tensor<4xi64>
|
||||
}
|
||||
|
||||
// CHECK-LABEL: add_scalar_fold
|
||||
func @add_scalar_fold() -> tensor<4xi64> {
|
||||
%0 = mhlo.constant dense<1> : tensor<4xi64>
|
||||
%1 = mhlo.constant dense<5> : tensor<4xi64>
|
||||
// CHECK: mhlo.constant dense<6>
|
||||
%2 = "mhlo.add"(%0, %1) : (tensor<4xi64>, tensor<4xi64>) -> (tensor<4xi64>)
|
||||
return %2 : tensor<4xi64>
|
||||
}
|
||||
|
||||
// CHECK-LABEL: add_fold_float
|
||||
func @add_fold_float() -> tensor<4xf64> {
|
||||
%0 = mhlo.constant dense<[1.0, 2.0, 3.0, 4.0]> : tensor<4xf64>
|
||||
%1 = mhlo.constant dense<[5.0, 6.0, 7.0, 8.0]> : tensor<4xf64>
|
||||
// CHECK: mhlo.constant dense<[6.000000e+00, 8.000000e+00, 1.000000e+01, 1.200000e+01]>
|
||||
%2 = "mhlo.add"(%0, %1) : (tensor<4xf64>, tensor<4xf64>) -> (tensor<4xf64>)
|
||||
return %2 : tensor<4xf64>
|
||||
}
|
||||
|
||||
// CHECK-LABEL: sub_scalar_fold
|
||||
func @sub_scalar_fold() -> tensor<4xi64> {
|
||||
%0 = mhlo.constant dense<5> : tensor<4xi64>
|
||||
%1 = mhlo.constant dense<1> : tensor<4xi64>
|
||||
// CHECK: mhlo.constant dense<4>
|
||||
%2 = "mhlo.subtract"(%0, %1) : (tensor<4xi64>, tensor<4xi64>) -> (tensor<4xi64>)
|
||||
return %2 : tensor<4xi64>
|
||||
}
|
||||
|
||||
// CHECK-LABEL: multiply_scalar_fold
|
||||
func @multiply_scalar_fold() -> tensor<4xi64> {
|
||||
%0 = mhlo.constant dense<5> : tensor<4xi64>
|
||||
%1 = mhlo.constant dense<3> : tensor<4xi64>
|
||||
// CHECK: mhlo.constant dense<15>
|
||||
%2 = "mhlo.multiply"(%0, %1) : (tensor<4xi64>, tensor<4xi64>) -> (tensor<4xi64>)
|
||||
return %2 : tensor<4xi64>
|
||||
}
|
||||
|
||||
// CHECK-LABEL: divide_scalar_fold
|
||||
func @divide_scalar_fold() -> tensor<4xi64> {
|
||||
%0 = mhlo.constant dense<7> : tensor<4xi64>
|
||||
%1 = mhlo.constant dense<5> : tensor<4xi64>
|
||||
// CHECK: mhlo.constant dense<1>
|
||||
%2 = "mhlo.divide"(%0, %1) : (tensor<4xi64>, tensor<4xi64>) -> (tensor<4xi64>)
|
||||
return %2 : tensor<4xi64>
|
||||
}
|
||||
|
||||
// CHECK-LABEL: divide_fold_float
|
||||
func @divide_fold_float() -> tensor<4xf64> {
|
||||
%0 = mhlo.constant dense<[5.0, 66.0, 5.0, 1.0]> : tensor<4xf64>
|
||||
%1 = mhlo.constant dense<[5.0, 3.0, 2.0, 4.0]> : tensor<4xf64>
|
||||
// CHECK: mhlo.constant dense<[1.000000e+00, 2.200000e+01, 2.500000e+00, 2.500000e-01]>
|
||||
%2 = "mhlo.divide"(%0, %1) : (tensor<4xf64>, tensor<4xf64>) -> (tensor<4xf64>)
|
||||
return %2 : tensor<4xf64>
|
||||
}
|
||||
|
||||
// CHECK-LABEL: max_scalar_fold
|
||||
func @max_scalar_fold() -> tensor<4xi64> {
|
||||
%0 = mhlo.constant dense<7> : tensor<4xi64>
|
||||
%1 = mhlo.constant dense<5> : tensor<4xi64>
|
||||
// CHECK: mhlo.constant dense<7>
|
||||
%2 = "mhlo.maximum"(%0, %1) : (tensor<4xi64>, tensor<4xi64>) -> (tensor<4xi64>)
|
||||
return %2 : tensor<4xi64>
|
||||
}
|
||||
|
||||
// CHECK-LABEL: max_fold_float
|
||||
func @max_fold_float() -> tensor<4xf64> {
|
||||
%0 = mhlo.constant dense<[5.0, 66.0, 5.0, 1.0]> : tensor<4xf64>
|
||||
%1 = mhlo.constant dense<[5.0, 3.0, 2.0, 4.0]> : tensor<4xf64>
|
||||
// CHECK: mhlo.constant dense<[5.000000e+00, 6.600000e+01, 5.000000e+00, 4.000000e+00]>
|
||||
%2 = "mhlo.maximum"(%0, %1) : (tensor<4xf64>, tensor<4xf64>) -> (tensor<4xf64>)
|
||||
return %2 : tensor<4xf64>
|
||||
}
|
||||
|
||||
// CHECK-LABEL: min_scalar_fold
|
||||
func @min_scalar_fold() -> tensor<4xi64> {
|
||||
%0 = mhlo.constant dense<7> : tensor<4xi64>
|
||||
%1 = mhlo.constant dense<-5> : tensor<4xi64>
|
||||
// CHECK: mhlo.constant dense<-5>
|
||||
%2 = "mhlo.minimum"(%0, %1) : (tensor<4xi64>, tensor<4xi64>) -> (tensor<4xi64>)
|
||||
return %2 : tensor<4xi64>
|
||||
}
|
||||
|
||||
// CHECK-LABEL: min_fold_float
|
||||
func @min_fold_float() -> tensor<4xf64> {
|
||||
%0 = mhlo.constant dense<[5.0, 66.0, 5.0, 1.0]> : tensor<4xf64>
|
||||
%1 = mhlo.constant dense<[5.0, 3.0, 2.0, 4.0]> : tensor<4xf64>
|
||||
// CHECK: mhlo.constant dense<[5.000000e+00, 3.000000e+00, 2.000000e+00, 1.000000e+00]>
|
||||
%2 = "mhlo.minimum"(%0, %1) : (tensor<4xf64>, tensor<4xf64>) -> (tensor<4xf64>)
|
||||
return %2 : tensor<4xf64>
|
||||
}
|
||||
|
||||
// CHECK-LABEL: concatenate_noop
|
||||
func @concatenate_noop(%arg0: tensor<4xi32>) -> tensor<4xi32> {
|
||||
// CHECK-SAME: [[ARG:%.+]]: tensor<4xi32>
|
||||
%0 = "mhlo.concatenate"(%arg0) { dimension = 0 : i64 } : (tensor<4xi32>) -> tensor<4xi32>
|
||||
|
||||
// CHECK: return [[ARG]]
|
||||
return %0 : tensor<4xi32>
|
||||
}
|
||||
|
||||
// CHECK-LABEL: concatenate_remove_operand
|
||||
func @concatenate_remove_operand(%arg0: tensor<4xi32>, %arg1: tensor<0xi32>) -> tensor<4xi32> {
|
||||
// CHECK-SAME: [[ARG0:%.+]]: tensor<4xi32>
|
||||
// CHECK-SAME: [[ARG1:%.+]]: tensor<0xi32>
|
||||
%0 = "mhlo.concatenate"(%arg0, %arg1) { dimension = 0 : i64 } : (tensor<4xi32>, tensor<0xi32>) -> tensor<4xi32>
|
||||
|
||||
// CHECK: return [[ARG0]]
|
||||
return %0 : tensor<4xi32>
|
||||
}
|
||||
|
||||
// CHECK-LABEL: concatenate_empty_bool
|
||||
func @concatenate_empty_bool(%arg0: tensor<0xi1>, %arg1: tensor<0xi1>) -> tensor<0xi1> {
|
||||
// CHECK: mhlo.constant
|
||||
%0 = "mhlo.concatenate"(%arg0, %arg1) { dimension = 0 : i64 } : (tensor<0xi1>, tensor<0xi1>) -> tensor<0xi1>
|
||||
|
||||
return %0 : tensor<0xi1>
|
||||
}
|
||||
|
||||
// CHECK-LABEL: concatenate_empty_int
|
||||
func @concatenate_empty_int(%arg0: tensor<0xi32>, %arg1: tensor<0xi32>) -> tensor<0xi32> {
|
||||
// CHECK: mhlo.constant
|
||||
%0 = "mhlo.concatenate"(%arg0, %arg1) { dimension = 0 : i64 } : (tensor<0xi32>, tensor<0xi32>) -> tensor<0xi32>
|
||||
|
||||
return %0 : tensor<0xi32>
|
||||
}
|
||||
|
||||
// CHECK-LABEL: concatenate_empty_float
|
||||
func @concatenate_empty_float(%arg0: tensor<0xf32>, %arg1: tensor<0xf32>) -> tensor<0xf32> {
|
||||
// CHECK: mhlo.constant
|
||||
%0 = "mhlo.concatenate"(%arg0, %arg1) { dimension = 0 : i64 } : (tensor<0xf32>, tensor<0xf32>) -> tensor<0xf32>
|
||||
|
||||
return %0 : tensor<0xf32>
|
||||
}
|
||||
|
||||
// CHECK-LABEL: concatenate_const_1D
|
||||
func @concatenate_const_1D() -> tensor<4xi32> {
|
||||
// CHECK: [[VAL:%.+]]= mhlo.constant dense<[0, 1, 2, 3]>
|
||||
%0 = mhlo.constant dense<[0, 1]> : tensor<2xi32>
|
||||
%1 = mhlo.constant dense<[2, 3]> : tensor<2xi32>
|
||||
%2 = "mhlo.concatenate"(%0, %1) { dimension = 0 : i64 } : (tensor<2xi32>, tensor<2xi32>) -> tensor<4xi32>
|
||||
|
||||
// CHECK: return [[VAL]]
|
||||
return %2 : tensor<4xi32>
|
||||
}
|
||||
|
||||
// CHECK-LABEL: concatenate_const_1D_float
|
||||
func @concatenate_const_1D_float() -> tensor<4xf32> {
|
||||
// CHECK: [[VAL:%.+]] = mhlo.constant dense<[0.000000e+00, 1.000000e+00, 2.000000e+00, 3.000000e+00]>
|
||||
|
||||
%0 = mhlo.constant dense<[0.0, 1.0]> : tensor<2xf32>
|
||||
%1 = mhlo.constant dense<[2.0, 3.0]> : tensor<2xf32>
|
||||
%2 = "mhlo.concatenate"(%0, %1) { dimension = 0 : i64 } : (tensor<2xf32>, tensor<2xf32>) -> tensor<4xf32>
|
||||
|
||||
// CHECK: return [[VAL]]
|
||||
return %2 : tensor<4xf32>
|
||||
}
|
||||
|
||||
// CHECK-LABEL: concatenate_const_2D_vertical
|
||||
func @concatenate_const_2D_vertical() -> tensor<2x2xi32> {
|
||||
// CHECK: [[VAL:%.+]]= mhlo.constant dense<[
|
||||
// CHECK-SAME: [0, 1], [2, 3]
|
||||
// CHECK-SAME: ]>
|
||||
%0 = mhlo.constant dense<[[0, 1]]> : tensor<1x2xi32>
|
||||
%1 = mhlo.constant dense<[[2, 3]]> : tensor<1x2xi32>
|
||||
%2 = "mhlo.concatenate"(%0, %1) { dimension = 0 : i64 } : (tensor<1x2xi32>, tensor<1x2xi32>) -> tensor<2x2xi32>
|
||||
|
||||
// CHECK: return [[VAL]]
|
||||
return %2 : tensor<2x2xi32>
|
||||
}
|
||||
|
||||
// CHECK-LABEL: concatenate_const_2D_horizontal
|
||||
func @concatenate_const_2D_horizontal() -> tensor<2x2xi32> {
|
||||
// CHECK: [[VAL:%.+]]= mhlo.constant dense<[
|
||||
// CHECK-SAME: [0, 2], [1, 3]
|
||||
// CHECK-SAME: ]>
|
||||
%0 = mhlo.constant dense<[[0], [1]]> : tensor<2x1xi32>
|
||||
%1 = mhlo.constant dense<[[2], [3]]> : tensor<2x1xi32>
|
||||
%2 = "mhlo.concatenate"(%0, %1) { dimension = 1 : i64 } : (tensor<2x1xi32>, tensor<2x1xi32>) -> tensor<2x2xi32>
|
||||
|
||||
// CHECK: return [[VAL]]
|
||||
return %2 : tensor<2x2xi32>
|
||||
}
|
||||
|
||||
// CHECK-LABEL: dynamic_slice_variable_start
|
||||
func @dynamic_slice_variable_start(%arg0: tensor<3x4xi32>, %arg1: tensor<i64>, %arg2: tensor<i64>) -> tensor<1x4xi32> {
|
||||
// CHECK: "mhlo.dynamic-slice"
|
||||
%1 = "mhlo.dynamic-slice"(%arg0, %arg1, %arg2) {slice_sizes = dense<[1, 4]> : tensor<2xi64>} : (tensor<3x4xi32>, tensor<i64>, tensor<i64>) -> tensor<1x4xi32>
|
||||
return %1 : tensor<1x4xi32>
|
||||
}
|
||||
|
||||
// CHECK-LABEL: dynamic_slice_constant_start
|
||||
func @dynamic_slice_constant_start(%arg0: tensor<4xi32>) -> tensor<2xi32> {
|
||||
// CHECK: %[[RESULT:.*]] = "mhlo.slice"(%arg0)
|
||||
// CHECK-DAG-SAME: limit_indices = dense<3> : tensor<1xi64>
|
||||
// CHECK-DAG-SAME: start_indices = dense<1> : tensor<1xi64>
|
||||
// CHECK-DAG-SAME: strides = dense<1> : tensor<1xi64>}
|
||||
// CHECK: return %[[RESULT]] : tensor<2xi32>
|
||||
%0 = mhlo.constant dense<1> : tensor<i64>
|
||||
%1 = "mhlo.dynamic-slice"(%arg0, %0) {slice_sizes = dense<2> : tensor<1xi64>} : (tensor<4xi32>, tensor<i64>) -> tensor<2xi32>
|
||||
return %1 : tensor<2xi32>
|
||||
}
|
||||
|
||||
// CHECK-LABEL: dynamic_slice_constant_start_dynamic_shape
|
||||
func @dynamic_slice_constant_start_dynamic_shape(%arg0: tensor<?x4xi32>, %arg1: tensor<2xi64>) -> tensor<?x4xi32> {
|
||||
// CHECK: %[[RESULT:.*]] = "mhlo.slice"(%arg0)
|
||||
// CHECK-DAG-SAME: limit_indices = dense<[2, 4]> : tensor<2xi64>
|
||||
// CHECK-DAG-SAME: start_indices = dense<[1, 0]> : tensor<2xi64>
|
||||
// CHECK-DAG-SAME: strides = dense<1> : tensor<2xi64>
|
||||
// CHECK: return %[[RESULT]] : tensor<?x4xi32>
|
||||
%0 = mhlo.constant dense<1> : tensor<i64>
|
||||
%1 = mhlo.constant dense<0> : tensor<i64>
|
||||
%2 = "mhlo.dynamic-slice"(%arg0, %0, %1) {slice_sizes = dense<[1, 4]> : tensor<2xi64>} : (tensor<?x4xi32>, tensor<i64>, tensor<i64>) -> tensor<?x4xi32>
|
||||
return %2 : tensor<?x4xi32>
|
||||
}
|
||||
|
||||
// CHECK-LABEL: slice_2D_noop
|
||||
// CHECK-SAME: [[ARG:%.+]]: tensor<2x2xi64>
|
||||
func @slice_2D_noop(%arg0: tensor<2x2xi64>) -> tensor<2x2xi64> {
|
||||
%0 = "mhlo.slice"(%arg0) { limit_indices = dense<[2, 2]> : tensor<2xi64>, start_indices = dense<[0, 0]> : tensor<2xi64>, strides = dense<1> : tensor<2xi64>} : (tensor<2x2xi64>) -> (tensor<2x2xi64>)
|
||||
|
||||
// CHECK-NEXT: return [[ARG]]
|
||||
return %0 : tensor<2x2xi64>
|
||||
}
|
||||
|
||||
// CHECK-LABEL: slice_1D_fold
|
||||
func @slice_1D_fold() -> tensor<2xi64> {
|
||||
%0 = mhlo.constant dense<[5, 7, 9, 10]> : tensor<4xi64>
|
||||
// CHECK: mhlo.constant dense<[7, 9]>
|
||||
%1 = "mhlo.slice"(%0) { limit_indices = dense<[3]> : tensor<1xi64>, start_indices = dense<[1]> : tensor<1xi64>, strides = dense<1> : tensor<1xi64>} : (tensor<4xi64>) -> (tensor<2xi64>)
|
||||
return %1 : tensor<2xi64>
|
||||
}
|
||||
|
||||
// CHECK-LABEL: slice_1D_fp
|
||||
func @slice_1D_fp() -> tensor<2xf32> {
|
||||
%0 = mhlo.constant dense<[5.0, 7.0, 9.0, 10.0]> : tensor<4xf32>
|
||||
// CHECK: mhlo.constant dense<[7.000000e+00, 9.000000e+00]>
|
||||
%1 = "mhlo.slice"(%0) { limit_indices = dense<[3]> : tensor<1xi64>, start_indices = dense<[1]> : tensor<1xi64>, strides = dense<1> : tensor<1xi64>} : (tensor<4xf32>) -> (tensor<2xf32>)
|
||||
return %1 : tensor<2xf32>
|
||||
}
|
||||
|
||||
// CHECK-LABEL: slice_1D_strided_fold
|
||||
func @slice_1D_strided_fold() -> tensor<2xi64> {
|
||||
%0 = mhlo.constant dense<[5, 7, 9, 10]> : tensor<4xi64>
|
||||
// CHECK: mhlo.constant dense<[7, 10]>
|
||||
%1 = "mhlo.slice"(%0) { limit_indices = dense<[4]> : tensor<1xi64>, start_indices = dense<[1]> : tensor<1xi64>, strides = dense<2> : tensor<1xi64>} : (tensor<4xi64>) -> (tensor<2xi64>)
|
||||
return %1 : tensor<2xi64>
|
||||
}
|
||||
|
||||
// CHECK-LABEL: slice_2D_fold
|
||||
func @slice_2D_fold() -> tensor<2x2xi64> {
|
||||
%0 = mhlo.constant dense<[[0, 1, 2, 3], [4, 5, 6, 7], [8, 9, 10, 11], [12, 13, 14, 15]]> : tensor<4x4xi64>
|
||||
// CHECK-NEXT: mhlo.constant dense<[
|
||||
// CHECK-SAME: [6, 7],
|
||||
// CHECK-SAME: [10, 11]
|
||||
// CHECK-SAME: ]>
|
||||
%1 = "mhlo.slice"(%0) { limit_indices = dense<[3, 4]> : tensor<2xi64>, start_indices = dense<[1, 2]> : tensor<2xi64>, strides = dense<1> : tensor<2xi64>} : (tensor<4x4xi64>) -> (tensor<2x2xi64>)
|
||||
return %1 : tensor<2x2xi64>
|
||||
}
|
||||
|
||||
// CHECK-LABEL: slice_2D_fold_horizontal
|
||||
func @slice_2D_fold_horizontal() -> tensor<1x4xi64> {
|
||||
%0 = mhlo.constant dense<[[0, 1, 2, 3], [4, 5, 6, 7], [8, 9, 10, 11], [12, 13, 14, 15]]> : tensor<4x4xi64>
|
||||
// CHECK-NEXT: mhlo.constant dense<[
|
||||
// CHECK-SAME: [0, 1, 2, 3]
|
||||
// CHECK-SAME: ]>
|
||||
%1 = "mhlo.slice"(%0) { limit_indices = dense<[1, 4]> : tensor<2xi64>, start_indices = dense<[0, 0]> : tensor<2xi64>, strides = dense<1> : tensor<2xi64>} : (tensor<4x4xi64>) -> (tensor<1x4xi64>)
|
||||
return %1 : tensor<1x4xi64>
|
||||
}
|
||||
|
||||
// CHECK-LABEL: slice_2D_fold_vertical
|
||||
func @slice_2D_fold_vertical() -> tensor<4x1xi64> {
|
||||
%0 = mhlo.constant dense<[[0, 1, 2, 3], [4, 5, 6, 7], [8, 9, 10, 11], [12, 13, 14, 15]]> : tensor<4x4xi64>
|
||||
// CHECK-NEXT: mhlo.constant dense<[
|
||||
// CHECK-SAME: [2], [6], [10], [14]
|
||||
// CHECK-SAME: ]>
|
||||
%1 = "mhlo.slice"(%0) { limit_indices = dense<[4, 3]> : tensor<2xi64>, start_indices = dense<[0, 2]> : tensor<2xi64>, strides = dense<1> : tensor<2xi64>} : (tensor<4x4xi64>) -> (tensor<4x1xi64>)
|
||||
return %1 : tensor<4x1xi64>
|
||||
}
|
||||
|
||||
// CHECK-LABEL: slice_concat_fold_first
|
||||
func @slice_concat_fold_first(%arg0: tensor<1x5xf32>, %arg1: tensor<1x5xf32>) -> tensor<1x5xf32> {
|
||||
%0 = "mhlo.concatenate"(%arg0, %arg1) { dimension = 0 : i64 } : (tensor<1x5xf32>, tensor<1x5xf32>) -> tensor<2x5xf32>
|
||||
%1 = "mhlo.slice"(%0) { limit_indices = dense<[1, 5]> : tensor<2xi64>, start_indices = dense<[0, 0]> : tensor<2xi64>, strides = dense<1> : tensor<2xi64>} : (tensor<2x5xf32>) -> (tensor<1x5xf32>)
|
||||
// CHECK: return %arg0
|
||||
return %1 : tensor<1x5xf32>
|
||||
}
|
||||
|
||||
// CHECK-LABEL: slice_concat_fold_second
|
||||
func @slice_concat_fold_second(%arg0: tensor<1x5xf32>, %arg1: tensor<1x5xf32>) -> tensor<1x5xf32> {
|
||||
%0 = "mhlo.concatenate"(%arg0, %arg1) { dimension = 0 : i64 } : (tensor<1x5xf32>, tensor<1x5xf32>) -> tensor<2x5xf32>
|
||||
%1 = "mhlo.slice"(%0) { limit_indices = dense<[2, 5]> : tensor<2xi64>, start_indices = dense<[1, 0]> : tensor<2xi64>, strides = dense<1> : tensor<2xi64>} : (tensor<2x5xf32>) -> (tensor<1x5xf32>)
|
||||
// CHECK: return %arg1
|
||||
return %1 : tensor<1x5xf32>
|
||||
}
|
||||
|
||||
// CHECK-LABEL: slice_concat_fold_second_with_slice
|
||||
func @slice_concat_fold_second_with_slice(%arg0: tensor<1x5xf32>, %arg1: tensor<1x5xf32>) -> tensor<1x4xf32> {
|
||||
%0 = "mhlo.concatenate"(%arg0, %arg1) { dimension = 0 : i64 } : (tensor<1x5xf32>, tensor<1x5xf32>) -> tensor<2x5xf32>
|
||||
// CHECK: [[SLICE:%.+]] = "mhlo.slice"(%arg1) {limit_indices = dense<[1, 5]> : tensor<2xi64>, start_indices = dense<[0, 1]> : tensor<2xi64>, strides = dense<1> : tensor<2xi64>} : (tensor<1x5xf32>) -> tensor<1x4xf32>
|
||||
%1 = "mhlo.slice"(%0) { limit_indices = dense<[2, 5]> : tensor<2xi64>, start_indices = dense<[1, 1]> : tensor<2xi64>, strides = dense<1> : tensor<2xi64>} : (tensor<2x5xf32>) -> (tensor<1x4xf32>)
|
||||
|
||||
// CHECK: return [[SLICE]]
|
||||
return %1 : tensor<1x4xf32>
|
||||
}
|
||||
|
||||
// CHECK-LABEL: slice_concat_fold_middle
|
||||
func @slice_concat_fold_middle(%arg0: tensor<1x5xf32>, %arg1: tensor<2x5xf32>, %arg2: tensor<1x5xf32>) -> tensor<1x5xf32> {
|
||||
%0 = "mhlo.concatenate"(%arg0, %arg1, %arg2) { dimension = 0 : i64 } : (tensor<1x5xf32>, tensor<2x5xf32>, tensor<1x5xf32>) -> tensor<4x5xf32>
|
||||
// CHECK: [[SLICE:%.+]] = "mhlo.slice"(%arg1) {limit_indices = dense<[2, 5]> : tensor<2xi64>, start_indices = dense<[1, 0]> : tensor<2xi64>, strides = dense<1> : tensor<2xi64>}
|
||||
%1 = "mhlo.slice"(%0) { limit_indices = dense<[3, 5]> : tensor<2xi64>, start_indices = dense<[2, 0]> : tensor<2xi64>, strides = dense<1> : tensor<2xi64>} : (tensor<4x5xf32>) -> (tensor<1x5xf32>)
|
||||
|
||||
// CHECK: return [[SLICE]]
|
||||
return %1 : tensor<1x5xf32>
|
||||
}
|
||||
|
||||
// CHECK-LABEL: slice_concat_fold_two
|
||||
func @slice_concat_fold_two(%arg0: tensor<1x5xf32>, %arg1: tensor<2x5xf32>, %arg2: tensor<1x5xf32>) -> tensor<2x5xf32> {
|
||||
// CHECK: [[CONCAT:%.+]] = "mhlo.concatenate"(%arg1, %arg2) {dimension = 0 : i64}
|
||||
%0 = "mhlo.concatenate"(%arg0, %arg1, %arg2) { dimension = 0 : i64 } : (tensor<1x5xf32>, tensor<2x5xf32>, tensor<1x5xf32>) -> tensor<4x5xf32>
|
||||
|
||||
// CHECK: [[SLICE:%.+]] = "mhlo.slice"([[CONCAT]]) {limit_indices = dense<[3, 5]> : tensor<2xi64>, start_indices = dense<[1, 0]> : tensor<2xi64>, strides = dense<1> : tensor<2xi64>}
|
||||
%1 = "mhlo.slice"(%0) { limit_indices = dense<[4, 5]> : tensor<2xi64>, start_indices = dense<[2, 0]> : tensor<2xi64>, strides = dense<1> : tensor<2xi64>} : (tensor<4x5xf32>) -> (tensor<2x5xf32>)
|
||||
|
||||
// CHECK: return [[SLICE]]
|
||||
return %1 : tensor<2x5xf32>
|
||||
}
|
||||
|
||||
// CHECK-LABEL: func @broadcast_in_dim_identity
|
||||
func @broadcast_in_dim_identity(%arg0: tensor<2x3x4xf32>) -> tensor<2x3x4xf32> {
|
||||
// CHECK: return %arg0
|
||||
%0 = "mhlo.broadcast_in_dim"(%arg0) {broadcast_dimensions = dense<[0, 1, 2]> : tensor<3xi64>} : (tensor<2x3x4xf32>) -> tensor<2x3x4xf32>
|
||||
return %0 : tensor<2x3x4xf32>
|
||||
}
|
||||
|
||||
// CHECK-LABEL: func @broadcast_in_dim_not_identity_because_it_actually_broadcasts
|
||||
func @broadcast_in_dim_not_identity_because_it_actually_broadcasts(%arg0: tensor<1x2xf32>) -> tensor<2x2xf32> {
|
||||
// CHECK: mhlo.broadcast_in_dim
|
||||
%0 = "mhlo.broadcast_in_dim"(%arg0) {broadcast_dimensions = dense<[0, 1]> : tensor<2xi64>} : (tensor<1x2xf32>) -> tensor<2x2xf32>
|
||||
return %0 : tensor<2x2xf32>
|
||||
}
|
||||
|
||||
// CHECK-LABEL: func @broadcast_in_dim_not_identity_permutation
|
||||
func @broadcast_in_dim_not_identity_permutation(%arg0: tensor<2x2xf32>) -> tensor<2x2xf32> {
|
||||
// CHECK: mhlo.broadcast_in_dim
|
||||
%0 = "mhlo.broadcast_in_dim"(%arg0) {broadcast_dimensions = dense<[1, 0]> : tensor<2xi64>} : (tensor<2x2xf32>) -> tensor<2x2xf32>
|
||||
return %0 : tensor<2x2xf32>
|
||||
}
|
||||
|
||||
|
||||
// CHECK-LABEL: func @dynamic_broadcast_in_dim_op_not_actually_dynamic
|
||||
func @dynamic_broadcast_in_dim_op_not_actually_dynamic(%arg0: tensor<4xf32>, %arg1: tensor<2xi64>) -> tensor<5x4xf32> {
|
||||
// CHECK: %[[RESULT:.+]] = "mhlo.broadcast_in_dim"(%arg0) {broadcast_dimensions = dense<1> : tensor<1xi64>} : (tensor<4xf32>) -> tensor<5x4xf32>
|
||||
%0 = "mhlo.dynamic_broadcast_in_dim"(%arg0, %arg1) { broadcast_dimensions = dense<1> : tensor<1xi64> } : (tensor<4xf32>, tensor<2xi64>) -> tensor<5x4xf32>
|
||||
// CHECK: return %[[RESULT]] : tensor<5x4xf32>
|
||||
return %0 : tensor<5x4xf32>
|
||||
}
|
||||
|
||||
// CHECK-LABEL: @complex_expand_fold
|
||||
func @complex_expand_fold(%arg0: tensor<4xf32>, %arg1: tensor<4xf32>) -> (tensor<4xf32>, tensor<4xf32>) {
|
||||
%0 = "mhlo.complex"(%arg0, %arg1) : (tensor<4xf32>, tensor<4xf32>) -> (tensor<4xcomplex<f32>>)
|
||||
%1 = "mhlo.real"(%0) : (tensor<4xcomplex<f32>>) -> (tensor<4xf32>)
|
||||
%2 = "mhlo.imag"(%0) : (tensor<4xcomplex<f32>>) -> (tensor<4xf32>)
|
||||
// CHECK: return %arg0, %arg1
|
||||
return %1, %2 : tensor<4xf32>, tensor<4xf32>
|
||||
}
|
||||
|
||||
// CHECK-LABEL: @complex_collapse_fold
|
||||
func @complex_collapse_fold(%arg0: tensor<4xcomplex<f32>>) -> tensor<4xcomplex<f32>> {
|
||||
%0 = "mhlo.real"(%arg0) : (tensor<4xcomplex<f32>>) -> (tensor<4xf32>)
|
||||
%1 = "mhlo.imag"(%arg0) : (tensor<4xcomplex<f32>>) -> (tensor<4xf32>)
|
||||
%2 = "mhlo.complex"(%0, %1) : (tensor<4xf32>, tensor<4xf32>) -> tensor<4xcomplex<f32>>
|
||||
// CHECK: return %arg0
|
||||
return %2 : tensor<4xcomplex<f32>>
|
||||
}
|
||||
|
||||
// CHECK-LABEL: @dynamic_iota_is_static
|
||||
func @dynamic_iota_is_static(%arg0 : tensor<1xindex>) -> tensor<4xi32> {
|
||||
// CHECK: [[RESULT:%.*]] = "mhlo.iota"
|
||||
// CHECK: return [[RESULT]]
|
||||
%0 = "mhlo.dynamic_iota"(%arg0) {iota_dimension = 0 : i64} : (tensor<1xindex>) -> tensor<4xi32>
|
||||
return %0 : tensor<4xi32>
|
||||
}
|
||||
|
||||
// CHECK-LABEL: @dynamic_iota_broadcast
|
||||
func @dynamic_iota_broadcast(%arg0 : tensor<2xindex>) -> tensor<5x?xi32> {
|
||||
// CHECK: [[IOTA:%.+]] = "mhlo.iota"() {iota_dimension = 0 : i64} : () -> tensor<5xi32>
|
||||
// CHECK: [[BROADCAST:%.+]] = "mhlo.dynamic_broadcast_in_dim"([[IOTA]], %arg0) {broadcast_dimensions = dense<0> : tensor<1xi64>} : (tensor<5xi32>, tensor<2xindex>) -> tensor<5x?xi32>
|
||||
%0 = "mhlo.dynamic_iota"(%arg0) {iota_dimension = 0 : i64} : (tensor<2xindex>) -> tensor<5x?xi32>
|
||||
|
||||
// CHECK: return [[BROADCAST]]
|
||||
return %0 : tensor<5x?xi32>
|
||||
}
|
||||
|
||||
// CHECK-LABEL: @dynamic_iota_broadcast_second
|
||||
func @dynamic_iota_broadcast_second(%arg0 : tensor<2xindex>) -> tensor<5x?xi32> {
|
||||
// CHECK-NEXT: [[CAST1:%.+]] = index_cast %arg0 : tensor<2xindex> to tensor<2xi64>
|
||||
// CHECK-NEXT: [[SLICE:%.+]] = "mhlo.slice"([[CAST1]]) {limit_indices = dense<2> : tensor<1xi64>, start_indices = dense<1> : tensor<1xi64>, strides = dense<1> : tensor<1xi64>} : (tensor<2xi64>) -> tensor<1xi64>
|
||||
// CHECK-NEXT: [[CAST2:%.+]] = index_cast [[SLICE]] : tensor<1xi64> to tensor<1xindex>
|
||||
// CHECK-NEXT: [[IOTA:%.+]] = "mhlo.dynamic_iota"([[CAST2]]) {iota_dimension = 0 : i64} : (tensor<1xindex>) -> tensor<?xi32>
|
||||
// CHECK-NEXT: [[BROADCAST:%.+]] = "mhlo.dynamic_broadcast_in_dim"([[IOTA]], %arg0) {broadcast_dimensions = dense<1> : tensor<1xi64>} : (tensor<?xi32>, tensor<2xindex>) -> tensor<5x?xi32>
|
||||
%0 = "mhlo.dynamic_iota"(%arg0) {iota_dimension = 1 : i64} : (tensor<2xindex>) -> tensor<5x?xi32>
|
||||
|
||||
// CHECK: return [[BROADCAST]]
|
||||
return %0 : tensor<5x?xi32>
|
||||
}
|
||||
|
||||
|
||||
// CHECK-LABEL: @iota_not_lowered_to_constant
|
||||
func @iota_not_lowered_to_constant() -> tensor<4xi32> {
|
||||
// CHECK: [[RESULT:%.*]] = "mhlo.iota"
|
||||
// CHECK: return [[RESULT]]
|
||||
%0 = "mhlo.iota"() {iota_dimension = 0 : i64} : () -> tensor<4xi32>
|
||||
return %0 : tensor<4xi32>
|
||||
}
|
||||
|
||||
// CHECK-LABEL: @iota_broadcast
|
||||
func @iota_broadcast() -> tensor<5x4xi32> {
|
||||
// CHECK: [[IOTA:%.+]] = "mhlo.iota"() {iota_dimension = 0 : i64} : () -> tensor<5xi32>
|
||||
// CHECK: [[RESULT:%.+]] = "mhlo.broadcast_in_dim"([[IOTA]]) {broadcast_dimensions = dense<0> : tensor<1xi64>} : (tensor<5xi32>) -> tensor<5x4xi32>
|
||||
%0 = "mhlo.iota"() {iota_dimension = 0 : i64} : () -> tensor<5x4xi32>
|
||||
|
||||
return %0 : tensor<5x4xi32>
|
||||
}
|
||||
|
||||
// CHECK-LABEL: @iota_broadcast
|
||||
func @iota_broadcast_second() -> tensor<5x4xi32> {
|
||||
// CHECK: [[IOTA:%.+]] = "mhlo.iota"() {iota_dimension = 0 : i64} : () -> tensor<4xi32>
|
||||
// CHECK: [[RESULT:%.+]] = "mhlo.broadcast_in_dim"([[IOTA]]) {broadcast_dimensions = dense<1> : tensor<1xi64>} : (tensor<4xi32>) -> tensor<5x4xi32>
|
||||
%0 = "mhlo.iota"() {iota_dimension = 1 : i64} : () -> tensor<5x4xi32>
|
||||
|
||||
return %0 : tensor<5x4xi32>
|
||||
}
|
||||
|
||||
// CHECK-LABEL: @unary_einsum
|
||||
func @unary_einsum(%arg0: tensor<2x3xf32>) -> tensor<2x2xf32> {
|
||||
// CHECK: %[[ONE:.*]] = mhlo.constant dense<1.000000e+00> : tensor<f32>
|
||||
// CHECK: "mhlo.einsum"(%[[ONE]], %arg0) {einsum_config = ",ab->aa"}
|
||||
%0 = "mhlo.unary_einsum"(%arg0) {einsum_config = "ab->aa"} : (tensor<2x3xf32>) -> tensor<2x2xf32>
|
||||
return %0 : tensor<2x2xf32>
|
||||
}
|
||||
|
||||
// CHECK-LABEL: func @fold_copy
|
||||
// CHECK-SAME: [[ARG:%[a-zA-Z0-9]+]]
|
||||
func @fold_copy(%arg : tensor<1x4xf32>) -> tensor<1x4xf32> {
|
||||
// CHECK: return [[ARG]]
|
||||
%0 = "mhlo.copy"(%arg) : (tensor<1x4xf32>) -> tensor<1x4xf32>
|
||||
return %0 : tensor<1x4xf32>
|
||||
}
|
||||
|
||||
// CHECK-LABEL: func @dynamic_reshape_not_actually_dynamic
|
||||
func @dynamic_reshape_not_actually_dynamic(%arg0: tensor<4xf32>, %shape: tensor<2xindex>) -> tensor<4x1xf32> {
|
||||
// CHECK: mhlo.reshape
|
||||
%0 = "mhlo.dynamic_reshape"(%arg0, %shape) : (tensor<4xf32>, tensor<2xindex>) -> tensor<4x1xf32>
|
||||
return %0 : tensor<4x1xf32>
|
||||
}
|
||||
|
||||
// CHECK-LABEL: do_not_dce_while_with_outfeed
|
||||
func @do_not_dce_while_with_outfeed(%arg0: tensor<i64>) -> tensor<i64> {
|
||||
// CHECK: mhlo.while
|
||||
%0 = "mhlo.while"(%arg0) ( {
|
||||
^bb0(%arg1: tensor<i64>):
|
||||
%1 = "mhlo.compare"(%arg1, %arg1) {comparison_direction = "LT"} : (tensor<i64>, tensor<i64>) -> tensor<i1>
|
||||
"mhlo.return"(%1) : (tensor<i1>) -> ()
|
||||
}, {
|
||||
^bb0(%arg1: tensor<i64>):
|
||||
%1 = "mhlo.create_token"() : () -> !mhlo.token
|
||||
// Side-effecting op outfeed present inside while.
|
||||
%2 = "mhlo.outfeed"(%arg1, %1) {outfeed_config = ""} : (tensor<i64>, !mhlo.token) -> !mhlo.token
|
||||
"mhlo.return"(%arg1) : (tensor<i64>) -> ()
|
||||
}) : (tensor<i64>) -> tensor<i64>
|
||||
|
||||
return %arg0 : tensor<i64>
|
||||
}
|
||||
|
||||
// CHECK-LABEL: dce_while_without_side_effect
|
||||
func @dce_while_without_side_effect(%arg0: tensor<i64>) -> tensor<i64> {
|
||||
// CHECK-NOT: mhlo.while
|
||||
%0 = "mhlo.while"(%arg0) ( {
|
||||
^bb0(%arg1: tensor<i64>):
|
||||
%1 = "mhlo.compare"(%arg1, %arg1) {comparison_direction = "LT"} : (tensor<i64>, tensor<i64>) -> tensor<i1>
|
||||
"mhlo.return"(%1) : (tensor<i1>) -> ()
|
||||
}, {
|
||||
^bb0(%arg1: tensor<i64>):
|
||||
%1 = "mhlo.create_token"() : () -> !mhlo.token
|
||||
"mhlo.return"(%arg1) : (tensor<i64>) -> ()
|
||||
}) : (tensor<i64>) -> tensor<i64>
|
||||
|
||||
return %arg0 : tensor<i64>
|
||||
}
|
@ -1,4 +1,4 @@
|
||||
// RUN: xla-opt -test-xla-infer-shaped-type-methods -allow-unregistered-dialect -split-input-file -verify-diagnostics %s -o - | FileCheck %s
|
||||
// RUN: mlir-hlo-opt -test-xla-infer-shaped-type-methods -allow-unregistered-dialect -split-input-file -verify-diagnostics %s -o - | FileCheck %s
|
||||
|
||||
// CHECK-LABEL: @broadcast_add
|
||||
// Note that all broadcast_ops are expanded from the same template, so
|
@ -1,10 +1,10 @@
|
||||
// RUN: xla-opt -test-xla-chlo-legalize-to-hlo -cse -split-input-file -verify-diagnostics %s -o - | FileCheck %s
|
||||
// RUN: mlir-hlo-opt -test-xla-chlo-legalize-to-hlo -cse -split-input-file -verify-diagnostics %s -o - | FileCheck %s
|
||||
|
||||
// Check the non-broadcast case for each registered op, then just check a
|
||||
// representative op for detailed broadcast semantics.
|
||||
// CHECK-LABEL: @addWithoutBroadcast
|
||||
func @addWithoutBroadcast(%arg0: tensor<4xf32>, %arg1: tensor<4xf32>) -> tensor<4xf32> {
|
||||
// CHECK: xla_hlo.add %arg0, %arg1
|
||||
// CHECK: mhlo.add %arg0, %arg1
|
||||
%0 = xla_chlo.broadcast_add %arg0, %arg1 : (tensor<4xf32>, tensor<4xf32>) -> tensor<4xf32>
|
||||
return %0 : tensor<4xf32>
|
||||
}
|
||||
@ -20,9 +20,9 @@ func @dynamicBroadcast(%arg0: tensor<?xf32>, %arg1: tensor<?x?xf32>) -> tensor<?
|
||||
// CHECK-NEXT: %[[FINAL_RESULT:.+]] = shape.assuming %[[WITNESS]]
|
||||
// CHECK-DAG: %[[RESULT_S:.+]] = "shape.broadcast"(%[[ARG0_S]], %[[ARG1_S]])
|
||||
// CHECK: %[[RESULT_EXTENTS:.+]] = shape.to_extent_tensor %[[RESULT_S]]
|
||||
// CHECK-DAG: %[[ARG0_B:.+]] = "xla_hlo.dynamic_broadcast_in_dim"(%[[ARG0]], %[[RESULT_EXTENTS]]) {broadcast_dimensions = dense<1> : tensor<1xi64>}
|
||||
// CHECK-DAG: %[[ARG1_B:.+]] = "xla_hlo.dynamic_broadcast_in_dim"(%[[ARG1]], %[[RESULT_EXTENTS]]) {broadcast_dimensions = dense<[0, 1]> : tensor<2xi64>}
|
||||
// CHECK-NEXT: %[[RESULT:.+]] = xla_hlo.add %[[ARG0_B]], %[[ARG1_B]]
|
||||
// CHECK-DAG: %[[ARG0_B:.+]] = "mhlo.dynamic_broadcast_in_dim"(%[[ARG0]], %[[RESULT_EXTENTS]]) {broadcast_dimensions = dense<1> : tensor<1xi64>}
|
||||
// CHECK-DAG: %[[ARG1_B:.+]] = "mhlo.dynamic_broadcast_in_dim"(%[[ARG1]], %[[RESULT_EXTENTS]]) {broadcast_dimensions = dense<[0, 1]> : tensor<2xi64>}
|
||||
// CHECK-NEXT: %[[RESULT:.+]] = mhlo.add %[[ARG0_B]], %[[ARG1_B]]
|
||||
// CHECK-NEXT: shape.assuming_yield %[[RESULT]]
|
||||
// CHECK-NEXT: }
|
||||
// CHECK-NEXT: return %[[FINAL_RESULT]] : tensor<?x?xf32>
|
||||
@ -41,9 +41,9 @@ func @dynamicBroadcastComplex(%arg0: tensor<?xf32>, %arg1: tensor<?x?xf32>) -> t
|
||||
// CHECK-NEXT: %[[FINAL_RESULT:.+]] = shape.assuming %[[WITNESS]]
|
||||
// CHECK-NEXT: %[[RESULT_S:.+]] = "shape.broadcast"(%[[ARG0_S]], %[[ARG1_S]])
|
||||
// CHECK-NEXT: %[[RESULT_EXTENTS:.+]] = shape.to_extent_tensor %[[RESULT_S]]
|
||||
// CHECK-DAG: %[[ARG0_B:.+]] = "xla_hlo.dynamic_broadcast_in_dim"(%[[ARG0]], %[[RESULT_EXTENTS]]) {broadcast_dimensions = dense<1> : tensor<1xi64>} : (tensor<?xf32>, tensor<2xindex>) -> tensor<?x?xf32>
|
||||
// CHECK-DAG: %[[ARG1_B:.+]] = "xla_hlo.dynamic_broadcast_in_dim"(%[[ARG1]], %[[RESULT_EXTENTS]]) {broadcast_dimensions = dense<[0, 1]> : tensor<2xi64>} : (tensor<?x?xf32>, tensor<2xindex>) -> tensor<?x?xf32>
|
||||
// CHECK-NEXT: %[[RESULT:.+]] = "xla_hlo.complex"(%[[ARG0_B]], %[[ARG1_B]]) : (tensor<?x?xf32>, tensor<?x?xf32>) -> tensor<?x?xcomplex<f32>>
|
||||
// CHECK-DAG: %[[ARG0_B:.+]] = "mhlo.dynamic_broadcast_in_dim"(%[[ARG0]], %[[RESULT_EXTENTS]]) {broadcast_dimensions = dense<1> : tensor<1xi64>} : (tensor<?xf32>, tensor<2xindex>) -> tensor<?x?xf32>
|
||||
// CHECK-DAG: %[[ARG1_B:.+]] = "mhlo.dynamic_broadcast_in_dim"(%[[ARG1]], %[[RESULT_EXTENTS]]) {broadcast_dimensions = dense<[0, 1]> : tensor<2xi64>} : (tensor<?x?xf32>, tensor<2xindex>) -> tensor<?x?xf32>
|
||||
// CHECK-NEXT: %[[RESULT:.+]] = "mhlo.complex"(%[[ARG0_B]], %[[ARG1_B]]) : (tensor<?x?xf32>, tensor<?x?xf32>) -> tensor<?x?xcomplex<f32>>
|
||||
// CHECK-NEXT: shape.assuming_yield %[[RESULT]]
|
||||
// CHECK-NEXT: }
|
||||
// CHECK-NEXT: return %[[FINAL_RESULT]] : tensor<?x?xcomplex<f32>>
|
||||
@ -62,9 +62,9 @@ func @dynamicBroadcastCompare(%arg0: tensor<?xf32>, %arg1: tensor<?x?xf32>) -> t
|
||||
// CHECK: %[[FINAL_RESULT:.+]] = shape.assuming %[[WITNESS]]
|
||||
// CHECK: %[[RESULT_S:.+]] = "shape.broadcast"(%[[ARG0_S]], %[[ARG1_S]])
|
||||
// CHECK: %[[RESULT_EXTENTS:.+]] = shape.to_extent_tensor %[[RESULT_S]]
|
||||
// CHECK-DAG: %[[ARG0_B:.+]] = "xla_hlo.dynamic_broadcast_in_dim"(%[[ARG0]], %[[RESULT_EXTENTS]]) {broadcast_dimensions = dense<1> : tensor<1xi64>} : (tensor<?xf32>, tensor<2xindex>) -> tensor<?x?xf32>
|
||||
// CHECK-DAG: %[[ARG1_B:.+]] = "xla_hlo.dynamic_broadcast_in_dim"(%[[ARG1]], %[[RESULT_EXTENTS]]) {broadcast_dimensions = dense<[0, 1]> : tensor<2xi64>} : (tensor<?x?xf32>, tensor<2xindex>) -> tensor<?x?xf32>
|
||||
// CHECK: %[[RESULT:.+]] = "xla_hlo.compare"(%[[ARG0_B]], %[[ARG1_B]]) {comparison_direction = "EQ"} : (tensor<?x?xf32>, tensor<?x?xf32>) -> tensor<?x?xi1>
|
||||
// CHECK-DAG: %[[ARG0_B:.+]] = "mhlo.dynamic_broadcast_in_dim"(%[[ARG0]], %[[RESULT_EXTENTS]]) {broadcast_dimensions = dense<1> : tensor<1xi64>} : (tensor<?xf32>, tensor<2xindex>) -> tensor<?x?xf32>
|
||||
// CHECK-DAG: %[[ARG1_B:.+]] = "mhlo.dynamic_broadcast_in_dim"(%[[ARG1]], %[[RESULT_EXTENTS]]) {broadcast_dimensions = dense<[0, 1]> : tensor<2xi64>} : (tensor<?x?xf32>, tensor<2xindex>) -> tensor<?x?xf32>
|
||||
// CHECK: %[[RESULT:.+]] = "mhlo.compare"(%[[ARG0_B]], %[[ARG1_B]]) {comparison_direction = "EQ"} : (tensor<?x?xf32>, tensor<?x?xf32>) -> tensor<?x?xi1>
|
||||
// CHECK: shape.assuming_yield %[[RESULT]]
|
||||
// CHECK-NEXT: }
|
||||
// CHECK: return %[[FINAL_RESULT]] : tensor<?x?xi1>
|
||||
@ -76,7 +76,7 @@ func @dynamicBroadcastCompare(%arg0: tensor<?xf32>, %arg1: tensor<?x?xf32>) -> t
|
||||
// Verifies that broadcast_dimensions validity checks are valid.
|
||||
// CHECK-LABEL: @dynamicNonScalarBroadcastDimensions
|
||||
func @dynamicNonScalarBroadcastDimensions(%arg0: tensor<1x4xf32>, %arg1: tensor<4xf32>) -> tensor<1x4xf32> {
|
||||
// CHECK: xla_hlo.add
|
||||
// CHECK: mhlo.add
|
||||
%0 = xla_chlo.broadcast_add %arg0, %arg1 {broadcast_dimensions = dense<1> : tensor<1xi64>} : (tensor<1x4xf32>, tensor<4xf32>) -> tensor<1x4xf32>
|
||||
return %0 : tensor<1x4xf32>
|
||||
}
|
||||
@ -85,7 +85,7 @@ func @dynamicNonScalarBroadcastDimensions(%arg0: tensor<1x4xf32>, %arg1: tensor<
|
||||
// Verifies that broadcast_dimensions validity checks are valid.
|
||||
// CHECK-LABEL: @dynamicNonScalarByScalarBroadcastDimensions
|
||||
func @dynamicNonScalarByScalarBroadcastDimensions(%arg0: tensor<1x4xf32>, %arg1: tensor<f32>) -> tensor<1x4xf32> {
|
||||
// CHECK: xla_hlo.add
|
||||
// CHECK: mhlo.add
|
||||
%0 = xla_chlo.broadcast_add %arg0, %arg1 {broadcast_dimensions = dense<[]> : tensor<0xi64>} : (tensor<1x4xf32>, tensor<f32>) -> tensor<1x4xf32>
|
||||
return %0 : tensor<1x4xf32>
|
||||
}
|
||||
@ -113,7 +113,7 @@ func @dynamicNonScalarBroadcastDimensionsMismatch(%arg0: tensor<1x4xf32>, %arg1:
|
||||
// expansions. Tests below merely verify that the op has an expansion.
|
||||
// CHECK-LABEL: @andWithoutBroadcast
|
||||
func @andWithoutBroadcast(%arg0: tensor<4xi1>, %arg1: tensor<4xi1>) -> tensor<4xi1> {
|
||||
// CHECK: xla_hlo.and %arg0, %arg1
|
||||
// CHECK: mhlo.and %arg0, %arg1
|
||||
%0 = xla_chlo.broadcast_and %arg0, %arg1 : (tensor<4xi1>, tensor<4xi1>) -> tensor<4xi1>
|
||||
return %0 : tensor<4xi1>
|
||||
}
|
||||
@ -121,7 +121,7 @@ func @andWithoutBroadcast(%arg0: tensor<4xi1>, %arg1: tensor<4xi1>) -> tensor<4x
|
||||
// -----
|
||||
// CHECK-LABEL: @atan2WithoutBroadcast
|
||||
func @atan2WithoutBroadcast(%arg0: tensor<4xf32>, %arg1: tensor<4xf32>) -> tensor<4xf32> {
|
||||
// CHECK: xla_hlo.atan2 %arg0, %arg1
|
||||
// CHECK: mhlo.atan2 %arg0, %arg1
|
||||
%0 = xla_chlo.broadcast_atan2 %arg0, %arg1 : (tensor<4xf32>, tensor<4xf32>) -> tensor<4xf32>
|
||||
return %0 : tensor<4xf32>
|
||||
}
|
||||
@ -129,7 +129,7 @@ func @atan2WithoutBroadcast(%arg0: tensor<4xf32>, %arg1: tensor<4xf32>) -> tenso
|
||||
// -----
|
||||
// CHECK-LABEL: @compareWithoutBroadcast
|
||||
func @compareWithoutBroadcast(%arg0: tensor<4xf32>, %arg1: tensor<4xf32>) -> tensor<4xi1> {
|
||||
// CHECK: "xla_hlo.compare"(%arg0, %arg1) {comparison_direction = "EQ"} : (tensor<4xf32>, tensor<4xf32>) -> tensor<4xi1>
|
||||
// CHECK: "mhlo.compare"(%arg0, %arg1) {comparison_direction = "EQ"} : (tensor<4xf32>, tensor<4xf32>) -> tensor<4xi1>
|
||||
%0 = xla_chlo.broadcast_compare %arg0, %arg1 {comparison_direction = "EQ"} : (tensor<4xf32>, tensor<4xf32>) -> tensor<4xi1>
|
||||
return %0 : tensor<4xi1>
|
||||
}
|
||||
@ -137,7 +137,7 @@ func @compareWithoutBroadcast(%arg0: tensor<4xf32>, %arg1: tensor<4xf32>) -> ten
|
||||
// -----
|
||||
// CHECK-LABEL: @complexWithoutBroadcast
|
||||
func @complexWithoutBroadcast(%arg0: tensor<4xf32>, %arg1: tensor<4xf32>) -> tensor<4xcomplex<f32>> {
|
||||
// CHECK: "xla_hlo.complex"(%arg0, %arg1) : (tensor<4xf32>, tensor<4xf32>) -> tensor<4xcomplex<f32>>
|
||||
// CHECK: "mhlo.complex"(%arg0, %arg1) : (tensor<4xf32>, tensor<4xf32>) -> tensor<4xcomplex<f32>>
|
||||
%0 = xla_chlo.broadcast_complex %arg0, %arg1 : (tensor<4xf32>, tensor<4xf32>) -> tensor<4xcomplex<f32>>
|
||||
return %0 : tensor<4xcomplex<f32>>
|
||||
}
|
||||
@ -145,7 +145,7 @@ func @complexWithoutBroadcast(%arg0: tensor<4xf32>, %arg1: tensor<4xf32>) -> ten
|
||||
// -----
|
||||
// CHECK-LABEL: @divideWithoutBroadcast
|
||||
func @divideWithoutBroadcast(%arg0: tensor<4xf32>, %arg1: tensor<4xf32>) -> tensor<4xf32> {
|
||||
// CHECK: xla_hlo.divide %arg0, %arg1
|
||||
// CHECK: mhlo.divide %arg0, %arg1
|
||||
%0 = xla_chlo.broadcast_divide %arg0, %arg1 : (tensor<4xf32>, tensor<4xf32>) -> tensor<4xf32>
|
||||
return %0 : tensor<4xf32>
|
||||
}
|
||||
@ -153,7 +153,7 @@ func @divideWithoutBroadcast(%arg0: tensor<4xf32>, %arg1: tensor<4xf32>) -> tens
|
||||
// -----
|
||||
// CHECK-LABEL: @maximumWithoutBroadcast
|
||||
func @maximumWithoutBroadcast(%arg0: tensor<4xf32>, %arg1: tensor<4xf32>) -> tensor<4xf32> {
|
||||
// CHECK: xla_hlo.maximum %arg0, %arg1
|
||||
// CHECK: mhlo.maximum %arg0, %arg1
|
||||
%0 = xla_chlo.broadcast_maximum %arg0, %arg1 : (tensor<4xf32>, tensor<4xf32>) -> tensor<4xf32>
|
||||
return %0 : tensor<4xf32>
|
||||
}
|
||||
@ -161,7 +161,7 @@ func @maximumWithoutBroadcast(%arg0: tensor<4xf32>, %arg1: tensor<4xf32>) -> ten
|
||||
// -----
|
||||
// CHECK-LABEL: @minimumWithoutBroadcast
|
||||
func @minimumWithoutBroadcast(%arg0: tensor<4xf32>, %arg1: tensor<4xf32>) -> tensor<4xf32> {
|
||||
// CHECK: xla_hlo.minimum %arg0, %arg1
|
||||
// CHECK: mhlo.minimum %arg0, %arg1
|
||||
%0 = xla_chlo.broadcast_minimum %arg0, %arg1 : (tensor<4xf32>, tensor<4xf32>) -> tensor<4xf32>
|
||||
return %0 : tensor<4xf32>
|
||||
}
|
||||
@ -169,7 +169,7 @@ func @minimumWithoutBroadcast(%arg0: tensor<4xf32>, %arg1: tensor<4xf32>) -> ten
|
||||
// -----
|
||||
// CHECK-LABEL: @multiplyWithoutBroadcast
|
||||
func @multiplyWithoutBroadcast(%arg0: tensor<4xf32>, %arg1: tensor<4xf32>) -> tensor<4xf32> {
|
||||
// CHECK: xla_hlo.multiply %arg0, %arg1
|
||||
// CHECK: mhlo.multiply %arg0, %arg1
|
||||
%0 = xla_chlo.broadcast_multiply %arg0, %arg1 : (tensor<4xf32>, tensor<4xf32>) -> tensor<4xf32>
|
||||
return %0 : tensor<4xf32>
|
||||
}
|
||||
@ -177,7 +177,7 @@ func @multiplyWithoutBroadcast(%arg0: tensor<4xf32>, %arg1: tensor<4xf32>) -> te
|
||||
// -----
|
||||
// CHECK-LABEL: @orWithoutBroadcast
|
||||
func @orWithoutBroadcast(%arg0: tensor<4xi1>, %arg1: tensor<4xi1>) -> tensor<4xi1> {
|
||||
// CHECK: xla_hlo.or %arg0, %arg1
|
||||
// CHECK: mhlo.or %arg0, %arg1
|
||||
%0 = xla_chlo.broadcast_or %arg0, %arg1 : (tensor<4xi1>, tensor<4xi1>) -> tensor<4xi1>
|
||||
return %0 : tensor<4xi1>
|
||||
}
|
||||
@ -185,7 +185,7 @@ func @orWithoutBroadcast(%arg0: tensor<4xi1>, %arg1: tensor<4xi1>) -> tensor<4xi
|
||||
// -----
|
||||
// CHECK-LABEL: @powerWithoutBroadcast
|
||||
func @powerWithoutBroadcast(%arg0: tensor<4xf32>, %arg1: tensor<4xf32>) -> tensor<4xf32> {
|
||||
// CHECK: xla_hlo.power %arg0, %arg1
|
||||
// CHECK: mhlo.power %arg0, %arg1
|
||||
%0 = xla_chlo.broadcast_power %arg0, %arg1 : (tensor<4xf32>, tensor<4xf32>) -> tensor<4xf32>
|
||||
return %0 : tensor<4xf32>
|
||||
}
|
||||
@ -193,7 +193,7 @@ func @powerWithoutBroadcast(%arg0: tensor<4xf32>, %arg1: tensor<4xf32>) -> tenso
|
||||
// -----
|
||||
// CHECK-LABEL: @remainderWithoutBroadcast
|
||||
func @remainderWithoutBroadcast(%arg0: tensor<4xf32>, %arg1: tensor<4xf32>) -> tensor<4xf32> {
|
||||
// CHECK: xla_hlo.remainder %arg0, %arg1
|
||||
// CHECK: mhlo.remainder %arg0, %arg1
|
||||
%0 = xla_chlo.broadcast_remainder %arg0, %arg1 : (tensor<4xf32>, tensor<4xf32>) -> tensor<4xf32>
|
||||
return %0 : tensor<4xf32>
|
||||
}
|
||||
@ -201,7 +201,7 @@ func @remainderWithoutBroadcast(%arg0: tensor<4xf32>, %arg1: tensor<4xf32>) -> t
|
||||
// -----
|
||||
// CHECK-LABEL: @shift_leftWithoutBroadcast
|
||||
func @shift_leftWithoutBroadcast(%arg0: tensor<4xf32>, %arg1: tensor<4xf32>) -> tensor<4xf32> {
|
||||
// CHECK: xla_hlo.shift_left %arg0, %arg1
|
||||
// CHECK: mhlo.shift_left %arg0, %arg1
|
||||
%0 = xla_chlo.broadcast_shift_left %arg0, %arg1 : (tensor<4xf32>, tensor<4xf32>) -> tensor<4xf32>
|
||||
return %0 : tensor<4xf32>
|
||||
}
|
||||
@ -209,7 +209,7 @@ func @shift_leftWithoutBroadcast(%arg0: tensor<4xf32>, %arg1: tensor<4xf32>) ->
|
||||
// -----
|
||||
// CHECK-LABEL: @shift_right_arithmeticWithoutBroadcast
|
||||
func @shift_right_arithmeticWithoutBroadcast(%arg0: tensor<4xf32>, %arg1: tensor<4xf32>) -> tensor<4xf32> {
|
||||
// CHECK: xla_hlo.shift_right_arithmetic %arg0, %arg1
|
||||
// CHECK: mhlo.shift_right_arithmetic %arg0, %arg1
|
||||
%0 = xla_chlo.broadcast_shift_right_arithmetic %arg0, %arg1 : (tensor<4xf32>, tensor<4xf32>) -> tensor<4xf32>
|
||||
return %0 : tensor<4xf32>
|
||||
}
|
||||
@ -217,7 +217,7 @@ func @shift_right_arithmeticWithoutBroadcast(%arg0: tensor<4xf32>, %arg1: tensor
|
||||
// -----
|
||||
// CHECK-LABEL: @shift_right_logicalWithoutBroadcast
|
||||
func @shift_right_logicalWithoutBroadcast(%arg0: tensor<4xf32>, %arg1: tensor<4xf32>) -> tensor<4xf32> {
|
||||
// CHECK: xla_hlo.shift_right_logical %arg0, %arg1
|
||||
// CHECK: mhlo.shift_right_logical %arg0, %arg1
|
||||
%0 = xla_chlo.broadcast_shift_right_logical %arg0, %arg1 : (tensor<4xf32>, tensor<4xf32>) -> tensor<4xf32>
|
||||
return %0 : tensor<4xf32>
|
||||
}
|
||||
@ -225,7 +225,7 @@ func @shift_right_logicalWithoutBroadcast(%arg0: tensor<4xf32>, %arg1: tensor<4x
|
||||
// -----
|
||||
// CHECK-LABEL: @subWithoutBroadcast
|
||||
func @subWithoutBroadcast(%arg0: tensor<4xf32>, %arg1: tensor<4xf32>) -> tensor<4xf32> {
|
||||
// CHECK: xla_hlo.subtract %arg0, %arg1
|
||||
// CHECK: mhlo.subtract %arg0, %arg1
|
||||
%0 = xla_chlo.broadcast_subtract %arg0, %arg1 : (tensor<4xf32>, tensor<4xf32>) -> tensor<4xf32>
|
||||
return %0 : tensor<4xf32>
|
||||
}
|
||||
@ -233,7 +233,7 @@ func @subWithoutBroadcast(%arg0: tensor<4xf32>, %arg1: tensor<4xf32>) -> tensor<
|
||||
// -----
|
||||
// CHECK-LABEL: @xorWithoutBroadcast
|
||||
func @xorWithoutBroadcast(%arg0: tensor<4xi1>, %arg1: tensor<4xi1>) -> tensor<4xi1> {
|
||||
// CHECK: xla_hlo.xor %arg0, %arg1
|
||||
// CHECK: mhlo.xor %arg0, %arg1
|
||||
%0 = xla_chlo.broadcast_xor %arg0, %arg1 : (tensor<4xi1>, tensor<4xi1>) -> tensor<4xi1>
|
||||
return %0 : tensor<4xi1>
|
||||
}
|
@ -1,9 +1,9 @@
|
||||
// RUN: xla-opt %s -pass-pipeline='func(canonicalize)' | FileCheck %s
|
||||
// RUN: mlir-hlo-opt %s -pass-pipeline='func(canonicalize)' | FileCheck %s
|
||||
|
||||
// CHECK-LABEL: func @single_operand
|
||||
// CHECK-SAME: [[ARG:%[a-zA-Z0-9]+]]
|
||||
func @single_operand(%arg: tensor<1x2xf32>) -> tensor<1x2xf32> {
|
||||
%0 = "xla_hlo.concatenate"(%arg) {dimension = 0 : i64} : (tensor<1x2xf32>) -> tensor<1x2xf32>
|
||||
%0 = "mhlo.concatenate"(%arg) {dimension = 0 : i64} : (tensor<1x2xf32>) -> tensor<1x2xf32>
|
||||
// CHECK-NEXT: return [[ARG]]
|
||||
return %0 : tensor<1x2xf32>
|
||||
}
|
225
tensorflow/compiler/mlir/hlo/tests/convert.mlir
Normal file
225
tensorflow/compiler/mlir/hlo/tests/convert.mlir
Normal file
@ -0,0 +1,225 @@
|
||||
// RUN: mlir-hlo-opt %s -split-input-file -pass-pipeline='func(canonicalize)' | FileCheck %s
|
||||
|
||||
// -----
|
||||
|
||||
// CHECK-LABEL: func @same_type
|
||||
// CHECK-SAME: [[ARG:%[a-zA-Z0-9]+]]
|
||||
func @same_type(%arg: tensor<f32>) -> tensor<f32> {
|
||||
%0 = "mhlo.convert"(%arg) : (tensor<f32>) -> tensor<f32>
|
||||
// CHECK-NEXT: return [[ARG]]
|
||||
return %0 : tensor<f32>
|
||||
}
|
||||
|
||||
// -----
|
||||
|
||||
// CHECK-LABEL: func @int_widening
|
||||
// CHECK-SAME: [[ARG:%[a-zA-Z0-9]+]]
|
||||
func @int_widening(%arg: tensor<i32>) -> tensor<i64> {
|
||||
// CHECK-NEXT: [[RES:%.+]] = "mhlo.convert"([[ARG]]) : (tensor<i32>) -> tensor<i64>
|
||||
%0 = "mhlo.convert"(%arg) : (tensor<i32>) -> tensor<i64>
|
||||
// CHECK-NEXT: return [[RES]]
|
||||
return %0 : tensor<i64>
|
||||
}
|
||||
|
||||
// -----
|
||||
|
||||
// CHECK-LABEL: func @int_narrowing
|
||||
// CHECK-SAME: [[ARG:%[a-zA-Z0-9]+]]
|
||||
func @int_narrowing(%arg: tensor<i32>) -> tensor<i16> {
|
||||
// CHECK-NEXT: [[RES:%.+]] = "mhlo.convert"([[ARG]]) : (tensor<i32>) -> tensor<i16>
|
||||
%0 = "mhlo.convert"(%arg) : (tensor<i32>) -> tensor<i16>
|
||||
// CHECK-NEXT: return [[RES]]
|
||||
return %0 : tensor<i16>
|
||||
}
|
||||
|
||||
// -----
|
||||
|
||||
// CHECK-LABEL: func @float_int
|
||||
// CHECK-SAME: [[ARG:%[a-zA-Z0-9]+]]
|
||||
func @float_int(%arg: tensor<f32>) -> tensor<i32> {
|
||||
// CHECK-NEXT: [[RES:%.+]] = "mhlo.convert"([[ARG]]) : (tensor<f32>) -> tensor<i32>
|
||||
%0 = "mhlo.convert"(%arg) : (tensor<f32>) -> tensor<i32>
|
||||
// CHECK-NEXT: return [[RES]]
|
||||
return %0 : tensor<i32>
|
||||
}
|
||||
|
||||
// -----
|
||||
|
||||
// CHECK-LABEL: func @int_float
|
||||
// CHECK-SAME: [[ARG:%[a-zA-Z0-9]+]]
|
||||
func @int_float(%arg: tensor<i32>) -> tensor<f32> {
|
||||
// CHECK-NEXT: [[RES:%.+]] = "mhlo.convert"([[ARG]]) : (tensor<i32>) -> tensor<f32>
|
||||
%0 = "mhlo.convert"(%arg) : (tensor<i32>) -> tensor<f32>
|
||||
// CHECK-NEXT: return [[RES]]
|
||||
return %0 : tensor<f32>
|
||||
}
|
||||
|
||||
// -----
|
||||
|
||||
// CHECK-LABEL: func @high_rank_tensor
|
||||
// CHECK-SAME: [[ARG:%[a-zA-Z0-9]+]]
|
||||
func @high_rank_tensor(%arg: tensor<2x3xi32>) -> tensor<2x3xf32> {
|
||||
// CHECK-NEXT: [[RES:%.+]] = "mhlo.convert"([[ARG]]) : (tensor<2x3xi32>) -> tensor<2x3xf32>
|
||||
%0 = "mhlo.convert"(%arg) : (tensor<2x3xi32>) -> tensor<2x3xf32>
|
||||
// CHECK-NEXT: return [[RES]]
|
||||
return %0 : tensor<2x3xf32>
|
||||
}
|
||||
|
||||
// -----
|
||||
|
||||
|
||||
// CHECK-LABEL: func @const_same_type
|
||||
func @const_same_type() -> tensor<i32> {
|
||||
// CHECK-NEXT: [[CST:%.+]] = mhlo.constant dense<42> : tensor<i32>
|
||||
%cst = mhlo.constant dense<42> : tensor<i32>
|
||||
%0 = "mhlo.convert"(%cst) : (tensor<i32>) -> tensor<i32>
|
||||
// CHECK-NEXT: return [[CST]]
|
||||
return %0 : tensor<i32>
|
||||
}
|
||||
|
||||
// -----
|
||||
|
||||
// CHECK-LABEL: func @const_float_int
|
||||
func @const_float_int() -> tensor<i32> {
|
||||
// CHECK-NEXT: [[CST:%.+]] = mhlo.constant dense<42> : tensor<i32>
|
||||
%cst = mhlo.constant dense<42.0> : tensor<f32>
|
||||
%0 = "mhlo.convert"(%cst) : (tensor<f32>) -> tensor<i32>
|
||||
// CHECK-NEXT: return [[CST]]
|
||||
return %0 : tensor<i32>
|
||||
}
|
||||
|
||||
// -----
|
||||
|
||||
// CHECK-LABEL: func @const_int_float
|
||||
func @const_int_float() -> tensor<f32> {
|
||||
// CHECK-NEXT: [[CST:%.+]] = mhlo.constant dense<4.{{0*}}e+00> : tensor<f32>
|
||||
%cst = mhlo.constant dense<4> : tensor<i32>
|
||||
%0 = "mhlo.convert"(%cst) : (tensor<i32>) -> tensor<f32>
|
||||
// CHECK-NEXT: return [[CST]]
|
||||
return %0 : tensor<f32>
|
||||
}
|
||||
|
||||
// -----
|
||||
|
||||
// CHECK-LABEL: func @const_negative_int_float
|
||||
func @const_negative_int_float() -> tensor<f32> {
|
||||
// CHECK-NEXT: [[CST:%.+]] = mhlo.constant dense<-4.{{0*}}e+00> : tensor<f32>
|
||||
%cst = mhlo.constant dense<-4> : tensor<i32>
|
||||
%0 = "mhlo.convert"(%cst) : (tensor<i32>) -> tensor<f32>
|
||||
// CHECK-NEXT: return [[CST]]
|
||||
return %0 : tensor<f32>
|
||||
}
|
||||
|
||||
// -----
|
||||
|
||||
// CHECK-LABEL: func @const_int_bf16
|
||||
func @const_int_bf16() -> tensor<bf16> {
|
||||
// CHECK-NEXT: [[CST:%.+]] = mhlo.constant dense<4.{{0*}}e+00> : tensor<bf16>
|
||||
%cst = mhlo.constant dense<4> : tensor<i32>
|
||||
%0 = "mhlo.convert"(%cst) : (tensor<i32>) -> tensor<bf16>
|
||||
// CHECK-NEXT: return [[CST]]
|
||||
return %0 : tensor<bf16>
|
||||
}
|
||||
|
||||
// -----
|
||||
|
||||
// CHECK-LABEL: func @const_bf16_int
|
||||
func @const_bf16_int() -> tensor<i16> {
|
||||
// CHECK-NEXT: [[CST:%.+]] = mhlo.constant dense<42> : tensor<i16>
|
||||
%cst = mhlo.constant dense<42.0> : tensor<bf16>
|
||||
%0 = "mhlo.convert"(%cst) : (tensor<bf16>) -> tensor<i16>
|
||||
// CHECK-NEXT: return [[CST]]
|
||||
return %0 : tensor<i16>
|
||||
}
|
||||
|
||||
// -----
|
||||
|
||||
// CHECK-LABEL: func @const_int_narrowing
|
||||
func @const_int_narrowing() -> tensor<i32> {
|
||||
// CHECK-NEXT: [[CST:%.+]] = mhlo.constant dense<42> : tensor<i32>
|
||||
%cst = mhlo.constant dense<42> : tensor<i64>
|
||||
%0 = "mhlo.convert"(%cst) : (tensor<i64>) -> tensor<i32>
|
||||
// CHECK-NEXT: return [[CST]]
|
||||
return %0 : tensor<i32>
|
||||
}
|
||||
|
||||
// -----
|
||||
|
||||
// CHECK-LABEL: func @const_int_widening
|
||||
func @const_int_widening() -> tensor<i64> {
|
||||
// CHECK-NEXT: [[CST:%.+]] = mhlo.constant dense<42> : tensor<i64>
|
||||
%cst = mhlo.constant dense<42> : tensor<i32>
|
||||
%0 = "mhlo.convert"(%cst) : (tensor<i32>) -> tensor<i64>
|
||||
// CHECK-NEXT: return [[CST]]
|
||||
return %0 : tensor<i64>
|
||||
}
|
||||
|
||||
// -----
|
||||
|
||||
// CHECK-LABEL: func @const_negative_int_widening
|
||||
func @const_negative_int_widening() -> tensor<i64> {
|
||||
// CHECK-NEXT: [[CST:%.+]] = mhlo.constant dense<-42> : tensor<i64>
|
||||
%cst = mhlo.constant dense<-42> : tensor<i32>
|
||||
%0 = "mhlo.convert"(%cst) : (tensor<i32>) -> tensor<i64>
|
||||
// CHECK-NEXT: return [[CST]]
|
||||
return %0 : tensor<i64>
|
||||
}
|
||||
|
||||
// -----
|
||||
|
||||
// CHECK-LABEL: func @const_float_narrowing
|
||||
func @const_float_narrowing() -> tensor<f32> {
|
||||
// CHECK-NEXT: [[CST:%.+]] = mhlo.constant dense<4.2{{0*}}e+00> : tensor<f32>
|
||||
%cst = mhlo.constant dense<4.2> : tensor<f64>
|
||||
%0 = "mhlo.convert"(%cst) : (tensor<f64>) -> tensor<f32>
|
||||
// CHECK-NEXT: return [[CST]]
|
||||
return %0 : tensor<f32>
|
||||
}
|
||||
|
||||
// -----
|
||||
|
||||
// CHECK-LABEL: func @const_f32_bf16
|
||||
func @const_f32_bf16() -> tensor<bf16> {
|
||||
// CHECK-NEXT: [[CST:%.+]] = mhlo.constant dense<4.2{{0*}}e+01> : tensor<bf16>
|
||||
%cst = mhlo.constant dense<42.0> : tensor<f32>
|
||||
%0 = "mhlo.convert"(%cst) : (tensor<f32>) -> tensor<bf16>
|
||||
// CHECK-NEXT: return [[CST]]
|
||||
return %0 : tensor<bf16>
|
||||
}
|
||||
|
||||
// -----
|
||||
|
||||
// CHECK-LABEL: func @const_bf16_f64
|
||||
func @const_bf16_f64() -> tensor<f64> {
|
||||
// CHECK-NEXT: [[CST:%.+]] = mhlo.constant dense<4.187500e+00> : tensor<f64>
|
||||
%cst = mhlo.constant dense<4.2> : tensor<bf16>
|
||||
%0 = "mhlo.convert"(%cst) : (tensor<bf16>) -> tensor<f64>
|
||||
// CHECK-NEXT: return [[CST]]
|
||||
return %0 : tensor<f64>
|
||||
}
|
||||
|
||||
// -----
|
||||
|
||||
// CHECK-LABEL: func @const_bf16_int
|
||||
func @const_bf16_int() -> tensor<i64> {
|
||||
// CHECK-NEXT: [[CST:%.+]] = mhlo.constant dense<42> : tensor<i64>
|
||||
%cst = mhlo.constant dense<42.0> : tensor<bf16>
|
||||
%0 = "mhlo.convert"(%cst) : (tensor<bf16>) -> tensor<i64>
|
||||
// CHECK-NEXT: return [[CST]]
|
||||
return %0 : tensor<i64>
|
||||
}
|
||||
|
||||
|
||||
// -----
|
||||
|
||||
// CHECK-LABEL: func @const_high_rank_tensor
|
||||
func @const_high_rank_tensor() -> tensor<2x3xi32> {
|
||||
// CHECK-NEXT: [[CST:%.+]] = mhlo.constant dense<[
|
||||
// CHECK-SAME: [1, 2, 3], [4, 5, 6]
|
||||
// CHECK-SAME: ]> : tensor<2x3xi32>
|
||||
%cst = mhlo.constant dense<[[1.0, 2.0, 3.0], [4.0, 5.0, 6.0]]> : tensor<2x3xf32>
|
||||
%0 = "mhlo.convert"(%cst) : (tensor<2x3xf32>) -> tensor<2x3xi32>
|
||||
// CHECK-NEXT: return [[CST]]
|
||||
return %0 : tensor<2x3xi32>
|
||||
}
|
||||
|
@ -1,10 +1,10 @@
|
||||
// RUN: xla-opt -hlo-legalize-to-lhlo -buffer-placement -split-input-file %s -o - | FileCheck --check-prefixes=PRE,BOTH %s
|
||||
// RUN: xla-opt -hlo-legalize-to-lhlo=results-escape-function=true -buffer-placement -split-input-file %s -o - | FileCheck --check-prefixes=ESC,BOTH %s
|
||||
// RUN: mlir-hlo-opt -hlo-legalize-to-lhlo -buffer-placement -split-input-file %s -o - | FileCheck --check-prefixes=PRE,BOTH %s
|
||||
// RUN: mlir-hlo-opt -hlo-legalize-to-lhlo=results-escape-function=true -buffer-placement -split-input-file %s -o - | FileCheck --check-prefixes=ESC,BOTH %s
|
||||
|
||||
// BOTH-LABEL: func @attrs
|
||||
func @attrs_copy(%operand: memref<2x2xf32>, %result: memref<2x2xf32>) {
|
||||
%tensor_operand = tensor_load %operand : memref<2x2xf32>
|
||||
%tensor_result = "xla_hlo.exponential"(%tensor_operand)
|
||||
%tensor_result = "mhlo.exponential"(%tensor_operand)
|
||||
{some_attr_1 = "exp.1", some_attr_2 = dense<1> : tensor<1xi64>}
|
||||
: (tensor<2x2xf32>) -> tensor<2x2xf32>
|
||||
// BOTH: "xla_lhlo.exponential"(%{{.*}}, %{{.*}}) {some_attr_1 = "exp.1", some_attr_2 = dense<1> : tensor<1xi64>}
|
||||
@ -28,11 +28,11 @@ func @return_func(%arg0: tensor<4xf32>) -> tensor<4xf32> {
|
||||
|
||||
// BOTH-LABEL: func @func_op_long
|
||||
func @func_op_long(%arg0: tensor<4xf32>, %arg1: tensor<4xf32>) -> tensor<4xf32> {
|
||||
%1 = xla_hlo.maximum %arg0, %arg1 : tensor<4xf32>
|
||||
%2 = xla_hlo.add %arg0, %1 : tensor<4xf32>
|
||||
%3 = xla_hlo.minimum %arg0, %arg1 : tensor<4xf32>
|
||||
%4 = xla_hlo.subtract %arg1, %3 : tensor<4xf32>
|
||||
%5 = xla_hlo.multiply %2, %4 : tensor<4xf32>
|
||||
%1 = mhlo.maximum %arg0, %arg1 : tensor<4xf32>
|
||||
%2 = mhlo.add %arg0, %1 : tensor<4xf32>
|
||||
%3 = mhlo.minimum %arg0, %arg1 : tensor<4xf32>
|
||||
%4 = mhlo.subtract %arg1, %3 : tensor<4xf32>
|
||||
%5 = mhlo.multiply %2, %4 : tensor<4xf32>
|
||||
return %5 : tensor<4xf32>
|
||||
}
|
||||
// PRE: (%[[NEW_ARG0:.*]]: memref<4xf32>, %[[NEW_ARG1:.*]]: memref<4xf32>, %[[RESULT:.*]]: memref<4xf32>)
|
||||
@ -65,12 +65,12 @@ func @fusion(%multiplier: memref<2x2xf32>, %summand_1: memref<2x2xf32>,
|
||||
// BOTH-NEXT: %[[ADD_RESULT:.*]] = alloc() : memref<2x2xf32>
|
||||
%tensor_summand_1 = tensor_load %summand_1 : memref<2x2xf32>
|
||||
%tensor_summand_2 = tensor_load %summand_2 : memref<2x2xf32>
|
||||
%sum = "xla_hlo.add"(%tensor_summand_1, %tensor_summand_2)
|
||||
%sum = "mhlo.add"(%tensor_summand_1, %tensor_summand_2)
|
||||
: (tensor<2x2xf32>, tensor<2x2xf32>) -> tensor<2x2xf32>
|
||||
// BOTH-NEXT: "xla_lhlo.add"(%{{.*}}, %{{.*}}, %[[ADD_RESULT]])
|
||||
// BOTH-NEXT: %[[MUL_RESULT:.*]] = alloc() : memref<2x2xf32>
|
||||
%tensor_multiplier = tensor_load %multiplier : memref<2x2xf32>
|
||||
%tensor_result = "xla_hlo.multiply"(%sum, %tensor_multiplier)
|
||||
%tensor_result = "mhlo.multiply"(%sum, %tensor_multiplier)
|
||||
: (tensor<2x2xf32>, tensor<2x2xf32>) -> tensor<2x2xf32>
|
||||
// BOTH-NEXT: "xla_lhlo.multiply"(%[[ADD_RESULT]], %{{.*}}, %[[MUL_RESULT]])
|
||||
// BOTH-NEXT: dealloc %[[ADD_RESULT]] : memref<2x2xf32>
|
||||
@ -86,7 +86,7 @@ func @fusion(%multiplier: memref<2x2xf32>, %summand_1: memref<2x2xf32>,
|
||||
// BOTH-LABEL: func @copy
|
||||
func @copy(%operand: memref<2x2xf32>, %result: memref<2x2xf32>) {
|
||||
%tensor_operand = tensor_load %operand : memref<2x2xf32>
|
||||
%tensor_result = "xla_hlo.copy"(%tensor_operand)
|
||||
%tensor_result = "mhlo.copy"(%tensor_operand)
|
||||
: (tensor<2x2xf32>) -> tensor<2x2xf32>
|
||||
// BOTH: "xla_lhlo.copy"(%{{.*}}, %{{.*}})
|
||||
tensor_store %tensor_result, %result : memref<2x2xf32>
|
||||
@ -98,7 +98,7 @@ func @copy(%operand: memref<2x2xf32>, %result: memref<2x2xf32>) {
|
||||
// BOTH-LABEL: func @exp
|
||||
func @exp(%operand: memref<2x2xf32>, %result: memref<2x2xf32>) {
|
||||
%tensor_operand = tensor_load %operand : memref<2x2xf32>
|
||||
%tensor_result = "xla_hlo.exponential"(%tensor_operand)
|
||||
%tensor_result = "mhlo.exponential"(%tensor_operand)
|
||||
: (tensor<2x2xf32>) -> tensor<2x2xf32>
|
||||
// BOTH: "xla_lhlo.exponential"(%{{.*}}, %{{.*}})
|
||||
tensor_store %tensor_result, %result : memref<2x2xf32>
|
||||
@ -110,7 +110,7 @@ func @exp(%operand: memref<2x2xf32>, %result: memref<2x2xf32>) {
|
||||
// BOTH-LABEL: func @log
|
||||
func @log(%operand: memref<2x2xf32>, %result: memref<2x2xf32>) {
|
||||
%tensor_operand = tensor_load %operand : memref<2x2xf32>
|
||||
%tensor_result = "xla_hlo.log"(%tensor_operand)
|
||||
%tensor_result = "mhlo.log"(%tensor_operand)
|
||||
: (tensor<2x2xf32>) -> tensor<2x2xf32>
|
||||
// BOTH: "xla_lhlo.log"(%{{.*}}, %{{.*}})
|
||||
tensor_store %tensor_result, %result : memref<2x2xf32>
|
||||
@ -125,7 +125,7 @@ func @select(%pred: memref<2x2xi1>, %lhs: memref<2x2xf32>,
|
||||
%tensor_pred = tensor_load %pred : memref<2x2xi1>
|
||||
%tensor_lhs = tensor_load %lhs : memref<2x2xf32>
|
||||
%tensor_rhs = tensor_load %rhs : memref<2x2xf32>
|
||||
%tensor_result = "xla_hlo.select"(%tensor_pred, %tensor_lhs, %tensor_rhs)
|
||||
%tensor_result = "mhlo.select"(%tensor_pred, %tensor_lhs, %tensor_rhs)
|
||||
: (tensor<2x2xi1>, tensor<2x2xf32>, tensor<2x2xf32>) -> tensor<2x2xf32>
|
||||
// BOTH: "xla_lhlo.select"(%{{.*}}, %{{.*}}, %{{.*}}, %{{.*}})
|
||||
tensor_store %tensor_result, %result : memref<2x2xf32>
|
||||
@ -138,7 +138,7 @@ func @select(%pred: memref<2x2xi1>, %lhs: memref<2x2xf32>,
|
||||
func @compare(%lhs: memref<2x2xf32>, %rhs: memref<2x2xf32>, %result: memref<2x2xi1>) {
|
||||
%tensor_lhs = tensor_load %lhs : memref<2x2xf32>
|
||||
%tensor_rhs = tensor_load %rhs : memref<2x2xf32>
|
||||
%tensor_result = "xla_hlo.compare"(%tensor_lhs, %tensor_rhs)
|
||||
%tensor_result = "mhlo.compare"(%tensor_lhs, %tensor_rhs)
|
||||
{comparison_direction = "EQ"}
|
||||
: (tensor<2x2xf32>, tensor<2x2xf32>) -> tensor<2x2xi1>
|
||||
// BOTH: "xla_lhlo.compare"(%{{.*}}, %{{.*}}, %{{.*}}) {comparison_direction = "EQ"}
|
||||
@ -151,7 +151,7 @@ func @compare(%lhs: memref<2x2xf32>, %rhs: memref<2x2xf32>, %result: memref<2x2x
|
||||
// BOTH-LABEL: func @broadcast
|
||||
func @broadcast(%operand: memref<5xf32>, %result: memref<10x5xf32>) {
|
||||
%tensor_operand = tensor_load %operand : memref<5xf32>
|
||||
%tensor_result = "xla_hlo.broadcast_in_dim"(%tensor_operand)
|
||||
%tensor_result = "mhlo.broadcast_in_dim"(%tensor_operand)
|
||||
{broadcast_dimensions = dense<1> : tensor<1xi64>}
|
||||
: (tensor<5xf32>) -> tensor<10x5xf32>
|
||||
// BOTH: "xla_lhlo.broadcast_in_dim"(%{{.*}}, %{{.*}}) {broadcast_dimensions = dense<1> : tensor<1xi64>}
|
||||
@ -170,7 +170,7 @@ func @dyn_broadcast(%operand: memref<?x?xf32>) {
|
||||
// BOTH-SAME: (%[[OPERAND:.*]]: memref<?x?xf32>)
|
||||
%tensor_operand = tensor_load %operand : memref<?x?xf32>
|
||||
%shape = call @external_func() : () -> tensor<3xi64>
|
||||
%tensor_result = "xla_hlo.dynamic_broadcast_in_dim"(%tensor_operand, %shape) {
|
||||
%tensor_result = "mhlo.dynamic_broadcast_in_dim"(%tensor_operand, %shape) {
|
||||
broadcast_dimensions = dense<[1, 2]> : tensor<2xi64>
|
||||
} : (tensor<?x?xf32>, tensor<3xi64>) -> tensor<?x?x?xf32>
|
||||
// BOTH: %[[SHAPE:.*]] = call @external_func()
|
||||
@ -226,7 +226,7 @@ func @complex(%real: memref<2x2xf32>,
|
||||
%result: memref<2x2xcomplex<f32>>) {
|
||||
%tensor_real = tensor_load %real : memref<2x2xf32>
|
||||
%tensor_imag = tensor_load %imag : memref<2x2xf32>
|
||||
%tensor_result = "xla_hlo.complex"(%tensor_real, %tensor_imag)
|
||||
%tensor_result = "mhlo.complex"(%tensor_real, %tensor_imag)
|
||||
: (tensor<2x2xf32>, tensor<2x2xf32>) -> tensor<2x2xcomplex<f32>>
|
||||
// BOTH: "xla_lhlo.complex"(%{{.*}}, %{{.*}})
|
||||
tensor_store %tensor_result, %result : memref<2x2xcomplex<f32>>
|
||||
@ -238,7 +238,7 @@ func @complex(%real: memref<2x2xf32>,
|
||||
// BOTH-LABEL: func @real
|
||||
func @real(%operand: memref<2x2xcomplex<f32>>, %result: memref<2x2xf32>) {
|
||||
%tensor_operand = tensor_load %operand : memref<2x2xcomplex<f32>>
|
||||
%tensor_result = "xla_hlo.real"(%tensor_operand)
|
||||
%tensor_result = "mhlo.real"(%tensor_operand)
|
||||
: (tensor<2x2xcomplex<f32>>) -> tensor<2x2xf32>
|
||||
// BOTH: "xla_lhlo.real"(%{{.*}}, %{{.*}})
|
||||
tensor_store %tensor_result, %result : memref<2x2xf32>
|
||||
@ -250,7 +250,7 @@ func @real(%operand: memref<2x2xcomplex<f32>>, %result: memref<2x2xf32>) {
|
||||
// BOTH-LABEL: func @imag
|
||||
func @imag(%operand: memref<2x2xcomplex<f32>>, %result: memref<2x2xf32>) {
|
||||
%tensor_operand = tensor_load %operand : memref<2x2xcomplex<f32>>
|
||||
%tensor_result = "xla_hlo.imag"(%tensor_operand)
|
||||
%tensor_result = "mhlo.imag"(%tensor_operand)
|
||||
: (tensor<2x2xcomplex<f32>>) -> tensor<2x2xf32>
|
||||
// BOTH: "xla_lhlo.imag"(%{{.*}}, %{{.*}})
|
||||
tensor_store %tensor_result, %result : memref<2x2xf32>
|
||||
@ -261,7 +261,7 @@ func @imag(%operand: memref<2x2xcomplex<f32>>, %result: memref<2x2xf32>) {
|
||||
|
||||
// BOTH-LABEL: func @iota
|
||||
func @iota(%result: memref<10xi32>) {
|
||||
%tensor_result = "xla_hlo.iota"()
|
||||
%tensor_result = "mhlo.iota"()
|
||||
{iota_dimension = 0 : i64} : () -> tensor<10xi32>
|
||||
// BOTH: "xla_lhlo.iota"(%{{.*}}) {iota_dimension = 0 : i64}
|
||||
tensor_store %tensor_result, %result : memref<10xi32>
|
||||
@ -273,7 +273,7 @@ func @iota(%result: memref<10xi32>) {
|
||||
// BOTH-LABEL: func @abs
|
||||
func @abs(%operand: memref<2x2xf32>, %result: memref<2x2xf32>) {
|
||||
%tensor_operand = tensor_load %operand : memref<2x2xf32>
|
||||
%tensor_result = "xla_hlo.abs"(%tensor_operand)
|
||||
%tensor_result = "mhlo.abs"(%tensor_operand)
|
||||
: (tensor<2x2xf32>) -> tensor<2x2xf32>
|
||||
// BOTH: "xla_lhlo.abs"(%{{.*}}, %{{.*}})
|
||||
tensor_store %tensor_result, %result : memref<2x2xf32>
|
||||
@ -285,7 +285,7 @@ func @abs(%operand: memref<2x2xf32>, %result: memref<2x2xf32>) {
|
||||
// BOTH-LABEL: func @ceil
|
||||
func @ceil(%operand: memref<2x2xf32>, %result: memref<2x2xf32>) {
|
||||
%tensor_operand = tensor_load %operand : memref<2x2xf32>
|
||||
%tensor_result = "xla_hlo.ceil"(%tensor_operand)
|
||||
%tensor_result = "mhlo.ceil"(%tensor_operand)
|
||||
: (tensor<2x2xf32>) -> tensor<2x2xf32>
|
||||
// BOTH: "xla_lhlo.ceil"(%{{.*}}, %{{.*}})
|
||||
tensor_store %tensor_result, %result : memref<2x2xf32>
|
||||
@ -297,7 +297,7 @@ func @ceil(%operand: memref<2x2xf32>, %result: memref<2x2xf32>) {
|
||||
// BOTH-LABEL: func @convert
|
||||
func @convert(%operand: memref<2x2xf32>, %result: memref<2x2xf32>) {
|
||||
%tensor_operand = tensor_load %operand : memref<2x2xf32>
|
||||
%tensor_result = "xla_hlo.convert"(%tensor_operand)
|
||||
%tensor_result = "mhlo.convert"(%tensor_operand)
|
||||
: (tensor<2x2xf32>) -> tensor<2x2xf32>
|
||||
// BOTH: "xla_lhlo.copy"(%{{.*}}, %{{.*}})
|
||||
// BOTH-NOT: tensor_store
|
||||
@ -310,7 +310,7 @@ func @convert(%operand: memref<2x2xf32>, %result: memref<2x2xf32>) {
|
||||
// BOTH-LABEL: func @cos
|
||||
func @cos(%operand: memref<2x2xf32>, %result: memref<2x2xf32>) {
|
||||
%tensor_operand = tensor_load %operand : memref<2x2xf32>
|
||||
%tensor_result = "xla_hlo.cosine"(%tensor_operand)
|
||||
%tensor_result = "mhlo.cosine"(%tensor_operand)
|
||||
: (tensor<2x2xf32>) -> tensor<2x2xf32>
|
||||
// BOTH: "xla_lhlo.cosine"(%{{.*}}, %{{.*}})
|
||||
tensor_store %tensor_result, %result : memref<2x2xf32>
|
||||
@ -322,7 +322,7 @@ func @cos(%operand: memref<2x2xf32>, %result: memref<2x2xf32>) {
|
||||
// BOTH-LABEL: func @neg
|
||||
func @neg(%operand: memref<2x2xf32>, %result: memref<2x2xf32>) {
|
||||
%tensor_operand = tensor_load %operand : memref<2x2xf32>
|
||||
%tensor_result = "xla_hlo.negate"(%tensor_operand)
|
||||
%tensor_result = "mhlo.negate"(%tensor_operand)
|
||||
: (tensor<2x2xf32>) -> tensor<2x2xf32>
|
||||
// BOTH: "xla_lhlo.negate"(%{{.*}}, %{{.*}})
|
||||
tensor_store %tensor_result, %result : memref<2x2xf32>
|
||||
@ -334,7 +334,7 @@ func @neg(%operand: memref<2x2xf32>, %result: memref<2x2xf32>) {
|
||||
// BOTH-LABEL: func @rsqrt
|
||||
func @rsqrt(%operand: memref<2x2xf32>, %result: memref<2x2xf32>) {
|
||||
%tensor_operand = tensor_load %operand : memref<2x2xf32>
|
||||
%tensor_result = "xla_hlo.rsqrt"(%tensor_operand)
|
||||
%tensor_result = "mhlo.rsqrt"(%tensor_operand)
|
||||
: (tensor<2x2xf32>) -> tensor<2x2xf32>
|
||||
// BOTH: "xla_lhlo.rsqrt"(%{{.*}}, %{{.*}})
|
||||
tensor_store %tensor_result, %result : memref<2x2xf32>
|
||||
@ -346,7 +346,7 @@ func @rsqrt(%operand: memref<2x2xf32>, %result: memref<2x2xf32>) {
|
||||
// BOTH-LABEL: func @sign
|
||||
func @sign(%operand: memref<2x2xf32>, %result: memref<2x2xf32>) {
|
||||
%tensor_operand = tensor_load %operand : memref<2x2xf32>
|
||||
%tensor_result = "xla_hlo.sign"(%tensor_operand)
|
||||
%tensor_result = "mhlo.sign"(%tensor_operand)
|
||||
: (tensor<2x2xf32>) -> tensor<2x2xf32>
|
||||
// BOTH: "xla_lhlo.sign"(%{{.*}}, %{{.*}})
|
||||
tensor_store %tensor_result, %result : memref<2x2xf32>
|
||||
@ -358,7 +358,7 @@ func @sign(%operand: memref<2x2xf32>, %result: memref<2x2xf32>) {
|
||||
// BOTH-LABEL: func @sqrt
|
||||
func @sqrt(%operand: memref<2x2xf32>, %result: memref<2x2xf32>) {
|
||||
%tensor_operand = tensor_load %operand : memref<2x2xf32>
|
||||
%tensor_result = "xla_hlo.sqrt"(%tensor_operand)
|
||||
%tensor_result = "mhlo.sqrt"(%tensor_operand)
|
||||
: (tensor<2x2xf32>) -> tensor<2x2xf32>
|
||||
// BOTH: "xla_lhlo.sqrt"(%{{.*}}, %{{.*}})
|
||||
tensor_store %tensor_result, %result : memref<2x2xf32>
|
||||
@ -370,7 +370,7 @@ func @sqrt(%operand: memref<2x2xf32>, %result: memref<2x2xf32>) {
|
||||
// BOTH-LABEL: func @tanh
|
||||
func @tanh(%operand: memref<2x2xf32>, %result: memref<2x2xf32>) {
|
||||
%tensor_operand = tensor_load %operand : memref<2x2xf32>
|
||||
%tensor_result = "xla_hlo.tanh"(%tensor_operand)
|
||||
%tensor_result = "mhlo.tanh"(%tensor_operand)
|
||||
: (tensor<2x2xf32>) -> tensor<2x2xf32>
|
||||
// BOTH: "xla_lhlo.tanh"(%{{.*}}, %{{.*}})
|
||||
tensor_store %tensor_result, %result : memref<2x2xf32>
|
||||
@ -383,7 +383,7 @@ func @tanh(%operand: memref<2x2xf32>, %result: memref<2x2xf32>) {
|
||||
func @remainder(%lhs: memref<2x2xf32>, %rhs: memref<2x2xf32>, %result: memref<2x2xf32>) {
|
||||
%tensor_lhs = tensor_load %lhs : memref<2x2xf32>
|
||||
%tensor_rhs = tensor_load %rhs : memref<2x2xf32>
|
||||
%tensor_result = "xla_hlo.remainder"(%tensor_lhs, %tensor_rhs)
|
||||
%tensor_result = "mhlo.remainder"(%tensor_lhs, %tensor_rhs)
|
||||
: (tensor<2x2xf32>, tensor<2x2xf32>) -> tensor<2x2xf32>
|
||||
// BOTH: "xla_lhlo.remainder"(%{{.*}}, %{{.*}}, %{{.*}})
|
||||
tensor_store %tensor_result, %result : memref<2x2xf32>
|
||||
@ -395,7 +395,7 @@ func @remainder(%lhs: memref<2x2xf32>, %rhs: memref<2x2xf32>, %result: memref<2x
|
||||
// Dynamic shape binary element-wise operation.
|
||||
// BOTH-LABEL: func @add_dyn
|
||||
func @add_dyn(%lhs: tensor<?x?xf32>, %rhs: tensor<?x?xf32>) {
|
||||
%result = "xla_hlo.add"(%lhs, %rhs)
|
||||
%result = "mhlo.add"(%lhs, %rhs)
|
||||
: (tensor<?x?xf32>, tensor<?x?xf32>) -> tensor<?x?xf32>
|
||||
// BOTH: %[[C0:.*]] = constant 0 : index
|
||||
// BOTH: %[[DIM0:.*]] = dim %arg0, %[[C0]] : memref<?x?xf32>
|
||||
@ -420,7 +420,7 @@ func @add_dyn(%lhs: tensor<?x?xf32>, %rhs: tensor<?x?xf32>) {
|
||||
// Dynamic shape unary element-wise operation.
|
||||
// BOTH-LABEL: func @tanh_dyn
|
||||
func @tanh_dyn(%arg0: tensor<?x?xf32>) {
|
||||
%result = "xla_hlo.tanh"(%arg0)
|
||||
%result = "mhlo.tanh"(%arg0)
|
||||
: (tensor<?x?xf32>) -> tensor<?x?xf32>
|
||||
// BOTH: %[[C0:.*]] = constant 0 : index
|
||||
// BOTH: %[[DIM0:.*]] = dim %arg0, %[[C0]] : memref<?x?xf32>
|
||||
@ -448,7 +448,7 @@ func @dot(%arg0: tensor<1024x1024xf32>) -> tensor<1024x1024xf32> {
|
||||
// ESC-SAME: (%[[ARG0:.*]]: [[TYPE:.*]]) -> [[TYPE]]
|
||||
// BOTH-NEXT: %[[ALLOC:.*]] = alloc
|
||||
// BOTH: "xla_lhlo.dot"(%[[ARG0]], %[[ARG0]], %[[ALLOC]]) : ([[TYPE]], [[TYPE]], [[TYPE]]) -> ()
|
||||
%dot = "xla_hlo.dot"(%arg0, %arg0)
|
||||
%dot = "mhlo.dot"(%arg0, %arg0)
|
||||
: (tensor<1024x1024xf32>, tensor<1024x1024xf32>) -> tensor<1024x1024xf32>
|
||||
// PRE: "xla_lhlo.copy"(%[[ALLOC]], %[[RESULT]])
|
||||
// ESC: return %[[ALLOC]]
|
||||
@ -466,7 +466,7 @@ func @conv(%input: tensor<3x5x5x3xf32>, %filter : tensor<2x2x3x4xf32>) -> tensor
|
||||
// BOTH-SAME: [0, 1], [0, 1]]> : tensor<2x2xi64>
|
||||
// BOTH-SAME: rhs_dilation = dense<[1, 2]>
|
||||
// BOTH-SAME: window_strides = dense<[2, 1]>
|
||||
%out = "xla_hlo.convolution"(%filter, %input) {
|
||||
%out = "mhlo.convolution"(%filter, %input) {
|
||||
batch_group_count = 1 : i64,
|
||||
dimension_numbers = {
|
||||
input_batch_dimension = 0 : i64,
|
@ -1,4 +1,4 @@
|
||||
// RUN: xla-opt %s -hlo-legalize-to-linalg -split-input-file | FileCheck %s
|
||||
// RUN: mlir-hlo-opt %s -hlo-legalize-to-linalg -split-input-file | FileCheck %s
|
||||
|
||||
// CHECK: #map0 = affine_map<(d0, d1) -> (d0, d1)>
|
||||
// CHECK-LABEL: func @float_add
|
||||
@ -10,7 +10,7 @@ func @float_add(%lhs: tensor<2x2xf32>,
|
||||
// CHECK-SAME: %[[ARG1:[a-zA-Z0-9_]*]]: f32
|
||||
// CHECK: %[[RESULT:[a-zA-Z0-9_]*]] = addf %[[ARG0]], %[[ARG1]]
|
||||
// CHECK: linalg.yield %[[RESULT]]
|
||||
%0 = "xla_hlo.add"(%lhs, %rhs) : (tensor<2x2xf32>,
|
||||
%0 = "mhlo.add"(%lhs, %rhs) : (tensor<2x2xf32>,
|
||||
tensor<2x2xf32>) -> tensor<2x2xf32>
|
||||
return %0 : tensor<2x2xf32>
|
||||
}
|
||||
@ -22,7 +22,7 @@ func @integer_add(%lhs: tensor<2x2xi32>,
|
||||
%rhs: tensor<2x2xi32>) -> tensor<2x2xi32> {
|
||||
// CHECK: linalg.generic
|
||||
// CHECK: addi
|
||||
%0 = "xla_hlo.add"(%lhs, %rhs) : (tensor<2x2xi32>,
|
||||
%0 = "mhlo.add"(%lhs, %rhs) : (tensor<2x2xi32>,
|
||||
tensor<2x2xi32>) -> tensor<2x2xi32>
|
||||
return %0 : tensor<2x2xi32>
|
||||
}
|
||||
@ -34,7 +34,7 @@ func @float_mul(%lhs: tensor<2x2xf32>,
|
||||
%rhs: tensor<2x2xf32>) -> tensor<2x2xf32> {
|
||||
// CHECK: linalg.generic
|
||||
// CHECK: mulf
|
||||
%0 = "xla_hlo.multiply"(%lhs, %rhs) : (tensor<2x2xf32>,
|
||||
%0 = "mhlo.multiply"(%lhs, %rhs) : (tensor<2x2xf32>,
|
||||
tensor<2x2xf32>) -> tensor<2x2xf32>
|
||||
return %0 : tensor<2x2xf32>
|
||||
}
|
||||
@ -46,7 +46,7 @@ func @integer_mul(%lhs: tensor<2x2xi32>,
|
||||
%rhs: tensor<2x2xi32>) -> tensor<2x2xi32> {
|
||||
// CHECK: linalg.generic
|
||||
// CHECK: muli
|
||||
%0 = "xla_hlo.multiply"(%lhs, %rhs) : (tensor<2x2xi32>,
|
||||
%0 = "mhlo.multiply"(%lhs, %rhs) : (tensor<2x2xi32>,
|
||||
tensor<2x2xi32>) -> tensor<2x2xi32>
|
||||
return %0 : tensor<2x2xi32>
|
||||
}
|
||||
@ -58,7 +58,7 @@ func @float_remainder(%lhs: tensor<2x2xf32>,
|
||||
%rhs: tensor<2x2xf32>) -> tensor<2x2xf32> {
|
||||
// CHECK: linalg.generic
|
||||
// CHECK: remf
|
||||
%0 = "xla_hlo.remainder"(%lhs, %rhs) : (tensor<2x2xf32>,
|
||||
%0 = "mhlo.remainder"(%lhs, %rhs) : (tensor<2x2xf32>,
|
||||
tensor<2x2xf32>) -> tensor<2x2xf32>
|
||||
return %0 : tensor<2x2xf32>
|
||||
}
|
||||
@ -70,7 +70,7 @@ func @integer_remainder(%lhs: tensor<2x2xi32>,
|
||||
%rhs: tensor<2x2xi32>) -> tensor<2x2xi32> {
|
||||
// CHECK: linalg.generic
|
||||
// CHECK: remi_signed
|
||||
%0 = "xla_hlo.remainder"(%lhs, %rhs) : (tensor<2x2xi32>,
|
||||
%0 = "mhlo.remainder"(%lhs, %rhs) : (tensor<2x2xi32>,
|
||||
tensor<2x2xi32>) -> tensor<2x2xi32>
|
||||
return %0 : tensor<2x2xi32>
|
||||
}
|
||||
@ -79,7 +79,7 @@ func @integer_remainder(%lhs: tensor<2x2xi32>,
|
||||
|
||||
// CHECK-LABEL: func @float_rsqrt
|
||||
func @float_rsqrt(%operand: tensor<2x2xf32>) -> tensor<2x2xf32> {
|
||||
%tensor_result = "xla_hlo.rsqrt"(%operand)
|
||||
%tensor_result = "mhlo.rsqrt"(%operand)
|
||||
: (tensor<2x2xf32>) -> tensor<2x2xf32>
|
||||
// CHECK: linalg.generic
|
||||
// CHECK: rsqrt
|
||||
@ -93,7 +93,7 @@ func @float_sub(%lhs: tensor<2x2xf32>,
|
||||
%rhs: tensor<2x2xf32>) -> tensor<2x2xf32> {
|
||||
// CHECK: linalg.generic
|
||||
// CHECK: subf
|
||||
%0 = "xla_hlo.subtract"(%lhs, %rhs) : (tensor<2x2xf32>,
|
||||
%0 = "mhlo.subtract"(%lhs, %rhs) : (tensor<2x2xf32>,
|
||||
tensor<2x2xf32>) -> tensor<2x2xf32>
|
||||
return %0 : tensor<2x2xf32>
|
||||
}
|
||||
@ -105,7 +105,7 @@ func @integer_sub(%lhs: tensor<2x2xi32>,
|
||||
%rhs: tensor<2x2xi32>) -> tensor<2x2xi32> {
|
||||
// CHECK: linalg.generic
|
||||
// CHECK: subi
|
||||
%0 = "xla_hlo.subtract"(%lhs, %rhs) : (tensor<2x2xi32>,
|
||||
%0 = "mhlo.subtract"(%lhs, %rhs) : (tensor<2x2xi32>,
|
||||
tensor<2x2xi32>) -> tensor<2x2xi32>
|
||||
return %0 : tensor<2x2xi32>
|
||||
}
|
||||
@ -116,7 +116,7 @@ func @integer_sub(%lhs: tensor<2x2xi32>,
|
||||
func @float_abs(%arg0: tensor<2x2xf32>) -> tensor<2x2xf32> {
|
||||
// CHECK: linalg.generic
|
||||
// CHECK: absf
|
||||
%0 = "xla_hlo.abs"(%arg0) : (tensor<2x2xf32>) -> tensor<2x2xf32>
|
||||
%0 = "mhlo.abs"(%arg0) : (tensor<2x2xf32>) -> tensor<2x2xf32>
|
||||
return %0 : tensor<2x2xf32>
|
||||
}
|
||||
|
||||
@ -126,7 +126,7 @@ func @float_abs(%arg0: tensor<2x2xf32>) -> tensor<2x2xf32> {
|
||||
func @float_exp(%arg0: tensor<2x2xf32>) -> tensor<2x2xf32> {
|
||||
// CHECK: linalg.generic
|
||||
// CHECK: exp
|
||||
%0 = "xla_hlo.exponential"(%arg0) : (tensor<2x2xf32>) -> tensor<2x2xf32>
|
||||
%0 = "mhlo.exponential"(%arg0) : (tensor<2x2xf32>) -> tensor<2x2xf32>
|
||||
return %0 : tensor<2x2xf32>
|
||||
}
|
||||
|
||||
@ -136,7 +136,7 @@ func @float_exp(%arg0: tensor<2x2xf32>) -> tensor<2x2xf32> {
|
||||
func @float_log(%arg0: tensor<2x2xf32>) -> tensor<2x2xf32> {
|
||||
// CHECK: linalg.generic
|
||||
// CHECK: log
|
||||
%0 = "xla_hlo.log"(%arg0) : (tensor<2x2xf32>) -> tensor<2x2xf32>
|
||||
%0 = "mhlo.log"(%arg0) : (tensor<2x2xf32>) -> tensor<2x2xf32>
|
||||
return %0 : tensor<2x2xf32>
|
||||
}
|
||||
|
||||
@ -146,7 +146,7 @@ func @float_log(%arg0: tensor<2x2xf32>) -> tensor<2x2xf32> {
|
||||
func @float_ceil(%arg0: tensor<2x2xf32>) -> tensor<2x2xf32> {
|
||||
// CHECK: linalg.generic
|
||||
// CHECK: ceilf
|
||||
%0 = "xla_hlo.ceil"(%arg0) : (tensor<2x2xf32>) -> tensor<2x2xf32>
|
||||
%0 = "mhlo.ceil"(%arg0) : (tensor<2x2xf32>) -> tensor<2x2xf32>
|
||||
return %0 : tensor<2x2xf32>
|
||||
}
|
||||
|
||||
@ -156,7 +156,7 @@ func @float_ceil(%arg0: tensor<2x2xf32>) -> tensor<2x2xf32> {
|
||||
func @float_neg(%arg0: tensor<2x2xf32>) -> tensor<2x2xf32> {
|
||||
// CHECK: linalg.generic
|
||||
// CHECK: negf
|
||||
%0 = "xla_hlo.negate"(%arg0) : (tensor<2x2xf32>) -> tensor<2x2xf32>
|
||||
%0 = "mhlo.negate"(%arg0) : (tensor<2x2xf32>) -> tensor<2x2xf32>
|
||||
return %0 : tensor<2x2xf32>
|
||||
}
|
||||
|
||||
@ -166,7 +166,7 @@ func @float_neg(%arg0: tensor<2x2xf32>) -> tensor<2x2xf32> {
|
||||
func @float_tanh(%arg0: tensor<2x2xf32>) -> tensor<2x2xf32> {
|
||||
// CHECK: linalg.generic
|
||||
// CHECK: tanh
|
||||
%0 = "xla_hlo.tanh"(%arg0) : (tensor<2x2xf32>) -> tensor<2x2xf32>
|
||||
%0 = "mhlo.tanh"(%arg0) : (tensor<2x2xf32>) -> tensor<2x2xf32>
|
||||
return %0 : tensor<2x2xf32>
|
||||
}
|
||||
|
||||
@ -177,7 +177,7 @@ func @integer_and(%lhs: tensor<2x2xi32>,
|
||||
%rhs: tensor<2x2xi32>) -> tensor<2x2xi32> {
|
||||
// CHECK: linalg.generic
|
||||
// CHECK: and
|
||||
%0 = "xla_hlo.and"(%lhs, %rhs) : (tensor<2x2xi32>,
|
||||
%0 = "mhlo.and"(%lhs, %rhs) : (tensor<2x2xi32>,
|
||||
tensor<2x2xi32>) -> tensor<2x2xi32>
|
||||
return %0 : tensor<2x2xi32>
|
||||
}
|
||||
@ -187,7 +187,7 @@ func @integer_and(%lhs: tensor<2x2xi32>,
|
||||
// CHECK-LABEL: func @float_cmp
|
||||
func @float_cmp(%lhs: tensor<2x2xf32>,
|
||||
%rhs: tensor<2x2xf32>) -> (tensor<2x2xi1>) {
|
||||
%0 = "xla_hlo.compare"(%lhs, %rhs) {comparison_direction = "EQ"}
|
||||
%0 = "mhlo.compare"(%lhs, %rhs) {comparison_direction = "EQ"}
|
||||
: (tensor<2x2xf32>, tensor<2x2xf32>) -> tensor<2x2xi1>
|
||||
return %0 : tensor<2x2xi1>
|
||||
}
|
||||
@ -201,7 +201,7 @@ func @float_cmp(%lhs: tensor<2x2xf32>,
|
||||
// CHECK-LABEL: func @int_cmp
|
||||
func @int_cmp(%lhs: tensor<2x2xi32>,
|
||||
%rhs: tensor<2x2xi32>) -> tensor<2x2xi1> {
|
||||
%0 = "xla_hlo.compare"(%lhs, %rhs) {comparison_direction = "LT"}
|
||||
%0 = "mhlo.compare"(%lhs, %rhs) {comparison_direction = "LT"}
|
||||
: (tensor<2x2xi32>, tensor<2x2xi32>) -> (tensor<2x2xi1>)
|
||||
return %0 : tensor<2x2xi1>
|
||||
}
|
||||
@ -216,7 +216,7 @@ func @int_cmp(%lhs: tensor<2x2xi32>,
|
||||
func @float_cos(%arg0: tensor<2x2xf32>) -> tensor<2x2xf32> {
|
||||
// CHECK: linalg.generic
|
||||
// CHECK: cos
|
||||
%0 = "xla_hlo.cosine"(%arg0) : (tensor<2x2xf32>) -> tensor<2x2xf32>
|
||||
%0 = "mhlo.cosine"(%arg0) : (tensor<2x2xf32>) -> tensor<2x2xf32>
|
||||
return %0 : tensor<2x2xf32>
|
||||
}
|
||||
|
||||
@ -226,7 +226,7 @@ func @float_cos(%arg0: tensor<2x2xf32>) -> tensor<2x2xf32> {
|
||||
func @float_sin(%arg0: tensor<2x2xf32>) -> tensor<2x2xf32> {
|
||||
// CHECK: linalg.generic
|
||||
// CHECK: sin
|
||||
%0 = "xla_hlo.sine"(%arg0) : (tensor<2x2xf32>) -> tensor<2x2xf32>
|
||||
%0 = "mhlo.sine"(%arg0) : (tensor<2x2xf32>) -> tensor<2x2xf32>
|
||||
return %0 : tensor<2x2xf32>
|
||||
}
|
||||
|
||||
@ -235,7 +235,7 @@ func @float_sin(%arg0: tensor<2x2xf32>) -> tensor<2x2xf32> {
|
||||
// CHECK-LABEL: func @copy
|
||||
// CHECK-SAME: [[ARG:%[a-zA-Z0-9]+]]
|
||||
func @copy(%input: tensor<2x4x8xf32>) -> tensor<2x4x8xf32> {
|
||||
%0 = "xla_hlo.copy"(%input) : (tensor<2x4x8xf32>) -> (tensor<2x4x8xf32>)
|
||||
%0 = "mhlo.copy"(%input) : (tensor<2x4x8xf32>) -> (tensor<2x4x8xf32>)
|
||||
return %0 : tensor<2x4x8xf32>
|
||||
}
|
||||
// CHECK: return [[ARG]] : tensor<2x4x8xf32>
|
||||
@ -245,7 +245,7 @@ func @copy(%input: tensor<2x4x8xf32>) -> tensor<2x4x8xf32> {
|
||||
// CHECK-LABEL: func @select
|
||||
func @select(%pred: tensor<2x2xi1>, %lhs: tensor<2x2xf32>,
|
||||
%rhs: tensor<2x2xf32>) -> tensor<2x2xf32> {
|
||||
%0 = "xla_hlo.select"(%pred, %lhs, %rhs)
|
||||
%0 = "mhlo.select"(%pred, %lhs, %rhs)
|
||||
: (tensor<2x2xi1>, tensor<2x2xf32>, tensor<2x2xf32>) -> (tensor<2x2xf32>)
|
||||
return %0 : tensor<2x2xf32>
|
||||
}
|
||||
@ -260,7 +260,7 @@ func @select(%pred: tensor<2x2xi1>, %lhs: tensor<2x2xf32>,
|
||||
// CHECK-DAG: #[[RESULT_MAP:.+]] = affine_map<(d0, d1, d2) -> (d0, d1, d2)>
|
||||
// CHECK-LABEL: func @broadcast_scalar
|
||||
func @broadcast_scalar(%arg: tensor<f32>) -> tensor<4x2x1xf32> {
|
||||
%0 = "xla_hlo.broadcast"(%arg) {broadcast_sizes = dense<[4, 2, 1]> : tensor<3xi64>} : (tensor<f32>) -> tensor<4x2x1xf32>
|
||||
%0 = "mhlo.broadcast"(%arg) {broadcast_sizes = dense<[4, 2, 1]> : tensor<3xi64>} : (tensor<f32>) -> tensor<4x2x1xf32>
|
||||
return %0: tensor<4x2x1xf32>
|
||||
}
|
||||
// CHECK: linalg.generic {{{.*}}indexing_maps = [#[[OPERAND_MAP]], #[[RESULT_MAP]]]
|
||||
@ -273,7 +273,7 @@ func @broadcast_scalar(%arg: tensor<f32>) -> tensor<4x2x1xf32> {
|
||||
// CHECK-DAG: #[[RESULT_MAP:.+]] = affine_map<(d0, d1, d2, d3, d4, d5) -> (d0, d1, d2, d3, d4, d5)>
|
||||
// CHECK-LABEL: func @broadcast
|
||||
func @broadcast(%arg: tensor<4x?x16xf32>) -> tensor<4x2x1x4x?x16xf32> {
|
||||
%0 = "xla_hlo.broadcast"(%arg) {broadcast_sizes = dense<[4, 2, 1]> : tensor<3xi64>} : (tensor<4x?x16xf32>) -> tensor<4x2x1x4x?x16xf32>
|
||||
%0 = "mhlo.broadcast"(%arg) {broadcast_sizes = dense<[4, 2, 1]> : tensor<3xi64>} : (tensor<4x?x16xf32>) -> tensor<4x2x1x4x?x16xf32>
|
||||
return %0: tensor<4x2x1x4x?x16xf32>
|
||||
}
|
||||
// CHECK: linalg.generic {{{.*}}indexing_maps = [#[[OPERAND_MAP]], #[[RESULT_MAP]]]
|
||||
@ -286,7 +286,7 @@ func @broadcast(%arg: tensor<4x?x16xf32>) -> tensor<4x2x1x4x?x16xf32> {
|
||||
// CHECK-DAG: #[[RESULT_MAP:.*]] = affine_map<(d0, d1, d2, d3, d4) -> (d0, d1, d2, d3, d4)>
|
||||
// CHECK-LABEL: func @broadcast_in_dim
|
||||
func @broadcast_in_dim(%operand: tensor<5x7x1xf32>) -> tensor<7x10x6x4x5xf32> {
|
||||
%0 = "xla_hlo.broadcast_in_dim"(%operand)
|
||||
%0 = "mhlo.broadcast_in_dim"(%operand)
|
||||
{broadcast_dimensions = dense<[4,0,2]> : tensor<3xi64>}
|
||||
: (tensor<5x7x1xf32>) -> tensor<7x10x6x4x5xf32>
|
||||
return %0 : tensor<7x10x6x4x5xf32>
|
||||
@ -302,7 +302,7 @@ func @broadcast_in_dim(%operand: tensor<5x7x1xf32>) -> tensor<7x10x6x4x5xf32> {
|
||||
// CHECK-LABEL: func @broadcast_in_dim_with_one_to_one
|
||||
func @broadcast_in_dim_with_one_to_one(
|
||||
%operand: tensor<1xf32>) -> tensor<1x5xf32> {
|
||||
%0 = "xla_hlo.broadcast_in_dim"(%operand)
|
||||
%0 = "mhlo.broadcast_in_dim"(%operand)
|
||||
{broadcast_dimensions = dense<[0]> : tensor<1xi64>}
|
||||
: (tensor<1xf32>) -> tensor<1x5xf32>
|
||||
return %0 : tensor<1x5xf32>
|
||||
@ -317,7 +317,7 @@ func @broadcast_in_dim_with_one_to_one(
|
||||
// CHECK-DAG: #[[RESULT_MAP:.*]] = affine_map<(d0, d1, d2) -> (d0, d1, d2)>
|
||||
// CHECK-LABEL: func @broadcast_scalar
|
||||
func @broadcast_scalar(%operand: tensor<f32>) -> tensor<7x10x6xf32> {
|
||||
%0 = "xla_hlo.broadcast_in_dim"(%operand)
|
||||
%0 = "mhlo.broadcast_in_dim"(%operand)
|
||||
{broadcast_dimensions = dense<[]> : tensor<0xi64>}
|
||||
: (tensor<f32>) -> tensor<7x10x6xf32>
|
||||
return %0 : tensor<7x10x6xf32>
|
||||
@ -332,7 +332,7 @@ func @broadcast_scalar(%operand: tensor<f32>) -> tensor<7x10x6xf32> {
|
||||
// CHECK-DAG: #[[RESULT_MAP:.*]] = affine_map<(d0, d1, d2, d3) -> (d0, d1, d2, d3)>
|
||||
// CHECK-LABEL: func @transpose
|
||||
func @transpose(%arg0: tensor<2x3x9x5xi32>) -> tensor<3x2x5x9xi32> {
|
||||
%0 = "xla_hlo.transpose"(%arg0) {permutation = dense<[1, 0, 3, 2]> : tensor<4xi64>}
|
||||
%0 = "mhlo.transpose"(%arg0) {permutation = dense<[1, 0, 3, 2]> : tensor<4xi64>}
|
||||
: (tensor<2x3x9x5xi32>) -> tensor<3x2x5x9xi32>
|
||||
return %0 : tensor<3x2x5x9xi32>
|
||||
}
|
||||
@ -344,7 +344,7 @@ func @transpose(%arg0: tensor<2x3x9x5xi32>) -> tensor<3x2x5x9xi32> {
|
||||
// CHECK-DAG: #[[RESHAPE_MAP2:.*]] = affine_map<(d0, d1, d2) -> (d2)>
|
||||
// CHECK-LABEL: func @reshape_3D_2D
|
||||
func @reshape_3D_2D(%arg0: tensor<12x1x42xi32>) -> tensor<12x42xi32> {
|
||||
%0 = "xla_hlo.reshape"(%arg0) : (tensor<12x1x42xi32>) -> tensor<12x42xi32>
|
||||
%0 = "mhlo.reshape"(%arg0) : (tensor<12x1x42xi32>) -> tensor<12x42xi32>
|
||||
return %0 : tensor<12x42xi32>
|
||||
}
|
||||
// CHECK: linalg.tensor_reshape %{{.*}} [#[[RESHAPE_MAP1]], #[[RESHAPE_MAP2]]]
|
||||
@ -355,7 +355,7 @@ func @reshape_3D_2D(%arg0: tensor<12x1x42xi32>) -> tensor<12x42xi32> {
|
||||
// CHECK-DAG: #[[RESHAPE_MAP2:.*]] = affine_map<(d0, d1, d2, d3) -> (d1, d2, d3)>
|
||||
// CHECK-LABEL: func @reshape_4D_2D
|
||||
func @reshape_4D_2D(%arg0: tensor<12x42x1x1xi32>) -> tensor<12x42xi32> {
|
||||
%0 = "xla_hlo.reshape"(%arg0) : (tensor<12x42x1x1xi32>) -> tensor<12x42xi32>
|
||||
%0 = "mhlo.reshape"(%arg0) : (tensor<12x42x1x1xi32>) -> tensor<12x42xi32>
|
||||
return %0 : tensor<12x42xi32>
|
||||
}
|
||||
// CHECK: linalg.tensor_reshape %{{.*}} [#[[RESHAPE_MAP1]], #[[RESHAPE_MAP2]]]
|
||||
@ -366,7 +366,7 @@ func @reshape_4D_2D(%arg0: tensor<12x42x1x1xi32>) -> tensor<12x42xi32> {
|
||||
// CHECK-DAG: #[[RESHAPE_MAP2:.*]] = affine_map<(d0, d1, d2, d3) -> (d2, d3)>
|
||||
// CHECK-LABEL: func @reshape_2D_4D
|
||||
func @reshape_2D_4D(%arg0: tensor<12x42xi32>) -> tensor<12x1x42x1xi32> {
|
||||
%0 = "xla_hlo.reshape"(%arg0) : (tensor<12x42xi32>) -> tensor<12x1x42x1xi32>
|
||||
%0 = "mhlo.reshape"(%arg0) : (tensor<12x42xi32>) -> tensor<12x1x42x1xi32>
|
||||
return %0 : tensor<12x1x42x1xi32>
|
||||
}
|
||||
// CHECK: linalg.tensor_reshape %{{.*}} [#[[RESHAPE_MAP1]], #[[RESHAPE_MAP2]]]
|
||||
@ -375,7 +375,7 @@ func @reshape_2D_4D(%arg0: tensor<12x42xi32>) -> tensor<12x1x42x1xi32> {
|
||||
|
||||
// CHECK-LABEL: func @minf
|
||||
func @minf(%lhs: tensor<2x2xf32>, %rhs: tensor<2x2xf32>) -> tensor<2x2xf32> {
|
||||
%0 = "xla_hlo.minimum"(%lhs, %rhs)
|
||||
%0 = "mhlo.minimum"(%lhs, %rhs)
|
||||
: (tensor<2x2xf32>, tensor<2x2xf32>) -> tensor<2x2xf32>
|
||||
return %0 : tensor<2x2xf32>
|
||||
}
|
||||
@ -389,7 +389,7 @@ func @minf(%lhs: tensor<2x2xf32>, %rhs: tensor<2x2xf32>) -> tensor<2x2xf32> {
|
||||
|
||||
// CHECK-LABEL: func @maxi
|
||||
func @maxi(%lhs: tensor<2x2xi32>, %rhs: tensor<2x2xi32>) -> tensor<2x2xi32> {
|
||||
%0 = "xla_hlo.maximum"(%lhs, %rhs)
|
||||
%0 = "mhlo.maximum"(%lhs, %rhs)
|
||||
: (tensor<2x2xi32>, tensor<2x2xi32>) -> tensor<2x2xi32>
|
||||
return %0 : tensor<2x2xi32>
|
||||
}
|
||||
@ -404,7 +404,7 @@ func @maxi(%lhs: tensor<2x2xi32>, %rhs: tensor<2x2xi32>) -> tensor<2x2xi32> {
|
||||
// CHECK-DAG: #[[MAP:.*]] = affine_map<() -> ()>
|
||||
// CHECK-LABEL: func @add_scalar
|
||||
func @add_scalar(%lhs: tensor<f32>, %rhs: tensor<f32>) -> tensor<f32> {
|
||||
%0 = "xla_hlo.add"(%lhs, %rhs) : (tensor<f32>, tensor<f32>) -> tensor<f32>
|
||||
%0 = "mhlo.add"(%lhs, %rhs) : (tensor<f32>, tensor<f32>) -> tensor<f32>
|
||||
return %0 : tensor<f32>
|
||||
}
|
||||
// CHECK: linalg.generic
|
||||
@ -417,7 +417,7 @@ func @add_scalar(%lhs: tensor<f32>, %rhs: tensor<f32>) -> tensor<f32> {
|
||||
|
||||
func @reshape_collapse_single_dim
|
||||
(%arg0: tensor<1x28x28x1xf32>) -> tensor<1x784xf32> {
|
||||
%0 = "xla_hlo.reshape"(%arg0) : (tensor<1x28x28x1xf32>) -> tensor<1x784xf32>
|
||||
%0 = "mhlo.reshape"(%arg0) : (tensor<1x28x28x1xf32>) -> tensor<1x784xf32>
|
||||
return %0 : tensor<1x784xf32>
|
||||
}
|
||||
// CHECK-DAG: #[[MAP0:.*]] = affine_map<(d0, d1, d2, d3) -> (d0)>
|
||||
@ -428,7 +428,7 @@ func @reshape_collapse_single_dim
|
||||
// -----
|
||||
|
||||
func @reshape_collapse(%arg0: tensor<2x2x2x3xf32>) -> tensor<2x4x3xf32> {
|
||||
%0 = "xla_hlo.reshape"(%arg0) : (tensor<2x2x2x3xf32>) -> tensor<2x4x3xf32>
|
||||
%0 = "mhlo.reshape"(%arg0) : (tensor<2x2x2x3xf32>) -> tensor<2x4x3xf32>
|
||||
return %0 : tensor<2x4x3xf32>
|
||||
}
|
||||
// CHECK-DAG: #[[MAP0:.*]] = affine_map<(d0, d1, d2, d3) -> (d0)>
|
||||
@ -440,7 +440,7 @@ func @reshape_collapse(%arg0: tensor<2x2x2x3xf32>) -> tensor<2x4x3xf32> {
|
||||
// -----
|
||||
|
||||
func @reshape_expand(%arg0: tensor<2x8xf32>) -> tensor<2x4x2xf32> {
|
||||
%0 = "xla_hlo.reshape"(%arg0) : (tensor<2x8xf32>) -> tensor<2x4x2xf32>
|
||||
%0 = "mhlo.reshape"(%arg0) : (tensor<2x8xf32>) -> tensor<2x4x2xf32>
|
||||
return %0 : tensor<2x4x2xf32>
|
||||
}
|
||||
// CHECK-DAG: #[[MAP0:.*]] = affine_map<(d0, d1, d2) -> (d0)>
|
||||
@ -451,7 +451,7 @@ func @reshape_expand(%arg0: tensor<2x8xf32>) -> tensor<2x4x2xf32> {
|
||||
// -----
|
||||
|
||||
func @reshape_single_expand(%arg0 : tensor<8xf32>) -> tensor<1x4x2xf32> {
|
||||
%0 = "xla_hlo.reshape"(%arg0) : (tensor<8xf32>) -> tensor<1x4x2xf32>
|
||||
%0 = "mhlo.reshape"(%arg0) : (tensor<8xf32>) -> tensor<1x4x2xf32>
|
||||
return %0 : tensor<1x4x2xf32>
|
||||
}
|
||||
// CHECK: #[[MAP0:.*]] = affine_map<(d0, d1, d2) -> (d0, d1, d2)>
|
||||
@ -462,7 +462,7 @@ func @reshape_single_expand(%arg0 : tensor<8xf32>) -> tensor<1x4x2xf32> {
|
||||
|
||||
func @reshape_multiple_collapse
|
||||
(%arg0 : tensor<1x2x2x5x3x2xf32>) -> tensor<1x4x5x6xf32> {
|
||||
%0 = "xla_hlo.reshape"(%arg0) : (tensor<1x2x2x5x3x2xf32>) -> tensor<1x4x5x6xf32>
|
||||
%0 = "mhlo.reshape"(%arg0) : (tensor<1x2x2x5x3x2xf32>) -> tensor<1x4x5x6xf32>
|
||||
return %0 : tensor<1x4x5x6xf32>
|
||||
}
|
||||
// CHECK-DAG: #[[MAP0:.*]] = affine_map<(d0, d1, d2, d3, d4, d5) -> (d0)>
|
||||
@ -476,7 +476,7 @@ func @reshape_multiple_collapse
|
||||
|
||||
// CHECK-LABEL: func @convert_i32_to_f32
|
||||
func @convert_i32_to_f32(%input: tensor<2x2xi32>) -> tensor<2x2xf32> {
|
||||
%result = "xla_hlo.convert"(%input) : (tensor<2x2xi32>) -> tensor<2x2xf32>
|
||||
%result = "mhlo.convert"(%input) : (tensor<2x2xi32>) -> tensor<2x2xf32>
|
||||
return %result : tensor<2x2xf32>
|
||||
}
|
||||
// CHECK: linalg.generic
|
||||
@ -488,7 +488,7 @@ func @convert_i32_to_f32(%input: tensor<2x2xi32>) -> tensor<2x2xf32> {
|
||||
|
||||
// CHECK-LABEL: func @convert_i16_to_i32
|
||||
func @convert_i16_to_i32(%input: tensor<2x2xi16>) -> tensor<2x2xi32> {
|
||||
%result = "xla_hlo.convert"(%input) : (tensor<2x2xi16>) -> tensor<2x2xi32>
|
||||
%result = "mhlo.convert"(%input) : (tensor<2x2xi16>) -> tensor<2x2xi32>
|
||||
return %result : tensor<2x2xi32>
|
||||
}
|
||||
// CHECK: linalg.generic
|
||||
@ -500,7 +500,7 @@ func @convert_i16_to_i32(%input: tensor<2x2xi16>) -> tensor<2x2xi32> {
|
||||
|
||||
// CHECK-LABEL: func @convert_i32_to_i16
|
||||
func @convert_i32_to_i16(%input: tensor<2x2xi32>) -> tensor<2x2xi16> {
|
||||
%result = "xla_hlo.convert"(%input) : (tensor<2x2xi32>) -> tensor<2x2xi16>
|
||||
%result = "mhlo.convert"(%input) : (tensor<2x2xi32>) -> tensor<2x2xi16>
|
||||
return %result : tensor<2x2xi16>
|
||||
}
|
||||
// CHECK: linalg.generic
|
||||
@ -512,7 +512,7 @@ func @convert_i32_to_i16(%input: tensor<2x2xi32>) -> tensor<2x2xi16> {
|
||||
|
||||
// CHECK-LABEL: func @convert_f32_to_f64
|
||||
func @convert_f32_to_f64(%input: tensor<2x2xf32>) -> tensor<2x2xf64> {
|
||||
%result = "xla_hlo.convert"(%input) : (tensor<2x2xf32>) -> tensor<2x2xf64>
|
||||
%result = "mhlo.convert"(%input) : (tensor<2x2xf32>) -> tensor<2x2xf64>
|
||||
return %result : tensor<2x2xf64>
|
||||
}
|
||||
// CHECK: linalg.generic
|
||||
@ -524,7 +524,7 @@ func @convert_f32_to_f64(%input: tensor<2x2xf32>) -> tensor<2x2xf64> {
|
||||
|
||||
// CHECK-LABEL: func @convert_f64_to_f32
|
||||
func @convert_f64_to_f32(%input: tensor<2x2xf64>) -> tensor<2x2xf32> {
|
||||
%result = "xla_hlo.convert"(%input) : (tensor<2x2xf64>) -> tensor<2x2xf32>
|
||||
%result = "mhlo.convert"(%input) : (tensor<2x2xf64>) -> tensor<2x2xf32>
|
||||
return %result : tensor<2x2xf32>
|
||||
}
|
||||
// CHECK: linalg.generic
|
||||
@ -536,7 +536,7 @@ func @convert_f64_to_f32(%input: tensor<2x2xf64>) -> tensor<2x2xf32> {
|
||||
|
||||
// CHECK-LABEL: func @convert_f32_to_i32
|
||||
func @convert_f32_to_i32(%input: tensor<2x2xf32>) -> tensor<2x2xi32> {
|
||||
%result = "xla_hlo.convert"(%input) : (tensor<2x2xf32>) -> tensor<2x2xi32>
|
||||
%result = "mhlo.convert"(%input) : (tensor<2x2xf32>) -> tensor<2x2xi32>
|
||||
return %result : tensor<2x2xi32>
|
||||
}
|
||||
// CHECK: linalg.generic
|
||||
@ -550,7 +550,7 @@ func @convert_f32_to_i32(%input: tensor<2x2xf32>) -> tensor<2x2xi32> {
|
||||
// CHECK-DAG: #[[RESULT_MAP:.*]] = affine_map<(d0, d1) -> (d0, d1)>
|
||||
// CHECK-LABEL: func @reverse
|
||||
func @reverse(%input: tensor<2x3xf32>) -> tensor<2x3xf32> {
|
||||
%result = "xla_hlo.reverse"(%input) {
|
||||
%result = "mhlo.reverse"(%input) {
|
||||
dimensions = dense<1> : tensor<1xi64>
|
||||
} : (tensor<2x3xf32>) -> tensor<2x3xf32>
|
||||
return %result : tensor<2x3xf32>
|
@ -1,28 +1,28 @@
|
||||
// RUN: xla-opt %s -inline | FileCheck %s
|
||||
// RUN: mlir-hlo-opt %s -inline | FileCheck %s
|
||||
|
||||
// Test case: Basic test of inlining into xla_hlo.while.
|
||||
// Test case: Basic test of inlining into mhlo.while.
|
||||
|
||||
// CHECK-LABEL: func @caller
|
||||
// CHECK: "xla_hlo.while"{{.*}}( {
|
||||
// CHECK: "mhlo.while"{{.*}}( {
|
||||
// CHECK: }, {
|
||||
// CHECK: "xla_hlo.exponential"
|
||||
// CHECK: "mhlo.exponential"
|
||||
// CHECK: })
|
||||
// CHECK-LABEL: func @callee
|
||||
|
||||
func @caller(%arg0: tensor<f32>, %pred: tensor<i1>) -> tensor<f32> {
|
||||
%0 = "xla_hlo.while"(%arg0) ( {
|
||||
%0 = "mhlo.while"(%arg0) ( {
|
||||
^entry(%unused: tensor<f32>):
|
||||
"xla_hlo.return"(%pred) : (tensor<i1>) -> ()
|
||||
"mhlo.return"(%pred) : (tensor<i1>) -> ()
|
||||
}, {
|
||||
^entry(%0: tensor<f32>):
|
||||
%1 = call @callee(%0) : (tensor<f32>) -> (tensor<f32>)
|
||||
"xla_hlo.return"(%1) : (tensor<f32>) -> ()
|
||||
"mhlo.return"(%1) : (tensor<f32>) -> ()
|
||||
} ) : (tensor<f32>) -> (tensor<f32>)
|
||||
return %0 : tensor<f32>
|
||||
}
|
||||
|
||||
|
||||
func @callee(%arg0: tensor<f32>) -> tensor<f32> {
|
||||
%0 = "xla_hlo.exponential"(%arg0) : (tensor<f32>) -> tensor<f32>
|
||||
%0 = "mhlo.exponential"(%arg0) : (tensor<f32>) -> tensor<f32>
|
||||
return %0 : tensor<f32>
|
||||
}
|
@ -1,24 +1,24 @@
|
||||
// RUN: xla-opt -xla-legalize-control-flow %s -o - | FileCheck %s
|
||||
// RUN: mlir-hlo-opt -xla-legalize-control-flow %s -o - | FileCheck %s
|
||||
|
||||
// CHECK-LABEL: func @while(%arg0: tensor<i64>) -> tensor<i64> {
|
||||
func @while(%arg0: tensor<i64>) -> tensor<i64> {
|
||||
//CHECK: br ^bb1(%arg0 : tensor<i64>)
|
||||
//CHECK: ^bb1([[VAL0:%.+]]: tensor<i64>):
|
||||
//CHECK: [[VAL1:%.+]] = "xla_hlo.compare"([[VAL0]], [[VAL0]])
|
||||
//CHECK: [[VAL1:%.+]] = "mhlo.compare"([[VAL0]], [[VAL0]])
|
||||
//CHECK: [[VAL2:%.+]] = extract_element [[VAL1]][] : tensor<i1>
|
||||
//CHECK: cond_br [[VAL2]], ^bb2([[VAL0]] : tensor<i64>), ^bb3([[VAL0]] : tensor<i64>)
|
||||
//CHECK: ^bb2([[VAL3:%.+]]: tensor<i64>):
|
||||
//CHECK: [[VAL4:%.+]] = xla_hlo.add [[VAL3]], [[VAL3]]
|
||||
//CHECK: [[VAL4:%.+]] = mhlo.add [[VAL3]], [[VAL3]]
|
||||
//CHECK: br ^bb1([[VAL4]] : tensor<i64>)
|
||||
//CHECK: ^bb3([[VAL5:%.+]]: tensor<i64>):
|
||||
%0 = "xla_hlo.while"(%arg0) ( {
|
||||
%0 = "mhlo.while"(%arg0) ( {
|
||||
^bb0(%arg1: tensor<i64>):
|
||||
%1 = "xla_hlo.compare"(%arg1, %arg1) {comparison_direction = "LT", name = "compare.2"} : (tensor<i64>, tensor<i64>) -> tensor<i1>
|
||||
"xla_hlo.return"(%1) : (tensor<i1>) -> ()
|
||||
%1 = "mhlo.compare"(%arg1, %arg1) {comparison_direction = "LT", name = "compare.2"} : (tensor<i64>, tensor<i64>) -> tensor<i1>
|
||||
"mhlo.return"(%1) : (tensor<i1>) -> ()
|
||||
}, {
|
||||
^bb0(%arg1: tensor<i64>):
|
||||
%1 = xla_hlo.add %arg1, %arg1 {name = "compare.0"} : tensor<i64>
|
||||
"xla_hlo.return"(%1) : (tensor<i64>) -> ()
|
||||
%1 = mhlo.add %arg1, %arg1 {name = "compare.0"} : tensor<i64>
|
||||
"mhlo.return"(%1) : (tensor<i64>) -> ()
|
||||
}) : (tensor<i64>) -> tensor<i64>
|
||||
|
||||
// CHECK-NEXT: return [[VAL5]]
|
||||
@ -30,27 +30,27 @@ func @conditional(%arg0: tensor<f32>) -> tensor<f32> {
|
||||
// CHECK: [[C0:%.+]] = constant dense<1.000000e+01> : tensor<f32>
|
||||
%cst = constant dense<1.000000e+01> : tensor<f32>
|
||||
|
||||
// CHECK: [[VAL0:%.+]] = "xla_hlo.compare"(%arg0, [[C0]]) {comparison_direction = "LT"} : (tensor<f32>, tensor<f32>) -> tensor<i1>
|
||||
%0 = "xla_hlo.compare"(%arg0, %cst) {comparison_direction = "LT"} : (tensor<f32>, tensor<f32>) -> tensor<i1>
|
||||
// CHECK: [[VAL0:%.+]] = "mhlo.compare"(%arg0, [[C0]]) {comparison_direction = "LT"} : (tensor<f32>, tensor<f32>) -> tensor<i1>
|
||||
%0 = "mhlo.compare"(%arg0, %cst) {comparison_direction = "LT"} : (tensor<f32>, tensor<f32>) -> tensor<i1>
|
||||
|
||||
// CHECK: [[VAL1:%.+]] = extract_element [[VAL0]][] : tensor<i1>
|
||||
// CHECK: cond_br [[VAL1]], ^bb1(%arg0 : tensor<f32>), ^bb2(%arg0 : tensor<f32>)
|
||||
%1 = "xla_hlo.if"(%0, %arg0, %arg0) ( {
|
||||
%1 = "mhlo.if"(%0, %arg0, %arg0) ( {
|
||||
|
||||
^bb0(%arg1: tensor<f32>):
|
||||
// CHECK: ^bb1([[VAL2:%.+]]: tensor<f32>):
|
||||
// CHECK: [[VAL3:%.+]] = "xla_hlo.log"([[VAL2]]) : (tensor<f32>) -> tensor<f32>
|
||||
// CHECK: [[VAL3:%.+]] = "mhlo.log"([[VAL2]]) : (tensor<f32>) -> tensor<f32>
|
||||
// CHECK: br ^bb3([[VAL3]] : tensor<f32>)
|
||||
%2 = "xla_hlo.log"(%arg1) : (tensor<f32>) -> tensor<f32>
|
||||
"xla_hlo.return"(%2) : (tensor<f32>) -> ()
|
||||
%2 = "mhlo.log"(%arg1) : (tensor<f32>) -> tensor<f32>
|
||||
"mhlo.return"(%2) : (tensor<f32>) -> ()
|
||||
}, {
|
||||
|
||||
^bb0(%arg1: tensor<f32>):
|
||||
// CHECK: ^bb2([[VAL4:%.+]]: tensor<f32>):
|
||||
// CHECK: [[VAL5:%.+]] = "xla_hlo.exponential"([[VAL4]]) : (tensor<f32>) -> tensor<f32>
|
||||
// CHECK: [[VAL5:%.+]] = "mhlo.exponential"([[VAL4]]) : (tensor<f32>) -> tensor<f32>
|
||||
// CHECK: br ^bb3([[VAL5]] : tensor<f32>)
|
||||
%2 = "xla_hlo.exponential"(%arg1) : (tensor<f32>) -> tensor<f32>
|
||||
"xla_hlo.return"(%2) : (tensor<f32>) -> ()
|
||||
%2 = "mhlo.exponential"(%arg1) : (tensor<f32>) -> tensor<f32>
|
||||
"mhlo.return"(%2) : (tensor<f32>) -> ()
|
||||
}) : (tensor<i1>, tensor<f32>, tensor<f32>) -> tensor<f32>
|
||||
|
||||
// CHECK: ^bb3([[VAL6:%.+]]: tensor<f32>):
|
||||
@ -62,27 +62,27 @@ func @conditional(%arg0: tensor<f32>) -> tensor<f32> {
|
||||
func @while_with_multiple_blocks_in_body(%arg0: tensor<i64>) -> tensor<i64> {
|
||||
// CHECK: br ^[[COND_ENTRY:.+]](%arg0 : tensor<i64>)
|
||||
// CHECK: ^[[COND_ENTRY]](%0: tensor<i64>):
|
||||
// CHECK: %1 = "xla_hlo.compare"(%0, %0) {comparison_direction = "LT"} : (tensor<i64>, tensor<i64>) -> tensor<i1>
|
||||
// CHECK: %1 = "mhlo.compare"(%0, %0) {comparison_direction = "LT"} : (tensor<i64>, tensor<i64>) -> tensor<i1>
|
||||
// CHECK: %2 = extract_element %1[] : tensor<i1>
|
||||
// CHECK: cond_br %2, ^[[BODY_ENTRY:.+]](%0 : tensor<i64>), ^[[EXIT:.+]](%0 : tensor<i64>)
|
||||
// CHECK: ^[[BODY_ENTRY]](%3: tensor<i64>):
|
||||
// CHECK: br ^[[BODY_SUCC:.+]](%3 : tensor<i64>)
|
||||
// CHECK: ^[[BODY_SUCC]](%4: tensor<i64>):
|
||||
// CHECK: %5 = xla_hlo.add %4, %4 : tensor<i64>
|
||||
// CHECK: %5 = mhlo.add %4, %4 : tensor<i64>
|
||||
// CHECK: br ^[[COND_ENTRY]](%5 : tensor<i64>)
|
||||
// CHECK: ^[[EXIT]](%6: tensor<i64>):
|
||||
// CHECK: return %6 : tensor<i64>
|
||||
// CHECK: }
|
||||
%0 = "xla_hlo.while"(%arg0) ( {
|
||||
%0 = "mhlo.while"(%arg0) ( {
|
||||
^cond_entry(%arg1: tensor<i64>):
|
||||
%1 = "xla_hlo.compare"(%arg1, %arg1) {comparison_direction = "LT"} : (tensor<i64>, tensor<i64>) -> tensor<i1>
|
||||
"xla_hlo.return"(%1) : (tensor<i1>) -> ()
|
||||
%1 = "mhlo.compare"(%arg1, %arg1) {comparison_direction = "LT"} : (tensor<i64>, tensor<i64>) -> tensor<i1>
|
||||
"mhlo.return"(%1) : (tensor<i1>) -> ()
|
||||
}, {
|
||||
^body_entry(%arg1: tensor<i64>):
|
||||
br ^body_succ(%arg1: tensor<i64>)
|
||||
^body_succ(%0: tensor<i64>):
|
||||
%1 = xla_hlo.add %0, %0 : tensor<i64>
|
||||
"xla_hlo.return"(%1) : (tensor<i64>) -> ()
|
||||
%1 = mhlo.add %0, %0 : tensor<i64>
|
||||
"mhlo.return"(%1) : (tensor<i64>) -> ()
|
||||
}) : (tensor<i64>) -> tensor<i64>
|
||||
|
||||
return %0 : tensor<i64>
|
||||
@ -94,7 +94,7 @@ func @while_with_multiple_blocks_in_cond(%arg0: tensor<i64>) -> tensor<i64> {
|
||||
// CHECK: ^[[COND_ENTRY]](%0: tensor<i64>):
|
||||
// CHECK: br ^[[COND_SUCC:.+]](%0 : tensor<i64>)
|
||||
// CHECK: ^[[COND_SUCC]](%1: tensor<i64>):
|
||||
// CHECK: %2 = "xla_hlo.compare"(%1, %1) {comparison_direction = "LT"} : (tensor<i64>, tensor<i64>) -> tensor<i1>
|
||||
// CHECK: %2 = "mhlo.compare"(%1, %1) {comparison_direction = "LT"} : (tensor<i64>, tensor<i64>) -> tensor<i1>
|
||||
// CHECK: %3 = extract_element %2[] : tensor<i1>
|
||||
// CHECK: cond_br %3, ^[[BODY_ENTRY:.+]](%0 : tensor<i64>), ^[[EXIT:.+]](%0 : tensor<i64>)
|
||||
// CHECK: ^[[BODY_ENTRY]](%4: tensor<i64>):
|
||||
@ -102,15 +102,15 @@ func @while_with_multiple_blocks_in_cond(%arg0: tensor<i64>) -> tensor<i64> {
|
||||
// CHECK: ^[[EXIT]](%5: tensor<i64>):
|
||||
// CHECK: return %5 : tensor<i64>
|
||||
// CHECK: }
|
||||
%0 = "xla_hlo.while"(%arg0) ( {
|
||||
%0 = "mhlo.while"(%arg0) ( {
|
||||
^cond_entry(%arg1: tensor<i64>):
|
||||
br ^cond_succ(%arg1: tensor<i64>)
|
||||
^cond_succ(%0: tensor<i64>):
|
||||
%1 = "xla_hlo.compare"(%0, %0) {comparison_direction = "LT"} : (tensor<i64>, tensor<i64>) -> tensor<i1>
|
||||
"xla_hlo.return"(%1) : (tensor<i1>) -> ()
|
||||
%1 = "mhlo.compare"(%0, %0) {comparison_direction = "LT"} : (tensor<i64>, tensor<i64>) -> tensor<i1>
|
||||
"mhlo.return"(%1) : (tensor<i1>) -> ()
|
||||
}, {
|
||||
^body_entry(%arg1: tensor<i64>):
|
||||
"xla_hlo.return"(%arg1) : (tensor<i64>) -> ()
|
||||
"mhlo.return"(%arg1) : (tensor<i64>) -> ()
|
||||
}) : (tensor<i64>) -> tensor<i64>
|
||||
|
||||
return %0 : tensor<i64>
|
||||
@ -123,24 +123,24 @@ func @conditional_with_multiple_blocks(%arg0: tensor<f32>, %arg1: tensor<f32>, %
|
||||
// CHECK: ^[[THEN_ENTRY]](%1: tensor<f32>):
|
||||
// CHECK: br ^[[THEN_SUCC:.+]](%1 : tensor<f32>)
|
||||
// CHECK: ^[[THEN_SUCC]](%2: tensor<f32>):
|
||||
// CHECK: %3 = "xla_hlo.log"(%2) : (tensor<f32>) -> tensor<f32>
|
||||
// CHECK: %3 = "mhlo.log"(%2) : (tensor<f32>) -> tensor<f32>
|
||||
// CHECK: br ^[[EXIT:.+]](%3 : tensor<f32>)
|
||||
// CHECK: ^[[ELSE_ENTRY]](%4: tensor<f32>):
|
||||
// CHECK: %5 = "xla_hlo.exponential"(%4) : (tensor<f32>) -> tensor<f32>
|
||||
// CHECK: %5 = "mhlo.exponential"(%4) : (tensor<f32>) -> tensor<f32>
|
||||
// CHECK: br ^[[EXIT]](%5 : tensor<f32>)
|
||||
// CHECK: ^[[EXIT]](%6: tensor<f32>):
|
||||
// CHECK: return %6 : tensor<f32>
|
||||
// CHECK: }
|
||||
%1 = "xla_hlo.if"(%pred, %arg0, %arg1) ( {
|
||||
%1 = "mhlo.if"(%pred, %arg0, %arg1) ( {
|
||||
^then_entry(%arg2: tensor<f32>):
|
||||
br ^then_succ(%arg2: tensor<f32>)
|
||||
^then_succ(%0: tensor<f32>):
|
||||
%2 = "xla_hlo.log"(%0) : (tensor<f32>) -> tensor<f32>
|
||||
"xla_hlo.return"(%2) : (tensor<f32>) -> ()
|
||||
%2 = "mhlo.log"(%0) : (tensor<f32>) -> tensor<f32>
|
||||
"mhlo.return"(%2) : (tensor<f32>) -> ()
|
||||
}, {
|
||||
^else_entry(%arg2: tensor<f32>):
|
||||
%2 = "xla_hlo.exponential"(%arg2) : (tensor<f32>) -> tensor<f32>
|
||||
"xla_hlo.return"(%2) : (tensor<f32>) -> ()
|
||||
%2 = "mhlo.exponential"(%arg2) : (tensor<f32>) -> tensor<f32>
|
||||
"mhlo.return"(%2) : (tensor<f32>) -> ()
|
||||
}) : (tensor<i1>, tensor<f32>, tensor<f32>) -> tensor<f32>
|
||||
return %1 : tensor<f32>
|
||||
}
|
@ -1,21 +1,21 @@
|
||||
// RUN: xla-opt -xla-legalize-to-std %s -o - | FileCheck %s
|
||||
// RUN: mlir-hlo-opt -xla-legalize-to-std %s -o - | FileCheck %s
|
||||
|
||||
// CHECK-LABEL: func @binary_ops_float(%arg0: tensor<4xf32>, %arg1: tensor<4xf32>) -> tensor<4xf32> {
|
||||
func @binary_ops_float(%arg0: tensor<4xf32>, %arg1: tensor<4xf32>) -> tensor<4xf32> {
|
||||
// CHECK-NEXT: %0 = addf %arg0, %arg1 : tensor<4xf32>
|
||||
%0 = "xla_hlo.add"(%arg0, %arg1) {name = "add.3"} : (tensor<4xf32>, tensor<4xf32>) -> tensor<4xf32>
|
||||
%0 = "mhlo.add"(%arg0, %arg1) {name = "add.3"} : (tensor<4xf32>, tensor<4xf32>) -> tensor<4xf32>
|
||||
|
||||
// CHECK-NEXT: %1 = mulf %0, %arg1 : tensor<4xf32>
|
||||
%1 = "xla_hlo.multiply"(%0, %arg1) {name = "mul.4"} : (tensor<4xf32>, tensor<4xf32>) -> tensor<4xf32>
|
||||
%1 = "mhlo.multiply"(%0, %arg1) {name = "mul.4"} : (tensor<4xf32>, tensor<4xf32>) -> tensor<4xf32>
|
||||
|
||||
// CHECK-NEXT: %2 = subf %1, %arg1 : tensor<4xf32>
|
||||
%2 = "xla_hlo.subtract"(%1, %arg1) {name = "sub.5"} : (tensor<4xf32>, tensor<4xf32>) -> tensor<4xf32>
|
||||
%2 = "mhlo.subtract"(%1, %arg1) {name = "sub.5"} : (tensor<4xf32>, tensor<4xf32>) -> tensor<4xf32>
|
||||
|
||||
// CHECK-NEXT: %3 = divf %2, %arg1 : tensor<4xf32>
|
||||
%3 = "xla_hlo.divide"(%2, %arg1) {name = "div.6"} : (tensor<4xf32>, tensor<4xf32>) -> tensor<4xf32>
|
||||
%3 = "mhlo.divide"(%2, %arg1) {name = "div.6"} : (tensor<4xf32>, tensor<4xf32>) -> tensor<4xf32>
|
||||
|
||||
// CHECK-NEXT: %4 = remf %3, %arg1 : tensor<4xf32>
|
||||
%4 = "xla_hlo.remainder"(%3, %arg1) : (tensor<4xf32>, tensor<4xf32>) -> tensor<4xf32>
|
||||
%4 = "mhlo.remainder"(%3, %arg1) : (tensor<4xf32>, tensor<4xf32>) -> tensor<4xf32>
|
||||
|
||||
// CHECK-NEXT: return %4 : tensor<4xf32>
|
||||
return %4 : tensor<4xf32>
|
||||
@ -24,19 +24,19 @@ func @binary_ops_float(%arg0: tensor<4xf32>, %arg1: tensor<4xf32>) -> tensor<4xf
|
||||
// CHECK-LABEL: func @binary_ops_int(%arg0: tensor<4xi32>, %arg1: tensor<4xi32>) -> tensor<4xi32> {
|
||||
func @binary_ops_int(%arg0: tensor<4xi32>, %arg1: tensor<4xi32>) -> tensor<4xi32> {
|
||||
// CHECK-NEXT: %0 = addi %arg0, %arg1 : tensor<4xi32>
|
||||
%0 = "xla_hlo.add"(%arg0, %arg1) {name = "add.3"} : (tensor<4xi32>, tensor<4xi32>) -> tensor<4xi32>
|
||||
%0 = "mhlo.add"(%arg0, %arg1) {name = "add.3"} : (tensor<4xi32>, tensor<4xi32>) -> tensor<4xi32>
|
||||
|
||||
// CHECK-NEXT: %1 = muli %0, %arg1 : tensor<4xi32>
|
||||
%1 = "xla_hlo.multiply"(%0, %arg1) {name = "mul.4"} : (tensor<4xi32>, tensor<4xi32>) -> tensor<4xi32>
|
||||
%1 = "mhlo.multiply"(%0, %arg1) {name = "mul.4"} : (tensor<4xi32>, tensor<4xi32>) -> tensor<4xi32>
|
||||
|
||||
// CHECK-NEXT: %2 = subi %1, %arg1 : tensor<4xi32>
|
||||
%2 = "xla_hlo.subtract"(%1, %arg1) {name = "sub.5"} : (tensor<4xi32>, tensor<4xi32>) -> tensor<4xi32>
|
||||
%2 = "mhlo.subtract"(%1, %arg1) {name = "sub.5"} : (tensor<4xi32>, tensor<4xi32>) -> tensor<4xi32>
|
||||
|
||||
// CHECK-NEXT: %3 = divi_signed %2, %arg1 : tensor<4xi32>
|
||||
%3 = "xla_hlo.divide"(%2, %arg1) {name = "div.6"} : (tensor<4xi32>, tensor<4xi32>) -> tensor<4xi32>
|
||||
%3 = "mhlo.divide"(%2, %arg1) {name = "div.6"} : (tensor<4xi32>, tensor<4xi32>) -> tensor<4xi32>
|
||||
|
||||
// CHECK-NEXT: %4 = remi_signed %3, %arg1 : tensor<4xi32>
|
||||
%4 = "xla_hlo.remainder"(%3, %arg1) : (tensor<4xi32>, tensor<4xi32>) -> tensor<4xi32>
|
||||
%4 = "mhlo.remainder"(%3, %arg1) : (tensor<4xi32>, tensor<4xi32>) -> tensor<4xi32>
|
||||
|
||||
// CHECK-NEXT: return %4 : tensor<4xi32>
|
||||
return %4 : tensor<4xi32>
|
||||
@ -45,17 +45,17 @@ func @binary_ops_int(%arg0: tensor<4xi32>, %arg1: tensor<4xi32>) -> tensor<4xi32
|
||||
// CHECK-LABEL: func @compare_int(%arg0: tensor<4xi32>) -> (tensor<4xi1>, tensor<4xi1>, tensor<4xi1>, tensor<4xi1>, tensor<4xi1>, tensor<4xi1>) {
|
||||
func @compare_int(%arg0: tensor<4xi32>) -> (tensor<4xi1>,tensor<4xi1>,tensor<4xi1>,tensor<4xi1>,tensor<4xi1>,tensor<4xi1>) {
|
||||
// CHECK-NEXT: %0 = cmpi "eq", %arg0, %arg0 : tensor<4xi32>
|
||||
%0 = "xla_hlo.compare"(%arg0, %arg0) {comparison_direction = "EQ"} : (tensor<4xi32>, tensor<4xi32>) -> tensor<4xi1>
|
||||
%0 = "mhlo.compare"(%arg0, %arg0) {comparison_direction = "EQ"} : (tensor<4xi32>, tensor<4xi32>) -> tensor<4xi1>
|
||||
// CHECK-NEXT: %1 = cmpi "ne", %arg0, %arg0 : tensor<4xi32>
|
||||
%1 = "xla_hlo.compare"(%arg0, %arg0) {comparison_direction = "NE"} : (tensor<4xi32>, tensor<4xi32>) -> tensor<4xi1>
|
||||
%1 = "mhlo.compare"(%arg0, %arg0) {comparison_direction = "NE"} : (tensor<4xi32>, tensor<4xi32>) -> tensor<4xi1>
|
||||
// CHECK-NEXT: %2 = cmpi "slt", %arg0, %arg0 : tensor<4xi32>
|
||||
%2 = "xla_hlo.compare"(%arg0, %arg0) {comparison_direction = "LT"} : (tensor<4xi32>, tensor<4xi32>) -> tensor<4xi1>
|
||||
%2 = "mhlo.compare"(%arg0, %arg0) {comparison_direction = "LT"} : (tensor<4xi32>, tensor<4xi32>) -> tensor<4xi1>
|
||||
// CHECK-NEXT: %3 = cmpi "sle", %arg0, %arg0 : tensor<4xi32>
|
||||
%3 = "xla_hlo.compare"(%arg0, %arg0) {comparison_direction = "LE"} : (tensor<4xi32>, tensor<4xi32>) -> tensor<4xi1>
|
||||
%3 = "mhlo.compare"(%arg0, %arg0) {comparison_direction = "LE"} : (tensor<4xi32>, tensor<4xi32>) -> tensor<4xi1>
|
||||
// CHECK-NEXT: %4 = cmpi "sgt", %arg0, %arg0 : tensor<4xi32>
|
||||
%4 = "xla_hlo.compare"(%arg0, %arg0) {comparison_direction = "GT"} : (tensor<4xi32>, tensor<4xi32>) -> tensor<4xi1>
|
||||
%4 = "mhlo.compare"(%arg0, %arg0) {comparison_direction = "GT"} : (tensor<4xi32>, tensor<4xi32>) -> tensor<4xi1>
|
||||
// CHECK-NEXT: %5 = cmpi "sge", %arg0, %arg0 : tensor<4xi32>
|
||||
%5 = "xla_hlo.compare"(%arg0, %arg0) {comparison_direction = "GE"} : (tensor<4xi32>, tensor<4xi32>) -> tensor<4xi1>
|
||||
%5 = "mhlo.compare"(%arg0, %arg0) {comparison_direction = "GE"} : (tensor<4xi32>, tensor<4xi32>) -> tensor<4xi1>
|
||||
// CHECK-NEXT: return %0, %1, %2, %3, %4, %5 : tensor<4xi1>, tensor<4xi1>, tensor<4xi1>, tensor<4xi1>, tensor<4xi1>, tensor<4xi1>
|
||||
return %0, %1, %2, %3, %4, %5 : tensor<4xi1>, tensor<4xi1>, tensor<4xi1>, tensor<4xi1>, tensor<4xi1>, tensor<4xi1>
|
||||
}
|
||||
@ -63,28 +63,28 @@ func @compare_int(%arg0: tensor<4xi32>) -> (tensor<4xi1>,tensor<4xi1>,tensor<4xi
|
||||
// CHECK-LABEL: func @compare_float
|
||||
func @compare_float(%arg0: tensor<4xf32>) -> (tensor<4xi1>,tensor<4xi1>,tensor<4xi1>,tensor<4xi1>,tensor<4xi1>,tensor<4xi1>) {
|
||||
// CHECK-NEXT: %0 = cmpf "oeq", %arg0, %arg0 : tensor<4xf32>
|
||||
%0 = "xla_hlo.compare"(%arg0, %arg0) {comparison_direction = "EQ"} : (tensor<4xf32>, tensor<4xf32>) -> tensor<4xi1>
|
||||
%0 = "mhlo.compare"(%arg0, %arg0) {comparison_direction = "EQ"} : (tensor<4xf32>, tensor<4xf32>) -> tensor<4xi1>
|
||||
// CHECK-NEXT: %1 = cmpf "une", %arg0, %arg0 : tensor<4xf32>
|
||||
%1 = "xla_hlo.compare"(%arg0, %arg0) {comparison_direction = "NE"} : (tensor<4xf32>, tensor<4xf32>) -> tensor<4xi1>
|
||||
%1 = "mhlo.compare"(%arg0, %arg0) {comparison_direction = "NE"} : (tensor<4xf32>, tensor<4xf32>) -> tensor<4xi1>
|
||||
// CHECK-NEXT: %2 = cmpf "olt", %arg0, %arg0 : tensor<4xf32>
|
||||
%2 = "xla_hlo.compare"(%arg0, %arg0) {comparison_direction = "LT"} : (tensor<4xf32>, tensor<4xf32>) -> tensor<4xi1>
|
||||
%2 = "mhlo.compare"(%arg0, %arg0) {comparison_direction = "LT"} : (tensor<4xf32>, tensor<4xf32>) -> tensor<4xi1>
|
||||
// CHECK-NEXT: %3 = cmpf "ole", %arg0, %arg0 : tensor<4xf32>
|
||||
%3 = "xla_hlo.compare"(%arg0, %arg0) {comparison_direction = "LE"} : (tensor<4xf32>, tensor<4xf32>) -> tensor<4xi1>
|
||||
%3 = "mhlo.compare"(%arg0, %arg0) {comparison_direction = "LE"} : (tensor<4xf32>, tensor<4xf32>) -> tensor<4xi1>
|
||||
// CHECK-NEXT: %4 = cmpf "ogt", %arg0, %arg0 : tensor<4xf32>
|
||||
%4 = "xla_hlo.compare"(%arg0, %arg0) {comparison_direction = "GT"} : (tensor<4xf32>, tensor<4xf32>) -> tensor<4xi1>
|
||||
%4 = "mhlo.compare"(%arg0, %arg0) {comparison_direction = "GT"} : (tensor<4xf32>, tensor<4xf32>) -> tensor<4xi1>
|
||||
// CHECK-NEXT: %5 = cmpf "oge", %arg0, %arg0 : tensor<4xf32>
|
||||
%5 = "xla_hlo.compare"(%arg0, %arg0) {comparison_direction = "GE"} : (tensor<4xf32>, tensor<4xf32>) -> tensor<4xi1>
|
||||
%5 = "mhlo.compare"(%arg0, %arg0) {comparison_direction = "GE"} : (tensor<4xf32>, tensor<4xf32>) -> tensor<4xi1>
|
||||
return %0, %1, %2, %3, %4, %5: tensor<4xi1>, tensor<4xi1>, tensor<4xi1>, tensor<4xi1>, tensor<4xi1>, tensor<4xi1>
|
||||
}
|
||||
|
||||
// CHECK-LABEL: func @int_constant
|
||||
func @int_constant() -> (tensor<i32>, tensor<2x3xi32>, tensor<2x3xi32>) {
|
||||
// CHECK-NEXT: [[CST0:%.+]] = constant {{.+}} : tensor<i32>
|
||||
%0 = "xla_hlo.constant"() {value = dense<0> : tensor<i32>} : () -> (tensor<i32>)
|
||||
%0 = "mhlo.constant"() {value = dense<0> : tensor<i32>} : () -> (tensor<i32>)
|
||||
// CHECK-NEXT: [[CST1:%.+]] = constant {{.+}} : tensor<2x3xi32>
|
||||
%1 = "xla_hlo.constant"() {value = dense<1> : tensor<2x3xi32>} : () -> (tensor<2x3xi32>)
|
||||
%1 = "mhlo.constant"() {value = dense<1> : tensor<2x3xi32>} : () -> (tensor<2x3xi32>)
|
||||
// CHECK-NEXT: [[CST2:%.+]] = constant {{.+}} : tensor<2x3xi32>
|
||||
%2 = "xla_hlo.constant"() {value = dense<[[1, 2, 3], [4, 5, 6]]> : tensor<2x3xi32>} : () -> (tensor<2x3xi32>)
|
||||
%2 = "mhlo.constant"() {value = dense<[[1, 2, 3], [4, 5, 6]]> : tensor<2x3xi32>} : () -> (tensor<2x3xi32>)
|
||||
// CHECK-NEXT: return [[CST0]], [[CST1]], [[CST2]] : tensor<i32>, tensor<2x3xi32>, tensor<2x3xi32>
|
||||
return %0, %1, %2: tensor<i32>, tensor<2x3xi32>, tensor<2x3xi32>
|
||||
}
|
||||
@ -92,11 +92,11 @@ func @int_constant() -> (tensor<i32>, tensor<2x3xi32>, tensor<2x3xi32>) {
|
||||
// CHECK-LABEL: func @float_constant
|
||||
func @float_constant() -> (tensor<f32>, tensor<2x3xf32>, tensor<2x3xf32>) {
|
||||
// CHECK-NEXT: [[CST0:%.+]] = constant {{.+}} : tensor<f32>
|
||||
%0 = "xla_hlo.constant"() {value = dense<0.0> : tensor<f32>} : () -> (tensor<f32>)
|
||||
%0 = "mhlo.constant"() {value = dense<0.0> : tensor<f32>} : () -> (tensor<f32>)
|
||||
// CHECK-NEXT: [[CST1:%.+]] = constant {{.+}} : tensor<2x3xf32>
|
||||
%1 = "xla_hlo.constant"() {value = dense<1.0> : tensor<2x3xf32>} : () -> (tensor<2x3xf32>)
|
||||
%1 = "mhlo.constant"() {value = dense<1.0> : tensor<2x3xf32>} : () -> (tensor<2x3xf32>)
|
||||
// CHECK-NEXT: [[CST2:%.+]] = constant {{.+}} : tensor<2x3xf32>
|
||||
%2 = "xla_hlo.constant"() {value = dense<[[1.0, 2.0, 3.0], [4.0, 5.0, 6.0]]> : tensor<2x3xf32>} : () -> (tensor<2x3xf32>)
|
||||
%2 = "mhlo.constant"() {value = dense<[[1.0, 2.0, 3.0], [4.0, 5.0, 6.0]]> : tensor<2x3xf32>} : () -> (tensor<2x3xf32>)
|
||||
// CHECK-NEXT: return [[CST0]], [[CST1]], [[CST2]] : tensor<f32>, tensor<2x3xf32>, tensor<2x3xf32>
|
||||
return %0, %1, %2: tensor<f32>, tensor<2x3xf32>, tensor<2x3xf32>
|
||||
}
|
||||
@ -105,7 +105,7 @@ func @float_constant() -> (tensor<f32>, tensor<2x3xf32>, tensor<2x3xf32>) {
|
||||
// CHECK-LABEL: func @iota.const.1() -> tensor<4xi32> {
|
||||
func @iota.const.1() -> tensor<4xi32> {
|
||||
// CHECK-NEXT: %[[CST:.*]] = constant dense<[0, 1, 2, 3]> : tensor<4xi32>
|
||||
%0 = "xla_hlo.iota"() {iota_dimension = 0 : i64} : () -> tensor<4xi32>
|
||||
%0 = "mhlo.iota"() {iota_dimension = 0 : i64} : () -> tensor<4xi32>
|
||||
// CHECK-NEXT: return %[[CST]] : tensor<4xi32>
|
||||
return %0 : tensor<4xi32>
|
||||
}
|
||||
@ -113,7 +113,7 @@ func @iota.const.1() -> tensor<4xi32> {
|
||||
// CHECK-LABEL: func @iota.const.2() -> tensor<2x4xi32> {
|
||||
func @iota.const.2() -> tensor<2x4xi32> {
|
||||
// CHECK-NEXT: %[[CST:.*]] = constant dense<{{\[\[}}0, 0, 0, 0], [1, 1, 1, 1]]> : tensor<2x4xi32>
|
||||
%0 = "xla_hlo.iota"() {iota_dimension = 0 : i64} : () -> tensor<2x4xi32>
|
||||
%0 = "mhlo.iota"() {iota_dimension = 0 : i64} : () -> tensor<2x4xi32>
|
||||
// CHECK-NEXT: return %[[CST]] : tensor<2x4xi32>
|
||||
return %0 : tensor<2x4xi32>
|
||||
}
|
||||
@ -121,7 +121,7 @@ func @iota.const.2() -> tensor<2x4xi32> {
|
||||
// CHECK-LABEL: func @iota.const.3() -> tensor<2x4xi32> {
|
||||
func @iota.const.3() -> tensor<2x4xi32> {
|
||||
// CHECK-NEXT: %[[CST:.*]] = constant dense<{{\[\[}}0, 1, 2, 3], [0, 1, 2, 3]]> : tensor<2x4xi32>
|
||||
%0 = "xla_hlo.iota"() {iota_dimension = 1 : i64} : () -> tensor<2x4xi32>
|
||||
%0 = "mhlo.iota"() {iota_dimension = 1 : i64} : () -> tensor<2x4xi32>
|
||||
// CHECK-NEXT: return %[[CST]] : tensor<2x4xi32>
|
||||
return %0 : tensor<2x4xi32>
|
||||
}
|
||||
@ -129,7 +129,7 @@ func @iota.const.3() -> tensor<2x4xi32> {
|
||||
// CHECK-LABEL: func @iota.const.4() -> tensor<2x3x4xi32> {
|
||||
func @iota.const.4() -> tensor<2x3x4xi32> {
|
||||
// CHECK-NEXT: %[[CST:.*]] = constant dense<{{\[\[\[}}0, 0, 0, 0], [0, 0, 0, 0], [0, 0, 0, 0{{\]\]}}, {{\[\[}}1, 1, 1, 1], [1, 1, 1, 1], [1, 1, 1, 1]]]> : tensor<2x3x4xi32>
|
||||
%0 = "xla_hlo.iota"() {iota_dimension = 0 : i64} : () -> tensor<2x3x4xi32>
|
||||
%0 = "mhlo.iota"() {iota_dimension = 0 : i64} : () -> tensor<2x3x4xi32>
|
||||
// CHECK-NEXT: return %[[CST]] : tensor<2x3x4xi32>
|
||||
return %0 : tensor<2x3x4xi32>
|
||||
}
|
||||
@ -137,7 +137,7 @@ func @iota.const.4() -> tensor<2x3x4xi32> {
|
||||
// CHECK-LABEL: func @iota.const.5() -> tensor<2x3x4xi32> {
|
||||
func @iota.const.5() -> tensor<2x3x4xi32> {
|
||||
// CHECK-NEXT: %[[CST:.*]] = constant dense<{{\[\[\[}}0, 0, 0, 0], [1, 1, 1, 1], [2, 2, 2, 2{{\]\]}}, {{\[\[}}0, 0, 0, 0], [1, 1, 1, 1], [2, 2, 2, 2]]]> : tensor<2x3x4xi32>
|
||||
%0 = "xla_hlo.iota"() {iota_dimension = 1 : i64} : () -> tensor<2x3x4xi32>
|
||||
%0 = "mhlo.iota"() {iota_dimension = 1 : i64} : () -> tensor<2x3x4xi32>
|
||||
// CHECK-NEXT: return %[[CST]] : tensor<2x3x4xi32>
|
||||
return %0 : tensor<2x3x4xi32>
|
||||
}
|
||||
@ -145,7 +145,7 @@ func @iota.const.5() -> tensor<2x3x4xi32> {
|
||||
// CHECK-LABEL: func @iota.const.6() -> tensor<2x3x4xi32> {
|
||||
func @iota.const.6() -> tensor<2x3x4xi32> {
|
||||
// CHECK-NEXT: %[[CST:.*]] = constant dense<{{\[\[\[}}0, 1, 2, 3], [0, 1, 2, 3], [0, 1, 2, 3{{\]\]}}, {{\[\[}}0, 1, 2, 3], [0, 1, 2, 3], [0, 1, 2, 3]]]> : tensor<2x3x4xi32>
|
||||
%0 = "xla_hlo.iota"() {iota_dimension = 2 : i64} : () -> tensor<2x3x4xi32>
|
||||
%0 = "mhlo.iota"() {iota_dimension = 2 : i64} : () -> tensor<2x3x4xi32>
|
||||
// CHECK-NEXT: return %[[CST]] : tensor<2x3x4xi32>
|
||||
return %0 : tensor<2x3x4xi32>
|
||||
}
|
||||
@ -153,7 +153,7 @@ func @iota.const.6() -> tensor<2x3x4xi32> {
|
||||
// CHECK-LABEL: func @iota.const.f32
|
||||
func @iota.const.f32() -> tensor<4xf32> {
|
||||
// CHECK-NEXT: %[[CST:.*]] = constant dense<[0.000000e+00, 1.000000e+00, 2.000000e+00, 3.000000e+00]> : tensor<4xf32>
|
||||
%0 = "xla_hlo.iota"() {iota_dimension = 0 : i64} : () -> tensor<4xf32>
|
||||
%0 = "mhlo.iota"() {iota_dimension = 0 : i64} : () -> tensor<4xf32>
|
||||
// CHECK-NEXT: return %[[CST]] : tensor<4xf32>
|
||||
return %0 : tensor<4xf32>
|
||||
}
|
||||
@ -161,7 +161,7 @@ func @iota.const.f32() -> tensor<4xf32> {
|
||||
// CHECK-LABEL: func @iota.const.f64
|
||||
func @iota.const.f64() -> tensor<4xf64> {
|
||||
// CHECK-NEXT: %[[CST:.*]] = constant dense<[0.000000e+00, 1.000000e+00, 2.000000e+00, 3.000000e+00]> : tensor<4xf64>
|
||||
%0 = "xla_hlo.iota"() {iota_dimension = 0 : i64} : () -> tensor<4xf64>
|
||||
%0 = "mhlo.iota"() {iota_dimension = 0 : i64} : () -> tensor<4xf64>
|
||||
// CHECK-NEXT: return %[[CST]] : tensor<4xf64>
|
||||
return %0 : tensor<4xf64>
|
||||
}
|
||||
@ -169,7 +169,7 @@ func @iota.const.f64() -> tensor<4xf64> {
|
||||
// CHECK-LABEL: func @iota.const.bf16
|
||||
func @iota.const.bf16() -> tensor<4xbf16> {
|
||||
// CHECK-NEXT: %[[CST:.*]] = constant dense<[0.000000e+00, 1.000000e+00, 2.000000e+00, 3.000000e+00]> : tensor<4xbf16>
|
||||
%0 = "xla_hlo.iota"() {iota_dimension = 0 : i64} : () -> tensor<4xbf16>
|
||||
%0 = "mhlo.iota"() {iota_dimension = 0 : i64} : () -> tensor<4xbf16>
|
||||
// CHECK-NEXT: return %[[CST]] : tensor<4xbf16>
|
||||
return %0 : tensor<4xbf16>
|
||||
}
|
||||
@ -178,8 +178,8 @@ func @iota.const.bf16() -> tensor<4xbf16> {
|
||||
func @iota.const.complex.f32() -> tensor<4xcomplex<f32>> {
|
||||
// CHECK-NEXT: [[REAL:%.*]] = constant dense<[0.000000e+00, 1.000000e+00, 2.000000e+00, 3.000000e+00]> : tensor<4xf32>
|
||||
// CHECK-NEXT: [[IMAG:%.*]] = constant dense<0.000000e+00> : tensor<4xf32>
|
||||
// CHECK-NEXT: [[COMPLEX:%.*]] = "xla_hlo.complex"([[REAL]], [[IMAG]])
|
||||
%0 = "xla_hlo.iota"() {iota_dimension = 0 : i64} : () -> tensor<4xcomplex<f32>>
|
||||
// CHECK-NEXT: [[COMPLEX:%.*]] = "mhlo.complex"([[REAL]], [[IMAG]])
|
||||
%0 = "mhlo.iota"() {iota_dimension = 0 : i64} : () -> tensor<4xcomplex<f32>>
|
||||
// CHECK-NEXT: return [[COMPLEX]] : tensor<4xcomplex<f32>>
|
||||
return %0 : tensor<4xcomplex<f32>>
|
||||
}
|
||||
@ -188,8 +188,8 @@ func @iota.const.complex.f32() -> tensor<4xcomplex<f32>> {
|
||||
func @iota.const.complex.f64() -> tensor<4xcomplex<f64>> {
|
||||
// CHECK-NEXT: [[REAL:%.*]] = constant dense<[0.000000e+00, 1.000000e+00, 2.000000e+00, 3.000000e+00]> : tensor<4xf64>
|
||||
// CHECK-NEXT: [[IMAG:%.*]] = constant dense<0.000000e+00> : tensor<4xf64>
|
||||
// CHECK-NEXT: [[COMPLEX:%.*]] = "xla_hlo.complex"([[REAL]], [[IMAG]])
|
||||
%0 = "xla_hlo.iota"() {iota_dimension = 0 : i64} : () -> tensor<4xcomplex<f64>>
|
||||
// CHECK-NEXT: [[COMPLEX:%.*]] = "mhlo.complex"([[REAL]], [[IMAG]])
|
||||
%0 = "mhlo.iota"() {iota_dimension = 0 : i64} : () -> tensor<4xcomplex<f64>>
|
||||
// CHECK-NEXT: return [[COMPLEX]] : tensor<4xcomplex<f64>>
|
||||
return %0 : tensor<4xcomplex<f64>>
|
||||
}
|
@ -1,4 +1,4 @@
|
||||
// RUN: xla-opt -xla-legalize-tanh-to-approximation -split-input-file %s | FileCheck %s
|
||||
// RUN: mlir-hlo-opt -xla-legalize-tanh-to-approximation -split-input-file %s | FileCheck %s
|
||||
|
||||
func @tanh_f64(%arg0 : f64) -> f64 {
|
||||
%res = tanh %arg0 : f64
|
@ -1,4 +1,4 @@
|
||||
// RUN: xla-opt -lhlo-copy-removal %s -o - | FileCheck %s
|
||||
// RUN: mlir-hlo-opt -lhlo-copy-removal %s -o - | FileCheck %s
|
||||
|
||||
// CHECK-LABEL: func @remove_simple
|
||||
func @remove_simple(%arg0: memref<2x2xf32>) {
|
@ -1,6 +1,6 @@
|
||||
// RUN: xla-opt -lhlo-fuse-linalg %s -split-input-file | FileCheck %s --dump-input=always
|
||||
// RUN: xla-opt -lhlo-fuse-linalg=tile-sizes=2,3 %s -split-input-file | FileCheck %s -check-prefix=TILED
|
||||
// RUN: xla-opt -lhlo-fuse-linalg=use-parallel-loops %s -split-input-file | FileCheck %s -check-prefix=PLOOP
|
||||
// RUN: mlir-hlo-opt -lhlo-fuse-linalg %s -split-input-file | FileCheck %s --dump-input=always
|
||||
// RUN: mlir-hlo-opt -lhlo-fuse-linalg=tile-sizes=2,3 %s -split-input-file | FileCheck %s -check-prefix=TILED
|
||||
// RUN: mlir-hlo-opt -lhlo-fuse-linalg=use-parallel-loops %s -split-input-file | FileCheck %s -check-prefix=PLOOP
|
||||
|
||||
#map0 = affine_map<(d0, d1) -> (d0, d1)>
|
||||
#pointwise_2d_trait = {args_in = 2, args_out = 1, indexing_maps = [#map0, #map0, #map0], iterator_types = ["parallel", "parallel"]}
|
@ -4,7 +4,7 @@
|
||||
// Lowering to STD dialect and store forwarding pass would be required to get
|
||||
// rid of them. This is exactly what is done in the real MLIR GPU pipeline, but
|
||||
// here we disable verification with `verify-each=0` to check the output IR.
|
||||
// RUN: xla-opt %s -lhlo-legalize-to-parallel-loops -canonicalize --verify-each=0 | FileCheck %s
|
||||
// RUN: mlir-hlo-opt %s -lhlo-legalize-to-parallel-loops -canonicalize --verify-each=0 | FileCheck %s
|
||||
|
||||
func @select_and_scatter(%arg: memref<112x112xf32>,
|
||||
%src: memref<56x56xf32>,
|
@ -1,4 +1,4 @@
|
||||
// RUN: xla-opt -lhlo-legalize-to-affine %s -o - | FileCheck %s
|
||||
// RUN: mlir-hlo-opt -lhlo-legalize-to-affine %s -o - | FileCheck %s
|
||||
|
||||
// Smoke test.
|
||||
// CHECK-LABEL: func @min_op
|
@ -1,4 +1,4 @@
|
||||
// RUN: xla-opt %s -lhlo-legalize-to-gpu -split-input-file | FileCheck %s
|
||||
// RUN: mlir-hlo-opt %s -lhlo-legalize-to-gpu -split-input-file | FileCheck %s
|
||||
|
||||
func @reduce(%arg: memref<100x10xf32>,
|
||||
%init: memref<f32>,
|
@ -1,4 +1,4 @@
|
||||
// RUN: xla-opt %s -lhlo-legalize-to-linalg -split-input-file | FileCheck %s
|
||||
// RUN: mlir-hlo-opt %s -lhlo-legalize-to-linalg -split-input-file | FileCheck %s
|
||||
|
||||
// CHECK: #map0 = affine_map<(d0, d1) -> (d0, d1)>
|
||||
// CHECK-LABEL: func @element_wise
|
@ -1,4 +1,4 @@
|
||||
// RUN: xla-opt %s --test-lhlo-legalize-to-llvm -split-input-file | FileCheck %s
|
||||
// RUN: mlir-hlo-opt %s --test-lhlo-legalize-to-llvm -split-input-file | FileCheck %s
|
||||
|
||||
// CHECK-LABEL: func @static_memref_cast
|
||||
func @static_memref_cast(%buf : memref<10x1x5xf32>) {
|
@ -1,4 +1,4 @@
|
||||
// RUN: xla-opt %s -lhlo-legalize-to-parallel-loops -canonicalize -split-input-file | FileCheck %s
|
||||
// RUN: mlir-hlo-opt %s -lhlo-legalize-to-parallel-loops -canonicalize -split-input-file | FileCheck %s
|
||||
|
||||
func @reduce(%arg: memref<100x10x5xf32>,
|
||||
%init: memref<f32>,
|
@ -1,4 +1,4 @@
|
||||
// RUN: xla-opt %s -verify-diagnostics -split-input-file | xla-opt | FileCheck %s
|
||||
// RUN: mlir-hlo-opt %s -verify-diagnostics -split-input-file | mlir-hlo-opt | FileCheck %s
|
||||
|
||||
// -----
|
||||
|
||||
@ -396,9 +396,9 @@ func @fusion_memref(%input1: memref<10xf32>, %input2: memref<10xf32>, %input3: m
|
||||
"xla_lhlo.fusion"() ( {
|
||||
%0 = tensor_load %input1 : memref<10xf32>
|
||||
%1 = tensor_load %input2 : memref<10xf32>
|
||||
%2 = "xla_hlo.add"(%0, %1) {name = "add"} : (tensor<10xf32>, tensor<10xf32>) -> tensor<10xf32>
|
||||
%2 = "mhlo.add"(%0, %1) {name = "add"} : (tensor<10xf32>, tensor<10xf32>) -> tensor<10xf32>
|
||||
%3 = tensor_load %input3 : memref<10xf32>
|
||||
%4 = "xla_hlo.multiply"(%2, %3) {name = "multiply"} : (tensor<10xf32>, tensor<10xf32>) -> tensor<10xf32>
|
||||
%4 = "mhlo.multiply"(%2, %3) {name = "multiply"} : (tensor<10xf32>, tensor<10xf32>) -> tensor<10xf32>
|
||||
tensor_store %4, %out : memref<10xf32>
|
||||
"xla_lhlo.terminator"() : () -> ()
|
||||
} ) : () -> ()
|
||||
@ -803,15 +803,15 @@ func @shift_right_logical_memrefs(%arg0: memref<1xf32>, %arg1: memref<1xf32>, %a
|
||||
func @all_reduce_memrefs(%arg0: memref<10xf32>, %arg_out: memref<10xf32>) -> () {
|
||||
"xla_lhlo.all_reduce"(%arg0, %arg_out) ({
|
||||
^bb0(%lhs: tensor<f32>, %rhs: tensor<f32>):
|
||||
%max = xla_hlo.maximum %lhs, %rhs : tensor<f32>
|
||||
"xla_hlo.return"(%max) : (tensor<f32>) -> ()
|
||||
%max = mhlo.maximum %lhs, %rhs : tensor<f32>
|
||||
"mhlo.return"(%max) : (tensor<f32>) -> ()
|
||||
})
|
||||
{ replica_groups = dense<[[0, 2, 4, 6], [1, 3, 5, 7]]> : tensor<2x4xi64> }: (memref<10xf32>, memref<10xf32>) -> ()
|
||||
|
||||
"xla_lhlo.all_reduce"(%arg0, %arg_out) ({
|
||||
^bb0(%lhs: tensor<f32>, %rhs: tensor<f32>):
|
||||
%max = xla_hlo.maximum %lhs, %rhs : tensor<f32>
|
||||
"xla_hlo.return"(%max) : (tensor<f32>) -> ()
|
||||
%max = mhlo.maximum %lhs, %rhs : tensor<f32>
|
||||
"mhlo.return"(%max) : (tensor<f32>) -> ()
|
||||
})
|
||||
{
|
||||
replica_groups = dense<[[0, 2, 4, 6], [1, 3, 5, 7]]> : tensor<2x4xi64>,
|
||||
@ -958,8 +958,8 @@ func @scatter_memrefs(%input: memref<200x100x300xf32>, %indices: memref<10x2xi32
|
||||
%updates: memref<10x300xf32>, %arg_out: memref<200x100x300xf32>) -> () {
|
||||
"xla_lhlo.scatter" (%input, %indices, %updates, %arg_out) ({
|
||||
^bb0(%lhs: tensor<f32>, %rhs: tensor<f32>): // no predecessors
|
||||
%add = xla_hlo.add %lhs, %rhs : tensor<f32>
|
||||
"xla_hlo.return"(%add) : (tensor<f32>) -> ()
|
||||
%add = mhlo.add %lhs, %rhs : tensor<f32>
|
||||
"mhlo.return"(%add) : (tensor<f32>) -> ()
|
||||
}) {
|
||||
scatter_dimension_numbers = {
|
||||
update_window_dims = dense<[1]> : tensor<1xi64>,
|
||||
@ -979,8 +979,8 @@ func @scatter_memrefs(%input: memref<200x100x300xf32>, %indices: memref<10x2xi32
|
||||
func @map_memrefs(%arg0: memref<20xf32>, %arg1: memref<20xf32>, %arg_out: memref<20xf32>) -> () {
|
||||
"xla_lhlo.map"(%arg0, %arg1, %arg_out) ({
|
||||
^bb0(%a: tensor<f32>, %b: tensor<f32>):
|
||||
%c = xla_hlo.add %a, %b : tensor<f32>
|
||||
"xla_hlo.return"(%c) : (tensor<f32>) -> ()
|
||||
%c = mhlo.add %a, %b : tensor<f32>
|
||||
"mhlo.return"(%c) : (tensor<f32>) -> ()
|
||||
}) {dimensions = dense<0> : tensor<1xi64>} : (memref<20xf32>, memref<20xf32>, memref<20xf32>) -> ()
|
||||
return
|
||||
}
|
||||
@ -991,8 +991,8 @@ func @map_memrefs(%arg0: memref<20xf32>, %arg1: memref<20xf32>, %arg_out: memref
|
||||
// expected-error@+1{{requires the same shape for all operands}}
|
||||
"xla_lhlo.map"(%arg0, %arg1, %arg_out) ({
|
||||
^bb0(%a: tensor<f32>, %b: tensor<f32>):
|
||||
%c = xla_hlo.add %a, %b : tensor<f32>
|
||||
"xla_hlo.return"(%c) : (tensor<f32>) -> ()
|
||||
%c = mhlo.add %a, %b : tensor<f32>
|
||||
"mhlo.return"(%c) : (tensor<f32>) -> ()
|
||||
}) {dimensions = dense<0> : tensor<1xi64>} : (memref<20xf32>, memref<20xf32>, memref<10xf32>) -> ()
|
||||
return
|
||||
}
|
||||
@ -1012,8 +1012,8 @@ func @sort_memrefs(%arg0: memref<16x16xf32>, %arg1: memref<16x16xf16>,
|
||||
%out0: memref<16x16xf32>, %out1: memref<16x16xf16>) -> () {
|
||||
"xla_lhlo.sort"(%arg0, %arg1, %out0, %out1) ( {
|
||||
^bb0(%a: tensor<f32>, %b: tensor<f32>, %c: tensor<f16>, %d: tensor<f16>):
|
||||
%7 = "xla_hlo.compare"(%a, %b) {comparison_direction = "GT"} : (tensor<f32>, tensor<f32>) -> tensor<i1>
|
||||
"xla_hlo.return"(%7) : (tensor<i1>) -> ()
|
||||
%7 = "mhlo.compare"(%a, %b) {comparison_direction = "GT"} : (tensor<f32>, tensor<f32>) -> tensor<i1>
|
||||
"mhlo.return"(%7) : (tensor<i1>) -> ()
|
||||
}) {dimension = 1 : i64, is_stable = true} : (memref<16x16xf32>, memref<16x16xf16>, memref<16x16xf32>, memref<16x16xf16>) -> ()
|
||||
return
|
||||
}
|
||||
@ -1025,8 +1025,8 @@ func @sort_memrefs(%arg0: memref<16x16xf32>, %arg1: memref<16x16xf16>,
|
||||
%out0: memref<16x16xf32>, %out1: memref<16x16xf16>) -> () {
|
||||
"xla_lhlo.sort"(%arg0, %arg1, %out0, %out1) ( {
|
||||
^bb0(%a: tensor<f32>, %b: tensor<f32>, %c: tensor<f16>, %d: tensor<f16>):
|
||||
%7 = "xla_hlo.compare"(%a, %b) {comparison_direction = "GT"} : (tensor<f32>, tensor<f32>) -> tensor<i1>
|
||||
"xla_hlo.return"(%7) : (tensor<i1>) -> ()
|
||||
%7 = "mhlo.compare"(%a, %b) {comparison_direction = "GT"} : (tensor<f32>, tensor<f32>) -> tensor<i1>
|
||||
"mhlo.return"(%7) : (tensor<i1>) -> ()
|
||||
}) {dimension = 1 : i64} : (memref<16x16xf32>, memref<16x16xf16>, memref<16x16xf32>, memref<16x16xf16>) -> ()
|
||||
return
|
||||
}
|
||||
@ -1038,8 +1038,8 @@ func @sort_memrefs(%arg0: memref<16x16xf32>, %arg1: memref<16x16xf16>,
|
||||
%out0: memref<16x16xf32>, %out1: memref<16x16xf16>) -> () {
|
||||
"xla_lhlo.sort"(%arg0, %arg1, %out0, %out1) ( {
|
||||
^bb0(%a: tensor<f32>, %b: tensor<f32>, %c: tensor<f16>, %d: tensor<f16>):
|
||||
%7 = "xla_hlo.compare"(%a, %b) {comparison_direction = "GT"} : (tensor<f32>, tensor<f32>) -> tensor<i1>
|
||||
"xla_hlo.return"(%7) : (tensor<i1>) -> ()
|
||||
%7 = "mhlo.compare"(%a, %b) {comparison_direction = "GT"} : (tensor<f32>, tensor<f32>) -> tensor<i1>
|
||||
"mhlo.return"(%7) : (tensor<i1>) -> ()
|
||||
}) : (memref<16x16xf32>, memref<16x16xf16>, memref<16x16xf32>, memref<16x16xf16>) -> ()
|
||||
return
|
||||
}
|
224
tensorflow/compiler/mlir/hlo/tests/lower-complex.mlir
Normal file
224
tensorflow/compiler/mlir/hlo/tests/lower-complex.mlir
Normal file
@ -0,0 +1,224 @@
|
||||
// RUN: mlir-hlo-opt %s -test-xla-chlo-legalize-to-hlo -test-xla-lower-complex | FileCheck %s
|
||||
|
||||
// CHECK-LABEL: @add
|
||||
func @add(%arg0 : tensor<2xf32>, %arg1 : tensor<2xf32>, %arg2 : tensor<2xf32>, %arg3 : tensor<2xf32>) -> (tensor<2xf32>, tensor<2xf32>) {
|
||||
%2 = "mhlo.complex"(%arg0, %arg1) : (tensor<2xf32>, tensor<2xf32>) -> (tensor<2xcomplex<f32>>)
|
||||
%3 = "mhlo.complex"(%arg2, %arg3) : (tensor<2xf32>, tensor<2xf32>) -> (tensor<2xcomplex<f32>>)
|
||||
|
||||
// CHECK-DAG: [[VAL0:%.+]] = mhlo.add %arg0, %arg2
|
||||
// CHECK-DAG: [[VAL1:%.+]] = mhlo.add %arg1, %arg3
|
||||
%4 = "mhlo.add"(%2, %3) : (tensor<2xcomplex<f32>>, tensor<2xcomplex<f32>>) -> (tensor<2xcomplex<f32>>)
|
||||
%5 = "mhlo.real"(%4) : (tensor<2xcomplex<f32>>) -> (tensor<2xf32>)
|
||||
%6 = "mhlo.imag"(%4) : (tensor<2xcomplex<f32>>) -> (tensor<2xf32>)
|
||||
|
||||
// CHECK: return [[VAL0]], [[VAL1]]
|
||||
return %5, %6 : tensor<2xf32>, tensor<2xf32>
|
||||
}
|
||||
|
||||
// CHECK-LABEL: @add_unranked
|
||||
func @add_unranked(%arg0 : tensor<*xf32>, %arg1 : tensor<*xf32>, %arg2 : tensor<*xf32>, %arg3 : tensor<*xf32>) -> (tensor<*xf32>, tensor<*xf32>) {
|
||||
%2 = "mhlo.complex"(%arg0, %arg1) : (tensor<*xf32>, tensor<*xf32>) -> (tensor<*xcomplex<f32>>)
|
||||
%3 = "mhlo.complex"(%arg2, %arg3) : (tensor<*xf32>, tensor<*xf32>) -> (tensor<*xcomplex<f32>>)
|
||||
|
||||
// CHECK-DAG: [[VAL0:%.+]] = mhlo.add %arg0, %arg2
|
||||
// CHECK-DAG: [[VAL1:%.+]] = mhlo.add %arg1, %arg3
|
||||
%4 = "mhlo.add"(%2, %3) : (tensor<*xcomplex<f32>>, tensor<*xcomplex<f32>>) -> (tensor<*xcomplex<f32>>)
|
||||
%5 = "mhlo.real"(%4) : (tensor<*xcomplex<f32>>) -> (tensor<*xf32>)
|
||||
%6 = "mhlo.imag"(%4) : (tensor<*xcomplex<f32>>) -> (tensor<*xf32>)
|
||||
|
||||
// CHECK: return [[VAL0]], [[VAL1]]
|
||||
return %5, %6 : tensor<*xf32>, tensor<*xf32>
|
||||
}
|
||||
|
||||
// CHECK-LABEL: @sub
|
||||
func @sub(%arg0 : tensor<2xf32>, %arg1 : tensor<2xf32>, %arg2 : tensor<2xf32>, %arg3 : tensor<2xf32>) -> (tensor<2xf32>, tensor<2xf32>) {
|
||||
%2 = "mhlo.complex"(%arg0, %arg1) : (tensor<2xf32>, tensor<2xf32>) -> (tensor<2xcomplex<f32>>)
|
||||
%3 = "mhlo.complex"(%arg2, %arg3) : (tensor<2xf32>, tensor<2xf32>) -> (tensor<2xcomplex<f32>>)
|
||||
|
||||
// CHECK-DAG: [[VAL0:%.+]] = mhlo.subtract %arg0, %arg2
|
||||
// CHECK-DAG: [[VAL1:%.+]] = mhlo.subtract %arg1, %arg3
|
||||
%4 = "mhlo.subtract"(%2, %3) : (tensor<2xcomplex<f32>>, tensor<2xcomplex<f32>>) -> (tensor<2xcomplex<f32>>)
|
||||
%5 = "mhlo.real"(%4) : (tensor<2xcomplex<f32>>) -> (tensor<2xf32>)
|
||||
%6 = "mhlo.imag"(%4) : (tensor<2xcomplex<f32>>) -> (tensor<2xf32>)
|
||||
|
||||
// CHECK: return [[VAL0]], [[VAL1]]
|
||||
return %5, %6 : tensor<2xf32>, tensor<2xf32>
|
||||
}
|
||||
|
||||
// CHECK-LABEL: @sub_unranked
|
||||
func @sub_unranked(%arg0 : tensor<*xf32>, %arg1 : tensor<*xf32>, %arg2 : tensor<*xf32>, %arg3 : tensor<*xf32>) -> (tensor<*xf32>, tensor<*xf32>) {
|
||||
%2 = "mhlo.complex"(%arg0, %arg1) : (tensor<*xf32>, tensor<*xf32>) -> (tensor<*xcomplex<f32>>)
|
||||
%3 = "mhlo.complex"(%arg2, %arg3) : (tensor<*xf32>, tensor<*xf32>) -> (tensor<*xcomplex<f32>>)
|
||||
|
||||
// CHECK-DAG: [[VAL0:%.+]] = mhlo.subtract %arg0, %arg2
|
||||
// CHECK-DAG: [[VAL1:%.+]] = mhlo.subtract %arg1, %arg3
|
||||
%4 = "mhlo.subtract"(%2, %3) : (tensor<*xcomplex<f32>>, tensor<*xcomplex<f32>>) -> (tensor<*xcomplex<f32>>)
|
||||
%5 = "mhlo.real"(%4) : (tensor<*xcomplex<f32>>) -> (tensor<*xf32>)
|
||||
%6 = "mhlo.imag"(%4) : (tensor<*xcomplex<f32>>) -> (tensor<*xf32>)
|
||||
|
||||
// CHECK: return [[VAL0]], [[VAL1]]
|
||||
return %5, %6 : tensor<*xf32>, tensor<*xf32>
|
||||
}
|
||||
|
||||
// CHECK-LABEL: @mul
|
||||
func @mul(%arg0 : tensor<2xf32>, %arg1 : tensor<2xf32>, %arg2 : tensor<2xf32>, %arg3 : tensor<2xf32>) -> (tensor<2xf32>, tensor<2xf32>) {
|
||||
%2 = "mhlo.complex"(%arg0, %arg1) : (tensor<2xf32>, tensor<2xf32>) -> (tensor<2xcomplex<f32>>)
|
||||
%3 = "mhlo.complex"(%arg2, %arg3) : (tensor<2xf32>, tensor<2xf32>) -> (tensor<2xcomplex<f32>>)
|
||||
|
||||
// CHECK-DAG: [[VAL0:%.+]] = mhlo.multiply %arg0, %arg2
|
||||
// CHECK-DAG: [[VAL1:%.+]] = mhlo.multiply %arg1, %arg3
|
||||
// CHECK-DAG: [[VAL2:%.+]] = mhlo.subtract [[VAL0]], [[VAL1]]
|
||||
// CHECK-DAG: [[VAL3:%.+]] = mhlo.multiply %arg0, %arg3
|
||||
// CHECK-DAG: [[VAL4:%.+]] = mhlo.multiply %arg1, %arg2
|
||||
// CHECK-DAG: [[VAL5:%.+]] = mhlo.add [[VAL3]], [[VAL4]]
|
||||
%4 = "mhlo.multiply"(%2, %3) : (tensor<2xcomplex<f32>>, tensor<2xcomplex<f32>>) -> (tensor<2xcomplex<f32>>)
|
||||
%5 = "mhlo.real"(%4) : (tensor<2xcomplex<f32>>) -> (tensor<2xf32>)
|
||||
%6 = "mhlo.imag"(%4) : (tensor<2xcomplex<f32>>) -> (tensor<2xf32>)
|
||||
|
||||
// CHECK: return %2, %5 : tensor<2xf32>, tensor<2xf32>
|
||||
return %5, %6 : tensor<2xf32>, tensor<2xf32>
|
||||
}
|
||||
|
||||
// CHECK-LABEL: @mul_unranked
|
||||
func @mul_unranked(%arg0 : tensor<*xf32>, %arg1 : tensor<*xf32>, %arg2 : tensor<*xf32>, %arg3 : tensor<*xf32>) -> (tensor<*xf32>, tensor<*xf32>) {
|
||||
%2 = "mhlo.complex"(%arg0, %arg1) : (tensor<*xf32>, tensor<*xf32>) -> (tensor<*xcomplex<f32>>)
|
||||
%3 = "mhlo.complex"(%arg2, %arg3) : (tensor<*xf32>, tensor<*xf32>) -> (tensor<*xcomplex<f32>>)
|
||||
|
||||
// CHECK-DAG: [[VAL0:%.+]] = mhlo.multiply %arg0, %arg2
|
||||
// CHECK-DAG: [[VAL1:%.+]] = mhlo.multiply %arg1, %arg3
|
||||
// CHECK-DAG: [[VAL2:%.+]] = mhlo.subtract [[VAL0]], [[VAL1]]
|
||||
// CHECK-DAG: [[VAL3:%.+]] = mhlo.multiply %arg0, %arg3
|
||||
// CHECK-DAG: [[VAL4:%.+]] = mhlo.multiply %arg1, %arg2
|
||||
// CHECK-DAG: [[VAL5:%.+]] = mhlo.add [[VAL3]], [[VAL4]]
|
||||
%4 = "mhlo.multiply"(%2, %3) : (tensor<*xcomplex<f32>>, tensor<*xcomplex<f32>>) -> (tensor<*xcomplex<f32>>)
|
||||
%5 = "mhlo.real"(%4) : (tensor<*xcomplex<f32>>) -> (tensor<*xf32>)
|
||||
%6 = "mhlo.imag"(%4) : (tensor<*xcomplex<f32>>) -> (tensor<*xf32>)
|
||||
|
||||
// CHECK: return %2, %5 : tensor<*xf32>, tensor<*xf32>
|
||||
return %5, %6 : tensor<*xf32>, tensor<*xf32>
|
||||
}
|
||||
|
||||
// CHECK-LABEL: @div
|
||||
func @div(%arg0 : tensor<2xf32>, %arg1 : tensor<2xf32>, %arg2 : tensor<2xf32>, %arg3 : tensor<2xf32>) -> (tensor<2xf32>, tensor<2xf32>) {
|
||||
%2 = "mhlo.complex"(%arg0, %arg1) : (tensor<2xf32>, tensor<2xf32>) -> (tensor<2xcomplex<f32>>)
|
||||
%3 = "mhlo.complex"(%arg2, %arg3) : (tensor<2xf32>, tensor<2xf32>) -> (tensor<2xcomplex<f32>>)
|
||||
|
||||
// CHECK-DAG: [[VAL0:%.+]] = "mhlo.negate"(%arg3)
|
||||
|
||||
// Compute the numerator's real component:
|
||||
// numerator.real = lhs.real * rhs.real lhs.imag * rhs.imag
|
||||
// CHECK-DAG: [[VAL1:%.+]] = mhlo.multiply %arg0, %arg2
|
||||
// CHECK-DAG: [[VAL2:%.+]] = mhlo.multiply %arg1, [[VAL0]]
|
||||
// CHECK-DAG: [[VAL3:%.+]] = mhlo.subtract [[VAL1]], [[VAL2]]
|
||||
|
||||
// Compute the real valued denominator as rhs * con(rhs):
|
||||
// denominator = rhs.real * rhs.real + rhs.imag * rhs.imag
|
||||
// CHECK-DAG: [[VAL4:%.+]] = mhlo.multiply %arg2, %arg2
|
||||
// CHECK-DAG: [[VAL5:%.+]] = mhlo.multiply %arg3, [[VAL0]]
|
||||
// CHECK-DAG: [[VAL6:%.+]] = mhlo.subtract [[VAL4]], [[VAL5]]
|
||||
|
||||
// Compute the numerator's imaginary component:
|
||||
// numerator.imag = lhs.imag * rhs.real - lhs.real * rhs.imag
|
||||
// CHECK-DAG: [[VAL7:%.+]] = mhlo.multiply %arg1, %arg2
|
||||
// CHECK-DAG: [[VAL8:%.+]] = mhlo.multiply %arg0, [[VAL0]]
|
||||
// CHECK-DAG: [[VAL9:%.+]] = mhlo.add [[VAL8]], [[VAL7]]
|
||||
|
||||
// Divide the numerator by the real valued denominator.
|
||||
// CHECK-DAG: [[VAL10:%.+]] = mhlo.divide [[VAL3]], [[VAL6]]
|
||||
// CHECK-DAG: [[VAL11:%.+]] = mhlo.divide [[VAL9]], [[VAL6]]
|
||||
%4 = "mhlo.divide"(%2, %3) : (tensor<2xcomplex<f32>>, tensor<2xcomplex<f32>>) -> (tensor<2xcomplex<f32>>)
|
||||
|
||||
%5 = "mhlo.real"(%4) : (tensor<2xcomplex<f32>>) -> (tensor<2xf32>)
|
||||
%6 = "mhlo.imag"(%4) : (tensor<2xcomplex<f32>>) -> (tensor<2xf32>)
|
||||
|
||||
// CHECK: return [[VAL10]], [[VAL11]]
|
||||
return %5, %6 : tensor<2xf32>, tensor<2xf32>
|
||||
}
|
||||
|
||||
// -----
|
||||
|
||||
// CHECK-LABEL: @div_unranked
|
||||
func @div_unranked(%arg0 : tensor<*xf32>, %arg1 : tensor<*xf32>, %arg2 : tensor<*xf32>, %arg3 : tensor<*xf32>) -> (tensor<*xf32>, tensor<*xf32>) {
|
||||
%2 = "mhlo.complex"(%arg0, %arg1) : (tensor<*xf32>, tensor<*xf32>) -> (tensor<*xcomplex<f32>>)
|
||||
%3 = "mhlo.complex"(%arg2, %arg3) : (tensor<*xf32>, tensor<*xf32>) -> (tensor<*xcomplex<f32>>)
|
||||
|
||||
// CHECK-DAG: [[VAL0:%.+]] = "mhlo.negate"(%arg3)
|
||||
|
||||
// Compute the numerator's real component:
|
||||
// numerator.real = lhs.real * rhs.real lhs.imag * rhs.imag
|
||||
// CHECK-DAG: [[VAL1:%.+]] = mhlo.multiply %arg0, %arg2
|
||||
// CHECK-DAG: [[VAL2:%.+]] = mhlo.multiply %arg1, [[VAL0]]
|
||||
// CHECK-DAG: [[VAL3:%.+]] = mhlo.subtract [[VAL1]], [[VAL2]]
|
||||
|
||||
// Compute the real valued denominator as rhs * con(rhs):
|
||||
// denominator = rhs.real * rhs.real + rhs.imag * rhs.imag
|
||||
// CHECK-DAG: [[VAL4:%.+]] = mhlo.multiply %arg2, %arg2
|
||||
// CHECK-DAG: [[VAL5:%.+]] = mhlo.multiply %arg3, [[VAL0]]
|
||||
// CHECK-DAG: [[VAL6:%.+]] = mhlo.subtract [[VAL4]], [[VAL5]]
|
||||
|
||||
// Compute the numerator's imaginary component:
|
||||
// numerator.imag = lhs.imag * rhs.real - lhs.real * rhs.imag
|
||||
// CHECK-DAG: [[VAL7:%.+]] = mhlo.multiply %arg1, %arg2
|
||||
// CHECK-DAG: [[VAL8:%.+]] = mhlo.multiply %arg0, [[VAL0]]
|
||||
// CHECK-DAG: [[VAL9:%.+]] = mhlo.add [[VAL8]], [[VAL7]]
|
||||
|
||||
// Divide the numerator by the real valued denominator.
|
||||
// CHECK-DAG: [[VAL10:%.+]] = mhlo.divide [[VAL3]], [[VAL6]]
|
||||
// CHECK-DAG: [[VAL11:%.+]] = mhlo.divide [[VAL9]], [[VAL6]]
|
||||
%4 = "mhlo.divide"(%2, %3) : (tensor<*xcomplex<f32>>, tensor<*xcomplex<f32>>) -> (tensor<*xcomplex<f32>>)
|
||||
|
||||
%5 = "mhlo.real"(%4) : (tensor<*xcomplex<f32>>) -> (tensor<*xf32>)
|
||||
%6 = "mhlo.imag"(%4) : (tensor<*xcomplex<f32>>) -> (tensor<*xf32>)
|
||||
|
||||
// CHECK: return [[VAL10]], [[VAL11]]
|
||||
return %5, %6 : tensor<*xf32>, tensor<*xf32>
|
||||
}
|
||||
|
||||
// CHECK-LABEL: @abs
|
||||
func @abs(%arg0 : tensor<2xf32>, %arg1 : tensor<2xf32>) -> (tensor<2xf32>) {
|
||||
%0 = "mhlo.complex"(%arg0, %arg1) : (tensor<2xf32>, tensor<2xf32>) -> (tensor<2xcomplex<f32>>)
|
||||
|
||||
// CHECK-DAG: [[VAL0:%.+]] = mhlo.multiply %arg0, %arg0
|
||||
// CHECK-DAG: [[VAL1:%.+]] = mhlo.multiply %arg1, %arg1
|
||||
// CHECK-DAG: [[VAL2:%.+]] = mhlo.add [[VAL0]], [[VAL1]]
|
||||
// CHECK-DAG: [[VAL3:%.+]] = "mhlo.sqrt"([[VAL2]])
|
||||
%1 = "mhlo.abs"(%0) : (tensor<2xcomplex<f32>>) -> (tensor<2xcomplex<f32>>)
|
||||
%2 = "mhlo.real"(%1) : (tensor<2xcomplex<f32>>) -> (tensor<2xf32>)
|
||||
|
||||
// CHECK: return [[VAL3]]
|
||||
return %2 : tensor<2xf32>
|
||||
}
|
||||
|
||||
// CHECK-LABEL: @exp
|
||||
func @exp(%arg0 : tensor<2xf32>, %arg1 : tensor<2xf32>) -> (tensor<2xf32>, tensor<2xf32>) {
|
||||
%0 = "mhlo.complex"(%arg0, %arg1) : (tensor<2xf32>, tensor<2xf32>) -> (tensor<2xcomplex<f32>>)
|
||||
|
||||
// CHECK-DAG: [[VAL0:%.+]] = "mhlo.exponential"(%arg0)
|
||||
// CHECK-DAG: [[VAL1:%.+]] = "mhlo.cosine"(%arg1)
|
||||
// CHECK-DAG: [[VAL2:%.+]] = "mhlo.sine"(%arg1)
|
||||
// CHECK-DAG: [[VAL3:%.+]] = mhlo.multiply [[VAL0]], [[VAL1]]
|
||||
// CHECK-DAG: [[VAL4:%.+]] = mhlo.multiply [[VAL0]], [[VAL2]]
|
||||
%1 = "mhlo.exponential"(%0) : (tensor<2xcomplex<f32>>) -> (tensor<2xcomplex<f32>>)
|
||||
%2 = "mhlo.real"(%1) : (tensor<2xcomplex<f32>>) -> (tensor<2xf32>)
|
||||
%3 = "mhlo.imag"(%1) : (tensor<2xcomplex<f32>>) -> (tensor<2xf32>)
|
||||
|
||||
// CHECK: return [[VAL3]], [[VAL4]]
|
||||
return %2, %3 : tensor<2xf32>, tensor<2xf32>
|
||||
}
|
||||
|
||||
// CHECK-LABEL: @exp_unranked
|
||||
func @exp_unranked(%arg0 : tensor<*xf32>, %arg1 : tensor<*xf32>) -> (tensor<*xf32>, tensor<*xf32>) {
|
||||
%0 = "mhlo.complex"(%arg0, %arg1) : (tensor<*xf32>, tensor<*xf32>) -> (tensor<*xcomplex<f32>>)
|
||||
|
||||
// CHECK-DAG: [[VAL0:%.+]] = "mhlo.exponential"(%arg0)
|
||||
// CHECK-DAG: [[VAL1:%.+]] = "mhlo.cosine"(%arg1)
|
||||
// CHECK-DAG: [[VAL2:%.+]] = "mhlo.sine"(%arg1)
|
||||
// CHECK-DAG: [[VAL3:%.+]] = mhlo.multiply [[VAL0]], [[VAL1]]
|
||||
// CHECK-DAG: [[VAL4:%.+]] = mhlo.multiply [[VAL0]], [[VAL2]]
|
||||
%1 = "mhlo.exponential"(%0) : (tensor<*xcomplex<f32>>) -> (tensor<*xcomplex<f32>>)
|
||||
%2 = "mhlo.real"(%1) : (tensor<*xcomplex<f32>>) -> (tensor<*xf32>)
|
||||
%3 = "mhlo.imag"(%1) : (tensor<*xcomplex<f32>>) -> (tensor<*xf32>)
|
||||
|
||||
// CHECK: return [[VAL3]], [[VAL4]]
|
||||
return %2, %3 : tensor<*xf32>, tensor<*xf32>
|
||||
}
|
35
tensorflow/compiler/mlir/hlo/tests/lower-general-dot.mlir
Normal file
35
tensorflow/compiler/mlir/hlo/tests/lower-general-dot.mlir
Normal file
@ -0,0 +1,35 @@
|
||||
// RUN: mlir-hlo-opt -test-xla-lower-general-dot -split-input-file %s -o - | FileCheck %s
|
||||
|
||||
// CHECK-LABEL: @testDebatch1
|
||||
func @testDebatch1(%arg0: tensor<1x1x2xf32>, %arg1: tensor<2x3xf32>) -> tensor<1x1x3xf32> {
|
||||
// CHECK-DAG: [[R0:%.+]] = "mhlo.reshape"(%arg0) : (tensor<1x1x2xf32>) -> tensor<1x2xf32>
|
||||
// CHECK-DAG: [[R1:%.+]] = "mhlo.dot"([[R0]], %arg1) {precision_config = ["DEFAULT", "DEFAULT"]} : (tensor<1x2xf32>, tensor<2x3xf32>) -> tensor<1x3xf32>
|
||||
// CHECK: [[R2:%.+]] = "mhlo.reshape"([[R1]]) : (tensor<1x3xf32>) -> tensor<1x1x3xf32>
|
||||
%0 = "mhlo.dot_general"(%arg0, %arg1) {dot_dimension_numbers = {lhs_batching_dimensions = dense<[]> : tensor<0xi64>, lhs_contracting_dimensions = dense<2> : tensor<1xi64>, rhs_batching_dimensions = dense<[]> : tensor<0xi64>, rhs_contracting_dimensions = dense<0> : tensor<1xi64>}, precision_config = ["DEFAULT", "DEFAULT"]} : (tensor<1x1x2xf32>, tensor<2x3xf32>) -> tensor<1x1x3xf32>
|
||||
|
||||
return %0 : tensor<1x1x3xf32>
|
||||
}
|
||||
|
||||
// -----
|
||||
|
||||
// CHECK-LABEL: @testDebatch2
|
||||
func @testDebatch2(%arg0: tensor<2x3xf32>, %arg1: tensor<1x1x2xf32>) -> tensor<3x1x1xf32> {
|
||||
// CHECK-DAG: [[R0:%.+]] = "mhlo.transpose"(%arg0) {permutation = dense<[1, 0]> : tensor<2xi64>} : (tensor<2x3xf32>) -> tensor<3x2xf32>
|
||||
// CHECK-DAG: [[R1:%.+]] = "mhlo.transpose"(%arg1) {permutation = dense<[2, 0, 1]> : tensor<3xi64>} : (tensor<1x1x2xf32>) -> tensor<2x1x1xf32>
|
||||
// CHECK-DAG: [[R2:%.+]] = "mhlo.reshape"([[R1]]) : (tensor<2x1x1xf32>) -> tensor<2x1xf32>
|
||||
// CHECK-DAG: [[R3:%.+]] = "mhlo.dot"([[R0]], [[R2]]) {precision_config = ["DEFAULT", "DEFAULT"]} : (tensor<3x2xf32>, tensor<2x1xf32>) -> tensor<3x1xf32>
|
||||
// CHECK: [[R4:%.+]] = "mhlo.reshape"([[R3]]) : (tensor<3x1xf32>) -> tensor<3x1x1xf32>
|
||||
|
||||
%0 = "mhlo.dot_general"(%arg0, %arg1) {dot_dimension_numbers = {lhs_batching_dimensions = dense<[]> : tensor<0xi64>, lhs_contracting_dimensions = dense<0> : tensor<1xi64>, rhs_batching_dimensions = dense<[]> : tensor<0xi64>, rhs_contracting_dimensions = dense<2> : tensor<1xi64>}, precision_config = ["DEFAULT", "DEFAULT"]} : (tensor<2x3xf32>, tensor<1x1x2xf32>) -> tensor<3x1x1xf32>
|
||||
return %0 : tensor<3x1x1xf32>
|
||||
}
|
||||
|
||||
// -----
|
||||
|
||||
// CHECK-LABEL: @testBatchPassthrough
|
||||
func @testBatchPassthrough(%arg0: tensor<2x2x3xf32>, %arg1: tensor<2x1x2xf32>) -> tensor<3x2x1xf32> {
|
||||
// CHECK-NEXT: "mhlo.dot_general"(%arg0, %arg1)
|
||||
%0 = "mhlo.dot_general"(%arg0, %arg1) {dot_dimension_numbers = {lhs_batching_dimensions = dense<[0]> : tensor<1xi64>, lhs_contracting_dimensions = dense<1> : tensor<1xi64>, rhs_batching_dimensions = dense<[0]> : tensor<1xi64>, rhs_contracting_dimensions = dense<2> : tensor<1xi64>}, precision_config = ["DEFAULT", "DEFAULT"]} : (tensor<2x2x3xf32>, tensor<2x1x2xf32>) -> tensor<3x2x1xf32>
|
||||
return %0 : tensor<3x2x1xf32>
|
||||
}
|
||||
|
@ -0,0 +1,11 @@
|
||||
// RUN: mlir-hlo-opt -test-xla-materialize-broadcasts -split-input-file %s -o - | FileCheck %s
|
||||
|
||||
// CHECK-LABEL: @clampBroadcast
|
||||
// CHECK-SAME: (%[[MIN:.+]]: tensor<f32>, %[[VAL:.+]]: tensor<4xf32>, %[[MAX:.+]]: tensor<f32>)
|
||||
func @clampBroadcast(%min: tensor<f32>, %value: tensor<4xf32>, %max: tensor<f32>) -> tensor<4xf32> {
|
||||
// CHECK-DAG: %[[MIN_BC:.+]] = "mhlo.broadcast"(%[[MIN]]) {broadcast_sizes = dense<4> : tensor<1xi64>} : (tensor<f32>) -> tensor<4xf32>
|
||||
// CHECK-DAG: %[[MAX_BC:.+]] = "mhlo.broadcast"(%[[MAX]]) {broadcast_sizes = dense<4> : tensor<1xi64>} : (tensor<f32>) -> tensor<4xf32>
|
||||
// CHECK: "mhlo.clamp"(%[[MIN_BC]], %[[VAL]], %[[MAX_BC]]) : (tensor<4xf32>, tensor<4xf32>, tensor<4xf32>) -> tensor<4xf32>
|
||||
%0 = "mhlo.clamp"(%min, %value, %max) : (tensor<f32>, tensor<4xf32>, tensor<f32>) -> tensor<4xf32>
|
||||
return %0 : tensor<4xf32>
|
||||
}
|
File diff suppressed because it is too large
Load Diff
@ -1,14 +1,14 @@
|
||||
// RUN: xla-opt %s -pass-pipeline='func(canonicalize)' | FileCheck %s
|
||||
// RUN: mlir-hlo-opt %s -pass-pipeline='func(canonicalize)' | FileCheck %s
|
||||
|
||||
// CHECK-LABEL: func @noop
|
||||
// CHECK-SAME: (%[[ARG0:.*]]: tensor<4x8xf32>)
|
||||
// CHECK: return %[[ARG0]]
|
||||
func @noop(%arg0: tensor<4x8xf32>) -> tensor<4x8xf32> {
|
||||
%0 = xla_hlo.constant dense<0.000000e+00> : tensor<f32>
|
||||
%2 = "xla_hlo.reduce"(%arg0, %0) ( {
|
||||
%0 = mhlo.constant dense<0.000000e+00> : tensor<f32>
|
||||
%2 = "mhlo.reduce"(%arg0, %0) ( {
|
||||
^bb0(%arg1: tensor<f32>, %arg2: tensor<f32>):
|
||||
%4 = xla_hlo.add %arg1, %arg2 : tensor<f32>
|
||||
"xla_hlo.return"(%4) : (tensor<f32>) -> ()
|
||||
%4 = mhlo.add %arg1, %arg2 : tensor<f32>
|
||||
"mhlo.return"(%4) : (tensor<f32>) -> ()
|
||||
}) {dimensions = dense<[]> : tensor<0xi64>} : (tensor<4x8xf32>, tensor<f32>) -> tensor<4x8xf32>
|
||||
return %2 : tensor<4x8xf32>
|
||||
}
|
149
tensorflow/compiler/mlir/hlo/tests/reshape.mlir
Normal file
149
tensorflow/compiler/mlir/hlo/tests/reshape.mlir
Normal file
@ -0,0 +1,149 @@
|
||||
// RUN: mlir-hlo-opt %s -split-input-file -pass-pipeline='func(canonicalize)' | FileCheck %s
|
||||
|
||||
// CHECK-LABEL: func @const_fold_collapse_to_scalar
|
||||
func @const_fold_collapse_to_scalar() -> tensor<i32> {
|
||||
// CHECK-NEXT: [[CST:%.+]] = mhlo.constant dense<42> : tensor<i32>
|
||||
%cst = mhlo.constant dense<42> : tensor<1x1xi32>
|
||||
%0 = "mhlo.reshape"(%cst) : (tensor<1x1xi32>) -> tensor<i32>
|
||||
// CHECK-NEXT: return [[CST]]
|
||||
return %0 : tensor<i32>
|
||||
}
|
||||
|
||||
// -----
|
||||
|
||||
// CHECK-LABEL: func @const_fold_collapse_to_tensor
|
||||
func @const_fold_collapse_to_tensor() -> tensor<2xi32> {
|
||||
// CHECK-NEXT: [[CST:%.+]] = mhlo.constant dense<42> : tensor<2xi32>
|
||||
%cst = mhlo.constant dense<42> : tensor<1x2xi32>
|
||||
%0 = "mhlo.reshape"(%cst) : (tensor<1x2xi32>) -> tensor<2xi32>
|
||||
// CHECK-NEXT: return [[CST]]
|
||||
return %0 : tensor<2xi32>
|
||||
}
|
||||
|
||||
// -----
|
||||
|
||||
// CHECK-LABEL: func @const_fold_expand
|
||||
func @const_fold_expand() -> tensor<1xi32> {
|
||||
// CHECK-NEXT: [[CST:%.+]] = mhlo.constant dense<42> : tensor<1xi32>
|
||||
%cst = mhlo.constant dense<42> : tensor<i32>
|
||||
%0 = "mhlo.reshape"(%cst) : (tensor<i32>) -> tensor<1xi32>
|
||||
// CHECK-NEXT: return [[CST]]
|
||||
return %0 : tensor<1xi32>
|
||||
}
|
||||
|
||||
// -----
|
||||
|
||||
// CHECK-LABEL: func @const_fold_nontrivial
|
||||
func @const_fold_nontrivial() -> tensor<16xi64> {
|
||||
// CHECK-NEXT: [[CST:%.+]] = mhlo.constant dense<42> : tensor<16xi64>
|
||||
%cst = mhlo.constant dense<42> : tensor<4x4xi64>
|
||||
%0 = "mhlo.reshape"(%cst) : (tensor<4x4xi64>) -> tensor<16xi64>
|
||||
// CHECK-NEXT: return [[CST]]
|
||||
return %0 : tensor<16xi64>
|
||||
}
|
||||
|
||||
// -----
|
||||
|
||||
// CHECK-LABEL: func @const_fold_flatten
|
||||
func @const_fold_flatten() -> tensor<16xi64> {
|
||||
// CHECK-NEXT: [[CST:%.+]] = mhlo.constant dense<42> : tensor<16xi64>
|
||||
%cst = mhlo.constant dense<42> : tensor<4x4xi64>
|
||||
%0 = "mhlo.reshape"(%cst) : (tensor<4x4xi64>) -> tensor<16xi64>
|
||||
// CHECK-NEXT: return [[CST]]
|
||||
return %0 : tensor<16xi64>
|
||||
}
|
||||
|
||||
// -----
|
||||
|
||||
// CHECK-LABEL: func @const_fold_6
|
||||
func @const_fold_6() -> tensor<6xi32> {
|
||||
// CHECK-NEXT: [[CST:%.+]] = mhlo.constant dense<[1, 2, 3, 4, 5, 6]> : tensor<6xi32>
|
||||
%cst = mhlo.constant dense<[[1, 2], [3, 4], [5, 6]]> : tensor<3x2xi32>
|
||||
%0 = "mhlo.reshape"(%cst) : (tensor<3x2xi32>) -> tensor<6xi32>
|
||||
// CHECK-NEXT: return [[CST]]
|
||||
return %0 : tensor<6xi32>
|
||||
}
|
||||
|
||||
// -----
|
||||
|
||||
// CHECK-LABEL: func @const_fold_same_shape
|
||||
func @const_fold_same_shape() -> tensor<2x3xi32> {
|
||||
// CHECK-NEXT: [[CST:%.+]] = mhlo.constant dense<[
|
||||
// CHECK-SAME: [1, 2, 3], [4, 5, 6]
|
||||
// CHECK-SAME: ]> : tensor<2x3xi32>
|
||||
%cst = mhlo.constant dense<[1, 2, 3, 4, 5, 6]> : tensor<6xi32>
|
||||
%0 = "mhlo.reshape"(%cst) : (tensor<6xi32>) -> tensor<2x3xi32>
|
||||
// CHECK-NEXT: return [[CST]]
|
||||
return %0 : tensor<2x3xi32>
|
||||
}
|
||||
|
||||
// -----
|
||||
|
||||
// CHECK-LABEL: func @const_fold_float
|
||||
func @const_fold_float() -> tensor<16xf64> {
|
||||
// CHECK-NEXT: [[CST:%.+]] = mhlo.constant dense<4.2{{0*}}e+00> : tensor<16xf64>
|
||||
%cst = mhlo.constant dense<4.2> : tensor<4x4xf64>
|
||||
%0 = "mhlo.reshape"(%cst) : (tensor<4x4xf64>) -> tensor<16xf64>
|
||||
// CHECK-NEXT: return [[CST]]
|
||||
return %0 : tensor<16xf64>
|
||||
}
|
||||
|
||||
// -----
|
||||
|
||||
// CHECK-LABEL: func @non_const_same_shape
|
||||
// CHECK-SAME: [[ARG:%[a-zA-Z0-9]+]]
|
||||
func @non_const_same_shape(%arg : tensor<2x3xi32>) -> tensor<2x3xi32> {
|
||||
// CHECK-NEXT: return [[ARG]]
|
||||
%0 = "mhlo.reshape"(%arg) : (tensor<2x3xi32>) -> tensor<2x3xi32>
|
||||
return %0 : tensor<2x3xi32>
|
||||
}
|
||||
|
||||
// -----
|
||||
|
||||
// CHECK-LABEL: func @non_const_chained_reshape
|
||||
// CHECK-SAME: [[ARG:%[a-zA-Z0-9]+]]
|
||||
func @non_const_chained_reshape(%arg : tensor<2x3xi32>) -> (tensor<3x2xi32>, tensor<6xi32>) {
|
||||
// CHECK-NEXT: "mhlo.reshape"([[ARG]]) : (tensor<2x3xi32>) -> tensor<3x2xi32>
|
||||
// CHECK-NEXT: "mhlo.reshape"([[ARG]]) : (tensor<2x3xi32>) -> tensor<6xi32>
|
||||
%0 = "mhlo.reshape"(%arg) : (tensor<2x3xi32>) -> tensor<3x2xi32>
|
||||
%1 = "mhlo.reshape"(%0) : (tensor<3x2xi32>) -> tensor<6xi32>
|
||||
return %0, %1 : tensor<3x2xi32>, tensor<6xi32> // return both so nothing is removed
|
||||
}
|
||||
|
||||
// -----
|
||||
|
||||
// CHECK-LABEL: func @non_const_chained_reshape_unused_parent
|
||||
// CHECK-SAME: [[ARG:%[a-zA-Z0-9]+]]
|
||||
func @non_const_chained_reshape_unused_parent(%arg : tensor<2x3xi32>) -> tensor<6xi32> {
|
||||
// CHECK-NEXT: [[RES:%.+]] = "mhlo.reshape"([[ARG]]) : (tensor<2x3xi32>) -> tensor<6xi32>
|
||||
%0 = "mhlo.reshape"(%arg) : (tensor<2x3xi32>) -> tensor<3x2xi32>
|
||||
%1 = "mhlo.reshape"(%0) : (tensor<3x2xi32>) -> tensor<6xi32>
|
||||
// CHECK-NEXT: return [[RES]]
|
||||
return %1 : tensor<6xi32>
|
||||
}
|
||||
|
||||
// -----
|
||||
|
||||
// CHECK-LABEL: func @non_const_chained_reshape_becomes_noop
|
||||
// CHECK-SAME: [[ARG:%[a-zA-Z0-9]+]]
|
||||
func @non_const_chained_reshape_becomes_noop(%arg : tensor<2x3xi32>) -> tensor<2x3xi32> {
|
||||
%0 = "mhlo.reshape"(%arg) : (tensor<2x3xi32>) -> tensor<3x2xi32>
|
||||
%1 = "mhlo.reshape"(%0) : (tensor<3x2xi32>) -> tensor<2x3xi32>
|
||||
// CHECK-NEXT: return [[ARG]]
|
||||
return %1 : tensor<2x3xi32>
|
||||
}
|
||||
|
||||
// -----
|
||||
|
||||
// CHECK-LABEL: func @non_const_many_chained_reshapes
|
||||
// CHECK-SAME: [[ARG:%[a-zA-Z0-9]+]]
|
||||
func @non_const_many_chained_reshapes(%arg : tensor<2x3x4xi32>) -> tensor<1x2x4x3xi32> {
|
||||
// CHECK-NEXT: [[RES:%.+]] = "mhlo.reshape"([[ARG]]) : (tensor<2x3x4xi32>) -> tensor<1x2x4x3xi32>
|
||||
%0 = "mhlo.reshape"(%arg) : (tensor<2x3x4xi32>) -> tensor<4x3x2xi32>
|
||||
%1 = "mhlo.reshape"(%0) : (tensor<4x3x2xi32>) -> tensor<12x2xi32>
|
||||
%2 = "mhlo.reshape"(%1) : (tensor<12x2xi32>) -> tensor<2x12xi32>
|
||||
%3 = "mhlo.reshape"(%2) : (tensor<2x12xi32>) -> tensor<24xi32>
|
||||
%4 = "mhlo.reshape"(%3) : (tensor<24xi32>) -> tensor<1x2x4x3xi32>
|
||||
// CHECK-NEXT: return [[RES]]
|
||||
return %4 : tensor<1x2x4x3xi32>
|
||||
}
|
@ -1,9 +1,9 @@
|
||||
// RUN: xla-opt %s -pass-pipeline='func(canonicalize)' | FileCheck %s
|
||||
// RUN: mlir-hlo-opt %s -pass-pipeline='func(canonicalize)' | FileCheck %s
|
||||
|
||||
// CHECK-LABEL: func @noop
|
||||
// CHECK-SAME: (%[[ARG0:.*]]: tensor<1x2xf32>)
|
||||
func @noop(%arg0: tensor<1x2xf32>) -> tensor<1x2xf32> {
|
||||
%0 = "xla_hlo.reverse"(%arg0) {dimensions = dense<[]> : tensor<0xi64>} : (tensor<1x2xf32>) -> tensor<1x2xf32>
|
||||
%0 = "mhlo.reverse"(%arg0) {dimensions = dense<[]> : tensor<0xi64>} : (tensor<1x2xf32>) -> tensor<1x2xf32>
|
||||
// CHECK: return %[[ARG0]]
|
||||
return %0 : tensor<1x2xf32>
|
||||
}
|
@ -0,0 +1,60 @@
|
||||
// RUN: mlir-hlo-opt %s -xla-hlo-sink-constants-to-control-flow | FileCheck %s
|
||||
|
||||
// Tests sinking constants to a while loop.
|
||||
|
||||
// CHECK-LABEL: func @sink_const_to_while
|
||||
func @sink_const_to_while(%arg0: tensor<i64>) -> tensor<i64> {
|
||||
// CHECK-NEXT: mhlo.while
|
||||
%c0 = mhlo.constant dense<1> : tensor<i64>
|
||||
%c1 = mhlo.constant dense<2> : tensor<i64>
|
||||
%0 = "mhlo.while"(%arg0) ( {
|
||||
^bb0(%arg1: tensor<i64>):
|
||||
// CHECK: %[[ARG1A:.+]]: tensor<i64>
|
||||
// CHECK: %[[C0:.+]] = mhlo.constant dense<1> : tensor<i64>
|
||||
// CHECK: "mhlo.compare"(%[[C0]], %[[ARG1A]])
|
||||
%1 = "mhlo.compare"(%c0, %arg1) {comparison_direction = "LT"} : (tensor<i64>, tensor<i64>) -> tensor<i1>
|
||||
"mhlo.return"(%1) : (tensor<i1>) -> ()
|
||||
}, {
|
||||
^bb0(%arg1: tensor<i64>):
|
||||
// CHECK: %[[ARG1B:.+]]: tensor<i64>
|
||||
// CHECK-DAG: %[[C1:.+]] = mhlo.constant dense<2> : tensor<i64>
|
||||
// CHECK-DAG: %[[ADD0:.+]] = mhlo.add %[[ARG1B]], %[[ARG1B]]
|
||||
%2 = mhlo.add %arg1, %arg1 : tensor<i64>
|
||||
// CHECK: %[[ADD1:.+]] = mhlo.add %[[C1]], %[[ADD0]]
|
||||
%3 = mhlo.add %c1, %2 : tensor<i64>
|
||||
// CHECK: %[[ADD2:.+]] = mhlo.add %[[C1]], %[[ADD1]]
|
||||
%4 = mhlo.add %c1, %3 : tensor<i64>
|
||||
"mhlo.return"(%4) : (tensor<i64>) -> ()
|
||||
}) : (tensor<i64>) -> tensor<i64>
|
||||
return %0 : tensor<i64>
|
||||
}
|
||||
|
||||
// Tests sinking constants to a conditional op.
|
||||
|
||||
// CHECK-LABEL: func @sink_const_to_conditional
|
||||
func @sink_const_to_conditional(%arg0: tensor<i64>) -> tensor<i64> {
|
||||
%c0 = mhlo.constant dense<1> : tensor<i64>
|
||||
%c1 = mhlo.constant dense<2> : tensor<i64>
|
||||
%0 = "mhlo.compare"(%arg0, %c0) {comparison_direction = "LT"} : (tensor<i64>, tensor<i64>) -> tensor<i1>
|
||||
%1 = "mhlo.tuple"(%arg0) : (tensor<i64>) -> tuple<tensor<i64>>
|
||||
// CHECK: mhlo.if
|
||||
%2 = "mhlo.if"(%0, %1, %1) ( {
|
||||
^bb0(%arg1: tuple<tensor<i64>>):
|
||||
// CHECK: %[[C0:.+]] = mhlo.constant dense<1> : tensor<i64>
|
||||
%3 = "mhlo.get_tuple_element"(%arg1) {index = 0 : i32} : (tuple<tensor<i64>>) -> tensor<i64>
|
||||
// CHECK: %[[ADD0:.+]] = mhlo.add %[[C0]],
|
||||
%4 = mhlo.add %c0, %3 : tensor<i64>
|
||||
%5 = "mhlo.tuple"(%4) : (tensor<i64>) -> tuple<tensor<i64>>
|
||||
"mhlo.return"(%5) : (tuple<tensor<i64>>) -> ()
|
||||
}, {
|
||||
^bb0(%arg1: tuple<tensor<i64>>):
|
||||
// CHECK: %[[C1:.+]] = mhlo.constant dense<2> : tensor<i64>
|
||||
%6 = "mhlo.get_tuple_element"(%arg1) {index = 0 : i32} : (tuple<tensor<i64>>) -> tensor<i64>
|
||||
// CHECK: %[[ADD1:.+]] = mhlo.add %[[C1]],
|
||||
%7 = mhlo.add %c1, %6 : tensor<i64>
|
||||
%8 = "mhlo.tuple"(%7) : (tensor<i64>) -> tuple<tensor<i64>>
|
||||
"mhlo.return"(%8) : (tuple<tensor<i64>>) -> ()
|
||||
}) : (tensor<i1>, tuple<tensor<i64>>, tuple<tensor<i64>>) -> tuple<tensor<i64>>
|
||||
%9 = "mhlo.get_tuple_element"(%2) {index = 0 : i32} : (tuple<tensor<i64>>) -> tensor<i64>
|
||||
return %9 : tensor<i64>
|
||||
}
|
@ -1,9 +1,9 @@
|
||||
// RUN: xla-opt %s -split-input-file -pass-pipeline='func(canonicalize)' | FileCheck %s
|
||||
// RUN: mlir-hlo-opt %s -split-input-file -pass-pipeline='func(canonicalize)' | FileCheck %s
|
||||
|
||||
// CHECK-LABEL: func @remove_noop
|
||||
// CHECK-SAME: [[ARG:%[a-zA-Z0-9]+]]
|
||||
func @remove_noop(%arg : tensor<2x3x9x5xi32>) -> tensor<2x3x9x5xi32> {
|
||||
%0 = "xla_hlo.transpose"(%arg) {permutation = dense<[0, 1, 2, 3]> : tensor<4xi64>}: (tensor<2x3x9x5xi32>) -> tensor<2x3x9x5xi32>
|
||||
%0 = "mhlo.transpose"(%arg) {permutation = dense<[0, 1, 2, 3]> : tensor<4xi64>}: (tensor<2x3x9x5xi32>) -> tensor<2x3x9x5xi32>
|
||||
// CHECK-NEXT: return [[ARG]]
|
||||
return %0 : tensor<2x3x9x5xi32>
|
||||
}
|
||||
@ -13,8 +13,8 @@ func @remove_noop(%arg : tensor<2x3x9x5xi32>) -> tensor<2x3x9x5xi32> {
|
||||
// CHECK-LABEL: func @keep_real_transpose
|
||||
// CHECK-SAME: [[ARG:%[a-zA-Z0-9]+]]
|
||||
func @keep_real_transpose(%arg : tensor<2x3x9x5xi32>) -> tensor<3x2x5x9xi32> {
|
||||
// CHECK-NEXT: "xla_hlo.transpose"([[ARG]])
|
||||
%0 = "xla_hlo.transpose"(%arg) {permutation = dense<[1, 0, 3, 2]> : tensor<4xi64>}: (tensor<2x3x9x5xi32>) -> tensor<3x2x5x9xi32>
|
||||
// CHECK-NEXT: "mhlo.transpose"([[ARG]])
|
||||
%0 = "mhlo.transpose"(%arg) {permutation = dense<[1, 0, 3, 2]> : tensor<4xi64>}: (tensor<2x3x9x5xi32>) -> tensor<3x2x5x9xi32>
|
||||
return %0 : tensor<3x2x5x9xi32>
|
||||
}
|
||||
|
||||
@ -23,7 +23,7 @@ func @keep_real_transpose(%arg : tensor<2x3x9x5xi32>) -> tensor<3x2x5x9xi32> {
|
||||
// CHECK-LABEL: func @keep_same_shape_real_transpose
|
||||
// CHECK-SAME: [[ARG:%[a-zA-Z0-9]+]]
|
||||
func @keep_same_shape_real_transpose(%arg : tensor<4x4xi32>) -> tensor<4x4xi32> {
|
||||
// CHECK-NEXT: "xla_hlo.transpose"([[ARG]])
|
||||
%0 = "xla_hlo.transpose"(%arg) {permutation = dense<[1, 0]> : tensor<2xi64>}: (tensor<4x4xi32>) -> tensor<4x4xi32>
|
||||
// CHECK-NEXT: "mhlo.transpose"([[ARG]])
|
||||
%0 = "mhlo.transpose"(%arg) {permutation = dense<[1, 0]> : tensor<2xi64>}: (tensor<4x4xi32>) -> tensor<4x4xi32>
|
||||
return %0 : tensor<4x4xi32>
|
||||
}
|
10
tensorflow/compiler/mlir/hlo/tests/tuple.mlir
Normal file
10
tensorflow/compiler/mlir/hlo/tests/tuple.mlir
Normal file
@ -0,0 +1,10 @@
|
||||
// RUN: mlir-hlo-opt %s -split-input-file -pass-pipeline='func(canonicalize)' | FileCheck %s
|
||||
|
||||
// CHECK-LABEL: func @fold_access
|
||||
// CHECK-SAME: [[ARG:%[a-zA-Z0-9]+]]
|
||||
func @fold_access(%arg : tensor<i32>) -> tensor<i32> {
|
||||
// CHECK-NEXT: return [[ARG]]
|
||||
%tuple = "mhlo.tuple"(%arg) : (tensor<i32>) -> tuple<tensor<i32>>
|
||||
%element = "mhlo.get_tuple_element"(%tuple) {index = 0 : i32} : (tuple<tensor<i32>>) -> tensor<i32>
|
||||
return %element : tensor<i32>
|
||||
}
|
@ -1,4 +1,4 @@
|
||||
// RUN: xla-opt -split-input-file -test-xla-unfuse-batch-norm -verify-diagnostics %s | FileCheck --enable-var-scope %s
|
||||
// RUN: mlir-hlo-opt -split-input-file -test-xla-unfuse-batch-norm -verify-diagnostics %s | FileCheck --enable-var-scope %s
|
||||
|
||||
// CHECK-LABEL: @batchNormInference_2D_inner_features
|
||||
// CHECK-SAME: %[[X:[^:[:space:]]+]]
|
||||
@ -10,19 +10,19 @@ func @batchNormInference_2D_inner_features(
|
||||
%x: tensor<4x256xf32>, %scale: tensor<256xf32>, %offset: tensor<256xf32>,
|
||||
%mean: tensor<256xf32>, %variance: tensor<256xf32>)
|
||||
-> (tensor<4x256xf32>) {
|
||||
// CHECK-DAG: %[[EPS:.+]] = xla_hlo.constant dense<1.001000e-05> : tensor<f32>
|
||||
// CHECK-DAG: %[[EPS_BCAST:.+]] = "xla_hlo.broadcast_in_dim"(%[[EPS]]) {broadcast_dimensions = dense<[]> : tensor<0xi64>} : (tensor<f32>) -> tensor<256xf32>
|
||||
// CHECK-DAG: %[[VARIANCE_EPS:.+]] = xla_hlo.add %[[VARIANCE]], %[[EPS_BCAST]] : tensor<256xf32>
|
||||
// CHECK-DAG: %[[STDDEV:.+]] = "xla_hlo.sqrt"(%[[VARIANCE_EPS]]) : (tensor<256xf32>) -> tensor<256xf32>
|
||||
// CHECK-DAG: %[[STDDEV_BCAST:.+]] = "xla_hlo.broadcast_in_dim"(%[[STDDEV]]) {broadcast_dimensions = dense<1> : tensor<1xi64>} : (tensor<256xf32>) -> tensor<4x256xf32>
|
||||
// CHECK-DAG: %[[SCALE_BCAST:.+]] = "xla_hlo.broadcast_in_dim"(%[[SCALE]]) {broadcast_dimensions = dense<1> : tensor<1xi64>} : (tensor<256xf32>) -> tensor<4x256xf32>
|
||||
// CHECK-DAG: %[[OFFSET_BCAST:.+]] = "xla_hlo.broadcast_in_dim"(%[[OFFSET]]) {broadcast_dimensions = dense<1> : tensor<1xi64>} : (tensor<256xf32>) -> tensor<4x256xf32>
|
||||
// CHECK-DAG: %[[MEAN_BCAST:.+]] = "xla_hlo.broadcast_in_dim"(%[[MEAN]]) {broadcast_dimensions = dense<1> : tensor<1xi64>} : (tensor<256xf32>) -> tensor<4x256xf32>
|
||||
// CHECK-DAG: %[[X_CENTER:.+]] = xla_hlo.subtract %[[X]], %[[MEAN_BCAST]] : tensor<4x256xf32>
|
||||
// CHECK-DAG: %[[X_SCALED:.+]] = xla_hlo.multiply %[[X_CENTER]], %[[SCALE_BCAST]] : tensor<4x256xf32>
|
||||
// CHECK-DAG: %[[X_NORMED:.+]] = xla_hlo.divide %[[X_SCALED]], %[[STDDEV_BCAST]] : tensor<4x256xf32>
|
||||
// CHECK-DAG: %[[RESULT:.+]] = xla_hlo.add %[[X_NORMED]], %[[OFFSET_BCAST]] : tensor<4x256xf32>
|
||||
%0 = "xla_hlo.batch_norm_inference"(%x, %scale, %offset, %mean, %variance)
|
||||
// CHECK-DAG: %[[EPS:.+]] = mhlo.constant dense<1.001000e-05> : tensor<f32>
|
||||
// CHECK-DAG: %[[EPS_BCAST:.+]] = "mhlo.broadcast_in_dim"(%[[EPS]]) {broadcast_dimensions = dense<[]> : tensor<0xi64>} : (tensor<f32>) -> tensor<256xf32>
|
||||
// CHECK-DAG: %[[VARIANCE_EPS:.+]] = mhlo.add %[[VARIANCE]], %[[EPS_BCAST]] : tensor<256xf32>
|
||||
// CHECK-DAG: %[[STDDEV:.+]] = "mhlo.sqrt"(%[[VARIANCE_EPS]]) : (tensor<256xf32>) -> tensor<256xf32>
|
||||
// CHECK-DAG: %[[STDDEV_BCAST:.+]] = "mhlo.broadcast_in_dim"(%[[STDDEV]]) {broadcast_dimensions = dense<1> : tensor<1xi64>} : (tensor<256xf32>) -> tensor<4x256xf32>
|
||||
// CHECK-DAG: %[[SCALE_BCAST:.+]] = "mhlo.broadcast_in_dim"(%[[SCALE]]) {broadcast_dimensions = dense<1> : tensor<1xi64>} : (tensor<256xf32>) -> tensor<4x256xf32>
|
||||
// CHECK-DAG: %[[OFFSET_BCAST:.+]] = "mhlo.broadcast_in_dim"(%[[OFFSET]]) {broadcast_dimensions = dense<1> : tensor<1xi64>} : (tensor<256xf32>) -> tensor<4x256xf32>
|
||||
// CHECK-DAG: %[[MEAN_BCAST:.+]] = "mhlo.broadcast_in_dim"(%[[MEAN]]) {broadcast_dimensions = dense<1> : tensor<1xi64>} : (tensor<256xf32>) -> tensor<4x256xf32>
|
||||
// CHECK-DAG: %[[X_CENTER:.+]] = mhlo.subtract %[[X]], %[[MEAN_BCAST]] : tensor<4x256xf32>
|
||||
// CHECK-DAG: %[[X_SCALED:.+]] = mhlo.multiply %[[X_CENTER]], %[[SCALE_BCAST]] : tensor<4x256xf32>
|
||||
// CHECK-DAG: %[[X_NORMED:.+]] = mhlo.divide %[[X_SCALED]], %[[STDDEV_BCAST]] : tensor<4x256xf32>
|
||||
// CHECK-DAG: %[[RESULT:.+]] = mhlo.add %[[X_NORMED]], %[[OFFSET_BCAST]] : tensor<4x256xf32>
|
||||
%0 = "mhlo.batch_norm_inference"(%x, %scale, %offset, %mean, %variance)
|
||||
{epsilon = 1.001000e-05 : f32, feature_index = 1 : i64} :
|
||||
(tensor<4x256xf32>, tensor<256xf32>, tensor<256xf32>, tensor<256xf32>,
|
||||
tensor<256xf32>) -> tensor<4x256xf32>
|
||||
@ -36,12 +36,12 @@ func @batchNormInference_2D_inner_features(
|
||||
// the verifier to enforce the rest.
|
||||
// CHECK-SAME: %[[X:[^:]+]]
|
||||
// CHECK-SAME: %[[SCALE:[^:]+]]
|
||||
// CHECK-DAG: %[[SCALE_BCAST:.+]] = "xla_hlo.broadcast_in_dim"(%[[SCALE]]) {broadcast_dimensions = dense<2> : tensor<1xi64>} : (tensor<256xf32>) -> tensor<3x4x256x6xf32>
|
||||
// CHECK-DAG: %[[SCALE_BCAST:.+]] = "mhlo.broadcast_in_dim"(%[[SCALE]]) {broadcast_dimensions = dense<2> : tensor<1xi64>} : (tensor<256xf32>) -> tensor<3x4x256x6xf32>
|
||||
func @batchNormInference_4D_middle_features(
|
||||
%x: tensor<3x4x256x6xf32>, %scale: tensor<256xf32>, %offset: tensor<256xf32>,
|
||||
%mean: tensor<256xf32>, %variance: tensor<256xf32>)
|
||||
-> (tensor<3x4x256x6xf32>) {
|
||||
%0 = "xla_hlo.batch_norm_inference"(%x, %scale, %offset, %mean, %variance)
|
||||
%0 = "mhlo.batch_norm_inference"(%x, %scale, %offset, %mean, %variance)
|
||||
{epsilon = 1.001000e-05 : f32, feature_index = 2 : i64} :
|
||||
(tensor<3x4x256x6xf32>, tensor<256xf32>, tensor<256xf32>, tensor<256xf32>,
|
||||
tensor<256xf32>) -> tensor<3x4x256x6xf32>
|
||||
@ -51,12 +51,12 @@ func @batchNormInference_4D_middle_features(
|
||||
// -----
|
||||
// CHECK-LABEL: @batchNormInference_f64
|
||||
// Validate that epsilon is properly promoted to f64
|
||||
// CHECK-DAG: %[[EPS:.+]] = xla_hlo.constant dense<1.000000e+00> : tensor<f64>
|
||||
// CHECK-DAG: %[[EPS:.+]] = mhlo.constant dense<1.000000e+00> : tensor<f64>
|
||||
func @batchNormInference_f64(
|
||||
%x: tensor<4x256xf64>, %scale: tensor<256xf64>, %offset: tensor<256xf64>,
|
||||
%mean: tensor<256xf64>, %variance: tensor<256xf64>)
|
||||
-> (tensor<4x256xf64>) {
|
||||
%0 = "xla_hlo.batch_norm_inference"(%x, %scale, %offset, %mean, %variance)
|
||||
%0 = "mhlo.batch_norm_inference"(%x, %scale, %offset, %mean, %variance)
|
||||
{epsilon = 1.0 : f32, feature_index = 1 : i64} :
|
||||
(tensor<4x256xf64>, tensor<256xf64>, tensor<256xf64>, tensor<256xf64>,
|
||||
tensor<256xf64>) -> tensor<4x256xf64>
|
||||
@ -66,12 +66,12 @@ func @batchNormInference_f64(
|
||||
// -----
|
||||
// CHECK-LABEL: @batchNormInference_f16
|
||||
// Validate that epsilon is properly promoted to f64
|
||||
// CHECK-DAG: %[[EPS:.+]] = xla_hlo.constant dense<1.000000e+00> : tensor<f16>
|
||||
// CHECK-DAG: %[[EPS:.+]] = mhlo.constant dense<1.000000e+00> : tensor<f16>
|
||||
func @batchNormInference_f16(
|
||||
%x: tensor<4x256xf16>, %scale: tensor<256xf16>, %offset: tensor<256xf16>,
|
||||
%mean: tensor<256xf16>, %variance: tensor<256xf16>)
|
||||
-> (tensor<4x256xf16>) {
|
||||
%0 = "xla_hlo.batch_norm_inference"(%x, %scale, %offset, %mean, %variance)
|
||||
%0 = "mhlo.batch_norm_inference"(%x, %scale, %offset, %mean, %variance)
|
||||
{epsilon = 1.0 : f32, feature_index = 1 : i64} :
|
||||
(tensor<4x256xf16>, tensor<256xf16>, tensor<256xf16>, tensor<256xf16>,
|
||||
tensor<256xf16>) -> tensor<4x256xf16>
|
||||
@ -85,7 +85,7 @@ func @batchNormInference_f16_overflow(
|
||||
%mean: tensor<256xf16>, %variance: tensor<256xf16>)
|
||||
-> (tensor<4x256xf16>) {
|
||||
// expected-warning @+1 {{Could not convert batch_norm epsilon to target fp type: opStatus = 24}}
|
||||
%0 = "xla_hlo.batch_norm_inference"(%x, %scale, %offset, %mean, %variance)
|
||||
%0 = "mhlo.batch_norm_inference"(%x, %scale, %offset, %mean, %variance)
|
||||
{epsilon = 0.00000001 : f32, feature_index = 1 : i64} :
|
||||
(tensor<4x256xf16>, tensor<256xf16>, tensor<256xf16>, tensor<256xf16>,
|
||||
tensor<256xf16>) -> tensor<4x256xf16>
|
||||
@ -108,26 +108,26 @@ func @batchNormInference_dynamic_shape(
|
||||
// CHECK-DAG: %[[C1:.*]] = constant 1 : index
|
||||
// CHECK-DAG: %[[C2:.*]] = constant 2 : index
|
||||
// CHECK-DAG: %[[C3:.*]] = constant 3 : index
|
||||
// CHECK-DAG: %[[EPS:.+]] = xla_hlo.constant dense<1.000000e-03> : tensor<f32>
|
||||
// CHECK-DAG: %[[EPS:.+]] = mhlo.constant dense<1.000000e-03> : tensor<f32>
|
||||
// CHECK-DAG: %[[DIM:.+]] = dim %[[VARIANCE]], %[[C0]] : tensor<?xf32>
|
||||
// CHECK-DAG: %[[TO_DIM_TENSOR:.+]] = tensor_from_elements(%[[DIM]]) : tensor<1xindex>
|
||||
// CHECK-DAG: %[[EPS_BCAST:.+]] = "xla_hlo.dynamic_broadcast_in_dim"(%[[EPS]], %[[TO_DIM_TENSOR]]) {broadcast_dimensions = dense<[]> : tensor<0xi64>} : (tensor<f32>, tensor<1xindex>) -> tensor<?xf32>
|
||||
// CHECK-DAG: %[[VARIANCE_EPS:.+]] = xla_hlo.add %[[VARIANCE]], %[[EPS_BCAST]] : tensor<?xf32>
|
||||
// CHECK-DAG: %[[STDDEV:.+]] = "xla_hlo.sqrt"(%[[VARIANCE_EPS]]) : (tensor<?xf32>) -> tensor<?xf32>
|
||||
// CHECK-DAG: %[[EPS_BCAST:.+]] = "mhlo.dynamic_broadcast_in_dim"(%[[EPS]], %[[TO_DIM_TENSOR]]) {broadcast_dimensions = dense<[]> : tensor<0xi64>} : (tensor<f32>, tensor<1xindex>) -> tensor<?xf32>
|
||||
// CHECK-DAG: %[[VARIANCE_EPS:.+]] = mhlo.add %[[VARIANCE]], %[[EPS_BCAST]] : tensor<?xf32>
|
||||
// CHECK-DAG: %[[STDDEV:.+]] = "mhlo.sqrt"(%[[VARIANCE_EPS]]) : (tensor<?xf32>) -> tensor<?xf32>
|
||||
// CHECK-DAG: %[[INPUT_DIM_0:.+]] = dim %[[X]], %[[C0]] : tensor<?x?x?x?xf32>
|
||||
// CHECK-DAG: %[[INPUT_DIM_1:.+]] = dim %[[X]], %[[C1]] : tensor<?x?x?x?xf32>
|
||||
// CHECK-DAG: %[[INPUT_DIM_2:.+]] = dim %[[X]], %[[C2]] : tensor<?x?x?x?xf32>
|
||||
// CHECK-DAG: %[[INPUT_DIM_3:.+]] = dim %[[X]], %[[C3]] : tensor<?x?x?x?xf32>
|
||||
// CHECK-DAG: %[[TO_INPUT_DIM_TENSOR:.+]] = tensor_from_elements(%[[INPUT_DIM_0]], %[[INPUT_DIM_1]], %[[INPUT_DIM_2]], %[[INPUT_DIM_3]]) : tensor<4xindex>
|
||||
// CHECK-DAG: %[[STDDEV_BCAST:.+]] = "xla_hlo.dynamic_broadcast_in_dim"(%[[STDDEV]], %[[TO_INPUT_DIM_TENSOR]]) {broadcast_dimensions = dense<1> : tensor<1xi64>} : (tensor<?xf32>, tensor<4xindex>) -> tensor<?x?x?x?xf32>
|
||||
// CHECK-DAG: %[[SCALE_BCAST:.+]] = "xla_hlo.dynamic_broadcast_in_dim"(%[[SCALE]], %[[TO_INPUT_DIM_TENSOR]]) {broadcast_dimensions = dense<1> : tensor<1xi64>} : (tensor<?xf32>, tensor<4xindex>) -> tensor<?x?x?x?xf32>
|
||||
// CHECK-DAG: %[[OFFSET_BCAST:.+]] = "xla_hlo.dynamic_broadcast_in_dim"(%[[OFFSET]], %[[TO_INPUT_DIM_TENSOR]]) {broadcast_dimensions = dense<1> : tensor<1xi64>} : (tensor<?xf32>, tensor<4xindex>) -> tensor<?x?x?x?xf32>
|
||||
// CHECK-DAG: %[[MEAN_BCAST:.+]] = "xla_hlo.dynamic_broadcast_in_dim"(%[[MEAN]], %[[TO_INPUT_DIM_TENSOR]]) {broadcast_dimensions = dense<1> : tensor<1xi64>} : (tensor<?xf32>, tensor<4xindex>) -> tensor<?x?x?x?xf32>
|
||||
// CHECK-DAG: %[[X_CENTER:.+]] = xla_hlo.subtract %[[X]], %[[MEAN_BCAST]] : tensor<?x?x?x?xf32>
|
||||
// CHECK-DAG: %[[X_SCALED:.+]] = xla_hlo.multiply %[[X_CENTER]], %[[SCALE_BCAST]] : tensor<?x?x?x?xf32>
|
||||
// CHECK-DAG: %[[X_NORMED:.+]] = xla_hlo.divide %[[X_SCALED]], %[[STDDEV_BCAST]] : tensor<?x?x?x?xf32>
|
||||
// CHECK-DAG: %[[RESULT:.+]] = xla_hlo.add %[[X_NORMED]], %[[OFFSET_BCAST]] : tensor<?x?x?x?xf32>
|
||||
%0 = "xla_hlo.batch_norm_inference"(%x, %scale, %offset, %mean, %variance)
|
||||
// CHECK-DAG: %[[STDDEV_BCAST:.+]] = "mhlo.dynamic_broadcast_in_dim"(%[[STDDEV]], %[[TO_INPUT_DIM_TENSOR]]) {broadcast_dimensions = dense<1> : tensor<1xi64>} : (tensor<?xf32>, tensor<4xindex>) -> tensor<?x?x?x?xf32>
|
||||
// CHECK-DAG: %[[SCALE_BCAST:.+]] = "mhlo.dynamic_broadcast_in_dim"(%[[SCALE]], %[[TO_INPUT_DIM_TENSOR]]) {broadcast_dimensions = dense<1> : tensor<1xi64>} : (tensor<?xf32>, tensor<4xindex>) -> tensor<?x?x?x?xf32>
|
||||
// CHECK-DAG: %[[OFFSET_BCAST:.+]] = "mhlo.dynamic_broadcast_in_dim"(%[[OFFSET]], %[[TO_INPUT_DIM_TENSOR]]) {broadcast_dimensions = dense<1> : tensor<1xi64>} : (tensor<?xf32>, tensor<4xindex>) -> tensor<?x?x?x?xf32>
|
||||
// CHECK-DAG: %[[MEAN_BCAST:.+]] = "mhlo.dynamic_broadcast_in_dim"(%[[MEAN]], %[[TO_INPUT_DIM_TENSOR]]) {broadcast_dimensions = dense<1> : tensor<1xi64>} : (tensor<?xf32>, tensor<4xindex>) -> tensor<?x?x?x?xf32>
|
||||
// CHECK-DAG: %[[X_CENTER:.+]] = mhlo.subtract %[[X]], %[[MEAN_BCAST]] : tensor<?x?x?x?xf32>
|
||||
// CHECK-DAG: %[[X_SCALED:.+]] = mhlo.multiply %[[X_CENTER]], %[[SCALE_BCAST]] : tensor<?x?x?x?xf32>
|
||||
// CHECK-DAG: %[[X_NORMED:.+]] = mhlo.divide %[[X_SCALED]], %[[STDDEV_BCAST]] : tensor<?x?x?x?xf32>
|
||||
// CHECK-DAG: %[[RESULT:.+]] = mhlo.add %[[X_NORMED]], %[[OFFSET_BCAST]] : tensor<?x?x?x?xf32>
|
||||
%0 = "mhlo.batch_norm_inference"(%x, %scale, %offset, %mean, %variance)
|
||||
{epsilon = 0.001 : f32, feature_index = 1 : i64} :
|
||||
(tensor<?x?x?x?xf32>, tensor<?xf32>, tensor<?xf32>, tensor<?xf32>,
|
||||
tensor<?xf32>) -> tensor<?x?x?x?xf32>
|
97
tensorflow/compiler/mlir/hlo/tests/xla-hlo-fusion.mlir
Normal file
97
tensorflow/compiler/mlir/hlo/tests/xla-hlo-fusion.mlir
Normal file
@ -0,0 +1,97 @@
|
||||
// RUN: mlir-hlo-opt %s -xla-hlo-fusion -split-input-file | FileCheck %s
|
||||
|
||||
// CHECK-LABEL: func @multi_outputs_same
|
||||
func @multi_outputs_same(%arg0: tensor<?x?xf32>, %arg1: tensor<?x?xf32>) -> (tensor<?x?xf32>, tensor<?x?xf32>) {
|
||||
%0 = "mhlo.add"(%arg0, %arg1) : (tensor<?x?xf32>, tensor<?x?xf32>) -> tensor<?x?xf32>
|
||||
%1 = "mhlo.subtract"(%arg0, %0) : (tensor<?x?xf32>, tensor<?x?xf32>) -> tensor<?x?xf32>
|
||||
%2 = "mhlo.add"(%1, %1) : (tensor<?x?xf32>, tensor<?x?xf32>) -> tensor<?x?xf32>
|
||||
// CHECK: %[[RET:.*]]:2 = "mhlo.fusion"
|
||||
// CHECK-NEXT: mhlo.add
|
||||
// CHECK-NEXT: mhlo.subtract
|
||||
// CHECK-NEXT: mhlo.add
|
||||
// CHECK-NEXT: mhlo.return
|
||||
return %1, %2 : tensor<?x?xf32>, tensor<?x?xf32>
|
||||
}
|
||||
|
||||
// -----
|
||||
|
||||
// CHECK-LABEL: func @multi_outputs_same_2
|
||||
func @multi_outputs_same_2(%arg0: tensor<?x?xf32>, %arg1: tensor<?x?xf32>) -> (tensor<?x?xf32>, tensor<?x?xf32>, tensor<?x?xf32>) {
|
||||
%0 = "mhlo.abs"(%arg0) : (tensor<?x?xf32>) -> tensor<?x?xf32>
|
||||
%1 = "mhlo.abs"(%arg1) : (tensor<?x?xf32>) -> tensor<?x?xf32>
|
||||
%2 = "mhlo.add"(%0, %1) : (tensor<?x?xf32>, tensor<?x?xf32>) -> tensor<?x?xf32>
|
||||
%3 = "mhlo.abs"(%0) : (tensor<?x?xf32>) -> tensor<?x?xf32>
|
||||
%4 = "mhlo.abs"(%1) : (tensor<?x?xf32>) -> tensor<?x?xf32>
|
||||
// CHECK: %[[RET:.*]]:3 = "mhlo.fusion"
|
||||
// CHECK-NEXT: mhlo.abs
|
||||
// CHECK-NEXT: mhlo.abs
|
||||
// CHECK-NEXT: mhlo.add
|
||||
// CHECK-NEXT: mhlo.abs
|
||||
// CHECK-NEXT: mhlo.abs
|
||||
// CHECK-NEXT: mhlo.return
|
||||
return %2, %3, %4 : tensor<?x?xf32>, tensor<?x?xf32>, tensor<?x?xf32>
|
||||
}
|
||||
|
||||
// -----
|
||||
|
||||
// CHECK-LABEL: func @multi_outputs_not_sure_same
|
||||
func @multi_outputs_not_sure_same(%arg0: tensor<?x?xf32>, %arg1: tensor<?x?xf32>) -> (tensor<?x?xf32>, tensor<?x?xf32>) {
|
||||
%0 = "mhlo.add"(%arg0, %arg0) : (tensor<?x?xf32>, tensor<?x?xf32>) -> tensor<?x?xf32>
|
||||
// CHECK-NOT: mhlo.fusion
|
||||
%1 = "mhlo.subtract"(%arg1, %arg1) : (tensor<?x?xf32>, tensor<?x?xf32>) -> tensor<?x?xf32>
|
||||
return %0, %1 : tensor<?x?xf32>, tensor<?x?xf32>
|
||||
}
|
||||
|
||||
// -----
|
||||
|
||||
// CHECK-LABEL: func @reduce
|
||||
func @reduce(%arg0: tensor<?x?xf32>, %arg1: tensor<?x?xf32>) -> (tensor<?x?xf32>, tensor<?xf32>) {
|
||||
%0 = "mhlo.add"(%arg0, %arg1) : (tensor<?x?xf32>, tensor<?x?xf32>) -> tensor<?x?xf32>
|
||||
%1 = "mhlo.subtract"(%arg0, %0) : (tensor<?x?xf32>, tensor<?x?xf32>) -> tensor<?x?xf32>
|
||||
// CHECK: %[[RET0:.*]] = "mhlo.fusion"
|
||||
// CHECK-NEXT: mhlo.add
|
||||
// CHECK-NEXT: mhlo.subtract
|
||||
// CHECK-NEXT: mhlo.return
|
||||
// Currently we do not support fuse arguments and ops without direct producer-consumer
|
||||
// relationship. Thus Reduce Op should not be fused with above two ops.
|
||||
|
||||
%2 = mhlo.constant dense<0.000000e+00> : tensor<f32>
|
||||
%3 = "mhlo.reduce"(%arg0, %2) ( {
|
||||
^bb0(%arg2: tensor<f32>, %arg3: tensor<f32>):
|
||||
%4 = "mhlo.add"(%arg2, %arg3) : (tensor<f32>, tensor<f32>) -> tensor<f32>
|
||||
"mhlo.return"(%4) : (tensor<f32>) -> ()
|
||||
}) {dimensions = dense<[1]> : tensor<1xi64>} : (tensor<?x?xf32>, tensor<f32>) -> tensor<?xf32>
|
||||
%4 = "mhlo.add"(%3, %3) : (tensor<?xf32>, tensor<?xf32>) -> tensor<?xf32>
|
||||
// Above two ops should not be fused since reduce op can not be
|
||||
// fused with its consumer.
|
||||
// CHECK-NOT: mhlo.fusion
|
||||
|
||||
return %1, %4 : tensor<?x?xf32>, tensor<?xf32>
|
||||
}
|
||||
|
||||
// -----
|
||||
|
||||
// CHECK-LABEL: func @reduce_2
|
||||
func @reduce_2(%arg0: tensor<?x?xf32>, %arg1: tensor<?x?xf32>) -> (tensor<?x?xf32>, tensor<?xf32>) {
|
||||
%0 = "mhlo.add"(%arg0, %arg1) : (tensor<?x?xf32>, tensor<?x?xf32>) -> tensor<?x?xf32>
|
||||
%1 = "mhlo.subtract"(%arg0, %0) : (tensor<?x?xf32>, tensor<?x?xf32>) -> tensor<?x?xf32>
|
||||
|
||||
%2 = mhlo.constant dense<0.000000e+00> : tensor<f32>
|
||||
%3 = "mhlo.reduce"(%1, %2) ( {
|
||||
^bb0(%arg2: tensor<f32>, %arg3: tensor<f32>):
|
||||
%4 = "mhlo.add"(%arg2, %arg3) : (tensor<f32>, tensor<f32>) -> tensor<f32>
|
||||
"mhlo.return"(%4) : (tensor<f32>) -> ()
|
||||
}) {dimensions = dense<[1]> : tensor<1xi64>} : (tensor<?x?xf32>, tensor<f32>) -> tensor<?xf32>
|
||||
// CHECK: %[[RET0:.*]]:2 = "mhlo.fusion"
|
||||
// CHECK-NEXT: mhlo.add
|
||||
// CHECK-NEXT: mhlo.subtract
|
||||
// CHECK-NEXT: mhlo.constant
|
||||
// CHECK-NEXT: mhlo.reduce
|
||||
// CHECK: mhlo.return
|
||||
|
||||
// Following op should not be fused with the above ops since reduce op can not be
|
||||
// fused with its consumer.
|
||||
// CHECK-NOT: mhlo.fusion
|
||||
%4 = "mhlo.add"(%3, %3) : (tensor<?xf32>, tensor<?xf32>) -> tensor<?xf32>
|
||||
return %1, %4 : tensor<?x?xf32>, tensor<?xf32>
|
||||
}
|
@ -1,4 +1,4 @@
|
||||
// RUN: xla-opt -transform-unranked-hlo -split-input-file %s | FileCheck %s
|
||||
// RUN: mlir-hlo-opt -transform-unranked-hlo -split-input-file %s | FileCheck %s
|
||||
|
||||
// Check the validity of expected IR.
|
||||
// CHECK-LABEL: @sqr_transform_result
|
||||
@ -9,15 +9,15 @@ func @sqr_transform_result(%a: tensor<*xf32>) -> tensor<*xf32> {
|
||||
%num_elements = shape.num_elements %shape
|
||||
%num_elements_as_index = shape.size_to_index %num_elements
|
||||
%flat_shape = tensor_from_elements(%num_elements_as_index) : tensor<1xindex>
|
||||
%flat_a = "xla_hlo.dynamic_reshape"(%a, %flat_shape)
|
||||
%flat_a = "mhlo.dynamic_reshape"(%a, %flat_shape)
|
||||
: (tensor<*xf32>, tensor<1xindex>) -> tensor<?xf32>
|
||||
|
||||
// Apply operation.
|
||||
%flat_b = "xla_hlo.sqrt"(%flat_a) : (tensor<?xf32>) -> tensor<?xf32>
|
||||
%flat_b = "mhlo.sqrt"(%flat_a) : (tensor<?xf32>) -> tensor<?xf32>
|
||||
|
||||
// Restore original shape.
|
||||
%shape_as_extent_tensor = shape.to_extent_tensor %shape : tensor<?xindex>
|
||||
%b = "xla_hlo.dynamic_reshape"(%flat_b, %shape_as_extent_tensor)
|
||||
%b = "mhlo.dynamic_reshape"(%flat_b, %shape_as_extent_tensor)
|
||||
: (tensor<?xf32>, tensor<?xindex>) -> tensor<*xf32>
|
||||
|
||||
return %b : tensor<*xf32>
|
||||
@ -33,12 +33,12 @@ func @sqrt(%a: tensor<*xf32>) -> tensor<*xf32> {
|
||||
// CHECK-NEXT: %[[NUM_ELEMENTS:.*]] = shape.num_elements %[[SHAPE]]
|
||||
// CHECK-NEXT: %[[NUM_ELEMENTS_AS_INDEX:.*]] = shape.size_to_index %[[NUM_ELEMENTS]]
|
||||
// CHECK-NEXT: %[[FLAT_SHAPE:.*]] = tensor_from_elements(%[[NUM_ELEMENTS_AS_INDEX]]) : tensor<1xindex>
|
||||
// CHECK-NEXT: %[[FLAT_A:.*]] = "xla_hlo.dynamic_reshape"(%[[A]], %[[FLAT_SHAPE]]) : (tensor<*xf32>, tensor<1xindex>) -> tensor<?xf32>
|
||||
// CHECK-NEXT: %[[FLAT_B:.*]] = "xla_hlo.sqrt"(%[[FLAT_A]]) : (tensor<?xf32>) -> tensor<?xf32>
|
||||
// CHECK-NEXT: %[[FLAT_A:.*]] = "mhlo.dynamic_reshape"(%[[A]], %[[FLAT_SHAPE]]) : (tensor<*xf32>, tensor<1xindex>) -> tensor<?xf32>
|
||||
// CHECK-NEXT: %[[FLAT_B:.*]] = "mhlo.sqrt"(%[[FLAT_A]]) : (tensor<?xf32>) -> tensor<?xf32>
|
||||
// CHECK-NEXT: %[[SHAPE_AS_EXTENT_TENSOR:.*]] = shape.to_extent_tensor %[[SHAPE]] : tensor<?xindex>
|
||||
// CHECK-NEXT: %[[B:.*]] = "xla_hlo.dynamic_reshape"(%[[FLAT_B]], %[[SHAPE_AS_EXTENT_TENSOR]]) : (tensor<?xf32>, tensor<?xindex>) -> tensor<*xf32>
|
||||
// CHECK-NEXT: %[[B:.*]] = "mhlo.dynamic_reshape"(%[[FLAT_B]], %[[SHAPE_AS_EXTENT_TENSOR]]) : (tensor<?xf32>, tensor<?xindex>) -> tensor<*xf32>
|
||||
// CHECK-NEXT: return %[[B]] : tensor<*xf32>
|
||||
%b = "xla_hlo.sqrt"(%a) : (tensor<*xf32>) -> tensor<*xf32>
|
||||
%b = "mhlo.sqrt"(%a) : (tensor<*xf32>) -> tensor<*xf32>
|
||||
return %b : tensor<*xf32>
|
||||
}
|
||||
|
||||
@ -48,9 +48,9 @@ func @sqrt(%a: tensor<*xf32>) -> tensor<*xf32> {
|
||||
// CHECK-LABEL: @sqrt_ranked
|
||||
// CHECK-SAME: (%[[A:.*]]: tensor<3x?xf32>)
|
||||
func @sqrt_ranked(%a: tensor<3x?xf32>) -> tensor<3x?xf32> {
|
||||
// CHECK-NEXT: %[[B:.*]] = "xla_hlo.sqrt"(%[[A]]) : (tensor<3x?xf32>) -> tensor<3x?xf32>
|
||||
// CHECK-NEXT: %[[B:.*]] = "mhlo.sqrt"(%[[A]]) : (tensor<3x?xf32>) -> tensor<3x?xf32>
|
||||
// CHECK-NEXT: return %[[B]] : tensor<3x?xf32>
|
||||
%b = "xla_hlo.sqrt"(%a) : (tensor<3x?xf32>) -> tensor<3x?xf32>
|
||||
%b = "mhlo.sqrt"(%a) : (tensor<3x?xf32>) -> tensor<3x?xf32>
|
||||
return %b : tensor<3x?xf32>
|
||||
}
|
||||
|
||||
@ -60,9 +60,9 @@ func @sqrt_ranked(%a: tensor<3x?xf32>) -> tensor<3x?xf32> {
|
||||
// CHECK-LABEL: @sqrt_static
|
||||
// CHECK-SAME: (%[[A:.*]]: tensor<2x3xf32>)
|
||||
func @sqrt_static(%a: tensor<2x3xf32>) -> tensor<2x3xf32> {
|
||||
// CHECK-NEXT: %[[B:.*]] = "xla_hlo.sqrt"(%[[A]]) : (tensor<2x3xf32>) -> tensor<2x3xf32>
|
||||
// CHECK-NEXT: %[[B:.*]] = "mhlo.sqrt"(%[[A]]) : (tensor<2x3xf32>) -> tensor<2x3xf32>
|
||||
// CHECK-NEXT: return %[[B]] : tensor<2x3xf32>
|
||||
%b = "xla_hlo.sqrt"(%a) : (tensor<2x3xf32>) -> tensor<2x3xf32>
|
||||
%b = "mhlo.sqrt"(%a) : (tensor<2x3xf32>) -> tensor<2x3xf32>
|
||||
return %b : tensor<2x3xf32>
|
||||
}
|
||||
|
||||
@ -77,12 +77,12 @@ func @add_unranked(%a : tensor<*xf32>, %b : tensor<*xf32>) -> tensor<*xf32> {
|
||||
// CHECK: %[[NUM_ELEMENTS:.*]] = shape.num_elements %[[SHAPE]]
|
||||
// CHECK: %[[NUM_ELEMENTS_AS_INDEX:.*]] = shape.size_to_index %[[NUM_ELEMENTS]]
|
||||
// CHECK: %[[FLAT_SHAPE:.*]] = tensor_from_elements(%[[NUM_ELEMENTS_AS_INDEX]]) : tensor<1xindex>
|
||||
// CHECK: %[[FLAT_A:.*]] = "xla_hlo.dynamic_reshape"(%[[A]], %[[FLAT_SHAPE]]) : (tensor<*xf32>, tensor<1xindex>) -> tensor<?xf32>
|
||||
// CHECK: %[[FLAT_B:.*]] = "xla_hlo.dynamic_reshape"(%[[B]], %[[FLAT_SHAPE]]) : (tensor<*xf32>, tensor<1xindex>) -> tensor<?xf32>
|
||||
// CHECK: %[[FLAT_RESULT:.*]] = xla_hlo.add %[[FLAT_A]], %[[FLAT_B]] : tensor<?xf32>
|
||||
// CHECK: %[[FLAT_A:.*]] = "mhlo.dynamic_reshape"(%[[A]], %[[FLAT_SHAPE]]) : (tensor<*xf32>, tensor<1xindex>) -> tensor<?xf32>
|
||||
// CHECK: %[[FLAT_B:.*]] = "mhlo.dynamic_reshape"(%[[B]], %[[FLAT_SHAPE]]) : (tensor<*xf32>, tensor<1xindex>) -> tensor<?xf32>
|
||||
// CHECK: %[[FLAT_RESULT:.*]] = mhlo.add %[[FLAT_A]], %[[FLAT_B]] : tensor<?xf32>
|
||||
// CHECK: %[[SHAPE_AS_EXTENT_TENSOR:.*]] = shape.to_extent_tensor %[[SHAPE]] : tensor<?xindex>
|
||||
// CHECK: %[[RESULT:.*]] = "xla_hlo.dynamic_reshape"(%[[FLAT_RESULT]], %[[SHAPE_AS_EXTENT_TENSOR]]) : (tensor<?xf32>, tensor<?xindex>) -> tensor<*xf32>
|
||||
// CHECK: %[[RESULT:.*]] = "mhlo.dynamic_reshape"(%[[FLAT_RESULT]], %[[SHAPE_AS_EXTENT_TENSOR]]) : (tensor<?xf32>, tensor<?xindex>) -> tensor<*xf32>
|
||||
// CHECK: return %[[RESULT]] : tensor<*xf32>
|
||||
%result = xla_hlo.add %a, %b : tensor<*xf32>
|
||||
%result = mhlo.add %a, %b : tensor<*xf32>
|
||||
return %result : tensor<*xf32>
|
||||
}
|
@ -115,12 +115,12 @@ Status MlirFunctionOptimizationPass::Run(
|
||||
});
|
||||
|
||||
if (!is_enabled) {
|
||||
VLOG(1) << "None of the MLIR optimization passes are enabled "
|
||||
VLOG(0) << "None of the MLIR optimization passes are enabled "
|
||||
<< "(registered " << registry_->passes().size() << ")";
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
VLOG(1) << "Running MLIR Graph Optimization Passes "
|
||||
VLOG(0) << "Running MLIR Graph Optimization Passes "
|
||||
<< "(registered " << registry_->passes().size() << " passes)";
|
||||
|
||||
GraphDebugInfo debug_info;
|
||||
@ -187,12 +187,12 @@ Status MlirV1CompatGraphOptimizationPass::Run(
|
||||
});
|
||||
|
||||
if (!is_enabled) {
|
||||
VLOG(1) << "None of the MLIR optimization passes are enabled "
|
||||
VLOG(0) << "None of the MLIR optimization passes are enabled "
|
||||
<< "(registered" << registry_->passes().size() << " passes)";
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
VLOG(1) << "Running MLIR Graph Optimization V1 Compat Passes "
|
||||
VLOG(0) << "Running MLIR Graph Optimization V1 Compat Passes "
|
||||
<< "(registered" << registry_->passes().size() << " passes)";
|
||||
|
||||
GraphDebugInfo debug_info;
|
||||
|
@ -70,7 +70,7 @@ tool_dirs = config.mlir_tf_tools_dirs + [
|
||||
config.mlir_tools_dir, config.llvm_tools_dir
|
||||
]
|
||||
tool_names = [
|
||||
'mlir-opt', 'mlir-translate', 'tf-opt', 'tf_tfl_translate',
|
||||
'mlir-opt', 'mlir-hlo-opt', 'mlir-translate', 'tf-opt', 'tf_tfl_translate',
|
||||
'tf_tfjs_translate', 'flatbuffer_to_string', 'flatbuffer_translate',
|
||||
'tf-mlir-translate', 'mlir-tflite-runner', 'tfcompile',
|
||||
'json_to_flatbuffer', 'xla-gpu-opt', 'xla-opt', 'hlo_to_llvm_ir'
|
||||
|
@ -42,6 +42,7 @@ config.suffixes = ['.td', '.mlir', '.pbtxt']
|
||||
|
||||
mlir_tf_tools_dirs = [
|
||||
'tensorflow/compiler/mlir',
|
||||
'tensorflow/compiler/mlir/hlo',
|
||||
'tensorflow/compiler/mlir/lite',
|
||||
'tensorflow/compiler/mlir/tensorflow',
|
||||
'tensorflow/compiler/mlir/tfjs',
|
||||
|
@ -144,6 +144,7 @@ gentbl(
|
||||
td_srcs = [
|
||||
"@llvm-project//mlir:include/mlir/IR/OpBase.td",
|
||||
"@llvm-project//mlir:include/mlir/Dialect/StandardOps/IR/Ops.td",
|
||||
"@llvm-project//mlir:include/mlir/IR/SymbolInterfaces.td",
|
||||
],
|
||||
test = True,
|
||||
)
|
||||
@ -786,7 +787,6 @@ cc_library(
|
||||
"//tensorflow/core:lib",
|
||||
"//tensorflow/core:protos_all_cc",
|
||||
"//tensorflow/core/grappler/utils:transitive_fanin",
|
||||
"//tensorflow/core/platform:protobuf_internal",
|
||||
"//tensorflow/core/platform:types",
|
||||
"//tensorflow/stream_executor/lib",
|
||||
"@com_google_absl//absl/algorithm:container",
|
||||
|
Binary file not shown.
After Width: | Height: | Size: 180 KiB |
196
tensorflow/compiler/mlir/tensorflow/g3doc/space_to_depth.md
Normal file
196
tensorflow/compiler/mlir/tensorflow/g3doc/space_to_depth.md
Normal file
@ -0,0 +1,196 @@
|
||||
# Automatic Space to Depth Transform in MLIR Bridge
|
||||
|
||||
Author: wangtao@, yuanzx@, hinsu@, lyandy@, chiachenc@, aminim@, jpienaar@,
|
||||
dehao@
|
||||
|
||||
## TL;DR
|
||||
|
||||
_This document describes an automatic space to depth transform for the first
|
||||
convolution in the new MLIR bridge to improve MXU efficiency of low batch size
|
||||
convolutions._
|
||||
|
||||
## Background
|
||||
|
||||
For image models, the first layer is usually not MXU friendly as it has a
|
||||
feature size of 3. This results in poor performance especially with small batch.
|
||||
|
||||
One way to address this issue is to use the `space-to-depth` transform. This
|
||||
optimization tiles the 2x2 space dimensions to the feature dimension so that the
|
||||
feature dimension becomes 3\*4=12, which is more MXU friendly. In order to make
|
||||
this optimization efficient, the shape of the weight needs to be padded and
|
||||
transposed to the shape that the convolution emitter expects. The input also
|
||||
needs to be transposed on the host and padded on the device to make the
|
||||
convolution efficient. Although a 2x2 space-to-depth transform works only when
|
||||
the first convolution has a stride of 2, many image models, ResNet-like in
|
||||
particular, have a stride-2 convolution in the first layer.
|
||||
|
||||
Space to depth helped models such as MaskRCNN, SSD and I3D gain more than 2X
|
||||
speedup and reduce memory usage in the first convolution.
|
||||
|
||||
The first convolution in many image models, including ResNet or ResNet-like, is
|
||||
a (kernel=7, stride=2) 2D convolution. The input of the convolution is images,
|
||||
which usually has RGB channels. The input of this first convolution is of shape
|
||||
[batch\_size, height, width, 3] and the kernel size is [kernel\_size,
|
||||
kernel\_size, 3, out\_channel]. Space to depth is to transform this first
|
||||
convolution's input to [batch\_size, height // stride, width // stride, 3 \*
|
||||
stride \* stride] and the kernel to [kernel\_size // stride, kernel\_size //
|
||||
stride, 3 \* stride \* stride, out\_channel] to improve TPU MXU utilization.
|
||||
|
||||

|
||||
|
||||
This optimization can be automatically done by the graph optimizer where weight
|
||||
transformation is done at variable loading time and the input transformation is
|
||||
done for every inference invocation. A further optimization can fuse this (at
|
||||
host) with the double transpose to minimize memory operation on host.
|
||||
|
||||
## Proposed Method
|
||||
|
||||
**block\_size** is defined as the number of space sizes transformed to the depth
|
||||
dimension. _stride % block\_size == 0_ and _stride >= block\_size_ is required
|
||||
to do the transform. There are three parts of automatically space to depth
|
||||
transformation:
|
||||
|
||||
1. Transform input on the host.
|
||||
|
||||
Space-to-depth performs the following permutation, which is equivalent to
|
||||
`tf.nn.space_to_depth`.
|
||||
|
||||
```python
|
||||
images = tf.reshape(images, [batch, h // block_size, block_size,
|
||||
w // block_size, block_size, c])
|
||||
images = tf.transpose(images, [0, 1, 3, 2, 4, 5])
|
||||
images = tf.reshape(images, [batch, h // block_size, w // block_size,
|
||||
c * (block_size ** 2)])
|
||||
```
|
||||
|
||||
`SpaceToDepthOp` can be called on the host to perform the transform.
|
||||
|
||||
1. Weight Transformation
|
||||
|
||||
Weight Transformation is similar to Input Transform. Weight transform is
|
||||
needed to apply space to depth optimization for a model that needs to load a
|
||||
pre-train checkpoint. This transform can be done on the host or TPU device
|
||||
based on the cost. As the size of the kernel is relatively small, this won't
|
||||
add additional cost to TPU device time. Below is the logic to transform the
|
||||
kernel of shape [7, 7, 3, 64] to [4, 4, 12, 84].
|
||||
|
||||
```python
|
||||
conv0 = tf.compat.v1.layers.Conv2D(
|
||||
filters=filters,
|
||||
kernel_size=kernel_size,
|
||||
strides=2,
|
||||
padding=('SAME' if strides == 1 else 'VALID'),
|
||||
use_bias=False,
|
||||
kernel_initializer=tf.variance_scaling_initializer(),
|
||||
data_format=data_format)
|
||||
|
||||
# Use the image size without space-to-depth transform as the input of conv0.
|
||||
batch_size, h, w, channel = inputs.get_shape().as_list()
|
||||
conv0.build([
|
||||
batch_size, h * space_to_depth_block_size, w * space_to_depth_block_size,
|
||||
channel // (space_to_depth_block_size**2)
|
||||
])
|
||||
|
||||
kernel = conv0.weights[0]
|
||||
# [7, 7, 3, 64] --> [8, 8, 3, 64]
|
||||
|
||||
kernel = tf.pad(
|
||||
kernel,
|
||||
paddings=tf.constant([[1, 0], [1, 0], [0, 0], [0, 0]]),
|
||||
mode='CONSTANT',
|
||||
constant_values=0.)
|
||||
# Transform kernel follows the space-to-depth logic: https://www.tensorflow.org/api_docs/python/tf/nn/space_to_depth)
|
||||
kernel = tf.reshape(
|
||||
kernel,
|
||||
[4, space_to_depth_block_size, 4, space_to_depth_block_size, 3, filters])
|
||||
|
||||
kernel = tf.transpose(kernel, [0, 2, 1, 3, 4, 5])
|
||||
kernel = tf.reshape(kernel, [4, 4, int(channel), filters])
|
||||
kernel = tf.cast(kernel, inputs.dtype)
|
||||
```
|
||||
|
||||
If kernel\_size % block\_size != 0, padding is needed for the weight before
|
||||
transform, input of Convolution needs to be padded as well.
|
||||
|
||||
1. Rewrite the first convolution
|
||||
|
||||
Need to rewrite the first convolution's shape of input from [batch\_size,
|
||||
height, width, 3] to [batch\_size, height // block\_size, width //
|
||||
block\_size, 3 \* block\_size \* block\_size] and kernel shape from
|
||||
[kernel\_size, kernel\_size, 3, out\_channel] to [kernel\_size //
|
||||
block\_size, kernel\_size // block\_size, 3 \* block\_size \* block\_size,
|
||||
|
||||
This is the proposed workflow for automatic space to depth transformation.
|
||||
All the transformations will be triggered in a MLIR SpaceToDepthRewritePass,
|
||||
this Rewrite pass will be triggered before TPURewrite so that no metadata
|
||||
rewrite is needed.
|
||||
|
||||
* First, the rewrite pass will walk through all the convolutions in func of
|
||||
tf\_device::LaunchOp and get the first Convolution and its shape;
|
||||
* Second, the rewrite pass will apply transformations to the first
|
||||
convolution, the padding before the first convolution, first convolution's
|
||||
filters and its Conv2DBackPropFilter;
|
||||
* At last, the rewrite pass will insert SpaceToDepthOp after IteratorGetNext
|
||||
where the iterator's result has the same shape as the first convolution's
|
||||
input.
|
||||
|
||||
#### Pseudo MLIR code before and after RewritePass
|
||||
|
||||
```mlir
|
||||
// Example: original program:
|
||||
//
|
||||
module {
|
||||
func @while_body {
|
||||
%input = "tf.IteratorGetNext"(...) {device = "/CPU:0"}:
|
||||
-> tensor<2x224x224x3xf32>
|
||||
%device_launch = "tf_device.launch_func"(%input,...) {func = @_func,...)
|
||||
return ...
|
||||
}
|
||||
func @_func(%input: tensor<2x224x224x3xf32>,
|
||||
%filter: tensor<7x7x3x64xf32>) {
|
||||
%6 = "tf.Conv2D"(%input, %filter) {strides = [1, 2, 2, 1]}:
|
||||
(tensor<2x230x230x3xf32>, tensor<7x7x3x64xf32>) ->
|
||||
tensor<2x112x112x64xf32>
|
||||
}
|
||||
}
|
||||
|
||||
// With this pass, the program will be transformed into:
|
||||
module {
|
||||
func @while_body {
|
||||
%input = "tf.IteratorGetNext"(...) {device = "/CPU:0"}
|
||||
-> tensor<2x224x224x3xf32>
|
||||
%space_to_depth = "tf.SpaceToDepth"(%input) {block_size = 2, ...}:
|
||||
(tensor<2x224x224x3xf32>) -> tensor<2x112x112x12xf32>
|
||||
%device_launch = "tf_device.launch_func"(%space_to_depth,...) {func = @_func,...)
|
||||
return ...
|
||||
}
|
||||
func @_func(%input: tensor<2x112x112x12xf32>,
|
||||
%filter: tensor<7x7x3x64xf32>) {
|
||||
%filter_transform = "tf.Pad/tf.Transpose/tf.Reshape"(%filter):
|
||||
tensor<7x7x3x64xf32>) -> tensor<4x4x12x64xf32>
|
||||
%conv = "tf.Conv2D"(%input, %filter_transfrom) {strides = [1, 1, 1, 1]}:
|
||||
(tensor<2x112x112x12xf32>, tensor<4x4x12x64xf32>) ->
|
||||
tensor<2x112x112x64xf32>
|
||||
}
|
||||
}
|
||||
```
|
||||
|
||||
### SpaceToDepth Trigger Condition
|
||||
|
||||
Space to depth will only be triggered when batch size is small and the first
|
||||
convolution channel size is small. Stride of the convolution should be bigger
|
||||
than 1 as well. A cost model will be built that takes input shape and host cost
|
||||
into consideration to trigger the transformation. There will be a flag to
|
||||
disable this feature as well.
|
||||
|
||||
### Fuse SpaceToDepth with Automatic Double Transpose
|
||||
|
||||
The transpose and reshape op in SpaceToDepthOp on TPU hosts may cause image
|
||||
model to be infeed bound. To reduce host time, space to depth transform can be
|
||||
fused with `automatic double transpose` to reduce extra overhead on the host.
|
||||
|
||||
### Extend from Conv2D to Conv3D
|
||||
|
||||
SpaceToDepth not only helps with 2D image models but also 3D image models such
|
||||
as I3D. The plan is to apply automatic space to depth for Conv2D as the first
|
||||
step. After Conv2D is well tested, will generalize this technique to Conv3D.
|
@ -229,7 +229,8 @@ namespace {
|
||||
ParseResult ParseReplicateOpOperands(
|
||||
OpAsmParser* parser, OperationState* state,
|
||||
llvm::SmallVectorImpl<llvm::SmallVector<OpAsmParser::OperandType, 8>>*
|
||||
operands,
|
||||
replicated_inputs,
|
||||
llvm::SmallVectorImpl<OpAsmParser::OperandType>* packed_inputs,
|
||||
llvm::SmallVectorImpl<OpAsmParser::OperandType>* region_args,
|
||||
llvm::SmallVectorImpl<Type>* region_arg_types) {
|
||||
// No operands or empty operand list.
|
||||
@ -238,26 +239,61 @@ ParseResult ParseReplicateOpOperands(
|
||||
return success();
|
||||
|
||||
// Parse comma separated operands of the following format:
|
||||
// [%a, ...] as %block_arg: type
|
||||
// replicated_input
|
||||
// [%a, ...] as %block_arg0: type
|
||||
// packed_input
|
||||
// %b as %block_arg1: type
|
||||
//
|
||||
// Replicated inputs are placed before packed inputs when forming the op.
|
||||
llvm::SmallVector<OpAsmParser::OperandType, 8> replicated_region_args;
|
||||
llvm::SmallVector<OpAsmParser::OperandType, 8> packed_region_args;
|
||||
llvm::SmallVector<Type, 8> replicated_region_arg_types;
|
||||
llvm::SmallVector<Type, 8> packed_region_arg_types;
|
||||
do {
|
||||
if (parser->parseOperandList(operands->emplace_back(),
|
||||
OpAsmParser::Delimiter::Square) ||
|
||||
parser->parseKeyword("as",
|
||||
" between replicated inputs and block argument") ||
|
||||
parser->parseRegionArgument(region_args->emplace_back()) ||
|
||||
parser->parseColonType(region_arg_types->emplace_back()))
|
||||
OpAsmParser::OperandType operand_type;
|
||||
if (parser->parseOptionalOperand(operand_type).hasValue()) {
|
||||
packed_inputs->emplace_back(operand_type);
|
||||
if (parser->parseKeyword("as",
|
||||
" between packed input and block argument") ||
|
||||
parser->parseRegionArgument(packed_region_args.emplace_back()) ||
|
||||
parser->parseColonType(packed_region_arg_types.emplace_back()))
|
||||
return failure();
|
||||
} else if (parser->parseOperandList(replicated_inputs->emplace_back(),
|
||||
OpAsmParser::Delimiter::Square) ||
|
||||
parser->parseKeyword(
|
||||
"as", " between replicated inputs and block argument") ||
|
||||
parser->parseRegionArgument(
|
||||
replicated_region_args.emplace_back()) ||
|
||||
parser->parseColonType(
|
||||
replicated_region_arg_types.emplace_back())) {
|
||||
return failure();
|
||||
}
|
||||
} while (succeeded(parser->parseOptionalComma()));
|
||||
|
||||
region_args->reserve(replicated_region_args.size() +
|
||||
packed_region_args.size());
|
||||
region_args->append(replicated_region_args.begin(),
|
||||
replicated_region_args.end());
|
||||
region_args->append(packed_region_args.begin(), packed_region_args.end());
|
||||
|
||||
region_arg_types->reserve(replicated_region_arg_types.size() +
|
||||
packed_region_arg_types.size());
|
||||
region_arg_types->append(replicated_region_arg_types.begin(),
|
||||
replicated_region_arg_types.end());
|
||||
region_arg_types->append(packed_region_arg_types.begin(),
|
||||
packed_region_arg_types.end());
|
||||
|
||||
// Parse remaining `)` surrounding operands.
|
||||
return parser->parseRParen();
|
||||
}
|
||||
|
||||
ParseResult SetOperands(
|
||||
ParseResult SetReplicateOpOperands(
|
||||
llvm::SMLoc loc, OpAsmParser* parser, OperationState* state,
|
||||
llvm::ArrayRef<llvm::SmallVector<OpAsmParser::OperandType, 8>> operands,
|
||||
llvm::ArrayRef<Type> region_arg_types, int* n) {
|
||||
if (operands.empty()) return success();
|
||||
llvm::ArrayRef<llvm::SmallVector<OpAsmParser::OperandType, 8>>
|
||||
replicated_inputs,
|
||||
llvm::ArrayRef<OpAsmParser::OperandType> packed_inputs,
|
||||
llvm::ArrayRef<Type> region_arg_types, int32_t* n) {
|
||||
if (replicated_inputs.empty() && packed_inputs.empty()) return success();
|
||||
|
||||
for (const auto& attr : state->attributes)
|
||||
if (attr.first.strref() == "n")
|
||||
@ -267,38 +303,68 @@ ParseResult SetOperands(
|
||||
if (*n < 2)
|
||||
return parser->emitError(loc) << "expects 'n' to be at least 2, got " << *n;
|
||||
|
||||
for (int i = 0, e = operands.size(); i < e; ++i) {
|
||||
const auto& operand = operands[i];
|
||||
for (auto replicated_input_and_idx : llvm::enumerate(replicated_inputs)) {
|
||||
const int32_t idx = replicated_input_and_idx.index();
|
||||
const auto& replicated_input = replicated_input_and_idx.value();
|
||||
// Check if replicated input matches `n`.
|
||||
if (operand.size() != *n)
|
||||
if (replicated_input.size() != *n)
|
||||
return parser->emitError(loc)
|
||||
<< "expects number of operands for replicated input " << i
|
||||
<< " to be 'n' (" << *n << "), got " << operand.size();
|
||||
<< "expects number of operands for replicated input " << idx
|
||||
<< " to be 'n' (" << *n << "), got " << replicated_input.size();
|
||||
|
||||
// Resolve replicated input and block argument type.
|
||||
if (parser->resolveOperands(operand, region_arg_types[i], state->operands))
|
||||
if (parser->resolveOperands(replicated_input, region_arg_types[idx],
|
||||
state->operands))
|
||||
return failure();
|
||||
}
|
||||
|
||||
const int32_t num_replicated_block_args = replicated_inputs.size();
|
||||
for (auto packed_input_and_idx : llvm::enumerate(packed_inputs)) {
|
||||
const int32_t idx = packed_input_and_idx.index();
|
||||
const auto& packed_input = packed_input_and_idx.value();
|
||||
|
||||
// Resolve packed input and block argument type.
|
||||
if (parser->resolveOperand(
|
||||
packed_input, region_arg_types[idx + num_replicated_block_args],
|
||||
state->operands))
|
||||
return failure();
|
||||
}
|
||||
|
||||
return success();
|
||||
}
|
||||
|
||||
constexpr char kOperandSegmentSizesAttr[] = "operand_segment_sizes";
|
||||
|
||||
ParseResult ParseReplicateOp(OpAsmParser* parser, OperationState* state) {
|
||||
llvm::SMLoc loc = parser->getCurrentLocation();
|
||||
|
||||
// Parse operands, attributes, and region of op.
|
||||
llvm::SmallVector<llvm::SmallVector<OpAsmParser::OperandType, 8>, 8> operands;
|
||||
llvm::SmallVector<llvm::SmallVector<OpAsmParser::OperandType, 8>, 8>
|
||||
replicated_inputs;
|
||||
llvm::SmallVector<OpAsmParser::OperandType, 8> packed_inputs;
|
||||
llvm::SmallVector<OpAsmParser::OperandType, 8> region_args;
|
||||
llvm::SmallVector<Type, 8> region_arg_types;
|
||||
int n = 0;
|
||||
int32_t n = 0;
|
||||
Region& body = *state->addRegion();
|
||||
if (ParseReplicateOpOperands(parser, state, &operands, ®ion_args,
|
||||
if (ParseReplicateOpOperands(parser, state, &replicated_inputs,
|
||||
&packed_inputs, ®ion_args,
|
||||
®ion_arg_types) ||
|
||||
parser->parseOptionalAttrDict(state->attributes) ||
|
||||
SetOperands(loc, parser, state, operands, region_arg_types, &n) ||
|
||||
SetReplicateOpOperands(loc, parser, state, replicated_inputs,
|
||||
packed_inputs, region_arg_types, &n) ||
|
||||
parser->parseRegion(body, region_args, region_arg_types))
|
||||
return failure();
|
||||
|
||||
// Add derived `operand_segment_sizes` attribute based on parsed operands.
|
||||
if (!state->attributes.get(kOperandSegmentSizesAttr)) {
|
||||
int32_t num_replicated_inputs = replicated_inputs.size() * n;
|
||||
int32_t num_packed_inputs = packed_inputs.size();
|
||||
auto attr = DenseIntElementsAttr::get(
|
||||
VectorType::get({2}, parser->getBuilder().getI32Type()),
|
||||
{num_replicated_inputs, num_packed_inputs});
|
||||
state->addAttribute(kOperandSegmentSizesAttr, attr);
|
||||
}
|
||||
|
||||
// Ensure that the region is well formed: it contains at least a block with
|
||||
// a ReturnOp terminator.
|
||||
ReplicateOp::ensureTerminator(body, parser->getBuilder(), state->location);
|
||||
@ -323,22 +389,40 @@ void Print(ReplicateOp op, OpAsmPrinter* p) {
|
||||
*p << op.getOperationName();
|
||||
|
||||
// Print comma separated operands of the following format:
|
||||
// [%a, ...] as %block_arg: type
|
||||
int n = op.getAttrOfType<IntegerAttr>("n").getInt();
|
||||
// replicated_input
|
||||
// [%a, ...] as %block_arg0: type
|
||||
// packed_input
|
||||
// %b as %block_arg1: type
|
||||
const int32_t n = op.n().getSExtValue();
|
||||
const int32_t num_replicated_inputs =
|
||||
(*op.operand_segment_sizes().int_value_begin()).getSExtValue();
|
||||
const int32_t num_replicated_block_args = num_replicated_inputs / n;
|
||||
|
||||
if (op.getNumOperands()) {
|
||||
*p << '(';
|
||||
Block& block = op.body().front();
|
||||
interleaveComma(block.getArguments(), *p, [&](BlockArgument arg) {
|
||||
const int block_arg_num = arg.getArgNumber();
|
||||
*p << '[';
|
||||
p->printOperands(std::next(op.operand_begin(), block_arg_num * n),
|
||||
std::next(op.operand_begin(), (block_arg_num + 1) * n));
|
||||
*p << "] as " << arg << ": " << arg.getType();
|
||||
if (block_arg_num < num_replicated_block_args) {
|
||||
*p << '[';
|
||||
p->printOperands(
|
||||
std::next(op.replicated_inputs().begin(), block_arg_num * n),
|
||||
std::next(op.replicated_inputs().begin(), (block_arg_num + 1) * n));
|
||||
*p << "]";
|
||||
} else {
|
||||
p->printOperand(*std::next(op.packed_inputs().begin(),
|
||||
block_arg_num - num_replicated_block_args));
|
||||
}
|
||||
*p << " as " << arg << ": " << arg.getType();
|
||||
});
|
||||
*p << ')';
|
||||
}
|
||||
|
||||
p->printOptionalAttrDict(op.getAttrs());
|
||||
// Skip derived `operand_segment_sizes` attribute as custom print format of
|
||||
// operands holds enough information to calculate these variadic operand list
|
||||
// lengths.
|
||||
p->printOptionalAttrDict(op.getAttrs(), /*elidedAttrs=*/ArrayRef<StringRef>{
|
||||
kOperandSegmentSizesAttr});
|
||||
p->printRegion(op.body(), /*printEntryBlockArgs=*/false);
|
||||
}
|
||||
|
||||
@ -353,9 +437,7 @@ LogicalResult VerifyCompatibleTypes(Type a, Type b) {
|
||||
}
|
||||
|
||||
LogicalResult Verify(ReplicateOp op) {
|
||||
uint64_t n = op.n().getLimitedValue();
|
||||
if (n < 2)
|
||||
return op.emitOpError() << "expects 'n' to be at least 2, got " << n;
|
||||
int32_t n = op.n().getSExtValue();
|
||||
|
||||
// Check number of devices, if set, matches `n`.
|
||||
if (op.devices().hasValue()) {
|
||||
@ -381,22 +463,46 @@ LogicalResult Verify(ReplicateOp op) {
|
||||
|
||||
Block& block = op.body().front();
|
||||
|
||||
// Check number of operands matches `n` * number of block arguments.
|
||||
if (op.getNumOperands() != n * block.getNumArguments())
|
||||
return op.emitOpError()
|
||||
<< "expects number of operands (" << op.getNumOperands()
|
||||
<< ") to be equal to 'n' * number of block arguments (" << n << " * "
|
||||
<< block.getNumArguments() << ")";
|
||||
auto operand_segment_sizes = op.operand_segment_sizes();
|
||||
const int32_t num_replicated_inputs =
|
||||
operand_segment_sizes.getValue<IntegerAttr>({0}).getInt();
|
||||
const int32_t num_packed_inputs =
|
||||
operand_segment_sizes.getValue<IntegerAttr>({1}).getInt();
|
||||
|
||||
// Check replicated input types match block argument types.
|
||||
if (num_replicated_inputs % n != 0)
|
||||
return op.emitOpError()
|
||||
<< "expects number of replicated inputs (" << num_replicated_inputs
|
||||
<< ") to be evenly divisible by 'n' (" << n << ")";
|
||||
|
||||
const int32_t num_replicated_block_args = num_replicated_inputs / n;
|
||||
if (num_replicated_block_args + num_packed_inputs != block.getNumArguments())
|
||||
return op.emitOpError()
|
||||
<< "expects number of block arguments (" << block.getNumArguments()
|
||||
<< ") to be equal to number of replicated inputs ("
|
||||
<< num_replicated_inputs << ") / 'n' (" << n
|
||||
<< ") + number of packed inputs (" << num_packed_inputs << ")";
|
||||
|
||||
// Check input types match block argument types.
|
||||
auto verify_operand_types = [&](BlockArgument block_arg,
|
||||
int32_t op_operand_idx) -> LogicalResult {
|
||||
Type op_operand_type = op.getOperand(op_operand_idx).getType();
|
||||
if (failed(VerifyCompatibleTypes(block_arg.getType(), op_operand_type)))
|
||||
return op.emitOpError()
|
||||
<< "expects operand " << op_operand_idx << " (" << op_operand_type
|
||||
<< ") and block argument " << block_arg.getArgNumber() << " ("
|
||||
<< block_arg.getType() << ") to have compatible types";
|
||||
|
||||
return success();
|
||||
};
|
||||
for (auto block_arg : block.getArguments()) {
|
||||
Type block_arg_type = block_arg.getType();
|
||||
for (int i = n * block_arg.getArgNumber(), e = i + n; i < e; ++i)
|
||||
if (failed(VerifyCompatibleTypes(block_arg_type,
|
||||
op.getOperand(i).getType())))
|
||||
return op.emitOpError()
|
||||
<< "incompatible types for operand " << i
|
||||
<< " and block argument " << block_arg.getArgNumber();
|
||||
if (block_arg.getArgNumber() < num_replicated_block_args) {
|
||||
for (int32_t i = n * block_arg.getArgNumber(), e = i + n; i < e; ++i)
|
||||
if (failed(verify_operand_types(block_arg, i))) return failure();
|
||||
} else {
|
||||
const int32_t idx = block_arg.getArgNumber() - num_replicated_block_args +
|
||||
num_replicated_inputs;
|
||||
if (failed(verify_operand_types(block_arg, idx))) return failure();
|
||||
}
|
||||
}
|
||||
|
||||
Operation& terminator = block.back();
|
||||
@ -412,8 +518,8 @@ LogicalResult Verify(ReplicateOp op) {
|
||||
for (auto operand_type_and_idx :
|
||||
llvm::enumerate(terminator.getOperandTypes())) {
|
||||
Type operand_type = operand_type_and_idx.value();
|
||||
int operand_idx = operand_type_and_idx.index();
|
||||
for (int i = n * operand_idx, e = i + n; i < e; ++i)
|
||||
int32_t operand_idx = operand_type_and_idx.index();
|
||||
for (int32_t i = n * operand_idx, e = i + n; i < e; ++i)
|
||||
if (failed(VerifyCompatibleTypes(operand_type, op.getType(i))))
|
||||
return op.emitOpError() << "incompatible types for result " << i
|
||||
<< " and terminator operand " << operand_idx;
|
||||
@ -428,7 +534,7 @@ void BuildReplicateOp(
|
||||
const llvm::SmallDenseMap<StringRef, llvm::SmallVector<StringRef, 4>>&
|
||||
devices,
|
||||
llvm::ArrayRef<std::pair<OperandsTy, Type>> replicated_inputs,
|
||||
ResultsTy replica_output_types) {
|
||||
llvm::ArrayRef<Value> packed_inputs, ResultsTy replica_output_types) {
|
||||
DCHECK_GE(n, 2);
|
||||
state->addAttribute("n", builder->getI32IntegerAttr(n));
|
||||
|
||||
@ -456,6 +562,17 @@ void BuildReplicateOp(
|
||||
block.addArgument(replicated_input.second);
|
||||
}
|
||||
|
||||
for (auto& packed_input : packed_inputs) {
|
||||
state->addOperands(packed_input);
|
||||
block.addArgument(packed_input.getType());
|
||||
}
|
||||
|
||||
// Add derived `operand_segment_sizes` attribute.
|
||||
int32_t num_replicated_inputs = replicated_inputs.size() * n;
|
||||
auto operand_segment_sizes = DenseIntElementsAttr::get(
|
||||
VectorType::get({2}, builder->getI32Type()), {num_replicated_inputs, 0});
|
||||
state->addAttribute(kOperandSegmentSizesAttr, operand_segment_sizes);
|
||||
|
||||
for (const auto& output_type : replica_output_types)
|
||||
state->addTypes(llvm::SmallVector<Type, 8>(n, output_type));
|
||||
}
|
||||
@ -466,9 +583,10 @@ void ReplicateOp::build(
|
||||
const llvm::SmallDenseMap<StringRef, llvm::SmallVector<StringRef, 4>>&
|
||||
devices,
|
||||
llvm::ArrayRef<std::pair<llvm::ArrayRef<Value>, Type>> replicated_inputs,
|
||||
llvm::ArrayRef<Value> packed_inputs,
|
||||
llvm::ArrayRef<Type> replica_output_types) {
|
||||
BuildReplicateOp(&builder, &state, n, devices, replicated_inputs,
|
||||
replica_output_types);
|
||||
packed_inputs, replica_output_types);
|
||||
}
|
||||
|
||||
void ReplicateOp::build(
|
||||
@ -476,9 +594,10 @@ void ReplicateOp::build(
|
||||
const llvm::SmallDenseMap<StringRef, llvm::SmallVector<StringRef, 4>>&
|
||||
devices,
|
||||
llvm::ArrayRef<std::pair<Operation::operand_range, Type>> replicated_inputs,
|
||||
llvm::ArrayRef<Value> packed_inputs,
|
||||
Operation::result_type_range replica_output_types) {
|
||||
BuildReplicateOp(&builder, &state, n, devices, replicated_inputs,
|
||||
replica_output_types);
|
||||
packed_inputs, replica_output_types);
|
||||
}
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
@ -177,8 +177,8 @@ def TfDevice_ParallelExecuteOp : TfDevice_Op<"parallel_execute",
|
||||
let verifier = [{ return Verify(*this); }];
|
||||
}
|
||||
|
||||
def TfDevice_ReplicateOp :
|
||||
TfDevice_Op<"replicate", [SingleBlockImplicitTerminator<"ReturnOp">]> {
|
||||
def TfDevice_ReplicateOp : TfDevice_Op<"replicate",
|
||||
[SingleBlockImplicitTerminator<"ReturnOp">, AttrSizedOperandSegments]> {
|
||||
let summary = "Wraps an N-way replicated computation.";
|
||||
|
||||
let description = [{
|
||||
@ -187,22 +187,30 @@ across multiple devices. The number of replications is based on the `n`
|
||||
attribute. Explicit devices can be populated in the `devices` attribute, and it
|
||||
must be a mapping of device alias to list of explicit or aliased device names
|
||||
from the outer scope. The device name map specifies devices on which replicated
|
||||
ops inside tf_device.replicate will be executed. A tf_device.parallel_execute
|
||||
inside the tf_device.replicate op region may be used to represent computations
|
||||
across a larger set of devices. In that case, the device alias can be used to
|
||||
specify device assignment and replication of each concurrent execution
|
||||
(i.e. region) defined by tf_device.parallel_execute op. The size of each value
|
||||
list in the device name map must match `n`. Within a replica, the execution
|
||||
semantics follow standard sequential behavior. Ops in the tf_device.replicate
|
||||
wrapped with a tf_device.launch will have its device set to the associated
|
||||
replicated device from `devices` if the tf_device.launch refers to an aliased
|
||||
device name. Otherwise the device already set in tf_device.launch is used
|
||||
instead. Operands are replicated inputs: each group of `n` inputs corresponds to
|
||||
an input for a single individual replica and is mapped to a single region
|
||||
argument. Inside one group the operands are matching in order the `devices`
|
||||
attribute. Each replicated input must have compatible shapes and types. Operands
|
||||
not replicated can be implicitly captured by ops in the region. Results are
|
||||
replicated each from the regions terminator.
|
||||
ops inside tf_device.replicate will be executed.
|
||||
|
||||
A tf_device.parallel_execute inside the tf_device.replicate op region may be
|
||||
used to represent computations across a larger set of devices. In that case, the
|
||||
device alias can be used to specify device assignment and replication of each
|
||||
concurrent execution (i.e. region) defined by tf_device.parallel_execute op.
|
||||
The size of each value list in the device name map must match `n`. Within a
|
||||
replica, the execution semantics follow standard sequential behavior. Ops in the
|
||||
tf_device.replicate wrapped with a tf_device.launch will have its device set to
|
||||
the associated replicated device from `devices` if the tf_device.launch refers
|
||||
to an aliased device name. Otherwise the device already set in tf_device.launch
|
||||
is used instead.
|
||||
|
||||
Operands are replicated inputs and packed inputs.
|
||||
|
||||
replicated_inputs: each group of `n` inputs corresponds to an input for a single
|
||||
individual replica and is mapped to a single region argument. Inside one group
|
||||
the operands are matching in order the `devices` attribute. Each replicated
|
||||
input must have compatible shapes and types.
|
||||
packed_inputs: each input corresponds to an input broadcasted across all
|
||||
replicas and is mapped to a single region argument.
|
||||
|
||||
Operands not replicated can be implicitly captured by ops in the region. Results
|
||||
are replicated each from the regions terminator.
|
||||
|
||||
For example:
|
||||
```
|
||||
@ -214,46 +222,55 @@ For example:
|
||||
%5 = "tf.opF"() : () -> tensor<!tf.resource>
|
||||
%6 = "tf.opG"() : () -> tensor<!tf.string>
|
||||
%7 = "tf.opH"() : () -> tensor<!tf.string>
|
||||
%8 = "tf.opI"() : () -> tensor<i1>
|
||||
%output:8 = tf_device.replicate([%0, %1] as %input_0:tensor<i32>,
|
||||
[%2, %3] as %input_1:tensor<f32>,
|
||||
[%4, %5] as %input_2:tensor<!tf.resource>
|
||||
[%6, %7] as %input_3:tensor<!tf.string>)
|
||||
%8 = "tf.opI"() : () -> tensor<!tf.variant>
|
||||
%9 = "tf.opJ"() : () -> tensor<i1>
|
||||
%output:8 = tf_device.replicate([%0, %1] as %input_0: tensor<i32>,
|
||||
[%2, %3] as %input_1: tensor<f32>,
|
||||
[%4, %5] as %input_2: tensor<!tf.resource>,
|
||||
[%6, %7] as %input_3: tensor<!tf.string>,
|
||||
%8 as %input_4: tensor<!tf.variant>)
|
||||
{n = 2 : i32,
|
||||
devices = {DEVICE_ALIAS_0 = ["/DEVICE:0", "/DEVICE:1"],
|
||||
DEVICE_ALIAS_1 = ["/DEVICE:2", "/DEVICE:3"]}} {
|
||||
// Inside the region, %0, %2, %4, and %6 corresponds to
|
||||
// "/DEVICE:0"/"/DEVICE:2" and %1, %3, %5, and %7 corresponds to
|
||||
// "/DEVICE:1"/"/DEVICE:3", depending on which device alias is used.
|
||||
%j = "tf_device.launch"() ( {
|
||||
%9 = "tf.opJ"(%input_0, %6) : (tensor<i32>, tensor<i1>) -> tensor<i32>
|
||||
%k = "tf_device.launch"() ( {
|
||||
%9 = "tf.opK"(%input_0, %input_4, %9) :
|
||||
(tensor<i32>, tensor<!tf.variant>, tensor<i1>) -> tensor<i32>
|
||||
tf_device.return %9 : tensor<i32>
|
||||
}) {device = "DEVICE_ALIAS_0"} : () -> tensor<i32>
|
||||
%k = "tf_device.launch"() ( {
|
||||
%10 = "tf.opK"(%input_1, %6) : (tensor<f32>, tensor<i1>) -> tensor<f32>
|
||||
%l = "tf_device.launch"() ( {
|
||||
%10 = "tf.opL"(%input_1, %input_4, %9) :
|
||||
(tensor<f32>, tensor<!tf.variant>, tensor<i1>) -> tensor<f32>
|
||||
tf_device.return %10 : tensor<f32>
|
||||
}) {device = "DEVICE_ALIAS_1"} : () -> tensor<f32>
|
||||
%l = "tf_device.launch"() ( {
|
||||
%11 = "tf.opL"(%input_2, %6) : (tensor<!tf.resource>, tensor<i1>)
|
||||
-> tensor<!tf.resource>
|
||||
%m = "tf_device.launch"() ( {
|
||||
%11 = "tf.opM"(%input_2, %input_4, %9) :
|
||||
(tensor<!tf.resource>, tensor<!tf.variant>, tensor<i1>)
|
||||
-> tensor<!tf.resource>
|
||||
tf_device.return %11 : tensor<!tf.resource>
|
||||
}) {device = "/DEVICE:4"} : () -> tensor<f32>
|
||||
%m = "tf.opM"(%input_3, %6) : (tensor<!tf.string>, tensor<i1>)
|
||||
-> tensor<!tf.string>
|
||||
tf_device.return %j, %k, %l, %m :
|
||||
%n = "tf.opN"(%input_3, %input_4, %9) :
|
||||
(tensor<!tf.string>, tensor<!tf.variant>, tensor<i1>)
|
||||
-> tensor<!tf.string>
|
||||
tf_device.return %k, %l, %m, %n :
|
||||
tensor<i32>, tensor<f32>, tensor<!tf.resource>, tensor<!tf.string>
|
||||
}
|
||||
// %output#0 corresponds to %j returned from "/DEVICE:0"
|
||||
// %output#1 corresponds to %j returned from "/DEVICE:1"
|
||||
// %output#2 corresponds to %k returned from "/DEVICE:2"
|
||||
// %output#3 corresponds to %k returned from "/DEVICE:3"
|
||||
// %output#4, %output#5 corresponds to %l and will be returned from "/DEVICE:4"
|
||||
// %output#6, %output#7 corresponds to %m and will have no device set
|
||||
// %output#0 corresponds to %k returned from "/DEVICE:0"
|
||||
// %output#1 corresponds to %k returned from "/DEVICE:1"
|
||||
// %output#2 corresponds to %l returned from "/DEVICE:2"
|
||||
// %output#3 corresponds to %l returned from "/DEVICE:3"
|
||||
// %output#4, %output#5 corresponds to %m and will be returned from "/DEVICE:4"
|
||||
// %output#6, %output#7 corresponds to %n and will have no device set
|
||||
```
|
||||
}];
|
||||
|
||||
let arguments = (ins
|
||||
Variadic<AnyType>:$replicated_inputs,
|
||||
Variadic<AnyType>:$packed_inputs,
|
||||
|
||||
I32ElementsAttr:$operand_segment_sizes,
|
||||
Confined<I32Attr, [IntMinValue<2>]>:$n,
|
||||
OptionalAttr<DictionaryAttr>:$devices
|
||||
);
|
||||
@ -272,10 +289,12 @@ For example:
|
||||
OpBuilder<"OpBuilder& builder, OperationState& state, int n, "
|
||||
"const llvm::SmallDenseMap<StringRef, llvm::SmallVector<StringRef, 4>>& devices, "
|
||||
"llvm::ArrayRef<std::pair<llvm::ArrayRef<Value>, Type>> replicated_inputs, "
|
||||
"llvm::ArrayRef<Value> packed_inputs, "
|
||||
"llvm::ArrayRef<Type> replica_output_types">,
|
||||
OpBuilder<"OpBuilder& builder, OperationState& state, int n, "
|
||||
"const llvm::SmallDenseMap<StringRef, llvm::SmallVector<StringRef, 4>>& devices, "
|
||||
"llvm::ArrayRef<std::pair<Operation::operand_range, Type>> replicated_inputs, "
|
||||
"llvm::ArrayRef<Value> packed_inputs, "
|
||||
"Operation::result_type_range replica_output_types">
|
||||
];
|
||||
|
||||
|
@ -4274,6 +4274,117 @@ LogicalResult WhileRegionOp::moveOutOfLoop(
|
||||
return success();
|
||||
}
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
// WhileRegionOp canonicalization
|
||||
//===----------------------------------------------------------------------===//
|
||||
namespace {
|
||||
// Eliminate values that pass through the WhileRegionOp body.
|
||||
struct WhileRegionEliminatePassThrough
|
||||
: public OpRewritePattern<WhileRegionOp> {
|
||||
using OpRewritePattern<WhileRegionOp>::OpRewritePattern;
|
||||
|
||||
LogicalResult matchAndRewrite(WhileRegionOp while_op,
|
||||
PatternRewriter &rewriter) const override {
|
||||
// Replace values that simply passthrough the body with extern values. The
|
||||
// block arguments of body and while match and so the corresponding cond
|
||||
// argument can be easily found.
|
||||
int old_num_operands = while_op.getNumOperands();
|
||||
int new_num_operands = old_num_operands;
|
||||
auto &body_block = while_op.body().front();
|
||||
auto &cond_block = while_op.cond().front();
|
||||
auto &yield = *body_block.getTerminator();
|
||||
|
||||
// Bit mask indicating which operands will be removed.
|
||||
SmallVector<bool, 16> removed_operand(old_num_operands, false);
|
||||
|
||||
for (int op_idx : llvm::seq<int>(0, old_num_operands)) {
|
||||
auto body_arg = body_block.getArgument(op_idx);
|
||||
if (body_arg == yield.getOperand(op_idx)) {
|
||||
// Replace the use of the passthrough value with the while operand
|
||||
// in the body and condition regions, as well as the while output (if
|
||||
// type match)
|
||||
// TODO(jurahul): Use PatternRewriter API for IR modification.
|
||||
auto value = while_op.getOperand(op_idx);
|
||||
if (body_arg.getType() == value.getType())
|
||||
body_arg.replaceAllUsesWith(value);
|
||||
|
||||
auto cond_arg = cond_block.getArgument(op_idx);
|
||||
if (cond_arg.getType() == value.getType())
|
||||
cond_arg.replaceAllUsesWith(value);
|
||||
|
||||
auto result = while_op.getResult(op_idx);
|
||||
if (result.getType() == value.getType())
|
||||
result.replaceAllUsesWith(value);
|
||||
}
|
||||
|
||||
// Now check if the operand is unused in both regions as well as the
|
||||
// result. If so, mark it for removal.
|
||||
if (body_block.getArgument(op_idx).use_empty() &&
|
||||
cond_block.getArgument(op_idx).use_empty() &&
|
||||
while_op.getResult(op_idx).use_empty()) {
|
||||
removed_operand[op_idx] = true;
|
||||
new_num_operands--;
|
||||
}
|
||||
}
|
||||
|
||||
if (new_num_operands == old_num_operands) return failure();
|
||||
|
||||
// Compress the operands, region arguments, and outputs.
|
||||
SmallVector<Value, 4> new_while_operands;
|
||||
SmallVector<Type, 4> new_result_types;
|
||||
new_while_operands.reserve(new_num_operands);
|
||||
new_result_types.reserve(new_num_operands);
|
||||
|
||||
// Build new operands and result type.
|
||||
int next_idx = 0;
|
||||
for (int op_idx : llvm::seq<int>(0, old_num_operands)) {
|
||||
if (removed_operand[op_idx]) continue;
|
||||
new_while_operands.push_back(while_op.getOperand(op_idx));
|
||||
new_result_types.push_back(while_op.getResult(op_idx).getType());
|
||||
next_idx++;
|
||||
}
|
||||
|
||||
// Create the new while operation.
|
||||
auto new_while_op =
|
||||
rewriter.create<WhileRegionOp>(while_op.getLoc(), new_result_types,
|
||||
new_while_operands, while_op.getAttrs());
|
||||
|
||||
// Move region bodies to the new while.
|
||||
rewriter.inlineRegionBefore(while_op.cond(), new_while_op.cond(),
|
||||
new_while_op.cond().end());
|
||||
rewriter.inlineRegionBefore(while_op.body(), new_while_op.body(),
|
||||
new_while_op.body().end());
|
||||
|
||||
auto &new_cond_block = new_while_op.cond().front();
|
||||
auto &new_body_block = new_while_op.body().front();
|
||||
auto &new_yield = *new_body_block.getTerminator();
|
||||
|
||||
// Build a vector of new results. Also patch up the region bodies and yield.
|
||||
SmallVector<Value, 4> new_results;
|
||||
next_idx = 0;
|
||||
for (int op_idx : llvm::seq<int>(0, old_num_operands)) {
|
||||
if (removed_operand[op_idx]) {
|
||||
new_cond_block.eraseArgument(next_idx);
|
||||
new_body_block.eraseArgument(next_idx);
|
||||
new_yield.eraseOperand(next_idx);
|
||||
new_results.push_back(nullptr);
|
||||
} else {
|
||||
new_results.push_back(new_while_op.getResult(next_idx++));
|
||||
}
|
||||
}
|
||||
|
||||
rewriter.replaceOp(while_op, new_results);
|
||||
return success();
|
||||
}
|
||||
};
|
||||
|
||||
} // anonymous namespace
|
||||
|
||||
void WhileRegionOp::getCanonicalizationPatterns(
|
||||
OwningRewritePatternList &results, MLIRContext *context) {
|
||||
results.insert<WhileRegionEliminatePassThrough>(context);
|
||||
}
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
// XdivyOp
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
@ -658,7 +658,9 @@ def TL_WhileRegionOp : TF_Op<"WhileRegion",
|
||||
|
||||
This implies that the operand and result types for tf.WhileRegion should be
|
||||
the same. Note that the condition and body regions can implicitly capture
|
||||
loop invariant values directly.
|
||||
loop invariant values directly. In canonical form, iteration variables that
|
||||
pass through the loop body unmodified are converted to implicitly captured
|
||||
references to their values outside the loop.
|
||||
}];
|
||||
|
||||
let arguments = (ins
|
||||
@ -676,6 +678,8 @@ def TL_WhileRegionOp : TF_Op<"WhileRegion",
|
||||
let regions = (region SizedRegion<1>:$cond, SizedRegion<1>:$body);
|
||||
|
||||
let verifier = [{ return Verify(*this); }];
|
||||
|
||||
let hasCanonicalizer = 1;
|
||||
}
|
||||
|
||||
def TF_TensorListReserveOp : TF_TensorListInitOp<"TensorListReserve"> {
|
||||
@ -717,7 +721,8 @@ This operation holds the metadata common to operations of a `tpu.replicate()` co
|
||||
DefaultValuedAttr<StrArrayAttr, "{}">:$host_compute_core,
|
||||
DefaultValuedAttr<StrArrayAttr, "{}">:$padding_map,
|
||||
DefaultValuedAttr<StrAttr, "STEP_MARK_AT_ENTRY">:$step_marker_location,
|
||||
DefaultValuedAttr<BoolAttr, "false">:$allow_soft_placement
|
||||
DefaultValuedAttr<BoolAttr, "false">:$allow_soft_placement,
|
||||
DefaultValuedAttr<BoolAttr, "false">:$use_spmd_for_xla_partitioning
|
||||
);
|
||||
|
||||
let results = (outs);
|
||||
|
@ -141,16 +141,27 @@ static LogicalResult VerifyIndexPath(Operation *op, NamedAttribute named_attr) {
|
||||
return mlir::success();
|
||||
}
|
||||
|
||||
Type GetBoundInputArgTypeFor(GlobalTensorOp global_tensor) {
|
||||
auto type = global_tensor.type().cast<TensorType>();
|
||||
return RankedTensorType::get(
|
||||
{}, TF::ResourceType::get({type}, type.getContext()));
|
||||
Type GetBoundInputArgTypeFor(mlir::Operation *op) {
|
||||
if (auto global_tensor = llvm::dyn_cast<GlobalTensorOp>(op)) {
|
||||
auto type = global_tensor.type().cast<TensorType>();
|
||||
return RankedTensorType::get(
|
||||
{}, TF::ResourceType::get({type}, type.getContext()));
|
||||
}
|
||||
|
||||
if (auto asset = llvm::dyn_cast<AssetOp>(op)) {
|
||||
return RankedTensorType::get({}, TF::StringType::get(asset.getContext()));
|
||||
}
|
||||
|
||||
op->emitError() << "unknown symbol operation";
|
||||
return {};
|
||||
}
|
||||
|
||||
static LogicalResult VerifyBoundInputArgType(Operation *op_for_diagnostics,
|
||||
Type arg_type,
|
||||
GlobalTensorOp global_tensor) {
|
||||
auto expected_type = GetBoundInputArgTypeFor(global_tensor);
|
||||
mlir::Operation *symbol_op) {
|
||||
auto expected_type = GetBoundInputArgTypeFor(symbol_op);
|
||||
if (!expected_type) return failure();
|
||||
|
||||
if (arg_type != expected_type) {
|
||||
return op_for_diagnostics->emitError()
|
||||
<< "bound input with type " << arg_type << " expected to have type "
|
||||
@ -169,14 +180,14 @@ LogicalResult TensorFlowSavedModelDialect::verifyRegionArgAttribute(
|
||||
}
|
||||
auto symbol_name = named_attr.second.cast<FlatSymbolRefAttr>().getValue();
|
||||
auto module = op->getParentOfType<ModuleOp>();
|
||||
auto global_tensor = module.lookupSymbol<GlobalTensorOp>(symbol_name);
|
||||
if (!global_tensor) {
|
||||
mlir::Operation *symbol_op = module.lookupSymbol(symbol_name);
|
||||
if (!symbol_op) {
|
||||
return op->emitError() << "'tf_saved_model.bound_input' attribute must "
|
||||
"reference a valid symbol, got invalid symbol '"
|
||||
<< symbol_name << "'";
|
||||
}
|
||||
auto arg_type = cast<FuncOp>(op).getArgument(arg_index).getType();
|
||||
return VerifyBoundInputArgType(op, arg_type, global_tensor);
|
||||
return VerifyBoundInputArgType(op, arg_type, symbol_op);
|
||||
}
|
||||
if (named_attr.first == "tf_saved_model.index_path") {
|
||||
return VerifyIndexPath(op, named_attr);
|
||||
@ -404,12 +415,12 @@ bool HasTfSavedModelSemantics(ModuleOp module) {
|
||||
return module.getAttr("tf_saved_model.semantics") != nullptr;
|
||||
}
|
||||
|
||||
GlobalTensorOp LookupBoundInput(FuncOp func, int arg_index,
|
||||
const SymbolTable &symbol_table) {
|
||||
Operation *LookupBoundInput(FuncOp func, int arg_index,
|
||||
const SymbolTable &symbol_table) {
|
||||
auto attr = func.getArgAttrOfType<FlatSymbolRefAttr>(
|
||||
arg_index, "tf_saved_model.bound_input");
|
||||
if (!attr) return nullptr;
|
||||
return symbol_table.lookup<GlobalTensorOp>(attr.getValue());
|
||||
return symbol_table.lookup(attr.getValue());
|
||||
}
|
||||
|
||||
SessionInitializerOp GetSessionInitializerOp(mlir::ModuleOp op) {
|
||||
|
@ -36,6 +36,8 @@ class TensorFlowSavedModelDialect : public Dialect {
|
||||
NamedAttribute named_attr) override;
|
||||
LogicalResult verifyOperationAttribute(Operation *op,
|
||||
NamedAttribute named_attr) override;
|
||||
|
||||
static StringRef getDialectNamespace() { return "tf_saved_model"; }
|
||||
};
|
||||
|
||||
// Declares the operations for this dialect using the generated header.
|
||||
@ -54,12 +56,19 @@ bool HasTfSavedModelSemantics(ModuleOp module);
|
||||
|
||||
// Returns the tf_saved_model.global_tensor op that func's arg_index'th argument
|
||||
// refers to as a bound input, or null.
|
||||
GlobalTensorOp LookupBoundInput(FuncOp func, int arg_index,
|
||||
const SymbolTable &symbol_table);
|
||||
Operation *LookupBoundInput(FuncOp func, int arg_index,
|
||||
const SymbolTable &symbol_table);
|
||||
|
||||
// Gets the type that an exported function arg that is bound to `global_tensor`
|
||||
// should have.
|
||||
Type GetBoundInputArgTypeFor(GlobalTensorOp global_tensor);
|
||||
template <typename T>
|
||||
T LookupBoundInputOfType(FuncOp func, int arg_index,
|
||||
const SymbolTable &symbol_table) {
|
||||
return llvm::dyn_cast_or_null<T>(
|
||||
LookupBoundInput(func, arg_index, symbol_table));
|
||||
}
|
||||
|
||||
// Gets the type that an exported function arg that is bound to symbol ops such
|
||||
// as `global_tensor` and `asset` should have.
|
||||
Type GetBoundInputArgTypeFor(mlir::Operation *op);
|
||||
|
||||
// Returns the session initializer of this module if it exists. Returns null
|
||||
// otherwise.
|
||||
|
@ -19,6 +19,7 @@ limitations under the License.
|
||||
#define SAVED_MODEL_DIALECT
|
||||
|
||||
include "mlir/IR/OpBase.td"
|
||||
include "mlir/IR/SymbolInterfaces.td"
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
// Dialect definition
|
||||
@ -154,4 +155,24 @@ def TfSavedModel_SessionInitializerOp: TfSavedModel_Op<"session_initializer"> {
|
||||
let hasCanonicalizer = 1;
|
||||
}
|
||||
|
||||
def TfSavedModel_AssetOp: TfSavedModel_Op<"asset", [Symbol]> {
|
||||
let summary = "Represents an asset in saved model.";
|
||||
let description = [{
|
||||
Represents an asset in the saved model that points to an external file. It
|
||||
is a scalar string tensor and it is passed as an argument to the session
|
||||
initializer function.
|
||||
|
||||
The `sym_name` represents the symbol table name used for internal IR
|
||||
references.
|
||||
|
||||
The `filename` attribute contains the file path to the asset file and it is
|
||||
relative to saved model directory.
|
||||
}];
|
||||
|
||||
let arguments = (ins
|
||||
StrAttr:$sym_name,
|
||||
StrAttr:$filename
|
||||
);
|
||||
}
|
||||
|
||||
#endif // SAVED_MODEL_DIALECT
|
||||
|
@ -19,7 +19,7 @@ module attributes {tf.versions = {producer = 888 : i32}} {
|
||||
|
||||
// CHECK-LABEL: func @_func
|
||||
// CHECK-SAME: %[[ARG0:.*]]: tensor<?xi32>,
|
||||
// CHECK-SAME: %[[ARG1:.*]]: tensor<?xi32> {xla_hlo.is_same_data_across_replicas}
|
||||
// CHECK-SAME: %[[ARG1:.*]]: tensor<?xi32> {mhlo.is_same_data_across_replicas}
|
||||
// CHECK-SAME: %[[ARG2:.*]]: tensor<?xi32>)
|
||||
func @_func(%arg0: tensor<?xi32>, %arg1: tensor<?xi32>, %arg2: tensor<?xi32>) -> tensor<?xi32> {
|
||||
%0 = "tf._D"(%arg0, %arg1) : (tensor<?xi32>, tensor<?xi32>) -> tensor<?xi32>
|
||||
@ -54,9 +54,9 @@ module attributes {tf.versions = {producer = 888 : i32}} {
|
||||
}
|
||||
|
||||
// CHECK-LABEL: func @_func
|
||||
// CHECK-SAME: %[[ARG0:.*]]: tensor<?xi32> {xla_hlo.is_same_data_across_replicas},
|
||||
// CHECK-SAME: %[[ARG0:.*]]: tensor<?xi32> {mhlo.is_same_data_across_replicas},
|
||||
// CHECK-SAME: %[[ARG1:.*]]: tensor<?xi32>,
|
||||
// CHECK-SAME: %[[ARG2:.*]]: tensor<!tf.resource<tensor<?xi32>>> {xla_hlo.is_same_data_across_replicas}
|
||||
// CHECK-SAME: %[[ARG2:.*]]: tensor<!tf.resource<tensor<?xi32>>> {mhlo.is_same_data_across_replicas}
|
||||
func @_func(%arg0: tensor<?xi32>, %arg1: tensor<?xi32>, %arg2: tensor<!tf.resource<tensor<?xi32>>>) -> tensor<?xi32> {
|
||||
%0 = "tf._D"(%arg0, %arg1) : (tensor<?xi32>, tensor<?xi32>) -> tensor<?xi32>
|
||||
return %0 : tensor<?xi32>
|
||||
@ -78,7 +78,7 @@ module attributes {tf.versions = {producer = 888 : i32}} {
|
||||
}
|
||||
|
||||
// CHECK-LABEL: func @_func
|
||||
// CHECK-NOT: xla_hlo.is_same_data_across_replicas
|
||||
// CHECK-NOT: mhlo.is_same_data_across_replicas
|
||||
func @_func(%arg0: tensor<?xi32>, %arg1: tensor<?xi32>) -> tensor<?xi32> {
|
||||
%0 = "tf._D"(%arg0, %arg1) : (tensor<?xi32>, tensor<?xi32>) -> tensor<?xi32>
|
||||
return %0 : tensor<?xi32>
|
||||
|
@ -648,3 +648,152 @@ func @erase_tf_var_is_initialized(%arg0 : tensor<!tf.resource<tensor<f32>>>) ->
|
||||
// Unused VarIsInitializedOp is erased.
|
||||
// CHECK: tf.VarHandleOp
|
||||
// CHECK-NEXT: tf.UnknownOp
|
||||
|
||||
|
||||
// Simple pass through value
|
||||
// CHECK-LABEL: testWhileRegionSimplePassThrough
|
||||
func @testWhileRegionSimplePassThrough(%arg0 : tensor<*xf32>, %arg1 : tensor<i32>) -> tensor<*xf32> {
|
||||
// CHECK: "tf.WhileRegion"(%arg1)
|
||||
%0:2 = "tf.WhileRegion"(%arg0, %arg1) (
|
||||
{
|
||||
// condition, check if count has reached 0
|
||||
^bb0(%carg0: tensor<*xf32>, %carg1: tensor<i32>):
|
||||
%zero = constant dense<0> : tensor<i32>
|
||||
%ne = "tf.NotEqual"(%carg1, %zero) : (tensor<i32>, tensor<i32>) -> tensor<i1>
|
||||
"tf.Yield"(%ne) : (tensor<i1>) -> ()
|
||||
},
|
||||
{
|
||||
// loop body
|
||||
^bb0(%barg0: tensor<*xf32>, %barg1: tensor<i32>):
|
||||
%one = constant dense<1> : tensor<i32>
|
||||
%sub = "tf.Sub"(%barg1, %one) : (tensor<i32>, tensor<i32>) -> tensor<i32>
|
||||
"tf.Yield"(%barg0, %sub) : (tensor<*xf32>, tensor<i32>) -> ()
|
||||
}
|
||||
) { is_stateless = false } : (tensor<*xf32>, tensor<i32>) -> (tensor<*xf32>, tensor<i32>)
|
||||
// CHECK: return %arg0 : tensor<*xf32>
|
||||
return %0#0 : tensor<*xf32>
|
||||
}
|
||||
|
||||
// Multiple pass through values
|
||||
// CHECK-LABEL: testWhileRegionMultiplePassThrough
|
||||
func @testWhileRegionMultiplePassThrough(%arg0 : tensor<*xf32>, %arg1 : tensor<*xf32>, %arg2 : tensor<*xf32>, %arg3 : tensor<i32>) -> tensor<*xf32> {
|
||||
// Verify that first 3 operands are elimiinated.
|
||||
// CHECK: "tf.WhileRegion"(%arg3)
|
||||
%0:4 = "tf.WhileRegion"(%arg0, %arg1, %arg2, %arg3) (
|
||||
{
|
||||
// condition, check if count has reached 0
|
||||
^bb0(%carg0 : tensor<*xf32>, %carg1 : tensor<*xf32>, %carg2 : tensor<*xf32>, %carg3 : tensor<i32>):
|
||||
%zero = constant dense<0> : tensor<i32>
|
||||
%ne = "tf.NotEqual"(%carg3, %zero) : (tensor<i32>, tensor<i32>) -> tensor<i1>
|
||||
"tf.Yield"(%ne) : (tensor<i1>) -> ()
|
||||
},
|
||||
{
|
||||
// loop body
|
||||
^bb0(%barg0 : tensor<*xf32>, %barg1 : tensor<*xf32>, %barg2 : tensor<*xf32>, %barg3 : tensor<i32>):
|
||||
%one = constant dense<1> : tensor<i32>
|
||||
%sub = "tf.Sub"(%barg3, %one) : (tensor<i32>, tensor<i32>) -> tensor<i32>
|
||||
"tf.Yield"(%barg0, %barg1, %barg2, %sub) : (tensor<*xf32>, tensor<*xf32>, tensor<*xf32>, tensor<i32>) -> ()
|
||||
}
|
||||
) { is_stateless = false } : (tensor<*xf32>, tensor<*xf32>, tensor<*xf32>, tensor<i32>) -> (tensor<*xf32>, tensor<*xf32>, tensor<*xf32>, tensor<i32>)
|
||||
|
||||
// CHECK: %[[SUB0:.*]] = "tf.Sub"(%arg0, %arg1)
|
||||
// CHECK: %[[SUB1:.*]] = "tf.Sub"(%arg2, %[[SUB0]])
|
||||
// CHECK: return %[[SUB1]]
|
||||
%sub0 = "tf.Sub" (%0#0, %0#1) : (tensor<*xf32>, tensor<*xf32>) -> tensor<*xf32>
|
||||
%sub1 = "tf.Sub" (%0#2, %sub0) : (tensor<*xf32>, tensor<*xf32>) -> tensor<*xf32>
|
||||
return %sub1 : tensor<*xf32>
|
||||
}
|
||||
|
||||
// Multiple non contiguous pass through values
|
||||
// CHECK-LABEL: testWhileRegionMultiplePassThroughNonContiguous
|
||||
func @testWhileRegionMultiplePassThroughNonContiguous(%arg0 : tensor<*xf32>, %arg1 : tensor<*xf32>, %arg2 : tensor<*xf32>, %arg3 : tensor<i32>) -> tensor<*xf32> {
|
||||
// Verify arg0 and arg2 are eliminated
|
||||
// CHECK: %[[WHILE_OUT:.*]]:2 = "tf.WhileRegion"(%arg1, %arg3)
|
||||
%0:4 = "tf.WhileRegion"(%arg0, %arg1, %arg2, %arg3) (
|
||||
{
|
||||
// condition, check if count has reached 0
|
||||
^bb0(%carg0 : tensor<*xf32>, %carg1 : tensor<*xf32>, %carg2 : tensor<*xf32>, %carg3 : tensor<i32>):
|
||||
%zero = constant dense<0> : tensor<i32>
|
||||
%ne = "tf.NotEqual"(%carg3, %zero) : (tensor<i32>, tensor<i32>) -> tensor<i1>
|
||||
"tf.Yield"(%ne) : (tensor<i1>) -> ()
|
||||
},
|
||||
{
|
||||
// loop body
|
||||
^bb0(%barg0 : tensor<*xf32>, %barg1 : tensor<*xf32>, %barg2 : tensor<*xf32>, %barg3 : tensor<i32>):
|
||||
%arg1neg = "tf.Neg"(%barg1) : (tensor<*xf32>) -> tensor<*xf32>
|
||||
%one = constant dense<1> : tensor<i32>
|
||||
%sub = "tf.Sub"(%barg3, %one) : (tensor<i32>, tensor<i32>) -> tensor<i32>
|
||||
"tf.Yield"(%barg0, %arg1neg, %barg2, %sub) : (tensor<*xf32>, tensor<*xf32>, tensor<*xf32>, tensor<i32>) -> ()
|
||||
}
|
||||
) { is_stateless = false } : (tensor<*xf32>, tensor<*xf32>, tensor<*xf32>, tensor<i32>) -> (tensor<*xf32>, tensor<*xf32>, tensor<*xf32>, tensor<i32>)
|
||||
|
||||
// Verify that use of while loop results corresponding to result #0 and 2 of
|
||||
// the while are replaces with corresponding WhileRegion operands
|
||||
// CHECK: %[[SUB0:.*]] = "tf.Sub"(%arg0, %[[WHILE_OUT]]#0)
|
||||
// CHECK: %[[SUB1:.*]] = "tf.Sub"(%arg2, %[[SUB0]])
|
||||
// CHECK: return %[[SUB1]]
|
||||
%sub0 = "tf.Sub" (%0#0, %0#1) : (tensor<*xf32>, tensor<*xf32>) -> tensor<*xf32>
|
||||
%sub1 = "tf.Sub" (%0#2, %sub0) : (tensor<*xf32>, tensor<*xf32>) -> tensor<*xf32>
|
||||
return %sub1 : tensor<*xf32>
|
||||
}
|
||||
|
||||
// Pass through but with type mismatch (tensor<*xf32> is compatible with
|
||||
// tensor<?x?xf32> in the body). WhileRegion canonicalization does not handle
|
||||
// this.
|
||||
// CHECK-LABEL: testWhileRegionPassThroughTypeMismatch
|
||||
func @testWhileRegionPassThroughTypeMismatch(%arg0 : tensor<*xf32>, %arg1 : tensor<i32>) -> tensor<*xf32> {
|
||||
// Verify that the While stay's unchanged
|
||||
// CHECK: "tf.WhileRegion"(%arg0, %arg1)
|
||||
%0:2 = "tf.WhileRegion"(%arg0, %arg1) (
|
||||
{
|
||||
// condition, check if count has reached 0
|
||||
^bb0(%carg0: tensor<*xf32>, %carg1: tensor<i32>):
|
||||
%zero = constant dense<0> : tensor<i32>
|
||||
%ne = "tf.NotEqual"(%carg1, %zero) : (tensor<i32>, tensor<i32>) -> tensor<i1>
|
||||
"tf.Yield"(%ne) : (tensor<i1>) -> ()
|
||||
},
|
||||
{
|
||||
// loop body
|
||||
^bb0(%barg0: tensor<?x?xf32>, %barg1: tensor<i32>):
|
||||
%one = constant dense<1> : tensor<i32>
|
||||
%sub = "tf.Sub"(%barg1, %one) : (tensor<i32>, tensor<i32>) -> tensor<i32>
|
||||
"tf.Yield"(%barg0, %sub) : (tensor<?x?xf32>, tensor<i32>) -> ()
|
||||
}
|
||||
) { is_stateless = false } : (tensor<*xf32>, tensor<i32>) -> (tensor<*xf32>, tensor<i32>)
|
||||
// Verify that the result stays uchanged
|
||||
// CHECK: return %arg0 : tensor<*xf32>
|
||||
return %0#0 : tensor<*xf32>
|
||||
}
|
||||
|
||||
// Unused value flowing through the while (operand 2 and 3, is unused in the
|
||||
// while and the corresponding result is unused as well). Canonicalization will
|
||||
// eliminate them.
|
||||
// CHECK-LABEL: testWhileRegionUnusedValue
|
||||
func @testWhileRegionUnusedValue(%arg0 : tensor<*xf32>, %arg1 : tensor<i32>, %arg2: tensor<i32>) -> tensor<*xf32> {
|
||||
%cst = constant dense <33.0> : tensor<f32>
|
||||
// Verify that last 2 operands of while (unused) are removed
|
||||
// CHECK: %[[WHILE_OUT:.*]]:2 = "tf.WhileRegion"(%arg0, %arg1)
|
||||
%0:4 = "tf.WhileRegion"(%arg0, %arg1, %arg2, %cst) (
|
||||
{
|
||||
// condition, check if count has reached 0
|
||||
^bb0(%carg0: tensor<*xf32>, %carg1: tensor<i32>, %carg2:tensor<i32>, %carg3:tensor<f32>):
|
||||
%zero = constant dense<0> : tensor<i32>
|
||||
%ne = "tf.NotEqual"(%carg1, %zero) : (tensor<i32>, tensor<i32>) -> tensor<i1>
|
||||
"tf.Yield"(%ne) : (tensor<i1>) -> ()
|
||||
},
|
||||
{
|
||||
// loop body
|
||||
^bb0(%barg0: tensor<*xf32>, %barg1: tensor<i32>, %barg2:tensor<i32>, %barg3:tensor<f32>):
|
||||
%add = "tf.Add"(%barg0, %barg0) : (tensor<*xf32>, tensor<*xf32>) -> tensor<*xf32>
|
||||
%one = constant dense<1> : tensor<i32>
|
||||
%sub = "tf.Sub"(%barg1, %one) : (tensor<i32>, tensor<i32>) -> tensor<i32>
|
||||
%dummy0 = constant dense<7> : tensor<i32>
|
||||
%dummy1 = constant dense<3.0> : tensor<f32>
|
||||
"tf.Yield"(%add, %sub, %dummy0, %dummy1) : (tensor<*xf32>, tensor<i32>, tensor<i32>, tensor<f32>) -> ()
|
||||
}
|
||||
) { is_stateless = false } : (tensor<*xf32>, tensor<i32>, tensor<i32>, tensor<f32>) -> (tensor<*xf32>, tensor<i32>, tensor<i32>, tensor<f32>)
|
||||
|
||||
// Verify that return still uses while result # 0
|
||||
// CHECK: return %[[WHILE_OUT]]#0 : tensor<*xf32>
|
||||
return %0#0 : tensor<*xf32>
|
||||
}
|
||||
|
@ -17,8 +17,8 @@ func @biasAdd_dynamic(%arg0: tensor<?x?x?x?xi32>, %arg1: tensor<?xi32>) -> tenso
|
||||
}
|
||||
|
||||
func @add(%arg0: tensor<2xi32>) -> tensor<2xi32> {
|
||||
%0 = xla_hlo.add %arg0, %arg0 : tensor<2xi32>
|
||||
%1 = xla_hlo.add %0, %arg0 : tensor<2xi32>
|
||||
%0 = mhlo.add %arg0, %arg0 : tensor<2xi32>
|
||||
%1 = mhlo.add %0, %arg0 : tensor<2xi32>
|
||||
return %1 : tensor<2xi32>
|
||||
}
|
||||
|
||||
@ -33,7 +33,7 @@ func @broadcast_multi_dim_add(%arg0: tensor<4x1x1xi32>, %arg1: tensor<4x4x4x4xi3
|
||||
}
|
||||
|
||||
func @div(%arg0: tensor<2xi32>) -> tensor<2xi32> {
|
||||
%0 = xla_hlo.divide %arg0, %arg0 : tensor<2xi32>
|
||||
%0 = mhlo.divide %arg0, %arg0 : tensor<2xi32>
|
||||
return %0 : tensor<2xi32>
|
||||
}
|
||||
|
||||
@ -43,7 +43,7 @@ func @broadcast_div(%arg0: tensor<1xi32>, %arg1: tensor<1x2xi32>) -> tensor<1x2x
|
||||
}
|
||||
|
||||
func @shift_left(%arg0: tensor<4xi32>, %arg1: tensor<4xi32>) -> tensor<4xi32> {
|
||||
%0 = xla_hlo.shift_left %arg0, %arg1 : tensor<4xi32>
|
||||
%0 = mhlo.shift_left %arg0, %arg1 : tensor<4xi32>
|
||||
return %0 : tensor<4xi32>
|
||||
}
|
||||
|
||||
@ -53,17 +53,17 @@ func @div_dynamic(%arg0: tensor<?xi32>, %arg1: tensor<?x?xi32>) -> tensor<?x?xi3
|
||||
}
|
||||
|
||||
func @maximum(%arg0: tensor<4xf32>, %arg1: tensor<4xf32>) -> tensor<4xf32> {
|
||||
%0 = xla_hlo.maximum %arg0, %arg1 : tensor<4xf32>
|
||||
%0 = mhlo.maximum %arg0, %arg1 : tensor<4xf32>
|
||||
return %0 : tensor<4xf32>
|
||||
}
|
||||
|
||||
func @minimum(%arg0: tensor<4xf32>, %arg1: tensor<4xf32>) -> tensor<4xf32> {
|
||||
%0 = xla_hlo.minimum %arg0, %arg1 : tensor<4xf32>
|
||||
%0 = mhlo.minimum %arg0, %arg1 : tensor<4xf32>
|
||||
return %0 : tensor<4xf32>
|
||||
}
|
||||
|
||||
func @mul(%arg0: tensor<2xi32>) -> tensor<2xi32> {
|
||||
%0 = xla_hlo.multiply %arg0, %arg0 : tensor<2xi32>
|
||||
%0 = mhlo.multiply %arg0, %arg0 : tensor<2xi32>
|
||||
return %0 : tensor<2xi32>
|
||||
}
|
||||
|
||||
@ -73,7 +73,7 @@ func @broadcast_mul(%arg0: tensor<1xi32>, %arg1: tensor<1x2xi32>) -> tensor<1x2x
|
||||
}
|
||||
|
||||
func @real_div(%arg0: tensor<2xi32>) -> tensor<2xi32> {
|
||||
%0 = xla_hlo.divide %arg0, %arg0 : tensor<2xi32>
|
||||
%0 = mhlo.divide %arg0, %arg0 : tensor<2xi32>
|
||||
return %0 : tensor<2xi32>
|
||||
}
|
||||
|
||||
@ -83,7 +83,7 @@ func @broadcast_real_div(%arg0: tensor<1xi32>, %arg1: tensor<1x2xi32>) -> tensor
|
||||
}
|
||||
|
||||
func @sub(%arg0: tensor<2xi32>) -> tensor<2xi32> {
|
||||
%0 = xla_hlo.subtract %arg0, %arg0 : tensor<2xi32>
|
||||
%0 = mhlo.subtract %arg0, %arg0 : tensor<2xi32>
|
||||
return %0 : tensor<2xi32>
|
||||
}
|
||||
|
||||
@ -93,7 +93,7 @@ func @broadcast_sub(%arg0: tensor<1xi32>, %arg1: tensor<1x2xi32>) -> tensor<1x2x
|
||||
}
|
||||
|
||||
func @shift_right(%arg0: tensor<4xi32>, %arg1: tensor<4xi32>) -> tensor<4xi32> {
|
||||
%0 = xla_hlo.shift_right_arithmetic %arg0, %arg1 : tensor<4xi32>
|
||||
%0 = mhlo.shift_right_arithmetic %arg0, %arg1 : tensor<4xi32>
|
||||
return %0 : tensor<4xi32>
|
||||
}
|
||||
|
||||
@ -103,7 +103,7 @@ func @broadcast_shift_right(%arg0: tensor<4xi32>, %arg1: tensor<2x4xi32>) -> ten
|
||||
}
|
||||
|
||||
func @and(%arg0: tensor<2xi1>) -> tensor<2xi1> {
|
||||
%0 = xla_hlo.and %arg0, %arg0 : tensor<2xi1>
|
||||
%0 = mhlo.and %arg0, %arg0 : tensor<2xi1>
|
||||
return %0 : tensor<2xi1>
|
||||
}
|
||||
|
||||
@ -118,7 +118,7 @@ func @and_dynamic(%arg0: tensor<?xi1>, %arg1: tensor<1xi1>) -> tensor<?xi1> {
|
||||
}
|
||||
|
||||
func @or(%arg0: tensor<2xi1>) -> tensor<2xi1> {
|
||||
%0 = xla_hlo.or %arg0, %arg0 : tensor<2xi1>
|
||||
%0 = mhlo.or %arg0, %arg0 : tensor<2xi1>
|
||||
return %0 : tensor<2xi1>
|
||||
}
|
||||
|
||||
@ -133,7 +133,7 @@ func @or_dynamic(%arg0: tensor<?xi1>, %arg1: tensor<1xi1>) -> tensor<?xi1> {
|
||||
}
|
||||
|
||||
func @bitwise_or(%arg0: tensor<4xi32>, %arg1: tensor<4xi32>) -> tensor<4xi32> {
|
||||
%0 = xla_hlo.or %arg0, %arg1 : tensor<4xi32>
|
||||
%0 = mhlo.or %arg0, %arg1 : tensor<4xi32>
|
||||
return %0 : tensor<4xi32>
|
||||
}
|
||||
|
||||
@ -148,7 +148,7 @@ func @bitwise_or_dynamic(%arg0: tensor<?xi32>, %arg1: tensor<1xi32>) -> tensor<?
|
||||
}
|
||||
|
||||
func @bitwise_and(%arg0: tensor<4xi32>, %arg1: tensor<4xi32>) -> tensor<4xi32> {
|
||||
%0 = xla_hlo.and %arg0, %arg1 : tensor<4xi32>
|
||||
%0 = mhlo.and %arg0, %arg1 : tensor<4xi32>
|
||||
return %0 : tensor<4xi32>
|
||||
}
|
||||
|
||||
@ -163,69 +163,69 @@ func @bitwise_and_dynamic(%arg0: tensor<?xi32>, %arg1: tensor<1xi32>) -> tensor<
|
||||
}
|
||||
|
||||
func @pow(%arg0: tensor<2xf32>) -> tensor<2xf32> {
|
||||
%0 = xla_hlo.power %arg0, %arg0 : tensor<2xf32>
|
||||
%0 = mhlo.power %arg0, %arg0 : tensor<2xf32>
|
||||
return %0 : tensor<2xf32>
|
||||
}
|
||||
|
||||
func @pow_dynamic(%arg0: tensor<?xf32>) -> tensor<?xf32> {
|
||||
%0 = xla_hlo.power %arg0, %arg0 : tensor<?xf32>
|
||||
%0 = mhlo.power %arg0, %arg0 : tensor<?xf32>
|
||||
return %0 : tensor<?xf32>
|
||||
}
|
||||
|
||||
func @floordiv_broadcast_i32(%arg0: tensor<2x3xi32>, %arg1: tensor<3xi32>) -> tensor<2x3xi32> {
|
||||
%0 = xla_hlo.constant dense<0> : tensor<2x3xi32>
|
||||
%0 = mhlo.constant dense<0> : tensor<2x3xi32>
|
||||
%1 = "xla_chlo.broadcast_compare"(%arg0, %0) {comparison_direction = "LT"} : (tensor<2x3xi32>, tensor<2x3xi32>) -> tensor<2x3xi1>
|
||||
%2 = xla_hlo.constant dense<0> : tensor<3xi32>
|
||||
%2 = mhlo.constant dense<0> : tensor<3xi32>
|
||||
%3 = "xla_chlo.broadcast_compare"(%arg1, %2) {comparison_direction = "LT"} : (tensor<3xi32>, tensor<3xi32>) -> tensor<3xi1>
|
||||
%4 = "xla_chlo.broadcast_compare"(%1, %3) {broadcast_dimensions = dense<1> : tensor<1xi64>, comparison_direction = "EQ"} : (tensor<2x3xi1>, tensor<3xi1>) -> tensor<2x3xi1>
|
||||
%5 = "xla_chlo.broadcast_divide"(%arg0, %arg1) {broadcast_dimensions = dense<1> : tensor<1xi64>} : (tensor<2x3xi32>, tensor<3xi32>) -> tensor<2x3xi32>
|
||||
%6 = "xla_hlo.abs"(%arg0) : (tensor<2x3xi32>) -> tensor<2x3xi32>
|
||||
%7 = "xla_hlo.abs"(%arg1) : (tensor<3xi32>) -> tensor<3xi32>
|
||||
%8 = xla_hlo.constant dense<1> : tensor<3xi32>
|
||||
%9 = xla_hlo.subtract %7, %8 : tensor<3xi32>
|
||||
%6 = "mhlo.abs"(%arg0) : (tensor<2x3xi32>) -> tensor<2x3xi32>
|
||||
%7 = "mhlo.abs"(%arg1) : (tensor<3xi32>) -> tensor<3xi32>
|
||||
%8 = mhlo.constant dense<1> : tensor<3xi32>
|
||||
%9 = mhlo.subtract %7, %8 : tensor<3xi32>
|
||||
%10 = "xla_chlo.broadcast_add"(%6, %9) {broadcast_dimensions = dense<1> : tensor<1xi64>} : (tensor<2x3xi32>, tensor<3xi32>) -> tensor<2x3xi32>
|
||||
%11 = "xla_hlo.negate"(%10) : (tensor<2x3xi32>) -> tensor<2x3xi32>
|
||||
%12 = "xla_hlo.abs"(%arg1) : (tensor<3xi32>) -> tensor<3xi32>
|
||||
%11 = "mhlo.negate"(%10) : (tensor<2x3xi32>) -> tensor<2x3xi32>
|
||||
%12 = "mhlo.abs"(%arg1) : (tensor<3xi32>) -> tensor<3xi32>
|
||||
%13 = "xla_chlo.broadcast_divide"(%11, %12) {broadcast_dimensions = dense<1> : tensor<1xi64>} : (tensor<2x3xi32>, tensor<3xi32>) -> tensor<2x3xi32>
|
||||
%14 = "xla_hlo.select"(%4, %5, %13) : (tensor<2x3xi1>, tensor<2x3xi32>, tensor<2x3xi32>) -> tensor<2x3xi32>
|
||||
%14 = "mhlo.select"(%4, %5, %13) : (tensor<2x3xi1>, tensor<2x3xi32>, tensor<2x3xi32>) -> tensor<2x3xi32>
|
||||
return %14 : tensor<2x3xi32>
|
||||
}
|
||||
|
||||
func @floordiv_reverse_broadcast_i32(%arg0: tensor<3xi32>, %arg1: tensor<2x3xi32>) -> tensor<2x3xi32> {
|
||||
%0 = xla_hlo.constant dense<0> : tensor<3xi32>
|
||||
%1 = "xla_hlo.compare"(%arg0, %0) {comparison_direction = "LT"} : (tensor<3xi32>, tensor<3xi32>) -> tensor<3xi1>
|
||||
%2 = xla_hlo.constant dense<0> : tensor<2x3xi32>
|
||||
%0 = mhlo.constant dense<0> : tensor<3xi32>
|
||||
%1 = "mhlo.compare"(%arg0, %0) {comparison_direction = "LT"} : (tensor<3xi32>, tensor<3xi32>) -> tensor<3xi1>
|
||||
%2 = mhlo.constant dense<0> : tensor<2x3xi32>
|
||||
%3 = "xla_chlo.broadcast_compare"(%arg1, %2) {comparison_direction = "LT"} : (tensor<2x3xi32>, tensor<2x3xi32>) -> tensor<2x3xi1>
|
||||
%4 = "xla_chlo.broadcast_compare"(%1, %3) {broadcast_dimensions = dense<1> : tensor<1xi64>, comparison_direction = "EQ"} : (tensor<3xi1>, tensor<2x3xi1>) -> tensor<2x3xi1>
|
||||
%5 = "xla_chlo.broadcast_divide"(%arg0, %arg1) {broadcast_dimensions = dense<1> : tensor<1xi64>} : (tensor<3xi32>, tensor<2x3xi32>) -> tensor<2x3xi32>
|
||||
%6 = "xla_hlo.abs"(%arg0) : (tensor<3xi32>) -> tensor<3xi32>
|
||||
%7 = "xla_hlo.abs"(%arg1) : (tensor<2x3xi32>) -> tensor<2x3xi32>
|
||||
%8 = xla_hlo.constant dense<1> : tensor<2x3xi32>
|
||||
%9 = xla_hlo.subtract %7, %8 : tensor<2x3xi32>
|
||||
%6 = "mhlo.abs"(%arg0) : (tensor<3xi32>) -> tensor<3xi32>
|
||||
%7 = "mhlo.abs"(%arg1) : (tensor<2x3xi32>) -> tensor<2x3xi32>
|
||||
%8 = mhlo.constant dense<1> : tensor<2x3xi32>
|
||||
%9 = mhlo.subtract %7, %8 : tensor<2x3xi32>
|
||||
%10 = "xla_chlo.broadcast_add"(%6, %9) {broadcast_dimensions = dense<1> : tensor<1xi64>} : (tensor<3xi32>, tensor<2x3xi32>) -> tensor<2x3xi32>
|
||||
%11 = "xla_hlo.negate"(%10) : (tensor<2x3xi32>) -> tensor<2x3xi32>
|
||||
%12 = "xla_hlo.abs"(%arg1) : (tensor<2x3xi32>) -> tensor<2x3xi32>
|
||||
%13 = xla_hlo.divide %11, %12 : tensor<2x3xi32>
|
||||
%14 = "xla_hlo.select"(%4, %5, %13) : (tensor<2x3xi1>, tensor<2x3xi32>, tensor<2x3xi32>) -> tensor<2x3xi32>
|
||||
%11 = "mhlo.negate"(%10) : (tensor<2x3xi32>) -> tensor<2x3xi32>
|
||||
%12 = "mhlo.abs"(%arg1) : (tensor<2x3xi32>) -> tensor<2x3xi32>
|
||||
%13 = mhlo.divide %11, %12 : tensor<2x3xi32>
|
||||
%14 = "mhlo.select"(%4, %5, %13) : (tensor<2x3xi1>, tensor<2x3xi32>, tensor<2x3xi32>) -> tensor<2x3xi32>
|
||||
return %14 : tensor<2x3xi32>
|
||||
}
|
||||
|
||||
func @floordiv_f32(%arg0: tensor<2xf32>) -> tensor<2xf32> {
|
||||
%0 = xla_hlo.divide %arg0, %arg0 : tensor<2xf32>
|
||||
%1 = xla_hlo.divide %arg0, %arg0 : tensor<2xf32>
|
||||
%2 = "xla_hlo.floor"(%1) : (tensor<2xf32>) -> tensor<2xf32>
|
||||
%0 = mhlo.divide %arg0, %arg0 : tensor<2xf32>
|
||||
%1 = mhlo.divide %arg0, %arg0 : tensor<2xf32>
|
||||
%2 = "mhlo.floor"(%1) : (tensor<2xf32>) -> tensor<2xf32>
|
||||
return %2 : tensor<2xf32>
|
||||
}
|
||||
|
||||
func @floordiv_f16_broadcast(%arg0: tensor<2x3xf16>, %arg1: tensor<3xf16>) -> tensor<2x3xf16> {
|
||||
%0 = "xla_chlo.broadcast_divide"(%arg0, %arg1) {broadcast_dimensions = dense<1> : tensor<1xi64>} : (tensor<2x3xf16>, tensor<3xf16>) -> tensor<2x3xf16>
|
||||
%1 = "xla_chlo.broadcast_divide"(%arg0, %arg1) {broadcast_dimensions = dense<1> : tensor<1xi64>} : (tensor<2x3xf16>, tensor<3xf16>) -> tensor<2x3xf16>
|
||||
%2 = "xla_hlo.floor"(%1) : (tensor<2x3xf16>) -> tensor<2x3xf16>
|
||||
%2 = "mhlo.floor"(%1) : (tensor<2x3xf16>) -> tensor<2x3xf16>
|
||||
return %2 : tensor<2x3xf16>
|
||||
}
|
||||
|
||||
func @equal(%arg0: tensor<2xi32>) -> tensor<2xi1> {
|
||||
%0 = "xla_hlo.compare"(%arg0, %arg0) {comparison_direction = "EQ"} : (tensor<2xi32>, tensor<2xi32>) -> tensor<2xi1>
|
||||
%0 = "mhlo.compare"(%arg0, %arg0) {comparison_direction = "EQ"} : (tensor<2xi32>, tensor<2xi32>) -> tensor<2xi1>
|
||||
return %0 : tensor<2xi1>
|
||||
}
|
||||
|
||||
@ -250,7 +250,7 @@ func @equal_incompatible_shape_broadcastable(%arg0: tensor<?xi32>, %arg1: tensor
|
||||
}
|
||||
|
||||
func @notequal(%arg0: tensor<2xi32>) -> tensor<2xi1> {
|
||||
%0 = "xla_hlo.compare"(%arg0, %arg0) {comparison_direction = "NE"} : (tensor<2xi32>, tensor<2xi32>) -> tensor<2xi1>
|
||||
%0 = "mhlo.compare"(%arg0, %arg0) {comparison_direction = "NE"} : (tensor<2xi32>, tensor<2xi32>) -> tensor<2xi1>
|
||||
return %0 : tensor<2xi1>
|
||||
}
|
||||
|
||||
@ -270,7 +270,7 @@ func @notequal_incompatible_shape_broadcastable(%arg0: tensor<?xi32>, %arg1: ten
|
||||
}
|
||||
|
||||
func @greater(%arg0: tensor<2xi32>) -> tensor<2xi1> {
|
||||
%0 = "xla_hlo.compare"(%arg0, %arg0) {comparison_direction = "GT"} : (tensor<2xi32>, tensor<2xi32>) -> tensor<2xi1>
|
||||
%0 = "mhlo.compare"(%arg0, %arg0) {comparison_direction = "GT"} : (tensor<2xi32>, tensor<2xi32>) -> tensor<2xi1>
|
||||
return %0 : tensor<2xi1>
|
||||
}
|
||||
|
||||
@ -280,7 +280,7 @@ func @broadcast_greater(%arg0: tensor<1xi32>, %arg1: tensor<1x2xi32>) -> tensor<
|
||||
}
|
||||
|
||||
func @greater_equal(%arg0: tensor<2xi32>) -> tensor<2xi1> {
|
||||
%0 = "xla_hlo.compare"(%arg0, %arg0) {comparison_direction = "GE"} : (tensor<2xi32>, tensor<2xi32>) -> tensor<2xi1>
|
||||
%0 = "mhlo.compare"(%arg0, %arg0) {comparison_direction = "GE"} : (tensor<2xi32>, tensor<2xi32>) -> tensor<2xi1>
|
||||
return %0 : tensor<2xi1>
|
||||
}
|
||||
|
||||
@ -290,7 +290,7 @@ func @broadcast_greater_equal(%arg0: tensor<1xi32>, %arg1: tensor<1x2xi32>) -> t
|
||||
}
|
||||
|
||||
func @less(%arg0: tensor<2xi32>) -> tensor<2xi1> {
|
||||
%0 = "xla_hlo.compare"(%arg0, %arg0) {comparison_direction = "LT"} : (tensor<2xi32>, tensor<2xi32>) -> tensor<2xi1>
|
||||
%0 = "mhlo.compare"(%arg0, %arg0) {comparison_direction = "LT"} : (tensor<2xi32>, tensor<2xi32>) -> tensor<2xi1>
|
||||
return %0 : tensor<2xi1>
|
||||
}
|
||||
|
||||
@ -300,7 +300,7 @@ func @broadcast_less(%arg0: tensor<1xi32>, %arg1: tensor<1x2xi32>) -> tensor<1x2
|
||||
}
|
||||
|
||||
func @less_equal(%arg0: tensor<2xi32>) -> tensor<2xi1> {
|
||||
%0 = "xla_hlo.compare"(%arg0, %arg0) {comparison_direction = "LE"} : (tensor<2xi32>, tensor<2xi32>) -> tensor<2xi1>
|
||||
%0 = "mhlo.compare"(%arg0, %arg0) {comparison_direction = "LE"} : (tensor<2xi32>, tensor<2xi32>) -> tensor<2xi1>
|
||||
return %0 : tensor<2xi1>
|
||||
}
|
||||
|
||||
@ -310,426 +310,426 @@ func @broadcast_less_equal(%arg0: tensor<1xi32>, %arg1: tensor<1x2xi32>) -> tens
|
||||
}
|
||||
|
||||
func @concat_v2(%arg0: tensor<3x3xf32>, %arg1: tensor<3x3xf32>) -> tensor<6x3xf32> {
|
||||
%2 = "xla_hlo.concatenate"(%arg0, %arg1) {dimension = 0 : i64} : (tensor<3x3xf32>, tensor<3x3xf32>) -> tensor<6x3xf32>
|
||||
%2 = "mhlo.concatenate"(%arg0, %arg1) {dimension = 0 : i64} : (tensor<3x3xf32>, tensor<3x3xf32>) -> tensor<6x3xf32>
|
||||
return %2 : tensor<6x3xf32>
|
||||
}
|
||||
|
||||
func @concat_v2_1d_axis(%arg0: tensor<3x3xf32>, %arg1: tensor<3x3xf32>) -> tensor<3x6xf32> {
|
||||
%2 = "xla_hlo.concatenate"(%arg0, %arg1) {dimension = 1 : i64} : (tensor<3x3xf32>, tensor<3x3xf32>) -> tensor<3x6xf32>
|
||||
%2 = "mhlo.concatenate"(%arg0, %arg1) {dimension = 1 : i64} : (tensor<3x3xf32>, tensor<3x3xf32>) -> tensor<3x6xf32>
|
||||
return %2 : tensor<3x6xf32>
|
||||
}
|
||||
|
||||
func @const() -> tensor<2xi32> {
|
||||
%0 = xla_hlo.constant dense<0> : tensor<2xi32>
|
||||
%0 = mhlo.constant dense<0> : tensor<2xi32>
|
||||
return %0 : tensor<2xi32>
|
||||
}
|
||||
|
||||
func @relu(%arg0: tensor<1xi32>) -> tensor<1xi32> {
|
||||
%0 = xla_hlo.constant dense<0> : tensor<i32>
|
||||
%0 = mhlo.constant dense<0> : tensor<i32>
|
||||
%1 = "xla_chlo.broadcast_maximum"(%0, %arg0) {broadcast_dimensions = dense<[]> : tensor<0xi64>} : (tensor<i32>, tensor<1xi32>) -> tensor<1xi32>
|
||||
return %1 : tensor<1xi32>
|
||||
}
|
||||
|
||||
func @relu_unranked(%arg0: tensor<?xi32>) -> tensor<?xi32> {
|
||||
%0 = xla_hlo.constant dense<0> : tensor<i32>
|
||||
%0 = mhlo.constant dense<0> : tensor<i32>
|
||||
%1 = "xla_chlo.broadcast_maximum"(%0, %arg0) {broadcast_dimensions = dense<[]> : tensor<0xi64>} : (tensor<i32>, tensor<?xi32>) -> tensor<?xi32>
|
||||
return %1 : tensor<?xi32>
|
||||
}
|
||||
|
||||
func @relu6(%arg0: tensor<1xi32>) -> tensor<1xi32> {
|
||||
%0 = xla_hlo.constant dense<0> : tensor<i32>
|
||||
%1 = xla_hlo.constant dense<6> : tensor<i32>
|
||||
%0 = mhlo.constant dense<0> : tensor<i32>
|
||||
%1 = mhlo.constant dense<6> : tensor<i32>
|
||||
%2 = "xla_chlo.broadcast_minimum"(%arg0, %1) {broadcast_dimensions = dense<[]> : tensor<0xi64>} : (tensor<1xi32>, tensor<i32>) -> tensor<1xi32>
|
||||
%3 = "xla_chlo.broadcast_maximum"(%2, %0) {broadcast_dimensions = dense<[]> : tensor<0xi64>} : (tensor<1xi32>, tensor<i32>) -> tensor<1xi32>
|
||||
return %3 : tensor<1xi32>
|
||||
}
|
||||
|
||||
func @relu6_unranked(%arg0: tensor<?xi32>) -> tensor<?xi32> {
|
||||
%0 = xla_hlo.constant dense<0> : tensor<i32>
|
||||
%1 = xla_hlo.constant dense<6> : tensor<i32>
|
||||
%0 = mhlo.constant dense<0> : tensor<i32>
|
||||
%1 = mhlo.constant dense<6> : tensor<i32>
|
||||
%2 = "xla_chlo.broadcast_minimum"(%arg0, %1) {broadcast_dimensions = dense<[]> : tensor<0xi64>} : (tensor<?xi32>, tensor<i32>) -> tensor<?xi32>
|
||||
%3 = "xla_chlo.broadcast_maximum"(%2, %0) {broadcast_dimensions = dense<[]> : tensor<0xi64>} : (tensor<?xi32>, tensor<i32>) -> tensor<?xi32>
|
||||
return %3 : tensor<?xi32>
|
||||
}
|
||||
|
||||
func @relu_grad(%arg0: tensor<4x8xf32>, %arg1: tensor<?x?xf32>) -> tensor<4x8xf32> {
|
||||
%0 = xla_hlo.constant dense<0.000000e+00> : tensor<f32>
|
||||
%0 = mhlo.constant dense<0.000000e+00> : tensor<f32>
|
||||
%1 = "xla_chlo.broadcast_compare"(%arg1, %0) {broadcast_dimensions = dense<[]> : tensor<0xi64>, comparison_direction = "GT"} : (tensor<?x?xf32>, tensor<f32>) -> tensor<?x?xi1>
|
||||
%2 = xla_hlo.constant dense<0.000000e+00> : tensor<4x8xf32>
|
||||
%3 = "xla_hlo.select"(%1, %arg0, %2) : (tensor<?x?xi1>, tensor<4x8xf32>, tensor<4x8xf32>) -> tensor<4x8xf32>
|
||||
%2 = mhlo.constant dense<0.000000e+00> : tensor<4x8xf32>
|
||||
%3 = "mhlo.select"(%1, %arg0, %2) : (tensor<?x?xi1>, tensor<4x8xf32>, tensor<4x8xf32>) -> tensor<4x8xf32>
|
||||
return %3 : tensor<4x8xf32>
|
||||
}
|
||||
|
||||
func @select(%arg0: tensor<2xi1>, %arg1: tensor<2xi32>, %arg2: tensor<2xi32>) -> tensor<2xi32> {
|
||||
%0 = "xla_hlo.select"(%arg0, %arg1, %arg2) : (tensor<2xi1>, tensor<2xi32>, tensor<2xi32>) -> tensor<2xi32>
|
||||
%0 = "mhlo.select"(%arg0, %arg1, %arg2) : (tensor<2xi1>, tensor<2xi32>, tensor<2xi32>) -> tensor<2xi32>
|
||||
return %0 : tensor<2xi32>
|
||||
}
|
||||
|
||||
func @select_float(%arg0: tensor<2xi1>, %arg1: tensor<2xf32>, %arg2: tensor<2xf32>) -> tensor<2xf32> {
|
||||
%0 = "xla_hlo.select"(%arg0, %arg1, %arg2) : (tensor<2xi1>, tensor<2xf32>, tensor<2xf32>) -> tensor<2xf32>
|
||||
%0 = "mhlo.select"(%arg0, %arg1, %arg2) : (tensor<2xi1>, tensor<2xf32>, tensor<2xf32>) -> tensor<2xf32>
|
||||
return %0 : tensor<2xf32>
|
||||
}
|
||||
|
||||
func @select_multidimensional(%arg0: tensor<3x2xi1>, %arg1: tensor<3x2xi32>, %arg2: tensor<3x2xi32>) -> tensor<3x2xi32> {
|
||||
%0 = "xla_hlo.select"(%arg0, %arg1, %arg2) : (tensor<3x2xi1>, tensor<3x2xi32>, tensor<3x2xi32>) -> tensor<3x2xi32>
|
||||
%0 = "mhlo.select"(%arg0, %arg1, %arg2) : (tensor<3x2xi1>, tensor<3x2xi32>, tensor<3x2xi32>) -> tensor<3x2xi32>
|
||||
return %0 : tensor<3x2xi32>
|
||||
}
|
||||
|
||||
func @selectv2(%arg0: tensor<2xi1>, %arg1: tensor<2xi32>, %arg2: tensor<2xi32>) -> tensor<2xi32> {
|
||||
%0 = "xla_hlo.select"(%arg0, %arg1, %arg2) : (tensor<2xi1>, tensor<2xi32>, tensor<2xi32>) -> tensor<2xi32>
|
||||
%0 = "mhlo.select"(%arg0, %arg1, %arg2) : (tensor<2xi1>, tensor<2xi32>, tensor<2xi32>) -> tensor<2xi32>
|
||||
return %0 : tensor<2xi32>
|
||||
}
|
||||
|
||||
func @selectv2_pred_scalar(%arg0: tensor<i1>, %arg1: tensor<2xi32>, %arg2: tensor<2xi32>) -> tensor<2xi32> {
|
||||
%0 = "xla_hlo.select"(%arg0, %arg1, %arg2) : (tensor<i1>, tensor<2xi32>, tensor<2xi32>) -> tensor<2xi32>
|
||||
%0 = "mhlo.select"(%arg0, %arg1, %arg2) : (tensor<i1>, tensor<2xi32>, tensor<2xi32>) -> tensor<2xi32>
|
||||
return %0 : tensor<2xi32>
|
||||
}
|
||||
|
||||
func @transpose_2d(%arg0: tensor<2x3xf32>) -> tensor<3x2xf32> {
|
||||
%0 = xla_hlo.constant dense<[1, 0]> : tensor<2xi64>
|
||||
%1 = xla_hlo.constant dense<[1, 0]> : tensor<2xi64>
|
||||
%2 = "xla_hlo.transpose"(%arg0) {permutation = dense<[1, 0]> : tensor<2xi64>} : (tensor<2x3xf32>) -> tensor<3x2xf32>
|
||||
%0 = mhlo.constant dense<[1, 0]> : tensor<2xi64>
|
||||
%1 = mhlo.constant dense<[1, 0]> : tensor<2xi64>
|
||||
%2 = "mhlo.transpose"(%arg0) {permutation = dense<[1, 0]> : tensor<2xi64>} : (tensor<2x3xf32>) -> tensor<3x2xf32>
|
||||
return %2 : tensor<3x2xf32>
|
||||
}
|
||||
|
||||
func @transpose_3d_int32(%arg0: tensor<1x2x3xf32>) -> tensor<3x2x1xf32> {
|
||||
%0 = xla_hlo.constant dense<[2, 1, 0]> : tensor<3xi32>
|
||||
%1 = xla_hlo.constant dense<[2, 1, 0]> : tensor<3xi64>
|
||||
%2 = "xla_hlo.transpose"(%arg0) {permutation = dense<[2, 1, 0]> : tensor<3xi64>} : (tensor<1x2x3xf32>) -> tensor<3x2x1xf32>
|
||||
%0 = mhlo.constant dense<[2, 1, 0]> : tensor<3xi32>
|
||||
%1 = mhlo.constant dense<[2, 1, 0]> : tensor<3xi64>
|
||||
%2 = "mhlo.transpose"(%arg0) {permutation = dense<[2, 1, 0]> : tensor<3xi64>} : (tensor<1x2x3xf32>) -> tensor<3x2x1xf32>
|
||||
return %2 : tensor<3x2x1xf32>
|
||||
}
|
||||
|
||||
func @transpose_3d(%arg0: tensor<1x2x3xf32>) -> tensor<3x2x1xf32> {
|
||||
%0 = xla_hlo.constant dense<[2, 1, 0]> : tensor<3xi64>
|
||||
%1 = xla_hlo.constant dense<[2, 1, 0]> : tensor<3xi64>
|
||||
%2 = "xla_hlo.transpose"(%arg0) {permutation = dense<[2, 1, 0]> : tensor<3xi64>} : (tensor<1x2x3xf32>) -> tensor<3x2x1xf32>
|
||||
%0 = mhlo.constant dense<[2, 1, 0]> : tensor<3xi64>
|
||||
%1 = mhlo.constant dense<[2, 1, 0]> : tensor<3xi64>
|
||||
%2 = "mhlo.transpose"(%arg0) {permutation = dense<[2, 1, 0]> : tensor<3xi64>} : (tensor<1x2x3xf32>) -> tensor<3x2x1xf32>
|
||||
return %2 : tensor<3x2x1xf32>
|
||||
}
|
||||
|
||||
func @transpose_dynamic_2d(%arg0: tensor<?x4xf32>) -> tensor<4x?xf32> {
|
||||
%0 = xla_hlo.constant dense<[1, 0]> : tensor<2xi64>
|
||||
%1 = xla_hlo.constant dense<[1, 0]> : tensor<2xi64>
|
||||
%2 = "xla_hlo.transpose"(%arg0) {permutation = dense<[1, 0]> : tensor<2xi64>} : (tensor<?x4xf32>) -> tensor<4x?xf32>
|
||||
%0 = mhlo.constant dense<[1, 0]> : tensor<2xi64>
|
||||
%1 = mhlo.constant dense<[1, 0]> : tensor<2xi64>
|
||||
%2 = "mhlo.transpose"(%arg0) {permutation = dense<[1, 0]> : tensor<2xi64>} : (tensor<?x4xf32>) -> tensor<4x?xf32>
|
||||
return %2 : tensor<4x?xf32>
|
||||
}
|
||||
|
||||
func @transpose_unranked_2d(%arg0: tensor<*xf32>) -> tensor<*xf32> {
|
||||
%0 = xla_hlo.constant dense<[1, 0]> : tensor<2xi64>
|
||||
%1 = xla_hlo.constant dense<[1, 0]> : tensor<2xi64>
|
||||
%2 = "xla_hlo.transpose"(%arg0) {permutation = dense<[1, 0]> : tensor<2xi64>} : (tensor<*xf32>) -> tensor<*xf32>
|
||||
%0 = mhlo.constant dense<[1, 0]> : tensor<2xi64>
|
||||
%1 = mhlo.constant dense<[1, 0]> : tensor<2xi64>
|
||||
%2 = "mhlo.transpose"(%arg0) {permutation = dense<[1, 0]> : tensor<2xi64>} : (tensor<*xf32>) -> tensor<*xf32>
|
||||
return %2 : tensor<*xf32>
|
||||
}
|
||||
|
||||
func @abs(%arg0: tensor<2xf32>) -> tensor<2xf32> {
|
||||
%0 = "xla_hlo.abs"(%arg0) : (tensor<2xf32>) -> tensor<2xf32>
|
||||
%0 = "mhlo.abs"(%arg0) : (tensor<2xf32>) -> tensor<2xf32>
|
||||
return %0 : tensor<2xf32>
|
||||
}
|
||||
|
||||
func @abs_dynamic(%arg0: tensor<?xf32>) -> tensor<?xf32> {
|
||||
%0 = "xla_hlo.abs"(%arg0) : (tensor<?xf32>) -> tensor<?xf32>
|
||||
%0 = "mhlo.abs"(%arg0) : (tensor<?xf32>) -> tensor<?xf32>
|
||||
return %0 : tensor<?xf32>
|
||||
}
|
||||
|
||||
func @abs_unranked(%arg0: tensor<*xf32>) -> tensor<*xf32> {
|
||||
%0 = "xla_hlo.abs"(%arg0) : (tensor<*xf32>) -> tensor<*xf32>
|
||||
%0 = "mhlo.abs"(%arg0) : (tensor<*xf32>) -> tensor<*xf32>
|
||||
return %0 : tensor<*xf32>
|
||||
}
|
||||
|
||||
func @ceil(%arg0: tensor<2xf32>) -> tensor<2xf32> {
|
||||
%0 = "xla_hlo.ceil"(%arg0) : (tensor<2xf32>) -> tensor<2xf32>
|
||||
%0 = "mhlo.ceil"(%arg0) : (tensor<2xf32>) -> tensor<2xf32>
|
||||
return %0 : tensor<2xf32>
|
||||
}
|
||||
|
||||
func @ceil_dynamic(%arg0: tensor<?xf32>) -> tensor<?xf32> {
|
||||
%0 = "xla_hlo.ceil"(%arg0) : (tensor<?xf32>) -> tensor<?xf32>
|
||||
%0 = "mhlo.ceil"(%arg0) : (tensor<?xf32>) -> tensor<?xf32>
|
||||
return %0 : tensor<?xf32>
|
||||
}
|
||||
|
||||
func @ceil_unranked(%arg0: tensor<*xf32>) -> tensor<*xf32> {
|
||||
%0 = "xla_hlo.ceil"(%arg0) : (tensor<*xf32>) -> tensor<*xf32>
|
||||
%0 = "mhlo.ceil"(%arg0) : (tensor<*xf32>) -> tensor<*xf32>
|
||||
return %0 : tensor<*xf32>
|
||||
}
|
||||
|
||||
func @complex_abs(%arg0: tensor<2xcomplex<f32>>) -> tensor<2xf32> {
|
||||
%0 = "xla_hlo.abs"(%arg0) : (tensor<2xcomplex<f32>>) -> tensor<2xf32>
|
||||
%0 = "mhlo.abs"(%arg0) : (tensor<2xcomplex<f32>>) -> tensor<2xf32>
|
||||
return %0 : tensor<2xf32>
|
||||
}
|
||||
|
||||
func @cos(%arg0: tensor<2xf32>) -> tensor<2xf32> {
|
||||
%0 = "xla_hlo.cosine"(%arg0) : (tensor<2xf32>) -> tensor<2xf32>
|
||||
%0 = "mhlo.cosine"(%arg0) : (tensor<2xf32>) -> tensor<2xf32>
|
||||
return %0 : tensor<2xf32>
|
||||
}
|
||||
|
||||
func @cos_dynamic(%arg0: tensor<?xf32>) -> tensor<?xf32> {
|
||||
%0 = "xla_hlo.cosine"(%arg0) : (tensor<?xf32>) -> tensor<?xf32>
|
||||
%0 = "mhlo.cosine"(%arg0) : (tensor<?xf32>) -> tensor<?xf32>
|
||||
return %0 : tensor<?xf32>
|
||||
}
|
||||
|
||||
func @cos_unranked(%arg0: tensor<*xf32>) -> tensor<*xf32> {
|
||||
%0 = "xla_hlo.cosine"(%arg0) : (tensor<*xf32>) -> tensor<*xf32>
|
||||
%0 = "mhlo.cosine"(%arg0) : (tensor<*xf32>) -> tensor<*xf32>
|
||||
return %0 : tensor<*xf32>
|
||||
}
|
||||
|
||||
func @exp(%arg0: tensor<2xf32>) -> tensor<2xf32> {
|
||||
%0 = "xla_hlo.exponential"(%arg0) : (tensor<2xf32>) -> tensor<2xf32>
|
||||
%0 = "mhlo.exponential"(%arg0) : (tensor<2xf32>) -> tensor<2xf32>
|
||||
return %0 : tensor<2xf32>
|
||||
}
|
||||
|
||||
func @exp_dynamic(%arg0: tensor<?xf32>) -> tensor<?xf32> {
|
||||
%0 = "xla_hlo.exponential"(%arg0) : (tensor<?xf32>) -> tensor<?xf32>
|
||||
%0 = "mhlo.exponential"(%arg0) : (tensor<?xf32>) -> tensor<?xf32>
|
||||
return %0 : tensor<?xf32>
|
||||
}
|
||||
|
||||
func @exp_unranked(%arg0: tensor<*xf32>) -> tensor<*xf32> {
|
||||
%0 = "xla_hlo.exponential"(%arg0) : (tensor<*xf32>) -> tensor<*xf32>
|
||||
%0 = "mhlo.exponential"(%arg0) : (tensor<*xf32>) -> tensor<*xf32>
|
||||
return %0 : tensor<*xf32>
|
||||
}
|
||||
|
||||
func @floor(%arg0: tensor<2xf32>) -> tensor<2xf32> {
|
||||
%0 = "xla_hlo.floor"(%arg0) : (tensor<2xf32>) -> tensor<2xf32>
|
||||
%0 = "mhlo.floor"(%arg0) : (tensor<2xf32>) -> tensor<2xf32>
|
||||
return %0 : tensor<2xf32>
|
||||
}
|
||||
|
||||
func @floor_dynamic(%arg0: tensor<?xf32>) -> tensor<?xf32> {
|
||||
%0 = "xla_hlo.floor"(%arg0) : (tensor<?xf32>) -> tensor<?xf32>
|
||||
%0 = "mhlo.floor"(%arg0) : (tensor<?xf32>) -> tensor<?xf32>
|
||||
return %0 : tensor<?xf32>
|
||||
}
|
||||
|
||||
func @floor_unranked(%arg0: tensor<*xf32>) -> tensor<*xf32> {
|
||||
%0 = "xla_hlo.floor"(%arg0) : (tensor<*xf32>) -> tensor<*xf32>
|
||||
%0 = "mhlo.floor"(%arg0) : (tensor<*xf32>) -> tensor<*xf32>
|
||||
return %0 : tensor<*xf32>
|
||||
}
|
||||
|
||||
func @is_finite(%arg0: tensor<2xf32>) -> tensor<2xi1> {
|
||||
%0 = "xla_hlo.is_finite"(%arg0) : (tensor<2xf32>) -> tensor<2xi1>
|
||||
%0 = "mhlo.is_finite"(%arg0) : (tensor<2xf32>) -> tensor<2xi1>
|
||||
return %0 : tensor<2xi1>
|
||||
}
|
||||
|
||||
func @is_finite_dynamic(%arg0: tensor<?xf32>) -> tensor<?xi1> {
|
||||
%0 = "xla_hlo.is_finite"(%arg0) : (tensor<?xf32>) -> tensor<?xi1>
|
||||
%0 = "mhlo.is_finite"(%arg0) : (tensor<?xf32>) -> tensor<?xi1>
|
||||
return %0 : tensor<?xi1>
|
||||
}
|
||||
|
||||
func @is_finite_unranked(%arg0: tensor<*xf32>) -> tensor<*xi1> {
|
||||
%0 = "xla_hlo.is_finite"(%arg0) : (tensor<*xf32>) -> tensor<*xi1>
|
||||
%0 = "mhlo.is_finite"(%arg0) : (tensor<*xf32>) -> tensor<*xi1>
|
||||
return %0 : tensor<*xi1>
|
||||
}
|
||||
|
||||
func @log(%arg0: tensor<2xf32>) -> tensor<2xf32> {
|
||||
%0 = "xla_hlo.log"(%arg0) : (tensor<2xf32>) -> tensor<2xf32>
|
||||
%0 = "mhlo.log"(%arg0) : (tensor<2xf32>) -> tensor<2xf32>
|
||||
return %0 : tensor<2xf32>
|
||||
}
|
||||
|
||||
func @log_dynamic(%arg0: tensor<?xf32>) -> tensor<?xf32> {
|
||||
%0 = "xla_hlo.log"(%arg0) : (tensor<?xf32>) -> tensor<?xf32>
|
||||
%0 = "mhlo.log"(%arg0) : (tensor<?xf32>) -> tensor<?xf32>
|
||||
return %0 : tensor<?xf32>
|
||||
}
|
||||
|
||||
func @log_unranked(%arg0: tensor<*xf32>) -> tensor<*xf32> {
|
||||
%0 = "xla_hlo.log"(%arg0) : (tensor<*xf32>) -> tensor<*xf32>
|
||||
%0 = "mhlo.log"(%arg0) : (tensor<*xf32>) -> tensor<*xf32>
|
||||
return %0 : tensor<*xf32>
|
||||
}
|
||||
|
||||
func @log1p(%arg0: tensor<2xf32>) -> tensor<2xf32> {
|
||||
%0 = "xla_hlo.log_plus_one"(%arg0) : (tensor<2xf32>) -> tensor<2xf32>
|
||||
%0 = "mhlo.log_plus_one"(%arg0) : (tensor<2xf32>) -> tensor<2xf32>
|
||||
return %0 : tensor<2xf32>
|
||||
}
|
||||
|
||||
func @log1p_dynamic(%arg0: tensor<?xf32>) -> tensor<?xf32> {
|
||||
%0 = "xla_hlo.log_plus_one"(%arg0) : (tensor<?xf32>) -> tensor<?xf32>
|
||||
%0 = "mhlo.log_plus_one"(%arg0) : (tensor<?xf32>) -> tensor<?xf32>
|
||||
return %0 : tensor<?xf32>
|
||||
}
|
||||
|
||||
func @log1p_unranked(%arg0: tensor<*xf32>) -> tensor<*xf32> {
|
||||
%0 = "xla_hlo.log_plus_one"(%arg0) : (tensor<*xf32>) -> tensor<*xf32>
|
||||
%0 = "mhlo.log_plus_one"(%arg0) : (tensor<*xf32>) -> tensor<*xf32>
|
||||
return %0 : tensor<*xf32>
|
||||
}
|
||||
|
||||
func @neg(%arg0: tensor<2xf32>) -> tensor<2xf32> {
|
||||
%0 = "xla_hlo.negate"(%arg0) : (tensor<2xf32>) -> tensor<2xf32>
|
||||
%0 = "mhlo.negate"(%arg0) : (tensor<2xf32>) -> tensor<2xf32>
|
||||
return %0 : tensor<2xf32>
|
||||
}
|
||||
|
||||
func @neg_dynamic(%arg0: tensor<?xf32>) -> tensor<?xf32> {
|
||||
%0 = "xla_hlo.negate"(%arg0) : (tensor<?xf32>) -> tensor<?xf32>
|
||||
%0 = "mhlo.negate"(%arg0) : (tensor<?xf32>) -> tensor<?xf32>
|
||||
return %0 : tensor<?xf32>
|
||||
}
|
||||
|
||||
func @neg_unranked(%arg0: tensor<*xf32>) -> tensor<*xf32> {
|
||||
%0 = "xla_hlo.negate"(%arg0) : (tensor<*xf32>) -> tensor<*xf32>
|
||||
%0 = "mhlo.negate"(%arg0) : (tensor<*xf32>) -> tensor<*xf32>
|
||||
return %0 : tensor<*xf32>
|
||||
}
|
||||
|
||||
func @sigmoid(%arg0: tensor<2xf32>) -> tensor<2xf32> {
|
||||
%0 = xla_hlo.constant dense<5.000000e-01> : tensor<f32>
|
||||
%1 = xla_hlo.constant dense<2> : tensor<1xi64>
|
||||
%2 = xla_hlo.constant dense<5.000000e-01> : tensor<2xf32>
|
||||
%3 = xla_hlo.multiply %arg0, %2 : tensor<2xf32>
|
||||
%4 = "xla_hlo.tanh"(%3) : (tensor<2xf32>) -> tensor<2xf32>
|
||||
%5 = xla_hlo.multiply %4, %2 : tensor<2xf32>
|
||||
%6 = xla_hlo.add %5, %2 : tensor<2xf32>
|
||||
%0 = mhlo.constant dense<5.000000e-01> : tensor<f32>
|
||||
%1 = mhlo.constant dense<2> : tensor<1xi64>
|
||||
%2 = mhlo.constant dense<5.000000e-01> : tensor<2xf32>
|
||||
%3 = mhlo.multiply %arg0, %2 : tensor<2xf32>
|
||||
%4 = "mhlo.tanh"(%3) : (tensor<2xf32>) -> tensor<2xf32>
|
||||
%5 = mhlo.multiply %4, %2 : tensor<2xf32>
|
||||
%6 = mhlo.add %5, %2 : tensor<2xf32>
|
||||
return %6 : tensor<2xf32>
|
||||
}
|
||||
|
||||
func @sin(%arg0: tensor<2xf32>) -> tensor<2xf32> {
|
||||
%0 = "xla_hlo.sine"(%arg0) : (tensor<2xf32>) -> tensor<2xf32>
|
||||
%0 = "mhlo.sine"(%arg0) : (tensor<2xf32>) -> tensor<2xf32>
|
||||
return %0 : tensor<2xf32>
|
||||
}
|
||||
|
||||
func @sin_dynamic(%arg0: tensor<?xf32>) -> tensor<?xf32> {
|
||||
%0 = "xla_hlo.sine"(%arg0) : (tensor<?xf32>) -> tensor<?xf32>
|
||||
%0 = "mhlo.sine"(%arg0) : (tensor<?xf32>) -> tensor<?xf32>
|
||||
return %0 : tensor<?xf32>
|
||||
}
|
||||
|
||||
func @sin_unranked(%arg0: tensor<*xf32>) -> tensor<*xf32> {
|
||||
%0 = "xla_hlo.sine"(%arg0) : (tensor<*xf32>) -> tensor<*xf32>
|
||||
%0 = "mhlo.sine"(%arg0) : (tensor<*xf32>) -> tensor<*xf32>
|
||||
return %0 : tensor<*xf32>
|
||||
}
|
||||
|
||||
func @rsqrt(%arg0: tensor<2xf32>) -> tensor<2xf32> {
|
||||
%0 = "xla_hlo.rsqrt"(%arg0) : (tensor<2xf32>) -> tensor<2xf32>
|
||||
%0 = "mhlo.rsqrt"(%arg0) : (tensor<2xf32>) -> tensor<2xf32>
|
||||
return %0 : tensor<2xf32>
|
||||
}
|
||||
|
||||
func @rsqrt_dynamic(%arg0: tensor<?xf32>) -> tensor<?xf32> {
|
||||
%0 = "xla_hlo.rsqrt"(%arg0) : (tensor<?xf32>) -> tensor<?xf32>
|
||||
%0 = "mhlo.rsqrt"(%arg0) : (tensor<?xf32>) -> tensor<?xf32>
|
||||
return %0 : tensor<?xf32>
|
||||
}
|
||||
|
||||
func @rsqrt_unranked(%arg0: tensor<*xf32>) -> tensor<*xf32> {
|
||||
%0 = "xla_hlo.rsqrt"(%arg0) : (tensor<*xf32>) -> tensor<*xf32>
|
||||
%0 = "mhlo.rsqrt"(%arg0) : (tensor<*xf32>) -> tensor<*xf32>
|
||||
return %0 : tensor<*xf32>
|
||||
}
|
||||
|
||||
func @sqrt(%arg0: tensor<2xf32>) -> tensor<2xf32> {
|
||||
%0 = "xla_hlo.sqrt"(%arg0) : (tensor<2xf32>) -> tensor<2xf32>
|
||||
%0 = "mhlo.sqrt"(%arg0) : (tensor<2xf32>) -> tensor<2xf32>
|
||||
return %0 : tensor<2xf32>
|
||||
}
|
||||
|
||||
func @sqrt_dynamic(%arg0: tensor<?xf32>) -> tensor<?xf32> {
|
||||
%0 = "xla_hlo.sqrt"(%arg0) : (tensor<?xf32>) -> tensor<?xf32>
|
||||
%0 = "mhlo.sqrt"(%arg0) : (tensor<?xf32>) -> tensor<?xf32>
|
||||
return %0 : tensor<?xf32>
|
||||
}
|
||||
|
||||
func @sqrt_unranked(%arg0: tensor<*xf32>) -> tensor<*xf32> {
|
||||
%0 = "xla_hlo.sqrt"(%arg0) : (tensor<*xf32>) -> tensor<*xf32>
|
||||
%0 = "mhlo.sqrt"(%arg0) : (tensor<*xf32>) -> tensor<*xf32>
|
||||
return %0 : tensor<*xf32>
|
||||
}
|
||||
|
||||
func @tanh(%arg0: tensor<2xf32>) -> tensor<2xf32> {
|
||||
%0 = "xla_hlo.tanh"(%arg0) : (tensor<2xf32>) -> tensor<2xf32>
|
||||
%0 = "mhlo.tanh"(%arg0) : (tensor<2xf32>) -> tensor<2xf32>
|
||||
return %0 : tensor<2xf32>
|
||||
}
|
||||
|
||||
func @tanh_dynamic(%arg0: tensor<?xf32>) -> tensor<?xf32> {
|
||||
%0 = "xla_hlo.tanh"(%arg0) : (tensor<?xf32>) -> tensor<?xf32>
|
||||
%0 = "mhlo.tanh"(%arg0) : (tensor<?xf32>) -> tensor<?xf32>
|
||||
return %0 : tensor<?xf32>
|
||||
}
|
||||
|
||||
func @tanh_unranked(%arg0: tensor<*xf32>) -> tensor<*xf32> {
|
||||
%0 = "xla_hlo.tanh"(%arg0) : (tensor<*xf32>) -> tensor<*xf32>
|
||||
%0 = "mhlo.tanh"(%arg0) : (tensor<*xf32>) -> tensor<*xf32>
|
||||
return %0 : tensor<*xf32>
|
||||
}
|
||||
|
||||
func @bitcast(%arg0: tensor<2xf32>) -> tensor<2xf32> {
|
||||
%0 = "xla_hlo.bitcast_convert"(%arg0) : (tensor<2xf32>) -> tensor<2xf32>
|
||||
%0 = "mhlo.bitcast_convert"(%arg0) : (tensor<2xf32>) -> tensor<2xf32>
|
||||
return %0 : tensor<2xf32>
|
||||
}
|
||||
|
||||
func @bitcast_dynamic(%arg0: tensor<?xf32>) -> tensor<?xf32> {
|
||||
%0 = "xla_hlo.bitcast_convert"(%arg0) : (tensor<?xf32>) -> tensor<?xf32>
|
||||
%0 = "mhlo.bitcast_convert"(%arg0) : (tensor<?xf32>) -> tensor<?xf32>
|
||||
return %0 : tensor<?xf32>
|
||||
}
|
||||
|
||||
func @bitcast_unranked(%arg0: tensor<*xf32>) -> tensor<*xf32> {
|
||||
%0 = "xla_hlo.bitcast_convert"(%arg0) : (tensor<*xf32>) -> tensor<*xf32>
|
||||
%0 = "mhlo.bitcast_convert"(%arg0) : (tensor<*xf32>) -> tensor<*xf32>
|
||||
return %0 : tensor<*xf32>
|
||||
}
|
||||
|
||||
func @bitcast_same_widths(%arg0: tensor<2xf32>) -> tensor<2xi32> {
|
||||
%0 = "xla_hlo.bitcast_convert"(%arg0) : (tensor<2xf32>) -> tensor<2xi32>
|
||||
%0 = "mhlo.bitcast_convert"(%arg0) : (tensor<2xf32>) -> tensor<2xi32>
|
||||
return %0 : tensor<2xi32>
|
||||
}
|
||||
|
||||
func @sign(%arg0: tensor<1x2x3x4xf32>) -> tensor<1x2x3x4xf32> {
|
||||
%0 = "xla_hlo.compare"(%arg0, %arg0) {comparison_direction = "NE"} : (tensor<1x2x3x4xf32>, tensor<1x2x3x4xf32>) -> tensor<1x2x3x4xi1>
|
||||
%1 = xla_hlo.constant dense<0.000000e+00> : tensor<1x2x3x4xf32>
|
||||
%2 = "xla_hlo.compare"(%arg0, %arg0) {comparison_direction = "NE"} : (tensor<1x2x3x4xf32>, tensor<1x2x3x4xf32>) -> tensor<1x2x3x4xi1>
|
||||
%3 = xla_hlo.constant dense<0.000000e+00> : tensor<1x2x3x4xf32>
|
||||
%4 = "xla_hlo.sign"(%arg0) : (tensor<1x2x3x4xf32>) -> tensor<1x2x3x4xf32>
|
||||
%5 = "xla_hlo.select"(%2, %3, %4) : (tensor<1x2x3x4xi1>, tensor<1x2x3x4xf32>, tensor<1x2x3x4xf32>) -> tensor<1x2x3x4xf32>
|
||||
%6 = "xla_hlo.select"(%0, %1, %5) : (tensor<1x2x3x4xi1>, tensor<1x2x3x4xf32>, tensor<1x2x3x4xf32>) -> tensor<1x2x3x4xf32>
|
||||
%0 = "mhlo.compare"(%arg0, %arg0) {comparison_direction = "NE"} : (tensor<1x2x3x4xf32>, tensor<1x2x3x4xf32>) -> tensor<1x2x3x4xi1>
|
||||
%1 = mhlo.constant dense<0.000000e+00> : tensor<1x2x3x4xf32>
|
||||
%2 = "mhlo.compare"(%arg0, %arg0) {comparison_direction = "NE"} : (tensor<1x2x3x4xf32>, tensor<1x2x3x4xf32>) -> tensor<1x2x3x4xi1>
|
||||
%3 = mhlo.constant dense<0.000000e+00> : tensor<1x2x3x4xf32>
|
||||
%4 = "mhlo.sign"(%arg0) : (tensor<1x2x3x4xf32>) -> tensor<1x2x3x4xf32>
|
||||
%5 = "mhlo.select"(%2, %3, %4) : (tensor<1x2x3x4xi1>, tensor<1x2x3x4xf32>, tensor<1x2x3x4xf32>) -> tensor<1x2x3x4xf32>
|
||||
%6 = "mhlo.select"(%0, %1, %5) : (tensor<1x2x3x4xi1>, tensor<1x2x3x4xf32>, tensor<1x2x3x4xf32>) -> tensor<1x2x3x4xf32>
|
||||
return %6 : tensor<1x2x3x4xf32>
|
||||
}
|
||||
|
||||
func @size_rank_one_i32(%arg0: tensor<f32>) -> tensor<i32> {
|
||||
%0 = xla_hlo.constant dense<1> : tensor<i32>
|
||||
%0 = mhlo.constant dense<1> : tensor<i32>
|
||||
return %0 : tensor<i32>
|
||||
}
|
||||
|
||||
func @size_rank_one_i64(%arg0: tensor<f32>) -> tensor<i64> {
|
||||
%0 = xla_hlo.constant dense<1> : tensor<i64>
|
||||
%0 = mhlo.constant dense<1> : tensor<i64>
|
||||
return %0 : tensor<i64>
|
||||
}
|
||||
|
||||
func @complex(%arg0: tensor<3xf32>, %arg1: tensor<3xf32>) -> tensor<3xcomplex<f32>> {
|
||||
%0 = "xla_hlo.complex"(%arg0, %arg1) : (tensor<3xf32>, tensor<3xf32>) -> tensor<3xcomplex<f32>>
|
||||
%0 = "mhlo.complex"(%arg0, %arg1) : (tensor<3xf32>, tensor<3xf32>) -> tensor<3xcomplex<f32>>
|
||||
return %0 : tensor<3xcomplex<f32>>
|
||||
}
|
||||
|
||||
func @convert_i32_f32(%arg0: tensor<2xi32>) -> tensor<2xf32> {
|
||||
%0 = "xla_hlo.convert"(%arg0) : (tensor<2xi32>) -> tensor<2xf32>
|
||||
%0 = "mhlo.convert"(%arg0) : (tensor<2xi32>) -> tensor<2xf32>
|
||||
return %0 : tensor<2xf32>
|
||||
}
|
||||
|
||||
func @convert_slice(%arg0: tensor<1x4672xf32>) -> tensor<1x519xf32> {
|
||||
%0 = "xla_hlo.slice"(%arg0) {limit_indices = dense<[1, 4672]> : tensor<2xi64>, start_indices = dense<[0, 4153]> : tensor<2xi64>, strides = dense<1> : tensor<2xi64>} : (tensor<1x4672xf32>) -> tensor<1x519xf32>
|
||||
%0 = "mhlo.slice"(%arg0) {limit_indices = dense<[1, 4672]> : tensor<2xi64>, start_indices = dense<[0, 4153]> : tensor<2xi64>, strides = dense<1> : tensor<2xi64>} : (tensor<1x4672xf32>) -> tensor<1x519xf32>
|
||||
return %0 : tensor<1x519xf32>
|
||||
}
|
||||
|
||||
func @reshape(%arg0: tensor<4x6xf32>) -> tensor<2x2x6xf32> {
|
||||
%0 = "xla_hlo.reshape"(%arg0) : (tensor<4x6xf32>) -> tensor<2x2x6xf32>
|
||||
%0 = "mhlo.reshape"(%arg0) : (tensor<4x6xf32>) -> tensor<2x2x6xf32>
|
||||
return %0 : tensor<2x2x6xf32>
|
||||
|
||||
}
|
||||
|
||||
func @convert_dot_1d_2d(%arg0: tensor<256xf32>, %arg1: tensor<256x1xf32>) -> tensor<1xf32> {
|
||||
%0 = "xla_hlo.dot"(%arg0, %arg1) {precision_config = ["DEFAULT", "DEFAULT"]} : (tensor<256xf32>, tensor<256x1xf32>) -> tensor<1xf32>
|
||||
%0 = "mhlo.dot"(%arg0, %arg1) {precision_config = ["DEFAULT", "DEFAULT"]} : (tensor<256xf32>, tensor<256x1xf32>) -> tensor<1xf32>
|
||||
return %0 : tensor<1xf32>
|
||||
}
|
||||
|
||||
func @convert_dot_2d_1d(%arg0: tensor<1x256xf32>, %arg1: tensor<256xf32>) -> tensor<1xf32> {
|
||||
%0 = "xla_hlo.dot"(%arg0, %arg1) {precision_config = ["DEFAULT", "DEFAULT"]} : (tensor<1x256xf32>, tensor<256xf32>) -> tensor<1xf32>
|
||||
%0 = "mhlo.dot"(%arg0, %arg1) {precision_config = ["DEFAULT", "DEFAULT"]} : (tensor<1x256xf32>, tensor<256xf32>) -> tensor<1xf32>
|
||||
return %0 : tensor<1xf32>
|
||||
}
|
||||
|
||||
func @convert_dot_1d_1d(%arg0: tensor<256xf32>, %arg1: tensor<256xf32>) -> tensor<f32> {
|
||||
%0 = "xla_hlo.dot"(%arg0, %arg1) {precision_config = ["DEFAULT", "DEFAULT"]} : (tensor<256xf32>, tensor<256xf32>) -> tensor<f32>
|
||||
%0 = "mhlo.dot"(%arg0, %arg1) {precision_config = ["DEFAULT", "DEFAULT"]} : (tensor<256xf32>, tensor<256xf32>) -> tensor<f32>
|
||||
return %0 : tensor<f32>
|
||||
}
|
||||
|
||||
func @convert_dot_2d_2d(%arg0: tensor<1x256xf32>, %arg1: tensor<256x1xf32>) -> tensor<1x1xf32> {
|
||||
%0 = "xla_hlo.dot"(%arg0, %arg1) {precision_config = ["DEFAULT", "DEFAULT"]} : (tensor<1x256xf32>, tensor<256x1xf32>) -> tensor<1x1xf32>
|
||||
%0 = "mhlo.dot"(%arg0, %arg1) {precision_config = ["DEFAULT", "DEFAULT"]} : (tensor<1x256xf32>, tensor<256x1xf32>) -> tensor<1x1xf32>
|
||||
return %0 : tensor<1x1xf32>
|
||||
}
|
||||
|
||||
func @broadcast_in_dim_tf_style(%arg0: tensor<8x1x16xf32>) -> tensor<3x8x8x16xf32> {
|
||||
%0 = "xla_hlo.broadcast_in_dim"(%arg0) {broadcast_dimensions = dense<[1, 2, 3]> : tensor<3xi64>, name = "broadcast.0"} : (tensor<8x1x16xf32>) -> tensor<3x8x8x16xf32>
|
||||
%0 = "mhlo.broadcast_in_dim"(%arg0) {broadcast_dimensions = dense<[1, 2, 3]> : tensor<3xi64>, name = "broadcast.0"} : (tensor<8x1x16xf32>) -> tensor<3x8x8x16xf32>
|
||||
return %0 : tensor<3x8x8x16xf32>
|
||||
}
|
||||
|
||||
func @broadcast_in_dim_general_case(%arg0: tensor<3x1x16xf32>) -> tensor<3x8x8x16xf32> {
|
||||
%0 = "xla_hlo.broadcast_in_dim"(%arg0) {broadcast_dimensions = dense<[0, 2, 3]> : tensor<3xi64>, name = "broadcast.0"} : (tensor<3x1x16xf32>) -> tensor<3x8x8x16xf32>
|
||||
%0 = "mhlo.broadcast_in_dim"(%arg0) {broadcast_dimensions = dense<[0, 2, 3]> : tensor<3xi64>, name = "broadcast.0"} : (tensor<3x1x16xf32>) -> tensor<3x8x8x16xf32>
|
||||
return %0 : tensor<3x8x8x16xf32>
|
||||
}
|
||||
|
||||
func @convert_dot_general(%arg0: tensor<3x2x6x5x1xf32>, %arg1: tensor<3x2x4x6xf32>) -> tensor<3x5x1x4xf32> {
|
||||
%0 = "xla_hlo.dot_general"(%arg0, %arg1) {dot_dimension_numbers = {lhs_batching_dimensions = dense<0> : tensor<1xi64>, lhs_contracting_dimensions = dense<[1, 2]> : tensor<2xi64>, rhs_batching_dimensions = dense<0> : tensor<1xi64>, rhs_contracting_dimensions = dense<[1, 3]> : tensor<2xi64>}, precision_config = ["DEFAULT", "DEFAULT"]} : (tensor<3x2x6x5x1xf32>, tensor<3x2x4x6xf32>) -> tensor<3x5x1x4xf32>
|
||||
%0 = "mhlo.dot_general"(%arg0, %arg1) {dot_dimension_numbers = {lhs_batching_dimensions = dense<0> : tensor<1xi64>, lhs_contracting_dimensions = dense<[1, 2]> : tensor<2xi64>, rhs_batching_dimensions = dense<0> : tensor<1xi64>, rhs_contracting_dimensions = dense<[1, 3]> : tensor<2xi64>}, precision_config = ["DEFAULT", "DEFAULT"]} : (tensor<3x2x6x5x1xf32>, tensor<3x2x4x6xf32>) -> tensor<3x5x1x4xf32>
|
||||
return %0 : tensor<3x5x1x4xf32>
|
||||
}
|
||||
|
||||
func @convert_conv2d(%arg0: tensor<1x8x8x207xf32>, %arg1: tensor<3x3x207x16xf32>) -> tensor<1x8x8x16xf32> {
|
||||
%0 = "xla_hlo.convolution"(%arg0, %arg1) {batch_group_count = 1 : i64, dimension_numbers =
|
||||
%0 = "mhlo.convolution"(%arg0, %arg1) {batch_group_count = 1 : i64, dimension_numbers =
|
||||
{input_batch_dimension = 0 : i64, input_feature_dimension = 3 : i64, input_spatial_dimensions = dense<[1, 2]> : tensor<2xi64>, kernel_input_feature_dimension = 2 : i64, kernel_output_feature_dimension = 3 : i64, kernel_spatial_dimensions = dense<[0, 1]> : tensor<2xi64>, output_batch_dimension = 0 : i64, output_feature_dimension = 3 : i64, output_spatial_dimensions = dense<[1, 2]> : tensor<2xi64>},
|
||||
feature_group_count = 1 : i64, lhs_dilation = dense<1> : tensor<2xi64>, padding = dense<1> : tensor<2x2xi64>, precision_config = ["DEFAULT", "DEFAULT"], rhs_dilation = dense<1> : tensor<2xi64>, window_strides = dense<1> : tensor<2xi64>} :
|
||||
(tensor<1x8x8x207xf32>, tensor<3x3x207x16xf32>) -> tensor<1x8x8x16xf32>
|
||||
@ -737,7 +737,7 @@ func @convert_conv2d(%arg0: tensor<1x8x8x207xf32>, %arg1: tensor<3x3x207x16xf32>
|
||||
}
|
||||
|
||||
func @convert_depthwise_conv2d(%arg0: tensor<1x8x8x207xf32>, %arg1: tensor<3x3x207x16xf32>) -> tensor<1x8x8x16xf32> {
|
||||
%0 = "xla_hlo.convolution"(%arg0, %arg1) {batch_group_count = 1 : i64, dimension_numbers =
|
||||
%0 = "mhlo.convolution"(%arg0, %arg1) {batch_group_count = 1 : i64, dimension_numbers =
|
||||
{input_batch_dimension = 0 : i64, input_feature_dimension = 3 : i64, input_spatial_dimensions = dense<[1, 2]> : tensor<2xi64>, kernel_input_feature_dimension = 2 : i64, kernel_output_feature_dimension = 3 : i64, kernel_spatial_dimensions = dense<[0, 1]> : tensor<2xi64>, output_batch_dimension = 0 : i64, output_feature_dimension = 3 : i64, output_spatial_dimensions = dense<[1, 2]> : tensor<2xi64>},
|
||||
feature_group_count = 207 : i64, lhs_dilation = dense<1> : tensor<2xi64>, padding = dense<1> : tensor<2x2xi64>, precision_config = ["DEFAULT", "DEFAULT"], rhs_dilation = dense<1> : tensor<2xi64>, window_strides = dense<1> : tensor<2xi64>} :
|
||||
(tensor<1x8x8x207xf32>, tensor<3x3x207x16xf32>) -> tensor<1x8x8x16xf32>
|
||||
@ -745,7 +745,7 @@ func @convert_depthwise_conv2d(%arg0: tensor<1x8x8x207xf32>, %arg1: tensor<3x3x2
|
||||
}
|
||||
|
||||
func @convert_conv2d_valid_padding(%arg0: tensor<1x8x8x207xf32>, %arg1: tensor<3x3x207x16xf32>) -> tensor<1x8x8x16xf32> {
|
||||
%0 = "xla_hlo.convolution"(%arg0, %arg1) {batch_group_count = 1 : i64, dimension_numbers =
|
||||
%0 = "mhlo.convolution"(%arg0, %arg1) {batch_group_count = 1 : i64, dimension_numbers =
|
||||
{input_batch_dimension = 0 : i64, input_feature_dimension = 3 : i64, input_spatial_dimensions = dense<[1, 2]> : tensor<2xi64>, kernel_input_feature_dimension = 2 : i64, kernel_output_feature_dimension = 3 : i64, kernel_spatial_dimensions = dense<[0, 1]> : tensor<2xi64>, output_batch_dimension = 0 : i64, output_feature_dimension = 3 : i64, output_spatial_dimensions = dense<[1, 2]> : tensor<2xi64>},
|
||||
feature_group_count = 1 : i64, lhs_dilation = dense<1> : tensor<2xi64>, padding = dense<0> : tensor<2x2xi64>, precision_config = ["DEFAULT", "DEFAULT"], rhs_dilation = dense<1> : tensor<2xi64>, window_strides = dense<1> : tensor<2xi64>} :
|
||||
(tensor<1x8x8x207xf32>, tensor<3x3x207x16xf32>) -> tensor<1x8x8x16xf32>
|
||||
@ -753,22 +753,22 @@ func @convert_conv2d_valid_padding(%arg0: tensor<1x8x8x207xf32>, %arg1: tensor<3
|
||||
}
|
||||
|
||||
func @convert_reduce_to_sum(%arg0: tensor<1x256xf32>) -> tensor<1xf32> {
|
||||
%0 = xla_hlo.constant dense<0.000000e+00> : tensor<f32>
|
||||
%1 = "xla_hlo.reduce"(%arg0, %0) ( {
|
||||
%0 = mhlo.constant dense<0.000000e+00> : tensor<f32>
|
||||
%1 = "mhlo.reduce"(%arg0, %0) ( {
|
||||
^bb0(%arg1: tensor<f32>, %arg2: tensor<f32>):
|
||||
%2 = xla_hlo.add %arg1, %arg2 : tensor<f32>
|
||||
"xla_hlo.return"(%2) : (tensor<f32>) -> ()
|
||||
%2 = mhlo.add %arg1, %arg2 : tensor<f32>
|
||||
"mhlo.return"(%2) : (tensor<f32>) -> ()
|
||||
}) {dimensions = dense<1> : tensor<1xi64>} : (tensor<1x256xf32>, tensor<f32>) -> tensor<1xf32>
|
||||
return %1 : tensor<1xf32>
|
||||
}
|
||||
|
||||
func @convert_reduce_to_max(%arg0: tensor<1x256xf32>) -> tensor<1xf32> {
|
||||
// "0xFF800000" represents -INF for f32.
|
||||
%0 = xla_hlo.constant dense<0xFF800000> : tensor<f32>
|
||||
%1 = "xla_hlo.reduce"(%arg0, %0) ( {
|
||||
%0 = mhlo.constant dense<0xFF800000> : tensor<f32>
|
||||
%1 = "mhlo.reduce"(%arg0, %0) ( {
|
||||
^bb0(%arg1: tensor<f32>, %arg2: tensor<f32>):
|
||||
%2 = xla_hlo.maximum %arg1, %arg2 : tensor<f32>
|
||||
"xla_hlo.return"(%2) : (tensor<f32>) -> ()
|
||||
%2 = mhlo.maximum %arg1, %arg2 : tensor<f32>
|
||||
"mhlo.return"(%2) : (tensor<f32>) -> ()
|
||||
}) {dimensions = dense<1> : tensor<1xi64>} : (tensor<1x256xf32>, tensor<f32>) -> tensor<1xf32>
|
||||
return %1 : tensor<1xf32>
|
||||
}
|
||||
@ -776,11 +776,11 @@ func @convert_reduce_to_max(%arg0: tensor<1x256xf32>) -> tensor<1xf32> {
|
||||
|
||||
func @convert_reduce_to_min(%arg0: tensor<1x256xf32>) -> tensor<1xf32> {
|
||||
// "0x7F800000" represents INF for f32.
|
||||
%0 = xla_hlo.constant dense<0x7F800000> : tensor<f32>
|
||||
%1 = "xla_hlo.reduce"(%arg0, %0) ( {
|
||||
%0 = mhlo.constant dense<0x7F800000> : tensor<f32>
|
||||
%1 = "mhlo.reduce"(%arg0, %0) ( {
|
||||
^bb0(%arg1: tensor<f32>, %arg2: tensor<f32>):
|
||||
%2 = xla_hlo.minimum %arg1, %arg2 : tensor<f32>
|
||||
"xla_hlo.return"(%2) : (tensor<f32>) -> ()
|
||||
%2 = mhlo.minimum %arg1, %arg2 : tensor<f32>
|
||||
"mhlo.return"(%2) : (tensor<f32>) -> ()
|
||||
}) {dimensions = dense<1> : tensor<1xi64>} : (tensor<1x256xf32>, tensor<f32>) -> tensor<1xf32>
|
||||
return %1 : tensor<1xf32>
|
||||
}
|
||||
|
@ -42,16 +42,19 @@ func @empty_replicate() {
|
||||
|
||||
// CHECK-LABEL: func @replicate_with_multiple_operands
|
||||
func @replicate_with_multiple_operands() {
|
||||
%0 = "tf.opA"() : () -> (tensor<*xi1>)
|
||||
%1 = "tf.opB"() : () -> (tensor<*xi1>)
|
||||
%2 = "tf.opC"() : () -> (tensor<*xi1>)
|
||||
%3 = "tf.opD"() : () -> (tensor<*xi32>)
|
||||
%4 = "tf.opE"() : () -> (tensor<*xi32>)
|
||||
%5 = "tf.opF"() : () -> (tensor<*xi32>)
|
||||
%6 = "tf.opG"() : () -> (tensor<*xf32>)
|
||||
%7 = "tf.opH"() : () -> (tensor<*xf32>)
|
||||
%8 = "tf.opI"() : () -> (tensor<*xf32>)
|
||||
tf_device.replicate([%0, %1, %2] as %input0: tensor<*xi1>, [%3, %4, %5] as %input1: tensor<*xi32>, [%6, %7, %8] as %input2: tensor<*xf32>) {n = 3 : i32} {
|
||||
%0 = "tf.opA"() : () -> tensor<*xi1>
|
||||
%1 = "tf.opB"() : () -> tensor<*xi1>
|
||||
%2 = "tf.opC"() : () -> tensor<*xi1>
|
||||
%3 = "tf.opD"() : () -> tensor<*xi32>
|
||||
%4 = "tf.opE"() : () -> tensor<*xi32>
|
||||
%5 = "tf.opF"() : () -> tensor<*xi32>
|
||||
%6 = "tf.opG"() : () -> tensor<*xf32>
|
||||
%7 = "tf.opH"() : () -> tensor<*xf32>
|
||||
%8 = "tf.opI"() : () -> tensor<*xf32>
|
||||
%9 = "tf.opJ"() : () -> tensor<*xi8>
|
||||
%10 = "tf.opK"() : () -> tensor<*xi16>
|
||||
%11 = "tf.opL"() : () -> tensor<*xi64>
|
||||
tf_device.replicate([%0, %1, %2] as %input0: tensor<*xi1>, %9 as %input1: tensor<*xi8>, %10 as %input2: tensor<*xi16>, [%3, %4, %5] as %input3: tensor<*xi32>, [%6, %7, %8] as %input4: tensor<*xf32>, %11 as %input5: tensor<*xi64>) {n = 3 : i32} {
|
||||
tf_device.return
|
||||
}
|
||||
return
|
||||
@ -65,12 +68,32 @@ func @replicate_with_multiple_operands() {
|
||||
// CHECK: %[[OP_G:[a-z0-9]*]] = "tf.opG"
|
||||
// CHECK: %[[OP_H:[a-z0-9]*]] = "tf.opH"
|
||||
// CHECK: %[[OP_I:[a-z0-9]*]] = "tf.opI"
|
||||
// CHECK: %[[OP_J:[a-z0-9]*]] = "tf.opJ"
|
||||
// CHECK: %[[OP_K:[a-z0-9]*]] = "tf.opK"
|
||||
// CHECK: %[[OP_L:[a-z0-9]*]] = "tf.opL"
|
||||
// CHECK: tf_device.replicate
|
||||
// CHECK-SAME: ([%[[OP_A]], %[[OP_B]], %[[OP_C]]] as %{{[a-z0-9]*}}: tensor<*xi1>, [%[[OP_D]], %[[OP_E]], %[[OP_F]]] as %{{[a-z0-9]*}}: tensor<*xi32>, [%[[OP_G]], %[[OP_H]], %[[OP_I]]] as %{{[a-z0-9]*}}: tensor<*xf32>)
|
||||
// CHECK-SAME: [%[[OP_A]], %[[OP_B]], %[[OP_C]]] as %{{[a-z0-9]*}}: tensor<*xi1>
|
||||
// CHECK-SAME: [%[[OP_D]], %[[OP_E]], %[[OP_F]]] as %{{[a-z0-9]*}}: tensor<*xi32>
|
||||
// CHECK-SAME: [%[[OP_G]], %[[OP_H]], %[[OP_I]]] as %{{[a-z0-9]*}}: tensor<*xf32>
|
||||
// CHECK-SAME: %[[OP_J]] as %{{[a-z0-9]*}}: tensor<*xi8>
|
||||
// CHECK-SAME: %[[OP_K]] as %{{[a-z0-9]*}}: tensor<*xi16>
|
||||
// CHECK-SAME: %[[OP_L]] as %{{[a-z0-9]*}}: tensor<*xi64>
|
||||
// CHECK-SAME: n = 3
|
||||
// CHECK-NEXT: tf_device.return
|
||||
}
|
||||
|
||||
// CHECK-LABEL: func @replicate_derived_operand_segment_sizes
|
||||
func @replicate_derived_operand_segment_sizes() {
|
||||
tf_device.replicate {n = 2 : i32, operand_segment_sizes = dense<[0, 0]> : vector<2xi32>} {
|
||||
}
|
||||
return
|
||||
|
||||
// CHECK: tf_device.replicate
|
||||
// CHECK-SAME: n = 2
|
||||
// CHECK-NOT: operand_segment_sizes
|
||||
// CHECK-NEXT: tf_device.return
|
||||
}
|
||||
|
||||
// CHECK-LABEL: func @replicate_with_return
|
||||
// CHECK-SAME: (%[[ARG_0:[a-z0-9]*]]: tensor<*xf32>, %[[ARG_1:[a-z0-9]*]]: tensor<*xf32>, %[[ARG_2:[a-z0-9]*]]: tensor<*xi32>)
|
||||
func @replicate_with_return(%arg0: tensor<*xf32>, %arg1: tensor<*xf32>, %arg2: tensor<*xi32>) {
|
||||
|
@ -61,7 +61,7 @@ func @parser_replicate_terminator() {
|
||||
func @verifier_replicate_no_block() {
|
||||
"tf_device.replicate" () ({
|
||||
// expected-error@-1 {{'tf_device.replicate' op region #0 ('body') failed to verify constraint: region with 1 blocks}}
|
||||
}) {n = 2 : i32} : () -> ()
|
||||
}) {n = 2 : i32, operand_segment_sizes = dense<[0, 0]> : vector<2xi32>} : () -> ()
|
||||
return
|
||||
}
|
||||
|
||||
@ -72,7 +72,7 @@ func @verifier_replicate_empty_block() {
|
||||
"tf_device.replicate" () ({
|
||||
// expected-error@-1 {{'tf_device.replicate' op expects a non-empty block}}
|
||||
^entry:
|
||||
}) {n = 2 : i32} : () -> ()
|
||||
}) {n = 2 : i32, operand_segment_sizes = dense<[0, 0]> : vector<2xi32>} : () -> ()
|
||||
return
|
||||
}
|
||||
|
||||
@ -85,7 +85,7 @@ func @verifier_replicate_terminator() {
|
||||
// expected-error@-1 {{'tf_device.replicate' op expects regions to end with 'tf_device.return', found 'std.return'}}
|
||||
^entry:
|
||||
return
|
||||
}) {n = 2 : i32} : () -> ()
|
||||
}) {n = 2 : i32, operand_segment_sizes = dense<[0, 0]> : vector<2xi32>} : () -> ()
|
||||
return
|
||||
}
|
||||
|
||||
@ -97,7 +97,7 @@ func @verifier_replicate_n() {
|
||||
// expected-error@-1 {{'tf_device.replicate' op attribute 'n' failed to satisfy constraint: 32-bit signless integer attribute whose minimum value is 2}}
|
||||
^entry:
|
||||
tf_device.return
|
||||
}) {n = 1 : i32} : () -> ()
|
||||
}) {n = 1 : i32, operand_segment_sizes = dense<[0, 0]> : vector<2xi32>} : () -> ()
|
||||
}
|
||||
|
||||
// -----
|
||||
@ -109,43 +109,66 @@ func @verifier_replicate_n_device() {
|
||||
// expected-error@-1 {{'tf_device.replicate' op expects number of devices (2) to be equal to 'n' (3)}}
|
||||
^entry:
|
||||
tf_device.return
|
||||
}) {n = 3 : i32, devices = {TPU_REPLICATED_CORE_0 = ["/DEVICE:0", "/DEVICE:1"]}} : () -> ()
|
||||
}) {devices = {TPU_REPLICATED_CORE_0 = ["/DEVICE:0", "/DEVICE:1"]}, n = 3 : i32, operand_segment_sizes = dense<[0, 0]> : vector<2xi32>} : () -> ()
|
||||
}
|
||||
|
||||
// -----
|
||||
|
||||
// Check that replicate op's `devices` attribute must consist of dictionary
|
||||
// Check that replicate op's 'devices' attribute must consist of dictionary
|
||||
// with values as list with size equal to 'n' attribute.
|
||||
func @verifier_replicate_n_device_multiple_alias() {
|
||||
"tf_device.replicate" () ({
|
||||
// expected-error@-1 {{'tf_device.replicate' op expects number of devices (2) to be equal to 'n' (3)}}
|
||||
^entry:
|
||||
tf_device.return
|
||||
}) {n = 3 : i32, devices = {TPU_REPLICATED_CORE_0 = ["/DEVICE:0", "/DEVICE:1"], TPU_REPLICATED_CORE_1 = ["/DEVICE:2"]}} : () -> ()
|
||||
}) {devices = {TPU_REPLICATED_CORE_0 = ["/DEVICE:0", "/DEVICE:1"], TPU_REPLICATED_CORE_1 = ["/DEVICE:2"]}, n = 3 : i32, operand_segment_sizes = dense<[0, 0]> : vector<2xi32>} : () -> ()
|
||||
}
|
||||
|
||||
// -----
|
||||
|
||||
// Check that a replicate with mismatched operand and block arg counts is
|
||||
// invalid.
|
||||
func @verifier_replicate_operand_block_arg_count(%arg0: tensor<*xi32>) {
|
||||
"tf_device.replicate" (%arg0, %arg0, %arg0) ({
|
||||
// expected-error@-1 {{'tf_device.replicate' op expects number of operands (3) to be equal to 'n' * number of block arguments (2 * 1)}}
|
||||
^entry(%input0: tensor<*xi32>):
|
||||
// Check number of replicated inputs is evenly divisible by 'n'.
|
||||
func @verifier_replicate_bad_operand_segment_sizes(%arg0: tensor<*xi32>) {
|
||||
"tf_device.replicate" (%arg0, %arg0, %arg0, %arg0) ({
|
||||
// expected-error@-1 {{'tf_device.replicate' op expects number of replicated inputs (4) to be evenly divisible by 'n' (3)}}
|
||||
^entry(%input0: tensor<*xi32>, %input1: tensor<*xi32>):
|
||||
tf_device.return
|
||||
}) {n = 2 : i32} : (tensor<*xi32>, tensor<*xi32>, tensor<*xi32>) -> ()
|
||||
}) {n = 3 : i32, operand_segment_sizes = dense<[4, 0]> : vector<2xi32>} : (tensor<*xi32>, tensor<*xi32>, tensor<*xi32>, tensor<*xi32>) -> ()
|
||||
}
|
||||
|
||||
// -----
|
||||
|
||||
// Check that a replicate with incompatible operand and block argument type is
|
||||
// invalid.
|
||||
func @verifier_replicate_operand_block_arg_type(%arg0: tensor<*xi32>, %arg1: tensor<*xi1>) {
|
||||
// Check number of replicated inputs / 'n' + number of packed inputs matches the
|
||||
// number of block arguments.
|
||||
func @verifier_replicate_num_block_args(%arg0: tensor<*xi32>) {
|
||||
"tf_device.replicate" (%arg0, %arg0, %arg0, %arg0, %arg0) ({
|
||||
// expected-error@-1 {{'tf_device.replicate' op expects number of block arguments (2) to be equal to number of replicated inputs (3) / 'n' (3) + number of packed inputs (2)}}
|
||||
^entry(%input0: tensor<*xi32>, %input1: tensor<*xi32>):
|
||||
tf_device.return
|
||||
}) {n = 3 : i32, operand_segment_sizes = dense<[3, 2]> : vector<2xi32>} : (tensor<*xi32>, tensor<*xi32>, tensor<*xi32>, tensor<*xi32>, tensor<*xi32>) -> ()
|
||||
}
|
||||
|
||||
// -----
|
||||
|
||||
// Check that a replicate with incompatible replicated operand and block
|
||||
// argument type is invalid.
|
||||
func @verifier_replicate_replicated_operand_block_arg_type(%arg0: tensor<*xi32>, %arg1: tensor<*xi1>) {
|
||||
"tf_device.replicate" (%arg0, %arg1) ({
|
||||
// expected-error@-1 {{'tf_device.replicate' op incompatible types for operand 1 and block argument 0}}
|
||||
// expected-error@-1 {{'tf_device.replicate' op expects operand 1 ('tensor<*xi1>') and block argument 0 ('tensor<*xi32>') to have compatible types}}
|
||||
^entry(%input0: tensor<*xi32>):
|
||||
tf_device.return
|
||||
}) {n = 2 : i32} : (tensor<*xi32>, tensor<*xi1>) -> ()
|
||||
}) {n = 2 : i32, operand_segment_sizes = dense<[2, 0]> : vector<2xi32>} : (tensor<*xi32>, tensor<*xi1>) -> ()
|
||||
}
|
||||
|
||||
// -----
|
||||
|
||||
// Check that a replicate with incompatible packed operand and block argument
|
||||
// type is invalid.
|
||||
func @verifier_replicate_packed_operand_block_arg_type(%arg0: tensor<*xi1>) {
|
||||
"tf_device.replicate" (%arg0) ({
|
||||
// expected-error@-1 {{'tf_device.replicate' op expects operand 0 ('tensor<*xi1>') and block argument 0 ('tensor<*xi32>') to have compatible types}}
|
||||
^entry(%input0: tensor<*xi32>):
|
||||
tf_device.return
|
||||
}) {n = 2 : i32, operand_segment_sizes = dense<[0, 1]> : vector<2xi32>} : (tensor<*xi1>) -> ()
|
||||
}
|
||||
|
||||
// -----
|
||||
@ -157,7 +180,7 @@ func @verifier_replicate_result_return_operand_count(%arg0: tensor<*xi32>) {
|
||||
// expected-error@-1 {{'tf_device.replicate' op expects number of results (3) to be equal to 'n' * number of terminator operands (2 * 1)}}
|
||||
^entry:
|
||||
tf_device.return %arg0 : tensor<*xi32>
|
||||
}) {n = 2 : i32} : () -> (tensor<*xi32>, tensor<*xi32>, tensor<*xi32>)
|
||||
}) {n = 2 : i32, operand_segment_sizes = dense<[0, 0]> : vector<2xi32>} : () -> (tensor<*xi32>, tensor<*xi32>, tensor<*xi32>)
|
||||
}
|
||||
|
||||
// -----
|
||||
@ -169,7 +192,7 @@ func @verifier_replicate_result_return_operand_type(%arg0: tensor<*xi32>) {
|
||||
// expected-error@-1 {{'tf_device.replicate' op incompatible types for result 1 and terminator operand 0}}
|
||||
^entry:
|
||||
tf_device.return %arg0 : tensor<*xi32>
|
||||
}) {n = 2 : i32} : () -> (tensor<*xi32>, tensor<*xi1>)
|
||||
}) {n = 2 : i32, operand_segment_sizes = dense<[0, 0]> : vector<2xi32>} : () -> (tensor<*xi32>, tensor<*xi1>)
|
||||
}
|
||||
|
||||
// -----
|
||||
|
Some files were not shown because too many files have changed in this diff Show More
Loading…
x
Reference in New Issue
Block a user