Propagate sharded argument layouts through TF/XLA bridge.

After parameter sharding, per core argument might have different layout. In XLA compiler we cannot deduce layout for sharded parameter any more (because we cannot access shape_representation_fn any more). So we override XLA parameter layout with sharded parameter layout.

In XlaDeviceContext, CopyCPUTensorToDevice() use shape_representation_fn(cpu_tensor_shape) as device tensor shape, so we must use the same shape as XLA compiler input shape. For CopyDeviceTensorToCPU(), device tensor shape is defined by XLA compiler directly, so we do not need to fix anything.

PiperOrigin-RevId: 284812560
Change-Id: I567f180a8035ff71982d49910b84c98d07eb25d1
This commit is contained in:
Tong Shen 2019-12-10 11:31:26 -08:00 committed by TensorFlower Gardener
parent a9f9b02a3e
commit d7336a9186
2 changed files with 65 additions and 7 deletions

View File

@ -758,9 +758,51 @@ Status XlaCompiler::CompileFunction(
}
// Computes the XLA shape for argument 'arg'.
Status XlaCompiler::XLAShapeForArgument(const XlaCompiler::Argument& arg,
bool is_entry_computation,
xla::Shape* xla_shape) const {
Status XlaCompiler::XLAShapeForArgument(
const XlaCompiler::Argument& arg, bool is_entry_computation,
const absl::optional<xla::HloSharding>& arg_sharding,
xla::Shape* xla_shape) const {
auto rewrite_layout_with_sharded_shape =
[](const absl::optional<xla::HloSharding>& arg_sharding,
bool use_fast_memory,
XlaCompiler::ShapeRepresentationFn shape_representation_fn,
xla::Shape* xla_shape) {
if (arg_sharding && !arg_sharding->IsTileMaximal()) {
// After parameter sharding, per core parameter might have different
// layout. For example, before sharding, a parameter of shape [128,
// 128] will be assigned default minor-to-major {1, 0}. But after we
// shard this parameter to [128, 64] * 2, the sharded parameters
// will have minor-to-major {0, 1}.
//
// As a result, for sharded parameters, we set their layout to per
// core parameter's layout.
//
// TODO(endlessroad): for variable input & update, we might have
// different layouts which will prevent input output aliasing and
// increase memory usage. Investigate such cases.
int64 device = *arg_sharding->tile_assignment().begin();
std::vector<int64> offset =
arg_sharding->TileOffsetForDevice(*xla_shape, device);
std::vector<int64> limit =
arg_sharding->TileLimitForDevice(*xla_shape, device);
std::vector<int64> dimensions(xla_shape->rank());
for (int64 i = 0; i < xla_shape->rank(); ++i) {
dimensions[i] = limit[i] - offset[i];
}
xla::Shape per_device_xla_shape =
xla::ShapeUtil::MakeShape(xla_shape->element_type(), dimensions);
TensorShape per_device_tensor_shape;
TF_RETURN_IF_ERROR(XLAShapeToTensorShape(per_device_xla_shape,
&per_device_tensor_shape));
TF_ASSIGN_OR_RETURN(DataType dtype, EncodePrimitiveTypeAsDataType(
xla_shape->element_type()));
TF_ASSIGN_OR_RETURN(per_device_xla_shape,
shape_representation_fn(per_device_tensor_shape,
dtype, use_fast_memory));
*xla_shape->mutable_layout() = per_device_xla_shape.layout();
}
return Status::OK();
};
switch (arg.kind) {
case XlaCompiler::Argument::kConstant:
LOG(FATAL) << "Unreachable case";
@ -776,6 +818,9 @@ Status XlaCompiler::XLAShapeForArgument(const XlaCompiler::Argument& arg,
TF_ASSIGN_OR_RETURN(*xla_shape, options_.shape_representation_fn(
shape, arg.type,
/*use_fast_memory=*/false));
TF_RETURN_IF_ERROR(rewrite_layout_with_sharded_shape(
arg_sharding, /*use_fast_memory=*/false,
options_.shape_representation_fn, xla_shape));
} else {
if (absl::holds_alternative<xla::Shape>(arg.shape)) {
*xla_shape = absl::get<xla::Shape>(arg.shape);
@ -801,6 +846,9 @@ Status XlaCompiler::XLAShapeForArgument(const XlaCompiler::Argument& arg,
options_.shape_representation_fn(
absl::get<TensorShape>(arg.shape), arg.type,
/*use_fast_memory=*/arg.fast_mem));
TF_RETURN_IF_ERROR(rewrite_layout_with_sharded_shape(
arg_sharding, arg.fast_mem, options_.shape_representation_fn,
xla_shape));
return Status::OK();
}
case XlaResource::kTensorArray: {
@ -939,8 +987,16 @@ Status XlaCompiler::BuildArguments(
std::vector<xla::Shape> arg_shapes(input_to_args->size());
for (std::vector<int>::size_type i = 0; i < input_to_args->size(); ++i) {
// Computes the shapes of non-constant arguments.
TF_RETURN_IF_ERROR(XLAShapeForArgument(
args[(*input_to_args)[i]], is_entry_computation, &arg_shapes[i]));
auto arg_sharding = arg_shardings.find((*input_to_args)[i]);
absl::optional<xla::HloSharding> sharding;
if (arg_sharding != arg_shardings.end()) {
TF_ASSIGN_OR_RETURN(auto hlo_sharding,
xla::HloSharding::FromProto(arg_sharding->second));
sharding = hlo_sharding;
}
TF_RETURN_IF_ERROR(XLAShapeForArgument(args[(*input_to_args)[i]],
is_entry_computation, sharding,
&arg_shapes[i]));
}
if (use_tuple_arg) {

View File

@ -381,8 +381,10 @@ class XlaCompiler {
// Returns the shape of the XLA parameter for an argument 'arg'.
// See the class comment for more details about the argument passing
// convention.
Status XLAShapeForArgument(const Argument& arg, bool is_entry_computation,
xla::Shape* xla_shape) const;
Status XLAShapeForArgument(
const Argument& arg, bool is_entry_computation,
const absl::optional<xla::HloSharding>& arg_sharding,
xla::Shape* xla_shape) const;
// Retrieves the channel handle associated with `key`. Allocates
// a new channel handle if none exists.