Support non tuple arguments in MLIR to HLO compilation helper function

Planning to use this in a follow-up change to use MLIR compilation passes for on demand XLA compilation.

Note that the result always uses tuple type for the main computation so we don't need a corresponding parameter for the result.

PiperOrigin-RevId: 301209747
Change-Id: Ic1d0fbdd2c69512f21d5feafa9880c0c024d0279
This commit is contained in:
Smit Hinsu 2020-03-16 12:07:14 -07:00 committed by TensorFlower Gardener
parent e861b664e6
commit 7680958e81
3 changed files with 85 additions and 19 deletions

View File

@ -67,6 +67,7 @@ Status ParseMlirModule(llvm::StringRef mlir_module_string,
// Converts arg_shapes to xla::Shape's and store into xla_input_shapes.
Status GetXlaInputShapes(
mlir::ModuleOp module, llvm::ArrayRef<TensorShape> arg_shapes,
bool use_tuple_args,
const xla::CustomShapeRepresentationFn shape_representation_fn,
std::vector<xla::Shape>* xla_input_shapes) {
xla_input_shapes->clear();
@ -88,8 +89,12 @@ Status GetXlaInputShapes(
TF_ASSIGN_OR_RETURN(xla_shape,
shape_representation_fn(arg_shapes[i], dtype));
}
xla_input_shapes->push_back(
xla::ShapeUtil::MakeTupleShape(individual_arg_shapes));
if (use_tuple_args) {
xla_input_shapes->push_back(
xla::ShapeUtil::MakeTupleShape(individual_arg_shapes));
} else {
*xla_input_shapes = individual_arg_shapes;
}
return Status::OK();
}
@ -257,6 +262,7 @@ Status ConvertMLIRToXlaComputation(mlir::ModuleOp module_op,
Status CompileSerializedMlirToXlaHlo(
llvm::StringRef mlir_module_string, llvm::ArrayRef<TensorShape> arg_shapes,
bool use_tuple_args,
const XlaCompiler::ShapeRepresentationFn shape_representation_fn,
XlaCompiler::CompilationResult* compilation_result) {
mlir::MLIRContext mlir_context;
@ -278,7 +284,7 @@ Status CompileSerializedMlirToXlaHlo(
// Convert MLIR module to XLA HLO proto contained in XlaComputation.
compilation_result->computation = std::make_shared<xla::XlaComputation>();
TF_RETURN_IF_ERROR(ConvertMLIRToXlaComputation(
module_op, compilation_result->computation.get(), /*use_tuple_args=*/true,
module_op, compilation_result->computation.get(), use_tuple_args,
/*return_tuple=*/true));
// Construct mapping from XlaComputation's arg to input edges of execute
@ -291,7 +297,7 @@ Status CompileSerializedMlirToXlaHlo(
};
// Compute all input shapes.
TF_RETURN_IF_ERROR(GetXlaInputShapes(module_op, arg_shapes,
TF_RETURN_IF_ERROR(GetXlaInputShapes(module_op, arg_shapes, use_tuple_args,
shape_representation_fn_no_fast_memory,
&compilation_result->xla_input_shapes));

View File

@ -50,6 +50,7 @@ Status ConvertMLIRToXlaComputation(mlir::ModuleOp module_op,
// metadata and stores them in CompilationResult.
Status CompileSerializedMlirToXlaHlo(
llvm::StringRef mlir_module_string, llvm::ArrayRef<TensorShape> arg_shapes,
bool use_tuple_args,
const XlaCompiler::ShapeRepresentationFn shape_representation_fn,
XlaCompiler::CompilationResult* compilation_result);
} // namespace tensorflow

View File

@ -41,30 +41,31 @@ TEST(CompileSerializedMlirToXlaHloTest, InvalidSerializedMlirModule) {
std::vector<TensorShape> arg_shapes;
XlaCompiler::CompilationResult compilation_result;
Status s = CompileSerializedMlirToXlaHlo(invalid_mlir_module, arg_shapes,
TestShapeRepresentation,
&compilation_result);
Status s = CompileSerializedMlirToXlaHlo(
invalid_mlir_module, arg_shapes,
/*use_tuple_args=*/true, TestShapeRepresentation, &compilation_result);
EXPECT_EQ(s.code(), tensorflow::errors::Code::INVALID_ARGUMENT);
EXPECT_EQ(s.ToString(),
"Invalid argument: could not parse MLIR module<stdin>: error: "
"custom op 'totally' is unknown\n");
}
TEST(CompileSerializedMlirToXlaHloTest, Success) {
string mlir_module = R"(
module attributes {tf.versions = {producer = 179 : i32}} {
func @main(%arg0: tensor<f32>, %arg1: tensor<f32>) -> tensor<f32> {
%0 = "tf.AddV2"(%arg0, %arg1) {T = "tfdtype$DT_FLOAT", name = "add"} : (tensor<f32>, tensor<f32>) -> tensor<f32>
return %0 : tensor<f32>
}
constexpr llvm::StringRef kBinaryAddModule = R"(
module attributes {tf.versions = {producer = 179 : i32}} {
func @main(%arg0: tensor<f32>, %arg1: tensor<f32>) -> tensor<f32> {
%0 = "tf.AddV2"(%arg0, %arg1) {T = "tfdtype$DT_FLOAT", name = "add"} : (tensor<f32>, tensor<f32>) -> tensor<f32>
return %0 : tensor<f32>
}
)";
}
)";
TEST(CompileSerializedMlirToXlaHloTest, TupleArgs) {
std::vector<TensorShape> arg_shapes(2, TensorShape());
XlaCompiler::CompilationResult compilation_result;
Status s = CompileSerializedMlirToXlaHlo(
mlir_module, arg_shapes, TestShapeRepresentation, &compilation_result);
kBinaryAddModule, arg_shapes,
/*use_tuple_args=*/true, TestShapeRepresentation, &compilation_result);
ASSERT_TRUE(s.ok());
const xla::HloModuleConfig module_config(
@ -86,7 +87,7 @@ ENTRY %main.6 (arg_tuple.1: (f32[], f32[])) -> (f32[]) {
EXPECT_EQ(expected_hlo_module_string,
status_or_hlo_module.ValueOrDie()->ToString());
// Expect an iota like input mapping.
// Expect an in order input mapping.
EXPECT_EQ(compilation_result.input_mapping, std::vector<int>({0, 1}));
// Expect a single tuple-shape, containing two F32 scalars.
@ -116,6 +117,62 @@ ENTRY %main.6 (arg_tuple.1: (f32[], f32[])) -> (f32[]) {
EXPECT_TRUE(compilation_result.resource_updates.empty());
}
TEST(CompileSerializedMlirToXlaHloTest, IndividualArgs) {
std::vector<TensorShape> arg_shapes(2, TensorShape());
XlaCompiler::CompilationResult compilation_result;
Status s = CompileSerializedMlirToXlaHlo(
kBinaryAddModule, arg_shapes,
/*use_tuple_args=*/false, TestShapeRepresentation, &compilation_result);
ASSERT_TRUE(s.ok());
const xla::HloModuleConfig module_config(
compilation_result.computation->GetProgramShape().ValueOrDie());
auto status_or_hlo_module = xla::HloModule::CreateFromProto(
compilation_result.computation->proto(), module_config);
ASSERT_TRUE(status_or_hlo_module.ok());
string expected_hlo_module_string = R"(HloModule main.5
ENTRY %main.5 (Arg_0.1: f32[], Arg_1.2: f32[]) -> (f32[]) {
%Arg_0.1 = f32[] parameter(0)
%Arg_1.2 = f32[] parameter(1)
%add.3 = f32[] add(f32[] %Arg_0.1, f32[] %Arg_1.2)
ROOT %tuple.4 = (f32[]) tuple(f32[] %add.3)
}
)";
EXPECT_EQ(expected_hlo_module_string,
status_or_hlo_module.ValueOrDie()->ToString());
// Expect an in order input mapping.
EXPECT_EQ(compilation_result.input_mapping, std::vector<int>({0, 1}));
// Expect two inputs, each containing a F32 scalar.
EXPECT_EQ(compilation_result.xla_input_shapes.size(), 2);
xla::Shape expected_input_shape = xla::ShapeUtil::MakeShape(xla::F32, {});
EXPECT_EQ(compilation_result.xla_input_shapes[0], expected_input_shape);
EXPECT_EQ(compilation_result.xla_input_shapes[1], expected_input_shape);
// Expect output shape is a tuple shape containing a single F32 Scalar type.
const xla::Shape output_shape =
xla::ShapeUtil::MakeShape(xla::PrimitiveType::F32, {});
const xla::Shape tuple_output_shape =
xla::ShapeUtil::MakeTupleShape({output_shape});
EXPECT_EQ(compilation_result.xla_output_shape, tuple_output_shape);
// Expect exactly 1 OutputDescription.
EXPECT_EQ(compilation_result.outputs.size(), 1);
const XlaCompiler::OutputDescription& output_desc =
compilation_result.outputs.front();
EXPECT_EQ(output_desc.type, DataType::DT_FLOAT);
EXPECT_EQ(output_desc.shape, TensorShape());
EXPECT_FALSE(output_desc.is_constant);
EXPECT_FALSE(output_desc.is_tensor_list);
// Expect no resource updates from computation.
EXPECT_TRUE(compilation_result.resource_updates.empty());
}
// Tests that foldable ops are constant-folded to enable legalization of ops
// that require compile time constant operand.
TEST(CompileSerializedMlirToXlaHloTest, CompileTimeConstantFoldedSuccess) {
@ -136,7 +193,8 @@ TEST(CompileSerializedMlirToXlaHloTest, CompileTimeConstantFoldedSuccess) {
XlaCompiler::CompilationResult compilation_result;
Status s = CompileSerializedMlirToXlaHlo(
mlir_module, arg_shapes, TestShapeRepresentation, &compilation_result);
mlir_module, arg_shapes,
/*use_tuple_args=*/true, TestShapeRepresentation, &compilation_result);
ASSERT_TRUE(s.ok());
const xla::HloModuleConfig module_config(
@ -174,7 +232,8 @@ TEST(CompileSerializedMlirToXlaHloTest, ShapeInference) {
XlaCompiler::CompilationResult compilation_result;
Status s = CompileSerializedMlirToXlaHlo(
mlir_module, arg_shapes, TestShapeRepresentation, &compilation_result);
mlir_module, arg_shapes,
/*use_tuple_args=*/true, TestShapeRepresentation, &compilation_result);
TF_ASSERT_OK(s);
const xla::HloModuleConfig module_config(