Update the tests to use OpsTestBase::GetOutput().
PiperOrigin-RevId: 248834707
This commit is contained in:
parent
b7a587f630
commit
564cc016d1
@ -15,6 +15,7 @@ limitations under the License.
|
|||||||
|
|
||||||
#include <dirent.h>
|
#include <dirent.h>
|
||||||
#include <string.h>
|
#include <string.h>
|
||||||
|
|
||||||
#include <fstream>
|
#include <fstream>
|
||||||
#include <vector>
|
#include <vector>
|
||||||
|
|
||||||
@ -68,9 +69,9 @@ TEST_F(GetSerializedResourceOpTest, Basic) {
|
|||||||
TF_ASSERT_OK(RunOpKernel());
|
TF_ASSERT_OK(RunOpKernel());
|
||||||
|
|
||||||
// Verify the result.
|
// Verify the result.
|
||||||
// TODO(laigd): OpsTestBase::GetOutput() doesn't work.
|
// string type output will remain on CPU, so we're not using GetOutput() here.
|
||||||
Tensor* output = context_->mutable_output(0);
|
EXPECT_EQ("my_serialized_str",
|
||||||
EXPECT_EQ("my_serialized_str", output->scalar<string>()());
|
context_->mutable_output(0)->scalar<string>()());
|
||||||
}
|
}
|
||||||
|
|
||||||
} // namespace tensorrt
|
} // namespace tensorrt
|
||||||
|
@ -87,16 +87,11 @@ TYPED_TEST(TRTEngineOpTest, Basic) {
|
|||||||
TF_ASSERT_OK(OpsTestBase::RunOpKernel());
|
TF_ASSERT_OK(OpsTestBase::RunOpKernel());
|
||||||
|
|
||||||
// Verify the result.
|
// Verify the result.
|
||||||
// TODO(laigd): OpsTestBase::GetOutput() doesn't work.
|
Tensor* output = OpsTestBase::GetOutput(0);
|
||||||
Tensor* output = OpsTestBase::context_->mutable_output(0);
|
EXPECT_THAT(
|
||||||
const auto& tensor_map = output->flat<TypeParam>();
|
absl::Span<const TypeParam>(output->template flat<TypeParam>().data(),
|
||||||
std::vector<TypeParam> output_data(tensor_map.size());
|
output->NumElements()),
|
||||||
ASSERT_EQ(0, cudaDeviceSynchronize());
|
ElementsAre(TypeParam(0.0f), TypeParam(2.0f)));
|
||||||
ASSERT_EQ(0, cudaMemcpy(output_data.data(), tensor_map.data(),
|
|
||||||
sizeof(TypeParam) * tensor_map.size(),
|
|
||||||
cudaMemcpyDeviceToHost));
|
|
||||||
EXPECT_THAT(absl::Span<const TypeParam>(output_data),
|
|
||||||
ElementsAre(TypeParam(0.0f), TypeParam(2.0f)));
|
|
||||||
}
|
}
|
||||||
|
|
||||||
} // namespace tensorrt
|
} // namespace tensorrt
|
||||||
|
Loading…
Reference in New Issue
Block a user