Switch nn_ops shape fns to delegate to C++, for all that have a C++
implementation (fractional pool ones don't yet). Change BiasAdd functions to require only rank 3, not 4, for NHWC. This matches the behavior of GetBiasValueDims in bias_op.cc. Removed unused functions common_shapes.bias_add_shape and common_shapes.bias_add_grad_shape. Change: 132597521
This commit is contained in:
parent
57d6a3ee56
commit
c3a30a230f
@ -135,7 +135,7 @@ Status BiasAddShape(shape_inference::InferenceContext* c) {
|
||||
Status s = c->GetAttr("data_format", &data_format);
|
||||
|
||||
if (s.ok() && data_format == "NCHW") {
|
||||
TF_RETURN_IF_ERROR(c->WithRankAtLeast(c->input(0), 4, &input_shape));
|
||||
TF_RETURN_IF_ERROR(c->WithRankAtLeast(c->input(0), 3, &input_shape));
|
||||
} else {
|
||||
TF_RETURN_IF_ERROR(c->WithRankAtLeast(c->input(0), 2, &input_shape));
|
||||
}
|
||||
@ -193,7 +193,7 @@ Status BiasAddGradShape(shape_inference::InferenceContext* c) {
|
||||
Status s = c->GetAttr("data_format", &data_format);
|
||||
|
||||
if (s.ok() && data_format == "NCHW") {
|
||||
TF_RETURN_IF_ERROR(c->WithRankAtLeast(c->input(0), 4, &input_shape));
|
||||
TF_RETURN_IF_ERROR(c->WithRankAtLeast(c->input(0), 3, &input_shape));
|
||||
c->set_output(0, c->Vector(c->Dim(input_shape, -3)));
|
||||
} else {
|
||||
TF_RETURN_IF_ERROR(c->WithRankAtLeast(c->input(0), 2, &input_shape));
|
||||
|
@ -242,6 +242,19 @@ TEST(CommonShapeFnsTest, BiasAddShapeTest) {
|
||||
EXPECT_EQ("[8,6,4,2,3,4,5]", c.DebugString(output));
|
||||
}
|
||||
|
||||
{
|
||||
// NCHW format with input rank 3
|
||||
TF_CHECK_OK(NodeDefBuilder("test", "BiasAdd")
|
||||
.Input("a", 0, DT_FLOAT)
|
||||
.Input("b", 0, DT_FLOAT)
|
||||
.Attr("data_format", "NCHW")
|
||||
.Finalize(&def));
|
||||
InferenceContext c(&def, op_def, {"[10,11,12]", "[10]"}, {});
|
||||
TF_EXPECT_OK(BiasAddShape(&c));
|
||||
ShapeHandle output = c.output(0);
|
||||
EXPECT_EQ("[10,11,12]", c.DebugString(output));
|
||||
}
|
||||
|
||||
{
|
||||
// Input rank not high enough
|
||||
InferenceContext c(&def, op_def, {"[3]", "[3]"}, {});
|
||||
@ -256,7 +269,7 @@ TEST(CommonShapeFnsTest, BiasAddShapeTest) {
|
||||
.Attr("data_format", "NCHW")
|
||||
.Finalize(&def));
|
||||
// NCHW format
|
||||
InferenceContext c(&def, op_def, {"[2,3,4]", "[3]"}, {});
|
||||
InferenceContext c(&def, op_def, {"[2,3]", "[3]"}, {});
|
||||
EXPECT_FALSE(BiasAddShape(&c).ok());
|
||||
}
|
||||
}
|
||||
@ -313,6 +326,18 @@ TEST(CommonShapeFnsTest, BiasAddGradShapeTest) {
|
||||
EXPECT_EQ(3, c.Value(c.Dim(output, 0)));
|
||||
}
|
||||
|
||||
{
|
||||
// NCHW format with input rank 3
|
||||
TF_CHECK_OK(NodeDefBuilder("test", "BiasAddGrad")
|
||||
.Input("a", 0, DT_FLOAT)
|
||||
.Attr("data_format", "NCHW")
|
||||
.Finalize(&def));
|
||||
InferenceContext c(&def, op_def, {"[10,11,12]"}, {});
|
||||
TF_EXPECT_OK(BiasAddGradShape(&c));
|
||||
ShapeHandle output = c.output(0);
|
||||
EXPECT_EQ(10, c.Value(c.Dim(output, 0)));
|
||||
}
|
||||
|
||||
{
|
||||
// Input rank not high enough
|
||||
InferenceContext c(&def, op_def, {"[3]"}, {});
|
||||
@ -326,7 +351,7 @@ TEST(CommonShapeFnsTest, BiasAddGradShapeTest) {
|
||||
.Attr("data_format", "NCHW")
|
||||
.Finalize(&def));
|
||||
// NCHW format
|
||||
InferenceContext c(&def, op_def, {"[2,3,4]"}, {});
|
||||
InferenceContext c(&def, op_def, {"[2,3]"}, {});
|
||||
EXPECT_FALSE(BiasAddGradShape(&c).ok());
|
||||
}
|
||||
}
|
||||
|
@ -102,45 +102,6 @@ def matmul_shape(op):
|
||||
return [tensor_shape.TensorShape([output_rows, output_cols])]
|
||||
|
||||
|
||||
def bias_add_shape(op):
|
||||
"""Shape function for a BiasAdd op."""
|
||||
input_shape = op.inputs[0].get_shape().with_rank_at_least(2)
|
||||
bias_shape = op.inputs[1].get_shape().with_rank(1)
|
||||
if input_shape.ndims is not None:
|
||||
# Output has the same shape as input, and matches the length of
|
||||
# bias in its bias dimension.
|
||||
try:
|
||||
data_format = op.get_attr("data_format")
|
||||
except ValueError:
|
||||
data_format = None
|
||||
if data_format == b"NCHW":
|
||||
# Merge the length of bias_shape into the third-to-last dimension.
|
||||
output_shape = input_shape[0:-3].concatenate(input_shape[-3].merge_with(
|
||||
bias_shape[0])).concatenate(input_shape[-2:])
|
||||
else:
|
||||
output_shape = input_shape[0:-1].concatenate(input_shape[-1].merge_with(
|
||||
bias_shape[0]))
|
||||
else:
|
||||
output_shape = tensor_shape.unknown_shape()
|
||||
return [output_shape]
|
||||
|
||||
|
||||
def bias_add_grad_shape(op):
|
||||
"""Shape function for a BiasAddGrad op."""
|
||||
input_shape = op.inputs[0].get_shape().with_rank_at_least(2)
|
||||
try:
|
||||
data_format = op.get_attr("data_format")
|
||||
except ValueError:
|
||||
data_format = None
|
||||
|
||||
if data_format == b"NCHW":
|
||||
output_shape = input_shape[-3]
|
||||
else:
|
||||
output_shape = input_shape[-1]
|
||||
|
||||
return [output_shape]
|
||||
|
||||
|
||||
def get_conv_output_size(input_size, filter_size, strides, padding_type):
|
||||
"""Returns the spatial size of a n-d convolution/pooling output."""
|
||||
input_size = tuple([tensor_shape.as_dimension(x).value for x in input_size])
|
||||
|
@ -929,30 +929,12 @@ class PoolingTest(tf.test.TestCase):
|
||||
pool_func(tf.placeholder(tf.float32, shape=[1, 3]),
|
||||
ksize=[1, 1, 1, 1], strides=[1, 1, 1, 1], padding="SAME")
|
||||
|
||||
# Illegal strides.
|
||||
with self.assertRaisesRegexp(ValueError, "strides in the batch"):
|
||||
tf.nn.max_pool_with_argmax(
|
||||
tf.placeholder(tf.float32),
|
||||
ksize=[1, 1, 1, 1],
|
||||
strides=[2, 1, 1, 1],
|
||||
padding="SAME")
|
||||
|
||||
# Filter larger than input.
|
||||
for pool_func in [tf.nn.max_pool_with_argmax]:
|
||||
with self.assertRaisesRegexp(ValueError,
|
||||
"Filter must not be larger than the input"):
|
||||
pool_func(tf.placeholder(tf.float32,
|
||||
shape=[32, 20, 20, 3]),
|
||||
ksize=[1, 20, 21, 1], strides=[1, 1, 1, 1], padding="SAME")
|
||||
with self.assertRaisesRegexp(ValueError,
|
||||
"Filter must not be larger than the input"):
|
||||
pool_func(tf.placeholder(tf.float32,
|
||||
shape=[32, 20, 20, 3]),
|
||||
ksize=[1, 21, 20, 1], strides=[1, 1, 1, 1], padding="SAME")
|
||||
|
||||
def testOpEdgeCases(self):
|
||||
with self.test_session() as sess:
|
||||
for pool_func in [tf.nn.max_pool, tf.nn.avg_pool]:
|
||||
pool_funcs = [tf.nn.max_pool, tf.nn.avg_pool]
|
||||
if tf.test.is_gpu_available():
|
||||
pool_funcs.append(tf.nn.max_pool_with_argmax)
|
||||
for pool_func in pool_funcs:
|
||||
# Illegal strides.
|
||||
with self.assertRaisesRegexp(
|
||||
tf.errors.UnimplementedError,
|
||||
|
@ -81,7 +81,7 @@ class TopKTest(tf.test.TestCase):
|
||||
def testKTooLarge(self):
|
||||
inputs = [[0.1, 0.2], [0.3, 0.4]]
|
||||
with self.assertRaisesRegexp(
|
||||
ValueError, r"input.shape \(2, 2\) must have last dimension >= k = 4"):
|
||||
ValueError, r"must have last dimension >= k = 4"):
|
||||
tf.nn.top_k(inputs, 4)
|
||||
|
||||
def testTopKGradients(self):
|
||||
|
@ -393,9 +393,10 @@ def bias_add(value, bias, data_format=None, name=None):
|
||||
return gen_nn_ops._bias_add(value, bias, data_format=data_format, name=name)
|
||||
|
||||
|
||||
ops.RegisterShape("BiasAdd")(common_shapes.bias_add_shape)
|
||||
|
||||
ops.RegisterShape("BiasAddGrad")(common_shapes.bias_add_grad_shape)
|
||||
ops.RegisterShape("BiasAddV1")(common_shapes.call_cpp_shape_fn)
|
||||
ops.RegisterShape("BiasAdd")(common_shapes.call_cpp_shape_fn)
|
||||
ops.RegisterShape("BiasAddGradV1")(common_shapes.call_cpp_shape_fn)
|
||||
ops.RegisterShape("BiasAddGrad")(common_shapes.call_cpp_shape_fn)
|
||||
|
||||
|
||||
# pylint: disable=protected-access
|
||||
@ -426,11 +427,6 @@ def bias_add_v1(value, bias, name=None):
|
||||
return gen_nn_ops._bias_add_v1(value, bias, name=name)
|
||||
|
||||
|
||||
ops.RegisterShape("BiasAddV1")(common_shapes.bias_add_shape)
|
||||
|
||||
ops.RegisterShape("BiasAddGradV1")(common_shapes.bias_add_grad_shape)
|
||||
|
||||
|
||||
def crelu(features, name=None):
|
||||
"""Computes Concatenated ReLU.
|
||||
|
||||
@ -866,23 +862,12 @@ ops.RegisterShape("LRNGrad")(common_shapes.call_cpp_shape_fn)
|
||||
ops.RegisterShape("Softmax")(common_shapes.call_cpp_shape_fn)
|
||||
ops.RegisterShape("LogSoftmax")(common_shapes.call_cpp_shape_fn)
|
||||
ops.RegisterShape("InTopK")(common_shapes.call_cpp_shape_fn)
|
||||
ops.RegisterShape("TopK")(common_shapes.call_cpp_shape_fn)
|
||||
|
||||
|
||||
@ops.RegisterShape("TopK")
|
||||
@ops.RegisterShape("TopKV2")
|
||||
def _TopKShape(op):
|
||||
"""Shape function for TopK and TopKV2 ops."""
|
||||
input_shape = op.inputs[0].get_shape().with_rank_at_least(1)
|
||||
if len(op.inputs) >= 2:
|
||||
k = tensor_util.constant_value(op.inputs[1])
|
||||
else:
|
||||
k = op.get_attr("k")
|
||||
last = input_shape[-1].value
|
||||
if last is not None and k is not None and last < k:
|
||||
raise ValueError("input.shape %s must have last dimension >= k = %d" %
|
||||
(input_shape, k))
|
||||
output_shape = input_shape[:-1].concatenate([k])
|
||||
return [output_shape, output_shape]
|
||||
def _TopKV2Shape(op):
|
||||
return common_shapes.call_cpp_shape_fn(op, input_tensors_needed=[1])
|
||||
|
||||
|
||||
ops.RegisterShape("BatchNormWithGlobalNormalization")(
|
||||
@ -960,24 +945,12 @@ def _FusedResizeAndPadConv2DShape(op):
|
||||
return [tensor_shape.TensorShape(output_shape)]
|
||||
|
||||
|
||||
@ops.RegisterShape("MaxPoolWithArgmax")
|
||||
def _MaxPoolWithArgMaxShape(op):
|
||||
"""Shape function for MaxPoolWithArgmax op."""
|
||||
return common_shapes.max_pool_shape(op) * 2
|
||||
ops.RegisterShape("MaxPoolWithArgmax")(common_shapes.call_cpp_shape_fn)
|
||||
|
||||
|
||||
@ops.RegisterShape("AvgPoolGrad")
|
||||
def _AvgPoolGradShape(op):
|
||||
"""Shape function for the AvgPoolGrad op."""
|
||||
orig_input_shape = tensor_util.constant_value(op.inputs[0])
|
||||
if orig_input_shape is not None:
|
||||
return [tensor_shape.TensorShape(orig_input_shape.tolist())]
|
||||
else:
|
||||
# NOTE(mrry): We could in principle work out the shape from the
|
||||
# gradients and the attrs, but if we do not know orig_input_shape
|
||||
# statically, then we are unlikely to know the shape of the
|
||||
# gradients either.
|
||||
return [tensor_shape.unknown_shape(ndims=4)]
|
||||
return common_shapes.call_cpp_shape_fn(op, input_tensors_needed=[0])
|
||||
|
||||
|
||||
@ops.RegisterShape("FractionalMaxPool")
|
||||
@ -1015,50 +988,22 @@ def _fractional_avg_pool_grad_shape(op):
|
||||
|
||||
@ops.RegisterShape("Conv2DBackpropFilter")
|
||||
def _Conv2DBackpropFilterShape(op):
|
||||
"""Shape function for the Conv2DBackpropFilter op."""
|
||||
filter_shape = tensor_util.constant_value(op.inputs[1])
|
||||
if filter_shape is not None:
|
||||
return [tensor_shape.TensorShape(filter_shape.tolist())]
|
||||
else:
|
||||
# NOTE(mrry): We could in principle work out the shape from the
|
||||
# gradients and the attrs, but if we do not know filter_shape
|
||||
# statically, then we are unlikely to know the shape of the
|
||||
# gradients either.
|
||||
return [tensor_shape.unknown_shape(ndims=4)]
|
||||
return common_shapes.call_cpp_shape_fn(op, input_tensors_needed=[1])
|
||||
|
||||
|
||||
@ops.RegisterShape("Conv2DBackpropInput")
|
||||
def _Conv2DBackpropInputShape(op):
|
||||
"""Shape function for the Conv2DBackpropInput op."""
|
||||
input_shape = tensor_util.constant_value(op.inputs[0])
|
||||
if input_shape is not None:
|
||||
return [tensor_shape.TensorShape(input_shape.tolist())]
|
||||
else:
|
||||
# NOTE(mrry): We could in principle work out the shape from the
|
||||
# gradients and the attrs, but if we do not know input_shape
|
||||
# statically, then we are unlikely to know the shape of the
|
||||
# gradients either.
|
||||
return [tensor_shape.unknown_shape(ndims=4)]
|
||||
return common_shapes.call_cpp_shape_fn(op, input_tensors_needed=[0])
|
||||
|
||||
|
||||
@ops.RegisterShape("DepthwiseConv2dNativeBackpropFilter")
|
||||
def _DepthwiseConv2dNativeBackpropFilterShape(op):
|
||||
"""Shape function for the DepthwiseConv2dNativeBackpropFilter op."""
|
||||
filter_shape = tensor_util.constant_value(op.inputs[1])
|
||||
if filter_shape is not None:
|
||||
return [tensor_shape.TensorShape(filter_shape.tolist())]
|
||||
else:
|
||||
return [tensor_shape.unknown_shape(ndims=4)]
|
||||
return common_shapes.call_cpp_shape_fn(op, input_tensors_needed=[1])
|
||||
|
||||
|
||||
@ops.RegisterShape("DepthwiseConv2dNativeBackpropInput")
|
||||
def _DepthwiseConv2dNativeBackpropInputShape(op):
|
||||
"""Shape function for the DepthwiseConv2dNativeBackpropInput op."""
|
||||
input_shape = tensor_util.constant_value(op.inputs[0])
|
||||
if input_shape is not None:
|
||||
return [tensor_shape.TensorShape(input_shape.tolist())]
|
||||
else:
|
||||
return [tensor_shape.unknown_shape(ndims=4)]
|
||||
return common_shapes.call_cpp_shape_fn(op, input_tensors_needed=[0])
|
||||
|
||||
|
||||
ops.RegisterShape("MaxPoolGrad")(common_shapes.call_cpp_shape_fn)
|
||||
@ -1136,55 +1081,9 @@ def _calc_depthwise_conv_weight_params(graph, node):
|
||||
filter_channel_multiplier))
|
||||
|
||||
|
||||
@ops.RegisterShape("Conv3D")
|
||||
def _Conv3DShape(op):
|
||||
"""Shape function for Conv3D."""
|
||||
input_shape = op.inputs[0].get_shape().with_rank(5)
|
||||
filter_shape = op.inputs[1].get_shape().with_rank(5)
|
||||
|
||||
batch_size = input_shape[0]
|
||||
out_channels = filter_shape[4]
|
||||
# Check that the input number of channels is compatible between
|
||||
# input data and filter size.
|
||||
input_shape[4].assert_is_compatible_with(filter_shape[3])
|
||||
|
||||
stride_b, stride_p, stride_r, stride_c, stride_d = op.get_attr("strides")
|
||||
assert stride_b == 1
|
||||
assert stride_d == 1
|
||||
|
||||
padding_type = op.get_attr("padding")
|
||||
out_planes, out_rows, out_cols = common_shapes.get_conv_output_size(
|
||||
input_shape[1:4], filter_shape[0:3], (stride_p, stride_r, stride_c),
|
||||
padding_type)
|
||||
|
||||
return [tensor_shape.TensorShape([batch_size, out_planes, out_rows, out_cols,
|
||||
out_channels])]
|
||||
|
||||
|
||||
@ops.RegisterShape("MaxPool3D")
|
||||
@ops.RegisterShape("AvgPool3D")
|
||||
def _Pool3DShape(op):
|
||||
"""Shape function for Max/AvgPool3D."""
|
||||
input_shape = op.inputs[0].get_shape().with_rank(5)
|
||||
ksize_b, ksize_p, ksize_r, ksize_c, ksize_d = op.get_attr("ksize")
|
||||
assert ksize_b == 1
|
||||
assert ksize_d == 1
|
||||
|
||||
stride_b, stride_p, stride_r, stride_c, stride_d = op.get_attr("strides")
|
||||
assert stride_b == 1
|
||||
assert stride_d == 1
|
||||
|
||||
batch_size = input_shape[0]
|
||||
channels = input_shape[4]
|
||||
|
||||
padding = op.get_attr("padding")
|
||||
out_planes, out_rows, out_cols = common_shapes.get_conv_output_size(
|
||||
input_shape[1:4], (ksize_p, ksize_r, ksize_c),
|
||||
(stride_p, stride_r, stride_c), padding)
|
||||
return [tensor_shape.TensorShape([batch_size, out_planes, out_rows, out_cols,
|
||||
channels])]
|
||||
|
||||
|
||||
ops.RegisterShape("Conv3D")(common_shapes.call_cpp_shape_fn)
|
||||
ops.RegisterShape("MaxPool3D")(common_shapes.call_cpp_shape_fn)
|
||||
ops.RegisterShape("AvgPool3D")(common_shapes.call_cpp_shape_fn)
|
||||
ops.RegisterShape("Conv3DBackpropFilter")(common_shapes.call_cpp_shape_fn)
|
||||
ops.RegisterShape("Conv3DBackpropInput")(common_shapes.call_cpp_shape_fn)
|
||||
ops.RegisterShape("Conv3DBackpropFilterV2")(common_shapes.call_cpp_shape_fn)
|
||||
@ -1392,46 +1291,7 @@ def conv1d(value, filters, stride, padding,
|
||||
return array_ops.squeeze(result, [1])
|
||||
|
||||
|
||||
@ops.RegisterShape("Dilation2D")
|
||||
def _Dilation2DShape(op):
|
||||
"""Shape function for Dilation2D op."""
|
||||
input_shape = op.inputs[0].get_shape().with_rank(4)
|
||||
filter_shape = op.inputs[1].get_shape().with_rank(3)
|
||||
|
||||
batch_size = input_shape[0]
|
||||
in_rows = input_shape[1]
|
||||
in_cols = input_shape[2]
|
||||
depth = input_shape[3]
|
||||
|
||||
filter_rows = filter_shape[0]
|
||||
filter_cols = filter_shape[1]
|
||||
# Check that the input depths are compatible.
|
||||
input_shape[3].assert_is_compatible_with(filter_shape[2])
|
||||
|
||||
stride_b, stride_r, stride_c, stride_d = op.get_attr("strides")
|
||||
if stride_b != 1 or stride_d != 1:
|
||||
raise ValueError("Current implementation does not yet support "
|
||||
"strides in the batch and depth dimensions.")
|
||||
|
||||
rate_b, rate_r, rate_c, rate_d = op.get_attr("rates")
|
||||
if rate_b != 1 or rate_d != 1:
|
||||
raise ValueError("Current implementation does not yet support "
|
||||
"rates in the batch and depth dimensions.")
|
||||
|
||||
filter_rows_eff = filter_rows + (filter_rows - 1) * (rate_r - 1)
|
||||
filter_cols_eff = filter_cols + (filter_cols - 1) * (rate_c - 1)
|
||||
|
||||
padding = op.get_attr("padding")
|
||||
out_rows, out_cols = common_shapes.get2d_conv_output_size(in_rows, in_cols,
|
||||
filter_rows_eff,
|
||||
filter_cols_eff,
|
||||
stride_r, stride_c,
|
||||
padding)
|
||||
|
||||
output_shape = [batch_size, out_rows, out_cols, depth]
|
||||
return [tensor_shape.TensorShape(output_shape)]
|
||||
|
||||
|
||||
ops.RegisterShape("Dilation2D")(common_shapes.call_cpp_shape_fn)
|
||||
ops.RegisterShape("Dilation2DBackpropInput")(common_shapes.call_cpp_shape_fn)
|
||||
ops.RegisterShape("Dilation2DBackpropFilter")(common_shapes.call_cpp_shape_fn)
|
||||
|
||||
|
Loading…
Reference in New Issue
Block a user