Add MlirHloBuilder op implementations

PiperOrigin-RevId: 307472994
Change-Id: Ifbca316f653f44469cebd3aa5a507e8ccabf5001
This commit is contained in:
HyoukJoong Lee 2020-04-20 14:08:59 -07:00 committed by TensorFlower Gardener
parent 63f2383118
commit f65a9acfda
6 changed files with 431 additions and 61 deletions

View File

@ -165,6 +165,90 @@ StatusOr<XlaOp> MlirHloBuilder::AddOpWithShape(
/*attributes=*/{}); /*attributes=*/{});
} }
XlaOp MlirHloBuilder::CreateToken() {
return ReportErrorOrReturn([&]() -> StatusOr<XlaOp> {
return MakeXlaOp(builder_.create<mlir::xla_hlo::CreateTokenOp>(
loc_, mlir::xla_hlo::TokenType::get(builder_.getContext())));
});
}
StatusOr<XlaOp> MlirHloBuilder::InfeedWithTokenInternal(
const Shape& infeed_instruction_shape, XlaOp token, const string& config) {
TF_ASSIGN_OR_RETURN(mlir::Type result_type,
ConvertShapeToType<mlir::RankedTensorType>(
infeed_instruction_shape, builder_));
return MakeXlaOp(builder_.create<mlir::xla_hlo::InfeedOp>(
loc_, result_type, GetValue(token),
/*infeed_config=*/config));
}
StatusOr<XlaOp> MlirHloBuilder::OutfeedWithTokenInternal(
XlaOp operand, XlaOp token, const Shape& shape_with_layout,
const string& outfeed_config) {
auto token_type = mlir::xla_hlo::TokenType::get(builder_.getContext());
return MakeXlaOp(builder_.create<mlir::xla_hlo::OutfeedOp>(
loc_, token_type, GetValue(operand), GetValue(token), outfeed_config));
}
StatusOr<XlaOp> MlirHloBuilder::ConcatInDimInternal(
const Shape& shape, absl::Span<const XlaOp> operands, int64 dimension) {
TF_ASSIGN_OR_RETURN(
mlir::Type result_type,
ConvertShapeToType<mlir::RankedTensorType>(shape, builder_));
auto mlir_operands = GetValues(operands);
return MakeXlaOp(builder_.create<mlir::xla_hlo::ConcatenateOp>(
loc_, result_type, mlir_operands, builder_.getI64IntegerAttr(dimension)));
}
StatusOr<XlaOp> MlirHloBuilder::GetTupleElementInternal(const Shape& shape,
XlaOp tuple_data,
int64 index) {
TF_ASSIGN_OR_RETURN(
mlir::Type result_type,
ConvertShapeToType<mlir::RankedTensorType>(shape, builder_));
return MakeXlaOp(builder_.create<mlir::xla_hlo::GetTupleElementOp>(
loc_, result_type, GetValue(tuple_data),
builder_.getI32IntegerAttr(index)));
}
StatusOr<XlaOp> MlirHloBuilder::SliceInternal(
const Shape& shape, XlaOp operand, absl::Span<const int64> start_indices,
absl::Span<const int64> limit_indices, absl::Span<const int64> strides) {
return MakeXlaOp(builder_.create<mlir::xla_hlo::SliceOp>(
loc_, GetValue(operand), GetI64ElementsAttr(start_indices, &builder_),
GetI64ElementsAttr(limit_indices, &builder_),
GetI64ElementsAttr(strides, &builder_)));
}
StatusOr<XlaOp> MlirHloBuilder::PadInternal(
const Shape& shape, XlaOp operand, XlaOp padding_value,
const PaddingConfig& padding_config) {
TF_ASSIGN_OR_RETURN(
mlir::Type result_type,
ConvertShapeToType<mlir::RankedTensorType>(shape, builder_));
std::vector<int64> low;
std::vector<int64> high;
std::vector<int64> internal;
for (auto& dimension : padding_config.dimensions()) {
low.push_back(dimension.edge_padding_low());
high.push_back(dimension.edge_padding_high());
internal.push_back(dimension.interior_padding());
}
return MakeXlaOp(builder_.create<mlir::xla_hlo::PadOp>(
loc_, result_type, GetValue(operand), GetValue(padding_value),
GetI64ElementsAttr(low, &builder_), GetI64ElementsAttr(high, &builder_),
GetI64ElementsAttr(internal, &builder_)));
}
StatusOr<XlaOp> MlirHloBuilder::TupleInternal(
const Shape& shape, absl::Span<const XlaOp> elements) {
mlir::SmallVector<mlir::Value, 4> operands;
for (auto& element : elements) {
operands.push_back(GetValue(element));
}
return MakeXlaOp(builder_.create<mlir::xla_hlo::TupleOp>(loc_, operands));
}
StatusOr<XlaOp> MlirHloBuilder::CreateOp( StatusOr<XlaOp> MlirHloBuilder::CreateOp(
const std::string& op_name, const Shape& shape, const std::string& op_name, const Shape& shape,
llvm::ArrayRef<XlaOp> operands, llvm::ArrayRef<XlaOp> operands,

View File

@ -54,6 +54,9 @@ class MlirHloBuilder : public XlaBuilder {
// TODO(hinsu): Add a constructor to build a new MLIR function from scratch // TODO(hinsu): Add a constructor to build a new MLIR function from scratch
// and override Build methods. // and override Build methods.
MlirHloBuilder(std::string name, mlir::OpBuilder builder, mlir::Location loc)
: XlaBuilder(name), builder_(builder), loc_(loc) {}
MlirHloBuilder(const MlirHloBuilder&) = delete; MlirHloBuilder(const MlirHloBuilder&) = delete;
MlirHloBuilder& operator=(const MlirHloBuilder&) = delete; MlirHloBuilder& operator=(const MlirHloBuilder&) = delete;
@ -75,6 +78,17 @@ class MlirHloBuilder : public XlaBuilder {
return mlir::Value::getFromOpaquePointer(ptr); return mlir::Value::getFromOpaquePointer(ptr);
} }
// Returns MLIR values corresponding to the given XLA ops.
//
// Requires that the ops were created by this builder.
std::vector<mlir::Value> GetValues(absl::Span<const XlaOp> ops) {
std::vector<mlir::Value> values;
for (auto xla_op : ops) {
values.push_back(GetValue(xla_op));
}
return values;
}
// Sets location for newly built ops, until reset. // Sets location for newly built ops, until reset.
void SetLocation(mlir::Location loc) { loc_ = loc; } void SetLocation(mlir::Location loc) { loc_ = loc; }
@ -120,6 +134,34 @@ class MlirHloBuilder : public XlaBuilder {
StatusOr<XlaOp> AddOpWithShape(HloOpcode opcode, const Shape& shape, StatusOr<XlaOp> AddOpWithShape(HloOpcode opcode, const Shape& shape,
absl::Span<const XlaOp> operands) override; absl::Span<const XlaOp> operands) override;
XlaOp CreateToken() override;
StatusOr<XlaOp> InfeedWithTokenInternal(const Shape& infeed_instruction_shape,
XlaOp token,
const string& config) override;
StatusOr<XlaOp> OutfeedWithTokenInternal(
XlaOp operand, XlaOp token, const Shape& shape_with_layout,
const string& outfeed_config) override;
StatusOr<XlaOp> ConcatInDimInternal(const Shape& shape,
absl::Span<const XlaOp> operands,
int64 dimension) override;
StatusOr<XlaOp> GetTupleElementInternal(const Shape& shape, XlaOp tuple_data,
int64 index) override;
StatusOr<XlaOp> SliceInternal(const Shape& shape, XlaOp operand,
absl::Span<const int64> start_indices,
absl::Span<const int64> limit_indices,
absl::Span<const int64> strides) override;
StatusOr<XlaOp> PadInternal(const Shape& shape, XlaOp operand,
XlaOp padding_value,
const PaddingConfig& padding_config) override;
StatusOr<XlaOp> TupleInternal(const Shape& shape,
absl::Span<const XlaOp> elements) override;
// Creates HLO dialect op and returns the result as an XlaOp. // Creates HLO dialect op and returns the result as an XlaOp.
StatusOr<XlaOp> CreateOp(const std::string& op_name, const Shape& shape, StatusOr<XlaOp> CreateOp(const std::string& op_name, const Shape& shape,
llvm::ArrayRef<XlaOp> operands, llvm::ArrayRef<XlaOp> operands,

View File

@ -1,4 +1,5 @@
load("//tensorflow/compiler/mlir:glob_lit_test.bzl", "glob_lit_tests") load("//tensorflow/compiler/mlir:glob_lit_test.bzl", "glob_lit_tests")
load("//tensorflow:tensorflow.bzl", "tf_cc_test")
package(licenses = ["notice"]) package(licenses = ["notice"])
@ -18,3 +19,18 @@ filegroup(
"@llvm-project//llvm:FileCheck", "@llvm-project//llvm:FileCheck",
], ],
) )
tf_cc_test(
name = "mlir_hlo_builder_test",
srcs = ["mlir_hlo_builder_test.cc"],
deps = [
"//tensorflow/compiler/mlir/xla:hlo",
"//tensorflow/compiler/mlir/xla:mlir_hlo_builder",
"//tensorflow/compiler/xla:shape_util",
"//tensorflow/core:test",
"//tensorflow/core:test_main",
"//tensorflow/core:testlib",
"@llvm-project//llvm:support",
"@llvm-project//mlir:IR",
],
)

View File

@ -0,0 +1,179 @@
/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#include "tensorflow/compiler/mlir/xla/ir/mlir_hlo_builder.h"
#include <string>
#include "llvm/Support/raw_ostream.h"
#include "mlir/IR/Attributes.h" // from @llvm-project
#include "mlir/IR/Builders.h" // from @llvm-project
#include "mlir/IR/Dialect.h" // from @llvm-project
#include "mlir/IR/Location.h" // from @llvm-project
#include "mlir/IR/MLIRContext.h" // from @llvm-project
#include "mlir/IR/Module.h" // from @llvm-project
#include "tensorflow/compiler/mlir/xla/ir/hlo_ops.h"
#include "tensorflow/compiler/xla/shape_util.h"
#include "tensorflow/core/lib/core/status_test_util.h"
#include "tensorflow/core/platform/test.h"
namespace xla {
namespace {
static void ExpectHasSubstr(absl::string_view s, absl::string_view expected) {
EXPECT_TRUE(absl::StrContains(s, expected))
<< s << " does not contain " << expected;
}
class XlaBuilderTest : public ::testing::Test {
protected:
XlaBuilderTest()
: name_(SetupTest()),
context_(),
module_(mlir::ModuleOp::create(mlir::UnknownLoc::get(&context_))),
builder_(&module_->getBodyRegion()),
xla_builder_(name_, builder_, module_->getLoc()) {}
string SetupTest() {
mlir::registerDialect<mlir::xla_hlo::XlaHloDialect>();
return ::testing::UnitTest::GetInstance()->current_test_info()->name();
}
// Retuns the MLIR op string representation of the given XlaOp.
string GetMlirOpString(XlaOp xla_op) {
string str;
llvm::raw_string_ostream ostream{str};
xla_builder_.GetValue(xla_op).print(ostream);
ostream.flush();
return str;
}
string name_;
mlir::MLIRContext context_;
mlir::OwningModuleRef module_;
mlir::OpBuilder builder_;
MlirHloBuilder xla_builder_;
};
TEST_F(XlaBuilderTest, CreateToken) {
auto token = CreateToken(&xla_builder_);
auto str = GetMlirOpString(token);
TF_ASSERT_OK(xla_builder_.GetCurrentStatus());
ExpectHasSubstr(GetMlirOpString(token),
R"("xla_hlo.create_token"() : () -> !xla_hlo.token)");
}
TEST_F(XlaBuilderTest, Infeed) {
auto token = CreateToken(&xla_builder_);
auto infeed = InfeedWithToken(token, ShapeUtil::MakeShape(F32, {4, 8}), "");
TF_ASSERT_OK(xla_builder_.GetCurrentStatus());
ExpectHasSubstr(
GetMlirOpString(infeed),
R"("xla_hlo.infeed"(%0) {infeed_config = ""} : (!xla_hlo.token) -> tuple<tensor<4x8xf32>, !xla_hlo.token>)");
}
TEST_F(XlaBuilderTest, Outfeed) {
auto outfeed_shape = ShapeUtil::MakeShape(F32, {4, 8});
auto data = ConstantLiteral(
&xla_builder_,
LiteralUtil::CreateFromDimensions(F32, outfeed_shape.dimensions()));
auto token = CreateToken(&xla_builder_);
auto outfeed = OutfeedWithToken(data, token, outfeed_shape, "");
TF_ASSERT_OK(xla_builder_.GetCurrentStatus());
ExpectHasSubstr(
GetMlirOpString(outfeed),
R"("xla_hlo.outfeed"(%0, %1) {outfeed_config = ""} : (tensor<4x8xf32>, !xla_hlo.token) -> !xla_hlo.token)");
}
TEST_F(XlaBuilderTest, ConcatInDim) {
auto data0 = ConstantLiteral(
&xla_builder_, LiteralUtil::CreateFromDimensions(F32, {2, 4, 5}));
auto data1 = ConstantLiteral(
&xla_builder_, LiteralUtil::CreateFromDimensions(F32, {2, 6, 5}));
auto concat = ConcatInDim(&xla_builder_, {data0, data1}, 1);
TF_ASSERT_OK(xla_builder_.GetCurrentStatus());
ExpectHasSubstr(
GetMlirOpString(concat),
R"("xla_hlo.concatenate"(%0, %1) {dimension = 1 : i64} : (tensor<2x4x5xf32>, tensor<2x6x5xf32>) -> tensor<2x10x5xf32>)");
}
TEST_F(XlaBuilderTest, Tuple) {
auto data0 = ConstantLiteral(&xla_builder_,
LiteralUtil::CreateFromDimensions(F32, {3, 7}));
auto data1 = ConstantLiteral(&xla_builder_,
LiteralUtil::CreateFromDimensions(F32, {}));
auto tuple = Tuple(&xla_builder_, {data0, data1});
TF_ASSERT_OK(xla_builder_.GetCurrentStatus());
ExpectHasSubstr(
GetMlirOpString(tuple),
R"("xla_hlo.tuple"(%0, %1) : (tensor<3x7xf32>, tensor<f32>) -> tuple<tensor<3x7xf32>, tensor<f32>>)");
}
TEST_F(XlaBuilderTest, GetTupleElement) {
auto data0 = ConstantLiteral(&xla_builder_,
LiteralUtil::CreateFromDimensions(F32, {3, 7}));
auto data1 = ConstantLiteral(&xla_builder_,
LiteralUtil::CreateFromDimensions(F32, {}));
auto tuple_data = Tuple(&xla_builder_, {data0, data1});
auto gte = GetTupleElement(tuple_data, 1);
TF_ASSERT_OK(xla_builder_.GetCurrentStatus());
ExpectHasSubstr(
GetMlirOpString(gte),
R"("xla_hlo.get_tuple_element"(%2) {index = 1 : i32} : (tuple<tensor<3x7xf32>, tensor<f32>>) -> tensor<f32>)");
}
TEST_F(XlaBuilderTest, Slice) {
auto data = ConstantLiteral(&xla_builder_,
LiteralUtil::CreateFromDimensions(F32, {3, 7}));
auto slice = Slice(data, {0, 1}, {2, 5}, {1, 1});
TF_ASSERT_OK(xla_builder_.GetCurrentStatus());
ExpectHasSubstr(
GetMlirOpString(slice),
R"("xla_hlo.slice"(%0) {limit_indices = dense<[2, 5]> : tensor<2xi64>, start_indices = dense<[0, 1]> : tensor<2xi64>, strides = dense<1> : tensor<2xi64>} : (tensor<3x7xf32>) -> tensor<2x4xf32>)");
}
TEST_F(XlaBuilderTest, Pad) {
auto data = ConstantLiteral(&xla_builder_,
LiteralUtil::CreateFromDimensions(F32, {3, 7}));
auto zero = ConstantLiteral(&xla_builder_, LiteralUtil::Zero(F32));
PaddingConfig padding_config;
auto* dims0 = padding_config.add_dimensions();
dims0->set_edge_padding_low(1);
dims0->set_interior_padding(0);
dims0->set_edge_padding_high(2);
auto* dims1 = padding_config.add_dimensions();
dims1->set_edge_padding_low(3);
dims1->set_interior_padding(1);
dims1->set_edge_padding_high(0);
auto pad = Pad(data, zero, padding_config);
TF_ASSERT_OK(xla_builder_.GetCurrentStatus());
ExpectHasSubstr(
GetMlirOpString(pad),
R"("xla_hlo.pad"(%0, %1) {edge_padding_high = dense<[2, 0]> : tensor<2xi64>, edge_padding_low = dense<[1, 3]> : tensor<2xi64>, interior_padding = dense<[0, 1]> : tensor<2xi64>} : (tensor<3x7xf32>, tensor<f32>) -> tensor<6x16xf32>)");
}
} // namespace
} // namespace xla

View File

@ -822,23 +822,29 @@ XlaOp XlaBuilder::Slice(XlaOp operand, absl::Span<const int64> start_indices,
absl::Span<const int64> limit_indices, absl::Span<const int64> limit_indices,
absl::Span<const int64> strides) { absl::Span<const int64> strides) {
return ReportErrorOrReturn([&]() -> StatusOr<XlaOp> { return ReportErrorOrReturn([&]() -> StatusOr<XlaOp> {
HloInstructionProto instr;
TF_ASSIGN_OR_RETURN(const Shape* operand_shape, GetShapePtr(operand)); TF_ASSIGN_OR_RETURN(const Shape* operand_shape, GetShapePtr(operand));
TF_ASSIGN_OR_RETURN(Shape shape, ShapeInference::InferSliceShape( TF_ASSIGN_OR_RETURN(Shape shape, ShapeInference::InferSliceShape(
*operand_shape, start_indices, *operand_shape, start_indices,
limit_indices, strides)); limit_indices, strides));
*instr.mutable_shape() = shape.ToProto(); return SliceInternal(shape, operand, start_indices, limit_indices, strides);
for (int i = 0; i < start_indices.size(); i++) {
auto* slice_config = instr.add_slice_dimensions();
slice_config->set_start(start_indices[i]);
slice_config->set_limit(limit_indices[i]);
slice_config->set_stride(strides[i]);
}
return AddInstruction(std::move(instr), HloOpcode::kSlice, {operand});
}); });
} }
StatusOr<XlaOp> XlaBuilder::SliceInternal(const Shape& shape, XlaOp operand,
absl::Span<const int64> start_indices,
absl::Span<const int64> limit_indices,
absl::Span<const int64> strides) {
HloInstructionProto instr;
*instr.mutable_shape() = shape.ToProto();
for (int i = 0; i < start_indices.size(); i++) {
auto* slice_config = instr.add_slice_dimensions();
slice_config->set_start(start_indices[i]);
slice_config->set_limit(limit_indices[i]);
slice_config->set_stride(strides[i]);
}
return AddInstruction(std::move(instr), HloOpcode::kSlice, {operand});
}
XlaOp XlaBuilder::SliceInDim(XlaOp operand, int64 start_index, XlaOp XlaBuilder::SliceInDim(XlaOp operand, int64 start_index,
int64 limit_index, int64 stride, int64 dimno) { int64 limit_index, int64 stride, int64 dimno) {
return ReportErrorOrReturn([&]() -> StatusOr<XlaOp> { return ReportErrorOrReturn([&]() -> StatusOr<XlaOp> {
@ -952,41 +958,49 @@ XlaOp XlaBuilder::DynamicUpdateSlice(XlaOp operand, XlaOp update,
XlaOp XlaBuilder::ConcatInDim(absl::Span<const XlaOp> operands, XlaOp XlaBuilder::ConcatInDim(absl::Span<const XlaOp> operands,
int64 dimension) { int64 dimension) {
return ReportErrorOrReturn([&]() -> StatusOr<XlaOp> { return ReportErrorOrReturn([&]() -> StatusOr<XlaOp> {
HloInstructionProto instr;
std::vector<const Shape*> operand_shape_ptrs; std::vector<const Shape*> operand_shape_ptrs;
TF_ASSIGN_OR_RETURN(const auto& operand_shapes, GetOperandShapes(operands)); TF_ASSIGN_OR_RETURN(const auto& operand_shapes, GetOperandShapes(operands));
absl::c_transform(operand_shapes, std::back_inserter(operand_shape_ptrs), absl::c_transform(operand_shapes, std::back_inserter(operand_shape_ptrs),
[](const Shape& shape) { return &shape; }); [](const Shape& shape) { return &shape; });
TF_ASSIGN_OR_RETURN(Shape shape, ShapeInference::InferConcatOpShape( TF_ASSIGN_OR_RETURN(Shape shape, ShapeInference::InferConcatOpShape(
operand_shape_ptrs, dimension)); operand_shape_ptrs, dimension));
*instr.mutable_shape() = shape.ToProto(); return ConcatInDimInternal(shape, operands, dimension);
instr.add_dimensions(dimension);
return AddInstruction(std::move(instr), HloOpcode::kConcatenate, operands);
}); });
} }
StatusOr<XlaOp> XlaBuilder::ConcatInDimInternal(
const Shape& shape, absl::Span<const XlaOp> operands, int64 dimension) {
HloInstructionProto instr;
*instr.mutable_shape() = shape.ToProto();
instr.add_dimensions(dimension);
return AddInstruction(std::move(instr), HloOpcode::kConcatenate, operands);
}
XlaOp XlaBuilder::Pad(XlaOp operand, XlaOp padding_value, XlaOp XlaBuilder::Pad(XlaOp operand, XlaOp padding_value,
const PaddingConfig& padding_config) { const PaddingConfig& padding_config) {
return ReportErrorOrReturn([&]() -> StatusOr<XlaOp> { return ReportErrorOrReturn([&]() -> StatusOr<XlaOp> {
HloInstructionProto instr;
TF_ASSIGN_OR_RETURN(const Shape* operand_shape, GetShapePtr(operand)); TF_ASSIGN_OR_RETURN(const Shape* operand_shape, GetShapePtr(operand));
TF_ASSIGN_OR_RETURN(const Shape* padding_value_shape, TF_ASSIGN_OR_RETURN(const Shape* padding_value_shape,
GetShapePtr(padding_value)); GetShapePtr(padding_value));
TF_ASSIGN_OR_RETURN( TF_ASSIGN_OR_RETURN(
Shape shape, ShapeInference::InferPadShape( Shape shape, ShapeInference::InferPadShape(
*operand_shape, *padding_value_shape, padding_config)); *operand_shape, *padding_value_shape, padding_config));
*instr.mutable_shape() = shape.ToProto(); return PadInternal(shape, operand, padding_value, padding_config);
*instr.mutable_padding_config() = padding_config;
return AddInstruction(std::move(instr), HloOpcode::kPad,
{operand, padding_value});
}); });
} }
StatusOr<XlaOp> XlaBuilder::PadInternal(const Shape& shape, XlaOp operand,
XlaOp padding_value,
const PaddingConfig& padding_config) {
HloInstructionProto instr;
*instr.mutable_shape() = shape.ToProto();
*instr.mutable_padding_config() = padding_config;
return AddInstruction(std::move(instr), HloOpcode::kPad,
{operand, padding_value});
}
XlaOp XlaBuilder::Reshape(XlaOp operand, absl::Span<const int64> dimensions, XlaOp XlaBuilder::Reshape(XlaOp operand, absl::Span<const int64> dimensions,
absl::Span<const int64> new_sizes, absl::Span<const int64> new_sizes,
int64 inferred_dimension) { int64 inferred_dimension) {
@ -1080,7 +1094,6 @@ XlaOp XlaBuilder::Select(XlaOp pred, XlaOp on_true, XlaOp on_false) {
XlaOp XlaBuilder::Tuple(absl::Span<const XlaOp> elements) { XlaOp XlaBuilder::Tuple(absl::Span<const XlaOp> elements) {
return ReportErrorOrReturn([&]() -> StatusOr<XlaOp> { return ReportErrorOrReturn([&]() -> StatusOr<XlaOp> {
HloInstructionProto instr;
std::vector<const Shape*> operand_shape_ptrs; std::vector<const Shape*> operand_shape_ptrs;
TF_ASSIGN_OR_RETURN(const auto& operand_shapes, GetOperandShapes(elements)); TF_ASSIGN_OR_RETURN(const auto& operand_shapes, GetOperandShapes(elements));
absl::c_transform(operand_shapes, std::back_inserter(operand_shape_ptrs), absl::c_transform(operand_shapes, std::back_inserter(operand_shape_ptrs),
@ -1088,14 +1101,19 @@ XlaOp XlaBuilder::Tuple(absl::Span<const XlaOp> elements) {
TF_ASSIGN_OR_RETURN(const Shape shape, TF_ASSIGN_OR_RETURN(const Shape shape,
ShapeInference::InferVariadicOpShape( ShapeInference::InferVariadicOpShape(
HloOpcode::kTuple, operand_shape_ptrs)); HloOpcode::kTuple, operand_shape_ptrs));
*instr.mutable_shape() = shape.ToProto(); return TupleInternal(shape, elements);
return AddInstruction(std::move(instr), HloOpcode::kTuple, elements);
}); });
} }
StatusOr<XlaOp> XlaBuilder::TupleInternal(const Shape& shape,
absl::Span<const XlaOp> elements) {
HloInstructionProto instr;
*instr.mutable_shape() = shape.ToProto();
return AddInstruction(std::move(instr), HloOpcode::kTuple, elements);
}
XlaOp XlaBuilder::GetTupleElement(XlaOp tuple_data, int64 index) { XlaOp XlaBuilder::GetTupleElement(XlaOp tuple_data, int64 index) {
return ReportErrorOrReturn([&]() -> StatusOr<XlaOp> { return ReportErrorOrReturn([&]() -> StatusOr<XlaOp> {
HloInstructionProto instr;
TF_ASSIGN_OR_RETURN(const Shape* tuple_shape, GetShapePtr(tuple_data)); TF_ASSIGN_OR_RETURN(const Shape* tuple_shape, GetShapePtr(tuple_data));
if (!tuple_shape->IsTuple()) { if (!tuple_shape->IsTuple()) {
return InvalidArgument( return InvalidArgument(
@ -1107,16 +1125,22 @@ XlaOp XlaBuilder::GetTupleElement(XlaOp tuple_data, int64 index) {
"GetTupleElement() index (%d) out of range for tuple shape %s", index, "GetTupleElement() index (%d) out of range for tuple shape %s", index,
ShapeUtil::HumanString(*tuple_shape)); ShapeUtil::HumanString(*tuple_shape));
} }
*instr.mutable_shape() = return GetTupleElementInternal(
ShapeUtil::GetTupleElementShape(*tuple_shape, index).ToProto(); ShapeUtil::GetTupleElementShape(*tuple_shape, index), tuple_data,
index);
instr.set_tuple_index(index);
return AddInstruction(std::move(instr), HloOpcode::kGetTupleElement,
{tuple_data});
}); });
} }
StatusOr<XlaOp> XlaBuilder::GetTupleElementInternal(const Shape& shape,
XlaOp tuple_data,
int64 index) {
HloInstructionProto instr;
*instr.mutable_shape() = shape.ToProto();
instr.set_tuple_index(index);
return AddInstruction(std::move(instr), HloOpcode::kGetTupleElement,
{tuple_data});
}
XlaOp XlaBuilder::Dot(XlaOp lhs, XlaOp rhs, XlaOp XlaBuilder::Dot(XlaOp lhs, XlaOp rhs,
const PrecisionConfig* precision_config) { const PrecisionConfig* precision_config) {
return ReportErrorOrReturn([&]() -> StatusOr<XlaOp> { return ReportErrorOrReturn([&]() -> StatusOr<XlaOp> {
@ -1407,14 +1431,11 @@ XlaOp XlaBuilder::Infeed(const Shape& shape, const string& config) {
XlaOp XlaBuilder::InfeedWithToken(XlaOp token, const Shape& shape, XlaOp XlaBuilder::InfeedWithToken(XlaOp token, const Shape& shape,
const string& config) { const string& config) {
return ReportErrorOrReturn([&]() -> StatusOr<XlaOp> { return ReportErrorOrReturn([&]() -> StatusOr<XlaOp> {
HloInstructionProto instr;
if (!LayoutUtil::HasLayout(shape)) { if (!LayoutUtil::HasLayout(shape)) {
return InvalidArgument("Given shape to Infeed must have a layout"); return InvalidArgument("Given shape to Infeed must have a layout");
} }
const Shape infeed_instruction_shape = const Shape infeed_instruction_shape =
ShapeUtil::MakeTupleShape({shape, ShapeUtil::MakeTokenShape()}); ShapeUtil::MakeTupleShape({shape, ShapeUtil::MakeTokenShape()});
*instr.mutable_shape() = infeed_instruction_shape.ToProto();
instr.set_infeed_config(config);
if (shape.IsArray() && sharding() && if (shape.IsArray() && sharding() &&
sharding()->type() == OpSharding::OTHER) { sharding()->type() == OpSharding::OTHER) {
@ -1427,11 +1448,18 @@ XlaOp XlaBuilder::InfeedWithToken(XlaOp token, const Shape& shape,
return InvalidArgument( return InvalidArgument(
"Replicated sharding is not yet supported for infeeds"); "Replicated sharding is not yet supported for infeeds");
} }
return InfeedWithTokenInternal(infeed_instruction_shape, token, config);
return AddInstruction(std::move(instr), HloOpcode::kInfeed, {token});
}); });
} }
StatusOr<XlaOp> XlaBuilder::InfeedWithTokenInternal(
const Shape& infeed_instruction_shape, XlaOp token, const string& config) {
HloInstructionProto instr;
*instr.mutable_shape() = infeed_instruction_shape.ToProto();
instr.set_infeed_config(config);
return AddInstruction(std::move(instr), HloOpcode::kInfeed, {token});
}
void XlaBuilder::Outfeed(XlaOp operand, const Shape& shape_with_layout, void XlaBuilder::Outfeed(XlaOp operand, const Shape& shape_with_layout,
const string& outfeed_config) { const string& outfeed_config) {
ReportErrorOrReturn([&]() -> StatusOr<XlaOp> { ReportErrorOrReturn([&]() -> StatusOr<XlaOp> {
@ -1488,10 +1516,6 @@ XlaOp XlaBuilder::OutfeedWithToken(XlaOp operand, XlaOp token,
const Shape& shape_with_layout, const Shape& shape_with_layout,
const string& outfeed_config) { const string& outfeed_config) {
return ReportErrorOrReturn([&]() -> StatusOr<XlaOp> { return ReportErrorOrReturn([&]() -> StatusOr<XlaOp> {
HloInstructionProto instr;
*instr.mutable_shape() = ShapeUtil::MakeTokenShape().ToProto();
// Check and set outfeed shape. // Check and set outfeed shape.
if (!LayoutUtil::HasLayout(shape_with_layout)) { if (!LayoutUtil::HasLayout(shape_with_layout)) {
return InvalidArgument("Given shape to Outfeed must have a layout"); return InvalidArgument("Given shape to Outfeed must have a layout");
@ -1503,15 +1527,22 @@ XlaOp XlaBuilder::OutfeedWithToken(XlaOp operand, XlaOp token,
ShapeUtil::HumanStringWithLayout(shape_with_layout), ShapeUtil::HumanStringWithLayout(shape_with_layout),
ShapeUtil::HumanStringWithLayout(*operand_shape)); ShapeUtil::HumanStringWithLayout(*operand_shape));
} }
*instr.mutable_outfeed_shape() = shape_with_layout.ToProto(); return OutfeedWithTokenInternal(operand, token, shape_with_layout,
outfeed_config);
instr.set_outfeed_config(outfeed_config);
return AddInstruction(std::move(instr), HloOpcode::kOutfeed,
{operand, token});
}); });
} }
StatusOr<XlaOp> XlaBuilder::OutfeedWithTokenInternal(
XlaOp operand, XlaOp token, const Shape& shape_with_layout,
const string& outfeed_config) {
HloInstructionProto instr;
*instr.mutable_shape() = ShapeUtil::MakeTokenShape().ToProto();
*instr.mutable_outfeed_shape() = shape_with_layout.ToProto();
instr.set_outfeed_config(outfeed_config);
return AddInstruction(std::move(instr), HloOpcode::kOutfeed,
{operand, token});
}
XlaOp XlaBuilder::CreateToken() { XlaOp XlaBuilder::CreateToken() {
return ReportErrorOrReturn([&]() -> StatusOr<XlaOp> { return ReportErrorOrReturn([&]() -> StatusOr<XlaOp> {
HloInstructionProto instr; HloInstructionProto instr;

View File

@ -364,6 +364,10 @@ class XlaBuilder {
Status SetInstructionFrontendAttribute(XlaOp op, string attribute, Status SetInstructionFrontendAttribute(XlaOp op, string attribute,
string value); string value);
// Returns shapes for the operands.
StatusOr<std::vector<Shape>> GetOperandShapes(
absl::Span<const XlaOp> operands) const;
private: private:
// Build helper which takes the id of the root operation.. // Build helper which takes the id of the root operation..
StatusOr<XlaComputation> Build(int64 root_id, bool remove_dynamic_dimensions); StatusOr<XlaComputation> Build(int64 root_id, bool remove_dynamic_dimensions);
@ -391,6 +395,10 @@ class XlaBuilder {
XlaOp Pad(XlaOp operand, XlaOp padding_value, XlaOp Pad(XlaOp operand, XlaOp padding_value,
const PaddingConfig& padding_config); const PaddingConfig& padding_config);
virtual StatusOr<XlaOp> PadInternal(const Shape& shape, XlaOp operand,
XlaOp padding_value,
const PaddingConfig& padding_config);
XlaOp Reshape(XlaOp operand, absl::Span<const int64> dimensions, XlaOp Reshape(XlaOp operand, absl::Span<const int64> dimensions,
absl::Span<const int64> new_sizes, absl::Span<const int64> new_sizes,
int64 inferred_dimension = -1); int64 inferred_dimension = -1);
@ -406,9 +414,12 @@ class XlaBuilder {
XlaOp Slice(XlaOp operand, absl::Span<const int64> start_indices, XlaOp Slice(XlaOp operand, absl::Span<const int64> start_indices,
absl::Span<const int64> limit_indices, absl::Span<const int64> limit_indices,
absl::Span<const int64> strides); absl::Span<const int64> strides);
virtual StatusOr<XlaOp> SliceInternal(const Shape& shape, XlaOp operand,
XlaOp SliceInDim(XlaOp operand, int64 start_index, int64 limit_index, absl::Span<const int64> start_indices,
int64 stride, int64 dimno); absl::Span<const int64> limit_indices,
absl::Span<const int64> strides);
virtual XlaOp SliceInDim(XlaOp operand, int64 start_index, int64 limit_index,
int64 stride, int64 dimno);
ABSL_DEPRECATED("Use span-of-indices form instead") ABSL_DEPRECATED("Use span-of-indices form instead")
XlaOp DynamicSlice(XlaOp operand, XlaOp start_indices, XlaOp DynamicSlice(XlaOp operand, XlaOp start_indices,
@ -422,14 +433,22 @@ class XlaBuilder {
absl::Span<const XlaOp> start_indices); absl::Span<const XlaOp> start_indices);
XlaOp ConcatInDim(absl::Span<const XlaOp> operands, int64 dimension); XlaOp ConcatInDim(absl::Span<const XlaOp> operands, int64 dimension);
virtual StatusOr<XlaOp> ConcatInDimInternal(const Shape& shape,
absl::Span<const XlaOp> operands,
int64 dimension);
void Trace(const string& tag, XlaOp operand); void Trace(const string& tag, XlaOp operand);
XlaOp Select(XlaOp pred, XlaOp on_true, XlaOp on_false); XlaOp Select(XlaOp pred, XlaOp on_true, XlaOp on_false);
XlaOp Tuple(absl::Span<const XlaOp> elements); XlaOp Tuple(absl::Span<const XlaOp> elements);
virtual StatusOr<XlaOp> TupleInternal(const Shape& shape,
absl::Span<const XlaOp> elements);
XlaOp GetTupleElement(XlaOp tuple_data, int64 index); XlaOp GetTupleElement(XlaOp tuple_data, int64 index);
virtual StatusOr<XlaOp> GetTupleElementInternal(const Shape& shape,
XlaOp tuple_data,
int64 index);
XlaOp Dot(XlaOp lhs, XlaOp rhs, XlaOp Dot(XlaOp lhs, XlaOp rhs,
const PrecisionConfig* precision_config = nullptr); const PrecisionConfig* precision_config = nullptr);
@ -476,15 +495,18 @@ class XlaBuilder {
absl::Span<const int64> fft_length); absl::Span<const int64> fft_length);
XlaOp Infeed(const Shape& shape, const string& config = ""); XlaOp Infeed(const Shape& shape, const string& config = "");
XlaOp InfeedWithToken(XlaOp token, const Shape& shape, XlaOp InfeedWithToken(XlaOp token, const Shape& shape, const string& config);
const string& config = ""); virtual StatusOr<XlaOp> InfeedWithTokenInternal(
const Shape& infeed_instruction_shape, XlaOp token, const string& config);
void Outfeed(XlaOp operand, const Shape& shape_with_layout, void Outfeed(XlaOp operand, const Shape& shape_with_layout,
const string& outfeed_config); const string& outfeed_config);
XlaOp OutfeedWithToken(XlaOp operand, XlaOp token, XlaOp OutfeedWithToken(XlaOp operand, XlaOp token,
const Shape& shape_with_layout, const Shape& shape_with_layout,
const string& outfeed_config); const string& outfeed_config);
virtual StatusOr<XlaOp> OutfeedWithTokenInternal(
XlaOp operand, XlaOp token, const Shape& shape_with_layout,
const string& outfeed_config);
XlaOp Call(const XlaComputation& computation, XlaOp Call(const XlaComputation& computation,
absl::Span<const XlaOp> operands); absl::Span<const XlaOp> operands);
@ -624,7 +646,7 @@ class XlaBuilder {
XlaOp RecvFromHost(XlaOp token, const Shape& shape, XlaOp RecvFromHost(XlaOp token, const Shape& shape,
const ChannelHandle& handle); const ChannelHandle& handle);
XlaOp CreateToken(); virtual XlaOp CreateToken();
XlaOp AfterAll(absl::Span<const XlaOp> tokens); XlaOp AfterAll(absl::Span<const XlaOp> tokens);
@ -701,10 +723,6 @@ class XlaBuilder {
// Returns the (inferred) result for the program shape using the given root. // Returns the (inferred) result for the program shape using the given root.
StatusOr<ProgramShape> GetProgramShape(int64 root_id) const; StatusOr<ProgramShape> GetProgramShape(int64 root_id) const;
// Returns shapes for the operands.
StatusOr<std::vector<Shape>> GetOperandShapes(
absl::Span<const XlaOp> operands) const;
// A visitor which checks whether an operation is a compile-time constant, // A visitor which checks whether an operation is a compile-time constant,
// meaning that it doesn't depend on any parameters, or on any stateful // meaning that it doesn't depend on any parameters, or on any stateful
// operation such as `RngNormal` or `Infeed`. The visitor walks the // operation such as `RngNormal` or `Infeed`. The visitor walks the