RFC: Create separate lib for attribute converters from XLA to MLIR
These will be re-used in MlirHloBuilder separately. PiperOrigin-RevId: 307032293 Change-Id: I3517bb764b1046bffe39baf177bd44b103cfb640
This commit is contained in:
parent
2d676104cb
commit
0eb2537aa0
@ -704,6 +704,7 @@ cc_library(
|
||||
"hlo_module_importer.h",
|
||||
],
|
||||
deps = [
|
||||
":attribute_importer",
|
||||
":hlo",
|
||||
":hlo_utils",
|
||||
"//tensorflow/compiler/mlir/tensorflow:error_util",
|
||||
@ -723,6 +724,18 @@ cc_library(
|
||||
],
|
||||
)
|
||||
|
||||
cc_library(
|
||||
name = "attribute_importer",
|
||||
srcs = ["attribute_importer.cc"],
|
||||
hdrs = ["attribute_importer.h"],
|
||||
deps = [
|
||||
":hlo",
|
||||
"//tensorflow/compiler/xla:xla_data_proto_cc",
|
||||
"//tensorflow/core/platform:types",
|
||||
"@llvm-project//mlir:IR",
|
||||
],
|
||||
)
|
||||
|
||||
cc_library(
|
||||
name = "xla_mlir_translate",
|
||||
srcs = ["xla_mlir_translate.cc"],
|
||||
|
124
tensorflow/compiler/mlir/xla/attribute_importer.cc
Normal file
124
tensorflow/compiler/mlir/xla/attribute_importer.cc
Normal file
@ -0,0 +1,124 @@
|
||||
/* Copyright 2019 The TensorFlow Authors. All Rights Reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
|
||||
#include "tensorflow/compiler/mlir/xla/attribute_importer.h"
|
||||
|
||||
#include <vector>
|
||||
|
||||
namespace xla {
|
||||
|
||||
static mlir::DenseIntElementsAttr Convert(llvm::ArrayRef<int64_t> elements,
|
||||
mlir::Builder* builder) {
|
||||
return mlir::DenseIntElementsAttr::get(
|
||||
mlir::RankedTensorType::get(elements.size(), builder->getIntegerType(64)),
|
||||
elements);
|
||||
}
|
||||
|
||||
mlir::ArrayAttr ConvertPrecisionConfig(const PrecisionConfig* config,
|
||||
mlir::Builder* builder) {
|
||||
if (!config) return {};
|
||||
|
||||
// TODO(b/129709049) The HLO text format elides this in the all DEFAULT
|
||||
// case and the parser sticks it in. Maybe we should too.
|
||||
llvm::SmallVector<mlir::Attribute, 4> operand_precision_attrs;
|
||||
|
||||
for (auto prec : config->operand_precision()) {
|
||||
operand_precision_attrs.push_back(
|
||||
builder->getStringAttr(PrecisionConfig_Precision_Name(prec)));
|
||||
}
|
||||
return builder->getArrayAttr(operand_precision_attrs);
|
||||
}
|
||||
|
||||
// Converts the gather dimensions to attributes.
|
||||
mlir::xla_hlo::GatherDimensionNumbers ConvertGatherDimensionNumbers(
|
||||
const xla::GatherDimensionNumbers& dnums, mlir::Builder* builder) {
|
||||
std::vector<int64_t> offset_dims(dnums.offset_dims().begin(),
|
||||
dnums.offset_dims().end());
|
||||
std::vector<int64_t> collapsed_slice_dims(
|
||||
dnums.collapsed_slice_dims().begin(), dnums.collapsed_slice_dims().end());
|
||||
std::vector<int64_t> start_index_map(dnums.start_index_map().begin(),
|
||||
dnums.start_index_map().end());
|
||||
return mlir::xla_hlo::GatherDimensionNumbers::get(
|
||||
Convert(offset_dims, builder), Convert(collapsed_slice_dims, builder),
|
||||
Convert(start_index_map, builder),
|
||||
builder->getI64IntegerAttr(dnums.index_vector_dim()),
|
||||
builder->getContext());
|
||||
}
|
||||
|
||||
mlir::xla_hlo::ScatterDimensionNumbers ConvertScatterDimensionNumbers(
|
||||
const xla::ScatterDimensionNumbers& dnums, mlir::Builder* builder) {
|
||||
std::vector<int64_t> update_window_dims(dnums.update_window_dims().begin(),
|
||||
dnums.update_window_dims().end());
|
||||
std::vector<int64_t> inserted_window_dims(
|
||||
dnums.inserted_window_dims().begin(), dnums.inserted_window_dims().end());
|
||||
std::vector<int64_t> scatter_dims_to_operand_dims(
|
||||
dnums.scatter_dims_to_operand_dims().begin(),
|
||||
dnums.scatter_dims_to_operand_dims().end());
|
||||
return mlir::xla_hlo::ScatterDimensionNumbers::get(
|
||||
Convert(update_window_dims, builder),
|
||||
Convert(inserted_window_dims, builder),
|
||||
Convert(scatter_dims_to_operand_dims, builder),
|
||||
builder->getI64IntegerAttr(dnums.index_vector_dim()),
|
||||
builder->getContext());
|
||||
}
|
||||
|
||||
mlir::xla_hlo::DotDimensionNumbers ConvertDotDimensionNumbers(
|
||||
const DotDimensionNumbers& dnums, mlir::Builder* builder) {
|
||||
std::vector<int64_t> rhs_contracting_dimensions(
|
||||
dnums.rhs_contracting_dimensions().begin(),
|
||||
dnums.rhs_contracting_dimensions().end());
|
||||
std::vector<int64_t> lhs_contracting_dimensions(
|
||||
dnums.lhs_contracting_dimensions().begin(),
|
||||
dnums.lhs_contracting_dimensions().end());
|
||||
std::vector<int64_t> rhs_batch_dimensions(
|
||||
dnums.rhs_batch_dimensions().begin(), dnums.rhs_batch_dimensions().end());
|
||||
std::vector<int64_t> lhs_batch_dimensions(
|
||||
dnums.lhs_batch_dimensions().begin(), dnums.lhs_batch_dimensions().end());
|
||||
|
||||
// Push the attributes into our new DictionaryAttr.
|
||||
auto lhs_batch_dims_attr = Convert(lhs_batch_dimensions, builder);
|
||||
auto rhs_batch_dims_attr = Convert(rhs_batch_dimensions, builder);
|
||||
auto lhs_contracting_dims_attr = Convert(lhs_contracting_dimensions, builder);
|
||||
auto rhs_contracting_dims_attr = Convert(rhs_contracting_dimensions, builder);
|
||||
|
||||
return mlir::xla_hlo::DotDimensionNumbers::get(
|
||||
lhs_batch_dims_attr, rhs_batch_dims_attr, lhs_contracting_dims_attr,
|
||||
rhs_contracting_dims_attr, builder->getContext());
|
||||
}
|
||||
|
||||
mlir::xla_hlo::ConvDimensionNumbers ConvertConvDimensionNumbers(
|
||||
const xla::ConvolutionDimensionNumbers& dnums, mlir::Builder* builder) {
|
||||
llvm::SmallVector<int64_t, 4> input_spatial_dims(
|
||||
dnums.input_spatial_dimensions().begin(),
|
||||
dnums.input_spatial_dimensions().end());
|
||||
llvm::SmallVector<int64_t, 4> kernel_spatial_dims(
|
||||
dnums.kernel_spatial_dimensions().begin(),
|
||||
dnums.kernel_spatial_dimensions().end());
|
||||
llvm::SmallVector<int64_t, 4> output_spatial_dims(
|
||||
dnums.output_spatial_dimensions().begin(),
|
||||
dnums.output_spatial_dimensions().end());
|
||||
return mlir::xla_hlo::ConvDimensionNumbers::get(
|
||||
builder->getI64IntegerAttr(dnums.input_batch_dimension()),
|
||||
builder->getI64IntegerAttr(dnums.input_feature_dimension()),
|
||||
Convert(input_spatial_dims, builder),
|
||||
builder->getI64IntegerAttr(dnums.kernel_input_feature_dimension()),
|
||||
builder->getI64IntegerAttr(dnums.kernel_output_feature_dimension()),
|
||||
Convert(kernel_spatial_dims, builder),
|
||||
builder->getI64IntegerAttr(dnums.output_batch_dimension()),
|
||||
builder->getI64IntegerAttr(dnums.kernel_output_feature_dimension()),
|
||||
Convert(output_spatial_dims, builder), builder->getContext());
|
||||
}
|
||||
|
||||
} // namespace xla
|
49
tensorflow/compiler/mlir/xla/attribute_importer.h
Normal file
49
tensorflow/compiler/mlir/xla/attribute_importer.h
Normal file
@ -0,0 +1,49 @@
|
||||
/* Copyright 2019 The TensorFlow Authors. All Rights Reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
|
||||
#ifndef TENSORFLOW_COMPILER_MLIR_XLA_ATTRIBUTE_IMPORTER_H_
|
||||
#define TENSORFLOW_COMPILER_MLIR_XLA_ATTRIBUTE_IMPORTER_H_
|
||||
|
||||
#include "mlir/IR/Attributes.h" // from @llvm-project
|
||||
#include "mlir/IR/Builders.h" // from @llvm-project
|
||||
#include "tensorflow/compiler/mlir/xla/ir/hlo_ops.h"
|
||||
#include "tensorflow/compiler/xla/xla_data.pb.h"
|
||||
#include "tensorflow/core/platform/types.h"
|
||||
|
||||
namespace xla {
|
||||
|
||||
// Converts an XLA PrecisionConfig to the corresponding MLIR attribute.
|
||||
mlir::ArrayAttr ConvertPrecisionConfig(const PrecisionConfig* config,
|
||||
mlir::Builder* builder);
|
||||
|
||||
// Converts the gather dimensions to attributes.
|
||||
mlir::xla_hlo::GatherDimensionNumbers ConvertGatherDimensionNumbers(
|
||||
const xla::GatherDimensionNumbers& dnums, mlir::Builder* builder);
|
||||
|
||||
// Converts the scatter dimensions to attributes.
|
||||
mlir::xla_hlo::ScatterDimensionNumbers ConvertScatterDimensionNumbers(
|
||||
const xla::ScatterDimensionNumbers& dnums, mlir::Builder* builder);
|
||||
|
||||
// Converts the dot dimensions to attributes.
|
||||
mlir::xla_hlo::DotDimensionNumbers ConvertDotDimensionNumbers(
|
||||
const DotDimensionNumbers& dnums, mlir::Builder* builder);
|
||||
|
||||
// Converts the conv dimensions to attributes.
|
||||
mlir::xla_hlo::ConvDimensionNumbers ConvertConvDimensionNumbers(
|
||||
const xla::ConvolutionDimensionNumbers& dnums, mlir::Builder* builder);
|
||||
|
||||
} // namespace xla
|
||||
|
||||
#endif // TENSORFLOW_COMPILER_MLIR_XLA_ATTRIBUTE_IMPORTER_H_
|
@ -29,6 +29,7 @@ limitations under the License.
|
||||
#include "mlir/IR/Region.h" // from @llvm-project
|
||||
#include "mlir/IR/StandardTypes.h" // from @llvm-project
|
||||
#include "tensorflow/compiler/mlir/tensorflow/utils/error_util.h"
|
||||
#include "tensorflow/compiler/mlir/xla/attribute_importer.h"
|
||||
#include "tensorflow/compiler/mlir/xla/hlo_utils.h"
|
||||
#include "tensorflow/compiler/mlir/xla/ir/hlo_ops.h"
|
||||
#include "tensorflow/compiler/xla/protobuf_util.h"
|
||||
@ -56,6 +57,7 @@ using mlir::Value;
|
||||
namespace xla {
|
||||
|
||||
namespace {
|
||||
|
||||
// Note: This sanitization function causes an irreversible many-to-one mapping
|
||||
// and any solution to mitigate this would cause issues with the reverse
|
||||
// direction. Longterm solution is to add a function attribute to maintain the
|
||||
@ -230,15 +232,19 @@ StatusOr<mlir::Operation*> HloFunctionImporter::ImportInstruction(
|
||||
#undef MakeAndReturnBatchNormOp
|
||||
|
||||
case HloOpcode::kDot: {
|
||||
attributes.push_back(ConvertPrecisionConfig(instruction));
|
||||
attributes.push_back(builder_->getNamedAttr(
|
||||
"precision_config",
|
||||
ConvertPrecisionConfig(&instruction->precision_config(), builder_)));
|
||||
|
||||
// Consider consolidating DotOps together.
|
||||
if (DotIsDefault(instruction)) {
|
||||
MakeAndReturn(DotOp);
|
||||
}
|
||||
|
||||
attributes.push_back(
|
||||
ConvertDotDimensionNumbers(instruction->dot_dimension_numbers()));
|
||||
attributes.push_back(builder_->getNamedAttr(
|
||||
"dot_dimension_numbers",
|
||||
ConvertDotDimensionNumbers(instruction->dot_dimension_numbers(),
|
||||
builder_)));
|
||||
MakeAndReturn(DotGeneralOp);
|
||||
}
|
||||
case HloOpcode::kCall: {
|
||||
@ -278,8 +284,10 @@ StatusOr<mlir::Operation*> HloFunctionImporter::ImportInstruction(
|
||||
}
|
||||
case HloOpcode::kGather: {
|
||||
auto gather_instruction = Cast<HloGatherInstruction>(instruction);
|
||||
attributes.push_back(ConvertGatherDimensionNumbers(
|
||||
gather_instruction->gather_dimension_numbers()));
|
||||
attributes.push_back(builder_->getNamedAttr(
|
||||
"dimension_numbers",
|
||||
ConvertGatherDimensionNumbers(
|
||||
gather_instruction->gather_dimension_numbers(), builder_)));
|
||||
|
||||
std::vector<int64_t> slice_sizes(
|
||||
gather_instruction->gather_slice_sizes().begin(),
|
||||
@ -345,8 +353,10 @@ StatusOr<mlir::Operation*> HloFunctionImporter::ImportInstruction(
|
||||
}
|
||||
case HloOpcode::kScatter: {
|
||||
auto scatter = Cast<HloScatterInstruction>(instruction);
|
||||
attributes.push_back(
|
||||
ConvertScatterDimensionNumbers(scatter->scatter_dimension_numbers()));
|
||||
attributes.push_back(builder_->getNamedAttr(
|
||||
"scatter_dimension_numbers",
|
||||
ConvertScatterDimensionNumbers(scatter->scatter_dimension_numbers(),
|
||||
builder_)));
|
||||
attributes.push_back(builder_->getNamedAttr(
|
||||
"indices_are_sorted",
|
||||
builder_->getBoolAttr(scatter->indices_are_sorted())));
|
||||
@ -577,15 +587,20 @@ StatusOr<mlir::Operation*> HloFunctionImporter::ImportInstruction(
|
||||
builder_->getNamedAttr("lhs_dilations", Convert(lhs_dilations)));
|
||||
attributes.push_back(
|
||||
builder_->getNamedAttr("rhs_dilations", Convert(rhs_dilations)));
|
||||
attributes.push_back(ConvertConvDimensionNumbers(
|
||||
instruction->convolution_dimension_numbers()));
|
||||
attributes.push_back(builder_->getNamedAttr(
|
||||
"dimension_numbers",
|
||||
ConvertConvDimensionNumbers(
|
||||
instruction->convolution_dimension_numbers(), builder_)));
|
||||
attributes.push_back(builder_->getNamedAttr(
|
||||
"feature_group_count",
|
||||
builder_->getI64IntegerAttr(instruction->feature_group_count())));
|
||||
attributes.push_back(builder_->getNamedAttr(
|
||||
"batch_group_count",
|
||||
builder_->getI64IntegerAttr(instruction->batch_group_count())));
|
||||
attributes.push_back(ConvertPrecisionConfig(instruction));
|
||||
attributes.push_back(builder_->getNamedAttr(
|
||||
"precision_config",
|
||||
ConvertPrecisionConfig(&instruction->precision_config(), builder_)));
|
||||
|
||||
MakeAndReturn(ConvOp);
|
||||
}
|
||||
|
||||
@ -717,20 +732,6 @@ StatusOr<Value> HloFunctionImporter::GetMlirValue(HloInstruction* instruction) {
|
||||
"Unable to find value for input: ", instruction->ToString()));
|
||||
}
|
||||
|
||||
mlir::NamedAttribute HloFunctionImporter::ConvertPrecisionConfig(
|
||||
HloInstruction* instruction) {
|
||||
// TODO(b/129709049) The HLO text format elides this in the all DEFAULT
|
||||
// case and the parser sticks it in. Maybe we should too.
|
||||
llvm::SmallVector<mlir::Attribute, 4> operand_precision_attrs;
|
||||
|
||||
for (auto prec : instruction->precision_config().operand_precision()) {
|
||||
operand_precision_attrs.push_back(
|
||||
builder_->getStringAttr(PrecisionConfig_Precision_Name(prec)));
|
||||
}
|
||||
return builder_->getNamedAttr(
|
||||
"precision_config", builder_->getArrayAttr(operand_precision_attrs));
|
||||
}
|
||||
|
||||
mlir::NamedAttribute HloFunctionImporter::ConvertComparisonDirection(
|
||||
HloInstruction* instruction) {
|
||||
return builder_->getNamedAttr(
|
||||
@ -751,10 +752,10 @@ mlir::DenseIntElementsAttr HloFunctionImporter::ConvertDimensions(
|
||||
}
|
||||
|
||||
mlir::DenseIntElementsAttr HloFunctionImporter::Convert(
|
||||
llvm::ArrayRef<int64_t> op_dimensions) {
|
||||
llvm::ArrayRef<int64_t> elements) {
|
||||
return DenseIntElementsAttr::get(
|
||||
RankedTensorType::get(op_dimensions.size(), builder_->getIntegerType(64)),
|
||||
op_dimensions);
|
||||
RankedTensorType::get(elements.size(), builder_->getIntegerType(64)),
|
||||
elements);
|
||||
}
|
||||
|
||||
mlir::NamedAttribute HloFunctionImporter::ConvertPadding(
|
||||
@ -766,86 +767,6 @@ mlir::NamedAttribute HloFunctionImporter::ConvertPadding(
|
||||
return builder_->getNamedAttr("padding", attr);
|
||||
}
|
||||
|
||||
mlir::NamedAttribute HloFunctionImporter::ConvertDotDimensionNumbers(
|
||||
const DotDimensionNumbers& dnums) {
|
||||
std::vector<int64_t> rhs_contracting_dimensions(
|
||||
dnums.rhs_contracting_dimensions().begin(),
|
||||
dnums.rhs_contracting_dimensions().end());
|
||||
std::vector<int64_t> lhs_contracting_dimensions(
|
||||
dnums.lhs_contracting_dimensions().begin(),
|
||||
dnums.lhs_contracting_dimensions().end());
|
||||
std::vector<int64_t> rhs_batch_dimensions(
|
||||
dnums.rhs_batch_dimensions().begin(), dnums.rhs_batch_dimensions().end());
|
||||
std::vector<int64_t> lhs_batch_dimensions(
|
||||
dnums.lhs_batch_dimensions().begin(), dnums.lhs_batch_dimensions().end());
|
||||
|
||||
// Push the attributes into our new DictionaryAttr.
|
||||
auto lhs_batch_dims_attr = Convert(lhs_batch_dimensions);
|
||||
auto rhs_batch_dims_attr = Convert(rhs_batch_dimensions);
|
||||
auto lhs_contracting_dims_attr = Convert(lhs_contracting_dimensions);
|
||||
auto rhs_contracting_dims_attr = Convert(rhs_contracting_dimensions);
|
||||
|
||||
auto attr = mlir::xla_hlo::DotDimensionNumbers::get(
|
||||
lhs_batch_dims_attr, rhs_batch_dims_attr, lhs_contracting_dims_attr,
|
||||
rhs_contracting_dims_attr, context_);
|
||||
return builder_->getNamedAttr("dot_dimension_numbers", attr);
|
||||
}
|
||||
|
||||
mlir::NamedAttribute HloFunctionImporter::ConvertConvDimensionNumbers(
|
||||
const xla::ConvolutionDimensionNumbers& dnums) {
|
||||
llvm::SmallVector<int64_t, 4> input_spatial_dims(
|
||||
dnums.input_spatial_dimensions().begin(),
|
||||
dnums.input_spatial_dimensions().end());
|
||||
llvm::SmallVector<int64_t, 4> kernel_spatial_dims(
|
||||
dnums.kernel_spatial_dimensions().begin(),
|
||||
dnums.kernel_spatial_dimensions().end());
|
||||
llvm::SmallVector<int64_t, 4> output_spatial_dims(
|
||||
dnums.output_spatial_dimensions().begin(),
|
||||
dnums.output_spatial_dimensions().end());
|
||||
auto attr = mlir::xla_hlo::ConvDimensionNumbers::get(
|
||||
builder_->getI64IntegerAttr(dnums.input_batch_dimension()),
|
||||
builder_->getI64IntegerAttr(dnums.input_feature_dimension()),
|
||||
Convert(input_spatial_dims),
|
||||
builder_->getI64IntegerAttr(dnums.kernel_input_feature_dimension()),
|
||||
builder_->getI64IntegerAttr(dnums.kernel_output_feature_dimension()),
|
||||
Convert(kernel_spatial_dims),
|
||||
builder_->getI64IntegerAttr(dnums.output_batch_dimension()),
|
||||
builder_->getI64IntegerAttr(dnums.kernel_output_feature_dimension()),
|
||||
Convert(output_spatial_dims), context_);
|
||||
return builder_->getNamedAttr("dimension_numbers", attr);
|
||||
}
|
||||
|
||||
mlir::NamedAttribute HloFunctionImporter::ConvertGatherDimensionNumbers(
|
||||
const xla::GatherDimensionNumbers& dnums) {
|
||||
std::vector<int64_t> offset_dims(dnums.offset_dims().begin(),
|
||||
dnums.offset_dims().end());
|
||||
std::vector<int64_t> collapsed_slice_dims(
|
||||
dnums.collapsed_slice_dims().begin(), dnums.collapsed_slice_dims().end());
|
||||
std::vector<int64_t> start_index_map(dnums.start_index_map().begin(),
|
||||
dnums.start_index_map().end());
|
||||
auto attr = mlir::xla_hlo::GatherDimensionNumbers::get(
|
||||
Convert(offset_dims), Convert(collapsed_slice_dims),
|
||||
Convert(start_index_map),
|
||||
builder_->getI64IntegerAttr(dnums.index_vector_dim()), context_);
|
||||
return builder_->getNamedAttr("dimension_numbers", attr);
|
||||
}
|
||||
|
||||
mlir::NamedAttribute HloFunctionImporter::ConvertScatterDimensionNumbers(
|
||||
const xla::ScatterDimensionNumbers& dnums) {
|
||||
std::vector<int64_t> update_window_dims(dnums.update_window_dims().begin(),
|
||||
dnums.update_window_dims().end());
|
||||
std::vector<int64_t> inserted_window_dims(
|
||||
dnums.inserted_window_dims().begin(), dnums.inserted_window_dims().end());
|
||||
std::vector<int64_t> scatter_dims_to_operand_dims(
|
||||
dnums.scatter_dims_to_operand_dims().begin(),
|
||||
dnums.scatter_dims_to_operand_dims().end());
|
||||
auto attr = mlir::xla_hlo::ScatterDimensionNumbers::get(
|
||||
Convert(update_window_dims), Convert(inserted_window_dims),
|
||||
Convert(scatter_dims_to_operand_dims),
|
||||
builder_->getI64IntegerAttr(dnums.index_vector_dim()), context_);
|
||||
return builder_->getNamedAttr("scatter_dimension_numbers", attr);
|
||||
}
|
||||
|
||||
mlir::NamedAttribute HloFunctionImporter::ConvertSourceTargetPairs(
|
||||
const std::vector<std::pair<tensorflow::int64, tensorflow::int64>>&
|
||||
source_target_pairs) {
|
||||
|
@ -89,9 +89,6 @@ class HloFunctionImporter {
|
||||
// Returns the Mlir Value for the corresponding HloInstruction.
|
||||
StatusOr<mlir::Value> GetMlirValue(xla::HloInstruction* instruction);
|
||||
|
||||
// Converts an XLA PrecisionConfig to the corresponding MLIR attribute.
|
||||
mlir::NamedAttribute ConvertPrecisionConfig(xla::HloInstruction* instruction);
|
||||
|
||||
// Converts an XLA ComparisonDirection to the corresponding MLIR attribute.
|
||||
mlir::NamedAttribute ConvertComparisonDirection(
|
||||
xla::HloInstruction* instruction);
|
||||
@ -101,28 +98,12 @@ class HloFunctionImporter {
|
||||
llvm::ArrayRef<tensorflow::int64> op_dimensions);
|
||||
|
||||
// Converts Array ref to an DenseIntElementsAttr.
|
||||
mlir::DenseIntElementsAttr Convert(llvm::ArrayRef<int64_t> op_dimensions);
|
||||
mlir::DenseIntElementsAttr Convert(llvm::ArrayRef<int64_t> elements);
|
||||
|
||||
// Converts Array ref to padding attribute. Input is a flattened list of
|
||||
// padding low and padding high for each of the spatial dimensions.
|
||||
mlir::NamedAttribute ConvertPadding(llvm::ArrayRef<int64_t> padding);
|
||||
|
||||
// Converts the dot dimensions to attribute.
|
||||
mlir::NamedAttribute ConvertDotDimensionNumbers(
|
||||
const DotDimensionNumbers& dnums);
|
||||
|
||||
// Converts the conv dimensions to attributes.
|
||||
mlir::NamedAttribute ConvertConvDimensionNumbers(
|
||||
const xla::ConvolutionDimensionNumbers& dnums);
|
||||
|
||||
// Converts the gather dimensions to attributes.
|
||||
mlir::NamedAttribute ConvertGatherDimensionNumbers(
|
||||
const xla::GatherDimensionNumbers& dnums);
|
||||
|
||||
// Converts the scatter dimensions to attributes.
|
||||
mlir::NamedAttribute ConvertScatterDimensionNumbers(
|
||||
const xla::ScatterDimensionNumbers& dnums);
|
||||
|
||||
// Converts replica groups to attribute
|
||||
mlir::NamedAttribute ConvertReplicaGroups(
|
||||
const std::vector<ReplicaGroup>& replica_groups);
|
||||
|
Loading…
Reference in New Issue
Block a user