Minor change to depthwise convolution with dilation.
PiperOrigin-RevId: 284217866 Change-Id: Icd5419e62cef5870745182f8120b937ed791212f
This commit is contained in:
parent
e37dc0f10a
commit
722fc02b3f
tensorflow/python
@ -243,6 +243,7 @@ py_library(
|
||||
"**/*test.py",
|
||||
"**/benchmark.py", # In platform_benchmark.
|
||||
"**/analytics.py", # In platform_analytics.
|
||||
"**/device_context.py", # In platform_device_context.
|
||||
],
|
||||
) + ["platform/build_info.py"],
|
||||
srcs_version = "PY2AND3",
|
||||
@ -275,6 +276,16 @@ py_library(
|
||||
srcs_version = "PY2AND3",
|
||||
)
|
||||
|
||||
py_library(
|
||||
name = "platform_device_context",
|
||||
srcs = ["platform/device_context.py"],
|
||||
srcs_version = "PY2AND3",
|
||||
deps = [
|
||||
":control_flow_ops",
|
||||
":framework",
|
||||
],
|
||||
)
|
||||
|
||||
py_library(
|
||||
name = "platform_test",
|
||||
srcs = ["platform/googletest.py"],
|
||||
@ -3805,6 +3816,7 @@ py_library(
|
||||
":nn_grad",
|
||||
":nn_ops",
|
||||
":nn_ops_gen",
|
||||
":platform_device_context",
|
||||
":rnn",
|
||||
":sparse_ops",
|
||||
":util",
|
||||
|
@ -38,6 +38,7 @@ from tensorflow.python.ops import math_ops
|
||||
from tensorflow.python.ops import nn_ops
|
||||
from tensorflow.python.ops import variables
|
||||
from tensorflow.python.ops.losses import util as losses_util
|
||||
from tensorflow.python.platform import device_context
|
||||
from tensorflow.python.util.deprecation import deprecated_args
|
||||
from tensorflow.python.util.deprecation import deprecated_argument_lookup
|
||||
from tensorflow.python.util.tf_export import tf_export
|
||||
@ -707,22 +708,6 @@ def zero_fraction(value, name=None):
|
||||
return array_ops.identity(zero_fraction_float32, "fraction")
|
||||
|
||||
|
||||
# copybara:strip_begin
|
||||
# TODO(b/138808492): Remove code inside copybara
|
||||
# to make TPU code and CPU code consistent.
|
||||
def _enclosing_tpu_context():
|
||||
# pylint: disable=protected-access
|
||||
context = ops.get_default_graph()._get_control_flow_context()
|
||||
# pylint: enable=protected-access
|
||||
while context is not None and not isinstance(
|
||||
context, control_flow_ops.XLAControlFlowContext):
|
||||
context = context.outer_context
|
||||
return context
|
||||
|
||||
|
||||
# copybara:strip_end
|
||||
|
||||
|
||||
# pylint: disable=redefined-builtin
|
||||
@tf_export(v1=["nn.depthwise_conv2d"])
|
||||
def depthwise_conv2d(input,
|
||||
@ -782,11 +767,8 @@ def depthwise_conv2d(input,
|
||||
if rate is None:
|
||||
rate = [1, 1]
|
||||
|
||||
# copybara:strip_begin
|
||||
# TODO(b/138808492): Remove code inside copybara
|
||||
# to make TPU code and CPU code consistent.
|
||||
# Use depthwise_conv2d_native if executing on TPU.
|
||||
if _enclosing_tpu_context() is not None:
|
||||
if device_context.enclosing_tpu_context() is not None:
|
||||
if data_format == "NCHW":
|
||||
dilations = [1, 1, rate[0], rate[1]]
|
||||
else:
|
||||
@ -799,7 +781,6 @@ def depthwise_conv2d(input,
|
||||
data_format=data_format,
|
||||
dilations=dilations,
|
||||
name=name)
|
||||
# copybara:strip_end
|
||||
|
||||
def op(input_converted, _, padding):
|
||||
return nn_ops.depthwise_conv2d_native(
|
||||
|
@ -36,10 +36,6 @@ from tensorflow.python.framework import tensor_shape
|
||||
from tensorflow.python.framework import tensor_util
|
||||
from tensorflow.python.ops import array_ops
|
||||
from tensorflow.python.ops import check_ops
|
||||
# copybara:strip_begin
|
||||
# TODO(b/138808492): Remove code inside copybara
|
||||
from tensorflow.python.ops import control_flow_ops
|
||||
# copybara:strip_end
|
||||
from tensorflow.python.ops import gen_math_ops
|
||||
from tensorflow.python.ops import gen_nn_ops
|
||||
from tensorflow.python.ops import math_ops
|
||||
@ -48,6 +44,7 @@ from tensorflow.python.ops import random_ops
|
||||
# pylint: disable=wildcard-import
|
||||
from tensorflow.python.ops.gen_nn_ops import *
|
||||
# pylint: enable=wildcard-import
|
||||
from tensorflow.python.platform import device_context
|
||||
from tensorflow.python.platform import tf_logging as logging
|
||||
from tensorflow.python.util import deprecation
|
||||
from tensorflow.python.util.compat import collections_abc
|
||||
@ -927,22 +924,6 @@ convolution_v2.__doc__ = deprecation.rewrite_argument_docstring(
|
||||
"filter", "filters")
|
||||
|
||||
|
||||
# copybara:strip_begin
|
||||
# TODO(b/138808492): Remove code inside copybara
|
||||
# to make TPU code and CPU code consistent.
|
||||
def _enclosing_tpu_context():
|
||||
# pylint: disable=protected-access
|
||||
run_context = ops.get_default_graph()._get_control_flow_context()
|
||||
# pylint: enable=protected-access
|
||||
while run_context is not None and not isinstance(
|
||||
run_context, control_flow_ops.XLAControlFlowContext):
|
||||
run_context = run_context.outer_context
|
||||
return run_context
|
||||
|
||||
|
||||
# copybara:strip_end
|
||||
|
||||
|
||||
def convolution_internal(
|
||||
input, # pylint: disable=redefined-builtin
|
||||
filters,
|
||||
@ -980,28 +961,20 @@ def convolution_internal(
|
||||
strides = _get_sequence(strides, n, channel_index, "strides")
|
||||
dilations = _get_sequence(dilations, n, channel_index, "dilations")
|
||||
|
||||
# copybara:strip_begin
|
||||
# TODO(b/138808492): Remove code inside copybara
|
||||
# to make TPU code and CPU code consistent.
|
||||
scopes = {1: "conv1d", 2: "Conv2D", 3: "Conv3D"}
|
||||
if not call_from_convolution and _enclosing_tpu_context() is not None:
|
||||
if not call_from_convolution and device_context.enclosing_tpu_context(
|
||||
) is not None:
|
||||
scope = scopes[n]
|
||||
else:
|
||||
scope = "convolution"
|
||||
# copybara:strip_end
|
||||
# copybara:insert scope = "convolution"
|
||||
|
||||
with ops.name_scope(name, scope, [input, filters]) as name:
|
||||
conv_ops = {1: conv1d, 2: gen_nn_ops.conv2d, 3: gen_nn_ops.conv3d}
|
||||
|
||||
# copybara:strip_begin
|
||||
# TODO(b/138808492): Remove code inside copybara
|
||||
# to make TPU code and CPU code consistent.
|
||||
if _enclosing_tpu_context() is not None or all(i == 1 for i in dilations):
|
||||
if device_context.enclosing_tpu_context() is not None or all(
|
||||
i == 1 for i in dilations):
|
||||
# fast path for TPU or if no dilation as gradient only supported on GPU
|
||||
# for dilations
|
||||
# copybara:strip_end
|
||||
# copybara:insert if all(i == 1 for i in dilations):
|
||||
op = conv_ops[n]
|
||||
return op(
|
||||
input,
|
||||
@ -1120,11 +1093,8 @@ class Convolution(object):
|
||||
name=self.name)
|
||||
|
||||
def __call__(self, inp, filter): # pylint: disable=redefined-builtin
|
||||
# copybara:strip_begin
|
||||
# TODO(b/138808492): Remove code inside copybara
|
||||
# to make TPU code and CPU code consistent.
|
||||
# TPU convolution supports dilations greater than 1.
|
||||
if _enclosing_tpu_context() is not None:
|
||||
if device_context.enclosing_tpu_context() is not None:
|
||||
return convolution_internal(
|
||||
inp,
|
||||
filter,
|
||||
@ -1136,8 +1106,6 @@ class Convolution(object):
|
||||
call_from_convolution=False)
|
||||
else:
|
||||
return self.conv_op(inp, filter)
|
||||
# copybara:strip_end
|
||||
# copybara:insert return self.conv_op(inp, filter)
|
||||
|
||||
|
||||
@tf_export(v1=["nn.pool"])
|
||||
|
22
tensorflow/python/platform/device_context.py
Normal file
22
tensorflow/python/platform/device_context.py
Normal file
@ -0,0 +1,22 @@
|
||||
# Copyright 2019 The TensorFlow Authors. All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ==============================================================================
|
||||
"""Helpers to get device context."""
|
||||
from __future__ import absolute_import
|
||||
from __future__ import division
|
||||
from __future__ import print_function
|
||||
|
||||
|
||||
def enclosing_tpu_context():
|
||||
pass
|
Loading…
Reference in New Issue
Block a user