Merge pull request #16991 from yifeif/branch_185565363
Branch 185565363
This commit is contained in:
commit
17103a0b8d
@ -224,9 +224,6 @@ def tf_library(name, graph, config,
|
||||
# TODO(cwhipkey): only depend on kernel code that the model actually needed.
|
||||
"//tensorflow/compiler/tf2xla/kernels:index_ops_kernel_argmax_float_1d",
|
||||
"//tensorflow/compiler/tf2xla/kernels:index_ops_kernel_argmax_float_2d",
|
||||
"//tensorflow/compiler/xla/service/cpu:cpu_runtime_avx",
|
||||
"//tensorflow/compiler/xla/service/cpu:cpu_runtime_neon",
|
||||
"//tensorflow/compiler/xla/service/cpu:cpu_runtime_sse4_1",
|
||||
"//tensorflow/compiler/xla/service/cpu:runtime_conv2d",
|
||||
"//tensorflow/compiler/xla/service/cpu:runtime_matmul",
|
||||
"//tensorflow/compiler/xla/service/cpu:runtime_single_threaded_conv2d",
|
||||
|
@ -639,6 +639,7 @@ tf_xla_py_test(
|
||||
name = "variable_ops_test",
|
||||
size = "small",
|
||||
srcs = ["variable_ops_test.py"],
|
||||
tags = ["optonly"],
|
||||
deps = [
|
||||
":xla_test",
|
||||
"//tensorflow/python:array_ops",
|
||||
@ -677,6 +678,19 @@ tf_xla_py_test(
|
||||
],
|
||||
)
|
||||
|
||||
tf_xla_py_test(
|
||||
name = "scatter_nd_op_test",
|
||||
size = "medium",
|
||||
srcs = ["scatter_nd_op_test.py"],
|
||||
tags = ["optonly"],
|
||||
deps = [
|
||||
":xla_test",
|
||||
"//tensorflow/python:array_ops",
|
||||
"//tensorflow/python:framework_for_generated_wrappers",
|
||||
"//tensorflow/python:platform_test",
|
||||
],
|
||||
)
|
||||
|
||||
cuda_py_test(
|
||||
name = "xla_device_test",
|
||||
size = "small",
|
||||
@ -801,6 +815,17 @@ tf_library(
|
||||
tfcompile_flags = ["--xla_cpu_multi_thread_eigen=false"],
|
||||
)
|
||||
|
||||
tf_xla_py_test(
|
||||
name = "fake_quant_ops_test",
|
||||
size = "medium",
|
||||
srcs = ["fake_quant_ops_test.py"],
|
||||
deps = [
|
||||
":xla_test",
|
||||
"//tensorflow/python:framework_for_generated_wrappers",
|
||||
"//tensorflow/python:platform_test",
|
||||
],
|
||||
)
|
||||
|
||||
# -----------------------------------------------------------------------------
|
||||
|
||||
filegroup(
|
||||
|
452
tensorflow/compiler/tests/fake_quant_ops_test.py
Normal file
452
tensorflow/compiler/tests/fake_quant_ops_test.py
Normal file
@ -0,0 +1,452 @@
|
||||
# Copyright 2018 The TensorFlow Authors. All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ==============================================================================
|
||||
from __future__ import absolute_import
|
||||
from __future__ import division
|
||||
from __future__ import print_function
|
||||
|
||||
import numpy as np
|
||||
from tensorflow.compiler.tests.xla_test import XLATestCase
|
||||
from tensorflow.python.framework import dtypes
|
||||
from tensorflow.python.ops import array_ops
|
||||
from tensorflow.python.ops import gen_array_ops
|
||||
from tensorflow.python.platform import googletest
|
||||
|
||||
|
||||
class FakeQuantWithMinMaxArgsTest(XLATestCase):
|
||||
"""Test cases for FakeQuantWithMinMaxArgs operation."""
|
||||
|
||||
# 8 bits, wide range.
|
||||
def testOp_with8BitsNoScalingNoNudging(self):
|
||||
self._TestOp(0.0, 255.0, 8, False, 0.0, 255.0, 1.0)
|
||||
|
||||
def testOp_with8BitsScalingAndNudgingDown(self):
|
||||
self._TestOp(0.5, 128.0, 8, False, 0.0, 127.5, 0.5)
|
||||
|
||||
def testOp_with8BitsScalingAndNudgingUp(self):
|
||||
self._TestOp(-128.0, -0.5, 8, False, -127.5, 0.0, 0.5)
|
||||
|
||||
def testOp_with8BitsScalingAndNudgingBetween(self):
|
||||
self._TestOp(-0.1, 127.4, 8, False, 0.0, 127.5, 0.5)
|
||||
|
||||
# 8 bits, narrow range.
|
||||
def testOp_with8BitsNarrowRangeNoScalingNoNudging(self):
|
||||
self._TestOp(0.0, 254.0, 8, True, 0.0, 254.0, 1.0)
|
||||
|
||||
def testOp_with8BitsNarrowRangeScalingAndNudgingDown(self):
|
||||
self._TestOp(0.1, 127.1, 8, True, 0.0, 127.0, 0.5)
|
||||
|
||||
def testOp_with8BitsNarrowRangeScalingAndNudgingUp(self):
|
||||
self._TestOp(-127.1, -0.1, 8, True, -127.0, 0.0, 0.5)
|
||||
|
||||
def testOp_with8BitsNarrowRangeScalingAndNudgingBetween(self):
|
||||
self._TestOp(-0.1, 126.9, 8, True, 0.0, 127.0, 0.5)
|
||||
|
||||
# 7 bits, wide range.
|
||||
def testOp_with7BitsNoScalingNoNudging(self):
|
||||
self._TestOp(0.0, 127.0, 7, False, 0.0, 127.0, 1.0)
|
||||
|
||||
def testOp_with7BitsScalingAndNudgingDown(self):
|
||||
self._TestOp(0.5, 64.0, 7, False, 0.0, 63.5, 0.5)
|
||||
|
||||
def testOp_with7BitsScalingAndNudgingUp(self):
|
||||
self._TestOp(-64.0, -0.5, 7, False, -63.5, 0.0, 0.5)
|
||||
|
||||
def testOp_with7BitsScalingAndNudgingBetween(self):
|
||||
self._TestOp(-0.1, 63.4, 7, False, 0.0, 63.5, 0.5)
|
||||
|
||||
# 7 bits, narrow range.
|
||||
def testOp_with7BitsNarrowRangeNoScalingNoNudging(self):
|
||||
self._TestOp(0.0, 126.0, 7, True, 0.0, 126.0, 1.0)
|
||||
|
||||
def testOp_with7BitsNarrowRangeScalingAndNudgingDown(self):
|
||||
self._TestOp(0.1, 63.1, 7, True, 0.0, 63.0, 0.5)
|
||||
|
||||
def testOp_with7BitsNarrowRangeScalingAndNudgingUp(self):
|
||||
self._TestOp(-63.1, -0.1, 7, True, -63.0, 0.0, 0.5)
|
||||
|
||||
def testOp_with7BitsNarrowRangeScalingAndNudgingBetween(self):
|
||||
self._TestOp(-0.1, 62.9, 7, True, 0.0, 63.0, 0.5)
|
||||
|
||||
def _TestOp(self, input_min, input_max, num_bits, narrow_range,
|
||||
expected_nudged_input_min, expected_nudged_input_max,
|
||||
expected_step):
|
||||
inputs = np.array(
|
||||
[
|
||||
expected_nudged_input_min - expected_step,
|
||||
expected_nudged_input_min - 0.01, expected_nudged_input_min,
|
||||
expected_nudged_input_min + 0.01,
|
||||
expected_nudged_input_min + expected_step - 0.01,
|
||||
expected_nudged_input_min + expected_step,
|
||||
expected_nudged_input_min + expected_step + 0.01,
|
||||
expected_nudged_input_max - 0.01, expected_nudged_input_max,
|
||||
expected_nudged_input_max + 0.01,
|
||||
expected_nudged_input_max + expected_step
|
||||
],
|
||||
dtype=np.float32)
|
||||
expected = np.array(
|
||||
[
|
||||
expected_nudged_input_min, expected_nudged_input_min,
|
||||
expected_nudged_input_min, expected_nudged_input_min,
|
||||
expected_nudged_input_min + expected_step,
|
||||
expected_nudged_input_min + expected_step,
|
||||
expected_nudged_input_min + expected_step,
|
||||
expected_nudged_input_max, expected_nudged_input_max,
|
||||
expected_nudged_input_max, expected_nudged_input_max
|
||||
],
|
||||
dtype=np.float32)
|
||||
|
||||
with self.test_session() as session:
|
||||
with self.test_scope():
|
||||
input_placeholder = array_ops.placeholder(
|
||||
dtypes.float32, inputs.shape, name="inputs")
|
||||
outputs = array_ops.fake_quant_with_min_max_args(
|
||||
input_placeholder,
|
||||
min=input_min,
|
||||
max=input_max,
|
||||
num_bits=num_bits,
|
||||
narrow_range=narrow_range)
|
||||
result = session.run(outputs, {input_placeholder: inputs})
|
||||
self.assertAllCloseAccordingToType(
|
||||
result, expected, rtol=1e-3, atol=1e-5, bfloat16_rtol=0.03)
|
||||
|
||||
|
||||
class FakeQuantWithMinMaxArgsGradientTest(XLATestCase):
|
||||
"""Test cases for FakeQuantWithMinMaxArgsGradient operation."""
|
||||
|
||||
# 8 bits, wide range.
|
||||
def testOp_with8BitsNoScalingNoNudging(self):
|
||||
self._TestOp(0.0, 255.0, 8, False, 0.0, 255.0, 1.0)
|
||||
|
||||
def testOp_with8BitsScalingAndNudgingDown(self):
|
||||
self._TestOp(0.5, 128.0, 8, False, 0.0, 127.5, 0.5)
|
||||
|
||||
def testOp_with8BitsScalingAndNudgingUp(self):
|
||||
self._TestOp(-128.0, -0.5, 8, False, -127.5, 0.0, 0.5)
|
||||
|
||||
def testOp_with8BitsScalingAndNudgingBetween(self):
|
||||
self._TestOp(-0.1, 127.4, 8, False, 0.0, 127.5, 0.5)
|
||||
|
||||
# 8 bits, narrow range.
|
||||
def testOp_with8BitsNarrowRangeNoScalingNoNudging(self):
|
||||
self._TestOp(0.0, 254.0, 8, True, 0.0, 254.0, 1.0)
|
||||
|
||||
def testOp_with8BitsNarrowRangeScalingAndNudgingDown(self):
|
||||
self._TestOp(0.1, 127.1, 8, True, 0.0, 127.0, 0.5)
|
||||
|
||||
def testOp_with8BitsNarrowRangeScalingAndNudgingUp(self):
|
||||
self._TestOp(-127.1, -0.1, 8, True, -127.0, 0.0, 0.5)
|
||||
|
||||
def testOp_with8BitsNarrowRangeScalingAndNudgingBetween(self):
|
||||
self._TestOp(-0.1, 126.9, 8, True, 0.0, 127.0, 0.5)
|
||||
|
||||
# 7 bits, wide range.
|
||||
def testOp_with7BitsNoScalingNoNudging(self):
|
||||
self._TestOp(0.0, 127.0, 7, False, 0.0, 127.0, 1.0)
|
||||
|
||||
def testOp_with7BitsScalingAndNudgingDown(self):
|
||||
self._TestOp(0.5, 64.0, 7, False, 0.0, 63.5, 0.5)
|
||||
|
||||
def testOp_with7BitsScalingAndNudgingUp(self):
|
||||
self._TestOp(-64.0, -0.5, 7, False, -63.5, 0.0, 0.5)
|
||||
|
||||
def testOp_with7BitsScalingAndNudgingBetween(self):
|
||||
self._TestOp(-0.1, 63.4, 7, False, 0.0, 63.5, 0.5)
|
||||
|
||||
# 7 bits, narrow range.
|
||||
def testOp_with7BitsNarrowRangeNoScalingNoNudging(self):
|
||||
self._TestOp(0.0, 126.0, 7, True, 0.0, 126.0, 1.0)
|
||||
|
||||
def testOp_with7BitsNarrowRangeScalingAndNudgingDown(self):
|
||||
self._TestOp(0.1, 63.1, 7, True, 0.0, 63.0, 0.5)
|
||||
|
||||
def testOp_with7BitsNarrowRangeScalingAndNudgingUp(self):
|
||||
self._TestOp(-63.1, -0.1, 7, True, -63.0, 0.0, 0.5)
|
||||
|
||||
def testOp_with7BitsNarrowRangeScalingAndNudgingBetween(self):
|
||||
self._TestOp(-0.1, 62.9, 7, True, 0.0, 63.0, 0.5)
|
||||
|
||||
def _TestOp(self, input_min, input_max, num_bits, narrow_range,
|
||||
expected_nudged_input_min, expected_nudged_input_max,
|
||||
expected_step):
|
||||
inputs = np.array(
|
||||
[
|
||||
expected_nudged_input_min - expected_step,
|
||||
expected_nudged_input_min - 0.01, expected_nudged_input_min,
|
||||
expected_nudged_input_min + 0.01,
|
||||
expected_nudged_input_min + expected_step - 0.01,
|
||||
expected_nudged_input_min + expected_step,
|
||||
expected_nudged_input_min + expected_step + 0.01,
|
||||
expected_nudged_input_max - 0.01, expected_nudged_input_max,
|
||||
expected_nudged_input_max + 0.01,
|
||||
expected_nudged_input_max + expected_step
|
||||
],
|
||||
dtype=np.float32)
|
||||
gradients = np.arange(1, len(inputs) + 1, dtype=np.float32)
|
||||
expected_backprops = np.array(
|
||||
[0.0, 0.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0, 9.0, 0.0, 0.0],
|
||||
dtype=np.float32)
|
||||
|
||||
with self.test_session() as session:
|
||||
with self.test_scope():
|
||||
gradient_placeholder = array_ops.placeholder(
|
||||
dtypes.float32, gradients.shape, name="gradients")
|
||||
input_placeholder = array_ops.placeholder(
|
||||
dtypes.float32, inputs.shape, name="inputs")
|
||||
outputs = gen_array_ops.fake_quant_with_min_max_args_gradient(
|
||||
gradient_placeholder,
|
||||
input_placeholder,
|
||||
min=input_min,
|
||||
max=input_max,
|
||||
num_bits=num_bits,
|
||||
narrow_range=narrow_range)
|
||||
backprops = session.run(outputs, {
|
||||
gradient_placeholder: gradients,
|
||||
input_placeholder: inputs
|
||||
})
|
||||
self.assertAllCloseAccordingToType(
|
||||
backprops,
|
||||
expected_backprops,
|
||||
rtol=1e-3,
|
||||
atol=1e-5,
|
||||
bfloat16_rtol=0.03)
|
||||
|
||||
|
||||
class FakeQuantWithMinMaxVarsTest(XLATestCase):
|
||||
"""Test cases for FakeQuantWithMinMaxVars operation."""
|
||||
|
||||
# 8 bits, wide range.
|
||||
def testOp_with8BitsNoScalingNoNudging(self):
|
||||
self._TestOp(0.0, 255.0, 8, False, 0.0, 255.0, 1.0)
|
||||
|
||||
def testOp_with8BitsScalingAndNudgingDown(self):
|
||||
self._TestOp(0.5, 128.0, 8, False, 0.0, 127.5, 0.5)
|
||||
|
||||
def testOp_with8BitsScalingAndNudgingUp(self):
|
||||
self._TestOp(-128.0, -0.5, 8, False, -127.5, 0.0, 0.5)
|
||||
|
||||
def testOp_with8BitsScalingAndNudgingBetween(self):
|
||||
self._TestOp(-0.1, 127.4, 8, False, 0.0, 127.5, 0.5)
|
||||
|
||||
# 8 bits, narrow range.
|
||||
def testOp_with8BitsNarrowRangeNoScalingNoNudging(self):
|
||||
self._TestOp(0.0, 254.0, 8, True, 0.0, 254.0, 1.0)
|
||||
|
||||
def testOp_with8BitsNarrowRangeScalingAndNudgingDown(self):
|
||||
self._TestOp(0.1, 127.1, 8, True, 0.0, 127.0, 0.5)
|
||||
|
||||
def testOp_with8BitsNarrowRangeScalingAndNudgingUp(self):
|
||||
self._TestOp(-127.1, -0.1, 8, True, -127.0, 0.0, 0.5)
|
||||
|
||||
def testOp_with8BitsNarrowRangeScalingAndNudgingBetween(self):
|
||||
self._TestOp(-0.1, 126.9, 8, True, 0.0, 127.0, 0.5)
|
||||
|
||||
# 7 bits, wide range.
|
||||
def testOp_with7BitsNoScalingNoNudging(self):
|
||||
self._TestOp(0.0, 127.0, 7, False, 0.0, 127.0, 1.0)
|
||||
|
||||
def testOp_with7BitsScalingAndNudgingDown(self):
|
||||
self._TestOp(0.5, 64.0, 7, False, 0.0, 63.5, 0.5)
|
||||
|
||||
def testOp_with7BitsScalingAndNudgingUp(self):
|
||||
self._TestOp(-64.0, -0.5, 7, False, -63.5, 0.0, 0.5)
|
||||
|
||||
def testOp_with7BitsScalingAndNudgingBetween(self):
|
||||
self._TestOp(-0.1, 63.4, 7, False, 0.0, 63.5, 0.5)
|
||||
|
||||
# 7 bits, narrow range.
|
||||
def testOp_with7BitsNarrowRangeNoScalingNoNudging(self):
|
||||
self._TestOp(0.0, 126.0, 7, True, 0.0, 126.0, 1.0)
|
||||
|
||||
def testOp_with7BitsNarrowRangeScalingAndNudgingDown(self):
|
||||
self._TestOp(0.1, 63.1, 7, True, 0.0, 63.0, 0.5)
|
||||
|
||||
def testOp_with7BitsNarrowRangeScalingAndNudgingUp(self):
|
||||
self._TestOp(-63.1, -0.1, 7, True, -63.0, 0.0, 0.5)
|
||||
|
||||
def testOp_with7BitsNarrowRangeScalingAndNudgingBetween(self):
|
||||
self._TestOp(-0.1, 62.9, 7, True, 0.0, 63.0, 0.5)
|
||||
|
||||
def _TestOp(self, input_min, input_max, num_bits, narrow_range,
|
||||
expected_nudged_input_min, expected_nudged_input_max,
|
||||
expected_step):
|
||||
inputs = np.array(
|
||||
[
|
||||
expected_nudged_input_min - expected_step,
|
||||
expected_nudged_input_min - 0.01, expected_nudged_input_min,
|
||||
expected_nudged_input_min + 0.01,
|
||||
expected_nudged_input_min + expected_step - 0.01,
|
||||
expected_nudged_input_min + expected_step,
|
||||
expected_nudged_input_min + expected_step + 0.01,
|
||||
expected_nudged_input_max - 0.01, expected_nudged_input_max,
|
||||
expected_nudged_input_max + 0.01,
|
||||
expected_nudged_input_max + expected_step
|
||||
],
|
||||
dtype=np.float32)
|
||||
expected = np.array(
|
||||
[
|
||||
expected_nudged_input_min, expected_nudged_input_min,
|
||||
expected_nudged_input_min, expected_nudged_input_min,
|
||||
expected_nudged_input_min + expected_step,
|
||||
expected_nudged_input_min + expected_step,
|
||||
expected_nudged_input_min + expected_step,
|
||||
expected_nudged_input_max, expected_nudged_input_max,
|
||||
expected_nudged_input_max, expected_nudged_input_max
|
||||
],
|
||||
dtype=np.float32)
|
||||
|
||||
with self.test_session() as session:
|
||||
with self.test_scope():
|
||||
input_placeholder = array_ops.placeholder(
|
||||
dtypes.float32, inputs.shape, name="inputs")
|
||||
min_placeholder = array_ops.placeholder(dtypes.float32, (), name="min")
|
||||
max_placeholder = array_ops.placeholder(dtypes.float32, (), name="max")
|
||||
outputs = array_ops.fake_quant_with_min_max_vars(
|
||||
input_placeholder,
|
||||
min_placeholder,
|
||||
max_placeholder,
|
||||
num_bits=num_bits,
|
||||
narrow_range=narrow_range)
|
||||
result = session.run(
|
||||
outputs, {
|
||||
input_placeholder: inputs,
|
||||
min_placeholder: input_min,
|
||||
max_placeholder: input_max
|
||||
})
|
||||
self.assertAllCloseAccordingToType(
|
||||
result, expected, rtol=1e-3, atol=1e-5, bfloat16_rtol=0.03)
|
||||
|
||||
|
||||
class FakeQuantWithMinMaxVarsGradientTest(XLATestCase):
|
||||
"""Test cases for FakeQuantWithMinMaxVarsGradient operation."""
|
||||
|
||||
# 8 bits, wide range.
|
||||
def testOp_with8BitsNoScalingNoNudging(self):
|
||||
self._TestOp(0.0, 255.0, 8, False, 0.0, 255.0, 1.0)
|
||||
|
||||
def testOp_with8BitsScalingAndNudgingDown(self):
|
||||
self._TestOp(0.5, 128.0, 8, False, 0.0, 127.5, 0.5)
|
||||
|
||||
def testOp_with8BitsScalingAndNudgingUp(self):
|
||||
self._TestOp(-128.0, -0.5, 8, False, -127.5, 0.0, 0.5)
|
||||
|
||||
def testOp_with8BitsScalingAndNudgingBetween(self):
|
||||
self._TestOp(-0.1, 127.4, 8, False, 0.0, 127.5, 0.5)
|
||||
|
||||
# 8 bits, narrow range.
|
||||
def testOp_with8BitsNarrowRangeNoScalingNoNudging(self):
|
||||
self._TestOp(0.0, 254.0, 8, True, 0.0, 254.0, 1.0)
|
||||
|
||||
def testOp_with8BitsNarrowRangeScalingAndNudgingDown(self):
|
||||
self._TestOp(0.1, 127.1, 8, True, 0.0, 127.0, 0.5)
|
||||
|
||||
def testOp_with8BitsNarrowRangeScalingAndNudgingUp(self):
|
||||
self._TestOp(-127.1, -0.1, 8, True, -127.0, 0.0, 0.5)
|
||||
|
||||
def testOp_with8BitsNarrowRangeScalingAndNudgingBetween(self):
|
||||
self._TestOp(-0.1, 126.9, 8, True, 0.0, 127.0, 0.5)
|
||||
|
||||
# 7 bits, wide range.
|
||||
def testOp_with7BitsNoScalingNoNudging(self):
|
||||
self._TestOp(0.0, 127.0, 7, False, 0.0, 127.0, 1.0)
|
||||
|
||||
def testOp_with7BitsScalingAndNudgingDown(self):
|
||||
self._TestOp(0.5, 64.0, 7, False, 0.0, 63.5, 0.5)
|
||||
|
||||
def testOp_with7BitsScalingAndNudgingUp(self):
|
||||
self._TestOp(-64.0, -0.5, 7, False, -63.5, 0.0, 0.5)
|
||||
|
||||
def testOp_with7BitsScalingAndNudgingBetween(self):
|
||||
self._TestOp(-0.1, 63.4, 7, False, 0.0, 63.5, 0.5)
|
||||
|
||||
# 7 bits, narrow range.
|
||||
def testOp_with7BitsNarrowRangeNoScalingNoNudging(self):
|
||||
self._TestOp(0.0, 126.0, 7, True, 0.0, 126.0, 1.0)
|
||||
|
||||
def testOp_with7BitsNarrowRangeScalingAndNudgingDown(self):
|
||||
self._TestOp(0.1, 63.1, 7, True, 0.0, 63.0, 0.5)
|
||||
|
||||
def testOp_with7BitsNarrowRangeScalingAndNudgingUp(self):
|
||||
self._TestOp(-63.1, -0.1, 7, True, -63.0, 0.0, 0.5)
|
||||
|
||||
def testOp_with7BitsNarrowRangeScalingAndNudgingBetween(self):
|
||||
self._TestOp(-0.1, 62.9, 7, True, 0.0, 63.0, 0.5)
|
||||
|
||||
def _TestOp(self, input_min, input_max, num_bits, narrow_range,
|
||||
expected_nudged_input_min, expected_nudged_input_max,
|
||||
expected_step):
|
||||
inputs = np.array(
|
||||
[
|
||||
expected_nudged_input_min - expected_step,
|
||||
expected_nudged_input_min - 0.01, expected_nudged_input_min,
|
||||
expected_nudged_input_min + 0.01,
|
||||
expected_nudged_input_min + expected_step - 0.01,
|
||||
expected_nudged_input_min + expected_step,
|
||||
expected_nudged_input_min + expected_step + 0.01,
|
||||
expected_nudged_input_max - 0.01, expected_nudged_input_max,
|
||||
expected_nudged_input_max + 0.01,
|
||||
expected_nudged_input_max + expected_step
|
||||
],
|
||||
dtype=np.float32)
|
||||
gradients = np.arange(1, len(inputs) + 1, dtype=np.float32)
|
||||
expected_backprops_wrt_input = np.array(
|
||||
[0.0, 0.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0, 9.0, 0.0, 0.0],
|
||||
dtype=np.float32)
|
||||
expected_backprops_wrt_min = 1.0 + 2.0
|
||||
expected_backprops_wrt_max = 10.0 + 11.0
|
||||
|
||||
with self.test_session() as session:
|
||||
with self.test_scope():
|
||||
gradient_placeholder = array_ops.placeholder(
|
||||
dtypes.float32, gradients.shape, name="gradients")
|
||||
input_placeholder = array_ops.placeholder(
|
||||
dtypes.float32, inputs.shape, name="inputs")
|
||||
min_placeholder = array_ops.placeholder(dtypes.float32, (), name="min")
|
||||
max_placeholder = array_ops.placeholder(dtypes.float32, (), name="max")
|
||||
outputs = array_ops.fake_quant_with_min_max_vars_gradient(
|
||||
gradient_placeholder,
|
||||
input_placeholder,
|
||||
min_placeholder,
|
||||
max_placeholder,
|
||||
num_bits=num_bits,
|
||||
narrow_range=narrow_range)
|
||||
backprops_wrt_input, backprops_wrt_min, backprops_wrt_max = session.run(
|
||||
outputs, {
|
||||
gradient_placeholder: gradients,
|
||||
input_placeholder: inputs,
|
||||
min_placeholder: input_min,
|
||||
max_placeholder: input_max
|
||||
})
|
||||
self.assertAllCloseAccordingToType(
|
||||
backprops_wrt_input,
|
||||
expected_backprops_wrt_input,
|
||||
rtol=1e-3,
|
||||
atol=1e-5,
|
||||
bfloat16_rtol=0.03)
|
||||
self.assertAllCloseAccordingToType(
|
||||
backprops_wrt_min,
|
||||
expected_backprops_wrt_min,
|
||||
rtol=1e-3,
|
||||
atol=1e-5,
|
||||
bfloat16_rtol=0.03)
|
||||
self.assertAllCloseAccordingToType(
|
||||
backprops_wrt_max,
|
||||
expected_backprops_wrt_max,
|
||||
rtol=1e-3,
|
||||
atol=1e-5,
|
||||
bfloat16_rtol=0.03)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
googletest.main()
|
188
tensorflow/compiler/tests/scatter_nd_op_test.py
Normal file
188
tensorflow/compiler/tests/scatter_nd_op_test.py
Normal file
@ -0,0 +1,188 @@
|
||||
# Copyright 2018 The TensorFlow Authors. All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ==============================================================================
|
||||
"""Tests for tensorflow.ops.tf.scatter_nd."""
|
||||
|
||||
from __future__ import absolute_import
|
||||
from __future__ import division
|
||||
from __future__ import print_function
|
||||
|
||||
import functools
|
||||
|
||||
import numpy as np
|
||||
|
||||
from tensorflow.compiler.tests.xla_test import XLATestCase
|
||||
from tensorflow.python.framework import errors
|
||||
from tensorflow.python.ops import array_ops
|
||||
from tensorflow.python.platform import test
|
||||
|
||||
|
||||
def _AsType(v, vtype):
|
||||
return v.astype(vtype) if isinstance(v, np.ndarray) else vtype(v)
|
||||
|
||||
|
||||
def _FlatInnerDims(tensor, ndims=2):
|
||||
shape = list(tensor.shape)
|
||||
return tensor.reshape(
|
||||
[functools.reduce(lambda x, y: x * y, shape[:-ndims + 1], 1)] +
|
||||
shape[-ndims + 1:])
|
||||
|
||||
|
||||
def _FlatOuterDims(tensor, ndims=2):
|
||||
shape = list(tensor.shape)
|
||||
return tensor.reshape(
|
||||
shape[:ndims - 1] +
|
||||
[functools.reduce(lambda x, y: x * y, shape[ndims - 1:], 1)])
|
||||
|
||||
|
||||
def _NumpyScatterNd(ref, indices, updates, op):
|
||||
ixdim = indices.shape[-1]
|
||||
num_updates = indices.size // ixdim
|
||||
total_nd = len(ref.shape)
|
||||
slice_size = 1
|
||||
for i in range(ixdim, total_nd):
|
||||
slice_size *= ref.shape[i]
|
||||
flat_indices = _FlatInnerDims(indices)
|
||||
flat_updates = updates.reshape((num_updates, slice_size))
|
||||
output_flat = _FlatOuterDims(ref, ixdim + 1)
|
||||
for ix_updates, ix_output in enumerate(flat_indices):
|
||||
ix_output = tuple(ix_output)
|
||||
output_flat[ix_output] = op(output_flat[ix_output],
|
||||
flat_updates[ix_updates])
|
||||
return output_flat.reshape(ref.shape)
|
||||
|
||||
|
||||
def _NumpyUpdate(indices, updates, shape):
|
||||
ref = np.zeros(shape, dtype=updates.dtype)
|
||||
return _NumpyScatterNd(ref, indices, updates, lambda p, u: u)
|
||||
|
||||
|
||||
class ScatterNdTest(XLATestCase):
|
||||
|
||||
def _VariableRankTest(self,
|
||||
np_scatter,
|
||||
tf_scatter,
|
||||
vtype,
|
||||
itype,
|
||||
repeat_indices=False):
|
||||
np.random.seed(8)
|
||||
ref_shapes = [(3, 6), (3, 6), (3, 6, 9), (3, 6, 9), (3, 6, 9), (3, 6, 9)]
|
||||
indices_shapes = [(2,), (2, 2), (2,), (2, 2), (2, 3), (2, 3, 3)]
|
||||
for ref_shape, indices_shape in zip(ref_shapes, indices_shapes):
|
||||
num_updates = indices_shape[0]
|
||||
ixdim = indices_shape[-1]
|
||||
|
||||
indexable_area_shape = ()
|
||||
for i in range(ixdim):
|
||||
indexable_area_shape += (ref_shape[i],)
|
||||
all_indices = [
|
||||
list(coord)
|
||||
for coord, _ in np.ndenumerate(np.empty(indexable_area_shape, vtype))
|
||||
]
|
||||
np.random.shuffle(all_indices)
|
||||
indices = np.array(all_indices[:num_updates])
|
||||
|
||||
if num_updates > 1 and repeat_indices:
|
||||
indices = indices[:num_updates // 2]
|
||||
for _ in range(num_updates - num_updates // 2):
|
||||
indices = np.append(
|
||||
indices, [indices[np.random.randint(num_updates // 2)]], axis=0)
|
||||
np.random.shuffle(indices)
|
||||
indices = _AsType(indices[:num_updates], itype)
|
||||
|
||||
updates_shape = (num_updates,)
|
||||
for i in range(ixdim, len(ref_shape)):
|
||||
updates_shape += (ref_shape[i],)
|
||||
updates = _AsType(np.random.randn(*(updates_shape)), vtype)
|
||||
|
||||
# Scatter via numpy
|
||||
np_out = np_scatter(indices, updates, ref_shape)
|
||||
# Scatter via tensorflow
|
||||
tf_out = tf_scatter(indices, updates, ref_shape)
|
||||
|
||||
self.assertAllClose(np_out, tf_out)
|
||||
|
||||
def _VariableRankTests(self, np_scatter, tf_scatter):
|
||||
for vtype in self.numeric_types:
|
||||
for itype in set([np.int32, np.int64]).intersection(set(self.int_types)):
|
||||
self._VariableRankTest(np_scatter, tf_scatter, vtype, itype)
|
||||
|
||||
def _runScatterNd(self, indices, updates, shape):
|
||||
with self.test_session():
|
||||
updates_placeholder = array_ops.placeholder(updates.dtype)
|
||||
indices_placeholder = array_ops.placeholder(indices.dtype)
|
||||
with self.test_scope():
|
||||
output = array_ops.scatter_nd(indices_placeholder, updates_placeholder,
|
||||
shape)
|
||||
feed_dict = {updates_placeholder: updates, indices_placeholder: indices}
|
||||
return output.eval(feed_dict=feed_dict)
|
||||
|
||||
def testSimple(self):
|
||||
indices = np.array([[4], [3], [1], [7]], dtype=np.int32)
|
||||
updates = np.array([9, 10, 11, 12], dtype=np.float32)
|
||||
expected = np.array([0, 11, 0, 10, 9, 0, 0, 12], dtype=np.int32)
|
||||
self.assertAllEqual(expected, self._runScatterNd(indices, updates, [8]))
|
||||
|
||||
def testSimple2(self):
|
||||
indices = np.array([[1, 0], [1, 1]], dtype=np.int32)
|
||||
updates = np.array([11., 12.], dtype=np.float32)
|
||||
expected = np.array([[0., 0.], [11., 12.], [0., 0.]], dtype=np.float32)
|
||||
self.assertAllEqual(expected, self._runScatterNd(indices, updates, [3, 2]))
|
||||
|
||||
def testSimple3(self):
|
||||
indices = np.array([[1]], dtype=np.int32)
|
||||
updates = np.array([[11., 12.]], dtype=np.float32)
|
||||
expected = np.array([[0., 0.], [11., 12.], [0., 0.]])
|
||||
self.assertAllEqual(expected, self._runScatterNd(indices, updates, [3, 2]))
|
||||
|
||||
def testVariableRankUpdate(self):
|
||||
self._VariableRankTests(_NumpyUpdate, self._runScatterNd)
|
||||
|
||||
def testExtraIndicesDimensions(self):
|
||||
indices = np.zeros([1, 1, 2], np.int32)
|
||||
updates = np.zeros([1, 1], np.int32)
|
||||
expected = np.zeros([2, 2], dtype=np.int32)
|
||||
self.assertAllEqual(expected, self._runScatterNd(indices, updates, [2, 2]))
|
||||
|
||||
def testRank3InvalidShape1(self):
|
||||
indices = np.zeros([3, 2, 2], np.int32)
|
||||
updates = np.zeros([2, 2, 2], np.int32)
|
||||
with self.assertRaisesWithPredicateMatch(errors.InvalidArgumentError,
|
||||
"Must have updates.shape"):
|
||||
self._runScatterNd(indices, updates, [2, 2, 2])
|
||||
|
||||
def testRank3InvalidShape2(self):
|
||||
indices = np.zeros([2, 2, 1], np.int32)
|
||||
updates = np.zeros([2, 2], np.int32)
|
||||
with self.assertRaisesWithPredicateMatch(errors.InvalidArgumentError,
|
||||
"Must have updates.shape"):
|
||||
self._runScatterNd(indices, updates, [2, 2, 2])
|
||||
|
||||
def testScatterOutOfRange(self):
|
||||
updates = np.array([-3, -4, -5]).astype(np.float32)
|
||||
|
||||
# Indices all in range, no problem.
|
||||
indices = np.array([[2], [0], [5]], dtype=np.int32)
|
||||
self._runScatterNd(indices, updates, [6])
|
||||
|
||||
# Indices out of range should not fail. It produces implementation-defined
|
||||
# output.
|
||||
indices = np.array([[-1], [0], [5]], dtype=np.int32)
|
||||
self._runScatterNd(indices, updates, [6])
|
||||
indices = np.array([[2], [0], [6]], dtype=np.int32)
|
||||
self._runScatterNd(indices, updates, [6])
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
test.main()
|
@ -60,6 +60,14 @@ class SegmentReductionOpsTest(XLATestCase):
|
||||
np.array([0, 1, 2, 3, 4, 5], dtype=dtype),
|
||||
np.array([3, 0, 2, 1, 3, 3], dtype=np.int32), 4))
|
||||
|
||||
def testUnsortedSegmentSum1DIndices1DDataNegativeIndices(self):
|
||||
for dtype in self.numeric_types:
|
||||
self.assertAllClose(
|
||||
np.array([0, 3, 2, 5], dtype=dtype),
|
||||
self.UnsortedSegmentSum(
|
||||
np.array([0, 1, 2, 3, 4, 5], dtype=dtype),
|
||||
np.array([3, -1, 2, 1, -1, 3], dtype=np.int32), 4))
|
||||
|
||||
def testUnsortedSegmentSum1DIndices2DDataDisjoint(self):
|
||||
for dtype in self.numeric_types:
|
||||
data = np.array(
|
||||
|
@ -285,7 +285,8 @@ Status BuildLoopBody(const Graph& graph, Frame* frame,
|
||||
Status FunctionalizeLoop(Graph* graph, Frame* frame,
|
||||
FunctionLibraryDefinition* library) {
|
||||
VLOG(2) << "Frame " << frame->name << " before: "
|
||||
<< dump_graph::DumpGraphToFile("functionalize_before", *graph);
|
||||
<< dump_graph::DumpGraphToFile("functionalize_before", *graph,
|
||||
library);
|
||||
|
||||
// Split loop-varying Enter nodes with multiple successors. If the same
|
||||
// Tensor is fed as input to multiple loop arguments, we may end up with a
|
||||
@ -470,7 +471,7 @@ Status FunctionalizeLoop(Graph* graph, Frame* frame,
|
||||
TF_RETURN_IF_ERROR(BuildLoopBody(*graph, frame, &arg_types, &body_graph));
|
||||
|
||||
VLOG(2) << "Frame " << frame->name << " condition: "
|
||||
<< dump_graph::DumpGraphToFile("loop_condition", *cond_graph)
|
||||
<< dump_graph::DumpGraphToFile("loop_condition", *cond_graph, library)
|
||||
<< " body: " << dump_graph::DumpGraphToFile("loop_body", *body_graph);
|
||||
|
||||
static std::atomic<int64> sequence_num(0LL);
|
||||
@ -551,7 +552,8 @@ Status FunctionalizeLoop(Graph* graph, Frame* frame,
|
||||
frame->parent->nodes.insert(while_node);
|
||||
|
||||
VLOG(2) << "Frame " << frame->name << " after: "
|
||||
<< dump_graph::DumpGraphToFile("functionalize_after", *graph);
|
||||
<< dump_graph::DumpGraphToFile("functionalize_after", *graph,
|
||||
library);
|
||||
|
||||
return Status::OK();
|
||||
}
|
||||
@ -584,11 +586,11 @@ class FunctionalizeCond {
|
||||
explicit CondArgNode(Node* input) : input(input) {}
|
||||
string ToString() const {
|
||||
return strings::StrCat("input=", input->name(),
|
||||
" switches=", NodesToString(switch_nodes));
|
||||
" switches=", NodesToString(switches));
|
||||
}
|
||||
|
||||
Node* input;
|
||||
std::vector<Node*> switch_nodes;
|
||||
std::vector<Node*> switches;
|
||||
};
|
||||
using CondArgNodes = std::vector<CondArgNode>;
|
||||
|
||||
@ -602,15 +604,22 @@ class FunctionalizeCond {
|
||||
int count;
|
||||
};
|
||||
|
||||
struct PredicateSwitches {
|
||||
explicit PredicateSwitches(Node* predicate) : predicate(predicate) {}
|
||||
// Group of switch nodes that will be part of the same XlaIf.
|
||||
struct SwitchCluster {
|
||||
explicit SwitchCluster(Node* predicate) : predicate(predicate) {}
|
||||
string ToString() const {
|
||||
return strings::StrCat(name, " predicate=", predicate->name(),
|
||||
" switches=", NodesToString(switches));
|
||||
}
|
||||
|
||||
string name;
|
||||
Node* predicate;
|
||||
std::vector<Node*> switches;
|
||||
};
|
||||
|
||||
FunctionalizeCond(Graph* graph, FunctionLibraryDefinition* library)
|
||||
: library_(library), graph_(graph) {}
|
||||
FunctionalizeCond(Graph* graph, FunctionLibraryDefinition* library,
|
||||
bool dump_graphs)
|
||||
: library_(library), graph_(graph), dump_graphs_(dump_graphs) {}
|
||||
|
||||
// Perform the actual cond functionalization. Iterate over groups of switch
|
||||
// nodes (linked by common predicate), from innermost to outermost, and
|
||||
@ -621,27 +630,25 @@ class FunctionalizeCond {
|
||||
// frontier (the nodes where the cond ends).
|
||||
StatusOr<std::pair<std::unordered_map<Node*, ForwardFlowNode>,
|
||||
std::unordered_set<Node*>>>
|
||||
DetermineBranchMapAndFrontier(const std::vector<Node*>& switches);
|
||||
DetermineBranchMapAndFrontier(const SwitchCluster& switch_cluster);
|
||||
|
||||
// Returns XlaIf node created from subgraph of merge and switch nodes. This
|
||||
// encapsulates the process of extracting the bodies needed for the then and
|
||||
// else branch, creates a XlaIf node, removing the nodes of the branches from
|
||||
// the graph and replacing the merge node with a XlaIf.
|
||||
StatusOr<Node*> ConvertToXlaIf(const CondArgNodes& cond_arg_nodes,
|
||||
const std::vector<Node*>& switch_nodes,
|
||||
const std::vector<Node*>& merge_nodes,
|
||||
Node* predicate);
|
||||
const SwitchCluster& switch_cluster,
|
||||
const std::vector<Node*>& switches);
|
||||
|
||||
// Builds a XlaIfOp to replace the Switch-Graph-Merge cluster with.
|
||||
StatusOr<Node*> BuildAndAddXlaIfOp(const CondArgNodes& cond_arg_nodes,
|
||||
const std::vector<Node*>& switch_nodes,
|
||||
const std::vector<Node*>& merge_nodes,
|
||||
Node* predicate);
|
||||
const SwitchCluster& switch_cluster,
|
||||
const std::vector<Node*>& merge_nodes);
|
||||
|
||||
// Extracts a function body corresponding to the given input edge of the merge
|
||||
// node.
|
||||
Status ExtractBody(const CondArgNodes& cond_arg_nodes,
|
||||
const std::vector<Node*>& switch_nodes,
|
||||
const std::vector<Node*>& switches,
|
||||
const std::vector<Node*>& merge_nodes, int input_edge,
|
||||
Graph* body);
|
||||
|
||||
@ -652,9 +659,9 @@ class FunctionalizeCond {
|
||||
// Adds all output edges from the `if_node`.
|
||||
Status AddOutputEdges(const std::vector<Node*>& outputs, Node* if_node);
|
||||
|
||||
// Returns the switches of graph_ (along with grouping predicates) in
|
||||
// postorder. Dead switch nodes are skipped and removed from the graph.
|
||||
std::vector<PredicateSwitches> DeterminePredicateSwitchOrder();
|
||||
// Returns the switch clusters of graph_ in postorder. Dead switch nodes are
|
||||
// skipped and removed from the graph.
|
||||
StatusOr<std::vector<SwitchCluster>> DeterminePredicateSwitchOrder();
|
||||
|
||||
// Update the state for destination based on the state of source and the node
|
||||
// being updated.
|
||||
@ -677,6 +684,7 @@ class FunctionalizeCond {
|
||||
|
||||
FunctionLibraryDefinition* library_;
|
||||
Graph* graph_;
|
||||
bool dump_graphs_;
|
||||
};
|
||||
|
||||
bool IsDeadSwitch(const Node* node) {
|
||||
@ -724,10 +732,13 @@ Status FunctionalizeCond::ValidateFrontier(
|
||||
") in both Else and Then branch should be in Both.");
|
||||
}
|
||||
}
|
||||
if (pending[kBoth].empty() && pending[kThenBranch].empty() &&
|
||||
pending[kElseBranch].empty()) {
|
||||
return errors::Internal("Unexpected empty frontier for switch nodes");
|
||||
}
|
||||
// An empty frontier indicates a dead switch. Above we attempt to remove dead
|
||||
// switch nodes, but not all are removed so don't treat it as an error yet.
|
||||
// TODO(jpienaar): Find out why dead switch nodes remain.
|
||||
// if (pending[kBoth].empty() && pending[kThenBranch].empty() &&
|
||||
// pending[kElseBranch].empty()) {
|
||||
// return errors::Internal("Unexpected empty frontier for switch nodes");
|
||||
// }
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
@ -754,33 +765,191 @@ Status FunctionalizeCond::Join(const ForwardFlowNode& src_state,
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
std::vector<FunctionalizeCond::PredicateSwitches>
|
||||
StatusOr<std::vector<FunctionalizeCond::SwitchCluster>>
|
||||
FunctionalizeCond::DeterminePredicateSwitchOrder() {
|
||||
struct Cluster {
|
||||
bool operator==(const Cluster& other) const {
|
||||
return representative == other.representative;
|
||||
}
|
||||
int representative = -1;
|
||||
};
|
||||
|
||||
// Perform a DFS over the graph and
|
||||
// * Determine the reverse topological order of the nodes (there should be no
|
||||
// cycles at this point so the post-order numbering corresponds to the
|
||||
// reverse topological sorting);
|
||||
// * Identify dead switches;
|
||||
// * Initialize the cluster's representative;
|
||||
std::vector<UnionFind<Cluster>> clusters(graph_->num_node_ids());
|
||||
std::vector<Node*> dead_switches;
|
||||
std::vector<Node*> switch_order;
|
||||
DFS(*graph_, nullptr, [this, &dead_switches, &switch_order](Node* n) {
|
||||
std::vector<Node*> rev_topo_sorted_nodes;
|
||||
DFS(*graph_, nullptr, [&](Node* n) {
|
||||
clusters[n->id()].Get().representative = n->id();
|
||||
if (IsSwitch(n)) {
|
||||
if (IsDeadSwitch(n)) {
|
||||
dead_switches.push_back(n);
|
||||
} else {
|
||||
rev_topo_sorted_nodes.push_back(n);
|
||||
switch_order.push_back(n);
|
||||
}
|
||||
} else if (n->IsOp()) {
|
||||
// Exclude src and sink nodes from further consideration.
|
||||
rev_topo_sorted_nodes.push_back(n);
|
||||
}
|
||||
});
|
||||
|
||||
std::vector<SwitchCluster> switch_clusters;
|
||||
// Return early if there are no switches in the graph.
|
||||
if (switch_order.empty()) {
|
||||
return switch_clusters;
|
||||
}
|
||||
|
||||
// Remove all dead switch nodes.
|
||||
for (Node* n : dead_switches) {
|
||||
VLOG(2) << "Removing dead switch: " << n->DebugString();
|
||||
graph_->RemoveNode(n);
|
||||
}
|
||||
|
||||
std::vector<PredicateSwitches> predicate_switch_order;
|
||||
if (switch_order.empty()) {
|
||||
return predicate_switch_order;
|
||||
// Identify switch nodes that are part of the same control flow context by
|
||||
// considering the operands of operations: an operation is part of the same
|
||||
// control context as its operands unless the operation is a switch. Control
|
||||
// dependencies are considered part of the same control flow context if the
|
||||
// switch depth is the same (see comment below).
|
||||
|
||||
// entry_cluster records the input cluster to a switch node. This is used when
|
||||
// merging with a merge node where the dst's cluster is merged with the entry
|
||||
// cluster of the merge node's cluster (which corresponds to a switch cluster
|
||||
// and so has an entry cluster).
|
||||
std::unordered_map<int, UnionFind<Cluster>*> entry_cluster;
|
||||
|
||||
// Returns the output cluster of a node. Where the output cluster is cluster
|
||||
// where the output of the node is used. For non-merge nodes this is simply
|
||||
// the cluster they are part of, while for merge nodes it is the entry cluster
|
||||
// of the cluster they are part of (this will correspond to the entry node of
|
||||
// a switch node that dominates the merge).
|
||||
auto find_output_cluster = [&](Node* n) {
|
||||
UnionFind<Cluster>* cluster = &clusters[n->id()];
|
||||
if (!IsMerge(n)) return cluster;
|
||||
auto it = entry_cluster.find(clusters[n->id()].Get().representative);
|
||||
// If the cluster is not found in the entry_cluster map then an
|
||||
// instruction not dominated by a switch node has been merged into the
|
||||
// cluster of the merge. This indicates a failure of the clustering.
|
||||
CHECK(it != entry_cluster.end())
|
||||
<< "Unable to find entry for n=" << n->id() << " ("
|
||||
<< cluster->Get().representative << ")";
|
||||
return it->second;
|
||||
};
|
||||
|
||||
// TODO(jpienaar): This could be combined with DetermineBranchMapAndFrontier.
|
||||
std::vector<int> switch_depth(graph_->num_node_ids());
|
||||
for (auto it = rev_topo_sorted_nodes.rbegin();
|
||||
it != rev_topo_sorted_nodes.rend(); ++it) {
|
||||
Node* n = *it;
|
||||
|
||||
// Compute switch depth.
|
||||
int new_switch_depth = 0;
|
||||
for (const Edge* e : n->in_edges()) {
|
||||
Node* src = e->src();
|
||||
new_switch_depth = std::max(
|
||||
new_switch_depth, switch_depth[src->id()] - (IsMerge(src) ? 1 : 0));
|
||||
}
|
||||
switch_depth[n->id()] = new_switch_depth + (IsSwitch(n) ? 1 : 0);
|
||||
|
||||
// Only merge the input operands of a switch. The switch's clustering itself
|
||||
// is determined by the interaction of the switch's outputs.
|
||||
if (IsSwitch(n)) {
|
||||
Node* input;
|
||||
TF_CHECK_OK(n->input_node(0, &input));
|
||||
entry_cluster[n->id()] = &clusters[input->id()];
|
||||
UnionFind<Cluster>* cluster = find_output_cluster(input);
|
||||
int cluster_depth = switch_depth[cluster->Get().representative];
|
||||
// Merge the inputs of the switch node with one another. This results in
|
||||
// predicates and control input residing in the same cluster.
|
||||
for (const Edge* e : n->in_edges()) {
|
||||
Node* src = e->src();
|
||||
UnionFind<Cluster>* src_cluster = find_output_cluster(src);
|
||||
int src_cluster_depth = switch_depth[src_cluster->Get().representative];
|
||||
if (cluster_depth != src_cluster_depth) {
|
||||
return errors::InvalidArgument(
|
||||
"Unable to functionalize control flow in graph: Switch ('",
|
||||
n->name(), "') has operands ('", input->name(), "' and '",
|
||||
src->name(), "') that have different switch depths (",
|
||||
cluster_depth, " != ", src_cluster_depth, ")");
|
||||
}
|
||||
cluster->Merge(src_cluster);
|
||||
}
|
||||
continue;
|
||||
}
|
||||
|
||||
for (const Edge* e : n->in_edges()) {
|
||||
Node* src = e->src();
|
||||
if (!src->IsOp()) continue;
|
||||
UnionFind<Cluster>* cluster = find_output_cluster(src);
|
||||
// Merge a node with its data operands and with its control operands if
|
||||
// the src and dst are in the same ControlContext. The ControlContext is
|
||||
// not explicitly available here, and instead the switch depth is used as
|
||||
// a proxy here. Due to the invariant that control edges can only be from
|
||||
// a containing scope to an inner scope or from the inner scope to its
|
||||
// containing scope (for exit nodes), the switch depth will only match if
|
||||
// the src and dst are in the same ControlContext. Control edges between
|
||||
// ControlContexts are handled during the extraction.
|
||||
int src_id = cluster->Get().representative;
|
||||
int src_depth = switch_depth[src_id];
|
||||
if (!e->IsControlEdge() || new_switch_depth == src_depth) {
|
||||
if (src_depth != new_switch_depth) {
|
||||
return errors::InvalidArgument(
|
||||
"Unable to functionalize control flow in graph: Operand ('",
|
||||
src->name(), "') and operator ('", n->name(),
|
||||
"') have different switch depths (", src_depth,
|
||||
" != ", new_switch_depth, ")");
|
||||
}
|
||||
cluster->Merge(&clusters[n->id()]);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
if (dump_graphs_) {
|
||||
// Mark the switch cluster each node is part of.
|
||||
for (Node* n : graph_->nodes()) {
|
||||
n->ClearAttr("_XlaFunctionalizeSwitchGroup");
|
||||
n->AddAttr("_XlaFunctionalizeSwitchGroup",
|
||||
clusters[n->id()].Get().representative);
|
||||
}
|
||||
LOG(INFO) << "FunctionalizeControlFlow (with_clusters): "
|
||||
<< dump_graph::DumpGraphToFile("functionalize_clustered", *graph_,
|
||||
library_);
|
||||
}
|
||||
|
||||
// Verify all the nodes of a cluster are at the same depth.
|
||||
std::unordered_map<int, std::pair<int, Node*>> cluster_to_depth_node;
|
||||
for (Node* n : graph_->nodes()) {
|
||||
int depth = switch_depth[n->id()];
|
||||
int cluster_rep = clusters[n->id()].Get().representative;
|
||||
auto it = cluster_to_depth_node.find(cluster_rep);
|
||||
if (it == cluster_to_depth_node.end()) {
|
||||
cluster_to_depth_node[cluster_rep] = std::make_pair(depth, n);
|
||||
} else {
|
||||
if (it->second.first != depth) {
|
||||
return errors::Internal(
|
||||
"Illegal clustering created, mismatch in depths:", "\n\t",
|
||||
n->DebugString(), "(", clusters[n->id()].Get().representative,
|
||||
") at depth=", depth, " vs\n\t", it->second.second->DebugString(),
|
||||
"(", clusters[n->id()].Get().representative, ") at depth ",
|
||||
it->second.first);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
struct Hash {
|
||||
size_t operator()(const std::pair<Node*, Cluster>& item) const {
|
||||
return Hash64Combine(hash<Node*>()(item.first),
|
||||
std::hash<int>()(item.second.representative));
|
||||
}
|
||||
};
|
||||
|
||||
// Merge Switch nodes with common predicate.
|
||||
std::unordered_map<Node*, int> predicate_index;
|
||||
std::unordered_map<std::pair<Node*, Cluster>, int, Hash> predicate_index;
|
||||
// The nodes in switch_order are in reverse topological order, but the
|
||||
// clustered switches need not be (i.e., when considered as a cluster one
|
||||
// element of a cluster may be later in the topological order than another
|
||||
@ -789,13 +958,19 @@ FunctionalizeCond::DeterminePredicateSwitchOrder() {
|
||||
for (auto it = switch_order.rbegin(); it != switch_order.rend(); ++it) {
|
||||
Node* pred;
|
||||
TF_CHECK_OK((*it)->input_node(1, &pred));
|
||||
if (predicate_index.find(pred) == predicate_index.end()) {
|
||||
predicate_index[pred] = predicate_switch_order.size();
|
||||
predicate_switch_order.emplace_back(pred);
|
||||
auto repr = std::make_pair(pred, clusters[(*it)->id()].Get());
|
||||
if (predicate_index.find(repr) == predicate_index.end()) {
|
||||
predicate_index[repr] = switch_clusters.size();
|
||||
switch_clusters.emplace_back(pred);
|
||||
// Generate a name by concatenating with the cluster representative as
|
||||
// there could be multiple switch clusters with the same predicate.
|
||||
switch_clusters[predicate_index[repr]].name =
|
||||
strings::StrCat(pred->name(), "_", repr.second.representative, "_If");
|
||||
}
|
||||
predicate_switch_order[predicate_index[pred]].switches.push_back(*it);
|
||||
switch_clusters[predicate_index[repr]].switches.push_back(*it);
|
||||
}
|
||||
return predicate_switch_order;
|
||||
|
||||
return switch_clusters;
|
||||
}
|
||||
|
||||
StatusOr<std::vector<Node*>>
|
||||
@ -843,10 +1018,10 @@ StatusOr<
|
||||
std::pair<std::unordered_map<Node*, FunctionalizeCond::ForwardFlowNode>,
|
||||
std::unordered_set<Node*>>>
|
||||
FunctionalizeCond::DetermineBranchMapAndFrontier(
|
||||
const std::vector<Node*>& switches) {
|
||||
const SwitchCluster& switch_cluster) {
|
||||
std::unordered_map<Node*, ForwardFlowNode> branch_map;
|
||||
std::unordered_set<Node*> frontier;
|
||||
std::vector<Node*> stack = switches;
|
||||
std::vector<Node*> stack = switch_cluster.switches;
|
||||
std::vector<bool> visited(graph_->num_node_ids(), false);
|
||||
while (!stack.empty()) {
|
||||
Node* n = stack.back();
|
||||
@ -888,7 +1063,7 @@ FunctionalizeCond::DetermineBranchMapAndFrontier(
|
||||
}
|
||||
}
|
||||
|
||||
if (VLOG_IS_ON(2)) {
|
||||
if (dump_graphs_) {
|
||||
for (const auto& kv : branch_map) {
|
||||
// Append attribute to the graph if running with logging to make the
|
||||
// changes clearer in the visualization.
|
||||
@ -900,8 +1075,8 @@ FunctionalizeCond::DetermineBranchMapAndFrontier(
|
||||
}
|
||||
|
||||
Status FunctionalizeCond::FunctionalizeInternal() {
|
||||
std::vector<PredicateSwitches> predicate_switch_order =
|
||||
DeterminePredicateSwitchOrder();
|
||||
TF_ASSIGN_OR_RETURN(std::vector<SwitchCluster> predicate_switch_order,
|
||||
DeterminePredicateSwitchOrder());
|
||||
|
||||
// Iterate from innermost set of clustered switches to outermost, replacing
|
||||
// matching switch->merge subgraphs with single XlaIf nodes.
|
||||
@ -914,10 +1089,12 @@ Status FunctionalizeCond::FunctionalizeInternal() {
|
||||
std::unordered_map<Node*, ForwardFlowNode> branch_map;
|
||||
std::unordered_set<Node*> frontier;
|
||||
TF_ASSIGN_OR_RETURN(std::tie(branch_map, frontier),
|
||||
DetermineBranchMapAndFrontier(ps.switches));
|
||||
DetermineBranchMapAndFrontier(ps));
|
||||
|
||||
VLOG(2) << "FunctionalizeControlFlow (before XlaIf conversion): "
|
||||
<< dump_graph::DumpGraphToFile("functionalize_bc", *graph_);
|
||||
if (dump_graphs_)
|
||||
LOG(INFO) << "FunctionalizeControlFlow (before XlaIf conversion): "
|
||||
<< dump_graph::DumpGraphToFile("functionalize_bc", *graph_,
|
||||
library_);
|
||||
TF_RETURN_IF_ERROR(ValidateFrontier(branch_map, frontier));
|
||||
|
||||
// Sort the merge and switch nodes using NodeCmp. The switch-nodes are
|
||||
@ -934,7 +1111,7 @@ Status FunctionalizeCond::FunctionalizeInternal() {
|
||||
input_index[in] = cond_arg_nodes.size();
|
||||
cond_arg_nodes.emplace_back(in);
|
||||
}
|
||||
cond_arg_nodes.at(input_index.at(in)).switch_nodes.push_back(switch_node);
|
||||
cond_arg_nodes.at(input_index.at(in)).switches.push_back(switch_node);
|
||||
}
|
||||
std::vector<Node*> merge_nodes(frontier.begin(), frontier.end());
|
||||
std::sort(merge_nodes.begin(), merge_nodes.end(), NodeCmp());
|
||||
@ -943,9 +1120,8 @@ Status FunctionalizeCond::FunctionalizeInternal() {
|
||||
EnsureDominanceAndReturnNonDominatedControlNodes(
|
||||
branch_map, ps.switches));
|
||||
|
||||
TF_ASSIGN_OR_RETURN(
|
||||
Node * if_node,
|
||||
ConvertToXlaIf(cond_arg_nodes, ps.switches, merge_nodes, ps.predicate));
|
||||
TF_ASSIGN_OR_RETURN(Node * if_node,
|
||||
ConvertToXlaIf(cond_arg_nodes, ps, merge_nodes));
|
||||
for (Node* old : old_control_nodes) {
|
||||
graph_->AddControlEdge(old, if_node);
|
||||
}
|
||||
@ -954,25 +1130,26 @@ Status FunctionalizeCond::FunctionalizeInternal() {
|
||||
graph_->RemoveNode(del_kv.first);
|
||||
}
|
||||
for (auto& kv : cond_arg_nodes) {
|
||||
for (Node* node : kv.switch_nodes) {
|
||||
for (Node* node : kv.switches) {
|
||||
graph_->RemoveNode(node);
|
||||
}
|
||||
}
|
||||
VLOG(2) << "FunctionalizeControlFlow (after XlaIf conversion): "
|
||||
<< dump_graph::DumpGraphToFile("functionalize_ac", *graph_);
|
||||
if (dump_graphs_)
|
||||
LOG(INFO) << "FunctionalizeControlFlow (after XlaIf conversion): "
|
||||
<< dump_graph::DumpGraphToFile("functionalize_ac", *graph_,
|
||||
library_);
|
||||
}
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
StatusOr<Node*> FunctionalizeCond::BuildAndAddXlaIfOp(
|
||||
const CondArgNodes& cond_arg_nodes, const std::vector<Node*>& switch_nodes,
|
||||
const std::vector<Node*>& merge_nodes, Node* predicate) {
|
||||
VLOG(2) << "Build if op for " << NodesToString(merge_nodes) << " with input "
|
||||
<< NodesToString(switch_nodes);
|
||||
const CondArgNodes& cond_arg_nodes, const SwitchCluster& switch_cluster,
|
||||
const std::vector<Node*>& merge_nodes) {
|
||||
VLOG(2) << "Build if op for " << switch_cluster.name;
|
||||
|
||||
NodeDef if_def;
|
||||
// Create a new If node using the name of the merge node.
|
||||
NodeDefBuilder builder(strings::StrCat(predicate->name(), "_If"), "XlaIf");
|
||||
NodeDefBuilder builder(switch_cluster.name, "XlaIf");
|
||||
string branch[] = {"else_branch", "then_branch"};
|
||||
for (int i = 0; i < 2; ++i) {
|
||||
static std::atomic<int64> sequence_num(0LL);
|
||||
@ -982,12 +1159,9 @@ StatusOr<Node*> FunctionalizeCond::BuildAndAddXlaIfOp(
|
||||
body_name.set_name(
|
||||
strings::StrCat("_functionalize_if_", branch[i], "_", id));
|
||||
auto body = xla::MakeUnique<Graph>(graph_->op_registry());
|
||||
TF_RETURN_IF_ERROR(
|
||||
ExtractBody(cond_arg_nodes, switch_nodes, merge_nodes, i, body.get()));
|
||||
TF_RETURN_IF_ERROR(ExtractBody(cond_arg_nodes, switch_cluster.switches,
|
||||
merge_nodes, i, body.get()));
|
||||
VLOG(3) << "Body " << branch[i] << ": " << DebugString(body.get());
|
||||
VLOG(4) << "FunctionalizeControlFlow (" << branch[i] << "): "
|
||||
<< dump_graph::DumpGraphToFile(
|
||||
strings::StrCat("functionalize_", branch[i]), *body);
|
||||
FunctionDef body_fdef;
|
||||
TF_RETURN_IF_ERROR(GraphToFunctionDef(*body, body_name.name(), &body_fdef));
|
||||
TF_RETURN_IF_ERROR(library_->AddFunctionDef(body_fdef));
|
||||
@ -999,7 +1173,7 @@ StatusOr<Node*> FunctionalizeCond::BuildAndAddXlaIfOp(
|
||||
DataTypeVector in_arg_types;
|
||||
for (auto& kv : cond_arg_nodes) {
|
||||
bool inserted = false;
|
||||
for (const Node* arg : kv.switch_nodes) {
|
||||
for (const Node* arg : kv.switches) {
|
||||
const Edge* in_edge;
|
||||
TF_RETURN_IF_ERROR(arg->input_edge(0, &in_edge));
|
||||
if (in_edge->IsControlEdge()) {
|
||||
@ -1026,10 +1200,11 @@ StatusOr<Node*> FunctionalizeCond::BuildAndAddXlaIfOp(
|
||||
builder.Attr("Tout", out_type);
|
||||
|
||||
builder.Attr("Tcond", DT_BOOL);
|
||||
builder.Device(predicate->assigned_device_name());
|
||||
builder.Device(switch_cluster.predicate->assigned_device_name());
|
||||
// Conditional should be the first input ...
|
||||
builder.Input(
|
||||
NodeDefBuilder::NodeOut(predicate->name(), 0, predicate->output_type(0)));
|
||||
NodeDefBuilder::NodeOut(switch_cluster.predicate->name(), 0,
|
||||
switch_cluster.predicate->output_type(0)));
|
||||
// ... followed by the other inputs.
|
||||
builder.Input(inputs);
|
||||
|
||||
@ -1039,7 +1214,7 @@ StatusOr<Node*> FunctionalizeCond::BuildAndAddXlaIfOp(
|
||||
}
|
||||
|
||||
Status FunctionalizeCond::ExtractBody(const CondArgNodes& cond_arg_nodes,
|
||||
const std::vector<Node*>& switch_nodes,
|
||||
const std::vector<Node*>& switches,
|
||||
const std::vector<Node*>& merge_nodes,
|
||||
int input_edge, Graph* body) {
|
||||
VLOG(2) << "ExtractBody for " << NodesToString(merge_nodes) << " along edge "
|
||||
@ -1049,7 +1224,7 @@ Status FunctionalizeCond::ExtractBody(const CondArgNodes& cond_arg_nodes,
|
||||
int arg_count = 0;
|
||||
for (auto& kv : cond_arg_nodes) {
|
||||
Node* arg_node = nullptr;
|
||||
for (const auto* arg : kv.switch_nodes) {
|
||||
for (const auto* arg : kv.switches) {
|
||||
DataType dtype = arg->input_type(0);
|
||||
if (arg_node == nullptr) {
|
||||
TF_ASSIGN_OR_RETURN(arg_node, BuildArgNode(body, dtype, arg_count++));
|
||||
@ -1073,8 +1248,7 @@ Status FunctionalizeCond::ExtractBody(const CondArgNodes& cond_arg_nodes,
|
||||
node_map.at(in->id()) = body->CopyNode(in);
|
||||
}
|
||||
|
||||
if (std::find(switch_nodes.begin(), switch_nodes.end(), in) ==
|
||||
switch_nodes.end()) {
|
||||
if (std::find(switches.begin(), switches.end(), in) == switches.end()) {
|
||||
body->AddEdge(node_map.at(in->id()), in_edge->src_output(),
|
||||
node_map.at(node->id()), 0);
|
||||
} else {
|
||||
@ -1096,7 +1270,7 @@ Status FunctionalizeCond::AddInputEdges(const CondArgNodes& cond_arg_nodes,
|
||||
graph_->AddEdge(predicate, 0, if_node, index++);
|
||||
for (auto& kv : cond_arg_nodes) {
|
||||
bool inserted = false;
|
||||
for (const Node* arg : kv.switch_nodes) {
|
||||
for (const Node* arg : kv.switches) {
|
||||
const Edge* in_edge;
|
||||
TF_RETURN_IF_ERROR(arg->input_edge(0, &in_edge));
|
||||
if (in_edge->IsControlEdge()) {
|
||||
@ -1139,16 +1313,17 @@ Status FunctionalizeCond::AddOutputEdges(const std::vector<Node*>& outputs,
|
||||
}
|
||||
|
||||
StatusOr<Node*> FunctionalizeCond::ConvertToXlaIf(
|
||||
const CondArgNodes& cond_arg_nodes, const std::vector<Node*>& switch_nodes,
|
||||
const std::vector<Node*>& merge_nodes, Node* predicate) {
|
||||
VLOG(1) << "ConvertToXlaIf for " << NodesToString(switch_nodes) << " -> "
|
||||
const CondArgNodes& cond_arg_nodes, const SwitchCluster& switch_cluster,
|
||||
const std::vector<Node*>& merge_nodes) {
|
||||
VLOG(1) << "ConvertToXlaIf for " << switch_cluster.ToString() << " -> "
|
||||
<< NodesToString(merge_nodes);
|
||||
|
||||
// Extract bodies and builds a If operator.
|
||||
TF_ASSIGN_OR_RETURN(
|
||||
Node * if_node,
|
||||
BuildAndAddXlaIfOp(cond_arg_nodes, switch_nodes, merge_nodes, predicate));
|
||||
TF_RETURN_IF_ERROR(AddInputEdges(cond_arg_nodes, predicate, if_node));
|
||||
BuildAndAddXlaIfOp(cond_arg_nodes, switch_cluster, merge_nodes));
|
||||
TF_RETURN_IF_ERROR(
|
||||
AddInputEdges(cond_arg_nodes, switch_cluster.predicate, if_node));
|
||||
TF_RETURN_IF_ERROR(AddOutputEdges(merge_nodes, if_node));
|
||||
|
||||
return if_node;
|
||||
@ -1157,18 +1332,19 @@ StatusOr<Node*> FunctionalizeCond::ConvertToXlaIf(
|
||||
Status FunctionalizeCond::Functionalize(Graph* graph,
|
||||
FunctionLibraryDefinition* library) {
|
||||
VLOG(1) << "FunctionalizeCond::Functionalize";
|
||||
FunctionalizeCond fc(graph, library);
|
||||
FunctionalizeCond fc(graph, library, /*dump_graphs=*/VLOG_IS_ON(2));
|
||||
return fc.FunctionalizeInternal();
|
||||
}
|
||||
|
||||
} // namespace
|
||||
|
||||
// Transformation that converts Tensorflow's graph control flow constructs into
|
||||
// Transformation that converts TensorFlow's graph control flow constructs into
|
||||
// functional equivalents.
|
||||
Status FunctionalizeControlFlow(Graph* graph,
|
||||
FunctionLibraryDefinition* library) {
|
||||
VLOG(2) << "FunctionalizeControlFlow (initial): "
|
||||
<< dump_graph::DumpGraphToFile("functionalize_initial", *graph);
|
||||
<< dump_graph::DumpGraphToFile("functionalize_initial", *graph,
|
||||
library);
|
||||
// Note: BuildControlFlowInfo() requires that the graph's source node is
|
||||
// connected to all source nodes in the graph. Many graphs violate this
|
||||
// invariant.
|
||||
@ -1180,7 +1356,8 @@ Status FunctionalizeControlFlow(Graph* graph,
|
||||
for (Node* node : graph->op_nodes()) {
|
||||
const ControlFlowInfo& cf = cf_info[node->id()];
|
||||
|
||||
VLOG(2) << "node: " << node->name() << " frame_name: " << cf.frame_name
|
||||
VLOG(2) << "node: " << node->name() << " (" << node->id()
|
||||
<< ") frame_name: " << cf.frame_name
|
||||
<< " frame: " << (cf.frame ? cf.frame->name() : "---")
|
||||
<< " parent_frame: "
|
||||
<< (cf.parent_frame ? cf.parent_frame->name() : "---");
|
||||
@ -1248,7 +1425,8 @@ Status FunctionalizeControlFlow(Graph* graph,
|
||||
TF_RETURN_IF_ERROR(FunctionalizeCond::Functionalize(graph, library));
|
||||
|
||||
VLOG(2) << "FunctionalizeControlFlow (final): "
|
||||
<< dump_graph::DumpGraphToFile("functionalize_final", *graph);
|
||||
<< dump_graph::DumpGraphToFile("functionalize_final", *graph,
|
||||
library);
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
|
@ -38,10 +38,11 @@ namespace {
|
||||
|
||||
// Returns the names of the "then" and "else" functions for the XlaIf node in a
|
||||
// graph.
|
||||
Status FindIfThenAndElse(const GraphDef& graph, NameAttrList* then_fn,
|
||||
NameAttrList* else_fn) {
|
||||
Status FindIfThenAndElse(const GraphDef& graph, string* op_name,
|
||||
NameAttrList* then_fn, NameAttrList* else_fn) {
|
||||
for (const NodeDef& node : graph.node()) {
|
||||
if (node.op() == "XlaIf") {
|
||||
*op_name = node.name();
|
||||
const NameAttrList* result;
|
||||
TF_RETURN_IF_ERROR(GetNodeAttr(node, "then_branch", &result));
|
||||
*then_fn = *result;
|
||||
@ -96,9 +97,10 @@ TEST(FunctionalizeControlFlow, Conditional) {
|
||||
|
||||
GraphDef graph_def;
|
||||
graph.ToGraphDef(&graph_def);
|
||||
string op_name;
|
||||
NameAttrList then_fn;
|
||||
NameAttrList else_fn;
|
||||
TF_EXPECT_OK(FindIfThenAndElse(graph_def, &then_fn, &else_fn));
|
||||
TF_EXPECT_OK(FindIfThenAndElse(graph_def, &op_name, &then_fn, &else_fn));
|
||||
InstantiationResultForTest else_result;
|
||||
TF_EXPECT_OK(
|
||||
InstantiateFunctionForTest(else_fn.name(), library, &else_result));
|
||||
@ -109,7 +111,7 @@ TEST(FunctionalizeControlFlow, Conditional) {
|
||||
auto y = ops::Placeholder(scope.WithOpName("y"), DT_INT32);
|
||||
auto x = ops::Placeholder(scope.WithOpName("x"), DT_INT32);
|
||||
auto less = ops::Less(scope.WithOpName("cond/Less"), y, x);
|
||||
auto if_op = ops::XlaIf(scope.WithOpName("cond/Less_If"), less,
|
||||
auto if_op = ops::XlaIf(scope.WithOpName(op_name), less,
|
||||
std::initializer_list<Input>{less, y, x}, then_fn,
|
||||
else_fn, {DT_INT32});
|
||||
GraphDef expected;
|
||||
|
@ -134,7 +134,7 @@ Status GraphCompiler::Compile() {
|
||||
TF_RET_CHECK(src->id() < output_registry.size());
|
||||
const NodeOutputs& src_outputs = output_registry[src->id()];
|
||||
|
||||
tensor_inputs_[e->dst_input()] = src_outputs[e->src_output()];
|
||||
tensor_inputs_.at(e->dst_input()) = src_outputs.at(e->src_output());
|
||||
}
|
||||
|
||||
OpKernelContext op_context(¶ms, n->num_outputs());
|
||||
|
@ -32,6 +32,7 @@ tf_kernel_library(
|
||||
"dynamic_stitch_op.cc",
|
||||
"elu_op.cc",
|
||||
"extract_image_patches_op.cc",
|
||||
"fake_quantize_ops.cc",
|
||||
"fft_ops.cc",
|
||||
"fill_op.cc",
|
||||
"function_ops.cc",
|
||||
@ -64,6 +65,7 @@ tf_kernel_library(
|
||||
"reverse_op.cc",
|
||||
"reverse_sequence_op.cc",
|
||||
"scan_ops.cc",
|
||||
"scatter_nd_op.cc",
|
||||
"segment_reduction_ops.cc",
|
||||
"select_op.cc",
|
||||
"sendrecv_ops.cc",
|
||||
@ -96,12 +98,15 @@ tf_kernel_library(
|
||||
"//tensorflow/compiler/tf2xla:xla_compiler",
|
||||
"//tensorflow/compiler/tf2xla/lib:batch_dot",
|
||||
"//tensorflow/compiler/tf2xla/lib:cholesky",
|
||||
"//tensorflow/compiler/tf2xla/lib:scatter",
|
||||
"//tensorflow/compiler/tf2xla/lib:triangular_solve",
|
||||
"//tensorflow/compiler/tf2xla/lib:util",
|
||||
"//tensorflow/compiler/tf2xla/lib:while_loop",
|
||||
"//tensorflow/compiler/tf2xla/ops:sendrecv_ops",
|
||||
"//tensorflow/compiler/xla:array4d",
|
||||
"//tensorflow/compiler/xla:literal_util",
|
||||
"//tensorflow/compiler/xla:shape_util",
|
||||
"//tensorflow/compiler/xla:status_macros",
|
||||
"//tensorflow/compiler/xla:util",
|
||||
"//tensorflow/compiler/xla:xla_data_proto",
|
||||
"//tensorflow/compiler/xla/client:client_library",
|
||||
|
289
tensorflow/compiler/tf2xla/kernels/fake_quantize_ops.cc
Normal file
289
tensorflow/compiler/tf2xla/kernels/fake_quantize_ops.cc
Normal file
@ -0,0 +1,289 @@
|
||||
/* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
|
||||
#include "tensorflow/compiler/tf2xla/xla_helpers.h"
|
||||
#include "tensorflow/compiler/tf2xla/xla_op_kernel.h"
|
||||
#include "tensorflow/compiler/tf2xla/xla_op_registry.h"
|
||||
#include "tensorflow/compiler/xla/client/lib/arithmetic.h"
|
||||
#include "tensorflow/core/platform/macros.h"
|
||||
|
||||
namespace tensorflow {
|
||||
namespace {
|
||||
|
||||
// Gymnastics with nudged zero point is to ensure that the real zero maps to
|
||||
// an integer, which is required for e.g. zero-padding in convolutional layers.
|
||||
void CpuNudge(const float min, const float max, const float quant_min,
|
||||
const float quant_max, float* nudged_min, float* nudged_max,
|
||||
float* scale) {
|
||||
*scale = (max - min) / (quant_max - quant_min);
|
||||
|
||||
const float zero_point_from_min = quant_min - min / *scale;
|
||||
float nudged_zero_point;
|
||||
if (zero_point_from_min <= quant_min) {
|
||||
nudged_zero_point = quant_min;
|
||||
} else if (zero_point_from_min >= quant_max) {
|
||||
nudged_zero_point = quant_max;
|
||||
} else {
|
||||
nudged_zero_point = std::round(zero_point_from_min);
|
||||
}
|
||||
|
||||
*nudged_min = (quant_min - nudged_zero_point) * (*scale);
|
||||
*nudged_max = (quant_max - nudged_zero_point) * (*scale);
|
||||
}
|
||||
|
||||
// An XLA version of CpuNudge().
|
||||
void XlaNudge(xla::ComputationBuilder* b, const DataType data_type,
|
||||
const xla::ComputationDataHandle& min,
|
||||
const xla::ComputationDataHandle& max,
|
||||
const float quant_min_value, const float quant_max_value,
|
||||
xla::ComputationDataHandle* nudged_min,
|
||||
xla::ComputationDataHandle* nudged_max,
|
||||
xla::ComputationDataHandle* scale) {
|
||||
*scale = b->Div(b->Sub(max, min),
|
||||
XlaHelpers::FloatLiteral(b, data_type,
|
||||
quant_max_value - quant_min_value));
|
||||
xla::ComputationDataHandle quant_min =
|
||||
XlaHelpers::FloatLiteral(b, data_type, quant_min_value);
|
||||
xla::ComputationDataHandle zero_point_from_min =
|
||||
b->Sub(quant_min, b->Div(min, *scale));
|
||||
xla::ComputationDataHandle quant_max =
|
||||
XlaHelpers::FloatLiteral(b, data_type, quant_max_value);
|
||||
xla::ComputationDataHandle nudged_zero_point =
|
||||
b->Select(b->Le(zero_point_from_min, quant_min), quant_min,
|
||||
b->Select(b->Ge(zero_point_from_min, quant_max), quant_max,
|
||||
b->Round(zero_point_from_min)));
|
||||
*nudged_min = b->Mul(b->Sub(quant_min, nudged_zero_point), *scale);
|
||||
*nudged_max = b->Mul(b->Sub(quant_max, nudged_zero_point), *scale);
|
||||
}
|
||||
|
||||
xla::ComputationDataHandle Quantize(
|
||||
xla::ComputationBuilder* b, const xla::ComputationDataHandle& input,
|
||||
const DataType data_type,
|
||||
const xla::ComputationDataHandle& nudged_input_min,
|
||||
const xla::ComputationDataHandle& nudged_input_max,
|
||||
const xla::ComputationDataHandle& input_scale) {
|
||||
xla::ComputationDataHandle one = XlaHelpers::FloatLiteral(b, data_type, 1.0f);
|
||||
xla::ComputationDataHandle inv_scale = b->Div(one, input_scale);
|
||||
xla::ComputationDataHandle half =
|
||||
XlaHelpers::FloatLiteral(b, data_type, 0.5f);
|
||||
|
||||
xla::ComputationDataHandle clamped =
|
||||
b->Clamp(nudged_input_min, input, nudged_input_max);
|
||||
xla::ComputationDataHandle clamped_shifted =
|
||||
b->Sub(clamped, nudged_input_min);
|
||||
xla::ComputationDataHandle rounded =
|
||||
b->Floor(b->Add(b->Mul(clamped_shifted, inv_scale), half));
|
||||
return b->Add(b->Mul(rounded, input_scale), nudged_input_min);
|
||||
}
|
||||
|
||||
class FakeQuantWithMinMaxArgsOp : public XlaOpKernel {
|
||||
public:
|
||||
explicit FakeQuantWithMinMaxArgsOp(OpKernelConstruction* ctx)
|
||||
: XlaOpKernel(ctx) {
|
||||
int num_bits;
|
||||
OP_REQUIRES_OK(ctx, ctx->GetAttr("num_bits", &num_bits));
|
||||
OP_REQUIRES(ctx, num_bits >= 2 && num_bits <= 16,
|
||||
errors::InvalidArgument("num_bits is out of range, expected "
|
||||
"between 2 and 16, was: ",
|
||||
num_bits));
|
||||
bool narrow_range;
|
||||
OP_REQUIRES_OK(ctx, ctx->GetAttr("narrow_range", &narrow_range));
|
||||
quant_min_ = narrow_range ? 1 : 0;
|
||||
quant_max_ = (1 << num_bits) - 1;
|
||||
|
||||
float input_min, input_max;
|
||||
OP_REQUIRES_OK(ctx, ctx->GetAttr("min", &input_min));
|
||||
OP_REQUIRES_OK(ctx, ctx->GetAttr("max", &input_max));
|
||||
CpuNudge(input_min, input_max, quant_min_, quant_max_, &nudged_input_min_,
|
||||
&nudged_input_max_, &input_scale_);
|
||||
}
|
||||
|
||||
void Compile(XlaOpKernelContext* ctx) override {
|
||||
xla::ComputationDataHandle input = ctx->Input(0);
|
||||
const DataType data_type = ctx->input_type(0);
|
||||
|
||||
xla::ComputationBuilder* b = ctx->builder();
|
||||
xla::ComputationDataHandle nudged_input_min =
|
||||
XlaHelpers::FloatLiteral(b, data_type, nudged_input_min_);
|
||||
xla::ComputationDataHandle nudged_input_max =
|
||||
XlaHelpers::FloatLiteral(b, data_type, nudged_input_max_);
|
||||
xla::ComputationDataHandle input_scale =
|
||||
XlaHelpers::FloatLiteral(b, data_type, input_scale_);
|
||||
xla::ComputationDataHandle output = Quantize(
|
||||
b, input, data_type, nudged_input_min, nudged_input_max, input_scale);
|
||||
ctx->SetOutput(0, output);
|
||||
}
|
||||
|
||||
private:
|
||||
float quant_min_;
|
||||
float quant_max_;
|
||||
float nudged_input_min_;
|
||||
float nudged_input_max_;
|
||||
float input_scale_;
|
||||
};
|
||||
|
||||
REGISTER_XLA_OP(Name("FakeQuantWithMinMaxArgs"), FakeQuantWithMinMaxArgsOp);
|
||||
|
||||
class FakeQuantWithMinMaxArgsGradOp : public XlaOpKernel {
|
||||
public:
|
||||
explicit FakeQuantWithMinMaxArgsGradOp(OpKernelConstruction* ctx)
|
||||
: XlaOpKernel(ctx) {
|
||||
int num_bits;
|
||||
OP_REQUIRES_OK(ctx, ctx->GetAttr("num_bits", &num_bits));
|
||||
OP_REQUIRES(ctx, num_bits >= 2 && num_bits <= 16,
|
||||
errors::InvalidArgument("num_bits is out of range, expected "
|
||||
"between 2 and 16, was: ",
|
||||
num_bits));
|
||||
bool narrow_range;
|
||||
OP_REQUIRES_OK(ctx, ctx->GetAttr("narrow_range", &narrow_range));
|
||||
const float quant_min = narrow_range ? 1 : 0;
|
||||
const float quant_max = (1 << num_bits) - 1;
|
||||
|
||||
float input_min, input_max, scale;
|
||||
OP_REQUIRES_OK(ctx, ctx->GetAttr("min", &input_min));
|
||||
OP_REQUIRES_OK(ctx, ctx->GetAttr("max", &input_max));
|
||||
CpuNudge(input_min, input_max, quant_min, quant_max, &nudged_input_min_,
|
||||
&nudged_input_max_, &scale);
|
||||
}
|
||||
|
||||
void Compile(XlaOpKernelContext* ctx) override {
|
||||
xla::ComputationDataHandle gradient = ctx->Input(0);
|
||||
const TensorShape gradient_shape = ctx->InputShape(0);
|
||||
xla::ComputationDataHandle input = ctx->Input(1);
|
||||
const DataType data_type = ctx->input_type(1);
|
||||
|
||||
xla::ComputationBuilder* b = ctx->builder();
|
||||
xla::ComputationDataHandle nudged_input_min =
|
||||
XlaHelpers::FloatLiteral(b, data_type, nudged_input_min_);
|
||||
xla::ComputationDataHandle nudged_input_max =
|
||||
XlaHelpers::FloatLiteral(b, data_type, nudged_input_max_);
|
||||
|
||||
xla::ComputationDataHandle between_nudged_min_max =
|
||||
b->And(b->Le(nudged_input_min, input), b->Le(input, nudged_input_max));
|
||||
xla::ComputationDataHandle zeroes = b->Broadcast(
|
||||
XlaHelpers::Zero(b, data_type), gradient_shape.dim_sizes());
|
||||
xla::ComputationDataHandle output =
|
||||
b->Select(between_nudged_min_max, gradient, zeroes);
|
||||
ctx->SetOutput(0, output);
|
||||
}
|
||||
|
||||
private:
|
||||
float nudged_input_min_;
|
||||
float nudged_input_max_;
|
||||
};
|
||||
|
||||
REGISTER_XLA_OP(Name("FakeQuantWithMinMaxArgsGradient"),
|
||||
FakeQuantWithMinMaxArgsGradOp);
|
||||
|
||||
class FakeQuantWithMinMaxVarsOp : public XlaOpKernel {
|
||||
public:
|
||||
explicit FakeQuantWithMinMaxVarsOp(OpKernelConstruction* ctx)
|
||||
: XlaOpKernel(ctx) {
|
||||
int num_bits;
|
||||
OP_REQUIRES_OK(ctx, ctx->GetAttr("num_bits", &num_bits));
|
||||
OP_REQUIRES(ctx, num_bits >= 2 && num_bits <= 16,
|
||||
errors::InvalidArgument("num_bits is out of range, expected "
|
||||
"between 2 and 16, was: ",
|
||||
num_bits));
|
||||
bool narrow_range;
|
||||
OP_REQUIRES_OK(ctx, ctx->GetAttr("narrow_range", &narrow_range));
|
||||
quant_min_ = narrow_range ? 1 : 0;
|
||||
quant_max_ = (1 << num_bits) - 1;
|
||||
}
|
||||
|
||||
void Compile(XlaOpKernelContext* ctx) override {
|
||||
xla::ComputationDataHandle input = ctx->Input(0);
|
||||
const DataType data_type = ctx->input_type(0);
|
||||
xla::ComputationDataHandle input_min = ctx->Input(1);
|
||||
xla::ComputationDataHandle input_max = ctx->Input(2);
|
||||
|
||||
xla::ComputationBuilder* b = ctx->builder();
|
||||
xla::ComputationDataHandle nudged_input_min, nudged_input_max, input_scale;
|
||||
XlaNudge(b, data_type, input_min, input_max, quant_min_, quant_max_,
|
||||
&nudged_input_min, &nudged_input_max, &input_scale);
|
||||
|
||||
xla::ComputationDataHandle output = Quantize(
|
||||
b, input, data_type, nudged_input_min, nudged_input_max, input_scale);
|
||||
ctx->SetOutput(0, output);
|
||||
}
|
||||
|
||||
private:
|
||||
float quant_min_;
|
||||
float quant_max_;
|
||||
};
|
||||
|
||||
REGISTER_XLA_OP(Name("FakeQuantWithMinMaxVars"), FakeQuantWithMinMaxVarsOp);
|
||||
|
||||
class FakeQuantWithMinMaxVarsGradOp : public XlaOpKernel {
|
||||
public:
|
||||
explicit FakeQuantWithMinMaxVarsGradOp(OpKernelConstruction* ctx)
|
||||
: XlaOpKernel(ctx) {
|
||||
int num_bits;
|
||||
OP_REQUIRES_OK(ctx, ctx->GetAttr("num_bits", &num_bits));
|
||||
OP_REQUIRES(ctx, num_bits >= 2 && num_bits <= 16,
|
||||
errors::InvalidArgument("num_bits is out of range, expected "
|
||||
"between 2 and 16, was: ",
|
||||
num_bits));
|
||||
bool narrow_range;
|
||||
OP_REQUIRES_OK(ctx, ctx->GetAttr("narrow_range", &narrow_range));
|
||||
quant_min_ = narrow_range ? 1 : 0;
|
||||
quant_max_ = (1 << num_bits) - 1;
|
||||
}
|
||||
|
||||
void Compile(XlaOpKernelContext* ctx) override {
|
||||
xla::ComputationDataHandle gradient = ctx->Input(0);
|
||||
const TensorShape gradient_shape = ctx->InputShape(0);
|
||||
xla::ComputationDataHandle input = ctx->Input(1);
|
||||
const DataType data_type = ctx->input_type(1);
|
||||
xla::ComputationDataHandle input_min = ctx->Input(2);
|
||||
xla::ComputationDataHandle input_max = ctx->Input(3);
|
||||
|
||||
xla::ComputationBuilder* b = ctx->builder();
|
||||
xla::ComputationDataHandle nudged_input_min, nudged_input_max, input_scale;
|
||||
XlaNudge(b, data_type, input_min, input_max, quant_min_, quant_max_,
|
||||
&nudged_input_min, &nudged_input_max, &input_scale);
|
||||
|
||||
xla::ComputationDataHandle between_nudged_min_max =
|
||||
b->And(b->Le(nudged_input_min, input), b->Le(input, nudged_input_max));
|
||||
xla::ComputationDataHandle zero = XlaHelpers::Zero(b, data_type);
|
||||
xla::ComputationDataHandle zeroes =
|
||||
b->Broadcast(zero, gradient_shape.dim_sizes());
|
||||
xla::ComputationDataHandle output0 =
|
||||
b->Select(between_nudged_min_max, gradient, zeroes);
|
||||
ctx->SetOutput(0, output0);
|
||||
|
||||
xla::ComputationDataHandle below_min = b->Lt(input, nudged_input_min);
|
||||
xla::ComputationDataHandle output1 =
|
||||
b->ReduceAll(b->Select(below_min, gradient, zeroes), zero,
|
||||
*ctx->GetOrCreateAdd(data_type));
|
||||
ctx->SetOutput(1, output1);
|
||||
|
||||
xla::ComputationDataHandle above_max = b->Gt(input, nudged_input_max);
|
||||
xla::ComputationDataHandle output2 =
|
||||
b->ReduceAll(b->Select(above_max, gradient, zeroes), zero,
|
||||
*ctx->GetOrCreateAdd(data_type));
|
||||
ctx->SetOutput(2, output2);
|
||||
}
|
||||
|
||||
private:
|
||||
float quant_min_;
|
||||
float quant_max_;
|
||||
};
|
||||
|
||||
REGISTER_XLA_OP(Name("FakeQuantWithMinMaxVarsGradient"),
|
||||
FakeQuantWithMinMaxVarsGradOp);
|
||||
|
||||
} // namespace
|
||||
} // namespace tensorflow
|
@ -1,4 +1,4 @@
|
||||
/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
|
||||
/* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
@ -14,6 +14,7 @@ limitations under the License.
|
||||
==============================================================================*/
|
||||
|
||||
#include "tensorflow/compiler/tf2xla/kernels/gather_op_helpers.h"
|
||||
#include "tensorflow/compiler/tf2xla/lib/while_loop.h"
|
||||
#include "tensorflow/compiler/tf2xla/shape_util.h"
|
||||
#include "tensorflow/compiler/tf2xla/type_util.h"
|
||||
#include "tensorflow/compiler/tf2xla/xla_context.h"
|
||||
@ -32,12 +33,12 @@ Status XlaGather(const xla::ComputationDataHandle& input,
|
||||
DataType dtype, DataType index_type,
|
||||
xla::ComputationBuilder* builder,
|
||||
xla::ComputationDataHandle* gather_output) {
|
||||
// If the indices are N-dimensional, then the last dimension of indices should
|
||||
// be of size N and correspond to the N indices.
|
||||
int64 num_axes = 1;
|
||||
// If the indices are N-dimensional, then the minor dimension of indices
|
||||
// should be of size N and correspond to the N indices.
|
||||
int64 num_index_dims = 1;
|
||||
if (indices_are_nd) {
|
||||
CHECK_GE(indices_shape.dims(), 1);
|
||||
num_axes = indices_shape.dim_size(indices_shape.dims() - 1);
|
||||
num_index_dims = indices_shape.dim_size(indices_shape.dims() - 1);
|
||||
indices_shape.RemoveLastDims(1);
|
||||
}
|
||||
|
||||
@ -46,15 +47,15 @@ Status XlaGather(const xla::ComputationDataHandle& input,
|
||||
// input, the output is returned with shape:
|
||||
// input.shape[:axis] + indices.shape + input.shape[axis+1:]
|
||||
|
||||
const int num_indices = indices_shape.num_elements();
|
||||
const int64 num_indices = indices_shape.num_elements();
|
||||
TensorShape input_shape_pre_axis(input_shape);
|
||||
input_shape_pre_axis.RemoveDimRange(axis, input_shape.dims());
|
||||
TensorShape input_shape_post_axis(input_shape);
|
||||
input_shape_post_axis.RemoveDimRange(0, axis + num_axes);
|
||||
input_shape_post_axis.RemoveDimRange(0, axis + num_index_dims);
|
||||
// Each slice of the input tensor has shape:
|
||||
// [<input_shape_pre_axis>, 1, ..., 1, <input shape_post_axis>]
|
||||
TensorShape slice_shape(input_shape);
|
||||
for (int64 i = 0; i < num_axes; ++i) {
|
||||
for (int64 i = 0; i < num_index_dims; ++i) {
|
||||
slice_shape.set_dim(axis + i, 1);
|
||||
}
|
||||
|
||||
@ -79,7 +80,7 @@ Status XlaGather(const xla::ComputationDataHandle& input,
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
for (int64 i = 0; i < num_axes; ++i) {
|
||||
for (int64 i = 0; i < num_index_dims; ++i) {
|
||||
if (input_shape.dim_size(axis + i) == 0) {
|
||||
return errors::InvalidArgument("Gather dimension ", axis + i,
|
||||
" is of size zero in tensor with shape ",
|
||||
@ -91,57 +92,30 @@ Status XlaGather(const xla::ComputationDataHandle& input,
|
||||
// iteration. If there is an axis dimension, we must leave it alone.
|
||||
std::vector<int64> flat_indices_shape = {num_indices};
|
||||
if (indices_are_nd) {
|
||||
flat_indices_shape.push_back(num_axes);
|
||||
flat_indices_shape.push_back(num_index_dims);
|
||||
}
|
||||
|
||||
// Specify the shape of the loop-carried Tensor tuple.
|
||||
xla::PrimitiveType ptype;
|
||||
TF_CHECK_OK(DataTypeToPrimitiveType(dtype, &ptype));
|
||||
xla::PrimitiveType idxtype;
|
||||
TF_CHECK_OK(DataTypeToPrimitiveType(index_type, &idxtype));
|
||||
std::vector<xla::Shape> tuple_shapes(
|
||||
{// The iteration counter i is a scalar, incremented each iteration.
|
||||
xla::ShapeUtil::MakeShape(idxtype, {}),
|
||||
// The input array has shape input_shape. Loop invariant.
|
||||
xla::ShapeUtil::MakeShape(ptype, input_shape.dim_sizes()),
|
||||
// The gather indices are reshaped to flat_indices_shape. Loop invariant.
|
||||
xla::ShapeUtil::MakeShape(idxtype, flat_indices_shape),
|
||||
// The output array, which is updated on each loop iteration.
|
||||
xla::ShapeUtil::MakeShape(ptype, loop_out_shape.dim_sizes())});
|
||||
xla::Shape tuple_shape = xla::ShapeUtil::MakeTupleShape(tuple_shapes);
|
||||
|
||||
// Construct the initial values of the loop-carried Tensors.
|
||||
auto init_i = XlaHelpers::Zero(builder, index_type);
|
||||
auto flat_indices = builder->Reshape(indices, flat_indices_shape);
|
||||
auto init_out = builder->Broadcast(XlaHelpers::Zero(builder, dtype),
|
||||
loop_out_shape.dim_sizes());
|
||||
auto flat_indices = builder->Reshape(indices, flat_indices_shape);
|
||||
auto init = builder->Tuple({init_i, input, flat_indices, init_out});
|
||||
|
||||
// Construct the while loop condition (i < num_indices)
|
||||
std::unique_ptr<xla::ComputationBuilder> condb =
|
||||
builder->CreateSubBuilder("GatherWhileCond");
|
||||
condb->Lt(condb->GetTupleElement(
|
||||
condb->Parameter(0, tuple_shape, "GatherWhileTuple"), 0),
|
||||
XlaHelpers::IntegerLiteral(condb.get(), index_type, num_indices));
|
||||
auto cond_status = condb->Build();
|
||||
auto cond = cond_status.ConsumeValueOrDie();
|
||||
auto init = {input, flat_indices, init_out};
|
||||
|
||||
// Construct the while loop body's function. The implementation of gather is:
|
||||
// for i in range(num_indices):
|
||||
// index = dynamic-slice(indices, i)
|
||||
// xi = dynamic-slice(input, index)
|
||||
// output = dynamic-update-slice(output, xi, i)
|
||||
std::unique_ptr<xla::ComputationBuilder> bodyb =
|
||||
builder->CreateSubBuilder("GatherWhileBody");
|
||||
{
|
||||
// The four loop carried values.
|
||||
auto loop_tuple = bodyb->Parameter(0, tuple_shape, "GatherWhileTuple");
|
||||
auto i = bodyb->GetTupleElement(loop_tuple, 0);
|
||||
auto input = bodyb->GetTupleElement(loop_tuple, 1);
|
||||
auto indices = bodyb->GetTupleElement(loop_tuple, 2);
|
||||
auto output = bodyb->GetTupleElement(loop_tuple, 3);
|
||||
auto body_fn = [&](xla::ComputationDataHandle i,
|
||||
gtl::ArraySlice<xla::ComputationDataHandle> loop_vars,
|
||||
xla::ComputationBuilder* bodyb) {
|
||||
auto input = loop_vars[0];
|
||||
auto indices = loop_vars[1];
|
||||
auto output = loop_vars[2];
|
||||
|
||||
auto zero_index = XlaHelpers::Zero(bodyb.get(), index_type);
|
||||
auto zero_index = XlaHelpers::Zero(bodyb, index_type);
|
||||
|
||||
// Slice the i-th index from the indices array.
|
||||
xla::ComputationDataHandle index;
|
||||
@ -150,7 +124,7 @@ Status XlaGather(const xla::ComputationDataHandle& input,
|
||||
// Slice out the entire nd index, if applicable.
|
||||
indices_offset = bodyb->Pad(indices_offset, zero_index,
|
||||
xla::MakeEdgePaddingConfig({{0, 1}}));
|
||||
index = bodyb->DynamicSlice(indices, indices_offset, {1, num_axes});
|
||||
index = bodyb->DynamicSlice(indices, indices_offset, {1, num_index_dims});
|
||||
index = bodyb->Collapse(index, {0, 1});
|
||||
} else {
|
||||
index = bodyb->DynamicSlice(indices, indices_offset, {1});
|
||||
@ -174,16 +148,16 @@ Status XlaGather(const xla::ComputationDataHandle& input,
|
||||
// Update the output Tensor
|
||||
auto updated_output = bodyb->DynamicUpdateSlice(output, slice_i, out_index);
|
||||
|
||||
bodyb->Tuple({bodyb->Add(i, XlaHelpers::One(bodyb.get(), index_type)),
|
||||
input, indices, updated_output});
|
||||
}
|
||||
auto body_status = bodyb->Build();
|
||||
auto body = body_status.ConsumeValueOrDie();
|
||||
return std::vector<xla::ComputationDataHandle>{input, indices,
|
||||
updated_output};
|
||||
};
|
||||
|
||||
// Construct the While loop, extract and reshape the output.
|
||||
auto gather_while = builder->While(cond, body, init);
|
||||
auto result = builder->GetTupleElement(gather_while, 3);
|
||||
*gather_output = builder->Reshape(result, out_shape.dim_sizes());
|
||||
xla::PrimitiveType ptype;
|
||||
TF_RETURN_IF_ERROR(DataTypeToPrimitiveType(index_type, &ptype));
|
||||
TF_ASSIGN_OR_RETURN(auto outputs, XlaForEachIndex(num_indices, ptype, body_fn,
|
||||
init, "gather", builder));
|
||||
*gather_output = builder->Reshape(outputs[2], out_shape.dim_sizes());
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
@ -250,9 +224,10 @@ class GatherNdOp : public XlaOpKernel {
|
||||
errors::InvalidArgument("params must be at least a vector"));
|
||||
OP_REQUIRES(context, TensorShapeUtils::IsVectorOrHigher(indices_shape),
|
||||
errors::InvalidArgument("indices must be at least a vector"));
|
||||
const int64 num_axes = indices_shape.dim_size(indices_shape.dims() - 1);
|
||||
const int64 num_index_dims =
|
||||
indices_shape.dim_size(indices_shape.dims() - 1);
|
||||
OP_REQUIRES(
|
||||
context, num_axes <= params_shape.dims(),
|
||||
context, num_index_dims <= params_shape.dims(),
|
||||
errors::InvalidArgument(
|
||||
"index innermost dimension length must be <= params rank; saw: ",
|
||||
indices_shape.dim_size(indices_shape.dims() - 1), " vs. ",
|
||||
|
121
tensorflow/compiler/tf2xla/kernels/scatter_nd_op.cc
Normal file
121
tensorflow/compiler/tf2xla/kernels/scatter_nd_op.cc
Normal file
@ -0,0 +1,121 @@
|
||||
/* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
|
||||
#include "tensorflow/compiler/tf2xla/lib/scatter.h"
|
||||
#include "tensorflow/compiler/tf2xla/shape_util.h"
|
||||
#include "tensorflow/compiler/tf2xla/type_util.h"
|
||||
#include "tensorflow/compiler/tf2xla/xla_helpers.h"
|
||||
#include "tensorflow/compiler/tf2xla/xla_op_kernel.h"
|
||||
#include "tensorflow/compiler/tf2xla/xla_op_registry.h"
|
||||
#include "tensorflow/compiler/xla/status_macros.h"
|
||||
#include "tensorflow/core/framework/kernel_def_builder.h"
|
||||
#include "tensorflow/core/framework/op_kernel.h"
|
||||
|
||||
namespace tensorflow {
|
||||
namespace {
|
||||
|
||||
// Check whether updates.shape = indices.shape[:batch_dim] +
|
||||
// buffer_shape[num_index_dims:]
|
||||
Status ValidateUpdateShape(const TensorShape& buffer_shape,
|
||||
const TensorShape& indices_shape,
|
||||
const TensorShape& updates_shape) {
|
||||
if (indices_shape.dims() < 1) {
|
||||
return errors::InvalidArgument(
|
||||
"indices shape must have >= 1 dimension; got ",
|
||||
indices_shape.DebugString());
|
||||
}
|
||||
|
||||
const int64 num_index_dims = indices_shape.dim_size(indices_shape.dims() - 1);
|
||||
const int64 batch_dim = indices_shape.dims() - 1;
|
||||
|
||||
auto shape_err = [&]() {
|
||||
return errors::InvalidArgument(
|
||||
"Must have updates.shape = indices.shape[:batch_dim] + ",
|
||||
"buffer_shape[num_index_dims:], got updates.shape: ",
|
||||
updates_shape.DebugString(),
|
||||
", indices.shape: ", indices_shape.DebugString(),
|
||||
", buffer_shape: ", buffer_shape.DebugString(),
|
||||
", num_index_dims: ", num_index_dims, ", and batch_dim: ", batch_dim);
|
||||
};
|
||||
|
||||
if (updates_shape.dims() < batch_dim) return shape_err();
|
||||
if (buffer_shape.dims() <
|
||||
num_index_dims + (updates_shape.dims() - batch_dim)) {
|
||||
return shape_err();
|
||||
}
|
||||
if (updates_shape.dims() !=
|
||||
batch_dim + buffer_shape.dims() - num_index_dims) {
|
||||
return shape_err();
|
||||
}
|
||||
for (int d = 0; d < batch_dim; ++d) {
|
||||
if (updates_shape.dim_size(d) != indices_shape.dim_size(d)) {
|
||||
return shape_err();
|
||||
}
|
||||
}
|
||||
for (int d = 0; d < updates_shape.dims() - batch_dim; ++d) {
|
||||
if (updates_shape.dim_size(d + batch_dim) !=
|
||||
buffer_shape.dim_size(d + num_index_dims)) {
|
||||
return shape_err();
|
||||
}
|
||||
}
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
class ScatterNdOp : public XlaOpKernel {
|
||||
public:
|
||||
explicit ScatterNdOp(OpKernelConstruction* context) : XlaOpKernel(context) {}
|
||||
|
||||
void Compile(XlaOpKernelContext* context) override {
|
||||
DataType dtype = context->input_type(1);
|
||||
|
||||
TensorShape indices_shape = context->InputShape(0);
|
||||
TensorShape updates_shape = context->InputShape(1);
|
||||
|
||||
TensorShape buffer_shape;
|
||||
OP_REQUIRES_OK(context, context->ConstantInputAsShape(2, &buffer_shape));
|
||||
|
||||
OP_REQUIRES(
|
||||
context, TensorShapeUtils::IsVectorOrHigher(buffer_shape),
|
||||
errors::InvalidArgument("Output must be at least 1-D, ",
|
||||
"got shape: ", buffer_shape.DebugString()));
|
||||
|
||||
OP_REQUIRES(
|
||||
context,
|
||||
buffer_shape.num_elements() > 0 || (indices_shape.num_elements() == 0 &&
|
||||
updates_shape.num_elements() == 0),
|
||||
errors::InvalidArgument(
|
||||
"Indices and updates specified for empty output. indices shape: ",
|
||||
indices_shape.DebugString()));
|
||||
|
||||
OP_REQUIRES_OK(context, ValidateUpdateShape(buffer_shape, indices_shape,
|
||||
updates_shape));
|
||||
|
||||
xla::ComputationBuilder* builder = context->builder();
|
||||
auto buffer = builder->Broadcast(XlaHelpers::Zero(builder, dtype),
|
||||
buffer_shape.dim_sizes());
|
||||
auto indices = context->Input(0);
|
||||
auto updates = context->Input(1);
|
||||
auto result =
|
||||
XlaScatter(buffer, updates, indices,
|
||||
/*indices_are_vectors=*/true, /*combiner=*/{}, builder);
|
||||
OP_REQUIRES_OK(context, result.status());
|
||||
context->SetOutput(0, result.ValueOrDie());
|
||||
}
|
||||
};
|
||||
|
||||
REGISTER_XLA_OP(Name("ScatterNd").CompileTimeConstInput("shape"), ScatterNdOp);
|
||||
|
||||
} // namespace
|
||||
} // namespace tensorflow
|
@ -1,39 +0,0 @@
|
||||
/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
// Helper methods for XLA Scatter Ops.
|
||||
#ifndef TENSORFLOW_COMPILER_TF2XLA_KERNELS_SCATTER_OP_HELPERS_H_
|
||||
#define TENSORFLOW_COMPILER_TF2XLA_KERNELS_SCATTER_OP_HELPERS_H_
|
||||
|
||||
#include "tensorflow/compiler/tf2xla/xla_op_kernel.h"
|
||||
#include "tensorflow/compiler/xla/client/client_library.h"
|
||||
#include "tensorflow/compiler/xla/client/computation_builder.h"
|
||||
#include "tensorflow/core/framework/op_kernel.h"
|
||||
#include "tensorflow/core/util/bcast.h"
|
||||
|
||||
namespace tensorflow {
|
||||
|
||||
// Adds to builder an XLA computation that performs a scatter-add of input (of
|
||||
// shape input_shape) keyed on indices (of shape indices_shape). The shape
|
||||
// of the Tensor returned by this is num_segments input_shape[indices.dims():]
|
||||
//
|
||||
static xla::ComputationDataHandle XlaComputeScatterAddDynamicSlice(
|
||||
XlaOpKernelContext* ctx, const xla::ComputationDataHandle& input,
|
||||
const TensorShape& input_shape, const xla::ComputationDataHandle& indices,
|
||||
const TensorShape& indices_shape, int64 num_segments, DataType dtype,
|
||||
xla::ComputationBuilder* builder);
|
||||
|
||||
} // namespace tensorflow
|
||||
|
||||
#endif // TENSORFLOW_COMPILER_TF2XLA_KERNELS_SCATTER_OP_HELPERS_H_
|
@ -1,4 +1,4 @@
|
||||
/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
|
||||
/* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
@ -13,125 +13,13 @@ See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
|
||||
#include <sstream>
|
||||
#include "tensorflow/compiler/tf2xla/kernels/cwise_ops.h"
|
||||
#include "tensorflow/compiler/tf2xla/shape_util.h"
|
||||
#include "tensorflow/compiler/tf2xla/type_util.h"
|
||||
#include "tensorflow/compiler/tf2xla/lib/scatter.h"
|
||||
#include "tensorflow/compiler/tf2xla/xla_helpers.h"
|
||||
#include "tensorflow/compiler/tf2xla/xla_op_kernel.h"
|
||||
#include "tensorflow/compiler/tf2xla/xla_op_registry.h"
|
||||
#include "tensorflow/compiler/xla/client/computation_builder.h"
|
||||
#include "tensorflow/compiler/xla/literal_util.h"
|
||||
#include "tensorflow/core/framework/kernel_def_builder.h"
|
||||
#include "tensorflow/core/framework/types.h"
|
||||
|
||||
namespace tensorflow {
|
||||
|
||||
xla::ComputationDataHandle XlaComputeScatterAddDynamicSlice(
|
||||
XlaOpKernelContext* ctx, const xla::ComputationDataHandle& input,
|
||||
const TensorShape& input_shape, const xla::ComputationDataHandle& indices,
|
||||
const TensorShape& indices_shape, int64 num_segments, DataType dtype,
|
||||
xla::ComputationBuilder* builder) {
|
||||
// Flatten data for dynamic indexing via indices_1d.
|
||||
TensorShape input_shape_i(input_shape);
|
||||
for (int64 d = 0; d < indices_shape.dims(); ++d) {
|
||||
input_shape_i.RemoveDim(0);
|
||||
}
|
||||
TensorShape flat_shape({indices_shape.num_elements()});
|
||||
flat_shape.AppendShape(input_shape_i);
|
||||
|
||||
// output is same as flattened input shape with dim_size(0) = num_segments.
|
||||
TensorShape out_shape(flat_shape);
|
||||
out_shape.set_dim(0, num_segments);
|
||||
|
||||
// Slices from the input data are same shape as the input data, except dim 0.
|
||||
TensorShape slice_shape(flat_shape);
|
||||
slice_shape.set_dim(0, 1);
|
||||
TensorShape loop_out_slice_shape(out_shape);
|
||||
loop_out_slice_shape.set_dim(0, 1);
|
||||
|
||||
// Construct the initial values of the loop-carried variables
|
||||
// Flatten the indices into 1-D for ease of iteration.
|
||||
auto indices_1d = builder->Reshape(indices, {indices_shape.num_elements()});
|
||||
// Flatten the data for ease of indexing via values in indices_1d.
|
||||
auto data_flat = builder->Reshape(input, flat_shape.dim_sizes());
|
||||
|
||||
auto init_i = builder->ConstantR0<int32>(0);
|
||||
auto init_out = builder->Broadcast(XlaHelpers::Zero(builder, dtype),
|
||||
out_shape.dim_sizes());
|
||||
|
||||
xla::PrimitiveType ptype;
|
||||
TF_CHECK_OK(DataTypeToPrimitiveType(dtype, &ptype));
|
||||
|
||||
std::vector<xla::Shape> tuple_shapes(
|
||||
{// The loop iteration counter is a scalar, incremented each iteration.
|
||||
xla::ShapeUtil::MakeShape(xla::S32, {}),
|
||||
// The flattened input data is loop invariant.
|
||||
xla::ShapeUtil::MakeShape(ptype, flat_shape.dim_sizes()),
|
||||
// The scatter indices tensor is loop invariant.
|
||||
xla::ShapeUtil::MakeShape(xla::S32, {indices_shape.num_elements()}),
|
||||
// The output data array is updated each loop iteration.
|
||||
xla::ShapeUtil::MakeShape(ptype, out_shape.dim_sizes())});
|
||||
xla::Shape tuple_shape = xla::ShapeUtil::MakeTupleShape(tuple_shapes);
|
||||
|
||||
auto init = builder->Tuple({init_i, data_flat, indices_1d, init_out});
|
||||
|
||||
// Construct the while loop condition (i < num_indices)
|
||||
xla::ComputationBuilder condb(ctx->builder()->client(),
|
||||
"ScatterAddWhileCond");
|
||||
condb.Lt(condb.GetTupleElement(
|
||||
condb.Parameter(0, tuple_shape, "ScatterAddWhileTuple"), 0),
|
||||
condb.ConstantR0<int32>(indices_shape.num_elements()));
|
||||
auto cond_status = condb.Build();
|
||||
auto cond = cond_status.ConsumeValueOrDie();
|
||||
|
||||
// Construct the while loop body's function. The implementation of scatter is:
|
||||
// for i in range(num_indices):
|
||||
// index = dynamic-slice(indices, i)
|
||||
// xi = dynamic-slice(input, i)
|
||||
// output = dynamic-update-slice(output, xi, index)
|
||||
xla::ComputationBuilder bodyb(ctx->builder()->client(),
|
||||
"ScatterAddWhileBody");
|
||||
{
|
||||
auto input_tuple = bodyb.Parameter(0, tuple_shape, "ScatterAddWhileTuple");
|
||||
auto i = bodyb.GetTupleElement(input_tuple, 0);
|
||||
auto data = bodyb.GetTupleElement(input_tuple, 1);
|
||||
auto idcs = bodyb.GetTupleElement(input_tuple, 2);
|
||||
auto output = bodyb.GetTupleElement(input_tuple, 3);
|
||||
|
||||
// Index into the data array at i.
|
||||
auto zero = bodyb.ConstantR1<int32>({0});
|
||||
std::vector<xla::ComputationDataHandle> index_vals(flat_shape.dims(), zero);
|
||||
index_vals[0] = bodyb.Reshape(i, {1});
|
||||
auto index = bodyb.ConcatInDim(index_vals, 0);
|
||||
|
||||
auto data_slice =
|
||||
bodyb.Reshape(bodyb.DynamicSlice(data, index, slice_shape.dim_sizes()),
|
||||
loop_out_slice_shape.dim_sizes());
|
||||
|
||||
// Index into the output array.
|
||||
std::vector<xla::ComputationDataHandle> out_index_vals(out_shape.dims(),
|
||||
zero);
|
||||
out_index_vals[0] = bodyb.DynamicSlice(idcs, bodyb.Reshape(i, {1}), {1});
|
||||
auto out_index = bodyb.ConcatInDim(out_index_vals, 0);
|
||||
|
||||
// Slice the output array, update value, and update the output slice.
|
||||
auto updated_output = bodyb.DynamicUpdateSlice(
|
||||
output,
|
||||
bodyb.Add(data_slice,
|
||||
bodyb.DynamicSlice(output, out_index,
|
||||
loop_out_slice_shape.dim_sizes())),
|
||||
out_index);
|
||||
|
||||
auto ip1 = bodyb.Add(i, bodyb.ConstantR0<int32>(1));
|
||||
bodyb.Tuple({ip1, data, idcs, updated_output});
|
||||
}
|
||||
auto body_status = bodyb.Build();
|
||||
auto body = body_status.ConsumeValueOrDie();
|
||||
|
||||
auto gather_while = builder->While(cond, body, init);
|
||||
return builder->GetTupleElement(gather_while, 3);
|
||||
}
|
||||
|
||||
namespace {
|
||||
|
||||
class UnsortedSegmentSum : public XlaOpKernel {
|
||||
@ -153,10 +41,10 @@ class UnsortedSegmentSum : public XlaOpKernel {
|
||||
// as data with the first indices.rank dimensions are replaced
|
||||
// by a single dimension with size num_segments.
|
||||
auto data = ctx->Input(0);
|
||||
auto data_shape = ctx->InputShape(0);
|
||||
TensorShape data_shape = ctx->InputShape(0);
|
||||
|
||||
auto indices = ctx->Input(1);
|
||||
auto indices_shape = ctx->InputShape(1);
|
||||
TensorShape indices_shape = ctx->InputShape(1);
|
||||
|
||||
int64 num_segments;
|
||||
OP_REQUIRES_OK(ctx, ctx->ConstantInputAsIntScalar(2, &num_segments));
|
||||
@ -174,10 +62,21 @@ class UnsortedSegmentSum : public XlaOpKernel {
|
||||
d, " differs ", data_shape.dim_size(d), " vs. ",
|
||||
indices_shape.dim_size(d)));
|
||||
}
|
||||
auto result = XlaComputeScatterAddDynamicSlice(
|
||||
ctx, data, data_shape, indices, indices_shape, num_segments, dtype_,
|
||||
ctx->builder());
|
||||
ctx->SetOutput(0, result);
|
||||
xla::ComputationBuilder* builder = ctx->builder();
|
||||
TensorShape buffer_shape = data_shape;
|
||||
buffer_shape.RemoveDimRange(0, indices_shape.dims());
|
||||
buffer_shape.InsertDim(0, num_segments);
|
||||
auto buffer = builder->Broadcast(XlaHelpers::Zero(builder, dtype_),
|
||||
buffer_shape.dim_sizes());
|
||||
|
||||
auto combiner =
|
||||
[](xla::ComputationDataHandle a, xla::ComputationDataHandle b,
|
||||
xla::ComputationBuilder* builder) { return builder->Add(a, b); };
|
||||
|
||||
auto result = XlaScatter(buffer, /*updates=*/data, indices,
|
||||
/*indices_are_vectors=*/false, combiner, builder);
|
||||
OP_REQUIRES_OK(ctx, result.status());
|
||||
ctx->SetOutput(0, result.ValueOrDie());
|
||||
}
|
||||
|
||||
private:
|
||||
|
@ -49,6 +49,25 @@ cc_library(
|
||||
],
|
||||
)
|
||||
|
||||
cc_library(
|
||||
name = "scatter",
|
||||
srcs = ["scatter.cc"],
|
||||
hdrs = ["scatter.h"],
|
||||
deps = [
|
||||
":util",
|
||||
":while_loop",
|
||||
"//tensorflow/compiler/xla:literal_util",
|
||||
"//tensorflow/compiler/xla:shape_util",
|
||||
"//tensorflow/compiler/xla:status_macros",
|
||||
"//tensorflow/compiler/xla:statusor",
|
||||
"//tensorflow/compiler/xla:util",
|
||||
"//tensorflow/compiler/xla/client:computation",
|
||||
"//tensorflow/compiler/xla/client:computation_builder",
|
||||
"//tensorflow/compiler/xla/client/lib:arithmetic",
|
||||
"//tensorflow/core:lib",
|
||||
],
|
||||
)
|
||||
|
||||
cc_library(
|
||||
name = "triangular_solve",
|
||||
srcs = ["triangular_solve.cc"],
|
||||
@ -107,6 +126,21 @@ cc_library(
|
||||
],
|
||||
)
|
||||
|
||||
cc_library(
|
||||
name = "while_loop",
|
||||
srcs = ["while_loop.cc"],
|
||||
hdrs = ["while_loop.h"],
|
||||
deps = [
|
||||
":util",
|
||||
"//tensorflow/compiler/xla:shape_util",
|
||||
"//tensorflow/compiler/xla:status_macros",
|
||||
"//tensorflow/compiler/xla:statusor",
|
||||
"//tensorflow/compiler/xla/client:computation",
|
||||
"//tensorflow/compiler/xla/client:computation_builder",
|
||||
"//tensorflow/core:lib",
|
||||
],
|
||||
)
|
||||
|
||||
# -----------------------------------------------------------------------------
|
||||
|
||||
filegroup(
|
||||
|
189
tensorflow/compiler/tf2xla/lib/scatter.cc
Normal file
189
tensorflow/compiler/tf2xla/lib/scatter.cc
Normal file
@ -0,0 +1,189 @@
|
||||
/* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
|
||||
#include "tensorflow/compiler/tf2xla/lib/scatter.h"
|
||||
|
||||
#include <memory>
|
||||
#include <vector>
|
||||
|
||||
#include "tensorflow/compiler/tf2xla/lib/util.h"
|
||||
#include "tensorflow/compiler/tf2xla/lib/while_loop.h"
|
||||
#include "tensorflow/compiler/xla/client/lib/arithmetic.h"
|
||||
#include "tensorflow/compiler/xla/literal_util.h"
|
||||
#include "tensorflow/compiler/xla/shape_util.h"
|
||||
#include "tensorflow/compiler/xla/status_macros.h"
|
||||
#include "tensorflow/compiler/xla/util.h"
|
||||
#include "tensorflow/core/lib/core/errors.h"
|
||||
#include "tensorflow/core/lib/gtl/array_slice.h"
|
||||
|
||||
namespace tensorflow {
|
||||
|
||||
xla::StatusOr<xla::ComputationDataHandle> XlaScatter(
|
||||
const xla::ComputationDataHandle& buffer,
|
||||
const xla::ComputationDataHandle& updates,
|
||||
const xla::ComputationDataHandle& indices, bool indices_are_vectors,
|
||||
const std::function<xla::ComputationDataHandle(
|
||||
xla::ComputationDataHandle, xla::ComputationDataHandle,
|
||||
xla::ComputationBuilder*)>& combiner,
|
||||
xla::ComputationBuilder* builder) {
|
||||
TF_ASSIGN_OR_RETURN(std::unique_ptr<xla::Shape> buffer_shape,
|
||||
builder->GetShape(buffer));
|
||||
TF_ASSIGN_OR_RETURN(std::unique_ptr<xla::Shape> updates_shape,
|
||||
builder->GetShape(updates));
|
||||
TF_ASSIGN_OR_RETURN(std::unique_ptr<xla::Shape> indices_shape,
|
||||
builder->GetShape(indices));
|
||||
gtl::ArraySlice<int64> indices_dims =
|
||||
xla::AsInt64Slice(indices_shape->dimensions());
|
||||
gtl::ArraySlice<int64> buffer_dims =
|
||||
xla::AsInt64Slice(buffer_shape->dimensions());
|
||||
|
||||
// If the indices are N-dimensional, the minor dimension of indices contains
|
||||
// the indices to update. Otherwise the indices are all scalars.
|
||||
int64 num_index_dims = 1;
|
||||
if (indices_are_vectors) {
|
||||
TF_RET_CHECK(!indices_dims.empty());
|
||||
num_index_dims = indices_dims.back();
|
||||
if (num_index_dims > xla::ShapeUtil::Rank(*buffer_shape)) {
|
||||
return errors::InvalidArgument(
|
||||
"The size of the minor dimension of the indices (shape: ",
|
||||
xla::ShapeUtil::HumanString(*indices_shape),
|
||||
") must be <= the rank of the buffer (shape: ",
|
||||
xla::ShapeUtil::HumanString(*buffer_shape), ")");
|
||||
}
|
||||
indices_dims.pop_back();
|
||||
}
|
||||
|
||||
int64 num_indices = 1;
|
||||
for (int64 dim : indices_dims) {
|
||||
num_indices *= dim;
|
||||
}
|
||||
|
||||
// Degenerate case: nothing to update. Return the buffer unchanged.
|
||||
if (num_indices == 0) {
|
||||
return buffer;
|
||||
}
|
||||
|
||||
// If any of the indexed dimensions are zero in the buffer, the update cannot
|
||||
// succeed since it updates a slice of size 1.
|
||||
for (int64 i = 0; i < num_index_dims; ++i) {
|
||||
if (xla::ShapeUtil::GetDimension(*buffer_shape, i) == 0) {
|
||||
return errors::InvalidArgument(
|
||||
"Scatter dimension ", i, " is of size zero in tensor with shape ",
|
||||
xla::ShapeUtil::HumanString(*buffer_shape));
|
||||
}
|
||||
}
|
||||
|
||||
// Shape of the non-indexed dimensions of the buffer.
|
||||
std::vector<int64> buffer_shape_post_axes(
|
||||
buffer_dims.begin() + num_index_dims, buffer_dims.end());
|
||||
|
||||
// Flatten the major dimensions of indices and updates into a single dimension
|
||||
// for ease of iteration.
|
||||
std::vector<int64> flat_indices_shape({num_indices});
|
||||
if (indices_are_vectors) {
|
||||
flat_indices_shape.push_back(num_index_dims);
|
||||
}
|
||||
|
||||
std::vector<int64> flat_updates_shape({num_indices});
|
||||
flat_updates_shape.insert(flat_updates_shape.end(),
|
||||
buffer_shape_post_axes.begin(),
|
||||
buffer_shape_post_axes.end());
|
||||
|
||||
// Construct the initial values of the loop-carried Tensors.
|
||||
auto flat_indices = builder->Reshape(indices, flat_indices_shape);
|
||||
auto flat_updates = builder->Reshape(updates, flat_updates_shape);
|
||||
auto init = {flat_indices, flat_updates, buffer};
|
||||
|
||||
// Constructs the loop body. The implementation of scatter is essentially:
|
||||
// for i in range(num_indices):
|
||||
// index = dynamic-slice(indices, i)
|
||||
// update = dynamic-slice(updates, i)
|
||||
// buffer = dynamic-update-slice(buffer, update, index)
|
||||
auto body_fn = [&](xla::ComputationDataHandle i,
|
||||
gtl::ArraySlice<xla::ComputationDataHandle> loop_vars,
|
||||
xla::ComputationBuilder* body_builder) {
|
||||
auto indices = loop_vars[0];
|
||||
auto updates = loop_vars[1];
|
||||
auto buffer = loop_vars[2];
|
||||
|
||||
auto zero_index = body_builder->ConstantLiteral(
|
||||
xla::Literal::Zero(indices_shape->element_type()));
|
||||
|
||||
// Slice the i-th index from the indices array.
|
||||
xla::ComputationDataHandle index;
|
||||
auto indices_offset = body_builder->Reshape(i, {1});
|
||||
if (indices_are_vectors) {
|
||||
indices_offset = body_builder->Pad(indices_offset, zero_index,
|
||||
xla::MakeEdgePaddingConfig({{0, 1}}));
|
||||
|
||||
index = body_builder->DynamicSlice(indices, indices_offset,
|
||||
{1, num_index_dims});
|
||||
index = body_builder->Collapse(index, {0, 1});
|
||||
} else {
|
||||
index = body_builder->DynamicSlice(indices, indices_offset, {1});
|
||||
}
|
||||
|
||||
// Discard updates with negative indices, since some users expect this.
|
||||
auto index_in_range =
|
||||
body_builder->ReduceAll(body_builder->Le(zero_index, index),
|
||||
body_builder->ConstantR0<bool>(true),
|
||||
xla::CreateScalarAndComputation(body_builder));
|
||||
|
||||
index = body_builder->Pad(
|
||||
index, zero_index,
|
||||
xla::MakeEdgePaddingConfig({{0, buffer_shape_post_axes.size()}}));
|
||||
|
||||
// Slice the i-th index from the updates array.
|
||||
auto updates_offset = body_builder->Reshape(i, {1});
|
||||
updates_offset = body_builder->Pad(
|
||||
updates_offset, zero_index,
|
||||
xla::MakeEdgePaddingConfig({{0, buffer_shape_post_axes.size()}}));
|
||||
std::vector<int64> flat_updates_slice_shape({1});
|
||||
flat_updates_slice_shape.insert(flat_updates_slice_shape.end(),
|
||||
buffer_shape_post_axes.begin(),
|
||||
buffer_shape_post_axes.end());
|
||||
auto update = body_builder->DynamicSlice(updates, updates_offset,
|
||||
flat_updates_slice_shape);
|
||||
|
||||
// Unflatten the major (iteration) dimensions of the slice to their original
|
||||
// shape.
|
||||
std::vector<int64> updates_slice_shape(num_index_dims, 1);
|
||||
updates_slice_shape.insert(updates_slice_shape.end(),
|
||||
buffer_shape_post_axes.begin(),
|
||||
buffer_shape_post_axes.end());
|
||||
update = body_builder->Reshape(update, updates_slice_shape);
|
||||
|
||||
// Apply the update to the buffer. If there is a combiner, use it to merge
|
||||
// the current values with the update.
|
||||
if (combiner) {
|
||||
auto current_value =
|
||||
body_builder->DynamicSlice(buffer, index, updates_slice_shape);
|
||||
update = combiner(current_value, update, body_builder);
|
||||
}
|
||||
// Apply the update if it is in range.
|
||||
buffer = body_builder->Select(
|
||||
index_in_range, body_builder->DynamicUpdateSlice(buffer, update, index),
|
||||
buffer);
|
||||
|
||||
return std::vector<xla::ComputationDataHandle>{indices, updates, buffer};
|
||||
};
|
||||
|
||||
TF_ASSIGN_OR_RETURN(
|
||||
auto outputs, XlaForEachIndex(num_indices, indices_shape->element_type(),
|
||||
body_fn, init, "scatter", builder));
|
||||
return outputs[2];
|
||||
}
|
||||
|
||||
} // namespace tensorflow
|
53
tensorflow/compiler/tf2xla/lib/scatter.h
Normal file
53
tensorflow/compiler/tf2xla/lib/scatter.h
Normal file
@ -0,0 +1,53 @@
|
||||
/* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
|
||||
#ifndef TENSORFLOW_COMPILER_TF2XLA_LIB_SCATTER_H_
|
||||
#define TENSORFLOW_COMPILER_TF2XLA_LIB_SCATTER_H_
|
||||
|
||||
#include <functional>
|
||||
|
||||
#include "tensorflow/compiler/xla/client/computation.h"
|
||||
#include "tensorflow/compiler/xla/client/computation_builder.h"
|
||||
#include "tensorflow/compiler/xla/statusor.h"
|
||||
|
||||
namespace tensorflow {
|
||||
|
||||
// Builds an XLA computation that performs a scatter operation on `buffer`,
|
||||
// returning an updated buffer.
|
||||
// For each i0, i1, ..., sets
|
||||
// buffer[indices[i0, i1, ...], ...] := updates[i0, i1, ...]
|
||||
//
|
||||
// If `indices_are_vectors` is false, then each index in indices is a scalar,
|
||||
// and the shape of `indices` must be a prefix of the shape of updates.
|
||||
// Otherwise, `indices_are_vectors`, then indices are multidimensional and the
|
||||
// minor dimension of `indices` represents a vector of indices.
|
||||
//
|
||||
// If any indices are negative, the corresponding update is discarded.
|
||||
//
|
||||
// If a `combiner` is provided, updates are combined with the existing values in
|
||||
// the buffer using the combiner function. Otherwise, the updates replace the
|
||||
// existing values. The order of updates is implementation-defined.
|
||||
xla::StatusOr<xla::ComputationDataHandle> XlaScatter(
|
||||
const xla::ComputationDataHandle& buffer,
|
||||
const xla::ComputationDataHandle& updates,
|
||||
const xla::ComputationDataHandle& indices, bool indices_are_vectors,
|
||||
const std::function<xla::ComputationDataHandle(
|
||||
xla::ComputationDataHandle, xla::ComputationDataHandle,
|
||||
xla::ComputationBuilder*)>& combiner,
|
||||
xla::ComputationBuilder* builder);
|
||||
|
||||
} // namespace tensorflow
|
||||
|
||||
#endif // TENSORFLOW_COMPILER_TF2XLA_LIB_SCATTER_H_
|
@ -57,6 +57,61 @@ xla::ComputationDataHandle FloatLiteral(xla::ComputationBuilder* builder,
|
||||
}
|
||||
}
|
||||
|
||||
xla::ComputationDataHandle IntegerLiteral(xla::ComputationBuilder* builder,
|
||||
xla::PrimitiveType type,
|
||||
int64 value) {
|
||||
xla::Literal literal;
|
||||
switch (type) {
|
||||
case xla::U8:
|
||||
literal = std::move(*xla::Literal::CreateR0<uint8>(value));
|
||||
break;
|
||||
case xla::U32:
|
||||
literal = std::move(*xla::Literal::CreateR0<uint32>(value));
|
||||
break;
|
||||
case xla::U64:
|
||||
literal = std::move(*xla::Literal::CreateR0<uint64>(value));
|
||||
break;
|
||||
case xla::S8:
|
||||
literal = std::move(*xla::Literal::CreateR0<int8>(value));
|
||||
break;
|
||||
case xla::S32:
|
||||
literal = std::move(*xla::Literal::CreateR0<int32>(value));
|
||||
break;
|
||||
case xla::S64:
|
||||
literal = std::move(*xla::Literal::CreateR0<int64>(value));
|
||||
break;
|
||||
case xla::F32:
|
||||
literal = std::move(*xla::Literal::CreateR0<float>(value));
|
||||
break;
|
||||
case xla::F64:
|
||||
literal = std::move(*xla::Literal::CreateR0<double>(value));
|
||||
break;
|
||||
case xla::C64:
|
||||
literal = std::move(*xla::Literal::CreateR0<complex64>(value));
|
||||
break;
|
||||
case xla::PRED:
|
||||
LOG(FATAL) << "pred element type is not integral";
|
||||
case xla::S16:
|
||||
case xla::U16:
|
||||
LOG(FATAL) << "u16/s16 literals not yet implemented";
|
||||
case xla::BF16:
|
||||
literal = std::move(
|
||||
*xla::Literal::CreateR0<bfloat16>(static_cast<bfloat16>(value)));
|
||||
break;
|
||||
case xla::F16:
|
||||
literal = std::move(
|
||||
*xla::Literal::CreateR0<xla::half>(static_cast<xla::half>(value)));
|
||||
break;
|
||||
case xla::TUPLE:
|
||||
LOG(FATAL) << "tuple element type is not integral";
|
||||
case xla::OPAQUE:
|
||||
LOG(FATAL) << "opaque element type is not integral";
|
||||
default:
|
||||
LOG(FATAL) << "unhandled element type " << type;
|
||||
}
|
||||
return builder->ConstantLiteral(literal);
|
||||
}
|
||||
|
||||
xla::StatusOr<xla::ComputationDataHandle> SliceInMinorDims(
|
||||
xla::ComputationBuilder* builder, const xla::ComputationDataHandle& x,
|
||||
gtl::ArraySlice<int64> start, gtl::ArraySlice<int64> end) {
|
||||
|
@ -32,6 +32,11 @@ xla::ComputationDataHandle Zeros(xla::ComputationBuilder* builder,
|
||||
xla::ComputationDataHandle FloatLiteral(xla::ComputationBuilder* builder,
|
||||
xla::PrimitiveType type, double value);
|
||||
|
||||
// Returns a integer scalar constant of 'type' with 'value'.
|
||||
// If 'type' is complex, returns a real value with zero imaginary component.
|
||||
xla::ComputationDataHandle IntegerLiteral(xla::ComputationBuilder* builder,
|
||||
xla::PrimitiveType type, int64 value);
|
||||
|
||||
// Performs a slice in the minor dimensions of a Tensor.
|
||||
xla::StatusOr<xla::ComputationDataHandle> SliceInMinorDims(
|
||||
xla::ComputationBuilder* builder, const xla::ComputationDataHandle& x,
|
||||
|
125
tensorflow/compiler/tf2xla/lib/while_loop.cc
Normal file
125
tensorflow/compiler/tf2xla/lib/while_loop.cc
Normal file
@ -0,0 +1,125 @@
|
||||
/* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
|
||||
#include "tensorflow/compiler/tf2xla/lib/while_loop.h"
|
||||
#include "tensorflow/compiler/tf2xla/lib/util.h"
|
||||
#include "tensorflow/compiler/xla/shape_util.h"
|
||||
#include "tensorflow/compiler/xla/status_macros.h"
|
||||
|
||||
namespace tensorflow {
|
||||
|
||||
xla::StatusOr<std::vector<xla::ComputationDataHandle>> XlaWhileLoop(
|
||||
const LoopConditionFunction& condition_function,
|
||||
const LoopBodyFunction& body_function,
|
||||
gtl::ArraySlice<xla::ComputationDataHandle> initial_values,
|
||||
StringPiece name, xla::ComputationBuilder* builder) {
|
||||
int arity = initial_values.size();
|
||||
std::vector<xla::Shape> var_shapes;
|
||||
var_shapes.reserve(arity);
|
||||
for (const xla::ComputationDataHandle& input : initial_values) {
|
||||
TF_ASSIGN_OR_RETURN(auto shape, builder->GetShape(input));
|
||||
var_shapes.push_back(std::move(*shape));
|
||||
}
|
||||
xla::Shape tuple_shape = xla::ShapeUtil::MakeTupleShape(var_shapes);
|
||||
|
||||
// Unpacks a tuple into its component parts.
|
||||
auto unpack_tuple = [](xla::ComputationDataHandle tuple, int arity,
|
||||
xla::ComputationBuilder* builder) {
|
||||
std::vector<xla::ComputationDataHandle> elements(arity);
|
||||
for (int i = 0; i < arity; ++i) {
|
||||
elements[i] = builder->GetTupleElement(tuple, i);
|
||||
}
|
||||
return elements;
|
||||
};
|
||||
|
||||
// Build the condition.
|
||||
std::unique_ptr<xla::ComputationBuilder> cond_builder =
|
||||
builder->CreateSubBuilder(strings::StrCat(name, "_condition"));
|
||||
{
|
||||
auto parameter = cond_builder->Parameter(0, tuple_shape, "parameter");
|
||||
|
||||
TF_ASSIGN_OR_RETURN(
|
||||
auto result,
|
||||
condition_function(unpack_tuple(parameter, arity, cond_builder.get()),
|
||||
cond_builder.get()));
|
||||
TF_RETURN_IF_ERROR(cond_builder->SetReturnValue(result));
|
||||
}
|
||||
TF_ASSIGN_OR_RETURN(auto cond, cond_builder->Build());
|
||||
|
||||
// Build the body.
|
||||
std::unique_ptr<xla::ComputationBuilder> body_builder =
|
||||
builder->CreateSubBuilder(strings::StrCat(name, "_body"));
|
||||
{
|
||||
auto parameter = body_builder->Parameter(0, tuple_shape, "parameter");
|
||||
|
||||
TF_ASSIGN_OR_RETURN(
|
||||
auto result,
|
||||
body_function(unpack_tuple(parameter, arity, body_builder.get()),
|
||||
body_builder.get()));
|
||||
|
||||
TF_RET_CHECK(result.size() == initial_values.size());
|
||||
body_builder->Tuple(result);
|
||||
}
|
||||
TF_ASSIGN_OR_RETURN(auto body, body_builder->Build());
|
||||
|
||||
auto outputs = builder->While(cond, body, builder->Tuple(initial_values));
|
||||
|
||||
return unpack_tuple(outputs, arity, builder);
|
||||
}
|
||||
|
||||
xla::StatusOr<std::vector<xla::ComputationDataHandle>> XlaForEachIndex(
|
||||
int64 num_iterations, xla::PrimitiveType num_iterations_type,
|
||||
const ForEachIndexBodyFunction& body_function,
|
||||
gtl::ArraySlice<xla::ComputationDataHandle> initial_values,
|
||||
StringPiece name, xla::ComputationBuilder* builder) {
|
||||
auto while_cond_fn = [&](gtl::ArraySlice<xla::ComputationDataHandle> values,
|
||||
xla::ComputationBuilder* cond_builder)
|
||||
-> xla::StatusOr<xla::ComputationDataHandle> {
|
||||
return cond_builder->Lt(
|
||||
values[0],
|
||||
IntegerLiteral(cond_builder, num_iterations_type, num_iterations));
|
||||
};
|
||||
auto while_body_fn = [&](gtl::ArraySlice<xla::ComputationDataHandle> values,
|
||||
xla::ComputationBuilder* body_builder)
|
||||
-> xla::StatusOr<std::vector<xla::ComputationDataHandle>> {
|
||||
xla::ComputationDataHandle iteration = values[0];
|
||||
|
||||
std::vector<xla::ComputationDataHandle> updated_values;
|
||||
updated_values.reserve(values.size());
|
||||
updated_values.push_back(body_builder->Add(
|
||||
iteration,
|
||||
body_builder->ConstantLiteral(xla::Literal::One(num_iterations_type))));
|
||||
|
||||
values.remove_prefix(1);
|
||||
TF_ASSIGN_OR_RETURN(std::vector<xla::ComputationDataHandle> body_outputs,
|
||||
body_function(iteration, values, body_builder));
|
||||
updated_values.insert(updated_values.end(), body_outputs.begin(),
|
||||
body_outputs.end());
|
||||
return updated_values;
|
||||
};
|
||||
|
||||
std::vector<xla::ComputationDataHandle> values;
|
||||
values.reserve(initial_values.size() + 1);
|
||||
values.push_back(
|
||||
builder->ConstantLiteral(xla::Literal::Zero(num_iterations_type)));
|
||||
values.insert(values.end(), initial_values.begin(), initial_values.end());
|
||||
|
||||
TF_ASSIGN_OR_RETURN(values, XlaWhileLoop(while_cond_fn, while_body_fn, values,
|
||||
name, builder));
|
||||
values.erase(values.begin(), values.begin() + 1);
|
||||
return values;
|
||||
}
|
||||
|
||||
} // namespace tensorflow
|
74
tensorflow/compiler/tf2xla/lib/while_loop.h
Normal file
74
tensorflow/compiler/tf2xla/lib/while_loop.h
Normal file
@ -0,0 +1,74 @@
|
||||
/* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
|
||||
#ifndef TENSORFLOW_COMPILER_TF2XLA_LIB_WHILE_LOOP_H_
|
||||
#define TENSORFLOW_COMPILER_TF2XLA_LIB_WHILE_LOOP_H_
|
||||
|
||||
#include <functional>
|
||||
#include <vector>
|
||||
|
||||
#include "tensorflow/compiler/xla/client/computation.h"
|
||||
#include "tensorflow/compiler/xla/client/computation_builder.h"
|
||||
#include "tensorflow/compiler/xla/statusor.h"
|
||||
#include "tensorflow/core/lib/core/stringpiece.h"
|
||||
#include "tensorflow/core/lib/gtl/array_slice.h"
|
||||
|
||||
namespace tensorflow {
|
||||
|
||||
// Function that builds a loop condition. Takes as input a sequence of input
|
||||
// values, and returns a boolean value representing if the condition succeeds.
|
||||
typedef std::function<xla::StatusOr<xla::ComputationDataHandle>(
|
||||
gtl::ArraySlice<xla::ComputationDataHandle>, xla::ComputationBuilder*)>
|
||||
LoopConditionFunction;
|
||||
|
||||
// Function that builds a loop body. Takes as input a sequence of input values
|
||||
// and returns a sequence of output values.
|
||||
typedef std::function<xla::StatusOr<std::vector<xla::ComputationDataHandle>>(
|
||||
gtl::ArraySlice<xla::ComputationDataHandle>, xla::ComputationBuilder*)>
|
||||
LoopBodyFunction;
|
||||
|
||||
// Helper function for building an XLA while loop, where the values carried by
|
||||
// the loop are a tuple of values, e.g., (a, b, c):
|
||||
// while(
|
||||
// condition: (a, b, c) -> bool,
|
||||
// body: (a, b, c) -> (a, b, c)
|
||||
// init: (a, b, c)
|
||||
// )
|
||||
// 'name' is a descriptive name for the loop.
|
||||
xla::StatusOr<std::vector<xla::ComputationDataHandle>> XlaWhileLoop(
|
||||
const LoopConditionFunction& condition_function,
|
||||
const LoopBodyFunction& body_function,
|
||||
gtl::ArraySlice<xla::ComputationDataHandle> initial_values,
|
||||
StringPiece name, xla::ComputationBuilder* builder);
|
||||
|
||||
// Builds an XLA loop that repeats a computation `num_iterations` times.
|
||||
//
|
||||
// The body function (ForEachIndexBodyFunction) takes as input a pair of
|
||||
// (current iteration number, loop-carried values), and returns an updated
|
||||
// vector of the loop-carried values.
|
||||
typedef std::function<xla::StatusOr<std::vector<xla::ComputationDataHandle>>(
|
||||
xla::ComputationDataHandle, gtl::ArraySlice<xla::ComputationDataHandle>,
|
||||
xla::ComputationBuilder*)>
|
||||
ForEachIndexBodyFunction;
|
||||
|
||||
xla::StatusOr<std::vector<xla::ComputationDataHandle>> XlaForEachIndex(
|
||||
int64 num_iterations, xla::PrimitiveType num_iterations_type,
|
||||
const ForEachIndexBodyFunction& body_function,
|
||||
gtl::ArraySlice<xla::ComputationDataHandle> initial_values,
|
||||
StringPiece name, xla::ComputationBuilder* builder);
|
||||
|
||||
} // namespace tensorflow
|
||||
|
||||
#endif // TENSORFLOW_COMPILER_TF2XLA_LIB_WHILE_LOOP_H_
|
@ -135,58 +135,9 @@ xla::ComputationDataHandle XlaHelpers::Epsilon(xla::ComputationBuilder* b,
|
||||
|
||||
xla::ComputationDataHandle XlaHelpers::IntegerLiteral(
|
||||
xla::ComputationBuilder* b, DataType data_type, int64 value) {
|
||||
xla::Literal literal;
|
||||
xla::PrimitiveType type;
|
||||
TF_CHECK_OK(DataTypeToPrimitiveType(data_type, &type));
|
||||
switch (type) {
|
||||
case xla::U8:
|
||||
literal = std::move(*xla::Literal::CreateR0<uint8>(value));
|
||||
break;
|
||||
case xla::U32:
|
||||
literal = std::move(*xla::Literal::CreateR0<uint32>(value));
|
||||
break;
|
||||
case xla::U64:
|
||||
literal = std::move(*xla::Literal::CreateR0<uint64>(value));
|
||||
break;
|
||||
case xla::S8:
|
||||
literal = std::move(*xla::Literal::CreateR0<int8>(value));
|
||||
break;
|
||||
case xla::S32:
|
||||
literal = std::move(*xla::Literal::CreateR0<int32>(value));
|
||||
break;
|
||||
case xla::S64:
|
||||
literal = std::move(*xla::Literal::CreateR0<int64>(value));
|
||||
break;
|
||||
case xla::F32:
|
||||
literal = std::move(*xla::Literal::CreateR0<float>(value));
|
||||
break;
|
||||
case xla::F64:
|
||||
literal = std::move(*xla::Literal::CreateR0<double>(value));
|
||||
break;
|
||||
case xla::C64:
|
||||
literal = std::move(*xla::Literal::CreateR0<complex64>(value));
|
||||
break;
|
||||
case xla::PRED:
|
||||
LOG(FATAL) << "pred element type is not integral";
|
||||
case xla::S16:
|
||||
case xla::U16:
|
||||
LOG(FATAL) << "u16/s16 literals not yet implemented";
|
||||
case xla::BF16:
|
||||
literal = std::move(
|
||||
*xla::Literal::CreateR0<bfloat16>(static_cast<bfloat16>(value)));
|
||||
break;
|
||||
case xla::F16:
|
||||
literal = std::move(
|
||||
*xla::Literal::CreateR0<xla::half>(static_cast<xla::half>(value)));
|
||||
break;
|
||||
case xla::TUPLE:
|
||||
LOG(FATAL) << "tuple element type is not integral";
|
||||
case xla::OPAQUE:
|
||||
LOG(FATAL) << "opaque element type is not integral";
|
||||
default:
|
||||
LOG(FATAL) << "unhandled element type " << type;
|
||||
}
|
||||
return b->ConstantLiteral(literal);
|
||||
return ::tensorflow::IntegerLiteral(b, type, value);
|
||||
}
|
||||
|
||||
xla::ComputationDataHandle XlaHelpers::FloatLiteral(xla::ComputationBuilder* b,
|
||||
|
@ -43,6 +43,81 @@ filegroup(
|
||||
]),
|
||||
)
|
||||
|
||||
cc_library(
|
||||
name = "bfloat16_support",
|
||||
srcs = ["bfloat16_support.cc"],
|
||||
hdrs = ["bfloat16_support.h"],
|
||||
deps = [
|
||||
":hlo",
|
||||
],
|
||||
)
|
||||
|
||||
cc_library(
|
||||
name = "bfloat16_conversion_folding",
|
||||
srcs = ["bfloat16_conversion_folding.cc"],
|
||||
hdrs = ["bfloat16_conversion_folding.h"],
|
||||
deps = [
|
||||
":bfloat16_support",
|
||||
":hlo",
|
||||
":hlo_pass",
|
||||
"//tensorflow/compiler/xla:status_macros",
|
||||
"//tensorflow/compiler/xla:xla_data_proto",
|
||||
"//tensorflow/core:lib",
|
||||
],
|
||||
)
|
||||
|
||||
tf_cc_test(
|
||||
name = "bfloat16_conversion_folding_test",
|
||||
srcs = ["bfloat16_conversion_folding_test.cc"],
|
||||
deps = [
|
||||
":bfloat16_conversion_folding",
|
||||
":bfloat16_support",
|
||||
":hlo",
|
||||
"//tensorflow/compiler/xla:shape_util",
|
||||
"//tensorflow/compiler/xla:status_macros",
|
||||
"//tensorflow/compiler/xla:test",
|
||||
"//tensorflow/compiler/xla:test_helpers",
|
||||
"//tensorflow/compiler/xla:types",
|
||||
"//tensorflow/compiler/xla:xla_data_proto",
|
||||
"//tensorflow/compiler/xla/tests:hlo_test_base",
|
||||
"//tensorflow/compiler/xla/tests:xla_internal_test_main",
|
||||
"//tensorflow/core:lib",
|
||||
],
|
||||
)
|
||||
|
||||
cc_library(
|
||||
name = "bfloat16_normalization",
|
||||
srcs = ["bfloat16_normalization.cc"],
|
||||
hdrs = ["bfloat16_normalization.h"],
|
||||
deps = [
|
||||
":bfloat16_support",
|
||||
":hlo",
|
||||
":hlo_pass",
|
||||
"//tensorflow/compiler/xla:status_macros",
|
||||
"//tensorflow/compiler/xla:xla_data_proto",
|
||||
"//tensorflow/core:lib",
|
||||
],
|
||||
)
|
||||
|
||||
tf_cc_test(
|
||||
name = "bfloat16_normalization_test",
|
||||
srcs = ["bfloat16_normalization_test.cc"],
|
||||
deps = [
|
||||
":bfloat16_normalization",
|
||||
":bfloat16_support",
|
||||
":hlo",
|
||||
"//tensorflow/compiler/xla:shape_util",
|
||||
"//tensorflow/compiler/xla:status_macros",
|
||||
"//tensorflow/compiler/xla:test",
|
||||
"//tensorflow/compiler/xla:test_helpers",
|
||||
"//tensorflow/compiler/xla:types",
|
||||
"//tensorflow/compiler/xla:xla_data_proto",
|
||||
"//tensorflow/compiler/xla/tests:hlo_test_base",
|
||||
"//tensorflow/compiler/xla/tests:xla_internal_test_main",
|
||||
"//tensorflow/core:lib",
|
||||
],
|
||||
)
|
||||
|
||||
cc_library(
|
||||
name = "shape_inference",
|
||||
srcs = ["shape_inference.cc"],
|
||||
|
184
tensorflow/compiler/xla/service/bfloat16_conversion_folding.cc
Normal file
184
tensorflow/compiler/xla/service/bfloat16_conversion_folding.cc
Normal file
@ -0,0 +1,184 @@
|
||||
/* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
|
||||
#include "tensorflow/compiler/xla/service/bfloat16_conversion_folding.h"
|
||||
|
||||
#include "tensorflow/compiler/xla/service/hlo_computation.h"
|
||||
#include "tensorflow/compiler/xla/service/hlo_instruction.h"
|
||||
#include "tensorflow/compiler/xla/service/hlo_opcode.h"
|
||||
#include "tensorflow/compiler/xla/status_macros.h"
|
||||
#include "tensorflow/compiler/xla/xla_data.pb.h"
|
||||
#include "tensorflow/core/lib/gtl/array_slice.h"
|
||||
#include "tensorflow/core/platform/logging.h"
|
||||
#include "tensorflow/core/platform/types.h"
|
||||
|
||||
namespace xla {
|
||||
|
||||
class BFloat16ConversionFoldingVisitor : public DfsHloVisitorWithDefault {
|
||||
public:
|
||||
explicit BFloat16ConversionFoldingVisitor(
|
||||
HloComputation* computation, const BFloat16Support* bfloat16_support)
|
||||
: computation_(computation), bfloat16_support_(bfloat16_support) {}
|
||||
|
||||
Status DefaultAction(HloInstruction* hlo) override;
|
||||
|
||||
static bool Run(HloComputation* computation,
|
||||
const BFloat16Support* bfloat16_support) {
|
||||
BFloat16ConversionFoldingVisitor visitor(computation, bfloat16_support);
|
||||
TF_CHECK_OK(computation->Accept(&visitor));
|
||||
return visitor.changed_;
|
||||
}
|
||||
|
||||
private:
|
||||
// Checks if the HLO has a BF16 -> F32 conversion as input, or a F32 -> BF16
|
||||
// conversion as output, and folds them to the HLO itself if feasible.
|
||||
Status TryFoldBF16Conversions(HloInstruction* hlo);
|
||||
|
||||
// Folds the F32 -> BF16 conversions from the HLO's output.
|
||||
//
|
||||
// Precondition: all of the HLO's users are F32 -> BF16 conversions.
|
||||
Status FoldOutputConversions(HloInstruction* hlo);
|
||||
|
||||
// Folds the BF16 -> F32 conversion operand to the HLO.
|
||||
//
|
||||
// Precondition: the operand is a F32 -> BF16 conversion.
|
||||
Status FoldOperandConversion(HloInstruction* hlo, int64 operand_index);
|
||||
|
||||
HloComputation* computation_;
|
||||
const BFloat16Support* bfloat16_support_;
|
||||
bool changed_ = false;
|
||||
};
|
||||
|
||||
Status BFloat16ConversionFoldingVisitor::FoldOutputConversions(
|
||||
HloInstruction* hlo) {
|
||||
std::vector<HloInstruction*> materialized_users = hlo->users();
|
||||
hlo->mutable_shape()->set_element_type(BF16);
|
||||
for (auto user : materialized_users) {
|
||||
CHECK_EQ(user->opcode(), HloOpcode::kConvert);
|
||||
TF_RETURN_IF_ERROR(user->ReplaceAllUsesWith(hlo));
|
||||
changed_ = true;
|
||||
}
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
Status BFloat16ConversionFoldingVisitor::FoldOperandConversion(
|
||||
HloInstruction* hlo, int64 operand_index) {
|
||||
// The operand is a convert from BF16 to F32.
|
||||
auto operand = hlo->mutable_operand(operand_index);
|
||||
CHECK_EQ(operand->opcode(), HloOpcode::kConvert);
|
||||
TF_RETURN_IF_ERROR(
|
||||
hlo->ReplaceOperandWith(operand_index, operand->mutable_operand(0)));
|
||||
changed_ = true;
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
Status BFloat16ConversionFoldingVisitor::TryFoldBF16Conversions(
|
||||
HloInstruction* hlo) {
|
||||
std::vector<int64> bf16_to_f32_operands;
|
||||
bool has_other_f32_operands = false;
|
||||
for (int64 i = 0; i < hlo->operands().size(); ++i) {
|
||||
auto operand = hlo->operand(i);
|
||||
if (operand->shape().element_type() == F32) {
|
||||
if (operand->opcode() == HloOpcode::kConvert &&
|
||||
operand->operand(0)->shape().element_type() == BF16 &&
|
||||
bfloat16_support_->SupportsBF16Operand(*hlo, i)) {
|
||||
// Operand is a convert from BF16 to F32 and we support BF16 input
|
||||
// directly in the current HLO at the operand index.
|
||||
bf16_to_f32_operands.push_back(i);
|
||||
} else {
|
||||
has_other_f32_operands = true;
|
||||
}
|
||||
continue;
|
||||
}
|
||||
}
|
||||
|
||||
bool fold_output_conversion = hlo->user_count() > 0 &&
|
||||
hlo->shape().element_type() == F32 &&
|
||||
bfloat16_support_->SupportsBF16Output(*hlo) &&
|
||||
hlo != computation_->root_instruction();
|
||||
if (fold_output_conversion) {
|
||||
for (auto user : hlo->users()) {
|
||||
if (user->opcode() == HloOpcode::kConvert &&
|
||||
user->shape().element_type() == BF16) {
|
||||
continue;
|
||||
}
|
||||
// We should not change the output type if any user is not a conversion
|
||||
// from F32 to BF16.
|
||||
fold_output_conversion = false;
|
||||
break;
|
||||
}
|
||||
}
|
||||
|
||||
if (!bfloat16_support_->SupportsMixedPrecisions(*hlo)) {
|
||||
if (has_other_f32_operands ||
|
||||
(!fold_output_conversion && hlo->shape().element_type() == F32)) {
|
||||
// Some of the operands/output will remain F32, but we cannot use mixed
|
||||
// precisions, so we cannot do anything here.
|
||||
return Status::OK();
|
||||
}
|
||||
}
|
||||
|
||||
if (fold_output_conversion) {
|
||||
TF_RETURN_IF_ERROR(FoldOutputConversions(hlo));
|
||||
}
|
||||
|
||||
for (int64 i : bf16_to_f32_operands) {
|
||||
TF_RETURN_IF_ERROR(FoldOperandConversion(hlo, i));
|
||||
}
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
Status BFloat16ConversionFoldingVisitor::DefaultAction(HloInstruction* hlo) {
|
||||
// Do not fold BF16 conversions for instructions related to tuples, entry and
|
||||
// exit of a computation, fusion, convert, and control flow.
|
||||
if (hlo->opcode() == HloOpcode::kTuple || //
|
||||
hlo->opcode() == HloOpcode::kGetTupleElement || //
|
||||
hlo->opcode() == HloOpcode::kInfeed || //
|
||||
hlo->opcode() == HloOpcode::kOutfeed || //
|
||||
hlo->opcode() == HloOpcode::kConstant || //
|
||||
hlo->opcode() == HloOpcode::kParameter || //
|
||||
hlo->opcode() == HloOpcode::kFusion || //
|
||||
hlo->opcode() == HloOpcode::kConvert || //
|
||||
hlo->opcode() == HloOpcode::kCall || //
|
||||
hlo->opcode() == HloOpcode::kCustomCall || //
|
||||
hlo->opcode() == HloOpcode::kWhile || //
|
||||
hlo->opcode() == HloOpcode::kConditional) {
|
||||
return Status::OK();
|
||||
}
|
||||
if (hlo == computation_->root_instruction() &&
|
||||
!bfloat16_support_->SupportsMixedPrecisions(*hlo)) {
|
||||
// If hlo is the root instruction, we cannot change its output, so folding
|
||||
// can only happen when it supports mixed precision so that we can change
|
||||
// its operands.
|
||||
return Status::OK();
|
||||
}
|
||||
return TryFoldBF16Conversions(hlo);
|
||||
}
|
||||
|
||||
StatusOr<bool> BFloat16ConversionFolding::Run(HloModule* module) {
|
||||
XLA_VLOG_LINES(
|
||||
2, "BFloat16ConversionFolding::Run(), before:\n" + module->ToString());
|
||||
bool changed = false;
|
||||
for (auto* comp : module->MakeNonfusionComputations()) {
|
||||
if (BFloat16ConversionFoldingVisitor::Run(comp, bfloat16_support_)) {
|
||||
changed = true;
|
||||
}
|
||||
}
|
||||
XLA_VLOG_LINES(
|
||||
2, "BFloat16ConversionFolding::Run(), after:\n" + module->ToString());
|
||||
return changed;
|
||||
}
|
||||
|
||||
} // namespace xla
|
@ -0,0 +1,52 @@
|
||||
/* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
|
||||
#ifndef TENSORFLOW_COMPILER_XLA_SERVICE_BFLOAT16_CONVERSION_FOLDING_H_
|
||||
#define TENSORFLOW_COMPILER_XLA_SERVICE_BFLOAT16_CONVERSION_FOLDING_H_
|
||||
|
||||
#include "tensorflow/compiler/xla/service/bfloat16_support.h"
|
||||
#include "tensorflow/compiler/xla/service/hlo_module.h"
|
||||
#include "tensorflow/compiler/xla/service/hlo_pass_interface.h"
|
||||
|
||||
namespace xla {
|
||||
|
||||
// A pass which folds F32 <-> BF16 conversions to their operands or users, when
|
||||
// it is supported by the backend.
|
||||
//
|
||||
// This pass follows the passed-in backend-specific BF16 support rules, but can
|
||||
// introduce mixed precision in individual HLOs which breaks the assumption of
|
||||
// some other HLO passes. So it should be used at the end of the HLO
|
||||
// optimization pipeline followed by a DCE pass. If other passes are needed
|
||||
// after this pass, run BFloat16MixedPrecisionRemoval first to undo some of the
|
||||
// changed made by this pass.
|
||||
class BFloat16ConversionFolding : public HloPassInterface {
|
||||
public:
|
||||
explicit BFloat16ConversionFolding(const BFloat16Support* bfloat16_support)
|
||||
: bfloat16_support_(bfloat16_support) {}
|
||||
|
||||
~BFloat16ConversionFolding() override = default;
|
||||
tensorflow::StringPiece name() const override { return "bfloat16-fold"; }
|
||||
|
||||
// Run BF16 conversion folding on the given computation. Returns whether the
|
||||
// computation was changed.
|
||||
StatusOr<bool> Run(HloModule* module) override;
|
||||
|
||||
private:
|
||||
const BFloat16Support* bfloat16_support_;
|
||||
};
|
||||
|
||||
} // namespace xla
|
||||
|
||||
#endif // TENSORFLOW_COMPILER_XLA_SERVICE_BFLOAT16_CONVERSION_FOLDING_H_
|
@ -0,0 +1,209 @@
|
||||
/* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
|
||||
#include "tensorflow/compiler/xla/service/bfloat16_conversion_folding.h"
|
||||
#include "tensorflow/compiler/xla/service/bfloat16_support.h"
|
||||
#include "tensorflow/compiler/xla/service/hlo_computation.h"
|
||||
#include "tensorflow/compiler/xla/service/hlo_instruction.h"
|
||||
#include "tensorflow/compiler/xla/service/hlo_module.h"
|
||||
#include "tensorflow/compiler/xla/service/hlo_opcode.h"
|
||||
#include "tensorflow/compiler/xla/shape_util.h"
|
||||
#include "tensorflow/compiler/xla/test.h"
|
||||
#include "tensorflow/compiler/xla/test_helpers.h"
|
||||
#include "tensorflow/compiler/xla/tests/hlo_test_base.h"
|
||||
#include "tensorflow/compiler/xla/xla_data.pb.h"
|
||||
|
||||
namespace xla {
|
||||
|
||||
class TestBFloat16Support : public BFloat16Support {
|
||||
public:
|
||||
TestBFloat16Support() {}
|
||||
~TestBFloat16Support() override {}
|
||||
|
||||
bool SupportsBF16Operand(const HloInstruction& hlo,
|
||||
int64 operand_index) const override {
|
||||
if (hlo.opcode() == HloOpcode::kAdd ||
|
||||
hlo.opcode() == HloOpcode::kSubtract ||
|
||||
hlo.opcode() == HloOpcode::kTuple ||
|
||||
hlo.opcode() == HloOpcode::kGetTupleElement) {
|
||||
return true;
|
||||
}
|
||||
return false;
|
||||
}
|
||||
|
||||
bool SupportsBF16Output(const HloInstruction& hlo) const override {
|
||||
if (hlo.opcode() == HloOpcode::kAdd ||
|
||||
hlo.opcode() == HloOpcode::kSubtract ||
|
||||
hlo.opcode() == HloOpcode::kTuple ||
|
||||
hlo.opcode() == HloOpcode::kGetTupleElement) {
|
||||
return true;
|
||||
}
|
||||
return false;
|
||||
}
|
||||
|
||||
bool SupportsMixedPrecisions(const HloInstruction& hlo) const override {
|
||||
if (hlo.opcode() == HloOpcode::kAdd || hlo.opcode() == HloOpcode::kTuple ||
|
||||
hlo.opcode() == HloOpcode::kGetTupleElement) {
|
||||
return true;
|
||||
}
|
||||
return false;
|
||||
}
|
||||
};
|
||||
|
||||
class BFloat16ConversionFoldingTest : public HloTestBase {
|
||||
protected:
|
||||
bool FoldConversions(HloModule* module) {
|
||||
TestBFloat16Support bfloat16_support_;
|
||||
BFloat16ConversionFolding fold(&bfloat16_support_);
|
||||
StatusOr<bool> result = fold.Run(module);
|
||||
EXPECT_IS_OK(result.status());
|
||||
return result.ValueOrDie();
|
||||
}
|
||||
};
|
||||
|
||||
TEST_F(BFloat16ConversionFoldingTest, FoldIfSupported) {
|
||||
auto builder = HloComputation::Builder(TestName());
|
||||
Shape f32_shape = ShapeUtil::MakeShape(F32, {2, 4});
|
||||
Shape bf16_shape = ShapeUtil::MakeShape(BF16, {2, 4});
|
||||
|
||||
HloInstruction* a = builder.AddInstruction(
|
||||
HloInstruction::CreateParameter(0, f32_shape, "a"));
|
||||
HloInstruction* b = builder.AddInstruction(
|
||||
HloInstruction::CreateParameter(1, f32_shape, "b"));
|
||||
HloInstruction* c = builder.AddInstruction(
|
||||
HloInstruction::CreateParameter(2, f32_shape, "c"));
|
||||
|
||||
HloInstruction* add0 = builder.AddInstruction(
|
||||
HloInstruction::CreateBinary(f32_shape, HloOpcode::kAdd, a, b));
|
||||
HloInstruction* convert0 =
|
||||
builder.AddInstruction(HloInstruction::CreateConvert(bf16_shape, add0));
|
||||
HloInstruction* convert1 = builder.AddInstruction(
|
||||
HloInstruction::CreateConvert(f32_shape, convert0));
|
||||
|
||||
HloInstruction* add1 = builder.AddInstruction(
|
||||
HloInstruction::CreateBinary(f32_shape, HloOpcode::kAdd, convert1, c));
|
||||
builder.AddInstruction(HloInstruction::CreateConvert(bf16_shape, add1));
|
||||
|
||||
auto module = CreateNewModule();
|
||||
auto computation = module->AddEntryComputation(builder.Build());
|
||||
|
||||
EXPECT_TRUE(FoldConversions(module.get()));
|
||||
|
||||
EXPECT_EQ(computation->root_instruction(), add1);
|
||||
EXPECT_EQ(add0->shape().element_type(), BF16);
|
||||
EXPECT_EQ(add1->shape().element_type(), BF16);
|
||||
EXPECT_EQ(add1->operand(0), add0);
|
||||
}
|
||||
|
||||
TEST_F(BFloat16ConversionFoldingTest, DoNotFoldIfUnsupported) {
|
||||
auto builder = HloComputation::Builder(TestName());
|
||||
Shape f32_shape = ShapeUtil::MakeShape(F32, {2, 4});
|
||||
Shape bf16_shape = ShapeUtil::MakeShape(BF16, {2, 4});
|
||||
|
||||
HloInstruction* a = builder.AddInstruction(
|
||||
HloInstruction::CreateParameter(0, f32_shape, "a"));
|
||||
HloInstruction* b = builder.AddInstruction(
|
||||
HloInstruction::CreateParameter(1, f32_shape, "b"));
|
||||
HloInstruction* c = builder.AddInstruction(
|
||||
HloInstruction::CreateParameter(2, f32_shape, "c"));
|
||||
|
||||
HloInstruction* mul0 = builder.AddInstruction(
|
||||
HloInstruction::CreateBinary(f32_shape, HloOpcode::kMultiply, a, b));
|
||||
HloInstruction* convert0 =
|
||||
builder.AddInstruction(HloInstruction::CreateConvert(bf16_shape, mul0));
|
||||
HloInstruction* convert1 = builder.AddInstruction(
|
||||
HloInstruction::CreateConvert(f32_shape, convert0));
|
||||
|
||||
HloInstruction* mul1 = builder.AddInstruction(HloInstruction::CreateBinary(
|
||||
f32_shape, HloOpcode::kMultiply, convert1, c));
|
||||
HloInstruction* convert2 =
|
||||
builder.AddInstruction(HloInstruction::CreateConvert(bf16_shape, mul1));
|
||||
|
||||
auto module = CreateNewModule();
|
||||
auto computation = module->AddEntryComputation(builder.Build());
|
||||
|
||||
EXPECT_FALSE(FoldConversions(module.get()));
|
||||
|
||||
EXPECT_EQ(computation->root_instruction(), convert2);
|
||||
EXPECT_EQ(mul0->shape().element_type(), F32);
|
||||
EXPECT_EQ(mul1->shape().element_type(), F32);
|
||||
EXPECT_EQ(mul1->operand(0), convert1);
|
||||
}
|
||||
|
||||
TEST_F(BFloat16ConversionFoldingTest, DoNotFoldUnsupportedMixedPrecision) {
|
||||
auto builder = HloComputation::Builder(TestName());
|
||||
Shape f32_shape = ShapeUtil::MakeShape(F32, {2, 4});
|
||||
Shape bf16_shape = ShapeUtil::MakeShape(BF16, {2, 4});
|
||||
|
||||
HloInstruction* a = builder.AddInstruction(
|
||||
HloInstruction::CreateParameter(0, f32_shape, "a"));
|
||||
HloInstruction* b = builder.AddInstruction(
|
||||
HloInstruction::CreateParameter(1, f32_shape, "b"));
|
||||
HloInstruction* c = builder.AddInstruction(
|
||||
HloInstruction::CreateParameter(2, f32_shape, "c"));
|
||||
|
||||
HloInstruction* sub0 = builder.AddInstruction(
|
||||
HloInstruction::CreateBinary(f32_shape, HloOpcode::kSubtract, a, b));
|
||||
HloInstruction* convert0 =
|
||||
builder.AddInstruction(HloInstruction::CreateConvert(bf16_shape, sub0));
|
||||
HloInstruction* convert1 = builder.AddInstruction(
|
||||
HloInstruction::CreateConvert(f32_shape, convert0));
|
||||
|
||||
HloInstruction* sub1 = builder.AddInstruction(HloInstruction::CreateBinary(
|
||||
f32_shape, HloOpcode::kSubtract, convert1, c));
|
||||
HloInstruction* convert2 =
|
||||
builder.AddInstruction(HloInstruction::CreateConvert(bf16_shape, sub1));
|
||||
|
||||
auto module = CreateNewModule();
|
||||
auto computation = module->AddEntryComputation(builder.Build());
|
||||
|
||||
EXPECT_FALSE(FoldConversions(module.get()));
|
||||
|
||||
EXPECT_EQ(computation->root_instruction(), convert2);
|
||||
EXPECT_EQ(sub0->shape().element_type(), F32);
|
||||
EXPECT_EQ(sub1->shape().element_type(), F32);
|
||||
EXPECT_EQ(sub1->operand(0), convert1);
|
||||
}
|
||||
|
||||
TEST_F(BFloat16ConversionFoldingTest, DoNotFoldTuple) {
|
||||
auto builder = HloComputation::Builder(TestName());
|
||||
Shape f32_shape = ShapeUtil::MakeShape(F32, {2, 4});
|
||||
Shape bf16_shape = ShapeUtil::MakeShape(BF16, {2, 4});
|
||||
|
||||
HloInstruction* a = builder.AddInstruction(
|
||||
HloInstruction::CreateParameter(0, f32_shape, "a"));
|
||||
HloInstruction* b = builder.AddInstruction(
|
||||
HloInstruction::CreateParameter(1, bf16_shape, "b"));
|
||||
HloInstruction* convert0 =
|
||||
builder.AddInstruction(HloInstruction::CreateConvert(f32_shape, b));
|
||||
|
||||
HloInstruction* tuple =
|
||||
builder.AddInstruction(HloInstruction::CreateTuple({a, convert0}));
|
||||
HloInstruction* gte = builder.AddInstruction(
|
||||
HloInstruction::CreateGetTupleElement(f32_shape, tuple, 0));
|
||||
HloInstruction* convert1 =
|
||||
builder.AddInstruction(HloInstruction::CreateConvert(bf16_shape, gte));
|
||||
|
||||
auto module = CreateNewModule();
|
||||
auto computation = module->AddEntryComputation(builder.Build());
|
||||
|
||||
EXPECT_FALSE(FoldConversions(module.get()));
|
||||
|
||||
EXPECT_EQ(computation->root_instruction(), convert1);
|
||||
EXPECT_EQ(gte->shape().element_type(), F32);
|
||||
EXPECT_EQ(tuple->operand(1), convert0);
|
||||
}
|
||||
|
||||
} // namespace xla
|
351
tensorflow/compiler/xla/service/bfloat16_normalization.cc
Normal file
351
tensorflow/compiler/xla/service/bfloat16_normalization.cc
Normal file
@ -0,0 +1,351 @@
|
||||
/* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
|
||||
#include "tensorflow/compiler/xla/service/bfloat16_normalization.h"
|
||||
|
||||
#include "tensorflow/compiler/xla/service/hlo_computation.h"
|
||||
#include "tensorflow/compiler/xla/service/hlo_instruction.h"
|
||||
#include "tensorflow/compiler/xla/service/hlo_opcode.h"
|
||||
#include "tensorflow/compiler/xla/status_macros.h"
|
||||
#include "tensorflow/compiler/xla/xla_data.pb.h"
|
||||
#include "tensorflow/core/lib/gtl/array_slice.h"
|
||||
#include "tensorflow/core/platform/logging.h"
|
||||
#include "tensorflow/core/platform/types.h"
|
||||
|
||||
namespace xla {
|
||||
|
||||
class BFloat16NormalizationVisitor : public DfsHloVisitorWithDefault {
|
||||
public:
|
||||
explicit BFloat16NormalizationVisitor(HloComputation* computation,
|
||||
const BFloat16Support* bfloat16_support)
|
||||
: computation_(computation), bfloat16_support_(bfloat16_support) {}
|
||||
|
||||
Status DefaultAction(HloInstruction* hlo) override;
|
||||
|
||||
// Special handling for cross-replica-sum which can have a tuple output.
|
||||
Status HandleCrossReplicaSum(HloInstruction* crs) override;
|
||||
|
||||
static bool Run(HloComputation* computation,
|
||||
const BFloat16Support* bfloat16_support) {
|
||||
BFloat16NormalizationVisitor visitor(computation, bfloat16_support);
|
||||
TF_CHECK_OK(computation->Accept(&visitor));
|
||||
return visitor.changed_;
|
||||
}
|
||||
|
||||
private:
|
||||
// Checks if the HLO uses BF16 in an unsupported way, and if so, inserts
|
||||
// conversions between F32 and BF16 to make it supported.
|
||||
Status HandleInstruction(HloInstruction* hlo);
|
||||
|
||||
// Inserts a conversion HLO that changes the given HLO's output type.
|
||||
Status InsertConvertAfterOutput(HloInstruction* hlo, PrimitiveType to,
|
||||
HloComputation* computation);
|
||||
|
||||
// Changes the output type to the specified type, then inserts a conversion
|
||||
// to the original type.
|
||||
Status ChangeOutputTypeThenInsertConvertBack(HloInstruction* hlo,
|
||||
PrimitiveType to,
|
||||
HloComputation* computation);
|
||||
|
||||
// Inserts a conversion HLO that changes the given HLO's operand type.
|
||||
Status InsertConvertBeforeOperand(HloInstruction* hlo, int64 operand_idx,
|
||||
PrimitiveType to,
|
||||
HloComputation* computation);
|
||||
|
||||
// Inserts conversion HLOs to replace the called computations' BF16
|
||||
// operands/outputs to F32.
|
||||
Status ConvertCalledComputations(
|
||||
HloInstruction* hlo,
|
||||
tensorflow::gtl::ArraySlice<HloComputation*> bf16_called_comps);
|
||||
|
||||
HloComputation* computation_;
|
||||
const BFloat16Support* bfloat16_support_;
|
||||
bool changed_ = false;
|
||||
};
|
||||
|
||||
Status BFloat16NormalizationVisitor::InsertConvertAfterOutput(
|
||||
HloInstruction* hlo, PrimitiveType to, HloComputation* computation) {
|
||||
bool is_root = computation->root_instruction() == hlo;
|
||||
std::vector<HloInstruction*> materialized_users = hlo->users();
|
||||
// Use inst's shape temporarily, in order to pass checks in ReplaceUseWith.
|
||||
auto convert = computation->AddInstruction(
|
||||
HloInstruction::CreateConvert(hlo->shape(), hlo));
|
||||
for (auto* user : materialized_users) {
|
||||
TF_RETURN_IF_ERROR(hlo->ReplaceUseWith(user, convert));
|
||||
}
|
||||
if (is_root) {
|
||||
computation->set_root_instruction(convert);
|
||||
}
|
||||
convert->mutable_shape()->set_element_type(to);
|
||||
changed_ = true;
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
Status BFloat16NormalizationVisitor::ChangeOutputTypeThenInsertConvertBack(
|
||||
HloInstruction* hlo, PrimitiveType to, HloComputation* computation) {
|
||||
auto original_type = hlo->shape().element_type();
|
||||
hlo->mutable_shape()->set_element_type(to);
|
||||
return InsertConvertAfterOutput(hlo, original_type, computation);
|
||||
}
|
||||
|
||||
Status BFloat16NormalizationVisitor::InsertConvertBeforeOperand(
|
||||
HloInstruction* hlo, int64 operand_idx, PrimitiveType to,
|
||||
HloComputation* computation) {
|
||||
auto operand = hlo->mutable_operand(operand_idx);
|
||||
auto convert = computation->AddInstruction(HloInstruction::CreateConvert(
|
||||
ShapeUtil::ChangeElementType(operand->shape(), to), operand));
|
||||
TF_RETURN_IF_ERROR(hlo->ReplaceOperandWith(operand_idx, convert));
|
||||
changed_ = true;
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
Status BFloat16NormalizationVisitor::ConvertCalledComputations(
|
||||
HloInstruction* hlo,
|
||||
tensorflow::gtl::ArraySlice<HloComputation*> bf16_called_comps) {
|
||||
std::map<HloComputation*, HloComputation*> cloned_computations;
|
||||
for (auto& comp : bf16_called_comps) {
|
||||
auto cloned = comp->parent()->AddEmbeddedComputation(comp->Clone());
|
||||
cloned_computations[comp] = cloned;
|
||||
changed_ = true;
|
||||
}
|
||||
hlo->ReplaceCalledComputations([&](HloComputation* comp) {
|
||||
auto it = cloned_computations.find(comp);
|
||||
if (it != cloned_computations.end()) {
|
||||
return it->second;
|
||||
}
|
||||
return comp;
|
||||
});
|
||||
for (auto& comp_pair : cloned_computations) {
|
||||
auto comp = comp_pair.second;
|
||||
if (comp->root_instruction()->shape().element_type() == BF16) {
|
||||
TF_RETURN_IF_ERROR(
|
||||
InsertConvertAfterOutput(comp->root_instruction(), F32, comp));
|
||||
}
|
||||
for (auto* param : comp->parameter_instructions()) {
|
||||
if (param->shape().element_type() == BF16) {
|
||||
// This changes the parameter to F32 then inserts a convert after it.
|
||||
TF_RETURN_IF_ERROR(
|
||||
ChangeOutputTypeThenInsertConvertBack(param, F32, comp));
|
||||
}
|
||||
}
|
||||
}
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
Status BFloat16NormalizationVisitor::HandleCrossReplicaSum(
|
||||
HloInstruction* crs) {
|
||||
if (!ShapeUtil::IsTuple(crs->shape())) {
|
||||
return HandleInstruction(crs);
|
||||
}
|
||||
|
||||
std::vector<PrimitiveType> operand_types(crs->operand_count());
|
||||
std::vector<PrimitiveType> output_types(crs->operand_count());
|
||||
bool has_f32 = false;
|
||||
bool has_bf16 = false;
|
||||
bool has_bf16_output = false;
|
||||
for (int64 i = 0; i < crs->operand_count(); ++i) {
|
||||
operand_types[i] = crs->operand(i)->shape().element_type();
|
||||
output_types[i] = ShapeUtil::GetSubshape(crs->shape(), {i}).element_type();
|
||||
if (operand_types[i] == F32 || output_types[i] == F32) {
|
||||
has_f32 = true;
|
||||
} else if (operand_types[i] == BF16) {
|
||||
has_bf16 = true;
|
||||
}
|
||||
if (output_types[i] == BF16) {
|
||||
has_bf16 = true;
|
||||
has_bf16_output = true;
|
||||
}
|
||||
}
|
||||
|
||||
for (int64 i = 0; i < crs->operand_count(); ++i) {
|
||||
if (operand_types[i] != BF16) {
|
||||
continue;
|
||||
}
|
||||
if (bfloat16_support_->SupportsBF16Operand(*crs, i) &&
|
||||
(bfloat16_support_->SupportsMixedPrecisions(*crs) || !has_f32)) {
|
||||
continue;
|
||||
}
|
||||
TF_RETURN_IF_ERROR(InsertConvertBeforeOperand(crs, i, F32, computation_));
|
||||
has_f32 = true;
|
||||
}
|
||||
|
||||
if (!has_bf16_output) {
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
if (bfloat16_support_->SupportsBF16Output(*crs) &&
|
||||
(bfloat16_support_->SupportsMixedPrecisions(*crs) || !has_f32)) {
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
std::vector<HloInstruction*> output_elements(crs->operand_count());
|
||||
auto original_shape = crs->shape();
|
||||
for (int64 i = 0; i < crs->operand_count(); ++i) {
|
||||
auto subshape = ShapeUtil::GetMutableSubshape(crs->mutable_shape(), {i});
|
||||
if (output_types[i] != BF16) {
|
||||
output_elements[i] = computation_->AddInstruction(
|
||||
HloInstruction::CreateGetTupleElement(*subshape, crs, i));
|
||||
continue;
|
||||
}
|
||||
subshape->set_element_type(F32);
|
||||
auto gte = computation_->AddInstruction(
|
||||
HloInstruction::CreateGetTupleElement(*subshape, crs, i));
|
||||
output_elements[i] =
|
||||
computation_->AddInstruction(HloInstruction::CreateConvert(
|
||||
ShapeUtil::ChangeElementType(*subshape, BF16), gte));
|
||||
}
|
||||
auto tuple = computation_->AddInstruction(
|
||||
HloInstruction::CreateTuple(output_elements));
|
||||
|
||||
std::vector<HloInstruction*> materialized_users = crs->users();
|
||||
// Use the crs' shape temporarily, in order to pass checks in
|
||||
// ReplaceUseWith.
|
||||
*tuple->mutable_shape() = crs->shape();
|
||||
for (auto* user : materialized_users) {
|
||||
TF_RETURN_IF_ERROR(crs->ReplaceUseWith(user, tuple));
|
||||
}
|
||||
*tuple->mutable_shape() = original_shape;
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
Status BFloat16NormalizationVisitor::HandleInstruction(HloInstruction* hlo) {
|
||||
std::vector<int64> bf16_operands;
|
||||
std::vector<int64> f32_operands;
|
||||
bool has_f32 = false;
|
||||
bool has_bf16 = false;
|
||||
|
||||
for (int64 i = 0; i < hlo->operand_count(); ++i) {
|
||||
if (hlo->operand(i)->shape().element_type() == F32) {
|
||||
f32_operands.push_back(i);
|
||||
has_f32 = true;
|
||||
} else if (hlo->operand(i)->shape().element_type() == BF16) {
|
||||
bf16_operands.push_back(i);
|
||||
has_bf16 = true;
|
||||
}
|
||||
}
|
||||
|
||||
if (hlo->shape().element_type() == F32) {
|
||||
has_f32 = true;
|
||||
} else if (hlo->shape().element_type() == BF16) {
|
||||
has_bf16 = true;
|
||||
}
|
||||
|
||||
std::vector<HloComputation*> bf16_called_comps;
|
||||
for (auto* comp : hlo->called_computations()) {
|
||||
bool comp_has_bf16 = false;
|
||||
if (comp->root_instruction()->shape().element_type() == F32) {
|
||||
has_f32 = true;
|
||||
} else if (comp->root_instruction()->shape().element_type() == BF16) {
|
||||
has_bf16 = true;
|
||||
comp_has_bf16 = true;
|
||||
}
|
||||
for (auto* param : comp->parameter_instructions()) {
|
||||
if (param->shape().element_type() == F32) {
|
||||
has_f32 = true;
|
||||
} else if (param->shape().element_type() == BF16) {
|
||||
has_bf16 = true;
|
||||
comp_has_bf16 = true;
|
||||
}
|
||||
}
|
||||
if (comp_has_bf16) {
|
||||
bf16_called_comps.push_back(comp);
|
||||
}
|
||||
}
|
||||
|
||||
if (!bfloat16_support_->SupportsMixedPrecisions(*hlo) && has_bf16 &&
|
||||
has_f32) {
|
||||
// Resolve unsupported mixed precision.
|
||||
//
|
||||
// See if we can change everything to BF16.
|
||||
if (hlo->called_computations().empty() &&
|
||||
hlo->shape().element_type() == BF16) {
|
||||
bool can_use_bf16 = true;
|
||||
for (int i : f32_operands) {
|
||||
if (bfloat16_support_->EffectiveOperandPrecisionIsOutputPrecision(*hlo,
|
||||
i) &&
|
||||
bfloat16_support_->SupportsBF16Operand(*hlo, i)) {
|
||||
continue;
|
||||
}
|
||||
can_use_bf16 = false;
|
||||
break;
|
||||
}
|
||||
if (can_use_bf16) {
|
||||
for (int i : f32_operands) {
|
||||
TF_RETURN_IF_ERROR(
|
||||
InsertConvertBeforeOperand(hlo, i, BF16, computation_));
|
||||
}
|
||||
return Status::OK();
|
||||
}
|
||||
}
|
||||
if (hlo->shape().element_type() == BF16) {
|
||||
TF_RETURN_IF_ERROR(
|
||||
ChangeOutputTypeThenInsertConvertBack(hlo, F32, computation_));
|
||||
}
|
||||
for (int i : bf16_operands) {
|
||||
TF_RETURN_IF_ERROR(InsertConvertBeforeOperand(hlo, i, F32, computation_));
|
||||
}
|
||||
return ConvertCalledComputations(hlo, bf16_called_comps);
|
||||
}
|
||||
|
||||
for (int i : bf16_operands) {
|
||||
if (!bfloat16_support_->SupportsBF16Operand(*hlo, i)) {
|
||||
TF_RETURN_IF_ERROR(InsertConvertBeforeOperand(hlo, i, F32, computation_));
|
||||
}
|
||||
}
|
||||
|
||||
if (hlo->shape().element_type() == BF16 &&
|
||||
!bfloat16_support_->SupportsBF16Output(*hlo)) {
|
||||
TF_RETURN_IF_ERROR(
|
||||
ChangeOutputTypeThenInsertConvertBack(hlo, F32, computation_));
|
||||
}
|
||||
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
Status BFloat16NormalizationVisitor::DefaultAction(HloInstruction* hlo) {
|
||||
// Do not change instructions related to entry and exit of a computation,
|
||||
// tuples, fusion, convert, and control flow.
|
||||
if (hlo->opcode() == HloOpcode::kTuple || //
|
||||
hlo->opcode() == HloOpcode::kGetTupleElement || //
|
||||
hlo->opcode() == HloOpcode::kInfeed || //
|
||||
hlo->opcode() == HloOpcode::kOutfeed || //
|
||||
hlo->opcode() == HloOpcode::kConstant || //
|
||||
hlo->opcode() == HloOpcode::kParameter || //
|
||||
hlo->opcode() == HloOpcode::kFusion || //
|
||||
hlo->opcode() == HloOpcode::kConvert || //
|
||||
hlo->opcode() == HloOpcode::kCall || //
|
||||
hlo->opcode() == HloOpcode::kCustomCall || //
|
||||
hlo->opcode() == HloOpcode::kWhile || //
|
||||
hlo->opcode() == HloOpcode::kConditional) {
|
||||
return Status::OK();
|
||||
}
|
||||
return HandleInstruction(hlo);
|
||||
}
|
||||
|
||||
StatusOr<bool> BFloat16Normalization::Run(HloModule* module) {
|
||||
XLA_VLOG_LINES(
|
||||
2, "BFloat16Normalization::Run(), before:\n" + module->ToString());
|
||||
bool changed = false;
|
||||
for (auto* comp : module->MakeComputationPostOrder()) {
|
||||
if (BFloat16NormalizationVisitor::Run(comp, bfloat16_support_)) {
|
||||
changed = true;
|
||||
}
|
||||
}
|
||||
XLA_VLOG_LINES(2,
|
||||
"BFloat16Normalization::Run(), after:\n" + module->ToString());
|
||||
return changed;
|
||||
}
|
||||
|
||||
} // namespace xla
|
92
tensorflow/compiler/xla/service/bfloat16_normalization.h
Normal file
92
tensorflow/compiler/xla/service/bfloat16_normalization.h
Normal file
@ -0,0 +1,92 @@
|
||||
/* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
|
||||
#ifndef TENSORFLOW_COMPILER_XLA_SERVICE_BFLOAT16_NORMALIZATION_H_
|
||||
#define TENSORFLOW_COMPILER_XLA_SERVICE_BFLOAT16_NORMALIZATION_H_
|
||||
|
||||
#include "tensorflow/compiler/xla/service/bfloat16_support.h"
|
||||
#include "tensorflow/compiler/xla/service/hlo_module.h"
|
||||
#include "tensorflow/compiler/xla/service/hlo_pass_interface.h"
|
||||
|
||||
namespace xla {
|
||||
|
||||
// A pass which adds F32 <-> BF16 conversions for HLO instructions that do not
|
||||
// support BF16 input/output or mixed precision, according to the passed-in
|
||||
// backend-specific BF16 support rules.
|
||||
class BFloat16Normalization : public HloPassInterface {
|
||||
public:
|
||||
explicit BFloat16Normalization(const BFloat16Support* bfloat16_support)
|
||||
: bfloat16_support_(bfloat16_support) {}
|
||||
|
||||
~BFloat16Normalization() override = default;
|
||||
tensorflow::StringPiece name() const override { return "bf16-normalization"; }
|
||||
|
||||
// Run BF16 normalization on the given computation. Returns whether the
|
||||
// computation was changed.
|
||||
StatusOr<bool> Run(HloModule* module) override;
|
||||
|
||||
private:
|
||||
const BFloat16Support* bfloat16_support_;
|
||||
};
|
||||
|
||||
// A pass that unconditionally removes the mixed F32/BF16 uses in HLO
|
||||
// instructions (excluding convert) by adding F32 <-> BF16 conversions. Unlike
|
||||
// BFloat16Normalization, this pass does not use a backend-specific
|
||||
// BFloat16Support, and does not change HLOs that have BF16 data if they do not
|
||||
// use mixed precision; it removes mixed precision even if the backend supports
|
||||
// it. This pass is used to make the HLO module valid for other HLO passes which
|
||||
// do not support mixed precision.
|
||||
class BFloat16MixedPrecisionRemoval : public HloPassInterface {
|
||||
public:
|
||||
BFloat16MixedPrecisionRemoval() {}
|
||||
|
||||
~BFloat16MixedPrecisionRemoval() override = default;
|
||||
|
||||
tensorflow::StringPiece name() const override {
|
||||
return "bf16-mixed-precision-removal";
|
||||
}
|
||||
|
||||
// Run mixed precision removal on the given computation. Returns whether the
|
||||
// computation was changed.
|
||||
StatusOr<bool> Run(HloModule* module) override {
|
||||
BFloat16Normalization normalization(&no_mixed_precision_support_);
|
||||
return normalization.Run(module);
|
||||
}
|
||||
|
||||
private:
|
||||
class BFloat16SupportForMixedPrecisionRemoval : public BFloat16Support {
|
||||
public:
|
||||
BFloat16SupportForMixedPrecisionRemoval() {}
|
||||
|
||||
~BFloat16SupportForMixedPrecisionRemoval() override = default;
|
||||
|
||||
bool SupportsBF16Operand(const HloInstruction& hlo,
|
||||
int64 operand_index) const override {
|
||||
return true;
|
||||
}
|
||||
|
||||
bool SupportsBF16Output(const HloInstruction& hlo) const override {
|
||||
return true;
|
||||
}
|
||||
|
||||
bool SupportsMixedPrecisions(const HloInstruction& hlo) const override {
|
||||
return false;
|
||||
}
|
||||
} no_mixed_precision_support_;
|
||||
};
|
||||
|
||||
} // namespace xla
|
||||
|
||||
#endif // TENSORFLOW_COMPILER_XLA_SERVICE_BFLOAT16_NORMALIZATION_H_
|
248
tensorflow/compiler/xla/service/bfloat16_normalization_test.cc
Normal file
248
tensorflow/compiler/xla/service/bfloat16_normalization_test.cc
Normal file
@ -0,0 +1,248 @@
|
||||
/* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
|
||||
#include "tensorflow/compiler/xla/service/bfloat16_normalization.h"
|
||||
#include "tensorflow/compiler/xla/service/bfloat16_support.h"
|
||||
#include "tensorflow/compiler/xla/service/hlo_computation.h"
|
||||
#include "tensorflow/compiler/xla/service/hlo_instruction.h"
|
||||
#include "tensorflow/compiler/xla/service/hlo_module.h"
|
||||
#include "tensorflow/compiler/xla/service/hlo_opcode.h"
|
||||
#include "tensorflow/compiler/xla/shape_util.h"
|
||||
#include "tensorflow/compiler/xla/test.h"
|
||||
#include "tensorflow/compiler/xla/test_helpers.h"
|
||||
#include "tensorflow/compiler/xla/tests/hlo_test_base.h"
|
||||
#include "tensorflow/compiler/xla/xla_data.pb.h"
|
||||
|
||||
namespace xla {
|
||||
|
||||
class TestBFloat16Support : public BFloat16Support {
|
||||
public:
|
||||
TestBFloat16Support() {}
|
||||
~TestBFloat16Support() override {}
|
||||
|
||||
bool SupportsBF16Operand(const HloInstruction& hlo,
|
||||
int64 operand_index) const override {
|
||||
if (hlo.opcode() == HloOpcode::kAdd ||
|
||||
hlo.opcode() == HloOpcode::kSubtract ||
|
||||
hlo.opcode() == HloOpcode::kReduce ||
|
||||
hlo.opcode() == HloOpcode::kTuple ||
|
||||
hlo.opcode() == HloOpcode::kGetTupleElement) {
|
||||
return true;
|
||||
}
|
||||
return false;
|
||||
}
|
||||
|
||||
bool SupportsBF16Output(const HloInstruction& hlo) const override {
|
||||
if (hlo.opcode() == HloOpcode::kAdd || hlo.opcode() == HloOpcode::kReduce ||
|
||||
hlo.opcode() == HloOpcode::kSubtract ||
|
||||
hlo.opcode() == HloOpcode::kTuple ||
|
||||
hlo.opcode() == HloOpcode::kGetTupleElement) {
|
||||
return true;
|
||||
}
|
||||
return false;
|
||||
}
|
||||
|
||||
bool SupportsMixedPrecisions(const HloInstruction& hlo) const override {
|
||||
if (hlo.opcode() == HloOpcode::kAdd || hlo.opcode() == HloOpcode::kTuple ||
|
||||
hlo.opcode() == HloOpcode::kGetTupleElement) {
|
||||
return true;
|
||||
}
|
||||
return false;
|
||||
}
|
||||
};
|
||||
|
||||
class BFloat16NormalizationTest : public HloTestBase {
|
||||
protected:
|
||||
bool Normalize(HloModule* module) {
|
||||
TestBFloat16Support bfloat16_support_;
|
||||
BFloat16Normalization normalization(&bfloat16_support_);
|
||||
StatusOr<bool> result = normalization.Run(module);
|
||||
EXPECT_IS_OK(result.status());
|
||||
return result.ValueOrDie();
|
||||
}
|
||||
};
|
||||
|
||||
TEST_F(BFloat16NormalizationTest, NoopIfSupported) {
|
||||
auto builder = HloComputation::Builder(TestName());
|
||||
Shape f32_shape = ShapeUtil::MakeShape(F32, {2, 4});
|
||||
Shape bf16_shape = ShapeUtil::MakeShape(BF16, {2, 4});
|
||||
|
||||
HloInstruction* a = builder.AddInstruction(
|
||||
HloInstruction::CreateParameter(0, f32_shape, "a"));
|
||||
HloInstruction* b = builder.AddInstruction(
|
||||
HloInstruction::CreateParameter(1, bf16_shape, "b"));
|
||||
HloInstruction* c = builder.AddInstruction(
|
||||
HloInstruction::CreateParameter(2, f32_shape, "c"));
|
||||
|
||||
HloInstruction* add0 = builder.AddInstruction(
|
||||
HloInstruction::CreateBinary(bf16_shape, HloOpcode::kAdd, a, b));
|
||||
|
||||
HloInstruction* add1 = builder.AddInstruction(
|
||||
HloInstruction::CreateBinary(f32_shape, HloOpcode::kAdd, add0, c));
|
||||
|
||||
auto module = CreateNewModule();
|
||||
auto computation = module->AddEntryComputation(builder.Build());
|
||||
|
||||
EXPECT_FALSE(Normalize(module.get()));
|
||||
|
||||
EXPECT_EQ(computation->root_instruction(), add1);
|
||||
EXPECT_EQ(add0->shape().element_type(), BF16);
|
||||
EXPECT_EQ(add1->shape().element_type(), F32);
|
||||
}
|
||||
|
||||
TEST_F(BFloat16NormalizationTest, ResolveIfUnsupportedBF16) {
|
||||
auto builder = HloComputation::Builder(TestName());
|
||||
Shape f32_shape = ShapeUtil::MakeShape(F32, {2, 4});
|
||||
Shape bf16_shape = ShapeUtil::MakeShape(BF16, {2, 4});
|
||||
|
||||
HloInstruction* a = builder.AddInstruction(
|
||||
HloInstruction::CreateParameter(0, f32_shape, "a"));
|
||||
HloInstruction* b = builder.AddInstruction(
|
||||
HloInstruction::CreateParameter(1, bf16_shape, "b"));
|
||||
HloInstruction* c = builder.AddInstruction(
|
||||
HloInstruction::CreateParameter(2, f32_shape, "c"));
|
||||
|
||||
HloInstruction* mul0 = builder.AddInstruction(
|
||||
HloInstruction::CreateBinary(bf16_shape, HloOpcode::kMultiply, a, b));
|
||||
|
||||
HloInstruction* mul1 = builder.AddInstruction(
|
||||
HloInstruction::CreateBinary(bf16_shape, HloOpcode::kMultiply, mul0, c));
|
||||
|
||||
auto module = CreateNewModule();
|
||||
auto computation = module->AddEntryComputation(builder.Build());
|
||||
|
||||
EXPECT_TRUE(Normalize(module.get()));
|
||||
|
||||
EXPECT_EQ(computation->root_instruction()->opcode(), HloOpcode::kConvert);
|
||||
EXPECT_EQ(computation->root_instruction()->operand(0), mul1);
|
||||
EXPECT_EQ(mul0->shape().element_type(), F32);
|
||||
EXPECT_EQ(mul1->shape().element_type(), F32);
|
||||
EXPECT_EQ(mul1->operand(0)->opcode(), HloOpcode::kConvert);
|
||||
}
|
||||
|
||||
TEST_F(BFloat16NormalizationTest, ResolveUnsupportedMixedPrecisionSubtraction) {
|
||||
auto builder = HloComputation::Builder(TestName());
|
||||
Shape f32_shape = ShapeUtil::MakeShape(F32, {2, 4});
|
||||
Shape bf16_shape = ShapeUtil::MakeShape(BF16, {2, 4});
|
||||
|
||||
HloInstruction* a = builder.AddInstruction(
|
||||
HloInstruction::CreateParameter(0, f32_shape, "a"));
|
||||
HloInstruction* b = builder.AddInstruction(
|
||||
HloInstruction::CreateParameter(1, bf16_shape, "b"));
|
||||
HloInstruction* c = builder.AddInstruction(
|
||||
HloInstruction::CreateParameter(2, f32_shape, "c"));
|
||||
|
||||
HloInstruction* sub0 = builder.AddInstruction(
|
||||
HloInstruction::CreateBinary(bf16_shape, HloOpcode::kSubtract, a, b));
|
||||
|
||||
HloInstruction* sub1 = builder.AddInstruction(
|
||||
HloInstruction::CreateBinary(bf16_shape, HloOpcode::kSubtract, sub0, c));
|
||||
|
||||
auto module = CreateNewModule();
|
||||
auto computation = module->AddEntryComputation(builder.Build());
|
||||
|
||||
EXPECT_TRUE(Normalize(module.get()));
|
||||
|
||||
EXPECT_EQ(computation->root_instruction()->opcode(), HloOpcode::kConvert);
|
||||
EXPECT_EQ(computation->root_instruction()->operand(0), sub1);
|
||||
EXPECT_EQ(sub0->shape().element_type(), F32);
|
||||
EXPECT_EQ(sub1->shape().element_type(), F32);
|
||||
EXPECT_EQ(sub1->operand(0)->opcode(), HloOpcode::kConvert);
|
||||
}
|
||||
|
||||
TEST_F(BFloat16NormalizationTest, ResolveUnsupportedMixedPrecisionReduce) {
|
||||
Shape f32_input_shape = ShapeUtil::MakeShape(F32, {2, 4});
|
||||
Shape f32_output_shape = ShapeUtil::MakeShape(F32, {4});
|
||||
|
||||
Shape bf16_scalar_shape = ShapeUtil::MakeShape(BF16, {2, 4});
|
||||
|
||||
auto reduce_comp_builder = HloComputation::Builder("reduce_comp");
|
||||
auto reduce_comp_param0 = reduce_comp_builder.AddInstruction(
|
||||
HloInstruction::CreateParameter(0, bf16_scalar_shape, "param0"));
|
||||
auto reduce_comp_param1 = reduce_comp_builder.AddInstruction(
|
||||
HloInstruction::CreateParameter(1, bf16_scalar_shape, "param1"));
|
||||
reduce_comp_builder.AddInstruction(
|
||||
HloInstruction::CreateBinary(bf16_scalar_shape, HloOpcode::kAdd,
|
||||
reduce_comp_param0, reduce_comp_param1));
|
||||
|
||||
auto module = CreateNewModule();
|
||||
auto reduce_computation =
|
||||
module->AddEmbeddedComputation(reduce_comp_builder.Build());
|
||||
|
||||
auto builder = HloComputation::Builder(TestName());
|
||||
HloInstruction* input = builder.AddInstruction(
|
||||
HloInstruction::CreateParameter(0, f32_input_shape, "a"));
|
||||
HloInstruction* init = builder.AddInstruction(
|
||||
HloInstruction::CreateParameter(1, bf16_scalar_shape, "init"));
|
||||
HloInstruction* reduce = builder.AddInstruction(HloInstruction::CreateReduce(
|
||||
f32_output_shape, input, init, {0}, reduce_computation));
|
||||
|
||||
auto computation = module->AddEntryComputation(builder.Build());
|
||||
|
||||
EXPECT_TRUE(Normalize(module.get()));
|
||||
|
||||
EXPECT_EQ(computation->root_instruction(), reduce);
|
||||
EXPECT_EQ(reduce->called_computations().size(), 1);
|
||||
EXPECT_EQ(reduce->called_computations()[0]->num_parameters(), 2);
|
||||
EXPECT_EQ(reduce->called_computations()[0]
|
||||
->parameter_instruction(0)
|
||||
->shape()
|
||||
.element_type(),
|
||||
F32);
|
||||
EXPECT_EQ(reduce->called_computations()[0]
|
||||
->parameter_instruction(1)
|
||||
->shape()
|
||||
.element_type(),
|
||||
F32);
|
||||
EXPECT_EQ(reduce->called_computations()[0]
|
||||
->root_instruction()
|
||||
->shape()
|
||||
.element_type(),
|
||||
F32);
|
||||
EXPECT_EQ(reduce->shape().element_type(), F32);
|
||||
EXPECT_EQ(reduce->operand(0), input);
|
||||
EXPECT_EQ(input->shape().element_type(), F32);
|
||||
EXPECT_EQ(reduce->operand(1)->opcode(), HloOpcode::kConvert);
|
||||
EXPECT_EQ(reduce->operand(1)->shape().element_type(), F32);
|
||||
}
|
||||
|
||||
TEST_F(BFloat16NormalizationTest, ResolveMixedPrecisionTupleCrossReplicaSum) {
|
||||
auto builder = HloComputation::Builder(TestName());
|
||||
Shape f32_shape = ShapeUtil::MakeShape(F32, {2, 4});
|
||||
Shape bf16_shape = ShapeUtil::MakeShape(BF16, {2, 4});
|
||||
|
||||
HloInstruction* a = builder.AddInstruction(
|
||||
HloInstruction::CreateParameter(0, f32_shape, "a"));
|
||||
HloInstruction* b = builder.AddInstruction(
|
||||
HloInstruction::CreateParameter(1, bf16_shape, "b"));
|
||||
|
||||
HloInstruction* crs =
|
||||
builder.AddInstruction(HloInstruction::CreateCrossReplicaSum(
|
||||
ShapeUtil::MakeTupleShape({f32_shape, bf16_shape}), {a, b}));
|
||||
HloInstruction* gte = builder.AddInstruction(
|
||||
HloInstruction::CreateGetTupleElement(bf16_shape, crs, 1));
|
||||
|
||||
auto module = CreateNewModule();
|
||||
auto computation = module->AddEntryComputation(builder.Build());
|
||||
|
||||
EXPECT_TRUE(Normalize(module.get()));
|
||||
|
||||
EXPECT_EQ(computation->root_instruction(), gte);
|
||||
EXPECT_EQ(gte->shape().element_type(), BF16);
|
||||
EXPECT_EQ(crs->operand(1)->shape().element_type(), F32);
|
||||
EXPECT_EQ(ShapeUtil::GetSubshape(crs->shape(), {1}).element_type(), F32);
|
||||
}
|
||||
|
||||
} // namespace xla
|
111
tensorflow/compiler/xla/service/bfloat16_support.cc
Normal file
111
tensorflow/compiler/xla/service/bfloat16_support.cc
Normal file
@ -0,0 +1,111 @@
|
||||
/* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
|
||||
#include "tensorflow/compiler/xla/service/bfloat16_support.h"
|
||||
#include "tensorflow/compiler/xla/service/hlo_instruction.h"
|
||||
#include "tensorflow/compiler/xla/service/hlo_opcode.h"
|
||||
|
||||
namespace xla {
|
||||
|
||||
bool BFloat16Support::SupportsBF16Operand(const HloInstruction& hlo,
|
||||
int64 operand_index) const {
|
||||
switch (hlo.opcode()) {
|
||||
case HloOpcode::kCall:
|
||||
case HloOpcode::kConditional:
|
||||
case HloOpcode::kCustomCall:
|
||||
case HloOpcode::kGetTupleElement:
|
||||
case HloOpcode::kTuple:
|
||||
case HloOpcode::kWhile:
|
||||
return true;
|
||||
case HloOpcode::kConvert:
|
||||
CHECK_EQ(operand_index, 0);
|
||||
return hlo.operand(0)->shape().element_type() == BF16;
|
||||
default:
|
||||
break;
|
||||
}
|
||||
return false;
|
||||
}
|
||||
|
||||
bool BFloat16Support::SupportsBF16Output(const HloInstruction& hlo) const {
|
||||
switch (hlo.opcode()) {
|
||||
case HloOpcode::kCall:
|
||||
case HloOpcode::kConditional:
|
||||
case HloOpcode::kCustomCall:
|
||||
case HloOpcode::kGetTupleElement:
|
||||
case HloOpcode::kTuple:
|
||||
case HloOpcode::kWhile:
|
||||
return true;
|
||||
case HloOpcode::kConvert:
|
||||
return hlo.shape().element_type() == BF16;
|
||||
default:
|
||||
break;
|
||||
}
|
||||
return false;
|
||||
}
|
||||
|
||||
bool BFloat16Support::SupportsMixedPrecisions(const HloInstruction& hlo) const {
|
||||
switch (hlo.opcode()) {
|
||||
case HloOpcode::kCall:
|
||||
case HloOpcode::kConditional:
|
||||
case HloOpcode::kConvert:
|
||||
case HloOpcode::kCustomCall:
|
||||
case HloOpcode::kGetTupleElement:
|
||||
case HloOpcode::kTuple:
|
||||
case HloOpcode::kWhile:
|
||||
return true;
|
||||
default:
|
||||
break;
|
||||
}
|
||||
return false;
|
||||
}
|
||||
|
||||
/* static */
|
||||
bool BFloat16Support::EffectiveOperandPrecisionIsOutputPrecision(
|
||||
const HloInstruction& hlo, int64 operand_index) {
|
||||
switch (hlo.opcode()) {
|
||||
case HloOpcode::kAbs:
|
||||
case HloOpcode::kBroadcast:
|
||||
case HloOpcode::kClamp:
|
||||
case HloOpcode::kConcatenate:
|
||||
case HloOpcode::kCopy:
|
||||
case HloOpcode::kGetTupleElement:
|
||||
case HloOpcode::kMaximum:
|
||||
case HloOpcode::kMinimum:
|
||||
case HloOpcode::kPad:
|
||||
case HloOpcode::kReshape:
|
||||
case HloOpcode::kReverse:
|
||||
case HloOpcode::kSlice:
|
||||
case HloOpcode::kSort:
|
||||
case HloOpcode::kTranspose:
|
||||
case HloOpcode::kTuple:
|
||||
return true;
|
||||
case HloOpcode::kDynamicSlice:
|
||||
return operand_index == 0;
|
||||
case HloOpcode::kDynamicUpdateSlice:
|
||||
return operand_index == 0 || operand_index == 1;
|
||||
case HloOpcode::kSelect:
|
||||
return operand_index == 1 || operand_index == 2;
|
||||
default:
|
||||
break;
|
||||
}
|
||||
return false;
|
||||
}
|
||||
|
||||
bool BFloat16Support::EffectiveOperandPrecisionIsBF16(
|
||||
const HloInstruction& hlo, int64 operand_index) const {
|
||||
return false;
|
||||
}
|
||||
|
||||
} // namespace xla
|
60
tensorflow/compiler/xla/service/bfloat16_support.h
Normal file
60
tensorflow/compiler/xla/service/bfloat16_support.h
Normal file
@ -0,0 +1,60 @@
|
||||
/* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
|
||||
#ifndef TENSORFLOW_COMPILER_XLA_SERVICE_BFLOAT16_SUPPORT_H_
|
||||
#define TENSORFLOW_COMPILER_XLA_SERVICE_BFLOAT16_SUPPORT_H_
|
||||
|
||||
#include "tensorflow/compiler/xla/service/hlo_instruction.h"
|
||||
#include "tensorflow/compiler/xla/service/hlo_opcode.h"
|
||||
|
||||
namespace xla {
|
||||
|
||||
class BFloat16Support {
|
||||
public:
|
||||
BFloat16Support() {}
|
||||
virtual ~BFloat16Support() {}
|
||||
|
||||
// Returns whether the backend supports BF16 operand for the HLO instruction
|
||||
// at the given index.
|
||||
virtual bool SupportsBF16Operand(const HloInstruction& hlo,
|
||||
int64 operand_index) const;
|
||||
|
||||
// Returns whether the backend supports BF16 output for the HLO instruction.
|
||||
virtual bool SupportsBF16Output(const HloInstruction& hlo) const;
|
||||
|
||||
// Returns whether the backend support mixed precision: the operands, output,
|
||||
// and parameters/output of the called computations can have different
|
||||
// precisions (BF16 and F32).
|
||||
virtual bool SupportsMixedPrecisions(const HloInstruction& hlo) const;
|
||||
|
||||
// Returns whether the given HLO inherits its BF16 operand precision at the
|
||||
// given index, so even if the output is F32, elements in the output that
|
||||
// depend on the BF16 operand will still have BF16 effective precision even if
|
||||
// they have F32 format. Similarly, this also means if the output is BF16 then
|
||||
// increasing the operand precision from BF16 to F32 will not change the
|
||||
// output. This typically includes HLOs that pass elements from the operand to
|
||||
// the output without arithmetic operations.
|
||||
static bool EffectiveOperandPrecisionIsOutputPrecision(
|
||||
const HloInstruction& hlo, int64 operand_index);
|
||||
|
||||
// Returns if the backend only uses BF16 precision for the operand at the
|
||||
// specified index, even if the operand is F32.
|
||||
virtual bool EffectiveOperandPrecisionIsBF16(const HloInstruction& hlo,
|
||||
int64 operand_index) const;
|
||||
};
|
||||
|
||||
} // namespace xla
|
||||
|
||||
#endif // TENSORFLOW_COMPILER_XLA_SERVICE_BFLOAT16_SUPPORT_H_
|
@ -159,9 +159,6 @@ cc_library(
|
||||
deps = [
|
||||
":compiler_functor",
|
||||
":cpu_runtime",
|
||||
":cpu_runtime_avx",
|
||||
":cpu_runtime_neon",
|
||||
":cpu_runtime_sse4_1",
|
||||
":custom_call_target_registry",
|
||||
":disassembler",
|
||||
":external_constant_pool",
|
||||
@ -408,9 +405,6 @@ cc_library(
|
||||
hdrs = ["compiler_functor.h"],
|
||||
deps = [
|
||||
":cpu_runtime",
|
||||
":cpu_runtime_avx",
|
||||
":cpu_runtime_neon",
|
||||
":cpu_runtime_sse4_1",
|
||||
":disassembler",
|
||||
":llvm_ir_runtime",
|
||||
"//tensorflow/compiler/xla:statusor",
|
||||
@ -430,43 +424,6 @@ cc_library(
|
||||
],
|
||||
)
|
||||
|
||||
cc_library(
|
||||
name = "cpu_runtime_sse4_1",
|
||||
srcs = ["cpu_runtime_sse4_1.cc"],
|
||||
hdrs = ["cpu_runtime_sse4_1.h"],
|
||||
copts = ["-DEIGEN_AVOID_STL_ARRAY"],
|
||||
visibility = ["//visibility:public"],
|
||||
deps = [
|
||||
"//tensorflow/core:framework_lite",
|
||||
"//third_party/eigen3",
|
||||
],
|
||||
)
|
||||
|
||||
cc_library(
|
||||
name = "cpu_runtime_avx",
|
||||
srcs = ["cpu_runtime_avx.cc"],
|
||||
hdrs = ["cpu_runtime_avx.h"],
|
||||
copts = ["-DEIGEN_AVOID_STL_ARRAY"],
|
||||
visibility = ["//visibility:public"],
|
||||
deps = [
|
||||
"//tensorflow/core:framework_lite",
|
||||
"//third_party/eigen3",
|
||||
],
|
||||
)
|
||||
|
||||
cc_library(
|
||||
name = "cpu_runtime_neon",
|
||||
srcs = ["cpu_runtime_neon.cc"],
|
||||
hdrs = ["cpu_runtime_neon.h"],
|
||||
# runtime_copts() enables -mfpu=neon
|
||||
copts = ["-DEIGEN_AVOID_STL_ARRAY"] + runtime_copts(),
|
||||
visibility = ["//visibility:public"],
|
||||
deps = [
|
||||
"//tensorflow/core:framework_lite",
|
||||
"//third_party/eigen3",
|
||||
],
|
||||
)
|
||||
|
||||
cc_library(
|
||||
name = "cpu_runtime",
|
||||
srcs = [
|
||||
|
@ -37,9 +37,6 @@ limitations under the License.
|
||||
#include "llvm/Transforms/IPO/PassManagerBuilder.h"
|
||||
#include "tensorflow/compiler/xla/ptr_util.h"
|
||||
#include "tensorflow/compiler/xla/service/cpu/cpu_runtime.h"
|
||||
#include "tensorflow/compiler/xla/service/cpu/cpu_runtime_avx.h"
|
||||
#include "tensorflow/compiler/xla/service/cpu/cpu_runtime_neon.h"
|
||||
#include "tensorflow/compiler/xla/service/cpu/cpu_runtime_sse4_1.h"
|
||||
#include "tensorflow/compiler/xla/service/cpu/llvm_ir_runtime.h"
|
||||
#include "tensorflow/compiler/xla/service/llvm_ir/llvm_util.h"
|
||||
#include "tensorflow/compiler/xla/statusor.h"
|
||||
@ -50,15 +47,6 @@ limitations under the License.
|
||||
namespace xla {
|
||||
namespace cpu {
|
||||
|
||||
/* static */ CompilerFunctor::VectorIntrinsics
|
||||
CompilerFunctor::AllIntrinsics() {
|
||||
VectorIntrinsics intrinsics;
|
||||
intrinsics.sse_intrinsics = true;
|
||||
intrinsics.avx_intrinsics = true;
|
||||
intrinsics.neon_intrinsics = true;
|
||||
return intrinsics;
|
||||
}
|
||||
|
||||
/* Create filtered versions of the LLVM Pass Managers to filter out some
|
||||
of the expensive passes.
|
||||
Profiling:
|
||||
@ -192,31 +180,8 @@ operator()(llvm::Module& module) const {
|
||||
std::move(object_file), std::move(memory_buffer));
|
||||
}
|
||||
|
||||
namespace {
|
||||
// Returns the set of vectorized library functions supported for the target.
|
||||
std::vector<llvm::VecDesc> VectorFunctionsForTargetLibraryInfoImpl(
|
||||
llvm::Triple::ArchType arch, llvm::StringRef feature_string,
|
||||
CompilerFunctor::VectorIntrinsics const& available_intrinsics) {
|
||||
std::vector<llvm::VecDesc> vector_functions;
|
||||
|
||||
const llvm::VecDesc four_wide_vector_functions_neon[] = {
|
||||
{"logf", runtime::kLogV4F32NEONSymbolName, 4},
|
||||
{"llvm.log.f32", runtime::kLogV4F32NEONSymbolName, 4},
|
||||
};
|
||||
|
||||
const llvm::VecDesc four_wide_vector_functions_sse[] = {
|
||||
{"logf", runtime::kLogV4F32SSESymbolName, 4},
|
||||
{"llvm.log.f32", runtime::kLogV4F32SSESymbolName, 4},
|
||||
};
|
||||
|
||||
const llvm::VecDesc eight_wide_vector_functions_avx[] = {
|
||||
{"logf", runtime::kLogV8F32AVXSymbolName, 8},
|
||||
{"llvm.log.f32", runtime::kLogV8F32AVXSymbolName, 8},
|
||||
};
|
||||
|
||||
// These functions are generated by XLA as LLVM IR, so they're always
|
||||
// available.
|
||||
const llvm::VecDesc ir_vector_functions[] = {
|
||||
static std::vector<llvm::VecDesc> VectorFunctionsForTargetLibraryInfoImpl() {
|
||||
std::vector<llvm::VecDesc> result = {
|
||||
{"tanhf", runtime::kTanhV4F32SymbolName, 4},
|
||||
{"llvm.tanh.f32", runtime::kTanhV4F32SymbolName, 4},
|
||||
|
||||
@ -228,50 +193,15 @@ std::vector<llvm::VecDesc> VectorFunctionsForTargetLibraryInfoImpl(
|
||||
|
||||
{"expf", runtime::kExpV8F32SymbolName, 8},
|
||||
{"llvm.exp.f32", runtime::kExpV8F32SymbolName, 8},
|
||||
|
||||
{"logf", runtime::kLogV4F32SymbolName, 4},
|
||||
{"llvm.log.f32", runtime::kLogV4F32SymbolName, 4},
|
||||
|
||||
{"logf", runtime::kLogV8F32SymbolName, 8},
|
||||
{"llvm.log.f32", runtime::kLogV8F32SymbolName, 8},
|
||||
};
|
||||
|
||||
llvm::SmallVector<llvm::StringRef, 32> features;
|
||||
feature_string.split(features, ',', -1, /*KeepEmpty=*/false);
|
||||
auto has_feature = [&features](const llvm::StringRef feature) {
|
||||
return std::find(features.begin(), features.end(), feature) !=
|
||||
features.end();
|
||||
};
|
||||
|
||||
switch (arch) {
|
||||
case llvm::Triple::x86:
|
||||
case llvm::Triple::x86_64: {
|
||||
if (has_feature("+sse4.1") && available_intrinsics.sse_intrinsics) {
|
||||
vector_functions.insert(vector_functions.end(),
|
||||
std::begin(four_wide_vector_functions_sse),
|
||||
std::end(four_wide_vector_functions_sse));
|
||||
}
|
||||
if (has_feature("+avx") && available_intrinsics.avx_intrinsics) {
|
||||
vector_functions.insert(vector_functions.end(),
|
||||
std::begin(eight_wide_vector_functions_avx),
|
||||
std::end(eight_wide_vector_functions_avx));
|
||||
}
|
||||
break;
|
||||
}
|
||||
case llvm::Triple::arm:
|
||||
case llvm::Triple::aarch64: {
|
||||
if (has_feature("+neon") && available_intrinsics.neon_intrinsics) {
|
||||
vector_functions.insert(vector_functions.end(),
|
||||
std::begin(four_wide_vector_functions_neon),
|
||||
std::end(four_wide_vector_functions_neon));
|
||||
}
|
||||
break;
|
||||
}
|
||||
default:
|
||||
break;
|
||||
}
|
||||
|
||||
vector_functions.insert(vector_functions.end(),
|
||||
std::begin(ir_vector_functions),
|
||||
std::end(ir_vector_functions));
|
||||
|
||||
return vector_functions;
|
||||
return result;
|
||||
}
|
||||
} // namespace
|
||||
|
||||
void CompilerFunctor::AddTargetInfoPasses(
|
||||
llvm::legacy::PassManagerBase* passes) const {
|
||||
@ -279,9 +209,7 @@ void CompilerFunctor::AddTargetInfoPasses(
|
||||
auto target_library_info_impl =
|
||||
MakeUnique<llvm::TargetLibraryInfoImpl>(target_triple);
|
||||
target_library_info_impl->addVectorizableFunctions(
|
||||
VectorFunctionsForTargetLibraryInfoImpl(
|
||||
target_triple.getArch(), target_machine_->getTargetFeatureString(),
|
||||
available_intrinsics_));
|
||||
VectorFunctionsForTargetLibraryInfoImpl());
|
||||
passes->add(
|
||||
new llvm::TargetLibraryInfoWrapperPass(*target_library_info_impl));
|
||||
passes->add(createTargetTransformInfoWrapperPass(
|
||||
|
@ -31,21 +31,10 @@ namespace cpu {
|
||||
// Orc JIT compile layer.
|
||||
class CompilerFunctor {
|
||||
public:
|
||||
// Describes the set of vector intrinsics available to the generated code.
|
||||
struct VectorIntrinsics {
|
||||
bool sse_intrinsics;
|
||||
bool avx_intrinsics;
|
||||
bool neon_intrinsics;
|
||||
};
|
||||
|
||||
// Returns a VectorIntrinsics where all intrinsics are available.
|
||||
static VectorIntrinsics AllIntrinsics();
|
||||
|
||||
explicit CompilerFunctor(
|
||||
llvm::TargetMachine* target_machine, const Disassembler* disassembler,
|
||||
int opt_level, bool optimize_for_size, bool enable_fast_math,
|
||||
bool disable_expensive_passes,
|
||||
const VectorIntrinsics& available_intrinsics,
|
||||
LLVMCompiler::ModuleHook pre_optimization_hook = nullptr,
|
||||
LLVMCompiler::ModuleHook post_optimization_hook = nullptr)
|
||||
: target_machine_(target_machine),
|
||||
@ -54,7 +43,6 @@ class CompilerFunctor {
|
||||
optimize_for_size_(optimize_for_size),
|
||||
enable_fast_math_(enable_fast_math),
|
||||
disable_expensive_passes_(disable_expensive_passes),
|
||||
available_intrinsics_(available_intrinsics),
|
||||
pre_optimization_hook_(pre_optimization_hook),
|
||||
post_optimization_hook_(post_optimization_hook) {}
|
||||
|
||||
@ -78,7 +66,6 @@ class CompilerFunctor {
|
||||
const bool optimize_for_size_;
|
||||
const bool enable_fast_math_;
|
||||
const bool disable_expensive_passes_;
|
||||
const VectorIntrinsics available_intrinsics_;
|
||||
LLVMCompiler::ModuleHook pre_optimization_hook_;
|
||||
LLVMCompiler::ModuleHook post_optimization_hook_;
|
||||
};
|
||||
|
@ -888,8 +888,7 @@ CpuCompiler::CompileAheadOfTime(std::vector<std::unique_ptr<HloModule>> modules,
|
||||
options::OptimizeForSizeRequested(module->config()),
|
||||
module->config().debug_options().xla_enable_fast_math(),
|
||||
module->config().debug_options().xla_llvm_disable_expensive_passes(),
|
||||
CompilerFunctor::AllIntrinsics(), pre_optimization_ir_dump_hook,
|
||||
post_optimization_ir_dump_hook);
|
||||
pre_optimization_ir_dump_hook, post_optimization_ir_dump_hook);
|
||||
llvm::object::OwningBinary<llvm::object::ObjectFile> object_file =
|
||||
compiler_functor(llvm_module);
|
||||
llvm::StringRef object_file_data_ref = object_file.getBinary()->getData();
|
||||
|
@ -1,37 +0,0 @@
|
||||
/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
|
||||
#include "tensorflow/compiler/xla/service/cpu/cpu_runtime_avx.h"
|
||||
|
||||
#define EIGEN_USE_THREADS
|
||||
|
||||
#include "third_party/eigen3/Eigen/Core"
|
||||
|
||||
#ifdef TF_XLA_HAS_AVX
|
||||
xla::cpu::runtime::V8F32AVX __xla_cpu_runtime_LogV8F32AVX(
|
||||
xla::cpu::runtime::V8F32AVX x) {
|
||||
return Eigen::internal::plog(x);
|
||||
}
|
||||
#endif // TF_XLA_HAS_AVX
|
||||
|
||||
namespace xla {
|
||||
namespace cpu {
|
||||
namespace runtime {
|
||||
|
||||
const char *const kLogV8F32AVXSymbolName = "__xla_cpu_runtime_LogV8F32AVX";
|
||||
|
||||
} // namespace runtime
|
||||
} // namespace cpu
|
||||
} // namespace xla
|
@ -1,59 +0,0 @@
|
||||
/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
|
||||
// This header declares functions which may be called by the generated code on
|
||||
// the CPU. Calls to these functions must be resolved explicitly in the JIT in
|
||||
// xla::cpu::SimpleResolver. It also defines a per-CpuExecutable context
|
||||
// which is used to cache expensive state and resources utilized by the
|
||||
// aforementioned functions.
|
||||
|
||||
#ifndef TENSORFLOW_COMPILER_XLA_SERVICE_CPU_CPU_RUNTIME_AVX_H_
|
||||
#define TENSORFLOW_COMPILER_XLA_SERVICE_CPU_CPU_RUNTIME_AVX_H_
|
||||
|
||||
#include "tensorflow/core/platform/macros.h"
|
||||
|
||||
#if defined(__AVX__)
|
||||
#include <immintrin.h>
|
||||
#define TF_XLA_HAS_AVX
|
||||
#endif
|
||||
|
||||
namespace xla {
|
||||
namespace cpu {
|
||||
namespace runtime {
|
||||
|
||||
extern const char *const kLogV8F32AVXSymbolName;
|
||||
|
||||
#ifdef TF_XLA_HAS_AVX
|
||||
typedef __m256 V8F32AVX;
|
||||
#endif
|
||||
} // namespace runtime
|
||||
} // namespace cpu
|
||||
} // namespace xla
|
||||
|
||||
extern "C" {
|
||||
|
||||
#ifdef TF_XLA_HAS_AVX
|
||||
// The following functions are vectorized versions of a selection of libm
|
||||
// library functions.
|
||||
// References to these functions are created by the LLVM vectorizer.
|
||||
xla::cpu::runtime::V8F32AVX __xla_cpu_runtime_ExpV8F32AVX(
|
||||
xla::cpu::runtime::V8F32AVX x);
|
||||
|
||||
xla::cpu::runtime::V8F32AVX __xla_cpu_runtime_LogV8F32AVX(
|
||||
xla::cpu::runtime::V8F32AVX x);
|
||||
#endif
|
||||
}
|
||||
|
||||
#endif // TENSORFLOW_COMPILER_XLA_SERVICE_CPU_CPU_RUNTIME_AVX_H_
|
@ -1,46 +0,0 @@
|
||||
/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
|
||||
#include "tensorflow/compiler/xla/service/cpu/cpu_runtime_neon.h"
|
||||
|
||||
#define EIGEN_USE_THREADS
|
||||
|
||||
#include "third_party/eigen3/Eigen/Core"
|
||||
|
||||
#ifdef TF_XLA_HAS_NEON
|
||||
|
||||
xla::cpu::runtime::V4F32NEON __xla_cpu_runtime_ExpV4F32NEON(
|
||||
xla::cpu::runtime::V4F32NEON x) {
|
||||
return Eigen::internal::pexp(x);
|
||||
}
|
||||
|
||||
xla::cpu::runtime::V4F32NEON __xla_cpu_runtime_LogV4F32NEON(
|
||||
xla::cpu::runtime::V4F32NEON x) {
|
||||
Eigen::internal::Packet4f p = x;
|
||||
return Eigen::internal::plog(p);
|
||||
}
|
||||
|
||||
#endif // TF_XLA_HAS_NEON
|
||||
|
||||
namespace xla {
|
||||
namespace cpu {
|
||||
namespace runtime {
|
||||
|
||||
const char *const kExpV4F32NEONSymbolName = "__xla_cpu_runtime_ExpV4F32NEON";
|
||||
const char *const kLogV4F32NEONSymbolName = "__xla_cpu_runtime_LogV4F32NEON";
|
||||
|
||||
} // namespace runtime
|
||||
} // namespace cpu
|
||||
} // namespace xla
|
@ -1,62 +0,0 @@
|
||||
/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
|
||||
#ifndef TENSORFLOW_COMPILER_XLA_SERVICE_CPU_CPU_RUNTIME_NEON_H_
|
||||
#define TENSORFLOW_COMPILER_XLA_SERVICE_CPU_CPU_RUNTIME_NEON_H_
|
||||
|
||||
// This header declares functions which may be called by the generated code on
|
||||
// the CPU. Calls to these functions must be resolved explicitly in the JIT in
|
||||
// xla::cpu::SimpleResolver.
|
||||
|
||||
#include "tensorflow/core/platform/macros.h"
|
||||
|
||||
#ifdef __ARM_NEON__
|
||||
// For the other runtimes (AVX, SSE4.1) we define the vector type directly using
|
||||
// __attribute__((__vector_size__(*))). Unfortunately, the typedef for the ARM
|
||||
// NEON SIMD types is not portable, so the type has to come from <arm_neon.h>
|
||||
#include <arm_neon.h>
|
||||
#define TF_XLA_HAS_NEON
|
||||
#endif // __ARM_NEON__
|
||||
|
||||
namespace xla {
|
||||
namespace cpu {
|
||||
namespace runtime {
|
||||
|
||||
extern const char *const kExpV4F32NEONSymbolName;
|
||||
extern const char *const kLogV4F32NEONSymbolName;
|
||||
|
||||
#ifdef TF_XLA_HAS_NEON
|
||||
typedef float32x4_t V4F32NEON;
|
||||
#endif // TF_XLA_HAS_NEON
|
||||
|
||||
} // namespace runtime
|
||||
} // namespace cpu
|
||||
} // namespace xla
|
||||
|
||||
extern "C" {
|
||||
|
||||
#ifdef TF_XLA_HAS_NEON
|
||||
// The following functions are vectorized versions of a selection of libm
|
||||
// library functions.
|
||||
// References to these functions are created by the LLVM vectorizer.
|
||||
xla::cpu::runtime::V4F32NEON __xla_cpu_runtime_ExpV4F32NEON(
|
||||
xla::cpu::runtime::V4F32NEON x);
|
||||
|
||||
xla::cpu::runtime::V4F32NEON __xla_cpu_runtime_LogV4F32NEON(
|
||||
xla::cpu::runtime::V4F32NEON x);
|
||||
#endif // TF_XLA_HAS_NEON
|
||||
}
|
||||
|
||||
#endif // TENSORFLOW_COMPILER_XLA_SERVICE_CPU_CPU_RUNTIME_NEON_H_
|
@ -1,40 +0,0 @@
|
||||
/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
|
||||
#include "tensorflow/compiler/xla/service/cpu/cpu_runtime_sse4_1.h"
|
||||
|
||||
#define EIGEN_USE_THREADS
|
||||
|
||||
#include "third_party/eigen3/Eigen/Core"
|
||||
|
||||
#ifdef TF_XLA_HAS_SSE4_1
|
||||
|
||||
xla::cpu::runtime::V4F32SSE __xla_cpu_runtime_LogV4F32SSE(
|
||||
xla::cpu::runtime::V4F32SSE x) {
|
||||
Eigen::internal::Packet4f p = x;
|
||||
return Eigen::internal::plog(p);
|
||||
}
|
||||
|
||||
#endif // TF_XLA_HAS_SSE4_1
|
||||
|
||||
namespace xla {
|
||||
namespace cpu {
|
||||
namespace runtime {
|
||||
|
||||
const char *const kLogV4F32SSESymbolName = "__xla_cpu_runtime_LogV4F32SSE";
|
||||
|
||||
} // namespace runtime
|
||||
} // namespace cpu
|
||||
} // namespace xla
|
@ -1,59 +0,0 @@
|
||||
/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
|
||||
// This header declares functions which may be called by the generated code on
|
||||
// the CPU. Calls to these functions must be resolved explicitly in the JIT in
|
||||
// xla::cpu::SimpleResolver. It also defines a per-CpuExecutable context
|
||||
// which is used to cache expensive state and resources utilized by the
|
||||
// aforementioned functions.
|
||||
|
||||
#ifndef TENSORFLOW_COMPILER_XLA_SERVICE_CPU_CPU_RUNTIME_SSE4_1_H_
|
||||
#define TENSORFLOW_COMPILER_XLA_SERVICE_CPU_CPU_RUNTIME_SSE4_1_H_
|
||||
|
||||
#include "tensorflow/core/platform/macros.h"
|
||||
|
||||
// MSVC does not have __SSE4_1__ macro. Eigen enables EIGEN_VECTORIZE_SSE4_1
|
||||
// when __AVX__ is defined, we should do the same.
|
||||
#if defined(__SSE4_1__) || (defined(_MSC_VER) && defined(__AVX__))
|
||||
#include <smmintrin.h>
|
||||
#define TF_XLA_HAS_SSE4_1
|
||||
#endif
|
||||
|
||||
namespace xla {
|
||||
namespace cpu {
|
||||
namespace runtime {
|
||||
|
||||
extern const char *const kLogV4F32SSESymbolName;
|
||||
|
||||
#ifdef TF_XLA_HAS_SSE4_1
|
||||
typedef __m128 V4F32SSE;
|
||||
#endif
|
||||
|
||||
} // namespace runtime
|
||||
} // namespace cpu
|
||||
} // namespace xla
|
||||
|
||||
extern "C" {
|
||||
|
||||
#ifdef TF_XLA_HAS_SSE4_1
|
||||
// The following functions are vectorized versions of a selection of libm
|
||||
// library functions.
|
||||
// References to these functions are created by the LLVM vectorizer.
|
||||
xla::cpu::runtime::V4F32SSE __xla_cpu_runtime_LogV4F32SSE(
|
||||
xla::cpu::runtime::V4F32SSE x);
|
||||
#endif
|
||||
}
|
||||
|
||||
#endif // TENSORFLOW_COMPILER_XLA_SERVICE_CPU_CPU_RUNTIME_SSE4_1_H_
|
@ -21,6 +21,7 @@ limitations under the License.
|
||||
#include "llvm/IR/Verifier.h"
|
||||
#include "llvm/Transforms/Utils/Cloning.h"
|
||||
#include "tensorflow/compiler/xla/service/cpu/vector_support_library.h"
|
||||
#include "tensorflow/core/lib/core/casts.h"
|
||||
#include "tensorflow/core/platform/logging.h"
|
||||
|
||||
namespace xla {
|
||||
@ -31,6 +32,8 @@ const char* const kTanhV4F32SymbolName = "__xla_cpu_runtime_TanhV4F32";
|
||||
const char* const kTanhV8F32SymbolName = "__xla_cpu_runtime_TanhV8F32";
|
||||
const char* const kExpV4F32SymbolName = "__xla_cpu_runtime_ExpV4F32";
|
||||
const char* const kExpV8F32SymbolName = "__xla_cpu_runtime_ExpV8F32";
|
||||
const char* const kLogV4F32SymbolName = "__xla_cpu_runtime_LogV4F32AVX";
|
||||
const char* const kLogV8F32SymbolName = "__xla_cpu_runtime_LogV8F32AVX";
|
||||
|
||||
namespace {
|
||||
llvm::Function* EmitVectorF32TanhIfNeeded(llvm::Module* module,
|
||||
@ -116,19 +119,19 @@ llvm::Function* EmitVectorF32ExpIfNeeded(llvm::Module* module,
|
||||
|
||||
// This implements the same polynomial approximation as implemented in Eigen3.
|
||||
|
||||
const double exp_hi = 88.3762626647950;
|
||||
const double exp_lo = -88.3762626647949;
|
||||
const float exp_hi = 88.3762626647950;
|
||||
const float exp_lo = -88.3762626647949;
|
||||
|
||||
const double cephes_LOG2EF = 1.44269504088896341;
|
||||
const double cephes_exp_C1 = 0.693359375;
|
||||
const double cephes_exp_C2 = -2.12194440e-4;
|
||||
const float cephes_LOG2EF = 1.44269504088896341;
|
||||
const float cephes_exp_C1 = 0.693359375;
|
||||
const float cephes_exp_C2 = -2.12194440e-4;
|
||||
|
||||
const double cephes_exp_p0 = 1.9875691500E-4;
|
||||
const double cephes_exp_p1 = 1.3981999507E-3;
|
||||
const double cephes_exp_p2 = 8.3334519073E-3;
|
||||
const double cephes_exp_p3 = 4.1665795894E-2;
|
||||
const double cephes_exp_p4 = 1.6666665459E-1;
|
||||
const double cephes_exp_p5 = 5.0000001201E-1;
|
||||
const float cephes_exp_p0 = 1.9875691500E-4;
|
||||
const float cephes_exp_p1 = 1.3981999507E-3;
|
||||
const float cephes_exp_p2 = 8.3334519073E-3;
|
||||
const float cephes_exp_p3 = 4.1665795894E-2;
|
||||
const float cephes_exp_p4 = 1.6666665459E-1;
|
||||
const float cephes_exp_p5 = 5.0000001201E-1;
|
||||
|
||||
llvm::Value* input = &*vector_exp_function->arg_begin();
|
||||
llvm::Value* input_clamped =
|
||||
@ -146,7 +149,7 @@ llvm::Function* EmitVectorF32ExpIfNeeded(llvm::Module* module,
|
||||
y = vsl.MulAdd(y, x, cephes_exp_p4);
|
||||
y = vsl.MulAdd(y, x, cephes_exp_p5);
|
||||
y = vsl.MulAdd(y, z, x);
|
||||
y = vsl.Add(1.0, y);
|
||||
y = vsl.Add(1.0f, y);
|
||||
|
||||
// VectorSupportLibrary (intentionally) can't juggle more than one type at a
|
||||
// time so drop down to IRBuilder for this bit.
|
||||
@ -167,9 +170,133 @@ llvm::Function* EmitVectorF32ExpIfNeeded(llvm::Module* module,
|
||||
|
||||
ir_builder.CreateRet(result);
|
||||
|
||||
CHECK(!llvm::verifyFunction(*vector_exp_function));
|
||||
DCHECK(!llvm::verifyFunction(*vector_exp_function));
|
||||
return vector_exp_function;
|
||||
}
|
||||
|
||||
llvm::Function* EmitVectorF32LogIfNeeded(llvm::Module* module,
|
||||
llvm::StringRef function_name,
|
||||
int vector_width,
|
||||
bool enable_fast_math) {
|
||||
llvm::Function* vector_log_function = module->getFunction(function_name);
|
||||
if (vector_log_function == nullptr) {
|
||||
// If the function declaration is not present in the module, there can't be
|
||||
// any calls to resolve. Don't emit the function in this case.
|
||||
return nullptr;
|
||||
}
|
||||
|
||||
llvm::LLVMContext* context = &module->getContext();
|
||||
|
||||
llvm::BasicBlock* vector_log_body =
|
||||
llvm::BasicBlock::Create(*context, "body", vector_log_function);
|
||||
|
||||
llvm::IRBuilder<> ir_builder(vector_log_body);
|
||||
llvm::FastMathFlags fast_math_flags;
|
||||
fast_math_flags.setFast();
|
||||
ir_builder.setFastMathFlags(fast_math_flags);
|
||||
|
||||
llvm::Value* input = &*vector_log_function->arg_begin();
|
||||
VectorSupportLibrary vsl(F32, vector_width, &ir_builder, "log_f32");
|
||||
|
||||
const float half = 0.5;
|
||||
|
||||
// This implements the same polynomial approximation as implemented in Eigen3.
|
||||
// Returns NaN for x < 0, -INF for x = 0
|
||||
const float cephes_SQRTHF = 0.707106781186547524;
|
||||
const float cephes_log_p0 = 7.0376836292E-2;
|
||||
const float cephes_log_p1 = -1.1514610310E-1;
|
||||
const float cephes_log_p2 = 1.1676998740E-1;
|
||||
const float cephes_log_p3 = -1.2420140846E-1;
|
||||
const float cephes_log_p4 = +1.4249322787E-1;
|
||||
const float cephes_log_p5 = -1.6668057665E-1;
|
||||
const float cephes_log_p6 = +2.0000714765E-1;
|
||||
const float cephes_log_p7 = -2.4999993993E-1;
|
||||
const float cephes_log_p8 = +3.3333331174E-1;
|
||||
const float cephes_log_q1 = -2.12194440e-4;
|
||||
const float cephes_log_q2 = 0.693359375;
|
||||
|
||||
// The smallest non denormalized float number.
|
||||
const float min_norm_pos = tensorflow::bit_cast<float, int32>(0x00800000);
|
||||
const float minus_inf = tensorflow::bit_cast<float, int32>(0xff800000);
|
||||
|
||||
// NB! This number is denormal and since TF sets the denormals-are-zero flag
|
||||
// (and if TF didn't, -ffast-math would) trying to operate on this float using
|
||||
// C++ operations (including, for instance, implicit conversion to double)
|
||||
// will coerce this to zero.
|
||||
const float inv_mant_mask = tensorflow::bit_cast<float, int32>(~0x7f800000);
|
||||
|
||||
// invalid_mask is set if x is negative or NaN (and therefore output
|
||||
// must be NaN).
|
||||
llvm::Value* invalid_mask = vsl.FCmpULEMask(input, vsl.GetZeroVector());
|
||||
llvm::Value* iszero_mask = vsl.FCmpEQMask(input, vsl.GetZeroVector());
|
||||
|
||||
// Cut off denormalized stuff.
|
||||
input = vsl.Max(min_norm_pos, input);
|
||||
|
||||
// VectorSupportLibrary (intentionally) can't juggle more than one type at a
|
||||
// time so drop down to IRBuilder for this bit.
|
||||
llvm::Value* vector_constant_0x7f =
|
||||
ir_builder.CreateVectorSplat(vector_width, ir_builder.getInt32(0x7f));
|
||||
llvm::Value* vector_constant_23 =
|
||||
ir_builder.CreateVectorSplat(vector_width, ir_builder.getInt32(23));
|
||||
llvm::Type* i32_vector_type =
|
||||
llvm::VectorType::get(ir_builder.getInt32Ty(), vector_width);
|
||||
|
||||
llvm::Value* emm0 = ir_builder.CreateLShr(
|
||||
ir_builder.CreateBitCast(input, i32_vector_type), vector_constant_23);
|
||||
|
||||
// Keep only the fractional part.
|
||||
input = vsl.FloatAnd(input, inv_mant_mask);
|
||||
input = vsl.FloatOr(input, half);
|
||||
|
||||
emm0 = ir_builder.CreateSub(emm0, vector_constant_0x7f);
|
||||
llvm::Value* e =
|
||||
vsl.Add(1.0f, ir_builder.CreateSIToFP(emm0, vsl.vector_type()));
|
||||
|
||||
// part2:
|
||||
// if( x < SQRTHF ) {
|
||||
// e -= 1;
|
||||
// x = x + x - 1.0;
|
||||
// } else { x = x - 1.0; }
|
||||
llvm::Value* mask = vsl.FCmpOLTMask(input, cephes_SQRTHF);
|
||||
llvm::Value* tmp = vsl.FloatAnd(input, mask);
|
||||
input = vsl.Sub(input, 1.0);
|
||||
e = vsl.Sub(e, vsl.FloatAnd(mask, 1.0));
|
||||
input = vsl.Add(input, tmp);
|
||||
|
||||
llvm::Value* x2 = vsl.Mul(input, input);
|
||||
llvm::Value* x3 = vsl.Mul(x2, input);
|
||||
|
||||
llvm::Value *y, *y1, *y2;
|
||||
y = vsl.MulAdd(input, cephes_log_p0, cephes_log_p1);
|
||||
y1 = vsl.MulAdd(input, cephes_log_p3, cephes_log_p4);
|
||||
y2 = vsl.MulAdd(input, cephes_log_p6, cephes_log_p7);
|
||||
y = vsl.MulAdd(y, input, cephes_log_p2);
|
||||
y1 = vsl.MulAdd(y1, input, cephes_log_p5);
|
||||
y2 = vsl.MulAdd(y2, input, cephes_log_p8);
|
||||
y = vsl.MulAdd(y, x3, y1);
|
||||
y = vsl.MulAdd(y, x3, y2);
|
||||
y = vsl.Mul(y, x3);
|
||||
|
||||
y1 = vsl.Mul(cephes_log_q1, e);
|
||||
tmp = vsl.Mul(half, x2);
|
||||
y = vsl.Add(y, y1);
|
||||
input = vsl.Sub(input, tmp);
|
||||
y2 = vsl.Mul(cephes_log_q2, e);
|
||||
input = vsl.Add(input, y);
|
||||
input = vsl.Add(input, y2);
|
||||
|
||||
// Negative arg will be NAN, 0 will be -INF.
|
||||
llvm::Value* or_lhs =
|
||||
vsl.FloatAndNot(iszero_mask, vsl.FloatOr(input, invalid_mask));
|
||||
llvm::Value* or_rhs = vsl.FloatAnd(iszero_mask, minus_inf);
|
||||
llvm::Value* result = vsl.FloatOr(or_lhs, or_rhs);
|
||||
|
||||
ir_builder.CreateRet(result);
|
||||
|
||||
DCHECK(!llvm::verifyFunction(*vector_log_function));
|
||||
return vector_log_function;
|
||||
}
|
||||
} // namespace
|
||||
|
||||
void RewriteIRRuntimeFunctions(llvm::Module* module, bool enable_fast_math) {
|
||||
@ -187,11 +314,21 @@ void RewriteIRRuntimeFunctions(llvm::Module* module, bool enable_fast_math) {
|
||||
EmitVectorF32ExpIfNeeded(module, kExpV8F32SymbolName,
|
||||
/*vector_width=*/8, enable_fast_math);
|
||||
|
||||
auto* log_v4f32 =
|
||||
EmitVectorF32LogIfNeeded(module, kLogV4F32SymbolName,
|
||||
/*vector_width=*/4, enable_fast_math);
|
||||
auto* log_v8f32 =
|
||||
EmitVectorF32LogIfNeeded(module, kLogV8F32SymbolName,
|
||||
/*vector_width=*/8, enable_fast_math);
|
||||
|
||||
// Gather all the call sites, force inline them and then delete the vector
|
||||
// function bodies.
|
||||
//
|
||||
// TODO(b/73081976): Should we avoid inlining these intrinsics in some cases?
|
||||
|
||||
std::vector<llvm::CallInst*> calls_to_inline;
|
||||
for (auto* function : {tanh_v4f32, tanh_v8f32, exp_v4f32, exp_v8f32}) {
|
||||
for (auto* function :
|
||||
{tanh_v4f32, tanh_v8f32, exp_v4f32, exp_v8f32, log_v4f32, log_v8f32}) {
|
||||
if (function != nullptr) {
|
||||
for (auto* user : function->users()) {
|
||||
calls_to_inline.push_back(llvm::cast<llvm::CallInst>(user));
|
||||
@ -204,7 +341,8 @@ void RewriteIRRuntimeFunctions(llvm::Module* module, bool enable_fast_math) {
|
||||
CHECK(llvm::InlineFunction(call_to_inline, inline_function_info));
|
||||
}
|
||||
|
||||
for (auto* function : {tanh_v4f32, tanh_v8f32, exp_v4f32, exp_v8f32}) {
|
||||
for (auto* function :
|
||||
{tanh_v4f32, tanh_v8f32, exp_v4f32, exp_v8f32, log_v4f32, log_v8f32}) {
|
||||
if (function != nullptr) {
|
||||
function->eraseFromParent();
|
||||
}
|
||||
|
@ -27,6 +27,8 @@ extern const char* const kTanhV4F32SymbolName;
|
||||
extern const char* const kTanhV8F32SymbolName;
|
||||
extern const char* const kExpV4F32SymbolName;
|
||||
extern const char* const kExpV8F32SymbolName;
|
||||
extern const char* const kLogV4F32SymbolName;
|
||||
extern const char* const kLogV8F32SymbolName;
|
||||
|
||||
// The following CPU runtime functions have LLVM-IR only implementations:
|
||||
//
|
||||
|
@ -28,9 +28,6 @@ limitations under the License.
|
||||
#include "llvm/Support/Host.h"
|
||||
#include "tensorflow/compiler/xla/ptr_util.h"
|
||||
#include "tensorflow/compiler/xla/service/cpu/cpu_runtime.h"
|
||||
#include "tensorflow/compiler/xla/service/cpu/cpu_runtime_avx.h"
|
||||
#include "tensorflow/compiler/xla/service/cpu/cpu_runtime_neon.h"
|
||||
#include "tensorflow/compiler/xla/service/cpu/cpu_runtime_sse4_1.h"
|
||||
#include "tensorflow/compiler/xla/service/cpu/custom_call_target_registry.h"
|
||||
#include "tensorflow/compiler/xla/service/cpu/orc_jit_memory_mapper.h"
|
||||
#include "tensorflow/compiler/xla/service/cpu/runtime_conv2d.h"
|
||||
@ -101,27 +98,6 @@ llvm::StringRef GetHostCpuName() {
|
||||
cpu_name.consume_back("-avx512");
|
||||
return cpu_name;
|
||||
}
|
||||
|
||||
CompilerFunctor::VectorIntrinsics GetAvailableIntrinsics() {
|
||||
CompilerFunctor::VectorIntrinsics intrinsics;
|
||||
#ifdef TF_XLA_HAS_SSE4_1
|
||||
intrinsics.sse_intrinsics = true;
|
||||
#else
|
||||
intrinsics.sse_intrinsics = false;
|
||||
#endif
|
||||
#ifdef TF_XLA_HAS_AVX
|
||||
intrinsics.avx_intrinsics = true;
|
||||
#else
|
||||
intrinsics.avx_intrinsics = false;
|
||||
#endif
|
||||
#ifdef TF_XLA_HAS_NEON
|
||||
intrinsics.neon_intrinsics = true;
|
||||
#else
|
||||
intrinsics.neon_intrinsics = false;
|
||||
#endif
|
||||
return intrinsics;
|
||||
}
|
||||
|
||||
} // namespace
|
||||
|
||||
SimpleOrcJIT::SimpleOrcJIT(const llvm::TargetOptions& target_options,
|
||||
@ -169,13 +145,12 @@ SimpleOrcJIT::SimpleOrcJIT(const llvm::TargetOptions& target_options,
|
||||
orc_jit_memory_mapper::GetInstance());
|
||||
},
|
||||
[this](llvm::orc::VModuleKey K) { return symbol_resolver_; }),
|
||||
compile_layer_(
|
||||
object_layer_,
|
||||
CompilerFunctor(target_machine_.get(), &disassembler_, opt_level,
|
||||
optimize_for_size, enable_fast_math,
|
||||
disable_expensive_passes, GetAvailableIntrinsics(),
|
||||
std::move(pre_optimization_hook),
|
||||
std::move(post_optimization_hook))) {
|
||||
compile_layer_(object_layer_,
|
||||
CompilerFunctor(target_machine_.get(), &disassembler_,
|
||||
opt_level, optimize_for_size,
|
||||
enable_fast_math, disable_expensive_passes,
|
||||
std::move(pre_optimization_hook),
|
||||
std::move(post_optimization_hook))) {
|
||||
VLOG(1) << "CPU target: " << target_machine_->getTargetCPU().str()
|
||||
<< " features: " << target_machine_->getTargetFeatureString().str();
|
||||
}
|
||||
@ -240,15 +215,6 @@ bool RegisterKnownJITSymbols() {
|
||||
REGISTER_CPU_RUNTIME_SYMBOL(EigenSingleThreadedConvF32);
|
||||
REGISTER_CPU_RUNTIME_SYMBOL(EigenSingleThreadedMatMulF32);
|
||||
REGISTER_CPU_RUNTIME_SYMBOL(EigenSingleThreadedMatMulF64);
|
||||
#ifdef TF_XLA_HAS_NEON
|
||||
REGISTER_CPU_RUNTIME_SYMBOL(LogV4F32NEON);
|
||||
#endif
|
||||
#ifdef TF_XLA_HAS_SSE4_1
|
||||
REGISTER_CPU_RUNTIME_SYMBOL(LogV4F32SSE);
|
||||
#endif
|
||||
#ifdef TF_XLA_HAS_AVX
|
||||
REGISTER_CPU_RUNTIME_SYMBOL(LogV8F32AVX);
|
||||
#endif
|
||||
REGISTER_CPU_RUNTIME_SYMBOL(ParallelForkJoin);
|
||||
REGISTER_CPU_RUNTIME_SYMBOL(ReleaseInfeedBufferAfterDequeue);
|
||||
REGISTER_CPU_RUNTIME_SYMBOL(ReleaseOutfeedBufferAfterPopulation);
|
||||
|
@ -103,15 +103,92 @@ llvm::Value* VectorSupportLibrary::Div(llvm::Value* lhs, llvm::Value* rhs) {
|
||||
}
|
||||
}
|
||||
|
||||
llvm::Value* VectorSupportLibrary::Clamp(llvm::Value* a, double low,
|
||||
double high) {
|
||||
llvm::Value* VectorSupportLibrary::Clamp(llvm::Value* a, float low,
|
||||
float high) {
|
||||
AssertCorrectTypes({a});
|
||||
llvm::Type* type = a->getType();
|
||||
CHECK_LT(low, high);
|
||||
CHECK(scalar_type_->isFloatingPointTy());
|
||||
return llvm_ir::EmitFloatMin(
|
||||
llvm_ir::EmitFloatMax(a, llvm::ConstantFP::get(type, low), ir_builder_),
|
||||
llvm::ConstantFP::get(type, high), ir_builder_);
|
||||
llvm_ir::EmitFloatMax(a, GetConstantFloat(type, low), ir_builder_),
|
||||
GetConstantFloat(type, high), ir_builder_);
|
||||
}
|
||||
|
||||
llvm::Value* VectorSupportLibrary::FCmpEQMask(llvm::Value* lhs,
|
||||
llvm::Value* rhs) {
|
||||
AssertCorrectTypes({lhs, rhs});
|
||||
return I1ToFloat(ir_builder()->CreateFCmpOEQ(lhs, rhs, name()));
|
||||
}
|
||||
|
||||
llvm::Value* VectorSupportLibrary::FCmpOLTMask(llvm::Value* lhs,
|
||||
llvm::Value* rhs) {
|
||||
AssertCorrectTypes({lhs, rhs});
|
||||
return I1ToFloat(ir_builder()->CreateFCmpOLT(lhs, rhs, name()));
|
||||
}
|
||||
|
||||
llvm::Value* VectorSupportLibrary::FCmpULEMask(llvm::Value* lhs,
|
||||
llvm::Value* rhs) {
|
||||
AssertCorrectTypes({lhs, rhs});
|
||||
return I1ToFloat(ir_builder()->CreateFCmpULE(lhs, rhs, name()));
|
||||
}
|
||||
|
||||
llvm::Value* VectorSupportLibrary::I1ToFloat(llvm::Value* i1) {
|
||||
bool is_vector = llvm::isa<llvm::VectorType>(i1->getType());
|
||||
llvm::Type* integer_type = IntegerTypeForFloatSize(is_vector);
|
||||
return ir_builder()->CreateBitCast(
|
||||
ir_builder()->CreateSExt(i1, integer_type, name()),
|
||||
is_vector ? vector_type() : scalar_type(), name());
|
||||
}
|
||||
|
||||
llvm::Type* VectorSupportLibrary::IntegerTypeForFloatSize(bool vector) {
|
||||
CHECK(scalar_type()->isFloatingPointTy());
|
||||
const llvm::DataLayout& data_layout =
|
||||
ir_builder()->GetInsertBlock()->getModule()->getDataLayout();
|
||||
int64 float_size_bits = data_layout.getTypeSizeInBits(scalar_type());
|
||||
llvm::Type* scalar_int_type = ir_builder()->getIntNTy(float_size_bits);
|
||||
if (vector) {
|
||||
return llvm::VectorType::get(scalar_int_type, vector_size());
|
||||
} else {
|
||||
return scalar_int_type;
|
||||
}
|
||||
}
|
||||
|
||||
llvm::Value* VectorSupportLibrary::BroadcastScalar(llvm::Value* x) {
|
||||
CHECK_EQ(x->getType(), scalar_type());
|
||||
return ir_builder()->CreateVectorSplat(vector_size(), x, name());
|
||||
}
|
||||
|
||||
llvm::Value* VectorSupportLibrary::FloatAnd(llvm::Value* lhs,
|
||||
llvm::Value* rhs) {
|
||||
AssertCorrectTypes({lhs, rhs});
|
||||
llvm::Type* int_type =
|
||||
IntegerTypeForFloatSize(lhs->getType() == vector_type());
|
||||
return ir_builder()->CreateBitCast(
|
||||
ir_builder()->CreateAnd(
|
||||
ir_builder()->CreateBitCast(lhs, int_type, name()),
|
||||
ir_builder()->CreateBitCast(rhs, int_type, name()), name()),
|
||||
vector_type());
|
||||
}
|
||||
|
||||
llvm::Value* VectorSupportLibrary::FloatNot(llvm::Value* lhs) {
|
||||
AssertCorrectTypes({lhs});
|
||||
llvm::Type* int_type =
|
||||
IntegerTypeForFloatSize(lhs->getType() == vector_type());
|
||||
return ir_builder()->CreateBitCast(
|
||||
ir_builder()->CreateNot(
|
||||
ir_builder()->CreateBitCast(lhs, int_type, name()), name()),
|
||||
vector_type());
|
||||
}
|
||||
|
||||
llvm::Value* VectorSupportLibrary::FloatOr(llvm::Value* lhs, llvm::Value* rhs) {
|
||||
AssertCorrectTypes({lhs, rhs});
|
||||
llvm::Type* int_type =
|
||||
IntegerTypeForFloatSize(lhs->getType() == vector_type());
|
||||
return ir_builder()->CreateBitCast(
|
||||
ir_builder()->CreateOr(ir_builder()->CreateBitCast(lhs, int_type, name()),
|
||||
ir_builder()->CreateBitCast(rhs, int_type, name()),
|
||||
name()),
|
||||
vector_type(), name());
|
||||
}
|
||||
|
||||
llvm::Value* VectorSupportLibrary::AddInternal(llvm::Value* lhs,
|
||||
|
@ -41,40 +41,82 @@ class VectorSupportLibrary {
|
||||
llvm::Value* Mul(int64 lhs, llvm::Value* rhs) {
|
||||
return Mul(ir_builder()->getInt64(lhs), rhs);
|
||||
}
|
||||
llvm::Value* Mul(double lhs, llvm::Value* rhs) {
|
||||
return Mul(llvm::ConstantFP::get(rhs->getType(), lhs), rhs);
|
||||
llvm::Value* Mul(float lhs, llvm::Value* rhs) {
|
||||
return Mul(GetConstantFloat(rhs->getType(), lhs), rhs);
|
||||
}
|
||||
|
||||
llvm::Value* Add(llvm::Value* lhs, llvm::Value* rhs);
|
||||
llvm::Value* Add(int64 lhs, llvm::Value* rhs) {
|
||||
return Add(ir_builder()->getInt64(lhs), rhs);
|
||||
}
|
||||
llvm::Value* Add(double lhs, llvm::Value* rhs) {
|
||||
return Add(llvm::ConstantFP::get(vector_type(), lhs), rhs);
|
||||
llvm::Value* Add(float lhs, llvm::Value* rhs) {
|
||||
return Add(GetConstantFloat(rhs->getType(), lhs), rhs);
|
||||
}
|
||||
|
||||
llvm::Value* Sub(llvm::Value* lhs, llvm::Value* rhs);
|
||||
llvm::Value* Sub(llvm::Value* lhs, float rhs) {
|
||||
return Sub(lhs, GetConstantFloat(lhs->getType(), rhs));
|
||||
}
|
||||
llvm::Value* Max(llvm::Value* lhs, llvm::Value* rhs);
|
||||
llvm::Value* Max(float lhs, llvm::Value* rhs) {
|
||||
return Max(GetConstantFloat(rhs->getType(), lhs), rhs);
|
||||
}
|
||||
llvm::Value* Div(llvm::Value* lhs, llvm::Value* rhs);
|
||||
|
||||
llvm::Value* MulAdd(llvm::Value* a, llvm::Value* b, llvm::Value* c) {
|
||||
return Add(c, Mul(a, b));
|
||||
}
|
||||
|
||||
llvm::Value* MulAdd(llvm::Value* a, llvm::Value* b, double c) {
|
||||
return Add(llvm::ConstantFP::get(vector_type(), c), Mul(a, b));
|
||||
llvm::Value* MulAdd(llvm::Value* a, llvm::Value* b, float c) {
|
||||
return Add(GetConstantFloat(vector_type(), c), Mul(a, b));
|
||||
}
|
||||
|
||||
llvm::Value* MulAdd(llvm::Value* a, double b, double c) {
|
||||
return Add(llvm::ConstantFP::get(a->getType(), c),
|
||||
Mul(a, llvm::ConstantFP::get(a->getType(), b)));
|
||||
llvm::Value* MulAdd(llvm::Value* a, float b, float c) {
|
||||
return Add(GetConstantFloat(a->getType(), c),
|
||||
Mul(a, GetConstantFloat(a->getType(), b)));
|
||||
}
|
||||
|
||||
llvm::Value* Floor(llvm::Value* a);
|
||||
|
||||
llvm::Value* Clamp(llvm::Value* a, double low, double high);
|
||||
llvm::Value* SplatFloat(double d) {
|
||||
return llvm::ConstantFP::get(vector_type(), d);
|
||||
llvm::Value* Clamp(llvm::Value* a, float low, float high);
|
||||
llvm::Value* SplatFloat(float d) {
|
||||
return GetConstantFloat(vector_type(), d);
|
||||
}
|
||||
|
||||
// These compare instructions return a floating point typed mask instead of an
|
||||
// i1. For instance, on a vector typed input, lanes where the predicate is
|
||||
// true get a float with all ones and other lanes get a float with all zeros.
|
||||
// This is slightly odd from the perspective of LLVM's type system, but it
|
||||
// makes kernel IR generation code written using VectorSupportLibrary (its
|
||||
// raison d'etre) less cluttered.
|
||||
|
||||
llvm::Value* FCmpEQMask(llvm::Value* lhs, llvm::Value* rhs);
|
||||
llvm::Value* FCmpULEMask(llvm::Value* lhs, llvm::Value* rhs);
|
||||
llvm::Value* FCmpOLTMask(llvm::Value* lhs, llvm::Value* rhs);
|
||||
llvm::Value* FCmpOLTMask(llvm::Value* lhs, float rhs) {
|
||||
return FCmpOLTMask(lhs, GetConstantFloat(lhs->getType(), rhs));
|
||||
}
|
||||
|
||||
// These boolean operations operate on the bitwise values of the floating
|
||||
// point inputs. They return a (vector of) float(s) but like in the mask
|
||||
// generating predicates above this type system oddity makes the kernel IR
|
||||
// generation code less cluttered.
|
||||
llvm::Value* FloatAnd(llvm::Value* lhs, llvm::Value* rhs);
|
||||
llvm::Value* FloatAnd(llvm::Value* lhs, float rhs) {
|
||||
return FloatAnd(lhs, GetConstantFloat(lhs->getType(), rhs));
|
||||
}
|
||||
llvm::Value* FloatOr(llvm::Value* lhs, llvm::Value* rhs);
|
||||
llvm::Value* FloatOr(llvm::Value* lhs, float rhs) {
|
||||
return FloatOr(lhs, GetConstantFloat(lhs->getType(), rhs));
|
||||
}
|
||||
llvm::Value* FloatNot(llvm::Value* lhs);
|
||||
llvm::Value* FloatAndNot(llvm::Value* lhs, llvm::Value* rhs) {
|
||||
return FloatAnd(FloatNot(lhs), rhs);
|
||||
}
|
||||
|
||||
llvm::Value* BroadcastScalar(llvm::Value* x);
|
||||
llvm::Value* BroadcastScalar(float d) {
|
||||
return BroadcastScalar(GetConstantFloat(scalar_type(), d));
|
||||
}
|
||||
|
||||
llvm::Value* ComputeOffsetPointer(llvm::Value* base_pointer,
|
||||
@ -194,6 +236,17 @@ class VectorSupportLibrary {
|
||||
std::vector<llvm::Value*> ComputeAvxOptimizedHorizontalSums(
|
||||
std::vector<llvm::Value*> vectors, llvm::Value* init_values);
|
||||
|
||||
llvm::Type* IntegerTypeForFloatSize(bool vector);
|
||||
llvm::Value* I1ToFloat(llvm::Value* i1);
|
||||
llvm::Value* GetConstantFloat(llvm::Type* type, float f) {
|
||||
llvm::Constant* scalar_value =
|
||||
llvm::ConstantFP::get(type->getContext(), llvm::APFloat(f));
|
||||
if (llvm::isa<llvm::VectorType>(type)) {
|
||||
return llvm::ConstantVector::getSplat(vector_size(), scalar_value);
|
||||
}
|
||||
return scalar_value;
|
||||
}
|
||||
|
||||
int64 vector_size_;
|
||||
PrimitiveType primitive_type_;
|
||||
llvm::IRBuilder<>* ir_builder_;
|
||||
|
@ -45,7 +45,7 @@ ConvolutionThunk::ConvolutionThunk(
|
||||
const BufferAllocation::Slice& scratch_buffer, const Shape& input_shape,
|
||||
const Shape& filter_shape, const Shape& output_shape, const Window& window,
|
||||
const ConvolutionDimensionNumbers& dim_nums, int64 algorithm,
|
||||
const HloInstruction* hlo)
|
||||
bool tensor_ops_enabled, const HloInstruction* hlo)
|
||||
: Thunk(Kind::kConvolution, hlo),
|
||||
convolution_kind_(convolution_kind),
|
||||
input_buffer_(input_buffer),
|
||||
@ -58,7 +58,8 @@ ConvolutionThunk::ConvolutionThunk(
|
||||
output_shape_(output_shape),
|
||||
window_(window),
|
||||
dim_nums_(dim_nums),
|
||||
algorithm_(algorithm) {}
|
||||
algorithm_(algorithm),
|
||||
tensor_ops_enabled_(tensor_ops_enabled) {}
|
||||
|
||||
Status ConvolutionThunk::ExecuteOnStream(
|
||||
const BufferAllocations& buffer_allocations, se::Stream* stream) {
|
||||
@ -72,7 +73,7 @@ Status ConvolutionThunk::ExecuteOnStream(
|
||||
buffer_allocations.GetDeviceAddress(scratch_buffer_);
|
||||
|
||||
se::dnn::AlgorithmConfig algorithm_config(
|
||||
se::dnn::AlgorithmDesc(algorithm_, /*use_tensor_ops=*/false));
|
||||
se::dnn::AlgorithmDesc(algorithm_, tensor_ops_enabled_));
|
||||
|
||||
TF_RETURN_IF_ERROR(RunCudnnConvolution(
|
||||
convolution_kind_, input_shape_, filter_shape_, output_shape_, input_data,
|
||||
|
@ -59,7 +59,7 @@ class ConvolutionThunk : public Thunk {
|
||||
const Shape& input_shape, const Shape& filter_shape,
|
||||
const Shape& output_shape, const Window& window,
|
||||
const ConvolutionDimensionNumbers& dim_nums, int64 algorithm,
|
||||
const HloInstruction* hlo);
|
||||
bool tensor_ops_enabled, const HloInstruction* hlo);
|
||||
|
||||
ConvolutionThunk(const ConvolutionThunk&) = delete;
|
||||
ConvolutionThunk& operator=(const ConvolutionThunk&) = delete;
|
||||
@ -99,6 +99,7 @@ class ConvolutionThunk : public Thunk {
|
||||
const Window window_;
|
||||
const ConvolutionDimensionNumbers dim_nums_;
|
||||
int64 algorithm_;
|
||||
bool tensor_ops_enabled_;
|
||||
};
|
||||
|
||||
} // namespace gpu
|
||||
|
@ -172,7 +172,7 @@ string NumBytesToString(int64 bytes) {
|
||||
// cache misses and doing extra work. Overall, caching doesn't seem worth the
|
||||
// trouble, but we may want to revisit this if we ever find a model where
|
||||
// caching would speed up compilation a lot.
|
||||
optional<std::pair<int64, int64>>
|
||||
optional<std::tuple<int64, bool, int64>>
|
||||
CudnnConvolutionAlgorithmPicker::PickBestAlgorithm(
|
||||
CudnnConvKind kind, const Shape& input_shape, const Shape& filter_shape,
|
||||
const Shape& output_shape, const Window& window,
|
||||
@ -260,8 +260,9 @@ CudnnConvolutionAlgorithmPicker::PickBestAlgorithm(
|
||||
<< AlgorithmToString(best_result.algorithm()) << ", takes "
|
||||
<< best_result.elapsed_time_in_ms() << "ms, and uses "
|
||||
<< best_result_bytes_used << "B of scratch memory.";
|
||||
return std::make_pair(best_result.algorithm().algo_id(),
|
||||
best_result_bytes_used);
|
||||
return std::make_tuple(best_result.algorithm().algo_id(),
|
||||
best_result.algorithm().tensor_ops_enabled(),
|
||||
best_result_bytes_used);
|
||||
}
|
||||
|
||||
LOG(WARNING) << "All algorithms tried for convolution " << instr->ToString()
|
||||
@ -277,19 +278,19 @@ StatusOr<bool> CudnnConvolutionAlgorithmPicker::RunOnInstruction(
|
||||
const auto& lhs_shape = instr->operand(0)->shape();
|
||||
const auto& rhs_shape = instr->operand(1)->shape();
|
||||
const auto& conv_result_shape = instr->shape().tuple_shapes(0);
|
||||
optional<std::pair<int64, int64>> alg_and_scratch_bytes;
|
||||
optional<std::tuple<int64, bool, int64>> alg_scratch_and_tc;
|
||||
if (call_target == kCudnnConvForwardCallTarget) {
|
||||
alg_and_scratch_bytes = PickBestAlgorithm(
|
||||
alg_scratch_and_tc = PickBestAlgorithm(
|
||||
CudnnConvKind::kForward, /*input_shape=*/lhs_shape,
|
||||
/*filter_shape=*/rhs_shape, /*output_shape=*/conv_result_shape,
|
||||
instr->window(), instr->convolution_dimension_numbers(), instr);
|
||||
} else if (call_target == kCudnnConvBackwardInputCallTarget) {
|
||||
alg_and_scratch_bytes = PickBestAlgorithm(
|
||||
alg_scratch_and_tc = PickBestAlgorithm(
|
||||
CudnnConvKind::kBackwardInput, /*input_shape=*/conv_result_shape,
|
||||
/*filter_shape=*/rhs_shape, /*output_shape=*/lhs_shape, instr->window(),
|
||||
instr->convolution_dimension_numbers(), instr);
|
||||
} else if (call_target == kCudnnConvBackwardFilterCallTarget) {
|
||||
alg_and_scratch_bytes = PickBestAlgorithm(
|
||||
alg_scratch_and_tc = PickBestAlgorithm(
|
||||
CudnnConvKind::kBackwardFilter, /*input_shape=*/lhs_shape,
|
||||
/*filter_shape=*/conv_result_shape, /*output_shape=*/rhs_shape,
|
||||
instr->window(), instr->convolution_dimension_numbers(), instr);
|
||||
@ -298,17 +299,20 @@ StatusOr<bool> CudnnConvolutionAlgorithmPicker::RunOnInstruction(
|
||||
<< instr->ToString();
|
||||
}
|
||||
|
||||
if (!alg_and_scratch_bytes.has_value()) {
|
||||
if (!alg_scratch_and_tc.has_value()) {
|
||||
return false;
|
||||
}
|
||||
|
||||
int64 algorithm;
|
||||
bool tensor_ops_enabled;
|
||||
int64 scratch_bytes;
|
||||
std::tie(algorithm, scratch_bytes) = *alg_and_scratch_bytes;
|
||||
|
||||
std::tie(algorithm, tensor_ops_enabled, scratch_bytes) = *alg_scratch_and_tc;
|
||||
|
||||
VLOG(1) << "Setting cudnn conv to use algorithm " << algorithm << " and "
|
||||
<< NumBytesToString(scratch_bytes)
|
||||
<< " of scratch memory: " << instr->ToString();
|
||||
<< " of scratch memory: " << instr->ToString()
|
||||
<< " tensor_ops_enabled: " << tensor_ops_enabled;
|
||||
|
||||
// Replace instr with a new CustomCall which has the correct algorithm, and
|
||||
// whose output shape has the appropriate amount of scratch memory.
|
||||
@ -318,10 +322,15 @@ StatusOr<bool> CudnnConvolutionAlgorithmPicker::RunOnInstruction(
|
||||
ShapeUtil::MakeShape(U8, {scratch_bytes})});
|
||||
HloInstruction* algorithm_hlo = computation->AddInstruction(
|
||||
HloInstruction::CreateConstant(Literal::CreateR0<int64>(algorithm)));
|
||||
HloInstruction* tensor_ops_enabled_hlo =
|
||||
computation->AddInstruction(HloInstruction::CreateConstant(
|
||||
Literal::CreateR0<bool>(tensor_ops_enabled)));
|
||||
|
||||
HloInstruction* new_call =
|
||||
computation->AddInstruction(HloInstruction::CreateCustomCall(
|
||||
new_call_shape,
|
||||
{instr->mutable_operand(0), instr->mutable_operand(1), algorithm_hlo},
|
||||
{instr->mutable_operand(0), instr->mutable_operand(1), algorithm_hlo,
|
||||
tensor_ops_enabled_hlo},
|
||||
instr->custom_call_target()));
|
||||
new_call->set_window(instr->window());
|
||||
new_call->set_convolution_dimension_numbers(
|
||||
|
@ -47,7 +47,7 @@ class CudnnConvolutionAlgorithmPicker : public HloPassInterface {
|
||||
private:
|
||||
StatusOr<bool> RunOnComputation(HloComputation* computation);
|
||||
StatusOr<bool> RunOnInstruction(HloInstruction* instr);
|
||||
tensorflow::gtl::optional<std::pair<int64, int64>> PickBestAlgorithm(
|
||||
tensorflow::gtl::optional<std::tuple<int64, bool, int64>> PickBestAlgorithm(
|
||||
CudnnConvKind kind, const Shape& input_shape, const Shape& filter_shape,
|
||||
const Shape& output_shape, const Window& window,
|
||||
const ConvolutionDimensionNumbers& dnums, HloInstruction* instr);
|
||||
|
@ -106,6 +106,9 @@ Status RunCudnnConvolution(
|
||||
se::ScratchAllocator* scratch_allocator, const Window& window,
|
||||
const ConvolutionDimensionNumbers& dnums, AlgorithmConfig algorithm,
|
||||
Stream* stream, ProfileResult* profile_result /*= nullptr*/) {
|
||||
VLOG(3) << "Convolution Algorithm: " << algorithm.algorithm().algo_id();
|
||||
VLOG(3) << "tensor_ops_enabled: "
|
||||
<< algorithm.algorithm().tensor_ops_enabled();
|
||||
VLOG(3) << "Convolution kind: " << CudnnConvKindToString(kind);
|
||||
VLOG(3) << "input shape: { " << ShapeUtil::HumanString(input_shape) << " }";
|
||||
VLOG(3) << "filter shape: { " << ShapeUtil::HumanString(filter_shape) << " }";
|
||||
|
@ -79,9 +79,9 @@ StatusOr<bool> GpuCopyInsertion::Run(HloModule* module) {
|
||||
TF_RETURN_IF_ERROR(copy_operand_if_constant(i));
|
||||
}
|
||||
} else if (IsCustomCallToDnnConvolution(*hlo)) {
|
||||
// The last argument to a CUDNN convolution is its algorithm, which must
|
||||
// be an HLO constant -- it shouldn't be copied.
|
||||
for (int64 i = 0; i < hlo->operand_count() - 1; ++i) {
|
||||
// The last two arguments to a CUDNN convolution are two HLO constants for
|
||||
// cudnn algorithm and tensor_ops_enabled flag, which shouldn't be copied.
|
||||
for (int64 i = 0; i < hlo->operand_count() - 2; ++i) {
|
||||
TF_RETURN_IF_ERROR(copy_operand_if_constant(i));
|
||||
}
|
||||
} else if (ImplementedAsLibraryCall(*hlo)) {
|
||||
|
@ -63,10 +63,11 @@ bool IsCustomCallToDnnBatchNorm(const HloInstruction& hlo);
|
||||
// strings.
|
||||
//
|
||||
// These CustomCalls have window() and convolution_dimension_numbers() set like
|
||||
// regular convolution ops. They have the same LHS and RHS operands, plus one
|
||||
// additional int64 operand, representing which cudnn algorithm to run. This
|
||||
// operand must be an HLO constant. A value of -1 means that the implementation
|
||||
// is free to choose the best algorithm it can.
|
||||
// regular convolution ops. They have the same LHS and RHS operands, plus two
|
||||
// additional constant operands: an int64 operand for the cudnn algorithm and
|
||||
// a bool operand for whether tensor_ops is enabled. A value of -1 for the cudnn
|
||||
// algorithm means that the implementation is free to choose the best algorithm
|
||||
// it can.
|
||||
//
|
||||
// These calls output a tuple (conv_result, scratch_memory), where conv_result
|
||||
// is the actual result of the convolution, and scratch_memory is temporary
|
||||
|
@ -393,6 +393,11 @@ Status IrEmitterUnnested::HandleCustomCall(HloInstruction* custom_call) {
|
||||
CHECK(algorithm_inst->IsConstant()) << algorithm_inst->ToString();
|
||||
int64 algorithm = algorithm_inst->literal().Get<int64>({});
|
||||
|
||||
const HloInstruction* tensor_ops_enabled_inst = custom_call->operand(3);
|
||||
CHECK(tensor_ops_enabled_inst->IsConstant())
|
||||
<< tensor_ops_enabled_inst->ToString();
|
||||
bool tensor_ops_enabled = tensor_ops_enabled_inst->literal().Get<bool>({});
|
||||
|
||||
const auto& target = custom_call->custom_call_target();
|
||||
std::unique_ptr<ConvolutionThunk> thunk;
|
||||
if (target == kCudnnConvForwardCallTarget) {
|
||||
@ -407,7 +412,7 @@ Status IrEmitterUnnested::HandleCustomCall(HloInstruction* custom_call) {
|
||||
/*filter_shape=*/rhs_shape,
|
||||
/*output_shape=*/conv_result_shape, //
|
||||
custom_call->window(), custom_call->convolution_dimension_numbers(),
|
||||
algorithm, custom_call);
|
||||
algorithm, tensor_ops_enabled, custom_call);
|
||||
} else if (target == kCudnnConvBackwardInputCallTarget) {
|
||||
thunk = MakeUnique<ConvolutionThunk>(
|
||||
CudnnConvKind::kBackwardInput,
|
||||
@ -420,7 +425,7 @@ Status IrEmitterUnnested::HandleCustomCall(HloInstruction* custom_call) {
|
||||
/*filter_shape=*/rhs_shape,
|
||||
/*output_shape=*/lhs_shape, //
|
||||
custom_call->window(), custom_call->convolution_dimension_numbers(),
|
||||
algorithm, custom_call);
|
||||
algorithm, tensor_ops_enabled, custom_call);
|
||||
} else if (target == kCudnnConvBackwardFilterCallTarget) {
|
||||
thunk = MakeUnique<ConvolutionThunk>(
|
||||
CudnnConvKind::kBackwardFilter,
|
||||
@ -433,7 +438,7 @@ Status IrEmitterUnnested::HandleCustomCall(HloInstruction* custom_call) {
|
||||
/*filter_shape=*/conv_result_shape,
|
||||
/*output_shape=*/rhs_shape, //
|
||||
custom_call->window(), custom_call->convolution_dimension_numbers(),
|
||||
algorithm, custom_call);
|
||||
algorithm, tensor_ops_enabled, custom_call);
|
||||
} else {
|
||||
LOG(FATAL) << "Unexpected custom call target: "
|
||||
<< custom_call->custom_call_target();
|
||||
|
@ -1805,7 +1805,8 @@ void HloInstruction::RemoveUser(HloInstruction* user) {
|
||||
|
||||
Status HloInstruction::ReplaceUseWith(HloInstruction* user,
|
||||
HloInstruction* new_producer) {
|
||||
TF_RET_CHECK(ShapeUtil::Compatible(shape(), new_producer->shape()))
|
||||
TF_RET_CHECK(
|
||||
ShapeUtil::CompatibleIgnoringFpPrecision(shape(), new_producer->shape()))
|
||||
<< "this shape: " << ShapeUtil::HumanString(shape())
|
||||
<< ", replacement shape: "
|
||||
<< ShapeUtil::HumanString(new_producer->shape());
|
||||
@ -1828,8 +1829,8 @@ Status HloInstruction::ReplaceOperandWith(int64 operand_num,
|
||||
TF_RET_CHECK(operand_num >= 0);
|
||||
TF_RET_CHECK(operand_num < operand_count());
|
||||
HloInstruction* old_operand = mutable_operand(operand_num);
|
||||
TF_RET_CHECK(
|
||||
ShapeUtil::Compatible(old_operand->shape(), new_operand->shape()))
|
||||
TF_RET_CHECK(ShapeUtil::CompatibleIgnoringFpPrecision(old_operand->shape(),
|
||||
new_operand->shape()))
|
||||
<< old_operand->shape().ShortDebugString() << " is not compatible with "
|
||||
<< new_operand->shape().ShortDebugString();
|
||||
operands_[operand_num] = new_operand;
|
||||
|
@ -13,6 +13,8 @@ See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
|
||||
#include <set>
|
||||
|
||||
#include "tensorflow/compiler/xla/service/hlo_verifier.h"
|
||||
#include "tensorflow/compiler/xla/status_macros.h"
|
||||
#include "tensorflow/core/lib/core/errors.h"
|
||||
@ -164,6 +166,8 @@ Status ShapeVerifier::HandleBroadcast(HloInstruction* broadcast) {
|
||||
// HLO broadcast has no exact analog at the proto level so there is no
|
||||
// ShapeInference method. Check the output shape explicitly.
|
||||
const Shape& operand_shape = broadcast->operand(0)->shape();
|
||||
// Check for mixed precision.
|
||||
TF_RETURN_IF_ERROR(CheckShape(broadcast, broadcast->shape()));
|
||||
TF_RET_CHECK(ShapeUtil::Rank(operand_shape) ==
|
||||
broadcast->dimensions().size());
|
||||
for (int64 operand_dimension = 0;
|
||||
@ -178,6 +182,8 @@ Status ShapeVerifier::HandleBroadcast(HloInstruction* broadcast) {
|
||||
}
|
||||
|
||||
Status ShapeVerifier::HandleReshape(HloInstruction* reshape) {
|
||||
// Check for mixed precision.
|
||||
TF_RETURN_IF_ERROR(CheckShape(reshape, reshape->shape()));
|
||||
TF_RET_CHECK(ShapeUtil::ElementsIn(reshape->shape()) ==
|
||||
ShapeUtil::ElementsIn(reshape->operand(0)->shape()));
|
||||
return tensorflow::Status::OK();
|
||||
@ -359,13 +365,122 @@ Status ShapeVerifier::HandleBatchNormGrad(HloInstruction* batch_norm_grad) {
|
||||
batch_norm_grad->feature_index()));
|
||||
}
|
||||
|
||||
namespace {
|
||||
|
||||
// Checks that the instruction does not have mixed precision floating point
|
||||
// inputs.
|
||||
Status CheckMixedPrecisionOperands(const HloInstruction* instruction) {
|
||||
switch (instruction->opcode()) {
|
||||
// White list the following opcodes for mixed-precision check, because they
|
||||
// involve data pass through or grouping via tuples, where the precisions
|
||||
// of buffers can be different.
|
||||
case HloOpcode::kCall:
|
||||
case HloOpcode::kConditional:
|
||||
case HloOpcode::kConstant:
|
||||
case HloOpcode::kCrossReplicaSum:
|
||||
case HloOpcode::kCustomCall:
|
||||
case HloOpcode::kFusion:
|
||||
case HloOpcode::kGetTupleElement:
|
||||
case HloOpcode::kInfeed:
|
||||
case HloOpcode::kOutfeed:
|
||||
case HloOpcode::kParameter:
|
||||
case HloOpcode::kRecv:
|
||||
case HloOpcode::kRecvDone:
|
||||
case HloOpcode::kReducePrecision:
|
||||
case HloOpcode::kSelect:
|
||||
case HloOpcode::kSend:
|
||||
case HloOpcode::kSendDone:
|
||||
case HloOpcode::kTuple:
|
||||
case HloOpcode::kWhile:
|
||||
break;
|
||||
default: {
|
||||
PrimitiveType fp_type = PRIMITIVE_TYPE_INVALID;
|
||||
for (auto operand : instruction->operands()) {
|
||||
TF_RETURN_IF_ERROR(ShapeUtil::ForEachSubshapeWithStatus(
|
||||
operand->shape(),
|
||||
[&](const Shape& subshape, const ShapeIndex& index) {
|
||||
if (!ShapeUtil::ElementIsFloating(subshape)) {
|
||||
return Status::OK();
|
||||
}
|
||||
if (fp_type == PRIMITIVE_TYPE_INVALID) {
|
||||
fp_type = subshape.element_type();
|
||||
} else if (fp_type != subshape.element_type()) {
|
||||
return FailedPrecondition(
|
||||
"Seen floating point types of different precisions in "
|
||||
"%s, but mixed precision is disallowed.",
|
||||
instruction->ToString().c_str());
|
||||
}
|
||||
return Status::OK();
|
||||
}));
|
||||
}
|
||||
}
|
||||
}
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
} // namespace
|
||||
|
||||
Status ShapeVerifier::CheckShape(const HloInstruction* instruction,
|
||||
const Shape& expected_shape) {
|
||||
if (!ShapeUtil::Compatible(instruction->shape(), expected_shape)) {
|
||||
const Shape& inferred_shape) {
|
||||
// If allow_mixed_precision_ is false, check if there are operands with
|
||||
// different precisions. We need this check because ShapeInference allows
|
||||
// mixed precision inputs.
|
||||
if (!allow_mixed_precision_) {
|
||||
TF_RETURN_IF_ERROR(CheckMixedPrecisionOperands(instruction));
|
||||
}
|
||||
|
||||
// Check if the output shape matches the expected shape.
|
||||
bool compatible;
|
||||
// We treat BF16 and F32 as compatible types if mixed precision is allowed,
|
||||
// but only when the instruction defines the BF16/F32 buffer.
|
||||
switch (instruction->opcode()) {
|
||||
case HloOpcode::kSelect:
|
||||
if (ShapeUtil::IsTuple(inferred_shape) || !allow_mixed_precision_) {
|
||||
// Select only defines the top-level buffer, which in this case is the
|
||||
// tuple, so we cannot allow mixed precision.
|
||||
compatible =
|
||||
ShapeUtil::Compatible(instruction->shape(), inferred_shape);
|
||||
} else {
|
||||
compatible = ShapeUtil::CompatibleIgnoringFpPrecision(
|
||||
instruction->shape(), inferred_shape);
|
||||
}
|
||||
break;
|
||||
case HloOpcode::kGetTupleElement:
|
||||
case HloOpcode::kTuple:
|
||||
// Tuple and GetTupleElement do not define BF16/F32 buffers, so mixed
|
||||
// precision is disallowed.
|
||||
case HloOpcode::kConstant:
|
||||
case HloOpcode::kBitcast:
|
||||
case HloOpcode::kBitcastConvert:
|
||||
case HloOpcode::kCall:
|
||||
case HloOpcode::kConditional:
|
||||
case HloOpcode::kConvert:
|
||||
case HloOpcode::kCustomCall:
|
||||
case HloOpcode::kInfeed:
|
||||
case HloOpcode::kOutfeed:
|
||||
case HloOpcode::kParameter:
|
||||
case HloOpcode::kRecv:
|
||||
case HloOpcode::kRecvDone:
|
||||
case HloOpcode::kSend:
|
||||
case HloOpcode::kSendDone:
|
||||
case HloOpcode::kWhile:
|
||||
// The above opcodes should match the expected shapes exactly.
|
||||
compatible = ShapeUtil::Compatible(instruction->shape(), inferred_shape);
|
||||
break;
|
||||
default:
|
||||
if (allow_mixed_precision_) {
|
||||
compatible = ShapeUtil::CompatibleIgnoringFpPrecision(
|
||||
instruction->shape(), inferred_shape);
|
||||
} else {
|
||||
compatible =
|
||||
ShapeUtil::Compatible(instruction->shape(), inferred_shape);
|
||||
}
|
||||
}
|
||||
if (!compatible) {
|
||||
return InvalidArgument(
|
||||
"Expected instruction to have shape compatible with %s, actual "
|
||||
"shape is %s:\n%s",
|
||||
ShapeUtil::HumanString(expected_shape).c_str(),
|
||||
ShapeUtil::HumanString(inferred_shape).c_str(),
|
||||
ShapeUtil::HumanString(instruction->shape()).c_str(),
|
||||
instruction->ToString().c_str());
|
||||
}
|
||||
@ -373,14 +488,14 @@ Status ShapeVerifier::CheckShape(const HloInstruction* instruction,
|
||||
}
|
||||
|
||||
Status ShapeVerifier::CheckShape(const HloInstruction* instruction,
|
||||
const StatusOr<Shape>& expected_shape_status) {
|
||||
if (!expected_shape_status.ok()) {
|
||||
Status s = expected_shape_status.status();
|
||||
const StatusOr<Shape>& inferred_shape_status) {
|
||||
if (!inferred_shape_status.ok()) {
|
||||
Status s = inferred_shape_status.status();
|
||||
tensorflow::errors::AppendToMessage(&s, ", for instruction ",
|
||||
instruction->ToString());
|
||||
return s;
|
||||
}
|
||||
return CheckShape(instruction, expected_shape_status.ValueOrDie());
|
||||
return CheckShape(instruction, inferred_shape_status.ValueOrDie());
|
||||
}
|
||||
|
||||
Status ShapeVerifier::CheckUnaryShape(const HloInstruction* instruction) {
|
||||
|
@ -27,6 +27,10 @@ namespace xla {
|
||||
// TODO(b/26024837): Check output shape for all instruction types.
|
||||
class ShapeVerifier : public DfsHloVisitor {
|
||||
public:
|
||||
explicit ShapeVerifier() : allow_mixed_precision_(false) {}
|
||||
explicit ShapeVerifier(bool allow_mixed_precision)
|
||||
: allow_mixed_precision_(allow_mixed_precision) {}
|
||||
|
||||
Status HandleElementwiseUnary(HloInstruction* hlo) override;
|
||||
Status HandleElementwiseBinary(HloInstruction* hlo) override;
|
||||
Status HandleClamp(HloInstruction* clamp) override;
|
||||
@ -81,14 +85,14 @@ class ShapeVerifier : public DfsHloVisitor {
|
||||
}
|
||||
|
||||
protected:
|
||||
// Check the instruction's shape against the given expected shape and return
|
||||
// an appropriate error if there is a mismatch.
|
||||
// Check the instruction's shape against the shape given by ShapeInference
|
||||
// and return an appropriate error if there is a mismatch.
|
||||
Status CheckShape(const HloInstruction* instruction,
|
||||
const Shape& expected_shape);
|
||||
const Shape& inferred_shape);
|
||||
|
||||
// Overload which takes a StatusOr to reduce boilerplate in the caller.
|
||||
Status CheckShape(const HloInstruction* instruction,
|
||||
const StatusOr<Shape>& expected_shape_status);
|
||||
const StatusOr<Shape>& inferred_shape_status);
|
||||
|
||||
// Check a unary (binary, etc) instruction's shape against the inferred shape.
|
||||
Status CheckUnaryShape(const HloInstruction* instruction);
|
||||
@ -99,19 +103,32 @@ class ShapeVerifier : public DfsHloVisitor {
|
||||
// Checks if the given two instructions shares the same channel id.
|
||||
Status CheckSameChannel(const HloInstruction* instr1,
|
||||
const HloInstruction* instr2);
|
||||
|
||||
private:
|
||||
// Whether the inputs and output of an instruction can contain both F32s and
|
||||
// BF16s. Tuples that include both F32s and BF16s are allowed regardless of
|
||||
// this flag.
|
||||
bool allow_mixed_precision_;
|
||||
};
|
||||
|
||||
// HLO pass that verifies invariants of HLO instructions for each computation in
|
||||
// the module.
|
||||
class HloVerifier : public HloPassInterface {
|
||||
public:
|
||||
using ShapeVerifierFactory = std::function<std::unique_ptr<ShapeVerifier>()>;
|
||||
|
||||
// Uses standard shape inference.
|
||||
explicit HloVerifier()
|
||||
: shape_verifier_factory_([] { return MakeUnique<ShapeVerifier>(); }) {}
|
||||
: shape_verifier_factory_(
|
||||
[] { return MakeUnique<ShapeVerifier>(false); }) {}
|
||||
|
||||
explicit HloVerifier(bool allow_mixed_precision)
|
||||
: shape_verifier_factory_([allow_mixed_precision] {
|
||||
return MakeUnique<ShapeVerifier>(allow_mixed_precision);
|
||||
}) {}
|
||||
|
||||
// Uses custom shape verification.
|
||||
explicit HloVerifier(
|
||||
std::function<std::unique_ptr<ShapeVerifier>()> shape_verifier_factory)
|
||||
explicit HloVerifier(ShapeVerifierFactory shape_verifier_factory)
|
||||
: shape_verifier_factory_(std::move(shape_verifier_factory)) {}
|
||||
|
||||
~HloVerifier() override = default;
|
||||
@ -129,7 +146,7 @@ class HloVerifier : public HloPassInterface {
|
||||
// expectations. This is a factory function because ShapeVerifier, Note that
|
||||
// ShapeVerifier, being a DfsHloVisitor, is stateful. We want a clean object
|
||||
// for each run of the verifier.
|
||||
std::function<std::unique_ptr<ShapeVerifier>()> shape_verifier_factory_;
|
||||
ShapeVerifierFactory shape_verifier_factory_;
|
||||
};
|
||||
|
||||
} // namespace xla
|
||||
|
@ -209,7 +209,8 @@ tensorflow::Status VerifyReducerShape(const ProgramShape& reducer_shape,
|
||||
}
|
||||
|
||||
// Check that init_value's shape is suitable for reducer_shape.
|
||||
if (!ShapeUtil::Compatible(accumulator_shape, init_value_shape)) {
|
||||
if (!ShapeUtil::CompatibleIgnoringFpPrecision(accumulator_shape,
|
||||
init_value_shape)) {
|
||||
return InvalidArgument(
|
||||
"Reduction function's accumulator shape differs from the "
|
||||
"init_value shape: %s vs %s",
|
||||
@ -220,8 +221,8 @@ tensorflow::Status VerifyReducerShape(const ProgramShape& reducer_shape,
|
||||
// Check that the inputs can be passed in as the second argument.
|
||||
const Shape& input_element_shape =
|
||||
ShapeUtil::MakeShape(input_element_type, {});
|
||||
if (!ShapeUtil::Compatible(input_element_shape,
|
||||
reducer_shape.parameters(1))) {
|
||||
if (!ShapeUtil::CompatibleIgnoringFpPrecision(input_element_shape,
|
||||
reducer_shape.parameters(1))) {
|
||||
return InvalidArgument(
|
||||
"Reduction function's second parameter shape differs from the "
|
||||
"input type element type: %s vs %s",
|
||||
@ -231,7 +232,8 @@ tensorflow::Status VerifyReducerShape(const ProgramShape& reducer_shape,
|
||||
|
||||
// Currently the accumulator and inputs must be the same type,
|
||||
// though that restriction could be relaxed.
|
||||
if (!ShapeUtil::Compatible(accumulator_shape, reducer_shape.parameters(1))) {
|
||||
if (!ShapeUtil::CompatibleIgnoringFpPrecision(accumulator_shape,
|
||||
reducer_shape.parameters(1))) {
|
||||
return InvalidArgument(
|
||||
"Reduction function's second parameter shape currently must "
|
||||
"match the result shape. Got %s vs %s",
|
||||
@ -394,11 +396,13 @@ StatusOr<Shape> InferWindowOutputShape(const Shape& base_shape,
|
||||
dimension);
|
||||
}
|
||||
const Shape* arg_shape = nullptr;
|
||||
PrimitiveType element_type = PRIMITIVE_TYPE_INVALID;
|
||||
for (const Shape* shape : arg_shapes) {
|
||||
TF_RETURN_IF_ERROR(
|
||||
ExpectNotTupleOrOpaque(*shape, "operand of concatenation"));
|
||||
if (!arg_shape) {
|
||||
arg_shape = shape;
|
||||
element_type = arg_shape->element_type();
|
||||
continue;
|
||||
}
|
||||
if (ShapeUtil::Rank(*arg_shape) != ShapeUtil::Rank(*shape)) {
|
||||
@ -409,7 +413,7 @@ StatusOr<Shape> InferWindowOutputShape(const Shape& base_shape,
|
||||
ShapeUtil::HumanString(*arg_shape).c_str(), ShapeUtil::Rank(*shape),
|
||||
ShapeUtil::HumanString(*shape).c_str());
|
||||
}
|
||||
if (arg_shape->element_type() != shape->element_type()) {
|
||||
if (!ShapeUtil::SameElementTypeIgnoringFpPrecision(*arg_shape, *shape)) {
|
||||
return InvalidArgument(
|
||||
"cannot concatenate arrays with different element types: %s vs %s",
|
||||
PrimitiveType_Name(arg_shape->element_type()).c_str(),
|
||||
@ -431,6 +435,7 @@ StatusOr<Shape> InferWindowOutputShape(const Shape& base_shape,
|
||||
ShapeUtil::HumanString(*shape).c_str(), dimension);
|
||||
}
|
||||
}
|
||||
element_type = ShapeUtil::HigherPrecisionElementType(*shape, *arg_shape);
|
||||
}
|
||||
|
||||
std::vector<int64> new_dimensions(arg_shape->dimensions().begin(),
|
||||
@ -438,7 +443,7 @@ StatusOr<Shape> InferWindowOutputShape(const Shape& base_shape,
|
||||
for (size_t i = 1; i < arg_shapes.size(); ++i) {
|
||||
new_dimensions[dimension] += arg_shapes[i]->dimensions(dimension);
|
||||
}
|
||||
return ShapeUtil::MakeShape(arg_shape->element_type(), new_dimensions);
|
||||
return ShapeUtil::MakeShape(element_type, new_dimensions);
|
||||
}
|
||||
|
||||
/* static */ StatusOr<Shape> ShapeInference::InferConvertShape(
|
||||
@ -536,7 +541,8 @@ StatusOr<Shape> InferWindowOutputShape(const Shape& base_shape,
|
||||
ShapeUtil::HumanString(operand_shape).c_str(),
|
||||
padding_config.ShortDebugString().c_str());
|
||||
}
|
||||
if (operand_shape.element_type() != padding_value_shape.element_type()) {
|
||||
if (!ShapeUtil::SameElementTypeIgnoringFpPrecision(operand_shape,
|
||||
padding_value_shape)) {
|
||||
return InvalidArgument(
|
||||
"the element types of the operands to pad do not match");
|
||||
}
|
||||
@ -548,7 +554,9 @@ StatusOr<Shape> InferWindowOutputShape(const Shape& base_shape,
|
||||
std::max<int64>(operand_shape.dimensions(i) - 1, 0LL) *
|
||||
padding_config.dimensions(i).interior_padding();
|
||||
}
|
||||
return ShapeUtil::MakeShape(operand_shape.element_type(), dimensions);
|
||||
return ShapeUtil::MakeShape(
|
||||
ShapeUtil::HigherPrecisionElementType(operand_shape, padding_value_shape),
|
||||
dimensions);
|
||||
}
|
||||
|
||||
// Current DotDimensionNumbers Requirements:
|
||||
@ -673,7 +681,7 @@ Status ValidateDotDimensionNumbers(
|
||||
};
|
||||
|
||||
// Check if both element types are the same.
|
||||
if (lhs.element_type() != rhs.element_type()) {
|
||||
if (!ShapeUtil::SameElementTypeIgnoringFpPrecision(lhs, rhs)) {
|
||||
return fail("element types do not match");
|
||||
}
|
||||
|
||||
@ -736,7 +744,8 @@ Status ValidateDotDimensionNumbers(
|
||||
dimensions.push_back(rhs.dimensions(i));
|
||||
}
|
||||
}
|
||||
Shape result = ShapeUtil::MakeShape(lhs.element_type(), dimensions);
|
||||
Shape result = ShapeUtil::MakeShape(
|
||||
ShapeUtil::HigherPrecisionElementType(lhs, rhs), dimensions);
|
||||
|
||||
TF_DCHECK_OK(ShapeUtil::ValidateShapeWithOptionalLayout(result));
|
||||
VLOG(2) << "inferred dot shape: " << ShapeUtil::HumanString(result);
|
||||
@ -767,7 +776,8 @@ ShapeInference::InferDegenerateDimensionBroadcastShape(
|
||||
ShapeUtil::HumanString(rhs).c_str());
|
||||
}
|
||||
}
|
||||
return ShapeUtil::MakeShape(lhs.element_type(), output_dimensions);
|
||||
return ShapeUtil::MakeShape(ShapeUtil::HigherPrecisionElementType(lhs, rhs),
|
||||
output_dimensions);
|
||||
}
|
||||
|
||||
/* static */ StatusOr<Shape> ShapeInference::InferInDimBroadcastShape(
|
||||
@ -829,6 +839,8 @@ ShapeInference::InferDegenerateDimensionBroadcastShape(
|
||||
// specified in broadcast_dimensions are then changed to match the
|
||||
// corresponding dimension size in smaller_shape.
|
||||
Shape output_shape(larger_shape);
|
||||
output_shape.set_element_type(
|
||||
ShapeUtil::HigherPrecisionElementType(larger_shape, smaller_shape));
|
||||
|
||||
for (int i = 0; i < smaller_shape.dimensions_size(); ++i) {
|
||||
int64 dimension_to_match = broadcast_dimensions.at(i);
|
||||
@ -878,7 +890,7 @@ ShapeInference::InferDegenerateDimensionBroadcastShape(
|
||||
TF_RETURN_IF_ERROR(
|
||||
ExpectNotTupleOrOpaque(rhs, "rhs of elementwise binary operation"));
|
||||
|
||||
if (!ShapeUtil::SameElementType(lhs, rhs)) {
|
||||
if (!ShapeUtil::SameElementTypeIgnoringFpPrecision(lhs, rhs)) {
|
||||
return InvalidArgument(
|
||||
"binary op %s with different element types: %s and %s",
|
||||
BinaryOperation_Name(operation).c_str(),
|
||||
@ -897,10 +909,11 @@ ShapeInference::InferDegenerateDimensionBroadcastShape(
|
||||
}
|
||||
}
|
||||
|
||||
if (ShapeUtil::Compatible(lhs, rhs)) {
|
||||
if (ShapeUtil::CompatibleIgnoringFpPrecision(lhs, rhs)) {
|
||||
// If the shapes are the same other than layout, the output shape is the
|
||||
// same (elementwise op).
|
||||
return lhs;
|
||||
return ShapeUtil::ChangeElementType(
|
||||
lhs, ShapeUtil::HigherPrecisionElementType(lhs, rhs));
|
||||
}
|
||||
|
||||
if (ShapeUtil::Rank(lhs) == ShapeUtil::Rank(rhs)) {
|
||||
@ -973,7 +986,7 @@ ShapeInference::InferDegenerateDimensionBroadcastShape(
|
||||
TF_ASSIGN_OR_RETURN(const Shape& shape,
|
||||
InferElementwiseBinaryOpShape(operation, lhs, rhs,
|
||||
broadcast_dimensions));
|
||||
if (lhs.element_type() == F32) {
|
||||
if (lhs.element_type() == F32 && rhs.element_type() == F32) {
|
||||
return ShapeUtil::ChangeElementType(shape, C64);
|
||||
} else {
|
||||
return Unimplemented("complex component type not supported");
|
||||
@ -1078,12 +1091,13 @@ ShapeInference::InferDegenerateDimensionBroadcastShape(
|
||||
TF_RETURN_IF_ERROR(
|
||||
ExpectNotTupleOrOpaque(*arg_shapes[i], "operand of map"));
|
||||
|
||||
if (ShapeUtil::Compatible(*arg_shapes[i], *arg_shape)) {
|
||||
if (ShapeUtil::CompatibleIgnoringFpPrecision(*arg_shapes[i], *arg_shape)) {
|
||||
continue;
|
||||
}
|
||||
if (!ShapeUtil::IsTuple(*arg_shapes[i]) &&
|
||||
!ShapeUtil::IsTuple(*arg_shape) &&
|
||||
ShapeUtil::SameElementType(*arg_shapes[i], *arg_shape)) {
|
||||
ShapeUtil::SameElementTypeIgnoringFpPrecision(*arg_shapes[i],
|
||||
*arg_shape)) {
|
||||
if (ShapeUtil::IsScalar(*arg_shapes[i])) {
|
||||
continue;
|
||||
}
|
||||
@ -1148,7 +1162,8 @@ ShapeInference::InferDegenerateDimensionBroadcastShape(
|
||||
i, ShapeUtil::HumanString(parameter_shape).c_str());
|
||||
}
|
||||
|
||||
if (parameter_shape.element_type() != arg_shape->element_type()) {
|
||||
if (!ShapeUtil::SameElementTypeIgnoringFpPrecision(parameter_shape,
|
||||
*arg_shape)) {
|
||||
return InvalidArgument(
|
||||
"mapped computation's parameter type has to match argument element "
|
||||
"type; got parameter %d shape: %s, argument shape: %s",
|
||||
@ -1221,7 +1236,8 @@ ShapeInference::InferDegenerateDimensionBroadcastShape(
|
||||
PrimitiveType_Name(operand_shape.element_type()).c_str());
|
||||
}
|
||||
|
||||
if (!ShapeUtil::SameElementType(offset_shape, operand_shape)) {
|
||||
if (!ShapeUtil::SameElementTypeIgnoringFpPrecision(offset_shape,
|
||||
operand_shape)) {
|
||||
return InvalidArgument(
|
||||
"The inputs should have the same element type for batch-norm-training, "
|
||||
"but the shape of offset factor is %s "
|
||||
@ -1230,7 +1246,8 @@ ShapeInference::InferDegenerateDimensionBroadcastShape(
|
||||
PrimitiveType_Name(operand_shape.element_type()).c_str());
|
||||
}
|
||||
|
||||
if (!ShapeUtil::SameElementType(scale_shape, operand_shape)) {
|
||||
if (!ShapeUtil::SameElementTypeIgnoringFpPrecision(scale_shape,
|
||||
operand_shape)) {
|
||||
return InvalidArgument(
|
||||
"The inputs should have the same element type for batch-norm-training, "
|
||||
"but the shape of scale factor is %s "
|
||||
@ -1329,7 +1346,8 @@ ShapeInference::InferDegenerateDimensionBroadcastShape(
|
||||
PrimitiveType_Name(operand_shape.element_type()).c_str());
|
||||
}
|
||||
|
||||
if (!ShapeUtil::SameElementType(offset_shape, operand_shape)) {
|
||||
if (!ShapeUtil::SameElementTypeIgnoringFpPrecision(offset_shape,
|
||||
operand_shape)) {
|
||||
return InvalidArgument(
|
||||
"The inputs should have the same element type for "
|
||||
"batch-norm-inference, "
|
||||
@ -1339,7 +1357,8 @@ ShapeInference::InferDegenerateDimensionBroadcastShape(
|
||||
PrimitiveType_Name(operand_shape.element_type()).c_str());
|
||||
}
|
||||
|
||||
if (!ShapeUtil::SameElementType(scale_shape, operand_shape)) {
|
||||
if (!ShapeUtil::SameElementTypeIgnoringFpPrecision(scale_shape,
|
||||
operand_shape)) {
|
||||
return InvalidArgument(
|
||||
"The inputs should have the same element type for "
|
||||
"batch-norm-inference, "
|
||||
@ -1349,7 +1368,8 @@ ShapeInference::InferDegenerateDimensionBroadcastShape(
|
||||
PrimitiveType_Name(operand_shape.element_type()).c_str());
|
||||
}
|
||||
|
||||
if (!ShapeUtil::SameElementType(mean_shape, operand_shape)) {
|
||||
if (!ShapeUtil::SameElementTypeIgnoringFpPrecision(mean_shape,
|
||||
operand_shape)) {
|
||||
return InvalidArgument(
|
||||
"The inputs should have the same element type for "
|
||||
"batch-norm-inference, "
|
||||
@ -1359,7 +1379,8 @@ ShapeInference::InferDegenerateDimensionBroadcastShape(
|
||||
PrimitiveType_Name(operand_shape.element_type()).c_str());
|
||||
}
|
||||
|
||||
if (!ShapeUtil::SameElementType(variance_shape, operand_shape)) {
|
||||
if (!ShapeUtil::SameElementTypeIgnoringFpPrecision(variance_shape,
|
||||
operand_shape)) {
|
||||
return InvalidArgument(
|
||||
"The inputs should have the same element type for "
|
||||
"batch-norm-inference, "
|
||||
@ -1481,7 +1502,8 @@ ShapeInference::InferDegenerateDimensionBroadcastShape(
|
||||
PrimitiveType_Name(output_grad_shape.element_type()).c_str());
|
||||
}
|
||||
|
||||
if (!ShapeUtil::SameElementType(output_grad_shape, operand_shape)) {
|
||||
if (!ShapeUtil::SameElementTypeIgnoringFpPrecision(output_grad_shape,
|
||||
operand_shape)) {
|
||||
return InvalidArgument(
|
||||
"The inputs should have the same element type for batch-norm-grad, "
|
||||
"but the element type of output_grad is %s "
|
||||
@ -1490,7 +1512,8 @@ ShapeInference::InferDegenerateDimensionBroadcastShape(
|
||||
PrimitiveType_Name(operand_shape.element_type()).c_str());
|
||||
}
|
||||
|
||||
if (!ShapeUtil::SameElementType(scale_shape, operand_shape)) {
|
||||
if (!ShapeUtil::SameElementTypeIgnoringFpPrecision(scale_shape,
|
||||
operand_shape)) {
|
||||
return InvalidArgument(
|
||||
"The inputs should have the same element type for batch-norm-grad, "
|
||||
"but the element type of scale factor is %s "
|
||||
@ -1499,7 +1522,8 @@ ShapeInference::InferDegenerateDimensionBroadcastShape(
|
||||
PrimitiveType_Name(operand_shape.element_type()).c_str());
|
||||
}
|
||||
|
||||
if (!ShapeUtil::SameElementType(mean_shape, operand_shape)) {
|
||||
if (!ShapeUtil::SameElementTypeIgnoringFpPrecision(mean_shape,
|
||||
operand_shape)) {
|
||||
return InvalidArgument(
|
||||
"The inputs should have the same element type for batch-norm-grad, "
|
||||
"but the element type of mean is %s "
|
||||
@ -1508,7 +1532,8 @@ ShapeInference::InferDegenerateDimensionBroadcastShape(
|
||||
PrimitiveType_Name(operand_shape.element_type()).c_str());
|
||||
}
|
||||
|
||||
if (!ShapeUtil::SameElementType(var_shape, operand_shape)) {
|
||||
if (!ShapeUtil::SameElementTypeIgnoringFpPrecision(var_shape,
|
||||
operand_shape)) {
|
||||
return InvalidArgument(
|
||||
"The inputs should have the same element type for batch-norm-grad, "
|
||||
"but the element type of mean is %s "
|
||||
@ -1569,7 +1594,7 @@ ShapeInference::InferDegenerateDimensionBroadcastShape(
|
||||
TF_RETURN_IF_ERROR(ExpectNotTupleOrOpaque(lhs, "lhs of convolution"));
|
||||
TF_RETURN_IF_ERROR(ExpectNotTupleOrOpaque(rhs, "rhs of convolution"));
|
||||
|
||||
if (!ShapeUtil::SameElementType(lhs, rhs)) {
|
||||
if (!ShapeUtil::SameElementTypeIgnoringFpPrecision(lhs, rhs)) {
|
||||
return InvalidArgument(
|
||||
"Convolution with different element types: %s and %s",
|
||||
ShapeUtil::HumanString(lhs).c_str(),
|
||||
@ -1714,8 +1739,8 @@ ShapeInference::InferDegenerateDimensionBroadcastShape(
|
||||
dimensions[dnums.output_spatial_dimensions(i)] =
|
||||
window_output_shape.dimensions(i);
|
||||
}
|
||||
|
||||
return ShapeUtil::MakeShape(lhs.element_type(), dimensions);
|
||||
return ShapeUtil::MakeShape(ShapeUtil::HigherPrecisionElementType(lhs, rhs),
|
||||
dimensions);
|
||||
}
|
||||
|
||||
/* static */ StatusOr<Shape> ShapeInference::InferFftShape(
|
||||
@ -1877,16 +1902,16 @@ ShapeInference::InferDegenerateDimensionBroadcastShape(
|
||||
}
|
||||
const Shape& operand_element_shape =
|
||||
ShapeUtil::MakeShape(operand_shape.element_type(), {});
|
||||
if (!ShapeUtil::Compatible(operand_element_shape,
|
||||
select_shape.parameters(0))) {
|
||||
if (!ShapeUtil::CompatibleIgnoringFpPrecision(operand_element_shape,
|
||||
select_shape.parameters(0))) {
|
||||
return InvalidArgument(
|
||||
"select function's first parameter shape currently must "
|
||||
"match the operand element shape. Got %s vs %s",
|
||||
ShapeUtil::HumanString(select_shape.parameters(0)).c_str(),
|
||||
ShapeUtil::HumanString(operand_element_shape).c_str());
|
||||
}
|
||||
if (!ShapeUtil::Compatible(operand_element_shape,
|
||||
select_shape.parameters(1))) {
|
||||
if (!ShapeUtil::CompatibleIgnoringFpPrecision(operand_element_shape,
|
||||
select_shape.parameters(1))) {
|
||||
return InvalidArgument(
|
||||
"select function's second parameter shape currently must "
|
||||
"match the operand element shape. Got %s vs %s",
|
||||
@ -1903,7 +1928,8 @@ ShapeInference::InferDegenerateDimensionBroadcastShape(
|
||||
InferWindowOutputShape(operand_shape, window,
|
||||
operand_shape.element_type(),
|
||||
/*allow_negative_padding=*/false));
|
||||
if (!ShapeUtil::Compatible(source_shape, window_result_shape)) {
|
||||
if (!ShapeUtil::CompatibleIgnoringFpPrecision(source_shape,
|
||||
window_result_shape)) {
|
||||
return InvalidArgument(
|
||||
"source shape does not match the shape of window-reduced operand: "
|
||||
"source(%s), window-reduced operand(%s)",
|
||||
@ -2086,7 +2112,8 @@ ShapeInference::InferDegenerateDimensionBroadcastShape(
|
||||
ShapeUtil::Rank(update_shape), ShapeUtil::Rank(operand_shape));
|
||||
}
|
||||
|
||||
if (operand_shape.element_type() != update_shape.element_type()) {
|
||||
if (!ShapeUtil::SameElementTypeIgnoringFpPrecision(operand_shape,
|
||||
update_shape)) {
|
||||
return InvalidArgument(
|
||||
"dynamic update slice update element type does not match argument. "
|
||||
"operand.element_type: %s vs update.element_type: %s",
|
||||
@ -2322,24 +2349,26 @@ ShapeInference::InferDegenerateDimensionBroadcastShape(
|
||||
TF_RETURN_IF_ERROR(ExpectNotTupleOrOpaque(min, "clamp min"));
|
||||
TF_RETURN_IF_ERROR(ExpectNotTupleOrOpaque(operand, "clamp operand"));
|
||||
TF_RETURN_IF_ERROR(ExpectNotTupleOrOpaque(max, "clamp max"));
|
||||
if (!ShapeUtil::SameElementType(min, operand) ||
|
||||
!ShapeUtil::SameElementType(max, operand)) {
|
||||
if (!ShapeUtil::SameElementTypeIgnoringFpPrecision(min, operand) ||
|
||||
!ShapeUtil::SameElementTypeIgnoringFpPrecision(max, operand)) {
|
||||
return InvalidArgument("clamp op with different operand types: %s, %s, %s",
|
||||
ShapeUtil::HumanString(min).c_str(),
|
||||
ShapeUtil::HumanString(operand).c_str(),
|
||||
ShapeUtil::HumanString(max).c_str());
|
||||
}
|
||||
if (((ShapeUtil::Compatible(min, operand) || ShapeUtil::IsScalar(min)) &&
|
||||
(ShapeUtil::Compatible(max, operand) || ShapeUtil::IsScalar(max)))) {
|
||||
if (((ShapeUtil::CompatibleIgnoringFpPrecision(min, operand) ||
|
||||
ShapeUtil::IsScalar(min)) &&
|
||||
(ShapeUtil::CompatibleIgnoringFpPrecision(max, operand) ||
|
||||
ShapeUtil::IsScalar(max)))) {
|
||||
return operand;
|
||||
}
|
||||
if (ShapeUtil::IsScalar(operand)) {
|
||||
if (ShapeUtil::Compatible(min, max)) {
|
||||
return min;
|
||||
if (ShapeUtil::CompatibleIgnoringFpPrecision(min, max)) {
|
||||
return ShapeUtil::ChangeElementType(min, operand.element_type());
|
||||
} else if (ShapeUtil::IsScalar(min)) {
|
||||
return max;
|
||||
return ShapeUtil::ChangeElementType(max, operand.element_type());
|
||||
} else if (ShapeUtil::IsScalar(max)) {
|
||||
return min;
|
||||
return ShapeUtil::ChangeElementType(min, operand.element_type());
|
||||
}
|
||||
}
|
||||
return Unimplemented(
|
||||
@ -2352,7 +2381,15 @@ ShapeInference::InferDegenerateDimensionBroadcastShape(
|
||||
// broadcast from all operands, not just the predicate.
|
||||
/* static */ StatusOr<Shape> ShapeInference::InferSelectShape(
|
||||
const Shape& pred, const Shape& on_true, const Shape& on_false) {
|
||||
if (!ShapeUtil::Compatible(on_true, on_false)) {
|
||||
bool compatible;
|
||||
if (ShapeUtil::IsTuple(on_true)) {
|
||||
// Select only defines the top-level buffer, so if it's a tuple, the two
|
||||
// input must match exactly.
|
||||
compatible = ShapeUtil::Compatible(on_true, on_false);
|
||||
} else {
|
||||
compatible = ShapeUtil::CompatibleIgnoringFpPrecision(on_true, on_false);
|
||||
}
|
||||
if (!compatible) {
|
||||
return InvalidArgument(
|
||||
"operands to select must be the same shape; got %s and %s",
|
||||
ShapeUtil::HumanString(on_true).c_str(),
|
||||
@ -2367,7 +2404,8 @@ ShapeInference::InferDegenerateDimensionBroadcastShape(
|
||||
// By this stage we know that pred's element type is PRED. Therefore, this
|
||||
// check restricts pred to be a PRED scalar, or a PRED array with the same
|
||||
// dimensions as on_true and on_false.
|
||||
return on_true;
|
||||
return ShapeUtil::ChangeElementType(
|
||||
on_true, ShapeUtil::HigherPrecisionElementType(on_true, on_false));
|
||||
} else {
|
||||
return Unimplemented(
|
||||
"select operation with non-scalar predicate with dimensionality "
|
||||
|
@ -630,6 +630,19 @@ StatusOr<Shape> ParseShapeStringInternal(tensorflow::StringPiece* s) {
|
||||
return SameDimensions(lhs, rhs);
|
||||
}
|
||||
|
||||
/* static */ bool ShapeUtil::CompatibleIgnoringFpPrecision(const Shape& lhs,
|
||||
const Shape& rhs) {
|
||||
if (lhs.element_type() == TUPLE) {
|
||||
return rhs.element_type() == TUPLE &&
|
||||
ContainersEqual(lhs.tuple_shapes(), rhs.tuple_shapes(),
|
||||
CompatibleIgnoringFpPrecision);
|
||||
}
|
||||
if (SameElementTypeIgnoringFpPrecision(lhs, rhs)) {
|
||||
return CompatibleIgnoringElementType(lhs, rhs);
|
||||
}
|
||||
return false;
|
||||
}
|
||||
|
||||
/* static */ int64 ShapeUtil::GetDimension(const Shape& shape,
|
||||
int64 dimension_number) {
|
||||
return shape.dimensions(GetDimensionNumber(shape, dimension_number));
|
||||
|
@ -23,6 +23,7 @@ limitations under the License.
|
||||
#include <string>
|
||||
|
||||
#include "tensorflow/compiler/xla/layout_util.h"
|
||||
#include "tensorflow/compiler/xla/primitive_util.h"
|
||||
#include "tensorflow/compiler/xla/statusor.h"
|
||||
#include "tensorflow/compiler/xla/types.h"
|
||||
#include "tensorflow/compiler/xla/xla_data.pb.h"
|
||||
@ -211,6 +212,31 @@ class ShapeUtil {
|
||||
return lhs.element_type() == rhs.element_type();
|
||||
}
|
||||
|
||||
// As SameElementType, but allows floating point types to have different
|
||||
// precisions.
|
||||
static bool SameElementTypeIgnoringFpPrecision(const Shape& a,
|
||||
const Shape& b) {
|
||||
if (ElementIsFloating(a) && ElementIsFloating(b)) {
|
||||
return true;
|
||||
}
|
||||
return ShapeUtil::SameElementType(a, b);
|
||||
}
|
||||
|
||||
// Returns the higher-precision element type if a and b are both floating
|
||||
// point types; otherwise, checks that that they have the same element type
|
||||
// and returns it.
|
||||
static PrimitiveType HigherPrecisionElementType(const Shape& a,
|
||||
const Shape& b) {
|
||||
if (SameElementType(a, b)) {
|
||||
return a.element_type();
|
||||
}
|
||||
CHECK(SameElementTypeIgnoringFpPrecision(a, b));
|
||||
return primitive_util::BitWidth(a.element_type()) <
|
||||
primitive_util::BitWidth(b.element_type())
|
||||
? b.element_type()
|
||||
: a.element_type();
|
||||
}
|
||||
|
||||
// Returns true if the rank, dimension sizes, and element type are
|
||||
// identical. Layout is ignored. Tuple elements are compared recursively for
|
||||
// compatibility.
|
||||
@ -221,6 +247,10 @@ class ShapeUtil {
|
||||
// compatibility.
|
||||
static bool CompatibleIgnoringElementType(const Shape& lhs, const Shape& rhs);
|
||||
|
||||
// As Compatible, but allow one of lhs and rhs to be BF16 while the other
|
||||
// being F32. Tuple elements are compared recursively for compatibility.
|
||||
static bool CompatibleIgnoringFpPrecision(const Shape& lhs, const Shape& rhs);
|
||||
|
||||
// Returns whether the lhs and rhs shapes are identical protobufs.
|
||||
static bool Equal(const Shape& lhs, const Shape& rhs);
|
||||
|
||||
|
@ -170,6 +170,18 @@ TEST(ShapeUtilTest, CompatibleNotIdenticalShapes) {
|
||||
EXPECT_TRUE(ShapeUtil::Compatible(shape_1, shape_2));
|
||||
}
|
||||
|
||||
TEST(ShapeUtilTest, CompatibleIgnoringFpPrecision) {
|
||||
Shape shape1 = ShapeUtil::MakeShape(BF16, {3, 2});
|
||||
Shape shape2 = ShapeUtil::MakeShape(F32, {3, 2});
|
||||
ASSERT_TRUE(ShapeUtil::CompatibleIgnoringFpPrecision(shape1, shape2));
|
||||
}
|
||||
|
||||
TEST(ShapeUtilTest, IncompatibleIgnoringFpPrecision) {
|
||||
Shape shape1 = ShapeUtil::MakeShape(BF16, {3, 2});
|
||||
Shape shape2 = ShapeUtil::MakeShape(F32, {2, 2});
|
||||
ASSERT_FALSE(ShapeUtil::CompatibleIgnoringFpPrecision(shape1, shape2));
|
||||
}
|
||||
|
||||
TEST(ShapeUtilTest, IncompatibleDifferentElementShapes) {
|
||||
Shape shape_1 = ShapeUtil::MakeShape(F32, {3, 2});
|
||||
Shape shape_2 = ShapeUtil::MakeShape(PRED, {3, 2});
|
||||
@ -184,6 +196,14 @@ TEST(ShapeUtilTest, CompatibleTuples) {
|
||||
EXPECT_TRUE(ShapeUtil::Compatible(tuple1, tuple2));
|
||||
}
|
||||
|
||||
TEST(ShapeUtilTest, CompatibleTuplesIgnoringFpPrecision) {
|
||||
Shape tuple1 = ShapeUtil::MakeTupleShape(
|
||||
{ShapeUtil::MakeShape(BF16, {3, 2}), ShapeUtil::MakeShape(F32, {4, 5})});
|
||||
Shape tuple2 = ShapeUtil::MakeTupleShape(
|
||||
{ShapeUtil::MakeShape(F64, {3, 2}), ShapeUtil::MakeShape(BF16, {4, 5})});
|
||||
EXPECT_TRUE(ShapeUtil::CompatibleIgnoringFpPrecision(tuple1, tuple2));
|
||||
}
|
||||
|
||||
TEST(ShapeUtilTest, IncompatibleTuplesWithSwappedElements) {
|
||||
Shape tuple1 = ShapeUtil::MakeTupleShape(
|
||||
{ShapeUtil::MakeShape(PRED, {4, 5}), ShapeUtil::MakeShape(F32, {3, 2})});
|
||||
@ -193,6 +213,14 @@ TEST(ShapeUtilTest, IncompatibleTuplesWithSwappedElements) {
|
||||
EXPECT_FALSE(ShapeUtil::CompatibleIgnoringElementType(tuple1, tuple2));
|
||||
}
|
||||
|
||||
TEST(ShapeUtilTest, IncompatibleTuplesIgnoringFpPrecision) {
|
||||
Shape tuple1 = ShapeUtil::MakeTupleShape(
|
||||
{ShapeUtil::MakeShape(BF16, {4, 5}), ShapeUtil::MakeShape(F32, {3, 2})});
|
||||
Shape tuple2 = ShapeUtil::MakeTupleShape(
|
||||
{ShapeUtil::MakeShape(F32, {3, 2}), ShapeUtil::MakeShape(BF16, {4, 5})});
|
||||
EXPECT_FALSE(ShapeUtil::CompatibleIgnoringFpPrecision(tuple1, tuple2));
|
||||
}
|
||||
|
||||
TEST(ShapeUtilTest, IncompatibleTuplesWithDifferentPrimitiveType) {
|
||||
Shape tuple1 = ShapeUtil::MakeTupleShape(
|
||||
{ShapeUtil::MakeShape(PRED, {4, 5}), ShapeUtil::MakeShape(F32, {3, 2})});
|
||||
|
@ -2121,6 +2121,44 @@ XLA_TEST_F(ArrayElementwiseOpTest, ExpF32sVector) {
|
||||
error_spec_);
|
||||
}
|
||||
|
||||
XLA_TEST_F(ArrayElementwiseOpTest, LogF32sVector) {
|
||||
// The input tensor is large enough to exercise the vectorized exp
|
||||
// implementation on XLA CPU.
|
||||
ComputationBuilder builder(client_, TestName());
|
||||
|
||||
std::unique_ptr<Literal> input_literal = Literal::CreateR1<float>(
|
||||
{-1.29, -1.41, -1.25, -13.5, -11.7, -17.9, -198,
|
||||
-167, 1.29, 1.41, 1.25, 13.5, 11.7, 17.9,
|
||||
198, 167, 1.27e+03, 1.33e+03, 1.74e+03, 1.6e+04, 1.84e+04,
|
||||
1.74e+04, 1.89e+05, 1.9e+05, 1.93e+06, 1.98e+06, 1.65e+06, 1.97e+07,
|
||||
1.66e+07, 1e+07, 1.98e+08, 1.96e+08, 1.64e+09, 1.58e+09, 1.64e+09,
|
||||
1.44e+10, 1.5e+10, 1.99e+10, 1.17e+11, 1.08e+11, 1.08e+12, 1.38e+12,
|
||||
1.4e+12, 1.03e+13, 1.6e+13, 1.99e+13, 1.26e+14, 1.51e+14, 1.33e+15,
|
||||
1.41e+15, 1.63e+15, 1.39e+16, 1.21e+16, 1.27e+16, 1.28e+17, 1.62e+17,
|
||||
2e+18, 1.96e+18, 1.81e+18, 1.99e+19, 1.86e+19, 1.61e+19, 1.71e+20,
|
||||
1.47e+20, 1.83e+21, 1.33e+21, 1.3e+21, 1.35e+22, 1.84e+22, 1.02e+22,
|
||||
1.81e+23, 1.02e+23, 1.89e+24, 1.49e+24, 1.08e+24, 1.95e+25, 1.1e+25,
|
||||
1.62e+25, 1.2e+26, 1.41e+26, 1.93e+27, 1.66e+27, 1.62e+27, 1.05e+28,
|
||||
1.5e+28, 1.79e+28, 1.36e+29, 1.95e+29, 1.5e+30, 1.81e+30, 1.34e+30,
|
||||
1.7e+31, 1.44e+31, 1.1e+31, 1.4e+32, 1.67e+32, 1.96e+33, 1.11e+33,
|
||||
1.19e+33, 1.61e+34, 1.05e+34, 1.88e+34, 1.67e+35, 1.7e+35});
|
||||
TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr<GlobalData> input_data,
|
||||
client_->TransferToServer(*input_literal));
|
||||
|
||||
auto input = builder.Parameter(0, input_literal->shape(), "input");
|
||||
builder.Log(input);
|
||||
|
||||
std::vector<float> expected_result;
|
||||
int64 input_size = input_literal->shape().dimensions(0);
|
||||
expected_result.reserve(input_size);
|
||||
for (int64 i = 0; i < input_size; i++) {
|
||||
expected_result.push_back(std::log(input_literal->Get<float>({i})));
|
||||
}
|
||||
|
||||
ComputeAndCompareR1<float>(&builder, expected_result, {input_data.get()},
|
||||
error_spec_);
|
||||
}
|
||||
|
||||
XLA_TEST_F(ArrayElementwiseOpTest, AddChainFoldLeft) {
|
||||
// a ------ (add) --------- (add)
|
||||
// / /
|
||||
|
@ -276,8 +276,7 @@ if (tensorflow_BUILD_PYTHON_TESTS)
|
||||
"${tensorflow_source_dir}/tensorflow/python/data/kernel_tests/dataset_constructor_op_test.py" # Segfaults on windows
|
||||
"${tensorflow_source_dir}/tensorflow/contrib/data/python/kernel_tests/dataset_constructor_op_test.py" # Segfaults on Windows.
|
||||
"${tensorflow_source_dir}/tensorflow/python/data/kernel_tests/iterator_ops_cluster_test.py"
|
||||
# Broken tensorboard test due to cmake issues.
|
||||
"${tensorflow_source_dir}/tensorflow/contrib/data/python/kernel_tests/iterator_ops_cluster_test.py" # Needs portpicker
|
||||
"${tensorflow_source_dir}/tensorflow/contrib/data/python/kernel_tests/interleave_dataset_op_test.py" # Deadlocks
|
||||
"${tensorflow_source_dir}/tensorflow/contrib/data/python/kernel_tests/sloppy_transformation_dataset_op_test.py" # b/65430561
|
||||
# tensor_forest tests (also note that we exclude the hybrid tests for now)
|
||||
"${tensorflow_source_dir}/tensorflow/contrib/tensor_forest/python/kernel_tests/count_extremely_random_stats_op_test.py" # Results in wrong order.
|
||||
|
@ -166,8 +166,8 @@ def crf_log_likelihood(inputs,
|
||||
sequence_lengths: A [batch_size] vector of true sequence lengths.
|
||||
transition_params: A [num_tags, num_tags] transition matrix, if available.
|
||||
Returns:
|
||||
log_likelihood: A scalar containing the log-likelihood of the given sequence
|
||||
of tag indices.
|
||||
log_likelihood: A [batch_size] `Tensor` containing the log-likelihood of
|
||||
each example, given the sequence of tag indices.
|
||||
transition_params: A [num_tags, num_tags] transition matrix. This is either
|
||||
provided by the caller or created in this function.
|
||||
"""
|
||||
@ -182,7 +182,7 @@ def crf_log_likelihood(inputs,
|
||||
transition_params)
|
||||
log_norm = crf_log_norm(inputs, sequence_lengths, transition_params)
|
||||
|
||||
# Normalize the scores to get the log-likelihood.
|
||||
# Normalize the scores to get the log-likelihood per example.
|
||||
log_likelihood = sequence_scores - log_norm
|
||||
return log_likelihood, transition_params
|
||||
|
||||
|
@ -19,6 +19,7 @@ from __future__ import division
|
||||
from __future__ import print_function
|
||||
|
||||
import os
|
||||
|
||||
import sqlite3
|
||||
|
||||
from tensorflow.contrib.data.python.ops import readers
|
||||
|
@ -790,7 +790,7 @@ def _extract_tensors(tensors_and_vars):
|
||||
tensor, _ = tensor_and_var
|
||||
if isinstance(tensor, ops_lib.IndexedSlices):
|
||||
tensors.append(tensor.values)
|
||||
else:
|
||||
elif tensor is not None:
|
||||
tensors.append(tensor)
|
||||
return tensors
|
||||
|
||||
|
@ -240,6 +240,13 @@ class ReplicateModelTest(test_util.TensorFlowTestCase):
|
||||
labels = np.array([[1.0], [2.0]])
|
||||
|
||||
with self.test_session() as session:
|
||||
# Add another trainable variable that doesn't produce a gradient to
|
||||
# verify that None gradients are supported.
|
||||
_ = variable_scope.get_variable(
|
||||
'another_variable',
|
||||
initializer=constant_op.constant(1, dtype=dtypes.float64),
|
||||
dtype=dtypes.float64)
|
||||
|
||||
replicated_model_fn = replicate_model_fn.replicate_model_fn(
|
||||
self.model_fn, losses.Reduction.MEAN, devices=['/gpu:0', '/gpu:1'])
|
||||
estimator_spec = replicated_model_fn(
|
||||
@ -1119,8 +1126,6 @@ class SplitBatchTest(test_util.TensorFlowTestCase):
|
||||
feature_shards, label_shards = replicate_model_fn._split_batch(
|
||||
features, labels, 2, device='/gpu:0')
|
||||
|
||||
print(feature_shards[0]['x'].eval())
|
||||
print(feature_shards[1]['x'].eval())
|
||||
self.assertSparseValuesEqual(
|
||||
sparse_tensor.SparseTensorValue(
|
||||
indices=[[0, 0], [1, 0], [1, 1]],
|
||||
|
@ -181,7 +181,8 @@ class ClassifierMetricsTest(test.TestCase):
|
||||
batch_size = 3
|
||||
img = array_ops.ones([batch_size, 299, 299, 3])
|
||||
pool = _run_with_mock(
|
||||
classifier_metrics.run_inception, img,
|
||||
classifier_metrics.run_inception,
|
||||
img,
|
||||
output_tensor=classifier_metrics.INCEPTION_FINAL_POOL)
|
||||
|
||||
self.assertTrue(isinstance(pool, ops.Tensor))
|
||||
@ -195,9 +196,12 @@ class ClassifierMetricsTest(test.TestCase):
|
||||
batch_size = 3
|
||||
img = array_ops.ones([batch_size, 299, 299, 3])
|
||||
logits, pool = _run_with_mock(
|
||||
classifier_metrics.run_inception, img,
|
||||
output_tensor=[classifier_metrics.INCEPTION_OUTPUT,
|
||||
classifier_metrics.INCEPTION_FINAL_POOL])
|
||||
classifier_metrics.run_inception,
|
||||
img,
|
||||
output_tensor=[
|
||||
classifier_metrics.INCEPTION_OUTPUT,
|
||||
classifier_metrics.INCEPTION_FINAL_POOL
|
||||
])
|
||||
|
||||
self.assertTrue(isinstance(logits, ops.Tensor))
|
||||
self.assertTrue(isinstance(pool, ops.Tensor))
|
||||
@ -209,8 +213,10 @@ class ClassifierMetricsTest(test.TestCase):
|
||||
|
||||
def test_inception_score_graph(self):
|
||||
"""Test `inception_score` graph construction."""
|
||||
score = _run_with_mock(classifier_metrics.inception_score,
|
||||
array_ops.zeros([6, 299, 299, 3]), num_batches=3)
|
||||
score = _run_with_mock(
|
||||
classifier_metrics.inception_score,
|
||||
array_ops.zeros([6, 299, 299, 3]),
|
||||
num_batches=3)
|
||||
self.assertTrue(isinstance(score, ops.Tensor))
|
||||
score.shape.assert_has_rank(0)
|
||||
|
||||
@ -248,12 +254,14 @@ class ClassifierMetricsTest(test.TestCase):
|
||||
array_ops.zeros([8, 10], dtype=dtypes.int32), p_logits, q)
|
||||
|
||||
with self.assertRaisesRegexp(ValueError, 'must be floating type'):
|
||||
classifier_metrics._kl_divergence(
|
||||
p, array_ops.zeros([8, 10], dtype=dtypes.int32), q)
|
||||
classifier_metrics._kl_divergence(p,
|
||||
array_ops.zeros(
|
||||
[8, 10], dtype=dtypes.int32), q)
|
||||
|
||||
with self.assertRaisesRegexp(ValueError, 'must be floating type'):
|
||||
classifier_metrics._kl_divergence(
|
||||
p, p_logits, array_ops.zeros([10], dtype=dtypes.int32))
|
||||
classifier_metrics._kl_divergence(p, p_logits,
|
||||
array_ops.zeros(
|
||||
[10], dtype=dtypes.int32))
|
||||
|
||||
with self.assertRaisesRegexp(ValueError, 'must have rank 2'):
|
||||
classifier_metrics._kl_divergence(array_ops.zeros([8]), p_logits, q)
|
||||
@ -266,8 +274,9 @@ class ClassifierMetricsTest(test.TestCase):
|
||||
|
||||
def test_inception_score_value(self):
|
||||
"""Test that `inception_score` gives the correct value."""
|
||||
logits = np.array([np.array([1, 2] * 500 + [4]),
|
||||
np.array([4, 5] * 500 + [6])])
|
||||
logits = np.array(
|
||||
[np.array([1, 2] * 500 + [4]),
|
||||
np.array([4, 5] * 500 + [6])])
|
||||
unused_image = array_ops.zeros([2, 299, 299, 3])
|
||||
incscore = _run_with_mock(classifier_metrics.inception_score, unused_image)
|
||||
|
||||
@ -285,9 +294,11 @@ class ClassifierMetricsTest(test.TestCase):
|
||||
test_pool_real_a = np.float32(np.random.randn(512, 256))
|
||||
test_pool_gen_a = np.float32(np.random.randn(512, 256))
|
||||
|
||||
fid_op = _run_with_mock(classifier_metrics.frechet_classifier_distance,
|
||||
test_pool_real_a, test_pool_gen_a,
|
||||
classifier_fn=lambda x: x)
|
||||
fid_op = _run_with_mock(
|
||||
classifier_metrics.frechet_classifier_distance,
|
||||
test_pool_real_a,
|
||||
test_pool_gen_a,
|
||||
classifier_fn=lambda x: x)
|
||||
|
||||
with self.test_session() as sess:
|
||||
actual_fid = sess.run(fid_op)
|
||||
@ -296,6 +307,33 @@ class ClassifierMetricsTest(test.TestCase):
|
||||
|
||||
self.assertAllClose(expected_fid, actual_fid, 0.0001)
|
||||
|
||||
def test_frechet_classifier_distance_covariance(self):
|
||||
"""Test that `frechet_classifier_distance` takes covariance into account."""
|
||||
np.random.seed(0)
|
||||
|
||||
# Make num_examples > num_features to ensure scipy's sqrtm function
|
||||
# doesn't return a complex matrix.
|
||||
test_pool_reals, test_pool_gens = [], []
|
||||
for i in range(1, 11, 2):
|
||||
test_pool_reals.append(np.float32(np.random.randn(2048, 256) * i))
|
||||
test_pool_gens.append(np.float32(np.random.randn(2048, 256) * i))
|
||||
|
||||
fid_ops = []
|
||||
for i in range(len(test_pool_reals)):
|
||||
fid_ops.append(_run_with_mock(
|
||||
classifier_metrics.frechet_classifier_distance,
|
||||
test_pool_reals[i],
|
||||
test_pool_gens[i],
|
||||
classifier_fn=lambda x: x))
|
||||
|
||||
fids = []
|
||||
with self.test_session() as sess:
|
||||
for fid_op in fid_ops:
|
||||
fids.append(sess.run(fid_op))
|
||||
|
||||
# Check that the FIDs increase monotonically.
|
||||
self.assertTrue(all(fid_a < fid_b for fid_a, fid_b in zip(fids, fids[1:])))
|
||||
|
||||
def test_trace_sqrt_product_value(self):
|
||||
"""Test that `trace_sqrt_product` gives the correct value."""
|
||||
np.random.seed(0)
|
||||
|
@ -305,6 +305,7 @@ def wasserstein_gradient_penalty(
|
||||
discriminator_fn,
|
||||
discriminator_scope,
|
||||
epsilon=1e-10,
|
||||
target=1.0,
|
||||
weights=1.0,
|
||||
scope=None,
|
||||
loss_collection=ops.GraphKeys.LOSSES,
|
||||
@ -324,6 +325,8 @@ def wasserstein_gradient_penalty(
|
||||
discriminator_scope: If not `None`, reuse discriminators from this scope.
|
||||
epsilon: A small positive number added for numerical stability when
|
||||
computing the gradient norm.
|
||||
target: Optional Python number or `Tensor` indicating the target value of
|
||||
gradient norm. Defaults to 1.0.
|
||||
weights: Optional `Tensor` whose rank is either 0, or the same rank as
|
||||
`real_data` and `generated_data`, and must be broadcastable to
|
||||
them (i.e., all dimensions must be either `1`, or the same as the
|
||||
@ -374,7 +377,7 @@ def wasserstein_gradient_penalty(
|
||||
# For numerical stability, add epsilon to the sum before taking the square
|
||||
# root. Note tf.norm does not add epsilon.
|
||||
slopes = math_ops.sqrt(gradient_squares + epsilon)
|
||||
penalties = math_ops.square(slopes - 1.0)
|
||||
penalties = math_ops.square(slopes / target - 1.0)
|
||||
penalty = losses.compute_weighted_loss(
|
||||
penalties, weights, scope=scope, loss_collection=loss_collection,
|
||||
reduction=reduction)
|
||||
|
@ -481,6 +481,29 @@ class GradientPenaltyTest(test.TestCase, _PenaltyTest):
|
||||
})
|
||||
self.assertAlmostEqual(self._expected_loss, loss, 5)
|
||||
|
||||
def test_loss_with_gradient_norm_target(self):
|
||||
"""Test loss value with non default gradient norm target."""
|
||||
generated_data = array_ops.placeholder(dtypes.float32, shape=(None, None))
|
||||
real_data = array_ops.placeholder(dtypes.float32, shape=(None, None))
|
||||
|
||||
loss = tfgan_losses.wasserstein_gradient_penalty(
|
||||
generated_data,
|
||||
real_data,
|
||||
self._kwargs['generator_inputs'],
|
||||
self._kwargs['discriminator_fn'],
|
||||
self._kwargs['discriminator_scope'],
|
||||
target=2.0)
|
||||
|
||||
with self.test_session() as sess:
|
||||
variables.global_variables_initializer().run()
|
||||
loss = sess.run(
|
||||
loss,
|
||||
feed_dict={
|
||||
generated_data: self._generated_data_np,
|
||||
real_data: self._real_data_np,
|
||||
})
|
||||
self.assertAlmostEqual(1.0, loss, 5)
|
||||
|
||||
def test_reuses_scope(self):
|
||||
"""Test that gradient penalty reuses discriminator scope."""
|
||||
num_vars = len(ops.get_collection(ops.GraphKeys.GLOBAL_VARIABLES))
|
||||
|
@ -460,6 +460,7 @@ def gan_loss(
|
||||
# Auxiliary losses.
|
||||
gradient_penalty_weight=None,
|
||||
gradient_penalty_epsilon=1e-10,
|
||||
gradient_penalty_target=1.0,
|
||||
mutual_information_penalty_weight=None,
|
||||
aux_cond_generator_weight=None,
|
||||
aux_cond_discriminator_weight=None,
|
||||
@ -481,6 +482,9 @@ def gan_loss(
|
||||
small positive value used by the gradient penalty function for numerical
|
||||
stability. Note some applications will need to increase this value to
|
||||
avoid NaNs.
|
||||
gradient_penalty_target: If `gradient_penalty_weight` is not None, a Python
|
||||
number or `Tensor` indicating the target value of gradient norm. See the
|
||||
CIFAR10 section of https://arxiv.org/abs/1710.10196. Defaults to 1.0.
|
||||
mutual_information_penalty_weight: If not `None`, must be a non-negative
|
||||
Python number or Tensor indicating how much to weight the mutual
|
||||
information penalty. See https://arxiv.org/abs/1606.03657 for more
|
||||
@ -539,7 +543,10 @@ def gan_loss(
|
||||
# Add optional extra losses.
|
||||
if _use_aux_loss(gradient_penalty_weight):
|
||||
gp_loss = tfgan_losses.wasserstein_gradient_penalty(
|
||||
model, epsilon=gradient_penalty_epsilon, add_summaries=add_summaries)
|
||||
model,
|
||||
epsilon=gradient_penalty_epsilon,
|
||||
target=gradient_penalty_target,
|
||||
add_summaries=add_summaries)
|
||||
dis_loss += gradient_penalty_weight * gp_loss
|
||||
if _use_aux_loss(mutual_information_penalty_weight):
|
||||
info_loss = tfgan_losses.mutual_information_penalty(
|
||||
|
@ -64,8 +64,11 @@ void resize(T* out, uint8_t* in, int image_height, int image_width,
|
||||
ops::builtin::BuiltinOpResolver resolver;
|
||||
TfLiteRegistration* resize_op =
|
||||
resolver.FindOp(BuiltinOperator_RESIZE_BILINEAR);
|
||||
interpreter->AddNodeWithParameters({0, 1}, {2}, nullptr, 0, nullptr,
|
||||
resize_op, nullptr);
|
||||
auto* params = reinterpret_cast<TfLiteResizeBilinearParams*>(
|
||||
malloc(sizeof(TfLiteResizeBilinearParams)));
|
||||
params->align_corners = false;
|
||||
interpreter->AddNodeWithParameters({0, 1}, {2}, nullptr, 0, params, resize_op,
|
||||
nullptr);
|
||||
|
||||
interpreter->AllocateTensors();
|
||||
|
||||
|
@ -28,6 +28,7 @@ py_library(
|
||||
visibility = ["//visibility:public"],
|
||||
deps = [
|
||||
"//tensorflow/contrib/framework:framework_py",
|
||||
"//tensorflow/core:protos_all_py",
|
||||
"//tensorflow/python:platform",
|
||||
],
|
||||
)
|
||||
|
@ -73,6 +73,7 @@ import itertools as _itertools
|
||||
import uuid as _uuid
|
||||
|
||||
from tensorflow.contrib import framework as _framework
|
||||
from tensorflow.core.framework import attr_value_pb2 as _attr_value_pb2
|
||||
from tensorflow.python.framework import ops as _ops
|
||||
from tensorflow.python.ops import array_ops as _array_ops
|
||||
from tensorflow.python.util.all_util import remove_undocumented
|
||||
@ -133,10 +134,17 @@ class OpHint(object):
|
||||
|
||||
def augmented_identity(arg):
|
||||
identity_op = _array_ops.identity(arg)
|
||||
attr = identity_op.op.node_def.attr
|
||||
attr[OpHint.FUNCTION_NAME_ATTR].s = self._function_name
|
||||
attr[OpHint.FUNCTION_UUID_ATTR].s = self._unique_function_id
|
||||
attr[OpHint.FUNCTION_INPUT_INDEX_ATTR].i = self._curr_input_index
|
||||
# pylint: disable=protected-access
|
||||
identity_op.op._set_attr(
|
||||
OpHint.FUNCTION_NAME_ATTR,
|
||||
_attr_value_pb2.AttrValue(s=self._function_name))
|
||||
identity_op.op._set_attr(
|
||||
OpHint.FUNCTION_UUID_ATTR,
|
||||
_attr_value_pb2.AttrValue(s=self._unique_function_id))
|
||||
identity_op.op._set_attr(
|
||||
OpHint.FUNCTION_INPUT_INDEX_ATTR,
|
||||
_attr_value_pb2.AttrValue(i=self._curr_input_index))
|
||||
# pylint: enable=protected-access
|
||||
self._curr_input_index += 1
|
||||
return identity_op
|
||||
|
||||
@ -154,10 +162,17 @@ class OpHint(object):
|
||||
|
||||
def augmented_identity(arg):
|
||||
identity_op = _array_ops.identity(arg)
|
||||
attr = identity_op.op.node_def.attr
|
||||
attr[OpHint.FUNCTION_NAME_ATTR].s = self._function_name
|
||||
attr[OpHint.FUNCTION_UUID_ATTR].s = self._unique_function_id
|
||||
attr[OpHint.FUNCTION_OUTPUT_INDEX_ATTR].i = self._curr_output_index
|
||||
# pylint: disable=protected-access
|
||||
identity_op.op._set_attr(
|
||||
OpHint.FUNCTION_NAME_ATTR,
|
||||
_attr_value_pb2.AttrValue(s=self._function_name))
|
||||
identity_op.op._set_attr(
|
||||
OpHint.FUNCTION_UUID_ATTR,
|
||||
_attr_value_pb2.AttrValue(s=self._unique_function_id))
|
||||
identity_op.op._set_attr(
|
||||
OpHint.FUNCTION_OUTPUT_INDEX_ATTR,
|
||||
_attr_value_pb2.AttrValue(i=self._curr_output_index))
|
||||
# pylint: enable=protected-access
|
||||
self._curr_output_index += 1
|
||||
return identity_op
|
||||
|
||||
|
@ -139,14 +139,13 @@ bool ResolveConstantUnaryOperator::Run(Model* model, std::size_t op_index) {
|
||||
output_buffer_size * sizeof(output_float_data[0]));
|
||||
} else if (unary_op->type == OperatorType::kTensorFlowSum) {
|
||||
// At the moment only full reduction across all dimensions is supported.
|
||||
for (int i = 0; i < output_dims_count; i++) {
|
||||
CHECK_EQ(output_shape.dims(i), 1);
|
||||
}
|
||||
float sum = 0.f;
|
||||
for (int i = 0; i < input_buffer_size; i++) {
|
||||
sum += (*input_float_data)[i];
|
||||
}
|
||||
output_float_data[0] = sum;
|
||||
for (int i = 0; i < output_buffer_size; ++i) {
|
||||
output_float_data[i] = sum;
|
||||
}
|
||||
} else if (unary_op->type == OperatorType::kTensorFlowMin) {
|
||||
// At the moment only full reduction across all dimensions is supported.
|
||||
// TODO(starka): Output should not be padded.
|
||||
|
@ -35,6 +35,7 @@ NO_SIDE_EFFECT_CONSTRUCTORS = set(('tensorflow',))
|
||||
# TODO(mdan): Verify that these names are not hidden by generated code.
|
||||
# TODO(mdan): Make sure copybara renames the reference below.
|
||||
COMPILED_IMPORT_STATEMENTS = (
|
||||
'from __future__ import print_function',
|
||||
'import tensorflow as tf',
|
||||
'from tensorflow.contrib.py2tf import utils as '
|
||||
'py2tf_utils')
|
||||
|
@ -28,13 +28,11 @@ from tensorflow.python.ops import nn_ops
|
||||
from tensorflow.python.platform import googletest
|
||||
|
||||
|
||||
# TODO(suharshs): Add tests for testing experimental APIs and additional
|
||||
# input arguments
|
||||
class QuantizeGraphTest(test_util.TensorFlowTestCase):
|
||||
# We have a lot of other tests that test the details of the rewrite, here we
|
||||
# just the specific features of the quantize_graph API.
|
||||
|
||||
def _RunTestOverParameters(self, test_fn):
|
||||
def _RunTestOverAllRewrites(self, test_fn):
|
||||
rewrite_fns = [
|
||||
quantize_graph.create_training_graph,
|
||||
quantize_graph.create_eval_graph,
|
||||
@ -44,71 +42,125 @@ class QuantizeGraphTest(test_util.TensorFlowTestCase):
|
||||
for fn in rewrite_fns:
|
||||
test_fn(fn)
|
||||
|
||||
def testRewrite(self):
|
||||
self._RunTestOverParameters(self._TestRewrite)
|
||||
def _RunTestOverTrainingRewrites(self, test_fn):
|
||||
rewrite_fns = [
|
||||
quantize_graph.create_training_graph,
|
||||
quantize_graph.experimental_create_training_graph,
|
||||
]
|
||||
for fn in rewrite_fns:
|
||||
test_fn(fn)
|
||||
|
||||
def _TestRewrite(self, fn):
|
||||
def _RunTestOverExperimentalRewrites(self, test_fn):
|
||||
rewrite_fns = [
|
||||
quantize_graph.experimental_create_training_graph,
|
||||
quantize_graph.experimental_create_eval_graph,
|
||||
]
|
||||
for fn in rewrite_fns:
|
||||
test_fn(fn)
|
||||
|
||||
def testRewrite(self):
|
||||
self._RunTestOverAllRewrites(self._TestRewrite)
|
||||
|
||||
def _TestRewrite(self, rewrite_fn):
|
||||
graph = ops.Graph()
|
||||
with graph.as_default():
|
||||
batch_size, height, width, depth = 5, 128, 128, 3
|
||||
inputs = array_ops.zeros((batch_size, height, width, depth))
|
||||
conv = layers.conv2d(
|
||||
inputs,
|
||||
32, [5, 5],
|
||||
stride=2,
|
||||
padding='SAME',
|
||||
weights_initializer=self._WeightInit(0.09),
|
||||
activation_fn=None,
|
||||
scope='test')
|
||||
_ = nn_ops.relu6(conv)
|
||||
self._ConvLayer()
|
||||
|
||||
orig_variable_names = set(
|
||||
[v.name for v in graph.get_collection(ops.GraphKeys.GLOBAL_VARIABLES)])
|
||||
|
||||
fn(graph)
|
||||
rewrite_fn(graph)
|
||||
|
||||
q_variables = graph.get_collection(ops.GraphKeys.GLOBAL_VARIABLES)
|
||||
# Ensure that variables were added.
|
||||
self.assertTrue(len(orig_variable_names) < len(q_variables))
|
||||
|
||||
def testDefaultGraph(self):
|
||||
self._RunTestOverParameters(self._TestRewrite)
|
||||
self._RunTestOverAllRewrites(self._TestRewrite)
|
||||
|
||||
def _TestDefaultGraph(self, fn):
|
||||
def _TestDefaultGraph(self, rewrite_fn):
|
||||
# Tests that the default graph is correctly used when no args are provided
|
||||
# to rewrite_fn.
|
||||
with ops.Graph().as_default() as g:
|
||||
batch_size, height, width, depth = 5, 128, 128, 3
|
||||
inputs = array_ops.zeros((batch_size, height, width, depth))
|
||||
conv = layers.conv2d(
|
||||
inputs,
|
||||
32, [5, 5],
|
||||
stride=2,
|
||||
padding='SAME',
|
||||
weights_initializer=self._WeightInit(0.09),
|
||||
activation_fn=None,
|
||||
scope='test')
|
||||
_ = nn_ops.relu6(conv)
|
||||
|
||||
self._ConvLayer()
|
||||
orig_variable_names = set(
|
||||
[v.name for v in g.get_collection(ops.GraphKeys.GLOBAL_VARIABLES)])
|
||||
|
||||
fn()
|
||||
rewrite_fn()
|
||||
|
||||
q_variables = g.get_collection(ops.GraphKeys.GLOBAL_VARIABLES)
|
||||
# Ensure that variables were added.
|
||||
self.assertTrue(len(orig_variable_names) < len(q_variables))
|
||||
|
||||
def _WeightInit(self, stddev):
|
||||
"""Returns truncated normal variable initializer.
|
||||
def testQuantDelay(self):
|
||||
self._RunTestOverTrainingRewrites(self._TestQuantDelay)
|
||||
|
||||
Function is defined purely to shorten the name so that it stops wrapping.
|
||||
def _TestQuantDelay(self, rewrite_fn):
|
||||
with ops.Graph().as_default() as g:
|
||||
self._ConvLayer()
|
||||
quant_delay = 100
|
||||
rewrite_fn(quant_delay=quant_delay)
|
||||
|
||||
Args:
|
||||
stddev: Standard deviation of normal variable.
|
||||
quant_delay_found = False
|
||||
for op in g.get_operations():
|
||||
# Check to see if the quant_delay is correctly set.
|
||||
if 'activate_quant' in op.name and op.type == 'Const':
|
||||
quant_delay_found = True
|
||||
const_value = str(op.get_attr('value'))
|
||||
self.assertTrue(('int64_val: %i' % quant_delay) in const_value)
|
||||
self.assertTrue(quant_delay_found)
|
||||
|
||||
Returns:
|
||||
An initialized that initialzes with a truncated normal variable.
|
||||
"""
|
||||
return init_ops.truncated_normal_initializer(stddev=stddev)
|
||||
def testWeightBits(self):
|
||||
self._RunTestOverExperimentalRewrites(self._TestWeightBits)
|
||||
|
||||
def _TestWeightBits(self, rewrite_fn):
|
||||
with ops.Graph().as_default() as g:
|
||||
self._ConvLayer()
|
||||
weight_bits = 4
|
||||
rewrite_fn(weight_bits=weight_bits)
|
||||
|
||||
weights_quant_found = False
|
||||
for op in g.get_operations():
|
||||
# Check to see if FakeQuant operations for weights have the right bits
|
||||
# set.
|
||||
if 'weights_quant' in op.name and op.type == 'FakeQuantWithMinMaxVars':
|
||||
weights_quant_found = True
|
||||
self.assertEqual(op.get_attr('num_bits'), weight_bits)
|
||||
self.assertTrue(weights_quant_found)
|
||||
|
||||
def testActivationBits(self):
|
||||
self._RunTestOverExperimentalRewrites(self._TestActivationBits)
|
||||
|
||||
def _TestActivationBits(self, rewrite_fn):
|
||||
with ops.Graph().as_default() as g:
|
||||
self._ConvLayer()
|
||||
activation_bits = 4
|
||||
rewrite_fn(activation_bits=activation_bits)
|
||||
|
||||
act_quant_found = False
|
||||
for op in g.get_operations():
|
||||
# Check to see if FakeQuant operations for activations have the right bits
|
||||
# set.
|
||||
act_quant_names = ['act_quant', 'conv_quant', 'add_quant']
|
||||
if any(s in op.name
|
||||
for s in act_quant_names) and op.type == 'FakeQuantWithMinMaxVars':
|
||||
act_quant_found = True
|
||||
self.assertEqual(op.get_attr('num_bits'), activation_bits)
|
||||
self.assertTrue(act_quant_found)
|
||||
|
||||
def _ConvLayer(self):
|
||||
"""Add a basic convolution layer to the default graph."""
|
||||
batch_size, height, width, depth = 5, 128, 128, 3
|
||||
inputs = array_ops.zeros((batch_size, height, width, depth))
|
||||
weight_init = init_ops.truncated_normal_initializer
|
||||
conv = layers.conv2d(
|
||||
inputs,
|
||||
32, [5, 5],
|
||||
stride=2,
|
||||
padding='SAME',
|
||||
weights_initializer=weight_init(0.09),
|
||||
activation_fn=None,
|
||||
scope='test')
|
||||
_ = nn_ops.relu6(conv)
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
|
@ -349,7 +349,8 @@ class Image(ItemHandler):
|
||||
shape=None,
|
||||
channels=3,
|
||||
dtype=dtypes.uint8,
|
||||
repeated=False):
|
||||
repeated=False,
|
||||
dct_method=''):
|
||||
"""Initializes the image.
|
||||
|
||||
Args:
|
||||
@ -368,6 +369,11 @@ class Image(ItemHandler):
|
||||
tf.decode_raw,
|
||||
repeated: if False, decodes a single image. If True, decodes a
|
||||
variable number of image strings from a 1D tensor of strings.
|
||||
dct_method: An optional string. Defaults to empty string. It only takes
|
||||
effect when image format is jpeg, used to specify a hint about the
|
||||
algorithm used for jpeg decompression. Currently valid values
|
||||
are ['INTEGER_FAST', 'INTEGER_ACCURATE']. The hint may be ignored, for
|
||||
example, the jpeg library does not have that specific option.
|
||||
"""
|
||||
if not image_key:
|
||||
image_key = 'image/encoded'
|
||||
@ -381,6 +387,7 @@ class Image(ItemHandler):
|
||||
self._channels = channels
|
||||
self._dtype = dtype
|
||||
self._repeated = repeated
|
||||
self._dct_method = dct_method
|
||||
|
||||
def tensors_to_item(self, keys_to_tensors):
|
||||
"""See base class."""
|
||||
@ -406,9 +413,25 @@ class Image(ItemHandler):
|
||||
A tensor that represents decoded image of self._shape, or
|
||||
(?, ?, self._channels) if self._shape is not specified.
|
||||
"""
|
||||
|
||||
def decode_image():
|
||||
"""Decodes a png or jpg based on the headers."""
|
||||
return image_ops.decode_image(image_buffer, self._channels)
|
||||
"""Decodes a image based on the headers."""
|
||||
return image_ops.decode_image(image_buffer, channels=self._channels)
|
||||
|
||||
def decode_jpeg():
|
||||
"""Decodes a jpeg image with specified '_dct_method'."""
|
||||
return image_ops.decode_jpeg(
|
||||
image_buffer, channels=self._channels, dct_method=self._dct_method)
|
||||
|
||||
def check_jpeg():
|
||||
"""Checks if an image is jpeg."""
|
||||
# For jpeg, we directly use image_ops.decode_jpeg rather than decode_image
|
||||
# in order to feed the jpeg specify parameter 'dct_method'.
|
||||
return control_flow_ops.cond(
|
||||
image_ops.is_jpeg(image_buffer),
|
||||
decode_jpeg,
|
||||
decode_image,
|
||||
name='cond_jpeg')
|
||||
|
||||
def decode_raw():
|
||||
"""Decodes a raw image."""
|
||||
@ -420,7 +443,7 @@ class Image(ItemHandler):
|
||||
math_ops.equal(image_format, 'RAW')): decode_raw,
|
||||
}
|
||||
image = control_flow_ops.case(
|
||||
pred_fn_pairs, default=decode_image, exclusive=True)
|
||||
pred_fn_pairs, default=check_jpeg, exclusive=True)
|
||||
|
||||
image.set_shape([None, None, self._channels])
|
||||
if self._shape is not None:
|
||||
|
60
tensorflow/contrib/summary/summary_test_internal.py
Normal file
60
tensorflow/contrib/summary/summary_test_internal.py
Normal file
@ -0,0 +1,60 @@
|
||||
# Copyright 2017 The TensorFlow Authors. All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ==============================================================================
|
||||
"""Internal helpers for tests in this directory."""
|
||||
|
||||
from __future__ import absolute_import
|
||||
from __future__ import division
|
||||
from __future__ import print_function
|
||||
|
||||
import functools
|
||||
import os
|
||||
|
||||
import sqlite3
|
||||
|
||||
from tensorflow.contrib.summary import summary_ops
|
||||
from tensorflow.python.framework import test_util
|
||||
|
||||
|
||||
class SummaryDbTest(test_util.TensorFlowTestCase):
|
||||
"""Helper for summary database testing."""
|
||||
|
||||
def setUp(self):
|
||||
super(SummaryDbTest, self).setUp()
|
||||
self.db_path = os.path.join(self.get_temp_dir(), 'DbTest.sqlite')
|
||||
if os.path.exists(self.db_path):
|
||||
os.unlink(self.db_path)
|
||||
self.db = sqlite3.connect(self.db_path)
|
||||
self.create_db_writer = functools.partial(
|
||||
summary_ops.create_db_writer,
|
||||
db_uri=self.db_path,
|
||||
experiment_name='experiment',
|
||||
run_name='run',
|
||||
user_name='user')
|
||||
|
||||
def tearDown(self):
|
||||
self.db.close()
|
||||
super(SummaryDbTest, self).tearDown()
|
||||
|
||||
|
||||
def get_one(db, q, *p):
|
||||
return db.execute(q, p).fetchone()[0]
|
||||
|
||||
|
||||
def get_all(db, q, *p):
|
||||
return unroll(db.execute(q, p).fetchall())
|
||||
|
||||
|
||||
def unroll(list_of_tuples):
|
||||
return sum(list_of_tuples, ())
|
@ -21,6 +21,7 @@ from __future__ import print_function
|
||||
|
||||
import functools
|
||||
import os
|
||||
|
||||
import sqlite3
|
||||
|
||||
from tensorflow.contrib.summary import summary_ops
|
||||
|
@ -106,7 +106,9 @@ class _TPUContext(object):
|
||||
# pylint: disable=protected-access
|
||||
tpu_system_metadata = (
|
||||
tpu_system_metadata_lib._query_tpu_system_metadata(
|
||||
master, query_topology=self.model_parallelism_enabled))
|
||||
master,
|
||||
run_config=self._config,
|
||||
query_topology=self.model_parallelism_enabled))
|
||||
|
||||
self._lazy_tpu_system_metadata_dict[master] = tpu_system_metadata
|
||||
return tpu_system_metadata
|
||||
@ -409,6 +411,29 @@ class _TPUContext(object):
|
||||
'Tensorflow master address and TPU worker(s). Available devices '
|
||||
'are {}.'.format(tpu_system_metadata.devices))
|
||||
|
||||
if self._config.tpu_config.num_shards:
|
||||
user_provided_num_replicas = self._config.tpu_config.num_shards
|
||||
if user_provided_num_replicas != num_replicas:
|
||||
message = (
|
||||
'TPUConfig.num_shards is not set correctly. According to TPU '
|
||||
'system metadata for Tensorflow master ({}): num_replicas should '
|
||||
'be ({}), got ({}). For non-model-parallelism, num_replicas should '
|
||||
'be the total num of TPU cores in the system. For '
|
||||
'model-parallelism, the total number of TPU cores should be '
|
||||
'product(computation_shape) * num_replicas. Please set it '
|
||||
'accordingly or leave it as `None`'.format(
|
||||
self._get_master_address(), num_replicas,
|
||||
user_provided_num_replicas))
|
||||
|
||||
if self.model_parallelism_enabled:
|
||||
raise ValueError(message)
|
||||
else:
|
||||
logging.warning(message)
|
||||
logging.warning(
|
||||
'For non-model-parallelism, TPUEstimator currently '
|
||||
'automatically queries the TPU system information so ignores '
|
||||
'this field.')
|
||||
|
||||
if mode == model_fn_lib.ModeKeys.TRAIN:
|
||||
if self._train_batch_size % num_replicas != 0:
|
||||
raise ValueError(
|
||||
|
@ -45,7 +45,8 @@ _TPUSystemMetadata = collections.namedtuple('_TPUSystemMetadata', [
|
||||
])
|
||||
|
||||
|
||||
def _query_tpu_system_metadata(master_address, query_topology=False):
|
||||
def _query_tpu_system_metadata(master_address, run_config,
|
||||
query_topology=False):
|
||||
"""Automatically detects the TPU system metadata in the system."""
|
||||
tpu_core_count = 0
|
||||
devices = []
|
||||
@ -59,8 +60,8 @@ def _query_tpu_system_metadata(master_address, query_topology=False):
|
||||
with ops.Graph().as_default():
|
||||
with session_lib.Session(
|
||||
master_address,
|
||||
config=config_pb2.ConfigProto(
|
||||
operation_timeout_in_ms=_PINGING_MASTER_TIMEOUT_IN_MS)) as sess:
|
||||
config=_get_session_config_with_timeout(
|
||||
_PINGING_MASTER_TIMEOUT_IN_MS, run_config)) as sess:
|
||||
devices = sess.list_devices()
|
||||
for device in devices:
|
||||
match = _TPU_DEVICE_REG.match(device.name)
|
||||
@ -104,7 +105,7 @@ def _query_tpu_system_metadata(master_address, query_topology=False):
|
||||
'TPU worker has some problems. Available devices: {}'.format(
|
||||
master_address, devices))
|
||||
|
||||
topology = _obtain_topology(master_address)
|
||||
topology = _obtain_topology(master_address, run_config)
|
||||
|
||||
metadata = _TPUSystemMetadata(
|
||||
num_cores=tpu_core_count,
|
||||
@ -113,19 +114,26 @@ def _query_tpu_system_metadata(master_address, query_topology=False):
|
||||
topology=topology,
|
||||
devices=devices)
|
||||
|
||||
msg = 'Found TPU system %s' if tpu_core_count else 'Failed to find TPU: %s'
|
||||
logging.info(msg, metadata)
|
||||
if tpu_core_count:
|
||||
logging.info('Found TPU system:')
|
||||
logging.info('*** Num TPU Cores: %d', metadata.num_cores)
|
||||
logging.info('*** Num TPU Workers: %d', metadata.num_hosts)
|
||||
logging.info('*** Num TPU Cores Per Worker: %d',
|
||||
metadata.num_of_cores_per_host)
|
||||
logging.info('*** Available Devices: %s', metadata.devices)
|
||||
else:
|
||||
logging.info('Failed to find TPU: %s', metadata)
|
||||
return metadata
|
||||
|
||||
|
||||
def _obtain_topology(master_address):
|
||||
def _obtain_topology(master_address, run_config):
|
||||
try:
|
||||
logging.info('Initializing TPU system (master: %s) to fetch topology '
|
||||
'for model parallelism. This might take a while.',
|
||||
master_address)
|
||||
with ops.Graph().as_default():
|
||||
session_config = config_pb2.ConfigProto(
|
||||
operation_timeout_in_ms=_INITIAL_TPU_SYSTEM_TIMEOUT_IN_MS)
|
||||
session_config = _get_session_config_with_timeout(
|
||||
_INITIAL_TPU_SYSTEM_TIMEOUT_IN_MS, run_config)
|
||||
with session_lib.Session(
|
||||
master_address, config=session_config) as sess:
|
||||
topology = sess.run(tpu.initialize_system())
|
||||
@ -137,3 +145,11 @@ def _obtain_topology(master_address):
|
||||
master_address))
|
||||
|
||||
|
||||
def _get_session_config_with_timeout(timeout_in_secs, run_config):
|
||||
cluster_def = None
|
||||
if run_config.session_config and run_config.session_config.cluster_def.job:
|
||||
cluster_def = run_config.session_config.cluster_def
|
||||
|
||||
config = config_pb2.ConfigProto(
|
||||
operation_timeout_in_ms=timeout_in_secs, cluster_def=cluster_def)
|
||||
return config
|
||||
|
@ -31,6 +31,11 @@ END
|
||||
description: <<END
|
||||
Input images can be of different types but output images are always float.
|
||||
|
||||
The range of pixel values for the output image might be slightly different
|
||||
from the range for the input image because of limited numerical precision.
|
||||
To guarantee an output range, for example `[0.0, 1.0]`, apply
|
||||
`tf.clip_by_value` to the output.
|
||||
|
||||
Each output pixel is computed by first transforming the pixel's footprint into
|
||||
the input tensor and then averaging the pixels that intersect the footprint. An
|
||||
input pixel's contribution to the average is weighted by the fraction of its
|
||||
|
@ -56,9 +56,9 @@ limitations under the License.
|
||||
//
|
||||
// To add values to feature_lists:
|
||||
// AppendFeatureValues({4.0},
|
||||
// GetFeatureList("movie_ratings", &se)->Add());
|
||||
// GetFeatureList("images", &se)->Add());
|
||||
// AppendFeatureValues({5.0, 3.0},
|
||||
// GetFeatureList("movie_ratings", &se)->Add());
|
||||
// GetFeatureList("images", &se)->Add());
|
||||
// This will create a feature list keyed as "images" with two features:
|
||||
// feature_lists {
|
||||
// feature_list {
|
||||
|
@ -1025,9 +1025,8 @@ StringPiece Tensor::tensor_data() const {
|
||||
}
|
||||
|
||||
bool Tensor::SharesBufferWith(const Tensor& b) const {
|
||||
CHECK_NE(nullptr, buf_);
|
||||
CHECK_NE(nullptr, b.buf_);
|
||||
return buf_->root_buffer() == b.buf_->root_buffer();
|
||||
return buf_ != nullptr && b.buf_ != nullptr &&
|
||||
buf_->root_buffer() == b.buf_->root_buffer();
|
||||
}
|
||||
|
||||
string Tensor::DebugString() const {
|
||||
|
@ -1085,6 +1085,21 @@ class DummyCPUAllocator : public Allocator {
|
||||
void DeallocateRaw(void* ptr) override {}
|
||||
};
|
||||
|
||||
TEST(Tensor, SharesBufferWith) {
|
||||
Tensor a_empty;
|
||||
Tensor b_empty;
|
||||
Tensor a(DT_FLOAT, TensorShape({1}));
|
||||
Tensor b(DT_FLOAT, TensorShape({1}));
|
||||
Tensor copy(a);
|
||||
EXPECT_FALSE(a_empty.SharesBufferWith(a_empty));
|
||||
EXPECT_FALSE(a_empty.SharesBufferWith(b_empty));
|
||||
EXPECT_FALSE(a_empty.SharesBufferWith(a));
|
||||
EXPECT_FALSE(a_empty.SharesBufferWith(copy));
|
||||
EXPECT_TRUE(a.SharesBufferWith(a));
|
||||
EXPECT_FALSE(a.SharesBufferWith(b));
|
||||
EXPECT_TRUE(a.SharesBufferWith(copy));
|
||||
}
|
||||
|
||||
TEST(Tensor, FailureToAllocate) {
|
||||
TensorShape shape({1});
|
||||
DummyCPUAllocator allocator;
|
||||
|
@ -28,6 +28,7 @@ limitations under the License.
|
||||
#include "tensorflow/core/lib/strings/strcat.h"
|
||||
#include "tensorflow/core/platform/mutex.h"
|
||||
#include "tensorflow/core/platform/notification.h"
|
||||
#include "tensorflow/core/platform/types.h"
|
||||
#include "tensorflow/core/public/session.h"
|
||||
|
||||
namespace tensorflow {
|
||||
@ -86,7 +87,9 @@ Status SingleMachine::Provision() {
|
||||
attr = GetLocalCPUInfo();
|
||||
} else if (dev.device_type() == "GPU") {
|
||||
attr = GetLocalGPUInfo(gpu_id++);
|
||||
} else {
|
||||
} else if (dev.device_type().find("XLA") == string::npos) {
|
||||
// Filter out the fake XLA devices to avoid double counting the actual
|
||||
// hardware resources that are available.
|
||||
attr.set_type(dev.device_type());
|
||||
}
|
||||
// Overwrite the memory size since users might have requested to use only a
|
||||
|
@ -27,6 +27,7 @@ limitations under the License.
|
||||
#include "tensorflow/core/grappler/costs/utils.h"
|
||||
#include "tensorflow/core/grappler/op_types.h"
|
||||
#include "tensorflow/core/grappler/utils.h"
|
||||
#include "tensorflow/core/lib/core/errors.h"
|
||||
#include "tensorflow/core/lib/strings/numbers.h"
|
||||
#include "tensorflow/core/lib/strings/str_util.h"
|
||||
#include "tensorflow/core/platform/logging.h"
|
||||
@ -446,13 +447,14 @@ Status VirtualScheduler::Init() {
|
||||
}
|
||||
|
||||
if (ready_nodes_->Empty()) {
|
||||
return Status(error::UNAVAILABLE, "No ready nodes in the graph.");
|
||||
return errors::InvalidArgument("No ready nodes in the graph.");
|
||||
}
|
||||
|
||||
if (!feed_nodes.empty())
|
||||
LOG(ERROR) << "Some feed nodes were not found in the graph: "
|
||||
<< str_util::Join(feed_nodes, ",");
|
||||
|
||||
if (!feed_nodes.empty()) {
|
||||
return errors::InvalidArgument(
|
||||
strings::StrCat("Some feed nodes were not found in the graph: ",
|
||||
str_util::Join(feed_nodes, ",")));
|
||||
}
|
||||
initialized_ = true;
|
||||
return Status::OK();
|
||||
}
|
||||
|
@ -371,6 +371,7 @@ cc_library(
|
||||
":dependency_optimizer",
|
||||
":graph_optimizer",
|
||||
":layout_optimizer",
|
||||
":loop_optimizer",
|
||||
":memory_optimizer",
|
||||
":model_pruner",
|
||||
"//tensorflow/core:framework",
|
||||
@ -380,3 +381,39 @@ cc_library(
|
||||
"//tensorflow/core/grappler/utils:topological_sort",
|
||||
],
|
||||
)
|
||||
|
||||
cc_library(
|
||||
name = "loop_optimizer",
|
||||
srcs = ["loop_optimizer.cc"],
|
||||
hdrs = [
|
||||
"loop_optimizer.h",
|
||||
],
|
||||
visibility = ["//visibility:public"],
|
||||
deps = [
|
||||
":graph_optimizer",
|
||||
"//tensorflow/core:framework",
|
||||
"//tensorflow/core:lib",
|
||||
"//tensorflow/core:lib_internal",
|
||||
"//tensorflow/core:protos_all_cc",
|
||||
"//tensorflow/core/grappler:grappler_item",
|
||||
"//tensorflow/core/grappler:op_types",
|
||||
"//tensorflow/core/grappler:utils",
|
||||
"//tensorflow/core/grappler/costs:graph_properties",
|
||||
],
|
||||
)
|
||||
|
||||
tf_cc_test(
|
||||
name = "loop_optimizer_test",
|
||||
size = "small",
|
||||
srcs = ["loop_optimizer_test.cc"],
|
||||
deps = [
|
||||
":loop_optimizer",
|
||||
"//tensorflow/cc:cc_ops",
|
||||
"//tensorflow/core:protos_all_cc",
|
||||
"//tensorflow/core:test",
|
||||
"//tensorflow/core:test_main",
|
||||
"//tensorflow/core/grappler:grappler_item",
|
||||
"//tensorflow/core/grappler:utils",
|
||||
"//tensorflow/core/grappler/inputs:trivial_test_graph_input_yielder",
|
||||
],
|
||||
)
|
||||
|
@ -1717,13 +1717,28 @@ class SqueezeProcessor : public AgnosticNodeProcessor {
|
||||
|
||||
protected:
|
||||
bool ShouldProcess() const override {
|
||||
return !MustPreserve() && IsPortZeroDimsN(*node_, 2) && HasOutputs() &&
|
||||
IsNodeAfterNCHWToNHWC() && IsInputConvertible() && IsAlongDimHW() &&
|
||||
IsOnGPU();
|
||||
bool is_dims_supported = (IsPortZeroDimsN(*node_, 2) && IsAlongHW()) ||
|
||||
(IsPortZeroDimsN(*node_, 1) && IsAlongNHW());
|
||||
return !MustPreserve() && HasOutputs() && IsNodeAfterNCHWToNHWC() &&
|
||||
IsInputConvertible() && is_dims_supported && IsOnGPU();
|
||||
}
|
||||
|
||||
Status AddLayoutTransposeToOutputs() override { return Status::OK(); }
|
||||
|
||||
Status CustomizedProcessing() override {
|
||||
TF_RETURN_IF_ERROR(HasAttribute(*node_, "squeeze_dims"));
|
||||
auto list = node_->mutable_attr()->at("squeeze_dims").mutable_list();
|
||||
if (list->i_size() == 2) {
|
||||
list->set_i(0, 2);
|
||||
list->set_i(1, 3);
|
||||
} else if (list->i_size() == 3) {
|
||||
list->set_i(1, 2);
|
||||
list->set_i(2, 3);
|
||||
}
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
private:
|
||||
bool IsInputConvertible() const {
|
||||
int input_port;
|
||||
auto input = node_map_->GetNode(node_->input(0));
|
||||
@ -1736,33 +1751,31 @@ class SqueezeProcessor : public AgnosticNodeProcessor {
|
||||
if (shape.dim(1).size() == 1 && shape.dim(2).size() == 1) {
|
||||
return true;
|
||||
}
|
||||
}
|
||||
return false;
|
||||
}
|
||||
|
||||
bool IsAlongDimHW() const {
|
||||
if (node_->attr().find("squeeze_dims") != node_->attr().end()) {
|
||||
auto list = node_->attr().at("squeeze_dims").list();
|
||||
// If list is empty, Squeeze op will squeeze all dimensions of size 1.
|
||||
if (list.i_size() == 0) return true;
|
||||
if (list.i_size() == 2) {
|
||||
if (list.i(0) == 1 && list.i(1) == 2) {
|
||||
return true;
|
||||
}
|
||||
if (shape.dim(0).size() == 1 && shape.dim(1).size() == 1 &&
|
||||
shape.dim(2).size() == 1) {
|
||||
return true;
|
||||
}
|
||||
}
|
||||
return false;
|
||||
}
|
||||
|
||||
Status CustomizedProcessing() override {
|
||||
TF_RETURN_IF_ERROR(HasAttribute(*node_, "squeeze_dims"));
|
||||
auto list = node_->mutable_attr()->at("squeeze_dims").mutable_list();
|
||||
if (list->i_size() == 2) {
|
||||
list->set_i(0, 2);
|
||||
list->set_i(1, 3);
|
||||
bool IsAlongAxis(const std::vector<int>& axis) const {
|
||||
if (node_->attr().find("squeeze_dims") != node_->attr().end()) {
|
||||
auto list = node_->attr().at("squeeze_dims").list();
|
||||
// If list is empty, Squeeze op will squeeze all dimensions of size 1.
|
||||
if (list.i_size() == 0) return true;
|
||||
if (list.i_size() == axis.size()) {
|
||||
bool along_axis = true;
|
||||
for (int i = 0; i < axis.size(); i++) {
|
||||
along_axis = along_axis && (list.i(i) == axis[i]);
|
||||
}
|
||||
if (along_axis) return true;
|
||||
}
|
||||
}
|
||||
return Status::OK();
|
||||
return false;
|
||||
}
|
||||
bool IsAlongHW() const { return IsAlongAxis({1, 2}); }
|
||||
bool IsAlongNHW() const { return IsAlongAxis({0, 1, 2}); }
|
||||
};
|
||||
|
||||
class ReduceProcessor : public AgnosticNodeProcessor {
|
||||
@ -1789,12 +1802,18 @@ class ReduceProcessor : public AgnosticNodeProcessor {
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
Status AddLayoutTransposeToOutputs() override { return Status::OK(); }
|
||||
Status AddLayoutTransposeToOutputs() override {
|
||||
if (KeepDims()) {
|
||||
return AddTransformToOutputs("Transpose");
|
||||
}
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
private:
|
||||
bool IsReduceAxisSupported() const {
|
||||
return IsAlongAllFourDims() || IsAlongHWC() ||
|
||||
((IsAlongNHW() || IsAlongHW() || IsAlongC()) && !KeepDims());
|
||||
return KeepDims() || ((IsAlongAllFourDims() || IsAlongHWC() ||
|
||||
IsAlongNHW() || IsAlongHW() || IsAlongC()) &&
|
||||
!KeepDims());
|
||||
}
|
||||
|
||||
bool IsAlongAxis(const std::vector<int>& axis) const {
|
||||
|
46
tensorflow/core/grappler/optimizers/loop_optimizer.cc
Normal file
46
tensorflow/core/grappler/optimizers/loop_optimizer.cc
Normal file
@ -0,0 +1,46 @@
|
||||
/* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
|
||||
#include "tensorflow/core/grappler/optimizers/loop_optimizer.h"
|
||||
|
||||
#include <unordered_map>
|
||||
#include <unordered_set>
|
||||
|
||||
#include "tensorflow/core/framework/op.h"
|
||||
#include "tensorflow/core/grappler/costs/graph_properties.h"
|
||||
#include "tensorflow/core/grappler/grappler_item.h"
|
||||
#include "tensorflow/core/grappler/op_types.h"
|
||||
#include "tensorflow/core/lib/core/errors.h"
|
||||
#include "tensorflow/core/lib/core/stringpiece.h"
|
||||
#include "tensorflow/core/lib/strings/strcat.h"
|
||||
|
||||
namespace tensorflow {
|
||||
namespace grappler {
|
||||
|
||||
Status LoopOptimizer::Optimize(Cluster* cluster, const GrapplerItem& item,
|
||||
GraphDef* optimized_graph) {
|
||||
*optimized_graph = item.graph;
|
||||
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
void LoopOptimizer::Feedback(Cluster* /*cluster*/, const GrapplerItem& /*item*/,
|
||||
const GraphDef& /*optimized_graph*/,
|
||||
double /*result*/) {
|
||||
// Nothing to do for LoopOptimizer.
|
||||
}
|
||||
|
||||
} // end namespace grappler
|
||||
} // end namespace tensorflow
|
49
tensorflow/core/grappler/optimizers/loop_optimizer.h
Normal file
49
tensorflow/core/grappler/optimizers/loop_optimizer.h
Normal file
@ -0,0 +1,49 @@
|
||||
/* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
|
||||
#ifndef TENSORFLOW_CORE_GRAPPLER_OPTIMIZERS_LOOP_OPTIMIZER_H_
|
||||
#define TENSORFLOW_CORE_GRAPPLER_OPTIMIZERS_LOOP_OPTIMIZER_H_
|
||||
|
||||
#include <unordered_set>
|
||||
#include "tensorflow/core/grappler/optimizers/graph_optimizer.h"
|
||||
#include "tensorflow/core/grappler/utils.h"
|
||||
#include "tensorflow/core/protobuf/rewriter_config.pb.h"
|
||||
|
||||
namespace tensorflow {
|
||||
namespace grappler {
|
||||
|
||||
class LoopOptimizer : public GraphOptimizer {
|
||||
public:
|
||||
LoopOptimizer() : opt_level_(RewriterConfig::ON) {}
|
||||
explicit LoopOptimizer(RewriterConfig::Toggle opt_level)
|
||||
: opt_level_(opt_level) {}
|
||||
~LoopOptimizer() override {}
|
||||
|
||||
string name() const override { return "loop_optimizer"; };
|
||||
|
||||
Status Optimize(Cluster* cluster, const GrapplerItem& item,
|
||||
GraphDef* optimized_graph) override;
|
||||
|
||||
void Feedback(Cluster* cluster, const GrapplerItem& item,
|
||||
const GraphDef& optimized_graph, double result) override;
|
||||
|
||||
private:
|
||||
RewriterConfig::Toggle opt_level_;
|
||||
};
|
||||
|
||||
} // end namespace grappler
|
||||
} // end namespace tensorflow
|
||||
|
||||
#endif // TENSORFLOW_CORE_GRAPPLER_OPTIMIZERS_LOOP_OPTIMIZER_H_
|
62
tensorflow/core/grappler/optimizers/loop_optimizer_test.cc
Normal file
62
tensorflow/core/grappler/optimizers/loop_optimizer_test.cc
Normal file
@ -0,0 +1,62 @@
|
||||
/* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
|
||||
#include "tensorflow/core/grappler/optimizers/loop_optimizer.h"
|
||||
#include "tensorflow/cc/ops/standard_ops.h"
|
||||
#include "tensorflow/core/framework/node_def.pb.h"
|
||||
#include "tensorflow/core/grappler/grappler_item.h"
|
||||
#include "tensorflow/core/grappler/inputs/trivial_test_graph_input_yielder.h"
|
||||
#include "tensorflow/core/grappler/utils.h"
|
||||
#include "tensorflow/core/lib/core/status_test_util.h"
|
||||
#include "tensorflow/core/platform/test.h"
|
||||
|
||||
namespace tensorflow {
|
||||
namespace grappler {
|
||||
namespace {
|
||||
|
||||
class LoopOptimizerTest : public ::testing::Test {};
|
||||
|
||||
void VerifyGraphsEqual(const GraphDef& original_graph,
|
||||
const GraphDef& optimized_graph, const string& func) {
|
||||
EXPECT_EQ(original_graph.node_size(), optimized_graph.node_size()) << func;
|
||||
for (int i = 0; i < original_graph.node_size(); ++i) {
|
||||
const NodeDef& original = original_graph.node(i);
|
||||
const NodeDef& optimized = optimized_graph.node(i);
|
||||
EXPECT_EQ(original.name(), optimized.name()) << func;
|
||||
EXPECT_EQ(original.op(), optimized.op()) << func;
|
||||
EXPECT_EQ(original.input_size(), optimized.input_size()) << func;
|
||||
for (int j = 0; j < original.input_size(); ++j) {
|
||||
EXPECT_EQ(original.input(j), optimized.input(j)) << func;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
TEST_F(LoopOptimizerTest, NoOp) {
|
||||
// This trivial graph is so basic there's nothing to optimize.
|
||||
TrivialTestGraphInputYielder fake_input(4, 1, 10, false, {"CPU:0"});
|
||||
GrapplerItem item;
|
||||
CHECK(fake_input.NextItem(&item));
|
||||
|
||||
LoopOptimizer optimizer;
|
||||
GraphDef output;
|
||||
Status status = optimizer.Optimize(nullptr, item, &output);
|
||||
TF_EXPECT_OK(status);
|
||||
|
||||
VerifyGraphsEqual(item.graph, output, __FUNCTION__);
|
||||
}
|
||||
|
||||
} // namespace
|
||||
} // namespace grappler
|
||||
} // namespace tensorflow
|
@ -490,12 +490,12 @@ void RecomputationRewritingPass(RewriterConfig::MemOptType optimization_level,
|
||||
}
|
||||
|
||||
bool SchedulingPass(Cluster* cluster, GrapplerItem* item) {
|
||||
// Look for AddN nodes and record input names.
|
||||
// Look for AddN nodes (and equivalent) and record input names.
|
||||
GraphView view(&item->graph);
|
||||
|
||||
std::unordered_map<string, std::unordered_set<NodeDef*>> addn_list;
|
||||
for (NodeDef& node : *item->graph.mutable_node()) {
|
||||
if (!IsAddN(node)) {
|
||||
if (!IsAddN(node) && node.op() != "AccumulateNV2") {
|
||||
continue;
|
||||
}
|
||||
// There is nothing to gain by optimizing nodes with 2 or fewer inputs.
|
||||
@ -511,6 +511,10 @@ bool SchedulingPass(Cluster* cluster, GrapplerItem* item) {
|
||||
}
|
||||
}
|
||||
|
||||
if (addn_list.empty()) {
|
||||
return false;
|
||||
}
|
||||
|
||||
GraphMemory memory(*item);
|
||||
const std::unordered_map<string, DeviceProperties>& devices =
|
||||
cluster->GetDevices();
|
||||
@ -560,6 +564,13 @@ bool SchedulingPass(Cluster* cluster, GrapplerItem* item) {
|
||||
VLOG(1) << "Missing properties for " << node->name();
|
||||
continue;
|
||||
}
|
||||
const TensorShapeProto& shape =
|
||||
properties.GetOutputProperties(node->name())[0].shape();
|
||||
PartialTensorShape shp(shape);
|
||||
if (!shp.IsFullyDefined()) {
|
||||
VLOG(1) << "Shape not fully known for " << node->name();
|
||||
continue;
|
||||
}
|
||||
|
||||
// Compute a topological ordering for the node fanin.
|
||||
std::unordered_map<NodeDef*, int> topo_order;
|
||||
@ -608,8 +619,6 @@ bool SchedulingPass(Cluster* cluster, GrapplerItem* item) {
|
||||
}
|
||||
}
|
||||
|
||||
const TensorShapeProto& shape =
|
||||
properties.GetOutputProperties(node->name())[0].shape();
|
||||
DataType dtype = node->attr().at("T").type();
|
||||
const string& device = node->device();
|
||||
|
||||
@ -721,6 +730,7 @@ Status BuildSwapPair(NodeDef* node, int input_to_swap,
|
||||
*swap_in_node->add_input() = swap_out_node->name();
|
||||
|
||||
// Colocate the swap_in_ node with the node itself.
|
||||
swap_in_node->set_device(node->device());
|
||||
string coloc_group = strings::StrCat("loc@", tensor_to_swap);
|
||||
(*swap_in_node->mutable_attr())["_class"].mutable_list()->add_s(coloc_group);
|
||||
(*node->mutable_attr())["_class"].mutable_list()->add_s(coloc_group);
|
||||
@ -1223,7 +1233,8 @@ Status MemoryOptimizer::Optimize(Cluster* cluster, const GrapplerItem& item,
|
||||
bool updated_graph = true;
|
||||
for (int i = 0; i < 25 && updated_graph; ++i) {
|
||||
updated_graph = false;
|
||||
if ((optimization_level_ == RewriterConfig::SCHEDULING_HEURISTICS ||
|
||||
if ((optimization_level_ == RewriterConfig::DEFAULT_MEM_OPT ||
|
||||
optimization_level_ == RewriterConfig::SCHEDULING_HEURISTICS ||
|
||||
optimization_level_ == RewriterConfig::HEURISTICS) &&
|
||||
cluster != nullptr) {
|
||||
updated_graph |= SchedulingPass(cluster, &optimized_item);
|
||||
|
@ -22,6 +22,7 @@ limitations under the License.
|
||||
#include "tensorflow/core/grappler/optimizers/dependency_optimizer.h"
|
||||
#include "tensorflow/core/grappler/optimizers/graph_optimizer.h"
|
||||
#include "tensorflow/core/grappler/optimizers/layout_optimizer.h"
|
||||
#include "tensorflow/core/grappler/optimizers/loop_optimizer.h"
|
||||
#include "tensorflow/core/grappler/optimizers/memory_optimizer.h"
|
||||
#include "tensorflow/core/grappler/optimizers/model_pruner.h"
|
||||
#include "tensorflow/core/grappler/utils/topological_sort.h"
|
||||
@ -75,6 +76,9 @@ std::unique_ptr<GraphOptimizer> MetaOptimizer::NewOptimizer(
|
||||
graph_optimizer.reset(
|
||||
new DependencyOptimizer(cfg_.dependency_optimization()));
|
||||
}
|
||||
if (optimizer == "loop") {
|
||||
graph_optimizer.reset(new LoopOptimizer(cfg_.loop_optimization()));
|
||||
}
|
||||
return graph_optimizer;
|
||||
}
|
||||
|
||||
@ -97,11 +101,15 @@ Status MetaOptimizer::Optimize(Cluster* cluster, const GrapplerItem& item,
|
||||
optimizers.push_back(std::unique_ptr<GraphOptimizer>(
|
||||
new DependencyOptimizer(cfg_.dependency_optimization())));
|
||||
}
|
||||
if (cfg_.loop_optimization() != RewriterConfig::OFF) {
|
||||
optimizers.push_back(std::unique_ptr<GraphOptimizer>(
|
||||
new LoopOptimizer(cfg_.loop_optimization())));
|
||||
}
|
||||
if (cfg_.layout_optimizer() != RewriterConfig::OFF) {
|
||||
optimizers.push_back(
|
||||
std::unique_ptr<GraphOptimizer>(new LayoutOptimizer()));
|
||||
}
|
||||
if (cfg_.memory_optimization() > 1) {
|
||||
if (cfg_.memory_optimization() != RewriterConfig::NO_MEM_OPT) {
|
||||
if (cfg_.memory_optimizer_target_node_name_prefix().empty()) {
|
||||
optimizers.push_back(std::unique_ptr<GraphOptimizer>(
|
||||
// Use the default target node name prefix "gradients/"
|
||||
@ -119,8 +127,8 @@ Status MetaOptimizer::Optimize(Cluster* cluster, const GrapplerItem& item,
|
||||
}
|
||||
} else {
|
||||
std::set<string> available_optimizers = {
|
||||
"pruning", "constfold", "layout", "memory",
|
||||
"autoparallel", "arithmetic", "dependency"};
|
||||
"pruning", "constfold", "layout", "memory",
|
||||
"autoparallel", "arithmetic", "dependency", "loop"};
|
||||
for (const auto& optimizer : cfg_.optimizers()) {
|
||||
if (available_optimizers.find(optimizer) != available_optimizers.end()) {
|
||||
optimizers.push_back(NewOptimizer(optimizer));
|
||||
@ -136,7 +144,7 @@ Status MetaOptimizer::Optimize(Cluster* cluster, const GrapplerItem& item,
|
||||
bool already_optimized = false;
|
||||
for (const auto& optimizer : optimizers) {
|
||||
if (!already_optimized) {
|
||||
auto status = optimizer->Optimize(cluster, item, optimized_graph);
|
||||
Status status = optimizer->Optimize(cluster, item, optimized_graph);
|
||||
string result;
|
||||
if (!status.ok()) {
|
||||
VLOG(1) << "Not able to apply optimizer " << optimizer->name()
|
||||
@ -152,7 +160,7 @@ Status MetaOptimizer::Optimize(Cluster* cluster, const GrapplerItem& item,
|
||||
<< " return status: " << result;
|
||||
} else {
|
||||
GrapplerItem optimized_item(item, std::move(*optimized_graph));
|
||||
auto status =
|
||||
Status status =
|
||||
optimizer->Optimize(cluster, optimized_item, optimized_graph);
|
||||
string result;
|
||||
if (!status.ok()) {
|
||||
@ -204,8 +212,10 @@ bool MetaOptimizerEnabled(const RewriterConfig& cfg) {
|
||||
cfg.layout_optimizer() != RewriterConfig::OFF ||
|
||||
cfg.constant_folding() != RewriterConfig::OFF ||
|
||||
cfg.dependency_optimization() != RewriterConfig::OFF ||
|
||||
cfg.loop_optimization() == RewriterConfig::ON ||
|
||||
cfg.arithmetic_optimization() != RewriterConfig::OFF ||
|
||||
cfg.auto_parallel().enable() || cfg.memory_optimization() > 1 ||
|
||||
cfg.auto_parallel().enable() ||
|
||||
cfg.memory_optimization() != RewriterConfig::NO_MEM_OPT ||
|
||||
!cfg.optimizers().empty();
|
||||
}
|
||||
|
||||
|
@ -67,7 +67,7 @@ Status ModelPruner::Optimize(Cluster* cluster, const GrapplerItem& item,
|
||||
// let's be conservative and preserve the graph as is.
|
||||
return errors::InvalidArgument("Invalid input graph.");
|
||||
}
|
||||
// Try to keep the nodes ordored somewhat topologically since this helps
|
||||
// Try to keep the nodes ordered somewhat topologically since this helps
|
||||
// further optimizations perform better.
|
||||
for (int i = keep.size() - 1; i >= 0; --i) {
|
||||
*runnable_item.graph.add_node() = *keep[i];
|
||||
|
File diff suppressed because it is too large
Load Diff
@ -37,11 +37,13 @@ message RewriterConfig {
|
||||
Toggle arithmetic_optimization = 7;
|
||||
// Control dependency optimizations (default is ON).
|
||||
Toggle dependency_optimization = 8;
|
||||
// Loop optimizations (default is OFF).
|
||||
Toggle loop_optimization = 9;
|
||||
// If true, don't remove unnecessary ops from the graph
|
||||
bool disable_model_pruning = 2;
|
||||
|
||||
enum MemOptType {
|
||||
// The default setting (currently disabled)
|
||||
// The default setting (SCHEDULING_HEURISTICS only)
|
||||
DEFAULT_MEM_OPT = 0;
|
||||
// Disabled in the meta-optimizer.
|
||||
NO_MEM_OPT = 1;
|
||||
|
@ -36,6 +36,7 @@ the following three:
|
||||
alt="Petal geometry compared for three iris species: Iris setosa, Iris virginica, and Iris versicolor"
|
||||
src="../images/iris_three_species.jpg">
|
||||
</div>
|
||||
|
||||
**From left to right,
|
||||
[*Iris setosa*](https://commons.wikimedia.org/w/index.php?curid=170298) (by
|
||||
[Radomil](https://commons.wikimedia.org/wiki/User:Radomil), CC BY-SA 3.0),
|
||||
@ -188,6 +189,7 @@ provides a programming stack consisting of multiple API layers:
|
||||
<div style="margin:auto; margin-bottom:10px; margin-top:20px;">
|
||||
<img style="width:100%" src="../images/tensorflow_programming_environment.png">
|
||||
</div>
|
||||
|
||||
**The TensorFlow Programming Environment.**
|
||||
<p> </p>
|
||||
|
||||
@ -331,7 +333,7 @@ interpret data is such a rich topic that we devote an entire
|
||||
From a code perspective, you build a list of `feature_column` objects by calling
|
||||
functions from the @{tf.feature_column} module. Each object describes an input
|
||||
to the model. To tell the model to interpret data as a floating-point value,
|
||||
call @{tf.feature_column.numeric_column). In `premade_estimator.py`, all
|
||||
call @{tf.feature_column.numeric_column}. In `premade_estimator.py`, all
|
||||
four features should be interpreted as literal floating-point values, so
|
||||
the code to create a feature column looks as follows:
|
||||
|
||||
@ -380,6 +382,7 @@ fully connected neural network consisting of three hidden layers:
|
||||
<div style="margin:auto; margin-bottom:10px; margin-top:20px;">
|
||||
<img style="width:100%" src="../images/simple_dnn.svg">
|
||||
</div>
|
||||
|
||||
**A neural network with three hidden layers.**
|
||||
<p> </p>
|
||||
|
||||
@ -568,6 +571,7 @@ of 0.5. The following suggests a more effective model:
|
||||
<tr> <td>5.5</td> <td>2.5</td> <td>4.0</td> <td>1.3</td> <td>1</td>
|
||||
<td style="background-color:green">1</td></tr>
|
||||
</table>
|
||||
|
||||
**A model that is 80% accurate.**
|
||||
<p> </p>
|
||||
|
||||
|
@ -98,6 +98,7 @@ classifies Iris flowers into three different species based on the size of their
|
||||
alt="Petal geometry compared for three iris species: Iris setosa, Iris virginica, and Iris versicolor"
|
||||
src="../images/iris_three_species.jpg">
|
||||
</div>
|
||||
|
||||
**From left to right,
|
||||
[*Iris setosa*](https://commons.wikimedia.org/w/index.php?curid=170298) (by
|
||||
[Radomil](https://commons.wikimedia.org/wiki/User:Radomil), CC BY-SA 3.0),
|
||||
|
Some files were not shown because too many files have changed in this diff Show More
Loading…
Reference in New Issue
Block a user