[XLA:CPU] Add support for CustomCall targets that return tuples.
Populate the tuple index table of the return value; the callee cannot do this since it does not know the buffer assignments. Explicitly enable custom_call_test only for cpu in the BUILD file, rather than disabling it on non-CPU backends. These tests would not work on any non-CPU backend. PiperOrigin-RevId: 225048065
This commit is contained in:
parent
316660063a
commit
4fe05f35cf
@ -2271,6 +2271,22 @@ Status IrEmitter::HandleCustomCall(HloInstruction* custom_call) {
|
|||||||
/*isVarArg=*/false)));
|
/*isVarArg=*/false)));
|
||||||
|
|
||||||
TF_RETURN_IF_ERROR(EmitTargetAddressForOp(custom_call));
|
TF_RETURN_IF_ERROR(EmitTargetAddressForOp(custom_call));
|
||||||
|
// Write the tuple table if the output is a tuple.
|
||||||
|
if (ShapeUtil::IsTuple(custom_call->shape())) {
|
||||||
|
std::vector<llvm::Value*> base_ptrs;
|
||||||
|
for (int i = 0; i < ShapeUtil::TupleElementCount(custom_call->shape());
|
||||||
|
++i) {
|
||||||
|
const Shape& elem_shape =
|
||||||
|
ShapeUtil::GetTupleElementShape(custom_call->shape(), i);
|
||||||
|
TF_RET_CHECK(!ShapeUtil::IsTuple(elem_shape))
|
||||||
|
<< "Nested tuples not implemented";
|
||||||
|
TF_ASSIGN_OR_RETURN(const BufferAllocation::Slice slice,
|
||||||
|
assignment_.GetUniqueSlice(custom_call, {i}));
|
||||||
|
llvm::Value* addr = EmitBufferPointer(slice, elem_shape);
|
||||||
|
base_ptrs.push_back(addr);
|
||||||
|
}
|
||||||
|
llvm_ir::EmitTuple(GetIrArrayFor(custom_call), base_ptrs, &b_, module_);
|
||||||
|
}
|
||||||
auto* output_address_arg =
|
auto* output_address_arg =
|
||||||
PointerCast(GetEmittedValueFor(custom_call), i8_ptr_type);
|
PointerCast(GetEmittedValueFor(custom_call), i8_ptr_type);
|
||||||
|
|
||||||
|
@ -1,6 +1,13 @@
|
|||||||
# Description:
|
# Description:
|
||||||
# Base testing infrastructure for XLA.
|
# Base testing infrastructure for XLA.
|
||||||
|
|
||||||
|
load("//tensorflow/compiler/xla/tests:build_defs.bzl", "generate_backend_suites", "generate_backend_test_macros", "xla_test", "xla_test_library")
|
||||||
|
load(
|
||||||
|
"//tensorflow/core:platform/default/build_config_root.bzl",
|
||||||
|
"tf_cuda_tests_tags",
|
||||||
|
)
|
||||||
|
load("//tensorflow:tensorflow.bzl", "tf_cc_binary", "tf_cc_test")
|
||||||
|
|
||||||
licenses(["notice"]) # Apache 2.0
|
licenses(["notice"]) # Apache 2.0
|
||||||
|
|
||||||
package(
|
package(
|
||||||
@ -23,17 +30,6 @@ filegroup(
|
|||||||
]),
|
]),
|
||||||
)
|
)
|
||||||
|
|
||||||
load("//tensorflow/compiler/xla/tests:build_defs.bzl", "xla_test")
|
|
||||||
load("//tensorflow/compiler/xla/tests:build_defs.bzl", "xla_test_library")
|
|
||||||
load("//tensorflow/compiler/xla/tests:build_defs.bzl", "generate_backend_suites")
|
|
||||||
load("//tensorflow/compiler/xla/tests:build_defs.bzl", "generate_backend_test_macros")
|
|
||||||
load("//tensorflow:tensorflow.bzl", "tf_cc_binary")
|
|
||||||
load("//tensorflow:tensorflow.bzl", "tf_cc_test")
|
|
||||||
load(
|
|
||||||
"//tensorflow/core:platform/default/build_config_root.bzl",
|
|
||||||
"tf_cuda_tests_tags",
|
|
||||||
)
|
|
||||||
|
|
||||||
# Generate test_suites for all backends, named "${backend}_tests".
|
# Generate test_suites for all backends, named "${backend}_tests".
|
||||||
generate_backend_suites()
|
generate_backend_suites()
|
||||||
|
|
||||||
@ -1348,6 +1344,7 @@ xla_test(
|
|||||||
xla_test(
|
xla_test(
|
||||||
name = "custom_call_test",
|
name = "custom_call_test",
|
||||||
srcs = ["custom_call_test.cc"],
|
srcs = ["custom_call_test.cc"],
|
||||||
|
backends = ["cpu"],
|
||||||
deps = [
|
deps = [
|
||||||
"//tensorflow/compiler/xla:literal",
|
"//tensorflow/compiler/xla:literal",
|
||||||
"//tensorflow/compiler/xla:literal_util",
|
"//tensorflow/compiler/xla:literal_util",
|
||||||
|
@ -54,11 +54,20 @@ void Add1ToValues(float* out, float** in) {
|
|||||||
out[2] = array[2] + 1;
|
out[2] = array[2] + 1;
|
||||||
out[3] = array[3] + 1;
|
out[3] = array[3] + 1;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
void F32TupleSwap(float** out, float** in) {
|
||||||
|
TF_ANNOTATE_MEMORY_IS_INITIALIZED(in[0], sizeof(float));
|
||||||
|
TF_ANNOTATE_MEMORY_IS_INITIALIZED(in[1], sizeof(float));
|
||||||
|
*out[0] = *in[1];
|
||||||
|
*out[1] = *in[0];
|
||||||
|
}
|
||||||
|
|
||||||
} // namespace
|
} // namespace
|
||||||
|
|
||||||
REGISTER_CUSTOM_CALL_TARGET(R0F32Add2);
|
REGISTER_CUSTOM_CALL_TARGET(R0F32Add2);
|
||||||
REGISTER_CUSTOM_CALL_TARGET(R2F32ReduceSum);
|
REGISTER_CUSTOM_CALL_TARGET(R2F32ReduceSum);
|
||||||
REGISTER_CUSTOM_CALL_TARGET(Add1ToValues);
|
REGISTER_CUSTOM_CALL_TARGET(Add1ToValues);
|
||||||
|
REGISTER_CUSTOM_CALL_TARGET(F32TupleSwap);
|
||||||
|
|
||||||
namespace xla {
|
namespace xla {
|
||||||
namespace {
|
namespace {
|
||||||
@ -69,7 +78,7 @@ class CustomCallTest : public HloTestBase {
|
|||||||
Shape r2f32_ = ShapeUtil::MakeShape(F32, {2, 2});
|
Shape r2f32_ = ShapeUtil::MakeShape(F32, {2, 2});
|
||||||
};
|
};
|
||||||
|
|
||||||
XLA_TEST_F(CustomCallTest, DISABLED_ON_GPU(CustomCallR0F32Add2)) {
|
XLA_TEST_F(CustomCallTest, CustomCallR0F32Add2) {
|
||||||
auto module = CreateNewUnverifiedModule();
|
auto module = CreateNewUnverifiedModule();
|
||||||
auto builder = HloComputation::Builder(TestName());
|
auto builder = HloComputation::Builder(TestName());
|
||||||
|
|
||||||
@ -84,7 +93,7 @@ XLA_TEST_F(CustomCallTest, DISABLED_ON_GPU(CustomCallR0F32Add2)) {
|
|||||||
LiteralTestUtil::ExpectR0Near<float>(44.0f, result, error_spec_);
|
LiteralTestUtil::ExpectR0Near<float>(44.0f, result, error_spec_);
|
||||||
}
|
}
|
||||||
|
|
||||||
XLA_TEST_F(CustomCallTest, DISABLED_ON_GPU(CustomCallR2F32Reduce)) {
|
XLA_TEST_F(CustomCallTest, CustomCallR2F32Reduce) {
|
||||||
auto module = CreateNewUnverifiedModule();
|
auto module = CreateNewUnverifiedModule();
|
||||||
auto builder = HloComputation::Builder(TestName());
|
auto builder = HloComputation::Builder(TestName());
|
||||||
|
|
||||||
@ -105,7 +114,7 @@ XLA_TEST_F(CustomCallTest, DISABLED_ON_GPU(CustomCallR2F32Reduce)) {
|
|||||||
LiteralTestUtil::ExpectR0Near<float>(10.0f, result, error_spec_);
|
LiteralTestUtil::ExpectR0Near<float>(10.0f, result, error_spec_);
|
||||||
}
|
}
|
||||||
|
|
||||||
XLA_TEST_F(CustomCallTest, DISABLED_ON_GPU(UsedInOtherComputations)) {
|
XLA_TEST_F(CustomCallTest, UsedInOtherComputations) {
|
||||||
auto module = CreateNewUnverifiedModule();
|
auto module = CreateNewUnverifiedModule();
|
||||||
auto b = HloComputation::Builder(TestName());
|
auto b = HloComputation::Builder(TestName());
|
||||||
|
|
||||||
@ -129,7 +138,7 @@ XLA_TEST_F(CustomCallTest, DISABLED_ON_GPU(UsedInOtherComputations)) {
|
|||||||
Array3D<float>{{{2, 3}, {4, 5}}, {{3, 4}, {5, 6}}}, result);
|
Array3D<float>{{{2, 3}, {4, 5}}, {{3, 4}, {5, 6}}}, result);
|
||||||
}
|
}
|
||||||
|
|
||||||
XLA_TEST_F(CustomCallTest, DISABLED_ON_GPU(InputAndOutputLayoutDiffer)) {
|
XLA_TEST_F(CustomCallTest, InputAndOutputLayoutDiffer) {
|
||||||
auto module = CreateNewUnverifiedModule();
|
auto module = CreateNewUnverifiedModule();
|
||||||
auto b = HloComputation::Builder(TestName());
|
auto b = HloComputation::Builder(TestName());
|
||||||
|
|
||||||
@ -151,7 +160,7 @@ XLA_TEST_F(CustomCallTest, DISABLED_ON_GPU(InputAndOutputLayoutDiffer)) {
|
|||||||
LiteralTestUtil::ExpectR2Equal<float>({{2.f, 4.f}, {3.f, 5.f}}, result);
|
LiteralTestUtil::ExpectR2Equal<float>({{2.f, 4.f}, {3.f, 5.f}}, result);
|
||||||
}
|
}
|
||||||
|
|
||||||
XLA_TEST_F(CustomCallTest, DISABLED_ON_GPU(LayoutConstrained)) {
|
XLA_TEST_F(CustomCallTest, LayoutConstrained) {
|
||||||
// The argument and result of the computation are set to different layouts,
|
// The argument and result of the computation are set to different layouts,
|
||||||
// but the custom call is layout constrained to a fixed operand and result
|
// but the custom call is layout constrained to a fixed operand and result
|
||||||
// layout, so the correct result should be produced.
|
// layout, so the correct result should be produced.
|
||||||
@ -176,6 +185,26 @@ XLA_TEST_F(CustomCallTest, DISABLED_ON_GPU(LayoutConstrained)) {
|
|||||||
LiteralTestUtil::ExpectR2Equal<float>({{2.f, 3.f}, {4.f, 5.f}}, result);
|
LiteralTestUtil::ExpectR2Equal<float>({{2.f, 3.f}, {4.f, 5.f}}, result);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
XLA_TEST_F(CustomCallTest, TupleOutput) {
|
||||||
|
const char* kModuleStr = R"(
|
||||||
|
HloModule m
|
||||||
|
test {
|
||||||
|
p0 = f32[] parameter(0)
|
||||||
|
p1 = f32[] parameter(1)
|
||||||
|
ROOT %custom-call = (f32[], f32[]) custom-call(f32[] %p0, f32[] %p1), custom_call_target="F32TupleSwap", operand_layout_constraints={f32[], f32[]}
|
||||||
|
}
|
||||||
|
)";
|
||||||
|
TF_ASSERT_OK_AND_ASSIGN(auto module,
|
||||||
|
ParseAndReturnVerifiedModule(kModuleStr));
|
||||||
|
|
||||||
|
Literal arg0 = LiteralUtil::CreateR0<float>(7.f);
|
||||||
|
Literal arg1 = LiteralUtil::CreateR0<float>(42.f);
|
||||||
|
|
||||||
|
Literal expected = LiteralUtil::MakeTuple({&arg1, &arg0});
|
||||||
|
Literal result = ExecuteAndTransfer(std::move(module), {&arg0, &arg1});
|
||||||
|
EXPECT_EQ(result, expected);
|
||||||
|
}
|
||||||
|
|
||||||
class CustomCallClientAPITest : public ClientLibraryTestBase {};
|
class CustomCallClientAPITest : public ClientLibraryTestBase {};
|
||||||
|
|
||||||
// When using the client API, CustomCall targets can't begin with '$' -- these
|
// When using the client API, CustomCall targets can't begin with '$' -- these
|
||||||
|
Loading…
Reference in New Issue
Block a user