Add support for updating argument/result shapes and layouts with associated shardings of entry function.
Sharding is present with model parallelism. Depending on what type of sharding is present, argument/result shapes and layouts need to be updated. ShapeRepresentationFn and shardings are used to determine the new shapes and layouts. PiperOrigin-RevId: 303182568 Change-Id: I4185c1ae12de618b0b2ce9c07d2cd795c4e329b8
This commit is contained in:
parent
c5b4b6dc1f
commit
79ca75b618
|
@ -1065,6 +1065,7 @@ COMPILE_MLIR_UTIL_DEPS = [
|
|||
":tensorflow_dialect_registration",
|
||||
":tensorflow_passes",
|
||||
":translate_utils",
|
||||
"@com_google_absl//absl/types:optional",
|
||||
"@llvm-project//llvm:support",
|
||||
"@llvm-project//mlir:IR",
|
||||
"@llvm-project//mlir:Parser",
|
||||
|
@ -1083,6 +1084,8 @@ COMPILE_MLIR_UTIL_DEPS = [
|
|||
"//tensorflow/core:protos_all_cc",
|
||||
"//tensorflow/core/platform:logging",
|
||||
"//tensorflow/stream_executor/lib",
|
||||
"//tensorflow/compiler/xla:xla_data_proto_cc",
|
||||
"//tensorflow/compiler/xla/service:hlo",
|
||||
]
|
||||
|
||||
# Prefer to link 'compile_mlir_util' library that also links necessary
|
||||
|
|
|
@ -15,6 +15,7 @@ limitations under the License.
|
|||
|
||||
#include "tensorflow/compiler/mlir/tensorflow/utils/compile_mlir_util.h"
|
||||
|
||||
#include "absl/types/optional.h"
|
||||
#include "llvm/ADT/ArrayRef.h"
|
||||
#include "llvm/ADT/StringRef.h"
|
||||
#include "mlir/Dialect/StandardOps/IR/Ops.h" // from @llvm-project
|
||||
|
@ -43,6 +44,8 @@ limitations under the License.
|
|||
#include "tensorflow/compiler/mlir/xla/transforms/passes.h"
|
||||
#include "tensorflow/compiler/mlir/xla/type_to_shape.h"
|
||||
#include "tensorflow/compiler/tf2xla/shape_util.h"
|
||||
#include "tensorflow/compiler/xla/service/hlo_sharding.h"
|
||||
#include "tensorflow/compiler/xla/xla_data.pb.h"
|
||||
#include "tensorflow/core/platform/logging.h"
|
||||
|
||||
namespace tensorflow {
|
||||
|
@ -74,7 +77,7 @@ Status ParseMlirModule(llvm::StringRef mlir_module_string,
|
|||
Status GetXlaInputShapes(
|
||||
mlir::ModuleOp module, llvm::ArrayRef<TensorShape> arg_shapes,
|
||||
bool use_tuple_args,
|
||||
const xla::CustomShapeRepresentationFn shape_representation_fn,
|
||||
const XlaCompiler::ShapeRepresentationFn shape_representation_fn,
|
||||
std::vector<xla::Shape>* xla_input_shapes) {
|
||||
xla_input_shapes->clear();
|
||||
|
||||
|
@ -93,7 +96,24 @@ Status GetXlaInputShapes(
|
|||
DataType dtype;
|
||||
TF_RETURN_IF_ERROR(ConvertToDataType(func_type.getInput(i), &dtype));
|
||||
TF_ASSIGN_OR_RETURN(xla_shape,
|
||||
shape_representation_fn(arg_shapes[i], dtype));
|
||||
shape_representation_fn(arg_shapes[i], dtype,
|
||||
/*use_fast_memory=*/false));
|
||||
|
||||
// Rewrite layout with sharding, if sharding is set.
|
||||
auto sharding =
|
||||
main_func.getArgAttrOfType<mlir::StringAttr>(i, "xla_hlo.sharding");
|
||||
if (!sharding) continue;
|
||||
|
||||
absl::optional<xla::HloSharding> arg_sharding;
|
||||
xla::OpSharding op_sharding;
|
||||
if (!op_sharding.ParseFromString(sharding.getValue().str()))
|
||||
return errors::InvalidArgument("failed to parse argument sharding ", i,
|
||||
" '", sharding.getValue().str(), "'");
|
||||
|
||||
TF_ASSIGN_OR_RETURN(arg_sharding, xla::HloSharding::FromProto(op_sharding));
|
||||
TF_RETURN_IF_ERROR(
|
||||
RewriteLayoutWithShardedShape(arg_sharding, /*use_fast_memory=*/false,
|
||||
shape_representation_fn, &xla_shape));
|
||||
}
|
||||
if (use_tuple_args) {
|
||||
xla_input_shapes->push_back(
|
||||
|
@ -108,9 +128,14 @@ Status GetXlaInputShapes(
|
|||
// output based on static shapes in MLIR module
|
||||
Status GetOutputInfo(
|
||||
mlir::ModuleOp module,
|
||||
const xla::CustomShapeRepresentationFn shape_representation_fn,
|
||||
const XlaCompiler::ShapeRepresentationFn shape_representation_fn,
|
||||
xla::Shape* xla_output_shape,
|
||||
std::vector<XlaCompiler::OutputDescription>* outputs) {
|
||||
auto shape_representation_fn_no_fast_memory =
|
||||
[shape_representation_fn](const TensorShape& shape, DataType dtype) {
|
||||
return shape_representation_fn(shape, dtype, /*use_fast_memory=*/false);
|
||||
};
|
||||
|
||||
mlir::FuncOp main_func = module.lookupSymbol<mlir::FuncOp>("main");
|
||||
mlir::FunctionType func_type = main_func.getType();
|
||||
|
||||
|
@ -121,8 +146,9 @@ Status GetOutputInfo(
|
|||
shapes.reserve(func_type.getNumResults());
|
||||
|
||||
for (mlir::Type type : func_type.getResults()) {
|
||||
TF_ASSIGN_OR_RETURN(xla::Shape shape,
|
||||
TypeToShape(type, shape_representation_fn));
|
||||
TF_ASSIGN_OR_RETURN(
|
||||
xla::Shape shape,
|
||||
xla::TypeToShape(type, shape_representation_fn_no_fast_memory));
|
||||
auto tensor_type = type.dyn_cast<mlir::RankedTensorType>();
|
||||
shapes.push_back(shape);
|
||||
|
||||
|
@ -225,12 +251,12 @@ static void RegisterDialects() {
|
|||
(void)init_once;
|
||||
}
|
||||
|
||||
} // namespace
|
||||
// namespace
|
||||
} // namespace
|
||||
|
||||
Status ConvertMLIRToXlaComputation(mlir::ModuleOp module_op,
|
||||
xla::XlaComputation* xla_computation,
|
||||
bool use_tuple_args, bool return_tuple) {
|
||||
Status ConvertMLIRToXlaComputation(
|
||||
mlir::ModuleOp module_op, xla::XlaComputation* xla_computation,
|
||||
bool use_tuple_args, bool return_tuple,
|
||||
const XlaCompiler::ShapeRepresentationFn shape_representation_fn) {
|
||||
mlir::PassManager tf2xla(module_op.getContext());
|
||||
tf2xla.addNestedPass<mlir::FuncOp>(mlir::createCanonicalizerPass());
|
||||
tf2xla.addPass(mlir::TF::CreateTensorListOpsDecompositionPass());
|
||||
|
@ -273,7 +299,8 @@ Status ConvertMLIRToXlaComputation(mlir::ModuleOp module_op,
|
|||
|
||||
xla::HloProto hlo_proto;
|
||||
TF_RETURN_IF_ERROR(mlir::ConvertMlirHloToHlo(module_op, &hlo_proto,
|
||||
use_tuple_args, return_tuple));
|
||||
use_tuple_args, return_tuple,
|
||||
shape_representation_fn));
|
||||
*xla_computation = xla::XlaComputation(hlo_proto.hlo_module());
|
||||
return Status::OK();
|
||||
}
|
||||
|
@ -281,7 +308,7 @@ Status ConvertMLIRToXlaComputation(mlir::ModuleOp module_op,
|
|||
static Status CompileMlirToXlaHlo(
|
||||
mlir::ModuleOp module_op, llvm::ArrayRef<TensorShape> arg_shapes,
|
||||
bool use_tuple_args,
|
||||
const XlaCompiler::ShapeRepresentationFn shape_representation_fn,
|
||||
XlaCompiler::ShapeRepresentationFn shape_representation_fn,
|
||||
XlaCompiler::CompilationResult* compilation_result) {
|
||||
if (VLOG_IS_ON(1))
|
||||
tensorflow::DumpMlirOpToFile("mlir_compile_before", module_op);
|
||||
|
@ -292,35 +319,28 @@ static Status CompileMlirToXlaHlo(
|
|||
if (VLOG_IS_ON(1))
|
||||
tensorflow::DumpMlirOpToFile("mlir_compile_shape_refiner", module_op);
|
||||
|
||||
if (!shape_representation_fn)
|
||||
shape_representation_fn = IdentityShapeRepresentationFn();
|
||||
|
||||
// Convert MLIR module to XLA HLO proto contained in XlaComputation.
|
||||
compilation_result->computation = std::make_shared<xla::XlaComputation>();
|
||||
TF_RETURN_IF_ERROR(ConvertMLIRToXlaComputation(
|
||||
module_op, compilation_result->computation.get(), use_tuple_args,
|
||||
/*return_tuple=*/true));
|
||||
/*return_tuple=*/true, shape_representation_fn));
|
||||
|
||||
// Construct mapping from XlaComputation's arg to input edges of execute
|
||||
// node.
|
||||
GetInputMappingForMlir(arg_shapes.size(), &compilation_result->input_mapping);
|
||||
|
||||
auto shape_representation_fn_no_fast_memory =
|
||||
[shape_representation_fn](const TensorShape& shape,
|
||||
DataType dtype) -> StatusOr<xla::Shape> {
|
||||
if (shape_representation_fn)
|
||||
return shape_representation_fn(shape, dtype, /*use_fast_memory=*/false);
|
||||
xla::Shape xla_shape;
|
||||
TF_RETURN_IF_ERROR(TensorShapeToXLAShape(dtype, shape, &xla_shape));
|
||||
return xla_shape;
|
||||
};
|
||||
|
||||
// Compute all input shapes.
|
||||
TF_RETURN_IF_ERROR(GetXlaInputShapes(module_op, arg_shapes, use_tuple_args,
|
||||
shape_representation_fn_no_fast_memory,
|
||||
shape_representation_fn,
|
||||
&compilation_result->xla_input_shapes));
|
||||
|
||||
// Compute all output descriptions.
|
||||
TF_RETURN_IF_ERROR(GetOutputInfo(
|
||||
module_op, shape_representation_fn_no_fast_memory,
|
||||
&compilation_result->xla_output_shape, &compilation_result->outputs));
|
||||
TF_RETURN_IF_ERROR(GetOutputInfo(module_op, shape_representation_fn,
|
||||
&compilation_result->xla_output_shape,
|
||||
&compilation_result->outputs));
|
||||
|
||||
// Compute what resource variables need to be updated after XlaComputation's
|
||||
// execution.
|
||||
|
|
|
@ -43,9 +43,13 @@ namespace tensorflow {
|
|||
// entry computation.
|
||||
// return_tuple: when this is true, always create a tuple result for the
|
||||
// entry computation.
|
||||
Status ConvertMLIRToXlaComputation(mlir::ModuleOp module_op,
|
||||
xla::XlaComputation* xla_computation,
|
||||
bool use_tuple_args, bool return_tuple);
|
||||
// shape_representation_fn: when this is set, this shape representation function
|
||||
// will be used to determine argument and result shapes. Otherwise the
|
||||
// original shape will be used as is.
|
||||
Status ConvertMLIRToXlaComputation(
|
||||
mlir::ModuleOp module_op, xla::XlaComputation* xla_computation,
|
||||
bool use_tuple_args, bool return_tuple,
|
||||
const XlaCompiler::ShapeRepresentationFn shape_representation_fn = nullptr);
|
||||
|
||||
// Compiles a serialized MLIR module into XLA HLO, generates all accompanying
|
||||
// metadata and stores them in CompilationResult.
|
||||
|
|
|
@ -40,7 +40,8 @@ xla::StatusOr<xla::Shape> TestShapeRepresentation(const TensorShape& shape,
|
|||
}
|
||||
|
||||
TEST(CompileSerializedMlirToXlaHloTest, InvalidSerializedMlirModule) {
|
||||
string invalid_mlir_module = "totally @invalid MLIR module {here} <-";
|
||||
constexpr char invalid_mlir_module[] =
|
||||
"totally @invalid MLIR module {here} <-";
|
||||
std::vector<TensorShape> arg_shapes;
|
||||
XlaCompiler::CompilationResult compilation_result;
|
||||
|
||||
|
@ -76,7 +77,7 @@ TEST(CompileSerializedMlirToXlaHloTest, TupleArgs) {
|
|||
auto status_or_hlo_module = xla::HloModule::CreateFromProto(
|
||||
compilation_result.computation->proto(), module_config);
|
||||
TF_ASSERT_OK(status_or_hlo_module.status());
|
||||
string expected_hlo_module_string = R"(HloModule main.6
|
||||
constexpr char expected_hlo_module_string[] = R"(HloModule main.6
|
||||
|
||||
ENTRY %main.6 (arg_tuple.1: (f32[], f32[])) -> (f32[]) {
|
||||
%arg_tuple.1 = (f32[], f32[]) parameter(0)
|
||||
|
@ -134,7 +135,7 @@ TEST(CompileSerializedMlirToXlaHloTest, IndividualArgs) {
|
|||
auto status_or_hlo_module = xla::HloModule::CreateFromProto(
|
||||
compilation_result.computation->proto(), module_config);
|
||||
TF_ASSERT_OK(status_or_hlo_module.status());
|
||||
string expected_hlo_module_string = R"(HloModule main.5
|
||||
constexpr char expected_hlo_module_string[] = R"(HloModule main.5
|
||||
|
||||
ENTRY %main.5 (Arg_0.1: f32[], Arg_1.2: f32[]) -> (f32[]) {
|
||||
%Arg_0.1 = f32[] parameter(0)
|
||||
|
@ -181,7 +182,7 @@ ENTRY %main.5 (Arg_0.1: f32[], Arg_1.2: f32[]) -> (f32[]) {
|
|||
TEST(CompileSerializedMlirToXlaHloTest, CompileTimeConstantFoldedSuccess) {
|
||||
// "tf.Shape" can only be folded away after shape inference. tf.Reshape can
|
||||
// only be lowered when tf.Shape is folded into a constant.
|
||||
string mlir_module = R"(
|
||||
constexpr char mlir_module[] = R"(
|
||||
module attributes {tf.versions = {producer = 179 : i32}} {
|
||||
func @main(%arg0: tensor<10x19xf32>, %arg1: tensor<19x10xf32> {tf_device.is_same_data_across_replicas = true}) -> tensor<10x19xf32> {
|
||||
%0 = "tf.Shape"(%arg0) : (tensor<10x19xf32>) -> tensor<2xi64>
|
||||
|
@ -205,7 +206,7 @@ TEST(CompileSerializedMlirToXlaHloTest, CompileTimeConstantFoldedSuccess) {
|
|||
auto status_or_hlo_module = xla::HloModule::CreateFromProto(
|
||||
compilation_result.computation->proto(), module_config);
|
||||
TF_ASSERT_OK(status_or_hlo_module.status());
|
||||
string expected_hlo_module_string = R"(HloModule main.6
|
||||
constexpr char expected_hlo_module_string[] = R"(HloModule main.6
|
||||
|
||||
ENTRY %main.6 (arg_tuple.1: (f32[10,19], f32[19,10])) -> (f32[10,19]) {
|
||||
%arg_tuple.1 = (f32[10,19]{1,0}, f32[19,10]{1,0}) parameter(0), parameter_replication={false,true}
|
||||
|
@ -221,7 +222,7 @@ ENTRY %main.6 (arg_tuple.1: (f32[10,19], f32[19,10])) -> (f32[10,19]) {
|
|||
}
|
||||
|
||||
TEST(CompileSerializedMlirToXlaHloTest, ShapeInference) {
|
||||
string mlir_module = R"(
|
||||
constexpr char mlir_module[] = R"(
|
||||
module attributes {tf.versions = {producer = 179 : i32}} {
|
||||
func @main(%arg0: tensor<*xf32>, %arg1: tensor<?x19xf32>) -> tensor<?x19xf32> {
|
||||
%0 = "tf.MatMul"(%arg0, %arg1) {T = "tfdtype$DT_FLOAT", transpose_a = false, transpose_b = false} : (tensor<*xf32>, tensor<?x19xf32>) -> tensor<?x19xf32>
|
||||
|
@ -245,13 +246,14 @@ TEST(CompileSerializedMlirToXlaHloTest, ShapeInference) {
|
|||
compilation_result.computation->proto(), module_config);
|
||||
TF_ASSERT_OK(status_or_hlo_module.status());
|
||||
|
||||
string expected_signature =
|
||||
constexpr char expected_signature[] =
|
||||
R"((arg_tuple.1: (f32[10,17], f32[17,19])) -> (f32[10,19]))";
|
||||
EXPECT_THAT(status_or_hlo_module.ValueOrDie()->ToString(),
|
||||
::testing::HasSubstr(expected_signature));
|
||||
}
|
||||
|
||||
constexpr llvm::StringRef kBroadcastGradientArgsModule = R"(
|
||||
TEST(CompileSerializedMlirToXlaHloTest, ConstantFoldHook) {
|
||||
constexpr char mlir_module[] = R"(
|
||||
module attributes {tf.versions = {producer = 179 : i32}} {
|
||||
func @main() -> (tensor<0xi32>, tensor<0xi32>) {
|
||||
%0 = "tf.Const"() {value = dense<[]> : tensor<0xi32>} : () -> tensor<0xi32>
|
||||
|
@ -261,12 +263,11 @@ module attributes {tf.versions = {producer = 179 : i32}} {
|
|||
}
|
||||
)";
|
||||
|
||||
TEST(CompileSerializedMlirToXlaHloTest, ConstantFoldHook) {
|
||||
std::vector<TensorShape> arg_shapes(2, TensorShape());
|
||||
XlaCompiler::CompilationResult compilation_result;
|
||||
|
||||
Status s = CompileSerializedMlirToXlaHlo(
|
||||
kBroadcastGradientArgsModule, arg_shapes,
|
||||
mlir_module, arg_shapes,
|
||||
/*use_tuple_args=*/true, TestShapeRepresentation, &compilation_result);
|
||||
TF_ASSERT_OK(s);
|
||||
|
||||
|
@ -275,7 +276,7 @@ TEST(CompileSerializedMlirToXlaHloTest, ConstantFoldHook) {
|
|||
auto status_or_hlo_module = xla::HloModule::CreateFromProto(
|
||||
compilation_result.computation->proto(), module_config);
|
||||
TF_ASSERT_OK(status_or_hlo_module.status());
|
||||
string expected_hlo_module_string = R"(HloModule main.4
|
||||
constexpr char expected_hlo_module_string[] = R"(HloModule main.4
|
||||
|
||||
ENTRY %main.4 (arg_tuple.1: ()) -> (s32[0], s32[0]) {
|
||||
%arg_tuple.1 = () parameter(0)
|
||||
|
@ -288,6 +289,128 @@ ENTRY %main.4 (arg_tuple.1: ()) -> (s32[0], s32[0]) {
|
|||
status_or_hlo_module.ValueOrDie()->ToString());
|
||||
}
|
||||
|
||||
// The following xla::OpSharding protos are used:
|
||||
// Serialized string:
|
||||
// "\08\03\1A\02\01\02\22\02\00\01"
|
||||
// Proto debug string:
|
||||
// type: OTHER
|
||||
// tile_assignment_dimensions: 1
|
||||
// tile_assignment_dimensions: 2
|
||||
// tile_assignment_devices: 0
|
||||
// tile_assignment_devices: 1
|
||||
//
|
||||
// Serialized string:
|
||||
// "\08\01\1A\01\01\22\01\00"
|
||||
// Proto debug string:
|
||||
// type: MAXIMAL
|
||||
// tile_assignment_dimensions: 1
|
||||
// tile_assignment_devices: 0
|
||||
//
|
||||
// Serialized string:
|
||||
// ""
|
||||
// Proto debug string (empty but would equivalent to):
|
||||
// type: REPLICATED
|
||||
TEST(CompileSerializedMlirToXlaHloTest, ArgumentSharding) {
|
||||
constexpr char mlir_module[] = R"(
|
||||
module attributes {tf.versions = {producer = 179 : i32}} {
|
||||
func @main(%arg0: tensor<128x10xf32> {xla_hlo.sharding = "\08\03\1A\02\01\02\22\02\00\01"}, %arg1: tensor<10x1024xf32> {xla_hlo.sharding = "\08\01\1A\01\01\22\01\00"}, %arg2: tensor<128x1024xf32> {xla_hlo.sharding = ""}) {
|
||||
return
|
||||
}
|
||||
}
|
||||
)";
|
||||
|
||||
std::vector<TensorShape> arg_shapes{TensorShape({128, 10}),
|
||||
TensorShape({10, 1024}),
|
||||
TensorShape({128, 1024})};
|
||||
XlaCompiler::CompilationResult compilation_result;
|
||||
|
||||
Status s = CompileSerializedMlirToXlaHlo(
|
||||
mlir_module, arg_shapes,
|
||||
/*use_tuple_args=*/true, TestShapeRepresentation, &compilation_result);
|
||||
TF_ASSERT_OK(s);
|
||||
|
||||
const xla::HloModuleConfig module_config(
|
||||
compilation_result.computation->GetProgramShape().ValueOrDie());
|
||||
auto status_or_hlo_module = xla::HloModule::CreateFromProto(
|
||||
compilation_result.computation->proto(), module_config);
|
||||
TF_ASSERT_OK(status_or_hlo_module.status());
|
||||
constexpr char expected_hlo_module_string[] = R"(HloModule main.6
|
||||
|
||||
ENTRY %main.6 (arg_tuple.1: (f32[128,10], f32[10,1024], f32[128,1024])) -> () {
|
||||
%arg_tuple.1 = (f32[128,10]{1,0}, f32[10,1024]{1,0}, f32[128,1024]{1,0}) parameter(0), sharding={{devices=[1,2]0,1}, {maximal device=0}, {replicated}}
|
||||
%get-tuple-element.2 = f32[128,10]{1,0} get-tuple-element((f32[128,10]{1,0}, f32[10,1024]{1,0}, f32[128,1024]{1,0}) %arg_tuple.1), index=0
|
||||
%get-tuple-element.3 = f32[10,1024]{1,0} get-tuple-element((f32[128,10]{1,0}, f32[10,1024]{1,0}, f32[128,1024]{1,0}) %arg_tuple.1), index=1
|
||||
%get-tuple-element.4 = f32[128,1024]{1,0} get-tuple-element((f32[128,10]{1,0}, f32[10,1024]{1,0}, f32[128,1024]{1,0}) %arg_tuple.1), index=2
|
||||
ROOT %tuple.5 = () tuple()
|
||||
}
|
||||
|
||||
)";
|
||||
EXPECT_EQ(expected_hlo_module_string,
|
||||
status_or_hlo_module.ValueOrDie()->ToString());
|
||||
}
|
||||
|
||||
TEST(CompileSerializedMlirToXlaHloTest, BadArgumentSharding) {
|
||||
constexpr char mlir_module[] = R"(
|
||||
module attributes {tf.versions = {producer = 179 : i32}} {
|
||||
func @main(%arg0: tensor<128x10xf32> {xla_hlo.sharding = "bad_sharding"}) {
|
||||
return
|
||||
}
|
||||
}
|
||||
)";
|
||||
|
||||
std::vector<TensorShape> arg_shapes{TensorShape({128, 10})};
|
||||
XlaCompiler::CompilationResult compilation_result;
|
||||
|
||||
Status s = CompileSerializedMlirToXlaHlo(
|
||||
mlir_module, arg_shapes,
|
||||
/*use_tuple_args=*/true, TestShapeRepresentation, &compilation_result);
|
||||
ASSERT_FALSE(s.ok());
|
||||
EXPECT_EQ(s.error_message(),
|
||||
"failed to parse argument sharding 0 'bad_sharding'");
|
||||
}
|
||||
|
||||
TEST(CompileSerializedMlirToXlaHloTest, ResultSharding) {
|
||||
constexpr char mlir_module[] = R"(
|
||||
module attributes {tf.versions = {bad_consumers = [], min_consumer = 0 : i32, producer = 351 : i32}} {
|
||||
func @main(%arg0: tensor<128x10xf32>, %arg1: tensor<10x1024xf32>, %arg2: tensor<128x1024xf32>) -> (tensor<128x10xf32> {xla_hlo.sharding = "\08\03\1A\02\01\02\22\02\00\01"}, tensor<10x1024xf32> {xla_hlo.sharding = "\08\01\1A\01\01\22\01\00"}, tensor<128x1024xf32> {xla_hlo.sharding = ""}) {
|
||||
return %arg0, %arg1, %arg2 : tensor<128x10xf32>, tensor<10x1024xf32>, tensor<128x1024xf32>
|
||||
}
|
||||
}
|
||||
)";
|
||||
|
||||
std::vector<TensorShape> arg_shapes{TensorShape({128, 10}),
|
||||
TensorShape({10, 1024}),
|
||||
TensorShape({128, 1024})};
|
||||
XlaCompiler::CompilationResult compilation_result;
|
||||
|
||||
Status s = CompileSerializedMlirToXlaHlo(
|
||||
mlir_module, arg_shapes,
|
||||
/*use_tuple_args=*/true, TestShapeRepresentation, &compilation_result);
|
||||
TF_ASSERT_OK(s);
|
||||
|
||||
const xla::HloModuleConfig module_config(
|
||||
compilation_result.computation->GetProgramShape().ValueOrDie());
|
||||
auto status_or_hlo_module = xla::HloModule::CreateFromProto(
|
||||
compilation_result.computation->proto(), module_config);
|
||||
TF_ASSERT_OK(status_or_hlo_module.status());
|
||||
constexpr char expected_hlo_module_string[] = R"(HloModule main.9
|
||||
|
||||
ENTRY %main.9 (arg_tuple.1: (f32[128,10], f32[10,1024], f32[128,1024])) -> (f32[128,10], f32[10,1024], f32[128,1024]) {
|
||||
%arg_tuple.1 = (f32[128,10]{1,0}, f32[10,1024]{1,0}, f32[128,1024]{1,0}) parameter(0)
|
||||
%get-tuple-element.2 = f32[128,10]{1,0} get-tuple-element((f32[128,10]{1,0}, f32[10,1024]{1,0}, f32[128,1024]{1,0}) %arg_tuple.1), index=0
|
||||
%reshape.5 = f32[128,10]{1,0} reshape(f32[128,10]{1,0} %get-tuple-element.2)
|
||||
%get-tuple-element.3 = f32[10,1024]{1,0} get-tuple-element((f32[128,10]{1,0}, f32[10,1024]{1,0}, f32[128,1024]{1,0}) %arg_tuple.1), index=1
|
||||
%reshape.6 = f32[10,1024]{1,0} reshape(f32[10,1024]{1,0} %get-tuple-element.3)
|
||||
%get-tuple-element.4 = f32[128,1024]{1,0} get-tuple-element((f32[128,10]{1,0}, f32[10,1024]{1,0}, f32[128,1024]{1,0}) %arg_tuple.1), index=2
|
||||
%reshape.7 = f32[128,1024]{1,0} reshape(f32[128,1024]{1,0} %get-tuple-element.4)
|
||||
ROOT %tuple.8 = (f32[128,10]{1,0}, f32[10,1024]{1,0}, f32[128,1024]{1,0}) tuple(f32[128,10]{1,0} %reshape.5, f32[10,1024]{1,0} %reshape.6, f32[128,1024]{1,0} %reshape.7), sharding={{devices=[1,2]0,1}, {maximal device=0}, {replicated}}
|
||||
}
|
||||
|
||||
)";
|
||||
EXPECT_EQ(expected_hlo_module_string,
|
||||
status_or_hlo_module.ValueOrDie()->ToString());
|
||||
}
|
||||
|
||||
// Verify that conversion from Graph to MLIR and empty shape representation
|
||||
// function is successful.
|
||||
TEST(CompileGraphToXlaHlo, Basic) {
|
||||
|
@ -311,7 +434,7 @@ TEST(CompileGraphToXlaHlo, Basic) {
|
|||
result.computation->proto(), module_config);
|
||||
ASSERT_TRUE(status_or_hlo_module.ok());
|
||||
|
||||
string expected_hlo_module_string = R"(HloModule main.3
|
||||
constexpr char expected_hlo_module_string[] = R"(HloModule main.3
|
||||
|
||||
ENTRY %main.3 (Arg_0.1: f32[]) -> (f32[]) {
|
||||
%Arg_0.1 = f32[] parameter(0)
|
||||
|
|
|
@ -618,7 +618,10 @@ cc_library(
|
|||
":hlo",
|
||||
":type_to_shape",
|
||||
":xla_dialect_registration",
|
||||
"//tensorflow/compiler/mlir/tensorflow:convert_type",
|
||||
"//tensorflow/compiler/mlir/tensorflow:error_util",
|
||||
"//tensorflow/compiler/tf2xla:common",
|
||||
"//tensorflow/compiler/tf2xla:xla_compiler",
|
||||
"//tensorflow/compiler/xla:comparison_util",
|
||||
"//tensorflow/compiler/xla:literal_util",
|
||||
"//tensorflow/compiler/xla:shape_util",
|
||||
|
@ -629,6 +632,8 @@ cc_library(
|
|||
"//tensorflow/compiler/xla/client/lib:quantize",
|
||||
"//tensorflow/compiler/xla/client/lib:slicing",
|
||||
"//tensorflow/compiler/xla/service:hlo",
|
||||
"//tensorflow/core:framework",
|
||||
"//tensorflow/core:protos_all_cc",
|
||||
"//tensorflow/stream_executor/lib",
|
||||
"@llvm-project//llvm:support",
|
||||
"@llvm-project//mlir:Analysis",
|
||||
|
|
|
@ -21,6 +21,7 @@ limitations under the License.
|
|||
#include "llvm/ADT/DenseMap.h"
|
||||
#include "llvm/ADT/DenseSet.h"
|
||||
#include "llvm/ADT/STLExtras.h"
|
||||
#include "llvm/ADT/SmallVector.h"
|
||||
#include "llvm/ADT/StringRef.h"
|
||||
#include "llvm/Support/FormatVariadic.h"
|
||||
#include "llvm/Support/MemoryBuffer.h"
|
||||
|
@ -37,8 +38,12 @@ limitations under the License.
|
|||
#include "mlir/IR/Operation.h" // from @llvm-project
|
||||
#include "mlir/IR/StandardTypes.h" // from @llvm-project
|
||||
#include "mlir/IR/TypeUtilities.h" // from @llvm-project
|
||||
#include "mlir/IR/UseDefLists.h" // from @llvm-project
|
||||
#include "tensorflow/compiler/mlir/tensorflow/utils/convert_type.h"
|
||||
#include "tensorflow/compiler/mlir/xla/ir/hlo_ops.h"
|
||||
#include "tensorflow/compiler/mlir/xla/type_to_shape.h"
|
||||
#include "tensorflow/compiler/tf2xla/shape_util.h"
|
||||
#include "tensorflow/compiler/tf2xla/xla_compiler.h"
|
||||
#include "tensorflow/compiler/xla/client/lib/matrix.h"
|
||||
#include "tensorflow/compiler/xla/client/lib/quantize.h"
|
||||
#include "tensorflow/compiler/xla/client/lib/slicing.h"
|
||||
|
@ -49,6 +54,8 @@ limitations under the License.
|
|||
#include "tensorflow/compiler/xla/shape_util.h"
|
||||
#include "tensorflow/compiler/xla/status_macros.h"
|
||||
#include "tensorflow/compiler/xla/xla_data.pb.h"
|
||||
#include "tensorflow/core/framework/tensor_shape.h"
|
||||
#include "tensorflow/core/framework/types.pb.h"
|
||||
#include "tensorflow/stream_executor/lib/statusor.h"
|
||||
|
||||
using ::stream_executor::port::StatusOr;
|
||||
|
@ -64,6 +71,7 @@ using ::tensorflow::uint8;
|
|||
constexpr char kPaddingMapAttr[] = "xla_hlo.padding_map";
|
||||
constexpr char kShapeIndicesAttr[] = "shape_indices";
|
||||
constexpr char kPaddingArgIndicesAttr[] = "padding_arg_indices";
|
||||
constexpr char kShardingAttr[] = "xla_hlo.sharding";
|
||||
constexpr char kRepicationAttr[] = "tf_device.is_same_data_across_replicas";
|
||||
|
||||
// Passes through everything except for unique_ptr, on which it calls get().
|
||||
|
@ -377,7 +385,7 @@ static xla::ScatterDimensionNumbers Convert_scatter_dimension_numbers(
|
|||
// returns absl::nullopt.
|
||||
static absl::optional<xla::OpSharding> CreateOpShardingFromAttribute(
|
||||
mlir::Operation* op) {
|
||||
auto sharding = op->getAttrOfType<mlir::StringAttr>("xla_hlo.sharding");
|
||||
auto sharding = op->getAttrOfType<mlir::StringAttr>(kShardingAttr);
|
||||
if (!sharding) {
|
||||
return absl::nullopt;
|
||||
}
|
||||
|
@ -389,6 +397,43 @@ static absl::optional<xla::OpSharding> CreateOpShardingFromAttribute(
|
|||
return sharding_proto;
|
||||
}
|
||||
|
||||
// Checks if all shardings are set.
|
||||
static bool AllOptionalShardingsAreSet(
|
||||
llvm::ArrayRef<absl::optional<xla::OpSharding>> shardings) {
|
||||
return llvm::all_of(shardings,
|
||||
[](const absl::optional<xla::OpSharding>& sharding) {
|
||||
return sharding.has_value();
|
||||
});
|
||||
}
|
||||
|
||||
// Extracts sharding from attribute string.
|
||||
static absl::optional<xla::OpSharding> CreateOpShardingFromStringRef(
|
||||
llvm::StringRef sharding) {
|
||||
xla::OpSharding sharding_proto;
|
||||
if (!sharding_proto.ParseFromString(sharding.str())) return absl::nullopt;
|
||||
return sharding_proto;
|
||||
}
|
||||
|
||||
// Extracts argument and result shardings from function.
|
||||
static void ExtractShardingsFromFunction(
|
||||
mlir::FuncOp function,
|
||||
llvm::SmallVectorImpl<absl::optional<xla::OpSharding>>* arg_shardings,
|
||||
llvm::SmallVectorImpl<absl::optional<xla::OpSharding>>* ret_shardings) {
|
||||
arg_shardings->resize(function.getNumArguments(),
|
||||
absl::optional<xla::OpSharding>());
|
||||
for (int i = 0; i < function.getNumArguments(); ++i)
|
||||
if (auto sharding =
|
||||
function.getArgAttrOfType<mlir::StringAttr>(i, kShardingAttr))
|
||||
(*arg_shardings)[i] = CreateOpShardingFromStringRef(sharding.getValue());
|
||||
|
||||
ret_shardings->resize(function.getNumResults(),
|
||||
absl::optional<xla::OpSharding>());
|
||||
for (int i = 0; i < function.getNumResults(); ++i)
|
||||
if (auto sharding =
|
||||
function.getResultAttrOfType<mlir::StringAttr>(i, kShardingAttr))
|
||||
(*ret_shardings)[i] = CreateOpShardingFromStringRef(sharding.getValue());
|
||||
}
|
||||
|
||||
namespace mlir {
|
||||
namespace {
|
||||
class ConvertToHloModule {
|
||||
|
@ -402,12 +447,17 @@ class ConvertToHloModule {
|
|||
// are converted to a tuple even when there is only a single return value.
|
||||
// Multiple return values are always converted to a tuple and returned as a
|
||||
// single value.
|
||||
explicit ConvertToHloModule(mlir::ModuleOp module, bool use_tuple_args,
|
||||
bool return_tuple)
|
||||
explicit ConvertToHloModule(
|
||||
mlir::ModuleOp module, bool use_tuple_args, bool return_tuple,
|
||||
tensorflow::XlaCompiler::ShapeRepresentationFn shape_representation_fn)
|
||||
: module_(module),
|
||||
module_builder_("main"),
|
||||
use_tuple_args_(use_tuple_args),
|
||||
return_tuple_(return_tuple) {}
|
||||
return_tuple_(return_tuple),
|
||||
shape_representation_fn_(shape_representation_fn) {
|
||||
if (!shape_representation_fn_)
|
||||
shape_representation_fn_ = tensorflow::IdentityShapeRepresentationFn();
|
||||
}
|
||||
|
||||
// Perform the lowering to XLA. This function returns failure if an error was
|
||||
// encountered.
|
||||
|
@ -432,6 +482,8 @@ class ConvertToHloModule {
|
|||
LogicalResult LowerBasicBlockAsFunction(
|
||||
Block* block, xla::XlaBuilder* builder, bool is_entry_function,
|
||||
const std::vector<bool>& entry_args_same_across_replicas,
|
||||
llvm::ArrayRef<absl::optional<xla::OpSharding>> arg_shardings,
|
||||
llvm::ArrayRef<absl::optional<xla::OpSharding>> ret_shardings,
|
||||
xla::XlaComputation* result);
|
||||
|
||||
::xla::HloModuleProto ConsumeMainProto() {
|
||||
|
@ -445,10 +497,22 @@ class ConvertToHloModule {
|
|||
ConvertToHloModule::ValueLoweringMap* value_lowering);
|
||||
|
||||
private:
|
||||
LogicalResult Lower(mlir::Operation* inst, bool is_entry_function,
|
||||
xla::XlaBuilder* builder,
|
||||
ConvertToHloModule::ValueLoweringMap* value_lowering,
|
||||
xla::XlaComputation* result);
|
||||
LogicalResult Lower(
|
||||
mlir::Operation* inst, bool is_entry_function,
|
||||
llvm::ArrayRef<absl::optional<xla::OpSharding>> ret_shardings,
|
||||
xla::XlaBuilder* builder,
|
||||
ConvertToHloModule::ValueLoweringMap* value_lowering,
|
||||
xla::XlaComputation* result);
|
||||
|
||||
LogicalResult SetEntryTupleShapesAndLeafReplication(
|
||||
Block* block, const std::vector<bool>& entry_args_same_across_replicas,
|
||||
llvm::SmallVectorImpl<xla::Shape>* arg_shapes,
|
||||
std::vector<bool>* leaf_replication);
|
||||
|
||||
LogicalResult SetEntryTupleShardings(
|
||||
Block* block, xla::XlaBuilder* builder,
|
||||
llvm::ArrayRef<absl::optional<xla::OpSharding>> arg_shardings,
|
||||
llvm::SmallVectorImpl<xla::Shape>* arg_shapes);
|
||||
|
||||
// The module being lowered.
|
||||
mlir::ModuleOp module_;
|
||||
|
@ -465,6 +529,10 @@ class ConvertToHloModule {
|
|||
// Whether to always return a tuple.
|
||||
bool return_tuple_;
|
||||
|
||||
// Shape representation function to determine entry function argument and
|
||||
// result shapes.
|
||||
tensorflow::XlaCompiler::ShapeRepresentationFn shape_representation_fn_;
|
||||
|
||||
// Unique suffix to give to the name of the next lowered region.
|
||||
size_t region_id_ = 0;
|
||||
};
|
||||
|
@ -876,7 +944,9 @@ StatusOr<xla::Literal> CreateLiteralFromAttr(Type type, ElementsAttr attr) {
|
|||
}
|
||||
|
||||
LogicalResult ConvertToHloModule::Lower(
|
||||
mlir::Operation* inst, bool is_entry_function, xla::XlaBuilder* builder,
|
||||
mlir::Operation* inst, bool is_entry_function,
|
||||
llvm::ArrayRef<absl::optional<xla::OpSharding>> ret_shardings,
|
||||
xla::XlaBuilder* builder,
|
||||
ConvertToHloModule::ValueLoweringMap* value_lowering,
|
||||
xla::XlaComputation* result) {
|
||||
if (succeeded(ExportXlaOperator(inst, {value_lowering, this, builder}))) {
|
||||
|
@ -906,11 +976,37 @@ LogicalResult ConvertToHloModule::Lower(
|
|||
xla::XlaOp return_value;
|
||||
unsigned num_return_values = inst->getNumOperands();
|
||||
if ((return_tuple_ && is_entry_function) || num_return_values > 1) {
|
||||
const bool has_ret_shardings =
|
||||
!ret_shardings.empty() && AllOptionalShardingsAreSet(ret_shardings);
|
||||
|
||||
std::vector<xla::XlaOp> returns(num_return_values);
|
||||
for (unsigned i = 0, e = inst->getNumOperands(); i != e; ++i) {
|
||||
returns[i] = value_map[inst->getOperand(i)];
|
||||
for (OpOperand& ret : inst->getOpOperands()) {
|
||||
unsigned index = ret.getOperandNumber();
|
||||
returns[index] = value_map[ret.get()];
|
||||
if (!is_entry_function || !has_ret_shardings) continue;
|
||||
|
||||
xla::Shape return_shape = xla::TypeToShape(ret.get().getType());
|
||||
StatusOr<xla::XlaOp> reshape =
|
||||
tensorflow::ReshapeWithCorrectRepresentationAndSharding(
|
||||
builder, returns[index], return_shape, shape_representation_fn_,
|
||||
ret_shardings[index], /*fast_mem=*/false);
|
||||
if (!reshape.ok())
|
||||
return inst->emitError() << reshape.status().error_message();
|
||||
|
||||
returns[index] = reshape.ValueOrDie();
|
||||
}
|
||||
|
||||
if (has_ret_shardings) {
|
||||
xla::OpSharding sharding;
|
||||
sharding.set_type(xla::OpSharding::TUPLE);
|
||||
for (auto& ret_sharding : ret_shardings)
|
||||
*sharding.add_tuple_shardings() = ret_sharding.value();
|
||||
|
||||
builder->SetSharding(sharding);
|
||||
}
|
||||
|
||||
return_value = xla::Tuple(builder, returns);
|
||||
builder->ClearSharding();
|
||||
} else if (num_return_values == 1) {
|
||||
return_value = value_map[inst->getOperand(0)];
|
||||
}
|
||||
|
@ -976,6 +1072,8 @@ LogicalResult ConvertToHloModule::RunOnFunction(mlir::FuncOp f) {
|
|||
|
||||
xla::XlaComputation computation;
|
||||
std::vector<bool> entry_args_same_across_replicas;
|
||||
llvm::SmallVector<absl::optional<xla::OpSharding>, 4> arg_shardings;
|
||||
llvm::SmallVector<absl::optional<xla::OpSharding>, 4> ret_shardings;
|
||||
if (entry_function) {
|
||||
bool any_arg_replicated = false;
|
||||
entry_args_same_across_replicas.reserve(f.getNumArguments());
|
||||
|
@ -1000,21 +1098,90 @@ LogicalResult ConvertToHloModule::RunOnFunction(mlir::FuncOp f) {
|
|||
// means no replication. This avoids the need for unrelated tests to handle
|
||||
// this field.
|
||||
if (!any_arg_replicated) entry_args_same_across_replicas.clear();
|
||||
|
||||
ExtractShardingsFromFunction(f, &arg_shardings, &ret_shardings);
|
||||
}
|
||||
if (failed(LowerBasicBlockAsFunction(&f.front(), &builder, entry_function,
|
||||
entry_args_same_across_replicas,
|
||||
&computation))) {
|
||||
if (failed(LowerBasicBlockAsFunction(
|
||||
&f.front(), &builder, entry_function, entry_args_same_across_replicas,
|
||||
arg_shardings, ret_shardings, &computation))) {
|
||||
return failure();
|
||||
}
|
||||
lowered_computation_[f] = std::move(computation);
|
||||
return success();
|
||||
}
|
||||
|
||||
LogicalResult ConvertToHloModule::SetEntryTupleShapesAndLeafReplication(
|
||||
Block* block, const std::vector<bool>& entry_args_same_across_replicas,
|
||||
llvm::SmallVectorImpl<xla::Shape>* arg_shapes,
|
||||
std::vector<bool>* leaf_replication) {
|
||||
arg_shapes->reserve(block->getNumArguments());
|
||||
leaf_replication->reserve(block->getNumArguments());
|
||||
for (BlockArgument& arg : block->getArguments()) {
|
||||
arg_shapes->push_back(xla::TypeToShape(arg.getType()));
|
||||
xla::Shape& arg_shape = arg_shapes->back();
|
||||
tensorflow::TensorShape arg_tensor_shape;
|
||||
auto status =
|
||||
tensorflow::XLAShapeToTensorShape(arg_shape, &arg_tensor_shape);
|
||||
if (!status.ok())
|
||||
return block->getParentOp()->emitError() << status.error_message();
|
||||
|
||||
tensorflow::DataType dtype;
|
||||
status = tensorflow::ConvertToDataType(arg.getType(), &dtype);
|
||||
if (!status.ok())
|
||||
return block->getParentOp()->emitError() << status.error_message();
|
||||
|
||||
auto arg_shape_status = shape_representation_fn_(arg_tensor_shape, dtype,
|
||||
/*use_fast_memory=*/false);
|
||||
if (!arg_shape_status.ok())
|
||||
return block->getParentOp()->emitError()
|
||||
<< arg_shape_status.status().error_message();
|
||||
|
||||
arg_shape = std::move(arg_shape_status.ValueOrDie());
|
||||
|
||||
if (entry_args_same_across_replicas.empty()) continue;
|
||||
for (int i = 0, e = xla::ShapeUtil::GetLeafCount(arg_shape); i < e; ++i)
|
||||
leaf_replication->push_back(
|
||||
entry_args_same_across_replicas[arg.getArgNumber()]);
|
||||
}
|
||||
|
||||
return success();
|
||||
}
|
||||
|
||||
LogicalResult ConvertToHloModule::SetEntryTupleShardings(
|
||||
Block* block, xla::XlaBuilder* builder,
|
||||
llvm::ArrayRef<absl::optional<xla::OpSharding>> arg_shardings,
|
||||
llvm::SmallVectorImpl<xla::Shape>* arg_shapes) {
|
||||
if (!arg_shardings.empty() && AllOptionalShardingsAreSet(arg_shardings)) {
|
||||
xla::OpSharding sharding;
|
||||
sharding.set_type(xla::OpSharding::TUPLE);
|
||||
for (auto arg_sharding : llvm::enumerate(arg_shardings)) {
|
||||
auto hlo_sharding =
|
||||
xla::HloSharding::FromProto(arg_sharding.value().value());
|
||||
if (!hlo_sharding.ok())
|
||||
return block->getParentOp()->emitError()
|
||||
<< hlo_sharding.status().error_message();
|
||||
|
||||
auto status = tensorflow::RewriteLayoutWithShardedShape(
|
||||
hlo_sharding.ValueOrDie(), /*use_fast_memory=*/false,
|
||||
shape_representation_fn_, &(*arg_shapes)[arg_sharding.index()]);
|
||||
if (!status.ok())
|
||||
return block->getParentOp()->emitError() << status.error_message();
|
||||
|
||||
*sharding.add_tuple_shardings() = arg_sharding.value().value();
|
||||
}
|
||||
|
||||
builder->SetSharding(sharding);
|
||||
}
|
||||
|
||||
return success();
|
||||
}
|
||||
|
||||
LogicalResult ConvertToHloModule::LowerBasicBlockAsFunction(
|
||||
Block* block, xla::XlaBuilder* builder, bool is_entry_function,
|
||||
const std::vector<bool>& entry_args_same_across_replicas,
|
||||
llvm::ArrayRef<absl::optional<xla::OpSharding>> arg_shardings,
|
||||
llvm::ArrayRef<absl::optional<xla::OpSharding>> ret_shardings,
|
||||
xla::XlaComputation* result) {
|
||||
auto& bb = *block;
|
||||
// Mapping from the Value to lowered XlaOp. The code below lowers in
|
||||
// program order and will fail if an operand is unseen. This can be improved.
|
||||
ValueLoweringMap lowering;
|
||||
|
@ -1022,29 +1189,28 @@ LogicalResult ConvertToHloModule::LowerBasicBlockAsFunction(
|
|||
// If using tuples as input, then there is only one input parameter that is a
|
||||
// tuple.
|
||||
if (is_entry_function && use_tuple_args_) {
|
||||
std::vector<xla::Shape> arg_shapes;
|
||||
arg_shapes.reserve(bb.getNumArguments());
|
||||
llvm::SmallVector<xla::Shape, 4> arg_shapes;
|
||||
std::vector<bool> leaf_replication;
|
||||
for (auto& arg : bb.getArguments()) {
|
||||
arg_shapes.push_back(xla::TypeToShape(arg.getType()));
|
||||
if (!entry_args_same_across_replicas.empty()) {
|
||||
for (int i = 0; i < xla::ShapeUtil::GetLeafCount(arg_shapes.back());
|
||||
++i) {
|
||||
leaf_replication.push_back(
|
||||
entry_args_same_across_replicas[arg.getArgNumber()]);
|
||||
}
|
||||
}
|
||||
}
|
||||
if (failed(SetEntryTupleShapesAndLeafReplication(
|
||||
block, entry_args_same_across_replicas, &arg_shapes,
|
||||
&leaf_replication)))
|
||||
return failure();
|
||||
|
||||
if (failed(
|
||||
SetEntryTupleShardings(block, builder, arg_shardings, &arg_shapes)))
|
||||
return failure();
|
||||
|
||||
xla::Shape input_shape = xla::ShapeUtil::MakeTupleShape(arg_shapes);
|
||||
auto tuple =
|
||||
xla::Parameter(builder, 0, input_shape, "arg_tuple", leaf_replication);
|
||||
for (auto& it : llvm::enumerate(bb.getArguments())) {
|
||||
lowering[it.value()] = xla::GetTupleElement(tuple, it.index());
|
||||
}
|
||||
|
||||
builder->ClearSharding();
|
||||
|
||||
for (BlockArgument& arg : block->getArguments())
|
||||
lowering[arg] = xla::GetTupleElement(tuple, arg.getArgNumber());
|
||||
} else {
|
||||
for (auto& it : llvm::enumerate(bb.getArguments())) {
|
||||
auto arg = it.value();
|
||||
auto num = it.index();
|
||||
for (BlockArgument& arg : block->getArguments()) {
|
||||
auto num = arg.getArgNumber();
|
||||
xla::Shape shape = xla::TypeToShape(arg.getType());
|
||||
if (entry_args_same_across_replicas.empty()) {
|
||||
lowering[arg] =
|
||||
|
@ -1058,8 +1224,9 @@ LogicalResult ConvertToHloModule::LowerBasicBlockAsFunction(
|
|||
}
|
||||
}
|
||||
|
||||
for (auto& inst : bb)
|
||||
if (failed(Lower(&inst, is_entry_function, builder, &lowering, result)))
|
||||
for (auto& inst : *block)
|
||||
if (failed(Lower(&inst, is_entry_function, ret_shardings, builder,
|
||||
&lowering, result)))
|
||||
return failure();
|
||||
|
||||
return success();
|
||||
|
@ -1069,8 +1236,10 @@ LogicalResult ConvertToHloModule::LowerRegionAsComputation(
|
|||
mlir::Region* region, xla::XlaComputation* func) {
|
||||
std::unique_ptr<xla::XlaBuilder> builder =
|
||||
module_builder_.CreateSubBuilder(absl::StrCat("region_", region_id_++));
|
||||
return LowerBasicBlockAsFunction(®ion->front(), builder.get(),
|
||||
/*is_entry_function=*/false, {}, func);
|
||||
return LowerBasicBlockAsFunction(
|
||||
®ion->front(), builder.get(),
|
||||
/*is_entry_function=*/false, /*entry_args_same_across_replicas=*/{},
|
||||
/*arg_shardings=*/{}, /*ret_shardings=*/{}, func);
|
||||
}
|
||||
|
||||
std::string PaddingMapBadArrayAttrMsg(llvm::StringRef attr_name, int index) {
|
||||
|
@ -1241,9 +1410,12 @@ LogicalResult AddDynamicParameterBindings(mlir::ModuleOp module,
|
|||
} // namespace
|
||||
|
||||
Status ConvertMlirHloToHlo(mlir::ModuleOp module, xla::HloProto* hlo_proto,
|
||||
bool use_tuple_args, bool return_tuple) {
|
||||
bool use_tuple_args, bool return_tuple,
|
||||
const tensorflow::XlaCompiler::ShapeRepresentationFn
|
||||
shape_representation_fn) {
|
||||
mlir::StatusScopedDiagnosticHandler diag_handler(module.getContext());
|
||||
ConvertToHloModule converter(module, use_tuple_args, return_tuple);
|
||||
ConvertToHloModule converter(module, use_tuple_args, return_tuple,
|
||||
shape_representation_fn);
|
||||
if (failed(converter.Run())) return diag_handler.ConsumeStatus();
|
||||
auto hlo_module = converter.ConsumeMainProto();
|
||||
hlo_proto->mutable_hlo_module()->Swap(&hlo_module);
|
||||
|
|
|
@ -18,6 +18,7 @@ limitations under the License.
|
|||
|
||||
#include "mlir/IR/Module.h" // from @llvm-project
|
||||
#include "tensorflow/compiler/mlir/tensorflow/utils/error_util.h"
|
||||
#include "tensorflow/compiler/tf2xla/xla_compiler.h"
|
||||
#include "tensorflow/compiler/xla/client/xla_builder.h"
|
||||
#include "tensorflow/compiler/xla/service/hlo_module.h"
|
||||
|
||||
|
@ -31,7 +32,9 @@ namespace mlir {
|
|||
// Multiple return values are always converted to a tuple and returned as a
|
||||
// single value.
|
||||
Status ConvertMlirHloToHlo(mlir::ModuleOp module, xla::HloProto* hlo_proto,
|
||||
bool use_tuple_args, bool return_tuple);
|
||||
bool use_tuple_args, bool return_tuple,
|
||||
const tensorflow::XlaCompiler::ShapeRepresentationFn
|
||||
shape_representation_fn = nullptr);
|
||||
|
||||
// Creates XlaOp equivalent of a given MLIR operation using the operand info
|
||||
// from `value_lowering` map.
|
||||
|
|
|
@ -138,46 +138,6 @@ Status ExecuteGraph(XlaContext* xla_context, std::unique_ptr<Graph> graph,
|
|||
return Status::OK();
|
||||
}
|
||||
|
||||
// There is a shape_representation_fn or sharding for an output, this function
|
||||
// uses a reshape to fix the layout.
|
||||
xla::StatusOr<xla::XlaOp> ReshapeWithCorrectRepresentationAndSharding(
|
||||
xla::XlaBuilder* builder, xla::XlaOp original, xla::Shape original_shape,
|
||||
XlaCompiler::ShapeRepresentationFn shape_representation_fn,
|
||||
absl::optional<xla::OpSharding> sharding, bool fast_mem) {
|
||||
if (original_shape.IsTuple()) {
|
||||
std::vector<xla::XlaOp> elements;
|
||||
for (int64 i = 0; i < original_shape.tuple_shapes_size(); ++i) {
|
||||
auto subsharding = sharding ? sharding->tuple_shardings(i) : sharding;
|
||||
TF_ASSIGN_OR_RETURN(auto element,
|
||||
ReshapeWithCorrectRepresentationAndSharding(
|
||||
builder, xla::GetTupleElement(original, i),
|
||||
original_shape.tuple_shapes(i),
|
||||
shape_representation_fn, subsharding, fast_mem));
|
||||
elements.push_back(element);
|
||||
}
|
||||
return xla::Tuple(builder, elements);
|
||||
}
|
||||
if (!original_shape.IsArray()) return original;
|
||||
TensorShape shape;
|
||||
TF_RETURN_IF_ERROR(XLAShapeToTensorShape(original_shape, &shape));
|
||||
TF_ASSIGN_OR_RETURN(DataType dtype, EncodePrimitiveTypeAsDataType(
|
||||
original_shape.element_type()));
|
||||
TF_ASSIGN_OR_RETURN(auto to_shape,
|
||||
shape_representation_fn(shape, dtype, fast_mem));
|
||||
if (sharding) {
|
||||
TF_ASSIGN_OR_RETURN(auto hlo_sharding,
|
||||
xla::HloSharding::FromProto(*sharding));
|
||||
TF_RETURN_IF_ERROR(RewriteLayoutWithShardedShape(
|
||||
hlo_sharding, fast_mem, shape_representation_fn, &to_shape));
|
||||
}
|
||||
if (xla::ShapeUtil::Compatible(original_shape, to_shape)) {
|
||||
for (int64 i = 0; i < original_shape.rank(); ++i) {
|
||||
to_shape.set_dynamic_dimension(i, original_shape.is_dynamic_dimension(i));
|
||||
}
|
||||
}
|
||||
return xla::Reshape(to_shape, original);
|
||||
}
|
||||
|
||||
// Builds the XLA computation.
|
||||
// - `args` is the list of input arguments
|
||||
// - `retvals` is the list of retvals produced by _Retval operators, in index
|
||||
|
@ -562,13 +522,7 @@ XlaCompiler::XlaCompiler(XlaCompiler::Options options)
|
|||
|
||||
// The default shape representation function is the identity.
|
||||
if (!options_.shape_representation_fn) {
|
||||
options_.shape_representation_fn =
|
||||
[](const TensorShape& shape, DataType dtype,
|
||||
bool use_fast_memory) -> xla::StatusOr<xla::Shape> {
|
||||
xla::Shape xla_shape;
|
||||
TF_RETURN_IF_ERROR(TensorShapeToXLAShape(dtype, shape, &xla_shape));
|
||||
return xla_shape;
|
||||
};
|
||||
options_.shape_representation_fn = IdentityShapeRepresentationFn();
|
||||
}
|
||||
}
|
||||
|
||||
|
@ -1502,6 +1456,15 @@ xla::StatusOr<xla::XlaOp> XlaCompiler::GetNodeToken(const string& node_name) {
|
|||
return iter->second;
|
||||
}
|
||||
|
||||
XlaCompiler::ShapeRepresentationFn IdentityShapeRepresentationFn() {
|
||||
return [](const TensorShape& shape, DataType dtype,
|
||||
bool use_fast_memory) -> xla::StatusOr<xla::Shape> {
|
||||
xla::Shape xla_shape;
|
||||
TF_RETURN_IF_ERROR(TensorShapeToXLAShape(dtype, shape, &xla_shape));
|
||||
return xla_shape;
|
||||
};
|
||||
}
|
||||
|
||||
// Rewrites the layout of xla_shape if there is tiled sharding.
|
||||
Status RewriteLayoutWithShardedShape(
|
||||
const absl::optional<xla::HloSharding>& sharding, bool use_fast_memory,
|
||||
|
@ -1542,4 +1505,44 @@ Status RewriteLayoutWithShardedShape(
|
|||
return Status::OK();
|
||||
}
|
||||
|
||||
// There is a shape_representation_fn or sharding for an output, this function
|
||||
// uses a reshape to fix the layout.
|
||||
xla::StatusOr<xla::XlaOp> ReshapeWithCorrectRepresentationAndSharding(
|
||||
xla::XlaBuilder* builder, xla::XlaOp original, xla::Shape original_shape,
|
||||
XlaCompiler::ShapeRepresentationFn shape_representation_fn,
|
||||
absl::optional<xla::OpSharding> sharding, bool fast_mem) {
|
||||
if (original_shape.IsTuple()) {
|
||||
std::vector<xla::XlaOp> elements;
|
||||
for (int64 i = 0; i < original_shape.tuple_shapes_size(); ++i) {
|
||||
auto subsharding = sharding ? sharding->tuple_shardings(i) : sharding;
|
||||
TF_ASSIGN_OR_RETURN(auto element,
|
||||
ReshapeWithCorrectRepresentationAndSharding(
|
||||
builder, xla::GetTupleElement(original, i),
|
||||
original_shape.tuple_shapes(i),
|
||||
shape_representation_fn, subsharding, fast_mem));
|
||||
elements.push_back(element);
|
||||
}
|
||||
return xla::Tuple(builder, elements);
|
||||
}
|
||||
if (!original_shape.IsArray()) return original;
|
||||
TensorShape shape;
|
||||
TF_RETURN_IF_ERROR(XLAShapeToTensorShape(original_shape, &shape));
|
||||
TF_ASSIGN_OR_RETURN(DataType dtype, EncodePrimitiveTypeAsDataType(
|
||||
original_shape.element_type()));
|
||||
TF_ASSIGN_OR_RETURN(auto to_shape,
|
||||
shape_representation_fn(shape, dtype, fast_mem));
|
||||
if (sharding) {
|
||||
TF_ASSIGN_OR_RETURN(auto hlo_sharding,
|
||||
xla::HloSharding::FromProto(*sharding));
|
||||
TF_RETURN_IF_ERROR(RewriteLayoutWithShardedShape(
|
||||
hlo_sharding, fast_mem, shape_representation_fn, &to_shape));
|
||||
}
|
||||
if (xla::ShapeUtil::Compatible(original_shape, to_shape)) {
|
||||
for (int64 i = 0; i < original_shape.rank(); ++i) {
|
||||
to_shape.set_dynamic_dimension(i, original_shape.is_dynamic_dimension(i));
|
||||
}
|
||||
}
|
||||
return xla::Reshape(to_shape, original);
|
||||
}
|
||||
|
||||
} // namespace tensorflow
|
||||
|
|
|
@ -518,12 +518,22 @@ class XlaCompiler {
|
|||
TF_DISALLOW_COPY_AND_ASSIGN(XlaCompiler);
|
||||
};
|
||||
|
||||
// Creates an identity shape representation function.
|
||||
XlaCompiler::ShapeRepresentationFn IdentityShapeRepresentationFn();
|
||||
|
||||
// Rewrites the layout of xla_shape if there is tiled sharding.
|
||||
Status RewriteLayoutWithShardedShape(
|
||||
const absl::optional<xla::HloSharding>& sharding, bool use_fast_memory,
|
||||
XlaCompiler::ShapeRepresentationFn shape_representation_fn,
|
||||
xla::Shape* xla_shape);
|
||||
|
||||
// Adds reshapes to fix the layout of an output, if a shape_representation_fn or
|
||||
// sharding is present.
|
||||
xla::StatusOr<xla::XlaOp> ReshapeWithCorrectRepresentationAndSharding(
|
||||
xla::XlaBuilder* builder, xla::XlaOp original, xla::Shape original_shape,
|
||||
XlaCompiler::ShapeRepresentationFn shape_representation_fn,
|
||||
absl::optional<xla::OpSharding> sharding, bool fast_mem);
|
||||
|
||||
} // namespace tensorflow
|
||||
|
||||
#endif // TENSORFLOW_COMPILER_TF2XLA_XLA_COMPILER_H_
|
||||
|
|
Loading…
Reference in New Issue