diff --git a/tensorflow/compiler/tf2xla/xla_compiler.cc b/tensorflow/compiler/tf2xla/xla_compiler.cc index defd96b570c..d69e40e93e2 100644 --- a/tensorflow/compiler/tf2xla/xla_compiler.cc +++ b/tensorflow/compiler/tf2xla/xla_compiler.cc @@ -758,9 +758,51 @@ Status XlaCompiler::CompileFunction( } // Computes the XLA shape for argument 'arg'. -Status XlaCompiler::XLAShapeForArgument(const XlaCompiler::Argument& arg, - bool is_entry_computation, - xla::Shape* xla_shape) const { +Status XlaCompiler::XLAShapeForArgument( + const XlaCompiler::Argument& arg, bool is_entry_computation, + const absl::optional& arg_sharding, + xla::Shape* xla_shape) const { + auto rewrite_layout_with_sharded_shape = + [](const absl::optional& arg_sharding, + bool use_fast_memory, + XlaCompiler::ShapeRepresentationFn shape_representation_fn, + xla::Shape* xla_shape) { + if (arg_sharding && !arg_sharding->IsTileMaximal()) { + // After parameter sharding, per core parameter might have different + // layout. For example, before sharding, a parameter of shape [128, + // 128] will be assigned default minor-to-major {1, 0}. But after we + // shard this parameter to [128, 64] * 2, the sharded parameters + // will have minor-to-major {0, 1}. + // + // As a result, for sharded parameters, we set their layout to per + // core parameter's layout. + // + // TODO(endlessroad): for variable input & update, we might have + // different layouts which will prevent input output aliasing and + // increase memory usage. Investigate such cases. + int64 device = *arg_sharding->tile_assignment().begin(); + std::vector offset = + arg_sharding->TileOffsetForDevice(*xla_shape, device); + std::vector limit = + arg_sharding->TileLimitForDevice(*xla_shape, device); + std::vector dimensions(xla_shape->rank()); + for (int64 i = 0; i < xla_shape->rank(); ++i) { + dimensions[i] = limit[i] - offset[i]; + } + xla::Shape per_device_xla_shape = + xla::ShapeUtil::MakeShape(xla_shape->element_type(), dimensions); + TensorShape per_device_tensor_shape; + TF_RETURN_IF_ERROR(XLAShapeToTensorShape(per_device_xla_shape, + &per_device_tensor_shape)); + TF_ASSIGN_OR_RETURN(DataType dtype, EncodePrimitiveTypeAsDataType( + xla_shape->element_type())); + TF_ASSIGN_OR_RETURN(per_device_xla_shape, + shape_representation_fn(per_device_tensor_shape, + dtype, use_fast_memory)); + *xla_shape->mutable_layout() = per_device_xla_shape.layout(); + } + return Status::OK(); + }; switch (arg.kind) { case XlaCompiler::Argument::kConstant: LOG(FATAL) << "Unreachable case"; @@ -776,6 +818,9 @@ Status XlaCompiler::XLAShapeForArgument(const XlaCompiler::Argument& arg, TF_ASSIGN_OR_RETURN(*xla_shape, options_.shape_representation_fn( shape, arg.type, /*use_fast_memory=*/false)); + TF_RETURN_IF_ERROR(rewrite_layout_with_sharded_shape( + arg_sharding, /*use_fast_memory=*/false, + options_.shape_representation_fn, xla_shape)); } else { if (absl::holds_alternative(arg.shape)) { *xla_shape = absl::get(arg.shape); @@ -801,6 +846,9 @@ Status XlaCompiler::XLAShapeForArgument(const XlaCompiler::Argument& arg, options_.shape_representation_fn( absl::get(arg.shape), arg.type, /*use_fast_memory=*/arg.fast_mem)); + TF_RETURN_IF_ERROR(rewrite_layout_with_sharded_shape( + arg_sharding, arg.fast_mem, options_.shape_representation_fn, + xla_shape)); return Status::OK(); } case XlaResource::kTensorArray: { @@ -939,8 +987,16 @@ Status XlaCompiler::BuildArguments( std::vector arg_shapes(input_to_args->size()); for (std::vector::size_type i = 0; i < input_to_args->size(); ++i) { // Computes the shapes of non-constant arguments. - TF_RETURN_IF_ERROR(XLAShapeForArgument( - args[(*input_to_args)[i]], is_entry_computation, &arg_shapes[i])); + auto arg_sharding = arg_shardings.find((*input_to_args)[i]); + absl::optional sharding; + if (arg_sharding != arg_shardings.end()) { + TF_ASSIGN_OR_RETURN(auto hlo_sharding, + xla::HloSharding::FromProto(arg_sharding->second)); + sharding = hlo_sharding; + } + TF_RETURN_IF_ERROR(XLAShapeForArgument(args[(*input_to_args)[i]], + is_entry_computation, sharding, + &arg_shapes[i])); } if (use_tuple_arg) { diff --git a/tensorflow/compiler/tf2xla/xla_compiler.h b/tensorflow/compiler/tf2xla/xla_compiler.h index 670da043c1a..1ae82c3ea54 100644 --- a/tensorflow/compiler/tf2xla/xla_compiler.h +++ b/tensorflow/compiler/tf2xla/xla_compiler.h @@ -381,8 +381,10 @@ class XlaCompiler { // Returns the shape of the XLA parameter for an argument 'arg'. // See the class comment for more details about the argument passing // convention. - Status XLAShapeForArgument(const Argument& arg, bool is_entry_computation, - xla::Shape* xla_shape) const; + Status XLAShapeForArgument( + const Argument& arg, bool is_entry_computation, + const absl::optional& arg_sharding, + xla::Shape* xla_shape) const; // Retrieves the channel handle associated with `key`. Allocates // a new channel handle if none exists.