RFC: Create separate lib for attribute converters from XLA to MLIR

These will be re-used in MlirHloBuilder separately.

PiperOrigin-RevId: 307032293
Change-Id: I3517bb764b1046bffe39baf177bd44b103cfb640
This commit is contained in:
Smit Hinsu 2020-04-17 06:27:40 -07:00 committed by TensorFlower Gardener
parent 2d676104cb
commit 0eb2537aa0
5 changed files with 215 additions and 127 deletions

View File

@ -704,6 +704,7 @@ cc_library(
"hlo_module_importer.h",
],
deps = [
":attribute_importer",
":hlo",
":hlo_utils",
"//tensorflow/compiler/mlir/tensorflow:error_util",
@ -723,6 +724,18 @@ cc_library(
],
)
cc_library(
name = "attribute_importer",
srcs = ["attribute_importer.cc"],
hdrs = ["attribute_importer.h"],
deps = [
":hlo",
"//tensorflow/compiler/xla:xla_data_proto_cc",
"//tensorflow/core/platform:types",
"@llvm-project//mlir:IR",
],
)
cc_library(
name = "xla_mlir_translate",
srcs = ["xla_mlir_translate.cc"],

View File

@ -0,0 +1,124 @@
/* Copyright 2019 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#include "tensorflow/compiler/mlir/xla/attribute_importer.h"
#include <vector>
namespace xla {
static mlir::DenseIntElementsAttr Convert(llvm::ArrayRef<int64_t> elements,
mlir::Builder* builder) {
return mlir::DenseIntElementsAttr::get(
mlir::RankedTensorType::get(elements.size(), builder->getIntegerType(64)),
elements);
}
mlir::ArrayAttr ConvertPrecisionConfig(const PrecisionConfig* config,
mlir::Builder* builder) {
if (!config) return {};
// TODO(b/129709049) The HLO text format elides this in the all DEFAULT
// case and the parser sticks it in. Maybe we should too.
llvm::SmallVector<mlir::Attribute, 4> operand_precision_attrs;
for (auto prec : config->operand_precision()) {
operand_precision_attrs.push_back(
builder->getStringAttr(PrecisionConfig_Precision_Name(prec)));
}
return builder->getArrayAttr(operand_precision_attrs);
}
// Converts the gather dimensions to attributes.
mlir::xla_hlo::GatherDimensionNumbers ConvertGatherDimensionNumbers(
const xla::GatherDimensionNumbers& dnums, mlir::Builder* builder) {
std::vector<int64_t> offset_dims(dnums.offset_dims().begin(),
dnums.offset_dims().end());
std::vector<int64_t> collapsed_slice_dims(
dnums.collapsed_slice_dims().begin(), dnums.collapsed_slice_dims().end());
std::vector<int64_t> start_index_map(dnums.start_index_map().begin(),
dnums.start_index_map().end());
return mlir::xla_hlo::GatherDimensionNumbers::get(
Convert(offset_dims, builder), Convert(collapsed_slice_dims, builder),
Convert(start_index_map, builder),
builder->getI64IntegerAttr(dnums.index_vector_dim()),
builder->getContext());
}
mlir::xla_hlo::ScatterDimensionNumbers ConvertScatterDimensionNumbers(
const xla::ScatterDimensionNumbers& dnums, mlir::Builder* builder) {
std::vector<int64_t> update_window_dims(dnums.update_window_dims().begin(),
dnums.update_window_dims().end());
std::vector<int64_t> inserted_window_dims(
dnums.inserted_window_dims().begin(), dnums.inserted_window_dims().end());
std::vector<int64_t> scatter_dims_to_operand_dims(
dnums.scatter_dims_to_operand_dims().begin(),
dnums.scatter_dims_to_operand_dims().end());
return mlir::xla_hlo::ScatterDimensionNumbers::get(
Convert(update_window_dims, builder),
Convert(inserted_window_dims, builder),
Convert(scatter_dims_to_operand_dims, builder),
builder->getI64IntegerAttr(dnums.index_vector_dim()),
builder->getContext());
}
mlir::xla_hlo::DotDimensionNumbers ConvertDotDimensionNumbers(
const DotDimensionNumbers& dnums, mlir::Builder* builder) {
std::vector<int64_t> rhs_contracting_dimensions(
dnums.rhs_contracting_dimensions().begin(),
dnums.rhs_contracting_dimensions().end());
std::vector<int64_t> lhs_contracting_dimensions(
dnums.lhs_contracting_dimensions().begin(),
dnums.lhs_contracting_dimensions().end());
std::vector<int64_t> rhs_batch_dimensions(
dnums.rhs_batch_dimensions().begin(), dnums.rhs_batch_dimensions().end());
std::vector<int64_t> lhs_batch_dimensions(
dnums.lhs_batch_dimensions().begin(), dnums.lhs_batch_dimensions().end());
// Push the attributes into our new DictionaryAttr.
auto lhs_batch_dims_attr = Convert(lhs_batch_dimensions, builder);
auto rhs_batch_dims_attr = Convert(rhs_batch_dimensions, builder);
auto lhs_contracting_dims_attr = Convert(lhs_contracting_dimensions, builder);
auto rhs_contracting_dims_attr = Convert(rhs_contracting_dimensions, builder);
return mlir::xla_hlo::DotDimensionNumbers::get(
lhs_batch_dims_attr, rhs_batch_dims_attr, lhs_contracting_dims_attr,
rhs_contracting_dims_attr, builder->getContext());
}
mlir::xla_hlo::ConvDimensionNumbers ConvertConvDimensionNumbers(
const xla::ConvolutionDimensionNumbers& dnums, mlir::Builder* builder) {
llvm::SmallVector<int64_t, 4> input_spatial_dims(
dnums.input_spatial_dimensions().begin(),
dnums.input_spatial_dimensions().end());
llvm::SmallVector<int64_t, 4> kernel_spatial_dims(
dnums.kernel_spatial_dimensions().begin(),
dnums.kernel_spatial_dimensions().end());
llvm::SmallVector<int64_t, 4> output_spatial_dims(
dnums.output_spatial_dimensions().begin(),
dnums.output_spatial_dimensions().end());
return mlir::xla_hlo::ConvDimensionNumbers::get(
builder->getI64IntegerAttr(dnums.input_batch_dimension()),
builder->getI64IntegerAttr(dnums.input_feature_dimension()),
Convert(input_spatial_dims, builder),
builder->getI64IntegerAttr(dnums.kernel_input_feature_dimension()),
builder->getI64IntegerAttr(dnums.kernel_output_feature_dimension()),
Convert(kernel_spatial_dims, builder),
builder->getI64IntegerAttr(dnums.output_batch_dimension()),
builder->getI64IntegerAttr(dnums.kernel_output_feature_dimension()),
Convert(output_spatial_dims, builder), builder->getContext());
}
} // namespace xla

View File

@ -0,0 +1,49 @@
/* Copyright 2019 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#ifndef TENSORFLOW_COMPILER_MLIR_XLA_ATTRIBUTE_IMPORTER_H_
#define TENSORFLOW_COMPILER_MLIR_XLA_ATTRIBUTE_IMPORTER_H_
#include "mlir/IR/Attributes.h" // from @llvm-project
#include "mlir/IR/Builders.h" // from @llvm-project
#include "tensorflow/compiler/mlir/xla/ir/hlo_ops.h"
#include "tensorflow/compiler/xla/xla_data.pb.h"
#include "tensorflow/core/platform/types.h"
namespace xla {
// Converts an XLA PrecisionConfig to the corresponding MLIR attribute.
mlir::ArrayAttr ConvertPrecisionConfig(const PrecisionConfig* config,
mlir::Builder* builder);
// Converts the gather dimensions to attributes.
mlir::xla_hlo::GatherDimensionNumbers ConvertGatherDimensionNumbers(
const xla::GatherDimensionNumbers& dnums, mlir::Builder* builder);
// Converts the scatter dimensions to attributes.
mlir::xla_hlo::ScatterDimensionNumbers ConvertScatterDimensionNumbers(
const xla::ScatterDimensionNumbers& dnums, mlir::Builder* builder);
// Converts the dot dimensions to attributes.
mlir::xla_hlo::DotDimensionNumbers ConvertDotDimensionNumbers(
const DotDimensionNumbers& dnums, mlir::Builder* builder);
// Converts the conv dimensions to attributes.
mlir::xla_hlo::ConvDimensionNumbers ConvertConvDimensionNumbers(
const xla::ConvolutionDimensionNumbers& dnums, mlir::Builder* builder);
} // namespace xla
#endif // TENSORFLOW_COMPILER_MLIR_XLA_ATTRIBUTE_IMPORTER_H_

View File

@ -29,6 +29,7 @@ limitations under the License.
#include "mlir/IR/Region.h" // from @llvm-project
#include "mlir/IR/StandardTypes.h" // from @llvm-project
#include "tensorflow/compiler/mlir/tensorflow/utils/error_util.h"
#include "tensorflow/compiler/mlir/xla/attribute_importer.h"
#include "tensorflow/compiler/mlir/xla/hlo_utils.h"
#include "tensorflow/compiler/mlir/xla/ir/hlo_ops.h"
#include "tensorflow/compiler/xla/protobuf_util.h"
@ -56,6 +57,7 @@ using mlir::Value;
namespace xla {
namespace {
// Note: This sanitization function causes an irreversible many-to-one mapping
// and any solution to mitigate this would cause issues with the reverse
// direction. Longterm solution is to add a function attribute to maintain the
@ -230,15 +232,19 @@ StatusOr<mlir::Operation*> HloFunctionImporter::ImportInstruction(
#undef MakeAndReturnBatchNormOp
case HloOpcode::kDot: {
attributes.push_back(ConvertPrecisionConfig(instruction));
attributes.push_back(builder_->getNamedAttr(
"precision_config",
ConvertPrecisionConfig(&instruction->precision_config(), builder_)));
// Consider consolidating DotOps together.
if (DotIsDefault(instruction)) {
MakeAndReturn(DotOp);
}
attributes.push_back(
ConvertDotDimensionNumbers(instruction->dot_dimension_numbers()));
attributes.push_back(builder_->getNamedAttr(
"dot_dimension_numbers",
ConvertDotDimensionNumbers(instruction->dot_dimension_numbers(),
builder_)));
MakeAndReturn(DotGeneralOp);
}
case HloOpcode::kCall: {
@ -278,8 +284,10 @@ StatusOr<mlir::Operation*> HloFunctionImporter::ImportInstruction(
}
case HloOpcode::kGather: {
auto gather_instruction = Cast<HloGatherInstruction>(instruction);
attributes.push_back(ConvertGatherDimensionNumbers(
gather_instruction->gather_dimension_numbers()));
attributes.push_back(builder_->getNamedAttr(
"dimension_numbers",
ConvertGatherDimensionNumbers(
gather_instruction->gather_dimension_numbers(), builder_)));
std::vector<int64_t> slice_sizes(
gather_instruction->gather_slice_sizes().begin(),
@ -345,8 +353,10 @@ StatusOr<mlir::Operation*> HloFunctionImporter::ImportInstruction(
}
case HloOpcode::kScatter: {
auto scatter = Cast<HloScatterInstruction>(instruction);
attributes.push_back(
ConvertScatterDimensionNumbers(scatter->scatter_dimension_numbers()));
attributes.push_back(builder_->getNamedAttr(
"scatter_dimension_numbers",
ConvertScatterDimensionNumbers(scatter->scatter_dimension_numbers(),
builder_)));
attributes.push_back(builder_->getNamedAttr(
"indices_are_sorted",
builder_->getBoolAttr(scatter->indices_are_sorted())));
@ -577,15 +587,20 @@ StatusOr<mlir::Operation*> HloFunctionImporter::ImportInstruction(
builder_->getNamedAttr("lhs_dilations", Convert(lhs_dilations)));
attributes.push_back(
builder_->getNamedAttr("rhs_dilations", Convert(rhs_dilations)));
attributes.push_back(ConvertConvDimensionNumbers(
instruction->convolution_dimension_numbers()));
attributes.push_back(builder_->getNamedAttr(
"dimension_numbers",
ConvertConvDimensionNumbers(
instruction->convolution_dimension_numbers(), builder_)));
attributes.push_back(builder_->getNamedAttr(
"feature_group_count",
builder_->getI64IntegerAttr(instruction->feature_group_count())));
attributes.push_back(builder_->getNamedAttr(
"batch_group_count",
builder_->getI64IntegerAttr(instruction->batch_group_count())));
attributes.push_back(ConvertPrecisionConfig(instruction));
attributes.push_back(builder_->getNamedAttr(
"precision_config",
ConvertPrecisionConfig(&instruction->precision_config(), builder_)));
MakeAndReturn(ConvOp);
}
@ -717,20 +732,6 @@ StatusOr<Value> HloFunctionImporter::GetMlirValue(HloInstruction* instruction) {
"Unable to find value for input: ", instruction->ToString()));
}
mlir::NamedAttribute HloFunctionImporter::ConvertPrecisionConfig(
HloInstruction* instruction) {
// TODO(b/129709049) The HLO text format elides this in the all DEFAULT
// case and the parser sticks it in. Maybe we should too.
llvm::SmallVector<mlir::Attribute, 4> operand_precision_attrs;
for (auto prec : instruction->precision_config().operand_precision()) {
operand_precision_attrs.push_back(
builder_->getStringAttr(PrecisionConfig_Precision_Name(prec)));
}
return builder_->getNamedAttr(
"precision_config", builder_->getArrayAttr(operand_precision_attrs));
}
mlir::NamedAttribute HloFunctionImporter::ConvertComparisonDirection(
HloInstruction* instruction) {
return builder_->getNamedAttr(
@ -751,10 +752,10 @@ mlir::DenseIntElementsAttr HloFunctionImporter::ConvertDimensions(
}
mlir::DenseIntElementsAttr HloFunctionImporter::Convert(
llvm::ArrayRef<int64_t> op_dimensions) {
llvm::ArrayRef<int64_t> elements) {
return DenseIntElementsAttr::get(
RankedTensorType::get(op_dimensions.size(), builder_->getIntegerType(64)),
op_dimensions);
RankedTensorType::get(elements.size(), builder_->getIntegerType(64)),
elements);
}
mlir::NamedAttribute HloFunctionImporter::ConvertPadding(
@ -766,86 +767,6 @@ mlir::NamedAttribute HloFunctionImporter::ConvertPadding(
return builder_->getNamedAttr("padding", attr);
}
mlir::NamedAttribute HloFunctionImporter::ConvertDotDimensionNumbers(
const DotDimensionNumbers& dnums) {
std::vector<int64_t> rhs_contracting_dimensions(
dnums.rhs_contracting_dimensions().begin(),
dnums.rhs_contracting_dimensions().end());
std::vector<int64_t> lhs_contracting_dimensions(
dnums.lhs_contracting_dimensions().begin(),
dnums.lhs_contracting_dimensions().end());
std::vector<int64_t> rhs_batch_dimensions(
dnums.rhs_batch_dimensions().begin(), dnums.rhs_batch_dimensions().end());
std::vector<int64_t> lhs_batch_dimensions(
dnums.lhs_batch_dimensions().begin(), dnums.lhs_batch_dimensions().end());
// Push the attributes into our new DictionaryAttr.
auto lhs_batch_dims_attr = Convert(lhs_batch_dimensions);
auto rhs_batch_dims_attr = Convert(rhs_batch_dimensions);
auto lhs_contracting_dims_attr = Convert(lhs_contracting_dimensions);
auto rhs_contracting_dims_attr = Convert(rhs_contracting_dimensions);
auto attr = mlir::xla_hlo::DotDimensionNumbers::get(
lhs_batch_dims_attr, rhs_batch_dims_attr, lhs_contracting_dims_attr,
rhs_contracting_dims_attr, context_);
return builder_->getNamedAttr("dot_dimension_numbers", attr);
}
mlir::NamedAttribute HloFunctionImporter::ConvertConvDimensionNumbers(
const xla::ConvolutionDimensionNumbers& dnums) {
llvm::SmallVector<int64_t, 4> input_spatial_dims(
dnums.input_spatial_dimensions().begin(),
dnums.input_spatial_dimensions().end());
llvm::SmallVector<int64_t, 4> kernel_spatial_dims(
dnums.kernel_spatial_dimensions().begin(),
dnums.kernel_spatial_dimensions().end());
llvm::SmallVector<int64_t, 4> output_spatial_dims(
dnums.output_spatial_dimensions().begin(),
dnums.output_spatial_dimensions().end());
auto attr = mlir::xla_hlo::ConvDimensionNumbers::get(
builder_->getI64IntegerAttr(dnums.input_batch_dimension()),
builder_->getI64IntegerAttr(dnums.input_feature_dimension()),
Convert(input_spatial_dims),
builder_->getI64IntegerAttr(dnums.kernel_input_feature_dimension()),
builder_->getI64IntegerAttr(dnums.kernel_output_feature_dimension()),
Convert(kernel_spatial_dims),
builder_->getI64IntegerAttr(dnums.output_batch_dimension()),
builder_->getI64IntegerAttr(dnums.kernel_output_feature_dimension()),
Convert(output_spatial_dims), context_);
return builder_->getNamedAttr("dimension_numbers", attr);
}
mlir::NamedAttribute HloFunctionImporter::ConvertGatherDimensionNumbers(
const xla::GatherDimensionNumbers& dnums) {
std::vector<int64_t> offset_dims(dnums.offset_dims().begin(),
dnums.offset_dims().end());
std::vector<int64_t> collapsed_slice_dims(
dnums.collapsed_slice_dims().begin(), dnums.collapsed_slice_dims().end());
std::vector<int64_t> start_index_map(dnums.start_index_map().begin(),
dnums.start_index_map().end());
auto attr = mlir::xla_hlo::GatherDimensionNumbers::get(
Convert(offset_dims), Convert(collapsed_slice_dims),
Convert(start_index_map),
builder_->getI64IntegerAttr(dnums.index_vector_dim()), context_);
return builder_->getNamedAttr("dimension_numbers", attr);
}
mlir::NamedAttribute HloFunctionImporter::ConvertScatterDimensionNumbers(
const xla::ScatterDimensionNumbers& dnums) {
std::vector<int64_t> update_window_dims(dnums.update_window_dims().begin(),
dnums.update_window_dims().end());
std::vector<int64_t> inserted_window_dims(
dnums.inserted_window_dims().begin(), dnums.inserted_window_dims().end());
std::vector<int64_t> scatter_dims_to_operand_dims(
dnums.scatter_dims_to_operand_dims().begin(),
dnums.scatter_dims_to_operand_dims().end());
auto attr = mlir::xla_hlo::ScatterDimensionNumbers::get(
Convert(update_window_dims), Convert(inserted_window_dims),
Convert(scatter_dims_to_operand_dims),
builder_->getI64IntegerAttr(dnums.index_vector_dim()), context_);
return builder_->getNamedAttr("scatter_dimension_numbers", attr);
}
mlir::NamedAttribute HloFunctionImporter::ConvertSourceTargetPairs(
const std::vector<std::pair<tensorflow::int64, tensorflow::int64>>&
source_target_pairs) {

View File

@ -89,9 +89,6 @@ class HloFunctionImporter {
// Returns the Mlir Value for the corresponding HloInstruction.
StatusOr<mlir::Value> GetMlirValue(xla::HloInstruction* instruction);
// Converts an XLA PrecisionConfig to the corresponding MLIR attribute.
mlir::NamedAttribute ConvertPrecisionConfig(xla::HloInstruction* instruction);
// Converts an XLA ComparisonDirection to the corresponding MLIR attribute.
mlir::NamedAttribute ConvertComparisonDirection(
xla::HloInstruction* instruction);
@ -101,28 +98,12 @@ class HloFunctionImporter {
llvm::ArrayRef<tensorflow::int64> op_dimensions);
// Converts Array ref to an DenseIntElementsAttr.
mlir::DenseIntElementsAttr Convert(llvm::ArrayRef<int64_t> op_dimensions);
mlir::DenseIntElementsAttr Convert(llvm::ArrayRef<int64_t> elements);
// Converts Array ref to padding attribute. Input is a flattened list of
// padding low and padding high for each of the spatial dimensions.
mlir::NamedAttribute ConvertPadding(llvm::ArrayRef<int64_t> padding);
// Converts the dot dimensions to attribute.
mlir::NamedAttribute ConvertDotDimensionNumbers(
const DotDimensionNumbers& dnums);
// Converts the conv dimensions to attributes.
mlir::NamedAttribute ConvertConvDimensionNumbers(
const xla::ConvolutionDimensionNumbers& dnums);
// Converts the gather dimensions to attributes.
mlir::NamedAttribute ConvertGatherDimensionNumbers(
const xla::GatherDimensionNumbers& dnums);
// Converts the scatter dimensions to attributes.
mlir::NamedAttribute ConvertScatterDimensionNumbers(
const xla::ScatterDimensionNumbers& dnums);
// Converts replica groups to attribute
mlir::NamedAttribute ConvertReplicaGroups(
const std::vector<ReplicaGroup>& replica_groups);