Update the tests to use OpsTestBase::GetOutput().
PiperOrigin-RevId: 248834707
This commit is contained in:
parent
b7a587f630
commit
564cc016d1
@ -15,6 +15,7 @@ limitations under the License.
|
||||
|
||||
#include <dirent.h>
|
||||
#include <string.h>
|
||||
|
||||
#include <fstream>
|
||||
#include <vector>
|
||||
|
||||
@ -68,9 +69,9 @@ TEST_F(GetSerializedResourceOpTest, Basic) {
|
||||
TF_ASSERT_OK(RunOpKernel());
|
||||
|
||||
// Verify the result.
|
||||
// TODO(laigd): OpsTestBase::GetOutput() doesn't work.
|
||||
Tensor* output = context_->mutable_output(0);
|
||||
EXPECT_EQ("my_serialized_str", output->scalar<string>()());
|
||||
// string type output will remain on CPU, so we're not using GetOutput() here.
|
||||
EXPECT_EQ("my_serialized_str",
|
||||
context_->mutable_output(0)->scalar<string>()());
|
||||
}
|
||||
|
||||
} // namespace tensorrt
|
||||
|
@ -87,16 +87,11 @@ TYPED_TEST(TRTEngineOpTest, Basic) {
|
||||
TF_ASSERT_OK(OpsTestBase::RunOpKernel());
|
||||
|
||||
// Verify the result.
|
||||
// TODO(laigd): OpsTestBase::GetOutput() doesn't work.
|
||||
Tensor* output = OpsTestBase::context_->mutable_output(0);
|
||||
const auto& tensor_map = output->flat<TypeParam>();
|
||||
std::vector<TypeParam> output_data(tensor_map.size());
|
||||
ASSERT_EQ(0, cudaDeviceSynchronize());
|
||||
ASSERT_EQ(0, cudaMemcpy(output_data.data(), tensor_map.data(),
|
||||
sizeof(TypeParam) * tensor_map.size(),
|
||||
cudaMemcpyDeviceToHost));
|
||||
EXPECT_THAT(absl::Span<const TypeParam>(output_data),
|
||||
ElementsAre(TypeParam(0.0f), TypeParam(2.0f)));
|
||||
Tensor* output = OpsTestBase::GetOutput(0);
|
||||
EXPECT_THAT(
|
||||
absl::Span<const TypeParam>(output->template flat<TypeParam>().data(),
|
||||
output->NumElements()),
|
||||
ElementsAre(TypeParam(0.0f), TypeParam(2.0f)));
|
||||
}
|
||||
|
||||
} // namespace tensorrt
|
||||
|
Loading…
Reference in New Issue
Block a user