Make tf.reshape(tensor, shape)
work with the case that tensor
has 0 dimension and shape
has unknown dimension.
PiperOrigin-RevId: 243094911
This commit is contained in:
parent
86a6d44352
commit
36e10587b5
@ -51,6 +51,7 @@ class ReshapeOp : public XlaOpKernel {
|
||||
TensorShape shape;
|
||||
int64 product = 1;
|
||||
int unknown_index = -1;
|
||||
bool shape_has_zero_dim = false;
|
||||
for (int d = 0; d < num_dims; ++d) {
|
||||
const int32 size = shape_input[d];
|
||||
if (size == -1) {
|
||||
@ -60,6 +61,12 @@ class ReshapeOp : public XlaOpKernel {
|
||||
unknown_index, " and ", d));
|
||||
unknown_index = d;
|
||||
shape.AddDim(1);
|
||||
} else if (size == 0) {
|
||||
// We don't include zero-sized dimension in product, so that we can
|
||||
// still calculate number of elements for non-zero-sized dimensions and
|
||||
// therefore infer their shapes.
|
||||
shape.AddDim(size);
|
||||
shape_has_zero_dim = true;
|
||||
} else {
|
||||
OP_REQUIRES(ctx, size >= 0,
|
||||
errors::InvalidArgument(
|
||||
@ -69,18 +76,28 @@ class ReshapeOp : public XlaOpKernel {
|
||||
}
|
||||
}
|
||||
if (unknown_index != -1) {
|
||||
int64 input_num_elements = 1;
|
||||
bool input_has_zero_dim = false;
|
||||
for (int dim = 0; dim < input_shape.dims(); dim++) {
|
||||
// For zero dimension, we don't count it into `input_num_elements`
|
||||
// unless `sizes` has no zero dimension, so we are still able to
|
||||
// infer shapes for other dimensions.
|
||||
if (input_shape.dim_size(dim) > 0 || !shape_has_zero_dim) {
|
||||
input_num_elements *= input_shape.dim_size(dim);
|
||||
} else {
|
||||
input_has_zero_dim = true;
|
||||
}
|
||||
}
|
||||
|
||||
const int64 missing = input_num_elements / product;
|
||||
if (!input_has_zero_dim) {
|
||||
OP_REQUIRES(
|
||||
ctx, product > 0,
|
||||
errors::InvalidArgument("Reshape cannot infer the missing input size "
|
||||
"for an empty tensor unless all specified "
|
||||
"input sizes are non-zero"));
|
||||
const int64 missing = input_shape.num_elements() / product;
|
||||
OP_REQUIRES(
|
||||
ctx, product * missing == input_shape.num_elements(),
|
||||
ctx, product * missing == input_num_elements,
|
||||
errors::InvalidArgument(
|
||||
"Input to reshape is a tensor with ", input_shape.num_elements(),
|
||||
"Input to reshape is a tensor with ", input_num_elements,
|
||||
" values, but the requested shape requires a multiple of ",
|
||||
product));
|
||||
}
|
||||
shape.set_dim(unknown_index, missing);
|
||||
}
|
||||
OP_REQUIRES(ctx, shape.num_elements() == input_shape.num_elements(),
|
||||
|
@ -45,14 +45,17 @@ class ReshapeOp : public OpKernel {
|
||||
TensorShape shape;
|
||||
int64 product = 1;
|
||||
int unknown_index = -1;
|
||||
bool sizes_has_zero_dim;
|
||||
switch (sizes.dtype()) {
|
||||
case DT_INT32:
|
||||
OP_REQUIRES_OK(context, ValidateSizes<int32>(sizes, &product,
|
||||
&unknown_index, &shape));
|
||||
OP_REQUIRES_OK(context,
|
||||
ValidateSizes<int32>(sizes, &product, &unknown_index,
|
||||
&shape, &sizes_has_zero_dim));
|
||||
break;
|
||||
case DT_INT64:
|
||||
OP_REQUIRES_OK(context, ValidateSizes<int64>(sizes, &product,
|
||||
&unknown_index, &shape));
|
||||
OP_REQUIRES_OK(context,
|
||||
ValidateSizes<int64>(sizes, &product, &unknown_index,
|
||||
&shape, &sizes_has_zero_dim));
|
||||
break;
|
||||
default:
|
||||
context->CtxFailure(errors::InvalidArgument(
|
||||
@ -61,18 +64,28 @@ class ReshapeOp : public OpKernel {
|
||||
return;
|
||||
}
|
||||
if (unknown_index != -1) {
|
||||
int64 input_num_elements = 1;
|
||||
bool input_has_zero_dim = false;
|
||||
for (int dim = 0; dim < input.dims(); dim++) {
|
||||
// For zero dimension, we don't count it into `input_num_elements`
|
||||
// unless `sizes` has no zero dimension, so we are still able to
|
||||
// infer shapes for other dimensions.
|
||||
if (input.dim_size(dim) > 0 || !sizes_has_zero_dim) {
|
||||
input_num_elements *= input.dim_size(dim);
|
||||
} else {
|
||||
input_has_zero_dim = true;
|
||||
}
|
||||
}
|
||||
|
||||
const int64 missing = input_num_elements / product;
|
||||
if (!input_has_zero_dim) {
|
||||
OP_REQUIRES(
|
||||
context, product > 0,
|
||||
errors::InvalidArgument("Reshape cannot infer the missing input size "
|
||||
"for an empty tensor unless all specified "
|
||||
"input sizes are non-zero"));
|
||||
const int64 missing = input.NumElements() / product;
|
||||
OP_REQUIRES(
|
||||
context, product * missing == input.NumElements(),
|
||||
context, product * missing == input_num_elements,
|
||||
errors::InvalidArgument(
|
||||
"Input to reshape is a tensor with ", input.NumElements(),
|
||||
"Input to reshape is a tensor with ", input_num_elements,
|
||||
" values, but the requested shape requires a multiple of ",
|
||||
product));
|
||||
}
|
||||
shape.set_dim(unknown_index, missing);
|
||||
}
|
||||
OP_REQUIRES(context, shape.num_elements() == input.NumElements(),
|
||||
@ -92,9 +105,10 @@ class ReshapeOp : public OpKernel {
|
||||
private:
|
||||
template <typename Tshape>
|
||||
Status ValidateSizes(const Tensor& sizes, int64* product, int* unknown_index,
|
||||
TensorShape* shape) {
|
||||
TensorShape* shape, bool* has_zero_dim) {
|
||||
*product = 1;
|
||||
*unknown_index = -1;
|
||||
*has_zero_dim = false;
|
||||
const int64 num_dims = sizes.NumElements();
|
||||
auto Svec = sizes.flat<Tshape>();
|
||||
for (int d = 0; d < num_dims; ++d) {
|
||||
@ -110,6 +124,12 @@ class ReshapeOp : public OpKernel {
|
||||
} else if (size < 0) {
|
||||
return errors::InvalidArgument("Size ", d,
|
||||
" must be non-negative, not ", size);
|
||||
} else if (size == 0) {
|
||||
// We don't include zero-sized dimension in product, so that we can
|
||||
// still calculate number of elements for non-zero-sized dimensions and
|
||||
// therefore infer their shapes.
|
||||
shape->AddDim(size);
|
||||
*has_zero_dim = true;
|
||||
} else {
|
||||
shape->AddDim(size);
|
||||
(*product) *= size;
|
||||
|
@ -45,6 +45,18 @@ class ReshapeTest(test.TestCase):
|
||||
self.assertEqual(tf_ans.get_shape(), out.shape)
|
||||
self.assertShapeEqual(np_ans, tf_ans)
|
||||
|
||||
def _testZeroDimReshape(self, x, shape, expected, use_gpu=False):
|
||||
with self.cached_session(use_gpu=use_gpu):
|
||||
y = array_ops.reshape(x, shape)
|
||||
out = self.evaluate(y)
|
||||
self.assertEqual(expected, out.shape)
|
||||
|
||||
# Repeat with an int64 shape tensor.
|
||||
shape64 = constant_op.constant(shape, dtype=dtypes.int64)
|
||||
y = array_ops.reshape(x, shape64)
|
||||
out = self.evaluate(y)
|
||||
self.assertEqual(expected, out.shape)
|
||||
|
||||
def _testBothReshape(self, x, y):
|
||||
self._testReshape(x, y, False)
|
||||
self._testReshape(x, y, True)
|
||||
@ -89,6 +101,18 @@ class ReshapeTest(test.TestCase):
|
||||
x = np.arange(1., 7.).reshape([6]).astype(np.float32)
|
||||
self._testBothReshape(x, [3, -1])
|
||||
|
||||
def testZeroDimBasic(self):
|
||||
x = np.zeros([0, 6]).astype(np.float32)
|
||||
self._testBothReshape(x, [0, 2, 3])
|
||||
|
||||
def testZeroDimReshapeR1(self):
|
||||
x = np.zeros([0, 6]).astype(np.float32)
|
||||
self._testBothReshape(x, [-1])
|
||||
|
||||
def testZeroDimReshapeR3(self):
|
||||
x = np.zeros([0, 6]).astype(np.float32)
|
||||
self._testBothReshape(x, [-1, 2, 3])
|
||||
|
||||
# TODO(vrv): Add tests for failure conditions once python test_util
|
||||
# reports errors.
|
||||
|
||||
@ -113,6 +137,13 @@ class ReshapeTest(test.TestCase):
|
||||
self._testBothReshape(x, [0, 0, 0])
|
||||
self._testBothReshape(x, [1, -1, 5])
|
||||
|
||||
def testZeroDimWithUnspecifiedDim(self):
|
||||
for use_gpu in (True, False):
|
||||
self._testZeroDimReshape(x=np.zeros([0, 6]).astype(np.float32),
|
||||
shape=[0, -1, 3],
|
||||
expected=(0, 2, 3),
|
||||
use_gpu=use_gpu)
|
||||
|
||||
@test_util.run_deprecated_v1
|
||||
def testErrors(self):
|
||||
y = constant_op.constant(0.0, shape=[23, 29, 31])
|
||||
|
Loading…
Reference in New Issue
Block a user