From d5f5640cd9f0b955c1f8f9b27b976eaed2b3be71 Mon Sep 17 00:00:00 2001 From: Smit Hinsu Date: Wed, 8 Apr 2020 18:15:14 -0700 Subject: [PATCH] Add "xla-legalize-tf-with-tf2xla" pass to XLA compilation pass pipeline Adding this pass to the pipeline requires moving device deps outside of the pass. Otherwise, this will create circular dependency between XLA compilation cache and the device. Also, support BF16 type in ConvertPrimitiveTypeToMLIRType helper. PiperOrigin-RevId: 305594849 Change-Id: I469a6715511417b5db2bbc9b2a74fd2f24be5440 --- .../compiler/jit/xla_compilation_cache.cc | 8 +++--- tensorflow/compiler/mlir/BUILD | 7 +++-- tensorflow/compiler/mlir/tensorflow/BUILD | 2 ++ .../tensorflow/utils/compile_mlir_util.cc | 26 +++++++++++-------- .../mlir/tensorflow/utils/compile_mlir_util.h | 15 +++++++---- .../utils/compile_mlir_util_test.cc | 25 +++++++++--------- tensorflow/compiler/mlir/xla/BUILD | 2 -- tensorflow/compiler/mlir/xla/hlo_utils.cc | 2 ++ .../xla/transforms/legalize_tf_with_tf2xla.cc | 9 +++++++ .../compiler/mlir/xla/transforms/passes.h | 5 ++++ tensorflow/compiler/tf2xla/BUILD | 1 + tensorflow/compiler/tf2xla/mlir_tf2xla.cc | 3 ++- 12 files changed, 68 insertions(+), 37 deletions(-) diff --git a/tensorflow/compiler/jit/xla_compilation_cache.cc b/tensorflow/compiler/jit/xla_compilation_cache.cc index 5081df28a08..b51749bc332 100644 --- a/tensorflow/compiler/jit/xla_compilation_cache.cc +++ b/tensorflow/compiler/jit/xla_compilation_cache.cc @@ -296,10 +296,10 @@ Status XlaCompilationCache::CompileSingleOp( arg_shapes.push_back(absl::get(arg.shape)); } GraphDebugInfo debug_info; - return CompileGraphToXlaHlo(*graph, {arg_shapes.data(), arg_shapes.size()}, - compile_options.use_tuple_arg, - *options.flib_def, debug_info, - options.shape_representation_fn, result); + return CompileGraphToXlaHlo( + *graph, {arg_shapes.data(), arg_shapes.size()}, + options.device_type.type_string(), compile_options.use_tuple_arg, + *options.flib_def, debug_info, options.shape_representation_fn, result); }; return CompileImpl(options, name, args, compile_op, /*compile_threshold=*/absl::nullopt, diff --git a/tensorflow/compiler/mlir/BUILD b/tensorflow/compiler/mlir/BUILD index 63546db1eb0..2c222ac4cb7 100644 --- a/tensorflow/compiler/mlir/BUILD +++ b/tensorflow/compiler/mlir/BUILD @@ -58,6 +58,11 @@ cc_library( "//tensorflow/python:__subpackages__", ], deps = [ + "@llvm-project//mlir:Affine", + "@llvm-project//mlir:QuantOps", + # Link jit lib to link JIT devices required to run + # xla-legalize-tf-with-tf2xla pass. + "//tensorflow/compiler/jit", "//tensorflow/compiler/mlir/lite:tensorflow_lite", "//tensorflow/compiler/mlir/lite:tensorflow_lite_dialect_registration", "//tensorflow/compiler/mlir/lite:tensorflow_lite_legalize_tf", @@ -90,8 +95,6 @@ cc_library( "//tensorflow/compiler/mlir/xla:xla_lower", "//tensorflow/compiler/mlir/xla:xla_materialize_broadcasts", "//tensorflow/compiler/mlir/xla:xla_test_passes", - "@llvm-project//mlir:Affine", - "@llvm-project//mlir:QuantOps", ], ) diff --git a/tensorflow/compiler/mlir/tensorflow/BUILD b/tensorflow/compiler/mlir/tensorflow/BUILD index c2120ccc4ab..b4ef4cc0bb5 100644 --- a/tensorflow/compiler/mlir/tensorflow/BUILD +++ b/tensorflow/compiler/mlir/tensorflow/BUILD @@ -1078,6 +1078,7 @@ COMPILE_MLIR_UTIL_DEPS = [ "//tensorflow/compiler/mlir/xla:mlir_hlo_to_hlo", "//tensorflow/compiler/mlir/xla:type_to_shape", "//tensorflow/compiler/mlir/xla:xla_legalize_tf", + "//tensorflow/compiler/mlir/xla:xla_legalize_tf_with_tf2xla", "//tensorflow/compiler/tf2xla:common", "//tensorflow/compiler/tf2xla:xla_compiler", "//tensorflow/core:framework", @@ -1118,6 +1119,7 @@ tf_cc_test( srcs = ["utils/compile_mlir_util_test.cc"], deps = [ ":compile_mlir_util", + "//tensorflow/compiler/jit", "//tensorflow/compiler/tf2xla:common", "//tensorflow/compiler/tf2xla:xla_compiler", "//tensorflow/compiler/xla:test", diff --git a/tensorflow/compiler/mlir/tensorflow/utils/compile_mlir_util.cc b/tensorflow/compiler/mlir/tensorflow/utils/compile_mlir_util.cc index 3e250ec287b..7a627780f25 100644 --- a/tensorflow/compiler/mlir/tensorflow/utils/compile_mlir_util.cc +++ b/tensorflow/compiler/mlir/tensorflow/utils/compile_mlir_util.cc @@ -254,8 +254,9 @@ static void RegisterDialects() { } // namespace Status ConvertMLIRToXlaComputation( - mlir::ModuleOp module_op, xla::XlaComputation* xla_computation, - bool use_tuple_args, bool return_tuple, + mlir::ModuleOp module_op, llvm::StringRef device_type, + xla::XlaComputation* xla_computation, bool use_tuple_args, + bool return_tuple, const XlaCompiler::ShapeRepresentationFn shape_representation_fn) { mlir::PassManager tf2xla(module_op.getContext()); tf2xla.addNestedPass(mlir::createCanonicalizerPass()); @@ -268,6 +269,7 @@ Status ConvertMLIRToXlaComputation( // with a tuple argument which break the assumption of resource lifting // inside PromoteResourcesToArgs. tf2xla.addPass(mlir::xla_hlo::createLegalizeTFControlFlowPass()); + tf2xla.addPass(mlir::xla_hlo::createLegalizeTfWithTf2XlaPass(device_type)); // We need to run LegalizeTFPass 2 times because first // LegalizeTFPass(allow_partial_conversion=true) can expose more graph pruning // and canonicalization opportunities that are necessary for the second @@ -308,7 +310,7 @@ Status ConvertMLIRToXlaComputation( static Status CompileMlirToXlaHlo( mlir::ModuleOp module_op, llvm::ArrayRef arg_shapes, - bool use_tuple_args, + llvm::StringRef device_type, bool use_tuple_args, XlaCompiler::ShapeRepresentationFn shape_representation_fn, XlaCompiler::CompilationResult* compilation_result) { if (VLOG_IS_ON(1)) @@ -326,7 +328,8 @@ static Status CompileMlirToXlaHlo( // Convert MLIR module to XLA HLO proto contained in XlaComputation. compilation_result->computation = std::make_shared(); TF_RETURN_IF_ERROR(ConvertMLIRToXlaComputation( - module_op, compilation_result->computation.get(), use_tuple_args, + module_op, device_type, compilation_result->computation.get(), + use_tuple_args, /*return_tuple=*/true, shape_representation_fn)); // Construct mapping from XlaComputation's arg to input edges of execute @@ -355,7 +358,7 @@ static Status CompileMlirToXlaHlo( Status CompileSerializedMlirToXlaHlo( llvm::StringRef mlir_module_string, llvm::ArrayRef arg_shapes, - bool use_tuple_args, + llvm::StringRef device_type, bool use_tuple_args, const XlaCompiler::ShapeRepresentationFn shape_representation_fn, XlaCompiler::CompilationResult* compilation_result) { RegisterDialects(); @@ -364,14 +367,15 @@ Status CompileSerializedMlirToXlaHlo( TF_RETURN_IF_ERROR( ParseMlirModule(mlir_module_string, &mlir_context, &mlir_module)); - return CompileMlirToXlaHlo(mlir_module.get(), arg_shapes, use_tuple_args, - shape_representation_fn, compilation_result); + return CompileMlirToXlaHlo(mlir_module.get(), arg_shapes, device_type, + use_tuple_args, shape_representation_fn, + compilation_result); } Status CompileGraphToXlaHlo( const Graph& graph, llvm::ArrayRef arg_shapes, - bool use_tuple_args, const FunctionLibraryDefinition& flib_def, - const GraphDebugInfo& debug_info, + llvm::StringRef device_type, bool use_tuple_args, + const FunctionLibraryDefinition& flib_def, const GraphDebugInfo& debug_info, const XlaCompiler::ShapeRepresentationFn shape_representation_fn, XlaCompiler::CompilationResult* compilation_result) { RegisterDialects(); @@ -383,8 +387,8 @@ Status CompileGraphToXlaHlo( if (!module_or.ok()) return module_or.status(); return CompileMlirToXlaHlo(module_or.ValueOrDie().get(), arg_shapes, - use_tuple_args, shape_representation_fn, - compilation_result); + device_type, use_tuple_args, + shape_representation_fn, compilation_result); } } // namespace tensorflow diff --git a/tensorflow/compiler/mlir/tensorflow/utils/compile_mlir_util.h b/tensorflow/compiler/mlir/tensorflow/utils/compile_mlir_util.h index 2ce0a31eb78..74c602a7afb 100644 --- a/tensorflow/compiler/mlir/tensorflow/utils/compile_mlir_util.h +++ b/tensorflow/compiler/mlir/tensorflow/utils/compile_mlir_util.h @@ -29,6 +29,8 @@ namespace tensorflow { // Lowers MLIR module to XLA HLO inside an XlaComputation. The input module // should only contain operations in tf dialect. If the input module contains // operation in the tf_executor dialect, for example, returns an error. +// Exception to this are tf_executor dialect ops that are optimized away through +// canonicalization. // // Operations in tf dialect are lowered to XLA HLO through the following steps: // . Legalizes control flow operations. @@ -39,6 +41,8 @@ namespace tensorflow { // . Legalizes the operations to XLA HLO operations. // . Canonicalizes the XLA HLO operations. // +// device_type: XLA JIT device to use for compilation such as "XLA_CPU_JIT", +// "XLA_GPU_JIT" or "XLA_TPU_JIT". // use_tuple_args: when this is true, always create a tuple argument for the // entry computation. // return_tuple: when this is true, always create a tuple result for the @@ -47,23 +51,24 @@ namespace tensorflow { // will be used to determine argument and result shapes. Otherwise the // original shape will be used as is. Status ConvertMLIRToXlaComputation( - mlir::ModuleOp module_op, xla::XlaComputation* xla_computation, - bool use_tuple_args, bool return_tuple, + mlir::ModuleOp module_op, llvm::StringRef device_type, + xla::XlaComputation* xla_computation, bool use_tuple_args, + bool return_tuple, const XlaCompiler::ShapeRepresentationFn shape_representation_fn = nullptr); // Compiles a serialized MLIR module into XLA HLO, generates all accompanying // metadata and stores them in CompilationResult. Status CompileSerializedMlirToXlaHlo( llvm::StringRef mlir_module_string, llvm::ArrayRef arg_shapes, - bool use_tuple_args, + llvm::StringRef device_type, bool use_tuple_args, const XlaCompiler::ShapeRepresentationFn shape_representation_fn, XlaCompiler::CompilationResult* compilation_result); // Same as the above but takes input as TensorFlow Graph. Status CompileGraphToXlaHlo( const Graph& graph, llvm::ArrayRef arg_shapes, - bool use_tuple_args, const FunctionLibraryDefinition& flib_def, - const GraphDebugInfo& debug_info, + llvm::StringRef device_type, bool use_tuple_args, + const FunctionLibraryDefinition& flib_def, const GraphDebugInfo& debug_info, const XlaCompiler::ShapeRepresentationFn shape_representation_fn, XlaCompiler::CompilationResult* compilation_result); diff --git a/tensorflow/compiler/mlir/tensorflow/utils/compile_mlir_util_test.cc b/tensorflow/compiler/mlir/tensorflow/utils/compile_mlir_util_test.cc index d406934c520..26c50a24f58 100644 --- a/tensorflow/compiler/mlir/tensorflow/utils/compile_mlir_util_test.cc +++ b/tensorflow/compiler/mlir/tensorflow/utils/compile_mlir_util_test.cc @@ -46,7 +46,7 @@ TEST(CompileSerializedMlirToXlaHloTest, InvalidSerializedMlirModule) { XlaCompiler::CompilationResult compilation_result; Status s = CompileSerializedMlirToXlaHlo( - invalid_mlir_module, arg_shapes, + invalid_mlir_module, arg_shapes, "XLA_CPU_JIT", /*use_tuple_args=*/true, TestShapeRepresentation, &compilation_result); EXPECT_EQ(s.code(), tensorflow::errors::Code::INVALID_ARGUMENT); EXPECT_EQ(s.ToString(), @@ -68,7 +68,7 @@ TEST(CompileSerializedMlirToXlaHloTest, TupleArgs) { XlaCompiler::CompilationResult compilation_result; Status s = CompileSerializedMlirToXlaHlo( - kBinaryAddModule, arg_shapes, + kBinaryAddModule, arg_shapes, "XLA_CPU_JIT", /*use_tuple_args=*/true, TestShapeRepresentation, &compilation_result); TF_ASSERT_OK(s); @@ -126,7 +126,7 @@ TEST(CompileSerializedMlirToXlaHloTest, IndividualArgs) { XlaCompiler::CompilationResult compilation_result; Status s = CompileSerializedMlirToXlaHlo( - kBinaryAddModule, arg_shapes, + kBinaryAddModule, arg_shapes, "XLA_CPU_JIT", /*use_tuple_args=*/false, TestShapeRepresentation, &compilation_result); TF_ASSERT_OK(s); @@ -197,7 +197,7 @@ TEST(CompileSerializedMlirToXlaHloTest, CompileTimeConstantFoldedSuccess) { XlaCompiler::CompilationResult compilation_result; Status s = CompileSerializedMlirToXlaHlo( - mlir_module, arg_shapes, + mlir_module, arg_shapes, "XLA_CPU_JIT", /*use_tuple_args=*/true, TestShapeRepresentation, &compilation_result); TF_ASSERT_OK(s); @@ -236,7 +236,7 @@ TEST(CompileSerializedMlirToXlaHloTest, ShapeInference) { XlaCompiler::CompilationResult compilation_result; Status s = CompileSerializedMlirToXlaHlo( - mlir_module, arg_shapes, + mlir_module, arg_shapes, "XLA_CPU_JIT", /*use_tuple_args=*/true, TestShapeRepresentation, &compilation_result); TF_ASSERT_OK(s); @@ -267,7 +267,7 @@ module attributes {tf.versions = {producer = 179 : i32}} { XlaCompiler::CompilationResult compilation_result; Status s = CompileSerializedMlirToXlaHlo( - mlir_module, arg_shapes, + mlir_module, arg_shapes, "XLA_CPU_JIT", /*use_tuple_args=*/true, TestShapeRepresentation, &compilation_result); TF_ASSERT_OK(s); @@ -325,7 +325,7 @@ module attributes {tf.versions = {producer = 179 : i32}} { XlaCompiler::CompilationResult compilation_result; Status s = CompileSerializedMlirToXlaHlo( - mlir_module, arg_shapes, + mlir_module, arg_shapes, "XLA_CPU_JIT", /*use_tuple_args=*/true, TestShapeRepresentation, &compilation_result); TF_ASSERT_OK(s); @@ -362,7 +362,7 @@ module attributes {tf.versions = {producer = 179 : i32}} { XlaCompiler::CompilationResult compilation_result; Status s = CompileSerializedMlirToXlaHlo( - mlir_module, arg_shapes, + mlir_module, arg_shapes, "XLA_CPU_JIT", /*use_tuple_args=*/true, TestShapeRepresentation, &compilation_result); ASSERT_FALSE(s.ok()); EXPECT_EQ(s.error_message(), @@ -384,7 +384,7 @@ module attributes {tf.versions = {bad_consumers = [], min_consumer = 0 : i32, pr XlaCompiler::CompilationResult compilation_result; Status s = CompileSerializedMlirToXlaHlo( - mlir_module, arg_shapes, + mlir_module, arg_shapes, "XLA_CPU_JIT", /*use_tuple_args=*/true, TestShapeRepresentation, &compilation_result); TF_ASSERT_OK(s); @@ -424,9 +424,10 @@ TEST(CompileGraphToXlaHlo, Basic) { test::graph::Retval(&graph, 0, arg); XlaCompiler::CompilationResult result; - TF_ASSERT_OK(CompileGraphToXlaHlo( - graph, /*arg_shapes=*/{TensorShape()}, /*use_tuple_args=*/false, flib_def, - GraphDebugInfo(), /*shape_representation_fn=*/nullptr, &result)); + TF_ASSERT_OK( + CompileGraphToXlaHlo(graph, /*arg_shapes=*/{TensorShape()}, "XLA_CPU_JIT", + /*use_tuple_args=*/false, flib_def, GraphDebugInfo(), + /*shape_representation_fn=*/nullptr, &result)); const xla::HloModuleConfig module_config( result.computation->GetProgramShape().ValueOrDie()); diff --git a/tensorflow/compiler/mlir/xla/BUILD b/tensorflow/compiler/mlir/xla/BUILD index 0feb633948d..e20f8543e61 100644 --- a/tensorflow/compiler/mlir/xla/BUILD +++ b/tensorflow/compiler/mlir/xla/BUILD @@ -160,8 +160,6 @@ cc_library( deps = [ ":hlo", ":mlir_hlo_builder", - "//tensorflow/compiler/jit:xla_cpu_device", - "//tensorflow/compiler/jit:xla_cpu_jit", "//tensorflow/compiler/mlir:op_or_arg_name_mapper", "//tensorflow/compiler/mlir/tensorflow", "//tensorflow/compiler/mlir/tensorflow:convert_type", diff --git a/tensorflow/compiler/mlir/xla/hlo_utils.cc b/tensorflow/compiler/mlir/xla/hlo_utils.cc index 7526248baca..50240c84f9d 100644 --- a/tensorflow/compiler/mlir/xla/hlo_utils.cc +++ b/tensorflow/compiler/mlir/xla/hlo_utils.cc @@ -123,6 +123,8 @@ StatusOr ConvertPrimitiveTypeToMLIRType(PrimitiveType element_type, return builder.getI1Type(); case PrimitiveType::F16: return builder.getF16Type(); + case PrimitiveType::BF16: + return builder.getBF16Type(); case PrimitiveType::F32: return builder.getF32Type(); case PrimitiveType::F64: diff --git a/tensorflow/compiler/mlir/xla/transforms/legalize_tf_with_tf2xla.cc b/tensorflow/compiler/mlir/xla/transforms/legalize_tf_with_tf2xla.cc index 1659fe4d467..7ae18eb0d34 100644 --- a/tensorflow/compiler/mlir/xla/transforms/legalize_tf_with_tf2xla.cc +++ b/tensorflow/compiler/mlir/xla/transforms/legalize_tf_with_tf2xla.cc @@ -337,6 +337,10 @@ class LegalizeTF : public PassWrapper { public: LegalizeTF() = default; + explicit LegalizeTF(llvm::StringRef device_type) { + device_type_ = device_type.str(); + } + LegalizeTF(const LegalizeTF&) {} void runOnFunction() override { @@ -359,5 +363,10 @@ static PassRegistration pass( } // end namespace +std::unique_ptr> createLegalizeTfWithTf2XlaPass( + llvm::StringRef device_type) { + return std::make_unique(device_type); +} + } // end namespace xla_hlo } // end namespace mlir diff --git a/tensorflow/compiler/mlir/xla/transforms/passes.h b/tensorflow/compiler/mlir/xla/transforms/passes.h index 5b2ec24d9bc..2d0164981a3 100644 --- a/tensorflow/compiler/mlir/xla/transforms/passes.h +++ b/tensorflow/compiler/mlir/xla/transforms/passes.h @@ -38,6 +38,11 @@ namespace xla_hlo { std::unique_ptr> createLegalizeTFPass( bool allow_partial_conversion = false); +/// Lowers from TF dialect to HLO dialect using tf2xla op kernels for the +/// specified device type. +std::unique_ptr> createLegalizeTfWithTf2XlaPass( + llvm::StringRef device_type); + /// Lowers from TF dialect's control flow to HLO dialect's control flow. std::unique_ptr> createLegalizeTFControlFlowPass(); diff --git a/tensorflow/compiler/tf2xla/BUILD b/tensorflow/compiler/tf2xla/BUILD index 1c5867a1312..a5332385994 100644 --- a/tensorflow/compiler/tf2xla/BUILD +++ b/tensorflow/compiler/tf2xla/BUILD @@ -167,6 +167,7 @@ cc_library( ":tf2xla_proto_cc", ":tf2xla_util", ":xla_compiler", + "//tensorflow/compiler/jit", "//tensorflow/compiler/mlir/tensorflow", "//tensorflow/compiler/mlir/tensorflow:compile_mlir_util", "//tensorflow/compiler/mlir/tensorflow:convert_graphdef", diff --git a/tensorflow/compiler/tf2xla/mlir_tf2xla.cc b/tensorflow/compiler/tf2xla/mlir_tf2xla.cc index 43404bc2267..daf261fa5d8 100644 --- a/tensorflow/compiler/tf2xla/mlir_tf2xla.cc +++ b/tensorflow/compiler/tf2xla/mlir_tf2xla.cc @@ -174,7 +174,8 @@ Status ConvertGraphDefToXlaViaMlir( // Convert the MLIR module to XLA computation. If the input graph can't be // lowered down to a single graph node with a single island by the previous // step, this step will return an error. - return ConvertMLIRToXlaComputation(*module, computation, + return ConvertMLIRToXlaComputation(*module, /*device_type=*/"XLA_CPU_JIT", + computation, /*use_tuple_args=*/false, /*always_return_tuple=*/true); }