From a0c8fee0e68900f70f70ff301da398543421a10c Mon Sep 17 00:00:00 2001 From: Saurabh Saxena <srbs@google.com> Date: Tue, 3 Mar 2020 16:22:40 -0800 Subject: [PATCH] Support Identity/StopGradient ops for TensorLists in tf2xla bridge. This is needed to support back_prop=False in v2 control flow. PiperOrigin-RevId: 298720533 Change-Id: I33cd2f2603cff07193c0275878ff22c4c8338fa8 --- tensorflow/compiler/tests/while_test.py | 16 ++++++++++++++++ .../compiler/tf2xla/kernels/identity_op.cc | 16 +++++++++++----- 2 files changed, 27 insertions(+), 5 deletions(-) diff --git a/tensorflow/compiler/tests/while_test.py b/tensorflow/compiler/tests/while_test.py index 420dc04bec3..f1f8b6c353c 100644 --- a/tensorflow/compiler/tests/while_test.py +++ b/tensorflow/compiler/tests/while_test.py @@ -240,6 +240,22 @@ class WhileTest(xla_test.XLATestCase): self.assertAllEqual(r, np.array([(x + 3) * 2 for x in nums])) xla_context.Exit() + @test_util.enable_control_flow_v2 + def testMapBackPropFalse(self): + if is_compile_on_demand(): + self.skipTest("list_ops are not supported in cpu_ondemand") + with self.session(), self.test_scope(): + xla_context = control_flow_ops.XLAControlFlowContext() + xla_context.Enter() + nums = [1, 2, 3, 4, 5, 6] + elems = constant_op.constant(nums, name="data") + r = map_fn.map_fn( + lambda x: math_ops.multiply(math_ops.add(x, 3), 2), + elems, + back_prop=False) + self.assertAllEqual(r, np.array([(x + 3) * 2 for x in nums])) + xla_context.Exit() + def is_compile_on_demand(): return ("TF_XLA_FLAGS" in os.environ and diff --git a/tensorflow/compiler/tf2xla/kernels/identity_op.cc b/tensorflow/compiler/tf2xla/kernels/identity_op.cc index 8b27e8e85a3..38d8056d3e5 100644 --- a/tensorflow/compiler/tf2xla/kernels/identity_op.cc +++ b/tensorflow/compiler/tf2xla/kernels/identity_op.cc @@ -13,6 +13,7 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ +#include "tensorflow/compiler/tf2xla/kernels/tensor_list_utils.h" #include "tensorflow/compiler/tf2xla/xla_op_kernel.h" #include "tensorflow/compiler/tf2xla/xla_op_registry.h" @@ -25,10 +26,15 @@ class IdentityOp : public XlaOpKernel { void Compile(XlaOpKernelContext* ctx) override { for (int i = 0; i < ctx->num_inputs(); ++i) { - // Forwards using the underlying op_kernel_context so both tensor and - // resource values are forwarded correctly. - ctx->op_kernel_context()->set_output(i, - ctx->op_kernel_context()->input(i)); + if (IsTensorListInput(ctx, i)) { + ctx->SetTensorListOutput(i, ctx->Input(i)); + } else { + DCHECK(ctx->input_type(i) != DT_VARIANT); + // Forwards using the underlying op_kernel_context so both tensor and + // resource values are forwarded correctly. + ctx->op_kernel_context()->set_output( + i, ctx->op_kernel_context()->input(i)); + } } } @@ -48,7 +54,7 @@ REGISTER_XLA_OP(Name("IdentityN") IdentityOp); REGISTER_XLA_OP(Name("PlaceholderWithDefault"), IdentityOp); REGISTER_XLA_OP(Name("PreventGradient"), IdentityOp); -REGISTER_XLA_OP(Name("StopGradient"), IdentityOp); +REGISTER_XLA_OP(Name("StopGradient").AllowVariantTypes(), IdentityOp); REGISTER_XLA_OP(Name("Snapshot"), IdentityOp); } // namespace