diff --git a/tensorflow/compiler/mlir/BUILD b/tensorflow/compiler/mlir/BUILD index f758d40b152..247bb83e7f7 100644 --- a/tensorflow/compiler/mlir/BUILD +++ b/tensorflow/compiler/mlir/BUILD @@ -56,6 +56,25 @@ tf_cc_binary( ], ) +tf_cc_binary( + name = "tf-mlir-translate", + deps = [ + "//tensorflow/compiler/mlir/tensorflow:convert_graphdef", + "//tensorflow/compiler/mlir/tensorflow:mlir_roundtrip_flags", + "//tensorflow/compiler/mlir/tensorflow:translate_cl_options", + "//tensorflow/compiler/mlir/tensorflow:translate_lib", + "//tensorflow/compiler/mlir/tensorflow:translate_registration", + "//tensorflow/compiler/mlir/tensorflow:translate_tf_dialect_op", + "//tensorflow/compiler/mlir/xla:xla_mlir_translate", + "//tensorflow/core:protos_all_proto_cc", + "//tensorflow/stream_executor/lib", + "@llvm//:support", + "@local_config_mlir//:IR", + "@local_config_mlir//:Translation", + "@local_config_mlir//:tools/mlir-translate/mlir-translate", + ], +) + filegroup( name = "litfiles", srcs = glob(["runlit*py"]), diff --git a/tensorflow/compiler/mlir/tensorflow/BUILD b/tensorflow/compiler/mlir/tensorflow/BUILD index 105ee63a3e0..abe8df63b20 100644 --- a/tensorflow/compiler/mlir/tensorflow/BUILD +++ b/tensorflow/compiler/mlir/tensorflow/BUILD @@ -511,24 +511,6 @@ cc_library( alwayslink = 1, ) -tf_cc_binary( - name = "tf-mlir-translate", - deps = [ - ":convert_graphdef", - ":mlir_roundtrip_flags", - ":translate_cl_options", - ":translate_lib", - ":translate_registration", - ":translate_tf_dialect_op", - "//tensorflow/core:protos_all_proto_cc", - "//tensorflow/stream_executor/lib", - "@llvm//:support", - "@local_config_mlir//:IR", - "@local_config_mlir//:Translation", - "@local_config_mlir//:tools/mlir-translate/mlir-translate", - ], -) - tf_cc_test( name = "error_util_test", srcs = ["utils/error_util_test.cc"], diff --git a/tensorflow/compiler/mlir/tensorflow/tests/graphdef2mlir/BUILD b/tensorflow/compiler/mlir/tensorflow/tests/graphdef2mlir/BUILD index 14d0e30fcd5..39ab5ef0811 100644 --- a/tensorflow/compiler/mlir/tensorflow/tests/graphdef2mlir/BUILD +++ b/tensorflow/compiler/mlir/tensorflow/tests/graphdef2mlir/BUILD @@ -13,7 +13,7 @@ filegroup( name = "test_utilities", testonly = True, data = [ - "//tensorflow/compiler/mlir/tensorflow:tf-mlir-translate", + "//tensorflow/compiler/mlir:tf-mlir-translate", "@llvm//:FileCheck", ], ) diff --git a/tensorflow/compiler/mlir/tensorflow/tests/mlir2graphdef/BUILD b/tensorflow/compiler/mlir/tensorflow/tests/mlir2graphdef/BUILD index f71ea50b0e2..a4b7baeb90c 100644 --- a/tensorflow/compiler/mlir/tensorflow/tests/mlir2graphdef/BUILD +++ b/tensorflow/compiler/mlir/tensorflow/tests/mlir2graphdef/BUILD @@ -13,7 +13,7 @@ filegroup( name = "test_utilities", testonly = True, data = [ - "//tensorflow/compiler/mlir/tensorflow:tf-mlir-translate", + "//tensorflow/compiler/mlir:tf-mlir-translate", "@llvm//:FileCheck", ], ) diff --git a/tensorflow/compiler/mlir/xla/BUILD b/tensorflow/compiler/mlir/xla/BUILD index e949b20ce98..c36299ee263 100644 --- a/tensorflow/compiler/mlir/xla/BUILD +++ b/tensorflow/compiler/mlir/xla/BUILD @@ -1,4 +1,5 @@ load("@local_config_mlir//:tblgen.bzl", "gentbl") +load("//tensorflow:tensorflow.bzl", "tf_cc_test", "tf_native_cc_binary") package( default_visibility = [":friends"], @@ -166,3 +167,151 @@ cc_library( ], alwayslink = 1, ) + +cc_library( + name = "type_to_shape", + srcs = ["type_to_shape.cc"], + hdrs = ["type_to_shape.h"], + deps = [ + "//tensorflow/compiler/xla:shape_util", + "//tensorflow/compiler/xla:xla_data_proto", + "//tensorflow/core:lib", + "@com_google_absl//absl/base:core_headers", + "@local_config_mlir//:IR", + "@local_config_mlir//:Support", + ], +) + +tf_cc_test( + name = "type_to_shape_test", + srcs = ["type_to_shape_test.cc"], + deps = [ + ":type_to_shape", + "//tensorflow/compiler/xla:shape_util", + "//tensorflow/compiler/xla:test", + "//tensorflow/compiler/xla:xla_data_proto", + "//tensorflow/core:test_main", + "@local_config_mlir//:IR", + ], +) + +cc_library( + name = "mlir_hlo_to_hlo", + srcs = [ + "mlir_hlo_to_hlo.cc", + "operator_writers.inc", + ], + hdrs = ["mlir_hlo_to_hlo.h"], + deps = [ + ":type_to_shape", + ":xla", + "//tensorflow/compiler/mlir/tensorflow:error_util", + "//tensorflow/compiler/xla:comparison_util", + "//tensorflow/compiler/xla:status_macros", + "//tensorflow/compiler/xla:xla_data_proto", + "//tensorflow/compiler/xla/client:xla_builder", + "//tensorflow/compiler/xla/service:hlo", + "@llvm//:support", + "@local_config_mlir//:Analysis", + "@local_config_mlir//:IR", + "@local_config_mlir//:Pass", + "@local_config_mlir//:StandardOps", + "@local_config_mlir//:TransformUtils", + "@local_config_mlir//:Transforms", + ], +) + +cc_library( + name = "hlo_to_mlir_hlo", + srcs = [ + "hlo_to_mlir_hlo.cc", + ], + hdrs = [ + "hlo_to_mlir_hlo.h", + ], + deps = [ + ":hlo_module_importer", + "//tensorflow/compiler/mlir/tensorflow:error_util", + "//tensorflow/compiler/xla:status", + "//tensorflow/compiler/xla:status_macros", + "//tensorflow/core:lib", + ], +) + +cc_library( + name = "hlo_module_importer", + srcs = [ + "hlo_function_importer.cc", + "hlo_module_importer.cc", + ], + hdrs = [ + "hlo_function_importer.h", + "hlo_module_importer.h", + ], + deps = [ + ":xla", + "//tensorflow/compiler/mlir/tensorflow:error_util", + "//tensorflow/compiler/xla:protobuf_util", + "//tensorflow/compiler/xla:status", + "//tensorflow/compiler/xla:status_macros", + "//tensorflow/compiler/xla:statusor", + "//tensorflow/compiler/xla:xla_data_proto", + "//tensorflow/compiler/xla:xla_proto", + "//tensorflow/compiler/xla/service:hlo", + "@llvm//:support", + "@local_config_mlir//:IR", + "@local_config_mlir//:StandardOps", + ], +) + +cc_library( + name = "xla_mlir_translate", + srcs = [ + "xla_mlir_translate.cc", + ], + hdrs = [ + "xla_mlir_translate.h", + ], + deps = [ + ":hlo_to_mlir_hlo", + ":mlir_hlo_to_hlo", + "//tensorflow/compiler/xla:debug_options_flags", + "//tensorflow/compiler/xla:statusor", + "//tensorflow/compiler/xla/service:hlo_parser", + "//tensorflow/compiler/xla/service:hlo_proto", + "//tensorflow/core:lib", + "@com_google_protobuf//:protobuf_headers", + "@llvm//:support", + "@local_config_mlir//:IR", + "@local_config_mlir//:Translation", + ], + alwayslink = 1, +) + +tf_native_cc_binary( + name = "operator_writer_gen", + srcs = [ + "operator_writer_gen.cc", + ], + deps = [ + "@llvm//:config", + "@llvm//:support", + "@llvm//:tablegen", + "@local_config_mlir//:TableGen", + ], +) + +genrule( + name = "operator_writer_inc", + srcs = [ + "@local_config_mlir//:include/mlir/IR/OpBase.td", + "//tensorflow/compiler/mlir/xla:ir/xla_ops.td", + ], + outs = [ + "operator_writers.inc", + ], + cmd = ("$(location :operator_writer_gen) " + + "-I external/local_config_mlir/include " + + "$(location //tensorflow/compiler/mlir/xla:ir/xla_ops.td) " + " -o $@"), + tools = [":operator_writer_gen"], +) diff --git a/tensorflow/compiler/mlir/xla/hlo_function_importer.cc b/tensorflow/compiler/mlir/xla/hlo_function_importer.cc new file mode 100644 index 00000000000..b9ba5fcb9fb --- /dev/null +++ b/tensorflow/compiler/mlir/xla/hlo_function_importer.cc @@ -0,0 +1,528 @@ +/* Copyright 2019 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tensorflow/compiler/mlir/xla/hlo_function_importer.h" + +#include "llvm/ADT/ArrayRef.h" +#include "llvm/ADT/STLExtras.h" +#include "llvm/ADT/SmallVector.h" +#include "llvm/Support/raw_ostream.h" +#include "mlir/IR/Attributes.h" // TF:local_config_mlir +#include "mlir/IR/Builders.h" // TF:local_config_mlir +#include "mlir/IR/Identifier.h" // TF:local_config_mlir +#include "mlir/IR/Location.h" // TF:local_config_mlir +#include "mlir/IR/StandardTypes.h" // TF:local_config_mlir +#include "mlir/StandardOps/Ops.h" // TF:local_config_mlir +#include "tensorflow/compiler/mlir/tensorflow/utils/error_util.h" +#include "tensorflow/compiler/mlir/xla/ir/xla_ops.h" +#include "tensorflow/compiler/xla/protobuf_util.h" +#include "tensorflow/compiler/xla/service/hlo_computation.h" +#include "tensorflow/compiler/xla/service/hlo_instruction.h" +#include "tensorflow/compiler/xla/service/hlo_instructions.h" +#include "tensorflow/compiler/xla/service/hlo_module.h" +#include "tensorflow/compiler/xla/service/hlo_opcode.h" +#include "tensorflow/compiler/xla/status_macros.h" +#include "tensorflow/compiler/xla/xla_data.pb.h" + +using llvm::APInt; +using llvm::makeArrayRef; +using mlir::DenseElementsAttr; +using mlir::DenseIntElementsAttr; +using mlir::FuncOp; +using mlir::NamedAttribute; +using mlir::Operation; +using mlir::ShapedType; +using mlir::Type; +using mlir::Value; + +namespace xla { + +namespace { +// Note: This sanitization function causes an irreversible many-to-one mapping +// and any solution to mitigate this would cause issues with the reverse +// direction. Longterm solution is to add a function attribute to maintain the +// original HLO naming. +string SanitizeFunctionName(llvm::StringRef name) { + string output = name; + llvm::for_each(output, [](char& x) { x = x == '-' ? '_' : x; }); + return output; +} + +StatusOr CreateDenseAttrFromLiteral(ShapedType type, + const Literal& literal) { +#define DENSE_ELEMENT_ATTR_BUILDER(xla_type, cpp_type) \ + case xla_type: { \ + auto data_span = literal.data(); \ + return DenseElementsAttr::get( \ + type, llvm::makeArrayRef(data_span.data(), data_span.size())); \ + } + + switch (literal.shape().element_type()) { + DENSE_ELEMENT_ATTR_BUILDER(PrimitiveType::PRED, bool) + DENSE_ELEMENT_ATTR_BUILDER(PrimitiveType::F32, float) + DENSE_ELEMENT_ATTR_BUILDER(PrimitiveType::F64, double) + DENSE_ELEMENT_ATTR_BUILDER(PrimitiveType::S8, int8) + DENSE_ELEMENT_ATTR_BUILDER(PrimitiveType::S16, int16) + DENSE_ELEMENT_ATTR_BUILDER(PrimitiveType::S32, int32) + DENSE_ELEMENT_ATTR_BUILDER(PrimitiveType::S64, int64) + default: + return tensorflow::errors::Internal( + absl::StrCat("Unsupported type: ", + PrimitiveType_Name(literal.shape().element_type()))); + } +#undef DENSE_ELEMENT_ATTR_BUILDER +} +} // namespace + +StatusOr HloFunctionImporter::ImportFunction( + mlir::ModuleOp module, mlir::Builder* builder, + std::unordered_map* function_map, + HloComputation* computation) { + HloFunctionImporter importer(module, builder, function_map); + return importer.ImportFunction(computation); +} + +StatusOr HloFunctionImporter::ImportFunction( + HloComputation* computation) { + auto& imported = (*function_map_)[computation]; + if (imported) return imported; + + llvm::SmallVector args, rets; + TF_RETURN_IF_ERROR( + GetMlirTypes(computation->parameter_instructions(), &args)); + TF_RETURN_IF_ERROR(GetMlirTypes({computation->root_instruction()}, &rets)); + + auto func_type = mlir::FunctionType::get(args, rets, context_); + + string computation_name = + computation->parent()->entry_computation() == computation + ? "main" + : SanitizeFunctionName(computation->name()); + + // Construct the MLIR function and map arguments. + llvm::ArrayRef attrs; + auto function = mlir::FuncOp::create(mlir::UnknownLoc::get(context_), + computation_name, func_type, attrs); + module_.push_back(function); + + // Add to the map right away for function calls. + imported = function; + + function.addEntryBlock(); + + // Setup the input parameters. + const int num_parameters = computation->num_parameters(); + for (int i = 0; i < num_parameters; i++) { + auto hlo_parameter = computation->parameter_instruction(i); + instruction_value_map_[hlo_parameter] = function.getArgument(i); + } + + mlir::OpBuilder func_builder(function.getBody()); + for (auto instruction : computation->MakeInstructionPostOrder()) { + TF_ASSIGN_OR_RETURN(auto new_operation, + ImportInstruction(instruction, &func_builder)); + if (new_operation) { + instruction_value_map_[instruction] = new_operation->getResult(0); + } + } + + // Setup the return type (HLO only supports a single return value). + TF_ASSIGN_OR_RETURN(auto result, + GetMlirValue(computation->root_instruction())); + llvm::SmallVector return_values({result}); + // TODO(suderman): Add location tracking details. + func_builder.create(mlir::UnknownLoc::get(context_), + makeArrayRef(return_values)); + + return function; +} + +StatusOr HloFunctionImporter::ImportInstruction( + HloInstruction* instruction, mlir::OpBuilder* func_builder) { + TF_ASSIGN_OR_RETURN(auto operands, GetOperands(instruction)); + TF_ASSIGN_OR_RETURN(auto result_type, ConvertType(instruction->shape())); + llvm::SmallVector attributes = {builder_->getNamedAttr( + "name", builder_->getStringAttr(instruction->name()))}; + mlir::Location loc = mlir::UnknownLoc::get(context_); + + switch (instruction->opcode()) { + case HloOpcode::kParameter: { + return nullptr; + } + case HloOpcode::kConstant: { + auto attr = CreateDenseAttrFromLiteral( + result_type.cast(), instruction->literal()); + if (!attr.ok()) return attr.status(); + mlir::Operation* new_operation = + func_builder->create(loc, attr.ValueOrDie()); + for (auto attr : attributes) { + new_operation->setAttr(attr.first, attr.second); + } + return new_operation; + } + case HloOpcode::kIota: { + return func_builder + ->create( + loc, result_type, + func_builder->getI64IntegerAttr( + static_cast(instruction) + ->iota_dimension())) + .getOperation(); + } +#define MakeAndReturn(mlir_op) \ + { \ + mlir::Operation* new_operation = func_builder->create( \ + loc, result_type, operands, attributes); \ + return new_operation; \ + } + case HloOpcode::kBroadcast: { + // Note that the HLO broadcast is more powerful than the XLA broadcast op. + // BroadcastInDim offers a superset of the HLO op's functionality. + if (!instruction->dimensions().empty()) { + attributes.push_back(builder_->getNamedAttr( + "broadcast_dimensions", + ConvertDimensions(instruction->dimensions()))); + } + MakeAndReturn(BroadcastInDimOp); + } + case HloOpcode::kDot: { + // TODO(b/129153247) Add support for batch and contracting dimensions. + TF_RETURN_IF_ERROR(ValidateDotDimensions(instruction)); + + // TODO(b/129709049) The HLO text format elides this in the all DEFAULT + // case and the parser sticks it in. Maybe we should too. + attributes.push_back(ConvertPrecisionConfig(instruction)); + MakeAndReturn(DotOp); + } + case HloOpcode::kCall: { + TF_ASSIGN_OR_RETURN(FuncOp function, + ImportFunction(instruction->to_apply())); + mlir::Operation* new_operation = + func_builder->create(loc, function, operands); + return new_operation; + } + case HloOpcode::kCompare: { + attributes.push_back(ConvertComparisonDirection(instruction)); + MakeAndReturn(CompareOp); + } + case HloOpcode::kGather: { + const auto& gather_dimensions = instruction->gather_dimension_numbers(); + std::vector offset_dims(gather_dimensions.offset_dims().begin(), + gather_dimensions.offset_dims().end()); + + std::vector slice_sizes( + instruction->gather_slice_sizes().begin(), + instruction->gather_slice_sizes().end()); + + std::vector collapsed_slice_dims( + gather_dimensions.collapsed_slice_dims().begin(), + gather_dimensions.collapsed_slice_dims().end()); + + std::vector start_index_map( + gather_dimensions.start_index_map().begin(), + gather_dimensions.start_index_map().end()); + + // TODO(b/132057942): Change to explicitly passing an integer instead of + // call getI64IntegerAttr here. + return func_builder + ->create( + loc, result_type, operands[0], operands[1], + func_builder->getI64IntegerAttr( + gather_dimensions.index_vector_dim()), + Convert(offset_dims), Convert(slice_sizes), + Convert(collapsed_slice_dims), Convert(start_index_map)) + .getOperation(); + } + case HloOpcode::kDynamicUpdateSlice: { + return func_builder + ->create( + loc, result_type, operands[0], operands[1], + llvm::ArrayRef(operands.begin() + 2, operands.end())) + .getOperation(); + } + case HloOpcode::kPad: { + const auto& padding_config = instruction->padding_config(); + llvm::SmallVector edge_padding_low; + llvm::SmallVector edge_padding_high; + llvm::SmallVector interior_padding; + edge_padding_low.reserve(padding_config.dimensions_size()); + edge_padding_high.reserve(padding_config.dimensions_size()); + interior_padding.reserve(padding_config.dimensions_size()); + + for (const auto& dimension : padding_config.dimensions()) { + edge_padding_low.push_back(dimension.edge_padding_low()); + edge_padding_high.push_back(dimension.edge_padding_high()); + interior_padding.push_back(dimension.interior_padding()); + } + + return func_builder + ->create(loc, result_type, operands[0], operands[1], + Convert(edge_padding_low), + Convert(edge_padding_high), + Convert(interior_padding)) + .getOperation(); + } + case HloOpcode::kSlice: { + return func_builder + ->create( + loc, result_type, operands[0], + ConvertDimensions(instruction->slice_starts()), + ConvertDimensions(instruction->slice_limits())) + .getOperation(); + } + case HloOpcode::kConcatenate: { + // TODO(b/132057942): Support taking an uint64_t instead of an IntegerAttr + // for concatenate dimension. + return func_builder + ->create( + loc, result_type, operands, + builder_->getI64IntegerAttr(instruction->concatenate_dimension())) + .getOperation(); + } + case HloOpcode::kReduce: { + TF_ASSIGN_OR_RETURN(auto reduction, + ImportFunction(instruction->to_apply())); + // TODO(b/132057942): Make more convenient constructors, e.g. pass + // mlir function pointer instead of a function attr. + return func_builder + ->create( + loc, result_type, operands, + func_builder->getSymbolRefAttr(reduction), + ConvertDimensions(instruction->dimensions())) + .getOperation(); + } + case HloOpcode::kReverse: { + return func_builder + ->create( + loc, result_type, operands[0], + ConvertDimensions(instruction->dimensions())) + .getOperation(); + } + case HloOpcode::kWhile: { + TF_ASSIGN_OR_RETURN(auto body, ImportFunction(instruction->while_body())); + TF_ASSIGN_OR_RETURN(auto cond, + ImportFunction(instruction->while_condition())); + + llvm::SmallVector types; + types.reserve(operands.size()); + for (auto operand : operands) { + types.push_back(operand->getType()); + } + + auto cond_attr = func_builder->getSymbolRefAttr(cond); + auto body_attr = func_builder->getSymbolRefAttr(body); + + Operation* op = func_builder->create( + loc, types, operands, cond_attr, body_attr); + return op; + } + case HloOpcode::kGetTupleElement: { + attributes.push_back(builder_->getNamedAttr( + "index", builder_->getIntegerAttr(builder_->getIntegerType(32), + instruction->tuple_index()))); + MakeAndReturn(GetTupleElementOp); + }; + case HloOpcode::kTranspose: { + attributes.push_back(builder_->getNamedAttr( + "permutation", ConvertDimensions(instruction->dimensions()))); + MakeAndReturn(TransposeOp); + } +#define NoAttributeCase(hlo_op_code, mlir_op) \ + case HloOpcode::hlo_op_code: { \ + MakeAndReturn(mlir_op); \ + } + + // broadcast dimensions are never added here because they don't exist as + // part of the HLO instruction. They are only a convenience in the XLA + // builder API. + NoAttributeCase(kAdd, AddOp); + NoAttributeCase(kAnd, AndOp); + NoAttributeCase(kConvert, ConvertOp); + NoAttributeCase(kDivide, DivOp); + NoAttributeCase(kMaximum, MaxOp); + NoAttributeCase(kMinimum, MinOp); + NoAttributeCase(kMultiply, MulOp); + NoAttributeCase(kSelect, SelectOp); + NoAttributeCase(kSubtract, SubOp); + NoAttributeCase(kTanh, TanhOp); + NoAttributeCase(kTuple, TupleOp); + // TODO(b/129422361) Copy needs special handling because it is not defined + // in tensorflow/compiler/xla/client/xla_builder.h. + // See operation semantics in + // g3doc/platforms/xla/g3doc/internal/hlo_semantics#copy + NoAttributeCase(kCopy, CopyOp); + // TODO(b/129422361) Ops below need additional work to handle attributes. + NoAttributeCase(kConvolution, ConvOp); + NoAttributeCase(kReshape, ReshapeOp); +#undef NoAttributeCase +#undef MakeAndReturn + case HloOpcode::kAddDependency: + // Arbitrary op code that I suspect we will not implement for quite a + // while and allows testing handling of unknown ops. Selected because it + // is not mentioned in xla client anywhere or in the hlo of our sample + // models. + default: { + mlir::OperationState result(loc, "xla.unknown"); + result.addOperands(operands); + result.addTypes(result_type); + for (auto attr : attributes) { + result.attributes.push_back(attr); + } + + return func_builder->createOperation(result); + } + } +} + +StatusOr> HloFunctionImporter::GetOperands( + HloInstruction* instruction) { + llvm::SmallVector operands; + for (const auto& operand : instruction->operands()) { + auto input_it = instruction_value_map_.find(operand); + if (input_it == instruction_value_map_.end()) { + return tensorflow::errors::Internal( + absl::StrCat("Could not find input value: ", operand->name(), + " for instruction ", instruction->name())); + } + operands.push_back(input_it->second); + } + return operands; +} + +// TODO(suderman): Move to a general library when needed in other places. +StatusOr HloFunctionImporter::ConvertTensorType( + const Shape& shape) { + auto type = shape.element_type(); + + llvm::SmallVector array; + array.reserve(shape.dimensions_size()); + for (auto val : shape.dimensions()) { + array.push_back(val); + } + + switch (type) { + case PrimitiveType::PRED: + return builder_->getTensorType(array, builder_->getI1Type()); + case PrimitiveType::F16: + return builder_->getTensorType(array, builder_->getF16Type()); + case PrimitiveType::F32: + return builder_->getTensorType(array, builder_->getF32Type()); + case PrimitiveType::F64: + return builder_->getTensorType(array, builder_->getF64Type()); + case PrimitiveType::S8: + return builder_->getTensorType(array, builder_->getIntegerType(8)); + case PrimitiveType::S16: + return builder_->getTensorType(array, builder_->getIntegerType(16)); + case PrimitiveType::S32: + return builder_->getTensorType(array, builder_->getIntegerType(32)); + case PrimitiveType::S64: + return builder_->getTensorType(array, builder_->getIntegerType(64)); + default: + return tensorflow::errors::Internal( + absl::StrCat("Unsupported type: ", PrimitiveType_Name(type))); + } +} + +StatusOr HloFunctionImporter::ConvertType(const Shape& shape) { + if (shape.IsTuple()) { + mlir::Type mlir_type; + llvm::SmallVector contents; + contents.reserve(shape.tuple_shapes_size()); + for (const auto& subtype : shape.tuple_shapes()) { + TF_ASSIGN_OR_RETURN(auto mlir_subtype, ConvertType(subtype)); + contents.push_back(mlir_subtype); + } + + return builder_->getTupleType(contents); + } + + return ConvertTensorType(shape); +} + +tensorflow::Status HloFunctionImporter::GetMlirTypes( + const std::vector& instructions, + llvm::SmallVectorImpl* types) { + for (auto instruction : instructions) { + TF_ASSIGN_OR_RETURN(auto ret_type, ConvertType(instruction->shape())); + types->push_back(ret_type); + } + return tensorflow::Status::OK(); +} + +StatusOr HloFunctionImporter::GetMlirValue( + HloInstruction* instruction) { + auto lookup = instruction_value_map_.find(instruction); + if (lookup != instruction_value_map_.end()) { + return lookup->second; + } + + return tensorflow::errors::Internal(absl::StrCat( + "Unable to find value for input: ", instruction->ToString())); +} + +mlir::NamedAttribute HloFunctionImporter::ConvertPrecisionConfig( + HloInstruction* instruction) { + llvm::SmallVector operand_precision_attrs; + + for (auto prec : instruction->precision_config().operand_precision()) { + operand_precision_attrs.push_back( + builder_->getStringAttr(PrecisionConfig_Precision_Name(prec))); + } + return builder_->getNamedAttr( + "precision_config", builder_->getArrayAttr(operand_precision_attrs)); +} + +mlir::NamedAttribute HloFunctionImporter::ConvertComparisonDirection( + HloInstruction* instruction) { + return builder_->getNamedAttr( + "comparison_direction", + builder_->getStringAttr( + ComparisonDirectionToString(instruction->comparison_direction()))); +} + +mlir::ElementsAttr HloFunctionImporter::ConvertDimensions( + llvm::ArrayRef op_dimensions) { + llvm::SmallVector dimensions; + dimensions.reserve(op_dimensions.size()); + for (auto value : op_dimensions) dimensions.emplace_back(APInt(64, value)); + + return DenseIntElementsAttr::get( + builder_->getTensorType(dimensions.size(), builder_->getIntegerType(64)), + dimensions); +} + +mlir::ElementsAttr HloFunctionImporter::Convert( + llvm::ArrayRef op_dimensions) { + return builder_->getDenseIntElementsAttr( + builder_->getTensorType(op_dimensions.size(), + builder_->getIntegerType(64)), + op_dimensions); +} + +Status HloFunctionImporter::ValidateDotDimensions(HloInstruction* instruction) { + DotDimensionNumbers expected_dimension_numbers; + expected_dimension_numbers.add_lhs_contracting_dimensions( + instruction->operand(0)->shape().dimensions_size() == 1 ? 0 : 1); + expected_dimension_numbers.add_rhs_contracting_dimensions(0); + if (!xla::protobuf_util::ProtobufEquals(instruction->dot_dimension_numbers(), + expected_dimension_numbers)) { + return tensorflow::errors::Internal( + absl::StrCat("Dot operation has unsupported dimension numbers: ", + instruction->dot_dimension_numbers().DebugString())); + } + return Status::OK(); +} + +} // namespace xla diff --git a/tensorflow/compiler/mlir/xla/hlo_function_importer.h b/tensorflow/compiler/mlir/xla/hlo_function_importer.h new file mode 100644 index 00000000000..ee321432f4d --- /dev/null +++ b/tensorflow/compiler/mlir/xla/hlo_function_importer.h @@ -0,0 +1,112 @@ +/* Copyright 2019 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_COMPILER_MLIR_XLA_HLO_FUNCTION_IMPORTER_H_ +#define TENSORFLOW_COMPILER_MLIR_XLA_HLO_FUNCTION_IMPORTER_H_ + +#include + +#include "mlir/IR/Attributes.h" // TF:local_config_mlir +#include "mlir/IR/Builders.h" // TF:local_config_mlir +#include "mlir/IR/Function.h" // TF:local_config_mlir +#include "mlir/IR/MLIRContext.h" // TF:local_config_mlir +#include "mlir/IR/Module.h" // TF:local_config_mlir +#include "mlir/IR/StandardTypes.h" // TF:local_config_mlir +#include "tensorflow/compiler/mlir/tensorflow/utils/error_util.h" +#include "tensorflow/compiler/xla/status.h" +#include "tensorflow/compiler/xla/statusor.h" +#include "tensorflow/compiler/xla/xla_data.pb.h" + +namespace xla { + +class HloModule; +class HloComputation; +class HloInstruction; +class Shape; + +// Helper class for importing HloComputations. +class HloFunctionImporter { + public: + static StatusOr ImportFunction( + mlir::ModuleOp module, mlir::Builder* builder, + std::unordered_map* function_map, + xla::HloComputation* computation); + + private: + HloFunctionImporter( + mlir::ModuleOp module, mlir::Builder* builder, + std::unordered_map* function_map) + : context_(module.getContext()), + module_(module), + builder_(builder), + function_map_(function_map) {} + + StatusOr ImportFunction(xla::HloComputation* computation); + + // Imports an instruction. + StatusOr ImportInstruction(xla::HloInstruction* instruction, + mlir::OpBuilder* func_builder); + + // Gets the MLIR operand values from an HLO Instruction. + StatusOr> GetOperands( + xla::HloInstruction* instruction); + + // Converts xla Tensor type to the corresponding MLIR type. + StatusOr ConvertTensorType(const xla::Shape& shape); + + // Converts xla Primitive types to the corresponding MLIR type. + StatusOr ConvertType(const xla::Shape& shape); + + // Returns the output type of an HloInstruction. + StatusOr GetReturnType(xla::HloInstruction* instruction); + + // Takes a list of HloInstructions and generates the list of types used for + // input, bypassing tuples to subsets. + Status GetMlirTypes(const std::vector& instructions, + llvm::SmallVectorImpl* types); + + // Returns the Mlir Value for the corresponding HloInstruction. + StatusOr GetMlirValue(xla::HloInstruction* instruction); + + // Converts an XLA PrecisionConfig to the corresponding MLIR attribute. + mlir::NamedAttribute ConvertPrecisionConfig(xla::HloInstruction* instruction); + + // Converts an XLA ComparisonDirection to the corresponding MLIR attribute. + mlir::NamedAttribute ConvertComparisonDirection( + xla::HloInstruction* instruction); + + // Converts the dimensions of an HLO instruction into an MLIR attribute. + mlir::ElementsAttr ConvertDimensions(llvm::ArrayRef op_dimensions); + + // Converts Array ref to an ElementsAttr. + mlir::ElementsAttr Convert(llvm::ArrayRef op_dimensions); + + // Ensures dot instruction has only default contracting and batch dimensions. + Status ValidateDotDimensions(xla::HloInstruction* instruction); + + mlir::MLIRContext* context_; + mlir::ModuleOp module_; + mlir::Builder* builder_; + + // Mapping from HloComputation to the created MLIR function. + std::unordered_map* function_map_; + + // Mapping from HloInstructions to the associative MLIR values. + std::unordered_map instruction_value_map_; +}; + +} // namespace xla + +#endif // TENSORFLOW_COMPILER_MLIR_XLA_HLO_FUNCTION_IMPORTER_H_ diff --git a/tensorflow/compiler/mlir/xla/hlo_module_importer.cc b/tensorflow/compiler/mlir/xla/hlo_module_importer.cc new file mode 100644 index 00000000000..f11e06a56f9 --- /dev/null +++ b/tensorflow/compiler/mlir/xla/hlo_module_importer.cc @@ -0,0 +1,54 @@ +/* Copyright 2019 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tensorflow/compiler/mlir/xla/hlo_module_importer.h" + +#include "mlir/IR/Attributes.h" // TF:local_config_mlir +#include "mlir/IR/Location.h" // TF:local_config_mlir +#include "mlir/IR/OperationSupport.h" // TF:local_config_mlir +#include "mlir/IR/StandardTypes.h" // TF:local_config_mlir +#include "mlir/IR/Types.h" // TF:local_config_mlir +#include "mlir/StandardOps/Ops.h" // TF:local_config_mlir +#include "tensorflow/compiler/mlir/xla/hlo_function_importer.h" +#include "tensorflow/compiler/mlir/xla/ir/xla_ops.h" +#include "tensorflow/compiler/xla/service/hlo_computation.h" +#include "tensorflow/compiler/xla/service/hlo_instruction.h" +#include "tensorflow/compiler/xla/service/hlo_module.h" +#include "tensorflow/compiler/xla/xla.pb.h" + +namespace xla { + +Status HloModuleImporter::Import(const xla::HloModule& module) { + for (const auto& computation : module.computations()) { + auto result = HloFunctionImporter::ImportFunction( + module_, &builder_, &function_map_, computation); + TF_RETURN_IF_ERROR(result.status()); + } + + return Status::OK(); +} + +Status HloModuleImporter::Import(const xla::HloModuleProto& module_proto) { + xla::DebugOptions debug_options; + TF_ASSIGN_OR_RETURN( + auto module_config, + xla::HloModule::CreateModuleConfigFromProto(module_proto, debug_options)); + TF_ASSIGN_OR_RETURN(auto module, xla::HloModule::CreateFromProto( + module_proto, module_config)); + + return Import(*module); +} + +} // namespace xla diff --git a/tensorflow/compiler/mlir/xla/hlo_module_importer.h b/tensorflow/compiler/mlir/xla/hlo_module_importer.h new file mode 100644 index 00000000000..6603ef8500f --- /dev/null +++ b/tensorflow/compiler/mlir/xla/hlo_module_importer.h @@ -0,0 +1,62 @@ +/* Copyright 2019 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_COMPILER_MLIR_XLA_HLO_MODULE_IMPORTER_H_ +#define TENSORFLOW_COMPILER_MLIR_XLA_HLO_MODULE_IMPORTER_H_ + +#include + +#include "mlir/IR/Builders.h" // TF:local_config_mlir +#include "mlir/IR/Function.h" // TF:local_config_mlir +#include "mlir/IR/MLIRContext.h" // TF:local_config_mlir +#include "mlir/IR/Module.h" // TF:local_config_mlir +#include "tensorflow/compiler/mlir/tensorflow/utils/error_util.h" +#include "tensorflow/compiler/mlir/xla/ir/xla_ops.h" +#include "tensorflow/compiler/xla/status.h" +#include "tensorflow/compiler/xla/xla_data.pb.h" + +namespace xla { +class HloModule; +class HloModuleProto; +class HloComputation; +class HloInstruction; +class Shape; + +// Importer that takes an HloModule and imports it as an MLIR module in the XLA +// dialect. HloModuleImporter does not take ownership. +class HloModuleImporter { + public: + explicit HloModuleImporter(mlir::ModuleOp module) + : module_(module), builder_(module.getContext()) {} + + // Import the HloModule into the MLIR Module. + Status Import(const xla::HloModule& module); + + // Import the HloModuleProto into the MLIR Module. + Status Import(const xla::HloModuleProto& module); + + private: + mlir::ModuleOp module_; + mlir::Builder builder_; + + // Map for tracking which MLIR function map to which HLO Computation. This + // tracks functions as they are imported and provides a quick lookup for + // functions invoked by control flow related operations (e.g. while, call). + std::unordered_map function_map_; +}; + +} // namespace xla + +#endif // TENSORFLOW_COMPILER_MLIR_XLA_HLO_MODULE_IMPORTER_H_ diff --git a/tensorflow/compiler/mlir/xla/hlo_to_mlir_hlo.cc b/tensorflow/compiler/mlir/xla/hlo_to_mlir_hlo.cc new file mode 100644 index 00000000000..d9ffa166289 --- /dev/null +++ b/tensorflow/compiler/mlir/xla/hlo_to_mlir_hlo.cc @@ -0,0 +1,35 @@ +/* Copyright 2019 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tensorflow/compiler/mlir/xla/hlo_to_mlir_hlo.h" + +#include "tensorflow/compiler/mlir/xla/hlo_module_importer.h" +#include "tensorflow/compiler/xla/status_macros.h" +#include "tensorflow/core/lib/core/errors.h" + +namespace xla { + +Status ConvertHloToMlirHlo(mlir::ModuleOp module, + xla::HloModuleProto* hlo_module_proto) { + mlir::StatusScopedDiagnosticHandler diag_handler(module.getContext()); + return HloModuleImporter(module).Import(*hlo_module_proto); +} + +Status ConvertHloToMlirHlo(mlir::ModuleOp module, xla::HloModule* hlo_module) { + mlir::StatusScopedDiagnosticHandler diag_handler(module.getContext()); + return HloModuleImporter(module).Import(*hlo_module); +} + +} // namespace xla diff --git a/tensorflow/compiler/mlir/xla/hlo_to_mlir_hlo.h b/tensorflow/compiler/mlir/xla/hlo_to_mlir_hlo.h new file mode 100644 index 00000000000..5f212ffc893 --- /dev/null +++ b/tensorflow/compiler/mlir/xla/hlo_to_mlir_hlo.h @@ -0,0 +1,39 @@ +/* Copyright 2019 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_COMPILER_MLIR_XLA_HLO_TO_MLIR_HLO_H_ +#define TENSORFLOW_COMPILER_MLIR_XLA_HLO_TO_MLIR_HLO_H_ + +#include "tensorflow/compiler/mlir/tensorflow/utils/error_util.h" +#include "tensorflow/compiler/xla/status.h" + +namespace mlir { +class ModuleOp; +} // namespace mlir + +namespace xla { +class HloModule; +class HloModuleProto; + +// Converts an HLO module proto to a MLIR module in HLO dialect. +Status ConvertHloToMlirHlo(mlir::ModuleOp module, + xla::HloModuleProto* hlo_module); + +// Converts an HLO module to a MLIR module in HLO dialect. +Status ConvertHloToMlirHlo(mlir::ModuleOp module, xla::HloModule* hlo_module); + +} // namespace xla + +#endif // TENSORFLOW_COMPILER_MLIR_XLA_HLO_TO_MLIR_HLO_H_ diff --git a/tensorflow/compiler/mlir/xla/mlir_hlo_to_hlo.cc b/tensorflow/compiler/mlir/xla/mlir_hlo_to_hlo.cc new file mode 100644 index 00000000000..2ec1324a1cf --- /dev/null +++ b/tensorflow/compiler/mlir/xla/mlir_hlo_to_hlo.cc @@ -0,0 +1,259 @@ +/* Copyright 2019 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tensorflow/compiler/mlir/xla/mlir_hlo_to_hlo.h" + +#include + +#include "llvm/ADT/DenseMap.h" +#include "llvm/Support/MemoryBuffer.h" +#include "llvm/Support/SMLoc.h" +#include "llvm/Support/SourceMgr.h" +#include "llvm/Support/raw_ostream.h" +#include "mlir/IR/Attributes.h" // TF:local_config_mlir +#include "mlir/IR/Function.h" // TF:local_config_mlir +#include "mlir/IR/Location.h" // TF:local_config_mlir +#include "mlir/IR/MLIRContext.h" // TF:local_config_mlir +#include "mlir/IR/Module.h" // TF:local_config_mlir +#include "mlir/IR/Operation.h" // TF:local_config_mlir +#include "mlir/StandardOps/Ops.h" // TF:local_config_mlir +#include "tensorflow/compiler/mlir/xla/ir/xla_ops.h" +#include "tensorflow/compiler/mlir/xla/type_to_shape.h" +#include "tensorflow/compiler/xla/client/xla_builder.h" +#include "tensorflow/compiler/xla/comparison_util.h" +#include "tensorflow/compiler/xla/service/hlo_module.h" +#include "tensorflow/compiler/xla/status_macros.h" +#include "tensorflow/compiler/xla/xla_data.pb.h" + +static std::vector ConvertDenseIntAttr(mlir::DenseIntElementsAttr attr) { + llvm::ArrayRef raw_data = attr.getValues(); + if (attr.isSplat()) + return std::vector(attr.getType().getNumElements(), raw_data[0]); + return raw_data; +} + +// Converts the broadcast_dimensions attribute into a span of dimension numbers +// (empty if the attribute is absent). +static std::vector Convert_broadcast_dimensions( + llvm::Optional broadcast_dimensions) { + if (!broadcast_dimensions.hasValue()) return {}; + + return ConvertDenseIntAttr( + broadcast_dimensions->cast()); +} + +// Converts the broadcast_sizes attribute into a span of dimension sizes. +static std::vector Convert_broadcast_sizes( + mlir::ElementsAttr broadcast_sizes) { + return ConvertDenseIntAttr( + broadcast_sizes.cast()); +} + +static std::vector Convert_permutation(mlir::ElementsAttr permutation) { + return ConvertDenseIntAttr(permutation.cast()); +} + +// Converts the precision config array of strings attribute into the +// corresponding XLA proto. All the strings are assumed to be valid names of the +// Precision enum. This should have been checked in the op verify method. +static std::unique_ptr Convert_precision_config( + llvm::Optional optional_precision_config_attr) { + if (!optional_precision_config_attr.hasValue()) return nullptr; + + auto precision_config = absl::make_unique(); + for (auto attr : optional_precision_config_attr.getValue()) { + xla::PrecisionConfig::Precision p; + auto operand_precision = attr.cast().getValue().str(); + // TODO(jpienaar): Update this to ensure this is captured by verify. + if (xla::PrecisionConfig::Precision_Parse(operand_precision, &p)) { + precision_config->add_operand_precision(p); + } else { + auto* context = attr.getContext(); + mlir::emitError(mlir::UnknownLoc::get(context)) + << "unexpected operand precision " << operand_precision; + return nullptr; + } + } + + return precision_config; +} + +// Converts the comparison_direction string attribute into the XLA enum. The +// string is assumed to correspond to exactly one of the allowed strings +// representing the enum. This should have been checked in the op verify method. +static xla::ComparisonDirection Convert_comparison_direction( + llvm::StringRef comparison_direction_string) { + return xla::StringToComparisonDirection(comparison_direction_string.str()) + .ValueOrDie(); +} + +// Passes through everything except for unique_ptr, on which it calls get(). +// This exists to allow the generated code to call XLA functions that take a raw +// pointer. In particular, PrecisionConfig is passed to xla::Dot and xla::Conv +// as a pointer and there is otherwise no way to avoid a memory leak. +template +T Unwrap(T t) { + return t; +} + +template +T* Unwrap(const std::unique_ptr& t) { + return t.get(); +} + +// Convert APInt into an int. +// TODO(hpucha): This should be consolidated into a general place. +static int ConvertAPInt(llvm::APInt i) { return i.getSExtValue(); } + +// Convert APFloat to double. +static double ConvertAPFloat(llvm::APFloat value) { + const auto& semantics = value.getSemantics(); + bool losesInfo = false; + if (&semantics != &llvm::APFloat::IEEEdouble()) + value.convert(llvm::APFloat::IEEEdouble(), + llvm::APFloat::rmNearestTiesToEven, &losesInfo); + return value.convertToDouble(); +} + +#include "tensorflow/compiler/mlir/xla/operator_writers.inc" + +namespace mlir { +namespace { + +class ConvertToHloModule { + public: + using ValueLoweringMap = llvm::DenseMap; + using FunctionLoweringMap = llvm::DenseMap; + + explicit ConvertToHloModule(mlir::ModuleOp module) + : module_(module), module_builder_("main") {} + + // Perform the lowering to XLA. This function returns failure if an error was + // encountered. + LogicalResult Run() { + for (auto func : module_.getOps()) { + if (func.empty()) continue; + if (failed(RunOnFunction(func))) return failure(); + } + return success(); + } + + // Perform the lowering on a specific function. This function returns failure + // if an error was encountered. + LogicalResult RunOnFunction(mlir::FuncOp f); + + xla::HloModuleProto ConsumeMainProto() { + return lowered_computation_[module_.lookupSymbol("main")] + .proto(); + } + + private: + // The module being lowered. + mlir::ModuleOp module_; + + // The top-level XlaBuilder. + xla::XlaBuilder module_builder_; + + // Map between function and lowered computation. + FunctionLoweringMap lowered_computation_; +}; + +LogicalResult Lower(mlir::Operation* inst, xla::XlaBuilder* builder, + ConvertToHloModule::FunctionLoweringMap* function_lowering, + ConvertToHloModule::ValueLoweringMap* value_lowering) { + if (auto xla_op = CreateXlaOperator(inst, value_lowering)) return success(); + + // TODO(riverriddle) We currently don't support lowering constant operations. + if (isa(inst)) { + inst->emitError("unable to lower 'xla.constant' operation"); + return failure(); + } + + auto& value_map = *value_lowering; + if (auto ret = dyn_cast(inst)) { + // Construct the return value for the function. If there are multiple + // values returned, then create a tuple, else return value directly. + xla::XlaOp return_value; + unsigned num_return_values = ret.getNumOperands(); + if (num_return_values > 1) { + std::vector returns(num_return_values); + for (unsigned i = 0, e = ret.getNumOperands(); i != e; ++i) { + returns[i] = value_map[ret.getOperand(i)]; + } + return_value = xla::Tuple(builder, returns); + } else if (num_return_values == 1) { + return_value = value_map[ret.getOperand(0)]; + } + + // Build the XlaComputation and check for failures. + auto computation_or = + return_value.valid() ? builder->Build(return_value) : builder->Build(); + if (!computation_or.ok()) { + inst->emitError(llvm::Twine(computation_or.status().error_message())); + return failure(); + } + auto f = inst->getParentOfType(); + (*function_lowering)[f] = std::move(computation_or.ValueOrDie()); + return success(); + } + inst->emitError("unable to lower operation of type '" + + inst->getName().getStringRef().str() + '\''); + return failure(); +} + +LogicalResult ConvertToHloModule::RunOnFunction(mlir::FuncOp f) { + if (f.getBlocks().size() != 1) { + return f.emitError("only single block Function suppored"); + } + + // Create a sub-builder if this is not the main function. + std::unique_ptr builder_up; + bool entry_function = f.getName().str() == "main"; + if (!entry_function) + builder_up = module_builder_.CreateSubBuilder(f.getName().str()); + auto& builder = entry_function ? module_builder_ : *builder_up; + + // Mapping from the Value to lowered XlaOp. The code below lowers in + // program order and will fail if an operand is unseen. This can be improved. + ValueLoweringMap lowering; + for (auto& bb : f) { + int num = 0; + for (auto& arg : bb.getArguments()) { + xla::Shape shape = xla::TypeToShape(arg->getType()); + lowering[arg] = + xla::Parameter(&builder, num, shape, absl::StrCat("Arg_", num)); + ++num; + } + + for (auto& inst : bb) + if (failed(Lower(&inst, &builder, &lowered_computation_, &lowering))) + return failure(); + } + + return success(); +} + +} // namespace + +Status ConvertMlirHloToHlo(mlir::ModuleOp module, xla::HloProto* hlo_proto) { + mlir::StatusScopedDiagnosticHandler diag_handler(module.getContext()); + ConvertToHloModule converter(module); + if (failed(converter.Run())) return diag_handler.ConsumeStatus(); + auto hlo_module = converter.ConsumeMainProto(); + hlo_proto->mutable_hlo_module()->Swap(&hlo_module); + return Status::OK(); +} + +} // namespace mlir diff --git a/tensorflow/compiler/mlir/xla/mlir_hlo_to_hlo.h b/tensorflow/compiler/mlir/xla/mlir_hlo_to_hlo.h new file mode 100644 index 00000000000..b16636f039c --- /dev/null +++ b/tensorflow/compiler/mlir/xla/mlir_hlo_to_hlo.h @@ -0,0 +1,37 @@ +/* Copyright 2019 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_COMPILER_MLIR_XLA_MLIR_HLO_TO_HLO_H_ +#define TENSORFLOW_COMPILER_MLIR_XLA_MLIR_HLO_TO_HLO_H_ + +#include "mlir/IR/Module.h" // TF:local_config_mlir +#include "tensorflow/compiler/mlir/tensorflow/utils/error_util.h" +#include "tensorflow/compiler/xla/client/xla_builder.h" +#include "tensorflow/compiler/xla/service/hlo_module.h" + +namespace mlir { + +// Converts a MLIR module in HLO dialect into a HloModuleProto. +Status ConvertMlirHloToHlo(mlir::ModuleOp module, xla::HloProto* hlo_proto); + +// Creates XlaOp equivalent of a given MLIR operation using the operand info +// from `value_lowering` map. +llvm::Optional CreateXlaOperator( + mlir::Operation* op, + llvm::DenseMap* value_lowering); + +} // namespace mlir + +#endif // TENSORFLOW_COMPILER_MLIR_XLA_MLIR_HLO_TO_HLO_H_ diff --git a/tensorflow/compiler/mlir/xla/operator_writer_gen.cc b/tensorflow/compiler/mlir/xla/operator_writer_gen.cc new file mode 100644 index 00000000000..0fb315b90f9 --- /dev/null +++ b/tensorflow/compiler/mlir/xla/operator_writer_gen.cc @@ -0,0 +1,196 @@ +/* Copyright 2019 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include + +#include "llvm/ADT/StringExtras.h" +#include "llvm/Support/CommandLine.h" +#include "llvm/Support/FormatVariadic.h" +#include "llvm/Support/PrettyStackTrace.h" +#include "llvm/Support/Signals.h" +#include "llvm/TableGen/Error.h" +#include "llvm/TableGen/Main.h" +#include "llvm/TableGen/Record.h" +#include "llvm/TableGen/TableGenBackend.h" +#include "mlir/TableGen/Operator.h" // TF:local_config_mlir + +using llvm::dyn_cast; +using llvm::LessRecord; +using llvm::raw_ostream; +using llvm::Record; +using llvm::RecordKeeper; +using llvm::StringRef; +using mlir::tblgen::Operator; + +// Returns the builder function name for the given op definition. +// E.g., AddOp -> CreateAddOp +static inline std::string GetOperatorBuilderName(StringRef op_name) { + return "Create" + op_name.str(); +} + +static std::string GetConversionFunction( + mlir::tblgen::NamedAttribute named_attr) { + auto storage_type = named_attr.attr.getStorageType(); + // For some attribute types we have a general conversion, so use that. + if (storage_type.endswith("IntegerAttr") || + storage_type.endswith("FloatAttr")) { + return "Convert" + named_attr.attr.getReturnType().str(); + } + return "Convert_" + named_attr.name.str(); +} + +using ArgumentName = string; +using ArgumentDeclaration = string; +using Argument = std::pair; +using ArgumentList = std::vector; + +static std::string BuildOperator(const Operator& op) { + std::stringstream os; + StringRef op_name = op.getCppClassName(); + std::string xla_op_name = op_name.drop_back(2).str(); + + // Signature. + os << "static xla::XlaOp " << GetOperatorBuilderName(op_name) + << "(mlir::XLA::" << op_name.str() << " xla_op, " + << "llvm::DenseMap* " + "value_lowering) {\n"; + + os << " auto& value_map = *value_lowering;\n" + << " auto result = xla_op.getResult();\n"; + + // Invoke the conversion function for each attribute. + for (const auto& named_attr : op.getAttributes()) { + os << " auto " << named_attr.name.str() << " = " + << GetConversionFunction(named_attr) << "(" + << "xla_op." << named_attr.name.str() << "());\n"; + } + + // Assumes that the client builder method names closely follow the op names + // in the dialect. For e.g., AddOp -> xla::Add method. + os << " auto xla_result = xla::" << xla_op_name << "("; + + int num_operands = op.getNumOperands(); + if (num_operands == 1) { + os << "value_map[xla_op.getOperand()]"; + } else { + for (auto i = 0; i < num_operands; i++) { + os << "value_map[xla_op.getOperand(" << i << ")]"; + if (i != num_operands - 1) { + os << ", "; + } + } + } + + for (const auto& named_attr : op.getAttributes()) { + os << ", Unwrap(" << named_attr.name.str() << ")"; + } + + os << ");\n"; + + os << " value_map[result] = xla_result;\n"; + os << " return xla_result;\n"; + os << "}\n\n"; + return os.str(); +} + +// For each XLA op, emits a builder function that constructs the XLA op using +// the HLO client builder. +static void EmitOperatorBuilders(const RecordKeeper& record_keeper, + const std::vector& defs, + raw_ostream* ostream) { + raw_ostream& os = *ostream; + + for (const auto* def : defs) { + // Skip operations that have a custom converter. + if (def->getValueAsBit("hasCustomHLOConverter")) continue; + + Operator op(def); + os << BuildOperator(op); + } +} + +// Emits a builder function that returns the XlaOp object given a +// mlir::Operation. +// +// The signature of the function is: +// +// llvm::Optional +// mlir::CreateXlaOperator( +// mlir::Operation* op, +// llvm::DenseMap +// *value_lowering); +static void EmitBuilder(const std::vector& defs, + raw_ostream* ostream) { + raw_ostream& os = *ostream; + + // Signature + os << "llvm::Optional\n" + "mlir::CreateXlaOperator(mlir::Operation* op, " + "llvm::DenseMap " + "*value_lowering) {\n"; + + for (const auto* def : defs) { + // Skip operations that have a custom converter. + if (def->getValueAsBit("hasCustomHLOConverter")) continue; + + StringRef op_name = def->getName().drop_front(4); + + // Try to cast to each op and call the corresponding op builder. + os << " if (auto xla_op = llvm::dyn_cast(op))\n return " << GetOperatorBuilderName(op_name) + << "(xla_op, value_lowering);\n"; + } + + os << " return llvm::None;\n" + "}\n"; +} + +// The function below has a non-constant reference as that is required by LLVM's +// TableGenMain. +// NOLINTNEXTLINE +static bool OperatorWritersMain(raw_ostream& os, RecordKeeper& records) { + emitSourceFileHeader("MLIR XLA Builders", os); + + // Retrieve all the definitions derived from XLA_Op and sort by record name. + std::vector defs = records.getAllDerivedDefinitions("XLA_Op"); + llvm::sort(defs, LessRecord()); + + for (const auto* def : defs) { + // XLA ops in the .td file are expected to follow the naming convention: + // XLA_Op. + // The generated XLA op C++ class should be XLA::Op. + if (!def->getName().startswith("XLA_")) + PrintFatalError(def->getLoc(), + "unexpected op name format: 'XLA_' prefix missing"); + if (!def->getName().endswith("Op")) + PrintFatalError(def->getLoc(), + "unexpected op name format: 'Op' suffix missing"); + } + + EmitOperatorBuilders(records, defs, &os); + os << "\n\n"; + EmitBuilder(defs, &os); + + return false; +} + +int main(int argc, char** argv) { + llvm::sys::PrintStackTraceOnErrorSignal(argv[0]); + llvm::PrettyStackTraceProgram X(argc, argv); + + llvm::llvm_shutdown_obj Y; + llvm::cl::ParseCommandLineOptions(argc, argv); + return TableGenMain(argv[0], &OperatorWritersMain); +} diff --git a/tensorflow/compiler/mlir/xla/tests/translate/BUILD b/tensorflow/compiler/mlir/xla/tests/translate/BUILD new file mode 100644 index 00000000000..deb4e2ce231 --- /dev/null +++ b/tensorflow/compiler/mlir/xla/tests/translate/BUILD @@ -0,0 +1,23 @@ +load("//tensorflow/compiler/mlir:glob_lit_test.bzl", "glob_lit_tests") + +package(licenses = ["notice"]) + +glob_lit_tests( + data = [":test_utilities"], + driver = "@local_config_mlir//:run_lit.sh", + test_file_exts = [ + "mlir", + "hlo", + "hlotxt", + ], +) + +# Bundle together all of the test utilities that are used by tests. +filegroup( + name = "test_utilities", + testonly = True, + data = [ + "//tensorflow/compiler/mlir:tf-mlir-translate", + "@llvm//:FileCheck", + ], +) diff --git a/tensorflow/compiler/mlir/xla/tests/translate/add.hlotxt b/tensorflow/compiler/mlir/xla/tests/translate/add.hlotxt new file mode 100644 index 00000000000..d285df18bc9 --- /dev/null +++ b/tensorflow/compiler/mlir/xla/tests/translate/add.hlotxt @@ -0,0 +1,28 @@ +// RUN: tf-mlir-translate -hlo-text-to-mlir-hlo %s -o - | FileCheck %s + +// This test is more thorough than those of the the other binary ops to test +// their shared functionality. + +HloModule main.5 + +// CHECK-LABEL: func @main(%arg0: tensor<4xf32>, %arg1: tensor<4xf32>, %arg2: tensor, %arg3: tensor) -> tensor<4xf32> { +ENTRY %foo.5 (Arg_0.1: f32[4], Arg_1.2: f32[4], Arg_2.3: f32[], Arg_3.4: f32[]) -> f32[4] { + %Arg_0.1 = f32[4] parameter(0) + %Arg_1.2 = f32[4] parameter(1) + %Arg_2.3 = f32[] parameter(2) + %Arg_3.4 = f32[] parameter(3) + + // Add two tensors + // CHECK-NEXT: %0 = "xla.add"(%arg0, %arg1) {name = "add.3"} : (tensor<4xf32>, tensor<4xf32>) -> tensor<4xf32> + %add.3 = f32[4] add(f32[4] %Arg_0.1, f32[4] %Arg_1.2) + + // Add two scalars + // CHECK-NEXT: %1 = "xla.add"(%arg2, %arg3) {name = "add.4"} : (tensor, tensor) -> tensor + %add.4 = f32[] add(f32[] %Arg_2.3, f32[] %Arg_3.4) + + // Add a tensor and scalar + // CHECK-NEXT: %2 = "xla.add"(%0, %1) {name = "add.5"} : (tensor<4xf32>, tensor) -> tensor<4xf32> + // CHECK-NEXT: return %2 : tensor<4xf32> + ROOT %add.5 = f32[4] add(f32[4] %add.3, f32[] %add.4) +} + diff --git a/tensorflow/compiler/mlir/xla/tests/translate/add.mlir b/tensorflow/compiler/mlir/xla/tests/translate/add.mlir new file mode 100644 index 00000000000..4009759f3b8 --- /dev/null +++ b/tensorflow/compiler/mlir/xla/tests/translate/add.mlir @@ -0,0 +1,14 @@ +// RUN: tf-mlir-translate -mlir-hlo-to-hlo-text %s | FileCheck %s + +// CHECK-LABEL: ENTRY %main.5 (Arg_0.1: f32[4], Arg_1.2: f32[4]) -> f32[4] { +func @main(%arg0: tensor<4xf32>, %arg1: tensor<4xf32>) -> tensor<4xf32> { + // CHECK-NEXT: %Arg_0.1 = f32[4] parameter(0) + // CHECK-NEXT: %Arg_1.2 = f32[4] parameter(1) + + // CHECK-NEXT: %add.3 = f32[4] add(f32[4] %Arg_0.1, f32[4] %Arg_1.2) + %0 = "xla.add"(%arg0, %arg1) : (tensor<4xf32>, tensor<4xf32>) -> tensor<4xf32> + + // CHECK-NEXT: ROOT %add.4 = f32[4] add(f32[4] %add.3, f32[4] %Arg_1.2) + %1 = "xla.add"(%0, %arg1) : (tensor<4xf32>, tensor<4xf32>) -> tensor<4xf32> + return %1 : tensor<4xf32> +} diff --git a/tensorflow/compiler/mlir/xla/tests/translate/and.hlotxt b/tensorflow/compiler/mlir/xla/tests/translate/and.hlotxt new file mode 100644 index 00000000000..1826809db63 --- /dev/null +++ b/tensorflow/compiler/mlir/xla/tests/translate/and.hlotxt @@ -0,0 +1,14 @@ +// RUN: tf-mlir-translate -hlo-text-to-mlir-hlo %s -o - | FileCheck %s + +HloModule main.5 + +// CHECK-LABEL: func @main(%arg0: tensor<4xf32>, %arg1: tensor<4xf32>) -> tensor<4xf32> { +ENTRY %foo.5 (Arg_0.1: f32[4], Arg_1.2: f32[4]) -> f32[4] { + %Arg_0.1 = f32[4] parameter(0) + %Arg_1.2 = f32[4] parameter(1) + + // CHECK-NEXT: %0 = "xla.and"(%arg0, %arg1) {name = "and.3"} : (tensor<4xf32>, tensor<4xf32>) -> tensor<4xf32> + // CHECK-NEXT: return %0 : tensor<4xf32> + ROOT %and.3 = f32[4] and(f32[4] %Arg_0.1, f32[4] %Arg_1.2) +} + diff --git a/tensorflow/compiler/mlir/xla/tests/translate/binary_op_broadcast.mlir b/tensorflow/compiler/mlir/xla/tests/translate/binary_op_broadcast.mlir new file mode 100644 index 00000000000..9aff6393e86 --- /dev/null +++ b/tensorflow/compiler/mlir/xla/tests/translate/binary_op_broadcast.mlir @@ -0,0 +1,26 @@ +// RUN: tf-mlir-translate -mlir-hlo-to-hlo-text %s | FileCheck %s + +// CHECK-LABEL: ENTRY %main.13 (Arg_0.1: s32[1,4], Arg_1.2: s32[2,4], Arg_2.3: s32[2,3,4]) -> s32[2,3,4] { +func @main(%arg0: tensor<1x4xi32>, %arg1: tensor<2x4xi32>, %arg2: tensor<2x3x4xi32>) -> tensor<2x3x4xi32> { + // Same rank degenerate broadcast + // CHECK-NEXT: %Arg_0.1 = s32[1,4] parameter(0) + // CHECK-NEXT: %reshape.4 = s32[4] reshape(s32[1,4] %Arg_0.1) + // CHECK-NEXT: %broadcast.5 = s32[2,4] broadcast(s32[4] %reshape.4) + // CHECK-NEXT: %Arg_1.2 = s32[2,4] parameter(1) + // CHECK-NEXT: %add.6 = s32[2,4] add(s32[2,4] %broadcast.5, s32[2,4] %Arg_1.2) + %0 = "xla.add"(%arg0, %arg1) : (tensor<1x4xi32>, tensor<2x4xi32>) -> tensor<2x4xi32> + + // Broadcast up rank + // CHECK-NEXT: %broadcast.7 = s32[2,3,4] broadcast(s32[2,4] %Arg_1.2), dimensions={0,2} + // CHECK-NEXT: %Arg_2.3 = s32[2,3,4] parameter(2) + // CHECK-NEXT: %add.8 = s32[2,3,4] add(s32[2,3,4] %broadcast.7, s32[2,3,4] %Arg_2.3) + %1 = "xla.add"(%arg1, %arg2) {broadcast_dimensions = dense<[0,2]> : tensor<2xi64>} : (tensor<2x4xi32>, tensor<2x3x4xi32>) -> tensor<2x3x4xi32> + + // Broadcast up rank + degenerate broadcast + // CHECK-NEXT: %broadcast.9 = s32[2,1,4] broadcast(s32[1,4] %Arg_0.1), dimensions={1,2} + // CHECK-NEXT: %reshape.10 = s32[2,4] reshape(s32[2,1,4] %broadcast.9) + // CHECK-NEXT: %broadcast.11 = s32[2,3,4] broadcast(s32[2,4] %reshape.10), dimensions={0,2} + // CHECK-NEXT: ROOT %add.12 = s32[2,3,4] add(s32[2,3,4] %broadcast.11, s32[2,3,4] %Arg_2.3) + %2 = "xla.add"(%arg0, %arg2) {broadcast_dimensions = dense<[1,2]> : tensor<2xi64>} : (tensor<1x4xi32>, tensor<2x3x4xi32>) -> tensor<2x3x4xi32> + return %2 : tensor<2x3x4xi32> +} diff --git a/tensorflow/compiler/mlir/xla/tests/translate/broadcast.mlir b/tensorflow/compiler/mlir/xla/tests/translate/broadcast.mlir new file mode 100644 index 00000000000..1d231535703 --- /dev/null +++ b/tensorflow/compiler/mlir/xla/tests/translate/broadcast.mlir @@ -0,0 +1,9 @@ +// RUN: tf-mlir-translate -mlir-hlo-to-hlo-text %s | FileCheck %s + +// CHECK-LABEL: ENTRY %main.3 (Arg_0.1: s32[4]) -> s32[1,2,3,4] { +func @main(%arg0: tensor<4xi32>) -> tensor<1x2x3x4xi32> { + // CHECK-NEXT: %Arg_0.1 = s32[4] parameter(0) + // CHECK-NEXT: ROOT %broadcast.2 = s32[1,2,3,4] broadcast(s32[4] %Arg_0.1), dimensions={3} + %0 = "xla.broadcast"(%arg0) {broadcast_sizes = dense<[1,2,3]> : tensor<3xi64>} : (tensor<4xi32>) -> tensor<1x2x3x4xi32> + return %0 : tensor<1x2x3x4xi32> +} diff --git a/tensorflow/compiler/mlir/xla/tests/translate/broadcast_in_dim.hlotxt b/tensorflow/compiler/mlir/xla/tests/translate/broadcast_in_dim.hlotxt new file mode 100644 index 00000000000..d9c2e9fe094 --- /dev/null +++ b/tensorflow/compiler/mlir/xla/tests/translate/broadcast_in_dim.hlotxt @@ -0,0 +1,20 @@ +// RUN: tf-mlir-translate -hlo-text-to-mlir-hlo %s -o - | FileCheck %s + +HloModule main + +// CHECK-LABEL: func @main(%arg0: tensor<1x2xf32>) -> tensor<3x1x2xf32> { +ENTRY %main { + %Arg_0.1 = f32[1, 2] parameter(0) + + // CHECK-NEXT: %0 = "xla.broadcast_in_dim"(%arg0) {broadcast_dimensions = dense<[0, 1]> : tensor<2xi64>, name = "broadcast.2"} : (tensor<1x2xf32>) -> tensor<1x2x3xf32> + %broadcast.2 = f32[1,2,3] broadcast(%Arg_0.1), dimensions={0,1} + + // Degenerate broadcast + // CHECK-NEXT: %1 = "xla.broadcast_in_dim"(%arg0) {name = "broadcast.3"} : (tensor<1x2xf32>) -> tensor<3x2xf32> + broadcast.3 = f32[3,2] broadcast(%Arg_0.1), dimensions={} + + // CHECK-NEXT: %2 = "xla.broadcast_in_dim"(%arg0) {broadcast_dimensions = dense<[1, 2]> : tensor<2xi64>, name = "broadcast.4"} : (tensor<1x2xf32>) -> tensor<3x1x2xf32> + // CHECK-NEXT: return %2 : tensor<3x1x2xf32> + ROOT broadcast.4 = f32[3,1,2] broadcast(%Arg_0.1), dimensions={1, 2} +} + diff --git a/tensorflow/compiler/mlir/xla/tests/translate/call.hlotxt b/tensorflow/compiler/mlir/xla/tests/translate/call.hlotxt new file mode 100644 index 00000000000..c7ea0f9637e --- /dev/null +++ b/tensorflow/compiler/mlir/xla/tests/translate/call.hlotxt @@ -0,0 +1,19 @@ +// RUN: tf-mlir-translate -hlo-text-to-mlir-hlo %s -o - | FileCheck %s + +HloModule foo + +// CHECK-LABEL: func @call(%arg0: tensor) -> tensor { +%call (arg_1: s64[]) -> s64[] { + %arg_1 = s64[] parameter(0), metadata={op_name="XLA_Args"} + // CHECK-NEXT: %0 = "xla.add"(%arg0, %arg0) {name = "compare.2"} : (tensor, tensor) -> tensor + // CHECK-NEXT: return %0 : tensor + ROOT %compare.2 = s64[] add(%arg_1, %arg_1), metadata={op_type="Less" op_name="Less"} +} + +// CHECK-LABEL: func @main(%arg0: tensor) -> tensor { +ENTRY %foo (arg0.1: s64[]) -> s64[] { + %arg0.1 = s64[] parameter(0), metadata={op_name="XLA_Args"} + // CHECK-NEXT: %0 = call @call(%arg0) : (tensor) -> tensor + // CHECK-NEXT: return %0 : tensor + ROOT %call.2 = s64[] call(%arg0.1), to_apply=%call +} diff --git a/tensorflow/compiler/mlir/xla/tests/translate/comp.hlotxt b/tensorflow/compiler/mlir/xla/tests/translate/comp.hlotxt new file mode 100644 index 00000000000..ed3019b81cb --- /dev/null +++ b/tensorflow/compiler/mlir/xla/tests/translate/comp.hlotxt @@ -0,0 +1,21 @@ +// RUN: tf-mlir-translate -hlo-text-to-mlir-hlo %s -o - | FileCheck %s + +HloModule main + +// CHECK-LABEL: func @main(%arg0: tensor<3xf32>, %arg1: tensor<3xf32>, %arg2: tensor<1xf32>) -> tensor<3xi1> { +ENTRY %main (Arg_0.1: f32[3], Arg_1.2: f32[3], Arg_2.3: f32[1]) -> pred[3] { + %Arg_0.1 = f32[3] parameter(0) + %Arg_1.2 = f32[3] parameter(1) + %Arg_2.3 = f32[1] parameter(2) + + // CHECK-NEXT: %0 = "xla.compare"(%arg0, %arg1) {comparison_direction = "EQ", name = "compare.4"} : (tensor<3xf32>, tensor<3xf32>) -> tensor<3xi1> + %compare.4 = pred[3] compare(Arg_0.1, Arg_1.2), direction=EQ + + // CHECK-NEXT: %1 = "xla.compare"(%arg0, %arg1) {comparison_direction = "LE", name = "compare.5"} : (tensor<3xf32>, tensor<3xf32>) -> tensor<3xi1> + %compare.5 = pred[3] compare(Arg_0.1, Arg_1.2), direction=LE + + // Requires broadcast of compatible tensors. + // CHECK-NEXT: %2 = "xla.compare"(%arg0, %arg2) {comparison_direction = "GT", name = "compare.6"} : (tensor<3xf32>, tensor<1xf32>) -> tensor<3xi1> + // CHECK-NEXT: return %2 : tensor<3xi1> + ROOT %compare.6 = pred[3] compare(Arg_0.1, Arg_2.3), direction=GT +} diff --git a/tensorflow/compiler/mlir/xla/tests/translate/concat.hlotxt b/tensorflow/compiler/mlir/xla/tests/translate/concat.hlotxt new file mode 100644 index 00000000000..e73447d768d --- /dev/null +++ b/tensorflow/compiler/mlir/xla/tests/translate/concat.hlotxt @@ -0,0 +1,13 @@ +// RUN: tf-mlir-translate -hlo-text-to-mlir-hlo %s -o - | FileCheck %s + +HloModule main.5 + +// CHECK-LABEL: func @main(%arg0: tensor<4x1xf32>, %arg1: tensor<4x2xf32>) -> tensor<4x3xf32> { +ENTRY %foo.5 (Arg_0.1: f32[4, 1], Arg_1.2: f32[4, 2]) -> f32[4, 3] { + %Arg_0.1 = f32[4, 1] parameter(0) + %Arg_1.2 = f32[4, 2] parameter(1) + + // CHECK-NEXT: %0 = "xla.concatenate"(%arg0, %arg1) {dimension = 1 : i64} : (tensor<4x1xf32>, tensor<4x2xf32>) -> tensor<4x3xf32> + // CHECK-NEXT: return %0 : tensor<4x3xf32> + ROOT %concatenate.3 = f32[4, 3] concatenate(f32[4, 1] %Arg_0.1, f32[4, 2] %Arg_1.2), dimensions={1} +} diff --git a/tensorflow/compiler/mlir/xla/tests/translate/const.hlotxt b/tensorflow/compiler/mlir/xla/tests/translate/const.hlotxt new file mode 100644 index 00000000000..244ad31975a --- /dev/null +++ b/tensorflow/compiler/mlir/xla/tests/translate/const.hlotxt @@ -0,0 +1,21 @@ +// RUN: tf-mlir-translate -hlo-text-to-mlir-hlo %s -o - | FileCheck %s + +HloModule tfcompile.7 + +// CHECK-LABEL: func @main() -> tensor<2x2x1x1xf32> { +ENTRY %tfcompile.7 { + + // Scalar/0D tensor constant + // CHECK-NEXT: %cst = constant {name = "constant.0"} dense<1> : tensor + %constant.0 = s64[] constant(1) + + // Note that double brackets "[[" have to be escaped as they denote variables + // in FileCheck. The only way to do so is to drop into regex with "{{" + // CHECK-NEXT: %cst_0 = constant {name = "constant.1"} dense<{{\[\[\[\[}}1.000000e+00]], {{\[\[}}2.000000e+00]]], {{\[\[\[}}3.000000e+00]], {{\[\[}}4.000000e+00]]]]> : tensor<2x2x1x1xf32> + // CHECK-NEXT: return %cst_0 : tensor<2x2x1x1xf32> + ROOT %constant.1 = f32[2,2,1,1]{3,2,1,0} constant({{{{1.0}},{{2.0}}},{{{3.0}},{{4.0}}}}), metadata={op_type="Conv2D" op_name="embedded_inference/conv_model/conv_0/Conv2D"} +} + + + + diff --git a/tensorflow/compiler/mlir/xla/tests/translate/conv.hlotxt b/tensorflow/compiler/mlir/xla/tests/translate/conv.hlotxt new file mode 100644 index 00000000000..0de3ac6bffe --- /dev/null +++ b/tensorflow/compiler/mlir/xla/tests/translate/conv.hlotxt @@ -0,0 +1,31 @@ +// RUN: tf-mlir-translate -hlo-text-to-mlir-hlo %s -o - | FileCheck %s + +HloModule tfcompile.7 + +// TODO(b/129422361) Potentially update when copy, reshape, and conv have actual +// implementations with attributes, etc. +// CHECK-LABEL: func @main(%arg0: tensor<1x16x16x1xf32>) -> tuple> { +ENTRY %tfcompile.7 { + %arg0.1 = f32[1,16,16,1]{3,2,1,0} parameter(0), metadata={op_name="XLA_Args"} + + // CHECK-NEXT: %0 = "xla.copy"(%arg0) {name = "copy.1"} : (tensor<1x16x16x1xf32>) -> tensor<1x16x16x1xf32> + %copy.1 = f32[1,16,16,1]{2,1,3,0} copy(%arg0.1), metadata={op_name="XLA_Args"} + + // CHECK-NEXT: %1 = "xla.reshape"(%0) {name = "reshape.2"} : (tensor<1x16x16x1xf32>) -> tensor<1x16x16x1xf32> + %reshape.2 = f32[1,16,16,1]{2,1,3,0} reshape(%copy.1) + + // Note that double brackets "[[" have to be escaped as they denote variables + // in FileCheck. The only way to do so is to drop into regex with "{{" + // CHECK-NEXT: %cst = constant {name = "constant.3"} dense<{{\[\[\[\[}}5.000000e-01]], {{\[\[}}-6.000000e-01]]], {{\[\[\[}}3.000000e-01]], {{\[\[}}-1.000000e-01]]]]> : tensor<2x2x1x1xf32> + %constant.3 = f32[2,2,1,1]{3,2,1,0} constant({{{{0.5}}, {{-0.6}}}, {{{0.3}}, {{-0.1}}}}), metadata={op_type="Conv2D" op_name="embedded_inference/conv_model/conv_0/Conv2D"} + + // CHECK-NEXT: %2 = "xla.conv"(%1, %cst) {name = "convolution.4"} : (tensor<1x16x16x1xf32>, tensor<2x2x1x1xf32>) -> tensor<1x16x16x1xf32> + %convolution.4 = f32[1,16,16,1]{2,1,3,0} convolution(%reshape.2, %constant.3), window={size=2x2 pad=0_1x0_1}, dim_labels=b01f_01io->b01f, metadata={op_type="Conv2D" op_name="embedded_inference/conv_model/conv_0/Conv2D"} + + // CHECK-NEXT: %3 = "xla.reshape"(%2) {name = "reshape.5"} : (tensor<1x16x16x1xf32>) -> tensor<1x16x16x1xf32> + %reshape.5 = f32[1,16,16,1]{3,2,1,0} reshape(%convolution.4), metadata={op_name="XLA_Retvals"} + + // CHECK-NEXT: %4 = "xla.tuple"(%3) {name = "tuple.6"} : (tensor<1x16x16x1xf32>) -> tuple> + // CHECK-NEXT: return %4 : tuple> + ROOT %tuple.6 = (f32[1,16,16,1]{3,2,1,0}) tuple(%reshape.5), metadata={op_name="XLA_Retvals"} +} diff --git a/tensorflow/compiler/mlir/xla/tests/translate/convert.hlotxt b/tensorflow/compiler/mlir/xla/tests/translate/convert.hlotxt new file mode 100644 index 00000000000..3c0c7a9c1d1 --- /dev/null +++ b/tensorflow/compiler/mlir/xla/tests/translate/convert.hlotxt @@ -0,0 +1,20 @@ +// RUN: tf-mlir-translate -hlo-text-to-mlir-hlo %s -o - | FileCheck %s + +HloModule main.5 + +// CHECK-LABEL: func @main(%arg0: tensor<4xf32>, %arg1: tensor) -> tensor<4xf64> { +ENTRY %foo.5 (Arg_0.1: f32[4], Arg_1.2: f32[]) -> f64[4] { + %Arg_0.1 = f32[4] parameter(0) + %Arg_1.2 = f32[] parameter(1) + + // CHECK-NEXT: %0 = "xla.convert"(%arg0) {name = "convert.3"} : (tensor<4xf32>) -> tensor<4xf64> + %convert.3 = f64[4] convert(f32[4] %Arg_0.1) + + // CHECK-NEXT: %1 = "xla.convert"(%arg1) {name = "convert.4"} : (tensor) -> tensor + %convert.4 = f64[] convert(f32[] %Arg_1.2) + + // CHECK-NEXT: %2 = "xla.add"(%0, %1) {name = "add.5"} : (tensor<4xf64>, tensor) -> tensor<4xf64> + // CHECK-NEXT: return %2 : tensor<4xf64> + ROOT %add.5 = f64[4] add(f64[4] %convert.3, f64[] %convert.4) +} + diff --git a/tensorflow/compiler/mlir/xla/tests/translate/div.hlotxt b/tensorflow/compiler/mlir/xla/tests/translate/div.hlotxt new file mode 100644 index 00000000000..602ad96b852 --- /dev/null +++ b/tensorflow/compiler/mlir/xla/tests/translate/div.hlotxt @@ -0,0 +1,14 @@ +// RUN: tf-mlir-translate -hlo-text-to-mlir-hlo %s -o - | FileCheck %s + +HloModule main.5 + +// CHECK-LABEL: func @main(%arg0: tensor<4xf32>, %arg1: tensor<4xf32>) -> tensor<4xf32> { +ENTRY %foo.5 (Arg_0.1: f32[4], Arg_1.2: f32[4]) -> f32[4] { + %Arg_0.1 = f32[4] parameter(0) + %Arg_1.2 = f32[4] parameter(1) + + // CHECK-NEXT: %0 = "xla.div"(%arg0, %arg1) {name = "divide.3"} : (tensor<4xf32>, tensor<4xf32>) -> tensor<4xf32> + // CHECK-NEXT: return %0 : tensor<4xf32> + ROOT %divide.3 = f32[4] divide(f32[4] %Arg_0.1, f32[4] %Arg_1.2) +} + diff --git a/tensorflow/compiler/mlir/xla/tests/translate/dot.hlotxt b/tensorflow/compiler/mlir/xla/tests/translate/dot.hlotxt new file mode 100644 index 00000000000..5b7d0c6c2ef --- /dev/null +++ b/tensorflow/compiler/mlir/xla/tests/translate/dot.hlotxt @@ -0,0 +1,23 @@ +// RUN: tf-mlir-translate -hlo-text-to-mlir-hlo %s -o - | FileCheck %s + +HloModule main + +// CHECK-LABEL: func @main(%arg0: tensor<1x4xf32>, %arg1: tensor<4x1xf32>) -> tensor { +ENTRY %main (Arg_0.1: f32[1, 4], Arg_1.2: f32[4, 1]) -> f32[] { + %Arg_0.1 = f32[1, 4] parameter(0) + %Arg_1.2 = f32[4, 1] parameter(1) + + // CHECK-NEXT: %0 = "xla.dot"(%arg0, %arg1) {name = "dot.3", precision_config = ["HIGH", "HIGHEST"]} : (tensor<1x4xf32>, tensor<4x1xf32>) -> tensor + dot.3 = f32[] dot(Arg_0.1, Arg_1.2), lhs_contracting_dims={1}, rhs_contracting_dims={0}, operand_precision={high,highest} + + // CHECK-NEXT: %1 = "xla.dot"(%arg0, %arg1) {name = "dot.4", precision_config = ["HIGHEST", "DEFAULT"]} : (tensor<1x4xf32>, tensor<4x1xf32>) -> tensor + dot.4 = f32[] dot(Arg_0.1, Arg_1.2), lhs_contracting_dims={1}, rhs_contracting_dims={0}, operand_precision={highest,default} + + // CHECK-NEXT: %2 = "xla.dot"(%arg0, %arg1) {name = "dot.5", precision_config = ["DEFAULT", "DEFAULT"]} : (tensor<1x4xf32>, tensor<4x1xf32>) -> tensor + %dot.5 = f32[] dot(Arg_0.1, Arg_1.2), lhs_contracting_dims={1}, rhs_contracting_dims={0}, operand_precision={default,default} + + // TODO(b/129709049) consider making this default precision config inferred. + // CHECK-NEXT: %3 = "xla.dot"(%arg0, %arg1) {name = "dot.6", precision_config = ["DEFAULT", "DEFAULT"]} : (tensor<1x4xf32>, tensor<4x1xf32>) -> tensor + // CHECK-NEXT: return %3 : tensor + ROOT %dot.6 = f32[] dot(Arg_0.1, Arg_1.2), lhs_contracting_dims={1}, rhs_contracting_dims={0} +} diff --git a/tensorflow/compiler/mlir/xla/tests/translate/dynamic-update-slice.hlotxt b/tensorflow/compiler/mlir/xla/tests/translate/dynamic-update-slice.hlotxt new file mode 100644 index 00000000000..d31160cfb21 --- /dev/null +++ b/tensorflow/compiler/mlir/xla/tests/translate/dynamic-update-slice.hlotxt @@ -0,0 +1,26 @@ +// RUN: tf-mlir-translate -hlo-text-to-mlir-hlo %s -o - | FileCheck %s + +HloModule main + +// CHECK-LABEL: func @dynamic.update.slice.1(%arg0: tensor<4x4xf32>, %arg1: tensor<1x4xf32>, %arg2: tensor, %arg3: tensor) -> tensor<4x4xf32> { +%dynamic.update.slice.1 (Arg_0.1: f32[4, 4], Arg_1.2: f32[1, 4], Arg_2.3: f32[], Arg_3.4: f32[]) -> f32[4, 4] { + %Arg_0.1 = f32[4, 4] parameter(0) + %Arg_1.2 = f32[1, 4] parameter(1) + %Arg_2.3 = f32[] parameter(2) + %Arg_3.4 = f32[] parameter(3) + + // CHECK-NEXT: %0 = "xla.dynamic-update-slice"(%arg0, %arg1, %arg2, %arg3) : (tensor<4x4xf32>, tensor<1x4xf32>, tensor, tensor) -> tensor<4x4xf32> + // CHECK-NEXT: return %0 : tensor<4x4xf32> + ROOT %dynamic-update-slice.5 = f32[4, 4] dynamic-update-slice(%Arg_0.1, %Arg_1.2, %Arg_2.3, %Arg_3.4) +} + +// CHECK-LABEL: func @main(%arg0: tensor<4xf32>, %arg1: tensor<2xf32>, %arg2: tensor) -> tensor<4xf32> +%dynamic.update.slice.2 (Arg_0.1: f32[4], Arg_1.2: f32[2], Arg_2.3: f32[]) -> f32[4] { + %Arg_0.1 = f32[4] parameter(0) + %Arg_1.2 = f32[2] parameter(1) + %Arg_2.3 = f32[] parameter(2) + + // CHECK-NEXT: %0 = "xla.dynamic-update-slice"(%arg0, %arg1, %arg2) : (tensor<4xf32>, tensor<2xf32>, tensor) -> tensor<4xf32> + // CHECK-NEXT: return %0 : tensor<4xf32> + ROOT %dynamic-update-slice.5 = f32[4] dynamic-update-slice(%Arg_0.1, %Arg_1.2, %Arg_2.3) +} \ No newline at end of file diff --git a/tensorflow/compiler/mlir/xla/tests/translate/fully_connected_reference_model.hlotxt b/tensorflow/compiler/mlir/xla/tests/translate/fully_connected_reference_model.hlotxt new file mode 100644 index 00000000000..a4e5b19e1e1 --- /dev/null +++ b/tensorflow/compiler/mlir/xla/tests/translate/fully_connected_reference_model.hlotxt @@ -0,0 +1,103 @@ +// RUN: tf-mlir-translate -hlo-text-to-mlir-hlo %s -o - | FileCheck %s + +// This test comes from a fully connected reference model. + +HloModule tfcompile.48 + +// CHECK-LABEL: func @main(%arg0: tensor<1x300xf32>, %arg1: tensor<1x300x3x1xf32>) -> tuple> { +ENTRY %tfcompile.48 { + %arg0.1 = f32[1,300] parameter(0) + %arg1.2 = f32[1,300,3,1] parameter(1) + + // CHECK-NEXT: %0 = "xla.reshape"(%arg0) {name = "reshape.3"} : (tensor<1x300xf32>) -> tensor<1x300xf32> + %reshape.3 = f32[1,300] reshape(%arg0.1) + + // CHECK-NEXT: %1 = "xla.transpose"(%0) {name = "transpose.27", permutation = dense<[1, 0]> : tensor<2xi64>} : (tensor<1x300xf32>) -> tensor<300x1xf32> + %transpose.27 = f32[300,1] transpose(%reshape.3), dimensions={1,0} + + // CHECK-NEXT: %2 = "xla.reshape"(%1) {name = "reshape.28"} : (tensor<300x1xf32>) -> tensor<300x1x1xf32> + %reshape.28 = f32[300,1,1] reshape(%transpose.27) + + // CHECK-NEXT: %3 = "xla.reshape"(%2) {name = "reshape.29"} : (tensor<300x1x1xf32>) -> tensor<300x1xf32> + %reshape.29 = f32[300,1] reshape(%reshape.28) + + // CHECK-NEXT: %4 = "xla.broadcast_in_dim"(%3) {broadcast_dimensions = dense<[0, 1]> : tensor<2xi64>, name = "broadcast.30"} : (tensor<300x1xf32>) -> tensor<300x1x5xf32> + %broadcast.30 = f32[300,1,5] broadcast(%reshape.29), dimensions={0,1} + + // CHECK-NEXT: %cst = constant {name = "constant.8"} dense<1.000000e+00> : tensor + %constant.8 = f32[] constant(1) + + // CHECK-NEXT: %5 = "xla.broadcast_in_dim"(%cst) {name = "broadcast.9"} : (tensor) -> tensor<300x1x5xf32> + %broadcast.9 = f32[300,1,5] broadcast(%constant.8), dimensions={} + + // CHECK-NEXT: %6 = "xla.mul"(%4, %5) {name = "multiply.31"} : (tensor<300x1x5xf32>, tensor<300x1x5xf32>) -> tensor<300x1x5xf32> + %multiply.31 = f32[300,1,5] multiply(%broadcast.30, %broadcast.9) + + // CHECK-NEXT: %cst_0 = constant {name = "constant.32"} dense<0.000000e+00> : tensor + %constant.32 = f32[] constant(0) + + // CHECK-NEXT: %7 = "xla.broadcast_in_dim"(%cst_0) {name = "broadcast.33"} : (tensor) -> tensor<300x1x5xf32> + %broadcast.33 = f32[300,1,5] broadcast(%constant.32), dimensions={} + + // CHECK-NEXT: %8 = "xla.compare"(%6, %7) {comparison_direction = "GT", name = "compare.34"} : (tensor<300x1x5xf32>, tensor<300x1x5xf32>) -> tensor<300x1x5xi1> + %compare.34 = pred[300,1,5] compare(%multiply.31, %broadcast.33), direction=GT + + // CHECK-NEXT: %cst_1 = constant {name = "constant.10"} dense<0.000000e+00> : tensor + %constant.10 = f32[] constant(0) + + // CHECK-NEXT: %9 = "xla.broadcast_in_dim"(%cst_1) {name = "broadcast.11"} : (tensor) -> tensor<300x1x5xf32> + %broadcast.11 = f32[300,1,5] broadcast(%constant.10), dimensions={} + + // CHECK-NEXT: %cst_2 = constant {name = "constant.40"} dense<0.000000e+00> : tensor + %constant.40 = f32[] constant(0) + + // CHECK-NEXT: %10 = "xla.broadcast_in_dim"(%cst_2) {name = "broadcast.41"} : (tensor) -> tensor<300x5xf32> + %broadcast.41 = f32[300,5] broadcast(%constant.40), dimensions={} + + // CHECK-NEXT: %11 = "xla.copy"(%arg1) {name = "copy.1"} : (tensor<1x300x3x1xf32>) -> tensor<1x300x3x1xf32> + %copy.1 = f32[1,300,3,1] copy(%arg1.2) + + // CHECK-NEXT: %12 = "xla.reshape"(%11) {name = "reshape.4"} : (tensor<1x300x3x1xf32>) -> tensor<1x300x3x1xf32> + %reshape.4 = f32[1,300,3,1] reshape(%copy.1) + + // CHECK-NEXT: %13 = "xla.reshape"(%12) {name = "reshape.24"} : (tensor<1x300x3x1xf32>) -> tensor<1x300x3xf32> + %reshape.24 = f32[1,300,3] reshape(%reshape.4) + + // CHECK-NEXT: %14 = "xla.transpose"(%13) {name = "transpose.25", permutation = dense<[1, 0, 2]> : tensor<3xi64>} : (tensor<1x300x3xf32>) -> tensor<300x1x3xf32> + %transpose.25 = f32[300,1,3] transpose(%reshape.24), dimensions={1,0,2} + + // CHECK-NEXT: %15 = "xla.reshape"(%14) {name = "reshape.26"} : (tensor<300x1x3xf32>) -> tensor<300x3xf32> + %reshape.26 = f32[300,3] reshape(%transpose.25) + + // CHECK-NEXT: %cst_3 = constant {name = "constant.35"} dense<{{\[\[}}-1.060230e-01, 1.215050e-01, 8.002390e-01, -7.688850e-01, 0.0966112986], [6.890140e-01, -4.070560e-01, -0.797852993, 3.789250e-03, -2.088810e-01], [-6.085290e-01, 2.766170e-02, 2.685570e-01, 5.774010e-01, -4.284370e-01]]> : tensor<3x5xf32> + %constant.35 = f32[3,5] constant({ { -0.106023, 0.121505, 0.800239, -0.768885, 0.0966113 }, { 0.689014, -0.407056, -0.797853, 0.00378925, -0.208881 }, { -0.608529, 0.0276617, 0.268557, 0.577401, -0.428437 } }) + + // TODO(b/129709049) consider making this default precision config implied. + // CHECK-NEXT: %16 = "xla.dot"(%15, %cst_3) {name = "dot.36", precision_config = ["DEFAULT", "DEFAULT"]} : (tensor<300x3xf32>, tensor<3x5xf32>) -> tensor<300x5xf32> + %dot.36 = f32[300,5] dot(%reshape.26, %constant.35), lhs_contracting_dims={1}, rhs_contracting_dims={0} + + // CHECK-NEXT: %cst_4 = constant {name = "constant.37"} dense<0.000000e+00> : tensor<5xf32> + %constant.37 = f32[5]{0} constant({0, 0, 0, 0, 0}) + + // CHECK-NEXT: %17 = "xla.broadcast_in_dim"(%cst_4) {broadcast_dimensions = dense<1> : tensor<1xi64>, name = "broadcast.38"} : (tensor<5xf32>) -> tensor<300x5xf32> + %broadcast.38 = f32[300,5] broadcast(%constant.37), dimensions={1} + + // CHECK-NEXT: %18 = "xla.add"(%16, %17) {name = "add.39"} : (tensor<300x5xf32>, tensor<300x5xf32>) -> tensor<300x5xf32> + %add.39 = f32[300,5] add(%dot.36, %broadcast.38) + + // CHECK-NEXT: %19 = "xla.max"(%10, %18) {name = "maximum.42"} : (tensor<300x5xf32>, tensor<300x5xf32>) -> tensor<300x5xf32> + %maximum.42 = f32[300,5] maximum(%broadcast.41, %add.39) + + // CHECK-NEXT: %20 = "xla.reshape"(%19) {name = "reshape.44"} : (tensor<300x5xf32>) -> tensor<300x1x5xf32> + %reshape.44 = f32[300,1,5] reshape(%maximum.42) + + // CHECK-NEXT: %21 = "xla.select"(%8, %9, %20) {name = "select.45"} : (tensor<300x1x5xi1>, tensor<300x1x5xf32>, tensor<300x1x5xf32>) -> tensor<300x1x5xf32> + %select.45 = f32[300,1,5] select(%compare.34, %broadcast.11, %reshape.44) + + // CHECK-NEXT: %22 = "xla.reshape"(%21) {name = "reshape.46"} : (tensor<300x1x5xf32>) -> tensor<300x1x5xf32> + %reshape.46 = f32[300,1,5] reshape(%select.45) + + // CHECK-NEXT: %23 = "xla.tuple"(%22) {name = "tuple.47"} : (tensor<300x1x5xf32>) -> tuple> + // CHECK-NEXT: return %23 : tuple> + ROOT %tuple.47 = (f32[300,1,5]) tuple(%reshape.46) +} diff --git a/tensorflow/compiler/mlir/xla/tests/translate/iota.hlotxt b/tensorflow/compiler/mlir/xla/tests/translate/iota.hlotxt new file mode 100644 index 00000000000..9a4944d414e --- /dev/null +++ b/tensorflow/compiler/mlir/xla/tests/translate/iota.hlotxt @@ -0,0 +1,17 @@ +// RUN: tf-mlir-translate -hlo-text-to-mlir-hlo %s -o - | FileCheck %s + +HloModule main.5 + +// CHECK-LABEL: func @main() -> tensor<4xf32> { +ENTRY %iota.1 () -> f32[4] { + // CHECK-NEXT: %0 = "xla.iota"() {iota_dimension = 0 : i64} : () -> tensor<4xf32> + // CHECK-NEXT: return %0 : tensor<4xf32> + ROOT %iota.0 = f32[4] iota(), iota_dimension=0 +} + +// CHECK-LABEL: func @iota.2() -> tensor<4x5xf32> { +%iota.2 () -> f32[4, 5] { + // CHECK-NEXT: %0 = "xla.iota"() {iota_dimension = 1 : i64} : () -> tensor<4x5xf32> + // CHECK-NEXT: return %0 : tensor<4x5xf32> + ROOT %iota.0 = f32[4, 5] iota(), iota_dimension=1 +} \ No newline at end of file diff --git a/tensorflow/compiler/mlir/xla/tests/translate/max.hlotxt b/tensorflow/compiler/mlir/xla/tests/translate/max.hlotxt new file mode 100644 index 00000000000..dd6c0f504f5 --- /dev/null +++ b/tensorflow/compiler/mlir/xla/tests/translate/max.hlotxt @@ -0,0 +1,14 @@ +// RUN: tf-mlir-translate -hlo-text-to-mlir-hlo %s -o - | FileCheck %s + +HloModule main.5 + +// CHECK-LABEL: func @main(%arg0: tensor<4xf32>, %arg1: tensor<4xf32>) -> tensor<4xf32> { +ENTRY %foo.5 (Arg_0.1: f32[4], Arg_1.2: f32[4]) -> f32[4] { + %Arg_0.1 = f32[4] parameter(0) + %Arg_1.2 = f32[4] parameter(1) + + // CHECK-NEXT: %0 = "xla.max"(%arg0, %arg1) {name = "maximum.3"} : (tensor<4xf32>, tensor<4xf32>) -> tensor<4xf32> + // CHECK-NEXT: return %0 : tensor<4xf32> + ROOT %maximum.3 = f32[4] maximum(f32[4] %Arg_0.1, f32[4] %Arg_1.2) +} + diff --git a/tensorflow/compiler/mlir/xla/tests/translate/min.hlotxt b/tensorflow/compiler/mlir/xla/tests/translate/min.hlotxt new file mode 100644 index 00000000000..5efe44aa53a --- /dev/null +++ b/tensorflow/compiler/mlir/xla/tests/translate/min.hlotxt @@ -0,0 +1,14 @@ +// RUN: tf-mlir-translate -hlo-text-to-mlir-hlo %s -o - | FileCheck %s + +HloModule main.5 + +// CHECK-LABEL: func @main(%arg0: tensor<4xf32>, %arg1: tensor<4xf32>) -> tensor<4xf32> { +ENTRY %foo.5 (Arg_0.1: f32[4], Arg_1.2: f32[4]) -> f32[4] { + %Arg_0.1 = f32[4] parameter(0) + %Arg_1.2 = f32[4] parameter(1) + + // CHECK-NEXT: %0 = "xla.min"(%arg0, %arg1) {name = "minimum.3"} : (tensor<4xf32>, tensor<4xf32>) -> tensor<4xf32> + // CHECK-NEXT: return %0 : tensor<4xf32> + ROOT %minimum.3 = f32[4] minimum(f32[4] %Arg_0.1, f32[4] %Arg_1.2) +} + diff --git a/tensorflow/compiler/mlir/xla/tests/translate/mul.hlotxt b/tensorflow/compiler/mlir/xla/tests/translate/mul.hlotxt new file mode 100644 index 00000000000..1bfb6662124 --- /dev/null +++ b/tensorflow/compiler/mlir/xla/tests/translate/mul.hlotxt @@ -0,0 +1,14 @@ +// RUN: tf-mlir-translate -hlo-text-to-mlir-hlo %s -o - | FileCheck %s + +HloModule main.5 + +// CHECK-LABEL: func @main(%arg0: tensor<4xf32>, %arg1: tensor<4xf32>) -> tensor<4xf32> { +ENTRY %foo.5 (Arg_0.1: f32[4], Arg_1.2: f32[4]) -> f32[4] { + %Arg_0.1 = f32[4] parameter(0) + %Arg_1.2 = f32[4] parameter(1) + + // CHECK-NEXT: %0 = "xla.mul"(%arg0, %arg1) {name = "multiply.3"} : (tensor<4xf32>, tensor<4xf32>) -> tensor<4xf32> + // CHECK-NEXT: return %0 : tensor<4xf32> + ROOT %multiply.3 = f32[4] multiply(f32[4] %Arg_0.1, f32[4] %Arg_1.2) +} + diff --git a/tensorflow/compiler/mlir/xla/tests/translate/pad.hlotxt b/tensorflow/compiler/mlir/xla/tests/translate/pad.hlotxt new file mode 100644 index 00000000000..412f267ce42 --- /dev/null +++ b/tensorflow/compiler/mlir/xla/tests/translate/pad.hlotxt @@ -0,0 +1,23 @@ +// RUN: tf-mlir-translate -hlo-text-to-mlir-hlo %s -o - | FileCheck %s + +HloModule main.5 + +// CHECK-LABEL: func @main(%arg0: tensor<4xf32>, %arg1: tensor) -> tensor<4xf32> { +ENTRY %padding.1 (Arg_0.1: f32[4], Arg_1.2: f32[]) -> f32[4] { + %Arg_0.1 = f32[4] parameter(0) + %Arg_1.2 = f32[] parameter(1) + + // CHECK-NEXT: %0 = "xla.pad"(%arg0, %arg1) {edge_padding_high = dense<0> : tensor<1xi64>, edge_padding_low = dense<0> : tensor<1xi64>, interior_padding = dense<0> : tensor<1xi64>} : (tensor<4xf32>, tensor) -> tensor<4xf32> + // CHECK-NEXT: return %0 : tensor<4xf32> + ROOT %pad.3 = f32[4] pad(%Arg_0.1, %Arg_1.2), padding=0_0_0 +} + +// CHECK-LABEL: func @padding.2(%arg0: tensor<4x4x4xf32>, %arg1: tensor) -> tensor<7x11x15xf32> { +%padding.2 (Arg_0.1: f32[4, 4, 4], Arg_1.2: f32[]) -> f32[7, 11, 15] { + %Arg_0.1 = f32[4, 4, 4] parameter(0) + %Arg_1.2 = f32[] parameter(1) + + // CHECK-NEXT: %0 = "xla.pad"(%arg0, %arg1) {edge_padding_high = dense<[2, 4, 6]> : tensor<3xi64>, edge_padding_low = dense<[1, 3, 5]> : tensor<3xi64>, interior_padding = dense<0> : tensor<3xi64>} : (tensor<4x4x4xf32>, tensor) -> tensor<7x11x15xf32> + // CHECK-NEXT: return %0 : tensor<7x11x15xf32> + ROOT %pad.3 = f32[7, 11, 15] pad(%Arg_0.1, %Arg_1.2), padding=1_2x3_4x5_6 +} diff --git a/tensorflow/compiler/mlir/xla/tests/translate/reduce.hlotxt b/tensorflow/compiler/mlir/xla/tests/translate/reduce.hlotxt new file mode 100644 index 00000000000..37e638eb1f7 --- /dev/null +++ b/tensorflow/compiler/mlir/xla/tests/translate/reduce.hlotxt @@ -0,0 +1,53 @@ +// RUN: tf-mlir-translate -hlo-text-to-mlir-hlo %s -o - | FileCheck %s + +// This test is more thorough than those of the the other binary ops to test +// their shared functionality. + +HloModule main.5 + +%reduce_helper.1 (Arg_0.1: f32[4], Arg_1.2: f32[4]) -> f32[4] { + %Arg_0.1 = f32[4] parameter(0) + %Arg_1.2 = f32[4] parameter(1) + ROOT %add.3 = f32[4] add(f32[4] %Arg_0.1, f32[4] %Arg_1.2) +} + +%reduce_helper.2 (Arg_0.1: f32[], Arg_1.2: f32[]) -> f32[] { + %Arg_0.1 = f32[] parameter(0) + %Arg_1.2 = f32[] parameter(1) + ROOT %add.3 = f32[] add(f32[] %Arg_0.1, f32[] %Arg_1.2) +} + +%reduce_helper.3 (Arg_0.1: f32[], Arg_1.2: f32[], Arg_2.3: f32[], Arg_3.4: f32[]) -> (f32[], f32[]) { + %Arg_0.1 = f32[] parameter(0) + %Arg_1.2 = f32[] parameter(1) + %Arg_2.3 = f32[] parameter(2) + %Arg_3.4 = f32[] parameter(3) + %add.4 = f32[] add(f32[] %Arg_0.1, f32[] %Arg_2.3) + %add.5 = f32[] add(f32[] %Arg_1.2, f32[] %Arg_3.4) + ROOT %tuple.6 = (f32[], f32[]) tuple(%add.4, %add.5) +} + +// CHECK-LABEL: func @main(%arg0: tensor<4x4xf32>, %arg1: tensor<4xf32>, %arg2: tensor) -> tuple, tensor>, tensor> { +ENTRY %foo.5 (Arg_0.1: f32[4, 4], Arg_1.2: f32[4], Arg_2.3: f32[]) -> ((f32[], f32[]), f32[]) { + %Arg_0.1 = f32[4, 4] parameter(0) + %Arg_1.2 = f32[4] parameter(1) + %Arg_2.3 = f32[] parameter(2) + + // CHECK-NEXT: %0 = "xla.reduce"(%arg0, %arg0, %arg2, %arg2) {computation = @reduce_helper.3, dimensions = dense<[0, 1]> : tensor<2xi64>} : (tensor<4x4xf32>, tensor<4x4xf32>, tensor, tensor) -> tuple, tensor> + %reduce.1 = f32[4] reduce(%Arg_0.1, %Arg_1.2), dimensions={0}, to_apply=%reduce_helper.1 + + // CHECK-NEXT: %1 = "xla.reduce"(%arg0, %arg1) {computation = @reduce_helper.1, dimensions = dense<0> : tensor<1xi64>} : (tensor<4x4xf32>, tensor<4xf32>) -> tensor<4xf32> + %reduce.2 = f32[] reduce(%reduce.1, %Arg_2.3), dimensions={0}, to_apply=%reduce_helper.2 + + // CHECK-NEXT: %2 = "xla.reduce"(%1, %arg2) {computation = @reduce_helper.2, dimensions = dense<0> : tensor<1xi64>} : (tensor<4xf32>, tensor) -> tensor + %reduce.3 = f32[] reduce(%Arg_0.1, %Arg_2.3), dimensions={0, 1}, to_apply=%reduce_helper.2 + + // CHECK-NEXT: %3 = "xla.reduce"(%arg0, %arg2) {computation = @reduce_helper.2, dimensions = dense<[0, 1]> : tensor<2xi64>} : (tensor<4x4xf32>, tensor) -> tensor + %reduce.4 = (f32[], f32[]) reduce(%Arg_0.1, %Arg_0.1, %Arg_2.3, %Arg_2.3), dimensions={0, 1}, to_apply=%reduce_helper.3 + + // CHECK-NEXT: %4 = "xla.sub"(%2, %3) {name = "sub.5"} : (tensor, tensor) -> tensor + %sub.5 = f32[] subtract(%reduce.2, %reduce.3) + + ROOT %tuple.6 = ((f32[], f32[]), f32[]) tuple(%reduce.4, %sub.5) +} + diff --git a/tensorflow/compiler/mlir/xla/tests/translate/reverse.hlotxt b/tensorflow/compiler/mlir/xla/tests/translate/reverse.hlotxt new file mode 100644 index 00000000000..7c8303d5966 --- /dev/null +++ b/tensorflow/compiler/mlir/xla/tests/translate/reverse.hlotxt @@ -0,0 +1,21 @@ +// RUN: tf-mlir-translate -hlo-text-to-mlir-hlo %s -o - | FileCheck %s + +HloModule main.5 + +// CHECK-LABEL: func @main(%arg0: tensor<4xf32>) -> tensor<4xf32> { +ENTRY %reverse.1 (Arg_0.1: f32[4]) -> f32[4] { + %Arg_0.1 = f32[4] parameter(0) + + // CHECK-NEXT: %0 = "xla.reverse"(%arg0) {dimensions = dense<0> : tensor<1xi64>} : (tensor<4xf32>) -> tensor<4xf32> + // CHECK-NEXT: return %0 : tensor<4xf32> + ROOT reverse.2 = f32[4] reverse(%Arg_0.1), dimensions={0} +} + +// CHECK-LABEL: func @reverse.2(%arg0: tensor<4x4xf32>) -> tensor<4x4xf32> { +%reverse.2 (Arg_0.1: f32[4, 4]) -> f32[4, 4] { + %Arg_0.1 = f32[4, 4] parameter(0) + + // CHECK-NEXT: %0 = "xla.reverse"(%arg0) {dimensions = dense<[0, 1]> : tensor<2xi64>} : (tensor<4x4xf32>) -> tensor<4x4xf32> + // CHECK-NEXT: return %0 : tensor<4x4xf32> + ROOT reverse.2 = f32[4, 4] reverse(%Arg_0.1), dimensions={0, 1} +} diff --git a/tensorflow/compiler/mlir/xla/tests/translate/scalar.hlotxt b/tensorflow/compiler/mlir/xla/tests/translate/scalar.hlotxt new file mode 100644 index 00000000000..7b7ce4006f5 --- /dev/null +++ b/tensorflow/compiler/mlir/xla/tests/translate/scalar.hlotxt @@ -0,0 +1,9 @@ +// RUN: tf-mlir-translate -hlo-text-to-mlir-hlo %s -o - | FileCheck %s + +HloModule main.5 + +// CHECK-LABEL: func @main(%arg0: tensor) -> tensor { +ENTRY %main.5 (Arg_0.1: f32[]) -> f32[] { + // CHECK-NEXT: return %arg0 : tensor + ROOT %Arg_0.1 = f32[] parameter(0) +} diff --git a/tensorflow/compiler/mlir/xla/tests/translate/select.hlotxt b/tensorflow/compiler/mlir/xla/tests/translate/select.hlotxt new file mode 100644 index 00000000000..b9ae08d8c8c --- /dev/null +++ b/tensorflow/compiler/mlir/xla/tests/translate/select.hlotxt @@ -0,0 +1,15 @@ +// RUN: tf-mlir-translate -hlo-text-to-mlir-hlo %s -o - | FileCheck %s + +HloModule main + +// CHECK-LABEL: func @main(%arg0: tensor<2x3xi1>, %arg1: tensor<2x3xi32>, %arg2: tensor<2x3xi32>) -> tensor<2x3xi32> { +ENTRY %main { + %Arg_0.1 = pred[2,3] parameter(0) + %Arg_1.2 = s32[2,3] parameter(1) + %Arg_2.3 = s32[2,3] parameter(2) + + // CHECK-NEXT: %0 = "xla.select"(%arg0, %arg1, %arg2) {name = "select.4"} : (tensor<2x3xi1>, tensor<2x3xi32>, tensor<2x3xi32>) -> tensor<2x3xi32> + // CHECK-NEXT: return %0 : tensor<2x3xi32> + ROOT %select.4 = s32[2,3] select(pred[2,3] %Arg_0.1, s32[2,3] %Arg_1.2, s32[2,3] %Arg_2.3) +} + diff --git a/tensorflow/compiler/mlir/xla/tests/translate/select.mlir b/tensorflow/compiler/mlir/xla/tests/translate/select.mlir new file mode 100644 index 00000000000..4990ae712f8 --- /dev/null +++ b/tensorflow/compiler/mlir/xla/tests/translate/select.mlir @@ -0,0 +1,13 @@ +// RUN: tf-mlir-translate -mlir-hlo-to-hlo-text %s | FileCheck %s + +// CHECK-LABEL: ENTRY %main +func @main(%arg0: tensor<2x3xi1>, %arg1: tensor<2x3xi32>, %arg2: tensor<2x3xi32>) -> tensor<2x3xi32> { + // CHECK-NEXT: %Arg_0.1 = pred[2,3] parameter(0) + // CHECK-NEXT: %Arg_1.2 = s32[2,3] parameter(1) + // CHECK-NEXT: %Arg_2.3 = s32[2,3] parameter(2) + + // CHECK-NEXT: ROOT %select.4 = s32[2,3] select(pred[2,3] %Arg_0.1, s32[2,3] %Arg_1.2, s32[2,3] %Arg_2.3) + %0 = "xla.select"(%arg0, %arg1, %arg2) {name = "select.4"} : (tensor<2x3xi1>, tensor<2x3xi32>, tensor<2x3xi32>) -> tensor<2x3xi32> + return %0 : tensor<2x3xi32> +} + diff --git a/tensorflow/compiler/mlir/xla/tests/translate/simple.hlo b/tensorflow/compiler/mlir/xla/tests/translate/simple.hlo new file mode 100644 index 00000000000..83d85f7d45e --- /dev/null +++ b/tensorflow/compiler/mlir/xla/tests/translate/simple.hlo @@ -0,0 +1,146 @@ +# RUN: tf-mlir-translate -hlo-to-mlir-hlo %s -o - | FileCheck %s + +name: "foo.5" +entry_computation_name: "foo.5" +computations { + name: "foo.5" + instructions { + name: "Arg_0.1" + opcode: "parameter" + shape { + element_type: F32 + dimensions: 4 + layout { + minor_to_major: 0 + format: DENSE + } + is_dynamic_dimension: false + } + metadata { + } + id: 1 + } + instructions { + name: "Arg_1.2" + opcode: "parameter" + shape { + element_type: F32 + dimensions: 4 + layout { + minor_to_major: 0 + format: DENSE + } + is_dynamic_dimension: false + } + metadata { + } + parameter_number: 1 + id: 2 + } + instructions { + name: "add.3" + opcode: "add" + shape { + element_type: F32 + dimensions: 4 + layout { + minor_to_major: 0 + format: DENSE + } + is_dynamic_dimension: false + } + metadata { + } + id: 3 + operand_ids: 1 + operand_ids: 2 + } + instructions { + name: "dot.4" + opcode: "dot" + shape { + element_type: F32 + layout { + format: DENSE + } + } + metadata { + } + dot_dimension_numbers { + lhs_contracting_dimensions: 0 + rhs_contracting_dimensions: 0 + } + id: 4 + operand_ids: 3 + operand_ids: 2 + } + program_shape { + parameters { + element_type: F32 + dimensions: 4 + layout { + minor_to_major: 0 + format: DENSE + } + is_dynamic_dimension: false + } + parameters { + element_type: F32 + dimensions: 4 + layout { + minor_to_major: 0 + format: DENSE + } + is_dynamic_dimension: false + } + result { + element_type: F32 + layout { + format: DENSE + } + } + parameter_names: "Arg_0" + parameter_names: "Arg_1" + } + id: 5 + root_id: 4 +} +host_program_shape { + parameters { + element_type: F32 + dimensions: 4 + layout { + minor_to_major: 0 + format: DENSE + } + is_dynamic_dimension: false + } + parameters { + element_type: F32 + dimensions: 4 + layout { + minor_to_major: 0 + format: DENSE + } + is_dynamic_dimension: false + } + result { + element_type: F32 + layout { + format: DENSE + } + } + parameter_names: "Arg_0" + parameter_names: "Arg_1" +} +id: 5 +entry_computation_id: 5 +dynamic_parameter_binding { +} + +# CHECK-LABEL: func @main(%arg0: tensor<4xf32>, %arg1: tensor<4xf32>) -> tensor { +# CHECK-NEXT: %0 = "xla.add"(%arg0, %arg1) {name = "add.3"} : (tensor<4xf32>, tensor<4xf32>) -> tensor<4xf32> +# TODO(b/129709049) consider making this default precision config inferred. +# CHECK-NEXT: %1 = "xla.dot"(%0, %arg1) {name = "dot.4", precision_config = ["DEFAULT", "DEFAULT"]} : (tensor<4xf32>, tensor<4xf32>) -> tensor +# CHECK-NEXT: return %1 : tensor +# CHECK-NEXT: } diff --git a/tensorflow/compiler/mlir/xla/tests/translate/simple.hlotxt b/tensorflow/compiler/mlir/xla/tests/translate/simple.hlotxt new file mode 100644 index 00000000000..09462625bbb --- /dev/null +++ b/tensorflow/compiler/mlir/xla/tests/translate/simple.hlotxt @@ -0,0 +1,17 @@ +// RUN: tf-mlir-translate -hlo-text-to-mlir-hlo %s -o - | FileCheck %s + +HloModule foo.5 + +// CHECK-LABEL: func @main(%arg0: tensor<4xf32>, %arg1: tensor<4xf32>) -> tensor { +ENTRY %main.5 (Arg_0.1: f32[4], Arg_1.2: f32[4]) -> f32[] { + %Arg_0.1 = f32[4]{0} parameter(0) + %Arg_1.2 = f32[4]{0} parameter(1) + + // CHECK-NEXT: %0 = "xla.add"(%arg0, %arg1) {name = "add.3"} : (tensor<4xf32>, tensor<4xf32>) -> tensor<4xf32> + %add.3 = f32[4]{0} add(f32[4]{0} %Arg_0.1, f32[4]{0} %Arg_1.2) + + // TODO(b/129709049) consider making this default precision config inferred. + // CHECK-NEXT: %1 = "xla.dot"(%0, %arg1) {name = "dot.4", precision_config = ["DEFAULT", "DEFAULT"]} : (tensor<4xf32>, tensor<4xf32>) -> tensor + // CHECK-NEXT: return %1 : tensor + ROOT %dot.4 = f32[] dot(f32[4]{0} %add.3, f32[4]{0} %Arg_1.2), lhs_contracting_dims={0}, rhs_contracting_dims={0} +} diff --git a/tensorflow/compiler/mlir/xla/tests/translate/simple.mlir b/tensorflow/compiler/mlir/xla/tests/translate/simple.mlir new file mode 100644 index 00000000000..f6e277c97de --- /dev/null +++ b/tensorflow/compiler/mlir/xla/tests/translate/simple.mlir @@ -0,0 +1,21 @@ +// RUN: tf-mlir-translate -mlir-hlo-to-hlo %s | FileCheck %s + +func @main(tensor<4xf32>, tensor<4xf32>) -> tensor<4xf32> { +^bb0(%arg0: tensor<4xf32>, %arg1: tensor<4xf32>): + %0 = "xla.add"(%arg0, %arg1) : (tensor<4xf32>, tensor<4xf32>) -> tensor<4xf32> + %1 = "xla.dot"(%0, %arg1) : (tensor<4xf32>, tensor<4xf32>) -> tensor<4xf32> + return %1 : tensor<4xf32> +} + +// CHECK: name: "main +// CHECK: entry_computation_name: "main +// CHECK: computations { +// CHECK: name: "main +// CHECK: instructions { +// CHECK: name: "Arg_ +// CHECK: opcode: "parameter" +// CHECK: name: "add +// CHECK: opcode: "add" +// CHECK: name: "dot +// CHECK: opcode: "dot" + diff --git a/tensorflow/compiler/mlir/xla/tests/translate/sub.hlotxt b/tensorflow/compiler/mlir/xla/tests/translate/sub.hlotxt new file mode 100644 index 00000000000..6fc493aa764 --- /dev/null +++ b/tensorflow/compiler/mlir/xla/tests/translate/sub.hlotxt @@ -0,0 +1,14 @@ +// RUN: tf-mlir-translate -hlo-text-to-mlir-hlo %s -o - | FileCheck %s + +HloModule main.5 + +// CHECK-LABEL: func @main(%arg0: tensor<4xf32>, %arg1: tensor<4xf32>) -> tensor<4xf32> { +ENTRY %foo.5 (Arg_0.1: f32[4], Arg_1.2: f32[4]) -> f32[4] { + %Arg_0.1 = f32[4] parameter(0) + %Arg_1.2 = f32[4] parameter(1) + + // CHECK-NEXT: %0 = "xla.sub"(%arg0, %arg1) {name = "subtract.3"} : (tensor<4xf32>, tensor<4xf32>) -> tensor<4xf32> + // CHECK-NEXT: return %0 : tensor<4xf32> + ROOT %subtract.3 = f32[4] subtract(f32[4] %Arg_0.1, f32[4] %Arg_1.2) +} + diff --git a/tensorflow/compiler/mlir/xla/tests/translate/tanh.hlotxt b/tensorflow/compiler/mlir/xla/tests/translate/tanh.hlotxt new file mode 100644 index 00000000000..54dc0faef09 --- /dev/null +++ b/tensorflow/compiler/mlir/xla/tests/translate/tanh.hlotxt @@ -0,0 +1,12 @@ +// RUN: tf-mlir-translate -hlo-text-to-mlir-hlo %s -o - | FileCheck %s + +HloModule foo + +// CHECK-LABEL: func @main(%arg0: tensor<1x16x16x3xf32>) -> tensor<1x16x16x3xf32> { +ENTRY %foo (arg0.1: f32[1,16,16,3]) -> f32[1,16,16,3] { + %arg0.1 = f32[1,16,16,3]{3,2,1,0} parameter(0), metadata={op_name="XLA_Args"} + + // CHECK-NEXT: %0 = "xla.tanh"(%arg0) {name = "tanh.3"} : (tensor<1x16x16x3xf32>) -> tensor<1x16x16x3xf32> + // CHECK-NEXT: return %0 : tensor<1x16x16x3xf32> + ROOT %tanh.3 = f32[1,16,16,3]{3,2,1,0} tanh(f32[1,16,16,3]{3,2,1,0} %arg0.1), metadata={op_type="Tanh" op_name="embedded_inference/tanh_model/Tanh"} +} diff --git a/tensorflow/compiler/mlir/xla/tests/translate/transpose.hlotxt b/tensorflow/compiler/mlir/xla/tests/translate/transpose.hlotxt new file mode 100644 index 00000000000..335e54669eb --- /dev/null +++ b/tensorflow/compiler/mlir/xla/tests/translate/transpose.hlotxt @@ -0,0 +1,13 @@ +// RUN: tf-mlir-translate -hlo-text-to-mlir-hlo %s -o - | FileCheck %s + +HloModule main + +// CHECK-LABEL: func @main(%arg0: tensor<1x2x3x4xi32>) -> tensor<2x1x4x3xi32> { +ENTRY %main { + %Arg_0.1 = s32[1,2,3,4] parameter(0) + + // CHECK-NEXT: %0 = "xla.transpose"(%arg0) {name = "transpose.2", permutation = dense<[1, 0, 3, 2]> : tensor<4xi64>} : (tensor<1x2x3x4xi32>) -> tensor<2x1x4x3xi32> + // CHECK-NEXT: return %0 : tensor<2x1x4x3xi32> + ROOT %transpose.2 = s32[2,1,4,3] transpose(s32[1,2,3,4] %Arg_0.1), dimensions={1,0,3,2} +} + diff --git a/tensorflow/compiler/mlir/xla/tests/translate/transpose.mlir b/tensorflow/compiler/mlir/xla/tests/translate/transpose.mlir new file mode 100644 index 00000000000..e28d0a37d84 --- /dev/null +++ b/tensorflow/compiler/mlir/xla/tests/translate/transpose.mlir @@ -0,0 +1,11 @@ +// RUN: tf-mlir-translate -mlir-hlo-to-hlo-text %s | FileCheck %s + +// CHECK-LABEL: ENTRY %main +func @main(%arg0: tensor<1x2x3x4xi32>) -> tensor<2x1x4x3xi32> { + // CHECK-NEXT: %Arg_0.1 = s32[1,2,3,4] parameter(0) + + // CHECK-NEXT: ROOT %transpose.2 = s32[2,1,4,3] transpose(s32[1,2,3,4] %Arg_0.1), dimensions={1,0,3,2} + %0 = "xla.transpose"(%arg0) {permutation = dense<[1, 0, 3, 2]> : tensor<4xi64>} : (tensor<1x2x3x4xi32>) -> tensor<2x1x4x3xi32> + return %0 : tensor<2x1x4x3xi32> +} + diff --git a/tensorflow/compiler/mlir/xla/tests/translate/tuple.hlotxt b/tensorflow/compiler/mlir/xla/tests/translate/tuple.hlotxt new file mode 100644 index 00000000000..c98fa93fcd9 --- /dev/null +++ b/tensorflow/compiler/mlir/xla/tests/translate/tuple.hlotxt @@ -0,0 +1,16 @@ +// RUN: tf-mlir-translate -hlo-text-to-mlir-hlo %s -o - | FileCheck %s + +HloModule main + +// CHECK-LABEL: func @main(%arg0: tensor<1xi32>, %arg1: tensor<1x2xf32>) -> tuple, tensor<1x2xf32>> { +ENTRY %main(Arg_0.1: s32[1], Arg_1.2: f32[1, 2]) -> (s32[1], f32[1,2]) { + %Arg_0.1 = s32[1] parameter(0) + %Arg_1.2 = f32[1, 2] parameter(1) + + // CHECK-NEXT: %0 = "xla.tuple"(%arg0) {name = "tuple.3"} : (tensor<1xi32>) -> tuple> + %tuple.3 = (s32[1]) tuple(%Arg_0.1) + + // CHECK-NEXT: %1 = "xla.tuple"(%arg0, %arg1) {name = "tuple.4"} : (tensor<1xi32>, tensor<1x2xf32>) -> tuple, tensor<1x2xf32>> + // CHECK-NEXT: return %1 : tuple, tensor<1x2xf32>> + ROOT %tuple.4 = (s32[1], f32[1,2]) tuple(%Arg_0.1, %Arg_1.2) +} diff --git a/tensorflow/compiler/mlir/xla/tests/translate/unknown.hlotxt b/tensorflow/compiler/mlir/xla/tests/translate/unknown.hlotxt new file mode 100644 index 00000000000..42d52fd78c8 --- /dev/null +++ b/tensorflow/compiler/mlir/xla/tests/translate/unknown.hlotxt @@ -0,0 +1,11 @@ +// RUN: tf-mlir-translate -hlo-text-to-mlir-hlo %s -o - | FileCheck %s + +HloModule main + +// CHECK-LABEL: func @main(%arg0: tensor<1xf32>) -> tensor<1xf32> { +ENTRY %main (Arg_0.1: f32[1, 4], Arg_1.2: f32[4, 1]) -> f32[1] { + %Arg_0.1 = f32[1] parameter(0) + + // CHECK-NEXT: %0 = "xla.unknown"(%arg0, %arg0) {name = "add-dependency.2"} : (tensor<1xf32>, tensor<1xf32>) -> tensor<1xf32> + ROOT add-dependency.2 = f32[1] add-dependency(Arg_0.1, Arg_0.1) +} diff --git a/tensorflow/compiler/mlir/xla/tests/translate/while.hlotxt b/tensorflow/compiler/mlir/xla/tests/translate/while.hlotxt new file mode 100644 index 00000000000..a6d2a48797e --- /dev/null +++ b/tensorflow/compiler/mlir/xla/tests/translate/while.hlotxt @@ -0,0 +1,27 @@ +// RUN: tf-mlir-translate -hlo-text-to-mlir-hlo %s -o - | FileCheck %s + +HloModule foo + +// CHECK-LABEL: func @cond(%arg0: tensor) -> tensor { +%cond (arg_1: s64[]) -> pred[] { + %arg_1 = s64[] parameter(0), metadata={op_name="XLA_Args"} + // CHECK-NEXT: %0 = "xla.compare"(%arg0, %arg0) {comparison_direction = "LT", name = "compare.2"} : (tensor, tensor) -> tensor + // CHECK-NEXT: return %0 : tensor + ROOT %compare.2 = pred[] compare(%arg_1, %arg_1), direction=LT, metadata={op_type="Less" op_name="Less"} +} + +// CHECK-LABEL: func @loop(%arg0: tensor) -> tensor { +%loop (arg_1: s64[]) -> s64[] { + %arg_1 = s64[] parameter(0), metadata={op_name="XLA_Args"} + // CHECK-NEXT: %0 = "xla.add"(%arg0, %arg0) {name = "compare.0"} : (tensor, tensor) -> tensor + // CHECK-NEXT: return %0 : tensor + ROOT %compare.2 = s64[] add(%arg_1, %arg_1), metadata={op_type="Less" op_name="Less"} +} + +// CHECK-LABEL: func @main(%arg0: tensor) -> tensor { +ENTRY %foo (arg0.1: s64[]) -> s64[] { + %arg0.1 = s64[] parameter(0), metadata={op_name="XLA_Args"} + // CHECK-NEXT: %0 = "xla.while"(%arg0) {body = @loop, cond = @cond} : (tensor) -> tensor + // CHECK-NEXT: return %0 : tensor + ROOT %while.2 = s64[] while(%arg0.1), body=%loop, condition=%cond +} diff --git a/tensorflow/compiler/mlir/xla/type_to_shape.cc b/tensorflow/compiler/mlir/xla/type_to_shape.cc new file mode 100644 index 00000000000..40c896fef9c --- /dev/null +++ b/tensorflow/compiler/mlir/xla/type_to_shape.cc @@ -0,0 +1,145 @@ +/* Copyright 2019 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tensorflow/compiler/mlir/xla/type_to_shape.h" + +#include + +#include "absl/base/integral_types.h" +#include "mlir/IR/AffineMap.h" // TF:local_config_mlir +#include "mlir/IR/Diagnostics.h" // TF:local_config_mlir +#include "mlir/IR/Location.h" // TF:local_config_mlir +#include "mlir/IR/StandardTypes.h" // TF:local_config_mlir +#include "mlir/Support/DebugStringHelper.h" // TF:local_config_mlir +#include "tensorflow/compiler/xla/shape_util.h" +#include "tensorflow/core/platform/logging.h" + +using mlir::IntegerType; +using mlir::MemRefType; +using mlir::RankedTensorType; +using mlir::VectorType; +using xla::PrimitiveType; +using xla::ShapeUtil; + +namespace xla { + +PrimitiveType TypeToPrimitiveType(mlir::Type type) { + switch (type.getKind()) { + case mlir::StandardTypes::BF16: + return PrimitiveType::BF16; + case mlir::StandardTypes::F16: + return PrimitiveType::F16; + case mlir::StandardTypes::F32: + return PrimitiveType::F32; + case mlir::StandardTypes::F64: + return PrimitiveType::F64; + case mlir::StandardTypes::Integer: { + const auto integer = type.cast(); + switch (integer.getWidth()) { + case 1: + return PrimitiveType::PRED; + case 8: + return PrimitiveType::S8; + case 16: + return PrimitiveType::S16; + case 32: + return PrimitiveType::S32; + case 64: + return PrimitiveType::S64; + default: + return PrimitiveType::PRIMITIVE_TYPE_INVALID; + } + } + default: + return PrimitiveType::PRIMITIVE_TYPE_INVALID; + } +} + +Shape TypeToShape(mlir::Type type) { + PrimitiveType ptype = TypeToPrimitiveType(type); + if (ptype != PrimitiveType::PRIMITIVE_TYPE_INVALID) + return ShapeUtil::MakeShape(ptype, {}); + + switch (type.getKind()) { + case mlir::StandardTypes::BF16: + case mlir::StandardTypes::F32: + case mlir::StandardTypes::F64: + case mlir::StandardTypes::Integer: { + auto* context = type.getContext(); + mlir::emitError(mlir::UnknownLoc::get(context)) + << "lowering should have been handled by primitive type lowering for " + << debugString(type); + break; + } + case mlir::StandardTypes::Vector: { + const auto v = type.cast(); + llvm::SmallVector span(v.getShape().begin(), + v.getShape().end()); + mlir::Type element_type = v.getElementType(); + PrimitiveType primitive_type = TypeToPrimitiveType(element_type); + if (primitive_type != PrimitiveType::PRIMITIVE_TYPE_INVALID) + return ShapeUtil::MakeShape(primitive_type, span); + break; + } + case mlir::StandardTypes::MemRef: { + const auto m = type.cast(); + llvm::SmallVector span(m.getShape().begin(), + m.getShape().end()); + mlir::Type element_type = m.getElementType(); + // Treat a memref of a vector as if it was a memref of primitive type with + // the vector dimensions at the end. + if (auto v = element_type.dyn_cast()) { + element_type = v.getElementType(); + span.insert(span.end(), v.getShape().begin(), v.getShape().end()); + } + PrimitiveType primitive_type = TypeToPrimitiveType(element_type); + if (primitive_type == PrimitiveType::PRIMITIVE_TYPE_INVALID) break; + // For the primitive type case, the shape of the memref is similar to the + // vector type case (i.e., it is, modulo the layout, the same dimensions + // and primitive type). + // Currently we only return shapes for identity affine maps. + // TODO(andydavis) Map affine map layout function to XLA layout. + if (m.getAffineMaps().empty() || + (m.getAffineMaps().size() == 1 && m.getAffineMaps()[0].isIdentity())) + return ShapeUtil::MakeShape(primitive_type, span); + break; + } + case mlir::StandardTypes::RankedTensor: { + // TODO(jpienaar): This is only handling the base case with primitive + // element type. + const auto t = type.cast(); + llvm::SmallVector span(t.getShape().begin(), + t.getShape().end()); + // Only fully static shapes are supported. + // TODO(b/115638799): Update once xla::Shape can support dynamic shapes. + if (std::find(t.getShape().begin(), t.getShape().end(), -1) != + t.getShape().end()) + break; + mlir::Type element_type = t.getElementType(); + PrimitiveType primitive_type = TypeToPrimitiveType(element_type); + // Only primitive element type supported. + if (primitive_type != PrimitiveType::PRIMITIVE_TYPE_INVALID) + return ShapeUtil::MakeShape(primitive_type, span); + break; + } + default: + break; + } + // Return empty XLA shape to signify error. No MLIR Type maps to a empty + // Shape. + return {}; +} + +} // namespace xla diff --git a/tensorflow/compiler/mlir/xla/type_to_shape.h b/tensorflow/compiler/mlir/xla/type_to_shape.h new file mode 100644 index 00000000000..6bd5384f857 --- /dev/null +++ b/tensorflow/compiler/mlir/xla/type_to_shape.h @@ -0,0 +1,34 @@ +/* Copyright 2019 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_COMPILER_MLIR_XLA_TYPE_TO_SHAPE_H_ +#define TENSORFLOW_COMPILER_MLIR_XLA_TYPE_TO_SHAPE_H_ + +#include "mlir/IR/Types.h" // TF:local_config_mlir +#include "tensorflow/compiler/xla/shape.h" +#include "tensorflow/compiler/xla/xla_data.pb.h" + +namespace xla { + +// Returns a XLA Shape equivalent of a MLIR Type, else returns empty shape. +Shape TypeToShape(mlir::Type type); + +// Returns a XLA PrimitiveType equivalent of a MLIR Type that represents a +// primitive type (e.g., i8, f32), else returns PRIMITIVE_TYPE_INVALID. +PrimitiveType TypeToPrimitiveType(mlir::Type type); + +} // namespace xla + +#endif // TENSORFLOW_COMPILER_MLIR_XLA_TYPE_TO_SHAPE_H_ diff --git a/tensorflow/compiler/mlir/xla/type_to_shape_test.cc b/tensorflow/compiler/mlir/xla/type_to_shape_test.cc new file mode 100644 index 00000000000..9a77be947d5 --- /dev/null +++ b/tensorflow/compiler/mlir/xla/type_to_shape_test.cc @@ -0,0 +1,110 @@ +/* Copyright 2019 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tensorflow/compiler/mlir/xla/type_to_shape.h" + +#include "mlir/IR/Builders.h" // TF:local_config_mlir +#include "mlir/IR/MLIRContext.h" // TF:local_config_mlir +#include "mlir/IR/StandardTypes.h" // TF:local_config_mlir +#include "tensorflow/compiler/xla/shape_util.h" +#include "tensorflow/compiler/xla/test.h" +#include "tensorflow/compiler/xla/xla_data.pb.h" + +using mlir::Builder; +using mlir::MLIRContext; +using ::testing::EqualsProto; + +namespace xla { +namespace { + +TEST(TypeToShapeTest, ConvertPrimitiveTypes) { + MLIRContext context; + Builder b(&context); + + EXPECT_EQ(TypeToPrimitiveType(b.getF32Type()), PrimitiveType::F32); + EXPECT_EQ(TypeToPrimitiveType(b.getIntegerType(1)), PrimitiveType::PRED); + EXPECT_EQ(TypeToPrimitiveType(b.getIntegerType(17)), + PrimitiveType::PRIMITIVE_TYPE_INVALID); +} + +TEST(TypeToShapeTest, ConvertBasicTypesToTypes) { + MLIRContext context; + Builder b(&context); + + EXPECT_TRUE( + ShapeUtil::IsScalarWithElementType(TypeToShape(b.getF32Type()), F32)); + EXPECT_THAT( + TypeToShape(b.getVectorType({8, 128}, b.getIntegerType(32))).ToProto(), + EqualsProto( + ShapeUtil::MakeShape(PrimitiveType::S32, {8, 128}).ToProto())); + EXPECT_THAT( + TypeToShape(b.getVectorType({8, 128}, b.getF32Type())).ToProto(), + EqualsProto( + ShapeUtil::MakeShape(PrimitiveType::F32, {8, 128}).ToProto())); + + // MLIR Type that is not representable as XLA Shape. + EXPECT_THAT( + TypeToShape(b.getVectorType({8, 128}, b.getIntegerType(17))).ToProto(), + EqualsProto(Shape().ToProto())); +} + +TEST(TypeToShapeTest, ConvertMemRefTypeToTypes) { + MLIRContext context; + Builder b(&context); + + // Memref without any affine map. Note: memory space is ignored for shape. + EXPECT_THAT( + TypeToShape(b.getMemRefType({8, 128}, b.getF32Type())).ToProto(), + EqualsProto( + ShapeUtil::MakeShape(PrimitiveType::F32, {8, 128}).ToProto())); + EXPECT_THAT( + TypeToShape(b.getMemRefType({100, 13, 210}, b.getF32Type())).ToProto(), + EqualsProto( + ShapeUtil::MakeShape(PrimitiveType::F32, {100, 13, 210}).ToProto())); + + // Vector types are "flattened" into the end of the shape. + EXPECT_THAT( + TypeToShape(b.getMemRefType({100, 13, 210}, + b.getVectorType({8, 128}, b.getF32Type()))) + .ToProto(), + EqualsProto( + ShapeUtil::MakeShape(PrimitiveType::F32, {100, 13, 210, 8, 128}) + .ToProto())); +} + +TEST(TypeToShapeTest, ConvertTensorTypeToTypes) { + MLIRContext context; + Builder b(&context); + + EXPECT_THAT( + TypeToShape(b.getTensorType({8, 128}, b.getF32Type())).ToProto(), + EqualsProto( + ShapeUtil::MakeShape(PrimitiveType::F32, {8, 128}).ToProto())); + + // Shape cannot represent dynamic shapes. + // TODO(b/115638799): Update once Shape can support dynamic shapes. + EXPECT_THAT(TypeToShape(b.getTensorType(b.getF32Type())).ToProto(), + EqualsProto(Shape().ToProto())); + + // TODO(jpienaar): Expand to handle more complicated tensor types. + EXPECT_THAT( + TypeToShape( + b.getTensorType({8, 128}, b.getVectorType({16, 16}, b.getF32Type()))) + .ToProto(), + EqualsProto(Shape().ToProto())); +} + +} // namespace +} // namespace xla diff --git a/tensorflow/compiler/mlir/xla/xla_mlir_translate.cc b/tensorflow/compiler/mlir/xla/xla_mlir_translate.cc new file mode 100644 index 00000000000..9804858c084 --- /dev/null +++ b/tensorflow/compiler/mlir/xla/xla_mlir_translate.cc @@ -0,0 +1,192 @@ +/* Copyright 2019 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tensorflow/compiler/mlir/xla/xla_mlir_translate.h" + +#include "google/protobuf/text_format.h" +#include "llvm/Support/MemoryBuffer.h" +#include "llvm/Support/ToolOutputFile.h" +#include "mlir/IR/Module.h" // TF:local_config_mlir +#include "mlir/Translation.h" // TF:local_config_mlir +#include "tensorflow/compiler/mlir/xla/hlo_to_mlir_hlo.h" +#include "tensorflow/compiler/mlir/xla/mlir_hlo_to_hlo.h" +#include "tensorflow/compiler/xla/debug_options_flags.h" +#include "tensorflow/compiler/xla/service/hlo.pb.h" +#include "tensorflow/compiler/xla/service/hlo_parser.h" +#include "tensorflow/core/lib/core/errors.h" + +using stream_executor::port::Status; +using stream_executor::port::StatusOr; // NOLINT TODO(b/130822468) fix this + +namespace xla { + +namespace { +// Error collector that simply ignores errors reported. +class NoOpErrorCollector : public ::proto2::io::ErrorCollector { + public: + void AddError(int line, int column, const string& message) override {} +}; + +bool LoadHloProto(const std::string& contents, HloProto* hlo_proto) { + ::proto2::TextFormat::Parser parser; + NoOpErrorCollector collector; + parser.RecordErrorsTo(&collector); + return hlo_proto->ParseFromString(contents) || + parser.ParseFromString(contents, hlo_proto) || + hlo_proto->mutable_hlo_module()->ParseFromString(contents) || + parser.ParseFromString(contents, hlo_proto->mutable_hlo_module()); +} + +} // namespace + +mlir::OwningModuleRef HloToMlirHloTranslateFunction( + llvm::StringRef input_filename, mlir::MLIRContext* context) { + auto file_or_err = llvm::MemoryBuffer::getFileOrSTDIN(input_filename.str()); + if (std::error_code error = file_or_err.getError()) { + LOG(ERROR) << "Failure to read HLO module: " << error; + return nullptr; + } + + auto& input_file = *file_or_err; + HloProto hlo_proto; + string content(input_file->getBufferStart(), input_file->getBufferSize()); + if (!LoadHloProto(content, &hlo_proto)) { + LOG(ERROR) << "Failed to load proto: " << input_filename.str(); + return nullptr; + } + + mlir::OwningModuleRef module = + mlir::ModuleOp::create(mlir::UnknownLoc::get(context)); + auto status = + ConvertHloToMlirHlo(module.get(), hlo_proto.mutable_hlo_module()); + if (!status.ok()) { + LOG(ERROR) << "Hlo module import failed: " << status; + return nullptr; + } + + return module; +} + +mlir::OwningModuleRef HloTextToMlirHloTranslateFunction( + llvm::StringRef input_filename, mlir::MLIRContext* context) { + auto file_or_err = llvm::MemoryBuffer::getFileOrSTDIN(input_filename.str()); + if (std::error_code error = file_or_err.getError()) { + LOG(ERROR) << "Failure to open file: " << error; + return nullptr; + } + + auto& input_file = *file_or_err; + HloProto hlo_proto; + string content(input_file->getBufferStart(), input_file->getBufferSize()); + + auto hlo_module_error = ParseAndReturnUnverifiedModule(content); + if (!hlo_module_error.ok()) { + LOG(ERROR) << "HLO Module loading failed: " << hlo_module_error.status(); + return nullptr; + } + + auto hlo_module = std::move(hlo_module_error.ValueOrDie()); + mlir::OwningModuleRef module = + mlir::ModuleOp::create(mlir::UnknownLoc::get(context)); + auto status = ConvertHloToMlirHlo(*module, hlo_module.get()); + if (!status.ok()) { + LOG(ERROR) << "HLO Module import failed: " << status; + return nullptr; + } + + return module; +} + +static mlir::LogicalResult MlirHloToHloTranslateFunction( + mlir::ModuleOp module, llvm::StringRef output_filename) { + if (!module) return mlir::failure(); + + std::error_code error; + auto result = llvm::make_unique(output_filename, error, + llvm::sys::fs::F_None); + if (error) { + LOG(ERROR) << error.message(); + return mlir::failure(); + } + + HloProto hloProto; + Status status = mlir::ConvertMlirHloToHlo(module, &hloProto); + if (!status.ok()) { + LOG(ERROR) << "Module conversion failed: " << status; + return mlir::failure(); + } + + result->os() << hloProto.DebugString(); + result->keep(); + return mlir::success(); +} + +static StatusOr> HloModuleFromProto( + const HloProto& hlo_proto) { + const HloModuleProto& module_proto = hlo_proto.hlo_module(); + TF_ASSIGN_OR_RETURN(const HloModuleConfig module_config, + HloModule::CreateModuleConfigFromProto( + module_proto, GetDebugOptionsFromFlags())); + return HloModule::CreateFromProto(module_proto, module_config); +} + +static mlir::LogicalResult MlirHloToHloTextTranslateFunction( + mlir::ModuleOp module, llvm::StringRef output_filename) { + if (!module) return mlir::failure(); + + std::error_code error; + auto result = llvm::make_unique(output_filename, error, + llvm::sys::fs::F_None); + if (error) { + LOG(ERROR) << error.message(); + return mlir::failure(); + } + + HloProto hloProto; + Status status = mlir::ConvertMlirHloToHlo(module, &hloProto); + if (!status.ok()) { + LOG(ERROR) << "Module conversion failed: " << status; + return mlir::failure(); + } + + auto statusOrHloModule = HloModuleFromProto(hloProto); + + if (!statusOrHloModule.ok()) { + LOG(ERROR) << "Conversion to HLO module failed: " + << statusOrHloModule.status(); + return mlir::failure(); + } + + result->os() << statusOrHloModule.ValueOrDie()->ToString( + HloPrintOptions() + // We don't interpret or use layouts + .set_include_layout_in_shapes(false)); + result->keep(); + return mlir::success(); +} + +} // namespace xla + +static mlir::TranslateFromMLIRRegistration MlirHloToHloTranslate( + "mlir-hlo-to-hlo", xla::MlirHloToHloTranslateFunction); + +static mlir::TranslateFromMLIRRegistration MlirHloToHloTextTranslate( + "mlir-hlo-to-hlo-text", xla::MlirHloToHloTextTranslateFunction); + +static mlir::TranslateToMLIRRegistration HloToHloMlirTranslate( + "hlo-to-mlir-hlo", xla::HloToMlirHloTranslateFunction); + +static mlir::TranslateToMLIRRegistration HloTextToHloMlirTranslate( + "hlo-text-to-mlir-hlo", xla::HloTextToMlirHloTranslateFunction); diff --git a/tensorflow/compiler/mlir/xla/xla_mlir_translate.h b/tensorflow/compiler/mlir/xla/xla_mlir_translate.h new file mode 100644 index 00000000000..6b3d79b97cb --- /dev/null +++ b/tensorflow/compiler/mlir/xla/xla_mlir_translate.h @@ -0,0 +1,47 @@ +/* Copyright 2019 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_COMPILER_MLIR_XLA_XLA_MLIR_TRANSLATE_H_ +#define TENSORFLOW_COMPILER_MLIR_XLA_XLA_MLIR_TRANSLATE_H_ + +#include + +#include "tensorflow/compiler/xla/statusor.h" + +namespace llvm { +class StringRef; +} // namespace llvm + +namespace mlir { +class MLIRContext; +class OwningModuleRef; +} // namespace mlir + +namespace xla { + +// Converts a HloModuleProto stored in the file with the given `input_filename` +// into a MLIR module. Creates MLIR entities into the given MLIR `context`. +mlir::OwningModuleRef HloToMlirHloTranslateFunction( + llvm::StringRef input_filename, mlir::MLIRContext* context); + +// Converts a HloModule stored in text form for a file with the given +// `input_filename` into a MLIR module. Creates MLIR entities into the given +// MLIR `context`. +mlir::OwningModuleRef HloTextToMlirHloTranslateFunction( + llvm::StringRef input_filename, mlir::MLIRContext* context); + +} // namespace xla + +#endif // TENSORFLOW_COMPILER_MLIR_XLA_XLA_MLIR_TRANSLATE_H_