Minor change to depthwise convolution with dilation.

PiperOrigin-RevId: 284217866
Change-Id: Icd5419e62cef5870745182f8120b937ed791212f
This commit is contained in:
A. Unique TensorFlower 2019-12-06 10:52:23 -08:00 committed by TensorFlower Gardener
parent e37dc0f10a
commit 722fc02b3f
4 changed files with 42 additions and 59 deletions
tensorflow/python

View File

@ -243,6 +243,7 @@ py_library(
"**/*test.py",
"**/benchmark.py", # In platform_benchmark.
"**/analytics.py", # In platform_analytics.
"**/device_context.py", # In platform_device_context.
],
) + ["platform/build_info.py"],
srcs_version = "PY2AND3",
@ -275,6 +276,16 @@ py_library(
srcs_version = "PY2AND3",
)
py_library(
name = "platform_device_context",
srcs = ["platform/device_context.py"],
srcs_version = "PY2AND3",
deps = [
":control_flow_ops",
":framework",
],
)
py_library(
name = "platform_test",
srcs = ["platform/googletest.py"],
@ -3805,6 +3816,7 @@ py_library(
":nn_grad",
":nn_ops",
":nn_ops_gen",
":platform_device_context",
":rnn",
":sparse_ops",
":util",

View File

@ -38,6 +38,7 @@ from tensorflow.python.ops import math_ops
from tensorflow.python.ops import nn_ops
from tensorflow.python.ops import variables
from tensorflow.python.ops.losses import util as losses_util
from tensorflow.python.platform import device_context
from tensorflow.python.util.deprecation import deprecated_args
from tensorflow.python.util.deprecation import deprecated_argument_lookup
from tensorflow.python.util.tf_export import tf_export
@ -707,22 +708,6 @@ def zero_fraction(value, name=None):
return array_ops.identity(zero_fraction_float32, "fraction")
# copybara:strip_begin
# TODO(b/138808492): Remove code inside copybara
# to make TPU code and CPU code consistent.
def _enclosing_tpu_context():
# pylint: disable=protected-access
context = ops.get_default_graph()._get_control_flow_context()
# pylint: enable=protected-access
while context is not None and not isinstance(
context, control_flow_ops.XLAControlFlowContext):
context = context.outer_context
return context
# copybara:strip_end
# pylint: disable=redefined-builtin
@tf_export(v1=["nn.depthwise_conv2d"])
def depthwise_conv2d(input,
@ -782,11 +767,8 @@ def depthwise_conv2d(input,
if rate is None:
rate = [1, 1]
# copybara:strip_begin
# TODO(b/138808492): Remove code inside copybara
# to make TPU code and CPU code consistent.
# Use depthwise_conv2d_native if executing on TPU.
if _enclosing_tpu_context() is not None:
if device_context.enclosing_tpu_context() is not None:
if data_format == "NCHW":
dilations = [1, 1, rate[0], rate[1]]
else:
@ -799,7 +781,6 @@ def depthwise_conv2d(input,
data_format=data_format,
dilations=dilations,
name=name)
# copybara:strip_end
def op(input_converted, _, padding):
return nn_ops.depthwise_conv2d_native(

View File

@ -36,10 +36,6 @@ from tensorflow.python.framework import tensor_shape
from tensorflow.python.framework import tensor_util
from tensorflow.python.ops import array_ops
from tensorflow.python.ops import check_ops
# copybara:strip_begin
# TODO(b/138808492): Remove code inside copybara
from tensorflow.python.ops import control_flow_ops
# copybara:strip_end
from tensorflow.python.ops import gen_math_ops
from tensorflow.python.ops import gen_nn_ops
from tensorflow.python.ops import math_ops
@ -48,6 +44,7 @@ from tensorflow.python.ops import random_ops
# pylint: disable=wildcard-import
from tensorflow.python.ops.gen_nn_ops import *
# pylint: enable=wildcard-import
from tensorflow.python.platform import device_context
from tensorflow.python.platform import tf_logging as logging
from tensorflow.python.util import deprecation
from tensorflow.python.util.compat import collections_abc
@ -927,22 +924,6 @@ convolution_v2.__doc__ = deprecation.rewrite_argument_docstring(
"filter", "filters")
# copybara:strip_begin
# TODO(b/138808492): Remove code inside copybara
# to make TPU code and CPU code consistent.
def _enclosing_tpu_context():
# pylint: disable=protected-access
run_context = ops.get_default_graph()._get_control_flow_context()
# pylint: enable=protected-access
while run_context is not None and not isinstance(
run_context, control_flow_ops.XLAControlFlowContext):
run_context = run_context.outer_context
return run_context
# copybara:strip_end
def convolution_internal(
input, # pylint: disable=redefined-builtin
filters,
@ -980,28 +961,20 @@ def convolution_internal(
strides = _get_sequence(strides, n, channel_index, "strides")
dilations = _get_sequence(dilations, n, channel_index, "dilations")
# copybara:strip_begin
# TODO(b/138808492): Remove code inside copybara
# to make TPU code and CPU code consistent.
scopes = {1: "conv1d", 2: "Conv2D", 3: "Conv3D"}
if not call_from_convolution and _enclosing_tpu_context() is not None:
if not call_from_convolution and device_context.enclosing_tpu_context(
) is not None:
scope = scopes[n]
else:
scope = "convolution"
# copybara:strip_end
# copybara:insert scope = "convolution"
with ops.name_scope(name, scope, [input, filters]) as name:
conv_ops = {1: conv1d, 2: gen_nn_ops.conv2d, 3: gen_nn_ops.conv3d}
# copybara:strip_begin
# TODO(b/138808492): Remove code inside copybara
# to make TPU code and CPU code consistent.
if _enclosing_tpu_context() is not None or all(i == 1 for i in dilations):
if device_context.enclosing_tpu_context() is not None or all(
i == 1 for i in dilations):
# fast path for TPU or if no dilation as gradient only supported on GPU
# for dilations
# copybara:strip_end
# copybara:insert if all(i == 1 for i in dilations):
op = conv_ops[n]
return op(
input,
@ -1120,11 +1093,8 @@ class Convolution(object):
name=self.name)
def __call__(self, inp, filter): # pylint: disable=redefined-builtin
# copybara:strip_begin
# TODO(b/138808492): Remove code inside copybara
# to make TPU code and CPU code consistent.
# TPU convolution supports dilations greater than 1.
if _enclosing_tpu_context() is not None:
if device_context.enclosing_tpu_context() is not None:
return convolution_internal(
inp,
filter,
@ -1136,8 +1106,6 @@ class Convolution(object):
call_from_convolution=False)
else:
return self.conv_op(inp, filter)
# copybara:strip_end
# copybara:insert return self.conv_op(inp, filter)
@tf_export(v1=["nn.pool"])

View File

@ -0,0 +1,22 @@
# Copyright 2019 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Helpers to get device context."""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
def enclosing_tpu_context():
pass