[TF/XLA] Fix several layout issues.
1. The previous approach might have different layouts for computation.GetProgramShape() and xla_output_shape. It only used shape_representation_fn for xla_output_shape, but not entry's program shape. These being different are often confusing, and may make it hard to reproduce a bug with HLO dump which doesn't have HloModuleConfig. 2. Output shapes were not updated with layout when there is sharding. 3. The updated value of a resource did not preserve the fast_mem annotation on the argument. PiperOrigin-RevId: 295811071 Change-Id: I801a46d3039b2349dd0196cbc14ec3d9a8211d55
This commit is contained in:
parent
8f3272028b
commit
ac2c05a1d5
|
@ -97,6 +97,7 @@ xla::StatusOr<DataType> EncodePrimitiveTypeAsDataType(xla::PrimitiveType type) {
|
||||||
{xla::U16, DT_UINT16},
|
{xla::U16, DT_UINT16},
|
||||||
{xla::U32, DT_UINT32},
|
{xla::U32, DT_UINT32},
|
||||||
{xla::U64, DT_UINT64},
|
{xla::U64, DT_UINT64},
|
||||||
|
{xla::C128, DT_COMPLEX128},
|
||||||
});
|
});
|
||||||
|
|
||||||
auto it = data_type_map.find(type);
|
auto it = data_type_map.find(type);
|
||||||
|
|
|
@ -139,6 +139,86 @@ Status ExecuteGraph(XlaContext* xla_context, std::unique_ptr<Graph> graph,
|
||||||
return Status::OK();
|
return Status::OK();
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// Rewrites the layout of xla_shape if there is tiled sharding.
|
||||||
|
Status RewriteLayoutWithShardedShape(
|
||||||
|
const absl::optional<xla::HloSharding>& sharding, bool use_fast_memory,
|
||||||
|
XlaCompiler::ShapeRepresentationFn shape_representation_fn,
|
||||||
|
xla::Shape* xla_shape) {
|
||||||
|
if (sharding && !sharding->IsTileMaximal()) {
|
||||||
|
// After sharding, per core shape might have different layout. For example,
|
||||||
|
// before sharding, a shape [128, 128] will be assigned default
|
||||||
|
// minor-to-major {1, 0}. But after we shard this shape to [128, 64] * 2,
|
||||||
|
// the sharded shapes will have minor-to-major {0, 1}.
|
||||||
|
//
|
||||||
|
// As a result, for sharded shapes, we set their layout to per core shape's
|
||||||
|
// layout.
|
||||||
|
//
|
||||||
|
// TODO(endlessroad): for variable input & update, we might have
|
||||||
|
// different layouts which will prevent input output aliasing and
|
||||||
|
// increase memory usage. Investigate such cases.
|
||||||
|
int64 device = *sharding->tile_assignment().begin();
|
||||||
|
std::vector<int64> offset =
|
||||||
|
sharding->TileOffsetForDevice(*xla_shape, device);
|
||||||
|
std::vector<int64> limit = sharding->TileLimitForDevice(*xla_shape, device);
|
||||||
|
std::vector<int64> dimensions(xla_shape->rank());
|
||||||
|
for (int64 i = 0; i < xla_shape->rank(); ++i) {
|
||||||
|
dimensions[i] = limit[i] - offset[i];
|
||||||
|
}
|
||||||
|
xla::Shape per_device_xla_shape =
|
||||||
|
xla::ShapeUtil::MakeShape(xla_shape->element_type(), dimensions);
|
||||||
|
TensorShape per_device_tensor_shape;
|
||||||
|
TF_RETURN_IF_ERROR(
|
||||||
|
XLAShapeToTensorShape(per_device_xla_shape, &per_device_tensor_shape));
|
||||||
|
TF_ASSIGN_OR_RETURN(DataType dtype, EncodePrimitiveTypeAsDataType(
|
||||||
|
xla_shape->element_type()));
|
||||||
|
TF_ASSIGN_OR_RETURN(per_device_xla_shape,
|
||||||
|
shape_representation_fn(per_device_tensor_shape, dtype,
|
||||||
|
use_fast_memory));
|
||||||
|
*xla_shape->mutable_layout() = per_device_xla_shape.layout();
|
||||||
|
}
|
||||||
|
return Status::OK();
|
||||||
|
}
|
||||||
|
|
||||||
|
// There is a shape_representation_fn or sharding for an output, this function
|
||||||
|
// uses a reshape to fix the layout.
|
||||||
|
xla::StatusOr<xla::XlaOp> ReshapeWithCorrectRepresentationAndSharding(
|
||||||
|
xla::XlaBuilder* builder, xla::XlaOp original, xla::Shape original_shape,
|
||||||
|
XlaCompiler::ShapeRepresentationFn shape_representation_fn,
|
||||||
|
absl::optional<xla::OpSharding> sharding, bool fast_mem) {
|
||||||
|
if (original_shape.IsTuple()) {
|
||||||
|
std::vector<xla::XlaOp> elements;
|
||||||
|
for (int64 i = 0; i < original_shape.tuple_shapes_size(); ++i) {
|
||||||
|
auto subsharding = sharding ? sharding->tuple_shardings(i) : sharding;
|
||||||
|
TF_ASSIGN_OR_RETURN(auto element,
|
||||||
|
ReshapeWithCorrectRepresentationAndSharding(
|
||||||
|
builder, xla::GetTupleElement(original, i),
|
||||||
|
original_shape.tuple_shapes(i),
|
||||||
|
shape_representation_fn, subsharding, fast_mem));
|
||||||
|
elements.push_back(element);
|
||||||
|
}
|
||||||
|
return xla::Tuple(builder, elements);
|
||||||
|
}
|
||||||
|
if (!original_shape.IsArray()) return original;
|
||||||
|
TensorShape shape;
|
||||||
|
TF_RETURN_IF_ERROR(XLAShapeToTensorShape(original_shape, &shape));
|
||||||
|
TF_ASSIGN_OR_RETURN(DataType dtype, EncodePrimitiveTypeAsDataType(
|
||||||
|
original_shape.element_type()));
|
||||||
|
TF_ASSIGN_OR_RETURN(auto to_shape,
|
||||||
|
shape_representation_fn(shape, dtype, fast_mem));
|
||||||
|
if (sharding) {
|
||||||
|
TF_ASSIGN_OR_RETURN(auto hlo_sharding,
|
||||||
|
xla::HloSharding::FromProto(*sharding));
|
||||||
|
TF_RETURN_IF_ERROR(RewriteLayoutWithShardedShape(
|
||||||
|
hlo_sharding, fast_mem, shape_representation_fn, &to_shape));
|
||||||
|
}
|
||||||
|
if (xla::ShapeUtil::Compatible(original_shape, to_shape)) {
|
||||||
|
for (int64 i = 0; i < original_shape.rank(); ++i) {
|
||||||
|
to_shape.set_dynamic_dimension(i, original_shape.is_dynamic_dimension(i));
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return xla::Reshape(to_shape, original);
|
||||||
|
}
|
||||||
|
|
||||||
// Builds the XLA computation.
|
// Builds the XLA computation.
|
||||||
// - `args` is the list of input arguments
|
// - `args` is the list of input arguments
|
||||||
// - `retvals` is the list of retvals produced by _Retval operators, in index
|
// - `retvals` is the list of retvals produced by _Retval operators, in index
|
||||||
|
@ -188,10 +268,6 @@ Status BuildComputation(
|
||||||
std::vector<xla::XlaOp> elems;
|
std::vector<xla::XlaOp> elems;
|
||||||
elems.reserve(retvals.size());
|
elems.reserve(retvals.size());
|
||||||
|
|
||||||
// Keeps track of the layout of each retval. If a retval is not in this list,
|
|
||||||
// a descending layout is used. The first element is the output index, second
|
|
||||||
// element is the new layout.
|
|
||||||
std::vector<std::pair<int64, xla::Layout>> retval_index_and_layout;
|
|
||||||
// Keeps track of sharding of each retval. If a retval is not in this list,
|
// Keeps track of sharding of each retval. If a retval is not in this list,
|
||||||
// replicate sharding is used. The first element is the output index, second
|
// replicate sharding is used. The first element is the output index, second
|
||||||
// element is the sharding.
|
// element is the sharding.
|
||||||
|
@ -219,22 +295,22 @@ Status BuildComputation(
|
||||||
TF_ASSIGN_OR_RETURN(output.shape, retval.GetShape());
|
TF_ASSIGN_OR_RETURN(output.shape, retval.GetShape());
|
||||||
xla::XlaOp value = retval.handle();
|
xla::XlaOp value = retval.handle();
|
||||||
auto it = retval_shardings.find(i);
|
auto it = retval_shardings.find(i);
|
||||||
xla::XlaScopedShardingAssignment assign_sharding(
|
absl::optional<xla::OpSharding> sharding =
|
||||||
builder, it == retval_shardings.end()
|
it == retval_shardings.end() ? absl::optional<xla::OpSharding>()
|
||||||
? absl::optional<xla::OpSharding>()
|
: it->second;
|
||||||
: it->second);
|
|
||||||
if (it != retval_shardings.end()) {
|
if (it != retval_shardings.end()) {
|
||||||
retval_index_and_sharding[elems.size()] = it->second;
|
retval_index_and_sharding[elems.size()] = it->second;
|
||||||
}
|
}
|
||||||
if (shape_representation_fn) {
|
if (shape_representation_fn) {
|
||||||
// If there is a shape representation function, reshape the output
|
TF_ASSIGN_OR_RETURN(auto original_shape, builder->GetShape(value));
|
||||||
// tensor to the shape given by the representation shape function.
|
TF_ASSIGN_OR_RETURN(value,
|
||||||
TF_ASSIGN_OR_RETURN(xla::Shape shape, shape_representation_fn(
|
ReshapeWithCorrectRepresentationAndSharding(
|
||||||
output.shape, output.type,
|
builder, value, original_shape,
|
||||||
/*use_fast_memory=*/false));
|
shape_representation_fn, sharding,
|
||||||
value = xla::Reshape(value, xla::AsInt64Slice(shape.dimensions()));
|
/*fast_mem=*/false));
|
||||||
retval_index_and_layout.emplace_back(elems.size(), shape.layout());
|
}
|
||||||
} else if (it != retval_shardings.end()) {
|
if (it != retval_shardings.end()) {
|
||||||
|
xla::XlaScopedShardingAssignment assign_sharding(builder, sharding);
|
||||||
// Apply the sharding to the output, if there is a core assignment.
|
// Apply the sharding to the output, if there is a core assignment.
|
||||||
value = identity_op(value);
|
value = identity_op(value);
|
||||||
}
|
}
|
||||||
|
@ -312,43 +388,27 @@ Status BuildComputation(
|
||||||
update.tensor_array_gradients_accessed.insert(grad.first);
|
update.tensor_array_gradients_accessed.insert(grad.first);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
xla::XlaOp handle;
|
||||||
|
TF_RETURN_IF_ERROR(resource->Pack(&handle, builder));
|
||||||
|
auto sharding = it == arg_shardings.end()
|
||||||
|
? absl::optional<xla::OpSharding>()
|
||||||
|
: it->second;
|
||||||
|
// Set layout of the retval to device representation layout.
|
||||||
|
if (shape_representation_fn) {
|
||||||
|
TF_ASSIGN_OR_RETURN(auto original_shape, builder->GetShape(handle));
|
||||||
|
TF_ASSIGN_OR_RETURN(
|
||||||
|
handle, ReshapeWithCorrectRepresentationAndSharding(
|
||||||
|
builder, handle, original_shape,
|
||||||
|
shape_representation_fn, sharding, arg.fast_mem));
|
||||||
|
}
|
||||||
|
|
||||||
// Request that the value be returned on a specific core.
|
// Request that the value be returned on a specific core.
|
||||||
xla::XlaScopedShardingAssignment assign_sharding(
|
xla::XlaScopedShardingAssignment assign_sharding(builder, sharding);
|
||||||
builder, it == arg_shardings.end() ? absl::optional<xla::OpSharding>()
|
|
||||||
: it->second);
|
|
||||||
if (it != arg_shardings.end()) {
|
if (it != arg_shardings.end()) {
|
||||||
retval_index_and_sharding[elems.size()] = it->second;
|
retval_index_and_sharding[elems.size()] = it->second;
|
||||||
}
|
}
|
||||||
|
|
||||||
xla::XlaOp handle;
|
|
||||||
TF_RETURN_IF_ERROR(resource->Pack(&handle, builder));
|
|
||||||
|
|
||||||
// Ensures the correct sharding is applied to the output.
|
// Ensures the correct sharding is applied to the output.
|
||||||
handle = identity_op(handle);
|
handle = identity_op(handle);
|
||||||
|
|
||||||
// Set layout of the retval to device representation layout.
|
|
||||||
absl::optional<xla::Shape> representation_shape;
|
|
||||||
if (shape_representation_fn) {
|
|
||||||
TF_ASSIGN_OR_RETURN(
|
|
||||||
xla::Shape xla_shape,
|
|
||||||
shape_representation_fn(resource->shape(), resource->type(),
|
|
||||||
/*use_fast_memory=*/false));
|
|
||||||
representation_shape = xla_shape;
|
|
||||||
}
|
|
||||||
if (resource->representation_shape().has_value()) {
|
|
||||||
const xla::Shape& xla_shape = resource->representation_shape().value();
|
|
||||||
if (representation_shape) {
|
|
||||||
TF_RET_CHECK(
|
|
||||||
xla::ShapeUtil::Compatible(*representation_shape, xla_shape));
|
|
||||||
} else {
|
|
||||||
representation_shape = xla_shape;
|
|
||||||
}
|
|
||||||
}
|
|
||||||
if (representation_shape) {
|
|
||||||
retval_index_and_layout.emplace_back(elems.size(),
|
|
||||||
representation_shape->layout());
|
|
||||||
}
|
|
||||||
|
|
||||||
elems.push_back(handle);
|
elems.push_back(handle);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
@ -411,20 +471,8 @@ Status BuildComputation(
|
||||||
}
|
}
|
||||||
*computation = computation_status.ConsumeValueOrDie();
|
*computation = computation_status.ConsumeValueOrDie();
|
||||||
|
|
||||||
TF_ASSIGN_OR_RETURN(const auto& program_shape,
|
TF_ASSIGN_OR_RETURN(auto program_shape, computation->GetProgramShape());
|
||||||
computation->GetProgramShape());
|
|
||||||
*output_shape = program_shape.result();
|
*output_shape = program_shape.result();
|
||||||
// Update the output layout to the layout of retval.
|
|
||||||
for (auto& index_and_layout : retval_index_and_layout) {
|
|
||||||
if (!always_return_tuple && elems.size() == 1) {
|
|
||||||
*output_shape->mutable_layout() = index_and_layout.second;
|
|
||||||
continue;
|
|
||||||
}
|
|
||||||
|
|
||||||
xla::Shape* output_sub_shape = xla::ShapeUtil::GetMutableSubshape(
|
|
||||||
output_shape, {index_and_layout.first});
|
|
||||||
*output_sub_shape->mutable_layout() = index_and_layout.second;
|
|
||||||
}
|
|
||||||
return Status::OK();
|
return Status::OK();
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -779,47 +827,6 @@ Status XlaCompiler::XLAShapeForArgument(
|
||||||
const XlaCompiler::Argument& arg, bool is_entry_computation,
|
const XlaCompiler::Argument& arg, bool is_entry_computation,
|
||||||
const absl::optional<xla::HloSharding>& arg_sharding,
|
const absl::optional<xla::HloSharding>& arg_sharding,
|
||||||
xla::Shape* xla_shape) const {
|
xla::Shape* xla_shape) const {
|
||||||
auto rewrite_layout_with_sharded_shape =
|
|
||||||
[](const absl::optional<xla::HloSharding>& arg_sharding,
|
|
||||||
bool use_fast_memory,
|
|
||||||
XlaCompiler::ShapeRepresentationFn shape_representation_fn,
|
|
||||||
xla::Shape* xla_shape) {
|
|
||||||
if (arg_sharding && !arg_sharding->IsTileMaximal()) {
|
|
||||||
// After parameter sharding, per core parameter might have different
|
|
||||||
// layout. For example, before sharding, a parameter of shape [128,
|
|
||||||
// 128] will be assigned default minor-to-major {1, 0}. But after we
|
|
||||||
// shard this parameter to [128, 64] * 2, the sharded parameters
|
|
||||||
// will have minor-to-major {0, 1}.
|
|
||||||
//
|
|
||||||
// As a result, for sharded parameters, we set their layout to per
|
|
||||||
// core parameter's layout.
|
|
||||||
//
|
|
||||||
// TODO(endlessroad): for variable input & update, we might have
|
|
||||||
// different layouts which will prevent input output aliasing and
|
|
||||||
// increase memory usage. Investigate such cases.
|
|
||||||
int64 device = *arg_sharding->tile_assignment().begin();
|
|
||||||
std::vector<int64> offset =
|
|
||||||
arg_sharding->TileOffsetForDevice(*xla_shape, device);
|
|
||||||
std::vector<int64> limit =
|
|
||||||
arg_sharding->TileLimitForDevice(*xla_shape, device);
|
|
||||||
std::vector<int64> dimensions(xla_shape->rank());
|
|
||||||
for (int64 i = 0; i < xla_shape->rank(); ++i) {
|
|
||||||
dimensions[i] = limit[i] - offset[i];
|
|
||||||
}
|
|
||||||
xla::Shape per_device_xla_shape =
|
|
||||||
xla::ShapeUtil::MakeShape(xla_shape->element_type(), dimensions);
|
|
||||||
TensorShape per_device_tensor_shape;
|
|
||||||
TF_RETURN_IF_ERROR(XLAShapeToTensorShape(per_device_xla_shape,
|
|
||||||
&per_device_tensor_shape));
|
|
||||||
TF_ASSIGN_OR_RETURN(DataType dtype, EncodePrimitiveTypeAsDataType(
|
|
||||||
xla_shape->element_type()));
|
|
||||||
TF_ASSIGN_OR_RETURN(per_device_xla_shape,
|
|
||||||
shape_representation_fn(per_device_tensor_shape,
|
|
||||||
dtype, use_fast_memory));
|
|
||||||
*xla_shape->mutable_layout() = per_device_xla_shape.layout();
|
|
||||||
}
|
|
||||||
return Status::OK();
|
|
||||||
};
|
|
||||||
switch (arg.kind) {
|
switch (arg.kind) {
|
||||||
case XlaCompiler::Argument::kConstant:
|
case XlaCompiler::Argument::kConstant:
|
||||||
LOG(FATAL) << "Unreachable case";
|
LOG(FATAL) << "Unreachable case";
|
||||||
|
@ -835,7 +842,7 @@ Status XlaCompiler::XLAShapeForArgument(
|
||||||
TF_ASSIGN_OR_RETURN(*xla_shape, options_.shape_representation_fn(
|
TF_ASSIGN_OR_RETURN(*xla_shape, options_.shape_representation_fn(
|
||||||
shape, arg.type,
|
shape, arg.type,
|
||||||
/*use_fast_memory=*/false));
|
/*use_fast_memory=*/false));
|
||||||
TF_RETURN_IF_ERROR(rewrite_layout_with_sharded_shape(
|
TF_RETURN_IF_ERROR(RewriteLayoutWithShardedShape(
|
||||||
arg_sharding, /*use_fast_memory=*/false,
|
arg_sharding, /*use_fast_memory=*/false,
|
||||||
options_.shape_representation_fn, xla_shape));
|
options_.shape_representation_fn, xla_shape));
|
||||||
} else {
|
} else {
|
||||||
|
@ -863,7 +870,7 @@ Status XlaCompiler::XLAShapeForArgument(
|
||||||
options_.shape_representation_fn(
|
options_.shape_representation_fn(
|
||||||
absl::get<TensorShape>(arg.shape), arg.type,
|
absl::get<TensorShape>(arg.shape), arg.type,
|
||||||
/*use_fast_memory=*/arg.fast_mem));
|
/*use_fast_memory=*/arg.fast_mem));
|
||||||
TF_RETURN_IF_ERROR(rewrite_layout_with_sharded_shape(
|
TF_RETURN_IF_ERROR(RewriteLayoutWithShardedShape(
|
||||||
arg_sharding, arg.fast_mem, options_.shape_representation_fn,
|
arg_sharding, arg.fast_mem, options_.shape_representation_fn,
|
||||||
xla_shape));
|
xla_shape));
|
||||||
return Status::OK();
|
return Status::OK();
|
||||||
|
|
|
@ -365,7 +365,8 @@ TEST_F(XlaCompilerTest, HonorShapeRepresentationFnForFastMemVar) {
|
||||||
compile_options.return_updated_values_for_all_resources = true;
|
compile_options.return_updated_values_for_all_resources = true;
|
||||||
TF_ASSERT_OK(compiler.CompileGraph(compile_options, "add", std::move(graph),
|
TF_ASSERT_OK(compiler.CompileGraph(compile_options, "add", std::move(graph),
|
||||||
args, &result));
|
args, &result));
|
||||||
EXPECT_EQ(fast_mem_arg_count, 1);
|
// Count 2: one for argument, one for the return value.
|
||||||
|
EXPECT_EQ(fast_mem_arg_count, 2);
|
||||||
}
|
}
|
||||||
|
|
||||||
// Tests that the compiler can correctly propagate the layout assigned by
|
// Tests that the compiler can correctly propagate the layout assigned by
|
||||||
|
@ -417,6 +418,8 @@ TEST_F(XlaCompilerTest, HonorShapeRepresentationFnForRetVal) {
|
||||||
// Check that the return shapes are correctly tranposed.
|
// Check that the return shapes are correctly tranposed.
|
||||||
EXPECT_EQ(result.xla_output_shape,
|
EXPECT_EQ(result.xla_output_shape,
|
||||||
xla::ShapeUtil::MakeTupleShape({transposed, transposed}));
|
xla::ShapeUtil::MakeTupleShape({transposed, transposed}));
|
||||||
|
EXPECT_EQ(result.computation->GetProgramShape().ConsumeValueOrDie().result(),
|
||||||
|
xla::ShapeUtil::MakeTupleShape({transposed, transposed}));
|
||||||
}
|
}
|
||||||
|
|
||||||
// The layout of resource variable shouldn't change after transpose
|
// The layout of resource variable shouldn't change after transpose
|
||||||
|
@ -1091,6 +1094,8 @@ TEST_F(XlaCompilerTest, ResultLayoutSingle) {
|
||||||
EXPECT_TRUE(xla::ShapeUtil::Equal(
|
EXPECT_TRUE(xla::ShapeUtil::Equal(
|
||||||
result.xla_output_shape,
|
result.xla_output_shape,
|
||||||
xla::ShapeUtil::MakeShapeWithLayout(xla::S32, {2, 3}, {0, 1})));
|
xla::ShapeUtil::MakeShapeWithLayout(xla::S32, {2, 3}, {0, 1})));
|
||||||
|
EXPECT_EQ(result.computation->GetProgramShape().ConsumeValueOrDie().result(),
|
||||||
|
result.xla_output_shape);
|
||||||
}
|
}
|
||||||
|
|
||||||
TEST_F(XlaCompilerTest, ResultLayoutMultiple) {
|
TEST_F(XlaCompilerTest, ResultLayoutMultiple) {
|
||||||
|
@ -1131,6 +1136,8 @@ TEST_F(XlaCompilerTest, ResultLayoutMultiple) {
|
||||||
EXPECT_TRUE(xla::ShapeUtil::Equal(
|
EXPECT_TRUE(xla::ShapeUtil::Equal(
|
||||||
result.xla_output_shape,
|
result.xla_output_shape,
|
||||||
xla::ShapeUtil::MakeTupleShape({result_shape, result_shape})));
|
xla::ShapeUtil::MakeTupleShape({result_shape, result_shape})));
|
||||||
|
EXPECT_EQ(result.computation->GetProgramShape().ConsumeValueOrDie().result(),
|
||||||
|
result.xla_output_shape);
|
||||||
}
|
}
|
||||||
|
|
||||||
// Tests a simple graph that reads and writes a variable.
|
// Tests a simple graph that reads and writes a variable.
|
||||||
|
|
|
@ -528,7 +528,8 @@ StatusOr<XlaOp> XlaBuilder::AddBroadcastSequence(const Shape& output_shape,
|
||||||
}
|
}
|
||||||
|
|
||||||
// Eliminate the size one dimensions.
|
// Eliminate the size one dimensions.
|
||||||
TF_ASSIGN_OR_RETURN(XlaOp reshaped_operand, Reshape(reshaped_shape, operand));
|
TF_ASSIGN_OR_RETURN(XlaOp reshaped_operand,
|
||||||
|
ReshapeInternal(reshaped_shape, operand));
|
||||||
// Broadcast 'reshape' up to the larger size.
|
// Broadcast 'reshape' up to the larger size.
|
||||||
return InDimBroadcast(broadcast_shape, reshaped_operand,
|
return InDimBroadcast(broadcast_shape, reshaped_operand,
|
||||||
broadcast_dimensions);
|
broadcast_dimensions);
|
||||||
|
@ -828,7 +829,7 @@ XlaOp XlaBuilder::BroadcastInDim(
|
||||||
});
|
});
|
||||||
}
|
}
|
||||||
|
|
||||||
StatusOr<XlaOp> XlaBuilder::Reshape(const Shape& shape, XlaOp operand,
|
StatusOr<XlaOp> XlaBuilder::ReshapeInternal(const Shape& shape, XlaOp operand,
|
||||||
int64 inferred_dimension) {
|
int64 inferred_dimension) {
|
||||||
TF_RETURN_IF_ERROR(first_error_);
|
TF_RETURN_IF_ERROR(first_error_);
|
||||||
|
|
||||||
|
@ -1020,7 +1021,7 @@ XlaOp XlaBuilder::Reshape(XlaOp operand, absl::Span<const int64> dimensions,
|
||||||
XlaOp transposed = IsIdentityPermutation(dimensions)
|
XlaOp transposed = IsIdentityPermutation(dimensions)
|
||||||
? operand
|
? operand
|
||||||
: Transpose(operand, dimensions);
|
: Transpose(operand, dimensions);
|
||||||
return Reshape(shape, transposed, inferred_dimension);
|
return ReshapeInternal(shape, transposed, inferred_dimension);
|
||||||
});
|
});
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -1034,6 +1035,13 @@ XlaOp XlaBuilder::Reshape(XlaOp operand, absl::Span<const int64> new_sizes,
|
||||||
});
|
});
|
||||||
}
|
}
|
||||||
|
|
||||||
|
XlaOp XlaBuilder::Reshape(const Shape& shape, XlaOp operand,
|
||||||
|
int64 inferred_dimension) {
|
||||||
|
return ReportErrorOrReturn([&]() -> StatusOr<XlaOp> {
|
||||||
|
return ReshapeInternal(shape, operand, inferred_dimension);
|
||||||
|
});
|
||||||
|
}
|
||||||
|
|
||||||
XlaOp XlaBuilder::Collapse(XlaOp operand, absl::Span<const int64> dimensions) {
|
XlaOp XlaBuilder::Collapse(XlaOp operand, absl::Span<const int64> dimensions) {
|
||||||
return ReportErrorOrReturn([&]() -> StatusOr<XlaOp> {
|
return ReportErrorOrReturn([&]() -> StatusOr<XlaOp> {
|
||||||
if (dimensions.size() <= 1) {
|
if (dimensions.size() <= 1) {
|
||||||
|
@ -2951,6 +2959,10 @@ XlaOp Reshape(const XlaOp operand, absl::Span<const int64> new_sizes) {
|
||||||
return operand.builder()->Reshape(operand, new_sizes);
|
return operand.builder()->Reshape(operand, new_sizes);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
XlaOp Reshape(const Shape& shape, XlaOp operand) {
|
||||||
|
return operand.builder()->Reshape(shape, operand);
|
||||||
|
}
|
||||||
|
|
||||||
XlaOp ReshapeWithInferredDimension(XlaOp operand,
|
XlaOp ReshapeWithInferredDimension(XlaOp operand,
|
||||||
absl::Span<const int64> new_sizes,
|
absl::Span<const int64> new_sizes,
|
||||||
int64 inferred_dimension) {
|
int64 inferred_dimension) {
|
||||||
|
|
|
@ -397,6 +397,9 @@ class XlaBuilder {
|
||||||
XlaOp Reshape(XlaOp operand, absl::Span<const int64> new_sizes,
|
XlaOp Reshape(XlaOp operand, absl::Span<const int64> new_sizes,
|
||||||
int64 inferred_dimension = -1);
|
int64 inferred_dimension = -1);
|
||||||
|
|
||||||
|
XlaOp Reshape(const Shape& shape, XlaOp operand,
|
||||||
|
int64 inferred_dimension = -1);
|
||||||
|
|
||||||
XlaOp Collapse(XlaOp operand, absl::Span<const int64> dimensions);
|
XlaOp Collapse(XlaOp operand, absl::Span<const int64> dimensions);
|
||||||
|
|
||||||
XlaOp Slice(XlaOp operand, absl::Span<const int64> start_indices,
|
XlaOp Slice(XlaOp operand, absl::Span<const int64> start_indices,
|
||||||
|
@ -668,7 +671,7 @@ class XlaBuilder {
|
||||||
|
|
||||||
// Internal helper method for creating a Reshape op with the already inferred
|
// Internal helper method for creating a Reshape op with the already inferred
|
||||||
// shape.
|
// shape.
|
||||||
StatusOr<XlaOp> Reshape(const Shape& shape, XlaOp operand,
|
StatusOr<XlaOp> ReshapeInternal(const Shape& shape, XlaOp operand,
|
||||||
int64 inferred_dimension = -1);
|
int64 inferred_dimension = -1);
|
||||||
|
|
||||||
// Returns the (inferred) result for the program shape using the given root.
|
// Returns the (inferred) result for the program shape using the given root.
|
||||||
|
@ -777,6 +780,8 @@ class XlaBuilder {
|
||||||
|
|
||||||
friend XlaOp Reshape(XlaOp operand, absl::Span<const int64> new_sizes);
|
friend XlaOp Reshape(XlaOp operand, absl::Span<const int64> new_sizes);
|
||||||
|
|
||||||
|
friend XlaOp Reshape(const Shape& shape, XlaOp operand);
|
||||||
|
|
||||||
friend XlaOp ReshapeWithInferredDimension(XlaOp operand,
|
friend XlaOp ReshapeWithInferredDimension(XlaOp operand,
|
||||||
absl::Span<const int64> new_sizes,
|
absl::Span<const int64> new_sizes,
|
||||||
int64 inferred_dimension);
|
int64 inferred_dimension);
|
||||||
|
@ -1252,6 +1257,9 @@ XlaOp Reshape(XlaOp operand, absl::Span<const int64> dimensions,
|
||||||
// sizes. Conceptually, this is a limited form of "shape casting".
|
// sizes. Conceptually, this is a limited form of "shape casting".
|
||||||
XlaOp Reshape(XlaOp operand, absl::Span<const int64> new_sizes);
|
XlaOp Reshape(XlaOp operand, absl::Span<const int64> new_sizes);
|
||||||
|
|
||||||
|
// Enqueues a Reshape op that uses an explicit target shape.
|
||||||
|
XlaOp Reshape(const Shape& shape, XlaOp operand);
|
||||||
|
|
||||||
// `inferred_dimension` represents the output dimension that's inferred by
|
// `inferred_dimension` represents the output dimension that's inferred by
|
||||||
// upper-level framework by dividing the input element count by the known
|
// upper-level framework by dividing the input element count by the known
|
||||||
// output element count. While an inferred_dimension can be static, if there
|
// output element count. While an inferred_dimension can be static, if there
|
||||||
|
|
Loading…
Reference in New Issue