[TF/XLA] Fix several layout issues.

1. The previous approach might have different layouts for computation.GetProgramShape() and xla_output_shape. It only used shape_representation_fn for xla_output_shape, but not entry's program shape. These being different are often confusing, and may make it hard to reproduce a bug with HLO dump which doesn't have HloModuleConfig.

2. Output shapes were not updated with layout when there is sharding.

3. The updated value of a resource did not preserve the fast_mem annotation on the argument.

PiperOrigin-RevId: 295811071
Change-Id: I801a46d3039b2349dd0196cbc14ec3d9a8211d55
This commit is contained in:
Yuanzhong Xu 2020-02-18 13:39:31 -08:00 committed by TensorFlower Gardener
parent 8f3272028b
commit ac2c05a1d5
5 changed files with 145 additions and 110 deletions

View File

@ -97,6 +97,7 @@ xla::StatusOr<DataType> EncodePrimitiveTypeAsDataType(xla::PrimitiveType type) {
{xla::U16, DT_UINT16}, {xla::U16, DT_UINT16},
{xla::U32, DT_UINT32}, {xla::U32, DT_UINT32},
{xla::U64, DT_UINT64}, {xla::U64, DT_UINT64},
{xla::C128, DT_COMPLEX128},
}); });
auto it = data_type_map.find(type); auto it = data_type_map.find(type);

View File

@ -139,6 +139,86 @@ Status ExecuteGraph(XlaContext* xla_context, std::unique_ptr<Graph> graph,
return Status::OK(); return Status::OK();
} }
// Rewrites the layout of xla_shape if there is tiled sharding.
Status RewriteLayoutWithShardedShape(
const absl::optional<xla::HloSharding>& sharding, bool use_fast_memory,
XlaCompiler::ShapeRepresentationFn shape_representation_fn,
xla::Shape* xla_shape) {
if (sharding && !sharding->IsTileMaximal()) {
// After sharding, per core shape might have different layout. For example,
// before sharding, a shape [128, 128] will be assigned default
// minor-to-major {1, 0}. But after we shard this shape to [128, 64] * 2,
// the sharded shapes will have minor-to-major {0, 1}.
//
// As a result, for sharded shapes, we set their layout to per core shape's
// layout.
//
// TODO(endlessroad): for variable input & update, we might have
// different layouts which will prevent input output aliasing and
// increase memory usage. Investigate such cases.
int64 device = *sharding->tile_assignment().begin();
std::vector<int64> offset =
sharding->TileOffsetForDevice(*xla_shape, device);
std::vector<int64> limit = sharding->TileLimitForDevice(*xla_shape, device);
std::vector<int64> dimensions(xla_shape->rank());
for (int64 i = 0; i < xla_shape->rank(); ++i) {
dimensions[i] = limit[i] - offset[i];
}
xla::Shape per_device_xla_shape =
xla::ShapeUtil::MakeShape(xla_shape->element_type(), dimensions);
TensorShape per_device_tensor_shape;
TF_RETURN_IF_ERROR(
XLAShapeToTensorShape(per_device_xla_shape, &per_device_tensor_shape));
TF_ASSIGN_OR_RETURN(DataType dtype, EncodePrimitiveTypeAsDataType(
xla_shape->element_type()));
TF_ASSIGN_OR_RETURN(per_device_xla_shape,
shape_representation_fn(per_device_tensor_shape, dtype,
use_fast_memory));
*xla_shape->mutable_layout() = per_device_xla_shape.layout();
}
return Status::OK();
}
// There is a shape_representation_fn or sharding for an output, this function
// uses a reshape to fix the layout.
xla::StatusOr<xla::XlaOp> ReshapeWithCorrectRepresentationAndSharding(
xla::XlaBuilder* builder, xla::XlaOp original, xla::Shape original_shape,
XlaCompiler::ShapeRepresentationFn shape_representation_fn,
absl::optional<xla::OpSharding> sharding, bool fast_mem) {
if (original_shape.IsTuple()) {
std::vector<xla::XlaOp> elements;
for (int64 i = 0; i < original_shape.tuple_shapes_size(); ++i) {
auto subsharding = sharding ? sharding->tuple_shardings(i) : sharding;
TF_ASSIGN_OR_RETURN(auto element,
ReshapeWithCorrectRepresentationAndSharding(
builder, xla::GetTupleElement(original, i),
original_shape.tuple_shapes(i),
shape_representation_fn, subsharding, fast_mem));
elements.push_back(element);
}
return xla::Tuple(builder, elements);
}
if (!original_shape.IsArray()) return original;
TensorShape shape;
TF_RETURN_IF_ERROR(XLAShapeToTensorShape(original_shape, &shape));
TF_ASSIGN_OR_RETURN(DataType dtype, EncodePrimitiveTypeAsDataType(
original_shape.element_type()));
TF_ASSIGN_OR_RETURN(auto to_shape,
shape_representation_fn(shape, dtype, fast_mem));
if (sharding) {
TF_ASSIGN_OR_RETURN(auto hlo_sharding,
xla::HloSharding::FromProto(*sharding));
TF_RETURN_IF_ERROR(RewriteLayoutWithShardedShape(
hlo_sharding, fast_mem, shape_representation_fn, &to_shape));
}
if (xla::ShapeUtil::Compatible(original_shape, to_shape)) {
for (int64 i = 0; i < original_shape.rank(); ++i) {
to_shape.set_dynamic_dimension(i, original_shape.is_dynamic_dimension(i));
}
}
return xla::Reshape(to_shape, original);
}
// Builds the XLA computation. // Builds the XLA computation.
// - `args` is the list of input arguments // - `args` is the list of input arguments
// - `retvals` is the list of retvals produced by _Retval operators, in index // - `retvals` is the list of retvals produced by _Retval operators, in index
@ -188,10 +268,6 @@ Status BuildComputation(
std::vector<xla::XlaOp> elems; std::vector<xla::XlaOp> elems;
elems.reserve(retvals.size()); elems.reserve(retvals.size());
// Keeps track of the layout of each retval. If a retval is not in this list,
// a descending layout is used. The first element is the output index, second
// element is the new layout.
std::vector<std::pair<int64, xla::Layout>> retval_index_and_layout;
// Keeps track of sharding of each retval. If a retval is not in this list, // Keeps track of sharding of each retval. If a retval is not in this list,
// replicate sharding is used. The first element is the output index, second // replicate sharding is used. The first element is the output index, second
// element is the sharding. // element is the sharding.
@ -219,22 +295,22 @@ Status BuildComputation(
TF_ASSIGN_OR_RETURN(output.shape, retval.GetShape()); TF_ASSIGN_OR_RETURN(output.shape, retval.GetShape());
xla::XlaOp value = retval.handle(); xla::XlaOp value = retval.handle();
auto it = retval_shardings.find(i); auto it = retval_shardings.find(i);
xla::XlaScopedShardingAssignment assign_sharding( absl::optional<xla::OpSharding> sharding =
builder, it == retval_shardings.end() it == retval_shardings.end() ? absl::optional<xla::OpSharding>()
? absl::optional<xla::OpSharding>() : it->second;
: it->second);
if (it != retval_shardings.end()) { if (it != retval_shardings.end()) {
retval_index_and_sharding[elems.size()] = it->second; retval_index_and_sharding[elems.size()] = it->second;
} }
if (shape_representation_fn) { if (shape_representation_fn) {
// If there is a shape representation function, reshape the output TF_ASSIGN_OR_RETURN(auto original_shape, builder->GetShape(value));
// tensor to the shape given by the representation shape function. TF_ASSIGN_OR_RETURN(value,
TF_ASSIGN_OR_RETURN(xla::Shape shape, shape_representation_fn( ReshapeWithCorrectRepresentationAndSharding(
output.shape, output.type, builder, value, original_shape,
/*use_fast_memory=*/false)); shape_representation_fn, sharding,
value = xla::Reshape(value, xla::AsInt64Slice(shape.dimensions())); /*fast_mem=*/false));
retval_index_and_layout.emplace_back(elems.size(), shape.layout()); }
} else if (it != retval_shardings.end()) { if (it != retval_shardings.end()) {
xla::XlaScopedShardingAssignment assign_sharding(builder, sharding);
// Apply the sharding to the output, if there is a core assignment. // Apply the sharding to the output, if there is a core assignment.
value = identity_op(value); value = identity_op(value);
} }
@ -312,43 +388,27 @@ Status BuildComputation(
update.tensor_array_gradients_accessed.insert(grad.first); update.tensor_array_gradients_accessed.insert(grad.first);
} }
xla::XlaOp handle;
TF_RETURN_IF_ERROR(resource->Pack(&handle, builder));
auto sharding = it == arg_shardings.end()
? absl::optional<xla::OpSharding>()
: it->second;
// Set layout of the retval to device representation layout.
if (shape_representation_fn) {
TF_ASSIGN_OR_RETURN(auto original_shape, builder->GetShape(handle));
TF_ASSIGN_OR_RETURN(
handle, ReshapeWithCorrectRepresentationAndSharding(
builder, handle, original_shape,
shape_representation_fn, sharding, arg.fast_mem));
}
// Request that the value be returned on a specific core. // Request that the value be returned on a specific core.
xla::XlaScopedShardingAssignment assign_sharding( xla::XlaScopedShardingAssignment assign_sharding(builder, sharding);
builder, it == arg_shardings.end() ? absl::optional<xla::OpSharding>()
: it->second);
if (it != arg_shardings.end()) { if (it != arg_shardings.end()) {
retval_index_and_sharding[elems.size()] = it->second; retval_index_and_sharding[elems.size()] = it->second;
} }
xla::XlaOp handle;
TF_RETURN_IF_ERROR(resource->Pack(&handle, builder));
// Ensures the correct sharding is applied to the output. // Ensures the correct sharding is applied to the output.
handle = identity_op(handle); handle = identity_op(handle);
// Set layout of the retval to device representation layout.
absl::optional<xla::Shape> representation_shape;
if (shape_representation_fn) {
TF_ASSIGN_OR_RETURN(
xla::Shape xla_shape,
shape_representation_fn(resource->shape(), resource->type(),
/*use_fast_memory=*/false));
representation_shape = xla_shape;
}
if (resource->representation_shape().has_value()) {
const xla::Shape& xla_shape = resource->representation_shape().value();
if (representation_shape) {
TF_RET_CHECK(
xla::ShapeUtil::Compatible(*representation_shape, xla_shape));
} else {
representation_shape = xla_shape;
}
}
if (representation_shape) {
retval_index_and_layout.emplace_back(elems.size(),
representation_shape->layout());
}
elems.push_back(handle); elems.push_back(handle);
} }
} }
@ -411,20 +471,8 @@ Status BuildComputation(
} }
*computation = computation_status.ConsumeValueOrDie(); *computation = computation_status.ConsumeValueOrDie();
TF_ASSIGN_OR_RETURN(const auto& program_shape, TF_ASSIGN_OR_RETURN(auto program_shape, computation->GetProgramShape());
computation->GetProgramShape());
*output_shape = program_shape.result(); *output_shape = program_shape.result();
// Update the output layout to the layout of retval.
for (auto& index_and_layout : retval_index_and_layout) {
if (!always_return_tuple && elems.size() == 1) {
*output_shape->mutable_layout() = index_and_layout.second;
continue;
}
xla::Shape* output_sub_shape = xla::ShapeUtil::GetMutableSubshape(
output_shape, {index_and_layout.first});
*output_sub_shape->mutable_layout() = index_and_layout.second;
}
return Status::OK(); return Status::OK();
} }
@ -779,47 +827,6 @@ Status XlaCompiler::XLAShapeForArgument(
const XlaCompiler::Argument& arg, bool is_entry_computation, const XlaCompiler::Argument& arg, bool is_entry_computation,
const absl::optional<xla::HloSharding>& arg_sharding, const absl::optional<xla::HloSharding>& arg_sharding,
xla::Shape* xla_shape) const { xla::Shape* xla_shape) const {
auto rewrite_layout_with_sharded_shape =
[](const absl::optional<xla::HloSharding>& arg_sharding,
bool use_fast_memory,
XlaCompiler::ShapeRepresentationFn shape_representation_fn,
xla::Shape* xla_shape) {
if (arg_sharding && !arg_sharding->IsTileMaximal()) {
// After parameter sharding, per core parameter might have different
// layout. For example, before sharding, a parameter of shape [128,
// 128] will be assigned default minor-to-major {1, 0}. But after we
// shard this parameter to [128, 64] * 2, the sharded parameters
// will have minor-to-major {0, 1}.
//
// As a result, for sharded parameters, we set their layout to per
// core parameter's layout.
//
// TODO(endlessroad): for variable input & update, we might have
// different layouts which will prevent input output aliasing and
// increase memory usage. Investigate such cases.
int64 device = *arg_sharding->tile_assignment().begin();
std::vector<int64> offset =
arg_sharding->TileOffsetForDevice(*xla_shape, device);
std::vector<int64> limit =
arg_sharding->TileLimitForDevice(*xla_shape, device);
std::vector<int64> dimensions(xla_shape->rank());
for (int64 i = 0; i < xla_shape->rank(); ++i) {
dimensions[i] = limit[i] - offset[i];
}
xla::Shape per_device_xla_shape =
xla::ShapeUtil::MakeShape(xla_shape->element_type(), dimensions);
TensorShape per_device_tensor_shape;
TF_RETURN_IF_ERROR(XLAShapeToTensorShape(per_device_xla_shape,
&per_device_tensor_shape));
TF_ASSIGN_OR_RETURN(DataType dtype, EncodePrimitiveTypeAsDataType(
xla_shape->element_type()));
TF_ASSIGN_OR_RETURN(per_device_xla_shape,
shape_representation_fn(per_device_tensor_shape,
dtype, use_fast_memory));
*xla_shape->mutable_layout() = per_device_xla_shape.layout();
}
return Status::OK();
};
switch (arg.kind) { switch (arg.kind) {
case XlaCompiler::Argument::kConstant: case XlaCompiler::Argument::kConstant:
LOG(FATAL) << "Unreachable case"; LOG(FATAL) << "Unreachable case";
@ -835,7 +842,7 @@ Status XlaCompiler::XLAShapeForArgument(
TF_ASSIGN_OR_RETURN(*xla_shape, options_.shape_representation_fn( TF_ASSIGN_OR_RETURN(*xla_shape, options_.shape_representation_fn(
shape, arg.type, shape, arg.type,
/*use_fast_memory=*/false)); /*use_fast_memory=*/false));
TF_RETURN_IF_ERROR(rewrite_layout_with_sharded_shape( TF_RETURN_IF_ERROR(RewriteLayoutWithShardedShape(
arg_sharding, /*use_fast_memory=*/false, arg_sharding, /*use_fast_memory=*/false,
options_.shape_representation_fn, xla_shape)); options_.shape_representation_fn, xla_shape));
} else { } else {
@ -863,7 +870,7 @@ Status XlaCompiler::XLAShapeForArgument(
options_.shape_representation_fn( options_.shape_representation_fn(
absl::get<TensorShape>(arg.shape), arg.type, absl::get<TensorShape>(arg.shape), arg.type,
/*use_fast_memory=*/arg.fast_mem)); /*use_fast_memory=*/arg.fast_mem));
TF_RETURN_IF_ERROR(rewrite_layout_with_sharded_shape( TF_RETURN_IF_ERROR(RewriteLayoutWithShardedShape(
arg_sharding, arg.fast_mem, options_.shape_representation_fn, arg_sharding, arg.fast_mem, options_.shape_representation_fn,
xla_shape)); xla_shape));
return Status::OK(); return Status::OK();

View File

@ -365,7 +365,8 @@ TEST_F(XlaCompilerTest, HonorShapeRepresentationFnForFastMemVar) {
compile_options.return_updated_values_for_all_resources = true; compile_options.return_updated_values_for_all_resources = true;
TF_ASSERT_OK(compiler.CompileGraph(compile_options, "add", std::move(graph), TF_ASSERT_OK(compiler.CompileGraph(compile_options, "add", std::move(graph),
args, &result)); args, &result));
EXPECT_EQ(fast_mem_arg_count, 1); // Count 2: one for argument, one for the return value.
EXPECT_EQ(fast_mem_arg_count, 2);
} }
// Tests that the compiler can correctly propagate the layout assigned by // Tests that the compiler can correctly propagate the layout assigned by
@ -417,6 +418,8 @@ TEST_F(XlaCompilerTest, HonorShapeRepresentationFnForRetVal) {
// Check that the return shapes are correctly tranposed. // Check that the return shapes are correctly tranposed.
EXPECT_EQ(result.xla_output_shape, EXPECT_EQ(result.xla_output_shape,
xla::ShapeUtil::MakeTupleShape({transposed, transposed})); xla::ShapeUtil::MakeTupleShape({transposed, transposed}));
EXPECT_EQ(result.computation->GetProgramShape().ConsumeValueOrDie().result(),
xla::ShapeUtil::MakeTupleShape({transposed, transposed}));
} }
// The layout of resource variable shouldn't change after transpose // The layout of resource variable shouldn't change after transpose
@ -1091,6 +1094,8 @@ TEST_F(XlaCompilerTest, ResultLayoutSingle) {
EXPECT_TRUE(xla::ShapeUtil::Equal( EXPECT_TRUE(xla::ShapeUtil::Equal(
result.xla_output_shape, result.xla_output_shape,
xla::ShapeUtil::MakeShapeWithLayout(xla::S32, {2, 3}, {0, 1}))); xla::ShapeUtil::MakeShapeWithLayout(xla::S32, {2, 3}, {0, 1})));
EXPECT_EQ(result.computation->GetProgramShape().ConsumeValueOrDie().result(),
result.xla_output_shape);
} }
TEST_F(XlaCompilerTest, ResultLayoutMultiple) { TEST_F(XlaCompilerTest, ResultLayoutMultiple) {
@ -1131,6 +1136,8 @@ TEST_F(XlaCompilerTest, ResultLayoutMultiple) {
EXPECT_TRUE(xla::ShapeUtil::Equal( EXPECT_TRUE(xla::ShapeUtil::Equal(
result.xla_output_shape, result.xla_output_shape,
xla::ShapeUtil::MakeTupleShape({result_shape, result_shape}))); xla::ShapeUtil::MakeTupleShape({result_shape, result_shape})));
EXPECT_EQ(result.computation->GetProgramShape().ConsumeValueOrDie().result(),
result.xla_output_shape);
} }
// Tests a simple graph that reads and writes a variable. // Tests a simple graph that reads and writes a variable.

View File

@ -528,7 +528,8 @@ StatusOr<XlaOp> XlaBuilder::AddBroadcastSequence(const Shape& output_shape,
} }
// Eliminate the size one dimensions. // Eliminate the size one dimensions.
TF_ASSIGN_OR_RETURN(XlaOp reshaped_operand, Reshape(reshaped_shape, operand)); TF_ASSIGN_OR_RETURN(XlaOp reshaped_operand,
ReshapeInternal(reshaped_shape, operand));
// Broadcast 'reshape' up to the larger size. // Broadcast 'reshape' up to the larger size.
return InDimBroadcast(broadcast_shape, reshaped_operand, return InDimBroadcast(broadcast_shape, reshaped_operand,
broadcast_dimensions); broadcast_dimensions);
@ -828,7 +829,7 @@ XlaOp XlaBuilder::BroadcastInDim(
}); });
} }
StatusOr<XlaOp> XlaBuilder::Reshape(const Shape& shape, XlaOp operand, StatusOr<XlaOp> XlaBuilder::ReshapeInternal(const Shape& shape, XlaOp operand,
int64 inferred_dimension) { int64 inferred_dimension) {
TF_RETURN_IF_ERROR(first_error_); TF_RETURN_IF_ERROR(first_error_);
@ -1020,7 +1021,7 @@ XlaOp XlaBuilder::Reshape(XlaOp operand, absl::Span<const int64> dimensions,
XlaOp transposed = IsIdentityPermutation(dimensions) XlaOp transposed = IsIdentityPermutation(dimensions)
? operand ? operand
: Transpose(operand, dimensions); : Transpose(operand, dimensions);
return Reshape(shape, transposed, inferred_dimension); return ReshapeInternal(shape, transposed, inferred_dimension);
}); });
} }
@ -1034,6 +1035,13 @@ XlaOp XlaBuilder::Reshape(XlaOp operand, absl::Span<const int64> new_sizes,
}); });
} }
XlaOp XlaBuilder::Reshape(const Shape& shape, XlaOp operand,
int64 inferred_dimension) {
return ReportErrorOrReturn([&]() -> StatusOr<XlaOp> {
return ReshapeInternal(shape, operand, inferred_dimension);
});
}
XlaOp XlaBuilder::Collapse(XlaOp operand, absl::Span<const int64> dimensions) { XlaOp XlaBuilder::Collapse(XlaOp operand, absl::Span<const int64> dimensions) {
return ReportErrorOrReturn([&]() -> StatusOr<XlaOp> { return ReportErrorOrReturn([&]() -> StatusOr<XlaOp> {
if (dimensions.size() <= 1) { if (dimensions.size() <= 1) {
@ -2951,6 +2959,10 @@ XlaOp Reshape(const XlaOp operand, absl::Span<const int64> new_sizes) {
return operand.builder()->Reshape(operand, new_sizes); return operand.builder()->Reshape(operand, new_sizes);
} }
XlaOp Reshape(const Shape& shape, XlaOp operand) {
return operand.builder()->Reshape(shape, operand);
}
XlaOp ReshapeWithInferredDimension(XlaOp operand, XlaOp ReshapeWithInferredDimension(XlaOp operand,
absl::Span<const int64> new_sizes, absl::Span<const int64> new_sizes,
int64 inferred_dimension) { int64 inferred_dimension) {

View File

@ -397,6 +397,9 @@ class XlaBuilder {
XlaOp Reshape(XlaOp operand, absl::Span<const int64> new_sizes, XlaOp Reshape(XlaOp operand, absl::Span<const int64> new_sizes,
int64 inferred_dimension = -1); int64 inferred_dimension = -1);
XlaOp Reshape(const Shape& shape, XlaOp operand,
int64 inferred_dimension = -1);
XlaOp Collapse(XlaOp operand, absl::Span<const int64> dimensions); XlaOp Collapse(XlaOp operand, absl::Span<const int64> dimensions);
XlaOp Slice(XlaOp operand, absl::Span<const int64> start_indices, XlaOp Slice(XlaOp operand, absl::Span<const int64> start_indices,
@ -668,7 +671,7 @@ class XlaBuilder {
// Internal helper method for creating a Reshape op with the already inferred // Internal helper method for creating a Reshape op with the already inferred
// shape. // shape.
StatusOr<XlaOp> Reshape(const Shape& shape, XlaOp operand, StatusOr<XlaOp> ReshapeInternal(const Shape& shape, XlaOp operand,
int64 inferred_dimension = -1); int64 inferred_dimension = -1);
// Returns the (inferred) result for the program shape using the given root. // Returns the (inferred) result for the program shape using the given root.
@ -777,6 +780,8 @@ class XlaBuilder {
friend XlaOp Reshape(XlaOp operand, absl::Span<const int64> new_sizes); friend XlaOp Reshape(XlaOp operand, absl::Span<const int64> new_sizes);
friend XlaOp Reshape(const Shape& shape, XlaOp operand);
friend XlaOp ReshapeWithInferredDimension(XlaOp operand, friend XlaOp ReshapeWithInferredDimension(XlaOp operand,
absl::Span<const int64> new_sizes, absl::Span<const int64> new_sizes,
int64 inferred_dimension); int64 inferred_dimension);
@ -1252,6 +1257,9 @@ XlaOp Reshape(XlaOp operand, absl::Span<const int64> dimensions,
// sizes. Conceptually, this is a limited form of "shape casting". // sizes. Conceptually, this is a limited form of "shape casting".
XlaOp Reshape(XlaOp operand, absl::Span<const int64> new_sizes); XlaOp Reshape(XlaOp operand, absl::Span<const int64> new_sizes);
// Enqueues a Reshape op that uses an explicit target shape.
XlaOp Reshape(const Shape& shape, XlaOp operand);
// `inferred_dimension` represents the output dimension that's inferred by // `inferred_dimension` represents the output dimension that's inferred by
// upper-level framework by dividing the input element count by the known // upper-level framework by dividing the input element count by the known
// output element count. While an inferred_dimension can be static, if there // output element count. While an inferred_dimension can be static, if there