Extend MemoryTypesForNode test
Add verification of "_input_hostmem" and "_output_hostmem" attributes.
This commit is contained in:
parent
6fe5847c19
commit
1da0eb2f47
@ -33,12 +33,14 @@ class DummyKernel : public OpKernel {
|
||||
|
||||
REGISTER_OP("HostMemoryTest")
|
||||
.Input("a: float")
|
||||
.Input("b: T")
|
||||
.Input("c: N * string")
|
||||
.Input("d: Tlist")
|
||||
.Input("e: Rlist")
|
||||
.Input("b: float")
|
||||
.Input("c: T")
|
||||
.Input("d: N * string")
|
||||
.Input("e: Tlist")
|
||||
.Input("f: Rlist")
|
||||
.Output("o: N * T")
|
||||
.Output("p: Tlist")
|
||||
.Output("p: N * T")
|
||||
.Output("r: Tlist")
|
||||
.Attr("T: type")
|
||||
.Attr("N: int")
|
||||
.Attr("Tlist: list(type)")
|
||||
@ -46,21 +48,25 @@ REGISTER_OP("HostMemoryTest")
|
||||
REGISTER_KERNEL_BUILDER(Name("HostMemoryTest").Device(DEVICE_CPU), DummyKernel);
|
||||
REGISTER_KERNEL_BUILDER(Name("HostMemoryTest")
|
||||
.Device(DEVICE_GPU)
|
||||
.HostMemory("a")
|
||||
.HostMemory("c")
|
||||
.HostMemory("b")
|
||||
.HostMemory("d")
|
||||
.HostMemory("o"),
|
||||
.HostMemory("e")
|
||||
.HostMemory("p"),
|
||||
DummyKernel);
|
||||
|
||||
TEST(MemoryTypesForNode, Simple) {
|
||||
NodeDef node_def;
|
||||
TF_ASSERT_OK(NodeDefBuilder("test", "HostMemoryTest")
|
||||
.Input(FakeInput())
|
||||
.Input(FakeInput())
|
||||
.Input(FakeInput(DT_BOOL))
|
||||
.Input(FakeInput(3))
|
||||
.Input(FakeInput({DT_INT32, DT_FLOAT, DT_INT32}))
|
||||
.Input(FakeInput({DT_RESOURCE, DT_STRING, DT_RESOURCE}))
|
||||
.Finalize(&node_def));
|
||||
AddNodeAttr("_input_hostmem", {0}, &node_def);
|
||||
AddNodeAttr("_output_hostmem", {6, 7}, &node_def);
|
||||
|
||||
MemoryTypeVector input, output;
|
||||
|
||||
TF_EXPECT_OK(MemoryTypesForNode(OpRegistry::Global(), DEVICE_CPU, node_def,
|
||||
@ -68,24 +74,26 @@ TEST(MemoryTypesForNode, Simple) {
|
||||
// a:float, b:bool, c:3*string, d:(int32, float, int32),
|
||||
// e:(resource, string, resource)
|
||||
EXPECT_EQ(
|
||||
MemoryTypeVector({DEVICE_MEMORY, DEVICE_MEMORY, HOST_MEMORY, HOST_MEMORY,
|
||||
HOST_MEMORY, DEVICE_MEMORY, DEVICE_MEMORY,
|
||||
MemoryTypeVector({HOST_MEMORY, DEVICE_MEMORY, DEVICE_MEMORY, HOST_MEMORY,
|
||||
HOST_MEMORY, HOST_MEMORY, DEVICE_MEMORY, DEVICE_MEMORY,
|
||||
DEVICE_MEMORY, HOST_MEMORY, HOST_MEMORY, HOST_MEMORY}),
|
||||
input);
|
||||
// o:3*bool, p:(int32, float, int32)
|
||||
EXPECT_EQ(MemoryTypeVector({DEVICE_MEMORY, DEVICE_MEMORY, DEVICE_MEMORY,
|
||||
DEVICE_MEMORY, DEVICE_MEMORY, DEVICE_MEMORY}),
|
||||
DEVICE_MEMORY, DEVICE_MEMORY, DEVICE_MEMORY,
|
||||
HOST_MEMORY, HOST_MEMORY, DEVICE_MEMORY}),
|
||||
output);
|
||||
|
||||
TF_EXPECT_OK(MemoryTypesForNode(OpRegistry::Global(), DEVICE_GPU, node_def,
|
||||
&input, &output));
|
||||
EXPECT_EQ(
|
||||
MemoryTypeVector({HOST_MEMORY, DEVICE_MEMORY, HOST_MEMORY, HOST_MEMORY,
|
||||
MemoryTypeVector({HOST_MEMORY, HOST_MEMORY, DEVICE_MEMORY, HOST_MEMORY,
|
||||
HOST_MEMORY, HOST_MEMORY, HOST_MEMORY, HOST_MEMORY,
|
||||
HOST_MEMORY, HOST_MEMORY, HOST_MEMORY}),
|
||||
HOST_MEMORY, HOST_MEMORY, HOST_MEMORY, HOST_MEMORY}),
|
||||
input);
|
||||
EXPECT_EQ(MemoryTypeVector({HOST_MEMORY, HOST_MEMORY, HOST_MEMORY,
|
||||
DEVICE_MEMORY, DEVICE_MEMORY, DEVICE_MEMORY}),
|
||||
EXPECT_EQ(MemoryTypeVector({DEVICE_MEMORY, DEVICE_MEMORY, DEVICE_MEMORY,
|
||||
HOST_MEMORY, HOST_MEMORY, HOST_MEMORY,
|
||||
HOST_MEMORY, HOST_MEMORY, DEVICE_MEMORY}),
|
||||
output);
|
||||
}
|
||||
|
||||
|
Loading…
Reference in New Issue
Block a user