Update the tests to use OpsTestBase::GetOutput().

PiperOrigin-RevId: 248834707
This commit is contained in:
Guangda Lai 2019-05-17 21:50:10 -07:00 committed by TensorFlower Gardener
parent b7a587f630
commit 564cc016d1
2 changed files with 9 additions and 13 deletions

View File

@ -15,6 +15,7 @@ limitations under the License.
#include <dirent.h>
#include <string.h>
#include <fstream>
#include <vector>
@ -68,9 +69,9 @@ TEST_F(GetSerializedResourceOpTest, Basic) {
TF_ASSERT_OK(RunOpKernel());
// Verify the result.
// TODO(laigd): OpsTestBase::GetOutput() doesn't work.
Tensor* output = context_->mutable_output(0);
EXPECT_EQ("my_serialized_str", output->scalar<string>()());
// string type output will remain on CPU, so we're not using GetOutput() here.
EXPECT_EQ("my_serialized_str",
context_->mutable_output(0)->scalar<string>()());
}
} // namespace tensorrt

View File

@ -87,16 +87,11 @@ TYPED_TEST(TRTEngineOpTest, Basic) {
TF_ASSERT_OK(OpsTestBase::RunOpKernel());
// Verify the result.
// TODO(laigd): OpsTestBase::GetOutput() doesn't work.
Tensor* output = OpsTestBase::context_->mutable_output(0);
const auto& tensor_map = output->flat<TypeParam>();
std::vector<TypeParam> output_data(tensor_map.size());
ASSERT_EQ(0, cudaDeviceSynchronize());
ASSERT_EQ(0, cudaMemcpy(output_data.data(), tensor_map.data(),
sizeof(TypeParam) * tensor_map.size(),
cudaMemcpyDeviceToHost));
EXPECT_THAT(absl::Span<const TypeParam>(output_data),
ElementsAre(TypeParam(0.0f), TypeParam(2.0f)));
Tensor* output = OpsTestBase::GetOutput(0);
EXPECT_THAT(
absl::Span<const TypeParam>(output->template flat<TypeParam>().data(),
output->NumElements()),
ElementsAre(TypeParam(0.0f), TypeParam(2.0f)));
}
} // namespace tensorrt