Merge pull request #16991 from yifeif/branch_185565363

Branch 185565363
This commit is contained in:
Martin Wicke 2018-02-13 14:47:30 -08:00 committed by GitHub
commit 17103a0b8d
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
151 changed files with 7188 additions and 1998 deletions

View File

@ -224,9 +224,6 @@ def tf_library(name, graph, config,
# TODO(cwhipkey): only depend on kernel code that the model actually needed.
"//tensorflow/compiler/tf2xla/kernels:index_ops_kernel_argmax_float_1d",
"//tensorflow/compiler/tf2xla/kernels:index_ops_kernel_argmax_float_2d",
"//tensorflow/compiler/xla/service/cpu:cpu_runtime_avx",
"//tensorflow/compiler/xla/service/cpu:cpu_runtime_neon",
"//tensorflow/compiler/xla/service/cpu:cpu_runtime_sse4_1",
"//tensorflow/compiler/xla/service/cpu:runtime_conv2d",
"//tensorflow/compiler/xla/service/cpu:runtime_matmul",
"//tensorflow/compiler/xla/service/cpu:runtime_single_threaded_conv2d",

View File

@ -639,6 +639,7 @@ tf_xla_py_test(
name = "variable_ops_test",
size = "small",
srcs = ["variable_ops_test.py"],
tags = ["optonly"],
deps = [
":xla_test",
"//tensorflow/python:array_ops",
@ -677,6 +678,19 @@ tf_xla_py_test(
],
)
tf_xla_py_test(
name = "scatter_nd_op_test",
size = "medium",
srcs = ["scatter_nd_op_test.py"],
tags = ["optonly"],
deps = [
":xla_test",
"//tensorflow/python:array_ops",
"//tensorflow/python:framework_for_generated_wrappers",
"//tensorflow/python:platform_test",
],
)
cuda_py_test(
name = "xla_device_test",
size = "small",
@ -801,6 +815,17 @@ tf_library(
tfcompile_flags = ["--xla_cpu_multi_thread_eigen=false"],
)
tf_xla_py_test(
name = "fake_quant_ops_test",
size = "medium",
srcs = ["fake_quant_ops_test.py"],
deps = [
":xla_test",
"//tensorflow/python:framework_for_generated_wrappers",
"//tensorflow/python:platform_test",
],
)
# -----------------------------------------------------------------------------
filegroup(

View File

@ -0,0 +1,452 @@
# Copyright 2018 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import numpy as np
from tensorflow.compiler.tests.xla_test import XLATestCase
from tensorflow.python.framework import dtypes
from tensorflow.python.ops import array_ops
from tensorflow.python.ops import gen_array_ops
from tensorflow.python.platform import googletest
class FakeQuantWithMinMaxArgsTest(XLATestCase):
"""Test cases for FakeQuantWithMinMaxArgs operation."""
# 8 bits, wide range.
def testOp_with8BitsNoScalingNoNudging(self):
self._TestOp(0.0, 255.0, 8, False, 0.0, 255.0, 1.0)
def testOp_with8BitsScalingAndNudgingDown(self):
self._TestOp(0.5, 128.0, 8, False, 0.0, 127.5, 0.5)
def testOp_with8BitsScalingAndNudgingUp(self):
self._TestOp(-128.0, -0.5, 8, False, -127.5, 0.0, 0.5)
def testOp_with8BitsScalingAndNudgingBetween(self):
self._TestOp(-0.1, 127.4, 8, False, 0.0, 127.5, 0.5)
# 8 bits, narrow range.
def testOp_with8BitsNarrowRangeNoScalingNoNudging(self):
self._TestOp(0.0, 254.0, 8, True, 0.0, 254.0, 1.0)
def testOp_with8BitsNarrowRangeScalingAndNudgingDown(self):
self._TestOp(0.1, 127.1, 8, True, 0.0, 127.0, 0.5)
def testOp_with8BitsNarrowRangeScalingAndNudgingUp(self):
self._TestOp(-127.1, -0.1, 8, True, -127.0, 0.0, 0.5)
def testOp_with8BitsNarrowRangeScalingAndNudgingBetween(self):
self._TestOp(-0.1, 126.9, 8, True, 0.0, 127.0, 0.5)
# 7 bits, wide range.
def testOp_with7BitsNoScalingNoNudging(self):
self._TestOp(0.0, 127.0, 7, False, 0.0, 127.0, 1.0)
def testOp_with7BitsScalingAndNudgingDown(self):
self._TestOp(0.5, 64.0, 7, False, 0.0, 63.5, 0.5)
def testOp_with7BitsScalingAndNudgingUp(self):
self._TestOp(-64.0, -0.5, 7, False, -63.5, 0.0, 0.5)
def testOp_with7BitsScalingAndNudgingBetween(self):
self._TestOp(-0.1, 63.4, 7, False, 0.0, 63.5, 0.5)
# 7 bits, narrow range.
def testOp_with7BitsNarrowRangeNoScalingNoNudging(self):
self._TestOp(0.0, 126.0, 7, True, 0.0, 126.0, 1.0)
def testOp_with7BitsNarrowRangeScalingAndNudgingDown(self):
self._TestOp(0.1, 63.1, 7, True, 0.0, 63.0, 0.5)
def testOp_with7BitsNarrowRangeScalingAndNudgingUp(self):
self._TestOp(-63.1, -0.1, 7, True, -63.0, 0.0, 0.5)
def testOp_with7BitsNarrowRangeScalingAndNudgingBetween(self):
self._TestOp(-0.1, 62.9, 7, True, 0.0, 63.0, 0.5)
def _TestOp(self, input_min, input_max, num_bits, narrow_range,
expected_nudged_input_min, expected_nudged_input_max,
expected_step):
inputs = np.array(
[
expected_nudged_input_min - expected_step,
expected_nudged_input_min - 0.01, expected_nudged_input_min,
expected_nudged_input_min + 0.01,
expected_nudged_input_min + expected_step - 0.01,
expected_nudged_input_min + expected_step,
expected_nudged_input_min + expected_step + 0.01,
expected_nudged_input_max - 0.01, expected_nudged_input_max,
expected_nudged_input_max + 0.01,
expected_nudged_input_max + expected_step
],
dtype=np.float32)
expected = np.array(
[
expected_nudged_input_min, expected_nudged_input_min,
expected_nudged_input_min, expected_nudged_input_min,
expected_nudged_input_min + expected_step,
expected_nudged_input_min + expected_step,
expected_nudged_input_min + expected_step,
expected_nudged_input_max, expected_nudged_input_max,
expected_nudged_input_max, expected_nudged_input_max
],
dtype=np.float32)
with self.test_session() as session:
with self.test_scope():
input_placeholder = array_ops.placeholder(
dtypes.float32, inputs.shape, name="inputs")
outputs = array_ops.fake_quant_with_min_max_args(
input_placeholder,
min=input_min,
max=input_max,
num_bits=num_bits,
narrow_range=narrow_range)
result = session.run(outputs, {input_placeholder: inputs})
self.assertAllCloseAccordingToType(
result, expected, rtol=1e-3, atol=1e-5, bfloat16_rtol=0.03)
class FakeQuantWithMinMaxArgsGradientTest(XLATestCase):
"""Test cases for FakeQuantWithMinMaxArgsGradient operation."""
# 8 bits, wide range.
def testOp_with8BitsNoScalingNoNudging(self):
self._TestOp(0.0, 255.0, 8, False, 0.0, 255.0, 1.0)
def testOp_with8BitsScalingAndNudgingDown(self):
self._TestOp(0.5, 128.0, 8, False, 0.0, 127.5, 0.5)
def testOp_with8BitsScalingAndNudgingUp(self):
self._TestOp(-128.0, -0.5, 8, False, -127.5, 0.0, 0.5)
def testOp_with8BitsScalingAndNudgingBetween(self):
self._TestOp(-0.1, 127.4, 8, False, 0.0, 127.5, 0.5)
# 8 bits, narrow range.
def testOp_with8BitsNarrowRangeNoScalingNoNudging(self):
self._TestOp(0.0, 254.0, 8, True, 0.0, 254.0, 1.0)
def testOp_with8BitsNarrowRangeScalingAndNudgingDown(self):
self._TestOp(0.1, 127.1, 8, True, 0.0, 127.0, 0.5)
def testOp_with8BitsNarrowRangeScalingAndNudgingUp(self):
self._TestOp(-127.1, -0.1, 8, True, -127.0, 0.0, 0.5)
def testOp_with8BitsNarrowRangeScalingAndNudgingBetween(self):
self._TestOp(-0.1, 126.9, 8, True, 0.0, 127.0, 0.5)
# 7 bits, wide range.
def testOp_with7BitsNoScalingNoNudging(self):
self._TestOp(0.0, 127.0, 7, False, 0.0, 127.0, 1.0)
def testOp_with7BitsScalingAndNudgingDown(self):
self._TestOp(0.5, 64.0, 7, False, 0.0, 63.5, 0.5)
def testOp_with7BitsScalingAndNudgingUp(self):
self._TestOp(-64.0, -0.5, 7, False, -63.5, 0.0, 0.5)
def testOp_with7BitsScalingAndNudgingBetween(self):
self._TestOp(-0.1, 63.4, 7, False, 0.0, 63.5, 0.5)
# 7 bits, narrow range.
def testOp_with7BitsNarrowRangeNoScalingNoNudging(self):
self._TestOp(0.0, 126.0, 7, True, 0.0, 126.0, 1.0)
def testOp_with7BitsNarrowRangeScalingAndNudgingDown(self):
self._TestOp(0.1, 63.1, 7, True, 0.0, 63.0, 0.5)
def testOp_with7BitsNarrowRangeScalingAndNudgingUp(self):
self._TestOp(-63.1, -0.1, 7, True, -63.0, 0.0, 0.5)
def testOp_with7BitsNarrowRangeScalingAndNudgingBetween(self):
self._TestOp(-0.1, 62.9, 7, True, 0.0, 63.0, 0.5)
def _TestOp(self, input_min, input_max, num_bits, narrow_range,
expected_nudged_input_min, expected_nudged_input_max,
expected_step):
inputs = np.array(
[
expected_nudged_input_min - expected_step,
expected_nudged_input_min - 0.01, expected_nudged_input_min,
expected_nudged_input_min + 0.01,
expected_nudged_input_min + expected_step - 0.01,
expected_nudged_input_min + expected_step,
expected_nudged_input_min + expected_step + 0.01,
expected_nudged_input_max - 0.01, expected_nudged_input_max,
expected_nudged_input_max + 0.01,
expected_nudged_input_max + expected_step
],
dtype=np.float32)
gradients = np.arange(1, len(inputs) + 1, dtype=np.float32)
expected_backprops = np.array(
[0.0, 0.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0, 9.0, 0.0, 0.0],
dtype=np.float32)
with self.test_session() as session:
with self.test_scope():
gradient_placeholder = array_ops.placeholder(
dtypes.float32, gradients.shape, name="gradients")
input_placeholder = array_ops.placeholder(
dtypes.float32, inputs.shape, name="inputs")
outputs = gen_array_ops.fake_quant_with_min_max_args_gradient(
gradient_placeholder,
input_placeholder,
min=input_min,
max=input_max,
num_bits=num_bits,
narrow_range=narrow_range)
backprops = session.run(outputs, {
gradient_placeholder: gradients,
input_placeholder: inputs
})
self.assertAllCloseAccordingToType(
backprops,
expected_backprops,
rtol=1e-3,
atol=1e-5,
bfloat16_rtol=0.03)
class FakeQuantWithMinMaxVarsTest(XLATestCase):
"""Test cases for FakeQuantWithMinMaxVars operation."""
# 8 bits, wide range.
def testOp_with8BitsNoScalingNoNudging(self):
self._TestOp(0.0, 255.0, 8, False, 0.0, 255.0, 1.0)
def testOp_with8BitsScalingAndNudgingDown(self):
self._TestOp(0.5, 128.0, 8, False, 0.0, 127.5, 0.5)
def testOp_with8BitsScalingAndNudgingUp(self):
self._TestOp(-128.0, -0.5, 8, False, -127.5, 0.0, 0.5)
def testOp_with8BitsScalingAndNudgingBetween(self):
self._TestOp(-0.1, 127.4, 8, False, 0.0, 127.5, 0.5)
# 8 bits, narrow range.
def testOp_with8BitsNarrowRangeNoScalingNoNudging(self):
self._TestOp(0.0, 254.0, 8, True, 0.0, 254.0, 1.0)
def testOp_with8BitsNarrowRangeScalingAndNudgingDown(self):
self._TestOp(0.1, 127.1, 8, True, 0.0, 127.0, 0.5)
def testOp_with8BitsNarrowRangeScalingAndNudgingUp(self):
self._TestOp(-127.1, -0.1, 8, True, -127.0, 0.0, 0.5)
def testOp_with8BitsNarrowRangeScalingAndNudgingBetween(self):
self._TestOp(-0.1, 126.9, 8, True, 0.0, 127.0, 0.5)
# 7 bits, wide range.
def testOp_with7BitsNoScalingNoNudging(self):
self._TestOp(0.0, 127.0, 7, False, 0.0, 127.0, 1.0)
def testOp_with7BitsScalingAndNudgingDown(self):
self._TestOp(0.5, 64.0, 7, False, 0.0, 63.5, 0.5)
def testOp_with7BitsScalingAndNudgingUp(self):
self._TestOp(-64.0, -0.5, 7, False, -63.5, 0.0, 0.5)
def testOp_with7BitsScalingAndNudgingBetween(self):
self._TestOp(-0.1, 63.4, 7, False, 0.0, 63.5, 0.5)
# 7 bits, narrow range.
def testOp_with7BitsNarrowRangeNoScalingNoNudging(self):
self._TestOp(0.0, 126.0, 7, True, 0.0, 126.0, 1.0)
def testOp_with7BitsNarrowRangeScalingAndNudgingDown(self):
self._TestOp(0.1, 63.1, 7, True, 0.0, 63.0, 0.5)
def testOp_with7BitsNarrowRangeScalingAndNudgingUp(self):
self._TestOp(-63.1, -0.1, 7, True, -63.0, 0.0, 0.5)
def testOp_with7BitsNarrowRangeScalingAndNudgingBetween(self):
self._TestOp(-0.1, 62.9, 7, True, 0.0, 63.0, 0.5)
def _TestOp(self, input_min, input_max, num_bits, narrow_range,
expected_nudged_input_min, expected_nudged_input_max,
expected_step):
inputs = np.array(
[
expected_nudged_input_min - expected_step,
expected_nudged_input_min - 0.01, expected_nudged_input_min,
expected_nudged_input_min + 0.01,
expected_nudged_input_min + expected_step - 0.01,
expected_nudged_input_min + expected_step,
expected_nudged_input_min + expected_step + 0.01,
expected_nudged_input_max - 0.01, expected_nudged_input_max,
expected_nudged_input_max + 0.01,
expected_nudged_input_max + expected_step
],
dtype=np.float32)
expected = np.array(
[
expected_nudged_input_min, expected_nudged_input_min,
expected_nudged_input_min, expected_nudged_input_min,
expected_nudged_input_min + expected_step,
expected_nudged_input_min + expected_step,
expected_nudged_input_min + expected_step,
expected_nudged_input_max, expected_nudged_input_max,
expected_nudged_input_max, expected_nudged_input_max
],
dtype=np.float32)
with self.test_session() as session:
with self.test_scope():
input_placeholder = array_ops.placeholder(
dtypes.float32, inputs.shape, name="inputs")
min_placeholder = array_ops.placeholder(dtypes.float32, (), name="min")
max_placeholder = array_ops.placeholder(dtypes.float32, (), name="max")
outputs = array_ops.fake_quant_with_min_max_vars(
input_placeholder,
min_placeholder,
max_placeholder,
num_bits=num_bits,
narrow_range=narrow_range)
result = session.run(
outputs, {
input_placeholder: inputs,
min_placeholder: input_min,
max_placeholder: input_max
})
self.assertAllCloseAccordingToType(
result, expected, rtol=1e-3, atol=1e-5, bfloat16_rtol=0.03)
class FakeQuantWithMinMaxVarsGradientTest(XLATestCase):
"""Test cases for FakeQuantWithMinMaxVarsGradient operation."""
# 8 bits, wide range.
def testOp_with8BitsNoScalingNoNudging(self):
self._TestOp(0.0, 255.0, 8, False, 0.0, 255.0, 1.0)
def testOp_with8BitsScalingAndNudgingDown(self):
self._TestOp(0.5, 128.0, 8, False, 0.0, 127.5, 0.5)
def testOp_with8BitsScalingAndNudgingUp(self):
self._TestOp(-128.0, -0.5, 8, False, -127.5, 0.0, 0.5)
def testOp_with8BitsScalingAndNudgingBetween(self):
self._TestOp(-0.1, 127.4, 8, False, 0.0, 127.5, 0.5)
# 8 bits, narrow range.
def testOp_with8BitsNarrowRangeNoScalingNoNudging(self):
self._TestOp(0.0, 254.0, 8, True, 0.0, 254.0, 1.0)
def testOp_with8BitsNarrowRangeScalingAndNudgingDown(self):
self._TestOp(0.1, 127.1, 8, True, 0.0, 127.0, 0.5)
def testOp_with8BitsNarrowRangeScalingAndNudgingUp(self):
self._TestOp(-127.1, -0.1, 8, True, -127.0, 0.0, 0.5)
def testOp_with8BitsNarrowRangeScalingAndNudgingBetween(self):
self._TestOp(-0.1, 126.9, 8, True, 0.0, 127.0, 0.5)
# 7 bits, wide range.
def testOp_with7BitsNoScalingNoNudging(self):
self._TestOp(0.0, 127.0, 7, False, 0.0, 127.0, 1.0)
def testOp_with7BitsScalingAndNudgingDown(self):
self._TestOp(0.5, 64.0, 7, False, 0.0, 63.5, 0.5)
def testOp_with7BitsScalingAndNudgingUp(self):
self._TestOp(-64.0, -0.5, 7, False, -63.5, 0.0, 0.5)
def testOp_with7BitsScalingAndNudgingBetween(self):
self._TestOp(-0.1, 63.4, 7, False, 0.0, 63.5, 0.5)
# 7 bits, narrow range.
def testOp_with7BitsNarrowRangeNoScalingNoNudging(self):
self._TestOp(0.0, 126.0, 7, True, 0.0, 126.0, 1.0)
def testOp_with7BitsNarrowRangeScalingAndNudgingDown(self):
self._TestOp(0.1, 63.1, 7, True, 0.0, 63.0, 0.5)
def testOp_with7BitsNarrowRangeScalingAndNudgingUp(self):
self._TestOp(-63.1, -0.1, 7, True, -63.0, 0.0, 0.5)
def testOp_with7BitsNarrowRangeScalingAndNudgingBetween(self):
self._TestOp(-0.1, 62.9, 7, True, 0.0, 63.0, 0.5)
def _TestOp(self, input_min, input_max, num_bits, narrow_range,
expected_nudged_input_min, expected_nudged_input_max,
expected_step):
inputs = np.array(
[
expected_nudged_input_min - expected_step,
expected_nudged_input_min - 0.01, expected_nudged_input_min,
expected_nudged_input_min + 0.01,
expected_nudged_input_min + expected_step - 0.01,
expected_nudged_input_min + expected_step,
expected_nudged_input_min + expected_step + 0.01,
expected_nudged_input_max - 0.01, expected_nudged_input_max,
expected_nudged_input_max + 0.01,
expected_nudged_input_max + expected_step
],
dtype=np.float32)
gradients = np.arange(1, len(inputs) + 1, dtype=np.float32)
expected_backprops_wrt_input = np.array(
[0.0, 0.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0, 9.0, 0.0, 0.0],
dtype=np.float32)
expected_backprops_wrt_min = 1.0 + 2.0
expected_backprops_wrt_max = 10.0 + 11.0
with self.test_session() as session:
with self.test_scope():
gradient_placeholder = array_ops.placeholder(
dtypes.float32, gradients.shape, name="gradients")
input_placeholder = array_ops.placeholder(
dtypes.float32, inputs.shape, name="inputs")
min_placeholder = array_ops.placeholder(dtypes.float32, (), name="min")
max_placeholder = array_ops.placeholder(dtypes.float32, (), name="max")
outputs = array_ops.fake_quant_with_min_max_vars_gradient(
gradient_placeholder,
input_placeholder,
min_placeholder,
max_placeholder,
num_bits=num_bits,
narrow_range=narrow_range)
backprops_wrt_input, backprops_wrt_min, backprops_wrt_max = session.run(
outputs, {
gradient_placeholder: gradients,
input_placeholder: inputs,
min_placeholder: input_min,
max_placeholder: input_max
})
self.assertAllCloseAccordingToType(
backprops_wrt_input,
expected_backprops_wrt_input,
rtol=1e-3,
atol=1e-5,
bfloat16_rtol=0.03)
self.assertAllCloseAccordingToType(
backprops_wrt_min,
expected_backprops_wrt_min,
rtol=1e-3,
atol=1e-5,
bfloat16_rtol=0.03)
self.assertAllCloseAccordingToType(
backprops_wrt_max,
expected_backprops_wrt_max,
rtol=1e-3,
atol=1e-5,
bfloat16_rtol=0.03)
if __name__ == "__main__":
googletest.main()

View File

@ -0,0 +1,188 @@
# Copyright 2018 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Tests for tensorflow.ops.tf.scatter_nd."""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import functools
import numpy as np
from tensorflow.compiler.tests.xla_test import XLATestCase
from tensorflow.python.framework import errors
from tensorflow.python.ops import array_ops
from tensorflow.python.platform import test
def _AsType(v, vtype):
return v.astype(vtype) if isinstance(v, np.ndarray) else vtype(v)
def _FlatInnerDims(tensor, ndims=2):
shape = list(tensor.shape)
return tensor.reshape(
[functools.reduce(lambda x, y: x * y, shape[:-ndims + 1], 1)] +
shape[-ndims + 1:])
def _FlatOuterDims(tensor, ndims=2):
shape = list(tensor.shape)
return tensor.reshape(
shape[:ndims - 1] +
[functools.reduce(lambda x, y: x * y, shape[ndims - 1:], 1)])
def _NumpyScatterNd(ref, indices, updates, op):
ixdim = indices.shape[-1]
num_updates = indices.size // ixdim
total_nd = len(ref.shape)
slice_size = 1
for i in range(ixdim, total_nd):
slice_size *= ref.shape[i]
flat_indices = _FlatInnerDims(indices)
flat_updates = updates.reshape((num_updates, slice_size))
output_flat = _FlatOuterDims(ref, ixdim + 1)
for ix_updates, ix_output in enumerate(flat_indices):
ix_output = tuple(ix_output)
output_flat[ix_output] = op(output_flat[ix_output],
flat_updates[ix_updates])
return output_flat.reshape(ref.shape)
def _NumpyUpdate(indices, updates, shape):
ref = np.zeros(shape, dtype=updates.dtype)
return _NumpyScatterNd(ref, indices, updates, lambda p, u: u)
class ScatterNdTest(XLATestCase):
def _VariableRankTest(self,
np_scatter,
tf_scatter,
vtype,
itype,
repeat_indices=False):
np.random.seed(8)
ref_shapes = [(3, 6), (3, 6), (3, 6, 9), (3, 6, 9), (3, 6, 9), (3, 6, 9)]
indices_shapes = [(2,), (2, 2), (2,), (2, 2), (2, 3), (2, 3, 3)]
for ref_shape, indices_shape in zip(ref_shapes, indices_shapes):
num_updates = indices_shape[0]
ixdim = indices_shape[-1]
indexable_area_shape = ()
for i in range(ixdim):
indexable_area_shape += (ref_shape[i],)
all_indices = [
list(coord)
for coord, _ in np.ndenumerate(np.empty(indexable_area_shape, vtype))
]
np.random.shuffle(all_indices)
indices = np.array(all_indices[:num_updates])
if num_updates > 1 and repeat_indices:
indices = indices[:num_updates // 2]
for _ in range(num_updates - num_updates // 2):
indices = np.append(
indices, [indices[np.random.randint(num_updates // 2)]], axis=0)
np.random.shuffle(indices)
indices = _AsType(indices[:num_updates], itype)
updates_shape = (num_updates,)
for i in range(ixdim, len(ref_shape)):
updates_shape += (ref_shape[i],)
updates = _AsType(np.random.randn(*(updates_shape)), vtype)
# Scatter via numpy
np_out = np_scatter(indices, updates, ref_shape)
# Scatter via tensorflow
tf_out = tf_scatter(indices, updates, ref_shape)
self.assertAllClose(np_out, tf_out)
def _VariableRankTests(self, np_scatter, tf_scatter):
for vtype in self.numeric_types:
for itype in set([np.int32, np.int64]).intersection(set(self.int_types)):
self._VariableRankTest(np_scatter, tf_scatter, vtype, itype)
def _runScatterNd(self, indices, updates, shape):
with self.test_session():
updates_placeholder = array_ops.placeholder(updates.dtype)
indices_placeholder = array_ops.placeholder(indices.dtype)
with self.test_scope():
output = array_ops.scatter_nd(indices_placeholder, updates_placeholder,
shape)
feed_dict = {updates_placeholder: updates, indices_placeholder: indices}
return output.eval(feed_dict=feed_dict)
def testSimple(self):
indices = np.array([[4], [3], [1], [7]], dtype=np.int32)
updates = np.array([9, 10, 11, 12], dtype=np.float32)
expected = np.array([0, 11, 0, 10, 9, 0, 0, 12], dtype=np.int32)
self.assertAllEqual(expected, self._runScatterNd(indices, updates, [8]))
def testSimple2(self):
indices = np.array([[1, 0], [1, 1]], dtype=np.int32)
updates = np.array([11., 12.], dtype=np.float32)
expected = np.array([[0., 0.], [11., 12.], [0., 0.]], dtype=np.float32)
self.assertAllEqual(expected, self._runScatterNd(indices, updates, [3, 2]))
def testSimple3(self):
indices = np.array([[1]], dtype=np.int32)
updates = np.array([[11., 12.]], dtype=np.float32)
expected = np.array([[0., 0.], [11., 12.], [0., 0.]])
self.assertAllEqual(expected, self._runScatterNd(indices, updates, [3, 2]))
def testVariableRankUpdate(self):
self._VariableRankTests(_NumpyUpdate, self._runScatterNd)
def testExtraIndicesDimensions(self):
indices = np.zeros([1, 1, 2], np.int32)
updates = np.zeros([1, 1], np.int32)
expected = np.zeros([2, 2], dtype=np.int32)
self.assertAllEqual(expected, self._runScatterNd(indices, updates, [2, 2]))
def testRank3InvalidShape1(self):
indices = np.zeros([3, 2, 2], np.int32)
updates = np.zeros([2, 2, 2], np.int32)
with self.assertRaisesWithPredicateMatch(errors.InvalidArgumentError,
"Must have updates.shape"):
self._runScatterNd(indices, updates, [2, 2, 2])
def testRank3InvalidShape2(self):
indices = np.zeros([2, 2, 1], np.int32)
updates = np.zeros([2, 2], np.int32)
with self.assertRaisesWithPredicateMatch(errors.InvalidArgumentError,
"Must have updates.shape"):
self._runScatterNd(indices, updates, [2, 2, 2])
def testScatterOutOfRange(self):
updates = np.array([-3, -4, -5]).astype(np.float32)
# Indices all in range, no problem.
indices = np.array([[2], [0], [5]], dtype=np.int32)
self._runScatterNd(indices, updates, [6])
# Indices out of range should not fail. It produces implementation-defined
# output.
indices = np.array([[-1], [0], [5]], dtype=np.int32)
self._runScatterNd(indices, updates, [6])
indices = np.array([[2], [0], [6]], dtype=np.int32)
self._runScatterNd(indices, updates, [6])
if __name__ == "__main__":
test.main()

View File

@ -60,6 +60,14 @@ class SegmentReductionOpsTest(XLATestCase):
np.array([0, 1, 2, 3, 4, 5], dtype=dtype),
np.array([3, 0, 2, 1, 3, 3], dtype=np.int32), 4))
def testUnsortedSegmentSum1DIndices1DDataNegativeIndices(self):
for dtype in self.numeric_types:
self.assertAllClose(
np.array([0, 3, 2, 5], dtype=dtype),
self.UnsortedSegmentSum(
np.array([0, 1, 2, 3, 4, 5], dtype=dtype),
np.array([3, -1, 2, 1, -1, 3], dtype=np.int32), 4))
def testUnsortedSegmentSum1DIndices2DDataDisjoint(self):
for dtype in self.numeric_types:
data = np.array(

View File

@ -285,7 +285,8 @@ Status BuildLoopBody(const Graph& graph, Frame* frame,
Status FunctionalizeLoop(Graph* graph, Frame* frame,
FunctionLibraryDefinition* library) {
VLOG(2) << "Frame " << frame->name << " before: "
<< dump_graph::DumpGraphToFile("functionalize_before", *graph);
<< dump_graph::DumpGraphToFile("functionalize_before", *graph,
library);
// Split loop-varying Enter nodes with multiple successors. If the same
// Tensor is fed as input to multiple loop arguments, we may end up with a
@ -470,7 +471,7 @@ Status FunctionalizeLoop(Graph* graph, Frame* frame,
TF_RETURN_IF_ERROR(BuildLoopBody(*graph, frame, &arg_types, &body_graph));
VLOG(2) << "Frame " << frame->name << " condition: "
<< dump_graph::DumpGraphToFile("loop_condition", *cond_graph)
<< dump_graph::DumpGraphToFile("loop_condition", *cond_graph, library)
<< " body: " << dump_graph::DumpGraphToFile("loop_body", *body_graph);
static std::atomic<int64> sequence_num(0LL);
@ -551,7 +552,8 @@ Status FunctionalizeLoop(Graph* graph, Frame* frame,
frame->parent->nodes.insert(while_node);
VLOG(2) << "Frame " << frame->name << " after: "
<< dump_graph::DumpGraphToFile("functionalize_after", *graph);
<< dump_graph::DumpGraphToFile("functionalize_after", *graph,
library);
return Status::OK();
}
@ -584,11 +586,11 @@ class FunctionalizeCond {
explicit CondArgNode(Node* input) : input(input) {}
string ToString() const {
return strings::StrCat("input=", input->name(),
" switches=", NodesToString(switch_nodes));
" switches=", NodesToString(switches));
}
Node* input;
std::vector<Node*> switch_nodes;
std::vector<Node*> switches;
};
using CondArgNodes = std::vector<CondArgNode>;
@ -602,15 +604,22 @@ class FunctionalizeCond {
int count;
};
struct PredicateSwitches {
explicit PredicateSwitches(Node* predicate) : predicate(predicate) {}
// Group of switch nodes that will be part of the same XlaIf.
struct SwitchCluster {
explicit SwitchCluster(Node* predicate) : predicate(predicate) {}
string ToString() const {
return strings::StrCat(name, " predicate=", predicate->name(),
" switches=", NodesToString(switches));
}
string name;
Node* predicate;
std::vector<Node*> switches;
};
FunctionalizeCond(Graph* graph, FunctionLibraryDefinition* library)
: library_(library), graph_(graph) {}
FunctionalizeCond(Graph* graph, FunctionLibraryDefinition* library,
bool dump_graphs)
: library_(library), graph_(graph), dump_graphs_(dump_graphs) {}
// Perform the actual cond functionalization. Iterate over groups of switch
// nodes (linked by common predicate), from innermost to outermost, and
@ -621,27 +630,25 @@ class FunctionalizeCond {
// frontier (the nodes where the cond ends).
StatusOr<std::pair<std::unordered_map<Node*, ForwardFlowNode>,
std::unordered_set<Node*>>>
DetermineBranchMapAndFrontier(const std::vector<Node*>& switches);
DetermineBranchMapAndFrontier(const SwitchCluster& switch_cluster);
// Returns XlaIf node created from subgraph of merge and switch nodes. This
// encapsulates the process of extracting the bodies needed for the then and
// else branch, creates a XlaIf node, removing the nodes of the branches from
// the graph and replacing the merge node with a XlaIf.
StatusOr<Node*> ConvertToXlaIf(const CondArgNodes& cond_arg_nodes,
const std::vector<Node*>& switch_nodes,
const std::vector<Node*>& merge_nodes,
Node* predicate);
const SwitchCluster& switch_cluster,
const std::vector<Node*>& switches);
// Builds a XlaIfOp to replace the Switch-Graph-Merge cluster with.
StatusOr<Node*> BuildAndAddXlaIfOp(const CondArgNodes& cond_arg_nodes,
const std::vector<Node*>& switch_nodes,
const std::vector<Node*>& merge_nodes,
Node* predicate);
const SwitchCluster& switch_cluster,
const std::vector<Node*>& merge_nodes);
// Extracts a function body corresponding to the given input edge of the merge
// node.
Status ExtractBody(const CondArgNodes& cond_arg_nodes,
const std::vector<Node*>& switch_nodes,
const std::vector<Node*>& switches,
const std::vector<Node*>& merge_nodes, int input_edge,
Graph* body);
@ -652,9 +659,9 @@ class FunctionalizeCond {
// Adds all output edges from the `if_node`.
Status AddOutputEdges(const std::vector<Node*>& outputs, Node* if_node);
// Returns the switches of graph_ (along with grouping predicates) in
// postorder. Dead switch nodes are skipped and removed from the graph.
std::vector<PredicateSwitches> DeterminePredicateSwitchOrder();
// Returns the switch clusters of graph_ in postorder. Dead switch nodes are
// skipped and removed from the graph.
StatusOr<std::vector<SwitchCluster>> DeterminePredicateSwitchOrder();
// Update the state for destination based on the state of source and the node
// being updated.
@ -677,6 +684,7 @@ class FunctionalizeCond {
FunctionLibraryDefinition* library_;
Graph* graph_;
bool dump_graphs_;
};
bool IsDeadSwitch(const Node* node) {
@ -724,10 +732,13 @@ Status FunctionalizeCond::ValidateFrontier(
") in both Else and Then branch should be in Both.");
}
}
if (pending[kBoth].empty() && pending[kThenBranch].empty() &&
pending[kElseBranch].empty()) {
return errors::Internal("Unexpected empty frontier for switch nodes");
}
// An empty frontier indicates a dead switch. Above we attempt to remove dead
// switch nodes, but not all are removed so don't treat it as an error yet.
// TODO(jpienaar): Find out why dead switch nodes remain.
// if (pending[kBoth].empty() && pending[kThenBranch].empty() &&
// pending[kElseBranch].empty()) {
// return errors::Internal("Unexpected empty frontier for switch nodes");
// }
return Status::OK();
}
@ -754,33 +765,191 @@ Status FunctionalizeCond::Join(const ForwardFlowNode& src_state,
return Status::OK();
}
std::vector<FunctionalizeCond::PredicateSwitches>
StatusOr<std::vector<FunctionalizeCond::SwitchCluster>>
FunctionalizeCond::DeterminePredicateSwitchOrder() {
struct Cluster {
bool operator==(const Cluster& other) const {
return representative == other.representative;
}
int representative = -1;
};
// Perform a DFS over the graph and
// * Determine the reverse topological order of the nodes (there should be no
// cycles at this point so the post-order numbering corresponds to the
// reverse topological sorting);
// * Identify dead switches;
// * Initialize the cluster's representative;
std::vector<UnionFind<Cluster>> clusters(graph_->num_node_ids());
std::vector<Node*> dead_switches;
std::vector<Node*> switch_order;
DFS(*graph_, nullptr, [this, &dead_switches, &switch_order](Node* n) {
std::vector<Node*> rev_topo_sorted_nodes;
DFS(*graph_, nullptr, [&](Node* n) {
clusters[n->id()].Get().representative = n->id();
if (IsSwitch(n)) {
if (IsDeadSwitch(n)) {
dead_switches.push_back(n);
} else {
rev_topo_sorted_nodes.push_back(n);
switch_order.push_back(n);
}
} else if (n->IsOp()) {
// Exclude src and sink nodes from further consideration.
rev_topo_sorted_nodes.push_back(n);
}
});
std::vector<SwitchCluster> switch_clusters;
// Return early if there are no switches in the graph.
if (switch_order.empty()) {
return switch_clusters;
}
// Remove all dead switch nodes.
for (Node* n : dead_switches) {
VLOG(2) << "Removing dead switch: " << n->DebugString();
graph_->RemoveNode(n);
}
std::vector<PredicateSwitches> predicate_switch_order;
if (switch_order.empty()) {
return predicate_switch_order;
// Identify switch nodes that are part of the same control flow context by
// considering the operands of operations: an operation is part of the same
// control context as its operands unless the operation is a switch. Control
// dependencies are considered part of the same control flow context if the
// switch depth is the same (see comment below).
// entry_cluster records the input cluster to a switch node. This is used when
// merging with a merge node where the dst's cluster is merged with the entry
// cluster of the merge node's cluster (which corresponds to a switch cluster
// and so has an entry cluster).
std::unordered_map<int, UnionFind<Cluster>*> entry_cluster;
// Returns the output cluster of a node. Where the output cluster is cluster
// where the output of the node is used. For non-merge nodes this is simply
// the cluster they are part of, while for merge nodes it is the entry cluster
// of the cluster they are part of (this will correspond to the entry node of
// a switch node that dominates the merge).
auto find_output_cluster = [&](Node* n) {
UnionFind<Cluster>* cluster = &clusters[n->id()];
if (!IsMerge(n)) return cluster;
auto it = entry_cluster.find(clusters[n->id()].Get().representative);
// If the cluster is not found in the entry_cluster map then an
// instruction not dominated by a switch node has been merged into the
// cluster of the merge. This indicates a failure of the clustering.
CHECK(it != entry_cluster.end())
<< "Unable to find entry for n=" << n->id() << " ("
<< cluster->Get().representative << ")";
return it->second;
};
// TODO(jpienaar): This could be combined with DetermineBranchMapAndFrontier.
std::vector<int> switch_depth(graph_->num_node_ids());
for (auto it = rev_topo_sorted_nodes.rbegin();
it != rev_topo_sorted_nodes.rend(); ++it) {
Node* n = *it;
// Compute switch depth.
int new_switch_depth = 0;
for (const Edge* e : n->in_edges()) {
Node* src = e->src();
new_switch_depth = std::max(
new_switch_depth, switch_depth[src->id()] - (IsMerge(src) ? 1 : 0));
}
switch_depth[n->id()] = new_switch_depth + (IsSwitch(n) ? 1 : 0);
// Only merge the input operands of a switch. The switch's clustering itself
// is determined by the interaction of the switch's outputs.
if (IsSwitch(n)) {
Node* input;
TF_CHECK_OK(n->input_node(0, &input));
entry_cluster[n->id()] = &clusters[input->id()];
UnionFind<Cluster>* cluster = find_output_cluster(input);
int cluster_depth = switch_depth[cluster->Get().representative];
// Merge the inputs of the switch node with one another. This results in
// predicates and control input residing in the same cluster.
for (const Edge* e : n->in_edges()) {
Node* src = e->src();
UnionFind<Cluster>* src_cluster = find_output_cluster(src);
int src_cluster_depth = switch_depth[src_cluster->Get().representative];
if (cluster_depth != src_cluster_depth) {
return errors::InvalidArgument(
"Unable to functionalize control flow in graph: Switch ('",
n->name(), "') has operands ('", input->name(), "' and '",
src->name(), "') that have different switch depths (",
cluster_depth, " != ", src_cluster_depth, ")");
}
cluster->Merge(src_cluster);
}
continue;
}
for (const Edge* e : n->in_edges()) {
Node* src = e->src();
if (!src->IsOp()) continue;
UnionFind<Cluster>* cluster = find_output_cluster(src);
// Merge a node with its data operands and with its control operands if
// the src and dst are in the same ControlContext. The ControlContext is
// not explicitly available here, and instead the switch depth is used as
// a proxy here. Due to the invariant that control edges can only be from
// a containing scope to an inner scope or from the inner scope to its
// containing scope (for exit nodes), the switch depth will only match if
// the src and dst are in the same ControlContext. Control edges between
// ControlContexts are handled during the extraction.
int src_id = cluster->Get().representative;
int src_depth = switch_depth[src_id];
if (!e->IsControlEdge() || new_switch_depth == src_depth) {
if (src_depth != new_switch_depth) {
return errors::InvalidArgument(
"Unable to functionalize control flow in graph: Operand ('",
src->name(), "') and operator ('", n->name(),
"') have different switch depths (", src_depth,
" != ", new_switch_depth, ")");
}
cluster->Merge(&clusters[n->id()]);
}
}
}
if (dump_graphs_) {
// Mark the switch cluster each node is part of.
for (Node* n : graph_->nodes()) {
n->ClearAttr("_XlaFunctionalizeSwitchGroup");
n->AddAttr("_XlaFunctionalizeSwitchGroup",
clusters[n->id()].Get().representative);
}
LOG(INFO) << "FunctionalizeControlFlow (with_clusters): "
<< dump_graph::DumpGraphToFile("functionalize_clustered", *graph_,
library_);
}
// Verify all the nodes of a cluster are at the same depth.
std::unordered_map<int, std::pair<int, Node*>> cluster_to_depth_node;
for (Node* n : graph_->nodes()) {
int depth = switch_depth[n->id()];
int cluster_rep = clusters[n->id()].Get().representative;
auto it = cluster_to_depth_node.find(cluster_rep);
if (it == cluster_to_depth_node.end()) {
cluster_to_depth_node[cluster_rep] = std::make_pair(depth, n);
} else {
if (it->second.first != depth) {
return errors::Internal(
"Illegal clustering created, mismatch in depths:", "\n\t",
n->DebugString(), "(", clusters[n->id()].Get().representative,
") at depth=", depth, " vs\n\t", it->second.second->DebugString(),
"(", clusters[n->id()].Get().representative, ") at depth ",
it->second.first);
}
}
}
struct Hash {
size_t operator()(const std::pair<Node*, Cluster>& item) const {
return Hash64Combine(hash<Node*>()(item.first),
std::hash<int>()(item.second.representative));
}
};
// Merge Switch nodes with common predicate.
std::unordered_map<Node*, int> predicate_index;
std::unordered_map<std::pair<Node*, Cluster>, int, Hash> predicate_index;
// The nodes in switch_order are in reverse topological order, but the
// clustered switches need not be (i.e., when considered as a cluster one
// element of a cluster may be later in the topological order than another
@ -789,13 +958,19 @@ FunctionalizeCond::DeterminePredicateSwitchOrder() {
for (auto it = switch_order.rbegin(); it != switch_order.rend(); ++it) {
Node* pred;
TF_CHECK_OK((*it)->input_node(1, &pred));
if (predicate_index.find(pred) == predicate_index.end()) {
predicate_index[pred] = predicate_switch_order.size();
predicate_switch_order.emplace_back(pred);
auto repr = std::make_pair(pred, clusters[(*it)->id()].Get());
if (predicate_index.find(repr) == predicate_index.end()) {
predicate_index[repr] = switch_clusters.size();
switch_clusters.emplace_back(pred);
// Generate a name by concatenating with the cluster representative as
// there could be multiple switch clusters with the same predicate.
switch_clusters[predicate_index[repr]].name =
strings::StrCat(pred->name(), "_", repr.second.representative, "_If");
}
predicate_switch_order[predicate_index[pred]].switches.push_back(*it);
switch_clusters[predicate_index[repr]].switches.push_back(*it);
}
return predicate_switch_order;
return switch_clusters;
}
StatusOr<std::vector<Node*>>
@ -843,10 +1018,10 @@ StatusOr<
std::pair<std::unordered_map<Node*, FunctionalizeCond::ForwardFlowNode>,
std::unordered_set<Node*>>>
FunctionalizeCond::DetermineBranchMapAndFrontier(
const std::vector<Node*>& switches) {
const SwitchCluster& switch_cluster) {
std::unordered_map<Node*, ForwardFlowNode> branch_map;
std::unordered_set<Node*> frontier;
std::vector<Node*> stack = switches;
std::vector<Node*> stack = switch_cluster.switches;
std::vector<bool> visited(graph_->num_node_ids(), false);
while (!stack.empty()) {
Node* n = stack.back();
@ -888,7 +1063,7 @@ FunctionalizeCond::DetermineBranchMapAndFrontier(
}
}
if (VLOG_IS_ON(2)) {
if (dump_graphs_) {
for (const auto& kv : branch_map) {
// Append attribute to the graph if running with logging to make the
// changes clearer in the visualization.
@ -900,8 +1075,8 @@ FunctionalizeCond::DetermineBranchMapAndFrontier(
}
Status FunctionalizeCond::FunctionalizeInternal() {
std::vector<PredicateSwitches> predicate_switch_order =
DeterminePredicateSwitchOrder();
TF_ASSIGN_OR_RETURN(std::vector<SwitchCluster> predicate_switch_order,
DeterminePredicateSwitchOrder());
// Iterate from innermost set of clustered switches to outermost, replacing
// matching switch->merge subgraphs with single XlaIf nodes.
@ -914,10 +1089,12 @@ Status FunctionalizeCond::FunctionalizeInternal() {
std::unordered_map<Node*, ForwardFlowNode> branch_map;
std::unordered_set<Node*> frontier;
TF_ASSIGN_OR_RETURN(std::tie(branch_map, frontier),
DetermineBranchMapAndFrontier(ps.switches));
DetermineBranchMapAndFrontier(ps));
VLOG(2) << "FunctionalizeControlFlow (before XlaIf conversion): "
<< dump_graph::DumpGraphToFile("functionalize_bc", *graph_);
if (dump_graphs_)
LOG(INFO) << "FunctionalizeControlFlow (before XlaIf conversion): "
<< dump_graph::DumpGraphToFile("functionalize_bc", *graph_,
library_);
TF_RETURN_IF_ERROR(ValidateFrontier(branch_map, frontier));
// Sort the merge and switch nodes using NodeCmp. The switch-nodes are
@ -934,7 +1111,7 @@ Status FunctionalizeCond::FunctionalizeInternal() {
input_index[in] = cond_arg_nodes.size();
cond_arg_nodes.emplace_back(in);
}
cond_arg_nodes.at(input_index.at(in)).switch_nodes.push_back(switch_node);
cond_arg_nodes.at(input_index.at(in)).switches.push_back(switch_node);
}
std::vector<Node*> merge_nodes(frontier.begin(), frontier.end());
std::sort(merge_nodes.begin(), merge_nodes.end(), NodeCmp());
@ -943,9 +1120,8 @@ Status FunctionalizeCond::FunctionalizeInternal() {
EnsureDominanceAndReturnNonDominatedControlNodes(
branch_map, ps.switches));
TF_ASSIGN_OR_RETURN(
Node * if_node,
ConvertToXlaIf(cond_arg_nodes, ps.switches, merge_nodes, ps.predicate));
TF_ASSIGN_OR_RETURN(Node * if_node,
ConvertToXlaIf(cond_arg_nodes, ps, merge_nodes));
for (Node* old : old_control_nodes) {
graph_->AddControlEdge(old, if_node);
}
@ -954,25 +1130,26 @@ Status FunctionalizeCond::FunctionalizeInternal() {
graph_->RemoveNode(del_kv.first);
}
for (auto& kv : cond_arg_nodes) {
for (Node* node : kv.switch_nodes) {
for (Node* node : kv.switches) {
graph_->RemoveNode(node);
}
}
VLOG(2) << "FunctionalizeControlFlow (after XlaIf conversion): "
<< dump_graph::DumpGraphToFile("functionalize_ac", *graph_);
if (dump_graphs_)
LOG(INFO) << "FunctionalizeControlFlow (after XlaIf conversion): "
<< dump_graph::DumpGraphToFile("functionalize_ac", *graph_,
library_);
}
return Status::OK();
}
StatusOr<Node*> FunctionalizeCond::BuildAndAddXlaIfOp(
const CondArgNodes& cond_arg_nodes, const std::vector<Node*>& switch_nodes,
const std::vector<Node*>& merge_nodes, Node* predicate) {
VLOG(2) << "Build if op for " << NodesToString(merge_nodes) << " with input "
<< NodesToString(switch_nodes);
const CondArgNodes& cond_arg_nodes, const SwitchCluster& switch_cluster,
const std::vector<Node*>& merge_nodes) {
VLOG(2) << "Build if op for " << switch_cluster.name;
NodeDef if_def;
// Create a new If node using the name of the merge node.
NodeDefBuilder builder(strings::StrCat(predicate->name(), "_If"), "XlaIf");
NodeDefBuilder builder(switch_cluster.name, "XlaIf");
string branch[] = {"else_branch", "then_branch"};
for (int i = 0; i < 2; ++i) {
static std::atomic<int64> sequence_num(0LL);
@ -982,12 +1159,9 @@ StatusOr<Node*> FunctionalizeCond::BuildAndAddXlaIfOp(
body_name.set_name(
strings::StrCat("_functionalize_if_", branch[i], "_", id));
auto body = xla::MakeUnique<Graph>(graph_->op_registry());
TF_RETURN_IF_ERROR(
ExtractBody(cond_arg_nodes, switch_nodes, merge_nodes, i, body.get()));
TF_RETURN_IF_ERROR(ExtractBody(cond_arg_nodes, switch_cluster.switches,
merge_nodes, i, body.get()));
VLOG(3) << "Body " << branch[i] << ": " << DebugString(body.get());
VLOG(4) << "FunctionalizeControlFlow (" << branch[i] << "): "
<< dump_graph::DumpGraphToFile(
strings::StrCat("functionalize_", branch[i]), *body);
FunctionDef body_fdef;
TF_RETURN_IF_ERROR(GraphToFunctionDef(*body, body_name.name(), &body_fdef));
TF_RETURN_IF_ERROR(library_->AddFunctionDef(body_fdef));
@ -999,7 +1173,7 @@ StatusOr<Node*> FunctionalizeCond::BuildAndAddXlaIfOp(
DataTypeVector in_arg_types;
for (auto& kv : cond_arg_nodes) {
bool inserted = false;
for (const Node* arg : kv.switch_nodes) {
for (const Node* arg : kv.switches) {
const Edge* in_edge;
TF_RETURN_IF_ERROR(arg->input_edge(0, &in_edge));
if (in_edge->IsControlEdge()) {
@ -1026,10 +1200,11 @@ StatusOr<Node*> FunctionalizeCond::BuildAndAddXlaIfOp(
builder.Attr("Tout", out_type);
builder.Attr("Tcond", DT_BOOL);
builder.Device(predicate->assigned_device_name());
builder.Device(switch_cluster.predicate->assigned_device_name());
// Conditional should be the first input ...
builder.Input(
NodeDefBuilder::NodeOut(predicate->name(), 0, predicate->output_type(0)));
NodeDefBuilder::NodeOut(switch_cluster.predicate->name(), 0,
switch_cluster.predicate->output_type(0)));
// ... followed by the other inputs.
builder.Input(inputs);
@ -1039,7 +1214,7 @@ StatusOr<Node*> FunctionalizeCond::BuildAndAddXlaIfOp(
}
Status FunctionalizeCond::ExtractBody(const CondArgNodes& cond_arg_nodes,
const std::vector<Node*>& switch_nodes,
const std::vector<Node*>& switches,
const std::vector<Node*>& merge_nodes,
int input_edge, Graph* body) {
VLOG(2) << "ExtractBody for " << NodesToString(merge_nodes) << " along edge "
@ -1049,7 +1224,7 @@ Status FunctionalizeCond::ExtractBody(const CondArgNodes& cond_arg_nodes,
int arg_count = 0;
for (auto& kv : cond_arg_nodes) {
Node* arg_node = nullptr;
for (const auto* arg : kv.switch_nodes) {
for (const auto* arg : kv.switches) {
DataType dtype = arg->input_type(0);
if (arg_node == nullptr) {
TF_ASSIGN_OR_RETURN(arg_node, BuildArgNode(body, dtype, arg_count++));
@ -1073,8 +1248,7 @@ Status FunctionalizeCond::ExtractBody(const CondArgNodes& cond_arg_nodes,
node_map.at(in->id()) = body->CopyNode(in);
}
if (std::find(switch_nodes.begin(), switch_nodes.end(), in) ==
switch_nodes.end()) {
if (std::find(switches.begin(), switches.end(), in) == switches.end()) {
body->AddEdge(node_map.at(in->id()), in_edge->src_output(),
node_map.at(node->id()), 0);
} else {
@ -1096,7 +1270,7 @@ Status FunctionalizeCond::AddInputEdges(const CondArgNodes& cond_arg_nodes,
graph_->AddEdge(predicate, 0, if_node, index++);
for (auto& kv : cond_arg_nodes) {
bool inserted = false;
for (const Node* arg : kv.switch_nodes) {
for (const Node* arg : kv.switches) {
const Edge* in_edge;
TF_RETURN_IF_ERROR(arg->input_edge(0, &in_edge));
if (in_edge->IsControlEdge()) {
@ -1139,16 +1313,17 @@ Status FunctionalizeCond::AddOutputEdges(const std::vector<Node*>& outputs,
}
StatusOr<Node*> FunctionalizeCond::ConvertToXlaIf(
const CondArgNodes& cond_arg_nodes, const std::vector<Node*>& switch_nodes,
const std::vector<Node*>& merge_nodes, Node* predicate) {
VLOG(1) << "ConvertToXlaIf for " << NodesToString(switch_nodes) << " -> "
const CondArgNodes& cond_arg_nodes, const SwitchCluster& switch_cluster,
const std::vector<Node*>& merge_nodes) {
VLOG(1) << "ConvertToXlaIf for " << switch_cluster.ToString() << " -> "
<< NodesToString(merge_nodes);
// Extract bodies and builds a If operator.
TF_ASSIGN_OR_RETURN(
Node * if_node,
BuildAndAddXlaIfOp(cond_arg_nodes, switch_nodes, merge_nodes, predicate));
TF_RETURN_IF_ERROR(AddInputEdges(cond_arg_nodes, predicate, if_node));
BuildAndAddXlaIfOp(cond_arg_nodes, switch_cluster, merge_nodes));
TF_RETURN_IF_ERROR(
AddInputEdges(cond_arg_nodes, switch_cluster.predicate, if_node));
TF_RETURN_IF_ERROR(AddOutputEdges(merge_nodes, if_node));
return if_node;
@ -1157,18 +1332,19 @@ StatusOr<Node*> FunctionalizeCond::ConvertToXlaIf(
Status FunctionalizeCond::Functionalize(Graph* graph,
FunctionLibraryDefinition* library) {
VLOG(1) << "FunctionalizeCond::Functionalize";
FunctionalizeCond fc(graph, library);
FunctionalizeCond fc(graph, library, /*dump_graphs=*/VLOG_IS_ON(2));
return fc.FunctionalizeInternal();
}
} // namespace
// Transformation that converts Tensorflow's graph control flow constructs into
// Transformation that converts TensorFlow's graph control flow constructs into
// functional equivalents.
Status FunctionalizeControlFlow(Graph* graph,
FunctionLibraryDefinition* library) {
VLOG(2) << "FunctionalizeControlFlow (initial): "
<< dump_graph::DumpGraphToFile("functionalize_initial", *graph);
<< dump_graph::DumpGraphToFile("functionalize_initial", *graph,
library);
// Note: BuildControlFlowInfo() requires that the graph's source node is
// connected to all source nodes in the graph. Many graphs violate this
// invariant.
@ -1180,7 +1356,8 @@ Status FunctionalizeControlFlow(Graph* graph,
for (Node* node : graph->op_nodes()) {
const ControlFlowInfo& cf = cf_info[node->id()];
VLOG(2) << "node: " << node->name() << " frame_name: " << cf.frame_name
VLOG(2) << "node: " << node->name() << " (" << node->id()
<< ") frame_name: " << cf.frame_name
<< " frame: " << (cf.frame ? cf.frame->name() : "---")
<< " parent_frame: "
<< (cf.parent_frame ? cf.parent_frame->name() : "---");
@ -1248,7 +1425,8 @@ Status FunctionalizeControlFlow(Graph* graph,
TF_RETURN_IF_ERROR(FunctionalizeCond::Functionalize(graph, library));
VLOG(2) << "FunctionalizeControlFlow (final): "
<< dump_graph::DumpGraphToFile("functionalize_final", *graph);
<< dump_graph::DumpGraphToFile("functionalize_final", *graph,
library);
return Status::OK();
}

View File

@ -38,10 +38,11 @@ namespace {
// Returns the names of the "then" and "else" functions for the XlaIf node in a
// graph.
Status FindIfThenAndElse(const GraphDef& graph, NameAttrList* then_fn,
NameAttrList* else_fn) {
Status FindIfThenAndElse(const GraphDef& graph, string* op_name,
NameAttrList* then_fn, NameAttrList* else_fn) {
for (const NodeDef& node : graph.node()) {
if (node.op() == "XlaIf") {
*op_name = node.name();
const NameAttrList* result;
TF_RETURN_IF_ERROR(GetNodeAttr(node, "then_branch", &result));
*then_fn = *result;
@ -96,9 +97,10 @@ TEST(FunctionalizeControlFlow, Conditional) {
GraphDef graph_def;
graph.ToGraphDef(&graph_def);
string op_name;
NameAttrList then_fn;
NameAttrList else_fn;
TF_EXPECT_OK(FindIfThenAndElse(graph_def, &then_fn, &else_fn));
TF_EXPECT_OK(FindIfThenAndElse(graph_def, &op_name, &then_fn, &else_fn));
InstantiationResultForTest else_result;
TF_EXPECT_OK(
InstantiateFunctionForTest(else_fn.name(), library, &else_result));
@ -109,7 +111,7 @@ TEST(FunctionalizeControlFlow, Conditional) {
auto y = ops::Placeholder(scope.WithOpName("y"), DT_INT32);
auto x = ops::Placeholder(scope.WithOpName("x"), DT_INT32);
auto less = ops::Less(scope.WithOpName("cond/Less"), y, x);
auto if_op = ops::XlaIf(scope.WithOpName("cond/Less_If"), less,
auto if_op = ops::XlaIf(scope.WithOpName(op_name), less,
std::initializer_list<Input>{less, y, x}, then_fn,
else_fn, {DT_INT32});
GraphDef expected;

View File

@ -134,7 +134,7 @@ Status GraphCompiler::Compile() {
TF_RET_CHECK(src->id() < output_registry.size());
const NodeOutputs& src_outputs = output_registry[src->id()];
tensor_inputs_[e->dst_input()] = src_outputs[e->src_output()];
tensor_inputs_.at(e->dst_input()) = src_outputs.at(e->src_output());
}
OpKernelContext op_context(&params, n->num_outputs());

View File

@ -32,6 +32,7 @@ tf_kernel_library(
"dynamic_stitch_op.cc",
"elu_op.cc",
"extract_image_patches_op.cc",
"fake_quantize_ops.cc",
"fft_ops.cc",
"fill_op.cc",
"function_ops.cc",
@ -64,6 +65,7 @@ tf_kernel_library(
"reverse_op.cc",
"reverse_sequence_op.cc",
"scan_ops.cc",
"scatter_nd_op.cc",
"segment_reduction_ops.cc",
"select_op.cc",
"sendrecv_ops.cc",
@ -96,12 +98,15 @@ tf_kernel_library(
"//tensorflow/compiler/tf2xla:xla_compiler",
"//tensorflow/compiler/tf2xla/lib:batch_dot",
"//tensorflow/compiler/tf2xla/lib:cholesky",
"//tensorflow/compiler/tf2xla/lib:scatter",
"//tensorflow/compiler/tf2xla/lib:triangular_solve",
"//tensorflow/compiler/tf2xla/lib:util",
"//tensorflow/compiler/tf2xla/lib:while_loop",
"//tensorflow/compiler/tf2xla/ops:sendrecv_ops",
"//tensorflow/compiler/xla:array4d",
"//tensorflow/compiler/xla:literal_util",
"//tensorflow/compiler/xla:shape_util",
"//tensorflow/compiler/xla:status_macros",
"//tensorflow/compiler/xla:util",
"//tensorflow/compiler/xla:xla_data_proto",
"//tensorflow/compiler/xla/client:client_library",

View File

@ -0,0 +1,289 @@
/* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#include "tensorflow/compiler/tf2xla/xla_helpers.h"
#include "tensorflow/compiler/tf2xla/xla_op_kernel.h"
#include "tensorflow/compiler/tf2xla/xla_op_registry.h"
#include "tensorflow/compiler/xla/client/lib/arithmetic.h"
#include "tensorflow/core/platform/macros.h"
namespace tensorflow {
namespace {
// Gymnastics with nudged zero point is to ensure that the real zero maps to
// an integer, which is required for e.g. zero-padding in convolutional layers.
void CpuNudge(const float min, const float max, const float quant_min,
const float quant_max, float* nudged_min, float* nudged_max,
float* scale) {
*scale = (max - min) / (quant_max - quant_min);
const float zero_point_from_min = quant_min - min / *scale;
float nudged_zero_point;
if (zero_point_from_min <= quant_min) {
nudged_zero_point = quant_min;
} else if (zero_point_from_min >= quant_max) {
nudged_zero_point = quant_max;
} else {
nudged_zero_point = std::round(zero_point_from_min);
}
*nudged_min = (quant_min - nudged_zero_point) * (*scale);
*nudged_max = (quant_max - nudged_zero_point) * (*scale);
}
// An XLA version of CpuNudge().
void XlaNudge(xla::ComputationBuilder* b, const DataType data_type,
const xla::ComputationDataHandle& min,
const xla::ComputationDataHandle& max,
const float quant_min_value, const float quant_max_value,
xla::ComputationDataHandle* nudged_min,
xla::ComputationDataHandle* nudged_max,
xla::ComputationDataHandle* scale) {
*scale = b->Div(b->Sub(max, min),
XlaHelpers::FloatLiteral(b, data_type,
quant_max_value - quant_min_value));
xla::ComputationDataHandle quant_min =
XlaHelpers::FloatLiteral(b, data_type, quant_min_value);
xla::ComputationDataHandle zero_point_from_min =
b->Sub(quant_min, b->Div(min, *scale));
xla::ComputationDataHandle quant_max =
XlaHelpers::FloatLiteral(b, data_type, quant_max_value);
xla::ComputationDataHandle nudged_zero_point =
b->Select(b->Le(zero_point_from_min, quant_min), quant_min,
b->Select(b->Ge(zero_point_from_min, quant_max), quant_max,
b->Round(zero_point_from_min)));
*nudged_min = b->Mul(b->Sub(quant_min, nudged_zero_point), *scale);
*nudged_max = b->Mul(b->Sub(quant_max, nudged_zero_point), *scale);
}
xla::ComputationDataHandle Quantize(
xla::ComputationBuilder* b, const xla::ComputationDataHandle& input,
const DataType data_type,
const xla::ComputationDataHandle& nudged_input_min,
const xla::ComputationDataHandle& nudged_input_max,
const xla::ComputationDataHandle& input_scale) {
xla::ComputationDataHandle one = XlaHelpers::FloatLiteral(b, data_type, 1.0f);
xla::ComputationDataHandle inv_scale = b->Div(one, input_scale);
xla::ComputationDataHandle half =
XlaHelpers::FloatLiteral(b, data_type, 0.5f);
xla::ComputationDataHandle clamped =
b->Clamp(nudged_input_min, input, nudged_input_max);
xla::ComputationDataHandle clamped_shifted =
b->Sub(clamped, nudged_input_min);
xla::ComputationDataHandle rounded =
b->Floor(b->Add(b->Mul(clamped_shifted, inv_scale), half));
return b->Add(b->Mul(rounded, input_scale), nudged_input_min);
}
class FakeQuantWithMinMaxArgsOp : public XlaOpKernel {
public:
explicit FakeQuantWithMinMaxArgsOp(OpKernelConstruction* ctx)
: XlaOpKernel(ctx) {
int num_bits;
OP_REQUIRES_OK(ctx, ctx->GetAttr("num_bits", &num_bits));
OP_REQUIRES(ctx, num_bits >= 2 && num_bits <= 16,
errors::InvalidArgument("num_bits is out of range, expected "
"between 2 and 16, was: ",
num_bits));
bool narrow_range;
OP_REQUIRES_OK(ctx, ctx->GetAttr("narrow_range", &narrow_range));
quant_min_ = narrow_range ? 1 : 0;
quant_max_ = (1 << num_bits) - 1;
float input_min, input_max;
OP_REQUIRES_OK(ctx, ctx->GetAttr("min", &input_min));
OP_REQUIRES_OK(ctx, ctx->GetAttr("max", &input_max));
CpuNudge(input_min, input_max, quant_min_, quant_max_, &nudged_input_min_,
&nudged_input_max_, &input_scale_);
}
void Compile(XlaOpKernelContext* ctx) override {
xla::ComputationDataHandle input = ctx->Input(0);
const DataType data_type = ctx->input_type(0);
xla::ComputationBuilder* b = ctx->builder();
xla::ComputationDataHandle nudged_input_min =
XlaHelpers::FloatLiteral(b, data_type, nudged_input_min_);
xla::ComputationDataHandle nudged_input_max =
XlaHelpers::FloatLiteral(b, data_type, nudged_input_max_);
xla::ComputationDataHandle input_scale =
XlaHelpers::FloatLiteral(b, data_type, input_scale_);
xla::ComputationDataHandle output = Quantize(
b, input, data_type, nudged_input_min, nudged_input_max, input_scale);
ctx->SetOutput(0, output);
}
private:
float quant_min_;
float quant_max_;
float nudged_input_min_;
float nudged_input_max_;
float input_scale_;
};
REGISTER_XLA_OP(Name("FakeQuantWithMinMaxArgs"), FakeQuantWithMinMaxArgsOp);
class FakeQuantWithMinMaxArgsGradOp : public XlaOpKernel {
public:
explicit FakeQuantWithMinMaxArgsGradOp(OpKernelConstruction* ctx)
: XlaOpKernel(ctx) {
int num_bits;
OP_REQUIRES_OK(ctx, ctx->GetAttr("num_bits", &num_bits));
OP_REQUIRES(ctx, num_bits >= 2 && num_bits <= 16,
errors::InvalidArgument("num_bits is out of range, expected "
"between 2 and 16, was: ",
num_bits));
bool narrow_range;
OP_REQUIRES_OK(ctx, ctx->GetAttr("narrow_range", &narrow_range));
const float quant_min = narrow_range ? 1 : 0;
const float quant_max = (1 << num_bits) - 1;
float input_min, input_max, scale;
OP_REQUIRES_OK(ctx, ctx->GetAttr("min", &input_min));
OP_REQUIRES_OK(ctx, ctx->GetAttr("max", &input_max));
CpuNudge(input_min, input_max, quant_min, quant_max, &nudged_input_min_,
&nudged_input_max_, &scale);
}
void Compile(XlaOpKernelContext* ctx) override {
xla::ComputationDataHandle gradient = ctx->Input(0);
const TensorShape gradient_shape = ctx->InputShape(0);
xla::ComputationDataHandle input = ctx->Input(1);
const DataType data_type = ctx->input_type(1);
xla::ComputationBuilder* b = ctx->builder();
xla::ComputationDataHandle nudged_input_min =
XlaHelpers::FloatLiteral(b, data_type, nudged_input_min_);
xla::ComputationDataHandle nudged_input_max =
XlaHelpers::FloatLiteral(b, data_type, nudged_input_max_);
xla::ComputationDataHandle between_nudged_min_max =
b->And(b->Le(nudged_input_min, input), b->Le(input, nudged_input_max));
xla::ComputationDataHandle zeroes = b->Broadcast(
XlaHelpers::Zero(b, data_type), gradient_shape.dim_sizes());
xla::ComputationDataHandle output =
b->Select(between_nudged_min_max, gradient, zeroes);
ctx->SetOutput(0, output);
}
private:
float nudged_input_min_;
float nudged_input_max_;
};
REGISTER_XLA_OP(Name("FakeQuantWithMinMaxArgsGradient"),
FakeQuantWithMinMaxArgsGradOp);
class FakeQuantWithMinMaxVarsOp : public XlaOpKernel {
public:
explicit FakeQuantWithMinMaxVarsOp(OpKernelConstruction* ctx)
: XlaOpKernel(ctx) {
int num_bits;
OP_REQUIRES_OK(ctx, ctx->GetAttr("num_bits", &num_bits));
OP_REQUIRES(ctx, num_bits >= 2 && num_bits <= 16,
errors::InvalidArgument("num_bits is out of range, expected "
"between 2 and 16, was: ",
num_bits));
bool narrow_range;
OP_REQUIRES_OK(ctx, ctx->GetAttr("narrow_range", &narrow_range));
quant_min_ = narrow_range ? 1 : 0;
quant_max_ = (1 << num_bits) - 1;
}
void Compile(XlaOpKernelContext* ctx) override {
xla::ComputationDataHandle input = ctx->Input(0);
const DataType data_type = ctx->input_type(0);
xla::ComputationDataHandle input_min = ctx->Input(1);
xla::ComputationDataHandle input_max = ctx->Input(2);
xla::ComputationBuilder* b = ctx->builder();
xla::ComputationDataHandle nudged_input_min, nudged_input_max, input_scale;
XlaNudge(b, data_type, input_min, input_max, quant_min_, quant_max_,
&nudged_input_min, &nudged_input_max, &input_scale);
xla::ComputationDataHandle output = Quantize(
b, input, data_type, nudged_input_min, nudged_input_max, input_scale);
ctx->SetOutput(0, output);
}
private:
float quant_min_;
float quant_max_;
};
REGISTER_XLA_OP(Name("FakeQuantWithMinMaxVars"), FakeQuantWithMinMaxVarsOp);
class FakeQuantWithMinMaxVarsGradOp : public XlaOpKernel {
public:
explicit FakeQuantWithMinMaxVarsGradOp(OpKernelConstruction* ctx)
: XlaOpKernel(ctx) {
int num_bits;
OP_REQUIRES_OK(ctx, ctx->GetAttr("num_bits", &num_bits));
OP_REQUIRES(ctx, num_bits >= 2 && num_bits <= 16,
errors::InvalidArgument("num_bits is out of range, expected "
"between 2 and 16, was: ",
num_bits));
bool narrow_range;
OP_REQUIRES_OK(ctx, ctx->GetAttr("narrow_range", &narrow_range));
quant_min_ = narrow_range ? 1 : 0;
quant_max_ = (1 << num_bits) - 1;
}
void Compile(XlaOpKernelContext* ctx) override {
xla::ComputationDataHandle gradient = ctx->Input(0);
const TensorShape gradient_shape = ctx->InputShape(0);
xla::ComputationDataHandle input = ctx->Input(1);
const DataType data_type = ctx->input_type(1);
xla::ComputationDataHandle input_min = ctx->Input(2);
xla::ComputationDataHandle input_max = ctx->Input(3);
xla::ComputationBuilder* b = ctx->builder();
xla::ComputationDataHandle nudged_input_min, nudged_input_max, input_scale;
XlaNudge(b, data_type, input_min, input_max, quant_min_, quant_max_,
&nudged_input_min, &nudged_input_max, &input_scale);
xla::ComputationDataHandle between_nudged_min_max =
b->And(b->Le(nudged_input_min, input), b->Le(input, nudged_input_max));
xla::ComputationDataHandle zero = XlaHelpers::Zero(b, data_type);
xla::ComputationDataHandle zeroes =
b->Broadcast(zero, gradient_shape.dim_sizes());
xla::ComputationDataHandle output0 =
b->Select(between_nudged_min_max, gradient, zeroes);
ctx->SetOutput(0, output0);
xla::ComputationDataHandle below_min = b->Lt(input, nudged_input_min);
xla::ComputationDataHandle output1 =
b->ReduceAll(b->Select(below_min, gradient, zeroes), zero,
*ctx->GetOrCreateAdd(data_type));
ctx->SetOutput(1, output1);
xla::ComputationDataHandle above_max = b->Gt(input, nudged_input_max);
xla::ComputationDataHandle output2 =
b->ReduceAll(b->Select(above_max, gradient, zeroes), zero,
*ctx->GetOrCreateAdd(data_type));
ctx->SetOutput(2, output2);
}
private:
float quant_min_;
float quant_max_;
};
REGISTER_XLA_OP(Name("FakeQuantWithMinMaxVarsGradient"),
FakeQuantWithMinMaxVarsGradOp);
} // namespace
} // namespace tensorflow

View File

@ -1,4 +1,4 @@
/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
/* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
@ -14,6 +14,7 @@ limitations under the License.
==============================================================================*/
#include "tensorflow/compiler/tf2xla/kernels/gather_op_helpers.h"
#include "tensorflow/compiler/tf2xla/lib/while_loop.h"
#include "tensorflow/compiler/tf2xla/shape_util.h"
#include "tensorflow/compiler/tf2xla/type_util.h"
#include "tensorflow/compiler/tf2xla/xla_context.h"
@ -32,12 +33,12 @@ Status XlaGather(const xla::ComputationDataHandle& input,
DataType dtype, DataType index_type,
xla::ComputationBuilder* builder,
xla::ComputationDataHandle* gather_output) {
// If the indices are N-dimensional, then the last dimension of indices should
// be of size N and correspond to the N indices.
int64 num_axes = 1;
// If the indices are N-dimensional, then the minor dimension of indices
// should be of size N and correspond to the N indices.
int64 num_index_dims = 1;
if (indices_are_nd) {
CHECK_GE(indices_shape.dims(), 1);
num_axes = indices_shape.dim_size(indices_shape.dims() - 1);
num_index_dims = indices_shape.dim_size(indices_shape.dims() - 1);
indices_shape.RemoveLastDims(1);
}
@ -46,15 +47,15 @@ Status XlaGather(const xla::ComputationDataHandle& input,
// input, the output is returned with shape:
// input.shape[:axis] + indices.shape + input.shape[axis+1:]
const int num_indices = indices_shape.num_elements();
const int64 num_indices = indices_shape.num_elements();
TensorShape input_shape_pre_axis(input_shape);
input_shape_pre_axis.RemoveDimRange(axis, input_shape.dims());
TensorShape input_shape_post_axis(input_shape);
input_shape_post_axis.RemoveDimRange(0, axis + num_axes);
input_shape_post_axis.RemoveDimRange(0, axis + num_index_dims);
// Each slice of the input tensor has shape:
// [<input_shape_pre_axis>, 1, ..., 1, <input shape_post_axis>]
TensorShape slice_shape(input_shape);
for (int64 i = 0; i < num_axes; ++i) {
for (int64 i = 0; i < num_index_dims; ++i) {
slice_shape.set_dim(axis + i, 1);
}
@ -79,7 +80,7 @@ Status XlaGather(const xla::ComputationDataHandle& input,
return Status::OK();
}
for (int64 i = 0; i < num_axes; ++i) {
for (int64 i = 0; i < num_index_dims; ++i) {
if (input_shape.dim_size(axis + i) == 0) {
return errors::InvalidArgument("Gather dimension ", axis + i,
" is of size zero in tensor with shape ",
@ -91,57 +92,30 @@ Status XlaGather(const xla::ComputationDataHandle& input,
// iteration. If there is an axis dimension, we must leave it alone.
std::vector<int64> flat_indices_shape = {num_indices};
if (indices_are_nd) {
flat_indices_shape.push_back(num_axes);
flat_indices_shape.push_back(num_index_dims);
}
// Specify the shape of the loop-carried Tensor tuple.
xla::PrimitiveType ptype;
TF_CHECK_OK(DataTypeToPrimitiveType(dtype, &ptype));
xla::PrimitiveType idxtype;
TF_CHECK_OK(DataTypeToPrimitiveType(index_type, &idxtype));
std::vector<xla::Shape> tuple_shapes(
{// The iteration counter i is a scalar, incremented each iteration.
xla::ShapeUtil::MakeShape(idxtype, {}),
// The input array has shape input_shape. Loop invariant.
xla::ShapeUtil::MakeShape(ptype, input_shape.dim_sizes()),
// The gather indices are reshaped to flat_indices_shape. Loop invariant.
xla::ShapeUtil::MakeShape(idxtype, flat_indices_shape),
// The output array, which is updated on each loop iteration.
xla::ShapeUtil::MakeShape(ptype, loop_out_shape.dim_sizes())});
xla::Shape tuple_shape = xla::ShapeUtil::MakeTupleShape(tuple_shapes);
// Construct the initial values of the loop-carried Tensors.
auto init_i = XlaHelpers::Zero(builder, index_type);
auto flat_indices = builder->Reshape(indices, flat_indices_shape);
auto init_out = builder->Broadcast(XlaHelpers::Zero(builder, dtype),
loop_out_shape.dim_sizes());
auto flat_indices = builder->Reshape(indices, flat_indices_shape);
auto init = builder->Tuple({init_i, input, flat_indices, init_out});
// Construct the while loop condition (i < num_indices)
std::unique_ptr<xla::ComputationBuilder> condb =
builder->CreateSubBuilder("GatherWhileCond");
condb->Lt(condb->GetTupleElement(
condb->Parameter(0, tuple_shape, "GatherWhileTuple"), 0),
XlaHelpers::IntegerLiteral(condb.get(), index_type, num_indices));
auto cond_status = condb->Build();
auto cond = cond_status.ConsumeValueOrDie();
auto init = {input, flat_indices, init_out};
// Construct the while loop body's function. The implementation of gather is:
// for i in range(num_indices):
// index = dynamic-slice(indices, i)
// xi = dynamic-slice(input, index)
// output = dynamic-update-slice(output, xi, i)
std::unique_ptr<xla::ComputationBuilder> bodyb =
builder->CreateSubBuilder("GatherWhileBody");
{
// The four loop carried values.
auto loop_tuple = bodyb->Parameter(0, tuple_shape, "GatherWhileTuple");
auto i = bodyb->GetTupleElement(loop_tuple, 0);
auto input = bodyb->GetTupleElement(loop_tuple, 1);
auto indices = bodyb->GetTupleElement(loop_tuple, 2);
auto output = bodyb->GetTupleElement(loop_tuple, 3);
auto body_fn = [&](xla::ComputationDataHandle i,
gtl::ArraySlice<xla::ComputationDataHandle> loop_vars,
xla::ComputationBuilder* bodyb) {
auto input = loop_vars[0];
auto indices = loop_vars[1];
auto output = loop_vars[2];
auto zero_index = XlaHelpers::Zero(bodyb.get(), index_type);
auto zero_index = XlaHelpers::Zero(bodyb, index_type);
// Slice the i-th index from the indices array.
xla::ComputationDataHandle index;
@ -150,7 +124,7 @@ Status XlaGather(const xla::ComputationDataHandle& input,
// Slice out the entire nd index, if applicable.
indices_offset = bodyb->Pad(indices_offset, zero_index,
xla::MakeEdgePaddingConfig({{0, 1}}));
index = bodyb->DynamicSlice(indices, indices_offset, {1, num_axes});
index = bodyb->DynamicSlice(indices, indices_offset, {1, num_index_dims});
index = bodyb->Collapse(index, {0, 1});
} else {
index = bodyb->DynamicSlice(indices, indices_offset, {1});
@ -174,16 +148,16 @@ Status XlaGather(const xla::ComputationDataHandle& input,
// Update the output Tensor
auto updated_output = bodyb->DynamicUpdateSlice(output, slice_i, out_index);
bodyb->Tuple({bodyb->Add(i, XlaHelpers::One(bodyb.get(), index_type)),
input, indices, updated_output});
}
auto body_status = bodyb->Build();
auto body = body_status.ConsumeValueOrDie();
return std::vector<xla::ComputationDataHandle>{input, indices,
updated_output};
};
// Construct the While loop, extract and reshape the output.
auto gather_while = builder->While(cond, body, init);
auto result = builder->GetTupleElement(gather_while, 3);
*gather_output = builder->Reshape(result, out_shape.dim_sizes());
xla::PrimitiveType ptype;
TF_RETURN_IF_ERROR(DataTypeToPrimitiveType(index_type, &ptype));
TF_ASSIGN_OR_RETURN(auto outputs, XlaForEachIndex(num_indices, ptype, body_fn,
init, "gather", builder));
*gather_output = builder->Reshape(outputs[2], out_shape.dim_sizes());
return Status::OK();
}
@ -250,9 +224,10 @@ class GatherNdOp : public XlaOpKernel {
errors::InvalidArgument("params must be at least a vector"));
OP_REQUIRES(context, TensorShapeUtils::IsVectorOrHigher(indices_shape),
errors::InvalidArgument("indices must be at least a vector"));
const int64 num_axes = indices_shape.dim_size(indices_shape.dims() - 1);
const int64 num_index_dims =
indices_shape.dim_size(indices_shape.dims() - 1);
OP_REQUIRES(
context, num_axes <= params_shape.dims(),
context, num_index_dims <= params_shape.dims(),
errors::InvalidArgument(
"index innermost dimension length must be <= params rank; saw: ",
indices_shape.dim_size(indices_shape.dims() - 1), " vs. ",

View File

@ -0,0 +1,121 @@
/* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#include "tensorflow/compiler/tf2xla/lib/scatter.h"
#include "tensorflow/compiler/tf2xla/shape_util.h"
#include "tensorflow/compiler/tf2xla/type_util.h"
#include "tensorflow/compiler/tf2xla/xla_helpers.h"
#include "tensorflow/compiler/tf2xla/xla_op_kernel.h"
#include "tensorflow/compiler/tf2xla/xla_op_registry.h"
#include "tensorflow/compiler/xla/status_macros.h"
#include "tensorflow/core/framework/kernel_def_builder.h"
#include "tensorflow/core/framework/op_kernel.h"
namespace tensorflow {
namespace {
// Check whether updates.shape = indices.shape[:batch_dim] +
// buffer_shape[num_index_dims:]
Status ValidateUpdateShape(const TensorShape& buffer_shape,
const TensorShape& indices_shape,
const TensorShape& updates_shape) {
if (indices_shape.dims() < 1) {
return errors::InvalidArgument(
"indices shape must have >= 1 dimension; got ",
indices_shape.DebugString());
}
const int64 num_index_dims = indices_shape.dim_size(indices_shape.dims() - 1);
const int64 batch_dim = indices_shape.dims() - 1;
auto shape_err = [&]() {
return errors::InvalidArgument(
"Must have updates.shape = indices.shape[:batch_dim] + ",
"buffer_shape[num_index_dims:], got updates.shape: ",
updates_shape.DebugString(),
", indices.shape: ", indices_shape.DebugString(),
", buffer_shape: ", buffer_shape.DebugString(),
", num_index_dims: ", num_index_dims, ", and batch_dim: ", batch_dim);
};
if (updates_shape.dims() < batch_dim) return shape_err();
if (buffer_shape.dims() <
num_index_dims + (updates_shape.dims() - batch_dim)) {
return shape_err();
}
if (updates_shape.dims() !=
batch_dim + buffer_shape.dims() - num_index_dims) {
return shape_err();
}
for (int d = 0; d < batch_dim; ++d) {
if (updates_shape.dim_size(d) != indices_shape.dim_size(d)) {
return shape_err();
}
}
for (int d = 0; d < updates_shape.dims() - batch_dim; ++d) {
if (updates_shape.dim_size(d + batch_dim) !=
buffer_shape.dim_size(d + num_index_dims)) {
return shape_err();
}
}
return Status::OK();
}
class ScatterNdOp : public XlaOpKernel {
public:
explicit ScatterNdOp(OpKernelConstruction* context) : XlaOpKernel(context) {}
void Compile(XlaOpKernelContext* context) override {
DataType dtype = context->input_type(1);
TensorShape indices_shape = context->InputShape(0);
TensorShape updates_shape = context->InputShape(1);
TensorShape buffer_shape;
OP_REQUIRES_OK(context, context->ConstantInputAsShape(2, &buffer_shape));
OP_REQUIRES(
context, TensorShapeUtils::IsVectorOrHigher(buffer_shape),
errors::InvalidArgument("Output must be at least 1-D, ",
"got shape: ", buffer_shape.DebugString()));
OP_REQUIRES(
context,
buffer_shape.num_elements() > 0 || (indices_shape.num_elements() == 0 &&
updates_shape.num_elements() == 0),
errors::InvalidArgument(
"Indices and updates specified for empty output. indices shape: ",
indices_shape.DebugString()));
OP_REQUIRES_OK(context, ValidateUpdateShape(buffer_shape, indices_shape,
updates_shape));
xla::ComputationBuilder* builder = context->builder();
auto buffer = builder->Broadcast(XlaHelpers::Zero(builder, dtype),
buffer_shape.dim_sizes());
auto indices = context->Input(0);
auto updates = context->Input(1);
auto result =
XlaScatter(buffer, updates, indices,
/*indices_are_vectors=*/true, /*combiner=*/{}, builder);
OP_REQUIRES_OK(context, result.status());
context->SetOutput(0, result.ValueOrDie());
}
};
REGISTER_XLA_OP(Name("ScatterNd").CompileTimeConstInput("shape"), ScatterNdOp);
} // namespace
} // namespace tensorflow

View File

@ -1,39 +0,0 @@
/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
// Helper methods for XLA Scatter Ops.
#ifndef TENSORFLOW_COMPILER_TF2XLA_KERNELS_SCATTER_OP_HELPERS_H_
#define TENSORFLOW_COMPILER_TF2XLA_KERNELS_SCATTER_OP_HELPERS_H_
#include "tensorflow/compiler/tf2xla/xla_op_kernel.h"
#include "tensorflow/compiler/xla/client/client_library.h"
#include "tensorflow/compiler/xla/client/computation_builder.h"
#include "tensorflow/core/framework/op_kernel.h"
#include "tensorflow/core/util/bcast.h"
namespace tensorflow {
// Adds to builder an XLA computation that performs a scatter-add of input (of
// shape input_shape) keyed on indices (of shape indices_shape). The shape
// of the Tensor returned by this is num_segments input_shape[indices.dims():]
//
static xla::ComputationDataHandle XlaComputeScatterAddDynamicSlice(
XlaOpKernelContext* ctx, const xla::ComputationDataHandle& input,
const TensorShape& input_shape, const xla::ComputationDataHandle& indices,
const TensorShape& indices_shape, int64 num_segments, DataType dtype,
xla::ComputationBuilder* builder);
} // namespace tensorflow
#endif // TENSORFLOW_COMPILER_TF2XLA_KERNELS_SCATTER_OP_HELPERS_H_

View File

@ -1,4 +1,4 @@
/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
/* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
@ -13,125 +13,13 @@ See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#include <sstream>
#include "tensorflow/compiler/tf2xla/kernels/cwise_ops.h"
#include "tensorflow/compiler/tf2xla/shape_util.h"
#include "tensorflow/compiler/tf2xla/type_util.h"
#include "tensorflow/compiler/tf2xla/lib/scatter.h"
#include "tensorflow/compiler/tf2xla/xla_helpers.h"
#include "tensorflow/compiler/tf2xla/xla_op_kernel.h"
#include "tensorflow/compiler/tf2xla/xla_op_registry.h"
#include "tensorflow/compiler/xla/client/computation_builder.h"
#include "tensorflow/compiler/xla/literal_util.h"
#include "tensorflow/core/framework/kernel_def_builder.h"
#include "tensorflow/core/framework/types.h"
namespace tensorflow {
xla::ComputationDataHandle XlaComputeScatterAddDynamicSlice(
XlaOpKernelContext* ctx, const xla::ComputationDataHandle& input,
const TensorShape& input_shape, const xla::ComputationDataHandle& indices,
const TensorShape& indices_shape, int64 num_segments, DataType dtype,
xla::ComputationBuilder* builder) {
// Flatten data for dynamic indexing via indices_1d.
TensorShape input_shape_i(input_shape);
for (int64 d = 0; d < indices_shape.dims(); ++d) {
input_shape_i.RemoveDim(0);
}
TensorShape flat_shape({indices_shape.num_elements()});
flat_shape.AppendShape(input_shape_i);
// output is same as flattened input shape with dim_size(0) = num_segments.
TensorShape out_shape(flat_shape);
out_shape.set_dim(0, num_segments);
// Slices from the input data are same shape as the input data, except dim 0.
TensorShape slice_shape(flat_shape);
slice_shape.set_dim(0, 1);
TensorShape loop_out_slice_shape(out_shape);
loop_out_slice_shape.set_dim(0, 1);
// Construct the initial values of the loop-carried variables
// Flatten the indices into 1-D for ease of iteration.
auto indices_1d = builder->Reshape(indices, {indices_shape.num_elements()});
// Flatten the data for ease of indexing via values in indices_1d.
auto data_flat = builder->Reshape(input, flat_shape.dim_sizes());
auto init_i = builder->ConstantR0<int32>(0);
auto init_out = builder->Broadcast(XlaHelpers::Zero(builder, dtype),
out_shape.dim_sizes());
xla::PrimitiveType ptype;
TF_CHECK_OK(DataTypeToPrimitiveType(dtype, &ptype));
std::vector<xla::Shape> tuple_shapes(
{// The loop iteration counter is a scalar, incremented each iteration.
xla::ShapeUtil::MakeShape(xla::S32, {}),
// The flattened input data is loop invariant.
xla::ShapeUtil::MakeShape(ptype, flat_shape.dim_sizes()),
// The scatter indices tensor is loop invariant.
xla::ShapeUtil::MakeShape(xla::S32, {indices_shape.num_elements()}),
// The output data array is updated each loop iteration.
xla::ShapeUtil::MakeShape(ptype, out_shape.dim_sizes())});
xla::Shape tuple_shape = xla::ShapeUtil::MakeTupleShape(tuple_shapes);
auto init = builder->Tuple({init_i, data_flat, indices_1d, init_out});
// Construct the while loop condition (i < num_indices)
xla::ComputationBuilder condb(ctx->builder()->client(),
"ScatterAddWhileCond");
condb.Lt(condb.GetTupleElement(
condb.Parameter(0, tuple_shape, "ScatterAddWhileTuple"), 0),
condb.ConstantR0<int32>(indices_shape.num_elements()));
auto cond_status = condb.Build();
auto cond = cond_status.ConsumeValueOrDie();
// Construct the while loop body's function. The implementation of scatter is:
// for i in range(num_indices):
// index = dynamic-slice(indices, i)
// xi = dynamic-slice(input, i)
// output = dynamic-update-slice(output, xi, index)
xla::ComputationBuilder bodyb(ctx->builder()->client(),
"ScatterAddWhileBody");
{
auto input_tuple = bodyb.Parameter(0, tuple_shape, "ScatterAddWhileTuple");
auto i = bodyb.GetTupleElement(input_tuple, 0);
auto data = bodyb.GetTupleElement(input_tuple, 1);
auto idcs = bodyb.GetTupleElement(input_tuple, 2);
auto output = bodyb.GetTupleElement(input_tuple, 3);
// Index into the data array at i.
auto zero = bodyb.ConstantR1<int32>({0});
std::vector<xla::ComputationDataHandle> index_vals(flat_shape.dims(), zero);
index_vals[0] = bodyb.Reshape(i, {1});
auto index = bodyb.ConcatInDim(index_vals, 0);
auto data_slice =
bodyb.Reshape(bodyb.DynamicSlice(data, index, slice_shape.dim_sizes()),
loop_out_slice_shape.dim_sizes());
// Index into the output array.
std::vector<xla::ComputationDataHandle> out_index_vals(out_shape.dims(),
zero);
out_index_vals[0] = bodyb.DynamicSlice(idcs, bodyb.Reshape(i, {1}), {1});
auto out_index = bodyb.ConcatInDim(out_index_vals, 0);
// Slice the output array, update value, and update the output slice.
auto updated_output = bodyb.DynamicUpdateSlice(
output,
bodyb.Add(data_slice,
bodyb.DynamicSlice(output, out_index,
loop_out_slice_shape.dim_sizes())),
out_index);
auto ip1 = bodyb.Add(i, bodyb.ConstantR0<int32>(1));
bodyb.Tuple({ip1, data, idcs, updated_output});
}
auto body_status = bodyb.Build();
auto body = body_status.ConsumeValueOrDie();
auto gather_while = builder->While(cond, body, init);
return builder->GetTupleElement(gather_while, 3);
}
namespace {
class UnsortedSegmentSum : public XlaOpKernel {
@ -153,10 +41,10 @@ class UnsortedSegmentSum : public XlaOpKernel {
// as data with the first indices.rank dimensions are replaced
// by a single dimension with size num_segments.
auto data = ctx->Input(0);
auto data_shape = ctx->InputShape(0);
TensorShape data_shape = ctx->InputShape(0);
auto indices = ctx->Input(1);
auto indices_shape = ctx->InputShape(1);
TensorShape indices_shape = ctx->InputShape(1);
int64 num_segments;
OP_REQUIRES_OK(ctx, ctx->ConstantInputAsIntScalar(2, &num_segments));
@ -174,10 +62,21 @@ class UnsortedSegmentSum : public XlaOpKernel {
d, " differs ", data_shape.dim_size(d), " vs. ",
indices_shape.dim_size(d)));
}
auto result = XlaComputeScatterAddDynamicSlice(
ctx, data, data_shape, indices, indices_shape, num_segments, dtype_,
ctx->builder());
ctx->SetOutput(0, result);
xla::ComputationBuilder* builder = ctx->builder();
TensorShape buffer_shape = data_shape;
buffer_shape.RemoveDimRange(0, indices_shape.dims());
buffer_shape.InsertDim(0, num_segments);
auto buffer = builder->Broadcast(XlaHelpers::Zero(builder, dtype_),
buffer_shape.dim_sizes());
auto combiner =
[](xla::ComputationDataHandle a, xla::ComputationDataHandle b,
xla::ComputationBuilder* builder) { return builder->Add(a, b); };
auto result = XlaScatter(buffer, /*updates=*/data, indices,
/*indices_are_vectors=*/false, combiner, builder);
OP_REQUIRES_OK(ctx, result.status());
ctx->SetOutput(0, result.ValueOrDie());
}
private:

View File

@ -49,6 +49,25 @@ cc_library(
],
)
cc_library(
name = "scatter",
srcs = ["scatter.cc"],
hdrs = ["scatter.h"],
deps = [
":util",
":while_loop",
"//tensorflow/compiler/xla:literal_util",
"//tensorflow/compiler/xla:shape_util",
"//tensorflow/compiler/xla:status_macros",
"//tensorflow/compiler/xla:statusor",
"//tensorflow/compiler/xla:util",
"//tensorflow/compiler/xla/client:computation",
"//tensorflow/compiler/xla/client:computation_builder",
"//tensorflow/compiler/xla/client/lib:arithmetic",
"//tensorflow/core:lib",
],
)
cc_library(
name = "triangular_solve",
srcs = ["triangular_solve.cc"],
@ -107,6 +126,21 @@ cc_library(
],
)
cc_library(
name = "while_loop",
srcs = ["while_loop.cc"],
hdrs = ["while_loop.h"],
deps = [
":util",
"//tensorflow/compiler/xla:shape_util",
"//tensorflow/compiler/xla:status_macros",
"//tensorflow/compiler/xla:statusor",
"//tensorflow/compiler/xla/client:computation",
"//tensorflow/compiler/xla/client:computation_builder",
"//tensorflow/core:lib",
],
)
# -----------------------------------------------------------------------------
filegroup(

View File

@ -0,0 +1,189 @@
/* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#include "tensorflow/compiler/tf2xla/lib/scatter.h"
#include <memory>
#include <vector>
#include "tensorflow/compiler/tf2xla/lib/util.h"
#include "tensorflow/compiler/tf2xla/lib/while_loop.h"
#include "tensorflow/compiler/xla/client/lib/arithmetic.h"
#include "tensorflow/compiler/xla/literal_util.h"
#include "tensorflow/compiler/xla/shape_util.h"
#include "tensorflow/compiler/xla/status_macros.h"
#include "tensorflow/compiler/xla/util.h"
#include "tensorflow/core/lib/core/errors.h"
#include "tensorflow/core/lib/gtl/array_slice.h"
namespace tensorflow {
xla::StatusOr<xla::ComputationDataHandle> XlaScatter(
const xla::ComputationDataHandle& buffer,
const xla::ComputationDataHandle& updates,
const xla::ComputationDataHandle& indices, bool indices_are_vectors,
const std::function<xla::ComputationDataHandle(
xla::ComputationDataHandle, xla::ComputationDataHandle,
xla::ComputationBuilder*)>& combiner,
xla::ComputationBuilder* builder) {
TF_ASSIGN_OR_RETURN(std::unique_ptr<xla::Shape> buffer_shape,
builder->GetShape(buffer));
TF_ASSIGN_OR_RETURN(std::unique_ptr<xla::Shape> updates_shape,
builder->GetShape(updates));
TF_ASSIGN_OR_RETURN(std::unique_ptr<xla::Shape> indices_shape,
builder->GetShape(indices));
gtl::ArraySlice<int64> indices_dims =
xla::AsInt64Slice(indices_shape->dimensions());
gtl::ArraySlice<int64> buffer_dims =
xla::AsInt64Slice(buffer_shape->dimensions());
// If the indices are N-dimensional, the minor dimension of indices contains
// the indices to update. Otherwise the indices are all scalars.
int64 num_index_dims = 1;
if (indices_are_vectors) {
TF_RET_CHECK(!indices_dims.empty());
num_index_dims = indices_dims.back();
if (num_index_dims > xla::ShapeUtil::Rank(*buffer_shape)) {
return errors::InvalidArgument(
"The size of the minor dimension of the indices (shape: ",
xla::ShapeUtil::HumanString(*indices_shape),
") must be <= the rank of the buffer (shape: ",
xla::ShapeUtil::HumanString(*buffer_shape), ")");
}
indices_dims.pop_back();
}
int64 num_indices = 1;
for (int64 dim : indices_dims) {
num_indices *= dim;
}
// Degenerate case: nothing to update. Return the buffer unchanged.
if (num_indices == 0) {
return buffer;
}
// If any of the indexed dimensions are zero in the buffer, the update cannot
// succeed since it updates a slice of size 1.
for (int64 i = 0; i < num_index_dims; ++i) {
if (xla::ShapeUtil::GetDimension(*buffer_shape, i) == 0) {
return errors::InvalidArgument(
"Scatter dimension ", i, " is of size zero in tensor with shape ",
xla::ShapeUtil::HumanString(*buffer_shape));
}
}
// Shape of the non-indexed dimensions of the buffer.
std::vector<int64> buffer_shape_post_axes(
buffer_dims.begin() + num_index_dims, buffer_dims.end());
// Flatten the major dimensions of indices and updates into a single dimension
// for ease of iteration.
std::vector<int64> flat_indices_shape({num_indices});
if (indices_are_vectors) {
flat_indices_shape.push_back(num_index_dims);
}
std::vector<int64> flat_updates_shape({num_indices});
flat_updates_shape.insert(flat_updates_shape.end(),
buffer_shape_post_axes.begin(),
buffer_shape_post_axes.end());
// Construct the initial values of the loop-carried Tensors.
auto flat_indices = builder->Reshape(indices, flat_indices_shape);
auto flat_updates = builder->Reshape(updates, flat_updates_shape);
auto init = {flat_indices, flat_updates, buffer};
// Constructs the loop body. The implementation of scatter is essentially:
// for i in range(num_indices):
// index = dynamic-slice(indices, i)
// update = dynamic-slice(updates, i)
// buffer = dynamic-update-slice(buffer, update, index)
auto body_fn = [&](xla::ComputationDataHandle i,
gtl::ArraySlice<xla::ComputationDataHandle> loop_vars,
xla::ComputationBuilder* body_builder) {
auto indices = loop_vars[0];
auto updates = loop_vars[1];
auto buffer = loop_vars[2];
auto zero_index = body_builder->ConstantLiteral(
xla::Literal::Zero(indices_shape->element_type()));
// Slice the i-th index from the indices array.
xla::ComputationDataHandle index;
auto indices_offset = body_builder->Reshape(i, {1});
if (indices_are_vectors) {
indices_offset = body_builder->Pad(indices_offset, zero_index,
xla::MakeEdgePaddingConfig({{0, 1}}));
index = body_builder->DynamicSlice(indices, indices_offset,
{1, num_index_dims});
index = body_builder->Collapse(index, {0, 1});
} else {
index = body_builder->DynamicSlice(indices, indices_offset, {1});
}
// Discard updates with negative indices, since some users expect this.
auto index_in_range =
body_builder->ReduceAll(body_builder->Le(zero_index, index),
body_builder->ConstantR0<bool>(true),
xla::CreateScalarAndComputation(body_builder));
index = body_builder->Pad(
index, zero_index,
xla::MakeEdgePaddingConfig({{0, buffer_shape_post_axes.size()}}));
// Slice the i-th index from the updates array.
auto updates_offset = body_builder->Reshape(i, {1});
updates_offset = body_builder->Pad(
updates_offset, zero_index,
xla::MakeEdgePaddingConfig({{0, buffer_shape_post_axes.size()}}));
std::vector<int64> flat_updates_slice_shape({1});
flat_updates_slice_shape.insert(flat_updates_slice_shape.end(),
buffer_shape_post_axes.begin(),
buffer_shape_post_axes.end());
auto update = body_builder->DynamicSlice(updates, updates_offset,
flat_updates_slice_shape);
// Unflatten the major (iteration) dimensions of the slice to their original
// shape.
std::vector<int64> updates_slice_shape(num_index_dims, 1);
updates_slice_shape.insert(updates_slice_shape.end(),
buffer_shape_post_axes.begin(),
buffer_shape_post_axes.end());
update = body_builder->Reshape(update, updates_slice_shape);
// Apply the update to the buffer. If there is a combiner, use it to merge
// the current values with the update.
if (combiner) {
auto current_value =
body_builder->DynamicSlice(buffer, index, updates_slice_shape);
update = combiner(current_value, update, body_builder);
}
// Apply the update if it is in range.
buffer = body_builder->Select(
index_in_range, body_builder->DynamicUpdateSlice(buffer, update, index),
buffer);
return std::vector<xla::ComputationDataHandle>{indices, updates, buffer};
};
TF_ASSIGN_OR_RETURN(
auto outputs, XlaForEachIndex(num_indices, indices_shape->element_type(),
body_fn, init, "scatter", builder));
return outputs[2];
}
} // namespace tensorflow

View File

@ -0,0 +1,53 @@
/* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#ifndef TENSORFLOW_COMPILER_TF2XLA_LIB_SCATTER_H_
#define TENSORFLOW_COMPILER_TF2XLA_LIB_SCATTER_H_
#include <functional>
#include "tensorflow/compiler/xla/client/computation.h"
#include "tensorflow/compiler/xla/client/computation_builder.h"
#include "tensorflow/compiler/xla/statusor.h"
namespace tensorflow {
// Builds an XLA computation that performs a scatter operation on `buffer`,
// returning an updated buffer.
// For each i0, i1, ..., sets
// buffer[indices[i0, i1, ...], ...] := updates[i0, i1, ...]
//
// If `indices_are_vectors` is false, then each index in indices is a scalar,
// and the shape of `indices` must be a prefix of the shape of updates.
// Otherwise, `indices_are_vectors`, then indices are multidimensional and the
// minor dimension of `indices` represents a vector of indices.
//
// If any indices are negative, the corresponding update is discarded.
//
// If a `combiner` is provided, updates are combined with the existing values in
// the buffer using the combiner function. Otherwise, the updates replace the
// existing values. The order of updates is implementation-defined.
xla::StatusOr<xla::ComputationDataHandle> XlaScatter(
const xla::ComputationDataHandle& buffer,
const xla::ComputationDataHandle& updates,
const xla::ComputationDataHandle& indices, bool indices_are_vectors,
const std::function<xla::ComputationDataHandle(
xla::ComputationDataHandle, xla::ComputationDataHandle,
xla::ComputationBuilder*)>& combiner,
xla::ComputationBuilder* builder);
} // namespace tensorflow
#endif // TENSORFLOW_COMPILER_TF2XLA_LIB_SCATTER_H_

View File

@ -57,6 +57,61 @@ xla::ComputationDataHandle FloatLiteral(xla::ComputationBuilder* builder,
}
}
xla::ComputationDataHandle IntegerLiteral(xla::ComputationBuilder* builder,
xla::PrimitiveType type,
int64 value) {
xla::Literal literal;
switch (type) {
case xla::U8:
literal = std::move(*xla::Literal::CreateR0<uint8>(value));
break;
case xla::U32:
literal = std::move(*xla::Literal::CreateR0<uint32>(value));
break;
case xla::U64:
literal = std::move(*xla::Literal::CreateR0<uint64>(value));
break;
case xla::S8:
literal = std::move(*xla::Literal::CreateR0<int8>(value));
break;
case xla::S32:
literal = std::move(*xla::Literal::CreateR0<int32>(value));
break;
case xla::S64:
literal = std::move(*xla::Literal::CreateR0<int64>(value));
break;
case xla::F32:
literal = std::move(*xla::Literal::CreateR0<float>(value));
break;
case xla::F64:
literal = std::move(*xla::Literal::CreateR0<double>(value));
break;
case xla::C64:
literal = std::move(*xla::Literal::CreateR0<complex64>(value));
break;
case xla::PRED:
LOG(FATAL) << "pred element type is not integral";
case xla::S16:
case xla::U16:
LOG(FATAL) << "u16/s16 literals not yet implemented";
case xla::BF16:
literal = std::move(
*xla::Literal::CreateR0<bfloat16>(static_cast<bfloat16>(value)));
break;
case xla::F16:
literal = std::move(
*xla::Literal::CreateR0<xla::half>(static_cast<xla::half>(value)));
break;
case xla::TUPLE:
LOG(FATAL) << "tuple element type is not integral";
case xla::OPAQUE:
LOG(FATAL) << "opaque element type is not integral";
default:
LOG(FATAL) << "unhandled element type " << type;
}
return builder->ConstantLiteral(literal);
}
xla::StatusOr<xla::ComputationDataHandle> SliceInMinorDims(
xla::ComputationBuilder* builder, const xla::ComputationDataHandle& x,
gtl::ArraySlice<int64> start, gtl::ArraySlice<int64> end) {

View File

@ -32,6 +32,11 @@ xla::ComputationDataHandle Zeros(xla::ComputationBuilder* builder,
xla::ComputationDataHandle FloatLiteral(xla::ComputationBuilder* builder,
xla::PrimitiveType type, double value);
// Returns a integer scalar constant of 'type' with 'value'.
// If 'type' is complex, returns a real value with zero imaginary component.
xla::ComputationDataHandle IntegerLiteral(xla::ComputationBuilder* builder,
xla::PrimitiveType type, int64 value);
// Performs a slice in the minor dimensions of a Tensor.
xla::StatusOr<xla::ComputationDataHandle> SliceInMinorDims(
xla::ComputationBuilder* builder, const xla::ComputationDataHandle& x,

View File

@ -0,0 +1,125 @@
/* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#include "tensorflow/compiler/tf2xla/lib/while_loop.h"
#include "tensorflow/compiler/tf2xla/lib/util.h"
#include "tensorflow/compiler/xla/shape_util.h"
#include "tensorflow/compiler/xla/status_macros.h"
namespace tensorflow {
xla::StatusOr<std::vector<xla::ComputationDataHandle>> XlaWhileLoop(
const LoopConditionFunction& condition_function,
const LoopBodyFunction& body_function,
gtl::ArraySlice<xla::ComputationDataHandle> initial_values,
StringPiece name, xla::ComputationBuilder* builder) {
int arity = initial_values.size();
std::vector<xla::Shape> var_shapes;
var_shapes.reserve(arity);
for (const xla::ComputationDataHandle& input : initial_values) {
TF_ASSIGN_OR_RETURN(auto shape, builder->GetShape(input));
var_shapes.push_back(std::move(*shape));
}
xla::Shape tuple_shape = xla::ShapeUtil::MakeTupleShape(var_shapes);
// Unpacks a tuple into its component parts.
auto unpack_tuple = [](xla::ComputationDataHandle tuple, int arity,
xla::ComputationBuilder* builder) {
std::vector<xla::ComputationDataHandle> elements(arity);
for (int i = 0; i < arity; ++i) {
elements[i] = builder->GetTupleElement(tuple, i);
}
return elements;
};
// Build the condition.
std::unique_ptr<xla::ComputationBuilder> cond_builder =
builder->CreateSubBuilder(strings::StrCat(name, "_condition"));
{
auto parameter = cond_builder->Parameter(0, tuple_shape, "parameter");
TF_ASSIGN_OR_RETURN(
auto result,
condition_function(unpack_tuple(parameter, arity, cond_builder.get()),
cond_builder.get()));
TF_RETURN_IF_ERROR(cond_builder->SetReturnValue(result));
}
TF_ASSIGN_OR_RETURN(auto cond, cond_builder->Build());
// Build the body.
std::unique_ptr<xla::ComputationBuilder> body_builder =
builder->CreateSubBuilder(strings::StrCat(name, "_body"));
{
auto parameter = body_builder->Parameter(0, tuple_shape, "parameter");
TF_ASSIGN_OR_RETURN(
auto result,
body_function(unpack_tuple(parameter, arity, body_builder.get()),
body_builder.get()));
TF_RET_CHECK(result.size() == initial_values.size());
body_builder->Tuple(result);
}
TF_ASSIGN_OR_RETURN(auto body, body_builder->Build());
auto outputs = builder->While(cond, body, builder->Tuple(initial_values));
return unpack_tuple(outputs, arity, builder);
}
xla::StatusOr<std::vector<xla::ComputationDataHandle>> XlaForEachIndex(
int64 num_iterations, xla::PrimitiveType num_iterations_type,
const ForEachIndexBodyFunction& body_function,
gtl::ArraySlice<xla::ComputationDataHandle> initial_values,
StringPiece name, xla::ComputationBuilder* builder) {
auto while_cond_fn = [&](gtl::ArraySlice<xla::ComputationDataHandle> values,
xla::ComputationBuilder* cond_builder)
-> xla::StatusOr<xla::ComputationDataHandle> {
return cond_builder->Lt(
values[0],
IntegerLiteral(cond_builder, num_iterations_type, num_iterations));
};
auto while_body_fn = [&](gtl::ArraySlice<xla::ComputationDataHandle> values,
xla::ComputationBuilder* body_builder)
-> xla::StatusOr<std::vector<xla::ComputationDataHandle>> {
xla::ComputationDataHandle iteration = values[0];
std::vector<xla::ComputationDataHandle> updated_values;
updated_values.reserve(values.size());
updated_values.push_back(body_builder->Add(
iteration,
body_builder->ConstantLiteral(xla::Literal::One(num_iterations_type))));
values.remove_prefix(1);
TF_ASSIGN_OR_RETURN(std::vector<xla::ComputationDataHandle> body_outputs,
body_function(iteration, values, body_builder));
updated_values.insert(updated_values.end(), body_outputs.begin(),
body_outputs.end());
return updated_values;
};
std::vector<xla::ComputationDataHandle> values;
values.reserve(initial_values.size() + 1);
values.push_back(
builder->ConstantLiteral(xla::Literal::Zero(num_iterations_type)));
values.insert(values.end(), initial_values.begin(), initial_values.end());
TF_ASSIGN_OR_RETURN(values, XlaWhileLoop(while_cond_fn, while_body_fn, values,
name, builder));
values.erase(values.begin(), values.begin() + 1);
return values;
}
} // namespace tensorflow

View File

@ -0,0 +1,74 @@
/* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#ifndef TENSORFLOW_COMPILER_TF2XLA_LIB_WHILE_LOOP_H_
#define TENSORFLOW_COMPILER_TF2XLA_LIB_WHILE_LOOP_H_
#include <functional>
#include <vector>
#include "tensorflow/compiler/xla/client/computation.h"
#include "tensorflow/compiler/xla/client/computation_builder.h"
#include "tensorflow/compiler/xla/statusor.h"
#include "tensorflow/core/lib/core/stringpiece.h"
#include "tensorflow/core/lib/gtl/array_slice.h"
namespace tensorflow {
// Function that builds a loop condition. Takes as input a sequence of input
// values, and returns a boolean value representing if the condition succeeds.
typedef std::function<xla::StatusOr<xla::ComputationDataHandle>(
gtl::ArraySlice<xla::ComputationDataHandle>, xla::ComputationBuilder*)>
LoopConditionFunction;
// Function that builds a loop body. Takes as input a sequence of input values
// and returns a sequence of output values.
typedef std::function<xla::StatusOr<std::vector<xla::ComputationDataHandle>>(
gtl::ArraySlice<xla::ComputationDataHandle>, xla::ComputationBuilder*)>
LoopBodyFunction;
// Helper function for building an XLA while loop, where the values carried by
// the loop are a tuple of values, e.g., (a, b, c):
// while(
// condition: (a, b, c) -> bool,
// body: (a, b, c) -> (a, b, c)
// init: (a, b, c)
// )
// 'name' is a descriptive name for the loop.
xla::StatusOr<std::vector<xla::ComputationDataHandle>> XlaWhileLoop(
const LoopConditionFunction& condition_function,
const LoopBodyFunction& body_function,
gtl::ArraySlice<xla::ComputationDataHandle> initial_values,
StringPiece name, xla::ComputationBuilder* builder);
// Builds an XLA loop that repeats a computation `num_iterations` times.
//
// The body function (ForEachIndexBodyFunction) takes as input a pair of
// (current iteration number, loop-carried values), and returns an updated
// vector of the loop-carried values.
typedef std::function<xla::StatusOr<std::vector<xla::ComputationDataHandle>>(
xla::ComputationDataHandle, gtl::ArraySlice<xla::ComputationDataHandle>,
xla::ComputationBuilder*)>
ForEachIndexBodyFunction;
xla::StatusOr<std::vector<xla::ComputationDataHandle>> XlaForEachIndex(
int64 num_iterations, xla::PrimitiveType num_iterations_type,
const ForEachIndexBodyFunction& body_function,
gtl::ArraySlice<xla::ComputationDataHandle> initial_values,
StringPiece name, xla::ComputationBuilder* builder);
} // namespace tensorflow
#endif // TENSORFLOW_COMPILER_TF2XLA_LIB_WHILE_LOOP_H_

View File

@ -135,58 +135,9 @@ xla::ComputationDataHandle XlaHelpers::Epsilon(xla::ComputationBuilder* b,
xla::ComputationDataHandle XlaHelpers::IntegerLiteral(
xla::ComputationBuilder* b, DataType data_type, int64 value) {
xla::Literal literal;
xla::PrimitiveType type;
TF_CHECK_OK(DataTypeToPrimitiveType(data_type, &type));
switch (type) {
case xla::U8:
literal = std::move(*xla::Literal::CreateR0<uint8>(value));
break;
case xla::U32:
literal = std::move(*xla::Literal::CreateR0<uint32>(value));
break;
case xla::U64:
literal = std::move(*xla::Literal::CreateR0<uint64>(value));
break;
case xla::S8:
literal = std::move(*xla::Literal::CreateR0<int8>(value));
break;
case xla::S32:
literal = std::move(*xla::Literal::CreateR0<int32>(value));
break;
case xla::S64:
literal = std::move(*xla::Literal::CreateR0<int64>(value));
break;
case xla::F32:
literal = std::move(*xla::Literal::CreateR0<float>(value));
break;
case xla::F64:
literal = std::move(*xla::Literal::CreateR0<double>(value));
break;
case xla::C64:
literal = std::move(*xla::Literal::CreateR0<complex64>(value));
break;
case xla::PRED:
LOG(FATAL) << "pred element type is not integral";
case xla::S16:
case xla::U16:
LOG(FATAL) << "u16/s16 literals not yet implemented";
case xla::BF16:
literal = std::move(
*xla::Literal::CreateR0<bfloat16>(static_cast<bfloat16>(value)));
break;
case xla::F16:
literal = std::move(
*xla::Literal::CreateR0<xla::half>(static_cast<xla::half>(value)));
break;
case xla::TUPLE:
LOG(FATAL) << "tuple element type is not integral";
case xla::OPAQUE:
LOG(FATAL) << "opaque element type is not integral";
default:
LOG(FATAL) << "unhandled element type " << type;
}
return b->ConstantLiteral(literal);
return ::tensorflow::IntegerLiteral(b, type, value);
}
xla::ComputationDataHandle XlaHelpers::FloatLiteral(xla::ComputationBuilder* b,

View File

@ -43,6 +43,81 @@ filegroup(
]),
)
cc_library(
name = "bfloat16_support",
srcs = ["bfloat16_support.cc"],
hdrs = ["bfloat16_support.h"],
deps = [
":hlo",
],
)
cc_library(
name = "bfloat16_conversion_folding",
srcs = ["bfloat16_conversion_folding.cc"],
hdrs = ["bfloat16_conversion_folding.h"],
deps = [
":bfloat16_support",
":hlo",
":hlo_pass",
"//tensorflow/compiler/xla:status_macros",
"//tensorflow/compiler/xla:xla_data_proto",
"//tensorflow/core:lib",
],
)
tf_cc_test(
name = "bfloat16_conversion_folding_test",
srcs = ["bfloat16_conversion_folding_test.cc"],
deps = [
":bfloat16_conversion_folding",
":bfloat16_support",
":hlo",
"//tensorflow/compiler/xla:shape_util",
"//tensorflow/compiler/xla:status_macros",
"//tensorflow/compiler/xla:test",
"//tensorflow/compiler/xla:test_helpers",
"//tensorflow/compiler/xla:types",
"//tensorflow/compiler/xla:xla_data_proto",
"//tensorflow/compiler/xla/tests:hlo_test_base",
"//tensorflow/compiler/xla/tests:xla_internal_test_main",
"//tensorflow/core:lib",
],
)
cc_library(
name = "bfloat16_normalization",
srcs = ["bfloat16_normalization.cc"],
hdrs = ["bfloat16_normalization.h"],
deps = [
":bfloat16_support",
":hlo",
":hlo_pass",
"//tensorflow/compiler/xla:status_macros",
"//tensorflow/compiler/xla:xla_data_proto",
"//tensorflow/core:lib",
],
)
tf_cc_test(
name = "bfloat16_normalization_test",
srcs = ["bfloat16_normalization_test.cc"],
deps = [
":bfloat16_normalization",
":bfloat16_support",
":hlo",
"//tensorflow/compiler/xla:shape_util",
"//tensorflow/compiler/xla:status_macros",
"//tensorflow/compiler/xla:test",
"//tensorflow/compiler/xla:test_helpers",
"//tensorflow/compiler/xla:types",
"//tensorflow/compiler/xla:xla_data_proto",
"//tensorflow/compiler/xla/tests:hlo_test_base",
"//tensorflow/compiler/xla/tests:xla_internal_test_main",
"//tensorflow/core:lib",
],
)
cc_library(
name = "shape_inference",
srcs = ["shape_inference.cc"],

View File

@ -0,0 +1,184 @@
/* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#include "tensorflow/compiler/xla/service/bfloat16_conversion_folding.h"
#include "tensorflow/compiler/xla/service/hlo_computation.h"
#include "tensorflow/compiler/xla/service/hlo_instruction.h"
#include "tensorflow/compiler/xla/service/hlo_opcode.h"
#include "tensorflow/compiler/xla/status_macros.h"
#include "tensorflow/compiler/xla/xla_data.pb.h"
#include "tensorflow/core/lib/gtl/array_slice.h"
#include "tensorflow/core/platform/logging.h"
#include "tensorflow/core/platform/types.h"
namespace xla {
class BFloat16ConversionFoldingVisitor : public DfsHloVisitorWithDefault {
public:
explicit BFloat16ConversionFoldingVisitor(
HloComputation* computation, const BFloat16Support* bfloat16_support)
: computation_(computation), bfloat16_support_(bfloat16_support) {}
Status DefaultAction(HloInstruction* hlo) override;
static bool Run(HloComputation* computation,
const BFloat16Support* bfloat16_support) {
BFloat16ConversionFoldingVisitor visitor(computation, bfloat16_support);
TF_CHECK_OK(computation->Accept(&visitor));
return visitor.changed_;
}
private:
// Checks if the HLO has a BF16 -> F32 conversion as input, or a F32 -> BF16
// conversion as output, and folds them to the HLO itself if feasible.
Status TryFoldBF16Conversions(HloInstruction* hlo);
// Folds the F32 -> BF16 conversions from the HLO's output.
//
// Precondition: all of the HLO's users are F32 -> BF16 conversions.
Status FoldOutputConversions(HloInstruction* hlo);
// Folds the BF16 -> F32 conversion operand to the HLO.
//
// Precondition: the operand is a F32 -> BF16 conversion.
Status FoldOperandConversion(HloInstruction* hlo, int64 operand_index);
HloComputation* computation_;
const BFloat16Support* bfloat16_support_;
bool changed_ = false;
};
Status BFloat16ConversionFoldingVisitor::FoldOutputConversions(
HloInstruction* hlo) {
std::vector<HloInstruction*> materialized_users = hlo->users();
hlo->mutable_shape()->set_element_type(BF16);
for (auto user : materialized_users) {
CHECK_EQ(user->opcode(), HloOpcode::kConvert);
TF_RETURN_IF_ERROR(user->ReplaceAllUsesWith(hlo));
changed_ = true;
}
return Status::OK();
}
Status BFloat16ConversionFoldingVisitor::FoldOperandConversion(
HloInstruction* hlo, int64 operand_index) {
// The operand is a convert from BF16 to F32.
auto operand = hlo->mutable_operand(operand_index);
CHECK_EQ(operand->opcode(), HloOpcode::kConvert);
TF_RETURN_IF_ERROR(
hlo->ReplaceOperandWith(operand_index, operand->mutable_operand(0)));
changed_ = true;
return Status::OK();
}
Status BFloat16ConversionFoldingVisitor::TryFoldBF16Conversions(
HloInstruction* hlo) {
std::vector<int64> bf16_to_f32_operands;
bool has_other_f32_operands = false;
for (int64 i = 0; i < hlo->operands().size(); ++i) {
auto operand = hlo->operand(i);
if (operand->shape().element_type() == F32) {
if (operand->opcode() == HloOpcode::kConvert &&
operand->operand(0)->shape().element_type() == BF16 &&
bfloat16_support_->SupportsBF16Operand(*hlo, i)) {
// Operand is a convert from BF16 to F32 and we support BF16 input
// directly in the current HLO at the operand index.
bf16_to_f32_operands.push_back(i);
} else {
has_other_f32_operands = true;
}
continue;
}
}
bool fold_output_conversion = hlo->user_count() > 0 &&
hlo->shape().element_type() == F32 &&
bfloat16_support_->SupportsBF16Output(*hlo) &&
hlo != computation_->root_instruction();
if (fold_output_conversion) {
for (auto user : hlo->users()) {
if (user->opcode() == HloOpcode::kConvert &&
user->shape().element_type() == BF16) {
continue;
}
// We should not change the output type if any user is not a conversion
// from F32 to BF16.
fold_output_conversion = false;
break;
}
}
if (!bfloat16_support_->SupportsMixedPrecisions(*hlo)) {
if (has_other_f32_operands ||
(!fold_output_conversion && hlo->shape().element_type() == F32)) {
// Some of the operands/output will remain F32, but we cannot use mixed
// precisions, so we cannot do anything here.
return Status::OK();
}
}
if (fold_output_conversion) {
TF_RETURN_IF_ERROR(FoldOutputConversions(hlo));
}
for (int64 i : bf16_to_f32_operands) {
TF_RETURN_IF_ERROR(FoldOperandConversion(hlo, i));
}
return Status::OK();
}
Status BFloat16ConversionFoldingVisitor::DefaultAction(HloInstruction* hlo) {
// Do not fold BF16 conversions for instructions related to tuples, entry and
// exit of a computation, fusion, convert, and control flow.
if (hlo->opcode() == HloOpcode::kTuple || //
hlo->opcode() == HloOpcode::kGetTupleElement || //
hlo->opcode() == HloOpcode::kInfeed || //
hlo->opcode() == HloOpcode::kOutfeed || //
hlo->opcode() == HloOpcode::kConstant || //
hlo->opcode() == HloOpcode::kParameter || //
hlo->opcode() == HloOpcode::kFusion || //
hlo->opcode() == HloOpcode::kConvert || //
hlo->opcode() == HloOpcode::kCall || //
hlo->opcode() == HloOpcode::kCustomCall || //
hlo->opcode() == HloOpcode::kWhile || //
hlo->opcode() == HloOpcode::kConditional) {
return Status::OK();
}
if (hlo == computation_->root_instruction() &&
!bfloat16_support_->SupportsMixedPrecisions(*hlo)) {
// If hlo is the root instruction, we cannot change its output, so folding
// can only happen when it supports mixed precision so that we can change
// its operands.
return Status::OK();
}
return TryFoldBF16Conversions(hlo);
}
StatusOr<bool> BFloat16ConversionFolding::Run(HloModule* module) {
XLA_VLOG_LINES(
2, "BFloat16ConversionFolding::Run(), before:\n" + module->ToString());
bool changed = false;
for (auto* comp : module->MakeNonfusionComputations()) {
if (BFloat16ConversionFoldingVisitor::Run(comp, bfloat16_support_)) {
changed = true;
}
}
XLA_VLOG_LINES(
2, "BFloat16ConversionFolding::Run(), after:\n" + module->ToString());
return changed;
}
} // namespace xla

View File

@ -0,0 +1,52 @@
/* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#ifndef TENSORFLOW_COMPILER_XLA_SERVICE_BFLOAT16_CONVERSION_FOLDING_H_
#define TENSORFLOW_COMPILER_XLA_SERVICE_BFLOAT16_CONVERSION_FOLDING_H_
#include "tensorflow/compiler/xla/service/bfloat16_support.h"
#include "tensorflow/compiler/xla/service/hlo_module.h"
#include "tensorflow/compiler/xla/service/hlo_pass_interface.h"
namespace xla {
// A pass which folds F32 <-> BF16 conversions to their operands or users, when
// it is supported by the backend.
//
// This pass follows the passed-in backend-specific BF16 support rules, but can
// introduce mixed precision in individual HLOs which breaks the assumption of
// some other HLO passes. So it should be used at the end of the HLO
// optimization pipeline followed by a DCE pass. If other passes are needed
// after this pass, run BFloat16MixedPrecisionRemoval first to undo some of the
// changed made by this pass.
class BFloat16ConversionFolding : public HloPassInterface {
public:
explicit BFloat16ConversionFolding(const BFloat16Support* bfloat16_support)
: bfloat16_support_(bfloat16_support) {}
~BFloat16ConversionFolding() override = default;
tensorflow::StringPiece name() const override { return "bfloat16-fold"; }
// Run BF16 conversion folding on the given computation. Returns whether the
// computation was changed.
StatusOr<bool> Run(HloModule* module) override;
private:
const BFloat16Support* bfloat16_support_;
};
} // namespace xla
#endif // TENSORFLOW_COMPILER_XLA_SERVICE_BFLOAT16_CONVERSION_FOLDING_H_

View File

@ -0,0 +1,209 @@
/* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#include "tensorflow/compiler/xla/service/bfloat16_conversion_folding.h"
#include "tensorflow/compiler/xla/service/bfloat16_support.h"
#include "tensorflow/compiler/xla/service/hlo_computation.h"
#include "tensorflow/compiler/xla/service/hlo_instruction.h"
#include "tensorflow/compiler/xla/service/hlo_module.h"
#include "tensorflow/compiler/xla/service/hlo_opcode.h"
#include "tensorflow/compiler/xla/shape_util.h"
#include "tensorflow/compiler/xla/test.h"
#include "tensorflow/compiler/xla/test_helpers.h"
#include "tensorflow/compiler/xla/tests/hlo_test_base.h"
#include "tensorflow/compiler/xla/xla_data.pb.h"
namespace xla {
class TestBFloat16Support : public BFloat16Support {
public:
TestBFloat16Support() {}
~TestBFloat16Support() override {}
bool SupportsBF16Operand(const HloInstruction& hlo,
int64 operand_index) const override {
if (hlo.opcode() == HloOpcode::kAdd ||
hlo.opcode() == HloOpcode::kSubtract ||
hlo.opcode() == HloOpcode::kTuple ||
hlo.opcode() == HloOpcode::kGetTupleElement) {
return true;
}
return false;
}
bool SupportsBF16Output(const HloInstruction& hlo) const override {
if (hlo.opcode() == HloOpcode::kAdd ||
hlo.opcode() == HloOpcode::kSubtract ||
hlo.opcode() == HloOpcode::kTuple ||
hlo.opcode() == HloOpcode::kGetTupleElement) {
return true;
}
return false;
}
bool SupportsMixedPrecisions(const HloInstruction& hlo) const override {
if (hlo.opcode() == HloOpcode::kAdd || hlo.opcode() == HloOpcode::kTuple ||
hlo.opcode() == HloOpcode::kGetTupleElement) {
return true;
}
return false;
}
};
class BFloat16ConversionFoldingTest : public HloTestBase {
protected:
bool FoldConversions(HloModule* module) {
TestBFloat16Support bfloat16_support_;
BFloat16ConversionFolding fold(&bfloat16_support_);
StatusOr<bool> result = fold.Run(module);
EXPECT_IS_OK(result.status());
return result.ValueOrDie();
}
};
TEST_F(BFloat16ConversionFoldingTest, FoldIfSupported) {
auto builder = HloComputation::Builder(TestName());
Shape f32_shape = ShapeUtil::MakeShape(F32, {2, 4});
Shape bf16_shape = ShapeUtil::MakeShape(BF16, {2, 4});
HloInstruction* a = builder.AddInstruction(
HloInstruction::CreateParameter(0, f32_shape, "a"));
HloInstruction* b = builder.AddInstruction(
HloInstruction::CreateParameter(1, f32_shape, "b"));
HloInstruction* c = builder.AddInstruction(
HloInstruction::CreateParameter(2, f32_shape, "c"));
HloInstruction* add0 = builder.AddInstruction(
HloInstruction::CreateBinary(f32_shape, HloOpcode::kAdd, a, b));
HloInstruction* convert0 =
builder.AddInstruction(HloInstruction::CreateConvert(bf16_shape, add0));
HloInstruction* convert1 = builder.AddInstruction(
HloInstruction::CreateConvert(f32_shape, convert0));
HloInstruction* add1 = builder.AddInstruction(
HloInstruction::CreateBinary(f32_shape, HloOpcode::kAdd, convert1, c));
builder.AddInstruction(HloInstruction::CreateConvert(bf16_shape, add1));
auto module = CreateNewModule();
auto computation = module->AddEntryComputation(builder.Build());
EXPECT_TRUE(FoldConversions(module.get()));
EXPECT_EQ(computation->root_instruction(), add1);
EXPECT_EQ(add0->shape().element_type(), BF16);
EXPECT_EQ(add1->shape().element_type(), BF16);
EXPECT_EQ(add1->operand(0), add0);
}
TEST_F(BFloat16ConversionFoldingTest, DoNotFoldIfUnsupported) {
auto builder = HloComputation::Builder(TestName());
Shape f32_shape = ShapeUtil::MakeShape(F32, {2, 4});
Shape bf16_shape = ShapeUtil::MakeShape(BF16, {2, 4});
HloInstruction* a = builder.AddInstruction(
HloInstruction::CreateParameter(0, f32_shape, "a"));
HloInstruction* b = builder.AddInstruction(
HloInstruction::CreateParameter(1, f32_shape, "b"));
HloInstruction* c = builder.AddInstruction(
HloInstruction::CreateParameter(2, f32_shape, "c"));
HloInstruction* mul0 = builder.AddInstruction(
HloInstruction::CreateBinary(f32_shape, HloOpcode::kMultiply, a, b));
HloInstruction* convert0 =
builder.AddInstruction(HloInstruction::CreateConvert(bf16_shape, mul0));
HloInstruction* convert1 = builder.AddInstruction(
HloInstruction::CreateConvert(f32_shape, convert0));
HloInstruction* mul1 = builder.AddInstruction(HloInstruction::CreateBinary(
f32_shape, HloOpcode::kMultiply, convert1, c));
HloInstruction* convert2 =
builder.AddInstruction(HloInstruction::CreateConvert(bf16_shape, mul1));
auto module = CreateNewModule();
auto computation = module->AddEntryComputation(builder.Build());
EXPECT_FALSE(FoldConversions(module.get()));
EXPECT_EQ(computation->root_instruction(), convert2);
EXPECT_EQ(mul0->shape().element_type(), F32);
EXPECT_EQ(mul1->shape().element_type(), F32);
EXPECT_EQ(mul1->operand(0), convert1);
}
TEST_F(BFloat16ConversionFoldingTest, DoNotFoldUnsupportedMixedPrecision) {
auto builder = HloComputation::Builder(TestName());
Shape f32_shape = ShapeUtil::MakeShape(F32, {2, 4});
Shape bf16_shape = ShapeUtil::MakeShape(BF16, {2, 4});
HloInstruction* a = builder.AddInstruction(
HloInstruction::CreateParameter(0, f32_shape, "a"));
HloInstruction* b = builder.AddInstruction(
HloInstruction::CreateParameter(1, f32_shape, "b"));
HloInstruction* c = builder.AddInstruction(
HloInstruction::CreateParameter(2, f32_shape, "c"));
HloInstruction* sub0 = builder.AddInstruction(
HloInstruction::CreateBinary(f32_shape, HloOpcode::kSubtract, a, b));
HloInstruction* convert0 =
builder.AddInstruction(HloInstruction::CreateConvert(bf16_shape, sub0));
HloInstruction* convert1 = builder.AddInstruction(
HloInstruction::CreateConvert(f32_shape, convert0));
HloInstruction* sub1 = builder.AddInstruction(HloInstruction::CreateBinary(
f32_shape, HloOpcode::kSubtract, convert1, c));
HloInstruction* convert2 =
builder.AddInstruction(HloInstruction::CreateConvert(bf16_shape, sub1));
auto module = CreateNewModule();
auto computation = module->AddEntryComputation(builder.Build());
EXPECT_FALSE(FoldConversions(module.get()));
EXPECT_EQ(computation->root_instruction(), convert2);
EXPECT_EQ(sub0->shape().element_type(), F32);
EXPECT_EQ(sub1->shape().element_type(), F32);
EXPECT_EQ(sub1->operand(0), convert1);
}
TEST_F(BFloat16ConversionFoldingTest, DoNotFoldTuple) {
auto builder = HloComputation::Builder(TestName());
Shape f32_shape = ShapeUtil::MakeShape(F32, {2, 4});
Shape bf16_shape = ShapeUtil::MakeShape(BF16, {2, 4});
HloInstruction* a = builder.AddInstruction(
HloInstruction::CreateParameter(0, f32_shape, "a"));
HloInstruction* b = builder.AddInstruction(
HloInstruction::CreateParameter(1, bf16_shape, "b"));
HloInstruction* convert0 =
builder.AddInstruction(HloInstruction::CreateConvert(f32_shape, b));
HloInstruction* tuple =
builder.AddInstruction(HloInstruction::CreateTuple({a, convert0}));
HloInstruction* gte = builder.AddInstruction(
HloInstruction::CreateGetTupleElement(f32_shape, tuple, 0));
HloInstruction* convert1 =
builder.AddInstruction(HloInstruction::CreateConvert(bf16_shape, gte));
auto module = CreateNewModule();
auto computation = module->AddEntryComputation(builder.Build());
EXPECT_FALSE(FoldConversions(module.get()));
EXPECT_EQ(computation->root_instruction(), convert1);
EXPECT_EQ(gte->shape().element_type(), F32);
EXPECT_EQ(tuple->operand(1), convert0);
}
} // namespace xla

View File

@ -0,0 +1,351 @@
/* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#include "tensorflow/compiler/xla/service/bfloat16_normalization.h"
#include "tensorflow/compiler/xla/service/hlo_computation.h"
#include "tensorflow/compiler/xla/service/hlo_instruction.h"
#include "tensorflow/compiler/xla/service/hlo_opcode.h"
#include "tensorflow/compiler/xla/status_macros.h"
#include "tensorflow/compiler/xla/xla_data.pb.h"
#include "tensorflow/core/lib/gtl/array_slice.h"
#include "tensorflow/core/platform/logging.h"
#include "tensorflow/core/platform/types.h"
namespace xla {
class BFloat16NormalizationVisitor : public DfsHloVisitorWithDefault {
public:
explicit BFloat16NormalizationVisitor(HloComputation* computation,
const BFloat16Support* bfloat16_support)
: computation_(computation), bfloat16_support_(bfloat16_support) {}
Status DefaultAction(HloInstruction* hlo) override;
// Special handling for cross-replica-sum which can have a tuple output.
Status HandleCrossReplicaSum(HloInstruction* crs) override;
static bool Run(HloComputation* computation,
const BFloat16Support* bfloat16_support) {
BFloat16NormalizationVisitor visitor(computation, bfloat16_support);
TF_CHECK_OK(computation->Accept(&visitor));
return visitor.changed_;
}
private:
// Checks if the HLO uses BF16 in an unsupported way, and if so, inserts
// conversions between F32 and BF16 to make it supported.
Status HandleInstruction(HloInstruction* hlo);
// Inserts a conversion HLO that changes the given HLO's output type.
Status InsertConvertAfterOutput(HloInstruction* hlo, PrimitiveType to,
HloComputation* computation);
// Changes the output type to the specified type, then inserts a conversion
// to the original type.
Status ChangeOutputTypeThenInsertConvertBack(HloInstruction* hlo,
PrimitiveType to,
HloComputation* computation);
// Inserts a conversion HLO that changes the given HLO's operand type.
Status InsertConvertBeforeOperand(HloInstruction* hlo, int64 operand_idx,
PrimitiveType to,
HloComputation* computation);
// Inserts conversion HLOs to replace the called computations' BF16
// operands/outputs to F32.
Status ConvertCalledComputations(
HloInstruction* hlo,
tensorflow::gtl::ArraySlice<HloComputation*> bf16_called_comps);
HloComputation* computation_;
const BFloat16Support* bfloat16_support_;
bool changed_ = false;
};
Status BFloat16NormalizationVisitor::InsertConvertAfterOutput(
HloInstruction* hlo, PrimitiveType to, HloComputation* computation) {
bool is_root = computation->root_instruction() == hlo;
std::vector<HloInstruction*> materialized_users = hlo->users();
// Use inst's shape temporarily, in order to pass checks in ReplaceUseWith.
auto convert = computation->AddInstruction(
HloInstruction::CreateConvert(hlo->shape(), hlo));
for (auto* user : materialized_users) {
TF_RETURN_IF_ERROR(hlo->ReplaceUseWith(user, convert));
}
if (is_root) {
computation->set_root_instruction(convert);
}
convert->mutable_shape()->set_element_type(to);
changed_ = true;
return Status::OK();
}
Status BFloat16NormalizationVisitor::ChangeOutputTypeThenInsertConvertBack(
HloInstruction* hlo, PrimitiveType to, HloComputation* computation) {
auto original_type = hlo->shape().element_type();
hlo->mutable_shape()->set_element_type(to);
return InsertConvertAfterOutput(hlo, original_type, computation);
}
Status BFloat16NormalizationVisitor::InsertConvertBeforeOperand(
HloInstruction* hlo, int64 operand_idx, PrimitiveType to,
HloComputation* computation) {
auto operand = hlo->mutable_operand(operand_idx);
auto convert = computation->AddInstruction(HloInstruction::CreateConvert(
ShapeUtil::ChangeElementType(operand->shape(), to), operand));
TF_RETURN_IF_ERROR(hlo->ReplaceOperandWith(operand_idx, convert));
changed_ = true;
return Status::OK();
}
Status BFloat16NormalizationVisitor::ConvertCalledComputations(
HloInstruction* hlo,
tensorflow::gtl::ArraySlice<HloComputation*> bf16_called_comps) {
std::map<HloComputation*, HloComputation*> cloned_computations;
for (auto& comp : bf16_called_comps) {
auto cloned = comp->parent()->AddEmbeddedComputation(comp->Clone());
cloned_computations[comp] = cloned;
changed_ = true;
}
hlo->ReplaceCalledComputations([&](HloComputation* comp) {
auto it = cloned_computations.find(comp);
if (it != cloned_computations.end()) {
return it->second;
}
return comp;
});
for (auto& comp_pair : cloned_computations) {
auto comp = comp_pair.second;
if (comp->root_instruction()->shape().element_type() == BF16) {
TF_RETURN_IF_ERROR(
InsertConvertAfterOutput(comp->root_instruction(), F32, comp));
}
for (auto* param : comp->parameter_instructions()) {
if (param->shape().element_type() == BF16) {
// This changes the parameter to F32 then inserts a convert after it.
TF_RETURN_IF_ERROR(
ChangeOutputTypeThenInsertConvertBack(param, F32, comp));
}
}
}
return Status::OK();
}
Status BFloat16NormalizationVisitor::HandleCrossReplicaSum(
HloInstruction* crs) {
if (!ShapeUtil::IsTuple(crs->shape())) {
return HandleInstruction(crs);
}
std::vector<PrimitiveType> operand_types(crs->operand_count());
std::vector<PrimitiveType> output_types(crs->operand_count());
bool has_f32 = false;
bool has_bf16 = false;
bool has_bf16_output = false;
for (int64 i = 0; i < crs->operand_count(); ++i) {
operand_types[i] = crs->operand(i)->shape().element_type();
output_types[i] = ShapeUtil::GetSubshape(crs->shape(), {i}).element_type();
if (operand_types[i] == F32 || output_types[i] == F32) {
has_f32 = true;
} else if (operand_types[i] == BF16) {
has_bf16 = true;
}
if (output_types[i] == BF16) {
has_bf16 = true;
has_bf16_output = true;
}
}
for (int64 i = 0; i < crs->operand_count(); ++i) {
if (operand_types[i] != BF16) {
continue;
}
if (bfloat16_support_->SupportsBF16Operand(*crs, i) &&
(bfloat16_support_->SupportsMixedPrecisions(*crs) || !has_f32)) {
continue;
}
TF_RETURN_IF_ERROR(InsertConvertBeforeOperand(crs, i, F32, computation_));
has_f32 = true;
}
if (!has_bf16_output) {
return Status::OK();
}
if (bfloat16_support_->SupportsBF16Output(*crs) &&
(bfloat16_support_->SupportsMixedPrecisions(*crs) || !has_f32)) {
return Status::OK();
}
std::vector<HloInstruction*> output_elements(crs->operand_count());
auto original_shape = crs->shape();
for (int64 i = 0; i < crs->operand_count(); ++i) {
auto subshape = ShapeUtil::GetMutableSubshape(crs->mutable_shape(), {i});
if (output_types[i] != BF16) {
output_elements[i] = computation_->AddInstruction(
HloInstruction::CreateGetTupleElement(*subshape, crs, i));
continue;
}
subshape->set_element_type(F32);
auto gte = computation_->AddInstruction(
HloInstruction::CreateGetTupleElement(*subshape, crs, i));
output_elements[i] =
computation_->AddInstruction(HloInstruction::CreateConvert(
ShapeUtil::ChangeElementType(*subshape, BF16), gte));
}
auto tuple = computation_->AddInstruction(
HloInstruction::CreateTuple(output_elements));
std::vector<HloInstruction*> materialized_users = crs->users();
// Use the crs' shape temporarily, in order to pass checks in
// ReplaceUseWith.
*tuple->mutable_shape() = crs->shape();
for (auto* user : materialized_users) {
TF_RETURN_IF_ERROR(crs->ReplaceUseWith(user, tuple));
}
*tuple->mutable_shape() = original_shape;
return Status::OK();
}
Status BFloat16NormalizationVisitor::HandleInstruction(HloInstruction* hlo) {
std::vector<int64> bf16_operands;
std::vector<int64> f32_operands;
bool has_f32 = false;
bool has_bf16 = false;
for (int64 i = 0; i < hlo->operand_count(); ++i) {
if (hlo->operand(i)->shape().element_type() == F32) {
f32_operands.push_back(i);
has_f32 = true;
} else if (hlo->operand(i)->shape().element_type() == BF16) {
bf16_operands.push_back(i);
has_bf16 = true;
}
}
if (hlo->shape().element_type() == F32) {
has_f32 = true;
} else if (hlo->shape().element_type() == BF16) {
has_bf16 = true;
}
std::vector<HloComputation*> bf16_called_comps;
for (auto* comp : hlo->called_computations()) {
bool comp_has_bf16 = false;
if (comp->root_instruction()->shape().element_type() == F32) {
has_f32 = true;
} else if (comp->root_instruction()->shape().element_type() == BF16) {
has_bf16 = true;
comp_has_bf16 = true;
}
for (auto* param : comp->parameter_instructions()) {
if (param->shape().element_type() == F32) {
has_f32 = true;
} else if (param->shape().element_type() == BF16) {
has_bf16 = true;
comp_has_bf16 = true;
}
}
if (comp_has_bf16) {
bf16_called_comps.push_back(comp);
}
}
if (!bfloat16_support_->SupportsMixedPrecisions(*hlo) && has_bf16 &&
has_f32) {
// Resolve unsupported mixed precision.
//
// See if we can change everything to BF16.
if (hlo->called_computations().empty() &&
hlo->shape().element_type() == BF16) {
bool can_use_bf16 = true;
for (int i : f32_operands) {
if (bfloat16_support_->EffectiveOperandPrecisionIsOutputPrecision(*hlo,
i) &&
bfloat16_support_->SupportsBF16Operand(*hlo, i)) {
continue;
}
can_use_bf16 = false;
break;
}
if (can_use_bf16) {
for (int i : f32_operands) {
TF_RETURN_IF_ERROR(
InsertConvertBeforeOperand(hlo, i, BF16, computation_));
}
return Status::OK();
}
}
if (hlo->shape().element_type() == BF16) {
TF_RETURN_IF_ERROR(
ChangeOutputTypeThenInsertConvertBack(hlo, F32, computation_));
}
for (int i : bf16_operands) {
TF_RETURN_IF_ERROR(InsertConvertBeforeOperand(hlo, i, F32, computation_));
}
return ConvertCalledComputations(hlo, bf16_called_comps);
}
for (int i : bf16_operands) {
if (!bfloat16_support_->SupportsBF16Operand(*hlo, i)) {
TF_RETURN_IF_ERROR(InsertConvertBeforeOperand(hlo, i, F32, computation_));
}
}
if (hlo->shape().element_type() == BF16 &&
!bfloat16_support_->SupportsBF16Output(*hlo)) {
TF_RETURN_IF_ERROR(
ChangeOutputTypeThenInsertConvertBack(hlo, F32, computation_));
}
return Status::OK();
}
Status BFloat16NormalizationVisitor::DefaultAction(HloInstruction* hlo) {
// Do not change instructions related to entry and exit of a computation,
// tuples, fusion, convert, and control flow.
if (hlo->opcode() == HloOpcode::kTuple || //
hlo->opcode() == HloOpcode::kGetTupleElement || //
hlo->opcode() == HloOpcode::kInfeed || //
hlo->opcode() == HloOpcode::kOutfeed || //
hlo->opcode() == HloOpcode::kConstant || //
hlo->opcode() == HloOpcode::kParameter || //
hlo->opcode() == HloOpcode::kFusion || //
hlo->opcode() == HloOpcode::kConvert || //
hlo->opcode() == HloOpcode::kCall || //
hlo->opcode() == HloOpcode::kCustomCall || //
hlo->opcode() == HloOpcode::kWhile || //
hlo->opcode() == HloOpcode::kConditional) {
return Status::OK();
}
return HandleInstruction(hlo);
}
StatusOr<bool> BFloat16Normalization::Run(HloModule* module) {
XLA_VLOG_LINES(
2, "BFloat16Normalization::Run(), before:\n" + module->ToString());
bool changed = false;
for (auto* comp : module->MakeComputationPostOrder()) {
if (BFloat16NormalizationVisitor::Run(comp, bfloat16_support_)) {
changed = true;
}
}
XLA_VLOG_LINES(2,
"BFloat16Normalization::Run(), after:\n" + module->ToString());
return changed;
}
} // namespace xla

View File

@ -0,0 +1,92 @@
/* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#ifndef TENSORFLOW_COMPILER_XLA_SERVICE_BFLOAT16_NORMALIZATION_H_
#define TENSORFLOW_COMPILER_XLA_SERVICE_BFLOAT16_NORMALIZATION_H_
#include "tensorflow/compiler/xla/service/bfloat16_support.h"
#include "tensorflow/compiler/xla/service/hlo_module.h"
#include "tensorflow/compiler/xla/service/hlo_pass_interface.h"
namespace xla {
// A pass which adds F32 <-> BF16 conversions for HLO instructions that do not
// support BF16 input/output or mixed precision, according to the passed-in
// backend-specific BF16 support rules.
class BFloat16Normalization : public HloPassInterface {
public:
explicit BFloat16Normalization(const BFloat16Support* bfloat16_support)
: bfloat16_support_(bfloat16_support) {}
~BFloat16Normalization() override = default;
tensorflow::StringPiece name() const override { return "bf16-normalization"; }
// Run BF16 normalization on the given computation. Returns whether the
// computation was changed.
StatusOr<bool> Run(HloModule* module) override;
private:
const BFloat16Support* bfloat16_support_;
};
// A pass that unconditionally removes the mixed F32/BF16 uses in HLO
// instructions (excluding convert) by adding F32 <-> BF16 conversions. Unlike
// BFloat16Normalization, this pass does not use a backend-specific
// BFloat16Support, and does not change HLOs that have BF16 data if they do not
// use mixed precision; it removes mixed precision even if the backend supports
// it. This pass is used to make the HLO module valid for other HLO passes which
// do not support mixed precision.
class BFloat16MixedPrecisionRemoval : public HloPassInterface {
public:
BFloat16MixedPrecisionRemoval() {}
~BFloat16MixedPrecisionRemoval() override = default;
tensorflow::StringPiece name() const override {
return "bf16-mixed-precision-removal";
}
// Run mixed precision removal on the given computation. Returns whether the
// computation was changed.
StatusOr<bool> Run(HloModule* module) override {
BFloat16Normalization normalization(&no_mixed_precision_support_);
return normalization.Run(module);
}
private:
class BFloat16SupportForMixedPrecisionRemoval : public BFloat16Support {
public:
BFloat16SupportForMixedPrecisionRemoval() {}
~BFloat16SupportForMixedPrecisionRemoval() override = default;
bool SupportsBF16Operand(const HloInstruction& hlo,
int64 operand_index) const override {
return true;
}
bool SupportsBF16Output(const HloInstruction& hlo) const override {
return true;
}
bool SupportsMixedPrecisions(const HloInstruction& hlo) const override {
return false;
}
} no_mixed_precision_support_;
};
} // namespace xla
#endif // TENSORFLOW_COMPILER_XLA_SERVICE_BFLOAT16_NORMALIZATION_H_

View File

@ -0,0 +1,248 @@
/* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#include "tensorflow/compiler/xla/service/bfloat16_normalization.h"
#include "tensorflow/compiler/xla/service/bfloat16_support.h"
#include "tensorflow/compiler/xla/service/hlo_computation.h"
#include "tensorflow/compiler/xla/service/hlo_instruction.h"
#include "tensorflow/compiler/xla/service/hlo_module.h"
#include "tensorflow/compiler/xla/service/hlo_opcode.h"
#include "tensorflow/compiler/xla/shape_util.h"
#include "tensorflow/compiler/xla/test.h"
#include "tensorflow/compiler/xla/test_helpers.h"
#include "tensorflow/compiler/xla/tests/hlo_test_base.h"
#include "tensorflow/compiler/xla/xla_data.pb.h"
namespace xla {
class TestBFloat16Support : public BFloat16Support {
public:
TestBFloat16Support() {}
~TestBFloat16Support() override {}
bool SupportsBF16Operand(const HloInstruction& hlo,
int64 operand_index) const override {
if (hlo.opcode() == HloOpcode::kAdd ||
hlo.opcode() == HloOpcode::kSubtract ||
hlo.opcode() == HloOpcode::kReduce ||
hlo.opcode() == HloOpcode::kTuple ||
hlo.opcode() == HloOpcode::kGetTupleElement) {
return true;
}
return false;
}
bool SupportsBF16Output(const HloInstruction& hlo) const override {
if (hlo.opcode() == HloOpcode::kAdd || hlo.opcode() == HloOpcode::kReduce ||
hlo.opcode() == HloOpcode::kSubtract ||
hlo.opcode() == HloOpcode::kTuple ||
hlo.opcode() == HloOpcode::kGetTupleElement) {
return true;
}
return false;
}
bool SupportsMixedPrecisions(const HloInstruction& hlo) const override {
if (hlo.opcode() == HloOpcode::kAdd || hlo.opcode() == HloOpcode::kTuple ||
hlo.opcode() == HloOpcode::kGetTupleElement) {
return true;
}
return false;
}
};
class BFloat16NormalizationTest : public HloTestBase {
protected:
bool Normalize(HloModule* module) {
TestBFloat16Support bfloat16_support_;
BFloat16Normalization normalization(&bfloat16_support_);
StatusOr<bool> result = normalization.Run(module);
EXPECT_IS_OK(result.status());
return result.ValueOrDie();
}
};
TEST_F(BFloat16NormalizationTest, NoopIfSupported) {
auto builder = HloComputation::Builder(TestName());
Shape f32_shape = ShapeUtil::MakeShape(F32, {2, 4});
Shape bf16_shape = ShapeUtil::MakeShape(BF16, {2, 4});
HloInstruction* a = builder.AddInstruction(
HloInstruction::CreateParameter(0, f32_shape, "a"));
HloInstruction* b = builder.AddInstruction(
HloInstruction::CreateParameter(1, bf16_shape, "b"));
HloInstruction* c = builder.AddInstruction(
HloInstruction::CreateParameter(2, f32_shape, "c"));
HloInstruction* add0 = builder.AddInstruction(
HloInstruction::CreateBinary(bf16_shape, HloOpcode::kAdd, a, b));
HloInstruction* add1 = builder.AddInstruction(
HloInstruction::CreateBinary(f32_shape, HloOpcode::kAdd, add0, c));
auto module = CreateNewModule();
auto computation = module->AddEntryComputation(builder.Build());
EXPECT_FALSE(Normalize(module.get()));
EXPECT_EQ(computation->root_instruction(), add1);
EXPECT_EQ(add0->shape().element_type(), BF16);
EXPECT_EQ(add1->shape().element_type(), F32);
}
TEST_F(BFloat16NormalizationTest, ResolveIfUnsupportedBF16) {
auto builder = HloComputation::Builder(TestName());
Shape f32_shape = ShapeUtil::MakeShape(F32, {2, 4});
Shape bf16_shape = ShapeUtil::MakeShape(BF16, {2, 4});
HloInstruction* a = builder.AddInstruction(
HloInstruction::CreateParameter(0, f32_shape, "a"));
HloInstruction* b = builder.AddInstruction(
HloInstruction::CreateParameter(1, bf16_shape, "b"));
HloInstruction* c = builder.AddInstruction(
HloInstruction::CreateParameter(2, f32_shape, "c"));
HloInstruction* mul0 = builder.AddInstruction(
HloInstruction::CreateBinary(bf16_shape, HloOpcode::kMultiply, a, b));
HloInstruction* mul1 = builder.AddInstruction(
HloInstruction::CreateBinary(bf16_shape, HloOpcode::kMultiply, mul0, c));
auto module = CreateNewModule();
auto computation = module->AddEntryComputation(builder.Build());
EXPECT_TRUE(Normalize(module.get()));
EXPECT_EQ(computation->root_instruction()->opcode(), HloOpcode::kConvert);
EXPECT_EQ(computation->root_instruction()->operand(0), mul1);
EXPECT_EQ(mul0->shape().element_type(), F32);
EXPECT_EQ(mul1->shape().element_type(), F32);
EXPECT_EQ(mul1->operand(0)->opcode(), HloOpcode::kConvert);
}
TEST_F(BFloat16NormalizationTest, ResolveUnsupportedMixedPrecisionSubtraction) {
auto builder = HloComputation::Builder(TestName());
Shape f32_shape = ShapeUtil::MakeShape(F32, {2, 4});
Shape bf16_shape = ShapeUtil::MakeShape(BF16, {2, 4});
HloInstruction* a = builder.AddInstruction(
HloInstruction::CreateParameter(0, f32_shape, "a"));
HloInstruction* b = builder.AddInstruction(
HloInstruction::CreateParameter(1, bf16_shape, "b"));
HloInstruction* c = builder.AddInstruction(
HloInstruction::CreateParameter(2, f32_shape, "c"));
HloInstruction* sub0 = builder.AddInstruction(
HloInstruction::CreateBinary(bf16_shape, HloOpcode::kSubtract, a, b));
HloInstruction* sub1 = builder.AddInstruction(
HloInstruction::CreateBinary(bf16_shape, HloOpcode::kSubtract, sub0, c));
auto module = CreateNewModule();
auto computation = module->AddEntryComputation(builder.Build());
EXPECT_TRUE(Normalize(module.get()));
EXPECT_EQ(computation->root_instruction()->opcode(), HloOpcode::kConvert);
EXPECT_EQ(computation->root_instruction()->operand(0), sub1);
EXPECT_EQ(sub0->shape().element_type(), F32);
EXPECT_EQ(sub1->shape().element_type(), F32);
EXPECT_EQ(sub1->operand(0)->opcode(), HloOpcode::kConvert);
}
TEST_F(BFloat16NormalizationTest, ResolveUnsupportedMixedPrecisionReduce) {
Shape f32_input_shape = ShapeUtil::MakeShape(F32, {2, 4});
Shape f32_output_shape = ShapeUtil::MakeShape(F32, {4});
Shape bf16_scalar_shape = ShapeUtil::MakeShape(BF16, {2, 4});
auto reduce_comp_builder = HloComputation::Builder("reduce_comp");
auto reduce_comp_param0 = reduce_comp_builder.AddInstruction(
HloInstruction::CreateParameter(0, bf16_scalar_shape, "param0"));
auto reduce_comp_param1 = reduce_comp_builder.AddInstruction(
HloInstruction::CreateParameter(1, bf16_scalar_shape, "param1"));
reduce_comp_builder.AddInstruction(
HloInstruction::CreateBinary(bf16_scalar_shape, HloOpcode::kAdd,
reduce_comp_param0, reduce_comp_param1));
auto module = CreateNewModule();
auto reduce_computation =
module->AddEmbeddedComputation(reduce_comp_builder.Build());
auto builder = HloComputation::Builder(TestName());
HloInstruction* input = builder.AddInstruction(
HloInstruction::CreateParameter(0, f32_input_shape, "a"));
HloInstruction* init = builder.AddInstruction(
HloInstruction::CreateParameter(1, bf16_scalar_shape, "init"));
HloInstruction* reduce = builder.AddInstruction(HloInstruction::CreateReduce(
f32_output_shape, input, init, {0}, reduce_computation));
auto computation = module->AddEntryComputation(builder.Build());
EXPECT_TRUE(Normalize(module.get()));
EXPECT_EQ(computation->root_instruction(), reduce);
EXPECT_EQ(reduce->called_computations().size(), 1);
EXPECT_EQ(reduce->called_computations()[0]->num_parameters(), 2);
EXPECT_EQ(reduce->called_computations()[0]
->parameter_instruction(0)
->shape()
.element_type(),
F32);
EXPECT_EQ(reduce->called_computations()[0]
->parameter_instruction(1)
->shape()
.element_type(),
F32);
EXPECT_EQ(reduce->called_computations()[0]
->root_instruction()
->shape()
.element_type(),
F32);
EXPECT_EQ(reduce->shape().element_type(), F32);
EXPECT_EQ(reduce->operand(0), input);
EXPECT_EQ(input->shape().element_type(), F32);
EXPECT_EQ(reduce->operand(1)->opcode(), HloOpcode::kConvert);
EXPECT_EQ(reduce->operand(1)->shape().element_type(), F32);
}
TEST_F(BFloat16NormalizationTest, ResolveMixedPrecisionTupleCrossReplicaSum) {
auto builder = HloComputation::Builder(TestName());
Shape f32_shape = ShapeUtil::MakeShape(F32, {2, 4});
Shape bf16_shape = ShapeUtil::MakeShape(BF16, {2, 4});
HloInstruction* a = builder.AddInstruction(
HloInstruction::CreateParameter(0, f32_shape, "a"));
HloInstruction* b = builder.AddInstruction(
HloInstruction::CreateParameter(1, bf16_shape, "b"));
HloInstruction* crs =
builder.AddInstruction(HloInstruction::CreateCrossReplicaSum(
ShapeUtil::MakeTupleShape({f32_shape, bf16_shape}), {a, b}));
HloInstruction* gte = builder.AddInstruction(
HloInstruction::CreateGetTupleElement(bf16_shape, crs, 1));
auto module = CreateNewModule();
auto computation = module->AddEntryComputation(builder.Build());
EXPECT_TRUE(Normalize(module.get()));
EXPECT_EQ(computation->root_instruction(), gte);
EXPECT_EQ(gte->shape().element_type(), BF16);
EXPECT_EQ(crs->operand(1)->shape().element_type(), F32);
EXPECT_EQ(ShapeUtil::GetSubshape(crs->shape(), {1}).element_type(), F32);
}
} // namespace xla

View File

@ -0,0 +1,111 @@
/* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#include "tensorflow/compiler/xla/service/bfloat16_support.h"
#include "tensorflow/compiler/xla/service/hlo_instruction.h"
#include "tensorflow/compiler/xla/service/hlo_opcode.h"
namespace xla {
bool BFloat16Support::SupportsBF16Operand(const HloInstruction& hlo,
int64 operand_index) const {
switch (hlo.opcode()) {
case HloOpcode::kCall:
case HloOpcode::kConditional:
case HloOpcode::kCustomCall:
case HloOpcode::kGetTupleElement:
case HloOpcode::kTuple:
case HloOpcode::kWhile:
return true;
case HloOpcode::kConvert:
CHECK_EQ(operand_index, 0);
return hlo.operand(0)->shape().element_type() == BF16;
default:
break;
}
return false;
}
bool BFloat16Support::SupportsBF16Output(const HloInstruction& hlo) const {
switch (hlo.opcode()) {
case HloOpcode::kCall:
case HloOpcode::kConditional:
case HloOpcode::kCustomCall:
case HloOpcode::kGetTupleElement:
case HloOpcode::kTuple:
case HloOpcode::kWhile:
return true;
case HloOpcode::kConvert:
return hlo.shape().element_type() == BF16;
default:
break;
}
return false;
}
bool BFloat16Support::SupportsMixedPrecisions(const HloInstruction& hlo) const {
switch (hlo.opcode()) {
case HloOpcode::kCall:
case HloOpcode::kConditional:
case HloOpcode::kConvert:
case HloOpcode::kCustomCall:
case HloOpcode::kGetTupleElement:
case HloOpcode::kTuple:
case HloOpcode::kWhile:
return true;
default:
break;
}
return false;
}
/* static */
bool BFloat16Support::EffectiveOperandPrecisionIsOutputPrecision(
const HloInstruction& hlo, int64 operand_index) {
switch (hlo.opcode()) {
case HloOpcode::kAbs:
case HloOpcode::kBroadcast:
case HloOpcode::kClamp:
case HloOpcode::kConcatenate:
case HloOpcode::kCopy:
case HloOpcode::kGetTupleElement:
case HloOpcode::kMaximum:
case HloOpcode::kMinimum:
case HloOpcode::kPad:
case HloOpcode::kReshape:
case HloOpcode::kReverse:
case HloOpcode::kSlice:
case HloOpcode::kSort:
case HloOpcode::kTranspose:
case HloOpcode::kTuple:
return true;
case HloOpcode::kDynamicSlice:
return operand_index == 0;
case HloOpcode::kDynamicUpdateSlice:
return operand_index == 0 || operand_index == 1;
case HloOpcode::kSelect:
return operand_index == 1 || operand_index == 2;
default:
break;
}
return false;
}
bool BFloat16Support::EffectiveOperandPrecisionIsBF16(
const HloInstruction& hlo, int64 operand_index) const {
return false;
}
} // namespace xla

View File

@ -0,0 +1,60 @@
/* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#ifndef TENSORFLOW_COMPILER_XLA_SERVICE_BFLOAT16_SUPPORT_H_
#define TENSORFLOW_COMPILER_XLA_SERVICE_BFLOAT16_SUPPORT_H_
#include "tensorflow/compiler/xla/service/hlo_instruction.h"
#include "tensorflow/compiler/xla/service/hlo_opcode.h"
namespace xla {
class BFloat16Support {
public:
BFloat16Support() {}
virtual ~BFloat16Support() {}
// Returns whether the backend supports BF16 operand for the HLO instruction
// at the given index.
virtual bool SupportsBF16Operand(const HloInstruction& hlo,
int64 operand_index) const;
// Returns whether the backend supports BF16 output for the HLO instruction.
virtual bool SupportsBF16Output(const HloInstruction& hlo) const;
// Returns whether the backend support mixed precision: the operands, output,
// and parameters/output of the called computations can have different
// precisions (BF16 and F32).
virtual bool SupportsMixedPrecisions(const HloInstruction& hlo) const;
// Returns whether the given HLO inherits its BF16 operand precision at the
// given index, so even if the output is F32, elements in the output that
// depend on the BF16 operand will still have BF16 effective precision even if
// they have F32 format. Similarly, this also means if the output is BF16 then
// increasing the operand precision from BF16 to F32 will not change the
// output. This typically includes HLOs that pass elements from the operand to
// the output without arithmetic operations.
static bool EffectiveOperandPrecisionIsOutputPrecision(
const HloInstruction& hlo, int64 operand_index);
// Returns if the backend only uses BF16 precision for the operand at the
// specified index, even if the operand is F32.
virtual bool EffectiveOperandPrecisionIsBF16(const HloInstruction& hlo,
int64 operand_index) const;
};
} // namespace xla
#endif // TENSORFLOW_COMPILER_XLA_SERVICE_BFLOAT16_SUPPORT_H_

View File

@ -159,9 +159,6 @@ cc_library(
deps = [
":compiler_functor",
":cpu_runtime",
":cpu_runtime_avx",
":cpu_runtime_neon",
":cpu_runtime_sse4_1",
":custom_call_target_registry",
":disassembler",
":external_constant_pool",
@ -408,9 +405,6 @@ cc_library(
hdrs = ["compiler_functor.h"],
deps = [
":cpu_runtime",
":cpu_runtime_avx",
":cpu_runtime_neon",
":cpu_runtime_sse4_1",
":disassembler",
":llvm_ir_runtime",
"//tensorflow/compiler/xla:statusor",
@ -430,43 +424,6 @@ cc_library(
],
)
cc_library(
name = "cpu_runtime_sse4_1",
srcs = ["cpu_runtime_sse4_1.cc"],
hdrs = ["cpu_runtime_sse4_1.h"],
copts = ["-DEIGEN_AVOID_STL_ARRAY"],
visibility = ["//visibility:public"],
deps = [
"//tensorflow/core:framework_lite",
"//third_party/eigen3",
],
)
cc_library(
name = "cpu_runtime_avx",
srcs = ["cpu_runtime_avx.cc"],
hdrs = ["cpu_runtime_avx.h"],
copts = ["-DEIGEN_AVOID_STL_ARRAY"],
visibility = ["//visibility:public"],
deps = [
"//tensorflow/core:framework_lite",
"//third_party/eigen3",
],
)
cc_library(
name = "cpu_runtime_neon",
srcs = ["cpu_runtime_neon.cc"],
hdrs = ["cpu_runtime_neon.h"],
# runtime_copts() enables -mfpu=neon
copts = ["-DEIGEN_AVOID_STL_ARRAY"] + runtime_copts(),
visibility = ["//visibility:public"],
deps = [
"//tensorflow/core:framework_lite",
"//third_party/eigen3",
],
)
cc_library(
name = "cpu_runtime",
srcs = [

View File

@ -37,9 +37,6 @@ limitations under the License.
#include "llvm/Transforms/IPO/PassManagerBuilder.h"
#include "tensorflow/compiler/xla/ptr_util.h"
#include "tensorflow/compiler/xla/service/cpu/cpu_runtime.h"
#include "tensorflow/compiler/xla/service/cpu/cpu_runtime_avx.h"
#include "tensorflow/compiler/xla/service/cpu/cpu_runtime_neon.h"
#include "tensorflow/compiler/xla/service/cpu/cpu_runtime_sse4_1.h"
#include "tensorflow/compiler/xla/service/cpu/llvm_ir_runtime.h"
#include "tensorflow/compiler/xla/service/llvm_ir/llvm_util.h"
#include "tensorflow/compiler/xla/statusor.h"
@ -50,15 +47,6 @@ limitations under the License.
namespace xla {
namespace cpu {
/* static */ CompilerFunctor::VectorIntrinsics
CompilerFunctor::AllIntrinsics() {
VectorIntrinsics intrinsics;
intrinsics.sse_intrinsics = true;
intrinsics.avx_intrinsics = true;
intrinsics.neon_intrinsics = true;
return intrinsics;
}
/* Create filtered versions of the LLVM Pass Managers to filter out some
of the expensive passes.
Profiling:
@ -192,31 +180,8 @@ operator()(llvm::Module& module) const {
std::move(object_file), std::move(memory_buffer));
}
namespace {
// Returns the set of vectorized library functions supported for the target.
std::vector<llvm::VecDesc> VectorFunctionsForTargetLibraryInfoImpl(
llvm::Triple::ArchType arch, llvm::StringRef feature_string,
CompilerFunctor::VectorIntrinsics const& available_intrinsics) {
std::vector<llvm::VecDesc> vector_functions;
const llvm::VecDesc four_wide_vector_functions_neon[] = {
{"logf", runtime::kLogV4F32NEONSymbolName, 4},
{"llvm.log.f32", runtime::kLogV4F32NEONSymbolName, 4},
};
const llvm::VecDesc four_wide_vector_functions_sse[] = {
{"logf", runtime::kLogV4F32SSESymbolName, 4},
{"llvm.log.f32", runtime::kLogV4F32SSESymbolName, 4},
};
const llvm::VecDesc eight_wide_vector_functions_avx[] = {
{"logf", runtime::kLogV8F32AVXSymbolName, 8},
{"llvm.log.f32", runtime::kLogV8F32AVXSymbolName, 8},
};
// These functions are generated by XLA as LLVM IR, so they're always
// available.
const llvm::VecDesc ir_vector_functions[] = {
static std::vector<llvm::VecDesc> VectorFunctionsForTargetLibraryInfoImpl() {
std::vector<llvm::VecDesc> result = {
{"tanhf", runtime::kTanhV4F32SymbolName, 4},
{"llvm.tanh.f32", runtime::kTanhV4F32SymbolName, 4},
@ -228,50 +193,15 @@ std::vector<llvm::VecDesc> VectorFunctionsForTargetLibraryInfoImpl(
{"expf", runtime::kExpV8F32SymbolName, 8},
{"llvm.exp.f32", runtime::kExpV8F32SymbolName, 8},
{"logf", runtime::kLogV4F32SymbolName, 4},
{"llvm.log.f32", runtime::kLogV4F32SymbolName, 4},
{"logf", runtime::kLogV8F32SymbolName, 8},
{"llvm.log.f32", runtime::kLogV8F32SymbolName, 8},
};
llvm::SmallVector<llvm::StringRef, 32> features;
feature_string.split(features, ',', -1, /*KeepEmpty=*/false);
auto has_feature = [&features](const llvm::StringRef feature) {
return std::find(features.begin(), features.end(), feature) !=
features.end();
};
switch (arch) {
case llvm::Triple::x86:
case llvm::Triple::x86_64: {
if (has_feature("+sse4.1") && available_intrinsics.sse_intrinsics) {
vector_functions.insert(vector_functions.end(),
std::begin(four_wide_vector_functions_sse),
std::end(four_wide_vector_functions_sse));
}
if (has_feature("+avx") && available_intrinsics.avx_intrinsics) {
vector_functions.insert(vector_functions.end(),
std::begin(eight_wide_vector_functions_avx),
std::end(eight_wide_vector_functions_avx));
}
break;
}
case llvm::Triple::arm:
case llvm::Triple::aarch64: {
if (has_feature("+neon") && available_intrinsics.neon_intrinsics) {
vector_functions.insert(vector_functions.end(),
std::begin(four_wide_vector_functions_neon),
std::end(four_wide_vector_functions_neon));
}
break;
}
default:
break;
}
vector_functions.insert(vector_functions.end(),
std::begin(ir_vector_functions),
std::end(ir_vector_functions));
return vector_functions;
return result;
}
} // namespace
void CompilerFunctor::AddTargetInfoPasses(
llvm::legacy::PassManagerBase* passes) const {
@ -279,9 +209,7 @@ void CompilerFunctor::AddTargetInfoPasses(
auto target_library_info_impl =
MakeUnique<llvm::TargetLibraryInfoImpl>(target_triple);
target_library_info_impl->addVectorizableFunctions(
VectorFunctionsForTargetLibraryInfoImpl(
target_triple.getArch(), target_machine_->getTargetFeatureString(),
available_intrinsics_));
VectorFunctionsForTargetLibraryInfoImpl());
passes->add(
new llvm::TargetLibraryInfoWrapperPass(*target_library_info_impl));
passes->add(createTargetTransformInfoWrapperPass(

View File

@ -31,21 +31,10 @@ namespace cpu {
// Orc JIT compile layer.
class CompilerFunctor {
public:
// Describes the set of vector intrinsics available to the generated code.
struct VectorIntrinsics {
bool sse_intrinsics;
bool avx_intrinsics;
bool neon_intrinsics;
};
// Returns a VectorIntrinsics where all intrinsics are available.
static VectorIntrinsics AllIntrinsics();
explicit CompilerFunctor(
llvm::TargetMachine* target_machine, const Disassembler* disassembler,
int opt_level, bool optimize_for_size, bool enable_fast_math,
bool disable_expensive_passes,
const VectorIntrinsics& available_intrinsics,
LLVMCompiler::ModuleHook pre_optimization_hook = nullptr,
LLVMCompiler::ModuleHook post_optimization_hook = nullptr)
: target_machine_(target_machine),
@ -54,7 +43,6 @@ class CompilerFunctor {
optimize_for_size_(optimize_for_size),
enable_fast_math_(enable_fast_math),
disable_expensive_passes_(disable_expensive_passes),
available_intrinsics_(available_intrinsics),
pre_optimization_hook_(pre_optimization_hook),
post_optimization_hook_(post_optimization_hook) {}
@ -78,7 +66,6 @@ class CompilerFunctor {
const bool optimize_for_size_;
const bool enable_fast_math_;
const bool disable_expensive_passes_;
const VectorIntrinsics available_intrinsics_;
LLVMCompiler::ModuleHook pre_optimization_hook_;
LLVMCompiler::ModuleHook post_optimization_hook_;
};

View File

@ -888,8 +888,7 @@ CpuCompiler::CompileAheadOfTime(std::vector<std::unique_ptr<HloModule>> modules,
options::OptimizeForSizeRequested(module->config()),
module->config().debug_options().xla_enable_fast_math(),
module->config().debug_options().xla_llvm_disable_expensive_passes(),
CompilerFunctor::AllIntrinsics(), pre_optimization_ir_dump_hook,
post_optimization_ir_dump_hook);
pre_optimization_ir_dump_hook, post_optimization_ir_dump_hook);
llvm::object::OwningBinary<llvm::object::ObjectFile> object_file =
compiler_functor(llvm_module);
llvm::StringRef object_file_data_ref = object_file.getBinary()->getData();

View File

@ -1,37 +0,0 @@
/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#include "tensorflow/compiler/xla/service/cpu/cpu_runtime_avx.h"
#define EIGEN_USE_THREADS
#include "third_party/eigen3/Eigen/Core"
#ifdef TF_XLA_HAS_AVX
xla::cpu::runtime::V8F32AVX __xla_cpu_runtime_LogV8F32AVX(
xla::cpu::runtime::V8F32AVX x) {
return Eigen::internal::plog(x);
}
#endif // TF_XLA_HAS_AVX
namespace xla {
namespace cpu {
namespace runtime {
const char *const kLogV8F32AVXSymbolName = "__xla_cpu_runtime_LogV8F32AVX";
} // namespace runtime
} // namespace cpu
} // namespace xla

View File

@ -1,59 +0,0 @@
/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
// This header declares functions which may be called by the generated code on
// the CPU. Calls to these functions must be resolved explicitly in the JIT in
// xla::cpu::SimpleResolver. It also defines a per-CpuExecutable context
// which is used to cache expensive state and resources utilized by the
// aforementioned functions.
#ifndef TENSORFLOW_COMPILER_XLA_SERVICE_CPU_CPU_RUNTIME_AVX_H_
#define TENSORFLOW_COMPILER_XLA_SERVICE_CPU_CPU_RUNTIME_AVX_H_
#include "tensorflow/core/platform/macros.h"
#if defined(__AVX__)
#include <immintrin.h>
#define TF_XLA_HAS_AVX
#endif
namespace xla {
namespace cpu {
namespace runtime {
extern const char *const kLogV8F32AVXSymbolName;
#ifdef TF_XLA_HAS_AVX
typedef __m256 V8F32AVX;
#endif
} // namespace runtime
} // namespace cpu
} // namespace xla
extern "C" {
#ifdef TF_XLA_HAS_AVX
// The following functions are vectorized versions of a selection of libm
// library functions.
// References to these functions are created by the LLVM vectorizer.
xla::cpu::runtime::V8F32AVX __xla_cpu_runtime_ExpV8F32AVX(
xla::cpu::runtime::V8F32AVX x);
xla::cpu::runtime::V8F32AVX __xla_cpu_runtime_LogV8F32AVX(
xla::cpu::runtime::V8F32AVX x);
#endif
}
#endif // TENSORFLOW_COMPILER_XLA_SERVICE_CPU_CPU_RUNTIME_AVX_H_

View File

@ -1,46 +0,0 @@
/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#include "tensorflow/compiler/xla/service/cpu/cpu_runtime_neon.h"
#define EIGEN_USE_THREADS
#include "third_party/eigen3/Eigen/Core"
#ifdef TF_XLA_HAS_NEON
xla::cpu::runtime::V4F32NEON __xla_cpu_runtime_ExpV4F32NEON(
xla::cpu::runtime::V4F32NEON x) {
return Eigen::internal::pexp(x);
}
xla::cpu::runtime::V4F32NEON __xla_cpu_runtime_LogV4F32NEON(
xla::cpu::runtime::V4F32NEON x) {
Eigen::internal::Packet4f p = x;
return Eigen::internal::plog(p);
}
#endif // TF_XLA_HAS_NEON
namespace xla {
namespace cpu {
namespace runtime {
const char *const kExpV4F32NEONSymbolName = "__xla_cpu_runtime_ExpV4F32NEON";
const char *const kLogV4F32NEONSymbolName = "__xla_cpu_runtime_LogV4F32NEON";
} // namespace runtime
} // namespace cpu
} // namespace xla

View File

@ -1,62 +0,0 @@
/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#ifndef TENSORFLOW_COMPILER_XLA_SERVICE_CPU_CPU_RUNTIME_NEON_H_
#define TENSORFLOW_COMPILER_XLA_SERVICE_CPU_CPU_RUNTIME_NEON_H_
// This header declares functions which may be called by the generated code on
// the CPU. Calls to these functions must be resolved explicitly in the JIT in
// xla::cpu::SimpleResolver.
#include "tensorflow/core/platform/macros.h"
#ifdef __ARM_NEON__
// For the other runtimes (AVX, SSE4.1) we define the vector type directly using
// __attribute__((__vector_size__(*))). Unfortunately, the typedef for the ARM
// NEON SIMD types is not portable, so the type has to come from <arm_neon.h>
#include <arm_neon.h>
#define TF_XLA_HAS_NEON
#endif // __ARM_NEON__
namespace xla {
namespace cpu {
namespace runtime {
extern const char *const kExpV4F32NEONSymbolName;
extern const char *const kLogV4F32NEONSymbolName;
#ifdef TF_XLA_HAS_NEON
typedef float32x4_t V4F32NEON;
#endif // TF_XLA_HAS_NEON
} // namespace runtime
} // namespace cpu
} // namespace xla
extern "C" {
#ifdef TF_XLA_HAS_NEON
// The following functions are vectorized versions of a selection of libm
// library functions.
// References to these functions are created by the LLVM vectorizer.
xla::cpu::runtime::V4F32NEON __xla_cpu_runtime_ExpV4F32NEON(
xla::cpu::runtime::V4F32NEON x);
xla::cpu::runtime::V4F32NEON __xla_cpu_runtime_LogV4F32NEON(
xla::cpu::runtime::V4F32NEON x);
#endif // TF_XLA_HAS_NEON
}
#endif // TENSORFLOW_COMPILER_XLA_SERVICE_CPU_CPU_RUNTIME_NEON_H_

View File

@ -1,40 +0,0 @@
/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#include "tensorflow/compiler/xla/service/cpu/cpu_runtime_sse4_1.h"
#define EIGEN_USE_THREADS
#include "third_party/eigen3/Eigen/Core"
#ifdef TF_XLA_HAS_SSE4_1
xla::cpu::runtime::V4F32SSE __xla_cpu_runtime_LogV4F32SSE(
xla::cpu::runtime::V4F32SSE x) {
Eigen::internal::Packet4f p = x;
return Eigen::internal::plog(p);
}
#endif // TF_XLA_HAS_SSE4_1
namespace xla {
namespace cpu {
namespace runtime {
const char *const kLogV4F32SSESymbolName = "__xla_cpu_runtime_LogV4F32SSE";
} // namespace runtime
} // namespace cpu
} // namespace xla

View File

@ -1,59 +0,0 @@
/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
// This header declares functions which may be called by the generated code on
// the CPU. Calls to these functions must be resolved explicitly in the JIT in
// xla::cpu::SimpleResolver. It also defines a per-CpuExecutable context
// which is used to cache expensive state and resources utilized by the
// aforementioned functions.
#ifndef TENSORFLOW_COMPILER_XLA_SERVICE_CPU_CPU_RUNTIME_SSE4_1_H_
#define TENSORFLOW_COMPILER_XLA_SERVICE_CPU_CPU_RUNTIME_SSE4_1_H_
#include "tensorflow/core/platform/macros.h"
// MSVC does not have __SSE4_1__ macro. Eigen enables EIGEN_VECTORIZE_SSE4_1
// when __AVX__ is defined, we should do the same.
#if defined(__SSE4_1__) || (defined(_MSC_VER) && defined(__AVX__))
#include <smmintrin.h>
#define TF_XLA_HAS_SSE4_1
#endif
namespace xla {
namespace cpu {
namespace runtime {
extern const char *const kLogV4F32SSESymbolName;
#ifdef TF_XLA_HAS_SSE4_1
typedef __m128 V4F32SSE;
#endif
} // namespace runtime
} // namespace cpu
} // namespace xla
extern "C" {
#ifdef TF_XLA_HAS_SSE4_1
// The following functions are vectorized versions of a selection of libm
// library functions.
// References to these functions are created by the LLVM vectorizer.
xla::cpu::runtime::V4F32SSE __xla_cpu_runtime_LogV4F32SSE(
xla::cpu::runtime::V4F32SSE x);
#endif
}
#endif // TENSORFLOW_COMPILER_XLA_SERVICE_CPU_CPU_RUNTIME_SSE4_1_H_

View File

@ -21,6 +21,7 @@ limitations under the License.
#include "llvm/IR/Verifier.h"
#include "llvm/Transforms/Utils/Cloning.h"
#include "tensorflow/compiler/xla/service/cpu/vector_support_library.h"
#include "tensorflow/core/lib/core/casts.h"
#include "tensorflow/core/platform/logging.h"
namespace xla {
@ -31,6 +32,8 @@ const char* const kTanhV4F32SymbolName = "__xla_cpu_runtime_TanhV4F32";
const char* const kTanhV8F32SymbolName = "__xla_cpu_runtime_TanhV8F32";
const char* const kExpV4F32SymbolName = "__xla_cpu_runtime_ExpV4F32";
const char* const kExpV8F32SymbolName = "__xla_cpu_runtime_ExpV8F32";
const char* const kLogV4F32SymbolName = "__xla_cpu_runtime_LogV4F32AVX";
const char* const kLogV8F32SymbolName = "__xla_cpu_runtime_LogV8F32AVX";
namespace {
llvm::Function* EmitVectorF32TanhIfNeeded(llvm::Module* module,
@ -116,19 +119,19 @@ llvm::Function* EmitVectorF32ExpIfNeeded(llvm::Module* module,
// This implements the same polynomial approximation as implemented in Eigen3.
const double exp_hi = 88.3762626647950;
const double exp_lo = -88.3762626647949;
const float exp_hi = 88.3762626647950;
const float exp_lo = -88.3762626647949;
const double cephes_LOG2EF = 1.44269504088896341;
const double cephes_exp_C1 = 0.693359375;
const double cephes_exp_C2 = -2.12194440e-4;
const float cephes_LOG2EF = 1.44269504088896341;
const float cephes_exp_C1 = 0.693359375;
const float cephes_exp_C2 = -2.12194440e-4;
const double cephes_exp_p0 = 1.9875691500E-4;
const double cephes_exp_p1 = 1.3981999507E-3;
const double cephes_exp_p2 = 8.3334519073E-3;
const double cephes_exp_p3 = 4.1665795894E-2;
const double cephes_exp_p4 = 1.6666665459E-1;
const double cephes_exp_p5 = 5.0000001201E-1;
const float cephes_exp_p0 = 1.9875691500E-4;
const float cephes_exp_p1 = 1.3981999507E-3;
const float cephes_exp_p2 = 8.3334519073E-3;
const float cephes_exp_p3 = 4.1665795894E-2;
const float cephes_exp_p4 = 1.6666665459E-1;
const float cephes_exp_p5 = 5.0000001201E-1;
llvm::Value* input = &*vector_exp_function->arg_begin();
llvm::Value* input_clamped =
@ -146,7 +149,7 @@ llvm::Function* EmitVectorF32ExpIfNeeded(llvm::Module* module,
y = vsl.MulAdd(y, x, cephes_exp_p4);
y = vsl.MulAdd(y, x, cephes_exp_p5);
y = vsl.MulAdd(y, z, x);
y = vsl.Add(1.0, y);
y = vsl.Add(1.0f, y);
// VectorSupportLibrary (intentionally) can't juggle more than one type at a
// time so drop down to IRBuilder for this bit.
@ -167,9 +170,133 @@ llvm::Function* EmitVectorF32ExpIfNeeded(llvm::Module* module,
ir_builder.CreateRet(result);
CHECK(!llvm::verifyFunction(*vector_exp_function));
DCHECK(!llvm::verifyFunction(*vector_exp_function));
return vector_exp_function;
}
llvm::Function* EmitVectorF32LogIfNeeded(llvm::Module* module,
llvm::StringRef function_name,
int vector_width,
bool enable_fast_math) {
llvm::Function* vector_log_function = module->getFunction(function_name);
if (vector_log_function == nullptr) {
// If the function declaration is not present in the module, there can't be
// any calls to resolve. Don't emit the function in this case.
return nullptr;
}
llvm::LLVMContext* context = &module->getContext();
llvm::BasicBlock* vector_log_body =
llvm::BasicBlock::Create(*context, "body", vector_log_function);
llvm::IRBuilder<> ir_builder(vector_log_body);
llvm::FastMathFlags fast_math_flags;
fast_math_flags.setFast();
ir_builder.setFastMathFlags(fast_math_flags);
llvm::Value* input = &*vector_log_function->arg_begin();
VectorSupportLibrary vsl(F32, vector_width, &ir_builder, "log_f32");
const float half = 0.5;
// This implements the same polynomial approximation as implemented in Eigen3.
// Returns NaN for x < 0, -INF for x = 0
const float cephes_SQRTHF = 0.707106781186547524;
const float cephes_log_p0 = 7.0376836292E-2;
const float cephes_log_p1 = -1.1514610310E-1;
const float cephes_log_p2 = 1.1676998740E-1;
const float cephes_log_p3 = -1.2420140846E-1;
const float cephes_log_p4 = +1.4249322787E-1;
const float cephes_log_p5 = -1.6668057665E-1;
const float cephes_log_p6 = +2.0000714765E-1;
const float cephes_log_p7 = -2.4999993993E-1;
const float cephes_log_p8 = +3.3333331174E-1;
const float cephes_log_q1 = -2.12194440e-4;
const float cephes_log_q2 = 0.693359375;
// The smallest non denormalized float number.
const float min_norm_pos = tensorflow::bit_cast<float, int32>(0x00800000);
const float minus_inf = tensorflow::bit_cast<float, int32>(0xff800000);
// NB! This number is denormal and since TF sets the denormals-are-zero flag
// (and if TF didn't, -ffast-math would) trying to operate on this float using
// C++ operations (including, for instance, implicit conversion to double)
// will coerce this to zero.
const float inv_mant_mask = tensorflow::bit_cast<float, int32>(~0x7f800000);
// invalid_mask is set if x is negative or NaN (and therefore output
// must be NaN).
llvm::Value* invalid_mask = vsl.FCmpULEMask(input, vsl.GetZeroVector());
llvm::Value* iszero_mask = vsl.FCmpEQMask(input, vsl.GetZeroVector());
// Cut off denormalized stuff.
input = vsl.Max(min_norm_pos, input);
// VectorSupportLibrary (intentionally) can't juggle more than one type at a
// time so drop down to IRBuilder for this bit.
llvm::Value* vector_constant_0x7f =
ir_builder.CreateVectorSplat(vector_width, ir_builder.getInt32(0x7f));
llvm::Value* vector_constant_23 =
ir_builder.CreateVectorSplat(vector_width, ir_builder.getInt32(23));
llvm::Type* i32_vector_type =
llvm::VectorType::get(ir_builder.getInt32Ty(), vector_width);
llvm::Value* emm0 = ir_builder.CreateLShr(
ir_builder.CreateBitCast(input, i32_vector_type), vector_constant_23);
// Keep only the fractional part.
input = vsl.FloatAnd(input, inv_mant_mask);
input = vsl.FloatOr(input, half);
emm0 = ir_builder.CreateSub(emm0, vector_constant_0x7f);
llvm::Value* e =
vsl.Add(1.0f, ir_builder.CreateSIToFP(emm0, vsl.vector_type()));
// part2:
// if( x < SQRTHF ) {
// e -= 1;
// x = x + x - 1.0;
// } else { x = x - 1.0; }
llvm::Value* mask = vsl.FCmpOLTMask(input, cephes_SQRTHF);
llvm::Value* tmp = vsl.FloatAnd(input, mask);
input = vsl.Sub(input, 1.0);
e = vsl.Sub(e, vsl.FloatAnd(mask, 1.0));
input = vsl.Add(input, tmp);
llvm::Value* x2 = vsl.Mul(input, input);
llvm::Value* x3 = vsl.Mul(x2, input);
llvm::Value *y, *y1, *y2;
y = vsl.MulAdd(input, cephes_log_p0, cephes_log_p1);
y1 = vsl.MulAdd(input, cephes_log_p3, cephes_log_p4);
y2 = vsl.MulAdd(input, cephes_log_p6, cephes_log_p7);
y = vsl.MulAdd(y, input, cephes_log_p2);
y1 = vsl.MulAdd(y1, input, cephes_log_p5);
y2 = vsl.MulAdd(y2, input, cephes_log_p8);
y = vsl.MulAdd(y, x3, y1);
y = vsl.MulAdd(y, x3, y2);
y = vsl.Mul(y, x3);
y1 = vsl.Mul(cephes_log_q1, e);
tmp = vsl.Mul(half, x2);
y = vsl.Add(y, y1);
input = vsl.Sub(input, tmp);
y2 = vsl.Mul(cephes_log_q2, e);
input = vsl.Add(input, y);
input = vsl.Add(input, y2);
// Negative arg will be NAN, 0 will be -INF.
llvm::Value* or_lhs =
vsl.FloatAndNot(iszero_mask, vsl.FloatOr(input, invalid_mask));
llvm::Value* or_rhs = vsl.FloatAnd(iszero_mask, minus_inf);
llvm::Value* result = vsl.FloatOr(or_lhs, or_rhs);
ir_builder.CreateRet(result);
DCHECK(!llvm::verifyFunction(*vector_log_function));
return vector_log_function;
}
} // namespace
void RewriteIRRuntimeFunctions(llvm::Module* module, bool enable_fast_math) {
@ -187,11 +314,21 @@ void RewriteIRRuntimeFunctions(llvm::Module* module, bool enable_fast_math) {
EmitVectorF32ExpIfNeeded(module, kExpV8F32SymbolName,
/*vector_width=*/8, enable_fast_math);
auto* log_v4f32 =
EmitVectorF32LogIfNeeded(module, kLogV4F32SymbolName,
/*vector_width=*/4, enable_fast_math);
auto* log_v8f32 =
EmitVectorF32LogIfNeeded(module, kLogV8F32SymbolName,
/*vector_width=*/8, enable_fast_math);
// Gather all the call sites, force inline them and then delete the vector
// function bodies.
//
// TODO(b/73081976): Should we avoid inlining these intrinsics in some cases?
std::vector<llvm::CallInst*> calls_to_inline;
for (auto* function : {tanh_v4f32, tanh_v8f32, exp_v4f32, exp_v8f32}) {
for (auto* function :
{tanh_v4f32, tanh_v8f32, exp_v4f32, exp_v8f32, log_v4f32, log_v8f32}) {
if (function != nullptr) {
for (auto* user : function->users()) {
calls_to_inline.push_back(llvm::cast<llvm::CallInst>(user));
@ -204,7 +341,8 @@ void RewriteIRRuntimeFunctions(llvm::Module* module, bool enable_fast_math) {
CHECK(llvm::InlineFunction(call_to_inline, inline_function_info));
}
for (auto* function : {tanh_v4f32, tanh_v8f32, exp_v4f32, exp_v8f32}) {
for (auto* function :
{tanh_v4f32, tanh_v8f32, exp_v4f32, exp_v8f32, log_v4f32, log_v8f32}) {
if (function != nullptr) {
function->eraseFromParent();
}

View File

@ -27,6 +27,8 @@ extern const char* const kTanhV4F32SymbolName;
extern const char* const kTanhV8F32SymbolName;
extern const char* const kExpV4F32SymbolName;
extern const char* const kExpV8F32SymbolName;
extern const char* const kLogV4F32SymbolName;
extern const char* const kLogV8F32SymbolName;
// The following CPU runtime functions have LLVM-IR only implementations:
//

View File

@ -28,9 +28,6 @@ limitations under the License.
#include "llvm/Support/Host.h"
#include "tensorflow/compiler/xla/ptr_util.h"
#include "tensorflow/compiler/xla/service/cpu/cpu_runtime.h"
#include "tensorflow/compiler/xla/service/cpu/cpu_runtime_avx.h"
#include "tensorflow/compiler/xla/service/cpu/cpu_runtime_neon.h"
#include "tensorflow/compiler/xla/service/cpu/cpu_runtime_sse4_1.h"
#include "tensorflow/compiler/xla/service/cpu/custom_call_target_registry.h"
#include "tensorflow/compiler/xla/service/cpu/orc_jit_memory_mapper.h"
#include "tensorflow/compiler/xla/service/cpu/runtime_conv2d.h"
@ -101,27 +98,6 @@ llvm::StringRef GetHostCpuName() {
cpu_name.consume_back("-avx512");
return cpu_name;
}
CompilerFunctor::VectorIntrinsics GetAvailableIntrinsics() {
CompilerFunctor::VectorIntrinsics intrinsics;
#ifdef TF_XLA_HAS_SSE4_1
intrinsics.sse_intrinsics = true;
#else
intrinsics.sse_intrinsics = false;
#endif
#ifdef TF_XLA_HAS_AVX
intrinsics.avx_intrinsics = true;
#else
intrinsics.avx_intrinsics = false;
#endif
#ifdef TF_XLA_HAS_NEON
intrinsics.neon_intrinsics = true;
#else
intrinsics.neon_intrinsics = false;
#endif
return intrinsics;
}
} // namespace
SimpleOrcJIT::SimpleOrcJIT(const llvm::TargetOptions& target_options,
@ -169,13 +145,12 @@ SimpleOrcJIT::SimpleOrcJIT(const llvm::TargetOptions& target_options,
orc_jit_memory_mapper::GetInstance());
},
[this](llvm::orc::VModuleKey K) { return symbol_resolver_; }),
compile_layer_(
object_layer_,
CompilerFunctor(target_machine_.get(), &disassembler_, opt_level,
optimize_for_size, enable_fast_math,
disable_expensive_passes, GetAvailableIntrinsics(),
std::move(pre_optimization_hook),
std::move(post_optimization_hook))) {
compile_layer_(object_layer_,
CompilerFunctor(target_machine_.get(), &disassembler_,
opt_level, optimize_for_size,
enable_fast_math, disable_expensive_passes,
std::move(pre_optimization_hook),
std::move(post_optimization_hook))) {
VLOG(1) << "CPU target: " << target_machine_->getTargetCPU().str()
<< " features: " << target_machine_->getTargetFeatureString().str();
}
@ -240,15 +215,6 @@ bool RegisterKnownJITSymbols() {
REGISTER_CPU_RUNTIME_SYMBOL(EigenSingleThreadedConvF32);
REGISTER_CPU_RUNTIME_SYMBOL(EigenSingleThreadedMatMulF32);
REGISTER_CPU_RUNTIME_SYMBOL(EigenSingleThreadedMatMulF64);
#ifdef TF_XLA_HAS_NEON
REGISTER_CPU_RUNTIME_SYMBOL(LogV4F32NEON);
#endif
#ifdef TF_XLA_HAS_SSE4_1
REGISTER_CPU_RUNTIME_SYMBOL(LogV4F32SSE);
#endif
#ifdef TF_XLA_HAS_AVX
REGISTER_CPU_RUNTIME_SYMBOL(LogV8F32AVX);
#endif
REGISTER_CPU_RUNTIME_SYMBOL(ParallelForkJoin);
REGISTER_CPU_RUNTIME_SYMBOL(ReleaseInfeedBufferAfterDequeue);
REGISTER_CPU_RUNTIME_SYMBOL(ReleaseOutfeedBufferAfterPopulation);

View File

@ -103,15 +103,92 @@ llvm::Value* VectorSupportLibrary::Div(llvm::Value* lhs, llvm::Value* rhs) {
}
}
llvm::Value* VectorSupportLibrary::Clamp(llvm::Value* a, double low,
double high) {
llvm::Value* VectorSupportLibrary::Clamp(llvm::Value* a, float low,
float high) {
AssertCorrectTypes({a});
llvm::Type* type = a->getType();
CHECK_LT(low, high);
CHECK(scalar_type_->isFloatingPointTy());
return llvm_ir::EmitFloatMin(
llvm_ir::EmitFloatMax(a, llvm::ConstantFP::get(type, low), ir_builder_),
llvm::ConstantFP::get(type, high), ir_builder_);
llvm_ir::EmitFloatMax(a, GetConstantFloat(type, low), ir_builder_),
GetConstantFloat(type, high), ir_builder_);
}
llvm::Value* VectorSupportLibrary::FCmpEQMask(llvm::Value* lhs,
llvm::Value* rhs) {
AssertCorrectTypes({lhs, rhs});
return I1ToFloat(ir_builder()->CreateFCmpOEQ(lhs, rhs, name()));
}
llvm::Value* VectorSupportLibrary::FCmpOLTMask(llvm::Value* lhs,
llvm::Value* rhs) {
AssertCorrectTypes({lhs, rhs});
return I1ToFloat(ir_builder()->CreateFCmpOLT(lhs, rhs, name()));
}
llvm::Value* VectorSupportLibrary::FCmpULEMask(llvm::Value* lhs,
llvm::Value* rhs) {
AssertCorrectTypes({lhs, rhs});
return I1ToFloat(ir_builder()->CreateFCmpULE(lhs, rhs, name()));
}
llvm::Value* VectorSupportLibrary::I1ToFloat(llvm::Value* i1) {
bool is_vector = llvm::isa<llvm::VectorType>(i1->getType());
llvm::Type* integer_type = IntegerTypeForFloatSize(is_vector);
return ir_builder()->CreateBitCast(
ir_builder()->CreateSExt(i1, integer_type, name()),
is_vector ? vector_type() : scalar_type(), name());
}
llvm::Type* VectorSupportLibrary::IntegerTypeForFloatSize(bool vector) {
CHECK(scalar_type()->isFloatingPointTy());
const llvm::DataLayout& data_layout =
ir_builder()->GetInsertBlock()->getModule()->getDataLayout();
int64 float_size_bits = data_layout.getTypeSizeInBits(scalar_type());
llvm::Type* scalar_int_type = ir_builder()->getIntNTy(float_size_bits);
if (vector) {
return llvm::VectorType::get(scalar_int_type, vector_size());
} else {
return scalar_int_type;
}
}
llvm::Value* VectorSupportLibrary::BroadcastScalar(llvm::Value* x) {
CHECK_EQ(x->getType(), scalar_type());
return ir_builder()->CreateVectorSplat(vector_size(), x, name());
}
llvm::Value* VectorSupportLibrary::FloatAnd(llvm::Value* lhs,
llvm::Value* rhs) {
AssertCorrectTypes({lhs, rhs});
llvm::Type* int_type =
IntegerTypeForFloatSize(lhs->getType() == vector_type());
return ir_builder()->CreateBitCast(
ir_builder()->CreateAnd(
ir_builder()->CreateBitCast(lhs, int_type, name()),
ir_builder()->CreateBitCast(rhs, int_type, name()), name()),
vector_type());
}
llvm::Value* VectorSupportLibrary::FloatNot(llvm::Value* lhs) {
AssertCorrectTypes({lhs});
llvm::Type* int_type =
IntegerTypeForFloatSize(lhs->getType() == vector_type());
return ir_builder()->CreateBitCast(
ir_builder()->CreateNot(
ir_builder()->CreateBitCast(lhs, int_type, name()), name()),
vector_type());
}
llvm::Value* VectorSupportLibrary::FloatOr(llvm::Value* lhs, llvm::Value* rhs) {
AssertCorrectTypes({lhs, rhs});
llvm::Type* int_type =
IntegerTypeForFloatSize(lhs->getType() == vector_type());
return ir_builder()->CreateBitCast(
ir_builder()->CreateOr(ir_builder()->CreateBitCast(lhs, int_type, name()),
ir_builder()->CreateBitCast(rhs, int_type, name()),
name()),
vector_type(), name());
}
llvm::Value* VectorSupportLibrary::AddInternal(llvm::Value* lhs,

View File

@ -41,40 +41,82 @@ class VectorSupportLibrary {
llvm::Value* Mul(int64 lhs, llvm::Value* rhs) {
return Mul(ir_builder()->getInt64(lhs), rhs);
}
llvm::Value* Mul(double lhs, llvm::Value* rhs) {
return Mul(llvm::ConstantFP::get(rhs->getType(), lhs), rhs);
llvm::Value* Mul(float lhs, llvm::Value* rhs) {
return Mul(GetConstantFloat(rhs->getType(), lhs), rhs);
}
llvm::Value* Add(llvm::Value* lhs, llvm::Value* rhs);
llvm::Value* Add(int64 lhs, llvm::Value* rhs) {
return Add(ir_builder()->getInt64(lhs), rhs);
}
llvm::Value* Add(double lhs, llvm::Value* rhs) {
return Add(llvm::ConstantFP::get(vector_type(), lhs), rhs);
llvm::Value* Add(float lhs, llvm::Value* rhs) {
return Add(GetConstantFloat(rhs->getType(), lhs), rhs);
}
llvm::Value* Sub(llvm::Value* lhs, llvm::Value* rhs);
llvm::Value* Sub(llvm::Value* lhs, float rhs) {
return Sub(lhs, GetConstantFloat(lhs->getType(), rhs));
}
llvm::Value* Max(llvm::Value* lhs, llvm::Value* rhs);
llvm::Value* Max(float lhs, llvm::Value* rhs) {
return Max(GetConstantFloat(rhs->getType(), lhs), rhs);
}
llvm::Value* Div(llvm::Value* lhs, llvm::Value* rhs);
llvm::Value* MulAdd(llvm::Value* a, llvm::Value* b, llvm::Value* c) {
return Add(c, Mul(a, b));
}
llvm::Value* MulAdd(llvm::Value* a, llvm::Value* b, double c) {
return Add(llvm::ConstantFP::get(vector_type(), c), Mul(a, b));
llvm::Value* MulAdd(llvm::Value* a, llvm::Value* b, float c) {
return Add(GetConstantFloat(vector_type(), c), Mul(a, b));
}
llvm::Value* MulAdd(llvm::Value* a, double b, double c) {
return Add(llvm::ConstantFP::get(a->getType(), c),
Mul(a, llvm::ConstantFP::get(a->getType(), b)));
llvm::Value* MulAdd(llvm::Value* a, float b, float c) {
return Add(GetConstantFloat(a->getType(), c),
Mul(a, GetConstantFloat(a->getType(), b)));
}
llvm::Value* Floor(llvm::Value* a);
llvm::Value* Clamp(llvm::Value* a, double low, double high);
llvm::Value* SplatFloat(double d) {
return llvm::ConstantFP::get(vector_type(), d);
llvm::Value* Clamp(llvm::Value* a, float low, float high);
llvm::Value* SplatFloat(float d) {
return GetConstantFloat(vector_type(), d);
}
// These compare instructions return a floating point typed mask instead of an
// i1. For instance, on a vector typed input, lanes where the predicate is
// true get a float with all ones and other lanes get a float with all zeros.
// This is slightly odd from the perspective of LLVM's type system, but it
// makes kernel IR generation code written using VectorSupportLibrary (its
// raison d'etre) less cluttered.
llvm::Value* FCmpEQMask(llvm::Value* lhs, llvm::Value* rhs);
llvm::Value* FCmpULEMask(llvm::Value* lhs, llvm::Value* rhs);
llvm::Value* FCmpOLTMask(llvm::Value* lhs, llvm::Value* rhs);
llvm::Value* FCmpOLTMask(llvm::Value* lhs, float rhs) {
return FCmpOLTMask(lhs, GetConstantFloat(lhs->getType(), rhs));
}
// These boolean operations operate on the bitwise values of the floating
// point inputs. They return a (vector of) float(s) but like in the mask
// generating predicates above this type system oddity makes the kernel IR
// generation code less cluttered.
llvm::Value* FloatAnd(llvm::Value* lhs, llvm::Value* rhs);
llvm::Value* FloatAnd(llvm::Value* lhs, float rhs) {
return FloatAnd(lhs, GetConstantFloat(lhs->getType(), rhs));
}
llvm::Value* FloatOr(llvm::Value* lhs, llvm::Value* rhs);
llvm::Value* FloatOr(llvm::Value* lhs, float rhs) {
return FloatOr(lhs, GetConstantFloat(lhs->getType(), rhs));
}
llvm::Value* FloatNot(llvm::Value* lhs);
llvm::Value* FloatAndNot(llvm::Value* lhs, llvm::Value* rhs) {
return FloatAnd(FloatNot(lhs), rhs);
}
llvm::Value* BroadcastScalar(llvm::Value* x);
llvm::Value* BroadcastScalar(float d) {
return BroadcastScalar(GetConstantFloat(scalar_type(), d));
}
llvm::Value* ComputeOffsetPointer(llvm::Value* base_pointer,
@ -194,6 +236,17 @@ class VectorSupportLibrary {
std::vector<llvm::Value*> ComputeAvxOptimizedHorizontalSums(
std::vector<llvm::Value*> vectors, llvm::Value* init_values);
llvm::Type* IntegerTypeForFloatSize(bool vector);
llvm::Value* I1ToFloat(llvm::Value* i1);
llvm::Value* GetConstantFloat(llvm::Type* type, float f) {
llvm::Constant* scalar_value =
llvm::ConstantFP::get(type->getContext(), llvm::APFloat(f));
if (llvm::isa<llvm::VectorType>(type)) {
return llvm::ConstantVector::getSplat(vector_size(), scalar_value);
}
return scalar_value;
}
int64 vector_size_;
PrimitiveType primitive_type_;
llvm::IRBuilder<>* ir_builder_;

View File

@ -45,7 +45,7 @@ ConvolutionThunk::ConvolutionThunk(
const BufferAllocation::Slice& scratch_buffer, const Shape& input_shape,
const Shape& filter_shape, const Shape& output_shape, const Window& window,
const ConvolutionDimensionNumbers& dim_nums, int64 algorithm,
const HloInstruction* hlo)
bool tensor_ops_enabled, const HloInstruction* hlo)
: Thunk(Kind::kConvolution, hlo),
convolution_kind_(convolution_kind),
input_buffer_(input_buffer),
@ -58,7 +58,8 @@ ConvolutionThunk::ConvolutionThunk(
output_shape_(output_shape),
window_(window),
dim_nums_(dim_nums),
algorithm_(algorithm) {}
algorithm_(algorithm),
tensor_ops_enabled_(tensor_ops_enabled) {}
Status ConvolutionThunk::ExecuteOnStream(
const BufferAllocations& buffer_allocations, se::Stream* stream) {
@ -72,7 +73,7 @@ Status ConvolutionThunk::ExecuteOnStream(
buffer_allocations.GetDeviceAddress(scratch_buffer_);
se::dnn::AlgorithmConfig algorithm_config(
se::dnn::AlgorithmDesc(algorithm_, /*use_tensor_ops=*/false));
se::dnn::AlgorithmDesc(algorithm_, tensor_ops_enabled_));
TF_RETURN_IF_ERROR(RunCudnnConvolution(
convolution_kind_, input_shape_, filter_shape_, output_shape_, input_data,

View File

@ -59,7 +59,7 @@ class ConvolutionThunk : public Thunk {
const Shape& input_shape, const Shape& filter_shape,
const Shape& output_shape, const Window& window,
const ConvolutionDimensionNumbers& dim_nums, int64 algorithm,
const HloInstruction* hlo);
bool tensor_ops_enabled, const HloInstruction* hlo);
ConvolutionThunk(const ConvolutionThunk&) = delete;
ConvolutionThunk& operator=(const ConvolutionThunk&) = delete;
@ -99,6 +99,7 @@ class ConvolutionThunk : public Thunk {
const Window window_;
const ConvolutionDimensionNumbers dim_nums_;
int64 algorithm_;
bool tensor_ops_enabled_;
};
} // namespace gpu

View File

@ -172,7 +172,7 @@ string NumBytesToString(int64 bytes) {
// cache misses and doing extra work. Overall, caching doesn't seem worth the
// trouble, but we may want to revisit this if we ever find a model where
// caching would speed up compilation a lot.
optional<std::pair<int64, int64>>
optional<std::tuple<int64, bool, int64>>
CudnnConvolutionAlgorithmPicker::PickBestAlgorithm(
CudnnConvKind kind, const Shape& input_shape, const Shape& filter_shape,
const Shape& output_shape, const Window& window,
@ -260,8 +260,9 @@ CudnnConvolutionAlgorithmPicker::PickBestAlgorithm(
<< AlgorithmToString(best_result.algorithm()) << ", takes "
<< best_result.elapsed_time_in_ms() << "ms, and uses "
<< best_result_bytes_used << "B of scratch memory.";
return std::make_pair(best_result.algorithm().algo_id(),
best_result_bytes_used);
return std::make_tuple(best_result.algorithm().algo_id(),
best_result.algorithm().tensor_ops_enabled(),
best_result_bytes_used);
}
LOG(WARNING) << "All algorithms tried for convolution " << instr->ToString()
@ -277,19 +278,19 @@ StatusOr<bool> CudnnConvolutionAlgorithmPicker::RunOnInstruction(
const auto& lhs_shape = instr->operand(0)->shape();
const auto& rhs_shape = instr->operand(1)->shape();
const auto& conv_result_shape = instr->shape().tuple_shapes(0);
optional<std::pair<int64, int64>> alg_and_scratch_bytes;
optional<std::tuple<int64, bool, int64>> alg_scratch_and_tc;
if (call_target == kCudnnConvForwardCallTarget) {
alg_and_scratch_bytes = PickBestAlgorithm(
alg_scratch_and_tc = PickBestAlgorithm(
CudnnConvKind::kForward, /*input_shape=*/lhs_shape,
/*filter_shape=*/rhs_shape, /*output_shape=*/conv_result_shape,
instr->window(), instr->convolution_dimension_numbers(), instr);
} else if (call_target == kCudnnConvBackwardInputCallTarget) {
alg_and_scratch_bytes = PickBestAlgorithm(
alg_scratch_and_tc = PickBestAlgorithm(
CudnnConvKind::kBackwardInput, /*input_shape=*/conv_result_shape,
/*filter_shape=*/rhs_shape, /*output_shape=*/lhs_shape, instr->window(),
instr->convolution_dimension_numbers(), instr);
} else if (call_target == kCudnnConvBackwardFilterCallTarget) {
alg_and_scratch_bytes = PickBestAlgorithm(
alg_scratch_and_tc = PickBestAlgorithm(
CudnnConvKind::kBackwardFilter, /*input_shape=*/lhs_shape,
/*filter_shape=*/conv_result_shape, /*output_shape=*/rhs_shape,
instr->window(), instr->convolution_dimension_numbers(), instr);
@ -298,17 +299,20 @@ StatusOr<bool> CudnnConvolutionAlgorithmPicker::RunOnInstruction(
<< instr->ToString();
}
if (!alg_and_scratch_bytes.has_value()) {
if (!alg_scratch_and_tc.has_value()) {
return false;
}
int64 algorithm;
bool tensor_ops_enabled;
int64 scratch_bytes;
std::tie(algorithm, scratch_bytes) = *alg_and_scratch_bytes;
std::tie(algorithm, tensor_ops_enabled, scratch_bytes) = *alg_scratch_and_tc;
VLOG(1) << "Setting cudnn conv to use algorithm " << algorithm << " and "
<< NumBytesToString(scratch_bytes)
<< " of scratch memory: " << instr->ToString();
<< " of scratch memory: " << instr->ToString()
<< " tensor_ops_enabled: " << tensor_ops_enabled;
// Replace instr with a new CustomCall which has the correct algorithm, and
// whose output shape has the appropriate amount of scratch memory.
@ -318,10 +322,15 @@ StatusOr<bool> CudnnConvolutionAlgorithmPicker::RunOnInstruction(
ShapeUtil::MakeShape(U8, {scratch_bytes})});
HloInstruction* algorithm_hlo = computation->AddInstruction(
HloInstruction::CreateConstant(Literal::CreateR0<int64>(algorithm)));
HloInstruction* tensor_ops_enabled_hlo =
computation->AddInstruction(HloInstruction::CreateConstant(
Literal::CreateR0<bool>(tensor_ops_enabled)));
HloInstruction* new_call =
computation->AddInstruction(HloInstruction::CreateCustomCall(
new_call_shape,
{instr->mutable_operand(0), instr->mutable_operand(1), algorithm_hlo},
{instr->mutable_operand(0), instr->mutable_operand(1), algorithm_hlo,
tensor_ops_enabled_hlo},
instr->custom_call_target()));
new_call->set_window(instr->window());
new_call->set_convolution_dimension_numbers(

View File

@ -47,7 +47,7 @@ class CudnnConvolutionAlgorithmPicker : public HloPassInterface {
private:
StatusOr<bool> RunOnComputation(HloComputation* computation);
StatusOr<bool> RunOnInstruction(HloInstruction* instr);
tensorflow::gtl::optional<std::pair<int64, int64>> PickBestAlgorithm(
tensorflow::gtl::optional<std::tuple<int64, bool, int64>> PickBestAlgorithm(
CudnnConvKind kind, const Shape& input_shape, const Shape& filter_shape,
const Shape& output_shape, const Window& window,
const ConvolutionDimensionNumbers& dnums, HloInstruction* instr);

View File

@ -106,6 +106,9 @@ Status RunCudnnConvolution(
se::ScratchAllocator* scratch_allocator, const Window& window,
const ConvolutionDimensionNumbers& dnums, AlgorithmConfig algorithm,
Stream* stream, ProfileResult* profile_result /*= nullptr*/) {
VLOG(3) << "Convolution Algorithm: " << algorithm.algorithm().algo_id();
VLOG(3) << "tensor_ops_enabled: "
<< algorithm.algorithm().tensor_ops_enabled();
VLOG(3) << "Convolution kind: " << CudnnConvKindToString(kind);
VLOG(3) << "input shape: { " << ShapeUtil::HumanString(input_shape) << " }";
VLOG(3) << "filter shape: { " << ShapeUtil::HumanString(filter_shape) << " }";

View File

@ -79,9 +79,9 @@ StatusOr<bool> GpuCopyInsertion::Run(HloModule* module) {
TF_RETURN_IF_ERROR(copy_operand_if_constant(i));
}
} else if (IsCustomCallToDnnConvolution(*hlo)) {
// The last argument to a CUDNN convolution is its algorithm, which must
// be an HLO constant -- it shouldn't be copied.
for (int64 i = 0; i < hlo->operand_count() - 1; ++i) {
// The last two arguments to a CUDNN convolution are two HLO constants for
// cudnn algorithm and tensor_ops_enabled flag, which shouldn't be copied.
for (int64 i = 0; i < hlo->operand_count() - 2; ++i) {
TF_RETURN_IF_ERROR(copy_operand_if_constant(i));
}
} else if (ImplementedAsLibraryCall(*hlo)) {

View File

@ -63,10 +63,11 @@ bool IsCustomCallToDnnBatchNorm(const HloInstruction& hlo);
// strings.
//
// These CustomCalls have window() and convolution_dimension_numbers() set like
// regular convolution ops. They have the same LHS and RHS operands, plus one
// additional int64 operand, representing which cudnn algorithm to run. This
// operand must be an HLO constant. A value of -1 means that the implementation
// is free to choose the best algorithm it can.
// regular convolution ops. They have the same LHS and RHS operands, plus two
// additional constant operands: an int64 operand for the cudnn algorithm and
// a bool operand for whether tensor_ops is enabled. A value of -1 for the cudnn
// algorithm means that the implementation is free to choose the best algorithm
// it can.
//
// These calls output a tuple (conv_result, scratch_memory), where conv_result
// is the actual result of the convolution, and scratch_memory is temporary

View File

@ -393,6 +393,11 @@ Status IrEmitterUnnested::HandleCustomCall(HloInstruction* custom_call) {
CHECK(algorithm_inst->IsConstant()) << algorithm_inst->ToString();
int64 algorithm = algorithm_inst->literal().Get<int64>({});
const HloInstruction* tensor_ops_enabled_inst = custom_call->operand(3);
CHECK(tensor_ops_enabled_inst->IsConstant())
<< tensor_ops_enabled_inst->ToString();
bool tensor_ops_enabled = tensor_ops_enabled_inst->literal().Get<bool>({});
const auto& target = custom_call->custom_call_target();
std::unique_ptr<ConvolutionThunk> thunk;
if (target == kCudnnConvForwardCallTarget) {
@ -407,7 +412,7 @@ Status IrEmitterUnnested::HandleCustomCall(HloInstruction* custom_call) {
/*filter_shape=*/rhs_shape,
/*output_shape=*/conv_result_shape, //
custom_call->window(), custom_call->convolution_dimension_numbers(),
algorithm, custom_call);
algorithm, tensor_ops_enabled, custom_call);
} else if (target == kCudnnConvBackwardInputCallTarget) {
thunk = MakeUnique<ConvolutionThunk>(
CudnnConvKind::kBackwardInput,
@ -420,7 +425,7 @@ Status IrEmitterUnnested::HandleCustomCall(HloInstruction* custom_call) {
/*filter_shape=*/rhs_shape,
/*output_shape=*/lhs_shape, //
custom_call->window(), custom_call->convolution_dimension_numbers(),
algorithm, custom_call);
algorithm, tensor_ops_enabled, custom_call);
} else if (target == kCudnnConvBackwardFilterCallTarget) {
thunk = MakeUnique<ConvolutionThunk>(
CudnnConvKind::kBackwardFilter,
@ -433,7 +438,7 @@ Status IrEmitterUnnested::HandleCustomCall(HloInstruction* custom_call) {
/*filter_shape=*/conv_result_shape,
/*output_shape=*/rhs_shape, //
custom_call->window(), custom_call->convolution_dimension_numbers(),
algorithm, custom_call);
algorithm, tensor_ops_enabled, custom_call);
} else {
LOG(FATAL) << "Unexpected custom call target: "
<< custom_call->custom_call_target();

View File

@ -1805,7 +1805,8 @@ void HloInstruction::RemoveUser(HloInstruction* user) {
Status HloInstruction::ReplaceUseWith(HloInstruction* user,
HloInstruction* new_producer) {
TF_RET_CHECK(ShapeUtil::Compatible(shape(), new_producer->shape()))
TF_RET_CHECK(
ShapeUtil::CompatibleIgnoringFpPrecision(shape(), new_producer->shape()))
<< "this shape: " << ShapeUtil::HumanString(shape())
<< ", replacement shape: "
<< ShapeUtil::HumanString(new_producer->shape());
@ -1828,8 +1829,8 @@ Status HloInstruction::ReplaceOperandWith(int64 operand_num,
TF_RET_CHECK(operand_num >= 0);
TF_RET_CHECK(operand_num < operand_count());
HloInstruction* old_operand = mutable_operand(operand_num);
TF_RET_CHECK(
ShapeUtil::Compatible(old_operand->shape(), new_operand->shape()))
TF_RET_CHECK(ShapeUtil::CompatibleIgnoringFpPrecision(old_operand->shape(),
new_operand->shape()))
<< old_operand->shape().ShortDebugString() << " is not compatible with "
<< new_operand->shape().ShortDebugString();
operands_[operand_num] = new_operand;

View File

@ -13,6 +13,8 @@ See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#include <set>
#include "tensorflow/compiler/xla/service/hlo_verifier.h"
#include "tensorflow/compiler/xla/status_macros.h"
#include "tensorflow/core/lib/core/errors.h"
@ -164,6 +166,8 @@ Status ShapeVerifier::HandleBroadcast(HloInstruction* broadcast) {
// HLO broadcast has no exact analog at the proto level so there is no
// ShapeInference method. Check the output shape explicitly.
const Shape& operand_shape = broadcast->operand(0)->shape();
// Check for mixed precision.
TF_RETURN_IF_ERROR(CheckShape(broadcast, broadcast->shape()));
TF_RET_CHECK(ShapeUtil::Rank(operand_shape) ==
broadcast->dimensions().size());
for (int64 operand_dimension = 0;
@ -178,6 +182,8 @@ Status ShapeVerifier::HandleBroadcast(HloInstruction* broadcast) {
}
Status ShapeVerifier::HandleReshape(HloInstruction* reshape) {
// Check for mixed precision.
TF_RETURN_IF_ERROR(CheckShape(reshape, reshape->shape()));
TF_RET_CHECK(ShapeUtil::ElementsIn(reshape->shape()) ==
ShapeUtil::ElementsIn(reshape->operand(0)->shape()));
return tensorflow::Status::OK();
@ -359,13 +365,122 @@ Status ShapeVerifier::HandleBatchNormGrad(HloInstruction* batch_norm_grad) {
batch_norm_grad->feature_index()));
}
namespace {
// Checks that the instruction does not have mixed precision floating point
// inputs.
Status CheckMixedPrecisionOperands(const HloInstruction* instruction) {
switch (instruction->opcode()) {
// White list the following opcodes for mixed-precision check, because they
// involve data pass through or grouping via tuples, where the precisions
// of buffers can be different.
case HloOpcode::kCall:
case HloOpcode::kConditional:
case HloOpcode::kConstant:
case HloOpcode::kCrossReplicaSum:
case HloOpcode::kCustomCall:
case HloOpcode::kFusion:
case HloOpcode::kGetTupleElement:
case HloOpcode::kInfeed:
case HloOpcode::kOutfeed:
case HloOpcode::kParameter:
case HloOpcode::kRecv:
case HloOpcode::kRecvDone:
case HloOpcode::kReducePrecision:
case HloOpcode::kSelect:
case HloOpcode::kSend:
case HloOpcode::kSendDone:
case HloOpcode::kTuple:
case HloOpcode::kWhile:
break;
default: {
PrimitiveType fp_type = PRIMITIVE_TYPE_INVALID;
for (auto operand : instruction->operands()) {
TF_RETURN_IF_ERROR(ShapeUtil::ForEachSubshapeWithStatus(
operand->shape(),
[&](const Shape& subshape, const ShapeIndex& index) {
if (!ShapeUtil::ElementIsFloating(subshape)) {
return Status::OK();
}
if (fp_type == PRIMITIVE_TYPE_INVALID) {
fp_type = subshape.element_type();
} else if (fp_type != subshape.element_type()) {
return FailedPrecondition(
"Seen floating point types of different precisions in "
"%s, but mixed precision is disallowed.",
instruction->ToString().c_str());
}
return Status::OK();
}));
}
}
}
return Status::OK();
}
} // namespace
Status ShapeVerifier::CheckShape(const HloInstruction* instruction,
const Shape& expected_shape) {
if (!ShapeUtil::Compatible(instruction->shape(), expected_shape)) {
const Shape& inferred_shape) {
// If allow_mixed_precision_ is false, check if there are operands with
// different precisions. We need this check because ShapeInference allows
// mixed precision inputs.
if (!allow_mixed_precision_) {
TF_RETURN_IF_ERROR(CheckMixedPrecisionOperands(instruction));
}
// Check if the output shape matches the expected shape.
bool compatible;
// We treat BF16 and F32 as compatible types if mixed precision is allowed,
// but only when the instruction defines the BF16/F32 buffer.
switch (instruction->opcode()) {
case HloOpcode::kSelect:
if (ShapeUtil::IsTuple(inferred_shape) || !allow_mixed_precision_) {
// Select only defines the top-level buffer, which in this case is the
// tuple, so we cannot allow mixed precision.
compatible =
ShapeUtil::Compatible(instruction->shape(), inferred_shape);
} else {
compatible = ShapeUtil::CompatibleIgnoringFpPrecision(
instruction->shape(), inferred_shape);
}
break;
case HloOpcode::kGetTupleElement:
case HloOpcode::kTuple:
// Tuple and GetTupleElement do not define BF16/F32 buffers, so mixed
// precision is disallowed.
case HloOpcode::kConstant:
case HloOpcode::kBitcast:
case HloOpcode::kBitcastConvert:
case HloOpcode::kCall:
case HloOpcode::kConditional:
case HloOpcode::kConvert:
case HloOpcode::kCustomCall:
case HloOpcode::kInfeed:
case HloOpcode::kOutfeed:
case HloOpcode::kParameter:
case HloOpcode::kRecv:
case HloOpcode::kRecvDone:
case HloOpcode::kSend:
case HloOpcode::kSendDone:
case HloOpcode::kWhile:
// The above opcodes should match the expected shapes exactly.
compatible = ShapeUtil::Compatible(instruction->shape(), inferred_shape);
break;
default:
if (allow_mixed_precision_) {
compatible = ShapeUtil::CompatibleIgnoringFpPrecision(
instruction->shape(), inferred_shape);
} else {
compatible =
ShapeUtil::Compatible(instruction->shape(), inferred_shape);
}
}
if (!compatible) {
return InvalidArgument(
"Expected instruction to have shape compatible with %s, actual "
"shape is %s:\n%s",
ShapeUtil::HumanString(expected_shape).c_str(),
ShapeUtil::HumanString(inferred_shape).c_str(),
ShapeUtil::HumanString(instruction->shape()).c_str(),
instruction->ToString().c_str());
}
@ -373,14 +488,14 @@ Status ShapeVerifier::CheckShape(const HloInstruction* instruction,
}
Status ShapeVerifier::CheckShape(const HloInstruction* instruction,
const StatusOr<Shape>& expected_shape_status) {
if (!expected_shape_status.ok()) {
Status s = expected_shape_status.status();
const StatusOr<Shape>& inferred_shape_status) {
if (!inferred_shape_status.ok()) {
Status s = inferred_shape_status.status();
tensorflow::errors::AppendToMessage(&s, ", for instruction ",
instruction->ToString());
return s;
}
return CheckShape(instruction, expected_shape_status.ValueOrDie());
return CheckShape(instruction, inferred_shape_status.ValueOrDie());
}
Status ShapeVerifier::CheckUnaryShape(const HloInstruction* instruction) {

View File

@ -27,6 +27,10 @@ namespace xla {
// TODO(b/26024837): Check output shape for all instruction types.
class ShapeVerifier : public DfsHloVisitor {
public:
explicit ShapeVerifier() : allow_mixed_precision_(false) {}
explicit ShapeVerifier(bool allow_mixed_precision)
: allow_mixed_precision_(allow_mixed_precision) {}
Status HandleElementwiseUnary(HloInstruction* hlo) override;
Status HandleElementwiseBinary(HloInstruction* hlo) override;
Status HandleClamp(HloInstruction* clamp) override;
@ -81,14 +85,14 @@ class ShapeVerifier : public DfsHloVisitor {
}
protected:
// Check the instruction's shape against the given expected shape and return
// an appropriate error if there is a mismatch.
// Check the instruction's shape against the shape given by ShapeInference
// and return an appropriate error if there is a mismatch.
Status CheckShape(const HloInstruction* instruction,
const Shape& expected_shape);
const Shape& inferred_shape);
// Overload which takes a StatusOr to reduce boilerplate in the caller.
Status CheckShape(const HloInstruction* instruction,
const StatusOr<Shape>& expected_shape_status);
const StatusOr<Shape>& inferred_shape_status);
// Check a unary (binary, etc) instruction's shape against the inferred shape.
Status CheckUnaryShape(const HloInstruction* instruction);
@ -99,19 +103,32 @@ class ShapeVerifier : public DfsHloVisitor {
// Checks if the given two instructions shares the same channel id.
Status CheckSameChannel(const HloInstruction* instr1,
const HloInstruction* instr2);
private:
// Whether the inputs and output of an instruction can contain both F32s and
// BF16s. Tuples that include both F32s and BF16s are allowed regardless of
// this flag.
bool allow_mixed_precision_;
};
// HLO pass that verifies invariants of HLO instructions for each computation in
// the module.
class HloVerifier : public HloPassInterface {
public:
using ShapeVerifierFactory = std::function<std::unique_ptr<ShapeVerifier>()>;
// Uses standard shape inference.
explicit HloVerifier()
: shape_verifier_factory_([] { return MakeUnique<ShapeVerifier>(); }) {}
: shape_verifier_factory_(
[] { return MakeUnique<ShapeVerifier>(false); }) {}
explicit HloVerifier(bool allow_mixed_precision)
: shape_verifier_factory_([allow_mixed_precision] {
return MakeUnique<ShapeVerifier>(allow_mixed_precision);
}) {}
// Uses custom shape verification.
explicit HloVerifier(
std::function<std::unique_ptr<ShapeVerifier>()> shape_verifier_factory)
explicit HloVerifier(ShapeVerifierFactory shape_verifier_factory)
: shape_verifier_factory_(std::move(shape_verifier_factory)) {}
~HloVerifier() override = default;
@ -129,7 +146,7 @@ class HloVerifier : public HloPassInterface {
// expectations. This is a factory function because ShapeVerifier, Note that
// ShapeVerifier, being a DfsHloVisitor, is stateful. We want a clean object
// for each run of the verifier.
std::function<std::unique_ptr<ShapeVerifier>()> shape_verifier_factory_;
ShapeVerifierFactory shape_verifier_factory_;
};
} // namespace xla

View File

@ -209,7 +209,8 @@ tensorflow::Status VerifyReducerShape(const ProgramShape& reducer_shape,
}
// Check that init_value's shape is suitable for reducer_shape.
if (!ShapeUtil::Compatible(accumulator_shape, init_value_shape)) {
if (!ShapeUtil::CompatibleIgnoringFpPrecision(accumulator_shape,
init_value_shape)) {
return InvalidArgument(
"Reduction function's accumulator shape differs from the "
"init_value shape: %s vs %s",
@ -220,8 +221,8 @@ tensorflow::Status VerifyReducerShape(const ProgramShape& reducer_shape,
// Check that the inputs can be passed in as the second argument.
const Shape& input_element_shape =
ShapeUtil::MakeShape(input_element_type, {});
if (!ShapeUtil::Compatible(input_element_shape,
reducer_shape.parameters(1))) {
if (!ShapeUtil::CompatibleIgnoringFpPrecision(input_element_shape,
reducer_shape.parameters(1))) {
return InvalidArgument(
"Reduction function's second parameter shape differs from the "
"input type element type: %s vs %s",
@ -231,7 +232,8 @@ tensorflow::Status VerifyReducerShape(const ProgramShape& reducer_shape,
// Currently the accumulator and inputs must be the same type,
// though that restriction could be relaxed.
if (!ShapeUtil::Compatible(accumulator_shape, reducer_shape.parameters(1))) {
if (!ShapeUtil::CompatibleIgnoringFpPrecision(accumulator_shape,
reducer_shape.parameters(1))) {
return InvalidArgument(
"Reduction function's second parameter shape currently must "
"match the result shape. Got %s vs %s",
@ -394,11 +396,13 @@ StatusOr<Shape> InferWindowOutputShape(const Shape& base_shape,
dimension);
}
const Shape* arg_shape = nullptr;
PrimitiveType element_type = PRIMITIVE_TYPE_INVALID;
for (const Shape* shape : arg_shapes) {
TF_RETURN_IF_ERROR(
ExpectNotTupleOrOpaque(*shape, "operand of concatenation"));
if (!arg_shape) {
arg_shape = shape;
element_type = arg_shape->element_type();
continue;
}
if (ShapeUtil::Rank(*arg_shape) != ShapeUtil::Rank(*shape)) {
@ -409,7 +413,7 @@ StatusOr<Shape> InferWindowOutputShape(const Shape& base_shape,
ShapeUtil::HumanString(*arg_shape).c_str(), ShapeUtil::Rank(*shape),
ShapeUtil::HumanString(*shape).c_str());
}
if (arg_shape->element_type() != shape->element_type()) {
if (!ShapeUtil::SameElementTypeIgnoringFpPrecision(*arg_shape, *shape)) {
return InvalidArgument(
"cannot concatenate arrays with different element types: %s vs %s",
PrimitiveType_Name(arg_shape->element_type()).c_str(),
@ -431,6 +435,7 @@ StatusOr<Shape> InferWindowOutputShape(const Shape& base_shape,
ShapeUtil::HumanString(*shape).c_str(), dimension);
}
}
element_type = ShapeUtil::HigherPrecisionElementType(*shape, *arg_shape);
}
std::vector<int64> new_dimensions(arg_shape->dimensions().begin(),
@ -438,7 +443,7 @@ StatusOr<Shape> InferWindowOutputShape(const Shape& base_shape,
for (size_t i = 1; i < arg_shapes.size(); ++i) {
new_dimensions[dimension] += arg_shapes[i]->dimensions(dimension);
}
return ShapeUtil::MakeShape(arg_shape->element_type(), new_dimensions);
return ShapeUtil::MakeShape(element_type, new_dimensions);
}
/* static */ StatusOr<Shape> ShapeInference::InferConvertShape(
@ -536,7 +541,8 @@ StatusOr<Shape> InferWindowOutputShape(const Shape& base_shape,
ShapeUtil::HumanString(operand_shape).c_str(),
padding_config.ShortDebugString().c_str());
}
if (operand_shape.element_type() != padding_value_shape.element_type()) {
if (!ShapeUtil::SameElementTypeIgnoringFpPrecision(operand_shape,
padding_value_shape)) {
return InvalidArgument(
"the element types of the operands to pad do not match");
}
@ -548,7 +554,9 @@ StatusOr<Shape> InferWindowOutputShape(const Shape& base_shape,
std::max<int64>(operand_shape.dimensions(i) - 1, 0LL) *
padding_config.dimensions(i).interior_padding();
}
return ShapeUtil::MakeShape(operand_shape.element_type(), dimensions);
return ShapeUtil::MakeShape(
ShapeUtil::HigherPrecisionElementType(operand_shape, padding_value_shape),
dimensions);
}
// Current DotDimensionNumbers Requirements:
@ -673,7 +681,7 @@ Status ValidateDotDimensionNumbers(
};
// Check if both element types are the same.
if (lhs.element_type() != rhs.element_type()) {
if (!ShapeUtil::SameElementTypeIgnoringFpPrecision(lhs, rhs)) {
return fail("element types do not match");
}
@ -736,7 +744,8 @@ Status ValidateDotDimensionNumbers(
dimensions.push_back(rhs.dimensions(i));
}
}
Shape result = ShapeUtil::MakeShape(lhs.element_type(), dimensions);
Shape result = ShapeUtil::MakeShape(
ShapeUtil::HigherPrecisionElementType(lhs, rhs), dimensions);
TF_DCHECK_OK(ShapeUtil::ValidateShapeWithOptionalLayout(result));
VLOG(2) << "inferred dot shape: " << ShapeUtil::HumanString(result);
@ -767,7 +776,8 @@ ShapeInference::InferDegenerateDimensionBroadcastShape(
ShapeUtil::HumanString(rhs).c_str());
}
}
return ShapeUtil::MakeShape(lhs.element_type(), output_dimensions);
return ShapeUtil::MakeShape(ShapeUtil::HigherPrecisionElementType(lhs, rhs),
output_dimensions);
}
/* static */ StatusOr<Shape> ShapeInference::InferInDimBroadcastShape(
@ -829,6 +839,8 @@ ShapeInference::InferDegenerateDimensionBroadcastShape(
// specified in broadcast_dimensions are then changed to match the
// corresponding dimension size in smaller_shape.
Shape output_shape(larger_shape);
output_shape.set_element_type(
ShapeUtil::HigherPrecisionElementType(larger_shape, smaller_shape));
for (int i = 0; i < smaller_shape.dimensions_size(); ++i) {
int64 dimension_to_match = broadcast_dimensions.at(i);
@ -878,7 +890,7 @@ ShapeInference::InferDegenerateDimensionBroadcastShape(
TF_RETURN_IF_ERROR(
ExpectNotTupleOrOpaque(rhs, "rhs of elementwise binary operation"));
if (!ShapeUtil::SameElementType(lhs, rhs)) {
if (!ShapeUtil::SameElementTypeIgnoringFpPrecision(lhs, rhs)) {
return InvalidArgument(
"binary op %s with different element types: %s and %s",
BinaryOperation_Name(operation).c_str(),
@ -897,10 +909,11 @@ ShapeInference::InferDegenerateDimensionBroadcastShape(
}
}
if (ShapeUtil::Compatible(lhs, rhs)) {
if (ShapeUtil::CompatibleIgnoringFpPrecision(lhs, rhs)) {
// If the shapes are the same other than layout, the output shape is the
// same (elementwise op).
return lhs;
return ShapeUtil::ChangeElementType(
lhs, ShapeUtil::HigherPrecisionElementType(lhs, rhs));
}
if (ShapeUtil::Rank(lhs) == ShapeUtil::Rank(rhs)) {
@ -973,7 +986,7 @@ ShapeInference::InferDegenerateDimensionBroadcastShape(
TF_ASSIGN_OR_RETURN(const Shape& shape,
InferElementwiseBinaryOpShape(operation, lhs, rhs,
broadcast_dimensions));
if (lhs.element_type() == F32) {
if (lhs.element_type() == F32 && rhs.element_type() == F32) {
return ShapeUtil::ChangeElementType(shape, C64);
} else {
return Unimplemented("complex component type not supported");
@ -1078,12 +1091,13 @@ ShapeInference::InferDegenerateDimensionBroadcastShape(
TF_RETURN_IF_ERROR(
ExpectNotTupleOrOpaque(*arg_shapes[i], "operand of map"));
if (ShapeUtil::Compatible(*arg_shapes[i], *arg_shape)) {
if (ShapeUtil::CompatibleIgnoringFpPrecision(*arg_shapes[i], *arg_shape)) {
continue;
}
if (!ShapeUtil::IsTuple(*arg_shapes[i]) &&
!ShapeUtil::IsTuple(*arg_shape) &&
ShapeUtil::SameElementType(*arg_shapes[i], *arg_shape)) {
ShapeUtil::SameElementTypeIgnoringFpPrecision(*arg_shapes[i],
*arg_shape)) {
if (ShapeUtil::IsScalar(*arg_shapes[i])) {
continue;
}
@ -1148,7 +1162,8 @@ ShapeInference::InferDegenerateDimensionBroadcastShape(
i, ShapeUtil::HumanString(parameter_shape).c_str());
}
if (parameter_shape.element_type() != arg_shape->element_type()) {
if (!ShapeUtil::SameElementTypeIgnoringFpPrecision(parameter_shape,
*arg_shape)) {
return InvalidArgument(
"mapped computation's parameter type has to match argument element "
"type; got parameter %d shape: %s, argument shape: %s",
@ -1221,7 +1236,8 @@ ShapeInference::InferDegenerateDimensionBroadcastShape(
PrimitiveType_Name(operand_shape.element_type()).c_str());
}
if (!ShapeUtil::SameElementType(offset_shape, operand_shape)) {
if (!ShapeUtil::SameElementTypeIgnoringFpPrecision(offset_shape,
operand_shape)) {
return InvalidArgument(
"The inputs should have the same element type for batch-norm-training, "
"but the shape of offset factor is %s "
@ -1230,7 +1246,8 @@ ShapeInference::InferDegenerateDimensionBroadcastShape(
PrimitiveType_Name(operand_shape.element_type()).c_str());
}
if (!ShapeUtil::SameElementType(scale_shape, operand_shape)) {
if (!ShapeUtil::SameElementTypeIgnoringFpPrecision(scale_shape,
operand_shape)) {
return InvalidArgument(
"The inputs should have the same element type for batch-norm-training, "
"but the shape of scale factor is %s "
@ -1329,7 +1346,8 @@ ShapeInference::InferDegenerateDimensionBroadcastShape(
PrimitiveType_Name(operand_shape.element_type()).c_str());
}
if (!ShapeUtil::SameElementType(offset_shape, operand_shape)) {
if (!ShapeUtil::SameElementTypeIgnoringFpPrecision(offset_shape,
operand_shape)) {
return InvalidArgument(
"The inputs should have the same element type for "
"batch-norm-inference, "
@ -1339,7 +1357,8 @@ ShapeInference::InferDegenerateDimensionBroadcastShape(
PrimitiveType_Name(operand_shape.element_type()).c_str());
}
if (!ShapeUtil::SameElementType(scale_shape, operand_shape)) {
if (!ShapeUtil::SameElementTypeIgnoringFpPrecision(scale_shape,
operand_shape)) {
return InvalidArgument(
"The inputs should have the same element type for "
"batch-norm-inference, "
@ -1349,7 +1368,8 @@ ShapeInference::InferDegenerateDimensionBroadcastShape(
PrimitiveType_Name(operand_shape.element_type()).c_str());
}
if (!ShapeUtil::SameElementType(mean_shape, operand_shape)) {
if (!ShapeUtil::SameElementTypeIgnoringFpPrecision(mean_shape,
operand_shape)) {
return InvalidArgument(
"The inputs should have the same element type for "
"batch-norm-inference, "
@ -1359,7 +1379,8 @@ ShapeInference::InferDegenerateDimensionBroadcastShape(
PrimitiveType_Name(operand_shape.element_type()).c_str());
}
if (!ShapeUtil::SameElementType(variance_shape, operand_shape)) {
if (!ShapeUtil::SameElementTypeIgnoringFpPrecision(variance_shape,
operand_shape)) {
return InvalidArgument(
"The inputs should have the same element type for "
"batch-norm-inference, "
@ -1481,7 +1502,8 @@ ShapeInference::InferDegenerateDimensionBroadcastShape(
PrimitiveType_Name(output_grad_shape.element_type()).c_str());
}
if (!ShapeUtil::SameElementType(output_grad_shape, operand_shape)) {
if (!ShapeUtil::SameElementTypeIgnoringFpPrecision(output_grad_shape,
operand_shape)) {
return InvalidArgument(
"The inputs should have the same element type for batch-norm-grad, "
"but the element type of output_grad is %s "
@ -1490,7 +1512,8 @@ ShapeInference::InferDegenerateDimensionBroadcastShape(
PrimitiveType_Name(operand_shape.element_type()).c_str());
}
if (!ShapeUtil::SameElementType(scale_shape, operand_shape)) {
if (!ShapeUtil::SameElementTypeIgnoringFpPrecision(scale_shape,
operand_shape)) {
return InvalidArgument(
"The inputs should have the same element type for batch-norm-grad, "
"but the element type of scale factor is %s "
@ -1499,7 +1522,8 @@ ShapeInference::InferDegenerateDimensionBroadcastShape(
PrimitiveType_Name(operand_shape.element_type()).c_str());
}
if (!ShapeUtil::SameElementType(mean_shape, operand_shape)) {
if (!ShapeUtil::SameElementTypeIgnoringFpPrecision(mean_shape,
operand_shape)) {
return InvalidArgument(
"The inputs should have the same element type for batch-norm-grad, "
"but the element type of mean is %s "
@ -1508,7 +1532,8 @@ ShapeInference::InferDegenerateDimensionBroadcastShape(
PrimitiveType_Name(operand_shape.element_type()).c_str());
}
if (!ShapeUtil::SameElementType(var_shape, operand_shape)) {
if (!ShapeUtil::SameElementTypeIgnoringFpPrecision(var_shape,
operand_shape)) {
return InvalidArgument(
"The inputs should have the same element type for batch-norm-grad, "
"but the element type of mean is %s "
@ -1569,7 +1594,7 @@ ShapeInference::InferDegenerateDimensionBroadcastShape(
TF_RETURN_IF_ERROR(ExpectNotTupleOrOpaque(lhs, "lhs of convolution"));
TF_RETURN_IF_ERROR(ExpectNotTupleOrOpaque(rhs, "rhs of convolution"));
if (!ShapeUtil::SameElementType(lhs, rhs)) {
if (!ShapeUtil::SameElementTypeIgnoringFpPrecision(lhs, rhs)) {
return InvalidArgument(
"Convolution with different element types: %s and %s",
ShapeUtil::HumanString(lhs).c_str(),
@ -1714,8 +1739,8 @@ ShapeInference::InferDegenerateDimensionBroadcastShape(
dimensions[dnums.output_spatial_dimensions(i)] =
window_output_shape.dimensions(i);
}
return ShapeUtil::MakeShape(lhs.element_type(), dimensions);
return ShapeUtil::MakeShape(ShapeUtil::HigherPrecisionElementType(lhs, rhs),
dimensions);
}
/* static */ StatusOr<Shape> ShapeInference::InferFftShape(
@ -1877,16 +1902,16 @@ ShapeInference::InferDegenerateDimensionBroadcastShape(
}
const Shape& operand_element_shape =
ShapeUtil::MakeShape(operand_shape.element_type(), {});
if (!ShapeUtil::Compatible(operand_element_shape,
select_shape.parameters(0))) {
if (!ShapeUtil::CompatibleIgnoringFpPrecision(operand_element_shape,
select_shape.parameters(0))) {
return InvalidArgument(
"select function's first parameter shape currently must "
"match the operand element shape. Got %s vs %s",
ShapeUtil::HumanString(select_shape.parameters(0)).c_str(),
ShapeUtil::HumanString(operand_element_shape).c_str());
}
if (!ShapeUtil::Compatible(operand_element_shape,
select_shape.parameters(1))) {
if (!ShapeUtil::CompatibleIgnoringFpPrecision(operand_element_shape,
select_shape.parameters(1))) {
return InvalidArgument(
"select function's second parameter shape currently must "
"match the operand element shape. Got %s vs %s",
@ -1903,7 +1928,8 @@ ShapeInference::InferDegenerateDimensionBroadcastShape(
InferWindowOutputShape(operand_shape, window,
operand_shape.element_type(),
/*allow_negative_padding=*/false));
if (!ShapeUtil::Compatible(source_shape, window_result_shape)) {
if (!ShapeUtil::CompatibleIgnoringFpPrecision(source_shape,
window_result_shape)) {
return InvalidArgument(
"source shape does not match the shape of window-reduced operand: "
"source(%s), window-reduced operand(%s)",
@ -2086,7 +2112,8 @@ ShapeInference::InferDegenerateDimensionBroadcastShape(
ShapeUtil::Rank(update_shape), ShapeUtil::Rank(operand_shape));
}
if (operand_shape.element_type() != update_shape.element_type()) {
if (!ShapeUtil::SameElementTypeIgnoringFpPrecision(operand_shape,
update_shape)) {
return InvalidArgument(
"dynamic update slice update element type does not match argument. "
"operand.element_type: %s vs update.element_type: %s",
@ -2322,24 +2349,26 @@ ShapeInference::InferDegenerateDimensionBroadcastShape(
TF_RETURN_IF_ERROR(ExpectNotTupleOrOpaque(min, "clamp min"));
TF_RETURN_IF_ERROR(ExpectNotTupleOrOpaque(operand, "clamp operand"));
TF_RETURN_IF_ERROR(ExpectNotTupleOrOpaque(max, "clamp max"));
if (!ShapeUtil::SameElementType(min, operand) ||
!ShapeUtil::SameElementType(max, operand)) {
if (!ShapeUtil::SameElementTypeIgnoringFpPrecision(min, operand) ||
!ShapeUtil::SameElementTypeIgnoringFpPrecision(max, operand)) {
return InvalidArgument("clamp op with different operand types: %s, %s, %s",
ShapeUtil::HumanString(min).c_str(),
ShapeUtil::HumanString(operand).c_str(),
ShapeUtil::HumanString(max).c_str());
}
if (((ShapeUtil::Compatible(min, operand) || ShapeUtil::IsScalar(min)) &&
(ShapeUtil::Compatible(max, operand) || ShapeUtil::IsScalar(max)))) {
if (((ShapeUtil::CompatibleIgnoringFpPrecision(min, operand) ||
ShapeUtil::IsScalar(min)) &&
(ShapeUtil::CompatibleIgnoringFpPrecision(max, operand) ||
ShapeUtil::IsScalar(max)))) {
return operand;
}
if (ShapeUtil::IsScalar(operand)) {
if (ShapeUtil::Compatible(min, max)) {
return min;
if (ShapeUtil::CompatibleIgnoringFpPrecision(min, max)) {
return ShapeUtil::ChangeElementType(min, operand.element_type());
} else if (ShapeUtil::IsScalar(min)) {
return max;
return ShapeUtil::ChangeElementType(max, operand.element_type());
} else if (ShapeUtil::IsScalar(max)) {
return min;
return ShapeUtil::ChangeElementType(min, operand.element_type());
}
}
return Unimplemented(
@ -2352,7 +2381,15 @@ ShapeInference::InferDegenerateDimensionBroadcastShape(
// broadcast from all operands, not just the predicate.
/* static */ StatusOr<Shape> ShapeInference::InferSelectShape(
const Shape& pred, const Shape& on_true, const Shape& on_false) {
if (!ShapeUtil::Compatible(on_true, on_false)) {
bool compatible;
if (ShapeUtil::IsTuple(on_true)) {
// Select only defines the top-level buffer, so if it's a tuple, the two
// input must match exactly.
compatible = ShapeUtil::Compatible(on_true, on_false);
} else {
compatible = ShapeUtil::CompatibleIgnoringFpPrecision(on_true, on_false);
}
if (!compatible) {
return InvalidArgument(
"operands to select must be the same shape; got %s and %s",
ShapeUtil::HumanString(on_true).c_str(),
@ -2367,7 +2404,8 @@ ShapeInference::InferDegenerateDimensionBroadcastShape(
// By this stage we know that pred's element type is PRED. Therefore, this
// check restricts pred to be a PRED scalar, or a PRED array with the same
// dimensions as on_true and on_false.
return on_true;
return ShapeUtil::ChangeElementType(
on_true, ShapeUtil::HigherPrecisionElementType(on_true, on_false));
} else {
return Unimplemented(
"select operation with non-scalar predicate with dimensionality "

View File

@ -630,6 +630,19 @@ StatusOr<Shape> ParseShapeStringInternal(tensorflow::StringPiece* s) {
return SameDimensions(lhs, rhs);
}
/* static */ bool ShapeUtil::CompatibleIgnoringFpPrecision(const Shape& lhs,
const Shape& rhs) {
if (lhs.element_type() == TUPLE) {
return rhs.element_type() == TUPLE &&
ContainersEqual(lhs.tuple_shapes(), rhs.tuple_shapes(),
CompatibleIgnoringFpPrecision);
}
if (SameElementTypeIgnoringFpPrecision(lhs, rhs)) {
return CompatibleIgnoringElementType(lhs, rhs);
}
return false;
}
/* static */ int64 ShapeUtil::GetDimension(const Shape& shape,
int64 dimension_number) {
return shape.dimensions(GetDimensionNumber(shape, dimension_number));

View File

@ -23,6 +23,7 @@ limitations under the License.
#include <string>
#include "tensorflow/compiler/xla/layout_util.h"
#include "tensorflow/compiler/xla/primitive_util.h"
#include "tensorflow/compiler/xla/statusor.h"
#include "tensorflow/compiler/xla/types.h"
#include "tensorflow/compiler/xla/xla_data.pb.h"
@ -211,6 +212,31 @@ class ShapeUtil {
return lhs.element_type() == rhs.element_type();
}
// As SameElementType, but allows floating point types to have different
// precisions.
static bool SameElementTypeIgnoringFpPrecision(const Shape& a,
const Shape& b) {
if (ElementIsFloating(a) && ElementIsFloating(b)) {
return true;
}
return ShapeUtil::SameElementType(a, b);
}
// Returns the higher-precision element type if a and b are both floating
// point types; otherwise, checks that that they have the same element type
// and returns it.
static PrimitiveType HigherPrecisionElementType(const Shape& a,
const Shape& b) {
if (SameElementType(a, b)) {
return a.element_type();
}
CHECK(SameElementTypeIgnoringFpPrecision(a, b));
return primitive_util::BitWidth(a.element_type()) <
primitive_util::BitWidth(b.element_type())
? b.element_type()
: a.element_type();
}
// Returns true if the rank, dimension sizes, and element type are
// identical. Layout is ignored. Tuple elements are compared recursively for
// compatibility.
@ -221,6 +247,10 @@ class ShapeUtil {
// compatibility.
static bool CompatibleIgnoringElementType(const Shape& lhs, const Shape& rhs);
// As Compatible, but allow one of lhs and rhs to be BF16 while the other
// being F32. Tuple elements are compared recursively for compatibility.
static bool CompatibleIgnoringFpPrecision(const Shape& lhs, const Shape& rhs);
// Returns whether the lhs and rhs shapes are identical protobufs.
static bool Equal(const Shape& lhs, const Shape& rhs);

View File

@ -170,6 +170,18 @@ TEST(ShapeUtilTest, CompatibleNotIdenticalShapes) {
EXPECT_TRUE(ShapeUtil::Compatible(shape_1, shape_2));
}
TEST(ShapeUtilTest, CompatibleIgnoringFpPrecision) {
Shape shape1 = ShapeUtil::MakeShape(BF16, {3, 2});
Shape shape2 = ShapeUtil::MakeShape(F32, {3, 2});
ASSERT_TRUE(ShapeUtil::CompatibleIgnoringFpPrecision(shape1, shape2));
}
TEST(ShapeUtilTest, IncompatibleIgnoringFpPrecision) {
Shape shape1 = ShapeUtil::MakeShape(BF16, {3, 2});
Shape shape2 = ShapeUtil::MakeShape(F32, {2, 2});
ASSERT_FALSE(ShapeUtil::CompatibleIgnoringFpPrecision(shape1, shape2));
}
TEST(ShapeUtilTest, IncompatibleDifferentElementShapes) {
Shape shape_1 = ShapeUtil::MakeShape(F32, {3, 2});
Shape shape_2 = ShapeUtil::MakeShape(PRED, {3, 2});
@ -184,6 +196,14 @@ TEST(ShapeUtilTest, CompatibleTuples) {
EXPECT_TRUE(ShapeUtil::Compatible(tuple1, tuple2));
}
TEST(ShapeUtilTest, CompatibleTuplesIgnoringFpPrecision) {
Shape tuple1 = ShapeUtil::MakeTupleShape(
{ShapeUtil::MakeShape(BF16, {3, 2}), ShapeUtil::MakeShape(F32, {4, 5})});
Shape tuple2 = ShapeUtil::MakeTupleShape(
{ShapeUtil::MakeShape(F64, {3, 2}), ShapeUtil::MakeShape(BF16, {4, 5})});
EXPECT_TRUE(ShapeUtil::CompatibleIgnoringFpPrecision(tuple1, tuple2));
}
TEST(ShapeUtilTest, IncompatibleTuplesWithSwappedElements) {
Shape tuple1 = ShapeUtil::MakeTupleShape(
{ShapeUtil::MakeShape(PRED, {4, 5}), ShapeUtil::MakeShape(F32, {3, 2})});
@ -193,6 +213,14 @@ TEST(ShapeUtilTest, IncompatibleTuplesWithSwappedElements) {
EXPECT_FALSE(ShapeUtil::CompatibleIgnoringElementType(tuple1, tuple2));
}
TEST(ShapeUtilTest, IncompatibleTuplesIgnoringFpPrecision) {
Shape tuple1 = ShapeUtil::MakeTupleShape(
{ShapeUtil::MakeShape(BF16, {4, 5}), ShapeUtil::MakeShape(F32, {3, 2})});
Shape tuple2 = ShapeUtil::MakeTupleShape(
{ShapeUtil::MakeShape(F32, {3, 2}), ShapeUtil::MakeShape(BF16, {4, 5})});
EXPECT_FALSE(ShapeUtil::CompatibleIgnoringFpPrecision(tuple1, tuple2));
}
TEST(ShapeUtilTest, IncompatibleTuplesWithDifferentPrimitiveType) {
Shape tuple1 = ShapeUtil::MakeTupleShape(
{ShapeUtil::MakeShape(PRED, {4, 5}), ShapeUtil::MakeShape(F32, {3, 2})});

View File

@ -2121,6 +2121,44 @@ XLA_TEST_F(ArrayElementwiseOpTest, ExpF32sVector) {
error_spec_);
}
XLA_TEST_F(ArrayElementwiseOpTest, LogF32sVector) {
// The input tensor is large enough to exercise the vectorized exp
// implementation on XLA CPU.
ComputationBuilder builder(client_, TestName());
std::unique_ptr<Literal> input_literal = Literal::CreateR1<float>(
{-1.29, -1.41, -1.25, -13.5, -11.7, -17.9, -198,
-167, 1.29, 1.41, 1.25, 13.5, 11.7, 17.9,
198, 167, 1.27e+03, 1.33e+03, 1.74e+03, 1.6e+04, 1.84e+04,
1.74e+04, 1.89e+05, 1.9e+05, 1.93e+06, 1.98e+06, 1.65e+06, 1.97e+07,
1.66e+07, 1e+07, 1.98e+08, 1.96e+08, 1.64e+09, 1.58e+09, 1.64e+09,
1.44e+10, 1.5e+10, 1.99e+10, 1.17e+11, 1.08e+11, 1.08e+12, 1.38e+12,
1.4e+12, 1.03e+13, 1.6e+13, 1.99e+13, 1.26e+14, 1.51e+14, 1.33e+15,
1.41e+15, 1.63e+15, 1.39e+16, 1.21e+16, 1.27e+16, 1.28e+17, 1.62e+17,
2e+18, 1.96e+18, 1.81e+18, 1.99e+19, 1.86e+19, 1.61e+19, 1.71e+20,
1.47e+20, 1.83e+21, 1.33e+21, 1.3e+21, 1.35e+22, 1.84e+22, 1.02e+22,
1.81e+23, 1.02e+23, 1.89e+24, 1.49e+24, 1.08e+24, 1.95e+25, 1.1e+25,
1.62e+25, 1.2e+26, 1.41e+26, 1.93e+27, 1.66e+27, 1.62e+27, 1.05e+28,
1.5e+28, 1.79e+28, 1.36e+29, 1.95e+29, 1.5e+30, 1.81e+30, 1.34e+30,
1.7e+31, 1.44e+31, 1.1e+31, 1.4e+32, 1.67e+32, 1.96e+33, 1.11e+33,
1.19e+33, 1.61e+34, 1.05e+34, 1.88e+34, 1.67e+35, 1.7e+35});
TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr<GlobalData> input_data,
client_->TransferToServer(*input_literal));
auto input = builder.Parameter(0, input_literal->shape(), "input");
builder.Log(input);
std::vector<float> expected_result;
int64 input_size = input_literal->shape().dimensions(0);
expected_result.reserve(input_size);
for (int64 i = 0; i < input_size; i++) {
expected_result.push_back(std::log(input_literal->Get<float>({i})));
}
ComputeAndCompareR1<float>(&builder, expected_result, {input_data.get()},
error_spec_);
}
XLA_TEST_F(ArrayElementwiseOpTest, AddChainFoldLeft) {
// a ------ (add) --------- (add)
// / /

View File

@ -276,8 +276,7 @@ if (tensorflow_BUILD_PYTHON_TESTS)
"${tensorflow_source_dir}/tensorflow/python/data/kernel_tests/dataset_constructor_op_test.py" # Segfaults on windows
"${tensorflow_source_dir}/tensorflow/contrib/data/python/kernel_tests/dataset_constructor_op_test.py" # Segfaults on Windows.
"${tensorflow_source_dir}/tensorflow/python/data/kernel_tests/iterator_ops_cluster_test.py"
# Broken tensorboard test due to cmake issues.
"${tensorflow_source_dir}/tensorflow/contrib/data/python/kernel_tests/iterator_ops_cluster_test.py" # Needs portpicker
"${tensorflow_source_dir}/tensorflow/contrib/data/python/kernel_tests/interleave_dataset_op_test.py" # Deadlocks
"${tensorflow_source_dir}/tensorflow/contrib/data/python/kernel_tests/sloppy_transformation_dataset_op_test.py" # b/65430561
# tensor_forest tests (also note that we exclude the hybrid tests for now)
"${tensorflow_source_dir}/tensorflow/contrib/tensor_forest/python/kernel_tests/count_extremely_random_stats_op_test.py" # Results in wrong order.

View File

@ -166,8 +166,8 @@ def crf_log_likelihood(inputs,
sequence_lengths: A [batch_size] vector of true sequence lengths.
transition_params: A [num_tags, num_tags] transition matrix, if available.
Returns:
log_likelihood: A scalar containing the log-likelihood of the given sequence
of tag indices.
log_likelihood: A [batch_size] `Tensor` containing the log-likelihood of
each example, given the sequence of tag indices.
transition_params: A [num_tags, num_tags] transition matrix. This is either
provided by the caller or created in this function.
"""
@ -182,7 +182,7 @@ def crf_log_likelihood(inputs,
transition_params)
log_norm = crf_log_norm(inputs, sequence_lengths, transition_params)
# Normalize the scores to get the log-likelihood.
# Normalize the scores to get the log-likelihood per example.
log_likelihood = sequence_scores - log_norm
return log_likelihood, transition_params

View File

@ -19,6 +19,7 @@ from __future__ import division
from __future__ import print_function
import os
import sqlite3
from tensorflow.contrib.data.python.ops import readers

View File

@ -790,7 +790,7 @@ def _extract_tensors(tensors_and_vars):
tensor, _ = tensor_and_var
if isinstance(tensor, ops_lib.IndexedSlices):
tensors.append(tensor.values)
else:
elif tensor is not None:
tensors.append(tensor)
return tensors

View File

@ -240,6 +240,13 @@ class ReplicateModelTest(test_util.TensorFlowTestCase):
labels = np.array([[1.0], [2.0]])
with self.test_session() as session:
# Add another trainable variable that doesn't produce a gradient to
# verify that None gradients are supported.
_ = variable_scope.get_variable(
'another_variable',
initializer=constant_op.constant(1, dtype=dtypes.float64),
dtype=dtypes.float64)
replicated_model_fn = replicate_model_fn.replicate_model_fn(
self.model_fn, losses.Reduction.MEAN, devices=['/gpu:0', '/gpu:1'])
estimator_spec = replicated_model_fn(
@ -1119,8 +1126,6 @@ class SplitBatchTest(test_util.TensorFlowTestCase):
feature_shards, label_shards = replicate_model_fn._split_batch(
features, labels, 2, device='/gpu:0')
print(feature_shards[0]['x'].eval())
print(feature_shards[1]['x'].eval())
self.assertSparseValuesEqual(
sparse_tensor.SparseTensorValue(
indices=[[0, 0], [1, 0], [1, 1]],

View File

@ -181,7 +181,8 @@ class ClassifierMetricsTest(test.TestCase):
batch_size = 3
img = array_ops.ones([batch_size, 299, 299, 3])
pool = _run_with_mock(
classifier_metrics.run_inception, img,
classifier_metrics.run_inception,
img,
output_tensor=classifier_metrics.INCEPTION_FINAL_POOL)
self.assertTrue(isinstance(pool, ops.Tensor))
@ -195,9 +196,12 @@ class ClassifierMetricsTest(test.TestCase):
batch_size = 3
img = array_ops.ones([batch_size, 299, 299, 3])
logits, pool = _run_with_mock(
classifier_metrics.run_inception, img,
output_tensor=[classifier_metrics.INCEPTION_OUTPUT,
classifier_metrics.INCEPTION_FINAL_POOL])
classifier_metrics.run_inception,
img,
output_tensor=[
classifier_metrics.INCEPTION_OUTPUT,
classifier_metrics.INCEPTION_FINAL_POOL
])
self.assertTrue(isinstance(logits, ops.Tensor))
self.assertTrue(isinstance(pool, ops.Tensor))
@ -209,8 +213,10 @@ class ClassifierMetricsTest(test.TestCase):
def test_inception_score_graph(self):
"""Test `inception_score` graph construction."""
score = _run_with_mock(classifier_metrics.inception_score,
array_ops.zeros([6, 299, 299, 3]), num_batches=3)
score = _run_with_mock(
classifier_metrics.inception_score,
array_ops.zeros([6, 299, 299, 3]),
num_batches=3)
self.assertTrue(isinstance(score, ops.Tensor))
score.shape.assert_has_rank(0)
@ -248,12 +254,14 @@ class ClassifierMetricsTest(test.TestCase):
array_ops.zeros([8, 10], dtype=dtypes.int32), p_logits, q)
with self.assertRaisesRegexp(ValueError, 'must be floating type'):
classifier_metrics._kl_divergence(
p, array_ops.zeros([8, 10], dtype=dtypes.int32), q)
classifier_metrics._kl_divergence(p,
array_ops.zeros(
[8, 10], dtype=dtypes.int32), q)
with self.assertRaisesRegexp(ValueError, 'must be floating type'):
classifier_metrics._kl_divergence(
p, p_logits, array_ops.zeros([10], dtype=dtypes.int32))
classifier_metrics._kl_divergence(p, p_logits,
array_ops.zeros(
[10], dtype=dtypes.int32))
with self.assertRaisesRegexp(ValueError, 'must have rank 2'):
classifier_metrics._kl_divergence(array_ops.zeros([8]), p_logits, q)
@ -266,8 +274,9 @@ class ClassifierMetricsTest(test.TestCase):
def test_inception_score_value(self):
"""Test that `inception_score` gives the correct value."""
logits = np.array([np.array([1, 2] * 500 + [4]),
np.array([4, 5] * 500 + [6])])
logits = np.array(
[np.array([1, 2] * 500 + [4]),
np.array([4, 5] * 500 + [6])])
unused_image = array_ops.zeros([2, 299, 299, 3])
incscore = _run_with_mock(classifier_metrics.inception_score, unused_image)
@ -285,9 +294,11 @@ class ClassifierMetricsTest(test.TestCase):
test_pool_real_a = np.float32(np.random.randn(512, 256))
test_pool_gen_a = np.float32(np.random.randn(512, 256))
fid_op = _run_with_mock(classifier_metrics.frechet_classifier_distance,
test_pool_real_a, test_pool_gen_a,
classifier_fn=lambda x: x)
fid_op = _run_with_mock(
classifier_metrics.frechet_classifier_distance,
test_pool_real_a,
test_pool_gen_a,
classifier_fn=lambda x: x)
with self.test_session() as sess:
actual_fid = sess.run(fid_op)
@ -296,6 +307,33 @@ class ClassifierMetricsTest(test.TestCase):
self.assertAllClose(expected_fid, actual_fid, 0.0001)
def test_frechet_classifier_distance_covariance(self):
"""Test that `frechet_classifier_distance` takes covariance into account."""
np.random.seed(0)
# Make num_examples > num_features to ensure scipy's sqrtm function
# doesn't return a complex matrix.
test_pool_reals, test_pool_gens = [], []
for i in range(1, 11, 2):
test_pool_reals.append(np.float32(np.random.randn(2048, 256) * i))
test_pool_gens.append(np.float32(np.random.randn(2048, 256) * i))
fid_ops = []
for i in range(len(test_pool_reals)):
fid_ops.append(_run_with_mock(
classifier_metrics.frechet_classifier_distance,
test_pool_reals[i],
test_pool_gens[i],
classifier_fn=lambda x: x))
fids = []
with self.test_session() as sess:
for fid_op in fid_ops:
fids.append(sess.run(fid_op))
# Check that the FIDs increase monotonically.
self.assertTrue(all(fid_a < fid_b for fid_a, fid_b in zip(fids, fids[1:])))
def test_trace_sqrt_product_value(self):
"""Test that `trace_sqrt_product` gives the correct value."""
np.random.seed(0)

View File

@ -305,6 +305,7 @@ def wasserstein_gradient_penalty(
discriminator_fn,
discriminator_scope,
epsilon=1e-10,
target=1.0,
weights=1.0,
scope=None,
loss_collection=ops.GraphKeys.LOSSES,
@ -324,6 +325,8 @@ def wasserstein_gradient_penalty(
discriminator_scope: If not `None`, reuse discriminators from this scope.
epsilon: A small positive number added for numerical stability when
computing the gradient norm.
target: Optional Python number or `Tensor` indicating the target value of
gradient norm. Defaults to 1.0.
weights: Optional `Tensor` whose rank is either 0, or the same rank as
`real_data` and `generated_data`, and must be broadcastable to
them (i.e., all dimensions must be either `1`, or the same as the
@ -374,7 +377,7 @@ def wasserstein_gradient_penalty(
# For numerical stability, add epsilon to the sum before taking the square
# root. Note tf.norm does not add epsilon.
slopes = math_ops.sqrt(gradient_squares + epsilon)
penalties = math_ops.square(slopes - 1.0)
penalties = math_ops.square(slopes / target - 1.0)
penalty = losses.compute_weighted_loss(
penalties, weights, scope=scope, loss_collection=loss_collection,
reduction=reduction)

View File

@ -481,6 +481,29 @@ class GradientPenaltyTest(test.TestCase, _PenaltyTest):
})
self.assertAlmostEqual(self._expected_loss, loss, 5)
def test_loss_with_gradient_norm_target(self):
"""Test loss value with non default gradient norm target."""
generated_data = array_ops.placeholder(dtypes.float32, shape=(None, None))
real_data = array_ops.placeholder(dtypes.float32, shape=(None, None))
loss = tfgan_losses.wasserstein_gradient_penalty(
generated_data,
real_data,
self._kwargs['generator_inputs'],
self._kwargs['discriminator_fn'],
self._kwargs['discriminator_scope'],
target=2.0)
with self.test_session() as sess:
variables.global_variables_initializer().run()
loss = sess.run(
loss,
feed_dict={
generated_data: self._generated_data_np,
real_data: self._real_data_np,
})
self.assertAlmostEqual(1.0, loss, 5)
def test_reuses_scope(self):
"""Test that gradient penalty reuses discriminator scope."""
num_vars = len(ops.get_collection(ops.GraphKeys.GLOBAL_VARIABLES))

View File

@ -460,6 +460,7 @@ def gan_loss(
# Auxiliary losses.
gradient_penalty_weight=None,
gradient_penalty_epsilon=1e-10,
gradient_penalty_target=1.0,
mutual_information_penalty_weight=None,
aux_cond_generator_weight=None,
aux_cond_discriminator_weight=None,
@ -481,6 +482,9 @@ def gan_loss(
small positive value used by the gradient penalty function for numerical
stability. Note some applications will need to increase this value to
avoid NaNs.
gradient_penalty_target: If `gradient_penalty_weight` is not None, a Python
number or `Tensor` indicating the target value of gradient norm. See the
CIFAR10 section of https://arxiv.org/abs/1710.10196. Defaults to 1.0.
mutual_information_penalty_weight: If not `None`, must be a non-negative
Python number or Tensor indicating how much to weight the mutual
information penalty. See https://arxiv.org/abs/1606.03657 for more
@ -539,7 +543,10 @@ def gan_loss(
# Add optional extra losses.
if _use_aux_loss(gradient_penalty_weight):
gp_loss = tfgan_losses.wasserstein_gradient_penalty(
model, epsilon=gradient_penalty_epsilon, add_summaries=add_summaries)
model,
epsilon=gradient_penalty_epsilon,
target=gradient_penalty_target,
add_summaries=add_summaries)
dis_loss += gradient_penalty_weight * gp_loss
if _use_aux_loss(mutual_information_penalty_weight):
info_loss = tfgan_losses.mutual_information_penalty(

View File

@ -64,8 +64,11 @@ void resize(T* out, uint8_t* in, int image_height, int image_width,
ops::builtin::BuiltinOpResolver resolver;
TfLiteRegistration* resize_op =
resolver.FindOp(BuiltinOperator_RESIZE_BILINEAR);
interpreter->AddNodeWithParameters({0, 1}, {2}, nullptr, 0, nullptr,
resize_op, nullptr);
auto* params = reinterpret_cast<TfLiteResizeBilinearParams*>(
malloc(sizeof(TfLiteResizeBilinearParams)));
params->align_corners = false;
interpreter->AddNodeWithParameters({0, 1}, {2}, nullptr, 0, params, resize_op,
nullptr);
interpreter->AllocateTensors();

View File

@ -28,6 +28,7 @@ py_library(
visibility = ["//visibility:public"],
deps = [
"//tensorflow/contrib/framework:framework_py",
"//tensorflow/core:protos_all_py",
"//tensorflow/python:platform",
],
)

View File

@ -73,6 +73,7 @@ import itertools as _itertools
import uuid as _uuid
from tensorflow.contrib import framework as _framework
from tensorflow.core.framework import attr_value_pb2 as _attr_value_pb2
from tensorflow.python.framework import ops as _ops
from tensorflow.python.ops import array_ops as _array_ops
from tensorflow.python.util.all_util import remove_undocumented
@ -133,10 +134,17 @@ class OpHint(object):
def augmented_identity(arg):
identity_op = _array_ops.identity(arg)
attr = identity_op.op.node_def.attr
attr[OpHint.FUNCTION_NAME_ATTR].s = self._function_name
attr[OpHint.FUNCTION_UUID_ATTR].s = self._unique_function_id
attr[OpHint.FUNCTION_INPUT_INDEX_ATTR].i = self._curr_input_index
# pylint: disable=protected-access
identity_op.op._set_attr(
OpHint.FUNCTION_NAME_ATTR,
_attr_value_pb2.AttrValue(s=self._function_name))
identity_op.op._set_attr(
OpHint.FUNCTION_UUID_ATTR,
_attr_value_pb2.AttrValue(s=self._unique_function_id))
identity_op.op._set_attr(
OpHint.FUNCTION_INPUT_INDEX_ATTR,
_attr_value_pb2.AttrValue(i=self._curr_input_index))
# pylint: enable=protected-access
self._curr_input_index += 1
return identity_op
@ -154,10 +162,17 @@ class OpHint(object):
def augmented_identity(arg):
identity_op = _array_ops.identity(arg)
attr = identity_op.op.node_def.attr
attr[OpHint.FUNCTION_NAME_ATTR].s = self._function_name
attr[OpHint.FUNCTION_UUID_ATTR].s = self._unique_function_id
attr[OpHint.FUNCTION_OUTPUT_INDEX_ATTR].i = self._curr_output_index
# pylint: disable=protected-access
identity_op.op._set_attr(
OpHint.FUNCTION_NAME_ATTR,
_attr_value_pb2.AttrValue(s=self._function_name))
identity_op.op._set_attr(
OpHint.FUNCTION_UUID_ATTR,
_attr_value_pb2.AttrValue(s=self._unique_function_id))
identity_op.op._set_attr(
OpHint.FUNCTION_OUTPUT_INDEX_ATTR,
_attr_value_pb2.AttrValue(i=self._curr_output_index))
# pylint: enable=protected-access
self._curr_output_index += 1
return identity_op

View File

@ -139,14 +139,13 @@ bool ResolveConstantUnaryOperator::Run(Model* model, std::size_t op_index) {
output_buffer_size * sizeof(output_float_data[0]));
} else if (unary_op->type == OperatorType::kTensorFlowSum) {
// At the moment only full reduction across all dimensions is supported.
for (int i = 0; i < output_dims_count; i++) {
CHECK_EQ(output_shape.dims(i), 1);
}
float sum = 0.f;
for (int i = 0; i < input_buffer_size; i++) {
sum += (*input_float_data)[i];
}
output_float_data[0] = sum;
for (int i = 0; i < output_buffer_size; ++i) {
output_float_data[i] = sum;
}
} else if (unary_op->type == OperatorType::kTensorFlowMin) {
// At the moment only full reduction across all dimensions is supported.
// TODO(starka): Output should not be padded.

View File

@ -35,6 +35,7 @@ NO_SIDE_EFFECT_CONSTRUCTORS = set(('tensorflow',))
# TODO(mdan): Verify that these names are not hidden by generated code.
# TODO(mdan): Make sure copybara renames the reference below.
COMPILED_IMPORT_STATEMENTS = (
'from __future__ import print_function',
'import tensorflow as tf',
'from tensorflow.contrib.py2tf import utils as '
'py2tf_utils')

View File

@ -28,13 +28,11 @@ from tensorflow.python.ops import nn_ops
from tensorflow.python.platform import googletest
# TODO(suharshs): Add tests for testing experimental APIs and additional
# input arguments
class QuantizeGraphTest(test_util.TensorFlowTestCase):
# We have a lot of other tests that test the details of the rewrite, here we
# just the specific features of the quantize_graph API.
def _RunTestOverParameters(self, test_fn):
def _RunTestOverAllRewrites(self, test_fn):
rewrite_fns = [
quantize_graph.create_training_graph,
quantize_graph.create_eval_graph,
@ -44,71 +42,125 @@ class QuantizeGraphTest(test_util.TensorFlowTestCase):
for fn in rewrite_fns:
test_fn(fn)
def testRewrite(self):
self._RunTestOverParameters(self._TestRewrite)
def _RunTestOverTrainingRewrites(self, test_fn):
rewrite_fns = [
quantize_graph.create_training_graph,
quantize_graph.experimental_create_training_graph,
]
for fn in rewrite_fns:
test_fn(fn)
def _TestRewrite(self, fn):
def _RunTestOverExperimentalRewrites(self, test_fn):
rewrite_fns = [
quantize_graph.experimental_create_training_graph,
quantize_graph.experimental_create_eval_graph,
]
for fn in rewrite_fns:
test_fn(fn)
def testRewrite(self):
self._RunTestOverAllRewrites(self._TestRewrite)
def _TestRewrite(self, rewrite_fn):
graph = ops.Graph()
with graph.as_default():
batch_size, height, width, depth = 5, 128, 128, 3
inputs = array_ops.zeros((batch_size, height, width, depth))
conv = layers.conv2d(
inputs,
32, [5, 5],
stride=2,
padding='SAME',
weights_initializer=self._WeightInit(0.09),
activation_fn=None,
scope='test')
_ = nn_ops.relu6(conv)
self._ConvLayer()
orig_variable_names = set(
[v.name for v in graph.get_collection(ops.GraphKeys.GLOBAL_VARIABLES)])
fn(graph)
rewrite_fn(graph)
q_variables = graph.get_collection(ops.GraphKeys.GLOBAL_VARIABLES)
# Ensure that variables were added.
self.assertTrue(len(orig_variable_names) < len(q_variables))
def testDefaultGraph(self):
self._RunTestOverParameters(self._TestRewrite)
self._RunTestOverAllRewrites(self._TestRewrite)
def _TestDefaultGraph(self, fn):
def _TestDefaultGraph(self, rewrite_fn):
# Tests that the default graph is correctly used when no args are provided
# to rewrite_fn.
with ops.Graph().as_default() as g:
batch_size, height, width, depth = 5, 128, 128, 3
inputs = array_ops.zeros((batch_size, height, width, depth))
conv = layers.conv2d(
inputs,
32, [5, 5],
stride=2,
padding='SAME',
weights_initializer=self._WeightInit(0.09),
activation_fn=None,
scope='test')
_ = nn_ops.relu6(conv)
self._ConvLayer()
orig_variable_names = set(
[v.name for v in g.get_collection(ops.GraphKeys.GLOBAL_VARIABLES)])
fn()
rewrite_fn()
q_variables = g.get_collection(ops.GraphKeys.GLOBAL_VARIABLES)
# Ensure that variables were added.
self.assertTrue(len(orig_variable_names) < len(q_variables))
def _WeightInit(self, stddev):
"""Returns truncated normal variable initializer.
def testQuantDelay(self):
self._RunTestOverTrainingRewrites(self._TestQuantDelay)
Function is defined purely to shorten the name so that it stops wrapping.
def _TestQuantDelay(self, rewrite_fn):
with ops.Graph().as_default() as g:
self._ConvLayer()
quant_delay = 100
rewrite_fn(quant_delay=quant_delay)
Args:
stddev: Standard deviation of normal variable.
quant_delay_found = False
for op in g.get_operations():
# Check to see if the quant_delay is correctly set.
if 'activate_quant' in op.name and op.type == 'Const':
quant_delay_found = True
const_value = str(op.get_attr('value'))
self.assertTrue(('int64_val: %i' % quant_delay) in const_value)
self.assertTrue(quant_delay_found)
Returns:
An initialized that initialzes with a truncated normal variable.
"""
return init_ops.truncated_normal_initializer(stddev=stddev)
def testWeightBits(self):
self._RunTestOverExperimentalRewrites(self._TestWeightBits)
def _TestWeightBits(self, rewrite_fn):
with ops.Graph().as_default() as g:
self._ConvLayer()
weight_bits = 4
rewrite_fn(weight_bits=weight_bits)
weights_quant_found = False
for op in g.get_operations():
# Check to see if FakeQuant operations for weights have the right bits
# set.
if 'weights_quant' in op.name and op.type == 'FakeQuantWithMinMaxVars':
weights_quant_found = True
self.assertEqual(op.get_attr('num_bits'), weight_bits)
self.assertTrue(weights_quant_found)
def testActivationBits(self):
self._RunTestOverExperimentalRewrites(self._TestActivationBits)
def _TestActivationBits(self, rewrite_fn):
with ops.Graph().as_default() as g:
self._ConvLayer()
activation_bits = 4
rewrite_fn(activation_bits=activation_bits)
act_quant_found = False
for op in g.get_operations():
# Check to see if FakeQuant operations for activations have the right bits
# set.
act_quant_names = ['act_quant', 'conv_quant', 'add_quant']
if any(s in op.name
for s in act_quant_names) and op.type == 'FakeQuantWithMinMaxVars':
act_quant_found = True
self.assertEqual(op.get_attr('num_bits'), activation_bits)
self.assertTrue(act_quant_found)
def _ConvLayer(self):
"""Add a basic convolution layer to the default graph."""
batch_size, height, width, depth = 5, 128, 128, 3
inputs = array_ops.zeros((batch_size, height, width, depth))
weight_init = init_ops.truncated_normal_initializer
conv = layers.conv2d(
inputs,
32, [5, 5],
stride=2,
padding='SAME',
weights_initializer=weight_init(0.09),
activation_fn=None,
scope='test')
_ = nn_ops.relu6(conv)
if __name__ == '__main__':

View File

@ -349,7 +349,8 @@ class Image(ItemHandler):
shape=None,
channels=3,
dtype=dtypes.uint8,
repeated=False):
repeated=False,
dct_method=''):
"""Initializes the image.
Args:
@ -368,6 +369,11 @@ class Image(ItemHandler):
tf.decode_raw,
repeated: if False, decodes a single image. If True, decodes a
variable number of image strings from a 1D tensor of strings.
dct_method: An optional string. Defaults to empty string. It only takes
effect when image format is jpeg, used to specify a hint about the
algorithm used for jpeg decompression. Currently valid values
are ['INTEGER_FAST', 'INTEGER_ACCURATE']. The hint may be ignored, for
example, the jpeg library does not have that specific option.
"""
if not image_key:
image_key = 'image/encoded'
@ -381,6 +387,7 @@ class Image(ItemHandler):
self._channels = channels
self._dtype = dtype
self._repeated = repeated
self._dct_method = dct_method
def tensors_to_item(self, keys_to_tensors):
"""See base class."""
@ -406,9 +413,25 @@ class Image(ItemHandler):
A tensor that represents decoded image of self._shape, or
(?, ?, self._channels) if self._shape is not specified.
"""
def decode_image():
"""Decodes a png or jpg based on the headers."""
return image_ops.decode_image(image_buffer, self._channels)
"""Decodes a image based on the headers."""
return image_ops.decode_image(image_buffer, channels=self._channels)
def decode_jpeg():
"""Decodes a jpeg image with specified '_dct_method'."""
return image_ops.decode_jpeg(
image_buffer, channels=self._channels, dct_method=self._dct_method)
def check_jpeg():
"""Checks if an image is jpeg."""
# For jpeg, we directly use image_ops.decode_jpeg rather than decode_image
# in order to feed the jpeg specify parameter 'dct_method'.
return control_flow_ops.cond(
image_ops.is_jpeg(image_buffer),
decode_jpeg,
decode_image,
name='cond_jpeg')
def decode_raw():
"""Decodes a raw image."""
@ -420,7 +443,7 @@ class Image(ItemHandler):
math_ops.equal(image_format, 'RAW')): decode_raw,
}
image = control_flow_ops.case(
pred_fn_pairs, default=decode_image, exclusive=True)
pred_fn_pairs, default=check_jpeg, exclusive=True)
image.set_shape([None, None, self._channels])
if self._shape is not None:

View File

@ -0,0 +1,60 @@
# Copyright 2017 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Internal helpers for tests in this directory."""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import functools
import os
import sqlite3
from tensorflow.contrib.summary import summary_ops
from tensorflow.python.framework import test_util
class SummaryDbTest(test_util.TensorFlowTestCase):
"""Helper for summary database testing."""
def setUp(self):
super(SummaryDbTest, self).setUp()
self.db_path = os.path.join(self.get_temp_dir(), 'DbTest.sqlite')
if os.path.exists(self.db_path):
os.unlink(self.db_path)
self.db = sqlite3.connect(self.db_path)
self.create_db_writer = functools.partial(
summary_ops.create_db_writer,
db_uri=self.db_path,
experiment_name='experiment',
run_name='run',
user_name='user')
def tearDown(self):
self.db.close()
super(SummaryDbTest, self).tearDown()
def get_one(db, q, *p):
return db.execute(q, p).fetchone()[0]
def get_all(db, q, *p):
return unroll(db.execute(q, p).fetchall())
def unroll(list_of_tuples):
return sum(list_of_tuples, ())

View File

@ -21,6 +21,7 @@ from __future__ import print_function
import functools
import os
import sqlite3
from tensorflow.contrib.summary import summary_ops

View File

@ -106,7 +106,9 @@ class _TPUContext(object):
# pylint: disable=protected-access
tpu_system_metadata = (
tpu_system_metadata_lib._query_tpu_system_metadata(
master, query_topology=self.model_parallelism_enabled))
master,
run_config=self._config,
query_topology=self.model_parallelism_enabled))
self._lazy_tpu_system_metadata_dict[master] = tpu_system_metadata
return tpu_system_metadata
@ -409,6 +411,29 @@ class _TPUContext(object):
'Tensorflow master address and TPU worker(s). Available devices '
'are {}.'.format(tpu_system_metadata.devices))
if self._config.tpu_config.num_shards:
user_provided_num_replicas = self._config.tpu_config.num_shards
if user_provided_num_replicas != num_replicas:
message = (
'TPUConfig.num_shards is not set correctly. According to TPU '
'system metadata for Tensorflow master ({}): num_replicas should '
'be ({}), got ({}). For non-model-parallelism, num_replicas should '
'be the total num of TPU cores in the system. For '
'model-parallelism, the total number of TPU cores should be '
'product(computation_shape) * num_replicas. Please set it '
'accordingly or leave it as `None`'.format(
self._get_master_address(), num_replicas,
user_provided_num_replicas))
if self.model_parallelism_enabled:
raise ValueError(message)
else:
logging.warning(message)
logging.warning(
'For non-model-parallelism, TPUEstimator currently '
'automatically queries the TPU system information so ignores '
'this field.')
if mode == model_fn_lib.ModeKeys.TRAIN:
if self._train_batch_size % num_replicas != 0:
raise ValueError(

View File

@ -45,7 +45,8 @@ _TPUSystemMetadata = collections.namedtuple('_TPUSystemMetadata', [
])
def _query_tpu_system_metadata(master_address, query_topology=False):
def _query_tpu_system_metadata(master_address, run_config,
query_topology=False):
"""Automatically detects the TPU system metadata in the system."""
tpu_core_count = 0
devices = []
@ -59,8 +60,8 @@ def _query_tpu_system_metadata(master_address, query_topology=False):
with ops.Graph().as_default():
with session_lib.Session(
master_address,
config=config_pb2.ConfigProto(
operation_timeout_in_ms=_PINGING_MASTER_TIMEOUT_IN_MS)) as sess:
config=_get_session_config_with_timeout(
_PINGING_MASTER_TIMEOUT_IN_MS, run_config)) as sess:
devices = sess.list_devices()
for device in devices:
match = _TPU_DEVICE_REG.match(device.name)
@ -104,7 +105,7 @@ def _query_tpu_system_metadata(master_address, query_topology=False):
'TPU worker has some problems. Available devices: {}'.format(
master_address, devices))
topology = _obtain_topology(master_address)
topology = _obtain_topology(master_address, run_config)
metadata = _TPUSystemMetadata(
num_cores=tpu_core_count,
@ -113,19 +114,26 @@ def _query_tpu_system_metadata(master_address, query_topology=False):
topology=topology,
devices=devices)
msg = 'Found TPU system %s' if tpu_core_count else 'Failed to find TPU: %s'
logging.info(msg, metadata)
if tpu_core_count:
logging.info('Found TPU system:')
logging.info('*** Num TPU Cores: %d', metadata.num_cores)
logging.info('*** Num TPU Workers: %d', metadata.num_hosts)
logging.info('*** Num TPU Cores Per Worker: %d',
metadata.num_of_cores_per_host)
logging.info('*** Available Devices: %s', metadata.devices)
else:
logging.info('Failed to find TPU: %s', metadata)
return metadata
def _obtain_topology(master_address):
def _obtain_topology(master_address, run_config):
try:
logging.info('Initializing TPU system (master: %s) to fetch topology '
'for model parallelism. This might take a while.',
master_address)
with ops.Graph().as_default():
session_config = config_pb2.ConfigProto(
operation_timeout_in_ms=_INITIAL_TPU_SYSTEM_TIMEOUT_IN_MS)
session_config = _get_session_config_with_timeout(
_INITIAL_TPU_SYSTEM_TIMEOUT_IN_MS, run_config)
with session_lib.Session(
master_address, config=session_config) as sess:
topology = sess.run(tpu.initialize_system())
@ -137,3 +145,11 @@ def _obtain_topology(master_address):
master_address))
def _get_session_config_with_timeout(timeout_in_secs, run_config):
cluster_def = None
if run_config.session_config and run_config.session_config.cluster_def.job:
cluster_def = run_config.session_config.cluster_def
config = config_pb2.ConfigProto(
operation_timeout_in_ms=timeout_in_secs, cluster_def=cluster_def)
return config

View File

@ -31,6 +31,11 @@ END
description: <<END
Input images can be of different types but output images are always float.
The range of pixel values for the output image might be slightly different
from the range for the input image because of limited numerical precision.
To guarantee an output range, for example `[0.0, 1.0]`, apply
`tf.clip_by_value` to the output.
Each output pixel is computed by first transforming the pixel's footprint into
the input tensor and then averaging the pixels that intersect the footprint. An
input pixel's contribution to the average is weighted by the fraction of its

View File

@ -56,9 +56,9 @@ limitations under the License.
//
// To add values to feature_lists:
// AppendFeatureValues({4.0},
// GetFeatureList("movie_ratings", &se)->Add());
// GetFeatureList("images", &se)->Add());
// AppendFeatureValues({5.0, 3.0},
// GetFeatureList("movie_ratings", &se)->Add());
// GetFeatureList("images", &se)->Add());
// This will create a feature list keyed as "images" with two features:
// feature_lists {
// feature_list {

View File

@ -1025,9 +1025,8 @@ StringPiece Tensor::tensor_data() const {
}
bool Tensor::SharesBufferWith(const Tensor& b) const {
CHECK_NE(nullptr, buf_);
CHECK_NE(nullptr, b.buf_);
return buf_->root_buffer() == b.buf_->root_buffer();
return buf_ != nullptr && b.buf_ != nullptr &&
buf_->root_buffer() == b.buf_->root_buffer();
}
string Tensor::DebugString() const {

View File

@ -1085,6 +1085,21 @@ class DummyCPUAllocator : public Allocator {
void DeallocateRaw(void* ptr) override {}
};
TEST(Tensor, SharesBufferWith) {
Tensor a_empty;
Tensor b_empty;
Tensor a(DT_FLOAT, TensorShape({1}));
Tensor b(DT_FLOAT, TensorShape({1}));
Tensor copy(a);
EXPECT_FALSE(a_empty.SharesBufferWith(a_empty));
EXPECT_FALSE(a_empty.SharesBufferWith(b_empty));
EXPECT_FALSE(a_empty.SharesBufferWith(a));
EXPECT_FALSE(a_empty.SharesBufferWith(copy));
EXPECT_TRUE(a.SharesBufferWith(a));
EXPECT_FALSE(a.SharesBufferWith(b));
EXPECT_TRUE(a.SharesBufferWith(copy));
}
TEST(Tensor, FailureToAllocate) {
TensorShape shape({1});
DummyCPUAllocator allocator;

View File

@ -28,6 +28,7 @@ limitations under the License.
#include "tensorflow/core/lib/strings/strcat.h"
#include "tensorflow/core/platform/mutex.h"
#include "tensorflow/core/platform/notification.h"
#include "tensorflow/core/platform/types.h"
#include "tensorflow/core/public/session.h"
namespace tensorflow {
@ -86,7 +87,9 @@ Status SingleMachine::Provision() {
attr = GetLocalCPUInfo();
} else if (dev.device_type() == "GPU") {
attr = GetLocalGPUInfo(gpu_id++);
} else {
} else if (dev.device_type().find("XLA") == string::npos) {
// Filter out the fake XLA devices to avoid double counting the actual
// hardware resources that are available.
attr.set_type(dev.device_type());
}
// Overwrite the memory size since users might have requested to use only a

View File

@ -27,6 +27,7 @@ limitations under the License.
#include "tensorflow/core/grappler/costs/utils.h"
#include "tensorflow/core/grappler/op_types.h"
#include "tensorflow/core/grappler/utils.h"
#include "tensorflow/core/lib/core/errors.h"
#include "tensorflow/core/lib/strings/numbers.h"
#include "tensorflow/core/lib/strings/str_util.h"
#include "tensorflow/core/platform/logging.h"
@ -446,13 +447,14 @@ Status VirtualScheduler::Init() {
}
if (ready_nodes_->Empty()) {
return Status(error::UNAVAILABLE, "No ready nodes in the graph.");
return errors::InvalidArgument("No ready nodes in the graph.");
}
if (!feed_nodes.empty())
LOG(ERROR) << "Some feed nodes were not found in the graph: "
<< str_util::Join(feed_nodes, ",");
if (!feed_nodes.empty()) {
return errors::InvalidArgument(
strings::StrCat("Some feed nodes were not found in the graph: ",
str_util::Join(feed_nodes, ",")));
}
initialized_ = true;
return Status::OK();
}

View File

@ -371,6 +371,7 @@ cc_library(
":dependency_optimizer",
":graph_optimizer",
":layout_optimizer",
":loop_optimizer",
":memory_optimizer",
":model_pruner",
"//tensorflow/core:framework",
@ -380,3 +381,39 @@ cc_library(
"//tensorflow/core/grappler/utils:topological_sort",
],
)
cc_library(
name = "loop_optimizer",
srcs = ["loop_optimizer.cc"],
hdrs = [
"loop_optimizer.h",
],
visibility = ["//visibility:public"],
deps = [
":graph_optimizer",
"//tensorflow/core:framework",
"//tensorflow/core:lib",
"//tensorflow/core:lib_internal",
"//tensorflow/core:protos_all_cc",
"//tensorflow/core/grappler:grappler_item",
"//tensorflow/core/grappler:op_types",
"//tensorflow/core/grappler:utils",
"//tensorflow/core/grappler/costs:graph_properties",
],
)
tf_cc_test(
name = "loop_optimizer_test",
size = "small",
srcs = ["loop_optimizer_test.cc"],
deps = [
":loop_optimizer",
"//tensorflow/cc:cc_ops",
"//tensorflow/core:protos_all_cc",
"//tensorflow/core:test",
"//tensorflow/core:test_main",
"//tensorflow/core/grappler:grappler_item",
"//tensorflow/core/grappler:utils",
"//tensorflow/core/grappler/inputs:trivial_test_graph_input_yielder",
],
)

View File

@ -1717,13 +1717,28 @@ class SqueezeProcessor : public AgnosticNodeProcessor {
protected:
bool ShouldProcess() const override {
return !MustPreserve() && IsPortZeroDimsN(*node_, 2) && HasOutputs() &&
IsNodeAfterNCHWToNHWC() && IsInputConvertible() && IsAlongDimHW() &&
IsOnGPU();
bool is_dims_supported = (IsPortZeroDimsN(*node_, 2) && IsAlongHW()) ||
(IsPortZeroDimsN(*node_, 1) && IsAlongNHW());
return !MustPreserve() && HasOutputs() && IsNodeAfterNCHWToNHWC() &&
IsInputConvertible() && is_dims_supported && IsOnGPU();
}
Status AddLayoutTransposeToOutputs() override { return Status::OK(); }
Status CustomizedProcessing() override {
TF_RETURN_IF_ERROR(HasAttribute(*node_, "squeeze_dims"));
auto list = node_->mutable_attr()->at("squeeze_dims").mutable_list();
if (list->i_size() == 2) {
list->set_i(0, 2);
list->set_i(1, 3);
} else if (list->i_size() == 3) {
list->set_i(1, 2);
list->set_i(2, 3);
}
return Status::OK();
}
private:
bool IsInputConvertible() const {
int input_port;
auto input = node_map_->GetNode(node_->input(0));
@ -1736,33 +1751,31 @@ class SqueezeProcessor : public AgnosticNodeProcessor {
if (shape.dim(1).size() == 1 && shape.dim(2).size() == 1) {
return true;
}
}
return false;
}
bool IsAlongDimHW() const {
if (node_->attr().find("squeeze_dims") != node_->attr().end()) {
auto list = node_->attr().at("squeeze_dims").list();
// If list is empty, Squeeze op will squeeze all dimensions of size 1.
if (list.i_size() == 0) return true;
if (list.i_size() == 2) {
if (list.i(0) == 1 && list.i(1) == 2) {
return true;
}
if (shape.dim(0).size() == 1 && shape.dim(1).size() == 1 &&
shape.dim(2).size() == 1) {
return true;
}
}
return false;
}
Status CustomizedProcessing() override {
TF_RETURN_IF_ERROR(HasAttribute(*node_, "squeeze_dims"));
auto list = node_->mutable_attr()->at("squeeze_dims").mutable_list();
if (list->i_size() == 2) {
list->set_i(0, 2);
list->set_i(1, 3);
bool IsAlongAxis(const std::vector<int>& axis) const {
if (node_->attr().find("squeeze_dims") != node_->attr().end()) {
auto list = node_->attr().at("squeeze_dims").list();
// If list is empty, Squeeze op will squeeze all dimensions of size 1.
if (list.i_size() == 0) return true;
if (list.i_size() == axis.size()) {
bool along_axis = true;
for (int i = 0; i < axis.size(); i++) {
along_axis = along_axis && (list.i(i) == axis[i]);
}
if (along_axis) return true;
}
}
return Status::OK();
return false;
}
bool IsAlongHW() const { return IsAlongAxis({1, 2}); }
bool IsAlongNHW() const { return IsAlongAxis({0, 1, 2}); }
};
class ReduceProcessor : public AgnosticNodeProcessor {
@ -1789,12 +1802,18 @@ class ReduceProcessor : public AgnosticNodeProcessor {
return Status::OK();
}
Status AddLayoutTransposeToOutputs() override { return Status::OK(); }
Status AddLayoutTransposeToOutputs() override {
if (KeepDims()) {
return AddTransformToOutputs("Transpose");
}
return Status::OK();
}
private:
bool IsReduceAxisSupported() const {
return IsAlongAllFourDims() || IsAlongHWC() ||
((IsAlongNHW() || IsAlongHW() || IsAlongC()) && !KeepDims());
return KeepDims() || ((IsAlongAllFourDims() || IsAlongHWC() ||
IsAlongNHW() || IsAlongHW() || IsAlongC()) &&
!KeepDims());
}
bool IsAlongAxis(const std::vector<int>& axis) const {

View File

@ -0,0 +1,46 @@
/* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#include "tensorflow/core/grappler/optimizers/loop_optimizer.h"
#include <unordered_map>
#include <unordered_set>
#include "tensorflow/core/framework/op.h"
#include "tensorflow/core/grappler/costs/graph_properties.h"
#include "tensorflow/core/grappler/grappler_item.h"
#include "tensorflow/core/grappler/op_types.h"
#include "tensorflow/core/lib/core/errors.h"
#include "tensorflow/core/lib/core/stringpiece.h"
#include "tensorflow/core/lib/strings/strcat.h"
namespace tensorflow {
namespace grappler {
Status LoopOptimizer::Optimize(Cluster* cluster, const GrapplerItem& item,
GraphDef* optimized_graph) {
*optimized_graph = item.graph;
return Status::OK();
}
void LoopOptimizer::Feedback(Cluster* /*cluster*/, const GrapplerItem& /*item*/,
const GraphDef& /*optimized_graph*/,
double /*result*/) {
// Nothing to do for LoopOptimizer.
}
} // end namespace grappler
} // end namespace tensorflow

View File

@ -0,0 +1,49 @@
/* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#ifndef TENSORFLOW_CORE_GRAPPLER_OPTIMIZERS_LOOP_OPTIMIZER_H_
#define TENSORFLOW_CORE_GRAPPLER_OPTIMIZERS_LOOP_OPTIMIZER_H_
#include <unordered_set>
#include "tensorflow/core/grappler/optimizers/graph_optimizer.h"
#include "tensorflow/core/grappler/utils.h"
#include "tensorflow/core/protobuf/rewriter_config.pb.h"
namespace tensorflow {
namespace grappler {
class LoopOptimizer : public GraphOptimizer {
public:
LoopOptimizer() : opt_level_(RewriterConfig::ON) {}
explicit LoopOptimizer(RewriterConfig::Toggle opt_level)
: opt_level_(opt_level) {}
~LoopOptimizer() override {}
string name() const override { return "loop_optimizer"; };
Status Optimize(Cluster* cluster, const GrapplerItem& item,
GraphDef* optimized_graph) override;
void Feedback(Cluster* cluster, const GrapplerItem& item,
const GraphDef& optimized_graph, double result) override;
private:
RewriterConfig::Toggle opt_level_;
};
} // end namespace grappler
} // end namespace tensorflow
#endif // TENSORFLOW_CORE_GRAPPLER_OPTIMIZERS_LOOP_OPTIMIZER_H_

View File

@ -0,0 +1,62 @@
/* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#include "tensorflow/core/grappler/optimizers/loop_optimizer.h"
#include "tensorflow/cc/ops/standard_ops.h"
#include "tensorflow/core/framework/node_def.pb.h"
#include "tensorflow/core/grappler/grappler_item.h"
#include "tensorflow/core/grappler/inputs/trivial_test_graph_input_yielder.h"
#include "tensorflow/core/grappler/utils.h"
#include "tensorflow/core/lib/core/status_test_util.h"
#include "tensorflow/core/platform/test.h"
namespace tensorflow {
namespace grappler {
namespace {
class LoopOptimizerTest : public ::testing::Test {};
void VerifyGraphsEqual(const GraphDef& original_graph,
const GraphDef& optimized_graph, const string& func) {
EXPECT_EQ(original_graph.node_size(), optimized_graph.node_size()) << func;
for (int i = 0; i < original_graph.node_size(); ++i) {
const NodeDef& original = original_graph.node(i);
const NodeDef& optimized = optimized_graph.node(i);
EXPECT_EQ(original.name(), optimized.name()) << func;
EXPECT_EQ(original.op(), optimized.op()) << func;
EXPECT_EQ(original.input_size(), optimized.input_size()) << func;
for (int j = 0; j < original.input_size(); ++j) {
EXPECT_EQ(original.input(j), optimized.input(j)) << func;
}
}
}
TEST_F(LoopOptimizerTest, NoOp) {
// This trivial graph is so basic there's nothing to optimize.
TrivialTestGraphInputYielder fake_input(4, 1, 10, false, {"CPU:0"});
GrapplerItem item;
CHECK(fake_input.NextItem(&item));
LoopOptimizer optimizer;
GraphDef output;
Status status = optimizer.Optimize(nullptr, item, &output);
TF_EXPECT_OK(status);
VerifyGraphsEqual(item.graph, output, __FUNCTION__);
}
} // namespace
} // namespace grappler
} // namespace tensorflow

View File

@ -490,12 +490,12 @@ void RecomputationRewritingPass(RewriterConfig::MemOptType optimization_level,
}
bool SchedulingPass(Cluster* cluster, GrapplerItem* item) {
// Look for AddN nodes and record input names.
// Look for AddN nodes (and equivalent) and record input names.
GraphView view(&item->graph);
std::unordered_map<string, std::unordered_set<NodeDef*>> addn_list;
for (NodeDef& node : *item->graph.mutable_node()) {
if (!IsAddN(node)) {
if (!IsAddN(node) && node.op() != "AccumulateNV2") {
continue;
}
// There is nothing to gain by optimizing nodes with 2 or fewer inputs.
@ -511,6 +511,10 @@ bool SchedulingPass(Cluster* cluster, GrapplerItem* item) {
}
}
if (addn_list.empty()) {
return false;
}
GraphMemory memory(*item);
const std::unordered_map<string, DeviceProperties>& devices =
cluster->GetDevices();
@ -560,6 +564,13 @@ bool SchedulingPass(Cluster* cluster, GrapplerItem* item) {
VLOG(1) << "Missing properties for " << node->name();
continue;
}
const TensorShapeProto& shape =
properties.GetOutputProperties(node->name())[0].shape();
PartialTensorShape shp(shape);
if (!shp.IsFullyDefined()) {
VLOG(1) << "Shape not fully known for " << node->name();
continue;
}
// Compute a topological ordering for the node fanin.
std::unordered_map<NodeDef*, int> topo_order;
@ -608,8 +619,6 @@ bool SchedulingPass(Cluster* cluster, GrapplerItem* item) {
}
}
const TensorShapeProto& shape =
properties.GetOutputProperties(node->name())[0].shape();
DataType dtype = node->attr().at("T").type();
const string& device = node->device();
@ -721,6 +730,7 @@ Status BuildSwapPair(NodeDef* node, int input_to_swap,
*swap_in_node->add_input() = swap_out_node->name();
// Colocate the swap_in_ node with the node itself.
swap_in_node->set_device(node->device());
string coloc_group = strings::StrCat("loc@", tensor_to_swap);
(*swap_in_node->mutable_attr())["_class"].mutable_list()->add_s(coloc_group);
(*node->mutable_attr())["_class"].mutable_list()->add_s(coloc_group);
@ -1223,7 +1233,8 @@ Status MemoryOptimizer::Optimize(Cluster* cluster, const GrapplerItem& item,
bool updated_graph = true;
for (int i = 0; i < 25 && updated_graph; ++i) {
updated_graph = false;
if ((optimization_level_ == RewriterConfig::SCHEDULING_HEURISTICS ||
if ((optimization_level_ == RewriterConfig::DEFAULT_MEM_OPT ||
optimization_level_ == RewriterConfig::SCHEDULING_HEURISTICS ||
optimization_level_ == RewriterConfig::HEURISTICS) &&
cluster != nullptr) {
updated_graph |= SchedulingPass(cluster, &optimized_item);

View File

@ -22,6 +22,7 @@ limitations under the License.
#include "tensorflow/core/grappler/optimizers/dependency_optimizer.h"
#include "tensorflow/core/grappler/optimizers/graph_optimizer.h"
#include "tensorflow/core/grappler/optimizers/layout_optimizer.h"
#include "tensorflow/core/grappler/optimizers/loop_optimizer.h"
#include "tensorflow/core/grappler/optimizers/memory_optimizer.h"
#include "tensorflow/core/grappler/optimizers/model_pruner.h"
#include "tensorflow/core/grappler/utils/topological_sort.h"
@ -75,6 +76,9 @@ std::unique_ptr<GraphOptimizer> MetaOptimizer::NewOptimizer(
graph_optimizer.reset(
new DependencyOptimizer(cfg_.dependency_optimization()));
}
if (optimizer == "loop") {
graph_optimizer.reset(new LoopOptimizer(cfg_.loop_optimization()));
}
return graph_optimizer;
}
@ -97,11 +101,15 @@ Status MetaOptimizer::Optimize(Cluster* cluster, const GrapplerItem& item,
optimizers.push_back(std::unique_ptr<GraphOptimizer>(
new DependencyOptimizer(cfg_.dependency_optimization())));
}
if (cfg_.loop_optimization() != RewriterConfig::OFF) {
optimizers.push_back(std::unique_ptr<GraphOptimizer>(
new LoopOptimizer(cfg_.loop_optimization())));
}
if (cfg_.layout_optimizer() != RewriterConfig::OFF) {
optimizers.push_back(
std::unique_ptr<GraphOptimizer>(new LayoutOptimizer()));
}
if (cfg_.memory_optimization() > 1) {
if (cfg_.memory_optimization() != RewriterConfig::NO_MEM_OPT) {
if (cfg_.memory_optimizer_target_node_name_prefix().empty()) {
optimizers.push_back(std::unique_ptr<GraphOptimizer>(
// Use the default target node name prefix "gradients/"
@ -119,8 +127,8 @@ Status MetaOptimizer::Optimize(Cluster* cluster, const GrapplerItem& item,
}
} else {
std::set<string> available_optimizers = {
"pruning", "constfold", "layout", "memory",
"autoparallel", "arithmetic", "dependency"};
"pruning", "constfold", "layout", "memory",
"autoparallel", "arithmetic", "dependency", "loop"};
for (const auto& optimizer : cfg_.optimizers()) {
if (available_optimizers.find(optimizer) != available_optimizers.end()) {
optimizers.push_back(NewOptimizer(optimizer));
@ -136,7 +144,7 @@ Status MetaOptimizer::Optimize(Cluster* cluster, const GrapplerItem& item,
bool already_optimized = false;
for (const auto& optimizer : optimizers) {
if (!already_optimized) {
auto status = optimizer->Optimize(cluster, item, optimized_graph);
Status status = optimizer->Optimize(cluster, item, optimized_graph);
string result;
if (!status.ok()) {
VLOG(1) << "Not able to apply optimizer " << optimizer->name()
@ -152,7 +160,7 @@ Status MetaOptimizer::Optimize(Cluster* cluster, const GrapplerItem& item,
<< " return status: " << result;
} else {
GrapplerItem optimized_item(item, std::move(*optimized_graph));
auto status =
Status status =
optimizer->Optimize(cluster, optimized_item, optimized_graph);
string result;
if (!status.ok()) {
@ -204,8 +212,10 @@ bool MetaOptimizerEnabled(const RewriterConfig& cfg) {
cfg.layout_optimizer() != RewriterConfig::OFF ||
cfg.constant_folding() != RewriterConfig::OFF ||
cfg.dependency_optimization() != RewriterConfig::OFF ||
cfg.loop_optimization() == RewriterConfig::ON ||
cfg.arithmetic_optimization() != RewriterConfig::OFF ||
cfg.auto_parallel().enable() || cfg.memory_optimization() > 1 ||
cfg.auto_parallel().enable() ||
cfg.memory_optimization() != RewriterConfig::NO_MEM_OPT ||
!cfg.optimizers().empty();
}

View File

@ -67,7 +67,7 @@ Status ModelPruner::Optimize(Cluster* cluster, const GrapplerItem& item,
// let's be conservative and preserve the graph as is.
return errors::InvalidArgument("Invalid input graph.");
}
// Try to keep the nodes ordored somewhat topologically since this helps
// Try to keep the nodes ordered somewhat topologically since this helps
// further optimizations perform better.
for (int i = keep.size() - 1; i >= 0; --i) {
*runnable_item.graph.add_node() = *keep[i];

File diff suppressed because it is too large Load Diff

View File

@ -37,11 +37,13 @@ message RewriterConfig {
Toggle arithmetic_optimization = 7;
// Control dependency optimizations (default is ON).
Toggle dependency_optimization = 8;
// Loop optimizations (default is OFF).
Toggle loop_optimization = 9;
// If true, don't remove unnecessary ops from the graph
bool disable_model_pruning = 2;
enum MemOptType {
// The default setting (currently disabled)
// The default setting (SCHEDULING_HEURISTICS only)
DEFAULT_MEM_OPT = 0;
// Disabled in the meta-optimizer.
NO_MEM_OPT = 1;

View File

@ -36,6 +36,7 @@ the following three:
alt="Petal geometry compared for three iris species: Iris setosa, Iris virginica, and Iris versicolor"
src="../images/iris_three_species.jpg">
</div>
**From left to right,
[*Iris setosa*](https://commons.wikimedia.org/w/index.php?curid=170298) (by
[Radomil](https://commons.wikimedia.org/wiki/User:Radomil), CC BY-SA 3.0),
@ -188,6 +189,7 @@ provides a programming stack consisting of multiple API layers:
<div style="margin:auto; margin-bottom:10px; margin-top:20px;">
<img style="width:100%" src="../images/tensorflow_programming_environment.png">
</div>
**The TensorFlow Programming Environment.**
<p>&nbsp;</p>
@ -331,7 +333,7 @@ interpret data is such a rich topic that we devote an entire
From a code perspective, you build a list of `feature_column` objects by calling
functions from the @{tf.feature_column} module. Each object describes an input
to the model. To tell the model to interpret data as a floating-point value,
call @{tf.feature_column.numeric_column). In `premade_estimator.py`, all
call @{tf.feature_column.numeric_column}. In `premade_estimator.py`, all
four features should be interpreted as literal floating-point values, so
the code to create a feature column looks as follows:
@ -380,6 +382,7 @@ fully connected neural network consisting of three hidden layers:
<div style="margin:auto; margin-bottom:10px; margin-top:20px;">
<img style="width:100%" src="../images/simple_dnn.svg">
</div>
**A neural network with three hidden layers.**
<p>&nbsp;</p>
@ -568,6 +571,7 @@ of 0.5. The following suggests a more effective model:
<tr> <td>5.5</td> <td>2.5</td> <td>4.0</td> <td>1.3</td> <td>1</td>
<td style="background-color:green">1</td></tr>
</table>
**A model that is 80% accurate.**
<p>&nbsp;</p>

View File

@ -98,6 +98,7 @@ classifies Iris flowers into three different species based on the size of their
alt="Petal geometry compared for three iris species: Iris setosa, Iris virginica, and Iris versicolor"
src="../images/iris_three_species.jpg">
</div>
**From left to right,
[*Iris setosa*](https://commons.wikimedia.org/w/index.php?curid=170298) (by
[Radomil](https://commons.wikimedia.org/wiki/User:Radomil), CC BY-SA 3.0),

Some files were not shown because too many files have changed in this diff Show More