Do not perform colocation checks for IdentityN since it just forwards its inputs.
PiperOrigin-RevId: 257230757
This commit is contained in:
parent
348fde8095
commit
ca57a9d5a8
@ -1097,7 +1097,9 @@ tf_kernel_library(
|
||||
tf_kernel_library(
|
||||
name = "identity_n_op",
|
||||
prefix = "identity_n_op",
|
||||
deps = ARRAY_DEPS,
|
||||
deps = ARRAY_DEPS + [
|
||||
"//tensorflow/core:core_cpu_internal",
|
||||
],
|
||||
)
|
||||
|
||||
tf_kernel_library(
|
||||
|
@ -16,6 +16,7 @@ limitations under the License.
|
||||
// See docs in ../ops/array_ops.cc.
|
||||
#include "tensorflow/core/kernels/identity_n_op.h"
|
||||
|
||||
#include "tensorflow/core/common_runtime/input_colocation_exemption_registry.h"
|
||||
#include "tensorflow/core/framework/op_kernel.h"
|
||||
#include "tensorflow/core/framework/register_types.h"
|
||||
#include "tensorflow/core/framework/tensor.h"
|
||||
@ -24,5 +25,10 @@ limitations under the License.
|
||||
namespace tensorflow {
|
||||
|
||||
REGISTER_KERNEL_BUILDER(Name("IdentityN").Device(DEVICE_DEFAULT), IdentityNOp);
|
||||
// Do not worry about colocating IdentityN op with its resource inputs since
|
||||
// it just forwards it's inputs anyway. This is needed because we create
|
||||
// IdentityN nodes to club "all" outputs of functional ops while lowering to
|
||||
// make the original functional op fetchable.
|
||||
REGISTER_INPUT_COLOCATION_EXEMPTION("IdentityN");
|
||||
|
||||
} // namespace tensorflow
|
||||
|
@ -963,6 +963,7 @@ distribute_py_test(
|
||||
deps = [
|
||||
":single_loss_example",
|
||||
"//tensorflow/contrib/tpu:tpu_lib",
|
||||
"//tensorflow/python:framework_test_lib",
|
||||
"//tensorflow/python:variables",
|
||||
"//tensorflow/python/distribute:combinations",
|
||||
"//tensorflow/python/distribute:strategy_combinations",
|
||||
|
@ -25,9 +25,11 @@ from tensorflow.python.distribute import strategy_combinations
|
||||
from tensorflow.python.distribute.single_loss_example import single_loss_example
|
||||
from tensorflow.python.eager import context
|
||||
from tensorflow.python.eager import test
|
||||
from tensorflow.python.framework import test_util
|
||||
from tensorflow.python.ops import variables
|
||||
|
||||
|
||||
@test_util.with_control_flow_v2
|
||||
class SingleLossStepTest(test.TestCase, parameterized.TestCase):
|
||||
|
||||
@combinations.generate(
|
||||
|
Loading…
x
Reference in New Issue
Block a user