From ac2c05a1d57398653057405018a8c1e51e99756a Mon Sep 17 00:00:00 2001 From: Yuanzhong Xu Date: Tue, 18 Feb 2020 13:39:31 -0800 Subject: [PATCH] [TF/XLA] Fix several layout issues. 1. The previous approach might have different layouts for computation.GetProgramShape() and xla_output_shape. It only used shape_representation_fn for xla_output_shape, but not entry's program shape. These being different are often confusing, and may make it hard to reproduce a bug with HLO dump which doesn't have HloModuleConfig. 2. Output shapes were not updated with layout when there is sharding. 3. The updated value of a resource did not preserve the fast_mem annotation on the argument. PiperOrigin-RevId: 295811071 Change-Id: I801a46d3039b2349dd0196cbc14ec3d9a8211d55 --- tensorflow/compiler/tf2xla/type_util.cc | 1 + tensorflow/compiler/tf2xla/xla_compiler.cc | 213 +++++++++--------- .../compiler/tf2xla/xla_compiler_test.cc | 9 +- tensorflow/compiler/xla/client/xla_builder.cc | 20 +- tensorflow/compiler/xla/client/xla_builder.h | 12 +- 5 files changed, 145 insertions(+), 110 deletions(-) diff --git a/tensorflow/compiler/tf2xla/type_util.cc b/tensorflow/compiler/tf2xla/type_util.cc index 634f64e01e6..2266a07463d 100644 --- a/tensorflow/compiler/tf2xla/type_util.cc +++ b/tensorflow/compiler/tf2xla/type_util.cc @@ -97,6 +97,7 @@ xla::StatusOr EncodePrimitiveTypeAsDataType(xla::PrimitiveType type) { {xla::U16, DT_UINT16}, {xla::U32, DT_UINT32}, {xla::U64, DT_UINT64}, + {xla::C128, DT_COMPLEX128}, }); auto it = data_type_map.find(type); diff --git a/tensorflow/compiler/tf2xla/xla_compiler.cc b/tensorflow/compiler/tf2xla/xla_compiler.cc index 8e44d3d4255..3ea62882dcb 100644 --- a/tensorflow/compiler/tf2xla/xla_compiler.cc +++ b/tensorflow/compiler/tf2xla/xla_compiler.cc @@ -139,6 +139,86 @@ Status ExecuteGraph(XlaContext* xla_context, std::unique_ptr graph, return Status::OK(); } +// Rewrites the layout of xla_shape if there is tiled sharding. +Status RewriteLayoutWithShardedShape( + const absl::optional& sharding, bool use_fast_memory, + XlaCompiler::ShapeRepresentationFn shape_representation_fn, + xla::Shape* xla_shape) { + if (sharding && !sharding->IsTileMaximal()) { + // After sharding, per core shape might have different layout. For example, + // before sharding, a shape [128, 128] will be assigned default + // minor-to-major {1, 0}. But after we shard this shape to [128, 64] * 2, + // the sharded shapes will have minor-to-major {0, 1}. + // + // As a result, for sharded shapes, we set their layout to per core shape's + // layout. + // + // TODO(endlessroad): for variable input & update, we might have + // different layouts which will prevent input output aliasing and + // increase memory usage. Investigate such cases. + int64 device = *sharding->tile_assignment().begin(); + std::vector offset = + sharding->TileOffsetForDevice(*xla_shape, device); + std::vector limit = sharding->TileLimitForDevice(*xla_shape, device); + std::vector dimensions(xla_shape->rank()); + for (int64 i = 0; i < xla_shape->rank(); ++i) { + dimensions[i] = limit[i] - offset[i]; + } + xla::Shape per_device_xla_shape = + xla::ShapeUtil::MakeShape(xla_shape->element_type(), dimensions); + TensorShape per_device_tensor_shape; + TF_RETURN_IF_ERROR( + XLAShapeToTensorShape(per_device_xla_shape, &per_device_tensor_shape)); + TF_ASSIGN_OR_RETURN(DataType dtype, EncodePrimitiveTypeAsDataType( + xla_shape->element_type())); + TF_ASSIGN_OR_RETURN(per_device_xla_shape, + shape_representation_fn(per_device_tensor_shape, dtype, + use_fast_memory)); + *xla_shape->mutable_layout() = per_device_xla_shape.layout(); + } + return Status::OK(); +} + +// There is a shape_representation_fn or sharding for an output, this function +// uses a reshape to fix the layout. +xla::StatusOr ReshapeWithCorrectRepresentationAndSharding( + xla::XlaBuilder* builder, xla::XlaOp original, xla::Shape original_shape, + XlaCompiler::ShapeRepresentationFn shape_representation_fn, + absl::optional sharding, bool fast_mem) { + if (original_shape.IsTuple()) { + std::vector elements; + for (int64 i = 0; i < original_shape.tuple_shapes_size(); ++i) { + auto subsharding = sharding ? sharding->tuple_shardings(i) : sharding; + TF_ASSIGN_OR_RETURN(auto element, + ReshapeWithCorrectRepresentationAndSharding( + builder, xla::GetTupleElement(original, i), + original_shape.tuple_shapes(i), + shape_representation_fn, subsharding, fast_mem)); + elements.push_back(element); + } + return xla::Tuple(builder, elements); + } + if (!original_shape.IsArray()) return original; + TensorShape shape; + TF_RETURN_IF_ERROR(XLAShapeToTensorShape(original_shape, &shape)); + TF_ASSIGN_OR_RETURN(DataType dtype, EncodePrimitiveTypeAsDataType( + original_shape.element_type())); + TF_ASSIGN_OR_RETURN(auto to_shape, + shape_representation_fn(shape, dtype, fast_mem)); + if (sharding) { + TF_ASSIGN_OR_RETURN(auto hlo_sharding, + xla::HloSharding::FromProto(*sharding)); + TF_RETURN_IF_ERROR(RewriteLayoutWithShardedShape( + hlo_sharding, fast_mem, shape_representation_fn, &to_shape)); + } + if (xla::ShapeUtil::Compatible(original_shape, to_shape)) { + for (int64 i = 0; i < original_shape.rank(); ++i) { + to_shape.set_dynamic_dimension(i, original_shape.is_dynamic_dimension(i)); + } + } + return xla::Reshape(to_shape, original); +} + // Builds the XLA computation. // - `args` is the list of input arguments // - `retvals` is the list of retvals produced by _Retval operators, in index @@ -188,10 +268,6 @@ Status BuildComputation( std::vector elems; elems.reserve(retvals.size()); - // Keeps track of the layout of each retval. If a retval is not in this list, - // a descending layout is used. The first element is the output index, second - // element is the new layout. - std::vector> retval_index_and_layout; // Keeps track of sharding of each retval. If a retval is not in this list, // replicate sharding is used. The first element is the output index, second // element is the sharding. @@ -219,22 +295,22 @@ Status BuildComputation( TF_ASSIGN_OR_RETURN(output.shape, retval.GetShape()); xla::XlaOp value = retval.handle(); auto it = retval_shardings.find(i); - xla::XlaScopedShardingAssignment assign_sharding( - builder, it == retval_shardings.end() - ? absl::optional() - : it->second); + absl::optional sharding = + it == retval_shardings.end() ? absl::optional() + : it->second; if (it != retval_shardings.end()) { retval_index_and_sharding[elems.size()] = it->second; } if (shape_representation_fn) { - // If there is a shape representation function, reshape the output - // tensor to the shape given by the representation shape function. - TF_ASSIGN_OR_RETURN(xla::Shape shape, shape_representation_fn( - output.shape, output.type, - /*use_fast_memory=*/false)); - value = xla::Reshape(value, xla::AsInt64Slice(shape.dimensions())); - retval_index_and_layout.emplace_back(elems.size(), shape.layout()); - } else if (it != retval_shardings.end()) { + TF_ASSIGN_OR_RETURN(auto original_shape, builder->GetShape(value)); + TF_ASSIGN_OR_RETURN(value, + ReshapeWithCorrectRepresentationAndSharding( + builder, value, original_shape, + shape_representation_fn, sharding, + /*fast_mem=*/false)); + } + if (it != retval_shardings.end()) { + xla::XlaScopedShardingAssignment assign_sharding(builder, sharding); // Apply the sharding to the output, if there is a core assignment. value = identity_op(value); } @@ -312,43 +388,27 @@ Status BuildComputation( update.tensor_array_gradients_accessed.insert(grad.first); } + xla::XlaOp handle; + TF_RETURN_IF_ERROR(resource->Pack(&handle, builder)); + auto sharding = it == arg_shardings.end() + ? absl::optional() + : it->second; + // Set layout of the retval to device representation layout. + if (shape_representation_fn) { + TF_ASSIGN_OR_RETURN(auto original_shape, builder->GetShape(handle)); + TF_ASSIGN_OR_RETURN( + handle, ReshapeWithCorrectRepresentationAndSharding( + builder, handle, original_shape, + shape_representation_fn, sharding, arg.fast_mem)); + } + // Request that the value be returned on a specific core. - xla::XlaScopedShardingAssignment assign_sharding( - builder, it == arg_shardings.end() ? absl::optional() - : it->second); + xla::XlaScopedShardingAssignment assign_sharding(builder, sharding); if (it != arg_shardings.end()) { retval_index_and_sharding[elems.size()] = it->second; } - - xla::XlaOp handle; - TF_RETURN_IF_ERROR(resource->Pack(&handle, builder)); - // Ensures the correct sharding is applied to the output. handle = identity_op(handle); - - // Set layout of the retval to device representation layout. - absl::optional representation_shape; - if (shape_representation_fn) { - TF_ASSIGN_OR_RETURN( - xla::Shape xla_shape, - shape_representation_fn(resource->shape(), resource->type(), - /*use_fast_memory=*/false)); - representation_shape = xla_shape; - } - if (resource->representation_shape().has_value()) { - const xla::Shape& xla_shape = resource->representation_shape().value(); - if (representation_shape) { - TF_RET_CHECK( - xla::ShapeUtil::Compatible(*representation_shape, xla_shape)); - } else { - representation_shape = xla_shape; - } - } - if (representation_shape) { - retval_index_and_layout.emplace_back(elems.size(), - representation_shape->layout()); - } - elems.push_back(handle); } } @@ -411,20 +471,8 @@ Status BuildComputation( } *computation = computation_status.ConsumeValueOrDie(); - TF_ASSIGN_OR_RETURN(const auto& program_shape, - computation->GetProgramShape()); + TF_ASSIGN_OR_RETURN(auto program_shape, computation->GetProgramShape()); *output_shape = program_shape.result(); - // Update the output layout to the layout of retval. - for (auto& index_and_layout : retval_index_and_layout) { - if (!always_return_tuple && elems.size() == 1) { - *output_shape->mutable_layout() = index_and_layout.second; - continue; - } - - xla::Shape* output_sub_shape = xla::ShapeUtil::GetMutableSubshape( - output_shape, {index_and_layout.first}); - *output_sub_shape->mutable_layout() = index_and_layout.second; - } return Status::OK(); } @@ -779,47 +827,6 @@ Status XlaCompiler::XLAShapeForArgument( const XlaCompiler::Argument& arg, bool is_entry_computation, const absl::optional& arg_sharding, xla::Shape* xla_shape) const { - auto rewrite_layout_with_sharded_shape = - [](const absl::optional& arg_sharding, - bool use_fast_memory, - XlaCompiler::ShapeRepresentationFn shape_representation_fn, - xla::Shape* xla_shape) { - if (arg_sharding && !arg_sharding->IsTileMaximal()) { - // After parameter sharding, per core parameter might have different - // layout. For example, before sharding, a parameter of shape [128, - // 128] will be assigned default minor-to-major {1, 0}. But after we - // shard this parameter to [128, 64] * 2, the sharded parameters - // will have minor-to-major {0, 1}. - // - // As a result, for sharded parameters, we set their layout to per - // core parameter's layout. - // - // TODO(endlessroad): for variable input & update, we might have - // different layouts which will prevent input output aliasing and - // increase memory usage. Investigate such cases. - int64 device = *arg_sharding->tile_assignment().begin(); - std::vector offset = - arg_sharding->TileOffsetForDevice(*xla_shape, device); - std::vector limit = - arg_sharding->TileLimitForDevice(*xla_shape, device); - std::vector dimensions(xla_shape->rank()); - for (int64 i = 0; i < xla_shape->rank(); ++i) { - dimensions[i] = limit[i] - offset[i]; - } - xla::Shape per_device_xla_shape = - xla::ShapeUtil::MakeShape(xla_shape->element_type(), dimensions); - TensorShape per_device_tensor_shape; - TF_RETURN_IF_ERROR(XLAShapeToTensorShape(per_device_xla_shape, - &per_device_tensor_shape)); - TF_ASSIGN_OR_RETURN(DataType dtype, EncodePrimitiveTypeAsDataType( - xla_shape->element_type())); - TF_ASSIGN_OR_RETURN(per_device_xla_shape, - shape_representation_fn(per_device_tensor_shape, - dtype, use_fast_memory)); - *xla_shape->mutable_layout() = per_device_xla_shape.layout(); - } - return Status::OK(); - }; switch (arg.kind) { case XlaCompiler::Argument::kConstant: LOG(FATAL) << "Unreachable case"; @@ -835,7 +842,7 @@ Status XlaCompiler::XLAShapeForArgument( TF_ASSIGN_OR_RETURN(*xla_shape, options_.shape_representation_fn( shape, arg.type, /*use_fast_memory=*/false)); - TF_RETURN_IF_ERROR(rewrite_layout_with_sharded_shape( + TF_RETURN_IF_ERROR(RewriteLayoutWithShardedShape( arg_sharding, /*use_fast_memory=*/false, options_.shape_representation_fn, xla_shape)); } else { @@ -863,7 +870,7 @@ Status XlaCompiler::XLAShapeForArgument( options_.shape_representation_fn( absl::get(arg.shape), arg.type, /*use_fast_memory=*/arg.fast_mem)); - TF_RETURN_IF_ERROR(rewrite_layout_with_sharded_shape( + TF_RETURN_IF_ERROR(RewriteLayoutWithShardedShape( arg_sharding, arg.fast_mem, options_.shape_representation_fn, xla_shape)); return Status::OK(); diff --git a/tensorflow/compiler/tf2xla/xla_compiler_test.cc b/tensorflow/compiler/tf2xla/xla_compiler_test.cc index cf8bd6b6ce4..76780167187 100644 --- a/tensorflow/compiler/tf2xla/xla_compiler_test.cc +++ b/tensorflow/compiler/tf2xla/xla_compiler_test.cc @@ -365,7 +365,8 @@ TEST_F(XlaCompilerTest, HonorShapeRepresentationFnForFastMemVar) { compile_options.return_updated_values_for_all_resources = true; TF_ASSERT_OK(compiler.CompileGraph(compile_options, "add", std::move(graph), args, &result)); - EXPECT_EQ(fast_mem_arg_count, 1); + // Count 2: one for argument, one for the return value. + EXPECT_EQ(fast_mem_arg_count, 2); } // Tests that the compiler can correctly propagate the layout assigned by @@ -417,6 +418,8 @@ TEST_F(XlaCompilerTest, HonorShapeRepresentationFnForRetVal) { // Check that the return shapes are correctly tranposed. EXPECT_EQ(result.xla_output_shape, xla::ShapeUtil::MakeTupleShape({transposed, transposed})); + EXPECT_EQ(result.computation->GetProgramShape().ConsumeValueOrDie().result(), + xla::ShapeUtil::MakeTupleShape({transposed, transposed})); } // The layout of resource variable shouldn't change after transpose @@ -1091,6 +1094,8 @@ TEST_F(XlaCompilerTest, ResultLayoutSingle) { EXPECT_TRUE(xla::ShapeUtil::Equal( result.xla_output_shape, xla::ShapeUtil::MakeShapeWithLayout(xla::S32, {2, 3}, {0, 1}))); + EXPECT_EQ(result.computation->GetProgramShape().ConsumeValueOrDie().result(), + result.xla_output_shape); } TEST_F(XlaCompilerTest, ResultLayoutMultiple) { @@ -1131,6 +1136,8 @@ TEST_F(XlaCompilerTest, ResultLayoutMultiple) { EXPECT_TRUE(xla::ShapeUtil::Equal( result.xla_output_shape, xla::ShapeUtil::MakeTupleShape({result_shape, result_shape}))); + EXPECT_EQ(result.computation->GetProgramShape().ConsumeValueOrDie().result(), + result.xla_output_shape); } // Tests a simple graph that reads and writes a variable. diff --git a/tensorflow/compiler/xla/client/xla_builder.cc b/tensorflow/compiler/xla/client/xla_builder.cc index a7e761b7dd0..d4a267d4356 100644 --- a/tensorflow/compiler/xla/client/xla_builder.cc +++ b/tensorflow/compiler/xla/client/xla_builder.cc @@ -528,7 +528,8 @@ StatusOr XlaBuilder::AddBroadcastSequence(const Shape& output_shape, } // Eliminate the size one dimensions. - TF_ASSIGN_OR_RETURN(XlaOp reshaped_operand, Reshape(reshaped_shape, operand)); + TF_ASSIGN_OR_RETURN(XlaOp reshaped_operand, + ReshapeInternal(reshaped_shape, operand)); // Broadcast 'reshape' up to the larger size. return InDimBroadcast(broadcast_shape, reshaped_operand, broadcast_dimensions); @@ -828,8 +829,8 @@ XlaOp XlaBuilder::BroadcastInDim( }); } -StatusOr XlaBuilder::Reshape(const Shape& shape, XlaOp operand, - int64 inferred_dimension) { +StatusOr XlaBuilder::ReshapeInternal(const Shape& shape, XlaOp operand, + int64 inferred_dimension) { TF_RETURN_IF_ERROR(first_error_); HloInstructionProto instr; @@ -1020,7 +1021,7 @@ XlaOp XlaBuilder::Reshape(XlaOp operand, absl::Span dimensions, XlaOp transposed = IsIdentityPermutation(dimensions) ? operand : Transpose(operand, dimensions); - return Reshape(shape, transposed, inferred_dimension); + return ReshapeInternal(shape, transposed, inferred_dimension); }); } @@ -1034,6 +1035,13 @@ XlaOp XlaBuilder::Reshape(XlaOp operand, absl::Span new_sizes, }); } +XlaOp XlaBuilder::Reshape(const Shape& shape, XlaOp operand, + int64 inferred_dimension) { + return ReportErrorOrReturn([&]() -> StatusOr { + return ReshapeInternal(shape, operand, inferred_dimension); + }); +} + XlaOp XlaBuilder::Collapse(XlaOp operand, absl::Span dimensions) { return ReportErrorOrReturn([&]() -> StatusOr { if (dimensions.size() <= 1) { @@ -2951,6 +2959,10 @@ XlaOp Reshape(const XlaOp operand, absl::Span new_sizes) { return operand.builder()->Reshape(operand, new_sizes); } +XlaOp Reshape(const Shape& shape, XlaOp operand) { + return operand.builder()->Reshape(shape, operand); +} + XlaOp ReshapeWithInferredDimension(XlaOp operand, absl::Span new_sizes, int64 inferred_dimension) { diff --git a/tensorflow/compiler/xla/client/xla_builder.h b/tensorflow/compiler/xla/client/xla_builder.h index 993394ea275..6ec9aeb809f 100644 --- a/tensorflow/compiler/xla/client/xla_builder.h +++ b/tensorflow/compiler/xla/client/xla_builder.h @@ -397,6 +397,9 @@ class XlaBuilder { XlaOp Reshape(XlaOp operand, absl::Span new_sizes, int64 inferred_dimension = -1); + XlaOp Reshape(const Shape& shape, XlaOp operand, + int64 inferred_dimension = -1); + XlaOp Collapse(XlaOp operand, absl::Span dimensions); XlaOp Slice(XlaOp operand, absl::Span start_indices, @@ -668,8 +671,8 @@ class XlaBuilder { // Internal helper method for creating a Reshape op with the already inferred // shape. - StatusOr Reshape(const Shape& shape, XlaOp operand, - int64 inferred_dimension = -1); + StatusOr ReshapeInternal(const Shape& shape, XlaOp operand, + int64 inferred_dimension = -1); // Returns the (inferred) result for the program shape using the given root. StatusOr GetProgramShape(int64 root_id) const; @@ -777,6 +780,8 @@ class XlaBuilder { friend XlaOp Reshape(XlaOp operand, absl::Span new_sizes); + friend XlaOp Reshape(const Shape& shape, XlaOp operand); + friend XlaOp ReshapeWithInferredDimension(XlaOp operand, absl::Span new_sizes, int64 inferred_dimension); @@ -1252,6 +1257,9 @@ XlaOp Reshape(XlaOp operand, absl::Span dimensions, // sizes. Conceptually, this is a limited form of "shape casting". XlaOp Reshape(XlaOp operand, absl::Span new_sizes); +// Enqueues a Reshape op that uses an explicit target shape. +XlaOp Reshape(const Shape& shape, XlaOp operand); + // `inferred_dimension` represents the output dimension that's inferred by // upper-level framework by dividing the input element count by the known // output element count. While an inferred_dimension can be static, if there