[INTEL MKL] Adding support for quantized type gather nd op registration

This commit is contained in:
Xiaoming (Jason) Cui 2019-04-30 01:10:08 -07:00
parent 932874df5f
commit a068d4a458
4 changed files with 36 additions and 7 deletions

View File

@ -71,6 +71,7 @@ class GatherNdOp : public OpKernel {
// //
// Same for the GPU kernel. // Same for the GPU kernel.
TF_CALL_ALL_TYPES(REGISTER_GATHER_ND_CPU); TF_CALL_ALL_TYPES(REGISTER_GATHER_ND_CPU);
TF_CALL_QUANTIZED_TYPES(REGISTER_GATHER_ND_CPU);
#undef REGISTER_GATHER_ND_CPU #undef REGISTER_GATHER_ND_CPU

View File

@ -100,8 +100,8 @@ Status DoGatherNd(OpKernelContext* c, const Tensor& params,
} }
if (slice_size_big > std::numeric_limits<Index>::max()) { if (slice_size_big > std::numeric_limits<Index>::max()) {
return errors::InvalidArgument( return errors::InvalidArgument("slice size is too large for indexing: ",
"slice size is too large for indexing: ", slice_size_big, " > ", slice_size_big, " > ",
std::numeric_limits<Index>::max()); std::numeric_limits<Index>::max());
} }

View File

@ -32,6 +32,8 @@ limitations under the License.
#include "tensorflow/core/platform/types.h" #include "tensorflow/core/platform/types.h"
#include "tensorflow/core/util/util.h" #include "tensorflow/core/util/util.h"
#define TF_CALL_DATASET_TYPES(m) TF_CALL_ALL_TYPES(m) TF_CALL_QUANTIZED_TYPES(m)
namespace tensorflow { namespace tensorflow {
typedef Eigen::ThreadPoolDevice CPUDevice; typedef Eigen::ThreadPoolDevice CPUDevice;
@ -151,7 +153,7 @@ struct GatherNdSlice<CPUDevice, T, Index, IXDIM> {
REGISTER_GATHER_ND_FULL(type, int32); \ REGISTER_GATHER_ND_FULL(type, int32); \
REGISTER_GATHER_ND_FULL(type, int64) REGISTER_GATHER_ND_FULL(type, int64)
TF_CALL_ALL_TYPES(REGISTER_GATHER_ND_CPU); TF_CALL_DATASET_TYPES(REGISTER_GATHER_ND_CPU);
} // namespace functor } // namespace functor

View File

@ -57,9 +57,9 @@ namespace {
class GatherNdOpTest : public OpsTestBase { class GatherNdOpTest : public OpsTestBase {
protected: protected:
void MakeOp(DataType index_type) { void MakeOp(DataType param_type, DataType index_type) {
TF_ASSERT_OK(NodeDefBuilder("myop", "GatherNd") TF_ASSERT_OK(NodeDefBuilder("myop", "GatherNd")
.Input(FakeInput(DT_FLOAT)) .Input(FakeInput(param_type))
.Input(FakeInput(index_type)) .Input(FakeInput(index_type))
.Finalize(node_def())); .Finalize(node_def()));
TF_ASSERT_OK(InitOp()); TF_ASSERT_OK(InitOp());
@ -67,7 +67,7 @@ class GatherNdOpTest : public OpsTestBase {
}; };
TEST_F(GatherNdOpTest, Simple) { TEST_F(GatherNdOpTest, Simple) {
MakeOp(DT_INT32); MakeOp(DT_FLOAT, DT_INT32);
// Feed and run // Feed and run
AddInputFromArray<float>(TensorShape({5}), {0, 1, 2, 8, 4}); AddInputFromArray<float>(TensorShape({5}), {0, 1, 2, 8, 4});
@ -80,6 +80,32 @@ TEST_F(GatherNdOpTest, Simple) {
test::ExpectTensorEqual<float>(expected, *GetOutput(0)); test::ExpectTensorEqual<float>(expected, *GetOutput(0));
} }
TEST_F(GatherNdOpTest, Quantized_UINT8) {
MakeOp(DT_QUINT8, DT_INT32);
// Feed and run
AddInputFromArray<quint8>(TensorShape({5}), {0, 1, 2, 8, 4});
AddInputFromArray<int32>(TensorShape({2, 1}), {3, 4});
TF_ASSERT_OK(RunOpKernel());
// Check the output.
Tensor expected(allocator(), DT_QUINT8, TensorShape({2}));
test::FillValues<quint8>(&expected, {8, 4});
test::ExpectTensorEqual<quint8>(expected, *GetOutput(0));
}
TEST_F(GatherNdOpTest, Quantized_INT8) {
MakeOp(DT_QINT8, DT_INT32);
AddInputFromArray<qint8>(TensorShape({5}), {0, 1, 2, 8, 4});
AddInputFromArray<int32>(TensorShape({2, 1}), {3, 4});
TF_ASSERT_OK(RunOpKernel());
Tensor expected(allocator(), DT_QINT8, TensorShape({2}));
test::FillValues<qint8>(&expected, {8, 4});
test::ExpectTensorEqual<qint8>(expected, *GetOutput(0));
}
constexpr int kLookups = 2000; constexpr int kLookups = 2000;
template <typename Index> template <typename Index>