Extend MemoryTypesForNode test
Add verification of "_input_hostmem" and "_output_hostmem" attributes.
This commit is contained in:
parent
6fe5847c19
commit
1da0eb2f47
@ -33,12 +33,14 @@ class DummyKernel : public OpKernel {
|
|||||||
|
|
||||||
REGISTER_OP("HostMemoryTest")
|
REGISTER_OP("HostMemoryTest")
|
||||||
.Input("a: float")
|
.Input("a: float")
|
||||||
.Input("b: T")
|
.Input("b: float")
|
||||||
.Input("c: N * string")
|
.Input("c: T")
|
||||||
.Input("d: Tlist")
|
.Input("d: N * string")
|
||||||
.Input("e: Rlist")
|
.Input("e: Tlist")
|
||||||
|
.Input("f: Rlist")
|
||||||
.Output("o: N * T")
|
.Output("o: N * T")
|
||||||
.Output("p: Tlist")
|
.Output("p: N * T")
|
||||||
|
.Output("r: Tlist")
|
||||||
.Attr("T: type")
|
.Attr("T: type")
|
||||||
.Attr("N: int")
|
.Attr("N: int")
|
||||||
.Attr("Tlist: list(type)")
|
.Attr("Tlist: list(type)")
|
||||||
@ -46,21 +48,25 @@ REGISTER_OP("HostMemoryTest")
|
|||||||
REGISTER_KERNEL_BUILDER(Name("HostMemoryTest").Device(DEVICE_CPU), DummyKernel);
|
REGISTER_KERNEL_BUILDER(Name("HostMemoryTest").Device(DEVICE_CPU), DummyKernel);
|
||||||
REGISTER_KERNEL_BUILDER(Name("HostMemoryTest")
|
REGISTER_KERNEL_BUILDER(Name("HostMemoryTest")
|
||||||
.Device(DEVICE_GPU)
|
.Device(DEVICE_GPU)
|
||||||
.HostMemory("a")
|
.HostMemory("b")
|
||||||
.HostMemory("c")
|
|
||||||
.HostMemory("d")
|
.HostMemory("d")
|
||||||
.HostMemory("o"),
|
.HostMemory("e")
|
||||||
|
.HostMemory("p"),
|
||||||
DummyKernel);
|
DummyKernel);
|
||||||
|
|
||||||
TEST(MemoryTypesForNode, Simple) {
|
TEST(MemoryTypesForNode, Simple) {
|
||||||
NodeDef node_def;
|
NodeDef node_def;
|
||||||
TF_ASSERT_OK(NodeDefBuilder("test", "HostMemoryTest")
|
TF_ASSERT_OK(NodeDefBuilder("test", "HostMemoryTest")
|
||||||
|
.Input(FakeInput())
|
||||||
.Input(FakeInput())
|
.Input(FakeInput())
|
||||||
.Input(FakeInput(DT_BOOL))
|
.Input(FakeInput(DT_BOOL))
|
||||||
.Input(FakeInput(3))
|
.Input(FakeInput(3))
|
||||||
.Input(FakeInput({DT_INT32, DT_FLOAT, DT_INT32}))
|
.Input(FakeInput({DT_INT32, DT_FLOAT, DT_INT32}))
|
||||||
.Input(FakeInput({DT_RESOURCE, DT_STRING, DT_RESOURCE}))
|
.Input(FakeInput({DT_RESOURCE, DT_STRING, DT_RESOURCE}))
|
||||||
.Finalize(&node_def));
|
.Finalize(&node_def));
|
||||||
|
AddNodeAttr("_input_hostmem", {0}, &node_def);
|
||||||
|
AddNodeAttr("_output_hostmem", {6, 7}, &node_def);
|
||||||
|
|
||||||
MemoryTypeVector input, output;
|
MemoryTypeVector input, output;
|
||||||
|
|
||||||
TF_EXPECT_OK(MemoryTypesForNode(OpRegistry::Global(), DEVICE_CPU, node_def,
|
TF_EXPECT_OK(MemoryTypesForNode(OpRegistry::Global(), DEVICE_CPU, node_def,
|
||||||
@ -68,24 +74,26 @@ TEST(MemoryTypesForNode, Simple) {
|
|||||||
// a:float, b:bool, c:3*string, d:(int32, float, int32),
|
// a:float, b:bool, c:3*string, d:(int32, float, int32),
|
||||||
// e:(resource, string, resource)
|
// e:(resource, string, resource)
|
||||||
EXPECT_EQ(
|
EXPECT_EQ(
|
||||||
MemoryTypeVector({DEVICE_MEMORY, DEVICE_MEMORY, HOST_MEMORY, HOST_MEMORY,
|
MemoryTypeVector({HOST_MEMORY, DEVICE_MEMORY, DEVICE_MEMORY, HOST_MEMORY,
|
||||||
HOST_MEMORY, DEVICE_MEMORY, DEVICE_MEMORY,
|
HOST_MEMORY, HOST_MEMORY, DEVICE_MEMORY, DEVICE_MEMORY,
|
||||||
DEVICE_MEMORY, HOST_MEMORY, HOST_MEMORY, HOST_MEMORY}),
|
DEVICE_MEMORY, HOST_MEMORY, HOST_MEMORY, HOST_MEMORY}),
|
||||||
input);
|
input);
|
||||||
// o:3*bool, p:(int32, float, int32)
|
// o:3*bool, p:(int32, float, int32)
|
||||||
EXPECT_EQ(MemoryTypeVector({DEVICE_MEMORY, DEVICE_MEMORY, DEVICE_MEMORY,
|
EXPECT_EQ(MemoryTypeVector({DEVICE_MEMORY, DEVICE_MEMORY, DEVICE_MEMORY,
|
||||||
DEVICE_MEMORY, DEVICE_MEMORY, DEVICE_MEMORY}),
|
DEVICE_MEMORY, DEVICE_MEMORY, DEVICE_MEMORY,
|
||||||
|
HOST_MEMORY, HOST_MEMORY, DEVICE_MEMORY}),
|
||||||
output);
|
output);
|
||||||
|
|
||||||
TF_EXPECT_OK(MemoryTypesForNode(OpRegistry::Global(), DEVICE_GPU, node_def,
|
TF_EXPECT_OK(MemoryTypesForNode(OpRegistry::Global(), DEVICE_GPU, node_def,
|
||||||
&input, &output));
|
&input, &output));
|
||||||
EXPECT_EQ(
|
EXPECT_EQ(
|
||||||
MemoryTypeVector({HOST_MEMORY, DEVICE_MEMORY, HOST_MEMORY, HOST_MEMORY,
|
MemoryTypeVector({HOST_MEMORY, HOST_MEMORY, DEVICE_MEMORY, HOST_MEMORY,
|
||||||
HOST_MEMORY, HOST_MEMORY, HOST_MEMORY, HOST_MEMORY,
|
HOST_MEMORY, HOST_MEMORY, HOST_MEMORY, HOST_MEMORY,
|
||||||
HOST_MEMORY, HOST_MEMORY, HOST_MEMORY}),
|
HOST_MEMORY, HOST_MEMORY, HOST_MEMORY, HOST_MEMORY}),
|
||||||
input);
|
input);
|
||||||
EXPECT_EQ(MemoryTypeVector({HOST_MEMORY, HOST_MEMORY, HOST_MEMORY,
|
EXPECT_EQ(MemoryTypeVector({DEVICE_MEMORY, DEVICE_MEMORY, DEVICE_MEMORY,
|
||||||
DEVICE_MEMORY, DEVICE_MEMORY, DEVICE_MEMORY}),
|
HOST_MEMORY, HOST_MEMORY, HOST_MEMORY,
|
||||||
|
HOST_MEMORY, HOST_MEMORY, DEVICE_MEMORY}),
|
||||||
output);
|
output);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
Loading…
x
Reference in New Issue
Block a user