Add data_format attribute constraint for CPU kernels of SpaceToDepth and DepthToSpace ops.
Ops such as DepthToSpace and SpaceToDepth can only run on CPU if data_format=NHWC. Currently when we run a graph with these ops on both CPU and GPU devices, the op placement algorithm always prefers GPU over CPU. This avoids error if data_format=NCHW. In the future we may have alternative op placement algorithm which may put these ops on CPU to e.g. minimize memory copy overhead. It will be better to explicitly enforce this constraint when we determine supported device type for nodes. PiperOrigin-RevId: 271168551
This commit is contained in:
parent
62bbb4d2da
commit
83aecde638
tensorflow
core
python/kernel_tests
@ -37,44 +37,106 @@ Status KernelAttrsMatch(const KernelDef& kernel_def, AttrSlice attrs,
|
||||
bool* match) {
|
||||
*match = false;
|
||||
for (const auto& constraint : kernel_def.constraint()) {
|
||||
if (constraint.allowed_values().list().type_size() == 0) {
|
||||
auto constraint_value_case = AttrValue::VALUE_NOT_SET;
|
||||
int value_type_num = 0;
|
||||
if (constraint.allowed_values().list().type_size() > 0) {
|
||||
constraint_value_case = AttrValue::kType;
|
||||
value_type_num++;
|
||||
}
|
||||
if (constraint.allowed_values().list().s_size() > 0) {
|
||||
constraint_value_case = AttrValue::kS;
|
||||
value_type_num++;
|
||||
}
|
||||
if (constraint.allowed_values().list().i_size() > 0) {
|
||||
constraint_value_case = AttrValue::kI;
|
||||
value_type_num++;
|
||||
}
|
||||
if (constraint.allowed_values().list().b_size() > 0) {
|
||||
constraint_value_case = AttrValue::kB;
|
||||
value_type_num++;
|
||||
}
|
||||
|
||||
if (value_type_num == 0) {
|
||||
return errors::Unimplemented(
|
||||
"KernelDef '", kernel_def.ShortDebugString(),
|
||||
" has constraint on attr '", constraint.name(),
|
||||
"' with unsupported type: ",
|
||||
SummarizeAttrValue(constraint.allowed_values()));
|
||||
}
|
||||
if (value_type_num > 1) {
|
||||
return errors::InvalidArgument(
|
||||
"KernelDef '", kernel_def.ShortDebugString(),
|
||||
" has constraint on attr '", constraint.name(),
|
||||
"' with more than one value type: ",
|
||||
SummarizeAttrValue(constraint.allowed_values()));
|
||||
}
|
||||
|
||||
const AttrValue* found = attrs.Find(constraint.name());
|
||||
if (found) {
|
||||
if (found->type() != DT_INVALID) {
|
||||
if (!InTypeList(found->type(), constraint.allowed_values())) {
|
||||
return Status::OK();
|
||||
}
|
||||
} else {
|
||||
if (!AttrValueHasType(*found, "list(type)").ok()) {
|
||||
return errors::InvalidArgument(
|
||||
"KernelDef '", kernel_def.ShortDebugString(),
|
||||
"' has constraint on attr '", constraint.name(),
|
||||
"' that has value '", SummarizeAttrValue(*found),
|
||||
"' that does not have type 'type' or 'list(type)' in NodeDef "
|
||||
"'",
|
||||
attrs.SummarizeNode(), "'");
|
||||
}
|
||||
|
||||
for (int t : found->list().type()) {
|
||||
if (!InTypeList(static_cast<DataType>(t),
|
||||
constraint.allowed_values())) {
|
||||
return Status::OK();
|
||||
}
|
||||
}
|
||||
}
|
||||
} else {
|
||||
const AttrValue* attr_value = attrs.Find(constraint.name());
|
||||
if (attr_value == nullptr) {
|
||||
return errors::InvalidArgument(
|
||||
"OpKernel '", kernel_def.op(), "' has constraint on attr '",
|
||||
constraint.name(), "' not in NodeDef '", attrs.SummarizeNode(),
|
||||
"', KernelDef: '", kernel_def.ShortDebugString(), "'");
|
||||
}
|
||||
|
||||
#define RETURN_IF_ATTR_NOT_FOUND(n, oneof_case, type_str) \
|
||||
do { \
|
||||
if (constraint_value_case == AttrValue::oneof_case) { \
|
||||
Status s = AttrValueHasType(*attr_value, type_str); \
|
||||
if (!s.ok()) { \
|
||||
return errors::InvalidArgument( \
|
||||
"KernelDef '", kernel_def.ShortDebugString(), \
|
||||
"' has constraint on attr '", constraint.name(), \
|
||||
"' that has value '", SummarizeAttrValue(*attr_value), \
|
||||
"' that does not have the same type in NodeDef " \
|
||||
"'", \
|
||||
attrs.SummarizeNode(), "'"); \
|
||||
} \
|
||||
bool found = false; \
|
||||
for (auto& value : constraint.allowed_values().list().n()) { \
|
||||
if (value == attr_value->n()) { \
|
||||
found = true; \
|
||||
break; \
|
||||
} \
|
||||
} \
|
||||
if (!found) { \
|
||||
return Status::OK(); \
|
||||
} \
|
||||
} \
|
||||
} while (false)
|
||||
|
||||
RETURN_IF_ATTR_NOT_FOUND(s, kS, "string");
|
||||
RETURN_IF_ATTR_NOT_FOUND(i, kI, "int");
|
||||
RETURN_IF_ATTR_NOT_FOUND(b, kB, "bool");
|
||||
|
||||
#undef RETURN_IF_ATTR_NOT_FOUND
|
||||
|
||||
if (constraint_value_case != AttrValue::kType) {
|
||||
continue;
|
||||
}
|
||||
|
||||
if (attr_value->type() != DT_INVALID) {
|
||||
if (!InTypeList(attr_value->type(), constraint.allowed_values())) {
|
||||
return Status::OK();
|
||||
}
|
||||
} else {
|
||||
if (!AttrValueHasType(*attr_value, "list(type)").ok()) {
|
||||
return errors::InvalidArgument(
|
||||
"KernelDef '", kernel_def.ShortDebugString(),
|
||||
"' has constraint on attr '", constraint.name(),
|
||||
"' that has value '", SummarizeAttrValue(*attr_value),
|
||||
"' that does not have type 'type' or 'list(type)' in NodeDef "
|
||||
"'",
|
||||
attrs.SummarizeNode(), "'");
|
||||
}
|
||||
|
||||
for (int t : attr_value->list().type()) {
|
||||
if (!InTypeList(static_cast<DataType>(t),
|
||||
constraint.allowed_values())) {
|
||||
return Status::OK();
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
*match = true;
|
||||
return Status::OK();
|
||||
|
@ -614,12 +614,14 @@ class Conv2DCustomBackpropFilterOp : public OpKernel {
|
||||
REGISTER_KERNEL_BUILDER(Name("Conv2DBackpropFilter") \
|
||||
.Device(DEVICE_CPU) \
|
||||
.Label("custom") \
|
||||
.TypeConstraint<T>("T"), \
|
||||
.TypeConstraint<T>("T") \
|
||||
.AttrConstraint("data_format", "NHWC"), \
|
||||
Conv2DCustomBackpropFilterOp<CPUDevice, T>); \
|
||||
REGISTER_KERNEL_BUILDER(Name("Conv2DBackpropFilter") \
|
||||
.Device(DEVICE_CPU) \
|
||||
.Label("eigen_tensor") \
|
||||
.TypeConstraint<T>("T"), \
|
||||
.TypeConstraint<T>("T") \
|
||||
.AttrConstraint("data_format", "NHWC"), \
|
||||
Conv2DBackpropFilterOp<CPUDevice, T>);
|
||||
|
||||
TF_CALL_half(REGISTER_CPU_KERNELS);
|
||||
|
@ -175,10 +175,12 @@ struct DepthToSpaceOpFunctor<CPUDevice, T, FORMAT_NHWC> {
|
||||
};
|
||||
} // namespace functor
|
||||
|
||||
#define REGISTER(type) \
|
||||
REGISTER_KERNEL_BUILDER( \
|
||||
Name("DepthToSpace").Device(DEVICE_CPU).TypeConstraint<type>("T"), \
|
||||
DepthToSpaceOp<CPUDevice, type>);
|
||||
#define REGISTER(type) \
|
||||
REGISTER_KERNEL_BUILDER(Name("DepthToSpace") \
|
||||
.Device(DEVICE_CPU) \
|
||||
.TypeConstraint<type>("T") \
|
||||
.AttrConstraint("data_format", "NHWC"), \
|
||||
DepthToSpaceOp<CPUDevice, type>);
|
||||
|
||||
TF_CALL_ALL_TYPES(REGISTER);
|
||||
#undef REGISTER
|
||||
|
@ -190,10 +190,12 @@ struct SpaceToDepthOpFunctor<CPUDevice, T, FORMAT_NHWC> {
|
||||
};
|
||||
} // namespace functor
|
||||
|
||||
#define REGISTER(type) \
|
||||
REGISTER_KERNEL_BUILDER( \
|
||||
Name("SpaceToDepth").Device(DEVICE_CPU).TypeConstraint<type>("T"), \
|
||||
SpaceToDepthOp<CPUDevice, type>);
|
||||
#define REGISTER(type) \
|
||||
REGISTER_KERNEL_BUILDER(Name("SpaceToDepth") \
|
||||
.Device(DEVICE_CPU) \
|
||||
.TypeConstraint<type>("T") \
|
||||
.AttrConstraint("data_format", "NHWC"), \
|
||||
SpaceToDepthOp<CPUDevice, type>);
|
||||
|
||||
TF_CALL_ALL_TYPES(REGISTER);
|
||||
TF_CALL_qint8(REGISTER);
|
||||
|
@ -21,8 +21,10 @@ from __future__ import print_function
|
||||
|
||||
import numpy as np
|
||||
|
||||
from tensorflow.python.client import device_lib
|
||||
from tensorflow.python.framework import constant_op
|
||||
from tensorflow.python.framework import dtypes
|
||||
from tensorflow.python.framework import errors_impl
|
||||
from tensorflow.python.framework import ops
|
||||
from tensorflow.python.framework import test_util
|
||||
from tensorflow.python.ops import array_ops
|
||||
@ -41,6 +43,18 @@ class DepthToSpaceTest(test.TestCase):
|
||||
# test NHWC (default) on CPU
|
||||
x_tf = array_ops.depth_to_space(input_nhwc, block_size)
|
||||
self.assertAllEqual(x_tf.eval(), outputs)
|
||||
|
||||
# Run this test only if only CPU device is available
|
||||
if all(x.device_type == "CPU" for x in device_lib.list_local_devices()):
|
||||
input_nchw = test_util.NHWCToNCHW(input_nhwc)
|
||||
output_nchw = array_ops.depth_to_space(
|
||||
input_nchw, block_size, data_format="NCHW")
|
||||
output_nhwc = test_util.NCHWToNHWC(output_nchw)
|
||||
with self.assertRaisesRegexp(
|
||||
errors_impl.InvalidArgumentError,
|
||||
"No OpKernel was registered to support Op 'DepthToSpace'"):
|
||||
output_nhwc.eval()
|
||||
|
||||
if test.is_gpu_available():
|
||||
with self.cached_session(use_gpu=True):
|
||||
# test NHWC (default) on GPU
|
||||
|
Loading…
Reference in New Issue
Block a user