Do not perform colocation checks for IdentityN since it just forwards its inputs.
PiperOrigin-RevId: 257230757
This commit is contained in:
parent
348fde8095
commit
ca57a9d5a8
@ -1097,7 +1097,9 @@ tf_kernel_library(
|
|||||||
tf_kernel_library(
|
tf_kernel_library(
|
||||||
name = "identity_n_op",
|
name = "identity_n_op",
|
||||||
prefix = "identity_n_op",
|
prefix = "identity_n_op",
|
||||||
deps = ARRAY_DEPS,
|
deps = ARRAY_DEPS + [
|
||||||
|
"//tensorflow/core:core_cpu_internal",
|
||||||
|
],
|
||||||
)
|
)
|
||||||
|
|
||||||
tf_kernel_library(
|
tf_kernel_library(
|
||||||
|
@ -16,6 +16,7 @@ limitations under the License.
|
|||||||
// See docs in ../ops/array_ops.cc.
|
// See docs in ../ops/array_ops.cc.
|
||||||
#include "tensorflow/core/kernels/identity_n_op.h"
|
#include "tensorflow/core/kernels/identity_n_op.h"
|
||||||
|
|
||||||
|
#include "tensorflow/core/common_runtime/input_colocation_exemption_registry.h"
|
||||||
#include "tensorflow/core/framework/op_kernel.h"
|
#include "tensorflow/core/framework/op_kernel.h"
|
||||||
#include "tensorflow/core/framework/register_types.h"
|
#include "tensorflow/core/framework/register_types.h"
|
||||||
#include "tensorflow/core/framework/tensor.h"
|
#include "tensorflow/core/framework/tensor.h"
|
||||||
@ -24,5 +25,10 @@ limitations under the License.
|
|||||||
namespace tensorflow {
|
namespace tensorflow {
|
||||||
|
|
||||||
REGISTER_KERNEL_BUILDER(Name("IdentityN").Device(DEVICE_DEFAULT), IdentityNOp);
|
REGISTER_KERNEL_BUILDER(Name("IdentityN").Device(DEVICE_DEFAULT), IdentityNOp);
|
||||||
|
// Do not worry about colocating IdentityN op with its resource inputs since
|
||||||
|
// it just forwards it's inputs anyway. This is needed because we create
|
||||||
|
// IdentityN nodes to club "all" outputs of functional ops while lowering to
|
||||||
|
// make the original functional op fetchable.
|
||||||
|
REGISTER_INPUT_COLOCATION_EXEMPTION("IdentityN");
|
||||||
|
|
||||||
} // namespace tensorflow
|
} // namespace tensorflow
|
||||||
|
@ -963,6 +963,7 @@ distribute_py_test(
|
|||||||
deps = [
|
deps = [
|
||||||
":single_loss_example",
|
":single_loss_example",
|
||||||
"//tensorflow/contrib/tpu:tpu_lib",
|
"//tensorflow/contrib/tpu:tpu_lib",
|
||||||
|
"//tensorflow/python:framework_test_lib",
|
||||||
"//tensorflow/python:variables",
|
"//tensorflow/python:variables",
|
||||||
"//tensorflow/python/distribute:combinations",
|
"//tensorflow/python/distribute:combinations",
|
||||||
"//tensorflow/python/distribute:strategy_combinations",
|
"//tensorflow/python/distribute:strategy_combinations",
|
||||||
|
@ -25,9 +25,11 @@ from tensorflow.python.distribute import strategy_combinations
|
|||||||
from tensorflow.python.distribute.single_loss_example import single_loss_example
|
from tensorflow.python.distribute.single_loss_example import single_loss_example
|
||||||
from tensorflow.python.eager import context
|
from tensorflow.python.eager import context
|
||||||
from tensorflow.python.eager import test
|
from tensorflow.python.eager import test
|
||||||
|
from tensorflow.python.framework import test_util
|
||||||
from tensorflow.python.ops import variables
|
from tensorflow.python.ops import variables
|
||||||
|
|
||||||
|
|
||||||
|
@test_util.with_control_flow_v2
|
||||||
class SingleLossStepTest(test.TestCase, parameterized.TestCase):
|
class SingleLossStepTest(test.TestCase, parameterized.TestCase):
|
||||||
|
|
||||||
@combinations.generate(
|
@combinations.generate(
|
||||||
|
Loading…
x
Reference in New Issue
Block a user