No lowering on gradient case op when input is DeviceIndex op

PiperOrigin-RevId: 314468964
Change-Id: If137eaf1e6c28ba3b6770dd93bb1a747bb36e836
This commit is contained in:
Yanhua Sun 2020-06-02 22:16:48 -07:00 committed by TensorFlower Gardener
parent e5cf28829c
commit 3379fb99c8
2 changed files with 25 additions and 3 deletions

View File

@ -29,6 +29,7 @@ from tensorflow.python.eager import backprop_util
from tensorflow.python.framework import auto_control_deps_utils as acd
from tensorflow.python.framework import constant_op
from tensorflow.python.framework import dtypes
from tensorflow.python.framework import errors_impl
from tensorflow.python.framework import func_graph as func_graph_module
from tensorflow.python.framework import ops
from tensorflow.python.framework import tensor_util
@ -1061,8 +1062,17 @@ def _CaseGrad(op, *grads): # pylint: disable=invalid-name
# This modifies the graphs in branch_grad_graphs.
_make_output_composite_tensors_match(_CASE, branch_grad_graphs)
outputs = _build_case(case_op.inputs[0], branch_grad_graphs,
branches_grad_inputs, name="gradient")
try:
lowering = case_op._get_attr_bool("_lower_using_switch_merge")
except errors_impl.NotFoundError:
lowering = None
outputs = _build_case(
case_op.inputs[0],
branch_grad_graphs,
branches_grad_inputs,
name="gradient",
lower_using_switch_merge=lowering)
# The predicate has no gradient.
return [None] + outputs

View File

@ -41,6 +41,7 @@ from tensorflow.python.framework import test_util
from tensorflow.python.ops import array_ops
from tensorflow.python.ops import check_ops
from tensorflow.python.ops import control_flow_ops
from tensorflow.python.ops import control_flow_util_v2
from tensorflow.python.ops import control_flow_v2_toggles
from tensorflow.python.ops import custom_gradient
from tensorflow.python.ops import embedding_ops
@ -1022,7 +1023,16 @@ class IndexedCaseTest(test_util.TensorFlowTestCase, parameterized.TestCase):
self.assertEqual(expected, self.evaluate(case_out))
@parameterized.parameters((-1,), (1,), (4,), (5,))
def testCase_gradient(self, bi):
def testCase_gradient_disable_lowering(self, bi):
self._testCase_gradient(True, bi)
@parameterized.parameters((-1,), (1,), (4,), (5,))
def testCase_gradient_enable_lowering(self, bi):
self._testCase_gradient(False, bi)
def _testCase_gradient(self, disable_lowering, bi):
default_lowering = control_flow_util_v2._DISABLE_LOWER_USING_SWITCH_MERGE
control_flow_util_v2._DISABLE_LOWER_USING_SWITCH_MERGE = disable_lowering
nbranches = 5
inputs = [
array_ops.constant(float(bi), name="br{}_in".format(bi))
@ -1047,6 +1057,8 @@ class IndexedCaseTest(test_util.TensorFlowTestCase, parameterized.TestCase):
self.assertEqual(len(expected_grads), len(actual_grads))
for expected, actual in zip(expected_grads, actual_grads):
self.assertEqual(expected, self.evaluate(actual))
# reset to default value
control_flow_util_v2._DISABLE_LOWER_USING_SWITCH_MERGE = default_lowering
@parameterized.parameters((-2,), (2,), (5,))
def testCase_gradient_diffShapedIntermediates(self, bi):