Fixing a crash in tf.fingerprint([]). We now just return []

PiperOrigin-RevId: 334598888
Change-Id: I77423c797f003bf32a50687a0dd389c19f5fa27d
This commit is contained in:
Rohan Jain 2020-09-30 07:58:58 -07:00 committed by TensorFlower Gardener
parent 686912c768
commit 3db7b6d07b
3 changed files with 20 additions and 1 deletions

View File

@ -91,7 +91,12 @@ class FingerprintOp : public OpKernel {
input.shape()));
const int64 dim0 = input.shape().dim_size(0);
const int64 dim1 = input.shape().num_elements() / dim0;
int64 dim1;
if (dim0 == 0) {
dim1 = 0;
} else {
dim1 = input.shape().num_elements() / dim0;
}
Tensor* output;
OP_REQUIRES_OK(context,

View File

@ -61,6 +61,15 @@ class FingerprintOpTest : public OpsTestBase {
Tensor method_;
};
TEST_F(FingerprintOpTest, Empty) {
Tensor tensor(DT_UINT8, {0});
TF_ASSERT_OK(MakeFingerprintOp(&tensor));
TF_ASSERT_OK(RunOpKernel());
EXPECT_EQ(GetOutput(0)->shape(), (TensorShape{0, 8}));
EXPECT_EQ(GetOutput(0)->tensor_data(), "");
}
// This test detects changes in fingerprint method.
TEST_F(FingerprintOpTest, GoldenValue) {
Tensor tensor(DT_UINT8, {1, 3, 4, 5, 6, 7});

View File

@ -37,6 +37,11 @@ class FingerprintTest(test.TestCase):
self.assertTupleEqual(fingerprint0.shape, fingerprint1.shape)
self.assertTrue(np.any(fingerprint0 != fingerprint1))
def test_empty(self):
f0 = self.evaluate(array_ops.fingerprint([]))
self.assertEqual(f0.ndim, 2)
self.assertEqual(f0.shape, (0, 8))
if __name__ == "__main__":
test.main()