Replace Shape with a C++ class in XLA.

No functional change. Rename the proto message Shape to ShapeProto and define an in-place replacement C++ class named Shape with an interface which mirrors the protobuf generated code interface. Having Shape as a C++ class enables greater flexibility in the interface, enables enforcement of invariants, and potential performance improvements.

PiperOrigin-RevId: 223252977
This commit is contained in:
Mark Heffernan 2018-11-28 16:04:22 -08:00 committed by TensorFlower Gardener
parent 0f98c067fa
commit bd737c846c
46 changed files with 671 additions and 443 deletions

View File

@ -175,7 +175,8 @@ Status GenArgMethods(const tf2xla::Config& config,
}
for (int i = 0; i < num_args; ++i) {
std::vector<std::pair<string, string>> rewrites;
TF_RETURN_IF_ERROR(AddRewritesForShape(i, ps.parameters(i), &rewrites));
TF_RETURN_IF_ERROR(
AddRewritesForShape(i, xla::Shape(ps.parameters(i)), &rewrites));
const string code = R"(
void set_arg{{NAME}}_data(void* data) {
set_arg_data({{I}}, data);
@ -218,8 +219,8 @@ Status GenResultMethods(const tf2xla::Config& config,
}
for (int i = 0; i < ps.result().tuple_shapes_size(); ++i) {
std::vector<std::pair<string, string>> rewrites;
TF_RETURN_IF_ERROR(
AddRewritesForShape(i, ps.result().tuple_shapes(i), &rewrites));
TF_RETURN_IF_ERROR(AddRewritesForShape(
i, xla::Shape(ps.result().tuple_shapes(i)), &rewrites));
string code = R"(
{{TYPE}}* result{{NAME}}_data() {
return static_cast<{{TYPE}}*>(result_data({{I}}));
@ -588,7 +589,7 @@ class {{CLASS}} : public tensorflow::XlaCompiledCpuFunction {
{"{{METHODS_RESULT}}\n", methods_result},
{"{{NS_END}}\n", ns_end},
{"{{NS_START}}\n", ns_start},
{"{{PROGRAM_SHAPE}}", xla::ShapeUtil::HumanString(ps)},
{"{{PROGRAM_SHAPE}}", xla::ShapeUtil::HumanString(xla::ProgramShape(ps))},
{"{{PROGRAM_SHAPE_SHIM_EXPRESSION}}",
metadata_result.program_shape_access_shim},
{"{{RESULT_INDEX}}", absl::StrCat(result_index)},

View File

@ -58,15 +58,21 @@ Status CompileXla(xla::CompileOnlyClient* client,
}
compile_result->program_shape = pshape_or.ValueOrDie()->ToProto();
xla::ProgramShapeProto* pshape = &compile_result->program_shape;
std::vector<const xla::Shape*> arg_layouts;
arg_layouts.reserve(pshape->parameters_size());
// AotXlaComputationInstance::argument_layouts is a vector of Shape
// pointers. Accumulate the Shape objects themselves in a separate vector
// while building the vector of pointers.
std::vector<const xla::Shape*> arg_layout_ptrs(pshape->parameters_size());
std::vector<xla::Shape> arg_layouts(pshape->parameters_size());
for (int i = 0; i < pshape->parameters_size(); ++i) {
arg_layouts.push_back(pshape->mutable_parameters(i));
arg_layouts[i] = xla::Shape(*pshape->mutable_parameters(i));
arg_layout_ptrs[i] = &arg_layouts[i];
}
xla::CompileOnlyClient::AotXlaComputationInstance instance;
instance.computation = &computation;
instance.argument_layouts = std::move(arg_layouts);
instance.result_layout = &pshape->result();
instance.argument_layouts = std::move(arg_layout_ptrs);
xla::Shape result_shape(pshape->result());
instance.result_layout = &result_shape;
xla::StatusOr<std::vector<std::unique_ptr<xla::AotCompilationResult>>>
aot_or = client->CompileAheadOfTime({instance}, aot_opts);
if (!aot_or.ok()) {

View File

@ -529,10 +529,12 @@ TEST(TFCompileTest, ProgramShape) {
const xla::ProgramShapeProto* muladd_shape = muladd.ProgramShape();
ASSERT_TRUE(muladd_shape != nullptr);
ASSERT_EQ(muladd_shape->parameters_size(), 2);
EXPECT_TRUE(ShapeUtil::Compatible(muladd_shape->parameters(0), f32_2x2));
EXPECT_TRUE(ShapeUtil::Compatible(muladd_shape->parameters(1), f32_2x2));
EXPECT_TRUE(
ShapeUtil::Compatible(xla::Shape(muladd_shape->parameters(0)), f32_2x2));
EXPECT_TRUE(
ShapeUtil::Compatible(xla::Shape(muladd_shape->parameters(1)), f32_2x2));
const xla::Shape& muladd_result = muladd_shape->result();
const xla::Shape muladd_result(muladd_shape->result());
ASSERT_EQ(muladd_result.element_type(), xla::TUPLE);
ASSERT_EQ(ShapeUtil::TupleElementCount(muladd_result), 2);
const xla::Shape& muladd_result0 =

View File

@ -18,6 +18,7 @@ limitations under the License.
#ifndef TENSORFLOW_COMPILER_TF2XLA_SHAPE_UTIL_H_
#define TENSORFLOW_COMPILER_TF2XLA_SHAPE_UTIL_H_
#include "tensorflow/compiler/xla/shape.h"
#include "tensorflow/compiler/xla/xla_data.pb.h"
#include "tensorflow/core/framework/tensor_shape.h"
#include "tensorflow/core/framework/types.pb.h"

View File

@ -116,13 +116,13 @@ TEST(XlaJitCompiledCpuFunction, Sum) {
// Check program shape.
using xla::ShapeUtil;
const xla::Shape s32 = ShapeUtil::MakeShape(xla::S32, {});
const xla::ProgramShapeProto* program_shape = function.ProgramShape();
ASSERT_TRUE(program_shape != nullptr);
ASSERT_EQ(program_shape->parameters_size(), 2);
EXPECT_TRUE(ShapeUtil::Compatible(program_shape->parameters(0), s32));
EXPECT_TRUE(ShapeUtil::Compatible(program_shape->parameters(1), s32));
ASSERT_TRUE(function.ProgramShape() != nullptr);
const xla::ProgramShape program_shape(*function.ProgramShape());
ASSERT_EQ(program_shape.parameters_size(), 2);
EXPECT_TRUE(ShapeUtil::Compatible(program_shape.parameters(0), s32));
EXPECT_TRUE(ShapeUtil::Compatible(program_shape.parameters(1), s32));
const xla::Shape& result = program_shape->result();
const xla::Shape& result = program_shape.result();
ASSERT_EQ(result.element_type(), xla::TUPLE);
ASSERT_EQ(ShapeUtil::TupleElementCount(result), 1);
const xla::Shape& result0 = ShapeUtil::GetTupleElementShape(result, 0);

View File

@ -81,6 +81,7 @@ cc_library(
"//tensorflow/core:lib",
"@com_google_absl//absl/memory",
"@com_google_absl//absl/strings",
"@com_google_absl//absl/types:optional",
"@com_google_absl//absl/types:span",
],
)

View File

@ -20,6 +20,7 @@ limitations under the License.
#include "absl/memory/memory.h"
#include "absl/strings/str_cat.h"
#include "absl/types/optional.h"
#include "tensorflow/compiler/xla/client/xla_computation.h"
#include "tensorflow/compiler/xla/debug_options_flags.h"
#include "tensorflow/compiler/xla/execution_options_util.h"
@ -42,7 +43,7 @@ StatusOr<Literal> Client::Transfer(const GlobalData& data,
TransferToClientRequest request;
*request.mutable_data() = data.handle();
if (shape_with_layout != nullptr) {
*request.mutable_shape_with_layout() = *shape_with_layout;
*request.mutable_shape_with_layout() = shape_with_layout->ToProto();
}
TransferToClientResponse response;
@ -123,7 +124,7 @@ StatusOr<Literal> Client::TransferFromOutfeed(
}
request.set_replica_id(replica_id);
if (shape_with_layout != nullptr) {
*request.mutable_shape_with_layout() = *shape_with_layout;
*request.mutable_shape_with_layout() = shape_with_layout->ToProto();
}
TransferFromOutfeedResponse response;
@ -170,11 +171,14 @@ StatusOr<Literal> Client::ExecuteAndTransfer(
std::unique_ptr<GlobalData> data,
Execute(computation, arguments, execution_options, execution_profile));
const Shape* shape_with_output_layout = nullptr;
absl::optional<Shape> shape_with_output_layout;
if (execution_options && execution_options->has_shape_with_output_layout()) {
shape_with_output_layout = &execution_options->shape_with_output_layout();
shape_with_output_layout =
Shape(execution_options->shape_with_output_layout());
}
return Transfer(*data, shape_with_output_layout);
return Transfer(*data, shape_with_output_layout.has_value()
? &(*shape_with_output_layout)
: nullptr);
}
StatusOr<Literal> Client::ComputeConstant(const XlaComputation& computation,
@ -229,7 +233,7 @@ StatusOr<ExecutionHandle> Client::Compile(
// The argument shapes affect how the computation is compiled.
for (const auto& arg_shape : argument_shapes) {
*request.add_input_shape_with_layout() = arg_shape;
*request.add_input_shape_with_layout() = arg_shape.ToProto();
}
CompileResponse response;
@ -458,7 +462,7 @@ StatusOr<Shape> Client::GetShape(const GlobalData& data) {
return s;
}
return response.shape();
return Shape(response.shape());
}
StatusOr<string> Client::ExecutionStatsAsString(

View File

@ -19,6 +19,7 @@ limitations under the License.
#include "absl/strings/string_view.h"
#include "absl/types/optional.h"
#include "tensorflow/compiler/xla/service/device_memory_allocator.h"
#include "tensorflow/compiler/xla/shape.h"
#include "tensorflow/compiler/xla/util.h"
#include "tensorflow/compiler/xla/xla.pb.h"
#include "tensorflow/compiler/xla/xla_data.pb.h"

View File

@ -66,7 +66,7 @@ std::unique_ptr<GlobalData> MakeFakeDataViaDeviceOrDie(const Shape& shape,
XlaComputation computation = b.Build().ConsumeValueOrDie();
auto execution_options = CreateDefaultExecutionOptions();
*execution_options.mutable_shape_with_output_layout() = shape;
*execution_options.mutable_shape_with_output_layout() = shape.ToProto();
return client->Execute(computation, /*arguments=*/{}, &execution_options)
.ConsumeValueOrDie();
}
@ -98,8 +98,8 @@ std::vector<std::unique_ptr<GlobalData>> MakeFakeArgumentsOrDie(
auto program_shape = computation.proto().host_program_shape();
std::vector<std::unique_ptr<GlobalData>> results;
for (const Shape& shape : program_shape.parameters()) {
results.push_back(MakeFakeDataOrDie(shape, client));
for (const ShapeProto& shape : program_shape.parameters()) {
results.push_back(MakeFakeDataOrDie(Shape(shape), client));
}
return results;
}

View File

@ -36,7 +36,7 @@ OpSharding Tile(const Shape& tile_shape,
const TileAssignment& tile_assignment) {
OpSharding result;
result.set_type(OpSharding::Type::OpSharding_Type_OTHER);
*result.mutable_tile_shape() = tile_shape;
*result.mutable_tile_shape() = tile_shape.ToProto();
for (int64 dim : tile_assignment.dimensions()) {
result.add_tile_assignment_dimensions(dim);
}
@ -52,7 +52,7 @@ OpSharding Tile1D(const Shape& tile_shape, int64 num_tiles) {
CHECK_EQ(ShapeUtil::Rank(tile_shape), 1);
std::vector<int64> dimensions(1, num_tiles);
*result.mutable_tile_shape() = tile_shape;
*result.mutable_tile_shape() = tile_shape.ToProto();
auto& tile_dimension =
(*result.mutable_tile_shape()->mutable_dimensions())[0];
tile_dimension = CeilOfRatio(static_cast<int64>(tile_dimension), num_tiles);

View File

@ -102,7 +102,7 @@ StatusOr<Shape> XlaBuilder::GetShape(const XlaOp& op) const {
TF_RETURN_IF_ERROR(first_error_);
TF_ASSIGN_OR_RETURN(auto instr, LookUpInstruction(op));
return instr->shape();
return Shape(instr->shape());
}
StatusOr<std::vector<Shape>> XlaBuilder::GetOperandShapes(
@ -155,7 +155,7 @@ StatusOr<ProgramShape> XlaBuilder::GetProgramShape(int64 root_id) const {
ProgramShape program_shape;
*program_shape.mutable_result() = root_proto->shape();
*program_shape.mutable_result() = Shape(root_proto->shape());
// Check that the parameter numbers are continuous from 0, and add parameter
// shapes and names to the program shape.
@ -172,7 +172,7 @@ StatusOr<ProgramShape> XlaBuilder::GetProgramShape(int64 root_id) const {
const int64 index = instr.parameter_number();
TF_RET_CHECK(index >= 0 && index < param_count)
<< "invalid parameter number: " << index;
*program_shape.mutable_parameters(index) = instr.shape();
*program_shape.mutable_parameters(index) = Shape(instr.shape());
*program_shape.mutable_parameter_names(index) = instr.name();
}
}
@ -329,7 +329,7 @@ StatusOr<XlaOp> XlaBuilder::InDimBroadcast(
TF_RETURN_IF_ERROR(first_error_);
HloInstructionProto instr;
*instr.mutable_shape() = shape;
*instr.mutable_shape() = shape.ToProto();
for (int64 dim : broadcast_dimensions) {
instr.add_dimensions(dim);
}
@ -380,8 +380,9 @@ XlaOp XlaBuilder::UnaryOp(HloOpcode unop, const XlaOp& operand) {
return ReportErrorOrReturn([&]() -> StatusOr<XlaOp> {
HloInstructionProto instr;
TF_ASSIGN_OR_RETURN(const Shape& operand_shape, GetShape(operand));
TF_ASSIGN_OR_RETURN(*instr.mutable_shape(),
TF_ASSIGN_OR_RETURN(Shape shape,
ShapeInference::InferUnaryOpShape(unop, operand_shape));
*instr.mutable_shape() = shape.ToProto();
return AddInstruction(std::move(instr), unop, {operand});
});
}
@ -392,9 +393,10 @@ XlaOp XlaBuilder::BinaryOp(HloOpcode binop, const XlaOp& lhs, const XlaOp& rhs,
HloInstructionProto instr;
TF_ASSIGN_OR_RETURN(const Shape& lhs_shape, GetShape(lhs));
TF_ASSIGN_OR_RETURN(const Shape& rhs_shape, GetShape(rhs));
TF_ASSIGN_OR_RETURN(*instr.mutable_shape(),
TF_ASSIGN_OR_RETURN(Shape shape,
ShapeInference::InferBinaryOpShape(
binop, lhs_shape, rhs_shape, broadcast_dimensions));
*instr.mutable_shape() = shape.ToProto();
const int64 lhs_rank = ShapeUtil::Rank(lhs_shape);
const int64 rhs_rank = ShapeUtil::Rank(rhs_shape);
@ -408,7 +410,7 @@ XlaOp XlaBuilder::BinaryOp(HloOpcode binop, const XlaOp& lhs, const XlaOp& rhs,
const Shape& from_shape = should_broadcast_lhs ? lhs_shape : rhs_shape;
std::vector<int64> to_size;
for (int64 size : instr.shape().dimensions()) {
for (int64 size : shape.dimensions()) {
to_size.push_back(size);
}
for (int64 from_dim = 0; from_dim < ShapeUtil::Rank(from_shape);
@ -428,14 +430,14 @@ XlaOp XlaBuilder::BinaryOp(HloOpcode binop, const XlaOp& lhs, const XlaOp& rhs,
}
TF_ASSIGN_OR_RETURN(Shape updated_lhs_shape, GetShape(updated_lhs));
if (!ShapeUtil::SameDimensions(instr.shape(), updated_lhs_shape)) {
if (!ShapeUtil::SameDimensions(shape, updated_lhs_shape)) {
TF_ASSIGN_OR_RETURN(updated_lhs,
AddBroadcastSequence(instr.shape(), updated_lhs));
AddBroadcastSequence(shape, updated_lhs));
}
TF_ASSIGN_OR_RETURN(Shape updated_rhs_shape, GetShape(updated_rhs));
if (!ShapeUtil::SameDimensions(instr.shape(), updated_rhs_shape)) {
if (!ShapeUtil::SameDimensions(shape, updated_rhs_shape)) {
TF_ASSIGN_OR_RETURN(updated_rhs,
AddBroadcastSequence(instr.shape(), updated_rhs));
AddBroadcastSequence(shape, updated_rhs));
}
return AddInstruction(std::move(instr), binop, {updated_lhs, updated_rhs});
@ -449,30 +451,28 @@ XlaOp XlaBuilder::TernaryOp(HloOpcode triop, const XlaOp& lhs, const XlaOp& rhs,
TF_ASSIGN_OR_RETURN(const Shape& lhs_shape, GetShape(lhs));
TF_ASSIGN_OR_RETURN(const Shape& rhs_shape, GetShape(rhs));
TF_ASSIGN_OR_RETURN(const Shape& ehs_shape, GetShape(ehs));
TF_ASSIGN_OR_RETURN(*instr.mutable_shape(),
ShapeInference::InferTernaryOpShape(
triop, lhs_shape, rhs_shape, ehs_shape));
TF_ASSIGN_OR_RETURN(
Shape shape, ShapeInference::InferTernaryOpShape(triop, lhs_shape,
rhs_shape, ehs_shape));
*instr.mutable_shape() = shape.ToProto();
XlaOp updated_lhs = lhs;
XlaOp updated_rhs = rhs;
XlaOp updated_ehs = ehs;
if (!ShapeUtil::IsTuple(instr.shape())) {
if (!ShapeUtil::IsTuple(shape)) {
if (!ShapeUtil::IsTuple(lhs_shape) &&
!ShapeUtil::SameDimensions(instr.shape(), lhs_shape)) {
!ShapeUtil::SameDimensions(shape, lhs_shape)) {
// lhs is being implicitly broadcasted. Change to explicit.
TF_ASSIGN_OR_RETURN(updated_lhs,
AddBroadcastSequence(instr.shape(), lhs));
TF_ASSIGN_OR_RETURN(updated_lhs, AddBroadcastSequence(shape, lhs));
}
if (!ShapeUtil::IsTuple(rhs_shape) &&
!ShapeUtil::SameDimensions(instr.shape(), rhs_shape)) {
!ShapeUtil::SameDimensions(shape, rhs_shape)) {
// rhs is being implicitly broadcasted. Change to explicit.
TF_ASSIGN_OR_RETURN(updated_rhs,
AddBroadcastSequence(instr.shape(), rhs));
TF_ASSIGN_OR_RETURN(updated_rhs, AddBroadcastSequence(shape, rhs));
}
if (!ShapeUtil::IsTuple(ehs_shape) &&
!ShapeUtil::SameDimensions(instr.shape(), ehs_shape)) {
!ShapeUtil::SameDimensions(shape, ehs_shape)) {
// ehs is being implicitly broadcasted. Change to explicit.
TF_ASSIGN_OR_RETURN(updated_ehs,
AddBroadcastSequence(instr.shape(), ehs));
TF_ASSIGN_OR_RETURN(updated_ehs, AddBroadcastSequence(shape, ehs));
}
}
return AddInstruction(std::move(instr), triop,
@ -493,7 +493,7 @@ XlaOp XlaBuilder::Mul(const XlaOp& lhs, const XlaOp& rhs,
XlaOp XlaBuilder::ConstantLiteral(const LiteralSlice& literal) {
return ReportErrorOrReturn([&]() -> StatusOr<XlaOp> {
HloInstructionProto instr;
*instr.mutable_shape() = literal.shape();
*instr.mutable_shape() = literal.shape().ToProto();
*instr.mutable_literal() = literal.ToProto();
return AddInstruction(std::move(instr), HloOpcode::kConstant);
});
@ -502,7 +502,7 @@ XlaOp XlaBuilder::ConstantLiteral(const LiteralSlice& literal) {
XlaOp XlaBuilder::Iota(const Shape& shape, int64 iota_dimension) {
return ReportErrorOrReturn([&]() -> StatusOr<XlaOp> {
HloInstructionProto instr;
*instr.mutable_shape() = shape;
*instr.mutable_shape() = shape.ToProto();
instr.add_dimensions(iota_dimension);
return AddInstruction(std::move(instr), HloOpcode::kIota);
});
@ -522,10 +522,10 @@ XlaOp XlaBuilder::Call(const XlaComputation& computation,
[](const Shape& shape) { return &shape; });
TF_ASSIGN_OR_RETURN(const ProgramShape& called_program_shape,
computation.GetProgramShape());
TF_ASSIGN_OR_RETURN(
*instr.mutable_shape(),
ShapeInference::InferCallShape(operand_shape_ptrs,
/*to_apply=*/called_program_shape));
TF_ASSIGN_OR_RETURN(Shape shape, ShapeInference::InferCallShape(
operand_shape_ptrs,
/*to_apply=*/called_program_shape));
*instr.mutable_shape() = shape.ToProto();
AddCalledComputation(computation, &instr);
@ -543,7 +543,7 @@ XlaOp XlaBuilder::Parameter(int64 parameter_number, const Shape& shape,
}
instr.set_parameter_number(parameter_number);
instr.set_name(name);
*instr.mutable_shape() = shape;
*instr.mutable_shape() = shape.ToProto();
return AddInstruction(std::move(instr), HloOpcode::kParameter);
});
}
@ -601,7 +601,7 @@ StatusOr<XlaOp> XlaBuilder::Reshape(const Shape& shape, const XlaOp& operand) {
TF_RETURN_IF_ERROR(first_error_);
HloInstructionProto instr;
*instr.mutable_shape() = shape;
*instr.mutable_shape() = shape.ToProto();
return AddInstruction(std::move(instr), HloOpcode::kReshape, {operand});
}
@ -613,9 +613,9 @@ XlaOp XlaBuilder::Slice(const XlaOp& operand,
HloInstructionProto instr;
TF_ASSIGN_OR_RETURN(const Shape& operand_shape, GetShape(operand));
TF_ASSIGN_OR_RETURN(
*instr.mutable_shape(),
ShapeInference::InferSliceShape(operand_shape, start_indices,
limit_indices, strides));
Shape shape, ShapeInference::InferSliceShape(
operand_shape, start_indices, limit_indices, strides));
*instr.mutable_shape() = shape.ToProto();
for (int i = 0; i < start_indices.size(); i++) {
auto* slice_config = instr.add_slice_dimensions();
slice_config->set_start(start_indices[i]);
@ -650,9 +650,10 @@ XlaOp XlaBuilder::DynamicSlice(const XlaOp& operand, const XlaOp& start_indices,
TF_ASSIGN_OR_RETURN(const Shape& operand_shape, GetShape(operand));
TF_ASSIGN_OR_RETURN(const Shape& start_indices_shape,
GetShape(start_indices));
TF_ASSIGN_OR_RETURN(*instr.mutable_shape(),
TF_ASSIGN_OR_RETURN(Shape shape,
ShapeInference::InferDynamicSliceShape(
operand_shape, start_indices_shape, slice_sizes));
*instr.mutable_shape() = shape.ToProto();
for (int64 size : slice_sizes) {
instr.add_dynamic_slice_sizes(size);
@ -672,9 +673,10 @@ XlaOp XlaBuilder::DynamicUpdateSlice(const XlaOp& operand, const XlaOp& update,
TF_ASSIGN_OR_RETURN(const Shape& update_shape, GetShape(update));
TF_ASSIGN_OR_RETURN(const Shape& start_indices_shape,
GetShape(start_indices));
TF_ASSIGN_OR_RETURN(*instr.mutable_shape(),
TF_ASSIGN_OR_RETURN(Shape shape,
ShapeInference::InferDynamicUpdateSliceShape(
operand_shape, update_shape, start_indices_shape));
*instr.mutable_shape() = shape.ToProto();
return AddInstruction(std::move(instr), HloOpcode::kDynamicUpdateSlice,
{operand, update, start_indices});
@ -690,9 +692,9 @@ XlaOp XlaBuilder::ConcatInDim(absl::Span<const XlaOp> operands,
TF_ASSIGN_OR_RETURN(const auto& operand_shapes, GetOperandShapes(operands));
absl::c_transform(operand_shapes, std::back_inserter(operand_shape_ptrs),
[](const Shape& shape) { return &shape; });
TF_ASSIGN_OR_RETURN(
*instr.mutable_shape(),
ShapeInference::InferConcatOpShape(operand_shape_ptrs, dimension));
TF_ASSIGN_OR_RETURN(Shape shape, ShapeInference::InferConcatOpShape(
operand_shape_ptrs, dimension));
*instr.mutable_shape() = shape.ToProto();
instr.add_dimensions(dimension);
@ -709,10 +711,9 @@ XlaOp XlaBuilder::Pad(const XlaOp& operand, const XlaOp& padding_value,
TF_ASSIGN_OR_RETURN(const Shape& padding_value_shape,
GetShape(padding_value));
TF_ASSIGN_OR_RETURN(
*instr.mutable_shape(),
ShapeInference::InferPadShape(operand_shape, padding_value_shape,
padding_config));
Shape shape, ShapeInference::InferPadShape(
operand_shape, padding_value_shape, padding_config));
*instr.mutable_shape() = shape.ToProto();
*instr.mutable_padding_config() = padding_config;
return AddInstruction(std::move(instr), HloOpcode::kPad,
@ -725,7 +726,7 @@ XlaOp XlaBuilder::Reshape(const XlaOp& operand,
absl::Span<const int64> new_sizes) {
return ReportErrorOrReturn([&]() -> StatusOr<XlaOp> {
TF_ASSIGN_OR_RETURN(const Shape& operand_shape, GetShape(operand));
TF_ASSIGN_OR_RETURN(const Shape& shape,
TF_ASSIGN_OR_RETURN(const Shape shape,
ShapeInference::InferReshapeShape(
operand_shape, dimensions, new_sizes));
XlaOp transposed = IsIdentityPermutation(dimensions)
@ -738,7 +739,7 @@ XlaOp XlaBuilder::Reshape(const XlaOp& operand,
XlaOp XlaBuilder::Reshape(const XlaOp& operand,
absl::Span<const int64> new_sizes) {
return ReportErrorOrReturn([&]() -> StatusOr<XlaOp> {
TF_ASSIGN_OR_RETURN(auto shape, GetShape(operand));
TF_ASSIGN_OR_RETURN(Shape shape, GetShape(operand));
std::vector<int64> dimensions(shape.dimensions_size());
std::iota(dimensions.begin(), dimensions.end(), 0);
return Reshape(operand, dimensions, new_sizes);
@ -788,7 +789,7 @@ XlaOp XlaBuilder::Collapse(const XlaOp& operand,
void XlaBuilder::Trace(const string& tag, const XlaOp& operand) {
ReportErrorOrReturn([&]() -> StatusOr<XlaOp> {
HloInstructionProto instr;
*instr.mutable_shape() = ShapeUtil::MakeNil();
*instr.mutable_shape() = ShapeUtil::MakeNil().ToProto();
*instr.mutable_literal() = LiteralUtil::CreateR1U8(tag).ToProto();
return AddInstruction(std::move(instr), HloOpcode::kTrace, {operand});
});
@ -814,9 +815,10 @@ XlaOp XlaBuilder::Tuple(absl::Span<const XlaOp> elements) {
TF_ASSIGN_OR_RETURN(const auto& operand_shapes, GetOperandShapes(elements));
absl::c_transform(operand_shapes, std::back_inserter(operand_shape_ptrs),
[](const Shape& shape) { return &shape; });
TF_ASSIGN_OR_RETURN(*instr.mutable_shape(),
TF_ASSIGN_OR_RETURN(const Shape shape,
ShapeInference::InferVariadicOpShape(
HloOpcode::kTuple, operand_shape_ptrs));
*instr.mutable_shape() = shape.ToProto();
return AddInstruction(std::move(instr), HloOpcode::kTuple, elements);
});
}
@ -831,7 +833,7 @@ XlaOp XlaBuilder::GetTupleElement(const XlaOp& tuple_data, int64 index) {
ShapeUtil::HumanString(tuple_shape));
}
*instr.mutable_shape() =
ShapeUtil::GetTupleElementShape(tuple_shape, index);
ShapeUtil::GetTupleElementShape(tuple_shape, index).ToProto();
instr.set_tuple_index(index);
@ -890,9 +892,10 @@ XlaOp XlaBuilder::DotGeneral(const XlaOp& lhs, const XlaOp& rhs,
HloInstructionProto instr;
TF_ASSIGN_OR_RETURN(const Shape& lhs_shape, GetShape(lhs));
TF_ASSIGN_OR_RETURN(const Shape& rhs_shape, GetShape(rhs));
TF_ASSIGN_OR_RETURN(*instr.mutable_shape(),
TF_ASSIGN_OR_RETURN(Shape shape,
ShapeInference::InferDotOpShape(lhs_shape, rhs_shape,
dimension_numbers));
*instr.mutable_shape() = shape.ToProto();
*instr.mutable_dot_dimension_numbers() = dimension_numbers;
if (precision_config != nullptr) {
*instr.mutable_precision_config() = *precision_config;
@ -1034,10 +1037,11 @@ XlaOp XlaBuilder::ConvGeneralDilated(
MakeWindow(window_dimensions, window_strides, padding,
lhs_dilation, rhs_dilation));
TF_ASSIGN_OR_RETURN(*instr.mutable_shape(),
TF_ASSIGN_OR_RETURN(Shape shape,
ShapeInference::InferConvolveShape(
lhs_shape, rhs_shape, feature_group_count,
instr.window(), dimension_numbers));
*instr.mutable_shape() = shape.ToProto();
*instr.mutable_convolution_dimension_numbers() = dimension_numbers;
instr.set_feature_group_count(feature_group_count);
@ -1110,10 +1114,9 @@ XlaOp XlaBuilder::Fft(const XlaOp& operand, const FftType fft_type,
return ReportErrorOrReturn([&]() -> StatusOr<XlaOp> {
HloInstructionProto instr;
TF_ASSIGN_OR_RETURN(const Shape& operand_shape, GetShape(operand));
TF_ASSIGN_OR_RETURN(
*instr.mutable_shape(),
ShapeInference::InferFftShape(operand_shape, fft_type, fft_length));
TF_ASSIGN_OR_RETURN(Shape shape, ShapeInference::InferFftShape(
operand_shape, fft_type, fft_length));
*instr.mutable_shape() = shape.ToProto();
instr.set_fft_type(fft_type);
for (int64 i : fft_length) {
instr.add_fft_length(i);
@ -1131,7 +1134,7 @@ XlaOp XlaBuilder::Infeed(const Shape& shape, const string& config) {
}
const Shape infeed_instruction_shape =
ShapeUtil::MakeTupleShape({shape, ShapeUtil::MakeTokenShape()});
*instr.mutable_shape() = infeed_instruction_shape;
*instr.mutable_shape() = infeed_instruction_shape.ToProto();
instr.set_infeed_config(config);
if (ShapeUtil::IsArray(shape) && sharding() &&
@ -1152,7 +1155,7 @@ XlaOp XlaBuilder::Infeed(const Shape& shape, const string& config) {
XlaOp token;
auto make_token = [&]() {
HloInstructionProto token_instr;
*token_instr.mutable_shape() = ShapeUtil::MakeTokenShape();
*token_instr.mutable_shape() = ShapeUtil::MakeTokenShape().ToProto();
return AddInstruction(std::move(token_instr), HloOpcode::kAfterAll, {});
};
if (sharding()) {
@ -1191,7 +1194,7 @@ XlaOp XlaBuilder::Infeed(const Shape& shape, const string& config) {
// TODO(b/80000000): Remove this when clients have been updated to handle
// tokens.
HloInstructionProto infeed_data;
*infeed_data.mutable_shape() = shape;
*infeed_data.mutable_shape() = shape.ToProto();
infeed_data.set_tuple_index(0);
return AddInstruction(std::move(infeed_data), HloOpcode::kGetTupleElement,
{infeed});
@ -1207,7 +1210,7 @@ XlaOp XlaBuilder::InfeedWithToken(const XlaOp& token, const Shape& shape,
}
const Shape infeed_instruction_shape =
ShapeUtil::MakeTupleShape({shape, ShapeUtil::MakeTokenShape()});
*instr.mutable_shape() = infeed_instruction_shape;
*instr.mutable_shape() = infeed_instruction_shape.ToProto();
instr.set_infeed_config(config);
if (ShapeUtil::IsArray(shape) && sharding() &&
@ -1232,7 +1235,7 @@ void XlaBuilder::Outfeed(const XlaOp& operand, const Shape& shape_with_layout,
ReportErrorOrReturn([&]() -> StatusOr<XlaOp> {
HloInstructionProto instr;
*instr.mutable_shape() = ShapeUtil::MakeTokenShape();
*instr.mutable_shape() = ShapeUtil::MakeTokenShape().ToProto();
// Check and set outfeed shape.
if (!LayoutUtil::HasLayout(shape_with_layout)) {
@ -1245,14 +1248,14 @@ void XlaBuilder::Outfeed(const XlaOp& operand, const Shape& shape_with_layout,
ShapeUtil::HumanStringWithLayout(shape_with_layout),
ShapeUtil::HumanStringWithLayout(operand_shape));
}
*instr.mutable_outfeed_shape() = shape_with_layout;
*instr.mutable_outfeed_shape() = shape_with_layout.ToProto();
instr.set_outfeed_config(outfeed_config);
// Outfeed takes a token as its second operand. Generate the token to pass
// to the outfeed.
HloInstructionProto token_instr;
*token_instr.mutable_shape() = ShapeUtil::MakeTokenShape();
*token_instr.mutable_shape() = ShapeUtil::MakeTokenShape().ToProto();
TF_ASSIGN_OR_RETURN(XlaOp token, AddInstruction(std::move(token_instr),
HloOpcode::kAfterAll, {}));
@ -1266,7 +1269,7 @@ void XlaBuilder::Outfeed(const XlaOp& operand, const Shape& shape_with_layout,
// TODO(b/80000000): Remove this when clients have been updated to handle
// tokens.
HloInstructionProto tuple_instr;
*tuple_instr.mutable_shape() = ShapeUtil::MakeNil();
*tuple_instr.mutable_shape() = ShapeUtil::MakeNil().ToProto();
// The dummy tuple should have no sharding.
{
@ -1285,7 +1288,7 @@ XlaOp XlaBuilder::OutfeedWithToken(const XlaOp& operand, const XlaOp& token,
return ReportErrorOrReturn([&]() -> StatusOr<XlaOp> {
HloInstructionProto instr;
*instr.mutable_shape() = ShapeUtil::MakeTokenShape();
*instr.mutable_shape() = ShapeUtil::MakeTokenShape().ToProto();
// Check and set outfeed shape.
if (!LayoutUtil::HasLayout(shape_with_layout)) {
@ -1298,7 +1301,7 @@ XlaOp XlaBuilder::OutfeedWithToken(const XlaOp& operand, const XlaOp& token,
ShapeUtil::HumanStringWithLayout(shape_with_layout),
ShapeUtil::HumanStringWithLayout(operand_shape));
}
*instr.mutable_outfeed_shape() = shape_with_layout;
*instr.mutable_outfeed_shape() = shape_with_layout.ToProto();
instr.set_outfeed_config(outfeed_config);
@ -1310,7 +1313,7 @@ XlaOp XlaBuilder::OutfeedWithToken(const XlaOp& operand, const XlaOp& token,
XlaOp XlaBuilder::CreateToken() {
return ReportErrorOrReturn([&]() -> StatusOr<XlaOp> {
HloInstructionProto instr;
*instr.mutable_shape() = ShapeUtil::MakeTokenShape();
*instr.mutable_shape() = ShapeUtil::MakeTokenShape().ToProto();
return AddInstruction(std::move(instr), HloOpcode::kAfterAll);
});
}
@ -1330,7 +1333,7 @@ XlaOp XlaBuilder::AfterAll(absl::Span<const XlaOp> tokens) {
}
}
HloInstructionProto instr;
*instr.mutable_shape() = ShapeUtil::MakeTokenShape();
*instr.mutable_shape() = ShapeUtil::MakeTokenShape().ToProto();
return AddInstruction(std::move(instr), HloOpcode::kAfterAll, tokens);
});
}
@ -1347,7 +1350,7 @@ XlaOp XlaBuilder::CustomCall(
"are reserved for internal use.",
call_target_name);
}
*instr.mutable_shape() = shape;
*instr.mutable_shape() = shape.ToProto();
instr.set_custom_call_target(call_target_name);
instr.set_custom_call_opaque(opaque);
if (operand_shapes_with_layout.has_value()) {
@ -1371,7 +1374,7 @@ XlaOp XlaBuilder::CustomCall(
"constrained layout.",
operand_num);
}
*instr.add_operand_shapes_with_layout() = operand_shape;
*instr.add_operand_shapes_with_layout() = operand_shape.ToProto();
++operand_num;
}
}
@ -1525,9 +1528,9 @@ XlaOp XlaBuilder::Transpose(const XlaOp& operand,
return ReportErrorOrReturn([&]() -> StatusOr<XlaOp> {
HloInstructionProto instr;
TF_ASSIGN_OR_RETURN(const Shape& operand_shape, GetShape(operand));
TF_ASSIGN_OR_RETURN(
*instr.mutable_shape(),
ShapeInference::InferTransposeShape(operand_shape, permutation));
TF_ASSIGN_OR_RETURN(Shape shape, ShapeInference::InferTransposeShape(
operand_shape, permutation));
*instr.mutable_shape() = shape.ToProto();
for (int64 dim : permutation) {
instr.add_dimensions(dim);
}
@ -1540,9 +1543,9 @@ XlaOp XlaBuilder::Rev(const XlaOp& operand,
return ReportErrorOrReturn([&]() -> StatusOr<XlaOp> {
HloInstructionProto instr;
TF_ASSIGN_OR_RETURN(const Shape& operand_shape, GetShape(operand));
TF_ASSIGN_OR_RETURN(
*instr.mutable_shape(),
ShapeInference::InferReverseShape(operand_shape, dimensions));
TF_ASSIGN_OR_RETURN(Shape shape, ShapeInference::InferReverseShape(
operand_shape, dimensions));
*instr.mutable_shape() = shape.ToProto();
for (int64 dim : dimensions) {
instr.add_dimensions(dim);
}
@ -1561,9 +1564,9 @@ XlaOp XlaBuilder::Sort(const XlaOp& keys, absl::Span<const XlaOp> values,
GetOperandShapes(values));
absl::c_transform(values_shapes, std::back_inserter(operand_shape_ptrs),
[](const Shape& shape) { return &shape; });
TF_ASSIGN_OR_RETURN(*instr.mutable_shape(),
ShapeInference::InferVariadicOpShape(
HloOpcode::kSort, operand_shape_ptrs));
TF_ASSIGN_OR_RETURN(Shape shape, ShapeInference::InferVariadicOpShape(
HloOpcode::kSort, operand_shape_ptrs));
*instr.mutable_shape() = shape.ToProto();
if (dimension == -1) {
TF_ASSIGN_OR_RETURN(const Shape& keys_shape, GetShape(keys));
dimension = ShapeUtil::Rank(keys_shape) - 1;
@ -1585,9 +1588,9 @@ XlaOp XlaBuilder::ConvertElementType(const XlaOp& operand,
return ReportErrorOrReturn([&]() -> StatusOr<XlaOp> {
HloInstructionProto instr;
TF_ASSIGN_OR_RETURN(const Shape& operand_shape, GetShape(operand));
TF_ASSIGN_OR_RETURN(
*instr.mutable_shape(),
ShapeInference::InferConvertShape(operand_shape, new_element_type));
TF_ASSIGN_OR_RETURN(Shape shape, ShapeInference::InferConvertShape(
operand_shape, new_element_type));
*instr.mutable_shape() = shape.ToProto();
return AddInstruction(std::move(instr), HloOpcode::kConvert, {operand});
});
}
@ -1597,9 +1600,9 @@ XlaOp XlaBuilder::BitcastConvertType(const XlaOp& operand,
return ReportErrorOrReturn([&]() -> StatusOr<XlaOp> {
HloInstructionProto instr;
TF_ASSIGN_OR_RETURN(const Shape& operand_shape, GetShape(operand));
TF_ASSIGN_OR_RETURN(
*instr.mutable_shape(),
ShapeInference::InferConvertShape(operand_shape, new_element_type));
TF_ASSIGN_OR_RETURN(Shape shape, ShapeInference::InferConvertShape(
operand_shape, new_element_type));
*instr.mutable_shape() = shape.ToProto();
return AddInstruction(std::move(instr), HloOpcode::kBitcastConvert,
{operand});
});
@ -1631,11 +1634,11 @@ XlaOp XlaBuilder::Map(absl::Span<const XlaOp> operands,
TF_ASSIGN_OR_RETURN(const ProgramShape& called_program_shape,
computation.GetProgramShape());
TF_ASSIGN_OR_RETURN(
*instr.mutable_shape(),
ShapeInference::InferMapShape(operand_shape_ptrs, called_program_shape,
dimensions));
Shape shape, ShapeInference::InferMapShape(
operand_shape_ptrs, called_program_shape, dimensions));
*instr.mutable_shape() = shape.ToProto();
const Shape& output_shape = instr.shape();
Shape output_shape(instr.shape());
const int64 output_rank = ShapeUtil::Rank(output_shape);
AddCalledComputation(computation, &instr);
std::vector<XlaOp> new_operands(operands.begin(), operands.end());
@ -1678,7 +1681,7 @@ XlaOp XlaBuilder::RngOp(RandomDistribution distribution,
}
TF_RETURN_IF_ERROR(ShapeUtil::ValidateShapeWithOptionalLayout(shape));
*instr.mutable_shape() = shape;
*instr.mutable_shape() = shape.ToProto();
instr.set_distribution(distribution);
@ -1706,10 +1709,10 @@ XlaOp XlaBuilder::While(const XlaComputation& condition,
TF_ASSIGN_OR_RETURN(const auto& condition_program_shape,
condition.GetProgramShape());
TF_ASSIGN_OR_RETURN(const Shape& init_shape, GetShape(init));
TF_ASSIGN_OR_RETURN(
*instr.mutable_shape(),
ShapeInference::InferWhileShape(condition_program_shape,
body_program_shape, init_shape));
TF_ASSIGN_OR_RETURN(Shape shape, ShapeInference::InferWhileShape(
condition_program_shape,
body_program_shape, init_shape));
*instr.mutable_shape() = shape.ToProto();
// Body comes before condition computation in the vector.
AddCalledComputation(body, &instr);
AddCalledComputation(condition, &instr);
@ -1726,10 +1729,10 @@ XlaOp XlaBuilder::Gather(const XlaOp& input, const XlaOp& start_indices,
TF_ASSIGN_OR_RETURN(const Shape& input_shape, GetShape(input));
TF_ASSIGN_OR_RETURN(const Shape& start_indices_shape,
GetShape(start_indices));
TF_ASSIGN_OR_RETURN(
*instr.mutable_shape(),
ShapeInference::InferGatherShape(input_shape, start_indices_shape,
TF_ASSIGN_OR_RETURN(Shape shape, ShapeInference::InferGatherShape(
input_shape, start_indices_shape,
dimension_numbers, slice_sizes));
*instr.mutable_shape() = shape.ToProto();
*instr.mutable_gather_dimension_numbers() = dimension_numbers;
for (int64 bound : slice_sizes) {
@ -1754,10 +1757,11 @@ XlaOp XlaBuilder::Scatter(const XlaOp& input, const XlaOp& scatter_indices,
TF_ASSIGN_OR_RETURN(const Shape& updates_shape, GetShape(updates));
TF_ASSIGN_OR_RETURN(const ProgramShape& to_apply_shape,
update_computation.GetProgramShape());
TF_ASSIGN_OR_RETURN(*instr.mutable_shape(),
TF_ASSIGN_OR_RETURN(Shape shape,
ShapeInference::InferScatterShape(
input_shape, scatter_indices_shape, updates_shape,
to_apply_shape, dimension_numbers));
*instr.mutable_shape() = shape.ToProto();
*instr.mutable_scatter_dimension_numbers() = dimension_numbers;
@ -1784,10 +1788,11 @@ XlaOp XlaBuilder::Conditional(const XlaOp& predicate, const XlaOp& true_operand,
TF_ASSIGN_OR_RETURN(const ProgramShape& false_computation_shape,
false_computation.GetProgramShape());
TF_ASSIGN_OR_RETURN(
*instr.mutable_shape(),
Shape shape,
ShapeInference::InferConditionalShape(
predicate_shape, true_operand_shape, false_operand_shape,
true_computation_shape, false_computation_shape));
*instr.mutable_shape() = shape.ToProto();
// The index of true_computation must be 0 and that of false computation
// must be 1.
@ -1829,9 +1834,10 @@ XlaOp XlaBuilder::Reduce(absl::Span<const XlaOp> operands,
[](const Shape& shape) { return &shape; });
TF_ASSIGN_OR_RETURN(
*instr.mutable_shape(),
Shape shape,
ShapeInference::InferReduceShape(
operand_shape_ptrs, dimensions_to_reduce, called_program_shape));
*instr.mutable_shape() = shape.ToProto();
for (int64 dim : dimensions_to_reduce) {
instr.add_dimensions(dim);
@ -1894,10 +1900,10 @@ XlaOp XlaBuilder::ReduceWindowWithGeneralPadding(
MakeWindow(window_dimensions, window_strides, padding,
/*lhs_dilation=*/base_dilations,
/*rhs_dilation=*/window_dilations));
TF_ASSIGN_OR_RETURN(
*instr.mutable_shape(),
ShapeInference::InferReduceWindowShape(operand_shape, init_shape,
instr.window(), to_apply_shape));
TF_ASSIGN_OR_RETURN(Shape shape, ShapeInference::InferReduceWindowShape(
operand_shape, init_shape,
instr.window(), to_apply_shape));
*instr.mutable_shape() = shape.ToProto();
AddCalledComputation(computation, &instr);
return AddInstruction(std::move(instr), HloOpcode::kReduceWindow,
@ -1915,9 +1921,10 @@ XlaOp XlaBuilder::BatchNormTraining(const XlaOp& operand, const XlaOp& scale,
TF_ASSIGN_OR_RETURN(const Shape& scale_shape, GetShape(scale));
TF_ASSIGN_OR_RETURN(const Shape& offset_shape, GetShape(offset));
TF_ASSIGN_OR_RETURN(
*instr.mutable_shape(),
Shape shape,
ShapeInference::InferBatchNormTrainingShape(
operand_shape, scale_shape, offset_shape, feature_index));
*instr.mutable_shape() = shape.ToProto();
instr.set_epsilon(epsilon);
instr.set_feature_index(feature_index);
@ -1939,10 +1946,11 @@ XlaOp XlaBuilder::BatchNormInference(const XlaOp& operand, const XlaOp& scale,
TF_ASSIGN_OR_RETURN(const Shape& offset_shape, GetShape(offset));
TF_ASSIGN_OR_RETURN(const Shape& mean_shape, GetShape(mean));
TF_ASSIGN_OR_RETURN(const Shape& variance_shape, GetShape(variance));
TF_ASSIGN_OR_RETURN(*instr.mutable_shape(),
ShapeInference::InferBatchNormInferenceShape(
operand_shape, scale_shape, offset_shape,
mean_shape, variance_shape, feature_index));
TF_ASSIGN_OR_RETURN(
Shape shape, ShapeInference::InferBatchNormInferenceShape(
operand_shape, scale_shape, offset_shape, mean_shape,
variance_shape, feature_index));
*instr.mutable_shape() = shape.ToProto();
instr.set_epsilon(epsilon);
instr.set_feature_index(feature_index);
@ -1964,10 +1972,11 @@ XlaOp XlaBuilder::BatchNormGrad(const XlaOp& operand, const XlaOp& scale,
TF_ASSIGN_OR_RETURN(const Shape& batch_mean_shape, GetShape(batch_mean));
TF_ASSIGN_OR_RETURN(const Shape& batch_var_shape, GetShape(batch_var));
TF_ASSIGN_OR_RETURN(const Shape& grad_output_shape, GetShape(grad_output));
TF_ASSIGN_OR_RETURN(*instr.mutable_shape(),
TF_ASSIGN_OR_RETURN(Shape shape,
ShapeInference::InferBatchNormGradShape(
operand_shape, scale_shape, batch_mean_shape,
batch_var_shape, grad_output_shape, feature_index));
*instr.mutable_shape() = shape.ToProto();
instr.set_epsilon(epsilon);
instr.set_feature_index(feature_index);
@ -1998,9 +2007,9 @@ XlaOp XlaBuilder::CrossReplicaSum(
return ReportErrorOrReturn([&]() -> StatusOr<XlaOp> {
HloInstructionProto instr;
TF_ASSIGN_OR_RETURN(const Shape& operand_shape, GetShape(operand));
TF_ASSIGN_OR_RETURN(
*instr.mutable_shape(),
ShapeInference::InferCrossReplicaSumShape({&operand_shape}));
TF_ASSIGN_OR_RETURN(Shape shape, ShapeInference::InferCrossReplicaSumShape(
{&operand_shape}));
*instr.mutable_shape() = shape.ToProto();
for (const ReplicaGroup& group : replica_groups) {
*instr.add_replica_groups() = group;
@ -2053,8 +2062,8 @@ XlaOp XlaBuilder::AllToAll(const XlaOp& operand, int64 split_dimension,
absl::c_transform(slice_shapes, std::back_inserter(slice_shape_ptrs),
[](const Shape& shape) { return &shape; });
TF_ASSIGN_OR_RETURN(
*instr.mutable_shape(),
ShapeInference::InferAllToAllTupleShape(slice_shape_ptrs));
Shape shape, ShapeInference::InferAllToAllTupleShape(slice_shape_ptrs));
*instr.mutable_shape() = shape.ToProto();
for (const ReplicaGroup& group : replica_groups) {
*instr.add_replica_groups() = group;
}
@ -2079,8 +2088,9 @@ XlaOp XlaBuilder::CollectivePermute(
TF_ASSIGN_OR_RETURN(const Shape& operand_shape, GetShape(operand));
HloInstructionProto instr;
TF_ASSIGN_OR_RETURN(
*instr.mutable_shape(),
Shape shape,
ShapeInference::InferCollectivePermuteShape(operand_shape));
*instr.mutable_shape() = shape.ToProto();
for (const auto& pair : source_target_pairs) {
auto* proto_pair = instr.add_source_target_pairs();
@ -2129,10 +2139,11 @@ XlaOp XlaBuilder::SelectAndScatterWithGeneralPadding(
TF_ASSIGN_OR_RETURN(*instr.mutable_window(),
MakeWindow(window_dimensions, window_strides, padding,
/*lhs_dilation=*/{}, /*rhs_dilation=*/{}));
TF_ASSIGN_OR_RETURN(*instr.mutable_shape(),
TF_ASSIGN_OR_RETURN(Shape shape,
ShapeInference::InferSelectAndScatterShape(
operand_shape, select_shape, instr.window(),
source_shape, init_shape, scatter_shape));
*instr.mutable_shape() = shape.ToProto();
AddCalledComputation(select, &instr);
AddCalledComputation(scatter, &instr);
@ -2147,9 +2158,10 @@ XlaOp XlaBuilder::ReducePrecision(const XlaOp& operand, const int exponent_bits,
return ReportErrorOrReturn([&]() -> StatusOr<XlaOp> {
HloInstructionProto instr;
TF_ASSIGN_OR_RETURN(const Shape& operand_shape, GetShape(operand));
TF_ASSIGN_OR_RETURN(*instr.mutable_shape(),
TF_ASSIGN_OR_RETURN(Shape shape,
ShapeInference::InferReducePrecisionShape(
operand_shape, exponent_bits, mantissa_bits));
*instr.mutable_shape() = shape.ToProto();
instr.set_exponent_bits(exponent_bits);
instr.set_mantissa_bits(mantissa_bits);
return AddInstruction(std::move(instr), HloOpcode::kReducePrecision,
@ -2164,7 +2176,7 @@ void XlaBuilder::Send(const XlaOp& operand, const ChannelHandle& handle) {
// TODO(b/80000000): Remove this when clients have been updated to handle
// tokens.
HloInstructionProto token_instr;
*token_instr.mutable_shape() = ShapeUtil::MakeTokenShape();
*token_instr.mutable_shape() = ShapeUtil::MakeTokenShape().ToProto();
TF_ASSIGN_OR_RETURN(XlaOp token, AddInstruction(std::move(token_instr),
HloOpcode::kAfterAll, {}));
@ -2183,15 +2195,17 @@ XlaOp XlaBuilder::SendWithToken(const XlaOp& operand, const XlaOp& token,
// token}.
HloInstructionProto send_instr;
TF_ASSIGN_OR_RETURN(const Shape& shape, GetShape(operand));
*send_instr.mutable_shape() = ShapeUtil::MakeTupleShape(
{shape, ShapeUtil::MakeShape(U32, {}), ShapeUtil::MakeTokenShape()});
*send_instr.mutable_shape() =
ShapeUtil::MakeTupleShape(
{shape, ShapeUtil::MakeShape(U32, {}), ShapeUtil::MakeTokenShape()})
.ToProto();
send_instr.set_channel_id(handle.handle());
TF_ASSIGN_OR_RETURN(XlaOp send,
AddInstruction(std::move(send_instr), HloOpcode::kSend,
{operand, token}));
HloInstructionProto send_done_instr;
*send_done_instr.mutable_shape() = ShapeUtil::MakeTokenShape();
*send_done_instr.mutable_shape() = ShapeUtil::MakeTokenShape().ToProto();
send_done_instr.set_channel_id(handle.handle());
return AddInstruction(std::move(send_done_instr), HloOpcode::kSendDone,
{send});
@ -2205,7 +2219,7 @@ XlaOp XlaBuilder::Recv(const Shape& shape, const ChannelHandle& handle) {
// TODO(b/80000000): Remove this when clients have been updated to handle
// tokens.
HloInstructionProto token_instr;
*token_instr.mutable_shape() = ShapeUtil::MakeTokenShape();
*token_instr.mutable_shape() = ShapeUtil::MakeTokenShape().ToProto();
TF_ASSIGN_OR_RETURN(XlaOp token, AddInstruction(std::move(token_instr),
HloOpcode::kAfterAll, {}));
@ -2216,7 +2230,7 @@ XlaOp XlaBuilder::Recv(const Shape& shape, const ChannelHandle& handle) {
// TODO(b/80000000): Remove this when clients have been updated to handle
// tokens.
HloInstructionProto recv_data;
*recv_data.mutable_shape() = shape;
*recv_data.mutable_shape() = shape.ToProto();
recv_data.set_tuple_index(0);
return AddInstruction(std::move(recv_data), HloOpcode::kGetTupleElement,
{recv});
@ -2233,15 +2247,18 @@ XlaOp XlaBuilder::RecvWithToken(const XlaOp& token, const Shape& shape,
// Recv instruction produces a tuple of {receive buffer, U32 context,
// token}.
HloInstructionProto recv_instr;
*recv_instr.mutable_shape() = ShapeUtil::MakeTupleShape(
{shape, ShapeUtil::MakeShape(U32, {}), ShapeUtil::MakeTokenShape()});
*recv_instr.mutable_shape() =
ShapeUtil::MakeTupleShape(
{shape, ShapeUtil::MakeShape(U32, {}), ShapeUtil::MakeTokenShape()})
.ToProto();
recv_instr.set_channel_id(handle.handle());
TF_ASSIGN_OR_RETURN(XlaOp recv, AddInstruction(std::move(recv_instr),
HloOpcode::kRecv, {token}));
HloInstructionProto recv_done_instr;
*recv_done_instr.mutable_shape() =
ShapeUtil::MakeTupleShape({shape, ShapeUtil::MakeTokenShape()});
ShapeUtil::MakeTupleShape({shape, ShapeUtil::MakeTokenShape()})
.ToProto();
recv_done_instr.set_channel_id(handle.handle());
return AddInstruction(std::move(recv_done_instr), HloOpcode::kRecvDone,
{recv});
@ -2275,9 +2292,11 @@ XlaOp XlaBuilder::SendToHost(const XlaOp& operand, const XlaOp& token,
// Send instruction produces a tuple of {aliased operand, U32 context,
// token}.
HloInstructionProto send_instr;
*send_instr.mutable_shape() = ShapeUtil::MakeTupleShape(
{shape_with_layout, ShapeUtil::MakeShape(U32, {}),
ShapeUtil::MakeTokenShape()});
*send_instr.mutable_shape() =
ShapeUtil::MakeTupleShape({shape_with_layout,
ShapeUtil::MakeShape(U32, {}),
ShapeUtil::MakeTokenShape()})
.ToProto();
send_instr.set_channel_id(handle.handle());
send_instr.set_is_host_transfer(true);
TF_ASSIGN_OR_RETURN(XlaOp send,
@ -2285,7 +2304,7 @@ XlaOp XlaBuilder::SendToHost(const XlaOp& operand, const XlaOp& token,
{operand, token}));
HloInstructionProto send_done_instr;
*send_done_instr.mutable_shape() = ShapeUtil::MakeTokenShape();
*send_done_instr.mutable_shape() = ShapeUtil::MakeTokenShape().ToProto();
send_done_instr.set_channel_id(handle.handle());
send_done_instr.set_is_host_transfer(true);
return AddInstruction(std::move(send_done_instr), HloOpcode::kSendDone,
@ -2314,8 +2333,10 @@ XlaOp XlaBuilder::RecvFromHost(const XlaOp& token, const Shape& shape,
// Recv instruction produces a tuple of {receive buffer, U32 context,
// token}.
HloInstructionProto recv_instr;
*recv_instr.mutable_shape() = ShapeUtil::MakeTupleShape(
{shape, ShapeUtil::MakeShape(U32, {}), ShapeUtil::MakeTokenShape()});
*recv_instr.mutable_shape() =
ShapeUtil::MakeTupleShape(
{shape, ShapeUtil::MakeShape(U32, {}), ShapeUtil::MakeTokenShape()})
.ToProto();
recv_instr.set_channel_id(handle.handle());
recv_instr.set_is_host_transfer(true);
TF_ASSIGN_OR_RETURN(XlaOp recv, AddInstruction(std::move(recv_instr),
@ -2323,7 +2344,8 @@ XlaOp XlaBuilder::RecvFromHost(const XlaOp& token, const Shape& shape,
HloInstructionProto recv_done_instr;
*recv_done_instr.mutable_shape() =
ShapeUtil::MakeTupleShape({shape, ShapeUtil::MakeTokenShape()});
ShapeUtil::MakeTupleShape({shape, ShapeUtil::MakeTokenShape()})
.ToProto();
recv_done_instr.set_channel_id(handle.handle());
recv_done_instr.set_is_host_transfer(true);
return AddInstruction(std::move(recv_done_instr), HloOpcode::kRecvDone,
@ -2335,9 +2357,9 @@ XlaOp XlaBuilder::GetDimensionSize(const XlaOp& operand, int64 dimension) {
return ReportErrorOrReturn([&]() -> StatusOr<XlaOp> {
HloInstructionProto instr;
TF_ASSIGN_OR_RETURN(const auto& operand_shape, GetShape(operand));
TF_ASSIGN_OR_RETURN(
*instr.mutable_shape(),
ShapeInference::InferGetDimensionSizeShape(operand_shape, dimension));
TF_ASSIGN_OR_RETURN(Shape shape, ShapeInference::InferGetDimensionSizeShape(
operand_shape, dimension));
*instr.mutable_shape() = shape.ToProto();
instr.add_dimensions(dimension);
return AddInstruction(std::move(instr), HloOpcode::kGetDimensionSize,
{operand});

View File

@ -21,6 +21,7 @@ limitations under the License.
#include <vector>
#include "absl/types/span.h"
#include "tensorflow/compiler/xla/shape.h"
#include "tensorflow/compiler/xla/types.h"
#include "tensorflow/compiler/xla/xla_data.pb.h"
#include "tensorflow/core/platform/macros.h"

View File

@ -292,16 +292,17 @@ Status MutableLiteralBase::CopyElementFrom(const LiteralSlice& src_literal,
if (!proto.has_shape()) {
return InvalidArgument("LiteralProto has no shape");
}
if (ShapeUtil::HasPrimitiveType(proto.shape(), OPAQUE)) {
Shape shape(proto.shape());
if (ShapeUtil::HasPrimitiveType(shape, OPAQUE)) {
return InvalidArgument("Literal shape cannot include OPAQUE sub-shape");
}
if (!LayoutUtil::HasLayout(proto.shape())) {
if (!LayoutUtil::HasLayout(shape)) {
return InvalidArgument("LiteralProto has no layout");
}
TF_RETURN_IF_ERROR(ShapeUtil::ValidateShapeWithOptionalLayout(proto.shape()));
TF_RETURN_IF_ERROR(ShapeUtil::ValidateShapeWithOptionalLayout(shape));
Literal literal(proto.shape());
Literal literal(shape);
TF_RETURN_IF_ERROR(literal.root_piece_->ForEachMutableSubpieceWithStatus(
[&](const ShapeIndex& index, Piece* piece) {
@ -1794,7 +1795,7 @@ void CopyToRepeatedField(RepeatedFieldT* dest,
} // namespace
void LiteralBase::Piece::WriteToProto(LiteralProto* proto) const {
*proto->mutable_shape() = subshape();
*proto->mutable_shape() = subshape().ToProto();
switch (subshape().element_type()) {
case PRED:
CopyToRepeatedField(proto->mutable_preds(), data<bool>());
@ -1900,8 +1901,9 @@ Status LiteralBase::Piece::CopyFromProto(const LiteralProto& proto) {
// These conditions should have been checked in
// MutableLiteralBase::CreateFromProto.
TF_RET_CHECK(proto.has_shape());
TF_RET_CHECK(LayoutUtil::HasLayout(proto.shape()));
TF_RET_CHECK(ShapeUtil::Equal(proto.shape(), subshape()));
Shape shape(proto.shape());
TF_RET_CHECK(LayoutUtil::HasLayout(shape));
TF_RET_CHECK(ShapeUtil::Equal(shape, subshape()));
if (LayoutUtil::IsSparseArray(subshape())) {
// Compute the number of elements (indices) in the sparse shape and reserve

View File

@ -1377,13 +1377,26 @@ TEST_F(LiteralUtilTest, BitcastConvertBetweenInvalidTypes) {
absl::StrContains(status.error_message(), "bit widths are different"));
}
// Sets the layout of the given ShapeProto to the default.
void SetDefaultLayoutOnProto(ShapeProto* shape_proto) {
CHECK(ShapeUtil::IsArrayPrimitiveType(shape_proto->element_type()));
shape_proto->mutable_layout()->set_format(DENSE);
auto* minor_to_major =
shape_proto->mutable_layout()->mutable_minor_to_major();
minor_to_major->Resize(shape_proto->dimensions_size(), 0);
const int64 size = minor_to_major->size();
for (int64 i = 0; i < size; ++i) {
minor_to_major->Set(i, size - 1 - i);
}
}
TEST_F(LiteralUtilTest, CopyFromProto_Bool) {
LiteralProto p;
p.mutable_shape()->set_element_type(PRED);
for (int len = 0; len < 25; ++len) {
p.mutable_shape()->clear_dimensions();
p.mutable_shape()->add_dimensions(len);
LayoutUtil::SetToDefaultLayout(p.mutable_shape());
SetDefaultLayoutOnProto(p.mutable_shape());
p.clear_preds();
for (int i = 0; i < len; ++i) {
p.add_preds((i % 2) == (len % 2));
@ -1409,7 +1422,7 @@ TEST_F(LiteralUtilTest, ToProto_f16) {
EXPECT_EQ(4, m.data<half>().size());
LiteralProto p = m.ToProto();
EXPECT_EQ(4, ShapeUtil::ElementsIn(p.shape()));
EXPECT_EQ(4, ShapeUtil::ElementsIn(Shape(p.shape())));
EXPECT_EQ(8, p.f16s().size());
const char* d = p.f16s().data();
EXPECT_EQ(d[0], 0);
@ -1432,7 +1445,7 @@ TEST_F(LiteralUtilTest, CopyFromProto_f16) {
p.mutable_shape()->set_element_type(F16);
p.mutable_shape()->clear_dimensions();
p.mutable_shape()->add_dimensions(4);
LayoutUtil::SetToDefaultLayout(p.mutable_shape());
SetDefaultLayoutOnProto(p.mutable_shape());
p.clear_f16s();
p.set_f16s(half_vals, 8);
TF_ASSERT_OK_AND_ASSIGN(Literal literal, Literal::CreateFromProto(p));
@ -1454,7 +1467,7 @@ TEST_F(LiteralUtilTest, CopyFromProto_u16) {
p.mutable_shape()->set_element_type(U16);
p.mutable_shape()->clear_dimensions();
p.mutable_shape()->add_dimensions(4);
LayoutUtil::SetToDefaultLayout(p.mutable_shape());
SetDefaultLayoutOnProto(p.mutable_shape());
p.clear_u16s();
p.set_u16s(uint16_vals, 8);
TF_ASSERT_OK_AND_ASSIGN(Literal literal, Literal::CreateFromProto(p));
@ -1756,7 +1769,7 @@ TEST_F(LiteralUtilTest, ProtoRoundTrip) {
TEST_F(LiteralUtilTest, InvalidProtoNoValues) {
// Proto contains a shape, but no values.
LiteralProto proto;
*proto.mutable_shape() = ShapeUtil::MakeShape(F32, {3});
*proto.mutable_shape() = ShapeUtil::MakeShape(F32, {3}).ToProto();
Status status = Literal::CreateFromProto(proto).status();
ASSERT_FALSE(status.ok());
EXPECT_THAT(status.error_message(),
@ -1777,7 +1790,7 @@ TEST_F(LiteralUtilTest, InvalidProtoNoShape) {
TEST_F(LiteralUtilTest, InvalidProtoWrongContainer) {
// Proto contains values in wrong container.
LiteralProto proto;
*proto.mutable_shape() = ShapeUtil::MakeShape(F32, {3});
*proto.mutable_shape() = ShapeUtil::MakeShape(F32, {3}).ToProto();
proto.add_preds(false);
proto.add_preds(true);
proto.add_preds(false);
@ -1790,7 +1803,7 @@ TEST_F(LiteralUtilTest, InvalidProtoWrongContainer) {
TEST_F(LiteralUtilTest, InvalidProtoTooFewValues) {
// Proto contains too few values.
LiteralProto proto;
*proto.mutable_shape() = ShapeUtil::MakeShape(F32, {42, 2});
*proto.mutable_shape() = ShapeUtil::MakeShape(F32, {42, 2}).ToProto();
proto.add_f32s(1.0);
proto.add_f32s(2.0);
proto.add_f32s(3.0);
@ -1803,7 +1816,7 @@ TEST_F(LiteralUtilTest, InvalidProtoTooFewValues) {
TEST_F(LiteralUtilTest, InvalidProtoTooManyValues) {
// Proto contains too many values.
LiteralProto proto;
*proto.mutable_shape() = ShapeUtil::MakeShape(S32, {2});
*proto.mutable_shape() = ShapeUtil::MakeShape(S32, {2}).ToProto();
proto.add_s32s(42);
proto.add_s32s(-10);
proto.add_s32s(100);
@ -1816,8 +1829,8 @@ TEST_F(LiteralUtilTest, InvalidProtoTooManyValues) {
TEST_F(LiteralUtilTest, InvalidProtoMissingLayout) {
// Proto shape missing layout.
LiteralProto proto;
*proto.mutable_shape() = ShapeUtil::MakeShape(PRED, {2, 2});
LayoutUtil::ClearLayout(proto.mutable_shape());
*proto.mutable_shape() = ShapeUtil::MakeShape(PRED, {2, 2}).ToProto();
proto.mutable_shape()->clear_layout();
proto.add_preds(true);
proto.add_preds(false);
proto.add_preds(true);
@ -1830,11 +1843,13 @@ TEST_F(LiteralUtilTest, InvalidProtoMissingLayout) {
TEST_F(LiteralUtilTest, InvalidProtoTooFewTupleElements) {
// Proto has the too few tuple elements.
LiteralProto proto;
*proto.mutable_shape() = ShapeUtil::MakeTupleShape(
{ShapeUtil::MakeShape(PRED, {2}), ShapeUtil::MakeShape(F32, {})});
*proto.mutable_shape() =
ShapeUtil::MakeTupleShape(
{ShapeUtil::MakeShape(PRED, {2}), ShapeUtil::MakeShape(F32, {})})
.ToProto();
LiteralProto* element0 = proto.add_tuple_literals();
*element0->mutable_shape() =
ShapeUtil::GetTupleElementShape(proto.shape(), 0);
ShapeUtil::GetTupleElementShape(Shape(proto.shape()), 0).ToProto();
element0->add_preds(false);
element0->add_preds(true);
@ -1846,19 +1861,21 @@ TEST_F(LiteralUtilTest, InvalidProtoTooFewTupleElements) {
TEST_F(LiteralUtilTest, InvalidProtoTooManyTupleElements) {
// Proto has the too many tuple elements.
LiteralProto proto;
*proto.mutable_shape() = ShapeUtil::MakeTupleShape(
{ShapeUtil::MakeShape(PRED, {2}), ShapeUtil::MakeShape(F32, {})});
*proto.mutable_shape() =
ShapeUtil::MakeTupleShape(
{ShapeUtil::MakeShape(PRED, {2}), ShapeUtil::MakeShape(F32, {})})
.ToProto();
LiteralProto* element0 = proto.add_tuple_literals();
*element0->mutable_shape() =
ShapeUtil::GetTupleElementShape(proto.shape(), 0);
ShapeUtil::GetTupleElementShape(Shape(proto.shape()), 0).ToProto();
element0->add_preds(false);
element0->add_preds(true);
LiteralProto* element1 = proto.add_tuple_literals();
*element1->mutable_shape() =
ShapeUtil::GetTupleElementShape(proto.shape(), 1);
ShapeUtil::GetTupleElementShape(Shape(proto.shape()), 1).ToProto();
element1->add_f32s(42.0);
LiteralProto* element2 = proto.add_tuple_literals();
*element2->mutable_shape() = ShapeUtil::MakeShape(F32, {});
*element2->mutable_shape() = ShapeUtil::MakeShape(F32, {}).ToProto();
element2->add_f32s(123.0);
Status status = Literal::CreateFromProto(proto).status();

View File

@ -25,9 +25,10 @@ from tensorflow.compiler.xla.python_api import types
class Shape(object):
"""Wraps a xla_data_pb2.Shape message with a convenient Python type.
"""Wraps a xla_data_pb2.ShapeProto message with a convenient Python type.
Provides direct access to the underlying xla_data_pb2.Shape message in the
Provides direct access to the underlying xla_data_pb2.ShapeProto message in
the
message attribute, along with accessor wrappers to the message's fields.
Avoid direct access to .message unless interacting directly with protobuf APIs
like CopyFrom. In other words, prefer hauling the shape around in a Shape, and
@ -48,7 +49,7 @@ class Shape(object):
Raises:
ValueError: if element_type is TUPLE but dimensions are not Shape objects.
"""
self.message = xla_data_pb2.Shape()
self.message = xla_data_pb2.ShapeProto()
self.message.element_type = element_type
if element_type == xla_data_pb2.TUPLE:
if not all(isinstance(subshape, Shape) for subshape in dimensions):

View File

@ -16,7 +16,6 @@ xla_proto_library(
use_grpc_plugin = True,
visibility = ["//visibility:public"],
deps = [
"//tensorflow/compiler/xla:xla_data_proto",
"//tensorflow/compiler/xla:xla_proto",
],
)

View File

@ -43,7 +43,6 @@ limitations under the License.
syntax = "proto3";
import "tensorflow/compiler/xla/xla.proto";
import "tensorflow/compiler/xla/xla_data.proto";
package xla;

View File

@ -89,7 +89,7 @@ CompileOnlyService::CompileAheadOfTime(
ExecutionOptions execution_options;
*execution_options.mutable_debug_options() = debug_options;
*execution_options.mutable_shape_with_output_layout() =
*instance.result_layout;
instance.result_layout->ToProto();
TF_ASSIGN_OR_RETURN(
std::unique_ptr<HloModuleConfig> module_config,
CreateModuleConfig(

View File

@ -23,6 +23,7 @@ limitations under the License.
#include <deque>
#include "absl/types/span.h"
#include "tensorflow/compiler/xla/shape.h"
#include "tensorflow/compiler/xla/statusor.h"
#include "tensorflow/compiler/xla/types.h"
#include "tensorflow/compiler/xla/xla_data.pb.h"

View File

@ -51,7 +51,7 @@ message HloInstructionProto {
string name = 1;
string opcode = 2;
xla.Shape shape = 3;
xla.ShapeProto shape = 3;
xla.OpMetadata metadata = 7;
@ -132,7 +132,7 @@ message HloInstructionProto {
string custom_call_opaque = 53;
// Shape of outfeed request.
xla.Shape outfeed_shape = 29;
xla.ShapeProto outfeed_shape = 29;
// Describes the dimension numbers used for a dot operation
xla.DotDimensionNumbers dot_dimension_numbers = 30;
@ -190,7 +190,7 @@ message HloInstructionProto {
// 'operand_shapes_with_layout' must contain a shape with layout for each
// operand.
bool constrain_layout = 56;
repeated Shape operand_shapes_with_layout = 57;
repeated xla.ShapeProto operand_shapes_with_layout = 57;
}
// Serialization of HloComputation.

View File

@ -93,7 +93,8 @@ StatusOr<std::unique_ptr<HloInstruction>> HloInstruction::CreateFromProto(
[&computation_map](int64 id) { return computation_map.contains(id); }))
<< proto.name() << " instruction references invalid computation id(s)";
TF_RETURN_IF_ERROR(ShapeUtil::ValidateShapeWithOptionalLayout(proto.shape()));
Shape shape(proto.shape());
TF_RETURN_IF_ERROR(ShapeUtil::ValidateShapeWithOptionalLayout(shape));
switch (opcode) {
// Ops migrated to subclasses.
@ -101,23 +102,23 @@ StatusOr<std::unique_ptr<HloInstruction>> HloInstruction::CreateFromProto(
TF_RET_CHECK(proto.operand_ids_size() == 3)
<< "BatchNormTraining instruction should have 3 operands but sees "
<< proto.operand_ids_size();
instruction = CreateBatchNormTraining(
proto.shape(), operands(0), operands(1), operands(2), proto.epsilon(),
proto.feature_index());
instruction =
CreateBatchNormTraining(shape, operands(0), operands(1), operands(2),
proto.epsilon(), proto.feature_index());
break;
case HloOpcode::kBatchNormInference:
TF_RET_CHECK(proto.operand_ids_size() == 5)
<< "BatchNormInference instruction should have 5 operands but sees "
<< proto.operand_ids_size();
instruction = CreateBatchNormInference(
proto.shape(), operands(0), operands(1), operands(2), operands(3),
shape, operands(0), operands(1), operands(2), operands(3),
operands(4), proto.epsilon(), proto.feature_index());
break;
case HloOpcode::kBatchNormGrad:
TF_RET_CHECK(proto.operand_ids_size() == 5)
<< "BatchNormGrad instruction should have 5 operands but sees "
<< proto.operand_ids_size();
instruction = CreateBatchNormGrad(proto.shape(), operands(0), operands(1),
instruction = CreateBatchNormGrad(shape, operands(0), operands(1),
operands(2), operands(3), operands(4),
proto.epsilon(), proto.feature_index());
break;
@ -127,7 +128,7 @@ StatusOr<std::unique_ptr<HloInstruction>> HloInstruction::CreateFromProto(
<< proto.operand_ids_size();
std::vector<int64> fft_length(proto.fft_length().begin(),
proto.fft_length().end());
instruction = CreateFft(proto.shape(), operands(0), proto.fft_type(),
instruction = CreateFft(shape, operands(0), proto.fft_type(),
absl::Span<const int64>(fft_length));
break;
}
@ -148,7 +149,7 @@ StatusOr<std::unique_ptr<HloInstruction>> HloInstruction::CreateFromProto(
TF_RET_CHECK(proto.operand_ids_size() == 1)
<< "Recv instruction should have 1 operand but sees "
<< proto.operand_ids_size();
instruction = CreateRecv(proto.shape().tuple_shapes(0), operands(0),
instruction = CreateRecv(shape.tuple_shapes(0), operands(0),
proto.channel_id(), proto.is_host_transfer());
break;
case HloOpcode::kRecvDone:
@ -161,7 +162,7 @@ StatusOr<std::unique_ptr<HloInstruction>> HloInstruction::CreateFromProto(
TF_RET_CHECK(proto.operand_ids_size() == 1)
<< "Reverse instruction should have 1 operand but sees "
<< proto.operand_ids_size();
instruction = CreateReverse(proto.shape(), operands(0),
instruction = CreateReverse(shape, operands(0),
std::vector<int64>(proto.dimensions().begin(),
proto.dimensions().end()));
break;
@ -170,7 +171,7 @@ StatusOr<std::unique_ptr<HloInstruction>> HloInstruction::CreateFromProto(
<< "Concatenate instruction should have 1 dimension but sees "
<< proto.dimensions_size();
instruction =
CreateConcatenate(proto.shape(), all_operands(), proto.dimensions(0));
CreateConcatenate(shape, all_operands(), proto.dimensions(0));
break;
case HloOpcode::kReduce:
TF_RET_CHECK(proto.operand_ids_size() % 2 == 0)
@ -188,7 +189,7 @@ StatusOr<std::unique_ptr<HloInstruction>> HloInstruction::CreateFromProto(
absl::MakeSpan(reduce_operands)
.subspan(reduce_operands.size() / 2, reduce_operands.size());
instruction =
CreateReduce(proto.shape(), inputs, init_values,
CreateReduce(shape, inputs, init_values,
std::vector<int64>(proto.dimensions().begin(),
proto.dimensions().end()),
computations(0));
@ -203,7 +204,7 @@ StatusOr<std::unique_ptr<HloInstruction>> HloInstruction::CreateFromProto(
auto sort_operands = all_operands();
HloInstruction* keys = sort_operands[0];
instruction = CreateSort(
proto.shape(), proto.dimensions(0), keys,
shape, proto.dimensions(0), keys,
absl::Span<HloInstruction* const>(sort_operands).subspan(1));
break;
}
@ -212,7 +213,7 @@ StatusOr<std::unique_ptr<HloInstruction>> HloInstruction::CreateFromProto(
<< "Transpose instruction should have 1 operand but sees "
<< proto.operand_ids_size();
instruction =
CreateTranspose(proto.shape(), operands(0),
CreateTranspose(shape, operands(0),
std::vector<int64>(proto.dimensions().begin(),
proto.dimensions().end()));
break;
@ -221,7 +222,7 @@ StatusOr<std::unique_ptr<HloInstruction>> HloInstruction::CreateFromProto(
<< "Broadcast instruction should have 1 operand but sees "
<< proto.operand_ids_size();
instruction =
CreateBroadcast(proto.shape(), operands(0),
CreateBroadcast(shape, operands(0),
std::vector<int64>(proto.dimensions().begin(),
proto.dimensions().end()));
break;
@ -229,7 +230,7 @@ StatusOr<std::unique_ptr<HloInstruction>> HloInstruction::CreateFromProto(
TF_RET_CHECK(proto.called_computation_ids_size() == 1)
<< "Map instruction should have 1 called computation but sees "
<< proto.called_computation_ids_size();
instruction = CreateMap(proto.shape(), all_operands(), computations(0));
instruction = CreateMap(shape, all_operands(), computations(0));
break;
case HloOpcode::kSlice: {
TF_RET_CHECK(proto.operand_ids_size() == 1)
@ -242,8 +243,8 @@ StatusOr<std::unique_ptr<HloInstruction>> HloInstruction::CreateFromProto(
slice_limits.push_back(slice_dimensions.limit());
slice_strides.push_back(slice_dimensions.stride());
}
instruction = CreateSlice(proto.shape(), operands(0), slice_starts,
slice_limits, slice_strides);
instruction = CreateSlice(shape, operands(0), slice_starts, slice_limits,
slice_strides);
break;
}
case HloOpcode::kConstant: {
@ -253,7 +254,7 @@ StatusOr<std::unique_ptr<HloInstruction>> HloInstruction::CreateFromProto(
Literal::CreateFromProto(proto.literal()));
instruction = CreateConstant(std::move(literal));
} else {
instruction = absl::make_unique<HloConstantInstruction>(proto.shape());
instruction = absl::make_unique<HloConstantInstruction>(shape);
}
break;
}
@ -284,55 +285,54 @@ StatusOr<std::unique_ptr<HloInstruction>> HloInstruction::CreateFromProto(
tensorflow::gtl::FindPtrOrNull(computation_map, fusion_id);
TF_RET_CHECK(fused_computation != nullptr)
<< "No fusion computation with id " << fusion_id;
instruction = CreateFusion(proto.shape(), fusion_kind, all_operands(),
fused_computation);
instruction =
CreateFusion(shape, fusion_kind, all_operands(), fused_computation);
break;
}
case HloOpcode::kRng:
instruction =
CreateRng(proto.shape(), proto.distribution(), all_operands());
instruction = CreateRng(shape, proto.distribution(), all_operands());
break;
case HloOpcode::kParameter:
instruction = CreateParameter(proto.parameter_number(), proto.shape(),
proto.name());
instruction =
CreateParameter(proto.parameter_number(), shape, proto.name());
break;
case HloOpcode::kGetTupleElement:
TF_RET_CHECK(proto.operand_ids_size() == 1)
<< "GetTupleElement instruction should have 1 operand but sees "
<< proto.operand_ids_size();
instruction = CreateGetTupleElement(proto.shape(), operands(0),
proto.tuple_index());
instruction =
CreateGetTupleElement(shape, operands(0), proto.tuple_index());
break;
case HloOpcode::kReducePrecision:
TF_RET_CHECK(proto.operand_ids_size() == 1)
<< "ReducePrecision instruction should have 1 operand but sees "
<< proto.operand_ids_size();
instruction =
CreateReducePrecision(proto.shape(), operands(0),
proto.exponent_bits(), proto.mantissa_bits());
instruction = CreateReducePrecision(
shape, operands(0), proto.exponent_bits(), proto.mantissa_bits());
break;
case HloOpcode::kInfeed: {
TF_RET_CHECK(ShapeUtil::IsTuple(proto.shape()) &&
(ShapeUtil::TupleElementCount(proto.shape()) == 2))
TF_RET_CHECK(ShapeUtil::IsTuple(shape) &&
(ShapeUtil::TupleElementCount(shape) == 2))
<< "Infeed should have a tuple shape with 2 operands, but has: "
<< proto.shape();
const Shape& data_shape =
ShapeUtil::GetTupleElementShape(proto.shape(), 0);
<< shape;
const Shape& data_shape = ShapeUtil::GetTupleElementShape(shape, 0);
TF_RET_CHECK(proto.operand_ids_size() == 1)
<< "Infeed instruction should have 1 operand but sees "
<< proto.operand_ids_size();
instruction =
CreateInfeed(data_shape, operands(0), proto.infeed_config());
} break;
case HloOpcode::kOutfeed:
case HloOpcode::kOutfeed: {
TF_RET_CHECK(proto.operand_ids_size() == 2)
<< "Outfeed instruction should have 2 operands but sees "
<< proto.operand_ids_size();
Shape outfeed_shape(proto.outfeed_shape());
TF_RETURN_IF_ERROR(
ShapeUtil::ValidateShapeWithOptionalLayout(proto.outfeed_shape()));
instruction = CreateOutfeed(proto.outfeed_shape(), operands(0),
operands(1), proto.outfeed_config());
ShapeUtil::ValidateShapeWithOptionalLayout(outfeed_shape));
instruction = CreateOutfeed(outfeed_shape, operands(0), operands(1),
proto.outfeed_config());
break;
}
case HloOpcode::kCrossReplicaSum: {
TF_RET_CHECK(proto.called_computation_ids_size() == 1)
<< "CrossReplicaSum should have 1 called computation but sees "
@ -342,7 +342,7 @@ StatusOr<std::unique_ptr<HloInstruction>> HloInstruction::CreateFromProto(
all_reduce_id = proto.all_reduce_id();
}
instruction = CreateCrossReplicaSum(
proto.shape(), all_operands(), computations(0),
shape, all_operands(), computations(0),
/*replica_groups=*/
std::vector<ReplicaGroup>(proto.replica_groups().begin(),
proto.replica_groups().end()),
@ -352,7 +352,7 @@ StatusOr<std::unique_ptr<HloInstruction>> HloInstruction::CreateFromProto(
}
case HloOpcode::kAllToAll: {
instruction = CreateAllToAll(
proto.shape(), all_operands(),
shape, all_operands(),
/*replica_groups=*/
std::vector<ReplicaGroup>(proto.replica_groups().begin(),
proto.replica_groups().end()));
@ -368,8 +368,8 @@ StatusOr<std::unique_ptr<HloInstruction>> HloInstruction::CreateFromProto(
source_target_pairs[i].first = proto.source_target_pairs(i).source();
source_target_pairs[i].second = proto.source_target_pairs(i).target();
}
instruction = CreateCollectivePermute(proto.shape(), operands(0),
source_target_pairs);
instruction =
CreateCollectivePermute(shape, operands(0), source_target_pairs);
break;
}
case HloOpcode::kConvolution: {
@ -382,7 +382,7 @@ StatusOr<std::unique_ptr<HloInstruction>> HloInstruction::CreateFromProto(
precision_config.mutable_operand_precision()->Resize(
proto.operand_ids_size(), PrecisionConfig::DEFAULT);
instruction = CreateConvolve(
proto.shape(), operands(0), operands(1),
shape, operands(0), operands(1),
std::max<int64>(proto.feature_group_count(), 1), proto.window(),
proto.convolution_dimension_numbers(), precision_config);
break;
@ -394,7 +394,7 @@ StatusOr<std::unique_ptr<HloInstruction>> HloInstruction::CreateFromProto(
TF_RET_CHECK(proto.called_computation_ids_size() == 1)
<< "ReduceWindow should have 1 called computation but sees "
<< proto.called_computation_ids_size();
instruction = CreateReduceWindow(proto.shape(), operands(0), operands(1),
instruction = CreateReduceWindow(shape, operands(0), operands(1),
proto.window(), computations(0));
break;
case HloOpcode::kSelectAndScatter:
@ -404,9 +404,9 @@ StatusOr<std::unique_ptr<HloInstruction>> HloInstruction::CreateFromProto(
TF_RET_CHECK(proto.called_computation_ids_size() == 2)
<< "SelectAndScatter should have 2 called computations but sees "
<< proto.called_computation_ids_size();
instruction = CreateSelectAndScatter(
proto.shape(), operands(0), computations(0), proto.window(),
operands(1), operands(2), computations(1));
instruction = CreateSelectAndScatter(shape, operands(0), computations(0),
proto.window(), operands(1),
operands(2), computations(1));
break;
case HloOpcode::kCustomCall:
if (proto.constrain_layout()) {
@ -414,16 +414,17 @@ StatusOr<std::unique_ptr<HloInstruction>> HloInstruction::CreateFromProto(
// vector of pointers essentially) so create a vector of shapes to pass
// in.
std::vector<Shape> operand_shapes;
for (const Shape& shape : proto.operand_shapes_with_layout()) {
operand_shapes.push_back(shape);
for (const ShapeProto& shape_proto :
proto.operand_shapes_with_layout()) {
operand_shapes.emplace_back(shape_proto);
}
instruction = CreateCustomCall(
proto.shape(), all_operands(), proto.custom_call_target(),
operand_shapes, proto.custom_call_opaque());
instruction =
CreateCustomCall(shape, all_operands(), proto.custom_call_target(),
operand_shapes, proto.custom_call_opaque());
} else {
instruction = CreateCustomCall(proto.shape(), all_operands(),
proto.custom_call_target(),
proto.custom_call_opaque());
instruction =
CreateCustomCall(shape, all_operands(), proto.custom_call_target(),
proto.custom_call_opaque());
}
if (proto.has_window()) {
static_cast<HloCustomCallInstruction*>(instruction.get())
@ -443,8 +444,8 @@ StatusOr<std::unique_ptr<HloInstruction>> HloInstruction::CreateFromProto(
<< "Pad instruction should have 2 operands but sees "
<< proto.operand_ids_size();
TF_RET_CHECK(proto.has_padding_config());
instruction = CreatePad(proto.shape(), operands(0), operands(1),
proto.padding_config());
instruction =
CreatePad(shape, operands(0), operands(1), proto.padding_config());
break;
case HloOpcode::kDynamicSlice: {
TF_RET_CHECK(proto.operand_ids_size() == 2)
@ -452,8 +453,8 @@ StatusOr<std::unique_ptr<HloInstruction>> HloInstruction::CreateFromProto(
<< proto.operand_ids_size();
std::vector<int64> slice_sizes(proto.dynamic_slice_sizes_size());
absl::c_copy(proto.dynamic_slice_sizes(), slice_sizes.begin());
instruction = CreateDynamicSlice(proto.shape(), operands(0), operands(1),
slice_sizes);
instruction =
CreateDynamicSlice(shape, operands(0), operands(1), slice_sizes);
break;
}
case HloOpcode::kGather: {
@ -469,7 +470,7 @@ StatusOr<std::unique_ptr<HloInstruction>> HloInstruction::CreateFromProto(
for (int64 bound : proto.gather_slice_sizes()) {
gather_slice_sizes.push_back(bound);
}
instruction = CreateGather(proto.shape(), operands(0), operands(1),
instruction = CreateGather(shape, operands(0), operands(1),
*gather_dimension_numbers, gather_slice_sizes);
break;
}
@ -485,16 +486,15 @@ StatusOr<std::unique_ptr<HloInstruction>> HloInstruction::CreateFromProto(
auto scatter_dimension_numbers =
absl::make_unique<ScatterDimensionNumbers>(
proto.scatter_dimension_numbers());
instruction =
CreateScatter(proto.shape(), operands(0), operands(1), operands(2),
computations(0), *scatter_dimension_numbers);
instruction = CreateScatter(shape, operands(0), operands(1), operands(2),
computations(0), *scatter_dimension_numbers);
break;
}
case HloOpcode::kIota:
TF_RET_CHECK(proto.dimensions_size() == 1)
<< "Iota instruction should have 1 dimension but sees "
<< proto.dimensions_size();
instruction = CreateIota(proto.shape(), proto.dimensions(0));
instruction = CreateIota(shape, proto.dimensions(0));
break;
case HloOpcode::kDot: {
TF_RET_CHECK(proto.has_dot_dimension_numbers())
@ -506,8 +506,8 @@ StatusOr<std::unique_ptr<HloInstruction>> HloInstruction::CreateFromProto(
precision_config.mutable_operand_precision()->Resize(
proto.operand_ids_size(), PrecisionConfig::DEFAULT);
instruction = absl::make_unique<HloDotInstruction>(
proto.shape(), operands(0), operands(1),
proto.dot_dimension_numbers(), precision_config);
shape, operands(0), operands(1), proto.dot_dimension_numbers(),
precision_config);
break;
}
case HloOpcode::kDomain: {
@ -529,7 +529,7 @@ StatusOr<std::unique_ptr<HloInstruction>> HloInstruction::CreateFromProto(
exit_hlo_sharding = std::make_shared<const HloSharding>(sharding);
}
instruction = absl::make_unique<HloDomainInstruction>(
proto.shape(), operands(0),
shape, operands(0),
absl::make_unique<ShardingMetadata>(entry_hlo_sharding),
absl::make_unique<ShardingMetadata>(exit_hlo_sharding));
break;
@ -537,11 +537,11 @@ StatusOr<std::unique_ptr<HloInstruction>> HloInstruction::CreateFromProto(
case HloOpcode::kGetDimensionSize:
TF_RET_CHECK(proto.operand_ids_size() == 1);
TF_RET_CHECK(proto.dimensions_size() == 1);
instruction = CreateGetDimensionSize(proto.shape(), operands(0),
proto.dimensions(0));
instruction =
CreateGetDimensionSize(shape, operands(0), proto.dimensions(0));
break;
default: {
instruction = absl::WrapUnique(new HloInstruction(opcode, proto.shape()));
instruction = absl::WrapUnique(new HloInstruction(opcode, shape));
for (const int64 operand_id : proto.operand_ids()) {
instruction->AppendOperand(instruction_map.at(operand_id));
}
@ -2234,7 +2234,7 @@ HloInstructionProto HloInstruction::ToProto() const {
proto.set_id(unique_id_);
proto.set_name(name_);
proto.set_opcode(HloOpcodeString(opcode_));
*proto.mutable_shape() = shape_;
*proto.mutable_shape() = shape_.ToProto();
for (const HloInstruction* operand : operands_) {
proto.add_operand_ids(operand->unique_id());
}

View File

@ -1615,7 +1615,7 @@ HloOutfeedInstruction::HloOutfeedInstruction(const Shape& outfeed_shape,
HloInstructionProto HloOutfeedInstruction::ToProto() const {
HloInstructionProto proto = HloInstruction::ToProto();
proto.set_outfeed_config(outfeed_config());
*proto.mutable_outfeed_shape() = outfeed_shape();
*proto.mutable_outfeed_shape() = outfeed_shape().ToProto();
return proto;
}
@ -1867,7 +1867,7 @@ HloInstructionProto HloCustomCallInstruction::ToProto() const {
if (layout_constrained()) {
proto.set_constrain_layout(true);
for (const Shape& shape : operand_shapes_with_layout_) {
*proto.add_operand_shapes_with_layout() = shape;
*proto.add_operand_shapes_with_layout() = shape.ToProto();
}
}
return proto;

View File

@ -20,6 +20,7 @@ limitations under the License.
#include "absl/strings/string_view.h"
#include "tensorflow/compiler/xla/service/hlo_token.h"
#include "tensorflow/compiler/xla/shape.h"
#include "tensorflow/compiler/xla/types.h"
#include "tensorflow/compiler/xla/xla_data.pb.h"
#include "tensorflow/core/platform/logging.h"

View File

@ -257,7 +257,7 @@ StatusOr<std::unique_ptr<HloModule>> HloModule::CreateFromProto(
// the entry parameters and root.
TF_RET_CHECK(proto.has_host_program_shape())
<< "No program shape found in the proto";
const auto& expected_program_shape = proto.host_program_shape();
ProgramShape expected_program_shape(proto.host_program_shape());
TF_RET_CHECK(expected_program_shape.parameters_size() ==
module_config.entry_computation_layout().parameter_count());
for (int i = 0; i < expected_program_shape.parameters_size(); ++i) {
@ -369,7 +369,7 @@ StatusOr<HloModuleConfig> HloModule::CreateModuleConfigFromProto(
const HloModuleProto& module, const DebugOptions& debug_options) {
TF_RET_CHECK(module.has_host_program_shape())
<< "No program shape found in the proto";
const auto& program_shape = module.host_program_shape();
ProgramShape program_shape(module.host_program_shape());
HloModuleConfig module_config(ProgramShape{program_shape});
module_config.set_debug_options(debug_options);

View File

@ -48,7 +48,7 @@ StatusOr<std::unique_ptr<HloModule>> CreateModuleFromProto(
return std::move(module);
}
StatusOr<std::vector<const Shape*>> EntryComputationParameterShapes(
StatusOr<std::vector<const ShapeProto*>> EntryComputationParameterShapes(
const HloProto& hlo_proto) {
if (!hlo_proto.has_hlo_module()) {
return NotFound("HloProto missing HloModuleProto.");
@ -57,15 +57,16 @@ StatusOr<std::vector<const Shape*>> EntryComputationParameterShapes(
return NotFound("HloProto missing program shape.");
}
std::vector<const Shape*> parameter_shapes;
std::vector<const ShapeProto*> parameter_shapes;
const auto& program_shape = hlo_proto.hlo_module().host_program_shape();
for (const Shape& shape : program_shape.parameters()) {
for (const ShapeProto& shape : program_shape.parameters()) {
parameter_shapes.push_back(&shape);
}
return parameter_shapes;
}
StatusOr<const Shape*> EntryComputationOutputShape(const HloProto& hlo_proto) {
StatusOr<const ShapeProto*> EntryComputationOutputShape(
const HloProto& hlo_proto) {
if (!hlo_proto.has_hlo_module()) {
return NotFound("HloProto missing HloModuleProto.");
}

View File

@ -43,12 +43,13 @@ StatusOr<std::unique_ptr<HloModule>> CreateModuleFromProto(
// Returns the shapes of the parameters of the entry computation. Shape pointers
// refer to shapes inside of the given HloProto.
StatusOr<std::vector<const Shape*>> EntryComputationParameterShapes(
StatusOr<std::vector<const ShapeProto*>> EntryComputationParameterShapes(
const HloProto& hlo_proto);
// Returns the shape of the output of the entry computation. The shape pointer
// refers to the output shape inside of the given HloProto.
StatusOr<const Shape*> EntryComputationOutputShape(const HloProto& hlo_proto);
StatusOr<const ShapeProto*> EntryComputationOutputShape(
const HloProto& hlo_proto);
} // namespace xla

View File

@ -25,6 +25,7 @@ limitations under the License.
#include "llvm/IR/IRBuilder.h"
#include "llvm/IR/Value.h"
#include "tensorflow/compiler/xla/map_util.h"
#include "tensorflow/compiler/xla/shape.h"
#include "tensorflow/compiler/xla/types.h"
#include "tensorflow/compiler/xla/xla_data.pb.h"
#include "tensorflow/core/platform/logging.h"

View File

@ -244,10 +244,11 @@ StatusOr<llvm::Value*> EncodeSelfDescribingShapeConstant(const Shape& shape,
StatusOr<Shape> DecodeSelfDescribingShapeConstant(const void* shape_ptr,
int32 size_bytes) {
Shape shape;
TF_RET_CHECK(shape.ParseFromArray(shape_ptr, size_bytes));
ShapeProto shape_proto;
TF_RET_CHECK(shape_proto.ParseFromArray(shape_ptr, size_bytes));
Shape shape(shape_proto);
TF_RETURN_IF_ERROR(ShapeUtil::ValidateShape(shape));
return shape;
return std::move(shape);
}
llvm::Constant* ConvertLiteralToIrConstant(const Literal& literal,

View File

@ -101,12 +101,12 @@ ExecutionOptions CreateExecutionOptions(
}
if (build_options.result_layout() != nullptr) {
*execution_options.mutable_shape_with_output_layout() =
*build_options.result_layout();
build_options.result_layout()->ToProto();
} else {
Shape result_shape(program_shape->result());
LayoutUtil::SetToDefaultLayout(&result_shape);
*execution_options.mutable_shape_with_output_layout() =
program_shape->result();
LayoutUtil::SetToDefaultLayout(
execution_options.mutable_shape_with_output_layout());
result_shape.ToProto();
}
return execution_options;
}

View File

@ -41,6 +41,7 @@ limitations under the License.
#include "tensorflow/compiler/xla/service/source_map_util.h"
#include "tensorflow/compiler/xla/service/stream_pool.h"
#include "tensorflow/compiler/xla/service/transfer_manager.h"
#include "tensorflow/compiler/xla/shape.h"
#include "tensorflow/compiler/xla/shape_layout.h"
#include "tensorflow/compiler/xla/shape_util.h"
#include "tensorflow/compiler/xla/status_macros.h"
@ -275,8 +276,8 @@ StatusOr<std::unique_ptr<HloModuleConfig>> Service::CreateModuleConfig(
}
if (execution_options != nullptr &&
execution_options->has_shape_with_output_layout()) {
const auto& shape_with_output_layout =
execution_options->shape_with_output_layout();
const Shape shape_with_output_layout(
execution_options->shape_with_output_layout());
TF_RETURN_IF_ERROR(
ValidateResultShape(shape_with_output_layout, program_shape.result()));
TF_RETURN_IF_ERROR(
@ -818,14 +819,17 @@ Status Service::Compile(const CompileRequest* arg, CompileResponse* result) {
"The compile request does not support multiple device handles.");
}
std::vector<const Shape*> argument_shapes;
absl::c_transform(arg->input_shape_with_layout(),
std::back_inserter(argument_shapes),
[](const Shape& shape) { return &shape; });
std::vector<Shape> argument_shapes;
argument_shapes.reserve(arg->input_shape_with_layout_size());
std::vector<const Shape*> argument_shape_ptrs;
for (const ShapeProto& shape_proto : arg->input_shape_with_layout()) {
argument_shapes.push_back(Shape(shape_proto));
argument_shape_ptrs.push_back(&argument_shapes.back());
}
TF_ASSIGN_OR_RETURN(
std::unique_ptr<HloModuleConfig> module_config,
CreateModuleConfig(ProgramShape{arg->computation().host_program_shape()},
argument_shapes, &arg->execution_options()));
argument_shape_ptrs, &arg->execution_options()));
VLOG(3) << "Compile created HloModuleConfig computation layout: "
<< module_config->entry_computation_layout().ToString();
@ -930,14 +934,14 @@ Status Service::TransferToClient(const TransferToClientRequest* arg,
TF_ASSIGN_OR_RETURN(const ShapedBuffer* shaped_buffer,
allocation_tracker_.ResolveForReplica(arg->data(), 0));
const Shape* return_shape;
Shape return_shape;
if (arg->has_shape_with_layout()) {
if (!LayoutUtil::HasLayout(arg->shape_with_layout())) {
return_shape = Shape(arg->shape_with_layout());
if (!LayoutUtil::HasLayout(return_shape)) {
return InvalidArgument("shape_with_layout must have layout if present.");
}
return_shape = &arg->shape_with_layout();
} else {
return_shape = &shaped_buffer->on_host_shape();
return_shape = Shape(shaped_buffer->on_host_shape());
}
TF_ASSIGN_OR_RETURN(auto stream, execute_backend_->BorrowStream(
@ -948,11 +952,11 @@ Status Service::TransferToClient(const TransferToClientRequest* arg,
execute_backend_->transfer_manager()->TransferLiteralFromDevice(
stream.get(), *shaped_buffer));
if (LayoutUtil::LayoutsInShapesEqual(*return_shape, result_literal.shape())) {
if (LayoutUtil::LayoutsInShapesEqual(return_shape, result_literal.shape())) {
*result->mutable_literal() = result_literal.ToProto();
} else {
*result->mutable_literal() =
result_literal.Relayout(*return_shape).ToProto();
result_literal.Relayout(return_shape).ToProto();
}
return Status::OK();
}
@ -1045,11 +1049,11 @@ Status Service::TransferFromOutfeed(const TransferFromOutfeedRequest* arg,
executor = replicas[arg->replica_id()];
}
auto literal = Literal::CreateFromShape(arg->shape_with_layout());
auto literal = Literal::CreateFromShape(Shape(arg->shape_with_layout()));
TF_RETURN_IF_ERROR(
execute_backend_->transfer_manager()->TransferLiteralFromOutfeed(
executor, arg->shape_with_layout(), literal));
executor, Shape(arg->shape_with_layout()), literal));
*result->mutable_literal() = literal.ToProto();
return Status::OK();
}
@ -1103,7 +1107,7 @@ Status Service::ComputeConstantGraph(const ComputeConstantGraphRequest* arg,
Status Service::GetShape(const GetShapeRequest* arg, GetShapeResponse* result) {
TF_ASSIGN_OR_RETURN(const ShapedBuffer* buffer,
allocation_tracker_.ResolveForReplica(arg->data(), 0));
*result->mutable_shape() = buffer->on_host_shape();
*result->mutable_shape() = buffer->on_host_shape().ToProto();
return Status::OK();
}

View File

@ -1018,7 +1018,7 @@ ShapeInference::InferDegenerateDimensionBroadcastShape(HloOpcode operation,
switch (opcode) {
case HloOpcode::kTuple: {
Shape result = ShapeUtil::MakeTupleShape({});
result.mutable_tuple_shapes()->Reserve(operand_shapes.size());
result.mutable_tuple_shapes()->reserve(operand_shapes.size());
for (const Shape* shape : operand_shapes) {
ShapeUtil::AppendShapeToTuple(*shape, &result);
}

View File

@ -21,11 +21,56 @@ limitations under the License.
namespace xla {
ProgramShape::ProgramShape(const ProgramShapeProto& program_shape_proto) {
for (const Shape& shape : program_shape_proto.parameters()) {
*add_parameters() = shape;
Shape::Shape(const ShapeProto& shape_proto) {
set_element_type(shape_proto.element_type());
dimensions_.reserve(shape_proto.dimensions_size());
for (const int64 dimension : shape_proto.dimensions()) {
add_dimensions(dimension);
}
*mutable_result() = program_shape_proto.result();
tuple_shapes_.reserve(shape_proto.tuple_shapes_size());
for (const ShapeProto& element_shape : shape_proto.tuple_shapes()) {
*add_tuple_shapes() = Shape(element_shape);
}
if (shape_proto.has_layout()) {
*mutable_layout() = shape_proto.layout();
}
}
ShapeProto Shape::ToProto() const {
ShapeProto proto;
proto.set_element_type(element_type_);
proto.mutable_dimensions()->Reserve(dimensions_size());
for (const int64 dimension : dimensions()) {
proto.add_dimensions(dimension);
}
proto.mutable_tuple_shapes()->Reserve(tuple_shapes_size());
for (const Shape& shape : tuple_shapes()) {
*proto.add_tuple_shapes() = shape.ToProto();
}
if (has_layout()) {
*proto.mutable_layout() = layout();
}
return proto;
}
string Shape::ToString(bool print_layout) const {
if (print_layout) {
return ShapeUtil::HumanStringWithLayout(*this);
} else {
return ShapeUtil::HumanString(*this);
}
}
std::ostream& operator<<(std::ostream& out, const Shape& shape) {
out << shape.ToString(/*print_layout=*/true);
return out;
}
ProgramShape::ProgramShape(const ProgramShapeProto& program_shape_proto) {
for (const ShapeProto& shape_proto : program_shape_proto.parameters()) {
*add_parameters() = Shape(shape_proto);
}
*mutable_result() = Shape(program_shape_proto.result());
for (const string& name : program_shape_proto.parameter_names()) {
add_parameter_names(name);
}
@ -34,9 +79,9 @@ ProgramShape::ProgramShape(const ProgramShapeProto& program_shape_proto) {
ProgramShapeProto ProgramShape::ToProto() const {
ProgramShapeProto proto;
for (const Shape& shape : parameters()) {
*proto.add_parameters() = shape;
*proto.add_parameters() = shape.ToProto();
}
*proto.mutable_result() = result();
*proto.mutable_result() = result().ToProto();
for (const string& name : parameter_names()) {
proto.add_parameter_names(name);
}

View File

@ -26,6 +26,102 @@ limitations under the License.
namespace xla {
// A shape describes the number of dimensions in a array, the bounds of each
// dimension, and the primitive component type. For tuples, shape describes the
// structure (number of elements and nesting).
class Shape {
public:
Shape() = default;
// Construct a shape from a ShapeProto.
explicit Shape(const ShapeProto& shape_proto);
// Returns a ShapeProto representation of the Shape.
ShapeProto ToProto() const;
// Returns a human-readable string that represents the given shape, with or
// without layout. e.g. "F32[42,12] {0, 1}" or "F32[64]".
string ToString(bool print_layout = false) const;
// The following methods mirror the protobuf generated code interface for the
// message ShapeProto. This enabled easy migration of this data structure
// from a proto to a proper C++ class.
// TODO(b/29771030): Replace or augment these methods with a more ergonomic
// interface.
// Methods for accessing the primitive type.
PrimitiveType element_type() const { return element_type_; }
void set_element_type(PrimitiveType value) { element_type_ = value; }
// Methods for accessing the dimensions array.
int dimensions_size() const { return dimensions_.size(); }
int64 dimensions(int index) const { return dimensions_.at(index); }
void set_dimensions(int index, int64 value) { dimensions_.at(index) = value; }
void add_dimensions(int64 value) { dimensions_.push_back(value); }
void clear_dimensions() { dimensions_.clear(); }
const std::vector<int64>& dimensions() const { return dimensions_; }
std::vector<int64>* mutable_dimensions() { return &dimensions_; }
// Methods for accessing the tuple subshapes. This field only non-empty for
// tuple shapes.
int tuple_shapes_size() const { return tuple_shapes_.size(); }
const Shape& tuple_shapes(int index) const { return tuple_shapes_.at(index); }
Shape* mutable_tuple_shapes(int index) { return &tuple_shapes_.at(index); }
Shape* add_tuple_shapes() {
tuple_shapes_.push_back(Shape());
return &tuple_shapes_.back();
}
void clear_tuple_shapes() { tuple_shapes_.clear(); }
const std::vector<Shape>& tuple_shapes() const { return tuple_shapes_; }
std::vector<Shape>* mutable_tuple_shapes() { return &tuple_shapes_; }
// Methods for accessing the layout field.
bool has_layout() const { return layout_.has_value(); }
const Layout& layout() const {
if (layout_.has_value()) {
return *layout_;
} else {
return Layout::default_instance();
}
}
Layout* mutable_layout() {
if (!layout_.has_value()) {
layout_ = Layout();
}
return &layout_.value();
}
void clear_layout() { layout_.reset(); }
void Swap(Shape* other) {
using std::swap;
swap(*this, *other);
}
void Clear() {
element_type_ = PRIMITIVE_TYPE_INVALID;
dimensions_.clear();
tuple_shapes_.clear();
layout_.reset();
}
string SerializeAsString() const { return ToProto().SerializeAsString(); }
string ShortDebugString() const { return ToProto().ShortDebugString(); }
string DebugString() const { return ToProto().DebugString(); }
public:
// The element type of this shape (tuple, array, etc).
PrimitiveType element_type_ = PRIMITIVE_TYPE_INVALID;
// The array bounds of the dimensions. This is nonempty only for array shapes.
std::vector<int64> dimensions_;
// The tuple element subshapes. This is nonempty only for tuple shapes.
std::vector<Shape> tuple_shapes_;
// The array layout of the shape. This is present only for array shapes.
absl::optional<Layout> layout_;
};
// Shape of the parameters and output of an XLA computation. This is analogous
// to a traditional function signature.
class ProgramShape {
@ -61,7 +157,6 @@ class ProgramShape {
// Methods for accessing and manipulating the Shape of the result.
const Shape& result() const { return result_; }
Shape* mutable_result() { return &result_; }
void clear_result() { result_.Clear(); }
// Methods for accessing and manipulating the names of the parameters.
int parameter_names_size() const { return parameter_names_.size(); }
@ -101,6 +196,7 @@ class ProgramShape {
Shape result_;
};
std::ostream& operator<<(std::ostream& out, const Shape& shape);
std::ostream& operator<<(std::ostream& out, const ProgramShape& program_shape);
} // namespace xla

View File

@ -30,7 +30,51 @@ limitations under the License.
namespace xla {
namespace {
TEST(ShapeTest, ProgramShapeToFromProto) {
class ShapeTest : public ::testing::Test {
protected:
const Shape opaque_ = ShapeUtil::MakeOpaqueShape();
const Shape token_ = ShapeUtil::MakeTokenShape();
const Shape scalar_ = ShapeUtil::MakeShape(F32, {});
const Shape matrix_ = ShapeUtil::MakeShape(U32, {1, 2});
const Shape matrix2_ = ShapeUtil::MakeShapeWithLayout(S32, {3, 4}, {0, 1});
const Shape tuple_ =
ShapeUtil::MakeTupleShape({opaque_, scalar_, matrix_, matrix2_});
const Shape nested_tuple_ =
ShapeUtil::MakeTupleShape({tuple_, matrix_, token_});
};
TEST_F(ShapeTest, ShapeToFromProto) {
for (const Shape& shape :
{opaque_, token_, scalar_, matrix_, matrix2_, tuple_, nested_tuple_}) {
Shape shape_copy(shape.ToProto());
EXPECT_TRUE(ShapeUtil::Equal(shape, shape_copy))
<< shape << " != " << shape_copy;
}
}
TEST_F(ShapeTest, ShapeToString) {
EXPECT_EQ("opaque[]", opaque_.ToString());
EXPECT_EQ("token[]", token_.ToString());
EXPECT_EQ("f32[]", scalar_.ToString());
EXPECT_EQ("u32[1,2]", matrix_.ToString());
EXPECT_EQ("s32[3,4]", matrix2_.ToString());
EXPECT_EQ("(opaque[], f32[], u32[1,2], s32[3,4])", tuple_.ToString());
EXPECT_EQ("((opaque[], f32[], u32[1,2], s32[3,4]), u32[1,2], token[])",
nested_tuple_.ToString());
EXPECT_EQ("opaque[]", opaque_.ToString(/*print_layout=*/true));
EXPECT_EQ("f32[]", scalar_.ToString(/*print_layout=*/true));
EXPECT_EQ("u32[1,2]{1,0}", matrix_.ToString(/*print_layout=*/true));
EXPECT_EQ("s32[3,4]{0,1}", matrix2_.ToString(/*print_layout=*/true));
EXPECT_EQ("(opaque[], f32[], u32[1,2]{1,0}, s32[3,4]{0,1})",
tuple_.ToString(/*print_layout=*/true));
EXPECT_EQ(
"((opaque[], f32[], u32[1,2]{1,0}, s32[3,4]{0,1}), u32[1,2]{1,0}, "
"token[])",
nested_tuple_.ToString(/*print_layout=*/true));
}
TEST_F(ShapeTest, ProgramShapeToFromProto) {
ProgramShape program_shape;
*program_shape.add_parameters() = ShapeUtil::MakeShape(F32, {1, 2, 3});
*program_shape.add_parameters() = ShapeUtil::MakeTokenShape();
@ -67,17 +111,10 @@ TEST(ShapeTest, ProgramShapeToFromProto) {
}
}
TEST(ShapeTest, ProgramShapeToString) {
Shape opaque = ShapeUtil::MakeOpaqueShape();
Shape token = ShapeUtil::MakeTokenShape();
Shape scalar = ShapeUtil::MakeShape(F32, {});
Shape matrix = ShapeUtil::MakeShape(U32, {1, 2});
Shape matrix2 = ShapeUtil::MakeShapeWithLayout(S32, {3, 4}, {0, 1});
Shape tuple = ShapeUtil::MakeTupleShape({opaque, scalar, matrix, matrix2});
Shape nested_tuple = ShapeUtil::MakeTupleShape({tuple, matrix, token});
TEST_F(ShapeTest, ProgramShapeToString) {
ProgramShape prog = ShapeUtil::MakeProgramShape(
{opaque, scalar, matrix, matrix2, tuple, nested_tuple}, nested_tuple);
{opaque_, scalar_, matrix_, matrix2_, tuple_, nested_tuple_},
nested_tuple_);
EXPECT_EQ(
"((unknown): opaque[], "
"(unknown): f32[], "
@ -87,7 +124,7 @@ TEST(ShapeTest, ProgramShapeToString) {
"(unknown): ((opaque[], f32[], u32[1,2], s32[3,4]), u32[1,2], token[])) "
"-> "
"((opaque[], f32[], u32[1,2], s32[3,4]), u32[1,2], token[])",
ShapeUtil::HumanString(prog));
prog.ToString());
prog.add_parameter_names("arg0");
prog.add_parameter_names("scalar");
@ -105,7 +142,7 @@ TEST(ShapeTest, ProgramShapeToString) {
"token[])) "
"-> "
"((opaque[], f32[], u32[1,2], s32[3,4]), u32[1,2], token[])",
ShapeUtil::HumanString(prog));
prog.ToString());
}
} // namespace

View File

@ -79,14 +79,14 @@ bool ShapeIndexView::StartsWith(ShapeIndexView prefix) const {
indices_.subspan(0, prefix.size()) == prefix.indices_;
}
namespace {
// Returns whether the given primitive type corresponds to an array shape.
bool IsArrayPrimitiveType(PrimitiveType primitive_type) {
/* static */ bool ShapeUtil::IsArrayPrimitiveType(
PrimitiveType primitive_type) {
return primitive_type != PRIMITIVE_TYPE_INVALID && primitive_type != TUPLE &&
primitive_type != OPAQUE && primitive_type != TOKEN;
}
namespace {
// Recursive helper for comparing the equality of two shapes. Returns true if
// the shapes are the same. If compare_layouts is true, then layouts must also
// match.
@ -203,7 +203,7 @@ StatusOr<Shape> MakeShapeWithLayoutInternal(
/* static */ ProgramShape ShapeUtil::MakeProgramShape(
std::initializer_list<Shape> parameters, Shape result) {
ProgramShape program_shape;
for (const auto& shape : parameters) {
for (const Shape& shape : parameters) {
*program_shape.add_parameters() = shape;
}
*program_shape.mutable_result() = std::move(result);
@ -272,7 +272,7 @@ ShapeUtil::MakeShapeWithDescendingLayoutAndSamePhysicalLayout(
/* static */ Shape ShapeUtil::MakeTupleShape(absl::Span<const Shape> shapes) {
Shape result;
result.set_element_type(TUPLE);
result.mutable_tuple_shapes()->Reserve(shapes.size());
result.mutable_tuple_shapes()->reserve(shapes.size());
for (const auto& shape : shapes) {
AppendShapeToTuple(shape, &result);
}
@ -563,20 +563,6 @@ StatusOr<PrimitiveType> StringToPrimitiveType(const string& name) {
HumanString(program_shape.result()));
}
/* static */ string ShapeUtil::HumanString(
const ProgramShapeProto& program_shape_proto) {
std::vector<string> parameters;
for (auto& shape : program_shape_proto.parameters()) {
const int i = parameters.size();
parameters.push_back(StrCat(i < program_shape_proto.parameter_names_size()
? program_shape_proto.parameter_names(i)
: "(unknown)",
": ", HumanString(shape)));
}
return StrCat("(", absl::StrJoin(parameters, ", "), ") -> ",
HumanString(program_shape_proto.result()));
}
namespace {
// Parses shapes with simple recursive descent structure -- consumes from the
// front of s and passes that view recursively as required.
@ -1610,7 +1596,8 @@ ShapeUtil::DimensionsUnmodifiedByReshape(const Shape& input_shape,
/* static */ Shape ShapeUtil::DeleteDimension(int64 dim_to_delete,
Shape shape) {
CHECK(IsArray(shape));
shape.mutable_dimensions()->erase(shape.dimensions().begin() + dim_to_delete);
shape.mutable_dimensions()->erase(shape.mutable_dimensions()->begin() +
dim_to_delete);
if (LayoutUtil::HasLayout(shape)) {
Layout* layout = shape.mutable_layout();
layout->set_format(DENSE);
@ -1644,11 +1631,6 @@ ShapeUtil::DimensionsUnmodifiedByReshape(const Shape& input_shape,
return shape;
}
std::ostream& operator<<(std::ostream& out, const Shape& shape) {
out << ShapeUtil::HumanStringWithLayout(shape);
return out;
}
/*static*/ size_t ShapeUtil::Hash(const Shape& shape) {
using tensorflow::hash;
using tensorflow::Hash64Combine;

View File

@ -240,7 +240,6 @@ class ShapeUtil {
//
// (param_name: f32[42x12], ...) -> f32[24x42]
static string HumanString(const ProgramShape& program_shape);
static string HumanString(const ProgramShapeProto& program_shape_proto);
// Parses a ShapeUtil::HumanString-format shape string back into a shape
// object.
@ -469,6 +468,9 @@ class ShapeUtil {
// arrays.
static bool IsArray(const Shape& shape);
// Returns whether the given primitive type corresponds to an array shape.
static bool IsArrayPrimitiveType(PrimitiveType primitive_type);
// Returns whether the shape is a tuple with at least one element which is
// also a tuple.
static bool IsNestedTuple(const Shape& shape);
@ -796,8 +798,6 @@ class ShapeUtil {
TF_DISALLOW_COPY_AND_ASSIGN(ShapeUtil);
};
std::ostream& operator<<(std::ostream& out, const Shape& shape);
} // namespace xla
#endif // TENSORFLOW_COMPILER_XLA_SHAPE_UTIL_H_

View File

@ -546,37 +546,6 @@ TEST(ShapeUtilTest, IsLeafIndex) {
EXPECT_TRUE(ShapeUtil::IsLeafIndex(nested_tuple_shape, {1, 1}));
}
TEST(ShapeUtilTest, HumanString) {
Shape opaque = ShapeUtil::MakeOpaqueShape();
Shape token = ShapeUtil::MakeTokenShape();
Shape scalar = ShapeUtil::MakeShape(F32, {});
Shape matrix = ShapeUtil::MakeShape(U32, {1, 2});
Shape matrix2 = ShapeUtil::MakeShapeWithLayout(S32, {3, 4}, {0, 1});
Shape tuple = ShapeUtil::MakeTupleShape({opaque, scalar, matrix, matrix2});
Shape nested_tuple = ShapeUtil::MakeTupleShape({tuple, matrix, token});
EXPECT_EQ("opaque[]", ShapeUtil::HumanString(opaque));
EXPECT_EQ("token[]", ShapeUtil::HumanString(token));
EXPECT_EQ("f32[]", ShapeUtil::HumanString(scalar));
EXPECT_EQ("u32[1,2]", ShapeUtil::HumanString(matrix));
EXPECT_EQ("s32[3,4]", ShapeUtil::HumanString(matrix2));
EXPECT_EQ("(opaque[], f32[], u32[1,2], s32[3,4])",
ShapeUtil::HumanString(tuple));
EXPECT_EQ("((opaque[], f32[], u32[1,2], s32[3,4]), u32[1,2], token[])",
ShapeUtil::HumanString(nested_tuple));
EXPECT_EQ("opaque[]", ShapeUtil::HumanStringWithLayout(opaque));
EXPECT_EQ("f32[]", ShapeUtil::HumanStringWithLayout(scalar));
EXPECT_EQ("u32[1,2]{1,0}", ShapeUtil::HumanStringWithLayout(matrix));
EXPECT_EQ("s32[3,4]{0,1}", ShapeUtil::HumanStringWithLayout(matrix2));
EXPECT_EQ("(opaque[], f32[], u32[1,2]{1,0}, s32[3,4]{0,1})",
ShapeUtil::HumanStringWithLayout(tuple));
EXPECT_EQ(
"((opaque[], f32[], u32[1,2]{1,0}, s32[3,4]{0,1}), u32[1,2]{1,0}, "
"token[])",
ShapeUtil::HumanStringWithLayout(nested_tuple));
}
TEST(ShapeUtilTest, ForEachSubshapeArray) {
const Shape shape = ShapeUtil::MakeShape(F32, {2, 3});
int calls = 0;

View File

@ -107,7 +107,7 @@ StatusOr<Literal> ClientLibraryTestBase::ExecuteAndTransfer(
ExecutionOptions execution_options = execution_options_;
if (shape_with_output_layout != nullptr) {
*execution_options.mutable_shape_with_output_layout() =
*shape_with_output_layout;
shape_with_output_layout->ToProto();
}
return client_->ExecuteAndTransfer(computation, arguments,
&execution_options);
@ -127,7 +127,7 @@ StatusOr<Literal> ClientLibraryTestBase::ExecuteAndTransferReference(
ExecutionOptions execution_options = execution_options_;
if (shape_with_output_layout != nullptr) {
*execution_options.mutable_shape_with_output_layout() =
*shape_with_output_layout;
shape_with_output_layout->ToProto();
}
execution_options.clear_device_handles();
return ref_client_->ExecuteAndTransfer(computation, arguments,

View File

@ -50,7 +50,8 @@ XLA_TEST_F(ClientTest, ExecuteWithLayout) {
ExecutionOptions execution_options = execution_options_;
*execution_options.mutable_shape_with_output_layout() =
ShapeUtil::MakeShapeWithLayout(S32, /*dimensions=*/{2, 2},
execute_layout);
execute_layout)
.ToProto();
TF_ASSERT_OK_AND_ASSIGN(
std::unique_ptr<GlobalData> data,
client_->Execute(computation, {}, &execution_options));
@ -84,7 +85,8 @@ XLA_TEST_F(ClientTest, ExecuteWithTupleLayout) {
{ShapeUtil::MakeShapeWithLayout(S32, /*dimensions=*/{2, 2},
/*minor_to_major=*/{0, 1}),
ShapeUtil::MakeShapeWithLayout(S32, /*dimensions=*/{2, 2},
/*minor_to_major=*/{1, 0})});
/*minor_to_major=*/{1, 0})})
.ToProto();
TF_ASSERT_OK_AND_ASSIGN(
auto result,

View File

@ -618,7 +618,8 @@ XLA_TEST_P(ReshapeTest, R4Dim0MinorLayoutToR2Dim0MajorLayout) {
ExecutionOptions execution_options = execution_options_;
*execution_options.mutable_shape_with_output_layout() =
ShapeUtil::MakeShapeWithLayout(use_bfloat16() ? BF16 : F32, {2, 8},
{1, 0});
{1, 0})
.ToProto();
Literal actual =
client_
->ExecuteAndTransfer(computation, {input.get()}, &execution_options)
@ -767,7 +768,8 @@ XLA_TEST_P(ReshapeTest, NoopReshape) {
ExecutionOptions execution_options = execution_options_;
*execution_options.mutable_shape_with_output_layout() =
ShapeUtil::MakeShapeWithLayout(use_bfloat16() ? BF16 : F32, {7, 2, 3, 5},
{2, 3, 0, 1});
{2, 3, 0, 1})
.ToProto();
Literal output_literal =
client_
->ExecuteAndTransfer(computation, {input_data.get()},

View File

@ -82,13 +82,17 @@ struct Options {
std::unique_ptr<LocalExecutable> CompileExecutable(const HloSnapshot& module,
LocalClient* client) {
XlaComputation computation(module.hlo().hlo_module());
std::vector<const Shape*> argument_layouts;
for (const auto& param :
std::vector<Shape> argument_layouts;
argument_layouts.reserve(
computation.proto().host_program_shape().parameters_size());
std::vector<const Shape*> argument_layout_ptrs;
for (const ShapeProto& param :
computation.proto().host_program_shape().parameters()) {
argument_layouts.push_back(&param);
argument_layouts.push_back(Shape(param));
argument_layout_ptrs.push_back(&argument_layouts.back());
}
return client
->Compile(computation, argument_layouts, ExecutableBuildOptions())
->Compile(computation, argument_layout_ptrs, ExecutableBuildOptions())
.ValueOrDie();
}
@ -149,7 +153,7 @@ StatusOr<Literal> ReplayComputation(const HloSnapshot& module,
<< "--generate_fake_infeed only works if the model has 0 or 1 "
"infeed ops, but this one has >= 2.";
provide_infeed = true;
infeed_shape = instruction.shape();
infeed_shape = Shape(instruction.shape());
LOG(INFO) << "Generating fake infeed shape for inferred shape: "
<< ShapeUtil::HumanString(infeed_shape);
}
@ -315,9 +319,10 @@ int RealMain(absl::Span<char* const> args, const Options& opts) {
if (snapshot.has_result()) {
Literal literal =
Literal::CreateFromProto(snapshot.result()).ConsumeValueOrDie();
fprintf(stdout, "was %s:%s\n",
ShapeUtil::HumanString(snapshot.result().shape()).c_str(),
literal.ToString().c_str());
fprintf(
stdout, "was %s:%s\n",
ShapeUtil::HumanString(Shape(snapshot.result().shape())).c_str(),
literal.ToString().c_str());
}
}
}

View File

@ -152,6 +152,13 @@ static inline absl::Span<const int64> AsInt64Slice(
slice.size());
}
// TODO(b/29771030): This nop overload was added to simplify the migration of
// Shape from a proto to a C++ class. Remove after class has been migrated.
static inline absl::Span<const int64> AsInt64Slice(
absl::Span<const int64> slice) {
return slice;
}
// As above, but for uint64 types.
static inline absl::Span<const uint64> AsUInt64Slice(
const tensorflow::protobuf::RepeatedField<tensorflow::protobuf_uint64>& v) {

View File

@ -224,7 +224,7 @@ message ExecutionOptions {
// may be faster when using this layout.
//
// We use a Shape here to accommodate computations that return a tuple.
Shape shape_with_output_layout = 2;
ShapeProto shape_with_output_layout = 2;
// Used to seed random-number generators used in this computation. If this is
// 0, we generate a seed ourselves.
@ -253,7 +253,7 @@ message TransferToClientRequest {
// This optional field directs the service to return the literal in this
// layout. A shape is used to hold the layout to accommodate tuples.
Shape shape_with_layout = 2;
ShapeProto shape_with_layout = 2;
}
message TransferToClientResponse {
@ -281,7 +281,7 @@ message TransferToInfeedResponse {
message TransferFromOutfeedRequest {
// This optional field directs the service to return the literal in this
// layout. A shape is used to hold the layout to accommodate tuples.
Shape shape_with_layout = 1;
ShapeProto shape_with_layout = 1;
int64 replica_id = 2;
DeviceHandle device_handle = 3;
@ -332,7 +332,7 @@ message CompileRequest {
// The layouts of the input arguments. If not set, the default layout will be
// used. Although the real arguments are not needed in compilation, the
// layouts of the arguments can affect the compilation.
repeated Shape input_shape_with_layout = 3;
repeated ShapeProto input_shape_with_layout = 3;
}
message CompileResponse {
@ -406,7 +406,7 @@ message LoadDataRequest {
string columnio_field = 2;
// Individual element shape, excluding rows.
Shape element_shape = 3;
ShapeProto element_shape = 3;
// Warning: ColumnIO does not support random-access, so use offset with
// caution in performance-critical scenarios.
@ -422,7 +422,7 @@ message LoadDataRequest {
message LoadDataResponse {
GlobalDataHandle data = 1;
Shape data_shape = 2;
ShapeProto data_shape = 2;
int64 available_rows = 3;
int64 rows_loaded = 4;
int64 nanoseconds = 5;
@ -433,7 +433,7 @@ message GetShapeRequest {
}
message GetShapeResponse {
Shape shape = 1;
ShapeProto shape = 1;
}
message UnpackRequest {

View File

@ -154,7 +154,7 @@ message Layout {
// See the XLA documentation for more information on shapes and layouts.
//
// LINT.IfChange
message Shape {
message ShapeProto {
reserved 1;
reserved "rank";
@ -169,7 +169,7 @@ message Shape {
repeated int64 dimensions = 3;
// For tuples only, the shapes of constitutent shapes in the tuple sequence.
repeated Shape tuple_shapes = 4;
repeated ShapeProto tuple_shapes = 4;
// The layout used to back this shape.
Layout layout = 5;
@ -184,8 +184,8 @@ message Shape {
// Shape of the parameters and output of a computation (like a traditional
// function signature).
message ProgramShapeProto {
repeated Shape parameters = 1;
Shape result = 2;
repeated ShapeProto parameters = 1;
ShapeProto result = 2;
repeated string parameter_names = 3;
}
@ -320,7 +320,7 @@ message DeviceAssignmentProto {
// Transfers to/from the client are encoded in literal form, and the structure
// of the repeated fields is implied by the shape.
message LiteralProto {
Shape shape = 1;
ShapeProto shape = 1;
repeated bool preds = 2;
bytes s8s = 15;
bytes u8s = 3;
@ -521,7 +521,7 @@ message OpSharding {
}
Type type = 1;
// The shape of the sharded tile.
Shape tile_shape = 2;
ShapeProto tile_shape = 2;
// The shape of the tile assignment tensor - this must be the same rank as
// tile_shape and the product of its dimensions must equal
// tile_assignment_devices.size().

View File

@ -109,14 +109,17 @@ Status XRTCompileOp::Compile(OpKernelContext* ctx,
TF_ASSIGN_OR_RETURN(xla::XlaComputation computation,
client->LoadSnapshot(computation_proto.hlo_snapshot()));
std::vector<const xla::Shape*> argument_layouts(
std::vector<xla::Shape> argument_layouts(
config.program_shape().parameters_size());
std::vector<const xla::Shape*> argument_layout_ptrs(
config.program_shape().parameters_size());
for (int i = 0; i < config.program_shape().parameters_size(); ++i) {
argument_layouts[i] = &config.program_shape().parameters(i);
argument_layouts[i] = xla::Shape(config.program_shape().parameters(i));
argument_layout_ptrs[i] = &argument_layouts[i];
}
xla::ExecutableBuildOptions build_options;
build_options.set_device_ordinal(client->default_device_ordinal());
build_options.set_result_layout(config.program_shape().result());
build_options.set_result_layout(xla::Shape(config.program_shape().result()));
build_options.set_device_allocator(device_ref.backend()->memory_allocator());
if (config.has_debug_options()) {
*build_options.mutable_debug_options() =
@ -125,7 +128,7 @@ Status XRTCompileOp::Compile(OpKernelContext* ctx,
VLOG(1) << "Building executable";
auto compile_result =
client->Compile(computation, argument_layouts, build_options);
client->Compile(computation, argument_layout_ptrs, build_options);
if (!compile_result.ok()) {
return compile_result.status();
}

View File

@ -375,9 +375,12 @@ TEST(RawApiTest, CompileAndExecute) {
xrt::XLAComputation c;
auto config = c.mutable_config();
auto shapes = config->mutable_program_shape();
*shapes->add_parameters() = xla::ShapeUtil::MakeShape(xla::F32, {2});
*shapes->add_parameters() = xla::ShapeUtil::MakeShape(xla::F32, {2});
*shapes->mutable_result() = xla::ShapeUtil::MakeShape(xla::F32, {2});
*shapes->add_parameters() =
xla::ShapeUtil::MakeShape(xla::F32, {2}).ToProto();
*shapes->add_parameters() =
xla::ShapeUtil::MakeShape(xla::F32, {2}).ToProto();
*shapes->mutable_result() =
xla::ShapeUtil::MakeShape(xla::F32, {2}).ToProto();
StoreComputationSnapshot(AddAndScale(), c.mutable_hlo_snapshot());
xrt::XRTExecutionConfig e;
@ -427,9 +430,12 @@ TEST(RawApiTest, CompileAndExecuteWithArgumentVector) {
xrt::XLAComputation c;
auto config = c.mutable_config();
auto shapes = config->mutable_program_shape();
*shapes->add_parameters() = xla::ShapeUtil::MakeShape(xla::F32, {2});
*shapes->add_parameters() = xla::ShapeUtil::MakeShape(xla::F32, {2});
*shapes->mutable_result() = xla::ShapeUtil::MakeShape(xla::F32, {2});
*shapes->add_parameters() =
xla::ShapeUtil::MakeShape(xla::F32, {2}).ToProto();
*shapes->add_parameters() =
xla::ShapeUtil::MakeShape(xla::F32, {2}).ToProto();
*shapes->mutable_result() =
xla::ShapeUtil::MakeShape(xla::F32, {2}).ToProto();
StoreComputationSnapshot(AddAndScale(), c.mutable_hlo_snapshot());
xrt::XRTExecutionConfig e;
@ -494,8 +500,8 @@ TEST(RawApiTest, CompileWithXlaReturnShapes) {
xrt::XLAComputation c;
auto config = c.mutable_config();
auto shapes = config->mutable_program_shape();
*shapes->add_parameters() = param_shape;
*shapes->mutable_result() = result_shape;
*shapes->add_parameters() = param_shape.ToProto();
*shapes->mutable_result() = result_shape.ToProto();
StoreComputationSnapshot(xla_computation, c.mutable_hlo_snapshot());
Scope root = Scope::NewRootScope().WithDevice(DeviceFromFlag());
@ -510,8 +516,9 @@ TEST(RawApiTest, CompileWithXlaReturnShapes) {
TF_EXPECT_OK(session.Run(tensorflow::ClientSession::FeedType(),
{c_handle.program_shape}, {release}, &outputs));
xla::ProgramShapeProto program_shape;
EXPECT_TRUE(program_shape.ParseFromString(outputs[0].vec<string>()(0)));
xla::ProgramShapeProto program_shape_proto;
EXPECT_TRUE(program_shape_proto.ParseFromString(outputs[0].vec<string>()(0)));
xla::ProgramShape program_shape(program_shape_proto);
EXPECT_EQ(program_shape.parameters_size(), 1);
VLOG(2) << "Param: "
@ -547,11 +554,11 @@ TEST(RawApiTest, DotGeneralWithLayoutTest) {
auto config = c.mutable_config();
auto shapes = config->mutable_program_shape();
*shapes->add_parameters() =
xla::ShapeUtil::MakeShapeWithLayout(xla::F32, {2, 2}, {0, 1});
xla::ShapeUtil::MakeShapeWithLayout(xla::F32, {2, 2}, {0, 1}).ToProto();
*shapes->add_parameters() =
xla::ShapeUtil::MakeShapeWithLayout(xla::F32, {2, 1}, {0, 1});
xla::ShapeUtil::MakeShapeWithLayout(xla::F32, {2, 1}, {0, 1}).ToProto();
*shapes->mutable_result() =
xla::ShapeUtil::MakeShapeWithLayout(xla::F32, {2, 1}, {0, 1});
xla::ShapeUtil::MakeShapeWithLayout(xla::F32, {2, 1}, {0, 1}).ToProto();
StoreComputationSnapshot(Dot(), c.mutable_hlo_snapshot());
xrt::XRTExecutionConfig e;
@ -592,7 +599,7 @@ TEST(RawApiTest, CompileAndExecuteZeroArg) {
xrt::XLAComputation c;
auto config = c.mutable_config();
auto shapes = config->mutable_program_shape();
*shapes->mutable_result() = xla::ShapeUtil::MakeShape(xla::F32, {});
*shapes->mutable_result() = xla::ShapeUtil::MakeShape(xla::F32, {}).ToProto();
xrt::XRTExecutionConfig e;
e.set_release_input_handles(true);
@ -632,10 +639,13 @@ TEST(RawApiTest, CompileAndExecuteReturnTuple) {
xrt::XLAComputation c;
auto config = c.mutable_config();
auto shapes = config->mutable_program_shape();
*shapes->add_parameters() = xla::ShapeUtil::MakeShape(xla::F32, {2});
*shapes->add_parameters() = xla::ShapeUtil::MakeShape(xla::F32, {2});
*shapes->mutable_result() = xla::ShapeUtil::MakeTupleShape(
{xla::ShapeUtil::MakeShape(xla::F32, {2})});
*shapes->add_parameters() =
xla::ShapeUtil::MakeShape(xla::F32, {2}).ToProto();
*shapes->add_parameters() =
xla::ShapeUtil::MakeShape(xla::F32, {2}).ToProto();
*shapes->mutable_result() =
xla::ShapeUtil::MakeTupleShape({xla::ShapeUtil::MakeShape(xla::F32, {2})})
.ToProto();
StoreComputationSnapshot(AddAndTuple(), c.mutable_hlo_snapshot());
xrt::XRTExecutionConfig e;
@ -675,10 +685,13 @@ TEST(RawApiTest, LeakCompilationReference) {
xrt::XLAComputation c;
auto config = c.mutable_config();
auto shapes = config->mutable_program_shape();
*shapes->add_parameters() = xla::ShapeUtil::MakeShape(xla::F32, {2});
*shapes->add_parameters() = xla::ShapeUtil::MakeShape(xla::F32, {2});
*shapes->mutable_result() = xla::ShapeUtil::MakeTupleShape(
{xla::ShapeUtil::MakeShape(xla::F32, {2})});
*shapes->add_parameters() =
xla::ShapeUtil::MakeShape(xla::F32, {2}).ToProto();
*shapes->add_parameters() =
xla::ShapeUtil::MakeShape(xla::F32, {2}).ToProto();
*shapes->mutable_result() =
xla::ShapeUtil::MakeTupleShape({xla::ShapeUtil::MakeShape(xla::F32, {2})})
.ToProto();
StoreComputationSnapshot(AddAndTuple(), c.mutable_hlo_snapshot());
Scope root = Scope::NewRootScope().WithDevice(DeviceFromFlag());
@ -703,9 +716,9 @@ TEST(RawApiTest, CompileAndExecuteWithS64Argument) {
xrt::XLAComputation c;
auto config = c.mutable_config();
auto shapes = config->mutable_program_shape();
*shapes->add_parameters() = xla::ShapeUtil::MakeShape(xla::S64, {});
*shapes->add_parameters() = xla::ShapeUtil::MakeShape(xla::S64, {});
*shapes->mutable_result() = xla::ShapeUtil::MakeShape(xla::S64, {});
*shapes->add_parameters() = xla::ShapeUtil::MakeShape(xla::S64, {}).ToProto();
*shapes->add_parameters() = xla::ShapeUtil::MakeShape(xla::S64, {}).ToProto();
*shapes->mutable_result() = xla::ShapeUtil::MakeShape(xla::S64, {}).ToProto();
StoreComputationSnapshot(AddS64(), c.mutable_hlo_snapshot());
xrt::XRTExecutionConfig e;
@ -742,8 +755,8 @@ TEST(RawApiTest, CompileAndExecuteWithS64Argument) {
xla::ProgramShapeProto program_shape;
EXPECT_TRUE(program_shape.ParseFromString(outputs[1].vec<string>()(0)));
EXPECT_EQ(program_shape.parameters_size(), 2);
EXPECT_TRUE(
xla::ShapeUtil::HasPrimitiveType(program_shape.result(), xla::S64));
EXPECT_TRUE(xla::ShapeUtil::HasPrimitiveType(
xla::Shape(program_shape.result()), xla::S64));
}
} // namespace