Merge pull request #28278 from Intel-tensorflow:intel_pr_gathernd
PiperOrigin-RevId: 246949519
This commit is contained in:
commit
e98ef953db
@ -71,6 +71,7 @@ class GatherNdOp : public OpKernel {
|
||||
//
|
||||
// Same for the GPU kernel.
|
||||
TF_CALL_ALL_TYPES(REGISTER_GATHER_ND_CPU);
|
||||
TF_CALL_QUANTIZED_TYPES(REGISTER_GATHER_ND_CPU);
|
||||
|
||||
#undef REGISTER_GATHER_ND_CPU
|
||||
|
||||
|
@ -152,6 +152,7 @@ struct GatherNdSlice<CPUDevice, T, Index, IXDIM> {
|
||||
REGISTER_GATHER_ND_FULL(type, int64)
|
||||
|
||||
TF_CALL_ALL_TYPES(REGISTER_GATHER_ND_CPU);
|
||||
TF_CALL_QUANTIZED_TYPES(REGISTER_GATHER_ND_CPU);
|
||||
|
||||
} // namespace functor
|
||||
|
||||
|
@ -57,9 +57,9 @@ namespace {
|
||||
|
||||
class GatherNdOpTest : public OpsTestBase {
|
||||
protected:
|
||||
void MakeOp(DataType index_type) {
|
||||
void MakeOp(DataType param_type, DataType index_type) {
|
||||
TF_ASSERT_OK(NodeDefBuilder("myop", "GatherNd")
|
||||
.Input(FakeInput(DT_FLOAT))
|
||||
.Input(FakeInput(param_type))
|
||||
.Input(FakeInput(index_type))
|
||||
.Finalize(node_def()));
|
||||
TF_ASSERT_OK(InitOp());
|
||||
@ -67,7 +67,7 @@ class GatherNdOpTest : public OpsTestBase {
|
||||
};
|
||||
|
||||
TEST_F(GatherNdOpTest, Simple) {
|
||||
MakeOp(DT_INT32);
|
||||
MakeOp(DT_FLOAT, DT_INT32);
|
||||
|
||||
// Feed and run
|
||||
AddInputFromArray<float>(TensorShape({5}), {0, 1, 2, 8, 4});
|
||||
@ -80,6 +80,32 @@ TEST_F(GatherNdOpTest, Simple) {
|
||||
test::ExpectTensorEqual<float>(expected, *GetOutput(0));
|
||||
}
|
||||
|
||||
TEST_F(GatherNdOpTest, Quantized_UINT8) {
|
||||
MakeOp(DT_QUINT8, DT_INT32);
|
||||
|
||||
// Feed and run
|
||||
AddInputFromArray<quint8>(TensorShape({5}), {0, 1, 2, 8, 4});
|
||||
AddInputFromArray<int32>(TensorShape({2, 1}), {3, 4});
|
||||
TF_ASSERT_OK(RunOpKernel());
|
||||
|
||||
// Check the output.
|
||||
Tensor expected(allocator(), DT_QUINT8, TensorShape({2}));
|
||||
test::FillValues<quint8>(&expected, {8, 4});
|
||||
test::ExpectTensorEqual<quint8>(expected, *GetOutput(0));
|
||||
}
|
||||
|
||||
TEST_F(GatherNdOpTest, Quantized_INT8) {
|
||||
MakeOp(DT_QINT8, DT_INT32);
|
||||
|
||||
AddInputFromArray<qint8>(TensorShape({5}), {0, 1, 2, 8, 4});
|
||||
AddInputFromArray<int32>(TensorShape({2, 1}), {3, 4});
|
||||
TF_ASSERT_OK(RunOpKernel());
|
||||
|
||||
Tensor expected(allocator(), DT_QINT8, TensorShape({2}));
|
||||
test::FillValues<qint8>(&expected, {8, 4});
|
||||
test::ExpectTensorEqual<qint8>(expected, *GetOutput(0));
|
||||
}
|
||||
|
||||
constexpr int kLookups = 2000;
|
||||
|
||||
template <typename Index>
|
||||
|
Loading…
Reference in New Issue
Block a user