Check that the Bcast result shape is equal to the output_shape in broadcast_to
PiperOrigin-RevId: 242557631
This commit is contained in:
parent
0eb15926b8
commit
c1ff471dec
@ -80,6 +80,10 @@ class BroadcastToOp : public OpKernel {
|
||||
errors::InvalidArgument(
|
||||
"Incompatible shapes: ", input_shape.DebugString(), " vs. ",
|
||||
output_shape.DebugString()));
|
||||
OP_REQUIRES(ctx, BCast::ToShape(bcast.output_shape()) == output_shape,
|
||||
errors::InvalidArgument("Unable to broadcast tensor of shape ",
|
||||
input_shape, " to tensor of shape ",
|
||||
output_shape));
|
||||
|
||||
functor::BroadcastTo<Device, T>()(device, ctx, *output_tensor, output_shape,
|
||||
input_tensor, input_shape, bcast);
|
||||
|
@ -19,8 +19,10 @@ from __future__ import print_function
|
||||
|
||||
import numpy as np
|
||||
|
||||
from tensorflow.python.eager import context
|
||||
from tensorflow.python.framework import constant_op
|
||||
from tensorflow.python.framework import dtypes
|
||||
from tensorflow.python.framework import errors
|
||||
from tensorflow.python.framework import test_util
|
||||
from tensorflow.python.ops import array_ops
|
||||
from tensorflow.python.ops import gradient_checker
|
||||
@ -127,6 +129,14 @@ class BroadcastToTest(test_util.TensorFlowTestCase):
|
||||
# check shape inference when shape input is constant
|
||||
self.assertAllEqual(shape, v_np.shape)
|
||||
|
||||
def testBroadcastToBadOutputShape(self):
|
||||
with context.eager_mode():
|
||||
with self.assertRaisesRegexp(errors.InvalidArgumentError,
|
||||
"Unable to broadcast tensor of shape"):
|
||||
self.evaluate(
|
||||
array_ops.broadcast_to(
|
||||
constant_op.constant([0, 1]), constant_op.constant([2, 1])))
|
||||
|
||||
@test_util.run_deprecated_v1
|
||||
def testGradientForScalar(self):
|
||||
x = constant_op.constant(1, dtype=dtypes.float32)
|
||||
|
Loading…
Reference in New Issue
Block a user