Support non tuple arguments in MLIR to HLO compilation helper function
Planning to use this in a follow-up change to use MLIR compilation passes for on demand XLA compilation. Note that the result always uses tuple type for the main computation so we don't need a corresponding parameter for the result. PiperOrigin-RevId: 301209747 Change-Id: Ic1d0fbdd2c69512f21d5feafa9880c0c024d0279
This commit is contained in:
parent
e861b664e6
commit
7680958e81
@ -67,6 +67,7 @@ Status ParseMlirModule(llvm::StringRef mlir_module_string,
|
||||
// Converts arg_shapes to xla::Shape's and store into xla_input_shapes.
|
||||
Status GetXlaInputShapes(
|
||||
mlir::ModuleOp module, llvm::ArrayRef<TensorShape> arg_shapes,
|
||||
bool use_tuple_args,
|
||||
const xla::CustomShapeRepresentationFn shape_representation_fn,
|
||||
std::vector<xla::Shape>* xla_input_shapes) {
|
||||
xla_input_shapes->clear();
|
||||
@ -88,8 +89,12 @@ Status GetXlaInputShapes(
|
||||
TF_ASSIGN_OR_RETURN(xla_shape,
|
||||
shape_representation_fn(arg_shapes[i], dtype));
|
||||
}
|
||||
xla_input_shapes->push_back(
|
||||
xla::ShapeUtil::MakeTupleShape(individual_arg_shapes));
|
||||
if (use_tuple_args) {
|
||||
xla_input_shapes->push_back(
|
||||
xla::ShapeUtil::MakeTupleShape(individual_arg_shapes));
|
||||
} else {
|
||||
*xla_input_shapes = individual_arg_shapes;
|
||||
}
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
@ -257,6 +262,7 @@ Status ConvertMLIRToXlaComputation(mlir::ModuleOp module_op,
|
||||
|
||||
Status CompileSerializedMlirToXlaHlo(
|
||||
llvm::StringRef mlir_module_string, llvm::ArrayRef<TensorShape> arg_shapes,
|
||||
bool use_tuple_args,
|
||||
const XlaCompiler::ShapeRepresentationFn shape_representation_fn,
|
||||
XlaCompiler::CompilationResult* compilation_result) {
|
||||
mlir::MLIRContext mlir_context;
|
||||
@ -278,7 +284,7 @@ Status CompileSerializedMlirToXlaHlo(
|
||||
// Convert MLIR module to XLA HLO proto contained in XlaComputation.
|
||||
compilation_result->computation = std::make_shared<xla::XlaComputation>();
|
||||
TF_RETURN_IF_ERROR(ConvertMLIRToXlaComputation(
|
||||
module_op, compilation_result->computation.get(), /*use_tuple_args=*/true,
|
||||
module_op, compilation_result->computation.get(), use_tuple_args,
|
||||
/*return_tuple=*/true));
|
||||
|
||||
// Construct mapping from XlaComputation's arg to input edges of execute
|
||||
@ -291,7 +297,7 @@ Status CompileSerializedMlirToXlaHlo(
|
||||
};
|
||||
|
||||
// Compute all input shapes.
|
||||
TF_RETURN_IF_ERROR(GetXlaInputShapes(module_op, arg_shapes,
|
||||
TF_RETURN_IF_ERROR(GetXlaInputShapes(module_op, arg_shapes, use_tuple_args,
|
||||
shape_representation_fn_no_fast_memory,
|
||||
&compilation_result->xla_input_shapes));
|
||||
|
||||
|
@ -50,6 +50,7 @@ Status ConvertMLIRToXlaComputation(mlir::ModuleOp module_op,
|
||||
// metadata and stores them in CompilationResult.
|
||||
Status CompileSerializedMlirToXlaHlo(
|
||||
llvm::StringRef mlir_module_string, llvm::ArrayRef<TensorShape> arg_shapes,
|
||||
bool use_tuple_args,
|
||||
const XlaCompiler::ShapeRepresentationFn shape_representation_fn,
|
||||
XlaCompiler::CompilationResult* compilation_result);
|
||||
} // namespace tensorflow
|
||||
|
@ -41,30 +41,31 @@ TEST(CompileSerializedMlirToXlaHloTest, InvalidSerializedMlirModule) {
|
||||
std::vector<TensorShape> arg_shapes;
|
||||
XlaCompiler::CompilationResult compilation_result;
|
||||
|
||||
Status s = CompileSerializedMlirToXlaHlo(invalid_mlir_module, arg_shapes,
|
||||
TestShapeRepresentation,
|
||||
&compilation_result);
|
||||
Status s = CompileSerializedMlirToXlaHlo(
|
||||
invalid_mlir_module, arg_shapes,
|
||||
/*use_tuple_args=*/true, TestShapeRepresentation, &compilation_result);
|
||||
EXPECT_EQ(s.code(), tensorflow::errors::Code::INVALID_ARGUMENT);
|
||||
EXPECT_EQ(s.ToString(),
|
||||
"Invalid argument: could not parse MLIR module<stdin>: error: "
|
||||
"custom op 'totally' is unknown\n");
|
||||
}
|
||||
|
||||
TEST(CompileSerializedMlirToXlaHloTest, Success) {
|
||||
string mlir_module = R"(
|
||||
module attributes {tf.versions = {producer = 179 : i32}} {
|
||||
func @main(%arg0: tensor<f32>, %arg1: tensor<f32>) -> tensor<f32> {
|
||||
%0 = "tf.AddV2"(%arg0, %arg1) {T = "tfdtype$DT_FLOAT", name = "add"} : (tensor<f32>, tensor<f32>) -> tensor<f32>
|
||||
return %0 : tensor<f32>
|
||||
}
|
||||
constexpr llvm::StringRef kBinaryAddModule = R"(
|
||||
module attributes {tf.versions = {producer = 179 : i32}} {
|
||||
func @main(%arg0: tensor<f32>, %arg1: tensor<f32>) -> tensor<f32> {
|
||||
%0 = "tf.AddV2"(%arg0, %arg1) {T = "tfdtype$DT_FLOAT", name = "add"} : (tensor<f32>, tensor<f32>) -> tensor<f32>
|
||||
return %0 : tensor<f32>
|
||||
}
|
||||
)";
|
||||
}
|
||||
)";
|
||||
|
||||
TEST(CompileSerializedMlirToXlaHloTest, TupleArgs) {
|
||||
std::vector<TensorShape> arg_shapes(2, TensorShape());
|
||||
XlaCompiler::CompilationResult compilation_result;
|
||||
|
||||
Status s = CompileSerializedMlirToXlaHlo(
|
||||
mlir_module, arg_shapes, TestShapeRepresentation, &compilation_result);
|
||||
kBinaryAddModule, arg_shapes,
|
||||
/*use_tuple_args=*/true, TestShapeRepresentation, &compilation_result);
|
||||
ASSERT_TRUE(s.ok());
|
||||
|
||||
const xla::HloModuleConfig module_config(
|
||||
@ -86,7 +87,7 @@ ENTRY %main.6 (arg_tuple.1: (f32[], f32[])) -> (f32[]) {
|
||||
EXPECT_EQ(expected_hlo_module_string,
|
||||
status_or_hlo_module.ValueOrDie()->ToString());
|
||||
|
||||
// Expect an iota like input mapping.
|
||||
// Expect an in order input mapping.
|
||||
EXPECT_EQ(compilation_result.input_mapping, std::vector<int>({0, 1}));
|
||||
|
||||
// Expect a single tuple-shape, containing two F32 scalars.
|
||||
@ -116,6 +117,62 @@ ENTRY %main.6 (arg_tuple.1: (f32[], f32[])) -> (f32[]) {
|
||||
EXPECT_TRUE(compilation_result.resource_updates.empty());
|
||||
}
|
||||
|
||||
TEST(CompileSerializedMlirToXlaHloTest, IndividualArgs) {
|
||||
std::vector<TensorShape> arg_shapes(2, TensorShape());
|
||||
XlaCompiler::CompilationResult compilation_result;
|
||||
|
||||
Status s = CompileSerializedMlirToXlaHlo(
|
||||
kBinaryAddModule, arg_shapes,
|
||||
/*use_tuple_args=*/false, TestShapeRepresentation, &compilation_result);
|
||||
ASSERT_TRUE(s.ok());
|
||||
|
||||
const xla::HloModuleConfig module_config(
|
||||
compilation_result.computation->GetProgramShape().ValueOrDie());
|
||||
auto status_or_hlo_module = xla::HloModule::CreateFromProto(
|
||||
compilation_result.computation->proto(), module_config);
|
||||
ASSERT_TRUE(status_or_hlo_module.ok());
|
||||
string expected_hlo_module_string = R"(HloModule main.5
|
||||
|
||||
ENTRY %main.5 (Arg_0.1: f32[], Arg_1.2: f32[]) -> (f32[]) {
|
||||
%Arg_0.1 = f32[] parameter(0)
|
||||
%Arg_1.2 = f32[] parameter(1)
|
||||
%add.3 = f32[] add(f32[] %Arg_0.1, f32[] %Arg_1.2)
|
||||
ROOT %tuple.4 = (f32[]) tuple(f32[] %add.3)
|
||||
}
|
||||
|
||||
)";
|
||||
EXPECT_EQ(expected_hlo_module_string,
|
||||
status_or_hlo_module.ValueOrDie()->ToString());
|
||||
|
||||
// Expect an in order input mapping.
|
||||
EXPECT_EQ(compilation_result.input_mapping, std::vector<int>({0, 1}));
|
||||
|
||||
// Expect two inputs, each containing a F32 scalar.
|
||||
EXPECT_EQ(compilation_result.xla_input_shapes.size(), 2);
|
||||
xla::Shape expected_input_shape = xla::ShapeUtil::MakeShape(xla::F32, {});
|
||||
EXPECT_EQ(compilation_result.xla_input_shapes[0], expected_input_shape);
|
||||
EXPECT_EQ(compilation_result.xla_input_shapes[1], expected_input_shape);
|
||||
|
||||
// Expect output shape is a tuple shape containing a single F32 Scalar type.
|
||||
const xla::Shape output_shape =
|
||||
xla::ShapeUtil::MakeShape(xla::PrimitiveType::F32, {});
|
||||
const xla::Shape tuple_output_shape =
|
||||
xla::ShapeUtil::MakeTupleShape({output_shape});
|
||||
EXPECT_EQ(compilation_result.xla_output_shape, tuple_output_shape);
|
||||
|
||||
// Expect exactly 1 OutputDescription.
|
||||
EXPECT_EQ(compilation_result.outputs.size(), 1);
|
||||
const XlaCompiler::OutputDescription& output_desc =
|
||||
compilation_result.outputs.front();
|
||||
EXPECT_EQ(output_desc.type, DataType::DT_FLOAT);
|
||||
EXPECT_EQ(output_desc.shape, TensorShape());
|
||||
EXPECT_FALSE(output_desc.is_constant);
|
||||
EXPECT_FALSE(output_desc.is_tensor_list);
|
||||
|
||||
// Expect no resource updates from computation.
|
||||
EXPECT_TRUE(compilation_result.resource_updates.empty());
|
||||
}
|
||||
|
||||
// Tests that foldable ops are constant-folded to enable legalization of ops
|
||||
// that require compile time constant operand.
|
||||
TEST(CompileSerializedMlirToXlaHloTest, CompileTimeConstantFoldedSuccess) {
|
||||
@ -136,7 +193,8 @@ TEST(CompileSerializedMlirToXlaHloTest, CompileTimeConstantFoldedSuccess) {
|
||||
XlaCompiler::CompilationResult compilation_result;
|
||||
|
||||
Status s = CompileSerializedMlirToXlaHlo(
|
||||
mlir_module, arg_shapes, TestShapeRepresentation, &compilation_result);
|
||||
mlir_module, arg_shapes,
|
||||
/*use_tuple_args=*/true, TestShapeRepresentation, &compilation_result);
|
||||
ASSERT_TRUE(s.ok());
|
||||
|
||||
const xla::HloModuleConfig module_config(
|
||||
@ -174,7 +232,8 @@ TEST(CompileSerializedMlirToXlaHloTest, ShapeInference) {
|
||||
XlaCompiler::CompilationResult compilation_result;
|
||||
|
||||
Status s = CompileSerializedMlirToXlaHlo(
|
||||
mlir_module, arg_shapes, TestShapeRepresentation, &compilation_result);
|
||||
mlir_module, arg_shapes,
|
||||
/*use_tuple_args=*/true, TestShapeRepresentation, &compilation_result);
|
||||
TF_ASSERT_OK(s);
|
||||
|
||||
const xla::HloModuleConfig module_config(
|
||||
|
Loading…
Reference in New Issue
Block a user