From f65a9acfdaaa4585cfa32c057c64b6d84fc28371 Mon Sep 17 00:00:00 2001 From: HyoukJoong Lee Date: Mon, 20 Apr 2020 14:08:59 -0700 Subject: [PATCH] Add MlirHloBuilder op implementations PiperOrigin-RevId: 307472994 Change-Id: Ifbca316f653f44469cebd3aa5a507e8ccabf5001 --- .../compiler/mlir/xla/ir/mlir_hlo_builder.cc | 84 ++++++++ .../compiler/mlir/xla/ir/mlir_hlo_builder.h | 42 ++++ tensorflow/compiler/mlir/xla/tests/BUILD | 16 ++ .../mlir/xla/tests/mlir_hlo_builder_test.cc | 179 ++++++++++++++++++ tensorflow/compiler/xla/client/xla_builder.cc | 131 ++++++++----- tensorflow/compiler/xla/client/xla_builder.h | 40 ++-- 6 files changed, 431 insertions(+), 61 deletions(-) create mode 100644 tensorflow/compiler/mlir/xla/tests/mlir_hlo_builder_test.cc diff --git a/tensorflow/compiler/mlir/xla/ir/mlir_hlo_builder.cc b/tensorflow/compiler/mlir/xla/ir/mlir_hlo_builder.cc index 739f19e9625..cfa8c1b6bfc 100644 --- a/tensorflow/compiler/mlir/xla/ir/mlir_hlo_builder.cc +++ b/tensorflow/compiler/mlir/xla/ir/mlir_hlo_builder.cc @@ -165,6 +165,90 @@ StatusOr MlirHloBuilder::AddOpWithShape( /*attributes=*/{}); } +XlaOp MlirHloBuilder::CreateToken() { + return ReportErrorOrReturn([&]() -> StatusOr { + return MakeXlaOp(builder_.create( + loc_, mlir::xla_hlo::TokenType::get(builder_.getContext()))); + }); +} + +StatusOr MlirHloBuilder::InfeedWithTokenInternal( + const Shape& infeed_instruction_shape, XlaOp token, const string& config) { + TF_ASSIGN_OR_RETURN(mlir::Type result_type, + ConvertShapeToType( + infeed_instruction_shape, builder_)); + return MakeXlaOp(builder_.create( + loc_, result_type, GetValue(token), + /*infeed_config=*/config)); +} + +StatusOr MlirHloBuilder::OutfeedWithTokenInternal( + XlaOp operand, XlaOp token, const Shape& shape_with_layout, + const string& outfeed_config) { + auto token_type = mlir::xla_hlo::TokenType::get(builder_.getContext()); + return MakeXlaOp(builder_.create( + loc_, token_type, GetValue(operand), GetValue(token), outfeed_config)); +} + +StatusOr MlirHloBuilder::ConcatInDimInternal( + const Shape& shape, absl::Span operands, int64 dimension) { + TF_ASSIGN_OR_RETURN( + mlir::Type result_type, + ConvertShapeToType(shape, builder_)); + auto mlir_operands = GetValues(operands); + return MakeXlaOp(builder_.create( + loc_, result_type, mlir_operands, builder_.getI64IntegerAttr(dimension))); +} + +StatusOr MlirHloBuilder::GetTupleElementInternal(const Shape& shape, + XlaOp tuple_data, + int64 index) { + TF_ASSIGN_OR_RETURN( + mlir::Type result_type, + ConvertShapeToType(shape, builder_)); + return MakeXlaOp(builder_.create( + loc_, result_type, GetValue(tuple_data), + builder_.getI32IntegerAttr(index))); +} + +StatusOr MlirHloBuilder::SliceInternal( + const Shape& shape, XlaOp operand, absl::Span start_indices, + absl::Span limit_indices, absl::Span strides) { + return MakeXlaOp(builder_.create( + loc_, GetValue(operand), GetI64ElementsAttr(start_indices, &builder_), + GetI64ElementsAttr(limit_indices, &builder_), + GetI64ElementsAttr(strides, &builder_))); +} + +StatusOr MlirHloBuilder::PadInternal( + const Shape& shape, XlaOp operand, XlaOp padding_value, + const PaddingConfig& padding_config) { + TF_ASSIGN_OR_RETURN( + mlir::Type result_type, + ConvertShapeToType(shape, builder_)); + std::vector low; + std::vector high; + std::vector internal; + for (auto& dimension : padding_config.dimensions()) { + low.push_back(dimension.edge_padding_low()); + high.push_back(dimension.edge_padding_high()); + internal.push_back(dimension.interior_padding()); + } + return MakeXlaOp(builder_.create( + loc_, result_type, GetValue(operand), GetValue(padding_value), + GetI64ElementsAttr(low, &builder_), GetI64ElementsAttr(high, &builder_), + GetI64ElementsAttr(internal, &builder_))); +} + +StatusOr MlirHloBuilder::TupleInternal( + const Shape& shape, absl::Span elements) { + mlir::SmallVector operands; + for (auto& element : elements) { + operands.push_back(GetValue(element)); + } + return MakeXlaOp(builder_.create(loc_, operands)); +} + StatusOr MlirHloBuilder::CreateOp( const std::string& op_name, const Shape& shape, llvm::ArrayRef operands, diff --git a/tensorflow/compiler/mlir/xla/ir/mlir_hlo_builder.h b/tensorflow/compiler/mlir/xla/ir/mlir_hlo_builder.h index 95dafbd35f2..c0ef645a731 100644 --- a/tensorflow/compiler/mlir/xla/ir/mlir_hlo_builder.h +++ b/tensorflow/compiler/mlir/xla/ir/mlir_hlo_builder.h @@ -54,6 +54,9 @@ class MlirHloBuilder : public XlaBuilder { // TODO(hinsu): Add a constructor to build a new MLIR function from scratch // and override Build methods. + MlirHloBuilder(std::string name, mlir::OpBuilder builder, mlir::Location loc) + : XlaBuilder(name), builder_(builder), loc_(loc) {} + MlirHloBuilder(const MlirHloBuilder&) = delete; MlirHloBuilder& operator=(const MlirHloBuilder&) = delete; @@ -75,6 +78,17 @@ class MlirHloBuilder : public XlaBuilder { return mlir::Value::getFromOpaquePointer(ptr); } + // Returns MLIR values corresponding to the given XLA ops. + // + // Requires that the ops were created by this builder. + std::vector GetValues(absl::Span ops) { + std::vector values; + for (auto xla_op : ops) { + values.push_back(GetValue(xla_op)); + } + return values; + } + // Sets location for newly built ops, until reset. void SetLocation(mlir::Location loc) { loc_ = loc; } @@ -120,6 +134,34 @@ class MlirHloBuilder : public XlaBuilder { StatusOr AddOpWithShape(HloOpcode opcode, const Shape& shape, absl::Span operands) override; + XlaOp CreateToken() override; + + StatusOr InfeedWithTokenInternal(const Shape& infeed_instruction_shape, + XlaOp token, + const string& config) override; + StatusOr OutfeedWithTokenInternal( + XlaOp operand, XlaOp token, const Shape& shape_with_layout, + const string& outfeed_config) override; + + StatusOr ConcatInDimInternal(const Shape& shape, + absl::Span operands, + int64 dimension) override; + + StatusOr GetTupleElementInternal(const Shape& shape, XlaOp tuple_data, + int64 index) override; + + StatusOr SliceInternal(const Shape& shape, XlaOp operand, + absl::Span start_indices, + absl::Span limit_indices, + absl::Span strides) override; + + StatusOr PadInternal(const Shape& shape, XlaOp operand, + XlaOp padding_value, + const PaddingConfig& padding_config) override; + + StatusOr TupleInternal(const Shape& shape, + absl::Span elements) override; + // Creates HLO dialect op and returns the result as an XlaOp. StatusOr CreateOp(const std::string& op_name, const Shape& shape, llvm::ArrayRef operands, diff --git a/tensorflow/compiler/mlir/xla/tests/BUILD b/tensorflow/compiler/mlir/xla/tests/BUILD index 989b846f561..ad69383bd98 100644 --- a/tensorflow/compiler/mlir/xla/tests/BUILD +++ b/tensorflow/compiler/mlir/xla/tests/BUILD @@ -1,4 +1,5 @@ load("//tensorflow/compiler/mlir:glob_lit_test.bzl", "glob_lit_tests") +load("//tensorflow:tensorflow.bzl", "tf_cc_test") package(licenses = ["notice"]) @@ -18,3 +19,18 @@ filegroup( "@llvm-project//llvm:FileCheck", ], ) + +tf_cc_test( + name = "mlir_hlo_builder_test", + srcs = ["mlir_hlo_builder_test.cc"], + deps = [ + "//tensorflow/compiler/mlir/xla:hlo", + "//tensorflow/compiler/mlir/xla:mlir_hlo_builder", + "//tensorflow/compiler/xla:shape_util", + "//tensorflow/core:test", + "//tensorflow/core:test_main", + "//tensorflow/core:testlib", + "@llvm-project//llvm:support", + "@llvm-project//mlir:IR", + ], +) diff --git a/tensorflow/compiler/mlir/xla/tests/mlir_hlo_builder_test.cc b/tensorflow/compiler/mlir/xla/tests/mlir_hlo_builder_test.cc new file mode 100644 index 00000000000..54791e15cf4 --- /dev/null +++ b/tensorflow/compiler/mlir/xla/tests/mlir_hlo_builder_test.cc @@ -0,0 +1,179 @@ +/* Copyright 2020 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tensorflow/compiler/mlir/xla/ir/mlir_hlo_builder.h" + +#include + +#include "llvm/Support/raw_ostream.h" +#include "mlir/IR/Attributes.h" // from @llvm-project +#include "mlir/IR/Builders.h" // from @llvm-project +#include "mlir/IR/Dialect.h" // from @llvm-project +#include "mlir/IR/Location.h" // from @llvm-project +#include "mlir/IR/MLIRContext.h" // from @llvm-project +#include "mlir/IR/Module.h" // from @llvm-project +#include "tensorflow/compiler/mlir/xla/ir/hlo_ops.h" +#include "tensorflow/compiler/xla/shape_util.h" +#include "tensorflow/core/lib/core/status_test_util.h" +#include "tensorflow/core/platform/test.h" + +namespace xla { + +namespace { + +static void ExpectHasSubstr(absl::string_view s, absl::string_view expected) { + EXPECT_TRUE(absl::StrContains(s, expected)) + << s << " does not contain " << expected; +} + +class XlaBuilderTest : public ::testing::Test { + protected: + XlaBuilderTest() + : name_(SetupTest()), + context_(), + module_(mlir::ModuleOp::create(mlir::UnknownLoc::get(&context_))), + builder_(&module_->getBodyRegion()), + xla_builder_(name_, builder_, module_->getLoc()) {} + + string SetupTest() { + mlir::registerDialect(); + return ::testing::UnitTest::GetInstance()->current_test_info()->name(); + } + + // Retuns the MLIR op string representation of the given XlaOp. + string GetMlirOpString(XlaOp xla_op) { + string str; + llvm::raw_string_ostream ostream{str}; + xla_builder_.GetValue(xla_op).print(ostream); + ostream.flush(); + return str; + } + + string name_; + mlir::MLIRContext context_; + mlir::OwningModuleRef module_; + mlir::OpBuilder builder_; + MlirHloBuilder xla_builder_; +}; + +TEST_F(XlaBuilderTest, CreateToken) { + auto token = CreateToken(&xla_builder_); + auto str = GetMlirOpString(token); + + TF_ASSERT_OK(xla_builder_.GetCurrentStatus()); + + ExpectHasSubstr(GetMlirOpString(token), + R"("xla_hlo.create_token"() : () -> !xla_hlo.token)"); +} + +TEST_F(XlaBuilderTest, Infeed) { + auto token = CreateToken(&xla_builder_); + auto infeed = InfeedWithToken(token, ShapeUtil::MakeShape(F32, {4, 8}), ""); + + TF_ASSERT_OK(xla_builder_.GetCurrentStatus()); + ExpectHasSubstr( + GetMlirOpString(infeed), + R"("xla_hlo.infeed"(%0) {infeed_config = ""} : (!xla_hlo.token) -> tuple, !xla_hlo.token>)"); +} + +TEST_F(XlaBuilderTest, Outfeed) { + auto outfeed_shape = ShapeUtil::MakeShape(F32, {4, 8}); + auto data = ConstantLiteral( + &xla_builder_, + LiteralUtil::CreateFromDimensions(F32, outfeed_shape.dimensions())); + auto token = CreateToken(&xla_builder_); + auto outfeed = OutfeedWithToken(data, token, outfeed_shape, ""); + + TF_ASSERT_OK(xla_builder_.GetCurrentStatus()); + ExpectHasSubstr( + GetMlirOpString(outfeed), + R"("xla_hlo.outfeed"(%0, %1) {outfeed_config = ""} : (tensor<4x8xf32>, !xla_hlo.token) -> !xla_hlo.token)"); +} + +TEST_F(XlaBuilderTest, ConcatInDim) { + auto data0 = ConstantLiteral( + &xla_builder_, LiteralUtil::CreateFromDimensions(F32, {2, 4, 5})); + auto data1 = ConstantLiteral( + &xla_builder_, LiteralUtil::CreateFromDimensions(F32, {2, 6, 5})); + auto concat = ConcatInDim(&xla_builder_, {data0, data1}, 1); + + TF_ASSERT_OK(xla_builder_.GetCurrentStatus()); + ExpectHasSubstr( + GetMlirOpString(concat), + R"("xla_hlo.concatenate"(%0, %1) {dimension = 1 : i64} : (tensor<2x4x5xf32>, tensor<2x6x5xf32>) -> tensor<2x10x5xf32>)"); +} + +TEST_F(XlaBuilderTest, Tuple) { + auto data0 = ConstantLiteral(&xla_builder_, + LiteralUtil::CreateFromDimensions(F32, {3, 7})); + auto data1 = ConstantLiteral(&xla_builder_, + LiteralUtil::CreateFromDimensions(F32, {})); + auto tuple = Tuple(&xla_builder_, {data0, data1}); + + TF_ASSERT_OK(xla_builder_.GetCurrentStatus()); + ExpectHasSubstr( + GetMlirOpString(tuple), + R"("xla_hlo.tuple"(%0, %1) : (tensor<3x7xf32>, tensor) -> tuple, tensor>)"); +} + +TEST_F(XlaBuilderTest, GetTupleElement) { + auto data0 = ConstantLiteral(&xla_builder_, + LiteralUtil::CreateFromDimensions(F32, {3, 7})); + auto data1 = ConstantLiteral(&xla_builder_, + LiteralUtil::CreateFromDimensions(F32, {})); + auto tuple_data = Tuple(&xla_builder_, {data0, data1}); + auto gte = GetTupleElement(tuple_data, 1); + + TF_ASSERT_OK(xla_builder_.GetCurrentStatus()); + ExpectHasSubstr( + GetMlirOpString(gte), + R"("xla_hlo.get_tuple_element"(%2) {index = 1 : i32} : (tuple, tensor>) -> tensor)"); +} + +TEST_F(XlaBuilderTest, Slice) { + auto data = ConstantLiteral(&xla_builder_, + LiteralUtil::CreateFromDimensions(F32, {3, 7})); + auto slice = Slice(data, {0, 1}, {2, 5}, {1, 1}); + + TF_ASSERT_OK(xla_builder_.GetCurrentStatus()); + ExpectHasSubstr( + GetMlirOpString(slice), + R"("xla_hlo.slice"(%0) {limit_indices = dense<[2, 5]> : tensor<2xi64>, start_indices = dense<[0, 1]> : tensor<2xi64>, strides = dense<1> : tensor<2xi64>} : (tensor<3x7xf32>) -> tensor<2x4xf32>)"); +} + +TEST_F(XlaBuilderTest, Pad) { + auto data = ConstantLiteral(&xla_builder_, + LiteralUtil::CreateFromDimensions(F32, {3, 7})); + auto zero = ConstantLiteral(&xla_builder_, LiteralUtil::Zero(F32)); + + PaddingConfig padding_config; + auto* dims0 = padding_config.add_dimensions(); + dims0->set_edge_padding_low(1); + dims0->set_interior_padding(0); + dims0->set_edge_padding_high(2); + auto* dims1 = padding_config.add_dimensions(); + dims1->set_edge_padding_low(3); + dims1->set_interior_padding(1); + dims1->set_edge_padding_high(0); + auto pad = Pad(data, zero, padding_config); + + TF_ASSERT_OK(xla_builder_.GetCurrentStatus()); + ExpectHasSubstr( + GetMlirOpString(pad), + R"("xla_hlo.pad"(%0, %1) {edge_padding_high = dense<[2, 0]> : tensor<2xi64>, edge_padding_low = dense<[1, 3]> : tensor<2xi64>, interior_padding = dense<[0, 1]> : tensor<2xi64>} : (tensor<3x7xf32>, tensor) -> tensor<6x16xf32>)"); +} + +} // namespace +} // namespace xla diff --git a/tensorflow/compiler/xla/client/xla_builder.cc b/tensorflow/compiler/xla/client/xla_builder.cc index 2a690234e3c..ea93880f288 100644 --- a/tensorflow/compiler/xla/client/xla_builder.cc +++ b/tensorflow/compiler/xla/client/xla_builder.cc @@ -822,23 +822,29 @@ XlaOp XlaBuilder::Slice(XlaOp operand, absl::Span start_indices, absl::Span limit_indices, absl::Span strides) { return ReportErrorOrReturn([&]() -> StatusOr { - HloInstructionProto instr; TF_ASSIGN_OR_RETURN(const Shape* operand_shape, GetShapePtr(operand)); TF_ASSIGN_OR_RETURN(Shape shape, ShapeInference::InferSliceShape( *operand_shape, start_indices, limit_indices, strides)); - *instr.mutable_shape() = shape.ToProto(); - for (int i = 0; i < start_indices.size(); i++) { - auto* slice_config = instr.add_slice_dimensions(); - slice_config->set_start(start_indices[i]); - slice_config->set_limit(limit_indices[i]); - slice_config->set_stride(strides[i]); - } - - return AddInstruction(std::move(instr), HloOpcode::kSlice, {operand}); + return SliceInternal(shape, operand, start_indices, limit_indices, strides); }); } +StatusOr XlaBuilder::SliceInternal(const Shape& shape, XlaOp operand, + absl::Span start_indices, + absl::Span limit_indices, + absl::Span strides) { + HloInstructionProto instr; + *instr.mutable_shape() = shape.ToProto(); + for (int i = 0; i < start_indices.size(); i++) { + auto* slice_config = instr.add_slice_dimensions(); + slice_config->set_start(start_indices[i]); + slice_config->set_limit(limit_indices[i]); + slice_config->set_stride(strides[i]); + } + return AddInstruction(std::move(instr), HloOpcode::kSlice, {operand}); +} + XlaOp XlaBuilder::SliceInDim(XlaOp operand, int64 start_index, int64 limit_index, int64 stride, int64 dimno) { return ReportErrorOrReturn([&]() -> StatusOr { @@ -952,41 +958,49 @@ XlaOp XlaBuilder::DynamicUpdateSlice(XlaOp operand, XlaOp update, XlaOp XlaBuilder::ConcatInDim(absl::Span operands, int64 dimension) { return ReportErrorOrReturn([&]() -> StatusOr { - HloInstructionProto instr; - std::vector operand_shape_ptrs; TF_ASSIGN_OR_RETURN(const auto& operand_shapes, GetOperandShapes(operands)); absl::c_transform(operand_shapes, std::back_inserter(operand_shape_ptrs), [](const Shape& shape) { return &shape; }); TF_ASSIGN_OR_RETURN(Shape shape, ShapeInference::InferConcatOpShape( operand_shape_ptrs, dimension)); - *instr.mutable_shape() = shape.ToProto(); - - instr.add_dimensions(dimension); - - return AddInstruction(std::move(instr), HloOpcode::kConcatenate, operands); + return ConcatInDimInternal(shape, operands, dimension); }); } +StatusOr XlaBuilder::ConcatInDimInternal( + const Shape& shape, absl::Span operands, int64 dimension) { + HloInstructionProto instr; + *instr.mutable_shape() = shape.ToProto(); + + instr.add_dimensions(dimension); + + return AddInstruction(std::move(instr), HloOpcode::kConcatenate, operands); +} + XlaOp XlaBuilder::Pad(XlaOp operand, XlaOp padding_value, const PaddingConfig& padding_config) { return ReportErrorOrReturn([&]() -> StatusOr { - HloInstructionProto instr; - TF_ASSIGN_OR_RETURN(const Shape* operand_shape, GetShapePtr(operand)); TF_ASSIGN_OR_RETURN(const Shape* padding_value_shape, GetShapePtr(padding_value)); TF_ASSIGN_OR_RETURN( Shape shape, ShapeInference::InferPadShape( *operand_shape, *padding_value_shape, padding_config)); - *instr.mutable_shape() = shape.ToProto(); - *instr.mutable_padding_config() = padding_config; - - return AddInstruction(std::move(instr), HloOpcode::kPad, - {operand, padding_value}); + return PadInternal(shape, operand, padding_value, padding_config); }); } +StatusOr XlaBuilder::PadInternal(const Shape& shape, XlaOp operand, + XlaOp padding_value, + const PaddingConfig& padding_config) { + HloInstructionProto instr; + *instr.mutable_shape() = shape.ToProto(); + *instr.mutable_padding_config() = padding_config; + return AddInstruction(std::move(instr), HloOpcode::kPad, + {operand, padding_value}); +} + XlaOp XlaBuilder::Reshape(XlaOp operand, absl::Span dimensions, absl::Span new_sizes, int64 inferred_dimension) { @@ -1080,7 +1094,6 @@ XlaOp XlaBuilder::Select(XlaOp pred, XlaOp on_true, XlaOp on_false) { XlaOp XlaBuilder::Tuple(absl::Span elements) { return ReportErrorOrReturn([&]() -> StatusOr { - HloInstructionProto instr; std::vector operand_shape_ptrs; TF_ASSIGN_OR_RETURN(const auto& operand_shapes, GetOperandShapes(elements)); absl::c_transform(operand_shapes, std::back_inserter(operand_shape_ptrs), @@ -1088,14 +1101,19 @@ XlaOp XlaBuilder::Tuple(absl::Span elements) { TF_ASSIGN_OR_RETURN(const Shape shape, ShapeInference::InferVariadicOpShape( HloOpcode::kTuple, operand_shape_ptrs)); - *instr.mutable_shape() = shape.ToProto(); - return AddInstruction(std::move(instr), HloOpcode::kTuple, elements); + return TupleInternal(shape, elements); }); } +StatusOr XlaBuilder::TupleInternal(const Shape& shape, + absl::Span elements) { + HloInstructionProto instr; + *instr.mutable_shape() = shape.ToProto(); + return AddInstruction(std::move(instr), HloOpcode::kTuple, elements); +} + XlaOp XlaBuilder::GetTupleElement(XlaOp tuple_data, int64 index) { return ReportErrorOrReturn([&]() -> StatusOr { - HloInstructionProto instr; TF_ASSIGN_OR_RETURN(const Shape* tuple_shape, GetShapePtr(tuple_data)); if (!tuple_shape->IsTuple()) { return InvalidArgument( @@ -1107,16 +1125,22 @@ XlaOp XlaBuilder::GetTupleElement(XlaOp tuple_data, int64 index) { "GetTupleElement() index (%d) out of range for tuple shape %s", index, ShapeUtil::HumanString(*tuple_shape)); } - *instr.mutable_shape() = - ShapeUtil::GetTupleElementShape(*tuple_shape, index).ToProto(); - - instr.set_tuple_index(index); - - return AddInstruction(std::move(instr), HloOpcode::kGetTupleElement, - {tuple_data}); + return GetTupleElementInternal( + ShapeUtil::GetTupleElementShape(*tuple_shape, index), tuple_data, + index); }); } +StatusOr XlaBuilder::GetTupleElementInternal(const Shape& shape, + XlaOp tuple_data, + int64 index) { + HloInstructionProto instr; + *instr.mutable_shape() = shape.ToProto(); + instr.set_tuple_index(index); + return AddInstruction(std::move(instr), HloOpcode::kGetTupleElement, + {tuple_data}); +} + XlaOp XlaBuilder::Dot(XlaOp lhs, XlaOp rhs, const PrecisionConfig* precision_config) { return ReportErrorOrReturn([&]() -> StatusOr { @@ -1407,14 +1431,11 @@ XlaOp XlaBuilder::Infeed(const Shape& shape, const string& config) { XlaOp XlaBuilder::InfeedWithToken(XlaOp token, const Shape& shape, const string& config) { return ReportErrorOrReturn([&]() -> StatusOr { - HloInstructionProto instr; if (!LayoutUtil::HasLayout(shape)) { return InvalidArgument("Given shape to Infeed must have a layout"); } const Shape infeed_instruction_shape = ShapeUtil::MakeTupleShape({shape, ShapeUtil::MakeTokenShape()}); - *instr.mutable_shape() = infeed_instruction_shape.ToProto(); - instr.set_infeed_config(config); if (shape.IsArray() && sharding() && sharding()->type() == OpSharding::OTHER) { @@ -1427,11 +1448,18 @@ XlaOp XlaBuilder::InfeedWithToken(XlaOp token, const Shape& shape, return InvalidArgument( "Replicated sharding is not yet supported for infeeds"); } - - return AddInstruction(std::move(instr), HloOpcode::kInfeed, {token}); + return InfeedWithTokenInternal(infeed_instruction_shape, token, config); }); } +StatusOr XlaBuilder::InfeedWithTokenInternal( + const Shape& infeed_instruction_shape, XlaOp token, const string& config) { + HloInstructionProto instr; + *instr.mutable_shape() = infeed_instruction_shape.ToProto(); + instr.set_infeed_config(config); + return AddInstruction(std::move(instr), HloOpcode::kInfeed, {token}); +} + void XlaBuilder::Outfeed(XlaOp operand, const Shape& shape_with_layout, const string& outfeed_config) { ReportErrorOrReturn([&]() -> StatusOr { @@ -1488,10 +1516,6 @@ XlaOp XlaBuilder::OutfeedWithToken(XlaOp operand, XlaOp token, const Shape& shape_with_layout, const string& outfeed_config) { return ReportErrorOrReturn([&]() -> StatusOr { - HloInstructionProto instr; - - *instr.mutable_shape() = ShapeUtil::MakeTokenShape().ToProto(); - // Check and set outfeed shape. if (!LayoutUtil::HasLayout(shape_with_layout)) { return InvalidArgument("Given shape to Outfeed must have a layout"); @@ -1503,15 +1527,22 @@ XlaOp XlaBuilder::OutfeedWithToken(XlaOp operand, XlaOp token, ShapeUtil::HumanStringWithLayout(shape_with_layout), ShapeUtil::HumanStringWithLayout(*operand_shape)); } - *instr.mutable_outfeed_shape() = shape_with_layout.ToProto(); - - instr.set_outfeed_config(outfeed_config); - - return AddInstruction(std::move(instr), HloOpcode::kOutfeed, - {operand, token}); + return OutfeedWithTokenInternal(operand, token, shape_with_layout, + outfeed_config); }); } +StatusOr XlaBuilder::OutfeedWithTokenInternal( + XlaOp operand, XlaOp token, const Shape& shape_with_layout, + const string& outfeed_config) { + HloInstructionProto instr; + *instr.mutable_shape() = ShapeUtil::MakeTokenShape().ToProto(); + *instr.mutable_outfeed_shape() = shape_with_layout.ToProto(); + instr.set_outfeed_config(outfeed_config); + return AddInstruction(std::move(instr), HloOpcode::kOutfeed, + {operand, token}); +} + XlaOp XlaBuilder::CreateToken() { return ReportErrorOrReturn([&]() -> StatusOr { HloInstructionProto instr; diff --git a/tensorflow/compiler/xla/client/xla_builder.h b/tensorflow/compiler/xla/client/xla_builder.h index f320feec478..4eba598ff7d 100644 --- a/tensorflow/compiler/xla/client/xla_builder.h +++ b/tensorflow/compiler/xla/client/xla_builder.h @@ -364,6 +364,10 @@ class XlaBuilder { Status SetInstructionFrontendAttribute(XlaOp op, string attribute, string value); + // Returns shapes for the operands. + StatusOr> GetOperandShapes( + absl::Span operands) const; + private: // Build helper which takes the id of the root operation.. StatusOr Build(int64 root_id, bool remove_dynamic_dimensions); @@ -391,6 +395,10 @@ class XlaBuilder { XlaOp Pad(XlaOp operand, XlaOp padding_value, const PaddingConfig& padding_config); + virtual StatusOr PadInternal(const Shape& shape, XlaOp operand, + XlaOp padding_value, + const PaddingConfig& padding_config); + XlaOp Reshape(XlaOp operand, absl::Span dimensions, absl::Span new_sizes, int64 inferred_dimension = -1); @@ -406,9 +414,12 @@ class XlaBuilder { XlaOp Slice(XlaOp operand, absl::Span start_indices, absl::Span limit_indices, absl::Span strides); - - XlaOp SliceInDim(XlaOp operand, int64 start_index, int64 limit_index, - int64 stride, int64 dimno); + virtual StatusOr SliceInternal(const Shape& shape, XlaOp operand, + absl::Span start_indices, + absl::Span limit_indices, + absl::Span strides); + virtual XlaOp SliceInDim(XlaOp operand, int64 start_index, int64 limit_index, + int64 stride, int64 dimno); ABSL_DEPRECATED("Use span-of-indices form instead") XlaOp DynamicSlice(XlaOp operand, XlaOp start_indices, @@ -422,14 +433,22 @@ class XlaBuilder { absl::Span start_indices); XlaOp ConcatInDim(absl::Span operands, int64 dimension); + virtual StatusOr ConcatInDimInternal(const Shape& shape, + absl::Span operands, + int64 dimension); void Trace(const string& tag, XlaOp operand); XlaOp Select(XlaOp pred, XlaOp on_true, XlaOp on_false); XlaOp Tuple(absl::Span elements); + virtual StatusOr TupleInternal(const Shape& shape, + absl::Span elements); XlaOp GetTupleElement(XlaOp tuple_data, int64 index); + virtual StatusOr GetTupleElementInternal(const Shape& shape, + XlaOp tuple_data, + int64 index); XlaOp Dot(XlaOp lhs, XlaOp rhs, const PrecisionConfig* precision_config = nullptr); @@ -476,15 +495,18 @@ class XlaBuilder { absl::Span fft_length); XlaOp Infeed(const Shape& shape, const string& config = ""); - XlaOp InfeedWithToken(XlaOp token, const Shape& shape, - const string& config = ""); + XlaOp InfeedWithToken(XlaOp token, const Shape& shape, const string& config); + virtual StatusOr InfeedWithTokenInternal( + const Shape& infeed_instruction_shape, XlaOp token, const string& config); void Outfeed(XlaOp operand, const Shape& shape_with_layout, const string& outfeed_config); XlaOp OutfeedWithToken(XlaOp operand, XlaOp token, const Shape& shape_with_layout, const string& outfeed_config); - + virtual StatusOr OutfeedWithTokenInternal( + XlaOp operand, XlaOp token, const Shape& shape_with_layout, + const string& outfeed_config); XlaOp Call(const XlaComputation& computation, absl::Span operands); @@ -624,7 +646,7 @@ class XlaBuilder { XlaOp RecvFromHost(XlaOp token, const Shape& shape, const ChannelHandle& handle); - XlaOp CreateToken(); + virtual XlaOp CreateToken(); XlaOp AfterAll(absl::Span tokens); @@ -701,10 +723,6 @@ class XlaBuilder { // Returns the (inferred) result for the program shape using the given root. StatusOr GetProgramShape(int64 root_id) const; - // Returns shapes for the operands. - StatusOr> GetOperandShapes( - absl::Span operands) const; - // A visitor which checks whether an operation is a compile-time constant, // meaning that it doesn't depend on any parameters, or on any stateful // operation such as `RngNormal` or `Infeed`. The visitor walks the