Switch nn_ops shape fns to delegate to C++, for all that have a C++

implementation (fractional pool ones don't yet).

Change BiasAdd functions to require only rank 3, not 4, for NHWC. This matches
the behavior of GetBiasValueDims in bias_op.cc.

Removed unused functions common_shapes.bias_add_shape and
common_shapes.bias_add_grad_shape.
Change: 132597521
This commit is contained in:
A. Unique TensorFlower 2016-09-08 13:02:19 -08:00 committed by TensorFlower Gardener
parent 57d6a3ee56
commit c3a30a230f
6 changed files with 51 additions and 223 deletions

View File

@ -135,7 +135,7 @@ Status BiasAddShape(shape_inference::InferenceContext* c) {
Status s = c->GetAttr("data_format", &data_format);
if (s.ok() && data_format == "NCHW") {
TF_RETURN_IF_ERROR(c->WithRankAtLeast(c->input(0), 4, &input_shape));
TF_RETURN_IF_ERROR(c->WithRankAtLeast(c->input(0), 3, &input_shape));
} else {
TF_RETURN_IF_ERROR(c->WithRankAtLeast(c->input(0), 2, &input_shape));
}
@ -193,7 +193,7 @@ Status BiasAddGradShape(shape_inference::InferenceContext* c) {
Status s = c->GetAttr("data_format", &data_format);
if (s.ok() && data_format == "NCHW") {
TF_RETURN_IF_ERROR(c->WithRankAtLeast(c->input(0), 4, &input_shape));
TF_RETURN_IF_ERROR(c->WithRankAtLeast(c->input(0), 3, &input_shape));
c->set_output(0, c->Vector(c->Dim(input_shape, -3)));
} else {
TF_RETURN_IF_ERROR(c->WithRankAtLeast(c->input(0), 2, &input_shape));

View File

@ -242,6 +242,19 @@ TEST(CommonShapeFnsTest, BiasAddShapeTest) {
EXPECT_EQ("[8,6,4,2,3,4,5]", c.DebugString(output));
}
{
// NCHW format with input rank 3
TF_CHECK_OK(NodeDefBuilder("test", "BiasAdd")
.Input("a", 0, DT_FLOAT)
.Input("b", 0, DT_FLOAT)
.Attr("data_format", "NCHW")
.Finalize(&def));
InferenceContext c(&def, op_def, {"[10,11,12]", "[10]"}, {});
TF_EXPECT_OK(BiasAddShape(&c));
ShapeHandle output = c.output(0);
EXPECT_EQ("[10,11,12]", c.DebugString(output));
}
{
// Input rank not high enough
InferenceContext c(&def, op_def, {"[3]", "[3]"}, {});
@ -256,7 +269,7 @@ TEST(CommonShapeFnsTest, BiasAddShapeTest) {
.Attr("data_format", "NCHW")
.Finalize(&def));
// NCHW format
InferenceContext c(&def, op_def, {"[2,3,4]", "[3]"}, {});
InferenceContext c(&def, op_def, {"[2,3]", "[3]"}, {});
EXPECT_FALSE(BiasAddShape(&c).ok());
}
}
@ -313,6 +326,18 @@ TEST(CommonShapeFnsTest, BiasAddGradShapeTest) {
EXPECT_EQ(3, c.Value(c.Dim(output, 0)));
}
{
// NCHW format with input rank 3
TF_CHECK_OK(NodeDefBuilder("test", "BiasAddGrad")
.Input("a", 0, DT_FLOAT)
.Attr("data_format", "NCHW")
.Finalize(&def));
InferenceContext c(&def, op_def, {"[10,11,12]"}, {});
TF_EXPECT_OK(BiasAddGradShape(&c));
ShapeHandle output = c.output(0);
EXPECT_EQ(10, c.Value(c.Dim(output, 0)));
}
{
// Input rank not high enough
InferenceContext c(&def, op_def, {"[3]"}, {});
@ -326,7 +351,7 @@ TEST(CommonShapeFnsTest, BiasAddGradShapeTest) {
.Attr("data_format", "NCHW")
.Finalize(&def));
// NCHW format
InferenceContext c(&def, op_def, {"[2,3,4]"}, {});
InferenceContext c(&def, op_def, {"[2,3]"}, {});
EXPECT_FALSE(BiasAddGradShape(&c).ok());
}
}

View File

@ -102,45 +102,6 @@ def matmul_shape(op):
return [tensor_shape.TensorShape([output_rows, output_cols])]
def bias_add_shape(op):
"""Shape function for a BiasAdd op."""
input_shape = op.inputs[0].get_shape().with_rank_at_least(2)
bias_shape = op.inputs[1].get_shape().with_rank(1)
if input_shape.ndims is not None:
# Output has the same shape as input, and matches the length of
# bias in its bias dimension.
try:
data_format = op.get_attr("data_format")
except ValueError:
data_format = None
if data_format == b"NCHW":
# Merge the length of bias_shape into the third-to-last dimension.
output_shape = input_shape[0:-3].concatenate(input_shape[-3].merge_with(
bias_shape[0])).concatenate(input_shape[-2:])
else:
output_shape = input_shape[0:-1].concatenate(input_shape[-1].merge_with(
bias_shape[0]))
else:
output_shape = tensor_shape.unknown_shape()
return [output_shape]
def bias_add_grad_shape(op):
"""Shape function for a BiasAddGrad op."""
input_shape = op.inputs[0].get_shape().with_rank_at_least(2)
try:
data_format = op.get_attr("data_format")
except ValueError:
data_format = None
if data_format == b"NCHW":
output_shape = input_shape[-3]
else:
output_shape = input_shape[-1]
return [output_shape]
def get_conv_output_size(input_size, filter_size, strides, padding_type):
"""Returns the spatial size of a n-d convolution/pooling output."""
input_size = tuple([tensor_shape.as_dimension(x).value for x in input_size])

View File

@ -929,30 +929,12 @@ class PoolingTest(tf.test.TestCase):
pool_func(tf.placeholder(tf.float32, shape=[1, 3]),
ksize=[1, 1, 1, 1], strides=[1, 1, 1, 1], padding="SAME")
# Illegal strides.
with self.assertRaisesRegexp(ValueError, "strides in the batch"):
tf.nn.max_pool_with_argmax(
tf.placeholder(tf.float32),
ksize=[1, 1, 1, 1],
strides=[2, 1, 1, 1],
padding="SAME")
# Filter larger than input.
for pool_func in [tf.nn.max_pool_with_argmax]:
with self.assertRaisesRegexp(ValueError,
"Filter must not be larger than the input"):
pool_func(tf.placeholder(tf.float32,
shape=[32, 20, 20, 3]),
ksize=[1, 20, 21, 1], strides=[1, 1, 1, 1], padding="SAME")
with self.assertRaisesRegexp(ValueError,
"Filter must not be larger than the input"):
pool_func(tf.placeholder(tf.float32,
shape=[32, 20, 20, 3]),
ksize=[1, 21, 20, 1], strides=[1, 1, 1, 1], padding="SAME")
def testOpEdgeCases(self):
with self.test_session() as sess:
for pool_func in [tf.nn.max_pool, tf.nn.avg_pool]:
pool_funcs = [tf.nn.max_pool, tf.nn.avg_pool]
if tf.test.is_gpu_available():
pool_funcs.append(tf.nn.max_pool_with_argmax)
for pool_func in pool_funcs:
# Illegal strides.
with self.assertRaisesRegexp(
tf.errors.UnimplementedError,

View File

@ -81,7 +81,7 @@ class TopKTest(tf.test.TestCase):
def testKTooLarge(self):
inputs = [[0.1, 0.2], [0.3, 0.4]]
with self.assertRaisesRegexp(
ValueError, r"input.shape \(2, 2\) must have last dimension >= k = 4"):
ValueError, r"must have last dimension >= k = 4"):
tf.nn.top_k(inputs, 4)
def testTopKGradients(self):

View File

@ -393,9 +393,10 @@ def bias_add(value, bias, data_format=None, name=None):
return gen_nn_ops._bias_add(value, bias, data_format=data_format, name=name)
ops.RegisterShape("BiasAdd")(common_shapes.bias_add_shape)
ops.RegisterShape("BiasAddGrad")(common_shapes.bias_add_grad_shape)
ops.RegisterShape("BiasAddV1")(common_shapes.call_cpp_shape_fn)
ops.RegisterShape("BiasAdd")(common_shapes.call_cpp_shape_fn)
ops.RegisterShape("BiasAddGradV1")(common_shapes.call_cpp_shape_fn)
ops.RegisterShape("BiasAddGrad")(common_shapes.call_cpp_shape_fn)
# pylint: disable=protected-access
@ -426,11 +427,6 @@ def bias_add_v1(value, bias, name=None):
return gen_nn_ops._bias_add_v1(value, bias, name=name)
ops.RegisterShape("BiasAddV1")(common_shapes.bias_add_shape)
ops.RegisterShape("BiasAddGradV1")(common_shapes.bias_add_grad_shape)
def crelu(features, name=None):
"""Computes Concatenated ReLU.
@ -866,23 +862,12 @@ ops.RegisterShape("LRNGrad")(common_shapes.call_cpp_shape_fn)
ops.RegisterShape("Softmax")(common_shapes.call_cpp_shape_fn)
ops.RegisterShape("LogSoftmax")(common_shapes.call_cpp_shape_fn)
ops.RegisterShape("InTopK")(common_shapes.call_cpp_shape_fn)
ops.RegisterShape("TopK")(common_shapes.call_cpp_shape_fn)
@ops.RegisterShape("TopK")
@ops.RegisterShape("TopKV2")
def _TopKShape(op):
"""Shape function for TopK and TopKV2 ops."""
input_shape = op.inputs[0].get_shape().with_rank_at_least(1)
if len(op.inputs) >= 2:
k = tensor_util.constant_value(op.inputs[1])
else:
k = op.get_attr("k")
last = input_shape[-1].value
if last is not None and k is not None and last < k:
raise ValueError("input.shape %s must have last dimension >= k = %d" %
(input_shape, k))
output_shape = input_shape[:-1].concatenate([k])
return [output_shape, output_shape]
def _TopKV2Shape(op):
return common_shapes.call_cpp_shape_fn(op, input_tensors_needed=[1])
ops.RegisterShape("BatchNormWithGlobalNormalization")(
@ -960,24 +945,12 @@ def _FusedResizeAndPadConv2DShape(op):
return [tensor_shape.TensorShape(output_shape)]
@ops.RegisterShape("MaxPoolWithArgmax")
def _MaxPoolWithArgMaxShape(op):
"""Shape function for MaxPoolWithArgmax op."""
return common_shapes.max_pool_shape(op) * 2
ops.RegisterShape("MaxPoolWithArgmax")(common_shapes.call_cpp_shape_fn)
@ops.RegisterShape("AvgPoolGrad")
def _AvgPoolGradShape(op):
"""Shape function for the AvgPoolGrad op."""
orig_input_shape = tensor_util.constant_value(op.inputs[0])
if orig_input_shape is not None:
return [tensor_shape.TensorShape(orig_input_shape.tolist())]
else:
# NOTE(mrry): We could in principle work out the shape from the
# gradients and the attrs, but if we do not know orig_input_shape
# statically, then we are unlikely to know the shape of the
# gradients either.
return [tensor_shape.unknown_shape(ndims=4)]
return common_shapes.call_cpp_shape_fn(op, input_tensors_needed=[0])
@ops.RegisterShape("FractionalMaxPool")
@ -1015,50 +988,22 @@ def _fractional_avg_pool_grad_shape(op):
@ops.RegisterShape("Conv2DBackpropFilter")
def _Conv2DBackpropFilterShape(op):
"""Shape function for the Conv2DBackpropFilter op."""
filter_shape = tensor_util.constant_value(op.inputs[1])
if filter_shape is not None:
return [tensor_shape.TensorShape(filter_shape.tolist())]
else:
# NOTE(mrry): We could in principle work out the shape from the
# gradients and the attrs, but if we do not know filter_shape
# statically, then we are unlikely to know the shape of the
# gradients either.
return [tensor_shape.unknown_shape(ndims=4)]
return common_shapes.call_cpp_shape_fn(op, input_tensors_needed=[1])
@ops.RegisterShape("Conv2DBackpropInput")
def _Conv2DBackpropInputShape(op):
"""Shape function for the Conv2DBackpropInput op."""
input_shape = tensor_util.constant_value(op.inputs[0])
if input_shape is not None:
return [tensor_shape.TensorShape(input_shape.tolist())]
else:
# NOTE(mrry): We could in principle work out the shape from the
# gradients and the attrs, but if we do not know input_shape
# statically, then we are unlikely to know the shape of the
# gradients either.
return [tensor_shape.unknown_shape(ndims=4)]
return common_shapes.call_cpp_shape_fn(op, input_tensors_needed=[0])
@ops.RegisterShape("DepthwiseConv2dNativeBackpropFilter")
def _DepthwiseConv2dNativeBackpropFilterShape(op):
"""Shape function for the DepthwiseConv2dNativeBackpropFilter op."""
filter_shape = tensor_util.constant_value(op.inputs[1])
if filter_shape is not None:
return [tensor_shape.TensorShape(filter_shape.tolist())]
else:
return [tensor_shape.unknown_shape(ndims=4)]
return common_shapes.call_cpp_shape_fn(op, input_tensors_needed=[1])
@ops.RegisterShape("DepthwiseConv2dNativeBackpropInput")
def _DepthwiseConv2dNativeBackpropInputShape(op):
"""Shape function for the DepthwiseConv2dNativeBackpropInput op."""
input_shape = tensor_util.constant_value(op.inputs[0])
if input_shape is not None:
return [tensor_shape.TensorShape(input_shape.tolist())]
else:
return [tensor_shape.unknown_shape(ndims=4)]
return common_shapes.call_cpp_shape_fn(op, input_tensors_needed=[0])
ops.RegisterShape("MaxPoolGrad")(common_shapes.call_cpp_shape_fn)
@ -1136,55 +1081,9 @@ def _calc_depthwise_conv_weight_params(graph, node):
filter_channel_multiplier))
@ops.RegisterShape("Conv3D")
def _Conv3DShape(op):
"""Shape function for Conv3D."""
input_shape = op.inputs[0].get_shape().with_rank(5)
filter_shape = op.inputs[1].get_shape().with_rank(5)
batch_size = input_shape[0]
out_channels = filter_shape[4]
# Check that the input number of channels is compatible between
# input data and filter size.
input_shape[4].assert_is_compatible_with(filter_shape[3])
stride_b, stride_p, stride_r, stride_c, stride_d = op.get_attr("strides")
assert stride_b == 1
assert stride_d == 1
padding_type = op.get_attr("padding")
out_planes, out_rows, out_cols = common_shapes.get_conv_output_size(
input_shape[1:4], filter_shape[0:3], (stride_p, stride_r, stride_c),
padding_type)
return [tensor_shape.TensorShape([batch_size, out_planes, out_rows, out_cols,
out_channels])]
@ops.RegisterShape("MaxPool3D")
@ops.RegisterShape("AvgPool3D")
def _Pool3DShape(op):
"""Shape function for Max/AvgPool3D."""
input_shape = op.inputs[0].get_shape().with_rank(5)
ksize_b, ksize_p, ksize_r, ksize_c, ksize_d = op.get_attr("ksize")
assert ksize_b == 1
assert ksize_d == 1
stride_b, stride_p, stride_r, stride_c, stride_d = op.get_attr("strides")
assert stride_b == 1
assert stride_d == 1
batch_size = input_shape[0]
channels = input_shape[4]
padding = op.get_attr("padding")
out_planes, out_rows, out_cols = common_shapes.get_conv_output_size(
input_shape[1:4], (ksize_p, ksize_r, ksize_c),
(stride_p, stride_r, stride_c), padding)
return [tensor_shape.TensorShape([batch_size, out_planes, out_rows, out_cols,
channels])]
ops.RegisterShape("Conv3D")(common_shapes.call_cpp_shape_fn)
ops.RegisterShape("MaxPool3D")(common_shapes.call_cpp_shape_fn)
ops.RegisterShape("AvgPool3D")(common_shapes.call_cpp_shape_fn)
ops.RegisterShape("Conv3DBackpropFilter")(common_shapes.call_cpp_shape_fn)
ops.RegisterShape("Conv3DBackpropInput")(common_shapes.call_cpp_shape_fn)
ops.RegisterShape("Conv3DBackpropFilterV2")(common_shapes.call_cpp_shape_fn)
@ -1392,46 +1291,7 @@ def conv1d(value, filters, stride, padding,
return array_ops.squeeze(result, [1])
@ops.RegisterShape("Dilation2D")
def _Dilation2DShape(op):
"""Shape function for Dilation2D op."""
input_shape = op.inputs[0].get_shape().with_rank(4)
filter_shape = op.inputs[1].get_shape().with_rank(3)
batch_size = input_shape[0]
in_rows = input_shape[1]
in_cols = input_shape[2]
depth = input_shape[3]
filter_rows = filter_shape[0]
filter_cols = filter_shape[1]
# Check that the input depths are compatible.
input_shape[3].assert_is_compatible_with(filter_shape[2])
stride_b, stride_r, stride_c, stride_d = op.get_attr("strides")
if stride_b != 1 or stride_d != 1:
raise ValueError("Current implementation does not yet support "
"strides in the batch and depth dimensions.")
rate_b, rate_r, rate_c, rate_d = op.get_attr("rates")
if rate_b != 1 or rate_d != 1:
raise ValueError("Current implementation does not yet support "
"rates in the batch and depth dimensions.")
filter_rows_eff = filter_rows + (filter_rows - 1) * (rate_r - 1)
filter_cols_eff = filter_cols + (filter_cols - 1) * (rate_c - 1)
padding = op.get_attr("padding")
out_rows, out_cols = common_shapes.get2d_conv_output_size(in_rows, in_cols,
filter_rows_eff,
filter_cols_eff,
stride_r, stride_c,
padding)
output_shape = [batch_size, out_rows, out_cols, depth]
return [tensor_shape.TensorShape(output_shape)]
ops.RegisterShape("Dilation2D")(common_shapes.call_cpp_shape_fn)
ops.RegisterShape("Dilation2DBackpropInput")(common_shapes.call_cpp_shape_fn)
ops.RegisterShape("Dilation2DBackpropFilter")(common_shapes.call_cpp_shape_fn)