Check that the Bcast result shape is equal to the output_shape in broadcast_to
PiperOrigin-RevId: 242557631
This commit is contained in:
parent
0eb15926b8
commit
c1ff471dec
@ -80,6 +80,10 @@ class BroadcastToOp : public OpKernel {
|
|||||||
errors::InvalidArgument(
|
errors::InvalidArgument(
|
||||||
"Incompatible shapes: ", input_shape.DebugString(), " vs. ",
|
"Incompatible shapes: ", input_shape.DebugString(), " vs. ",
|
||||||
output_shape.DebugString()));
|
output_shape.DebugString()));
|
||||||
|
OP_REQUIRES(ctx, BCast::ToShape(bcast.output_shape()) == output_shape,
|
||||||
|
errors::InvalidArgument("Unable to broadcast tensor of shape ",
|
||||||
|
input_shape, " to tensor of shape ",
|
||||||
|
output_shape));
|
||||||
|
|
||||||
functor::BroadcastTo<Device, T>()(device, ctx, *output_tensor, output_shape,
|
functor::BroadcastTo<Device, T>()(device, ctx, *output_tensor, output_shape,
|
||||||
input_tensor, input_shape, bcast);
|
input_tensor, input_shape, bcast);
|
||||||
|
@ -19,8 +19,10 @@ from __future__ import print_function
|
|||||||
|
|
||||||
import numpy as np
|
import numpy as np
|
||||||
|
|
||||||
|
from tensorflow.python.eager import context
|
||||||
from tensorflow.python.framework import constant_op
|
from tensorflow.python.framework import constant_op
|
||||||
from tensorflow.python.framework import dtypes
|
from tensorflow.python.framework import dtypes
|
||||||
|
from tensorflow.python.framework import errors
|
||||||
from tensorflow.python.framework import test_util
|
from tensorflow.python.framework import test_util
|
||||||
from tensorflow.python.ops import array_ops
|
from tensorflow.python.ops import array_ops
|
||||||
from tensorflow.python.ops import gradient_checker
|
from tensorflow.python.ops import gradient_checker
|
||||||
@ -127,6 +129,14 @@ class BroadcastToTest(test_util.TensorFlowTestCase):
|
|||||||
# check shape inference when shape input is constant
|
# check shape inference when shape input is constant
|
||||||
self.assertAllEqual(shape, v_np.shape)
|
self.assertAllEqual(shape, v_np.shape)
|
||||||
|
|
||||||
|
def testBroadcastToBadOutputShape(self):
|
||||||
|
with context.eager_mode():
|
||||||
|
with self.assertRaisesRegexp(errors.InvalidArgumentError,
|
||||||
|
"Unable to broadcast tensor of shape"):
|
||||||
|
self.evaluate(
|
||||||
|
array_ops.broadcast_to(
|
||||||
|
constant_op.constant([0, 1]), constant_op.constant([2, 1])))
|
||||||
|
|
||||||
@test_util.run_deprecated_v1
|
@test_util.run_deprecated_v1
|
||||||
def testGradientForScalar(self):
|
def testGradientForScalar(self):
|
||||||
x = constant_op.constant(1, dtype=dtypes.float32)
|
x = constant_op.constant(1, dtype=dtypes.float32)
|
||||||
|
Loading…
Reference in New Issue
Block a user