[INTEL MKL] Adding support for quantized type gather nd op registration
This commit is contained in:
parent
932874df5f
commit
a068d4a458
@ -71,6 +71,7 @@ class GatherNdOp : public OpKernel {
|
|||||||
//
|
//
|
||||||
// Same for the GPU kernel.
|
// Same for the GPU kernel.
|
||||||
TF_CALL_ALL_TYPES(REGISTER_GATHER_ND_CPU);
|
TF_CALL_ALL_TYPES(REGISTER_GATHER_ND_CPU);
|
||||||
|
TF_CALL_QUANTIZED_TYPES(REGISTER_GATHER_ND_CPU);
|
||||||
|
|
||||||
#undef REGISTER_GATHER_ND_CPU
|
#undef REGISTER_GATHER_ND_CPU
|
||||||
|
|
||||||
|
@ -100,8 +100,8 @@ Status DoGatherNd(OpKernelContext* c, const Tensor& params,
|
|||||||
}
|
}
|
||||||
|
|
||||||
if (slice_size_big > std::numeric_limits<Index>::max()) {
|
if (slice_size_big > std::numeric_limits<Index>::max()) {
|
||||||
return errors::InvalidArgument(
|
return errors::InvalidArgument("slice size is too large for indexing: ",
|
||||||
"slice size is too large for indexing: ", slice_size_big, " > ",
|
slice_size_big, " > ",
|
||||||
std::numeric_limits<Index>::max());
|
std::numeric_limits<Index>::max());
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -32,6 +32,8 @@ limitations under the License.
|
|||||||
#include "tensorflow/core/platform/types.h"
|
#include "tensorflow/core/platform/types.h"
|
||||||
#include "tensorflow/core/util/util.h"
|
#include "tensorflow/core/util/util.h"
|
||||||
|
|
||||||
|
#define TF_CALL_DATASET_TYPES(m) TF_CALL_ALL_TYPES(m) TF_CALL_QUANTIZED_TYPES(m)
|
||||||
|
|
||||||
namespace tensorflow {
|
namespace tensorflow {
|
||||||
|
|
||||||
typedef Eigen::ThreadPoolDevice CPUDevice;
|
typedef Eigen::ThreadPoolDevice CPUDevice;
|
||||||
@ -151,7 +153,7 @@ struct GatherNdSlice<CPUDevice, T, Index, IXDIM> {
|
|||||||
REGISTER_GATHER_ND_FULL(type, int32); \
|
REGISTER_GATHER_ND_FULL(type, int32); \
|
||||||
REGISTER_GATHER_ND_FULL(type, int64)
|
REGISTER_GATHER_ND_FULL(type, int64)
|
||||||
|
|
||||||
TF_CALL_ALL_TYPES(REGISTER_GATHER_ND_CPU);
|
TF_CALL_DATASET_TYPES(REGISTER_GATHER_ND_CPU);
|
||||||
|
|
||||||
} // namespace functor
|
} // namespace functor
|
||||||
|
|
||||||
|
@ -57,9 +57,9 @@ namespace {
|
|||||||
|
|
||||||
class GatherNdOpTest : public OpsTestBase {
|
class GatherNdOpTest : public OpsTestBase {
|
||||||
protected:
|
protected:
|
||||||
void MakeOp(DataType index_type) {
|
void MakeOp(DataType param_type, DataType index_type) {
|
||||||
TF_ASSERT_OK(NodeDefBuilder("myop", "GatherNd")
|
TF_ASSERT_OK(NodeDefBuilder("myop", "GatherNd")
|
||||||
.Input(FakeInput(DT_FLOAT))
|
.Input(FakeInput(param_type))
|
||||||
.Input(FakeInput(index_type))
|
.Input(FakeInput(index_type))
|
||||||
.Finalize(node_def()));
|
.Finalize(node_def()));
|
||||||
TF_ASSERT_OK(InitOp());
|
TF_ASSERT_OK(InitOp());
|
||||||
@ -67,7 +67,7 @@ class GatherNdOpTest : public OpsTestBase {
|
|||||||
};
|
};
|
||||||
|
|
||||||
TEST_F(GatherNdOpTest, Simple) {
|
TEST_F(GatherNdOpTest, Simple) {
|
||||||
MakeOp(DT_INT32);
|
MakeOp(DT_FLOAT, DT_INT32);
|
||||||
|
|
||||||
// Feed and run
|
// Feed and run
|
||||||
AddInputFromArray<float>(TensorShape({5}), {0, 1, 2, 8, 4});
|
AddInputFromArray<float>(TensorShape({5}), {0, 1, 2, 8, 4});
|
||||||
@ -80,6 +80,32 @@ TEST_F(GatherNdOpTest, Simple) {
|
|||||||
test::ExpectTensorEqual<float>(expected, *GetOutput(0));
|
test::ExpectTensorEqual<float>(expected, *GetOutput(0));
|
||||||
}
|
}
|
||||||
|
|
||||||
|
TEST_F(GatherNdOpTest, Quantized_UINT8) {
|
||||||
|
MakeOp(DT_QUINT8, DT_INT32);
|
||||||
|
|
||||||
|
// Feed and run
|
||||||
|
AddInputFromArray<quint8>(TensorShape({5}), {0, 1, 2, 8, 4});
|
||||||
|
AddInputFromArray<int32>(TensorShape({2, 1}), {3, 4});
|
||||||
|
TF_ASSERT_OK(RunOpKernel());
|
||||||
|
|
||||||
|
// Check the output.
|
||||||
|
Tensor expected(allocator(), DT_QUINT8, TensorShape({2}));
|
||||||
|
test::FillValues<quint8>(&expected, {8, 4});
|
||||||
|
test::ExpectTensorEqual<quint8>(expected, *GetOutput(0));
|
||||||
|
}
|
||||||
|
|
||||||
|
TEST_F(GatherNdOpTest, Quantized_INT8) {
|
||||||
|
MakeOp(DT_QINT8, DT_INT32);
|
||||||
|
|
||||||
|
AddInputFromArray<qint8>(TensorShape({5}), {0, 1, 2, 8, 4});
|
||||||
|
AddInputFromArray<int32>(TensorShape({2, 1}), {3, 4});
|
||||||
|
TF_ASSERT_OK(RunOpKernel());
|
||||||
|
|
||||||
|
Tensor expected(allocator(), DT_QINT8, TensorShape({2}));
|
||||||
|
test::FillValues<qint8>(&expected, {8, 4});
|
||||||
|
test::ExpectTensorEqual<qint8>(expected, *GetOutput(0));
|
||||||
|
}
|
||||||
|
|
||||||
constexpr int kLookups = 2000;
|
constexpr int kLookups = 2000;
|
||||||
|
|
||||||
template <typename Index>
|
template <typename Index>
|
||||||
|
Loading…
Reference in New Issue
Block a user