Add MlirHloBuilder op implementations
PiperOrigin-RevId: 307472994 Change-Id: Ifbca316f653f44469cebd3aa5a507e8ccabf5001
This commit is contained in:
parent
63f2383118
commit
f65a9acfda
|
@ -165,6 +165,90 @@ StatusOr<XlaOp> MlirHloBuilder::AddOpWithShape(
|
|||
/*attributes=*/{});
|
||||
}
|
||||
|
||||
XlaOp MlirHloBuilder::CreateToken() {
|
||||
return ReportErrorOrReturn([&]() -> StatusOr<XlaOp> {
|
||||
return MakeXlaOp(builder_.create<mlir::xla_hlo::CreateTokenOp>(
|
||||
loc_, mlir::xla_hlo::TokenType::get(builder_.getContext())));
|
||||
});
|
||||
}
|
||||
|
||||
StatusOr<XlaOp> MlirHloBuilder::InfeedWithTokenInternal(
|
||||
const Shape& infeed_instruction_shape, XlaOp token, const string& config) {
|
||||
TF_ASSIGN_OR_RETURN(mlir::Type result_type,
|
||||
ConvertShapeToType<mlir::RankedTensorType>(
|
||||
infeed_instruction_shape, builder_));
|
||||
return MakeXlaOp(builder_.create<mlir::xla_hlo::InfeedOp>(
|
||||
loc_, result_type, GetValue(token),
|
||||
/*infeed_config=*/config));
|
||||
}
|
||||
|
||||
StatusOr<XlaOp> MlirHloBuilder::OutfeedWithTokenInternal(
|
||||
XlaOp operand, XlaOp token, const Shape& shape_with_layout,
|
||||
const string& outfeed_config) {
|
||||
auto token_type = mlir::xla_hlo::TokenType::get(builder_.getContext());
|
||||
return MakeXlaOp(builder_.create<mlir::xla_hlo::OutfeedOp>(
|
||||
loc_, token_type, GetValue(operand), GetValue(token), outfeed_config));
|
||||
}
|
||||
|
||||
StatusOr<XlaOp> MlirHloBuilder::ConcatInDimInternal(
|
||||
const Shape& shape, absl::Span<const XlaOp> operands, int64 dimension) {
|
||||
TF_ASSIGN_OR_RETURN(
|
||||
mlir::Type result_type,
|
||||
ConvertShapeToType<mlir::RankedTensorType>(shape, builder_));
|
||||
auto mlir_operands = GetValues(operands);
|
||||
return MakeXlaOp(builder_.create<mlir::xla_hlo::ConcatenateOp>(
|
||||
loc_, result_type, mlir_operands, builder_.getI64IntegerAttr(dimension)));
|
||||
}
|
||||
|
||||
StatusOr<XlaOp> MlirHloBuilder::GetTupleElementInternal(const Shape& shape,
|
||||
XlaOp tuple_data,
|
||||
int64 index) {
|
||||
TF_ASSIGN_OR_RETURN(
|
||||
mlir::Type result_type,
|
||||
ConvertShapeToType<mlir::RankedTensorType>(shape, builder_));
|
||||
return MakeXlaOp(builder_.create<mlir::xla_hlo::GetTupleElementOp>(
|
||||
loc_, result_type, GetValue(tuple_data),
|
||||
builder_.getI32IntegerAttr(index)));
|
||||
}
|
||||
|
||||
StatusOr<XlaOp> MlirHloBuilder::SliceInternal(
|
||||
const Shape& shape, XlaOp operand, absl::Span<const int64> start_indices,
|
||||
absl::Span<const int64> limit_indices, absl::Span<const int64> strides) {
|
||||
return MakeXlaOp(builder_.create<mlir::xla_hlo::SliceOp>(
|
||||
loc_, GetValue(operand), GetI64ElementsAttr(start_indices, &builder_),
|
||||
GetI64ElementsAttr(limit_indices, &builder_),
|
||||
GetI64ElementsAttr(strides, &builder_)));
|
||||
}
|
||||
|
||||
StatusOr<XlaOp> MlirHloBuilder::PadInternal(
|
||||
const Shape& shape, XlaOp operand, XlaOp padding_value,
|
||||
const PaddingConfig& padding_config) {
|
||||
TF_ASSIGN_OR_RETURN(
|
||||
mlir::Type result_type,
|
||||
ConvertShapeToType<mlir::RankedTensorType>(shape, builder_));
|
||||
std::vector<int64> low;
|
||||
std::vector<int64> high;
|
||||
std::vector<int64> internal;
|
||||
for (auto& dimension : padding_config.dimensions()) {
|
||||
low.push_back(dimension.edge_padding_low());
|
||||
high.push_back(dimension.edge_padding_high());
|
||||
internal.push_back(dimension.interior_padding());
|
||||
}
|
||||
return MakeXlaOp(builder_.create<mlir::xla_hlo::PadOp>(
|
||||
loc_, result_type, GetValue(operand), GetValue(padding_value),
|
||||
GetI64ElementsAttr(low, &builder_), GetI64ElementsAttr(high, &builder_),
|
||||
GetI64ElementsAttr(internal, &builder_)));
|
||||
}
|
||||
|
||||
StatusOr<XlaOp> MlirHloBuilder::TupleInternal(
|
||||
const Shape& shape, absl::Span<const XlaOp> elements) {
|
||||
mlir::SmallVector<mlir::Value, 4> operands;
|
||||
for (auto& element : elements) {
|
||||
operands.push_back(GetValue(element));
|
||||
}
|
||||
return MakeXlaOp(builder_.create<mlir::xla_hlo::TupleOp>(loc_, operands));
|
||||
}
|
||||
|
||||
StatusOr<XlaOp> MlirHloBuilder::CreateOp(
|
||||
const std::string& op_name, const Shape& shape,
|
||||
llvm::ArrayRef<XlaOp> operands,
|
||||
|
|
|
@ -54,6 +54,9 @@ class MlirHloBuilder : public XlaBuilder {
|
|||
// TODO(hinsu): Add a constructor to build a new MLIR function from scratch
|
||||
// and override Build methods.
|
||||
|
||||
MlirHloBuilder(std::string name, mlir::OpBuilder builder, mlir::Location loc)
|
||||
: XlaBuilder(name), builder_(builder), loc_(loc) {}
|
||||
|
||||
MlirHloBuilder(const MlirHloBuilder&) = delete;
|
||||
MlirHloBuilder& operator=(const MlirHloBuilder&) = delete;
|
||||
|
||||
|
@ -75,6 +78,17 @@ class MlirHloBuilder : public XlaBuilder {
|
|||
return mlir::Value::getFromOpaquePointer(ptr);
|
||||
}
|
||||
|
||||
// Returns MLIR values corresponding to the given XLA ops.
|
||||
//
|
||||
// Requires that the ops were created by this builder.
|
||||
std::vector<mlir::Value> GetValues(absl::Span<const XlaOp> ops) {
|
||||
std::vector<mlir::Value> values;
|
||||
for (auto xla_op : ops) {
|
||||
values.push_back(GetValue(xla_op));
|
||||
}
|
||||
return values;
|
||||
}
|
||||
|
||||
// Sets location for newly built ops, until reset.
|
||||
void SetLocation(mlir::Location loc) { loc_ = loc; }
|
||||
|
||||
|
@ -120,6 +134,34 @@ class MlirHloBuilder : public XlaBuilder {
|
|||
StatusOr<XlaOp> AddOpWithShape(HloOpcode opcode, const Shape& shape,
|
||||
absl::Span<const XlaOp> operands) override;
|
||||
|
||||
XlaOp CreateToken() override;
|
||||
|
||||
StatusOr<XlaOp> InfeedWithTokenInternal(const Shape& infeed_instruction_shape,
|
||||
XlaOp token,
|
||||
const string& config) override;
|
||||
StatusOr<XlaOp> OutfeedWithTokenInternal(
|
||||
XlaOp operand, XlaOp token, const Shape& shape_with_layout,
|
||||
const string& outfeed_config) override;
|
||||
|
||||
StatusOr<XlaOp> ConcatInDimInternal(const Shape& shape,
|
||||
absl::Span<const XlaOp> operands,
|
||||
int64 dimension) override;
|
||||
|
||||
StatusOr<XlaOp> GetTupleElementInternal(const Shape& shape, XlaOp tuple_data,
|
||||
int64 index) override;
|
||||
|
||||
StatusOr<XlaOp> SliceInternal(const Shape& shape, XlaOp operand,
|
||||
absl::Span<const int64> start_indices,
|
||||
absl::Span<const int64> limit_indices,
|
||||
absl::Span<const int64> strides) override;
|
||||
|
||||
StatusOr<XlaOp> PadInternal(const Shape& shape, XlaOp operand,
|
||||
XlaOp padding_value,
|
||||
const PaddingConfig& padding_config) override;
|
||||
|
||||
StatusOr<XlaOp> TupleInternal(const Shape& shape,
|
||||
absl::Span<const XlaOp> elements) override;
|
||||
|
||||
// Creates HLO dialect op and returns the result as an XlaOp.
|
||||
StatusOr<XlaOp> CreateOp(const std::string& op_name, const Shape& shape,
|
||||
llvm::ArrayRef<XlaOp> operands,
|
||||
|
|
|
@ -1,4 +1,5 @@
|
|||
load("//tensorflow/compiler/mlir:glob_lit_test.bzl", "glob_lit_tests")
|
||||
load("//tensorflow:tensorflow.bzl", "tf_cc_test")
|
||||
|
||||
package(licenses = ["notice"])
|
||||
|
||||
|
@ -18,3 +19,18 @@ filegroup(
|
|||
"@llvm-project//llvm:FileCheck",
|
||||
],
|
||||
)
|
||||
|
||||
tf_cc_test(
|
||||
name = "mlir_hlo_builder_test",
|
||||
srcs = ["mlir_hlo_builder_test.cc"],
|
||||
deps = [
|
||||
"//tensorflow/compiler/mlir/xla:hlo",
|
||||
"//tensorflow/compiler/mlir/xla:mlir_hlo_builder",
|
||||
"//tensorflow/compiler/xla:shape_util",
|
||||
"//tensorflow/core:test",
|
||||
"//tensorflow/core:test_main",
|
||||
"//tensorflow/core:testlib",
|
||||
"@llvm-project//llvm:support",
|
||||
"@llvm-project//mlir:IR",
|
||||
],
|
||||
)
|
||||
|
|
|
@ -0,0 +1,179 @@
|
|||
/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
|
||||
#include "tensorflow/compiler/mlir/xla/ir/mlir_hlo_builder.h"
|
||||
|
||||
#include <string>
|
||||
|
||||
#include "llvm/Support/raw_ostream.h"
|
||||
#include "mlir/IR/Attributes.h" // from @llvm-project
|
||||
#include "mlir/IR/Builders.h" // from @llvm-project
|
||||
#include "mlir/IR/Dialect.h" // from @llvm-project
|
||||
#include "mlir/IR/Location.h" // from @llvm-project
|
||||
#include "mlir/IR/MLIRContext.h" // from @llvm-project
|
||||
#include "mlir/IR/Module.h" // from @llvm-project
|
||||
#include "tensorflow/compiler/mlir/xla/ir/hlo_ops.h"
|
||||
#include "tensorflow/compiler/xla/shape_util.h"
|
||||
#include "tensorflow/core/lib/core/status_test_util.h"
|
||||
#include "tensorflow/core/platform/test.h"
|
||||
|
||||
namespace xla {
|
||||
|
||||
namespace {
|
||||
|
||||
static void ExpectHasSubstr(absl::string_view s, absl::string_view expected) {
|
||||
EXPECT_TRUE(absl::StrContains(s, expected))
|
||||
<< s << " does not contain " << expected;
|
||||
}
|
||||
|
||||
class XlaBuilderTest : public ::testing::Test {
|
||||
protected:
|
||||
XlaBuilderTest()
|
||||
: name_(SetupTest()),
|
||||
context_(),
|
||||
module_(mlir::ModuleOp::create(mlir::UnknownLoc::get(&context_))),
|
||||
builder_(&module_->getBodyRegion()),
|
||||
xla_builder_(name_, builder_, module_->getLoc()) {}
|
||||
|
||||
string SetupTest() {
|
||||
mlir::registerDialect<mlir::xla_hlo::XlaHloDialect>();
|
||||
return ::testing::UnitTest::GetInstance()->current_test_info()->name();
|
||||
}
|
||||
|
||||
// Retuns the MLIR op string representation of the given XlaOp.
|
||||
string GetMlirOpString(XlaOp xla_op) {
|
||||
string str;
|
||||
llvm::raw_string_ostream ostream{str};
|
||||
xla_builder_.GetValue(xla_op).print(ostream);
|
||||
ostream.flush();
|
||||
return str;
|
||||
}
|
||||
|
||||
string name_;
|
||||
mlir::MLIRContext context_;
|
||||
mlir::OwningModuleRef module_;
|
||||
mlir::OpBuilder builder_;
|
||||
MlirHloBuilder xla_builder_;
|
||||
};
|
||||
|
||||
TEST_F(XlaBuilderTest, CreateToken) {
|
||||
auto token = CreateToken(&xla_builder_);
|
||||
auto str = GetMlirOpString(token);
|
||||
|
||||
TF_ASSERT_OK(xla_builder_.GetCurrentStatus());
|
||||
|
||||
ExpectHasSubstr(GetMlirOpString(token),
|
||||
R"("xla_hlo.create_token"() : () -> !xla_hlo.token)");
|
||||
}
|
||||
|
||||
TEST_F(XlaBuilderTest, Infeed) {
|
||||
auto token = CreateToken(&xla_builder_);
|
||||
auto infeed = InfeedWithToken(token, ShapeUtil::MakeShape(F32, {4, 8}), "");
|
||||
|
||||
TF_ASSERT_OK(xla_builder_.GetCurrentStatus());
|
||||
ExpectHasSubstr(
|
||||
GetMlirOpString(infeed),
|
||||
R"("xla_hlo.infeed"(%0) {infeed_config = ""} : (!xla_hlo.token) -> tuple<tensor<4x8xf32>, !xla_hlo.token>)");
|
||||
}
|
||||
|
||||
TEST_F(XlaBuilderTest, Outfeed) {
|
||||
auto outfeed_shape = ShapeUtil::MakeShape(F32, {4, 8});
|
||||
auto data = ConstantLiteral(
|
||||
&xla_builder_,
|
||||
LiteralUtil::CreateFromDimensions(F32, outfeed_shape.dimensions()));
|
||||
auto token = CreateToken(&xla_builder_);
|
||||
auto outfeed = OutfeedWithToken(data, token, outfeed_shape, "");
|
||||
|
||||
TF_ASSERT_OK(xla_builder_.GetCurrentStatus());
|
||||
ExpectHasSubstr(
|
||||
GetMlirOpString(outfeed),
|
||||
R"("xla_hlo.outfeed"(%0, %1) {outfeed_config = ""} : (tensor<4x8xf32>, !xla_hlo.token) -> !xla_hlo.token)");
|
||||
}
|
||||
|
||||
TEST_F(XlaBuilderTest, ConcatInDim) {
|
||||
auto data0 = ConstantLiteral(
|
||||
&xla_builder_, LiteralUtil::CreateFromDimensions(F32, {2, 4, 5}));
|
||||
auto data1 = ConstantLiteral(
|
||||
&xla_builder_, LiteralUtil::CreateFromDimensions(F32, {2, 6, 5}));
|
||||
auto concat = ConcatInDim(&xla_builder_, {data0, data1}, 1);
|
||||
|
||||
TF_ASSERT_OK(xla_builder_.GetCurrentStatus());
|
||||
ExpectHasSubstr(
|
||||
GetMlirOpString(concat),
|
||||
R"("xla_hlo.concatenate"(%0, %1) {dimension = 1 : i64} : (tensor<2x4x5xf32>, tensor<2x6x5xf32>) -> tensor<2x10x5xf32>)");
|
||||
}
|
||||
|
||||
TEST_F(XlaBuilderTest, Tuple) {
|
||||
auto data0 = ConstantLiteral(&xla_builder_,
|
||||
LiteralUtil::CreateFromDimensions(F32, {3, 7}));
|
||||
auto data1 = ConstantLiteral(&xla_builder_,
|
||||
LiteralUtil::CreateFromDimensions(F32, {}));
|
||||
auto tuple = Tuple(&xla_builder_, {data0, data1});
|
||||
|
||||
TF_ASSERT_OK(xla_builder_.GetCurrentStatus());
|
||||
ExpectHasSubstr(
|
||||
GetMlirOpString(tuple),
|
||||
R"("xla_hlo.tuple"(%0, %1) : (tensor<3x7xf32>, tensor<f32>) -> tuple<tensor<3x7xf32>, tensor<f32>>)");
|
||||
}
|
||||
|
||||
TEST_F(XlaBuilderTest, GetTupleElement) {
|
||||
auto data0 = ConstantLiteral(&xla_builder_,
|
||||
LiteralUtil::CreateFromDimensions(F32, {3, 7}));
|
||||
auto data1 = ConstantLiteral(&xla_builder_,
|
||||
LiteralUtil::CreateFromDimensions(F32, {}));
|
||||
auto tuple_data = Tuple(&xla_builder_, {data0, data1});
|
||||
auto gte = GetTupleElement(tuple_data, 1);
|
||||
|
||||
TF_ASSERT_OK(xla_builder_.GetCurrentStatus());
|
||||
ExpectHasSubstr(
|
||||
GetMlirOpString(gte),
|
||||
R"("xla_hlo.get_tuple_element"(%2) {index = 1 : i32} : (tuple<tensor<3x7xf32>, tensor<f32>>) -> tensor<f32>)");
|
||||
}
|
||||
|
||||
TEST_F(XlaBuilderTest, Slice) {
|
||||
auto data = ConstantLiteral(&xla_builder_,
|
||||
LiteralUtil::CreateFromDimensions(F32, {3, 7}));
|
||||
auto slice = Slice(data, {0, 1}, {2, 5}, {1, 1});
|
||||
|
||||
TF_ASSERT_OK(xla_builder_.GetCurrentStatus());
|
||||
ExpectHasSubstr(
|
||||
GetMlirOpString(slice),
|
||||
R"("xla_hlo.slice"(%0) {limit_indices = dense<[2, 5]> : tensor<2xi64>, start_indices = dense<[0, 1]> : tensor<2xi64>, strides = dense<1> : tensor<2xi64>} : (tensor<3x7xf32>) -> tensor<2x4xf32>)");
|
||||
}
|
||||
|
||||
TEST_F(XlaBuilderTest, Pad) {
|
||||
auto data = ConstantLiteral(&xla_builder_,
|
||||
LiteralUtil::CreateFromDimensions(F32, {3, 7}));
|
||||
auto zero = ConstantLiteral(&xla_builder_, LiteralUtil::Zero(F32));
|
||||
|
||||
PaddingConfig padding_config;
|
||||
auto* dims0 = padding_config.add_dimensions();
|
||||
dims0->set_edge_padding_low(1);
|
||||
dims0->set_interior_padding(0);
|
||||
dims0->set_edge_padding_high(2);
|
||||
auto* dims1 = padding_config.add_dimensions();
|
||||
dims1->set_edge_padding_low(3);
|
||||
dims1->set_interior_padding(1);
|
||||
dims1->set_edge_padding_high(0);
|
||||
auto pad = Pad(data, zero, padding_config);
|
||||
|
||||
TF_ASSERT_OK(xla_builder_.GetCurrentStatus());
|
||||
ExpectHasSubstr(
|
||||
GetMlirOpString(pad),
|
||||
R"("xla_hlo.pad"(%0, %1) {edge_padding_high = dense<[2, 0]> : tensor<2xi64>, edge_padding_low = dense<[1, 3]> : tensor<2xi64>, interior_padding = dense<[0, 1]> : tensor<2xi64>} : (tensor<3x7xf32>, tensor<f32>) -> tensor<6x16xf32>)");
|
||||
}
|
||||
|
||||
} // namespace
|
||||
} // namespace xla
|
|
@ -822,23 +822,29 @@ XlaOp XlaBuilder::Slice(XlaOp operand, absl::Span<const int64> start_indices,
|
|||
absl::Span<const int64> limit_indices,
|
||||
absl::Span<const int64> strides) {
|
||||
return ReportErrorOrReturn([&]() -> StatusOr<XlaOp> {
|
||||
HloInstructionProto instr;
|
||||
TF_ASSIGN_OR_RETURN(const Shape* operand_shape, GetShapePtr(operand));
|
||||
TF_ASSIGN_OR_RETURN(Shape shape, ShapeInference::InferSliceShape(
|
||||
*operand_shape, start_indices,
|
||||
limit_indices, strides));
|
||||
*instr.mutable_shape() = shape.ToProto();
|
||||
for (int i = 0; i < start_indices.size(); i++) {
|
||||
auto* slice_config = instr.add_slice_dimensions();
|
||||
slice_config->set_start(start_indices[i]);
|
||||
slice_config->set_limit(limit_indices[i]);
|
||||
slice_config->set_stride(strides[i]);
|
||||
}
|
||||
|
||||
return AddInstruction(std::move(instr), HloOpcode::kSlice, {operand});
|
||||
return SliceInternal(shape, operand, start_indices, limit_indices, strides);
|
||||
});
|
||||
}
|
||||
|
||||
StatusOr<XlaOp> XlaBuilder::SliceInternal(const Shape& shape, XlaOp operand,
|
||||
absl::Span<const int64> start_indices,
|
||||
absl::Span<const int64> limit_indices,
|
||||
absl::Span<const int64> strides) {
|
||||
HloInstructionProto instr;
|
||||
*instr.mutable_shape() = shape.ToProto();
|
||||
for (int i = 0; i < start_indices.size(); i++) {
|
||||
auto* slice_config = instr.add_slice_dimensions();
|
||||
slice_config->set_start(start_indices[i]);
|
||||
slice_config->set_limit(limit_indices[i]);
|
||||
slice_config->set_stride(strides[i]);
|
||||
}
|
||||
return AddInstruction(std::move(instr), HloOpcode::kSlice, {operand});
|
||||
}
|
||||
|
||||
XlaOp XlaBuilder::SliceInDim(XlaOp operand, int64 start_index,
|
||||
int64 limit_index, int64 stride, int64 dimno) {
|
||||
return ReportErrorOrReturn([&]() -> StatusOr<XlaOp> {
|
||||
|
@ -952,41 +958,49 @@ XlaOp XlaBuilder::DynamicUpdateSlice(XlaOp operand, XlaOp update,
|
|||
XlaOp XlaBuilder::ConcatInDim(absl::Span<const XlaOp> operands,
|
||||
int64 dimension) {
|
||||
return ReportErrorOrReturn([&]() -> StatusOr<XlaOp> {
|
||||
HloInstructionProto instr;
|
||||
|
||||
std::vector<const Shape*> operand_shape_ptrs;
|
||||
TF_ASSIGN_OR_RETURN(const auto& operand_shapes, GetOperandShapes(operands));
|
||||
absl::c_transform(operand_shapes, std::back_inserter(operand_shape_ptrs),
|
||||
[](const Shape& shape) { return &shape; });
|
||||
TF_ASSIGN_OR_RETURN(Shape shape, ShapeInference::InferConcatOpShape(
|
||||
operand_shape_ptrs, dimension));
|
||||
*instr.mutable_shape() = shape.ToProto();
|
||||
|
||||
instr.add_dimensions(dimension);
|
||||
|
||||
return AddInstruction(std::move(instr), HloOpcode::kConcatenate, operands);
|
||||
return ConcatInDimInternal(shape, operands, dimension);
|
||||
});
|
||||
}
|
||||
|
||||
StatusOr<XlaOp> XlaBuilder::ConcatInDimInternal(
|
||||
const Shape& shape, absl::Span<const XlaOp> operands, int64 dimension) {
|
||||
HloInstructionProto instr;
|
||||
*instr.mutable_shape() = shape.ToProto();
|
||||
|
||||
instr.add_dimensions(dimension);
|
||||
|
||||
return AddInstruction(std::move(instr), HloOpcode::kConcatenate, operands);
|
||||
}
|
||||
|
||||
XlaOp XlaBuilder::Pad(XlaOp operand, XlaOp padding_value,
|
||||
const PaddingConfig& padding_config) {
|
||||
return ReportErrorOrReturn([&]() -> StatusOr<XlaOp> {
|
||||
HloInstructionProto instr;
|
||||
|
||||
TF_ASSIGN_OR_RETURN(const Shape* operand_shape, GetShapePtr(operand));
|
||||
TF_ASSIGN_OR_RETURN(const Shape* padding_value_shape,
|
||||
GetShapePtr(padding_value));
|
||||
TF_ASSIGN_OR_RETURN(
|
||||
Shape shape, ShapeInference::InferPadShape(
|
||||
*operand_shape, *padding_value_shape, padding_config));
|
||||
*instr.mutable_shape() = shape.ToProto();
|
||||
*instr.mutable_padding_config() = padding_config;
|
||||
|
||||
return AddInstruction(std::move(instr), HloOpcode::kPad,
|
||||
{operand, padding_value});
|
||||
return PadInternal(shape, operand, padding_value, padding_config);
|
||||
});
|
||||
}
|
||||
|
||||
StatusOr<XlaOp> XlaBuilder::PadInternal(const Shape& shape, XlaOp operand,
|
||||
XlaOp padding_value,
|
||||
const PaddingConfig& padding_config) {
|
||||
HloInstructionProto instr;
|
||||
*instr.mutable_shape() = shape.ToProto();
|
||||
*instr.mutable_padding_config() = padding_config;
|
||||
return AddInstruction(std::move(instr), HloOpcode::kPad,
|
||||
{operand, padding_value});
|
||||
}
|
||||
|
||||
XlaOp XlaBuilder::Reshape(XlaOp operand, absl::Span<const int64> dimensions,
|
||||
absl::Span<const int64> new_sizes,
|
||||
int64 inferred_dimension) {
|
||||
|
@ -1080,7 +1094,6 @@ XlaOp XlaBuilder::Select(XlaOp pred, XlaOp on_true, XlaOp on_false) {
|
|||
|
||||
XlaOp XlaBuilder::Tuple(absl::Span<const XlaOp> elements) {
|
||||
return ReportErrorOrReturn([&]() -> StatusOr<XlaOp> {
|
||||
HloInstructionProto instr;
|
||||
std::vector<const Shape*> operand_shape_ptrs;
|
||||
TF_ASSIGN_OR_RETURN(const auto& operand_shapes, GetOperandShapes(elements));
|
||||
absl::c_transform(operand_shapes, std::back_inserter(operand_shape_ptrs),
|
||||
|
@ -1088,14 +1101,19 @@ XlaOp XlaBuilder::Tuple(absl::Span<const XlaOp> elements) {
|
|||
TF_ASSIGN_OR_RETURN(const Shape shape,
|
||||
ShapeInference::InferVariadicOpShape(
|
||||
HloOpcode::kTuple, operand_shape_ptrs));
|
||||
*instr.mutable_shape() = shape.ToProto();
|
||||
return AddInstruction(std::move(instr), HloOpcode::kTuple, elements);
|
||||
return TupleInternal(shape, elements);
|
||||
});
|
||||
}
|
||||
|
||||
StatusOr<XlaOp> XlaBuilder::TupleInternal(const Shape& shape,
|
||||
absl::Span<const XlaOp> elements) {
|
||||
HloInstructionProto instr;
|
||||
*instr.mutable_shape() = shape.ToProto();
|
||||
return AddInstruction(std::move(instr), HloOpcode::kTuple, elements);
|
||||
}
|
||||
|
||||
XlaOp XlaBuilder::GetTupleElement(XlaOp tuple_data, int64 index) {
|
||||
return ReportErrorOrReturn([&]() -> StatusOr<XlaOp> {
|
||||
HloInstructionProto instr;
|
||||
TF_ASSIGN_OR_RETURN(const Shape* tuple_shape, GetShapePtr(tuple_data));
|
||||
if (!tuple_shape->IsTuple()) {
|
||||
return InvalidArgument(
|
||||
|
@ -1107,16 +1125,22 @@ XlaOp XlaBuilder::GetTupleElement(XlaOp tuple_data, int64 index) {
|
|||
"GetTupleElement() index (%d) out of range for tuple shape %s", index,
|
||||
ShapeUtil::HumanString(*tuple_shape));
|
||||
}
|
||||
*instr.mutable_shape() =
|
||||
ShapeUtil::GetTupleElementShape(*tuple_shape, index).ToProto();
|
||||
|
||||
instr.set_tuple_index(index);
|
||||
|
||||
return AddInstruction(std::move(instr), HloOpcode::kGetTupleElement,
|
||||
{tuple_data});
|
||||
return GetTupleElementInternal(
|
||||
ShapeUtil::GetTupleElementShape(*tuple_shape, index), tuple_data,
|
||||
index);
|
||||
});
|
||||
}
|
||||
|
||||
StatusOr<XlaOp> XlaBuilder::GetTupleElementInternal(const Shape& shape,
|
||||
XlaOp tuple_data,
|
||||
int64 index) {
|
||||
HloInstructionProto instr;
|
||||
*instr.mutable_shape() = shape.ToProto();
|
||||
instr.set_tuple_index(index);
|
||||
return AddInstruction(std::move(instr), HloOpcode::kGetTupleElement,
|
||||
{tuple_data});
|
||||
}
|
||||
|
||||
XlaOp XlaBuilder::Dot(XlaOp lhs, XlaOp rhs,
|
||||
const PrecisionConfig* precision_config) {
|
||||
return ReportErrorOrReturn([&]() -> StatusOr<XlaOp> {
|
||||
|
@ -1407,14 +1431,11 @@ XlaOp XlaBuilder::Infeed(const Shape& shape, const string& config) {
|
|||
XlaOp XlaBuilder::InfeedWithToken(XlaOp token, const Shape& shape,
|
||||
const string& config) {
|
||||
return ReportErrorOrReturn([&]() -> StatusOr<XlaOp> {
|
||||
HloInstructionProto instr;
|
||||
if (!LayoutUtil::HasLayout(shape)) {
|
||||
return InvalidArgument("Given shape to Infeed must have a layout");
|
||||
}
|
||||
const Shape infeed_instruction_shape =
|
||||
ShapeUtil::MakeTupleShape({shape, ShapeUtil::MakeTokenShape()});
|
||||
*instr.mutable_shape() = infeed_instruction_shape.ToProto();
|
||||
instr.set_infeed_config(config);
|
||||
|
||||
if (shape.IsArray() && sharding() &&
|
||||
sharding()->type() == OpSharding::OTHER) {
|
||||
|
@ -1427,11 +1448,18 @@ XlaOp XlaBuilder::InfeedWithToken(XlaOp token, const Shape& shape,
|
|||
return InvalidArgument(
|
||||
"Replicated sharding is not yet supported for infeeds");
|
||||
}
|
||||
|
||||
return AddInstruction(std::move(instr), HloOpcode::kInfeed, {token});
|
||||
return InfeedWithTokenInternal(infeed_instruction_shape, token, config);
|
||||
});
|
||||
}
|
||||
|
||||
StatusOr<XlaOp> XlaBuilder::InfeedWithTokenInternal(
|
||||
const Shape& infeed_instruction_shape, XlaOp token, const string& config) {
|
||||
HloInstructionProto instr;
|
||||
*instr.mutable_shape() = infeed_instruction_shape.ToProto();
|
||||
instr.set_infeed_config(config);
|
||||
return AddInstruction(std::move(instr), HloOpcode::kInfeed, {token});
|
||||
}
|
||||
|
||||
void XlaBuilder::Outfeed(XlaOp operand, const Shape& shape_with_layout,
|
||||
const string& outfeed_config) {
|
||||
ReportErrorOrReturn([&]() -> StatusOr<XlaOp> {
|
||||
|
@ -1488,10 +1516,6 @@ XlaOp XlaBuilder::OutfeedWithToken(XlaOp operand, XlaOp token,
|
|||
const Shape& shape_with_layout,
|
||||
const string& outfeed_config) {
|
||||
return ReportErrorOrReturn([&]() -> StatusOr<XlaOp> {
|
||||
HloInstructionProto instr;
|
||||
|
||||
*instr.mutable_shape() = ShapeUtil::MakeTokenShape().ToProto();
|
||||
|
||||
// Check and set outfeed shape.
|
||||
if (!LayoutUtil::HasLayout(shape_with_layout)) {
|
||||
return InvalidArgument("Given shape to Outfeed must have a layout");
|
||||
|
@ -1503,15 +1527,22 @@ XlaOp XlaBuilder::OutfeedWithToken(XlaOp operand, XlaOp token,
|
|||
ShapeUtil::HumanStringWithLayout(shape_with_layout),
|
||||
ShapeUtil::HumanStringWithLayout(*operand_shape));
|
||||
}
|
||||
*instr.mutable_outfeed_shape() = shape_with_layout.ToProto();
|
||||
|
||||
instr.set_outfeed_config(outfeed_config);
|
||||
|
||||
return AddInstruction(std::move(instr), HloOpcode::kOutfeed,
|
||||
{operand, token});
|
||||
return OutfeedWithTokenInternal(operand, token, shape_with_layout,
|
||||
outfeed_config);
|
||||
});
|
||||
}
|
||||
|
||||
StatusOr<XlaOp> XlaBuilder::OutfeedWithTokenInternal(
|
||||
XlaOp operand, XlaOp token, const Shape& shape_with_layout,
|
||||
const string& outfeed_config) {
|
||||
HloInstructionProto instr;
|
||||
*instr.mutable_shape() = ShapeUtil::MakeTokenShape().ToProto();
|
||||
*instr.mutable_outfeed_shape() = shape_with_layout.ToProto();
|
||||
instr.set_outfeed_config(outfeed_config);
|
||||
return AddInstruction(std::move(instr), HloOpcode::kOutfeed,
|
||||
{operand, token});
|
||||
}
|
||||
|
||||
XlaOp XlaBuilder::CreateToken() {
|
||||
return ReportErrorOrReturn([&]() -> StatusOr<XlaOp> {
|
||||
HloInstructionProto instr;
|
||||
|
|
|
@ -364,6 +364,10 @@ class XlaBuilder {
|
|||
Status SetInstructionFrontendAttribute(XlaOp op, string attribute,
|
||||
string value);
|
||||
|
||||
// Returns shapes for the operands.
|
||||
StatusOr<std::vector<Shape>> GetOperandShapes(
|
||||
absl::Span<const XlaOp> operands) const;
|
||||
|
||||
private:
|
||||
// Build helper which takes the id of the root operation..
|
||||
StatusOr<XlaComputation> Build(int64 root_id, bool remove_dynamic_dimensions);
|
||||
|
@ -391,6 +395,10 @@ class XlaBuilder {
|
|||
XlaOp Pad(XlaOp operand, XlaOp padding_value,
|
||||
const PaddingConfig& padding_config);
|
||||
|
||||
virtual StatusOr<XlaOp> PadInternal(const Shape& shape, XlaOp operand,
|
||||
XlaOp padding_value,
|
||||
const PaddingConfig& padding_config);
|
||||
|
||||
XlaOp Reshape(XlaOp operand, absl::Span<const int64> dimensions,
|
||||
absl::Span<const int64> new_sizes,
|
||||
int64 inferred_dimension = -1);
|
||||
|
@ -406,9 +414,12 @@ class XlaBuilder {
|
|||
XlaOp Slice(XlaOp operand, absl::Span<const int64> start_indices,
|
||||
absl::Span<const int64> limit_indices,
|
||||
absl::Span<const int64> strides);
|
||||
|
||||
XlaOp SliceInDim(XlaOp operand, int64 start_index, int64 limit_index,
|
||||
int64 stride, int64 dimno);
|
||||
virtual StatusOr<XlaOp> SliceInternal(const Shape& shape, XlaOp operand,
|
||||
absl::Span<const int64> start_indices,
|
||||
absl::Span<const int64> limit_indices,
|
||||
absl::Span<const int64> strides);
|
||||
virtual XlaOp SliceInDim(XlaOp operand, int64 start_index, int64 limit_index,
|
||||
int64 stride, int64 dimno);
|
||||
|
||||
ABSL_DEPRECATED("Use span-of-indices form instead")
|
||||
XlaOp DynamicSlice(XlaOp operand, XlaOp start_indices,
|
||||
|
@ -422,14 +433,22 @@ class XlaBuilder {
|
|||
absl::Span<const XlaOp> start_indices);
|
||||
|
||||
XlaOp ConcatInDim(absl::Span<const XlaOp> operands, int64 dimension);
|
||||
virtual StatusOr<XlaOp> ConcatInDimInternal(const Shape& shape,
|
||||
absl::Span<const XlaOp> operands,
|
||||
int64 dimension);
|
||||
|
||||
void Trace(const string& tag, XlaOp operand);
|
||||
|
||||
XlaOp Select(XlaOp pred, XlaOp on_true, XlaOp on_false);
|
||||
|
||||
XlaOp Tuple(absl::Span<const XlaOp> elements);
|
||||
virtual StatusOr<XlaOp> TupleInternal(const Shape& shape,
|
||||
absl::Span<const XlaOp> elements);
|
||||
|
||||
XlaOp GetTupleElement(XlaOp tuple_data, int64 index);
|
||||
virtual StatusOr<XlaOp> GetTupleElementInternal(const Shape& shape,
|
||||
XlaOp tuple_data,
|
||||
int64 index);
|
||||
|
||||
XlaOp Dot(XlaOp lhs, XlaOp rhs,
|
||||
const PrecisionConfig* precision_config = nullptr);
|
||||
|
@ -476,15 +495,18 @@ class XlaBuilder {
|
|||
absl::Span<const int64> fft_length);
|
||||
|
||||
XlaOp Infeed(const Shape& shape, const string& config = "");
|
||||
XlaOp InfeedWithToken(XlaOp token, const Shape& shape,
|
||||
const string& config = "");
|
||||
XlaOp InfeedWithToken(XlaOp token, const Shape& shape, const string& config);
|
||||
virtual StatusOr<XlaOp> InfeedWithTokenInternal(
|
||||
const Shape& infeed_instruction_shape, XlaOp token, const string& config);
|
||||
|
||||
void Outfeed(XlaOp operand, const Shape& shape_with_layout,
|
||||
const string& outfeed_config);
|
||||
XlaOp OutfeedWithToken(XlaOp operand, XlaOp token,
|
||||
const Shape& shape_with_layout,
|
||||
const string& outfeed_config);
|
||||
|
||||
virtual StatusOr<XlaOp> OutfeedWithTokenInternal(
|
||||
XlaOp operand, XlaOp token, const Shape& shape_with_layout,
|
||||
const string& outfeed_config);
|
||||
XlaOp Call(const XlaComputation& computation,
|
||||
absl::Span<const XlaOp> operands);
|
||||
|
||||
|
@ -624,7 +646,7 @@ class XlaBuilder {
|
|||
XlaOp RecvFromHost(XlaOp token, const Shape& shape,
|
||||
const ChannelHandle& handle);
|
||||
|
||||
XlaOp CreateToken();
|
||||
virtual XlaOp CreateToken();
|
||||
|
||||
XlaOp AfterAll(absl::Span<const XlaOp> tokens);
|
||||
|
||||
|
@ -701,10 +723,6 @@ class XlaBuilder {
|
|||
// Returns the (inferred) result for the program shape using the given root.
|
||||
StatusOr<ProgramShape> GetProgramShape(int64 root_id) const;
|
||||
|
||||
// Returns shapes for the operands.
|
||||
StatusOr<std::vector<Shape>> GetOperandShapes(
|
||||
absl::Span<const XlaOp> operands) const;
|
||||
|
||||
// A visitor which checks whether an operation is a compile-time constant,
|
||||
// meaning that it doesn't depend on any parameters, or on any stateful
|
||||
// operation such as `RngNormal` or `Infeed`. The visitor walks the
|
||||
|
|
Loading…
Reference in New Issue