Fix parameter check for batchdim in Gather.
PiperOrigin-RevId: 312371119 Change-Id: I7537194147199136b5b847ce6d1ddd361e42a393
This commit is contained in:
parent
930709e46e
commit
d894109fe1
@ -88,18 +88,18 @@ class GatherOp : public OpKernel {
|
||||
}
|
||||
|
||||
if (batch_dims_ != 0) {
|
||||
if (batch_dims_ < 0) {
|
||||
batch_dims_ = indices.dims() + batch_dims_;
|
||||
}
|
||||
|
||||
if (!axis_is_set) axis = batch_dims_;
|
||||
|
||||
OP_REQUIRES(
|
||||
c, batch_dims_ >= -indices.dims() && batch_dims_ <= indices.dims(),
|
||||
errors::InvalidArgument("Expected batch_dims in the range [",
|
||||
-indices.dims(), ", ", indices.dims(),
|
||||
"], but got ", batch_dims_));
|
||||
|
||||
if (batch_dims_ < 0) {
|
||||
batch_dims_ = indices.dims() + batch_dims_;
|
||||
}
|
||||
|
||||
if (!axis_is_set) axis = batch_dims_;
|
||||
|
||||
OP_REQUIRES(c, batch_dims_ < params.dims(),
|
||||
errors::InvalidArgument("batch_dims (", batch_dims_,
|
||||
") must be less than rank(params) (",
|
||||
|
@ -40,11 +40,12 @@ namespace {
|
||||
|
||||
class GatherOpTest : public OpsTestBase {
|
||||
protected:
|
||||
void MakeOp(DataType data_type, DataType index_type) {
|
||||
void MakeOp(DataType data_type, DataType index_type, int batch_dims = 0) {
|
||||
TF_ASSERT_OK(NodeDefBuilder("myop", "GatherV2")
|
||||
.Input(FakeInput(data_type))
|
||||
.Input(FakeInput(index_type))
|
||||
.Input(FakeInput(index_type))
|
||||
.Attr("batch_dims", batch_dims)
|
||||
.Finalize(node_def()));
|
||||
TF_ASSERT_OK(InitOp());
|
||||
}
|
||||
@ -176,6 +177,20 @@ TEST_F(GatherOpTest, Error_IndexOutOfRange) {
|
||||
<< s;
|
||||
}
|
||||
|
||||
TEST_F(GatherOpTest, Error_BatchDimsOutOfRange) {
|
||||
MakeOp(DT_FLOAT, DT_INT32, 10);
|
||||
|
||||
// Feed and run
|
||||
AddInputFromArray<float>(TensorShape({5, 3}),
|
||||
{0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14});
|
||||
AddInputFromArray<int32>(TensorShape({4}), {0, 4, 99, 2});
|
||||
AddInputFromArray<int32>(TensorShape({}), {0});
|
||||
Status s = RunOpKernel();
|
||||
EXPECT_TRUE(absl::StrContains(
|
||||
s.ToString(), "Expected batch_dims in the range [-1, 1], but got 10"))
|
||||
<< s;
|
||||
}
|
||||
|
||||
constexpr int kLookups = 2000;
|
||||
|
||||
template <typename Index>
|
||||
|
Loading…
Reference in New Issue
Block a user