Add data_format attribute constraint for CPU kernels of SpaceToDepth and DepthToSpace ops.

Ops such as DepthToSpace and SpaceToDepth can only run on CPU if data_format=NHWC. Currently when we run a graph with these ops on both CPU and GPU devices, the op placement algorithm always prefers GPU over CPU. This avoids error if data_format=NCHW.

In the future we may have alternative op placement algorithm which may put these ops on CPU to e.g. minimize memory copy overhead. It will be better to explicitly enforce this constraint when we determine supported device type for nodes.

PiperOrigin-RevId: 271168551
This commit is contained in:
Dong Lin 2019-09-25 11:32:45 -07:00 committed by TensorFlower Gardener
parent 62bbb4d2da
commit 83aecde638
5 changed files with 118 additions and 36 deletions

View File

@ -37,44 +37,106 @@ Status KernelAttrsMatch(const KernelDef& kernel_def, AttrSlice attrs,
bool* match) {
*match = false;
for (const auto& constraint : kernel_def.constraint()) {
if (constraint.allowed_values().list().type_size() == 0) {
auto constraint_value_case = AttrValue::VALUE_NOT_SET;
int value_type_num = 0;
if (constraint.allowed_values().list().type_size() > 0) {
constraint_value_case = AttrValue::kType;
value_type_num++;
}
if (constraint.allowed_values().list().s_size() > 0) {
constraint_value_case = AttrValue::kS;
value_type_num++;
}
if (constraint.allowed_values().list().i_size() > 0) {
constraint_value_case = AttrValue::kI;
value_type_num++;
}
if (constraint.allowed_values().list().b_size() > 0) {
constraint_value_case = AttrValue::kB;
value_type_num++;
}
if (value_type_num == 0) {
return errors::Unimplemented(
"KernelDef '", kernel_def.ShortDebugString(),
" has constraint on attr '", constraint.name(),
"' with unsupported type: ",
SummarizeAttrValue(constraint.allowed_values()));
}
if (value_type_num > 1) {
return errors::InvalidArgument(
"KernelDef '", kernel_def.ShortDebugString(),
" has constraint on attr '", constraint.name(),
"' with more than one value type: ",
SummarizeAttrValue(constraint.allowed_values()));
}
const AttrValue* found = attrs.Find(constraint.name());
if (found) {
if (found->type() != DT_INVALID) {
if (!InTypeList(found->type(), constraint.allowed_values())) {
return Status::OK();
}
} else {
if (!AttrValueHasType(*found, "list(type)").ok()) {
return errors::InvalidArgument(
"KernelDef '", kernel_def.ShortDebugString(),
"' has constraint on attr '", constraint.name(),
"' that has value '", SummarizeAttrValue(*found),
"' that does not have type 'type' or 'list(type)' in NodeDef "
"'",
attrs.SummarizeNode(), "'");
}
for (int t : found->list().type()) {
if (!InTypeList(static_cast<DataType>(t),
constraint.allowed_values())) {
return Status::OK();
}
}
}
} else {
const AttrValue* attr_value = attrs.Find(constraint.name());
if (attr_value == nullptr) {
return errors::InvalidArgument(
"OpKernel '", kernel_def.op(), "' has constraint on attr '",
constraint.name(), "' not in NodeDef '", attrs.SummarizeNode(),
"', KernelDef: '", kernel_def.ShortDebugString(), "'");
}
#define RETURN_IF_ATTR_NOT_FOUND(n, oneof_case, type_str) \
do { \
if (constraint_value_case == AttrValue::oneof_case) { \
Status s = AttrValueHasType(*attr_value, type_str); \
if (!s.ok()) { \
return errors::InvalidArgument( \
"KernelDef '", kernel_def.ShortDebugString(), \
"' has constraint on attr '", constraint.name(), \
"' that has value '", SummarizeAttrValue(*attr_value), \
"' that does not have the same type in NodeDef " \
"'", \
attrs.SummarizeNode(), "'"); \
} \
bool found = false; \
for (auto& value : constraint.allowed_values().list().n()) { \
if (value == attr_value->n()) { \
found = true; \
break; \
} \
} \
if (!found) { \
return Status::OK(); \
} \
} \
} while (false)
RETURN_IF_ATTR_NOT_FOUND(s, kS, "string");
RETURN_IF_ATTR_NOT_FOUND(i, kI, "int");
RETURN_IF_ATTR_NOT_FOUND(b, kB, "bool");
#undef RETURN_IF_ATTR_NOT_FOUND
if (constraint_value_case != AttrValue::kType) {
continue;
}
if (attr_value->type() != DT_INVALID) {
if (!InTypeList(attr_value->type(), constraint.allowed_values())) {
return Status::OK();
}
} else {
if (!AttrValueHasType(*attr_value, "list(type)").ok()) {
return errors::InvalidArgument(
"KernelDef '", kernel_def.ShortDebugString(),
"' has constraint on attr '", constraint.name(),
"' that has value '", SummarizeAttrValue(*attr_value),
"' that does not have type 'type' or 'list(type)' in NodeDef "
"'",
attrs.SummarizeNode(), "'");
}
for (int t : attr_value->list().type()) {
if (!InTypeList(static_cast<DataType>(t),
constraint.allowed_values())) {
return Status::OK();
}
}
}
}
*match = true;
return Status::OK();

View File

@ -614,12 +614,14 @@ class Conv2DCustomBackpropFilterOp : public OpKernel {
REGISTER_KERNEL_BUILDER(Name("Conv2DBackpropFilter") \
.Device(DEVICE_CPU) \
.Label("custom") \
.TypeConstraint<T>("T"), \
.TypeConstraint<T>("T") \
.AttrConstraint("data_format", "NHWC"), \
Conv2DCustomBackpropFilterOp<CPUDevice, T>); \
REGISTER_KERNEL_BUILDER(Name("Conv2DBackpropFilter") \
.Device(DEVICE_CPU) \
.Label("eigen_tensor") \
.TypeConstraint<T>("T"), \
.TypeConstraint<T>("T") \
.AttrConstraint("data_format", "NHWC"), \
Conv2DBackpropFilterOp<CPUDevice, T>);
TF_CALL_half(REGISTER_CPU_KERNELS);

View File

@ -175,10 +175,12 @@ struct DepthToSpaceOpFunctor<CPUDevice, T, FORMAT_NHWC> {
};
} // namespace functor
#define REGISTER(type) \
REGISTER_KERNEL_BUILDER( \
Name("DepthToSpace").Device(DEVICE_CPU).TypeConstraint<type>("T"), \
DepthToSpaceOp<CPUDevice, type>);
#define REGISTER(type) \
REGISTER_KERNEL_BUILDER(Name("DepthToSpace") \
.Device(DEVICE_CPU) \
.TypeConstraint<type>("T") \
.AttrConstraint("data_format", "NHWC"), \
DepthToSpaceOp<CPUDevice, type>);
TF_CALL_ALL_TYPES(REGISTER);
#undef REGISTER

View File

@ -190,10 +190,12 @@ struct SpaceToDepthOpFunctor<CPUDevice, T, FORMAT_NHWC> {
};
} // namespace functor
#define REGISTER(type) \
REGISTER_KERNEL_BUILDER( \
Name("SpaceToDepth").Device(DEVICE_CPU).TypeConstraint<type>("T"), \
SpaceToDepthOp<CPUDevice, type>);
#define REGISTER(type) \
REGISTER_KERNEL_BUILDER(Name("SpaceToDepth") \
.Device(DEVICE_CPU) \
.TypeConstraint<type>("T") \
.AttrConstraint("data_format", "NHWC"), \
SpaceToDepthOp<CPUDevice, type>);
TF_CALL_ALL_TYPES(REGISTER);
TF_CALL_qint8(REGISTER);

View File

@ -21,8 +21,10 @@ from __future__ import print_function
import numpy as np
from tensorflow.python.client import device_lib
from tensorflow.python.framework import constant_op
from tensorflow.python.framework import dtypes
from tensorflow.python.framework import errors_impl
from tensorflow.python.framework import ops
from tensorflow.python.framework import test_util
from tensorflow.python.ops import array_ops
@ -41,6 +43,18 @@ class DepthToSpaceTest(test.TestCase):
# test NHWC (default) on CPU
x_tf = array_ops.depth_to_space(input_nhwc, block_size)
self.assertAllEqual(x_tf.eval(), outputs)
# Run this test only if only CPU device is available
if all(x.device_type == "CPU" for x in device_lib.list_local_devices()):
input_nchw = test_util.NHWCToNCHW(input_nhwc)
output_nchw = array_ops.depth_to_space(
input_nchw, block_size, data_format="NCHW")
output_nhwc = test_util.NCHWToNHWC(output_nchw)
with self.assertRaisesRegexp(
errors_impl.InvalidArgumentError,
"No OpKernel was registered to support Op 'DepthToSpace'"):
output_nhwc.eval()
if test.is_gpu_available():
with self.cached_session(use_gpu=True):
# test NHWC (default) on GPU