Fixing a crash in tf.fingerprint([]). We now just return []
PiperOrigin-RevId: 334598888 Change-Id: I77423c797f003bf32a50687a0dd389c19f5fa27d
This commit is contained in:
parent
686912c768
commit
3db7b6d07b
@ -91,7 +91,12 @@ class FingerprintOp : public OpKernel {
|
||||
input.shape()));
|
||||
|
||||
const int64 dim0 = input.shape().dim_size(0);
|
||||
const int64 dim1 = input.shape().num_elements() / dim0;
|
||||
int64 dim1;
|
||||
if (dim0 == 0) {
|
||||
dim1 = 0;
|
||||
} else {
|
||||
dim1 = input.shape().num_elements() / dim0;
|
||||
}
|
||||
|
||||
Tensor* output;
|
||||
OP_REQUIRES_OK(context,
|
||||
|
@ -61,6 +61,15 @@ class FingerprintOpTest : public OpsTestBase {
|
||||
Tensor method_;
|
||||
};
|
||||
|
||||
TEST_F(FingerprintOpTest, Empty) {
|
||||
Tensor tensor(DT_UINT8, {0});
|
||||
|
||||
TF_ASSERT_OK(MakeFingerprintOp(&tensor));
|
||||
TF_ASSERT_OK(RunOpKernel());
|
||||
EXPECT_EQ(GetOutput(0)->shape(), (TensorShape{0, 8}));
|
||||
EXPECT_EQ(GetOutput(0)->tensor_data(), "");
|
||||
}
|
||||
|
||||
// This test detects changes in fingerprint method.
|
||||
TEST_F(FingerprintOpTest, GoldenValue) {
|
||||
Tensor tensor(DT_UINT8, {1, 3, 4, 5, 6, 7});
|
||||
|
@ -37,6 +37,11 @@ class FingerprintTest(test.TestCase):
|
||||
self.assertTupleEqual(fingerprint0.shape, fingerprint1.shape)
|
||||
self.assertTrue(np.any(fingerprint0 != fingerprint1))
|
||||
|
||||
def test_empty(self):
|
||||
f0 = self.evaluate(array_ops.fingerprint([]))
|
||||
self.assertEqual(f0.ndim, 2)
|
||||
self.assertEqual(f0.shape, (0, 8))
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
test.main()
|
||||
|
Loading…
Reference in New Issue
Block a user