Update the tests to use OpsTestBase::GetOutput().

PiperOrigin-RevId: 248834707
This commit is contained in:
Guangda Lai 2019-05-17 21:50:10 -07:00 committed by TensorFlower Gardener
parent b7a587f630
commit 564cc016d1
2 changed files with 9 additions and 13 deletions

View File

@ -15,6 +15,7 @@ limitations under the License.
#include <dirent.h> #include <dirent.h>
#include <string.h> #include <string.h>
#include <fstream> #include <fstream>
#include <vector> #include <vector>
@ -68,9 +69,9 @@ TEST_F(GetSerializedResourceOpTest, Basic) {
TF_ASSERT_OK(RunOpKernel()); TF_ASSERT_OK(RunOpKernel());
// Verify the result. // Verify the result.
// TODO(laigd): OpsTestBase::GetOutput() doesn't work. // string type output will remain on CPU, so we're not using GetOutput() here.
Tensor* output = context_->mutable_output(0); EXPECT_EQ("my_serialized_str",
EXPECT_EQ("my_serialized_str", output->scalar<string>()()); context_->mutable_output(0)->scalar<string>()());
} }
} // namespace tensorrt } // namespace tensorrt

View File

@ -87,16 +87,11 @@ TYPED_TEST(TRTEngineOpTest, Basic) {
TF_ASSERT_OK(OpsTestBase::RunOpKernel()); TF_ASSERT_OK(OpsTestBase::RunOpKernel());
// Verify the result. // Verify the result.
// TODO(laigd): OpsTestBase::GetOutput() doesn't work. Tensor* output = OpsTestBase::GetOutput(0);
Tensor* output = OpsTestBase::context_->mutable_output(0); EXPECT_THAT(
const auto& tensor_map = output->flat<TypeParam>(); absl::Span<const TypeParam>(output->template flat<TypeParam>().data(),
std::vector<TypeParam> output_data(tensor_map.size()); output->NumElements()),
ASSERT_EQ(0, cudaDeviceSynchronize()); ElementsAre(TypeParam(0.0f), TypeParam(2.0f)));
ASSERT_EQ(0, cudaMemcpy(output_data.data(), tensor_map.data(),
sizeof(TypeParam) * tensor_map.size(),
cudaMemcpyDeviceToHost));
EXPECT_THAT(absl::Span<const TypeParam>(output_data),
ElementsAre(TypeParam(0.0f), TypeParam(2.0f)));
} }
} // namespace tensorrt } // namespace tensorrt