Make tf.reshape(tensor, shape)
work with the case that tensor
has 0 dimension and shape
has unknown dimension.
PiperOrigin-RevId: 243094911
This commit is contained in:
parent
86a6d44352
commit
36e10587b5
@ -51,6 +51,7 @@ class ReshapeOp : public XlaOpKernel {
|
|||||||
TensorShape shape;
|
TensorShape shape;
|
||||||
int64 product = 1;
|
int64 product = 1;
|
||||||
int unknown_index = -1;
|
int unknown_index = -1;
|
||||||
|
bool shape_has_zero_dim = false;
|
||||||
for (int d = 0; d < num_dims; ++d) {
|
for (int d = 0; d < num_dims; ++d) {
|
||||||
const int32 size = shape_input[d];
|
const int32 size = shape_input[d];
|
||||||
if (size == -1) {
|
if (size == -1) {
|
||||||
@ -60,6 +61,12 @@ class ReshapeOp : public XlaOpKernel {
|
|||||||
unknown_index, " and ", d));
|
unknown_index, " and ", d));
|
||||||
unknown_index = d;
|
unknown_index = d;
|
||||||
shape.AddDim(1);
|
shape.AddDim(1);
|
||||||
|
} else if (size == 0) {
|
||||||
|
// We don't include zero-sized dimension in product, so that we can
|
||||||
|
// still calculate number of elements for non-zero-sized dimensions and
|
||||||
|
// therefore infer their shapes.
|
||||||
|
shape.AddDim(size);
|
||||||
|
shape_has_zero_dim = true;
|
||||||
} else {
|
} else {
|
||||||
OP_REQUIRES(ctx, size >= 0,
|
OP_REQUIRES(ctx, size >= 0,
|
||||||
errors::InvalidArgument(
|
errors::InvalidArgument(
|
||||||
@ -69,18 +76,28 @@ class ReshapeOp : public XlaOpKernel {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
if (unknown_index != -1) {
|
if (unknown_index != -1) {
|
||||||
OP_REQUIRES(
|
int64 input_num_elements = 1;
|
||||||
ctx, product > 0,
|
bool input_has_zero_dim = false;
|
||||||
errors::InvalidArgument("Reshape cannot infer the missing input size "
|
for (int dim = 0; dim < input_shape.dims(); dim++) {
|
||||||
"for an empty tensor unless all specified "
|
// For zero dimension, we don't count it into `input_num_elements`
|
||||||
"input sizes are non-zero"));
|
// unless `sizes` has no zero dimension, so we are still able to
|
||||||
const int64 missing = input_shape.num_elements() / product;
|
// infer shapes for other dimensions.
|
||||||
OP_REQUIRES(
|
if (input_shape.dim_size(dim) > 0 || !shape_has_zero_dim) {
|
||||||
ctx, product * missing == input_shape.num_elements(),
|
input_num_elements *= input_shape.dim_size(dim);
|
||||||
errors::InvalidArgument(
|
} else {
|
||||||
"Input to reshape is a tensor with ", input_shape.num_elements(),
|
input_has_zero_dim = true;
|
||||||
" values, but the requested shape requires a multiple of ",
|
}
|
||||||
product));
|
}
|
||||||
|
|
||||||
|
const int64 missing = input_num_elements / product;
|
||||||
|
if (!input_has_zero_dim) {
|
||||||
|
OP_REQUIRES(
|
||||||
|
ctx, product * missing == input_num_elements,
|
||||||
|
errors::InvalidArgument(
|
||||||
|
"Input to reshape is a tensor with ", input_num_elements,
|
||||||
|
" values, but the requested shape requires a multiple of ",
|
||||||
|
product));
|
||||||
|
}
|
||||||
shape.set_dim(unknown_index, missing);
|
shape.set_dim(unknown_index, missing);
|
||||||
}
|
}
|
||||||
OP_REQUIRES(ctx, shape.num_elements() == input_shape.num_elements(),
|
OP_REQUIRES(ctx, shape.num_elements() == input_shape.num_elements(),
|
||||||
|
@ -45,14 +45,17 @@ class ReshapeOp : public OpKernel {
|
|||||||
TensorShape shape;
|
TensorShape shape;
|
||||||
int64 product = 1;
|
int64 product = 1;
|
||||||
int unknown_index = -1;
|
int unknown_index = -1;
|
||||||
|
bool sizes_has_zero_dim;
|
||||||
switch (sizes.dtype()) {
|
switch (sizes.dtype()) {
|
||||||
case DT_INT32:
|
case DT_INT32:
|
||||||
OP_REQUIRES_OK(context, ValidateSizes<int32>(sizes, &product,
|
OP_REQUIRES_OK(context,
|
||||||
&unknown_index, &shape));
|
ValidateSizes<int32>(sizes, &product, &unknown_index,
|
||||||
|
&shape, &sizes_has_zero_dim));
|
||||||
break;
|
break;
|
||||||
case DT_INT64:
|
case DT_INT64:
|
||||||
OP_REQUIRES_OK(context, ValidateSizes<int64>(sizes, &product,
|
OP_REQUIRES_OK(context,
|
||||||
&unknown_index, &shape));
|
ValidateSizes<int64>(sizes, &product, &unknown_index,
|
||||||
|
&shape, &sizes_has_zero_dim));
|
||||||
break;
|
break;
|
||||||
default:
|
default:
|
||||||
context->CtxFailure(errors::InvalidArgument(
|
context->CtxFailure(errors::InvalidArgument(
|
||||||
@ -61,18 +64,28 @@ class ReshapeOp : public OpKernel {
|
|||||||
return;
|
return;
|
||||||
}
|
}
|
||||||
if (unknown_index != -1) {
|
if (unknown_index != -1) {
|
||||||
OP_REQUIRES(
|
int64 input_num_elements = 1;
|
||||||
context, product > 0,
|
bool input_has_zero_dim = false;
|
||||||
errors::InvalidArgument("Reshape cannot infer the missing input size "
|
for (int dim = 0; dim < input.dims(); dim++) {
|
||||||
"for an empty tensor unless all specified "
|
// For zero dimension, we don't count it into `input_num_elements`
|
||||||
"input sizes are non-zero"));
|
// unless `sizes` has no zero dimension, so we are still able to
|
||||||
const int64 missing = input.NumElements() / product;
|
// infer shapes for other dimensions.
|
||||||
OP_REQUIRES(
|
if (input.dim_size(dim) > 0 || !sizes_has_zero_dim) {
|
||||||
context, product * missing == input.NumElements(),
|
input_num_elements *= input.dim_size(dim);
|
||||||
errors::InvalidArgument(
|
} else {
|
||||||
"Input to reshape is a tensor with ", input.NumElements(),
|
input_has_zero_dim = true;
|
||||||
" values, but the requested shape requires a multiple of ",
|
}
|
||||||
product));
|
}
|
||||||
|
|
||||||
|
const int64 missing = input_num_elements / product;
|
||||||
|
if (!input_has_zero_dim) {
|
||||||
|
OP_REQUIRES(
|
||||||
|
context, product * missing == input_num_elements,
|
||||||
|
errors::InvalidArgument(
|
||||||
|
"Input to reshape is a tensor with ", input_num_elements,
|
||||||
|
" values, but the requested shape requires a multiple of ",
|
||||||
|
product));
|
||||||
|
}
|
||||||
shape.set_dim(unknown_index, missing);
|
shape.set_dim(unknown_index, missing);
|
||||||
}
|
}
|
||||||
OP_REQUIRES(context, shape.num_elements() == input.NumElements(),
|
OP_REQUIRES(context, shape.num_elements() == input.NumElements(),
|
||||||
@ -92,9 +105,10 @@ class ReshapeOp : public OpKernel {
|
|||||||
private:
|
private:
|
||||||
template <typename Tshape>
|
template <typename Tshape>
|
||||||
Status ValidateSizes(const Tensor& sizes, int64* product, int* unknown_index,
|
Status ValidateSizes(const Tensor& sizes, int64* product, int* unknown_index,
|
||||||
TensorShape* shape) {
|
TensorShape* shape, bool* has_zero_dim) {
|
||||||
*product = 1;
|
*product = 1;
|
||||||
*unknown_index = -1;
|
*unknown_index = -1;
|
||||||
|
*has_zero_dim = false;
|
||||||
const int64 num_dims = sizes.NumElements();
|
const int64 num_dims = sizes.NumElements();
|
||||||
auto Svec = sizes.flat<Tshape>();
|
auto Svec = sizes.flat<Tshape>();
|
||||||
for (int d = 0; d < num_dims; ++d) {
|
for (int d = 0; d < num_dims; ++d) {
|
||||||
@ -110,6 +124,12 @@ class ReshapeOp : public OpKernel {
|
|||||||
} else if (size < 0) {
|
} else if (size < 0) {
|
||||||
return errors::InvalidArgument("Size ", d,
|
return errors::InvalidArgument("Size ", d,
|
||||||
" must be non-negative, not ", size);
|
" must be non-negative, not ", size);
|
||||||
|
} else if (size == 0) {
|
||||||
|
// We don't include zero-sized dimension in product, so that we can
|
||||||
|
// still calculate number of elements for non-zero-sized dimensions and
|
||||||
|
// therefore infer their shapes.
|
||||||
|
shape->AddDim(size);
|
||||||
|
*has_zero_dim = true;
|
||||||
} else {
|
} else {
|
||||||
shape->AddDim(size);
|
shape->AddDim(size);
|
||||||
(*product) *= size;
|
(*product) *= size;
|
||||||
|
@ -45,6 +45,18 @@ class ReshapeTest(test.TestCase):
|
|||||||
self.assertEqual(tf_ans.get_shape(), out.shape)
|
self.assertEqual(tf_ans.get_shape(), out.shape)
|
||||||
self.assertShapeEqual(np_ans, tf_ans)
|
self.assertShapeEqual(np_ans, tf_ans)
|
||||||
|
|
||||||
|
def _testZeroDimReshape(self, x, shape, expected, use_gpu=False):
|
||||||
|
with self.cached_session(use_gpu=use_gpu):
|
||||||
|
y = array_ops.reshape(x, shape)
|
||||||
|
out = self.evaluate(y)
|
||||||
|
self.assertEqual(expected, out.shape)
|
||||||
|
|
||||||
|
# Repeat with an int64 shape tensor.
|
||||||
|
shape64 = constant_op.constant(shape, dtype=dtypes.int64)
|
||||||
|
y = array_ops.reshape(x, shape64)
|
||||||
|
out = self.evaluate(y)
|
||||||
|
self.assertEqual(expected, out.shape)
|
||||||
|
|
||||||
def _testBothReshape(self, x, y):
|
def _testBothReshape(self, x, y):
|
||||||
self._testReshape(x, y, False)
|
self._testReshape(x, y, False)
|
||||||
self._testReshape(x, y, True)
|
self._testReshape(x, y, True)
|
||||||
@ -89,6 +101,18 @@ class ReshapeTest(test.TestCase):
|
|||||||
x = np.arange(1., 7.).reshape([6]).astype(np.float32)
|
x = np.arange(1., 7.).reshape([6]).astype(np.float32)
|
||||||
self._testBothReshape(x, [3, -1])
|
self._testBothReshape(x, [3, -1])
|
||||||
|
|
||||||
|
def testZeroDimBasic(self):
|
||||||
|
x = np.zeros([0, 6]).astype(np.float32)
|
||||||
|
self._testBothReshape(x, [0, 2, 3])
|
||||||
|
|
||||||
|
def testZeroDimReshapeR1(self):
|
||||||
|
x = np.zeros([0, 6]).astype(np.float32)
|
||||||
|
self._testBothReshape(x, [-1])
|
||||||
|
|
||||||
|
def testZeroDimReshapeR3(self):
|
||||||
|
x = np.zeros([0, 6]).astype(np.float32)
|
||||||
|
self._testBothReshape(x, [-1, 2, 3])
|
||||||
|
|
||||||
# TODO(vrv): Add tests for failure conditions once python test_util
|
# TODO(vrv): Add tests for failure conditions once python test_util
|
||||||
# reports errors.
|
# reports errors.
|
||||||
|
|
||||||
@ -113,6 +137,13 @@ class ReshapeTest(test.TestCase):
|
|||||||
self._testBothReshape(x, [0, 0, 0])
|
self._testBothReshape(x, [0, 0, 0])
|
||||||
self._testBothReshape(x, [1, -1, 5])
|
self._testBothReshape(x, [1, -1, 5])
|
||||||
|
|
||||||
|
def testZeroDimWithUnspecifiedDim(self):
|
||||||
|
for use_gpu in (True, False):
|
||||||
|
self._testZeroDimReshape(x=np.zeros([0, 6]).astype(np.float32),
|
||||||
|
shape=[0, -1, 3],
|
||||||
|
expected=(0, 2, 3),
|
||||||
|
use_gpu=use_gpu)
|
||||||
|
|
||||||
@test_util.run_deprecated_v1
|
@test_util.run_deprecated_v1
|
||||||
def testErrors(self):
|
def testErrors(self):
|
||||||
y = constant_op.constant(0.0, shape=[23, 29, 31])
|
y = constant_op.constant(0.0, shape=[23, 29, 31])
|
||||||
|
Loading…
Reference in New Issue
Block a user