diff --git a/tensorflow/compiler/jit/xla_compilation_cache.cc b/tensorflow/compiler/jit/xla_compilation_cache.cc index 5081df28a08..b51749bc332 100644 --- a/tensorflow/compiler/jit/xla_compilation_cache.cc +++ b/tensorflow/compiler/jit/xla_compilation_cache.cc @@ -296,10 +296,10 @@ Status XlaCompilationCache::CompileSingleOp( arg_shapes.push_back(absl::get(arg.shape)); } GraphDebugInfo debug_info; - return CompileGraphToXlaHlo(*graph, {arg_shapes.data(), arg_shapes.size()}, - compile_options.use_tuple_arg, - *options.flib_def, debug_info, - options.shape_representation_fn, result); + return CompileGraphToXlaHlo( + *graph, {arg_shapes.data(), arg_shapes.size()}, + options.device_type.type_string(), compile_options.use_tuple_arg, + *options.flib_def, debug_info, options.shape_representation_fn, result); }; return CompileImpl(options, name, args, compile_op, /*compile_threshold=*/absl::nullopt, diff --git a/tensorflow/compiler/mlir/BUILD b/tensorflow/compiler/mlir/BUILD index 63546db1eb0..2c222ac4cb7 100644 --- a/tensorflow/compiler/mlir/BUILD +++ b/tensorflow/compiler/mlir/BUILD @@ -58,6 +58,11 @@ cc_library( "//tensorflow/python:__subpackages__", ], deps = [ + "@llvm-project//mlir:Affine", + "@llvm-project//mlir:QuantOps", + # Link jit lib to link JIT devices required to run + # xla-legalize-tf-with-tf2xla pass. + "//tensorflow/compiler/jit", "//tensorflow/compiler/mlir/lite:tensorflow_lite", "//tensorflow/compiler/mlir/lite:tensorflow_lite_dialect_registration", "//tensorflow/compiler/mlir/lite:tensorflow_lite_legalize_tf", @@ -90,8 +95,6 @@ cc_library( "//tensorflow/compiler/mlir/xla:xla_lower", "//tensorflow/compiler/mlir/xla:xla_materialize_broadcasts", "//tensorflow/compiler/mlir/xla:xla_test_passes", - "@llvm-project//mlir:Affine", - "@llvm-project//mlir:QuantOps", ], ) diff --git a/tensorflow/compiler/mlir/tensorflow/BUILD b/tensorflow/compiler/mlir/tensorflow/BUILD index c2120ccc4ab..b4ef4cc0bb5 100644 --- a/tensorflow/compiler/mlir/tensorflow/BUILD +++ b/tensorflow/compiler/mlir/tensorflow/BUILD @@ -1078,6 +1078,7 @@ COMPILE_MLIR_UTIL_DEPS = [ "//tensorflow/compiler/mlir/xla:mlir_hlo_to_hlo", "//tensorflow/compiler/mlir/xla:type_to_shape", "//tensorflow/compiler/mlir/xla:xla_legalize_tf", + "//tensorflow/compiler/mlir/xla:xla_legalize_tf_with_tf2xla", "//tensorflow/compiler/tf2xla:common", "//tensorflow/compiler/tf2xla:xla_compiler", "//tensorflow/core:framework", @@ -1118,6 +1119,7 @@ tf_cc_test( srcs = ["utils/compile_mlir_util_test.cc"], deps = [ ":compile_mlir_util", + "//tensorflow/compiler/jit", "//tensorflow/compiler/tf2xla:common", "//tensorflow/compiler/tf2xla:xla_compiler", "//tensorflow/compiler/xla:test", diff --git a/tensorflow/compiler/mlir/tensorflow/utils/compile_mlir_util.cc b/tensorflow/compiler/mlir/tensorflow/utils/compile_mlir_util.cc index 3e250ec287b..7a627780f25 100644 --- a/tensorflow/compiler/mlir/tensorflow/utils/compile_mlir_util.cc +++ b/tensorflow/compiler/mlir/tensorflow/utils/compile_mlir_util.cc @@ -254,8 +254,9 @@ static void RegisterDialects() { } // namespace Status ConvertMLIRToXlaComputation( - mlir::ModuleOp module_op, xla::XlaComputation* xla_computation, - bool use_tuple_args, bool return_tuple, + mlir::ModuleOp module_op, llvm::StringRef device_type, + xla::XlaComputation* xla_computation, bool use_tuple_args, + bool return_tuple, const XlaCompiler::ShapeRepresentationFn shape_representation_fn) { mlir::PassManager tf2xla(module_op.getContext()); tf2xla.addNestedPass(mlir::createCanonicalizerPass()); @@ -268,6 +269,7 @@ Status ConvertMLIRToXlaComputation( // with a tuple argument which break the assumption of resource lifting // inside PromoteResourcesToArgs. tf2xla.addPass(mlir::xla_hlo::createLegalizeTFControlFlowPass()); + tf2xla.addPass(mlir::xla_hlo::createLegalizeTfWithTf2XlaPass(device_type)); // We need to run LegalizeTFPass 2 times because first // LegalizeTFPass(allow_partial_conversion=true) can expose more graph pruning // and canonicalization opportunities that are necessary for the second @@ -308,7 +310,7 @@ Status ConvertMLIRToXlaComputation( static Status CompileMlirToXlaHlo( mlir::ModuleOp module_op, llvm::ArrayRef arg_shapes, - bool use_tuple_args, + llvm::StringRef device_type, bool use_tuple_args, XlaCompiler::ShapeRepresentationFn shape_representation_fn, XlaCompiler::CompilationResult* compilation_result) { if (VLOG_IS_ON(1)) @@ -326,7 +328,8 @@ static Status CompileMlirToXlaHlo( // Convert MLIR module to XLA HLO proto contained in XlaComputation. compilation_result->computation = std::make_shared(); TF_RETURN_IF_ERROR(ConvertMLIRToXlaComputation( - module_op, compilation_result->computation.get(), use_tuple_args, + module_op, device_type, compilation_result->computation.get(), + use_tuple_args, /*return_tuple=*/true, shape_representation_fn)); // Construct mapping from XlaComputation's arg to input edges of execute @@ -355,7 +358,7 @@ static Status CompileMlirToXlaHlo( Status CompileSerializedMlirToXlaHlo( llvm::StringRef mlir_module_string, llvm::ArrayRef arg_shapes, - bool use_tuple_args, + llvm::StringRef device_type, bool use_tuple_args, const XlaCompiler::ShapeRepresentationFn shape_representation_fn, XlaCompiler::CompilationResult* compilation_result) { RegisterDialects(); @@ -364,14 +367,15 @@ Status CompileSerializedMlirToXlaHlo( TF_RETURN_IF_ERROR( ParseMlirModule(mlir_module_string, &mlir_context, &mlir_module)); - return CompileMlirToXlaHlo(mlir_module.get(), arg_shapes, use_tuple_args, - shape_representation_fn, compilation_result); + return CompileMlirToXlaHlo(mlir_module.get(), arg_shapes, device_type, + use_tuple_args, shape_representation_fn, + compilation_result); } Status CompileGraphToXlaHlo( const Graph& graph, llvm::ArrayRef arg_shapes, - bool use_tuple_args, const FunctionLibraryDefinition& flib_def, - const GraphDebugInfo& debug_info, + llvm::StringRef device_type, bool use_tuple_args, + const FunctionLibraryDefinition& flib_def, const GraphDebugInfo& debug_info, const XlaCompiler::ShapeRepresentationFn shape_representation_fn, XlaCompiler::CompilationResult* compilation_result) { RegisterDialects(); @@ -383,8 +387,8 @@ Status CompileGraphToXlaHlo( if (!module_or.ok()) return module_or.status(); return CompileMlirToXlaHlo(module_or.ValueOrDie().get(), arg_shapes, - use_tuple_args, shape_representation_fn, - compilation_result); + device_type, use_tuple_args, + shape_representation_fn, compilation_result); } } // namespace tensorflow diff --git a/tensorflow/compiler/mlir/tensorflow/utils/compile_mlir_util.h b/tensorflow/compiler/mlir/tensorflow/utils/compile_mlir_util.h index 2ce0a31eb78..74c602a7afb 100644 --- a/tensorflow/compiler/mlir/tensorflow/utils/compile_mlir_util.h +++ b/tensorflow/compiler/mlir/tensorflow/utils/compile_mlir_util.h @@ -29,6 +29,8 @@ namespace tensorflow { // Lowers MLIR module to XLA HLO inside an XlaComputation. The input module // should only contain operations in tf dialect. If the input module contains // operation in the tf_executor dialect, for example, returns an error. +// Exception to this are tf_executor dialect ops that are optimized away through +// canonicalization. // // Operations in tf dialect are lowered to XLA HLO through the following steps: // . Legalizes control flow operations. @@ -39,6 +41,8 @@ namespace tensorflow { // . Legalizes the operations to XLA HLO operations. // . Canonicalizes the XLA HLO operations. // +// device_type: XLA JIT device to use for compilation such as "XLA_CPU_JIT", +// "XLA_GPU_JIT" or "XLA_TPU_JIT". // use_tuple_args: when this is true, always create a tuple argument for the // entry computation. // return_tuple: when this is true, always create a tuple result for the @@ -47,23 +51,24 @@ namespace tensorflow { // will be used to determine argument and result shapes. Otherwise the // original shape will be used as is. Status ConvertMLIRToXlaComputation( - mlir::ModuleOp module_op, xla::XlaComputation* xla_computation, - bool use_tuple_args, bool return_tuple, + mlir::ModuleOp module_op, llvm::StringRef device_type, + xla::XlaComputation* xla_computation, bool use_tuple_args, + bool return_tuple, const XlaCompiler::ShapeRepresentationFn shape_representation_fn = nullptr); // Compiles a serialized MLIR module into XLA HLO, generates all accompanying // metadata and stores them in CompilationResult. Status CompileSerializedMlirToXlaHlo( llvm::StringRef mlir_module_string, llvm::ArrayRef arg_shapes, - bool use_tuple_args, + llvm::StringRef device_type, bool use_tuple_args, const XlaCompiler::ShapeRepresentationFn shape_representation_fn, XlaCompiler::CompilationResult* compilation_result); // Same as the above but takes input as TensorFlow Graph. Status CompileGraphToXlaHlo( const Graph& graph, llvm::ArrayRef arg_shapes, - bool use_tuple_args, const FunctionLibraryDefinition& flib_def, - const GraphDebugInfo& debug_info, + llvm::StringRef device_type, bool use_tuple_args, + const FunctionLibraryDefinition& flib_def, const GraphDebugInfo& debug_info, const XlaCompiler::ShapeRepresentationFn shape_representation_fn, XlaCompiler::CompilationResult* compilation_result); diff --git a/tensorflow/compiler/mlir/tensorflow/utils/compile_mlir_util_test.cc b/tensorflow/compiler/mlir/tensorflow/utils/compile_mlir_util_test.cc index d406934c520..26c50a24f58 100644 --- a/tensorflow/compiler/mlir/tensorflow/utils/compile_mlir_util_test.cc +++ b/tensorflow/compiler/mlir/tensorflow/utils/compile_mlir_util_test.cc @@ -46,7 +46,7 @@ TEST(CompileSerializedMlirToXlaHloTest, InvalidSerializedMlirModule) { XlaCompiler::CompilationResult compilation_result; Status s = CompileSerializedMlirToXlaHlo( - invalid_mlir_module, arg_shapes, + invalid_mlir_module, arg_shapes, "XLA_CPU_JIT", /*use_tuple_args=*/true, TestShapeRepresentation, &compilation_result); EXPECT_EQ(s.code(), tensorflow::errors::Code::INVALID_ARGUMENT); EXPECT_EQ(s.ToString(), @@ -68,7 +68,7 @@ TEST(CompileSerializedMlirToXlaHloTest, TupleArgs) { XlaCompiler::CompilationResult compilation_result; Status s = CompileSerializedMlirToXlaHlo( - kBinaryAddModule, arg_shapes, + kBinaryAddModule, arg_shapes, "XLA_CPU_JIT", /*use_tuple_args=*/true, TestShapeRepresentation, &compilation_result); TF_ASSERT_OK(s); @@ -126,7 +126,7 @@ TEST(CompileSerializedMlirToXlaHloTest, IndividualArgs) { XlaCompiler::CompilationResult compilation_result; Status s = CompileSerializedMlirToXlaHlo( - kBinaryAddModule, arg_shapes, + kBinaryAddModule, arg_shapes, "XLA_CPU_JIT", /*use_tuple_args=*/false, TestShapeRepresentation, &compilation_result); TF_ASSERT_OK(s); @@ -197,7 +197,7 @@ TEST(CompileSerializedMlirToXlaHloTest, CompileTimeConstantFoldedSuccess) { XlaCompiler::CompilationResult compilation_result; Status s = CompileSerializedMlirToXlaHlo( - mlir_module, arg_shapes, + mlir_module, arg_shapes, "XLA_CPU_JIT", /*use_tuple_args=*/true, TestShapeRepresentation, &compilation_result); TF_ASSERT_OK(s); @@ -236,7 +236,7 @@ TEST(CompileSerializedMlirToXlaHloTest, ShapeInference) { XlaCompiler::CompilationResult compilation_result; Status s = CompileSerializedMlirToXlaHlo( - mlir_module, arg_shapes, + mlir_module, arg_shapes, "XLA_CPU_JIT", /*use_tuple_args=*/true, TestShapeRepresentation, &compilation_result); TF_ASSERT_OK(s); @@ -267,7 +267,7 @@ module attributes {tf.versions = {producer = 179 : i32}} { XlaCompiler::CompilationResult compilation_result; Status s = CompileSerializedMlirToXlaHlo( - mlir_module, arg_shapes, + mlir_module, arg_shapes, "XLA_CPU_JIT", /*use_tuple_args=*/true, TestShapeRepresentation, &compilation_result); TF_ASSERT_OK(s); @@ -325,7 +325,7 @@ module attributes {tf.versions = {producer = 179 : i32}} { XlaCompiler::CompilationResult compilation_result; Status s = CompileSerializedMlirToXlaHlo( - mlir_module, arg_shapes, + mlir_module, arg_shapes, "XLA_CPU_JIT", /*use_tuple_args=*/true, TestShapeRepresentation, &compilation_result); TF_ASSERT_OK(s); @@ -362,7 +362,7 @@ module attributes {tf.versions = {producer = 179 : i32}} { XlaCompiler::CompilationResult compilation_result; Status s = CompileSerializedMlirToXlaHlo( - mlir_module, arg_shapes, + mlir_module, arg_shapes, "XLA_CPU_JIT", /*use_tuple_args=*/true, TestShapeRepresentation, &compilation_result); ASSERT_FALSE(s.ok()); EXPECT_EQ(s.error_message(), @@ -384,7 +384,7 @@ module attributes {tf.versions = {bad_consumers = [], min_consumer = 0 : i32, pr XlaCompiler::CompilationResult compilation_result; Status s = CompileSerializedMlirToXlaHlo( - mlir_module, arg_shapes, + mlir_module, arg_shapes, "XLA_CPU_JIT", /*use_tuple_args=*/true, TestShapeRepresentation, &compilation_result); TF_ASSERT_OK(s); @@ -424,9 +424,10 @@ TEST(CompileGraphToXlaHlo, Basic) { test::graph::Retval(&graph, 0, arg); XlaCompiler::CompilationResult result; - TF_ASSERT_OK(CompileGraphToXlaHlo( - graph, /*arg_shapes=*/{TensorShape()}, /*use_tuple_args=*/false, flib_def, - GraphDebugInfo(), /*shape_representation_fn=*/nullptr, &result)); + TF_ASSERT_OK( + CompileGraphToXlaHlo(graph, /*arg_shapes=*/{TensorShape()}, "XLA_CPU_JIT", + /*use_tuple_args=*/false, flib_def, GraphDebugInfo(), + /*shape_representation_fn=*/nullptr, &result)); const xla::HloModuleConfig module_config( result.computation->GetProgramShape().ValueOrDie()); diff --git a/tensorflow/compiler/mlir/xla/BUILD b/tensorflow/compiler/mlir/xla/BUILD index 0feb633948d..e20f8543e61 100644 --- a/tensorflow/compiler/mlir/xla/BUILD +++ b/tensorflow/compiler/mlir/xla/BUILD @@ -160,8 +160,6 @@ cc_library( deps = [ ":hlo", ":mlir_hlo_builder", - "//tensorflow/compiler/jit:xla_cpu_device", - "//tensorflow/compiler/jit:xla_cpu_jit", "//tensorflow/compiler/mlir:op_or_arg_name_mapper", "//tensorflow/compiler/mlir/tensorflow", "//tensorflow/compiler/mlir/tensorflow:convert_type", diff --git a/tensorflow/compiler/mlir/xla/hlo_utils.cc b/tensorflow/compiler/mlir/xla/hlo_utils.cc index 7526248baca..50240c84f9d 100644 --- a/tensorflow/compiler/mlir/xla/hlo_utils.cc +++ b/tensorflow/compiler/mlir/xla/hlo_utils.cc @@ -123,6 +123,8 @@ StatusOr ConvertPrimitiveTypeToMLIRType(PrimitiveType element_type, return builder.getI1Type(); case PrimitiveType::F16: return builder.getF16Type(); + case PrimitiveType::BF16: + return builder.getBF16Type(); case PrimitiveType::F32: return builder.getF32Type(); case PrimitiveType::F64: diff --git a/tensorflow/compiler/mlir/xla/transforms/legalize_tf_with_tf2xla.cc b/tensorflow/compiler/mlir/xla/transforms/legalize_tf_with_tf2xla.cc index 1659fe4d467..7ae18eb0d34 100644 --- a/tensorflow/compiler/mlir/xla/transforms/legalize_tf_with_tf2xla.cc +++ b/tensorflow/compiler/mlir/xla/transforms/legalize_tf_with_tf2xla.cc @@ -337,6 +337,10 @@ class LegalizeTF : public PassWrapper { public: LegalizeTF() = default; + explicit LegalizeTF(llvm::StringRef device_type) { + device_type_ = device_type.str(); + } + LegalizeTF(const LegalizeTF&) {} void runOnFunction() override { @@ -359,5 +363,10 @@ static PassRegistration pass( } // end namespace +std::unique_ptr> createLegalizeTfWithTf2XlaPass( + llvm::StringRef device_type) { + return std::make_unique(device_type); +} + } // end namespace xla_hlo } // end namespace mlir diff --git a/tensorflow/compiler/mlir/xla/transforms/passes.h b/tensorflow/compiler/mlir/xla/transforms/passes.h index 5b2ec24d9bc..2d0164981a3 100644 --- a/tensorflow/compiler/mlir/xla/transforms/passes.h +++ b/tensorflow/compiler/mlir/xla/transforms/passes.h @@ -38,6 +38,11 @@ namespace xla_hlo { std::unique_ptr> createLegalizeTFPass( bool allow_partial_conversion = false); +/// Lowers from TF dialect to HLO dialect using tf2xla op kernels for the +/// specified device type. +std::unique_ptr> createLegalizeTfWithTf2XlaPass( + llvm::StringRef device_type); + /// Lowers from TF dialect's control flow to HLO dialect's control flow. std::unique_ptr> createLegalizeTFControlFlowPass(); diff --git a/tensorflow/compiler/tf2xla/BUILD b/tensorflow/compiler/tf2xla/BUILD index 1c5867a1312..a5332385994 100644 --- a/tensorflow/compiler/tf2xla/BUILD +++ b/tensorflow/compiler/tf2xla/BUILD @@ -167,6 +167,7 @@ cc_library( ":tf2xla_proto_cc", ":tf2xla_util", ":xla_compiler", + "//tensorflow/compiler/jit", "//tensorflow/compiler/mlir/tensorflow", "//tensorflow/compiler/mlir/tensorflow:compile_mlir_util", "//tensorflow/compiler/mlir/tensorflow:convert_graphdef", diff --git a/tensorflow/compiler/tf2xla/mlir_tf2xla.cc b/tensorflow/compiler/tf2xla/mlir_tf2xla.cc index 43404bc2267..daf261fa5d8 100644 --- a/tensorflow/compiler/tf2xla/mlir_tf2xla.cc +++ b/tensorflow/compiler/tf2xla/mlir_tf2xla.cc @@ -174,7 +174,8 @@ Status ConvertGraphDefToXlaViaMlir( // Convert the MLIR module to XLA computation. If the input graph can't be // lowered down to a single graph node with a single island by the previous // step, this step will return an error. - return ConvertMLIRToXlaComputation(*module, computation, + return ConvertMLIRToXlaComputation(*module, /*device_type=*/"XLA_CPU_JIT", + computation, /*use_tuple_args=*/false, /*always_return_tuple=*/true); }