No lowering on gradient case op when input is DeviceIndex op
PiperOrigin-RevId: 314468964 Change-Id: If137eaf1e6c28ba3b6770dd93bb1a747bb36e836
This commit is contained in:
parent
e5cf28829c
commit
3379fb99c8
|
@ -29,6 +29,7 @@ from tensorflow.python.eager import backprop_util
|
|||
from tensorflow.python.framework import auto_control_deps_utils as acd
|
||||
from tensorflow.python.framework import constant_op
|
||||
from tensorflow.python.framework import dtypes
|
||||
from tensorflow.python.framework import errors_impl
|
||||
from tensorflow.python.framework import func_graph as func_graph_module
|
||||
from tensorflow.python.framework import ops
|
||||
from tensorflow.python.framework import tensor_util
|
||||
|
@ -1061,8 +1062,17 @@ def _CaseGrad(op, *grads): # pylint: disable=invalid-name
|
|||
# This modifies the graphs in branch_grad_graphs.
|
||||
_make_output_composite_tensors_match(_CASE, branch_grad_graphs)
|
||||
|
||||
outputs = _build_case(case_op.inputs[0], branch_grad_graphs,
|
||||
branches_grad_inputs, name="gradient")
|
||||
try:
|
||||
lowering = case_op._get_attr_bool("_lower_using_switch_merge")
|
||||
except errors_impl.NotFoundError:
|
||||
lowering = None
|
||||
|
||||
outputs = _build_case(
|
||||
case_op.inputs[0],
|
||||
branch_grad_graphs,
|
||||
branches_grad_inputs,
|
||||
name="gradient",
|
||||
lower_using_switch_merge=lowering)
|
||||
|
||||
# The predicate has no gradient.
|
||||
return [None] + outputs
|
||||
|
|
|
@ -41,6 +41,7 @@ from tensorflow.python.framework import test_util
|
|||
from tensorflow.python.ops import array_ops
|
||||
from tensorflow.python.ops import check_ops
|
||||
from tensorflow.python.ops import control_flow_ops
|
||||
from tensorflow.python.ops import control_flow_util_v2
|
||||
from tensorflow.python.ops import control_flow_v2_toggles
|
||||
from tensorflow.python.ops import custom_gradient
|
||||
from tensorflow.python.ops import embedding_ops
|
||||
|
@ -1022,7 +1023,16 @@ class IndexedCaseTest(test_util.TensorFlowTestCase, parameterized.TestCase):
|
|||
self.assertEqual(expected, self.evaluate(case_out))
|
||||
|
||||
@parameterized.parameters((-1,), (1,), (4,), (5,))
|
||||
def testCase_gradient(self, bi):
|
||||
def testCase_gradient_disable_lowering(self, bi):
|
||||
self._testCase_gradient(True, bi)
|
||||
|
||||
@parameterized.parameters((-1,), (1,), (4,), (5,))
|
||||
def testCase_gradient_enable_lowering(self, bi):
|
||||
self._testCase_gradient(False, bi)
|
||||
|
||||
def _testCase_gradient(self, disable_lowering, bi):
|
||||
default_lowering = control_flow_util_v2._DISABLE_LOWER_USING_SWITCH_MERGE
|
||||
control_flow_util_v2._DISABLE_LOWER_USING_SWITCH_MERGE = disable_lowering
|
||||
nbranches = 5
|
||||
inputs = [
|
||||
array_ops.constant(float(bi), name="br{}_in".format(bi))
|
||||
|
@ -1047,6 +1057,8 @@ class IndexedCaseTest(test_util.TensorFlowTestCase, parameterized.TestCase):
|
|||
self.assertEqual(len(expected_grads), len(actual_grads))
|
||||
for expected, actual in zip(expected_grads, actual_grads):
|
||||
self.assertEqual(expected, self.evaluate(actual))
|
||||
# reset to default value
|
||||
control_flow_util_v2._DISABLE_LOWER_USING_SWITCH_MERGE = default_lowering
|
||||
|
||||
@parameterized.parameters((-2,), (2,), (5,))
|
||||
def testCase_gradient_diffShapedIntermediates(self, bi):
|
||||
|
|
Loading…
Reference in New Issue