Correct graph check in broadcast_to gradient.

PiperOrigin-RevId: 294970783
Change-Id: I2774ddf6949a1ba6814cefd939695010ecda0861
This commit is contained in:
Alexandre Passos 2020-02-13 12:12:18 -08:00 committed by TensorFlower Gardener
parent 8212278e1a
commit be5116dd13

View File

@ -1135,7 +1135,7 @@ def _BroadcastToGrad(op, grad):
input_value = op.inputs[0]
broadcast_shape = op.inputs[1]
input_value_shape = array_ops.shape(input_value)
if not context.executing_eagerly():
if not isinstance(broadcast_shape, ops.EagerTensor):
broadcast_shape_static = tensor_shape.TensorShape(
pywrap_tf_session.TF_TryEvaluateConstant_wrapper(
broadcast_shape.graph._c_graph, broadcast_shape._as_tf_output())) # pylint: disable=protected-access