Merge pull request #28278 from Intel-tensorflow:intel_pr_gathernd

PiperOrigin-RevId: 246949519
This commit is contained in:
TensorFlower Gardener 2019-05-06 21:41:11 -07:00
commit e98ef953db
3 changed files with 31 additions and 3 deletions

View File

@ -71,6 +71,7 @@ class GatherNdOp : public OpKernel {
//
// Same for the GPU kernel.
TF_CALL_ALL_TYPES(REGISTER_GATHER_ND_CPU);
TF_CALL_QUANTIZED_TYPES(REGISTER_GATHER_ND_CPU);
#undef REGISTER_GATHER_ND_CPU

View File

@ -152,6 +152,7 @@ struct GatherNdSlice<CPUDevice, T, Index, IXDIM> {
REGISTER_GATHER_ND_FULL(type, int64)
TF_CALL_ALL_TYPES(REGISTER_GATHER_ND_CPU);
TF_CALL_QUANTIZED_TYPES(REGISTER_GATHER_ND_CPU);
} // namespace functor

View File

@ -57,9 +57,9 @@ namespace {
class GatherNdOpTest : public OpsTestBase {
protected:
void MakeOp(DataType index_type) {
void MakeOp(DataType param_type, DataType index_type) {
TF_ASSERT_OK(NodeDefBuilder("myop", "GatherNd")
.Input(FakeInput(DT_FLOAT))
.Input(FakeInput(param_type))
.Input(FakeInput(index_type))
.Finalize(node_def()));
TF_ASSERT_OK(InitOp());
@ -67,7 +67,7 @@ class GatherNdOpTest : public OpsTestBase {
};
TEST_F(GatherNdOpTest, Simple) {
MakeOp(DT_INT32);
MakeOp(DT_FLOAT, DT_INT32);
// Feed and run
AddInputFromArray<float>(TensorShape({5}), {0, 1, 2, 8, 4});
@ -80,6 +80,32 @@ TEST_F(GatherNdOpTest, Simple) {
test::ExpectTensorEqual<float>(expected, *GetOutput(0));
}
TEST_F(GatherNdOpTest, Quantized_UINT8) {
MakeOp(DT_QUINT8, DT_INT32);
// Feed and run
AddInputFromArray<quint8>(TensorShape({5}), {0, 1, 2, 8, 4});
AddInputFromArray<int32>(TensorShape({2, 1}), {3, 4});
TF_ASSERT_OK(RunOpKernel());
// Check the output.
Tensor expected(allocator(), DT_QUINT8, TensorShape({2}));
test::FillValues<quint8>(&expected, {8, 4});
test::ExpectTensorEqual<quint8>(expected, *GetOutput(0));
}
TEST_F(GatherNdOpTest, Quantized_INT8) {
MakeOp(DT_QINT8, DT_INT32);
AddInputFromArray<qint8>(TensorShape({5}), {0, 1, 2, 8, 4});
AddInputFromArray<int32>(TensorShape({2, 1}), {3, 4});
TF_ASSERT_OK(RunOpKernel());
Tensor expected(allocator(), DT_QINT8, TensorShape({2}));
test::FillValues<qint8>(&expected, {8, 4});
test::ExpectTensorEqual<qint8>(expected, *GetOutput(0));
}
constexpr int kLookups = 2000;
template <typename Index>