Add "xla-legalize-tf-with-tf2xla" pass to XLA compilation pass pipeline
Adding this pass to the pipeline requires moving device deps outside of the pass. Otherwise, this will create circular dependency between XLA compilation cache and the device. Also, support BF16 type in ConvertPrimitiveTypeToMLIRType helper. PiperOrigin-RevId: 305594849 Change-Id: I469a6715511417b5db2bbc9b2a74fd2f24be5440
This commit is contained in:
parent
5ba0711744
commit
d5f5640cd9
|
@ -296,10 +296,10 @@ Status XlaCompilationCache::CompileSingleOp(
|
|||
arg_shapes.push_back(absl::get<TensorShape>(arg.shape));
|
||||
}
|
||||
GraphDebugInfo debug_info;
|
||||
return CompileGraphToXlaHlo(*graph, {arg_shapes.data(), arg_shapes.size()},
|
||||
compile_options.use_tuple_arg,
|
||||
*options.flib_def, debug_info,
|
||||
options.shape_representation_fn, result);
|
||||
return CompileGraphToXlaHlo(
|
||||
*graph, {arg_shapes.data(), arg_shapes.size()},
|
||||
options.device_type.type_string(), compile_options.use_tuple_arg,
|
||||
*options.flib_def, debug_info, options.shape_representation_fn, result);
|
||||
};
|
||||
return CompileImpl(options, name, args, compile_op,
|
||||
/*compile_threshold=*/absl::nullopt,
|
||||
|
|
|
@ -58,6 +58,11 @@ cc_library(
|
|||
"//tensorflow/python:__subpackages__",
|
||||
],
|
||||
deps = [
|
||||
"@llvm-project//mlir:Affine",
|
||||
"@llvm-project//mlir:QuantOps",
|
||||
# Link jit lib to link JIT devices required to run
|
||||
# xla-legalize-tf-with-tf2xla pass.
|
||||
"//tensorflow/compiler/jit",
|
||||
"//tensorflow/compiler/mlir/lite:tensorflow_lite",
|
||||
"//tensorflow/compiler/mlir/lite:tensorflow_lite_dialect_registration",
|
||||
"//tensorflow/compiler/mlir/lite:tensorflow_lite_legalize_tf",
|
||||
|
@ -90,8 +95,6 @@ cc_library(
|
|||
"//tensorflow/compiler/mlir/xla:xla_lower",
|
||||
"//tensorflow/compiler/mlir/xla:xla_materialize_broadcasts",
|
||||
"//tensorflow/compiler/mlir/xla:xla_test_passes",
|
||||
"@llvm-project//mlir:Affine",
|
||||
"@llvm-project//mlir:QuantOps",
|
||||
],
|
||||
)
|
||||
|
||||
|
|
|
@ -1078,6 +1078,7 @@ COMPILE_MLIR_UTIL_DEPS = [
|
|||
"//tensorflow/compiler/mlir/xla:mlir_hlo_to_hlo",
|
||||
"//tensorflow/compiler/mlir/xla:type_to_shape",
|
||||
"//tensorflow/compiler/mlir/xla:xla_legalize_tf",
|
||||
"//tensorflow/compiler/mlir/xla:xla_legalize_tf_with_tf2xla",
|
||||
"//tensorflow/compiler/tf2xla:common",
|
||||
"//tensorflow/compiler/tf2xla:xla_compiler",
|
||||
"//tensorflow/core:framework",
|
||||
|
@ -1118,6 +1119,7 @@ tf_cc_test(
|
|||
srcs = ["utils/compile_mlir_util_test.cc"],
|
||||
deps = [
|
||||
":compile_mlir_util",
|
||||
"//tensorflow/compiler/jit",
|
||||
"//tensorflow/compiler/tf2xla:common",
|
||||
"//tensorflow/compiler/tf2xla:xla_compiler",
|
||||
"//tensorflow/compiler/xla:test",
|
||||
|
|
|
@ -254,8 +254,9 @@ static void RegisterDialects() {
|
|||
} // namespace
|
||||
|
||||
Status ConvertMLIRToXlaComputation(
|
||||
mlir::ModuleOp module_op, xla::XlaComputation* xla_computation,
|
||||
bool use_tuple_args, bool return_tuple,
|
||||
mlir::ModuleOp module_op, llvm::StringRef device_type,
|
||||
xla::XlaComputation* xla_computation, bool use_tuple_args,
|
||||
bool return_tuple,
|
||||
const XlaCompiler::ShapeRepresentationFn shape_representation_fn) {
|
||||
mlir::PassManager tf2xla(module_op.getContext());
|
||||
tf2xla.addNestedPass<mlir::FuncOp>(mlir::createCanonicalizerPass());
|
||||
|
@ -268,6 +269,7 @@ Status ConvertMLIRToXlaComputation(
|
|||
// with a tuple argument which break the assumption of resource lifting
|
||||
// inside PromoteResourcesToArgs.
|
||||
tf2xla.addPass(mlir::xla_hlo::createLegalizeTFControlFlowPass());
|
||||
tf2xla.addPass(mlir::xla_hlo::createLegalizeTfWithTf2XlaPass(device_type));
|
||||
// We need to run LegalizeTFPass 2 times because first
|
||||
// LegalizeTFPass(allow_partial_conversion=true) can expose more graph pruning
|
||||
// and canonicalization opportunities that are necessary for the second
|
||||
|
@ -308,7 +310,7 @@ Status ConvertMLIRToXlaComputation(
|
|||
|
||||
static Status CompileMlirToXlaHlo(
|
||||
mlir::ModuleOp module_op, llvm::ArrayRef<TensorShape> arg_shapes,
|
||||
bool use_tuple_args,
|
||||
llvm::StringRef device_type, bool use_tuple_args,
|
||||
XlaCompiler::ShapeRepresentationFn shape_representation_fn,
|
||||
XlaCompiler::CompilationResult* compilation_result) {
|
||||
if (VLOG_IS_ON(1))
|
||||
|
@ -326,7 +328,8 @@ static Status CompileMlirToXlaHlo(
|
|||
// Convert MLIR module to XLA HLO proto contained in XlaComputation.
|
||||
compilation_result->computation = std::make_shared<xla::XlaComputation>();
|
||||
TF_RETURN_IF_ERROR(ConvertMLIRToXlaComputation(
|
||||
module_op, compilation_result->computation.get(), use_tuple_args,
|
||||
module_op, device_type, compilation_result->computation.get(),
|
||||
use_tuple_args,
|
||||
/*return_tuple=*/true, shape_representation_fn));
|
||||
|
||||
// Construct mapping from XlaComputation's arg to input edges of execute
|
||||
|
@ -355,7 +358,7 @@ static Status CompileMlirToXlaHlo(
|
|||
|
||||
Status CompileSerializedMlirToXlaHlo(
|
||||
llvm::StringRef mlir_module_string, llvm::ArrayRef<TensorShape> arg_shapes,
|
||||
bool use_tuple_args,
|
||||
llvm::StringRef device_type, bool use_tuple_args,
|
||||
const XlaCompiler::ShapeRepresentationFn shape_representation_fn,
|
||||
XlaCompiler::CompilationResult* compilation_result) {
|
||||
RegisterDialects();
|
||||
|
@ -364,14 +367,15 @@ Status CompileSerializedMlirToXlaHlo(
|
|||
|
||||
TF_RETURN_IF_ERROR(
|
||||
ParseMlirModule(mlir_module_string, &mlir_context, &mlir_module));
|
||||
return CompileMlirToXlaHlo(mlir_module.get(), arg_shapes, use_tuple_args,
|
||||
shape_representation_fn, compilation_result);
|
||||
return CompileMlirToXlaHlo(mlir_module.get(), arg_shapes, device_type,
|
||||
use_tuple_args, shape_representation_fn,
|
||||
compilation_result);
|
||||
}
|
||||
|
||||
Status CompileGraphToXlaHlo(
|
||||
const Graph& graph, llvm::ArrayRef<TensorShape> arg_shapes,
|
||||
bool use_tuple_args, const FunctionLibraryDefinition& flib_def,
|
||||
const GraphDebugInfo& debug_info,
|
||||
llvm::StringRef device_type, bool use_tuple_args,
|
||||
const FunctionLibraryDefinition& flib_def, const GraphDebugInfo& debug_info,
|
||||
const XlaCompiler::ShapeRepresentationFn shape_representation_fn,
|
||||
XlaCompiler::CompilationResult* compilation_result) {
|
||||
RegisterDialects();
|
||||
|
@ -383,8 +387,8 @@ Status CompileGraphToXlaHlo(
|
|||
if (!module_or.ok()) return module_or.status();
|
||||
|
||||
return CompileMlirToXlaHlo(module_or.ValueOrDie().get(), arg_shapes,
|
||||
use_tuple_args, shape_representation_fn,
|
||||
compilation_result);
|
||||
device_type, use_tuple_args,
|
||||
shape_representation_fn, compilation_result);
|
||||
}
|
||||
|
||||
} // namespace tensorflow
|
||||
|
|
|
@ -29,6 +29,8 @@ namespace tensorflow {
|
|||
// Lowers MLIR module to XLA HLO inside an XlaComputation. The input module
|
||||
// should only contain operations in tf dialect. If the input module contains
|
||||
// operation in the tf_executor dialect, for example, returns an error.
|
||||
// Exception to this are tf_executor dialect ops that are optimized away through
|
||||
// canonicalization.
|
||||
//
|
||||
// Operations in tf dialect are lowered to XLA HLO through the following steps:
|
||||
// . Legalizes control flow operations.
|
||||
|
@ -39,6 +41,8 @@ namespace tensorflow {
|
|||
// . Legalizes the operations to XLA HLO operations.
|
||||
// . Canonicalizes the XLA HLO operations.
|
||||
//
|
||||
// device_type: XLA JIT device to use for compilation such as "XLA_CPU_JIT",
|
||||
// "XLA_GPU_JIT" or "XLA_TPU_JIT".
|
||||
// use_tuple_args: when this is true, always create a tuple argument for the
|
||||
// entry computation.
|
||||
// return_tuple: when this is true, always create a tuple result for the
|
||||
|
@ -47,23 +51,24 @@ namespace tensorflow {
|
|||
// will be used to determine argument and result shapes. Otherwise the
|
||||
// original shape will be used as is.
|
||||
Status ConvertMLIRToXlaComputation(
|
||||
mlir::ModuleOp module_op, xla::XlaComputation* xla_computation,
|
||||
bool use_tuple_args, bool return_tuple,
|
||||
mlir::ModuleOp module_op, llvm::StringRef device_type,
|
||||
xla::XlaComputation* xla_computation, bool use_tuple_args,
|
||||
bool return_tuple,
|
||||
const XlaCompiler::ShapeRepresentationFn shape_representation_fn = nullptr);
|
||||
|
||||
// Compiles a serialized MLIR module into XLA HLO, generates all accompanying
|
||||
// metadata and stores them in CompilationResult.
|
||||
Status CompileSerializedMlirToXlaHlo(
|
||||
llvm::StringRef mlir_module_string, llvm::ArrayRef<TensorShape> arg_shapes,
|
||||
bool use_tuple_args,
|
||||
llvm::StringRef device_type, bool use_tuple_args,
|
||||
const XlaCompiler::ShapeRepresentationFn shape_representation_fn,
|
||||
XlaCompiler::CompilationResult* compilation_result);
|
||||
|
||||
// Same as the above but takes input as TensorFlow Graph.
|
||||
Status CompileGraphToXlaHlo(
|
||||
const Graph& graph, llvm::ArrayRef<TensorShape> arg_shapes,
|
||||
bool use_tuple_args, const FunctionLibraryDefinition& flib_def,
|
||||
const GraphDebugInfo& debug_info,
|
||||
llvm::StringRef device_type, bool use_tuple_args,
|
||||
const FunctionLibraryDefinition& flib_def, const GraphDebugInfo& debug_info,
|
||||
const XlaCompiler::ShapeRepresentationFn shape_representation_fn,
|
||||
XlaCompiler::CompilationResult* compilation_result);
|
||||
|
||||
|
|
|
@ -46,7 +46,7 @@ TEST(CompileSerializedMlirToXlaHloTest, InvalidSerializedMlirModule) {
|
|||
XlaCompiler::CompilationResult compilation_result;
|
||||
|
||||
Status s = CompileSerializedMlirToXlaHlo(
|
||||
invalid_mlir_module, arg_shapes,
|
||||
invalid_mlir_module, arg_shapes, "XLA_CPU_JIT",
|
||||
/*use_tuple_args=*/true, TestShapeRepresentation, &compilation_result);
|
||||
EXPECT_EQ(s.code(), tensorflow::errors::Code::INVALID_ARGUMENT);
|
||||
EXPECT_EQ(s.ToString(),
|
||||
|
@ -68,7 +68,7 @@ TEST(CompileSerializedMlirToXlaHloTest, TupleArgs) {
|
|||
XlaCompiler::CompilationResult compilation_result;
|
||||
|
||||
Status s = CompileSerializedMlirToXlaHlo(
|
||||
kBinaryAddModule, arg_shapes,
|
||||
kBinaryAddModule, arg_shapes, "XLA_CPU_JIT",
|
||||
/*use_tuple_args=*/true, TestShapeRepresentation, &compilation_result);
|
||||
TF_ASSERT_OK(s);
|
||||
|
||||
|
@ -126,7 +126,7 @@ TEST(CompileSerializedMlirToXlaHloTest, IndividualArgs) {
|
|||
XlaCompiler::CompilationResult compilation_result;
|
||||
|
||||
Status s = CompileSerializedMlirToXlaHlo(
|
||||
kBinaryAddModule, arg_shapes,
|
||||
kBinaryAddModule, arg_shapes, "XLA_CPU_JIT",
|
||||
/*use_tuple_args=*/false, TestShapeRepresentation, &compilation_result);
|
||||
TF_ASSERT_OK(s);
|
||||
|
||||
|
@ -197,7 +197,7 @@ TEST(CompileSerializedMlirToXlaHloTest, CompileTimeConstantFoldedSuccess) {
|
|||
XlaCompiler::CompilationResult compilation_result;
|
||||
|
||||
Status s = CompileSerializedMlirToXlaHlo(
|
||||
mlir_module, arg_shapes,
|
||||
mlir_module, arg_shapes, "XLA_CPU_JIT",
|
||||
/*use_tuple_args=*/true, TestShapeRepresentation, &compilation_result);
|
||||
TF_ASSERT_OK(s);
|
||||
|
||||
|
@ -236,7 +236,7 @@ TEST(CompileSerializedMlirToXlaHloTest, ShapeInference) {
|
|||
XlaCompiler::CompilationResult compilation_result;
|
||||
|
||||
Status s = CompileSerializedMlirToXlaHlo(
|
||||
mlir_module, arg_shapes,
|
||||
mlir_module, arg_shapes, "XLA_CPU_JIT",
|
||||
/*use_tuple_args=*/true, TestShapeRepresentation, &compilation_result);
|
||||
TF_ASSERT_OK(s);
|
||||
|
||||
|
@ -267,7 +267,7 @@ module attributes {tf.versions = {producer = 179 : i32}} {
|
|||
XlaCompiler::CompilationResult compilation_result;
|
||||
|
||||
Status s = CompileSerializedMlirToXlaHlo(
|
||||
mlir_module, arg_shapes,
|
||||
mlir_module, arg_shapes, "XLA_CPU_JIT",
|
||||
/*use_tuple_args=*/true, TestShapeRepresentation, &compilation_result);
|
||||
TF_ASSERT_OK(s);
|
||||
|
||||
|
@ -325,7 +325,7 @@ module attributes {tf.versions = {producer = 179 : i32}} {
|
|||
XlaCompiler::CompilationResult compilation_result;
|
||||
|
||||
Status s = CompileSerializedMlirToXlaHlo(
|
||||
mlir_module, arg_shapes,
|
||||
mlir_module, arg_shapes, "XLA_CPU_JIT",
|
||||
/*use_tuple_args=*/true, TestShapeRepresentation, &compilation_result);
|
||||
TF_ASSERT_OK(s);
|
||||
|
||||
|
@ -362,7 +362,7 @@ module attributes {tf.versions = {producer = 179 : i32}} {
|
|||
XlaCompiler::CompilationResult compilation_result;
|
||||
|
||||
Status s = CompileSerializedMlirToXlaHlo(
|
||||
mlir_module, arg_shapes,
|
||||
mlir_module, arg_shapes, "XLA_CPU_JIT",
|
||||
/*use_tuple_args=*/true, TestShapeRepresentation, &compilation_result);
|
||||
ASSERT_FALSE(s.ok());
|
||||
EXPECT_EQ(s.error_message(),
|
||||
|
@ -384,7 +384,7 @@ module attributes {tf.versions = {bad_consumers = [], min_consumer = 0 : i32, pr
|
|||
XlaCompiler::CompilationResult compilation_result;
|
||||
|
||||
Status s = CompileSerializedMlirToXlaHlo(
|
||||
mlir_module, arg_shapes,
|
||||
mlir_module, arg_shapes, "XLA_CPU_JIT",
|
||||
/*use_tuple_args=*/true, TestShapeRepresentation, &compilation_result);
|
||||
TF_ASSERT_OK(s);
|
||||
|
||||
|
@ -424,9 +424,10 @@ TEST(CompileGraphToXlaHlo, Basic) {
|
|||
test::graph::Retval(&graph, 0, arg);
|
||||
|
||||
XlaCompiler::CompilationResult result;
|
||||
TF_ASSERT_OK(CompileGraphToXlaHlo(
|
||||
graph, /*arg_shapes=*/{TensorShape()}, /*use_tuple_args=*/false, flib_def,
|
||||
GraphDebugInfo(), /*shape_representation_fn=*/nullptr, &result));
|
||||
TF_ASSERT_OK(
|
||||
CompileGraphToXlaHlo(graph, /*arg_shapes=*/{TensorShape()}, "XLA_CPU_JIT",
|
||||
/*use_tuple_args=*/false, flib_def, GraphDebugInfo(),
|
||||
/*shape_representation_fn=*/nullptr, &result));
|
||||
|
||||
const xla::HloModuleConfig module_config(
|
||||
result.computation->GetProgramShape().ValueOrDie());
|
||||
|
|
|
@ -160,8 +160,6 @@ cc_library(
|
|||
deps = [
|
||||
":hlo",
|
||||
":mlir_hlo_builder",
|
||||
"//tensorflow/compiler/jit:xla_cpu_device",
|
||||
"//tensorflow/compiler/jit:xla_cpu_jit",
|
||||
"//tensorflow/compiler/mlir:op_or_arg_name_mapper",
|
||||
"//tensorflow/compiler/mlir/tensorflow",
|
||||
"//tensorflow/compiler/mlir/tensorflow:convert_type",
|
||||
|
|
|
@ -123,6 +123,8 @@ StatusOr<mlir::Type> ConvertPrimitiveTypeToMLIRType(PrimitiveType element_type,
|
|||
return builder.getI1Type();
|
||||
case PrimitiveType::F16:
|
||||
return builder.getF16Type();
|
||||
case PrimitiveType::BF16:
|
||||
return builder.getBF16Type();
|
||||
case PrimitiveType::F32:
|
||||
return builder.getF32Type();
|
||||
case PrimitiveType::F64:
|
||||
|
|
|
@ -337,6 +337,10 @@ class LegalizeTF : public PassWrapper<LegalizeTF, FunctionPass> {
|
|||
public:
|
||||
LegalizeTF() = default;
|
||||
|
||||
explicit LegalizeTF(llvm::StringRef device_type) {
|
||||
device_type_ = device_type.str();
|
||||
}
|
||||
|
||||
LegalizeTF(const LegalizeTF&) {}
|
||||
|
||||
void runOnFunction() override {
|
||||
|
@ -359,5 +363,10 @@ static PassRegistration<LegalizeTF> pass(
|
|||
|
||||
} // end namespace
|
||||
|
||||
std::unique_ptr<OperationPass<FuncOp>> createLegalizeTfWithTf2XlaPass(
|
||||
llvm::StringRef device_type) {
|
||||
return std::make_unique<LegalizeTF>(device_type);
|
||||
}
|
||||
|
||||
} // end namespace xla_hlo
|
||||
} // end namespace mlir
|
||||
|
|
|
@ -38,6 +38,11 @@ namespace xla_hlo {
|
|||
std::unique_ptr<OperationPass<FuncOp>> createLegalizeTFPass(
|
||||
bool allow_partial_conversion = false);
|
||||
|
||||
/// Lowers from TF dialect to HLO dialect using tf2xla op kernels for the
|
||||
/// specified device type.
|
||||
std::unique_ptr<OperationPass<FuncOp>> createLegalizeTfWithTf2XlaPass(
|
||||
llvm::StringRef device_type);
|
||||
|
||||
/// Lowers from TF dialect's control flow to HLO dialect's control flow.
|
||||
std::unique_ptr<OperationPass<ModuleOp>> createLegalizeTFControlFlowPass();
|
||||
|
||||
|
|
|
@ -167,6 +167,7 @@ cc_library(
|
|||
":tf2xla_proto_cc",
|
||||
":tf2xla_util",
|
||||
":xla_compiler",
|
||||
"//tensorflow/compiler/jit",
|
||||
"//tensorflow/compiler/mlir/tensorflow",
|
||||
"//tensorflow/compiler/mlir/tensorflow:compile_mlir_util",
|
||||
"//tensorflow/compiler/mlir/tensorflow:convert_graphdef",
|
||||
|
|
|
@ -174,7 +174,8 @@ Status ConvertGraphDefToXlaViaMlir(
|
|||
// Convert the MLIR module to XLA computation. If the input graph can't be
|
||||
// lowered down to a single graph node with a single island by the previous
|
||||
// step, this step will return an error.
|
||||
return ConvertMLIRToXlaComputation(*module, computation,
|
||||
return ConvertMLIRToXlaComputation(*module, /*device_type=*/"XLA_CPU_JIT",
|
||||
computation,
|
||||
/*use_tuple_args=*/false,
|
||||
/*always_return_tuple=*/true);
|
||||
}
|
||||
|
|
Loading…
Reference in New Issue