Fix parameter check for batchdim in Gather.

PiperOrigin-RevId: 312371119
Change-Id: I7537194147199136b5b847ce6d1ddd361e42a393
This commit is contained in:
A. Unique TensorFlower 2020-05-19 15:55:10 -07:00 committed by TensorFlower Gardener
parent 930709e46e
commit d894109fe1
2 changed files with 22 additions and 7 deletions

View File

@ -88,18 +88,18 @@ class GatherOp : public OpKernel {
}
if (batch_dims_ != 0) {
if (batch_dims_ < 0) {
batch_dims_ = indices.dims() + batch_dims_;
}
if (!axis_is_set) axis = batch_dims_;
OP_REQUIRES(
c, batch_dims_ >= -indices.dims() && batch_dims_ <= indices.dims(),
errors::InvalidArgument("Expected batch_dims in the range [",
-indices.dims(), ", ", indices.dims(),
"], but got ", batch_dims_));
if (batch_dims_ < 0) {
batch_dims_ = indices.dims() + batch_dims_;
}
if (!axis_is_set) axis = batch_dims_;
OP_REQUIRES(c, batch_dims_ < params.dims(),
errors::InvalidArgument("batch_dims (", batch_dims_,
") must be less than rank(params) (",

View File

@ -40,11 +40,12 @@ namespace {
class GatherOpTest : public OpsTestBase {
protected:
void MakeOp(DataType data_type, DataType index_type) {
void MakeOp(DataType data_type, DataType index_type, int batch_dims = 0) {
TF_ASSERT_OK(NodeDefBuilder("myop", "GatherV2")
.Input(FakeInput(data_type))
.Input(FakeInput(index_type))
.Input(FakeInput(index_type))
.Attr("batch_dims", batch_dims)
.Finalize(node_def()));
TF_ASSERT_OK(InitOp());
}
@ -176,6 +177,20 @@ TEST_F(GatherOpTest, Error_IndexOutOfRange) {
<< s;
}
TEST_F(GatherOpTest, Error_BatchDimsOutOfRange) {
MakeOp(DT_FLOAT, DT_INT32, 10);
// Feed and run
AddInputFromArray<float>(TensorShape({5, 3}),
{0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14});
AddInputFromArray<int32>(TensorShape({4}), {0, 4, 99, 2});
AddInputFromArray<int32>(TensorShape({}), {0});
Status s = RunOpKernel();
EXPECT_TRUE(absl::StrContains(
s.ToString(), "Expected batch_dims in the range [-1, 1], but got 10"))
<< s;
}
constexpr int kLookups = 2000;
template <typename Index>