From 1da0eb2f4781a68383f71c2082e4d753872fb953 Mon Sep 17 00:00:00 2001 From: Krzysztof Laskowski Date: Thu, 6 Aug 2020 23:42:39 +0200 Subject: [PATCH] Extend MemoryTypesForNode test Add verification of "_input_hostmem" and "_output_hostmem" attributes. --- .../core/framework/memory_types_test.cc | 38 +++++++++++-------- 1 file changed, 23 insertions(+), 15 deletions(-) diff --git a/tensorflow/core/framework/memory_types_test.cc b/tensorflow/core/framework/memory_types_test.cc index 3126ea8e5f8..5228dbafc9b 100644 --- a/tensorflow/core/framework/memory_types_test.cc +++ b/tensorflow/core/framework/memory_types_test.cc @@ -33,12 +33,14 @@ class DummyKernel : public OpKernel { REGISTER_OP("HostMemoryTest") .Input("a: float") - .Input("b: T") - .Input("c: N * string") - .Input("d: Tlist") - .Input("e: Rlist") + .Input("b: float") + .Input("c: T") + .Input("d: N * string") + .Input("e: Tlist") + .Input("f: Rlist") .Output("o: N * T") - .Output("p: Tlist") + .Output("p: N * T") + .Output("r: Tlist") .Attr("T: type") .Attr("N: int") .Attr("Tlist: list(type)") @@ -46,21 +48,25 @@ REGISTER_OP("HostMemoryTest") REGISTER_KERNEL_BUILDER(Name("HostMemoryTest").Device(DEVICE_CPU), DummyKernel); REGISTER_KERNEL_BUILDER(Name("HostMemoryTest") .Device(DEVICE_GPU) - .HostMemory("a") - .HostMemory("c") + .HostMemory("b") .HostMemory("d") - .HostMemory("o"), + .HostMemory("e") + .HostMemory("p"), DummyKernel); TEST(MemoryTypesForNode, Simple) { NodeDef node_def; TF_ASSERT_OK(NodeDefBuilder("test", "HostMemoryTest") + .Input(FakeInput()) .Input(FakeInput()) .Input(FakeInput(DT_BOOL)) .Input(FakeInput(3)) .Input(FakeInput({DT_INT32, DT_FLOAT, DT_INT32})) .Input(FakeInput({DT_RESOURCE, DT_STRING, DT_RESOURCE})) .Finalize(&node_def)); + AddNodeAttr("_input_hostmem", {0}, &node_def); + AddNodeAttr("_output_hostmem", {6, 7}, &node_def); + MemoryTypeVector input, output; TF_EXPECT_OK(MemoryTypesForNode(OpRegistry::Global(), DEVICE_CPU, node_def, @@ -68,24 +74,26 @@ TEST(MemoryTypesForNode, Simple) { // a:float, b:bool, c:3*string, d:(int32, float, int32), // e:(resource, string, resource) EXPECT_EQ( - MemoryTypeVector({DEVICE_MEMORY, DEVICE_MEMORY, HOST_MEMORY, HOST_MEMORY, - HOST_MEMORY, DEVICE_MEMORY, DEVICE_MEMORY, + MemoryTypeVector({HOST_MEMORY, DEVICE_MEMORY, DEVICE_MEMORY, HOST_MEMORY, + HOST_MEMORY, HOST_MEMORY, DEVICE_MEMORY, DEVICE_MEMORY, DEVICE_MEMORY, HOST_MEMORY, HOST_MEMORY, HOST_MEMORY}), input); // o:3*bool, p:(int32, float, int32) EXPECT_EQ(MemoryTypeVector({DEVICE_MEMORY, DEVICE_MEMORY, DEVICE_MEMORY, - DEVICE_MEMORY, DEVICE_MEMORY, DEVICE_MEMORY}), + DEVICE_MEMORY, DEVICE_MEMORY, DEVICE_MEMORY, + HOST_MEMORY, HOST_MEMORY, DEVICE_MEMORY}), output); TF_EXPECT_OK(MemoryTypesForNode(OpRegistry::Global(), DEVICE_GPU, node_def, &input, &output)); EXPECT_EQ( - MemoryTypeVector({HOST_MEMORY, DEVICE_MEMORY, HOST_MEMORY, HOST_MEMORY, + MemoryTypeVector({HOST_MEMORY, HOST_MEMORY, DEVICE_MEMORY, HOST_MEMORY, HOST_MEMORY, HOST_MEMORY, HOST_MEMORY, HOST_MEMORY, - HOST_MEMORY, HOST_MEMORY, HOST_MEMORY}), + HOST_MEMORY, HOST_MEMORY, HOST_MEMORY, HOST_MEMORY}), input); - EXPECT_EQ(MemoryTypeVector({HOST_MEMORY, HOST_MEMORY, HOST_MEMORY, - DEVICE_MEMORY, DEVICE_MEMORY, DEVICE_MEMORY}), + EXPECT_EQ(MemoryTypeVector({DEVICE_MEMORY, DEVICE_MEMORY, DEVICE_MEMORY, + HOST_MEMORY, HOST_MEMORY, HOST_MEMORY, + HOST_MEMORY, HOST_MEMORY, DEVICE_MEMORY}), output); }