Extend MemoryTypesForNode test

Add verification of "_input_hostmem" and "_output_hostmem" attributes.
This commit is contained in:
Krzysztof Laskowski 2020-08-06 23:42:39 +02:00
parent 6fe5847c19
commit 1da0eb2f47

View File

@ -33,12 +33,14 @@ class DummyKernel : public OpKernel {
REGISTER_OP("HostMemoryTest") REGISTER_OP("HostMemoryTest")
.Input("a: float") .Input("a: float")
.Input("b: T") .Input("b: float")
.Input("c: N * string") .Input("c: T")
.Input("d: Tlist") .Input("d: N * string")
.Input("e: Rlist") .Input("e: Tlist")
.Input("f: Rlist")
.Output("o: N * T") .Output("o: N * T")
.Output("p: Tlist") .Output("p: N * T")
.Output("r: Tlist")
.Attr("T: type") .Attr("T: type")
.Attr("N: int") .Attr("N: int")
.Attr("Tlist: list(type)") .Attr("Tlist: list(type)")
@ -46,21 +48,25 @@ REGISTER_OP("HostMemoryTest")
REGISTER_KERNEL_BUILDER(Name("HostMemoryTest").Device(DEVICE_CPU), DummyKernel); REGISTER_KERNEL_BUILDER(Name("HostMemoryTest").Device(DEVICE_CPU), DummyKernel);
REGISTER_KERNEL_BUILDER(Name("HostMemoryTest") REGISTER_KERNEL_BUILDER(Name("HostMemoryTest")
.Device(DEVICE_GPU) .Device(DEVICE_GPU)
.HostMemory("a") .HostMemory("b")
.HostMemory("c")
.HostMemory("d") .HostMemory("d")
.HostMemory("o"), .HostMemory("e")
.HostMemory("p"),
DummyKernel); DummyKernel);
TEST(MemoryTypesForNode, Simple) { TEST(MemoryTypesForNode, Simple) {
NodeDef node_def; NodeDef node_def;
TF_ASSERT_OK(NodeDefBuilder("test", "HostMemoryTest") TF_ASSERT_OK(NodeDefBuilder("test", "HostMemoryTest")
.Input(FakeInput())
.Input(FakeInput()) .Input(FakeInput())
.Input(FakeInput(DT_BOOL)) .Input(FakeInput(DT_BOOL))
.Input(FakeInput(3)) .Input(FakeInput(3))
.Input(FakeInput({DT_INT32, DT_FLOAT, DT_INT32})) .Input(FakeInput({DT_INT32, DT_FLOAT, DT_INT32}))
.Input(FakeInput({DT_RESOURCE, DT_STRING, DT_RESOURCE})) .Input(FakeInput({DT_RESOURCE, DT_STRING, DT_RESOURCE}))
.Finalize(&node_def)); .Finalize(&node_def));
AddNodeAttr("_input_hostmem", {0}, &node_def);
AddNodeAttr("_output_hostmem", {6, 7}, &node_def);
MemoryTypeVector input, output; MemoryTypeVector input, output;
TF_EXPECT_OK(MemoryTypesForNode(OpRegistry::Global(), DEVICE_CPU, node_def, TF_EXPECT_OK(MemoryTypesForNode(OpRegistry::Global(), DEVICE_CPU, node_def,
@ -68,24 +74,26 @@ TEST(MemoryTypesForNode, Simple) {
// a:float, b:bool, c:3*string, d:(int32, float, int32), // a:float, b:bool, c:3*string, d:(int32, float, int32),
// e:(resource, string, resource) // e:(resource, string, resource)
EXPECT_EQ( EXPECT_EQ(
MemoryTypeVector({DEVICE_MEMORY, DEVICE_MEMORY, HOST_MEMORY, HOST_MEMORY, MemoryTypeVector({HOST_MEMORY, DEVICE_MEMORY, DEVICE_MEMORY, HOST_MEMORY,
HOST_MEMORY, DEVICE_MEMORY, DEVICE_MEMORY, HOST_MEMORY, HOST_MEMORY, DEVICE_MEMORY, DEVICE_MEMORY,
DEVICE_MEMORY, HOST_MEMORY, HOST_MEMORY, HOST_MEMORY}), DEVICE_MEMORY, HOST_MEMORY, HOST_MEMORY, HOST_MEMORY}),
input); input);
// o:3*bool, p:(int32, float, int32) // o:3*bool, p:(int32, float, int32)
EXPECT_EQ(MemoryTypeVector({DEVICE_MEMORY, DEVICE_MEMORY, DEVICE_MEMORY, EXPECT_EQ(MemoryTypeVector({DEVICE_MEMORY, DEVICE_MEMORY, DEVICE_MEMORY,
DEVICE_MEMORY, DEVICE_MEMORY, DEVICE_MEMORY}), DEVICE_MEMORY, DEVICE_MEMORY, DEVICE_MEMORY,
HOST_MEMORY, HOST_MEMORY, DEVICE_MEMORY}),
output); output);
TF_EXPECT_OK(MemoryTypesForNode(OpRegistry::Global(), DEVICE_GPU, node_def, TF_EXPECT_OK(MemoryTypesForNode(OpRegistry::Global(), DEVICE_GPU, node_def,
&input, &output)); &input, &output));
EXPECT_EQ( EXPECT_EQ(
MemoryTypeVector({HOST_MEMORY, DEVICE_MEMORY, HOST_MEMORY, HOST_MEMORY, MemoryTypeVector({HOST_MEMORY, HOST_MEMORY, DEVICE_MEMORY, HOST_MEMORY,
HOST_MEMORY, HOST_MEMORY, HOST_MEMORY, HOST_MEMORY, HOST_MEMORY, HOST_MEMORY, HOST_MEMORY, HOST_MEMORY,
HOST_MEMORY, HOST_MEMORY, HOST_MEMORY}), HOST_MEMORY, HOST_MEMORY, HOST_MEMORY, HOST_MEMORY}),
input); input);
EXPECT_EQ(MemoryTypeVector({HOST_MEMORY, HOST_MEMORY, HOST_MEMORY, EXPECT_EQ(MemoryTypeVector({DEVICE_MEMORY, DEVICE_MEMORY, DEVICE_MEMORY,
DEVICE_MEMORY, DEVICE_MEMORY, DEVICE_MEMORY}), HOST_MEMORY, HOST_MEMORY, HOST_MEMORY,
HOST_MEMORY, HOST_MEMORY, DEVICE_MEMORY}),
output); output);
} }