diff --git a/tensorflow/core/kernels/BUILD b/tensorflow/core/kernels/BUILD index 35d6a81751f..571c6aea635 100644 --- a/tensorflow/core/kernels/BUILD +++ b/tensorflow/core/kernels/BUILD @@ -1097,7 +1097,9 @@ tf_kernel_library( tf_kernel_library( name = "identity_n_op", prefix = "identity_n_op", - deps = ARRAY_DEPS, + deps = ARRAY_DEPS + [ + "//tensorflow/core:core_cpu_internal", + ], ) tf_kernel_library( diff --git a/tensorflow/core/kernels/identity_n_op.cc b/tensorflow/core/kernels/identity_n_op.cc index 9746b1fab3e..746a29bf5aa 100644 --- a/tensorflow/core/kernels/identity_n_op.cc +++ b/tensorflow/core/kernels/identity_n_op.cc @@ -16,6 +16,7 @@ limitations under the License. // See docs in ../ops/array_ops.cc. #include "tensorflow/core/kernels/identity_n_op.h" +#include "tensorflow/core/common_runtime/input_colocation_exemption_registry.h" #include "tensorflow/core/framework/op_kernel.h" #include "tensorflow/core/framework/register_types.h" #include "tensorflow/core/framework/tensor.h" @@ -24,5 +25,10 @@ limitations under the License. namespace tensorflow { REGISTER_KERNEL_BUILDER(Name("IdentityN").Device(DEVICE_DEFAULT), IdentityNOp); +// Do not worry about colocating IdentityN op with its resource inputs since +// it just forwards it's inputs anyway. This is needed because we create +// IdentityN nodes to club "all" outputs of functional ops while lowering to +// make the original functional op fetchable. +REGISTER_INPUT_COLOCATION_EXEMPTION("IdentityN"); } // namespace tensorflow diff --git a/tensorflow/python/distribute/BUILD b/tensorflow/python/distribute/BUILD index 28f3493ecc4..4f750277e44 100644 --- a/tensorflow/python/distribute/BUILD +++ b/tensorflow/python/distribute/BUILD @@ -963,6 +963,7 @@ distribute_py_test( deps = [ ":single_loss_example", "//tensorflow/contrib/tpu:tpu_lib", + "//tensorflow/python:framework_test_lib", "//tensorflow/python:variables", "//tensorflow/python/distribute:combinations", "//tensorflow/python/distribute:strategy_combinations", diff --git a/tensorflow/python/distribute/step_fn_test.py b/tensorflow/python/distribute/step_fn_test.py index 3a1a7e6eca4..28e6ad28c77 100644 --- a/tensorflow/python/distribute/step_fn_test.py +++ b/tensorflow/python/distribute/step_fn_test.py @@ -25,9 +25,11 @@ from tensorflow.python.distribute import strategy_combinations from tensorflow.python.distribute.single_loss_example import single_loss_example from tensorflow.python.eager import context from tensorflow.python.eager import test +from tensorflow.python.framework import test_util from tensorflow.python.ops import variables +@test_util.with_control_flow_v2 class SingleLossStepTest(test.TestCase, parameterized.TestCase): @combinations.generate(