Minor change to depthwise convolution with dilation.
PiperOrigin-RevId: 284217866 Change-Id: Icd5419e62cef5870745182f8120b937ed791212f
This commit is contained in:
parent
e37dc0f10a
commit
722fc02b3f
@ -243,6 +243,7 @@ py_library(
|
|||||||
"**/*test.py",
|
"**/*test.py",
|
||||||
"**/benchmark.py", # In platform_benchmark.
|
"**/benchmark.py", # In platform_benchmark.
|
||||||
"**/analytics.py", # In platform_analytics.
|
"**/analytics.py", # In platform_analytics.
|
||||||
|
"**/device_context.py", # In platform_device_context.
|
||||||
],
|
],
|
||||||
) + ["platform/build_info.py"],
|
) + ["platform/build_info.py"],
|
||||||
srcs_version = "PY2AND3",
|
srcs_version = "PY2AND3",
|
||||||
@ -275,6 +276,16 @@ py_library(
|
|||||||
srcs_version = "PY2AND3",
|
srcs_version = "PY2AND3",
|
||||||
)
|
)
|
||||||
|
|
||||||
|
py_library(
|
||||||
|
name = "platform_device_context",
|
||||||
|
srcs = ["platform/device_context.py"],
|
||||||
|
srcs_version = "PY2AND3",
|
||||||
|
deps = [
|
||||||
|
":control_flow_ops",
|
||||||
|
":framework",
|
||||||
|
],
|
||||||
|
)
|
||||||
|
|
||||||
py_library(
|
py_library(
|
||||||
name = "platform_test",
|
name = "platform_test",
|
||||||
srcs = ["platform/googletest.py"],
|
srcs = ["platform/googletest.py"],
|
||||||
@ -3805,6 +3816,7 @@ py_library(
|
|||||||
":nn_grad",
|
":nn_grad",
|
||||||
":nn_ops",
|
":nn_ops",
|
||||||
":nn_ops_gen",
|
":nn_ops_gen",
|
||||||
|
":platform_device_context",
|
||||||
":rnn",
|
":rnn",
|
||||||
":sparse_ops",
|
":sparse_ops",
|
||||||
":util",
|
":util",
|
||||||
|
@ -38,6 +38,7 @@ from tensorflow.python.ops import math_ops
|
|||||||
from tensorflow.python.ops import nn_ops
|
from tensorflow.python.ops import nn_ops
|
||||||
from tensorflow.python.ops import variables
|
from tensorflow.python.ops import variables
|
||||||
from tensorflow.python.ops.losses import util as losses_util
|
from tensorflow.python.ops.losses import util as losses_util
|
||||||
|
from tensorflow.python.platform import device_context
|
||||||
from tensorflow.python.util.deprecation import deprecated_args
|
from tensorflow.python.util.deprecation import deprecated_args
|
||||||
from tensorflow.python.util.deprecation import deprecated_argument_lookup
|
from tensorflow.python.util.deprecation import deprecated_argument_lookup
|
||||||
from tensorflow.python.util.tf_export import tf_export
|
from tensorflow.python.util.tf_export import tf_export
|
||||||
@ -707,22 +708,6 @@ def zero_fraction(value, name=None):
|
|||||||
return array_ops.identity(zero_fraction_float32, "fraction")
|
return array_ops.identity(zero_fraction_float32, "fraction")
|
||||||
|
|
||||||
|
|
||||||
# copybara:strip_begin
|
|
||||||
# TODO(b/138808492): Remove code inside copybara
|
|
||||||
# to make TPU code and CPU code consistent.
|
|
||||||
def _enclosing_tpu_context():
|
|
||||||
# pylint: disable=protected-access
|
|
||||||
context = ops.get_default_graph()._get_control_flow_context()
|
|
||||||
# pylint: enable=protected-access
|
|
||||||
while context is not None and not isinstance(
|
|
||||||
context, control_flow_ops.XLAControlFlowContext):
|
|
||||||
context = context.outer_context
|
|
||||||
return context
|
|
||||||
|
|
||||||
|
|
||||||
# copybara:strip_end
|
|
||||||
|
|
||||||
|
|
||||||
# pylint: disable=redefined-builtin
|
# pylint: disable=redefined-builtin
|
||||||
@tf_export(v1=["nn.depthwise_conv2d"])
|
@tf_export(v1=["nn.depthwise_conv2d"])
|
||||||
def depthwise_conv2d(input,
|
def depthwise_conv2d(input,
|
||||||
@ -782,11 +767,8 @@ def depthwise_conv2d(input,
|
|||||||
if rate is None:
|
if rate is None:
|
||||||
rate = [1, 1]
|
rate = [1, 1]
|
||||||
|
|
||||||
# copybara:strip_begin
|
|
||||||
# TODO(b/138808492): Remove code inside copybara
|
|
||||||
# to make TPU code and CPU code consistent.
|
|
||||||
# Use depthwise_conv2d_native if executing on TPU.
|
# Use depthwise_conv2d_native if executing on TPU.
|
||||||
if _enclosing_tpu_context() is not None:
|
if device_context.enclosing_tpu_context() is not None:
|
||||||
if data_format == "NCHW":
|
if data_format == "NCHW":
|
||||||
dilations = [1, 1, rate[0], rate[1]]
|
dilations = [1, 1, rate[0], rate[1]]
|
||||||
else:
|
else:
|
||||||
@ -799,7 +781,6 @@ def depthwise_conv2d(input,
|
|||||||
data_format=data_format,
|
data_format=data_format,
|
||||||
dilations=dilations,
|
dilations=dilations,
|
||||||
name=name)
|
name=name)
|
||||||
# copybara:strip_end
|
|
||||||
|
|
||||||
def op(input_converted, _, padding):
|
def op(input_converted, _, padding):
|
||||||
return nn_ops.depthwise_conv2d_native(
|
return nn_ops.depthwise_conv2d_native(
|
||||||
|
@ -36,10 +36,6 @@ from tensorflow.python.framework import tensor_shape
|
|||||||
from tensorflow.python.framework import tensor_util
|
from tensorflow.python.framework import tensor_util
|
||||||
from tensorflow.python.ops import array_ops
|
from tensorflow.python.ops import array_ops
|
||||||
from tensorflow.python.ops import check_ops
|
from tensorflow.python.ops import check_ops
|
||||||
# copybara:strip_begin
|
|
||||||
# TODO(b/138808492): Remove code inside copybara
|
|
||||||
from tensorflow.python.ops import control_flow_ops
|
|
||||||
# copybara:strip_end
|
|
||||||
from tensorflow.python.ops import gen_math_ops
|
from tensorflow.python.ops import gen_math_ops
|
||||||
from tensorflow.python.ops import gen_nn_ops
|
from tensorflow.python.ops import gen_nn_ops
|
||||||
from tensorflow.python.ops import math_ops
|
from tensorflow.python.ops import math_ops
|
||||||
@ -48,6 +44,7 @@ from tensorflow.python.ops import random_ops
|
|||||||
# pylint: disable=wildcard-import
|
# pylint: disable=wildcard-import
|
||||||
from tensorflow.python.ops.gen_nn_ops import *
|
from tensorflow.python.ops.gen_nn_ops import *
|
||||||
# pylint: enable=wildcard-import
|
# pylint: enable=wildcard-import
|
||||||
|
from tensorflow.python.platform import device_context
|
||||||
from tensorflow.python.platform import tf_logging as logging
|
from tensorflow.python.platform import tf_logging as logging
|
||||||
from tensorflow.python.util import deprecation
|
from tensorflow.python.util import deprecation
|
||||||
from tensorflow.python.util.compat import collections_abc
|
from tensorflow.python.util.compat import collections_abc
|
||||||
@ -927,22 +924,6 @@ convolution_v2.__doc__ = deprecation.rewrite_argument_docstring(
|
|||||||
"filter", "filters")
|
"filter", "filters")
|
||||||
|
|
||||||
|
|
||||||
# copybara:strip_begin
|
|
||||||
# TODO(b/138808492): Remove code inside copybara
|
|
||||||
# to make TPU code and CPU code consistent.
|
|
||||||
def _enclosing_tpu_context():
|
|
||||||
# pylint: disable=protected-access
|
|
||||||
run_context = ops.get_default_graph()._get_control_flow_context()
|
|
||||||
# pylint: enable=protected-access
|
|
||||||
while run_context is not None and not isinstance(
|
|
||||||
run_context, control_flow_ops.XLAControlFlowContext):
|
|
||||||
run_context = run_context.outer_context
|
|
||||||
return run_context
|
|
||||||
|
|
||||||
|
|
||||||
# copybara:strip_end
|
|
||||||
|
|
||||||
|
|
||||||
def convolution_internal(
|
def convolution_internal(
|
||||||
input, # pylint: disable=redefined-builtin
|
input, # pylint: disable=redefined-builtin
|
||||||
filters,
|
filters,
|
||||||
@ -980,28 +961,20 @@ def convolution_internal(
|
|||||||
strides = _get_sequence(strides, n, channel_index, "strides")
|
strides = _get_sequence(strides, n, channel_index, "strides")
|
||||||
dilations = _get_sequence(dilations, n, channel_index, "dilations")
|
dilations = _get_sequence(dilations, n, channel_index, "dilations")
|
||||||
|
|
||||||
# copybara:strip_begin
|
|
||||||
# TODO(b/138808492): Remove code inside copybara
|
|
||||||
# to make TPU code and CPU code consistent.
|
|
||||||
scopes = {1: "conv1d", 2: "Conv2D", 3: "Conv3D"}
|
scopes = {1: "conv1d", 2: "Conv2D", 3: "Conv3D"}
|
||||||
if not call_from_convolution and _enclosing_tpu_context() is not None:
|
if not call_from_convolution and device_context.enclosing_tpu_context(
|
||||||
|
) is not None:
|
||||||
scope = scopes[n]
|
scope = scopes[n]
|
||||||
else:
|
else:
|
||||||
scope = "convolution"
|
scope = "convolution"
|
||||||
# copybara:strip_end
|
|
||||||
# copybara:insert scope = "convolution"
|
|
||||||
|
|
||||||
with ops.name_scope(name, scope, [input, filters]) as name:
|
with ops.name_scope(name, scope, [input, filters]) as name:
|
||||||
conv_ops = {1: conv1d, 2: gen_nn_ops.conv2d, 3: gen_nn_ops.conv3d}
|
conv_ops = {1: conv1d, 2: gen_nn_ops.conv2d, 3: gen_nn_ops.conv3d}
|
||||||
|
|
||||||
# copybara:strip_begin
|
if device_context.enclosing_tpu_context() is not None or all(
|
||||||
# TODO(b/138808492): Remove code inside copybara
|
i == 1 for i in dilations):
|
||||||
# to make TPU code and CPU code consistent.
|
|
||||||
if _enclosing_tpu_context() is not None or all(i == 1 for i in dilations):
|
|
||||||
# fast path for TPU or if no dilation as gradient only supported on GPU
|
# fast path for TPU or if no dilation as gradient only supported on GPU
|
||||||
# for dilations
|
# for dilations
|
||||||
# copybara:strip_end
|
|
||||||
# copybara:insert if all(i == 1 for i in dilations):
|
|
||||||
op = conv_ops[n]
|
op = conv_ops[n]
|
||||||
return op(
|
return op(
|
||||||
input,
|
input,
|
||||||
@ -1120,11 +1093,8 @@ class Convolution(object):
|
|||||||
name=self.name)
|
name=self.name)
|
||||||
|
|
||||||
def __call__(self, inp, filter): # pylint: disable=redefined-builtin
|
def __call__(self, inp, filter): # pylint: disable=redefined-builtin
|
||||||
# copybara:strip_begin
|
|
||||||
# TODO(b/138808492): Remove code inside copybara
|
|
||||||
# to make TPU code and CPU code consistent.
|
|
||||||
# TPU convolution supports dilations greater than 1.
|
# TPU convolution supports dilations greater than 1.
|
||||||
if _enclosing_tpu_context() is not None:
|
if device_context.enclosing_tpu_context() is not None:
|
||||||
return convolution_internal(
|
return convolution_internal(
|
||||||
inp,
|
inp,
|
||||||
filter,
|
filter,
|
||||||
@ -1136,8 +1106,6 @@ class Convolution(object):
|
|||||||
call_from_convolution=False)
|
call_from_convolution=False)
|
||||||
else:
|
else:
|
||||||
return self.conv_op(inp, filter)
|
return self.conv_op(inp, filter)
|
||||||
# copybara:strip_end
|
|
||||||
# copybara:insert return self.conv_op(inp, filter)
|
|
||||||
|
|
||||||
|
|
||||||
@tf_export(v1=["nn.pool"])
|
@tf_export(v1=["nn.pool"])
|
||||||
|
22
tensorflow/python/platform/device_context.py
Normal file
22
tensorflow/python/platform/device_context.py
Normal file
@ -0,0 +1,22 @@
|
|||||||
|
# Copyright 2019 The TensorFlow Authors. All Rights Reserved.
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
# ==============================================================================
|
||||||
|
"""Helpers to get device context."""
|
||||||
|
from __future__ import absolute_import
|
||||||
|
from __future__ import division
|
||||||
|
from __future__ import print_function
|
||||||
|
|
||||||
|
|
||||||
|
def enclosing_tpu_context():
|
||||||
|
pass
|
Loading…
x
Reference in New Issue
Block a user