From 564cc016d1ae4c209b9ed2de0bd7ee48f30e40b9 Mon Sep 17 00:00:00 2001 From: Guangda Lai Date: Fri, 17 May 2019 21:50:10 -0700 Subject: [PATCH] Update the tests to use OpsTestBase::GetOutput(). PiperOrigin-RevId: 248834707 --- .../kernels/get_serialized_resource_op_test.cc | 7 ++++--- .../tf2tensorrt/kernels/trt_engine_op_test.cc | 15 +++++---------- 2 files changed, 9 insertions(+), 13 deletions(-) diff --git a/tensorflow/compiler/tf2tensorrt/kernels/get_serialized_resource_op_test.cc b/tensorflow/compiler/tf2tensorrt/kernels/get_serialized_resource_op_test.cc index ec038ebda07..d54cbf7836e 100644 --- a/tensorflow/compiler/tf2tensorrt/kernels/get_serialized_resource_op_test.cc +++ b/tensorflow/compiler/tf2tensorrt/kernels/get_serialized_resource_op_test.cc @@ -15,6 +15,7 @@ limitations under the License. #include #include + #include #include @@ -68,9 +69,9 @@ TEST_F(GetSerializedResourceOpTest, Basic) { TF_ASSERT_OK(RunOpKernel()); // Verify the result. - // TODO(laigd): OpsTestBase::GetOutput() doesn't work. - Tensor* output = context_->mutable_output(0); - EXPECT_EQ("my_serialized_str", output->scalar()()); + // string type output will remain on CPU, so we're not using GetOutput() here. + EXPECT_EQ("my_serialized_str", + context_->mutable_output(0)->scalar()()); } } // namespace tensorrt diff --git a/tensorflow/compiler/tf2tensorrt/kernels/trt_engine_op_test.cc b/tensorflow/compiler/tf2tensorrt/kernels/trt_engine_op_test.cc index b62fdc5dc4b..d4077692235 100644 --- a/tensorflow/compiler/tf2tensorrt/kernels/trt_engine_op_test.cc +++ b/tensorflow/compiler/tf2tensorrt/kernels/trt_engine_op_test.cc @@ -87,16 +87,11 @@ TYPED_TEST(TRTEngineOpTest, Basic) { TF_ASSERT_OK(OpsTestBase::RunOpKernel()); // Verify the result. - // TODO(laigd): OpsTestBase::GetOutput() doesn't work. - Tensor* output = OpsTestBase::context_->mutable_output(0); - const auto& tensor_map = output->flat(); - std::vector output_data(tensor_map.size()); - ASSERT_EQ(0, cudaDeviceSynchronize()); - ASSERT_EQ(0, cudaMemcpy(output_data.data(), tensor_map.data(), - sizeof(TypeParam) * tensor_map.size(), - cudaMemcpyDeviceToHost)); - EXPECT_THAT(absl::Span(output_data), - ElementsAre(TypeParam(0.0f), TypeParam(2.0f))); + Tensor* output = OpsTestBase::GetOutput(0); + EXPECT_THAT( + absl::Span(output->template flat().data(), + output->NumElements()), + ElementsAre(TypeParam(0.0f), TypeParam(2.0f))); } } // namespace tensorrt