Initial version of HLO translation tools.

Adds the initial version of tools to convert between HLO protos and the XLA dialect.

Co-authored-by: Robert Suderman <suderman@google.com>
Co-authored-by: Geoffrey Martin-Noble <gcmn@google.com>
Co-authored-by: River Riddle <riverriddle@google.com>
Co-authored-by: Lei Zhang <antiagainst@google.com>
Co-authored-by: Himabindu Pucha <hpucha@google.com>
PiperOrigin-RevId: 258691517
This commit is contained in:
Jacques Pienaar 2019-07-17 19:53:01 -07:00 committed by TensorFlower Gardener
parent 3d11f06b29
commit 4ad0cdbdeb
56 changed files with 2926 additions and 20 deletions

View File

@ -56,6 +56,25 @@ tf_cc_binary(
],
)
tf_cc_binary(
name = "tf-mlir-translate",
deps = [
"//tensorflow/compiler/mlir/tensorflow:convert_graphdef",
"//tensorflow/compiler/mlir/tensorflow:mlir_roundtrip_flags",
"//tensorflow/compiler/mlir/tensorflow:translate_cl_options",
"//tensorflow/compiler/mlir/tensorflow:translate_lib",
"//tensorflow/compiler/mlir/tensorflow:translate_registration",
"//tensorflow/compiler/mlir/tensorflow:translate_tf_dialect_op",
"//tensorflow/compiler/mlir/xla:xla_mlir_translate",
"//tensorflow/core:protos_all_proto_cc",
"//tensorflow/stream_executor/lib",
"@llvm//:support",
"@local_config_mlir//:IR",
"@local_config_mlir//:Translation",
"@local_config_mlir//:tools/mlir-translate/mlir-translate",
],
)
filegroup(
name = "litfiles",
srcs = glob(["runlit*py"]),

View File

@ -511,24 +511,6 @@ cc_library(
alwayslink = 1,
)
tf_cc_binary(
name = "tf-mlir-translate",
deps = [
":convert_graphdef",
":mlir_roundtrip_flags",
":translate_cl_options",
":translate_lib",
":translate_registration",
":translate_tf_dialect_op",
"//tensorflow/core:protos_all_proto_cc",
"//tensorflow/stream_executor/lib",
"@llvm//:support",
"@local_config_mlir//:IR",
"@local_config_mlir//:Translation",
"@local_config_mlir//:tools/mlir-translate/mlir-translate",
],
)
tf_cc_test(
name = "error_util_test",
srcs = ["utils/error_util_test.cc"],

View File

@ -13,7 +13,7 @@ filegroup(
name = "test_utilities",
testonly = True,
data = [
"//tensorflow/compiler/mlir/tensorflow:tf-mlir-translate",
"//tensorflow/compiler/mlir:tf-mlir-translate",
"@llvm//:FileCheck",
],
)

View File

@ -13,7 +13,7 @@ filegroup(
name = "test_utilities",
testonly = True,
data = [
"//tensorflow/compiler/mlir/tensorflow:tf-mlir-translate",
"//tensorflow/compiler/mlir:tf-mlir-translate",
"@llvm//:FileCheck",
],
)

View File

@ -1,4 +1,5 @@
load("@local_config_mlir//:tblgen.bzl", "gentbl")
load("//tensorflow:tensorflow.bzl", "tf_cc_test", "tf_native_cc_binary")
package(
default_visibility = [":friends"],
@ -166,3 +167,151 @@ cc_library(
],
alwayslink = 1,
)
cc_library(
name = "type_to_shape",
srcs = ["type_to_shape.cc"],
hdrs = ["type_to_shape.h"],
deps = [
"//tensorflow/compiler/xla:shape_util",
"//tensorflow/compiler/xla:xla_data_proto",
"//tensorflow/core:lib",
"@com_google_absl//absl/base:core_headers",
"@local_config_mlir//:IR",
"@local_config_mlir//:Support",
],
)
tf_cc_test(
name = "type_to_shape_test",
srcs = ["type_to_shape_test.cc"],
deps = [
":type_to_shape",
"//tensorflow/compiler/xla:shape_util",
"//tensorflow/compiler/xla:test",
"//tensorflow/compiler/xla:xla_data_proto",
"//tensorflow/core:test_main",
"@local_config_mlir//:IR",
],
)
cc_library(
name = "mlir_hlo_to_hlo",
srcs = [
"mlir_hlo_to_hlo.cc",
"operator_writers.inc",
],
hdrs = ["mlir_hlo_to_hlo.h"],
deps = [
":type_to_shape",
":xla",
"//tensorflow/compiler/mlir/tensorflow:error_util",
"//tensorflow/compiler/xla:comparison_util",
"//tensorflow/compiler/xla:status_macros",
"//tensorflow/compiler/xla:xla_data_proto",
"//tensorflow/compiler/xla/client:xla_builder",
"//tensorflow/compiler/xla/service:hlo",
"@llvm//:support",
"@local_config_mlir//:Analysis",
"@local_config_mlir//:IR",
"@local_config_mlir//:Pass",
"@local_config_mlir//:StandardOps",
"@local_config_mlir//:TransformUtils",
"@local_config_mlir//:Transforms",
],
)
cc_library(
name = "hlo_to_mlir_hlo",
srcs = [
"hlo_to_mlir_hlo.cc",
],
hdrs = [
"hlo_to_mlir_hlo.h",
],
deps = [
":hlo_module_importer",
"//tensorflow/compiler/mlir/tensorflow:error_util",
"//tensorflow/compiler/xla:status",
"//tensorflow/compiler/xla:status_macros",
"//tensorflow/core:lib",
],
)
cc_library(
name = "hlo_module_importer",
srcs = [
"hlo_function_importer.cc",
"hlo_module_importer.cc",
],
hdrs = [
"hlo_function_importer.h",
"hlo_module_importer.h",
],
deps = [
":xla",
"//tensorflow/compiler/mlir/tensorflow:error_util",
"//tensorflow/compiler/xla:protobuf_util",
"//tensorflow/compiler/xla:status",
"//tensorflow/compiler/xla:status_macros",
"//tensorflow/compiler/xla:statusor",
"//tensorflow/compiler/xla:xla_data_proto",
"//tensorflow/compiler/xla:xla_proto",
"//tensorflow/compiler/xla/service:hlo",
"@llvm//:support",
"@local_config_mlir//:IR",
"@local_config_mlir//:StandardOps",
],
)
cc_library(
name = "xla_mlir_translate",
srcs = [
"xla_mlir_translate.cc",
],
hdrs = [
"xla_mlir_translate.h",
],
deps = [
":hlo_to_mlir_hlo",
":mlir_hlo_to_hlo",
"//tensorflow/compiler/xla:debug_options_flags",
"//tensorflow/compiler/xla:statusor",
"//tensorflow/compiler/xla/service:hlo_parser",
"//tensorflow/compiler/xla/service:hlo_proto",
"//tensorflow/core:lib",
"@com_google_protobuf//:protobuf_headers",
"@llvm//:support",
"@local_config_mlir//:IR",
"@local_config_mlir//:Translation",
],
alwayslink = 1,
)
tf_native_cc_binary(
name = "operator_writer_gen",
srcs = [
"operator_writer_gen.cc",
],
deps = [
"@llvm//:config",
"@llvm//:support",
"@llvm//:tablegen",
"@local_config_mlir//:TableGen",
],
)
genrule(
name = "operator_writer_inc",
srcs = [
"@local_config_mlir//:include/mlir/IR/OpBase.td",
"//tensorflow/compiler/mlir/xla:ir/xla_ops.td",
],
outs = [
"operator_writers.inc",
],
cmd = ("$(location :operator_writer_gen) " +
"-I external/local_config_mlir/include " +
"$(location //tensorflow/compiler/mlir/xla:ir/xla_ops.td) " + " -o $@"),
tools = [":operator_writer_gen"],
)

View File

@ -0,0 +1,528 @@
/* Copyright 2019 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#include "tensorflow/compiler/mlir/xla/hlo_function_importer.h"
#include "llvm/ADT/ArrayRef.h"
#include "llvm/ADT/STLExtras.h"
#include "llvm/ADT/SmallVector.h"
#include "llvm/Support/raw_ostream.h"
#include "mlir/IR/Attributes.h" // TF:local_config_mlir
#include "mlir/IR/Builders.h" // TF:local_config_mlir
#include "mlir/IR/Identifier.h" // TF:local_config_mlir
#include "mlir/IR/Location.h" // TF:local_config_mlir
#include "mlir/IR/StandardTypes.h" // TF:local_config_mlir
#include "mlir/StandardOps/Ops.h" // TF:local_config_mlir
#include "tensorflow/compiler/mlir/tensorflow/utils/error_util.h"
#include "tensorflow/compiler/mlir/xla/ir/xla_ops.h"
#include "tensorflow/compiler/xla/protobuf_util.h"
#include "tensorflow/compiler/xla/service/hlo_computation.h"
#include "tensorflow/compiler/xla/service/hlo_instruction.h"
#include "tensorflow/compiler/xla/service/hlo_instructions.h"
#include "tensorflow/compiler/xla/service/hlo_module.h"
#include "tensorflow/compiler/xla/service/hlo_opcode.h"
#include "tensorflow/compiler/xla/status_macros.h"
#include "tensorflow/compiler/xla/xla_data.pb.h"
using llvm::APInt;
using llvm::makeArrayRef;
using mlir::DenseElementsAttr;
using mlir::DenseIntElementsAttr;
using mlir::FuncOp;
using mlir::NamedAttribute;
using mlir::Operation;
using mlir::ShapedType;
using mlir::Type;
using mlir::Value;
namespace xla {
namespace {
// Note: This sanitization function causes an irreversible many-to-one mapping
// and any solution to mitigate this would cause issues with the reverse
// direction. Longterm solution is to add a function attribute to maintain the
// original HLO naming.
string SanitizeFunctionName(llvm::StringRef name) {
string output = name;
llvm::for_each(output, [](char& x) { x = x == '-' ? '_' : x; });
return output;
}
StatusOr<DenseElementsAttr> CreateDenseAttrFromLiteral(ShapedType type,
const Literal& literal) {
#define DENSE_ELEMENT_ATTR_BUILDER(xla_type, cpp_type) \
case xla_type: { \
auto data_span = literal.data<cpp_type>(); \
return DenseElementsAttr::get( \
type, llvm::makeArrayRef(data_span.data(), data_span.size())); \
}
switch (literal.shape().element_type()) {
DENSE_ELEMENT_ATTR_BUILDER(PrimitiveType::PRED, bool)
DENSE_ELEMENT_ATTR_BUILDER(PrimitiveType::F32, float)
DENSE_ELEMENT_ATTR_BUILDER(PrimitiveType::F64, double)
DENSE_ELEMENT_ATTR_BUILDER(PrimitiveType::S8, int8)
DENSE_ELEMENT_ATTR_BUILDER(PrimitiveType::S16, int16)
DENSE_ELEMENT_ATTR_BUILDER(PrimitiveType::S32, int32)
DENSE_ELEMENT_ATTR_BUILDER(PrimitiveType::S64, int64)
default:
return tensorflow::errors::Internal(
absl::StrCat("Unsupported type: ",
PrimitiveType_Name(literal.shape().element_type())));
}
#undef DENSE_ELEMENT_ATTR_BUILDER
}
} // namespace
StatusOr<mlir::FuncOp> HloFunctionImporter::ImportFunction(
mlir::ModuleOp module, mlir::Builder* builder,
std::unordered_map<HloComputation*, FuncOp>* function_map,
HloComputation* computation) {
HloFunctionImporter importer(module, builder, function_map);
return importer.ImportFunction(computation);
}
StatusOr<mlir::FuncOp> HloFunctionImporter::ImportFunction(
HloComputation* computation) {
auto& imported = (*function_map_)[computation];
if (imported) return imported;
llvm::SmallVector<Type, 4> args, rets;
TF_RETURN_IF_ERROR(
GetMlirTypes(computation->parameter_instructions(), &args));
TF_RETURN_IF_ERROR(GetMlirTypes({computation->root_instruction()}, &rets));
auto func_type = mlir::FunctionType::get(args, rets, context_);
string computation_name =
computation->parent()->entry_computation() == computation
? "main"
: SanitizeFunctionName(computation->name());
// Construct the MLIR function and map arguments.
llvm::ArrayRef<mlir::NamedAttribute> attrs;
auto function = mlir::FuncOp::create(mlir::UnknownLoc::get(context_),
computation_name, func_type, attrs);
module_.push_back(function);
// Add to the map right away for function calls.
imported = function;
function.addEntryBlock();
// Setup the input parameters.
const int num_parameters = computation->num_parameters();
for (int i = 0; i < num_parameters; i++) {
auto hlo_parameter = computation->parameter_instruction(i);
instruction_value_map_[hlo_parameter] = function.getArgument(i);
}
mlir::OpBuilder func_builder(function.getBody());
for (auto instruction : computation->MakeInstructionPostOrder()) {
TF_ASSIGN_OR_RETURN(auto new_operation,
ImportInstruction(instruction, &func_builder));
if (new_operation) {
instruction_value_map_[instruction] = new_operation->getResult(0);
}
}
// Setup the return type (HLO only supports a single return value).
TF_ASSIGN_OR_RETURN(auto result,
GetMlirValue(computation->root_instruction()));
llvm::SmallVector<Value*, 1> return_values({result});
// TODO(suderman): Add location tracking details.
func_builder.create<mlir::ReturnOp>(mlir::UnknownLoc::get(context_),
makeArrayRef(return_values));
return function;
}
StatusOr<mlir::Operation*> HloFunctionImporter::ImportInstruction(
HloInstruction* instruction, mlir::OpBuilder* func_builder) {
TF_ASSIGN_OR_RETURN(auto operands, GetOperands(instruction));
TF_ASSIGN_OR_RETURN(auto result_type, ConvertType(instruction->shape()));
llvm::SmallVector<NamedAttribute, 10> attributes = {builder_->getNamedAttr(
"name", builder_->getStringAttr(instruction->name()))};
mlir::Location loc = mlir::UnknownLoc::get(context_);
switch (instruction->opcode()) {
case HloOpcode::kParameter: {
return nullptr;
}
case HloOpcode::kConstant: {
auto attr = CreateDenseAttrFromLiteral(
result_type.cast<mlir::TensorType>(), instruction->literal());
if (!attr.ok()) return attr.status();
mlir::Operation* new_operation =
func_builder->create<mlir::ConstantOp>(loc, attr.ValueOrDie());
for (auto attr : attributes) {
new_operation->setAttr(attr.first, attr.second);
}
return new_operation;
}
case HloOpcode::kIota: {
return func_builder
->create<mlir::XLA::IotaOp>(
loc, result_type,
func_builder->getI64IntegerAttr(
static_cast<HloIotaInstruction*>(instruction)
->iota_dimension()))
.getOperation();
}
#define MakeAndReturn(mlir_op) \
{ \
mlir::Operation* new_operation = func_builder->create<mlir::XLA::mlir_op>( \
loc, result_type, operands, attributes); \
return new_operation; \
}
case HloOpcode::kBroadcast: {
// Note that the HLO broadcast is more powerful than the XLA broadcast op.
// BroadcastInDim offers a superset of the HLO op's functionality.
if (!instruction->dimensions().empty()) {
attributes.push_back(builder_->getNamedAttr(
"broadcast_dimensions",
ConvertDimensions(instruction->dimensions())));
}
MakeAndReturn(BroadcastInDimOp);
}
case HloOpcode::kDot: {
// TODO(b/129153247) Add support for batch and contracting dimensions.
TF_RETURN_IF_ERROR(ValidateDotDimensions(instruction));
// TODO(b/129709049) The HLO text format elides this in the all DEFAULT
// case and the parser sticks it in. Maybe we should too.
attributes.push_back(ConvertPrecisionConfig(instruction));
MakeAndReturn(DotOp);
}
case HloOpcode::kCall: {
TF_ASSIGN_OR_RETURN(FuncOp function,
ImportFunction(instruction->to_apply()));
mlir::Operation* new_operation =
func_builder->create<mlir::CallOp>(loc, function, operands);
return new_operation;
}
case HloOpcode::kCompare: {
attributes.push_back(ConvertComparisonDirection(instruction));
MakeAndReturn(CompareOp);
}
case HloOpcode::kGather: {
const auto& gather_dimensions = instruction->gather_dimension_numbers();
std::vector<int64_t> offset_dims(gather_dimensions.offset_dims().begin(),
gather_dimensions.offset_dims().end());
std::vector<int64_t> slice_sizes(
instruction->gather_slice_sizes().begin(),
instruction->gather_slice_sizes().end());
std::vector<int64_t> collapsed_slice_dims(
gather_dimensions.collapsed_slice_dims().begin(),
gather_dimensions.collapsed_slice_dims().end());
std::vector<int64_t> start_index_map(
gather_dimensions.start_index_map().begin(),
gather_dimensions.start_index_map().end());
// TODO(b/132057942): Change to explicitly passing an integer instead of
// call getI64IntegerAttr here.
return func_builder
->create<mlir::XLA::GatherOp>(
loc, result_type, operands[0], operands[1],
func_builder->getI64IntegerAttr(
gather_dimensions.index_vector_dim()),
Convert(offset_dims), Convert(slice_sizes),
Convert(collapsed_slice_dims), Convert(start_index_map))
.getOperation();
}
case HloOpcode::kDynamicUpdateSlice: {
return func_builder
->create<mlir::XLA::DynamicUpdateSliceOp>(
loc, result_type, operands[0], operands[1],
llvm::ArrayRef<Value*>(operands.begin() + 2, operands.end()))
.getOperation();
}
case HloOpcode::kPad: {
const auto& padding_config = instruction->padding_config();
llvm::SmallVector<int64_t, 4> edge_padding_low;
llvm::SmallVector<int64_t, 4> edge_padding_high;
llvm::SmallVector<int64_t, 4> interior_padding;
edge_padding_low.reserve(padding_config.dimensions_size());
edge_padding_high.reserve(padding_config.dimensions_size());
interior_padding.reserve(padding_config.dimensions_size());
for (const auto& dimension : padding_config.dimensions()) {
edge_padding_low.push_back(dimension.edge_padding_low());
edge_padding_high.push_back(dimension.edge_padding_high());
interior_padding.push_back(dimension.interior_padding());
}
return func_builder
->create<mlir::XLA::PadOp>(loc, result_type, operands[0], operands[1],
Convert(edge_padding_low),
Convert(edge_padding_high),
Convert(interior_padding))
.getOperation();
}
case HloOpcode::kSlice: {
return func_builder
->create<mlir::XLA::SliceOp>(
loc, result_type, operands[0],
ConvertDimensions(instruction->slice_starts()),
ConvertDimensions(instruction->slice_limits()))
.getOperation();
}
case HloOpcode::kConcatenate: {
// TODO(b/132057942): Support taking an uint64_t instead of an IntegerAttr
// for concatenate dimension.
return func_builder
->create<mlir::XLA::ConcatenateOp>(
loc, result_type, operands,
builder_->getI64IntegerAttr(instruction->concatenate_dimension()))
.getOperation();
}
case HloOpcode::kReduce: {
TF_ASSIGN_OR_RETURN(auto reduction,
ImportFunction(instruction->to_apply()));
// TODO(b/132057942): Make more convenient constructors, e.g. pass
// mlir function pointer instead of a function attr.
return func_builder
->create<mlir::XLA::ReduceOp>(
loc, result_type, operands,
func_builder->getSymbolRefAttr(reduction),
ConvertDimensions(instruction->dimensions()))
.getOperation();
}
case HloOpcode::kReverse: {
return func_builder
->create<mlir::XLA::ReverseOp>(
loc, result_type, operands[0],
ConvertDimensions(instruction->dimensions()))
.getOperation();
}
case HloOpcode::kWhile: {
TF_ASSIGN_OR_RETURN(auto body, ImportFunction(instruction->while_body()));
TF_ASSIGN_OR_RETURN(auto cond,
ImportFunction(instruction->while_condition()));
llvm::SmallVector<Type, 4> types;
types.reserve(operands.size());
for (auto operand : operands) {
types.push_back(operand->getType());
}
auto cond_attr = func_builder->getSymbolRefAttr(cond);
auto body_attr = func_builder->getSymbolRefAttr(body);
Operation* op = func_builder->create<mlir::XLA::WhileOp>(
loc, types, operands, cond_attr, body_attr);
return op;
}
case HloOpcode::kGetTupleElement: {
attributes.push_back(builder_->getNamedAttr(
"index", builder_->getIntegerAttr(builder_->getIntegerType(32),
instruction->tuple_index())));
MakeAndReturn(GetTupleElementOp);
};
case HloOpcode::kTranspose: {
attributes.push_back(builder_->getNamedAttr(
"permutation", ConvertDimensions(instruction->dimensions())));
MakeAndReturn(TransposeOp);
}
#define NoAttributeCase(hlo_op_code, mlir_op) \
case HloOpcode::hlo_op_code: { \
MakeAndReturn(mlir_op); \
}
// broadcast dimensions are never added here because they don't exist as
// part of the HLO instruction. They are only a convenience in the XLA
// builder API.
NoAttributeCase(kAdd, AddOp);
NoAttributeCase(kAnd, AndOp);
NoAttributeCase(kConvert, ConvertOp);
NoAttributeCase(kDivide, DivOp);
NoAttributeCase(kMaximum, MaxOp);
NoAttributeCase(kMinimum, MinOp);
NoAttributeCase(kMultiply, MulOp);
NoAttributeCase(kSelect, SelectOp);
NoAttributeCase(kSubtract, SubOp);
NoAttributeCase(kTanh, TanhOp);
NoAttributeCase(kTuple, TupleOp);
// TODO(b/129422361) Copy needs special handling because it is not defined
// in tensorflow/compiler/xla/client/xla_builder.h.
// See operation semantics in
// g3doc/platforms/xla/g3doc/internal/hlo_semantics#copy
NoAttributeCase(kCopy, CopyOp);
// TODO(b/129422361) Ops below need additional work to handle attributes.
NoAttributeCase(kConvolution, ConvOp);
NoAttributeCase(kReshape, ReshapeOp);
#undef NoAttributeCase
#undef MakeAndReturn
case HloOpcode::kAddDependency:
// Arbitrary op code that I suspect we will not implement for quite a
// while and allows testing handling of unknown ops. Selected because it
// is not mentioned in xla client anywhere or in the hlo of our sample
// models.
default: {
mlir::OperationState result(loc, "xla.unknown");
result.addOperands(operands);
result.addTypes(result_type);
for (auto attr : attributes) {
result.attributes.push_back(attr);
}
return func_builder->createOperation(result);
}
}
}
StatusOr<llvm::SmallVector<mlir::Value*, 4>> HloFunctionImporter::GetOperands(
HloInstruction* instruction) {
llvm::SmallVector<mlir::Value*, 4> operands;
for (const auto& operand : instruction->operands()) {
auto input_it = instruction_value_map_.find(operand);
if (input_it == instruction_value_map_.end()) {
return tensorflow::errors::Internal(
absl::StrCat("Could not find input value: ", operand->name(),
" for instruction ", instruction->name()));
}
operands.push_back(input_it->second);
}
return operands;
}
// TODO(suderman): Move to a general library when needed in other places.
StatusOr<mlir::RankedTensorType> HloFunctionImporter::ConvertTensorType(
const Shape& shape) {
auto type = shape.element_type();
llvm::SmallVector<int64_t, 4> array;
array.reserve(shape.dimensions_size());
for (auto val : shape.dimensions()) {
array.push_back(val);
}
switch (type) {
case PrimitiveType::PRED:
return builder_->getTensorType(array, builder_->getI1Type());
case PrimitiveType::F16:
return builder_->getTensorType(array, builder_->getF16Type());
case PrimitiveType::F32:
return builder_->getTensorType(array, builder_->getF32Type());
case PrimitiveType::F64:
return builder_->getTensorType(array, builder_->getF64Type());
case PrimitiveType::S8:
return builder_->getTensorType(array, builder_->getIntegerType(8));
case PrimitiveType::S16:
return builder_->getTensorType(array, builder_->getIntegerType(16));
case PrimitiveType::S32:
return builder_->getTensorType(array, builder_->getIntegerType(32));
case PrimitiveType::S64:
return builder_->getTensorType(array, builder_->getIntegerType(64));
default:
return tensorflow::errors::Internal(
absl::StrCat("Unsupported type: ", PrimitiveType_Name(type)));
}
}
StatusOr<mlir::Type> HloFunctionImporter::ConvertType(const Shape& shape) {
if (shape.IsTuple()) {
mlir::Type mlir_type;
llvm::SmallVector<mlir::Type, 4> contents;
contents.reserve(shape.tuple_shapes_size());
for (const auto& subtype : shape.tuple_shapes()) {
TF_ASSIGN_OR_RETURN(auto mlir_subtype, ConvertType(subtype));
contents.push_back(mlir_subtype);
}
return builder_->getTupleType(contents);
}
return ConvertTensorType(shape);
}
tensorflow::Status HloFunctionImporter::GetMlirTypes(
const std::vector<HloInstruction*>& instructions,
llvm::SmallVectorImpl<mlir::Type>* types) {
for (auto instruction : instructions) {
TF_ASSIGN_OR_RETURN(auto ret_type, ConvertType(instruction->shape()));
types->push_back(ret_type);
}
return tensorflow::Status::OK();
}
StatusOr<Value*> HloFunctionImporter::GetMlirValue(
HloInstruction* instruction) {
auto lookup = instruction_value_map_.find(instruction);
if (lookup != instruction_value_map_.end()) {
return lookup->second;
}
return tensorflow::errors::Internal(absl::StrCat(
"Unable to find value for input: ", instruction->ToString()));
}
mlir::NamedAttribute HloFunctionImporter::ConvertPrecisionConfig(
HloInstruction* instruction) {
llvm::SmallVector<mlir::Attribute, 4> operand_precision_attrs;
for (auto prec : instruction->precision_config().operand_precision()) {
operand_precision_attrs.push_back(
builder_->getStringAttr(PrecisionConfig_Precision_Name(prec)));
}
return builder_->getNamedAttr(
"precision_config", builder_->getArrayAttr(operand_precision_attrs));
}
mlir::NamedAttribute HloFunctionImporter::ConvertComparisonDirection(
HloInstruction* instruction) {
return builder_->getNamedAttr(
"comparison_direction",
builder_->getStringAttr(
ComparisonDirectionToString(instruction->comparison_direction())));
}
mlir::ElementsAttr HloFunctionImporter::ConvertDimensions(
llvm::ArrayRef<int64> op_dimensions) {
llvm::SmallVector<APInt, 8> dimensions;
dimensions.reserve(op_dimensions.size());
for (auto value : op_dimensions) dimensions.emplace_back(APInt(64, value));
return DenseIntElementsAttr::get(
builder_->getTensorType(dimensions.size(), builder_->getIntegerType(64)),
dimensions);
}
mlir::ElementsAttr HloFunctionImporter::Convert(
llvm::ArrayRef<int64_t> op_dimensions) {
return builder_->getDenseIntElementsAttr(
builder_->getTensorType(op_dimensions.size(),
builder_->getIntegerType(64)),
op_dimensions);
}
Status HloFunctionImporter::ValidateDotDimensions(HloInstruction* instruction) {
DotDimensionNumbers expected_dimension_numbers;
expected_dimension_numbers.add_lhs_contracting_dimensions(
instruction->operand(0)->shape().dimensions_size() == 1 ? 0 : 1);
expected_dimension_numbers.add_rhs_contracting_dimensions(0);
if (!xla::protobuf_util::ProtobufEquals(instruction->dot_dimension_numbers(),
expected_dimension_numbers)) {
return tensorflow::errors::Internal(
absl::StrCat("Dot operation has unsupported dimension numbers: ",
instruction->dot_dimension_numbers().DebugString()));
}
return Status::OK();
}
} // namespace xla

View File

@ -0,0 +1,112 @@
/* Copyright 2019 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#ifndef TENSORFLOW_COMPILER_MLIR_XLA_HLO_FUNCTION_IMPORTER_H_
#define TENSORFLOW_COMPILER_MLIR_XLA_HLO_FUNCTION_IMPORTER_H_
#include <unordered_map>
#include "mlir/IR/Attributes.h" // TF:local_config_mlir
#include "mlir/IR/Builders.h" // TF:local_config_mlir
#include "mlir/IR/Function.h" // TF:local_config_mlir
#include "mlir/IR/MLIRContext.h" // TF:local_config_mlir
#include "mlir/IR/Module.h" // TF:local_config_mlir
#include "mlir/IR/StandardTypes.h" // TF:local_config_mlir
#include "tensorflow/compiler/mlir/tensorflow/utils/error_util.h"
#include "tensorflow/compiler/xla/status.h"
#include "tensorflow/compiler/xla/statusor.h"
#include "tensorflow/compiler/xla/xla_data.pb.h"
namespace xla {
class HloModule;
class HloComputation;
class HloInstruction;
class Shape;
// Helper class for importing HloComputations.
class HloFunctionImporter {
public:
static StatusOr<mlir::FuncOp> ImportFunction(
mlir::ModuleOp module, mlir::Builder* builder,
std::unordered_map<xla::HloComputation*, mlir::FuncOp>* function_map,
xla::HloComputation* computation);
private:
HloFunctionImporter(
mlir::ModuleOp module, mlir::Builder* builder,
std::unordered_map<xla::HloComputation*, mlir::FuncOp>* function_map)
: context_(module.getContext()),
module_(module),
builder_(builder),
function_map_(function_map) {}
StatusOr<mlir::FuncOp> ImportFunction(xla::HloComputation* computation);
// Imports an instruction.
StatusOr<mlir::Operation*> ImportInstruction(xla::HloInstruction* instruction,
mlir::OpBuilder* func_builder);
// Gets the MLIR operand values from an HLO Instruction.
StatusOr<llvm::SmallVector<mlir::Value*, 4>> GetOperands(
xla::HloInstruction* instruction);
// Converts xla Tensor type to the corresponding MLIR type.
StatusOr<mlir::RankedTensorType> ConvertTensorType(const xla::Shape& shape);
// Converts xla Primitive types to the corresponding MLIR type.
StatusOr<mlir::Type> ConvertType(const xla::Shape& shape);
// Returns the output type of an HloInstruction.
StatusOr<mlir::Type> GetReturnType(xla::HloInstruction* instruction);
// Takes a list of HloInstructions and generates the list of types used for
// input, bypassing tuples to subsets.
Status GetMlirTypes(const std::vector<xla::HloInstruction*>& instructions,
llvm::SmallVectorImpl<mlir::Type>* types);
// Returns the Mlir Value for the corresponding HloInstruction.
StatusOr<mlir::Value*> GetMlirValue(xla::HloInstruction* instruction);
// Converts an XLA PrecisionConfig to the corresponding MLIR attribute.
mlir::NamedAttribute ConvertPrecisionConfig(xla::HloInstruction* instruction);
// Converts an XLA ComparisonDirection to the corresponding MLIR attribute.
mlir::NamedAttribute ConvertComparisonDirection(
xla::HloInstruction* instruction);
// Converts the dimensions of an HLO instruction into an MLIR attribute.
mlir::ElementsAttr ConvertDimensions(llvm::ArrayRef<int64> op_dimensions);
// Converts Array ref to an ElementsAttr.
mlir::ElementsAttr Convert(llvm::ArrayRef<int64_t> op_dimensions);
// Ensures dot instruction has only default contracting and batch dimensions.
Status ValidateDotDimensions(xla::HloInstruction* instruction);
mlir::MLIRContext* context_;
mlir::ModuleOp module_;
mlir::Builder* builder_;
// Mapping from HloComputation to the created MLIR function.
std::unordered_map<xla::HloComputation*, mlir::FuncOp>* function_map_;
// Mapping from HloInstructions to the associative MLIR values.
std::unordered_map<xla::HloInstruction*, mlir::Value*> instruction_value_map_;
};
} // namespace xla
#endif // TENSORFLOW_COMPILER_MLIR_XLA_HLO_FUNCTION_IMPORTER_H_

View File

@ -0,0 +1,54 @@
/* Copyright 2019 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#include "tensorflow/compiler/mlir/xla/hlo_module_importer.h"
#include "mlir/IR/Attributes.h" // TF:local_config_mlir
#include "mlir/IR/Location.h" // TF:local_config_mlir
#include "mlir/IR/OperationSupport.h" // TF:local_config_mlir
#include "mlir/IR/StandardTypes.h" // TF:local_config_mlir
#include "mlir/IR/Types.h" // TF:local_config_mlir
#include "mlir/StandardOps/Ops.h" // TF:local_config_mlir
#include "tensorflow/compiler/mlir/xla/hlo_function_importer.h"
#include "tensorflow/compiler/mlir/xla/ir/xla_ops.h"
#include "tensorflow/compiler/xla/service/hlo_computation.h"
#include "tensorflow/compiler/xla/service/hlo_instruction.h"
#include "tensorflow/compiler/xla/service/hlo_module.h"
#include "tensorflow/compiler/xla/xla.pb.h"
namespace xla {
Status HloModuleImporter::Import(const xla::HloModule& module) {
for (const auto& computation : module.computations()) {
auto result = HloFunctionImporter::ImportFunction(
module_, &builder_, &function_map_, computation);
TF_RETURN_IF_ERROR(result.status());
}
return Status::OK();
}
Status HloModuleImporter::Import(const xla::HloModuleProto& module_proto) {
xla::DebugOptions debug_options;
TF_ASSIGN_OR_RETURN(
auto module_config,
xla::HloModule::CreateModuleConfigFromProto(module_proto, debug_options));
TF_ASSIGN_OR_RETURN(auto module, xla::HloModule::CreateFromProto(
module_proto, module_config));
return Import(*module);
}
} // namespace xla

View File

@ -0,0 +1,62 @@
/* Copyright 2019 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#ifndef TENSORFLOW_COMPILER_MLIR_XLA_HLO_MODULE_IMPORTER_H_
#define TENSORFLOW_COMPILER_MLIR_XLA_HLO_MODULE_IMPORTER_H_
#include <unordered_map>
#include "mlir/IR/Builders.h" // TF:local_config_mlir
#include "mlir/IR/Function.h" // TF:local_config_mlir
#include "mlir/IR/MLIRContext.h" // TF:local_config_mlir
#include "mlir/IR/Module.h" // TF:local_config_mlir
#include "tensorflow/compiler/mlir/tensorflow/utils/error_util.h"
#include "tensorflow/compiler/mlir/xla/ir/xla_ops.h"
#include "tensorflow/compiler/xla/status.h"
#include "tensorflow/compiler/xla/xla_data.pb.h"
namespace xla {
class HloModule;
class HloModuleProto;
class HloComputation;
class HloInstruction;
class Shape;
// Importer that takes an HloModule and imports it as an MLIR module in the XLA
// dialect. HloModuleImporter does not take ownership.
class HloModuleImporter {
public:
explicit HloModuleImporter(mlir::ModuleOp module)
: module_(module), builder_(module.getContext()) {}
// Import the HloModule into the MLIR Module.
Status Import(const xla::HloModule& module);
// Import the HloModuleProto into the MLIR Module.
Status Import(const xla::HloModuleProto& module);
private:
mlir::ModuleOp module_;
mlir::Builder builder_;
// Map for tracking which MLIR function map to which HLO Computation. This
// tracks functions as they are imported and provides a quick lookup for
// functions invoked by control flow related operations (e.g. while, call).
std::unordered_map<xla::HloComputation*, mlir::FuncOp> function_map_;
};
} // namespace xla
#endif // TENSORFLOW_COMPILER_MLIR_XLA_HLO_MODULE_IMPORTER_H_

View File

@ -0,0 +1,35 @@
/* Copyright 2019 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#include "tensorflow/compiler/mlir/xla/hlo_to_mlir_hlo.h"
#include "tensorflow/compiler/mlir/xla/hlo_module_importer.h"
#include "tensorflow/compiler/xla/status_macros.h"
#include "tensorflow/core/lib/core/errors.h"
namespace xla {
Status ConvertHloToMlirHlo(mlir::ModuleOp module,
xla::HloModuleProto* hlo_module_proto) {
mlir::StatusScopedDiagnosticHandler diag_handler(module.getContext());
return HloModuleImporter(module).Import(*hlo_module_proto);
}
Status ConvertHloToMlirHlo(mlir::ModuleOp module, xla::HloModule* hlo_module) {
mlir::StatusScopedDiagnosticHandler diag_handler(module.getContext());
return HloModuleImporter(module).Import(*hlo_module);
}
} // namespace xla

View File

@ -0,0 +1,39 @@
/* Copyright 2019 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#ifndef TENSORFLOW_COMPILER_MLIR_XLA_HLO_TO_MLIR_HLO_H_
#define TENSORFLOW_COMPILER_MLIR_XLA_HLO_TO_MLIR_HLO_H_
#include "tensorflow/compiler/mlir/tensorflow/utils/error_util.h"
#include "tensorflow/compiler/xla/status.h"
namespace mlir {
class ModuleOp;
} // namespace mlir
namespace xla {
class HloModule;
class HloModuleProto;
// Converts an HLO module proto to a MLIR module in HLO dialect.
Status ConvertHloToMlirHlo(mlir::ModuleOp module,
xla::HloModuleProto* hlo_module);
// Converts an HLO module to a MLIR module in HLO dialect.
Status ConvertHloToMlirHlo(mlir::ModuleOp module, xla::HloModule* hlo_module);
} // namespace xla
#endif // TENSORFLOW_COMPILER_MLIR_XLA_HLO_TO_MLIR_HLO_H_

View File

@ -0,0 +1,259 @@
/* Copyright 2019 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#include "tensorflow/compiler/mlir/xla/mlir_hlo_to_hlo.h"
#include <memory>
#include "llvm/ADT/DenseMap.h"
#include "llvm/Support/MemoryBuffer.h"
#include "llvm/Support/SMLoc.h"
#include "llvm/Support/SourceMgr.h"
#include "llvm/Support/raw_ostream.h"
#include "mlir/IR/Attributes.h" // TF:local_config_mlir
#include "mlir/IR/Function.h" // TF:local_config_mlir
#include "mlir/IR/Location.h" // TF:local_config_mlir
#include "mlir/IR/MLIRContext.h" // TF:local_config_mlir
#include "mlir/IR/Module.h" // TF:local_config_mlir
#include "mlir/IR/Operation.h" // TF:local_config_mlir
#include "mlir/StandardOps/Ops.h" // TF:local_config_mlir
#include "tensorflow/compiler/mlir/xla/ir/xla_ops.h"
#include "tensorflow/compiler/mlir/xla/type_to_shape.h"
#include "tensorflow/compiler/xla/client/xla_builder.h"
#include "tensorflow/compiler/xla/comparison_util.h"
#include "tensorflow/compiler/xla/service/hlo_module.h"
#include "tensorflow/compiler/xla/status_macros.h"
#include "tensorflow/compiler/xla/xla_data.pb.h"
static std::vector<int64> ConvertDenseIntAttr(mlir::DenseIntElementsAttr attr) {
llvm::ArrayRef<int64> raw_data = attr.getValues<int64>();
if (attr.isSplat())
return std::vector<int64>(attr.getType().getNumElements(), raw_data[0]);
return raw_data;
}
// Converts the broadcast_dimensions attribute into a span of dimension numbers
// (empty if the attribute is absent).
static std::vector<int64> Convert_broadcast_dimensions(
llvm::Optional<mlir::ElementsAttr> broadcast_dimensions) {
if (!broadcast_dimensions.hasValue()) return {};
return ConvertDenseIntAttr(
broadcast_dimensions->cast<mlir::DenseIntElementsAttr>());
}
// Converts the broadcast_sizes attribute into a span of dimension sizes.
static std::vector<int64> Convert_broadcast_sizes(
mlir::ElementsAttr broadcast_sizes) {
return ConvertDenseIntAttr(
broadcast_sizes.cast<mlir::DenseIntElementsAttr>());
}
static std::vector<int64> Convert_permutation(mlir::ElementsAttr permutation) {
return ConvertDenseIntAttr(permutation.cast<mlir::DenseIntElementsAttr>());
}
// Converts the precision config array of strings attribute into the
// corresponding XLA proto. All the strings are assumed to be valid names of the
// Precision enum. This should have been checked in the op verify method.
static std::unique_ptr<xla::PrecisionConfig> Convert_precision_config(
llvm::Optional<mlir::ArrayAttr> optional_precision_config_attr) {
if (!optional_precision_config_attr.hasValue()) return nullptr;
auto precision_config = absl::make_unique<xla::PrecisionConfig>();
for (auto attr : optional_precision_config_attr.getValue()) {
xla::PrecisionConfig::Precision p;
auto operand_precision = attr.cast<mlir::StringAttr>().getValue().str();
// TODO(jpienaar): Update this to ensure this is captured by verify.
if (xla::PrecisionConfig::Precision_Parse(operand_precision, &p)) {
precision_config->add_operand_precision(p);
} else {
auto* context = attr.getContext();
mlir::emitError(mlir::UnknownLoc::get(context))
<< "unexpected operand precision " << operand_precision;
return nullptr;
}
}
return precision_config;
}
// Converts the comparison_direction string attribute into the XLA enum. The
// string is assumed to correspond to exactly one of the allowed strings
// representing the enum. This should have been checked in the op verify method.
static xla::ComparisonDirection Convert_comparison_direction(
llvm::StringRef comparison_direction_string) {
return xla::StringToComparisonDirection(comparison_direction_string.str())
.ValueOrDie();
}
// Passes through everything except for unique_ptr, on which it calls get().
// This exists to allow the generated code to call XLA functions that take a raw
// pointer. In particular, PrecisionConfig is passed to xla::Dot and xla::Conv
// as a pointer and there is otherwise no way to avoid a memory leak.
template <typename T>
T Unwrap(T t) {
return t;
}
template <typename T>
T* Unwrap(const std::unique_ptr<T>& t) {
return t.get();
}
// Convert APInt into an int.
// TODO(hpucha): This should be consolidated into a general place.
static int ConvertAPInt(llvm::APInt i) { return i.getSExtValue(); }
// Convert APFloat to double.
static double ConvertAPFloat(llvm::APFloat value) {
const auto& semantics = value.getSemantics();
bool losesInfo = false;
if (&semantics != &llvm::APFloat::IEEEdouble())
value.convert(llvm::APFloat::IEEEdouble(),
llvm::APFloat::rmNearestTiesToEven, &losesInfo);
return value.convertToDouble();
}
#include "tensorflow/compiler/mlir/xla/operator_writers.inc"
namespace mlir {
namespace {
class ConvertToHloModule {
public:
using ValueLoweringMap = llvm::DenseMap<Value*, xla::XlaOp>;
using FunctionLoweringMap = llvm::DenseMap<mlir::FuncOp, xla::XlaComputation>;
explicit ConvertToHloModule(mlir::ModuleOp module)
: module_(module), module_builder_("main") {}
// Perform the lowering to XLA. This function returns failure if an error was
// encountered.
LogicalResult Run() {
for (auto func : module_.getOps<FuncOp>()) {
if (func.empty()) continue;
if (failed(RunOnFunction(func))) return failure();
}
return success();
}
// Perform the lowering on a specific function. This function returns failure
// if an error was encountered.
LogicalResult RunOnFunction(mlir::FuncOp f);
xla::HloModuleProto ConsumeMainProto() {
return lowered_computation_[module_.lookupSymbol<mlir::FuncOp>("main")]
.proto();
}
private:
// The module being lowered.
mlir::ModuleOp module_;
// The top-level XlaBuilder.
xla::XlaBuilder module_builder_;
// Map between function and lowered computation.
FunctionLoweringMap lowered_computation_;
};
LogicalResult Lower(mlir::Operation* inst, xla::XlaBuilder* builder,
ConvertToHloModule::FunctionLoweringMap* function_lowering,
ConvertToHloModule::ValueLoweringMap* value_lowering) {
if (auto xla_op = CreateXlaOperator(inst, value_lowering)) return success();
// TODO(riverriddle) We currently don't support lowering constant operations.
if (isa<mlir::XLA::ConstOp>(inst)) {
inst->emitError("unable to lower 'xla.constant' operation");
return failure();
}
auto& value_map = *value_lowering;
if (auto ret = dyn_cast<mlir::ReturnOp>(inst)) {
// Construct the return value for the function. If there are multiple
// values returned, then create a tuple, else return value directly.
xla::XlaOp return_value;
unsigned num_return_values = ret.getNumOperands();
if (num_return_values > 1) {
std::vector<xla::XlaOp> returns(num_return_values);
for (unsigned i = 0, e = ret.getNumOperands(); i != e; ++i) {
returns[i] = value_map[ret.getOperand(i)];
}
return_value = xla::Tuple(builder, returns);
} else if (num_return_values == 1) {
return_value = value_map[ret.getOperand(0)];
}
// Build the XlaComputation and check for failures.
auto computation_or =
return_value.valid() ? builder->Build(return_value) : builder->Build();
if (!computation_or.ok()) {
inst->emitError(llvm::Twine(computation_or.status().error_message()));
return failure();
}
auto f = inst->getParentOfType<mlir::FuncOp>();
(*function_lowering)[f] = std::move(computation_or.ValueOrDie());
return success();
}
inst->emitError("unable to lower operation of type '" +
inst->getName().getStringRef().str() + '\'');
return failure();
}
LogicalResult ConvertToHloModule::RunOnFunction(mlir::FuncOp f) {
if (f.getBlocks().size() != 1) {
return f.emitError("only single block Function suppored");
}
// Create a sub-builder if this is not the main function.
std::unique_ptr<xla::XlaBuilder> builder_up;
bool entry_function = f.getName().str() == "main";
if (!entry_function)
builder_up = module_builder_.CreateSubBuilder(f.getName().str());
auto& builder = entry_function ? module_builder_ : *builder_up;
// Mapping from the Value to lowered XlaOp. The code below lowers in
// program order and will fail if an operand is unseen. This can be improved.
ValueLoweringMap lowering;
for (auto& bb : f) {
int num = 0;
for (auto& arg : bb.getArguments()) {
xla::Shape shape = xla::TypeToShape(arg->getType());
lowering[arg] =
xla::Parameter(&builder, num, shape, absl::StrCat("Arg_", num));
++num;
}
for (auto& inst : bb)
if (failed(Lower(&inst, &builder, &lowered_computation_, &lowering)))
return failure();
}
return success();
}
} // namespace
Status ConvertMlirHloToHlo(mlir::ModuleOp module, xla::HloProto* hlo_proto) {
mlir::StatusScopedDiagnosticHandler diag_handler(module.getContext());
ConvertToHloModule converter(module);
if (failed(converter.Run())) return diag_handler.ConsumeStatus();
auto hlo_module = converter.ConsumeMainProto();
hlo_proto->mutable_hlo_module()->Swap(&hlo_module);
return Status::OK();
}
} // namespace mlir

View File

@ -0,0 +1,37 @@
/* Copyright 2019 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#ifndef TENSORFLOW_COMPILER_MLIR_XLA_MLIR_HLO_TO_HLO_H_
#define TENSORFLOW_COMPILER_MLIR_XLA_MLIR_HLO_TO_HLO_H_
#include "mlir/IR/Module.h" // TF:local_config_mlir
#include "tensorflow/compiler/mlir/tensorflow/utils/error_util.h"
#include "tensorflow/compiler/xla/client/xla_builder.h"
#include "tensorflow/compiler/xla/service/hlo_module.h"
namespace mlir {
// Converts a MLIR module in HLO dialect into a HloModuleProto.
Status ConvertMlirHloToHlo(mlir::ModuleOp module, xla::HloProto* hlo_proto);
// Creates XlaOp equivalent of a given MLIR operation using the operand info
// from `value_lowering` map.
llvm::Optional<xla::XlaOp> CreateXlaOperator(
mlir::Operation* op,
llvm::DenseMap<mlir::Value*, xla::XlaOp>* value_lowering);
} // namespace mlir
#endif // TENSORFLOW_COMPILER_MLIR_XLA_MLIR_HLO_TO_HLO_H_

View File

@ -0,0 +1,196 @@
/* Copyright 2019 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#include <sstream>
#include "llvm/ADT/StringExtras.h"
#include "llvm/Support/CommandLine.h"
#include "llvm/Support/FormatVariadic.h"
#include "llvm/Support/PrettyStackTrace.h"
#include "llvm/Support/Signals.h"
#include "llvm/TableGen/Error.h"
#include "llvm/TableGen/Main.h"
#include "llvm/TableGen/Record.h"
#include "llvm/TableGen/TableGenBackend.h"
#include "mlir/TableGen/Operator.h" // TF:local_config_mlir
using llvm::dyn_cast;
using llvm::LessRecord;
using llvm::raw_ostream;
using llvm::Record;
using llvm::RecordKeeper;
using llvm::StringRef;
using mlir::tblgen::Operator;
// Returns the builder function name for the given op definition.
// E.g., AddOp -> CreateAddOp
static inline std::string GetOperatorBuilderName(StringRef op_name) {
return "Create" + op_name.str();
}
static std::string GetConversionFunction(
mlir::tblgen::NamedAttribute named_attr) {
auto storage_type = named_attr.attr.getStorageType();
// For some attribute types we have a general conversion, so use that.
if (storage_type.endswith("IntegerAttr") ||
storage_type.endswith("FloatAttr")) {
return "Convert" + named_attr.attr.getReturnType().str();
}
return "Convert_" + named_attr.name.str();
}
using ArgumentName = string;
using ArgumentDeclaration = string;
using Argument = std::pair<ArgumentName, ArgumentDeclaration>;
using ArgumentList = std::vector<Argument>;
static std::string BuildOperator(const Operator& op) {
std::stringstream os;
StringRef op_name = op.getCppClassName();
std::string xla_op_name = op_name.drop_back(2).str();
// Signature.
os << "static xla::XlaOp " << GetOperatorBuilderName(op_name)
<< "(mlir::XLA::" << op_name.str() << " xla_op, "
<< "llvm::DenseMap<mlir::Value*, xla::XlaOp>* "
"value_lowering) {\n";
os << " auto& value_map = *value_lowering;\n"
<< " auto result = xla_op.getResult();\n";
// Invoke the conversion function for each attribute.
for (const auto& named_attr : op.getAttributes()) {
os << " auto " << named_attr.name.str() << " = "
<< GetConversionFunction(named_attr) << "("
<< "xla_op." << named_attr.name.str() << "());\n";
}
// Assumes that the client builder method names closely follow the op names
// in the dialect. For e.g., AddOp -> xla::Add method.
os << " auto xla_result = xla::" << xla_op_name << "(";
int num_operands = op.getNumOperands();
if (num_operands == 1) {
os << "value_map[xla_op.getOperand()]";
} else {
for (auto i = 0; i < num_operands; i++) {
os << "value_map[xla_op.getOperand(" << i << ")]";
if (i != num_operands - 1) {
os << ", ";
}
}
}
for (const auto& named_attr : op.getAttributes()) {
os << ", Unwrap(" << named_attr.name.str() << ")";
}
os << ");\n";
os << " value_map[result] = xla_result;\n";
os << " return xla_result;\n";
os << "}\n\n";
return os.str();
}
// For each XLA op, emits a builder function that constructs the XLA op using
// the HLO client builder.
static void EmitOperatorBuilders(const RecordKeeper& record_keeper,
const std::vector<Record*>& defs,
raw_ostream* ostream) {
raw_ostream& os = *ostream;
for (const auto* def : defs) {
// Skip operations that have a custom converter.
if (def->getValueAsBit("hasCustomHLOConverter")) continue;
Operator op(def);
os << BuildOperator(op);
}
}
// Emits a builder function that returns the XlaOp object given a
// mlir::Operation.
//
// The signature of the function is:
//
// llvm::Optional<xla::XlaOp>
// mlir::CreateXlaOperator(
// mlir::Operation* op,
// llvm::DenseMap<mlir::Value*, xla::XlaOp>
// *value_lowering);
static void EmitBuilder(const std::vector<Record*>& defs,
raw_ostream* ostream) {
raw_ostream& os = *ostream;
// Signature
os << "llvm::Optional<xla::XlaOp>\n"
"mlir::CreateXlaOperator(mlir::Operation* op, "
"llvm::DenseMap<mlir::Value*, xla::XlaOp> "
"*value_lowering) {\n";
for (const auto* def : defs) {
// Skip operations that have a custom converter.
if (def->getValueAsBit("hasCustomHLOConverter")) continue;
StringRef op_name = def->getName().drop_front(4);
// Try to cast to each op and call the corresponding op builder.
os << " if (auto xla_op = llvm::dyn_cast<mlir::XLA::" << op_name
<< ">(op))\n return " << GetOperatorBuilderName(op_name)
<< "(xla_op, value_lowering);\n";
}
os << " return llvm::None;\n"
"}\n";
}
// The function below has a non-constant reference as that is required by LLVM's
// TableGenMain.
// NOLINTNEXTLINE
static bool OperatorWritersMain(raw_ostream& os, RecordKeeper& records) {
emitSourceFileHeader("MLIR XLA Builders", os);
// Retrieve all the definitions derived from XLA_Op and sort by record name.
std::vector<Record*> defs = records.getAllDerivedDefinitions("XLA_Op");
llvm::sort(defs, LessRecord());
for (const auto* def : defs) {
// XLA ops in the .td file are expected to follow the naming convention:
// XLA_<OpName>Op.
// The generated XLA op C++ class should be XLA::<OpName>Op.
if (!def->getName().startswith("XLA_"))
PrintFatalError(def->getLoc(),
"unexpected op name format: 'XLA_' prefix missing");
if (!def->getName().endswith("Op"))
PrintFatalError(def->getLoc(),
"unexpected op name format: 'Op' suffix missing");
}
EmitOperatorBuilders(records, defs, &os);
os << "\n\n";
EmitBuilder(defs, &os);
return false;
}
int main(int argc, char** argv) {
llvm::sys::PrintStackTraceOnErrorSignal(argv[0]);
llvm::PrettyStackTraceProgram X(argc, argv);
llvm::llvm_shutdown_obj Y;
llvm::cl::ParseCommandLineOptions(argc, argv);
return TableGenMain(argv[0], &OperatorWritersMain);
}

View File

@ -0,0 +1,23 @@
load("//tensorflow/compiler/mlir:glob_lit_test.bzl", "glob_lit_tests")
package(licenses = ["notice"])
glob_lit_tests(
data = [":test_utilities"],
driver = "@local_config_mlir//:run_lit.sh",
test_file_exts = [
"mlir",
"hlo",
"hlotxt",
],
)
# Bundle together all of the test utilities that are used by tests.
filegroup(
name = "test_utilities",
testonly = True,
data = [
"//tensorflow/compiler/mlir:tf-mlir-translate",
"@llvm//:FileCheck",
],
)

View File

@ -0,0 +1,28 @@
// RUN: tf-mlir-translate -hlo-text-to-mlir-hlo %s -o - | FileCheck %s
// This test is more thorough than those of the the other binary ops to test
// their shared functionality.
HloModule main.5
// CHECK-LABEL: func @main(%arg0: tensor<4xf32>, %arg1: tensor<4xf32>, %arg2: tensor<f32>, %arg3: tensor<f32>) -> tensor<4xf32> {
ENTRY %foo.5 (Arg_0.1: f32[4], Arg_1.2: f32[4], Arg_2.3: f32[], Arg_3.4: f32[]) -> f32[4] {
%Arg_0.1 = f32[4] parameter(0)
%Arg_1.2 = f32[4] parameter(1)
%Arg_2.3 = f32[] parameter(2)
%Arg_3.4 = f32[] parameter(3)
// Add two tensors
// CHECK-NEXT: %0 = "xla.add"(%arg0, %arg1) {name = "add.3"} : (tensor<4xf32>, tensor<4xf32>) -> tensor<4xf32>
%add.3 = f32[4] add(f32[4] %Arg_0.1, f32[4] %Arg_1.2)
// Add two scalars
// CHECK-NEXT: %1 = "xla.add"(%arg2, %arg3) {name = "add.4"} : (tensor<f32>, tensor<f32>) -> tensor<f32>
%add.4 = f32[] add(f32[] %Arg_2.3, f32[] %Arg_3.4)
// Add a tensor and scalar
// CHECK-NEXT: %2 = "xla.add"(%0, %1) {name = "add.5"} : (tensor<4xf32>, tensor<f32>) -> tensor<4xf32>
// CHECK-NEXT: return %2 : tensor<4xf32>
ROOT %add.5 = f32[4] add(f32[4] %add.3, f32[] %add.4)
}

View File

@ -0,0 +1,14 @@
// RUN: tf-mlir-translate -mlir-hlo-to-hlo-text %s | FileCheck %s
// CHECK-LABEL: ENTRY %main.5 (Arg_0.1: f32[4], Arg_1.2: f32[4]) -> f32[4] {
func @main(%arg0: tensor<4xf32>, %arg1: tensor<4xf32>) -> tensor<4xf32> {
// CHECK-NEXT: %Arg_0.1 = f32[4] parameter(0)
// CHECK-NEXT: %Arg_1.2 = f32[4] parameter(1)
// CHECK-NEXT: %add.3 = f32[4] add(f32[4] %Arg_0.1, f32[4] %Arg_1.2)
%0 = "xla.add"(%arg0, %arg1) : (tensor<4xf32>, tensor<4xf32>) -> tensor<4xf32>
// CHECK-NEXT: ROOT %add.4 = f32[4] add(f32[4] %add.3, f32[4] %Arg_1.2)
%1 = "xla.add"(%0, %arg1) : (tensor<4xf32>, tensor<4xf32>) -> tensor<4xf32>
return %1 : tensor<4xf32>
}

View File

@ -0,0 +1,14 @@
// RUN: tf-mlir-translate -hlo-text-to-mlir-hlo %s -o - | FileCheck %s
HloModule main.5
// CHECK-LABEL: func @main(%arg0: tensor<4xf32>, %arg1: tensor<4xf32>) -> tensor<4xf32> {
ENTRY %foo.5 (Arg_0.1: f32[4], Arg_1.2: f32[4]) -> f32[4] {
%Arg_0.1 = f32[4] parameter(0)
%Arg_1.2 = f32[4] parameter(1)
// CHECK-NEXT: %0 = "xla.and"(%arg0, %arg1) {name = "and.3"} : (tensor<4xf32>, tensor<4xf32>) -> tensor<4xf32>
// CHECK-NEXT: return %0 : tensor<4xf32>
ROOT %and.3 = f32[4] and(f32[4] %Arg_0.1, f32[4] %Arg_1.2)
}

View File

@ -0,0 +1,26 @@
// RUN: tf-mlir-translate -mlir-hlo-to-hlo-text %s | FileCheck %s
// CHECK-LABEL: ENTRY %main.13 (Arg_0.1: s32[1,4], Arg_1.2: s32[2,4], Arg_2.3: s32[2,3,4]) -> s32[2,3,4] {
func @main(%arg0: tensor<1x4xi32>, %arg1: tensor<2x4xi32>, %arg2: tensor<2x3x4xi32>) -> tensor<2x3x4xi32> {
// Same rank degenerate broadcast
// CHECK-NEXT: %Arg_0.1 = s32[1,4] parameter(0)
// CHECK-NEXT: %reshape.4 = s32[4] reshape(s32[1,4] %Arg_0.1)
// CHECK-NEXT: %broadcast.5 = s32[2,4] broadcast(s32[4] %reshape.4)
// CHECK-NEXT: %Arg_1.2 = s32[2,4] parameter(1)
// CHECK-NEXT: %add.6 = s32[2,4] add(s32[2,4] %broadcast.5, s32[2,4] %Arg_1.2)
%0 = "xla.add"(%arg0, %arg1) : (tensor<1x4xi32>, tensor<2x4xi32>) -> tensor<2x4xi32>
// Broadcast up rank
// CHECK-NEXT: %broadcast.7 = s32[2,3,4] broadcast(s32[2,4] %Arg_1.2), dimensions={0,2}
// CHECK-NEXT: %Arg_2.3 = s32[2,3,4] parameter(2)
// CHECK-NEXT: %add.8 = s32[2,3,4] add(s32[2,3,4] %broadcast.7, s32[2,3,4] %Arg_2.3)
%1 = "xla.add"(%arg1, %arg2) {broadcast_dimensions = dense<[0,2]> : tensor<2xi64>} : (tensor<2x4xi32>, tensor<2x3x4xi32>) -> tensor<2x3x4xi32>
// Broadcast up rank + degenerate broadcast
// CHECK-NEXT: %broadcast.9 = s32[2,1,4] broadcast(s32[1,4] %Arg_0.1), dimensions={1,2}
// CHECK-NEXT: %reshape.10 = s32[2,4] reshape(s32[2,1,4] %broadcast.9)
// CHECK-NEXT: %broadcast.11 = s32[2,3,4] broadcast(s32[2,4] %reshape.10), dimensions={0,2}
// CHECK-NEXT: ROOT %add.12 = s32[2,3,4] add(s32[2,3,4] %broadcast.11, s32[2,3,4] %Arg_2.3)
%2 = "xla.add"(%arg0, %arg2) {broadcast_dimensions = dense<[1,2]> : tensor<2xi64>} : (tensor<1x4xi32>, tensor<2x3x4xi32>) -> tensor<2x3x4xi32>
return %2 : tensor<2x3x4xi32>
}

View File

@ -0,0 +1,9 @@
// RUN: tf-mlir-translate -mlir-hlo-to-hlo-text %s | FileCheck %s
// CHECK-LABEL: ENTRY %main.3 (Arg_0.1: s32[4]) -> s32[1,2,3,4] {
func @main(%arg0: tensor<4xi32>) -> tensor<1x2x3x4xi32> {
// CHECK-NEXT: %Arg_0.1 = s32[4] parameter(0)
// CHECK-NEXT: ROOT %broadcast.2 = s32[1,2,3,4] broadcast(s32[4] %Arg_0.1), dimensions={3}
%0 = "xla.broadcast"(%arg0) {broadcast_sizes = dense<[1,2,3]> : tensor<3xi64>} : (tensor<4xi32>) -> tensor<1x2x3x4xi32>
return %0 : tensor<1x2x3x4xi32>
}

View File

@ -0,0 +1,20 @@
// RUN: tf-mlir-translate -hlo-text-to-mlir-hlo %s -o - | FileCheck %s
HloModule main
// CHECK-LABEL: func @main(%arg0: tensor<1x2xf32>) -> tensor<3x1x2xf32> {
ENTRY %main {
%Arg_0.1 = f32[1, 2] parameter(0)
// CHECK-NEXT: %0 = "xla.broadcast_in_dim"(%arg0) {broadcast_dimensions = dense<[0, 1]> : tensor<2xi64>, name = "broadcast.2"} : (tensor<1x2xf32>) -> tensor<1x2x3xf32>
%broadcast.2 = f32[1,2,3] broadcast(%Arg_0.1), dimensions={0,1}
// Degenerate broadcast
// CHECK-NEXT: %1 = "xla.broadcast_in_dim"(%arg0) {name = "broadcast.3"} : (tensor<1x2xf32>) -> tensor<3x2xf32>
broadcast.3 = f32[3,2] broadcast(%Arg_0.1), dimensions={}
// CHECK-NEXT: %2 = "xla.broadcast_in_dim"(%arg0) {broadcast_dimensions = dense<[1, 2]> : tensor<2xi64>, name = "broadcast.4"} : (tensor<1x2xf32>) -> tensor<3x1x2xf32>
// CHECK-NEXT: return %2 : tensor<3x1x2xf32>
ROOT broadcast.4 = f32[3,1,2] broadcast(%Arg_0.1), dimensions={1, 2}
}

View File

@ -0,0 +1,19 @@
// RUN: tf-mlir-translate -hlo-text-to-mlir-hlo %s -o - | FileCheck %s
HloModule foo
// CHECK-LABEL: func @call(%arg0: tensor<i64>) -> tensor<i64> {
%call (arg_1: s64[]) -> s64[] {
%arg_1 = s64[] parameter(0), metadata={op_name="XLA_Args"}
// CHECK-NEXT: %0 = "xla.add"(%arg0, %arg0) {name = "compare.2"} : (tensor<i64>, tensor<i64>) -> tensor<i64>
// CHECK-NEXT: return %0 : tensor<i64>
ROOT %compare.2 = s64[] add(%arg_1, %arg_1), metadata={op_type="Less" op_name="Less"}
}
// CHECK-LABEL: func @main(%arg0: tensor<i64>) -> tensor<i64> {
ENTRY %foo (arg0.1: s64[]) -> s64[] {
%arg0.1 = s64[] parameter(0), metadata={op_name="XLA_Args"}
// CHECK-NEXT: %0 = call @call(%arg0) : (tensor<i64>) -> tensor<i64>
// CHECK-NEXT: return %0 : tensor<i64>
ROOT %call.2 = s64[] call(%arg0.1), to_apply=%call
}

View File

@ -0,0 +1,21 @@
// RUN: tf-mlir-translate -hlo-text-to-mlir-hlo %s -o - | FileCheck %s
HloModule main
// CHECK-LABEL: func @main(%arg0: tensor<3xf32>, %arg1: tensor<3xf32>, %arg2: tensor<1xf32>) -> tensor<3xi1> {
ENTRY %main (Arg_0.1: f32[3], Arg_1.2: f32[3], Arg_2.3: f32[1]) -> pred[3] {
%Arg_0.1 = f32[3] parameter(0)
%Arg_1.2 = f32[3] parameter(1)
%Arg_2.3 = f32[1] parameter(2)
// CHECK-NEXT: %0 = "xla.compare"(%arg0, %arg1) {comparison_direction = "EQ", name = "compare.4"} : (tensor<3xf32>, tensor<3xf32>) -> tensor<3xi1>
%compare.4 = pred[3] compare(Arg_0.1, Arg_1.2), direction=EQ
// CHECK-NEXT: %1 = "xla.compare"(%arg0, %arg1) {comparison_direction = "LE", name = "compare.5"} : (tensor<3xf32>, tensor<3xf32>) -> tensor<3xi1>
%compare.5 = pred[3] compare(Arg_0.1, Arg_1.2), direction=LE
// Requires broadcast of compatible tensors.
// CHECK-NEXT: %2 = "xla.compare"(%arg0, %arg2) {comparison_direction = "GT", name = "compare.6"} : (tensor<3xf32>, tensor<1xf32>) -> tensor<3xi1>
// CHECK-NEXT: return %2 : tensor<3xi1>
ROOT %compare.6 = pred[3] compare(Arg_0.1, Arg_2.3), direction=GT
}

View File

@ -0,0 +1,13 @@
// RUN: tf-mlir-translate -hlo-text-to-mlir-hlo %s -o - | FileCheck %s
HloModule main.5
// CHECK-LABEL: func @main(%arg0: tensor<4x1xf32>, %arg1: tensor<4x2xf32>) -> tensor<4x3xf32> {
ENTRY %foo.5 (Arg_0.1: f32[4, 1], Arg_1.2: f32[4, 2]) -> f32[4, 3] {
%Arg_0.1 = f32[4, 1] parameter(0)
%Arg_1.2 = f32[4, 2] parameter(1)
// CHECK-NEXT: %0 = "xla.concatenate"(%arg0, %arg1) {dimension = 1 : i64} : (tensor<4x1xf32>, tensor<4x2xf32>) -> tensor<4x3xf32>
// CHECK-NEXT: return %0 : tensor<4x3xf32>
ROOT %concatenate.3 = f32[4, 3] concatenate(f32[4, 1] %Arg_0.1, f32[4, 2] %Arg_1.2), dimensions={1}
}

View File

@ -0,0 +1,21 @@
// RUN: tf-mlir-translate -hlo-text-to-mlir-hlo %s -o - | FileCheck %s
HloModule tfcompile.7
// CHECK-LABEL: func @main() -> tensor<2x2x1x1xf32> {
ENTRY %tfcompile.7 {
// Scalar/0D tensor constant
// CHECK-NEXT: %cst = constant {name = "constant.0"} dense<1> : tensor<i64>
%constant.0 = s64[] constant(1)
// Note that double brackets "[[" have to be escaped as they denote variables
// in FileCheck. The only way to do so is to drop into regex with "{{"
// CHECK-NEXT: %cst_0 = constant {name = "constant.1"} dense<{{\[\[\[\[}}1.000000e+00]], {{\[\[}}2.000000e+00]]], {{\[\[\[}}3.000000e+00]], {{\[\[}}4.000000e+00]]]]> : tensor<2x2x1x1xf32>
// CHECK-NEXT: return %cst_0 : tensor<2x2x1x1xf32>
ROOT %constant.1 = f32[2,2,1,1]{3,2,1,0} constant({{{{1.0}},{{2.0}}},{{{3.0}},{{4.0}}}}), metadata={op_type="Conv2D" op_name="embedded_inference/conv_model/conv_0/Conv2D"}
}

View File

@ -0,0 +1,31 @@
// RUN: tf-mlir-translate -hlo-text-to-mlir-hlo %s -o - | FileCheck %s
HloModule tfcompile.7
// TODO(b/129422361) Potentially update when copy, reshape, and conv have actual
// implementations with attributes, etc.
// CHECK-LABEL: func @main(%arg0: tensor<1x16x16x1xf32>) -> tuple<tensor<1x16x16x1xf32>> {
ENTRY %tfcompile.7 {
%arg0.1 = f32[1,16,16,1]{3,2,1,0} parameter(0), metadata={op_name="XLA_Args"}
// CHECK-NEXT: %0 = "xla.copy"(%arg0) {name = "copy.1"} : (tensor<1x16x16x1xf32>) -> tensor<1x16x16x1xf32>
%copy.1 = f32[1,16,16,1]{2,1,3,0} copy(%arg0.1), metadata={op_name="XLA_Args"}
// CHECK-NEXT: %1 = "xla.reshape"(%0) {name = "reshape.2"} : (tensor<1x16x16x1xf32>) -> tensor<1x16x16x1xf32>
%reshape.2 = f32[1,16,16,1]{2,1,3,0} reshape(%copy.1)
// Note that double brackets "[[" have to be escaped as they denote variables
// in FileCheck. The only way to do so is to drop into regex with "{{"
// CHECK-NEXT: %cst = constant {name = "constant.3"} dense<{{\[\[\[\[}}5.000000e-01]], {{\[\[}}-6.000000e-01]]], {{\[\[\[}}3.000000e-01]], {{\[\[}}-1.000000e-01]]]]> : tensor<2x2x1x1xf32>
%constant.3 = f32[2,2,1,1]{3,2,1,0} constant({{{{0.5}}, {{-0.6}}}, {{{0.3}}, {{-0.1}}}}), metadata={op_type="Conv2D" op_name="embedded_inference/conv_model/conv_0/Conv2D"}
// CHECK-NEXT: %2 = "xla.conv"(%1, %cst) {name = "convolution.4"} : (tensor<1x16x16x1xf32>, tensor<2x2x1x1xf32>) -> tensor<1x16x16x1xf32>
%convolution.4 = f32[1,16,16,1]{2,1,3,0} convolution(%reshape.2, %constant.3), window={size=2x2 pad=0_1x0_1}, dim_labels=b01f_01io->b01f, metadata={op_type="Conv2D" op_name="embedded_inference/conv_model/conv_0/Conv2D"}
// CHECK-NEXT: %3 = "xla.reshape"(%2) {name = "reshape.5"} : (tensor<1x16x16x1xf32>) -> tensor<1x16x16x1xf32>
%reshape.5 = f32[1,16,16,1]{3,2,1,0} reshape(%convolution.4), metadata={op_name="XLA_Retvals"}
// CHECK-NEXT: %4 = "xla.tuple"(%3) {name = "tuple.6"} : (tensor<1x16x16x1xf32>) -> tuple<tensor<1x16x16x1xf32>>
// CHECK-NEXT: return %4 : tuple<tensor<1x16x16x1xf32>>
ROOT %tuple.6 = (f32[1,16,16,1]{3,2,1,0}) tuple(%reshape.5), metadata={op_name="XLA_Retvals"}
}

View File

@ -0,0 +1,20 @@
// RUN: tf-mlir-translate -hlo-text-to-mlir-hlo %s -o - | FileCheck %s
HloModule main.5
// CHECK-LABEL: func @main(%arg0: tensor<4xf32>, %arg1: tensor<f32>) -> tensor<4xf64> {
ENTRY %foo.5 (Arg_0.1: f32[4], Arg_1.2: f32[]) -> f64[4] {
%Arg_0.1 = f32[4] parameter(0)
%Arg_1.2 = f32[] parameter(1)
// CHECK-NEXT: %0 = "xla.convert"(%arg0) {name = "convert.3"} : (tensor<4xf32>) -> tensor<4xf64>
%convert.3 = f64[4] convert(f32[4] %Arg_0.1)
// CHECK-NEXT: %1 = "xla.convert"(%arg1) {name = "convert.4"} : (tensor<f32>) -> tensor<f64>
%convert.4 = f64[] convert(f32[] %Arg_1.2)
// CHECK-NEXT: %2 = "xla.add"(%0, %1) {name = "add.5"} : (tensor<4xf64>, tensor<f64>) -> tensor<4xf64>
// CHECK-NEXT: return %2 : tensor<4xf64>
ROOT %add.5 = f64[4] add(f64[4] %convert.3, f64[] %convert.4)
}

View File

@ -0,0 +1,14 @@
// RUN: tf-mlir-translate -hlo-text-to-mlir-hlo %s -o - | FileCheck %s
HloModule main.5
// CHECK-LABEL: func @main(%arg0: tensor<4xf32>, %arg1: tensor<4xf32>) -> tensor<4xf32> {
ENTRY %foo.5 (Arg_0.1: f32[4], Arg_1.2: f32[4]) -> f32[4] {
%Arg_0.1 = f32[4] parameter(0)
%Arg_1.2 = f32[4] parameter(1)
// CHECK-NEXT: %0 = "xla.div"(%arg0, %arg1) {name = "divide.3"} : (tensor<4xf32>, tensor<4xf32>) -> tensor<4xf32>
// CHECK-NEXT: return %0 : tensor<4xf32>
ROOT %divide.3 = f32[4] divide(f32[4] %Arg_0.1, f32[4] %Arg_1.2)
}

View File

@ -0,0 +1,23 @@
// RUN: tf-mlir-translate -hlo-text-to-mlir-hlo %s -o - | FileCheck %s
HloModule main
// CHECK-LABEL: func @main(%arg0: tensor<1x4xf32>, %arg1: tensor<4x1xf32>) -> tensor<f32> {
ENTRY %main (Arg_0.1: f32[1, 4], Arg_1.2: f32[4, 1]) -> f32[] {
%Arg_0.1 = f32[1, 4] parameter(0)
%Arg_1.2 = f32[4, 1] parameter(1)
// CHECK-NEXT: %0 = "xla.dot"(%arg0, %arg1) {name = "dot.3", precision_config = ["HIGH", "HIGHEST"]} : (tensor<1x4xf32>, tensor<4x1xf32>) -> tensor<f32>
dot.3 = f32[] dot(Arg_0.1, Arg_1.2), lhs_contracting_dims={1}, rhs_contracting_dims={0}, operand_precision={high,highest}
// CHECK-NEXT: %1 = "xla.dot"(%arg0, %arg1) {name = "dot.4", precision_config = ["HIGHEST", "DEFAULT"]} : (tensor<1x4xf32>, tensor<4x1xf32>) -> tensor<f32>
dot.4 = f32[] dot(Arg_0.1, Arg_1.2), lhs_contracting_dims={1}, rhs_contracting_dims={0}, operand_precision={highest,default}
// CHECK-NEXT: %2 = "xla.dot"(%arg0, %arg1) {name = "dot.5", precision_config = ["DEFAULT", "DEFAULT"]} : (tensor<1x4xf32>, tensor<4x1xf32>) -> tensor<f32>
%dot.5 = f32[] dot(Arg_0.1, Arg_1.2), lhs_contracting_dims={1}, rhs_contracting_dims={0}, operand_precision={default,default}
// TODO(b/129709049) consider making this default precision config inferred.
// CHECK-NEXT: %3 = "xla.dot"(%arg0, %arg1) {name = "dot.6", precision_config = ["DEFAULT", "DEFAULT"]} : (tensor<1x4xf32>, tensor<4x1xf32>) -> tensor<f32>
// CHECK-NEXT: return %3 : tensor<f32>
ROOT %dot.6 = f32[] dot(Arg_0.1, Arg_1.2), lhs_contracting_dims={1}, rhs_contracting_dims={0}
}

View File

@ -0,0 +1,26 @@
// RUN: tf-mlir-translate -hlo-text-to-mlir-hlo %s -o - | FileCheck %s
HloModule main
// CHECK-LABEL: func @dynamic.update.slice.1(%arg0: tensor<4x4xf32>, %arg1: tensor<1x4xf32>, %arg2: tensor<f32>, %arg3: tensor<f32>) -> tensor<4x4xf32> {
%dynamic.update.slice.1 (Arg_0.1: f32[4, 4], Arg_1.2: f32[1, 4], Arg_2.3: f32[], Arg_3.4: f32[]) -> f32[4, 4] {
%Arg_0.1 = f32[4, 4] parameter(0)
%Arg_1.2 = f32[1, 4] parameter(1)
%Arg_2.3 = f32[] parameter(2)
%Arg_3.4 = f32[] parameter(3)
// CHECK-NEXT: %0 = "xla.dynamic-update-slice"(%arg0, %arg1, %arg2, %arg3) : (tensor<4x4xf32>, tensor<1x4xf32>, tensor<f32>, tensor<f32>) -> tensor<4x4xf32>
// CHECK-NEXT: return %0 : tensor<4x4xf32>
ROOT %dynamic-update-slice.5 = f32[4, 4] dynamic-update-slice(%Arg_0.1, %Arg_1.2, %Arg_2.3, %Arg_3.4)
}
// CHECK-LABEL: func @main(%arg0: tensor<4xf32>, %arg1: tensor<2xf32>, %arg2: tensor<f32>) -> tensor<4xf32>
%dynamic.update.slice.2 (Arg_0.1: f32[4], Arg_1.2: f32[2], Arg_2.3: f32[]) -> f32[4] {
%Arg_0.1 = f32[4] parameter(0)
%Arg_1.2 = f32[2] parameter(1)
%Arg_2.3 = f32[] parameter(2)
// CHECK-NEXT: %0 = "xla.dynamic-update-slice"(%arg0, %arg1, %arg2) : (tensor<4xf32>, tensor<2xf32>, tensor<f32>) -> tensor<4xf32>
// CHECK-NEXT: return %0 : tensor<4xf32>
ROOT %dynamic-update-slice.5 = f32[4] dynamic-update-slice(%Arg_0.1, %Arg_1.2, %Arg_2.3)
}

View File

@ -0,0 +1,103 @@
// RUN: tf-mlir-translate -hlo-text-to-mlir-hlo %s -o - | FileCheck %s
// This test comes from a fully connected reference model.
HloModule tfcompile.48
// CHECK-LABEL: func @main(%arg0: tensor<1x300xf32>, %arg1: tensor<1x300x3x1xf32>) -> tuple<tensor<300x1x5xf32>> {
ENTRY %tfcompile.48 {
%arg0.1 = f32[1,300] parameter(0)
%arg1.2 = f32[1,300,3,1] parameter(1)
// CHECK-NEXT: %0 = "xla.reshape"(%arg0) {name = "reshape.3"} : (tensor<1x300xf32>) -> tensor<1x300xf32>
%reshape.3 = f32[1,300] reshape(%arg0.1)
// CHECK-NEXT: %1 = "xla.transpose"(%0) {name = "transpose.27", permutation = dense<[1, 0]> : tensor<2xi64>} : (tensor<1x300xf32>) -> tensor<300x1xf32>
%transpose.27 = f32[300,1] transpose(%reshape.3), dimensions={1,0}
// CHECK-NEXT: %2 = "xla.reshape"(%1) {name = "reshape.28"} : (tensor<300x1xf32>) -> tensor<300x1x1xf32>
%reshape.28 = f32[300,1,1] reshape(%transpose.27)
// CHECK-NEXT: %3 = "xla.reshape"(%2) {name = "reshape.29"} : (tensor<300x1x1xf32>) -> tensor<300x1xf32>
%reshape.29 = f32[300,1] reshape(%reshape.28)
// CHECK-NEXT: %4 = "xla.broadcast_in_dim"(%3) {broadcast_dimensions = dense<[0, 1]> : tensor<2xi64>, name = "broadcast.30"} : (tensor<300x1xf32>) -> tensor<300x1x5xf32>
%broadcast.30 = f32[300,1,5] broadcast(%reshape.29), dimensions={0,1}
// CHECK-NEXT: %cst = constant {name = "constant.8"} dense<1.000000e+00> : tensor<f32>
%constant.8 = f32[] constant(1)
// CHECK-NEXT: %5 = "xla.broadcast_in_dim"(%cst) {name = "broadcast.9"} : (tensor<f32>) -> tensor<300x1x5xf32>
%broadcast.9 = f32[300,1,5] broadcast(%constant.8), dimensions={}
// CHECK-NEXT: %6 = "xla.mul"(%4, %5) {name = "multiply.31"} : (tensor<300x1x5xf32>, tensor<300x1x5xf32>) -> tensor<300x1x5xf32>
%multiply.31 = f32[300,1,5] multiply(%broadcast.30, %broadcast.9)
// CHECK-NEXT: %cst_0 = constant {name = "constant.32"} dense<0.000000e+00> : tensor<f32>
%constant.32 = f32[] constant(0)
// CHECK-NEXT: %7 = "xla.broadcast_in_dim"(%cst_0) {name = "broadcast.33"} : (tensor<f32>) -> tensor<300x1x5xf32>
%broadcast.33 = f32[300,1,5] broadcast(%constant.32), dimensions={}
// CHECK-NEXT: %8 = "xla.compare"(%6, %7) {comparison_direction = "GT", name = "compare.34"} : (tensor<300x1x5xf32>, tensor<300x1x5xf32>) -> tensor<300x1x5xi1>
%compare.34 = pred[300,1,5] compare(%multiply.31, %broadcast.33), direction=GT
// CHECK-NEXT: %cst_1 = constant {name = "constant.10"} dense<0.000000e+00> : tensor<f32>
%constant.10 = f32[] constant(0)
// CHECK-NEXT: %9 = "xla.broadcast_in_dim"(%cst_1) {name = "broadcast.11"} : (tensor<f32>) -> tensor<300x1x5xf32>
%broadcast.11 = f32[300,1,5] broadcast(%constant.10), dimensions={}
// CHECK-NEXT: %cst_2 = constant {name = "constant.40"} dense<0.000000e+00> : tensor<f32>
%constant.40 = f32[] constant(0)
// CHECK-NEXT: %10 = "xla.broadcast_in_dim"(%cst_2) {name = "broadcast.41"} : (tensor<f32>) -> tensor<300x5xf32>
%broadcast.41 = f32[300,5] broadcast(%constant.40), dimensions={}
// CHECK-NEXT: %11 = "xla.copy"(%arg1) {name = "copy.1"} : (tensor<1x300x3x1xf32>) -> tensor<1x300x3x1xf32>
%copy.1 = f32[1,300,3,1] copy(%arg1.2)
// CHECK-NEXT: %12 = "xla.reshape"(%11) {name = "reshape.4"} : (tensor<1x300x3x1xf32>) -> tensor<1x300x3x1xf32>
%reshape.4 = f32[1,300,3,1] reshape(%copy.1)
// CHECK-NEXT: %13 = "xla.reshape"(%12) {name = "reshape.24"} : (tensor<1x300x3x1xf32>) -> tensor<1x300x3xf32>
%reshape.24 = f32[1,300,3] reshape(%reshape.4)
// CHECK-NEXT: %14 = "xla.transpose"(%13) {name = "transpose.25", permutation = dense<[1, 0, 2]> : tensor<3xi64>} : (tensor<1x300x3xf32>) -> tensor<300x1x3xf32>
%transpose.25 = f32[300,1,3] transpose(%reshape.24), dimensions={1,0,2}
// CHECK-NEXT: %15 = "xla.reshape"(%14) {name = "reshape.26"} : (tensor<300x1x3xf32>) -> tensor<300x3xf32>
%reshape.26 = f32[300,3] reshape(%transpose.25)
// CHECK-NEXT: %cst_3 = constant {name = "constant.35"} dense<{{\[\[}}-1.060230e-01, 1.215050e-01, 8.002390e-01, -7.688850e-01, 0.0966112986], [6.890140e-01, -4.070560e-01, -0.797852993, 3.789250e-03, -2.088810e-01], [-6.085290e-01, 2.766170e-02, 2.685570e-01, 5.774010e-01, -4.284370e-01]]> : tensor<3x5xf32>
%constant.35 = f32[3,5] constant({ { -0.106023, 0.121505, 0.800239, -0.768885, 0.0966113 }, { 0.689014, -0.407056, -0.797853, 0.00378925, -0.208881 }, { -0.608529, 0.0276617, 0.268557, 0.577401, -0.428437 } })
// TODO(b/129709049) consider making this default precision config implied.
// CHECK-NEXT: %16 = "xla.dot"(%15, %cst_3) {name = "dot.36", precision_config = ["DEFAULT", "DEFAULT"]} : (tensor<300x3xf32>, tensor<3x5xf32>) -> tensor<300x5xf32>
%dot.36 = f32[300,5] dot(%reshape.26, %constant.35), lhs_contracting_dims={1}, rhs_contracting_dims={0}
// CHECK-NEXT: %cst_4 = constant {name = "constant.37"} dense<0.000000e+00> : tensor<5xf32>
%constant.37 = f32[5]{0} constant({0, 0, 0, 0, 0})
// CHECK-NEXT: %17 = "xla.broadcast_in_dim"(%cst_4) {broadcast_dimensions = dense<1> : tensor<1xi64>, name = "broadcast.38"} : (tensor<5xf32>) -> tensor<300x5xf32>
%broadcast.38 = f32[300,5] broadcast(%constant.37), dimensions={1}
// CHECK-NEXT: %18 = "xla.add"(%16, %17) {name = "add.39"} : (tensor<300x5xf32>, tensor<300x5xf32>) -> tensor<300x5xf32>
%add.39 = f32[300,5] add(%dot.36, %broadcast.38)
// CHECK-NEXT: %19 = "xla.max"(%10, %18) {name = "maximum.42"} : (tensor<300x5xf32>, tensor<300x5xf32>) -> tensor<300x5xf32>
%maximum.42 = f32[300,5] maximum(%broadcast.41, %add.39)
// CHECK-NEXT: %20 = "xla.reshape"(%19) {name = "reshape.44"} : (tensor<300x5xf32>) -> tensor<300x1x5xf32>
%reshape.44 = f32[300,1,5] reshape(%maximum.42)
// CHECK-NEXT: %21 = "xla.select"(%8, %9, %20) {name = "select.45"} : (tensor<300x1x5xi1>, tensor<300x1x5xf32>, tensor<300x1x5xf32>) -> tensor<300x1x5xf32>
%select.45 = f32[300,1,5] select(%compare.34, %broadcast.11, %reshape.44)
// CHECK-NEXT: %22 = "xla.reshape"(%21) {name = "reshape.46"} : (tensor<300x1x5xf32>) -> tensor<300x1x5xf32>
%reshape.46 = f32[300,1,5] reshape(%select.45)
// CHECK-NEXT: %23 = "xla.tuple"(%22) {name = "tuple.47"} : (tensor<300x1x5xf32>) -> tuple<tensor<300x1x5xf32>>
// CHECK-NEXT: return %23 : tuple<tensor<300x1x5xf32>>
ROOT %tuple.47 = (f32[300,1,5]) tuple(%reshape.46)
}

View File

@ -0,0 +1,17 @@
// RUN: tf-mlir-translate -hlo-text-to-mlir-hlo %s -o - | FileCheck %s
HloModule main.5
// CHECK-LABEL: func @main() -> tensor<4xf32> {
ENTRY %iota.1 () -> f32[4] {
// CHECK-NEXT: %0 = "xla.iota"() {iota_dimension = 0 : i64} : () -> tensor<4xf32>
// CHECK-NEXT: return %0 : tensor<4xf32>
ROOT %iota.0 = f32[4] iota(), iota_dimension=0
}
// CHECK-LABEL: func @iota.2() -> tensor<4x5xf32> {
%iota.2 () -> f32[4, 5] {
// CHECK-NEXT: %0 = "xla.iota"() {iota_dimension = 1 : i64} : () -> tensor<4x5xf32>
// CHECK-NEXT: return %0 : tensor<4x5xf32>
ROOT %iota.0 = f32[4, 5] iota(), iota_dimension=1
}

View File

@ -0,0 +1,14 @@
// RUN: tf-mlir-translate -hlo-text-to-mlir-hlo %s -o - | FileCheck %s
HloModule main.5
// CHECK-LABEL: func @main(%arg0: tensor<4xf32>, %arg1: tensor<4xf32>) -> tensor<4xf32> {
ENTRY %foo.5 (Arg_0.1: f32[4], Arg_1.2: f32[4]) -> f32[4] {
%Arg_0.1 = f32[4] parameter(0)
%Arg_1.2 = f32[4] parameter(1)
// CHECK-NEXT: %0 = "xla.max"(%arg0, %arg1) {name = "maximum.3"} : (tensor<4xf32>, tensor<4xf32>) -> tensor<4xf32>
// CHECK-NEXT: return %0 : tensor<4xf32>
ROOT %maximum.3 = f32[4] maximum(f32[4] %Arg_0.1, f32[4] %Arg_1.2)
}

View File

@ -0,0 +1,14 @@
// RUN: tf-mlir-translate -hlo-text-to-mlir-hlo %s -o - | FileCheck %s
HloModule main.5
// CHECK-LABEL: func @main(%arg0: tensor<4xf32>, %arg1: tensor<4xf32>) -> tensor<4xf32> {
ENTRY %foo.5 (Arg_0.1: f32[4], Arg_1.2: f32[4]) -> f32[4] {
%Arg_0.1 = f32[4] parameter(0)
%Arg_1.2 = f32[4] parameter(1)
// CHECK-NEXT: %0 = "xla.min"(%arg0, %arg1) {name = "minimum.3"} : (tensor<4xf32>, tensor<4xf32>) -> tensor<4xf32>
// CHECK-NEXT: return %0 : tensor<4xf32>
ROOT %minimum.3 = f32[4] minimum(f32[4] %Arg_0.1, f32[4] %Arg_1.2)
}

View File

@ -0,0 +1,14 @@
// RUN: tf-mlir-translate -hlo-text-to-mlir-hlo %s -o - | FileCheck %s
HloModule main.5
// CHECK-LABEL: func @main(%arg0: tensor<4xf32>, %arg1: tensor<4xf32>) -> tensor<4xf32> {
ENTRY %foo.5 (Arg_0.1: f32[4], Arg_1.2: f32[4]) -> f32[4] {
%Arg_0.1 = f32[4] parameter(0)
%Arg_1.2 = f32[4] parameter(1)
// CHECK-NEXT: %0 = "xla.mul"(%arg0, %arg1) {name = "multiply.3"} : (tensor<4xf32>, tensor<4xf32>) -> tensor<4xf32>
// CHECK-NEXT: return %0 : tensor<4xf32>
ROOT %multiply.3 = f32[4] multiply(f32[4] %Arg_0.1, f32[4] %Arg_1.2)
}

View File

@ -0,0 +1,23 @@
// RUN: tf-mlir-translate -hlo-text-to-mlir-hlo %s -o - | FileCheck %s
HloModule main.5
// CHECK-LABEL: func @main(%arg0: tensor<4xf32>, %arg1: tensor<f32>) -> tensor<4xf32> {
ENTRY %padding.1 (Arg_0.1: f32[4], Arg_1.2: f32[]) -> f32[4] {
%Arg_0.1 = f32[4] parameter(0)
%Arg_1.2 = f32[] parameter(1)
// CHECK-NEXT: %0 = "xla.pad"(%arg0, %arg1) {edge_padding_high = dense<0> : tensor<1xi64>, edge_padding_low = dense<0> : tensor<1xi64>, interior_padding = dense<0> : tensor<1xi64>} : (tensor<4xf32>, tensor<f32>) -> tensor<4xf32>
// CHECK-NEXT: return %0 : tensor<4xf32>
ROOT %pad.3 = f32[4] pad(%Arg_0.1, %Arg_1.2), padding=0_0_0
}
// CHECK-LABEL: func @padding.2(%arg0: tensor<4x4x4xf32>, %arg1: tensor<f32>) -> tensor<7x11x15xf32> {
%padding.2 (Arg_0.1: f32[4, 4, 4], Arg_1.2: f32[]) -> f32[7, 11, 15] {
%Arg_0.1 = f32[4, 4, 4] parameter(0)
%Arg_1.2 = f32[] parameter(1)
// CHECK-NEXT: %0 = "xla.pad"(%arg0, %arg1) {edge_padding_high = dense<[2, 4, 6]> : tensor<3xi64>, edge_padding_low = dense<[1, 3, 5]> : tensor<3xi64>, interior_padding = dense<0> : tensor<3xi64>} : (tensor<4x4x4xf32>, tensor<f32>) -> tensor<7x11x15xf32>
// CHECK-NEXT: return %0 : tensor<7x11x15xf32>
ROOT %pad.3 = f32[7, 11, 15] pad(%Arg_0.1, %Arg_1.2), padding=1_2x3_4x5_6
}

View File

@ -0,0 +1,53 @@
// RUN: tf-mlir-translate -hlo-text-to-mlir-hlo %s -o - | FileCheck %s
// This test is more thorough than those of the the other binary ops to test
// their shared functionality.
HloModule main.5
%reduce_helper.1 (Arg_0.1: f32[4], Arg_1.2: f32[4]) -> f32[4] {
%Arg_0.1 = f32[4] parameter(0)
%Arg_1.2 = f32[4] parameter(1)
ROOT %add.3 = f32[4] add(f32[4] %Arg_0.1, f32[4] %Arg_1.2)
}
%reduce_helper.2 (Arg_0.1: f32[], Arg_1.2: f32[]) -> f32[] {
%Arg_0.1 = f32[] parameter(0)
%Arg_1.2 = f32[] parameter(1)
ROOT %add.3 = f32[] add(f32[] %Arg_0.1, f32[] %Arg_1.2)
}
%reduce_helper.3 (Arg_0.1: f32[], Arg_1.2: f32[], Arg_2.3: f32[], Arg_3.4: f32[]) -> (f32[], f32[]) {
%Arg_0.1 = f32[] parameter(0)
%Arg_1.2 = f32[] parameter(1)
%Arg_2.3 = f32[] parameter(2)
%Arg_3.4 = f32[] parameter(3)
%add.4 = f32[] add(f32[] %Arg_0.1, f32[] %Arg_2.3)
%add.5 = f32[] add(f32[] %Arg_1.2, f32[] %Arg_3.4)
ROOT %tuple.6 = (f32[], f32[]) tuple(%add.4, %add.5)
}
// CHECK-LABEL: func @main(%arg0: tensor<4x4xf32>, %arg1: tensor<4xf32>, %arg2: tensor<f32>) -> tuple<tuple<tensor<f32>, tensor<f32>>, tensor<f32>> {
ENTRY %foo.5 (Arg_0.1: f32[4, 4], Arg_1.2: f32[4], Arg_2.3: f32[]) -> ((f32[], f32[]), f32[]) {
%Arg_0.1 = f32[4, 4] parameter(0)
%Arg_1.2 = f32[4] parameter(1)
%Arg_2.3 = f32[] parameter(2)
// CHECK-NEXT: %0 = "xla.reduce"(%arg0, %arg0, %arg2, %arg2) {computation = @reduce_helper.3, dimensions = dense<[0, 1]> : tensor<2xi64>} : (tensor<4x4xf32>, tensor<4x4xf32>, tensor<f32>, tensor<f32>) -> tuple<tensor<f32>, tensor<f32>>
%reduce.1 = f32[4] reduce(%Arg_0.1, %Arg_1.2), dimensions={0}, to_apply=%reduce_helper.1
// CHECK-NEXT: %1 = "xla.reduce"(%arg0, %arg1) {computation = @reduce_helper.1, dimensions = dense<0> : tensor<1xi64>} : (tensor<4x4xf32>, tensor<4xf32>) -> tensor<4xf32>
%reduce.2 = f32[] reduce(%reduce.1, %Arg_2.3), dimensions={0}, to_apply=%reduce_helper.2
// CHECK-NEXT: %2 = "xla.reduce"(%1, %arg2) {computation = @reduce_helper.2, dimensions = dense<0> : tensor<1xi64>} : (tensor<4xf32>, tensor<f32>) -> tensor<f32>
%reduce.3 = f32[] reduce(%Arg_0.1, %Arg_2.3), dimensions={0, 1}, to_apply=%reduce_helper.2
// CHECK-NEXT: %3 = "xla.reduce"(%arg0, %arg2) {computation = @reduce_helper.2, dimensions = dense<[0, 1]> : tensor<2xi64>} : (tensor<4x4xf32>, tensor<f32>) -> tensor<f32>
%reduce.4 = (f32[], f32[]) reduce(%Arg_0.1, %Arg_0.1, %Arg_2.3, %Arg_2.3), dimensions={0, 1}, to_apply=%reduce_helper.3
// CHECK-NEXT: %4 = "xla.sub"(%2, %3) {name = "sub.5"} : (tensor<f32>, tensor<f32>) -> tensor<f32>
%sub.5 = f32[] subtract(%reduce.2, %reduce.3)
ROOT %tuple.6 = ((f32[], f32[]), f32[]) tuple(%reduce.4, %sub.5)
}

View File

@ -0,0 +1,21 @@
// RUN: tf-mlir-translate -hlo-text-to-mlir-hlo %s -o - | FileCheck %s
HloModule main.5
// CHECK-LABEL: func @main(%arg0: tensor<4xf32>) -> tensor<4xf32> {
ENTRY %reverse.1 (Arg_0.1: f32[4]) -> f32[4] {
%Arg_0.1 = f32[4] parameter(0)
// CHECK-NEXT: %0 = "xla.reverse"(%arg0) {dimensions = dense<0> : tensor<1xi64>} : (tensor<4xf32>) -> tensor<4xf32>
// CHECK-NEXT: return %0 : tensor<4xf32>
ROOT reverse.2 = f32[4] reverse(%Arg_0.1), dimensions={0}
}
// CHECK-LABEL: func @reverse.2(%arg0: tensor<4x4xf32>) -> tensor<4x4xf32> {
%reverse.2 (Arg_0.1: f32[4, 4]) -> f32[4, 4] {
%Arg_0.1 = f32[4, 4] parameter(0)
// CHECK-NEXT: %0 = "xla.reverse"(%arg0) {dimensions = dense<[0, 1]> : tensor<2xi64>} : (tensor<4x4xf32>) -> tensor<4x4xf32>
// CHECK-NEXT: return %0 : tensor<4x4xf32>
ROOT reverse.2 = f32[4, 4] reverse(%Arg_0.1), dimensions={0, 1}
}

View File

@ -0,0 +1,9 @@
// RUN: tf-mlir-translate -hlo-text-to-mlir-hlo %s -o - | FileCheck %s
HloModule main.5
// CHECK-LABEL: func @main(%arg0: tensor<f32>) -> tensor<f32> {
ENTRY %main.5 (Arg_0.1: f32[]) -> f32[] {
// CHECK-NEXT: return %arg0 : tensor<f32>
ROOT %Arg_0.1 = f32[] parameter(0)
}

View File

@ -0,0 +1,15 @@
// RUN: tf-mlir-translate -hlo-text-to-mlir-hlo %s -o - | FileCheck %s
HloModule main
// CHECK-LABEL: func @main(%arg0: tensor<2x3xi1>, %arg1: tensor<2x3xi32>, %arg2: tensor<2x3xi32>) -> tensor<2x3xi32> {
ENTRY %main {
%Arg_0.1 = pred[2,3] parameter(0)
%Arg_1.2 = s32[2,3] parameter(1)
%Arg_2.3 = s32[2,3] parameter(2)
// CHECK-NEXT: %0 = "xla.select"(%arg0, %arg1, %arg2) {name = "select.4"} : (tensor<2x3xi1>, tensor<2x3xi32>, tensor<2x3xi32>) -> tensor<2x3xi32>
// CHECK-NEXT: return %0 : tensor<2x3xi32>
ROOT %select.4 = s32[2,3] select(pred[2,3] %Arg_0.1, s32[2,3] %Arg_1.2, s32[2,3] %Arg_2.3)
}

View File

@ -0,0 +1,13 @@
// RUN: tf-mlir-translate -mlir-hlo-to-hlo-text %s | FileCheck %s
// CHECK-LABEL: ENTRY %main
func @main(%arg0: tensor<2x3xi1>, %arg1: tensor<2x3xi32>, %arg2: tensor<2x3xi32>) -> tensor<2x3xi32> {
// CHECK-NEXT: %Arg_0.1 = pred[2,3] parameter(0)
// CHECK-NEXT: %Arg_1.2 = s32[2,3] parameter(1)
// CHECK-NEXT: %Arg_2.3 = s32[2,3] parameter(2)
// CHECK-NEXT: ROOT %select.4 = s32[2,3] select(pred[2,3] %Arg_0.1, s32[2,3] %Arg_1.2, s32[2,3] %Arg_2.3)
%0 = "xla.select"(%arg0, %arg1, %arg2) {name = "select.4"} : (tensor<2x3xi1>, tensor<2x3xi32>, tensor<2x3xi32>) -> tensor<2x3xi32>
return %0 : tensor<2x3xi32>
}

View File

@ -0,0 +1,146 @@
# RUN: tf-mlir-translate -hlo-to-mlir-hlo %s -o - | FileCheck %s
name: "foo.5"
entry_computation_name: "foo.5"
computations {
name: "foo.5"
instructions {
name: "Arg_0.1"
opcode: "parameter"
shape {
element_type: F32
dimensions: 4
layout {
minor_to_major: 0
format: DENSE
}
is_dynamic_dimension: false
}
metadata {
}
id: 1
}
instructions {
name: "Arg_1.2"
opcode: "parameter"
shape {
element_type: F32
dimensions: 4
layout {
minor_to_major: 0
format: DENSE
}
is_dynamic_dimension: false
}
metadata {
}
parameter_number: 1
id: 2
}
instructions {
name: "add.3"
opcode: "add"
shape {
element_type: F32
dimensions: 4
layout {
minor_to_major: 0
format: DENSE
}
is_dynamic_dimension: false
}
metadata {
}
id: 3
operand_ids: 1
operand_ids: 2
}
instructions {
name: "dot.4"
opcode: "dot"
shape {
element_type: F32
layout {
format: DENSE
}
}
metadata {
}
dot_dimension_numbers {
lhs_contracting_dimensions: 0
rhs_contracting_dimensions: 0
}
id: 4
operand_ids: 3
operand_ids: 2
}
program_shape {
parameters {
element_type: F32
dimensions: 4
layout {
minor_to_major: 0
format: DENSE
}
is_dynamic_dimension: false
}
parameters {
element_type: F32
dimensions: 4
layout {
minor_to_major: 0
format: DENSE
}
is_dynamic_dimension: false
}
result {
element_type: F32
layout {
format: DENSE
}
}
parameter_names: "Arg_0"
parameter_names: "Arg_1"
}
id: 5
root_id: 4
}
host_program_shape {
parameters {
element_type: F32
dimensions: 4
layout {
minor_to_major: 0
format: DENSE
}
is_dynamic_dimension: false
}
parameters {
element_type: F32
dimensions: 4
layout {
minor_to_major: 0
format: DENSE
}
is_dynamic_dimension: false
}
result {
element_type: F32
layout {
format: DENSE
}
}
parameter_names: "Arg_0"
parameter_names: "Arg_1"
}
id: 5
entry_computation_id: 5
dynamic_parameter_binding {
}
# CHECK-LABEL: func @main(%arg0: tensor<4xf32>, %arg1: tensor<4xf32>) -> tensor<f32> {
# CHECK-NEXT: %0 = "xla.add"(%arg0, %arg1) {name = "add.3"} : (tensor<4xf32>, tensor<4xf32>) -> tensor<4xf32>
# TODO(b/129709049) consider making this default precision config inferred.
# CHECK-NEXT: %1 = "xla.dot"(%0, %arg1) {name = "dot.4", precision_config = ["DEFAULT", "DEFAULT"]} : (tensor<4xf32>, tensor<4xf32>) -> tensor<f32>
# CHECK-NEXT: return %1 : tensor<f32>
# CHECK-NEXT: }

View File

@ -0,0 +1,17 @@
// RUN: tf-mlir-translate -hlo-text-to-mlir-hlo %s -o - | FileCheck %s
HloModule foo.5
// CHECK-LABEL: func @main(%arg0: tensor<4xf32>, %arg1: tensor<4xf32>) -> tensor<f32> {
ENTRY %main.5 (Arg_0.1: f32[4], Arg_1.2: f32[4]) -> f32[] {
%Arg_0.1 = f32[4]{0} parameter(0)
%Arg_1.2 = f32[4]{0} parameter(1)
// CHECK-NEXT: %0 = "xla.add"(%arg0, %arg1) {name = "add.3"} : (tensor<4xf32>, tensor<4xf32>) -> tensor<4xf32>
%add.3 = f32[4]{0} add(f32[4]{0} %Arg_0.1, f32[4]{0} %Arg_1.2)
// TODO(b/129709049) consider making this default precision config inferred.
// CHECK-NEXT: %1 = "xla.dot"(%0, %arg1) {name = "dot.4", precision_config = ["DEFAULT", "DEFAULT"]} : (tensor<4xf32>, tensor<4xf32>) -> tensor<f32>
// CHECK-NEXT: return %1 : tensor<f32>
ROOT %dot.4 = f32[] dot(f32[4]{0} %add.3, f32[4]{0} %Arg_1.2), lhs_contracting_dims={0}, rhs_contracting_dims={0}
}

View File

@ -0,0 +1,21 @@
// RUN: tf-mlir-translate -mlir-hlo-to-hlo %s | FileCheck %s
func @main(tensor<4xf32>, tensor<4xf32>) -> tensor<4xf32> {
^bb0(%arg0: tensor<4xf32>, %arg1: tensor<4xf32>):
%0 = "xla.add"(%arg0, %arg1) : (tensor<4xf32>, tensor<4xf32>) -> tensor<4xf32>
%1 = "xla.dot"(%0, %arg1) : (tensor<4xf32>, tensor<4xf32>) -> tensor<4xf32>
return %1 : tensor<4xf32>
}
// CHECK: name: "main
// CHECK: entry_computation_name: "main
// CHECK: computations {
// CHECK: name: "main
// CHECK: instructions {
// CHECK: name: "Arg_
// CHECK: opcode: "parameter"
// CHECK: name: "add
// CHECK: opcode: "add"
// CHECK: name: "dot
// CHECK: opcode: "dot"

View File

@ -0,0 +1,14 @@
// RUN: tf-mlir-translate -hlo-text-to-mlir-hlo %s -o - | FileCheck %s
HloModule main.5
// CHECK-LABEL: func @main(%arg0: tensor<4xf32>, %arg1: tensor<4xf32>) -> tensor<4xf32> {
ENTRY %foo.5 (Arg_0.1: f32[4], Arg_1.2: f32[4]) -> f32[4] {
%Arg_0.1 = f32[4] parameter(0)
%Arg_1.2 = f32[4] parameter(1)
// CHECK-NEXT: %0 = "xla.sub"(%arg0, %arg1) {name = "subtract.3"} : (tensor<4xf32>, tensor<4xf32>) -> tensor<4xf32>
// CHECK-NEXT: return %0 : tensor<4xf32>
ROOT %subtract.3 = f32[4] subtract(f32[4] %Arg_0.1, f32[4] %Arg_1.2)
}

View File

@ -0,0 +1,12 @@
// RUN: tf-mlir-translate -hlo-text-to-mlir-hlo %s -o - | FileCheck %s
HloModule foo
// CHECK-LABEL: func @main(%arg0: tensor<1x16x16x3xf32>) -> tensor<1x16x16x3xf32> {
ENTRY %foo (arg0.1: f32[1,16,16,3]) -> f32[1,16,16,3] {
%arg0.1 = f32[1,16,16,3]{3,2,1,0} parameter(0), metadata={op_name="XLA_Args"}
// CHECK-NEXT: %0 = "xla.tanh"(%arg0) {name = "tanh.3"} : (tensor<1x16x16x3xf32>) -> tensor<1x16x16x3xf32>
// CHECK-NEXT: return %0 : tensor<1x16x16x3xf32>
ROOT %tanh.3 = f32[1,16,16,3]{3,2,1,0} tanh(f32[1,16,16,3]{3,2,1,0} %arg0.1), metadata={op_type="Tanh" op_name="embedded_inference/tanh_model/Tanh"}
}

View File

@ -0,0 +1,13 @@
// RUN: tf-mlir-translate -hlo-text-to-mlir-hlo %s -o - | FileCheck %s
HloModule main
// CHECK-LABEL: func @main(%arg0: tensor<1x2x3x4xi32>) -> tensor<2x1x4x3xi32> {
ENTRY %main {
%Arg_0.1 = s32[1,2,3,4] parameter(0)
// CHECK-NEXT: %0 = "xla.transpose"(%arg0) {name = "transpose.2", permutation = dense<[1, 0, 3, 2]> : tensor<4xi64>} : (tensor<1x2x3x4xi32>) -> tensor<2x1x4x3xi32>
// CHECK-NEXT: return %0 : tensor<2x1x4x3xi32>
ROOT %transpose.2 = s32[2,1,4,3] transpose(s32[1,2,3,4] %Arg_0.1), dimensions={1,0,3,2}
}

View File

@ -0,0 +1,11 @@
// RUN: tf-mlir-translate -mlir-hlo-to-hlo-text %s | FileCheck %s
// CHECK-LABEL: ENTRY %main
func @main(%arg0: tensor<1x2x3x4xi32>) -> tensor<2x1x4x3xi32> {
// CHECK-NEXT: %Arg_0.1 = s32[1,2,3,4] parameter(0)
// CHECK-NEXT: ROOT %transpose.2 = s32[2,1,4,3] transpose(s32[1,2,3,4] %Arg_0.1), dimensions={1,0,3,2}
%0 = "xla.transpose"(%arg0) {permutation = dense<[1, 0, 3, 2]> : tensor<4xi64>} : (tensor<1x2x3x4xi32>) -> tensor<2x1x4x3xi32>
return %0 : tensor<2x1x4x3xi32>
}

View File

@ -0,0 +1,16 @@
// RUN: tf-mlir-translate -hlo-text-to-mlir-hlo %s -o - | FileCheck %s
HloModule main
// CHECK-LABEL: func @main(%arg0: tensor<1xi32>, %arg1: tensor<1x2xf32>) -> tuple<tensor<1xi32>, tensor<1x2xf32>> {
ENTRY %main(Arg_0.1: s32[1], Arg_1.2: f32[1, 2]) -> (s32[1], f32[1,2]) {
%Arg_0.1 = s32[1] parameter(0)
%Arg_1.2 = f32[1, 2] parameter(1)
// CHECK-NEXT: %0 = "xla.tuple"(%arg0) {name = "tuple.3"} : (tensor<1xi32>) -> tuple<tensor<1xi32>>
%tuple.3 = (s32[1]) tuple(%Arg_0.1)
// CHECK-NEXT: %1 = "xla.tuple"(%arg0, %arg1) {name = "tuple.4"} : (tensor<1xi32>, tensor<1x2xf32>) -> tuple<tensor<1xi32>, tensor<1x2xf32>>
// CHECK-NEXT: return %1 : tuple<tensor<1xi32>, tensor<1x2xf32>>
ROOT %tuple.4 = (s32[1], f32[1,2]) tuple(%Arg_0.1, %Arg_1.2)
}

View File

@ -0,0 +1,11 @@
// RUN: tf-mlir-translate -hlo-text-to-mlir-hlo %s -o - | FileCheck %s
HloModule main
// CHECK-LABEL: func @main(%arg0: tensor<1xf32>) -> tensor<1xf32> {
ENTRY %main (Arg_0.1: f32[1, 4], Arg_1.2: f32[4, 1]) -> f32[1] {
%Arg_0.1 = f32[1] parameter(0)
// CHECK-NEXT: %0 = "xla.unknown"(%arg0, %arg0) {name = "add-dependency.2"} : (tensor<1xf32>, tensor<1xf32>) -> tensor<1xf32>
ROOT add-dependency.2 = f32[1] add-dependency(Arg_0.1, Arg_0.1)
}

View File

@ -0,0 +1,27 @@
// RUN: tf-mlir-translate -hlo-text-to-mlir-hlo %s -o - | FileCheck %s
HloModule foo
// CHECK-LABEL: func @cond(%arg0: tensor<i64>) -> tensor<i1> {
%cond (arg_1: s64[]) -> pred[] {
%arg_1 = s64[] parameter(0), metadata={op_name="XLA_Args"}
// CHECK-NEXT: %0 = "xla.compare"(%arg0, %arg0) {comparison_direction = "LT", name = "compare.2"} : (tensor<i64>, tensor<i64>) -> tensor<i1>
// CHECK-NEXT: return %0 : tensor<i1>
ROOT %compare.2 = pred[] compare(%arg_1, %arg_1), direction=LT, metadata={op_type="Less" op_name="Less"}
}
// CHECK-LABEL: func @loop(%arg0: tensor<i64>) -> tensor<i64> {
%loop (arg_1: s64[]) -> s64[] {
%arg_1 = s64[] parameter(0), metadata={op_name="XLA_Args"}
// CHECK-NEXT: %0 = "xla.add"(%arg0, %arg0) {name = "compare.0"} : (tensor<i64>, tensor<i64>) -> tensor<i64>
// CHECK-NEXT: return %0 : tensor<i64>
ROOT %compare.2 = s64[] add(%arg_1, %arg_1), metadata={op_type="Less" op_name="Less"}
}
// CHECK-LABEL: func @main(%arg0: tensor<i64>) -> tensor<i64> {
ENTRY %foo (arg0.1: s64[]) -> s64[] {
%arg0.1 = s64[] parameter(0), metadata={op_name="XLA_Args"}
// CHECK-NEXT: %0 = "xla.while"(%arg0) {body = @loop, cond = @cond} : (tensor<i64>) -> tensor<i64>
// CHECK-NEXT: return %0 : tensor<i64>
ROOT %while.2 = s64[] while(%arg0.1), body=%loop, condition=%cond
}

View File

@ -0,0 +1,145 @@
/* Copyright 2019 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#include "tensorflow/compiler/mlir/xla/type_to_shape.h"
#include <string>
#include "absl/base/integral_types.h"
#include "mlir/IR/AffineMap.h" // TF:local_config_mlir
#include "mlir/IR/Diagnostics.h" // TF:local_config_mlir
#include "mlir/IR/Location.h" // TF:local_config_mlir
#include "mlir/IR/StandardTypes.h" // TF:local_config_mlir
#include "mlir/Support/DebugStringHelper.h" // TF:local_config_mlir
#include "tensorflow/compiler/xla/shape_util.h"
#include "tensorflow/core/platform/logging.h"
using mlir::IntegerType;
using mlir::MemRefType;
using mlir::RankedTensorType;
using mlir::VectorType;
using xla::PrimitiveType;
using xla::ShapeUtil;
namespace xla {
PrimitiveType TypeToPrimitiveType(mlir::Type type) {
switch (type.getKind()) {
case mlir::StandardTypes::BF16:
return PrimitiveType::BF16;
case mlir::StandardTypes::F16:
return PrimitiveType::F16;
case mlir::StandardTypes::F32:
return PrimitiveType::F32;
case mlir::StandardTypes::F64:
return PrimitiveType::F64;
case mlir::StandardTypes::Integer: {
const auto integer = type.cast<IntegerType>();
switch (integer.getWidth()) {
case 1:
return PrimitiveType::PRED;
case 8:
return PrimitiveType::S8;
case 16:
return PrimitiveType::S16;
case 32:
return PrimitiveType::S32;
case 64:
return PrimitiveType::S64;
default:
return PrimitiveType::PRIMITIVE_TYPE_INVALID;
}
}
default:
return PrimitiveType::PRIMITIVE_TYPE_INVALID;
}
}
Shape TypeToShape(mlir::Type type) {
PrimitiveType ptype = TypeToPrimitiveType(type);
if (ptype != PrimitiveType::PRIMITIVE_TYPE_INVALID)
return ShapeUtil::MakeShape(ptype, {});
switch (type.getKind()) {
case mlir::StandardTypes::BF16:
case mlir::StandardTypes::F32:
case mlir::StandardTypes::F64:
case mlir::StandardTypes::Integer: {
auto* context = type.getContext();
mlir::emitError(mlir::UnknownLoc::get(context))
<< "lowering should have been handled by primitive type lowering for "
<< debugString(type);
break;
}
case mlir::StandardTypes::Vector: {
const auto v = type.cast<VectorType>();
llvm::SmallVector<int64, 4> span(v.getShape().begin(),
v.getShape().end());
mlir::Type element_type = v.getElementType();
PrimitiveType primitive_type = TypeToPrimitiveType(element_type);
if (primitive_type != PrimitiveType::PRIMITIVE_TYPE_INVALID)
return ShapeUtil::MakeShape(primitive_type, span);
break;
}
case mlir::StandardTypes::MemRef: {
const auto m = type.cast<MemRefType>();
llvm::SmallVector<int64, 6> span(m.getShape().begin(),
m.getShape().end());
mlir::Type element_type = m.getElementType();
// Treat a memref of a vector as if it was a memref of primitive type with
// the vector dimensions at the end.
if (auto v = element_type.dyn_cast<VectorType>()) {
element_type = v.getElementType();
span.insert(span.end(), v.getShape().begin(), v.getShape().end());
}
PrimitiveType primitive_type = TypeToPrimitiveType(element_type);
if (primitive_type == PrimitiveType::PRIMITIVE_TYPE_INVALID) break;
// For the primitive type case, the shape of the memref is similar to the
// vector type case (i.e., it is, modulo the layout, the same dimensions
// and primitive type).
// Currently we only return shapes for identity affine maps.
// TODO(andydavis) Map affine map layout function to XLA layout.
if (m.getAffineMaps().empty() ||
(m.getAffineMaps().size() == 1 && m.getAffineMaps()[0].isIdentity()))
return ShapeUtil::MakeShape(primitive_type, span);
break;
}
case mlir::StandardTypes::RankedTensor: {
// TODO(jpienaar): This is only handling the base case with primitive
// element type.
const auto t = type.cast<RankedTensorType>();
llvm::SmallVector<int64, 4> span(t.getShape().begin(),
t.getShape().end());
// Only fully static shapes are supported.
// TODO(b/115638799): Update once xla::Shape can support dynamic shapes.
if (std::find(t.getShape().begin(), t.getShape().end(), -1) !=
t.getShape().end())
break;
mlir::Type element_type = t.getElementType();
PrimitiveType primitive_type = TypeToPrimitiveType(element_type);
// Only primitive element type supported.
if (primitive_type != PrimitiveType::PRIMITIVE_TYPE_INVALID)
return ShapeUtil::MakeShape(primitive_type, span);
break;
}
default:
break;
}
// Return empty XLA shape to signify error. No MLIR Type maps to a empty
// Shape.
return {};
}
} // namespace xla

View File

@ -0,0 +1,34 @@
/* Copyright 2019 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#ifndef TENSORFLOW_COMPILER_MLIR_XLA_TYPE_TO_SHAPE_H_
#define TENSORFLOW_COMPILER_MLIR_XLA_TYPE_TO_SHAPE_H_
#include "mlir/IR/Types.h" // TF:local_config_mlir
#include "tensorflow/compiler/xla/shape.h"
#include "tensorflow/compiler/xla/xla_data.pb.h"
namespace xla {
// Returns a XLA Shape equivalent of a MLIR Type, else returns empty shape.
Shape TypeToShape(mlir::Type type);
// Returns a XLA PrimitiveType equivalent of a MLIR Type that represents a
// primitive type (e.g., i8, f32), else returns PRIMITIVE_TYPE_INVALID.
PrimitiveType TypeToPrimitiveType(mlir::Type type);
} // namespace xla
#endif // TENSORFLOW_COMPILER_MLIR_XLA_TYPE_TO_SHAPE_H_

View File

@ -0,0 +1,110 @@
/* Copyright 2019 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#include "tensorflow/compiler/mlir/xla/type_to_shape.h"
#include "mlir/IR/Builders.h" // TF:local_config_mlir
#include "mlir/IR/MLIRContext.h" // TF:local_config_mlir
#include "mlir/IR/StandardTypes.h" // TF:local_config_mlir
#include "tensorflow/compiler/xla/shape_util.h"
#include "tensorflow/compiler/xla/test.h"
#include "tensorflow/compiler/xla/xla_data.pb.h"
using mlir::Builder;
using mlir::MLIRContext;
using ::testing::EqualsProto;
namespace xla {
namespace {
TEST(TypeToShapeTest, ConvertPrimitiveTypes) {
MLIRContext context;
Builder b(&context);
EXPECT_EQ(TypeToPrimitiveType(b.getF32Type()), PrimitiveType::F32);
EXPECT_EQ(TypeToPrimitiveType(b.getIntegerType(1)), PrimitiveType::PRED);
EXPECT_EQ(TypeToPrimitiveType(b.getIntegerType(17)),
PrimitiveType::PRIMITIVE_TYPE_INVALID);
}
TEST(TypeToShapeTest, ConvertBasicTypesToTypes) {
MLIRContext context;
Builder b(&context);
EXPECT_TRUE(
ShapeUtil::IsScalarWithElementType(TypeToShape(b.getF32Type()), F32));
EXPECT_THAT(
TypeToShape(b.getVectorType({8, 128}, b.getIntegerType(32))).ToProto(),
EqualsProto(
ShapeUtil::MakeShape(PrimitiveType::S32, {8, 128}).ToProto()));
EXPECT_THAT(
TypeToShape(b.getVectorType({8, 128}, b.getF32Type())).ToProto(),
EqualsProto(
ShapeUtil::MakeShape(PrimitiveType::F32, {8, 128}).ToProto()));
// MLIR Type that is not representable as XLA Shape.
EXPECT_THAT(
TypeToShape(b.getVectorType({8, 128}, b.getIntegerType(17))).ToProto(),
EqualsProto(Shape().ToProto()));
}
TEST(TypeToShapeTest, ConvertMemRefTypeToTypes) {
MLIRContext context;
Builder b(&context);
// Memref without any affine map. Note: memory space is ignored for shape.
EXPECT_THAT(
TypeToShape(b.getMemRefType({8, 128}, b.getF32Type())).ToProto(),
EqualsProto(
ShapeUtil::MakeShape(PrimitiveType::F32, {8, 128}).ToProto()));
EXPECT_THAT(
TypeToShape(b.getMemRefType({100, 13, 210}, b.getF32Type())).ToProto(),
EqualsProto(
ShapeUtil::MakeShape(PrimitiveType::F32, {100, 13, 210}).ToProto()));
// Vector types are "flattened" into the end of the shape.
EXPECT_THAT(
TypeToShape(b.getMemRefType({100, 13, 210},
b.getVectorType({8, 128}, b.getF32Type())))
.ToProto(),
EqualsProto(
ShapeUtil::MakeShape(PrimitiveType::F32, {100, 13, 210, 8, 128})
.ToProto()));
}
TEST(TypeToShapeTest, ConvertTensorTypeToTypes) {
MLIRContext context;
Builder b(&context);
EXPECT_THAT(
TypeToShape(b.getTensorType({8, 128}, b.getF32Type())).ToProto(),
EqualsProto(
ShapeUtil::MakeShape(PrimitiveType::F32, {8, 128}).ToProto()));
// Shape cannot represent dynamic shapes.
// TODO(b/115638799): Update once Shape can support dynamic shapes.
EXPECT_THAT(TypeToShape(b.getTensorType(b.getF32Type())).ToProto(),
EqualsProto(Shape().ToProto()));
// TODO(jpienaar): Expand to handle more complicated tensor types.
EXPECT_THAT(
TypeToShape(
b.getTensorType({8, 128}, b.getVectorType({16, 16}, b.getF32Type())))
.ToProto(),
EqualsProto(Shape().ToProto()));
}
} // namespace
} // namespace xla

View File

@ -0,0 +1,192 @@
/* Copyright 2019 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#include "tensorflow/compiler/mlir/xla/xla_mlir_translate.h"
#include "google/protobuf/text_format.h"
#include "llvm/Support/MemoryBuffer.h"
#include "llvm/Support/ToolOutputFile.h"
#include "mlir/IR/Module.h" // TF:local_config_mlir
#include "mlir/Translation.h" // TF:local_config_mlir
#include "tensorflow/compiler/mlir/xla/hlo_to_mlir_hlo.h"
#include "tensorflow/compiler/mlir/xla/mlir_hlo_to_hlo.h"
#include "tensorflow/compiler/xla/debug_options_flags.h"
#include "tensorflow/compiler/xla/service/hlo.pb.h"
#include "tensorflow/compiler/xla/service/hlo_parser.h"
#include "tensorflow/core/lib/core/errors.h"
using stream_executor::port::Status;
using stream_executor::port::StatusOr; // NOLINT TODO(b/130822468) fix this
namespace xla {
namespace {
// Error collector that simply ignores errors reported.
class NoOpErrorCollector : public ::proto2::io::ErrorCollector {
public:
void AddError(int line, int column, const string& message) override {}
};
bool LoadHloProto(const std::string& contents, HloProto* hlo_proto) {
::proto2::TextFormat::Parser parser;
NoOpErrorCollector collector;
parser.RecordErrorsTo(&collector);
return hlo_proto->ParseFromString(contents) ||
parser.ParseFromString(contents, hlo_proto) ||
hlo_proto->mutable_hlo_module()->ParseFromString(contents) ||
parser.ParseFromString(contents, hlo_proto->mutable_hlo_module());
}
} // namespace
mlir::OwningModuleRef HloToMlirHloTranslateFunction(
llvm::StringRef input_filename, mlir::MLIRContext* context) {
auto file_or_err = llvm::MemoryBuffer::getFileOrSTDIN(input_filename.str());
if (std::error_code error = file_or_err.getError()) {
LOG(ERROR) << "Failure to read HLO module: " << error;
return nullptr;
}
auto& input_file = *file_or_err;
HloProto hlo_proto;
string content(input_file->getBufferStart(), input_file->getBufferSize());
if (!LoadHloProto(content, &hlo_proto)) {
LOG(ERROR) << "Failed to load proto: " << input_filename.str();
return nullptr;
}
mlir::OwningModuleRef module =
mlir::ModuleOp::create(mlir::UnknownLoc::get(context));
auto status =
ConvertHloToMlirHlo(module.get(), hlo_proto.mutable_hlo_module());
if (!status.ok()) {
LOG(ERROR) << "Hlo module import failed: " << status;
return nullptr;
}
return module;
}
mlir::OwningModuleRef HloTextToMlirHloTranslateFunction(
llvm::StringRef input_filename, mlir::MLIRContext* context) {
auto file_or_err = llvm::MemoryBuffer::getFileOrSTDIN(input_filename.str());
if (std::error_code error = file_or_err.getError()) {
LOG(ERROR) << "Failure to open file: " << error;
return nullptr;
}
auto& input_file = *file_or_err;
HloProto hlo_proto;
string content(input_file->getBufferStart(), input_file->getBufferSize());
auto hlo_module_error = ParseAndReturnUnverifiedModule(content);
if (!hlo_module_error.ok()) {
LOG(ERROR) << "HLO Module loading failed: " << hlo_module_error.status();
return nullptr;
}
auto hlo_module = std::move(hlo_module_error.ValueOrDie());
mlir::OwningModuleRef module =
mlir::ModuleOp::create(mlir::UnknownLoc::get(context));
auto status = ConvertHloToMlirHlo(*module, hlo_module.get());
if (!status.ok()) {
LOG(ERROR) << "HLO Module import failed: " << status;
return nullptr;
}
return module;
}
static mlir::LogicalResult MlirHloToHloTranslateFunction(
mlir::ModuleOp module, llvm::StringRef output_filename) {
if (!module) return mlir::failure();
std::error_code error;
auto result = llvm::make_unique<llvm::ToolOutputFile>(output_filename, error,
llvm::sys::fs::F_None);
if (error) {
LOG(ERROR) << error.message();
return mlir::failure();
}
HloProto hloProto;
Status status = mlir::ConvertMlirHloToHlo(module, &hloProto);
if (!status.ok()) {
LOG(ERROR) << "Module conversion failed: " << status;
return mlir::failure();
}
result->os() << hloProto.DebugString();
result->keep();
return mlir::success();
}
static StatusOr<std::unique_ptr<HloModule>> HloModuleFromProto(
const HloProto& hlo_proto) {
const HloModuleProto& module_proto = hlo_proto.hlo_module();
TF_ASSIGN_OR_RETURN(const HloModuleConfig module_config,
HloModule::CreateModuleConfigFromProto(
module_proto, GetDebugOptionsFromFlags()));
return HloModule::CreateFromProto(module_proto, module_config);
}
static mlir::LogicalResult MlirHloToHloTextTranslateFunction(
mlir::ModuleOp module, llvm::StringRef output_filename) {
if (!module) return mlir::failure();
std::error_code error;
auto result = llvm::make_unique<llvm::ToolOutputFile>(output_filename, error,
llvm::sys::fs::F_None);
if (error) {
LOG(ERROR) << error.message();
return mlir::failure();
}
HloProto hloProto;
Status status = mlir::ConvertMlirHloToHlo(module, &hloProto);
if (!status.ok()) {
LOG(ERROR) << "Module conversion failed: " << status;
return mlir::failure();
}
auto statusOrHloModule = HloModuleFromProto(hloProto);
if (!statusOrHloModule.ok()) {
LOG(ERROR) << "Conversion to HLO module failed: "
<< statusOrHloModule.status();
return mlir::failure();
}
result->os() << statusOrHloModule.ValueOrDie()->ToString(
HloPrintOptions()
// We don't interpret or use layouts
.set_include_layout_in_shapes(false));
result->keep();
return mlir::success();
}
} // namespace xla
static mlir::TranslateFromMLIRRegistration MlirHloToHloTranslate(
"mlir-hlo-to-hlo", xla::MlirHloToHloTranslateFunction);
static mlir::TranslateFromMLIRRegistration MlirHloToHloTextTranslate(
"mlir-hlo-to-hlo-text", xla::MlirHloToHloTextTranslateFunction);
static mlir::TranslateToMLIRRegistration HloToHloMlirTranslate(
"hlo-to-mlir-hlo", xla::HloToMlirHloTranslateFunction);
static mlir::TranslateToMLIRRegistration HloTextToHloMlirTranslate(
"hlo-text-to-mlir-hlo", xla::HloTextToMlirHloTranslateFunction);

View File

@ -0,0 +1,47 @@
/* Copyright 2019 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#ifndef TENSORFLOW_COMPILER_MLIR_XLA_XLA_MLIR_TRANSLATE_H_
#define TENSORFLOW_COMPILER_MLIR_XLA_XLA_MLIR_TRANSLATE_H_
#include <memory>
#include "tensorflow/compiler/xla/statusor.h"
namespace llvm {
class StringRef;
} // namespace llvm
namespace mlir {
class MLIRContext;
class OwningModuleRef;
} // namespace mlir
namespace xla {
// Converts a HloModuleProto stored in the file with the given `input_filename`
// into a MLIR module. Creates MLIR entities into the given MLIR `context`.
mlir::OwningModuleRef HloToMlirHloTranslateFunction(
llvm::StringRef input_filename, mlir::MLIRContext* context);
// Converts a HloModule stored in text form for a file with the given
// `input_filename` into a MLIR module. Creates MLIR entities into the given
// MLIR `context`.
mlir::OwningModuleRef HloTextToMlirHloTranslateFunction(
llvm::StringRef input_filename, mlir::MLIRContext* context);
} // namespace xla
#endif // TENSORFLOW_COMPILER_MLIR_XLA_XLA_MLIR_TRANSLATE_H_