Support Identity/StopGradient ops for TensorLists in tf2xla bridge.

This is needed to support back_prop=False in v2 control flow.

PiperOrigin-RevId: 298720533
Change-Id: I33cd2f2603cff07193c0275878ff22c4c8338fa8
This commit is contained in:
Saurabh Saxena 2020-03-03 16:22:40 -08:00 committed by TensorFlower Gardener
parent 0d145969f6
commit a0c8fee0e6
2 changed files with 27 additions and 5 deletions
tensorflow/compiler
tests
tf2xla/kernels

View File

@ -240,6 +240,22 @@ class WhileTest(xla_test.XLATestCase):
self.assertAllEqual(r, np.array([(x + 3) * 2 for x in nums]))
xla_context.Exit()
@test_util.enable_control_flow_v2
def testMapBackPropFalse(self):
if is_compile_on_demand():
self.skipTest("list_ops are not supported in cpu_ondemand")
with self.session(), self.test_scope():
xla_context = control_flow_ops.XLAControlFlowContext()
xla_context.Enter()
nums = [1, 2, 3, 4, 5, 6]
elems = constant_op.constant(nums, name="data")
r = map_fn.map_fn(
lambda x: math_ops.multiply(math_ops.add(x, 3), 2),
elems,
back_prop=False)
self.assertAllEqual(r, np.array([(x + 3) * 2 for x in nums]))
xla_context.Exit()
def is_compile_on_demand():
return ("TF_XLA_FLAGS" in os.environ and

View File

@ -13,6 +13,7 @@ See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#include "tensorflow/compiler/tf2xla/kernels/tensor_list_utils.h"
#include "tensorflow/compiler/tf2xla/xla_op_kernel.h"
#include "tensorflow/compiler/tf2xla/xla_op_registry.h"
@ -25,10 +26,15 @@ class IdentityOp : public XlaOpKernel {
void Compile(XlaOpKernelContext* ctx) override {
for (int i = 0; i < ctx->num_inputs(); ++i) {
// Forwards using the underlying op_kernel_context so both tensor and
// resource values are forwarded correctly.
ctx->op_kernel_context()->set_output(i,
ctx->op_kernel_context()->input(i));
if (IsTensorListInput(ctx, i)) {
ctx->SetTensorListOutput(i, ctx->Input(i));
} else {
DCHECK(ctx->input_type(i) != DT_VARIANT);
// Forwards using the underlying op_kernel_context so both tensor and
// resource values are forwarded correctly.
ctx->op_kernel_context()->set_output(
i, ctx->op_kernel_context()->input(i));
}
}
}
@ -48,7 +54,7 @@ REGISTER_XLA_OP(Name("IdentityN")
IdentityOp);
REGISTER_XLA_OP(Name("PlaceholderWithDefault"), IdentityOp);
REGISTER_XLA_OP(Name("PreventGradient"), IdentityOp);
REGISTER_XLA_OP(Name("StopGradient"), IdentityOp);
REGISTER_XLA_OP(Name("StopGradient").AllowVariantTypes(), IdentityOp);
REGISTER_XLA_OP(Name("Snapshot"), IdentityOp);
} // namespace