Support Identity/StopGradient ops for TensorLists in tf2xla bridge.
This is needed to support back_prop=False in v2 control flow. PiperOrigin-RevId: 298720533 Change-Id: I33cd2f2603cff07193c0275878ff22c4c8338fa8
This commit is contained in:
parent
0d145969f6
commit
a0c8fee0e6
tensorflow/compiler
@ -240,6 +240,22 @@ class WhileTest(xla_test.XLATestCase):
|
||||
self.assertAllEqual(r, np.array([(x + 3) * 2 for x in nums]))
|
||||
xla_context.Exit()
|
||||
|
||||
@test_util.enable_control_flow_v2
|
||||
def testMapBackPropFalse(self):
|
||||
if is_compile_on_demand():
|
||||
self.skipTest("list_ops are not supported in cpu_ondemand")
|
||||
with self.session(), self.test_scope():
|
||||
xla_context = control_flow_ops.XLAControlFlowContext()
|
||||
xla_context.Enter()
|
||||
nums = [1, 2, 3, 4, 5, 6]
|
||||
elems = constant_op.constant(nums, name="data")
|
||||
r = map_fn.map_fn(
|
||||
lambda x: math_ops.multiply(math_ops.add(x, 3), 2),
|
||||
elems,
|
||||
back_prop=False)
|
||||
self.assertAllEqual(r, np.array([(x + 3) * 2 for x in nums]))
|
||||
xla_context.Exit()
|
||||
|
||||
|
||||
def is_compile_on_demand():
|
||||
return ("TF_XLA_FLAGS" in os.environ and
|
||||
|
@ -13,6 +13,7 @@ See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
|
||||
#include "tensorflow/compiler/tf2xla/kernels/tensor_list_utils.h"
|
||||
#include "tensorflow/compiler/tf2xla/xla_op_kernel.h"
|
||||
#include "tensorflow/compiler/tf2xla/xla_op_registry.h"
|
||||
|
||||
@ -25,10 +26,15 @@ class IdentityOp : public XlaOpKernel {
|
||||
|
||||
void Compile(XlaOpKernelContext* ctx) override {
|
||||
for (int i = 0; i < ctx->num_inputs(); ++i) {
|
||||
// Forwards using the underlying op_kernel_context so both tensor and
|
||||
// resource values are forwarded correctly.
|
||||
ctx->op_kernel_context()->set_output(i,
|
||||
ctx->op_kernel_context()->input(i));
|
||||
if (IsTensorListInput(ctx, i)) {
|
||||
ctx->SetTensorListOutput(i, ctx->Input(i));
|
||||
} else {
|
||||
DCHECK(ctx->input_type(i) != DT_VARIANT);
|
||||
// Forwards using the underlying op_kernel_context so both tensor and
|
||||
// resource values are forwarded correctly.
|
||||
ctx->op_kernel_context()->set_output(
|
||||
i, ctx->op_kernel_context()->input(i));
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
@ -48,7 +54,7 @@ REGISTER_XLA_OP(Name("IdentityN")
|
||||
IdentityOp);
|
||||
REGISTER_XLA_OP(Name("PlaceholderWithDefault"), IdentityOp);
|
||||
REGISTER_XLA_OP(Name("PreventGradient"), IdentityOp);
|
||||
REGISTER_XLA_OP(Name("StopGradient"), IdentityOp);
|
||||
REGISTER_XLA_OP(Name("StopGradient").AllowVariantTypes(), IdentityOp);
|
||||
REGISTER_XLA_OP(Name("Snapshot"), IdentityOp);
|
||||
|
||||
} // namespace
|
||||
|
Loading…
Reference in New Issue
Block a user