Correct graph check in broadcast_to gradient.
PiperOrigin-RevId: 294970783 Change-Id: I2774ddf6949a1ba6814cefd939695010ecda0861
This commit is contained in:
parent
8212278e1a
commit
be5116dd13
@ -1135,7 +1135,7 @@ def _BroadcastToGrad(op, grad):
|
||||
input_value = op.inputs[0]
|
||||
broadcast_shape = op.inputs[1]
|
||||
input_value_shape = array_ops.shape(input_value)
|
||||
if not context.executing_eagerly():
|
||||
if not isinstance(broadcast_shape, ops.EagerTensor):
|
||||
broadcast_shape_static = tensor_shape.TensorShape(
|
||||
pywrap_tf_session.TF_TryEvaluateConstant_wrapper(
|
||||
broadcast_shape.graph._c_graph, broadcast_shape._as_tf_output())) # pylint: disable=protected-access
|
||||
|
Loading…
Reference in New Issue
Block a user