RFC: Create separate lib for attribute converters from XLA to MLIR
These will be re-used in MlirHloBuilder separately. PiperOrigin-RevId: 307032293 Change-Id: I3517bb764b1046bffe39baf177bd44b103cfb640
This commit is contained in:
parent
2d676104cb
commit
0eb2537aa0
tensorflow/compiler/mlir/xla
@ -704,6 +704,7 @@ cc_library(
|
|||||||
"hlo_module_importer.h",
|
"hlo_module_importer.h",
|
||||||
],
|
],
|
||||||
deps = [
|
deps = [
|
||||||
|
":attribute_importer",
|
||||||
":hlo",
|
":hlo",
|
||||||
":hlo_utils",
|
":hlo_utils",
|
||||||
"//tensorflow/compiler/mlir/tensorflow:error_util",
|
"//tensorflow/compiler/mlir/tensorflow:error_util",
|
||||||
@ -723,6 +724,18 @@ cc_library(
|
|||||||
],
|
],
|
||||||
)
|
)
|
||||||
|
|
||||||
|
cc_library(
|
||||||
|
name = "attribute_importer",
|
||||||
|
srcs = ["attribute_importer.cc"],
|
||||||
|
hdrs = ["attribute_importer.h"],
|
||||||
|
deps = [
|
||||||
|
":hlo",
|
||||||
|
"//tensorflow/compiler/xla:xla_data_proto_cc",
|
||||||
|
"//tensorflow/core/platform:types",
|
||||||
|
"@llvm-project//mlir:IR",
|
||||||
|
],
|
||||||
|
)
|
||||||
|
|
||||||
cc_library(
|
cc_library(
|
||||||
name = "xla_mlir_translate",
|
name = "xla_mlir_translate",
|
||||||
srcs = ["xla_mlir_translate.cc"],
|
srcs = ["xla_mlir_translate.cc"],
|
||||||
|
124
tensorflow/compiler/mlir/xla/attribute_importer.cc
Normal file
124
tensorflow/compiler/mlir/xla/attribute_importer.cc
Normal file
@ -0,0 +1,124 @@
|
|||||||
|
/* Copyright 2019 The TensorFlow Authors. All Rights Reserved.
|
||||||
|
|
||||||
|
Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
you may not use this file except in compliance with the License.
|
||||||
|
You may obtain a copy of the License at
|
||||||
|
|
||||||
|
http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
|
||||||
|
Unless required by applicable law or agreed to in writing, software
|
||||||
|
distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
See the License for the specific language governing permissions and
|
||||||
|
limitations under the License.
|
||||||
|
==============================================================================*/
|
||||||
|
|
||||||
|
#include "tensorflow/compiler/mlir/xla/attribute_importer.h"
|
||||||
|
|
||||||
|
#include <vector>
|
||||||
|
|
||||||
|
namespace xla {
|
||||||
|
|
||||||
|
static mlir::DenseIntElementsAttr Convert(llvm::ArrayRef<int64_t> elements,
|
||||||
|
mlir::Builder* builder) {
|
||||||
|
return mlir::DenseIntElementsAttr::get(
|
||||||
|
mlir::RankedTensorType::get(elements.size(), builder->getIntegerType(64)),
|
||||||
|
elements);
|
||||||
|
}
|
||||||
|
|
||||||
|
mlir::ArrayAttr ConvertPrecisionConfig(const PrecisionConfig* config,
|
||||||
|
mlir::Builder* builder) {
|
||||||
|
if (!config) return {};
|
||||||
|
|
||||||
|
// TODO(b/129709049) The HLO text format elides this in the all DEFAULT
|
||||||
|
// case and the parser sticks it in. Maybe we should too.
|
||||||
|
llvm::SmallVector<mlir::Attribute, 4> operand_precision_attrs;
|
||||||
|
|
||||||
|
for (auto prec : config->operand_precision()) {
|
||||||
|
operand_precision_attrs.push_back(
|
||||||
|
builder->getStringAttr(PrecisionConfig_Precision_Name(prec)));
|
||||||
|
}
|
||||||
|
return builder->getArrayAttr(operand_precision_attrs);
|
||||||
|
}
|
||||||
|
|
||||||
|
// Converts the gather dimensions to attributes.
|
||||||
|
mlir::xla_hlo::GatherDimensionNumbers ConvertGatherDimensionNumbers(
|
||||||
|
const xla::GatherDimensionNumbers& dnums, mlir::Builder* builder) {
|
||||||
|
std::vector<int64_t> offset_dims(dnums.offset_dims().begin(),
|
||||||
|
dnums.offset_dims().end());
|
||||||
|
std::vector<int64_t> collapsed_slice_dims(
|
||||||
|
dnums.collapsed_slice_dims().begin(), dnums.collapsed_slice_dims().end());
|
||||||
|
std::vector<int64_t> start_index_map(dnums.start_index_map().begin(),
|
||||||
|
dnums.start_index_map().end());
|
||||||
|
return mlir::xla_hlo::GatherDimensionNumbers::get(
|
||||||
|
Convert(offset_dims, builder), Convert(collapsed_slice_dims, builder),
|
||||||
|
Convert(start_index_map, builder),
|
||||||
|
builder->getI64IntegerAttr(dnums.index_vector_dim()),
|
||||||
|
builder->getContext());
|
||||||
|
}
|
||||||
|
|
||||||
|
mlir::xla_hlo::ScatterDimensionNumbers ConvertScatterDimensionNumbers(
|
||||||
|
const xla::ScatterDimensionNumbers& dnums, mlir::Builder* builder) {
|
||||||
|
std::vector<int64_t> update_window_dims(dnums.update_window_dims().begin(),
|
||||||
|
dnums.update_window_dims().end());
|
||||||
|
std::vector<int64_t> inserted_window_dims(
|
||||||
|
dnums.inserted_window_dims().begin(), dnums.inserted_window_dims().end());
|
||||||
|
std::vector<int64_t> scatter_dims_to_operand_dims(
|
||||||
|
dnums.scatter_dims_to_operand_dims().begin(),
|
||||||
|
dnums.scatter_dims_to_operand_dims().end());
|
||||||
|
return mlir::xla_hlo::ScatterDimensionNumbers::get(
|
||||||
|
Convert(update_window_dims, builder),
|
||||||
|
Convert(inserted_window_dims, builder),
|
||||||
|
Convert(scatter_dims_to_operand_dims, builder),
|
||||||
|
builder->getI64IntegerAttr(dnums.index_vector_dim()),
|
||||||
|
builder->getContext());
|
||||||
|
}
|
||||||
|
|
||||||
|
mlir::xla_hlo::DotDimensionNumbers ConvertDotDimensionNumbers(
|
||||||
|
const DotDimensionNumbers& dnums, mlir::Builder* builder) {
|
||||||
|
std::vector<int64_t> rhs_contracting_dimensions(
|
||||||
|
dnums.rhs_contracting_dimensions().begin(),
|
||||||
|
dnums.rhs_contracting_dimensions().end());
|
||||||
|
std::vector<int64_t> lhs_contracting_dimensions(
|
||||||
|
dnums.lhs_contracting_dimensions().begin(),
|
||||||
|
dnums.lhs_contracting_dimensions().end());
|
||||||
|
std::vector<int64_t> rhs_batch_dimensions(
|
||||||
|
dnums.rhs_batch_dimensions().begin(), dnums.rhs_batch_dimensions().end());
|
||||||
|
std::vector<int64_t> lhs_batch_dimensions(
|
||||||
|
dnums.lhs_batch_dimensions().begin(), dnums.lhs_batch_dimensions().end());
|
||||||
|
|
||||||
|
// Push the attributes into our new DictionaryAttr.
|
||||||
|
auto lhs_batch_dims_attr = Convert(lhs_batch_dimensions, builder);
|
||||||
|
auto rhs_batch_dims_attr = Convert(rhs_batch_dimensions, builder);
|
||||||
|
auto lhs_contracting_dims_attr = Convert(lhs_contracting_dimensions, builder);
|
||||||
|
auto rhs_contracting_dims_attr = Convert(rhs_contracting_dimensions, builder);
|
||||||
|
|
||||||
|
return mlir::xla_hlo::DotDimensionNumbers::get(
|
||||||
|
lhs_batch_dims_attr, rhs_batch_dims_attr, lhs_contracting_dims_attr,
|
||||||
|
rhs_contracting_dims_attr, builder->getContext());
|
||||||
|
}
|
||||||
|
|
||||||
|
mlir::xla_hlo::ConvDimensionNumbers ConvertConvDimensionNumbers(
|
||||||
|
const xla::ConvolutionDimensionNumbers& dnums, mlir::Builder* builder) {
|
||||||
|
llvm::SmallVector<int64_t, 4> input_spatial_dims(
|
||||||
|
dnums.input_spatial_dimensions().begin(),
|
||||||
|
dnums.input_spatial_dimensions().end());
|
||||||
|
llvm::SmallVector<int64_t, 4> kernel_spatial_dims(
|
||||||
|
dnums.kernel_spatial_dimensions().begin(),
|
||||||
|
dnums.kernel_spatial_dimensions().end());
|
||||||
|
llvm::SmallVector<int64_t, 4> output_spatial_dims(
|
||||||
|
dnums.output_spatial_dimensions().begin(),
|
||||||
|
dnums.output_spatial_dimensions().end());
|
||||||
|
return mlir::xla_hlo::ConvDimensionNumbers::get(
|
||||||
|
builder->getI64IntegerAttr(dnums.input_batch_dimension()),
|
||||||
|
builder->getI64IntegerAttr(dnums.input_feature_dimension()),
|
||||||
|
Convert(input_spatial_dims, builder),
|
||||||
|
builder->getI64IntegerAttr(dnums.kernel_input_feature_dimension()),
|
||||||
|
builder->getI64IntegerAttr(dnums.kernel_output_feature_dimension()),
|
||||||
|
Convert(kernel_spatial_dims, builder),
|
||||||
|
builder->getI64IntegerAttr(dnums.output_batch_dimension()),
|
||||||
|
builder->getI64IntegerAttr(dnums.kernel_output_feature_dimension()),
|
||||||
|
Convert(output_spatial_dims, builder), builder->getContext());
|
||||||
|
}
|
||||||
|
|
||||||
|
} // namespace xla
|
49
tensorflow/compiler/mlir/xla/attribute_importer.h
Normal file
49
tensorflow/compiler/mlir/xla/attribute_importer.h
Normal file
@ -0,0 +1,49 @@
|
|||||||
|
/* Copyright 2019 The TensorFlow Authors. All Rights Reserved.
|
||||||
|
|
||||||
|
Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
you may not use this file except in compliance with the License.
|
||||||
|
You may obtain a copy of the License at
|
||||||
|
|
||||||
|
http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
|
||||||
|
Unless required by applicable law or agreed to in writing, software
|
||||||
|
distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
See the License for the specific language governing permissions and
|
||||||
|
limitations under the License.
|
||||||
|
==============================================================================*/
|
||||||
|
|
||||||
|
#ifndef TENSORFLOW_COMPILER_MLIR_XLA_ATTRIBUTE_IMPORTER_H_
|
||||||
|
#define TENSORFLOW_COMPILER_MLIR_XLA_ATTRIBUTE_IMPORTER_H_
|
||||||
|
|
||||||
|
#include "mlir/IR/Attributes.h" // from @llvm-project
|
||||||
|
#include "mlir/IR/Builders.h" // from @llvm-project
|
||||||
|
#include "tensorflow/compiler/mlir/xla/ir/hlo_ops.h"
|
||||||
|
#include "tensorflow/compiler/xla/xla_data.pb.h"
|
||||||
|
#include "tensorflow/core/platform/types.h"
|
||||||
|
|
||||||
|
namespace xla {
|
||||||
|
|
||||||
|
// Converts an XLA PrecisionConfig to the corresponding MLIR attribute.
|
||||||
|
mlir::ArrayAttr ConvertPrecisionConfig(const PrecisionConfig* config,
|
||||||
|
mlir::Builder* builder);
|
||||||
|
|
||||||
|
// Converts the gather dimensions to attributes.
|
||||||
|
mlir::xla_hlo::GatherDimensionNumbers ConvertGatherDimensionNumbers(
|
||||||
|
const xla::GatherDimensionNumbers& dnums, mlir::Builder* builder);
|
||||||
|
|
||||||
|
// Converts the scatter dimensions to attributes.
|
||||||
|
mlir::xla_hlo::ScatterDimensionNumbers ConvertScatterDimensionNumbers(
|
||||||
|
const xla::ScatterDimensionNumbers& dnums, mlir::Builder* builder);
|
||||||
|
|
||||||
|
// Converts the dot dimensions to attributes.
|
||||||
|
mlir::xla_hlo::DotDimensionNumbers ConvertDotDimensionNumbers(
|
||||||
|
const DotDimensionNumbers& dnums, mlir::Builder* builder);
|
||||||
|
|
||||||
|
// Converts the conv dimensions to attributes.
|
||||||
|
mlir::xla_hlo::ConvDimensionNumbers ConvertConvDimensionNumbers(
|
||||||
|
const xla::ConvolutionDimensionNumbers& dnums, mlir::Builder* builder);
|
||||||
|
|
||||||
|
} // namespace xla
|
||||||
|
|
||||||
|
#endif // TENSORFLOW_COMPILER_MLIR_XLA_ATTRIBUTE_IMPORTER_H_
|
@ -29,6 +29,7 @@ limitations under the License.
|
|||||||
#include "mlir/IR/Region.h" // from @llvm-project
|
#include "mlir/IR/Region.h" // from @llvm-project
|
||||||
#include "mlir/IR/StandardTypes.h" // from @llvm-project
|
#include "mlir/IR/StandardTypes.h" // from @llvm-project
|
||||||
#include "tensorflow/compiler/mlir/tensorflow/utils/error_util.h"
|
#include "tensorflow/compiler/mlir/tensorflow/utils/error_util.h"
|
||||||
|
#include "tensorflow/compiler/mlir/xla/attribute_importer.h"
|
||||||
#include "tensorflow/compiler/mlir/xla/hlo_utils.h"
|
#include "tensorflow/compiler/mlir/xla/hlo_utils.h"
|
||||||
#include "tensorflow/compiler/mlir/xla/ir/hlo_ops.h"
|
#include "tensorflow/compiler/mlir/xla/ir/hlo_ops.h"
|
||||||
#include "tensorflow/compiler/xla/protobuf_util.h"
|
#include "tensorflow/compiler/xla/protobuf_util.h"
|
||||||
@ -56,6 +57,7 @@ using mlir::Value;
|
|||||||
namespace xla {
|
namespace xla {
|
||||||
|
|
||||||
namespace {
|
namespace {
|
||||||
|
|
||||||
// Note: This sanitization function causes an irreversible many-to-one mapping
|
// Note: This sanitization function causes an irreversible many-to-one mapping
|
||||||
// and any solution to mitigate this would cause issues with the reverse
|
// and any solution to mitigate this would cause issues with the reverse
|
||||||
// direction. Longterm solution is to add a function attribute to maintain the
|
// direction. Longterm solution is to add a function attribute to maintain the
|
||||||
@ -230,15 +232,19 @@ StatusOr<mlir::Operation*> HloFunctionImporter::ImportInstruction(
|
|||||||
#undef MakeAndReturnBatchNormOp
|
#undef MakeAndReturnBatchNormOp
|
||||||
|
|
||||||
case HloOpcode::kDot: {
|
case HloOpcode::kDot: {
|
||||||
attributes.push_back(ConvertPrecisionConfig(instruction));
|
attributes.push_back(builder_->getNamedAttr(
|
||||||
|
"precision_config",
|
||||||
|
ConvertPrecisionConfig(&instruction->precision_config(), builder_)));
|
||||||
|
|
||||||
// Consider consolidating DotOps together.
|
// Consider consolidating DotOps together.
|
||||||
if (DotIsDefault(instruction)) {
|
if (DotIsDefault(instruction)) {
|
||||||
MakeAndReturn(DotOp);
|
MakeAndReturn(DotOp);
|
||||||
}
|
}
|
||||||
|
|
||||||
attributes.push_back(
|
attributes.push_back(builder_->getNamedAttr(
|
||||||
ConvertDotDimensionNumbers(instruction->dot_dimension_numbers()));
|
"dot_dimension_numbers",
|
||||||
|
ConvertDotDimensionNumbers(instruction->dot_dimension_numbers(),
|
||||||
|
builder_)));
|
||||||
MakeAndReturn(DotGeneralOp);
|
MakeAndReturn(DotGeneralOp);
|
||||||
}
|
}
|
||||||
case HloOpcode::kCall: {
|
case HloOpcode::kCall: {
|
||||||
@ -278,8 +284,10 @@ StatusOr<mlir::Operation*> HloFunctionImporter::ImportInstruction(
|
|||||||
}
|
}
|
||||||
case HloOpcode::kGather: {
|
case HloOpcode::kGather: {
|
||||||
auto gather_instruction = Cast<HloGatherInstruction>(instruction);
|
auto gather_instruction = Cast<HloGatherInstruction>(instruction);
|
||||||
attributes.push_back(ConvertGatherDimensionNumbers(
|
attributes.push_back(builder_->getNamedAttr(
|
||||||
gather_instruction->gather_dimension_numbers()));
|
"dimension_numbers",
|
||||||
|
ConvertGatherDimensionNumbers(
|
||||||
|
gather_instruction->gather_dimension_numbers(), builder_)));
|
||||||
|
|
||||||
std::vector<int64_t> slice_sizes(
|
std::vector<int64_t> slice_sizes(
|
||||||
gather_instruction->gather_slice_sizes().begin(),
|
gather_instruction->gather_slice_sizes().begin(),
|
||||||
@ -345,8 +353,10 @@ StatusOr<mlir::Operation*> HloFunctionImporter::ImportInstruction(
|
|||||||
}
|
}
|
||||||
case HloOpcode::kScatter: {
|
case HloOpcode::kScatter: {
|
||||||
auto scatter = Cast<HloScatterInstruction>(instruction);
|
auto scatter = Cast<HloScatterInstruction>(instruction);
|
||||||
attributes.push_back(
|
attributes.push_back(builder_->getNamedAttr(
|
||||||
ConvertScatterDimensionNumbers(scatter->scatter_dimension_numbers()));
|
"scatter_dimension_numbers",
|
||||||
|
ConvertScatterDimensionNumbers(scatter->scatter_dimension_numbers(),
|
||||||
|
builder_)));
|
||||||
attributes.push_back(builder_->getNamedAttr(
|
attributes.push_back(builder_->getNamedAttr(
|
||||||
"indices_are_sorted",
|
"indices_are_sorted",
|
||||||
builder_->getBoolAttr(scatter->indices_are_sorted())));
|
builder_->getBoolAttr(scatter->indices_are_sorted())));
|
||||||
@ -577,15 +587,20 @@ StatusOr<mlir::Operation*> HloFunctionImporter::ImportInstruction(
|
|||||||
builder_->getNamedAttr("lhs_dilations", Convert(lhs_dilations)));
|
builder_->getNamedAttr("lhs_dilations", Convert(lhs_dilations)));
|
||||||
attributes.push_back(
|
attributes.push_back(
|
||||||
builder_->getNamedAttr("rhs_dilations", Convert(rhs_dilations)));
|
builder_->getNamedAttr("rhs_dilations", Convert(rhs_dilations)));
|
||||||
attributes.push_back(ConvertConvDimensionNumbers(
|
attributes.push_back(builder_->getNamedAttr(
|
||||||
instruction->convolution_dimension_numbers()));
|
"dimension_numbers",
|
||||||
|
ConvertConvDimensionNumbers(
|
||||||
|
instruction->convolution_dimension_numbers(), builder_)));
|
||||||
attributes.push_back(builder_->getNamedAttr(
|
attributes.push_back(builder_->getNamedAttr(
|
||||||
"feature_group_count",
|
"feature_group_count",
|
||||||
builder_->getI64IntegerAttr(instruction->feature_group_count())));
|
builder_->getI64IntegerAttr(instruction->feature_group_count())));
|
||||||
attributes.push_back(builder_->getNamedAttr(
|
attributes.push_back(builder_->getNamedAttr(
|
||||||
"batch_group_count",
|
"batch_group_count",
|
||||||
builder_->getI64IntegerAttr(instruction->batch_group_count())));
|
builder_->getI64IntegerAttr(instruction->batch_group_count())));
|
||||||
attributes.push_back(ConvertPrecisionConfig(instruction));
|
attributes.push_back(builder_->getNamedAttr(
|
||||||
|
"precision_config",
|
||||||
|
ConvertPrecisionConfig(&instruction->precision_config(), builder_)));
|
||||||
|
|
||||||
MakeAndReturn(ConvOp);
|
MakeAndReturn(ConvOp);
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -717,20 +732,6 @@ StatusOr<Value> HloFunctionImporter::GetMlirValue(HloInstruction* instruction) {
|
|||||||
"Unable to find value for input: ", instruction->ToString()));
|
"Unable to find value for input: ", instruction->ToString()));
|
||||||
}
|
}
|
||||||
|
|
||||||
mlir::NamedAttribute HloFunctionImporter::ConvertPrecisionConfig(
|
|
||||||
HloInstruction* instruction) {
|
|
||||||
// TODO(b/129709049) The HLO text format elides this in the all DEFAULT
|
|
||||||
// case and the parser sticks it in. Maybe we should too.
|
|
||||||
llvm::SmallVector<mlir::Attribute, 4> operand_precision_attrs;
|
|
||||||
|
|
||||||
for (auto prec : instruction->precision_config().operand_precision()) {
|
|
||||||
operand_precision_attrs.push_back(
|
|
||||||
builder_->getStringAttr(PrecisionConfig_Precision_Name(prec)));
|
|
||||||
}
|
|
||||||
return builder_->getNamedAttr(
|
|
||||||
"precision_config", builder_->getArrayAttr(operand_precision_attrs));
|
|
||||||
}
|
|
||||||
|
|
||||||
mlir::NamedAttribute HloFunctionImporter::ConvertComparisonDirection(
|
mlir::NamedAttribute HloFunctionImporter::ConvertComparisonDirection(
|
||||||
HloInstruction* instruction) {
|
HloInstruction* instruction) {
|
||||||
return builder_->getNamedAttr(
|
return builder_->getNamedAttr(
|
||||||
@ -751,10 +752,10 @@ mlir::DenseIntElementsAttr HloFunctionImporter::ConvertDimensions(
|
|||||||
}
|
}
|
||||||
|
|
||||||
mlir::DenseIntElementsAttr HloFunctionImporter::Convert(
|
mlir::DenseIntElementsAttr HloFunctionImporter::Convert(
|
||||||
llvm::ArrayRef<int64_t> op_dimensions) {
|
llvm::ArrayRef<int64_t> elements) {
|
||||||
return DenseIntElementsAttr::get(
|
return DenseIntElementsAttr::get(
|
||||||
RankedTensorType::get(op_dimensions.size(), builder_->getIntegerType(64)),
|
RankedTensorType::get(elements.size(), builder_->getIntegerType(64)),
|
||||||
op_dimensions);
|
elements);
|
||||||
}
|
}
|
||||||
|
|
||||||
mlir::NamedAttribute HloFunctionImporter::ConvertPadding(
|
mlir::NamedAttribute HloFunctionImporter::ConvertPadding(
|
||||||
@ -766,86 +767,6 @@ mlir::NamedAttribute HloFunctionImporter::ConvertPadding(
|
|||||||
return builder_->getNamedAttr("padding", attr);
|
return builder_->getNamedAttr("padding", attr);
|
||||||
}
|
}
|
||||||
|
|
||||||
mlir::NamedAttribute HloFunctionImporter::ConvertDotDimensionNumbers(
|
|
||||||
const DotDimensionNumbers& dnums) {
|
|
||||||
std::vector<int64_t> rhs_contracting_dimensions(
|
|
||||||
dnums.rhs_contracting_dimensions().begin(),
|
|
||||||
dnums.rhs_contracting_dimensions().end());
|
|
||||||
std::vector<int64_t> lhs_contracting_dimensions(
|
|
||||||
dnums.lhs_contracting_dimensions().begin(),
|
|
||||||
dnums.lhs_contracting_dimensions().end());
|
|
||||||
std::vector<int64_t> rhs_batch_dimensions(
|
|
||||||
dnums.rhs_batch_dimensions().begin(), dnums.rhs_batch_dimensions().end());
|
|
||||||
std::vector<int64_t> lhs_batch_dimensions(
|
|
||||||
dnums.lhs_batch_dimensions().begin(), dnums.lhs_batch_dimensions().end());
|
|
||||||
|
|
||||||
// Push the attributes into our new DictionaryAttr.
|
|
||||||
auto lhs_batch_dims_attr = Convert(lhs_batch_dimensions);
|
|
||||||
auto rhs_batch_dims_attr = Convert(rhs_batch_dimensions);
|
|
||||||
auto lhs_contracting_dims_attr = Convert(lhs_contracting_dimensions);
|
|
||||||
auto rhs_contracting_dims_attr = Convert(rhs_contracting_dimensions);
|
|
||||||
|
|
||||||
auto attr = mlir::xla_hlo::DotDimensionNumbers::get(
|
|
||||||
lhs_batch_dims_attr, rhs_batch_dims_attr, lhs_contracting_dims_attr,
|
|
||||||
rhs_contracting_dims_attr, context_);
|
|
||||||
return builder_->getNamedAttr("dot_dimension_numbers", attr);
|
|
||||||
}
|
|
||||||
|
|
||||||
mlir::NamedAttribute HloFunctionImporter::ConvertConvDimensionNumbers(
|
|
||||||
const xla::ConvolutionDimensionNumbers& dnums) {
|
|
||||||
llvm::SmallVector<int64_t, 4> input_spatial_dims(
|
|
||||||
dnums.input_spatial_dimensions().begin(),
|
|
||||||
dnums.input_spatial_dimensions().end());
|
|
||||||
llvm::SmallVector<int64_t, 4> kernel_spatial_dims(
|
|
||||||
dnums.kernel_spatial_dimensions().begin(),
|
|
||||||
dnums.kernel_spatial_dimensions().end());
|
|
||||||
llvm::SmallVector<int64_t, 4> output_spatial_dims(
|
|
||||||
dnums.output_spatial_dimensions().begin(),
|
|
||||||
dnums.output_spatial_dimensions().end());
|
|
||||||
auto attr = mlir::xla_hlo::ConvDimensionNumbers::get(
|
|
||||||
builder_->getI64IntegerAttr(dnums.input_batch_dimension()),
|
|
||||||
builder_->getI64IntegerAttr(dnums.input_feature_dimension()),
|
|
||||||
Convert(input_spatial_dims),
|
|
||||||
builder_->getI64IntegerAttr(dnums.kernel_input_feature_dimension()),
|
|
||||||
builder_->getI64IntegerAttr(dnums.kernel_output_feature_dimension()),
|
|
||||||
Convert(kernel_spatial_dims),
|
|
||||||
builder_->getI64IntegerAttr(dnums.output_batch_dimension()),
|
|
||||||
builder_->getI64IntegerAttr(dnums.kernel_output_feature_dimension()),
|
|
||||||
Convert(output_spatial_dims), context_);
|
|
||||||
return builder_->getNamedAttr("dimension_numbers", attr);
|
|
||||||
}
|
|
||||||
|
|
||||||
mlir::NamedAttribute HloFunctionImporter::ConvertGatherDimensionNumbers(
|
|
||||||
const xla::GatherDimensionNumbers& dnums) {
|
|
||||||
std::vector<int64_t> offset_dims(dnums.offset_dims().begin(),
|
|
||||||
dnums.offset_dims().end());
|
|
||||||
std::vector<int64_t> collapsed_slice_dims(
|
|
||||||
dnums.collapsed_slice_dims().begin(), dnums.collapsed_slice_dims().end());
|
|
||||||
std::vector<int64_t> start_index_map(dnums.start_index_map().begin(),
|
|
||||||
dnums.start_index_map().end());
|
|
||||||
auto attr = mlir::xla_hlo::GatherDimensionNumbers::get(
|
|
||||||
Convert(offset_dims), Convert(collapsed_slice_dims),
|
|
||||||
Convert(start_index_map),
|
|
||||||
builder_->getI64IntegerAttr(dnums.index_vector_dim()), context_);
|
|
||||||
return builder_->getNamedAttr("dimension_numbers", attr);
|
|
||||||
}
|
|
||||||
|
|
||||||
mlir::NamedAttribute HloFunctionImporter::ConvertScatterDimensionNumbers(
|
|
||||||
const xla::ScatterDimensionNumbers& dnums) {
|
|
||||||
std::vector<int64_t> update_window_dims(dnums.update_window_dims().begin(),
|
|
||||||
dnums.update_window_dims().end());
|
|
||||||
std::vector<int64_t> inserted_window_dims(
|
|
||||||
dnums.inserted_window_dims().begin(), dnums.inserted_window_dims().end());
|
|
||||||
std::vector<int64_t> scatter_dims_to_operand_dims(
|
|
||||||
dnums.scatter_dims_to_operand_dims().begin(),
|
|
||||||
dnums.scatter_dims_to_operand_dims().end());
|
|
||||||
auto attr = mlir::xla_hlo::ScatterDimensionNumbers::get(
|
|
||||||
Convert(update_window_dims), Convert(inserted_window_dims),
|
|
||||||
Convert(scatter_dims_to_operand_dims),
|
|
||||||
builder_->getI64IntegerAttr(dnums.index_vector_dim()), context_);
|
|
||||||
return builder_->getNamedAttr("scatter_dimension_numbers", attr);
|
|
||||||
}
|
|
||||||
|
|
||||||
mlir::NamedAttribute HloFunctionImporter::ConvertSourceTargetPairs(
|
mlir::NamedAttribute HloFunctionImporter::ConvertSourceTargetPairs(
|
||||||
const std::vector<std::pair<tensorflow::int64, tensorflow::int64>>&
|
const std::vector<std::pair<tensorflow::int64, tensorflow::int64>>&
|
||||||
source_target_pairs) {
|
source_target_pairs) {
|
||||||
|
@ -89,9 +89,6 @@ class HloFunctionImporter {
|
|||||||
// Returns the Mlir Value for the corresponding HloInstruction.
|
// Returns the Mlir Value for the corresponding HloInstruction.
|
||||||
StatusOr<mlir::Value> GetMlirValue(xla::HloInstruction* instruction);
|
StatusOr<mlir::Value> GetMlirValue(xla::HloInstruction* instruction);
|
||||||
|
|
||||||
// Converts an XLA PrecisionConfig to the corresponding MLIR attribute.
|
|
||||||
mlir::NamedAttribute ConvertPrecisionConfig(xla::HloInstruction* instruction);
|
|
||||||
|
|
||||||
// Converts an XLA ComparisonDirection to the corresponding MLIR attribute.
|
// Converts an XLA ComparisonDirection to the corresponding MLIR attribute.
|
||||||
mlir::NamedAttribute ConvertComparisonDirection(
|
mlir::NamedAttribute ConvertComparisonDirection(
|
||||||
xla::HloInstruction* instruction);
|
xla::HloInstruction* instruction);
|
||||||
@ -101,28 +98,12 @@ class HloFunctionImporter {
|
|||||||
llvm::ArrayRef<tensorflow::int64> op_dimensions);
|
llvm::ArrayRef<tensorflow::int64> op_dimensions);
|
||||||
|
|
||||||
// Converts Array ref to an DenseIntElementsAttr.
|
// Converts Array ref to an DenseIntElementsAttr.
|
||||||
mlir::DenseIntElementsAttr Convert(llvm::ArrayRef<int64_t> op_dimensions);
|
mlir::DenseIntElementsAttr Convert(llvm::ArrayRef<int64_t> elements);
|
||||||
|
|
||||||
// Converts Array ref to padding attribute. Input is a flattened list of
|
// Converts Array ref to padding attribute. Input is a flattened list of
|
||||||
// padding low and padding high for each of the spatial dimensions.
|
// padding low and padding high for each of the spatial dimensions.
|
||||||
mlir::NamedAttribute ConvertPadding(llvm::ArrayRef<int64_t> padding);
|
mlir::NamedAttribute ConvertPadding(llvm::ArrayRef<int64_t> padding);
|
||||||
|
|
||||||
// Converts the dot dimensions to attribute.
|
|
||||||
mlir::NamedAttribute ConvertDotDimensionNumbers(
|
|
||||||
const DotDimensionNumbers& dnums);
|
|
||||||
|
|
||||||
// Converts the conv dimensions to attributes.
|
|
||||||
mlir::NamedAttribute ConvertConvDimensionNumbers(
|
|
||||||
const xla::ConvolutionDimensionNumbers& dnums);
|
|
||||||
|
|
||||||
// Converts the gather dimensions to attributes.
|
|
||||||
mlir::NamedAttribute ConvertGatherDimensionNumbers(
|
|
||||||
const xla::GatherDimensionNumbers& dnums);
|
|
||||||
|
|
||||||
// Converts the scatter dimensions to attributes.
|
|
||||||
mlir::NamedAttribute ConvertScatterDimensionNumbers(
|
|
||||||
const xla::ScatterDimensionNumbers& dnums);
|
|
||||||
|
|
||||||
// Converts replica groups to attribute
|
// Converts replica groups to attribute
|
||||||
mlir::NamedAttribute ConvertReplicaGroups(
|
mlir::NamedAttribute ConvertReplicaGroups(
|
||||||
const std::vector<ReplicaGroup>& replica_groups);
|
const std::vector<ReplicaGroup>& replica_groups);
|
||||||
|
Loading…
Reference in New Issue
Block a user