Extend MemoryTypesForNode test

Add verification of "_input_hostmem" and "_output_hostmem" attributes.
This commit is contained in:
Krzysztof Laskowski 2020-08-06 23:42:39 +02:00
parent 6fe5847c19
commit 1da0eb2f47

View File

@ -33,12 +33,14 @@ class DummyKernel : public OpKernel {
REGISTER_OP("HostMemoryTest")
.Input("a: float")
.Input("b: T")
.Input("c: N * string")
.Input("d: Tlist")
.Input("e: Rlist")
.Input("b: float")
.Input("c: T")
.Input("d: N * string")
.Input("e: Tlist")
.Input("f: Rlist")
.Output("o: N * T")
.Output("p: Tlist")
.Output("p: N * T")
.Output("r: Tlist")
.Attr("T: type")
.Attr("N: int")
.Attr("Tlist: list(type)")
@ -46,21 +48,25 @@ REGISTER_OP("HostMemoryTest")
REGISTER_KERNEL_BUILDER(Name("HostMemoryTest").Device(DEVICE_CPU), DummyKernel);
REGISTER_KERNEL_BUILDER(Name("HostMemoryTest")
.Device(DEVICE_GPU)
.HostMemory("a")
.HostMemory("c")
.HostMemory("b")
.HostMemory("d")
.HostMemory("o"),
.HostMemory("e")
.HostMemory("p"),
DummyKernel);
TEST(MemoryTypesForNode, Simple) {
NodeDef node_def;
TF_ASSERT_OK(NodeDefBuilder("test", "HostMemoryTest")
.Input(FakeInput())
.Input(FakeInput())
.Input(FakeInput(DT_BOOL))
.Input(FakeInput(3))
.Input(FakeInput({DT_INT32, DT_FLOAT, DT_INT32}))
.Input(FakeInput({DT_RESOURCE, DT_STRING, DT_RESOURCE}))
.Finalize(&node_def));
AddNodeAttr("_input_hostmem", {0}, &node_def);
AddNodeAttr("_output_hostmem", {6, 7}, &node_def);
MemoryTypeVector input, output;
TF_EXPECT_OK(MemoryTypesForNode(OpRegistry::Global(), DEVICE_CPU, node_def,
@ -68,24 +74,26 @@ TEST(MemoryTypesForNode, Simple) {
// a:float, b:bool, c:3*string, d:(int32, float, int32),
// e:(resource, string, resource)
EXPECT_EQ(
MemoryTypeVector({DEVICE_MEMORY, DEVICE_MEMORY, HOST_MEMORY, HOST_MEMORY,
HOST_MEMORY, DEVICE_MEMORY, DEVICE_MEMORY,
MemoryTypeVector({HOST_MEMORY, DEVICE_MEMORY, DEVICE_MEMORY, HOST_MEMORY,
HOST_MEMORY, HOST_MEMORY, DEVICE_MEMORY, DEVICE_MEMORY,
DEVICE_MEMORY, HOST_MEMORY, HOST_MEMORY, HOST_MEMORY}),
input);
// o:3*bool, p:(int32, float, int32)
EXPECT_EQ(MemoryTypeVector({DEVICE_MEMORY, DEVICE_MEMORY, DEVICE_MEMORY,
DEVICE_MEMORY, DEVICE_MEMORY, DEVICE_MEMORY}),
DEVICE_MEMORY, DEVICE_MEMORY, DEVICE_MEMORY,
HOST_MEMORY, HOST_MEMORY, DEVICE_MEMORY}),
output);
TF_EXPECT_OK(MemoryTypesForNode(OpRegistry::Global(), DEVICE_GPU, node_def,
&input, &output));
EXPECT_EQ(
MemoryTypeVector({HOST_MEMORY, DEVICE_MEMORY, HOST_MEMORY, HOST_MEMORY,
MemoryTypeVector({HOST_MEMORY, HOST_MEMORY, DEVICE_MEMORY, HOST_MEMORY,
HOST_MEMORY, HOST_MEMORY, HOST_MEMORY, HOST_MEMORY,
HOST_MEMORY, HOST_MEMORY, HOST_MEMORY}),
HOST_MEMORY, HOST_MEMORY, HOST_MEMORY, HOST_MEMORY}),
input);
EXPECT_EQ(MemoryTypeVector({HOST_MEMORY, HOST_MEMORY, HOST_MEMORY,
DEVICE_MEMORY, DEVICE_MEMORY, DEVICE_MEMORY}),
EXPECT_EQ(MemoryTypeVector({DEVICE_MEMORY, DEVICE_MEMORY, DEVICE_MEMORY,
HOST_MEMORY, HOST_MEMORY, HOST_MEMORY,
HOST_MEMORY, HOST_MEMORY, DEVICE_MEMORY}),
output);
}