[TF:NN:CONVOLUTION] Don't call space to batch in depthwise convolution on TPU.
PiperOrigin-RevId: 261242540
This commit is contained in:
parent
5ee9c10e83
commit
20506ddda8
@ -25,6 +25,7 @@ from tensorflow.compiler.tests import xla_test
|
|||||||
from tensorflow.python.framework import constant_op
|
from tensorflow.python.framework import constant_op
|
||||||
from tensorflow.python.framework import ops
|
from tensorflow.python.framework import ops
|
||||||
from tensorflow.python.ops import array_ops
|
from tensorflow.python.ops import array_ops
|
||||||
|
from tensorflow.python.ops import nn_impl
|
||||||
from tensorflow.python.ops import nn_ops
|
from tensorflow.python.ops import nn_ops
|
||||||
import tensorflow.python.ops.nn_grad # pylint: disable=unused-import
|
import tensorflow.python.ops.nn_grad # pylint: disable=unused-import
|
||||||
from tensorflow.python.platform import test
|
from tensorflow.python.platform import test
|
||||||
@ -87,6 +88,32 @@ def ConfigsToTest():
|
|||||||
yield i, f, o, s, p
|
yield i, f, o, s, p
|
||||||
|
|
||||||
|
|
||||||
|
def ConfigsWithDilationsToTest():
|
||||||
|
"""Iterator for different convolution shapes, strides and paddings.
|
||||||
|
|
||||||
|
Yields:
|
||||||
|
Tuple (input_size, filter_size, out_size, stride, dilation, padding), the
|
||||||
|
depthwise
|
||||||
|
convolution parameters.
|
||||||
|
"""
|
||||||
|
input_sizes = [[4, 6, 6, 48], [4, 8, 8, 84], [4, 36, 36, 2], [4, 148, 148, 2],
|
||||||
|
[3, 300, 300, 3]]
|
||||||
|
filter_sizes = [[1, 1, 48, 2], [1, 3, 84, 1], [5, 5, 2, 1], [4, 4, 2, 8],
|
||||||
|
[2, 2, 3, 8]]
|
||||||
|
out_sizes = [[4, 6, 6, 96], [4, 8, 8, 84], [4, 36, 36, 2], [4, 74, 74, 16],
|
||||||
|
[3, 296, 296, 24]]
|
||||||
|
strides = [1, 1, 2, 2, 1]
|
||||||
|
dilations = [2, 2, 4, 2, 4]
|
||||||
|
# pylint: disable=invalid-name
|
||||||
|
VALID = "VALID"
|
||||||
|
SAME = "SAME"
|
||||||
|
# pylint: enable=invalid-name
|
||||||
|
paddings = [SAME, SAME, SAME, SAME, VALID]
|
||||||
|
for i, f, o, s, d, p in zip(input_sizes, filter_sizes, out_sizes, strides,
|
||||||
|
dilations, paddings):
|
||||||
|
yield i, f, o, s, d, p
|
||||||
|
|
||||||
|
|
||||||
def CheckGradConfigsToTest():
|
def CheckGradConfigsToTest():
|
||||||
"""Iterator for different convolution shapes, strides and paddings.
|
"""Iterator for different convolution shapes, strides and paddings.
|
||||||
|
|
||||||
@ -315,6 +342,118 @@ class DepthwiseConv2DTest(xla_test.XLATestCase):
|
|||||||
padding="VALID",
|
padding="VALID",
|
||||||
expected=expected_output)
|
expected=expected_output)
|
||||||
|
|
||||||
|
# This is testing that depthwise_conv2d with dilation produces
|
||||||
|
# the same results between CPU and TPU. It also tests that NCHW
|
||||||
|
# and NWHC formats agree.
|
||||||
|
def _VerifyValuesWithDilation(self,
|
||||||
|
tensor_in_sizes,
|
||||||
|
filter_in_sizes,
|
||||||
|
stride,
|
||||||
|
dilation,
|
||||||
|
padding,
|
||||||
|
data_type,
|
||||||
|
data_format="NHWC"):
|
||||||
|
"""Verifies the output values of the convolution function.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
tensor_in_sizes: Input tensor dimensions in [batch, input_rows,
|
||||||
|
input_cols, input_depth].
|
||||||
|
filter_in_sizes: Filter tensor dimensions in [filter_rows, filter_cols,
|
||||||
|
input_depth, depth_multiplier].
|
||||||
|
stride: Stride.
|
||||||
|
dilation: Dilation.
|
||||||
|
padding: Padding type.
|
||||||
|
data_type: The data type to use.
|
||||||
|
data_format: The data_format of the input. "NHWC" or "NCHW".
|
||||||
|
"""
|
||||||
|
total_size_1 = 1
|
||||||
|
total_size_2 = 1
|
||||||
|
for s in tensor_in_sizes:
|
||||||
|
total_size_1 *= s
|
||||||
|
for s in filter_in_sizes:
|
||||||
|
total_size_2 *= s
|
||||||
|
# Initializes the input and filter tensor with numbers incrementing from 1.
|
||||||
|
x1 = np.array([f * 1.0 for f in range(1, total_size_1 + 1)],
|
||||||
|
dtype=data_type).reshape(tensor_in_sizes)
|
||||||
|
x2 = np.array([f * 1.0 for f in range(1, total_size_2 + 1)],
|
||||||
|
dtype=data_type).reshape(filter_in_sizes)
|
||||||
|
with self.session() as sess:
|
||||||
|
if data_type == np.float32:
|
||||||
|
# TODO(b/64210055): Tolerance for TPU is high.
|
||||||
|
tolerance = 1e-2
|
||||||
|
else:
|
||||||
|
self.assertEqual(data_type, np.float64)
|
||||||
|
tolerance = 1e-8
|
||||||
|
|
||||||
|
t1 = array_ops.placeholder(shape=tensor_in_sizes, dtype=data_type)
|
||||||
|
t2 = array_ops.placeholder(shape=filter_in_sizes, dtype=data_type)
|
||||||
|
|
||||||
|
native_t1 = t1
|
||||||
|
strides = [1, stride, stride, 1]
|
||||||
|
dilations = [dilation, dilation]
|
||||||
|
if data_format == "NCHW":
|
||||||
|
# Transpose from NWHC input to NCHW
|
||||||
|
# Ex. [4, 5, 5, 48] to [4, 48, 5, 5]
|
||||||
|
native_t1 = array_ops.transpose(t1, [0, 3, 1, 2])
|
||||||
|
strides = [1, 1, stride, stride]
|
||||||
|
|
||||||
|
with self.test_scope():
|
||||||
|
conv_native = nn_impl.depthwise_conv2d(
|
||||||
|
native_t1,
|
||||||
|
t2,
|
||||||
|
strides=strides,
|
||||||
|
rate=dilations,
|
||||||
|
data_format=data_format,
|
||||||
|
padding=padding)
|
||||||
|
|
||||||
|
if data_format == "NCHW":
|
||||||
|
# Transpose back from NCHW to NHWC
|
||||||
|
conv_native = array_ops.transpose(conv_native, [0, 2, 3, 1])
|
||||||
|
|
||||||
|
with ops.device("CPU"):
|
||||||
|
# CPU only support NHWC format
|
||||||
|
strides = [1, stride, stride, 1]
|
||||||
|
conv_interface = nn_impl.depthwise_conv2d(
|
||||||
|
t1, t2, strides=strides, rate=dilations, padding=padding)
|
||||||
|
|
||||||
|
native_result = sess.run(conv_native, {t1: x1, t2: x2})
|
||||||
|
interface_result = sess.run(conv_interface, {t1: x1, t2: x2})
|
||||||
|
|
||||||
|
print("data_type:", data_type, "max diff = ",
|
||||||
|
np.amax(np.absolute(native_result - interface_result)))
|
||||||
|
self.assertAllClose(
|
||||||
|
np.ravel(native_result), np.ravel(interface_result), rtol=tolerance)
|
||||||
|
|
||||||
|
def testDilationDepthwiseConv2DWith(self):
|
||||||
|
for index, (input_size, filter_size, _, stride, dilation,
|
||||||
|
padding) in enumerate(ConfigsWithDilationsToTest()):
|
||||||
|
print("Testing DilationDepthwiseConv2D,", index, "th config:", input_size,
|
||||||
|
"*", filter_size, "stride:", stride, "dilation: ", dilation,
|
||||||
|
"padding:", padding)
|
||||||
|
for data_type in self.float_types:
|
||||||
|
# TODO(phawkins): the reference implementation only supports float32.
|
||||||
|
if data_type == np.float32:
|
||||||
|
self._VerifyValuesWithDilation(input_size, filter_size, stride,
|
||||||
|
dilation, padding, data_type)
|
||||||
|
|
||||||
|
def testDilationDepthwiseConv2DWithFormat(self):
|
||||||
|
for index, (input_size, filter_size, _, stride, dilation,
|
||||||
|
padding) in enumerate(ConfigsWithDilationsToTest()):
|
||||||
|
print("Testing DilationDepthwiseConv2DFormat,", index, "th config:",
|
||||||
|
input_size, "*", filter_size, "stride:", stride, "dilation:",
|
||||||
|
dilation, "padding:", padding)
|
||||||
|
for data_type in self.float_types:
|
||||||
|
# TODO(phawkins): the reference implementation only supports float32.
|
||||||
|
if data_type == np.float32:
|
||||||
|
self._VerifyValuesWithDilation(
|
||||||
|
input_size,
|
||||||
|
filter_size,
|
||||||
|
stride,
|
||||||
|
dilation,
|
||||||
|
padding,
|
||||||
|
data_type,
|
||||||
|
data_format="NCHW")
|
||||||
|
|
||||||
def _CompareBackpropInput(self, input_sizes, filter_sizes, output_sizes,
|
def _CompareBackpropInput(self, input_sizes, filter_sizes, output_sizes,
|
||||||
stride, padding):
|
stride, padding):
|
||||||
x1 = np.random.rand(*filter_sizes).astype(np.float32)
|
x1 = np.random.rand(*filter_sizes).astype(np.float32)
|
||||||
@ -420,5 +559,139 @@ class DepthwiseConv2DTest(xla_test.XLATestCase):
|
|||||||
padding,
|
padding,
|
||||||
data_format="NCHW")
|
data_format="NCHW")
|
||||||
|
|
||||||
|
def _CompareBackpropInputWithDilation(self, input_sizes, filter_sizes,
|
||||||
|
output_sizes, stride, dilation,
|
||||||
|
padding):
|
||||||
|
x1 = np.random.rand(*filter_sizes).astype(np.float32)
|
||||||
|
x2 = np.random.rand(*output_sizes).astype(np.float32)
|
||||||
|
|
||||||
|
def _GetVal(use_xla):
|
||||||
|
with self.session():
|
||||||
|
t1 = array_ops.placeholder(np.float32, shape=filter_sizes)
|
||||||
|
t2 = array_ops.placeholder(np.float32, shape=output_sizes)
|
||||||
|
if use_xla:
|
||||||
|
with self.test_scope():
|
||||||
|
t0 = constant_op.constant(input_sizes, shape=[len(input_sizes)])
|
||||||
|
backprop = nn_ops.depthwise_conv2d_native_backprop_input(
|
||||||
|
t0,
|
||||||
|
t1,
|
||||||
|
t2,
|
||||||
|
strides=[1, stride, stride, 1],
|
||||||
|
dilations=[1, dilation, dilation, 1],
|
||||||
|
padding=padding)
|
||||||
|
else:
|
||||||
|
# TODO(wangtao): figure out gradient with stride > 1.
|
||||||
|
# depthwise_conv2d_native_backprop_input on CPU doesn't support
|
||||||
|
# dilation.
|
||||||
|
t3 = array_ops.space_to_batch(
|
||||||
|
t2, block_size=dilation, paddings=[[0, 0], [0, 0]])
|
||||||
|
input_sizes_transform = [
|
||||||
|
input_sizes[0] * dilation * dilation, input_sizes[1] // dilation,
|
||||||
|
input_sizes[2] // dilation, input_sizes[3]
|
||||||
|
]
|
||||||
|
t0 = constant_op.constant(
|
||||||
|
input_sizes_transform, shape=[len(input_sizes)])
|
||||||
|
backprop_naive = nn_ops.depthwise_conv2d_native_backprop_input(
|
||||||
|
t0, t1, t3, strides=[1, stride, stride, 1], padding=padding)
|
||||||
|
backprop = array_ops.batch_to_space(
|
||||||
|
backprop_naive, [[0, 0], [0, 0]], block_size=dilation)
|
||||||
|
|
||||||
|
ret = backprop.eval({t1: x1, t2: x2})
|
||||||
|
self.assertShapeEqual(ret, backprop)
|
||||||
|
return ret
|
||||||
|
|
||||||
|
gpu_value = _GetVal(use_xla=True)
|
||||||
|
cpu_value = _GetVal(use_xla=False)
|
||||||
|
|
||||||
|
# TODO (b/64210055): Tolerance for TPU is high.
|
||||||
|
self.assertAllClose(cpu_value, gpu_value, rtol=1e-2, atol=1e-3)
|
||||||
|
|
||||||
|
def testDilationDepthwiseConv2DInputGradWithCompare(self):
|
||||||
|
for index, (input_size, filter_size, output_size, stride, dilation,
|
||||||
|
padding) in enumerate(ConfigsWithDilationsToTest()):
|
||||||
|
print("Testing DilationDepthwiseConv2DInputGradWithDilationCompare,",
|
||||||
|
index, "th config:", input_size, "*", filter_size, "stride:",
|
||||||
|
stride, "dilation:", dilation, "padding:", padding)
|
||||||
|
# TODO(wangtao): implement CPU grad computation with stride > 1.
|
||||||
|
if stride == 1:
|
||||||
|
self._CompareBackpropInputWithDilation(input_size, filter_size,
|
||||||
|
output_size, stride, dilation,
|
||||||
|
padding)
|
||||||
|
|
||||||
|
def _CompareBackpropFilterWithDilation(self,
|
||||||
|
input_sizes,
|
||||||
|
filter_sizes,
|
||||||
|
output_sizes,
|
||||||
|
stride,
|
||||||
|
dilation,
|
||||||
|
padding,
|
||||||
|
data_format="NHWC"):
|
||||||
|
x0 = np.random.rand(*input_sizes).astype(np.float32)
|
||||||
|
x2 = np.random.rand(*output_sizes).astype(np.float32)
|
||||||
|
|
||||||
|
def _GetVal(use_xla):
|
||||||
|
with self.session():
|
||||||
|
t0 = array_ops.placeholder(np.float32, shape=input_sizes)
|
||||||
|
t1 = constant_op.constant(filter_sizes, shape=[len(filter_sizes)])
|
||||||
|
t2 = array_ops.placeholder(np.float32, shape=output_sizes)
|
||||||
|
native_t0 = t0
|
||||||
|
native_t2 = t2
|
||||||
|
strides = [1, stride, stride, 1]
|
||||||
|
dilations = [1, dilation, dilation, 1]
|
||||||
|
|
||||||
|
if use_xla:
|
||||||
|
if data_format == "NCHW":
|
||||||
|
# Transpose from NWHC input to NCHW
|
||||||
|
# Ex. [4, 5, 5, 48] to [4, 48, 5, 5]
|
||||||
|
native_t0 = array_ops.transpose(t0, [0, 3, 1, 2])
|
||||||
|
native_t2 = array_ops.transpose(t2, [0, 3, 1, 2])
|
||||||
|
strides = [1, 1, stride, stride]
|
||||||
|
dilations = [1, 1, dilation, dilation]
|
||||||
|
with self.test_scope():
|
||||||
|
backprop = nn_ops.depthwise_conv2d_native_backprop_filter(
|
||||||
|
native_t0,
|
||||||
|
t1,
|
||||||
|
native_t2,
|
||||||
|
strides=strides,
|
||||||
|
padding=padding,
|
||||||
|
dilations=dilations,
|
||||||
|
data_format=data_format)
|
||||||
|
else:
|
||||||
|
# For CPU, the format NCHW is not supported. Therefore we always use
|
||||||
|
# NHWC here.
|
||||||
|
# depthwise_conv2d_native_backprop_filter on CPU doesn't support
|
||||||
|
# dilation.
|
||||||
|
native_t3 = array_ops.space_to_batch(
|
||||||
|
native_t2, block_size=dilation, paddings=[[0, 0], [0, 0]])
|
||||||
|
native_t0_transform = array_ops.space_to_batch(
|
||||||
|
native_t0, block_size=dilation, paddings=[[0, 0], [0, 0]])
|
||||||
|
backprop = nn_ops.depthwise_conv2d_native_backprop_filter(
|
||||||
|
native_t0_transform,
|
||||||
|
t1,
|
||||||
|
native_t3,
|
||||||
|
strides=strides,
|
||||||
|
padding=padding)
|
||||||
|
ret = backprop.eval({t0: x0, t2: x2})
|
||||||
|
self.assertShapeEqual(ret, backprop)
|
||||||
|
return ret
|
||||||
|
|
||||||
|
gpu_value = _GetVal(use_xla=True)
|
||||||
|
cpu_value = _GetVal(use_xla=False)
|
||||||
|
# TODO(b/64210055): Tolerance for TPU is high.
|
||||||
|
self.assertAllClose(cpu_value, gpu_value, rtol=1e-3, atol=1e-4)
|
||||||
|
|
||||||
|
def testDilationDepthwiseConv2DFilterGradCompare(self):
|
||||||
|
for index, (input_size, filter_size, output_size, stride, dilation,
|
||||||
|
padding) in enumerate(ConfigsWithDilationsToTest()):
|
||||||
|
print("Testing DilationDepthwiseConv2DFilterGradCompare,", index,
|
||||||
|
"th config:", input_size, "*", filter_size, "producing output",
|
||||||
|
output_size, "stride:", stride, "dilation:", dilation, "padding:",
|
||||||
|
padding)
|
||||||
|
if stride == 1:
|
||||||
|
# TODO(wangtao): implement CPU grad computation with stride > 1.
|
||||||
|
self._CompareBackpropFilterWithDilation(input_size, filter_size,
|
||||||
|
output_size, stride, dilation,
|
||||||
|
padding)
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
test.main()
|
test.main()
|
||||||
|
@ -408,11 +408,15 @@ class Conv2DCustomBackpropFilterOp : public OpKernel {
|
|||||||
errors::InvalidArgument(
|
errors::InvalidArgument(
|
||||||
"Current implementation does not yet support "
|
"Current implementation does not yet support "
|
||||||
"dilations in the batch and depth dimensions."));
|
"dilations in the batch and depth dimensions."));
|
||||||
|
if (std::is_same<Device, CPUDevice>::value ||
|
||||||
|
std::is_same<Device, GPUDevice>::value) {
|
||||||
// TODO(yangzihao): Add a CPU implementation for dilated convolution.
|
// TODO(yangzihao): Add a CPU implementation for dilated convolution.
|
||||||
OP_REQUIRES(context, (dilations_[1] == 1 && dilations_[2] == 1),
|
OP_REQUIRES(context, (dilations_[1] == 1 && dilations_[2] == 1),
|
||||||
errors::InvalidArgument(
|
errors::InvalidArgument(
|
||||||
"Current libxsmm and customized CPU implementations do "
|
"Current libxsmm and customized CPU implementations do "
|
||||||
"not yet support dilation rates larger than 1."));
|
"not yet support dilation rates larger than 1."));
|
||||||
|
dilations_ = {1, 1, 1, 1};
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
void Compute(OpKernelContext* context) override {
|
void Compute(OpKernelContext* context) override {
|
||||||
@ -434,8 +438,8 @@ class Conv2DCustomBackpropFilterOp : public OpKernel {
|
|||||||
context,
|
context,
|
||||||
ConvBackpropComputeDimensionsV2(
|
ConvBackpropComputeDimensionsV2(
|
||||||
"Conv2DCustomBackpropFilter", /*num_spatial_dims=*/2, input.shape(),
|
"Conv2DCustomBackpropFilter", /*num_spatial_dims=*/2, input.shape(),
|
||||||
filter_shape, out_backprop.shape(), /*dilations=*/{1, 1, 1, 1},
|
filter_shape, out_backprop.shape(), dilations_, strides_, padding_,
|
||||||
strides_, padding_, explicit_paddings_, data_format_, &dims));
|
explicit_paddings_, data_format_, &dims));
|
||||||
|
|
||||||
Tensor* filter_backprop;
|
Tensor* filter_backprop;
|
||||||
OP_REQUIRES_OK(context,
|
OP_REQUIRES_OK(context,
|
||||||
|
@ -628,15 +628,17 @@ def _DepthwiseConv2dNativeGrad(op, grad):
|
|||||||
array_ops.shape(op.inputs[0]),
|
array_ops.shape(op.inputs[0]),
|
||||||
op.inputs[1],
|
op.inputs[1],
|
||||||
grad,
|
grad,
|
||||||
op.get_attr("strides"),
|
dilations=op.get_attr("dilations"),
|
||||||
op.get_attr("padding"),
|
strides=op.get_attr("strides"),
|
||||||
|
padding=op.get_attr("padding"),
|
||||||
data_format=op.get_attr("data_format")),
|
data_format=op.get_attr("data_format")),
|
||||||
nn_ops.depthwise_conv2d_native_backprop_filter(
|
nn_ops.depthwise_conv2d_native_backprop_filter(
|
||||||
op.inputs[0],
|
op.inputs[0],
|
||||||
array_ops.shape(op.inputs[1]),
|
array_ops.shape(op.inputs[1]),
|
||||||
grad,
|
grad,
|
||||||
op.get_attr("strides"),
|
dilations=op.get_attr("dilations"),
|
||||||
op.get_attr("padding"),
|
strides=op.get_attr("strides"),
|
||||||
|
padding=op.get_attr("padding"),
|
||||||
data_format=op.get_attr("data_format"))
|
data_format=op.get_attr("data_format"))
|
||||||
]
|
]
|
||||||
|
|
||||||
|
@ -715,6 +715,22 @@ def zero_fraction(value, name=None):
|
|||||||
return array_ops.identity(zero_fraction_float32, "fraction")
|
return array_ops.identity(zero_fraction_float32, "fraction")
|
||||||
|
|
||||||
|
|
||||||
|
# copybara:strip_begin
|
||||||
|
# TODO(b/138808492): Remove code inside copybara
|
||||||
|
# to make TPU code and CPU code consistent.
|
||||||
|
def _enclosing_tpu_context():
|
||||||
|
# pylint: disable=protected-access
|
||||||
|
context = ops.get_default_graph()._get_control_flow_context()
|
||||||
|
# pylint: enable=protected-access
|
||||||
|
while context is not None and not isinstance(
|
||||||
|
context, control_flow_ops.XLAControlFlowContext):
|
||||||
|
context = context.outer_context
|
||||||
|
return context
|
||||||
|
|
||||||
|
|
||||||
|
# copybara:strip_end
|
||||||
|
|
||||||
|
|
||||||
# pylint: disable=redefined-builtin
|
# pylint: disable=redefined-builtin
|
||||||
@tf_export(v1=["nn.depthwise_conv2d"])
|
@tf_export(v1=["nn.depthwise_conv2d"])
|
||||||
def depthwise_conv2d(input,
|
def depthwise_conv2d(input,
|
||||||
@ -774,6 +790,25 @@ def depthwise_conv2d(input,
|
|||||||
if rate is None:
|
if rate is None:
|
||||||
rate = [1, 1]
|
rate = [1, 1]
|
||||||
|
|
||||||
|
# copybara:strip_begin
|
||||||
|
# TODO(b/138808492): Remove code inside copybara
|
||||||
|
# to make TPU code and CPU code consistent.
|
||||||
|
# Use depthwise_conv2d_native if executing on TPU.
|
||||||
|
if _enclosing_tpu_context() is not None:
|
||||||
|
if data_format == "NCHW":
|
||||||
|
dilations = [1, 1, rate[0], rate[1]]
|
||||||
|
else:
|
||||||
|
dilations = [1, rate[0], rate[1], 1]
|
||||||
|
return nn_ops.depthwise_conv2d_native(
|
||||||
|
input=input,
|
||||||
|
filter=filter,
|
||||||
|
strides=strides,
|
||||||
|
padding=padding,
|
||||||
|
data_format=data_format,
|
||||||
|
dilations=dilations,
|
||||||
|
name=name)
|
||||||
|
# copybara:strip_end
|
||||||
|
|
||||||
def op(input_converted, _, padding):
|
def op(input_converted, _, padding):
|
||||||
return nn_ops.depthwise_conv2d_native(
|
return nn_ops.depthwise_conv2d_native(
|
||||||
input=input_converted,
|
input=input_converted,
|
||||||
|
@ -33,6 +33,7 @@ from tensorflow.python.framework import tensor_shape
|
|||||||
from tensorflow.python.framework import tensor_util
|
from tensorflow.python.framework import tensor_util
|
||||||
from tensorflow.python.ops import array_ops
|
from tensorflow.python.ops import array_ops
|
||||||
from tensorflow.python.ops import check_ops
|
from tensorflow.python.ops import check_ops
|
||||||
|
from tensorflow.python.ops import control_flow_ops
|
||||||
from tensorflow.python.ops import gen_nn_ops
|
from tensorflow.python.ops import gen_nn_ops
|
||||||
from tensorflow.python.ops import math_ops
|
from tensorflow.python.ops import math_ops
|
||||||
from tensorflow.python.ops import random_ops
|
from tensorflow.python.ops import random_ops
|
||||||
@ -919,6 +920,22 @@ convolution_v2.__doc__ = deprecation.rewrite_argument_docstring(
|
|||||||
"filter", "filters")
|
"filter", "filters")
|
||||||
|
|
||||||
|
|
||||||
|
# copybara:strip_begin
|
||||||
|
# TODO(b/138808492): Remove code inside copybara
|
||||||
|
# to make TPU code and CPU code consistent.
|
||||||
|
def _enclosing_tpu_context():
|
||||||
|
# pylint: disable=protected-access
|
||||||
|
run_context = ops.get_default_graph()._get_control_flow_context()
|
||||||
|
# pylint: enable=protected-access
|
||||||
|
while run_context is not None and not isinstance(
|
||||||
|
run_context, control_flow_ops.XLAControlFlowContext):
|
||||||
|
run_context = run_context.outer_context
|
||||||
|
return run_context
|
||||||
|
|
||||||
|
|
||||||
|
# copybara:strip_end
|
||||||
|
|
||||||
|
|
||||||
def convolution_internal(
|
def convolution_internal(
|
||||||
input, # pylint: disable=redefined-builtin
|
input, # pylint: disable=redefined-builtin
|
||||||
filters,
|
filters,
|
||||||
@ -958,8 +975,14 @@ def convolution_internal(
|
|||||||
|
|
||||||
conv_ops = {1: conv1d, 2: gen_nn_ops.conv2d, 3: gen_nn_ops.conv3d}
|
conv_ops = {1: conv1d, 2: gen_nn_ops.conv2d, 3: gen_nn_ops.conv3d}
|
||||||
|
|
||||||
if all(i == 1 for i in dilations):
|
# copybara:strip_begin
|
||||||
# fast path if no dilation as gradient only supported on GPU for dilations
|
# TODO(b/138808492): Remove code inside copybara
|
||||||
|
# to make TPU code and CPU code consistent.
|
||||||
|
if _enclosing_tpu_context() is not None or all(i == 1 for i in dilations):
|
||||||
|
# fast path for TPU or if no dilation as gradient only supported on GPU
|
||||||
|
# for dilations
|
||||||
|
# copybara:strip_end
|
||||||
|
# copybara:insert if all(i == 1 for i in dilations):
|
||||||
op = conv_ops[n]
|
op = conv_ops[n]
|
||||||
return op(
|
return op(
|
||||||
input,
|
input,
|
||||||
@ -1056,7 +1079,9 @@ class Convolution(object):
|
|||||||
self.filter_shape = filter_shape
|
self.filter_shape = filter_shape
|
||||||
self.data_format = data_format
|
self.data_format = data_format
|
||||||
self.strides = strides
|
self.strides = strides
|
||||||
|
self.padding = padding
|
||||||
self.name = name
|
self.name = name
|
||||||
|
self.dilation_rate = dilation_rate
|
||||||
self.conv_op = _WithSpaceToBatch(
|
self.conv_op = _WithSpaceToBatch(
|
||||||
input_shape,
|
input_shape,
|
||||||
dilation_rate=dilation_rate,
|
dilation_rate=dilation_rate,
|
||||||
@ -1076,7 +1101,23 @@ class Convolution(object):
|
|||||||
name=self.name)
|
name=self.name)
|
||||||
|
|
||||||
def __call__(self, inp, filter): # pylint: disable=redefined-builtin
|
def __call__(self, inp, filter): # pylint: disable=redefined-builtin
|
||||||
|
# copybara:strip_begin
|
||||||
|
# TODO(b/138808492): Remove code inside copybara
|
||||||
|
# to make TPU code and CPU code consistent.
|
||||||
|
# TPU convolution supports dilations greater than 1.
|
||||||
|
if _enclosing_tpu_context() is not None:
|
||||||
|
return convolution_internal(
|
||||||
|
inp,
|
||||||
|
filter,
|
||||||
|
strides=self.strides,
|
||||||
|
padding=self.padding,
|
||||||
|
data_format=self.data_format,
|
||||||
|
dilations=self.dilation_rate,
|
||||||
|
name=self.name)
|
||||||
|
else:
|
||||||
return self.conv_op(inp, filter)
|
return self.conv_op(inp, filter)
|
||||||
|
# copybara:strip_end
|
||||||
|
# copybara:insert return self.conv_op(inp, filter)
|
||||||
|
|
||||||
|
|
||||||
@tf_export(v1=["nn.pool"])
|
@tf_export(v1=["nn.pool"])
|
||||||
|
Loading…
Reference in New Issue
Block a user