Support non tuple arguments in MLIR to HLO compilation helper function
Planning to use this in a follow-up change to use MLIR compilation passes for on demand XLA compilation. Note that the result always uses tuple type for the main computation so we don't need a corresponding parameter for the result. PiperOrigin-RevId: 301209747 Change-Id: Ic1d0fbdd2c69512f21d5feafa9880c0c024d0279
This commit is contained in:
parent
e861b664e6
commit
7680958e81
@ -67,6 +67,7 @@ Status ParseMlirModule(llvm::StringRef mlir_module_string,
|
|||||||
// Converts arg_shapes to xla::Shape's and store into xla_input_shapes.
|
// Converts arg_shapes to xla::Shape's and store into xla_input_shapes.
|
||||||
Status GetXlaInputShapes(
|
Status GetXlaInputShapes(
|
||||||
mlir::ModuleOp module, llvm::ArrayRef<TensorShape> arg_shapes,
|
mlir::ModuleOp module, llvm::ArrayRef<TensorShape> arg_shapes,
|
||||||
|
bool use_tuple_args,
|
||||||
const xla::CustomShapeRepresentationFn shape_representation_fn,
|
const xla::CustomShapeRepresentationFn shape_representation_fn,
|
||||||
std::vector<xla::Shape>* xla_input_shapes) {
|
std::vector<xla::Shape>* xla_input_shapes) {
|
||||||
xla_input_shapes->clear();
|
xla_input_shapes->clear();
|
||||||
@ -88,8 +89,12 @@ Status GetXlaInputShapes(
|
|||||||
TF_ASSIGN_OR_RETURN(xla_shape,
|
TF_ASSIGN_OR_RETURN(xla_shape,
|
||||||
shape_representation_fn(arg_shapes[i], dtype));
|
shape_representation_fn(arg_shapes[i], dtype));
|
||||||
}
|
}
|
||||||
xla_input_shapes->push_back(
|
if (use_tuple_args) {
|
||||||
xla::ShapeUtil::MakeTupleShape(individual_arg_shapes));
|
xla_input_shapes->push_back(
|
||||||
|
xla::ShapeUtil::MakeTupleShape(individual_arg_shapes));
|
||||||
|
} else {
|
||||||
|
*xla_input_shapes = individual_arg_shapes;
|
||||||
|
}
|
||||||
return Status::OK();
|
return Status::OK();
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -257,6 +262,7 @@ Status ConvertMLIRToXlaComputation(mlir::ModuleOp module_op,
|
|||||||
|
|
||||||
Status CompileSerializedMlirToXlaHlo(
|
Status CompileSerializedMlirToXlaHlo(
|
||||||
llvm::StringRef mlir_module_string, llvm::ArrayRef<TensorShape> arg_shapes,
|
llvm::StringRef mlir_module_string, llvm::ArrayRef<TensorShape> arg_shapes,
|
||||||
|
bool use_tuple_args,
|
||||||
const XlaCompiler::ShapeRepresentationFn shape_representation_fn,
|
const XlaCompiler::ShapeRepresentationFn shape_representation_fn,
|
||||||
XlaCompiler::CompilationResult* compilation_result) {
|
XlaCompiler::CompilationResult* compilation_result) {
|
||||||
mlir::MLIRContext mlir_context;
|
mlir::MLIRContext mlir_context;
|
||||||
@ -278,7 +284,7 @@ Status CompileSerializedMlirToXlaHlo(
|
|||||||
// Convert MLIR module to XLA HLO proto contained in XlaComputation.
|
// Convert MLIR module to XLA HLO proto contained in XlaComputation.
|
||||||
compilation_result->computation = std::make_shared<xla::XlaComputation>();
|
compilation_result->computation = std::make_shared<xla::XlaComputation>();
|
||||||
TF_RETURN_IF_ERROR(ConvertMLIRToXlaComputation(
|
TF_RETURN_IF_ERROR(ConvertMLIRToXlaComputation(
|
||||||
module_op, compilation_result->computation.get(), /*use_tuple_args=*/true,
|
module_op, compilation_result->computation.get(), use_tuple_args,
|
||||||
/*return_tuple=*/true));
|
/*return_tuple=*/true));
|
||||||
|
|
||||||
// Construct mapping from XlaComputation's arg to input edges of execute
|
// Construct mapping from XlaComputation's arg to input edges of execute
|
||||||
@ -291,7 +297,7 @@ Status CompileSerializedMlirToXlaHlo(
|
|||||||
};
|
};
|
||||||
|
|
||||||
// Compute all input shapes.
|
// Compute all input shapes.
|
||||||
TF_RETURN_IF_ERROR(GetXlaInputShapes(module_op, arg_shapes,
|
TF_RETURN_IF_ERROR(GetXlaInputShapes(module_op, arg_shapes, use_tuple_args,
|
||||||
shape_representation_fn_no_fast_memory,
|
shape_representation_fn_no_fast_memory,
|
||||||
&compilation_result->xla_input_shapes));
|
&compilation_result->xla_input_shapes));
|
||||||
|
|
||||||
|
@ -50,6 +50,7 @@ Status ConvertMLIRToXlaComputation(mlir::ModuleOp module_op,
|
|||||||
// metadata and stores them in CompilationResult.
|
// metadata and stores them in CompilationResult.
|
||||||
Status CompileSerializedMlirToXlaHlo(
|
Status CompileSerializedMlirToXlaHlo(
|
||||||
llvm::StringRef mlir_module_string, llvm::ArrayRef<TensorShape> arg_shapes,
|
llvm::StringRef mlir_module_string, llvm::ArrayRef<TensorShape> arg_shapes,
|
||||||
|
bool use_tuple_args,
|
||||||
const XlaCompiler::ShapeRepresentationFn shape_representation_fn,
|
const XlaCompiler::ShapeRepresentationFn shape_representation_fn,
|
||||||
XlaCompiler::CompilationResult* compilation_result);
|
XlaCompiler::CompilationResult* compilation_result);
|
||||||
} // namespace tensorflow
|
} // namespace tensorflow
|
||||||
|
@ -41,30 +41,31 @@ TEST(CompileSerializedMlirToXlaHloTest, InvalidSerializedMlirModule) {
|
|||||||
std::vector<TensorShape> arg_shapes;
|
std::vector<TensorShape> arg_shapes;
|
||||||
XlaCompiler::CompilationResult compilation_result;
|
XlaCompiler::CompilationResult compilation_result;
|
||||||
|
|
||||||
Status s = CompileSerializedMlirToXlaHlo(invalid_mlir_module, arg_shapes,
|
Status s = CompileSerializedMlirToXlaHlo(
|
||||||
TestShapeRepresentation,
|
invalid_mlir_module, arg_shapes,
|
||||||
&compilation_result);
|
/*use_tuple_args=*/true, TestShapeRepresentation, &compilation_result);
|
||||||
EXPECT_EQ(s.code(), tensorflow::errors::Code::INVALID_ARGUMENT);
|
EXPECT_EQ(s.code(), tensorflow::errors::Code::INVALID_ARGUMENT);
|
||||||
EXPECT_EQ(s.ToString(),
|
EXPECT_EQ(s.ToString(),
|
||||||
"Invalid argument: could not parse MLIR module<stdin>: error: "
|
"Invalid argument: could not parse MLIR module<stdin>: error: "
|
||||||
"custom op 'totally' is unknown\n");
|
"custom op 'totally' is unknown\n");
|
||||||
}
|
}
|
||||||
|
|
||||||
TEST(CompileSerializedMlirToXlaHloTest, Success) {
|
constexpr llvm::StringRef kBinaryAddModule = R"(
|
||||||
string mlir_module = R"(
|
module attributes {tf.versions = {producer = 179 : i32}} {
|
||||||
module attributes {tf.versions = {producer = 179 : i32}} {
|
func @main(%arg0: tensor<f32>, %arg1: tensor<f32>) -> tensor<f32> {
|
||||||
func @main(%arg0: tensor<f32>, %arg1: tensor<f32>) -> tensor<f32> {
|
%0 = "tf.AddV2"(%arg0, %arg1) {T = "tfdtype$DT_FLOAT", name = "add"} : (tensor<f32>, tensor<f32>) -> tensor<f32>
|
||||||
%0 = "tf.AddV2"(%arg0, %arg1) {T = "tfdtype$DT_FLOAT", name = "add"} : (tensor<f32>, tensor<f32>) -> tensor<f32>
|
return %0 : tensor<f32>
|
||||||
return %0 : tensor<f32>
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
)";
|
}
|
||||||
|
)";
|
||||||
|
|
||||||
|
TEST(CompileSerializedMlirToXlaHloTest, TupleArgs) {
|
||||||
std::vector<TensorShape> arg_shapes(2, TensorShape());
|
std::vector<TensorShape> arg_shapes(2, TensorShape());
|
||||||
XlaCompiler::CompilationResult compilation_result;
|
XlaCompiler::CompilationResult compilation_result;
|
||||||
|
|
||||||
Status s = CompileSerializedMlirToXlaHlo(
|
Status s = CompileSerializedMlirToXlaHlo(
|
||||||
mlir_module, arg_shapes, TestShapeRepresentation, &compilation_result);
|
kBinaryAddModule, arg_shapes,
|
||||||
|
/*use_tuple_args=*/true, TestShapeRepresentation, &compilation_result);
|
||||||
ASSERT_TRUE(s.ok());
|
ASSERT_TRUE(s.ok());
|
||||||
|
|
||||||
const xla::HloModuleConfig module_config(
|
const xla::HloModuleConfig module_config(
|
||||||
@ -86,7 +87,7 @@ ENTRY %main.6 (arg_tuple.1: (f32[], f32[])) -> (f32[]) {
|
|||||||
EXPECT_EQ(expected_hlo_module_string,
|
EXPECT_EQ(expected_hlo_module_string,
|
||||||
status_or_hlo_module.ValueOrDie()->ToString());
|
status_or_hlo_module.ValueOrDie()->ToString());
|
||||||
|
|
||||||
// Expect an iota like input mapping.
|
// Expect an in order input mapping.
|
||||||
EXPECT_EQ(compilation_result.input_mapping, std::vector<int>({0, 1}));
|
EXPECT_EQ(compilation_result.input_mapping, std::vector<int>({0, 1}));
|
||||||
|
|
||||||
// Expect a single tuple-shape, containing two F32 scalars.
|
// Expect a single tuple-shape, containing two F32 scalars.
|
||||||
@ -116,6 +117,62 @@ ENTRY %main.6 (arg_tuple.1: (f32[], f32[])) -> (f32[]) {
|
|||||||
EXPECT_TRUE(compilation_result.resource_updates.empty());
|
EXPECT_TRUE(compilation_result.resource_updates.empty());
|
||||||
}
|
}
|
||||||
|
|
||||||
|
TEST(CompileSerializedMlirToXlaHloTest, IndividualArgs) {
|
||||||
|
std::vector<TensorShape> arg_shapes(2, TensorShape());
|
||||||
|
XlaCompiler::CompilationResult compilation_result;
|
||||||
|
|
||||||
|
Status s = CompileSerializedMlirToXlaHlo(
|
||||||
|
kBinaryAddModule, arg_shapes,
|
||||||
|
/*use_tuple_args=*/false, TestShapeRepresentation, &compilation_result);
|
||||||
|
ASSERT_TRUE(s.ok());
|
||||||
|
|
||||||
|
const xla::HloModuleConfig module_config(
|
||||||
|
compilation_result.computation->GetProgramShape().ValueOrDie());
|
||||||
|
auto status_or_hlo_module = xla::HloModule::CreateFromProto(
|
||||||
|
compilation_result.computation->proto(), module_config);
|
||||||
|
ASSERT_TRUE(status_or_hlo_module.ok());
|
||||||
|
string expected_hlo_module_string = R"(HloModule main.5
|
||||||
|
|
||||||
|
ENTRY %main.5 (Arg_0.1: f32[], Arg_1.2: f32[]) -> (f32[]) {
|
||||||
|
%Arg_0.1 = f32[] parameter(0)
|
||||||
|
%Arg_1.2 = f32[] parameter(1)
|
||||||
|
%add.3 = f32[] add(f32[] %Arg_0.1, f32[] %Arg_1.2)
|
||||||
|
ROOT %tuple.4 = (f32[]) tuple(f32[] %add.3)
|
||||||
|
}
|
||||||
|
|
||||||
|
)";
|
||||||
|
EXPECT_EQ(expected_hlo_module_string,
|
||||||
|
status_or_hlo_module.ValueOrDie()->ToString());
|
||||||
|
|
||||||
|
// Expect an in order input mapping.
|
||||||
|
EXPECT_EQ(compilation_result.input_mapping, std::vector<int>({0, 1}));
|
||||||
|
|
||||||
|
// Expect two inputs, each containing a F32 scalar.
|
||||||
|
EXPECT_EQ(compilation_result.xla_input_shapes.size(), 2);
|
||||||
|
xla::Shape expected_input_shape = xla::ShapeUtil::MakeShape(xla::F32, {});
|
||||||
|
EXPECT_EQ(compilation_result.xla_input_shapes[0], expected_input_shape);
|
||||||
|
EXPECT_EQ(compilation_result.xla_input_shapes[1], expected_input_shape);
|
||||||
|
|
||||||
|
// Expect output shape is a tuple shape containing a single F32 Scalar type.
|
||||||
|
const xla::Shape output_shape =
|
||||||
|
xla::ShapeUtil::MakeShape(xla::PrimitiveType::F32, {});
|
||||||
|
const xla::Shape tuple_output_shape =
|
||||||
|
xla::ShapeUtil::MakeTupleShape({output_shape});
|
||||||
|
EXPECT_EQ(compilation_result.xla_output_shape, tuple_output_shape);
|
||||||
|
|
||||||
|
// Expect exactly 1 OutputDescription.
|
||||||
|
EXPECT_EQ(compilation_result.outputs.size(), 1);
|
||||||
|
const XlaCompiler::OutputDescription& output_desc =
|
||||||
|
compilation_result.outputs.front();
|
||||||
|
EXPECT_EQ(output_desc.type, DataType::DT_FLOAT);
|
||||||
|
EXPECT_EQ(output_desc.shape, TensorShape());
|
||||||
|
EXPECT_FALSE(output_desc.is_constant);
|
||||||
|
EXPECT_FALSE(output_desc.is_tensor_list);
|
||||||
|
|
||||||
|
// Expect no resource updates from computation.
|
||||||
|
EXPECT_TRUE(compilation_result.resource_updates.empty());
|
||||||
|
}
|
||||||
|
|
||||||
// Tests that foldable ops are constant-folded to enable legalization of ops
|
// Tests that foldable ops are constant-folded to enable legalization of ops
|
||||||
// that require compile time constant operand.
|
// that require compile time constant operand.
|
||||||
TEST(CompileSerializedMlirToXlaHloTest, CompileTimeConstantFoldedSuccess) {
|
TEST(CompileSerializedMlirToXlaHloTest, CompileTimeConstantFoldedSuccess) {
|
||||||
@ -136,7 +193,8 @@ TEST(CompileSerializedMlirToXlaHloTest, CompileTimeConstantFoldedSuccess) {
|
|||||||
XlaCompiler::CompilationResult compilation_result;
|
XlaCompiler::CompilationResult compilation_result;
|
||||||
|
|
||||||
Status s = CompileSerializedMlirToXlaHlo(
|
Status s = CompileSerializedMlirToXlaHlo(
|
||||||
mlir_module, arg_shapes, TestShapeRepresentation, &compilation_result);
|
mlir_module, arg_shapes,
|
||||||
|
/*use_tuple_args=*/true, TestShapeRepresentation, &compilation_result);
|
||||||
ASSERT_TRUE(s.ok());
|
ASSERT_TRUE(s.ok());
|
||||||
|
|
||||||
const xla::HloModuleConfig module_config(
|
const xla::HloModuleConfig module_config(
|
||||||
@ -174,7 +232,8 @@ TEST(CompileSerializedMlirToXlaHloTest, ShapeInference) {
|
|||||||
XlaCompiler::CompilationResult compilation_result;
|
XlaCompiler::CompilationResult compilation_result;
|
||||||
|
|
||||||
Status s = CompileSerializedMlirToXlaHlo(
|
Status s = CompileSerializedMlirToXlaHlo(
|
||||||
mlir_module, arg_shapes, TestShapeRepresentation, &compilation_result);
|
mlir_module, arg_shapes,
|
||||||
|
/*use_tuple_args=*/true, TestShapeRepresentation, &compilation_result);
|
||||||
TF_ASSERT_OK(s);
|
TF_ASSERT_OK(s);
|
||||||
|
|
||||||
const xla::HloModuleConfig module_config(
|
const xla::HloModuleConfig module_config(
|
||||||
|
Loading…
Reference in New Issue
Block a user