Add support for updating argument/result shapes and layouts with associated shardings of entry function.
Sharding is present with model parallelism. Depending on what type of sharding is present, argument/result shapes and layouts need to be updated. ShapeRepresentationFn and shardings are used to determine the new shapes and layouts. PiperOrigin-RevId: 303182568 Change-Id: I4185c1ae12de618b0b2ce9c07d2cd795c4e329b8
This commit is contained in:
parent
c5b4b6dc1f
commit
79ca75b618
|
@ -1065,6 +1065,7 @@ COMPILE_MLIR_UTIL_DEPS = [
|
||||||
":tensorflow_dialect_registration",
|
":tensorflow_dialect_registration",
|
||||||
":tensorflow_passes",
|
":tensorflow_passes",
|
||||||
":translate_utils",
|
":translate_utils",
|
||||||
|
"@com_google_absl//absl/types:optional",
|
||||||
"@llvm-project//llvm:support",
|
"@llvm-project//llvm:support",
|
||||||
"@llvm-project//mlir:IR",
|
"@llvm-project//mlir:IR",
|
||||||
"@llvm-project//mlir:Parser",
|
"@llvm-project//mlir:Parser",
|
||||||
|
@ -1083,6 +1084,8 @@ COMPILE_MLIR_UTIL_DEPS = [
|
||||||
"//tensorflow/core:protos_all_cc",
|
"//tensorflow/core:protos_all_cc",
|
||||||
"//tensorflow/core/platform:logging",
|
"//tensorflow/core/platform:logging",
|
||||||
"//tensorflow/stream_executor/lib",
|
"//tensorflow/stream_executor/lib",
|
||||||
|
"//tensorflow/compiler/xla:xla_data_proto_cc",
|
||||||
|
"//tensorflow/compiler/xla/service:hlo",
|
||||||
]
|
]
|
||||||
|
|
||||||
# Prefer to link 'compile_mlir_util' library that also links necessary
|
# Prefer to link 'compile_mlir_util' library that also links necessary
|
||||||
|
|
|
@ -15,6 +15,7 @@ limitations under the License.
|
||||||
|
|
||||||
#include "tensorflow/compiler/mlir/tensorflow/utils/compile_mlir_util.h"
|
#include "tensorflow/compiler/mlir/tensorflow/utils/compile_mlir_util.h"
|
||||||
|
|
||||||
|
#include "absl/types/optional.h"
|
||||||
#include "llvm/ADT/ArrayRef.h"
|
#include "llvm/ADT/ArrayRef.h"
|
||||||
#include "llvm/ADT/StringRef.h"
|
#include "llvm/ADT/StringRef.h"
|
||||||
#include "mlir/Dialect/StandardOps/IR/Ops.h" // from @llvm-project
|
#include "mlir/Dialect/StandardOps/IR/Ops.h" // from @llvm-project
|
||||||
|
@ -43,6 +44,8 @@ limitations under the License.
|
||||||
#include "tensorflow/compiler/mlir/xla/transforms/passes.h"
|
#include "tensorflow/compiler/mlir/xla/transforms/passes.h"
|
||||||
#include "tensorflow/compiler/mlir/xla/type_to_shape.h"
|
#include "tensorflow/compiler/mlir/xla/type_to_shape.h"
|
||||||
#include "tensorflow/compiler/tf2xla/shape_util.h"
|
#include "tensorflow/compiler/tf2xla/shape_util.h"
|
||||||
|
#include "tensorflow/compiler/xla/service/hlo_sharding.h"
|
||||||
|
#include "tensorflow/compiler/xla/xla_data.pb.h"
|
||||||
#include "tensorflow/core/platform/logging.h"
|
#include "tensorflow/core/platform/logging.h"
|
||||||
|
|
||||||
namespace tensorflow {
|
namespace tensorflow {
|
||||||
|
@ -74,7 +77,7 @@ Status ParseMlirModule(llvm::StringRef mlir_module_string,
|
||||||
Status GetXlaInputShapes(
|
Status GetXlaInputShapes(
|
||||||
mlir::ModuleOp module, llvm::ArrayRef<TensorShape> arg_shapes,
|
mlir::ModuleOp module, llvm::ArrayRef<TensorShape> arg_shapes,
|
||||||
bool use_tuple_args,
|
bool use_tuple_args,
|
||||||
const xla::CustomShapeRepresentationFn shape_representation_fn,
|
const XlaCompiler::ShapeRepresentationFn shape_representation_fn,
|
||||||
std::vector<xla::Shape>* xla_input_shapes) {
|
std::vector<xla::Shape>* xla_input_shapes) {
|
||||||
xla_input_shapes->clear();
|
xla_input_shapes->clear();
|
||||||
|
|
||||||
|
@ -93,7 +96,24 @@ Status GetXlaInputShapes(
|
||||||
DataType dtype;
|
DataType dtype;
|
||||||
TF_RETURN_IF_ERROR(ConvertToDataType(func_type.getInput(i), &dtype));
|
TF_RETURN_IF_ERROR(ConvertToDataType(func_type.getInput(i), &dtype));
|
||||||
TF_ASSIGN_OR_RETURN(xla_shape,
|
TF_ASSIGN_OR_RETURN(xla_shape,
|
||||||
shape_representation_fn(arg_shapes[i], dtype));
|
shape_representation_fn(arg_shapes[i], dtype,
|
||||||
|
/*use_fast_memory=*/false));
|
||||||
|
|
||||||
|
// Rewrite layout with sharding, if sharding is set.
|
||||||
|
auto sharding =
|
||||||
|
main_func.getArgAttrOfType<mlir::StringAttr>(i, "xla_hlo.sharding");
|
||||||
|
if (!sharding) continue;
|
||||||
|
|
||||||
|
absl::optional<xla::HloSharding> arg_sharding;
|
||||||
|
xla::OpSharding op_sharding;
|
||||||
|
if (!op_sharding.ParseFromString(sharding.getValue().str()))
|
||||||
|
return errors::InvalidArgument("failed to parse argument sharding ", i,
|
||||||
|
" '", sharding.getValue().str(), "'");
|
||||||
|
|
||||||
|
TF_ASSIGN_OR_RETURN(arg_sharding, xla::HloSharding::FromProto(op_sharding));
|
||||||
|
TF_RETURN_IF_ERROR(
|
||||||
|
RewriteLayoutWithShardedShape(arg_sharding, /*use_fast_memory=*/false,
|
||||||
|
shape_representation_fn, &xla_shape));
|
||||||
}
|
}
|
||||||
if (use_tuple_args) {
|
if (use_tuple_args) {
|
||||||
xla_input_shapes->push_back(
|
xla_input_shapes->push_back(
|
||||||
|
@ -108,9 +128,14 @@ Status GetXlaInputShapes(
|
||||||
// output based on static shapes in MLIR module
|
// output based on static shapes in MLIR module
|
||||||
Status GetOutputInfo(
|
Status GetOutputInfo(
|
||||||
mlir::ModuleOp module,
|
mlir::ModuleOp module,
|
||||||
const xla::CustomShapeRepresentationFn shape_representation_fn,
|
const XlaCompiler::ShapeRepresentationFn shape_representation_fn,
|
||||||
xla::Shape* xla_output_shape,
|
xla::Shape* xla_output_shape,
|
||||||
std::vector<XlaCompiler::OutputDescription>* outputs) {
|
std::vector<XlaCompiler::OutputDescription>* outputs) {
|
||||||
|
auto shape_representation_fn_no_fast_memory =
|
||||||
|
[shape_representation_fn](const TensorShape& shape, DataType dtype) {
|
||||||
|
return shape_representation_fn(shape, dtype, /*use_fast_memory=*/false);
|
||||||
|
};
|
||||||
|
|
||||||
mlir::FuncOp main_func = module.lookupSymbol<mlir::FuncOp>("main");
|
mlir::FuncOp main_func = module.lookupSymbol<mlir::FuncOp>("main");
|
||||||
mlir::FunctionType func_type = main_func.getType();
|
mlir::FunctionType func_type = main_func.getType();
|
||||||
|
|
||||||
|
@ -121,8 +146,9 @@ Status GetOutputInfo(
|
||||||
shapes.reserve(func_type.getNumResults());
|
shapes.reserve(func_type.getNumResults());
|
||||||
|
|
||||||
for (mlir::Type type : func_type.getResults()) {
|
for (mlir::Type type : func_type.getResults()) {
|
||||||
TF_ASSIGN_OR_RETURN(xla::Shape shape,
|
TF_ASSIGN_OR_RETURN(
|
||||||
TypeToShape(type, shape_representation_fn));
|
xla::Shape shape,
|
||||||
|
xla::TypeToShape(type, shape_representation_fn_no_fast_memory));
|
||||||
auto tensor_type = type.dyn_cast<mlir::RankedTensorType>();
|
auto tensor_type = type.dyn_cast<mlir::RankedTensorType>();
|
||||||
shapes.push_back(shape);
|
shapes.push_back(shape);
|
||||||
|
|
||||||
|
@ -226,11 +252,11 @@ static void RegisterDialects() {
|
||||||
}
|
}
|
||||||
|
|
||||||
} // namespace
|
} // namespace
|
||||||
// namespace
|
|
||||||
|
|
||||||
Status ConvertMLIRToXlaComputation(mlir::ModuleOp module_op,
|
Status ConvertMLIRToXlaComputation(
|
||||||
xla::XlaComputation* xla_computation,
|
mlir::ModuleOp module_op, xla::XlaComputation* xla_computation,
|
||||||
bool use_tuple_args, bool return_tuple) {
|
bool use_tuple_args, bool return_tuple,
|
||||||
|
const XlaCompiler::ShapeRepresentationFn shape_representation_fn) {
|
||||||
mlir::PassManager tf2xla(module_op.getContext());
|
mlir::PassManager tf2xla(module_op.getContext());
|
||||||
tf2xla.addNestedPass<mlir::FuncOp>(mlir::createCanonicalizerPass());
|
tf2xla.addNestedPass<mlir::FuncOp>(mlir::createCanonicalizerPass());
|
||||||
tf2xla.addPass(mlir::TF::CreateTensorListOpsDecompositionPass());
|
tf2xla.addPass(mlir::TF::CreateTensorListOpsDecompositionPass());
|
||||||
|
@ -273,7 +299,8 @@ Status ConvertMLIRToXlaComputation(mlir::ModuleOp module_op,
|
||||||
|
|
||||||
xla::HloProto hlo_proto;
|
xla::HloProto hlo_proto;
|
||||||
TF_RETURN_IF_ERROR(mlir::ConvertMlirHloToHlo(module_op, &hlo_proto,
|
TF_RETURN_IF_ERROR(mlir::ConvertMlirHloToHlo(module_op, &hlo_proto,
|
||||||
use_tuple_args, return_tuple));
|
use_tuple_args, return_tuple,
|
||||||
|
shape_representation_fn));
|
||||||
*xla_computation = xla::XlaComputation(hlo_proto.hlo_module());
|
*xla_computation = xla::XlaComputation(hlo_proto.hlo_module());
|
||||||
return Status::OK();
|
return Status::OK();
|
||||||
}
|
}
|
||||||
|
@ -281,7 +308,7 @@ Status ConvertMLIRToXlaComputation(mlir::ModuleOp module_op,
|
||||||
static Status CompileMlirToXlaHlo(
|
static Status CompileMlirToXlaHlo(
|
||||||
mlir::ModuleOp module_op, llvm::ArrayRef<TensorShape> arg_shapes,
|
mlir::ModuleOp module_op, llvm::ArrayRef<TensorShape> arg_shapes,
|
||||||
bool use_tuple_args,
|
bool use_tuple_args,
|
||||||
const XlaCompiler::ShapeRepresentationFn shape_representation_fn,
|
XlaCompiler::ShapeRepresentationFn shape_representation_fn,
|
||||||
XlaCompiler::CompilationResult* compilation_result) {
|
XlaCompiler::CompilationResult* compilation_result) {
|
||||||
if (VLOG_IS_ON(1))
|
if (VLOG_IS_ON(1))
|
||||||
tensorflow::DumpMlirOpToFile("mlir_compile_before", module_op);
|
tensorflow::DumpMlirOpToFile("mlir_compile_before", module_op);
|
||||||
|
@ -292,35 +319,28 @@ static Status CompileMlirToXlaHlo(
|
||||||
if (VLOG_IS_ON(1))
|
if (VLOG_IS_ON(1))
|
||||||
tensorflow::DumpMlirOpToFile("mlir_compile_shape_refiner", module_op);
|
tensorflow::DumpMlirOpToFile("mlir_compile_shape_refiner", module_op);
|
||||||
|
|
||||||
|
if (!shape_representation_fn)
|
||||||
|
shape_representation_fn = IdentityShapeRepresentationFn();
|
||||||
|
|
||||||
// Convert MLIR module to XLA HLO proto contained in XlaComputation.
|
// Convert MLIR module to XLA HLO proto contained in XlaComputation.
|
||||||
compilation_result->computation = std::make_shared<xla::XlaComputation>();
|
compilation_result->computation = std::make_shared<xla::XlaComputation>();
|
||||||
TF_RETURN_IF_ERROR(ConvertMLIRToXlaComputation(
|
TF_RETURN_IF_ERROR(ConvertMLIRToXlaComputation(
|
||||||
module_op, compilation_result->computation.get(), use_tuple_args,
|
module_op, compilation_result->computation.get(), use_tuple_args,
|
||||||
/*return_tuple=*/true));
|
/*return_tuple=*/true, shape_representation_fn));
|
||||||
|
|
||||||
// Construct mapping from XlaComputation's arg to input edges of execute
|
// Construct mapping from XlaComputation's arg to input edges of execute
|
||||||
// node.
|
// node.
|
||||||
GetInputMappingForMlir(arg_shapes.size(), &compilation_result->input_mapping);
|
GetInputMappingForMlir(arg_shapes.size(), &compilation_result->input_mapping);
|
||||||
|
|
||||||
auto shape_representation_fn_no_fast_memory =
|
|
||||||
[shape_representation_fn](const TensorShape& shape,
|
|
||||||
DataType dtype) -> StatusOr<xla::Shape> {
|
|
||||||
if (shape_representation_fn)
|
|
||||||
return shape_representation_fn(shape, dtype, /*use_fast_memory=*/false);
|
|
||||||
xla::Shape xla_shape;
|
|
||||||
TF_RETURN_IF_ERROR(TensorShapeToXLAShape(dtype, shape, &xla_shape));
|
|
||||||
return xla_shape;
|
|
||||||
};
|
|
||||||
|
|
||||||
// Compute all input shapes.
|
// Compute all input shapes.
|
||||||
TF_RETURN_IF_ERROR(GetXlaInputShapes(module_op, arg_shapes, use_tuple_args,
|
TF_RETURN_IF_ERROR(GetXlaInputShapes(module_op, arg_shapes, use_tuple_args,
|
||||||
shape_representation_fn_no_fast_memory,
|
shape_representation_fn,
|
||||||
&compilation_result->xla_input_shapes));
|
&compilation_result->xla_input_shapes));
|
||||||
|
|
||||||
// Compute all output descriptions.
|
// Compute all output descriptions.
|
||||||
TF_RETURN_IF_ERROR(GetOutputInfo(
|
TF_RETURN_IF_ERROR(GetOutputInfo(module_op, shape_representation_fn,
|
||||||
module_op, shape_representation_fn_no_fast_memory,
|
&compilation_result->xla_output_shape,
|
||||||
&compilation_result->xla_output_shape, &compilation_result->outputs));
|
&compilation_result->outputs));
|
||||||
|
|
||||||
// Compute what resource variables need to be updated after XlaComputation's
|
// Compute what resource variables need to be updated after XlaComputation's
|
||||||
// execution.
|
// execution.
|
||||||
|
|
|
@ -43,9 +43,13 @@ namespace tensorflow {
|
||||||
// entry computation.
|
// entry computation.
|
||||||
// return_tuple: when this is true, always create a tuple result for the
|
// return_tuple: when this is true, always create a tuple result for the
|
||||||
// entry computation.
|
// entry computation.
|
||||||
Status ConvertMLIRToXlaComputation(mlir::ModuleOp module_op,
|
// shape_representation_fn: when this is set, this shape representation function
|
||||||
xla::XlaComputation* xla_computation,
|
// will be used to determine argument and result shapes. Otherwise the
|
||||||
bool use_tuple_args, bool return_tuple);
|
// original shape will be used as is.
|
||||||
|
Status ConvertMLIRToXlaComputation(
|
||||||
|
mlir::ModuleOp module_op, xla::XlaComputation* xla_computation,
|
||||||
|
bool use_tuple_args, bool return_tuple,
|
||||||
|
const XlaCompiler::ShapeRepresentationFn shape_representation_fn = nullptr);
|
||||||
|
|
||||||
// Compiles a serialized MLIR module into XLA HLO, generates all accompanying
|
// Compiles a serialized MLIR module into XLA HLO, generates all accompanying
|
||||||
// metadata and stores them in CompilationResult.
|
// metadata and stores them in CompilationResult.
|
||||||
|
|
|
@ -40,7 +40,8 @@ xla::StatusOr<xla::Shape> TestShapeRepresentation(const TensorShape& shape,
|
||||||
}
|
}
|
||||||
|
|
||||||
TEST(CompileSerializedMlirToXlaHloTest, InvalidSerializedMlirModule) {
|
TEST(CompileSerializedMlirToXlaHloTest, InvalidSerializedMlirModule) {
|
||||||
string invalid_mlir_module = "totally @invalid MLIR module {here} <-";
|
constexpr char invalid_mlir_module[] =
|
||||||
|
"totally @invalid MLIR module {here} <-";
|
||||||
std::vector<TensorShape> arg_shapes;
|
std::vector<TensorShape> arg_shapes;
|
||||||
XlaCompiler::CompilationResult compilation_result;
|
XlaCompiler::CompilationResult compilation_result;
|
||||||
|
|
||||||
|
@ -76,7 +77,7 @@ TEST(CompileSerializedMlirToXlaHloTest, TupleArgs) {
|
||||||
auto status_or_hlo_module = xla::HloModule::CreateFromProto(
|
auto status_or_hlo_module = xla::HloModule::CreateFromProto(
|
||||||
compilation_result.computation->proto(), module_config);
|
compilation_result.computation->proto(), module_config);
|
||||||
TF_ASSERT_OK(status_or_hlo_module.status());
|
TF_ASSERT_OK(status_or_hlo_module.status());
|
||||||
string expected_hlo_module_string = R"(HloModule main.6
|
constexpr char expected_hlo_module_string[] = R"(HloModule main.6
|
||||||
|
|
||||||
ENTRY %main.6 (arg_tuple.1: (f32[], f32[])) -> (f32[]) {
|
ENTRY %main.6 (arg_tuple.1: (f32[], f32[])) -> (f32[]) {
|
||||||
%arg_tuple.1 = (f32[], f32[]) parameter(0)
|
%arg_tuple.1 = (f32[], f32[]) parameter(0)
|
||||||
|
@ -134,7 +135,7 @@ TEST(CompileSerializedMlirToXlaHloTest, IndividualArgs) {
|
||||||
auto status_or_hlo_module = xla::HloModule::CreateFromProto(
|
auto status_or_hlo_module = xla::HloModule::CreateFromProto(
|
||||||
compilation_result.computation->proto(), module_config);
|
compilation_result.computation->proto(), module_config);
|
||||||
TF_ASSERT_OK(status_or_hlo_module.status());
|
TF_ASSERT_OK(status_or_hlo_module.status());
|
||||||
string expected_hlo_module_string = R"(HloModule main.5
|
constexpr char expected_hlo_module_string[] = R"(HloModule main.5
|
||||||
|
|
||||||
ENTRY %main.5 (Arg_0.1: f32[], Arg_1.2: f32[]) -> (f32[]) {
|
ENTRY %main.5 (Arg_0.1: f32[], Arg_1.2: f32[]) -> (f32[]) {
|
||||||
%Arg_0.1 = f32[] parameter(0)
|
%Arg_0.1 = f32[] parameter(0)
|
||||||
|
@ -181,7 +182,7 @@ ENTRY %main.5 (Arg_0.1: f32[], Arg_1.2: f32[]) -> (f32[]) {
|
||||||
TEST(CompileSerializedMlirToXlaHloTest, CompileTimeConstantFoldedSuccess) {
|
TEST(CompileSerializedMlirToXlaHloTest, CompileTimeConstantFoldedSuccess) {
|
||||||
// "tf.Shape" can only be folded away after shape inference. tf.Reshape can
|
// "tf.Shape" can only be folded away after shape inference. tf.Reshape can
|
||||||
// only be lowered when tf.Shape is folded into a constant.
|
// only be lowered when tf.Shape is folded into a constant.
|
||||||
string mlir_module = R"(
|
constexpr char mlir_module[] = R"(
|
||||||
module attributes {tf.versions = {producer = 179 : i32}} {
|
module attributes {tf.versions = {producer = 179 : i32}} {
|
||||||
func @main(%arg0: tensor<10x19xf32>, %arg1: tensor<19x10xf32> {tf_device.is_same_data_across_replicas = true}) -> tensor<10x19xf32> {
|
func @main(%arg0: tensor<10x19xf32>, %arg1: tensor<19x10xf32> {tf_device.is_same_data_across_replicas = true}) -> tensor<10x19xf32> {
|
||||||
%0 = "tf.Shape"(%arg0) : (tensor<10x19xf32>) -> tensor<2xi64>
|
%0 = "tf.Shape"(%arg0) : (tensor<10x19xf32>) -> tensor<2xi64>
|
||||||
|
@ -205,7 +206,7 @@ TEST(CompileSerializedMlirToXlaHloTest, CompileTimeConstantFoldedSuccess) {
|
||||||
auto status_or_hlo_module = xla::HloModule::CreateFromProto(
|
auto status_or_hlo_module = xla::HloModule::CreateFromProto(
|
||||||
compilation_result.computation->proto(), module_config);
|
compilation_result.computation->proto(), module_config);
|
||||||
TF_ASSERT_OK(status_or_hlo_module.status());
|
TF_ASSERT_OK(status_or_hlo_module.status());
|
||||||
string expected_hlo_module_string = R"(HloModule main.6
|
constexpr char expected_hlo_module_string[] = R"(HloModule main.6
|
||||||
|
|
||||||
ENTRY %main.6 (arg_tuple.1: (f32[10,19], f32[19,10])) -> (f32[10,19]) {
|
ENTRY %main.6 (arg_tuple.1: (f32[10,19], f32[19,10])) -> (f32[10,19]) {
|
||||||
%arg_tuple.1 = (f32[10,19]{1,0}, f32[19,10]{1,0}) parameter(0), parameter_replication={false,true}
|
%arg_tuple.1 = (f32[10,19]{1,0}, f32[19,10]{1,0}) parameter(0), parameter_replication={false,true}
|
||||||
|
@ -221,7 +222,7 @@ ENTRY %main.6 (arg_tuple.1: (f32[10,19], f32[19,10])) -> (f32[10,19]) {
|
||||||
}
|
}
|
||||||
|
|
||||||
TEST(CompileSerializedMlirToXlaHloTest, ShapeInference) {
|
TEST(CompileSerializedMlirToXlaHloTest, ShapeInference) {
|
||||||
string mlir_module = R"(
|
constexpr char mlir_module[] = R"(
|
||||||
module attributes {tf.versions = {producer = 179 : i32}} {
|
module attributes {tf.versions = {producer = 179 : i32}} {
|
||||||
func @main(%arg0: tensor<*xf32>, %arg1: tensor<?x19xf32>) -> tensor<?x19xf32> {
|
func @main(%arg0: tensor<*xf32>, %arg1: tensor<?x19xf32>) -> tensor<?x19xf32> {
|
||||||
%0 = "tf.MatMul"(%arg0, %arg1) {T = "tfdtype$DT_FLOAT", transpose_a = false, transpose_b = false} : (tensor<*xf32>, tensor<?x19xf32>) -> tensor<?x19xf32>
|
%0 = "tf.MatMul"(%arg0, %arg1) {T = "tfdtype$DT_FLOAT", transpose_a = false, transpose_b = false} : (tensor<*xf32>, tensor<?x19xf32>) -> tensor<?x19xf32>
|
||||||
|
@ -245,13 +246,14 @@ TEST(CompileSerializedMlirToXlaHloTest, ShapeInference) {
|
||||||
compilation_result.computation->proto(), module_config);
|
compilation_result.computation->proto(), module_config);
|
||||||
TF_ASSERT_OK(status_or_hlo_module.status());
|
TF_ASSERT_OK(status_or_hlo_module.status());
|
||||||
|
|
||||||
string expected_signature =
|
constexpr char expected_signature[] =
|
||||||
R"((arg_tuple.1: (f32[10,17], f32[17,19])) -> (f32[10,19]))";
|
R"((arg_tuple.1: (f32[10,17], f32[17,19])) -> (f32[10,19]))";
|
||||||
EXPECT_THAT(status_or_hlo_module.ValueOrDie()->ToString(),
|
EXPECT_THAT(status_or_hlo_module.ValueOrDie()->ToString(),
|
||||||
::testing::HasSubstr(expected_signature));
|
::testing::HasSubstr(expected_signature));
|
||||||
}
|
}
|
||||||
|
|
||||||
constexpr llvm::StringRef kBroadcastGradientArgsModule = R"(
|
TEST(CompileSerializedMlirToXlaHloTest, ConstantFoldHook) {
|
||||||
|
constexpr char mlir_module[] = R"(
|
||||||
module attributes {tf.versions = {producer = 179 : i32}} {
|
module attributes {tf.versions = {producer = 179 : i32}} {
|
||||||
func @main() -> (tensor<0xi32>, tensor<0xi32>) {
|
func @main() -> (tensor<0xi32>, tensor<0xi32>) {
|
||||||
%0 = "tf.Const"() {value = dense<[]> : tensor<0xi32>} : () -> tensor<0xi32>
|
%0 = "tf.Const"() {value = dense<[]> : tensor<0xi32>} : () -> tensor<0xi32>
|
||||||
|
@ -261,12 +263,11 @@ module attributes {tf.versions = {producer = 179 : i32}} {
|
||||||
}
|
}
|
||||||
)";
|
)";
|
||||||
|
|
||||||
TEST(CompileSerializedMlirToXlaHloTest, ConstantFoldHook) {
|
|
||||||
std::vector<TensorShape> arg_shapes(2, TensorShape());
|
std::vector<TensorShape> arg_shapes(2, TensorShape());
|
||||||
XlaCompiler::CompilationResult compilation_result;
|
XlaCompiler::CompilationResult compilation_result;
|
||||||
|
|
||||||
Status s = CompileSerializedMlirToXlaHlo(
|
Status s = CompileSerializedMlirToXlaHlo(
|
||||||
kBroadcastGradientArgsModule, arg_shapes,
|
mlir_module, arg_shapes,
|
||||||
/*use_tuple_args=*/true, TestShapeRepresentation, &compilation_result);
|
/*use_tuple_args=*/true, TestShapeRepresentation, &compilation_result);
|
||||||
TF_ASSERT_OK(s);
|
TF_ASSERT_OK(s);
|
||||||
|
|
||||||
|
@ -275,7 +276,7 @@ TEST(CompileSerializedMlirToXlaHloTest, ConstantFoldHook) {
|
||||||
auto status_or_hlo_module = xla::HloModule::CreateFromProto(
|
auto status_or_hlo_module = xla::HloModule::CreateFromProto(
|
||||||
compilation_result.computation->proto(), module_config);
|
compilation_result.computation->proto(), module_config);
|
||||||
TF_ASSERT_OK(status_or_hlo_module.status());
|
TF_ASSERT_OK(status_or_hlo_module.status());
|
||||||
string expected_hlo_module_string = R"(HloModule main.4
|
constexpr char expected_hlo_module_string[] = R"(HloModule main.4
|
||||||
|
|
||||||
ENTRY %main.4 (arg_tuple.1: ()) -> (s32[0], s32[0]) {
|
ENTRY %main.4 (arg_tuple.1: ()) -> (s32[0], s32[0]) {
|
||||||
%arg_tuple.1 = () parameter(0)
|
%arg_tuple.1 = () parameter(0)
|
||||||
|
@ -288,6 +289,128 @@ ENTRY %main.4 (arg_tuple.1: ()) -> (s32[0], s32[0]) {
|
||||||
status_or_hlo_module.ValueOrDie()->ToString());
|
status_or_hlo_module.ValueOrDie()->ToString());
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// The following xla::OpSharding protos are used:
|
||||||
|
// Serialized string:
|
||||||
|
// "\08\03\1A\02\01\02\22\02\00\01"
|
||||||
|
// Proto debug string:
|
||||||
|
// type: OTHER
|
||||||
|
// tile_assignment_dimensions: 1
|
||||||
|
// tile_assignment_dimensions: 2
|
||||||
|
// tile_assignment_devices: 0
|
||||||
|
// tile_assignment_devices: 1
|
||||||
|
//
|
||||||
|
// Serialized string:
|
||||||
|
// "\08\01\1A\01\01\22\01\00"
|
||||||
|
// Proto debug string:
|
||||||
|
// type: MAXIMAL
|
||||||
|
// tile_assignment_dimensions: 1
|
||||||
|
// tile_assignment_devices: 0
|
||||||
|
//
|
||||||
|
// Serialized string:
|
||||||
|
// ""
|
||||||
|
// Proto debug string (empty but would equivalent to):
|
||||||
|
// type: REPLICATED
|
||||||
|
TEST(CompileSerializedMlirToXlaHloTest, ArgumentSharding) {
|
||||||
|
constexpr char mlir_module[] = R"(
|
||||||
|
module attributes {tf.versions = {producer = 179 : i32}} {
|
||||||
|
func @main(%arg0: tensor<128x10xf32> {xla_hlo.sharding = "\08\03\1A\02\01\02\22\02\00\01"}, %arg1: tensor<10x1024xf32> {xla_hlo.sharding = "\08\01\1A\01\01\22\01\00"}, %arg2: tensor<128x1024xf32> {xla_hlo.sharding = ""}) {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
}
|
||||||
|
)";
|
||||||
|
|
||||||
|
std::vector<TensorShape> arg_shapes{TensorShape({128, 10}),
|
||||||
|
TensorShape({10, 1024}),
|
||||||
|
TensorShape({128, 1024})};
|
||||||
|
XlaCompiler::CompilationResult compilation_result;
|
||||||
|
|
||||||
|
Status s = CompileSerializedMlirToXlaHlo(
|
||||||
|
mlir_module, arg_shapes,
|
||||||
|
/*use_tuple_args=*/true, TestShapeRepresentation, &compilation_result);
|
||||||
|
TF_ASSERT_OK(s);
|
||||||
|
|
||||||
|
const xla::HloModuleConfig module_config(
|
||||||
|
compilation_result.computation->GetProgramShape().ValueOrDie());
|
||||||
|
auto status_or_hlo_module = xla::HloModule::CreateFromProto(
|
||||||
|
compilation_result.computation->proto(), module_config);
|
||||||
|
TF_ASSERT_OK(status_or_hlo_module.status());
|
||||||
|
constexpr char expected_hlo_module_string[] = R"(HloModule main.6
|
||||||
|
|
||||||
|
ENTRY %main.6 (arg_tuple.1: (f32[128,10], f32[10,1024], f32[128,1024])) -> () {
|
||||||
|
%arg_tuple.1 = (f32[128,10]{1,0}, f32[10,1024]{1,0}, f32[128,1024]{1,0}) parameter(0), sharding={{devices=[1,2]0,1}, {maximal device=0}, {replicated}}
|
||||||
|
%get-tuple-element.2 = f32[128,10]{1,0} get-tuple-element((f32[128,10]{1,0}, f32[10,1024]{1,0}, f32[128,1024]{1,0}) %arg_tuple.1), index=0
|
||||||
|
%get-tuple-element.3 = f32[10,1024]{1,0} get-tuple-element((f32[128,10]{1,0}, f32[10,1024]{1,0}, f32[128,1024]{1,0}) %arg_tuple.1), index=1
|
||||||
|
%get-tuple-element.4 = f32[128,1024]{1,0} get-tuple-element((f32[128,10]{1,0}, f32[10,1024]{1,0}, f32[128,1024]{1,0}) %arg_tuple.1), index=2
|
||||||
|
ROOT %tuple.5 = () tuple()
|
||||||
|
}
|
||||||
|
|
||||||
|
)";
|
||||||
|
EXPECT_EQ(expected_hlo_module_string,
|
||||||
|
status_or_hlo_module.ValueOrDie()->ToString());
|
||||||
|
}
|
||||||
|
|
||||||
|
TEST(CompileSerializedMlirToXlaHloTest, BadArgumentSharding) {
|
||||||
|
constexpr char mlir_module[] = R"(
|
||||||
|
module attributes {tf.versions = {producer = 179 : i32}} {
|
||||||
|
func @main(%arg0: tensor<128x10xf32> {xla_hlo.sharding = "bad_sharding"}) {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
}
|
||||||
|
)";
|
||||||
|
|
||||||
|
std::vector<TensorShape> arg_shapes{TensorShape({128, 10})};
|
||||||
|
XlaCompiler::CompilationResult compilation_result;
|
||||||
|
|
||||||
|
Status s = CompileSerializedMlirToXlaHlo(
|
||||||
|
mlir_module, arg_shapes,
|
||||||
|
/*use_tuple_args=*/true, TestShapeRepresentation, &compilation_result);
|
||||||
|
ASSERT_FALSE(s.ok());
|
||||||
|
EXPECT_EQ(s.error_message(),
|
||||||
|
"failed to parse argument sharding 0 'bad_sharding'");
|
||||||
|
}
|
||||||
|
|
||||||
|
TEST(CompileSerializedMlirToXlaHloTest, ResultSharding) {
|
||||||
|
constexpr char mlir_module[] = R"(
|
||||||
|
module attributes {tf.versions = {bad_consumers = [], min_consumer = 0 : i32, producer = 351 : i32}} {
|
||||||
|
func @main(%arg0: tensor<128x10xf32>, %arg1: tensor<10x1024xf32>, %arg2: tensor<128x1024xf32>) -> (tensor<128x10xf32> {xla_hlo.sharding = "\08\03\1A\02\01\02\22\02\00\01"}, tensor<10x1024xf32> {xla_hlo.sharding = "\08\01\1A\01\01\22\01\00"}, tensor<128x1024xf32> {xla_hlo.sharding = ""}) {
|
||||||
|
return %arg0, %arg1, %arg2 : tensor<128x10xf32>, tensor<10x1024xf32>, tensor<128x1024xf32>
|
||||||
|
}
|
||||||
|
}
|
||||||
|
)";
|
||||||
|
|
||||||
|
std::vector<TensorShape> arg_shapes{TensorShape({128, 10}),
|
||||||
|
TensorShape({10, 1024}),
|
||||||
|
TensorShape({128, 1024})};
|
||||||
|
XlaCompiler::CompilationResult compilation_result;
|
||||||
|
|
||||||
|
Status s = CompileSerializedMlirToXlaHlo(
|
||||||
|
mlir_module, arg_shapes,
|
||||||
|
/*use_tuple_args=*/true, TestShapeRepresentation, &compilation_result);
|
||||||
|
TF_ASSERT_OK(s);
|
||||||
|
|
||||||
|
const xla::HloModuleConfig module_config(
|
||||||
|
compilation_result.computation->GetProgramShape().ValueOrDie());
|
||||||
|
auto status_or_hlo_module = xla::HloModule::CreateFromProto(
|
||||||
|
compilation_result.computation->proto(), module_config);
|
||||||
|
TF_ASSERT_OK(status_or_hlo_module.status());
|
||||||
|
constexpr char expected_hlo_module_string[] = R"(HloModule main.9
|
||||||
|
|
||||||
|
ENTRY %main.9 (arg_tuple.1: (f32[128,10], f32[10,1024], f32[128,1024])) -> (f32[128,10], f32[10,1024], f32[128,1024]) {
|
||||||
|
%arg_tuple.1 = (f32[128,10]{1,0}, f32[10,1024]{1,0}, f32[128,1024]{1,0}) parameter(0)
|
||||||
|
%get-tuple-element.2 = f32[128,10]{1,0} get-tuple-element((f32[128,10]{1,0}, f32[10,1024]{1,0}, f32[128,1024]{1,0}) %arg_tuple.1), index=0
|
||||||
|
%reshape.5 = f32[128,10]{1,0} reshape(f32[128,10]{1,0} %get-tuple-element.2)
|
||||||
|
%get-tuple-element.3 = f32[10,1024]{1,0} get-tuple-element((f32[128,10]{1,0}, f32[10,1024]{1,0}, f32[128,1024]{1,0}) %arg_tuple.1), index=1
|
||||||
|
%reshape.6 = f32[10,1024]{1,0} reshape(f32[10,1024]{1,0} %get-tuple-element.3)
|
||||||
|
%get-tuple-element.4 = f32[128,1024]{1,0} get-tuple-element((f32[128,10]{1,0}, f32[10,1024]{1,0}, f32[128,1024]{1,0}) %arg_tuple.1), index=2
|
||||||
|
%reshape.7 = f32[128,1024]{1,0} reshape(f32[128,1024]{1,0} %get-tuple-element.4)
|
||||||
|
ROOT %tuple.8 = (f32[128,10]{1,0}, f32[10,1024]{1,0}, f32[128,1024]{1,0}) tuple(f32[128,10]{1,0} %reshape.5, f32[10,1024]{1,0} %reshape.6, f32[128,1024]{1,0} %reshape.7), sharding={{devices=[1,2]0,1}, {maximal device=0}, {replicated}}
|
||||||
|
}
|
||||||
|
|
||||||
|
)";
|
||||||
|
EXPECT_EQ(expected_hlo_module_string,
|
||||||
|
status_or_hlo_module.ValueOrDie()->ToString());
|
||||||
|
}
|
||||||
|
|
||||||
// Verify that conversion from Graph to MLIR and empty shape representation
|
// Verify that conversion from Graph to MLIR and empty shape representation
|
||||||
// function is successful.
|
// function is successful.
|
||||||
TEST(CompileGraphToXlaHlo, Basic) {
|
TEST(CompileGraphToXlaHlo, Basic) {
|
||||||
|
@ -311,7 +434,7 @@ TEST(CompileGraphToXlaHlo, Basic) {
|
||||||
result.computation->proto(), module_config);
|
result.computation->proto(), module_config);
|
||||||
ASSERT_TRUE(status_or_hlo_module.ok());
|
ASSERT_TRUE(status_or_hlo_module.ok());
|
||||||
|
|
||||||
string expected_hlo_module_string = R"(HloModule main.3
|
constexpr char expected_hlo_module_string[] = R"(HloModule main.3
|
||||||
|
|
||||||
ENTRY %main.3 (Arg_0.1: f32[]) -> (f32[]) {
|
ENTRY %main.3 (Arg_0.1: f32[]) -> (f32[]) {
|
||||||
%Arg_0.1 = f32[] parameter(0)
|
%Arg_0.1 = f32[] parameter(0)
|
||||||
|
|
|
@ -618,7 +618,10 @@ cc_library(
|
||||||
":hlo",
|
":hlo",
|
||||||
":type_to_shape",
|
":type_to_shape",
|
||||||
":xla_dialect_registration",
|
":xla_dialect_registration",
|
||||||
|
"//tensorflow/compiler/mlir/tensorflow:convert_type",
|
||||||
"//tensorflow/compiler/mlir/tensorflow:error_util",
|
"//tensorflow/compiler/mlir/tensorflow:error_util",
|
||||||
|
"//tensorflow/compiler/tf2xla:common",
|
||||||
|
"//tensorflow/compiler/tf2xla:xla_compiler",
|
||||||
"//tensorflow/compiler/xla:comparison_util",
|
"//tensorflow/compiler/xla:comparison_util",
|
||||||
"//tensorflow/compiler/xla:literal_util",
|
"//tensorflow/compiler/xla:literal_util",
|
||||||
"//tensorflow/compiler/xla:shape_util",
|
"//tensorflow/compiler/xla:shape_util",
|
||||||
|
@ -629,6 +632,8 @@ cc_library(
|
||||||
"//tensorflow/compiler/xla/client/lib:quantize",
|
"//tensorflow/compiler/xla/client/lib:quantize",
|
||||||
"//tensorflow/compiler/xla/client/lib:slicing",
|
"//tensorflow/compiler/xla/client/lib:slicing",
|
||||||
"//tensorflow/compiler/xla/service:hlo",
|
"//tensorflow/compiler/xla/service:hlo",
|
||||||
|
"//tensorflow/core:framework",
|
||||||
|
"//tensorflow/core:protos_all_cc",
|
||||||
"//tensorflow/stream_executor/lib",
|
"//tensorflow/stream_executor/lib",
|
||||||
"@llvm-project//llvm:support",
|
"@llvm-project//llvm:support",
|
||||||
"@llvm-project//mlir:Analysis",
|
"@llvm-project//mlir:Analysis",
|
||||||
|
|
|
@ -21,6 +21,7 @@ limitations under the License.
|
||||||
#include "llvm/ADT/DenseMap.h"
|
#include "llvm/ADT/DenseMap.h"
|
||||||
#include "llvm/ADT/DenseSet.h"
|
#include "llvm/ADT/DenseSet.h"
|
||||||
#include "llvm/ADT/STLExtras.h"
|
#include "llvm/ADT/STLExtras.h"
|
||||||
|
#include "llvm/ADT/SmallVector.h"
|
||||||
#include "llvm/ADT/StringRef.h"
|
#include "llvm/ADT/StringRef.h"
|
||||||
#include "llvm/Support/FormatVariadic.h"
|
#include "llvm/Support/FormatVariadic.h"
|
||||||
#include "llvm/Support/MemoryBuffer.h"
|
#include "llvm/Support/MemoryBuffer.h"
|
||||||
|
@ -37,8 +38,12 @@ limitations under the License.
|
||||||
#include "mlir/IR/Operation.h" // from @llvm-project
|
#include "mlir/IR/Operation.h" // from @llvm-project
|
||||||
#include "mlir/IR/StandardTypes.h" // from @llvm-project
|
#include "mlir/IR/StandardTypes.h" // from @llvm-project
|
||||||
#include "mlir/IR/TypeUtilities.h" // from @llvm-project
|
#include "mlir/IR/TypeUtilities.h" // from @llvm-project
|
||||||
|
#include "mlir/IR/UseDefLists.h" // from @llvm-project
|
||||||
|
#include "tensorflow/compiler/mlir/tensorflow/utils/convert_type.h"
|
||||||
#include "tensorflow/compiler/mlir/xla/ir/hlo_ops.h"
|
#include "tensorflow/compiler/mlir/xla/ir/hlo_ops.h"
|
||||||
#include "tensorflow/compiler/mlir/xla/type_to_shape.h"
|
#include "tensorflow/compiler/mlir/xla/type_to_shape.h"
|
||||||
|
#include "tensorflow/compiler/tf2xla/shape_util.h"
|
||||||
|
#include "tensorflow/compiler/tf2xla/xla_compiler.h"
|
||||||
#include "tensorflow/compiler/xla/client/lib/matrix.h"
|
#include "tensorflow/compiler/xla/client/lib/matrix.h"
|
||||||
#include "tensorflow/compiler/xla/client/lib/quantize.h"
|
#include "tensorflow/compiler/xla/client/lib/quantize.h"
|
||||||
#include "tensorflow/compiler/xla/client/lib/slicing.h"
|
#include "tensorflow/compiler/xla/client/lib/slicing.h"
|
||||||
|
@ -49,6 +54,8 @@ limitations under the License.
|
||||||
#include "tensorflow/compiler/xla/shape_util.h"
|
#include "tensorflow/compiler/xla/shape_util.h"
|
||||||
#include "tensorflow/compiler/xla/status_macros.h"
|
#include "tensorflow/compiler/xla/status_macros.h"
|
||||||
#include "tensorflow/compiler/xla/xla_data.pb.h"
|
#include "tensorflow/compiler/xla/xla_data.pb.h"
|
||||||
|
#include "tensorflow/core/framework/tensor_shape.h"
|
||||||
|
#include "tensorflow/core/framework/types.pb.h"
|
||||||
#include "tensorflow/stream_executor/lib/statusor.h"
|
#include "tensorflow/stream_executor/lib/statusor.h"
|
||||||
|
|
||||||
using ::stream_executor::port::StatusOr;
|
using ::stream_executor::port::StatusOr;
|
||||||
|
@ -64,6 +71,7 @@ using ::tensorflow::uint8;
|
||||||
constexpr char kPaddingMapAttr[] = "xla_hlo.padding_map";
|
constexpr char kPaddingMapAttr[] = "xla_hlo.padding_map";
|
||||||
constexpr char kShapeIndicesAttr[] = "shape_indices";
|
constexpr char kShapeIndicesAttr[] = "shape_indices";
|
||||||
constexpr char kPaddingArgIndicesAttr[] = "padding_arg_indices";
|
constexpr char kPaddingArgIndicesAttr[] = "padding_arg_indices";
|
||||||
|
constexpr char kShardingAttr[] = "xla_hlo.sharding";
|
||||||
constexpr char kRepicationAttr[] = "tf_device.is_same_data_across_replicas";
|
constexpr char kRepicationAttr[] = "tf_device.is_same_data_across_replicas";
|
||||||
|
|
||||||
// Passes through everything except for unique_ptr, on which it calls get().
|
// Passes through everything except for unique_ptr, on which it calls get().
|
||||||
|
@ -377,7 +385,7 @@ static xla::ScatterDimensionNumbers Convert_scatter_dimension_numbers(
|
||||||
// returns absl::nullopt.
|
// returns absl::nullopt.
|
||||||
static absl::optional<xla::OpSharding> CreateOpShardingFromAttribute(
|
static absl::optional<xla::OpSharding> CreateOpShardingFromAttribute(
|
||||||
mlir::Operation* op) {
|
mlir::Operation* op) {
|
||||||
auto sharding = op->getAttrOfType<mlir::StringAttr>("xla_hlo.sharding");
|
auto sharding = op->getAttrOfType<mlir::StringAttr>(kShardingAttr);
|
||||||
if (!sharding) {
|
if (!sharding) {
|
||||||
return absl::nullopt;
|
return absl::nullopt;
|
||||||
}
|
}
|
||||||
|
@ -389,6 +397,43 @@ static absl::optional<xla::OpSharding> CreateOpShardingFromAttribute(
|
||||||
return sharding_proto;
|
return sharding_proto;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// Checks if all shardings are set.
|
||||||
|
static bool AllOptionalShardingsAreSet(
|
||||||
|
llvm::ArrayRef<absl::optional<xla::OpSharding>> shardings) {
|
||||||
|
return llvm::all_of(shardings,
|
||||||
|
[](const absl::optional<xla::OpSharding>& sharding) {
|
||||||
|
return sharding.has_value();
|
||||||
|
});
|
||||||
|
}
|
||||||
|
|
||||||
|
// Extracts sharding from attribute string.
|
||||||
|
static absl::optional<xla::OpSharding> CreateOpShardingFromStringRef(
|
||||||
|
llvm::StringRef sharding) {
|
||||||
|
xla::OpSharding sharding_proto;
|
||||||
|
if (!sharding_proto.ParseFromString(sharding.str())) return absl::nullopt;
|
||||||
|
return sharding_proto;
|
||||||
|
}
|
||||||
|
|
||||||
|
// Extracts argument and result shardings from function.
|
||||||
|
static void ExtractShardingsFromFunction(
|
||||||
|
mlir::FuncOp function,
|
||||||
|
llvm::SmallVectorImpl<absl::optional<xla::OpSharding>>* arg_shardings,
|
||||||
|
llvm::SmallVectorImpl<absl::optional<xla::OpSharding>>* ret_shardings) {
|
||||||
|
arg_shardings->resize(function.getNumArguments(),
|
||||||
|
absl::optional<xla::OpSharding>());
|
||||||
|
for (int i = 0; i < function.getNumArguments(); ++i)
|
||||||
|
if (auto sharding =
|
||||||
|
function.getArgAttrOfType<mlir::StringAttr>(i, kShardingAttr))
|
||||||
|
(*arg_shardings)[i] = CreateOpShardingFromStringRef(sharding.getValue());
|
||||||
|
|
||||||
|
ret_shardings->resize(function.getNumResults(),
|
||||||
|
absl::optional<xla::OpSharding>());
|
||||||
|
for (int i = 0; i < function.getNumResults(); ++i)
|
||||||
|
if (auto sharding =
|
||||||
|
function.getResultAttrOfType<mlir::StringAttr>(i, kShardingAttr))
|
||||||
|
(*ret_shardings)[i] = CreateOpShardingFromStringRef(sharding.getValue());
|
||||||
|
}
|
||||||
|
|
||||||
namespace mlir {
|
namespace mlir {
|
||||||
namespace {
|
namespace {
|
||||||
class ConvertToHloModule {
|
class ConvertToHloModule {
|
||||||
|
@ -402,12 +447,17 @@ class ConvertToHloModule {
|
||||||
// are converted to a tuple even when there is only a single return value.
|
// are converted to a tuple even when there is only a single return value.
|
||||||
// Multiple return values are always converted to a tuple and returned as a
|
// Multiple return values are always converted to a tuple and returned as a
|
||||||
// single value.
|
// single value.
|
||||||
explicit ConvertToHloModule(mlir::ModuleOp module, bool use_tuple_args,
|
explicit ConvertToHloModule(
|
||||||
bool return_tuple)
|
mlir::ModuleOp module, bool use_tuple_args, bool return_tuple,
|
||||||
|
tensorflow::XlaCompiler::ShapeRepresentationFn shape_representation_fn)
|
||||||
: module_(module),
|
: module_(module),
|
||||||
module_builder_("main"),
|
module_builder_("main"),
|
||||||
use_tuple_args_(use_tuple_args),
|
use_tuple_args_(use_tuple_args),
|
||||||
return_tuple_(return_tuple) {}
|
return_tuple_(return_tuple),
|
||||||
|
shape_representation_fn_(shape_representation_fn) {
|
||||||
|
if (!shape_representation_fn_)
|
||||||
|
shape_representation_fn_ = tensorflow::IdentityShapeRepresentationFn();
|
||||||
|
}
|
||||||
|
|
||||||
// Perform the lowering to XLA. This function returns failure if an error was
|
// Perform the lowering to XLA. This function returns failure if an error was
|
||||||
// encountered.
|
// encountered.
|
||||||
|
@ -432,6 +482,8 @@ class ConvertToHloModule {
|
||||||
LogicalResult LowerBasicBlockAsFunction(
|
LogicalResult LowerBasicBlockAsFunction(
|
||||||
Block* block, xla::XlaBuilder* builder, bool is_entry_function,
|
Block* block, xla::XlaBuilder* builder, bool is_entry_function,
|
||||||
const std::vector<bool>& entry_args_same_across_replicas,
|
const std::vector<bool>& entry_args_same_across_replicas,
|
||||||
|
llvm::ArrayRef<absl::optional<xla::OpSharding>> arg_shardings,
|
||||||
|
llvm::ArrayRef<absl::optional<xla::OpSharding>> ret_shardings,
|
||||||
xla::XlaComputation* result);
|
xla::XlaComputation* result);
|
||||||
|
|
||||||
::xla::HloModuleProto ConsumeMainProto() {
|
::xla::HloModuleProto ConsumeMainProto() {
|
||||||
|
@ -445,11 +497,23 @@ class ConvertToHloModule {
|
||||||
ConvertToHloModule::ValueLoweringMap* value_lowering);
|
ConvertToHloModule::ValueLoweringMap* value_lowering);
|
||||||
|
|
||||||
private:
|
private:
|
||||||
LogicalResult Lower(mlir::Operation* inst, bool is_entry_function,
|
LogicalResult Lower(
|
||||||
|
mlir::Operation* inst, bool is_entry_function,
|
||||||
|
llvm::ArrayRef<absl::optional<xla::OpSharding>> ret_shardings,
|
||||||
xla::XlaBuilder* builder,
|
xla::XlaBuilder* builder,
|
||||||
ConvertToHloModule::ValueLoweringMap* value_lowering,
|
ConvertToHloModule::ValueLoweringMap* value_lowering,
|
||||||
xla::XlaComputation* result);
|
xla::XlaComputation* result);
|
||||||
|
|
||||||
|
LogicalResult SetEntryTupleShapesAndLeafReplication(
|
||||||
|
Block* block, const std::vector<bool>& entry_args_same_across_replicas,
|
||||||
|
llvm::SmallVectorImpl<xla::Shape>* arg_shapes,
|
||||||
|
std::vector<bool>* leaf_replication);
|
||||||
|
|
||||||
|
LogicalResult SetEntryTupleShardings(
|
||||||
|
Block* block, xla::XlaBuilder* builder,
|
||||||
|
llvm::ArrayRef<absl::optional<xla::OpSharding>> arg_shardings,
|
||||||
|
llvm::SmallVectorImpl<xla::Shape>* arg_shapes);
|
||||||
|
|
||||||
// The module being lowered.
|
// The module being lowered.
|
||||||
mlir::ModuleOp module_;
|
mlir::ModuleOp module_;
|
||||||
|
|
||||||
|
@ -465,6 +529,10 @@ class ConvertToHloModule {
|
||||||
// Whether to always return a tuple.
|
// Whether to always return a tuple.
|
||||||
bool return_tuple_;
|
bool return_tuple_;
|
||||||
|
|
||||||
|
// Shape representation function to determine entry function argument and
|
||||||
|
// result shapes.
|
||||||
|
tensorflow::XlaCompiler::ShapeRepresentationFn shape_representation_fn_;
|
||||||
|
|
||||||
// Unique suffix to give to the name of the next lowered region.
|
// Unique suffix to give to the name of the next lowered region.
|
||||||
size_t region_id_ = 0;
|
size_t region_id_ = 0;
|
||||||
};
|
};
|
||||||
|
@ -876,7 +944,9 @@ StatusOr<xla::Literal> CreateLiteralFromAttr(Type type, ElementsAttr attr) {
|
||||||
}
|
}
|
||||||
|
|
||||||
LogicalResult ConvertToHloModule::Lower(
|
LogicalResult ConvertToHloModule::Lower(
|
||||||
mlir::Operation* inst, bool is_entry_function, xla::XlaBuilder* builder,
|
mlir::Operation* inst, bool is_entry_function,
|
||||||
|
llvm::ArrayRef<absl::optional<xla::OpSharding>> ret_shardings,
|
||||||
|
xla::XlaBuilder* builder,
|
||||||
ConvertToHloModule::ValueLoweringMap* value_lowering,
|
ConvertToHloModule::ValueLoweringMap* value_lowering,
|
||||||
xla::XlaComputation* result) {
|
xla::XlaComputation* result) {
|
||||||
if (succeeded(ExportXlaOperator(inst, {value_lowering, this, builder}))) {
|
if (succeeded(ExportXlaOperator(inst, {value_lowering, this, builder}))) {
|
||||||
|
@ -906,11 +976,37 @@ LogicalResult ConvertToHloModule::Lower(
|
||||||
xla::XlaOp return_value;
|
xla::XlaOp return_value;
|
||||||
unsigned num_return_values = inst->getNumOperands();
|
unsigned num_return_values = inst->getNumOperands();
|
||||||
if ((return_tuple_ && is_entry_function) || num_return_values > 1) {
|
if ((return_tuple_ && is_entry_function) || num_return_values > 1) {
|
||||||
|
const bool has_ret_shardings =
|
||||||
|
!ret_shardings.empty() && AllOptionalShardingsAreSet(ret_shardings);
|
||||||
|
|
||||||
std::vector<xla::XlaOp> returns(num_return_values);
|
std::vector<xla::XlaOp> returns(num_return_values);
|
||||||
for (unsigned i = 0, e = inst->getNumOperands(); i != e; ++i) {
|
for (OpOperand& ret : inst->getOpOperands()) {
|
||||||
returns[i] = value_map[inst->getOperand(i)];
|
unsigned index = ret.getOperandNumber();
|
||||||
|
returns[index] = value_map[ret.get()];
|
||||||
|
if (!is_entry_function || !has_ret_shardings) continue;
|
||||||
|
|
||||||
|
xla::Shape return_shape = xla::TypeToShape(ret.get().getType());
|
||||||
|
StatusOr<xla::XlaOp> reshape =
|
||||||
|
tensorflow::ReshapeWithCorrectRepresentationAndSharding(
|
||||||
|
builder, returns[index], return_shape, shape_representation_fn_,
|
||||||
|
ret_shardings[index], /*fast_mem=*/false);
|
||||||
|
if (!reshape.ok())
|
||||||
|
return inst->emitError() << reshape.status().error_message();
|
||||||
|
|
||||||
|
returns[index] = reshape.ValueOrDie();
|
||||||
}
|
}
|
||||||
|
|
||||||
|
if (has_ret_shardings) {
|
||||||
|
xla::OpSharding sharding;
|
||||||
|
sharding.set_type(xla::OpSharding::TUPLE);
|
||||||
|
for (auto& ret_sharding : ret_shardings)
|
||||||
|
*sharding.add_tuple_shardings() = ret_sharding.value();
|
||||||
|
|
||||||
|
builder->SetSharding(sharding);
|
||||||
|
}
|
||||||
|
|
||||||
return_value = xla::Tuple(builder, returns);
|
return_value = xla::Tuple(builder, returns);
|
||||||
|
builder->ClearSharding();
|
||||||
} else if (num_return_values == 1) {
|
} else if (num_return_values == 1) {
|
||||||
return_value = value_map[inst->getOperand(0)];
|
return_value = value_map[inst->getOperand(0)];
|
||||||
}
|
}
|
||||||
|
@ -976,6 +1072,8 @@ LogicalResult ConvertToHloModule::RunOnFunction(mlir::FuncOp f) {
|
||||||
|
|
||||||
xla::XlaComputation computation;
|
xla::XlaComputation computation;
|
||||||
std::vector<bool> entry_args_same_across_replicas;
|
std::vector<bool> entry_args_same_across_replicas;
|
||||||
|
llvm::SmallVector<absl::optional<xla::OpSharding>, 4> arg_shardings;
|
||||||
|
llvm::SmallVector<absl::optional<xla::OpSharding>, 4> ret_shardings;
|
||||||
if (entry_function) {
|
if (entry_function) {
|
||||||
bool any_arg_replicated = false;
|
bool any_arg_replicated = false;
|
||||||
entry_args_same_across_replicas.reserve(f.getNumArguments());
|
entry_args_same_across_replicas.reserve(f.getNumArguments());
|
||||||
|
@ -1000,21 +1098,90 @@ LogicalResult ConvertToHloModule::RunOnFunction(mlir::FuncOp f) {
|
||||||
// means no replication. This avoids the need for unrelated tests to handle
|
// means no replication. This avoids the need for unrelated tests to handle
|
||||||
// this field.
|
// this field.
|
||||||
if (!any_arg_replicated) entry_args_same_across_replicas.clear();
|
if (!any_arg_replicated) entry_args_same_across_replicas.clear();
|
||||||
|
|
||||||
|
ExtractShardingsFromFunction(f, &arg_shardings, &ret_shardings);
|
||||||
}
|
}
|
||||||
if (failed(LowerBasicBlockAsFunction(&f.front(), &builder, entry_function,
|
if (failed(LowerBasicBlockAsFunction(
|
||||||
entry_args_same_across_replicas,
|
&f.front(), &builder, entry_function, entry_args_same_across_replicas,
|
||||||
&computation))) {
|
arg_shardings, ret_shardings, &computation))) {
|
||||||
return failure();
|
return failure();
|
||||||
}
|
}
|
||||||
lowered_computation_[f] = std::move(computation);
|
lowered_computation_[f] = std::move(computation);
|
||||||
return success();
|
return success();
|
||||||
}
|
}
|
||||||
|
|
||||||
|
LogicalResult ConvertToHloModule::SetEntryTupleShapesAndLeafReplication(
|
||||||
|
Block* block, const std::vector<bool>& entry_args_same_across_replicas,
|
||||||
|
llvm::SmallVectorImpl<xla::Shape>* arg_shapes,
|
||||||
|
std::vector<bool>* leaf_replication) {
|
||||||
|
arg_shapes->reserve(block->getNumArguments());
|
||||||
|
leaf_replication->reserve(block->getNumArguments());
|
||||||
|
for (BlockArgument& arg : block->getArguments()) {
|
||||||
|
arg_shapes->push_back(xla::TypeToShape(arg.getType()));
|
||||||
|
xla::Shape& arg_shape = arg_shapes->back();
|
||||||
|
tensorflow::TensorShape arg_tensor_shape;
|
||||||
|
auto status =
|
||||||
|
tensorflow::XLAShapeToTensorShape(arg_shape, &arg_tensor_shape);
|
||||||
|
if (!status.ok())
|
||||||
|
return block->getParentOp()->emitError() << status.error_message();
|
||||||
|
|
||||||
|
tensorflow::DataType dtype;
|
||||||
|
status = tensorflow::ConvertToDataType(arg.getType(), &dtype);
|
||||||
|
if (!status.ok())
|
||||||
|
return block->getParentOp()->emitError() << status.error_message();
|
||||||
|
|
||||||
|
auto arg_shape_status = shape_representation_fn_(arg_tensor_shape, dtype,
|
||||||
|
/*use_fast_memory=*/false);
|
||||||
|
if (!arg_shape_status.ok())
|
||||||
|
return block->getParentOp()->emitError()
|
||||||
|
<< arg_shape_status.status().error_message();
|
||||||
|
|
||||||
|
arg_shape = std::move(arg_shape_status.ValueOrDie());
|
||||||
|
|
||||||
|
if (entry_args_same_across_replicas.empty()) continue;
|
||||||
|
for (int i = 0, e = xla::ShapeUtil::GetLeafCount(arg_shape); i < e; ++i)
|
||||||
|
leaf_replication->push_back(
|
||||||
|
entry_args_same_across_replicas[arg.getArgNumber()]);
|
||||||
|
}
|
||||||
|
|
||||||
|
return success();
|
||||||
|
}
|
||||||
|
|
||||||
|
LogicalResult ConvertToHloModule::SetEntryTupleShardings(
|
||||||
|
Block* block, xla::XlaBuilder* builder,
|
||||||
|
llvm::ArrayRef<absl::optional<xla::OpSharding>> arg_shardings,
|
||||||
|
llvm::SmallVectorImpl<xla::Shape>* arg_shapes) {
|
||||||
|
if (!arg_shardings.empty() && AllOptionalShardingsAreSet(arg_shardings)) {
|
||||||
|
xla::OpSharding sharding;
|
||||||
|
sharding.set_type(xla::OpSharding::TUPLE);
|
||||||
|
for (auto arg_sharding : llvm::enumerate(arg_shardings)) {
|
||||||
|
auto hlo_sharding =
|
||||||
|
xla::HloSharding::FromProto(arg_sharding.value().value());
|
||||||
|
if (!hlo_sharding.ok())
|
||||||
|
return block->getParentOp()->emitError()
|
||||||
|
<< hlo_sharding.status().error_message();
|
||||||
|
|
||||||
|
auto status = tensorflow::RewriteLayoutWithShardedShape(
|
||||||
|
hlo_sharding.ValueOrDie(), /*use_fast_memory=*/false,
|
||||||
|
shape_representation_fn_, &(*arg_shapes)[arg_sharding.index()]);
|
||||||
|
if (!status.ok())
|
||||||
|
return block->getParentOp()->emitError() << status.error_message();
|
||||||
|
|
||||||
|
*sharding.add_tuple_shardings() = arg_sharding.value().value();
|
||||||
|
}
|
||||||
|
|
||||||
|
builder->SetSharding(sharding);
|
||||||
|
}
|
||||||
|
|
||||||
|
return success();
|
||||||
|
}
|
||||||
|
|
||||||
LogicalResult ConvertToHloModule::LowerBasicBlockAsFunction(
|
LogicalResult ConvertToHloModule::LowerBasicBlockAsFunction(
|
||||||
Block* block, xla::XlaBuilder* builder, bool is_entry_function,
|
Block* block, xla::XlaBuilder* builder, bool is_entry_function,
|
||||||
const std::vector<bool>& entry_args_same_across_replicas,
|
const std::vector<bool>& entry_args_same_across_replicas,
|
||||||
|
llvm::ArrayRef<absl::optional<xla::OpSharding>> arg_shardings,
|
||||||
|
llvm::ArrayRef<absl::optional<xla::OpSharding>> ret_shardings,
|
||||||
xla::XlaComputation* result) {
|
xla::XlaComputation* result) {
|
||||||
auto& bb = *block;
|
|
||||||
// Mapping from the Value to lowered XlaOp. The code below lowers in
|
// Mapping from the Value to lowered XlaOp. The code below lowers in
|
||||||
// program order and will fail if an operand is unseen. This can be improved.
|
// program order and will fail if an operand is unseen. This can be improved.
|
||||||
ValueLoweringMap lowering;
|
ValueLoweringMap lowering;
|
||||||
|
@ -1022,29 +1189,28 @@ LogicalResult ConvertToHloModule::LowerBasicBlockAsFunction(
|
||||||
// If using tuples as input, then there is only one input parameter that is a
|
// If using tuples as input, then there is only one input parameter that is a
|
||||||
// tuple.
|
// tuple.
|
||||||
if (is_entry_function && use_tuple_args_) {
|
if (is_entry_function && use_tuple_args_) {
|
||||||
std::vector<xla::Shape> arg_shapes;
|
llvm::SmallVector<xla::Shape, 4> arg_shapes;
|
||||||
arg_shapes.reserve(bb.getNumArguments());
|
|
||||||
std::vector<bool> leaf_replication;
|
std::vector<bool> leaf_replication;
|
||||||
for (auto& arg : bb.getArguments()) {
|
if (failed(SetEntryTupleShapesAndLeafReplication(
|
||||||
arg_shapes.push_back(xla::TypeToShape(arg.getType()));
|
block, entry_args_same_across_replicas, &arg_shapes,
|
||||||
if (!entry_args_same_across_replicas.empty()) {
|
&leaf_replication)))
|
||||||
for (int i = 0; i < xla::ShapeUtil::GetLeafCount(arg_shapes.back());
|
return failure();
|
||||||
++i) {
|
|
||||||
leaf_replication.push_back(
|
if (failed(
|
||||||
entry_args_same_across_replicas[arg.getArgNumber()]);
|
SetEntryTupleShardings(block, builder, arg_shardings, &arg_shapes)))
|
||||||
}
|
return failure();
|
||||||
}
|
|
||||||
}
|
|
||||||
xla::Shape input_shape = xla::ShapeUtil::MakeTupleShape(arg_shapes);
|
xla::Shape input_shape = xla::ShapeUtil::MakeTupleShape(arg_shapes);
|
||||||
auto tuple =
|
auto tuple =
|
||||||
xla::Parameter(builder, 0, input_shape, "arg_tuple", leaf_replication);
|
xla::Parameter(builder, 0, input_shape, "arg_tuple", leaf_replication);
|
||||||
for (auto& it : llvm::enumerate(bb.getArguments())) {
|
|
||||||
lowering[it.value()] = xla::GetTupleElement(tuple, it.index());
|
builder->ClearSharding();
|
||||||
}
|
|
||||||
|
for (BlockArgument& arg : block->getArguments())
|
||||||
|
lowering[arg] = xla::GetTupleElement(tuple, arg.getArgNumber());
|
||||||
} else {
|
} else {
|
||||||
for (auto& it : llvm::enumerate(bb.getArguments())) {
|
for (BlockArgument& arg : block->getArguments()) {
|
||||||
auto arg = it.value();
|
auto num = arg.getArgNumber();
|
||||||
auto num = it.index();
|
|
||||||
xla::Shape shape = xla::TypeToShape(arg.getType());
|
xla::Shape shape = xla::TypeToShape(arg.getType());
|
||||||
if (entry_args_same_across_replicas.empty()) {
|
if (entry_args_same_across_replicas.empty()) {
|
||||||
lowering[arg] =
|
lowering[arg] =
|
||||||
|
@ -1058,8 +1224,9 @@ LogicalResult ConvertToHloModule::LowerBasicBlockAsFunction(
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
for (auto& inst : bb)
|
for (auto& inst : *block)
|
||||||
if (failed(Lower(&inst, is_entry_function, builder, &lowering, result)))
|
if (failed(Lower(&inst, is_entry_function, ret_shardings, builder,
|
||||||
|
&lowering, result)))
|
||||||
return failure();
|
return failure();
|
||||||
|
|
||||||
return success();
|
return success();
|
||||||
|
@ -1069,8 +1236,10 @@ LogicalResult ConvertToHloModule::LowerRegionAsComputation(
|
||||||
mlir::Region* region, xla::XlaComputation* func) {
|
mlir::Region* region, xla::XlaComputation* func) {
|
||||||
std::unique_ptr<xla::XlaBuilder> builder =
|
std::unique_ptr<xla::XlaBuilder> builder =
|
||||||
module_builder_.CreateSubBuilder(absl::StrCat("region_", region_id_++));
|
module_builder_.CreateSubBuilder(absl::StrCat("region_", region_id_++));
|
||||||
return LowerBasicBlockAsFunction(®ion->front(), builder.get(),
|
return LowerBasicBlockAsFunction(
|
||||||
/*is_entry_function=*/false, {}, func);
|
®ion->front(), builder.get(),
|
||||||
|
/*is_entry_function=*/false, /*entry_args_same_across_replicas=*/{},
|
||||||
|
/*arg_shardings=*/{}, /*ret_shardings=*/{}, func);
|
||||||
}
|
}
|
||||||
|
|
||||||
std::string PaddingMapBadArrayAttrMsg(llvm::StringRef attr_name, int index) {
|
std::string PaddingMapBadArrayAttrMsg(llvm::StringRef attr_name, int index) {
|
||||||
|
@ -1241,9 +1410,12 @@ LogicalResult AddDynamicParameterBindings(mlir::ModuleOp module,
|
||||||
} // namespace
|
} // namespace
|
||||||
|
|
||||||
Status ConvertMlirHloToHlo(mlir::ModuleOp module, xla::HloProto* hlo_proto,
|
Status ConvertMlirHloToHlo(mlir::ModuleOp module, xla::HloProto* hlo_proto,
|
||||||
bool use_tuple_args, bool return_tuple) {
|
bool use_tuple_args, bool return_tuple,
|
||||||
|
const tensorflow::XlaCompiler::ShapeRepresentationFn
|
||||||
|
shape_representation_fn) {
|
||||||
mlir::StatusScopedDiagnosticHandler diag_handler(module.getContext());
|
mlir::StatusScopedDiagnosticHandler diag_handler(module.getContext());
|
||||||
ConvertToHloModule converter(module, use_tuple_args, return_tuple);
|
ConvertToHloModule converter(module, use_tuple_args, return_tuple,
|
||||||
|
shape_representation_fn);
|
||||||
if (failed(converter.Run())) return diag_handler.ConsumeStatus();
|
if (failed(converter.Run())) return diag_handler.ConsumeStatus();
|
||||||
auto hlo_module = converter.ConsumeMainProto();
|
auto hlo_module = converter.ConsumeMainProto();
|
||||||
hlo_proto->mutable_hlo_module()->Swap(&hlo_module);
|
hlo_proto->mutable_hlo_module()->Swap(&hlo_module);
|
||||||
|
|
|
@ -18,6 +18,7 @@ limitations under the License.
|
||||||
|
|
||||||
#include "mlir/IR/Module.h" // from @llvm-project
|
#include "mlir/IR/Module.h" // from @llvm-project
|
||||||
#include "tensorflow/compiler/mlir/tensorflow/utils/error_util.h"
|
#include "tensorflow/compiler/mlir/tensorflow/utils/error_util.h"
|
||||||
|
#include "tensorflow/compiler/tf2xla/xla_compiler.h"
|
||||||
#include "tensorflow/compiler/xla/client/xla_builder.h"
|
#include "tensorflow/compiler/xla/client/xla_builder.h"
|
||||||
#include "tensorflow/compiler/xla/service/hlo_module.h"
|
#include "tensorflow/compiler/xla/service/hlo_module.h"
|
||||||
|
|
||||||
|
@ -31,7 +32,9 @@ namespace mlir {
|
||||||
// Multiple return values are always converted to a tuple and returned as a
|
// Multiple return values are always converted to a tuple and returned as a
|
||||||
// single value.
|
// single value.
|
||||||
Status ConvertMlirHloToHlo(mlir::ModuleOp module, xla::HloProto* hlo_proto,
|
Status ConvertMlirHloToHlo(mlir::ModuleOp module, xla::HloProto* hlo_proto,
|
||||||
bool use_tuple_args, bool return_tuple);
|
bool use_tuple_args, bool return_tuple,
|
||||||
|
const tensorflow::XlaCompiler::ShapeRepresentationFn
|
||||||
|
shape_representation_fn = nullptr);
|
||||||
|
|
||||||
// Creates XlaOp equivalent of a given MLIR operation using the operand info
|
// Creates XlaOp equivalent of a given MLIR operation using the operand info
|
||||||
// from `value_lowering` map.
|
// from `value_lowering` map.
|
||||||
|
|
|
@ -138,46 +138,6 @@ Status ExecuteGraph(XlaContext* xla_context, std::unique_ptr<Graph> graph,
|
||||||
return Status::OK();
|
return Status::OK();
|
||||||
}
|
}
|
||||||
|
|
||||||
// There is a shape_representation_fn or sharding for an output, this function
|
|
||||||
// uses a reshape to fix the layout.
|
|
||||||
xla::StatusOr<xla::XlaOp> ReshapeWithCorrectRepresentationAndSharding(
|
|
||||||
xla::XlaBuilder* builder, xla::XlaOp original, xla::Shape original_shape,
|
|
||||||
XlaCompiler::ShapeRepresentationFn shape_representation_fn,
|
|
||||||
absl::optional<xla::OpSharding> sharding, bool fast_mem) {
|
|
||||||
if (original_shape.IsTuple()) {
|
|
||||||
std::vector<xla::XlaOp> elements;
|
|
||||||
for (int64 i = 0; i < original_shape.tuple_shapes_size(); ++i) {
|
|
||||||
auto subsharding = sharding ? sharding->tuple_shardings(i) : sharding;
|
|
||||||
TF_ASSIGN_OR_RETURN(auto element,
|
|
||||||
ReshapeWithCorrectRepresentationAndSharding(
|
|
||||||
builder, xla::GetTupleElement(original, i),
|
|
||||||
original_shape.tuple_shapes(i),
|
|
||||||
shape_representation_fn, subsharding, fast_mem));
|
|
||||||
elements.push_back(element);
|
|
||||||
}
|
|
||||||
return xla::Tuple(builder, elements);
|
|
||||||
}
|
|
||||||
if (!original_shape.IsArray()) return original;
|
|
||||||
TensorShape shape;
|
|
||||||
TF_RETURN_IF_ERROR(XLAShapeToTensorShape(original_shape, &shape));
|
|
||||||
TF_ASSIGN_OR_RETURN(DataType dtype, EncodePrimitiveTypeAsDataType(
|
|
||||||
original_shape.element_type()));
|
|
||||||
TF_ASSIGN_OR_RETURN(auto to_shape,
|
|
||||||
shape_representation_fn(shape, dtype, fast_mem));
|
|
||||||
if (sharding) {
|
|
||||||
TF_ASSIGN_OR_RETURN(auto hlo_sharding,
|
|
||||||
xla::HloSharding::FromProto(*sharding));
|
|
||||||
TF_RETURN_IF_ERROR(RewriteLayoutWithShardedShape(
|
|
||||||
hlo_sharding, fast_mem, shape_representation_fn, &to_shape));
|
|
||||||
}
|
|
||||||
if (xla::ShapeUtil::Compatible(original_shape, to_shape)) {
|
|
||||||
for (int64 i = 0; i < original_shape.rank(); ++i) {
|
|
||||||
to_shape.set_dynamic_dimension(i, original_shape.is_dynamic_dimension(i));
|
|
||||||
}
|
|
||||||
}
|
|
||||||
return xla::Reshape(to_shape, original);
|
|
||||||
}
|
|
||||||
|
|
||||||
// Builds the XLA computation.
|
// Builds the XLA computation.
|
||||||
// - `args` is the list of input arguments
|
// - `args` is the list of input arguments
|
||||||
// - `retvals` is the list of retvals produced by _Retval operators, in index
|
// - `retvals` is the list of retvals produced by _Retval operators, in index
|
||||||
|
@ -562,13 +522,7 @@ XlaCompiler::XlaCompiler(XlaCompiler::Options options)
|
||||||
|
|
||||||
// The default shape representation function is the identity.
|
// The default shape representation function is the identity.
|
||||||
if (!options_.shape_representation_fn) {
|
if (!options_.shape_representation_fn) {
|
||||||
options_.shape_representation_fn =
|
options_.shape_representation_fn = IdentityShapeRepresentationFn();
|
||||||
[](const TensorShape& shape, DataType dtype,
|
|
||||||
bool use_fast_memory) -> xla::StatusOr<xla::Shape> {
|
|
||||||
xla::Shape xla_shape;
|
|
||||||
TF_RETURN_IF_ERROR(TensorShapeToXLAShape(dtype, shape, &xla_shape));
|
|
||||||
return xla_shape;
|
|
||||||
};
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -1502,6 +1456,15 @@ xla::StatusOr<xla::XlaOp> XlaCompiler::GetNodeToken(const string& node_name) {
|
||||||
return iter->second;
|
return iter->second;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
XlaCompiler::ShapeRepresentationFn IdentityShapeRepresentationFn() {
|
||||||
|
return [](const TensorShape& shape, DataType dtype,
|
||||||
|
bool use_fast_memory) -> xla::StatusOr<xla::Shape> {
|
||||||
|
xla::Shape xla_shape;
|
||||||
|
TF_RETURN_IF_ERROR(TensorShapeToXLAShape(dtype, shape, &xla_shape));
|
||||||
|
return xla_shape;
|
||||||
|
};
|
||||||
|
}
|
||||||
|
|
||||||
// Rewrites the layout of xla_shape if there is tiled sharding.
|
// Rewrites the layout of xla_shape if there is tiled sharding.
|
||||||
Status RewriteLayoutWithShardedShape(
|
Status RewriteLayoutWithShardedShape(
|
||||||
const absl::optional<xla::HloSharding>& sharding, bool use_fast_memory,
|
const absl::optional<xla::HloSharding>& sharding, bool use_fast_memory,
|
||||||
|
@ -1542,4 +1505,44 @@ Status RewriteLayoutWithShardedShape(
|
||||||
return Status::OK();
|
return Status::OK();
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// There is a shape_representation_fn or sharding for an output, this function
|
||||||
|
// uses a reshape to fix the layout.
|
||||||
|
xla::StatusOr<xla::XlaOp> ReshapeWithCorrectRepresentationAndSharding(
|
||||||
|
xla::XlaBuilder* builder, xla::XlaOp original, xla::Shape original_shape,
|
||||||
|
XlaCompiler::ShapeRepresentationFn shape_representation_fn,
|
||||||
|
absl::optional<xla::OpSharding> sharding, bool fast_mem) {
|
||||||
|
if (original_shape.IsTuple()) {
|
||||||
|
std::vector<xla::XlaOp> elements;
|
||||||
|
for (int64 i = 0; i < original_shape.tuple_shapes_size(); ++i) {
|
||||||
|
auto subsharding = sharding ? sharding->tuple_shardings(i) : sharding;
|
||||||
|
TF_ASSIGN_OR_RETURN(auto element,
|
||||||
|
ReshapeWithCorrectRepresentationAndSharding(
|
||||||
|
builder, xla::GetTupleElement(original, i),
|
||||||
|
original_shape.tuple_shapes(i),
|
||||||
|
shape_representation_fn, subsharding, fast_mem));
|
||||||
|
elements.push_back(element);
|
||||||
|
}
|
||||||
|
return xla::Tuple(builder, elements);
|
||||||
|
}
|
||||||
|
if (!original_shape.IsArray()) return original;
|
||||||
|
TensorShape shape;
|
||||||
|
TF_RETURN_IF_ERROR(XLAShapeToTensorShape(original_shape, &shape));
|
||||||
|
TF_ASSIGN_OR_RETURN(DataType dtype, EncodePrimitiveTypeAsDataType(
|
||||||
|
original_shape.element_type()));
|
||||||
|
TF_ASSIGN_OR_RETURN(auto to_shape,
|
||||||
|
shape_representation_fn(shape, dtype, fast_mem));
|
||||||
|
if (sharding) {
|
||||||
|
TF_ASSIGN_OR_RETURN(auto hlo_sharding,
|
||||||
|
xla::HloSharding::FromProto(*sharding));
|
||||||
|
TF_RETURN_IF_ERROR(RewriteLayoutWithShardedShape(
|
||||||
|
hlo_sharding, fast_mem, shape_representation_fn, &to_shape));
|
||||||
|
}
|
||||||
|
if (xla::ShapeUtil::Compatible(original_shape, to_shape)) {
|
||||||
|
for (int64 i = 0; i < original_shape.rank(); ++i) {
|
||||||
|
to_shape.set_dynamic_dimension(i, original_shape.is_dynamic_dimension(i));
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return xla::Reshape(to_shape, original);
|
||||||
|
}
|
||||||
|
|
||||||
} // namespace tensorflow
|
} // namespace tensorflow
|
||||||
|
|
|
@ -518,12 +518,22 @@ class XlaCompiler {
|
||||||
TF_DISALLOW_COPY_AND_ASSIGN(XlaCompiler);
|
TF_DISALLOW_COPY_AND_ASSIGN(XlaCompiler);
|
||||||
};
|
};
|
||||||
|
|
||||||
|
// Creates an identity shape representation function.
|
||||||
|
XlaCompiler::ShapeRepresentationFn IdentityShapeRepresentationFn();
|
||||||
|
|
||||||
// Rewrites the layout of xla_shape if there is tiled sharding.
|
// Rewrites the layout of xla_shape if there is tiled sharding.
|
||||||
Status RewriteLayoutWithShardedShape(
|
Status RewriteLayoutWithShardedShape(
|
||||||
const absl::optional<xla::HloSharding>& sharding, bool use_fast_memory,
|
const absl::optional<xla::HloSharding>& sharding, bool use_fast_memory,
|
||||||
XlaCompiler::ShapeRepresentationFn shape_representation_fn,
|
XlaCompiler::ShapeRepresentationFn shape_representation_fn,
|
||||||
xla::Shape* xla_shape);
|
xla::Shape* xla_shape);
|
||||||
|
|
||||||
|
// Adds reshapes to fix the layout of an output, if a shape_representation_fn or
|
||||||
|
// sharding is present.
|
||||||
|
xla::StatusOr<xla::XlaOp> ReshapeWithCorrectRepresentationAndSharding(
|
||||||
|
xla::XlaBuilder* builder, xla::XlaOp original, xla::Shape original_shape,
|
||||||
|
XlaCompiler::ShapeRepresentationFn shape_representation_fn,
|
||||||
|
absl::optional<xla::OpSharding> sharding, bool fast_mem);
|
||||||
|
|
||||||
} // namespace tensorflow
|
} // namespace tensorflow
|
||||||
|
|
||||||
#endif // TENSORFLOW_COMPILER_TF2XLA_XLA_COMPILER_H_
|
#endif // TENSORFLOW_COMPILER_TF2XLA_XLA_COMPILER_H_
|
||||||
|
|
Loading…
Reference in New Issue