From 79ca75b61858084e7e141f8effcdac705ffc0446 Mon Sep 17 00:00:00 2001 From: Andy Ly Date: Thu, 26 Mar 2020 13:29:16 -0700 Subject: [PATCH] Add support for updating argument/result shapes and layouts with associated shardings of entry function. Sharding is present with model parallelism. Depending on what type of sharding is present, argument/result shapes and layouts need to be updated. ShapeRepresentationFn and shardings are used to determine the new shapes and layouts. PiperOrigin-RevId: 303182568 Change-Id: I4185c1ae12de618b0b2ce9c07d2cd795c4e329b8 --- tensorflow/compiler/mlir/tensorflow/BUILD | 3 + .../tensorflow/utils/compile_mlir_util.cc | 74 ++++-- .../mlir/tensorflow/utils/compile_mlir_util.h | 10 +- .../utils/compile_mlir_util_test.cc | 147 +++++++++- tensorflow/compiler/mlir/xla/BUILD | 5 + .../compiler/mlir/xla/mlir_hlo_to_hlo.cc | 250 +++++++++++++++--- .../compiler/mlir/xla/mlir_hlo_to_hlo.h | 5 +- tensorflow/compiler/tf2xla/xla_compiler.cc | 97 +++---- tensorflow/compiler/tf2xla/xla_compiler.h | 10 + 9 files changed, 472 insertions(+), 129 deletions(-) diff --git a/tensorflow/compiler/mlir/tensorflow/BUILD b/tensorflow/compiler/mlir/tensorflow/BUILD index dfa642f378c..c2120ccc4ab 100644 --- a/tensorflow/compiler/mlir/tensorflow/BUILD +++ b/tensorflow/compiler/mlir/tensorflow/BUILD @@ -1065,6 +1065,7 @@ COMPILE_MLIR_UTIL_DEPS = [ ":tensorflow_dialect_registration", ":tensorflow_passes", ":translate_utils", + "@com_google_absl//absl/types:optional", "@llvm-project//llvm:support", "@llvm-project//mlir:IR", "@llvm-project//mlir:Parser", @@ -1083,6 +1084,8 @@ COMPILE_MLIR_UTIL_DEPS = [ "//tensorflow/core:protos_all_cc", "//tensorflow/core/platform:logging", "//tensorflow/stream_executor/lib", + "//tensorflow/compiler/xla:xla_data_proto_cc", + "//tensorflow/compiler/xla/service:hlo", ] # Prefer to link 'compile_mlir_util' library that also links necessary diff --git a/tensorflow/compiler/mlir/tensorflow/utils/compile_mlir_util.cc b/tensorflow/compiler/mlir/tensorflow/utils/compile_mlir_util.cc index 3fd711b9ef8..f2a1cc13b01 100644 --- a/tensorflow/compiler/mlir/tensorflow/utils/compile_mlir_util.cc +++ b/tensorflow/compiler/mlir/tensorflow/utils/compile_mlir_util.cc @@ -15,6 +15,7 @@ limitations under the License. #include "tensorflow/compiler/mlir/tensorflow/utils/compile_mlir_util.h" +#include "absl/types/optional.h" #include "llvm/ADT/ArrayRef.h" #include "llvm/ADT/StringRef.h" #include "mlir/Dialect/StandardOps/IR/Ops.h" // from @llvm-project @@ -43,6 +44,8 @@ limitations under the License. #include "tensorflow/compiler/mlir/xla/transforms/passes.h" #include "tensorflow/compiler/mlir/xla/type_to_shape.h" #include "tensorflow/compiler/tf2xla/shape_util.h" +#include "tensorflow/compiler/xla/service/hlo_sharding.h" +#include "tensorflow/compiler/xla/xla_data.pb.h" #include "tensorflow/core/platform/logging.h" namespace tensorflow { @@ -74,7 +77,7 @@ Status ParseMlirModule(llvm::StringRef mlir_module_string, Status GetXlaInputShapes( mlir::ModuleOp module, llvm::ArrayRef arg_shapes, bool use_tuple_args, - const xla::CustomShapeRepresentationFn shape_representation_fn, + const XlaCompiler::ShapeRepresentationFn shape_representation_fn, std::vector* xla_input_shapes) { xla_input_shapes->clear(); @@ -93,7 +96,24 @@ Status GetXlaInputShapes( DataType dtype; TF_RETURN_IF_ERROR(ConvertToDataType(func_type.getInput(i), &dtype)); TF_ASSIGN_OR_RETURN(xla_shape, - shape_representation_fn(arg_shapes[i], dtype)); + shape_representation_fn(arg_shapes[i], dtype, + /*use_fast_memory=*/false)); + + // Rewrite layout with sharding, if sharding is set. + auto sharding = + main_func.getArgAttrOfType(i, "xla_hlo.sharding"); + if (!sharding) continue; + + absl::optional arg_sharding; + xla::OpSharding op_sharding; + if (!op_sharding.ParseFromString(sharding.getValue().str())) + return errors::InvalidArgument("failed to parse argument sharding ", i, + " '", sharding.getValue().str(), "'"); + + TF_ASSIGN_OR_RETURN(arg_sharding, xla::HloSharding::FromProto(op_sharding)); + TF_RETURN_IF_ERROR( + RewriteLayoutWithShardedShape(arg_sharding, /*use_fast_memory=*/false, + shape_representation_fn, &xla_shape)); } if (use_tuple_args) { xla_input_shapes->push_back( @@ -108,9 +128,14 @@ Status GetXlaInputShapes( // output based on static shapes in MLIR module Status GetOutputInfo( mlir::ModuleOp module, - const xla::CustomShapeRepresentationFn shape_representation_fn, + const XlaCompiler::ShapeRepresentationFn shape_representation_fn, xla::Shape* xla_output_shape, std::vector* outputs) { + auto shape_representation_fn_no_fast_memory = + [shape_representation_fn](const TensorShape& shape, DataType dtype) { + return shape_representation_fn(shape, dtype, /*use_fast_memory=*/false); + }; + mlir::FuncOp main_func = module.lookupSymbol("main"); mlir::FunctionType func_type = main_func.getType(); @@ -121,8 +146,9 @@ Status GetOutputInfo( shapes.reserve(func_type.getNumResults()); for (mlir::Type type : func_type.getResults()) { - TF_ASSIGN_OR_RETURN(xla::Shape shape, - TypeToShape(type, shape_representation_fn)); + TF_ASSIGN_OR_RETURN( + xla::Shape shape, + xla::TypeToShape(type, shape_representation_fn_no_fast_memory)); auto tensor_type = type.dyn_cast(); shapes.push_back(shape); @@ -225,12 +251,12 @@ static void RegisterDialects() { (void)init_once; } -} // namespace -// namespace +} // namespace -Status ConvertMLIRToXlaComputation(mlir::ModuleOp module_op, - xla::XlaComputation* xla_computation, - bool use_tuple_args, bool return_tuple) { +Status ConvertMLIRToXlaComputation( + mlir::ModuleOp module_op, xla::XlaComputation* xla_computation, + bool use_tuple_args, bool return_tuple, + const XlaCompiler::ShapeRepresentationFn shape_representation_fn) { mlir::PassManager tf2xla(module_op.getContext()); tf2xla.addNestedPass(mlir::createCanonicalizerPass()); tf2xla.addPass(mlir::TF::CreateTensorListOpsDecompositionPass()); @@ -273,7 +299,8 @@ Status ConvertMLIRToXlaComputation(mlir::ModuleOp module_op, xla::HloProto hlo_proto; TF_RETURN_IF_ERROR(mlir::ConvertMlirHloToHlo(module_op, &hlo_proto, - use_tuple_args, return_tuple)); + use_tuple_args, return_tuple, + shape_representation_fn)); *xla_computation = xla::XlaComputation(hlo_proto.hlo_module()); return Status::OK(); } @@ -281,7 +308,7 @@ Status ConvertMLIRToXlaComputation(mlir::ModuleOp module_op, static Status CompileMlirToXlaHlo( mlir::ModuleOp module_op, llvm::ArrayRef arg_shapes, bool use_tuple_args, - const XlaCompiler::ShapeRepresentationFn shape_representation_fn, + XlaCompiler::ShapeRepresentationFn shape_representation_fn, XlaCompiler::CompilationResult* compilation_result) { if (VLOG_IS_ON(1)) tensorflow::DumpMlirOpToFile("mlir_compile_before", module_op); @@ -292,35 +319,28 @@ static Status CompileMlirToXlaHlo( if (VLOG_IS_ON(1)) tensorflow::DumpMlirOpToFile("mlir_compile_shape_refiner", module_op); + if (!shape_representation_fn) + shape_representation_fn = IdentityShapeRepresentationFn(); + // Convert MLIR module to XLA HLO proto contained in XlaComputation. compilation_result->computation = std::make_shared(); TF_RETURN_IF_ERROR(ConvertMLIRToXlaComputation( module_op, compilation_result->computation.get(), use_tuple_args, - /*return_tuple=*/true)); + /*return_tuple=*/true, shape_representation_fn)); // Construct mapping from XlaComputation's arg to input edges of execute // node. GetInputMappingForMlir(arg_shapes.size(), &compilation_result->input_mapping); - auto shape_representation_fn_no_fast_memory = - [shape_representation_fn](const TensorShape& shape, - DataType dtype) -> StatusOr { - if (shape_representation_fn) - return shape_representation_fn(shape, dtype, /*use_fast_memory=*/false); - xla::Shape xla_shape; - TF_RETURN_IF_ERROR(TensorShapeToXLAShape(dtype, shape, &xla_shape)); - return xla_shape; - }; - // Compute all input shapes. TF_RETURN_IF_ERROR(GetXlaInputShapes(module_op, arg_shapes, use_tuple_args, - shape_representation_fn_no_fast_memory, + shape_representation_fn, &compilation_result->xla_input_shapes)); // Compute all output descriptions. - TF_RETURN_IF_ERROR(GetOutputInfo( - module_op, shape_representation_fn_no_fast_memory, - &compilation_result->xla_output_shape, &compilation_result->outputs)); + TF_RETURN_IF_ERROR(GetOutputInfo(module_op, shape_representation_fn, + &compilation_result->xla_output_shape, + &compilation_result->outputs)); // Compute what resource variables need to be updated after XlaComputation's // execution. diff --git a/tensorflow/compiler/mlir/tensorflow/utils/compile_mlir_util.h b/tensorflow/compiler/mlir/tensorflow/utils/compile_mlir_util.h index 0dd4b8c5efe..2ce0a31eb78 100644 --- a/tensorflow/compiler/mlir/tensorflow/utils/compile_mlir_util.h +++ b/tensorflow/compiler/mlir/tensorflow/utils/compile_mlir_util.h @@ -43,9 +43,13 @@ namespace tensorflow { // entry computation. // return_tuple: when this is true, always create a tuple result for the // entry computation. -Status ConvertMLIRToXlaComputation(mlir::ModuleOp module_op, - xla::XlaComputation* xla_computation, - bool use_tuple_args, bool return_tuple); +// shape_representation_fn: when this is set, this shape representation function +// will be used to determine argument and result shapes. Otherwise the +// original shape will be used as is. +Status ConvertMLIRToXlaComputation( + mlir::ModuleOp module_op, xla::XlaComputation* xla_computation, + bool use_tuple_args, bool return_tuple, + const XlaCompiler::ShapeRepresentationFn shape_representation_fn = nullptr); // Compiles a serialized MLIR module into XLA HLO, generates all accompanying // metadata and stores them in CompilationResult. diff --git a/tensorflow/compiler/mlir/tensorflow/utils/compile_mlir_util_test.cc b/tensorflow/compiler/mlir/tensorflow/utils/compile_mlir_util_test.cc index f65fcc1016d..6b79ad2494f 100644 --- a/tensorflow/compiler/mlir/tensorflow/utils/compile_mlir_util_test.cc +++ b/tensorflow/compiler/mlir/tensorflow/utils/compile_mlir_util_test.cc @@ -40,7 +40,8 @@ xla::StatusOr TestShapeRepresentation(const TensorShape& shape, } TEST(CompileSerializedMlirToXlaHloTest, InvalidSerializedMlirModule) { - string invalid_mlir_module = "totally @invalid MLIR module {here} <-"; + constexpr char invalid_mlir_module[] = + "totally @invalid MLIR module {here} <-"; std::vector arg_shapes; XlaCompiler::CompilationResult compilation_result; @@ -76,7 +77,7 @@ TEST(CompileSerializedMlirToXlaHloTest, TupleArgs) { auto status_or_hlo_module = xla::HloModule::CreateFromProto( compilation_result.computation->proto(), module_config); TF_ASSERT_OK(status_or_hlo_module.status()); - string expected_hlo_module_string = R"(HloModule main.6 + constexpr char expected_hlo_module_string[] = R"(HloModule main.6 ENTRY %main.6 (arg_tuple.1: (f32[], f32[])) -> (f32[]) { %arg_tuple.1 = (f32[], f32[]) parameter(0) @@ -134,7 +135,7 @@ TEST(CompileSerializedMlirToXlaHloTest, IndividualArgs) { auto status_or_hlo_module = xla::HloModule::CreateFromProto( compilation_result.computation->proto(), module_config); TF_ASSERT_OK(status_or_hlo_module.status()); - string expected_hlo_module_string = R"(HloModule main.5 + constexpr char expected_hlo_module_string[] = R"(HloModule main.5 ENTRY %main.5 (Arg_0.1: f32[], Arg_1.2: f32[]) -> (f32[]) { %Arg_0.1 = f32[] parameter(0) @@ -181,7 +182,7 @@ ENTRY %main.5 (Arg_0.1: f32[], Arg_1.2: f32[]) -> (f32[]) { TEST(CompileSerializedMlirToXlaHloTest, CompileTimeConstantFoldedSuccess) { // "tf.Shape" can only be folded away after shape inference. tf.Reshape can // only be lowered when tf.Shape is folded into a constant. - string mlir_module = R"( + constexpr char mlir_module[] = R"( module attributes {tf.versions = {producer = 179 : i32}} { func @main(%arg0: tensor<10x19xf32>, %arg1: tensor<19x10xf32> {tf_device.is_same_data_across_replicas = true}) -> tensor<10x19xf32> { %0 = "tf.Shape"(%arg0) : (tensor<10x19xf32>) -> tensor<2xi64> @@ -205,7 +206,7 @@ TEST(CompileSerializedMlirToXlaHloTest, CompileTimeConstantFoldedSuccess) { auto status_or_hlo_module = xla::HloModule::CreateFromProto( compilation_result.computation->proto(), module_config); TF_ASSERT_OK(status_or_hlo_module.status()); - string expected_hlo_module_string = R"(HloModule main.6 + constexpr char expected_hlo_module_string[] = R"(HloModule main.6 ENTRY %main.6 (arg_tuple.1: (f32[10,19], f32[19,10])) -> (f32[10,19]) { %arg_tuple.1 = (f32[10,19]{1,0}, f32[19,10]{1,0}) parameter(0), parameter_replication={false,true} @@ -221,7 +222,7 @@ ENTRY %main.6 (arg_tuple.1: (f32[10,19], f32[19,10])) -> (f32[10,19]) { } TEST(CompileSerializedMlirToXlaHloTest, ShapeInference) { - string mlir_module = R"( + constexpr char mlir_module[] = R"( module attributes {tf.versions = {producer = 179 : i32}} { func @main(%arg0: tensor<*xf32>, %arg1: tensor) -> tensor { %0 = "tf.MatMul"(%arg0, %arg1) {T = "tfdtype$DT_FLOAT", transpose_a = false, transpose_b = false} : (tensor<*xf32>, tensor) -> tensor @@ -245,13 +246,14 @@ TEST(CompileSerializedMlirToXlaHloTest, ShapeInference) { compilation_result.computation->proto(), module_config); TF_ASSERT_OK(status_or_hlo_module.status()); - string expected_signature = + constexpr char expected_signature[] = R"((arg_tuple.1: (f32[10,17], f32[17,19])) -> (f32[10,19]))"; EXPECT_THAT(status_or_hlo_module.ValueOrDie()->ToString(), ::testing::HasSubstr(expected_signature)); } -constexpr llvm::StringRef kBroadcastGradientArgsModule = R"( +TEST(CompileSerializedMlirToXlaHloTest, ConstantFoldHook) { + constexpr char mlir_module[] = R"( module attributes {tf.versions = {producer = 179 : i32}} { func @main() -> (tensor<0xi32>, tensor<0xi32>) { %0 = "tf.Const"() {value = dense<[]> : tensor<0xi32>} : () -> tensor<0xi32> @@ -261,12 +263,11 @@ module attributes {tf.versions = {producer = 179 : i32}} { } )"; -TEST(CompileSerializedMlirToXlaHloTest, ConstantFoldHook) { std::vector arg_shapes(2, TensorShape()); XlaCompiler::CompilationResult compilation_result; Status s = CompileSerializedMlirToXlaHlo( - kBroadcastGradientArgsModule, arg_shapes, + mlir_module, arg_shapes, /*use_tuple_args=*/true, TestShapeRepresentation, &compilation_result); TF_ASSERT_OK(s); @@ -275,7 +276,7 @@ TEST(CompileSerializedMlirToXlaHloTest, ConstantFoldHook) { auto status_or_hlo_module = xla::HloModule::CreateFromProto( compilation_result.computation->proto(), module_config); TF_ASSERT_OK(status_or_hlo_module.status()); - string expected_hlo_module_string = R"(HloModule main.4 + constexpr char expected_hlo_module_string[] = R"(HloModule main.4 ENTRY %main.4 (arg_tuple.1: ()) -> (s32[0], s32[0]) { %arg_tuple.1 = () parameter(0) @@ -288,6 +289,128 @@ ENTRY %main.4 (arg_tuple.1: ()) -> (s32[0], s32[0]) { status_or_hlo_module.ValueOrDie()->ToString()); } +// The following xla::OpSharding protos are used: +// Serialized string: +// "\08\03\1A\02\01\02\22\02\00\01" +// Proto debug string: +// type: OTHER +// tile_assignment_dimensions: 1 +// tile_assignment_dimensions: 2 +// tile_assignment_devices: 0 +// tile_assignment_devices: 1 +// +// Serialized string: +// "\08\01\1A\01\01\22\01\00" +// Proto debug string: +// type: MAXIMAL +// tile_assignment_dimensions: 1 +// tile_assignment_devices: 0 +// +// Serialized string: +// "" +// Proto debug string (empty but would equivalent to): +// type: REPLICATED +TEST(CompileSerializedMlirToXlaHloTest, ArgumentSharding) { + constexpr char mlir_module[] = R"( +module attributes {tf.versions = {producer = 179 : i32}} { + func @main(%arg0: tensor<128x10xf32> {xla_hlo.sharding = "\08\03\1A\02\01\02\22\02\00\01"}, %arg1: tensor<10x1024xf32> {xla_hlo.sharding = "\08\01\1A\01\01\22\01\00"}, %arg2: tensor<128x1024xf32> {xla_hlo.sharding = ""}) { + return + } +} +)"; + + std::vector arg_shapes{TensorShape({128, 10}), + TensorShape({10, 1024}), + TensorShape({128, 1024})}; + XlaCompiler::CompilationResult compilation_result; + + Status s = CompileSerializedMlirToXlaHlo( + mlir_module, arg_shapes, + /*use_tuple_args=*/true, TestShapeRepresentation, &compilation_result); + TF_ASSERT_OK(s); + + const xla::HloModuleConfig module_config( + compilation_result.computation->GetProgramShape().ValueOrDie()); + auto status_or_hlo_module = xla::HloModule::CreateFromProto( + compilation_result.computation->proto(), module_config); + TF_ASSERT_OK(status_or_hlo_module.status()); + constexpr char expected_hlo_module_string[] = R"(HloModule main.6 + +ENTRY %main.6 (arg_tuple.1: (f32[128,10], f32[10,1024], f32[128,1024])) -> () { + %arg_tuple.1 = (f32[128,10]{1,0}, f32[10,1024]{1,0}, f32[128,1024]{1,0}) parameter(0), sharding={{devices=[1,2]0,1}, {maximal device=0}, {replicated}} + %get-tuple-element.2 = f32[128,10]{1,0} get-tuple-element((f32[128,10]{1,0}, f32[10,1024]{1,0}, f32[128,1024]{1,0}) %arg_tuple.1), index=0 + %get-tuple-element.3 = f32[10,1024]{1,0} get-tuple-element((f32[128,10]{1,0}, f32[10,1024]{1,0}, f32[128,1024]{1,0}) %arg_tuple.1), index=1 + %get-tuple-element.4 = f32[128,1024]{1,0} get-tuple-element((f32[128,10]{1,0}, f32[10,1024]{1,0}, f32[128,1024]{1,0}) %arg_tuple.1), index=2 + ROOT %tuple.5 = () tuple() +} + +)"; + EXPECT_EQ(expected_hlo_module_string, + status_or_hlo_module.ValueOrDie()->ToString()); +} + +TEST(CompileSerializedMlirToXlaHloTest, BadArgumentSharding) { + constexpr char mlir_module[] = R"( +module attributes {tf.versions = {producer = 179 : i32}} { + func @main(%arg0: tensor<128x10xf32> {xla_hlo.sharding = "bad_sharding"}) { + return + } +} +)"; + + std::vector arg_shapes{TensorShape({128, 10})}; + XlaCompiler::CompilationResult compilation_result; + + Status s = CompileSerializedMlirToXlaHlo( + mlir_module, arg_shapes, + /*use_tuple_args=*/true, TestShapeRepresentation, &compilation_result); + ASSERT_FALSE(s.ok()); + EXPECT_EQ(s.error_message(), + "failed to parse argument sharding 0 'bad_sharding'"); +} + +TEST(CompileSerializedMlirToXlaHloTest, ResultSharding) { + constexpr char mlir_module[] = R"( +module attributes {tf.versions = {bad_consumers = [], min_consumer = 0 : i32, producer = 351 : i32}} { + func @main(%arg0: tensor<128x10xf32>, %arg1: tensor<10x1024xf32>, %arg2: tensor<128x1024xf32>) -> (tensor<128x10xf32> {xla_hlo.sharding = "\08\03\1A\02\01\02\22\02\00\01"}, tensor<10x1024xf32> {xla_hlo.sharding = "\08\01\1A\01\01\22\01\00"}, tensor<128x1024xf32> {xla_hlo.sharding = ""}) { + return %arg0, %arg1, %arg2 : tensor<128x10xf32>, tensor<10x1024xf32>, tensor<128x1024xf32> + } +} +)"; + + std::vector arg_shapes{TensorShape({128, 10}), + TensorShape({10, 1024}), + TensorShape({128, 1024})}; + XlaCompiler::CompilationResult compilation_result; + + Status s = CompileSerializedMlirToXlaHlo( + mlir_module, arg_shapes, + /*use_tuple_args=*/true, TestShapeRepresentation, &compilation_result); + TF_ASSERT_OK(s); + + const xla::HloModuleConfig module_config( + compilation_result.computation->GetProgramShape().ValueOrDie()); + auto status_or_hlo_module = xla::HloModule::CreateFromProto( + compilation_result.computation->proto(), module_config); + TF_ASSERT_OK(status_or_hlo_module.status()); + constexpr char expected_hlo_module_string[] = R"(HloModule main.9 + +ENTRY %main.9 (arg_tuple.1: (f32[128,10], f32[10,1024], f32[128,1024])) -> (f32[128,10], f32[10,1024], f32[128,1024]) { + %arg_tuple.1 = (f32[128,10]{1,0}, f32[10,1024]{1,0}, f32[128,1024]{1,0}) parameter(0) + %get-tuple-element.2 = f32[128,10]{1,0} get-tuple-element((f32[128,10]{1,0}, f32[10,1024]{1,0}, f32[128,1024]{1,0}) %arg_tuple.1), index=0 + %reshape.5 = f32[128,10]{1,0} reshape(f32[128,10]{1,0} %get-tuple-element.2) + %get-tuple-element.3 = f32[10,1024]{1,0} get-tuple-element((f32[128,10]{1,0}, f32[10,1024]{1,0}, f32[128,1024]{1,0}) %arg_tuple.1), index=1 + %reshape.6 = f32[10,1024]{1,0} reshape(f32[10,1024]{1,0} %get-tuple-element.3) + %get-tuple-element.4 = f32[128,1024]{1,0} get-tuple-element((f32[128,10]{1,0}, f32[10,1024]{1,0}, f32[128,1024]{1,0}) %arg_tuple.1), index=2 + %reshape.7 = f32[128,1024]{1,0} reshape(f32[128,1024]{1,0} %get-tuple-element.4) + ROOT %tuple.8 = (f32[128,10]{1,0}, f32[10,1024]{1,0}, f32[128,1024]{1,0}) tuple(f32[128,10]{1,0} %reshape.5, f32[10,1024]{1,0} %reshape.6, f32[128,1024]{1,0} %reshape.7), sharding={{devices=[1,2]0,1}, {maximal device=0}, {replicated}} +} + +)"; + EXPECT_EQ(expected_hlo_module_string, + status_or_hlo_module.ValueOrDie()->ToString()); +} + // Verify that conversion from Graph to MLIR and empty shape representation // function is successful. TEST(CompileGraphToXlaHlo, Basic) { @@ -311,7 +434,7 @@ TEST(CompileGraphToXlaHlo, Basic) { result.computation->proto(), module_config); ASSERT_TRUE(status_or_hlo_module.ok()); - string expected_hlo_module_string = R"(HloModule main.3 + constexpr char expected_hlo_module_string[] = R"(HloModule main.3 ENTRY %main.3 (Arg_0.1: f32[]) -> (f32[]) { %Arg_0.1 = f32[] parameter(0) diff --git a/tensorflow/compiler/mlir/xla/BUILD b/tensorflow/compiler/mlir/xla/BUILD index 6597eeaa967..c472c0ed29d 100644 --- a/tensorflow/compiler/mlir/xla/BUILD +++ b/tensorflow/compiler/mlir/xla/BUILD @@ -618,7 +618,10 @@ cc_library( ":hlo", ":type_to_shape", ":xla_dialect_registration", + "//tensorflow/compiler/mlir/tensorflow:convert_type", "//tensorflow/compiler/mlir/tensorflow:error_util", + "//tensorflow/compiler/tf2xla:common", + "//tensorflow/compiler/tf2xla:xla_compiler", "//tensorflow/compiler/xla:comparison_util", "//tensorflow/compiler/xla:literal_util", "//tensorflow/compiler/xla:shape_util", @@ -629,6 +632,8 @@ cc_library( "//tensorflow/compiler/xla/client/lib:quantize", "//tensorflow/compiler/xla/client/lib:slicing", "//tensorflow/compiler/xla/service:hlo", + "//tensorflow/core:framework", + "//tensorflow/core:protos_all_cc", "//tensorflow/stream_executor/lib", "@llvm-project//llvm:support", "@llvm-project//mlir:Analysis", diff --git a/tensorflow/compiler/mlir/xla/mlir_hlo_to_hlo.cc b/tensorflow/compiler/mlir/xla/mlir_hlo_to_hlo.cc index 670f34b4318..8922cc131c6 100644 --- a/tensorflow/compiler/mlir/xla/mlir_hlo_to_hlo.cc +++ b/tensorflow/compiler/mlir/xla/mlir_hlo_to_hlo.cc @@ -21,6 +21,7 @@ limitations under the License. #include "llvm/ADT/DenseMap.h" #include "llvm/ADT/DenseSet.h" #include "llvm/ADT/STLExtras.h" +#include "llvm/ADT/SmallVector.h" #include "llvm/ADT/StringRef.h" #include "llvm/Support/FormatVariadic.h" #include "llvm/Support/MemoryBuffer.h" @@ -37,8 +38,12 @@ limitations under the License. #include "mlir/IR/Operation.h" // from @llvm-project #include "mlir/IR/StandardTypes.h" // from @llvm-project #include "mlir/IR/TypeUtilities.h" // from @llvm-project +#include "mlir/IR/UseDefLists.h" // from @llvm-project +#include "tensorflow/compiler/mlir/tensorflow/utils/convert_type.h" #include "tensorflow/compiler/mlir/xla/ir/hlo_ops.h" #include "tensorflow/compiler/mlir/xla/type_to_shape.h" +#include "tensorflow/compiler/tf2xla/shape_util.h" +#include "tensorflow/compiler/tf2xla/xla_compiler.h" #include "tensorflow/compiler/xla/client/lib/matrix.h" #include "tensorflow/compiler/xla/client/lib/quantize.h" #include "tensorflow/compiler/xla/client/lib/slicing.h" @@ -49,6 +54,8 @@ limitations under the License. #include "tensorflow/compiler/xla/shape_util.h" #include "tensorflow/compiler/xla/status_macros.h" #include "tensorflow/compiler/xla/xla_data.pb.h" +#include "tensorflow/core/framework/tensor_shape.h" +#include "tensorflow/core/framework/types.pb.h" #include "tensorflow/stream_executor/lib/statusor.h" using ::stream_executor::port::StatusOr; @@ -64,6 +71,7 @@ using ::tensorflow::uint8; constexpr char kPaddingMapAttr[] = "xla_hlo.padding_map"; constexpr char kShapeIndicesAttr[] = "shape_indices"; constexpr char kPaddingArgIndicesAttr[] = "padding_arg_indices"; +constexpr char kShardingAttr[] = "xla_hlo.sharding"; constexpr char kRepicationAttr[] = "tf_device.is_same_data_across_replicas"; // Passes through everything except for unique_ptr, on which it calls get(). @@ -377,7 +385,7 @@ static xla::ScatterDimensionNumbers Convert_scatter_dimension_numbers( // returns absl::nullopt. static absl::optional CreateOpShardingFromAttribute( mlir::Operation* op) { - auto sharding = op->getAttrOfType("xla_hlo.sharding"); + auto sharding = op->getAttrOfType(kShardingAttr); if (!sharding) { return absl::nullopt; } @@ -389,6 +397,43 @@ static absl::optional CreateOpShardingFromAttribute( return sharding_proto; } +// Checks if all shardings are set. +static bool AllOptionalShardingsAreSet( + llvm::ArrayRef> shardings) { + return llvm::all_of(shardings, + [](const absl::optional& sharding) { + return sharding.has_value(); + }); +} + +// Extracts sharding from attribute string. +static absl::optional CreateOpShardingFromStringRef( + llvm::StringRef sharding) { + xla::OpSharding sharding_proto; + if (!sharding_proto.ParseFromString(sharding.str())) return absl::nullopt; + return sharding_proto; +} + +// Extracts argument and result shardings from function. +static void ExtractShardingsFromFunction( + mlir::FuncOp function, + llvm::SmallVectorImpl>* arg_shardings, + llvm::SmallVectorImpl>* ret_shardings) { + arg_shardings->resize(function.getNumArguments(), + absl::optional()); + for (int i = 0; i < function.getNumArguments(); ++i) + if (auto sharding = + function.getArgAttrOfType(i, kShardingAttr)) + (*arg_shardings)[i] = CreateOpShardingFromStringRef(sharding.getValue()); + + ret_shardings->resize(function.getNumResults(), + absl::optional()); + for (int i = 0; i < function.getNumResults(); ++i) + if (auto sharding = + function.getResultAttrOfType(i, kShardingAttr)) + (*ret_shardings)[i] = CreateOpShardingFromStringRef(sharding.getValue()); +} + namespace mlir { namespace { class ConvertToHloModule { @@ -402,12 +447,17 @@ class ConvertToHloModule { // are converted to a tuple even when there is only a single return value. // Multiple return values are always converted to a tuple and returned as a // single value. - explicit ConvertToHloModule(mlir::ModuleOp module, bool use_tuple_args, - bool return_tuple) + explicit ConvertToHloModule( + mlir::ModuleOp module, bool use_tuple_args, bool return_tuple, + tensorflow::XlaCompiler::ShapeRepresentationFn shape_representation_fn) : module_(module), module_builder_("main"), use_tuple_args_(use_tuple_args), - return_tuple_(return_tuple) {} + return_tuple_(return_tuple), + shape_representation_fn_(shape_representation_fn) { + if (!shape_representation_fn_) + shape_representation_fn_ = tensorflow::IdentityShapeRepresentationFn(); + } // Perform the lowering to XLA. This function returns failure if an error was // encountered. @@ -432,6 +482,8 @@ class ConvertToHloModule { LogicalResult LowerBasicBlockAsFunction( Block* block, xla::XlaBuilder* builder, bool is_entry_function, const std::vector& entry_args_same_across_replicas, + llvm::ArrayRef> arg_shardings, + llvm::ArrayRef> ret_shardings, xla::XlaComputation* result); ::xla::HloModuleProto ConsumeMainProto() { @@ -445,10 +497,22 @@ class ConvertToHloModule { ConvertToHloModule::ValueLoweringMap* value_lowering); private: - LogicalResult Lower(mlir::Operation* inst, bool is_entry_function, - xla::XlaBuilder* builder, - ConvertToHloModule::ValueLoweringMap* value_lowering, - xla::XlaComputation* result); + LogicalResult Lower( + mlir::Operation* inst, bool is_entry_function, + llvm::ArrayRef> ret_shardings, + xla::XlaBuilder* builder, + ConvertToHloModule::ValueLoweringMap* value_lowering, + xla::XlaComputation* result); + + LogicalResult SetEntryTupleShapesAndLeafReplication( + Block* block, const std::vector& entry_args_same_across_replicas, + llvm::SmallVectorImpl* arg_shapes, + std::vector* leaf_replication); + + LogicalResult SetEntryTupleShardings( + Block* block, xla::XlaBuilder* builder, + llvm::ArrayRef> arg_shardings, + llvm::SmallVectorImpl* arg_shapes); // The module being lowered. mlir::ModuleOp module_; @@ -465,6 +529,10 @@ class ConvertToHloModule { // Whether to always return a tuple. bool return_tuple_; + // Shape representation function to determine entry function argument and + // result shapes. + tensorflow::XlaCompiler::ShapeRepresentationFn shape_representation_fn_; + // Unique suffix to give to the name of the next lowered region. size_t region_id_ = 0; }; @@ -876,7 +944,9 @@ StatusOr CreateLiteralFromAttr(Type type, ElementsAttr attr) { } LogicalResult ConvertToHloModule::Lower( - mlir::Operation* inst, bool is_entry_function, xla::XlaBuilder* builder, + mlir::Operation* inst, bool is_entry_function, + llvm::ArrayRef> ret_shardings, + xla::XlaBuilder* builder, ConvertToHloModule::ValueLoweringMap* value_lowering, xla::XlaComputation* result) { if (succeeded(ExportXlaOperator(inst, {value_lowering, this, builder}))) { @@ -906,11 +976,37 @@ LogicalResult ConvertToHloModule::Lower( xla::XlaOp return_value; unsigned num_return_values = inst->getNumOperands(); if ((return_tuple_ && is_entry_function) || num_return_values > 1) { + const bool has_ret_shardings = + !ret_shardings.empty() && AllOptionalShardingsAreSet(ret_shardings); + std::vector returns(num_return_values); - for (unsigned i = 0, e = inst->getNumOperands(); i != e; ++i) { - returns[i] = value_map[inst->getOperand(i)]; + for (OpOperand& ret : inst->getOpOperands()) { + unsigned index = ret.getOperandNumber(); + returns[index] = value_map[ret.get()]; + if (!is_entry_function || !has_ret_shardings) continue; + + xla::Shape return_shape = xla::TypeToShape(ret.get().getType()); + StatusOr reshape = + tensorflow::ReshapeWithCorrectRepresentationAndSharding( + builder, returns[index], return_shape, shape_representation_fn_, + ret_shardings[index], /*fast_mem=*/false); + if (!reshape.ok()) + return inst->emitError() << reshape.status().error_message(); + + returns[index] = reshape.ValueOrDie(); } + + if (has_ret_shardings) { + xla::OpSharding sharding; + sharding.set_type(xla::OpSharding::TUPLE); + for (auto& ret_sharding : ret_shardings) + *sharding.add_tuple_shardings() = ret_sharding.value(); + + builder->SetSharding(sharding); + } + return_value = xla::Tuple(builder, returns); + builder->ClearSharding(); } else if (num_return_values == 1) { return_value = value_map[inst->getOperand(0)]; } @@ -976,6 +1072,8 @@ LogicalResult ConvertToHloModule::RunOnFunction(mlir::FuncOp f) { xla::XlaComputation computation; std::vector entry_args_same_across_replicas; + llvm::SmallVector, 4> arg_shardings; + llvm::SmallVector, 4> ret_shardings; if (entry_function) { bool any_arg_replicated = false; entry_args_same_across_replicas.reserve(f.getNumArguments()); @@ -1000,21 +1098,90 @@ LogicalResult ConvertToHloModule::RunOnFunction(mlir::FuncOp f) { // means no replication. This avoids the need for unrelated tests to handle // this field. if (!any_arg_replicated) entry_args_same_across_replicas.clear(); + + ExtractShardingsFromFunction(f, &arg_shardings, &ret_shardings); } - if (failed(LowerBasicBlockAsFunction(&f.front(), &builder, entry_function, - entry_args_same_across_replicas, - &computation))) { + if (failed(LowerBasicBlockAsFunction( + &f.front(), &builder, entry_function, entry_args_same_across_replicas, + arg_shardings, ret_shardings, &computation))) { return failure(); } lowered_computation_[f] = std::move(computation); return success(); } +LogicalResult ConvertToHloModule::SetEntryTupleShapesAndLeafReplication( + Block* block, const std::vector& entry_args_same_across_replicas, + llvm::SmallVectorImpl* arg_shapes, + std::vector* leaf_replication) { + arg_shapes->reserve(block->getNumArguments()); + leaf_replication->reserve(block->getNumArguments()); + for (BlockArgument& arg : block->getArguments()) { + arg_shapes->push_back(xla::TypeToShape(arg.getType())); + xla::Shape& arg_shape = arg_shapes->back(); + tensorflow::TensorShape arg_tensor_shape; + auto status = + tensorflow::XLAShapeToTensorShape(arg_shape, &arg_tensor_shape); + if (!status.ok()) + return block->getParentOp()->emitError() << status.error_message(); + + tensorflow::DataType dtype; + status = tensorflow::ConvertToDataType(arg.getType(), &dtype); + if (!status.ok()) + return block->getParentOp()->emitError() << status.error_message(); + + auto arg_shape_status = shape_representation_fn_(arg_tensor_shape, dtype, + /*use_fast_memory=*/false); + if (!arg_shape_status.ok()) + return block->getParentOp()->emitError() + << arg_shape_status.status().error_message(); + + arg_shape = std::move(arg_shape_status.ValueOrDie()); + + if (entry_args_same_across_replicas.empty()) continue; + for (int i = 0, e = xla::ShapeUtil::GetLeafCount(arg_shape); i < e; ++i) + leaf_replication->push_back( + entry_args_same_across_replicas[arg.getArgNumber()]); + } + + return success(); +} + +LogicalResult ConvertToHloModule::SetEntryTupleShardings( + Block* block, xla::XlaBuilder* builder, + llvm::ArrayRef> arg_shardings, + llvm::SmallVectorImpl* arg_shapes) { + if (!arg_shardings.empty() && AllOptionalShardingsAreSet(arg_shardings)) { + xla::OpSharding sharding; + sharding.set_type(xla::OpSharding::TUPLE); + for (auto arg_sharding : llvm::enumerate(arg_shardings)) { + auto hlo_sharding = + xla::HloSharding::FromProto(arg_sharding.value().value()); + if (!hlo_sharding.ok()) + return block->getParentOp()->emitError() + << hlo_sharding.status().error_message(); + + auto status = tensorflow::RewriteLayoutWithShardedShape( + hlo_sharding.ValueOrDie(), /*use_fast_memory=*/false, + shape_representation_fn_, &(*arg_shapes)[arg_sharding.index()]); + if (!status.ok()) + return block->getParentOp()->emitError() << status.error_message(); + + *sharding.add_tuple_shardings() = arg_sharding.value().value(); + } + + builder->SetSharding(sharding); + } + + return success(); +} + LogicalResult ConvertToHloModule::LowerBasicBlockAsFunction( Block* block, xla::XlaBuilder* builder, bool is_entry_function, const std::vector& entry_args_same_across_replicas, + llvm::ArrayRef> arg_shardings, + llvm::ArrayRef> ret_shardings, xla::XlaComputation* result) { - auto& bb = *block; // Mapping from the Value to lowered XlaOp. The code below lowers in // program order and will fail if an operand is unseen. This can be improved. ValueLoweringMap lowering; @@ -1022,29 +1189,28 @@ LogicalResult ConvertToHloModule::LowerBasicBlockAsFunction( // If using tuples as input, then there is only one input parameter that is a // tuple. if (is_entry_function && use_tuple_args_) { - std::vector arg_shapes; - arg_shapes.reserve(bb.getNumArguments()); + llvm::SmallVector arg_shapes; std::vector leaf_replication; - for (auto& arg : bb.getArguments()) { - arg_shapes.push_back(xla::TypeToShape(arg.getType())); - if (!entry_args_same_across_replicas.empty()) { - for (int i = 0; i < xla::ShapeUtil::GetLeafCount(arg_shapes.back()); - ++i) { - leaf_replication.push_back( - entry_args_same_across_replicas[arg.getArgNumber()]); - } - } - } + if (failed(SetEntryTupleShapesAndLeafReplication( + block, entry_args_same_across_replicas, &arg_shapes, + &leaf_replication))) + return failure(); + + if (failed( + SetEntryTupleShardings(block, builder, arg_shardings, &arg_shapes))) + return failure(); + xla::Shape input_shape = xla::ShapeUtil::MakeTupleShape(arg_shapes); auto tuple = xla::Parameter(builder, 0, input_shape, "arg_tuple", leaf_replication); - for (auto& it : llvm::enumerate(bb.getArguments())) { - lowering[it.value()] = xla::GetTupleElement(tuple, it.index()); - } + + builder->ClearSharding(); + + for (BlockArgument& arg : block->getArguments()) + lowering[arg] = xla::GetTupleElement(tuple, arg.getArgNumber()); } else { - for (auto& it : llvm::enumerate(bb.getArguments())) { - auto arg = it.value(); - auto num = it.index(); + for (BlockArgument& arg : block->getArguments()) { + auto num = arg.getArgNumber(); xla::Shape shape = xla::TypeToShape(arg.getType()); if (entry_args_same_across_replicas.empty()) { lowering[arg] = @@ -1058,8 +1224,9 @@ LogicalResult ConvertToHloModule::LowerBasicBlockAsFunction( } } - for (auto& inst : bb) - if (failed(Lower(&inst, is_entry_function, builder, &lowering, result))) + for (auto& inst : *block) + if (failed(Lower(&inst, is_entry_function, ret_shardings, builder, + &lowering, result))) return failure(); return success(); @@ -1069,8 +1236,10 @@ LogicalResult ConvertToHloModule::LowerRegionAsComputation( mlir::Region* region, xla::XlaComputation* func) { std::unique_ptr builder = module_builder_.CreateSubBuilder(absl::StrCat("region_", region_id_++)); - return LowerBasicBlockAsFunction(®ion->front(), builder.get(), - /*is_entry_function=*/false, {}, func); + return LowerBasicBlockAsFunction( + ®ion->front(), builder.get(), + /*is_entry_function=*/false, /*entry_args_same_across_replicas=*/{}, + /*arg_shardings=*/{}, /*ret_shardings=*/{}, func); } std::string PaddingMapBadArrayAttrMsg(llvm::StringRef attr_name, int index) { @@ -1241,9 +1410,12 @@ LogicalResult AddDynamicParameterBindings(mlir::ModuleOp module, } // namespace Status ConvertMlirHloToHlo(mlir::ModuleOp module, xla::HloProto* hlo_proto, - bool use_tuple_args, bool return_tuple) { + bool use_tuple_args, bool return_tuple, + const tensorflow::XlaCompiler::ShapeRepresentationFn + shape_representation_fn) { mlir::StatusScopedDiagnosticHandler diag_handler(module.getContext()); - ConvertToHloModule converter(module, use_tuple_args, return_tuple); + ConvertToHloModule converter(module, use_tuple_args, return_tuple, + shape_representation_fn); if (failed(converter.Run())) return diag_handler.ConsumeStatus(); auto hlo_module = converter.ConsumeMainProto(); hlo_proto->mutable_hlo_module()->Swap(&hlo_module); diff --git a/tensorflow/compiler/mlir/xla/mlir_hlo_to_hlo.h b/tensorflow/compiler/mlir/xla/mlir_hlo_to_hlo.h index 983d61a8af2..1a341b00d0c 100644 --- a/tensorflow/compiler/mlir/xla/mlir_hlo_to_hlo.h +++ b/tensorflow/compiler/mlir/xla/mlir_hlo_to_hlo.h @@ -18,6 +18,7 @@ limitations under the License. #include "mlir/IR/Module.h" // from @llvm-project #include "tensorflow/compiler/mlir/tensorflow/utils/error_util.h" +#include "tensorflow/compiler/tf2xla/xla_compiler.h" #include "tensorflow/compiler/xla/client/xla_builder.h" #include "tensorflow/compiler/xla/service/hlo_module.h" @@ -31,7 +32,9 @@ namespace mlir { // Multiple return values are always converted to a tuple and returned as a // single value. Status ConvertMlirHloToHlo(mlir::ModuleOp module, xla::HloProto* hlo_proto, - bool use_tuple_args, bool return_tuple); + bool use_tuple_args, bool return_tuple, + const tensorflow::XlaCompiler::ShapeRepresentationFn + shape_representation_fn = nullptr); // Creates XlaOp equivalent of a given MLIR operation using the operand info // from `value_lowering` map. diff --git a/tensorflow/compiler/tf2xla/xla_compiler.cc b/tensorflow/compiler/tf2xla/xla_compiler.cc index 9b17ebe0260..85f2d5c1fc6 100644 --- a/tensorflow/compiler/tf2xla/xla_compiler.cc +++ b/tensorflow/compiler/tf2xla/xla_compiler.cc @@ -138,46 +138,6 @@ Status ExecuteGraph(XlaContext* xla_context, std::unique_ptr graph, return Status::OK(); } -// There is a shape_representation_fn or sharding for an output, this function -// uses a reshape to fix the layout. -xla::StatusOr ReshapeWithCorrectRepresentationAndSharding( - xla::XlaBuilder* builder, xla::XlaOp original, xla::Shape original_shape, - XlaCompiler::ShapeRepresentationFn shape_representation_fn, - absl::optional sharding, bool fast_mem) { - if (original_shape.IsTuple()) { - std::vector elements; - for (int64 i = 0; i < original_shape.tuple_shapes_size(); ++i) { - auto subsharding = sharding ? sharding->tuple_shardings(i) : sharding; - TF_ASSIGN_OR_RETURN(auto element, - ReshapeWithCorrectRepresentationAndSharding( - builder, xla::GetTupleElement(original, i), - original_shape.tuple_shapes(i), - shape_representation_fn, subsharding, fast_mem)); - elements.push_back(element); - } - return xla::Tuple(builder, elements); - } - if (!original_shape.IsArray()) return original; - TensorShape shape; - TF_RETURN_IF_ERROR(XLAShapeToTensorShape(original_shape, &shape)); - TF_ASSIGN_OR_RETURN(DataType dtype, EncodePrimitiveTypeAsDataType( - original_shape.element_type())); - TF_ASSIGN_OR_RETURN(auto to_shape, - shape_representation_fn(shape, dtype, fast_mem)); - if (sharding) { - TF_ASSIGN_OR_RETURN(auto hlo_sharding, - xla::HloSharding::FromProto(*sharding)); - TF_RETURN_IF_ERROR(RewriteLayoutWithShardedShape( - hlo_sharding, fast_mem, shape_representation_fn, &to_shape)); - } - if (xla::ShapeUtil::Compatible(original_shape, to_shape)) { - for (int64 i = 0; i < original_shape.rank(); ++i) { - to_shape.set_dynamic_dimension(i, original_shape.is_dynamic_dimension(i)); - } - } - return xla::Reshape(to_shape, original); -} - // Builds the XLA computation. // - `args` is the list of input arguments // - `retvals` is the list of retvals produced by _Retval operators, in index @@ -562,13 +522,7 @@ XlaCompiler::XlaCompiler(XlaCompiler::Options options) // The default shape representation function is the identity. if (!options_.shape_representation_fn) { - options_.shape_representation_fn = - [](const TensorShape& shape, DataType dtype, - bool use_fast_memory) -> xla::StatusOr { - xla::Shape xla_shape; - TF_RETURN_IF_ERROR(TensorShapeToXLAShape(dtype, shape, &xla_shape)); - return xla_shape; - }; + options_.shape_representation_fn = IdentityShapeRepresentationFn(); } } @@ -1502,6 +1456,15 @@ xla::StatusOr XlaCompiler::GetNodeToken(const string& node_name) { return iter->second; } +XlaCompiler::ShapeRepresentationFn IdentityShapeRepresentationFn() { + return [](const TensorShape& shape, DataType dtype, + bool use_fast_memory) -> xla::StatusOr { + xla::Shape xla_shape; + TF_RETURN_IF_ERROR(TensorShapeToXLAShape(dtype, shape, &xla_shape)); + return xla_shape; + }; +} + // Rewrites the layout of xla_shape if there is tiled sharding. Status RewriteLayoutWithShardedShape( const absl::optional& sharding, bool use_fast_memory, @@ -1542,4 +1505,44 @@ Status RewriteLayoutWithShardedShape( return Status::OK(); } +// There is a shape_representation_fn or sharding for an output, this function +// uses a reshape to fix the layout. +xla::StatusOr ReshapeWithCorrectRepresentationAndSharding( + xla::XlaBuilder* builder, xla::XlaOp original, xla::Shape original_shape, + XlaCompiler::ShapeRepresentationFn shape_representation_fn, + absl::optional sharding, bool fast_mem) { + if (original_shape.IsTuple()) { + std::vector elements; + for (int64 i = 0; i < original_shape.tuple_shapes_size(); ++i) { + auto subsharding = sharding ? sharding->tuple_shardings(i) : sharding; + TF_ASSIGN_OR_RETURN(auto element, + ReshapeWithCorrectRepresentationAndSharding( + builder, xla::GetTupleElement(original, i), + original_shape.tuple_shapes(i), + shape_representation_fn, subsharding, fast_mem)); + elements.push_back(element); + } + return xla::Tuple(builder, elements); + } + if (!original_shape.IsArray()) return original; + TensorShape shape; + TF_RETURN_IF_ERROR(XLAShapeToTensorShape(original_shape, &shape)); + TF_ASSIGN_OR_RETURN(DataType dtype, EncodePrimitiveTypeAsDataType( + original_shape.element_type())); + TF_ASSIGN_OR_RETURN(auto to_shape, + shape_representation_fn(shape, dtype, fast_mem)); + if (sharding) { + TF_ASSIGN_OR_RETURN(auto hlo_sharding, + xla::HloSharding::FromProto(*sharding)); + TF_RETURN_IF_ERROR(RewriteLayoutWithShardedShape( + hlo_sharding, fast_mem, shape_representation_fn, &to_shape)); + } + if (xla::ShapeUtil::Compatible(original_shape, to_shape)) { + for (int64 i = 0; i < original_shape.rank(); ++i) { + to_shape.set_dynamic_dimension(i, original_shape.is_dynamic_dimension(i)); + } + } + return xla::Reshape(to_shape, original); +} + } // namespace tensorflow diff --git a/tensorflow/compiler/tf2xla/xla_compiler.h b/tensorflow/compiler/tf2xla/xla_compiler.h index d67b1f26696..b95d250636a 100644 --- a/tensorflow/compiler/tf2xla/xla_compiler.h +++ b/tensorflow/compiler/tf2xla/xla_compiler.h @@ -518,12 +518,22 @@ class XlaCompiler { TF_DISALLOW_COPY_AND_ASSIGN(XlaCompiler); }; +// Creates an identity shape representation function. +XlaCompiler::ShapeRepresentationFn IdentityShapeRepresentationFn(); + // Rewrites the layout of xla_shape if there is tiled sharding. Status RewriteLayoutWithShardedShape( const absl::optional& sharding, bool use_fast_memory, XlaCompiler::ShapeRepresentationFn shape_representation_fn, xla::Shape* xla_shape); +// Adds reshapes to fix the layout of an output, if a shape_representation_fn or +// sharding is present. +xla::StatusOr ReshapeWithCorrectRepresentationAndSharding( + xla::XlaBuilder* builder, xla::XlaOp original, xla::Shape original_shape, + XlaCompiler::ShapeRepresentationFn shape_representation_fn, + absl::optional sharding, bool fast_mem); + } // namespace tensorflow #endif // TENSORFLOW_COMPILER_TF2XLA_XLA_COMPILER_H_