Add MlirHloBuilder op implementations

PiperOrigin-RevId: 307472994
Change-Id: Ifbca316f653f44469cebd3aa5a507e8ccabf5001
This commit is contained in:
HyoukJoong Lee 2020-04-20 14:08:59 -07:00 committed by TensorFlower Gardener
parent 63f2383118
commit f65a9acfda
6 changed files with 431 additions and 61 deletions

View File

@ -165,6 +165,90 @@ StatusOr<XlaOp> MlirHloBuilder::AddOpWithShape(
/*attributes=*/{});
}
XlaOp MlirHloBuilder::CreateToken() {
return ReportErrorOrReturn([&]() -> StatusOr<XlaOp> {
return MakeXlaOp(builder_.create<mlir::xla_hlo::CreateTokenOp>(
loc_, mlir::xla_hlo::TokenType::get(builder_.getContext())));
});
}
StatusOr<XlaOp> MlirHloBuilder::InfeedWithTokenInternal(
const Shape& infeed_instruction_shape, XlaOp token, const string& config) {
TF_ASSIGN_OR_RETURN(mlir::Type result_type,
ConvertShapeToType<mlir::RankedTensorType>(
infeed_instruction_shape, builder_));
return MakeXlaOp(builder_.create<mlir::xla_hlo::InfeedOp>(
loc_, result_type, GetValue(token),
/*infeed_config=*/config));
}
StatusOr<XlaOp> MlirHloBuilder::OutfeedWithTokenInternal(
XlaOp operand, XlaOp token, const Shape& shape_with_layout,
const string& outfeed_config) {
auto token_type = mlir::xla_hlo::TokenType::get(builder_.getContext());
return MakeXlaOp(builder_.create<mlir::xla_hlo::OutfeedOp>(
loc_, token_type, GetValue(operand), GetValue(token), outfeed_config));
}
StatusOr<XlaOp> MlirHloBuilder::ConcatInDimInternal(
const Shape& shape, absl::Span<const XlaOp> operands, int64 dimension) {
TF_ASSIGN_OR_RETURN(
mlir::Type result_type,
ConvertShapeToType<mlir::RankedTensorType>(shape, builder_));
auto mlir_operands = GetValues(operands);
return MakeXlaOp(builder_.create<mlir::xla_hlo::ConcatenateOp>(
loc_, result_type, mlir_operands, builder_.getI64IntegerAttr(dimension)));
}
StatusOr<XlaOp> MlirHloBuilder::GetTupleElementInternal(const Shape& shape,
XlaOp tuple_data,
int64 index) {
TF_ASSIGN_OR_RETURN(
mlir::Type result_type,
ConvertShapeToType<mlir::RankedTensorType>(shape, builder_));
return MakeXlaOp(builder_.create<mlir::xla_hlo::GetTupleElementOp>(
loc_, result_type, GetValue(tuple_data),
builder_.getI32IntegerAttr(index)));
}
StatusOr<XlaOp> MlirHloBuilder::SliceInternal(
const Shape& shape, XlaOp operand, absl::Span<const int64> start_indices,
absl::Span<const int64> limit_indices, absl::Span<const int64> strides) {
return MakeXlaOp(builder_.create<mlir::xla_hlo::SliceOp>(
loc_, GetValue(operand), GetI64ElementsAttr(start_indices, &builder_),
GetI64ElementsAttr(limit_indices, &builder_),
GetI64ElementsAttr(strides, &builder_)));
}
StatusOr<XlaOp> MlirHloBuilder::PadInternal(
const Shape& shape, XlaOp operand, XlaOp padding_value,
const PaddingConfig& padding_config) {
TF_ASSIGN_OR_RETURN(
mlir::Type result_type,
ConvertShapeToType<mlir::RankedTensorType>(shape, builder_));
std::vector<int64> low;
std::vector<int64> high;
std::vector<int64> internal;
for (auto& dimension : padding_config.dimensions()) {
low.push_back(dimension.edge_padding_low());
high.push_back(dimension.edge_padding_high());
internal.push_back(dimension.interior_padding());
}
return MakeXlaOp(builder_.create<mlir::xla_hlo::PadOp>(
loc_, result_type, GetValue(operand), GetValue(padding_value),
GetI64ElementsAttr(low, &builder_), GetI64ElementsAttr(high, &builder_),
GetI64ElementsAttr(internal, &builder_)));
}
StatusOr<XlaOp> MlirHloBuilder::TupleInternal(
const Shape& shape, absl::Span<const XlaOp> elements) {
mlir::SmallVector<mlir::Value, 4> operands;
for (auto& element : elements) {
operands.push_back(GetValue(element));
}
return MakeXlaOp(builder_.create<mlir::xla_hlo::TupleOp>(loc_, operands));
}
StatusOr<XlaOp> MlirHloBuilder::CreateOp(
const std::string& op_name, const Shape& shape,
llvm::ArrayRef<XlaOp> operands,

View File

@ -54,6 +54,9 @@ class MlirHloBuilder : public XlaBuilder {
// TODO(hinsu): Add a constructor to build a new MLIR function from scratch
// and override Build methods.
MlirHloBuilder(std::string name, mlir::OpBuilder builder, mlir::Location loc)
: XlaBuilder(name), builder_(builder), loc_(loc) {}
MlirHloBuilder(const MlirHloBuilder&) = delete;
MlirHloBuilder& operator=(const MlirHloBuilder&) = delete;
@ -75,6 +78,17 @@ class MlirHloBuilder : public XlaBuilder {
return mlir::Value::getFromOpaquePointer(ptr);
}
// Returns MLIR values corresponding to the given XLA ops.
//
// Requires that the ops were created by this builder.
std::vector<mlir::Value> GetValues(absl::Span<const XlaOp> ops) {
std::vector<mlir::Value> values;
for (auto xla_op : ops) {
values.push_back(GetValue(xla_op));
}
return values;
}
// Sets location for newly built ops, until reset.
void SetLocation(mlir::Location loc) { loc_ = loc; }
@ -120,6 +134,34 @@ class MlirHloBuilder : public XlaBuilder {
StatusOr<XlaOp> AddOpWithShape(HloOpcode opcode, const Shape& shape,
absl::Span<const XlaOp> operands) override;
XlaOp CreateToken() override;
StatusOr<XlaOp> InfeedWithTokenInternal(const Shape& infeed_instruction_shape,
XlaOp token,
const string& config) override;
StatusOr<XlaOp> OutfeedWithTokenInternal(
XlaOp operand, XlaOp token, const Shape& shape_with_layout,
const string& outfeed_config) override;
StatusOr<XlaOp> ConcatInDimInternal(const Shape& shape,
absl::Span<const XlaOp> operands,
int64 dimension) override;
StatusOr<XlaOp> GetTupleElementInternal(const Shape& shape, XlaOp tuple_data,
int64 index) override;
StatusOr<XlaOp> SliceInternal(const Shape& shape, XlaOp operand,
absl::Span<const int64> start_indices,
absl::Span<const int64> limit_indices,
absl::Span<const int64> strides) override;
StatusOr<XlaOp> PadInternal(const Shape& shape, XlaOp operand,
XlaOp padding_value,
const PaddingConfig& padding_config) override;
StatusOr<XlaOp> TupleInternal(const Shape& shape,
absl::Span<const XlaOp> elements) override;
// Creates HLO dialect op and returns the result as an XlaOp.
StatusOr<XlaOp> CreateOp(const std::string& op_name, const Shape& shape,
llvm::ArrayRef<XlaOp> operands,

View File

@ -1,4 +1,5 @@
load("//tensorflow/compiler/mlir:glob_lit_test.bzl", "glob_lit_tests")
load("//tensorflow:tensorflow.bzl", "tf_cc_test")
package(licenses = ["notice"])
@ -18,3 +19,18 @@ filegroup(
"@llvm-project//llvm:FileCheck",
],
)
tf_cc_test(
name = "mlir_hlo_builder_test",
srcs = ["mlir_hlo_builder_test.cc"],
deps = [
"//tensorflow/compiler/mlir/xla:hlo",
"//tensorflow/compiler/mlir/xla:mlir_hlo_builder",
"//tensorflow/compiler/xla:shape_util",
"//tensorflow/core:test",
"//tensorflow/core:test_main",
"//tensorflow/core:testlib",
"@llvm-project//llvm:support",
"@llvm-project//mlir:IR",
],
)

View File

@ -0,0 +1,179 @@
/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#include "tensorflow/compiler/mlir/xla/ir/mlir_hlo_builder.h"
#include <string>
#include "llvm/Support/raw_ostream.h"
#include "mlir/IR/Attributes.h" // from @llvm-project
#include "mlir/IR/Builders.h" // from @llvm-project
#include "mlir/IR/Dialect.h" // from @llvm-project
#include "mlir/IR/Location.h" // from @llvm-project
#include "mlir/IR/MLIRContext.h" // from @llvm-project
#include "mlir/IR/Module.h" // from @llvm-project
#include "tensorflow/compiler/mlir/xla/ir/hlo_ops.h"
#include "tensorflow/compiler/xla/shape_util.h"
#include "tensorflow/core/lib/core/status_test_util.h"
#include "tensorflow/core/platform/test.h"
namespace xla {
namespace {
static void ExpectHasSubstr(absl::string_view s, absl::string_view expected) {
EXPECT_TRUE(absl::StrContains(s, expected))
<< s << " does not contain " << expected;
}
class XlaBuilderTest : public ::testing::Test {
protected:
XlaBuilderTest()
: name_(SetupTest()),
context_(),
module_(mlir::ModuleOp::create(mlir::UnknownLoc::get(&context_))),
builder_(&module_->getBodyRegion()),
xla_builder_(name_, builder_, module_->getLoc()) {}
string SetupTest() {
mlir::registerDialect<mlir::xla_hlo::XlaHloDialect>();
return ::testing::UnitTest::GetInstance()->current_test_info()->name();
}
// Retuns the MLIR op string representation of the given XlaOp.
string GetMlirOpString(XlaOp xla_op) {
string str;
llvm::raw_string_ostream ostream{str};
xla_builder_.GetValue(xla_op).print(ostream);
ostream.flush();
return str;
}
string name_;
mlir::MLIRContext context_;
mlir::OwningModuleRef module_;
mlir::OpBuilder builder_;
MlirHloBuilder xla_builder_;
};
TEST_F(XlaBuilderTest, CreateToken) {
auto token = CreateToken(&xla_builder_);
auto str = GetMlirOpString(token);
TF_ASSERT_OK(xla_builder_.GetCurrentStatus());
ExpectHasSubstr(GetMlirOpString(token),
R"("xla_hlo.create_token"() : () -> !xla_hlo.token)");
}
TEST_F(XlaBuilderTest, Infeed) {
auto token = CreateToken(&xla_builder_);
auto infeed = InfeedWithToken(token, ShapeUtil::MakeShape(F32, {4, 8}), "");
TF_ASSERT_OK(xla_builder_.GetCurrentStatus());
ExpectHasSubstr(
GetMlirOpString(infeed),
R"("xla_hlo.infeed"(%0) {infeed_config = ""} : (!xla_hlo.token) -> tuple<tensor<4x8xf32>, !xla_hlo.token>)");
}
TEST_F(XlaBuilderTest, Outfeed) {
auto outfeed_shape = ShapeUtil::MakeShape(F32, {4, 8});
auto data = ConstantLiteral(
&xla_builder_,
LiteralUtil::CreateFromDimensions(F32, outfeed_shape.dimensions()));
auto token = CreateToken(&xla_builder_);
auto outfeed = OutfeedWithToken(data, token, outfeed_shape, "");
TF_ASSERT_OK(xla_builder_.GetCurrentStatus());
ExpectHasSubstr(
GetMlirOpString(outfeed),
R"("xla_hlo.outfeed"(%0, %1) {outfeed_config = ""} : (tensor<4x8xf32>, !xla_hlo.token) -> !xla_hlo.token)");
}
TEST_F(XlaBuilderTest, ConcatInDim) {
auto data0 = ConstantLiteral(
&xla_builder_, LiteralUtil::CreateFromDimensions(F32, {2, 4, 5}));
auto data1 = ConstantLiteral(
&xla_builder_, LiteralUtil::CreateFromDimensions(F32, {2, 6, 5}));
auto concat = ConcatInDim(&xla_builder_, {data0, data1}, 1);
TF_ASSERT_OK(xla_builder_.GetCurrentStatus());
ExpectHasSubstr(
GetMlirOpString(concat),
R"("xla_hlo.concatenate"(%0, %1) {dimension = 1 : i64} : (tensor<2x4x5xf32>, tensor<2x6x5xf32>) -> tensor<2x10x5xf32>)");
}
TEST_F(XlaBuilderTest, Tuple) {
auto data0 = ConstantLiteral(&xla_builder_,
LiteralUtil::CreateFromDimensions(F32, {3, 7}));
auto data1 = ConstantLiteral(&xla_builder_,
LiteralUtil::CreateFromDimensions(F32, {}));
auto tuple = Tuple(&xla_builder_, {data0, data1});
TF_ASSERT_OK(xla_builder_.GetCurrentStatus());
ExpectHasSubstr(
GetMlirOpString(tuple),
R"("xla_hlo.tuple"(%0, %1) : (tensor<3x7xf32>, tensor<f32>) -> tuple<tensor<3x7xf32>, tensor<f32>>)");
}
TEST_F(XlaBuilderTest, GetTupleElement) {
auto data0 = ConstantLiteral(&xla_builder_,
LiteralUtil::CreateFromDimensions(F32, {3, 7}));
auto data1 = ConstantLiteral(&xla_builder_,
LiteralUtil::CreateFromDimensions(F32, {}));
auto tuple_data = Tuple(&xla_builder_, {data0, data1});
auto gte = GetTupleElement(tuple_data, 1);
TF_ASSERT_OK(xla_builder_.GetCurrentStatus());
ExpectHasSubstr(
GetMlirOpString(gte),
R"("xla_hlo.get_tuple_element"(%2) {index = 1 : i32} : (tuple<tensor<3x7xf32>, tensor<f32>>) -> tensor<f32>)");
}
TEST_F(XlaBuilderTest, Slice) {
auto data = ConstantLiteral(&xla_builder_,
LiteralUtil::CreateFromDimensions(F32, {3, 7}));
auto slice = Slice(data, {0, 1}, {2, 5}, {1, 1});
TF_ASSERT_OK(xla_builder_.GetCurrentStatus());
ExpectHasSubstr(
GetMlirOpString(slice),
R"("xla_hlo.slice"(%0) {limit_indices = dense<[2, 5]> : tensor<2xi64>, start_indices = dense<[0, 1]> : tensor<2xi64>, strides = dense<1> : tensor<2xi64>} : (tensor<3x7xf32>) -> tensor<2x4xf32>)");
}
TEST_F(XlaBuilderTest, Pad) {
auto data = ConstantLiteral(&xla_builder_,
LiteralUtil::CreateFromDimensions(F32, {3, 7}));
auto zero = ConstantLiteral(&xla_builder_, LiteralUtil::Zero(F32));
PaddingConfig padding_config;
auto* dims0 = padding_config.add_dimensions();
dims0->set_edge_padding_low(1);
dims0->set_interior_padding(0);
dims0->set_edge_padding_high(2);
auto* dims1 = padding_config.add_dimensions();
dims1->set_edge_padding_low(3);
dims1->set_interior_padding(1);
dims1->set_edge_padding_high(0);
auto pad = Pad(data, zero, padding_config);
TF_ASSERT_OK(xla_builder_.GetCurrentStatus());
ExpectHasSubstr(
GetMlirOpString(pad),
R"("xla_hlo.pad"(%0, %1) {edge_padding_high = dense<[2, 0]> : tensor<2xi64>, edge_padding_low = dense<[1, 3]> : tensor<2xi64>, interior_padding = dense<[0, 1]> : tensor<2xi64>} : (tensor<3x7xf32>, tensor<f32>) -> tensor<6x16xf32>)");
}
} // namespace
} // namespace xla

View File

@ -822,23 +822,29 @@ XlaOp XlaBuilder::Slice(XlaOp operand, absl::Span<const int64> start_indices,
absl::Span<const int64> limit_indices,
absl::Span<const int64> strides) {
return ReportErrorOrReturn([&]() -> StatusOr<XlaOp> {
HloInstructionProto instr;
TF_ASSIGN_OR_RETURN(const Shape* operand_shape, GetShapePtr(operand));
TF_ASSIGN_OR_RETURN(Shape shape, ShapeInference::InferSliceShape(
*operand_shape, start_indices,
limit_indices, strides));
*instr.mutable_shape() = shape.ToProto();
for (int i = 0; i < start_indices.size(); i++) {
auto* slice_config = instr.add_slice_dimensions();
slice_config->set_start(start_indices[i]);
slice_config->set_limit(limit_indices[i]);
slice_config->set_stride(strides[i]);
}
return AddInstruction(std::move(instr), HloOpcode::kSlice, {operand});
return SliceInternal(shape, operand, start_indices, limit_indices, strides);
});
}
StatusOr<XlaOp> XlaBuilder::SliceInternal(const Shape& shape, XlaOp operand,
absl::Span<const int64> start_indices,
absl::Span<const int64> limit_indices,
absl::Span<const int64> strides) {
HloInstructionProto instr;
*instr.mutable_shape() = shape.ToProto();
for (int i = 0; i < start_indices.size(); i++) {
auto* slice_config = instr.add_slice_dimensions();
slice_config->set_start(start_indices[i]);
slice_config->set_limit(limit_indices[i]);
slice_config->set_stride(strides[i]);
}
return AddInstruction(std::move(instr), HloOpcode::kSlice, {operand});
}
XlaOp XlaBuilder::SliceInDim(XlaOp operand, int64 start_index,
int64 limit_index, int64 stride, int64 dimno) {
return ReportErrorOrReturn([&]() -> StatusOr<XlaOp> {
@ -952,41 +958,49 @@ XlaOp XlaBuilder::DynamicUpdateSlice(XlaOp operand, XlaOp update,
XlaOp XlaBuilder::ConcatInDim(absl::Span<const XlaOp> operands,
int64 dimension) {
return ReportErrorOrReturn([&]() -> StatusOr<XlaOp> {
HloInstructionProto instr;
std::vector<const Shape*> operand_shape_ptrs;
TF_ASSIGN_OR_RETURN(const auto& operand_shapes, GetOperandShapes(operands));
absl::c_transform(operand_shapes, std::back_inserter(operand_shape_ptrs),
[](const Shape& shape) { return &shape; });
TF_ASSIGN_OR_RETURN(Shape shape, ShapeInference::InferConcatOpShape(
operand_shape_ptrs, dimension));
*instr.mutable_shape() = shape.ToProto();
instr.add_dimensions(dimension);
return AddInstruction(std::move(instr), HloOpcode::kConcatenate, operands);
return ConcatInDimInternal(shape, operands, dimension);
});
}
StatusOr<XlaOp> XlaBuilder::ConcatInDimInternal(
const Shape& shape, absl::Span<const XlaOp> operands, int64 dimension) {
HloInstructionProto instr;
*instr.mutable_shape() = shape.ToProto();
instr.add_dimensions(dimension);
return AddInstruction(std::move(instr), HloOpcode::kConcatenate, operands);
}
XlaOp XlaBuilder::Pad(XlaOp operand, XlaOp padding_value,
const PaddingConfig& padding_config) {
return ReportErrorOrReturn([&]() -> StatusOr<XlaOp> {
HloInstructionProto instr;
TF_ASSIGN_OR_RETURN(const Shape* operand_shape, GetShapePtr(operand));
TF_ASSIGN_OR_RETURN(const Shape* padding_value_shape,
GetShapePtr(padding_value));
TF_ASSIGN_OR_RETURN(
Shape shape, ShapeInference::InferPadShape(
*operand_shape, *padding_value_shape, padding_config));
*instr.mutable_shape() = shape.ToProto();
*instr.mutable_padding_config() = padding_config;
return AddInstruction(std::move(instr), HloOpcode::kPad,
{operand, padding_value});
return PadInternal(shape, operand, padding_value, padding_config);
});
}
StatusOr<XlaOp> XlaBuilder::PadInternal(const Shape& shape, XlaOp operand,
XlaOp padding_value,
const PaddingConfig& padding_config) {
HloInstructionProto instr;
*instr.mutable_shape() = shape.ToProto();
*instr.mutable_padding_config() = padding_config;
return AddInstruction(std::move(instr), HloOpcode::kPad,
{operand, padding_value});
}
XlaOp XlaBuilder::Reshape(XlaOp operand, absl::Span<const int64> dimensions,
absl::Span<const int64> new_sizes,
int64 inferred_dimension) {
@ -1080,7 +1094,6 @@ XlaOp XlaBuilder::Select(XlaOp pred, XlaOp on_true, XlaOp on_false) {
XlaOp XlaBuilder::Tuple(absl::Span<const XlaOp> elements) {
return ReportErrorOrReturn([&]() -> StatusOr<XlaOp> {
HloInstructionProto instr;
std::vector<const Shape*> operand_shape_ptrs;
TF_ASSIGN_OR_RETURN(const auto& operand_shapes, GetOperandShapes(elements));
absl::c_transform(operand_shapes, std::back_inserter(operand_shape_ptrs),
@ -1088,14 +1101,19 @@ XlaOp XlaBuilder::Tuple(absl::Span<const XlaOp> elements) {
TF_ASSIGN_OR_RETURN(const Shape shape,
ShapeInference::InferVariadicOpShape(
HloOpcode::kTuple, operand_shape_ptrs));
*instr.mutable_shape() = shape.ToProto();
return AddInstruction(std::move(instr), HloOpcode::kTuple, elements);
return TupleInternal(shape, elements);
});
}
StatusOr<XlaOp> XlaBuilder::TupleInternal(const Shape& shape,
absl::Span<const XlaOp> elements) {
HloInstructionProto instr;
*instr.mutable_shape() = shape.ToProto();
return AddInstruction(std::move(instr), HloOpcode::kTuple, elements);
}
XlaOp XlaBuilder::GetTupleElement(XlaOp tuple_data, int64 index) {
return ReportErrorOrReturn([&]() -> StatusOr<XlaOp> {
HloInstructionProto instr;
TF_ASSIGN_OR_RETURN(const Shape* tuple_shape, GetShapePtr(tuple_data));
if (!tuple_shape->IsTuple()) {
return InvalidArgument(
@ -1107,16 +1125,22 @@ XlaOp XlaBuilder::GetTupleElement(XlaOp tuple_data, int64 index) {
"GetTupleElement() index (%d) out of range for tuple shape %s", index,
ShapeUtil::HumanString(*tuple_shape));
}
*instr.mutable_shape() =
ShapeUtil::GetTupleElementShape(*tuple_shape, index).ToProto();
instr.set_tuple_index(index);
return AddInstruction(std::move(instr), HloOpcode::kGetTupleElement,
{tuple_data});
return GetTupleElementInternal(
ShapeUtil::GetTupleElementShape(*tuple_shape, index), tuple_data,
index);
});
}
StatusOr<XlaOp> XlaBuilder::GetTupleElementInternal(const Shape& shape,
XlaOp tuple_data,
int64 index) {
HloInstructionProto instr;
*instr.mutable_shape() = shape.ToProto();
instr.set_tuple_index(index);
return AddInstruction(std::move(instr), HloOpcode::kGetTupleElement,
{tuple_data});
}
XlaOp XlaBuilder::Dot(XlaOp lhs, XlaOp rhs,
const PrecisionConfig* precision_config) {
return ReportErrorOrReturn([&]() -> StatusOr<XlaOp> {
@ -1407,14 +1431,11 @@ XlaOp XlaBuilder::Infeed(const Shape& shape, const string& config) {
XlaOp XlaBuilder::InfeedWithToken(XlaOp token, const Shape& shape,
const string& config) {
return ReportErrorOrReturn([&]() -> StatusOr<XlaOp> {
HloInstructionProto instr;
if (!LayoutUtil::HasLayout(shape)) {
return InvalidArgument("Given shape to Infeed must have a layout");
}
const Shape infeed_instruction_shape =
ShapeUtil::MakeTupleShape({shape, ShapeUtil::MakeTokenShape()});
*instr.mutable_shape() = infeed_instruction_shape.ToProto();
instr.set_infeed_config(config);
if (shape.IsArray() && sharding() &&
sharding()->type() == OpSharding::OTHER) {
@ -1427,11 +1448,18 @@ XlaOp XlaBuilder::InfeedWithToken(XlaOp token, const Shape& shape,
return InvalidArgument(
"Replicated sharding is not yet supported for infeeds");
}
return AddInstruction(std::move(instr), HloOpcode::kInfeed, {token});
return InfeedWithTokenInternal(infeed_instruction_shape, token, config);
});
}
StatusOr<XlaOp> XlaBuilder::InfeedWithTokenInternal(
const Shape& infeed_instruction_shape, XlaOp token, const string& config) {
HloInstructionProto instr;
*instr.mutable_shape() = infeed_instruction_shape.ToProto();
instr.set_infeed_config(config);
return AddInstruction(std::move(instr), HloOpcode::kInfeed, {token});
}
void XlaBuilder::Outfeed(XlaOp operand, const Shape& shape_with_layout,
const string& outfeed_config) {
ReportErrorOrReturn([&]() -> StatusOr<XlaOp> {
@ -1488,10 +1516,6 @@ XlaOp XlaBuilder::OutfeedWithToken(XlaOp operand, XlaOp token,
const Shape& shape_with_layout,
const string& outfeed_config) {
return ReportErrorOrReturn([&]() -> StatusOr<XlaOp> {
HloInstructionProto instr;
*instr.mutable_shape() = ShapeUtil::MakeTokenShape().ToProto();
// Check and set outfeed shape.
if (!LayoutUtil::HasLayout(shape_with_layout)) {
return InvalidArgument("Given shape to Outfeed must have a layout");
@ -1503,15 +1527,22 @@ XlaOp XlaBuilder::OutfeedWithToken(XlaOp operand, XlaOp token,
ShapeUtil::HumanStringWithLayout(shape_with_layout),
ShapeUtil::HumanStringWithLayout(*operand_shape));
}
*instr.mutable_outfeed_shape() = shape_with_layout.ToProto();
instr.set_outfeed_config(outfeed_config);
return AddInstruction(std::move(instr), HloOpcode::kOutfeed,
{operand, token});
return OutfeedWithTokenInternal(operand, token, shape_with_layout,
outfeed_config);
});
}
StatusOr<XlaOp> XlaBuilder::OutfeedWithTokenInternal(
XlaOp operand, XlaOp token, const Shape& shape_with_layout,
const string& outfeed_config) {
HloInstructionProto instr;
*instr.mutable_shape() = ShapeUtil::MakeTokenShape().ToProto();
*instr.mutable_outfeed_shape() = shape_with_layout.ToProto();
instr.set_outfeed_config(outfeed_config);
return AddInstruction(std::move(instr), HloOpcode::kOutfeed,
{operand, token});
}
XlaOp XlaBuilder::CreateToken() {
return ReportErrorOrReturn([&]() -> StatusOr<XlaOp> {
HloInstructionProto instr;

View File

@ -364,6 +364,10 @@ class XlaBuilder {
Status SetInstructionFrontendAttribute(XlaOp op, string attribute,
string value);
// Returns shapes for the operands.
StatusOr<std::vector<Shape>> GetOperandShapes(
absl::Span<const XlaOp> operands) const;
private:
// Build helper which takes the id of the root operation..
StatusOr<XlaComputation> Build(int64 root_id, bool remove_dynamic_dimensions);
@ -391,6 +395,10 @@ class XlaBuilder {
XlaOp Pad(XlaOp operand, XlaOp padding_value,
const PaddingConfig& padding_config);
virtual StatusOr<XlaOp> PadInternal(const Shape& shape, XlaOp operand,
XlaOp padding_value,
const PaddingConfig& padding_config);
XlaOp Reshape(XlaOp operand, absl::Span<const int64> dimensions,
absl::Span<const int64> new_sizes,
int64 inferred_dimension = -1);
@ -406,9 +414,12 @@ class XlaBuilder {
XlaOp Slice(XlaOp operand, absl::Span<const int64> start_indices,
absl::Span<const int64> limit_indices,
absl::Span<const int64> strides);
XlaOp SliceInDim(XlaOp operand, int64 start_index, int64 limit_index,
int64 stride, int64 dimno);
virtual StatusOr<XlaOp> SliceInternal(const Shape& shape, XlaOp operand,
absl::Span<const int64> start_indices,
absl::Span<const int64> limit_indices,
absl::Span<const int64> strides);
virtual XlaOp SliceInDim(XlaOp operand, int64 start_index, int64 limit_index,
int64 stride, int64 dimno);
ABSL_DEPRECATED("Use span-of-indices form instead")
XlaOp DynamicSlice(XlaOp operand, XlaOp start_indices,
@ -422,14 +433,22 @@ class XlaBuilder {
absl::Span<const XlaOp> start_indices);
XlaOp ConcatInDim(absl::Span<const XlaOp> operands, int64 dimension);
virtual StatusOr<XlaOp> ConcatInDimInternal(const Shape& shape,
absl::Span<const XlaOp> operands,
int64 dimension);
void Trace(const string& tag, XlaOp operand);
XlaOp Select(XlaOp pred, XlaOp on_true, XlaOp on_false);
XlaOp Tuple(absl::Span<const XlaOp> elements);
virtual StatusOr<XlaOp> TupleInternal(const Shape& shape,
absl::Span<const XlaOp> elements);
XlaOp GetTupleElement(XlaOp tuple_data, int64 index);
virtual StatusOr<XlaOp> GetTupleElementInternal(const Shape& shape,
XlaOp tuple_data,
int64 index);
XlaOp Dot(XlaOp lhs, XlaOp rhs,
const PrecisionConfig* precision_config = nullptr);
@ -476,15 +495,18 @@ class XlaBuilder {
absl::Span<const int64> fft_length);
XlaOp Infeed(const Shape& shape, const string& config = "");
XlaOp InfeedWithToken(XlaOp token, const Shape& shape,
const string& config = "");
XlaOp InfeedWithToken(XlaOp token, const Shape& shape, const string& config);
virtual StatusOr<XlaOp> InfeedWithTokenInternal(
const Shape& infeed_instruction_shape, XlaOp token, const string& config);
void Outfeed(XlaOp operand, const Shape& shape_with_layout,
const string& outfeed_config);
XlaOp OutfeedWithToken(XlaOp operand, XlaOp token,
const Shape& shape_with_layout,
const string& outfeed_config);
virtual StatusOr<XlaOp> OutfeedWithTokenInternal(
XlaOp operand, XlaOp token, const Shape& shape_with_layout,
const string& outfeed_config);
XlaOp Call(const XlaComputation& computation,
absl::Span<const XlaOp> operands);
@ -624,7 +646,7 @@ class XlaBuilder {
XlaOp RecvFromHost(XlaOp token, const Shape& shape,
const ChannelHandle& handle);
XlaOp CreateToken();
virtual XlaOp CreateToken();
XlaOp AfterAll(absl::Span<const XlaOp> tokens);
@ -701,10 +723,6 @@ class XlaBuilder {
// Returns the (inferred) result for the program shape using the given root.
StatusOr<ProgramShape> GetProgramShape(int64 root_id) const;
// Returns shapes for the operands.
StatusOr<std::vector<Shape>> GetOperandShapes(
absl::Span<const XlaOp> operands) const;
// A visitor which checks whether an operation is a compile-time constant,
// meaning that it doesn't depend on any parameters, or on any stateful
// operation such as `RngNormal` or `Infeed`. The visitor walks the