Switch nn_ops shape fns to delegate to C++, for all that have a C++
implementation (fractional pool ones don't yet). Change BiasAdd functions to require only rank 3, not 4, for NHWC. This matches the behavior of GetBiasValueDims in bias_op.cc. Removed unused functions common_shapes.bias_add_shape and common_shapes.bias_add_grad_shape. Change: 132597521
This commit is contained in:
parent
57d6a3ee56
commit
c3a30a230f
@ -135,7 +135,7 @@ Status BiasAddShape(shape_inference::InferenceContext* c) {
|
|||||||
Status s = c->GetAttr("data_format", &data_format);
|
Status s = c->GetAttr("data_format", &data_format);
|
||||||
|
|
||||||
if (s.ok() && data_format == "NCHW") {
|
if (s.ok() && data_format == "NCHW") {
|
||||||
TF_RETURN_IF_ERROR(c->WithRankAtLeast(c->input(0), 4, &input_shape));
|
TF_RETURN_IF_ERROR(c->WithRankAtLeast(c->input(0), 3, &input_shape));
|
||||||
} else {
|
} else {
|
||||||
TF_RETURN_IF_ERROR(c->WithRankAtLeast(c->input(0), 2, &input_shape));
|
TF_RETURN_IF_ERROR(c->WithRankAtLeast(c->input(0), 2, &input_shape));
|
||||||
}
|
}
|
||||||
@ -193,7 +193,7 @@ Status BiasAddGradShape(shape_inference::InferenceContext* c) {
|
|||||||
Status s = c->GetAttr("data_format", &data_format);
|
Status s = c->GetAttr("data_format", &data_format);
|
||||||
|
|
||||||
if (s.ok() && data_format == "NCHW") {
|
if (s.ok() && data_format == "NCHW") {
|
||||||
TF_RETURN_IF_ERROR(c->WithRankAtLeast(c->input(0), 4, &input_shape));
|
TF_RETURN_IF_ERROR(c->WithRankAtLeast(c->input(0), 3, &input_shape));
|
||||||
c->set_output(0, c->Vector(c->Dim(input_shape, -3)));
|
c->set_output(0, c->Vector(c->Dim(input_shape, -3)));
|
||||||
} else {
|
} else {
|
||||||
TF_RETURN_IF_ERROR(c->WithRankAtLeast(c->input(0), 2, &input_shape));
|
TF_RETURN_IF_ERROR(c->WithRankAtLeast(c->input(0), 2, &input_shape));
|
||||||
|
@ -242,6 +242,19 @@ TEST(CommonShapeFnsTest, BiasAddShapeTest) {
|
|||||||
EXPECT_EQ("[8,6,4,2,3,4,5]", c.DebugString(output));
|
EXPECT_EQ("[8,6,4,2,3,4,5]", c.DebugString(output));
|
||||||
}
|
}
|
||||||
|
|
||||||
|
{
|
||||||
|
// NCHW format with input rank 3
|
||||||
|
TF_CHECK_OK(NodeDefBuilder("test", "BiasAdd")
|
||||||
|
.Input("a", 0, DT_FLOAT)
|
||||||
|
.Input("b", 0, DT_FLOAT)
|
||||||
|
.Attr("data_format", "NCHW")
|
||||||
|
.Finalize(&def));
|
||||||
|
InferenceContext c(&def, op_def, {"[10,11,12]", "[10]"}, {});
|
||||||
|
TF_EXPECT_OK(BiasAddShape(&c));
|
||||||
|
ShapeHandle output = c.output(0);
|
||||||
|
EXPECT_EQ("[10,11,12]", c.DebugString(output));
|
||||||
|
}
|
||||||
|
|
||||||
{
|
{
|
||||||
// Input rank not high enough
|
// Input rank not high enough
|
||||||
InferenceContext c(&def, op_def, {"[3]", "[3]"}, {});
|
InferenceContext c(&def, op_def, {"[3]", "[3]"}, {});
|
||||||
@ -256,7 +269,7 @@ TEST(CommonShapeFnsTest, BiasAddShapeTest) {
|
|||||||
.Attr("data_format", "NCHW")
|
.Attr("data_format", "NCHW")
|
||||||
.Finalize(&def));
|
.Finalize(&def));
|
||||||
// NCHW format
|
// NCHW format
|
||||||
InferenceContext c(&def, op_def, {"[2,3,4]", "[3]"}, {});
|
InferenceContext c(&def, op_def, {"[2,3]", "[3]"}, {});
|
||||||
EXPECT_FALSE(BiasAddShape(&c).ok());
|
EXPECT_FALSE(BiasAddShape(&c).ok());
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@ -313,6 +326,18 @@ TEST(CommonShapeFnsTest, BiasAddGradShapeTest) {
|
|||||||
EXPECT_EQ(3, c.Value(c.Dim(output, 0)));
|
EXPECT_EQ(3, c.Value(c.Dim(output, 0)));
|
||||||
}
|
}
|
||||||
|
|
||||||
|
{
|
||||||
|
// NCHW format with input rank 3
|
||||||
|
TF_CHECK_OK(NodeDefBuilder("test", "BiasAddGrad")
|
||||||
|
.Input("a", 0, DT_FLOAT)
|
||||||
|
.Attr("data_format", "NCHW")
|
||||||
|
.Finalize(&def));
|
||||||
|
InferenceContext c(&def, op_def, {"[10,11,12]"}, {});
|
||||||
|
TF_EXPECT_OK(BiasAddGradShape(&c));
|
||||||
|
ShapeHandle output = c.output(0);
|
||||||
|
EXPECT_EQ(10, c.Value(c.Dim(output, 0)));
|
||||||
|
}
|
||||||
|
|
||||||
{
|
{
|
||||||
// Input rank not high enough
|
// Input rank not high enough
|
||||||
InferenceContext c(&def, op_def, {"[3]"}, {});
|
InferenceContext c(&def, op_def, {"[3]"}, {});
|
||||||
@ -326,7 +351,7 @@ TEST(CommonShapeFnsTest, BiasAddGradShapeTest) {
|
|||||||
.Attr("data_format", "NCHW")
|
.Attr("data_format", "NCHW")
|
||||||
.Finalize(&def));
|
.Finalize(&def));
|
||||||
// NCHW format
|
// NCHW format
|
||||||
InferenceContext c(&def, op_def, {"[2,3,4]"}, {});
|
InferenceContext c(&def, op_def, {"[2,3]"}, {});
|
||||||
EXPECT_FALSE(BiasAddGradShape(&c).ok());
|
EXPECT_FALSE(BiasAddGradShape(&c).ok());
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
@ -102,45 +102,6 @@ def matmul_shape(op):
|
|||||||
return [tensor_shape.TensorShape([output_rows, output_cols])]
|
return [tensor_shape.TensorShape([output_rows, output_cols])]
|
||||||
|
|
||||||
|
|
||||||
def bias_add_shape(op):
|
|
||||||
"""Shape function for a BiasAdd op."""
|
|
||||||
input_shape = op.inputs[0].get_shape().with_rank_at_least(2)
|
|
||||||
bias_shape = op.inputs[1].get_shape().with_rank(1)
|
|
||||||
if input_shape.ndims is not None:
|
|
||||||
# Output has the same shape as input, and matches the length of
|
|
||||||
# bias in its bias dimension.
|
|
||||||
try:
|
|
||||||
data_format = op.get_attr("data_format")
|
|
||||||
except ValueError:
|
|
||||||
data_format = None
|
|
||||||
if data_format == b"NCHW":
|
|
||||||
# Merge the length of bias_shape into the third-to-last dimension.
|
|
||||||
output_shape = input_shape[0:-3].concatenate(input_shape[-3].merge_with(
|
|
||||||
bias_shape[0])).concatenate(input_shape[-2:])
|
|
||||||
else:
|
|
||||||
output_shape = input_shape[0:-1].concatenate(input_shape[-1].merge_with(
|
|
||||||
bias_shape[0]))
|
|
||||||
else:
|
|
||||||
output_shape = tensor_shape.unknown_shape()
|
|
||||||
return [output_shape]
|
|
||||||
|
|
||||||
|
|
||||||
def bias_add_grad_shape(op):
|
|
||||||
"""Shape function for a BiasAddGrad op."""
|
|
||||||
input_shape = op.inputs[0].get_shape().with_rank_at_least(2)
|
|
||||||
try:
|
|
||||||
data_format = op.get_attr("data_format")
|
|
||||||
except ValueError:
|
|
||||||
data_format = None
|
|
||||||
|
|
||||||
if data_format == b"NCHW":
|
|
||||||
output_shape = input_shape[-3]
|
|
||||||
else:
|
|
||||||
output_shape = input_shape[-1]
|
|
||||||
|
|
||||||
return [output_shape]
|
|
||||||
|
|
||||||
|
|
||||||
def get_conv_output_size(input_size, filter_size, strides, padding_type):
|
def get_conv_output_size(input_size, filter_size, strides, padding_type):
|
||||||
"""Returns the spatial size of a n-d convolution/pooling output."""
|
"""Returns the spatial size of a n-d convolution/pooling output."""
|
||||||
input_size = tuple([tensor_shape.as_dimension(x).value for x in input_size])
|
input_size = tuple([tensor_shape.as_dimension(x).value for x in input_size])
|
||||||
|
@ -929,30 +929,12 @@ class PoolingTest(tf.test.TestCase):
|
|||||||
pool_func(tf.placeholder(tf.float32, shape=[1, 3]),
|
pool_func(tf.placeholder(tf.float32, shape=[1, 3]),
|
||||||
ksize=[1, 1, 1, 1], strides=[1, 1, 1, 1], padding="SAME")
|
ksize=[1, 1, 1, 1], strides=[1, 1, 1, 1], padding="SAME")
|
||||||
|
|
||||||
# Illegal strides.
|
|
||||||
with self.assertRaisesRegexp(ValueError, "strides in the batch"):
|
|
||||||
tf.nn.max_pool_with_argmax(
|
|
||||||
tf.placeholder(tf.float32),
|
|
||||||
ksize=[1, 1, 1, 1],
|
|
||||||
strides=[2, 1, 1, 1],
|
|
||||||
padding="SAME")
|
|
||||||
|
|
||||||
# Filter larger than input.
|
|
||||||
for pool_func in [tf.nn.max_pool_with_argmax]:
|
|
||||||
with self.assertRaisesRegexp(ValueError,
|
|
||||||
"Filter must not be larger than the input"):
|
|
||||||
pool_func(tf.placeholder(tf.float32,
|
|
||||||
shape=[32, 20, 20, 3]),
|
|
||||||
ksize=[1, 20, 21, 1], strides=[1, 1, 1, 1], padding="SAME")
|
|
||||||
with self.assertRaisesRegexp(ValueError,
|
|
||||||
"Filter must not be larger than the input"):
|
|
||||||
pool_func(tf.placeholder(tf.float32,
|
|
||||||
shape=[32, 20, 20, 3]),
|
|
||||||
ksize=[1, 21, 20, 1], strides=[1, 1, 1, 1], padding="SAME")
|
|
||||||
|
|
||||||
def testOpEdgeCases(self):
|
def testOpEdgeCases(self):
|
||||||
with self.test_session() as sess:
|
with self.test_session() as sess:
|
||||||
for pool_func in [tf.nn.max_pool, tf.nn.avg_pool]:
|
pool_funcs = [tf.nn.max_pool, tf.nn.avg_pool]
|
||||||
|
if tf.test.is_gpu_available():
|
||||||
|
pool_funcs.append(tf.nn.max_pool_with_argmax)
|
||||||
|
for pool_func in pool_funcs:
|
||||||
# Illegal strides.
|
# Illegal strides.
|
||||||
with self.assertRaisesRegexp(
|
with self.assertRaisesRegexp(
|
||||||
tf.errors.UnimplementedError,
|
tf.errors.UnimplementedError,
|
||||||
|
@ -81,7 +81,7 @@ class TopKTest(tf.test.TestCase):
|
|||||||
def testKTooLarge(self):
|
def testKTooLarge(self):
|
||||||
inputs = [[0.1, 0.2], [0.3, 0.4]]
|
inputs = [[0.1, 0.2], [0.3, 0.4]]
|
||||||
with self.assertRaisesRegexp(
|
with self.assertRaisesRegexp(
|
||||||
ValueError, r"input.shape \(2, 2\) must have last dimension >= k = 4"):
|
ValueError, r"must have last dimension >= k = 4"):
|
||||||
tf.nn.top_k(inputs, 4)
|
tf.nn.top_k(inputs, 4)
|
||||||
|
|
||||||
def testTopKGradients(self):
|
def testTopKGradients(self):
|
||||||
|
@ -393,9 +393,10 @@ def bias_add(value, bias, data_format=None, name=None):
|
|||||||
return gen_nn_ops._bias_add(value, bias, data_format=data_format, name=name)
|
return gen_nn_ops._bias_add(value, bias, data_format=data_format, name=name)
|
||||||
|
|
||||||
|
|
||||||
ops.RegisterShape("BiasAdd")(common_shapes.bias_add_shape)
|
ops.RegisterShape("BiasAddV1")(common_shapes.call_cpp_shape_fn)
|
||||||
|
ops.RegisterShape("BiasAdd")(common_shapes.call_cpp_shape_fn)
|
||||||
ops.RegisterShape("BiasAddGrad")(common_shapes.bias_add_grad_shape)
|
ops.RegisterShape("BiasAddGradV1")(common_shapes.call_cpp_shape_fn)
|
||||||
|
ops.RegisterShape("BiasAddGrad")(common_shapes.call_cpp_shape_fn)
|
||||||
|
|
||||||
|
|
||||||
# pylint: disable=protected-access
|
# pylint: disable=protected-access
|
||||||
@ -426,11 +427,6 @@ def bias_add_v1(value, bias, name=None):
|
|||||||
return gen_nn_ops._bias_add_v1(value, bias, name=name)
|
return gen_nn_ops._bias_add_v1(value, bias, name=name)
|
||||||
|
|
||||||
|
|
||||||
ops.RegisterShape("BiasAddV1")(common_shapes.bias_add_shape)
|
|
||||||
|
|
||||||
ops.RegisterShape("BiasAddGradV1")(common_shapes.bias_add_grad_shape)
|
|
||||||
|
|
||||||
|
|
||||||
def crelu(features, name=None):
|
def crelu(features, name=None):
|
||||||
"""Computes Concatenated ReLU.
|
"""Computes Concatenated ReLU.
|
||||||
|
|
||||||
@ -866,23 +862,12 @@ ops.RegisterShape("LRNGrad")(common_shapes.call_cpp_shape_fn)
|
|||||||
ops.RegisterShape("Softmax")(common_shapes.call_cpp_shape_fn)
|
ops.RegisterShape("Softmax")(common_shapes.call_cpp_shape_fn)
|
||||||
ops.RegisterShape("LogSoftmax")(common_shapes.call_cpp_shape_fn)
|
ops.RegisterShape("LogSoftmax")(common_shapes.call_cpp_shape_fn)
|
||||||
ops.RegisterShape("InTopK")(common_shapes.call_cpp_shape_fn)
|
ops.RegisterShape("InTopK")(common_shapes.call_cpp_shape_fn)
|
||||||
|
ops.RegisterShape("TopK")(common_shapes.call_cpp_shape_fn)
|
||||||
|
|
||||||
|
|
||||||
@ops.RegisterShape("TopK")
|
|
||||||
@ops.RegisterShape("TopKV2")
|
@ops.RegisterShape("TopKV2")
|
||||||
def _TopKShape(op):
|
def _TopKV2Shape(op):
|
||||||
"""Shape function for TopK and TopKV2 ops."""
|
return common_shapes.call_cpp_shape_fn(op, input_tensors_needed=[1])
|
||||||
input_shape = op.inputs[0].get_shape().with_rank_at_least(1)
|
|
||||||
if len(op.inputs) >= 2:
|
|
||||||
k = tensor_util.constant_value(op.inputs[1])
|
|
||||||
else:
|
|
||||||
k = op.get_attr("k")
|
|
||||||
last = input_shape[-1].value
|
|
||||||
if last is not None and k is not None and last < k:
|
|
||||||
raise ValueError("input.shape %s must have last dimension >= k = %d" %
|
|
||||||
(input_shape, k))
|
|
||||||
output_shape = input_shape[:-1].concatenate([k])
|
|
||||||
return [output_shape, output_shape]
|
|
||||||
|
|
||||||
|
|
||||||
ops.RegisterShape("BatchNormWithGlobalNormalization")(
|
ops.RegisterShape("BatchNormWithGlobalNormalization")(
|
||||||
@ -960,24 +945,12 @@ def _FusedResizeAndPadConv2DShape(op):
|
|||||||
return [tensor_shape.TensorShape(output_shape)]
|
return [tensor_shape.TensorShape(output_shape)]
|
||||||
|
|
||||||
|
|
||||||
@ops.RegisterShape("MaxPoolWithArgmax")
|
ops.RegisterShape("MaxPoolWithArgmax")(common_shapes.call_cpp_shape_fn)
|
||||||
def _MaxPoolWithArgMaxShape(op):
|
|
||||||
"""Shape function for MaxPoolWithArgmax op."""
|
|
||||||
return common_shapes.max_pool_shape(op) * 2
|
|
||||||
|
|
||||||
|
|
||||||
@ops.RegisterShape("AvgPoolGrad")
|
@ops.RegisterShape("AvgPoolGrad")
|
||||||
def _AvgPoolGradShape(op):
|
def _AvgPoolGradShape(op):
|
||||||
"""Shape function for the AvgPoolGrad op."""
|
return common_shapes.call_cpp_shape_fn(op, input_tensors_needed=[0])
|
||||||
orig_input_shape = tensor_util.constant_value(op.inputs[0])
|
|
||||||
if orig_input_shape is not None:
|
|
||||||
return [tensor_shape.TensorShape(orig_input_shape.tolist())]
|
|
||||||
else:
|
|
||||||
# NOTE(mrry): We could in principle work out the shape from the
|
|
||||||
# gradients and the attrs, but if we do not know orig_input_shape
|
|
||||||
# statically, then we are unlikely to know the shape of the
|
|
||||||
# gradients either.
|
|
||||||
return [tensor_shape.unknown_shape(ndims=4)]
|
|
||||||
|
|
||||||
|
|
||||||
@ops.RegisterShape("FractionalMaxPool")
|
@ops.RegisterShape("FractionalMaxPool")
|
||||||
@ -1015,50 +988,22 @@ def _fractional_avg_pool_grad_shape(op):
|
|||||||
|
|
||||||
@ops.RegisterShape("Conv2DBackpropFilter")
|
@ops.RegisterShape("Conv2DBackpropFilter")
|
||||||
def _Conv2DBackpropFilterShape(op):
|
def _Conv2DBackpropFilterShape(op):
|
||||||
"""Shape function for the Conv2DBackpropFilter op."""
|
return common_shapes.call_cpp_shape_fn(op, input_tensors_needed=[1])
|
||||||
filter_shape = tensor_util.constant_value(op.inputs[1])
|
|
||||||
if filter_shape is not None:
|
|
||||||
return [tensor_shape.TensorShape(filter_shape.tolist())]
|
|
||||||
else:
|
|
||||||
# NOTE(mrry): We could in principle work out the shape from the
|
|
||||||
# gradients and the attrs, but if we do not know filter_shape
|
|
||||||
# statically, then we are unlikely to know the shape of the
|
|
||||||
# gradients either.
|
|
||||||
return [tensor_shape.unknown_shape(ndims=4)]
|
|
||||||
|
|
||||||
|
|
||||||
@ops.RegisterShape("Conv2DBackpropInput")
|
@ops.RegisterShape("Conv2DBackpropInput")
|
||||||
def _Conv2DBackpropInputShape(op):
|
def _Conv2DBackpropInputShape(op):
|
||||||
"""Shape function for the Conv2DBackpropInput op."""
|
return common_shapes.call_cpp_shape_fn(op, input_tensors_needed=[0])
|
||||||
input_shape = tensor_util.constant_value(op.inputs[0])
|
|
||||||
if input_shape is not None:
|
|
||||||
return [tensor_shape.TensorShape(input_shape.tolist())]
|
|
||||||
else:
|
|
||||||
# NOTE(mrry): We could in principle work out the shape from the
|
|
||||||
# gradients and the attrs, but if we do not know input_shape
|
|
||||||
# statically, then we are unlikely to know the shape of the
|
|
||||||
# gradients either.
|
|
||||||
return [tensor_shape.unknown_shape(ndims=4)]
|
|
||||||
|
|
||||||
|
|
||||||
@ops.RegisterShape("DepthwiseConv2dNativeBackpropFilter")
|
@ops.RegisterShape("DepthwiseConv2dNativeBackpropFilter")
|
||||||
def _DepthwiseConv2dNativeBackpropFilterShape(op):
|
def _DepthwiseConv2dNativeBackpropFilterShape(op):
|
||||||
"""Shape function for the DepthwiseConv2dNativeBackpropFilter op."""
|
return common_shapes.call_cpp_shape_fn(op, input_tensors_needed=[1])
|
||||||
filter_shape = tensor_util.constant_value(op.inputs[1])
|
|
||||||
if filter_shape is not None:
|
|
||||||
return [tensor_shape.TensorShape(filter_shape.tolist())]
|
|
||||||
else:
|
|
||||||
return [tensor_shape.unknown_shape(ndims=4)]
|
|
||||||
|
|
||||||
|
|
||||||
@ops.RegisterShape("DepthwiseConv2dNativeBackpropInput")
|
@ops.RegisterShape("DepthwiseConv2dNativeBackpropInput")
|
||||||
def _DepthwiseConv2dNativeBackpropInputShape(op):
|
def _DepthwiseConv2dNativeBackpropInputShape(op):
|
||||||
"""Shape function for the DepthwiseConv2dNativeBackpropInput op."""
|
return common_shapes.call_cpp_shape_fn(op, input_tensors_needed=[0])
|
||||||
input_shape = tensor_util.constant_value(op.inputs[0])
|
|
||||||
if input_shape is not None:
|
|
||||||
return [tensor_shape.TensorShape(input_shape.tolist())]
|
|
||||||
else:
|
|
||||||
return [tensor_shape.unknown_shape(ndims=4)]
|
|
||||||
|
|
||||||
|
|
||||||
ops.RegisterShape("MaxPoolGrad")(common_shapes.call_cpp_shape_fn)
|
ops.RegisterShape("MaxPoolGrad")(common_shapes.call_cpp_shape_fn)
|
||||||
@ -1136,55 +1081,9 @@ def _calc_depthwise_conv_weight_params(graph, node):
|
|||||||
filter_channel_multiplier))
|
filter_channel_multiplier))
|
||||||
|
|
||||||
|
|
||||||
@ops.RegisterShape("Conv3D")
|
ops.RegisterShape("Conv3D")(common_shapes.call_cpp_shape_fn)
|
||||||
def _Conv3DShape(op):
|
ops.RegisterShape("MaxPool3D")(common_shapes.call_cpp_shape_fn)
|
||||||
"""Shape function for Conv3D."""
|
ops.RegisterShape("AvgPool3D")(common_shapes.call_cpp_shape_fn)
|
||||||
input_shape = op.inputs[0].get_shape().with_rank(5)
|
|
||||||
filter_shape = op.inputs[1].get_shape().with_rank(5)
|
|
||||||
|
|
||||||
batch_size = input_shape[0]
|
|
||||||
out_channels = filter_shape[4]
|
|
||||||
# Check that the input number of channels is compatible between
|
|
||||||
# input data and filter size.
|
|
||||||
input_shape[4].assert_is_compatible_with(filter_shape[3])
|
|
||||||
|
|
||||||
stride_b, stride_p, stride_r, stride_c, stride_d = op.get_attr("strides")
|
|
||||||
assert stride_b == 1
|
|
||||||
assert stride_d == 1
|
|
||||||
|
|
||||||
padding_type = op.get_attr("padding")
|
|
||||||
out_planes, out_rows, out_cols = common_shapes.get_conv_output_size(
|
|
||||||
input_shape[1:4], filter_shape[0:3], (stride_p, stride_r, stride_c),
|
|
||||||
padding_type)
|
|
||||||
|
|
||||||
return [tensor_shape.TensorShape([batch_size, out_planes, out_rows, out_cols,
|
|
||||||
out_channels])]
|
|
||||||
|
|
||||||
|
|
||||||
@ops.RegisterShape("MaxPool3D")
|
|
||||||
@ops.RegisterShape("AvgPool3D")
|
|
||||||
def _Pool3DShape(op):
|
|
||||||
"""Shape function for Max/AvgPool3D."""
|
|
||||||
input_shape = op.inputs[0].get_shape().with_rank(5)
|
|
||||||
ksize_b, ksize_p, ksize_r, ksize_c, ksize_d = op.get_attr("ksize")
|
|
||||||
assert ksize_b == 1
|
|
||||||
assert ksize_d == 1
|
|
||||||
|
|
||||||
stride_b, stride_p, stride_r, stride_c, stride_d = op.get_attr("strides")
|
|
||||||
assert stride_b == 1
|
|
||||||
assert stride_d == 1
|
|
||||||
|
|
||||||
batch_size = input_shape[0]
|
|
||||||
channels = input_shape[4]
|
|
||||||
|
|
||||||
padding = op.get_attr("padding")
|
|
||||||
out_planes, out_rows, out_cols = common_shapes.get_conv_output_size(
|
|
||||||
input_shape[1:4], (ksize_p, ksize_r, ksize_c),
|
|
||||||
(stride_p, stride_r, stride_c), padding)
|
|
||||||
return [tensor_shape.TensorShape([batch_size, out_planes, out_rows, out_cols,
|
|
||||||
channels])]
|
|
||||||
|
|
||||||
|
|
||||||
ops.RegisterShape("Conv3DBackpropFilter")(common_shapes.call_cpp_shape_fn)
|
ops.RegisterShape("Conv3DBackpropFilter")(common_shapes.call_cpp_shape_fn)
|
||||||
ops.RegisterShape("Conv3DBackpropInput")(common_shapes.call_cpp_shape_fn)
|
ops.RegisterShape("Conv3DBackpropInput")(common_shapes.call_cpp_shape_fn)
|
||||||
ops.RegisterShape("Conv3DBackpropFilterV2")(common_shapes.call_cpp_shape_fn)
|
ops.RegisterShape("Conv3DBackpropFilterV2")(common_shapes.call_cpp_shape_fn)
|
||||||
@ -1392,46 +1291,7 @@ def conv1d(value, filters, stride, padding,
|
|||||||
return array_ops.squeeze(result, [1])
|
return array_ops.squeeze(result, [1])
|
||||||
|
|
||||||
|
|
||||||
@ops.RegisterShape("Dilation2D")
|
ops.RegisterShape("Dilation2D")(common_shapes.call_cpp_shape_fn)
|
||||||
def _Dilation2DShape(op):
|
|
||||||
"""Shape function for Dilation2D op."""
|
|
||||||
input_shape = op.inputs[0].get_shape().with_rank(4)
|
|
||||||
filter_shape = op.inputs[1].get_shape().with_rank(3)
|
|
||||||
|
|
||||||
batch_size = input_shape[0]
|
|
||||||
in_rows = input_shape[1]
|
|
||||||
in_cols = input_shape[2]
|
|
||||||
depth = input_shape[3]
|
|
||||||
|
|
||||||
filter_rows = filter_shape[0]
|
|
||||||
filter_cols = filter_shape[1]
|
|
||||||
# Check that the input depths are compatible.
|
|
||||||
input_shape[3].assert_is_compatible_with(filter_shape[2])
|
|
||||||
|
|
||||||
stride_b, stride_r, stride_c, stride_d = op.get_attr("strides")
|
|
||||||
if stride_b != 1 or stride_d != 1:
|
|
||||||
raise ValueError("Current implementation does not yet support "
|
|
||||||
"strides in the batch and depth dimensions.")
|
|
||||||
|
|
||||||
rate_b, rate_r, rate_c, rate_d = op.get_attr("rates")
|
|
||||||
if rate_b != 1 or rate_d != 1:
|
|
||||||
raise ValueError("Current implementation does not yet support "
|
|
||||||
"rates in the batch and depth dimensions.")
|
|
||||||
|
|
||||||
filter_rows_eff = filter_rows + (filter_rows - 1) * (rate_r - 1)
|
|
||||||
filter_cols_eff = filter_cols + (filter_cols - 1) * (rate_c - 1)
|
|
||||||
|
|
||||||
padding = op.get_attr("padding")
|
|
||||||
out_rows, out_cols = common_shapes.get2d_conv_output_size(in_rows, in_cols,
|
|
||||||
filter_rows_eff,
|
|
||||||
filter_cols_eff,
|
|
||||||
stride_r, stride_c,
|
|
||||||
padding)
|
|
||||||
|
|
||||||
output_shape = [batch_size, out_rows, out_cols, depth]
|
|
||||||
return [tensor_shape.TensorShape(output_shape)]
|
|
||||||
|
|
||||||
|
|
||||||
ops.RegisterShape("Dilation2DBackpropInput")(common_shapes.call_cpp_shape_fn)
|
ops.RegisterShape("Dilation2DBackpropInput")(common_shapes.call_cpp_shape_fn)
|
||||||
ops.RegisterShape("Dilation2DBackpropFilter")(common_shapes.call_cpp_shape_fn)
|
ops.RegisterShape("Dilation2DBackpropFilter")(common_shapes.call_cpp_shape_fn)
|
||||||
|
|
||||||
|
Loading…
Reference in New Issue
Block a user