Support non tuple arguments in MLIR to HLO compilation helper function

Planning to use this in a follow-up change to use MLIR compilation passes for on demand XLA compilation.

Note that the result always uses tuple type for the main computation so we don't need a corresponding parameter for the result.

PiperOrigin-RevId: 301209747
Change-Id: Ic1d0fbdd2c69512f21d5feafa9880c0c024d0279
This commit is contained in:
Smit Hinsu 2020-03-16 12:07:14 -07:00 committed by TensorFlower Gardener
parent e861b664e6
commit 7680958e81
3 changed files with 85 additions and 19 deletions

View File

@ -67,6 +67,7 @@ Status ParseMlirModule(llvm::StringRef mlir_module_string,
// Converts arg_shapes to xla::Shape's and store into xla_input_shapes. // Converts arg_shapes to xla::Shape's and store into xla_input_shapes.
Status GetXlaInputShapes( Status GetXlaInputShapes(
mlir::ModuleOp module, llvm::ArrayRef<TensorShape> arg_shapes, mlir::ModuleOp module, llvm::ArrayRef<TensorShape> arg_shapes,
bool use_tuple_args,
const xla::CustomShapeRepresentationFn shape_representation_fn, const xla::CustomShapeRepresentationFn shape_representation_fn,
std::vector<xla::Shape>* xla_input_shapes) { std::vector<xla::Shape>* xla_input_shapes) {
xla_input_shapes->clear(); xla_input_shapes->clear();
@ -88,8 +89,12 @@ Status GetXlaInputShapes(
TF_ASSIGN_OR_RETURN(xla_shape, TF_ASSIGN_OR_RETURN(xla_shape,
shape_representation_fn(arg_shapes[i], dtype)); shape_representation_fn(arg_shapes[i], dtype));
} }
xla_input_shapes->push_back( if (use_tuple_args) {
xla::ShapeUtil::MakeTupleShape(individual_arg_shapes)); xla_input_shapes->push_back(
xla::ShapeUtil::MakeTupleShape(individual_arg_shapes));
} else {
*xla_input_shapes = individual_arg_shapes;
}
return Status::OK(); return Status::OK();
} }
@ -257,6 +262,7 @@ Status ConvertMLIRToXlaComputation(mlir::ModuleOp module_op,
Status CompileSerializedMlirToXlaHlo( Status CompileSerializedMlirToXlaHlo(
llvm::StringRef mlir_module_string, llvm::ArrayRef<TensorShape> arg_shapes, llvm::StringRef mlir_module_string, llvm::ArrayRef<TensorShape> arg_shapes,
bool use_tuple_args,
const XlaCompiler::ShapeRepresentationFn shape_representation_fn, const XlaCompiler::ShapeRepresentationFn shape_representation_fn,
XlaCompiler::CompilationResult* compilation_result) { XlaCompiler::CompilationResult* compilation_result) {
mlir::MLIRContext mlir_context; mlir::MLIRContext mlir_context;
@ -278,7 +284,7 @@ Status CompileSerializedMlirToXlaHlo(
// Convert MLIR module to XLA HLO proto contained in XlaComputation. // Convert MLIR module to XLA HLO proto contained in XlaComputation.
compilation_result->computation = std::make_shared<xla::XlaComputation>(); compilation_result->computation = std::make_shared<xla::XlaComputation>();
TF_RETURN_IF_ERROR(ConvertMLIRToXlaComputation( TF_RETURN_IF_ERROR(ConvertMLIRToXlaComputation(
module_op, compilation_result->computation.get(), /*use_tuple_args=*/true, module_op, compilation_result->computation.get(), use_tuple_args,
/*return_tuple=*/true)); /*return_tuple=*/true));
// Construct mapping from XlaComputation's arg to input edges of execute // Construct mapping from XlaComputation's arg to input edges of execute
@ -291,7 +297,7 @@ Status CompileSerializedMlirToXlaHlo(
}; };
// Compute all input shapes. // Compute all input shapes.
TF_RETURN_IF_ERROR(GetXlaInputShapes(module_op, arg_shapes, TF_RETURN_IF_ERROR(GetXlaInputShapes(module_op, arg_shapes, use_tuple_args,
shape_representation_fn_no_fast_memory, shape_representation_fn_no_fast_memory,
&compilation_result->xla_input_shapes)); &compilation_result->xla_input_shapes));

View File

@ -50,6 +50,7 @@ Status ConvertMLIRToXlaComputation(mlir::ModuleOp module_op,
// metadata and stores them in CompilationResult. // metadata and stores them in CompilationResult.
Status CompileSerializedMlirToXlaHlo( Status CompileSerializedMlirToXlaHlo(
llvm::StringRef mlir_module_string, llvm::ArrayRef<TensorShape> arg_shapes, llvm::StringRef mlir_module_string, llvm::ArrayRef<TensorShape> arg_shapes,
bool use_tuple_args,
const XlaCompiler::ShapeRepresentationFn shape_representation_fn, const XlaCompiler::ShapeRepresentationFn shape_representation_fn,
XlaCompiler::CompilationResult* compilation_result); XlaCompiler::CompilationResult* compilation_result);
} // namespace tensorflow } // namespace tensorflow

View File

@ -41,30 +41,31 @@ TEST(CompileSerializedMlirToXlaHloTest, InvalidSerializedMlirModule) {
std::vector<TensorShape> arg_shapes; std::vector<TensorShape> arg_shapes;
XlaCompiler::CompilationResult compilation_result; XlaCompiler::CompilationResult compilation_result;
Status s = CompileSerializedMlirToXlaHlo(invalid_mlir_module, arg_shapes, Status s = CompileSerializedMlirToXlaHlo(
TestShapeRepresentation, invalid_mlir_module, arg_shapes,
&compilation_result); /*use_tuple_args=*/true, TestShapeRepresentation, &compilation_result);
EXPECT_EQ(s.code(), tensorflow::errors::Code::INVALID_ARGUMENT); EXPECT_EQ(s.code(), tensorflow::errors::Code::INVALID_ARGUMENT);
EXPECT_EQ(s.ToString(), EXPECT_EQ(s.ToString(),
"Invalid argument: could not parse MLIR module<stdin>: error: " "Invalid argument: could not parse MLIR module<stdin>: error: "
"custom op 'totally' is unknown\n"); "custom op 'totally' is unknown\n");
} }
TEST(CompileSerializedMlirToXlaHloTest, Success) { constexpr llvm::StringRef kBinaryAddModule = R"(
string mlir_module = R"( module attributes {tf.versions = {producer = 179 : i32}} {
module attributes {tf.versions = {producer = 179 : i32}} { func @main(%arg0: tensor<f32>, %arg1: tensor<f32>) -> tensor<f32> {
func @main(%arg0: tensor<f32>, %arg1: tensor<f32>) -> tensor<f32> { %0 = "tf.AddV2"(%arg0, %arg1) {T = "tfdtype$DT_FLOAT", name = "add"} : (tensor<f32>, tensor<f32>) -> tensor<f32>
%0 = "tf.AddV2"(%arg0, %arg1) {T = "tfdtype$DT_FLOAT", name = "add"} : (tensor<f32>, tensor<f32>) -> tensor<f32> return %0 : tensor<f32>
return %0 : tensor<f32>
}
} }
)"; }
)";
TEST(CompileSerializedMlirToXlaHloTest, TupleArgs) {
std::vector<TensorShape> arg_shapes(2, TensorShape()); std::vector<TensorShape> arg_shapes(2, TensorShape());
XlaCompiler::CompilationResult compilation_result; XlaCompiler::CompilationResult compilation_result;
Status s = CompileSerializedMlirToXlaHlo( Status s = CompileSerializedMlirToXlaHlo(
mlir_module, arg_shapes, TestShapeRepresentation, &compilation_result); kBinaryAddModule, arg_shapes,
/*use_tuple_args=*/true, TestShapeRepresentation, &compilation_result);
ASSERT_TRUE(s.ok()); ASSERT_TRUE(s.ok());
const xla::HloModuleConfig module_config( const xla::HloModuleConfig module_config(
@ -86,7 +87,7 @@ ENTRY %main.6 (arg_tuple.1: (f32[], f32[])) -> (f32[]) {
EXPECT_EQ(expected_hlo_module_string, EXPECT_EQ(expected_hlo_module_string,
status_or_hlo_module.ValueOrDie()->ToString()); status_or_hlo_module.ValueOrDie()->ToString());
// Expect an iota like input mapping. // Expect an in order input mapping.
EXPECT_EQ(compilation_result.input_mapping, std::vector<int>({0, 1})); EXPECT_EQ(compilation_result.input_mapping, std::vector<int>({0, 1}));
// Expect a single tuple-shape, containing two F32 scalars. // Expect a single tuple-shape, containing two F32 scalars.
@ -116,6 +117,62 @@ ENTRY %main.6 (arg_tuple.1: (f32[], f32[])) -> (f32[]) {
EXPECT_TRUE(compilation_result.resource_updates.empty()); EXPECT_TRUE(compilation_result.resource_updates.empty());
} }
TEST(CompileSerializedMlirToXlaHloTest, IndividualArgs) {
std::vector<TensorShape> arg_shapes(2, TensorShape());
XlaCompiler::CompilationResult compilation_result;
Status s = CompileSerializedMlirToXlaHlo(
kBinaryAddModule, arg_shapes,
/*use_tuple_args=*/false, TestShapeRepresentation, &compilation_result);
ASSERT_TRUE(s.ok());
const xla::HloModuleConfig module_config(
compilation_result.computation->GetProgramShape().ValueOrDie());
auto status_or_hlo_module = xla::HloModule::CreateFromProto(
compilation_result.computation->proto(), module_config);
ASSERT_TRUE(status_or_hlo_module.ok());
string expected_hlo_module_string = R"(HloModule main.5
ENTRY %main.5 (Arg_0.1: f32[], Arg_1.2: f32[]) -> (f32[]) {
%Arg_0.1 = f32[] parameter(0)
%Arg_1.2 = f32[] parameter(1)
%add.3 = f32[] add(f32[] %Arg_0.1, f32[] %Arg_1.2)
ROOT %tuple.4 = (f32[]) tuple(f32[] %add.3)
}
)";
EXPECT_EQ(expected_hlo_module_string,
status_or_hlo_module.ValueOrDie()->ToString());
// Expect an in order input mapping.
EXPECT_EQ(compilation_result.input_mapping, std::vector<int>({0, 1}));
// Expect two inputs, each containing a F32 scalar.
EXPECT_EQ(compilation_result.xla_input_shapes.size(), 2);
xla::Shape expected_input_shape = xla::ShapeUtil::MakeShape(xla::F32, {});
EXPECT_EQ(compilation_result.xla_input_shapes[0], expected_input_shape);
EXPECT_EQ(compilation_result.xla_input_shapes[1], expected_input_shape);
// Expect output shape is a tuple shape containing a single F32 Scalar type.
const xla::Shape output_shape =
xla::ShapeUtil::MakeShape(xla::PrimitiveType::F32, {});
const xla::Shape tuple_output_shape =
xla::ShapeUtil::MakeTupleShape({output_shape});
EXPECT_EQ(compilation_result.xla_output_shape, tuple_output_shape);
// Expect exactly 1 OutputDescription.
EXPECT_EQ(compilation_result.outputs.size(), 1);
const XlaCompiler::OutputDescription& output_desc =
compilation_result.outputs.front();
EXPECT_EQ(output_desc.type, DataType::DT_FLOAT);
EXPECT_EQ(output_desc.shape, TensorShape());
EXPECT_FALSE(output_desc.is_constant);
EXPECT_FALSE(output_desc.is_tensor_list);
// Expect no resource updates from computation.
EXPECT_TRUE(compilation_result.resource_updates.empty());
}
// Tests that foldable ops are constant-folded to enable legalization of ops // Tests that foldable ops are constant-folded to enable legalization of ops
// that require compile time constant operand. // that require compile time constant operand.
TEST(CompileSerializedMlirToXlaHloTest, CompileTimeConstantFoldedSuccess) { TEST(CompileSerializedMlirToXlaHloTest, CompileTimeConstantFoldedSuccess) {
@ -136,7 +193,8 @@ TEST(CompileSerializedMlirToXlaHloTest, CompileTimeConstantFoldedSuccess) {
XlaCompiler::CompilationResult compilation_result; XlaCompiler::CompilationResult compilation_result;
Status s = CompileSerializedMlirToXlaHlo( Status s = CompileSerializedMlirToXlaHlo(
mlir_module, arg_shapes, TestShapeRepresentation, &compilation_result); mlir_module, arg_shapes,
/*use_tuple_args=*/true, TestShapeRepresentation, &compilation_result);
ASSERT_TRUE(s.ok()); ASSERT_TRUE(s.ok());
const xla::HloModuleConfig module_config( const xla::HloModuleConfig module_config(
@ -174,7 +232,8 @@ TEST(CompileSerializedMlirToXlaHloTest, ShapeInference) {
XlaCompiler::CompilationResult compilation_result; XlaCompiler::CompilationResult compilation_result;
Status s = CompileSerializedMlirToXlaHlo( Status s = CompileSerializedMlirToXlaHlo(
mlir_module, arg_shapes, TestShapeRepresentation, &compilation_result); mlir_module, arg_shapes,
/*use_tuple_args=*/true, TestShapeRepresentation, &compilation_result);
TF_ASSERT_OK(s); TF_ASSERT_OK(s);
const xla::HloModuleConfig module_config( const xla::HloModuleConfig module_config(