Do not perform colocation checks for IdentityN since it just forwards its inputs.

PiperOrigin-RevId: 257230757
This commit is contained in:
Saurabh Saxena 2019-07-09 11:05:15 -07:00 committed by TensorFlower Gardener
parent 348fde8095
commit ca57a9d5a8
4 changed files with 12 additions and 1 deletions

View File

@ -1097,7 +1097,9 @@ tf_kernel_library(
tf_kernel_library(
name = "identity_n_op",
prefix = "identity_n_op",
deps = ARRAY_DEPS,
deps = ARRAY_DEPS + [
"//tensorflow/core:core_cpu_internal",
],
)
tf_kernel_library(

View File

@ -16,6 +16,7 @@ limitations under the License.
// See docs in ../ops/array_ops.cc.
#include "tensorflow/core/kernels/identity_n_op.h"
#include "tensorflow/core/common_runtime/input_colocation_exemption_registry.h"
#include "tensorflow/core/framework/op_kernel.h"
#include "tensorflow/core/framework/register_types.h"
#include "tensorflow/core/framework/tensor.h"
@ -24,5 +25,10 @@ limitations under the License.
namespace tensorflow {
REGISTER_KERNEL_BUILDER(Name("IdentityN").Device(DEVICE_DEFAULT), IdentityNOp);
// Do not worry about colocating IdentityN op with its resource inputs since
// it just forwards it's inputs anyway. This is needed because we create
// IdentityN nodes to club "all" outputs of functional ops while lowering to
// make the original functional op fetchable.
REGISTER_INPUT_COLOCATION_EXEMPTION("IdentityN");
} // namespace tensorflow

View File

@ -963,6 +963,7 @@ distribute_py_test(
deps = [
":single_loss_example",
"//tensorflow/contrib/tpu:tpu_lib",
"//tensorflow/python:framework_test_lib",
"//tensorflow/python:variables",
"//tensorflow/python/distribute:combinations",
"//tensorflow/python/distribute:strategy_combinations",

View File

@ -25,9 +25,11 @@ from tensorflow.python.distribute import strategy_combinations
from tensorflow.python.distribute.single_loss_example import single_loss_example
from tensorflow.python.eager import context
from tensorflow.python.eager import test
from tensorflow.python.framework import test_util
from tensorflow.python.ops import variables
@test_util.with_control_flow_v2
class SingleLossStepTest(test.TestCase, parameterized.TestCase):
@combinations.generate(