Add support for updating argument/result shapes and layouts with associated shardings of entry function.

Sharding is present with model parallelism. Depending on what type of sharding is present, argument/result shapes and layouts need to be updated. ShapeRepresentationFn and shardings are used to determine the new shapes and layouts.

PiperOrigin-RevId: 303182568
Change-Id: I4185c1ae12de618b0b2ce9c07d2cd795c4e329b8
This commit is contained in:
Andy Ly 2020-03-26 13:29:16 -07:00 committed by TensorFlower Gardener
parent c5b4b6dc1f
commit 79ca75b618
9 changed files with 472 additions and 129 deletions

View File

@ -1065,6 +1065,7 @@ COMPILE_MLIR_UTIL_DEPS = [
":tensorflow_dialect_registration",
":tensorflow_passes",
":translate_utils",
"@com_google_absl//absl/types:optional",
"@llvm-project//llvm:support",
"@llvm-project//mlir:IR",
"@llvm-project//mlir:Parser",
@ -1083,6 +1084,8 @@ COMPILE_MLIR_UTIL_DEPS = [
"//tensorflow/core:protos_all_cc",
"//tensorflow/core/platform:logging",
"//tensorflow/stream_executor/lib",
"//tensorflow/compiler/xla:xla_data_proto_cc",
"//tensorflow/compiler/xla/service:hlo",
]
# Prefer to link 'compile_mlir_util' library that also links necessary

View File

@ -15,6 +15,7 @@ limitations under the License.
#include "tensorflow/compiler/mlir/tensorflow/utils/compile_mlir_util.h"
#include "absl/types/optional.h"
#include "llvm/ADT/ArrayRef.h"
#include "llvm/ADT/StringRef.h"
#include "mlir/Dialect/StandardOps/IR/Ops.h" // from @llvm-project
@ -43,6 +44,8 @@ limitations under the License.
#include "tensorflow/compiler/mlir/xla/transforms/passes.h"
#include "tensorflow/compiler/mlir/xla/type_to_shape.h"
#include "tensorflow/compiler/tf2xla/shape_util.h"
#include "tensorflow/compiler/xla/service/hlo_sharding.h"
#include "tensorflow/compiler/xla/xla_data.pb.h"
#include "tensorflow/core/platform/logging.h"
namespace tensorflow {
@ -74,7 +77,7 @@ Status ParseMlirModule(llvm::StringRef mlir_module_string,
Status GetXlaInputShapes(
mlir::ModuleOp module, llvm::ArrayRef<TensorShape> arg_shapes,
bool use_tuple_args,
const xla::CustomShapeRepresentationFn shape_representation_fn,
const XlaCompiler::ShapeRepresentationFn shape_representation_fn,
std::vector<xla::Shape>* xla_input_shapes) {
xla_input_shapes->clear();
@ -93,7 +96,24 @@ Status GetXlaInputShapes(
DataType dtype;
TF_RETURN_IF_ERROR(ConvertToDataType(func_type.getInput(i), &dtype));
TF_ASSIGN_OR_RETURN(xla_shape,
shape_representation_fn(arg_shapes[i], dtype));
shape_representation_fn(arg_shapes[i], dtype,
/*use_fast_memory=*/false));
// Rewrite layout with sharding, if sharding is set.
auto sharding =
main_func.getArgAttrOfType<mlir::StringAttr>(i, "xla_hlo.sharding");
if (!sharding) continue;
absl::optional<xla::HloSharding> arg_sharding;
xla::OpSharding op_sharding;
if (!op_sharding.ParseFromString(sharding.getValue().str()))
return errors::InvalidArgument("failed to parse argument sharding ", i,
" '", sharding.getValue().str(), "'");
TF_ASSIGN_OR_RETURN(arg_sharding, xla::HloSharding::FromProto(op_sharding));
TF_RETURN_IF_ERROR(
RewriteLayoutWithShardedShape(arg_sharding, /*use_fast_memory=*/false,
shape_representation_fn, &xla_shape));
}
if (use_tuple_args) {
xla_input_shapes->push_back(
@ -108,9 +128,14 @@ Status GetXlaInputShapes(
// output based on static shapes in MLIR module
Status GetOutputInfo(
mlir::ModuleOp module,
const xla::CustomShapeRepresentationFn shape_representation_fn,
const XlaCompiler::ShapeRepresentationFn shape_representation_fn,
xla::Shape* xla_output_shape,
std::vector<XlaCompiler::OutputDescription>* outputs) {
auto shape_representation_fn_no_fast_memory =
[shape_representation_fn](const TensorShape& shape, DataType dtype) {
return shape_representation_fn(shape, dtype, /*use_fast_memory=*/false);
};
mlir::FuncOp main_func = module.lookupSymbol<mlir::FuncOp>("main");
mlir::FunctionType func_type = main_func.getType();
@ -121,8 +146,9 @@ Status GetOutputInfo(
shapes.reserve(func_type.getNumResults());
for (mlir::Type type : func_type.getResults()) {
TF_ASSIGN_OR_RETURN(xla::Shape shape,
TypeToShape(type, shape_representation_fn));
TF_ASSIGN_OR_RETURN(
xla::Shape shape,
xla::TypeToShape(type, shape_representation_fn_no_fast_memory));
auto tensor_type = type.dyn_cast<mlir::RankedTensorType>();
shapes.push_back(shape);
@ -225,12 +251,12 @@ static void RegisterDialects() {
(void)init_once;
}
} // namespace
// namespace
} // namespace
Status ConvertMLIRToXlaComputation(mlir::ModuleOp module_op,
xla::XlaComputation* xla_computation,
bool use_tuple_args, bool return_tuple) {
Status ConvertMLIRToXlaComputation(
mlir::ModuleOp module_op, xla::XlaComputation* xla_computation,
bool use_tuple_args, bool return_tuple,
const XlaCompiler::ShapeRepresentationFn shape_representation_fn) {
mlir::PassManager tf2xla(module_op.getContext());
tf2xla.addNestedPass<mlir::FuncOp>(mlir::createCanonicalizerPass());
tf2xla.addPass(mlir::TF::CreateTensorListOpsDecompositionPass());
@ -273,7 +299,8 @@ Status ConvertMLIRToXlaComputation(mlir::ModuleOp module_op,
xla::HloProto hlo_proto;
TF_RETURN_IF_ERROR(mlir::ConvertMlirHloToHlo(module_op, &hlo_proto,
use_tuple_args, return_tuple));
use_tuple_args, return_tuple,
shape_representation_fn));
*xla_computation = xla::XlaComputation(hlo_proto.hlo_module());
return Status::OK();
}
@ -281,7 +308,7 @@ Status ConvertMLIRToXlaComputation(mlir::ModuleOp module_op,
static Status CompileMlirToXlaHlo(
mlir::ModuleOp module_op, llvm::ArrayRef<TensorShape> arg_shapes,
bool use_tuple_args,
const XlaCompiler::ShapeRepresentationFn shape_representation_fn,
XlaCompiler::ShapeRepresentationFn shape_representation_fn,
XlaCompiler::CompilationResult* compilation_result) {
if (VLOG_IS_ON(1))
tensorflow::DumpMlirOpToFile("mlir_compile_before", module_op);
@ -292,35 +319,28 @@ static Status CompileMlirToXlaHlo(
if (VLOG_IS_ON(1))
tensorflow::DumpMlirOpToFile("mlir_compile_shape_refiner", module_op);
if (!shape_representation_fn)
shape_representation_fn = IdentityShapeRepresentationFn();
// Convert MLIR module to XLA HLO proto contained in XlaComputation.
compilation_result->computation = std::make_shared<xla::XlaComputation>();
TF_RETURN_IF_ERROR(ConvertMLIRToXlaComputation(
module_op, compilation_result->computation.get(), use_tuple_args,
/*return_tuple=*/true));
/*return_tuple=*/true, shape_representation_fn));
// Construct mapping from XlaComputation's arg to input edges of execute
// node.
GetInputMappingForMlir(arg_shapes.size(), &compilation_result->input_mapping);
auto shape_representation_fn_no_fast_memory =
[shape_representation_fn](const TensorShape& shape,
DataType dtype) -> StatusOr<xla::Shape> {
if (shape_representation_fn)
return shape_representation_fn(shape, dtype, /*use_fast_memory=*/false);
xla::Shape xla_shape;
TF_RETURN_IF_ERROR(TensorShapeToXLAShape(dtype, shape, &xla_shape));
return xla_shape;
};
// Compute all input shapes.
TF_RETURN_IF_ERROR(GetXlaInputShapes(module_op, arg_shapes, use_tuple_args,
shape_representation_fn_no_fast_memory,
shape_representation_fn,
&compilation_result->xla_input_shapes));
// Compute all output descriptions.
TF_RETURN_IF_ERROR(GetOutputInfo(
module_op, shape_representation_fn_no_fast_memory,
&compilation_result->xla_output_shape, &compilation_result->outputs));
TF_RETURN_IF_ERROR(GetOutputInfo(module_op, shape_representation_fn,
&compilation_result->xla_output_shape,
&compilation_result->outputs));
// Compute what resource variables need to be updated after XlaComputation's
// execution.

View File

@ -43,9 +43,13 @@ namespace tensorflow {
// entry computation.
// return_tuple: when this is true, always create a tuple result for the
// entry computation.
Status ConvertMLIRToXlaComputation(mlir::ModuleOp module_op,
xla::XlaComputation* xla_computation,
bool use_tuple_args, bool return_tuple);
// shape_representation_fn: when this is set, this shape representation function
// will be used to determine argument and result shapes. Otherwise the
// original shape will be used as is.
Status ConvertMLIRToXlaComputation(
mlir::ModuleOp module_op, xla::XlaComputation* xla_computation,
bool use_tuple_args, bool return_tuple,
const XlaCompiler::ShapeRepresentationFn shape_representation_fn = nullptr);
// Compiles a serialized MLIR module into XLA HLO, generates all accompanying
// metadata and stores them in CompilationResult.

View File

@ -40,7 +40,8 @@ xla::StatusOr<xla::Shape> TestShapeRepresentation(const TensorShape& shape,
}
TEST(CompileSerializedMlirToXlaHloTest, InvalidSerializedMlirModule) {
string invalid_mlir_module = "totally @invalid MLIR module {here} <-";
constexpr char invalid_mlir_module[] =
"totally @invalid MLIR module {here} <-";
std::vector<TensorShape> arg_shapes;
XlaCompiler::CompilationResult compilation_result;
@ -76,7 +77,7 @@ TEST(CompileSerializedMlirToXlaHloTest, TupleArgs) {
auto status_or_hlo_module = xla::HloModule::CreateFromProto(
compilation_result.computation->proto(), module_config);
TF_ASSERT_OK(status_or_hlo_module.status());
string expected_hlo_module_string = R"(HloModule main.6
constexpr char expected_hlo_module_string[] = R"(HloModule main.6
ENTRY %main.6 (arg_tuple.1: (f32[], f32[])) -> (f32[]) {
%arg_tuple.1 = (f32[], f32[]) parameter(0)
@ -134,7 +135,7 @@ TEST(CompileSerializedMlirToXlaHloTest, IndividualArgs) {
auto status_or_hlo_module = xla::HloModule::CreateFromProto(
compilation_result.computation->proto(), module_config);
TF_ASSERT_OK(status_or_hlo_module.status());
string expected_hlo_module_string = R"(HloModule main.5
constexpr char expected_hlo_module_string[] = R"(HloModule main.5
ENTRY %main.5 (Arg_0.1: f32[], Arg_1.2: f32[]) -> (f32[]) {
%Arg_0.1 = f32[] parameter(0)
@ -181,7 +182,7 @@ ENTRY %main.5 (Arg_0.1: f32[], Arg_1.2: f32[]) -> (f32[]) {
TEST(CompileSerializedMlirToXlaHloTest, CompileTimeConstantFoldedSuccess) {
// "tf.Shape" can only be folded away after shape inference. tf.Reshape can
// only be lowered when tf.Shape is folded into a constant.
string mlir_module = R"(
constexpr char mlir_module[] = R"(
module attributes {tf.versions = {producer = 179 : i32}} {
func @main(%arg0: tensor<10x19xf32>, %arg1: tensor<19x10xf32> {tf_device.is_same_data_across_replicas = true}) -> tensor<10x19xf32> {
%0 = "tf.Shape"(%arg0) : (tensor<10x19xf32>) -> tensor<2xi64>
@ -205,7 +206,7 @@ TEST(CompileSerializedMlirToXlaHloTest, CompileTimeConstantFoldedSuccess) {
auto status_or_hlo_module = xla::HloModule::CreateFromProto(
compilation_result.computation->proto(), module_config);
TF_ASSERT_OK(status_or_hlo_module.status());
string expected_hlo_module_string = R"(HloModule main.6
constexpr char expected_hlo_module_string[] = R"(HloModule main.6
ENTRY %main.6 (arg_tuple.1: (f32[10,19], f32[19,10])) -> (f32[10,19]) {
%arg_tuple.1 = (f32[10,19]{1,0}, f32[19,10]{1,0}) parameter(0), parameter_replication={false,true}
@ -221,7 +222,7 @@ ENTRY %main.6 (arg_tuple.1: (f32[10,19], f32[19,10])) -> (f32[10,19]) {
}
TEST(CompileSerializedMlirToXlaHloTest, ShapeInference) {
string mlir_module = R"(
constexpr char mlir_module[] = R"(
module attributes {tf.versions = {producer = 179 : i32}} {
func @main(%arg0: tensor<*xf32>, %arg1: tensor<?x19xf32>) -> tensor<?x19xf32> {
%0 = "tf.MatMul"(%arg0, %arg1) {T = "tfdtype$DT_FLOAT", transpose_a = false, transpose_b = false} : (tensor<*xf32>, tensor<?x19xf32>) -> tensor<?x19xf32>
@ -245,13 +246,14 @@ TEST(CompileSerializedMlirToXlaHloTest, ShapeInference) {
compilation_result.computation->proto(), module_config);
TF_ASSERT_OK(status_or_hlo_module.status());
string expected_signature =
constexpr char expected_signature[] =
R"((arg_tuple.1: (f32[10,17], f32[17,19])) -> (f32[10,19]))";
EXPECT_THAT(status_or_hlo_module.ValueOrDie()->ToString(),
::testing::HasSubstr(expected_signature));
}
constexpr llvm::StringRef kBroadcastGradientArgsModule = R"(
TEST(CompileSerializedMlirToXlaHloTest, ConstantFoldHook) {
constexpr char mlir_module[] = R"(
module attributes {tf.versions = {producer = 179 : i32}} {
func @main() -> (tensor<0xi32>, tensor<0xi32>) {
%0 = "tf.Const"() {value = dense<[]> : tensor<0xi32>} : () -> tensor<0xi32>
@ -261,12 +263,11 @@ module attributes {tf.versions = {producer = 179 : i32}} {
}
)";
TEST(CompileSerializedMlirToXlaHloTest, ConstantFoldHook) {
std::vector<TensorShape> arg_shapes(2, TensorShape());
XlaCompiler::CompilationResult compilation_result;
Status s = CompileSerializedMlirToXlaHlo(
kBroadcastGradientArgsModule, arg_shapes,
mlir_module, arg_shapes,
/*use_tuple_args=*/true, TestShapeRepresentation, &compilation_result);
TF_ASSERT_OK(s);
@ -275,7 +276,7 @@ TEST(CompileSerializedMlirToXlaHloTest, ConstantFoldHook) {
auto status_or_hlo_module = xla::HloModule::CreateFromProto(
compilation_result.computation->proto(), module_config);
TF_ASSERT_OK(status_or_hlo_module.status());
string expected_hlo_module_string = R"(HloModule main.4
constexpr char expected_hlo_module_string[] = R"(HloModule main.4
ENTRY %main.4 (arg_tuple.1: ()) -> (s32[0], s32[0]) {
%arg_tuple.1 = () parameter(0)
@ -288,6 +289,128 @@ ENTRY %main.4 (arg_tuple.1: ()) -> (s32[0], s32[0]) {
status_or_hlo_module.ValueOrDie()->ToString());
}
// The following xla::OpSharding protos are used:
// Serialized string:
// "\08\03\1A\02\01\02\22\02\00\01"
// Proto debug string:
// type: OTHER
// tile_assignment_dimensions: 1
// tile_assignment_dimensions: 2
// tile_assignment_devices: 0
// tile_assignment_devices: 1
//
// Serialized string:
// "\08\01\1A\01\01\22\01\00"
// Proto debug string:
// type: MAXIMAL
// tile_assignment_dimensions: 1
// tile_assignment_devices: 0
//
// Serialized string:
// ""
// Proto debug string (empty but would equivalent to):
// type: REPLICATED
TEST(CompileSerializedMlirToXlaHloTest, ArgumentSharding) {
constexpr char mlir_module[] = R"(
module attributes {tf.versions = {producer = 179 : i32}} {
func @main(%arg0: tensor<128x10xf32> {xla_hlo.sharding = "\08\03\1A\02\01\02\22\02\00\01"}, %arg1: tensor<10x1024xf32> {xla_hlo.sharding = "\08\01\1A\01\01\22\01\00"}, %arg2: tensor<128x1024xf32> {xla_hlo.sharding = ""}) {
return
}
}
)";
std::vector<TensorShape> arg_shapes{TensorShape({128, 10}),
TensorShape({10, 1024}),
TensorShape({128, 1024})};
XlaCompiler::CompilationResult compilation_result;
Status s = CompileSerializedMlirToXlaHlo(
mlir_module, arg_shapes,
/*use_tuple_args=*/true, TestShapeRepresentation, &compilation_result);
TF_ASSERT_OK(s);
const xla::HloModuleConfig module_config(
compilation_result.computation->GetProgramShape().ValueOrDie());
auto status_or_hlo_module = xla::HloModule::CreateFromProto(
compilation_result.computation->proto(), module_config);
TF_ASSERT_OK(status_or_hlo_module.status());
constexpr char expected_hlo_module_string[] = R"(HloModule main.6
ENTRY %main.6 (arg_tuple.1: (f32[128,10], f32[10,1024], f32[128,1024])) -> () {
%arg_tuple.1 = (f32[128,10]{1,0}, f32[10,1024]{1,0}, f32[128,1024]{1,0}) parameter(0), sharding={{devices=[1,2]0,1}, {maximal device=0}, {replicated}}
%get-tuple-element.2 = f32[128,10]{1,0} get-tuple-element((f32[128,10]{1,0}, f32[10,1024]{1,0}, f32[128,1024]{1,0}) %arg_tuple.1), index=0
%get-tuple-element.3 = f32[10,1024]{1,0} get-tuple-element((f32[128,10]{1,0}, f32[10,1024]{1,0}, f32[128,1024]{1,0}) %arg_tuple.1), index=1
%get-tuple-element.4 = f32[128,1024]{1,0} get-tuple-element((f32[128,10]{1,0}, f32[10,1024]{1,0}, f32[128,1024]{1,0}) %arg_tuple.1), index=2
ROOT %tuple.5 = () tuple()
}
)";
EXPECT_EQ(expected_hlo_module_string,
status_or_hlo_module.ValueOrDie()->ToString());
}
TEST(CompileSerializedMlirToXlaHloTest, BadArgumentSharding) {
constexpr char mlir_module[] = R"(
module attributes {tf.versions = {producer = 179 : i32}} {
func @main(%arg0: tensor<128x10xf32> {xla_hlo.sharding = "bad_sharding"}) {
return
}
}
)";
std::vector<TensorShape> arg_shapes{TensorShape({128, 10})};
XlaCompiler::CompilationResult compilation_result;
Status s = CompileSerializedMlirToXlaHlo(
mlir_module, arg_shapes,
/*use_tuple_args=*/true, TestShapeRepresentation, &compilation_result);
ASSERT_FALSE(s.ok());
EXPECT_EQ(s.error_message(),
"failed to parse argument sharding 0 'bad_sharding'");
}
TEST(CompileSerializedMlirToXlaHloTest, ResultSharding) {
constexpr char mlir_module[] = R"(
module attributes {tf.versions = {bad_consumers = [], min_consumer = 0 : i32, producer = 351 : i32}} {
func @main(%arg0: tensor<128x10xf32>, %arg1: tensor<10x1024xf32>, %arg2: tensor<128x1024xf32>) -> (tensor<128x10xf32> {xla_hlo.sharding = "\08\03\1A\02\01\02\22\02\00\01"}, tensor<10x1024xf32> {xla_hlo.sharding = "\08\01\1A\01\01\22\01\00"}, tensor<128x1024xf32> {xla_hlo.sharding = ""}) {
return %arg0, %arg1, %arg2 : tensor<128x10xf32>, tensor<10x1024xf32>, tensor<128x1024xf32>
}
}
)";
std::vector<TensorShape> arg_shapes{TensorShape({128, 10}),
TensorShape({10, 1024}),
TensorShape({128, 1024})};
XlaCompiler::CompilationResult compilation_result;
Status s = CompileSerializedMlirToXlaHlo(
mlir_module, arg_shapes,
/*use_tuple_args=*/true, TestShapeRepresentation, &compilation_result);
TF_ASSERT_OK(s);
const xla::HloModuleConfig module_config(
compilation_result.computation->GetProgramShape().ValueOrDie());
auto status_or_hlo_module = xla::HloModule::CreateFromProto(
compilation_result.computation->proto(), module_config);
TF_ASSERT_OK(status_or_hlo_module.status());
constexpr char expected_hlo_module_string[] = R"(HloModule main.9
ENTRY %main.9 (arg_tuple.1: (f32[128,10], f32[10,1024], f32[128,1024])) -> (f32[128,10], f32[10,1024], f32[128,1024]) {
%arg_tuple.1 = (f32[128,10]{1,0}, f32[10,1024]{1,0}, f32[128,1024]{1,0}) parameter(0)
%get-tuple-element.2 = f32[128,10]{1,0} get-tuple-element((f32[128,10]{1,0}, f32[10,1024]{1,0}, f32[128,1024]{1,0}) %arg_tuple.1), index=0
%reshape.5 = f32[128,10]{1,0} reshape(f32[128,10]{1,0} %get-tuple-element.2)
%get-tuple-element.3 = f32[10,1024]{1,0} get-tuple-element((f32[128,10]{1,0}, f32[10,1024]{1,0}, f32[128,1024]{1,0}) %arg_tuple.1), index=1
%reshape.6 = f32[10,1024]{1,0} reshape(f32[10,1024]{1,0} %get-tuple-element.3)
%get-tuple-element.4 = f32[128,1024]{1,0} get-tuple-element((f32[128,10]{1,0}, f32[10,1024]{1,0}, f32[128,1024]{1,0}) %arg_tuple.1), index=2
%reshape.7 = f32[128,1024]{1,0} reshape(f32[128,1024]{1,0} %get-tuple-element.4)
ROOT %tuple.8 = (f32[128,10]{1,0}, f32[10,1024]{1,0}, f32[128,1024]{1,0}) tuple(f32[128,10]{1,0} %reshape.5, f32[10,1024]{1,0} %reshape.6, f32[128,1024]{1,0} %reshape.7), sharding={{devices=[1,2]0,1}, {maximal device=0}, {replicated}}
}
)";
EXPECT_EQ(expected_hlo_module_string,
status_or_hlo_module.ValueOrDie()->ToString());
}
// Verify that conversion from Graph to MLIR and empty shape representation
// function is successful.
TEST(CompileGraphToXlaHlo, Basic) {
@ -311,7 +434,7 @@ TEST(CompileGraphToXlaHlo, Basic) {
result.computation->proto(), module_config);
ASSERT_TRUE(status_or_hlo_module.ok());
string expected_hlo_module_string = R"(HloModule main.3
constexpr char expected_hlo_module_string[] = R"(HloModule main.3
ENTRY %main.3 (Arg_0.1: f32[]) -> (f32[]) {
%Arg_0.1 = f32[] parameter(0)

View File

@ -618,7 +618,10 @@ cc_library(
":hlo",
":type_to_shape",
":xla_dialect_registration",
"//tensorflow/compiler/mlir/tensorflow:convert_type",
"//tensorflow/compiler/mlir/tensorflow:error_util",
"//tensorflow/compiler/tf2xla:common",
"//tensorflow/compiler/tf2xla:xla_compiler",
"//tensorflow/compiler/xla:comparison_util",
"//tensorflow/compiler/xla:literal_util",
"//tensorflow/compiler/xla:shape_util",
@ -629,6 +632,8 @@ cc_library(
"//tensorflow/compiler/xla/client/lib:quantize",
"//tensorflow/compiler/xla/client/lib:slicing",
"//tensorflow/compiler/xla/service:hlo",
"//tensorflow/core:framework",
"//tensorflow/core:protos_all_cc",
"//tensorflow/stream_executor/lib",
"@llvm-project//llvm:support",
"@llvm-project//mlir:Analysis",

View File

@ -21,6 +21,7 @@ limitations under the License.
#include "llvm/ADT/DenseMap.h"
#include "llvm/ADT/DenseSet.h"
#include "llvm/ADT/STLExtras.h"
#include "llvm/ADT/SmallVector.h"
#include "llvm/ADT/StringRef.h"
#include "llvm/Support/FormatVariadic.h"
#include "llvm/Support/MemoryBuffer.h"
@ -37,8 +38,12 @@ limitations under the License.
#include "mlir/IR/Operation.h" // from @llvm-project
#include "mlir/IR/StandardTypes.h" // from @llvm-project
#include "mlir/IR/TypeUtilities.h" // from @llvm-project
#include "mlir/IR/UseDefLists.h" // from @llvm-project
#include "tensorflow/compiler/mlir/tensorflow/utils/convert_type.h"
#include "tensorflow/compiler/mlir/xla/ir/hlo_ops.h"
#include "tensorflow/compiler/mlir/xla/type_to_shape.h"
#include "tensorflow/compiler/tf2xla/shape_util.h"
#include "tensorflow/compiler/tf2xla/xla_compiler.h"
#include "tensorflow/compiler/xla/client/lib/matrix.h"
#include "tensorflow/compiler/xla/client/lib/quantize.h"
#include "tensorflow/compiler/xla/client/lib/slicing.h"
@ -49,6 +54,8 @@ limitations under the License.
#include "tensorflow/compiler/xla/shape_util.h"
#include "tensorflow/compiler/xla/status_macros.h"
#include "tensorflow/compiler/xla/xla_data.pb.h"
#include "tensorflow/core/framework/tensor_shape.h"
#include "tensorflow/core/framework/types.pb.h"
#include "tensorflow/stream_executor/lib/statusor.h"
using ::stream_executor::port::StatusOr;
@ -64,6 +71,7 @@ using ::tensorflow::uint8;
constexpr char kPaddingMapAttr[] = "xla_hlo.padding_map";
constexpr char kShapeIndicesAttr[] = "shape_indices";
constexpr char kPaddingArgIndicesAttr[] = "padding_arg_indices";
constexpr char kShardingAttr[] = "xla_hlo.sharding";
constexpr char kRepicationAttr[] = "tf_device.is_same_data_across_replicas";
// Passes through everything except for unique_ptr, on which it calls get().
@ -377,7 +385,7 @@ static xla::ScatterDimensionNumbers Convert_scatter_dimension_numbers(
// returns absl::nullopt.
static absl::optional<xla::OpSharding> CreateOpShardingFromAttribute(
mlir::Operation* op) {
auto sharding = op->getAttrOfType<mlir::StringAttr>("xla_hlo.sharding");
auto sharding = op->getAttrOfType<mlir::StringAttr>(kShardingAttr);
if (!sharding) {
return absl::nullopt;
}
@ -389,6 +397,43 @@ static absl::optional<xla::OpSharding> CreateOpShardingFromAttribute(
return sharding_proto;
}
// Checks if all shardings are set.
static bool AllOptionalShardingsAreSet(
llvm::ArrayRef<absl::optional<xla::OpSharding>> shardings) {
return llvm::all_of(shardings,
[](const absl::optional<xla::OpSharding>& sharding) {
return sharding.has_value();
});
}
// Extracts sharding from attribute string.
static absl::optional<xla::OpSharding> CreateOpShardingFromStringRef(
llvm::StringRef sharding) {
xla::OpSharding sharding_proto;
if (!sharding_proto.ParseFromString(sharding.str())) return absl::nullopt;
return sharding_proto;
}
// Extracts argument and result shardings from function.
static void ExtractShardingsFromFunction(
mlir::FuncOp function,
llvm::SmallVectorImpl<absl::optional<xla::OpSharding>>* arg_shardings,
llvm::SmallVectorImpl<absl::optional<xla::OpSharding>>* ret_shardings) {
arg_shardings->resize(function.getNumArguments(),
absl::optional<xla::OpSharding>());
for (int i = 0; i < function.getNumArguments(); ++i)
if (auto sharding =
function.getArgAttrOfType<mlir::StringAttr>(i, kShardingAttr))
(*arg_shardings)[i] = CreateOpShardingFromStringRef(sharding.getValue());
ret_shardings->resize(function.getNumResults(),
absl::optional<xla::OpSharding>());
for (int i = 0; i < function.getNumResults(); ++i)
if (auto sharding =
function.getResultAttrOfType<mlir::StringAttr>(i, kShardingAttr))
(*ret_shardings)[i] = CreateOpShardingFromStringRef(sharding.getValue());
}
namespace mlir {
namespace {
class ConvertToHloModule {
@ -402,12 +447,17 @@ class ConvertToHloModule {
// are converted to a tuple even when there is only a single return value.
// Multiple return values are always converted to a tuple and returned as a
// single value.
explicit ConvertToHloModule(mlir::ModuleOp module, bool use_tuple_args,
bool return_tuple)
explicit ConvertToHloModule(
mlir::ModuleOp module, bool use_tuple_args, bool return_tuple,
tensorflow::XlaCompiler::ShapeRepresentationFn shape_representation_fn)
: module_(module),
module_builder_("main"),
use_tuple_args_(use_tuple_args),
return_tuple_(return_tuple) {}
return_tuple_(return_tuple),
shape_representation_fn_(shape_representation_fn) {
if (!shape_representation_fn_)
shape_representation_fn_ = tensorflow::IdentityShapeRepresentationFn();
}
// Perform the lowering to XLA. This function returns failure if an error was
// encountered.
@ -432,6 +482,8 @@ class ConvertToHloModule {
LogicalResult LowerBasicBlockAsFunction(
Block* block, xla::XlaBuilder* builder, bool is_entry_function,
const std::vector<bool>& entry_args_same_across_replicas,
llvm::ArrayRef<absl::optional<xla::OpSharding>> arg_shardings,
llvm::ArrayRef<absl::optional<xla::OpSharding>> ret_shardings,
xla::XlaComputation* result);
::xla::HloModuleProto ConsumeMainProto() {
@ -445,10 +497,22 @@ class ConvertToHloModule {
ConvertToHloModule::ValueLoweringMap* value_lowering);
private:
LogicalResult Lower(mlir::Operation* inst, bool is_entry_function,
xla::XlaBuilder* builder,
ConvertToHloModule::ValueLoweringMap* value_lowering,
xla::XlaComputation* result);
LogicalResult Lower(
mlir::Operation* inst, bool is_entry_function,
llvm::ArrayRef<absl::optional<xla::OpSharding>> ret_shardings,
xla::XlaBuilder* builder,
ConvertToHloModule::ValueLoweringMap* value_lowering,
xla::XlaComputation* result);
LogicalResult SetEntryTupleShapesAndLeafReplication(
Block* block, const std::vector<bool>& entry_args_same_across_replicas,
llvm::SmallVectorImpl<xla::Shape>* arg_shapes,
std::vector<bool>* leaf_replication);
LogicalResult SetEntryTupleShardings(
Block* block, xla::XlaBuilder* builder,
llvm::ArrayRef<absl::optional<xla::OpSharding>> arg_shardings,
llvm::SmallVectorImpl<xla::Shape>* arg_shapes);
// The module being lowered.
mlir::ModuleOp module_;
@ -465,6 +529,10 @@ class ConvertToHloModule {
// Whether to always return a tuple.
bool return_tuple_;
// Shape representation function to determine entry function argument and
// result shapes.
tensorflow::XlaCompiler::ShapeRepresentationFn shape_representation_fn_;
// Unique suffix to give to the name of the next lowered region.
size_t region_id_ = 0;
};
@ -876,7 +944,9 @@ StatusOr<xla::Literal> CreateLiteralFromAttr(Type type, ElementsAttr attr) {
}
LogicalResult ConvertToHloModule::Lower(
mlir::Operation* inst, bool is_entry_function, xla::XlaBuilder* builder,
mlir::Operation* inst, bool is_entry_function,
llvm::ArrayRef<absl::optional<xla::OpSharding>> ret_shardings,
xla::XlaBuilder* builder,
ConvertToHloModule::ValueLoweringMap* value_lowering,
xla::XlaComputation* result) {
if (succeeded(ExportXlaOperator(inst, {value_lowering, this, builder}))) {
@ -906,11 +976,37 @@ LogicalResult ConvertToHloModule::Lower(
xla::XlaOp return_value;
unsigned num_return_values = inst->getNumOperands();
if ((return_tuple_ && is_entry_function) || num_return_values > 1) {
const bool has_ret_shardings =
!ret_shardings.empty() && AllOptionalShardingsAreSet(ret_shardings);
std::vector<xla::XlaOp> returns(num_return_values);
for (unsigned i = 0, e = inst->getNumOperands(); i != e; ++i) {
returns[i] = value_map[inst->getOperand(i)];
for (OpOperand& ret : inst->getOpOperands()) {
unsigned index = ret.getOperandNumber();
returns[index] = value_map[ret.get()];
if (!is_entry_function || !has_ret_shardings) continue;
xla::Shape return_shape = xla::TypeToShape(ret.get().getType());
StatusOr<xla::XlaOp> reshape =
tensorflow::ReshapeWithCorrectRepresentationAndSharding(
builder, returns[index], return_shape, shape_representation_fn_,
ret_shardings[index], /*fast_mem=*/false);
if (!reshape.ok())
return inst->emitError() << reshape.status().error_message();
returns[index] = reshape.ValueOrDie();
}
if (has_ret_shardings) {
xla::OpSharding sharding;
sharding.set_type(xla::OpSharding::TUPLE);
for (auto& ret_sharding : ret_shardings)
*sharding.add_tuple_shardings() = ret_sharding.value();
builder->SetSharding(sharding);
}
return_value = xla::Tuple(builder, returns);
builder->ClearSharding();
} else if (num_return_values == 1) {
return_value = value_map[inst->getOperand(0)];
}
@ -976,6 +1072,8 @@ LogicalResult ConvertToHloModule::RunOnFunction(mlir::FuncOp f) {
xla::XlaComputation computation;
std::vector<bool> entry_args_same_across_replicas;
llvm::SmallVector<absl::optional<xla::OpSharding>, 4> arg_shardings;
llvm::SmallVector<absl::optional<xla::OpSharding>, 4> ret_shardings;
if (entry_function) {
bool any_arg_replicated = false;
entry_args_same_across_replicas.reserve(f.getNumArguments());
@ -1000,21 +1098,90 @@ LogicalResult ConvertToHloModule::RunOnFunction(mlir::FuncOp f) {
// means no replication. This avoids the need for unrelated tests to handle
// this field.
if (!any_arg_replicated) entry_args_same_across_replicas.clear();
ExtractShardingsFromFunction(f, &arg_shardings, &ret_shardings);
}
if (failed(LowerBasicBlockAsFunction(&f.front(), &builder, entry_function,
entry_args_same_across_replicas,
&computation))) {
if (failed(LowerBasicBlockAsFunction(
&f.front(), &builder, entry_function, entry_args_same_across_replicas,
arg_shardings, ret_shardings, &computation))) {
return failure();
}
lowered_computation_[f] = std::move(computation);
return success();
}
LogicalResult ConvertToHloModule::SetEntryTupleShapesAndLeafReplication(
Block* block, const std::vector<bool>& entry_args_same_across_replicas,
llvm::SmallVectorImpl<xla::Shape>* arg_shapes,
std::vector<bool>* leaf_replication) {
arg_shapes->reserve(block->getNumArguments());
leaf_replication->reserve(block->getNumArguments());
for (BlockArgument& arg : block->getArguments()) {
arg_shapes->push_back(xla::TypeToShape(arg.getType()));
xla::Shape& arg_shape = arg_shapes->back();
tensorflow::TensorShape arg_tensor_shape;
auto status =
tensorflow::XLAShapeToTensorShape(arg_shape, &arg_tensor_shape);
if (!status.ok())
return block->getParentOp()->emitError() << status.error_message();
tensorflow::DataType dtype;
status = tensorflow::ConvertToDataType(arg.getType(), &dtype);
if (!status.ok())
return block->getParentOp()->emitError() << status.error_message();
auto arg_shape_status = shape_representation_fn_(arg_tensor_shape, dtype,
/*use_fast_memory=*/false);
if (!arg_shape_status.ok())
return block->getParentOp()->emitError()
<< arg_shape_status.status().error_message();
arg_shape = std::move(arg_shape_status.ValueOrDie());
if (entry_args_same_across_replicas.empty()) continue;
for (int i = 0, e = xla::ShapeUtil::GetLeafCount(arg_shape); i < e; ++i)
leaf_replication->push_back(
entry_args_same_across_replicas[arg.getArgNumber()]);
}
return success();
}
LogicalResult ConvertToHloModule::SetEntryTupleShardings(
Block* block, xla::XlaBuilder* builder,
llvm::ArrayRef<absl::optional<xla::OpSharding>> arg_shardings,
llvm::SmallVectorImpl<xla::Shape>* arg_shapes) {
if (!arg_shardings.empty() && AllOptionalShardingsAreSet(arg_shardings)) {
xla::OpSharding sharding;
sharding.set_type(xla::OpSharding::TUPLE);
for (auto arg_sharding : llvm::enumerate(arg_shardings)) {
auto hlo_sharding =
xla::HloSharding::FromProto(arg_sharding.value().value());
if (!hlo_sharding.ok())
return block->getParentOp()->emitError()
<< hlo_sharding.status().error_message();
auto status = tensorflow::RewriteLayoutWithShardedShape(
hlo_sharding.ValueOrDie(), /*use_fast_memory=*/false,
shape_representation_fn_, &(*arg_shapes)[arg_sharding.index()]);
if (!status.ok())
return block->getParentOp()->emitError() << status.error_message();
*sharding.add_tuple_shardings() = arg_sharding.value().value();
}
builder->SetSharding(sharding);
}
return success();
}
LogicalResult ConvertToHloModule::LowerBasicBlockAsFunction(
Block* block, xla::XlaBuilder* builder, bool is_entry_function,
const std::vector<bool>& entry_args_same_across_replicas,
llvm::ArrayRef<absl::optional<xla::OpSharding>> arg_shardings,
llvm::ArrayRef<absl::optional<xla::OpSharding>> ret_shardings,
xla::XlaComputation* result) {
auto& bb = *block;
// Mapping from the Value to lowered XlaOp. The code below lowers in
// program order and will fail if an operand is unseen. This can be improved.
ValueLoweringMap lowering;
@ -1022,29 +1189,28 @@ LogicalResult ConvertToHloModule::LowerBasicBlockAsFunction(
// If using tuples as input, then there is only one input parameter that is a
// tuple.
if (is_entry_function && use_tuple_args_) {
std::vector<xla::Shape> arg_shapes;
arg_shapes.reserve(bb.getNumArguments());
llvm::SmallVector<xla::Shape, 4> arg_shapes;
std::vector<bool> leaf_replication;
for (auto& arg : bb.getArguments()) {
arg_shapes.push_back(xla::TypeToShape(arg.getType()));
if (!entry_args_same_across_replicas.empty()) {
for (int i = 0; i < xla::ShapeUtil::GetLeafCount(arg_shapes.back());
++i) {
leaf_replication.push_back(
entry_args_same_across_replicas[arg.getArgNumber()]);
}
}
}
if (failed(SetEntryTupleShapesAndLeafReplication(
block, entry_args_same_across_replicas, &arg_shapes,
&leaf_replication)))
return failure();
if (failed(
SetEntryTupleShardings(block, builder, arg_shardings, &arg_shapes)))
return failure();
xla::Shape input_shape = xla::ShapeUtil::MakeTupleShape(arg_shapes);
auto tuple =
xla::Parameter(builder, 0, input_shape, "arg_tuple", leaf_replication);
for (auto& it : llvm::enumerate(bb.getArguments())) {
lowering[it.value()] = xla::GetTupleElement(tuple, it.index());
}
builder->ClearSharding();
for (BlockArgument& arg : block->getArguments())
lowering[arg] = xla::GetTupleElement(tuple, arg.getArgNumber());
} else {
for (auto& it : llvm::enumerate(bb.getArguments())) {
auto arg = it.value();
auto num = it.index();
for (BlockArgument& arg : block->getArguments()) {
auto num = arg.getArgNumber();
xla::Shape shape = xla::TypeToShape(arg.getType());
if (entry_args_same_across_replicas.empty()) {
lowering[arg] =
@ -1058,8 +1224,9 @@ LogicalResult ConvertToHloModule::LowerBasicBlockAsFunction(
}
}
for (auto& inst : bb)
if (failed(Lower(&inst, is_entry_function, builder, &lowering, result)))
for (auto& inst : *block)
if (failed(Lower(&inst, is_entry_function, ret_shardings, builder,
&lowering, result)))
return failure();
return success();
@ -1069,8 +1236,10 @@ LogicalResult ConvertToHloModule::LowerRegionAsComputation(
mlir::Region* region, xla::XlaComputation* func) {
std::unique_ptr<xla::XlaBuilder> builder =
module_builder_.CreateSubBuilder(absl::StrCat("region_", region_id_++));
return LowerBasicBlockAsFunction(&region->front(), builder.get(),
/*is_entry_function=*/false, {}, func);
return LowerBasicBlockAsFunction(
&region->front(), builder.get(),
/*is_entry_function=*/false, /*entry_args_same_across_replicas=*/{},
/*arg_shardings=*/{}, /*ret_shardings=*/{}, func);
}
std::string PaddingMapBadArrayAttrMsg(llvm::StringRef attr_name, int index) {
@ -1241,9 +1410,12 @@ LogicalResult AddDynamicParameterBindings(mlir::ModuleOp module,
} // namespace
Status ConvertMlirHloToHlo(mlir::ModuleOp module, xla::HloProto* hlo_proto,
bool use_tuple_args, bool return_tuple) {
bool use_tuple_args, bool return_tuple,
const tensorflow::XlaCompiler::ShapeRepresentationFn
shape_representation_fn) {
mlir::StatusScopedDiagnosticHandler diag_handler(module.getContext());
ConvertToHloModule converter(module, use_tuple_args, return_tuple);
ConvertToHloModule converter(module, use_tuple_args, return_tuple,
shape_representation_fn);
if (failed(converter.Run())) return diag_handler.ConsumeStatus();
auto hlo_module = converter.ConsumeMainProto();
hlo_proto->mutable_hlo_module()->Swap(&hlo_module);

View File

@ -18,6 +18,7 @@ limitations under the License.
#include "mlir/IR/Module.h" // from @llvm-project
#include "tensorflow/compiler/mlir/tensorflow/utils/error_util.h"
#include "tensorflow/compiler/tf2xla/xla_compiler.h"
#include "tensorflow/compiler/xla/client/xla_builder.h"
#include "tensorflow/compiler/xla/service/hlo_module.h"
@ -31,7 +32,9 @@ namespace mlir {
// Multiple return values are always converted to a tuple and returned as a
// single value.
Status ConvertMlirHloToHlo(mlir::ModuleOp module, xla::HloProto* hlo_proto,
bool use_tuple_args, bool return_tuple);
bool use_tuple_args, bool return_tuple,
const tensorflow::XlaCompiler::ShapeRepresentationFn
shape_representation_fn = nullptr);
// Creates XlaOp equivalent of a given MLIR operation using the operand info
// from `value_lowering` map.

View File

@ -138,46 +138,6 @@ Status ExecuteGraph(XlaContext* xla_context, std::unique_ptr<Graph> graph,
return Status::OK();
}
// There is a shape_representation_fn or sharding for an output, this function
// uses a reshape to fix the layout.
xla::StatusOr<xla::XlaOp> ReshapeWithCorrectRepresentationAndSharding(
xla::XlaBuilder* builder, xla::XlaOp original, xla::Shape original_shape,
XlaCompiler::ShapeRepresentationFn shape_representation_fn,
absl::optional<xla::OpSharding> sharding, bool fast_mem) {
if (original_shape.IsTuple()) {
std::vector<xla::XlaOp> elements;
for (int64 i = 0; i < original_shape.tuple_shapes_size(); ++i) {
auto subsharding = sharding ? sharding->tuple_shardings(i) : sharding;
TF_ASSIGN_OR_RETURN(auto element,
ReshapeWithCorrectRepresentationAndSharding(
builder, xla::GetTupleElement(original, i),
original_shape.tuple_shapes(i),
shape_representation_fn, subsharding, fast_mem));
elements.push_back(element);
}
return xla::Tuple(builder, elements);
}
if (!original_shape.IsArray()) return original;
TensorShape shape;
TF_RETURN_IF_ERROR(XLAShapeToTensorShape(original_shape, &shape));
TF_ASSIGN_OR_RETURN(DataType dtype, EncodePrimitiveTypeAsDataType(
original_shape.element_type()));
TF_ASSIGN_OR_RETURN(auto to_shape,
shape_representation_fn(shape, dtype, fast_mem));
if (sharding) {
TF_ASSIGN_OR_RETURN(auto hlo_sharding,
xla::HloSharding::FromProto(*sharding));
TF_RETURN_IF_ERROR(RewriteLayoutWithShardedShape(
hlo_sharding, fast_mem, shape_representation_fn, &to_shape));
}
if (xla::ShapeUtil::Compatible(original_shape, to_shape)) {
for (int64 i = 0; i < original_shape.rank(); ++i) {
to_shape.set_dynamic_dimension(i, original_shape.is_dynamic_dimension(i));
}
}
return xla::Reshape(to_shape, original);
}
// Builds the XLA computation.
// - `args` is the list of input arguments
// - `retvals` is the list of retvals produced by _Retval operators, in index
@ -562,13 +522,7 @@ XlaCompiler::XlaCompiler(XlaCompiler::Options options)
// The default shape representation function is the identity.
if (!options_.shape_representation_fn) {
options_.shape_representation_fn =
[](const TensorShape& shape, DataType dtype,
bool use_fast_memory) -> xla::StatusOr<xla::Shape> {
xla::Shape xla_shape;
TF_RETURN_IF_ERROR(TensorShapeToXLAShape(dtype, shape, &xla_shape));
return xla_shape;
};
options_.shape_representation_fn = IdentityShapeRepresentationFn();
}
}
@ -1502,6 +1456,15 @@ xla::StatusOr<xla::XlaOp> XlaCompiler::GetNodeToken(const string& node_name) {
return iter->second;
}
XlaCompiler::ShapeRepresentationFn IdentityShapeRepresentationFn() {
return [](const TensorShape& shape, DataType dtype,
bool use_fast_memory) -> xla::StatusOr<xla::Shape> {
xla::Shape xla_shape;
TF_RETURN_IF_ERROR(TensorShapeToXLAShape(dtype, shape, &xla_shape));
return xla_shape;
};
}
// Rewrites the layout of xla_shape if there is tiled sharding.
Status RewriteLayoutWithShardedShape(
const absl::optional<xla::HloSharding>& sharding, bool use_fast_memory,
@ -1542,4 +1505,44 @@ Status RewriteLayoutWithShardedShape(
return Status::OK();
}
// There is a shape_representation_fn or sharding for an output, this function
// uses a reshape to fix the layout.
xla::StatusOr<xla::XlaOp> ReshapeWithCorrectRepresentationAndSharding(
xla::XlaBuilder* builder, xla::XlaOp original, xla::Shape original_shape,
XlaCompiler::ShapeRepresentationFn shape_representation_fn,
absl::optional<xla::OpSharding> sharding, bool fast_mem) {
if (original_shape.IsTuple()) {
std::vector<xla::XlaOp> elements;
for (int64 i = 0; i < original_shape.tuple_shapes_size(); ++i) {
auto subsharding = sharding ? sharding->tuple_shardings(i) : sharding;
TF_ASSIGN_OR_RETURN(auto element,
ReshapeWithCorrectRepresentationAndSharding(
builder, xla::GetTupleElement(original, i),
original_shape.tuple_shapes(i),
shape_representation_fn, subsharding, fast_mem));
elements.push_back(element);
}
return xla::Tuple(builder, elements);
}
if (!original_shape.IsArray()) return original;
TensorShape shape;
TF_RETURN_IF_ERROR(XLAShapeToTensorShape(original_shape, &shape));
TF_ASSIGN_OR_RETURN(DataType dtype, EncodePrimitiveTypeAsDataType(
original_shape.element_type()));
TF_ASSIGN_OR_RETURN(auto to_shape,
shape_representation_fn(shape, dtype, fast_mem));
if (sharding) {
TF_ASSIGN_OR_RETURN(auto hlo_sharding,
xla::HloSharding::FromProto(*sharding));
TF_RETURN_IF_ERROR(RewriteLayoutWithShardedShape(
hlo_sharding, fast_mem, shape_representation_fn, &to_shape));
}
if (xla::ShapeUtil::Compatible(original_shape, to_shape)) {
for (int64 i = 0; i < original_shape.rank(); ++i) {
to_shape.set_dynamic_dimension(i, original_shape.is_dynamic_dimension(i));
}
}
return xla::Reshape(to_shape, original);
}
} // namespace tensorflow

View File

@ -518,12 +518,22 @@ class XlaCompiler {
TF_DISALLOW_COPY_AND_ASSIGN(XlaCompiler);
};
// Creates an identity shape representation function.
XlaCompiler::ShapeRepresentationFn IdentityShapeRepresentationFn();
// Rewrites the layout of xla_shape if there is tiled sharding.
Status RewriteLayoutWithShardedShape(
const absl::optional<xla::HloSharding>& sharding, bool use_fast_memory,
XlaCompiler::ShapeRepresentationFn shape_representation_fn,
xla::Shape* xla_shape);
// Adds reshapes to fix the layout of an output, if a shape_representation_fn or
// sharding is present.
xla::StatusOr<xla::XlaOp> ReshapeWithCorrectRepresentationAndSharding(
xla::XlaBuilder* builder, xla::XlaOp original, xla::Shape original_shape,
XlaCompiler::ShapeRepresentationFn shape_representation_fn,
absl::optional<xla::OpSharding> sharding, bool fast_mem);
} // namespace tensorflow
#endif // TENSORFLOW_COMPILER_TF2XLA_XLA_COMPILER_H_