Initial version of HLO translation tools.
Adds the initial version of tools to convert between HLO protos and the XLA dialect. Co-authored-by: Robert Suderman <suderman@google.com> Co-authored-by: Geoffrey Martin-Noble <gcmn@google.com> Co-authored-by: River Riddle <riverriddle@google.com> Co-authored-by: Lei Zhang <antiagainst@google.com> Co-authored-by: Himabindu Pucha <hpucha@google.com> PiperOrigin-RevId: 258691517
This commit is contained in:
parent
3d11f06b29
commit
4ad0cdbdeb
tensorflow/compiler/mlir
BUILD
tensorflow
xla
BUILDhlo_function_importer.cchlo_function_importer.hhlo_module_importer.cchlo_module_importer.hhlo_to_mlir_hlo.cchlo_to_mlir_hlo.hmlir_hlo_to_hlo.ccmlir_hlo_to_hlo.hoperator_writer_gen.cc
tests/translate
BUILDadd.hlotxtadd.mlirand.hlotxtbinary_op_broadcast.mlirbroadcast.mlirbroadcast_in_dim.hlotxtcall.hlotxtcomp.hlotxtconcat.hlotxtconst.hlotxtconv.hlotxtconvert.hlotxtdiv.hlotxtdot.hlotxtdynamic-update-slice.hlotxtfully_connected_reference_model.hlotxtiota.hlotxtmax.hlotxtmin.hlotxtmul.hlotxtpad.hlotxtreduce.hlotxtreverse.hlotxtscalar.hlotxtselect.hlotxtselect.mlirsimple.hlosimple.hlotxtsimple.mlirsub.hlotxttanh.hlotxttranspose.hlotxttranspose.mlirtuple.hlotxtunknown.hlotxtwhile.hlotxt
type_to_shape.cctype_to_shape.htype_to_shape_test.ccxla_mlir_translate.ccxla_mlir_translate.h@ -56,6 +56,25 @@ tf_cc_binary(
|
||||
],
|
||||
)
|
||||
|
||||
tf_cc_binary(
|
||||
name = "tf-mlir-translate",
|
||||
deps = [
|
||||
"//tensorflow/compiler/mlir/tensorflow:convert_graphdef",
|
||||
"//tensorflow/compiler/mlir/tensorflow:mlir_roundtrip_flags",
|
||||
"//tensorflow/compiler/mlir/tensorflow:translate_cl_options",
|
||||
"//tensorflow/compiler/mlir/tensorflow:translate_lib",
|
||||
"//tensorflow/compiler/mlir/tensorflow:translate_registration",
|
||||
"//tensorflow/compiler/mlir/tensorflow:translate_tf_dialect_op",
|
||||
"//tensorflow/compiler/mlir/xla:xla_mlir_translate",
|
||||
"//tensorflow/core:protos_all_proto_cc",
|
||||
"//tensorflow/stream_executor/lib",
|
||||
"@llvm//:support",
|
||||
"@local_config_mlir//:IR",
|
||||
"@local_config_mlir//:Translation",
|
||||
"@local_config_mlir//:tools/mlir-translate/mlir-translate",
|
||||
],
|
||||
)
|
||||
|
||||
filegroup(
|
||||
name = "litfiles",
|
||||
srcs = glob(["runlit*py"]),
|
||||
|
@ -511,24 +511,6 @@ cc_library(
|
||||
alwayslink = 1,
|
||||
)
|
||||
|
||||
tf_cc_binary(
|
||||
name = "tf-mlir-translate",
|
||||
deps = [
|
||||
":convert_graphdef",
|
||||
":mlir_roundtrip_flags",
|
||||
":translate_cl_options",
|
||||
":translate_lib",
|
||||
":translate_registration",
|
||||
":translate_tf_dialect_op",
|
||||
"//tensorflow/core:protos_all_proto_cc",
|
||||
"//tensorflow/stream_executor/lib",
|
||||
"@llvm//:support",
|
||||
"@local_config_mlir//:IR",
|
||||
"@local_config_mlir//:Translation",
|
||||
"@local_config_mlir//:tools/mlir-translate/mlir-translate",
|
||||
],
|
||||
)
|
||||
|
||||
tf_cc_test(
|
||||
name = "error_util_test",
|
||||
srcs = ["utils/error_util_test.cc"],
|
||||
|
@ -13,7 +13,7 @@ filegroup(
|
||||
name = "test_utilities",
|
||||
testonly = True,
|
||||
data = [
|
||||
"//tensorflow/compiler/mlir/tensorflow:tf-mlir-translate",
|
||||
"//tensorflow/compiler/mlir:tf-mlir-translate",
|
||||
"@llvm//:FileCheck",
|
||||
],
|
||||
)
|
||||
|
@ -13,7 +13,7 @@ filegroup(
|
||||
name = "test_utilities",
|
||||
testonly = True,
|
||||
data = [
|
||||
"//tensorflow/compiler/mlir/tensorflow:tf-mlir-translate",
|
||||
"//tensorflow/compiler/mlir:tf-mlir-translate",
|
||||
"@llvm//:FileCheck",
|
||||
],
|
||||
)
|
||||
|
@ -1,4 +1,5 @@
|
||||
load("@local_config_mlir//:tblgen.bzl", "gentbl")
|
||||
load("//tensorflow:tensorflow.bzl", "tf_cc_test", "tf_native_cc_binary")
|
||||
|
||||
package(
|
||||
default_visibility = [":friends"],
|
||||
@ -166,3 +167,151 @@ cc_library(
|
||||
],
|
||||
alwayslink = 1,
|
||||
)
|
||||
|
||||
cc_library(
|
||||
name = "type_to_shape",
|
||||
srcs = ["type_to_shape.cc"],
|
||||
hdrs = ["type_to_shape.h"],
|
||||
deps = [
|
||||
"//tensorflow/compiler/xla:shape_util",
|
||||
"//tensorflow/compiler/xla:xla_data_proto",
|
||||
"//tensorflow/core:lib",
|
||||
"@com_google_absl//absl/base:core_headers",
|
||||
"@local_config_mlir//:IR",
|
||||
"@local_config_mlir//:Support",
|
||||
],
|
||||
)
|
||||
|
||||
tf_cc_test(
|
||||
name = "type_to_shape_test",
|
||||
srcs = ["type_to_shape_test.cc"],
|
||||
deps = [
|
||||
":type_to_shape",
|
||||
"//tensorflow/compiler/xla:shape_util",
|
||||
"//tensorflow/compiler/xla:test",
|
||||
"//tensorflow/compiler/xla:xla_data_proto",
|
||||
"//tensorflow/core:test_main",
|
||||
"@local_config_mlir//:IR",
|
||||
],
|
||||
)
|
||||
|
||||
cc_library(
|
||||
name = "mlir_hlo_to_hlo",
|
||||
srcs = [
|
||||
"mlir_hlo_to_hlo.cc",
|
||||
"operator_writers.inc",
|
||||
],
|
||||
hdrs = ["mlir_hlo_to_hlo.h"],
|
||||
deps = [
|
||||
":type_to_shape",
|
||||
":xla",
|
||||
"//tensorflow/compiler/mlir/tensorflow:error_util",
|
||||
"//tensorflow/compiler/xla:comparison_util",
|
||||
"//tensorflow/compiler/xla:status_macros",
|
||||
"//tensorflow/compiler/xla:xla_data_proto",
|
||||
"//tensorflow/compiler/xla/client:xla_builder",
|
||||
"//tensorflow/compiler/xla/service:hlo",
|
||||
"@llvm//:support",
|
||||
"@local_config_mlir//:Analysis",
|
||||
"@local_config_mlir//:IR",
|
||||
"@local_config_mlir//:Pass",
|
||||
"@local_config_mlir//:StandardOps",
|
||||
"@local_config_mlir//:TransformUtils",
|
||||
"@local_config_mlir//:Transforms",
|
||||
],
|
||||
)
|
||||
|
||||
cc_library(
|
||||
name = "hlo_to_mlir_hlo",
|
||||
srcs = [
|
||||
"hlo_to_mlir_hlo.cc",
|
||||
],
|
||||
hdrs = [
|
||||
"hlo_to_mlir_hlo.h",
|
||||
],
|
||||
deps = [
|
||||
":hlo_module_importer",
|
||||
"//tensorflow/compiler/mlir/tensorflow:error_util",
|
||||
"//tensorflow/compiler/xla:status",
|
||||
"//tensorflow/compiler/xla:status_macros",
|
||||
"//tensorflow/core:lib",
|
||||
],
|
||||
)
|
||||
|
||||
cc_library(
|
||||
name = "hlo_module_importer",
|
||||
srcs = [
|
||||
"hlo_function_importer.cc",
|
||||
"hlo_module_importer.cc",
|
||||
],
|
||||
hdrs = [
|
||||
"hlo_function_importer.h",
|
||||
"hlo_module_importer.h",
|
||||
],
|
||||
deps = [
|
||||
":xla",
|
||||
"//tensorflow/compiler/mlir/tensorflow:error_util",
|
||||
"//tensorflow/compiler/xla:protobuf_util",
|
||||
"//tensorflow/compiler/xla:status",
|
||||
"//tensorflow/compiler/xla:status_macros",
|
||||
"//tensorflow/compiler/xla:statusor",
|
||||
"//tensorflow/compiler/xla:xla_data_proto",
|
||||
"//tensorflow/compiler/xla:xla_proto",
|
||||
"//tensorflow/compiler/xla/service:hlo",
|
||||
"@llvm//:support",
|
||||
"@local_config_mlir//:IR",
|
||||
"@local_config_mlir//:StandardOps",
|
||||
],
|
||||
)
|
||||
|
||||
cc_library(
|
||||
name = "xla_mlir_translate",
|
||||
srcs = [
|
||||
"xla_mlir_translate.cc",
|
||||
],
|
||||
hdrs = [
|
||||
"xla_mlir_translate.h",
|
||||
],
|
||||
deps = [
|
||||
":hlo_to_mlir_hlo",
|
||||
":mlir_hlo_to_hlo",
|
||||
"//tensorflow/compiler/xla:debug_options_flags",
|
||||
"//tensorflow/compiler/xla:statusor",
|
||||
"//tensorflow/compiler/xla/service:hlo_parser",
|
||||
"//tensorflow/compiler/xla/service:hlo_proto",
|
||||
"//tensorflow/core:lib",
|
||||
"@com_google_protobuf//:protobuf_headers",
|
||||
"@llvm//:support",
|
||||
"@local_config_mlir//:IR",
|
||||
"@local_config_mlir//:Translation",
|
||||
],
|
||||
alwayslink = 1,
|
||||
)
|
||||
|
||||
tf_native_cc_binary(
|
||||
name = "operator_writer_gen",
|
||||
srcs = [
|
||||
"operator_writer_gen.cc",
|
||||
],
|
||||
deps = [
|
||||
"@llvm//:config",
|
||||
"@llvm//:support",
|
||||
"@llvm//:tablegen",
|
||||
"@local_config_mlir//:TableGen",
|
||||
],
|
||||
)
|
||||
|
||||
genrule(
|
||||
name = "operator_writer_inc",
|
||||
srcs = [
|
||||
"@local_config_mlir//:include/mlir/IR/OpBase.td",
|
||||
"//tensorflow/compiler/mlir/xla:ir/xla_ops.td",
|
||||
],
|
||||
outs = [
|
||||
"operator_writers.inc",
|
||||
],
|
||||
cmd = ("$(location :operator_writer_gen) " +
|
||||
"-I external/local_config_mlir/include " +
|
||||
"$(location //tensorflow/compiler/mlir/xla:ir/xla_ops.td) " + " -o $@"),
|
||||
tools = [":operator_writer_gen"],
|
||||
)
|
||||
|
528
tensorflow/compiler/mlir/xla/hlo_function_importer.cc
Normal file
528
tensorflow/compiler/mlir/xla/hlo_function_importer.cc
Normal file
@ -0,0 +1,528 @@
|
||||
/* Copyright 2019 The TensorFlow Authors. All Rights Reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
|
||||
#include "tensorflow/compiler/mlir/xla/hlo_function_importer.h"
|
||||
|
||||
#include "llvm/ADT/ArrayRef.h"
|
||||
#include "llvm/ADT/STLExtras.h"
|
||||
#include "llvm/ADT/SmallVector.h"
|
||||
#include "llvm/Support/raw_ostream.h"
|
||||
#include "mlir/IR/Attributes.h" // TF:local_config_mlir
|
||||
#include "mlir/IR/Builders.h" // TF:local_config_mlir
|
||||
#include "mlir/IR/Identifier.h" // TF:local_config_mlir
|
||||
#include "mlir/IR/Location.h" // TF:local_config_mlir
|
||||
#include "mlir/IR/StandardTypes.h" // TF:local_config_mlir
|
||||
#include "mlir/StandardOps/Ops.h" // TF:local_config_mlir
|
||||
#include "tensorflow/compiler/mlir/tensorflow/utils/error_util.h"
|
||||
#include "tensorflow/compiler/mlir/xla/ir/xla_ops.h"
|
||||
#include "tensorflow/compiler/xla/protobuf_util.h"
|
||||
#include "tensorflow/compiler/xla/service/hlo_computation.h"
|
||||
#include "tensorflow/compiler/xla/service/hlo_instruction.h"
|
||||
#include "tensorflow/compiler/xla/service/hlo_instructions.h"
|
||||
#include "tensorflow/compiler/xla/service/hlo_module.h"
|
||||
#include "tensorflow/compiler/xla/service/hlo_opcode.h"
|
||||
#include "tensorflow/compiler/xla/status_macros.h"
|
||||
#include "tensorflow/compiler/xla/xla_data.pb.h"
|
||||
|
||||
using llvm::APInt;
|
||||
using llvm::makeArrayRef;
|
||||
using mlir::DenseElementsAttr;
|
||||
using mlir::DenseIntElementsAttr;
|
||||
using mlir::FuncOp;
|
||||
using mlir::NamedAttribute;
|
||||
using mlir::Operation;
|
||||
using mlir::ShapedType;
|
||||
using mlir::Type;
|
||||
using mlir::Value;
|
||||
|
||||
namespace xla {
|
||||
|
||||
namespace {
|
||||
// Note: This sanitization function causes an irreversible many-to-one mapping
|
||||
// and any solution to mitigate this would cause issues with the reverse
|
||||
// direction. Longterm solution is to add a function attribute to maintain the
|
||||
// original HLO naming.
|
||||
string SanitizeFunctionName(llvm::StringRef name) {
|
||||
string output = name;
|
||||
llvm::for_each(output, [](char& x) { x = x == '-' ? '_' : x; });
|
||||
return output;
|
||||
}
|
||||
|
||||
StatusOr<DenseElementsAttr> CreateDenseAttrFromLiteral(ShapedType type,
|
||||
const Literal& literal) {
|
||||
#define DENSE_ELEMENT_ATTR_BUILDER(xla_type, cpp_type) \
|
||||
case xla_type: { \
|
||||
auto data_span = literal.data<cpp_type>(); \
|
||||
return DenseElementsAttr::get( \
|
||||
type, llvm::makeArrayRef(data_span.data(), data_span.size())); \
|
||||
}
|
||||
|
||||
switch (literal.shape().element_type()) {
|
||||
DENSE_ELEMENT_ATTR_BUILDER(PrimitiveType::PRED, bool)
|
||||
DENSE_ELEMENT_ATTR_BUILDER(PrimitiveType::F32, float)
|
||||
DENSE_ELEMENT_ATTR_BUILDER(PrimitiveType::F64, double)
|
||||
DENSE_ELEMENT_ATTR_BUILDER(PrimitiveType::S8, int8)
|
||||
DENSE_ELEMENT_ATTR_BUILDER(PrimitiveType::S16, int16)
|
||||
DENSE_ELEMENT_ATTR_BUILDER(PrimitiveType::S32, int32)
|
||||
DENSE_ELEMENT_ATTR_BUILDER(PrimitiveType::S64, int64)
|
||||
default:
|
||||
return tensorflow::errors::Internal(
|
||||
absl::StrCat("Unsupported type: ",
|
||||
PrimitiveType_Name(literal.shape().element_type())));
|
||||
}
|
||||
#undef DENSE_ELEMENT_ATTR_BUILDER
|
||||
}
|
||||
} // namespace
|
||||
|
||||
StatusOr<mlir::FuncOp> HloFunctionImporter::ImportFunction(
|
||||
mlir::ModuleOp module, mlir::Builder* builder,
|
||||
std::unordered_map<HloComputation*, FuncOp>* function_map,
|
||||
HloComputation* computation) {
|
||||
HloFunctionImporter importer(module, builder, function_map);
|
||||
return importer.ImportFunction(computation);
|
||||
}
|
||||
|
||||
StatusOr<mlir::FuncOp> HloFunctionImporter::ImportFunction(
|
||||
HloComputation* computation) {
|
||||
auto& imported = (*function_map_)[computation];
|
||||
if (imported) return imported;
|
||||
|
||||
llvm::SmallVector<Type, 4> args, rets;
|
||||
TF_RETURN_IF_ERROR(
|
||||
GetMlirTypes(computation->parameter_instructions(), &args));
|
||||
TF_RETURN_IF_ERROR(GetMlirTypes({computation->root_instruction()}, &rets));
|
||||
|
||||
auto func_type = mlir::FunctionType::get(args, rets, context_);
|
||||
|
||||
string computation_name =
|
||||
computation->parent()->entry_computation() == computation
|
||||
? "main"
|
||||
: SanitizeFunctionName(computation->name());
|
||||
|
||||
// Construct the MLIR function and map arguments.
|
||||
llvm::ArrayRef<mlir::NamedAttribute> attrs;
|
||||
auto function = mlir::FuncOp::create(mlir::UnknownLoc::get(context_),
|
||||
computation_name, func_type, attrs);
|
||||
module_.push_back(function);
|
||||
|
||||
// Add to the map right away for function calls.
|
||||
imported = function;
|
||||
|
||||
function.addEntryBlock();
|
||||
|
||||
// Setup the input parameters.
|
||||
const int num_parameters = computation->num_parameters();
|
||||
for (int i = 0; i < num_parameters; i++) {
|
||||
auto hlo_parameter = computation->parameter_instruction(i);
|
||||
instruction_value_map_[hlo_parameter] = function.getArgument(i);
|
||||
}
|
||||
|
||||
mlir::OpBuilder func_builder(function.getBody());
|
||||
for (auto instruction : computation->MakeInstructionPostOrder()) {
|
||||
TF_ASSIGN_OR_RETURN(auto new_operation,
|
||||
ImportInstruction(instruction, &func_builder));
|
||||
if (new_operation) {
|
||||
instruction_value_map_[instruction] = new_operation->getResult(0);
|
||||
}
|
||||
}
|
||||
|
||||
// Setup the return type (HLO only supports a single return value).
|
||||
TF_ASSIGN_OR_RETURN(auto result,
|
||||
GetMlirValue(computation->root_instruction()));
|
||||
llvm::SmallVector<Value*, 1> return_values({result});
|
||||
// TODO(suderman): Add location tracking details.
|
||||
func_builder.create<mlir::ReturnOp>(mlir::UnknownLoc::get(context_),
|
||||
makeArrayRef(return_values));
|
||||
|
||||
return function;
|
||||
}
|
||||
|
||||
StatusOr<mlir::Operation*> HloFunctionImporter::ImportInstruction(
|
||||
HloInstruction* instruction, mlir::OpBuilder* func_builder) {
|
||||
TF_ASSIGN_OR_RETURN(auto operands, GetOperands(instruction));
|
||||
TF_ASSIGN_OR_RETURN(auto result_type, ConvertType(instruction->shape()));
|
||||
llvm::SmallVector<NamedAttribute, 10> attributes = {builder_->getNamedAttr(
|
||||
"name", builder_->getStringAttr(instruction->name()))};
|
||||
mlir::Location loc = mlir::UnknownLoc::get(context_);
|
||||
|
||||
switch (instruction->opcode()) {
|
||||
case HloOpcode::kParameter: {
|
||||
return nullptr;
|
||||
}
|
||||
case HloOpcode::kConstant: {
|
||||
auto attr = CreateDenseAttrFromLiteral(
|
||||
result_type.cast<mlir::TensorType>(), instruction->literal());
|
||||
if (!attr.ok()) return attr.status();
|
||||
mlir::Operation* new_operation =
|
||||
func_builder->create<mlir::ConstantOp>(loc, attr.ValueOrDie());
|
||||
for (auto attr : attributes) {
|
||||
new_operation->setAttr(attr.first, attr.second);
|
||||
}
|
||||
return new_operation;
|
||||
}
|
||||
case HloOpcode::kIota: {
|
||||
return func_builder
|
||||
->create<mlir::XLA::IotaOp>(
|
||||
loc, result_type,
|
||||
func_builder->getI64IntegerAttr(
|
||||
static_cast<HloIotaInstruction*>(instruction)
|
||||
->iota_dimension()))
|
||||
.getOperation();
|
||||
}
|
||||
#define MakeAndReturn(mlir_op) \
|
||||
{ \
|
||||
mlir::Operation* new_operation = func_builder->create<mlir::XLA::mlir_op>( \
|
||||
loc, result_type, operands, attributes); \
|
||||
return new_operation; \
|
||||
}
|
||||
case HloOpcode::kBroadcast: {
|
||||
// Note that the HLO broadcast is more powerful than the XLA broadcast op.
|
||||
// BroadcastInDim offers a superset of the HLO op's functionality.
|
||||
if (!instruction->dimensions().empty()) {
|
||||
attributes.push_back(builder_->getNamedAttr(
|
||||
"broadcast_dimensions",
|
||||
ConvertDimensions(instruction->dimensions())));
|
||||
}
|
||||
MakeAndReturn(BroadcastInDimOp);
|
||||
}
|
||||
case HloOpcode::kDot: {
|
||||
// TODO(b/129153247) Add support for batch and contracting dimensions.
|
||||
TF_RETURN_IF_ERROR(ValidateDotDimensions(instruction));
|
||||
|
||||
// TODO(b/129709049) The HLO text format elides this in the all DEFAULT
|
||||
// case and the parser sticks it in. Maybe we should too.
|
||||
attributes.push_back(ConvertPrecisionConfig(instruction));
|
||||
MakeAndReturn(DotOp);
|
||||
}
|
||||
case HloOpcode::kCall: {
|
||||
TF_ASSIGN_OR_RETURN(FuncOp function,
|
||||
ImportFunction(instruction->to_apply()));
|
||||
mlir::Operation* new_operation =
|
||||
func_builder->create<mlir::CallOp>(loc, function, operands);
|
||||
return new_operation;
|
||||
}
|
||||
case HloOpcode::kCompare: {
|
||||
attributes.push_back(ConvertComparisonDirection(instruction));
|
||||
MakeAndReturn(CompareOp);
|
||||
}
|
||||
case HloOpcode::kGather: {
|
||||
const auto& gather_dimensions = instruction->gather_dimension_numbers();
|
||||
std::vector<int64_t> offset_dims(gather_dimensions.offset_dims().begin(),
|
||||
gather_dimensions.offset_dims().end());
|
||||
|
||||
std::vector<int64_t> slice_sizes(
|
||||
instruction->gather_slice_sizes().begin(),
|
||||
instruction->gather_slice_sizes().end());
|
||||
|
||||
std::vector<int64_t> collapsed_slice_dims(
|
||||
gather_dimensions.collapsed_slice_dims().begin(),
|
||||
gather_dimensions.collapsed_slice_dims().end());
|
||||
|
||||
std::vector<int64_t> start_index_map(
|
||||
gather_dimensions.start_index_map().begin(),
|
||||
gather_dimensions.start_index_map().end());
|
||||
|
||||
// TODO(b/132057942): Change to explicitly passing an integer instead of
|
||||
// call getI64IntegerAttr here.
|
||||
return func_builder
|
||||
->create<mlir::XLA::GatherOp>(
|
||||
loc, result_type, operands[0], operands[1],
|
||||
func_builder->getI64IntegerAttr(
|
||||
gather_dimensions.index_vector_dim()),
|
||||
Convert(offset_dims), Convert(slice_sizes),
|
||||
Convert(collapsed_slice_dims), Convert(start_index_map))
|
||||
.getOperation();
|
||||
}
|
||||
case HloOpcode::kDynamicUpdateSlice: {
|
||||
return func_builder
|
||||
->create<mlir::XLA::DynamicUpdateSliceOp>(
|
||||
loc, result_type, operands[0], operands[1],
|
||||
llvm::ArrayRef<Value*>(operands.begin() + 2, operands.end()))
|
||||
.getOperation();
|
||||
}
|
||||
case HloOpcode::kPad: {
|
||||
const auto& padding_config = instruction->padding_config();
|
||||
llvm::SmallVector<int64_t, 4> edge_padding_low;
|
||||
llvm::SmallVector<int64_t, 4> edge_padding_high;
|
||||
llvm::SmallVector<int64_t, 4> interior_padding;
|
||||
edge_padding_low.reserve(padding_config.dimensions_size());
|
||||
edge_padding_high.reserve(padding_config.dimensions_size());
|
||||
interior_padding.reserve(padding_config.dimensions_size());
|
||||
|
||||
for (const auto& dimension : padding_config.dimensions()) {
|
||||
edge_padding_low.push_back(dimension.edge_padding_low());
|
||||
edge_padding_high.push_back(dimension.edge_padding_high());
|
||||
interior_padding.push_back(dimension.interior_padding());
|
||||
}
|
||||
|
||||
return func_builder
|
||||
->create<mlir::XLA::PadOp>(loc, result_type, operands[0], operands[1],
|
||||
Convert(edge_padding_low),
|
||||
Convert(edge_padding_high),
|
||||
Convert(interior_padding))
|
||||
.getOperation();
|
||||
}
|
||||
case HloOpcode::kSlice: {
|
||||
return func_builder
|
||||
->create<mlir::XLA::SliceOp>(
|
||||
loc, result_type, operands[0],
|
||||
ConvertDimensions(instruction->slice_starts()),
|
||||
ConvertDimensions(instruction->slice_limits()))
|
||||
.getOperation();
|
||||
}
|
||||
case HloOpcode::kConcatenate: {
|
||||
// TODO(b/132057942): Support taking an uint64_t instead of an IntegerAttr
|
||||
// for concatenate dimension.
|
||||
return func_builder
|
||||
->create<mlir::XLA::ConcatenateOp>(
|
||||
loc, result_type, operands,
|
||||
builder_->getI64IntegerAttr(instruction->concatenate_dimension()))
|
||||
.getOperation();
|
||||
}
|
||||
case HloOpcode::kReduce: {
|
||||
TF_ASSIGN_OR_RETURN(auto reduction,
|
||||
ImportFunction(instruction->to_apply()));
|
||||
// TODO(b/132057942): Make more convenient constructors, e.g. pass
|
||||
// mlir function pointer instead of a function attr.
|
||||
return func_builder
|
||||
->create<mlir::XLA::ReduceOp>(
|
||||
loc, result_type, operands,
|
||||
func_builder->getSymbolRefAttr(reduction),
|
||||
ConvertDimensions(instruction->dimensions()))
|
||||
.getOperation();
|
||||
}
|
||||
case HloOpcode::kReverse: {
|
||||
return func_builder
|
||||
->create<mlir::XLA::ReverseOp>(
|
||||
loc, result_type, operands[0],
|
||||
ConvertDimensions(instruction->dimensions()))
|
||||
.getOperation();
|
||||
}
|
||||
case HloOpcode::kWhile: {
|
||||
TF_ASSIGN_OR_RETURN(auto body, ImportFunction(instruction->while_body()));
|
||||
TF_ASSIGN_OR_RETURN(auto cond,
|
||||
ImportFunction(instruction->while_condition()));
|
||||
|
||||
llvm::SmallVector<Type, 4> types;
|
||||
types.reserve(operands.size());
|
||||
for (auto operand : operands) {
|
||||
types.push_back(operand->getType());
|
||||
}
|
||||
|
||||
auto cond_attr = func_builder->getSymbolRefAttr(cond);
|
||||
auto body_attr = func_builder->getSymbolRefAttr(body);
|
||||
|
||||
Operation* op = func_builder->create<mlir::XLA::WhileOp>(
|
||||
loc, types, operands, cond_attr, body_attr);
|
||||
return op;
|
||||
}
|
||||
case HloOpcode::kGetTupleElement: {
|
||||
attributes.push_back(builder_->getNamedAttr(
|
||||
"index", builder_->getIntegerAttr(builder_->getIntegerType(32),
|
||||
instruction->tuple_index())));
|
||||
MakeAndReturn(GetTupleElementOp);
|
||||
};
|
||||
case HloOpcode::kTranspose: {
|
||||
attributes.push_back(builder_->getNamedAttr(
|
||||
"permutation", ConvertDimensions(instruction->dimensions())));
|
||||
MakeAndReturn(TransposeOp);
|
||||
}
|
||||
#define NoAttributeCase(hlo_op_code, mlir_op) \
|
||||
case HloOpcode::hlo_op_code: { \
|
||||
MakeAndReturn(mlir_op); \
|
||||
}
|
||||
|
||||
// broadcast dimensions are never added here because they don't exist as
|
||||
// part of the HLO instruction. They are only a convenience in the XLA
|
||||
// builder API.
|
||||
NoAttributeCase(kAdd, AddOp);
|
||||
NoAttributeCase(kAnd, AndOp);
|
||||
NoAttributeCase(kConvert, ConvertOp);
|
||||
NoAttributeCase(kDivide, DivOp);
|
||||
NoAttributeCase(kMaximum, MaxOp);
|
||||
NoAttributeCase(kMinimum, MinOp);
|
||||
NoAttributeCase(kMultiply, MulOp);
|
||||
NoAttributeCase(kSelect, SelectOp);
|
||||
NoAttributeCase(kSubtract, SubOp);
|
||||
NoAttributeCase(kTanh, TanhOp);
|
||||
NoAttributeCase(kTuple, TupleOp);
|
||||
// TODO(b/129422361) Copy needs special handling because it is not defined
|
||||
// in tensorflow/compiler/xla/client/xla_builder.h.
|
||||
// See operation semantics in
|
||||
// g3doc/platforms/xla/g3doc/internal/hlo_semantics#copy
|
||||
NoAttributeCase(kCopy, CopyOp);
|
||||
// TODO(b/129422361) Ops below need additional work to handle attributes.
|
||||
NoAttributeCase(kConvolution, ConvOp);
|
||||
NoAttributeCase(kReshape, ReshapeOp);
|
||||
#undef NoAttributeCase
|
||||
#undef MakeAndReturn
|
||||
case HloOpcode::kAddDependency:
|
||||
// Arbitrary op code that I suspect we will not implement for quite a
|
||||
// while and allows testing handling of unknown ops. Selected because it
|
||||
// is not mentioned in xla client anywhere or in the hlo of our sample
|
||||
// models.
|
||||
default: {
|
||||
mlir::OperationState result(loc, "xla.unknown");
|
||||
result.addOperands(operands);
|
||||
result.addTypes(result_type);
|
||||
for (auto attr : attributes) {
|
||||
result.attributes.push_back(attr);
|
||||
}
|
||||
|
||||
return func_builder->createOperation(result);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
StatusOr<llvm::SmallVector<mlir::Value*, 4>> HloFunctionImporter::GetOperands(
|
||||
HloInstruction* instruction) {
|
||||
llvm::SmallVector<mlir::Value*, 4> operands;
|
||||
for (const auto& operand : instruction->operands()) {
|
||||
auto input_it = instruction_value_map_.find(operand);
|
||||
if (input_it == instruction_value_map_.end()) {
|
||||
return tensorflow::errors::Internal(
|
||||
absl::StrCat("Could not find input value: ", operand->name(),
|
||||
" for instruction ", instruction->name()));
|
||||
}
|
||||
operands.push_back(input_it->second);
|
||||
}
|
||||
return operands;
|
||||
}
|
||||
|
||||
// TODO(suderman): Move to a general library when needed in other places.
|
||||
StatusOr<mlir::RankedTensorType> HloFunctionImporter::ConvertTensorType(
|
||||
const Shape& shape) {
|
||||
auto type = shape.element_type();
|
||||
|
||||
llvm::SmallVector<int64_t, 4> array;
|
||||
array.reserve(shape.dimensions_size());
|
||||
for (auto val : shape.dimensions()) {
|
||||
array.push_back(val);
|
||||
}
|
||||
|
||||
switch (type) {
|
||||
case PrimitiveType::PRED:
|
||||
return builder_->getTensorType(array, builder_->getI1Type());
|
||||
case PrimitiveType::F16:
|
||||
return builder_->getTensorType(array, builder_->getF16Type());
|
||||
case PrimitiveType::F32:
|
||||
return builder_->getTensorType(array, builder_->getF32Type());
|
||||
case PrimitiveType::F64:
|
||||
return builder_->getTensorType(array, builder_->getF64Type());
|
||||
case PrimitiveType::S8:
|
||||
return builder_->getTensorType(array, builder_->getIntegerType(8));
|
||||
case PrimitiveType::S16:
|
||||
return builder_->getTensorType(array, builder_->getIntegerType(16));
|
||||
case PrimitiveType::S32:
|
||||
return builder_->getTensorType(array, builder_->getIntegerType(32));
|
||||
case PrimitiveType::S64:
|
||||
return builder_->getTensorType(array, builder_->getIntegerType(64));
|
||||
default:
|
||||
return tensorflow::errors::Internal(
|
||||
absl::StrCat("Unsupported type: ", PrimitiveType_Name(type)));
|
||||
}
|
||||
}
|
||||
|
||||
StatusOr<mlir::Type> HloFunctionImporter::ConvertType(const Shape& shape) {
|
||||
if (shape.IsTuple()) {
|
||||
mlir::Type mlir_type;
|
||||
llvm::SmallVector<mlir::Type, 4> contents;
|
||||
contents.reserve(shape.tuple_shapes_size());
|
||||
for (const auto& subtype : shape.tuple_shapes()) {
|
||||
TF_ASSIGN_OR_RETURN(auto mlir_subtype, ConvertType(subtype));
|
||||
contents.push_back(mlir_subtype);
|
||||
}
|
||||
|
||||
return builder_->getTupleType(contents);
|
||||
}
|
||||
|
||||
return ConvertTensorType(shape);
|
||||
}
|
||||
|
||||
tensorflow::Status HloFunctionImporter::GetMlirTypes(
|
||||
const std::vector<HloInstruction*>& instructions,
|
||||
llvm::SmallVectorImpl<mlir::Type>* types) {
|
||||
for (auto instruction : instructions) {
|
||||
TF_ASSIGN_OR_RETURN(auto ret_type, ConvertType(instruction->shape()));
|
||||
types->push_back(ret_type);
|
||||
}
|
||||
return tensorflow::Status::OK();
|
||||
}
|
||||
|
||||
StatusOr<Value*> HloFunctionImporter::GetMlirValue(
|
||||
HloInstruction* instruction) {
|
||||
auto lookup = instruction_value_map_.find(instruction);
|
||||
if (lookup != instruction_value_map_.end()) {
|
||||
return lookup->second;
|
||||
}
|
||||
|
||||
return tensorflow::errors::Internal(absl::StrCat(
|
||||
"Unable to find value for input: ", instruction->ToString()));
|
||||
}
|
||||
|
||||
mlir::NamedAttribute HloFunctionImporter::ConvertPrecisionConfig(
|
||||
HloInstruction* instruction) {
|
||||
llvm::SmallVector<mlir::Attribute, 4> operand_precision_attrs;
|
||||
|
||||
for (auto prec : instruction->precision_config().operand_precision()) {
|
||||
operand_precision_attrs.push_back(
|
||||
builder_->getStringAttr(PrecisionConfig_Precision_Name(prec)));
|
||||
}
|
||||
return builder_->getNamedAttr(
|
||||
"precision_config", builder_->getArrayAttr(operand_precision_attrs));
|
||||
}
|
||||
|
||||
mlir::NamedAttribute HloFunctionImporter::ConvertComparisonDirection(
|
||||
HloInstruction* instruction) {
|
||||
return builder_->getNamedAttr(
|
||||
"comparison_direction",
|
||||
builder_->getStringAttr(
|
||||
ComparisonDirectionToString(instruction->comparison_direction())));
|
||||
}
|
||||
|
||||
mlir::ElementsAttr HloFunctionImporter::ConvertDimensions(
|
||||
llvm::ArrayRef<int64> op_dimensions) {
|
||||
llvm::SmallVector<APInt, 8> dimensions;
|
||||
dimensions.reserve(op_dimensions.size());
|
||||
for (auto value : op_dimensions) dimensions.emplace_back(APInt(64, value));
|
||||
|
||||
return DenseIntElementsAttr::get(
|
||||
builder_->getTensorType(dimensions.size(), builder_->getIntegerType(64)),
|
||||
dimensions);
|
||||
}
|
||||
|
||||
mlir::ElementsAttr HloFunctionImporter::Convert(
|
||||
llvm::ArrayRef<int64_t> op_dimensions) {
|
||||
return builder_->getDenseIntElementsAttr(
|
||||
builder_->getTensorType(op_dimensions.size(),
|
||||
builder_->getIntegerType(64)),
|
||||
op_dimensions);
|
||||
}
|
||||
|
||||
Status HloFunctionImporter::ValidateDotDimensions(HloInstruction* instruction) {
|
||||
DotDimensionNumbers expected_dimension_numbers;
|
||||
expected_dimension_numbers.add_lhs_contracting_dimensions(
|
||||
instruction->operand(0)->shape().dimensions_size() == 1 ? 0 : 1);
|
||||
expected_dimension_numbers.add_rhs_contracting_dimensions(0);
|
||||
if (!xla::protobuf_util::ProtobufEquals(instruction->dot_dimension_numbers(),
|
||||
expected_dimension_numbers)) {
|
||||
return tensorflow::errors::Internal(
|
||||
absl::StrCat("Dot operation has unsupported dimension numbers: ",
|
||||
instruction->dot_dimension_numbers().DebugString()));
|
||||
}
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
} // namespace xla
|
112
tensorflow/compiler/mlir/xla/hlo_function_importer.h
Normal file
112
tensorflow/compiler/mlir/xla/hlo_function_importer.h
Normal file
@ -0,0 +1,112 @@
|
||||
/* Copyright 2019 The TensorFlow Authors. All Rights Reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
|
||||
#ifndef TENSORFLOW_COMPILER_MLIR_XLA_HLO_FUNCTION_IMPORTER_H_
|
||||
#define TENSORFLOW_COMPILER_MLIR_XLA_HLO_FUNCTION_IMPORTER_H_
|
||||
|
||||
#include <unordered_map>
|
||||
|
||||
#include "mlir/IR/Attributes.h" // TF:local_config_mlir
|
||||
#include "mlir/IR/Builders.h" // TF:local_config_mlir
|
||||
#include "mlir/IR/Function.h" // TF:local_config_mlir
|
||||
#include "mlir/IR/MLIRContext.h" // TF:local_config_mlir
|
||||
#include "mlir/IR/Module.h" // TF:local_config_mlir
|
||||
#include "mlir/IR/StandardTypes.h" // TF:local_config_mlir
|
||||
#include "tensorflow/compiler/mlir/tensorflow/utils/error_util.h"
|
||||
#include "tensorflow/compiler/xla/status.h"
|
||||
#include "tensorflow/compiler/xla/statusor.h"
|
||||
#include "tensorflow/compiler/xla/xla_data.pb.h"
|
||||
|
||||
namespace xla {
|
||||
|
||||
class HloModule;
|
||||
class HloComputation;
|
||||
class HloInstruction;
|
||||
class Shape;
|
||||
|
||||
// Helper class for importing HloComputations.
|
||||
class HloFunctionImporter {
|
||||
public:
|
||||
static StatusOr<mlir::FuncOp> ImportFunction(
|
||||
mlir::ModuleOp module, mlir::Builder* builder,
|
||||
std::unordered_map<xla::HloComputation*, mlir::FuncOp>* function_map,
|
||||
xla::HloComputation* computation);
|
||||
|
||||
private:
|
||||
HloFunctionImporter(
|
||||
mlir::ModuleOp module, mlir::Builder* builder,
|
||||
std::unordered_map<xla::HloComputation*, mlir::FuncOp>* function_map)
|
||||
: context_(module.getContext()),
|
||||
module_(module),
|
||||
builder_(builder),
|
||||
function_map_(function_map) {}
|
||||
|
||||
StatusOr<mlir::FuncOp> ImportFunction(xla::HloComputation* computation);
|
||||
|
||||
// Imports an instruction.
|
||||
StatusOr<mlir::Operation*> ImportInstruction(xla::HloInstruction* instruction,
|
||||
mlir::OpBuilder* func_builder);
|
||||
|
||||
// Gets the MLIR operand values from an HLO Instruction.
|
||||
StatusOr<llvm::SmallVector<mlir::Value*, 4>> GetOperands(
|
||||
xla::HloInstruction* instruction);
|
||||
|
||||
// Converts xla Tensor type to the corresponding MLIR type.
|
||||
StatusOr<mlir::RankedTensorType> ConvertTensorType(const xla::Shape& shape);
|
||||
|
||||
// Converts xla Primitive types to the corresponding MLIR type.
|
||||
StatusOr<mlir::Type> ConvertType(const xla::Shape& shape);
|
||||
|
||||
// Returns the output type of an HloInstruction.
|
||||
StatusOr<mlir::Type> GetReturnType(xla::HloInstruction* instruction);
|
||||
|
||||
// Takes a list of HloInstructions and generates the list of types used for
|
||||
// input, bypassing tuples to subsets.
|
||||
Status GetMlirTypes(const std::vector<xla::HloInstruction*>& instructions,
|
||||
llvm::SmallVectorImpl<mlir::Type>* types);
|
||||
|
||||
// Returns the Mlir Value for the corresponding HloInstruction.
|
||||
StatusOr<mlir::Value*> GetMlirValue(xla::HloInstruction* instruction);
|
||||
|
||||
// Converts an XLA PrecisionConfig to the corresponding MLIR attribute.
|
||||
mlir::NamedAttribute ConvertPrecisionConfig(xla::HloInstruction* instruction);
|
||||
|
||||
// Converts an XLA ComparisonDirection to the corresponding MLIR attribute.
|
||||
mlir::NamedAttribute ConvertComparisonDirection(
|
||||
xla::HloInstruction* instruction);
|
||||
|
||||
// Converts the dimensions of an HLO instruction into an MLIR attribute.
|
||||
mlir::ElementsAttr ConvertDimensions(llvm::ArrayRef<int64> op_dimensions);
|
||||
|
||||
// Converts Array ref to an ElementsAttr.
|
||||
mlir::ElementsAttr Convert(llvm::ArrayRef<int64_t> op_dimensions);
|
||||
|
||||
// Ensures dot instruction has only default contracting and batch dimensions.
|
||||
Status ValidateDotDimensions(xla::HloInstruction* instruction);
|
||||
|
||||
mlir::MLIRContext* context_;
|
||||
mlir::ModuleOp module_;
|
||||
mlir::Builder* builder_;
|
||||
|
||||
// Mapping from HloComputation to the created MLIR function.
|
||||
std::unordered_map<xla::HloComputation*, mlir::FuncOp>* function_map_;
|
||||
|
||||
// Mapping from HloInstructions to the associative MLIR values.
|
||||
std::unordered_map<xla::HloInstruction*, mlir::Value*> instruction_value_map_;
|
||||
};
|
||||
|
||||
} // namespace xla
|
||||
|
||||
#endif // TENSORFLOW_COMPILER_MLIR_XLA_HLO_FUNCTION_IMPORTER_H_
|
54
tensorflow/compiler/mlir/xla/hlo_module_importer.cc
Normal file
54
tensorflow/compiler/mlir/xla/hlo_module_importer.cc
Normal file
@ -0,0 +1,54 @@
|
||||
/* Copyright 2019 The TensorFlow Authors. All Rights Reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
|
||||
#include "tensorflow/compiler/mlir/xla/hlo_module_importer.h"
|
||||
|
||||
#include "mlir/IR/Attributes.h" // TF:local_config_mlir
|
||||
#include "mlir/IR/Location.h" // TF:local_config_mlir
|
||||
#include "mlir/IR/OperationSupport.h" // TF:local_config_mlir
|
||||
#include "mlir/IR/StandardTypes.h" // TF:local_config_mlir
|
||||
#include "mlir/IR/Types.h" // TF:local_config_mlir
|
||||
#include "mlir/StandardOps/Ops.h" // TF:local_config_mlir
|
||||
#include "tensorflow/compiler/mlir/xla/hlo_function_importer.h"
|
||||
#include "tensorflow/compiler/mlir/xla/ir/xla_ops.h"
|
||||
#include "tensorflow/compiler/xla/service/hlo_computation.h"
|
||||
#include "tensorflow/compiler/xla/service/hlo_instruction.h"
|
||||
#include "tensorflow/compiler/xla/service/hlo_module.h"
|
||||
#include "tensorflow/compiler/xla/xla.pb.h"
|
||||
|
||||
namespace xla {
|
||||
|
||||
Status HloModuleImporter::Import(const xla::HloModule& module) {
|
||||
for (const auto& computation : module.computations()) {
|
||||
auto result = HloFunctionImporter::ImportFunction(
|
||||
module_, &builder_, &function_map_, computation);
|
||||
TF_RETURN_IF_ERROR(result.status());
|
||||
}
|
||||
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
Status HloModuleImporter::Import(const xla::HloModuleProto& module_proto) {
|
||||
xla::DebugOptions debug_options;
|
||||
TF_ASSIGN_OR_RETURN(
|
||||
auto module_config,
|
||||
xla::HloModule::CreateModuleConfigFromProto(module_proto, debug_options));
|
||||
TF_ASSIGN_OR_RETURN(auto module, xla::HloModule::CreateFromProto(
|
||||
module_proto, module_config));
|
||||
|
||||
return Import(*module);
|
||||
}
|
||||
|
||||
} // namespace xla
|
62
tensorflow/compiler/mlir/xla/hlo_module_importer.h
Normal file
62
tensorflow/compiler/mlir/xla/hlo_module_importer.h
Normal file
@ -0,0 +1,62 @@
|
||||
/* Copyright 2019 The TensorFlow Authors. All Rights Reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
|
||||
#ifndef TENSORFLOW_COMPILER_MLIR_XLA_HLO_MODULE_IMPORTER_H_
|
||||
#define TENSORFLOW_COMPILER_MLIR_XLA_HLO_MODULE_IMPORTER_H_
|
||||
|
||||
#include <unordered_map>
|
||||
|
||||
#include "mlir/IR/Builders.h" // TF:local_config_mlir
|
||||
#include "mlir/IR/Function.h" // TF:local_config_mlir
|
||||
#include "mlir/IR/MLIRContext.h" // TF:local_config_mlir
|
||||
#include "mlir/IR/Module.h" // TF:local_config_mlir
|
||||
#include "tensorflow/compiler/mlir/tensorflow/utils/error_util.h"
|
||||
#include "tensorflow/compiler/mlir/xla/ir/xla_ops.h"
|
||||
#include "tensorflow/compiler/xla/status.h"
|
||||
#include "tensorflow/compiler/xla/xla_data.pb.h"
|
||||
|
||||
namespace xla {
|
||||
class HloModule;
|
||||
class HloModuleProto;
|
||||
class HloComputation;
|
||||
class HloInstruction;
|
||||
class Shape;
|
||||
|
||||
// Importer that takes an HloModule and imports it as an MLIR module in the XLA
|
||||
// dialect. HloModuleImporter does not take ownership.
|
||||
class HloModuleImporter {
|
||||
public:
|
||||
explicit HloModuleImporter(mlir::ModuleOp module)
|
||||
: module_(module), builder_(module.getContext()) {}
|
||||
|
||||
// Import the HloModule into the MLIR Module.
|
||||
Status Import(const xla::HloModule& module);
|
||||
|
||||
// Import the HloModuleProto into the MLIR Module.
|
||||
Status Import(const xla::HloModuleProto& module);
|
||||
|
||||
private:
|
||||
mlir::ModuleOp module_;
|
||||
mlir::Builder builder_;
|
||||
|
||||
// Map for tracking which MLIR function map to which HLO Computation. This
|
||||
// tracks functions as they are imported and provides a quick lookup for
|
||||
// functions invoked by control flow related operations (e.g. while, call).
|
||||
std::unordered_map<xla::HloComputation*, mlir::FuncOp> function_map_;
|
||||
};
|
||||
|
||||
} // namespace xla
|
||||
|
||||
#endif // TENSORFLOW_COMPILER_MLIR_XLA_HLO_MODULE_IMPORTER_H_
|
35
tensorflow/compiler/mlir/xla/hlo_to_mlir_hlo.cc
Normal file
35
tensorflow/compiler/mlir/xla/hlo_to_mlir_hlo.cc
Normal file
@ -0,0 +1,35 @@
|
||||
/* Copyright 2019 The TensorFlow Authors. All Rights Reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
|
||||
#include "tensorflow/compiler/mlir/xla/hlo_to_mlir_hlo.h"
|
||||
|
||||
#include "tensorflow/compiler/mlir/xla/hlo_module_importer.h"
|
||||
#include "tensorflow/compiler/xla/status_macros.h"
|
||||
#include "tensorflow/core/lib/core/errors.h"
|
||||
|
||||
namespace xla {
|
||||
|
||||
Status ConvertHloToMlirHlo(mlir::ModuleOp module,
|
||||
xla::HloModuleProto* hlo_module_proto) {
|
||||
mlir::StatusScopedDiagnosticHandler diag_handler(module.getContext());
|
||||
return HloModuleImporter(module).Import(*hlo_module_proto);
|
||||
}
|
||||
|
||||
Status ConvertHloToMlirHlo(mlir::ModuleOp module, xla::HloModule* hlo_module) {
|
||||
mlir::StatusScopedDiagnosticHandler diag_handler(module.getContext());
|
||||
return HloModuleImporter(module).Import(*hlo_module);
|
||||
}
|
||||
|
||||
} // namespace xla
|
39
tensorflow/compiler/mlir/xla/hlo_to_mlir_hlo.h
Normal file
39
tensorflow/compiler/mlir/xla/hlo_to_mlir_hlo.h
Normal file
@ -0,0 +1,39 @@
|
||||
/* Copyright 2019 The TensorFlow Authors. All Rights Reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
|
||||
#ifndef TENSORFLOW_COMPILER_MLIR_XLA_HLO_TO_MLIR_HLO_H_
|
||||
#define TENSORFLOW_COMPILER_MLIR_XLA_HLO_TO_MLIR_HLO_H_
|
||||
|
||||
#include "tensorflow/compiler/mlir/tensorflow/utils/error_util.h"
|
||||
#include "tensorflow/compiler/xla/status.h"
|
||||
|
||||
namespace mlir {
|
||||
class ModuleOp;
|
||||
} // namespace mlir
|
||||
|
||||
namespace xla {
|
||||
class HloModule;
|
||||
class HloModuleProto;
|
||||
|
||||
// Converts an HLO module proto to a MLIR module in HLO dialect.
|
||||
Status ConvertHloToMlirHlo(mlir::ModuleOp module,
|
||||
xla::HloModuleProto* hlo_module);
|
||||
|
||||
// Converts an HLO module to a MLIR module in HLO dialect.
|
||||
Status ConvertHloToMlirHlo(mlir::ModuleOp module, xla::HloModule* hlo_module);
|
||||
|
||||
} // namespace xla
|
||||
|
||||
#endif // TENSORFLOW_COMPILER_MLIR_XLA_HLO_TO_MLIR_HLO_H_
|
259
tensorflow/compiler/mlir/xla/mlir_hlo_to_hlo.cc
Normal file
259
tensorflow/compiler/mlir/xla/mlir_hlo_to_hlo.cc
Normal file
@ -0,0 +1,259 @@
|
||||
/* Copyright 2019 The TensorFlow Authors. All Rights Reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
|
||||
#include "tensorflow/compiler/mlir/xla/mlir_hlo_to_hlo.h"
|
||||
|
||||
#include <memory>
|
||||
|
||||
#include "llvm/ADT/DenseMap.h"
|
||||
#include "llvm/Support/MemoryBuffer.h"
|
||||
#include "llvm/Support/SMLoc.h"
|
||||
#include "llvm/Support/SourceMgr.h"
|
||||
#include "llvm/Support/raw_ostream.h"
|
||||
#include "mlir/IR/Attributes.h" // TF:local_config_mlir
|
||||
#include "mlir/IR/Function.h" // TF:local_config_mlir
|
||||
#include "mlir/IR/Location.h" // TF:local_config_mlir
|
||||
#include "mlir/IR/MLIRContext.h" // TF:local_config_mlir
|
||||
#include "mlir/IR/Module.h" // TF:local_config_mlir
|
||||
#include "mlir/IR/Operation.h" // TF:local_config_mlir
|
||||
#include "mlir/StandardOps/Ops.h" // TF:local_config_mlir
|
||||
#include "tensorflow/compiler/mlir/xla/ir/xla_ops.h"
|
||||
#include "tensorflow/compiler/mlir/xla/type_to_shape.h"
|
||||
#include "tensorflow/compiler/xla/client/xla_builder.h"
|
||||
#include "tensorflow/compiler/xla/comparison_util.h"
|
||||
#include "tensorflow/compiler/xla/service/hlo_module.h"
|
||||
#include "tensorflow/compiler/xla/status_macros.h"
|
||||
#include "tensorflow/compiler/xla/xla_data.pb.h"
|
||||
|
||||
static std::vector<int64> ConvertDenseIntAttr(mlir::DenseIntElementsAttr attr) {
|
||||
llvm::ArrayRef<int64> raw_data = attr.getValues<int64>();
|
||||
if (attr.isSplat())
|
||||
return std::vector<int64>(attr.getType().getNumElements(), raw_data[0]);
|
||||
return raw_data;
|
||||
}
|
||||
|
||||
// Converts the broadcast_dimensions attribute into a span of dimension numbers
|
||||
// (empty if the attribute is absent).
|
||||
static std::vector<int64> Convert_broadcast_dimensions(
|
||||
llvm::Optional<mlir::ElementsAttr> broadcast_dimensions) {
|
||||
if (!broadcast_dimensions.hasValue()) return {};
|
||||
|
||||
return ConvertDenseIntAttr(
|
||||
broadcast_dimensions->cast<mlir::DenseIntElementsAttr>());
|
||||
}
|
||||
|
||||
// Converts the broadcast_sizes attribute into a span of dimension sizes.
|
||||
static std::vector<int64> Convert_broadcast_sizes(
|
||||
mlir::ElementsAttr broadcast_sizes) {
|
||||
return ConvertDenseIntAttr(
|
||||
broadcast_sizes.cast<mlir::DenseIntElementsAttr>());
|
||||
}
|
||||
|
||||
static std::vector<int64> Convert_permutation(mlir::ElementsAttr permutation) {
|
||||
return ConvertDenseIntAttr(permutation.cast<mlir::DenseIntElementsAttr>());
|
||||
}
|
||||
|
||||
// Converts the precision config array of strings attribute into the
|
||||
// corresponding XLA proto. All the strings are assumed to be valid names of the
|
||||
// Precision enum. This should have been checked in the op verify method.
|
||||
static std::unique_ptr<xla::PrecisionConfig> Convert_precision_config(
|
||||
llvm::Optional<mlir::ArrayAttr> optional_precision_config_attr) {
|
||||
if (!optional_precision_config_attr.hasValue()) return nullptr;
|
||||
|
||||
auto precision_config = absl::make_unique<xla::PrecisionConfig>();
|
||||
for (auto attr : optional_precision_config_attr.getValue()) {
|
||||
xla::PrecisionConfig::Precision p;
|
||||
auto operand_precision = attr.cast<mlir::StringAttr>().getValue().str();
|
||||
// TODO(jpienaar): Update this to ensure this is captured by verify.
|
||||
if (xla::PrecisionConfig::Precision_Parse(operand_precision, &p)) {
|
||||
precision_config->add_operand_precision(p);
|
||||
} else {
|
||||
auto* context = attr.getContext();
|
||||
mlir::emitError(mlir::UnknownLoc::get(context))
|
||||
<< "unexpected operand precision " << operand_precision;
|
||||
return nullptr;
|
||||
}
|
||||
}
|
||||
|
||||
return precision_config;
|
||||
}
|
||||
|
||||
// Converts the comparison_direction string attribute into the XLA enum. The
|
||||
// string is assumed to correspond to exactly one of the allowed strings
|
||||
// representing the enum. This should have been checked in the op verify method.
|
||||
static xla::ComparisonDirection Convert_comparison_direction(
|
||||
llvm::StringRef comparison_direction_string) {
|
||||
return xla::StringToComparisonDirection(comparison_direction_string.str())
|
||||
.ValueOrDie();
|
||||
}
|
||||
|
||||
// Passes through everything except for unique_ptr, on which it calls get().
|
||||
// This exists to allow the generated code to call XLA functions that take a raw
|
||||
// pointer. In particular, PrecisionConfig is passed to xla::Dot and xla::Conv
|
||||
// as a pointer and there is otherwise no way to avoid a memory leak.
|
||||
template <typename T>
|
||||
T Unwrap(T t) {
|
||||
return t;
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
T* Unwrap(const std::unique_ptr<T>& t) {
|
||||
return t.get();
|
||||
}
|
||||
|
||||
// Convert APInt into an int.
|
||||
// TODO(hpucha): This should be consolidated into a general place.
|
||||
static int ConvertAPInt(llvm::APInt i) { return i.getSExtValue(); }
|
||||
|
||||
// Convert APFloat to double.
|
||||
static double ConvertAPFloat(llvm::APFloat value) {
|
||||
const auto& semantics = value.getSemantics();
|
||||
bool losesInfo = false;
|
||||
if (&semantics != &llvm::APFloat::IEEEdouble())
|
||||
value.convert(llvm::APFloat::IEEEdouble(),
|
||||
llvm::APFloat::rmNearestTiesToEven, &losesInfo);
|
||||
return value.convertToDouble();
|
||||
}
|
||||
|
||||
#include "tensorflow/compiler/mlir/xla/operator_writers.inc"
|
||||
|
||||
namespace mlir {
|
||||
namespace {
|
||||
|
||||
class ConvertToHloModule {
|
||||
public:
|
||||
using ValueLoweringMap = llvm::DenseMap<Value*, xla::XlaOp>;
|
||||
using FunctionLoweringMap = llvm::DenseMap<mlir::FuncOp, xla::XlaComputation>;
|
||||
|
||||
explicit ConvertToHloModule(mlir::ModuleOp module)
|
||||
: module_(module), module_builder_("main") {}
|
||||
|
||||
// Perform the lowering to XLA. This function returns failure if an error was
|
||||
// encountered.
|
||||
LogicalResult Run() {
|
||||
for (auto func : module_.getOps<FuncOp>()) {
|
||||
if (func.empty()) continue;
|
||||
if (failed(RunOnFunction(func))) return failure();
|
||||
}
|
||||
return success();
|
||||
}
|
||||
|
||||
// Perform the lowering on a specific function. This function returns failure
|
||||
// if an error was encountered.
|
||||
LogicalResult RunOnFunction(mlir::FuncOp f);
|
||||
|
||||
xla::HloModuleProto ConsumeMainProto() {
|
||||
return lowered_computation_[module_.lookupSymbol<mlir::FuncOp>("main")]
|
||||
.proto();
|
||||
}
|
||||
|
||||
private:
|
||||
// The module being lowered.
|
||||
mlir::ModuleOp module_;
|
||||
|
||||
// The top-level XlaBuilder.
|
||||
xla::XlaBuilder module_builder_;
|
||||
|
||||
// Map between function and lowered computation.
|
||||
FunctionLoweringMap lowered_computation_;
|
||||
};
|
||||
|
||||
LogicalResult Lower(mlir::Operation* inst, xla::XlaBuilder* builder,
|
||||
ConvertToHloModule::FunctionLoweringMap* function_lowering,
|
||||
ConvertToHloModule::ValueLoweringMap* value_lowering) {
|
||||
if (auto xla_op = CreateXlaOperator(inst, value_lowering)) return success();
|
||||
|
||||
// TODO(riverriddle) We currently don't support lowering constant operations.
|
||||
if (isa<mlir::XLA::ConstOp>(inst)) {
|
||||
inst->emitError("unable to lower 'xla.constant' operation");
|
||||
return failure();
|
||||
}
|
||||
|
||||
auto& value_map = *value_lowering;
|
||||
if (auto ret = dyn_cast<mlir::ReturnOp>(inst)) {
|
||||
// Construct the return value for the function. If there are multiple
|
||||
// values returned, then create a tuple, else return value directly.
|
||||
xla::XlaOp return_value;
|
||||
unsigned num_return_values = ret.getNumOperands();
|
||||
if (num_return_values > 1) {
|
||||
std::vector<xla::XlaOp> returns(num_return_values);
|
||||
for (unsigned i = 0, e = ret.getNumOperands(); i != e; ++i) {
|
||||
returns[i] = value_map[ret.getOperand(i)];
|
||||
}
|
||||
return_value = xla::Tuple(builder, returns);
|
||||
} else if (num_return_values == 1) {
|
||||
return_value = value_map[ret.getOperand(0)];
|
||||
}
|
||||
|
||||
// Build the XlaComputation and check for failures.
|
||||
auto computation_or =
|
||||
return_value.valid() ? builder->Build(return_value) : builder->Build();
|
||||
if (!computation_or.ok()) {
|
||||
inst->emitError(llvm::Twine(computation_or.status().error_message()));
|
||||
return failure();
|
||||
}
|
||||
auto f = inst->getParentOfType<mlir::FuncOp>();
|
||||
(*function_lowering)[f] = std::move(computation_or.ValueOrDie());
|
||||
return success();
|
||||
}
|
||||
inst->emitError("unable to lower operation of type '" +
|
||||
inst->getName().getStringRef().str() + '\'');
|
||||
return failure();
|
||||
}
|
||||
|
||||
LogicalResult ConvertToHloModule::RunOnFunction(mlir::FuncOp f) {
|
||||
if (f.getBlocks().size() != 1) {
|
||||
return f.emitError("only single block Function suppored");
|
||||
}
|
||||
|
||||
// Create a sub-builder if this is not the main function.
|
||||
std::unique_ptr<xla::XlaBuilder> builder_up;
|
||||
bool entry_function = f.getName().str() == "main";
|
||||
if (!entry_function)
|
||||
builder_up = module_builder_.CreateSubBuilder(f.getName().str());
|
||||
auto& builder = entry_function ? module_builder_ : *builder_up;
|
||||
|
||||
// Mapping from the Value to lowered XlaOp. The code below lowers in
|
||||
// program order and will fail if an operand is unseen. This can be improved.
|
||||
ValueLoweringMap lowering;
|
||||
for (auto& bb : f) {
|
||||
int num = 0;
|
||||
for (auto& arg : bb.getArguments()) {
|
||||
xla::Shape shape = xla::TypeToShape(arg->getType());
|
||||
lowering[arg] =
|
||||
xla::Parameter(&builder, num, shape, absl::StrCat("Arg_", num));
|
||||
++num;
|
||||
}
|
||||
|
||||
for (auto& inst : bb)
|
||||
if (failed(Lower(&inst, &builder, &lowered_computation_, &lowering)))
|
||||
return failure();
|
||||
}
|
||||
|
||||
return success();
|
||||
}
|
||||
|
||||
} // namespace
|
||||
|
||||
Status ConvertMlirHloToHlo(mlir::ModuleOp module, xla::HloProto* hlo_proto) {
|
||||
mlir::StatusScopedDiagnosticHandler diag_handler(module.getContext());
|
||||
ConvertToHloModule converter(module);
|
||||
if (failed(converter.Run())) return diag_handler.ConsumeStatus();
|
||||
auto hlo_module = converter.ConsumeMainProto();
|
||||
hlo_proto->mutable_hlo_module()->Swap(&hlo_module);
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
} // namespace mlir
|
37
tensorflow/compiler/mlir/xla/mlir_hlo_to_hlo.h
Normal file
37
tensorflow/compiler/mlir/xla/mlir_hlo_to_hlo.h
Normal file
@ -0,0 +1,37 @@
|
||||
/* Copyright 2019 The TensorFlow Authors. All Rights Reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
|
||||
#ifndef TENSORFLOW_COMPILER_MLIR_XLA_MLIR_HLO_TO_HLO_H_
|
||||
#define TENSORFLOW_COMPILER_MLIR_XLA_MLIR_HLO_TO_HLO_H_
|
||||
|
||||
#include "mlir/IR/Module.h" // TF:local_config_mlir
|
||||
#include "tensorflow/compiler/mlir/tensorflow/utils/error_util.h"
|
||||
#include "tensorflow/compiler/xla/client/xla_builder.h"
|
||||
#include "tensorflow/compiler/xla/service/hlo_module.h"
|
||||
|
||||
namespace mlir {
|
||||
|
||||
// Converts a MLIR module in HLO dialect into a HloModuleProto.
|
||||
Status ConvertMlirHloToHlo(mlir::ModuleOp module, xla::HloProto* hlo_proto);
|
||||
|
||||
// Creates XlaOp equivalent of a given MLIR operation using the operand info
|
||||
// from `value_lowering` map.
|
||||
llvm::Optional<xla::XlaOp> CreateXlaOperator(
|
||||
mlir::Operation* op,
|
||||
llvm::DenseMap<mlir::Value*, xla::XlaOp>* value_lowering);
|
||||
|
||||
} // namespace mlir
|
||||
|
||||
#endif // TENSORFLOW_COMPILER_MLIR_XLA_MLIR_HLO_TO_HLO_H_
|
196
tensorflow/compiler/mlir/xla/operator_writer_gen.cc
Normal file
196
tensorflow/compiler/mlir/xla/operator_writer_gen.cc
Normal file
@ -0,0 +1,196 @@
|
||||
/* Copyright 2019 The TensorFlow Authors. All Rights Reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
|
||||
#include <sstream>
|
||||
|
||||
#include "llvm/ADT/StringExtras.h"
|
||||
#include "llvm/Support/CommandLine.h"
|
||||
#include "llvm/Support/FormatVariadic.h"
|
||||
#include "llvm/Support/PrettyStackTrace.h"
|
||||
#include "llvm/Support/Signals.h"
|
||||
#include "llvm/TableGen/Error.h"
|
||||
#include "llvm/TableGen/Main.h"
|
||||
#include "llvm/TableGen/Record.h"
|
||||
#include "llvm/TableGen/TableGenBackend.h"
|
||||
#include "mlir/TableGen/Operator.h" // TF:local_config_mlir
|
||||
|
||||
using llvm::dyn_cast;
|
||||
using llvm::LessRecord;
|
||||
using llvm::raw_ostream;
|
||||
using llvm::Record;
|
||||
using llvm::RecordKeeper;
|
||||
using llvm::StringRef;
|
||||
using mlir::tblgen::Operator;
|
||||
|
||||
// Returns the builder function name for the given op definition.
|
||||
// E.g., AddOp -> CreateAddOp
|
||||
static inline std::string GetOperatorBuilderName(StringRef op_name) {
|
||||
return "Create" + op_name.str();
|
||||
}
|
||||
|
||||
static std::string GetConversionFunction(
|
||||
mlir::tblgen::NamedAttribute named_attr) {
|
||||
auto storage_type = named_attr.attr.getStorageType();
|
||||
// For some attribute types we have a general conversion, so use that.
|
||||
if (storage_type.endswith("IntegerAttr") ||
|
||||
storage_type.endswith("FloatAttr")) {
|
||||
return "Convert" + named_attr.attr.getReturnType().str();
|
||||
}
|
||||
return "Convert_" + named_attr.name.str();
|
||||
}
|
||||
|
||||
using ArgumentName = string;
|
||||
using ArgumentDeclaration = string;
|
||||
using Argument = std::pair<ArgumentName, ArgumentDeclaration>;
|
||||
using ArgumentList = std::vector<Argument>;
|
||||
|
||||
static std::string BuildOperator(const Operator& op) {
|
||||
std::stringstream os;
|
||||
StringRef op_name = op.getCppClassName();
|
||||
std::string xla_op_name = op_name.drop_back(2).str();
|
||||
|
||||
// Signature.
|
||||
os << "static xla::XlaOp " << GetOperatorBuilderName(op_name)
|
||||
<< "(mlir::XLA::" << op_name.str() << " xla_op, "
|
||||
<< "llvm::DenseMap<mlir::Value*, xla::XlaOp>* "
|
||||
"value_lowering) {\n";
|
||||
|
||||
os << " auto& value_map = *value_lowering;\n"
|
||||
<< " auto result = xla_op.getResult();\n";
|
||||
|
||||
// Invoke the conversion function for each attribute.
|
||||
for (const auto& named_attr : op.getAttributes()) {
|
||||
os << " auto " << named_attr.name.str() << " = "
|
||||
<< GetConversionFunction(named_attr) << "("
|
||||
<< "xla_op." << named_attr.name.str() << "());\n";
|
||||
}
|
||||
|
||||
// Assumes that the client builder method names closely follow the op names
|
||||
// in the dialect. For e.g., AddOp -> xla::Add method.
|
||||
os << " auto xla_result = xla::" << xla_op_name << "(";
|
||||
|
||||
int num_operands = op.getNumOperands();
|
||||
if (num_operands == 1) {
|
||||
os << "value_map[xla_op.getOperand()]";
|
||||
} else {
|
||||
for (auto i = 0; i < num_operands; i++) {
|
||||
os << "value_map[xla_op.getOperand(" << i << ")]";
|
||||
if (i != num_operands - 1) {
|
||||
os << ", ";
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
for (const auto& named_attr : op.getAttributes()) {
|
||||
os << ", Unwrap(" << named_attr.name.str() << ")";
|
||||
}
|
||||
|
||||
os << ");\n";
|
||||
|
||||
os << " value_map[result] = xla_result;\n";
|
||||
os << " return xla_result;\n";
|
||||
os << "}\n\n";
|
||||
return os.str();
|
||||
}
|
||||
|
||||
// For each XLA op, emits a builder function that constructs the XLA op using
|
||||
// the HLO client builder.
|
||||
static void EmitOperatorBuilders(const RecordKeeper& record_keeper,
|
||||
const std::vector<Record*>& defs,
|
||||
raw_ostream* ostream) {
|
||||
raw_ostream& os = *ostream;
|
||||
|
||||
for (const auto* def : defs) {
|
||||
// Skip operations that have a custom converter.
|
||||
if (def->getValueAsBit("hasCustomHLOConverter")) continue;
|
||||
|
||||
Operator op(def);
|
||||
os << BuildOperator(op);
|
||||
}
|
||||
}
|
||||
|
||||
// Emits a builder function that returns the XlaOp object given a
|
||||
// mlir::Operation.
|
||||
//
|
||||
// The signature of the function is:
|
||||
//
|
||||
// llvm::Optional<xla::XlaOp>
|
||||
// mlir::CreateXlaOperator(
|
||||
// mlir::Operation* op,
|
||||
// llvm::DenseMap<mlir::Value*, xla::XlaOp>
|
||||
// *value_lowering);
|
||||
static void EmitBuilder(const std::vector<Record*>& defs,
|
||||
raw_ostream* ostream) {
|
||||
raw_ostream& os = *ostream;
|
||||
|
||||
// Signature
|
||||
os << "llvm::Optional<xla::XlaOp>\n"
|
||||
"mlir::CreateXlaOperator(mlir::Operation* op, "
|
||||
"llvm::DenseMap<mlir::Value*, xla::XlaOp> "
|
||||
"*value_lowering) {\n";
|
||||
|
||||
for (const auto* def : defs) {
|
||||
// Skip operations that have a custom converter.
|
||||
if (def->getValueAsBit("hasCustomHLOConverter")) continue;
|
||||
|
||||
StringRef op_name = def->getName().drop_front(4);
|
||||
|
||||
// Try to cast to each op and call the corresponding op builder.
|
||||
os << " if (auto xla_op = llvm::dyn_cast<mlir::XLA::" << op_name
|
||||
<< ">(op))\n return " << GetOperatorBuilderName(op_name)
|
||||
<< "(xla_op, value_lowering);\n";
|
||||
}
|
||||
|
||||
os << " return llvm::None;\n"
|
||||
"}\n";
|
||||
}
|
||||
|
||||
// The function below has a non-constant reference as that is required by LLVM's
|
||||
// TableGenMain.
|
||||
// NOLINTNEXTLINE
|
||||
static bool OperatorWritersMain(raw_ostream& os, RecordKeeper& records) {
|
||||
emitSourceFileHeader("MLIR XLA Builders", os);
|
||||
|
||||
// Retrieve all the definitions derived from XLA_Op and sort by record name.
|
||||
std::vector<Record*> defs = records.getAllDerivedDefinitions("XLA_Op");
|
||||
llvm::sort(defs, LessRecord());
|
||||
|
||||
for (const auto* def : defs) {
|
||||
// XLA ops in the .td file are expected to follow the naming convention:
|
||||
// XLA_<OpName>Op.
|
||||
// The generated XLA op C++ class should be XLA::<OpName>Op.
|
||||
if (!def->getName().startswith("XLA_"))
|
||||
PrintFatalError(def->getLoc(),
|
||||
"unexpected op name format: 'XLA_' prefix missing");
|
||||
if (!def->getName().endswith("Op"))
|
||||
PrintFatalError(def->getLoc(),
|
||||
"unexpected op name format: 'Op' suffix missing");
|
||||
}
|
||||
|
||||
EmitOperatorBuilders(records, defs, &os);
|
||||
os << "\n\n";
|
||||
EmitBuilder(defs, &os);
|
||||
|
||||
return false;
|
||||
}
|
||||
|
||||
int main(int argc, char** argv) {
|
||||
llvm::sys::PrintStackTraceOnErrorSignal(argv[0]);
|
||||
llvm::PrettyStackTraceProgram X(argc, argv);
|
||||
|
||||
llvm::llvm_shutdown_obj Y;
|
||||
llvm::cl::ParseCommandLineOptions(argc, argv);
|
||||
return TableGenMain(argv[0], &OperatorWritersMain);
|
||||
}
|
23
tensorflow/compiler/mlir/xla/tests/translate/BUILD
Normal file
23
tensorflow/compiler/mlir/xla/tests/translate/BUILD
Normal file
@ -0,0 +1,23 @@
|
||||
load("//tensorflow/compiler/mlir:glob_lit_test.bzl", "glob_lit_tests")
|
||||
|
||||
package(licenses = ["notice"])
|
||||
|
||||
glob_lit_tests(
|
||||
data = [":test_utilities"],
|
||||
driver = "@local_config_mlir//:run_lit.sh",
|
||||
test_file_exts = [
|
||||
"mlir",
|
||||
"hlo",
|
||||
"hlotxt",
|
||||
],
|
||||
)
|
||||
|
||||
# Bundle together all of the test utilities that are used by tests.
|
||||
filegroup(
|
||||
name = "test_utilities",
|
||||
testonly = True,
|
||||
data = [
|
||||
"//tensorflow/compiler/mlir:tf-mlir-translate",
|
||||
"@llvm//:FileCheck",
|
||||
],
|
||||
)
|
28
tensorflow/compiler/mlir/xla/tests/translate/add.hlotxt
Normal file
28
tensorflow/compiler/mlir/xla/tests/translate/add.hlotxt
Normal file
@ -0,0 +1,28 @@
|
||||
// RUN: tf-mlir-translate -hlo-text-to-mlir-hlo %s -o - | FileCheck %s
|
||||
|
||||
// This test is more thorough than those of the the other binary ops to test
|
||||
// their shared functionality.
|
||||
|
||||
HloModule main.5
|
||||
|
||||
// CHECK-LABEL: func @main(%arg0: tensor<4xf32>, %arg1: tensor<4xf32>, %arg2: tensor<f32>, %arg3: tensor<f32>) -> tensor<4xf32> {
|
||||
ENTRY %foo.5 (Arg_0.1: f32[4], Arg_1.2: f32[4], Arg_2.3: f32[], Arg_3.4: f32[]) -> f32[4] {
|
||||
%Arg_0.1 = f32[4] parameter(0)
|
||||
%Arg_1.2 = f32[4] parameter(1)
|
||||
%Arg_2.3 = f32[] parameter(2)
|
||||
%Arg_3.4 = f32[] parameter(3)
|
||||
|
||||
// Add two tensors
|
||||
// CHECK-NEXT: %0 = "xla.add"(%arg0, %arg1) {name = "add.3"} : (tensor<4xf32>, tensor<4xf32>) -> tensor<4xf32>
|
||||
%add.3 = f32[4] add(f32[4] %Arg_0.1, f32[4] %Arg_1.2)
|
||||
|
||||
// Add two scalars
|
||||
// CHECK-NEXT: %1 = "xla.add"(%arg2, %arg3) {name = "add.4"} : (tensor<f32>, tensor<f32>) -> tensor<f32>
|
||||
%add.4 = f32[] add(f32[] %Arg_2.3, f32[] %Arg_3.4)
|
||||
|
||||
// Add a tensor and scalar
|
||||
// CHECK-NEXT: %2 = "xla.add"(%0, %1) {name = "add.5"} : (tensor<4xf32>, tensor<f32>) -> tensor<4xf32>
|
||||
// CHECK-NEXT: return %2 : tensor<4xf32>
|
||||
ROOT %add.5 = f32[4] add(f32[4] %add.3, f32[] %add.4)
|
||||
}
|
||||
|
14
tensorflow/compiler/mlir/xla/tests/translate/add.mlir
Normal file
14
tensorflow/compiler/mlir/xla/tests/translate/add.mlir
Normal file
@ -0,0 +1,14 @@
|
||||
// RUN: tf-mlir-translate -mlir-hlo-to-hlo-text %s | FileCheck %s
|
||||
|
||||
// CHECK-LABEL: ENTRY %main.5 (Arg_0.1: f32[4], Arg_1.2: f32[4]) -> f32[4] {
|
||||
func @main(%arg0: tensor<4xf32>, %arg1: tensor<4xf32>) -> tensor<4xf32> {
|
||||
// CHECK-NEXT: %Arg_0.1 = f32[4] parameter(0)
|
||||
// CHECK-NEXT: %Arg_1.2 = f32[4] parameter(1)
|
||||
|
||||
// CHECK-NEXT: %add.3 = f32[4] add(f32[4] %Arg_0.1, f32[4] %Arg_1.2)
|
||||
%0 = "xla.add"(%arg0, %arg1) : (tensor<4xf32>, tensor<4xf32>) -> tensor<4xf32>
|
||||
|
||||
// CHECK-NEXT: ROOT %add.4 = f32[4] add(f32[4] %add.3, f32[4] %Arg_1.2)
|
||||
%1 = "xla.add"(%0, %arg1) : (tensor<4xf32>, tensor<4xf32>) -> tensor<4xf32>
|
||||
return %1 : tensor<4xf32>
|
||||
}
|
14
tensorflow/compiler/mlir/xla/tests/translate/and.hlotxt
Normal file
14
tensorflow/compiler/mlir/xla/tests/translate/and.hlotxt
Normal file
@ -0,0 +1,14 @@
|
||||
// RUN: tf-mlir-translate -hlo-text-to-mlir-hlo %s -o - | FileCheck %s
|
||||
|
||||
HloModule main.5
|
||||
|
||||
// CHECK-LABEL: func @main(%arg0: tensor<4xf32>, %arg1: tensor<4xf32>) -> tensor<4xf32> {
|
||||
ENTRY %foo.5 (Arg_0.1: f32[4], Arg_1.2: f32[4]) -> f32[4] {
|
||||
%Arg_0.1 = f32[4] parameter(0)
|
||||
%Arg_1.2 = f32[4] parameter(1)
|
||||
|
||||
// CHECK-NEXT: %0 = "xla.and"(%arg0, %arg1) {name = "and.3"} : (tensor<4xf32>, tensor<4xf32>) -> tensor<4xf32>
|
||||
// CHECK-NEXT: return %0 : tensor<4xf32>
|
||||
ROOT %and.3 = f32[4] and(f32[4] %Arg_0.1, f32[4] %Arg_1.2)
|
||||
}
|
||||
|
@ -0,0 +1,26 @@
|
||||
// RUN: tf-mlir-translate -mlir-hlo-to-hlo-text %s | FileCheck %s
|
||||
|
||||
// CHECK-LABEL: ENTRY %main.13 (Arg_0.1: s32[1,4], Arg_1.2: s32[2,4], Arg_2.3: s32[2,3,4]) -> s32[2,3,4] {
|
||||
func @main(%arg0: tensor<1x4xi32>, %arg1: tensor<2x4xi32>, %arg2: tensor<2x3x4xi32>) -> tensor<2x3x4xi32> {
|
||||
// Same rank degenerate broadcast
|
||||
// CHECK-NEXT: %Arg_0.1 = s32[1,4] parameter(0)
|
||||
// CHECK-NEXT: %reshape.4 = s32[4] reshape(s32[1,4] %Arg_0.1)
|
||||
// CHECK-NEXT: %broadcast.5 = s32[2,4] broadcast(s32[4] %reshape.4)
|
||||
// CHECK-NEXT: %Arg_1.2 = s32[2,4] parameter(1)
|
||||
// CHECK-NEXT: %add.6 = s32[2,4] add(s32[2,4] %broadcast.5, s32[2,4] %Arg_1.2)
|
||||
%0 = "xla.add"(%arg0, %arg1) : (tensor<1x4xi32>, tensor<2x4xi32>) -> tensor<2x4xi32>
|
||||
|
||||
// Broadcast up rank
|
||||
// CHECK-NEXT: %broadcast.7 = s32[2,3,4] broadcast(s32[2,4] %Arg_1.2), dimensions={0,2}
|
||||
// CHECK-NEXT: %Arg_2.3 = s32[2,3,4] parameter(2)
|
||||
// CHECK-NEXT: %add.8 = s32[2,3,4] add(s32[2,3,4] %broadcast.7, s32[2,3,4] %Arg_2.3)
|
||||
%1 = "xla.add"(%arg1, %arg2) {broadcast_dimensions = dense<[0,2]> : tensor<2xi64>} : (tensor<2x4xi32>, tensor<2x3x4xi32>) -> tensor<2x3x4xi32>
|
||||
|
||||
// Broadcast up rank + degenerate broadcast
|
||||
// CHECK-NEXT: %broadcast.9 = s32[2,1,4] broadcast(s32[1,4] %Arg_0.1), dimensions={1,2}
|
||||
// CHECK-NEXT: %reshape.10 = s32[2,4] reshape(s32[2,1,4] %broadcast.9)
|
||||
// CHECK-NEXT: %broadcast.11 = s32[2,3,4] broadcast(s32[2,4] %reshape.10), dimensions={0,2}
|
||||
// CHECK-NEXT: ROOT %add.12 = s32[2,3,4] add(s32[2,3,4] %broadcast.11, s32[2,3,4] %Arg_2.3)
|
||||
%2 = "xla.add"(%arg0, %arg2) {broadcast_dimensions = dense<[1,2]> : tensor<2xi64>} : (tensor<1x4xi32>, tensor<2x3x4xi32>) -> tensor<2x3x4xi32>
|
||||
return %2 : tensor<2x3x4xi32>
|
||||
}
|
@ -0,0 +1,9 @@
|
||||
// RUN: tf-mlir-translate -mlir-hlo-to-hlo-text %s | FileCheck %s
|
||||
|
||||
// CHECK-LABEL: ENTRY %main.3 (Arg_0.1: s32[4]) -> s32[1,2,3,4] {
|
||||
func @main(%arg0: tensor<4xi32>) -> tensor<1x2x3x4xi32> {
|
||||
// CHECK-NEXT: %Arg_0.1 = s32[4] parameter(0)
|
||||
// CHECK-NEXT: ROOT %broadcast.2 = s32[1,2,3,4] broadcast(s32[4] %Arg_0.1), dimensions={3}
|
||||
%0 = "xla.broadcast"(%arg0) {broadcast_sizes = dense<[1,2,3]> : tensor<3xi64>} : (tensor<4xi32>) -> tensor<1x2x3x4xi32>
|
||||
return %0 : tensor<1x2x3x4xi32>
|
||||
}
|
@ -0,0 +1,20 @@
|
||||
// RUN: tf-mlir-translate -hlo-text-to-mlir-hlo %s -o - | FileCheck %s
|
||||
|
||||
HloModule main
|
||||
|
||||
// CHECK-LABEL: func @main(%arg0: tensor<1x2xf32>) -> tensor<3x1x2xf32> {
|
||||
ENTRY %main {
|
||||
%Arg_0.1 = f32[1, 2] parameter(0)
|
||||
|
||||
// CHECK-NEXT: %0 = "xla.broadcast_in_dim"(%arg0) {broadcast_dimensions = dense<[0, 1]> : tensor<2xi64>, name = "broadcast.2"} : (tensor<1x2xf32>) -> tensor<1x2x3xf32>
|
||||
%broadcast.2 = f32[1,2,3] broadcast(%Arg_0.1), dimensions={0,1}
|
||||
|
||||
// Degenerate broadcast
|
||||
// CHECK-NEXT: %1 = "xla.broadcast_in_dim"(%arg0) {name = "broadcast.3"} : (tensor<1x2xf32>) -> tensor<3x2xf32>
|
||||
broadcast.3 = f32[3,2] broadcast(%Arg_0.1), dimensions={}
|
||||
|
||||
// CHECK-NEXT: %2 = "xla.broadcast_in_dim"(%arg0) {broadcast_dimensions = dense<[1, 2]> : tensor<2xi64>, name = "broadcast.4"} : (tensor<1x2xf32>) -> tensor<3x1x2xf32>
|
||||
// CHECK-NEXT: return %2 : tensor<3x1x2xf32>
|
||||
ROOT broadcast.4 = f32[3,1,2] broadcast(%Arg_0.1), dimensions={1, 2}
|
||||
}
|
||||
|
19
tensorflow/compiler/mlir/xla/tests/translate/call.hlotxt
Normal file
19
tensorflow/compiler/mlir/xla/tests/translate/call.hlotxt
Normal file
@ -0,0 +1,19 @@
|
||||
// RUN: tf-mlir-translate -hlo-text-to-mlir-hlo %s -o - | FileCheck %s
|
||||
|
||||
HloModule foo
|
||||
|
||||
// CHECK-LABEL: func @call(%arg0: tensor<i64>) -> tensor<i64> {
|
||||
%call (arg_1: s64[]) -> s64[] {
|
||||
%arg_1 = s64[] parameter(0), metadata={op_name="XLA_Args"}
|
||||
// CHECK-NEXT: %0 = "xla.add"(%arg0, %arg0) {name = "compare.2"} : (tensor<i64>, tensor<i64>) -> tensor<i64>
|
||||
// CHECK-NEXT: return %0 : tensor<i64>
|
||||
ROOT %compare.2 = s64[] add(%arg_1, %arg_1), metadata={op_type="Less" op_name="Less"}
|
||||
}
|
||||
|
||||
// CHECK-LABEL: func @main(%arg0: tensor<i64>) -> tensor<i64> {
|
||||
ENTRY %foo (arg0.1: s64[]) -> s64[] {
|
||||
%arg0.1 = s64[] parameter(0), metadata={op_name="XLA_Args"}
|
||||
// CHECK-NEXT: %0 = call @call(%arg0) : (tensor<i64>) -> tensor<i64>
|
||||
// CHECK-NEXT: return %0 : tensor<i64>
|
||||
ROOT %call.2 = s64[] call(%arg0.1), to_apply=%call
|
||||
}
|
21
tensorflow/compiler/mlir/xla/tests/translate/comp.hlotxt
Normal file
21
tensorflow/compiler/mlir/xla/tests/translate/comp.hlotxt
Normal file
@ -0,0 +1,21 @@
|
||||
// RUN: tf-mlir-translate -hlo-text-to-mlir-hlo %s -o - | FileCheck %s
|
||||
|
||||
HloModule main
|
||||
|
||||
// CHECK-LABEL: func @main(%arg0: tensor<3xf32>, %arg1: tensor<3xf32>, %arg2: tensor<1xf32>) -> tensor<3xi1> {
|
||||
ENTRY %main (Arg_0.1: f32[3], Arg_1.2: f32[3], Arg_2.3: f32[1]) -> pred[3] {
|
||||
%Arg_0.1 = f32[3] parameter(0)
|
||||
%Arg_1.2 = f32[3] parameter(1)
|
||||
%Arg_2.3 = f32[1] parameter(2)
|
||||
|
||||
// CHECK-NEXT: %0 = "xla.compare"(%arg0, %arg1) {comparison_direction = "EQ", name = "compare.4"} : (tensor<3xf32>, tensor<3xf32>) -> tensor<3xi1>
|
||||
%compare.4 = pred[3] compare(Arg_0.1, Arg_1.2), direction=EQ
|
||||
|
||||
// CHECK-NEXT: %1 = "xla.compare"(%arg0, %arg1) {comparison_direction = "LE", name = "compare.5"} : (tensor<3xf32>, tensor<3xf32>) -> tensor<3xi1>
|
||||
%compare.5 = pred[3] compare(Arg_0.1, Arg_1.2), direction=LE
|
||||
|
||||
// Requires broadcast of compatible tensors.
|
||||
// CHECK-NEXT: %2 = "xla.compare"(%arg0, %arg2) {comparison_direction = "GT", name = "compare.6"} : (tensor<3xf32>, tensor<1xf32>) -> tensor<3xi1>
|
||||
// CHECK-NEXT: return %2 : tensor<3xi1>
|
||||
ROOT %compare.6 = pred[3] compare(Arg_0.1, Arg_2.3), direction=GT
|
||||
}
|
13
tensorflow/compiler/mlir/xla/tests/translate/concat.hlotxt
Normal file
13
tensorflow/compiler/mlir/xla/tests/translate/concat.hlotxt
Normal file
@ -0,0 +1,13 @@
|
||||
// RUN: tf-mlir-translate -hlo-text-to-mlir-hlo %s -o - | FileCheck %s
|
||||
|
||||
HloModule main.5
|
||||
|
||||
// CHECK-LABEL: func @main(%arg0: tensor<4x1xf32>, %arg1: tensor<4x2xf32>) -> tensor<4x3xf32> {
|
||||
ENTRY %foo.5 (Arg_0.1: f32[4, 1], Arg_1.2: f32[4, 2]) -> f32[4, 3] {
|
||||
%Arg_0.1 = f32[4, 1] parameter(0)
|
||||
%Arg_1.2 = f32[4, 2] parameter(1)
|
||||
|
||||
// CHECK-NEXT: %0 = "xla.concatenate"(%arg0, %arg1) {dimension = 1 : i64} : (tensor<4x1xf32>, tensor<4x2xf32>) -> tensor<4x3xf32>
|
||||
// CHECK-NEXT: return %0 : tensor<4x3xf32>
|
||||
ROOT %concatenate.3 = f32[4, 3] concatenate(f32[4, 1] %Arg_0.1, f32[4, 2] %Arg_1.2), dimensions={1}
|
||||
}
|
21
tensorflow/compiler/mlir/xla/tests/translate/const.hlotxt
Normal file
21
tensorflow/compiler/mlir/xla/tests/translate/const.hlotxt
Normal file
@ -0,0 +1,21 @@
|
||||
// RUN: tf-mlir-translate -hlo-text-to-mlir-hlo %s -o - | FileCheck %s
|
||||
|
||||
HloModule tfcompile.7
|
||||
|
||||
// CHECK-LABEL: func @main() -> tensor<2x2x1x1xf32> {
|
||||
ENTRY %tfcompile.7 {
|
||||
|
||||
// Scalar/0D tensor constant
|
||||
// CHECK-NEXT: %cst = constant {name = "constant.0"} dense<1> : tensor<i64>
|
||||
%constant.0 = s64[] constant(1)
|
||||
|
||||
// Note that double brackets "[[" have to be escaped as they denote variables
|
||||
// in FileCheck. The only way to do so is to drop into regex with "{{"
|
||||
// CHECK-NEXT: %cst_0 = constant {name = "constant.1"} dense<{{\[\[\[\[}}1.000000e+00]], {{\[\[}}2.000000e+00]]], {{\[\[\[}}3.000000e+00]], {{\[\[}}4.000000e+00]]]]> : tensor<2x2x1x1xf32>
|
||||
// CHECK-NEXT: return %cst_0 : tensor<2x2x1x1xf32>
|
||||
ROOT %constant.1 = f32[2,2,1,1]{3,2,1,0} constant({{{{1.0}},{{2.0}}},{{{3.0}},{{4.0}}}}), metadata={op_type="Conv2D" op_name="embedded_inference/conv_model/conv_0/Conv2D"}
|
||||
}
|
||||
|
||||
|
||||
|
||||
|
31
tensorflow/compiler/mlir/xla/tests/translate/conv.hlotxt
Normal file
31
tensorflow/compiler/mlir/xla/tests/translate/conv.hlotxt
Normal file
@ -0,0 +1,31 @@
|
||||
// RUN: tf-mlir-translate -hlo-text-to-mlir-hlo %s -o - | FileCheck %s
|
||||
|
||||
HloModule tfcompile.7
|
||||
|
||||
// TODO(b/129422361) Potentially update when copy, reshape, and conv have actual
|
||||
// implementations with attributes, etc.
|
||||
// CHECK-LABEL: func @main(%arg0: tensor<1x16x16x1xf32>) -> tuple<tensor<1x16x16x1xf32>> {
|
||||
ENTRY %tfcompile.7 {
|
||||
%arg0.1 = f32[1,16,16,1]{3,2,1,0} parameter(0), metadata={op_name="XLA_Args"}
|
||||
|
||||
// CHECK-NEXT: %0 = "xla.copy"(%arg0) {name = "copy.1"} : (tensor<1x16x16x1xf32>) -> tensor<1x16x16x1xf32>
|
||||
%copy.1 = f32[1,16,16,1]{2,1,3,0} copy(%arg0.1), metadata={op_name="XLA_Args"}
|
||||
|
||||
// CHECK-NEXT: %1 = "xla.reshape"(%0) {name = "reshape.2"} : (tensor<1x16x16x1xf32>) -> tensor<1x16x16x1xf32>
|
||||
%reshape.2 = f32[1,16,16,1]{2,1,3,0} reshape(%copy.1)
|
||||
|
||||
// Note that double brackets "[[" have to be escaped as they denote variables
|
||||
// in FileCheck. The only way to do so is to drop into regex with "{{"
|
||||
// CHECK-NEXT: %cst = constant {name = "constant.3"} dense<{{\[\[\[\[}}5.000000e-01]], {{\[\[}}-6.000000e-01]]], {{\[\[\[}}3.000000e-01]], {{\[\[}}-1.000000e-01]]]]> : tensor<2x2x1x1xf32>
|
||||
%constant.3 = f32[2,2,1,1]{3,2,1,0} constant({{{{0.5}}, {{-0.6}}}, {{{0.3}}, {{-0.1}}}}), metadata={op_type="Conv2D" op_name="embedded_inference/conv_model/conv_0/Conv2D"}
|
||||
|
||||
// CHECK-NEXT: %2 = "xla.conv"(%1, %cst) {name = "convolution.4"} : (tensor<1x16x16x1xf32>, tensor<2x2x1x1xf32>) -> tensor<1x16x16x1xf32>
|
||||
%convolution.4 = f32[1,16,16,1]{2,1,3,0} convolution(%reshape.2, %constant.3), window={size=2x2 pad=0_1x0_1}, dim_labels=b01f_01io->b01f, metadata={op_type="Conv2D" op_name="embedded_inference/conv_model/conv_0/Conv2D"}
|
||||
|
||||
// CHECK-NEXT: %3 = "xla.reshape"(%2) {name = "reshape.5"} : (tensor<1x16x16x1xf32>) -> tensor<1x16x16x1xf32>
|
||||
%reshape.5 = f32[1,16,16,1]{3,2,1,0} reshape(%convolution.4), metadata={op_name="XLA_Retvals"}
|
||||
|
||||
// CHECK-NEXT: %4 = "xla.tuple"(%3) {name = "tuple.6"} : (tensor<1x16x16x1xf32>) -> tuple<tensor<1x16x16x1xf32>>
|
||||
// CHECK-NEXT: return %4 : tuple<tensor<1x16x16x1xf32>>
|
||||
ROOT %tuple.6 = (f32[1,16,16,1]{3,2,1,0}) tuple(%reshape.5), metadata={op_name="XLA_Retvals"}
|
||||
}
|
20
tensorflow/compiler/mlir/xla/tests/translate/convert.hlotxt
Normal file
20
tensorflow/compiler/mlir/xla/tests/translate/convert.hlotxt
Normal file
@ -0,0 +1,20 @@
|
||||
// RUN: tf-mlir-translate -hlo-text-to-mlir-hlo %s -o - | FileCheck %s
|
||||
|
||||
HloModule main.5
|
||||
|
||||
// CHECK-LABEL: func @main(%arg0: tensor<4xf32>, %arg1: tensor<f32>) -> tensor<4xf64> {
|
||||
ENTRY %foo.5 (Arg_0.1: f32[4], Arg_1.2: f32[]) -> f64[4] {
|
||||
%Arg_0.1 = f32[4] parameter(0)
|
||||
%Arg_1.2 = f32[] parameter(1)
|
||||
|
||||
// CHECK-NEXT: %0 = "xla.convert"(%arg0) {name = "convert.3"} : (tensor<4xf32>) -> tensor<4xf64>
|
||||
%convert.3 = f64[4] convert(f32[4] %Arg_0.1)
|
||||
|
||||
// CHECK-NEXT: %1 = "xla.convert"(%arg1) {name = "convert.4"} : (tensor<f32>) -> tensor<f64>
|
||||
%convert.4 = f64[] convert(f32[] %Arg_1.2)
|
||||
|
||||
// CHECK-NEXT: %2 = "xla.add"(%0, %1) {name = "add.5"} : (tensor<4xf64>, tensor<f64>) -> tensor<4xf64>
|
||||
// CHECK-NEXT: return %2 : tensor<4xf64>
|
||||
ROOT %add.5 = f64[4] add(f64[4] %convert.3, f64[] %convert.4)
|
||||
}
|
||||
|
14
tensorflow/compiler/mlir/xla/tests/translate/div.hlotxt
Normal file
14
tensorflow/compiler/mlir/xla/tests/translate/div.hlotxt
Normal file
@ -0,0 +1,14 @@
|
||||
// RUN: tf-mlir-translate -hlo-text-to-mlir-hlo %s -o - | FileCheck %s
|
||||
|
||||
HloModule main.5
|
||||
|
||||
// CHECK-LABEL: func @main(%arg0: tensor<4xf32>, %arg1: tensor<4xf32>) -> tensor<4xf32> {
|
||||
ENTRY %foo.5 (Arg_0.1: f32[4], Arg_1.2: f32[4]) -> f32[4] {
|
||||
%Arg_0.1 = f32[4] parameter(0)
|
||||
%Arg_1.2 = f32[4] parameter(1)
|
||||
|
||||
// CHECK-NEXT: %0 = "xla.div"(%arg0, %arg1) {name = "divide.3"} : (tensor<4xf32>, tensor<4xf32>) -> tensor<4xf32>
|
||||
// CHECK-NEXT: return %0 : tensor<4xf32>
|
||||
ROOT %divide.3 = f32[4] divide(f32[4] %Arg_0.1, f32[4] %Arg_1.2)
|
||||
}
|
||||
|
23
tensorflow/compiler/mlir/xla/tests/translate/dot.hlotxt
Normal file
23
tensorflow/compiler/mlir/xla/tests/translate/dot.hlotxt
Normal file
@ -0,0 +1,23 @@
|
||||
// RUN: tf-mlir-translate -hlo-text-to-mlir-hlo %s -o - | FileCheck %s
|
||||
|
||||
HloModule main
|
||||
|
||||
// CHECK-LABEL: func @main(%arg0: tensor<1x4xf32>, %arg1: tensor<4x1xf32>) -> tensor<f32> {
|
||||
ENTRY %main (Arg_0.1: f32[1, 4], Arg_1.2: f32[4, 1]) -> f32[] {
|
||||
%Arg_0.1 = f32[1, 4] parameter(0)
|
||||
%Arg_1.2 = f32[4, 1] parameter(1)
|
||||
|
||||
// CHECK-NEXT: %0 = "xla.dot"(%arg0, %arg1) {name = "dot.3", precision_config = ["HIGH", "HIGHEST"]} : (tensor<1x4xf32>, tensor<4x1xf32>) -> tensor<f32>
|
||||
dot.3 = f32[] dot(Arg_0.1, Arg_1.2), lhs_contracting_dims={1}, rhs_contracting_dims={0}, operand_precision={high,highest}
|
||||
|
||||
// CHECK-NEXT: %1 = "xla.dot"(%arg0, %arg1) {name = "dot.4", precision_config = ["HIGHEST", "DEFAULT"]} : (tensor<1x4xf32>, tensor<4x1xf32>) -> tensor<f32>
|
||||
dot.4 = f32[] dot(Arg_0.1, Arg_1.2), lhs_contracting_dims={1}, rhs_contracting_dims={0}, operand_precision={highest,default}
|
||||
|
||||
// CHECK-NEXT: %2 = "xla.dot"(%arg0, %arg1) {name = "dot.5", precision_config = ["DEFAULT", "DEFAULT"]} : (tensor<1x4xf32>, tensor<4x1xf32>) -> tensor<f32>
|
||||
%dot.5 = f32[] dot(Arg_0.1, Arg_1.2), lhs_contracting_dims={1}, rhs_contracting_dims={0}, operand_precision={default,default}
|
||||
|
||||
// TODO(b/129709049) consider making this default precision config inferred.
|
||||
// CHECK-NEXT: %3 = "xla.dot"(%arg0, %arg1) {name = "dot.6", precision_config = ["DEFAULT", "DEFAULT"]} : (tensor<1x4xf32>, tensor<4x1xf32>) -> tensor<f32>
|
||||
// CHECK-NEXT: return %3 : tensor<f32>
|
||||
ROOT %dot.6 = f32[] dot(Arg_0.1, Arg_1.2), lhs_contracting_dims={1}, rhs_contracting_dims={0}
|
||||
}
|
@ -0,0 +1,26 @@
|
||||
// RUN: tf-mlir-translate -hlo-text-to-mlir-hlo %s -o - | FileCheck %s
|
||||
|
||||
HloModule main
|
||||
|
||||
// CHECK-LABEL: func @dynamic.update.slice.1(%arg0: tensor<4x4xf32>, %arg1: tensor<1x4xf32>, %arg2: tensor<f32>, %arg3: tensor<f32>) -> tensor<4x4xf32> {
|
||||
%dynamic.update.slice.1 (Arg_0.1: f32[4, 4], Arg_1.2: f32[1, 4], Arg_2.3: f32[], Arg_3.4: f32[]) -> f32[4, 4] {
|
||||
%Arg_0.1 = f32[4, 4] parameter(0)
|
||||
%Arg_1.2 = f32[1, 4] parameter(1)
|
||||
%Arg_2.3 = f32[] parameter(2)
|
||||
%Arg_3.4 = f32[] parameter(3)
|
||||
|
||||
// CHECK-NEXT: %0 = "xla.dynamic-update-slice"(%arg0, %arg1, %arg2, %arg3) : (tensor<4x4xf32>, tensor<1x4xf32>, tensor<f32>, tensor<f32>) -> tensor<4x4xf32>
|
||||
// CHECK-NEXT: return %0 : tensor<4x4xf32>
|
||||
ROOT %dynamic-update-slice.5 = f32[4, 4] dynamic-update-slice(%Arg_0.1, %Arg_1.2, %Arg_2.3, %Arg_3.4)
|
||||
}
|
||||
|
||||
// CHECK-LABEL: func @main(%arg0: tensor<4xf32>, %arg1: tensor<2xf32>, %arg2: tensor<f32>) -> tensor<4xf32>
|
||||
%dynamic.update.slice.2 (Arg_0.1: f32[4], Arg_1.2: f32[2], Arg_2.3: f32[]) -> f32[4] {
|
||||
%Arg_0.1 = f32[4] parameter(0)
|
||||
%Arg_1.2 = f32[2] parameter(1)
|
||||
%Arg_2.3 = f32[] parameter(2)
|
||||
|
||||
// CHECK-NEXT: %0 = "xla.dynamic-update-slice"(%arg0, %arg1, %arg2) : (tensor<4xf32>, tensor<2xf32>, tensor<f32>) -> tensor<4xf32>
|
||||
// CHECK-NEXT: return %0 : tensor<4xf32>
|
||||
ROOT %dynamic-update-slice.5 = f32[4] dynamic-update-slice(%Arg_0.1, %Arg_1.2, %Arg_2.3)
|
||||
}
|
@ -0,0 +1,103 @@
|
||||
// RUN: tf-mlir-translate -hlo-text-to-mlir-hlo %s -o - | FileCheck %s
|
||||
|
||||
// This test comes from a fully connected reference model.
|
||||
|
||||
HloModule tfcompile.48
|
||||
|
||||
// CHECK-LABEL: func @main(%arg0: tensor<1x300xf32>, %arg1: tensor<1x300x3x1xf32>) -> tuple<tensor<300x1x5xf32>> {
|
||||
ENTRY %tfcompile.48 {
|
||||
%arg0.1 = f32[1,300] parameter(0)
|
||||
%arg1.2 = f32[1,300,3,1] parameter(1)
|
||||
|
||||
// CHECK-NEXT: %0 = "xla.reshape"(%arg0) {name = "reshape.3"} : (tensor<1x300xf32>) -> tensor<1x300xf32>
|
||||
%reshape.3 = f32[1,300] reshape(%arg0.1)
|
||||
|
||||
// CHECK-NEXT: %1 = "xla.transpose"(%0) {name = "transpose.27", permutation = dense<[1, 0]> : tensor<2xi64>} : (tensor<1x300xf32>) -> tensor<300x1xf32>
|
||||
%transpose.27 = f32[300,1] transpose(%reshape.3), dimensions={1,0}
|
||||
|
||||
// CHECK-NEXT: %2 = "xla.reshape"(%1) {name = "reshape.28"} : (tensor<300x1xf32>) -> tensor<300x1x1xf32>
|
||||
%reshape.28 = f32[300,1,1] reshape(%transpose.27)
|
||||
|
||||
// CHECK-NEXT: %3 = "xla.reshape"(%2) {name = "reshape.29"} : (tensor<300x1x1xf32>) -> tensor<300x1xf32>
|
||||
%reshape.29 = f32[300,1] reshape(%reshape.28)
|
||||
|
||||
// CHECK-NEXT: %4 = "xla.broadcast_in_dim"(%3) {broadcast_dimensions = dense<[0, 1]> : tensor<2xi64>, name = "broadcast.30"} : (tensor<300x1xf32>) -> tensor<300x1x5xf32>
|
||||
%broadcast.30 = f32[300,1,5] broadcast(%reshape.29), dimensions={0,1}
|
||||
|
||||
// CHECK-NEXT: %cst = constant {name = "constant.8"} dense<1.000000e+00> : tensor<f32>
|
||||
%constant.8 = f32[] constant(1)
|
||||
|
||||
// CHECK-NEXT: %5 = "xla.broadcast_in_dim"(%cst) {name = "broadcast.9"} : (tensor<f32>) -> tensor<300x1x5xf32>
|
||||
%broadcast.9 = f32[300,1,5] broadcast(%constant.8), dimensions={}
|
||||
|
||||
// CHECK-NEXT: %6 = "xla.mul"(%4, %5) {name = "multiply.31"} : (tensor<300x1x5xf32>, tensor<300x1x5xf32>) -> tensor<300x1x5xf32>
|
||||
%multiply.31 = f32[300,1,5] multiply(%broadcast.30, %broadcast.9)
|
||||
|
||||
// CHECK-NEXT: %cst_0 = constant {name = "constant.32"} dense<0.000000e+00> : tensor<f32>
|
||||
%constant.32 = f32[] constant(0)
|
||||
|
||||
// CHECK-NEXT: %7 = "xla.broadcast_in_dim"(%cst_0) {name = "broadcast.33"} : (tensor<f32>) -> tensor<300x1x5xf32>
|
||||
%broadcast.33 = f32[300,1,5] broadcast(%constant.32), dimensions={}
|
||||
|
||||
// CHECK-NEXT: %8 = "xla.compare"(%6, %7) {comparison_direction = "GT", name = "compare.34"} : (tensor<300x1x5xf32>, tensor<300x1x5xf32>) -> tensor<300x1x5xi1>
|
||||
%compare.34 = pred[300,1,5] compare(%multiply.31, %broadcast.33), direction=GT
|
||||
|
||||
// CHECK-NEXT: %cst_1 = constant {name = "constant.10"} dense<0.000000e+00> : tensor<f32>
|
||||
%constant.10 = f32[] constant(0)
|
||||
|
||||
// CHECK-NEXT: %9 = "xla.broadcast_in_dim"(%cst_1) {name = "broadcast.11"} : (tensor<f32>) -> tensor<300x1x5xf32>
|
||||
%broadcast.11 = f32[300,1,5] broadcast(%constant.10), dimensions={}
|
||||
|
||||
// CHECK-NEXT: %cst_2 = constant {name = "constant.40"} dense<0.000000e+00> : tensor<f32>
|
||||
%constant.40 = f32[] constant(0)
|
||||
|
||||
// CHECK-NEXT: %10 = "xla.broadcast_in_dim"(%cst_2) {name = "broadcast.41"} : (tensor<f32>) -> tensor<300x5xf32>
|
||||
%broadcast.41 = f32[300,5] broadcast(%constant.40), dimensions={}
|
||||
|
||||
// CHECK-NEXT: %11 = "xla.copy"(%arg1) {name = "copy.1"} : (tensor<1x300x3x1xf32>) -> tensor<1x300x3x1xf32>
|
||||
%copy.1 = f32[1,300,3,1] copy(%arg1.2)
|
||||
|
||||
// CHECK-NEXT: %12 = "xla.reshape"(%11) {name = "reshape.4"} : (tensor<1x300x3x1xf32>) -> tensor<1x300x3x1xf32>
|
||||
%reshape.4 = f32[1,300,3,1] reshape(%copy.1)
|
||||
|
||||
// CHECK-NEXT: %13 = "xla.reshape"(%12) {name = "reshape.24"} : (tensor<1x300x3x1xf32>) -> tensor<1x300x3xf32>
|
||||
%reshape.24 = f32[1,300,3] reshape(%reshape.4)
|
||||
|
||||
// CHECK-NEXT: %14 = "xla.transpose"(%13) {name = "transpose.25", permutation = dense<[1, 0, 2]> : tensor<3xi64>} : (tensor<1x300x3xf32>) -> tensor<300x1x3xf32>
|
||||
%transpose.25 = f32[300,1,3] transpose(%reshape.24), dimensions={1,0,2}
|
||||
|
||||
// CHECK-NEXT: %15 = "xla.reshape"(%14) {name = "reshape.26"} : (tensor<300x1x3xf32>) -> tensor<300x3xf32>
|
||||
%reshape.26 = f32[300,3] reshape(%transpose.25)
|
||||
|
||||
// CHECK-NEXT: %cst_3 = constant {name = "constant.35"} dense<{{\[\[}}-1.060230e-01, 1.215050e-01, 8.002390e-01, -7.688850e-01, 0.0966112986], [6.890140e-01, -4.070560e-01, -0.797852993, 3.789250e-03, -2.088810e-01], [-6.085290e-01, 2.766170e-02, 2.685570e-01, 5.774010e-01, -4.284370e-01]]> : tensor<3x5xf32>
|
||||
%constant.35 = f32[3,5] constant({ { -0.106023, 0.121505, 0.800239, -0.768885, 0.0966113 }, { 0.689014, -0.407056, -0.797853, 0.00378925, -0.208881 }, { -0.608529, 0.0276617, 0.268557, 0.577401, -0.428437 } })
|
||||
|
||||
// TODO(b/129709049) consider making this default precision config implied.
|
||||
// CHECK-NEXT: %16 = "xla.dot"(%15, %cst_3) {name = "dot.36", precision_config = ["DEFAULT", "DEFAULT"]} : (tensor<300x3xf32>, tensor<3x5xf32>) -> tensor<300x5xf32>
|
||||
%dot.36 = f32[300,5] dot(%reshape.26, %constant.35), lhs_contracting_dims={1}, rhs_contracting_dims={0}
|
||||
|
||||
// CHECK-NEXT: %cst_4 = constant {name = "constant.37"} dense<0.000000e+00> : tensor<5xf32>
|
||||
%constant.37 = f32[5]{0} constant({0, 0, 0, 0, 0})
|
||||
|
||||
// CHECK-NEXT: %17 = "xla.broadcast_in_dim"(%cst_4) {broadcast_dimensions = dense<1> : tensor<1xi64>, name = "broadcast.38"} : (tensor<5xf32>) -> tensor<300x5xf32>
|
||||
%broadcast.38 = f32[300,5] broadcast(%constant.37), dimensions={1}
|
||||
|
||||
// CHECK-NEXT: %18 = "xla.add"(%16, %17) {name = "add.39"} : (tensor<300x5xf32>, tensor<300x5xf32>) -> tensor<300x5xf32>
|
||||
%add.39 = f32[300,5] add(%dot.36, %broadcast.38)
|
||||
|
||||
// CHECK-NEXT: %19 = "xla.max"(%10, %18) {name = "maximum.42"} : (tensor<300x5xf32>, tensor<300x5xf32>) -> tensor<300x5xf32>
|
||||
%maximum.42 = f32[300,5] maximum(%broadcast.41, %add.39)
|
||||
|
||||
// CHECK-NEXT: %20 = "xla.reshape"(%19) {name = "reshape.44"} : (tensor<300x5xf32>) -> tensor<300x1x5xf32>
|
||||
%reshape.44 = f32[300,1,5] reshape(%maximum.42)
|
||||
|
||||
// CHECK-NEXT: %21 = "xla.select"(%8, %9, %20) {name = "select.45"} : (tensor<300x1x5xi1>, tensor<300x1x5xf32>, tensor<300x1x5xf32>) -> tensor<300x1x5xf32>
|
||||
%select.45 = f32[300,1,5] select(%compare.34, %broadcast.11, %reshape.44)
|
||||
|
||||
// CHECK-NEXT: %22 = "xla.reshape"(%21) {name = "reshape.46"} : (tensor<300x1x5xf32>) -> tensor<300x1x5xf32>
|
||||
%reshape.46 = f32[300,1,5] reshape(%select.45)
|
||||
|
||||
// CHECK-NEXT: %23 = "xla.tuple"(%22) {name = "tuple.47"} : (tensor<300x1x5xf32>) -> tuple<tensor<300x1x5xf32>>
|
||||
// CHECK-NEXT: return %23 : tuple<tensor<300x1x5xf32>>
|
||||
ROOT %tuple.47 = (f32[300,1,5]) tuple(%reshape.46)
|
||||
}
|
17
tensorflow/compiler/mlir/xla/tests/translate/iota.hlotxt
Normal file
17
tensorflow/compiler/mlir/xla/tests/translate/iota.hlotxt
Normal file
@ -0,0 +1,17 @@
|
||||
// RUN: tf-mlir-translate -hlo-text-to-mlir-hlo %s -o - | FileCheck %s
|
||||
|
||||
HloModule main.5
|
||||
|
||||
// CHECK-LABEL: func @main() -> tensor<4xf32> {
|
||||
ENTRY %iota.1 () -> f32[4] {
|
||||
// CHECK-NEXT: %0 = "xla.iota"() {iota_dimension = 0 : i64} : () -> tensor<4xf32>
|
||||
// CHECK-NEXT: return %0 : tensor<4xf32>
|
||||
ROOT %iota.0 = f32[4] iota(), iota_dimension=0
|
||||
}
|
||||
|
||||
// CHECK-LABEL: func @iota.2() -> tensor<4x5xf32> {
|
||||
%iota.2 () -> f32[4, 5] {
|
||||
// CHECK-NEXT: %0 = "xla.iota"() {iota_dimension = 1 : i64} : () -> tensor<4x5xf32>
|
||||
// CHECK-NEXT: return %0 : tensor<4x5xf32>
|
||||
ROOT %iota.0 = f32[4, 5] iota(), iota_dimension=1
|
||||
}
|
14
tensorflow/compiler/mlir/xla/tests/translate/max.hlotxt
Normal file
14
tensorflow/compiler/mlir/xla/tests/translate/max.hlotxt
Normal file
@ -0,0 +1,14 @@
|
||||
// RUN: tf-mlir-translate -hlo-text-to-mlir-hlo %s -o - | FileCheck %s
|
||||
|
||||
HloModule main.5
|
||||
|
||||
// CHECK-LABEL: func @main(%arg0: tensor<4xf32>, %arg1: tensor<4xf32>) -> tensor<4xf32> {
|
||||
ENTRY %foo.5 (Arg_0.1: f32[4], Arg_1.2: f32[4]) -> f32[4] {
|
||||
%Arg_0.1 = f32[4] parameter(0)
|
||||
%Arg_1.2 = f32[4] parameter(1)
|
||||
|
||||
// CHECK-NEXT: %0 = "xla.max"(%arg0, %arg1) {name = "maximum.3"} : (tensor<4xf32>, tensor<4xf32>) -> tensor<4xf32>
|
||||
// CHECK-NEXT: return %0 : tensor<4xf32>
|
||||
ROOT %maximum.3 = f32[4] maximum(f32[4] %Arg_0.1, f32[4] %Arg_1.2)
|
||||
}
|
||||
|
14
tensorflow/compiler/mlir/xla/tests/translate/min.hlotxt
Normal file
14
tensorflow/compiler/mlir/xla/tests/translate/min.hlotxt
Normal file
@ -0,0 +1,14 @@
|
||||
// RUN: tf-mlir-translate -hlo-text-to-mlir-hlo %s -o - | FileCheck %s
|
||||
|
||||
HloModule main.5
|
||||
|
||||
// CHECK-LABEL: func @main(%arg0: tensor<4xf32>, %arg1: tensor<4xf32>) -> tensor<4xf32> {
|
||||
ENTRY %foo.5 (Arg_0.1: f32[4], Arg_1.2: f32[4]) -> f32[4] {
|
||||
%Arg_0.1 = f32[4] parameter(0)
|
||||
%Arg_1.2 = f32[4] parameter(1)
|
||||
|
||||
// CHECK-NEXT: %0 = "xla.min"(%arg0, %arg1) {name = "minimum.3"} : (tensor<4xf32>, tensor<4xf32>) -> tensor<4xf32>
|
||||
// CHECK-NEXT: return %0 : tensor<4xf32>
|
||||
ROOT %minimum.3 = f32[4] minimum(f32[4] %Arg_0.1, f32[4] %Arg_1.2)
|
||||
}
|
||||
|
14
tensorflow/compiler/mlir/xla/tests/translate/mul.hlotxt
Normal file
14
tensorflow/compiler/mlir/xla/tests/translate/mul.hlotxt
Normal file
@ -0,0 +1,14 @@
|
||||
// RUN: tf-mlir-translate -hlo-text-to-mlir-hlo %s -o - | FileCheck %s
|
||||
|
||||
HloModule main.5
|
||||
|
||||
// CHECK-LABEL: func @main(%arg0: tensor<4xf32>, %arg1: tensor<4xf32>) -> tensor<4xf32> {
|
||||
ENTRY %foo.5 (Arg_0.1: f32[4], Arg_1.2: f32[4]) -> f32[4] {
|
||||
%Arg_0.1 = f32[4] parameter(0)
|
||||
%Arg_1.2 = f32[4] parameter(1)
|
||||
|
||||
// CHECK-NEXT: %0 = "xla.mul"(%arg0, %arg1) {name = "multiply.3"} : (tensor<4xf32>, tensor<4xf32>) -> tensor<4xf32>
|
||||
// CHECK-NEXT: return %0 : tensor<4xf32>
|
||||
ROOT %multiply.3 = f32[4] multiply(f32[4] %Arg_0.1, f32[4] %Arg_1.2)
|
||||
}
|
||||
|
23
tensorflow/compiler/mlir/xla/tests/translate/pad.hlotxt
Normal file
23
tensorflow/compiler/mlir/xla/tests/translate/pad.hlotxt
Normal file
@ -0,0 +1,23 @@
|
||||
// RUN: tf-mlir-translate -hlo-text-to-mlir-hlo %s -o - | FileCheck %s
|
||||
|
||||
HloModule main.5
|
||||
|
||||
// CHECK-LABEL: func @main(%arg0: tensor<4xf32>, %arg1: tensor<f32>) -> tensor<4xf32> {
|
||||
ENTRY %padding.1 (Arg_0.1: f32[4], Arg_1.2: f32[]) -> f32[4] {
|
||||
%Arg_0.1 = f32[4] parameter(0)
|
||||
%Arg_1.2 = f32[] parameter(1)
|
||||
|
||||
// CHECK-NEXT: %0 = "xla.pad"(%arg0, %arg1) {edge_padding_high = dense<0> : tensor<1xi64>, edge_padding_low = dense<0> : tensor<1xi64>, interior_padding = dense<0> : tensor<1xi64>} : (tensor<4xf32>, tensor<f32>) -> tensor<4xf32>
|
||||
// CHECK-NEXT: return %0 : tensor<4xf32>
|
||||
ROOT %pad.3 = f32[4] pad(%Arg_0.1, %Arg_1.2), padding=0_0_0
|
||||
}
|
||||
|
||||
// CHECK-LABEL: func @padding.2(%arg0: tensor<4x4x4xf32>, %arg1: tensor<f32>) -> tensor<7x11x15xf32> {
|
||||
%padding.2 (Arg_0.1: f32[4, 4, 4], Arg_1.2: f32[]) -> f32[7, 11, 15] {
|
||||
%Arg_0.1 = f32[4, 4, 4] parameter(0)
|
||||
%Arg_1.2 = f32[] parameter(1)
|
||||
|
||||
// CHECK-NEXT: %0 = "xla.pad"(%arg0, %arg1) {edge_padding_high = dense<[2, 4, 6]> : tensor<3xi64>, edge_padding_low = dense<[1, 3, 5]> : tensor<3xi64>, interior_padding = dense<0> : tensor<3xi64>} : (tensor<4x4x4xf32>, tensor<f32>) -> tensor<7x11x15xf32>
|
||||
// CHECK-NEXT: return %0 : tensor<7x11x15xf32>
|
||||
ROOT %pad.3 = f32[7, 11, 15] pad(%Arg_0.1, %Arg_1.2), padding=1_2x3_4x5_6
|
||||
}
|
53
tensorflow/compiler/mlir/xla/tests/translate/reduce.hlotxt
Normal file
53
tensorflow/compiler/mlir/xla/tests/translate/reduce.hlotxt
Normal file
@ -0,0 +1,53 @@
|
||||
// RUN: tf-mlir-translate -hlo-text-to-mlir-hlo %s -o - | FileCheck %s
|
||||
|
||||
// This test is more thorough than those of the the other binary ops to test
|
||||
// their shared functionality.
|
||||
|
||||
HloModule main.5
|
||||
|
||||
%reduce_helper.1 (Arg_0.1: f32[4], Arg_1.2: f32[4]) -> f32[4] {
|
||||
%Arg_0.1 = f32[4] parameter(0)
|
||||
%Arg_1.2 = f32[4] parameter(1)
|
||||
ROOT %add.3 = f32[4] add(f32[4] %Arg_0.1, f32[4] %Arg_1.2)
|
||||
}
|
||||
|
||||
%reduce_helper.2 (Arg_0.1: f32[], Arg_1.2: f32[]) -> f32[] {
|
||||
%Arg_0.1 = f32[] parameter(0)
|
||||
%Arg_1.2 = f32[] parameter(1)
|
||||
ROOT %add.3 = f32[] add(f32[] %Arg_0.1, f32[] %Arg_1.2)
|
||||
}
|
||||
|
||||
%reduce_helper.3 (Arg_0.1: f32[], Arg_1.2: f32[], Arg_2.3: f32[], Arg_3.4: f32[]) -> (f32[], f32[]) {
|
||||
%Arg_0.1 = f32[] parameter(0)
|
||||
%Arg_1.2 = f32[] parameter(1)
|
||||
%Arg_2.3 = f32[] parameter(2)
|
||||
%Arg_3.4 = f32[] parameter(3)
|
||||
%add.4 = f32[] add(f32[] %Arg_0.1, f32[] %Arg_2.3)
|
||||
%add.5 = f32[] add(f32[] %Arg_1.2, f32[] %Arg_3.4)
|
||||
ROOT %tuple.6 = (f32[], f32[]) tuple(%add.4, %add.5)
|
||||
}
|
||||
|
||||
// CHECK-LABEL: func @main(%arg0: tensor<4x4xf32>, %arg1: tensor<4xf32>, %arg2: tensor<f32>) -> tuple<tuple<tensor<f32>, tensor<f32>>, tensor<f32>> {
|
||||
ENTRY %foo.5 (Arg_0.1: f32[4, 4], Arg_1.2: f32[4], Arg_2.3: f32[]) -> ((f32[], f32[]), f32[]) {
|
||||
%Arg_0.1 = f32[4, 4] parameter(0)
|
||||
%Arg_1.2 = f32[4] parameter(1)
|
||||
%Arg_2.3 = f32[] parameter(2)
|
||||
|
||||
// CHECK-NEXT: %0 = "xla.reduce"(%arg0, %arg0, %arg2, %arg2) {computation = @reduce_helper.3, dimensions = dense<[0, 1]> : tensor<2xi64>} : (tensor<4x4xf32>, tensor<4x4xf32>, tensor<f32>, tensor<f32>) -> tuple<tensor<f32>, tensor<f32>>
|
||||
%reduce.1 = f32[4] reduce(%Arg_0.1, %Arg_1.2), dimensions={0}, to_apply=%reduce_helper.1
|
||||
|
||||
// CHECK-NEXT: %1 = "xla.reduce"(%arg0, %arg1) {computation = @reduce_helper.1, dimensions = dense<0> : tensor<1xi64>} : (tensor<4x4xf32>, tensor<4xf32>) -> tensor<4xf32>
|
||||
%reduce.2 = f32[] reduce(%reduce.1, %Arg_2.3), dimensions={0}, to_apply=%reduce_helper.2
|
||||
|
||||
// CHECK-NEXT: %2 = "xla.reduce"(%1, %arg2) {computation = @reduce_helper.2, dimensions = dense<0> : tensor<1xi64>} : (tensor<4xf32>, tensor<f32>) -> tensor<f32>
|
||||
%reduce.3 = f32[] reduce(%Arg_0.1, %Arg_2.3), dimensions={0, 1}, to_apply=%reduce_helper.2
|
||||
|
||||
// CHECK-NEXT: %3 = "xla.reduce"(%arg0, %arg2) {computation = @reduce_helper.2, dimensions = dense<[0, 1]> : tensor<2xi64>} : (tensor<4x4xf32>, tensor<f32>) -> tensor<f32>
|
||||
%reduce.4 = (f32[], f32[]) reduce(%Arg_0.1, %Arg_0.1, %Arg_2.3, %Arg_2.3), dimensions={0, 1}, to_apply=%reduce_helper.3
|
||||
|
||||
// CHECK-NEXT: %4 = "xla.sub"(%2, %3) {name = "sub.5"} : (tensor<f32>, tensor<f32>) -> tensor<f32>
|
||||
%sub.5 = f32[] subtract(%reduce.2, %reduce.3)
|
||||
|
||||
ROOT %tuple.6 = ((f32[], f32[]), f32[]) tuple(%reduce.4, %sub.5)
|
||||
}
|
||||
|
21
tensorflow/compiler/mlir/xla/tests/translate/reverse.hlotxt
Normal file
21
tensorflow/compiler/mlir/xla/tests/translate/reverse.hlotxt
Normal file
@ -0,0 +1,21 @@
|
||||
// RUN: tf-mlir-translate -hlo-text-to-mlir-hlo %s -o - | FileCheck %s
|
||||
|
||||
HloModule main.5
|
||||
|
||||
// CHECK-LABEL: func @main(%arg0: tensor<4xf32>) -> tensor<4xf32> {
|
||||
ENTRY %reverse.1 (Arg_0.1: f32[4]) -> f32[4] {
|
||||
%Arg_0.1 = f32[4] parameter(0)
|
||||
|
||||
// CHECK-NEXT: %0 = "xla.reverse"(%arg0) {dimensions = dense<0> : tensor<1xi64>} : (tensor<4xf32>) -> tensor<4xf32>
|
||||
// CHECK-NEXT: return %0 : tensor<4xf32>
|
||||
ROOT reverse.2 = f32[4] reverse(%Arg_0.1), dimensions={0}
|
||||
}
|
||||
|
||||
// CHECK-LABEL: func @reverse.2(%arg0: tensor<4x4xf32>) -> tensor<4x4xf32> {
|
||||
%reverse.2 (Arg_0.1: f32[4, 4]) -> f32[4, 4] {
|
||||
%Arg_0.1 = f32[4, 4] parameter(0)
|
||||
|
||||
// CHECK-NEXT: %0 = "xla.reverse"(%arg0) {dimensions = dense<[0, 1]> : tensor<2xi64>} : (tensor<4x4xf32>) -> tensor<4x4xf32>
|
||||
// CHECK-NEXT: return %0 : tensor<4x4xf32>
|
||||
ROOT reverse.2 = f32[4, 4] reverse(%Arg_0.1), dimensions={0, 1}
|
||||
}
|
@ -0,0 +1,9 @@
|
||||
// RUN: tf-mlir-translate -hlo-text-to-mlir-hlo %s -o - | FileCheck %s
|
||||
|
||||
HloModule main.5
|
||||
|
||||
// CHECK-LABEL: func @main(%arg0: tensor<f32>) -> tensor<f32> {
|
||||
ENTRY %main.5 (Arg_0.1: f32[]) -> f32[] {
|
||||
// CHECK-NEXT: return %arg0 : tensor<f32>
|
||||
ROOT %Arg_0.1 = f32[] parameter(0)
|
||||
}
|
15
tensorflow/compiler/mlir/xla/tests/translate/select.hlotxt
Normal file
15
tensorflow/compiler/mlir/xla/tests/translate/select.hlotxt
Normal file
@ -0,0 +1,15 @@
|
||||
// RUN: tf-mlir-translate -hlo-text-to-mlir-hlo %s -o - | FileCheck %s
|
||||
|
||||
HloModule main
|
||||
|
||||
// CHECK-LABEL: func @main(%arg0: tensor<2x3xi1>, %arg1: tensor<2x3xi32>, %arg2: tensor<2x3xi32>) -> tensor<2x3xi32> {
|
||||
ENTRY %main {
|
||||
%Arg_0.1 = pred[2,3] parameter(0)
|
||||
%Arg_1.2 = s32[2,3] parameter(1)
|
||||
%Arg_2.3 = s32[2,3] parameter(2)
|
||||
|
||||
// CHECK-NEXT: %0 = "xla.select"(%arg0, %arg1, %arg2) {name = "select.4"} : (tensor<2x3xi1>, tensor<2x3xi32>, tensor<2x3xi32>) -> tensor<2x3xi32>
|
||||
// CHECK-NEXT: return %0 : tensor<2x3xi32>
|
||||
ROOT %select.4 = s32[2,3] select(pred[2,3] %Arg_0.1, s32[2,3] %Arg_1.2, s32[2,3] %Arg_2.3)
|
||||
}
|
||||
|
13
tensorflow/compiler/mlir/xla/tests/translate/select.mlir
Normal file
13
tensorflow/compiler/mlir/xla/tests/translate/select.mlir
Normal file
@ -0,0 +1,13 @@
|
||||
// RUN: tf-mlir-translate -mlir-hlo-to-hlo-text %s | FileCheck %s
|
||||
|
||||
// CHECK-LABEL: ENTRY %main
|
||||
func @main(%arg0: tensor<2x3xi1>, %arg1: tensor<2x3xi32>, %arg2: tensor<2x3xi32>) -> tensor<2x3xi32> {
|
||||
// CHECK-NEXT: %Arg_0.1 = pred[2,3] parameter(0)
|
||||
// CHECK-NEXT: %Arg_1.2 = s32[2,3] parameter(1)
|
||||
// CHECK-NEXT: %Arg_2.3 = s32[2,3] parameter(2)
|
||||
|
||||
// CHECK-NEXT: ROOT %select.4 = s32[2,3] select(pred[2,3] %Arg_0.1, s32[2,3] %Arg_1.2, s32[2,3] %Arg_2.3)
|
||||
%0 = "xla.select"(%arg0, %arg1, %arg2) {name = "select.4"} : (tensor<2x3xi1>, tensor<2x3xi32>, tensor<2x3xi32>) -> tensor<2x3xi32>
|
||||
return %0 : tensor<2x3xi32>
|
||||
}
|
||||
|
146
tensorflow/compiler/mlir/xla/tests/translate/simple.hlo
Normal file
146
tensorflow/compiler/mlir/xla/tests/translate/simple.hlo
Normal file
@ -0,0 +1,146 @@
|
||||
# RUN: tf-mlir-translate -hlo-to-mlir-hlo %s -o - | FileCheck %s
|
||||
|
||||
name: "foo.5"
|
||||
entry_computation_name: "foo.5"
|
||||
computations {
|
||||
name: "foo.5"
|
||||
instructions {
|
||||
name: "Arg_0.1"
|
||||
opcode: "parameter"
|
||||
shape {
|
||||
element_type: F32
|
||||
dimensions: 4
|
||||
layout {
|
||||
minor_to_major: 0
|
||||
format: DENSE
|
||||
}
|
||||
is_dynamic_dimension: false
|
||||
}
|
||||
metadata {
|
||||
}
|
||||
id: 1
|
||||
}
|
||||
instructions {
|
||||
name: "Arg_1.2"
|
||||
opcode: "parameter"
|
||||
shape {
|
||||
element_type: F32
|
||||
dimensions: 4
|
||||
layout {
|
||||
minor_to_major: 0
|
||||
format: DENSE
|
||||
}
|
||||
is_dynamic_dimension: false
|
||||
}
|
||||
metadata {
|
||||
}
|
||||
parameter_number: 1
|
||||
id: 2
|
||||
}
|
||||
instructions {
|
||||
name: "add.3"
|
||||
opcode: "add"
|
||||
shape {
|
||||
element_type: F32
|
||||
dimensions: 4
|
||||
layout {
|
||||
minor_to_major: 0
|
||||
format: DENSE
|
||||
}
|
||||
is_dynamic_dimension: false
|
||||
}
|
||||
metadata {
|
||||
}
|
||||
id: 3
|
||||
operand_ids: 1
|
||||
operand_ids: 2
|
||||
}
|
||||
instructions {
|
||||
name: "dot.4"
|
||||
opcode: "dot"
|
||||
shape {
|
||||
element_type: F32
|
||||
layout {
|
||||
format: DENSE
|
||||
}
|
||||
}
|
||||
metadata {
|
||||
}
|
||||
dot_dimension_numbers {
|
||||
lhs_contracting_dimensions: 0
|
||||
rhs_contracting_dimensions: 0
|
||||
}
|
||||
id: 4
|
||||
operand_ids: 3
|
||||
operand_ids: 2
|
||||
}
|
||||
program_shape {
|
||||
parameters {
|
||||
element_type: F32
|
||||
dimensions: 4
|
||||
layout {
|
||||
minor_to_major: 0
|
||||
format: DENSE
|
||||
}
|
||||
is_dynamic_dimension: false
|
||||
}
|
||||
parameters {
|
||||
element_type: F32
|
||||
dimensions: 4
|
||||
layout {
|
||||
minor_to_major: 0
|
||||
format: DENSE
|
||||
}
|
||||
is_dynamic_dimension: false
|
||||
}
|
||||
result {
|
||||
element_type: F32
|
||||
layout {
|
||||
format: DENSE
|
||||
}
|
||||
}
|
||||
parameter_names: "Arg_0"
|
||||
parameter_names: "Arg_1"
|
||||
}
|
||||
id: 5
|
||||
root_id: 4
|
||||
}
|
||||
host_program_shape {
|
||||
parameters {
|
||||
element_type: F32
|
||||
dimensions: 4
|
||||
layout {
|
||||
minor_to_major: 0
|
||||
format: DENSE
|
||||
}
|
||||
is_dynamic_dimension: false
|
||||
}
|
||||
parameters {
|
||||
element_type: F32
|
||||
dimensions: 4
|
||||
layout {
|
||||
minor_to_major: 0
|
||||
format: DENSE
|
||||
}
|
||||
is_dynamic_dimension: false
|
||||
}
|
||||
result {
|
||||
element_type: F32
|
||||
layout {
|
||||
format: DENSE
|
||||
}
|
||||
}
|
||||
parameter_names: "Arg_0"
|
||||
parameter_names: "Arg_1"
|
||||
}
|
||||
id: 5
|
||||
entry_computation_id: 5
|
||||
dynamic_parameter_binding {
|
||||
}
|
||||
|
||||
# CHECK-LABEL: func @main(%arg0: tensor<4xf32>, %arg1: tensor<4xf32>) -> tensor<f32> {
|
||||
# CHECK-NEXT: %0 = "xla.add"(%arg0, %arg1) {name = "add.3"} : (tensor<4xf32>, tensor<4xf32>) -> tensor<4xf32>
|
||||
# TODO(b/129709049) consider making this default precision config inferred.
|
||||
# CHECK-NEXT: %1 = "xla.dot"(%0, %arg1) {name = "dot.4", precision_config = ["DEFAULT", "DEFAULT"]} : (tensor<4xf32>, tensor<4xf32>) -> tensor<f32>
|
||||
# CHECK-NEXT: return %1 : tensor<f32>
|
||||
# CHECK-NEXT: }
|
17
tensorflow/compiler/mlir/xla/tests/translate/simple.hlotxt
Normal file
17
tensorflow/compiler/mlir/xla/tests/translate/simple.hlotxt
Normal file
@ -0,0 +1,17 @@
|
||||
// RUN: tf-mlir-translate -hlo-text-to-mlir-hlo %s -o - | FileCheck %s
|
||||
|
||||
HloModule foo.5
|
||||
|
||||
// CHECK-LABEL: func @main(%arg0: tensor<4xf32>, %arg1: tensor<4xf32>) -> tensor<f32> {
|
||||
ENTRY %main.5 (Arg_0.1: f32[4], Arg_1.2: f32[4]) -> f32[] {
|
||||
%Arg_0.1 = f32[4]{0} parameter(0)
|
||||
%Arg_1.2 = f32[4]{0} parameter(1)
|
||||
|
||||
// CHECK-NEXT: %0 = "xla.add"(%arg0, %arg1) {name = "add.3"} : (tensor<4xf32>, tensor<4xf32>) -> tensor<4xf32>
|
||||
%add.3 = f32[4]{0} add(f32[4]{0} %Arg_0.1, f32[4]{0} %Arg_1.2)
|
||||
|
||||
// TODO(b/129709049) consider making this default precision config inferred.
|
||||
// CHECK-NEXT: %1 = "xla.dot"(%0, %arg1) {name = "dot.4", precision_config = ["DEFAULT", "DEFAULT"]} : (tensor<4xf32>, tensor<4xf32>) -> tensor<f32>
|
||||
// CHECK-NEXT: return %1 : tensor<f32>
|
||||
ROOT %dot.4 = f32[] dot(f32[4]{0} %add.3, f32[4]{0} %Arg_1.2), lhs_contracting_dims={0}, rhs_contracting_dims={0}
|
||||
}
|
21
tensorflow/compiler/mlir/xla/tests/translate/simple.mlir
Normal file
21
tensorflow/compiler/mlir/xla/tests/translate/simple.mlir
Normal file
@ -0,0 +1,21 @@
|
||||
// RUN: tf-mlir-translate -mlir-hlo-to-hlo %s | FileCheck %s
|
||||
|
||||
func @main(tensor<4xf32>, tensor<4xf32>) -> tensor<4xf32> {
|
||||
^bb0(%arg0: tensor<4xf32>, %arg1: tensor<4xf32>):
|
||||
%0 = "xla.add"(%arg0, %arg1) : (tensor<4xf32>, tensor<4xf32>) -> tensor<4xf32>
|
||||
%1 = "xla.dot"(%0, %arg1) : (tensor<4xf32>, tensor<4xf32>) -> tensor<4xf32>
|
||||
return %1 : tensor<4xf32>
|
||||
}
|
||||
|
||||
// CHECK: name: "main
|
||||
// CHECK: entry_computation_name: "main
|
||||
// CHECK: computations {
|
||||
// CHECK: name: "main
|
||||
// CHECK: instructions {
|
||||
// CHECK: name: "Arg_
|
||||
// CHECK: opcode: "parameter"
|
||||
// CHECK: name: "add
|
||||
// CHECK: opcode: "add"
|
||||
// CHECK: name: "dot
|
||||
// CHECK: opcode: "dot"
|
||||
|
14
tensorflow/compiler/mlir/xla/tests/translate/sub.hlotxt
Normal file
14
tensorflow/compiler/mlir/xla/tests/translate/sub.hlotxt
Normal file
@ -0,0 +1,14 @@
|
||||
// RUN: tf-mlir-translate -hlo-text-to-mlir-hlo %s -o - | FileCheck %s
|
||||
|
||||
HloModule main.5
|
||||
|
||||
// CHECK-LABEL: func @main(%arg0: tensor<4xf32>, %arg1: tensor<4xf32>) -> tensor<4xf32> {
|
||||
ENTRY %foo.5 (Arg_0.1: f32[4], Arg_1.2: f32[4]) -> f32[4] {
|
||||
%Arg_0.1 = f32[4] parameter(0)
|
||||
%Arg_1.2 = f32[4] parameter(1)
|
||||
|
||||
// CHECK-NEXT: %0 = "xla.sub"(%arg0, %arg1) {name = "subtract.3"} : (tensor<4xf32>, tensor<4xf32>) -> tensor<4xf32>
|
||||
// CHECK-NEXT: return %0 : tensor<4xf32>
|
||||
ROOT %subtract.3 = f32[4] subtract(f32[4] %Arg_0.1, f32[4] %Arg_1.2)
|
||||
}
|
||||
|
12
tensorflow/compiler/mlir/xla/tests/translate/tanh.hlotxt
Normal file
12
tensorflow/compiler/mlir/xla/tests/translate/tanh.hlotxt
Normal file
@ -0,0 +1,12 @@
|
||||
// RUN: tf-mlir-translate -hlo-text-to-mlir-hlo %s -o - | FileCheck %s
|
||||
|
||||
HloModule foo
|
||||
|
||||
// CHECK-LABEL: func @main(%arg0: tensor<1x16x16x3xf32>) -> tensor<1x16x16x3xf32> {
|
||||
ENTRY %foo (arg0.1: f32[1,16,16,3]) -> f32[1,16,16,3] {
|
||||
%arg0.1 = f32[1,16,16,3]{3,2,1,0} parameter(0), metadata={op_name="XLA_Args"}
|
||||
|
||||
// CHECK-NEXT: %0 = "xla.tanh"(%arg0) {name = "tanh.3"} : (tensor<1x16x16x3xf32>) -> tensor<1x16x16x3xf32>
|
||||
// CHECK-NEXT: return %0 : tensor<1x16x16x3xf32>
|
||||
ROOT %tanh.3 = f32[1,16,16,3]{3,2,1,0} tanh(f32[1,16,16,3]{3,2,1,0} %arg0.1), metadata={op_type="Tanh" op_name="embedded_inference/tanh_model/Tanh"}
|
||||
}
|
@ -0,0 +1,13 @@
|
||||
// RUN: tf-mlir-translate -hlo-text-to-mlir-hlo %s -o - | FileCheck %s
|
||||
|
||||
HloModule main
|
||||
|
||||
// CHECK-LABEL: func @main(%arg0: tensor<1x2x3x4xi32>) -> tensor<2x1x4x3xi32> {
|
||||
ENTRY %main {
|
||||
%Arg_0.1 = s32[1,2,3,4] parameter(0)
|
||||
|
||||
// CHECK-NEXT: %0 = "xla.transpose"(%arg0) {name = "transpose.2", permutation = dense<[1, 0, 3, 2]> : tensor<4xi64>} : (tensor<1x2x3x4xi32>) -> tensor<2x1x4x3xi32>
|
||||
// CHECK-NEXT: return %0 : tensor<2x1x4x3xi32>
|
||||
ROOT %transpose.2 = s32[2,1,4,3] transpose(s32[1,2,3,4] %Arg_0.1), dimensions={1,0,3,2}
|
||||
}
|
||||
|
11
tensorflow/compiler/mlir/xla/tests/translate/transpose.mlir
Normal file
11
tensorflow/compiler/mlir/xla/tests/translate/transpose.mlir
Normal file
@ -0,0 +1,11 @@
|
||||
// RUN: tf-mlir-translate -mlir-hlo-to-hlo-text %s | FileCheck %s
|
||||
|
||||
// CHECK-LABEL: ENTRY %main
|
||||
func @main(%arg0: tensor<1x2x3x4xi32>) -> tensor<2x1x4x3xi32> {
|
||||
// CHECK-NEXT: %Arg_0.1 = s32[1,2,3,4] parameter(0)
|
||||
|
||||
// CHECK-NEXT: ROOT %transpose.2 = s32[2,1,4,3] transpose(s32[1,2,3,4] %Arg_0.1), dimensions={1,0,3,2}
|
||||
%0 = "xla.transpose"(%arg0) {permutation = dense<[1, 0, 3, 2]> : tensor<4xi64>} : (tensor<1x2x3x4xi32>) -> tensor<2x1x4x3xi32>
|
||||
return %0 : tensor<2x1x4x3xi32>
|
||||
}
|
||||
|
16
tensorflow/compiler/mlir/xla/tests/translate/tuple.hlotxt
Normal file
16
tensorflow/compiler/mlir/xla/tests/translate/tuple.hlotxt
Normal file
@ -0,0 +1,16 @@
|
||||
// RUN: tf-mlir-translate -hlo-text-to-mlir-hlo %s -o - | FileCheck %s
|
||||
|
||||
HloModule main
|
||||
|
||||
// CHECK-LABEL: func @main(%arg0: tensor<1xi32>, %arg1: tensor<1x2xf32>) -> tuple<tensor<1xi32>, tensor<1x2xf32>> {
|
||||
ENTRY %main(Arg_0.1: s32[1], Arg_1.2: f32[1, 2]) -> (s32[1], f32[1,2]) {
|
||||
%Arg_0.1 = s32[1] parameter(0)
|
||||
%Arg_1.2 = f32[1, 2] parameter(1)
|
||||
|
||||
// CHECK-NEXT: %0 = "xla.tuple"(%arg0) {name = "tuple.3"} : (tensor<1xi32>) -> tuple<tensor<1xi32>>
|
||||
%tuple.3 = (s32[1]) tuple(%Arg_0.1)
|
||||
|
||||
// CHECK-NEXT: %1 = "xla.tuple"(%arg0, %arg1) {name = "tuple.4"} : (tensor<1xi32>, tensor<1x2xf32>) -> tuple<tensor<1xi32>, tensor<1x2xf32>>
|
||||
// CHECK-NEXT: return %1 : tuple<tensor<1xi32>, tensor<1x2xf32>>
|
||||
ROOT %tuple.4 = (s32[1], f32[1,2]) tuple(%Arg_0.1, %Arg_1.2)
|
||||
}
|
11
tensorflow/compiler/mlir/xla/tests/translate/unknown.hlotxt
Normal file
11
tensorflow/compiler/mlir/xla/tests/translate/unknown.hlotxt
Normal file
@ -0,0 +1,11 @@
|
||||
// RUN: tf-mlir-translate -hlo-text-to-mlir-hlo %s -o - | FileCheck %s
|
||||
|
||||
HloModule main
|
||||
|
||||
// CHECK-LABEL: func @main(%arg0: tensor<1xf32>) -> tensor<1xf32> {
|
||||
ENTRY %main (Arg_0.1: f32[1, 4], Arg_1.2: f32[4, 1]) -> f32[1] {
|
||||
%Arg_0.1 = f32[1] parameter(0)
|
||||
|
||||
// CHECK-NEXT: %0 = "xla.unknown"(%arg0, %arg0) {name = "add-dependency.2"} : (tensor<1xf32>, tensor<1xf32>) -> tensor<1xf32>
|
||||
ROOT add-dependency.2 = f32[1] add-dependency(Arg_0.1, Arg_0.1)
|
||||
}
|
27
tensorflow/compiler/mlir/xla/tests/translate/while.hlotxt
Normal file
27
tensorflow/compiler/mlir/xla/tests/translate/while.hlotxt
Normal file
@ -0,0 +1,27 @@
|
||||
// RUN: tf-mlir-translate -hlo-text-to-mlir-hlo %s -o - | FileCheck %s
|
||||
|
||||
HloModule foo
|
||||
|
||||
// CHECK-LABEL: func @cond(%arg0: tensor<i64>) -> tensor<i1> {
|
||||
%cond (arg_1: s64[]) -> pred[] {
|
||||
%arg_1 = s64[] parameter(0), metadata={op_name="XLA_Args"}
|
||||
// CHECK-NEXT: %0 = "xla.compare"(%arg0, %arg0) {comparison_direction = "LT", name = "compare.2"} : (tensor<i64>, tensor<i64>) -> tensor<i1>
|
||||
// CHECK-NEXT: return %0 : tensor<i1>
|
||||
ROOT %compare.2 = pred[] compare(%arg_1, %arg_1), direction=LT, metadata={op_type="Less" op_name="Less"}
|
||||
}
|
||||
|
||||
// CHECK-LABEL: func @loop(%arg0: tensor<i64>) -> tensor<i64> {
|
||||
%loop (arg_1: s64[]) -> s64[] {
|
||||
%arg_1 = s64[] parameter(0), metadata={op_name="XLA_Args"}
|
||||
// CHECK-NEXT: %0 = "xla.add"(%arg0, %arg0) {name = "compare.0"} : (tensor<i64>, tensor<i64>) -> tensor<i64>
|
||||
// CHECK-NEXT: return %0 : tensor<i64>
|
||||
ROOT %compare.2 = s64[] add(%arg_1, %arg_1), metadata={op_type="Less" op_name="Less"}
|
||||
}
|
||||
|
||||
// CHECK-LABEL: func @main(%arg0: tensor<i64>) -> tensor<i64> {
|
||||
ENTRY %foo (arg0.1: s64[]) -> s64[] {
|
||||
%arg0.1 = s64[] parameter(0), metadata={op_name="XLA_Args"}
|
||||
// CHECK-NEXT: %0 = "xla.while"(%arg0) {body = @loop, cond = @cond} : (tensor<i64>) -> tensor<i64>
|
||||
// CHECK-NEXT: return %0 : tensor<i64>
|
||||
ROOT %while.2 = s64[] while(%arg0.1), body=%loop, condition=%cond
|
||||
}
|
145
tensorflow/compiler/mlir/xla/type_to_shape.cc
Normal file
145
tensorflow/compiler/mlir/xla/type_to_shape.cc
Normal file
@ -0,0 +1,145 @@
|
||||
/* Copyright 2019 The TensorFlow Authors. All Rights Reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
|
||||
#include "tensorflow/compiler/mlir/xla/type_to_shape.h"
|
||||
|
||||
#include <string>
|
||||
|
||||
#include "absl/base/integral_types.h"
|
||||
#include "mlir/IR/AffineMap.h" // TF:local_config_mlir
|
||||
#include "mlir/IR/Diagnostics.h" // TF:local_config_mlir
|
||||
#include "mlir/IR/Location.h" // TF:local_config_mlir
|
||||
#include "mlir/IR/StandardTypes.h" // TF:local_config_mlir
|
||||
#include "mlir/Support/DebugStringHelper.h" // TF:local_config_mlir
|
||||
#include "tensorflow/compiler/xla/shape_util.h"
|
||||
#include "tensorflow/core/platform/logging.h"
|
||||
|
||||
using mlir::IntegerType;
|
||||
using mlir::MemRefType;
|
||||
using mlir::RankedTensorType;
|
||||
using mlir::VectorType;
|
||||
using xla::PrimitiveType;
|
||||
using xla::ShapeUtil;
|
||||
|
||||
namespace xla {
|
||||
|
||||
PrimitiveType TypeToPrimitiveType(mlir::Type type) {
|
||||
switch (type.getKind()) {
|
||||
case mlir::StandardTypes::BF16:
|
||||
return PrimitiveType::BF16;
|
||||
case mlir::StandardTypes::F16:
|
||||
return PrimitiveType::F16;
|
||||
case mlir::StandardTypes::F32:
|
||||
return PrimitiveType::F32;
|
||||
case mlir::StandardTypes::F64:
|
||||
return PrimitiveType::F64;
|
||||
case mlir::StandardTypes::Integer: {
|
||||
const auto integer = type.cast<IntegerType>();
|
||||
switch (integer.getWidth()) {
|
||||
case 1:
|
||||
return PrimitiveType::PRED;
|
||||
case 8:
|
||||
return PrimitiveType::S8;
|
||||
case 16:
|
||||
return PrimitiveType::S16;
|
||||
case 32:
|
||||
return PrimitiveType::S32;
|
||||
case 64:
|
||||
return PrimitiveType::S64;
|
||||
default:
|
||||
return PrimitiveType::PRIMITIVE_TYPE_INVALID;
|
||||
}
|
||||
}
|
||||
default:
|
||||
return PrimitiveType::PRIMITIVE_TYPE_INVALID;
|
||||
}
|
||||
}
|
||||
|
||||
Shape TypeToShape(mlir::Type type) {
|
||||
PrimitiveType ptype = TypeToPrimitiveType(type);
|
||||
if (ptype != PrimitiveType::PRIMITIVE_TYPE_INVALID)
|
||||
return ShapeUtil::MakeShape(ptype, {});
|
||||
|
||||
switch (type.getKind()) {
|
||||
case mlir::StandardTypes::BF16:
|
||||
case mlir::StandardTypes::F32:
|
||||
case mlir::StandardTypes::F64:
|
||||
case mlir::StandardTypes::Integer: {
|
||||
auto* context = type.getContext();
|
||||
mlir::emitError(mlir::UnknownLoc::get(context))
|
||||
<< "lowering should have been handled by primitive type lowering for "
|
||||
<< debugString(type);
|
||||
break;
|
||||
}
|
||||
case mlir::StandardTypes::Vector: {
|
||||
const auto v = type.cast<VectorType>();
|
||||
llvm::SmallVector<int64, 4> span(v.getShape().begin(),
|
||||
v.getShape().end());
|
||||
mlir::Type element_type = v.getElementType();
|
||||
PrimitiveType primitive_type = TypeToPrimitiveType(element_type);
|
||||
if (primitive_type != PrimitiveType::PRIMITIVE_TYPE_INVALID)
|
||||
return ShapeUtil::MakeShape(primitive_type, span);
|
||||
break;
|
||||
}
|
||||
case mlir::StandardTypes::MemRef: {
|
||||
const auto m = type.cast<MemRefType>();
|
||||
llvm::SmallVector<int64, 6> span(m.getShape().begin(),
|
||||
m.getShape().end());
|
||||
mlir::Type element_type = m.getElementType();
|
||||
// Treat a memref of a vector as if it was a memref of primitive type with
|
||||
// the vector dimensions at the end.
|
||||
if (auto v = element_type.dyn_cast<VectorType>()) {
|
||||
element_type = v.getElementType();
|
||||
span.insert(span.end(), v.getShape().begin(), v.getShape().end());
|
||||
}
|
||||
PrimitiveType primitive_type = TypeToPrimitiveType(element_type);
|
||||
if (primitive_type == PrimitiveType::PRIMITIVE_TYPE_INVALID) break;
|
||||
// For the primitive type case, the shape of the memref is similar to the
|
||||
// vector type case (i.e., it is, modulo the layout, the same dimensions
|
||||
// and primitive type).
|
||||
// Currently we only return shapes for identity affine maps.
|
||||
// TODO(andydavis) Map affine map layout function to XLA layout.
|
||||
if (m.getAffineMaps().empty() ||
|
||||
(m.getAffineMaps().size() == 1 && m.getAffineMaps()[0].isIdentity()))
|
||||
return ShapeUtil::MakeShape(primitive_type, span);
|
||||
break;
|
||||
}
|
||||
case mlir::StandardTypes::RankedTensor: {
|
||||
// TODO(jpienaar): This is only handling the base case with primitive
|
||||
// element type.
|
||||
const auto t = type.cast<RankedTensorType>();
|
||||
llvm::SmallVector<int64, 4> span(t.getShape().begin(),
|
||||
t.getShape().end());
|
||||
// Only fully static shapes are supported.
|
||||
// TODO(b/115638799): Update once xla::Shape can support dynamic shapes.
|
||||
if (std::find(t.getShape().begin(), t.getShape().end(), -1) !=
|
||||
t.getShape().end())
|
||||
break;
|
||||
mlir::Type element_type = t.getElementType();
|
||||
PrimitiveType primitive_type = TypeToPrimitiveType(element_type);
|
||||
// Only primitive element type supported.
|
||||
if (primitive_type != PrimitiveType::PRIMITIVE_TYPE_INVALID)
|
||||
return ShapeUtil::MakeShape(primitive_type, span);
|
||||
break;
|
||||
}
|
||||
default:
|
||||
break;
|
||||
}
|
||||
// Return empty XLA shape to signify error. No MLIR Type maps to a empty
|
||||
// Shape.
|
||||
return {};
|
||||
}
|
||||
|
||||
} // namespace xla
|
34
tensorflow/compiler/mlir/xla/type_to_shape.h
Normal file
34
tensorflow/compiler/mlir/xla/type_to_shape.h
Normal file
@ -0,0 +1,34 @@
|
||||
/* Copyright 2019 The TensorFlow Authors. All Rights Reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
|
||||
#ifndef TENSORFLOW_COMPILER_MLIR_XLA_TYPE_TO_SHAPE_H_
|
||||
#define TENSORFLOW_COMPILER_MLIR_XLA_TYPE_TO_SHAPE_H_
|
||||
|
||||
#include "mlir/IR/Types.h" // TF:local_config_mlir
|
||||
#include "tensorflow/compiler/xla/shape.h"
|
||||
#include "tensorflow/compiler/xla/xla_data.pb.h"
|
||||
|
||||
namespace xla {
|
||||
|
||||
// Returns a XLA Shape equivalent of a MLIR Type, else returns empty shape.
|
||||
Shape TypeToShape(mlir::Type type);
|
||||
|
||||
// Returns a XLA PrimitiveType equivalent of a MLIR Type that represents a
|
||||
// primitive type (e.g., i8, f32), else returns PRIMITIVE_TYPE_INVALID.
|
||||
PrimitiveType TypeToPrimitiveType(mlir::Type type);
|
||||
|
||||
} // namespace xla
|
||||
|
||||
#endif // TENSORFLOW_COMPILER_MLIR_XLA_TYPE_TO_SHAPE_H_
|
110
tensorflow/compiler/mlir/xla/type_to_shape_test.cc
Normal file
110
tensorflow/compiler/mlir/xla/type_to_shape_test.cc
Normal file
@ -0,0 +1,110 @@
|
||||
/* Copyright 2019 The TensorFlow Authors. All Rights Reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
|
||||
#include "tensorflow/compiler/mlir/xla/type_to_shape.h"
|
||||
|
||||
#include "mlir/IR/Builders.h" // TF:local_config_mlir
|
||||
#include "mlir/IR/MLIRContext.h" // TF:local_config_mlir
|
||||
#include "mlir/IR/StandardTypes.h" // TF:local_config_mlir
|
||||
#include "tensorflow/compiler/xla/shape_util.h"
|
||||
#include "tensorflow/compiler/xla/test.h"
|
||||
#include "tensorflow/compiler/xla/xla_data.pb.h"
|
||||
|
||||
using mlir::Builder;
|
||||
using mlir::MLIRContext;
|
||||
using ::testing::EqualsProto;
|
||||
|
||||
namespace xla {
|
||||
namespace {
|
||||
|
||||
TEST(TypeToShapeTest, ConvertPrimitiveTypes) {
|
||||
MLIRContext context;
|
||||
Builder b(&context);
|
||||
|
||||
EXPECT_EQ(TypeToPrimitiveType(b.getF32Type()), PrimitiveType::F32);
|
||||
EXPECT_EQ(TypeToPrimitiveType(b.getIntegerType(1)), PrimitiveType::PRED);
|
||||
EXPECT_EQ(TypeToPrimitiveType(b.getIntegerType(17)),
|
||||
PrimitiveType::PRIMITIVE_TYPE_INVALID);
|
||||
}
|
||||
|
||||
TEST(TypeToShapeTest, ConvertBasicTypesToTypes) {
|
||||
MLIRContext context;
|
||||
Builder b(&context);
|
||||
|
||||
EXPECT_TRUE(
|
||||
ShapeUtil::IsScalarWithElementType(TypeToShape(b.getF32Type()), F32));
|
||||
EXPECT_THAT(
|
||||
TypeToShape(b.getVectorType({8, 128}, b.getIntegerType(32))).ToProto(),
|
||||
EqualsProto(
|
||||
ShapeUtil::MakeShape(PrimitiveType::S32, {8, 128}).ToProto()));
|
||||
EXPECT_THAT(
|
||||
TypeToShape(b.getVectorType({8, 128}, b.getF32Type())).ToProto(),
|
||||
EqualsProto(
|
||||
ShapeUtil::MakeShape(PrimitiveType::F32, {8, 128}).ToProto()));
|
||||
|
||||
// MLIR Type that is not representable as XLA Shape.
|
||||
EXPECT_THAT(
|
||||
TypeToShape(b.getVectorType({8, 128}, b.getIntegerType(17))).ToProto(),
|
||||
EqualsProto(Shape().ToProto()));
|
||||
}
|
||||
|
||||
TEST(TypeToShapeTest, ConvertMemRefTypeToTypes) {
|
||||
MLIRContext context;
|
||||
Builder b(&context);
|
||||
|
||||
// Memref without any affine map. Note: memory space is ignored for shape.
|
||||
EXPECT_THAT(
|
||||
TypeToShape(b.getMemRefType({8, 128}, b.getF32Type())).ToProto(),
|
||||
EqualsProto(
|
||||
ShapeUtil::MakeShape(PrimitiveType::F32, {8, 128}).ToProto()));
|
||||
EXPECT_THAT(
|
||||
TypeToShape(b.getMemRefType({100, 13, 210}, b.getF32Type())).ToProto(),
|
||||
EqualsProto(
|
||||
ShapeUtil::MakeShape(PrimitiveType::F32, {100, 13, 210}).ToProto()));
|
||||
|
||||
// Vector types are "flattened" into the end of the shape.
|
||||
EXPECT_THAT(
|
||||
TypeToShape(b.getMemRefType({100, 13, 210},
|
||||
b.getVectorType({8, 128}, b.getF32Type())))
|
||||
.ToProto(),
|
||||
EqualsProto(
|
||||
ShapeUtil::MakeShape(PrimitiveType::F32, {100, 13, 210, 8, 128})
|
||||
.ToProto()));
|
||||
}
|
||||
|
||||
TEST(TypeToShapeTest, ConvertTensorTypeToTypes) {
|
||||
MLIRContext context;
|
||||
Builder b(&context);
|
||||
|
||||
EXPECT_THAT(
|
||||
TypeToShape(b.getTensorType({8, 128}, b.getF32Type())).ToProto(),
|
||||
EqualsProto(
|
||||
ShapeUtil::MakeShape(PrimitiveType::F32, {8, 128}).ToProto()));
|
||||
|
||||
// Shape cannot represent dynamic shapes.
|
||||
// TODO(b/115638799): Update once Shape can support dynamic shapes.
|
||||
EXPECT_THAT(TypeToShape(b.getTensorType(b.getF32Type())).ToProto(),
|
||||
EqualsProto(Shape().ToProto()));
|
||||
|
||||
// TODO(jpienaar): Expand to handle more complicated tensor types.
|
||||
EXPECT_THAT(
|
||||
TypeToShape(
|
||||
b.getTensorType({8, 128}, b.getVectorType({16, 16}, b.getF32Type())))
|
||||
.ToProto(),
|
||||
EqualsProto(Shape().ToProto()));
|
||||
}
|
||||
|
||||
} // namespace
|
||||
} // namespace xla
|
192
tensorflow/compiler/mlir/xla/xla_mlir_translate.cc
Normal file
192
tensorflow/compiler/mlir/xla/xla_mlir_translate.cc
Normal file
@ -0,0 +1,192 @@
|
||||
/* Copyright 2019 The TensorFlow Authors. All Rights Reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
|
||||
#include "tensorflow/compiler/mlir/xla/xla_mlir_translate.h"
|
||||
|
||||
#include "google/protobuf/text_format.h"
|
||||
#include "llvm/Support/MemoryBuffer.h"
|
||||
#include "llvm/Support/ToolOutputFile.h"
|
||||
#include "mlir/IR/Module.h" // TF:local_config_mlir
|
||||
#include "mlir/Translation.h" // TF:local_config_mlir
|
||||
#include "tensorflow/compiler/mlir/xla/hlo_to_mlir_hlo.h"
|
||||
#include "tensorflow/compiler/mlir/xla/mlir_hlo_to_hlo.h"
|
||||
#include "tensorflow/compiler/xla/debug_options_flags.h"
|
||||
#include "tensorflow/compiler/xla/service/hlo.pb.h"
|
||||
#include "tensorflow/compiler/xla/service/hlo_parser.h"
|
||||
#include "tensorflow/core/lib/core/errors.h"
|
||||
|
||||
using stream_executor::port::Status;
|
||||
using stream_executor::port::StatusOr; // NOLINT TODO(b/130822468) fix this
|
||||
|
||||
namespace xla {
|
||||
|
||||
namespace {
|
||||
// Error collector that simply ignores errors reported.
|
||||
class NoOpErrorCollector : public ::proto2::io::ErrorCollector {
|
||||
public:
|
||||
void AddError(int line, int column, const string& message) override {}
|
||||
};
|
||||
|
||||
bool LoadHloProto(const std::string& contents, HloProto* hlo_proto) {
|
||||
::proto2::TextFormat::Parser parser;
|
||||
NoOpErrorCollector collector;
|
||||
parser.RecordErrorsTo(&collector);
|
||||
return hlo_proto->ParseFromString(contents) ||
|
||||
parser.ParseFromString(contents, hlo_proto) ||
|
||||
hlo_proto->mutable_hlo_module()->ParseFromString(contents) ||
|
||||
parser.ParseFromString(contents, hlo_proto->mutable_hlo_module());
|
||||
}
|
||||
|
||||
} // namespace
|
||||
|
||||
mlir::OwningModuleRef HloToMlirHloTranslateFunction(
|
||||
llvm::StringRef input_filename, mlir::MLIRContext* context) {
|
||||
auto file_or_err = llvm::MemoryBuffer::getFileOrSTDIN(input_filename.str());
|
||||
if (std::error_code error = file_or_err.getError()) {
|
||||
LOG(ERROR) << "Failure to read HLO module: " << error;
|
||||
return nullptr;
|
||||
}
|
||||
|
||||
auto& input_file = *file_or_err;
|
||||
HloProto hlo_proto;
|
||||
string content(input_file->getBufferStart(), input_file->getBufferSize());
|
||||
if (!LoadHloProto(content, &hlo_proto)) {
|
||||
LOG(ERROR) << "Failed to load proto: " << input_filename.str();
|
||||
return nullptr;
|
||||
}
|
||||
|
||||
mlir::OwningModuleRef module =
|
||||
mlir::ModuleOp::create(mlir::UnknownLoc::get(context));
|
||||
auto status =
|
||||
ConvertHloToMlirHlo(module.get(), hlo_proto.mutable_hlo_module());
|
||||
if (!status.ok()) {
|
||||
LOG(ERROR) << "Hlo module import failed: " << status;
|
||||
return nullptr;
|
||||
}
|
||||
|
||||
return module;
|
||||
}
|
||||
|
||||
mlir::OwningModuleRef HloTextToMlirHloTranslateFunction(
|
||||
llvm::StringRef input_filename, mlir::MLIRContext* context) {
|
||||
auto file_or_err = llvm::MemoryBuffer::getFileOrSTDIN(input_filename.str());
|
||||
if (std::error_code error = file_or_err.getError()) {
|
||||
LOG(ERROR) << "Failure to open file: " << error;
|
||||
return nullptr;
|
||||
}
|
||||
|
||||
auto& input_file = *file_or_err;
|
||||
HloProto hlo_proto;
|
||||
string content(input_file->getBufferStart(), input_file->getBufferSize());
|
||||
|
||||
auto hlo_module_error = ParseAndReturnUnverifiedModule(content);
|
||||
if (!hlo_module_error.ok()) {
|
||||
LOG(ERROR) << "HLO Module loading failed: " << hlo_module_error.status();
|
||||
return nullptr;
|
||||
}
|
||||
|
||||
auto hlo_module = std::move(hlo_module_error.ValueOrDie());
|
||||
mlir::OwningModuleRef module =
|
||||
mlir::ModuleOp::create(mlir::UnknownLoc::get(context));
|
||||
auto status = ConvertHloToMlirHlo(*module, hlo_module.get());
|
||||
if (!status.ok()) {
|
||||
LOG(ERROR) << "HLO Module import failed: " << status;
|
||||
return nullptr;
|
||||
}
|
||||
|
||||
return module;
|
||||
}
|
||||
|
||||
static mlir::LogicalResult MlirHloToHloTranslateFunction(
|
||||
mlir::ModuleOp module, llvm::StringRef output_filename) {
|
||||
if (!module) return mlir::failure();
|
||||
|
||||
std::error_code error;
|
||||
auto result = llvm::make_unique<llvm::ToolOutputFile>(output_filename, error,
|
||||
llvm::sys::fs::F_None);
|
||||
if (error) {
|
||||
LOG(ERROR) << error.message();
|
||||
return mlir::failure();
|
||||
}
|
||||
|
||||
HloProto hloProto;
|
||||
Status status = mlir::ConvertMlirHloToHlo(module, &hloProto);
|
||||
if (!status.ok()) {
|
||||
LOG(ERROR) << "Module conversion failed: " << status;
|
||||
return mlir::failure();
|
||||
}
|
||||
|
||||
result->os() << hloProto.DebugString();
|
||||
result->keep();
|
||||
return mlir::success();
|
||||
}
|
||||
|
||||
static StatusOr<std::unique_ptr<HloModule>> HloModuleFromProto(
|
||||
const HloProto& hlo_proto) {
|
||||
const HloModuleProto& module_proto = hlo_proto.hlo_module();
|
||||
TF_ASSIGN_OR_RETURN(const HloModuleConfig module_config,
|
||||
HloModule::CreateModuleConfigFromProto(
|
||||
module_proto, GetDebugOptionsFromFlags()));
|
||||
return HloModule::CreateFromProto(module_proto, module_config);
|
||||
}
|
||||
|
||||
static mlir::LogicalResult MlirHloToHloTextTranslateFunction(
|
||||
mlir::ModuleOp module, llvm::StringRef output_filename) {
|
||||
if (!module) return mlir::failure();
|
||||
|
||||
std::error_code error;
|
||||
auto result = llvm::make_unique<llvm::ToolOutputFile>(output_filename, error,
|
||||
llvm::sys::fs::F_None);
|
||||
if (error) {
|
||||
LOG(ERROR) << error.message();
|
||||
return mlir::failure();
|
||||
}
|
||||
|
||||
HloProto hloProto;
|
||||
Status status = mlir::ConvertMlirHloToHlo(module, &hloProto);
|
||||
if (!status.ok()) {
|
||||
LOG(ERROR) << "Module conversion failed: " << status;
|
||||
return mlir::failure();
|
||||
}
|
||||
|
||||
auto statusOrHloModule = HloModuleFromProto(hloProto);
|
||||
|
||||
if (!statusOrHloModule.ok()) {
|
||||
LOG(ERROR) << "Conversion to HLO module failed: "
|
||||
<< statusOrHloModule.status();
|
||||
return mlir::failure();
|
||||
}
|
||||
|
||||
result->os() << statusOrHloModule.ValueOrDie()->ToString(
|
||||
HloPrintOptions()
|
||||
// We don't interpret or use layouts
|
||||
.set_include_layout_in_shapes(false));
|
||||
result->keep();
|
||||
return mlir::success();
|
||||
}
|
||||
|
||||
} // namespace xla
|
||||
|
||||
static mlir::TranslateFromMLIRRegistration MlirHloToHloTranslate(
|
||||
"mlir-hlo-to-hlo", xla::MlirHloToHloTranslateFunction);
|
||||
|
||||
static mlir::TranslateFromMLIRRegistration MlirHloToHloTextTranslate(
|
||||
"mlir-hlo-to-hlo-text", xla::MlirHloToHloTextTranslateFunction);
|
||||
|
||||
static mlir::TranslateToMLIRRegistration HloToHloMlirTranslate(
|
||||
"hlo-to-mlir-hlo", xla::HloToMlirHloTranslateFunction);
|
||||
|
||||
static mlir::TranslateToMLIRRegistration HloTextToHloMlirTranslate(
|
||||
"hlo-text-to-mlir-hlo", xla::HloTextToMlirHloTranslateFunction);
|
47
tensorflow/compiler/mlir/xla/xla_mlir_translate.h
Normal file
47
tensorflow/compiler/mlir/xla/xla_mlir_translate.h
Normal file
@ -0,0 +1,47 @@
|
||||
/* Copyright 2019 The TensorFlow Authors. All Rights Reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
|
||||
#ifndef TENSORFLOW_COMPILER_MLIR_XLA_XLA_MLIR_TRANSLATE_H_
|
||||
#define TENSORFLOW_COMPILER_MLIR_XLA_XLA_MLIR_TRANSLATE_H_
|
||||
|
||||
#include <memory>
|
||||
|
||||
#include "tensorflow/compiler/xla/statusor.h"
|
||||
|
||||
namespace llvm {
|
||||
class StringRef;
|
||||
} // namespace llvm
|
||||
|
||||
namespace mlir {
|
||||
class MLIRContext;
|
||||
class OwningModuleRef;
|
||||
} // namespace mlir
|
||||
|
||||
namespace xla {
|
||||
|
||||
// Converts a HloModuleProto stored in the file with the given `input_filename`
|
||||
// into a MLIR module. Creates MLIR entities into the given MLIR `context`.
|
||||
mlir::OwningModuleRef HloToMlirHloTranslateFunction(
|
||||
llvm::StringRef input_filename, mlir::MLIRContext* context);
|
||||
|
||||
// Converts a HloModule stored in text form for a file with the given
|
||||
// `input_filename` into a MLIR module. Creates MLIR entities into the given
|
||||
// MLIR `context`.
|
||||
mlir::OwningModuleRef HloTextToMlirHloTranslateFunction(
|
||||
llvm::StringRef input_filename, mlir::MLIRContext* context);
|
||||
|
||||
} // namespace xla
|
||||
|
||||
#endif // TENSORFLOW_COMPILER_MLIR_XLA_XLA_MLIR_TRANSLATE_H_
|
Loading…
Reference in New Issue
Block a user