[TF:XLA] Add implementation of depthwise convolution.

This implementation expands the depthwise convolution kernels into a regular convolution kernel, which may not scale to large feature depths.

PiperOrigin-RevId: 163705408
This commit is contained in:
Peter Hawkins 2017-07-31 09:46:18 -07:00 committed by TensorFlower Gardener
parent f6f07b0275
commit 99b190a1f1
10 changed files with 664 additions and 402 deletions

View File

@ -46,13 +46,6 @@ py_library(
],
)
cc_library(
name = "depthwise_conv2d_test_kernel",
testonly = 1,
srcs = ["depthwise_conv2d_test_kernel.cc"],
deps = ["//tensorflow/core:framework_lite"],
)
tf_xla_py_test(
name = "adagrad_test",
size = "small",
@ -159,6 +152,22 @@ tf_xla_py_test(
],
)
tf_xla_py_test(
name = "depthwise_conv_op_test",
size = "medium",
srcs = ["depthwise_conv_op_test.py"],
shard_count = 5,
deps = [
":xla_test",
"//tensorflow/python:array_ops",
"//tensorflow/python:framework_for_generated_wrappers",
"//tensorflow/python:nn",
"//tensorflow/python:nn_ops",
"//tensorflow/python:nn_ops_gen",
"//tensorflow/python:platform_test",
],
)
tf_xla_py_test(
name = "dynamic_stitch_test",
size = "small",

View File

@ -26,10 +26,8 @@ import numpy as np
from tensorflow.compiler.tests.xla_test import XLATestCase
from tensorflow.python.framework import dtypes
from tensorflow.python.framework import ops
from tensorflow.python.ops import array_ops
from tensorflow.python.ops import gen_nn_ops
from tensorflow.python.ops import nn_impl
from tensorflow.python.ops import nn_ops
from tensorflow.python.platform import googletest
@ -447,80 +445,5 @@ class Conv2DBackpropFilterTest(XLATestCase):
expected=expected_output)
class DepthwiseConv2DTest(XLATestCase):
CPU_DEVICE = "/job:localhost/replica:0/task:0/cpu:0"
def ConfigsToTest(self):
input_sizes = [[4, 35, 35, 2], [4, 147, 147, 2], [3, 299, 299, 3],
[5, 183, 183, 1]]
filter_sizes = [[5, 5, 2, 1], [3, 3, 2, 8], [2, 2, 3, 8], [5, 5, 1, 2]]
strides = [1, 3, 2, 2]
# pylint: disable=invalid-name
VALID = "VALID"
SAME = "SAME"
# pylint: enable=invalid-name
paddings = [SAME, VALID, SAME, SAME, SAME]
for i, f, s, p in zip(input_sizes, filter_sizes, strides, paddings):
yield i, f, s, p
def _VerifyValues(self, input_size, filter_size, stride, padding):
imag = np.random.rand(*input_size).astype(np.float32)
filt = np.random.rand(*filter_size).astype(np.float32)
strides = [1, stride, stride, 1]
with self.test_session():
with self.test_scope():
imag_ph = array_ops.placeholder(dtypes.float32, shape=input_size)
filt_ph = array_ops.placeholder(dtypes.float32, shape=filter_size)
feed_dict = {imag_ph: imag, filt_ph: filt}
xla_out = nn_impl.depthwise_conv2d(imag_ph, filt_ph, strides,
padding).eval(feed_dict=feed_dict)
with self.test_session():
with ops.device(self.CPU_DEVICE):
imag_ph = array_ops.placeholder(dtypes.float32, shape=input_size)
filt_ph = array_ops.placeholder(dtypes.float32, shape=filter_size)
feed_dict = {imag_ph: imag, filt_ph: filt}
cpu_out = nn_impl.depthwise_conv2d(imag_ph, filt_ph, strides,
padding).eval(feed_dict=feed_dict)
self.assertAllClose(xla_out, cpu_out)
# This is disabled because we need a mechanism to set command-line flags,
# i.e. an implementation of SetCommandLineOption() below.
#
# def _VerifyDummy(self, input_size, filter_size, stride, padding):
# imag = np.random.rand(*input_size).astype(np.float32)
# filt = np.random.rand(*filter_size).astype(np.float32)
# strides = [1, stride, stride, 1]
#
# with self.test_session():
# with self.test_scope():
# imag_ph = tf.placeholder(tf.float32, shape=input_size)
# filt_ph = tf.placeholder(tf.float32, shape=filter_size)
# feed_dict = {imag_ph: imag, filt_ph: filt}
# SetCommandLineOption(
# "tf_tla_depthwise_conv2d_custom_func",
# "DummyDepthwiseConv2dKernel")
# xla_out = tf.nn.depthwise_conv2d(
# imag_ph, filt_ph, strides, padding).eval(feed_dict=feed_dict)
# SetCommandLineOption(
# "tf_tla_depthwise_conv2d_custom_func", "")
#
# expected = np.array(range(np.ravel(xla_out).shape[0]), dtype=np.float32)
# self.assertAllClose(np.ravel(xla_out), expected)
def testBasic(self):
for i, f, s, p in self.ConfigsToTest():
self._VerifyValues(i, f, s, p)
# Test disabled until _VerifyDummy(), above can be implemented.
# def testCustomFunc(self):
# if self.has_custom_call:
# for i, f, s, p in self.ConfigsToTest():
# self._VerifyDummy(i, f, s, p)
if __name__ == "__main__":
googletest.main()

View File

@ -1,30 +0,0 @@
/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#include "tensorflow/core/platform/types.h"
using tensorflow::int64;
// A dummy implementation that fills the output with 0, 1, 2,...
// to test the custom call implementation of DepthwiseConv2dNative op.
// TODO(keveman): Test this after adding a real implementation for the kernel.
extern "C" void DummyDepthwiseConv2dKernel(float* output, void** inputs) {
const int64* output_size = reinterpret_cast<const int64*>(inputs[4]);
const int64 total_size =
output_size[0] * output_size[1] * output_size[2] * output_size[3];
for (int64 i = 0; i < total_size; ++i) {
*(output + i) = i;
}
}

View File

@ -0,0 +1,389 @@
# Copyright 2015 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Functional tests for depthwise convolutional operations."""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import numpy as np
from six.moves import xrange # pylint: disable=redefined-builtin
from tensorflow.compiler.tests.xla_test import XLATestCase
from tensorflow.python.framework import constant_op
from tensorflow.python.framework import ops
from tensorflow.python.ops import array_ops
from tensorflow.python.ops import nn_ops
import tensorflow.python.ops.nn_grad # pylint: disable=unused-import
from tensorflow.python.platform import test
# Reference implementation of depthwise_conv2d
def ReferenceDepthwiseConv2D(input_tensor, filter_tensor, strides, padding,
data_format=None):
# Reference implementation of depthwise convolution that uses regular
# convolution.
convs = []
in_channels = filter_tensor.shape[2]
# Use a custom implementation of depthwise conv2d using slicing.
for channel in xrange(in_channels):
# Slice the input along channel
if data_format == "NCHW":
input_slice = input_tensor[:, channel:channel+1, :, :]
else:
input_slice = input_tensor[:, :, :, channel:channel+1]
# Slice the filters. Filters are H, W, InC, DepthMultiplier
filter_slice = filter_tensor[:, :, channel:channel+1, :]
# Do conv
convs.append(nn_ops.conv2d(input_slice, filter_slice,
strides, padding,
data_format=data_format,
name="depthwise_slice_%d" % channel))
# Concat along dimension.
if data_format == "NCHW":
return array_ops.concat(convs, 1)
else:
return array_ops.concat(convs, 3)
def ConfigsToTest():
"""Iterator for different convolution shapes, strides and paddings.
Yields:
Tuple (input_size, filter_size, out_size, stride, padding), the depthwise
convolution parameters.
"""
input_sizes = [[4, 5, 5, 48], [4, 8, 8, 84], [4, 17, 17, 48], [4, 9, 27, 8],
[4, 31, 31, 7], [4, 35, 35, 2], [4, 147, 147, 2],
[3, 299, 299, 3], [5, 183, 183, 1]]
filter_sizes = [[1, 1, 48, 2], [1, 3, 84, 1], [3, 1, 48, 4], [3, 3, 8, 1],
[3, 3, 7, 1], [5, 5, 2, 1], [3, 3, 2, 8], [2, 2, 3,
8], [5, 5, 1, 2]]
out_sizes = [[4, 5, 5, 96], [4, 8, 8, 84], [4, 17, 17, 192], [4, 9, 27, 8],
[4, 31, 31, 7], [4, 35, 35, 2], [4, 49, 49, 16],
[3, 150, 150, 24], [5, 92, 92, 2]]
strides = [1, 1, 1, 1, 1, 1, 3, 2, 2]
# pylint: disable=invalid-name
VALID = "VALID"
SAME = "SAME"
# pylint: enable=invalid-name
paddings = [SAME, SAME, SAME, SAME, SAME, SAME, VALID, SAME, SAME, SAME]
for i, f, o, s, p in zip(input_sizes, filter_sizes, out_sizes, strides,
paddings):
yield i, f, o, s, p
def CheckGradConfigsToTest():
"""Iterator for different convolution shapes, strides and paddings.
compute_gradient_error() is very expensive. So the configs should be
relatively small.
Yields:
Tuple (input_size, filter_size, out_size, stride, padding), the depthwise
convolution parameters.
"""
input_sizes = [[2, 5, 8, 1], [4, 5, 5, 1], [2, 4, 4, 2], [1, 15, 15, 2],
[2, 15, 16, 1]]
filter_sizes = [[4, 4, 1, 2], [2, 2, 1, 2], [3, 1, 2, 2], [1, 3, 2, 1],
[3, 3, 1, 2]]
out_sizes = [[2, 5, 8, 2], [4, 2, 2, 2], [2, 4, 4, 4], [1, 15, 15, 2],
[2, 5, 5, 2]]
strides = [1, 2, 1, 1, 3]
# pylint: disable=invalid-name
VALID = "VALID"
SAME = "SAME"
# pylint: enable=invalid-name
paddings = [SAME, VALID, SAME, SAME, VALID]
for i, f, o, s, p in zip(input_sizes, filter_sizes, out_sizes, strides,
paddings):
yield i, f, o, s, p
class DepthwiseConv2DTest(XLATestCase):
# This is testing that depthwise_conv2d and depthwise_conv2d_native
# produce the same results. It also tests that NCHW and NWHC
# formats agree, by comparing the depthwise_conv2d_native with
# 'NCHW' format (with transposition) matches the 'NHWC' format using
# the higher level interface.
def _VerifyValues(self,
tensor_in_sizes,
filter_in_sizes,
stride,
padding,
data_type,
data_format="NHWC"):
"""Verifies the output values of the convolution function.
Args:
tensor_in_sizes: Input tensor dimensions in
[batch, input_rows, input_cols, input_depth].
filter_in_sizes: Filter tensor dimensions in
[filter_rows, filter_cols, input_depth, depth_multiplier].
stride: Stride.
padding: Padding type.
data_type: The data type to use.
data_format: The data_format of the input. "NHWC" or "NCHW".
"""
total_size_1 = 1
total_size_2 = 1
for s in tensor_in_sizes:
total_size_1 *= s
for s in filter_in_sizes:
total_size_2 *= s
# Initializes the input and filter tensor with numbers incrementing from 1.
x1 = np.array([f * 1.0 for f in range(1, total_size_1 + 1)],
dtype=data_type).reshape(tensor_in_sizes)
x2 = np.array([f * 1.0 for f in range(1, total_size_2 + 1)],
dtype=data_type).reshape(filter_in_sizes)
with self.test_session() as sess:
if data_type == np.float32:
tolerance = 1e-5
else:
self.assertEqual(data_type, np.float64)
tolerance = 1e-8
t1 = array_ops.placeholder(shape=tensor_in_sizes, dtype=data_type)
t2 = array_ops.placeholder(shape=filter_in_sizes, dtype=data_type)
native_t1 = t1
strides = [1, stride, stride, 1]
if data_format == "NCHW":
# Transpose from NWHC input to NCHW
# Ex. [4, 5, 5, 48] to [4, 48, 5, 5]
native_t1 = array_ops.transpose(t1, [0, 3, 1, 2])
strides = [1, 1, stride, stride]
with self.test_scope():
conv_native = nn_ops.depthwise_conv2d_native(
native_t1,
t2,
strides=strides,
data_format=data_format,
padding=padding)
if data_format == "NCHW":
# Transpose back from NCHW to NHWC
conv_native = array_ops.transpose(conv_native, [0, 2, 3, 1])
with ops.device("CPU"):
conv_interface = ReferenceDepthwiseConv2D(
t1, t2, strides=[1, stride, stride, 1], padding=padding)
native_result = sess.run(conv_native, {t1: x1, t2: x2})
interface_result = sess.run(conv_interface, {t1: x1, t2: x2})
print("data_type:", data_type, "max diff = ",
np.amax(np.absolute(native_result - interface_result)))
self.assertAllClose(
np.ravel(native_result), np.ravel(interface_result), rtol=tolerance)
def testDepthwiseConv2D(self):
for index, (input_size, filter_size, _, stride,
padding) in enumerate(ConfigsToTest()):
print("Testing DepthwiseConv2D,", index, "th config:", input_size, "*",
filter_size, "stride:", stride, "padding:", padding)
for data_type in self.float_types:
# TODO(phawkins): the reference implementation only supports float32.
if data_type == np.float32:
self._VerifyValues(
input_size, filter_size, stride, padding, data_type)
def testDepthwiseConv2DFormat(self):
for index, (input_size, filter_size, _, stride,
padding) in enumerate(ConfigsToTest()):
print("Testing DepthwiseConv2DFormat,", index, "th config:", input_size,
"*", filter_size, "stride:", stride, "padding:", padding)
for data_type in self.float_types:
# TODO(phawkins): the reference implementation only supports float32.
if data_type == np.float32:
self._VerifyValues(
input_size,
filter_size,
stride,
padding,
data_type,
data_format="NCHW")
# This is testing against hand calculated results.
def _VerifyHandValues(self, tensor_in_sizes, filter_in_sizes, stride, padding,
expected):
"""Verifies the output values of the depthwise convolution function.
Args:
tensor_in_sizes: Input tensor dimensions in
[batch, input_rows, input_cols, input_depth].
filter_in_sizes: Filter tensor dimensions in
[filter_rows, filter_cols, input_depth, depth_multiplier].
stride: Stride.
padding: Padding type.
expected: An array containing the expected operation outputs.
"""
total_size_1 = 1
total_size_2 = 1
for s in tensor_in_sizes:
total_size_1 *= s
for s in filter_in_sizes:
total_size_2 *= s
# Initializes the input tensor with array containing incrementing
# numbers from 1.
x1 = np.array([f * 1.0 for f in range(1, total_size_1 + 1)],
dtype=np.float32).reshape(tensor_in_sizes)
x2 = np.array([f * 1.0 for f in range(1, total_size_2 + 1)],
dtype=np.float32).reshape(filter_in_sizes)
with self.test_session() as sess:
t1 = array_ops.placeholder(shape=tensor_in_sizes, dtype=np.float32)
t2 = array_ops.placeholder(shape=filter_in_sizes, dtype=np.float32)
with self.test_scope():
conv = nn_ops.depthwise_conv2d_native(
t1, t2, strides=[1, stride, stride, 1], padding=padding)
value = sess.run(conv, {t1: x1, t2: x2})
print("value = ", value)
self.assertArrayNear(expected, np.ravel(value), 1e-5)
self.assertShapeEqual(value, conv)
def testConv2D2x2Filter(self):
# The inputs look like this (it's a 3 x 2 matrix, each of depth 2):
#
# [ (1.0, 2.0), (3.0, 4.0), ( 5.0, 6.0) ]
# [ (7.0, 8.0), (9.0, 10.0), (11.0, 12.0) ]
# We can view this as two inputs
#
# input depth 0:
#
# [ 1.0, 3.0, 5.0 ]
# [ 7.0, 9.0, 11.0 ]
#
# input depth 1:
#
# [ 2.0, 4.0, 6.0 ]
# [ 8.0, 10.0, 12.0 ]
#
# The filter looks like this (it has two 2 x 2 patches, each generating 2
# depths):
#
# filter #0:
#
# [ (1.0, 3.0), ( 5.0, 7.0)]
# [ (9.0, 11.0), (13.0, 15.0)]
#
# filter #1:
#
# [ ( 2.0, 4.0), ( 6.0, 8.0)]
# [ (10.0, 12.0), (14.0, 16.0)]
#
# So the outputs are:
#
# (position 0, 0: in_depth 0, output_depth 0 -- using filter #0)
# 1.0 * 1.0 + 7.0 * 9.0 + 3.0 * 5.0 + 9.0 * 13.0 = 196
# (position 0, 0: in_depth 0, output_depth 1 -- using filter #1)
# 1.0 * 2.0 + 7.0 * 10.0 + 3.0 * 6.0 + 9.0 * 14.0 = 216
# (position 0, 0: in_depth 1, output_depth 2 -- using filter #0)
# 2.0 * 3.0 + 8.0 * 11.0 + 4.0 * 7.0 + 10.0 * 15.0 = 272
# (position 0, 0: in_depth 1, output_depth 3 -- using filter #1)
# 2.0 * 4.0 + 8.0 * 12.0 + 4.0 * 8.0 + 10.0 * 16.0 = 296
#
# (position 1, 0: in_depth 0, output_depth 0 -- using filter #0)
# 3.0 * 1.0 + 9.0 * 9.0 + 5.0 * 5.0 + 11.0 * 13.0 = 252
# (position 1, 0: in_depth 0, output_depth 1 -- using filter #1)
# 3.0 * 2.0 + 9.0 * 10.0 + 5.0 * 6.0 + 11.0 * 14.0 = 280
# (position 1, 0: in_depth 1, output_depth 2 -- using filter #0)
# 4.0 * 3.0 + 10.0 * 11.0 + 6.0 * 7.0 + 12.0 * 15.0 = 344
# (position 1, 0: in_depth 1, output_depth 3 -- using filter #1)
# 4.0 * 4.0 + 10.0 * 12.0 + 6.0 * 8.0 + 12.0 * 16.0 = 376
expected_output = [196, 216, 272, 296, 252, 280, 344, 376]
self._VerifyHandValues(
tensor_in_sizes=[1, 2, 3, 2],
filter_in_sizes=[2, 2, 2, 2],
stride=1,
padding="VALID",
expected=expected_output)
def _CompareBackpropInput(self, input_sizes, filter_sizes, output_sizes,
stride, padding):
x1 = np.random.rand(*filter_sizes).astype(np.float32)
x2 = np.random.rand(*output_sizes).astype(np.float32)
def _GetVal(use_xla):
with self.test_session():
t0 = constant_op.constant(input_sizes, shape=[len(input_sizes)])
t1 = array_ops.placeholder(np.float32, shape=filter_sizes)
t2 = array_ops.placeholder(np.float32, shape=output_sizes)
if use_xla:
with self.test_scope():
backprop = nn_ops.depthwise_conv2d_native_backprop_input(
t0, t1, t2, strides=[1, stride, stride, 1], padding=padding)
else:
backprop = nn_ops.depthwise_conv2d_native_backprop_input(
t0, t1, t2, strides=[1, stride, stride, 1], padding=padding)
ret = backprop.eval({t1: x1, t2: x2})
self.assertShapeEqual(ret, backprop)
return ret
gpu_value = _GetVal(use_xla=True)
cpu_value = _GetVal(use_xla=False)
self.assertAllClose(cpu_value, gpu_value, rtol=1e-4, atol=1e-4)
def testDepthwiseConv2DInputGradCompare(self):
for index, (input_size, filter_size, output_size, stride,
padding) in enumerate(ConfigsToTest()):
print("Testing DepthwiseConv2DInputGradCompare,", index, "th config:",
input_size, "*", filter_size, "stride:", stride, "padding:",
padding)
self._CompareBackpropInput(input_size, filter_size, output_size, stride,
padding)
def _CompareBackpropFilter(self, input_sizes, filter_sizes, output_sizes,
stride, padding):
x0 = np.random.rand(*input_sizes).astype(np.float32)
x2 = np.random.rand(*output_sizes).astype(np.float32)
def _GetVal(use_xla):
with self.test_session():
t0 = array_ops.placeholder(np.float32, shape=input_sizes)
t1 = constant_op.constant(filter_sizes, shape=[len(filter_sizes)])
t2 = array_ops.placeholder(np.float32, shape=output_sizes)
if use_xla:
with self.test_scope():
backprop = nn_ops.depthwise_conv2d_native_backprop_filter(
t0, t1, t2, strides=[1, stride, stride, 1], padding=padding)
else:
backprop = nn_ops.depthwise_conv2d_native_backprop_filter(
t0, t1, t2, strides=[1, stride, stride, 1], padding=padding)
ret = backprop.eval({t0: x0, t2: x2})
self.assertShapeEqual(ret, backprop)
return ret
gpu_value = _GetVal(use_xla=True)
cpu_value = _GetVal(use_xla=False)
self.assertAllClose(cpu_value, gpu_value, rtol=1e-4, atol=1e-4)
def testDepthwiseConv2DFilterGradCompare(self):
for index, (input_size, filter_size, output_size, stride,
padding) in enumerate(ConfigsToTest()):
print("Testing DepthwiseConv2DFilterGradCompare,", index, "th config:",
input_size, "*", filter_size, "stride:", stride, "padding:",
padding)
self._CompareBackpropFilter(input_size, filter_size, output_size,
stride, padding)
if __name__ == "__main__":
test.main()

View File

@ -368,11 +368,11 @@ OpTest::OpTest() {
void OpTest::Repeatedly(const std::function<TestResult(void)>& fn) {
int const max_repetitions = tf_xla_test_repetitions;
int valid_test_runs = 0;
// We run up to 10 * max_repetitions times; the idea is that if we roll the
// We run up to 20 * max_repetitions times; the idea is that if we roll the
// dice enough times we will find some valid parameters. We want to put an
// upper limit on the number iterations just in case the probability of
// finding feasible parameters is very low.
for (int i = 0; !HasFailure() && i < max_repetitions * 10 &&
for (int i = 0; !HasFailure() && i < max_repetitions * 20 &&
valid_test_runs < max_repetitions;
++i) {
TestResult result = fn();
@ -1326,6 +1326,77 @@ TEST_F(OpTest, Conv3DBackpropInput) {
});
}
TEST_F(OpTest, DepthwiseConv2DNative) {
Repeatedly([this]() {
WindowedSpatialDims d = ChooseWindowedSpatialDims(2);
std::uniform_int_distribution<int> random_int(1, 5);
int features_in = random_int(generator());
int depth_multiplier = random_int(generator());
std::vector<int64> input_dims = {RandomDim(), d.input_dims[0],
d.input_dims[1], features_in};
std::vector<int64> kernel_dims = {d.kernel_dims[0], d.kernel_dims[1],
features_in, depth_multiplier};
return ExpectTfAndXlaOutputsAreClose(
OpTestBuilder("DepthwiseConv2dNative")
.RandomInput(DT_FLOAT, input_dims)
.RandomInput(DT_FLOAT, kernel_dims)
.Attr("T", DT_FLOAT)
.Attr("strides", ImageDims(FORMAT_NHWC, 1, 1, d.stride_dims))
.Attr("padding", d.padding == SAME ? "SAME" : "VALID"));
});
}
TEST_F(OpTest, DepthwiseConv2DBackpropFilter) {
Repeatedly([this]() {
WindowedSpatialDims d = ChooseWindowedSpatialDims(2);
std::uniform_int_distribution<int> random_int(1, 5);
int features_in = random_int(generator());
int depth_multiplier = random_int(generator());
int32 batch = RandomDim();
std::vector<int64> activations =
ImageDims(FORMAT_NHWC, batch, features_in, d.input_dims);
std::vector<int64> backprop = ImageDims(
FORMAT_NHWC, batch, features_in * depth_multiplier, d.output_dims);
Tensor kernel_shape = test::AsTensor<int32>(AsInt32s(
{d.kernel_dims[0], d.kernel_dims[1], features_in, depth_multiplier}));
return ExpectTfAndXlaOutputsAreClose(
OpTestBuilder("DepthwiseConv2dNativeBackpropFilter")
.RandomInput(DT_FLOAT, activations)
.Input(kernel_shape)
.RandomInput(DT_FLOAT, backprop)
.Attr("T", DT_FLOAT)
.Attr("strides", ImageDims(FORMAT_NHWC, 1, 1, d.stride_dims))
.Attr("padding", d.padding == SAME ? "SAME" : "VALID")
.Attr("data_format", "NHWC"));
});
}
TEST_F(OpTest, DepthwiseConv2DBackpropInput) {
Repeatedly([this]() {
WindowedSpatialDims d = ChooseWindowedSpatialDims(2);
std::uniform_int_distribution<int> random_int(1, 5);
int features_in = random_int(generator());
int depth_multiplier = random_int(generator());
int32 batch = RandomDim();
Tensor in_shape = test::AsTensor<int32>(
AsInt32s(ImageDims(FORMAT_NHWC, batch, features_in, d.input_dims)));
std::vector<int64> backprop = ImageDims(
FORMAT_NHWC, batch, features_in * depth_multiplier, d.output_dims);
std::vector<int64> kernel = {d.kernel_dims[0], d.kernel_dims[1],
features_in, depth_multiplier};
return ExpectTfAndXlaOutputsAreClose(
OpTestBuilder("DepthwiseConv2dNativeBackpropInput")
.Input(in_shape)
.RandomInput(DT_FLOAT, kernel)
.RandomInput(DT_FLOAT, backprop)
.Attr("T", DT_FLOAT)
.Attr("strides", ImageDims(FORMAT_NHWC, 1, 1, d.stride_dims))
.Attr("padding", d.padding == SAME ? "SAME" : "VALID")
.Attr("data_format", "NHWC"));
});
}
TEST_F(OpTest, Diag) {
Repeatedly([this]() {
DataType type = Choose<DataType>({DT_INT32, DT_FLOAT});

View File

@ -48,6 +48,8 @@ Status BackwardsConstAnalysis(const Graph& g,
{"Conv2DBackpropInput", "input_sizes"},
{"Conv3DBackpropFilterV2", "filter_sizes"},
{"Conv3DBackpropInputV2", "input_sizes"},
{"DepthwiseConv2dNativeBackpropFilter", "filter_sizes"},
{"DepthwiseConv2dNativeBackpropInput", "input_sizes"},
{"DynamicStitch", "indices"},
{"ExpandDims", "dim"},
{"Fill", "dims"},

View File

@ -26,7 +26,6 @@ tf_kernel_library(
"conv_ops.cc",
"cross_op.cc",
"cwise_ops.cc",
"depthwise_conv_ops.cc",
"diag_op.cc",
"dynamic_stitch_op.cc",
"elu_op.cc",
@ -82,11 +81,8 @@ tf_kernel_library(
"//tensorflow/core:lib",
"//tensorflow/core:tensorflow_opensource",
"//tensorflow/core/kernels:concat_lib",
"//tensorflow/core/kernels:conv_2d",
"//tensorflow/core/kernels:conv_ops",
"//tensorflow/core/kernels:cwise_op",
"//tensorflow/core/kernels:depthwise_conv_op",
"//tensorflow/core/kernels:matmul_op",
"//tensorflow/core/kernels:no_op",
"//tensorflow/core/kernels:ops_util",
"//tensorflow/core/kernels:pooling_ops",

View File

@ -25,7 +25,6 @@ limitations under the License.
#include "tensorflow/core/framework/tensor_shape.h"
#include "tensorflow/core/framework/tensor_slice.h"
#include "tensorflow/core/kernels/bounds_check.h"
#include "tensorflow/core/kernels/conv_2d.h"
#include "tensorflow/core/kernels/conv_grad_ops.h"
#include "tensorflow/core/kernels/ops_util.h"
#include "tensorflow/core/util/padding.h"
@ -35,12 +34,100 @@ namespace tensorflow {
namespace {
// Returns the expanded size of a filter used for depthwise convolution.
// If `shape` is [H, W, ..., M, N] returns [H, W, ..., M, M*N].
TensorShape ExpandedFilterShapeForDepthwiseConvolution(
const TensorShape& shape) {
int num_dims = shape.dims();
CHECK_GE(num_dims, 2);
TensorShape expanded_shape = shape;
expanded_shape.set_dim(num_dims - 1, shape.dim_size(num_dims - 2) *
shape.dim_size(num_dims - 1));
return expanded_shape;
}
// Expands a filter of shape [H, W, ..., M, N] to [H, W, ..., M, M*N] by adding
// zeros for the cross-depth filters. Used to build a depthwise convolution.
xla::ComputationDataHandle ExpandFilterForDepthwiseConvolution(
const TensorShape& filter_shape, DataType dtype,
const xla::ComputationDataHandle& filter,
xla::ComputationBuilder* builder) {
// Filter has shape [H, W, ..., M, N]
// Dilate to [H, W, ..., M*M, N] using M inter-element padding, and then
// reshape to [H, W, ..., M, M*N].
int num_spatial_dims = filter_shape.dims() - 2;
const int64 in_depth = filter_shape.dim_size(num_spatial_dims);
xla::PaddingConfig padding = xla::MakeNoPaddingConfig(filter_shape.dims());
padding.mutable_dimensions(num_spatial_dims)->set_interior_padding(in_depth);
auto dilated_filter =
builder->Pad(filter, XlaHelpers::Zero(builder, dtype), padding);
TensorShape expanded_filter_shape =
ExpandedFilterShapeForDepthwiseConvolution(filter_shape);
return builder->Reshape(dilated_filter, expanded_filter_shape.dim_sizes());
}
// Inverse of ExpandFilterForDepthwiseConvolution.
xla::ComputationDataHandle ContractFilterForDepthwiseBackprop(
const TensorShape& filter_shape, DataType dtype,
const xla::ComputationDataHandle& filter_backprop,
xla::ComputationBuilder* builder) {
int num_spatial_dims = filter_shape.dims() - 2;
// Reshape to [H, W, ..., M*M, N]
TensorShape shape = filter_shape;
int64 in_depth = filter_shape.dim_size(num_spatial_dims);
shape.set_dim(num_spatial_dims, in_depth * in_depth);
auto reshaped = builder->Reshape(filter_backprop, shape.dim_sizes());
std::vector<int64> zeros(filter_shape.dims());
std::vector<int64> strides(filter_shape.dims(), 1LL);
strides[num_spatial_dims] = in_depth + 1;
return builder->Slice(reshaped, zeros, shape.dim_sizes(), strides);
// Alternate implementation for backends without strided Slice() support.
// TODO(phawkins): Remove when all backends support strided slice.
// // Pad [..., M * (M + 1), N]
// xla::PaddingConfig config =
// xla::MakeNoPaddingConfig(filter_shape.dims());
// config.mutable_dimensions(num_spatial_dims)
// ->set_edge_padding_high(in_depth);
// auto zero = XlaHelpers::Zero(builder, dtype);
// auto padded = builder->Pad(reshaped, zero, config);
//
// // Reshape to [..., M, M + 1, N]
// shape = filter_shape;
// shape.set_dim(num_spatial_dims, in_depth);
// shape.set_dim(num_spatial_dims + 1, in_depth + 1);
// int64 out_depth = filter_shape.dim_size(num_spatial_dims + 1);
// shape.AddDim(out_depth);
// reshaped = builder->Reshape(padded, shape.dim_sizes());
//
// // Slice to [..., M, 1, N]
// std::vector<int64> zeros(shape.dims());
// std::vector<int64> strides(shape.dims(), 1LL);
// shape.set_dim(num_spatial_dims + 1, 1);
// auto sliced = builder->Slice(reshaped, zeros, shape.dim_sizes(),
// strides);
//
// // Reshape to [..., M, N]
// return builder->Reshape(sliced, filter_shape.dim_sizes());
}
class ConvOp : public XlaOpKernel {
public:
explicit ConvOp(OpKernelConstruction* ctx, int num_spatial_dims)
: XlaOpKernel(ctx), num_spatial_dims_(num_spatial_dims) {
explicit ConvOp(OpKernelConstruction* ctx, int num_spatial_dims,
bool depthwise)
: XlaOpKernel(ctx),
num_spatial_dims_(num_spatial_dims),
depthwise_(depthwise) {
OP_REQUIRES_OK(ctx, ctx->GetAttr("strides", &strides_));
OP_REQUIRES_OK(ctx, ctx->GetAttr("padding", &padding_));
string data_format;
OP_REQUIRES_OK(ctx, ctx->GetAttr("data_format", &data_format));
OP_REQUIRES(ctx, FormatFromString(data_format, &data_format_),
errors::InvalidArgument("Invalid data format"));
}
int num_dims() const { return num_spatial_dims_ + 2; }
@ -82,9 +169,16 @@ class ConvOp : public XlaOpKernel {
"input and filter must have the same depth: ", in_depth,
" vs ", input_shape.dim_size(feature_dim)));
xla::ComputationBuilder* b = ctx->builder();
xla::ComputationDataHandle filter = ctx->Input(1);
if (depthwise_) {
filter = ExpandFilterForDepthwiseConvolution(
filter_shape, ctx->input_type(0), filter, b);
}
xla::ConvolutionDimensionNumbers dims;
std::vector<int64> window_strides;
dims.set_batch_dimension(GetTensorBatchDimIndex(num_dims(), data_format_));
dims.set_feature_dimension(feature_dim);
for (int i = 0; i < num_spatial_dims_; ++i) {
@ -99,13 +193,14 @@ class ConvOp : public XlaOpKernel {
xla::Padding xla_padding =
(padding_ == VALID) ? xla::Padding::kValid : xla::Padding::kSame;
xla::ComputationDataHandle conv = ctx->builder()->ConvWithGeneralDimensions(
ctx->Input(0), ctx->Input(1), window_strides, xla_padding, dims);
xla::ComputationDataHandle conv = b->ConvWithGeneralDimensions(
ctx->Input(0), filter, window_strides, xla_padding, dims);
ctx->SetOutput(0, conv);
}
protected:
const int num_spatial_dims_;
const bool depthwise_;
std::vector<int32> strides_;
Padding padding_;
TensorFormat data_format_ = FORMAT_NHWC;
@ -117,29 +212,38 @@ class ConvOp : public XlaOpKernel {
class Conv2DOp : public ConvOp {
public:
explicit Conv2DOp(OpKernelConstruction* ctx)
: ConvOp(ctx, /*num_spatial_dims=*/2) {
string data_format;
OP_REQUIRES_OK(ctx, ctx->GetAttr("data_format", &data_format));
OP_REQUIRES(ctx, FormatFromString(data_format, &data_format_),
errors::InvalidArgument("Invalid data format"));
}
: ConvOp(ctx, /*num_spatial_dims=*/2, /*depthwise=*/false) {}
};
REGISTER_XLA_OP(Name("Conv2D"), Conv2DOp);
class Conv3DOp : public ConvOp {
public:
explicit Conv3DOp(OpKernelConstruction* ctx)
: ConvOp(ctx, /*num_spatial_dims=*/3) {}
: ConvOp(ctx, /*num_spatial_dims=*/3, /*depthwise=*/false) {}
};
REGISTER_XLA_OP(Name("Conv3D"), Conv3DOp);
class DepthwiseConv2DOp : public ConvOp {
public:
explicit DepthwiseConv2DOp(OpKernelConstruction* ctx)
: ConvOp(ctx, /*num_spatial_dims=*/2, /*depthwise=*/true) {}
};
REGISTER_XLA_OP(Name("DepthwiseConv2dNative"), DepthwiseConv2DOp);
// Backprop for input.
class ConvBackpropInputOp : public XlaOpKernel {
public:
explicit ConvBackpropInputOp(OpKernelConstruction* ctx, int num_spatial_dims)
: XlaOpKernel(ctx), num_spatial_dims_(num_spatial_dims) {
explicit ConvBackpropInputOp(OpKernelConstruction* ctx, int num_spatial_dims,
bool depthwise)
: XlaOpKernel(ctx),
num_spatial_dims_(num_spatial_dims),
depthwise_(depthwise) {
OP_REQUIRES_OK(ctx, ctx->GetAttr("strides", &strides_));
OP_REQUIRES_OK(ctx, ctx->GetAttr("padding", &padding_));
string data_format;
OP_REQUIRES_OK(ctx, ctx->GetAttr("data_format", &data_format));
OP_REQUIRES(ctx, FormatFromString(data_format, &data_format_),
errors::InvalidArgument("Invalid data format"));
}
int num_dims() const { return num_spatial_dims_ + 2; }
@ -162,13 +266,17 @@ class ConvBackpropInputOp : public XlaOpKernel {
const TensorShape filter_shape = ctx->InputShape(1);
const TensorShape out_backprop_shape = ctx->InputShape(2);
const TensorShape expanded_filter_shape =
depthwise_ ? ExpandedFilterShapeForDepthwiseConvolution(filter_shape)
: filter_shape;
// Reuse dimension computation logic from conv_grad_ops.cc.
ConvBackpropDimensions dims;
OP_REQUIRES_OK(
ctx, ConvBackpropComputeDimensions(
type_string(), num_spatial_dims_, input_shape, filter_shape,
out_backprop_shape, strides_, padding_, data_format_, &dims));
OP_REQUIRES_OK(ctx, ConvBackpropComputeDimensions(
type_string(), num_spatial_dims_, input_shape,
expanded_filter_shape, out_backprop_shape, strides_,
padding_, data_format_, &dims));
xla::ComputationBuilder* b = ctx->builder();
auto filter = ctx->Input(1);
auto out_backprop = ctx->Input(2);
@ -200,13 +308,19 @@ class ConvBackpropInputOp : public XlaOpKernel {
lhs_dilation[i] = dims.spatial_dims[i].stride;
}
// If this is a depthwise convolution, expand the filter.
if (depthwise_) {
filter = ExpandFilterForDepthwiseConvolution(
filter_shape, ctx->input_type(1), filter, b);
}
// Mirror the filter in the spatial dimensions.
xla::ComputationDataHandle mirrored_weights =
ctx->builder()->Rev(filter, kernel_spatial_dims);
b->Rev(filter, kernel_spatial_dims);
// activation gradients
// = gradients (with padding and dilation) <conv> mirrored_weights
xla::ComputationDataHandle in_backprop = ctx->builder()->ConvGeneralDilated(
xla::ComputationDataHandle in_backprop = b->ConvGeneralDilated(
out_backprop, mirrored_weights, /*window_strides=*/ones, padding,
lhs_dilation, /*rhs_dilation=*/ones, dnums);
@ -215,6 +329,7 @@ class ConvBackpropInputOp : public XlaOpKernel {
protected:
const int num_spatial_dims_;
const bool depthwise_;
std::vector<int32> strides_;
Padding padding_;
TensorFormat data_format_ = FORMAT_NHWC;
@ -226,28 +341,38 @@ class ConvBackpropInputOp : public XlaOpKernel {
class Conv2DBackpropInputOp : public ConvBackpropInputOp {
public:
explicit Conv2DBackpropInputOp(OpKernelConstruction* ctx)
: ConvBackpropInputOp(ctx, /*num_spatial_dims=*/2) {
string data_format;
OP_REQUIRES_OK(ctx, ctx->GetAttr("data_format", &data_format));
OP_REQUIRES(ctx, FormatFromString(data_format, &data_format_),
errors::InvalidArgument("Invalid data format"));
}
: ConvBackpropInputOp(ctx, /*num_spatial_dims=*/2, /*depthwise=*/false) {}
};
REGISTER_XLA_OP(Name("Conv2DBackpropInput"), Conv2DBackpropInputOp);
class Conv3DBackpropInputOp : public ConvBackpropInputOp {
public:
explicit Conv3DBackpropInputOp(OpKernelConstruction* ctx)
: ConvBackpropInputOp(ctx, /*num_spatial_dims=*/3) {}
: ConvBackpropInputOp(ctx, /*num_spatial_dims=*/3, /*depthwise=*/false) {}
};
REGISTER_XLA_OP(Name("Conv3DBackpropInputV2"), Conv3DBackpropInputOp);
class DepthwiseConv2DBackpropInputOp : public ConvBackpropInputOp {
public:
explicit DepthwiseConv2DBackpropInputOp(OpKernelConstruction* ctx)
: ConvBackpropInputOp(ctx, /*num_spatial_dims=*/2, /*depthwise=*/true) {}
};
REGISTER_XLA_OP(Name("DepthwiseConv2dNativeBackpropInput"),
DepthwiseConv2DBackpropInputOp);
class ConvBackpropFilterOp : public XlaOpKernel {
public:
explicit ConvBackpropFilterOp(OpKernelConstruction* ctx, int num_spatial_dims)
: XlaOpKernel(ctx), num_spatial_dims_(num_spatial_dims) {
explicit ConvBackpropFilterOp(OpKernelConstruction* ctx, int num_spatial_dims,
bool depthwise)
: XlaOpKernel(ctx),
num_spatial_dims_(num_spatial_dims),
depthwise_(depthwise) {
OP_REQUIRES_OK(ctx, ctx->GetAttr("strides", &strides_));
OP_REQUIRES_OK(ctx, ctx->GetAttr("padding", &padding_));
string data_format;
OP_REQUIRES_OK(ctx, ctx->GetAttr("data_format", &data_format));
OP_REQUIRES(ctx, FormatFromString(data_format, &data_format_),
errors::InvalidArgument("Invalid data format"));
}
int num_dims() const { return num_spatial_dims_ + 2; }
@ -266,13 +391,18 @@ class ConvBackpropFilterOp : public XlaOpKernel {
OP_REQUIRES_OK(ctx, ctx->ConstantInputAsShape(1, &filter_shape));
const TensorShape out_backprop_shape = ctx->InputShape(2);
const TensorShape expanded_filter_shape =
depthwise_ ? ExpandedFilterShapeForDepthwiseConvolution(filter_shape)
: filter_shape;
// Reuse dimension computation logic from conv_grad_ops.cc.
ConvBackpropDimensions dims;
OP_REQUIRES_OK(ctx, ConvBackpropComputeDimensions(
type_string(), num_spatial_dims_, activations_shape,
filter_shape, out_backprop_shape, strides_,
expanded_filter_shape, out_backprop_shape, strides_,
padding_, data_format_, &dims));
xla::ComputationBuilder* b = ctx->builder();
xla::ComputationDataHandle activations = ctx->Input(0);
xla::ComputationDataHandle gradients = ctx->Input(2);
@ -357,10 +487,10 @@ class ConvBackpropFilterOp : public XlaOpKernel {
//
// This is done by specifying the window dilation factors in the
// convolution HLO below.
auto filter_backprop = ctx->builder()->ConvGeneralDilated(
activations, gradients,
/*window_strides=*/ones, padding, /*lhs_dilation=*/ones, rhs_dilation,
dnums);
auto filter_backprop =
b->ConvGeneralDilated(activations, gradients,
/*window_strides=*/ones, padding,
/*lhs_dilation=*/ones, rhs_dilation, dnums);
// The layout of filter_backprop will match the layout of
// padded_activations
@ -375,12 +505,18 @@ class ConvBackpropFilterOp : public XlaOpKernel {
transpose_dims.push_back(c_dim);
transpose_dims.push_back(n_dim);
xla::ComputationDataHandle filter_backprop_reshaped =
ctx->builder()->Transpose(filter_backprop, transpose_dims);
b->Transpose(filter_backprop, transpose_dims);
if (depthwise_) {
filter_backprop_reshaped = ContractFilterForDepthwiseBackprop(
filter_shape, ctx->input_type(0), filter_backprop_reshaped, b);
}
ctx->SetOutput(0, filter_backprop_reshaped);
}
protected:
int num_spatial_dims_;
const int num_spatial_dims_;
const bool depthwise_;
std::vector<int32> strides_;
Padding padding_;
TensorFormat data_format_ = FORMAT_NHWC;
@ -392,11 +528,7 @@ class ConvBackpropFilterOp : public XlaOpKernel {
class Conv2DBackpropFilterOp : public ConvBackpropFilterOp {
public:
explicit Conv2DBackpropFilterOp(OpKernelConstruction* ctx)
: ConvBackpropFilterOp(ctx, /*num_spatial_dims=*/2) {
string data_format;
OP_REQUIRES_OK(ctx, ctx->GetAttr("data_format", &data_format));
OP_REQUIRES(ctx, FormatFromString(data_format, &data_format_),
errors::InvalidArgument("Invalid data format"));
: ConvBackpropFilterOp(ctx, /*num_spatial_dims=*/2, /*depthwise=*/false) {
}
};
REGISTER_XLA_OP(Name("Conv2DBackpropFilter"), Conv2DBackpropFilterOp);
@ -404,9 +536,18 @@ REGISTER_XLA_OP(Name("Conv2DBackpropFilter"), Conv2DBackpropFilterOp);
class Conv3DBackpropFilterOp : public ConvBackpropFilterOp {
public:
explicit Conv3DBackpropFilterOp(OpKernelConstruction* ctx)
: ConvBackpropFilterOp(ctx, /*num_spatial_dims=*/3) {}
: ConvBackpropFilterOp(ctx, /*num_spatial_dims=*/3, /*depthwise=*/false) {
}
};
REGISTER_XLA_OP(Name("Conv3DBackpropFilterV2"), Conv3DBackpropFilterOp);
class DepthwiseConv2DBackpropFilterOp : public ConvBackpropFilterOp {
public:
explicit DepthwiseConv2DBackpropFilterOp(OpKernelConstruction* ctx)
: ConvBackpropFilterOp(ctx, /*num_spatial_dims=*/2, /*depthwise=*/true) {}
};
REGISTER_XLA_OP(Name("DepthwiseConv2dNativeBackpropFilter"),
DepthwiseConv2DBackpropFilterOp);
} // namespace
} // namespace tensorflow

View File

@ -1,235 +0,0 @@
/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
// XLA-specific Ops for 2D depthwise convolution.
#include "tensorflow/compiler/tf2xla/shape_util.h"
#include "tensorflow/compiler/tf2xla/type_util.h"
#include "tensorflow/compiler/tf2xla/xla_helpers.h"
#include "tensorflow/compiler/tf2xla/xla_op_kernel.h"
#include "tensorflow/compiler/tf2xla/xla_op_registry.h"
#include "tensorflow/compiler/xla/literal_util.h"
#include "tensorflow/core/framework/numeric_op.h"
#include "tensorflow/core/framework/op_kernel.h"
#include "tensorflow/core/framework/tensor.h"
#include "tensorflow/core/framework/tensor_shape.h"
#include "tensorflow/core/framework/tensor_slice.h"
#include "tensorflow/core/kernels/bounds_check.h"
#include "tensorflow/core/kernels/depthwise_conv_op.h"
#include "tensorflow/core/kernels/ops_util.h"
#include "tensorflow/core/util/padding.h"
#include "tensorflow/core/util/tensor_format.h"
namespace tensorflow {
namespace {
// Name of the function to use as the implementation for depthwise 2D
// convolution. Default is empty string; another possible value is
// "DummyDepthwiseConv2dKernel".
static const char kDepthwiseConv2dCustomFunc[] = "";
class DepthwiseConv2dNativeOp : public XlaOpKernel {
public:
explicit DepthwiseConv2dNativeOp(OpKernelConstruction* ctx)
: XlaOpKernel(ctx) {
// TODO(keveman): Refactor this (and other XLA OpKernel constructors) so
// that they use a common implementation shared with non-XLA kernels.
OP_REQUIRES_OK(ctx, ctx->GetAttr("strides", &strides_));
OP_REQUIRES(ctx, strides_.size() == 4,
errors::InvalidArgument("Sliding window strides field must "
"specify 4 dimensions"));
OP_REQUIRES(ctx, strides_[1] == strides_[2],
errors::InvalidArgument(
"Current implementation only supports equal length "
"strides in the row and column dimensions."));
OP_REQUIRES(
ctx, (strides_[0] == 1 && strides_[3] == 1),
errors::InvalidArgument("Current implementation does not yet support "
"strides in the batch and depth dimensions."));
OP_REQUIRES_OK(ctx, ctx->GetAttr("padding", &padding_));
}
void Compile(XlaOpKernelContext* ctx) override {
// Input tensor is of the following dimensions:
// [ batch, in_rows, in_cols, in_depth ]
const TensorShape input_shape = ctx->InputShape(0);
// Input filter is of the following dimensions:
// [ filter_rows, filter_cols, in_depth, depth_multiplier]
const TensorShape filter_shape = ctx->InputShape(1);
// For 2D convolution, there should be 4 dimensions.
OP_REQUIRES(ctx, input_shape.dims() == 4,
errors::InvalidArgument("input must be 4-dimensional",
input_shape.DebugString()));
OP_REQUIRES(ctx, filter_shape.dims() == 4,
errors::InvalidArgument("filter must be 4-dimensional: ",
filter_shape.DebugString()));
// The last dimension for input is in_depth. It must be the same as the
// filter's in_depth.
const int64 in_depth = input_shape.dim_size(3);
OP_REQUIRES(
ctx, in_depth == filter_shape.dim_size(2),
errors::InvalidArgument("input and filter must have the same depth: ",
in_depth, " vs ", filter_shape.dim_size(2)));
// The last dimension for filter is depth multiplier.
const int64 depth_multiplier = filter_shape.dim_size(3);
// The output depth is input depth x depth multiplier.
const int64 out_depth = in_depth * depth_multiplier;
// The second dimension for input is rows/height.
// The first dimension for filter is rows/height.
const int64 input_rows = input_shape.dim_size(1);
const int64 filter_rows = filter_shape.dim_size(0);
// The third dimension for input is columns/width.
// The second dimension for filter is columns/width.
const int64 input_cols = input_shape.dim_size(2);
const int64 filter_cols = filter_shape.dim_size(1);
// The first dimension for input is batch.
const int64 batch = input_shape.dim_size(0);
// For now we take the stride from the second dimension only (we
// assume row = col stride, and do not support striding on the
// batch or depth dimension).
const int32 stride = strides_[1];
int64 out_rows = 0, out_cols = 0, pad_rows = 0, pad_cols = 0;
OP_REQUIRES_OK(ctx, GetWindowedOutputSize(input_rows, filter_rows, stride,
padding_, &out_rows, &pad_rows));
OP_REQUIRES_OK(ctx, GetWindowedOutputSize(input_cols, filter_cols, stride,
padding_, &out_cols, &pad_cols));
TensorShape out_shape({batch, out_rows, out_cols, out_depth});
OP_REQUIRES(
ctx, out_shape.num_elements() <= 2147483647,
errors::InvalidArgument("total number of outputs should be within the "
"range of int which is used in the GPU kernel",
in_depth, " vs ", filter_shape.dim_size(2)));
// Output tensor is of the following dimensions:
// [ in_batch, out_rows, out_cols, out_depth ]
VLOG(2) << "DepthwiseConv2dNative: "
<< " Input: [" << batch << ", " << input_rows << ", " << input_cols
<< ", " << in_depth << "]; Filter: [" << filter_rows << ", "
<< filter_cols << ", " << in_depth << ", " << depth_multiplier
<< "]; stride = " << stride << ", pad_rows = " << pad_rows
<< ", pad_cols = " << pad_cols << ", output: [" << batch << ", "
<< out_rows << ", " << out_cols << ", " << out_depth << "]";
xla::ComputationBuilder& b = *ctx->builder();
xla::ComputationDataHandle input = ctx->Input(0);
xla::ComputationDataHandle filter = ctx->Input(1);
xla::ComputationDataHandle output;
const string custom_function_name = kDepthwiseConv2dCustomFunc;
if (!custom_function_name.empty()) {
xla::Shape xla_out_shape;
OP_REQUIRES_OK(
ctx, TensorShapeToXLAShape(input_type(0), out_shape, &xla_out_shape));
// The custom function for depthwise should interpret its arguments
// as follows :
// func(T* output,
// const T* input, const T* filter,
// const int32* input_size, const int32* filter_size,
// const int32* output_size,
// int32 stride, int32 pad_rows, int32 pad_cols)
//
// where T is the type of Tensor that this kernel is registered for.
// Note that the custom call op passes uses the following calling
// convention:
// func(void* output, void** inputs)
//
// Therefore the custom function should first construct the above
// inputs by unparsing the second argument passed to it.
output = b.CustomCall(
custom_function_name,
{input, filter,
b.ConstantR1<int64>({batch, input_rows, input_cols, in_depth}),
b.ConstantR1<int64>(
{filter_rows, filter_cols, in_depth, depth_multiplier}),
b.ConstantR1<int64>({batch, out_rows, out_cols, out_depth}),
b.ConstantR0<int64>(stride), b.ConstantR0<int64>(pad_rows),
b.ConstantR0<int64>(pad_cols)},
xla_out_shape);
} else {
// These will be used to define the bounds of each slice.
// Within the loop, the input_channel index will be modified.
gtl::InlinedVector<int64, 4> filter_begin(4, 0);
gtl::InlinedVector<int64, 4> filter_limits(4);
gtl::InlinedVector<int64, 4> input_begin(4, 0);
gtl::InlinedVector<int64, 4> input_limits(4);
gtl::InlinedVector<int64, 4> strides(4, 1);
for (int i = 0; i < 4; ++i) {
filter_limits[i] = filter_shape.dim_size(i);
input_limits[i] = input_shape.dim_size(i);
}
std::vector<int64> strides_for_tla{strides_[1], strides_[2]};
xla::Padding xla_padding =
(padding_ == VALID) ? xla::Padding::kValid : xla::Padding::kSame;
xla::ConvolutionDimensionNumbers dims;
dims.set_batch_dimension(0);
dims.set_feature_dimension(3);
dims.add_spatial_dimensions(1);
dims.add_spatial_dimensions(2);
// TF filter shape is [ H, W, inC, outC ]
dims.add_kernel_spatial_dimensions(0);
dims.add_kernel_spatial_dimensions(1);
dims.set_kernel_input_feature_dimension(2);
dims.set_kernel_output_feature_dimension(3);
// Create one convolution for each input channel
std::vector<xla::ComputationDataHandle> convs;
for (int i = 0; i < in_depth; ++i) {
filter_begin[2] = i;
filter_limits[2] = i + 1;
input_begin[3] = i;
input_limits[3] = i + 1;
xla::ComputationDataHandle filter_slice =
b.Slice(filter, filter_begin, filter_limits, strides);
xla::ComputationDataHandle input_slice =
b.Slice(input, input_begin, input_limits, strides);
convs.push_back(b.ConvWithGeneralDimensions(
input_slice, filter_slice, strides_for_tla, xla_padding, dims));
}
// Concatenate the per-channel convolutions along the depth dimension.
output = b.ConcatInDim(convs, 3);
}
ctx->SetOutput(0, output);
}
private:
std::vector<int32> strides_;
Padding padding_;
TF_DISALLOW_COPY_AND_ASSIGN(DepthwiseConv2dNativeOp);
};
REGISTER_XLA_OP(Name("DepthwiseConv2dNative").TypeConstraint("T", kFloatTypes),
DepthwiseConv2dNativeOp);
} // namespace
} // namespace tensorflow

View File

@ -1079,10 +1079,6 @@ class Conv2DTest(test.TestCase):
padding="VALID"))
# This is only a very simple test. More comprehensive tests live in
# //learning/dist_belief/experimental/brain_compatibility/conv_nn_test.py
# where we compare the numeric results of the depthwise conv op with the
# depthwise weighted sum transformer in dist_belief.
class DepthwiseConv2DTest(test.TestCase):
def _VerifyValues(self, tensor_in_sizes, filter_in_sizes, stride, padding,