diff --git a/tensorflow/compiler/mlir/tensorflow/utils/compile_mlir_util.cc b/tensorflow/compiler/mlir/tensorflow/utils/compile_mlir_util.cc index 731f7235c12..10aad0a03ff 100644 --- a/tensorflow/compiler/mlir/tensorflow/utils/compile_mlir_util.cc +++ b/tensorflow/compiler/mlir/tensorflow/utils/compile_mlir_util.cc @@ -67,6 +67,7 @@ Status ParseMlirModule(llvm::StringRef mlir_module_string, // Converts arg_shapes to xla::Shape's and store into xla_input_shapes. Status GetXlaInputShapes( mlir::ModuleOp module, llvm::ArrayRef arg_shapes, + bool use_tuple_args, const xla::CustomShapeRepresentationFn shape_representation_fn, std::vector* xla_input_shapes) { xla_input_shapes->clear(); @@ -88,8 +89,12 @@ Status GetXlaInputShapes( TF_ASSIGN_OR_RETURN(xla_shape, shape_representation_fn(arg_shapes[i], dtype)); } - xla_input_shapes->push_back( - xla::ShapeUtil::MakeTupleShape(individual_arg_shapes)); + if (use_tuple_args) { + xla_input_shapes->push_back( + xla::ShapeUtil::MakeTupleShape(individual_arg_shapes)); + } else { + *xla_input_shapes = individual_arg_shapes; + } return Status::OK(); } @@ -257,6 +262,7 @@ Status ConvertMLIRToXlaComputation(mlir::ModuleOp module_op, Status CompileSerializedMlirToXlaHlo( llvm::StringRef mlir_module_string, llvm::ArrayRef arg_shapes, + bool use_tuple_args, const XlaCompiler::ShapeRepresentationFn shape_representation_fn, XlaCompiler::CompilationResult* compilation_result) { mlir::MLIRContext mlir_context; @@ -278,7 +284,7 @@ Status CompileSerializedMlirToXlaHlo( // Convert MLIR module to XLA HLO proto contained in XlaComputation. compilation_result->computation = std::make_shared(); TF_RETURN_IF_ERROR(ConvertMLIRToXlaComputation( - module_op, compilation_result->computation.get(), /*use_tuple_args=*/true, + module_op, compilation_result->computation.get(), use_tuple_args, /*return_tuple=*/true)); // Construct mapping from XlaComputation's arg to input edges of execute @@ -291,7 +297,7 @@ Status CompileSerializedMlirToXlaHlo( }; // Compute all input shapes. - TF_RETURN_IF_ERROR(GetXlaInputShapes(module_op, arg_shapes, + TF_RETURN_IF_ERROR(GetXlaInputShapes(module_op, arg_shapes, use_tuple_args, shape_representation_fn_no_fast_memory, &compilation_result->xla_input_shapes)); diff --git a/tensorflow/compiler/mlir/tensorflow/utils/compile_mlir_util.h b/tensorflow/compiler/mlir/tensorflow/utils/compile_mlir_util.h index ed25aaf929e..41fa8b90e4f 100644 --- a/tensorflow/compiler/mlir/tensorflow/utils/compile_mlir_util.h +++ b/tensorflow/compiler/mlir/tensorflow/utils/compile_mlir_util.h @@ -50,6 +50,7 @@ Status ConvertMLIRToXlaComputation(mlir::ModuleOp module_op, // metadata and stores them in CompilationResult. Status CompileSerializedMlirToXlaHlo( llvm::StringRef mlir_module_string, llvm::ArrayRef arg_shapes, + bool use_tuple_args, const XlaCompiler::ShapeRepresentationFn shape_representation_fn, XlaCompiler::CompilationResult* compilation_result); } // namespace tensorflow diff --git a/tensorflow/compiler/mlir/tensorflow/utils/compile_mlir_util_test.cc b/tensorflow/compiler/mlir/tensorflow/utils/compile_mlir_util_test.cc index 8e0f9cb2497..b258dd68ae1 100644 --- a/tensorflow/compiler/mlir/tensorflow/utils/compile_mlir_util_test.cc +++ b/tensorflow/compiler/mlir/tensorflow/utils/compile_mlir_util_test.cc @@ -41,30 +41,31 @@ TEST(CompileSerializedMlirToXlaHloTest, InvalidSerializedMlirModule) { std::vector arg_shapes; XlaCompiler::CompilationResult compilation_result; - Status s = CompileSerializedMlirToXlaHlo(invalid_mlir_module, arg_shapes, - TestShapeRepresentation, - &compilation_result); + Status s = CompileSerializedMlirToXlaHlo( + invalid_mlir_module, arg_shapes, + /*use_tuple_args=*/true, TestShapeRepresentation, &compilation_result); EXPECT_EQ(s.code(), tensorflow::errors::Code::INVALID_ARGUMENT); EXPECT_EQ(s.ToString(), "Invalid argument: could not parse MLIR module: error: " "custom op 'totally' is unknown\n"); } -TEST(CompileSerializedMlirToXlaHloTest, Success) { - string mlir_module = R"( - module attributes {tf.versions = {producer = 179 : i32}} { - func @main(%arg0: tensor, %arg1: tensor) -> tensor { - %0 = "tf.AddV2"(%arg0, %arg1) {T = "tfdtype$DT_FLOAT", name = "add"} : (tensor, tensor) -> tensor - return %0 : tensor - } +constexpr llvm::StringRef kBinaryAddModule = R"( + module attributes {tf.versions = {producer = 179 : i32}} { + func @main(%arg0: tensor, %arg1: tensor) -> tensor { + %0 = "tf.AddV2"(%arg0, %arg1) {T = "tfdtype$DT_FLOAT", name = "add"} : (tensor, tensor) -> tensor + return %0 : tensor } - )"; + } +)"; +TEST(CompileSerializedMlirToXlaHloTest, TupleArgs) { std::vector arg_shapes(2, TensorShape()); XlaCompiler::CompilationResult compilation_result; Status s = CompileSerializedMlirToXlaHlo( - mlir_module, arg_shapes, TestShapeRepresentation, &compilation_result); + kBinaryAddModule, arg_shapes, + /*use_tuple_args=*/true, TestShapeRepresentation, &compilation_result); ASSERT_TRUE(s.ok()); const xla::HloModuleConfig module_config( @@ -86,7 +87,7 @@ ENTRY %main.6 (arg_tuple.1: (f32[], f32[])) -> (f32[]) { EXPECT_EQ(expected_hlo_module_string, status_or_hlo_module.ValueOrDie()->ToString()); - // Expect an iota like input mapping. + // Expect an in order input mapping. EXPECT_EQ(compilation_result.input_mapping, std::vector({0, 1})); // Expect a single tuple-shape, containing two F32 scalars. @@ -116,6 +117,62 @@ ENTRY %main.6 (arg_tuple.1: (f32[], f32[])) -> (f32[]) { EXPECT_TRUE(compilation_result.resource_updates.empty()); } +TEST(CompileSerializedMlirToXlaHloTest, IndividualArgs) { + std::vector arg_shapes(2, TensorShape()); + XlaCompiler::CompilationResult compilation_result; + + Status s = CompileSerializedMlirToXlaHlo( + kBinaryAddModule, arg_shapes, + /*use_tuple_args=*/false, TestShapeRepresentation, &compilation_result); + ASSERT_TRUE(s.ok()); + + const xla::HloModuleConfig module_config( + compilation_result.computation->GetProgramShape().ValueOrDie()); + auto status_or_hlo_module = xla::HloModule::CreateFromProto( + compilation_result.computation->proto(), module_config); + ASSERT_TRUE(status_or_hlo_module.ok()); + string expected_hlo_module_string = R"(HloModule main.5 + +ENTRY %main.5 (Arg_0.1: f32[], Arg_1.2: f32[]) -> (f32[]) { + %Arg_0.1 = f32[] parameter(0) + %Arg_1.2 = f32[] parameter(1) + %add.3 = f32[] add(f32[] %Arg_0.1, f32[] %Arg_1.2) + ROOT %tuple.4 = (f32[]) tuple(f32[] %add.3) +} + +)"; + EXPECT_EQ(expected_hlo_module_string, + status_or_hlo_module.ValueOrDie()->ToString()); + + // Expect an in order input mapping. + EXPECT_EQ(compilation_result.input_mapping, std::vector({0, 1})); + + // Expect two inputs, each containing a F32 scalar. + EXPECT_EQ(compilation_result.xla_input_shapes.size(), 2); + xla::Shape expected_input_shape = xla::ShapeUtil::MakeShape(xla::F32, {}); + EXPECT_EQ(compilation_result.xla_input_shapes[0], expected_input_shape); + EXPECT_EQ(compilation_result.xla_input_shapes[1], expected_input_shape); + + // Expect output shape is a tuple shape containing a single F32 Scalar type. + const xla::Shape output_shape = + xla::ShapeUtil::MakeShape(xla::PrimitiveType::F32, {}); + const xla::Shape tuple_output_shape = + xla::ShapeUtil::MakeTupleShape({output_shape}); + EXPECT_EQ(compilation_result.xla_output_shape, tuple_output_shape); + + // Expect exactly 1 OutputDescription. + EXPECT_EQ(compilation_result.outputs.size(), 1); + const XlaCompiler::OutputDescription& output_desc = + compilation_result.outputs.front(); + EXPECT_EQ(output_desc.type, DataType::DT_FLOAT); + EXPECT_EQ(output_desc.shape, TensorShape()); + EXPECT_FALSE(output_desc.is_constant); + EXPECT_FALSE(output_desc.is_tensor_list); + + // Expect no resource updates from computation. + EXPECT_TRUE(compilation_result.resource_updates.empty()); +} + // Tests that foldable ops are constant-folded to enable legalization of ops // that require compile time constant operand. TEST(CompileSerializedMlirToXlaHloTest, CompileTimeConstantFoldedSuccess) { @@ -136,7 +193,8 @@ TEST(CompileSerializedMlirToXlaHloTest, CompileTimeConstantFoldedSuccess) { XlaCompiler::CompilationResult compilation_result; Status s = CompileSerializedMlirToXlaHlo( - mlir_module, arg_shapes, TestShapeRepresentation, &compilation_result); + mlir_module, arg_shapes, + /*use_tuple_args=*/true, TestShapeRepresentation, &compilation_result); ASSERT_TRUE(s.ok()); const xla::HloModuleConfig module_config( @@ -174,7 +232,8 @@ TEST(CompileSerializedMlirToXlaHloTest, ShapeInference) { XlaCompiler::CompilationResult compilation_result; Status s = CompileSerializedMlirToXlaHlo( - mlir_module, arg_shapes, TestShapeRepresentation, &compilation_result); + mlir_module, arg_shapes, + /*use_tuple_args=*/true, TestShapeRepresentation, &compilation_result); TF_ASSERT_OK(s); const xla::HloModuleConfig module_config(