Replace Shape with a C++ class in XLA.
No functional change. Rename the proto message Shape to ShapeProto and define an in-place replacement C++ class named Shape with an interface which mirrors the protobuf generated code interface. Having Shape as a C++ class enables greater flexibility in the interface, enables enforcement of invariants, and potential performance improvements. PiperOrigin-RevId: 223252977
This commit is contained in:
parent
0f98c067fa
commit
bd737c846c
tensorflow/compiler
aot
tf2xla
xla
client
index_util.hliteral.ccliteral_test.ccpython_api
rpc
service
compile_only_service.cc
shape.ccshape.hshape_test.ccshape_util.ccshape_util.hshape_util_test.cccpu
hlo.protohlo_instruction.cchlo_instructions.cchlo_lexer.hhlo_module.cchlo_proto_util.cchlo_proto_util.hllvm_ir
local_service.ccservice.ccshape_inference.cctests
tools
util.hxla.protoxla_data.protoxrt
@ -175,7 +175,8 @@ Status GenArgMethods(const tf2xla::Config& config,
|
||||
}
|
||||
for (int i = 0; i < num_args; ++i) {
|
||||
std::vector<std::pair<string, string>> rewrites;
|
||||
TF_RETURN_IF_ERROR(AddRewritesForShape(i, ps.parameters(i), &rewrites));
|
||||
TF_RETURN_IF_ERROR(
|
||||
AddRewritesForShape(i, xla::Shape(ps.parameters(i)), &rewrites));
|
||||
const string code = R"(
|
||||
void set_arg{{NAME}}_data(void* data) {
|
||||
set_arg_data({{I}}, data);
|
||||
@ -218,8 +219,8 @@ Status GenResultMethods(const tf2xla::Config& config,
|
||||
}
|
||||
for (int i = 0; i < ps.result().tuple_shapes_size(); ++i) {
|
||||
std::vector<std::pair<string, string>> rewrites;
|
||||
TF_RETURN_IF_ERROR(
|
||||
AddRewritesForShape(i, ps.result().tuple_shapes(i), &rewrites));
|
||||
TF_RETURN_IF_ERROR(AddRewritesForShape(
|
||||
i, xla::Shape(ps.result().tuple_shapes(i)), &rewrites));
|
||||
string code = R"(
|
||||
{{TYPE}}* result{{NAME}}_data() {
|
||||
return static_cast<{{TYPE}}*>(result_data({{I}}));
|
||||
@ -588,7 +589,7 @@ class {{CLASS}} : public tensorflow::XlaCompiledCpuFunction {
|
||||
{"{{METHODS_RESULT}}\n", methods_result},
|
||||
{"{{NS_END}}\n", ns_end},
|
||||
{"{{NS_START}}\n", ns_start},
|
||||
{"{{PROGRAM_SHAPE}}", xla::ShapeUtil::HumanString(ps)},
|
||||
{"{{PROGRAM_SHAPE}}", xla::ShapeUtil::HumanString(xla::ProgramShape(ps))},
|
||||
{"{{PROGRAM_SHAPE_SHIM_EXPRESSION}}",
|
||||
metadata_result.program_shape_access_shim},
|
||||
{"{{RESULT_INDEX}}", absl::StrCat(result_index)},
|
||||
|
@ -58,15 +58,21 @@ Status CompileXla(xla::CompileOnlyClient* client,
|
||||
}
|
||||
compile_result->program_shape = pshape_or.ValueOrDie()->ToProto();
|
||||
xla::ProgramShapeProto* pshape = &compile_result->program_shape;
|
||||
std::vector<const xla::Shape*> arg_layouts;
|
||||
arg_layouts.reserve(pshape->parameters_size());
|
||||
|
||||
// AotXlaComputationInstance::argument_layouts is a vector of Shape
|
||||
// pointers. Accumulate the Shape objects themselves in a separate vector
|
||||
// while building the vector of pointers.
|
||||
std::vector<const xla::Shape*> arg_layout_ptrs(pshape->parameters_size());
|
||||
std::vector<xla::Shape> arg_layouts(pshape->parameters_size());
|
||||
for (int i = 0; i < pshape->parameters_size(); ++i) {
|
||||
arg_layouts.push_back(pshape->mutable_parameters(i));
|
||||
arg_layouts[i] = xla::Shape(*pshape->mutable_parameters(i));
|
||||
arg_layout_ptrs[i] = &arg_layouts[i];
|
||||
}
|
||||
xla::CompileOnlyClient::AotXlaComputationInstance instance;
|
||||
instance.computation = &computation;
|
||||
instance.argument_layouts = std::move(arg_layouts);
|
||||
instance.result_layout = &pshape->result();
|
||||
instance.argument_layouts = std::move(arg_layout_ptrs);
|
||||
xla::Shape result_shape(pshape->result());
|
||||
instance.result_layout = &result_shape;
|
||||
xla::StatusOr<std::vector<std::unique_ptr<xla::AotCompilationResult>>>
|
||||
aot_or = client->CompileAheadOfTime({instance}, aot_opts);
|
||||
if (!aot_or.ok()) {
|
||||
|
@ -529,10 +529,12 @@ TEST(TFCompileTest, ProgramShape) {
|
||||
const xla::ProgramShapeProto* muladd_shape = muladd.ProgramShape();
|
||||
ASSERT_TRUE(muladd_shape != nullptr);
|
||||
ASSERT_EQ(muladd_shape->parameters_size(), 2);
|
||||
EXPECT_TRUE(ShapeUtil::Compatible(muladd_shape->parameters(0), f32_2x2));
|
||||
EXPECT_TRUE(ShapeUtil::Compatible(muladd_shape->parameters(1), f32_2x2));
|
||||
EXPECT_TRUE(
|
||||
ShapeUtil::Compatible(xla::Shape(muladd_shape->parameters(0)), f32_2x2));
|
||||
EXPECT_TRUE(
|
||||
ShapeUtil::Compatible(xla::Shape(muladd_shape->parameters(1)), f32_2x2));
|
||||
|
||||
const xla::Shape& muladd_result = muladd_shape->result();
|
||||
const xla::Shape muladd_result(muladd_shape->result());
|
||||
ASSERT_EQ(muladd_result.element_type(), xla::TUPLE);
|
||||
ASSERT_EQ(ShapeUtil::TupleElementCount(muladd_result), 2);
|
||||
const xla::Shape& muladd_result0 =
|
||||
|
@ -18,6 +18,7 @@ limitations under the License.
|
||||
#ifndef TENSORFLOW_COMPILER_TF2XLA_SHAPE_UTIL_H_
|
||||
#define TENSORFLOW_COMPILER_TF2XLA_SHAPE_UTIL_H_
|
||||
|
||||
#include "tensorflow/compiler/xla/shape.h"
|
||||
#include "tensorflow/compiler/xla/xla_data.pb.h"
|
||||
#include "tensorflow/core/framework/tensor_shape.h"
|
||||
#include "tensorflow/core/framework/types.pb.h"
|
||||
|
@ -116,13 +116,13 @@ TEST(XlaJitCompiledCpuFunction, Sum) {
|
||||
// Check program shape.
|
||||
using xla::ShapeUtil;
|
||||
const xla::Shape s32 = ShapeUtil::MakeShape(xla::S32, {});
|
||||
const xla::ProgramShapeProto* program_shape = function.ProgramShape();
|
||||
ASSERT_TRUE(program_shape != nullptr);
|
||||
ASSERT_EQ(program_shape->parameters_size(), 2);
|
||||
EXPECT_TRUE(ShapeUtil::Compatible(program_shape->parameters(0), s32));
|
||||
EXPECT_TRUE(ShapeUtil::Compatible(program_shape->parameters(1), s32));
|
||||
ASSERT_TRUE(function.ProgramShape() != nullptr);
|
||||
const xla::ProgramShape program_shape(*function.ProgramShape());
|
||||
ASSERT_EQ(program_shape.parameters_size(), 2);
|
||||
EXPECT_TRUE(ShapeUtil::Compatible(program_shape.parameters(0), s32));
|
||||
EXPECT_TRUE(ShapeUtil::Compatible(program_shape.parameters(1), s32));
|
||||
|
||||
const xla::Shape& result = program_shape->result();
|
||||
const xla::Shape& result = program_shape.result();
|
||||
ASSERT_EQ(result.element_type(), xla::TUPLE);
|
||||
ASSERT_EQ(ShapeUtil::TupleElementCount(result), 1);
|
||||
const xla::Shape& result0 = ShapeUtil::GetTupleElementShape(result, 0);
|
||||
|
@ -81,6 +81,7 @@ cc_library(
|
||||
"//tensorflow/core:lib",
|
||||
"@com_google_absl//absl/memory",
|
||||
"@com_google_absl//absl/strings",
|
||||
"@com_google_absl//absl/types:optional",
|
||||
"@com_google_absl//absl/types:span",
|
||||
],
|
||||
)
|
||||
|
@ -20,6 +20,7 @@ limitations under the License.
|
||||
|
||||
#include "absl/memory/memory.h"
|
||||
#include "absl/strings/str_cat.h"
|
||||
#include "absl/types/optional.h"
|
||||
#include "tensorflow/compiler/xla/client/xla_computation.h"
|
||||
#include "tensorflow/compiler/xla/debug_options_flags.h"
|
||||
#include "tensorflow/compiler/xla/execution_options_util.h"
|
||||
@ -42,7 +43,7 @@ StatusOr<Literal> Client::Transfer(const GlobalData& data,
|
||||
TransferToClientRequest request;
|
||||
*request.mutable_data() = data.handle();
|
||||
if (shape_with_layout != nullptr) {
|
||||
*request.mutable_shape_with_layout() = *shape_with_layout;
|
||||
*request.mutable_shape_with_layout() = shape_with_layout->ToProto();
|
||||
}
|
||||
TransferToClientResponse response;
|
||||
|
||||
@ -123,7 +124,7 @@ StatusOr<Literal> Client::TransferFromOutfeed(
|
||||
}
|
||||
request.set_replica_id(replica_id);
|
||||
if (shape_with_layout != nullptr) {
|
||||
*request.mutable_shape_with_layout() = *shape_with_layout;
|
||||
*request.mutable_shape_with_layout() = shape_with_layout->ToProto();
|
||||
}
|
||||
TransferFromOutfeedResponse response;
|
||||
|
||||
@ -170,11 +171,14 @@ StatusOr<Literal> Client::ExecuteAndTransfer(
|
||||
std::unique_ptr<GlobalData> data,
|
||||
Execute(computation, arguments, execution_options, execution_profile));
|
||||
|
||||
const Shape* shape_with_output_layout = nullptr;
|
||||
absl::optional<Shape> shape_with_output_layout;
|
||||
if (execution_options && execution_options->has_shape_with_output_layout()) {
|
||||
shape_with_output_layout = &execution_options->shape_with_output_layout();
|
||||
shape_with_output_layout =
|
||||
Shape(execution_options->shape_with_output_layout());
|
||||
}
|
||||
return Transfer(*data, shape_with_output_layout);
|
||||
return Transfer(*data, shape_with_output_layout.has_value()
|
||||
? &(*shape_with_output_layout)
|
||||
: nullptr);
|
||||
}
|
||||
|
||||
StatusOr<Literal> Client::ComputeConstant(const XlaComputation& computation,
|
||||
@ -229,7 +233,7 @@ StatusOr<ExecutionHandle> Client::Compile(
|
||||
|
||||
// The argument shapes affect how the computation is compiled.
|
||||
for (const auto& arg_shape : argument_shapes) {
|
||||
*request.add_input_shape_with_layout() = arg_shape;
|
||||
*request.add_input_shape_with_layout() = arg_shape.ToProto();
|
||||
}
|
||||
|
||||
CompileResponse response;
|
||||
@ -458,7 +462,7 @@ StatusOr<Shape> Client::GetShape(const GlobalData& data) {
|
||||
return s;
|
||||
}
|
||||
|
||||
return response.shape();
|
||||
return Shape(response.shape());
|
||||
}
|
||||
|
||||
StatusOr<string> Client::ExecutionStatsAsString(
|
||||
|
@ -19,6 +19,7 @@ limitations under the License.
|
||||
#include "absl/strings/string_view.h"
|
||||
#include "absl/types/optional.h"
|
||||
#include "tensorflow/compiler/xla/service/device_memory_allocator.h"
|
||||
#include "tensorflow/compiler/xla/shape.h"
|
||||
#include "tensorflow/compiler/xla/util.h"
|
||||
#include "tensorflow/compiler/xla/xla.pb.h"
|
||||
#include "tensorflow/compiler/xla/xla_data.pb.h"
|
||||
|
@ -66,7 +66,7 @@ std::unique_ptr<GlobalData> MakeFakeDataViaDeviceOrDie(const Shape& shape,
|
||||
XlaComputation computation = b.Build().ConsumeValueOrDie();
|
||||
|
||||
auto execution_options = CreateDefaultExecutionOptions();
|
||||
*execution_options.mutable_shape_with_output_layout() = shape;
|
||||
*execution_options.mutable_shape_with_output_layout() = shape.ToProto();
|
||||
return client->Execute(computation, /*arguments=*/{}, &execution_options)
|
||||
.ConsumeValueOrDie();
|
||||
}
|
||||
@ -98,8 +98,8 @@ std::vector<std::unique_ptr<GlobalData>> MakeFakeArgumentsOrDie(
|
||||
auto program_shape = computation.proto().host_program_shape();
|
||||
|
||||
std::vector<std::unique_ptr<GlobalData>> results;
|
||||
for (const Shape& shape : program_shape.parameters()) {
|
||||
results.push_back(MakeFakeDataOrDie(shape, client));
|
||||
for (const ShapeProto& shape : program_shape.parameters()) {
|
||||
results.push_back(MakeFakeDataOrDie(Shape(shape), client));
|
||||
}
|
||||
return results;
|
||||
}
|
||||
|
@ -36,7 +36,7 @@ OpSharding Tile(const Shape& tile_shape,
|
||||
const TileAssignment& tile_assignment) {
|
||||
OpSharding result;
|
||||
result.set_type(OpSharding::Type::OpSharding_Type_OTHER);
|
||||
*result.mutable_tile_shape() = tile_shape;
|
||||
*result.mutable_tile_shape() = tile_shape.ToProto();
|
||||
for (int64 dim : tile_assignment.dimensions()) {
|
||||
result.add_tile_assignment_dimensions(dim);
|
||||
}
|
||||
@ -52,7 +52,7 @@ OpSharding Tile1D(const Shape& tile_shape, int64 num_tiles) {
|
||||
|
||||
CHECK_EQ(ShapeUtil::Rank(tile_shape), 1);
|
||||
std::vector<int64> dimensions(1, num_tiles);
|
||||
*result.mutable_tile_shape() = tile_shape;
|
||||
*result.mutable_tile_shape() = tile_shape.ToProto();
|
||||
auto& tile_dimension =
|
||||
(*result.mutable_tile_shape()->mutable_dimensions())[0];
|
||||
tile_dimension = CeilOfRatio(static_cast<int64>(tile_dimension), num_tiles);
|
||||
|
@ -102,7 +102,7 @@ StatusOr<Shape> XlaBuilder::GetShape(const XlaOp& op) const {
|
||||
TF_RETURN_IF_ERROR(first_error_);
|
||||
|
||||
TF_ASSIGN_OR_RETURN(auto instr, LookUpInstruction(op));
|
||||
return instr->shape();
|
||||
return Shape(instr->shape());
|
||||
}
|
||||
|
||||
StatusOr<std::vector<Shape>> XlaBuilder::GetOperandShapes(
|
||||
@ -155,7 +155,7 @@ StatusOr<ProgramShape> XlaBuilder::GetProgramShape(int64 root_id) const {
|
||||
|
||||
ProgramShape program_shape;
|
||||
|
||||
*program_shape.mutable_result() = root_proto->shape();
|
||||
*program_shape.mutable_result() = Shape(root_proto->shape());
|
||||
|
||||
// Check that the parameter numbers are continuous from 0, and add parameter
|
||||
// shapes and names to the program shape.
|
||||
@ -172,7 +172,7 @@ StatusOr<ProgramShape> XlaBuilder::GetProgramShape(int64 root_id) const {
|
||||
const int64 index = instr.parameter_number();
|
||||
TF_RET_CHECK(index >= 0 && index < param_count)
|
||||
<< "invalid parameter number: " << index;
|
||||
*program_shape.mutable_parameters(index) = instr.shape();
|
||||
*program_shape.mutable_parameters(index) = Shape(instr.shape());
|
||||
*program_shape.mutable_parameter_names(index) = instr.name();
|
||||
}
|
||||
}
|
||||
@ -329,7 +329,7 @@ StatusOr<XlaOp> XlaBuilder::InDimBroadcast(
|
||||
TF_RETURN_IF_ERROR(first_error_);
|
||||
|
||||
HloInstructionProto instr;
|
||||
*instr.mutable_shape() = shape;
|
||||
*instr.mutable_shape() = shape.ToProto();
|
||||
for (int64 dim : broadcast_dimensions) {
|
||||
instr.add_dimensions(dim);
|
||||
}
|
||||
@ -380,8 +380,9 @@ XlaOp XlaBuilder::UnaryOp(HloOpcode unop, const XlaOp& operand) {
|
||||
return ReportErrorOrReturn([&]() -> StatusOr<XlaOp> {
|
||||
HloInstructionProto instr;
|
||||
TF_ASSIGN_OR_RETURN(const Shape& operand_shape, GetShape(operand));
|
||||
TF_ASSIGN_OR_RETURN(*instr.mutable_shape(),
|
||||
TF_ASSIGN_OR_RETURN(Shape shape,
|
||||
ShapeInference::InferUnaryOpShape(unop, operand_shape));
|
||||
*instr.mutable_shape() = shape.ToProto();
|
||||
return AddInstruction(std::move(instr), unop, {operand});
|
||||
});
|
||||
}
|
||||
@ -392,9 +393,10 @@ XlaOp XlaBuilder::BinaryOp(HloOpcode binop, const XlaOp& lhs, const XlaOp& rhs,
|
||||
HloInstructionProto instr;
|
||||
TF_ASSIGN_OR_RETURN(const Shape& lhs_shape, GetShape(lhs));
|
||||
TF_ASSIGN_OR_RETURN(const Shape& rhs_shape, GetShape(rhs));
|
||||
TF_ASSIGN_OR_RETURN(*instr.mutable_shape(),
|
||||
TF_ASSIGN_OR_RETURN(Shape shape,
|
||||
ShapeInference::InferBinaryOpShape(
|
||||
binop, lhs_shape, rhs_shape, broadcast_dimensions));
|
||||
*instr.mutable_shape() = shape.ToProto();
|
||||
|
||||
const int64 lhs_rank = ShapeUtil::Rank(lhs_shape);
|
||||
const int64 rhs_rank = ShapeUtil::Rank(rhs_shape);
|
||||
@ -408,7 +410,7 @@ XlaOp XlaBuilder::BinaryOp(HloOpcode binop, const XlaOp& lhs, const XlaOp& rhs,
|
||||
const Shape& from_shape = should_broadcast_lhs ? lhs_shape : rhs_shape;
|
||||
|
||||
std::vector<int64> to_size;
|
||||
for (int64 size : instr.shape().dimensions()) {
|
||||
for (int64 size : shape.dimensions()) {
|
||||
to_size.push_back(size);
|
||||
}
|
||||
for (int64 from_dim = 0; from_dim < ShapeUtil::Rank(from_shape);
|
||||
@ -428,14 +430,14 @@ XlaOp XlaBuilder::BinaryOp(HloOpcode binop, const XlaOp& lhs, const XlaOp& rhs,
|
||||
}
|
||||
|
||||
TF_ASSIGN_OR_RETURN(Shape updated_lhs_shape, GetShape(updated_lhs));
|
||||
if (!ShapeUtil::SameDimensions(instr.shape(), updated_lhs_shape)) {
|
||||
if (!ShapeUtil::SameDimensions(shape, updated_lhs_shape)) {
|
||||
TF_ASSIGN_OR_RETURN(updated_lhs,
|
||||
AddBroadcastSequence(instr.shape(), updated_lhs));
|
||||
AddBroadcastSequence(shape, updated_lhs));
|
||||
}
|
||||
TF_ASSIGN_OR_RETURN(Shape updated_rhs_shape, GetShape(updated_rhs));
|
||||
if (!ShapeUtil::SameDimensions(instr.shape(), updated_rhs_shape)) {
|
||||
if (!ShapeUtil::SameDimensions(shape, updated_rhs_shape)) {
|
||||
TF_ASSIGN_OR_RETURN(updated_rhs,
|
||||
AddBroadcastSequence(instr.shape(), updated_rhs));
|
||||
AddBroadcastSequence(shape, updated_rhs));
|
||||
}
|
||||
|
||||
return AddInstruction(std::move(instr), binop, {updated_lhs, updated_rhs});
|
||||
@ -449,30 +451,28 @@ XlaOp XlaBuilder::TernaryOp(HloOpcode triop, const XlaOp& lhs, const XlaOp& rhs,
|
||||
TF_ASSIGN_OR_RETURN(const Shape& lhs_shape, GetShape(lhs));
|
||||
TF_ASSIGN_OR_RETURN(const Shape& rhs_shape, GetShape(rhs));
|
||||
TF_ASSIGN_OR_RETURN(const Shape& ehs_shape, GetShape(ehs));
|
||||
TF_ASSIGN_OR_RETURN(*instr.mutable_shape(),
|
||||
ShapeInference::InferTernaryOpShape(
|
||||
triop, lhs_shape, rhs_shape, ehs_shape));
|
||||
TF_ASSIGN_OR_RETURN(
|
||||
Shape shape, ShapeInference::InferTernaryOpShape(triop, lhs_shape,
|
||||
rhs_shape, ehs_shape));
|
||||
*instr.mutable_shape() = shape.ToProto();
|
||||
XlaOp updated_lhs = lhs;
|
||||
XlaOp updated_rhs = rhs;
|
||||
XlaOp updated_ehs = ehs;
|
||||
if (!ShapeUtil::IsTuple(instr.shape())) {
|
||||
if (!ShapeUtil::IsTuple(shape)) {
|
||||
if (!ShapeUtil::IsTuple(lhs_shape) &&
|
||||
!ShapeUtil::SameDimensions(instr.shape(), lhs_shape)) {
|
||||
!ShapeUtil::SameDimensions(shape, lhs_shape)) {
|
||||
// lhs is being implicitly broadcasted. Change to explicit.
|
||||
TF_ASSIGN_OR_RETURN(updated_lhs,
|
||||
AddBroadcastSequence(instr.shape(), lhs));
|
||||
TF_ASSIGN_OR_RETURN(updated_lhs, AddBroadcastSequence(shape, lhs));
|
||||
}
|
||||
if (!ShapeUtil::IsTuple(rhs_shape) &&
|
||||
!ShapeUtil::SameDimensions(instr.shape(), rhs_shape)) {
|
||||
!ShapeUtil::SameDimensions(shape, rhs_shape)) {
|
||||
// rhs is being implicitly broadcasted. Change to explicit.
|
||||
TF_ASSIGN_OR_RETURN(updated_rhs,
|
||||
AddBroadcastSequence(instr.shape(), rhs));
|
||||
TF_ASSIGN_OR_RETURN(updated_rhs, AddBroadcastSequence(shape, rhs));
|
||||
}
|
||||
if (!ShapeUtil::IsTuple(ehs_shape) &&
|
||||
!ShapeUtil::SameDimensions(instr.shape(), ehs_shape)) {
|
||||
!ShapeUtil::SameDimensions(shape, ehs_shape)) {
|
||||
// ehs is being implicitly broadcasted. Change to explicit.
|
||||
TF_ASSIGN_OR_RETURN(updated_ehs,
|
||||
AddBroadcastSequence(instr.shape(), ehs));
|
||||
TF_ASSIGN_OR_RETURN(updated_ehs, AddBroadcastSequence(shape, ehs));
|
||||
}
|
||||
}
|
||||
return AddInstruction(std::move(instr), triop,
|
||||
@ -493,7 +493,7 @@ XlaOp XlaBuilder::Mul(const XlaOp& lhs, const XlaOp& rhs,
|
||||
XlaOp XlaBuilder::ConstantLiteral(const LiteralSlice& literal) {
|
||||
return ReportErrorOrReturn([&]() -> StatusOr<XlaOp> {
|
||||
HloInstructionProto instr;
|
||||
*instr.mutable_shape() = literal.shape();
|
||||
*instr.mutable_shape() = literal.shape().ToProto();
|
||||
*instr.mutable_literal() = literal.ToProto();
|
||||
return AddInstruction(std::move(instr), HloOpcode::kConstant);
|
||||
});
|
||||
@ -502,7 +502,7 @@ XlaOp XlaBuilder::ConstantLiteral(const LiteralSlice& literal) {
|
||||
XlaOp XlaBuilder::Iota(const Shape& shape, int64 iota_dimension) {
|
||||
return ReportErrorOrReturn([&]() -> StatusOr<XlaOp> {
|
||||
HloInstructionProto instr;
|
||||
*instr.mutable_shape() = shape;
|
||||
*instr.mutable_shape() = shape.ToProto();
|
||||
instr.add_dimensions(iota_dimension);
|
||||
return AddInstruction(std::move(instr), HloOpcode::kIota);
|
||||
});
|
||||
@ -522,10 +522,10 @@ XlaOp XlaBuilder::Call(const XlaComputation& computation,
|
||||
[](const Shape& shape) { return &shape; });
|
||||
TF_ASSIGN_OR_RETURN(const ProgramShape& called_program_shape,
|
||||
computation.GetProgramShape());
|
||||
TF_ASSIGN_OR_RETURN(
|
||||
*instr.mutable_shape(),
|
||||
ShapeInference::InferCallShape(operand_shape_ptrs,
|
||||
/*to_apply=*/called_program_shape));
|
||||
TF_ASSIGN_OR_RETURN(Shape shape, ShapeInference::InferCallShape(
|
||||
operand_shape_ptrs,
|
||||
/*to_apply=*/called_program_shape));
|
||||
*instr.mutable_shape() = shape.ToProto();
|
||||
|
||||
AddCalledComputation(computation, &instr);
|
||||
|
||||
@ -543,7 +543,7 @@ XlaOp XlaBuilder::Parameter(int64 parameter_number, const Shape& shape,
|
||||
}
|
||||
instr.set_parameter_number(parameter_number);
|
||||
instr.set_name(name);
|
||||
*instr.mutable_shape() = shape;
|
||||
*instr.mutable_shape() = shape.ToProto();
|
||||
return AddInstruction(std::move(instr), HloOpcode::kParameter);
|
||||
});
|
||||
}
|
||||
@ -601,7 +601,7 @@ StatusOr<XlaOp> XlaBuilder::Reshape(const Shape& shape, const XlaOp& operand) {
|
||||
TF_RETURN_IF_ERROR(first_error_);
|
||||
|
||||
HloInstructionProto instr;
|
||||
*instr.mutable_shape() = shape;
|
||||
*instr.mutable_shape() = shape.ToProto();
|
||||
return AddInstruction(std::move(instr), HloOpcode::kReshape, {operand});
|
||||
}
|
||||
|
||||
@ -613,9 +613,9 @@ XlaOp XlaBuilder::Slice(const XlaOp& operand,
|
||||
HloInstructionProto instr;
|
||||
TF_ASSIGN_OR_RETURN(const Shape& operand_shape, GetShape(operand));
|
||||
TF_ASSIGN_OR_RETURN(
|
||||
*instr.mutable_shape(),
|
||||
ShapeInference::InferSliceShape(operand_shape, start_indices,
|
||||
limit_indices, strides));
|
||||
Shape shape, ShapeInference::InferSliceShape(
|
||||
operand_shape, start_indices, limit_indices, strides));
|
||||
*instr.mutable_shape() = shape.ToProto();
|
||||
for (int i = 0; i < start_indices.size(); i++) {
|
||||
auto* slice_config = instr.add_slice_dimensions();
|
||||
slice_config->set_start(start_indices[i]);
|
||||
@ -650,9 +650,10 @@ XlaOp XlaBuilder::DynamicSlice(const XlaOp& operand, const XlaOp& start_indices,
|
||||
TF_ASSIGN_OR_RETURN(const Shape& operand_shape, GetShape(operand));
|
||||
TF_ASSIGN_OR_RETURN(const Shape& start_indices_shape,
|
||||
GetShape(start_indices));
|
||||
TF_ASSIGN_OR_RETURN(*instr.mutable_shape(),
|
||||
TF_ASSIGN_OR_RETURN(Shape shape,
|
||||
ShapeInference::InferDynamicSliceShape(
|
||||
operand_shape, start_indices_shape, slice_sizes));
|
||||
*instr.mutable_shape() = shape.ToProto();
|
||||
|
||||
for (int64 size : slice_sizes) {
|
||||
instr.add_dynamic_slice_sizes(size);
|
||||
@ -672,9 +673,10 @@ XlaOp XlaBuilder::DynamicUpdateSlice(const XlaOp& operand, const XlaOp& update,
|
||||
TF_ASSIGN_OR_RETURN(const Shape& update_shape, GetShape(update));
|
||||
TF_ASSIGN_OR_RETURN(const Shape& start_indices_shape,
|
||||
GetShape(start_indices));
|
||||
TF_ASSIGN_OR_RETURN(*instr.mutable_shape(),
|
||||
TF_ASSIGN_OR_RETURN(Shape shape,
|
||||
ShapeInference::InferDynamicUpdateSliceShape(
|
||||
operand_shape, update_shape, start_indices_shape));
|
||||
*instr.mutable_shape() = shape.ToProto();
|
||||
|
||||
return AddInstruction(std::move(instr), HloOpcode::kDynamicUpdateSlice,
|
||||
{operand, update, start_indices});
|
||||
@ -690,9 +692,9 @@ XlaOp XlaBuilder::ConcatInDim(absl::Span<const XlaOp> operands,
|
||||
TF_ASSIGN_OR_RETURN(const auto& operand_shapes, GetOperandShapes(operands));
|
||||
absl::c_transform(operand_shapes, std::back_inserter(operand_shape_ptrs),
|
||||
[](const Shape& shape) { return &shape; });
|
||||
TF_ASSIGN_OR_RETURN(
|
||||
*instr.mutable_shape(),
|
||||
ShapeInference::InferConcatOpShape(operand_shape_ptrs, dimension));
|
||||
TF_ASSIGN_OR_RETURN(Shape shape, ShapeInference::InferConcatOpShape(
|
||||
operand_shape_ptrs, dimension));
|
||||
*instr.mutable_shape() = shape.ToProto();
|
||||
|
||||
instr.add_dimensions(dimension);
|
||||
|
||||
@ -709,10 +711,9 @@ XlaOp XlaBuilder::Pad(const XlaOp& operand, const XlaOp& padding_value,
|
||||
TF_ASSIGN_OR_RETURN(const Shape& padding_value_shape,
|
||||
GetShape(padding_value));
|
||||
TF_ASSIGN_OR_RETURN(
|
||||
*instr.mutable_shape(),
|
||||
ShapeInference::InferPadShape(operand_shape, padding_value_shape,
|
||||
padding_config));
|
||||
|
||||
Shape shape, ShapeInference::InferPadShape(
|
||||
operand_shape, padding_value_shape, padding_config));
|
||||
*instr.mutable_shape() = shape.ToProto();
|
||||
*instr.mutable_padding_config() = padding_config;
|
||||
|
||||
return AddInstruction(std::move(instr), HloOpcode::kPad,
|
||||
@ -725,7 +726,7 @@ XlaOp XlaBuilder::Reshape(const XlaOp& operand,
|
||||
absl::Span<const int64> new_sizes) {
|
||||
return ReportErrorOrReturn([&]() -> StatusOr<XlaOp> {
|
||||
TF_ASSIGN_OR_RETURN(const Shape& operand_shape, GetShape(operand));
|
||||
TF_ASSIGN_OR_RETURN(const Shape& shape,
|
||||
TF_ASSIGN_OR_RETURN(const Shape shape,
|
||||
ShapeInference::InferReshapeShape(
|
||||
operand_shape, dimensions, new_sizes));
|
||||
XlaOp transposed = IsIdentityPermutation(dimensions)
|
||||
@ -738,7 +739,7 @@ XlaOp XlaBuilder::Reshape(const XlaOp& operand,
|
||||
XlaOp XlaBuilder::Reshape(const XlaOp& operand,
|
||||
absl::Span<const int64> new_sizes) {
|
||||
return ReportErrorOrReturn([&]() -> StatusOr<XlaOp> {
|
||||
TF_ASSIGN_OR_RETURN(auto shape, GetShape(operand));
|
||||
TF_ASSIGN_OR_RETURN(Shape shape, GetShape(operand));
|
||||
std::vector<int64> dimensions(shape.dimensions_size());
|
||||
std::iota(dimensions.begin(), dimensions.end(), 0);
|
||||
return Reshape(operand, dimensions, new_sizes);
|
||||
@ -788,7 +789,7 @@ XlaOp XlaBuilder::Collapse(const XlaOp& operand,
|
||||
void XlaBuilder::Trace(const string& tag, const XlaOp& operand) {
|
||||
ReportErrorOrReturn([&]() -> StatusOr<XlaOp> {
|
||||
HloInstructionProto instr;
|
||||
*instr.mutable_shape() = ShapeUtil::MakeNil();
|
||||
*instr.mutable_shape() = ShapeUtil::MakeNil().ToProto();
|
||||
*instr.mutable_literal() = LiteralUtil::CreateR1U8(tag).ToProto();
|
||||
return AddInstruction(std::move(instr), HloOpcode::kTrace, {operand});
|
||||
});
|
||||
@ -814,9 +815,10 @@ XlaOp XlaBuilder::Tuple(absl::Span<const XlaOp> elements) {
|
||||
TF_ASSIGN_OR_RETURN(const auto& operand_shapes, GetOperandShapes(elements));
|
||||
absl::c_transform(operand_shapes, std::back_inserter(operand_shape_ptrs),
|
||||
[](const Shape& shape) { return &shape; });
|
||||
TF_ASSIGN_OR_RETURN(*instr.mutable_shape(),
|
||||
TF_ASSIGN_OR_RETURN(const Shape shape,
|
||||
ShapeInference::InferVariadicOpShape(
|
||||
HloOpcode::kTuple, operand_shape_ptrs));
|
||||
*instr.mutable_shape() = shape.ToProto();
|
||||
return AddInstruction(std::move(instr), HloOpcode::kTuple, elements);
|
||||
});
|
||||
}
|
||||
@ -831,7 +833,7 @@ XlaOp XlaBuilder::GetTupleElement(const XlaOp& tuple_data, int64 index) {
|
||||
ShapeUtil::HumanString(tuple_shape));
|
||||
}
|
||||
*instr.mutable_shape() =
|
||||
ShapeUtil::GetTupleElementShape(tuple_shape, index);
|
||||
ShapeUtil::GetTupleElementShape(tuple_shape, index).ToProto();
|
||||
|
||||
instr.set_tuple_index(index);
|
||||
|
||||
@ -890,9 +892,10 @@ XlaOp XlaBuilder::DotGeneral(const XlaOp& lhs, const XlaOp& rhs,
|
||||
HloInstructionProto instr;
|
||||
TF_ASSIGN_OR_RETURN(const Shape& lhs_shape, GetShape(lhs));
|
||||
TF_ASSIGN_OR_RETURN(const Shape& rhs_shape, GetShape(rhs));
|
||||
TF_ASSIGN_OR_RETURN(*instr.mutable_shape(),
|
||||
TF_ASSIGN_OR_RETURN(Shape shape,
|
||||
ShapeInference::InferDotOpShape(lhs_shape, rhs_shape,
|
||||
dimension_numbers));
|
||||
*instr.mutable_shape() = shape.ToProto();
|
||||
*instr.mutable_dot_dimension_numbers() = dimension_numbers;
|
||||
if (precision_config != nullptr) {
|
||||
*instr.mutable_precision_config() = *precision_config;
|
||||
@ -1034,10 +1037,11 @@ XlaOp XlaBuilder::ConvGeneralDilated(
|
||||
MakeWindow(window_dimensions, window_strides, padding,
|
||||
lhs_dilation, rhs_dilation));
|
||||
|
||||
TF_ASSIGN_OR_RETURN(*instr.mutable_shape(),
|
||||
TF_ASSIGN_OR_RETURN(Shape shape,
|
||||
ShapeInference::InferConvolveShape(
|
||||
lhs_shape, rhs_shape, feature_group_count,
|
||||
instr.window(), dimension_numbers));
|
||||
*instr.mutable_shape() = shape.ToProto();
|
||||
|
||||
*instr.mutable_convolution_dimension_numbers() = dimension_numbers;
|
||||
instr.set_feature_group_count(feature_group_count);
|
||||
@ -1110,10 +1114,9 @@ XlaOp XlaBuilder::Fft(const XlaOp& operand, const FftType fft_type,
|
||||
return ReportErrorOrReturn([&]() -> StatusOr<XlaOp> {
|
||||
HloInstructionProto instr;
|
||||
TF_ASSIGN_OR_RETURN(const Shape& operand_shape, GetShape(operand));
|
||||
TF_ASSIGN_OR_RETURN(
|
||||
*instr.mutable_shape(),
|
||||
ShapeInference::InferFftShape(operand_shape, fft_type, fft_length));
|
||||
|
||||
TF_ASSIGN_OR_RETURN(Shape shape, ShapeInference::InferFftShape(
|
||||
operand_shape, fft_type, fft_length));
|
||||
*instr.mutable_shape() = shape.ToProto();
|
||||
instr.set_fft_type(fft_type);
|
||||
for (int64 i : fft_length) {
|
||||
instr.add_fft_length(i);
|
||||
@ -1131,7 +1134,7 @@ XlaOp XlaBuilder::Infeed(const Shape& shape, const string& config) {
|
||||
}
|
||||
const Shape infeed_instruction_shape =
|
||||
ShapeUtil::MakeTupleShape({shape, ShapeUtil::MakeTokenShape()});
|
||||
*instr.mutable_shape() = infeed_instruction_shape;
|
||||
*instr.mutable_shape() = infeed_instruction_shape.ToProto();
|
||||
instr.set_infeed_config(config);
|
||||
|
||||
if (ShapeUtil::IsArray(shape) && sharding() &&
|
||||
@ -1152,7 +1155,7 @@ XlaOp XlaBuilder::Infeed(const Shape& shape, const string& config) {
|
||||
XlaOp token;
|
||||
auto make_token = [&]() {
|
||||
HloInstructionProto token_instr;
|
||||
*token_instr.mutable_shape() = ShapeUtil::MakeTokenShape();
|
||||
*token_instr.mutable_shape() = ShapeUtil::MakeTokenShape().ToProto();
|
||||
return AddInstruction(std::move(token_instr), HloOpcode::kAfterAll, {});
|
||||
};
|
||||
if (sharding()) {
|
||||
@ -1191,7 +1194,7 @@ XlaOp XlaBuilder::Infeed(const Shape& shape, const string& config) {
|
||||
// TODO(b/80000000): Remove this when clients have been updated to handle
|
||||
// tokens.
|
||||
HloInstructionProto infeed_data;
|
||||
*infeed_data.mutable_shape() = shape;
|
||||
*infeed_data.mutable_shape() = shape.ToProto();
|
||||
infeed_data.set_tuple_index(0);
|
||||
return AddInstruction(std::move(infeed_data), HloOpcode::kGetTupleElement,
|
||||
{infeed});
|
||||
@ -1207,7 +1210,7 @@ XlaOp XlaBuilder::InfeedWithToken(const XlaOp& token, const Shape& shape,
|
||||
}
|
||||
const Shape infeed_instruction_shape =
|
||||
ShapeUtil::MakeTupleShape({shape, ShapeUtil::MakeTokenShape()});
|
||||
*instr.mutable_shape() = infeed_instruction_shape;
|
||||
*instr.mutable_shape() = infeed_instruction_shape.ToProto();
|
||||
instr.set_infeed_config(config);
|
||||
|
||||
if (ShapeUtil::IsArray(shape) && sharding() &&
|
||||
@ -1232,7 +1235,7 @@ void XlaBuilder::Outfeed(const XlaOp& operand, const Shape& shape_with_layout,
|
||||
ReportErrorOrReturn([&]() -> StatusOr<XlaOp> {
|
||||
HloInstructionProto instr;
|
||||
|
||||
*instr.mutable_shape() = ShapeUtil::MakeTokenShape();
|
||||
*instr.mutable_shape() = ShapeUtil::MakeTokenShape().ToProto();
|
||||
|
||||
// Check and set outfeed shape.
|
||||
if (!LayoutUtil::HasLayout(shape_with_layout)) {
|
||||
@ -1245,14 +1248,14 @@ void XlaBuilder::Outfeed(const XlaOp& operand, const Shape& shape_with_layout,
|
||||
ShapeUtil::HumanStringWithLayout(shape_with_layout),
|
||||
ShapeUtil::HumanStringWithLayout(operand_shape));
|
||||
}
|
||||
*instr.mutable_outfeed_shape() = shape_with_layout;
|
||||
*instr.mutable_outfeed_shape() = shape_with_layout.ToProto();
|
||||
|
||||
instr.set_outfeed_config(outfeed_config);
|
||||
|
||||
// Outfeed takes a token as its second operand. Generate the token to pass
|
||||
// to the outfeed.
|
||||
HloInstructionProto token_instr;
|
||||
*token_instr.mutable_shape() = ShapeUtil::MakeTokenShape();
|
||||
*token_instr.mutable_shape() = ShapeUtil::MakeTokenShape().ToProto();
|
||||
TF_ASSIGN_OR_RETURN(XlaOp token, AddInstruction(std::move(token_instr),
|
||||
HloOpcode::kAfterAll, {}));
|
||||
|
||||
@ -1266,7 +1269,7 @@ void XlaBuilder::Outfeed(const XlaOp& operand, const Shape& shape_with_layout,
|
||||
// TODO(b/80000000): Remove this when clients have been updated to handle
|
||||
// tokens.
|
||||
HloInstructionProto tuple_instr;
|
||||
*tuple_instr.mutable_shape() = ShapeUtil::MakeNil();
|
||||
*tuple_instr.mutable_shape() = ShapeUtil::MakeNil().ToProto();
|
||||
|
||||
// The dummy tuple should have no sharding.
|
||||
{
|
||||
@ -1285,7 +1288,7 @@ XlaOp XlaBuilder::OutfeedWithToken(const XlaOp& operand, const XlaOp& token,
|
||||
return ReportErrorOrReturn([&]() -> StatusOr<XlaOp> {
|
||||
HloInstructionProto instr;
|
||||
|
||||
*instr.mutable_shape() = ShapeUtil::MakeTokenShape();
|
||||
*instr.mutable_shape() = ShapeUtil::MakeTokenShape().ToProto();
|
||||
|
||||
// Check and set outfeed shape.
|
||||
if (!LayoutUtil::HasLayout(shape_with_layout)) {
|
||||
@ -1298,7 +1301,7 @@ XlaOp XlaBuilder::OutfeedWithToken(const XlaOp& operand, const XlaOp& token,
|
||||
ShapeUtil::HumanStringWithLayout(shape_with_layout),
|
||||
ShapeUtil::HumanStringWithLayout(operand_shape));
|
||||
}
|
||||
*instr.mutable_outfeed_shape() = shape_with_layout;
|
||||
*instr.mutable_outfeed_shape() = shape_with_layout.ToProto();
|
||||
|
||||
instr.set_outfeed_config(outfeed_config);
|
||||
|
||||
@ -1310,7 +1313,7 @@ XlaOp XlaBuilder::OutfeedWithToken(const XlaOp& operand, const XlaOp& token,
|
||||
XlaOp XlaBuilder::CreateToken() {
|
||||
return ReportErrorOrReturn([&]() -> StatusOr<XlaOp> {
|
||||
HloInstructionProto instr;
|
||||
*instr.mutable_shape() = ShapeUtil::MakeTokenShape();
|
||||
*instr.mutable_shape() = ShapeUtil::MakeTokenShape().ToProto();
|
||||
return AddInstruction(std::move(instr), HloOpcode::kAfterAll);
|
||||
});
|
||||
}
|
||||
@ -1330,7 +1333,7 @@ XlaOp XlaBuilder::AfterAll(absl::Span<const XlaOp> tokens) {
|
||||
}
|
||||
}
|
||||
HloInstructionProto instr;
|
||||
*instr.mutable_shape() = ShapeUtil::MakeTokenShape();
|
||||
*instr.mutable_shape() = ShapeUtil::MakeTokenShape().ToProto();
|
||||
return AddInstruction(std::move(instr), HloOpcode::kAfterAll, tokens);
|
||||
});
|
||||
}
|
||||
@ -1347,7 +1350,7 @@ XlaOp XlaBuilder::CustomCall(
|
||||
"are reserved for internal use.",
|
||||
call_target_name);
|
||||
}
|
||||
*instr.mutable_shape() = shape;
|
||||
*instr.mutable_shape() = shape.ToProto();
|
||||
instr.set_custom_call_target(call_target_name);
|
||||
instr.set_custom_call_opaque(opaque);
|
||||
if (operand_shapes_with_layout.has_value()) {
|
||||
@ -1371,7 +1374,7 @@ XlaOp XlaBuilder::CustomCall(
|
||||
"constrained layout.",
|
||||
operand_num);
|
||||
}
|
||||
*instr.add_operand_shapes_with_layout() = operand_shape;
|
||||
*instr.add_operand_shapes_with_layout() = operand_shape.ToProto();
|
||||
++operand_num;
|
||||
}
|
||||
}
|
||||
@ -1525,9 +1528,9 @@ XlaOp XlaBuilder::Transpose(const XlaOp& operand,
|
||||
return ReportErrorOrReturn([&]() -> StatusOr<XlaOp> {
|
||||
HloInstructionProto instr;
|
||||
TF_ASSIGN_OR_RETURN(const Shape& operand_shape, GetShape(operand));
|
||||
TF_ASSIGN_OR_RETURN(
|
||||
*instr.mutable_shape(),
|
||||
ShapeInference::InferTransposeShape(operand_shape, permutation));
|
||||
TF_ASSIGN_OR_RETURN(Shape shape, ShapeInference::InferTransposeShape(
|
||||
operand_shape, permutation));
|
||||
*instr.mutable_shape() = shape.ToProto();
|
||||
for (int64 dim : permutation) {
|
||||
instr.add_dimensions(dim);
|
||||
}
|
||||
@ -1540,9 +1543,9 @@ XlaOp XlaBuilder::Rev(const XlaOp& operand,
|
||||
return ReportErrorOrReturn([&]() -> StatusOr<XlaOp> {
|
||||
HloInstructionProto instr;
|
||||
TF_ASSIGN_OR_RETURN(const Shape& operand_shape, GetShape(operand));
|
||||
TF_ASSIGN_OR_RETURN(
|
||||
*instr.mutable_shape(),
|
||||
ShapeInference::InferReverseShape(operand_shape, dimensions));
|
||||
TF_ASSIGN_OR_RETURN(Shape shape, ShapeInference::InferReverseShape(
|
||||
operand_shape, dimensions));
|
||||
*instr.mutable_shape() = shape.ToProto();
|
||||
for (int64 dim : dimensions) {
|
||||
instr.add_dimensions(dim);
|
||||
}
|
||||
@ -1561,9 +1564,9 @@ XlaOp XlaBuilder::Sort(const XlaOp& keys, absl::Span<const XlaOp> values,
|
||||
GetOperandShapes(values));
|
||||
absl::c_transform(values_shapes, std::back_inserter(operand_shape_ptrs),
|
||||
[](const Shape& shape) { return &shape; });
|
||||
TF_ASSIGN_OR_RETURN(*instr.mutable_shape(),
|
||||
ShapeInference::InferVariadicOpShape(
|
||||
HloOpcode::kSort, operand_shape_ptrs));
|
||||
TF_ASSIGN_OR_RETURN(Shape shape, ShapeInference::InferVariadicOpShape(
|
||||
HloOpcode::kSort, operand_shape_ptrs));
|
||||
*instr.mutable_shape() = shape.ToProto();
|
||||
if (dimension == -1) {
|
||||
TF_ASSIGN_OR_RETURN(const Shape& keys_shape, GetShape(keys));
|
||||
dimension = ShapeUtil::Rank(keys_shape) - 1;
|
||||
@ -1585,9 +1588,9 @@ XlaOp XlaBuilder::ConvertElementType(const XlaOp& operand,
|
||||
return ReportErrorOrReturn([&]() -> StatusOr<XlaOp> {
|
||||
HloInstructionProto instr;
|
||||
TF_ASSIGN_OR_RETURN(const Shape& operand_shape, GetShape(operand));
|
||||
TF_ASSIGN_OR_RETURN(
|
||||
*instr.mutable_shape(),
|
||||
ShapeInference::InferConvertShape(operand_shape, new_element_type));
|
||||
TF_ASSIGN_OR_RETURN(Shape shape, ShapeInference::InferConvertShape(
|
||||
operand_shape, new_element_type));
|
||||
*instr.mutable_shape() = shape.ToProto();
|
||||
return AddInstruction(std::move(instr), HloOpcode::kConvert, {operand});
|
||||
});
|
||||
}
|
||||
@ -1597,9 +1600,9 @@ XlaOp XlaBuilder::BitcastConvertType(const XlaOp& operand,
|
||||
return ReportErrorOrReturn([&]() -> StatusOr<XlaOp> {
|
||||
HloInstructionProto instr;
|
||||
TF_ASSIGN_OR_RETURN(const Shape& operand_shape, GetShape(operand));
|
||||
TF_ASSIGN_OR_RETURN(
|
||||
*instr.mutable_shape(),
|
||||
ShapeInference::InferConvertShape(operand_shape, new_element_type));
|
||||
TF_ASSIGN_OR_RETURN(Shape shape, ShapeInference::InferConvertShape(
|
||||
operand_shape, new_element_type));
|
||||
*instr.mutable_shape() = shape.ToProto();
|
||||
return AddInstruction(std::move(instr), HloOpcode::kBitcastConvert,
|
||||
{operand});
|
||||
});
|
||||
@ -1631,11 +1634,11 @@ XlaOp XlaBuilder::Map(absl::Span<const XlaOp> operands,
|
||||
TF_ASSIGN_OR_RETURN(const ProgramShape& called_program_shape,
|
||||
computation.GetProgramShape());
|
||||
TF_ASSIGN_OR_RETURN(
|
||||
*instr.mutable_shape(),
|
||||
ShapeInference::InferMapShape(operand_shape_ptrs, called_program_shape,
|
||||
dimensions));
|
||||
Shape shape, ShapeInference::InferMapShape(
|
||||
operand_shape_ptrs, called_program_shape, dimensions));
|
||||
*instr.mutable_shape() = shape.ToProto();
|
||||
|
||||
const Shape& output_shape = instr.shape();
|
||||
Shape output_shape(instr.shape());
|
||||
const int64 output_rank = ShapeUtil::Rank(output_shape);
|
||||
AddCalledComputation(computation, &instr);
|
||||
std::vector<XlaOp> new_operands(operands.begin(), operands.end());
|
||||
@ -1678,7 +1681,7 @@ XlaOp XlaBuilder::RngOp(RandomDistribution distribution,
|
||||
}
|
||||
|
||||
TF_RETURN_IF_ERROR(ShapeUtil::ValidateShapeWithOptionalLayout(shape));
|
||||
*instr.mutable_shape() = shape;
|
||||
*instr.mutable_shape() = shape.ToProto();
|
||||
|
||||
instr.set_distribution(distribution);
|
||||
|
||||
@ -1706,10 +1709,10 @@ XlaOp XlaBuilder::While(const XlaComputation& condition,
|
||||
TF_ASSIGN_OR_RETURN(const auto& condition_program_shape,
|
||||
condition.GetProgramShape());
|
||||
TF_ASSIGN_OR_RETURN(const Shape& init_shape, GetShape(init));
|
||||
TF_ASSIGN_OR_RETURN(
|
||||
*instr.mutable_shape(),
|
||||
ShapeInference::InferWhileShape(condition_program_shape,
|
||||
body_program_shape, init_shape));
|
||||
TF_ASSIGN_OR_RETURN(Shape shape, ShapeInference::InferWhileShape(
|
||||
condition_program_shape,
|
||||
body_program_shape, init_shape));
|
||||
*instr.mutable_shape() = shape.ToProto();
|
||||
// Body comes before condition computation in the vector.
|
||||
AddCalledComputation(body, &instr);
|
||||
AddCalledComputation(condition, &instr);
|
||||
@ -1726,10 +1729,10 @@ XlaOp XlaBuilder::Gather(const XlaOp& input, const XlaOp& start_indices,
|
||||
TF_ASSIGN_OR_RETURN(const Shape& input_shape, GetShape(input));
|
||||
TF_ASSIGN_OR_RETURN(const Shape& start_indices_shape,
|
||||
GetShape(start_indices));
|
||||
TF_ASSIGN_OR_RETURN(
|
||||
*instr.mutable_shape(),
|
||||
ShapeInference::InferGatherShape(input_shape, start_indices_shape,
|
||||
TF_ASSIGN_OR_RETURN(Shape shape, ShapeInference::InferGatherShape(
|
||||
input_shape, start_indices_shape,
|
||||
dimension_numbers, slice_sizes));
|
||||
*instr.mutable_shape() = shape.ToProto();
|
||||
|
||||
*instr.mutable_gather_dimension_numbers() = dimension_numbers;
|
||||
for (int64 bound : slice_sizes) {
|
||||
@ -1754,10 +1757,11 @@ XlaOp XlaBuilder::Scatter(const XlaOp& input, const XlaOp& scatter_indices,
|
||||
TF_ASSIGN_OR_RETURN(const Shape& updates_shape, GetShape(updates));
|
||||
TF_ASSIGN_OR_RETURN(const ProgramShape& to_apply_shape,
|
||||
update_computation.GetProgramShape());
|
||||
TF_ASSIGN_OR_RETURN(*instr.mutable_shape(),
|
||||
TF_ASSIGN_OR_RETURN(Shape shape,
|
||||
ShapeInference::InferScatterShape(
|
||||
input_shape, scatter_indices_shape, updates_shape,
|
||||
to_apply_shape, dimension_numbers));
|
||||
*instr.mutable_shape() = shape.ToProto();
|
||||
|
||||
*instr.mutable_scatter_dimension_numbers() = dimension_numbers;
|
||||
|
||||
@ -1784,10 +1788,11 @@ XlaOp XlaBuilder::Conditional(const XlaOp& predicate, const XlaOp& true_operand,
|
||||
TF_ASSIGN_OR_RETURN(const ProgramShape& false_computation_shape,
|
||||
false_computation.GetProgramShape());
|
||||
TF_ASSIGN_OR_RETURN(
|
||||
*instr.mutable_shape(),
|
||||
Shape shape,
|
||||
ShapeInference::InferConditionalShape(
|
||||
predicate_shape, true_operand_shape, false_operand_shape,
|
||||
true_computation_shape, false_computation_shape));
|
||||
*instr.mutable_shape() = shape.ToProto();
|
||||
|
||||
// The index of true_computation must be 0 and that of false computation
|
||||
// must be 1.
|
||||
@ -1829,9 +1834,10 @@ XlaOp XlaBuilder::Reduce(absl::Span<const XlaOp> operands,
|
||||
[](const Shape& shape) { return &shape; });
|
||||
|
||||
TF_ASSIGN_OR_RETURN(
|
||||
*instr.mutable_shape(),
|
||||
Shape shape,
|
||||
ShapeInference::InferReduceShape(
|
||||
operand_shape_ptrs, dimensions_to_reduce, called_program_shape));
|
||||
*instr.mutable_shape() = shape.ToProto();
|
||||
|
||||
for (int64 dim : dimensions_to_reduce) {
|
||||
instr.add_dimensions(dim);
|
||||
@ -1894,10 +1900,10 @@ XlaOp XlaBuilder::ReduceWindowWithGeneralPadding(
|
||||
MakeWindow(window_dimensions, window_strides, padding,
|
||||
/*lhs_dilation=*/base_dilations,
|
||||
/*rhs_dilation=*/window_dilations));
|
||||
TF_ASSIGN_OR_RETURN(
|
||||
*instr.mutable_shape(),
|
||||
ShapeInference::InferReduceWindowShape(operand_shape, init_shape,
|
||||
instr.window(), to_apply_shape));
|
||||
TF_ASSIGN_OR_RETURN(Shape shape, ShapeInference::InferReduceWindowShape(
|
||||
operand_shape, init_shape,
|
||||
instr.window(), to_apply_shape));
|
||||
*instr.mutable_shape() = shape.ToProto();
|
||||
|
||||
AddCalledComputation(computation, &instr);
|
||||
return AddInstruction(std::move(instr), HloOpcode::kReduceWindow,
|
||||
@ -1915,9 +1921,10 @@ XlaOp XlaBuilder::BatchNormTraining(const XlaOp& operand, const XlaOp& scale,
|
||||
TF_ASSIGN_OR_RETURN(const Shape& scale_shape, GetShape(scale));
|
||||
TF_ASSIGN_OR_RETURN(const Shape& offset_shape, GetShape(offset));
|
||||
TF_ASSIGN_OR_RETURN(
|
||||
*instr.mutable_shape(),
|
||||
Shape shape,
|
||||
ShapeInference::InferBatchNormTrainingShape(
|
||||
operand_shape, scale_shape, offset_shape, feature_index));
|
||||
*instr.mutable_shape() = shape.ToProto();
|
||||
|
||||
instr.set_epsilon(epsilon);
|
||||
instr.set_feature_index(feature_index);
|
||||
@ -1939,10 +1946,11 @@ XlaOp XlaBuilder::BatchNormInference(const XlaOp& operand, const XlaOp& scale,
|
||||
TF_ASSIGN_OR_RETURN(const Shape& offset_shape, GetShape(offset));
|
||||
TF_ASSIGN_OR_RETURN(const Shape& mean_shape, GetShape(mean));
|
||||
TF_ASSIGN_OR_RETURN(const Shape& variance_shape, GetShape(variance));
|
||||
TF_ASSIGN_OR_RETURN(*instr.mutable_shape(),
|
||||
ShapeInference::InferBatchNormInferenceShape(
|
||||
operand_shape, scale_shape, offset_shape,
|
||||
mean_shape, variance_shape, feature_index));
|
||||
TF_ASSIGN_OR_RETURN(
|
||||
Shape shape, ShapeInference::InferBatchNormInferenceShape(
|
||||
operand_shape, scale_shape, offset_shape, mean_shape,
|
||||
variance_shape, feature_index));
|
||||
*instr.mutable_shape() = shape.ToProto();
|
||||
|
||||
instr.set_epsilon(epsilon);
|
||||
instr.set_feature_index(feature_index);
|
||||
@ -1964,10 +1972,11 @@ XlaOp XlaBuilder::BatchNormGrad(const XlaOp& operand, const XlaOp& scale,
|
||||
TF_ASSIGN_OR_RETURN(const Shape& batch_mean_shape, GetShape(batch_mean));
|
||||
TF_ASSIGN_OR_RETURN(const Shape& batch_var_shape, GetShape(batch_var));
|
||||
TF_ASSIGN_OR_RETURN(const Shape& grad_output_shape, GetShape(grad_output));
|
||||
TF_ASSIGN_OR_RETURN(*instr.mutable_shape(),
|
||||
TF_ASSIGN_OR_RETURN(Shape shape,
|
||||
ShapeInference::InferBatchNormGradShape(
|
||||
operand_shape, scale_shape, batch_mean_shape,
|
||||
batch_var_shape, grad_output_shape, feature_index));
|
||||
*instr.mutable_shape() = shape.ToProto();
|
||||
|
||||
instr.set_epsilon(epsilon);
|
||||
instr.set_feature_index(feature_index);
|
||||
@ -1998,9 +2007,9 @@ XlaOp XlaBuilder::CrossReplicaSum(
|
||||
return ReportErrorOrReturn([&]() -> StatusOr<XlaOp> {
|
||||
HloInstructionProto instr;
|
||||
TF_ASSIGN_OR_RETURN(const Shape& operand_shape, GetShape(operand));
|
||||
TF_ASSIGN_OR_RETURN(
|
||||
*instr.mutable_shape(),
|
||||
ShapeInference::InferCrossReplicaSumShape({&operand_shape}));
|
||||
TF_ASSIGN_OR_RETURN(Shape shape, ShapeInference::InferCrossReplicaSumShape(
|
||||
{&operand_shape}));
|
||||
*instr.mutable_shape() = shape.ToProto();
|
||||
|
||||
for (const ReplicaGroup& group : replica_groups) {
|
||||
*instr.add_replica_groups() = group;
|
||||
@ -2053,8 +2062,8 @@ XlaOp XlaBuilder::AllToAll(const XlaOp& operand, int64 split_dimension,
|
||||
absl::c_transform(slice_shapes, std::back_inserter(slice_shape_ptrs),
|
||||
[](const Shape& shape) { return &shape; });
|
||||
TF_ASSIGN_OR_RETURN(
|
||||
*instr.mutable_shape(),
|
||||
ShapeInference::InferAllToAllTupleShape(slice_shape_ptrs));
|
||||
Shape shape, ShapeInference::InferAllToAllTupleShape(slice_shape_ptrs));
|
||||
*instr.mutable_shape() = shape.ToProto();
|
||||
for (const ReplicaGroup& group : replica_groups) {
|
||||
*instr.add_replica_groups() = group;
|
||||
}
|
||||
@ -2079,8 +2088,9 @@ XlaOp XlaBuilder::CollectivePermute(
|
||||
TF_ASSIGN_OR_RETURN(const Shape& operand_shape, GetShape(operand));
|
||||
HloInstructionProto instr;
|
||||
TF_ASSIGN_OR_RETURN(
|
||||
*instr.mutable_shape(),
|
||||
Shape shape,
|
||||
ShapeInference::InferCollectivePermuteShape(operand_shape));
|
||||
*instr.mutable_shape() = shape.ToProto();
|
||||
|
||||
for (const auto& pair : source_target_pairs) {
|
||||
auto* proto_pair = instr.add_source_target_pairs();
|
||||
@ -2129,10 +2139,11 @@ XlaOp XlaBuilder::SelectAndScatterWithGeneralPadding(
|
||||
TF_ASSIGN_OR_RETURN(*instr.mutable_window(),
|
||||
MakeWindow(window_dimensions, window_strides, padding,
|
||||
/*lhs_dilation=*/{}, /*rhs_dilation=*/{}));
|
||||
TF_ASSIGN_OR_RETURN(*instr.mutable_shape(),
|
||||
TF_ASSIGN_OR_RETURN(Shape shape,
|
||||
ShapeInference::InferSelectAndScatterShape(
|
||||
operand_shape, select_shape, instr.window(),
|
||||
source_shape, init_shape, scatter_shape));
|
||||
*instr.mutable_shape() = shape.ToProto();
|
||||
|
||||
AddCalledComputation(select, &instr);
|
||||
AddCalledComputation(scatter, &instr);
|
||||
@ -2147,9 +2158,10 @@ XlaOp XlaBuilder::ReducePrecision(const XlaOp& operand, const int exponent_bits,
|
||||
return ReportErrorOrReturn([&]() -> StatusOr<XlaOp> {
|
||||
HloInstructionProto instr;
|
||||
TF_ASSIGN_OR_RETURN(const Shape& operand_shape, GetShape(operand));
|
||||
TF_ASSIGN_OR_RETURN(*instr.mutable_shape(),
|
||||
TF_ASSIGN_OR_RETURN(Shape shape,
|
||||
ShapeInference::InferReducePrecisionShape(
|
||||
operand_shape, exponent_bits, mantissa_bits));
|
||||
*instr.mutable_shape() = shape.ToProto();
|
||||
instr.set_exponent_bits(exponent_bits);
|
||||
instr.set_mantissa_bits(mantissa_bits);
|
||||
return AddInstruction(std::move(instr), HloOpcode::kReducePrecision,
|
||||
@ -2164,7 +2176,7 @@ void XlaBuilder::Send(const XlaOp& operand, const ChannelHandle& handle) {
|
||||
// TODO(b/80000000): Remove this when clients have been updated to handle
|
||||
// tokens.
|
||||
HloInstructionProto token_instr;
|
||||
*token_instr.mutable_shape() = ShapeUtil::MakeTokenShape();
|
||||
*token_instr.mutable_shape() = ShapeUtil::MakeTokenShape().ToProto();
|
||||
TF_ASSIGN_OR_RETURN(XlaOp token, AddInstruction(std::move(token_instr),
|
||||
HloOpcode::kAfterAll, {}));
|
||||
|
||||
@ -2183,15 +2195,17 @@ XlaOp XlaBuilder::SendWithToken(const XlaOp& operand, const XlaOp& token,
|
||||
// token}.
|
||||
HloInstructionProto send_instr;
|
||||
TF_ASSIGN_OR_RETURN(const Shape& shape, GetShape(operand));
|
||||
*send_instr.mutable_shape() = ShapeUtil::MakeTupleShape(
|
||||
{shape, ShapeUtil::MakeShape(U32, {}), ShapeUtil::MakeTokenShape()});
|
||||
*send_instr.mutable_shape() =
|
||||
ShapeUtil::MakeTupleShape(
|
||||
{shape, ShapeUtil::MakeShape(U32, {}), ShapeUtil::MakeTokenShape()})
|
||||
.ToProto();
|
||||
send_instr.set_channel_id(handle.handle());
|
||||
TF_ASSIGN_OR_RETURN(XlaOp send,
|
||||
AddInstruction(std::move(send_instr), HloOpcode::kSend,
|
||||
{operand, token}));
|
||||
|
||||
HloInstructionProto send_done_instr;
|
||||
*send_done_instr.mutable_shape() = ShapeUtil::MakeTokenShape();
|
||||
*send_done_instr.mutable_shape() = ShapeUtil::MakeTokenShape().ToProto();
|
||||
send_done_instr.set_channel_id(handle.handle());
|
||||
return AddInstruction(std::move(send_done_instr), HloOpcode::kSendDone,
|
||||
{send});
|
||||
@ -2205,7 +2219,7 @@ XlaOp XlaBuilder::Recv(const Shape& shape, const ChannelHandle& handle) {
|
||||
// TODO(b/80000000): Remove this when clients have been updated to handle
|
||||
// tokens.
|
||||
HloInstructionProto token_instr;
|
||||
*token_instr.mutable_shape() = ShapeUtil::MakeTokenShape();
|
||||
*token_instr.mutable_shape() = ShapeUtil::MakeTokenShape().ToProto();
|
||||
TF_ASSIGN_OR_RETURN(XlaOp token, AddInstruction(std::move(token_instr),
|
||||
HloOpcode::kAfterAll, {}));
|
||||
|
||||
@ -2216,7 +2230,7 @@ XlaOp XlaBuilder::Recv(const Shape& shape, const ChannelHandle& handle) {
|
||||
// TODO(b/80000000): Remove this when clients have been updated to handle
|
||||
// tokens.
|
||||
HloInstructionProto recv_data;
|
||||
*recv_data.mutable_shape() = shape;
|
||||
*recv_data.mutable_shape() = shape.ToProto();
|
||||
recv_data.set_tuple_index(0);
|
||||
return AddInstruction(std::move(recv_data), HloOpcode::kGetTupleElement,
|
||||
{recv});
|
||||
@ -2233,15 +2247,18 @@ XlaOp XlaBuilder::RecvWithToken(const XlaOp& token, const Shape& shape,
|
||||
// Recv instruction produces a tuple of {receive buffer, U32 context,
|
||||
// token}.
|
||||
HloInstructionProto recv_instr;
|
||||
*recv_instr.mutable_shape() = ShapeUtil::MakeTupleShape(
|
||||
{shape, ShapeUtil::MakeShape(U32, {}), ShapeUtil::MakeTokenShape()});
|
||||
*recv_instr.mutable_shape() =
|
||||
ShapeUtil::MakeTupleShape(
|
||||
{shape, ShapeUtil::MakeShape(U32, {}), ShapeUtil::MakeTokenShape()})
|
||||
.ToProto();
|
||||
recv_instr.set_channel_id(handle.handle());
|
||||
TF_ASSIGN_OR_RETURN(XlaOp recv, AddInstruction(std::move(recv_instr),
|
||||
HloOpcode::kRecv, {token}));
|
||||
|
||||
HloInstructionProto recv_done_instr;
|
||||
*recv_done_instr.mutable_shape() =
|
||||
ShapeUtil::MakeTupleShape({shape, ShapeUtil::MakeTokenShape()});
|
||||
ShapeUtil::MakeTupleShape({shape, ShapeUtil::MakeTokenShape()})
|
||||
.ToProto();
|
||||
recv_done_instr.set_channel_id(handle.handle());
|
||||
return AddInstruction(std::move(recv_done_instr), HloOpcode::kRecvDone,
|
||||
{recv});
|
||||
@ -2275,9 +2292,11 @@ XlaOp XlaBuilder::SendToHost(const XlaOp& operand, const XlaOp& token,
|
||||
// Send instruction produces a tuple of {aliased operand, U32 context,
|
||||
// token}.
|
||||
HloInstructionProto send_instr;
|
||||
*send_instr.mutable_shape() = ShapeUtil::MakeTupleShape(
|
||||
{shape_with_layout, ShapeUtil::MakeShape(U32, {}),
|
||||
ShapeUtil::MakeTokenShape()});
|
||||
*send_instr.mutable_shape() =
|
||||
ShapeUtil::MakeTupleShape({shape_with_layout,
|
||||
ShapeUtil::MakeShape(U32, {}),
|
||||
ShapeUtil::MakeTokenShape()})
|
||||
.ToProto();
|
||||
send_instr.set_channel_id(handle.handle());
|
||||
send_instr.set_is_host_transfer(true);
|
||||
TF_ASSIGN_OR_RETURN(XlaOp send,
|
||||
@ -2285,7 +2304,7 @@ XlaOp XlaBuilder::SendToHost(const XlaOp& operand, const XlaOp& token,
|
||||
{operand, token}));
|
||||
|
||||
HloInstructionProto send_done_instr;
|
||||
*send_done_instr.mutable_shape() = ShapeUtil::MakeTokenShape();
|
||||
*send_done_instr.mutable_shape() = ShapeUtil::MakeTokenShape().ToProto();
|
||||
send_done_instr.set_channel_id(handle.handle());
|
||||
send_done_instr.set_is_host_transfer(true);
|
||||
return AddInstruction(std::move(send_done_instr), HloOpcode::kSendDone,
|
||||
@ -2314,8 +2333,10 @@ XlaOp XlaBuilder::RecvFromHost(const XlaOp& token, const Shape& shape,
|
||||
// Recv instruction produces a tuple of {receive buffer, U32 context,
|
||||
// token}.
|
||||
HloInstructionProto recv_instr;
|
||||
*recv_instr.mutable_shape() = ShapeUtil::MakeTupleShape(
|
||||
{shape, ShapeUtil::MakeShape(U32, {}), ShapeUtil::MakeTokenShape()});
|
||||
*recv_instr.mutable_shape() =
|
||||
ShapeUtil::MakeTupleShape(
|
||||
{shape, ShapeUtil::MakeShape(U32, {}), ShapeUtil::MakeTokenShape()})
|
||||
.ToProto();
|
||||
recv_instr.set_channel_id(handle.handle());
|
||||
recv_instr.set_is_host_transfer(true);
|
||||
TF_ASSIGN_OR_RETURN(XlaOp recv, AddInstruction(std::move(recv_instr),
|
||||
@ -2323,7 +2344,8 @@ XlaOp XlaBuilder::RecvFromHost(const XlaOp& token, const Shape& shape,
|
||||
|
||||
HloInstructionProto recv_done_instr;
|
||||
*recv_done_instr.mutable_shape() =
|
||||
ShapeUtil::MakeTupleShape({shape, ShapeUtil::MakeTokenShape()});
|
||||
ShapeUtil::MakeTupleShape({shape, ShapeUtil::MakeTokenShape()})
|
||||
.ToProto();
|
||||
recv_done_instr.set_channel_id(handle.handle());
|
||||
recv_done_instr.set_is_host_transfer(true);
|
||||
return AddInstruction(std::move(recv_done_instr), HloOpcode::kRecvDone,
|
||||
@ -2335,9 +2357,9 @@ XlaOp XlaBuilder::GetDimensionSize(const XlaOp& operand, int64 dimension) {
|
||||
return ReportErrorOrReturn([&]() -> StatusOr<XlaOp> {
|
||||
HloInstructionProto instr;
|
||||
TF_ASSIGN_OR_RETURN(const auto& operand_shape, GetShape(operand));
|
||||
TF_ASSIGN_OR_RETURN(
|
||||
*instr.mutable_shape(),
|
||||
ShapeInference::InferGetDimensionSizeShape(operand_shape, dimension));
|
||||
TF_ASSIGN_OR_RETURN(Shape shape, ShapeInference::InferGetDimensionSizeShape(
|
||||
operand_shape, dimension));
|
||||
*instr.mutable_shape() = shape.ToProto();
|
||||
instr.add_dimensions(dimension);
|
||||
return AddInstruction(std::move(instr), HloOpcode::kGetDimensionSize,
|
||||
{operand});
|
||||
|
@ -21,6 +21,7 @@ limitations under the License.
|
||||
#include <vector>
|
||||
|
||||
#include "absl/types/span.h"
|
||||
#include "tensorflow/compiler/xla/shape.h"
|
||||
#include "tensorflow/compiler/xla/types.h"
|
||||
#include "tensorflow/compiler/xla/xla_data.pb.h"
|
||||
#include "tensorflow/core/platform/macros.h"
|
||||
|
@ -292,16 +292,17 @@ Status MutableLiteralBase::CopyElementFrom(const LiteralSlice& src_literal,
|
||||
if (!proto.has_shape()) {
|
||||
return InvalidArgument("LiteralProto has no shape");
|
||||
}
|
||||
if (ShapeUtil::HasPrimitiveType(proto.shape(), OPAQUE)) {
|
||||
Shape shape(proto.shape());
|
||||
if (ShapeUtil::HasPrimitiveType(shape, OPAQUE)) {
|
||||
return InvalidArgument("Literal shape cannot include OPAQUE sub-shape");
|
||||
}
|
||||
if (!LayoutUtil::HasLayout(proto.shape())) {
|
||||
if (!LayoutUtil::HasLayout(shape)) {
|
||||
return InvalidArgument("LiteralProto has no layout");
|
||||
}
|
||||
|
||||
TF_RETURN_IF_ERROR(ShapeUtil::ValidateShapeWithOptionalLayout(proto.shape()));
|
||||
TF_RETURN_IF_ERROR(ShapeUtil::ValidateShapeWithOptionalLayout(shape));
|
||||
|
||||
Literal literal(proto.shape());
|
||||
Literal literal(shape);
|
||||
|
||||
TF_RETURN_IF_ERROR(literal.root_piece_->ForEachMutableSubpieceWithStatus(
|
||||
[&](const ShapeIndex& index, Piece* piece) {
|
||||
@ -1794,7 +1795,7 @@ void CopyToRepeatedField(RepeatedFieldT* dest,
|
||||
} // namespace
|
||||
|
||||
void LiteralBase::Piece::WriteToProto(LiteralProto* proto) const {
|
||||
*proto->mutable_shape() = subshape();
|
||||
*proto->mutable_shape() = subshape().ToProto();
|
||||
switch (subshape().element_type()) {
|
||||
case PRED:
|
||||
CopyToRepeatedField(proto->mutable_preds(), data<bool>());
|
||||
@ -1900,8 +1901,9 @@ Status LiteralBase::Piece::CopyFromProto(const LiteralProto& proto) {
|
||||
// These conditions should have been checked in
|
||||
// MutableLiteralBase::CreateFromProto.
|
||||
TF_RET_CHECK(proto.has_shape());
|
||||
TF_RET_CHECK(LayoutUtil::HasLayout(proto.shape()));
|
||||
TF_RET_CHECK(ShapeUtil::Equal(proto.shape(), subshape()));
|
||||
Shape shape(proto.shape());
|
||||
TF_RET_CHECK(LayoutUtil::HasLayout(shape));
|
||||
TF_RET_CHECK(ShapeUtil::Equal(shape, subshape()));
|
||||
|
||||
if (LayoutUtil::IsSparseArray(subshape())) {
|
||||
// Compute the number of elements (indices) in the sparse shape and reserve
|
||||
|
@ -1377,13 +1377,26 @@ TEST_F(LiteralUtilTest, BitcastConvertBetweenInvalidTypes) {
|
||||
absl::StrContains(status.error_message(), "bit widths are different"));
|
||||
}
|
||||
|
||||
// Sets the layout of the given ShapeProto to the default.
|
||||
void SetDefaultLayoutOnProto(ShapeProto* shape_proto) {
|
||||
CHECK(ShapeUtil::IsArrayPrimitiveType(shape_proto->element_type()));
|
||||
shape_proto->mutable_layout()->set_format(DENSE);
|
||||
auto* minor_to_major =
|
||||
shape_proto->mutable_layout()->mutable_minor_to_major();
|
||||
minor_to_major->Resize(shape_proto->dimensions_size(), 0);
|
||||
const int64 size = minor_to_major->size();
|
||||
for (int64 i = 0; i < size; ++i) {
|
||||
minor_to_major->Set(i, size - 1 - i);
|
||||
}
|
||||
}
|
||||
|
||||
TEST_F(LiteralUtilTest, CopyFromProto_Bool) {
|
||||
LiteralProto p;
|
||||
p.mutable_shape()->set_element_type(PRED);
|
||||
for (int len = 0; len < 25; ++len) {
|
||||
p.mutable_shape()->clear_dimensions();
|
||||
p.mutable_shape()->add_dimensions(len);
|
||||
LayoutUtil::SetToDefaultLayout(p.mutable_shape());
|
||||
SetDefaultLayoutOnProto(p.mutable_shape());
|
||||
p.clear_preds();
|
||||
for (int i = 0; i < len; ++i) {
|
||||
p.add_preds((i % 2) == (len % 2));
|
||||
@ -1409,7 +1422,7 @@ TEST_F(LiteralUtilTest, ToProto_f16) {
|
||||
EXPECT_EQ(4, m.data<half>().size());
|
||||
|
||||
LiteralProto p = m.ToProto();
|
||||
EXPECT_EQ(4, ShapeUtil::ElementsIn(p.shape()));
|
||||
EXPECT_EQ(4, ShapeUtil::ElementsIn(Shape(p.shape())));
|
||||
EXPECT_EQ(8, p.f16s().size());
|
||||
const char* d = p.f16s().data();
|
||||
EXPECT_EQ(d[0], 0);
|
||||
@ -1432,7 +1445,7 @@ TEST_F(LiteralUtilTest, CopyFromProto_f16) {
|
||||
p.mutable_shape()->set_element_type(F16);
|
||||
p.mutable_shape()->clear_dimensions();
|
||||
p.mutable_shape()->add_dimensions(4);
|
||||
LayoutUtil::SetToDefaultLayout(p.mutable_shape());
|
||||
SetDefaultLayoutOnProto(p.mutable_shape());
|
||||
p.clear_f16s();
|
||||
p.set_f16s(half_vals, 8);
|
||||
TF_ASSERT_OK_AND_ASSIGN(Literal literal, Literal::CreateFromProto(p));
|
||||
@ -1454,7 +1467,7 @@ TEST_F(LiteralUtilTest, CopyFromProto_u16) {
|
||||
p.mutable_shape()->set_element_type(U16);
|
||||
p.mutable_shape()->clear_dimensions();
|
||||
p.mutable_shape()->add_dimensions(4);
|
||||
LayoutUtil::SetToDefaultLayout(p.mutable_shape());
|
||||
SetDefaultLayoutOnProto(p.mutable_shape());
|
||||
p.clear_u16s();
|
||||
p.set_u16s(uint16_vals, 8);
|
||||
TF_ASSERT_OK_AND_ASSIGN(Literal literal, Literal::CreateFromProto(p));
|
||||
@ -1756,7 +1769,7 @@ TEST_F(LiteralUtilTest, ProtoRoundTrip) {
|
||||
TEST_F(LiteralUtilTest, InvalidProtoNoValues) {
|
||||
// Proto contains a shape, but no values.
|
||||
LiteralProto proto;
|
||||
*proto.mutable_shape() = ShapeUtil::MakeShape(F32, {3});
|
||||
*proto.mutable_shape() = ShapeUtil::MakeShape(F32, {3}).ToProto();
|
||||
Status status = Literal::CreateFromProto(proto).status();
|
||||
ASSERT_FALSE(status.ok());
|
||||
EXPECT_THAT(status.error_message(),
|
||||
@ -1777,7 +1790,7 @@ TEST_F(LiteralUtilTest, InvalidProtoNoShape) {
|
||||
TEST_F(LiteralUtilTest, InvalidProtoWrongContainer) {
|
||||
// Proto contains values in wrong container.
|
||||
LiteralProto proto;
|
||||
*proto.mutable_shape() = ShapeUtil::MakeShape(F32, {3});
|
||||
*proto.mutable_shape() = ShapeUtil::MakeShape(F32, {3}).ToProto();
|
||||
proto.add_preds(false);
|
||||
proto.add_preds(true);
|
||||
proto.add_preds(false);
|
||||
@ -1790,7 +1803,7 @@ TEST_F(LiteralUtilTest, InvalidProtoWrongContainer) {
|
||||
TEST_F(LiteralUtilTest, InvalidProtoTooFewValues) {
|
||||
// Proto contains too few values.
|
||||
LiteralProto proto;
|
||||
*proto.mutable_shape() = ShapeUtil::MakeShape(F32, {42, 2});
|
||||
*proto.mutable_shape() = ShapeUtil::MakeShape(F32, {42, 2}).ToProto();
|
||||
proto.add_f32s(1.0);
|
||||
proto.add_f32s(2.0);
|
||||
proto.add_f32s(3.0);
|
||||
@ -1803,7 +1816,7 @@ TEST_F(LiteralUtilTest, InvalidProtoTooFewValues) {
|
||||
TEST_F(LiteralUtilTest, InvalidProtoTooManyValues) {
|
||||
// Proto contains too many values.
|
||||
LiteralProto proto;
|
||||
*proto.mutable_shape() = ShapeUtil::MakeShape(S32, {2});
|
||||
*proto.mutable_shape() = ShapeUtil::MakeShape(S32, {2}).ToProto();
|
||||
proto.add_s32s(42);
|
||||
proto.add_s32s(-10);
|
||||
proto.add_s32s(100);
|
||||
@ -1816,8 +1829,8 @@ TEST_F(LiteralUtilTest, InvalidProtoTooManyValues) {
|
||||
TEST_F(LiteralUtilTest, InvalidProtoMissingLayout) {
|
||||
// Proto shape missing layout.
|
||||
LiteralProto proto;
|
||||
*proto.mutable_shape() = ShapeUtil::MakeShape(PRED, {2, 2});
|
||||
LayoutUtil::ClearLayout(proto.mutable_shape());
|
||||
*proto.mutable_shape() = ShapeUtil::MakeShape(PRED, {2, 2}).ToProto();
|
||||
proto.mutable_shape()->clear_layout();
|
||||
proto.add_preds(true);
|
||||
proto.add_preds(false);
|
||||
proto.add_preds(true);
|
||||
@ -1830,11 +1843,13 @@ TEST_F(LiteralUtilTest, InvalidProtoMissingLayout) {
|
||||
TEST_F(LiteralUtilTest, InvalidProtoTooFewTupleElements) {
|
||||
// Proto has the too few tuple elements.
|
||||
LiteralProto proto;
|
||||
*proto.mutable_shape() = ShapeUtil::MakeTupleShape(
|
||||
{ShapeUtil::MakeShape(PRED, {2}), ShapeUtil::MakeShape(F32, {})});
|
||||
*proto.mutable_shape() =
|
||||
ShapeUtil::MakeTupleShape(
|
||||
{ShapeUtil::MakeShape(PRED, {2}), ShapeUtil::MakeShape(F32, {})})
|
||||
.ToProto();
|
||||
LiteralProto* element0 = proto.add_tuple_literals();
|
||||
*element0->mutable_shape() =
|
||||
ShapeUtil::GetTupleElementShape(proto.shape(), 0);
|
||||
ShapeUtil::GetTupleElementShape(Shape(proto.shape()), 0).ToProto();
|
||||
element0->add_preds(false);
|
||||
element0->add_preds(true);
|
||||
|
||||
@ -1846,19 +1861,21 @@ TEST_F(LiteralUtilTest, InvalidProtoTooFewTupleElements) {
|
||||
TEST_F(LiteralUtilTest, InvalidProtoTooManyTupleElements) {
|
||||
// Proto has the too many tuple elements.
|
||||
LiteralProto proto;
|
||||
*proto.mutable_shape() = ShapeUtil::MakeTupleShape(
|
||||
{ShapeUtil::MakeShape(PRED, {2}), ShapeUtil::MakeShape(F32, {})});
|
||||
*proto.mutable_shape() =
|
||||
ShapeUtil::MakeTupleShape(
|
||||
{ShapeUtil::MakeShape(PRED, {2}), ShapeUtil::MakeShape(F32, {})})
|
||||
.ToProto();
|
||||
LiteralProto* element0 = proto.add_tuple_literals();
|
||||
*element0->mutable_shape() =
|
||||
ShapeUtil::GetTupleElementShape(proto.shape(), 0);
|
||||
ShapeUtil::GetTupleElementShape(Shape(proto.shape()), 0).ToProto();
|
||||
element0->add_preds(false);
|
||||
element0->add_preds(true);
|
||||
LiteralProto* element1 = proto.add_tuple_literals();
|
||||
*element1->mutable_shape() =
|
||||
ShapeUtil::GetTupleElementShape(proto.shape(), 1);
|
||||
ShapeUtil::GetTupleElementShape(Shape(proto.shape()), 1).ToProto();
|
||||
element1->add_f32s(42.0);
|
||||
LiteralProto* element2 = proto.add_tuple_literals();
|
||||
*element2->mutable_shape() = ShapeUtil::MakeShape(F32, {});
|
||||
*element2->mutable_shape() = ShapeUtil::MakeShape(F32, {}).ToProto();
|
||||
element2->add_f32s(123.0);
|
||||
|
||||
Status status = Literal::CreateFromProto(proto).status();
|
||||
|
@ -25,9 +25,10 @@ from tensorflow.compiler.xla.python_api import types
|
||||
|
||||
|
||||
class Shape(object):
|
||||
"""Wraps a xla_data_pb2.Shape message with a convenient Python type.
|
||||
"""Wraps a xla_data_pb2.ShapeProto message with a convenient Python type.
|
||||
|
||||
Provides direct access to the underlying xla_data_pb2.Shape message in the
|
||||
Provides direct access to the underlying xla_data_pb2.ShapeProto message in
|
||||
the
|
||||
message attribute, along with accessor wrappers to the message's fields.
|
||||
Avoid direct access to .message unless interacting directly with protobuf APIs
|
||||
like CopyFrom. In other words, prefer hauling the shape around in a Shape, and
|
||||
@ -48,7 +49,7 @@ class Shape(object):
|
||||
Raises:
|
||||
ValueError: if element_type is TUPLE but dimensions are not Shape objects.
|
||||
"""
|
||||
self.message = xla_data_pb2.Shape()
|
||||
self.message = xla_data_pb2.ShapeProto()
|
||||
self.message.element_type = element_type
|
||||
if element_type == xla_data_pb2.TUPLE:
|
||||
if not all(isinstance(subshape, Shape) for subshape in dimensions):
|
||||
|
@ -16,7 +16,6 @@ xla_proto_library(
|
||||
use_grpc_plugin = True,
|
||||
visibility = ["//visibility:public"],
|
||||
deps = [
|
||||
"//tensorflow/compiler/xla:xla_data_proto",
|
||||
"//tensorflow/compiler/xla:xla_proto",
|
||||
],
|
||||
)
|
||||
|
@ -43,7 +43,6 @@ limitations under the License.
|
||||
syntax = "proto3";
|
||||
|
||||
import "tensorflow/compiler/xla/xla.proto";
|
||||
import "tensorflow/compiler/xla/xla_data.proto";
|
||||
|
||||
package xla;
|
||||
|
||||
|
@ -89,7 +89,7 @@ CompileOnlyService::CompileAheadOfTime(
|
||||
ExecutionOptions execution_options;
|
||||
*execution_options.mutable_debug_options() = debug_options;
|
||||
*execution_options.mutable_shape_with_output_layout() =
|
||||
*instance.result_layout;
|
||||
instance.result_layout->ToProto();
|
||||
TF_ASSIGN_OR_RETURN(
|
||||
std::unique_ptr<HloModuleConfig> module_config,
|
||||
CreateModuleConfig(
|
||||
|
@ -23,6 +23,7 @@ limitations under the License.
|
||||
#include <deque>
|
||||
|
||||
#include "absl/types/span.h"
|
||||
#include "tensorflow/compiler/xla/shape.h"
|
||||
#include "tensorflow/compiler/xla/statusor.h"
|
||||
#include "tensorflow/compiler/xla/types.h"
|
||||
#include "tensorflow/compiler/xla/xla_data.pb.h"
|
||||
|
@ -51,7 +51,7 @@ message HloInstructionProto {
|
||||
|
||||
string name = 1;
|
||||
string opcode = 2;
|
||||
xla.Shape shape = 3;
|
||||
xla.ShapeProto shape = 3;
|
||||
|
||||
xla.OpMetadata metadata = 7;
|
||||
|
||||
@ -132,7 +132,7 @@ message HloInstructionProto {
|
||||
string custom_call_opaque = 53;
|
||||
|
||||
// Shape of outfeed request.
|
||||
xla.Shape outfeed_shape = 29;
|
||||
xla.ShapeProto outfeed_shape = 29;
|
||||
|
||||
// Describes the dimension numbers used for a dot operation
|
||||
xla.DotDimensionNumbers dot_dimension_numbers = 30;
|
||||
@ -190,7 +190,7 @@ message HloInstructionProto {
|
||||
// 'operand_shapes_with_layout' must contain a shape with layout for each
|
||||
// operand.
|
||||
bool constrain_layout = 56;
|
||||
repeated Shape operand_shapes_with_layout = 57;
|
||||
repeated xla.ShapeProto operand_shapes_with_layout = 57;
|
||||
}
|
||||
|
||||
// Serialization of HloComputation.
|
||||
|
@ -93,7 +93,8 @@ StatusOr<std::unique_ptr<HloInstruction>> HloInstruction::CreateFromProto(
|
||||
[&computation_map](int64 id) { return computation_map.contains(id); }))
|
||||
<< proto.name() << " instruction references invalid computation id(s)";
|
||||
|
||||
TF_RETURN_IF_ERROR(ShapeUtil::ValidateShapeWithOptionalLayout(proto.shape()));
|
||||
Shape shape(proto.shape());
|
||||
TF_RETURN_IF_ERROR(ShapeUtil::ValidateShapeWithOptionalLayout(shape));
|
||||
|
||||
switch (opcode) {
|
||||
// Ops migrated to subclasses.
|
||||
@ -101,23 +102,23 @@ StatusOr<std::unique_ptr<HloInstruction>> HloInstruction::CreateFromProto(
|
||||
TF_RET_CHECK(proto.operand_ids_size() == 3)
|
||||
<< "BatchNormTraining instruction should have 3 operands but sees "
|
||||
<< proto.operand_ids_size();
|
||||
instruction = CreateBatchNormTraining(
|
||||
proto.shape(), operands(0), operands(1), operands(2), proto.epsilon(),
|
||||
proto.feature_index());
|
||||
instruction =
|
||||
CreateBatchNormTraining(shape, operands(0), operands(1), operands(2),
|
||||
proto.epsilon(), proto.feature_index());
|
||||
break;
|
||||
case HloOpcode::kBatchNormInference:
|
||||
TF_RET_CHECK(proto.operand_ids_size() == 5)
|
||||
<< "BatchNormInference instruction should have 5 operands but sees "
|
||||
<< proto.operand_ids_size();
|
||||
instruction = CreateBatchNormInference(
|
||||
proto.shape(), operands(0), operands(1), operands(2), operands(3),
|
||||
shape, operands(0), operands(1), operands(2), operands(3),
|
||||
operands(4), proto.epsilon(), proto.feature_index());
|
||||
break;
|
||||
case HloOpcode::kBatchNormGrad:
|
||||
TF_RET_CHECK(proto.operand_ids_size() == 5)
|
||||
<< "BatchNormGrad instruction should have 5 operands but sees "
|
||||
<< proto.operand_ids_size();
|
||||
instruction = CreateBatchNormGrad(proto.shape(), operands(0), operands(1),
|
||||
instruction = CreateBatchNormGrad(shape, operands(0), operands(1),
|
||||
operands(2), operands(3), operands(4),
|
||||
proto.epsilon(), proto.feature_index());
|
||||
break;
|
||||
@ -127,7 +128,7 @@ StatusOr<std::unique_ptr<HloInstruction>> HloInstruction::CreateFromProto(
|
||||
<< proto.operand_ids_size();
|
||||
std::vector<int64> fft_length(proto.fft_length().begin(),
|
||||
proto.fft_length().end());
|
||||
instruction = CreateFft(proto.shape(), operands(0), proto.fft_type(),
|
||||
instruction = CreateFft(shape, operands(0), proto.fft_type(),
|
||||
absl::Span<const int64>(fft_length));
|
||||
break;
|
||||
}
|
||||
@ -148,7 +149,7 @@ StatusOr<std::unique_ptr<HloInstruction>> HloInstruction::CreateFromProto(
|
||||
TF_RET_CHECK(proto.operand_ids_size() == 1)
|
||||
<< "Recv instruction should have 1 operand but sees "
|
||||
<< proto.operand_ids_size();
|
||||
instruction = CreateRecv(proto.shape().tuple_shapes(0), operands(0),
|
||||
instruction = CreateRecv(shape.tuple_shapes(0), operands(0),
|
||||
proto.channel_id(), proto.is_host_transfer());
|
||||
break;
|
||||
case HloOpcode::kRecvDone:
|
||||
@ -161,7 +162,7 @@ StatusOr<std::unique_ptr<HloInstruction>> HloInstruction::CreateFromProto(
|
||||
TF_RET_CHECK(proto.operand_ids_size() == 1)
|
||||
<< "Reverse instruction should have 1 operand but sees "
|
||||
<< proto.operand_ids_size();
|
||||
instruction = CreateReverse(proto.shape(), operands(0),
|
||||
instruction = CreateReverse(shape, operands(0),
|
||||
std::vector<int64>(proto.dimensions().begin(),
|
||||
proto.dimensions().end()));
|
||||
break;
|
||||
@ -170,7 +171,7 @@ StatusOr<std::unique_ptr<HloInstruction>> HloInstruction::CreateFromProto(
|
||||
<< "Concatenate instruction should have 1 dimension but sees "
|
||||
<< proto.dimensions_size();
|
||||
instruction =
|
||||
CreateConcatenate(proto.shape(), all_operands(), proto.dimensions(0));
|
||||
CreateConcatenate(shape, all_operands(), proto.dimensions(0));
|
||||
break;
|
||||
case HloOpcode::kReduce:
|
||||
TF_RET_CHECK(proto.operand_ids_size() % 2 == 0)
|
||||
@ -188,7 +189,7 @@ StatusOr<std::unique_ptr<HloInstruction>> HloInstruction::CreateFromProto(
|
||||
absl::MakeSpan(reduce_operands)
|
||||
.subspan(reduce_operands.size() / 2, reduce_operands.size());
|
||||
instruction =
|
||||
CreateReduce(proto.shape(), inputs, init_values,
|
||||
CreateReduce(shape, inputs, init_values,
|
||||
std::vector<int64>(proto.dimensions().begin(),
|
||||
proto.dimensions().end()),
|
||||
computations(0));
|
||||
@ -203,7 +204,7 @@ StatusOr<std::unique_ptr<HloInstruction>> HloInstruction::CreateFromProto(
|
||||
auto sort_operands = all_operands();
|
||||
HloInstruction* keys = sort_operands[0];
|
||||
instruction = CreateSort(
|
||||
proto.shape(), proto.dimensions(0), keys,
|
||||
shape, proto.dimensions(0), keys,
|
||||
absl::Span<HloInstruction* const>(sort_operands).subspan(1));
|
||||
break;
|
||||
}
|
||||
@ -212,7 +213,7 @@ StatusOr<std::unique_ptr<HloInstruction>> HloInstruction::CreateFromProto(
|
||||
<< "Transpose instruction should have 1 operand but sees "
|
||||
<< proto.operand_ids_size();
|
||||
instruction =
|
||||
CreateTranspose(proto.shape(), operands(0),
|
||||
CreateTranspose(shape, operands(0),
|
||||
std::vector<int64>(proto.dimensions().begin(),
|
||||
proto.dimensions().end()));
|
||||
break;
|
||||
@ -221,7 +222,7 @@ StatusOr<std::unique_ptr<HloInstruction>> HloInstruction::CreateFromProto(
|
||||
<< "Broadcast instruction should have 1 operand but sees "
|
||||
<< proto.operand_ids_size();
|
||||
instruction =
|
||||
CreateBroadcast(proto.shape(), operands(0),
|
||||
CreateBroadcast(shape, operands(0),
|
||||
std::vector<int64>(proto.dimensions().begin(),
|
||||
proto.dimensions().end()));
|
||||
break;
|
||||
@ -229,7 +230,7 @@ StatusOr<std::unique_ptr<HloInstruction>> HloInstruction::CreateFromProto(
|
||||
TF_RET_CHECK(proto.called_computation_ids_size() == 1)
|
||||
<< "Map instruction should have 1 called computation but sees "
|
||||
<< proto.called_computation_ids_size();
|
||||
instruction = CreateMap(proto.shape(), all_operands(), computations(0));
|
||||
instruction = CreateMap(shape, all_operands(), computations(0));
|
||||
break;
|
||||
case HloOpcode::kSlice: {
|
||||
TF_RET_CHECK(proto.operand_ids_size() == 1)
|
||||
@ -242,8 +243,8 @@ StatusOr<std::unique_ptr<HloInstruction>> HloInstruction::CreateFromProto(
|
||||
slice_limits.push_back(slice_dimensions.limit());
|
||||
slice_strides.push_back(slice_dimensions.stride());
|
||||
}
|
||||
instruction = CreateSlice(proto.shape(), operands(0), slice_starts,
|
||||
slice_limits, slice_strides);
|
||||
instruction = CreateSlice(shape, operands(0), slice_starts, slice_limits,
|
||||
slice_strides);
|
||||
break;
|
||||
}
|
||||
case HloOpcode::kConstant: {
|
||||
@ -253,7 +254,7 @@ StatusOr<std::unique_ptr<HloInstruction>> HloInstruction::CreateFromProto(
|
||||
Literal::CreateFromProto(proto.literal()));
|
||||
instruction = CreateConstant(std::move(literal));
|
||||
} else {
|
||||
instruction = absl::make_unique<HloConstantInstruction>(proto.shape());
|
||||
instruction = absl::make_unique<HloConstantInstruction>(shape);
|
||||
}
|
||||
break;
|
||||
}
|
||||
@ -284,55 +285,54 @@ StatusOr<std::unique_ptr<HloInstruction>> HloInstruction::CreateFromProto(
|
||||
tensorflow::gtl::FindPtrOrNull(computation_map, fusion_id);
|
||||
TF_RET_CHECK(fused_computation != nullptr)
|
||||
<< "No fusion computation with id " << fusion_id;
|
||||
instruction = CreateFusion(proto.shape(), fusion_kind, all_operands(),
|
||||
fused_computation);
|
||||
instruction =
|
||||
CreateFusion(shape, fusion_kind, all_operands(), fused_computation);
|
||||
break;
|
||||
}
|
||||
case HloOpcode::kRng:
|
||||
instruction =
|
||||
CreateRng(proto.shape(), proto.distribution(), all_operands());
|
||||
instruction = CreateRng(shape, proto.distribution(), all_operands());
|
||||
break;
|
||||
case HloOpcode::kParameter:
|
||||
instruction = CreateParameter(proto.parameter_number(), proto.shape(),
|
||||
proto.name());
|
||||
instruction =
|
||||
CreateParameter(proto.parameter_number(), shape, proto.name());
|
||||
break;
|
||||
case HloOpcode::kGetTupleElement:
|
||||
TF_RET_CHECK(proto.operand_ids_size() == 1)
|
||||
<< "GetTupleElement instruction should have 1 operand but sees "
|
||||
<< proto.operand_ids_size();
|
||||
instruction = CreateGetTupleElement(proto.shape(), operands(0),
|
||||
proto.tuple_index());
|
||||
instruction =
|
||||
CreateGetTupleElement(shape, operands(0), proto.tuple_index());
|
||||
break;
|
||||
case HloOpcode::kReducePrecision:
|
||||
TF_RET_CHECK(proto.operand_ids_size() == 1)
|
||||
<< "ReducePrecision instruction should have 1 operand but sees "
|
||||
<< proto.operand_ids_size();
|
||||
instruction =
|
||||
CreateReducePrecision(proto.shape(), operands(0),
|
||||
proto.exponent_bits(), proto.mantissa_bits());
|
||||
instruction = CreateReducePrecision(
|
||||
shape, operands(0), proto.exponent_bits(), proto.mantissa_bits());
|
||||
break;
|
||||
case HloOpcode::kInfeed: {
|
||||
TF_RET_CHECK(ShapeUtil::IsTuple(proto.shape()) &&
|
||||
(ShapeUtil::TupleElementCount(proto.shape()) == 2))
|
||||
TF_RET_CHECK(ShapeUtil::IsTuple(shape) &&
|
||||
(ShapeUtil::TupleElementCount(shape) == 2))
|
||||
<< "Infeed should have a tuple shape with 2 operands, but has: "
|
||||
<< proto.shape();
|
||||
const Shape& data_shape =
|
||||
ShapeUtil::GetTupleElementShape(proto.shape(), 0);
|
||||
<< shape;
|
||||
const Shape& data_shape = ShapeUtil::GetTupleElementShape(shape, 0);
|
||||
TF_RET_CHECK(proto.operand_ids_size() == 1)
|
||||
<< "Infeed instruction should have 1 operand but sees "
|
||||
<< proto.operand_ids_size();
|
||||
instruction =
|
||||
CreateInfeed(data_shape, operands(0), proto.infeed_config());
|
||||
} break;
|
||||
case HloOpcode::kOutfeed:
|
||||
case HloOpcode::kOutfeed: {
|
||||
TF_RET_CHECK(proto.operand_ids_size() == 2)
|
||||
<< "Outfeed instruction should have 2 operands but sees "
|
||||
<< proto.operand_ids_size();
|
||||
Shape outfeed_shape(proto.outfeed_shape());
|
||||
TF_RETURN_IF_ERROR(
|
||||
ShapeUtil::ValidateShapeWithOptionalLayout(proto.outfeed_shape()));
|
||||
instruction = CreateOutfeed(proto.outfeed_shape(), operands(0),
|
||||
operands(1), proto.outfeed_config());
|
||||
ShapeUtil::ValidateShapeWithOptionalLayout(outfeed_shape));
|
||||
instruction = CreateOutfeed(outfeed_shape, operands(0), operands(1),
|
||||
proto.outfeed_config());
|
||||
break;
|
||||
}
|
||||
case HloOpcode::kCrossReplicaSum: {
|
||||
TF_RET_CHECK(proto.called_computation_ids_size() == 1)
|
||||
<< "CrossReplicaSum should have 1 called computation but sees "
|
||||
@ -342,7 +342,7 @@ StatusOr<std::unique_ptr<HloInstruction>> HloInstruction::CreateFromProto(
|
||||
all_reduce_id = proto.all_reduce_id();
|
||||
}
|
||||
instruction = CreateCrossReplicaSum(
|
||||
proto.shape(), all_operands(), computations(0),
|
||||
shape, all_operands(), computations(0),
|
||||
/*replica_groups=*/
|
||||
std::vector<ReplicaGroup>(proto.replica_groups().begin(),
|
||||
proto.replica_groups().end()),
|
||||
@ -352,7 +352,7 @@ StatusOr<std::unique_ptr<HloInstruction>> HloInstruction::CreateFromProto(
|
||||
}
|
||||
case HloOpcode::kAllToAll: {
|
||||
instruction = CreateAllToAll(
|
||||
proto.shape(), all_operands(),
|
||||
shape, all_operands(),
|
||||
/*replica_groups=*/
|
||||
std::vector<ReplicaGroup>(proto.replica_groups().begin(),
|
||||
proto.replica_groups().end()));
|
||||
@ -368,8 +368,8 @@ StatusOr<std::unique_ptr<HloInstruction>> HloInstruction::CreateFromProto(
|
||||
source_target_pairs[i].first = proto.source_target_pairs(i).source();
|
||||
source_target_pairs[i].second = proto.source_target_pairs(i).target();
|
||||
}
|
||||
instruction = CreateCollectivePermute(proto.shape(), operands(0),
|
||||
source_target_pairs);
|
||||
instruction =
|
||||
CreateCollectivePermute(shape, operands(0), source_target_pairs);
|
||||
break;
|
||||
}
|
||||
case HloOpcode::kConvolution: {
|
||||
@ -382,7 +382,7 @@ StatusOr<std::unique_ptr<HloInstruction>> HloInstruction::CreateFromProto(
|
||||
precision_config.mutable_operand_precision()->Resize(
|
||||
proto.operand_ids_size(), PrecisionConfig::DEFAULT);
|
||||
instruction = CreateConvolve(
|
||||
proto.shape(), operands(0), operands(1),
|
||||
shape, operands(0), operands(1),
|
||||
std::max<int64>(proto.feature_group_count(), 1), proto.window(),
|
||||
proto.convolution_dimension_numbers(), precision_config);
|
||||
break;
|
||||
@ -394,7 +394,7 @@ StatusOr<std::unique_ptr<HloInstruction>> HloInstruction::CreateFromProto(
|
||||
TF_RET_CHECK(proto.called_computation_ids_size() == 1)
|
||||
<< "ReduceWindow should have 1 called computation but sees "
|
||||
<< proto.called_computation_ids_size();
|
||||
instruction = CreateReduceWindow(proto.shape(), operands(0), operands(1),
|
||||
instruction = CreateReduceWindow(shape, operands(0), operands(1),
|
||||
proto.window(), computations(0));
|
||||
break;
|
||||
case HloOpcode::kSelectAndScatter:
|
||||
@ -404,9 +404,9 @@ StatusOr<std::unique_ptr<HloInstruction>> HloInstruction::CreateFromProto(
|
||||
TF_RET_CHECK(proto.called_computation_ids_size() == 2)
|
||||
<< "SelectAndScatter should have 2 called computations but sees "
|
||||
<< proto.called_computation_ids_size();
|
||||
instruction = CreateSelectAndScatter(
|
||||
proto.shape(), operands(0), computations(0), proto.window(),
|
||||
operands(1), operands(2), computations(1));
|
||||
instruction = CreateSelectAndScatter(shape, operands(0), computations(0),
|
||||
proto.window(), operands(1),
|
||||
operands(2), computations(1));
|
||||
break;
|
||||
case HloOpcode::kCustomCall:
|
||||
if (proto.constrain_layout()) {
|
||||
@ -414,16 +414,17 @@ StatusOr<std::unique_ptr<HloInstruction>> HloInstruction::CreateFromProto(
|
||||
// vector of pointers essentially) so create a vector of shapes to pass
|
||||
// in.
|
||||
std::vector<Shape> operand_shapes;
|
||||
for (const Shape& shape : proto.operand_shapes_with_layout()) {
|
||||
operand_shapes.push_back(shape);
|
||||
for (const ShapeProto& shape_proto :
|
||||
proto.operand_shapes_with_layout()) {
|
||||
operand_shapes.emplace_back(shape_proto);
|
||||
}
|
||||
instruction = CreateCustomCall(
|
||||
proto.shape(), all_operands(), proto.custom_call_target(),
|
||||
operand_shapes, proto.custom_call_opaque());
|
||||
instruction =
|
||||
CreateCustomCall(shape, all_operands(), proto.custom_call_target(),
|
||||
operand_shapes, proto.custom_call_opaque());
|
||||
} else {
|
||||
instruction = CreateCustomCall(proto.shape(), all_operands(),
|
||||
proto.custom_call_target(),
|
||||
proto.custom_call_opaque());
|
||||
instruction =
|
||||
CreateCustomCall(shape, all_operands(), proto.custom_call_target(),
|
||||
proto.custom_call_opaque());
|
||||
}
|
||||
if (proto.has_window()) {
|
||||
static_cast<HloCustomCallInstruction*>(instruction.get())
|
||||
@ -443,8 +444,8 @@ StatusOr<std::unique_ptr<HloInstruction>> HloInstruction::CreateFromProto(
|
||||
<< "Pad instruction should have 2 operands but sees "
|
||||
<< proto.operand_ids_size();
|
||||
TF_RET_CHECK(proto.has_padding_config());
|
||||
instruction = CreatePad(proto.shape(), operands(0), operands(1),
|
||||
proto.padding_config());
|
||||
instruction =
|
||||
CreatePad(shape, operands(0), operands(1), proto.padding_config());
|
||||
break;
|
||||
case HloOpcode::kDynamicSlice: {
|
||||
TF_RET_CHECK(proto.operand_ids_size() == 2)
|
||||
@ -452,8 +453,8 @@ StatusOr<std::unique_ptr<HloInstruction>> HloInstruction::CreateFromProto(
|
||||
<< proto.operand_ids_size();
|
||||
std::vector<int64> slice_sizes(proto.dynamic_slice_sizes_size());
|
||||
absl::c_copy(proto.dynamic_slice_sizes(), slice_sizes.begin());
|
||||
instruction = CreateDynamicSlice(proto.shape(), operands(0), operands(1),
|
||||
slice_sizes);
|
||||
instruction =
|
||||
CreateDynamicSlice(shape, operands(0), operands(1), slice_sizes);
|
||||
break;
|
||||
}
|
||||
case HloOpcode::kGather: {
|
||||
@ -469,7 +470,7 @@ StatusOr<std::unique_ptr<HloInstruction>> HloInstruction::CreateFromProto(
|
||||
for (int64 bound : proto.gather_slice_sizes()) {
|
||||
gather_slice_sizes.push_back(bound);
|
||||
}
|
||||
instruction = CreateGather(proto.shape(), operands(0), operands(1),
|
||||
instruction = CreateGather(shape, operands(0), operands(1),
|
||||
*gather_dimension_numbers, gather_slice_sizes);
|
||||
break;
|
||||
}
|
||||
@ -485,16 +486,15 @@ StatusOr<std::unique_ptr<HloInstruction>> HloInstruction::CreateFromProto(
|
||||
auto scatter_dimension_numbers =
|
||||
absl::make_unique<ScatterDimensionNumbers>(
|
||||
proto.scatter_dimension_numbers());
|
||||
instruction =
|
||||
CreateScatter(proto.shape(), operands(0), operands(1), operands(2),
|
||||
computations(0), *scatter_dimension_numbers);
|
||||
instruction = CreateScatter(shape, operands(0), operands(1), operands(2),
|
||||
computations(0), *scatter_dimension_numbers);
|
||||
break;
|
||||
}
|
||||
case HloOpcode::kIota:
|
||||
TF_RET_CHECK(proto.dimensions_size() == 1)
|
||||
<< "Iota instruction should have 1 dimension but sees "
|
||||
<< proto.dimensions_size();
|
||||
instruction = CreateIota(proto.shape(), proto.dimensions(0));
|
||||
instruction = CreateIota(shape, proto.dimensions(0));
|
||||
break;
|
||||
case HloOpcode::kDot: {
|
||||
TF_RET_CHECK(proto.has_dot_dimension_numbers())
|
||||
@ -506,8 +506,8 @@ StatusOr<std::unique_ptr<HloInstruction>> HloInstruction::CreateFromProto(
|
||||
precision_config.mutable_operand_precision()->Resize(
|
||||
proto.operand_ids_size(), PrecisionConfig::DEFAULT);
|
||||
instruction = absl::make_unique<HloDotInstruction>(
|
||||
proto.shape(), operands(0), operands(1),
|
||||
proto.dot_dimension_numbers(), precision_config);
|
||||
shape, operands(0), operands(1), proto.dot_dimension_numbers(),
|
||||
precision_config);
|
||||
break;
|
||||
}
|
||||
case HloOpcode::kDomain: {
|
||||
@ -529,7 +529,7 @@ StatusOr<std::unique_ptr<HloInstruction>> HloInstruction::CreateFromProto(
|
||||
exit_hlo_sharding = std::make_shared<const HloSharding>(sharding);
|
||||
}
|
||||
instruction = absl::make_unique<HloDomainInstruction>(
|
||||
proto.shape(), operands(0),
|
||||
shape, operands(0),
|
||||
absl::make_unique<ShardingMetadata>(entry_hlo_sharding),
|
||||
absl::make_unique<ShardingMetadata>(exit_hlo_sharding));
|
||||
break;
|
||||
@ -537,11 +537,11 @@ StatusOr<std::unique_ptr<HloInstruction>> HloInstruction::CreateFromProto(
|
||||
case HloOpcode::kGetDimensionSize:
|
||||
TF_RET_CHECK(proto.operand_ids_size() == 1);
|
||||
TF_RET_CHECK(proto.dimensions_size() == 1);
|
||||
instruction = CreateGetDimensionSize(proto.shape(), operands(0),
|
||||
proto.dimensions(0));
|
||||
instruction =
|
||||
CreateGetDimensionSize(shape, operands(0), proto.dimensions(0));
|
||||
break;
|
||||
default: {
|
||||
instruction = absl::WrapUnique(new HloInstruction(opcode, proto.shape()));
|
||||
instruction = absl::WrapUnique(new HloInstruction(opcode, shape));
|
||||
for (const int64 operand_id : proto.operand_ids()) {
|
||||
instruction->AppendOperand(instruction_map.at(operand_id));
|
||||
}
|
||||
@ -2234,7 +2234,7 @@ HloInstructionProto HloInstruction::ToProto() const {
|
||||
proto.set_id(unique_id_);
|
||||
proto.set_name(name_);
|
||||
proto.set_opcode(HloOpcodeString(opcode_));
|
||||
*proto.mutable_shape() = shape_;
|
||||
*proto.mutable_shape() = shape_.ToProto();
|
||||
for (const HloInstruction* operand : operands_) {
|
||||
proto.add_operand_ids(operand->unique_id());
|
||||
}
|
||||
|
@ -1615,7 +1615,7 @@ HloOutfeedInstruction::HloOutfeedInstruction(const Shape& outfeed_shape,
|
||||
HloInstructionProto HloOutfeedInstruction::ToProto() const {
|
||||
HloInstructionProto proto = HloInstruction::ToProto();
|
||||
proto.set_outfeed_config(outfeed_config());
|
||||
*proto.mutable_outfeed_shape() = outfeed_shape();
|
||||
*proto.mutable_outfeed_shape() = outfeed_shape().ToProto();
|
||||
return proto;
|
||||
}
|
||||
|
||||
@ -1867,7 +1867,7 @@ HloInstructionProto HloCustomCallInstruction::ToProto() const {
|
||||
if (layout_constrained()) {
|
||||
proto.set_constrain_layout(true);
|
||||
for (const Shape& shape : operand_shapes_with_layout_) {
|
||||
*proto.add_operand_shapes_with_layout() = shape;
|
||||
*proto.add_operand_shapes_with_layout() = shape.ToProto();
|
||||
}
|
||||
}
|
||||
return proto;
|
||||
|
@ -20,6 +20,7 @@ limitations under the License.
|
||||
|
||||
#include "absl/strings/string_view.h"
|
||||
#include "tensorflow/compiler/xla/service/hlo_token.h"
|
||||
#include "tensorflow/compiler/xla/shape.h"
|
||||
#include "tensorflow/compiler/xla/types.h"
|
||||
#include "tensorflow/compiler/xla/xla_data.pb.h"
|
||||
#include "tensorflow/core/platform/logging.h"
|
||||
|
@ -257,7 +257,7 @@ StatusOr<std::unique_ptr<HloModule>> HloModule::CreateFromProto(
|
||||
// the entry parameters and root.
|
||||
TF_RET_CHECK(proto.has_host_program_shape())
|
||||
<< "No program shape found in the proto";
|
||||
const auto& expected_program_shape = proto.host_program_shape();
|
||||
ProgramShape expected_program_shape(proto.host_program_shape());
|
||||
TF_RET_CHECK(expected_program_shape.parameters_size() ==
|
||||
module_config.entry_computation_layout().parameter_count());
|
||||
for (int i = 0; i < expected_program_shape.parameters_size(); ++i) {
|
||||
@ -369,7 +369,7 @@ StatusOr<HloModuleConfig> HloModule::CreateModuleConfigFromProto(
|
||||
const HloModuleProto& module, const DebugOptions& debug_options) {
|
||||
TF_RET_CHECK(module.has_host_program_shape())
|
||||
<< "No program shape found in the proto";
|
||||
const auto& program_shape = module.host_program_shape();
|
||||
ProgramShape program_shape(module.host_program_shape());
|
||||
|
||||
HloModuleConfig module_config(ProgramShape{program_shape});
|
||||
module_config.set_debug_options(debug_options);
|
||||
|
@ -48,7 +48,7 @@ StatusOr<std::unique_ptr<HloModule>> CreateModuleFromProto(
|
||||
return std::move(module);
|
||||
}
|
||||
|
||||
StatusOr<std::vector<const Shape*>> EntryComputationParameterShapes(
|
||||
StatusOr<std::vector<const ShapeProto*>> EntryComputationParameterShapes(
|
||||
const HloProto& hlo_proto) {
|
||||
if (!hlo_proto.has_hlo_module()) {
|
||||
return NotFound("HloProto missing HloModuleProto.");
|
||||
@ -57,15 +57,16 @@ StatusOr<std::vector<const Shape*>> EntryComputationParameterShapes(
|
||||
return NotFound("HloProto missing program shape.");
|
||||
}
|
||||
|
||||
std::vector<const Shape*> parameter_shapes;
|
||||
std::vector<const ShapeProto*> parameter_shapes;
|
||||
const auto& program_shape = hlo_proto.hlo_module().host_program_shape();
|
||||
for (const Shape& shape : program_shape.parameters()) {
|
||||
for (const ShapeProto& shape : program_shape.parameters()) {
|
||||
parameter_shapes.push_back(&shape);
|
||||
}
|
||||
return parameter_shapes;
|
||||
}
|
||||
|
||||
StatusOr<const Shape*> EntryComputationOutputShape(const HloProto& hlo_proto) {
|
||||
StatusOr<const ShapeProto*> EntryComputationOutputShape(
|
||||
const HloProto& hlo_proto) {
|
||||
if (!hlo_proto.has_hlo_module()) {
|
||||
return NotFound("HloProto missing HloModuleProto.");
|
||||
}
|
||||
|
@ -43,12 +43,13 @@ StatusOr<std::unique_ptr<HloModule>> CreateModuleFromProto(
|
||||
|
||||
// Returns the shapes of the parameters of the entry computation. Shape pointers
|
||||
// refer to shapes inside of the given HloProto.
|
||||
StatusOr<std::vector<const Shape*>> EntryComputationParameterShapes(
|
||||
StatusOr<std::vector<const ShapeProto*>> EntryComputationParameterShapes(
|
||||
const HloProto& hlo_proto);
|
||||
|
||||
// Returns the shape of the output of the entry computation. The shape pointer
|
||||
// refers to the output shape inside of the given HloProto.
|
||||
StatusOr<const Shape*> EntryComputationOutputShape(const HloProto& hlo_proto);
|
||||
StatusOr<const ShapeProto*> EntryComputationOutputShape(
|
||||
const HloProto& hlo_proto);
|
||||
|
||||
} // namespace xla
|
||||
|
||||
|
@ -25,6 +25,7 @@ limitations under the License.
|
||||
#include "llvm/IR/IRBuilder.h"
|
||||
#include "llvm/IR/Value.h"
|
||||
#include "tensorflow/compiler/xla/map_util.h"
|
||||
#include "tensorflow/compiler/xla/shape.h"
|
||||
#include "tensorflow/compiler/xla/types.h"
|
||||
#include "tensorflow/compiler/xla/xla_data.pb.h"
|
||||
#include "tensorflow/core/platform/logging.h"
|
||||
|
@ -244,10 +244,11 @@ StatusOr<llvm::Value*> EncodeSelfDescribingShapeConstant(const Shape& shape,
|
||||
|
||||
StatusOr<Shape> DecodeSelfDescribingShapeConstant(const void* shape_ptr,
|
||||
int32 size_bytes) {
|
||||
Shape shape;
|
||||
TF_RET_CHECK(shape.ParseFromArray(shape_ptr, size_bytes));
|
||||
ShapeProto shape_proto;
|
||||
TF_RET_CHECK(shape_proto.ParseFromArray(shape_ptr, size_bytes));
|
||||
Shape shape(shape_proto);
|
||||
TF_RETURN_IF_ERROR(ShapeUtil::ValidateShape(shape));
|
||||
return shape;
|
||||
return std::move(shape);
|
||||
}
|
||||
|
||||
llvm::Constant* ConvertLiteralToIrConstant(const Literal& literal,
|
||||
|
@ -101,12 +101,12 @@ ExecutionOptions CreateExecutionOptions(
|
||||
}
|
||||
if (build_options.result_layout() != nullptr) {
|
||||
*execution_options.mutable_shape_with_output_layout() =
|
||||
*build_options.result_layout();
|
||||
build_options.result_layout()->ToProto();
|
||||
} else {
|
||||
Shape result_shape(program_shape->result());
|
||||
LayoutUtil::SetToDefaultLayout(&result_shape);
|
||||
*execution_options.mutable_shape_with_output_layout() =
|
||||
program_shape->result();
|
||||
LayoutUtil::SetToDefaultLayout(
|
||||
execution_options.mutable_shape_with_output_layout());
|
||||
result_shape.ToProto();
|
||||
}
|
||||
return execution_options;
|
||||
}
|
||||
|
@ -41,6 +41,7 @@ limitations under the License.
|
||||
#include "tensorflow/compiler/xla/service/source_map_util.h"
|
||||
#include "tensorflow/compiler/xla/service/stream_pool.h"
|
||||
#include "tensorflow/compiler/xla/service/transfer_manager.h"
|
||||
#include "tensorflow/compiler/xla/shape.h"
|
||||
#include "tensorflow/compiler/xla/shape_layout.h"
|
||||
#include "tensorflow/compiler/xla/shape_util.h"
|
||||
#include "tensorflow/compiler/xla/status_macros.h"
|
||||
@ -275,8 +276,8 @@ StatusOr<std::unique_ptr<HloModuleConfig>> Service::CreateModuleConfig(
|
||||
}
|
||||
if (execution_options != nullptr &&
|
||||
execution_options->has_shape_with_output_layout()) {
|
||||
const auto& shape_with_output_layout =
|
||||
execution_options->shape_with_output_layout();
|
||||
const Shape shape_with_output_layout(
|
||||
execution_options->shape_with_output_layout());
|
||||
TF_RETURN_IF_ERROR(
|
||||
ValidateResultShape(shape_with_output_layout, program_shape.result()));
|
||||
TF_RETURN_IF_ERROR(
|
||||
@ -818,14 +819,17 @@ Status Service::Compile(const CompileRequest* arg, CompileResponse* result) {
|
||||
"The compile request does not support multiple device handles.");
|
||||
}
|
||||
|
||||
std::vector<const Shape*> argument_shapes;
|
||||
absl::c_transform(arg->input_shape_with_layout(),
|
||||
std::back_inserter(argument_shapes),
|
||||
[](const Shape& shape) { return &shape; });
|
||||
std::vector<Shape> argument_shapes;
|
||||
argument_shapes.reserve(arg->input_shape_with_layout_size());
|
||||
std::vector<const Shape*> argument_shape_ptrs;
|
||||
for (const ShapeProto& shape_proto : arg->input_shape_with_layout()) {
|
||||
argument_shapes.push_back(Shape(shape_proto));
|
||||
argument_shape_ptrs.push_back(&argument_shapes.back());
|
||||
}
|
||||
TF_ASSIGN_OR_RETURN(
|
||||
std::unique_ptr<HloModuleConfig> module_config,
|
||||
CreateModuleConfig(ProgramShape{arg->computation().host_program_shape()},
|
||||
argument_shapes, &arg->execution_options()));
|
||||
argument_shape_ptrs, &arg->execution_options()));
|
||||
VLOG(3) << "Compile created HloModuleConfig computation layout: "
|
||||
<< module_config->entry_computation_layout().ToString();
|
||||
|
||||
@ -930,14 +934,14 @@ Status Service::TransferToClient(const TransferToClientRequest* arg,
|
||||
TF_ASSIGN_OR_RETURN(const ShapedBuffer* shaped_buffer,
|
||||
allocation_tracker_.ResolveForReplica(arg->data(), 0));
|
||||
|
||||
const Shape* return_shape;
|
||||
Shape return_shape;
|
||||
if (arg->has_shape_with_layout()) {
|
||||
if (!LayoutUtil::HasLayout(arg->shape_with_layout())) {
|
||||
return_shape = Shape(arg->shape_with_layout());
|
||||
if (!LayoutUtil::HasLayout(return_shape)) {
|
||||
return InvalidArgument("shape_with_layout must have layout if present.");
|
||||
}
|
||||
return_shape = &arg->shape_with_layout();
|
||||
} else {
|
||||
return_shape = &shaped_buffer->on_host_shape();
|
||||
return_shape = Shape(shaped_buffer->on_host_shape());
|
||||
}
|
||||
|
||||
TF_ASSIGN_OR_RETURN(auto stream, execute_backend_->BorrowStream(
|
||||
@ -948,11 +952,11 @@ Status Service::TransferToClient(const TransferToClientRequest* arg,
|
||||
execute_backend_->transfer_manager()->TransferLiteralFromDevice(
|
||||
stream.get(), *shaped_buffer));
|
||||
|
||||
if (LayoutUtil::LayoutsInShapesEqual(*return_shape, result_literal.shape())) {
|
||||
if (LayoutUtil::LayoutsInShapesEqual(return_shape, result_literal.shape())) {
|
||||
*result->mutable_literal() = result_literal.ToProto();
|
||||
} else {
|
||||
*result->mutable_literal() =
|
||||
result_literal.Relayout(*return_shape).ToProto();
|
||||
result_literal.Relayout(return_shape).ToProto();
|
||||
}
|
||||
return Status::OK();
|
||||
}
|
||||
@ -1045,11 +1049,11 @@ Status Service::TransferFromOutfeed(const TransferFromOutfeedRequest* arg,
|
||||
executor = replicas[arg->replica_id()];
|
||||
}
|
||||
|
||||
auto literal = Literal::CreateFromShape(arg->shape_with_layout());
|
||||
auto literal = Literal::CreateFromShape(Shape(arg->shape_with_layout()));
|
||||
|
||||
TF_RETURN_IF_ERROR(
|
||||
execute_backend_->transfer_manager()->TransferLiteralFromOutfeed(
|
||||
executor, arg->shape_with_layout(), literal));
|
||||
executor, Shape(arg->shape_with_layout()), literal));
|
||||
*result->mutable_literal() = literal.ToProto();
|
||||
return Status::OK();
|
||||
}
|
||||
@ -1103,7 +1107,7 @@ Status Service::ComputeConstantGraph(const ComputeConstantGraphRequest* arg,
|
||||
Status Service::GetShape(const GetShapeRequest* arg, GetShapeResponse* result) {
|
||||
TF_ASSIGN_OR_RETURN(const ShapedBuffer* buffer,
|
||||
allocation_tracker_.ResolveForReplica(arg->data(), 0));
|
||||
*result->mutable_shape() = buffer->on_host_shape();
|
||||
*result->mutable_shape() = buffer->on_host_shape().ToProto();
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
|
@ -1018,7 +1018,7 @@ ShapeInference::InferDegenerateDimensionBroadcastShape(HloOpcode operation,
|
||||
switch (opcode) {
|
||||
case HloOpcode::kTuple: {
|
||||
Shape result = ShapeUtil::MakeTupleShape({});
|
||||
result.mutable_tuple_shapes()->Reserve(operand_shapes.size());
|
||||
result.mutable_tuple_shapes()->reserve(operand_shapes.size());
|
||||
for (const Shape* shape : operand_shapes) {
|
||||
ShapeUtil::AppendShapeToTuple(*shape, &result);
|
||||
}
|
||||
|
@ -21,11 +21,56 @@ limitations under the License.
|
||||
|
||||
namespace xla {
|
||||
|
||||
ProgramShape::ProgramShape(const ProgramShapeProto& program_shape_proto) {
|
||||
for (const Shape& shape : program_shape_proto.parameters()) {
|
||||
*add_parameters() = shape;
|
||||
Shape::Shape(const ShapeProto& shape_proto) {
|
||||
set_element_type(shape_proto.element_type());
|
||||
dimensions_.reserve(shape_proto.dimensions_size());
|
||||
for (const int64 dimension : shape_proto.dimensions()) {
|
||||
add_dimensions(dimension);
|
||||
}
|
||||
*mutable_result() = program_shape_proto.result();
|
||||
tuple_shapes_.reserve(shape_proto.tuple_shapes_size());
|
||||
for (const ShapeProto& element_shape : shape_proto.tuple_shapes()) {
|
||||
*add_tuple_shapes() = Shape(element_shape);
|
||||
}
|
||||
if (shape_proto.has_layout()) {
|
||||
*mutable_layout() = shape_proto.layout();
|
||||
}
|
||||
}
|
||||
|
||||
ShapeProto Shape::ToProto() const {
|
||||
ShapeProto proto;
|
||||
proto.set_element_type(element_type_);
|
||||
proto.mutable_dimensions()->Reserve(dimensions_size());
|
||||
for (const int64 dimension : dimensions()) {
|
||||
proto.add_dimensions(dimension);
|
||||
}
|
||||
proto.mutable_tuple_shapes()->Reserve(tuple_shapes_size());
|
||||
for (const Shape& shape : tuple_shapes()) {
|
||||
*proto.add_tuple_shapes() = shape.ToProto();
|
||||
}
|
||||
if (has_layout()) {
|
||||
*proto.mutable_layout() = layout();
|
||||
}
|
||||
return proto;
|
||||
}
|
||||
|
||||
string Shape::ToString(bool print_layout) const {
|
||||
if (print_layout) {
|
||||
return ShapeUtil::HumanStringWithLayout(*this);
|
||||
} else {
|
||||
return ShapeUtil::HumanString(*this);
|
||||
}
|
||||
}
|
||||
|
||||
std::ostream& operator<<(std::ostream& out, const Shape& shape) {
|
||||
out << shape.ToString(/*print_layout=*/true);
|
||||
return out;
|
||||
}
|
||||
|
||||
ProgramShape::ProgramShape(const ProgramShapeProto& program_shape_proto) {
|
||||
for (const ShapeProto& shape_proto : program_shape_proto.parameters()) {
|
||||
*add_parameters() = Shape(shape_proto);
|
||||
}
|
||||
*mutable_result() = Shape(program_shape_proto.result());
|
||||
for (const string& name : program_shape_proto.parameter_names()) {
|
||||
add_parameter_names(name);
|
||||
}
|
||||
@ -34,9 +79,9 @@ ProgramShape::ProgramShape(const ProgramShapeProto& program_shape_proto) {
|
||||
ProgramShapeProto ProgramShape::ToProto() const {
|
||||
ProgramShapeProto proto;
|
||||
for (const Shape& shape : parameters()) {
|
||||
*proto.add_parameters() = shape;
|
||||
*proto.add_parameters() = shape.ToProto();
|
||||
}
|
||||
*proto.mutable_result() = result();
|
||||
*proto.mutable_result() = result().ToProto();
|
||||
for (const string& name : parameter_names()) {
|
||||
proto.add_parameter_names(name);
|
||||
}
|
||||
|
@ -26,6 +26,102 @@ limitations under the License.
|
||||
|
||||
namespace xla {
|
||||
|
||||
// A shape describes the number of dimensions in a array, the bounds of each
|
||||
// dimension, and the primitive component type. For tuples, shape describes the
|
||||
// structure (number of elements and nesting).
|
||||
class Shape {
|
||||
public:
|
||||
Shape() = default;
|
||||
|
||||
// Construct a shape from a ShapeProto.
|
||||
explicit Shape(const ShapeProto& shape_proto);
|
||||
|
||||
// Returns a ShapeProto representation of the Shape.
|
||||
ShapeProto ToProto() const;
|
||||
|
||||
// Returns a human-readable string that represents the given shape, with or
|
||||
// without layout. e.g. "F32[42,12] {0, 1}" or "F32[64]".
|
||||
string ToString(bool print_layout = false) const;
|
||||
|
||||
// The following methods mirror the protobuf generated code interface for the
|
||||
// message ShapeProto. This enabled easy migration of this data structure
|
||||
// from a proto to a proper C++ class.
|
||||
// TODO(b/29771030): Replace or augment these methods with a more ergonomic
|
||||
// interface.
|
||||
|
||||
// Methods for accessing the primitive type.
|
||||
PrimitiveType element_type() const { return element_type_; }
|
||||
void set_element_type(PrimitiveType value) { element_type_ = value; }
|
||||
|
||||
// Methods for accessing the dimensions array.
|
||||
int dimensions_size() const { return dimensions_.size(); }
|
||||
int64 dimensions(int index) const { return dimensions_.at(index); }
|
||||
void set_dimensions(int index, int64 value) { dimensions_.at(index) = value; }
|
||||
void add_dimensions(int64 value) { dimensions_.push_back(value); }
|
||||
void clear_dimensions() { dimensions_.clear(); }
|
||||
const std::vector<int64>& dimensions() const { return dimensions_; }
|
||||
std::vector<int64>* mutable_dimensions() { return &dimensions_; }
|
||||
|
||||
// Methods for accessing the tuple subshapes. This field only non-empty for
|
||||
// tuple shapes.
|
||||
int tuple_shapes_size() const { return tuple_shapes_.size(); }
|
||||
const Shape& tuple_shapes(int index) const { return tuple_shapes_.at(index); }
|
||||
Shape* mutable_tuple_shapes(int index) { return &tuple_shapes_.at(index); }
|
||||
Shape* add_tuple_shapes() {
|
||||
tuple_shapes_.push_back(Shape());
|
||||
return &tuple_shapes_.back();
|
||||
}
|
||||
void clear_tuple_shapes() { tuple_shapes_.clear(); }
|
||||
const std::vector<Shape>& tuple_shapes() const { return tuple_shapes_; }
|
||||
std::vector<Shape>* mutable_tuple_shapes() { return &tuple_shapes_; }
|
||||
|
||||
// Methods for accessing the layout field.
|
||||
bool has_layout() const { return layout_.has_value(); }
|
||||
const Layout& layout() const {
|
||||
if (layout_.has_value()) {
|
||||
return *layout_;
|
||||
} else {
|
||||
return Layout::default_instance();
|
||||
}
|
||||
}
|
||||
Layout* mutable_layout() {
|
||||
if (!layout_.has_value()) {
|
||||
layout_ = Layout();
|
||||
}
|
||||
return &layout_.value();
|
||||
}
|
||||
void clear_layout() { layout_.reset(); }
|
||||
|
||||
void Swap(Shape* other) {
|
||||
using std::swap;
|
||||
swap(*this, *other);
|
||||
}
|
||||
|
||||
void Clear() {
|
||||
element_type_ = PRIMITIVE_TYPE_INVALID;
|
||||
dimensions_.clear();
|
||||
tuple_shapes_.clear();
|
||||
layout_.reset();
|
||||
}
|
||||
|
||||
string SerializeAsString() const { return ToProto().SerializeAsString(); }
|
||||
string ShortDebugString() const { return ToProto().ShortDebugString(); }
|
||||
string DebugString() const { return ToProto().DebugString(); }
|
||||
|
||||
public:
|
||||
// The element type of this shape (tuple, array, etc).
|
||||
PrimitiveType element_type_ = PRIMITIVE_TYPE_INVALID;
|
||||
|
||||
// The array bounds of the dimensions. This is nonempty only for array shapes.
|
||||
std::vector<int64> dimensions_;
|
||||
|
||||
// The tuple element subshapes. This is nonempty only for tuple shapes.
|
||||
std::vector<Shape> tuple_shapes_;
|
||||
|
||||
// The array layout of the shape. This is present only for array shapes.
|
||||
absl::optional<Layout> layout_;
|
||||
};
|
||||
|
||||
// Shape of the parameters and output of an XLA computation. This is analogous
|
||||
// to a traditional function signature.
|
||||
class ProgramShape {
|
||||
@ -61,7 +157,6 @@ class ProgramShape {
|
||||
// Methods for accessing and manipulating the Shape of the result.
|
||||
const Shape& result() const { return result_; }
|
||||
Shape* mutable_result() { return &result_; }
|
||||
void clear_result() { result_.Clear(); }
|
||||
|
||||
// Methods for accessing and manipulating the names of the parameters.
|
||||
int parameter_names_size() const { return parameter_names_.size(); }
|
||||
@ -101,6 +196,7 @@ class ProgramShape {
|
||||
Shape result_;
|
||||
};
|
||||
|
||||
std::ostream& operator<<(std::ostream& out, const Shape& shape);
|
||||
std::ostream& operator<<(std::ostream& out, const ProgramShape& program_shape);
|
||||
|
||||
} // namespace xla
|
||||
|
@ -30,7 +30,51 @@ limitations under the License.
|
||||
namespace xla {
|
||||
namespace {
|
||||
|
||||
TEST(ShapeTest, ProgramShapeToFromProto) {
|
||||
class ShapeTest : public ::testing::Test {
|
||||
protected:
|
||||
const Shape opaque_ = ShapeUtil::MakeOpaqueShape();
|
||||
const Shape token_ = ShapeUtil::MakeTokenShape();
|
||||
const Shape scalar_ = ShapeUtil::MakeShape(F32, {});
|
||||
const Shape matrix_ = ShapeUtil::MakeShape(U32, {1, 2});
|
||||
const Shape matrix2_ = ShapeUtil::MakeShapeWithLayout(S32, {3, 4}, {0, 1});
|
||||
const Shape tuple_ =
|
||||
ShapeUtil::MakeTupleShape({opaque_, scalar_, matrix_, matrix2_});
|
||||
const Shape nested_tuple_ =
|
||||
ShapeUtil::MakeTupleShape({tuple_, matrix_, token_});
|
||||
};
|
||||
|
||||
TEST_F(ShapeTest, ShapeToFromProto) {
|
||||
for (const Shape& shape :
|
||||
{opaque_, token_, scalar_, matrix_, matrix2_, tuple_, nested_tuple_}) {
|
||||
Shape shape_copy(shape.ToProto());
|
||||
EXPECT_TRUE(ShapeUtil::Equal(shape, shape_copy))
|
||||
<< shape << " != " << shape_copy;
|
||||
}
|
||||
}
|
||||
|
||||
TEST_F(ShapeTest, ShapeToString) {
|
||||
EXPECT_EQ("opaque[]", opaque_.ToString());
|
||||
EXPECT_EQ("token[]", token_.ToString());
|
||||
EXPECT_EQ("f32[]", scalar_.ToString());
|
||||
EXPECT_EQ("u32[1,2]", matrix_.ToString());
|
||||
EXPECT_EQ("s32[3,4]", matrix2_.ToString());
|
||||
EXPECT_EQ("(opaque[], f32[], u32[1,2], s32[3,4])", tuple_.ToString());
|
||||
EXPECT_EQ("((opaque[], f32[], u32[1,2], s32[3,4]), u32[1,2], token[])",
|
||||
nested_tuple_.ToString());
|
||||
|
||||
EXPECT_EQ("opaque[]", opaque_.ToString(/*print_layout=*/true));
|
||||
EXPECT_EQ("f32[]", scalar_.ToString(/*print_layout=*/true));
|
||||
EXPECT_EQ("u32[1,2]{1,0}", matrix_.ToString(/*print_layout=*/true));
|
||||
EXPECT_EQ("s32[3,4]{0,1}", matrix2_.ToString(/*print_layout=*/true));
|
||||
EXPECT_EQ("(opaque[], f32[], u32[1,2]{1,0}, s32[3,4]{0,1})",
|
||||
tuple_.ToString(/*print_layout=*/true));
|
||||
EXPECT_EQ(
|
||||
"((opaque[], f32[], u32[1,2]{1,0}, s32[3,4]{0,1}), u32[1,2]{1,0}, "
|
||||
"token[])",
|
||||
nested_tuple_.ToString(/*print_layout=*/true));
|
||||
}
|
||||
|
||||
TEST_F(ShapeTest, ProgramShapeToFromProto) {
|
||||
ProgramShape program_shape;
|
||||
*program_shape.add_parameters() = ShapeUtil::MakeShape(F32, {1, 2, 3});
|
||||
*program_shape.add_parameters() = ShapeUtil::MakeTokenShape();
|
||||
@ -67,17 +111,10 @@ TEST(ShapeTest, ProgramShapeToFromProto) {
|
||||
}
|
||||
}
|
||||
|
||||
TEST(ShapeTest, ProgramShapeToString) {
|
||||
Shape opaque = ShapeUtil::MakeOpaqueShape();
|
||||
Shape token = ShapeUtil::MakeTokenShape();
|
||||
Shape scalar = ShapeUtil::MakeShape(F32, {});
|
||||
Shape matrix = ShapeUtil::MakeShape(U32, {1, 2});
|
||||
Shape matrix2 = ShapeUtil::MakeShapeWithLayout(S32, {3, 4}, {0, 1});
|
||||
Shape tuple = ShapeUtil::MakeTupleShape({opaque, scalar, matrix, matrix2});
|
||||
Shape nested_tuple = ShapeUtil::MakeTupleShape({tuple, matrix, token});
|
||||
|
||||
TEST_F(ShapeTest, ProgramShapeToString) {
|
||||
ProgramShape prog = ShapeUtil::MakeProgramShape(
|
||||
{opaque, scalar, matrix, matrix2, tuple, nested_tuple}, nested_tuple);
|
||||
{opaque_, scalar_, matrix_, matrix2_, tuple_, nested_tuple_},
|
||||
nested_tuple_);
|
||||
EXPECT_EQ(
|
||||
"((unknown): opaque[], "
|
||||
"(unknown): f32[], "
|
||||
@ -87,7 +124,7 @@ TEST(ShapeTest, ProgramShapeToString) {
|
||||
"(unknown): ((opaque[], f32[], u32[1,2], s32[3,4]), u32[1,2], token[])) "
|
||||
"-> "
|
||||
"((opaque[], f32[], u32[1,2], s32[3,4]), u32[1,2], token[])",
|
||||
ShapeUtil::HumanString(prog));
|
||||
prog.ToString());
|
||||
|
||||
prog.add_parameter_names("arg0");
|
||||
prog.add_parameter_names("scalar");
|
||||
@ -105,7 +142,7 @@ TEST(ShapeTest, ProgramShapeToString) {
|
||||
"token[])) "
|
||||
"-> "
|
||||
"((opaque[], f32[], u32[1,2], s32[3,4]), u32[1,2], token[])",
|
||||
ShapeUtil::HumanString(prog));
|
||||
prog.ToString());
|
||||
}
|
||||
|
||||
} // namespace
|
||||
|
@ -79,14 +79,14 @@ bool ShapeIndexView::StartsWith(ShapeIndexView prefix) const {
|
||||
indices_.subspan(0, prefix.size()) == prefix.indices_;
|
||||
}
|
||||
|
||||
namespace {
|
||||
|
||||
// Returns whether the given primitive type corresponds to an array shape.
|
||||
bool IsArrayPrimitiveType(PrimitiveType primitive_type) {
|
||||
/* static */ bool ShapeUtil::IsArrayPrimitiveType(
|
||||
PrimitiveType primitive_type) {
|
||||
return primitive_type != PRIMITIVE_TYPE_INVALID && primitive_type != TUPLE &&
|
||||
primitive_type != OPAQUE && primitive_type != TOKEN;
|
||||
}
|
||||
|
||||
namespace {
|
||||
|
||||
// Recursive helper for comparing the equality of two shapes. Returns true if
|
||||
// the shapes are the same. If compare_layouts is true, then layouts must also
|
||||
// match.
|
||||
@ -203,7 +203,7 @@ StatusOr<Shape> MakeShapeWithLayoutInternal(
|
||||
/* static */ ProgramShape ShapeUtil::MakeProgramShape(
|
||||
std::initializer_list<Shape> parameters, Shape result) {
|
||||
ProgramShape program_shape;
|
||||
for (const auto& shape : parameters) {
|
||||
for (const Shape& shape : parameters) {
|
||||
*program_shape.add_parameters() = shape;
|
||||
}
|
||||
*program_shape.mutable_result() = std::move(result);
|
||||
@ -272,7 +272,7 @@ ShapeUtil::MakeShapeWithDescendingLayoutAndSamePhysicalLayout(
|
||||
/* static */ Shape ShapeUtil::MakeTupleShape(absl::Span<const Shape> shapes) {
|
||||
Shape result;
|
||||
result.set_element_type(TUPLE);
|
||||
result.mutable_tuple_shapes()->Reserve(shapes.size());
|
||||
result.mutable_tuple_shapes()->reserve(shapes.size());
|
||||
for (const auto& shape : shapes) {
|
||||
AppendShapeToTuple(shape, &result);
|
||||
}
|
||||
@ -563,20 +563,6 @@ StatusOr<PrimitiveType> StringToPrimitiveType(const string& name) {
|
||||
HumanString(program_shape.result()));
|
||||
}
|
||||
|
||||
/* static */ string ShapeUtil::HumanString(
|
||||
const ProgramShapeProto& program_shape_proto) {
|
||||
std::vector<string> parameters;
|
||||
for (auto& shape : program_shape_proto.parameters()) {
|
||||
const int i = parameters.size();
|
||||
parameters.push_back(StrCat(i < program_shape_proto.parameter_names_size()
|
||||
? program_shape_proto.parameter_names(i)
|
||||
: "(unknown)",
|
||||
": ", HumanString(shape)));
|
||||
}
|
||||
return StrCat("(", absl::StrJoin(parameters, ", "), ") -> ",
|
||||
HumanString(program_shape_proto.result()));
|
||||
}
|
||||
|
||||
namespace {
|
||||
// Parses shapes with simple recursive descent structure -- consumes from the
|
||||
// front of s and passes that view recursively as required.
|
||||
@ -1610,7 +1596,8 @@ ShapeUtil::DimensionsUnmodifiedByReshape(const Shape& input_shape,
|
||||
/* static */ Shape ShapeUtil::DeleteDimension(int64 dim_to_delete,
|
||||
Shape shape) {
|
||||
CHECK(IsArray(shape));
|
||||
shape.mutable_dimensions()->erase(shape.dimensions().begin() + dim_to_delete);
|
||||
shape.mutable_dimensions()->erase(shape.mutable_dimensions()->begin() +
|
||||
dim_to_delete);
|
||||
if (LayoutUtil::HasLayout(shape)) {
|
||||
Layout* layout = shape.mutable_layout();
|
||||
layout->set_format(DENSE);
|
||||
@ -1644,11 +1631,6 @@ ShapeUtil::DimensionsUnmodifiedByReshape(const Shape& input_shape,
|
||||
return shape;
|
||||
}
|
||||
|
||||
std::ostream& operator<<(std::ostream& out, const Shape& shape) {
|
||||
out << ShapeUtil::HumanStringWithLayout(shape);
|
||||
return out;
|
||||
}
|
||||
|
||||
/*static*/ size_t ShapeUtil::Hash(const Shape& shape) {
|
||||
using tensorflow::hash;
|
||||
using tensorflow::Hash64Combine;
|
||||
|
@ -240,7 +240,6 @@ class ShapeUtil {
|
||||
//
|
||||
// (param_name: f32[42x12], ...) -> f32[24x42]
|
||||
static string HumanString(const ProgramShape& program_shape);
|
||||
static string HumanString(const ProgramShapeProto& program_shape_proto);
|
||||
|
||||
// Parses a ShapeUtil::HumanString-format shape string back into a shape
|
||||
// object.
|
||||
@ -469,6 +468,9 @@ class ShapeUtil {
|
||||
// arrays.
|
||||
static bool IsArray(const Shape& shape);
|
||||
|
||||
// Returns whether the given primitive type corresponds to an array shape.
|
||||
static bool IsArrayPrimitiveType(PrimitiveType primitive_type);
|
||||
|
||||
// Returns whether the shape is a tuple with at least one element which is
|
||||
// also a tuple.
|
||||
static bool IsNestedTuple(const Shape& shape);
|
||||
@ -796,8 +798,6 @@ class ShapeUtil {
|
||||
TF_DISALLOW_COPY_AND_ASSIGN(ShapeUtil);
|
||||
};
|
||||
|
||||
std::ostream& operator<<(std::ostream& out, const Shape& shape);
|
||||
|
||||
} // namespace xla
|
||||
|
||||
#endif // TENSORFLOW_COMPILER_XLA_SHAPE_UTIL_H_
|
||||
|
@ -546,37 +546,6 @@ TEST(ShapeUtilTest, IsLeafIndex) {
|
||||
EXPECT_TRUE(ShapeUtil::IsLeafIndex(nested_tuple_shape, {1, 1}));
|
||||
}
|
||||
|
||||
TEST(ShapeUtilTest, HumanString) {
|
||||
Shape opaque = ShapeUtil::MakeOpaqueShape();
|
||||
Shape token = ShapeUtil::MakeTokenShape();
|
||||
Shape scalar = ShapeUtil::MakeShape(F32, {});
|
||||
Shape matrix = ShapeUtil::MakeShape(U32, {1, 2});
|
||||
Shape matrix2 = ShapeUtil::MakeShapeWithLayout(S32, {3, 4}, {0, 1});
|
||||
Shape tuple = ShapeUtil::MakeTupleShape({opaque, scalar, matrix, matrix2});
|
||||
Shape nested_tuple = ShapeUtil::MakeTupleShape({tuple, matrix, token});
|
||||
|
||||
EXPECT_EQ("opaque[]", ShapeUtil::HumanString(opaque));
|
||||
EXPECT_EQ("token[]", ShapeUtil::HumanString(token));
|
||||
EXPECT_EQ("f32[]", ShapeUtil::HumanString(scalar));
|
||||
EXPECT_EQ("u32[1,2]", ShapeUtil::HumanString(matrix));
|
||||
EXPECT_EQ("s32[3,4]", ShapeUtil::HumanString(matrix2));
|
||||
EXPECT_EQ("(opaque[], f32[], u32[1,2], s32[3,4])",
|
||||
ShapeUtil::HumanString(tuple));
|
||||
EXPECT_EQ("((opaque[], f32[], u32[1,2], s32[3,4]), u32[1,2], token[])",
|
||||
ShapeUtil::HumanString(nested_tuple));
|
||||
|
||||
EXPECT_EQ("opaque[]", ShapeUtil::HumanStringWithLayout(opaque));
|
||||
EXPECT_EQ("f32[]", ShapeUtil::HumanStringWithLayout(scalar));
|
||||
EXPECT_EQ("u32[1,2]{1,0}", ShapeUtil::HumanStringWithLayout(matrix));
|
||||
EXPECT_EQ("s32[3,4]{0,1}", ShapeUtil::HumanStringWithLayout(matrix2));
|
||||
EXPECT_EQ("(opaque[], f32[], u32[1,2]{1,0}, s32[3,4]{0,1})",
|
||||
ShapeUtil::HumanStringWithLayout(tuple));
|
||||
EXPECT_EQ(
|
||||
"((opaque[], f32[], u32[1,2]{1,0}, s32[3,4]{0,1}), u32[1,2]{1,0}, "
|
||||
"token[])",
|
||||
ShapeUtil::HumanStringWithLayout(nested_tuple));
|
||||
}
|
||||
|
||||
TEST(ShapeUtilTest, ForEachSubshapeArray) {
|
||||
const Shape shape = ShapeUtil::MakeShape(F32, {2, 3});
|
||||
int calls = 0;
|
||||
|
@ -107,7 +107,7 @@ StatusOr<Literal> ClientLibraryTestBase::ExecuteAndTransfer(
|
||||
ExecutionOptions execution_options = execution_options_;
|
||||
if (shape_with_output_layout != nullptr) {
|
||||
*execution_options.mutable_shape_with_output_layout() =
|
||||
*shape_with_output_layout;
|
||||
shape_with_output_layout->ToProto();
|
||||
}
|
||||
return client_->ExecuteAndTransfer(computation, arguments,
|
||||
&execution_options);
|
||||
@ -127,7 +127,7 @@ StatusOr<Literal> ClientLibraryTestBase::ExecuteAndTransferReference(
|
||||
ExecutionOptions execution_options = execution_options_;
|
||||
if (shape_with_output_layout != nullptr) {
|
||||
*execution_options.mutable_shape_with_output_layout() =
|
||||
*shape_with_output_layout;
|
||||
shape_with_output_layout->ToProto();
|
||||
}
|
||||
execution_options.clear_device_handles();
|
||||
return ref_client_->ExecuteAndTransfer(computation, arguments,
|
||||
|
@ -50,7 +50,8 @@ XLA_TEST_F(ClientTest, ExecuteWithLayout) {
|
||||
ExecutionOptions execution_options = execution_options_;
|
||||
*execution_options.mutable_shape_with_output_layout() =
|
||||
ShapeUtil::MakeShapeWithLayout(S32, /*dimensions=*/{2, 2},
|
||||
execute_layout);
|
||||
execute_layout)
|
||||
.ToProto();
|
||||
TF_ASSERT_OK_AND_ASSIGN(
|
||||
std::unique_ptr<GlobalData> data,
|
||||
client_->Execute(computation, {}, &execution_options));
|
||||
@ -84,7 +85,8 @@ XLA_TEST_F(ClientTest, ExecuteWithTupleLayout) {
|
||||
{ShapeUtil::MakeShapeWithLayout(S32, /*dimensions=*/{2, 2},
|
||||
/*minor_to_major=*/{0, 1}),
|
||||
ShapeUtil::MakeShapeWithLayout(S32, /*dimensions=*/{2, 2},
|
||||
/*minor_to_major=*/{1, 0})});
|
||||
/*minor_to_major=*/{1, 0})})
|
||||
.ToProto();
|
||||
|
||||
TF_ASSERT_OK_AND_ASSIGN(
|
||||
auto result,
|
||||
|
@ -618,7 +618,8 @@ XLA_TEST_P(ReshapeTest, R4Dim0MinorLayoutToR2Dim0MajorLayout) {
|
||||
ExecutionOptions execution_options = execution_options_;
|
||||
*execution_options.mutable_shape_with_output_layout() =
|
||||
ShapeUtil::MakeShapeWithLayout(use_bfloat16() ? BF16 : F32, {2, 8},
|
||||
{1, 0});
|
||||
{1, 0})
|
||||
.ToProto();
|
||||
Literal actual =
|
||||
client_
|
||||
->ExecuteAndTransfer(computation, {input.get()}, &execution_options)
|
||||
@ -767,7 +768,8 @@ XLA_TEST_P(ReshapeTest, NoopReshape) {
|
||||
ExecutionOptions execution_options = execution_options_;
|
||||
*execution_options.mutable_shape_with_output_layout() =
|
||||
ShapeUtil::MakeShapeWithLayout(use_bfloat16() ? BF16 : F32, {7, 2, 3, 5},
|
||||
{2, 3, 0, 1});
|
||||
{2, 3, 0, 1})
|
||||
.ToProto();
|
||||
Literal output_literal =
|
||||
client_
|
||||
->ExecuteAndTransfer(computation, {input_data.get()},
|
||||
|
@ -82,13 +82,17 @@ struct Options {
|
||||
std::unique_ptr<LocalExecutable> CompileExecutable(const HloSnapshot& module,
|
||||
LocalClient* client) {
|
||||
XlaComputation computation(module.hlo().hlo_module());
|
||||
std::vector<const Shape*> argument_layouts;
|
||||
for (const auto& param :
|
||||
std::vector<Shape> argument_layouts;
|
||||
argument_layouts.reserve(
|
||||
computation.proto().host_program_shape().parameters_size());
|
||||
std::vector<const Shape*> argument_layout_ptrs;
|
||||
for (const ShapeProto& param :
|
||||
computation.proto().host_program_shape().parameters()) {
|
||||
argument_layouts.push_back(¶m);
|
||||
argument_layouts.push_back(Shape(param));
|
||||
argument_layout_ptrs.push_back(&argument_layouts.back());
|
||||
}
|
||||
return client
|
||||
->Compile(computation, argument_layouts, ExecutableBuildOptions())
|
||||
->Compile(computation, argument_layout_ptrs, ExecutableBuildOptions())
|
||||
.ValueOrDie();
|
||||
}
|
||||
|
||||
@ -149,7 +153,7 @@ StatusOr<Literal> ReplayComputation(const HloSnapshot& module,
|
||||
<< "--generate_fake_infeed only works if the model has 0 or 1 "
|
||||
"infeed ops, but this one has >= 2.";
|
||||
provide_infeed = true;
|
||||
infeed_shape = instruction.shape();
|
||||
infeed_shape = Shape(instruction.shape());
|
||||
LOG(INFO) << "Generating fake infeed shape for inferred shape: "
|
||||
<< ShapeUtil::HumanString(infeed_shape);
|
||||
}
|
||||
@ -315,9 +319,10 @@ int RealMain(absl::Span<char* const> args, const Options& opts) {
|
||||
if (snapshot.has_result()) {
|
||||
Literal literal =
|
||||
Literal::CreateFromProto(snapshot.result()).ConsumeValueOrDie();
|
||||
fprintf(stdout, "was %s:%s\n",
|
||||
ShapeUtil::HumanString(snapshot.result().shape()).c_str(),
|
||||
literal.ToString().c_str());
|
||||
fprintf(
|
||||
stdout, "was %s:%s\n",
|
||||
ShapeUtil::HumanString(Shape(snapshot.result().shape())).c_str(),
|
||||
literal.ToString().c_str());
|
||||
}
|
||||
}
|
||||
}
|
||||
|
@ -152,6 +152,13 @@ static inline absl::Span<const int64> AsInt64Slice(
|
||||
slice.size());
|
||||
}
|
||||
|
||||
// TODO(b/29771030): This nop overload was added to simplify the migration of
|
||||
// Shape from a proto to a C++ class. Remove after class has been migrated.
|
||||
static inline absl::Span<const int64> AsInt64Slice(
|
||||
absl::Span<const int64> slice) {
|
||||
return slice;
|
||||
}
|
||||
|
||||
// As above, but for uint64 types.
|
||||
static inline absl::Span<const uint64> AsUInt64Slice(
|
||||
const tensorflow::protobuf::RepeatedField<tensorflow::protobuf_uint64>& v) {
|
||||
|
@ -224,7 +224,7 @@ message ExecutionOptions {
|
||||
// may be faster when using this layout.
|
||||
//
|
||||
// We use a Shape here to accommodate computations that return a tuple.
|
||||
Shape shape_with_output_layout = 2;
|
||||
ShapeProto shape_with_output_layout = 2;
|
||||
|
||||
// Used to seed random-number generators used in this computation. If this is
|
||||
// 0, we generate a seed ourselves.
|
||||
@ -253,7 +253,7 @@ message TransferToClientRequest {
|
||||
|
||||
// This optional field directs the service to return the literal in this
|
||||
// layout. A shape is used to hold the layout to accommodate tuples.
|
||||
Shape shape_with_layout = 2;
|
||||
ShapeProto shape_with_layout = 2;
|
||||
}
|
||||
|
||||
message TransferToClientResponse {
|
||||
@ -281,7 +281,7 @@ message TransferToInfeedResponse {
|
||||
message TransferFromOutfeedRequest {
|
||||
// This optional field directs the service to return the literal in this
|
||||
// layout. A shape is used to hold the layout to accommodate tuples.
|
||||
Shape shape_with_layout = 1;
|
||||
ShapeProto shape_with_layout = 1;
|
||||
|
||||
int64 replica_id = 2;
|
||||
DeviceHandle device_handle = 3;
|
||||
@ -332,7 +332,7 @@ message CompileRequest {
|
||||
// The layouts of the input arguments. If not set, the default layout will be
|
||||
// used. Although the real arguments are not needed in compilation, the
|
||||
// layouts of the arguments can affect the compilation.
|
||||
repeated Shape input_shape_with_layout = 3;
|
||||
repeated ShapeProto input_shape_with_layout = 3;
|
||||
}
|
||||
|
||||
message CompileResponse {
|
||||
@ -406,7 +406,7 @@ message LoadDataRequest {
|
||||
string columnio_field = 2;
|
||||
|
||||
// Individual element shape, excluding rows.
|
||||
Shape element_shape = 3;
|
||||
ShapeProto element_shape = 3;
|
||||
|
||||
// Warning: ColumnIO does not support random-access, so use offset with
|
||||
// caution in performance-critical scenarios.
|
||||
@ -422,7 +422,7 @@ message LoadDataRequest {
|
||||
|
||||
message LoadDataResponse {
|
||||
GlobalDataHandle data = 1;
|
||||
Shape data_shape = 2;
|
||||
ShapeProto data_shape = 2;
|
||||
int64 available_rows = 3;
|
||||
int64 rows_loaded = 4;
|
||||
int64 nanoseconds = 5;
|
||||
@ -433,7 +433,7 @@ message GetShapeRequest {
|
||||
}
|
||||
|
||||
message GetShapeResponse {
|
||||
Shape shape = 1;
|
||||
ShapeProto shape = 1;
|
||||
}
|
||||
|
||||
message UnpackRequest {
|
||||
|
@ -154,7 +154,7 @@ message Layout {
|
||||
// See the XLA documentation for more information on shapes and layouts.
|
||||
//
|
||||
// LINT.IfChange
|
||||
message Shape {
|
||||
message ShapeProto {
|
||||
reserved 1;
|
||||
reserved "rank";
|
||||
|
||||
@ -169,7 +169,7 @@ message Shape {
|
||||
repeated int64 dimensions = 3;
|
||||
|
||||
// For tuples only, the shapes of constitutent shapes in the tuple sequence.
|
||||
repeated Shape tuple_shapes = 4;
|
||||
repeated ShapeProto tuple_shapes = 4;
|
||||
|
||||
// The layout used to back this shape.
|
||||
Layout layout = 5;
|
||||
@ -184,8 +184,8 @@ message Shape {
|
||||
// Shape of the parameters and output of a computation (like a traditional
|
||||
// function signature).
|
||||
message ProgramShapeProto {
|
||||
repeated Shape parameters = 1;
|
||||
Shape result = 2;
|
||||
repeated ShapeProto parameters = 1;
|
||||
ShapeProto result = 2;
|
||||
repeated string parameter_names = 3;
|
||||
}
|
||||
|
||||
@ -320,7 +320,7 @@ message DeviceAssignmentProto {
|
||||
// Transfers to/from the client are encoded in literal form, and the structure
|
||||
// of the repeated fields is implied by the shape.
|
||||
message LiteralProto {
|
||||
Shape shape = 1;
|
||||
ShapeProto shape = 1;
|
||||
repeated bool preds = 2;
|
||||
bytes s8s = 15;
|
||||
bytes u8s = 3;
|
||||
@ -521,7 +521,7 @@ message OpSharding {
|
||||
}
|
||||
Type type = 1;
|
||||
// The shape of the sharded tile.
|
||||
Shape tile_shape = 2;
|
||||
ShapeProto tile_shape = 2;
|
||||
// The shape of the tile assignment tensor - this must be the same rank as
|
||||
// tile_shape and the product of its dimensions must equal
|
||||
// tile_assignment_devices.size().
|
||||
|
@ -109,14 +109,17 @@ Status XRTCompileOp::Compile(OpKernelContext* ctx,
|
||||
TF_ASSIGN_OR_RETURN(xla::XlaComputation computation,
|
||||
client->LoadSnapshot(computation_proto.hlo_snapshot()));
|
||||
|
||||
std::vector<const xla::Shape*> argument_layouts(
|
||||
std::vector<xla::Shape> argument_layouts(
|
||||
config.program_shape().parameters_size());
|
||||
std::vector<const xla::Shape*> argument_layout_ptrs(
|
||||
config.program_shape().parameters_size());
|
||||
for (int i = 0; i < config.program_shape().parameters_size(); ++i) {
|
||||
argument_layouts[i] = &config.program_shape().parameters(i);
|
||||
argument_layouts[i] = xla::Shape(config.program_shape().parameters(i));
|
||||
argument_layout_ptrs[i] = &argument_layouts[i];
|
||||
}
|
||||
xla::ExecutableBuildOptions build_options;
|
||||
build_options.set_device_ordinal(client->default_device_ordinal());
|
||||
build_options.set_result_layout(config.program_shape().result());
|
||||
build_options.set_result_layout(xla::Shape(config.program_shape().result()));
|
||||
build_options.set_device_allocator(device_ref.backend()->memory_allocator());
|
||||
if (config.has_debug_options()) {
|
||||
*build_options.mutable_debug_options() =
|
||||
@ -125,7 +128,7 @@ Status XRTCompileOp::Compile(OpKernelContext* ctx,
|
||||
|
||||
VLOG(1) << "Building executable";
|
||||
auto compile_result =
|
||||
client->Compile(computation, argument_layouts, build_options);
|
||||
client->Compile(computation, argument_layout_ptrs, build_options);
|
||||
if (!compile_result.ok()) {
|
||||
return compile_result.status();
|
||||
}
|
||||
|
@ -375,9 +375,12 @@ TEST(RawApiTest, CompileAndExecute) {
|
||||
xrt::XLAComputation c;
|
||||
auto config = c.mutable_config();
|
||||
auto shapes = config->mutable_program_shape();
|
||||
*shapes->add_parameters() = xla::ShapeUtil::MakeShape(xla::F32, {2});
|
||||
*shapes->add_parameters() = xla::ShapeUtil::MakeShape(xla::F32, {2});
|
||||
*shapes->mutable_result() = xla::ShapeUtil::MakeShape(xla::F32, {2});
|
||||
*shapes->add_parameters() =
|
||||
xla::ShapeUtil::MakeShape(xla::F32, {2}).ToProto();
|
||||
*shapes->add_parameters() =
|
||||
xla::ShapeUtil::MakeShape(xla::F32, {2}).ToProto();
|
||||
*shapes->mutable_result() =
|
||||
xla::ShapeUtil::MakeShape(xla::F32, {2}).ToProto();
|
||||
StoreComputationSnapshot(AddAndScale(), c.mutable_hlo_snapshot());
|
||||
|
||||
xrt::XRTExecutionConfig e;
|
||||
@ -427,9 +430,12 @@ TEST(RawApiTest, CompileAndExecuteWithArgumentVector) {
|
||||
xrt::XLAComputation c;
|
||||
auto config = c.mutable_config();
|
||||
auto shapes = config->mutable_program_shape();
|
||||
*shapes->add_parameters() = xla::ShapeUtil::MakeShape(xla::F32, {2});
|
||||
*shapes->add_parameters() = xla::ShapeUtil::MakeShape(xla::F32, {2});
|
||||
*shapes->mutable_result() = xla::ShapeUtil::MakeShape(xla::F32, {2});
|
||||
*shapes->add_parameters() =
|
||||
xla::ShapeUtil::MakeShape(xla::F32, {2}).ToProto();
|
||||
*shapes->add_parameters() =
|
||||
xla::ShapeUtil::MakeShape(xla::F32, {2}).ToProto();
|
||||
*shapes->mutable_result() =
|
||||
xla::ShapeUtil::MakeShape(xla::F32, {2}).ToProto();
|
||||
StoreComputationSnapshot(AddAndScale(), c.mutable_hlo_snapshot());
|
||||
|
||||
xrt::XRTExecutionConfig e;
|
||||
@ -494,8 +500,8 @@ TEST(RawApiTest, CompileWithXlaReturnShapes) {
|
||||
xrt::XLAComputation c;
|
||||
auto config = c.mutable_config();
|
||||
auto shapes = config->mutable_program_shape();
|
||||
*shapes->add_parameters() = param_shape;
|
||||
*shapes->mutable_result() = result_shape;
|
||||
*shapes->add_parameters() = param_shape.ToProto();
|
||||
*shapes->mutable_result() = result_shape.ToProto();
|
||||
StoreComputationSnapshot(xla_computation, c.mutable_hlo_snapshot());
|
||||
|
||||
Scope root = Scope::NewRootScope().WithDevice(DeviceFromFlag());
|
||||
@ -510,8 +516,9 @@ TEST(RawApiTest, CompileWithXlaReturnShapes) {
|
||||
TF_EXPECT_OK(session.Run(tensorflow::ClientSession::FeedType(),
|
||||
{c_handle.program_shape}, {release}, &outputs));
|
||||
|
||||
xla::ProgramShapeProto program_shape;
|
||||
EXPECT_TRUE(program_shape.ParseFromString(outputs[0].vec<string>()(0)));
|
||||
xla::ProgramShapeProto program_shape_proto;
|
||||
EXPECT_TRUE(program_shape_proto.ParseFromString(outputs[0].vec<string>()(0)));
|
||||
xla::ProgramShape program_shape(program_shape_proto);
|
||||
EXPECT_EQ(program_shape.parameters_size(), 1);
|
||||
|
||||
VLOG(2) << "Param: "
|
||||
@ -547,11 +554,11 @@ TEST(RawApiTest, DotGeneralWithLayoutTest) {
|
||||
auto config = c.mutable_config();
|
||||
auto shapes = config->mutable_program_shape();
|
||||
*shapes->add_parameters() =
|
||||
xla::ShapeUtil::MakeShapeWithLayout(xla::F32, {2, 2}, {0, 1});
|
||||
xla::ShapeUtil::MakeShapeWithLayout(xla::F32, {2, 2}, {0, 1}).ToProto();
|
||||
*shapes->add_parameters() =
|
||||
xla::ShapeUtil::MakeShapeWithLayout(xla::F32, {2, 1}, {0, 1});
|
||||
xla::ShapeUtil::MakeShapeWithLayout(xla::F32, {2, 1}, {0, 1}).ToProto();
|
||||
*shapes->mutable_result() =
|
||||
xla::ShapeUtil::MakeShapeWithLayout(xla::F32, {2, 1}, {0, 1});
|
||||
xla::ShapeUtil::MakeShapeWithLayout(xla::F32, {2, 1}, {0, 1}).ToProto();
|
||||
StoreComputationSnapshot(Dot(), c.mutable_hlo_snapshot());
|
||||
|
||||
xrt::XRTExecutionConfig e;
|
||||
@ -592,7 +599,7 @@ TEST(RawApiTest, CompileAndExecuteZeroArg) {
|
||||
xrt::XLAComputation c;
|
||||
auto config = c.mutable_config();
|
||||
auto shapes = config->mutable_program_shape();
|
||||
*shapes->mutable_result() = xla::ShapeUtil::MakeShape(xla::F32, {});
|
||||
*shapes->mutable_result() = xla::ShapeUtil::MakeShape(xla::F32, {}).ToProto();
|
||||
|
||||
xrt::XRTExecutionConfig e;
|
||||
e.set_release_input_handles(true);
|
||||
@ -632,10 +639,13 @@ TEST(RawApiTest, CompileAndExecuteReturnTuple) {
|
||||
xrt::XLAComputation c;
|
||||
auto config = c.mutable_config();
|
||||
auto shapes = config->mutable_program_shape();
|
||||
*shapes->add_parameters() = xla::ShapeUtil::MakeShape(xla::F32, {2});
|
||||
*shapes->add_parameters() = xla::ShapeUtil::MakeShape(xla::F32, {2});
|
||||
*shapes->mutable_result() = xla::ShapeUtil::MakeTupleShape(
|
||||
{xla::ShapeUtil::MakeShape(xla::F32, {2})});
|
||||
*shapes->add_parameters() =
|
||||
xla::ShapeUtil::MakeShape(xla::F32, {2}).ToProto();
|
||||
*shapes->add_parameters() =
|
||||
xla::ShapeUtil::MakeShape(xla::F32, {2}).ToProto();
|
||||
*shapes->mutable_result() =
|
||||
xla::ShapeUtil::MakeTupleShape({xla::ShapeUtil::MakeShape(xla::F32, {2})})
|
||||
.ToProto();
|
||||
StoreComputationSnapshot(AddAndTuple(), c.mutable_hlo_snapshot());
|
||||
|
||||
xrt::XRTExecutionConfig e;
|
||||
@ -675,10 +685,13 @@ TEST(RawApiTest, LeakCompilationReference) {
|
||||
xrt::XLAComputation c;
|
||||
auto config = c.mutable_config();
|
||||
auto shapes = config->mutable_program_shape();
|
||||
*shapes->add_parameters() = xla::ShapeUtil::MakeShape(xla::F32, {2});
|
||||
*shapes->add_parameters() = xla::ShapeUtil::MakeShape(xla::F32, {2});
|
||||
*shapes->mutable_result() = xla::ShapeUtil::MakeTupleShape(
|
||||
{xla::ShapeUtil::MakeShape(xla::F32, {2})});
|
||||
*shapes->add_parameters() =
|
||||
xla::ShapeUtil::MakeShape(xla::F32, {2}).ToProto();
|
||||
*shapes->add_parameters() =
|
||||
xla::ShapeUtil::MakeShape(xla::F32, {2}).ToProto();
|
||||
*shapes->mutable_result() =
|
||||
xla::ShapeUtil::MakeTupleShape({xla::ShapeUtil::MakeShape(xla::F32, {2})})
|
||||
.ToProto();
|
||||
StoreComputationSnapshot(AddAndTuple(), c.mutable_hlo_snapshot());
|
||||
|
||||
Scope root = Scope::NewRootScope().WithDevice(DeviceFromFlag());
|
||||
@ -703,9 +716,9 @@ TEST(RawApiTest, CompileAndExecuteWithS64Argument) {
|
||||
xrt::XLAComputation c;
|
||||
auto config = c.mutable_config();
|
||||
auto shapes = config->mutable_program_shape();
|
||||
*shapes->add_parameters() = xla::ShapeUtil::MakeShape(xla::S64, {});
|
||||
*shapes->add_parameters() = xla::ShapeUtil::MakeShape(xla::S64, {});
|
||||
*shapes->mutable_result() = xla::ShapeUtil::MakeShape(xla::S64, {});
|
||||
*shapes->add_parameters() = xla::ShapeUtil::MakeShape(xla::S64, {}).ToProto();
|
||||
*shapes->add_parameters() = xla::ShapeUtil::MakeShape(xla::S64, {}).ToProto();
|
||||
*shapes->mutable_result() = xla::ShapeUtil::MakeShape(xla::S64, {}).ToProto();
|
||||
StoreComputationSnapshot(AddS64(), c.mutable_hlo_snapshot());
|
||||
|
||||
xrt::XRTExecutionConfig e;
|
||||
@ -742,8 +755,8 @@ TEST(RawApiTest, CompileAndExecuteWithS64Argument) {
|
||||
xla::ProgramShapeProto program_shape;
|
||||
EXPECT_TRUE(program_shape.ParseFromString(outputs[1].vec<string>()(0)));
|
||||
EXPECT_EQ(program_shape.parameters_size(), 2);
|
||||
EXPECT_TRUE(
|
||||
xla::ShapeUtil::HasPrimitiveType(program_shape.result(), xla::S64));
|
||||
EXPECT_TRUE(xla::ShapeUtil::HasPrimitiveType(
|
||||
xla::Shape(program_shape.result()), xla::S64));
|
||||
}
|
||||
|
||||
} // namespace
|
||||
|
Loading…
Reference in New Issue
Block a user