From d894109fe1203f2259819841b85a0354c7780609 Mon Sep 17 00:00:00 2001 From: "A. Unique TensorFlower" Date: Tue, 19 May 2020 15:55:10 -0700 Subject: [PATCH] Fix parameter check for batchdim in Gather. PiperOrigin-RevId: 312371119 Change-Id: I7537194147199136b5b847ce6d1ddd361e42a393 --- tensorflow/core/kernels/gather_op.cc | 12 ++++++------ tensorflow/core/kernels/gather_op_test.cc | 17 ++++++++++++++++- 2 files changed, 22 insertions(+), 7 deletions(-) diff --git a/tensorflow/core/kernels/gather_op.cc b/tensorflow/core/kernels/gather_op.cc index 3ff7afca7df..5e6bd1de9d6 100644 --- a/tensorflow/core/kernels/gather_op.cc +++ b/tensorflow/core/kernels/gather_op.cc @@ -88,18 +88,18 @@ class GatherOp : public OpKernel { } if (batch_dims_ != 0) { - if (batch_dims_ < 0) { - batch_dims_ = indices.dims() + batch_dims_; - } - - if (!axis_is_set) axis = batch_dims_; - OP_REQUIRES( c, batch_dims_ >= -indices.dims() && batch_dims_ <= indices.dims(), errors::InvalidArgument("Expected batch_dims in the range [", -indices.dims(), ", ", indices.dims(), "], but got ", batch_dims_)); + if (batch_dims_ < 0) { + batch_dims_ = indices.dims() + batch_dims_; + } + + if (!axis_is_set) axis = batch_dims_; + OP_REQUIRES(c, batch_dims_ < params.dims(), errors::InvalidArgument("batch_dims (", batch_dims_, ") must be less than rank(params) (", diff --git a/tensorflow/core/kernels/gather_op_test.cc b/tensorflow/core/kernels/gather_op_test.cc index ecac2274ae8..e4c77881ea8 100644 --- a/tensorflow/core/kernels/gather_op_test.cc +++ b/tensorflow/core/kernels/gather_op_test.cc @@ -40,11 +40,12 @@ namespace { class GatherOpTest : public OpsTestBase { protected: - void MakeOp(DataType data_type, DataType index_type) { + void MakeOp(DataType data_type, DataType index_type, int batch_dims = 0) { TF_ASSERT_OK(NodeDefBuilder("myop", "GatherV2") .Input(FakeInput(data_type)) .Input(FakeInput(index_type)) .Input(FakeInput(index_type)) + .Attr("batch_dims", batch_dims) .Finalize(node_def())); TF_ASSERT_OK(InitOp()); } @@ -176,6 +177,20 @@ TEST_F(GatherOpTest, Error_IndexOutOfRange) { << s; } +TEST_F(GatherOpTest, Error_BatchDimsOutOfRange) { + MakeOp(DT_FLOAT, DT_INT32, 10); + + // Feed and run + AddInputFromArray(TensorShape({5, 3}), + {0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14}); + AddInputFromArray(TensorShape({4}), {0, 4, 99, 2}); + AddInputFromArray(TensorShape({}), {0}); + Status s = RunOpKernel(); + EXPECT_TRUE(absl::StrContains( + s.ToString(), "Expected batch_dims in the range [-1, 1], but got 10")) + << s; +} + constexpr int kLookups = 2000; template