Merge changes from github.
Change: 153925676
This commit is contained in:
parent
3c0900a49c
commit
326942394e
15
README.md
15
README.md
@ -34,12 +34,13 @@ and discussion.**
|
||||
|
||||
People who are a little more adventurous can also try our nightly binaries:
|
||||
|
||||
* Linux CPU-only: [Python 2](https://ci.tensorflow.org/view/Nightly/job/nightly-matrix-cpu/TF_BUILD_IS_OPT=OPT,TF_BUILD_IS_PIP=PIP,TF_BUILD_PYTHON_VERSION=PYTHON2,label=cpu-slave/lastSuccessfulBuild/artifact/pip_test/whl/tensorflow-1.1.0rc1-cp27-none-linux_x86_64.whl) ([build history](https://ci.tensorflow.org/view/Nightly/job/nightly-matrix-cpu/TF_BUILD_IS_OPT=OPT,TF_BUILD_IS_PIP=PIP,TF_BUILD_PYTHON_VERSION=PYTHON2,label=cpu-slave)) / [Python 3.4](https://ci.tensorflow.org/view/Nightly/job/nightly-matrix-cpu/TF_BUILD_IS_OPT=OPT,TF_BUILD_IS_PIP=PIP,TF_BUILD_PYTHON_VERSION=PYTHON3,label=cpu-slave/lastSuccessfulBuild/artifact/pip_test/whl/tensorflow-1.1.0rc1-cp34-cp34m-linux_x86_64.whl) ([build history](https://ci.tensorflow.org/view/Nightly/job/nightly-matrix-cpu/TF_BUILD_IS_OPT=OPT,TF_BUILD_IS_PIP=PIP,TF_BUILD_PYTHON_VERSION=PYTHON3,label=cpu-slave/)) / [Python 3.5](https://ci.tensorflow.org/view/Nightly/job/nightly-python35-linux-cpu/lastSuccessfulBuild/artifact/pip_test/whl/tensorflow-1.1.0rc1-cp35-cp35m-linux_x86_64.whl) ([build history](https://ci.tensorflow.org/view/Nightly/job/nightly-python35-linux-cpu/))
|
||||
* Linux GPU: [Python 2](https://ci.tensorflow.org/view/Nightly/job/nightly-matrix-linux-gpu/TF_BUILD_IS_OPT=OPT,TF_BUILD_IS_PIP=PIP,TF_BUILD_PYTHON_VERSION=PYTHON2,label=gpu-linux/lastSuccessfulBuild/artifact/pip_test/whl/tensorflow_gpu-1.1.0rc1-cp27-none-linux_x86_64.whl) ([build history](https://ci.tensorflow.org/view/Nightly/job/nightly-matrix-linux-gpu/TF_BUILD_IS_OPT=OPT,TF_BUILD_IS_PIP=PIP,TF_BUILD_PYTHON_VERSION=PYTHON2,label=gpu-linux/)) / [Python 3.4](https://ci.tensorflow.org/view/Nightly/job/nightly-matrix-linux-gpu/TF_BUILD_IS_OPT=OPT,TF_BUILD_IS_PIP=PIP,TF_BUILD_PYTHON_VERSION=PYTHON3,label=gpu-linux/lastSuccessfulBuild/artifact/pip_test/whl/tensorflow_gpu-1.1.0rc1-cp34-cp34m-linux_x86_64.whl) ([build history](https://ci.tensorflow.org/view/Nightly/job/nightly-matrix-linux-gpu/TF_BUILD_IS_OPT=OPT,TF_BUILD_IS_PIP=PIP,TF_BUILD_PYTHON_VERSION=PYTHON3,label=gpu-linux/)) / [Python 3.5](https://ci.tensorflow.org/view/Nightly/job/nightly-matrix-linux-gpu/TF_BUILD_IS_OPT=OPT,TF_BUILD_IS_PIP=PIP,TF_BUILD_PYTHON_VERSION=PYTHON3.5,label=gpu-linux/lastSuccessfulBuild/artifact/pip_test/whl/tensorflow_gpu-1.1.0rc1-cp35-cp35m-linux_x86_64.whl) ([build history](https://ci.tensorflow.org/view/Nightly/job/nightly-matrix-linux-gpu/TF_BUILD_IS_OPT=OPT,TF_BUILD_IS_PIP=PIP,TF_BUILD_PYTHON_VERSION=PYTHON3.5,label=gpu-linux/))
|
||||
* Mac CPU-only: [Python 2](https://ci.tensorflow.org/view/Nightly/job/nightly-matrix-cpu/TF_BUILD_IS_OPT=OPT,TF_BUILD_IS_PIP=PIP,TF_BUILD_PYTHON_VERSION=PYTHON2,label=mac-slave/lastSuccessfulBuild/artifact/pip_test/whl/tensorflow-1.1.0rc1-py2-none-any.whl) ([build history](https://ci.tensorflow.org/view/Nightly/job/nightly-matrix-cpu/TF_BUILD_IS_OPT=OPT,TF_BUILD_IS_PIP=PIP,TF_BUILD_PYTHON_VERSION=PYTHON2,label=mac-slave/)) / [Python 3](https://ci.tensorflow.org/view/Nightly/job/nightly-matrix-cpu/TF_BUILD_IS_OPT=OPT,TF_BUILD_IS_PIP=PIP,TF_BUILD_PYTHON_VERSION=PYTHON3,label=mac-slave/lastSuccessfulBuild/artifact/pip_test/whl/tensorflow-1.1.0rc1-py3-none-any.whl) ([build history](https://ci.tensorflow.org/view/Nightly/job/nightly-matrix-cpu/TF_BUILD_IS_OPT=OPT,TF_BUILD_IS_PIP=PIP,TF_BUILD_PYTHON_VERSION=PYTHON3,label=mac-slave/))
|
||||
* Mac GPU: [Python 2](https://ci.tensorflow.org/view/Nightly/job/nightly-matrix-mac-gpu/TF_BUILD_IS_OPT=OPT,TF_BUILD_IS_PIP=PIP,TF_BUILD_PYTHON_VERSION=PYTHON2,label=gpu-mac/lastSuccessfulBuild/artifact/pip_test/whl/tensorflow_gpu-1.1.0rc1-py2-none-any.whl) ([build history](https://ci.tensorflow.org/view/Nightly/job/nightly-matrix-mac-gpu/TF_BUILD_IS_OPT=OPT,TF_BUILD_IS_PIP=PIP,TF_BUILD_PYTHON_VERSION=PYTHON2,label=gpu-mac/)) / [Python 3](https://ci.tensorflow.org/view/Nightly/job/nightly-matrix-mac-gpu/TF_BUILD_IS_OPT=OPT,TF_BUILD_IS_PIP=PIP,TF_BUILD_PYTHON_VERSION=PYTHON3,label=gpu-mac/lastSuccessfulBuild/artifact/pip_test/whl/tensorflow_gpu-1.1.0rc1-py3-none-any.whl) ([build history](https://ci.tensorflow.org/view/Nightly/job/nightly-matrix-mac-gpu/TF_BUILD_IS_OPT=OPT,TF_BUILD_IS_PIP=PIP,TF_BUILD_PYTHON_VERSION=PYTHON3,label=gpu-mac/))
|
||||
* Windows CPU-only: [Python 3.5 64-bit](https://ci.tensorflow.org/view/Nightly/job/nightly-win/DEVICE=cpu,OS=windows/lastSuccessfulBuild/artifact/cmake_build/tf_python/dist/tensorflow-1.1.0rc1-cp35-cp35m-win_amd64.whl) ([build history](https://ci.tensorflow.org/view/Nightly/job/nightly-win/DEVICE=cpu,OS=windows/))
|
||||
* Windows GPU: [Python 3.5 64-bit](https://ci.tensorflow.org/view/Nightly/job/nightly-win/DEVICE=gpu,OS=windows/lastSuccessfulBuild/artifact/cmake_build/tf_python/dist/tensorflow_gpu-1.1.0rc1-cp35-cp35m-win_amd64.whl) ([build history](https://ci.tensorflow.org/view/Nightly/job/nightly-win/DEVICE=gpu,OS=windows/))
|
||||
|
||||
* Linux CPU-only: [Python 2](https://ci.tensorflow.org/view/Nightly/job/nightly-matrix-cpu/TF_BUILD_IS_OPT=OPT,TF_BUILD_IS_PIP=PIP,TF_BUILD_PYTHON_VERSION=PYTHON2,label=cpu-slave/lastSuccessfulBuild/artifact/pip_test/whl/tensorflow-1.1.0rc2-cp27-none-linux_x86_64.whl) ([build history](https://ci.tensorflow.org/view/Nightly/job/nightly-matrix-cpu/TF_BUILD_IS_OPT=OPT,TF_BUILD_IS_PIP=PIP,TF_BUILD_PYTHON_VERSION=PYTHON2,label=cpu-slave)) / [Python 3.4](https://ci.tensorflow.org/view/Nightly/job/nightly-matrix-cpu/TF_BUILD_IS_OPT=OPT,TF_BUILD_IS_PIP=PIP,TF_BUILD_PYTHON_VERSION=PYTHON3,label=cpu-slave/lastSuccessfulBuild/artifact/pip_test/whl/tensorflow-1.1.0rc2-cp34-cp34m-linux_x86_64.whl) ([build history](https://ci.tensorflow.org/view/Nightly/job/nightly-matrix-cpu/TF_BUILD_IS_OPT=OPT,TF_BUILD_IS_PIP=PIP,TF_BUILD_PYTHON_VERSION=PYTHON3,label=cpu-slave/)) / [Python 3.5](https://ci.tensorflow.org/view/Nightly/job/nightly-python35-linux-cpu/lastSuccessfulBuild/artifact/pip_test/whl/tensorflow-1.1.0rc2-cp35-cp35m-linux_x86_64.whl) ([build history](https://ci.tensorflow.org/view/Nightly/job/nightly-python35-linux-cpu/))
|
||||
* Linux GPU: [Python 2](https://ci.tensorflow.org/view/Nightly/job/nightly-matrix-linux-gpu/TF_BUILD_IS_OPT=OPT,TF_BUILD_IS_PIP=PIP,TF_BUILD_PYTHON_VERSION=PYTHON2,label=gpu-linux/lastSuccessfulBuild/artifact/pip_test/whl/tensorflow_gpu-1.1.0rc2-cp27-none-linux_x86_64.whl) ([build history](https://ci.tensorflow.org/view/Nightly/job/nightly-matrix-linux-gpu/TF_BUILD_IS_OPT=OPT,TF_BUILD_IS_PIP=PIP,TF_BUILD_PYTHON_VERSION=PYTHON2,label=gpu-linux/)) / [Python 3.4](https://ci.tensorflow.org/view/Nightly/job/nightly-matrix-linux-gpu/TF_BUILD_IS_OPT=OPT,TF_BUILD_IS_PIP=PIP,TF_BUILD_PYTHON_VERSION=PYTHON3,label=gpu-linux/lastSuccessfulBuild/artifact/pip_test/whl/tensorflow_gpu-1.1.0rc2-cp34-cp34m-linux_x86_64.whl) ([build history](https://ci.tensorflow.org/view/Nightly/job/nightly-matrix-linux-gpu/TF_BUILD_IS_OPT=OPT,TF_BUILD_IS_PIP=PIP,TF_BUILD_PYTHON_VERSION=PYTHON3,label=gpu-linux/)) / [Python 3.5](https://ci.tensorflow.org/view/Nightly/job/nightly-matrix-linux-gpu/TF_BUILD_IS_OPT=OPT,TF_BUILD_IS_PIP=PIP,TF_BUILD_PYTHON_VERSION=PYTHON3.5,label=gpu-linux/lastSuccessfulBuild/artifact/pip_test/whl/tensorflow_gpu-1.1.0rc2-cp35-cp35m-linux_x86_64.whl) ([build history](https://ci.tensorflow.org/view/Nightly/job/nightly-matrix-linux-gpu/TF_BUILD_IS_OPT=OPT,TF_BUILD_IS_PIP=PIP,TF_BUILD_PYTHON_VERSION=PYTHON3.5,label=gpu-linux/))
|
||||
* Mac CPU-only: [Python 2](https://ci.tensorflow.org/view/Nightly/job/nightly-matrix-cpu/TF_BUILD_IS_OPT=OPT,TF_BUILD_IS_PIP=PIP,TF_BUILD_PYTHON_VERSION=PYTHON2,label=mac-slave/lastSuccessfulBuild/artifact/pip_test/whl/tensorflow-1.1.0rc2-py2-none-any.whl) ([build history](https://ci.tensorflow.org/view/Nightly/job/nightly-matrix-cpu/TF_BUILD_IS_OPT=OPT,TF_BUILD_IS_PIP=PIP,TF_BUILD_PYTHON_VERSION=PYTHON2,label=mac-slave/)) / [Python 3](https://ci.tensorflow.org/view/Nightly/job/nightly-matrix-cpu/TF_BUILD_IS_OPT=OPT,TF_BUILD_IS_PIP=PIP,TF_BUILD_PYTHON_VERSION=PYTHON3,label=mac-slave/lastSuccessfulBuild/artifact/pip_test/whl/tensorflow-1.1.0rc2-py3-none-any.whl) ([build history](https://ci.tensorflow.org/view/Nightly/job/nightly-matrix-cpu/TF_BUILD_IS_OPT=OPT,TF_BUILD_IS_PIP=PIP,TF_BUILD_PYTHON_VERSION=PYTHON3,label=mac-slave/))
|
||||
* Mac GPU: [Python 2](https://ci.tensorflow.org/view/Nightly/job/nightly-matrix-mac-gpu/TF_BUILD_IS_OPT=OPT,TF_BUILD_IS_PIP=PIP,TF_BUILD_PYTHON_VERSION=PYTHON2,label=gpu-mac/lastSuccessfulBuild/artifact/pip_test/whl/tensorflow_gpu-1.1.0rc2-py2-none-any.whl) ([build history](https://ci.tensorflow.org/view/Nightly/job/nightly-matrix-mac-gpu/TF_BUILD_IS_OPT=OPT,TF_BUILD_IS_PIP=PIP,TF_BUILD_PYTHON_VERSION=PYTHON2,label=gpu-mac/)) / [Python 3](https://ci.tensorflow.org/view/Nightly/job/nightly-matrix-mac-gpu/TF_BUILD_IS_OPT=OPT,TF_BUILD_IS_PIP=PIP,TF_BUILD_PYTHON_VERSION=PYTHON3,label=gpu-mac/lastSuccessfulBuild/artifact/pip_test/whl/tensorflow_gpu-1.1.0rc2-py3-none-any.whl) ([build history](https://ci.tensorflow.org/view/Nightly/job/nightly-matrix-mac-gpu/TF_BUILD_IS_OPT=OPT,TF_BUILD_IS_PIP=PIP,TF_BUILD_PYTHON_VERSION=PYTHON3,label=gpu-mac/))
|
||||
* Windows CPU-only: [Python 3.5 64-bit](https://ci.tensorflow.org/view/Nightly/job/nightly-win/DEVICE=cpu,OS=windows/lastSuccessfulBuild/artifact/cmake_build/tf_python/dist/tensorflow-1.1.0rc2-cp35-cp35m-win_amd64.whl) ([build history](https://ci.tensorflow.org/view/Nightly/job/nightly-win/DEVICE=cpu,OS=windows/))
|
||||
* Windows GPU: [Python 3.5 64-bit](https://ci.tensorflow.org/view/Nightly/job/nightly-win/DEVICE=gpu,OS=windows/lastSuccessfulBuild/artifact/cmake_build/tf_python/dist/tensorflow_gpu-1.1.0rc2-cp35-cp35m-win_amd64.whl) ([build history](https://ci.tensorflow.org/view/Nightly/job/nightly-win/DEVICE=gpu,OS=windows/))
|
||||
* Android: [demo APK](https://ci.tensorflow.org/view/Nightly/job/nightly-android/lastSuccessfulBuild/artifact/out/tensorflow_demo.apk), [native libs](http://ci.tensorflow.org/view/Nightly/job/nightly-android/lastSuccessfulBuild/artifact/out/native/)
|
||||
([build history](https://ci.tensorflow.org/view/Nightly/job/nightly-android/))
|
||||
|
||||
@ -62,7 +63,7 @@ $ python
|
||||
|
||||
## For more information
|
||||
|
||||
* [TensorFlow website](http://tensorflow.org)
|
||||
* [TensorFlow website](https://tensorflow.org)
|
||||
* [TensorFlow whitepaper](http://download.tensorflow.org/paper/whitepaper2015.pdf)
|
||||
* [TensorFlow Model Zoo](https://github.com/tensorflow/models)
|
||||
* [TensorFlow MOOC on Udacity](https://www.udacity.com/course/deep-learning--ud730)
|
||||
|
@ -36,6 +36,7 @@
|
||||
* New navigation bar in Curses-based UI
|
||||
* NodeStepper (command `invoke_stepper`) now uses intermediate tensor dumps. It also uses `TensorHandles` as direct feeds during successive `cont` calls for improved performance and reduced memory consumption.
|
||||
* Initial release of installation guides for Java, C, and Go.
|
||||
* Added Text Dashboard to TensorBoard.
|
||||
|
||||
## Deprecations
|
||||
|
||||
@ -91,6 +92,8 @@
|
||||
* Command history now persists across runs.
|
||||
* Bug fix in graph validation related to `tf.while_loops`.
|
||||
* Java Maven fixes for bugs with Windows installation.
|
||||
* Backport fixes and improvements from external keras.
|
||||
* Keras config file handling fix.
|
||||
|
||||
## Thanks to our Contributors
|
||||
|
||||
|
22
configure
vendored
22
configure
vendored
@ -94,10 +94,10 @@ write_action_env_to_bazelrc "PYTHON_BIN_PATH" "$PYTHON_BIN_PATH"
|
||||
if false; then # Disable building with MKL for now
|
||||
while [ "$TF_NEED_MKL" == "" ]; do
|
||||
fromuser=""
|
||||
read -p "Do you wish to build TensorFlow with MKL support? [y/N] " INPUT
|
||||
read -p "Do you wish to build TensorFlow with MKL support (experimental)? [y/N] " INPUT
|
||||
fromuser="1"
|
||||
case $INPUT in
|
||||
[Yy]* ) echo "MKL support will be enabled for TensorFlow"; TF_NEED_MKL=1;;
|
||||
[Yy]* ) echo "MKL support (experimental) (will be enabled for TensorFlow"; TF_NEED_MKL=1;;
|
||||
[Nn]* ) echo "No MKL support will be enabled for TensorFlow"; TF_NEED_MKL=0;;
|
||||
"" ) echo "No MKL support will be enabled for TensorFlow"; TF_NEED_MKL=0;;
|
||||
* ) echo "Invalid selection: " $INPUT;;
|
||||
@ -244,6 +244,24 @@ if [[ "$TF_ENABLE_XLA" == "1" ]]; then
|
||||
write_to_bazelrc 'build --define with_xla_support=true'
|
||||
fi
|
||||
|
||||
# Verbs configuration
|
||||
while [ "$TF_NEED_VERBS" == "" ]; do
|
||||
read -p "Do you wish to build TensorFlow with "\
|
||||
"VERBS support? [y/N] " INPUT
|
||||
case $INPUT in
|
||||
[Yy]* ) echo "VERBS support will be enabled for "\
|
||||
"TensorFlow"; TF_NEED_VERBS=1;;
|
||||
[Nn]* ) echo "No VERBS support will be enabled for "\
|
||||
"TensorFlow"; TF_NEED_VERBS=0;;
|
||||
"" ) echo "No VERBS support will be enabled for "\
|
||||
"TensorFlow"; TF_NEED_VERBS=0;;
|
||||
* ) echo "Invalid selection: " $INPUT;;
|
||||
esac
|
||||
done
|
||||
|
||||
if [[ "$TF_NEED_VERBS" == "1" ]]; then
|
||||
write_to_bazelrc 'build --define with_verbs_support=true'
|
||||
fi
|
||||
|
||||
# Invoke python_config and set up symlinks to python includes
|
||||
./util/python/python_config.sh "$PYTHON_BIN_PATH"
|
||||
|
@ -84,6 +84,12 @@ config_setting(
|
||||
visibility = ["//visibility:public"],
|
||||
)
|
||||
|
||||
config_setting(
|
||||
name = "linux_ppc64le",
|
||||
values = {"cpu": "ppc"},
|
||||
visibility = ["//visibility:public"],
|
||||
)
|
||||
|
||||
config_setting(
|
||||
name = "debug",
|
||||
values = {
|
||||
@ -108,7 +114,7 @@ config_setting(
|
||||
|
||||
# TODO(jhseu): Enable on other platforms other than Linux.
|
||||
config_setting(
|
||||
name = "with_jemalloc",
|
||||
name = "with_jemalloc_linux_x86_64",
|
||||
values = {
|
||||
"cpu": "k8",
|
||||
"define": "with_jemalloc=true",
|
||||
@ -116,6 +122,15 @@ config_setting(
|
||||
visibility = ["//visibility:public"],
|
||||
)
|
||||
|
||||
config_setting(
|
||||
name = "with_jemalloc_linux_ppc64le",
|
||||
values = {
|
||||
"cpu": "ppc",
|
||||
"define": "with_jemalloc=true",
|
||||
},
|
||||
visibility = ["//visibility:public"],
|
||||
)
|
||||
|
||||
config_setting(
|
||||
name = "with_gcp_support",
|
||||
values = {"define": "with_gcp_support=true"},
|
||||
@ -134,6 +149,12 @@ config_setting(
|
||||
visibility = ["//visibility:public"],
|
||||
)
|
||||
|
||||
config_setting(
|
||||
name = "with_verbs_support",
|
||||
values = {"define": "with_verbs_support=true"},
|
||||
visibility = ["//visibility:public"],
|
||||
)
|
||||
|
||||
package_group(
|
||||
name = "internal",
|
||||
packages = ["//tensorflow/..."],
|
||||
@ -249,6 +270,7 @@ filegroup(
|
||||
"//tensorflow/contrib/tfprof/python/tools/tfprof:all_files",
|
||||
"//tensorflow/contrib/training:all_files",
|
||||
"//tensorflow/contrib/util:all_files",
|
||||
"//tensorflow/contrib/verbs:all_files",
|
||||
"//tensorflow/contrib/xla_tf_graph:all_files",
|
||||
"//tensorflow/core:all_files",
|
||||
"//tensorflow/core/debug:all_files",
|
||||
|
@ -282,5 +282,6 @@ def target_llvm_triple():
|
||||
"//tensorflow:android_arm": "armv7-none-android",
|
||||
"//tensorflow:android_arm64": "aarch64-none-android",
|
||||
"//tensorflow:android_x86": "i686-none-android",
|
||||
"//tensorflow:linux_ppc64le": "ppc64le-ibm-linux-gnu",
|
||||
"//conditions:default": "x86_64-pc-linux",
|
||||
})
|
||||
|
@ -142,7 +142,7 @@ template <typename UInt>
|
||||
string LittleEndianData(UInt data) {
|
||||
static_assert(std::is_unsigned<UInt>::value, "UInt must be unsigned");
|
||||
string str;
|
||||
for (int i = 0; i < sizeof(UInt); ++i) {
|
||||
for (size_t i = 0; i < sizeof(UInt); ++i) {
|
||||
const unsigned char bits = static_cast<unsigned char>(data & 0xFFU);
|
||||
char ch;
|
||||
::memcpy(&ch, &bits, sizeof(bits));
|
||||
|
@ -26,6 +26,7 @@ from six.moves import xrange # pylint: disable=redefined-builtin
|
||||
|
||||
from tensorflow.contrib.learn.python.learn.datasets import base
|
||||
from tensorflow.python.framework import dtypes
|
||||
from tensorflow.python.framework import random_seed
|
||||
|
||||
# CVDF mirror of http://yann.lecun.com/exdb/mnist/
|
||||
SOURCE_URL = 'https://storage.googleapis.com/cvdf-datasets/mnist/'
|
||||
@ -109,12 +110,16 @@ class DataSet(object):
|
||||
fake_data=False,
|
||||
one_hot=False,
|
||||
dtype=dtypes.float32,
|
||||
reshape=True):
|
||||
reshape=True,
|
||||
seed=None):
|
||||
"""Construct a DataSet.
|
||||
one_hot arg is used only if fake_data is true. `dtype` can be either
|
||||
`uint8` to leave the input as `[0, 255]`, or `float32` to rescale into
|
||||
`[0, 1]`.
|
||||
`[0, 1]`. Seed arg provides for convenient deterministic testing.
|
||||
"""
|
||||
seed1, seed2 = random_seed.get_seed(seed)
|
||||
# If op level seed is not set, use whatever graph level seed is returned
|
||||
numpy.random.seed(seed1 if seed is None else seed2)
|
||||
dtype = dtypes.as_dtype(dtype).base_dtype
|
||||
if dtype not in (dtypes.uint8, dtypes.float32):
|
||||
raise TypeError('Invalid image dtype %r, expected uint8 or float32' %
|
||||
@ -208,11 +213,13 @@ def read_data_sets(train_dir,
|
||||
one_hot=False,
|
||||
dtype=dtypes.float32,
|
||||
reshape=True,
|
||||
validation_size=5000):
|
||||
validation_size=5000,
|
||||
seed=None):
|
||||
if fake_data:
|
||||
|
||||
def fake():
|
||||
return DataSet([], [], fake_data=True, one_hot=one_hot, dtype=dtype)
|
||||
return DataSet(
|
||||
[], [], fake_data=True, one_hot=one_hot, dtype=dtype, seed=seed)
|
||||
|
||||
train = fake()
|
||||
validation = fake()
|
||||
@ -254,12 +261,16 @@ def read_data_sets(train_dir,
|
||||
train_images = train_images[validation_size:]
|
||||
train_labels = train_labels[validation_size:]
|
||||
|
||||
train = DataSet(train_images, train_labels, dtype=dtype, reshape=reshape)
|
||||
validation = DataSet(validation_images,
|
||||
validation_labels,
|
||||
dtype=dtype,
|
||||
reshape=reshape)
|
||||
test = DataSet(test_images, test_labels, dtype=dtype, reshape=reshape)
|
||||
train = DataSet(
|
||||
train_images, train_labels, dtype=dtype, reshape=reshape, seed=seed)
|
||||
validation = DataSet(
|
||||
validation_images,
|
||||
validation_labels,
|
||||
dtype=dtype,
|
||||
reshape=reshape,
|
||||
seed=seed)
|
||||
test = DataSet(
|
||||
test_images, test_labels, dtype=dtype, reshape=reshape, seed=seed)
|
||||
|
||||
return base.Datasets(train=train, validation=validation, test=test)
|
||||
|
||||
|
@ -52,7 +52,6 @@ from tensorflow.python.framework import ops
|
||||
from tensorflow.python.ops import array_ops
|
||||
from tensorflow.python.ops import check_ops
|
||||
from tensorflow.python.ops import control_flow_ops
|
||||
from tensorflow.python.ops import data_flow_ops
|
||||
from tensorflow.python.ops import math_ops
|
||||
from tensorflow.python.ops import parsing_ops
|
||||
from tensorflow.python.ops import variables as variables_lib
|
||||
@ -63,7 +62,6 @@ from tensorflow.python.saved_model import tag_constants
|
||||
from tensorflow.python.training import basic_session_run_hooks
|
||||
from tensorflow.python.training import input as input_lib
|
||||
from tensorflow.python.training import monitored_session
|
||||
from tensorflow.python.training import queue_runner_impl
|
||||
from tensorflow.python.training import saver as saver_lib
|
||||
from tensorflow.python.training import session_run_hook
|
||||
from tensorflow.python.util import compat
|
||||
@ -82,18 +80,6 @@ def boston_input_fn(num_epochs=None):
|
||||
return features, labels
|
||||
|
||||
|
||||
def boston_input_fn_with_queue(num_epochs=None):
|
||||
features, labels = boston_input_fn(num_epochs=num_epochs)
|
||||
|
||||
# Create a minimal queue runner.
|
||||
fake_queue = data_flow_ops.FIFOQueue(30, dtypes.int32)
|
||||
queue_runner = queue_runner_impl.QueueRunner(fake_queue,
|
||||
[constant_op.constant(0)])
|
||||
queue_runner_impl.add_queue_runner(queue_runner)
|
||||
|
||||
return features, labels
|
||||
|
||||
|
||||
def iris_input_fn():
|
||||
iris = base.load_iris()
|
||||
features = array_ops.reshape(
|
||||
|
@ -38,7 +38,7 @@ then
|
||||
exit 1
|
||||
fi
|
||||
|
||||
./configure --prefix="${GENDIR}"
|
||||
./configure --prefix="${GENDIR}" --with-pic
|
||||
if [ $? -ne 0 ]
|
||||
then
|
||||
echo "./configure command failed."
|
||||
|
@ -68,8 +68,8 @@ class LSTMBlockCellTest(test.TestCase):
|
||||
m3 = array_ops.zeros([1, 2])
|
||||
g, ((out_m0, out_m1),
|
||||
(out_m2, out_m3)) = core_rnn_cell_impl.MultiRNNCell(
|
||||
[lstm_ops.LSTMBlockCell(2)] * 2, state_is_tuple=True)(x, (
|
||||
(m0, m1), (m2, m3)))
|
||||
[lstm_ops.LSTMBlockCell(2) for _ in range(2)],
|
||||
state_is_tuple=True)(x, ((m0, m1), (m2, m3)))
|
||||
sess.run([variables.global_variables_initializer()])
|
||||
res = sess.run([g, out_m0, out_m1, out_m2, out_m3], {
|
||||
x.name: np.array([[1., 1.]]),
|
||||
|
@ -473,7 +473,7 @@ class AttentionWrapper(core_rnn_cell.RNNCell):
|
||||
if probability_fn is None:
|
||||
probability_fn = nn_ops.softmax
|
||||
else:
|
||||
if not callable(cell_input_fn):
|
||||
if not callable(probability_fn):
|
||||
raise TypeError(
|
||||
"probability_fn must be callable, saw type: %s"
|
||||
% type(probability_fn).__name__)
|
||||
|
@ -447,7 +447,7 @@ vgg = tf.contrib.slim.nets.vgg
|
||||
images, labels = ...
|
||||
|
||||
# Create the model.
|
||||
predictions = vgg.vgg16(images)
|
||||
predictions = vgg.vgg_16(images)
|
||||
|
||||
# Define the loss functions and get the total loss.
|
||||
loss = slim.losses.softmax_cross_entropy(predictions, labels)
|
||||
|
168
tensorflow/contrib/verbs/BUILD
Normal file
168
tensorflow/contrib/verbs/BUILD
Normal file
@ -0,0 +1,168 @@
|
||||
# Description:
|
||||
# Verbs RDMA communication interfaces and implementations for TensorFlow.
|
||||
|
||||
package(default_visibility = [
|
||||
"//tensorflow:__subpackages__",
|
||||
])
|
||||
|
||||
licenses(["notice"]) # Apache 2.0
|
||||
|
||||
exports_files(["LICENSE"])
|
||||
|
||||
filegroup(
|
||||
name = "all_files",
|
||||
srcs = glob(
|
||||
["**/*"],
|
||||
exclude = [
|
||||
"**/METADATA",
|
||||
"**/OWNERS",
|
||||
],
|
||||
),
|
||||
visibility = ["//tensorflow:__subpackages__"],
|
||||
)
|
||||
|
||||
filegroup(
|
||||
name = "c_srcs",
|
||||
data = glob([
|
||||
"**/*.cc",
|
||||
"**/*.h",
|
||||
]),
|
||||
)
|
||||
|
||||
# For platform specific build config
|
||||
load(
|
||||
"//tensorflow/core:platform/default/build_config.bzl",
|
||||
"tf_proto_library_cc",
|
||||
)
|
||||
|
||||
tf_proto_library_cc(
|
||||
name = "verbs_service_proto",
|
||||
srcs = ["verbs_service.proto"],
|
||||
has_services = 1,
|
||||
cc_api_version = 2,
|
||||
visibility = [
|
||||
"//tensorflow:__subpackages__",
|
||||
],
|
||||
)
|
||||
|
||||
cc_library(
|
||||
name = "verbs_util",
|
||||
srcs = ["verbs_util.cc"],
|
||||
hdrs = ["verbs_util.h"],
|
||||
deps = [
|
||||
"//tensorflow/core:core_cpu_internal",
|
||||
"//tensorflow/core:framework",
|
||||
"//tensorflow/core:gpu_runtime",
|
||||
"//tensorflow/core:lib",
|
||||
"//tensorflow/core:lib_internal",
|
||||
],
|
||||
)
|
||||
|
||||
cc_library(
|
||||
name = "grpc_verbs_service",
|
||||
srcs = ["grpc_verbs_service.cc"],
|
||||
hdrs = ["grpc_verbs_service.h"],
|
||||
deps = [
|
||||
":grpc_verbs_service_impl",
|
||||
":rdma_mgr",
|
||||
":verbs_service_proto_cc",
|
||||
"//tensorflow/core:lib",
|
||||
"//tensorflow/core/distributed_runtime:session_mgr",
|
||||
"//tensorflow/core/distributed_runtime:worker_env",
|
||||
"//tensorflow/core/distributed_runtime/rpc:async_service_interface",
|
||||
"//tensorflow/core/distributed_runtime/rpc:grpc_call",
|
||||
"//tensorflow/core/distributed_runtime/rpc:grpc_util",
|
||||
"@grpc//:grpc++_unsecure",
|
||||
],
|
||||
alwayslink = 1,
|
||||
)
|
||||
|
||||
cc_library(
|
||||
name = "grpc_verbs_service_impl",
|
||||
srcs = ["grpc_verbs_service_impl.cc"],
|
||||
hdrs = ["grpc_verbs_service_impl.h"],
|
||||
deps = [
|
||||
":verbs_service_proto_cc",
|
||||
"@grpc//:grpc++_unsecure",
|
||||
],
|
||||
)
|
||||
|
||||
cc_library(
|
||||
name = "grpc_verbs_client",
|
||||
srcs = ["grpc_verbs_client.cc"],
|
||||
hdrs = ["grpc_verbs_client.h"],
|
||||
deps = [
|
||||
":grpc_verbs_service_impl",
|
||||
":verbs_service_proto_cc",
|
||||
"//tensorflow/core:lib",
|
||||
"//tensorflow/core/distributed_runtime:call_options",
|
||||
"//tensorflow/core/distributed_runtime/rpc:grpc_util",
|
||||
],
|
||||
alwayslink = 1,
|
||||
)
|
||||
|
||||
cc_library(
|
||||
name = "rdma_rendezvous_mgr",
|
||||
srcs = ["rdma_rendezvous_mgr.cc"],
|
||||
hdrs = ["rdma_rendezvous_mgr.h"],
|
||||
deps = [
|
||||
":rdma_mgr",
|
||||
"//tensorflow/core:core_cpu_internal",
|
||||
"//tensorflow/core:framework",
|
||||
"//tensorflow/core:lib",
|
||||
"//tensorflow/core/distributed_runtime:base_rendezvous_mgr",
|
||||
"//tensorflow/core/distributed_runtime:worker_env",
|
||||
],
|
||||
)
|
||||
|
||||
cc_library(
|
||||
name = "rdma_mgr",
|
||||
srcs = ["rdma_mgr.cc"],
|
||||
hdrs = ["rdma_mgr.h"],
|
||||
deps = [
|
||||
":grpc_verbs_client",
|
||||
":rdma",
|
||||
"//tensorflow/core:core_cpu_internal",
|
||||
"//tensorflow/core:framework",
|
||||
"//tensorflow/core:lib",
|
||||
"//tensorflow/core:lib_internal",
|
||||
"//tensorflow/core/distributed_runtime:worker_env",
|
||||
"//tensorflow/core/distributed_runtime/rpc:grpc_channel",
|
||||
"//tensorflow/core/distributed_runtime/rpc:grpc_worker_cache",
|
||||
],
|
||||
)
|
||||
|
||||
cc_library(
|
||||
name = "rdma",
|
||||
srcs = ["rdma.cc"],
|
||||
hdrs = ["rdma.h"],
|
||||
linkopts = select({
|
||||
"//tensorflow:with_verbs_support": ["-libverbs"],
|
||||
"//conditions:default": [],
|
||||
}),
|
||||
deps = [
|
||||
":verbs_util",
|
||||
"//tensorflow/core:core_cpu_internal",
|
||||
"//tensorflow/core:framework",
|
||||
"//tensorflow/core:gpu_runtime",
|
||||
"//tensorflow/core:lib",
|
||||
"//tensorflow/core:lib_internal",
|
||||
"//tensorflow/core/distributed_runtime:rendezvous_mgr_interface",
|
||||
"//tensorflow/core/distributed_runtime:session_mgr",
|
||||
"//tensorflow/core/distributed_runtime:worker_env",
|
||||
],
|
||||
)
|
||||
|
||||
cc_library(
|
||||
name = "verbs_server_lib",
|
||||
srcs = ["verbs_server_lib.cc"],
|
||||
hdrs = ["verbs_server_lib.h"],
|
||||
linkstatic = 1, # Seems to be needed since alwayslink is broken in bazel
|
||||
deps = [
|
||||
":grpc_verbs_service",
|
||||
":rdma_mgr",
|
||||
":rdma_rendezvous_mgr",
|
||||
"//tensorflow/core/distributed_runtime/rpc:grpc_server_lib",
|
||||
],
|
||||
alwayslink = 1,
|
||||
)
|
77
tensorflow/contrib/verbs/README.md
Normal file
77
tensorflow/contrib/verbs/README.md
Normal file
@ -0,0 +1,77 @@
|
||||
## How to compile and use Rdma-enabled tensorflow
|
||||
1. Follow the regular TF compilation instructions. During configure step, if you want ibverbs based Rdma support, answer yes to this question:
|
||||
|
||||
```Do you wish to build TensorFlow with VERBS-RDMA support [y/N]```
|
||||
|
||||
2. To turn on Rdma connection, add the protocol "grpc+verbs" in server definition:
|
||||
|
||||
```server = tf.train.Server(cluster, job_name="local", task_index=0, protocol='grpc+verbs') # default protocol is 'grpc'```
|
||||
|
||||
## Overview
|
||||
The design is based on Tensorflow r1.0. An Rdma path is added between servers for tensor transfer (weights, gradients, etc). The existing GRPC path remains and is responsible for "administrative" tasks, such as setting up the Rdma path, exchanging computation graphs, etc.
|
||||
|
||||
During the server setup, an Rdma manager is created to manage low-level Rdma components such as Rdma channel and Rdma adapter, an Rdma rendezvous manager is created to oversee send/recv operations between servers. Following the distributed Tensorflow design philosophy, the send operation is passive, i.e. merely placing a tensor in the local out-going table. It is the receive operation that actually initiates the tensor transfer.
|
||||
|
||||
Tensorflow dynamically allocates memory for tensors that are to be sent or received. This causes difficulty for Rdma operations where pinned memory is required. Two remedies are possible, either the memory is pinned, transfer, then unpinned for each and every tensor to be transferred, or a buffer is pre-allocated and pinned for each tensor. The former incurs significant operation overhead since pinning and unpinning memory for each dynamically generated tensor is slow. The latter incurs large memory overhead and extra copying from the tensor to its pinned buffer, but may still be faster than the former. The second approach is adopted in this design. Each Rdma channel, representing a Rdma connection to a peer, contains a table of pinned buffers for all the seen tensors that requires transfer. It is assumed that the tensor size rarely changes across different steps. So only one buffer is created for the same tensor across all the steps. In the rare case when the tensor size does increases, the old buffer is discarded and new buffer of larger size is created and pinned.
|
||||
|
||||
When a tensor is prepared fro transfer, it is first converted to TensorProto, then the proto is serialized to byte array and copied to the pinned buffer. The content of the buffer is transferred to the remote node via Rdma write. On the remote side, the process is reversed. This is illustrated in the diagram below. The conversion of TensorProto is introduced to simplify transfer of string-tensors. Also since the TensorProto lives in host memory, even if the origin tensor lives in the device, the pinned buffers are all allocated in the host memory.
|
||||

|
||||
|
||||
The following improvements can be made in the future. First, conversion to TensorProto and serialization can be avoided for numeric (float/int) tensors since their internal buffer can be access directly as byte array. Second, the pinned buffer may be allocated on device if the tensor is located in the device. This avoids extra device-to-host copy at the expense of extra device memory consumption.
|
||||
## Design details
|
||||
|
||||
### Rdma components
|
||||
|
||||
* **Rdma adapter:** The base for Rdma communications. It may contain multiple channels and buffers. It is responsible for handling various incoming Rdma messages.
|
||||
* **Rdma channel:** Responsible for Rdma connection to a particular node. It manages multiple buffers. A channel has a callback table which stores all the callbacks for the requested tensors.
|
||||
* **Rdma buffer:** Responsible for sending or receiving data. It has a fixed size memory to store the data. It has a queue to store the pending jobs. There are three types of buffers, message buffer, ACK buffer and tensor buffer. A channel has two message buffers, two ack buffers and many tensor buffers.
|
||||
* **Rdma manager:** Manages the adapter and channels, including channel creation, channel setup via GRPC service, channel lookup, etc.
|
||||
* **Rdma rendezvous manager:** manages multiple rdma rendezvous.
|
||||
* **Rdma rendezvous:** a derived class of BaseRemoteRendezvous. This class is the back end for "send" and "recv" ops. When the sendrecv_op wants to send or receive a tensor, it calls the rendezvous' "send" and "recv" functions respectively. Rendezvous are identified by "step_id", a random number, so that tensors for different iterations don't get mixed up.
|
||||
|
||||
### The SEND operation
|
||||
|
||||
In tensorflow, when rendezvous sends a tensor, it merely puts a tensor in a local table in the corresponding rendezvous. If the tensor has been requested, a callback exists in the table. "send" will activate the callback, which tries to send the tensor across the node.
|
||||
|
||||
|
||||
### The RECV operation
|
||||
|
||||
When a tensor is requested, rendezvous' recv function is called. The function first places a callback in the channel's callback table, which will be activated once the tensor is sent from the source. In the next step, a message is sent to notify the source of the requested tensor. Once the source receives the message, it will check locally for the tensor, if not found, a callback is placed in the table, otherwise, the tensor id will be placed at corresponding Rdma buffer's job queue for future transmission. When a tensor is scheduled to be transmitted, the Rdma buffer needs to have the memory allocated and initialized (registered with the remote buffer info). If the memory is not ready, the transmission is deferred, a message is sent to the destination to establish the memory first. The other case a transimssion can be deferred is when the buffer is still being used by an on-going transmission.
|
||||
|
||||
### Three types of Rdma buffers
|
||||
|
||||
* **Message buffer:** responsible for sending message only.
|
||||
* **Ack buffer:** once a message is sent, the recipient needs to send an ack via the ack buffer to free up the message buffer. An ack buffer is exclusively for its coupled message buffer.
|
||||
* **Tensor buffer:** responsible for sending tensors. The recipient needs to send back a message to free up the sending buffer.
|
||||
|
||||
### Rdma packet format
|
||||
|
||||
|type|name_size|name|step_id|buffer_size|remote_addr|rkey|is_dead|data_type|tensor_shape|tensor_bytes|tensor_buffer|
|
||||
|
||||
### Six types of Rdma messages
|
||||
* RDMA_MESSAGE_ACK
|
||||
* RDMA_MESSAGE_BUFFER_IDLE
|
||||
* RDMA_MESSAGE_BUFFER_REQUEST
|
||||
* RDMA_MESSAGE_BUFFER_RESPONSE
|
||||
* RDMA_MESSAGE_TENSOR_REQUEST
|
||||
* RDMA_MESSAGE_TENSOR_WRITE
|
||||
|
||||
### Actions upon receiving Rdma messages
|
||||
* RDMA_MESSAGE_ACK
|
||||
* sender: mark local ack buffer idle.
|
||||
* receiver: mark remote message buffer idle, send next item.
|
||||
* RDMA_MESSAGE_BUFFER_IDLE
|
||||
* sender: mark local message buffer idle, send next item.
|
||||
* receiver: send ack, set remote tensor buffer idle, send next item.
|
||||
* RDMA_MESSAGE_BUFFER_REQUEST
|
||||
* sender: mark local message buffer idle, send next item.
|
||||
* receiver: send ack, find or create tensor buffer, send BUFFER_RESPONSE.
|
||||
* RDMA_MESSAGE_BUFFER_RESPONSE
|
||||
* sender: mark local message buffer idle, send next item.
|
||||
* receiver: send ack, set remote buffer info, set local and remote buffer idle, send next item.
|
||||
* RDMA_MESSAGE_TENSOR_REQUEST
|
||||
* sender: mark local message buffer idle, send next item.
|
||||
* receiver: send ack, find or create tensor buffer, enqueue tensor id, send next item.
|
||||
* RDMA_MESSAGE_TENSOR_WRITE
|
||||
* sender: mark local message buffer idle, send next item.
|
||||
* receiver: run callback.
|
BIN
tensorflow/contrib/verbs/design_diagram.png
Normal file
BIN
tensorflow/contrib/verbs/design_diagram.png
Normal file
Binary file not shown.
After Width: | Height: | Size: 13 KiB |
47
tensorflow/contrib/verbs/grpc_verbs_client.cc
Normal file
47
tensorflow/contrib/verbs/grpc_verbs_client.cc
Normal file
@ -0,0 +1,47 @@
|
||||
/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
|
||||
#include "tensorflow/contrib/verbs/grpc_verbs_client.h"
|
||||
|
||||
#include "tensorflow/core/distributed_runtime/rpc/grpc_util.h"
|
||||
#include "tensorflow/core/lib/core/errors.h"
|
||||
#include "tensorflow/core/lib/core/status.h"
|
||||
|
||||
namespace tensorflow {
|
||||
|
||||
Status GrpcVerbsClient::GetRemoteAddress(CallOptions* call_options,
|
||||
const GetRemoteAddressRequest* request,
|
||||
GetRemoteAddressResponse* response) {
|
||||
::grpc::ClientContext ctx;
|
||||
ctx.set_fail_fast(false);
|
||||
SetDeadline(&ctx, call_options->GetTimeout());
|
||||
return FromGrpcStatus(stub_->GetRemoteAddress(&ctx, *request, response));
|
||||
}
|
||||
|
||||
Status GrpcVerbsClient::GetRemoteAddress(const GetRemoteAddressRequest* request,
|
||||
GetRemoteAddressResponse* response) {
|
||||
CallOptions call_options;
|
||||
call_options.SetTimeout(-1); // no time out
|
||||
return GetRemoteAddress(&call_options, request, response);
|
||||
}
|
||||
|
||||
void GrpcVerbsClient::SetDeadline(::grpc::ClientContext* ctx,
|
||||
int64 time_in_ms) {
|
||||
if (time_in_ms > 0) {
|
||||
ctx->set_deadline(gpr_time_from_millis(time_in_ms, GPR_TIMESPAN));
|
||||
}
|
||||
}
|
||||
|
||||
} // namespace tensorflow
|
50
tensorflow/contrib/verbs/grpc_verbs_client.h
Normal file
50
tensorflow/contrib/verbs/grpc_verbs_client.h
Normal file
@ -0,0 +1,50 @@
|
||||
/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
|
||||
#ifndef THIRD_PARTY_TENSORFLOW_CONTRIB_GRPC_VERBS_CLIENT_H_
|
||||
#define THIRD_PARTY_TENSORFLOW_CONTRIB_GRPC_VERBS_CLIENT_H_
|
||||
|
||||
#include "tensorflow/contrib/verbs/grpc_verbs_service_impl.h"
|
||||
#include "tensorflow/contrib/verbs/verbs_service.pb.h"
|
||||
#include "tensorflow/core/distributed_runtime/call_options.h"
|
||||
#include "tensorflow/core/distributed_runtime/rpc/grpc_util.h"
|
||||
#include "tensorflow/core/lib/core/status.h"
|
||||
|
||||
namespace tensorflow {
|
||||
|
||||
// GrpcVerbsClient is a client that uses gRPC to talk to the Verbs service.
|
||||
class GrpcVerbsClient {
|
||||
public:
|
||||
explicit GrpcVerbsClient(SharedGrpcChannelPtr client_channel)
|
||||
: stub_(grpc::VerbsService::NewStub(client_channel)) {}
|
||||
~GrpcVerbsClient() {}
|
||||
|
||||
Status GetRemoteAddress(CallOptions* call_options,
|
||||
const GetRemoteAddressRequest* request,
|
||||
GetRemoteAddressResponse* response);
|
||||
Status GetRemoteAddress(const GetRemoteAddressRequest* request,
|
||||
GetRemoteAddressResponse* response);
|
||||
|
||||
private:
|
||||
std::unique_ptr<grpc::VerbsService::Stub> stub_;
|
||||
|
||||
void SetDeadline(::grpc::ClientContext* ctx, int64 time_in_ms);
|
||||
|
||||
TF_DISALLOW_COPY_AND_ASSIGN(GrpcVerbsClient);
|
||||
};
|
||||
|
||||
} // namespace tensorflow
|
||||
|
||||
#endif // THIRD_PARTY_TENSORFLOW_CONTRIB_GRPC_VERBS_CLIENT_H_
|
165
tensorflow/contrib/verbs/grpc_verbs_service.cc
Normal file
165
tensorflow/contrib/verbs/grpc_verbs_service.cc
Normal file
@ -0,0 +1,165 @@
|
||||
/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
|
||||
#ifdef TENSORFLOW_USE_VERBS
|
||||
|
||||
#include "grpc++/alarm.h"
|
||||
#include "grpc++/grpc++.h"
|
||||
#include "grpc++/server_builder.h"
|
||||
|
||||
#include "tensorflow/contrib/verbs/grpc_verbs_service.h"
|
||||
#include "tensorflow/core/distributed_runtime/rpc/grpc_util.h"
|
||||
#include "tensorflow/core/distributed_runtime/session_mgr.h"
|
||||
|
||||
namespace tensorflow {
|
||||
|
||||
GrpcVerbsService::GrpcVerbsService(const WorkerEnv* worker_env,
|
||||
::grpc::ServerBuilder* builder)
|
||||
: is_shutdown_(false), worker_env_(worker_env) {
|
||||
builder->RegisterService(&verbs_service_);
|
||||
cq_ = builder->AddCompletionQueue().release();
|
||||
}
|
||||
|
||||
GrpcVerbsService::~GrpcVerbsService() {
|
||||
delete shutdown_alarm_;
|
||||
delete cq_;
|
||||
}
|
||||
|
||||
void GrpcVerbsService::Shutdown() {
|
||||
bool did_shutdown = false;
|
||||
{
|
||||
mutex_lock l(shutdown_mu_);
|
||||
if (!is_shutdown_) {
|
||||
LOG(INFO) << "Shutting down GrpcWorkerService.";
|
||||
is_shutdown_ = true;
|
||||
did_shutdown = true;
|
||||
}
|
||||
}
|
||||
if (did_shutdown) {
|
||||
shutdown_alarm_ =
|
||||
new ::grpc::Alarm(cq_, gpr_now(GPR_CLOCK_MONOTONIC), nullptr);
|
||||
}
|
||||
}
|
||||
|
||||
// This macro creates a new request for the given RPC method name
|
||||
// (e.g., `ENQUEUE_REQUEST(GetRemoteAddress, false);`), and enqueues it on
|
||||
// `this->cq_`.
|
||||
//
|
||||
// This macro is invoked one or more times for each RPC method to
|
||||
// ensure that there are sufficient completion queue entries to
|
||||
// handle incoming requests without blocking.
|
||||
//
|
||||
// The implementation of the request handler for each RPC method
|
||||
// must ensure that it calls ENQUEUE_REQUEST() for that RPC method,
|
||||
// to keep accepting new requests.
|
||||
#define ENQUEUE_REQUEST(method, supports_cancel) \
|
||||
do { \
|
||||
mutex_lock l(shutdown_mu_); \
|
||||
if (!is_shutdown_) { \
|
||||
Call<GrpcVerbsService, grpc::VerbsService::AsyncService, \
|
||||
method##Request, method##Response>:: \
|
||||
EnqueueRequest(&verbs_service_, cq_, \
|
||||
&grpc::VerbsService::AsyncService::Request##method, \
|
||||
&GrpcVerbsService::method##Handler, \
|
||||
(supports_cancel)); \
|
||||
} \
|
||||
} while (0)
|
||||
|
||||
// This method blocks forever handling requests from the completion queue.
|
||||
void GrpcVerbsService::HandleRPCsLoop() {
|
||||
for (int i = 0; i < 10; ++i) {
|
||||
ENQUEUE_REQUEST(GetRemoteAddress, false);
|
||||
}
|
||||
|
||||
void* tag;
|
||||
bool ok;
|
||||
|
||||
while (cq_->Next(&tag, &ok)) {
|
||||
UntypedCall<GrpcVerbsService>::Tag* callback_tag =
|
||||
static_cast<UntypedCall<GrpcVerbsService>::Tag*>(tag);
|
||||
if (callback_tag) {
|
||||
callback_tag->OnCompleted(this, ok);
|
||||
} else {
|
||||
cq_->Shutdown();
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
void GrpcVerbsService::GetRemoteAddressHandler(
|
||||
WorkerCall<GetRemoteAddressRequest, GetRemoteAddressResponse>* call) {
|
||||
Status s = GetRemoteAddressSync(&call->request, &call->response);
|
||||
call->SendResponse(ToGrpcStatus(s));
|
||||
ENQUEUE_REQUEST(GetRemoteAddress, false);
|
||||
}
|
||||
|
||||
// synchronous method
|
||||
Status GrpcVerbsService::GetRemoteAddressSync(
|
||||
const GetRemoteAddressRequest* request,
|
||||
GetRemoteAddressResponse* response) {
|
||||
// analyzing request
|
||||
// the channel setting part is redundant.
|
||||
const string remote_host_name = request->host_name();
|
||||
RdmaChannel* rc = rdma_mgr_->FindChannel(remote_host_name);
|
||||
CHECK(rc);
|
||||
RdmaAddress ra;
|
||||
ra.lid = request->channel().lid();
|
||||
ra.qpn = request->channel().qpn();
|
||||
ra.psn = request->channel().psn();
|
||||
rc->SetRemoteAddress(ra, false);
|
||||
rc->Connect();
|
||||
int i = 0;
|
||||
int idx[] = {1, 0, 3, 2};
|
||||
std::vector<RdmaBuffer*> mb(rc->message_buffers());
|
||||
CHECK_EQ(request->mr_size(), 4);
|
||||
for (const auto& mr : request->mr()) {
|
||||
// the connections are crossed, i.e.
|
||||
// local tx_message_buffer <---> remote rx_message_buffer_
|
||||
// local rx_message_buffer <---> remote tx_message_buffer_
|
||||
// local tx_ack_buffer <---> remote rx_ack_buffer_
|
||||
// local rx_ack_buffer <---> remote tx_ack_buffer_
|
||||
// hence idx[] = {1, 0, 3, 2}.
|
||||
RdmaBuffer* rb = mb[idx[i]];
|
||||
RemoteMR rmr;
|
||||
rmr.remote_addr = mr.remote_addr();
|
||||
rmr.rkey = mr.rkey();
|
||||
rb->SetRemoteMR(rmr, false);
|
||||
i++;
|
||||
}
|
||||
CHECK(i == RdmaChannel::kNumMessageBuffers);
|
||||
|
||||
// setting up response
|
||||
response->set_host_name(
|
||||
worker_env_->session_mgr->LegacySession()->worker_name);
|
||||
Channel* channel_info = response->mutable_channel();
|
||||
channel_info->set_lid(rc->self().lid);
|
||||
channel_info->set_qpn(rc->self().qpn);
|
||||
channel_info->set_psn(rc->self().psn);
|
||||
for (int i = 0; i < RdmaChannel::kNumMessageBuffers; i++) {
|
||||
MemoryRegion* mr = response->add_mr();
|
||||
mr->set_remote_addr(reinterpret_cast<uint64>(mb[i]->buffer()));
|
||||
mr->set_rkey(mb[i]->self()->rkey);
|
||||
}
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
// Create a GrpcVerbsService, then assign it to a given handle.
|
||||
void SetNewVerbsService(GrpcVerbsService** handle, const WorkerEnv* worker_env,
|
||||
::grpc::ServerBuilder* builder) {
|
||||
*handle = new GrpcVerbsService(worker_env, builder);
|
||||
}
|
||||
|
||||
} // namespace tensorflow
|
||||
|
||||
#endif // TENSORFLOW_USE_VERBS
|
72
tensorflow/contrib/verbs/grpc_verbs_service.h
Normal file
72
tensorflow/contrib/verbs/grpc_verbs_service.h
Normal file
@ -0,0 +1,72 @@
|
||||
/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
|
||||
#ifndef THIRD_PARTY_TENSORFLOW_CONTRIB_VERBS_GRPC_VERBS_SERVICE_H_
|
||||
#define THIRD_PARTY_TENSORFLOW_CONTRIB_VERBS_GRPC_VERBS_SERVICE_H_
|
||||
|
||||
#ifdef TENSORFLOW_USE_VERBS
|
||||
|
||||
#include "tensorflow/contrib/verbs/grpc_verbs_service_impl.h"
|
||||
#include "tensorflow/contrib/verbs/rdma_mgr.h"
|
||||
#include "tensorflow/contrib/verbs/verbs_service.pb.h"
|
||||
#include "tensorflow/core/distributed_runtime/rpc/async_service_interface.h"
|
||||
#include "tensorflow/core/distributed_runtime/rpc/grpc_call.h"
|
||||
#include "tensorflow/core/lib/core/refcount.h"
|
||||
|
||||
namespace grpc {
|
||||
class ServerBuilder;
|
||||
class ServerCompletionQueue;
|
||||
class Alarm;
|
||||
} // namespace grpc
|
||||
|
||||
namespace tensorflow {
|
||||
|
||||
class GrpcVerbsService : public AsyncServiceInterface {
|
||||
public:
|
||||
GrpcVerbsService(const WorkerEnv* worker_env, ::grpc::ServerBuilder* builder);
|
||||
~GrpcVerbsService();
|
||||
void HandleRPCsLoop() override;
|
||||
void Shutdown() override;
|
||||
void SetRdmaMgr(RdmaMgr* rdma_mgr) { rdma_mgr_ = rdma_mgr; }
|
||||
|
||||
private:
|
||||
template <class RequestMessage, class ResponseMessage>
|
||||
using WorkerCall = Call<GrpcVerbsService, grpc::VerbsService::AsyncService,
|
||||
RequestMessage, ResponseMessage>;
|
||||
void GetRemoteAddressHandler(
|
||||
WorkerCall<GetRemoteAddressRequest, GetRemoteAddressResponse>* call);
|
||||
Status GetRemoteAddressSync(const GetRemoteAddressRequest* request,
|
||||
GetRemoteAddressResponse* response);
|
||||
|
||||
::grpc::ServerCompletionQueue* cq_;
|
||||
grpc::VerbsService::AsyncService verbs_service_;
|
||||
mutex shutdown_mu_;
|
||||
bool is_shutdown_ GUARDED_BY(shutdown_mu_);
|
||||
::grpc::Alarm* shutdown_alarm_;
|
||||
// not owned
|
||||
RdmaMgr* rdma_mgr_;
|
||||
const WorkerEnv* const worker_env_;
|
||||
|
||||
TF_DISALLOW_COPY_AND_ASSIGN(GrpcVerbsService);
|
||||
};
|
||||
|
||||
// Create a GrpcVerbsService, then assign it to a given handle.
|
||||
void SetNewVerbsService(GrpcVerbsService** handle, const WorkerEnv* worker_env,
|
||||
::grpc::ServerBuilder* builder);
|
||||
|
||||
} // namespace tensorflow
|
||||
|
||||
#endif // TENSORFLOW_USE_VERBS
|
||||
#endif // THIRD_PARTY_TENSORFLOW_CONTRIB_VERBS_GRPC_VERBS_SERVICE_H_
|
68
tensorflow/contrib/verbs/grpc_verbs_service_impl.cc
Normal file
68
tensorflow/contrib/verbs/grpc_verbs_service_impl.cc
Normal file
@ -0,0 +1,68 @@
|
||||
/* Copyright 2016 The TensorFlow Authors. All Rights Reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
|
||||
#include "tensorflow/contrib/verbs/grpc_verbs_service_impl.h"
|
||||
|
||||
#include "grpc++/impl/codegen/async_stream.h"
|
||||
#include "grpc++/impl/codegen/async_unary_call.h"
|
||||
#include "grpc++/impl/codegen/channel_interface.h"
|
||||
#include "grpc++/impl/codegen/client_unary_call.h"
|
||||
#include "grpc++/impl/codegen/method_handler_impl.h"
|
||||
#include "grpc++/impl/codegen/rpc_service_method.h"
|
||||
#include "grpc++/impl/codegen/service_type.h"
|
||||
#include "grpc++/impl/codegen/sync_stream.h"
|
||||
|
||||
namespace tensorflow {
|
||||
|
||||
namespace grpc {
|
||||
|
||||
static const char* grpcVerbsService_method_names[] = {
|
||||
"/tensorflow.VerbsService/GetRemoteAddress",
|
||||
};
|
||||
|
||||
std::unique_ptr<VerbsService::Stub> VerbsService::NewStub(
|
||||
const std::shared_ptr< ::grpc::ChannelInterface>& channel,
|
||||
const ::grpc::StubOptions& options) {
|
||||
std::unique_ptr<VerbsService::Stub> stub(new VerbsService::Stub(channel));
|
||||
return stub;
|
||||
}
|
||||
|
||||
VerbsService::Stub::Stub(
|
||||
const std::shared_ptr< ::grpc::ChannelInterface>& channel)
|
||||
: channel_(channel),
|
||||
rpcmethod_GetRemoteAddress_(grpcVerbsService_method_names[0],
|
||||
::grpc::RpcMethod::NORMAL_RPC, channel) {}
|
||||
|
||||
::grpc::Status VerbsService::Stub::GetRemoteAddress(
|
||||
::grpc::ClientContext* context, const GetRemoteAddressRequest& request,
|
||||
GetRemoteAddressResponse* response) {
|
||||
return ::grpc::BlockingUnaryCall(channel_.get(), rpcmethod_GetRemoteAddress_,
|
||||
context, request, response);
|
||||
}
|
||||
|
||||
VerbsService::AsyncService::AsyncService() {
|
||||
for (int i = 0; i < 1; ++i) {
|
||||
AddMethod(new ::grpc::RpcServiceMethod(grpcVerbsService_method_names[i],
|
||||
::grpc::RpcMethod::NORMAL_RPC,
|
||||
nullptr));
|
||||
::grpc::Service::MarkMethodAsync(i);
|
||||
}
|
||||
}
|
||||
|
||||
VerbsService::AsyncService::~AsyncService() {}
|
||||
|
||||
} // namespace grpc
|
||||
|
||||
} // namespace tensorflow
|
89
tensorflow/contrib/verbs/grpc_verbs_service_impl.h
Normal file
89
tensorflow/contrib/verbs/grpc_verbs_service_impl.h
Normal file
@ -0,0 +1,89 @@
|
||||
/* Copyright 2016 The TensorFlow Authors. All Rights Reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
|
||||
#ifndef THIRD_PARTY_TENSORFLOW_CONTRIB_GRPC_VERBS_SERVICE_IMPL_H_
|
||||
#define THIRD_PARTY_TENSORFLOW_CONTRIB_GRPC_VERBS_SERVICE_IMPL_H_
|
||||
|
||||
#include "grpc++/impl/codegen/async_stream.h"
|
||||
#include "grpc++/impl/codegen/async_unary_call.h"
|
||||
#include "grpc++/impl/codegen/proto_utils.h"
|
||||
#include "grpc++/impl/codegen/rpc_method.h"
|
||||
#include "grpc++/impl/codegen/service_type.h"
|
||||
#include "grpc++/impl/codegen/status.h"
|
||||
#include "grpc++/impl/codegen/stub_options.h"
|
||||
#include "grpc++/impl/codegen/sync_stream.h"
|
||||
|
||||
#include "tensorflow/contrib/verbs/verbs_service.pb.h"
|
||||
|
||||
namespace grpc {
|
||||
class CompletionQueue;
|
||||
class Channel;
|
||||
class RpcService;
|
||||
class ServerCompletionQueue;
|
||||
class ServerContext;
|
||||
} // namespace grpc
|
||||
|
||||
namespace tensorflow {
|
||||
|
||||
namespace grpc {
|
||||
|
||||
// Implementation of `tensorflow.VerbsService`, based on the
|
||||
// definition in "//tensorflow/contrib/verbs/verbs_service.proto",
|
||||
// and the gRPC generated stub and service classes.
|
||||
// See the proto file for the definition of methods and messages.
|
||||
class VerbsService GRPC_FINAL {
|
||||
public:
|
||||
class StubInterface {
|
||||
public:
|
||||
virtual ~StubInterface() {}
|
||||
virtual ::grpc::Status GetRemoteAddress(
|
||||
::grpc::ClientContext* context, const GetRemoteAddressRequest& request,
|
||||
GetRemoteAddressResponse* response) = 0;
|
||||
};
|
||||
class Stub GRPC_FINAL : public StubInterface {
|
||||
public:
|
||||
Stub(const std::shared_ptr< ::grpc::ChannelInterface>& channel);
|
||||
::grpc::Status GetRemoteAddress(
|
||||
::grpc::ClientContext* context, const GetRemoteAddressRequest& request,
|
||||
GetRemoteAddressResponse* response) GRPC_OVERRIDE;
|
||||
|
||||
private:
|
||||
std::shared_ptr< ::grpc::ChannelInterface> channel_;
|
||||
const ::grpc::RpcMethod rpcmethod_GetRemoteAddress_;
|
||||
};
|
||||
static std::unique_ptr<Stub> NewStub(
|
||||
const std::shared_ptr< ::grpc::ChannelInterface>& channel,
|
||||
const ::grpc::StubOptions& options = ::grpc::StubOptions());
|
||||
|
||||
class AsyncService : public ::grpc::Service {
|
||||
public:
|
||||
AsyncService();
|
||||
virtual ~AsyncService();
|
||||
void RequestGetRemoteAddress(
|
||||
::grpc::ServerContext* context, GetRemoteAddressRequest* request,
|
||||
::grpc::ServerAsyncResponseWriter<GetRemoteAddressResponse>* response,
|
||||
::grpc::CompletionQueue* new_call_cq,
|
||||
::grpc::ServerCompletionQueue* notification_cq, void* tag) {
|
||||
::grpc::Service::RequestAsyncUnary(0, context, request, response,
|
||||
new_call_cq, notification_cq, tag);
|
||||
}
|
||||
};
|
||||
};
|
||||
|
||||
} // namespace grpc
|
||||
|
||||
} // namespace tensorflow
|
||||
|
||||
#endif // THIRD_PARTY_TENSORFLOW_CONTRIB_GRPC_VERBS_SERVICE_IMPL_H_
|
874
tensorflow/contrib/verbs/rdma.cc
Normal file
874
tensorflow/contrib/verbs/rdma.cc
Normal file
@ -0,0 +1,874 @@
|
||||
/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
|
||||
#ifdef TENSORFLOW_USE_VERBS
|
||||
|
||||
#include "tensorflow/contrib/verbs/rdma.h"
|
||||
#include <cstdlib>
|
||||
#include "tensorflow/contrib/verbs/verbs_util.h"
|
||||
#include "tensorflow/core/common_runtime/device_mgr.h"
|
||||
#include "tensorflow/core/common_runtime/dma_helper.h"
|
||||
#include "tensorflow/core/common_runtime/gpu/gpu_util.h"
|
||||
#include "tensorflow/core/distributed_runtime/rendezvous_mgr_interface.h"
|
||||
#include "tensorflow/core/distributed_runtime/session_mgr.h"
|
||||
#include "tensorflow/core/framework/rendezvous.h"
|
||||
#include "tensorflow/core/framework/tensor.h"
|
||||
#include "tensorflow/core/lib/core/status.h"
|
||||
#include "tensorflow/core/lib/core/stringpiece.h"
|
||||
#include "tensorflow/core/lib/hash/hash.h"
|
||||
#include "tensorflow/core/lib/random/random.h"
|
||||
|
||||
namespace tensorflow {
|
||||
|
||||
namespace {
|
||||
// hash name to 32-bit integer
|
||||
uint32_t NameHash(const string& name) {
|
||||
return Hash32(name.data(), name.size(), 0x1234ABCD);
|
||||
}
|
||||
|
||||
// convenience function for printing message
|
||||
string MessageTypeToString(RdmaMessageType rmt) {
|
||||
switch (rmt) {
|
||||
case RDMA_MESSAGE_ACK:
|
||||
return "RDMA_MESSAGE_ACK";
|
||||
break;
|
||||
case RDMA_MESSAGE_BUFFER_IDLE:
|
||||
return "RDMA_MESSAGE_BUFFER_IDLE";
|
||||
break;
|
||||
case RDMA_MESSAGE_BUFFER_REQUEST:
|
||||
return "RDMA_MESSAGE_BUFFER_REQUEST";
|
||||
break;
|
||||
case RDMA_MESSAGE_BUFFER_RESPONSE:
|
||||
return "RDMA_MESSAGE_BUFFER_RESPONSE";
|
||||
break;
|
||||
case RDMA_MESSAGE_TENSOR_REQUEST:
|
||||
return "RDMA_MESSAGE_TENSOR_REQUEST";
|
||||
break;
|
||||
case RDMA_MESSAGE_TENSOR_WRITE:
|
||||
return "RDMA_MESSAGE_TENSOR_WRITE";
|
||||
break;
|
||||
default:
|
||||
return "UNKNOWN MESSAGE";
|
||||
}
|
||||
}
|
||||
} // namespace
|
||||
|
||||
ibv_context* open_default_device() {
|
||||
ibv_device** dev_list;
|
||||
ibv_device* ib_dev;
|
||||
dev_list = ibv_get_device_list(NULL);
|
||||
CHECK(dev_list) << "No InfiniBand device found";
|
||||
ib_dev = dev_list[0];
|
||||
CHECK(ib_dev) << "No InfiniBand device found";
|
||||
ibv_context* context = ibv_open_device(ib_dev);
|
||||
CHECK(context) << "Open context failed for " << ibv_get_device_name(ib_dev);
|
||||
return context;
|
||||
}
|
||||
|
||||
ibv_pd* alloc_protection_domain(ibv_context* context) {
|
||||
ibv_pd* pd = ibv_alloc_pd(context);
|
||||
CHECK(pd) << "Failed to allocate protection domain";
|
||||
return pd;
|
||||
}
|
||||
|
||||
RdmaAdapter::RdmaAdapter(const WorkerEnv* worker_env)
|
||||
: context_(open_default_device()),
|
||||
pd_(alloc_protection_domain(context_)),
|
||||
worker_env_(worker_env) {
|
||||
event_channel_ = ibv_create_comp_channel(context_);
|
||||
CHECK(event_channel_) << "Failed to create completion channel";
|
||||
cq_ = ibv_create_cq(context_, MAX_CONCURRENT_WRITES * 2, NULL, event_channel_,
|
||||
0);
|
||||
CHECK(cq_) << "Failed to create completion queue";
|
||||
CHECK(!ibv_req_notify_cq(cq_, 0)) << "Failed to request CQ notification";
|
||||
polling_thread_.reset(Env::Default()->StartThread(
|
||||
ThreadOptions(), "RdmaAdapterCQThread", [this] { Process_CQ(); }));
|
||||
VLOG(2) << "Start RdmaAdapter: " << name();
|
||||
}
|
||||
|
||||
RdmaAdapter::~RdmaAdapter() {
|
||||
polling_thread_.reset();
|
||||
CHECK(!ibv_destroy_cq(cq_)) << "Failed to destroy CQ";
|
||||
CHECK(!ibv_destroy_comp_channel(event_channel_))
|
||||
<< "Failed to destroy channel";
|
||||
CHECK(!ibv_dealloc_pd(pd_)) << "Failed to deallocate PD";
|
||||
CHECK(!ibv_close_device(context_)) << "Failed to release context";
|
||||
}
|
||||
|
||||
string RdmaAdapter::name() const { return string(context_->device->name); }
|
||||
|
||||
// Function to process incoming messages
|
||||
// There are two types of messages:
|
||||
// 1. IBV_WC_RECV_RDMA_WITH_IMM (receive)
|
||||
// 2. IBV_WC_RDMA_WRITE (send))
|
||||
void RdmaAdapter::Process_CQ() {
|
||||
while (true) {
|
||||
ibv_cq* cq;
|
||||
void* cq_context;
|
||||
CHECK(!ibv_get_cq_event(event_channel_, &cq, &cq_context));
|
||||
CHECK(cq == cq_);
|
||||
ibv_ack_cq_events(cq, 1);
|
||||
CHECK(!ibv_req_notify_cq(cq_, 0));
|
||||
|
||||
int ne =
|
||||
ibv_poll_cq(cq_, MAX_CONCURRENT_WRITES * 2, static_cast<ibv_wc*>(wc_));
|
||||
CHECK_GE(ne, 0);
|
||||
for (int i = 0; i < ne; ++i) {
|
||||
CHECK(wc_[i].status == IBV_WC_SUCCESS)
|
||||
<< "Failed status \n"
|
||||
<< ibv_wc_status_str(wc_[i].status) << " " << wc_[i].status << " "
|
||||
<< static_cast<int>(wc_[i].wr_id) << " " << wc_[i].vendor_err;
|
||||
if (wc_[i].opcode == IBV_WC_RECV_RDMA_WITH_IMM) {
|
||||
RdmaChannel* rc = reinterpret_cast<RdmaChannel*>(wc_[i].wr_id);
|
||||
// put back a recv wr.
|
||||
rc->Recv();
|
||||
// imm_data is the index of RX buffer in the buffer table.
|
||||
uint32_t imm_data = wc_[i].imm_data;
|
||||
RdmaBuffer* rb = rc->FindBuffer(imm_data);
|
||||
RdmaMessage rm;
|
||||
RdmaMessage::ParseMessage(rm, rb->buffer_);
|
||||
VLOG(2) << "recv RDMA message: " << MessageTypeToString(rm.type_);
|
||||
|
||||
if (rm.type_ == RDMA_MESSAGE_ACK) {
|
||||
// receive an ack to a message
|
||||
rb = rc->tx_message_buffer_;
|
||||
rb->SetBufferStatus(remote, idle);
|
||||
rb->SendNextItem();
|
||||
} else if (rm.type_ == RDMA_MESSAGE_TENSOR_REQUEST) {
|
||||
// received a request-for-tensor message
|
||||
// send ack to release remote tx message buffer
|
||||
RdmaBuffer* ab = rc->tx_ack_buffer_;
|
||||
ab->SendNextItem();
|
||||
// find or create buffer
|
||||
RdmaBuffer* tb = rc->FindOrCreateBuffer(rm.name_);
|
||||
string key_with_step_id =
|
||||
VerbsUtil::AppendStepidToKey(rm.name_, rm.step_id_);
|
||||
tb->EnqueueItem(key_with_step_id);
|
||||
// send the next tensor
|
||||
worker_env_->compute_pool->Schedule([tb]() { tb->SendNextItem(); });
|
||||
} else if (rm.type_ == RDMA_MESSAGE_BUFFER_IDLE) {
|
||||
// receive tensor-buffer-ready message
|
||||
// send ack to release remote tx message buffer
|
||||
RdmaBuffer* ab = rc->tx_ack_buffer_;
|
||||
ab->SendNextItem();
|
||||
// find buffer
|
||||
RdmaBuffer* tb = rc->FindBuffer(rm.name_);
|
||||
tb->SetBufferStatus(remote, idle);
|
||||
worker_env_->compute_pool->Schedule([tb]() { tb->SendNextItem(); });
|
||||
} else if (rm.type_ == RDMA_MESSAGE_BUFFER_REQUEST) {
|
||||
// remote host requests to create a tensor buffer;
|
||||
// send ack to release remote tx message buffer
|
||||
RdmaBuffer* ab = rc->tx_ack_buffer_;
|
||||
ab->SendNextItem();
|
||||
// find or create the buffer
|
||||
RdmaBuffer* tb = rc->FindOrCreateBuffer(rm.name_, TENSOR);
|
||||
RemoteMR rmr;
|
||||
rmr.remote_addr = rm.remote_addr_;
|
||||
rmr.rkey = rm.rkey_;
|
||||
tb->SetRemoteMR(rmr, true);
|
||||
tb->CreateCPUBuffer(rm.buffer_size_);
|
||||
// create RDMA_MESSAGE_BUFFER_RESPONSE message
|
||||
RdmaMessage br;
|
||||
br.type_ = RDMA_MESSAGE_BUFFER_RESPONSE;
|
||||
br.name_size_ = rm.name_.size();
|
||||
br.name_ = rm.name_;
|
||||
br.buffer_size_ = rm.buffer_size_;
|
||||
br.remote_addr_ = reinterpret_cast<uint64_t>(tb->buffer_);
|
||||
br.rkey_ = tb->self_->rkey;
|
||||
string message = RdmaMessage::CreateMessage(br);
|
||||
RdmaBuffer* mb = rc->tx_message_buffer_;
|
||||
mb->EnqueueItem(message);
|
||||
mb->SendNextItem();
|
||||
} else if (rm.type_ == RDMA_MESSAGE_BUFFER_RESPONSE) {
|
||||
// remote creates a buffer and responds
|
||||
// send ack to release remote tx message buffer
|
||||
RdmaBuffer* ab = rc->tx_ack_buffer_;
|
||||
ab->SendNextItem();
|
||||
// find buffer
|
||||
RdmaBuffer* tb = rc->FindBuffer(rm.name_);
|
||||
CHECK(rm.buffer_size_ == tb->size_)
|
||||
<< "rm.buffer_size = " << rm.buffer_size_
|
||||
<< "tb->size_ = " << tb->size_ << "rm.name_ = " << rm.name_;
|
||||
RemoteMR rmr;
|
||||
rmr.remote_addr = rm.remote_addr_;
|
||||
rmr.rkey = rm.rkey_;
|
||||
tb->SetRemoteMR(rmr, true);
|
||||
tb->SetBufferStatus(local, idle);
|
||||
tb->SetBufferStatus(remote, idle);
|
||||
worker_env_->compute_pool->Schedule([tb]() { tb->SendNextItem(); });
|
||||
} else if (rm.type_ == RDMA_MESSAGE_TENSOR_WRITE) {
|
||||
// tensor RDMA write completed
|
||||
worker_env_->compute_pool->Schedule([rm, rc]() {
|
||||
string key_with_step_id =
|
||||
VerbsUtil::AppendStepidToKey(rm.name_, rm.step_id_);
|
||||
rc->RunRecvCallback(key_with_step_id);
|
||||
});
|
||||
}
|
||||
} else if (wc_[i].opcode == IBV_WC_RDMA_WRITE) {
|
||||
RdmaBuffer* rb = reinterpret_cast<RdmaBuffer*>(wc_[i].wr_id);
|
||||
rb->SetBufferStatus(local, idle);
|
||||
RdmaMessage rm;
|
||||
RdmaMessage::ParseMessage(rm, rb->buffer_);
|
||||
VLOG(2) << "sent RDMA message: " << MessageTypeToString(rm.type_);
|
||||
if (rm.type_ != RDMA_MESSAGE_ACK) {
|
||||
worker_env_->compute_pool->Schedule([rb]() { rb->SendNextItem(); });
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
RdmaChannel::RdmaChannel(const RdmaAdapter* adapter, const string local_name,
|
||||
const string remote_name)
|
||||
: adapter_(adapter), local_name_(local_name), remote_name_(remote_name) {
|
||||
// Create queue pair
|
||||
{
|
||||
struct ibv_qp_init_attr attr;
|
||||
memset(&attr, 0, sizeof(ibv_qp_init_attr));
|
||||
attr.send_cq = adapter_->cq_;
|
||||
attr.recv_cq = adapter_->cq_;
|
||||
attr.cap.max_send_wr = RdmaAdapter::MAX_CONCURRENT_WRITES;
|
||||
attr.cap.max_recv_wr = RdmaAdapter::MAX_CONCURRENT_WRITES;
|
||||
attr.cap.max_send_sge = 1;
|
||||
attr.cap.max_recv_sge = 1;
|
||||
attr.qp_type = IBV_QPT_RC;
|
||||
|
||||
qp_ = ibv_create_qp(adapter_->pd_, &attr);
|
||||
CHECK(qp_) << "Failed to create queue pair";
|
||||
}
|
||||
|
||||
// Init queue pair
|
||||
{
|
||||
struct ibv_qp_attr attr;
|
||||
memset(&attr, 0, sizeof(ibv_qp_attr));
|
||||
attr.qp_state = IBV_QPS_INIT;
|
||||
attr.pkey_index = 0;
|
||||
attr.port_num = 1;
|
||||
attr.qp_access_flags = IBV_ACCESS_LOCAL_WRITE | IBV_ACCESS_REMOTE_WRITE;
|
||||
|
||||
int mask =
|
||||
IBV_QP_STATE | IBV_QP_PKEY_INDEX | IBV_QP_PORT | IBV_QP_ACCESS_FLAGS;
|
||||
CHECK(!ibv_modify_qp(qp_, &attr, mask)) << "Failed to set QP to INIT";
|
||||
}
|
||||
|
||||
// Local address
|
||||
{
|
||||
struct ibv_port_attr attr;
|
||||
CHECK(!ibv_query_port(adapter_->context_, (uint8_t)1, &attr))
|
||||
<< "Query port";
|
||||
self_.lid = attr.lid;
|
||||
self_.qpn = qp_->qp_num;
|
||||
self_.psn = static_cast<uint32_t>(random::New64()) & 0xffffff;
|
||||
}
|
||||
|
||||
// create message and ack buffers, then initialize the tables.
|
||||
{
|
||||
const string buffer_names[] = {"tx_message_buffer", "rx_message_buffer",
|
||||
"tx_ack_buffer", "rx_ack_buffer"};
|
||||
tx_message_buffer_ = new RdmaMessageBuffer(this, buffer_names[0]);
|
||||
rx_message_buffer_ = new RdmaMessageBuffer(this, buffer_names[1]);
|
||||
tx_ack_buffer_ = new RdmaAckBuffer(this, buffer_names[2]);
|
||||
rx_ack_buffer_ = new RdmaAckBuffer(this, buffer_names[3]);
|
||||
message_buffers_.reserve(kNumMessageBuffers);
|
||||
message_buffers_.push_back(tx_message_buffer_);
|
||||
message_buffers_.push_back(rx_message_buffer_);
|
||||
message_buffers_.push_back(tx_ack_buffer_);
|
||||
message_buffers_.push_back(rx_ack_buffer_);
|
||||
// create buffer on host
|
||||
tx_message_buffer_->CreateCPUBuffer(RdmaMessage::kRdmaMessageBufferSize);
|
||||
rx_message_buffer_->CreateCPUBuffer(RdmaMessage::kRdmaMessageBufferSize);
|
||||
tx_ack_buffer_->CreateCPUBuffer(RdmaMessage::kRdmaAckBufferSize);
|
||||
rx_ack_buffer_->CreateCPUBuffer(RdmaMessage::kRdmaAckBufferSize);
|
||||
// bt_mu_.lock() is not used in constructor.
|
||||
for (int i = 0; i < kNumMessageBuffers; i++) {
|
||||
uint32_t index = NameHash(buffer_names[i]);
|
||||
buffer_table_.insert({index, message_buffers_[i]});
|
||||
buffer_index_name_table_.insert({index, buffer_names[i]});
|
||||
buffer_name_index_table_.insert({buffer_names[i], index});
|
||||
}
|
||||
|
||||
// Initiate recv
|
||||
for (int i = 0; i < 100; i++) {
|
||||
Recv();
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
RdmaChannel::~RdmaChannel() {
|
||||
CHECK(!ibv_destroy_qp(qp_)) << "Failed to destroy QP";
|
||||
delete tx_message_buffer_;
|
||||
delete rx_message_buffer_;
|
||||
delete tx_ack_buffer_;
|
||||
delete rx_ack_buffer_;
|
||||
}
|
||||
|
||||
void RdmaChannel::SetRemoteAddress(const RdmaAddress& ra, bool override) {
|
||||
mutex_lock lock{mu_};
|
||||
if ((override) || (!remote_set_)) {
|
||||
remote_.lid = ra.lid;
|
||||
remote_.qpn = ra.qpn;
|
||||
remote_.psn = ra.psn;
|
||||
remote_set_ = true;
|
||||
} else {
|
||||
CHECK(remote_.lid == ra.lid);
|
||||
CHECK(remote_.qpn == ra.qpn);
|
||||
CHECK(remote_.psn == ra.psn);
|
||||
}
|
||||
}
|
||||
|
||||
// Adding tokens to the completion queue
|
||||
// Tokens are needed to process future messages.
|
||||
void RdmaChannel::Recv() {
|
||||
struct ibv_recv_wr wr;
|
||||
memset(&wr, 0, sizeof(wr));
|
||||
wr.wr_id = (uint64_t)this;
|
||||
struct ibv_recv_wr* bad_wr;
|
||||
CHECK(!ibv_post_recv(qp_, &wr, &bad_wr)) << "Failed to post recv";
|
||||
}
|
||||
|
||||
// Lookup 32-bit buffer index from buffer name
|
||||
// Args:
|
||||
// buffer_name: name of the buffer
|
||||
// Returns:
|
||||
// 32-bit index
|
||||
uint32_t RdmaChannel::LookupBufferIndex(const string& buffer_name) {
|
||||
mutex_lock lock{bt_mu_};
|
||||
BufferNameIndexTable::iterator iter =
|
||||
buffer_name_index_table_.find(buffer_name);
|
||||
CHECK(iter != buffer_name_index_table_.end());
|
||||
return iter->second;
|
||||
}
|
||||
|
||||
// Find a buffer by its 32-bit index
|
||||
// Args:
|
||||
// index: 32-bit hash code of the tensor buffer name
|
||||
// Returns:
|
||||
// name of the tensor buffer
|
||||
RdmaBuffer* RdmaChannel::FindBuffer(const uint32_t index) {
|
||||
mutex_lock lock{bt_mu_};
|
||||
BufferTable::iterator iter = buffer_table_.find(index);
|
||||
CHECK(iter != buffer_table_.end());
|
||||
return iter->second;
|
||||
}
|
||||
|
||||
// Find a buffer by its name
|
||||
// Args:
|
||||
// name: name of the buffer
|
||||
// Returns:
|
||||
// the named rdma buffer
|
||||
RdmaBuffer* RdmaChannel::FindBuffer(const string& name) {
|
||||
uint32_t index = LookupBufferIndex(name);
|
||||
return FindBuffer(index);
|
||||
}
|
||||
|
||||
// Find a buffer if it exists, otherwise create one.
|
||||
// The memory inside the created buffer is not allocated.
|
||||
// Args:
|
||||
// name: the name of the buffer
|
||||
// buffer_type: TENSOR, MESSAGE or ACK.
|
||||
// Returns:
|
||||
// the named buffer
|
||||
RdmaBuffer* RdmaChannel::FindOrCreateBuffer(const string& name,
|
||||
BufferType buffer_type) {
|
||||
mutex_lock lock{bt_mu_};
|
||||
RdmaBuffer* rb;
|
||||
// find index
|
||||
BufferNameIndexTable::iterator iter = buffer_name_index_table_.find(name);
|
||||
if (iter != buffer_name_index_table_.end()) {
|
||||
uint32_t index = iter->second;
|
||||
// find buffer
|
||||
BufferTable::iterator iter = buffer_table_.find(index);
|
||||
CHECK(iter != buffer_table_.end());
|
||||
rb = iter->second;
|
||||
} else {
|
||||
uint32_t index = NameHash(name);
|
||||
if (buffer_type == TENSOR) {
|
||||
rb = new RdmaTensorBuffer(this, name);
|
||||
} else if (buffer_type == MESSAGE) {
|
||||
rb = new RdmaMessageBuffer(this, name);
|
||||
} else if (buffer_type == ACK) {
|
||||
rb = new RdmaAckBuffer(this, name);
|
||||
}
|
||||
buffer_name_index_table_.insert({name, index});
|
||||
buffer_index_name_table_.insert({index, name});
|
||||
buffer_table_.insert({index, rb});
|
||||
}
|
||||
CHECK(rb);
|
||||
return rb;
|
||||
}
|
||||
|
||||
// Insert callback to the callback_table.
|
||||
// The callback is activated when the corresponding tensor is received.
|
||||
// Arg:
|
||||
// key: the name of the tensor
|
||||
// recv_done: the callback associated with the tensor.
|
||||
// Returns:
|
||||
// None
|
||||
void RdmaChannel::InsertRecvCallback(const string& key,
|
||||
std::function<void()> recv_done) {
|
||||
mutex_lock lock{ct_mu_};
|
||||
callback_table_.insert({key, recv_done});
|
||||
}
|
||||
|
||||
// Remove callback from the callback_table.
|
||||
// Arg:
|
||||
// key: the name of the tensor
|
||||
// Returns:
|
||||
// None
|
||||
void RdmaChannel::RemoveRecvCallback(const string& key) {
|
||||
mutex_lock lock{ct_mu_};
|
||||
callback_table_.erase(key);
|
||||
}
|
||||
|
||||
// Run named callback in the callback_table.
|
||||
// Arg:
|
||||
// key: the name of the tensor
|
||||
// Returns:
|
||||
// None
|
||||
void RdmaChannel::RunRecvCallback(const string& key) {
|
||||
std::function<void()> recv_done;
|
||||
{
|
||||
mutex_lock lock{ct_mu_};
|
||||
CallbackTable::iterator iter = callback_table_.find(key);
|
||||
CHECK(iter != callback_table_.end());
|
||||
recv_done = iter->second;
|
||||
}
|
||||
recv_done();
|
||||
}
|
||||
|
||||
void RdmaChannel::Connect() {
|
||||
{
|
||||
mutex_lock lock{mu_};
|
||||
CHECK(remote_set_) << "remote channel is not set";
|
||||
}
|
||||
Connect(remote_);
|
||||
}
|
||||
|
||||
// Setup channel to a remote node
|
||||
// Args:
|
||||
// remoteAddr: the rdma address of a remote channel.
|
||||
// Returns:
|
||||
// None
|
||||
void RdmaChannel::Connect(const RdmaAddress& remoteAddr) {
|
||||
mutex_lock lock{mu_};
|
||||
if (!connected_) {
|
||||
struct ibv_qp_attr attr;
|
||||
memset(&attr, 0, sizeof(ibv_qp_attr));
|
||||
attr.qp_state = IBV_QPS_RTR;
|
||||
attr.path_mtu = IBV_MTU_4096;
|
||||
attr.dest_qp_num = remoteAddr.qpn;
|
||||
attr.rq_psn = remoteAddr.psn;
|
||||
attr.max_dest_rd_atomic = 1;
|
||||
attr.min_rnr_timer = 12;
|
||||
attr.ah_attr.is_global = 0;
|
||||
attr.ah_attr.dlid = remoteAddr.lid;
|
||||
attr.ah_attr.sl = 0;
|
||||
attr.ah_attr.src_path_bits = 0;
|
||||
attr.ah_attr.port_num = 1;
|
||||
|
||||
int r;
|
||||
CHECK(!(r = ibv_modify_qp(qp_, &attr,
|
||||
IBV_QP_STATE | IBV_QP_AV | IBV_QP_PATH_MTU |
|
||||
IBV_QP_DEST_QPN | IBV_QP_RQ_PSN |
|
||||
IBV_QP_MAX_DEST_RD_ATOMIC |
|
||||
IBV_QP_MIN_RNR_TIMER)))
|
||||
<< "QP to Ready to Receive " << r;
|
||||
|
||||
memset(&attr, 0, sizeof(ibv_qp_attr));
|
||||
attr.qp_state = IBV_QPS_RTS;
|
||||
attr.sq_psn = self_.psn;
|
||||
attr.timeout = 14;
|
||||
attr.retry_cnt = 7;
|
||||
attr.rnr_retry = 7; /* infinite */
|
||||
attr.max_rd_atomic = 1;
|
||||
|
||||
CHECK(!(r = ibv_modify_qp(qp_, &attr,
|
||||
IBV_QP_STATE | IBV_QP_TIMEOUT | IBV_QP_RETRY_CNT |
|
||||
IBV_QP_RNR_RETRY | IBV_QP_SQ_PSN |
|
||||
IBV_QP_MAX_QP_RD_ATOMIC)))
|
||||
<< "QP to Ready to Send " << r;
|
||||
|
||||
connected_ = true;
|
||||
} else {
|
||||
LOG(INFO) << "channel already connected";
|
||||
}
|
||||
}
|
||||
|
||||
RdmaBuffer::RdmaBuffer(RdmaChannel* channel, string name)
|
||||
: channel_(channel), name_(name) {}
|
||||
|
||||
RdmaBuffer::~RdmaBuffer() {
|
||||
CHECK(!ibv_dereg_mr(self_)) << "ibv_dereg_mr failed";
|
||||
FreeBuffer();
|
||||
}
|
||||
|
||||
void RdmaBuffer::FreeBuffer() {
|
||||
if ((buffer_ != nullptr) && buffer_on_host_) {
|
||||
free(buffer_);
|
||||
}
|
||||
// TODO
|
||||
// release buffer if it is on device.
|
||||
// We don't support RDMABuffer on device at this moment.
|
||||
}
|
||||
|
||||
// Allocate CPU memory for the Rdma buffer
|
||||
// Args:
|
||||
// size: to-be-allocated memory size
|
||||
// lock: whether or not mutex_lock the process to protect concurrency.
|
||||
// Returns:
|
||||
// None
|
||||
void RdmaBuffer::CreateCPUBuffer(size_t size, bool lock) {
|
||||
CHECK(size > 0);
|
||||
if (lock) {
|
||||
mu_.lock();
|
||||
}
|
||||
if (local_status_ != none) {
|
||||
// delete existing buffer
|
||||
CHECK(!ibv_dereg_mr(self_)) << "ibv_dereg_mr failed";
|
||||
FreeBuffer();
|
||||
}
|
||||
size_ = size;
|
||||
buffer_ = malloc(size_);
|
||||
self_ = ibv_reg_mr(channel_->adapter_->pd_, buffer_, size_,
|
||||
IBV_ACCESS_LOCAL_WRITE | IBV_ACCESS_REMOTE_WRITE);
|
||||
CHECK(self_) << "Failed to register memory region";
|
||||
buffer_on_host_ = true;
|
||||
local_status_ = idle;
|
||||
if (lock) {
|
||||
mu_.unlock();
|
||||
}
|
||||
}
|
||||
|
||||
// Set address of remote memory region
|
||||
// Args:
|
||||
// rmr: address of remote memory region
|
||||
// override: whether override existing information
|
||||
// Returns:
|
||||
// None
|
||||
void RdmaBuffer::SetRemoteMR(RemoteMR rmr, bool override) {
|
||||
mutex_lock lock{mu_};
|
||||
if ((override) || (remote_status_ == none)) {
|
||||
remote_.remote_addr = rmr.remote_addr;
|
||||
remote_.rkey = rmr.rkey;
|
||||
remote_status_ = idle;
|
||||
} else {
|
||||
CHECK(remote_.remote_addr == rmr.remote_addr);
|
||||
CHECK(remote_.rkey == rmr.rkey);
|
||||
}
|
||||
}
|
||||
|
||||
// Put a task in the buffer's job queue
|
||||
void RdmaBuffer::EnqueueItem(string item) {
|
||||
mutex_lock lock{mu_};
|
||||
queue_.push(item);
|
||||
}
|
||||
|
||||
// Rdma-Write the content of the buffer
|
||||
void RdmaBuffer::Write(uint32_t imm_data, size_t buffer_size) {
|
||||
struct ibv_sge list;
|
||||
list.addr = (uint64_t)buffer_;
|
||||
list.length = buffer_size;
|
||||
list.lkey = self_->lkey;
|
||||
|
||||
struct ibv_send_wr wr;
|
||||
memset(&wr, 0, sizeof(wr));
|
||||
wr.wr_id = (uint64_t)this;
|
||||
wr.sg_list = &list;
|
||||
wr.num_sge = 1;
|
||||
wr.opcode = IBV_WR_RDMA_WRITE_WITH_IMM;
|
||||
wr.send_flags = IBV_SEND_SIGNALED;
|
||||
wr.imm_data = imm_data;
|
||||
wr.wr.rdma.remote_addr = (uint64_t)remote_.remote_addr;
|
||||
wr.wr.rdma.rkey = remote_.rkey;
|
||||
|
||||
struct ibv_send_wr* bad_wr;
|
||||
CHECK(!ibv_post_send(channel_->qp_, &wr, &bad_wr)) << "Failed to post send";
|
||||
}
|
||||
|
||||
RdmaAckBuffer::RdmaAckBuffer(RdmaChannel* channel, string name)
|
||||
: RdmaBuffer(channel, name) {}
|
||||
|
||||
RdmaMessageBuffer::RdmaMessageBuffer(RdmaChannel* channel, string name)
|
||||
: RdmaBuffer(channel, name) {}
|
||||
|
||||
RdmaTensorBuffer::RdmaTensorBuffer(RdmaChannel* channel, string name)
|
||||
: RdmaBuffer(channel, name) {}
|
||||
|
||||
// Send the next ack from the buffer's job queue.
|
||||
void RdmaAckBuffer::SendNextItem() {
|
||||
uint32_t imm_data = LookupBufferIndex("rx_ack_buffer");
|
||||
RdmaMessage rm;
|
||||
rm.name_ = "rx_ack_buffer";
|
||||
rm.type_ = RDMA_MESSAGE_ACK;
|
||||
rm.name_size_ = rm.name_.size();
|
||||
string message = RdmaMessage::CreateMessage(rm);
|
||||
memcpy(buffer_, message.data(), message.size());
|
||||
Write(imm_data, message.size());
|
||||
}
|
||||
|
||||
// Send the next message from the buffer's job queue.
|
||||
void RdmaMessageBuffer::SendNextItem() {
|
||||
uint32_t imm_data = LookupBufferIndex("rx_message_buffer");
|
||||
mu_.lock();
|
||||
if (!queue_.empty() && (local_status_ == idle) && (remote_status_ == idle)) {
|
||||
local_status_ = busy;
|
||||
remote_status_ = busy;
|
||||
string message = queue_.front();
|
||||
queue_.pop();
|
||||
// local/remote_status_ won't be set back to idle
|
||||
// unitl Write() is successful
|
||||
mu_.unlock();
|
||||
memcpy(buffer_, message.data(), message.size());
|
||||
Write(imm_data, message.size());
|
||||
} else {
|
||||
mu_.unlock();
|
||||
}
|
||||
}
|
||||
|
||||
// Send the next tensor from the buffer's job queue.
|
||||
void RdmaTensorBuffer::SendNextItem() {
|
||||
// get the key
|
||||
string key_with_step_id = "";
|
||||
{
|
||||
mutex_lock lock{mu_};
|
||||
if (!queue_.empty()) {
|
||||
key_with_step_id = queue_.front();
|
||||
queue_.pop();
|
||||
}
|
||||
}
|
||||
// send the tensor if a key is acquired.
|
||||
if (key_with_step_id != "") {
|
||||
VLOG(2) << "try to send tensor: " << key_with_step_id;
|
||||
string key;
|
||||
int64 step_id;
|
||||
VerbsUtil::GetKeyAndStepId(key_with_step_id, key, step_id);
|
||||
CHECK(key.compare(name_) == 0);
|
||||
Rendezvous::ParsedKey parsed;
|
||||
Rendezvous::ParseKey(key, &parsed);
|
||||
Rendezvous::DoneCallback cb = [this, key_with_step_id, key, step_id,
|
||||
parsed](const Status& status,
|
||||
const Rendezvous::Args& send_args,
|
||||
const Rendezvous::Args& recv_args,
|
||||
const Tensor& in, bool is_dead) {
|
||||
CHECK(status.ok()) << "RecvLocalAsync was not ok, key" << key_with_step_id
|
||||
<< " error message: " << status.error_message();
|
||||
size_t buffer_size = RdmaMessage::kMessageTotalBytes;
|
||||
size_t tensor_bytes = 0;
|
||||
TensorProto proto;
|
||||
// Figures out which device the tensor is hosted on.
|
||||
Device* src_dev = nullptr;
|
||||
Status s = channel_->adapter_->worker_env_->device_mgr->LookupDevice(
|
||||
parsed.src_device, &src_dev);
|
||||
CHECK(s.ok()) << "src device not found";
|
||||
// Does the device have the right incarnation number we expect?
|
||||
CHECK(src_dev->attributes().incarnation() == parsed.src_incarnation)
|
||||
<< "RecvTensor expects a different device incarnation: "
|
||||
<< parsed.src_incarnation << " vs. "
|
||||
<< src_dev->attributes().incarnation()
|
||||
<< ". Your worker job was probably restarted. Check your "
|
||||
<< "worker job for the reason why it was restarted.";
|
||||
Device* dst_dev = nullptr;
|
||||
// destination is on CPU.
|
||||
s = channel_->adapter_->worker_env_->device_mgr->LookupDevice("CPU:0",
|
||||
&dst_dev);
|
||||
CHECK(s.ok()) << "dst device not found";
|
||||
AllocatorAttributes dst_alloc_attr;
|
||||
dst_alloc_attr.set_on_host(true);
|
||||
// string tensor needs to be serialized
|
||||
if (src_dev->tensorflow_gpu_device_info() &&
|
||||
(!send_args.alloc_attrs.on_host())) {
|
||||
CHECK(send_args.device_context)
|
||||
<< "send dev name: " << src_dev->name()
|
||||
<< " gpu_info: " << src_dev->tensorflow_gpu_device_info();
|
||||
// "val" is on a GPU. Uses GPUUtil to fill the proto.
|
||||
s = VerbsUtil::SetProtoFromGPUSync(
|
||||
in, src_dev, send_args.device_context, &proto, is_dead);
|
||||
CHECK(s.ok()) << "set proto from gpu sync";
|
||||
} else {
|
||||
// tensor is in CPU memory.
|
||||
in.AsProtoTensorContent(&proto);
|
||||
}
|
||||
tensor_bytes = proto.ByteSize();
|
||||
// maybe some margin for string tensor?
|
||||
buffer_size += tensor_bytes;
|
||||
// prepare message
|
||||
RdmaMessage rm;
|
||||
rm.name_size_ = key.size();
|
||||
rm.name_ = key;
|
||||
rm.tensor_shape_ = in.shape();
|
||||
rm.data_type_ = in.dtype();
|
||||
rm.step_id_ = step_id;
|
||||
rm.is_dead_ = is_dead;
|
||||
rm.tensor_bytes_ = tensor_bytes;
|
||||
rm.buffer_size_ = buffer_size;
|
||||
mu_.lock();
|
||||
if (local_status_ == none ||
|
||||
(buffer_size > size_ && local_status_ == idle &&
|
||||
remote_status_ == idle)) {
|
||||
if ((local_status_ != none) && (buffer_size > size_)) {
|
||||
CHECK(rm.data_type_ == DT_STRING)
|
||||
<< "Only string tensor allows to change size";
|
||||
}
|
||||
CreateCPUBuffer(buffer_size, false);
|
||||
mu_.unlock();
|
||||
// put back the key since it is not sent;
|
||||
EnqueueItem(key_with_step_id);
|
||||
// ask the remote to create the same buffer
|
||||
rm.type_ = RDMA_MESSAGE_BUFFER_REQUEST;
|
||||
rm.remote_addr_ = reinterpret_cast<uint64_t>(buffer_);
|
||||
rm.rkey_ = self_->rkey;
|
||||
string message = RdmaMessage::CreateMessage(rm);
|
||||
channel_->tx_message_buffer_->EnqueueItem(message);
|
||||
channel_->tx_message_buffer_->SendNextItem();
|
||||
} else if ((local_status_ == idle) && (remote_status_ == idle)) {
|
||||
// both buffers are ready, send the tensor
|
||||
local_status_ = busy;
|
||||
remote_status_ = busy;
|
||||
// local/remote_status_ won't be set back to idle
|
||||
// unitl Write() is successful
|
||||
mu_.unlock();
|
||||
CHECK((buffer_size == size_ && rm.data_type_ != DT_STRING) ||
|
||||
(buffer_size <= size_ && rm.data_type_ == DT_STRING))
|
||||
<< "tensor and buffer size do not agree!"
|
||||
<< " buffer_size = " << size_
|
||||
<< " requested tensor size = " << buffer_size << in.DebugString();
|
||||
uint32_t imm_data = LookupBufferIndex(key);
|
||||
rm.type_ = RDMA_MESSAGE_TENSOR_WRITE;
|
||||
string message = RdmaMessage::CreateMessage(rm);
|
||||
memcpy(buffer_, message.data(), message.size());
|
||||
if (!is_dead) {
|
||||
// copy the tensor buffer content
|
||||
void* output =
|
||||
static_cast<void*>(static_cast<char*>(buffer_) +
|
||||
RdmaMessage::kTensorBufferStartIndex);
|
||||
CHECK(tensor_bytes + RdmaMessage::kTensorBufferStartIndex <= size_);
|
||||
proto.SerializeToArray(output, tensor_bytes);
|
||||
} else {
|
||||
buffer_size = RdmaMessage::kMessageTotalBytes;
|
||||
}
|
||||
Write(imm_data, buffer_size);
|
||||
} else {
|
||||
mu_.unlock();
|
||||
// put back the key since it is not sent;
|
||||
EnqueueItem(key_with_step_id);
|
||||
}
|
||||
};
|
||||
// Use default session (legacy_session_)
|
||||
// TODO use WorkerSessionForSession
|
||||
// need to pass in session handle
|
||||
channel_->adapter_->worker_env_->session_mgr->LegacySession()
|
||||
->rendezvous_mgr->RecvLocalAsync(step_id, parsed, cb);
|
||||
}
|
||||
}
|
||||
|
||||
// Create a RdmaMessage according to the pre-defined format
|
||||
// Args:
|
||||
// rm: the message structure
|
||||
// Returns:
|
||||
// message in string format
|
||||
string RdmaMessage::CreateMessage(const RdmaMessage& rm) {
|
||||
// Rdma Message format
|
||||
// type|name_size|name|step_id|buffer_size|remote_addr|rkey|is_dead|...
|
||||
// 1B| 2B | 512| 8B | 8B | 8B | 4B | 1B |...
|
||||
// ...|data_type|tensor_shape|tensor_bytes|tensor_buffer
|
||||
// ...| XB | XB | 8B |...
|
||||
//
|
||||
// ACK: type|13|"rx_ack_buffer"
|
||||
// TENSOR_REQUEST: type|name_size|tensor_name|step_id
|
||||
// TENSOR_WRITE: type|name_size|tensor_name|step_id|...|is_dead
|
||||
// |data_type|tensor_shape|tensor_bytes
|
||||
// BUFFER_IDLE: type|name_size|buffer_name
|
||||
// BUFFER_REQUEST:
|
||||
// type|name_size|buffer_name|...|buffer_size|remote_addr|rkey|
|
||||
// BUFFER_RESPONSE:
|
||||
// type|name_size|buffer_name|...|buffer_size|remote_addr|rkey|
|
||||
char message[kMessageTotalBytes];
|
||||
// type
|
||||
message[kTypeStartIndex] = static_cast<char>(rm.type_) & 0xff;
|
||||
// size of name
|
||||
memcpy(&message[kNameSizeStartIndex], &rm.name_size_, sizeof(rm.name_size_));
|
||||
// name
|
||||
memcpy(&message[kNameStartIndex], rm.name_.data(), rm.name_.size());
|
||||
// buffer_size, remote_addr, rkey
|
||||
if ((rm.type_ == RDMA_MESSAGE_BUFFER_REQUEST) ||
|
||||
(rm.type_ == RDMA_MESSAGE_BUFFER_RESPONSE)) {
|
||||
memcpy(&message[kBufferSizeStartIndex], &rm.buffer_size_,
|
||||
sizeof(rm.buffer_size_));
|
||||
memcpy(&message[kRemoteAddrStartIndex], &rm.remote_addr_,
|
||||
sizeof(rm.remote_addr_));
|
||||
memcpy(&message[kRkeyStartIndex], &rm.rkey_, sizeof(rm.rkey_));
|
||||
}
|
||||
// step_id
|
||||
if ((rm.type_ == RDMA_MESSAGE_TENSOR_WRITE) ||
|
||||
(rm.type_ == RDMA_MESSAGE_TENSOR_REQUEST)) {
|
||||
memcpy(&message[kStepIdStartIndex], &rm.step_id_, sizeof(rm.step_id_));
|
||||
}
|
||||
// is_dead, data_type, tensor_shape, tensor_bytes
|
||||
if (rm.type_ == RDMA_MESSAGE_TENSOR_WRITE) {
|
||||
memcpy(&message[kIsDeadStartIndex], &rm.is_dead_, sizeof(rm.is_dead_));
|
||||
|
||||
memcpy(&message[kDataTypeStartIndex], &rm.data_type_,
|
||||
sizeof(rm.data_type_));
|
||||
memcpy(&message[kTensorShapeStartIndex], &rm.tensor_shape_,
|
||||
sizeof(rm.tensor_shape_));
|
||||
memcpy(&message[kTensorBytesStartIndex], &rm.tensor_bytes_,
|
||||
sizeof(rm.tensor_bytes_));
|
||||
}
|
||||
return string(message, kMessageTotalBytes);
|
||||
}
|
||||
|
||||
// Parse a RdmaMessage according to the pre-defined format
|
||||
// Args:
|
||||
// rm: the message structure where the parsed message will be saved
|
||||
// buffer: the place where the raw message is stored
|
||||
// Returns:
|
||||
// None
|
||||
void RdmaMessage::ParseMessage(RdmaMessage& rm, void* buffer) {
|
||||
char* message = static_cast<char*>(buffer);
|
||||
// type
|
||||
rm.type_ = static_cast<RdmaMessageType>(message[kTypeStartIndex]);
|
||||
// name_size_
|
||||
memcpy(&rm.name_size_, &message[kNameSizeStartIndex], sizeof(rm.name_size_));
|
||||
// name
|
||||
rm.name_ = string(&message[kNameStartIndex], rm.name_size_);
|
||||
// buffer_size, remote_addr, rkey
|
||||
if ((rm.type_ == RDMA_MESSAGE_BUFFER_REQUEST) ||
|
||||
(rm.type_ == RDMA_MESSAGE_BUFFER_RESPONSE)) {
|
||||
memcpy(&rm.buffer_size_, &message[kBufferSizeStartIndex],
|
||||
sizeof(rm.buffer_size_));
|
||||
memcpy(&rm.remote_addr_, &message[kRemoteAddrStartIndex],
|
||||
sizeof(rm.remote_addr_));
|
||||
memcpy(&rm.rkey_, &message[kRkeyStartIndex], sizeof(rm.rkey_));
|
||||
}
|
||||
// step_id
|
||||
if ((rm.type_ == RDMA_MESSAGE_TENSOR_WRITE) ||
|
||||
(rm.type_ == RDMA_MESSAGE_TENSOR_REQUEST)) {
|
||||
memcpy(&rm.step_id_, &message[kStepIdStartIndex], sizeof(rm.step_id_));
|
||||
}
|
||||
// data_type, tensor_bytes, tensor_shape, is_dead
|
||||
if (rm.type_ == RDMA_MESSAGE_TENSOR_WRITE) {
|
||||
memcpy(&rm.is_dead_, &message[kIsDeadStartIndex], sizeof(rm.is_dead_));
|
||||
memcpy(&rm.data_type_, &message[kDataTypeStartIndex],
|
||||
sizeof(rm.data_type_));
|
||||
memcpy(&rm.tensor_shape_, &message[kTensorShapeStartIndex],
|
||||
sizeof(rm.tensor_shape_));
|
||||
memcpy(&rm.tensor_bytes_, &message[kTensorBytesStartIndex],
|
||||
sizeof(rm.tensor_bytes_));
|
||||
}
|
||||
}
|
||||
|
||||
} // end namespace tensorflow
|
||||
|
||||
#endif
|
277
tensorflow/contrib/verbs/rdma.h
Normal file
277
tensorflow/contrib/verbs/rdma.h
Normal file
@ -0,0 +1,277 @@
|
||||
/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
|
||||
#ifndef THIRD_PARTY_TENSORFLOW_CONTRIB_VERBS_RDMA_H_
|
||||
#define THIRD_PARTY_TENSORFLOW_CONTRIB_VERBS_RDMA_H_
|
||||
|
||||
#ifdef TENSORFLOW_USE_VERBS
|
||||
|
||||
#include <infiniband/verbs.h>
|
||||
#include <cstring> // for memset
|
||||
#include <functional>
|
||||
#include <memory> // for shared_ptr
|
||||
#include <queue>
|
||||
#include <string>
|
||||
#include <unordered_map>
|
||||
#include <vector>
|
||||
|
||||
#include "tensorflow/core/distributed_runtime/worker_env.h"
|
||||
#include "tensorflow/core/framework/tensor_shape.h"
|
||||
#include "tensorflow/core/framework/types.h"
|
||||
#include "tensorflow/core/platform/env.h"
|
||||
#include "tensorflow/core/platform/mutex.h"
|
||||
|
||||
namespace tensorflow {
|
||||
|
||||
// structure to save the address of remote channels.
|
||||
struct RdmaAddress {
|
||||
uint32_t lid;
|
||||
uint32_t qpn;
|
||||
uint32_t psn;
|
||||
};
|
||||
// structure to save information for remote memory regions.
|
||||
struct RemoteMR {
|
||||
uint64_t remote_addr;
|
||||
uint32_t rkey;
|
||||
};
|
||||
enum BufferStatus { none, idle, busy };
|
||||
enum Location { local, remote };
|
||||
enum BufferType { ACK, MESSAGE, TENSOR };
|
||||
enum RdmaMessageType {
|
||||
RDMA_MESSAGE_ACK,
|
||||
RDMA_MESSAGE_BUFFER_IDLE,
|
||||
RDMA_MESSAGE_BUFFER_REQUEST,
|
||||
RDMA_MESSAGE_BUFFER_RESPONSE,
|
||||
RDMA_MESSAGE_TENSOR_REQUEST,
|
||||
RDMA_MESSAGE_TENSOR_WRITE
|
||||
};
|
||||
class RdmaBuffer;
|
||||
// Class that represents the Rdma Adapter.
|
||||
// Responsible for creation of the completion queue, and handling
|
||||
// of work completions.
|
||||
class RdmaAdapter {
|
||||
friend class RdmaChannel;
|
||||
friend class RdmaBuffer;
|
||||
friend class RdmaAckBuffer;
|
||||
friend class RdmaMessageBuffer;
|
||||
friend class RdmaTensorBuffer;
|
||||
friend class RdmaMgr;
|
||||
friend class RdmaRemoteRendezvous;
|
||||
|
||||
public:
|
||||
RdmaAdapter(const WorkerEnv* worker_env);
|
||||
~RdmaAdapter();
|
||||
// Adapter name, e.g. mlx5_0.
|
||||
string name() const;
|
||||
void Process_CQ();
|
||||
|
||||
protected:
|
||||
static const int MAX_CONCURRENT_WRITES = 1000;
|
||||
ibv_context* context_;
|
||||
// ibverbs protection domain
|
||||
ibv_pd* pd_;
|
||||
// Completion event channel, to wait for work completions
|
||||
ibv_comp_channel* event_channel_;
|
||||
// Completion queue, to poll on work completions
|
||||
ibv_cq* cq_;
|
||||
// Pre-allocated work completions array used for polling
|
||||
ibv_wc wc_[MAX_CONCURRENT_WRITES * 2];
|
||||
// worker env for thread
|
||||
const WorkerEnv* worker_env_;
|
||||
// thread for cq.
|
||||
std::unique_ptr<Thread> polling_thread_;
|
||||
};
|
||||
|
||||
// Class that represents a connection to a remote Rdma peer.
|
||||
// Responsible for connecting queue pairs.
|
||||
class RdmaChannel {
|
||||
friend class RdmaAdapter;
|
||||
friend class RdmaBuffer;
|
||||
friend class RdmaAckBuffer;
|
||||
friend class RdmaMessageBuffer;
|
||||
friend class RdmaTensorBuffer;
|
||||
friend class RdmaMgr;
|
||||
friend class RdmaRemoteRendezvous;
|
||||
|
||||
public:
|
||||
explicit RdmaChannel(const RdmaAdapter* adapter, const string local_name,
|
||||
const string remote_name_);
|
||||
~RdmaChannel();
|
||||
inline const RdmaAddress& self() { return self_; }
|
||||
RdmaAddress address() const;
|
||||
inline const std::vector<RdmaBuffer*>& message_buffers() const {
|
||||
return message_buffers_;
|
||||
}
|
||||
void Connect(const RdmaAddress& remoteAddr);
|
||||
void Connect();
|
||||
void Recv();
|
||||
RdmaBuffer* FindBuffer(const uint32_t index);
|
||||
RdmaBuffer* FindBuffer(const string& name);
|
||||
RdmaBuffer* FindOrCreateBuffer(const string& name,
|
||||
BufferType buffer_type = TENSOR);
|
||||
uint32_t LookupBufferIndex(const string& buffer_name);
|
||||
void SetRemoteAddress(const RdmaAddress& ra, bool override);
|
||||
void InsertRecvCallback(const string& key, std::function<void()> recv_done);
|
||||
void RemoveRecvCallback(const string& key);
|
||||
void RunRecvCallback(const string& key);
|
||||
static const int kNumMessageBuffers = 4;
|
||||
|
||||
protected:
|
||||
const RdmaAdapter* adapter_;
|
||||
RdmaAddress self_;
|
||||
string local_name_;
|
||||
string remote_name_;
|
||||
ibv_qp* qp_;
|
||||
mutex mu_;
|
||||
bool connected_ GUARDED_BY(bt_mu_) = false;
|
||||
RdmaAddress remote_ GUARDED_BY(bt_mu_);
|
||||
bool remote_set_ GUARDED_BY(bt_mu_) = false;
|
||||
mutex ct_mu_;
|
||||
typedef std::unordered_map<string, std::function<void()> > CallbackTable;
|
||||
CallbackTable callback_table_ GUARDED_BY(ct_mu_);
|
||||
mutex bt_mu_;
|
||||
typedef std::unordered_map<unsigned int, RdmaBuffer*> BufferTable;
|
||||
BufferTable buffer_table_ GUARDED_BY(bt_mu_);
|
||||
typedef std::unordered_map<uint32_t, string> BufferIndexNameTable;
|
||||
BufferIndexNameTable buffer_index_name_table_ GUARDED_BY(bt_mu_);
|
||||
typedef std::unordered_map<string, uint32_t> BufferNameIndexTable;
|
||||
BufferNameIndexTable buffer_name_index_table_ GUARDED_BY(bt_mu_);
|
||||
RdmaBuffer* tx_message_buffer_;
|
||||
RdmaBuffer* rx_message_buffer_;
|
||||
RdmaBuffer* tx_ack_buffer_;
|
||||
RdmaBuffer* rx_ack_buffer_;
|
||||
std::vector<RdmaBuffer*> message_buffers_;
|
||||
};
|
||||
|
||||
// Class that represents a buffer for Rdma writes and reads.
|
||||
class RdmaBuffer {
|
||||
friend class RdmaChannel;
|
||||
friend class RdmaAdapter;
|
||||
friend class RdmaMgr;
|
||||
friend class RdmaRemoteRendezvous;
|
||||
|
||||
public:
|
||||
explicit RdmaBuffer(RdmaChannel* channel, string name);
|
||||
virtual ~RdmaBuffer();
|
||||
|
||||
inline void* buffer() const { return buffer_; }
|
||||
inline ibv_mr* self() const { return self_; }
|
||||
inline void SetBufferStatus(Location loc, BufferStatus status) {
|
||||
mu_.lock();
|
||||
if (loc == local) {
|
||||
local_status_ = status;
|
||||
} else {
|
||||
remote_status_ = status;
|
||||
}
|
||||
mu_.unlock();
|
||||
}
|
||||
void FreeBuffer();
|
||||
void EnqueueItem(string Item);
|
||||
virtual void SendNextItem(){};
|
||||
void CreateCPUBuffer(size_t size, bool lock = true);
|
||||
void SetRemoteMR(RemoteMR rmi, bool override);
|
||||
uint32_t LookupBufferIndex(const string& buffer_name) {
|
||||
return const_cast<RdmaChannel*>(channel_)->LookupBufferIndex(buffer_name);
|
||||
}
|
||||
void Write(uint32_t imm_data, size_t buffer_size);
|
||||
|
||||
protected:
|
||||
const RdmaChannel* channel_;
|
||||
void* buffer_ = nullptr;
|
||||
bool buffer_on_host_ = true;
|
||||
size_t size_ = 0;
|
||||
const string name_;
|
||||
ibv_mr* self_ = nullptr;
|
||||
mutex mu_;
|
||||
RemoteMR remote_;
|
||||
std::queue<string> queue_ GUARDED_BY(mu_);
|
||||
BufferStatus local_status_ GUARDED_BY(mu_) = none;
|
||||
BufferStatus remote_status_ GUARDED_BY(mu_) = none;
|
||||
};
|
||||
|
||||
class RdmaAckBuffer : public RdmaBuffer {
|
||||
public:
|
||||
explicit RdmaAckBuffer(RdmaChannel* channel, string name);
|
||||
virtual ~RdmaAckBuffer() override {}
|
||||
void SendNextItem() override;
|
||||
};
|
||||
|
||||
class RdmaMessageBuffer : public RdmaBuffer {
|
||||
friend class RdmaChannel;
|
||||
friend class RdmaAapater;
|
||||
|
||||
public:
|
||||
explicit RdmaMessageBuffer(RdmaChannel* channel, string name);
|
||||
virtual ~RdmaMessageBuffer() override {}
|
||||
void SendNextItem() override;
|
||||
};
|
||||
|
||||
class RdmaTensorBuffer : public RdmaBuffer {
|
||||
public:
|
||||
explicit RdmaTensorBuffer(RdmaChannel* channel, string name);
|
||||
virtual ~RdmaTensorBuffer() override {}
|
||||
void SendNextItem() override;
|
||||
};
|
||||
|
||||
struct RdmaMessage {
|
||||
RdmaMessageType type_;
|
||||
uint16_t name_size_;
|
||||
string name_;
|
||||
int64 step_id_;
|
||||
uint64_t buffer_size_;
|
||||
uint64_t remote_addr_;
|
||||
uint32_t rkey_;
|
||||
bool is_dead_;
|
||||
DataType data_type_;
|
||||
TensorShape tensor_shape_;
|
||||
size_t tensor_bytes_;
|
||||
|
||||
// type|name_size|name|step_id|buffer_size|remote_addr|rkey|is_dead|...
|
||||
// 1B| 2B | 512| 8B | 8B | 8B | 4B | 1B |...
|
||||
// ...|data_type|tensor_shape|tensor_bytes|tensor_buffer
|
||||
// ...| XB | XB | 8B |...
|
||||
//
|
||||
static const size_t kNameCapacity = 512;
|
||||
static const size_t kTypeStartIndex = 0;
|
||||
static const size_t kNameSizeStartIndex = kTypeStartIndex + sizeof(type_);
|
||||
static const size_t kNameStartIndex =
|
||||
kNameSizeStartIndex + sizeof(name_size_);
|
||||
static const size_t kStepIdStartIndex = kNameStartIndex + kNameCapacity;
|
||||
static const size_t kBufferSizeStartIndex =
|
||||
kStepIdStartIndex + sizeof(step_id_);
|
||||
static const size_t kRemoteAddrStartIndex =
|
||||
kBufferSizeStartIndex + sizeof(buffer_size_);
|
||||
static const size_t kRkeyStartIndex =
|
||||
kRemoteAddrStartIndex + sizeof(remote_addr_);
|
||||
static const size_t kIsDeadStartIndex = kRkeyStartIndex + sizeof(rkey_);
|
||||
static const size_t kDataTypeStartIndex =
|
||||
kIsDeadStartIndex + sizeof(is_dead_);
|
||||
static const size_t kTensorShapeStartIndex =
|
||||
kDataTypeStartIndex + sizeof(data_type_);
|
||||
static const size_t kTensorBytesStartIndex =
|
||||
kTensorShapeStartIndex + sizeof(TensorShape);
|
||||
static const size_t kTensorBufferStartIndex =
|
||||
kTensorBytesStartIndex + sizeof(tensor_bytes_);
|
||||
static const size_t kMessageTotalBytes = kTensorBufferStartIndex;
|
||||
static const size_t kRdmaMessageBufferSize = kMessageTotalBytes;
|
||||
static const size_t kRdmaAckBufferSize = kMessageTotalBytes;
|
||||
static string CreateMessage(const RdmaMessage& rm);
|
||||
static void ParseMessage(RdmaMessage& rm, void* buffer);
|
||||
};
|
||||
|
||||
} // namespace tensorflow
|
||||
|
||||
#endif // TENSORFLOW_USE_VERBS
|
||||
#endif // THIRD_PARTY_TENSORFLOW_CONTRIB_VERBS_RDMA_H_
|
133
tensorflow/contrib/verbs/rdma_mgr.cc
Normal file
133
tensorflow/contrib/verbs/rdma_mgr.cc
Normal file
@ -0,0 +1,133 @@
|
||||
/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
|
||||
#ifdef TENSORFLOW_USE_VERBS
|
||||
|
||||
#include "tensorflow/contrib/verbs/rdma_mgr.h"
|
||||
#include <vector>
|
||||
#include "tensorflow/contrib/verbs/grpc_verbs_client.h"
|
||||
#include "tensorflow/contrib/verbs/verbs_service.pb.h"
|
||||
#include "tensorflow/core/distributed_runtime/rpc/grpc_worker_cache.h"
|
||||
#include "tensorflow/core/distributed_runtime/session_mgr.h"
|
||||
#include "tensorflow/core/lib/core/status.h"
|
||||
|
||||
namespace tensorflow {
|
||||
|
||||
RdmaMgr::RdmaMgr(const WorkerEnv* const worker_env,
|
||||
GrpcChannelCache* const channel_cache)
|
||||
: worker_env_(worker_env), channel_cache_(channel_cache) {
|
||||
rdma_adapter_ = new RdmaAdapter(worker_env_);
|
||||
// hardcoded to default session (legacy_session_)
|
||||
// TODO: use WorkerSessionForSession
|
||||
// need to pass in session handle
|
||||
local_worker_ = worker_env_->session_mgr->LegacySession()->worker_name;
|
||||
std::vector<string> workers;
|
||||
worker_env_->session_mgr->LegacySession()->worker_cache->ListWorkers(
|
||||
&workers);
|
||||
num_remote_workers_ = workers.size() - 1;
|
||||
VLOG(2) << "rmda_mgr on local worker: " << local_worker_;
|
||||
for (size_t i = 0; i < workers.size(); i++) {
|
||||
if (local_worker_.compare(workers[i]) != 0) {
|
||||
channel_table_.insert(
|
||||
{workers[i],
|
||||
new RdmaChannel(rdma_adapter_, local_worker_, workers[i])});
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Setup Rdma channels between peers.
|
||||
// This is done at the beginning of the server setup.
|
||||
|
||||
void RdmaMgr::SetupChannels() {
|
||||
for (const auto& p : channel_table_) {
|
||||
string worker_name = p.first;
|
||||
LOG(INFO) << "connecting to remote node " << worker_name;
|
||||
RdmaChannel* rc = p.second;
|
||||
GetRemoteAddressRequest req;
|
||||
GetRemoteAddressResponse resp;
|
||||
// get the channel cache
|
||||
SharedGrpcChannelPtr client_channel =
|
||||
channel_cache_->FindWorkerChannel(worker_name);
|
||||
GrpcVerbsClient* client = new GrpcVerbsClient(client_channel);
|
||||
CHECK(client != nullptr) << "No worker known as " << worker_name;
|
||||
|
||||
// setting up request
|
||||
req.set_host_name(local_worker_);
|
||||
Channel* channel_info = req.mutable_channel();
|
||||
channel_info->set_lid(rc->self_.lid);
|
||||
channel_info->set_qpn(rc->self_.qpn);
|
||||
channel_info->set_psn(rc->self_.psn);
|
||||
for (int i = 0; i < RdmaChannel::kNumMessageBuffers; i++) {
|
||||
MemoryRegion* mr = req.add_mr();
|
||||
mr->set_remote_addr(
|
||||
reinterpret_cast<uint64_t>(rc->message_buffers_[i]->buffer_));
|
||||
mr->set_rkey(rc->message_buffers_[i]->self_->rkey);
|
||||
}
|
||||
// synchronous call
|
||||
Status s = client->GetRemoteAddress(&req, &resp);
|
||||
// save obtained remote addresses
|
||||
// connect to the remote channel
|
||||
if (s.ok()) {
|
||||
CHECK(worker_name.compare(resp.host_name()) == 0);
|
||||
RdmaAddress ra;
|
||||
ra.lid = resp.channel().lid();
|
||||
ra.qpn = resp.channel().qpn();
|
||||
ra.psn = resp.channel().psn();
|
||||
rc->SetRemoteAddress(ra, false);
|
||||
rc->Connect();
|
||||
int i = 0;
|
||||
int idx[] = {1, 0, 3, 2};
|
||||
for (const auto& mr : resp.mr()) {
|
||||
// the connections are crossed, i.e.
|
||||
// local tx_message_buffer <---> remote rx_message_buffer_
|
||||
// local rx_message_buffer <---> remote tx_message_buffer_
|
||||
// local tx_ack_buffer <---> remote rx_ack_buffer_
|
||||
// local rx_ack_buffer <---> remote tx_ack_buffer_
|
||||
// hence idx[] = {1, 0, 3, 2}.
|
||||
RdmaBuffer* rb = rc->message_buffers_[idx[i]];
|
||||
RemoteMR rmr;
|
||||
rmr.remote_addr = mr.remote_addr();
|
||||
rmr.rkey = mr.rkey();
|
||||
rb->SetRemoteMR(rmr, false);
|
||||
i++;
|
||||
}
|
||||
CHECK(i == RdmaChannel::kNumMessageBuffers);
|
||||
} else {
|
||||
LOG(ERROR) << s.error_message();
|
||||
}
|
||||
delete client;
|
||||
}
|
||||
}
|
||||
|
||||
RdmaMgr::~RdmaMgr() {
|
||||
for (const auto& p : channel_table_) delete p.second;
|
||||
channel_table_.clear();
|
||||
delete rdma_adapter_;
|
||||
}
|
||||
|
||||
// Find a channel via the given name.
|
||||
// Args:
|
||||
// name: peer name, e.g. worker1
|
||||
// Returns
|
||||
// channel object that is connected to the named peer.
|
||||
RdmaChannel* RdmaMgr::FindChannel(const string& name) {
|
||||
ChannelTable::iterator iter = channel_table_.find(name);
|
||||
CHECK(iter != channel_table_.end());
|
||||
return iter->second;
|
||||
}
|
||||
|
||||
} // end namespace tensorflow
|
||||
|
||||
#endif
|
54
tensorflow/contrib/verbs/rdma_mgr.h
Normal file
54
tensorflow/contrib/verbs/rdma_mgr.h
Normal file
@ -0,0 +1,54 @@
|
||||
/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
|
||||
#ifndef THIRD_PARTY_TENSORFLOW_CONTRIB_VERBS_RDMA_MGR_H_
|
||||
#define THIRD_PARTY_TENSORFLOW_CONTRIB_VERBS_RDMA_MGR_H_
|
||||
|
||||
#ifdef TENSORFLOW_USE_VERBS
|
||||
|
||||
#include <string>
|
||||
#include <unordered_map>
|
||||
|
||||
#include "tensorflow/contrib/verbs/rdma.h"
|
||||
#include "tensorflow/core/distributed_runtime/rpc/grpc_channel.h"
|
||||
#include "tensorflow/core/distributed_runtime/worker_env.h"
|
||||
|
||||
namespace tensorflow {
|
||||
|
||||
class RdmaMgr {
|
||||
public:
|
||||
explicit RdmaMgr(const WorkerEnv* const worker_env,
|
||||
GrpcChannelCache* const channel_cache);
|
||||
~RdmaMgr();
|
||||
RdmaChannel* FindChannel(const string& key);
|
||||
void SetupChannels();
|
||||
const string& local_worker() { return local_worker_; }
|
||||
|
||||
private:
|
||||
string local_worker_;
|
||||
size_t num_remote_workers_;
|
||||
const WorkerEnv* const worker_env_;
|
||||
GrpcChannelCache* const channel_cache_;
|
||||
RdmaAdapter* rdma_adapter_;
|
||||
typedef std::unordered_map<string, RdmaChannel*> ChannelTable;
|
||||
ChannelTable channel_table_;
|
||||
|
||||
TF_DISALLOW_COPY_AND_ASSIGN(RdmaMgr);
|
||||
};
|
||||
|
||||
} // namespace tensorflow
|
||||
|
||||
#endif // TENSORFLOW_USE_VERBS
|
||||
#endif // THIRD_PARTY_TENSORFLOW_CONTRIB_VERBS_RDMA_MGR_H_
|
149
tensorflow/contrib/verbs/rdma_rendezvous_mgr.cc
Normal file
149
tensorflow/contrib/verbs/rdma_rendezvous_mgr.cc
Normal file
@ -0,0 +1,149 @@
|
||||
/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
|
||||
#ifdef TENSORFLOW_USE_VERBS
|
||||
|
||||
#include "tensorflow/contrib/verbs/rdma_rendezvous_mgr.h"
|
||||
#include <unordered_set>
|
||||
#include "tensorflow/contrib/verbs/verbs_util.h"
|
||||
#include "tensorflow/core/common_runtime/device.h"
|
||||
#include "tensorflow/core/common_runtime/device_mgr.h"
|
||||
#include "tensorflow/core/common_runtime/dma_helper.h"
|
||||
#include "tensorflow/core/lib/core/errors.h"
|
||||
#include "tensorflow/core/lib/strings/numbers.h"
|
||||
#include "tensorflow/core/lib/strings/str_util.h"
|
||||
|
||||
namespace tensorflow {
|
||||
|
||||
class RdmaRemoteRendezvous : public BaseRemoteRendezvous {
|
||||
public:
|
||||
RdmaRemoteRendezvous(const WorkerEnv* env, const string& worker_name,
|
||||
int64 step_id, RdmaMgr* rdma_mgr)
|
||||
: BaseRemoteRendezvous(env, worker_name, step_id, true),
|
||||
rdma_mgr_(rdma_mgr) {}
|
||||
|
||||
protected:
|
||||
void RecvFromRemoteAsync(const Rendezvous::ParsedKey& parsed,
|
||||
const Rendezvous::Args& args,
|
||||
DoneCallback done) override;
|
||||
|
||||
private:
|
||||
~RdmaRemoteRendezvous() override {}
|
||||
RdmaMgr* rdma_mgr_;
|
||||
|
||||
TF_DISALLOW_COPY_AND_ASSIGN(RdmaRemoteRendezvous);
|
||||
};
|
||||
|
||||
void RdmaRemoteRendezvous::RecvFromRemoteAsync(
|
||||
const Rendezvous::ParsedKey& parsed, const Rendezvous::Args& recv_args,
|
||||
DoneCallback done) {
|
||||
Status s;
|
||||
// parse src_name and dst_name
|
||||
string src_name, dst_name, unused;
|
||||
if (!DeviceNameUtils::SplitDeviceName(parsed.src_device, &src_name,
|
||||
&unused)) {
|
||||
s = errors::Internal("Could not parse src name.");
|
||||
}
|
||||
CHECK(s.ok()) << "s is not ok, error code " << s.error_message();
|
||||
if (!s.ok()) {
|
||||
done(s, Args(), recv_args, Tensor{}, false);
|
||||
return;
|
||||
}
|
||||
if (!DeviceNameUtils::SplitDeviceName(parsed.dst_device, &dst_name,
|
||||
&unused)) {
|
||||
s = errors::Internal("Could not parse dst name.");
|
||||
}
|
||||
CHECK(s.ok()) << "s is not ok, error code " << s.error_message();
|
||||
if (!s.ok()) {
|
||||
done(s, Args(), recv_args, Tensor{}, false);
|
||||
return;
|
||||
}
|
||||
CHECK(dst_name.compare(rdma_mgr_->local_worker()) == 0);
|
||||
RdmaChannel* rc = rdma_mgr_->FindChannel(src_name);
|
||||
string key(std::move(parsed.FullKey().ToString()));
|
||||
string key_with_step_id = VerbsUtil::AppendStepidToKey(key, step_id_);
|
||||
// insert callback
|
||||
rc->InsertRecvCallback(key_with_step_id, [this, key, key_with_step_id, rc,
|
||||
recv_args, parsed, done]() {
|
||||
Status s;
|
||||
Device* src_dev;
|
||||
s = env_->device_mgr->LookupDevice("CPU:0", &src_dev);
|
||||
CHECK(s.ok()) << "s is not ok, error code " << s.error_message();
|
||||
if (!s.ok()) {
|
||||
done(s, Args(), recv_args, Tensor(), true);
|
||||
return;
|
||||
}
|
||||
Device* dst_dev;
|
||||
s = env_->device_mgr->LookupDevice(parsed.dst_device, &dst_dev);
|
||||
CHECK(s.ok()) << "s is not ok, error code " << s.error_message();
|
||||
if (!s.ok()) {
|
||||
done(s, Args(), recv_args, Tensor(), true);
|
||||
return;
|
||||
}
|
||||
RdmaBuffer* rb = rc->FindBuffer(key);
|
||||
RdmaMessage rm;
|
||||
CHECK(rb->size_ >= RdmaMessage::kMessageTotalBytes);
|
||||
RdmaMessage::ParseMessage(rm, rb->buffer_);
|
||||
CHECK(rm.type_ == RDMA_MESSAGE_TENSOR_WRITE);
|
||||
Tensor val;
|
||||
if (!rm.is_dead_) {
|
||||
void* input = static_cast<char*>(rb->buffer_) +
|
||||
RdmaMessage::kTensorBufferStartIndex;
|
||||
TensorProto proto;
|
||||
CHECK(rm.tensor_bytes_ + RdmaMessage::kTensorBufferStartIndex <=
|
||||
rb->size_);
|
||||
CHECK(ParseProtoUnlimited(&proto, input, rm.tensor_bytes_))
|
||||
<< "fail to parse proto from array";
|
||||
s = dst_dev->MakeTensorFromProto(proto, recv_args.alloc_attrs, &val);
|
||||
}
|
||||
|
||||
rc->RemoveRecvCallback(key_with_step_id);
|
||||
// create message
|
||||
RdmaMessage br;
|
||||
br.type_ = RDMA_MESSAGE_BUFFER_IDLE;
|
||||
br.name_size_ = key.size();
|
||||
br.name_ = key;
|
||||
string message = RdmaMessage::CreateMessage(br);
|
||||
RdmaBuffer* tb = rc->tx_message_buffer_;
|
||||
tb->EnqueueItem(message);
|
||||
tb->SendNextItem();
|
||||
done(s, Args(), recv_args, val, rm.is_dead_);
|
||||
});
|
||||
// append key to message queue
|
||||
RdmaBuffer* rb = rc->tx_message_buffer_;
|
||||
RdmaMessage rm;
|
||||
rm.type_ = RDMA_MESSAGE_TENSOR_REQUEST;
|
||||
rm.name_size_ = key.size();
|
||||
rm.name_ = key;
|
||||
rm.step_id_ = step_id_;
|
||||
string message = RdmaMessage::CreateMessage(rm);
|
||||
rb->EnqueueItem(message);
|
||||
rb->SendNextItem();
|
||||
}
|
||||
|
||||
RdmaRendezvousMgr::RdmaRendezvousMgr(const WorkerEnv* env,
|
||||
const string& worker_name,
|
||||
WorkerCacheInterface* worker_cache)
|
||||
: BaseRendezvousMgr(env, worker_name) {}
|
||||
|
||||
BaseRemoteRendezvous* RdmaRendezvousMgr::Create(int64 step_id,
|
||||
const WorkerEnv* worker_env,
|
||||
const string& worker_name) {
|
||||
return new RdmaRemoteRendezvous(worker_env, worker_name, step_id, rdma_mgr_);
|
||||
}
|
||||
|
||||
} // end namespace tensorflow
|
||||
|
||||
#endif
|
64
tensorflow/contrib/verbs/rdma_rendezvous_mgr.h
Normal file
64
tensorflow/contrib/verbs/rdma_rendezvous_mgr.h
Normal file
@ -0,0 +1,64 @@
|
||||
/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
|
||||
#ifndef THIRD_PARTY_TENSORFLOW_CONTRIB_VERBS_RDMA_RENDEZVOUS_MGR_H_
|
||||
#define THIRD_PARTY_TENSORFLOW_CONTRIB_VERBS_RDMA_RENDEZVOUS_MGR_H_
|
||||
|
||||
#ifdef TENSORFLOW_USE_VERBS
|
||||
|
||||
#include "tensorflow/contrib/verbs/rdma_mgr.h"
|
||||
#include "tensorflow/core/distributed_runtime/base_rendezvous_mgr.h"
|
||||
#include "tensorflow/core/distributed_runtime/worker_env.h"
|
||||
#include "tensorflow/core/platform/macros.h"
|
||||
|
||||
namespace tensorflow {
|
||||
|
||||
// RendezvousMgr keeps track of a set of local rendezvous instances.
|
||||
// All tensors sent by this worker are buffered in a RendezvousMgr
|
||||
// until the tensor is received. Each global unique "step_id"
|
||||
// corresponds to one local rendezvous instance managed by a
|
||||
// RendezvousMgr.
|
||||
//
|
||||
// E.g.,
|
||||
// Rendezvous* rendez = worker_env->rendezvous_mgr->Find(0x8935);
|
||||
// fork execution of an graph executor using "rendez" on thread 1;
|
||||
// fork execution of another graph executor using "rendez" on thread 2;
|
||||
// ...
|
||||
// join threads 1 and 2;
|
||||
//
|
||||
// In the example above, execution in thread 1 and 2 communicates with
|
||||
// each other by send/recv operations through the "rend".
|
||||
//
|
||||
// Tensors sent and recved through rendezvous managed by this
|
||||
// RendezvousMgr must have keys generated by Rendezvous::CreateKey.
|
||||
class RdmaRendezvousMgr : public BaseRendezvousMgr {
|
||||
public:
|
||||
explicit RdmaRendezvousMgr(const WorkerEnv* env, const string& worker_name,
|
||||
WorkerCacheInterface* worker_cache);
|
||||
void SetRdmaMgr(RdmaMgr* rdma_mgr) { rdma_mgr_ = rdma_mgr; }
|
||||
|
||||
protected:
|
||||
BaseRemoteRendezvous* Create(int64 step_id, const WorkerEnv* worker_env,
|
||||
const string& worker_name) override;
|
||||
|
||||
private:
|
||||
RdmaMgr* rdma_mgr_;
|
||||
TF_DISALLOW_COPY_AND_ASSIGN(RdmaRendezvousMgr);
|
||||
};
|
||||
|
||||
} // end namespace tensorflow
|
||||
|
||||
#endif // TENSORFLOW_USE_VERBS
|
||||
#endif // THIRD_PARTY_TENSORFLOW_CONTRIB_VERBS_RDMA_RENDEZVOUS_MGR_H_
|
172
tensorflow/contrib/verbs/verbs_server_lib.cc
Normal file
172
tensorflow/contrib/verbs/verbs_server_lib.cc
Normal file
@ -0,0 +1,172 @@
|
||||
/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
|
||||
#ifdef TENSORFLOW_USE_VERBS
|
||||
|
||||
#include "tensorflow/contrib/verbs/verbs_server_lib.h"
|
||||
|
||||
#include "tensorflow/contrib/verbs/rdma_mgr.h"
|
||||
#include "tensorflow/contrib/verbs/rdma_rendezvous_mgr.h"
|
||||
#include "tensorflow/core/distributed_runtime/server_lib.h"
|
||||
#include "tensorflow/core/lib/core/status.h"
|
||||
#include "tensorflow/core/platform/env.h"
|
||||
|
||||
namespace tensorflow {
|
||||
|
||||
namespace {
|
||||
// static utility function
|
||||
RendezvousMgrInterface* NewRdmaRendezvousMgr(
|
||||
const WorkerEnv* env, const string& worker_name,
|
||||
WorkerCacheInterface* worker_cache) {
|
||||
return new RdmaRendezvousMgr(env, worker_name, worker_cache);
|
||||
}
|
||||
|
||||
} // namespace
|
||||
|
||||
VerbsServer::VerbsServer(const ServerDef& server_def, Env* env)
|
||||
: GrpcServer(server_def, env), verbs_state_(DISCONNECTED) {}
|
||||
|
||||
VerbsServer::~VerbsServer() {
|
||||
TF_CHECK_OK(Stop());
|
||||
TF_CHECK_OK(Join());
|
||||
delete rdma_mgr_;
|
||||
delete verbs_service_;
|
||||
delete channel_cache_;
|
||||
}
|
||||
|
||||
Status VerbsServer::ChannelCacheFactory(const ServerDef& server_def,
|
||||
GrpcChannelCache** channel_cache) {
|
||||
string name_prefix =
|
||||
strings::StrCat("/job:", server_def.job_name(), "/replica:0",
|
||||
"/task:", server_def.task_index());
|
||||
|
||||
GrpcChannelSpec channel_spec;
|
||||
TF_RETURN_IF_ERROR(ParseChannelSpec(server_def, &channel_spec));
|
||||
|
||||
*channel_cache =
|
||||
NewGrpcChannelCache(channel_spec, GetChannelCreationFunction(server_def));
|
||||
|
||||
const string host_port = (*channel_cache)->TranslateTask(name_prefix);
|
||||
int requested_port;
|
||||
|
||||
if (!strings::safe_strto32(str_util::Split(host_port, ':')[1],
|
||||
&requested_port)) {
|
||||
return errors::Internal("Could not parse port for local server from \"",
|
||||
(*channel_cache)->TranslateTask(name_prefix),
|
||||
"\".");
|
||||
}
|
||||
if (requested_port != bound_port()) {
|
||||
return errors::InvalidArgument("Requested port ", requested_port,
|
||||
" differs from expected port ",
|
||||
bound_port());
|
||||
}
|
||||
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
Status VerbsServer::Init(ServiceInitFunction service_func,
|
||||
RendezvousMgrCreationFunction rendezvous_mgr_func) {
|
||||
Status s = GrpcServer::Init(service_func, rendezvous_mgr_func);
|
||||
{
|
||||
mutex_lock l(mu_);
|
||||
CHECK_EQ(verbs_state_, DISCONNECTED);
|
||||
CHECK(ChannelCacheFactory(server_def(), &channel_cache_).ok());
|
||||
rdma_mgr_ = new RdmaMgr(worker_env(), channel_cache_);
|
||||
// set rdma_mgr for verbs_service and rdma_rendezvous_mgr
|
||||
verbs_service_->SetRdmaMgr(rdma_mgr_);
|
||||
// hardcoded to default session (legacy_session_)
|
||||
// TODO: use WorkerSessionForSession
|
||||
// need to pass in session handle
|
||||
dynamic_cast<RdmaRendezvousMgr*>(
|
||||
worker_env()->session_mgr->LegacySession()->rendezvous_mgr.get())
|
||||
->SetRdmaMgr(rdma_mgr_);
|
||||
}
|
||||
return s;
|
||||
}
|
||||
|
||||
Status VerbsServer::Start() {
|
||||
Status s = GrpcServer::Start();
|
||||
{
|
||||
mutex_lock l(mu_);
|
||||
if (verbs_state_ == DISCONNECTED) {
|
||||
// verbs_thread needs to be initiated
|
||||
// before rdma_mgr sets up the rdma channels.
|
||||
verbs_thread_.reset(worker_env()->env->StartThread(
|
||||
ThreadOptions(), "TF_verbs_service",
|
||||
[this] { verbs_service_->HandleRPCsLoop(); }));
|
||||
rdma_mgr_->SetupChannels();
|
||||
verbs_state_ = CONNECTED;
|
||||
}
|
||||
}
|
||||
return s;
|
||||
}
|
||||
|
||||
Status VerbsServer::Join() {
|
||||
Status s = GrpcServer::Join();
|
||||
{
|
||||
mutex_lock l(mu_);
|
||||
if (verbs_state_ == CONNECTED) {
|
||||
verbs_state_ = DISCONNECTED;
|
||||
verbs_thread_.reset();
|
||||
}
|
||||
}
|
||||
return s;
|
||||
}
|
||||
|
||||
/* static */
|
||||
Status VerbsServer::Create(const ServerDef& server_def, Env* env,
|
||||
std::unique_ptr<ServerInterface>* out_server) {
|
||||
std::unique_ptr<VerbsServer> ret(new VerbsServer(server_def, Env::Default()));
|
||||
ServiceInitFunction service_func = [&ret](const WorkerEnv* worker_env,
|
||||
::grpc::ServerBuilder* builder) {
|
||||
return SetNewVerbsService(&ret->verbs_service_, worker_env, builder);
|
||||
};
|
||||
TF_RETURN_IF_ERROR(ret->Init(service_func, NewRdmaRendezvousMgr));
|
||||
*out_server = std::move(ret);
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
namespace {
|
||||
|
||||
class VerbsServerFactory : public ServerFactory {
|
||||
public:
|
||||
bool AcceptsOptions(const ServerDef& server_def) override {
|
||||
return server_def.protocol() == "grpc+verbs";
|
||||
}
|
||||
|
||||
Status NewServer(const ServerDef& server_def,
|
||||
std::unique_ptr<ServerInterface>* out_server) override {
|
||||
return VerbsServer::Create(server_def, Env::Default(), out_server);
|
||||
}
|
||||
};
|
||||
|
||||
// Registers a `ServerFactory` for `VerbsServer` instances.
|
||||
class VerbsServerRegistrar {
|
||||
public:
|
||||
VerbsServerRegistrar() {
|
||||
gpr_allocation_functions alloc_fns;
|
||||
alloc_fns.malloc_fn = port::Malloc;
|
||||
alloc_fns.realloc_fn = port::Realloc;
|
||||
alloc_fns.free_fn = port::Free;
|
||||
gpr_set_allocation_functions(alloc_fns);
|
||||
ServerFactory::Register("VERBS_SERVER", new VerbsServerFactory());
|
||||
}
|
||||
};
|
||||
static VerbsServerRegistrar registrar;
|
||||
|
||||
} // namespace
|
||||
} // namespace tensorflow
|
||||
|
||||
#endif
|
66
tensorflow/contrib/verbs/verbs_server_lib.h
Normal file
66
tensorflow/contrib/verbs/verbs_server_lib.h
Normal file
@ -0,0 +1,66 @@
|
||||
/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
|
||||
#ifndef THIRD_PARTY_TENSORFLOW_CONTRIB_VERBS_VERBS_SERVER_LIB_H_
|
||||
#define THIRD_PARTY_TENSORFLOW_CONTRIB_VERBS_VERBS_SERVER_LIB_H_
|
||||
|
||||
#ifdef TENSORFLOW_USE_VERBS
|
||||
|
||||
#include "tensorflow/contrib/verbs/grpc_verbs_service.h"
|
||||
#include "tensorflow/contrib/verbs/rdma_mgr.h"
|
||||
#include "tensorflow/core/distributed_runtime/rpc/grpc_server_lib.h"
|
||||
|
||||
namespace tensorflow {
|
||||
|
||||
class VerbsServer : public GrpcServer {
|
||||
protected:
|
||||
VerbsServer(const ServerDef& server_def, Env* env);
|
||||
|
||||
public:
|
||||
static Status Create(const ServerDef& server_def, Env* env,
|
||||
std::unique_ptr<ServerInterface>* out_server);
|
||||
|
||||
// Destruction is only supported in the factory method. Clean
|
||||
// shutdown is not currently implemented for this server type.
|
||||
virtual ~VerbsServer() override;
|
||||
|
||||
// Implementations of ServerInterface methods.
|
||||
Status Start() override;
|
||||
Status Join() override;
|
||||
|
||||
protected:
|
||||
Status Init(ServiceInitFunction service_func,
|
||||
RendezvousMgrCreationFunction rendezvous_mgr_func);
|
||||
Status ChannelCacheFactory(const ServerDef& server_def,
|
||||
GrpcChannelCache** channel_cache);
|
||||
|
||||
private:
|
||||
RdmaMgr* rdma_mgr_;
|
||||
|
||||
// Guards state transitions.
|
||||
mutex mu_;
|
||||
|
||||
enum State { DISCONNECTED, CONNECTED };
|
||||
State verbs_state_ GUARDED_BY(mu_);
|
||||
|
||||
GrpcVerbsService* verbs_service_ = nullptr;
|
||||
std::unique_ptr<Thread> verbs_thread_ GUARDED_BY(mu_);
|
||||
GrpcChannelCache* channel_cache_ = nullptr;
|
||||
};
|
||||
|
||||
} // namespace tensorflow
|
||||
|
||||
#endif // TENSORFLOW_USE_VERBS
|
||||
#endif // THIRD_PARTY_TENSORFLOW_CONTRIB_VERBS_VERBS_SERVER_LIB_H_
|
60
tensorflow/contrib/verbs/verbs_service.proto
Normal file
60
tensorflow/contrib/verbs/verbs_service.proto
Normal file
@ -0,0 +1,60 @@
|
||||
/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
|
||||
syntax = "proto3";
|
||||
|
||||
package tensorflow;
|
||||
option java_outer_classname = "VerbsServiceProtos";
|
||||
option java_multiple_files = true;
|
||||
option java_package = "org.tensorflow.contrib.verbs";
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////
|
||||
//
|
||||
// GRPC Helper messages used to exchange RDMA information.
|
||||
//
|
||||
////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
message Channel {
|
||||
int32 lid = 1;
|
||||
int32 qpn = 2;
|
||||
int32 psn = 3;
|
||||
}
|
||||
|
||||
message MemoryRegion {
|
||||
uint64 remote_addr = 1;
|
||||
uint32 rkey = 2;
|
||||
}
|
||||
message GetRemoteAddressRequest {
|
||||
string host_name = 1;
|
||||
Channel channel = 2;
|
||||
repeated MemoryRegion mr = 3;
|
||||
}
|
||||
|
||||
message GetRemoteAddressResponse {
|
||||
string host_name = 1;
|
||||
Channel channel = 2;
|
||||
repeated MemoryRegion mr = 3;
|
||||
}
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////
|
||||
//
|
||||
// VerbsService
|
||||
//
|
||||
////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
service VerbsService {
|
||||
rpc GetRemoteAddress(GetRemoteAddressRequest)
|
||||
returns (GetRemoteAddressResponse);
|
||||
}
|
61
tensorflow/contrib/verbs/verbs_util.cc
Normal file
61
tensorflow/contrib/verbs/verbs_util.cc
Normal file
@ -0,0 +1,61 @@
|
||||
/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
|
||||
#include "tensorflow/contrib/verbs/verbs_util.h"
|
||||
|
||||
#include "tensorflow/core/common_runtime/gpu/gpu_util.h"
|
||||
#include "tensorflow/core/lib/core/notification.h"
|
||||
#include "tensorflow/core/lib/strings/str_util.h"
|
||||
namespace tensorflow {
|
||||
|
||||
// static sync wrapper:
|
||||
Status VerbsUtil::SetProtoFromGPUSync(const Tensor& tensor, Device* dev,
|
||||
const DeviceContext* device_context,
|
||||
TensorProto* proto, bool is_dead) {
|
||||
Notification n;
|
||||
Status status;
|
||||
GPUUtil::SetProtoFromGPU(tensor, dev, device_context, proto, is_dead,
|
||||
[&n, &status](const Status& s) {
|
||||
status = s;
|
||||
n.Notify();
|
||||
});
|
||||
n.WaitForNotification();
|
||||
return status;
|
||||
}
|
||||
|
||||
// static
|
||||
string VerbsUtil::AppendStepidToKey(const string& key, int64 step_id) {
|
||||
return strings::StrCat(key, ";", step_id);
|
||||
}
|
||||
|
||||
// static
|
||||
void VerbsUtil::GetKeyAndStepId(const string& key_with_step_id, string& key,
|
||||
int64& step_id) {
|
||||
StringPiece s(key_with_step_id);
|
||||
// a key (with step_id) has exact 6 parts if split by ";"
|
||||
// part 1: src_device;
|
||||
// part 2: src_incarnation;
|
||||
// part 3: dst_device;
|
||||
// part 4: name;
|
||||
// part 5: frame_iter.frame_id:frame_iter.iter_id
|
||||
// part 6: step_id
|
||||
std::vector<string> parts = str_util::Split(s, ';');
|
||||
CHECK(parts.size() == 6) << "Key with step_id must have 6 parts";
|
||||
strings::safe_strto64(parts[5], &step_id);
|
||||
parts.pop_back(); // remove step_id
|
||||
key.assign(str_util::Join(parts, ";")); // stitch them together
|
||||
}
|
||||
|
||||
} // namespace tensorflow
|
41
tensorflow/contrib/verbs/verbs_util.h
Normal file
41
tensorflow/contrib/verbs/verbs_util.h
Normal file
@ -0,0 +1,41 @@
|
||||
/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
|
||||
#ifndef TENSORFLOW_CONTRIB_RDMA_UTIL_H_
|
||||
#define TENSORFLOW_CONTRIB_RDMA_UTIL_H_
|
||||
|
||||
#include <string>
|
||||
|
||||
#include "tensorflow/core/common_runtime/device.h"
|
||||
#include "tensorflow/core/framework/tensor.h"
|
||||
#include "tensorflow/core/lib/core/status.h"
|
||||
|
||||
namespace tensorflow {
|
||||
|
||||
class TensorProto;
|
||||
|
||||
class VerbsUtil {
|
||||
public:
|
||||
// synchronous wrapper of SetProtoFromGPU
|
||||
static Status SetProtoFromGPUSync(const Tensor& tensor, Device* dev,
|
||||
const DeviceContext* device_context,
|
||||
TensorProto* proto, bool is_dead);
|
||||
static string AppendStepidToKey(const string& key, int64 step_id);
|
||||
static void GetKeyAndStepId(const string& key_with_step_id, string& key,
|
||||
int64& step_id);
|
||||
};
|
||||
|
||||
} // namespace tensorflow
|
||||
#endif // TENSORFLOW_CONTRIB_RDMA_UTIL_H_
|
@ -108,6 +108,7 @@ load(
|
||||
"tf_additional_cloud_op_deps",
|
||||
"tf_additional_cloud_kernel_deps",
|
||||
"tf_lib_proto_parsing_deps",
|
||||
"tf_additional_verbs_lib_defines",
|
||||
)
|
||||
load(
|
||||
"//tensorflow/core:platform/default/build_config_root.bzl",
|
||||
@ -732,9 +733,13 @@ cc_library(
|
||||
"//tensorflow/core/kernels:math_not_windows",
|
||||
"//tensorflow/core/kernels:quantized_ops",
|
||||
]) + if_mkl([
|
||||
"//tensorflow/core/kernels:mkl_concat_op",
|
||||
"//tensorflow/core/kernels:mkl_conv_op",
|
||||
"//tensorflow/core/kernels:mkl_fused_batch_norm_op",
|
||||
"//tensorflow/core/kernels:mkl_lrn_op",
|
||||
"//tensorflow/core/kernels:mkl_pooling_ops",
|
||||
"//tensorflow/core/kernels:mkl_relu_op",
|
||||
"//tensorflow/core/kernels:mkl_reshape_op",
|
||||
"//tensorflow/core/kernels:mkl_tfconv_op",
|
||||
]),
|
||||
)
|
||||
@ -1272,7 +1277,9 @@ cc_library(
|
||||
"platform/tracing.h",
|
||||
],
|
||||
copts = tf_copts(),
|
||||
defines = tf_additional_lib_defines() + ["SNAPPY"],
|
||||
defines = tf_additional_lib_defines() + [
|
||||
"SNAPPY",
|
||||
] + tf_additional_verbs_lib_defines(),
|
||||
linkopts = select({
|
||||
"//tensorflow:freebsd": [],
|
||||
"//conditions:default": ["-ldl"],
|
||||
@ -2089,7 +2096,6 @@ tf_cc_test_mkl(
|
||||
size = "small",
|
||||
srcs = [
|
||||
"graph/mkl_layout_pass_test.cc",
|
||||
"graph/mkl_optimizer_merge_test.cc",
|
||||
"graph/mkl_tfconversion_pass_test.cc",
|
||||
],
|
||||
linkstatic = tf_kernel_tests_linkstatic(),
|
||||
@ -2110,9 +2116,13 @@ tf_cc_test_mkl(
|
||||
"//tensorflow/cc:cc_ops",
|
||||
"//tensorflow/cc:scope",
|
||||
"//tensorflow/cc:sendrecv_ops",
|
||||
"//tensorflow/core/kernels:mkl_concat_op",
|
||||
"//tensorflow/core/kernels:mkl_conv_op",
|
||||
"//tensorflow/core/kernels:mkl_fused_batch_norm_op",
|
||||
"//tensorflow/core/kernels:mkl_lrn_op",
|
||||
"//tensorflow/core/kernels:mkl_pooling_ops",
|
||||
"//tensorflow/core/kernels:mkl_relu_op",
|
||||
"//tensorflow/core/kernels:mkl_reshape_op",
|
||||
"//tensorflow/core/kernels:mkl_tfconv_op",
|
||||
"//tensorflow/core/kernels:ops_util",
|
||||
"//third_party/eigen3",
|
||||
|
@ -17,7 +17,9 @@ limitations under the License.
|
||||
|
||||
#include <vector>
|
||||
|
||||
#if defined(PLATFORM_GOOGLE)
|
||||
#include "grpc++/create_channel.h"
|
||||
#endif
|
||||
|
||||
#if defined(PLATFORM_WINDOWS)
|
||||
// winsock2.h is used in grpc, so Ws2_32.lib is needed
|
||||
|
@ -62,6 +62,13 @@ class NoReusePortOption : public ::grpc::ServerBuilderOption {
|
||||
plugins) override {}
|
||||
};
|
||||
|
||||
// static utility function
|
||||
RendezvousMgrInterface* NewRpcRendezvousMgr(
|
||||
const WorkerEnv* env, const string& worker_name,
|
||||
WorkerCacheInterface* worker_cache) {
|
||||
return new RpcRendezvousMgr(env, worker_name, worker_cache);
|
||||
}
|
||||
|
||||
} // namespace
|
||||
|
||||
GrpcServer::GrpcServer(const ServerDef& server_def, Env* env)
|
||||
@ -93,7 +100,8 @@ GrpcServer::~GrpcServer() {
|
||||
// - worker_env_.compute_pool
|
||||
}
|
||||
|
||||
Status GrpcServer::Init() {
|
||||
Status GrpcServer::Init(ServiceInitFunction service_func,
|
||||
RendezvousMgrCreationFunction rendevous_mgr_func) {
|
||||
mutex_lock l(mu_);
|
||||
CHECK_EQ(state_, NEW);
|
||||
master_env_.env = env_;
|
||||
@ -170,6 +178,10 @@ Status GrpcServer::Init() {
|
||||
worker_impl_ = NewGrpcWorker(&worker_env_);
|
||||
worker_service_ =
|
||||
NewGrpcWorkerService(worker_impl_.get(), &builder).release();
|
||||
// extra service:
|
||||
if (service_func != nullptr) {
|
||||
service_func(&worker_env_, &builder);
|
||||
}
|
||||
server_ = builder.BuildAndStart();
|
||||
|
||||
if (!server_) {
|
||||
@ -182,7 +194,9 @@ Status GrpcServer::Init() {
|
||||
|
||||
// Set up worker environment.
|
||||
std::unique_ptr<RendezvousMgrInterface> rendezvous_mgr(
|
||||
new RpcRendezvousMgr(&worker_env_, name_prefix, worker_cache));
|
||||
rendevous_mgr_func == nullptr ?
|
||||
new RpcRendezvousMgr(&worker_env_, name_prefix, worker_cache) :
|
||||
rendevous_mgr_func(&worker_env_, name_prefix, worker_cache));
|
||||
worker_env_.session_mgr = new SessionMgr(
|
||||
&worker_env_, SessionMgr::WorkerNameFromServerDef(server_def_),
|
||||
std::unique_ptr<WorkerCacheInterface>(worker_cache),
|
||||
@ -211,6 +225,10 @@ Status GrpcServer::Init() {
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
Status GrpcServer::Init() {
|
||||
return Init(nullptr, nullptr);
|
||||
}
|
||||
|
||||
Status GrpcServer::ParseChannelSpec(const ServerDef& server_def,
|
||||
GrpcChannelSpec* channel_spec) {
|
||||
for (const auto& job : server_def.cluster().job()) {
|
||||
@ -248,6 +266,7 @@ Status GrpcServer::WorkerCacheFactory(const ServerDef& server_def,
|
||||
channel_spec, GetChannelCreationFunction(server_def)));
|
||||
const string host_port = channel_cache->TranslateTask(name_prefix);
|
||||
int requested_port;
|
||||
|
||||
if (!strings::safe_strto32(str_util::Split(host_port, ':')[1],
|
||||
&requested_port)) {
|
||||
return errors::Internal("Could not parse port for local server from \"",
|
||||
@ -346,7 +365,8 @@ Status GrpcServer::Create(const ServerDef& server_def, Env* env,
|
||||
std::unique_ptr<ServerInterface>* out_server) {
|
||||
std::unique_ptr<GrpcServer> ret(
|
||||
new GrpcServer(server_def, env == nullptr ? Env::Default() : env));
|
||||
TF_RETURN_IF_ERROR(ret->Init());
|
||||
ServiceInitFunction service_func = nullptr;
|
||||
TF_RETURN_IF_ERROR(ret->Init(service_func, NewRpcRendezvousMgr));
|
||||
*out_server = std::move(ret);
|
||||
return Status::OK();
|
||||
}
|
||||
|
@ -36,6 +36,17 @@ namespace tensorflow {
|
||||
class GrpcWorker;
|
||||
class Master;
|
||||
|
||||
// function that creates a RendezvousMgr.
|
||||
typedef std::function<RendezvousMgrInterface*(
|
||||
const WorkerEnv*, const std::string& worker_name,
|
||||
WorkerCacheInterface* worker_cache)>
|
||||
RendezvousMgrCreationFunction;
|
||||
|
||||
// function that registers a service to the server. The service needs to
|
||||
// be registered before builder.BuildAndStart().
|
||||
typedef std::function<void(const WorkerEnv*, ::grpc::ServerBuilder*)>
|
||||
ServiceInitFunction;
|
||||
|
||||
class GrpcServer : public ServerInterface {
|
||||
protected:
|
||||
GrpcServer(const ServerDef& server_def, Env* env);
|
||||
@ -55,6 +66,9 @@ class GrpcServer : public ServerInterface {
|
||||
const string target() const override;
|
||||
|
||||
protected:
|
||||
Status Init(ServiceInitFunction service_func,
|
||||
RendezvousMgrCreationFunction rendezvous_mgr_func);
|
||||
|
||||
Status Init();
|
||||
|
||||
// A subclass can override this method to support secure credentials.
|
||||
@ -78,6 +92,10 @@ class GrpcServer : public ServerInterface {
|
||||
// This method may only be called after `this->Init()` returns successfully.
|
||||
int bound_port() const { return bound_port_; }
|
||||
|
||||
WorkerEnv* worker_env() { return &worker_env_; }
|
||||
|
||||
const ServerDef& server_def() const { return server_def_; }
|
||||
|
||||
private:
|
||||
// The overall server configuration.
|
||||
const ServerDef server_def_;
|
||||
|
@ -126,25 +126,33 @@ FunctionDef XTimes16() {
|
||||
{{"y", "y:y:0"}});
|
||||
}
|
||||
|
||||
FunctionDef WXPlusB() {
|
||||
return FDH::Define(
|
||||
// Name
|
||||
"WXPlusB",
|
||||
// Args
|
||||
{"w: T", "x: T", "b: T"},
|
||||
// Return values
|
||||
{"y: T"},
|
||||
// Attr def
|
||||
{"T: {float, double}"},
|
||||
// Nodes
|
||||
{{{"mm"},
|
||||
"MatMul",
|
||||
{"w", "x"},
|
||||
{{"T", "$T"},
|
||||
{"transpose_a", false},
|
||||
{"transpose_b", false},
|
||||
FunctionDef WXPlusB(){return FDH::Define(
|
||||
// Name
|
||||
"WXPlusB",
|
||||
// Args
|
||||
{"w: T", "x: T", "b: T"},
|
||||
// Return values
|
||||
{"y: T"},
|
||||
// Attr def
|
||||
{"T: {float, double}"},
|
||||
// Nodes
|
||||
{
|
||||
{{"mm"},
|
||||
"MatMul",
|
||||
{"w", "x"},
|
||||
{
|
||||
{"T", "$T"}, {"transpose_a", false}, {"transpose_b", false},
|
||||
#ifdef INTEL_MKL
|
||||
}},
|
||||
#else
|
||||
{"_kernel", "eigen"}}},
|
||||
{{"y"}, "Add", {"mm", "b"}, {{"T", "$T"}}}});
|
||||
#endif
|
||||
{
|
||||
{"y"}, "Add", {"mm", "b"}, {
|
||||
{ "T", "$T" }
|
||||
}
|
||||
}
|
||||
});
|
||||
}
|
||||
|
||||
FunctionDef Swap() {
|
||||
|
File diff suppressed because it is too large
Load Diff
@ -110,9 +110,11 @@ class MklLayoutPassTest : public ::testing::Test {
|
||||
};
|
||||
|
||||
REGISTER_OP("Input").Output("o: float").SetIsStateful();
|
||||
REGISTER_OP("InputList").Output("o: N * float").Attr("N: int").SetIsStateful();
|
||||
REGISTER_OP("HalfInput").Output("o: half").SetIsStateful();
|
||||
REGISTER_OP("MklInput").Output("o: uint8").SetIsStateful();
|
||||
REGISTER_OP("MklInput2").Output("o: uint8").Output("o1: uint8").SetIsStateful();
|
||||
REGISTER_OP("Int32Input").Output("o: int32").SetIsStateful();
|
||||
REGISTER_OP("_MklInput").Output("o: uint8").SetIsStateful();
|
||||
REGISTER_OP("_MklInput2").Output("o: uint8").Output("o1: uint8").SetIsStateful();
|
||||
|
||||
/////////////////////////////////////////////////////////////////////
|
||||
// Unit tests related to node merge optiimization
|
||||
@ -133,20 +135,22 @@ TEST_F(MklLayoutPassTest, Basic) {
|
||||
|
||||
// Test set 1: Conv2D + AddBias
|
||||
|
||||
// C=MklConv2D(A,M,B,N); E=BiasAdd(C,D); Z=Sub(E,Y)
|
||||
// C=_MklConv2D(A,M,B,N); E=BiasAdd(C,D); Z=Sub(E,Y) (for interleaved ordering)
|
||||
// C=_MklConv2D(A,B,M,N); E=BiasAdd(C,D); Z=Sub(E,Y) (for contiguous ordering)
|
||||
TEST_F(MklLayoutPassTest, NodeMerge_Conv2DWithBias_Positive) {
|
||||
CHECK_EQ(kTensorOrdering, MklTfTensorOrdering::TENSORS_CONTIGUOUS);
|
||||
InitGraph(
|
||||
"node { name: 'A' op: 'Input'}"
|
||||
"node { name: 'M' op: 'MklInput'}"
|
||||
"node { name: 'B' op: 'Input'}"
|
||||
"node { name: 'N' op: 'MklInput'}"
|
||||
"node { name: 'C' op: 'MklConv2D'"
|
||||
"node { name: 'M' op: '_MklInput'}"
|
||||
"node { name: 'N' op: '_MklInput'}"
|
||||
"node { name: 'C' op: '_MklConv2D'"
|
||||
" attr { key: 'T' value { type: DT_FLOAT } }"
|
||||
" attr { key: 'data_format' value { s: 'NCHW' } }"
|
||||
" attr { key: 'use_cudnn_on_gpu' value { b: false } }"
|
||||
" attr { key: 'strides' value { list: {i: 1, i:1, i:1, i:1} } }"
|
||||
" attr { key: 'padding' value { s: 'SAME' } }"
|
||||
" input: ['A', 'M', 'B', 'N']}"
|
||||
" input: ['A', 'B', 'M', 'N']}"
|
||||
"node { name: 'D' op: 'Input'}"
|
||||
"node { name: 'E' op: 'BiasAdd'"
|
||||
" attr { key: 'T' value { type: DT_FLOAT } }"
|
||||
@ -157,26 +161,28 @@ TEST_F(MklLayoutPassTest, NodeMerge_Conv2DWithBias_Positive) {
|
||||
" attr {key: 'T' value { type: DT_FLOAT } }"
|
||||
" input: ['E', 'Y']}");
|
||||
EXPECT_EQ(DoMklLayoutOptimizationPass(),
|
||||
"A(Input);B(Input);D(Input);DMT/_0(Const);E(MklConv2DWithBias);"
|
||||
"M(MklInput);N(MklInput);Y(Input);Z(Sub)|A->E;B->E:2;D->E:4;"
|
||||
"DMT/_0->E:5;E->Z;M->E:1;N->E:3;Y->Z:1");
|
||||
"A(Input);B(Input);D(Input);DMT/_0(Const);E(_MklConv2DWithBias);"
|
||||
"M(_MklInput);N(_MklInput);Y(Input);Z(Sub)|A->E;B->E:1;D->E:2;"
|
||||
"DMT/_0->E:5;E->Z;M->E:3;N->E:4;Y->Z:1");
|
||||
}
|
||||
|
||||
// C=MklConv2D(A,M:1,B,N:1); E=BiasAdd(C,D); Z=Sub(E,Y)
|
||||
// C=_MklConv2D(A,M:1,B,N:1); E=BiasAdd(C,D); Z=Sub(E,Y) (for interleaved)
|
||||
// C=_MklConv2D(A,B,M:1,N:1); E=BiasAdd(C,D); Z=Sub(E,Y) (for contiguous)
|
||||
// Test for correct output slots selected
|
||||
TEST_F(MklLayoutPassTest, NodeMerge_Conv2DWithBias_Positive1) {
|
||||
CHECK_EQ(kTensorOrdering, MklTfTensorOrdering::TENSORS_CONTIGUOUS);
|
||||
InitGraph(
|
||||
"node { name: 'A' op: 'Input'}"
|
||||
"node { name: 'M' op: 'MklInput2'}"
|
||||
"node { name: 'B' op: 'Input'}"
|
||||
"node { name: 'N' op: 'MklInput2'}"
|
||||
"node { name: 'C' op: 'MklConv2D'"
|
||||
"node { name: 'M' op: '_MklInput2'}"
|
||||
"node { name: 'N' op: '_MklInput2'}"
|
||||
"node { name: 'C' op: '_MklConv2D'"
|
||||
" attr { key: 'T' value { type: DT_FLOAT } }"
|
||||
" attr { key: 'data_format' value { s: 'NCHW' } }"
|
||||
" attr { key: 'use_cudnn_on_gpu' value { b: false } }"
|
||||
" attr { key: 'strides' value { list: {i: 1, i:1, i:1, i:1} } }"
|
||||
" attr { key: 'padding' value { s: 'SAME' } }"
|
||||
" input: ['A', 'M:1', 'B', 'N:1']}"
|
||||
" input: ['A', 'B', 'M:1', 'N:1']}"
|
||||
"node { name: 'D' op: 'Input'}"
|
||||
"node { name: 'E' op: 'BiasAdd'"
|
||||
" attr { key: 'T' value { type: DT_FLOAT } }"
|
||||
@ -187,16 +193,17 @@ TEST_F(MklLayoutPassTest, NodeMerge_Conv2DWithBias_Positive1) {
|
||||
" attr {key: 'T' value { type: DT_FLOAT } }"
|
||||
" input: ['E', 'Y']}");
|
||||
EXPECT_EQ(DoMklLayoutOptimizationPass(),
|
||||
"A(Input);B(Input);D(Input);DMT/_0(Const);E(MklConv2DWithBias);"
|
||||
"M(MklInput2);N(MklInput2);Y(Input);Z(Sub)|A->E;B->E:2;D->E:4;"
|
||||
"DMT/_0->E:5;E->Z;M:1->E:1;N:1->E:3;Y->Z:1");
|
||||
"A(Input);B(Input);D(Input);DMT/_0(Const);E(_MklConv2DWithBias);"
|
||||
"M(_MklInput2);N(_MklInput2);Y(Input);Z(Sub)|A->E;B->E:1;D->E:2;"
|
||||
"DMT/_0->E:5;E->Z;M:1->E:3;N:1->E:4;Y->Z:1");
|
||||
}
|
||||
|
||||
// C=Conv2D(A,B); E=BiasAdd(C,D); Z=Sub(E,Y);
|
||||
// This is a case of node rewrite followed by node merge.
|
||||
// We will first rewrite Conv2D to MklConv2D, and then merge MklConv2D
|
||||
// with BiasAdd to produce MklConv2DWithBias.
|
||||
// We will first rewrite Conv2D to _MklConv2D, and then merge _MklConv2D
|
||||
// with BiasAdd to produce _MklConv2DWithBias.
|
||||
TEST_F(MklLayoutPassTest, NodeMerge_Conv2DWithBias_Positive2) {
|
||||
CHECK_EQ(kTensorOrdering, MklTfTensorOrdering::TENSORS_CONTIGUOUS);
|
||||
InitGraph(
|
||||
"node { name: 'A' op: 'Input'}"
|
||||
"node { name: 'B' op: 'Input'}"
|
||||
@ -218,70 +225,70 @@ TEST_F(MklLayoutPassTest, NodeMerge_Conv2DWithBias_Positive2) {
|
||||
" input: ['E', 'Y']}");
|
||||
EXPECT_EQ(DoMklLayoutOptimizationPass(),
|
||||
"A(Input);B(Input);D(Input);DMT/_0(Const);DMT/_1(Const);"
|
||||
"DMT/_2(Const);E(MklConv2DWithBias);Y(Input);Z(Sub)|"
|
||||
"A->E;B->E:2;D->E:4;DMT/_0->E:1;DMT/_1->E:3;DMT/_2->E:5;"
|
||||
"DMT/_2(Const);E(_MklConv2DWithBias);Y(Input);Z(Sub)|"
|
||||
"A->E;B->E:1;D->E:2;DMT/_0->E:3;DMT/_1->E:4;DMT/_2->E:5;"
|
||||
"E->Z;Y->Z:1");
|
||||
}
|
||||
|
||||
// Graph contains only MklConv2D, no AddBias.
|
||||
// Graph contains only _MklConv2D, no AddBias.
|
||||
TEST_F(MklLayoutPassTest, NodeMerge_Conv2DWithBias_Negative_NoAddBias) {
|
||||
InitGraph(
|
||||
"node { name: 'A' op: 'Input'}"
|
||||
"node { name: 'M' op: 'MklInput'}"
|
||||
"node { name: 'B' op: 'Input'}"
|
||||
"node { name: 'N' op: 'MklInput'}"
|
||||
"node { name: 'C' op: 'MklConv2D'"
|
||||
"node { name: 'M' op: '_MklInput'}"
|
||||
"node { name: 'N' op: '_MklInput'}"
|
||||
"node { name: 'C' op: '_MklConv2D'"
|
||||
" attr { key: 'T' value { type: DT_FLOAT } }"
|
||||
" attr { key: 'data_format' value { s: 'NCHW' } }"
|
||||
" attr { key: 'use_cudnn_on_gpu' value { b: false } }"
|
||||
" attr { key: 'strides' value { list: {i: 1, i:1, i:1, i:1} } }"
|
||||
" attr { key: 'padding' value { s: 'SAME' } }"
|
||||
" input: ['A', 'M', 'B', 'N']}");
|
||||
" input: ['A', 'B', 'M', 'N']}");
|
||||
EXPECT_EQ(DoMklLayoutOptimizationPass(),
|
||||
"A(Input);B(Input);C(MklConv2D);M(MklInput);N(MklInput)|"
|
||||
"A->C;B->C:2;M->C:1;N->C:3");
|
||||
"A(Input);B(Input);C(_MklConv2D);M(_MklInput);N(_MklInput)|"
|
||||
"A->C;B->C:1;M->C:2;N->C:3");
|
||||
}
|
||||
|
||||
// MklConv2D output does not go to BiasAdd.
|
||||
// _MklConv2D output does not go to BiasAdd.
|
||||
TEST_F(MklLayoutPassTest, NodeMerge_Conv2DWithBias_Negative_Dataflow1) {
|
||||
InitGraph(
|
||||
"node { name: 'A' op: 'Input'}"
|
||||
"node { name: 'M' op: 'MklInput'}"
|
||||
"node { name: 'B' op: 'Input'}"
|
||||
"node { name: 'N' op: 'MklInput'}"
|
||||
"node { name: 'C' op: 'MklConv2D'"
|
||||
"node { name: 'M' op: '_MklInput'}"
|
||||
"node { name: 'N' op: '_MklInput'}"
|
||||
"node { name: 'C' op: '_MklConv2D'"
|
||||
" attr { key: 'T' value { type: DT_FLOAT } }"
|
||||
" attr { key: 'data_format' value { s: 'NCHW' } }"
|
||||
" attr { key: 'use_cudnn_on_gpu' value { b: false } }"
|
||||
" attr { key: 'strides' value { list: {i: 1, i:1, i:1, i:1} } }"
|
||||
" attr { key: 'padding' value { s: 'SAME' } }"
|
||||
" input: ['A', 'M', 'B', 'N']}"
|
||||
" input: ['A', 'B', 'M', 'N']}"
|
||||
"node { name: 'D' op: 'Input'}"
|
||||
"node { name: 'E' op: 'Input'}"
|
||||
"node { name: 'F' op: 'BiasAdd'"
|
||||
" attr { key: 'T' value { type: DT_FLOAT } }"
|
||||
" attr { key: 'data_format' value { s: 'NCHW' } }"
|
||||
" input: ['D', 'E'] }"); // Output of MklConv2D does not go to BiasAdd.
|
||||
" input: ['D', 'E'] }"); // Output of _MklConv2D does not go to BiasAdd.
|
||||
EXPECT_EQ(DoMklLayoutOptimizationPass(),
|
||||
"A(Input);B(Input);C(MklConv2D);D(Input);E(Input);F(BiasAdd);"
|
||||
"M(MklInput);N(MklInput)|A->C;B->C:2;D->F;E->F:1;M->C:1;N->C:3");
|
||||
"A(Input);B(Input);C(_MklConv2D);D(Input);E(Input);F(BiasAdd);"
|
||||
"M(_MklInput);N(_MklInput)|A->C;B->C:1;D->F;E->F:1;M->C:2;N->C:3");
|
||||
}
|
||||
|
||||
// MklConv2D has two outgoing edges: BiasAdd and some other dummy node (Add).
|
||||
// _MklConv2D has two outgoing edges: BiasAdd and some other dummy node (Add).
|
||||
// Merge should not be done in such case.
|
||||
TEST_F(MklLayoutPassTest, NodeMerge_Conv2DWithBias_Negative_Dataflow2) {
|
||||
InitGraph(
|
||||
"node { name: 'A' op: 'Input'}"
|
||||
"node { name: 'M' op: 'MklInput'}"
|
||||
"node { name: 'B' op: 'Input'}"
|
||||
"node { name: 'N' op: 'MklInput'}"
|
||||
"node { name: 'C' op: 'MklConv2D'"
|
||||
"node { name: 'M' op: '_MklInput'}"
|
||||
"node { name: 'N' op: '_MklInput'}"
|
||||
"node { name: 'C' op: '_MklConv2D'"
|
||||
" attr { key: 'T' value { type: DT_FLOAT } }"
|
||||
" attr { key: 'data_format' value { s: 'NCHW' } }"
|
||||
" attr { key: 'use_cudnn_on_gpu' value { b: false } }"
|
||||
" attr { key: 'strides' value { list: {i: 1, i:1, i:1, i:1} } }"
|
||||
" attr { key: 'padding' value { s: 'SAME' } }"
|
||||
" input: ['A', 'M', 'B', 'N']}"
|
||||
" input: ['A', 'B', 'M', 'N']}"
|
||||
"node { name: 'D' op: 'Input'}"
|
||||
"node { name: 'E' op: 'Input'}"
|
||||
"node { name: 'F' op: 'BiasAdd'"
|
||||
@ -293,9 +300,9 @@ TEST_F(MklLayoutPassTest, NodeMerge_Conv2DWithBias_Negative_Dataflow2) {
|
||||
" attr { key: 'T' value { type: DT_FLOAT } }"
|
||||
" input: ['C', 'E'] }");
|
||||
EXPECT_EQ(DoMklLayoutOptimizationPass(),
|
||||
"A(Input);B(Input);C(MklConv2D);D(Input);E(Input);F(BiasAdd);"
|
||||
"G(Add);M(MklInput);N(MklInput)|A->C;B->C:2;C->G;D->F;"
|
||||
"E->F:1;E->G:1;M->C:1;N->C:3");
|
||||
"A(Input);B(Input);C(_MklConv2D);D(Input);E(Input);F(BiasAdd);"
|
||||
"G(Add);M(_MklInput);N(_MklInput)|A->C;B->C:1;C->G;D->F;"
|
||||
"E->F:1;E->G:1;M->C:2;N->C:3");
|
||||
}
|
||||
|
||||
// data_format attribute value mismatch. Merge should not be done
|
||||
@ -303,43 +310,81 @@ TEST_F(MklLayoutPassTest, NodeMerge_Conv2DWithBias_Negative_Dataflow2) {
|
||||
TEST_F(MklLayoutPassTest, NodeMerge_Conv2DWithBias_Negative_AttrMismatch) {
|
||||
InitGraph(
|
||||
"node { name: 'A' op: 'Input'}"
|
||||
"node { name: 'M' op: 'MklInput'}"
|
||||
"node { name: 'B' op: 'Input'}"
|
||||
"node { name: 'N' op: 'MklInput'}"
|
||||
"node { name: 'C' op: 'MklConv2D'"
|
||||
"node { name: 'M' op: '_MklInput'}"
|
||||
"node { name: 'N' op: '_MklInput'}"
|
||||
"node { name: 'C' op: '_MklConv2D'"
|
||||
" attr { key: 'T' value { type: DT_FLOAT } }"
|
||||
" attr { key: 'data_format' value { s: 'NCHW' } }"
|
||||
" attr { key: 'use_cudnn_on_gpu' value { b: false } }"
|
||||
" attr { key: 'strides' value { list: {i: 1, i:1, i:1, i:1} } }"
|
||||
" attr { key: 'padding' value { s: 'SAME' } }"
|
||||
" input: ['A', 'M', 'B', 'N']}"
|
||||
" input: ['A', 'B', 'M', 'N']}"
|
||||
"node { name: 'D' op: 'Input'}"
|
||||
"node { name: 'E' op: 'BiasAdd'"
|
||||
" attr { key: 'T' value { type: DT_FLOAT } }"
|
||||
" attr { key: 'data_format' value { s: 'NHCW' } }"
|
||||
" input: ['C', 'D'] }");
|
||||
EXPECT_EQ(DoMklLayoutOptimizationPass(),
|
||||
"A(Input);B(Input);C(MklConv2D);D(Input);E(BiasAdd);M(MklInput);"
|
||||
"N(MklInput)|A->C;B->C:2;C->E;D->E:1;M->C:1;N->C:3");
|
||||
"A(Input);B(Input);C(_MklConv2D);D(Input);E(BiasAdd);M(_MklInput);"
|
||||
"N(_MklInput)|A->C;B->C:1;C->E;D->E:1;M->C:2;N->C:3");
|
||||
}
|
||||
|
||||
// No MklConv2D in context, but Conv2D in context.
|
||||
// Only Conv2D would be rewritten to MklConv2D, but no rewrite
|
||||
// for BiasAddGrad should happen.
|
||||
// C=MklConv2D(A,M,B,N); D=Sub(C,A); E=BiasAddGrad(D)
|
||||
TEST_F(MklLayoutPassTest, NodeMerge_Conv2DBackprop_Neg_NoMklConv2DWithBias) {
|
||||
// Disabling Conv2DBackpropBias test for now as we have disabled rewrite
|
||||
// of BiasAddGrad into BackpropBias
|
||||
#if 0
|
||||
// Test set 2: _MklConv2D..BiasAddGrad -> _MklConv2DWithBiasBackpropBias
|
||||
// rewrite tests
|
||||
|
||||
// D=_MklConv2D(A,M,B,N,C,O); E=Sub(D,A); F=BiasAddGrad(E)
|
||||
TEST_F(MklLayoutPassTest, NodeMerge_Conv2DBackprop_Positive) {
|
||||
InitGraph(
|
||||
"node { name: 'A' op: 'Input'}"
|
||||
"node { name: 'M' op: 'MklInput'}"
|
||||
"node { name: 'B' op: 'Input'}"
|
||||
"node { name: 'N' op: 'MklInput'}"
|
||||
"node { name: 'C' op: 'MklConv2D'"
|
||||
"node { name: 'C' op: 'Input'}"
|
||||
"node { name: 'M' op: '_MklInput'}"
|
||||
"node { name: 'N' op: '_MklInput'}"
|
||||
"node { name: 'O' op: '_MklInput'}"
|
||||
"node { name: 'D' op: '_MklConv2DWithBias'"
|
||||
" attr { key: 'T' value { type: DT_FLOAT } }"
|
||||
" attr { key: 'data_format' value { s: 'NCHW' } }"
|
||||
" attr { key: 'use_cudnn_on_gpu' value { b: false } }"
|
||||
" attr { key: 'strides' value { list: {i: 1, i:1, i:1, i:1} } }"
|
||||
" attr { key: 'padding' value { s: 'SAME' } }"
|
||||
" input: ['A', 'M', 'B', 'N']}"
|
||||
" input: ['A', 'B', 'C', 'M', 'N', 'O']}"
|
||||
"node { name: 'E' op: 'Sub'"
|
||||
" attr {key: 'T' value { type: DT_FLOAT } }"
|
||||
" input: ['D', 'A']}"
|
||||
"node { name: 'F' op: 'BiasAddGrad'"
|
||||
" attr { key: 'T' value { type: DT_FLOAT } }"
|
||||
" attr { key: 'data_format' value { s: 'NCHW' } }"
|
||||
" input: ['E'] }");
|
||||
EXPECT_EQ(DoMklLayoutOptimizationPass(),
|
||||
"A(Input);B(Input);C(Input);D(_MklConv2DWithBias);DMT/_0(Const);"
|
||||
"E(Sub);F(_MklConv2DWithBiasBackpropBias);M(_MklInput);N(_MklInput);"
|
||||
"O(_MklInput)|A->D;A->E:1;B->D:1;C->D:2;D->E;DMT/_0->F:1;E->F;"
|
||||
"M->D:3;N->D:4;O->D:5");
|
||||
}
|
||||
#endif
|
||||
|
||||
// No _MklConv2D in context, but Conv2D in context.
|
||||
// Only Conv2D would be rewritten to _MklConv2D, but no rewrite
|
||||
// for BiasAddGrad should happen.
|
||||
// C=_MklConv2D(A,M,B,N); D=Sub(C,A); E=BiasAddGrad(D) (for interleaved)
|
||||
// C=_MklConv2D(A,B,M,N); D=Sub(C,A); E=BiasAddGrad(D) (for contiguous)
|
||||
TEST_F(MklLayoutPassTest, NodeMerge_Conv2DBackprop_Neg_No_MklConv2DWithBias) {
|
||||
InitGraph(
|
||||
"node { name: 'A' op: 'Input'}"
|
||||
"node { name: 'B' op: 'Input'}"
|
||||
"node { name: 'M' op: '_MklInput'}"
|
||||
"node { name: 'N' op: '_MklInput'}"
|
||||
"node { name: 'C' op: '_MklConv2D'"
|
||||
" attr { key: 'T' value { type: DT_FLOAT } }"
|
||||
" attr { key: 'data_format' value { s: 'NCHW' } }"
|
||||
" attr { key: 'use_cudnn_on_gpu' value { b: false } }"
|
||||
" attr { key: 'strides' value { list: {i: 1, i:1, i:1, i:1} } }"
|
||||
" attr { key: 'padding' value { s: 'SAME' } }"
|
||||
" input: ['A', 'B', 'M', 'N']}"
|
||||
"node { name: 'D' op: 'Sub'"
|
||||
" attr {key: 'T' value { type: DT_FLOAT } }"
|
||||
" input: ['C', 'A']}"
|
||||
@ -348,9 +393,9 @@ TEST_F(MklLayoutPassTest, NodeMerge_Conv2DBackprop_Neg_NoMklConv2DWithBias) {
|
||||
" attr { key: 'data_format' value { s: 'NCHW' } }"
|
||||
" input: ['D'] }");
|
||||
EXPECT_EQ(DoMklLayoutOptimizationPass(),
|
||||
"A(Input);B(Input);C(MklConv2D);D(Sub);E(BiasAddGrad);"
|
||||
"M(MklInput);N(MklInput)|A->C;A->D:1;B->C:2;C->D;D->E;"
|
||||
"M->C:1;N->C:3");
|
||||
"A(Input);B(Input);C(_MklConv2D);D(Sub);E(BiasAddGrad);"
|
||||
"M(_MklInput);N(_MklInput)|A->C;A->D:1;B->C:1;C->D;D->E;"
|
||||
"M->C:2;N->C:3");
|
||||
}
|
||||
|
||||
// No Conv2D in the context for BiasAddGrad. No rewrite should happen.
|
||||
@ -462,8 +507,8 @@ TEST_F(MklLayoutPassTest, NodeRewrite_Conv2D_Basic) {
|
||||
"node { name: 'D' op: 'Mul' attr { key: 'T' value { type: DT_FLOAT } }"
|
||||
" input: ['B', 'C'] }");
|
||||
EXPECT_EQ(DoMklLayoutOptimizationPass(),
|
||||
"A(Input);B(Input);C(MklConv2D);D(Mul);DMT/_0(Const);DMT/_1(Const)|"
|
||||
"A->C;B->C:2;B->D;C->D:1;DMT/_0->C:1;DMT/_1->C:3");
|
||||
"A(Input);B(Input);C(_MklConv2D);D(Mul);DMT/_0(Const);DMT/_1(Const)|"
|
||||
"A->C;B->C:1;B->D;C->D:1;DMT/_0->C:2;DMT/_1->C:3");
|
||||
}
|
||||
|
||||
// 2 Conv2D Ops in sequence. Both should get transformed and 1st Conv2D will
|
||||
@ -489,9 +534,9 @@ TEST_F(MklLayoutPassTest, NodeRewrite_Conv2D_Positive1) {
|
||||
"node { name: 'E' op: 'Mul' attr { key: 'T' value { type: DT_FLOAT } }"
|
||||
" input: ['C', 'D'] }");
|
||||
EXPECT_EQ(DoMklLayoutOptimizationPass(),
|
||||
"A(Input);B(Input);C(MklConv2D);D(MklConv2D);DMT/_0(Const);"
|
||||
"DMT/_1(Const);DMT/_2(Const);E(Mul)|A->C;A->D;B->C:2;C->D:2;C->E;"
|
||||
"C:1->D:3;D->E:1;DMT/_0->C:1;DMT/_1->C:3;DMT/_2->D:1");
|
||||
"A(Input);B(Input);C(_MklConv2D);D(_MklConv2D);DMT/_0(Const);"
|
||||
"DMT/_1(Const);DMT/_2(Const);E(Mul)|A->C;A->D;B->C:1;C->D:1;C->E;"
|
||||
"C:1->D:3;D->E:1;DMT/_0->C:2;DMT/_1->C:3;DMT/_2->D:2");
|
||||
}
|
||||
|
||||
// Conv2D with INT32 which is not supported by Mkl
|
||||
@ -513,10 +558,374 @@ TEST_F(MklLayoutPassTest, NodeRewrite_Conv2D_Negative_UnsupportedType) {
|
||||
"A->C;B->C:1;B->D;C->D:1");
|
||||
}
|
||||
|
||||
// Concat Op test: Concat with no Mkl layer feeding it
|
||||
TEST_F(MklLayoutPassTest, NodeRewrite_Concat_Basic) {
|
||||
InitGraph(
|
||||
"node { name: 'A' op: 'Const' "
|
||||
" attr { key: 'dtype' value { type: DT_INT32 } }"
|
||||
" attr { key: 'value' value { "
|
||||
" tensor { dtype: DT_INT32 tensor_shape { dim { size: 1 } } "
|
||||
" int_val: 0 } } } }"
|
||||
"node { name: 'B' op: 'InputList'"
|
||||
" attr { key: 'N' value { i: 2 } }}"
|
||||
"node { name: 'C' op: 'Input'}"
|
||||
"node { name: 'D' op: 'Concat'"
|
||||
" attr { key: 'T' value { type: DT_FLOAT } }"
|
||||
" attr { key: 'N' value { i: 2 } }"
|
||||
" input: ['A', 'B']}"
|
||||
"node { name: 'E' op: 'Mul' attr { key: 'T' value { type: DT_FLOAT } }"
|
||||
" input: ['C', 'D'] }");
|
||||
EXPECT_EQ(DoMklLayoutOptimizationPass(),
|
||||
"A(Const);B(InputList);C(Input);D(_MklConcat);DMT/_0(Const);"
|
||||
"DMT/_1(Const);DMT/_2(Const);E(Mul)|A->D;B->D:1;B->D:2;C->E;"
|
||||
"D->E:1;DMT/_0->D:3;DMT/_1->D:4;DMT/_2->D:5");
|
||||
}
|
||||
|
||||
// Concat with 2 Mkl layers feeding it
|
||||
TEST_F(MklLayoutPassTest, NodeRewrite_Concat_Input_Mkl) {
|
||||
InitGraph(
|
||||
"node { name: 'A' op: 'Input'}"
|
||||
"node { name: 'B' op: 'Input'}"
|
||||
"node { name: 'C' op: 'Input'}"
|
||||
"node { name: 'D' op: 'Input'}"
|
||||
"node { name: 'E' op: 'Conv2D'"
|
||||
" attr { key: 'T' value { type: DT_FLOAT } }"
|
||||
" attr { key: 'data_format' value { s: 'NCHW' } }"
|
||||
" attr { key: 'use_cudnn_on_gpu' value { b: false } }"
|
||||
" attr { key: 'strides' value { list: {i: 1, i:1, i:1, i:1} } }"
|
||||
" attr { key: 'padding' value { s: 'SAME' } }"
|
||||
" input: ['A', 'B']}"
|
||||
"node { name: 'F' op: 'Conv2D'"
|
||||
" attr { key: 'T' value { type: DT_FLOAT } }"
|
||||
" attr { key: 'data_format' value { s: 'NCHW' } }"
|
||||
" attr { key: 'use_cudnn_on_gpu' value { b: false } }"
|
||||
" attr { key: 'strides' value { list: {i: 1, i:1, i:1, i:1} } }"
|
||||
" attr { key: 'padding' value { s: 'SAME' } }"
|
||||
" input: ['C', 'D']}"
|
||||
"node { name: 'G' op: 'Const' "
|
||||
" attr { key: 'dtype' value { type: DT_INT32 } }"
|
||||
" attr { key: 'value' value { "
|
||||
" tensor { dtype: DT_INT32 tensor_shape { dim { size: 1 } } "
|
||||
" int_val: 0 } } } }"
|
||||
"node { name: 'H' op: 'Concat'"
|
||||
" attr { key: 'T' value { type: DT_FLOAT } }"
|
||||
" attr { key: 'N' value { i: 2 } }"
|
||||
" input: ['G', 'E', 'F']}"
|
||||
"node { name: 'I' op: 'Mul' attr { key: 'T' value { type: DT_FLOAT } }"
|
||||
" input: ['A', 'H'] }");
|
||||
EXPECT_EQ(DoMklLayoutOptimizationPass(),
|
||||
"A(Input);B(Input);C(Input);D(Input);DMT/_0(Const);DMT/_1(Const);"
|
||||
"DMT/_2(Const);DMT/_3(Const);DMT/_4(Const);E(_MklConv2D);"
|
||||
"F(_MklConv2D);G(Const);H(_MklConcat);I(Mul)|A->E;A->I;B->E:1;C->F;"
|
||||
"D->F:1;DMT/_0->F:2;DMT/_1->F:3;DMT/_2->E:2;DMT/_3->E:3;"
|
||||
"DMT/_4->H:3;E->H:1;E:1->H:4;F->H:2;F:1->H:5;G->H;H->I:1");
|
||||
}
|
||||
|
||||
// Concat with 1 Mkl and 1 non-Mkl layer feeding it
|
||||
TEST_F(MklLayoutPassTest, NodeRewrite_Concat_Input_MixedMkl) {
|
||||
InitGraph(
|
||||
"node { name: 'A' op: 'Input'}"
|
||||
"node { name: 'B' op: 'Input'}"
|
||||
"node { name: 'C' op: 'Input'}"
|
||||
"node { name: 'D' op: 'Input'}"
|
||||
"node { name: 'E' op: 'Conv2D'"
|
||||
" attr { key: 'T' value { type: DT_FLOAT } }"
|
||||
" attr { key: 'data_format' value { s: 'NCHW' } }"
|
||||
" attr { key: 'use_cudnn_on_gpu' value { b: false } }"
|
||||
" attr { key: 'strides' value { list: {i: 1, i:1, i:1, i:1} } }"
|
||||
" attr { key: 'padding' value { s: 'SAME' } }"
|
||||
" input: ['A', 'B']}"
|
||||
"node { name: 'F' op: 'Mul' attr { key: 'T' value { type: DT_FLOAT } }"
|
||||
" input: ['C', 'D']}"
|
||||
"node { name: 'G' op: 'Const' "
|
||||
" attr { key: 'dtype' value { type: DT_INT32 } }"
|
||||
" attr { key: 'value' value { "
|
||||
" tensor { dtype: DT_INT32 tensor_shape { dim { size: 1 } } "
|
||||
" int_val: 0 } } } }"
|
||||
"node { name: 'H' op: 'Concat'"
|
||||
" attr { key: 'T' value { type: DT_FLOAT } }"
|
||||
" attr { key: 'N' value { i: 2 } }"
|
||||
" input: ['G', 'E', 'F']}"
|
||||
"node { name: 'I' op: 'Mul' attr { key: 'T' value { type: DT_FLOAT } }"
|
||||
" input: ['A', 'H'] }");
|
||||
EXPECT_EQ(DoMklLayoutOptimizationPass(),
|
||||
"A(Input);B(Input);C(Input);D(Input);DMT/_0(Const);DMT/_1(Const);"
|
||||
"DMT/_2(Const);DMT/_3(Const);E(_MklConv2D);F(Mul);G(Const);"
|
||||
"H(_MklConcat);I(Mul)|A->E;A->I;B->E:1;C->F;D->F:1;DMT/_0->E:2;"
|
||||
"DMT/_1->E:3;DMT/_2->H:3;DMT/_3->H:5;E->H:1;E:1->H:4;F->H:2;"
|
||||
"G->H;H->I:1");
|
||||
}
|
||||
|
||||
#if 0
|
||||
// ConcatV2 Op test: ConcatV2 with no Mkl layer feeding it
|
||||
TEST_F(MklLayoutPassTest, NodeRewrite_ConcatV2_Basic) {
|
||||
InitGraph(
|
||||
"node { name: 'A' op: 'Const' "
|
||||
" attr { key: 'dtype' value { type: DT_INT32 } }"
|
||||
" attr { key: 'value' value { "
|
||||
" tensor { dtype: DT_INT32 tensor_shape { dim { size: 1 } } "
|
||||
" int_val: 0 } } } }"
|
||||
"node { name: 'B' op: 'InputList'"
|
||||
" attr { key: 'N' value { i: 2 } }}"
|
||||
"node { name: 'C' op: 'Input'}"
|
||||
"node { name: 'D' op: 'ConcatV2'"
|
||||
" attr { key: 'T' value { type: DT_FLOAT } }"
|
||||
" attr { key: 'Tidx' value { type: DT_INT32 } }"
|
||||
" attr { key: 'N' value { i: 2 } }"
|
||||
" input: ['B:0', 'B:1', 'A']}"
|
||||
"node { name: 'E' op: 'Mul' attr { key: 'T' value { type: DT_FLOAT } }"
|
||||
" input: ['C', 'D'] }");
|
||||
EXPECT_EQ(DoMklLayoutOptimizationPass(),
|
||||
"A(Const);B(InputList);C(Input);D(_MklConcat);DMT/_0(Const);"
|
||||
"DMT/_1(Const);DMT/_2(Const);E(Mul)|A->D:2;B->D;B:1->D:1;C->E;"
|
||||
"D->E:1;DMT/_0->D:3;DMT/_1->D:4;DMT/_2->D:5");
|
||||
}
|
||||
#endif
|
||||
|
||||
// ConcatV2 with 2 Mkl layers feeding it
|
||||
TEST_F(MklLayoutPassTest, NodeRewrite_ConcatV2_Input_Mkl) {
|
||||
InitGraph(
|
||||
"node { name: 'A' op: 'Input'}"
|
||||
"node { name: 'B' op: 'Input'}"
|
||||
"node { name: 'C' op: 'Input'}"
|
||||
"node { name: 'D' op: 'Input'}"
|
||||
"node { name: 'E' op: 'Conv2D'"
|
||||
" attr { key: 'T' value { type: DT_FLOAT } }"
|
||||
" attr { key: 'data_format' value { s: 'NCHW' } }"
|
||||
" attr { key: 'use_cudnn_on_gpu' value { b: false } }"
|
||||
" attr { key: 'strides' value { list: {i: 1, i:1, i:1, i:1} } }"
|
||||
" attr { key: 'padding' value { s: 'SAME' } }"
|
||||
" input: ['A', 'B']}"
|
||||
"node { name: 'F' op: 'Conv2D'"
|
||||
" attr { key: 'T' value { type: DT_FLOAT } }"
|
||||
" attr { key: 'data_format' value { s: 'NCHW' } }"
|
||||
" attr { key: 'use_cudnn_on_gpu' value { b: false } }"
|
||||
" attr { key: 'strides' value { list: {i: 1, i:1, i:1, i:1} } }"
|
||||
" attr { key: 'padding' value { s: 'SAME' } }"
|
||||
" input: ['C', 'D']}"
|
||||
"node { name: 'G' op: 'Const' "
|
||||
" attr { key: 'dtype' value { type: DT_INT32 } }"
|
||||
" attr { key: 'value' value { "
|
||||
" tensor { dtype: DT_INT32 tensor_shape { dim { size: 1 } } "
|
||||
" int_val: 0 } } } }"
|
||||
"node { name: 'H' op: 'ConcatV2'"
|
||||
" attr { key: 'T' value { type: DT_FLOAT } }"
|
||||
" attr { key: 'Tidx' value { type: DT_INT32 } }"
|
||||
" attr { key: 'N' value { i: 2 } }"
|
||||
" input: ['E', 'F', 'G']}"
|
||||
"node { name: 'I' op: 'Mul' attr { key: 'T' value { type: DT_FLOAT } }"
|
||||
" input: ['A', 'H'] }");
|
||||
EXPECT_EQ(DoMklLayoutOptimizationPass(),
|
||||
"A(Input);B(Input);C(Input);D(Input);DMT/_0(Const);DMT/_1(Const);"
|
||||
"DMT/_2(Const);DMT/_3(Const);DMT/_4(Const);E(_MklConv2D);"
|
||||
"F(_MklConv2D);G(Const);H(_MklConcatV2);I(Mul)|A->E;A->I;B->E:1;C->F;"
|
||||
"D->F:1;DMT/_0->F:2;DMT/_1->F:3;DMT/_2->E:2;DMT/_3->E:3;"
|
||||
"DMT/_4->H:5;E->H;E:1->H:3;F->H:1;F:1->H:4;G->H:2;H->I:1");
|
||||
}
|
||||
|
||||
// ConcatV2 with 1 Mkl and 1 non-Mkl layer feeding it
|
||||
TEST_F(MklLayoutPassTest, NodeRewrite_ConcatV2_Input_MixedMkl) {
|
||||
InitGraph(
|
||||
"node { name: 'A' op: 'Input'}"
|
||||
"node { name: 'B' op: 'Input'}"
|
||||
"node { name: 'C' op: 'Input'}"
|
||||
"node { name: 'D' op: 'Input'}"
|
||||
"node { name: 'E' op: 'Conv2D'"
|
||||
" attr { key: 'T' value { type: DT_FLOAT } }"
|
||||
" attr { key: 'data_format' value { s: 'NCHW' } }"
|
||||
" attr { key: 'use_cudnn_on_gpu' value { b: false } }"
|
||||
" attr { key: 'strides' value { list: {i: 1, i:1, i:1, i:1} } }"
|
||||
" attr { key: 'padding' value { s: 'SAME' } }"
|
||||
" input: ['A', 'B']}"
|
||||
"node { name: 'F' op: 'Mul' attr { key: 'T' value { type: DT_FLOAT } }"
|
||||
" input: ['C', 'D']}"
|
||||
"node { name: 'G' op: 'Const' "
|
||||
" attr { key: 'dtype' value { type: DT_INT32 } }"
|
||||
" attr { key: 'value' value { "
|
||||
" tensor { dtype: DT_INT32 tensor_shape { dim { size: 1 } } "
|
||||
" int_val: 0 } } } }"
|
||||
"node { name: 'H' op: 'ConcatV2'"
|
||||
" attr { key: 'T' value { type: DT_FLOAT } }"
|
||||
" attr { key: 'Tidx' value { type: DT_INT32 } }"
|
||||
" attr { key: 'N' value { i: 2 } }"
|
||||
" input: ['E', 'F', 'G']}"
|
||||
"node { name: 'I' op: 'Mul' attr { key: 'T' value { type: DT_FLOAT } }"
|
||||
" input: ['A', 'H'] }");
|
||||
EXPECT_EQ(DoMklLayoutOptimizationPass(),
|
||||
"A(Input);B(Input);C(Input);D(Input);DMT/_0(Const);DMT/_1(Const);"
|
||||
"DMT/_2(Const);DMT/_3(Const);E(_MklConv2D);F(Mul);G(Const);"
|
||||
"H(_MklConcatV2);I(Mul)|A->E;A->I;B->E:1;C->F;D->F:1;DMT/_0->E:2;"
|
||||
"DMT/_1->E:3;DMT/_2->H:4;DMT/_3->H:5;E->H;E:1->H:3;F->H:1;"
|
||||
"G->H:2;H->I:1");
|
||||
}
|
||||
|
||||
/////////////////////////////////////////////////////////////////////
|
||||
// Unit tests related to rewriting node for workspace edges
|
||||
/////////////////////////////////////////////////////////////////////
|
||||
|
||||
/* Test LRN->MaxPool->MaxPoolGrad->LRNGrad replacement by workspace nodes. */
|
||||
TEST_F(MklLayoutPassTest, MaxPoolLRN_Positive) {
|
||||
InitGraph(
|
||||
"node { name: 'A' op: 'Input'}"
|
||||
"node { name: 'B' op: 'LRN'"
|
||||
" attr { key: 'T' value { type: DT_FLOAT } }"
|
||||
" attr { key: 'alpha' value { f: 0.001 } }"
|
||||
" attr { key: 'beta' value { f: 0.75 } }"
|
||||
" attr { key: 'bias' value { f: 1.0 } }"
|
||||
" attr { key: 'data_format' value { s: 'NCHW' } }"
|
||||
" attr { key: 'depth_radius' value { i: 2 } }"
|
||||
" input: ['A'] }"
|
||||
"node { name: 'C' op: 'MaxPool'"
|
||||
" attr { key: 'T' value { type: DT_FLOAT } }"
|
||||
" attr { key: 'data_format' value { s: 'NCHW' } }"
|
||||
" attr { key: 'ksize' value { list: {i: 1, i:1, i:3, i:3} } }"
|
||||
" attr { key: 'padding' value { s: 'VALID' } }"
|
||||
" attr { key: 'strides' value { list: {i: 1, i:1, i:2, i:2} } }"
|
||||
" input: ['B'] }"
|
||||
"node { name: 'D' op: 'Input'}"
|
||||
"node { name: 'E' op: 'MaxPoolGrad'"
|
||||
" attr { key: 'T' value { type: DT_FLOAT } }"
|
||||
" attr { key: 'data_format' value { s: 'NCHW' } }"
|
||||
" attr { key: 'ksize' value { list: {i: 1, i:1, i:3, i:3} } }"
|
||||
" attr { key: 'padding' value { s: 'VALID' } }"
|
||||
" attr { key: 'strides' value { list: {i: 1, i:1, i:2, i:2} } }"
|
||||
" input: ['B', 'C', 'D'] }"
|
||||
"node { name: 'F' op: 'Input'}"
|
||||
"node { name: 'G' op: 'LRNGrad'"
|
||||
" attr { key: 'T' value { type: DT_FLOAT } }"
|
||||
" attr { key: 'alpha' value { f: 0.001 } }"
|
||||
" attr { key: 'beta' value { f: 0.75 } }"
|
||||
" attr { key: 'bias' value { f: 1.0 } }"
|
||||
" attr { key: 'data_format' value { s: 'NCHW' } }"
|
||||
" attr { key: 'depth_radius' value { i: 2 } }"
|
||||
" input: ['E', 'F', 'B'] }"
|
||||
"node { name: 'H' op: 'Input'}"
|
||||
"node { name: 'I' op: 'Mul' attr { key: 'T' value { type: DT_FLOAT } }"
|
||||
" input: ['H', 'G'] }");
|
||||
EXPECT_EQ(
|
||||
DoMklLayoutOptimizationPass(),
|
||||
"A(Input);B(_MklLRN);C(_MklMaxPool);D(Input);DMT/_0(Const);DMT/_1(Const);"
|
||||
"DMT/_2(Const);E(_MklMaxPoolGrad);F(Input);G(_MklLRNGrad);H(Input);I(Mul)|"
|
||||
"A->B;B->C;B->E;B->G:2;B:1->G:3;B:2->C:1;B:2->E:4;B:2->G:6;B:3->G:7;"
|
||||
"C->E:1;C:1->E:3;C:2->E:5;C:3->E:7;D->E:2;DMT/_0->B:1;DMT/_1->E:6;"
|
||||
"DMT/_2->G:5;E->G;E:1->G:4;F->G:1;G->I:1;H->I");
|
||||
}
|
||||
|
||||
/* Test LRN->LRNGrad replacement by workspace nodes. */
|
||||
TEST_F(MklLayoutPassTest, LRN_Positive) {
|
||||
InitGraph(
|
||||
"node { name: 'A' op: 'Input'}"
|
||||
"node { name: 'B' op: 'LRN'"
|
||||
" attr { key: 'T' value { type: DT_FLOAT } }"
|
||||
" attr { key: 'alpha' value { f: 0.001 } }"
|
||||
" attr { key: 'beta' value { f: 0.75 } }"
|
||||
" attr { key: 'bias' value { f: 1.0 } }"
|
||||
" attr { key: 'data_format' value { s: 'NCHW' } }"
|
||||
" attr { key: 'depth_radius' value { i: 2 } }"
|
||||
" input: ['A'] }"
|
||||
"node { name: 'C' op: 'Input'}"
|
||||
"node { name: 'D' op: 'Input'}"
|
||||
"node { name: 'E' op: 'LRNGrad'"
|
||||
" attr { key: 'T' value { type: DT_FLOAT } }"
|
||||
" attr { key: 'alpha' value { f: 0.001 } }"
|
||||
" attr { key: 'beta' value { f: 0.75 } }"
|
||||
" attr { key: 'bias' value { f: 1.0 } }"
|
||||
" attr { key: 'data_format' value { s: 'NCHW' } }"
|
||||
" attr { key: 'depth_radius' value { i: 2 } }"
|
||||
" input: ['C', 'D', 'B'] }"
|
||||
"node { name: 'F' op: 'Mul' attr { key: 'T' value { type: DT_FLOAT } }"
|
||||
" input: ['C', 'E'] }");
|
||||
EXPECT_EQ(DoMklLayoutOptimizationPass(),
|
||||
"A(Input);B(_MklLRN);C(Input);D(Input);DMT/_0(Const);DMT/_1(Const);"
|
||||
"DMT/_2(Const);E(_MklLRNGrad);F(Mul)|"
|
||||
"A->B;B->E:2;B:1->E:3;B:2->E:6;B:3->E:7;C->E;C->F;D->E:1;"
|
||||
"DMT/_0->B:1;DMT/_1->E:4;DMT/_2->E:5;E->F:1");
|
||||
}
|
||||
|
||||
/* Test LRN->LRNGrad replacement when only one of them is present. */
|
||||
TEST_F(MklLayoutPassTest, LRN_Negative1) {
|
||||
InitGraph(
|
||||
"node { name: 'A' op: 'Input'}"
|
||||
"node { name: 'B' op: 'LRN'"
|
||||
" attr { key: 'T' value { type: DT_FLOAT } }"
|
||||
" attr { key: 'alpha' value { f: 0.001 } }"
|
||||
" attr { key: 'beta' value { f: 0.75 } }"
|
||||
" attr { key: 'bias' value { f: 1.0 } }"
|
||||
" attr { key: 'data_format' value { s: 'NCHW' } }"
|
||||
" attr { key: 'depth_radius' value { i: 2 } }"
|
||||
" input: ['A'] }"
|
||||
"node { name: 'C' op: 'Mul' attr { key: 'T' value { type: DT_FLOAT } }"
|
||||
" input: ['A', 'B'] }");
|
||||
EXPECT_EQ(DoMklLayoutOptimizationPass(),
|
||||
"A(Input);B(_MklLRN);C(Mul);DMT/_0(Const)|"
|
||||
"A->B;A->C;B->C:1;DMT/_0->B:1");
|
||||
}
|
||||
|
||||
/* Test LRN->LRNGrad replacement when only one of them is present. */
|
||||
TEST_F(MklLayoutPassTest, LRN_Negative2) {
|
||||
InitGraph(
|
||||
"node { name: 'A' op: 'Input'}"
|
||||
"node { name: 'B' op: 'Input'}"
|
||||
"node { name: 'C' op: 'Input'}"
|
||||
"node { name: 'D' op: 'LRNGrad'"
|
||||
" attr { key: 'T' value { type: DT_FLOAT } }"
|
||||
" attr { key: 'alpha' value { f: 0.001 } }"
|
||||
" attr { key: 'beta' value { f: 0.75 } }"
|
||||
" attr { key: 'bias' value { f: 1.0 } }"
|
||||
" attr { key: 'data_format' value { s: 'NCHW' } }"
|
||||
" attr { key: 'depth_radius' value { i: 2 } }"
|
||||
" input: ['A', 'B', 'C'] }"
|
||||
"node { name: 'E' op: 'Mul' attr { key: 'T' value { type: DT_FLOAT } }"
|
||||
" input: ['A', 'D'] }");
|
||||
EXPECT_EQ(DoMklLayoutOptimizationPass(),
|
||||
"A(Input);B(Input);C(Input);D(_MklLRNGrad);DMT/_0(Const);"
|
||||
"DMT/_1(Const);DMT/_2(Const);DMT/_3(Const);DMT/_4(Const);E(Mul)|"
|
||||
"A->D;A->E;B->D:1;C->D:2;D->E:1;DMT/_0->D:3;DMT/_1->D:7;"
|
||||
"DMT/_2->D:4;DMT/_3->D:5;DMT/_4->D:6");
|
||||
}
|
||||
|
||||
/* Test LRN->LRNGrad negative case, where single LRN feeds
|
||||
2 LRNGrad nodes at different slots. */
|
||||
TEST_F(MklLayoutPassTest, LRN_Negative3) {
|
||||
InitGraph(
|
||||
"node { name: 'A' op: 'Input'}"
|
||||
"node { name: 'B' op: 'LRN'"
|
||||
" attr { key: 'T' value { type: DT_FLOAT } }"
|
||||
" attr { key: 'alpha' value { f: 0.001 } }"
|
||||
" attr { key: 'beta' value { f: 0.75 } }"
|
||||
" attr { key: 'bias' value { f: 1.0 } }"
|
||||
" attr { key: 'data_format' value { s: 'NCHW' } }"
|
||||
" attr { key: 'depth_radius' value { i: 2 } }"
|
||||
" input: ['A'] }"
|
||||
"node { name: 'C' op: 'Input'}"
|
||||
"node { name: 'D' op: 'Input'}"
|
||||
"node { name: 'E' op: 'LRNGrad'"
|
||||
" attr { key: 'T' value { type: DT_FLOAT } }"
|
||||
" attr { key: 'alpha' value { f: 0.001 } }"
|
||||
" attr { key: 'beta' value { f: 0.75 } }"
|
||||
" attr { key: 'bias' value { f: 1.0 } }"
|
||||
" attr { key: 'data_format' value { s: 'NCHW' } }"
|
||||
" attr { key: 'depth_radius' value { i: 2 } }"
|
||||
" input: ['C', 'D', 'B'] }"
|
||||
"node { name: 'F' op: 'LRNGrad'"
|
||||
" attr { key: 'T' value { type: DT_FLOAT } }"
|
||||
" attr { key: 'alpha' value { f: 0.001 } }"
|
||||
" attr { key: 'beta' value { f: 0.75 } }"
|
||||
" attr { key: 'bias' value { f: 1.0 } }"
|
||||
" attr { key: 'data_format' value { s: 'NCHW' } }"
|
||||
" attr { key: 'depth_radius' value { i: 2 } }"
|
||||
" input: ['C', 'B', 'D'] }"
|
||||
"node { name: 'G' op: 'Mul' attr { key: 'T' value { type: DT_FLOAT } }"
|
||||
" input: ['E', 'F'] }");
|
||||
EXPECT_EQ(DoMklLayoutOptimizationPass(),
|
||||
"A(Input);B(_MklLRN);C(Input);D(Input);DMT/_0(Const);DMT/_1(Const);"
|
||||
"DMT/_2(Const);DMT/_3(Const);DMT/_4(Const);DMT/_5(Const);"
|
||||
"DMT/_6(Const);E(_MklLRNGrad);F(_MklLRNGrad);G(Mul)|A->B;B->E:2;"
|
||||
"B->F:1;B:1->E:3;B:2->E:6;B:2->F:5;B:3->E:7;C->E;C->F;D->E:1;"
|
||||
"D->F:2;DMT/_0->B:1;DMT/_1->F:3;DMT/_2->F:7;DMT/_3->F:4;"
|
||||
"DMT/_4->F:6;DMT/_5->E:4;DMT/_6->E:5;E->G;F->G:1");
|
||||
}
|
||||
|
||||
/* Test MaxPool->MaxPoolGrad replacement by workspace+rewrite nodes. */
|
||||
TEST_F(MklLayoutPassTest, NodeWorkspace_MaxPool_Positive) {
|
||||
InitGraph(
|
||||
@ -540,10 +949,10 @@ TEST_F(MklLayoutPassTest, NodeWorkspace_MaxPool_Positive) {
|
||||
"node { name: 'F' op: 'Mul' attr { key: 'T' value { type: DT_FLOAT } }"
|
||||
" input: ['C', 'E'] }");
|
||||
EXPECT_EQ(DoMklLayoutOptimizationPass(),
|
||||
"A(Input);B(MklMaxPool);C(Input);D(Input);DMT/_0(Const);"
|
||||
"DMT/_1(Const);DMT/_2(Const);E(MklMaxPoolGrad);F(Mul)|"
|
||||
"A->B;B->E:2;B:1->E:3;B:2->E:6;B:3->E:7;C->E;C->F;D->E:4;"
|
||||
"DMT/_0->B:1;DMT/_1->E:1;DMT/_2->E:5;E->F:1");
|
||||
"A(Input);B(_MklMaxPool);C(Input);D(Input);DMT/_0(Const);"
|
||||
"DMT/_1(Const);DMT/_2(Const);E(_MklMaxPoolGrad);F(Mul)|"
|
||||
"A->B;B->E:1;B:1->E:3;B:2->E:5;B:3->E:7;C->E;C->F;D->E:2;"
|
||||
"DMT/_0->B:1;DMT/_1->E:4;DMT/_2->E:6;E->F:1");
|
||||
}
|
||||
|
||||
// Test MaxPool>MaxPoolGrad replacement when only one of them is present.
|
||||
@ -562,11 +971,11 @@ TEST_F(MklLayoutPassTest, NodeWorkspace_MaxPool_Negative1) {
|
||||
"node { name: 'C' op: 'Mul' attr { key: 'T' value { type: DT_FLOAT } }"
|
||||
" input: ['A', 'B'] }");
|
||||
EXPECT_EQ(DoMklLayoutOptimizationPass(),
|
||||
"A(Input);B(MklMaxPool);C(Mul);DMT/_0(Const)|"
|
||||
"A(Input);B(_MklMaxPool);C(Mul);DMT/_0(Const)|"
|
||||
"A->B;A->C;B->C:1;DMT/_0->B:1");
|
||||
}
|
||||
|
||||
// Test MaxPool->MaxPoolGrad replacement when only one of them is present.
|
||||
// Test MaxPoolGrad replacement when only one of them is present.
|
||||
// In this case, we will rewrite MaxPoolGrad and for workspace tensor and
|
||||
// its Mkl part, we will generate dummy tensor.
|
||||
TEST_F(MklLayoutPassTest, NodeWorkspace_MaxPool_Negative2) {
|
||||
@ -584,10 +993,10 @@ TEST_F(MklLayoutPassTest, NodeWorkspace_MaxPool_Negative2) {
|
||||
"node { name: 'E' op: 'Mul' attr { key: 'T' value { type: DT_FLOAT } }"
|
||||
" input: ['A', 'D'] }");
|
||||
EXPECT_EQ(DoMklLayoutOptimizationPass(),
|
||||
"A(Input);B(Input);C(Input);D(MklMaxPoolGrad);DMT/_0(Const);"
|
||||
"A(Input);B(Input);C(Input);D(_MklMaxPoolGrad);DMT/_0(Const);"
|
||||
"DMT/_1(Const);DMT/_2(Const);DMT/_3(Const);DMT/_4(Const);E(Mul)|"
|
||||
"A->D;A->E;B->D:2;C->D:4;D->E:1;DMT/_0->D:1;DMT/_1->D:3;"
|
||||
"DMT/_2->D:5;DMT/_3->D:6;DMT/_4->D:7");
|
||||
"A->D;A->E;B->D:1;C->D:2;D->E:1;DMT/_0->D:3;DMT/_1->D:7;"
|
||||
"DMT/_2->D:4;DMT/_3->D:5;DMT/_4->D:6");
|
||||
}
|
||||
|
||||
/////////////////////////////////////////////////////////////////////
|
||||
|
@ -1,651 +0,0 @@
|
||||
/* Copyright 2015 The TensorFlow Authors. All Rights Reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
|
||||
#ifdef INTEL_MKL
|
||||
// This module implements node merging optimization on the graph.
|
||||
// We process the nodes in the graph in reverse postorder
|
||||
// (i.e. inputs before their downstream dependencies).
|
||||
//
|
||||
#include <memory>
|
||||
#include <queue>
|
||||
#include <set>
|
||||
#include <string>
|
||||
#include <utility>
|
||||
#include <vector>
|
||||
|
||||
#include "tensorflow/core/graph/mkl_optimizer_merge.h"
|
||||
|
||||
#include "tensorflow/core/common_runtime/function.h"
|
||||
#include "tensorflow/core/common_runtime/optimization_registry.h"
|
||||
#include "tensorflow/core/framework/node_def_util.h"
|
||||
#include "tensorflow/core/graph/algorithm.h"
|
||||
#include "tensorflow/core/graph/graph.h"
|
||||
#include "tensorflow/core/graph/node_builder.h"
|
||||
#include "tensorflow/core/lib/core/status.h"
|
||||
#include "tensorflow/core/lib/gtl/map_util.h"
|
||||
#include "tensorflow/core/lib/hash/hash.h"
|
||||
#include "tensorflow/core/platform/logging.h"
|
||||
|
||||
namespace tensorflow {
|
||||
|
||||
// How many hops do we search for matching node in the backward dataflow graph?
|
||||
// We use maxhop of 10 based on empirical observations. Also, these are
|
||||
// maxhops in backward data-flow graph. Since input of forward nodes (Conv2D)
|
||||
// directly goes to backward nodes, we do not expect the hop-distance
|
||||
// would be more than few nodes.
|
||||
static size_t kNodeMergeContextMaxDepth = 10;
|
||||
|
||||
// This optimization pass performs two tasks: merge
|
||||
// nodes in the forward pass, and rewrite the gradient ops
|
||||
// corresponding to merged forward ops.
|
||||
//
|
||||
// Merging nodes in the graph: Currently, it merges Conv2D+AddBias together.
|
||||
//
|
||||
// Rewriting nodes in the graph: This is neded in order to optimize
|
||||
// gradient ops of Conv2D+AddBias. Gradient op of both the Conv2D and
|
||||
// MatMul is BiasAddGrad, and we need to rewrite BiasAddGrad into
|
||||
// Conv2D-specific BiasAddGrad, and MatMul-specific BiasAddGrad.
|
||||
// This is context-specific optimization, where the context is the
|
||||
// forward operator that the BiasAddGrad corresponds to.
|
||||
class NodeMergeRewritePass : public GraphOptimizationPass {
|
||||
public:
|
||||
NodeMergeRewritePass() {
|
||||
csinfo_.conv2d = "MklConv2D";
|
||||
csinfo_.conv2dwithbias = "MklConv2DWithBias";
|
||||
csinfo_.conv2dwithbiasbackpropbias = "Conv2DWithBiasBackpropBias";
|
||||
csinfo_.biasadd = "BiasAdd";
|
||||
csinfo_.matmul = "MatMul";
|
||||
csinfo_.biasaddgrad = "BiasAddGrad";
|
||||
|
||||
minfo_.push_back(
|
||||
{csinfo_.conv2d, csinfo_.biasadd, 0, csinfo_.conv2dwithbias});
|
||||
|
||||
// We use maxhop of 10 based on emperical observations. Also, these are
|
||||
// maxhops in backward data-flow graph. Since input of forward nodes
|
||||
// (Conv2D) directly goes to backward nodes, we do not expect the
|
||||
// hop-distance would be more than few nodes.
|
||||
// TODO(nhasabni) Temporarily disabling rewrite of BiasAddGrad.
|
||||
// Will enable it once we support Conv2DWithBiasBackpropBias op.
|
||||
#if 0
|
||||
rinfo_.push_back({csinfo_.biasaddgrad, csinfo_.conv2dwithbiasbackpropbias,
|
||||
{csinfo_.conv2dwithbias, kNodeMergeContextMaxDepth}});
|
||||
rinfo_.push_back({csinfo_.biasaddgrad, csinfo_.conv2dwithbiasbackpropbias,
|
||||
{csinfo_.conv2d, kNodeMergeContextMaxDepth}});
|
||||
// For now, we are rewriting BiasAddGrad to BiasAddGrad for MatMul. This is
|
||||
// because we do not have a separate Op for MatMulwithBias.
|
||||
rinfo_.push_back({csinfo_.biasaddgrad, csinfo_.biasaddgrad,
|
||||
{csinfo_.matmul, kNodeMergeContextMaxDepth}});
|
||||
#endif
|
||||
}
|
||||
|
||||
// Standard interface to run optimization pass
|
||||
Status Run(const GraphOptimizationPassOptions& options);
|
||||
|
||||
// Helper function which does most of heavy lifting for node merge
|
||||
//
|
||||
// Extracts common functionality between Run public interface and
|
||||
// test interface.
|
||||
//
|
||||
// @return true, if and only if graph is mutated; false otherwise.
|
||||
bool RunPass(std::unique_ptr<Graph>* g);
|
||||
|
||||
private:
|
||||
/// Structure to specify information used in node merge
|
||||
typedef struct {
|
||||
string pred; // Predecessor node string
|
||||
string succ; // Successor node string
|
||||
int op; // What operand no the predecessor node corresponds
|
||||
// to successor node?
|
||||
string newnode; // Name of the node after merge
|
||||
} MergeInfo;
|
||||
|
||||
/// Structure to specify information used in node rewrite
|
||||
typedef struct {
|
||||
string node; // Name of the node to be rewritten
|
||||
string rewrite; // New name of the node after rewrite
|
||||
typedef struct {
|
||||
string fwd; // Node name in forward pass that this node
|
||||
// corresponds to
|
||||
size_t maxhop; // Maximum number of hops the mfwd_ is located
|
||||
// from this node. If mfwd_ is farther than mmaxhop_
|
||||
// then we do not rewrite the node.
|
||||
} ContextInfo;
|
||||
ContextInfo cinfo; // Context for rewrite
|
||||
} RewriteInfo;
|
||||
|
||||
/// Structure to store all constant strings
|
||||
typedef struct {
|
||||
string conv2d;
|
||||
string conv2dwithbias;
|
||||
string conv2dwithbiasbackpropbias;
|
||||
string biasadd;
|
||||
string matmul;
|
||||
string biasaddgrad;
|
||||
} ConstStringInfo;
|
||||
|
||||
ConstStringInfo csinfo_;
|
||||
std::vector<MergeInfo> minfo_;
|
||||
std::vector<RewriteInfo> rinfo_;
|
||||
|
||||
private:
|
||||
// Return a node that can be merged with input node
|
||||
//
|
||||
// @return pointer to the node if we can find such a
|
||||
// node. Otherwise, it returns nullptr.
|
||||
Node* FindNodeForMerge(const Node* a) const;
|
||||
|
||||
// Merge predecessor node with its successor.
|
||||
// Currently, we merge Conv2D with AddBias only.
|
||||
//
|
||||
// Input nodes succ and pred may be deleted if the call to
|
||||
// this function is successful. Attempt to use the pointers
|
||||
// after the call to function may result is undefined behaviors.
|
||||
//
|
||||
// @input g - input graph, succ - successor node, pred - predecessor node
|
||||
// @return Status::OK(), if merging is successful and supported.
|
||||
// Returns appropriate Status error code otherwise.
|
||||
// Graph is updated in case nodes are merged. Otherwise, it is
|
||||
// not updated.
|
||||
Status MergeNode(std::unique_ptr<Graph>* g, Node* succ, Node* pred);
|
||||
|
||||
// Is input node (n) a candidate for rewrite?
|
||||
//
|
||||
// @return true, if it can be rewritten; false, otherwise.
|
||||
bool IsApplicableRewriteNode(const Node* n) const;
|
||||
|
||||
// Rewrites input node to a new node specified by its matching rewrite info.
|
||||
//
|
||||
// Method first searches matching rewrite info for input node and then
|
||||
// uses that info to rewrite.
|
||||
//
|
||||
// Input node may be deleted in case of rewrite. Attempt to use the node
|
||||
// after the call can result in undefined behaviors.
|
||||
//
|
||||
// @input g - input graph, n - Node to be rewritten
|
||||
// @return Status::OK(), if the input node is rewritten;
|
||||
// Returns appropriate Status error code otherwise.
|
||||
// Graph is updated in case the input node is rewritten.
|
||||
// Otherwise, it is not updated.
|
||||
Status RewriteNode(std::unique_ptr<Graph>* g, Node* n);
|
||||
|
||||
// Helper function that searches the matching rewriteinfo for the node.
|
||||
// Implements depth-first search in the data dependence graph for the
|
||||
// gradient op in backward direction.
|
||||
//
|
||||
// @input n - Node (gradient op) whose rewriteinfo is to be searched,
|
||||
// fwdn - pointer to node from the forward pass that this node
|
||||
// belongs to
|
||||
// @return Matching rewriteinfo in case a match is found; null otherwise.
|
||||
const RewriteInfo* FindMatchingRewriteInfo(const Node* n,
|
||||
const Node** fwdn) const;
|
||||
|
||||
// Generate a graph node in graph 'g' representing a dummy Mkl tensor node,
|
||||
// and return it in '*out'.
|
||||
// TODO(nhasabni) We should move this to mkl_util.h
|
||||
void GetDummyMklTensorNode(std::unique_ptr<Graph>* g, Node** out);
|
||||
};
|
||||
|
||||
// We register merge optimizer for phase 2 in pre-placement group.
|
||||
// Do not change the ordering of the Mkl passes.
|
||||
REGISTER_OPTIMIZATION(OptimizationPassRegistry::PRE_PLACEMENT, 2,
|
||||
NodeMergeRewritePass);
|
||||
|
||||
static void FillInputs(const Node* n,
|
||||
gtl::InlinedVector<Node*, 4>* control_edges,
|
||||
gtl::InlinedVector<std::pair<Node*, int>, 4>* in) {
|
||||
DCHECK_EQ(in->size(), n->num_inputs());
|
||||
control_edges->clear();
|
||||
for (const Edge* e : n->in_edges()) {
|
||||
if (e->IsControlEdge()) {
|
||||
control_edges->push_back(e->src());
|
||||
} else {
|
||||
(*in)[e->dst_input()] = std::make_pair(e->src(), e->src_output());
|
||||
}
|
||||
}
|
||||
std::sort(control_edges->begin(), control_edges->end());
|
||||
if (n->op_def().is_commutative()) {
|
||||
// For commutative inputs, we sort the input by the input Node*
|
||||
// to get a canonical ordering (so that add(a,b) and add(b, a) will
|
||||
// hash to the same value if is_commutative is true for 'add').
|
||||
std::sort(in->begin(), in->end());
|
||||
}
|
||||
}
|
||||
|
||||
Node* NodeMergeRewritePass::FindNodeForMerge(const Node* a) const {
|
||||
// Search for all matching mergeinfo.
|
||||
// We allow more than one match for extensibility.
|
||||
std::vector<const MergeInfo*> matching_mi;
|
||||
for (auto mi = minfo_.cbegin(); mi != minfo_.cend(); ++mi) {
|
||||
if (a->type_string() == mi->succ) {
|
||||
matching_mi.push_back(&*mi);
|
||||
}
|
||||
}
|
||||
|
||||
for (const MergeInfo* mi : matching_mi) {
|
||||
const int N_in = a->num_inputs();
|
||||
if (mi->op >= N_in) {
|
||||
continue;
|
||||
}
|
||||
|
||||
// Get the control edges and input of node
|
||||
gtl::InlinedVector<Node*, 4> a_control_edges;
|
||||
gtl::InlinedVector<std::pair<Node*, int>, 4> a_in(N_in);
|
||||
FillInputs(a, &a_control_edges, &a_in);
|
||||
|
||||
// Get operand op of the operator
|
||||
Node* b = nullptr;
|
||||
b = a_in[mi->op].first;
|
||||
if (b == nullptr || (b->type_string() != mi->pred)) {
|
||||
// NOTE: Should the first check be assert?
|
||||
continue;
|
||||
}
|
||||
|
||||
gtl::InlinedVector<Node*, 4> b_control_edges;
|
||||
gtl::InlinedVector<std::pair<Node*, int>, 4> b_in(N_in);
|
||||
FillInputs(b, &b_control_edges, &b_in);
|
||||
|
||||
// Shouldn't merge if a and b have different control edges.
|
||||
if (a_control_edges != b_control_edges) {
|
||||
continue;
|
||||
} else {
|
||||
// We found a match.
|
||||
return b;
|
||||
}
|
||||
}
|
||||
|
||||
return nullptr;
|
||||
}
|
||||
|
||||
void NodeMergeRewritePass::GetDummyMklTensorNode(std::unique_ptr<Graph>* g,
|
||||
Node** out) {
|
||||
const DataType dt = DataTypeToEnum<uint8>::v();
|
||||
TensorProto proto;
|
||||
proto.set_dtype(dt);
|
||||
uint8 zero[8] = {0, 0, 0, 0, 0, 0, 0, 0};
|
||||
proto.set_tensor_content(const_cast<const void*>(static_cast<void*>(&zero)),
|
||||
8);
|
||||
TensorShape dummy_shape({8});
|
||||
dummy_shape.AsProto(proto.mutable_tensor_shape());
|
||||
TF_CHECK_OK(NodeBuilder((*g)->NewName("DMT"), "Const")
|
||||
.Attr("value", proto)
|
||||
.Attr("dtype", dt)
|
||||
.Finalize(&**g, out));
|
||||
}
|
||||
|
||||
Status NodeMergeRewritePass::MergeNode(std::unique_ptr<Graph>* g, Node* succ,
|
||||
Node* pred) {
|
||||
CHECK_NOTNULL(succ);
|
||||
CHECK_NOTNULL(pred);
|
||||
|
||||
if (succ->type_string() == csinfo_.biasadd &&
|
||||
pred->type_string() == csinfo_.conv2d) {
|
||||
// 1. Get all attributes from input nodes.
|
||||
DataType T_pred, T_succ;
|
||||
string padding;
|
||||
std::vector<int32> strides;
|
||||
string data_format_pred, data_format_succ;
|
||||
bool use_cudnn_on_gnu;
|
||||
TF_CHECK_OK(GetNodeAttr(pred->def(), "T", &T_pred));
|
||||
TF_CHECK_OK(GetNodeAttr(succ->def(), "T", &T_succ));
|
||||
TF_CHECK_OK(GetNodeAttr(pred->def(), "padding", &padding));
|
||||
TF_CHECK_OK(GetNodeAttr(pred->def(), "strides", &strides));
|
||||
TF_CHECK_OK(GetNodeAttr(pred->def(), "data_format", &data_format_pred));
|
||||
TF_CHECK_OK(GetNodeAttr(succ->def(), "data_format", &data_format_succ));
|
||||
TF_CHECK_OK(
|
||||
GetNodeAttr(pred->def(), "use_cudnn_on_gpu", &use_cudnn_on_gnu));
|
||||
// We check to ensure that data formats of both succ and pred are same.
|
||||
// We expect them to be same, so we can enforce this as assert.
|
||||
// But assert can be too strict, so we enforce this as a check.
|
||||
// If the check fails, then we do not merge two nodes.
|
||||
// We also do same check for devices.
|
||||
if (data_format_pred != data_format_succ || T_pred != T_succ ||
|
||||
pred->assigned_device_name() != succ->assigned_device_name() ||
|
||||
pred->def().device() != succ->def().device()) {
|
||||
return Status(error::Code::INVALID_ARGUMENT,
|
||||
"data_format or T attribute or devices of Conv2D and "
|
||||
"BiasAdd do not match. Will skip node merge optimization");
|
||||
}
|
||||
|
||||
// 2. Get inputs from both the nodes.
|
||||
// Find the 2 inputs from the conv and the bias from the add Bias.
|
||||
Node* oper1 = nullptr;
|
||||
Node* oper1_mkl = nullptr; // Mkl tensor corresponding to oper1
|
||||
Node* oper2 = nullptr;
|
||||
Node* oper2_mkl = nullptr; // Mkl tensor corresponding to oper2
|
||||
Node* oper3 = nullptr;
|
||||
Node* oper3_mkl = nullptr; // Mkl tensor corresponding to oper3
|
||||
|
||||
const int succ_num = succ->num_inputs();
|
||||
gtl::InlinedVector<Node*, 4> succ_control_edges;
|
||||
gtl::InlinedVector<std::pair<Node*, int>, 4> succ_in(succ_num);
|
||||
FillInputs(succ, &succ_control_edges, &succ_in);
|
||||
|
||||
const int pred_num = pred->num_inputs();
|
||||
gtl::InlinedVector<Node*, 4> pred_control_edges;
|
||||
gtl::InlinedVector<std::pair<Node*, int>, 4> pred_in(pred_num);
|
||||
FillInputs(pred, &pred_control_edges, &pred_in);
|
||||
|
||||
// We need to ensure that there is only 1 edge between Conv2D and AddBias.
|
||||
// Otherwise, merging is semantically incorrect.
|
||||
if (pred->out_edges().size() != 1) {
|
||||
return Status(error::Code::INVALID_ARGUMENT,
|
||||
"Conv2D has multiple outputs."
|
||||
"Will skip node merge optimization");
|
||||
}
|
||||
|
||||
for (const Edge* e : pred->out_edges()) {
|
||||
if (e->dst() != succ) {
|
||||
return Status(error::Code::INVALID_ARGUMENT,
|
||||
"Conv2D does not feed to BiasAdd."
|
||||
"Will skip node merge optimization");
|
||||
}
|
||||
}
|
||||
|
||||
// Get operand 0, 1 of conv2D and their Mkl tensors.
|
||||
CHECK_EQ(pred->in_edges().size(), 4); // MklConv2D must have 4 inputs.
|
||||
oper1 = pred_in[0].first;
|
||||
oper1_mkl = pred_in[1].first;
|
||||
oper2 = pred_in[2].first;
|
||||
oper2_mkl = pred_in[3].first;
|
||||
// Get operand 1 of add_bias
|
||||
// BiasAdd must have 2 inputs: Conv, bias
|
||||
CHECK_EQ(succ->in_edges().size(), 2);
|
||||
oper3 = succ_in[1].first;
|
||||
GetDummyMklTensorNode(g, &oper3_mkl); // Get dummy Mkl tensor node
|
||||
// as BiasAdd does not have Mkl tensor as input.
|
||||
CHECK_NOTNULL(oper3_mkl);
|
||||
|
||||
Node* ret;
|
||||
// We will use the node name of BiasAdd as the name of new node
|
||||
TF_CHECK_OK(NodeBuilder(succ->name(), csinfo_.conv2dwithbias)
|
||||
.Input(oper1)
|
||||
.Input(oper1_mkl)
|
||||
.Input(oper2)
|
||||
.Input(oper2_mkl)
|
||||
.Input(oper3)
|
||||
.Input(oper3_mkl)
|
||||
.Attr("T", T_pred)
|
||||
.Attr("strides", strides)
|
||||
.Attr("padding", padding)
|
||||
.Attr("data_format", data_format_pred)
|
||||
.Attr("use_cudnn_on_gpu", use_cudnn_on_gnu)
|
||||
.Device(succ->def().device())
|
||||
.Finalize(&**g, &ret));
|
||||
CHECK_NOTNULL(ret);
|
||||
|
||||
// Incoming edges are fixed, we will fix the outgoing edges now.
|
||||
for (const Edge* e : succ->out_edges()) {
|
||||
(*g)->AddEdge(ret, e->src_output(), e->dst(), e->dst_input());
|
||||
}
|
||||
|
||||
// Copy device assigned to old node to new node.
|
||||
// It's ok to use pred or succ as we have enforced a check that
|
||||
// both have same device assigned.
|
||||
ret->set_assigned_device_name(pred->assigned_device_name());
|
||||
|
||||
VLOG(1) << "NodeMergeRewritePass: Merged old node:" << pred->DebugString()
|
||||
<< ", and node: " << succ->DebugString()
|
||||
<< ", into node:" << ret->DebugString();
|
||||
|
||||
(*g)->RemoveNode(succ);
|
||||
(*g)->RemoveNode(pred);
|
||||
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
return Status(error::Code::UNIMPLEMENTED,
|
||||
"Unimplemented case for node merge optimization.");
|
||||
}
|
||||
|
||||
Status NodeMergeRewritePass::RewriteNode(std::unique_ptr<Graph>* g, Node* n) {
|
||||
CHECK_NOTNULL(n);
|
||||
|
||||
// Get the matching rewriteinfo for the node
|
||||
const Node* fwdn = nullptr;
|
||||
const RewriteInfo* ri = FindMatchingRewriteInfo(n, &fwdn);
|
||||
if (ri == nullptr || fwdn == nullptr) {
|
||||
VLOG(2) << "NodeMergeRewritePass: Rewriteinfo not found for: "
|
||||
<< n->type_string();
|
||||
return Status(error::Code::INVALID_ARGUMENT,
|
||||
"Rewrite info not found for the node."
|
||||
"Will skip node rewrite optimization");
|
||||
}
|
||||
|
||||
VLOG(1) << "NodeMergeRewritePass: Rewrite called for: " << n->type_string();
|
||||
|
||||
if (n->type_string() == csinfo_.biasaddgrad &&
|
||||
ri->node == csinfo_.biasaddgrad &&
|
||||
(ri->rewrite == csinfo_.conv2dwithbiasbackpropbias ||
|
||||
ri->rewrite == csinfo_.biasaddgrad)) {
|
||||
DataType T;
|
||||
string data_format;
|
||||
TF_CHECK_OK(GetNodeAttr(n->def(), "T", &T));
|
||||
TF_CHECK_OK(GetNodeAttr(n->def(), "data_format", &data_format));
|
||||
|
||||
int n_num = n->num_inputs(); // this must be 1.
|
||||
CHECK_EQ(n_num, 1);
|
||||
|
||||
gtl::InlinedVector<Node*, 4> n_control_edges;
|
||||
gtl::InlinedVector<std::pair<Node*, int>, 4> n_in(n_num);
|
||||
FillInputs(n, &n_control_edges, &n_in);
|
||||
|
||||
Node *ret = nullptr, *op = n_in[0].first;
|
||||
|
||||
if (ri->rewrite == csinfo_.conv2dwithbiasbackpropbias) {
|
||||
// Get strides info from Conv2D (node in the forward pass that this
|
||||
// node corresponds to).
|
||||
std::vector<int32> strides;
|
||||
TF_CHECK_OK(GetNodeAttr(fwdn->def(), "strides", &strides));
|
||||
|
||||
// We use same name as original node name as there may be fetchoutputs
|
||||
// associated with it.
|
||||
TF_CHECK_OK(NodeBuilder(n->name(), ri->rewrite)
|
||||
.Input(op)
|
||||
.Attr("T", T)
|
||||
.Attr("data_format", data_format)
|
||||
.Attr("strides", strides)
|
||||
.Device(n->def().device())
|
||||
.Finalize(&**g, &ret));
|
||||
} else {
|
||||
CHECK_EQ(ri->rewrite, csinfo_.biasaddgrad);
|
||||
TF_CHECK_OK(NodeBuilder(n->name(), ri->rewrite)
|
||||
.Input(op)
|
||||
.Attr("T", T)
|
||||
.Attr("data_format", data_format)
|
||||
.Device(n->def().device())
|
||||
.Finalize(&**g, &ret));
|
||||
}
|
||||
|
||||
CHECK_NOTNULL(ret);
|
||||
|
||||
// Incoming edges are fixed, we will fix the outgoing edges now.
|
||||
for (const Edge* e : n->out_edges()) {
|
||||
(*g)->AddEdge(ret, e->src_output(), e->dst(), e->dst_input());
|
||||
}
|
||||
|
||||
// Copy device assigned to old node to new node.
|
||||
ret->set_assigned_device_name(n->assigned_device_name());
|
||||
|
||||
VLOG(1) << "MKLOptimizerMergePass: Rewrote old node:" << n->DebugString()
|
||||
<< ", into node:" << ret->DebugString();
|
||||
(*g)->RemoveNode(n);
|
||||
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
return Status(error::Code::UNIMPLEMENTED,
|
||||
"Unimplemented case for node rewrite optimization.");
|
||||
}
|
||||
|
||||
const NodeMergeRewritePass::RewriteInfo*
|
||||
NodeMergeRewritePass::FindMatchingRewriteInfo(const Node* n,
|
||||
const Node** fwdn) const {
|
||||
CHECK_NOTNULL(n);
|
||||
CHECK_NOTNULL(fwdn);
|
||||
*fwdn = nullptr;
|
||||
|
||||
// Search for matching rewriteinfo based on node name.
|
||||
// There could be more than one matching rewriteinfos.
|
||||
std::vector<const RewriteInfo*> matching_ri;
|
||||
for (auto ri = rinfo_.cbegin(); ri != rinfo_.cend(); ++ri) {
|
||||
if (n->type_string() == ri->node) {
|
||||
matching_ri.push_back(&*ri);
|
||||
}
|
||||
}
|
||||
|
||||
VLOG(1) << "NodeMergeRewritePass: Searching graph for: " << n->type_string()
|
||||
<< " in backwards.";
|
||||
|
||||
// Now we will check for forward op name for rewrite info in data
|
||||
// flow graph. Get the max hops we should search for the fwd node
|
||||
// We are now going to search (breadth-first) backwards in data
|
||||
// dependence graph (for up to max hops) from n for the node
|
||||
// specified in fwd.
|
||||
// queue to maintain nodes to be visited and depth info for
|
||||
// breadth-first search
|
||||
std::queue<std::pair<const Node*, int>> nqueue;
|
||||
const Node* curr_node = n;
|
||||
size_t curr_depth = 0;
|
||||
nqueue.push(std::make_pair(curr_node, curr_depth));
|
||||
|
||||
while (curr_depth < kNodeMergeContextMaxDepth && !nqueue.empty()) {
|
||||
std::pair<const Node*, int> curr_pair = nqueue.front();
|
||||
nqueue.pop();
|
||||
|
||||
std::set<const Node*> visited_nodes;
|
||||
curr_node = curr_pair.first;
|
||||
curr_depth = curr_pair.second;
|
||||
CHECK_NOTNULL(curr_node);
|
||||
|
||||
VLOG(1) << "NodeMergeRewritePass: Visiting node: "
|
||||
<< curr_node->type_string() << " at depth: " << curr_depth
|
||||
<< " for node: " << n->type_string();
|
||||
|
||||
// If we find a match, we return immediately with the matching rewrite
|
||||
// info.
|
||||
for (const RewriteInfo* ri : matching_ri) {
|
||||
if (curr_node->type_string() == ri->cinfo.fwd) {
|
||||
*fwdn = curr_node;
|
||||
return ri;
|
||||
}
|
||||
}
|
||||
|
||||
// Else we explore backward edges from current node.
|
||||
// Add the source nodes of all incoming edges of the node to the queue.
|
||||
for (const Edge* e : curr_node->in_edges()) {
|
||||
// We do not visit already visited node.
|
||||
if (visited_nodes.find(e->src()) == visited_nodes.end()) {
|
||||
// Depth of these nodes is 1 more than the depth of current node.
|
||||
nqueue.push(std::make_pair(e->src(), curr_depth + 1));
|
||||
visited_nodes.insert(e->src());
|
||||
}
|
||||
}
|
||||
} /* while */
|
||||
|
||||
return nullptr;
|
||||
}
|
||||
|
||||
bool NodeMergeRewritePass::IsApplicableRewriteNode(const Node* n) const {
|
||||
CHECK_NOTNULL(n);
|
||||
|
||||
// Search for matching rewriteinfo
|
||||
// Even if we find one match, we return true.
|
||||
bool match_found = false;
|
||||
for (const RewriteInfo& ri : rinfo_) {
|
||||
if (n->type_string() == ri.node) {
|
||||
match_found = true;
|
||||
break;
|
||||
}
|
||||
}
|
||||
|
||||
return match_found;
|
||||
}
|
||||
|
||||
bool NodeMergeRewritePass::RunPass(std::unique_ptr<Graph>* g) {
|
||||
bool result = false;
|
||||
CHECK_NOTNULL(g);
|
||||
|
||||
DumpGraph("Before OptimizeMerge", &**g);
|
||||
|
||||
std::vector<Node*> order;
|
||||
GetReversePostOrder(**g, &order);
|
||||
std::vector<std::pair<Node*, Node*>> nodes_to_be_merged;
|
||||
std::vector<Node*> nodes_to_be_rewritten;
|
||||
|
||||
for (Node* n : order) {
|
||||
if (!n->IsOp()) continue;
|
||||
Node* n1 = nullptr;
|
||||
if ((n1 = FindNodeForMerge(n)) != nullptr) {
|
||||
VLOG(1) << "NodeMergeRewritePass: Scheduled nodes " << n->name()
|
||||
<< " and " << n1->name() << " for merging";
|
||||
nodes_to_be_merged.push_back(std::make_pair(n, n1));
|
||||
} else if (IsApplicableRewriteNode(n)) {
|
||||
VLOG(1) << "NodeMergeRewritePass: Scheduled node " << n->name()
|
||||
<< " for rewrite";
|
||||
nodes_to_be_rewritten.push_back(n);
|
||||
}
|
||||
}
|
||||
|
||||
for (std::pair<Node*, Node*> i : nodes_to_be_merged) {
|
||||
// Even if MergeNode merges single pair of nodes, we
|
||||
// need to return true.
|
||||
string n1_name = i.first->name();
|
||||
string n2_name = i.second->name();
|
||||
if (MergeNode(g, i.first, i.second) == Status::OK()) {
|
||||
VLOG(1) << "NodeMergeRewritePass: Merged nodes " << n1_name << " and "
|
||||
<< n2_name;
|
||||
result = true;
|
||||
}
|
||||
}
|
||||
|
||||
DumpGraph("After OptimizeMerge(nodemerge)", &**g);
|
||||
|
||||
for (Node* i : nodes_to_be_rewritten) {
|
||||
string name = i->name();
|
||||
if (RewriteNode(g, i) == Status::OK()) {
|
||||
VLOG(1) << "NodeMergeRewritePass: Rewrite node: " << name
|
||||
<< " successful.";
|
||||
result = true;
|
||||
}
|
||||
}
|
||||
|
||||
DumpGraph("After OptimizeMerge(noderewrite)", &**g);
|
||||
|
||||
return result;
|
||||
}
|
||||
|
||||
bool OptimizeNodeMerge(std::unique_ptr<Graph>* g) {
|
||||
return NodeMergeRewritePass().RunPass(g);
|
||||
}
|
||||
|
||||
Status NodeMergeRewritePass::Run(const GraphOptimizationPassOptions& options) {
|
||||
if (options.graph == nullptr) {
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
// Get the ownership of graph
|
||||
std::unique_ptr<Graph>* g = std::move(options.graph);
|
||||
|
||||
RunPass(g);
|
||||
|
||||
// Return the ownership of graph back
|
||||
options.graph->reset(g->release());
|
||||
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
} // namespace tensorflow
|
||||
|
||||
#endif
|
@ -1,36 +0,0 @@
|
||||
/* Copyright 2015 The TensorFlow Authors. All Rights Reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
|
||||
// An optimization pass that performs node merging and rewrite on graph nodes
|
||||
|
||||
#ifndef TENSORFLOW_GRAPH_MKL_OPTIMIZER_MERGE_H_
|
||||
#define TENSORFLOW_GRAPH_MKL_OPTIMIZER_MERGE_H_
|
||||
|
||||
#ifdef INTEL_MKL
|
||||
|
||||
#include <sys/types.h>
|
||||
#include <memory>
|
||||
#include "tensorflow/core/graph/graph.h"
|
||||
|
||||
namespace tensorflow {
|
||||
// Interface to invoke the pass for unit test
|
||||
//
|
||||
// Returns true if and only if 'g' is mutated.
|
||||
extern bool OptimizeNodeMerge(std::unique_ptr<Graph>* g);
|
||||
} // namespace tensorflow
|
||||
|
||||
#endif // INTEL_MKL
|
||||
|
||||
#endif // TENSORFLOW_GRAPH_MKL_OPTIMIZER_MERGE_H_
|
@ -1,470 +0,0 @@
|
||||
/* Copyright 2015 The TensorFlow Authors. All Rights Reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
|
||||
#ifdef INTEL_MKL
|
||||
|
||||
#include "tensorflow/core/graph/mkl_optimizer_merge.h"
|
||||
|
||||
#include <vector>
|
||||
#include "tensorflow/core/framework/op.h"
|
||||
#include "tensorflow/core/framework/tensor.h"
|
||||
#include "tensorflow/core/graph/graph.h"
|
||||
#include "tensorflow/core/graph/graph_constructor.h"
|
||||
#include "tensorflow/core/graph/testlib.h"
|
||||
#include "tensorflow/core/kernels/ops_util.h"
|
||||
#include "tensorflow/core/lib/random/simple_philox.h"
|
||||
#include "tensorflow/core/lib/strings/str_util.h"
|
||||
#include "tensorflow/core/lib/strings/stringprintf.h"
|
||||
#include "tensorflow/core/platform/logging.h"
|
||||
#include "tensorflow/core/platform/protobuf.h"
|
||||
#include "tensorflow/core/platform/test.h"
|
||||
#include "tensorflow/core/platform/test_benchmark.h"
|
||||
|
||||
namespace tensorflow {
|
||||
namespace {
|
||||
|
||||
class OptimizerMergeTest : public ::testing::Test {
|
||||
public:
|
||||
OptimizerMergeTest() : graph_(OpRegistry::Global()) {}
|
||||
|
||||
static void InitGraph(const string& s, Graph* graph) {
|
||||
GraphDef graph_def;
|
||||
|
||||
auto parser = protobuf::TextFormat::Parser();
|
||||
CHECK(parser.MergeFromString(s, &graph_def)) << s;
|
||||
GraphConstructorOptions opts;
|
||||
TF_CHECK_OK(ConvertGraphDefToGraph(opts, graph_def, graph));
|
||||
}
|
||||
|
||||
void InitGraph(const string& s) {
|
||||
InitGraph(s, &graph_);
|
||||
original_ = CanonicalGraphString(&graph_);
|
||||
}
|
||||
|
||||
static bool IncludeNode(const Node* n) { return n->IsOp(); }
|
||||
|
||||
static string EdgeId(const Node* n, int index) {
|
||||
if (index == 0) {
|
||||
return n->name();
|
||||
} else if (index == Graph::kControlSlot) {
|
||||
return strings::StrCat(n->name(), ":control");
|
||||
} else {
|
||||
return strings::StrCat(n->name(), ":", index);
|
||||
}
|
||||
}
|
||||
|
||||
string CanonicalGraphString(Graph* g) {
|
||||
std::vector<string> nodes;
|
||||
std::vector<string> edges;
|
||||
for (const Node* n : g->nodes()) {
|
||||
if (IncludeNode(n)) {
|
||||
nodes.push_back(strings::StrCat(n->name(), "(", n->type_string(), ")"));
|
||||
}
|
||||
}
|
||||
for (const Edge* e : g->edges()) {
|
||||
if (IncludeNode(e->src()) && IncludeNode(e->dst())) {
|
||||
edges.push_back(strings::StrCat(EdgeId(e->src(), e->src_output()), "->",
|
||||
EdgeId(e->dst(), e->dst_input())));
|
||||
}
|
||||
}
|
||||
// Canonicalize
|
||||
std::sort(nodes.begin(), nodes.end());
|
||||
std::sort(edges.begin(), edges.end());
|
||||
return strings::StrCat(str_util::Join(nodes, ";"), "|",
|
||||
str_util::Join(edges, ";"));
|
||||
}
|
||||
|
||||
string DoNodeMerge() {
|
||||
string before = CanonicalGraphString(&graph_);
|
||||
LOG(ERROR) << "Before node merge optimize: " << before;
|
||||
|
||||
std::unique_ptr<Graph>* ug = new std::unique_ptr<Graph>(&graph_);
|
||||
OptimizeNodeMerge(ug);
|
||||
|
||||
string result = CanonicalGraphString(&graph_);
|
||||
LOG(ERROR) << "After node merge optimize: " << result;
|
||||
return result;
|
||||
}
|
||||
|
||||
const string& OriginalGraph() const { return original_; }
|
||||
|
||||
Graph graph_;
|
||||
string original_;
|
||||
};
|
||||
|
||||
REGISTER_OP("Input").Output("o: float").SetIsStateful();
|
||||
REGISTER_OP("MklInput").Output("o: uint8").SetIsStateful();
|
||||
|
||||
TEST_F(OptimizerMergeTest, Basic) {
|
||||
InitGraph(
|
||||
"node { name: 'A' op: 'Input'}"
|
||||
"node { name: 'B' op: 'Input'}"
|
||||
"node { name: 'C' op: 'Mul' attr { key: 'T' value { type: DT_FLOAT } }"
|
||||
" input: ['A', 'B'] }"
|
||||
"node { name: 'D' op: 'Mul' attr { key: 'T' value { type: DT_FLOAT } }"
|
||||
" input: ['A', 'B'] }");
|
||||
EXPECT_EQ(DoNodeMerge(),
|
||||
"A(Input);B(Input);C(Mul);D(Mul)|"
|
||||
"A->C;A->D;B->C:1;B->D:1");
|
||||
}
|
||||
|
||||
// Test set 1: Conv2D + AddBias
|
||||
|
||||
// C=MklConv2D(A,M,B,N); E=BiasAdd(C,D); Z=Sub(E,Y)
|
||||
TEST_F(OptimizerMergeTest, Conv2DWithBias_Positive) {
|
||||
InitGraph(
|
||||
"node { name: 'A' op: 'Input'}"
|
||||
"node { name: 'M' op: 'MklInput'}"
|
||||
"node { name: 'B' op: 'Input'}"
|
||||
"node { name: 'N' op: 'MklInput'}"
|
||||
"node { name: 'C' op: 'MklConv2D'"
|
||||
" attr { key: 'T' value { type: DT_FLOAT } }"
|
||||
" attr { key: 'data_format' value { s: 'NCHW' } }"
|
||||
" attr { key: 'use_cudnn_on_gpu' value { b: false } }"
|
||||
" attr { key: 'strides' value { list: {i: 1, i:1, i:1, i:1} } }"
|
||||
" attr { key: 'padding' value { s: 'SAME' } }"
|
||||
" input: ['A', 'M', 'B', 'N']}"
|
||||
"node { name: 'D' op: 'Input'}"
|
||||
"node { name: 'E' op: 'BiasAdd'"
|
||||
" attr { key: 'T' value { type: DT_FLOAT } }"
|
||||
" attr { key: 'data_format' value { s: 'NCHW' } }"
|
||||
" input: ['C', 'D'] }"
|
||||
"node { name: 'Y' op: 'Input'}"
|
||||
"node { name: 'Z' op: 'Sub'"
|
||||
" attr {key: 'T' value { type: DT_FLOAT } }"
|
||||
" input: ['E', 'Y']}");
|
||||
EXPECT_EQ(DoNodeMerge(),
|
||||
"A(Input);B(Input);D(Input);DMT/_0(Const);E(MklConv2DWithBias);"
|
||||
"M(MklInput);N(MklInput);Y(Input);Z(Sub)|A->E;B->E:2;D->E:4;"
|
||||
"DMT/_0->E:5;E->Z;M->E:1;N->E:3;Y->Z:1");
|
||||
}
|
||||
|
||||
// C=Conv2D(A,B); E=BiasAdd(C,D); Z=Sub(E,Y);
|
||||
// We do not merge in this case as op is Conv2D and not MklConv2D.
|
||||
TEST_F(OptimizerMergeTest, Conv2DWithBias_Negative_NoMklConv2D) {
|
||||
InitGraph(
|
||||
"node { name: 'A' op: 'Input'}"
|
||||
"node { name: 'B' op: 'Input'}"
|
||||
"node { name: 'C' op: 'Conv2D'"
|
||||
" attr { key: 'T' value { type: DT_FLOAT } }"
|
||||
" attr { key: 'data_format' value { s: 'NCHW' } }"
|
||||
" attr { key: 'use_cudnn_on_gpu' value { b: false } }"
|
||||
" attr { key: 'strides' value { list: {i: 1, i:1, i:1, i:1} } }"
|
||||
" attr { key: 'padding' value { s: 'SAME' } }"
|
||||
" input: ['A', 'B']}"
|
||||
"node { name: 'D' op: 'Input'}"
|
||||
"node { name: 'E' op: 'BiasAdd'"
|
||||
" attr { key: 'T' value { type: DT_FLOAT } }"
|
||||
" attr { key: 'data_format' value { s: 'NCHW' } }"
|
||||
" input: ['C', 'D'] }"
|
||||
"node { name: 'Y' op: 'Input'}"
|
||||
"node { name: 'Z' op: 'Sub'"
|
||||
" attr {key: 'T' value { type: DT_FLOAT } }"
|
||||
" input: ['E', 'Y']}");
|
||||
EXPECT_EQ(DoNodeMerge(),
|
||||
"A(Input);B(Input);C(Conv2D);D(Input);E(BiasAdd);Y(Input);Z(Sub)|"
|
||||
"A->C;B->C:1;C->E;D->E:1;E->Z;Y->Z:1");
|
||||
}
|
||||
|
||||
// Graph contains only MklConv2D, no AddBias.
|
||||
TEST_F(OptimizerMergeTest, Conv2DWithBias_Negative_NoAddBias) {
|
||||
InitGraph(
|
||||
"node { name: 'A' op: 'Input'}"
|
||||
"node { name: 'M' op: 'MklInput'}"
|
||||
"node { name: 'B' op: 'Input'}"
|
||||
"node { name: 'N' op: 'MklInput'}"
|
||||
"node { name: 'C' op: 'MklConv2D'"
|
||||
" attr { key: 'T' value { type: DT_FLOAT } }"
|
||||
" attr { key: 'data_format' value { s: 'NCHW' } }"
|
||||
" attr { key: 'use_cudnn_on_gpu' value { b: false } }"
|
||||
" attr { key: 'strides' value { list: {i: 1, i:1, i:1, i:1} } }"
|
||||
" attr { key: 'padding' value { s: 'SAME' } }"
|
||||
" input: ['A', 'M', 'B', 'N']}");
|
||||
EXPECT_EQ(DoNodeMerge(),
|
||||
"A(Input);B(Input);C(MklConv2D);M(MklInput);N(MklInput)|"
|
||||
"A->C;B->C:2;M->C:1;N->C:3");
|
||||
}
|
||||
|
||||
// MklConv2D output does not go to BiasAdd.
|
||||
TEST_F(OptimizerMergeTest, Conv2DWithBias_Negative_Dataflow1) {
|
||||
InitGraph(
|
||||
"node { name: 'A' op: 'Input'}"
|
||||
"node { name: 'M' op: 'MklInput'}"
|
||||
"node { name: 'B' op: 'Input'}"
|
||||
"node { name: 'N' op: 'MklInput'}"
|
||||
"node { name: 'C' op: 'MklConv2D'"
|
||||
" attr { key: 'T' value { type: DT_FLOAT } }"
|
||||
" attr { key: 'data_format' value { s: 'NCHW' } }"
|
||||
" attr { key: 'use_cudnn_on_gpu' value { b: false } }"
|
||||
" attr { key: 'strides' value { list: {i: 1, i:1, i:1, i:1} } }"
|
||||
" attr { key: 'padding' value { s: 'SAME' } }"
|
||||
" input: ['A', 'M', 'B', 'N']}"
|
||||
"node { name: 'D' op: 'Input'}"
|
||||
"node { name: 'E' op: 'Input'}"
|
||||
"node { name: 'F' op: 'BiasAdd'"
|
||||
" attr { key: 'T' value { type: DT_FLOAT } }"
|
||||
" attr { key: 'data_format' value { s: 'NCHW' } }"
|
||||
" input: ['D', 'E'] }"); // Output of MklConv2D does not go to BiasAdd.
|
||||
EXPECT_EQ(DoNodeMerge(),
|
||||
"A(Input);B(Input);C(MklConv2D);D(Input);E(Input);F(BiasAdd);"
|
||||
"M(MklInput);N(MklInput)|A->C;B->C:2;D->F;E->F:1;M->C:1;N->C:3");
|
||||
}
|
||||
|
||||
// MklConv2D has two outgoing edges: BiasAdd and some other dummy node (Add).
|
||||
// Merge should not be done in such case.
|
||||
TEST_F(OptimizerMergeTest, Conv2DWithBias_Negative_Dataflow2) {
|
||||
InitGraph(
|
||||
"node { name: 'A' op: 'Input'}"
|
||||
"node { name: 'M' op: 'MklInput'}"
|
||||
"node { name: 'B' op: 'Input'}"
|
||||
"node { name: 'N' op: 'MklInput'}"
|
||||
"node { name: 'C' op: 'MklConv2D'"
|
||||
" attr { key: 'T' value { type: DT_FLOAT } }"
|
||||
" attr { key: 'data_format' value { s: 'NCHW' } }"
|
||||
" attr { key: 'use_cudnn_on_gpu' value { b: false } }"
|
||||
" attr { key: 'strides' value { list: {i: 1, i:1, i:1, i:1} } }"
|
||||
" attr { key: 'padding' value { s: 'SAME' } }"
|
||||
" input: ['A', 'M', 'B', 'N']}"
|
||||
"node { name: 'D' op: 'Input'}"
|
||||
"node { name: 'E' op: 'Input'}"
|
||||
"node { name: 'F' op: 'BiasAdd'"
|
||||
" attr { key: 'T' value { type: DT_FLOAT } }"
|
||||
" attr { key: 'data_format' value { s: 'NCHW' } }"
|
||||
" input: ['D', 'E'] }" // Conv2D has two outputs.
|
||||
// No merge should happen.
|
||||
"node { name: 'G' op: 'Add'"
|
||||
" attr { key: 'T' value { type: DT_FLOAT } }"
|
||||
" input: ['C', 'E'] }");
|
||||
EXPECT_EQ(DoNodeMerge(),
|
||||
"A(Input);B(Input);C(MklConv2D);D(Input);E(Input);F(BiasAdd);"
|
||||
"G(Add);M(MklInput);N(MklInput)|A->C;B->C:2;C->G;D->F;"
|
||||
"E->F:1;E->G:1;M->C:1;N->C:3");
|
||||
}
|
||||
|
||||
// data_format attribute value mismatch. Merge should not be done
|
||||
// in such case.
|
||||
TEST_F(OptimizerMergeTest, Conv2DWithBias_Negative_AttrMismatch) {
|
||||
InitGraph(
|
||||
"node { name: 'A' op: 'Input'}"
|
||||
"node { name: 'M' op: 'MklInput'}"
|
||||
"node { name: 'B' op: 'Input'}"
|
||||
"node { name: 'N' op: 'MklInput'}"
|
||||
"node { name: 'C' op: 'MklConv2D'"
|
||||
" attr { key: 'T' value { type: DT_FLOAT } }"
|
||||
" attr { key: 'data_format' value { s: 'NCHW' } }"
|
||||
" attr { key: 'use_cudnn_on_gpu' value { b: false } }"
|
||||
" attr { key: 'strides' value { list: {i: 1, i:1, i:1, i:1} } }"
|
||||
" attr { key: 'padding' value { s: 'SAME' } }"
|
||||
" input: ['A', 'M', 'B', 'N']}"
|
||||
"node { name: 'D' op: 'Input'}"
|
||||
"node { name: 'E' op: 'BiasAdd'"
|
||||
" attr { key: 'T' value { type: DT_FLOAT } }"
|
||||
" attr { key: 'data_format' value { s: 'NHCW' } }"
|
||||
" input: ['C', 'D'] }");
|
||||
EXPECT_EQ(DoNodeMerge(),
|
||||
"A(Input);B(Input);C(MklConv2D);D(Input);E(BiasAdd);M(MklInput);"
|
||||
"N(MklInput)|A->C;B->C:2;C->E;D->E:1;M->C:1;N->C:3");
|
||||
}
|
||||
|
||||
#if 0
|
||||
// This test set is disabled temporarily as we do not enable node rewrite.
|
||||
// This test set will be enabled when we support Mkl-specific kernels for
|
||||
// backward bias.
|
||||
//
|
||||
// Test set 2: MklConv2D..BiasAddGrad -> Conv2DWithBiasBackpropBias
|
||||
// rewrite tests
|
||||
|
||||
// C=MklConv2D(A,M,B,N); D=Sub(C,A); E=BiasAddGrad(D)
|
||||
TEST_F(OptimizerMergeTest, Conv2DBackprop_Positive) {
|
||||
InitGraph(
|
||||
"node { name: 'A' op: 'Input'}"
|
||||
"node { name: 'M' op: 'MklInput'}"
|
||||
"node { name: 'B' op: 'Input'}"
|
||||
"node { name: 'N' op: 'MklInput'}"
|
||||
"node { name: 'C' op: 'MklConv2D'"
|
||||
" attr { key: 'T' value { type: DT_FLOAT } }"
|
||||
" attr { key: 'data_format' value { s: 'NCHW' } }"
|
||||
" attr { key: 'use_cudnn_on_gpu' value { b: false } }"
|
||||
" attr { key: 'strides' value { list: {i: 1, i:1, i:1, i:1} } }"
|
||||
" attr { key: 'padding' value { s: 'SAME' } }"
|
||||
" input: ['A', 'M', 'B', 'N']}"
|
||||
"node { name: 'D' op: 'Sub'"
|
||||
" attr {key: 'T' value { type: DT_FLOAT } }"
|
||||
" input: ['C', 'A']}"
|
||||
"node { name: 'E' op: 'BiasAddGrad'"
|
||||
" attr { key: 'T' value { type: DT_FLOAT } }"
|
||||
" attr { key: 'data_format' value { s: 'NCHW' } }"
|
||||
" input: ['D'] }");
|
||||
EXPECT_EQ(DoNodeMerge(),
|
||||
"A(Input);B(Input);C(MklConv2D);D(Sub);E(Conv2DWithBiasBackpropBias);"
|
||||
"M(MklInput);N(MklInput)|A->C;A->D:1;B->C:2;C->D;D->E;M->C:1;N->C:3");
|
||||
}
|
||||
|
||||
// No MklConv2D in context, but Conv2D in context. No rewrite should happen.
|
||||
// C=Conv2D(A,B); D=Sub(C,A); E=BiasAddGrad(D)
|
||||
TEST_F(OptimizerMergeTest, Conv2DBackprop_Negative_NoMklConv2D) {
|
||||
InitGraph(
|
||||
"node { name: 'A' op: 'Input'}"
|
||||
"node { name: 'B' op: 'Input'}"
|
||||
"node { name: 'C' op: 'Conv2D'"
|
||||
" attr { key: 'T' value { type: DT_FLOAT } }"
|
||||
" attr { key: 'data_format' value { s: 'NCHW' } }"
|
||||
" attr { key: 'use_cudnn_on_gpu' value { b: false } }"
|
||||
" attr { key: 'strides' value { list: {i: 1, i:1, i:1, i:1} } }"
|
||||
" attr { key: 'padding' value { s: 'SAME' } }"
|
||||
" input: ['A', 'B']}"
|
||||
"node { name: 'D' op: 'Sub'"
|
||||
" attr {key: 'T' value { type: DT_FLOAT } }"
|
||||
" input: ['C', 'A']}"
|
||||
"node { name: 'E' op: 'BiasAddGrad'"
|
||||
" attr { key: 'T' value { type: DT_FLOAT } }"
|
||||
" attr { key: 'data_format' value { s: 'NCHW' } }"
|
||||
" input: ['D'] }");
|
||||
EXPECT_EQ(DoNodeMerge(),
|
||||
"A(Input);B(Input);C(Conv2D);D(Sub);E(BiasAddGrad)|"
|
||||
"A->C;A->D:1;B->C:1;C->D;D->E");
|
||||
}
|
||||
|
||||
// No Conv2D in the context for BiasAddGrad. No rewrite should happen.
|
||||
// C=Add(A,B); D=Sub(C,A); E=BiasAddGrad(D)
|
||||
TEST_F(OptimizerMergeTest, Conv2DBackprop_Negative_NoConv2D) {
|
||||
InitGraph(
|
||||
"node { name: 'A' op: 'Input'}"
|
||||
"node { name: 'B' op: 'Input'}"
|
||||
"node { name: 'C' op: 'Add'"
|
||||
" attr { key: 'T' value { type: DT_FLOAT } }"
|
||||
" input: ['A', 'B']}"
|
||||
"node { name: 'D' op: 'Sub'"
|
||||
" attr {key: 'T' value { type: DT_FLOAT } }"
|
||||
" input: ['C', 'A']}"
|
||||
"node { name: 'E' op: 'BiasAddGrad'"
|
||||
" attr { key: 'T' value { type: DT_FLOAT } }"
|
||||
" attr { key: 'data_format' value { s: 'NCHW' } }"
|
||||
" input: ['D'] }");
|
||||
EXPECT_EQ(DoNodeMerge(),
|
||||
"A(Input);B(Input);C(Add);D(Sub);E(BiasAddGrad)|"
|
||||
"A->C;A->D:1;B->C:1;C->D;D->E");
|
||||
}
|
||||
|
||||
// No Conv2D in the context for BiasAddGrad, but MatMul in context.
|
||||
// Rewrite should happen, but name of BiasAddGrad does not change.
|
||||
// C=MatMul(A,B); D=Sub(C,A); E=BiasAddGrad(D)
|
||||
TEST_F(OptimizerMergeTest, Conv2DBackprop_Negative_NoConv2D_MatMul) {
|
||||
InitGraph(
|
||||
"node { name: 'A' op: 'Input'}"
|
||||
"node { name: 'B' op: 'Input'}"
|
||||
"node { name: 'C' op: 'MatMul'"
|
||||
" attr { key: 'T' value { type: DT_FLOAT } }"
|
||||
" attr { key: 'transpose_a' value { b: false } }"
|
||||
" attr { key: 'transpose_b' value { b: false } }"
|
||||
" input: ['A', 'B']}"
|
||||
"node { name: 'D' op: 'Sub'"
|
||||
" attr {key: 'T' value { type: DT_FLOAT } }"
|
||||
" input: ['C', 'A']}"
|
||||
"node { name: 'E' op: 'BiasAddGrad'"
|
||||
" attr { key: 'T' value { type: DT_FLOAT } }"
|
||||
" attr { key: 'data_format' value { s: 'NCHW' } }"
|
||||
" input: ['D'] }");
|
||||
EXPECT_EQ(DoNodeMerge(),
|
||||
"A(Input);B(Input);C(MatMul);D(Sub);E(BiasAddGrad)|"
|
||||
"A->C;A->D:1;B->C:1;C->D;D->E");
|
||||
}
|
||||
|
||||
// Test set 3: MatMul..BiasAddGrad -> BiasAddGrad rewrite tests
|
||||
// C=MatMul(A,B); D=Sub(C,A); E=BiasAddGrad(D)
|
||||
TEST_F(OptimizerMergeTest, MatMulBiasAddGrad_Positive) {
|
||||
InitGraph(
|
||||
"node { name: 'A' op: 'Input'}"
|
||||
"node { name: 'B' op: 'Input'}"
|
||||
"node { name: 'C' op: 'MatMul'"
|
||||
" attr { key: 'T' value { type: DT_FLOAT } }"
|
||||
" attr { key: 'transpose_a' value { b: false } }"
|
||||
" attr { key: 'transpose_b' value { b: false } }"
|
||||
" input: ['A', 'B']}"
|
||||
"node { name: 'D' op: 'Sub'"
|
||||
" attr {key: 'T' value { type: DT_FLOAT } }"
|
||||
" input: ['C', 'A']}"
|
||||
"node { name: 'E' op: 'BiasAddGrad'"
|
||||
" attr { key: 'T' value { type: DT_FLOAT } }"
|
||||
" attr { key: 'data_format' value { s: 'NCHW' } }"
|
||||
" input: ['D'] }");
|
||||
EXPECT_EQ(DoNodeMerge(),
|
||||
"A(Input);B(Input);C(MatMul);D(Sub);E(BiasAddGrad)|"
|
||||
"A->C;A->D:1;B->C:1;C->D;D->E");
|
||||
}
|
||||
|
||||
// No MatMul in the context for BiasAddGrad. No rewrite should happen.
|
||||
// C=Add(A,B); D=Sub(C,A); E=BiasAddGrad(D)
|
||||
TEST_F(OptimizerMergeTest, MatMulBiasAddGrad_Negative_NoMatMul) {
|
||||
InitGraph(
|
||||
"node { name: 'A' op: 'Input'}"
|
||||
"node { name: 'B' op: 'Input'}"
|
||||
"node { name: 'C' op: 'Add'"
|
||||
" attr { key: 'T' value { type: DT_FLOAT } }"
|
||||
" input: ['A', 'B']}"
|
||||
"node { name: 'D' op: 'Sub'"
|
||||
" attr {key: 'T' value { type: DT_FLOAT } }"
|
||||
" input: ['C', 'A']}"
|
||||
"node { name: 'E' op: 'BiasAddGrad'"
|
||||
" attr { key: 'T' value { type: DT_FLOAT } }"
|
||||
" attr { key: 'data_format' value { s: 'NCHW' } }"
|
||||
" input: ['D'] }");
|
||||
EXPECT_EQ(DoNodeMerge(),
|
||||
"A(Input);B(Input);C(Add);D(Sub);E(BiasAddGrad)|"
|
||||
"A->C;A->D:1;B->C:1;C->D;D->E");
|
||||
}
|
||||
#endif
|
||||
|
||||
static void BM_NodeMerge(int iters, int op_nodes) {
|
||||
testing::StopTiming();
|
||||
string s;
|
||||
for (int in = 0; in < 10; in++) {
|
||||
s += strings::Printf("node { name: 'in%04d' op: 'Input'}", in);
|
||||
}
|
||||
random::PhiloxRandom philox(301, 17);
|
||||
random::SimplePhilox rnd(&philox);
|
||||
for (int op = 0; op < op_nodes; op++) {
|
||||
s += strings::Printf(
|
||||
"node { name: 'op%04d' op: 'Mul' attr { key: 'T' value { "
|
||||
"type: DT_FLOAT } } input: ['in%04d', 'in%04d' ] }",
|
||||
op, rnd.Uniform(10), rnd.Uniform(10));
|
||||
}
|
||||
|
||||
bool first = true;
|
||||
while (iters > 0) {
|
||||
Graph* graph = new Graph(OpRegistry::Global());
|
||||
OptimizerMergeTest::InitGraph(s, graph);
|
||||
int N = graph->num_node_ids();
|
||||
if (first) {
|
||||
testing::SetLabel(strings::StrCat("Per graph node. Nodes: ", N));
|
||||
first = false;
|
||||
}
|
||||
{
|
||||
testing::StartTiming();
|
||||
std::unique_ptr<Graph> ug(graph);
|
||||
OptimizeNodeMerge(&ug);
|
||||
testing::StopTiming();
|
||||
}
|
||||
iters -= N; // Our benchmark units are individual graph nodes,
|
||||
// not whole graphs
|
||||
// delete graph;
|
||||
}
|
||||
}
|
||||
BENCHMARK(BM_NodeMerge)->Arg(1000)->Arg(10000);
|
||||
|
||||
} // namespace
|
||||
} // namespace tensorflow
|
||||
|
||||
#endif /* INTEL_MKL */
|
@ -40,16 +40,16 @@ namespace tensorflow {
|
||||
|
||||
// This pass inserts Mkl to Tf tensor conversion nodes (represented by C)
|
||||
// in the graph in between A and B, where A and B match any one
|
||||
// of the following
|
||||
// cases:
|
||||
// 1) A = layer/Op that generates output in Mkl format and,
|
||||
// B = layer/Op that does not accept input in Mkl format and,
|
||||
// of the following cases:
|
||||
//
|
||||
// 1) A = a node that generates output in the Mkl format and,
|
||||
// B = a node that does not accept input in the Mkl format and,
|
||||
// A -> B (there is a direct edge between A and B, then
|
||||
// We will insert C such that A->C->B.
|
||||
//
|
||||
// 2) A = layer/Op that generates output in Mkl format and,
|
||||
// B = NULL (in other words, A is the last layer in the graph), then
|
||||
// We will insert C such that A->C->B. (C will be the last layer.)
|
||||
// 2) A = a node that generates output in the Mkl format and,
|
||||
// B = NULL (in other words, A is the last node in the graph), then
|
||||
// We will insert C such that A->C->B. (C will be the last node.)
|
||||
//
|
||||
// Note that case 1 applies to all outputs of A that are input to B.
|
||||
// In other words, the conversions will be required for every output
|
||||
@ -59,9 +59,9 @@ namespace tensorflow {
|
||||
// do the conversion for A1 and A2 only. We do not need to do any conversion
|
||||
// for A3.
|
||||
//
|
||||
// This pass relies on layers registering themselves about their Mkl compliant.
|
||||
// Mkl compliant layer can accept inputs in Mkl format, and produce output in
|
||||
// Mkl format. Non-compliant layer accepts inputs and outputs in
|
||||
// This pass relies on ops registering themselves about their Mkl compliance.
|
||||
// An Mkl-compliant op can accept inputs in the Mkl format, and produce outputs
|
||||
// in the Mkl format. Non-compliant ops accept inputs and outputs in the
|
||||
// TensorFlow format.
|
||||
//
|
||||
class MklToTfConversionPass : public GraphOptimizationPass {
|
||||
@ -84,7 +84,7 @@ class MklToTfConversionPass : public GraphOptimizationPass {
|
||||
// @input T Datatype to use for checking input op
|
||||
// @return true if op is Mkl supported; false, otherwise.
|
||||
inline bool IsMklSupportedOp(const string& op_name, DataType T) const {
|
||||
return mkl_layer_registry::IsMklLayer(op_name, T);
|
||||
return mkl_op_registry::IsMklOp(op_name, T);
|
||||
}
|
||||
|
||||
// Insert layout conversion node on the edge pointed by 'e' from graph 'g'.
|
||||
@ -129,14 +129,16 @@ Status MklToTfConversionPass::InsertConversionNodeOnEdge(
|
||||
return Status(error::Code::INVALID_ARGUMENT, err_msg.c_str());
|
||||
}
|
||||
|
||||
// Lets build the conversion node and specify src as input.
|
||||
// Build the conversion node and specify src as input.
|
||||
TF_CHECK_OK(
|
||||
NodeBuilder((*g)->NewName("Mkl2Tf"), "MklToTf")
|
||||
NodeBuilder((*g)->NewName("Mkl2Tf"), "_MklToTf")
|
||||
.Input(src, e->src_output())
|
||||
.Input(src, e->src_output() + 1) // Mkl tensor immediately
|
||||
// follows Tf tensor.
|
||||
.Device(src->def().device()) // We want to get conversion node
|
||||
// on same device as source node.
|
||||
.Input(src, DataIndexToMetaDataIndex(
|
||||
e->src_output(),
|
||||
src->num_outputs())) // Get an Mkl tensor slot
|
||||
// from the Tf tensor slot.
|
||||
.Device(src->def().device()) // We want to get conversion node
|
||||
// on same device as source node.
|
||||
.Attr("T", src_datatype)
|
||||
.Finalize(&**g, &conversion_node));
|
||||
|
||||
@ -149,8 +151,8 @@ Status MklToTfConversionPass::InsertConversionNodeOnEdge(
|
||||
// We want conversion node to be on the same device as the source node.
|
||||
conversion_node->set_assigned_device_name(src->assigned_device_name());
|
||||
|
||||
// Set the Mkl layer label for this op.
|
||||
conversion_node->AddAttr("_kernel", mkl_layer_registry::kMklLayerLabel);
|
||||
// Set the Mkl op label for this op.
|
||||
conversion_node->AddAttr("_kernel", mkl_op_registry::kMklOpLabel);
|
||||
|
||||
// Now that we have added edge from src->conversion_node, let's add edge from
|
||||
// output of conversion_node to the dest node. Since conversion_node
|
||||
@ -173,11 +175,11 @@ bool MklToTfConversionPass::RunPass(std::unique_ptr<Graph>* g) {
|
||||
|
||||
DumpGraph("Before MklToTfConversionPass", &**g);
|
||||
|
||||
// Since we are looking for mkl-supported op node immediately
|
||||
// followed by non-mkl op node, we will just iterate over edge
|
||||
// Since we are looking for an Mkl-supported op node immediately
|
||||
// followed by a non-Mkl op node, we will just iterate over edge
|
||||
// set of the graph.
|
||||
// vector to maintain candiadate edges whose source and destination
|
||||
// are candidate for inserting conversion node
|
||||
// edge set whose source and destination are candidates for
|
||||
// inserting conversion node
|
||||
std::vector<Edge*> candidate_edges;
|
||||
|
||||
for (const Edge* e : (*g)->edges()) {
|
||||
@ -190,9 +192,9 @@ bool MklToTfConversionPass::RunPass(std::unique_ptr<Graph>* g) {
|
||||
}
|
||||
|
||||
// We skip adding MklToTf on an edge between X->MklToTf or
|
||||
// MklToTf->X, where X is any layer.
|
||||
if (src->type_string().compare("MklToTf") == 0 ||
|
||||
dst->type_string().compare("MklToTf") == 0) {
|
||||
// MklToTf->X, where X is any node.
|
||||
if (src->type_string().compare("_MklToTf") == 0 ||
|
||||
dst->type_string().compare("_MklToTf") == 0) {
|
||||
continue;
|
||||
}
|
||||
|
||||
@ -210,7 +212,6 @@ bool MklToTfConversionPass::RunPass(std::unique_ptr<Graph>* g) {
|
||||
GetNodeAttr(dst->def(), "T", &dst_datatype);
|
||||
|
||||
// Check if src with is Mkl-compliant, while dst is not Mkl-compliant.
|
||||
|
||||
if (IsMklSupportedOp(src->type_string(), src_datatype) &&
|
||||
!IsMklSupportedOp(dst->type_string(), dst_datatype)) {
|
||||
VLOG(1) << "MklToTfConversionPass: Scheduled nodes " << src->name()
|
||||
|
@ -16,6 +16,7 @@ limitations under the License.
|
||||
#ifdef INTEL_MKL
|
||||
|
||||
#include "tensorflow/core/graph/mkl_tfconversion_pass.h"
|
||||
#include "tensorflow/core/util/mkl_util.h"
|
||||
|
||||
#include <algorithm>
|
||||
#include <string>
|
||||
@ -109,7 +110,7 @@ class MklToTfConversionPass : public ::testing::Test {
|
||||
|
||||
REGISTER_OP("Input").Output("o: float").SetIsStateful();
|
||||
REGISTER_OP("HalfInput").Output("o: half").SetIsStateful();
|
||||
REGISTER_OP("MklInput").Output("o: uint8").SetIsStateful();
|
||||
REGISTER_OP("_MklInput").Output("o: uint8").SetIsStateful();
|
||||
|
||||
TEST_F(MklToTfConversionPass, Basic) {
|
||||
InitGraph(
|
||||
@ -125,58 +126,116 @@ TEST_F(MklToTfConversionPass, Basic) {
|
||||
}
|
||||
|
||||
// MklConv2D followed by Non-Mkl layer
|
||||
// C=MklConv2D(A,M,B,N); E=Sub(C,D)
|
||||
// C=MklConv2D(A,M,B,N); E=Sub(C,D) (for interleaved ordering)
|
||||
// C=MklConv2D(A,B,M,N); E=Sub(C,D) (for contiguous ordering)
|
||||
TEST_F(MklToTfConversionPass, Positive) {
|
||||
InitGraph(
|
||||
"node { name: 'A' op: 'Input'}"
|
||||
"node { name: 'M' op: 'MklInput'}"
|
||||
"node { name: 'B' op: 'Input'}"
|
||||
"node { name: 'N' op: 'MklInput'}"
|
||||
"node { name: 'C' op: 'MklConv2D'"
|
||||
" attr { key: 'T' value { type: DT_FLOAT } }"
|
||||
" attr { key: 'data_format' value { s: 'NCHW' } }"
|
||||
" attr { key: 'use_cudnn_on_gpu' value { b: false } }"
|
||||
" attr { key: 'strides' value { list: {i: 1, i:1, i:1, i:1} } }"
|
||||
" attr { key: 'padding' value { s: 'SAME' } }"
|
||||
" input: ['A', 'M', 'B', 'N']}"
|
||||
"node { name: 'D' op: 'Input'}"
|
||||
"node { name: 'E' op: 'Sub'"
|
||||
" attr {key: 'T' value { type: DT_FLOAT } }"
|
||||
" input: ['C', 'D']}");
|
||||
EXPECT_EQ(DoRunMklToTfConversionPass(),
|
||||
"A(Input);B(Input);C(MklConv2D);D(Input);E(Sub);M(MklInput);"
|
||||
"Mkl2Tf/_0(MklToTf);N(MklInput)|A->C;B->C:2;C->Mkl2Tf/_0;"
|
||||
"C:1->Mkl2Tf/_0:1;D->E:1;M->C:1;Mkl2Tf/_0->E;N->C:3");
|
||||
if (kTensorOrdering == MklTfTensorOrdering::TENSORS_INTERLEAVED) {
|
||||
InitGraph(
|
||||
"node { name: 'A' op: 'Input'}"
|
||||
"node { name: 'M' op: '_MklInput'}"
|
||||
"node { name: 'B' op: 'Input'}"
|
||||
"node { name: 'N' op: '_MklInput'}"
|
||||
"node { name: 'C' op: '_MklConv2D'"
|
||||
" attr { key: 'T' value { type: DT_FLOAT } }"
|
||||
" attr { key: 'data_format' value { s: 'NCHW' } }"
|
||||
" attr { key: 'use_cudnn_on_gpu' value { b: false } }"
|
||||
" attr { key: 'strides' value { list: {i: 1, i:1, i:1, i:1} } "
|
||||
"}"
|
||||
" attr { key: 'padding' value { s: 'SAME' } }"
|
||||
" input: ['A', 'M', 'B', 'N']}"
|
||||
"node { name: 'D' op: 'Input'}"
|
||||
"node { name: 'E' op: 'Sub'"
|
||||
" attr {key: 'T' value { type: DT_FLOAT } }"
|
||||
" input: ['C', 'D']}");
|
||||
EXPECT_EQ(DoRunMklToTfConversionPass(),
|
||||
"A(Input);B(Input);C(_MklConv2D);D(Input);E(Sub);M(_MklInput);"
|
||||
"_Mkl2Tf/_0(_MklToTf);N(_MklInput)|A->C;B->C:2;C->Mkl2Tf/_0;"
|
||||
"C:1->Mkl2Tf/_0:1;D->E:1;M->C:1;Mkl2Tf/_0->E;N->C:3");
|
||||
} else {
|
||||
CHECK_EQ(kTensorOrdering, MklTfTensorOrdering::TENSORS_CONTIGUOUS);
|
||||
InitGraph(
|
||||
"node { name: 'A' op: 'Input'}"
|
||||
"node { name: 'B' op: 'Input'}"
|
||||
"node { name: 'M' op: '_MklInput'}"
|
||||
"node { name: 'N' op: '_MklInput'}"
|
||||
"node { name: 'C' op: '_MklConv2D'"
|
||||
" attr { key: 'T' value { type: DT_FLOAT } }"
|
||||
" attr { key: 'data_format' value { s: 'NCHW' } }"
|
||||
" attr { key: 'use_cudnn_on_gpu' value { b: false } }"
|
||||
" attr { key: 'strides' value { list: {i: 1, i:1, i:1, i:1} } "
|
||||
"}"
|
||||
" attr { key: 'padding' value { s: 'SAME' } }"
|
||||
" input: ['A', 'B', 'M', 'N']}"
|
||||
"node { name: 'D' op: 'Input'}"
|
||||
"node { name: 'E' op: 'Sub'"
|
||||
" attr {key: 'T' value { type: DT_FLOAT } }"
|
||||
" input: ['C', 'D']}");
|
||||
EXPECT_EQ(DoRunMklToTfConversionPass(),
|
||||
"A(Input);B(Input);C(_MklConv2D);D(Input);E(Sub);M(_MklInput);"
|
||||
"_Mkl2Tf/_0(_MklToTf);N(_MklInput)|A->C;B->C:1;C->Mkl2Tf/_0;"
|
||||
"C:1->Mkl2Tf/_0:1;D->E:1;M->C:2;Mkl2Tf/_0->E;N->C:3");
|
||||
}
|
||||
}
|
||||
|
||||
// MklConv2D followed by MklToTf op followed by Non-Mkl layer.
|
||||
// C=MklConv2D(A,M,B,N); D=MklToTf(C:0, C:1) F=Sub(D,E)
|
||||
// C=MklConv2D(A,M,B,N); D=MklToTf(C:0, C:1) F=Sub(D,E) (for interleaved)
|
||||
// C=MklConv2D(A,B,M,N); D=MklToTf(C:0, C:1) F=Sub(D,E) (for contiguous)
|
||||
// MklToTf node should not be inserted again.
|
||||
TEST_F(MklToTfConversionPass, Negative_DoubleInsert) {
|
||||
InitGraph(
|
||||
"node { name: 'A' op: 'Input'}"
|
||||
"node { name: 'M' op: 'MklInput'}"
|
||||
"node { name: 'B' op: 'Input'}"
|
||||
"node { name: 'N' op: 'MklInput'}"
|
||||
"node { name: 'C' op: 'MklConv2D'"
|
||||
" attr { key: 'T' value { type: DT_FLOAT } }"
|
||||
" attr { key: 'data_format' value { s: 'NCHW' } }"
|
||||
" attr { key: 'use_cudnn_on_gpu' value { b: false } }"
|
||||
" attr { key: 'strides' value { list: {i: 1, i:1, i:1, i:1} } }"
|
||||
" attr { key: 'padding' value { s: 'SAME' } }"
|
||||
" input: ['A', 'M', 'B', 'N']}"
|
||||
"node { name: 'D' op: 'MklToTf'"
|
||||
" attr { key: 'T' value { type: DT_FLOAT } }"
|
||||
" attr { key: 'data_format' value { s: 'NCHW' } }"
|
||||
" input: ['C:0', 'C:1']}"
|
||||
"node { name: 'E' op: 'Input'}"
|
||||
"node { name: 'F' op: 'Sub'"
|
||||
" attr {key: 'T' value { type: DT_FLOAT } }"
|
||||
" input: ['D', 'E']}");
|
||||
EXPECT_EQ(DoRunMklToTfConversionPass(),
|
||||
"A(Input);B(Input);C(MklConv2D);D(MklToTf);E(Input);"
|
||||
"F(Sub);M(MklInput);N(MklInput)|"
|
||||
"A->C;B->C:2;C->D;C:1->D:1;D->F;E->F:1;M->C:1;N->C:3");
|
||||
if (kTensorOrdering == MklTfTensorOrdering::TENSORS_INTERLEAVED) {
|
||||
InitGraph(
|
||||
"node { name: 'A' op: 'Input'}"
|
||||
"node { name: 'M' op: '_MklInput'}"
|
||||
"node { name: 'B' op: 'Input'}"
|
||||
"node { name: 'N' op: '_MklInput'}"
|
||||
"node { name: 'C' op: '_MklConv2D'"
|
||||
" attr { key: 'T' value { type: DT_FLOAT } }"
|
||||
" attr { key: 'data_format' value { s: 'NCHW' } }"
|
||||
" attr { key: 'use_cudnn_on_gpu' value { b: false } }"
|
||||
" attr { key: 'strides' value { list: {i: 1, i:1, i:1, i:1} } "
|
||||
"}"
|
||||
" attr { key: 'padding' value { s: 'SAME' } }"
|
||||
" input: ['A', 'M', 'B', 'N']}"
|
||||
"node { name: 'D' op: '_MklToTf'"
|
||||
" attr { key: 'T' value { type: DT_FLOAT } }"
|
||||
" attr { key: 'data_format' value { s: 'NCHW' } }"
|
||||
" input: ['C:0', 'C:1']}"
|
||||
"node { name: 'E' op: 'Input'}"
|
||||
"node { name: 'F' op: 'Sub'"
|
||||
" attr {key: 'T' value { type: DT_FLOAT } }"
|
||||
" input: ['D', 'E']}");
|
||||
EXPECT_EQ(DoRunMklToTfConversionPass(),
|
||||
"A(Input);B(Input);C(_MklConv2D);D(_MklToTf);E(Input);"
|
||||
"F(Sub);M(_MklInput);N(_MklInput)|"
|
||||
"A->C;B->C:2;C->D;C:1->D:1;D->F;E->F:1;M->C:1;N->C:3");
|
||||
} else {
|
||||
CHECK_EQ(kTensorOrdering, MklTfTensorOrdering::TENSORS_CONTIGUOUS);
|
||||
InitGraph(
|
||||
"node { name: 'A' op: 'Input'}"
|
||||
"node { name: 'B' op: 'Input'}"
|
||||
"node { name: 'M' op: '_MklInput'}"
|
||||
"node { name: 'N' op: '_MklInput'}"
|
||||
"node { name: 'C' op: '_MklConv2D'"
|
||||
" attr { key: 'T' value { type: DT_FLOAT } }"
|
||||
" attr { key: 'data_format' value { s: 'NCHW' } }"
|
||||
" attr { key: 'use_cudnn_on_gpu' value { b: false } }"
|
||||
" attr { key: 'strides' value { list: {i: 1, i:1, i:1, i:1} } "
|
||||
"}"
|
||||
" attr { key: 'padding' value { s: 'SAME' } }"
|
||||
" input: ['A', 'B', 'M', 'N']}"
|
||||
"node { name: 'D' op: '_MklToTf'"
|
||||
" attr { key: 'T' value { type: DT_FLOAT } }"
|
||||
" attr { key: 'data_format' value { s: 'NCHW' } }"
|
||||
" input: ['C:0', 'C:1']}"
|
||||
"node { name: 'E' op: 'Input'}"
|
||||
"node { name: 'F' op: 'Sub'"
|
||||
" attr {key: 'T' value { type: DT_FLOAT } }"
|
||||
" input: ['D', 'E']}");
|
||||
EXPECT_EQ(DoRunMklToTfConversionPass(),
|
||||
"A(Input);B(Input);C(_MklConv2D);D(_MklToTf);E(Input);"
|
||||
"F(Sub);M(_MklInput);N(_MklInput)|"
|
||||
"A->C;B->C:1;C->D;C:1->D:1;D->F;E->F:1;M->C:2;N->C:3");
|
||||
}
|
||||
}
|
||||
|
||||
// C=Conv2D(A,B); E=BiasAdd(C,D); Z=Sub(E,Y);
|
||||
|
@ -230,7 +230,7 @@ void AutoParallel::BuildGraph(GraphDef* graph) {
|
||||
AddOneReplica(graph, i);
|
||||
}
|
||||
std::set<string> fetches;
|
||||
for (int i = 0; i < item_->fetch.size(); i++) {
|
||||
for (size_t i = 0; i < item_->fetch.size(); i++) {
|
||||
for (int j = 0; j < num_replicas_; j++) {
|
||||
string prefix = strings::StrCat(kAutoParallelPrefix, "-Replica-", j);
|
||||
string fetch = AddPrefixToNodeName(item_->fetch[i], prefix);
|
||||
|
@ -4828,6 +4828,38 @@ tf_mkl_kernel_library(
|
||||
],
|
||||
)
|
||||
|
||||
tf_mkl_kernel_library(
|
||||
name = "mkl_fused_batch_norm_op",
|
||||
srcs = ["mkl_fused_batch_norm_op.cc"],
|
||||
deps = NN_DEPS + [
|
||||
"//third_party/mkl:intel_binary_blob",
|
||||
],
|
||||
)
|
||||
|
||||
tf_mkl_kernel_library(
|
||||
name = "mkl_concat_op",
|
||||
prefix = "mkl_concat_op",
|
||||
deps = ARRAY_DEPS + [
|
||||
"//third_party/mkl:intel_binary_blob",
|
||||
],
|
||||
)
|
||||
|
||||
tf_mkl_kernel_library(
|
||||
name = "mkl_reshape_op",
|
||||
prefix = "mkl_reshape_op",
|
||||
deps = ARRAY_DEPS + [
|
||||
"//third_party/mkl:intel_binary_blob",
|
||||
],
|
||||
)
|
||||
|
||||
tf_mkl_kernel_library(
|
||||
name = "mkl_lrn_op",
|
||||
prefix = "mkl_lrn_op",
|
||||
deps = NN_DEPS + [
|
||||
"//third_party/mkl:intel_binary_blob",
|
||||
],
|
||||
)
|
||||
|
||||
# -----------------------------------------------------------------------------
|
||||
# Google-internal targets. These must be at the end for syncrepo.
|
||||
|
||||
|
@ -28,12 +28,14 @@ namespace tensorflow {
|
||||
class FixedLengthRecordReader : public ReaderBase {
|
||||
public:
|
||||
FixedLengthRecordReader(const string& node_name, int64 header_bytes,
|
||||
int64 record_bytes, int64 footer_bytes, Env* env)
|
||||
int64 record_bytes, int64 footer_bytes,
|
||||
int64 hop_bytes, Env* env)
|
||||
: ReaderBase(
|
||||
strings::StrCat("FixedLengthRecordReader '", node_name, "'")),
|
||||
header_bytes_(header_bytes),
|
||||
record_bytes_(record_bytes),
|
||||
footer_bytes_(footer_bytes),
|
||||
hop_bytes_(hop_bytes),
|
||||
env_(env),
|
||||
file_pos_limit_(-1),
|
||||
record_number_(0) {}
|
||||
@ -62,14 +64,31 @@ class FixedLengthRecordReader : public ReaderBase {
|
||||
|
||||
Status ReadLocked(string* key, string* value, bool* produced,
|
||||
bool* at_end) override {
|
||||
if (input_buffer_->Tell() >= file_pos_limit_) {
|
||||
// The condition `input_buffer_->Tell() + record_bytes_ > file_pos_limit_`
|
||||
// is to confirm that none of record bytes is out of the range of
|
||||
// file_pos_limit_.
|
||||
// This is necessary for the condition `hop_bytes > 0`. For example.
|
||||
// File: "0123456"
|
||||
// Reader setting: `record_bytes=3`, `hop_bytes=2`, `footer_bytes=0`,
|
||||
// `header_bytes=0`
|
||||
// Without this checking condition, the forth time the reader will at
|
||||
// this position: "012345|6" and the reading operation will result in
|
||||
// an error.
|
||||
if (input_buffer_->Tell() >= file_pos_limit_ ||
|
||||
input_buffer_->Tell() + record_bytes_ > file_pos_limit_) {
|
||||
*at_end = true;
|
||||
return Status::OK();
|
||||
}
|
||||
const int64 pos_before_read = input_buffer_->Tell();
|
||||
TF_RETURN_IF_ERROR(input_buffer_->ReadNBytes(record_bytes_, value));
|
||||
*key = strings::StrCat(current_work(), ":", record_number_);
|
||||
*produced = true;
|
||||
++record_number_;
|
||||
|
||||
if (hop_bytes_ > 0) {
|
||||
input_buffer_->Seek(pos_before_read + hop_bytes_).IgnoreError();
|
||||
}
|
||||
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
@ -87,6 +106,7 @@ class FixedLengthRecordReader : public ReaderBase {
|
||||
const int64 header_bytes_;
|
||||
const int64 record_bytes_;
|
||||
const int64 footer_bytes_;
|
||||
const int64 hop_bytes_;
|
||||
Env* const env_;
|
||||
int64 file_pos_limit_;
|
||||
int64 record_number_;
|
||||
@ -98,10 +118,12 @@ class FixedLengthRecordReaderOp : public ReaderOpKernel {
|
||||
public:
|
||||
explicit FixedLengthRecordReaderOp(OpKernelConstruction* context)
|
||||
: ReaderOpKernel(context) {
|
||||
int64 header_bytes = -1, record_bytes = -1, footer_bytes = -1;
|
||||
int64 header_bytes = -1, record_bytes = -1, footer_bytes = -1,
|
||||
hop_bytes = -1;
|
||||
OP_REQUIRES_OK(context, context->GetAttr("header_bytes", &header_bytes));
|
||||
OP_REQUIRES_OK(context, context->GetAttr("record_bytes", &record_bytes));
|
||||
OP_REQUIRES_OK(context, context->GetAttr("footer_bytes", &footer_bytes));
|
||||
OP_REQUIRES_OK(context, context->GetAttr("hop_bytes", &hop_bytes));
|
||||
OP_REQUIRES(context, header_bytes >= 0,
|
||||
errors::InvalidArgument("header_bytes must be >= 0 not ",
|
||||
header_bytes));
|
||||
@ -111,11 +133,15 @@ class FixedLengthRecordReaderOp : public ReaderOpKernel {
|
||||
OP_REQUIRES(context, footer_bytes >= 0,
|
||||
errors::InvalidArgument("footer_bytes must be >= 0 not ",
|
||||
footer_bytes));
|
||||
OP_REQUIRES(
|
||||
context, hop_bytes >= 0,
|
||||
errors::InvalidArgument("hop_bytes must be >= 0 not ", hop_bytes));
|
||||
Env* env = context->env();
|
||||
SetReaderFactory([this, header_bytes, record_bytes, footer_bytes, env]() {
|
||||
return new FixedLengthRecordReader(name(), header_bytes, record_bytes,
|
||||
footer_bytes, env);
|
||||
});
|
||||
SetReaderFactory(
|
||||
[this, header_bytes, record_bytes, footer_bytes, hop_bytes, env]() {
|
||||
return new FixedLengthRecordReader(name(), header_bytes, record_bytes,
|
||||
footer_bytes, hop_bytes, env);
|
||||
});
|
||||
}
|
||||
};
|
||||
|
||||
|
@ -29,10 +29,9 @@ namespace tensorflow {
|
||||
typedef Eigen::ThreadPoolDevice CPUDevice;
|
||||
|
||||
template <typename Device, typename T>
|
||||
class MklAvgPoolingOp : public UnaryOp<T> {
|
||||
class MklAvgPoolingOp : public OpKernel {
|
||||
public:
|
||||
explicit MklAvgPoolingOp(OpKernelConstruction* context)
|
||||
: UnaryOp<T>(context) {
|
||||
explicit MklAvgPoolingOp(OpKernelConstruction* context) : OpKernel(context) {
|
||||
string data_format;
|
||||
OP_REQUIRES_OK(context, context->GetAttr("data_format", &data_format));
|
||||
OP_REQUIRES(context, FormatFromString(data_format, &data_format_),
|
||||
@ -78,6 +77,7 @@ class MklAvgPoolingOp : public UnaryOp<T> {
|
||||
Tensor mkl_tmp_input_buf_tensor_;
|
||||
mkl_context.MklCreateLayoutsAndPrimitives(context,
|
||||
&mkl_tmp_input_buf_tensor_);
|
||||
OP_REQUIRES_OK(context, context->status());
|
||||
|
||||
Tensor workspace_tensor;
|
||||
void* workspace_buf;
|
||||
@ -120,7 +120,7 @@ class MklAvgPoolingOp : public UnaryOp<T> {
|
||||
mkl_out_shape.GetMklLayout())) /
|
||||
sizeof(T));
|
||||
|
||||
AllocateOutputSetMklshape(context, 0, &output, tensor_out_shape,
|
||||
AllocateOutputSetMklShape(context, 0, &output, tensor_out_shape,
|
||||
mkl_out_shape);
|
||||
mkl_context.pooling_res[dnnResourceDst] =
|
||||
static_cast<void*>(output->flat<T>().data());
|
||||
@ -138,9 +138,10 @@ class MklAvgPoolingOp : public UnaryOp<T> {
|
||||
typedef struct {
|
||||
MklPoolingOpParams params;
|
||||
MklShape input_shape;
|
||||
dnnPrimitive_t prim_pooling_fwd, convert_input;
|
||||
dnnLayout_t lt_user_input, lt_prim_input, lt_workspace;
|
||||
void* input_buf;
|
||||
dnnPrimitive_t prim_pooling_fwd = nullptr, convert_input = nullptr;
|
||||
dnnLayout_t lt_user_input = nullptr, lt_prim_input = nullptr,
|
||||
lt_workspace = nullptr;
|
||||
void* input_buf = nullptr;
|
||||
void* pooling_res[dnnResourceNumber];
|
||||
|
||||
void MklCreateLayoutsAndPrimitives(OpKernelContext* context,
|
||||
@ -243,6 +244,11 @@ class MklAvgPoolingGradOp : public OpKernel {
|
||||
pool_params.Init(context, ksize_, stride_, padding_, data_format_,
|
||||
output_shape);
|
||||
|
||||
if (outbackprop_in_mkl_format == false)
|
||||
mkl_context.params.in_dim = out_backprop.dims();
|
||||
else
|
||||
mkl_context.params.in_dim = mkl_context.out_backprop_shape.GetDimension();
|
||||
|
||||
// Extract the parameters for the op from the pooling specs
|
||||
ExtractMklOpParams(context, data_format_, pool_params, &mkl_context.params);
|
||||
|
||||
@ -250,6 +256,7 @@ class MklAvgPoolingGradOp : public OpKernel {
|
||||
Tensor outbackprop_buf_tensor;
|
||||
void* outbackprop_buf;
|
||||
mkl_context.MklCreateLayoutsAndPrimitives(context);
|
||||
OP_REQUIRES_OK(context, context->status());
|
||||
|
||||
// Check if outbackprop layout requires conversion.
|
||||
if (!dnnLayoutCompare_F32(mkl_context.lt_user_outbackprop,
|
||||
@ -304,7 +311,7 @@ class MklAvgPoolingGradOp : public OpKernel {
|
||||
mkl_out_shape.GetMklLayout())) /
|
||||
sizeof(T));
|
||||
|
||||
AllocateOutputSetMklshape(context, 0, &output, tensor_out_shape,
|
||||
AllocateOutputSetMklShape(context, 0, &output, tensor_out_shape,
|
||||
mkl_out_shape);
|
||||
|
||||
// Set output tensor.
|
||||
@ -323,10 +330,10 @@ class MklAvgPoolingGradOp : public OpKernel {
|
||||
typedef struct {
|
||||
MklPoolingOpParams params;
|
||||
MklShape out_backprop_shape;
|
||||
dnnPrimitive_t prim_pooling_bwd, convert_outbackprop;
|
||||
dnnPrimitive_t prim_pooling_bwd = nullptr, convert_outbackprop = nullptr;
|
||||
void* pooling_res[dnnResourceNumber];
|
||||
dnnLayout_t lt_user_input, lt_user_outbackprop, lt_prim_outbackprop,
|
||||
lt_workspace;
|
||||
dnnLayout_t lt_user_input = nullptr, lt_user_outbackprop = nullptr,
|
||||
lt_prim_outbackprop = nullptr, lt_workspace = nullptr;
|
||||
|
||||
void MklCreateLayoutsAndPrimitives(OpKernelContext* context) {
|
||||
const Tensor& tensor_in_shape = MklGetInput(context, 0);
|
||||
@ -348,11 +355,6 @@ class MklAvgPoolingGradOp : public OpKernel {
|
||||
"4-dimensional"));
|
||||
} else {
|
||||
// Input in MKL format.
|
||||
OP_REQUIRES(
|
||||
context, out_backprop.dims() == 2,
|
||||
errors::InvalidArgument("out_backprop in MKL format must be "
|
||||
"2-dimensional"));
|
||||
|
||||
// For avgpooling, out_backprop should have 4 dimensions.
|
||||
OP_REQUIRES(context, out_backprop_shape.GetDimension() == 4,
|
||||
errors::InvalidArgument("out_backprop must be "
|
||||
@ -412,16 +414,16 @@ class MklAvgPoolingGradOp : public OpKernel {
|
||||
TensorFormat data_format_;
|
||||
};
|
||||
|
||||
REGISTER_KERNEL_BUILDER(Name("MklAvgPool")
|
||||
REGISTER_KERNEL_BUILDER(Name("_MklAvgPool")
|
||||
.Device(DEVICE_CPU)
|
||||
.TypeConstraint<float>("T")
|
||||
.Label(mkl_layer_registry::kMklLayerLabel),
|
||||
.Label(mkl_op_registry::kMklOpLabel),
|
||||
MklAvgPoolingOp<CPUDevice, float>);
|
||||
|
||||
REGISTER_KERNEL_BUILDER(Name("MklAvgPoolGrad")
|
||||
REGISTER_KERNEL_BUILDER(Name("_MklAvgPoolGrad")
|
||||
.Device(DEVICE_CPU)
|
||||
.TypeConstraint<float>("T")
|
||||
.Label(mkl_layer_registry::kMklLayerLabel),
|
||||
.Label(mkl_op_registry::kMklOpLabel),
|
||||
MklAvgPoolingGradOp<CPUDevice, float>);
|
||||
|
||||
} // namespace tensorflow
|
||||
|
458
tensorflow/core/kernels/mkl_concat_op.cc
Normal file
458
tensorflow/core/kernels/mkl_concat_op.cc
Normal file
@ -0,0 +1,458 @@
|
||||
/* Copyright 2015 The TensorFlow Authors. All Rights Reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
|
||||
#ifdef INTEL_MKL
|
||||
|
||||
#include <limits>
|
||||
#include <vector>
|
||||
|
||||
#include "third_party/eigen3/unsupported/Eigen/CXX11/Tensor"
|
||||
#include "tensorflow/core/framework/op_kernel.h"
|
||||
#include "tensorflow/core/framework/register_types.h"
|
||||
#include "tensorflow/core/framework/tensor.h"
|
||||
#include "tensorflow/core/framework/tensor_types.h"
|
||||
#include "tensorflow/core/framework/types.h"
|
||||
#include "tensorflow/core/kernels/bounds_check.h"
|
||||
#include "tensorflow/core/kernels/concat_lib.h"
|
||||
#include "tensorflow/core/lib/core/status.h"
|
||||
#include "tensorflow/core/platform/types.h"
|
||||
|
||||
#include "third_party/mkl/include/mkl_dnn.h"
|
||||
#include "third_party/mkl/include/mkl_dnn_types.h"
|
||||
#include "tensorflow/core/util/mkl_util.h"
|
||||
|
||||
namespace tensorflow {
|
||||
typedef Eigen::ThreadPoolDevice CPUDevice;
|
||||
|
||||
enum AxisArgumentName { NAME_IS_AXIS, NAME_IS_CONCAT_DIM };
|
||||
|
||||
// TODO(intelft) Check if we can reuse existing EigenConcatOp using Mutable
|
||||
// reference inputs.
|
||||
// --------------------------------------------------------------------------
|
||||
// Eigen Concat Op
|
||||
// --------------------------------------------------------------------------
|
||||
template <typename Device, typename T, AxisArgumentName AxisArgName>
|
||||
class EigenConcatBaseOp : public OpKernel {
|
||||
public:
|
||||
typedef std::vector<std::unique_ptr<typename TTypes<T, 2>::ConstMatrix>>
|
||||
ConstMatrixVector;
|
||||
|
||||
explicit EigenConcatBaseOp(OpKernelConstruction* c) : OpKernel(c) {}
|
||||
|
||||
// Although, we modify Compute for this call to accept one extra param,
|
||||
// we need to have empty Compute because Compute is pure virtual function.
|
||||
void Compute(OpKernelContext* c) {}
|
||||
|
||||
void Compute(OpKernelContext* c, const std::vector<Tensor>& values) {
|
||||
const Tensor* concat_dim_tensor;
|
||||
const char* axis_attribute_name =
|
||||
AxisArgName == NAME_IS_AXIS
|
||||
? "axis"
|
||||
: AxisArgName == NAME_IS_CONCAT_DIM ? "concat_dim" : "<invalid>";
|
||||
OP_REQUIRES_OK(c, c->input(axis_attribute_name, &concat_dim_tensor));
|
||||
OP_REQUIRES(c, IsLegacyScalar(concat_dim_tensor->shape()),
|
||||
errors::InvalidArgument(
|
||||
axis_attribute_name,
|
||||
" tensor should be a scalar integer, but got shape ",
|
||||
concat_dim_tensor->shape().DebugString()));
|
||||
const int32 concat_dim =
|
||||
internal::SubtleMustCopy(concat_dim_tensor->scalar<int32>()());
|
||||
// Instead of accessing values from context, we use input to Compute.
|
||||
const int N = values.size();
|
||||
const int input_dims = values[0].dims();
|
||||
const TensorShape& input_shape = values[0].shape();
|
||||
|
||||
int32 axis = concat_dim < 0 ? concat_dim + input_dims : concat_dim;
|
||||
OP_REQUIRES(c,
|
||||
(0 <= axis && axis < input_dims) ||
|
||||
(allow_legacy_scalars() && concat_dim == 0),
|
||||
errors::InvalidArgument(
|
||||
"ConcatOp : Expected concatenating dimensions in the range "
|
||||
"[",
|
||||
-input_dims, ", ", input_dims, "), but got ", concat_dim));
|
||||
// Note that we reduce the concat of n-dimensional tensors into a two
|
||||
// dimensional concat. Assuming the dimensions of any input/output
|
||||
// tensor are {x0, x1,...,xn-1, y0, y1,...,ym-1}, where the concat is along
|
||||
// the dimension indicated with size y0, we flatten it to {x, y}, where y =
|
||||
// Prod_i(yi) and x = ((n > 0) ? Prod_i(xi) : 1).
|
||||
ConstMatrixVector inputs_flat;
|
||||
inputs_flat.reserve(N);
|
||||
int64 inputs_flat_dim0 = 1;
|
||||
for (int d = 0; d < axis; ++d) {
|
||||
inputs_flat_dim0 *= input_shape.dim_size(d);
|
||||
}
|
||||
int64 output_concat_dim = 0;
|
||||
const bool input_is_scalar = IsLegacyScalar(input_shape);
|
||||
for (int i = 0; i < N; ++i) {
|
||||
const auto in = values[i];
|
||||
const bool in_is_scalar = IsLegacyScalar(in.shape());
|
||||
OP_REQUIRES(
|
||||
c, in.dims() == input_dims || (input_is_scalar && in_is_scalar),
|
||||
errors::InvalidArgument(
|
||||
"ConcatOp : Ranks of all input tensors should match: shape[0] = ",
|
||||
input_shape.DebugString(), " vs. shape[", i,
|
||||
"] = ", in.shape().DebugString()));
|
||||
for (int j = 0; j < input_dims; ++j) {
|
||||
if (j == axis) {
|
||||
continue;
|
||||
}
|
||||
OP_REQUIRES(
|
||||
c, in.dim_size(j) == input_shape.dim_size(j),
|
||||
errors::InvalidArgument(
|
||||
"ConcatOp : Dimensions of inputs should match: shape[0] = ",
|
||||
input_shape.DebugString(), " vs. shape[", i,
|
||||
"] = ", in.shape().DebugString()));
|
||||
}
|
||||
if (in.NumElements() > 0) {
|
||||
int64 inputs_flat_dim1 = in.NumElements() / inputs_flat_dim0;
|
||||
inputs_flat.emplace_back(new typename TTypes<T, 2>::ConstMatrix(
|
||||
in.shaped<T, 2>({inputs_flat_dim0, inputs_flat_dim1})));
|
||||
}
|
||||
// TODO(irving): Remove check once !allow_legacy_scalars().
|
||||
output_concat_dim += in.dims() > 0 ? in.dim_size(axis) : 1;
|
||||
}
|
||||
|
||||
TensorShape output_shape(input_shape);
|
||||
// TODO(irving): Remove rank 0 case once !allow_legacy_scalars().
|
||||
if (output_shape.dims() == 0) {
|
||||
output_shape.AddDim(output_concat_dim);
|
||||
} else {
|
||||
output_shape.set_dim(axis, output_concat_dim);
|
||||
}
|
||||
Tensor* output = nullptr;
|
||||
OP_REQUIRES_OK(c, c->allocate_output(0, output_shape, &output));
|
||||
if (output->NumElements() > 0) {
|
||||
int64 output_dim1 = output->NumElements() / inputs_flat_dim0;
|
||||
auto output_flat = output->shaped<T, 2>({inputs_flat_dim0, output_dim1});
|
||||
ConcatCPU<T>(c->device(), inputs_flat, &output_flat);
|
||||
}
|
||||
}
|
||||
};
|
||||
|
||||
// --------------------------------------------------------------------------
|
||||
// Mkl Concat Op
|
||||
// --------------------------------------------------------------------------
|
||||
|
||||
template <typename Device, typename T, AxisArgumentName AxisArgName>
|
||||
class MklConcatOp : public OpKernel {
|
||||
private:
|
||||
TensorFormat data_format_;
|
||||
EigenConcatBaseOp<Device, T, AxisArgName> eigen_concat_op_;
|
||||
|
||||
public:
|
||||
typedef std::vector<std::unique_ptr<typename TTypes<T, 2>::ConstMatrix>>
|
||||
ConstMatrixVector;
|
||||
|
||||
explicit MklConcatOp(OpKernelConstruction* c)
|
||||
: OpKernel(c), eigen_concat_op_(c) {}
|
||||
|
||||
void Compute(OpKernelContext* context) override {
|
||||
MklConcatOpContext mkl_context;
|
||||
|
||||
// Get input tensors.
|
||||
OpInputList input_tensors;
|
||||
GetMklInputList(context, "values", &input_tensors);
|
||||
const int N = input_tensors.size();
|
||||
// Get MKL shapes.
|
||||
MklShapeList input_shapes(N);
|
||||
GetMklShapeList(context, "values", &input_shapes);
|
||||
|
||||
// If this is Concat, then concat_dim is 0th input.
|
||||
// If this is ConcatV2, then axis is Nth input.
|
||||
const Tensor& concat_dim_tensor = AxisArgName == NAME_IS_CONCAT_DIM
|
||||
? MklGetInput(context, 0)
|
||||
: MklGetInput(context, N);
|
||||
|
||||
// Sanity checks
|
||||
OP_REQUIRES(
|
||||
context, IsLegacyScalar(concat_dim_tensor.shape()),
|
||||
errors::InvalidArgument(
|
||||
"Concat dim tensor should be a scalar integer, but got shape ",
|
||||
concat_dim_tensor.shape().DebugString()));
|
||||
int32 concat_dim =
|
||||
internal::SubtleMustCopy(concat_dim_tensor.scalar<int32>()());
|
||||
|
||||
MklShape& inpshape0 = input_shapes[0];
|
||||
|
||||
// Check that all tensors are Mkl, if not we call Eigen version.
|
||||
bool invoke_eigen = false;
|
||||
bool is_concat_dim_channel = true;
|
||||
if (!AreAllMklTensors(input_shapes)) {
|
||||
invoke_eigen = true;
|
||||
}
|
||||
|
||||
// Check that total number of dimensions is 4, if not call Eigen.
|
||||
if (!invoke_eigen) {
|
||||
for (auto& s : input_shapes) {
|
||||
if (s.GetDimension() != 4) {
|
||||
invoke_eigen = true;
|
||||
break;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// check that concat_dim is channel, if not call Eigen version.
|
||||
if (!invoke_eigen) {
|
||||
for (auto& s : input_shapes) {
|
||||
if (!s.IsMklChannelDim(concat_dim)) {
|
||||
invoke_eigen = true;
|
||||
is_concat_dim_channel = false;
|
||||
break;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
if (invoke_eigen) {
|
||||
string msg = std::string("Invoking Eigen version of Concat. Reason:") +
|
||||
(!is_concat_dim_channel
|
||||
? std::string("Concat dimension is not channel")
|
||||
: std::string("Not all tensors are in Mkl layout"));
|
||||
VLOG(1) << "_MklConcatOp: " << msg;
|
||||
CallEigenVersion(context, input_tensors, input_shapes);
|
||||
return;
|
||||
}
|
||||
|
||||
// For MKL format, the channel is dimension number 2.
|
||||
// So if we are concating over channel and _all_ inputs are in MKL
|
||||
// format, then we set concat_dim to 2.
|
||||
// Since we have reached till here, it means we are concating
|
||||
// over channel.
|
||||
concat_dim = MklDims::C;
|
||||
|
||||
// One more sanity check: check that ranks of all tensors match
|
||||
// and that their shapes match except for concat_dim.
|
||||
int i = 0;
|
||||
for (auto& s : input_shapes) {
|
||||
size_t exp_dims = inpshape0.GetDimension();
|
||||
OP_REQUIRES(context, s.GetDimension() == exp_dims,
|
||||
errors::InvalidArgument(
|
||||
"_MklConcatOp : Ranks of all input tensors should match:"
|
||||
" input dimensions = ",
|
||||
s.GetDimension(), " vs. expected rank = ", exp_dims));
|
||||
|
||||
for (int d = 0; d < exp_dims; ++d) {
|
||||
if (d == concat_dim) {
|
||||
continue;
|
||||
}
|
||||
|
||||
size_t exp_size = inpshape0.GetSizes()[d];
|
||||
OP_REQUIRES(
|
||||
context, exp_size == s.GetSizes()[d],
|
||||
errors::InvalidArgument("_MklConcatOp : Dimensions of inputs"
|
||||
"should match: shape[0][",
|
||||
d, "]= ", exp_size, " vs. shape[", i, "][",
|
||||
d, "] = ", s.GetSizes()[d]));
|
||||
}
|
||||
++i;
|
||||
}
|
||||
|
||||
// Use input MKL layout instead of creating new layouts.
|
||||
int64 output_concat_dim_size = 0;
|
||||
for (auto& s : input_shapes) {
|
||||
output_concat_dim_size +=
|
||||
s.GetDimension() > 0 ? s.GetSizes()[concat_dim] : 1;
|
||||
}
|
||||
mkl_context.MklCreateInputLayouts(context, input_shapes);
|
||||
|
||||
CHECK_EQ(dnnConcatCreate_F32(&mkl_context.prim_concat, NULL, N,
|
||||
&mkl_context.lt_inputs[0]),
|
||||
E_SUCCESS);
|
||||
|
||||
// Calculate output sizes and strides
|
||||
TensorFormat data_format;
|
||||
if (inpshape0.IsTensorInNHWCFormat()) {
|
||||
data_format = FORMAT_NHWC;
|
||||
} else {
|
||||
OP_REQUIRES(
|
||||
context, inpshape0.IsTensorInNCHWFormat(),
|
||||
errors::InvalidArgument(
|
||||
"_MklConcat only supports all inputs in NCHW or NHWC format "));
|
||||
data_format = FORMAT_NCHW;
|
||||
}
|
||||
|
||||
// Since all tensors are in Mkl layout, we copy sizes from input tensor.
|
||||
mkl_context.out_sizes[MklDims::W] = inpshape0.GetSizes()[MklDims::W];
|
||||
mkl_context.out_sizes[MklDims::H] = inpshape0.GetSizes()[MklDims::H];
|
||||
mkl_context.out_sizes[MklDims::C] = output_concat_dim_size;
|
||||
mkl_context.out_sizes[MklDims::N] = inpshape0.GetSizes()[MklDims::N];
|
||||
GetStridesFromSizes(data_format, mkl_context.out_strides,
|
||||
mkl_context.out_sizes);
|
||||
|
||||
// Set output Mkl shape.
|
||||
int64 dim = 4;
|
||||
MklShape mkl_output_mkl_shape;
|
||||
mkl_output_mkl_shape.SetMklTensor(true);
|
||||
mkl_output_mkl_shape.SetMklLayout(mkl_context.prim_concat, dnnResourceDst);
|
||||
mkl_output_mkl_shape.SetTfLayout(dim, mkl_context.out_sizes,
|
||||
mkl_context.out_strides);
|
||||
mkl_output_mkl_shape.SetTfDimOrder(dim, inpshape0.GetTfToMklDimMap());
|
||||
|
||||
TensorShape mkl_output_tf_shape;
|
||||
mkl_output_tf_shape.AddDim(1);
|
||||
mkl_output_tf_shape.AddDim(
|
||||
dnnLayoutGetMemorySize_F32(
|
||||
static_cast<dnnLayout_t>(mkl_output_mkl_shape.GetMklLayout())) /
|
||||
sizeof(T));
|
||||
|
||||
Tensor* output = nullptr;
|
||||
AllocateOutputSetMklShape(context, 0, &output, mkl_output_tf_shape,
|
||||
mkl_output_mkl_shape);
|
||||
|
||||
// Set destination resource.
|
||||
mkl_context.concat_res[dnnResourceDst] =
|
||||
const_cast<void*>(static_cast<const void*>(output->flat<T>().data()));
|
||||
|
||||
mkl_context.mkl_tmp_tensors.resize(N);
|
||||
mkl_context.MklPrepareConcatInputs(context, input_tensors);
|
||||
|
||||
// Execute primitive.
|
||||
CHECK_EQ(dnnExecute_F32(mkl_context.prim_concat, mkl_context.concat_res),
|
||||
E_SUCCESS);
|
||||
|
||||
mkl_context.MklCleanup();
|
||||
}
|
||||
|
||||
private:
|
||||
typedef struct {
|
||||
TensorFormat data_format;
|
||||
size_t out_sizes[4];
|
||||
size_t out_strides[4];
|
||||
dnnPrimitive_t prim_concat;
|
||||
void* concat_res[dnnResourceNumber];
|
||||
std::vector<dnnLayout_t> lt_inputs;
|
||||
std::vector<Tensor> mkl_tmp_tensors;
|
||||
|
||||
// Create MKL dnnLayout_t objects for tensors coming into the layer
|
||||
// We only support case where input tensors are all in Mkl layout.
|
||||
void MklCreateInputLayouts(OpKernelContext* context,
|
||||
MklShapeList& input_shapes) {
|
||||
for (auto& is : input_shapes) {
|
||||
CHECK_EQ(is.IsMklTensor(), true);
|
||||
lt_inputs.push_back((dnnLayout_t)is.GetCurLayout());
|
||||
}
|
||||
}
|
||||
|
||||
void MklPrepareConcatInputs(OpKernelContext* context,
|
||||
OpInputList& input_tensors) {
|
||||
CHECK_EQ(lt_inputs.size(), mkl_tmp_tensors.size());
|
||||
|
||||
for (int i = 0; i < lt_inputs.size(); ++i) {
|
||||
dnnPrimitive_t mkl_prim_convert_input;
|
||||
dnnLayout_t mkl_lt_internal_input;
|
||||
void* mkl_buf_convert_input = nullptr;
|
||||
|
||||
CHECK_EQ(dnnLayoutCreateFromPrimitive_F32(
|
||||
&mkl_lt_internal_input, prim_concat,
|
||||
(dnnResourceType_t)(dnnResourceMultipleSrc + i)),
|
||||
E_SUCCESS);
|
||||
|
||||
if (!dnnLayoutCompare_F32(lt_inputs[i], mkl_lt_internal_input)) {
|
||||
CHECK_EQ(dnnConversionCreate_F32(&mkl_prim_convert_input,
|
||||
lt_inputs[i], mkl_lt_internal_input),
|
||||
E_SUCCESS);
|
||||
|
||||
AllocTmpBuffer(context, &mkl_tmp_tensors[i], mkl_lt_internal_input,
|
||||
&mkl_buf_convert_input);
|
||||
|
||||
CHECK_EQ(dnnConversionExecute_F32(
|
||||
mkl_prim_convert_input,
|
||||
const_cast<void*>(static_cast<const void*>(
|
||||
input_tensors[i].flat<T>().data())),
|
||||
mkl_buf_convert_input),
|
||||
E_SUCCESS);
|
||||
|
||||
concat_res[dnnResourceMultipleSrc + i] = mkl_buf_convert_input;
|
||||
CHECK_EQ(dnnDelete_F32(mkl_prim_convert_input), E_SUCCESS);
|
||||
} else {
|
||||
concat_res[dnnResourceMultipleSrc + i] = const_cast<void*>(
|
||||
static_cast<const void*>(input_tensors[i].flat<T>().data()));
|
||||
}
|
||||
|
||||
CHECK_EQ(dnnLayoutDelete_F32(mkl_lt_internal_input), E_SUCCESS);
|
||||
}
|
||||
}
|
||||
|
||||
void MklCleanup() {
|
||||
for (auto& lt : lt_inputs) {
|
||||
lt = nullptr;
|
||||
}
|
||||
CHECK_EQ(dnnDelete_F32(prim_concat), E_SUCCESS);
|
||||
}
|
||||
} MklConcatOpContext;
|
||||
|
||||
void CallEigenVersion(OpKernelContext* context, const OpInputList& values,
|
||||
const MklShapeList& input_shapes) {
|
||||
// Before calling Eigen version, we need to convert Mkl tensors to TF.
|
||||
// First check that the number of input tensors and the number of Mkl
|
||||
// shapes match.
|
||||
CHECK_EQ(values.size(), input_shapes.size());
|
||||
|
||||
std::vector<Tensor> converted_values;
|
||||
for (int i = 0; i < input_shapes.size(); i++) {
|
||||
if (input_shapes[i].IsMklTensor()) {
|
||||
// If input tensor is Mkl, then do the conversion.
|
||||
Tensor tmp_tensor =
|
||||
ConvertMklToTF<T>(context, values[i], input_shapes[i]);
|
||||
converted_values.push_back(tmp_tensor);
|
||||
} else {
|
||||
// If input tensor is TF already, then we do not need any conversion.
|
||||
converted_values.push_back(values[i]);
|
||||
}
|
||||
}
|
||||
|
||||
// Call Eigen concat.
|
||||
eigen_concat_op_.Compute(context, converted_values);
|
||||
|
||||
// Set dummy Mkl tensor as output Mkl tensor for this op.
|
||||
MklShape mkl_tensor_mkl_shape;
|
||||
mkl_tensor_mkl_shape.SetMklTensor(false);
|
||||
mkl_tensor_mkl_shape.SetDimensions(4);
|
||||
mkl_tensor_mkl_shape.SetTfDimOrder(4); // Dimensions
|
||||
Tensor* mkl_tensor = nullptr;
|
||||
TensorShape mkl_tensor_tf_shape;
|
||||
mkl_tensor_tf_shape.AddDim(
|
||||
SIZE_OF_MKL_SERIAL_DATA(mkl_tensor_mkl_shape.GetDimension()));
|
||||
int tf_output_index = 0;
|
||||
context->allocate_output(
|
||||
GetTensorMetaDataIndex(tf_output_index, context->num_outputs()),
|
||||
mkl_tensor_tf_shape, &mkl_tensor);
|
||||
mkl_tensor_mkl_shape.SerializeMklShape(
|
||||
mkl_tensor->flat<uint8>().data(),
|
||||
mkl_tensor->flat<uint8>().size() * sizeof(uint8));
|
||||
}
|
||||
};
|
||||
|
||||
/* Use optimized concat for float type only */
|
||||
#define REGISTER_MKL_CPU(type) \
|
||||
REGISTER_KERNEL_BUILDER(Name("_MklConcat") \
|
||||
.Device(DEVICE_CPU) \
|
||||
.TypeConstraint<type>("T") \
|
||||
.HostMemory("concat_dim") \
|
||||
.Label(mkl_op_registry::kMklOpLabel), \
|
||||
MklConcatOp<CPUDevice, type, NAME_IS_CONCAT_DIM>) \
|
||||
REGISTER_KERNEL_BUILDER(Name("_MklConcatV2") \
|
||||
.Device(DEVICE_CPU) \
|
||||
.TypeConstraint<type>("T") \
|
||||
.TypeConstraint<int32>("Tidx") \
|
||||
.HostMemory("axis") \
|
||||
.Label(mkl_op_registry::kMklOpLabel), \
|
||||
MklConcatOp<CPUDevice, type, NAME_IS_AXIS>)
|
||||
|
||||
TF_CALL_float(REGISTER_MKL_CPU);
|
||||
|
||||
#undef REGISTER_CONCAT_MKL
|
||||
} // namespace tensorflow
|
||||
|
||||
#endif // INTEL_MKL
|
@ -87,7 +87,7 @@ class MklConv2DCustomBackpropBiasOp : public OpKernel {
|
||||
Tensor* bias_backprop = nullptr;
|
||||
MklShape output_mkl_shape;
|
||||
output_mkl_shape.SetMklTensor(false);
|
||||
AllocateOutputSetMklshape(context, 0, &bias_backprop, output_shape,
|
||||
AllocateOutputSetMklShape(context, 0, &bias_backprop, output_shape,
|
||||
output_mkl_shape);
|
||||
|
||||
mkl_context.in_dims = 4;
|
||||
@ -251,11 +251,11 @@ class MklConv2DCustomBackpropBiasOp : public OpKernel {
|
||||
TF_DISALLOW_COPY_AND_ASSIGN(MklConv2DCustomBackpropBiasOp);
|
||||
};
|
||||
|
||||
#define REGISTER_CPU_KERNELS(T) \
|
||||
REGISTER_KERNEL_BUILDER(Name("MklConv2DWithBiasBackpropBias") \
|
||||
.Device(DEVICE_CPU) \
|
||||
.TypeConstraint<T>("T") \
|
||||
.Label(mkl_layer_registry::kMklLayerLabel), \
|
||||
#define REGISTER_CPU_KERNELS(T) \
|
||||
REGISTER_KERNEL_BUILDER(Name("_MklConv2DWithBiasBackpropBias") \
|
||||
.Device(DEVICE_CPU) \
|
||||
.TypeConstraint<T>("T") \
|
||||
.Label(mkl_op_registry::kMklOpLabel), \
|
||||
MklConv2DCustomBackpropBiasOp<CPUDevice, T>);
|
||||
|
||||
TF_CALL_float(REGISTER_CPU_KERNELS);
|
||||
|
@ -217,7 +217,7 @@ class MklConv2DCustomBackpropFilterOp : public OpKernel {
|
||||
mkl_context.grad_filter_shape.SetTfLayout(mkl_context.filter_dims,
|
||||
mkl_context.filter_sizes,
|
||||
mkl_context.filter_strides);
|
||||
AllocateOutputSetMklshape(context, 0, &grad_filter, filter_shape,
|
||||
AllocateOutputSetMklShape(context, 0, &grad_filter, filter_shape,
|
||||
mkl_context.grad_filter_shape);
|
||||
|
||||
// Need to set member variable for TF layout
|
||||
@ -408,11 +408,11 @@ class MklConv2DCustomBackpropFilterOp : public OpKernel {
|
||||
TensorFormat data_format_;
|
||||
};
|
||||
|
||||
#define REGISTER_MKL_FILTER_KERNELS(T) \
|
||||
REGISTER_KERNEL_BUILDER(Name("MklConv2DBackpropFilter") \
|
||||
.Device(DEVICE_CPU) \
|
||||
.TypeConstraint<T>("T") \
|
||||
.Label(mkl_layer_registry::kMklLayerLabel), \
|
||||
#define REGISTER_MKL_FILTER_KERNELS(T) \
|
||||
REGISTER_KERNEL_BUILDER(Name("_MklConv2DBackpropFilter") \
|
||||
.Device(DEVICE_CPU) \
|
||||
.TypeConstraint<T>("T") \
|
||||
.Label(mkl_op_registry::kMklOpLabel), \
|
||||
MklConv2DCustomBackpropFilterOp<CPUDevice, T>);
|
||||
|
||||
TF_CALL_float(REGISTER_MKL_FILTER_KERNELS);
|
||||
|
@ -202,7 +202,7 @@ class MklConv2DCustomBackpropInputOp : public OpKernel {
|
||||
mkl_out_shape.AddDim(dnnLayoutGetMemorySize_F32(static_cast<dnnLayout_t>(
|
||||
mklOutputShape.GetMklLayout())) /
|
||||
sizeof(T));
|
||||
AllocateOutputSetMklshape(context, 0, &in_backprop, mkl_out_shape,
|
||||
AllocateOutputSetMklShape(context, 0, &in_backprop, mkl_out_shape,
|
||||
mklOutputShape);
|
||||
|
||||
mkl_context.conv_res[dnnResourceDiffSrc] =
|
||||
@ -341,11 +341,11 @@ class MklConv2DCustomBackpropInputOp : public OpKernel {
|
||||
TensorFormat data_format;
|
||||
};
|
||||
|
||||
#define REGISTER_MKL_CPU_KERNELS(T) \
|
||||
REGISTER_KERNEL_BUILDER(Name("MklConv2DBackpropInput") \
|
||||
.Device(DEVICE_CPU) \
|
||||
.TypeConstraint<T>("T") \
|
||||
.Label(mkl_layer_registry::kMklLayerLabel), \
|
||||
#define REGISTER_MKL_CPU_KERNELS(T) \
|
||||
REGISTER_KERNEL_BUILDER(Name("_MklConv2DBackpropInput") \
|
||||
.Device(DEVICE_CPU) \
|
||||
.TypeConstraint<T>("T") \
|
||||
.Label(mkl_op_registry::kMklOpLabel), \
|
||||
MklConv2DCustomBackpropInputOp<CPUDevice, T>);
|
||||
|
||||
TF_CALL_float(REGISTER_MKL_CPU_KERNELS);
|
||||
|
@ -178,7 +178,7 @@ class MklConv2DOp : public OpKernel {
|
||||
// Nothing to do, allocate output tensor and return
|
||||
MklShape mkl_output_mkl_shape;
|
||||
mkl_output_mkl_shape.SetMklTensor(false);
|
||||
AllocateOutputSetMklshape(context, 0, &output, input.shape(),
|
||||
AllocateOutputSetMklShape(context, 0, &output, input.shape(),
|
||||
mkl_output_mkl_shape);
|
||||
return;
|
||||
}
|
||||
@ -264,7 +264,7 @@ class MklConv2DOp : public OpKernel {
|
||||
dnnLayoutGetMemorySize_F32(
|
||||
static_cast<dnnLayout_t>(mkl_output_mkl_shape.GetMklLayout())) /
|
||||
sizeof(T));
|
||||
AllocateOutputSetMklshape(context, 0, &output, mkl_output_tf_shape,
|
||||
AllocateOutputSetMklShape(context, 0, &output, mkl_output_tf_shape,
|
||||
mkl_output_mkl_shape);
|
||||
mkl_context.conv_res[dnnResourceDst] =
|
||||
static_cast<void*>(output->flat<T>().data());
|
||||
@ -437,16 +437,16 @@ class MklConv2DOp : public OpKernel {
|
||||
TensorFormat data_format_;
|
||||
};
|
||||
|
||||
#define REGISTER_MKL_CPU(T) \
|
||||
REGISTER_KERNEL_BUILDER(Name("MklConv2D") \
|
||||
.Device(DEVICE_CPU) \
|
||||
.TypeConstraint<T>("T") \
|
||||
.Label(mkl_layer_registry::kMklLayerLabel), \
|
||||
MklConv2DOp<CPUDevice, T, false>); \
|
||||
REGISTER_KERNEL_BUILDER(Name("MklConv2DWithBias") \
|
||||
.Device(DEVICE_CPU) \
|
||||
.TypeConstraint<T>("T") \
|
||||
.Label(mkl_layer_registry::kMklLayerLabel), \
|
||||
#define REGISTER_MKL_CPU(T) \
|
||||
REGISTER_KERNEL_BUILDER(Name("_MklConv2D") \
|
||||
.Device(DEVICE_CPU) \
|
||||
.TypeConstraint<T>("T") \
|
||||
.Label(mkl_op_registry::kMklOpLabel), \
|
||||
MklConv2DOp<CPUDevice, T, false>); \
|
||||
REGISTER_KERNEL_BUILDER(Name("_MklConv2DWithBias") \
|
||||
.Device(DEVICE_CPU) \
|
||||
.TypeConstraint<T>("T") \
|
||||
.Label(mkl_op_registry::kMklOpLabel), \
|
||||
MklConv2DOp<CPUDevice, T, true>);
|
||||
|
||||
TF_CALL_float(REGISTER_MKL_CPU);
|
||||
|
689
tensorflow/core/kernels/mkl_fused_batch_norm_op.cc
Normal file
689
tensorflow/core/kernels/mkl_fused_batch_norm_op.cc
Normal file
@ -0,0 +1,689 @@
|
||||
/* Copyright 2016 The TensorFlow Authors. All Rights Reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
#ifdef INTEL_MKL
|
||||
|
||||
#include "third_party/eigen3/unsupported/Eigen/CXX11/Tensor"
|
||||
#include "tensorflow/core/framework/op_kernel.h"
|
||||
#include "tensorflow/core/framework/register_types.h"
|
||||
#include "tensorflow/core/framework/tensor.h"
|
||||
#include "tensorflow/core/framework/tensor_types.h"
|
||||
#include "tensorflow/core/util/tensor_format.h"
|
||||
|
||||
#include "third_party/mkl/include/mkl_dnn.h"
|
||||
#include "third_party/mkl/include/mkl_dnn_types.h"
|
||||
#include "tensorflow/core/util/mkl_util.h"
|
||||
|
||||
// TODO(inteltf) Address comments from PR 8968.
|
||||
|
||||
namespace tensorflow {
|
||||
using CPUDevice = Eigen::ThreadPoolDevice;
|
||||
template <typename Device, typename T>
|
||||
class MklFusedBatchNormOp : public OpKernel {
|
||||
public:
|
||||
explicit MklFusedBatchNormOp(OpKernelConstruction* context)
|
||||
: OpKernel(context) {
|
||||
float epsilon;
|
||||
OP_REQUIRES_OK(context, context->GetAttr("epsilon", &epsilon));
|
||||
epsilon_ = T(epsilon);
|
||||
string tensor_format;
|
||||
OP_REQUIRES_OK(context, context->GetAttr("data_format", &tensor_format));
|
||||
OP_REQUIRES(context, FormatFromString(tensor_format, &tensor_format_),
|
||||
errors::InvalidArgument("Invalid data format"));
|
||||
OP_REQUIRES_OK(context, context->GetAttr("is_training", &is_training_));
|
||||
}
|
||||
|
||||
void Compute(OpKernelContext* context) override {
|
||||
MklFusedBatchNormOpContext mkl_context;
|
||||
|
||||
const Tensor& input = MklGetInput(context, 0);
|
||||
const Tensor& scale = MklGetInput(context, 1);
|
||||
const Tensor& shift = MklGetInput(context, 2);
|
||||
const Tensor& est_mean = MklGetInput(context, 3);
|
||||
const Tensor& est_variance = MklGetInput(context, 4);
|
||||
|
||||
GetMklShape(context, 0, &(mkl_context.mkl_shape_input_shape));
|
||||
bool input_in_mkl_format = mkl_context.mkl_shape_input_shape.IsMklTensor();
|
||||
if (!input_in_mkl_format) {
|
||||
OP_REQUIRES(context, input.dims() == 4,
|
||||
errors::InvalidArgument("input must be 4-dimensional",
|
||||
input.shape().DebugString()));
|
||||
}
|
||||
OP_REQUIRES(context, scale.dims() == 1,
|
||||
errors::InvalidArgument("scale must be 1-dimensional",
|
||||
scale.shape().DebugString()));
|
||||
OP_REQUIRES(context, shift.dims() == 1,
|
||||
errors::InvalidArgument("offset must be 1-dimensional",
|
||||
shift.shape().DebugString()));
|
||||
OP_REQUIRES(context, est_mean.dims() == 1,
|
||||
errors::InvalidArgument("estimated_mean must be 1-dimensional",
|
||||
est_mean.shape().DebugString()));
|
||||
OP_REQUIRES(
|
||||
context, est_variance.dims() == 1,
|
||||
errors::InvalidArgument("estimated_variance must be 1-dimensional",
|
||||
est_variance.shape().DebugString()));
|
||||
if (is_training_) {
|
||||
OP_REQUIRES(context, est_mean.dim_size(0) == 0,
|
||||
errors::InvalidArgument("estimated_mean empty for training",
|
||||
est_mean.shape().DebugString()));
|
||||
OP_REQUIRES(context, est_variance.dim_size(0) == 0,
|
||||
errors::InvalidArgument(
|
||||
"estimated_variance must be empty for training",
|
||||
est_variance.shape().DebugString()));
|
||||
}
|
||||
|
||||
unsigned int flag_batch_norm =
|
||||
is_training_ ? dnnUseScaleShift
|
||||
: (dnnUseInputMeanVariance | dnnUseScaleShift);
|
||||
|
||||
mkl_context.MklExtractParams(context, tensor_format_);
|
||||
|
||||
// Create layout only for input data as it is used in Op primitive.
|
||||
mkl_context.MklCreateInputLayout(context);
|
||||
|
||||
// Create Op primitive.
|
||||
CHECK_EQ(dnnBatchNormalizationCreateForward_v2_F32(
|
||||
&(mkl_context.mkl_prim_batchnorm), nullptr,
|
||||
mkl_context.mkl_lt_input, static_cast<float>(epsilon_),
|
||||
flag_batch_norm),
|
||||
E_SUCCESS);
|
||||
|
||||
// Temporary tensors with buffers for the context inputs, if
|
||||
// conversion to MKL-Op specific layouts are required. It is assumed here
|
||||
// that TF's 1D tensors (scale, shift, est_mean, and est_variance) won't
|
||||
// require any conversion.
|
||||
// Since scale-shift is combined in MKL, a buffer is required.
|
||||
Tensor mkl_tmp_input_buf_tensor, mkl_tmp_scale_shift_buf_tensor;
|
||||
mkl_context.MklPrepareContextInputs(context, &mkl_tmp_input_buf_tensor,
|
||||
&mkl_tmp_scale_shift_buf_tensor);
|
||||
|
||||
// Output data in MKL layout
|
||||
Tensor* output = nullptr;
|
||||
TensorShape tf_shape_output;
|
||||
MklShape mkl_shape_output;
|
||||
mkl_shape_output.SetMklTensor(true);
|
||||
mkl_shape_output.SetMklLayout(mkl_context.mkl_prim_batchnorm,
|
||||
dnnResourceDst);
|
||||
mkl_shape_output.SetTfLayout(mkl_context.mkl_params.in_dim,
|
||||
mkl_context.mkl_params.in_sizes,
|
||||
mkl_context.mkl_params.in_strides);
|
||||
mkl_shape_output.SetTfDimOrder(mkl_context.mkl_params.in_dim,
|
||||
tensor_format_);
|
||||
tf_shape_output.AddDim(dnnLayoutGetMemorySize_F32(static_cast<dnnLayout_t>(
|
||||
mkl_shape_output.GetMklLayout())) /
|
||||
sizeof(T));
|
||||
AllocateOutputSetMklShape(context, 0, &output, tf_shape_output,
|
||||
mkl_shape_output);
|
||||
mkl_context.mkl_res_batchnorm[dnnResourceDst] =
|
||||
static_cast<void*>(output->flat<T>().data());
|
||||
|
||||
// Batch mean in TF layout
|
||||
Tensor* batch_mean = nullptr;
|
||||
MklShape mkl_shape_batch_mean;
|
||||
mkl_shape_batch_mean.SetMklTensor(false);
|
||||
AllocateOutputSetMklShape(context, 1, &batch_mean, scale.shape(),
|
||||
mkl_shape_batch_mean);
|
||||
// Batch variance in TF layout
|
||||
Tensor* batch_variance = nullptr;
|
||||
MklShape mkl_shape_batch_variance;
|
||||
mkl_shape_batch_variance.SetMklTensor(false);
|
||||
AllocateOutputSetMklShape(context, 2, &batch_variance, scale.shape(),
|
||||
mkl_shape_batch_variance);
|
||||
// If training mode, set dnnResourceMean and dnnResourceVariance to
|
||||
// output tensors for batch mean and variance.
|
||||
// Otherwise, set dnnResourceMean and dnnResourceVariance to
|
||||
// estimated mean and variance.
|
||||
if (is_training_)
|
||||
mkl_context.MklSetMeanVariance(*batch_mean, *batch_variance);
|
||||
else
|
||||
mkl_context.MklSetMeanVariance(est_mean, est_variance);
|
||||
|
||||
// Now that all resources are set, it is ready for dnnExecute
|
||||
CHECK_EQ(dnnExecute_F32(mkl_context.mkl_prim_batchnorm,
|
||||
mkl_context.mkl_res_batchnorm),
|
||||
E_SUCCESS);
|
||||
|
||||
// Mean and variance (without Bessel's correction) saved for backward
|
||||
// computation to serve as pre-computed mean and variance.
|
||||
Tensor* saved_mean = nullptr;
|
||||
MklShape mkl_shape_saved_mean;
|
||||
mkl_shape_saved_mean.SetMklTensor(false);
|
||||
AllocateOutputSetMklShape(context, 3, &saved_mean, scale.shape(),
|
||||
mkl_shape_saved_mean);
|
||||
std::memcpy(
|
||||
reinterpret_cast<char*>(saved_mean->flat<float>().data()),
|
||||
reinterpret_cast<char*>(mkl_context.mkl_res_batchnorm[dnnResourceMean]),
|
||||
scale.NumElements() * sizeof(float));
|
||||
Tensor* saved_variance = nullptr;
|
||||
MklShape mkl_shape_saved_variance;
|
||||
mkl_shape_saved_variance.SetMklTensor(false);
|
||||
AllocateOutputSetMklShape(context, 4, &saved_variance, scale.shape(),
|
||||
mkl_shape_saved_variance);
|
||||
std::memcpy(reinterpret_cast<char*>(saved_variance->flat<float>().data()),
|
||||
reinterpret_cast<char*>(
|
||||
mkl_context.mkl_res_batchnorm[dnnResourceVariance]),
|
||||
scale.NumElements() * sizeof(float));
|
||||
|
||||
// Bessel's correction on variance, if training mode is on
|
||||
if (is_training_) {
|
||||
float* p_var = static_cast<float*>(batch_variance->flat<T>().data());
|
||||
auto depth = mkl_context.mkl_params.depth;
|
||||
size_t orig_size = mkl_context.mkl_params.in_sizes[0] *
|
||||
mkl_context.mkl_params.in_sizes[1] *
|
||||
mkl_context.mkl_params.in_sizes[3];
|
||||
size_t adjust_size = orig_size - 1;
|
||||
float adjust_factor = (static_cast<float>(orig_size)) / adjust_size;
|
||||
for (int i = 0; i < depth; i++) p_var[i] = adjust_factor * p_var[i];
|
||||
}
|
||||
|
||||
mkl_context.MklCleanup();
|
||||
}
|
||||
|
||||
private:
|
||||
T epsilon_;
|
||||
TensorFormat tensor_format_;
|
||||
bool is_training_;
|
||||
|
||||
// Structure containing all info for MklOp
|
||||
typedef struct {
|
||||
// Parameters used for input and output layouts
|
||||
struct MklBatchNormParams {
|
||||
// BatchNormOp src and
|
||||
size_t in_dim;
|
||||
size_t in_sizes[4];
|
||||
size_t in_strides[4];
|
||||
size_t depth; // Batch normalization is done for per channel.
|
||||
} mkl_params;
|
||||
|
||||
MklShape mkl_shape_input_shape;
|
||||
|
||||
// MKL primitive and resources for BatchNormOp
|
||||
dnnPrimitive_t mkl_prim_batchnorm = nullptr;
|
||||
void* mkl_res_batchnorm[dnnResourceNumber];
|
||||
|
||||
// MKL layouts for inputs in the context
|
||||
dnnLayout_t mkl_lt_input = nullptr;
|
||||
|
||||
void MklCleanup() {
|
||||
bool input_in_mkl_format = mkl_shape_input_shape.IsMklTensor();
|
||||
if (!input_in_mkl_format) dnnLayoutDelete_F32(mkl_lt_input);
|
||||
if (mkl_prim_batchnorm != nullptr) dnnDelete_F32(mkl_prim_batchnorm);
|
||||
}
|
||||
|
||||
void MklExtractParams(OpKernelContext* context,
|
||||
const TensorFormat& tensor_format) {
|
||||
const Tensor& input = MklGetInput(context, 0);
|
||||
bool input_in_mkl_format = mkl_shape_input_shape.IsMklTensor();
|
||||
mkl_params.in_dim = input_in_mkl_format
|
||||
? mkl_shape_input_shape.GetDimension()
|
||||
: input.dims();
|
||||
mkl_params.in_sizes[0] = static_cast<size_t>(
|
||||
input_in_mkl_format ? mkl_shape_input_shape.GetSizes()[0]
|
||||
: GetTensorDim(input, tensor_format, 'W'));
|
||||
mkl_params.in_sizes[1] = static_cast<size_t>(
|
||||
input_in_mkl_format ? mkl_shape_input_shape.GetSizes()[1]
|
||||
: GetTensorDim(input, tensor_format, 'H'));
|
||||
mkl_params.in_sizes[2] = static_cast<size_t>(
|
||||
input_in_mkl_format ? mkl_shape_input_shape.GetSizes()[2]
|
||||
: GetTensorDim(input, tensor_format, 'C'));
|
||||
mkl_params.in_sizes[3] = static_cast<size_t>(
|
||||
input_in_mkl_format ? mkl_shape_input_shape.GetSizes()[3]
|
||||
: GetTensorDim(input, tensor_format, 'N'));
|
||||
mkl_params.depth = mkl_params.in_sizes[2];
|
||||
GetStridesFromSizes(tensor_format, mkl_params.in_strides,
|
||||
mkl_params.in_sizes);
|
||||
}
|
||||
|
||||
void MklCreateInputLayout(OpKernelContext* context) {
|
||||
const Tensor& input = MklGetInput(context, 0);
|
||||
bool input_in_mkl_format = mkl_shape_input_shape.IsMklTensor();
|
||||
if (input_in_mkl_format) {
|
||||
mkl_lt_input =
|
||||
static_cast<dnnLayout_t>(mkl_shape_input_shape.GetCurLayout());
|
||||
} else {
|
||||
CHECK_EQ(
|
||||
dnnLayoutCreate_F32(&mkl_lt_input, mkl_params.in_dim,
|
||||
mkl_params.in_sizes, mkl_params.in_strides),
|
||||
E_SUCCESS);
|
||||
}
|
||||
}
|
||||
|
||||
void MklPrepareContextInputs(OpKernelContext* context,
|
||||
Tensor* mkl_tmp_input_buf_tensor,
|
||||
Tensor* mkl_tmp_scale_shift_buf_tensor) {
|
||||
bool mkl_convert_input;
|
||||
dnnPrimitive_t mkl_prim_convert_input = nullptr;
|
||||
dnnLayout_t mkl_lt_internal_input = nullptr;
|
||||
void* mkl_buf_converted_input = nullptr;
|
||||
// Compare with internal layouts and convert if needed
|
||||
const Tensor& input = MklGetInput(context, 0);
|
||||
void* mkl_buf_input =
|
||||
const_cast<void*>(static_cast<const void*>(input.flat<T>().data()));
|
||||
CHECK_EQ(dnnLayoutCreateFromPrimitive_F32(
|
||||
&mkl_lt_internal_input, mkl_prim_batchnorm, dnnResourceSrc),
|
||||
E_SUCCESS);
|
||||
mkl_convert_input =
|
||||
!dnnLayoutCompare_F32(mkl_lt_internal_input, mkl_lt_input);
|
||||
if (mkl_convert_input) {
|
||||
CHECK_EQ(dnnConversionCreate_F32(&mkl_prim_convert_input, mkl_lt_input,
|
||||
mkl_lt_internal_input),
|
||||
E_SUCCESS);
|
||||
AllocTmpBuffer(context, mkl_tmp_input_buf_tensor, mkl_lt_internal_input,
|
||||
&mkl_buf_converted_input);
|
||||
CHECK_EQ(dnnConversionExecute_F32(mkl_prim_convert_input, mkl_buf_input,
|
||||
mkl_buf_converted_input),
|
||||
E_SUCCESS);
|
||||
dnnDelete_F32(mkl_prim_convert_input);
|
||||
}
|
||||
dnnLayoutDelete_F32(mkl_lt_internal_input);
|
||||
mkl_res_batchnorm[dnnResourceSrc] =
|
||||
(mkl_convert_input) ? mkl_buf_converted_input : mkl_buf_input;
|
||||
|
||||
// scale-shift layout is created from primitive. So no conversion
|
||||
// is needed, however, a buffer has to be allocated.
|
||||
dnnLayout_t mkl_lt_scale_shift = nullptr;
|
||||
void* mkl_buf_scale_shift = nullptr;
|
||||
CHECK_EQ(
|
||||
dnnLayoutCreateFromPrimitive_F32(
|
||||
&mkl_lt_scale_shift, mkl_prim_batchnorm, dnnResourceScaleShift),
|
||||
E_SUCCESS);
|
||||
AllocTmpBuffer(context, mkl_tmp_scale_shift_buf_tensor,
|
||||
mkl_lt_scale_shift, &mkl_buf_scale_shift);
|
||||
// Fill the scale-shift buffer with data, presumably buffer is 2D array
|
||||
const Tensor& scale = MklGetInput(context, 1);
|
||||
const Tensor& shift = MklGetInput(context, 2);
|
||||
float* buf_scale_shift = static_cast<float*>(mkl_buf_scale_shift);
|
||||
float* buf_scale = const_cast<float*>(
|
||||
static_cast<const float*>(scale.flat<float>().data()));
|
||||
float* buf_shift = const_cast<float*>(
|
||||
static_cast<const float*>(shift.flat<float>().data()));
|
||||
auto depth = mkl_params.depth;
|
||||
for (int i = 0; i < depth; i++) {
|
||||
buf_scale_shift[i] = buf_scale[i];
|
||||
buf_scale_shift[i + depth] = buf_shift[i];
|
||||
}
|
||||
mkl_res_batchnorm[dnnResourceScaleShift] = mkl_buf_scale_shift;
|
||||
}
|
||||
|
||||
inline void MklSetMeanVariance(const Tensor& mean, const Tensor& variance) {
|
||||
mkl_res_batchnorm[dnnResourceMean] = const_cast<void*>(
|
||||
static_cast<const void*>(mean.flat<float>().data()));
|
||||
mkl_res_batchnorm[dnnResourceVariance] = const_cast<void*>(
|
||||
static_cast<const void*>(variance.flat<float>().data()));
|
||||
}
|
||||
} MklFusedBatchNormOpContext;
|
||||
};
|
||||
|
||||
#define REGISTER_MKL_CPU(T) \
|
||||
REGISTER_KERNEL_BUILDER(Name("_MklFusedBatchNorm") \
|
||||
.Device(DEVICE_CPU) \
|
||||
.TypeConstraint<T>("T") \
|
||||
.Label(mkl_op_registry::kMklOpLabel), \
|
||||
MklFusedBatchNormOp<CPUDevice, T>);
|
||||
TF_CALL_float(REGISTER_MKL_CPU);
|
||||
#undef REGISTER_MKL_CPU
|
||||
|
||||
template <typename Device, typename T>
|
||||
class MklFusedBatchNormGradOp : public OpKernel {
|
||||
public:
|
||||
explicit MklFusedBatchNormGradOp(OpKernelConstruction* context)
|
||||
: OpKernel(context) {
|
||||
float epsilon;
|
||||
OP_REQUIRES_OK(context, context->GetAttr("epsilon", &epsilon));
|
||||
epsilon_ = T(epsilon);
|
||||
string tensor_format;
|
||||
OP_REQUIRES_OK(context, context->GetAttr("data_format", &tensor_format));
|
||||
OP_REQUIRES(context, FormatFromString(tensor_format, &tensor_format_),
|
||||
errors::InvalidArgument("Invalid data format"));
|
||||
}
|
||||
|
||||
void Compute(OpKernelContext* context) override {
|
||||
MklFusedBatchNormGradOpContext mkl_context;
|
||||
|
||||
const Tensor& out_backprop = MklGetInput(context, 0);
|
||||
const Tensor& input = MklGetInput(context, 1);
|
||||
const Tensor& scale = MklGetInput(context, 2);
|
||||
const Tensor& saved_mean = MklGetInput(context, 3);
|
||||
const Tensor& saved_var = MklGetInput(context, 4);
|
||||
|
||||
// Here scale, mean, and variance are 1D and considered
|
||||
// those having same layout in MKL and TF
|
||||
GetMklShape(context, 0, &(mkl_context.mkl_shape_out_backprop));
|
||||
GetMklShape(context, 1, &(mkl_context.mkl_shape_input_shape));
|
||||
|
||||
bool input_in_mkl_format = mkl_context.mkl_shape_input_shape.IsMklTensor();
|
||||
bool out_backprop_in_mkl_format =
|
||||
mkl_context.mkl_shape_out_backprop.IsMklTensor();
|
||||
if (!out_backprop_in_mkl_format) {
|
||||
OP_REQUIRES(context, out_backprop.dims() == 4,
|
||||
errors::InvalidArgument("input must be 4-dimensional",
|
||||
out_backprop.shape().DebugString()));
|
||||
}
|
||||
if (!input_in_mkl_format) {
|
||||
OP_REQUIRES(context, input.dims() == 4,
|
||||
errors::InvalidArgument("input must be 4-dimensional",
|
||||
input.shape().DebugString()));
|
||||
}
|
||||
OP_REQUIRES(context, scale.dims() == 1,
|
||||
errors::InvalidArgument("scale must be 1-dimensional",
|
||||
scale.shape().DebugString()));
|
||||
OP_REQUIRES(context, saved_mean.dims() == 1,
|
||||
errors::InvalidArgument("saved mean must be 1-dimensional",
|
||||
saved_mean.shape().DebugString()));
|
||||
OP_REQUIRES(context, saved_var.dims() == 1,
|
||||
errors::InvalidArgument("saved variance must be 1-dimensional",
|
||||
saved_var.shape().DebugString()));
|
||||
|
||||
mkl_context.MklExtractParams(context, tensor_format_);
|
||||
|
||||
mkl_context.MklCreateInputLayout(context);
|
||||
|
||||
unsigned int flag_batch_norm_grad = dnnUseScaleShift;
|
||||
|
||||
// Create Backward Op primitive.
|
||||
CHECK_EQ(dnnBatchNormalizationCreateBackward_v2_F32(
|
||||
&(mkl_context.mkl_prim_batchnorm_bwd), nullptr,
|
||||
mkl_context.mkl_lt_input, static_cast<float>(epsilon_),
|
||||
flag_batch_norm_grad),
|
||||
E_SUCCESS);
|
||||
|
||||
// Temporary tensors and their buffers if conversion is required
|
||||
Tensor mkl_tmp_input_buf_tensor, mkl_tmp_outbackprop_buf_tensor,
|
||||
mkl_tmp_scaleshift_buf_tensor;
|
||||
mkl_context.MklPrepareContextInputs(context, &mkl_tmp_input_buf_tensor,
|
||||
&mkl_tmp_outbackprop_buf_tensor,
|
||||
&mkl_tmp_scaleshift_buf_tensor);
|
||||
|
||||
// Allocate tensor for grad w.r.t. input(x)
|
||||
Tensor* in_backprop = nullptr;
|
||||
TensorShape tf_shape_in_backprop;
|
||||
MklShape mkl_shape_in_backprop;
|
||||
mkl_shape_in_backprop.SetMklTensor(true);
|
||||
mkl_shape_in_backprop.SetMklLayout(mkl_context.mkl_prim_batchnorm_bwd,
|
||||
dnnResourceDiffSrc);
|
||||
mkl_shape_in_backprop.SetTfLayout(mkl_context.mkl_params.in_dims,
|
||||
mkl_context.mkl_params.in_sizes,
|
||||
mkl_context.mkl_params.in_strides);
|
||||
mkl_shape_in_backprop.SetTfDimOrder(mkl_context.mkl_params.in_dims,
|
||||
tensor_format_);
|
||||
tf_shape_in_backprop.AddDim(
|
||||
dnnLayoutGetMemorySize_F32(
|
||||
static_cast<dnnLayout_t>(mkl_shape_in_backprop.GetMklLayout())) /
|
||||
sizeof(T));
|
||||
AllocateOutputSetMklShape(context, 0, &in_backprop, tf_shape_in_backprop,
|
||||
mkl_shape_in_backprop);
|
||||
mkl_context.mkl_res_batchnorm_bwd[dnnResourceDiffSrc] =
|
||||
static_cast<void*>(in_backprop->flat<T>().data());
|
||||
|
||||
// grad_scale and grad_shift are combined together in MKL
|
||||
// So create a single temporary buffer for those.
|
||||
// Also set dnnResourceDiffScaleShift to the temporary buffer
|
||||
Tensor mkl_tmp_grad_scale_shift_buf_tensor;
|
||||
mkl_context.MklPrepareGradScaleShift(context,
|
||||
&mkl_tmp_grad_scale_shift_buf_tensor);
|
||||
|
||||
// All dnn resources are set now, ready to execute
|
||||
CHECK_EQ(dnnExecute_F32(mkl_context.mkl_prim_batchnorm_bwd,
|
||||
mkl_context.mkl_res_batchnorm_bwd),
|
||||
E_SUCCESS);
|
||||
|
||||
// Now separate out scale and shift grad and copy to individual tensors
|
||||
const TensorShape& tf_shape_scale_shift = scale.shape();
|
||||
// Allocate tensor for grad w.r.t. scale (beta)
|
||||
Tensor* scale_backprop = nullptr;
|
||||
MklShape mkl_shape_scale_backprop;
|
||||
AllocateOutputSetMklShape(context, 1, &scale_backprop, tf_shape_scale_shift,
|
||||
mkl_shape_scale_backprop);
|
||||
|
||||
// Allocate tensor for grad w.r.t. shift(gamma)
|
||||
Tensor* shift_backprop = nullptr;
|
||||
MklShape mkl_shape_shift_backprop;
|
||||
AllocateOutputSetMklShape(context, 2, &shift_backprop, tf_shape_scale_shift,
|
||||
mkl_shape_shift_backprop);
|
||||
|
||||
// copy scale and shift grads to tensors
|
||||
float* mkl_buf_scale_shift = const_cast<float*>(static_cast<const float*>(
|
||||
mkl_tmp_grad_scale_shift_buf_tensor.flat<T>().data()));
|
||||
float* tf_buf_scale = const_cast<float*>(
|
||||
static_cast<const float*>(scale_backprop->flat<T>().data()));
|
||||
float* tf_buf_shift = const_cast<float*>(
|
||||
static_cast<const float*>(shift_backprop->flat<T>().data()));
|
||||
auto depth = mkl_context.mkl_params.depth;
|
||||
for (int i = 0; i < depth; i++) {
|
||||
tf_buf_scale[i] = mkl_buf_scale_shift[i];
|
||||
tf_buf_shift[i] = mkl_buf_scale_shift[i + depth];
|
||||
}
|
||||
|
||||
// Two placeholders for estimated_mean and estimated_variance, which are
|
||||
// used for inference and thus not needed here for gradient computation.
|
||||
Tensor* placeholder_1 = nullptr;
|
||||
MklShape mkl_shape_placeholder_1;
|
||||
AllocateOutputSetMklShape(context, 3, &placeholder_1, TensorShape({}),
|
||||
mkl_shape_placeholder_1);
|
||||
Tensor* placeholder_2 = nullptr;
|
||||
MklShape mkl_shape_placeholder_2;
|
||||
AllocateOutputSetMklShape(context, 4, &placeholder_2, TensorShape({}),
|
||||
mkl_shape_placeholder_2);
|
||||
|
||||
mkl_context.MklCleanup();
|
||||
}
|
||||
|
||||
private:
|
||||
T epsilon_;
|
||||
TensorFormat tensor_format_;
|
||||
|
||||
// Structure containing all info for MklOp
|
||||
typedef struct {
|
||||
// Parameters used for input and output layouts
|
||||
struct MklBatchNormParams {
|
||||
// BatchNormOp src and
|
||||
size_t in_dims;
|
||||
size_t in_sizes[4];
|
||||
size_t in_strides[4];
|
||||
size_t depth; // Batch normalization is done for per channel.
|
||||
} mkl_params;
|
||||
|
||||
MklShape mkl_shape_out_backprop;
|
||||
MklShape mkl_shape_input_shape;
|
||||
|
||||
// MKL primitive and resources for BatchNormOp
|
||||
dnnPrimitive_t mkl_prim_batchnorm_bwd = nullptr;
|
||||
void* mkl_res_batchnorm_bwd[dnnResourceNumber];
|
||||
|
||||
// MKL layouts for inputs in the context
|
||||
dnnLayout_t mkl_lt_out_backprop = nullptr;
|
||||
dnnLayout_t mkl_lt_input = nullptr;
|
||||
|
||||
void MklCleanup() {
|
||||
bool input_in_mkl_format = mkl_shape_input_shape.IsMklTensor();
|
||||
bool out_backprop_in_mkl_format = mkl_shape_out_backprop.IsMklTensor();
|
||||
if (!input_in_mkl_format) dnnLayoutDelete_F32(mkl_lt_input);
|
||||
if (!out_backprop_in_mkl_format) dnnLayoutDelete_F32(mkl_lt_out_backprop);
|
||||
|
||||
dnnDelete_F32(mkl_prim_batchnorm_bwd);
|
||||
}
|
||||
|
||||
void MklExtractParams(OpKernelContext* context,
|
||||
const TensorFormat& tensor_format) {
|
||||
const Tensor& input = MklGetInput(context, 1);
|
||||
bool input_in_mkl_format = mkl_shape_input_shape.IsMklTensor();
|
||||
mkl_params.in_dims = input_in_mkl_format
|
||||
? mkl_shape_input_shape.GetDimension()
|
||||
: input.dims();
|
||||
mkl_params.in_sizes[0] = static_cast<size_t>(
|
||||
input_in_mkl_format ? mkl_shape_input_shape.GetSizes()[0]
|
||||
: GetTensorDim(input, tensor_format, 'W'));
|
||||
mkl_params.in_sizes[1] = static_cast<size_t>(
|
||||
input_in_mkl_format ? mkl_shape_input_shape.GetSizes()[1]
|
||||
: GetTensorDim(input, tensor_format, 'H'));
|
||||
mkl_params.in_sizes[2] = static_cast<size_t>(
|
||||
input_in_mkl_format ? mkl_shape_input_shape.GetSizes()[2]
|
||||
: GetTensorDim(input, tensor_format, 'C'));
|
||||
mkl_params.in_sizes[3] = static_cast<size_t>(
|
||||
input_in_mkl_format ? mkl_shape_input_shape.GetSizes()[3]
|
||||
: GetTensorDim(input, tensor_format, 'N'));
|
||||
mkl_params.depth = mkl_params.in_sizes[2];
|
||||
GetStridesFromSizes(tensor_format, mkl_params.in_strides,
|
||||
mkl_params.in_sizes);
|
||||
}
|
||||
|
||||
void MklCreateInputLayout(OpKernelContext* context) {
|
||||
bool input_in_mkl_format = mkl_shape_input_shape.IsMklTensor();
|
||||
if (input_in_mkl_format) {
|
||||
mkl_lt_input =
|
||||
static_cast<dnnLayout_t>(mkl_shape_input_shape.GetCurLayout());
|
||||
} else {
|
||||
CHECK_EQ(
|
||||
dnnLayoutCreate_F32(&mkl_lt_input, mkl_params.in_dims,
|
||||
mkl_params.in_sizes, mkl_params.in_strides),
|
||||
E_SUCCESS);
|
||||
}
|
||||
|
||||
bool out_backprop_in_mkl_format = mkl_shape_out_backprop.IsMklTensor();
|
||||
if (out_backprop_in_mkl_format) {
|
||||
mkl_lt_out_backprop =
|
||||
static_cast<dnnLayout_t>(mkl_shape_out_backprop.GetCurLayout());
|
||||
} else {
|
||||
CHECK_EQ(
|
||||
dnnLayoutCreate_F32(&mkl_lt_out_backprop, mkl_params.in_dims,
|
||||
mkl_params.in_sizes, mkl_params.in_strides),
|
||||
E_SUCCESS);
|
||||
}
|
||||
}
|
||||
|
||||
void MklPrepareContextInputs(OpKernelContext* context,
|
||||
Tensor* mkl_tmp_input_buf_tensor,
|
||||
Tensor* mkl_tmp_outbackprop_buf_tensor,
|
||||
Tensor* mkl_tmp_scaleshift_buf_tensor) {
|
||||
bool mkl_convert_input;
|
||||
dnnPrimitive_t mkl_prim_convert_input = nullptr;
|
||||
dnnLayout_t mkl_lt_internal_input = nullptr;
|
||||
void* mkl_buf_converted_input = nullptr;
|
||||
// Compare with internal layouts and convert if needed
|
||||
const Tensor& input = MklGetInput(context, 1);
|
||||
void* mkl_buf_input =
|
||||
const_cast<void*>(static_cast<const void*>(input.flat<T>().data()));
|
||||
CHECK_EQ(
|
||||
dnnLayoutCreateFromPrimitive_F32(
|
||||
&mkl_lt_internal_input, mkl_prim_batchnorm_bwd, dnnResourceSrc),
|
||||
E_SUCCESS);
|
||||
mkl_convert_input =
|
||||
!dnnLayoutCompare_F32(mkl_lt_internal_input, mkl_lt_input);
|
||||
if (mkl_convert_input) {
|
||||
CHECK_EQ(dnnConversionCreate_F32(&mkl_prim_convert_input, mkl_lt_input,
|
||||
mkl_lt_internal_input),
|
||||
E_SUCCESS);
|
||||
AllocTmpBuffer(context, mkl_tmp_input_buf_tensor, mkl_lt_internal_input,
|
||||
&mkl_buf_converted_input);
|
||||
CHECK_EQ(dnnConversionExecute_F32(mkl_prim_convert_input, mkl_buf_input,
|
||||
mkl_buf_converted_input),
|
||||
E_SUCCESS);
|
||||
dnnDelete_F32(mkl_prim_convert_input);
|
||||
}
|
||||
dnnLayoutDelete_F32(mkl_lt_internal_input);
|
||||
mkl_res_batchnorm_bwd[dnnResourceSrc] =
|
||||
(mkl_convert_input) ? mkl_buf_converted_input : mkl_buf_input;
|
||||
|
||||
bool mkl_convert_out_backprop;
|
||||
dnnPrimitive_t mkl_prim_convert_out_backprop = nullptr;
|
||||
dnnLayout_t mkl_lt_internal_out_backprop = nullptr;
|
||||
void* mkl_buf_converted_out_backprop = nullptr;
|
||||
// Compare with internal layouts and convert if needed
|
||||
const Tensor& out_backprop = MklGetInput(context, 0);
|
||||
void* mkl_buf_out_backprop = const_cast<void*>(
|
||||
static_cast<const void*>(out_backprop.flat<T>().data()));
|
||||
CHECK_EQ(dnnLayoutCreateFromPrimitive_F32(&mkl_lt_internal_out_backprop,
|
||||
mkl_prim_batchnorm_bwd,
|
||||
dnnResourceDiffDst),
|
||||
E_SUCCESS);
|
||||
mkl_convert_out_backprop = !dnnLayoutCompare_F32(
|
||||
mkl_lt_internal_out_backprop, mkl_lt_out_backprop);
|
||||
if (mkl_convert_out_backprop) {
|
||||
CHECK_EQ(dnnConversionCreate_F32(&mkl_prim_convert_out_backprop,
|
||||
mkl_lt_out_backprop,
|
||||
mkl_lt_internal_out_backprop),
|
||||
E_SUCCESS);
|
||||
AllocTmpBuffer(context, mkl_tmp_outbackprop_buf_tensor,
|
||||
mkl_lt_internal_out_backprop,
|
||||
&mkl_buf_converted_out_backprop);
|
||||
CHECK_EQ(dnnConversionExecute_F32(mkl_prim_convert_out_backprop,
|
||||
mkl_buf_out_backprop,
|
||||
mkl_buf_converted_out_backprop),
|
||||
E_SUCCESS);
|
||||
dnnDelete_F32(mkl_prim_convert_out_backprop);
|
||||
}
|
||||
dnnLayoutDelete_F32(mkl_lt_internal_out_backprop);
|
||||
mkl_res_batchnorm_bwd[dnnResourceDiffDst] =
|
||||
(mkl_convert_out_backprop) ? mkl_buf_converted_out_backprop
|
||||
: mkl_buf_out_backprop;
|
||||
|
||||
// Set dnnResourceMean and dnnResourceVariance
|
||||
const Tensor& saved_mean = MklGetInput(context, 3);
|
||||
const Tensor& saved_var = MklGetInput(context, 4);
|
||||
void* mkl_buf_saved_mean = const_cast<void*>(
|
||||
static_cast<const void*>(saved_mean.flat<T>().data()));
|
||||
void* mkl_buf_saved_var = const_cast<void*>(
|
||||
static_cast<const void*>(saved_var.flat<T>().data()));
|
||||
mkl_res_batchnorm_bwd[dnnResourceMean] = mkl_buf_saved_mean;
|
||||
mkl_res_batchnorm_bwd[dnnResourceVariance] = mkl_buf_saved_var;
|
||||
|
||||
// Set dnnResourceScaleShift
|
||||
// Note backward Op needs only current values of scale parameters,
|
||||
// shift parameters could be garbage and won't be used
|
||||
const Tensor& scale = MklGetInput(context, 2);
|
||||
dnnLayout_t mkl_lt_scale_shift = nullptr;
|
||||
void* mkl_buf_scale_shift = nullptr;
|
||||
CHECK_EQ(dnnLayoutCreateFromPrimitive_F32(&mkl_lt_scale_shift,
|
||||
mkl_prim_batchnorm_bwd,
|
||||
dnnResourceScaleShift),
|
||||
E_SUCCESS);
|
||||
AllocTmpBuffer(context, mkl_tmp_scaleshift_buf_tensor, mkl_lt_scale_shift,
|
||||
&mkl_buf_scale_shift);
|
||||
float* pscale =
|
||||
const_cast<float*>(static_cast<const float*>(scale.flat<T>().data()));
|
||||
float* pscale_shift = static_cast<float*>(mkl_buf_scale_shift);
|
||||
auto depth = mkl_params.depth;
|
||||
for (int i = 0; i < depth; i++) pscale_shift[i] = pscale[i];
|
||||
mkl_res_batchnorm_bwd[dnnResourceScaleShift] = mkl_buf_scale_shift;
|
||||
dnnLayoutDelete_F32(mkl_lt_scale_shift);
|
||||
}
|
||||
|
||||
void MklPrepareGradScaleShift(OpKernelContext* context,
|
||||
Tensor* mkl_tmp_grad_scale_shift_buf_tensor) {
|
||||
dnnLayout_t mkl_lt_grad_scaleshift = nullptr;
|
||||
void* mkl_buf_grad_scaleshift = nullptr;
|
||||
CHECK_EQ(dnnLayoutCreateFromPrimitive_F32(&mkl_lt_grad_scaleshift,
|
||||
mkl_prim_batchnorm_bwd,
|
||||
dnnResourceDiffScaleShift),
|
||||
E_SUCCESS);
|
||||
AllocTmpBuffer(context, mkl_tmp_grad_scale_shift_buf_tensor,
|
||||
mkl_lt_grad_scaleshift, &mkl_buf_grad_scaleshift);
|
||||
mkl_res_batchnorm_bwd[dnnResourceDiffScaleShift] =
|
||||
mkl_buf_grad_scaleshift;
|
||||
dnnLayoutDelete_F32(mkl_lt_grad_scaleshift);
|
||||
}
|
||||
} MklFusedBatchNormGradOpContext;
|
||||
};
|
||||
|
||||
#define REGISTER_MKL_CPU(T) \
|
||||
REGISTER_KERNEL_BUILDER(Name("_MklFusedBatchNormGrad") \
|
||||
.Device(DEVICE_CPU) \
|
||||
.TypeConstraint<T>("T") \
|
||||
.Label(mkl_op_registry::kMklOpLabel), \
|
||||
MklFusedBatchNormGradOp<CPUDevice, T>);
|
||||
TF_CALL_float(REGISTER_MKL_CPU);
|
||||
#undef REGISTER_MKL_CPU
|
||||
} // namespace tensorflow
|
||||
|
||||
#endif // INTEL_MKL
|
722
tensorflow/core/kernels/mkl_lrn_op.cc
Normal file
722
tensorflow/core/kernels/mkl_lrn_op.cc
Normal file
@ -0,0 +1,722 @@
|
||||
/* Copyright 2015 The TensorFlow Authors. All Rights Reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
|
||||
// LRN = Local Response Normalization
|
||||
// See docs in ../ops/nn_ops.cc. This opkernel uses MKL library, create MKL
|
||||
// layout and primitives, use MKL dnn primitives to compute local
|
||||
// response normalization
|
||||
|
||||
#ifdef INTEL_MKL
|
||||
|
||||
#define EIGEN_USE_THREADS
|
||||
#include <vector>
|
||||
#include "third_party/eigen3/unsupported/Eigen/CXX11/Tensor"
|
||||
#include "third_party/mkl/include/mkl_dnn.h"
|
||||
#include "third_party/mkl/include/mkl_dnn_types.h"
|
||||
#include "tensorflow/core/framework/op_kernel.h"
|
||||
#include "tensorflow/core/framework/register_types.h"
|
||||
#include "tensorflow/core/framework/tensor.h"
|
||||
#include "tensorflow/core/kernels/bounds_check.h"
|
||||
#include "tensorflow/core/kernels/ops_util.h"
|
||||
#include "tensorflow/core/lib/core/errors.h"
|
||||
#include "tensorflow/core/util/mkl_util.h"
|
||||
#include "tensorflow/core/util/tensor_format.h"
|
||||
|
||||
#if !defined(IS_MOBILE_PLATFORM)
|
||||
#include "tensorflow/core/util/work_sharder.h"
|
||||
#endif
|
||||
|
||||
namespace tensorflow {
|
||||
|
||||
namespace {
|
||||
// Create a depth-by-depth band matrix with 1s along a swath of size (2 *
|
||||
// depth_radius + 1) around the diagonal.
|
||||
template <typename T>
|
||||
void GetBandMatrix(int depth, int depth_radius,
|
||||
Eigen::Tensor<T, 2, Eigen::RowMajor>* result) {
|
||||
result->setZero();
|
||||
for (int row = 0; row < depth; ++row) {
|
||||
const int begin = std::max<int>(0, row - depth_radius);
|
||||
const int end = std::min<int>(depth, row + depth_radius + 1);
|
||||
Eigen::DSizes<Eigen::DenseIndex, 2> start(row, begin);
|
||||
Eigen::DSizes<Eigen::DenseIndex, 2> sizes(1, end - begin);
|
||||
result->slice(start, sizes).setConstant(T(1));
|
||||
}
|
||||
}
|
||||
|
||||
} // namespace
|
||||
|
||||
template <typename T>
|
||||
class MklLRNOp : public OpKernel {
|
||||
public:
|
||||
~MklLRNOp() {}
|
||||
|
||||
explicit MklLRNOp(OpKernelConstruction* context) : OpKernel(context) {
|
||||
int64 depth_radius64;
|
||||
OP_REQUIRES_OK(context, context->GetAttr("depth_radius", &depth_radius64));
|
||||
OP_REQUIRES(
|
||||
context,
|
||||
FastBoundsCheck(depth_radius64, std::numeric_limits<int>::max()),
|
||||
errors::InvalidArgument("depth_radius = ", depth_radius64,
|
||||
" larger than int max"));
|
||||
depth_radius_ = static_cast<size_t>(depth_radius64);
|
||||
|
||||
OP_REQUIRES_OK(context, context->GetAttr("bias", &bias_));
|
||||
OP_REQUIRES_OK(context, context->GetAttr("alpha", &alpha_));
|
||||
OP_REQUIRES_OK(context, context->GetAttr("beta", &beta_));
|
||||
workspace_enabled_ = false;
|
||||
context->GetAttr("workspace_enabled", &workspace_enabled_);
|
||||
}
|
||||
|
||||
void Compute(OpKernelContext* context) override {
|
||||
MklLRNOpContext mkl_context;
|
||||
|
||||
const Tensor& input = MklGetInput(context, 0);
|
||||
GetMklShape(context, 0, &mkl_context.input_shape);
|
||||
bool input_in_mkl_format = mkl_context.input_shape.IsMklTensor();
|
||||
|
||||
// Sanity checks
|
||||
mkl_context.in_dims = input_in_mkl_format
|
||||
? mkl_context.input_shape.GetDimension()
|
||||
: input.dims();
|
||||
OP_REQUIRES(context, mkl_context.in_dims == 4,
|
||||
errors::InvalidArgument("input must be 4-dimensional"));
|
||||
OP_REQUIRES(
|
||||
context,
|
||||
FastBoundsCheck(input.NumElements(), std::numeric_limits<int>::max()),
|
||||
errors::InvalidArgument("argument to LRN too large"));
|
||||
|
||||
if (!input_in_mkl_format) {
|
||||
mkl_context.MklDefaultToEigen(context, depth_radius_, bias_, alpha_,
|
||||
beta_, input);
|
||||
return;
|
||||
}
|
||||
|
||||
if (input_in_mkl_format) {
|
||||
// MKL supports normalization over channel dimension only
|
||||
if (mkl_context.input_shape.tf_dim_idx(mkl_context.in_dims - 1) ==
|
||||
MklDims::C) {
|
||||
mkl_context.lt_input =
|
||||
static_cast<dnnLayout_t>(mkl_context.input_shape.GetCurLayout());
|
||||
workspace_enabled_ = true;
|
||||
} else {
|
||||
mkl_context.MklDefaultToEigen(context, depth_radius_, bias_, alpha_,
|
||||
beta_, input);
|
||||
return;
|
||||
}
|
||||
}
|
||||
|
||||
int kernel_size = 2 * depth_radius_ + 1;
|
||||
|
||||
CHECK_EQ(dnnLRNCreateForward_F32(
|
||||
&mkl_context.lrn_fwd, NULL, mkl_context.lt_input, kernel_size,
|
||||
static_cast<float>(alpha_ * kernel_size), beta_, bias_),
|
||||
E_SUCCESS);
|
||||
|
||||
// Allocate output tensor and shape
|
||||
Tensor* output = nullptr;
|
||||
Tensor* workspace = nullptr;
|
||||
|
||||
// Convert Inputs if needed
|
||||
Tensor mkl_tmp_input_buf_tensor;
|
||||
mkl_context.MklPrepareLRNInputs(context, &mkl_tmp_input_buf_tensor);
|
||||
|
||||
// Allocate Layer Outputs
|
||||
mkl_context.MklAllocateOutputs(context, &output, &workspace,
|
||||
workspace_enabled_);
|
||||
|
||||
Tensor mkl_tmp_workspace_buf_tensor;
|
||||
mkl_context.MklPrepareLRNOutputs(context, output, workspace,
|
||||
&mkl_tmp_workspace_buf_tensor,
|
||||
workspace_enabled_);
|
||||
|
||||
// Execute LRN.
|
||||
CHECK_EQ(dnnExecute_F32(mkl_context.lrn_fwd, mkl_context.lrn_res),
|
||||
E_SUCCESS);
|
||||
|
||||
// Release MKL resources.
|
||||
mkl_context.MklCleanup();
|
||||
}
|
||||
|
||||
private:
|
||||
typedef struct {
|
||||
size_t in_dims;
|
||||
size_t in_sizes[4];
|
||||
size_t in_strides[4];
|
||||
size_t out_sizes[4];
|
||||
size_t out_strides[4];
|
||||
MklShape input_shape;
|
||||
dnnPrimitive_t lrn_fwd = nullptr;
|
||||
dnnPrimitive_t convert_input = nullptr;
|
||||
/* dnnPrimitive_t convert_output; */
|
||||
dnnLayout_t lt_input = nullptr;
|
||||
/* dnnLayout_t lt_output; */
|
||||
dnnLayout_t lt_internal_input = nullptr;
|
||||
dnnLayout_t lt_internal_workspace = nullptr;
|
||||
dnnLayout_t lt_internal_output = nullptr;
|
||||
void* lrn_res[dnnResourceNumber];
|
||||
|
||||
// Convert Inputs if needed
|
||||
void MklPrepareLRNInputs(OpKernelContext* context,
|
||||
Tensor* mkl_tmp_input_buf_tensor) {
|
||||
const Tensor& input = MklGetInput(context, 0);
|
||||
void* mkl_buf_input =
|
||||
const_cast<void*>(static_cast<const void*>(input.flat<T>().data()));
|
||||
|
||||
CHECK_EQ(dnnLayoutCreateFromPrimitive_F32(<_internal_input, lrn_fwd,
|
||||
dnnResourceSrc),
|
||||
E_SUCCESS);
|
||||
|
||||
void* mkl_buf_convert_input = nullptr;
|
||||
bool mkl_convert_input = false;
|
||||
mkl_convert_input = !dnnLayoutCompare_F32(lt_internal_input, lt_input);
|
||||
|
||||
if (mkl_convert_input) {
|
||||
CHECK_EQ(dnnConversionCreate_F32(&convert_input, lt_input,
|
||||
lt_internal_input),
|
||||
E_SUCCESS);
|
||||
AllocTmpBuffer(context, mkl_tmp_input_buf_tensor, lt_internal_input,
|
||||
&mkl_buf_convert_input);
|
||||
CHECK_EQ(dnnConversionExecute_F32(convert_input, mkl_buf_input,
|
||||
mkl_buf_convert_input),
|
||||
E_SUCCESS);
|
||||
dnnDelete_F32(convert_input);
|
||||
}
|
||||
|
||||
lrn_res[dnnResourceSrc] =
|
||||
(mkl_convert_input) ? mkl_buf_convert_input : mkl_buf_input;
|
||||
}
|
||||
|
||||
// Allocate Layer Outputs
|
||||
void MklAllocateOutputs(OpKernelContext* context, Tensor** output,
|
||||
Tensor** workspace, bool workspace_enabled_) {
|
||||
TensorShape mkl_output_tf_shape; /* First tensor */
|
||||
MklShape mkl_output_mkl_shape; /* Second tensor */
|
||||
|
||||
mkl_output_mkl_shape.SetMklTensor(true);
|
||||
mkl_output_mkl_shape.SetMklLayout(lrn_fwd, dnnResourceDst);
|
||||
mkl_output_mkl_shape.SetTfLayout(in_dims, input_shape.GetSizes(),
|
||||
input_shape.GetStrides());
|
||||
mkl_output_mkl_shape.SetTfDimOrder(in_dims,
|
||||
input_shape.GetTfToMklDimMap());
|
||||
mkl_output_tf_shape.AddDim(
|
||||
dnnLayoutGetMemorySize_F32(
|
||||
static_cast<dnnLayout_t>(mkl_output_mkl_shape.GetMklLayout())) /
|
||||
sizeof(T));
|
||||
AllocateOutputSetMklShape(context, 0, output,
|
||||
mkl_output_tf_shape /* First tensor */,
|
||||
mkl_output_mkl_shape /* Second Tensor */);
|
||||
|
||||
if (workspace_enabled_) {
|
||||
TensorShape mkl_workspace_tf_shape; /* First tensor */
|
||||
MklShape mkl_workspace_mkl_shape; /* Second tensor */
|
||||
mkl_workspace_mkl_shape.SetMklTensor(false);
|
||||
mkl_workspace_mkl_shape.SetMklLayout(lrn_fwd, dnnResourceWorkspace);
|
||||
// Assumes workspace has same TF layout and TF dim order as input
|
||||
mkl_workspace_mkl_shape.SetTfLayout(in_dims, input_shape.GetSizes(),
|
||||
input_shape.GetStrides());
|
||||
mkl_workspace_mkl_shape.SetTfDimOrder(in_dims,
|
||||
input_shape.GetTfToMklDimMap());
|
||||
mkl_workspace_tf_shape.AddDim(
|
||||
dnnLayoutGetMemorySize_F32(static_cast<dnnLayout_t>(
|
||||
mkl_workspace_mkl_shape.GetMklLayout())) /
|
||||
sizeof(T));
|
||||
AllocateOutputSetMklShape(context, 1, workspace,
|
||||
mkl_workspace_tf_shape /* First tensor */,
|
||||
mkl_workspace_mkl_shape /* Second Tensor */);
|
||||
}
|
||||
}
|
||||
|
||||
void MklPrepareLRNOutputs(OpKernelContext* context, Tensor* output,
|
||||
Tensor* workspace,
|
||||
Tensor* mkl_tmp_workspace_buf_tensor,
|
||||
bool workspace_enabled_) {
|
||||
CHECK_EQ(dnnLayoutCreateFromPrimitive_F32(<_internal_workspace, lrn_fwd,
|
||||
dnnResourceWorkspace),
|
||||
E_SUCCESS);
|
||||
|
||||
CHECK_EQ(dnnLayoutCreateFromPrimitive_F32(<_internal_output, lrn_fwd,
|
||||
dnnResourceDst),
|
||||
E_SUCCESS);
|
||||
|
||||
void* mkl_buf_output =
|
||||
const_cast<void*>(static_cast<const void*>(output->flat<T>().data()));
|
||||
lrn_res[dnnResourceDst] = mkl_buf_output;
|
||||
|
||||
void* mkl_buf_workspace = nullptr;
|
||||
if (workspace_enabled_) {
|
||||
mkl_buf_workspace = const_cast<void*>(
|
||||
static_cast<const void*>(workspace->flat<T>().data()));
|
||||
} else {
|
||||
AllocTmpBuffer(context, mkl_tmp_workspace_buf_tensor,
|
||||
lt_internal_workspace, &mkl_buf_workspace);
|
||||
}
|
||||
lrn_res[dnnResourceWorkspace] = mkl_buf_workspace;
|
||||
}
|
||||
|
||||
// Fallback implementation - Taken from lrn_op.cc
|
||||
// TODO(intelft) Check if we can use EigenLRNOp directly instead of making a
|
||||
// copy.
|
||||
void MklDefaultToEigen(OpKernelContext* context, int depth_radius_,
|
||||
float bias_, float alpha_, float beta_,
|
||||
const Tensor& input) {
|
||||
const int batch = static_cast<int>(input.dim_size(0));
|
||||
const int rows = static_cast<int>(input.dim_size(1));
|
||||
const int cols = static_cast<int>(input.dim_size(2));
|
||||
const int depth = static_cast<int>(input.dim_size(3));
|
||||
const int nodes = cols * rows;
|
||||
|
||||
auto in_shaped = input.shaped<T, 2>({nodes * batch, depth});
|
||||
// Multiplying the input with the band matrix has the effect of reducing
|
||||
// the
|
||||
// correct patch along the depth.
|
||||
Eigen::Tensor<T, 2, Eigen::RowMajor> multiplier(depth, depth);
|
||||
GetBandMatrix<T>(depth, depth_radius_, &multiplier);
|
||||
|
||||
Tensor *output, *workspace;
|
||||
MklShape mkl_output_mkl_shape, mkl_workspace_mkl_shape;
|
||||
mkl_output_mkl_shape.SetMklTensor(false);
|
||||
mkl_output_mkl_shape.SetDimensions(4);
|
||||
AllocateOutputSetMklShape(context, 0, &output, input.shape(),
|
||||
mkl_output_mkl_shape);
|
||||
|
||||
mkl_workspace_mkl_shape.SetMklTensor(false);
|
||||
mkl_workspace_mkl_shape.SetDimensions(4);
|
||||
AllocateOutputSetMklShape(context, 1, &workspace, input.shape(),
|
||||
mkl_workspace_mkl_shape);
|
||||
|
||||
auto out_shaped = output->shaped<T, 2>({nodes * batch, depth});
|
||||
Eigen::array<DimPair, 1> dims = {{DimPair(1, 0)}};
|
||||
auto tmp = in_shaped.square().contract(multiplier, dims) * alpha_ + bias_;
|
||||
if (beta_ == T(1)) {
|
||||
out_shaped.device(context->eigen_cpu_device()) =
|
||||
in_shaped * tmp.inverse();
|
||||
} else if (beta_ == T(0.5)) {
|
||||
out_shaped.device(context->eigen_cpu_device()) =
|
||||
in_shaped * tmp.rsqrt();
|
||||
} else {
|
||||
out_shaped.device(context->eigen_cpu_device()) =
|
||||
in_shaped * (tmp.log() * -beta_).exp();
|
||||
}
|
||||
}
|
||||
|
||||
// Release MKL resources.
|
||||
void MklCleanup() {
|
||||
dnnDelete_F32(lrn_fwd);
|
||||
dnnLayoutDelete_F32(lt_internal_input);
|
||||
dnnLayoutDelete_F32(lt_internal_workspace);
|
||||
dnnLayoutDelete_F32(lt_internal_output);
|
||||
}
|
||||
} MklLRNOpContext;
|
||||
|
||||
typedef typename Eigen::Tensor<T, 1, Eigen::RowMajor>::DimensionPair DimPair;
|
||||
|
||||
bool workspace_enabled_;
|
||||
int depth_radius_;
|
||||
float bias_;
|
||||
float alpha_;
|
||||
float beta_;
|
||||
};
|
||||
|
||||
template <typename T>
|
||||
class MklLRNGradOp : public OpKernel {
|
||||
public:
|
||||
explicit MklLRNGradOp(OpKernelConstruction* context) : OpKernel(context) {
|
||||
int64 depth_radius64;
|
||||
OP_REQUIRES_OK(context, context->GetAttr("depth_radius", &depth_radius64));
|
||||
OP_REQUIRES(
|
||||
context,
|
||||
FastBoundsCheck(depth_radius64, std::numeric_limits<int>::max()),
|
||||
errors::InvalidArgument("depth_radius = ", depth_radius64,
|
||||
" larger than int max"));
|
||||
depth_radius_ = static_cast<int>(depth_radius64);
|
||||
OP_REQUIRES_OK(context, context->GetAttr("bias", &bias_));
|
||||
OP_REQUIRES_OK(context, context->GetAttr("alpha", &alpha_));
|
||||
OP_REQUIRES_OK(context, context->GetAttr("beta", &beta_));
|
||||
workspace_enabled_ = false;
|
||||
context->GetAttr("workspace_enabled", &workspace_enabled_);
|
||||
}
|
||||
|
||||
void Compute(OpKernelContext* context) override {
|
||||
MklLRNGradOpContext mkl_context;
|
||||
mkl_context.depth_radius_ = depth_radius_;
|
||||
mkl_context.bias_ = bias_;
|
||||
mkl_context.alpha_ = alpha_;
|
||||
mkl_context.beta_ = beta_;
|
||||
|
||||
const Tensor& in_grads = MklGetInput(context, 0);
|
||||
const Tensor& in_image = MklGetInput(context, 1);
|
||||
const Tensor& out_image = MklGetInput(context, 2);
|
||||
|
||||
GetMklShape(context, 0, &mkl_context.ingrad_shape);
|
||||
GetMklShape(context, 1, &mkl_context.inimage_shape);
|
||||
GetMklShape(context, 2, &mkl_context.outimage_shape);
|
||||
|
||||
bool ingrad_in_mkl_format = mkl_context.ingrad_shape.IsMklTensor();
|
||||
bool inimage_in_mkl_format = mkl_context.inimage_shape.IsMklTensor();
|
||||
bool outimage_in_mkl_format = mkl_context.outimage_shape.IsMklTensor();
|
||||
|
||||
mkl_context.in_dims = inimage_in_mkl_format
|
||||
? mkl_context.inimage_shape.GetDimension()
|
||||
: in_image.dims();
|
||||
OP_REQUIRES(context, mkl_context.in_dims == 4,
|
||||
errors::InvalidArgument("input images must be 4-dimensional"));
|
||||
|
||||
if (!workspace_enabled_) {
|
||||
mkl_context.MklDefaultToEigen(context);
|
||||
return;
|
||||
}
|
||||
if (ingrad_in_mkl_format || inimage_in_mkl_format) {
|
||||
const MklShape* tmp_mkl_shape = (ingrad_in_mkl_format)
|
||||
? &mkl_context.ingrad_shape
|
||||
: &mkl_context.inimage_shape;
|
||||
if (tmp_mkl_shape->tf_dim_idx(mkl_context.in_dims - 1) != MklDims::C) {
|
||||
// Fallback to eigen
|
||||
mkl_context.MklDefaultToEigen(context);
|
||||
return;
|
||||
} else { // MKL supports normalization over channel dimension only
|
||||
for (int i = 0; i < mkl_context.in_dims; i++) {
|
||||
mkl_context.in_sizes[i] = mkl_context.out_sizes[i] =
|
||||
tmp_mkl_shape->GetSizes()[i];
|
||||
mkl_context.in_strides[i] = mkl_context.out_strides[i] =
|
||||
tmp_mkl_shape->GetStrides()[i];
|
||||
}
|
||||
}
|
||||
} else {
|
||||
// Fallback to eigen
|
||||
mkl_context.MklDefaultToEigen(context);
|
||||
return;
|
||||
}
|
||||
|
||||
// Dimensions check for sanity purpose
|
||||
if (ingrad_in_mkl_format) {
|
||||
OP_REQUIRES(
|
||||
context, mkl_context.ingrad_shape.GetDimension() == 4,
|
||||
errors::InvalidArgument("input gradient must be 4-dimensional"));
|
||||
} else {
|
||||
OP_REQUIRES(
|
||||
context, in_grads.dims() == 4,
|
||||
errors::InvalidArgument("input gradient must be 4-dimensional"));
|
||||
}
|
||||
|
||||
if (outimage_in_mkl_format) {
|
||||
OP_REQUIRES(
|
||||
context, mkl_context.outimage_shape.GetDimension() == 4,
|
||||
errors::InvalidArgument("Output image must be 4-dimensional"));
|
||||
} else {
|
||||
OP_REQUIRES(
|
||||
context, out_image.dims() == 4,
|
||||
errors::InvalidArgument("Output image must be 4-dimensional"));
|
||||
}
|
||||
|
||||
// Prepare mkl input layout
|
||||
mkl_context.MklPrepareLRNInputsLayouts(context);
|
||||
int ksize = 2 * depth_radius_ + 1;
|
||||
|
||||
CHECK_EQ(dnnLRNCreateBackward_F32(
|
||||
&mkl_context.lrn_bwd, NULL, mkl_context.lt_input,
|
||||
mkl_context.lt_output, ksize,
|
||||
static_cast<float>(alpha_ * ksize), beta_, bias_),
|
||||
E_SUCCESS);
|
||||
|
||||
// Allocate output tensor and shape.
|
||||
TensorShape mkl_output_tf_shape; /* First tensor */
|
||||
MklShape mkl_output_mkl_shape; /* Second tensor */
|
||||
mkl_output_mkl_shape.SetMklTensor(true);
|
||||
CHECK_NE(mkl_context.lrn_bwd, nullptr);
|
||||
mkl_output_mkl_shape.SetMklLayout(mkl_context.lrn_bwd, dnnResourceDiffSrc);
|
||||
mkl_output_mkl_shape.SetTfLayout(mkl_context.in_dims, mkl_context.out_sizes,
|
||||
mkl_context.out_strides);
|
||||
if (ingrad_in_mkl_format) {
|
||||
mkl_output_mkl_shape.SetTfDimOrder(
|
||||
mkl_context.in_dims, mkl_context.ingrad_shape.GetTfToMklDimMap());
|
||||
} else {
|
||||
mkl_output_mkl_shape.SetTfDimOrder(
|
||||
mkl_context.in_dims, mkl_context.inimage_shape.GetTfToMklDimMap());
|
||||
}
|
||||
mkl_output_tf_shape.AddDim(
|
||||
dnnLayoutGetMemorySize_F32(
|
||||
static_cast<dnnLayout_t>(mkl_output_mkl_shape.GetMklLayout())) /
|
||||
sizeof(T));
|
||||
Tensor* output = nullptr;
|
||||
AllocateOutputSetMklShape(context, 0, &output, mkl_output_tf_shape,
|
||||
mkl_output_mkl_shape);
|
||||
|
||||
// Get pointers to output data.
|
||||
void* user_output =
|
||||
const_cast<void*>(static_cast<const void*>(output->flat<T>().data()));
|
||||
|
||||
Tensor mkl_tmp_input_buf_tensor, mkl_tmp_image_buf_tensor,
|
||||
mkl_tmp_outimage_buf_tensor, mkl_tmp_workspace_buf_tensor;
|
||||
// Convert Inputs if needed
|
||||
mkl_context.MklPrepareLRNGradInput(
|
||||
context, &mkl_tmp_input_buf_tensor, &mkl_tmp_image_buf_tensor,
|
||||
&mkl_tmp_outimage_buf_tensor, &mkl_tmp_workspace_buf_tensor);
|
||||
|
||||
// We do not do any conversion for output. But we simply emit it
|
||||
// in MKL format.
|
||||
mkl_context.res_lrn_bwd[dnnResourceDiffSrc] = user_output;
|
||||
// Execute LRN backward using dnnExecute
|
||||
CHECK_EQ(dnnExecute_F32(mkl_context.lrn_bwd, mkl_context.res_lrn_bwd),
|
||||
E_SUCCESS);
|
||||
// Release MKL resources.
|
||||
mkl_context.Mklcleanup();
|
||||
}
|
||||
|
||||
private:
|
||||
typedef struct {
|
||||
int depth_radius_;
|
||||
float bias_;
|
||||
float alpha_;
|
||||
float beta_;
|
||||
size_t in_dims;
|
||||
size_t in_sizes[4];
|
||||
size_t in_strides[4];
|
||||
size_t out_sizes[4];
|
||||
size_t out_strides[4];
|
||||
MklShape ingrad_shape, inimage_shape, outimage_shape;
|
||||
dnnPrimitive_t lrn_bwd = nullptr;
|
||||
dnnPrimitive_t convert_input = nullptr;
|
||||
/* dnnPrimitive_t convert_output; */
|
||||
dnnLayout_t lt_input = nullptr;
|
||||
dnnLayout_t lt_output = nullptr;
|
||||
dnnLayout_t lt_bdw_input = nullptr;
|
||||
dnnLayout_t lt_workspace = nullptr;
|
||||
dnnLayout_t lt_internal_input = nullptr;
|
||||
/* dnnLayout_t lt_internal_workspace;
|
||||
dnnLayout_t lt_internal_output; */
|
||||
void* res_lrn_bwd[dnnResourceNumber];
|
||||
|
||||
// prepare mkl input
|
||||
void MklPrepareLRNInputsLayouts(OpKernelContext* context) {
|
||||
bool ingrad_in_mkl_format = ingrad_shape.IsMklTensor();
|
||||
bool inimage_in_mkl_format = inimage_shape.IsMklTensor();
|
||||
if (!ingrad_in_mkl_format) {
|
||||
CHECK_EQ(dnnLayoutCreate_F32(<_input, in_dims, in_sizes, in_strides),
|
||||
E_SUCCESS);
|
||||
} else {
|
||||
lt_input = static_cast<dnnLayout_t>(ingrad_shape.GetCurLayout());
|
||||
}
|
||||
|
||||
if (!inimage_in_mkl_format) {
|
||||
CHECK_EQ(
|
||||
dnnLayoutCreate_F32(<_output, in_dims, out_sizes, out_strides),
|
||||
E_SUCCESS);
|
||||
} else {
|
||||
lt_output = static_cast<dnnLayout_t>(inimage_shape.GetCurLayout());
|
||||
}
|
||||
}
|
||||
|
||||
// convert input if needed
|
||||
void MklPrepareLRNGradInput(OpKernelContext* context,
|
||||
Tensor* mkl_tmp_input_buf_tensor,
|
||||
Tensor* mkl_tmp_image_buf_tensor,
|
||||
Tensor* mkl_tmp_outimage_buf_tensor,
|
||||
Tensor* mkl_tmp_workspace_buf_tensor) {
|
||||
const Tensor& in_grads = MklGetInput(context, 0);
|
||||
const Tensor& in_image = MklGetInput(context, 1);
|
||||
const Tensor& out_image = MklGetInput(context, 2);
|
||||
|
||||
void* user_input = const_cast<void*>(
|
||||
static_cast<const void*>(in_grads.flat<T>().data()));
|
||||
void* user_fwd_input = const_cast<void*>(
|
||||
static_cast<const void*>(in_image.flat<T>().data()));
|
||||
void* user_fwd_output = const_cast<void*>(
|
||||
static_cast<const void*>(out_image.flat<T>().data()));
|
||||
CHECK_EQ(dnnLayoutCreateFromPrimitive_F32(<_workspace, lrn_bwd,
|
||||
dnnResourceWorkspace),
|
||||
E_SUCCESS);
|
||||
CHECK_EQ(dnnLayoutCreateFromPrimitive_F32(<_bdw_input, lrn_bwd,
|
||||
dnnResourceDiffDst),
|
||||
E_SUCCESS);
|
||||
|
||||
bool ingrad_in_mkl_format = ingrad_shape.IsMklTensor();
|
||||
if (ingrad_in_mkl_format) {
|
||||
if (!dnnLayoutCompare_F32(lt_bdw_input, lt_input)) {
|
||||
AllocTmpBuffer(context, mkl_tmp_input_buf_tensor, lt_bdw_input,
|
||||
&res_lrn_bwd[dnnResourceDiffDst]);
|
||||
ingrad_shape.GetConvertedFlatData(lt_bdw_input, user_input,
|
||||
res_lrn_bwd[dnnResourceDiffDst]);
|
||||
} else {
|
||||
res_lrn_bwd[dnnResourceDiffDst] = user_input;
|
||||
}
|
||||
} else {
|
||||
if (!dnnLayoutCompare_F32(lt_bdw_input, lt_input)) {
|
||||
CHECK_EQ(
|
||||
dnnConversionCreate_F32(&convert_input, lt_input, lt_bdw_input),
|
||||
E_SUCCESS);
|
||||
|
||||
AllocTmpBuffer(context, mkl_tmp_input_buf_tensor, lt_bdw_input,
|
||||
&res_lrn_bwd[dnnResourceDiffDst]);
|
||||
CHECK_EQ(dnnConversionExecute_F32(convert_input, user_input,
|
||||
res_lrn_bwd[dnnResourceDiffDst]),
|
||||
E_SUCCESS);
|
||||
dnnDelete_F32(convert_input);
|
||||
} else {
|
||||
res_lrn_bwd[dnnResourceDiffDst] = user_input;
|
||||
}
|
||||
}
|
||||
|
||||
// Although MKL documentation for LRN does not specify setting/getting
|
||||
// of dnnResourceSrc and dnnResourceDst, Caffe code sets dnnResourceSrc.
|
||||
// So we set dnnResourceSrc here. But we do not know why we are setting
|
||||
// dnnResourceDst.
|
||||
#if 0
|
||||
// NOTE: The code below is kept just so that we know how we should handle
|
||||
// dnnResourceSrc if the primitive layout for dnnResourceSrc was supported.
|
||||
|
||||
if (!dnnLayoutCompare_F32(lt_internal_input,
|
||||
static_cast<dnnLayout_t>inimage_shape.GetCurLayout())) {
|
||||
AllocTmpBuffer(context, mkl_tmp_image_buf_tensor, lt_internal_input,
|
||||
&res_lrn_bwd[dnnResourceSrc]);
|
||||
inimage_shape.GetConvertedFlatData(lt_internal_input,
|
||||
user_fwd_input,
|
||||
res_lrn_bwd[dnnResourceSrc]);
|
||||
} else {
|
||||
res_lrn_bwd[dnnResourceSrc] = user_fwd_input;
|
||||
}
|
||||
#endif
|
||||
|
||||
// Since we cannot get expected layout for dnnResourceSrc, we construct
|
||||
// buffer using
|
||||
// MKL format if input is in MKL format.
|
||||
if (inimage_shape.IsMklTensor()) {
|
||||
AllocTmpBuffer(context, mkl_tmp_image_buf_tensor,
|
||||
(dnnLayout_t)inimage_shape.GetCurLayout(),
|
||||
&res_lrn_bwd[dnnResourceSrc]);
|
||||
} else {
|
||||
res_lrn_bwd[dnnResourceSrc] = user_fwd_input;
|
||||
}
|
||||
|
||||
// Same comment as above.
|
||||
if (outimage_shape.IsMklTensor()) {
|
||||
AllocTmpBuffer(context, mkl_tmp_outimage_buf_tensor,
|
||||
(dnnLayout_t)outimage_shape.GetCurLayout(),
|
||||
&res_lrn_bwd[dnnResourceDst]);
|
||||
} else {
|
||||
res_lrn_bwd[dnnResourceDst] = user_fwd_output;
|
||||
}
|
||||
|
||||
// Allocate buffer for workspace.
|
||||
AllocTmpBuffer(context, mkl_tmp_workspace_buf_tensor, lt_workspace,
|
||||
&res_lrn_bwd[dnnResourceWorkspace]);
|
||||
}
|
||||
|
||||
// Fallback implementation - Taken from lrn_op.cc
|
||||
// TODO(intelft) Check if we can use EigenLRNOp directly instead of making a
|
||||
// copy.
|
||||
void MklDefaultToEigen(OpKernelContext* context) {
|
||||
// CHECK(false);
|
||||
Tensor in_grads = MklGetInput(context, 0);
|
||||
Tensor in_image = MklGetInput(context, 1);
|
||||
Tensor out_image = MklGetInput(context, 2);
|
||||
|
||||
GetMklShape(context, 0, &ingrad_shape);
|
||||
GetMklShape(context, 1, &inimage_shape);
|
||||
GetMklShape(context, 2, &outimage_shape);
|
||||
|
||||
const int64 batch = static_cast<int64>(in_grads.dim_size(0));
|
||||
const int64 rows = static_cast<int64>(in_grads.dim_size(1));
|
||||
const int64 cols = static_cast<int64>(in_grads.dim_size(2));
|
||||
const int64 depth = static_cast<int64>(in_grads.dim_size(3));
|
||||
const auto nodes = cols * rows;
|
||||
|
||||
auto grads_shaped = in_grads.shaped<T, 2>({nodes * batch, depth});
|
||||
auto in_shaped = in_image.shaped<T, 2>({nodes * batch, depth});
|
||||
auto activations = out_image.shaped<T, 2>({nodes * batch, depth});
|
||||
|
||||
Tensor* output;
|
||||
MklShape mkl_output_mkl_shape;
|
||||
mkl_output_mkl_shape.SetMklTensor(false);
|
||||
mkl_output_mkl_shape.SetDimensions(4);
|
||||
AllocateOutputSetMklShape(context, 0, &output, in_grads.shape(),
|
||||
mkl_output_mkl_shape);
|
||||
|
||||
auto out_shaped = output->shaped<T, 2>({nodes * batch, depth});
|
||||
out_shaped.setZero();
|
||||
auto shard = [this, activations, in_shaped, grads_shaped, out_shaped,
|
||||
depth](int64 begin, int64 end) {
|
||||
for (int64 i = begin; i < end; ++i) {
|
||||
for (int64 j = 0; j < depth; ++j) {
|
||||
int64 depth_begin = std::max<int64>(0, j - depth_radius_);
|
||||
int64 depth_end = std::min<int64>(depth, j + depth_radius_ + 1);
|
||||
|
||||
T norm(0);
|
||||
for (int64 k = depth_begin; k < depth_end; ++k) {
|
||||
norm += in_shaped(i, k) * in_shaped(i, k);
|
||||
}
|
||||
norm = alpha_ * norm + bias_;
|
||||
DCHECK_GT(norm, T(1e-6));
|
||||
for (int64 k = depth_begin; k < depth_end; ++k) {
|
||||
T dyi = T(-2) * alpha_ * beta_ * in_shaped(i, k) *
|
||||
activations(i, j) / norm;
|
||||
if (k == j) {
|
||||
dyi += Eigen::numext::pow(norm, -beta_);
|
||||
}
|
||||
dyi *= grads_shaped(i, j);
|
||||
const_cast<typename TTypes<T, 2>::Tensor&>(out_shaped)(i, k) +=
|
||||
dyi;
|
||||
}
|
||||
}
|
||||
}
|
||||
};
|
||||
auto worker_threads =
|
||||
*(context->device()->tensorflow_cpu_worker_threads());
|
||||
Shard(worker_threads.num_threads, worker_threads.workers, nodes * batch,
|
||||
depth * depth, shard);
|
||||
}
|
||||
|
||||
// release mkl resources
|
||||
void Mklcleanup() {
|
||||
bool ingrad_in_mkl_format = ingrad_shape.IsMklTensor();
|
||||
bool inimage_in_mkl_format = inimage_shape.IsMklTensor();
|
||||
if (!ingrad_in_mkl_format) {
|
||||
CHECK_EQ(dnnLayoutDelete_F32(lt_input), E_SUCCESS);
|
||||
}
|
||||
|
||||
if (!inimage_in_mkl_format) {
|
||||
CHECK_EQ(dnnLayoutDelete_F32(lt_output), E_SUCCESS);
|
||||
}
|
||||
dnnDelete_F32(lrn_bwd);
|
||||
dnnLayoutDelete_F32(lt_bdw_input);
|
||||
dnnLayoutDelete_F32(lt_workspace);
|
||||
}
|
||||
} MklLRNGradOpContext;
|
||||
|
||||
typedef typename Eigen::Tensor<T, 1, Eigen::RowMajor>::DimensionPair DimPair;
|
||||
bool workspace_enabled_;
|
||||
int depth_radius_;
|
||||
float bias_;
|
||||
float alpha_;
|
||||
float beta_;
|
||||
};
|
||||
|
||||
#define REGISTER_MKL_LRN_CPU(T) \
|
||||
REGISTER_KERNEL_BUILDER(Name("_MklLRN") \
|
||||
.Device(DEVICE_CPU) \
|
||||
.TypeConstraint<T>("T") \
|
||||
.Label(mkl_op_registry::kMklOpLabel), \
|
||||
MklLRNOp<T>); \
|
||||
REGISTER_KERNEL_BUILDER(Name("_MklLRNGrad") \
|
||||
.Device(DEVICE_CPU) \
|
||||
.TypeConstraint<T>("T") \
|
||||
.Label(mkl_op_registry::kMklOpLabel), \
|
||||
MklLRNGradOp<T>);
|
||||
|
||||
TF_CALL_float(REGISTER_MKL_LRN_CPU);
|
||||
|
||||
} // namespace tensorflow
|
||||
|
||||
#endif // INTEL_MKL
|
@ -83,10 +83,11 @@ class MklMaxPoolingOp : public OpKernel {
|
||||
ExtractMklOpParams(context, data_format_, pool_params, &mkl_context.params);
|
||||
|
||||
mkl_context.MklCreateLayoutsAndPrimitives(context);
|
||||
OP_REQUIRES_OK(context, context->status());
|
||||
|
||||
// Declare output tensor
|
||||
TensorShape tensor_out_shape;
|
||||
MklShape mkl_out_shape;
|
||||
MklShape mkl_out_shape, mkl_workspace_shape;
|
||||
mkl_out_shape.SetMklTensor(true);
|
||||
mkl_out_shape.SetMklLayout(mkl_context.prim_pooling_fwd, dnnResourceDst);
|
||||
mkl_out_shape.SetTfLayout(mkl_context.params.in_dim,
|
||||
@ -98,31 +99,22 @@ class MklMaxPoolingOp : public OpKernel {
|
||||
tensor_out_shape.AddDim(dnnLayoutGetMemorySize_F32(static_cast<dnnLayout_t>(
|
||||
mkl_out_shape.GetMklLayout())) /
|
||||
sizeof(T));
|
||||
AllocateOutputSetMklshape(context, 0, &output_tensor, tensor_out_shape,
|
||||
AllocateOutputSetMklShape(context, 0, &output_tensor, tensor_out_shape,
|
||||
mkl_out_shape);
|
||||
|
||||
if (!workspace_enabled_) {
|
||||
mkl_out_shape.SetMklTensor(false);
|
||||
}
|
||||
|
||||
Tensor* workspace_tensor;
|
||||
void* workspace_buf = nullptr;
|
||||
if (workspace_enabled_) {
|
||||
TensorShape workspace_shape;
|
||||
workspace_shape.AddDim(
|
||||
dnnLayoutGetMemorySize_F32(
|
||||
static_cast<dnnLayout_t>(mkl_context.lt_workspace)) /
|
||||
sizeof(T));
|
||||
AllocateOutputSetMklshape(context, 1, &workspace_tensor, workspace_shape,
|
||||
mkl_out_shape);
|
||||
mkl_context.pooling_res[dnnResourceWorkspace] = const_cast<void*>(
|
||||
static_cast<const void*>(workspace_tensor->flat<T>().data()));
|
||||
} else {
|
||||
AllocTmpBuffer(context, workspace_tensor, mkl_context.lt_workspace,
|
||||
&workspace_buf);
|
||||
mkl_context.pooling_res[dnnResourceWorkspace] = workspace_buf;
|
||||
}
|
||||
|
||||
TensorShape workspace_shape;
|
||||
mkl_workspace_shape.SetMklTensor(false);
|
||||
workspace_shape.AddDim(dnnLayoutGetMemorySize_F32(static_cast<dnnLayout_t>(
|
||||
mkl_context.lt_workspace)) /
|
||||
sizeof(T));
|
||||
AllocateOutputSetMklShape(context, 1, &workspace_tensor, workspace_shape,
|
||||
mkl_workspace_shape);
|
||||
|
||||
mkl_context.pooling_res[dnnResourceWorkspace] = const_cast<void*>(
|
||||
static_cast<const void*>(workspace_tensor->flat<T>().data()));
|
||||
mkl_context.pooling_res[dnnResourceSrc] =
|
||||
const_cast<void*>(static_cast<const void*>(tensor_in.flat<T>().data()));
|
||||
mkl_context.pooling_res[dnnResourceDst] = const_cast<void*>(
|
||||
@ -140,8 +132,8 @@ class MklMaxPoolingOp : public OpKernel {
|
||||
MklPoolingOpParams params;
|
||||
MklShape input_shape;
|
||||
void* pooling_res[dnnResourceNumber];
|
||||
dnnPrimitive_t prim_pooling_fwd;
|
||||
dnnLayout_t lt_user_input, lt_workspace;
|
||||
dnnPrimitive_t prim_pooling_fwd = nullptr;
|
||||
dnnLayout_t lt_user_input = nullptr, lt_workspace = nullptr;
|
||||
|
||||
void MklCreateLayoutsAndPrimitives(OpKernelContext* context) {
|
||||
bool input_in_mkl_format = input_shape.IsMklTensor();
|
||||
@ -256,8 +248,13 @@ class MklMaxPoolingGradOp : public OpKernel {
|
||||
ExtractMklOpParams(context, data_format_, pool_params, &mkl_context.params);
|
||||
|
||||
mkl_context.MklCreateLayouts(context);
|
||||
OP_REQUIRES_OK(context, context->status());
|
||||
|
||||
mkl_context.MklCreatePrimitives(context, workspace_enabled_);
|
||||
OP_REQUIRES_OK(context, context->status());
|
||||
|
||||
mkl_context.MklPrepareInputs(context, workspace_enabled_);
|
||||
OP_REQUIRES_OK(context, context->status());
|
||||
|
||||
// Create shape for the input back prop output
|
||||
TensorShape mkl_input_backprop;
|
||||
@ -274,7 +271,7 @@ class MklMaxPoolingGradOp : public OpKernel {
|
||||
dnnLayoutGetMemorySize_F32(
|
||||
static_cast<dnnLayout_t>(mkl_output_shape.GetMklLayout())) /
|
||||
sizeof(T));
|
||||
AllocateOutputSetMklshape(context, 0, &output_tensor, mkl_input_backprop,
|
||||
AllocateOutputSetMklShape(context, 0, &output_tensor, mkl_input_backprop,
|
||||
mkl_output_shape);
|
||||
mkl_context.pooling_res[dnnResourceDiffSrc] = const_cast<void*>(
|
||||
static_cast<const void*>(output_tensor->flat<T>().data()));
|
||||
@ -297,12 +294,15 @@ class MklMaxPoolingGradOp : public OpKernel {
|
||||
MklShape input_shape, output_backprop_shape;
|
||||
void* pooling_resfwd[dnnResourceNumber];
|
||||
void* pooling_res[dnnResourceNumber];
|
||||
dnnPrimitive_t prim_pooling_fwd, prim_pooling_bwd, convert_input,
|
||||
convert_outbackprop;
|
||||
dnnLayout_t lt_outbackprop_user, lt_outbackprop_prim, lt_input_user,
|
||||
lt_input_prim;
|
||||
dnnPrimitive_t prim_pooling_fwd = nullptr, prim_pooling_bwd = nullptr,
|
||||
convert_input = nullptr, convert_outbackprop = nullptr;
|
||||
dnnLayout_t lt_outbackprop_user = nullptr, lt_outbackprop_prim = nullptr,
|
||||
lt_input_user = nullptr, lt_input_prim = nullptr;
|
||||
void* input_buf;
|
||||
void* outbackprop_buf;
|
||||
Tensor tmp_output_buf_tensor;
|
||||
Tensor workspace_buf_tensor;
|
||||
Tensor input_buf_tensor, outbackprop_buf_tensor;
|
||||
|
||||
void MklCreateLayouts(OpKernelContext* context) {
|
||||
bool input_in_mkl_format = input_shape.IsMklTensor();
|
||||
@ -351,9 +351,6 @@ class MklMaxPoolingGradOp : public OpKernel {
|
||||
<_outbackprop_prim, prim_pooling_bwd, dnnResourceDiffDst),
|
||||
E_SUCCESS);
|
||||
|
||||
// Tensors needed to create temporary buffers
|
||||
Tensor input_buf_tensor, outbackprop_buf_tensor;
|
||||
|
||||
if (workspace_enabled == false) {
|
||||
CHECK_EQ(dnnLayoutCreateFromPrimitive_F32(
|
||||
<_input_prim, prim_pooling_fwd, dnnResourceSrc),
|
||||
@ -384,11 +381,8 @@ class MklMaxPoolingGradOp : public OpKernel {
|
||||
bool input_in_mkl_format = input_shape.IsMklTensor();
|
||||
bool outbackprop_in_mkl_format = output_backprop_shape.IsMklTensor();
|
||||
|
||||
void* tmp_output_buf;
|
||||
Tensor tmp_output_buf_tensor;
|
||||
|
||||
void* workspace_buf;
|
||||
Tensor workspace_buf_tensor;
|
||||
void* tmp_output_buf = nullptr;
|
||||
void* workspace_buf = nullptr;
|
||||
|
||||
if (workspace_enabled == false) {
|
||||
if (convert_input != nullptr) {
|
||||
@ -490,16 +484,16 @@ class MklMaxPoolingGradOp : public OpKernel {
|
||||
bool workspace_enabled_;
|
||||
};
|
||||
|
||||
REGISTER_KERNEL_BUILDER(Name("MklMaxPool")
|
||||
REGISTER_KERNEL_BUILDER(Name("_MklMaxPool")
|
||||
.Device(DEVICE_CPU)
|
||||
.TypeConstraint<float>("T")
|
||||
.Label(mkl_layer_registry::kMklLayerLabel),
|
||||
.Label(mkl_op_registry::kMklOpLabel),
|
||||
MklMaxPoolingOp<CPUDevice, float>);
|
||||
|
||||
REGISTER_KERNEL_BUILDER(Name("MklMaxPoolGrad")
|
||||
REGISTER_KERNEL_BUILDER(Name("_MklMaxPoolGrad")
|
||||
.Device(DEVICE_CPU)
|
||||
.TypeConstraint<float>("T")
|
||||
.Label(mkl_layer_registry::kMklLayerLabel),
|
||||
.Label(mkl_op_registry::kMklOpLabel),
|
||||
MklMaxPoolingGradOp<CPUDevice, float>);
|
||||
|
||||
} // namespace tensorflow
|
||||
|
@ -63,7 +63,7 @@ class MklReluOp : public OpKernel {
|
||||
const TensorShape& o_shape = input.shape();
|
||||
Tensor* out_tensor = nullptr;
|
||||
mkl_context.output_shape.SetMklTensor(false);
|
||||
AllocateOutputSetMklshape(context, 0, &out_tensor, o_shape,
|
||||
AllocateOutputSetMklShape(context, 0, &out_tensor, o_shape,
|
||||
mkl_context.output_shape);
|
||||
void* out_o = static_cast<void*>(out_tensor->flat<T>().data());
|
||||
(static_cast<T*>(out_o))[0] =
|
||||
@ -114,12 +114,12 @@ class MklReluOp : public OpKernel {
|
||||
tf_shape.AddDim(dnnLayoutGetMemorySize_F32(static_cast<dnnLayout_t>(
|
||||
mkl_context.output_shape.GetMklLayout())) /
|
||||
sizeof(T));
|
||||
AllocateOutputSetMklshape(context, 0, &output, tf_shape,
|
||||
AllocateOutputSetMklShape(context, 0, &output, tf_shape,
|
||||
mkl_context.output_shape);
|
||||
} else {
|
||||
const TensorShape& o_shape = input.shape();
|
||||
mkl_context.output_shape.SetMklTensor(false);
|
||||
AllocateOutputSetMklshape(context, 0, &output, o_shape,
|
||||
AllocateOutputSetMklShape(context, 0, &output, o_shape,
|
||||
mkl_context.output_shape);
|
||||
}
|
||||
|
||||
@ -293,7 +293,7 @@ void MklReluGradOp<Device, T>::Compute(OpKernelContext* context) {
|
||||
// Allocate space for g and
|
||||
const TensorShape& g_shape = g.shape();
|
||||
mkl_context.output_shape.SetMklTensor(false);
|
||||
AllocateOutputSetMklshape(context, 0, &output, g_shape,
|
||||
AllocateOutputSetMklShape(context, 0, &output, g_shape,
|
||||
mkl_context.output_shape);
|
||||
void* out_o = static_cast<void*>(output->flat<T>().data());
|
||||
(static_cast<T*>(out_o))[0] =
|
||||
@ -359,13 +359,13 @@ void MklReluGradOp<Device, T>::Compute(OpKernelContext* context) {
|
||||
tf_shape.AddDim(dnnLayoutGetMemorySize_F32(static_cast<dnnLayout_t>(
|
||||
mkl_context.output_shape.GetMklLayout())) /
|
||||
sizeof(T));
|
||||
AllocateOutputSetMklshape(context, 0, &output, tf_shape,
|
||||
AllocateOutputSetMklShape(context, 0, &output, tf_shape,
|
||||
mkl_context.output_shape);
|
||||
|
||||
} else {
|
||||
const TensorShape& o_shape = g.shape();
|
||||
mkl_context.output_shape.SetMklTensor(false);
|
||||
AllocateOutputSetMklshape(context, 0, &output, o_shape,
|
||||
AllocateOutputSetMklShape(context, 0, &output, o_shape,
|
||||
mkl_context.output_shape);
|
||||
}
|
||||
|
||||
@ -379,16 +379,16 @@ void MklReluGradOp<Device, T>::Compute(OpKernelContext* context) {
|
||||
|
||||
/* Register DNN kernels for supported operations and supported types - right now
|
||||
* it is only Relu and f32*/
|
||||
#define REGISTER_RELU_MKL_SUPPORTED_KERNELS_TYPES(type) \
|
||||
REGISTER_KERNEL_BUILDER(Name("MklRelu") \
|
||||
.Device(DEVICE_CPU) \
|
||||
.TypeConstraint<type>("T") \
|
||||
.Label(mkl_layer_registry::kMklLayerLabel), \
|
||||
MklReluOp<CPUDevice, type>); \
|
||||
REGISTER_KERNEL_BUILDER(Name("MklReluGrad") \
|
||||
.Device(DEVICE_CPU) \
|
||||
.TypeConstraint<type>("T") \
|
||||
.Label(mkl_layer_registry::kMklLayerLabel), \
|
||||
#define REGISTER_RELU_MKL_SUPPORTED_KERNELS_TYPES(type) \
|
||||
REGISTER_KERNEL_BUILDER(Name("_MklRelu") \
|
||||
.Device(DEVICE_CPU) \
|
||||
.TypeConstraint<type>("T") \
|
||||
.Label(mkl_op_registry::kMklOpLabel), \
|
||||
MklReluOp<CPUDevice, type>); \
|
||||
REGISTER_KERNEL_BUILDER(Name("_MklReluGrad") \
|
||||
.Device(DEVICE_CPU) \
|
||||
.TypeConstraint<type>("T") \
|
||||
.Label(mkl_op_registry::kMklOpLabel), \
|
||||
MklReluGradOp<CPUDevice, type>);
|
||||
TF_CALL_float(REGISTER_RELU_MKL_SUPPORTED_KERNELS_TYPES);
|
||||
|
||||
|
149
tensorflow/core/kernels/mkl_reshape_op.cc
Normal file
149
tensorflow/core/kernels/mkl_reshape_op.cc
Normal file
@ -0,0 +1,149 @@
|
||||
/* Copyright 2015 The TensorFlow Authors. All Rights Reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
|
||||
#ifdef INTEL_MKL
|
||||
|
||||
#include <memory>
|
||||
#include "tensorflow/core/framework/op_kernel.h"
|
||||
#include "tensorflow/core/framework/register_types.h"
|
||||
#include "tensorflow/core/framework/tensor.h"
|
||||
#include "tensorflow/core/framework/tensor_shape.h"
|
||||
#include "tensorflow/core/framework/types.h"
|
||||
#include "tensorflow/core/lib/core/status.h"
|
||||
#include "tensorflow/core/platform/logging.h"
|
||||
|
||||
#include "third_party/mkl/include/mkl_dnn.h"
|
||||
#include "third_party/mkl/include/mkl_dnn_types.h"
|
||||
#include "tensorflow/core/util/mkl_util.h"
|
||||
|
||||
namespace tensorflow {
|
||||
using CPUDevice = Eigen::ThreadPoolDevice;
|
||||
template <typename Device, typename T>
|
||||
class MklReshapeOp : public OpKernel {
|
||||
public:
|
||||
explicit MklReshapeOp(OpKernelConstruction* context) : OpKernel(context) {}
|
||||
|
||||
void Compute(OpKernelContext* context) override {
|
||||
const Tensor& input = MklGetInput(context, 0);
|
||||
const Tensor& sizes = MklGetInput(context, 1);
|
||||
|
||||
// Preliminary validation of sizes.
|
||||
OP_REQUIRES(context, IsLegacyVector(sizes.shape()),
|
||||
errors::InvalidArgument("sizes input must be 1-D, not shape ",
|
||||
sizes.shape().DebugString()));
|
||||
const int64 num_dims = sizes.NumElements();
|
||||
|
||||
// Compute the output shape. Determine product of specified
|
||||
// dimensions, and find the index of the unspecified one.
|
||||
TensorShape shape;
|
||||
int64 product = 1;
|
||||
int unknown_index = -1;
|
||||
auto vec_size = sizes.flat<int32>();
|
||||
for (int d = 0; d < num_dims; ++d) {
|
||||
const int32 size = vec_size(d);
|
||||
if (size == -1) {
|
||||
OP_REQUIRES(
|
||||
context, unknown_index == -1,
|
||||
errors::InvalidArgument("only one input size may be -1, not both ",
|
||||
unknown_index, " and ", d));
|
||||
unknown_index = d;
|
||||
shape.AddDim(1);
|
||||
} else {
|
||||
OP_REQUIRES(context, size >= 0,
|
||||
errors::InvalidArgument(
|
||||
"size ", d, " must be non-negative, not ", size));
|
||||
shape.AddDim(size);
|
||||
product *= size;
|
||||
}
|
||||
}
|
||||
if (unknown_index != -1) {
|
||||
OP_REQUIRES(
|
||||
context, product > 0,
|
||||
errors::InvalidArgument("Reshape cannot infer the missing input size "
|
||||
"for an empty tensor unless all specified "
|
||||
"input sizes are non-zero"));
|
||||
const int64 missing = input.NumElements() / product;
|
||||
OP_REQUIRES(
|
||||
context, product * missing == input.NumElements(),
|
||||
errors::InvalidArgument(
|
||||
"Input to reshape is a tensor with ", input.NumElements(),
|
||||
" values, but the requested shape requires a multiple of ",
|
||||
product));
|
||||
shape.set_dim(unknown_index, missing);
|
||||
}
|
||||
OP_REQUIRES(context, shape.num_elements() == input.NumElements(),
|
||||
errors::InvalidArgument("Input to reshape is a tensor with ",
|
||||
input.NumElements(),
|
||||
" values, but the requested shape has ",
|
||||
shape.num_elements()));
|
||||
|
||||
MklShape mkl_shape_input;
|
||||
GetMklShape(context, 0, &mkl_shape_input);
|
||||
bool input_in_mkl_format = mkl_shape_input.IsMklTensor();
|
||||
if (input_in_mkl_format) {
|
||||
TensorShape& shape_to = shape;
|
||||
TensorShape shape_from;
|
||||
for (size_t i = 0; i < mkl_shape_input.GetDimension(); i++) {
|
||||
// Outermost to innermost dimension
|
||||
shape_from.AddDim(
|
||||
mkl_shape_input.GetSizes()[mkl_shape_input.tf_dim_idx(i)]);
|
||||
}
|
||||
|
||||
if (shape_from == shape_to) {
|
||||
CopyMklTensorInToOut(context, 0, 0);
|
||||
return;
|
||||
} else {
|
||||
// Allocate output tensor.
|
||||
Tensor* output_tensor = NULL;
|
||||
MklShape mkl_shape_output;
|
||||
mkl_shape_output.SetMklTensor(false);
|
||||
AllocateOutputSetMklShape(context, 0, &output_tensor, shape_to,
|
||||
mkl_shape_output);
|
||||
|
||||
// Get output layout pointer.
|
||||
dnnLayout_t output_layout =
|
||||
static_cast<dnnLayout_t>(mkl_shape_input.GetTfLayout());
|
||||
|
||||
// Execute DNNConversion.
|
||||
// Note: we assume an MKL tensor always have float as its data type.
|
||||
void* input_buffer =
|
||||
static_cast<void*>(const_cast<float*>(input.flat<float>().data()));
|
||||
void* output_buffer = static_cast<void*>(
|
||||
const_cast<float*>(output_tensor->flat<float>().data()));
|
||||
mkl_shape_input.GetConvertedFlatData(output_layout, input_buffer,
|
||||
output_buffer);
|
||||
|
||||
VLOG(1) << "MKLToTFConversion complete successfully.";
|
||||
return;
|
||||
}
|
||||
} else {
|
||||
CopyTFTensorInToOut(context, 0, 0, shape);
|
||||
}
|
||||
}
|
||||
};
|
||||
|
||||
#define REGISTER_MKL_CPU(T) \
|
||||
REGISTER_KERNEL_BUILDER(Name("_MklReshape") \
|
||||
.Device(DEVICE_CPU) \
|
||||
.HostMemory("shape") \
|
||||
.TypeConstraint<T>("T") \
|
||||
.TypeConstraint<int32>("Tshape") \
|
||||
.Label(mkl_op_registry::kMklOpLabel), \
|
||||
MklReshapeOp<CPUDevice, T>);
|
||||
TF_CALL_float(REGISTER_MKL_CPU);
|
||||
#undef REGISTER_MKL_CPU
|
||||
} // namespace tensorflow
|
||||
|
||||
#endif // INTEL_MKL
|
@ -105,11 +105,11 @@ class MklToTfOp : public OpKernel {
|
||||
// Register kernel
|
||||
///////////////////////////////////////////////////////////
|
||||
|
||||
#define REGISTER_CPU(T) \
|
||||
REGISTER_KERNEL_BUILDER(Name("MklToTf") \
|
||||
.Device(DEVICE_CPU) \
|
||||
.TypeConstraint<T>("T") \
|
||||
.Label(mkl_layer_registry::kMklLayerLabel), \
|
||||
#define REGISTER_CPU(T) \
|
||||
REGISTER_KERNEL_BUILDER(Name("MklToTf") \
|
||||
.Device(DEVICE_CPU) \
|
||||
.TypeConstraint<T>("T") \
|
||||
.Label(mkl_op_registry::kMklOpLabel), \
|
||||
MklToTfOp<CPUDevice, T>);
|
||||
|
||||
TF_CALL_float(REGISTER_CPU);
|
||||
|
@ -233,9 +233,9 @@ class Im2ColConvFunctor {
|
||||
int filter_top_offset;
|
||||
if (padding == VALID) {
|
||||
filter_left_offset =
|
||||
((output_width - 1) * stride + filter_width - input_width) / 2;
|
||||
((output_width - 1) * stride + filter_width - input_width + 1) / 2;
|
||||
filter_top_offset =
|
||||
((output_height - 1) * stride + filter_height - input_height) / 2;
|
||||
((output_height - 1) * stride + filter_height - input_height + 1) / 2;
|
||||
} else {
|
||||
filter_left_offset =
|
||||
((output_width - 1) * stride + filter_width - input_width) / 2;
|
||||
|
@ -29,7 +29,7 @@ namespace tensorflow {
|
||||
typedef Eigen::ThreadPoolDevice CPUDevice;
|
||||
typedef Eigen::GpuDevice GPUDevice;
|
||||
|
||||
template <typename Device, typename T>
|
||||
template <typename Device, typename T, typename Tindices>
|
||||
class SparseTensorDenseMatMulOp : public OpKernel {
|
||||
public:
|
||||
explicit SparseTensorDenseMatMulOp(OpKernelConstruction* ctx)
|
||||
@ -139,15 +139,14 @@ class SparseTensorDenseMatMulOp : public OpKernel {
|
||||
TensorShape({0}), &scratch));
|
||||
}
|
||||
|
||||
#define MAYBE_ADJOINT(ADJ_A, ADJ_B) \
|
||||
if (adjoint_a_ == ADJ_A && adjoint_b_ == ADJ_B) { \
|
||||
Status functor_status = functor::SparseTensorDenseMatMulFunctor< \
|
||||
Device, T, ADJ_A, ADJ_B>::Compute(ctx->eigen_device<Device>(), \
|
||||
out->matrix<T>(), \
|
||||
a_indices->matrix<int64>(), \
|
||||
a_values->vec<T>(), b->matrix<T>(), \
|
||||
scratch.vec<T>()); \
|
||||
OP_REQUIRES_OK(ctx, functor_status); \
|
||||
#define MAYBE_ADJOINT(ADJ_A, ADJ_B) \
|
||||
if (adjoint_a_ == ADJ_A && adjoint_b_ == ADJ_B) { \
|
||||
Status functor_status = functor::SparseTensorDenseMatMulFunctor< \
|
||||
Device, T, Tindices, ADJ_A, \
|
||||
ADJ_B>::Compute(ctx->eigen_device<Device>(), out->matrix<T>(), \
|
||||
a_indices->matrix<Tindices>(), a_values->vec<T>(), \
|
||||
b->matrix<T>(), scratch.vec<T>()); \
|
||||
OP_REQUIRES_OK(ctx, functor_status); \
|
||||
}
|
||||
|
||||
MAYBE_ADJOINT(false, false);
|
||||
@ -163,53 +162,73 @@ class SparseTensorDenseMatMulOp : public OpKernel {
|
||||
bool adjoint_b_;
|
||||
};
|
||||
|
||||
#define REGISTER_CPU(T) \
|
||||
REGISTER_KERNEL_BUILDER(Name("SparseTensorDenseMatMul") \
|
||||
.Device(DEVICE_CPU) \
|
||||
.TypeConstraint<T>("T") \
|
||||
.HostMemory("a_shape"), \
|
||||
SparseTensorDenseMatMulOp<CPUDevice, T>);
|
||||
#define REGISTER_CPU(TypeT, TypeIndex) \
|
||||
REGISTER_KERNEL_BUILDER( \
|
||||
Name("SparseTensorDenseMatMul") \
|
||||
.Device(DEVICE_CPU) \
|
||||
.TypeConstraint<TypeT>("T") \
|
||||
.TypeConstraint<TypeIndex>("Tindices") \
|
||||
.HostMemory("a_shape"), \
|
||||
SparseTensorDenseMatMulOp<CPUDevice, TypeT, TypeIndex>);
|
||||
|
||||
REGISTER_CPU(float);
|
||||
REGISTER_CPU(double);
|
||||
REGISTER_CPU(int32);
|
||||
REGISTER_CPU(complex64);
|
||||
REGISTER_CPU(complex128);
|
||||
#define REGISTER_KERNELS_CPU(T) \
|
||||
REGISTER_CPU(T, int64); \
|
||||
REGISTER_CPU(T, int32)
|
||||
|
||||
REGISTER_KERNELS_CPU(float);
|
||||
REGISTER_KERNELS_CPU(double);
|
||||
REGISTER_KERNELS_CPU(int32);
|
||||
REGISTER_KERNELS_CPU(complex64);
|
||||
REGISTER_KERNELS_CPU(complex128);
|
||||
|
||||
#if GOOGLE_CUDA
|
||||
|
||||
namespace functor {
|
||||
#define DECLARE_GPU_SPEC(T, ADJ_A, ADJ_B) \
|
||||
template <> \
|
||||
Status SparseTensorDenseMatMulFunctor<GPUDevice, T, ADJ_A, ADJ_B>::Compute( \
|
||||
const GPUDevice& d, typename TTypes<T>::Matrix out, \
|
||||
TTypes<int64>::ConstMatrix a_indices, \
|
||||
typename TTypes<T>::ConstVec a_values, \
|
||||
typename TTypes<T>::ConstMatrix b, typename TTypes<T>::Vec scratch); \
|
||||
extern template struct SparseTensorDenseMatMulFunctor<GPUDevice, T, ADJ_A, \
|
||||
ADJ_B>;
|
||||
#define DECLARE_GPU_SPEC(T, Tindices, ADJ_A, ADJ_B) \
|
||||
template <> \
|
||||
Status SparseTensorDenseMatMulFunctor< \
|
||||
GPUDevice, T, Tindices, ADJ_A, \
|
||||
ADJ_B>::Compute(const GPUDevice& d, typename TTypes<T>::Matrix out, \
|
||||
typename TTypes<Tindices>::ConstMatrix a_indices, \
|
||||
typename TTypes<T>::ConstVec a_values, \
|
||||
typename TTypes<T>::ConstMatrix b, \
|
||||
typename TTypes<T>::Vec scratch); \
|
||||
extern template struct SparseTensorDenseMatMulFunctor< \
|
||||
GPUDevice, T, Tindices, ADJ_A, ADJ_B>;
|
||||
|
||||
#define DECLARE_ADJOINT_GPU_SPEC(T) \
|
||||
DECLARE_GPU_SPEC(T, false, false) \
|
||||
DECLARE_GPU_SPEC(T, false, true) \
|
||||
DECLARE_GPU_SPEC(T, true, false) \
|
||||
DECLARE_GPU_SPEC(T, true, true)
|
||||
#define REGISTER_GPU_SPEC(T, ADJ_A, ADJ_B) \
|
||||
DECLARE_GPU_SPEC(T, int32, ADJ_A, ADJ_B); \
|
||||
DECLARE_GPU_SPEC(T, int64, ADJ_A, ADJ_B)
|
||||
|
||||
#define DECLARE_ADJOINT_GPU_SPEC(T) \
|
||||
REGISTER_GPU_SPEC(T, false, false) \
|
||||
REGISTER_GPU_SPEC(T, false, true) \
|
||||
REGISTER_GPU_SPEC(T, true, false) \
|
||||
REGISTER_GPU_SPEC(T, true, true)
|
||||
|
||||
DECLARE_ADJOINT_GPU_SPEC(float);
|
||||
#undef DECLARE_ADJOINT_GPU_SPEC
|
||||
#undef DECLARE_GPU_SPEC
|
||||
#undef REGISTER_GPU_SPEC
|
||||
|
||||
} // namespace functor
|
||||
|
||||
#define REGISTER_GPU(T) \
|
||||
REGISTER_KERNEL_BUILDER(Name("SparseTensorDenseMatMul") \
|
||||
.Device(DEVICE_GPU) \
|
||||
.TypeConstraint<T>("T") \
|
||||
.HostMemory("a_shape"), \
|
||||
SparseTensorDenseMatMulOp<GPUDevice, T>);
|
||||
#define REGISTER_GPU(TypeT, TypeIndex) \
|
||||
REGISTER_KERNEL_BUILDER( \
|
||||
Name("SparseTensorDenseMatMul") \
|
||||
.Device(DEVICE_GPU) \
|
||||
.TypeConstraint<TypeT>("T") \
|
||||
.TypeConstraint<TypeIndex>("Tindices") \
|
||||
.HostMemory("a_shape"), \
|
||||
SparseTensorDenseMatMulOp<GPUDevice, TypeT, TypeIndex>);
|
||||
|
||||
REGISTER_GPU(float);
|
||||
#define REGISTER_KERNELS_GPU(T) \
|
||||
REGISTER_GPU(T, int64); \
|
||||
REGISTER_GPU(T, int32)
|
||||
|
||||
REGISTER_KERNELS_GPU(float);
|
||||
#undef REGISTER_GPU
|
||||
#undef REGISTER_KERNELS_GPU
|
||||
#endif // GOOGLE_CUDA
|
||||
|
||||
namespace functor {
|
||||
@ -228,13 +247,13 @@ Status MOutOfBoundsError(int64 m, std::size_t i, int lhs_index_a,
|
||||
}
|
||||
} // namespace
|
||||
|
||||
template <typename T, bool ADJ_A, bool ADJ_B>
|
||||
struct SparseTensorDenseMatMulFunctor<CPUDevice, T, ADJ_A, ADJ_B> {
|
||||
template <typename T, typename Tindices, bool ADJ_A, bool ADJ_B>
|
||||
struct SparseTensorDenseMatMulFunctor<CPUDevice, T, Tindices, ADJ_A, ADJ_B> {
|
||||
// Vectorize certain operations above this size.
|
||||
static const std::size_t kNumVectorize = 32;
|
||||
|
||||
static Status Compute(const CPUDevice& d, typename TTypes<T>::Matrix out,
|
||||
TTypes<int64>::ConstMatrix a_indices,
|
||||
typename TTypes<Tindices>::ConstMatrix a_indices,
|
||||
typename TTypes<T>::ConstVec a_values,
|
||||
typename TTypes<T>::ConstMatrix b,
|
||||
typename TTypes<T>::Vec scratch) {
|
||||
@ -255,8 +274,8 @@ struct SparseTensorDenseMatMulFunctor<CPUDevice, T, ADJ_A, ADJ_B> {
|
||||
auto maybe_adjoint_b = MaybeAdjoint<decltype(b), ADJ_B>(b);
|
||||
|
||||
for (std::size_t i = 0; i < nnz; ++i) {
|
||||
const int64 m = internal::SubtleMustCopy(a_indices(i, lhs_index_a));
|
||||
const int64 k = internal::SubtleMustCopy(a_indices(i, rhs_index_a));
|
||||
const Tindices m = internal::SubtleMustCopy(a_indices(i, lhs_index_a));
|
||||
const Tindices k = internal::SubtleMustCopy(a_indices(i, rhs_index_a));
|
||||
if (!FastBoundsCheck(k, lhs_right)) {
|
||||
return KOutOfBoundsError(k, i, rhs_index_a, lhs_right);
|
||||
}
|
||||
@ -273,19 +292,19 @@ struct SparseTensorDenseMatMulFunctor<CPUDevice, T, ADJ_A, ADJ_B> {
|
||||
// Vectorization via Eigen.
|
||||
const int b_chip_index = ADJ_B ? 1 : 0;
|
||||
|
||||
#define LOOP_NNZ(b_passed) \
|
||||
for (std::size_t i = 0; i < nnz; ++i) { \
|
||||
const int64 m = internal::SubtleMustCopy(a_indices(i, lhs_index_a)); \
|
||||
const int64 k = internal::SubtleMustCopy(a_indices(i, rhs_index_a)); \
|
||||
const T a_value = (ADJ_A) ? MaybeConj(a_values(i)) : a_values(i); \
|
||||
if (!FastBoundsCheck(k, lhs_right)) { \
|
||||
return KOutOfBoundsError(k, i, rhs_index_a, lhs_right); \
|
||||
} \
|
||||
if (!FastBoundsCheck(m, out.dimension(0))) { \
|
||||
return MOutOfBoundsError(m, i, lhs_index_a, out.dimension(0)); \
|
||||
} \
|
||||
out.template chip<0>(m) += \
|
||||
b_passed.template chip<b_chip_index>(k) * a_value; \
|
||||
#define LOOP_NNZ(b_passed) \
|
||||
for (std::size_t i = 0; i < nnz; ++i) { \
|
||||
const Tindices m = internal::SubtleMustCopy(a_indices(i, lhs_index_a)); \
|
||||
const Tindices k = internal::SubtleMustCopy(a_indices(i, rhs_index_a)); \
|
||||
const T a_value = (ADJ_A) ? MaybeConj(a_values(i)) : a_values(i); \
|
||||
if (!FastBoundsCheck(k, lhs_right)) { \
|
||||
return KOutOfBoundsError(k, i, rhs_index_a, lhs_right); \
|
||||
} \
|
||||
if (!FastBoundsCheck(m, out.dimension(0))) { \
|
||||
return MOutOfBoundsError(m, i, lhs_index_a, out.dimension(0)); \
|
||||
} \
|
||||
out.template chip<0>(m) += \
|
||||
b_passed.template chip<b_chip_index>(k) * a_value; \
|
||||
}
|
||||
|
||||
if (ADJ_B) {
|
||||
|
@ -25,11 +25,12 @@ namespace tensorflow {
|
||||
|
||||
namespace functor {
|
||||
|
||||
template <typename Device, typename T, bool ADJ_A, bool ADJ_B>
|
||||
template <typename Device, typename T, typename Tindices, bool ADJ_A,
|
||||
bool ADJ_B>
|
||||
struct SparseTensorDenseMatMulFunctor {
|
||||
static EIGEN_ALWAYS_INLINE Status
|
||||
Compute(const Device& d, typename TTypes<T>::Matrix out,
|
||||
TTypes<int64>::ConstMatrix a_indices,
|
||||
typename TTypes<Tindices>::ConstMatrix a_indices,
|
||||
typename TTypes<T>::ConstVec a_values,
|
||||
typename TTypes<T>::ConstMatrix b, typename TTypes<T>::Vec scratch);
|
||||
};
|
||||
|
@ -27,12 +27,12 @@ typedef Eigen::GpuDevice GPUDevice;
|
||||
|
||||
namespace generator {
|
||||
|
||||
template <typename T, bool ADJ_A, bool ADJ_B>
|
||||
template <typename T, typename Tindices, bool ADJ_A, bool ADJ_B>
|
||||
class SparseTensorDenseMatMulGPUGenerator {
|
||||
public:
|
||||
EIGEN_DEVICE_FUNC EIGEN_ALWAYS_INLINE SparseTensorDenseMatMulGPUGenerator(
|
||||
typename TTypes<T, 2>::Tensor32Bit out,
|
||||
TTypes<const int64, 2>::Tensor32Bit a_indices,
|
||||
typename TTypes<const Tindices, 2>::Tensor32Bit a_indices,
|
||||
typename TTypes<const T, 1>::Tensor32Bit a_values,
|
||||
typename TTypes<const T, 2>::Tensor32Bit b)
|
||||
: out_(out),
|
||||
@ -77,7 +77,7 @@ class SparseTensorDenseMatMulGPUGenerator {
|
||||
mutable typename TTypes<T, 2>::Tensor32Bit out_;
|
||||
const int lhs_index_a_;
|
||||
const int rhs_index_a_;
|
||||
TTypes<const int64, 2>::Tensor32Bit a_indices_;
|
||||
typename TTypes<const Tindices, 2>::Tensor32Bit a_indices_;
|
||||
typename TTypes<const T, 1>::Tensor32Bit a_values_;
|
||||
const int lhs_right_size;
|
||||
functor::MaybeAdjoint<typename TTypes<const T, 2>::Tensor32Bit, ADJ_B>
|
||||
@ -88,14 +88,14 @@ class SparseTensorDenseMatMulGPUGenerator {
|
||||
|
||||
namespace functor {
|
||||
|
||||
template <typename T, bool ADJ_A, bool ADJ_B>
|
||||
struct SparseTensorDenseMatMulFunctor<GPUDevice, T, ADJ_A, ADJ_B> {
|
||||
template <typename T, typename Tindices, bool ADJ_A, bool ADJ_B>
|
||||
struct SparseTensorDenseMatMulFunctor<GPUDevice, T, Tindices, ADJ_A, ADJ_B> {
|
||||
static EIGEN_ALWAYS_INLINE Status
|
||||
Compute(const GPUDevice& d, typename TTypes<T>::Matrix out,
|
||||
TTypes<int64>::ConstMatrix a_indices,
|
||||
typename TTypes<Tindices>::ConstMatrix a_indices,
|
||||
typename TTypes<T>::ConstVec a_values,
|
||||
typename TTypes<T>::ConstMatrix b, typename TTypes<T>::Vec scratch) {
|
||||
generator::SparseTensorDenseMatMulGPUGenerator<T, ADJ_A, ADJ_B>
|
||||
generator::SparseTensorDenseMatMulGPUGenerator<T, Tindices, ADJ_A, ADJ_B>
|
||||
sparse_tensor_dense_matmul_generator(To32Bit(out), To32Bit(a_indices),
|
||||
To32Bit(a_values), To32Bit(b));
|
||||
To32Bit(out).device(d) = To32Bit(out).constant(T(0));
|
||||
@ -146,17 +146,18 @@ struct SparseTensorDenseMatMulFunctor<GPUDevice, T, ADJ_A, ADJ_B> {
|
||||
|
||||
} // namespace functor
|
||||
|
||||
#define DEFINE(T) \
|
||||
template struct functor::SparseTensorDenseMatMulFunctor<GPUDevice, T, false, \
|
||||
false>; \
|
||||
template struct functor::SparseTensorDenseMatMulFunctor<GPUDevice, T, false, \
|
||||
true>; \
|
||||
template struct functor::SparseTensorDenseMatMulFunctor<GPUDevice, T, true, \
|
||||
false>; \
|
||||
template struct functor::SparseTensorDenseMatMulFunctor<GPUDevice, T, true, \
|
||||
true>;
|
||||
#define DEFINE(T, Tindices) \
|
||||
template struct functor::SparseTensorDenseMatMulFunctor< \
|
||||
GPUDevice, T, Tindices, false, false>; \
|
||||
template struct functor::SparseTensorDenseMatMulFunctor< \
|
||||
GPUDevice, T, Tindices, false, true>; \
|
||||
template struct functor::SparseTensorDenseMatMulFunctor< \
|
||||
GPUDevice, T, Tindices, true, false>; \
|
||||
template struct functor::SparseTensorDenseMatMulFunctor< \
|
||||
GPUDevice, T, Tindices, true, true>;
|
||||
|
||||
DEFINE(float);
|
||||
DEFINE(float, int32);
|
||||
DEFINE(float, int64);
|
||||
#undef DEFINE
|
||||
|
||||
} // end namespace tensorflow
|
||||
|
@ -394,6 +394,28 @@ output: A `Tensor` with the concatenation of values stacked along the
|
||||
in `concat_dim` where it has the sum of the sizes.
|
||||
)doc");
|
||||
|
||||
// TODO(vivek.v.rane@intel.com): Prefix the op names with underscore if the ops
|
||||
// are not to be made user-accessible.
|
||||
#ifdef INTEL_MKL
|
||||
REGISTER_OP("_MklConcatV2")
|
||||
.Input("values: N * T")
|
||||
.Input("axis: Tidx")
|
||||
.Input("mkl_values: N * uint8")
|
||||
.Input("mkl_axis: uint8")
|
||||
.Output("output: T")
|
||||
.Output("mkl_output: uint8")
|
||||
.Attr("N: int >= 2")
|
||||
.Attr("T: type")
|
||||
.Attr("Tidx: {int32, int64} = DT_INT32")
|
||||
.SetShapeFn(shape_inference::ConcatV2Shape)
|
||||
.Doc(R"doc(
|
||||
MKL version of ConcatV2 operator. Uses MKL DNN APIs to perform concatenation.
|
||||
|
||||
NOTE Do not invoke this operator directly in Python. Graph rewrite pass is
|
||||
expected to invoke these operators.
|
||||
)doc");
|
||||
#endif
|
||||
|
||||
REGISTER_OP("ConcatOffset")
|
||||
.Input("concat_dim: int32")
|
||||
.Input("shape: N * int32")
|
||||
@ -1638,6 +1660,21 @@ reshape(t, []) ==> 7
|
||||
shape: Defines the shape of the output tensor.
|
||||
)Doc");
|
||||
|
||||
#ifdef INTEL_MKL
|
||||
REGISTER_OP("_MklReshape")
|
||||
.Input("tensor: T")
|
||||
.Input("shape: Tshape")
|
||||
.Input("mkl_tensor: uint8")
|
||||
.Input("mkl_shape: uint8")
|
||||
.Output("output: T")
|
||||
.Output("mkl_output: uint8")
|
||||
.Attr("T: type")
|
||||
.Attr("Tshape: {int32, int64} = DT_INT32")
|
||||
.SetShapeFn([](InferenceContext* c) { return SetOutputShapeForReshape(c); })
|
||||
.Doc(R"Doc( MKL implementation of ReshapeOp.
|
||||
)Doc");
|
||||
#endif // INTEL_MKL
|
||||
|
||||
// --------------------------------------------------------------------------
|
||||
REGISTER_OP("InvertPermutation")
|
||||
.Input("x: T")
|
||||
@ -4965,6 +5002,27 @@ backprop_wrt_max: Backpropagated gradients w.r.t. max parameter, shape `[d]`:
|
||||
`sum_per_d(gradients * (inputs > max))`.
|
||||
)doc");
|
||||
|
||||
#ifdef INTEL_MKL
|
||||
REGISTER_OP("_MklConcat")
|
||||
.Input("concat_dim: int32")
|
||||
.Input("values: N * T")
|
||||
.Input("mkl_concat_dim: uint8")
|
||||
.Input("mkl_values: N * uint8")
|
||||
.Output("output: T")
|
||||
.Output("mkl_output: uint8")
|
||||
.Attr("N: int >= 2")
|
||||
.Attr("T: type")
|
||||
.SetShapeFn([](InferenceContext* c) {
|
||||
return shape_inference::ConcatShape(c, c->num_inputs() - 3);
|
||||
})
|
||||
.Doc(R"doc(
|
||||
MKL version of Concat operator. Uses MKL DNN APIs to perform concatenation.
|
||||
|
||||
NOTE Do not invoke this operator directly in Python. Graph rewrite pass is
|
||||
expected to invoke these operators.
|
||||
)doc");
|
||||
#endif
|
||||
|
||||
// Deprecated op registrations:
|
||||
|
||||
// The following can be deleted after 10mar2017.
|
||||
|
@ -440,6 +440,7 @@ REGISTER_OP("FixedLengthRecordReader")
|
||||
.Attr("header_bytes: int = 0")
|
||||
.Attr("record_bytes: int")
|
||||
.Attr("footer_bytes: int = 0")
|
||||
.Attr("hop_bytes: int = 0")
|
||||
.Attr("container: string = ''")
|
||||
.Attr("shared_name: string = ''")
|
||||
.SetIsStateful()
|
||||
@ -448,6 +449,11 @@ REGISTER_OP("FixedLengthRecordReader")
|
||||
A Reader that outputs fixed-length records from a file.
|
||||
|
||||
reader_handle: The handle to reference the Reader.
|
||||
header_bytes: Number of bytes in the header, defaults to 0.
|
||||
record_bytes: Number of bytes in the record.
|
||||
footer_bytes: Number of bytes in the footer, defaults to 0.
|
||||
hop_bytes: Number of bytes to hop before each read. Default of 0 means using
|
||||
record_bytes.
|
||||
container: If non-empty, this reader is placed in the given container.
|
||||
Otherwise, a default container is used.
|
||||
shared_name: If non-empty, this reader is named in the given bucket
|
||||
@ -459,6 +465,7 @@ REGISTER_OP("FixedLengthRecordReaderV2")
|
||||
.Attr("header_bytes: int = 0")
|
||||
.Attr("record_bytes: int")
|
||||
.Attr("footer_bytes: int = 0")
|
||||
.Attr("hop_bytes: int = 0")
|
||||
.Attr("container: string = ''")
|
||||
.Attr("shared_name: string = ''")
|
||||
.SetIsStateful()
|
||||
@ -467,6 +474,11 @@ REGISTER_OP("FixedLengthRecordReaderV2")
|
||||
A Reader that outputs fixed-length records from a file.
|
||||
|
||||
reader_handle: The handle to reference the Reader.
|
||||
header_bytes: Number of bytes in the header, defaults to 0.
|
||||
record_bytes: Number of bytes in the record.
|
||||
footer_bytes: Number of bytes in the footer, defaults to 0.
|
||||
hop_bytes: Number of bytes to hop before each read. Default of 0 means using
|
||||
record_bytes.
|
||||
container: If non-empty, this reader is placed in the given container.
|
||||
Otherwise, a default container is used.
|
||||
shared_name: If non-empty, this reader is named in the given bucket
|
||||
|
@ -2612,10 +2612,10 @@ scale_after_normalization: A bool indicating whether the resulted tensor
|
||||
)doc");
|
||||
|
||||
#ifdef INTEL_MKL
|
||||
REGISTER_OP("MklConv2D")
|
||||
REGISTER_OP("_MklConv2D")
|
||||
.Input("input: T")
|
||||
.Input("mkl_input: uint8")
|
||||
.Input("filter: T")
|
||||
.Input("mkl_input: uint8")
|
||||
.Input("mkl_filter: uint8")
|
||||
.Output("output: T")
|
||||
.Output("mkl_output: uint8")
|
||||
@ -2632,12 +2632,12 @@ NOTE Do not invoke this operator directly in Python. Graph rewrite pass is
|
||||
expected to invoke these operators.
|
||||
)doc");
|
||||
|
||||
REGISTER_OP("MklConv2DWithBias")
|
||||
REGISTER_OP("_MklConv2DWithBias")
|
||||
.Input("input: T")
|
||||
.Input("mkl_input: uint8")
|
||||
.Input("filter: T")
|
||||
.Input("mkl_filter: uint8")
|
||||
.Input("bias: T")
|
||||
.Input("mkl_input: uint8")
|
||||
.Input("mkl_filter: uint8")
|
||||
.Input("mkl_bias: uint8")
|
||||
.Output("output: T")
|
||||
.Output("mkl_output: uint8")
|
||||
@ -2654,12 +2654,12 @@ NOTE Do not invoke this operator directly in Python. Graph rewrite pass is
|
||||
expected to invoke these operators.
|
||||
)doc");
|
||||
|
||||
REGISTER_OP("MklConv2DBackpropFilter")
|
||||
REGISTER_OP("_MklConv2DBackpropFilter")
|
||||
.Input("input: T")
|
||||
.Input("mkl_input: uint8")
|
||||
.Input("filter_sizes: int32")
|
||||
.Input("mkl_filter_size: uint8")
|
||||
.Input("out_backprop: T")
|
||||
.Input("mkl_input: uint8")
|
||||
.Input("mkl_filter_size: uint8")
|
||||
.Input("mkl_out_backprop: uint8")
|
||||
.Output("output: T")
|
||||
.Output("mkl_output: uint8")
|
||||
@ -2669,7 +2669,7 @@ REGISTER_OP("MklConv2DBackpropFilter")
|
||||
.Attr(GetPaddingAttrString())
|
||||
.Attr(GetConvnetDataFormatAttrString())
|
||||
.SetShapeFn([](InferenceContext* c) {
|
||||
return InputTensorShapeOrUnknown(c, 2 /* input_idx */, 4 /* ndims */);
|
||||
return InputTensorShapeOrUnknown(c, 1 /* input_idx */, 4 /* ndims */);
|
||||
})
|
||||
.Doc(R"doc(
|
||||
MKL version of Conv2DBackpropFilter. Uses MKL DNN APIs to compute the
|
||||
@ -2679,7 +2679,7 @@ NOTE Do not invoke this operator directly in Python. Graph rewrite pass is
|
||||
expected to invoke these operators.
|
||||
)doc");
|
||||
|
||||
REGISTER_OP("MklConv2DWithBiasBackpropBias")
|
||||
REGISTER_OP("_MklConv2DWithBiasBackpropBias")
|
||||
.Input("out_backprop: T")
|
||||
.Input("mkl_out_backprop: uint8")
|
||||
.Output("output: T")
|
||||
@ -2695,12 +2695,12 @@ NOTE Do not invoke this operator directly in Python. Graph rewrite pass is
|
||||
expected to invoke these operators.
|
||||
)doc");
|
||||
|
||||
REGISTER_OP("MklConv2DBackpropInput")
|
||||
REGISTER_OP("_MklConv2DBackpropInput")
|
||||
.Input("input_sizes: int32")
|
||||
.Input("mkl_input_sizes: uint8")
|
||||
.Input("filter: T")
|
||||
.Input("mkl_filter: uint8")
|
||||
.Input("out_backprop: T")
|
||||
.Input("mkl_input_sizes: uint8")
|
||||
.Input("mkl_filter: uint8")
|
||||
.Input("mkl_out_backprop: uint8")
|
||||
.Output("output: T")
|
||||
.Output("mkl_output: uint8")
|
||||
@ -2720,7 +2720,7 @@ NOTE Do not invoke this operator directly in Python. Graph rewrite pass is
|
||||
expected to invoke these operators.
|
||||
)doc");
|
||||
|
||||
REGISTER_OP("MklRelu")
|
||||
REGISTER_OP("_MklRelu")
|
||||
.Input("features: T")
|
||||
.Input("mkl_features: uint8")
|
||||
.Output("activations: T")
|
||||
@ -2734,10 +2734,10 @@ NOTE Do not invoke this operator directly in Python. Graph rewrite pass is
|
||||
expected to invoke these operators.
|
||||
)doc");
|
||||
|
||||
REGISTER_OP("MklReluGrad")
|
||||
REGISTER_OP("_MklReluGrad")
|
||||
.Input("gradients: T")
|
||||
.Input("mkl_gradients: uint8")
|
||||
.Input("features: T")
|
||||
.Input("mkl_gradients: uint8")
|
||||
.Input("mkl_features: uint8")
|
||||
.Output("backprops: T")
|
||||
.Output("mkl_backprops: uint8")
|
||||
@ -2751,7 +2751,7 @@ NOTE Do not invoke this operator directly in Python. Graph rewrite pass is
|
||||
expected to invoke these operators.
|
||||
)doc");
|
||||
|
||||
REGISTER_OP("MklMaxPool")
|
||||
REGISTER_OP("_MklMaxPool")
|
||||
.Attr("T: {float, half} = DT_FLOAT")
|
||||
.Attr("ksize: list(int) >= 4")
|
||||
.Attr("strides: list(int) >= 4")
|
||||
@ -2761,8 +2761,8 @@ REGISTER_OP("MklMaxPool")
|
||||
.Input("input: T")
|
||||
.Input("mkl_input: uint8")
|
||||
.Output("output: T")
|
||||
.Output("mkl_output: uint8")
|
||||
.Output("workspace: T")
|
||||
.Output("mkl_output: uint8")
|
||||
.Output("mkl_workspace: uint8")
|
||||
.SetShapeFn(shape_inference::MaxPoolShape)
|
||||
.Doc(R"doc(
|
||||
@ -2773,7 +2773,7 @@ NOTE Do not invoke this operator directly in Python. Graph rewrite pass is
|
||||
expected to invoke these operators.
|
||||
)doc");
|
||||
|
||||
REGISTER_OP("MklMaxPoolGrad")
|
||||
REGISTER_OP("_MklMaxPoolGrad")
|
||||
.Attr("T: {float, half} = DT_FLOAT")
|
||||
.Attr("ksize: list(int) >= 4")
|
||||
.Attr("strides: list(int) >= 4")
|
||||
@ -2781,12 +2781,12 @@ REGISTER_OP("MklMaxPoolGrad")
|
||||
.Attr(GetPaddingAttrString())
|
||||
.Attr(GetConvnetDataFormatAttrString())
|
||||
.Input("orig_input: T")
|
||||
.Input("mkl_orig_input: uint8")
|
||||
.Input("orig_output: T")
|
||||
.Input("mkl_orig_output: uint8")
|
||||
.Input("grad: T")
|
||||
.Input("mkl_grad: uint8")
|
||||
.Input("workspace: T")
|
||||
.Input("mkl_orig_input: uint8")
|
||||
.Input("mkl_orig_output: uint8")
|
||||
.Input("mkl_grad: uint8")
|
||||
.Input("mkl_workspace: uint8")
|
||||
.Output("output: T")
|
||||
.Output("mkl_output: uint8")
|
||||
@ -2801,7 +2801,7 @@ NOTE Do not invoke this operator directly in Python. Graph rewrite pass is
|
||||
expected to invoke these operators.
|
||||
)doc");
|
||||
|
||||
REGISTER_OP("MklAvgPool")
|
||||
REGISTER_OP("_MklAvgPool")
|
||||
.Input("value: T")
|
||||
.Input("mkl_input: uint8")
|
||||
.Output("output: T")
|
||||
@ -2820,10 +2820,10 @@ NOTE Do not invoke this operator directly in Python. Graph rewrite pass is
|
||||
expected to invoke these operators.
|
||||
)doc");
|
||||
|
||||
REGISTER_OP("MklAvgPoolGrad")
|
||||
REGISTER_OP("_MklAvgPoolGrad")
|
||||
.Input("orig_input_shape: int32")
|
||||
.Input("mkl_orig_input: uint8")
|
||||
.Input("grad: T")
|
||||
.Input("mkl_orig_input: uint8")
|
||||
.Input("mkl_grad: uint8")
|
||||
.Output("output: T")
|
||||
.Output("mkl_output: uint8")
|
||||
@ -2843,7 +2843,212 @@ NOTE Do not invoke this operator directly in Python. Graph rewrite pass is
|
||||
expected to invoke these operators.
|
||||
)doc");
|
||||
|
||||
REGISTER_OP("MklToTf")
|
||||
REGISTER_OP("_MklLRN")
|
||||
.Input("input: T")
|
||||
.Input("mkl_input: uint8")
|
||||
.Output("output: T")
|
||||
.Output("workspace: T")
|
||||
.Output("mkl_output: uint8")
|
||||
.Output("mkl_workspace: uint8")
|
||||
.Attr("depth_radius: int = 5")
|
||||
.Attr("bias: float = 1.0")
|
||||
.Attr("alpha: float = 1.0")
|
||||
.Attr("beta: float = 0.5")
|
||||
.Attr("workspace_enabled: bool = false")
|
||||
.Attr("T: {float, half} = DT_FLOAT")
|
||||
.SetShapeFn([](InferenceContext* c) {
|
||||
return UnchangedShapeWithRank(c, 4);
|
||||
})
|
||||
.Doc(R"doc(
|
||||
MKL version of LRN operator. Uses MKL DNN APIs to perform local response
|
||||
normalization.
|
||||
|
||||
NOTE Do not invoke this operator directly in Python. Graph rewrite pass is
|
||||
expected to invoke these operators.
|
||||
)doc");
|
||||
|
||||
REGISTER_OP("_MklLRNGrad")
|
||||
.Input("input_grads: T")
|
||||
.Input("input_image: T")
|
||||
.Input("output_image: T")
|
||||
.Input("workspace: T")
|
||||
.Input("mkl_input_grads: uint8")
|
||||
.Input("mkl_input_image: uint8")
|
||||
.Input("mkl_output_image: uint8")
|
||||
.Input("mkl_workspace: uint8")
|
||||
.Output("output: T")
|
||||
.Output("mkl_output: uint8")
|
||||
.Attr("depth_radius: int = 5")
|
||||
.Attr("bias: float = 1.0")
|
||||
.Attr("alpha: float = 1.0")
|
||||
.Attr("beta: float = 0.5")
|
||||
.Attr("workspace_enabled: bool = false")
|
||||
.Attr("T: {float, half} = DT_FLOAT")
|
||||
.SetShapeFn([](InferenceContext* c) {
|
||||
ShapeHandle s;
|
||||
TF_RETURN_IF_ERROR(c->WithRank(c->input(0), 4, &s)); // input_grads
|
||||
TF_RETURN_IF_ERROR(c->Merge(s, c->input(1), &s)); // input_image
|
||||
TF_RETURN_IF_ERROR(c->Merge(s, c->input(2), &s)); // output_image
|
||||
c->set_output(0, s);
|
||||
return Status::OK();
|
||||
})
|
||||
.Doc(R"doc(
|
||||
MKL version of LRNGrad operator. Uses MKL DNN APIs to compute gradient for
|
||||
local response normalization.
|
||||
|
||||
NOTE Do not invoke this operator directly in Python. Graph rewrite pass is
|
||||
expected to invoke these operators.
|
||||
)doc");
|
||||
|
||||
REGISTER_OP("_MklFusedBatchNorm")
|
||||
.Input("x: T")
|
||||
.Input("scale: T")
|
||||
.Input("offset: T")
|
||||
.Input("mean: T")
|
||||
.Input("variance: T")
|
||||
.Input("mkl_x: uint8")
|
||||
.Input("mkl_scale: uint8")
|
||||
.Input("mkl_offset: uint8")
|
||||
.Input("mkl_mean: uint8")
|
||||
.Input("mkl_variance: uint8")
|
||||
.Output("y: T")
|
||||
.Output("batch_mean: T")
|
||||
.Output("batch_variance: T")
|
||||
.Output("reserve_space_1: T")
|
||||
.Output("reserve_space_2: T")
|
||||
.Output("mkl_y: uint8")
|
||||
.Output("mkl_batch_mean: uint8")
|
||||
.Output("mkl_batch_variance: uint8")
|
||||
.Output("mkl_reserve_space_1: uint8")
|
||||
.Output("mkl_reserve_space_2: uint8")
|
||||
.Attr("T: numbertype")
|
||||
.Attr("epsilon: float = 0.0001")
|
||||
.Attr("data_format: string = 'NHWC'")
|
||||
.Attr("is_training: bool = true")
|
||||
.SetShapeFn([](InferenceContext* c) {
|
||||
ShapeHandle x;
|
||||
TF_RETURN_IF_ERROR(c->WithRank(c->input(0), 4, &x));
|
||||
|
||||
bool is_training;
|
||||
c->GetAttr("is_training", &is_training);
|
||||
int number_inputs = (is_training) ? 3 : 5;
|
||||
string data_format;
|
||||
c->GetAttr("data_format", &data_format);
|
||||
DimensionHandle channel_dim =
|
||||
(data_format == "NHWC") ? c->Dim(x, 3) : c->Dim(x, 1);
|
||||
|
||||
// covers scale, offset, and if is_training is false, mean, variance
|
||||
for (int i = 1; i < number_inputs; ++i) {
|
||||
ShapeHandle vec;
|
||||
TF_RETURN_IF_ERROR(c->WithRank(c->input(i), 1, &vec));
|
||||
TF_RETURN_IF_ERROR(c->Merge(channel_dim, c->Dim(vec, 0), &channel_dim));
|
||||
}
|
||||
|
||||
ShapeHandle y;
|
||||
if (data_format == "NHWC") {
|
||||
TF_RETURN_IF_ERROR(c->ReplaceDim(x, 3, channel_dim, &y));
|
||||
} else {
|
||||
TF_RETURN_IF_ERROR(c->ReplaceDim(x, 1, channel_dim, &y));
|
||||
}
|
||||
c->set_output(0, y);
|
||||
ShapeHandle vector_shape = c->Vector(channel_dim);
|
||||
c->set_output(1, vector_shape);
|
||||
c->set_output(2, vector_shape);
|
||||
c->set_output(3, vector_shape);
|
||||
c->set_output(4, vector_shape);
|
||||
return Status::OK();
|
||||
})
|
||||
.Doc(R"doc(
|
||||
MKL version of FusedBatchNorm operator. Uses MKL DNN APIs to perform fused
|
||||
batch normalization.
|
||||
|
||||
NOTE Do not invoke this operator directly in Python. Graph rewrite pass is
|
||||
expected to invoke these operators.
|
||||
)doc");
|
||||
|
||||
REGISTER_OP("_MklFusedBatchNormGrad")
|
||||
.Input("y_backprop: T")
|
||||
.Input("x: T")
|
||||
.Input("scale: T")
|
||||
.Input("reserve_space_1: T")
|
||||
.Input("reserve_space_2: T")
|
||||
.Input("mkl_y_backprop: uint8")
|
||||
.Input("mkl_x: uint8")
|
||||
.Input("mkl_scale: uint8")
|
||||
.Input("mkl_reserve_space_1: uint8")
|
||||
.Input("mkl_reserve_space_2: uint8")
|
||||
.Output("x_backprop: T")
|
||||
.Output("scale_backprop: T")
|
||||
.Output("offset_backprop: T")
|
||||
.Output("reserve_space_3: T")
|
||||
.Output("reserve_space_4: T")
|
||||
.Output("mkl_x_backprop: uint8")
|
||||
.Output("mkl_scale_backprop: uint8")
|
||||
.Output("mkl_offset_backprop: uint8")
|
||||
.Output("mkl_reserve_space_3: uint8")
|
||||
.Output("mkl_reserve_space_4: uint8")
|
||||
.Attr("T: numbertype")
|
||||
.Attr("epsilon: float = 0.0001")
|
||||
.Attr("data_format: string = 'NHWC'")
|
||||
.Attr("is_training: bool = true")
|
||||
.SetShapeFn([](InferenceContext* c) {
|
||||
ShapeHandle y_backprop;
|
||||
TF_RETURN_IF_ERROR(c->WithRank(c->input(0), 4, &y_backprop));
|
||||
ShapeHandle x;
|
||||
TF_RETURN_IF_ERROR(c->WithRank(c->input(1), 4, &x));
|
||||
|
||||
bool is_training;
|
||||
string data_format;
|
||||
c->GetAttr("is_training", &is_training);
|
||||
c->GetAttr("data_format", &data_format);
|
||||
DimensionHandle channel_dim = (data_format == "NHWC")
|
||||
? c->Dim(y_backprop, 3)
|
||||
: c->Dim(y_backprop, 1);
|
||||
if (data_format == "NHWC") {
|
||||
TF_RETURN_IF_ERROR(c->Merge(channel_dim, c->Dim(x, 3), &channel_dim));
|
||||
} else {
|
||||
TF_RETURN_IF_ERROR(c->Merge(channel_dim, c->Dim(x, 1), &channel_dim));
|
||||
}
|
||||
|
||||
// covers scale, mean (reserve_space_1), variance (reserve_space_2)
|
||||
for (int i = 2; i < 5; ++i) {
|
||||
ShapeHandle vec;
|
||||
TF_RETURN_IF_ERROR(c->WithRank(c->input(i), 1, &vec));
|
||||
TF_RETURN_IF_ERROR(c->Merge(channel_dim, c->Dim(vec, 0), &channel_dim));
|
||||
}
|
||||
|
||||
ShapeHandle x_backprop;
|
||||
if (data_format == "NHWC") {
|
||||
TF_RETURN_IF_ERROR(
|
||||
c->ReplaceDim(y_backprop, 3, channel_dim, &x_backprop));
|
||||
} else {
|
||||
TF_RETURN_IF_ERROR(
|
||||
c->ReplaceDim(y_backprop, 1, channel_dim, &x_backprop));
|
||||
}
|
||||
c->set_output(0, x_backprop);
|
||||
c->set_output(1, c->Vector(channel_dim));
|
||||
c->set_output(2, c->Vector(channel_dim));
|
||||
// Set the correct shapes for reserve_spaces
|
||||
// so that gradients can be performed when
|
||||
// the op is in a symbolic condition.
|
||||
if (is_training) {
|
||||
c->set_output(3, c->Vector(0));
|
||||
c->set_output(4, c->Vector(0));
|
||||
} else {
|
||||
c->set_output(3, c->Vector(channel_dim));
|
||||
c->set_output(4, c->Vector(channel_dim));
|
||||
}
|
||||
return Status::OK();
|
||||
})
|
||||
.Doc(R"doc(
|
||||
MKL version of FusedBatchNormGrad operator. Uses MKL DNN APIs to compute
|
||||
gradients for fused batch normalization.
|
||||
|
||||
NOTE Do not invoke this operator directly in Python. Graph rewrite pass is
|
||||
expected to invoke these operators.
|
||||
)doc");
|
||||
|
||||
REGISTER_OP("_MklToTf")
|
||||
.Input("input: T")
|
||||
.Input("mkl_input: uint8")
|
||||
.Output("output: T")
|
||||
|
@ -26416,6 +26416,59 @@ op {
|
||||
summary: "Computes the sum along segments of a tensor."
|
||||
description: "Read @{$math_ops#segmentation$the section on segmentation} for an explanation of\nsegments.\n\nComputes a tensor such that\n`(output[i] = sum_{j...} data[j...]` where the sum is over tuples `j...` such\nthat `segment_ids[j...] == i`. Unlike `SegmentSum`, `segment_ids`\nneed not be sorted and need not cover all values in the full\nrange of valid values.\n\nIf the sum is empty for a given segment ID `i`, `output[i] = 0`.\n\n`num_segments` should equal the number of distinct segment IDs.\n\n<div style=\"width:70%; margin:auto; margin-bottom:10px; margin-top:20px;\">\n<img style=\"width:100%\" src=\"https://www.tensorflow.org/images/UnsortedSegmentSum.png\" alt>\n</div>"
|
||||
}
|
||||
op {
|
||||
name: "UnsortedSegmentSum"
|
||||
input_arg {
|
||||
name: "data"
|
||||
type_attr: "T"
|
||||
}
|
||||
input_arg {
|
||||
name: "segment_ids"
|
||||
description: "A tensor whose shape is a prefix of `data.shape`."
|
||||
type_attr: "Tindices"
|
||||
}
|
||||
input_arg {
|
||||
name: "num_segments"
|
||||
type: DT_INT32
|
||||
}
|
||||
output_arg {
|
||||
name: "output"
|
||||
description: "Has same shape as data, except for the first `segment_ids.rank`\ndimensions, which are replaced with a single dimension which has size\n`num_segments`."
|
||||
type_attr: "T"
|
||||
}
|
||||
attr {
|
||||
name: "T"
|
||||
type: "type"
|
||||
allowed_values {
|
||||
list {
|
||||
type: DT_FLOAT
|
||||
type: DT_DOUBLE
|
||||
type: DT_INT64
|
||||
type: DT_INT32
|
||||
type: DT_UINT8
|
||||
type: DT_UINT16
|
||||
type: DT_INT16
|
||||
type: DT_INT8
|
||||
type: DT_QINT8
|
||||
type: DT_QUINT8
|
||||
type: DT_QINT32
|
||||
type: DT_HALF
|
||||
}
|
||||
}
|
||||
}
|
||||
attr {
|
||||
name: "Tindices"
|
||||
type: "type"
|
||||
allowed_values {
|
||||
list {
|
||||
type: DT_INT32
|
||||
type: DT_INT64
|
||||
}
|
||||
}
|
||||
}
|
||||
summary: "Computes the max along segments of a tensor."
|
||||
description: "Read [the section on\nSegmentation](../../api_docs/python/math_ops.md#segmentation) for an explanation\nof segments.\n\nComputes a tensor such that\n\\\\(output_i = \\sum_j data_j\\\\) where sum is over `j` such\nthat `segment_ids[j] == i`. Unlike `SegmentSum`, `segment_ids`\nneed not be sorted and need not cover all values in the full\n range of valid values.\n\nIf the sum is empty for a given segment ID `i`, `output[i] = 0`.\n\n`num_segments` should equal the number of distinct segment IDs.\n\n<div style=\"width:70%; margin:auto; margin-bottom:10px; margin-top:20px;\">\n<img style=\"width:100%\" src=\"../../images/UnsortedSegmentSum.png\" alt>\n</div>"
|
||||
}
|
||||
op {
|
||||
name: "Unstage"
|
||||
output_arg {
|
||||
|
@ -128,12 +128,13 @@ pair takes space.
|
||||
)doc");
|
||||
|
||||
REGISTER_OP("SparseTensorDenseMatMul")
|
||||
.Input("a_indices: int64")
|
||||
.Input("a_indices: Tindices")
|
||||
.Input("a_values: T")
|
||||
.Input("a_shape: int64")
|
||||
.Input("b: T")
|
||||
.Output("product: T")
|
||||
.Attr("T: type")
|
||||
.Attr("Tindices: {int32,int64} = DT_INT64")
|
||||
.Attr("adjoint_a: bool = false")
|
||||
.Attr("adjoint_b: bool = false")
|
||||
.SetShapeFn([](InferenceContext* c) {
|
||||
|
@ -194,13 +194,15 @@ def tf_kernel_tests_linkstatic():
|
||||
|
||||
def tf_additional_lib_defines():
|
||||
return select({
|
||||
"//tensorflow:with_jemalloc": ["TENSORFLOW_USE_JEMALLOC"],
|
||||
"//tensorflow:with_jemalloc_linux_x86_64": ["TENSORFLOW_USE_JEMALLOC"],
|
||||
"//tensorflow:with_jemalloc_linux_ppc64le":["TENSORFLOW_USE_JEMALLOC"],
|
||||
"//conditions:default": [],
|
||||
})
|
||||
|
||||
def tf_additional_lib_deps():
|
||||
return select({
|
||||
"//tensorflow:with_jemalloc": ["@jemalloc"],
|
||||
"//tensorflow:with_jemalloc_linux_x86_64": ["@jemalloc"],
|
||||
"//tensorflow:with_jemalloc_linux_ppc64le": ["@jemalloc"],
|
||||
"//conditions:default": [],
|
||||
})
|
||||
|
||||
@ -246,3 +248,9 @@ def tf_lib_proto_parsing_deps():
|
||||
":protos_all_cc",
|
||||
"//tensorflow/core/platform/default/build_config:proto_parsing",
|
||||
]
|
||||
|
||||
def tf_additional_verbs_lib_defines():
|
||||
return select({
|
||||
"//tensorflow:with_verbs_support": ["TENSORFLOW_USE_VERBS"],
|
||||
"//conditions:default": [],
|
||||
})
|
||||
|
@ -22,3 +22,11 @@ def tf_additional_license_deps():
|
||||
"//tensorflow:with_xla_support": ["@llvm//:LICENSE.TXT"],
|
||||
"//conditions:default": [],
|
||||
})
|
||||
|
||||
def tf_additional_verbs_deps():
|
||||
return select({
|
||||
"//tensorflow:with_verbs_support": [
|
||||
"//tensorflow/contrib/verbs:verbs_server_lib",
|
||||
"//tensorflow/contrib/verbs:grpc_verbs_client"],
|
||||
"//conditions:default": [],
|
||||
})
|
||||
|
@ -20,11 +20,11 @@ limitations under the License.
|
||||
|
||||
#define TF_MAJOR_VERSION 1
|
||||
#define TF_MINOR_VERSION 1
|
||||
#define TF_PATCH_VERSION 0-rc1
|
||||
#define TF_PATCH_VERSION 0
|
||||
|
||||
// TF_VERSION_SUFFIX is non-empty for pre-releases (e.g. "-alpha", "-alpha.1",
|
||||
// "-beta", "-rc", "-rc.1")
|
||||
#define TF_VERSION_SUFFIX ""
|
||||
#define TF_VERSION_SUFFIX "-rc2"
|
||||
|
||||
#define TF_STR_HELPER(x) #x
|
||||
#define TF_STR(x) TF_STR_HELPER(x)
|
||||
|
@ -75,7 +75,6 @@ class MklShape {
|
||||
void SetTfLayout(const size_t dimension, const size_t* sizes,
|
||||
const size_t* strides) {
|
||||
dimension_ = dimension;
|
||||
|
||||
if (dimension > 0) { // MKl doesn't support zero dimension tensors
|
||||
sizes_ = new size_t[dimension];
|
||||
strides_ = new size_t[dimension];
|
||||
@ -140,6 +139,39 @@ class MklShape {
|
||||
const size_t* GetTfToMklDimMap() const { return tf_to_mkl_dim_map_; }
|
||||
size_t tf_dim_idx(int index) const { return tf_to_mkl_dim_map_[index]; }
|
||||
|
||||
// Query TF-MKL dimension ordering map and check if Tensorflow dimension 'd'
|
||||
// corresponds to MKL's Channel dimension.
|
||||
bool IsMklChannelDim(int d) const { return tf_dim_idx(d) == MklDims::C; }
|
||||
// Query TF-MKL dimension ordering map and check if Tensorflow dimension 'd'
|
||||
// corresponds to MKL's Batch dimension.
|
||||
bool IsMklBatchDim(int d) const { return tf_dim_idx(d) == MklDims::N; }
|
||||
// Query TF-MKL dimension ordering map and check if Tensorflow dimension 'd'
|
||||
// corresponds to MKL's Width dimension.
|
||||
bool IsMklWidthDim(int d) const { return tf_dim_idx(d) == MklDims::W; }
|
||||
// Query TF-MKL dimension ordering map and check if Tensorflow dimension 'd'
|
||||
// corresponds to MKL's Height dimension.
|
||||
bool IsMklHeightDim(int d) const { return tf_dim_idx(d) == MklDims::H; }
|
||||
|
||||
// Check if the TF-Mkl dimension ordering map specifies if the input
|
||||
// tensor is in NCHW format.
|
||||
bool IsTensorInNCHWFormat() const {
|
||||
TensorFormat data_format = FORMAT_NCHW;
|
||||
return (IsMklBatchDim(GetTensorDimIndex<2>(data_format, 'N')) &&
|
||||
IsMklChannelDim(GetTensorDimIndex<2>(data_format, 'C')) &&
|
||||
IsMklHeightDim(GetTensorDimIndex<2>(data_format, 'H')) &&
|
||||
IsMklWidthDim(GetTensorDimIndex<2>(data_format, 'W')));
|
||||
}
|
||||
|
||||
// Check if the TF-Mkl dimension ordering map specifies if the input
|
||||
// tensor is in NHWC format.
|
||||
bool IsTensorInNHWCFormat() const {
|
||||
TensorFormat data_format = FORMAT_NHWC;
|
||||
return (IsMklBatchDim(GetTensorDimIndex<2>(data_format, 'N')) &&
|
||||
IsMklChannelDim(GetTensorDimIndex<2>(data_format, 'C')) &&
|
||||
IsMklHeightDim(GetTensorDimIndex<2>(data_format, 'H')) &&
|
||||
IsMklWidthDim(GetTensorDimIndex<2>(data_format, 'W')));
|
||||
}
|
||||
|
||||
void GetConvertedFlatData(dnnLayout_t targetLayout, void* input,
|
||||
void* output) const {
|
||||
dnnLayout_t curLayout;
|
||||
@ -194,9 +226,9 @@ class MklShape {
|
||||
(STRIDES_OFFSET(dims) + dims * sizeof(size_t)) // Location of mklLayout_
|
||||
#define TF_LAYOUT_OFFSET(dims) \
|
||||
(MKL_LAYOUT_OFFSET(dims) + SIZE_OF_MKL_DNN_BUF) // Location of tfLayout_
|
||||
// Location of tf_to_mkl_dim_map_
|
||||
#define TF_TO_MKL_DIM_MAP_OFFSET(dims) \
|
||||
(TF_LAYOUT_OFFSET(dims) + SIZE_OF_MKL_DNN_BUF)
|
||||
(TF_LAYOUT_OFFSET(dims) + \
|
||||
SIZE_OF_MKL_DNN_BUF) // Location of tf_to_mkl_dim_map_
|
||||
|
||||
// TODO(agramesh1) make sure to create a const to share with rewrite pass
|
||||
// for min size of MKL metadata tensor.
|
||||
@ -265,45 +297,166 @@ class MklShape {
|
||||
size_t dimension_ = 0;
|
||||
size_t* sizes_ = nullptr; // Required by MKL for conversions
|
||||
size_t* strides_ = nullptr; // Required by MKL for conversions
|
||||
// TF dimension corresponding to this MKL dimension
|
||||
size_t* tf_to_mkl_dim_map_ = nullptr;
|
||||
size_t* tf_to_mkl_dim_map_ =
|
||||
nullptr; // TF dimension corresponding to this MKL dimension
|
||||
};
|
||||
|
||||
int inline GetTensorDataIndex(int n) {
|
||||
return 2 * n; // index corresponding to nth input/output tensor
|
||||
// List of MklShape objects. Used in Concat/Split layers.
|
||||
typedef std::vector<MklShape> MklShapeList;
|
||||
|
||||
// Check if all tensors specified by MklShapes are MKL tensors.
|
||||
inline bool AreAllMklTensors(const MklShapeList& shapes) {
|
||||
for (auto& s : shapes) {
|
||||
if (!s.IsMklTensor()) {
|
||||
return false;
|
||||
}
|
||||
}
|
||||
return true;
|
||||
}
|
||||
|
||||
int inline GetTensorMetaDataIndex(int n) {
|
||||
// index corresponding to meta data of nth input/output tensor
|
||||
return 2 * n + 1;
|
||||
template <typename T>
|
||||
inline Tensor ConvertMklToTF(OpKernelContext* context, const Tensor& mkl_tensor,
|
||||
const MklShape& mkl_shape) {
|
||||
Tensor output_tensor;
|
||||
TensorShape output_shape;
|
||||
|
||||
for (size_t j = 0; j < mkl_shape.GetDimension(); j++) {
|
||||
// Outermost to innermost dimension
|
||||
output_shape.AddDim(mkl_shape.GetSizes()[mkl_shape.tf_dim_idx(j)]);
|
||||
}
|
||||
|
||||
// Allocate output tensor.
|
||||
context->allocate_temp(DataTypeToEnum<T>::v(), output_shape, &output_tensor);
|
||||
|
||||
dnnLayout_t output_layout = static_cast<dnnLayout_t>(mkl_shape.GetTfLayout());
|
||||
void* input_buffer = const_cast<T*>(mkl_tensor.flat<T>().data());
|
||||
void* output_buffer = const_cast<T*>(output_tensor.flat<T>().data());
|
||||
|
||||
if (mkl_tensor.NumElements() != 0) {
|
||||
mkl_shape.GetConvertedFlatData(output_layout, input_buffer, output_buffer);
|
||||
}
|
||||
|
||||
return output_tensor;
|
||||
}
|
||||
|
||||
// Since our ops are going to produce and also consume N addition tensors
|
||||
// (Mkl) for N Tensorflow tensors, we can have following different
|
||||
// orderings among these 2N tensors.
|
||||
//
|
||||
// E.g., for Tensorflow tensors A, B, and C, our ops will produce and
|
||||
// consume A_m, B_m, and C_m additionally.
|
||||
//
|
||||
// INTERLEAVED: in this case 2N tensors are interleaved. So for above
|
||||
// example, the ordering looks like: A, A_m, B, B_m, C, C_m.
|
||||
//
|
||||
// CONTIGUOUS: in thi case N Tensorflow tensors are contiguous followed
|
||||
// by N Mkl tensors. So for above example, the ordering looks
|
||||
// like: A, B, C, A_m, B_m, C_m
|
||||
//
|
||||
// Following APIs map index of original Tensorflow tensors to their appropriate
|
||||
// position based on selected ordering. For contiguous ordering, we need to know
|
||||
// the total number of tensors (parameter total).
|
||||
//
|
||||
typedef enum { TENSORS_INTERLEAVED, TENSORS_CONTIGUOUS } MklTfTensorOrdering;
|
||||
// NOTE: Currently, we use contiguous ordering. If you change this, then you
|
||||
// would need to change Mkl op definitions in nn_ops.cc.
|
||||
static MklTfTensorOrdering kTensorOrdering = TENSORS_CONTIGUOUS;
|
||||
|
||||
// Get index of MetaData tensor from index 'n' of Data tensor.
|
||||
inline int DataIndexToMetaDataIndex(int n, int total_tensors) {
|
||||
if (kTensorOrdering == MklTfTensorOrdering::TENSORS_INTERLEAVED) {
|
||||
// For interleaved ordering, Mkl tensor follows immediately after
|
||||
// Tensorflow tensor.
|
||||
return n + 1;
|
||||
} else {
|
||||
CHECK_EQ(kTensorOrdering, MklTfTensorOrdering::TENSORS_CONTIGUOUS);
|
||||
// For contiguous ordering, Mkl tensor is n+total_tensors / 2 away.
|
||||
return n + total_tensors / 2;
|
||||
}
|
||||
}
|
||||
|
||||
int inline GetTensorDataIndex(int n, int total_tensors) {
|
||||
if (kTensorOrdering == MklTfTensorOrdering::TENSORS_INTERLEAVED) {
|
||||
return 2 * n; // index corresponding to nth input/output tensor
|
||||
} else {
|
||||
CHECK_EQ(kTensorOrdering, MklTfTensorOrdering::TENSORS_CONTIGUOUS);
|
||||
return n;
|
||||
}
|
||||
}
|
||||
|
||||
int inline GetTensorMetaDataIndex(int n, int total_tensors) {
|
||||
// Get index for TensorData first and then use mapping function
|
||||
// to get TensorMetaData index from TensorData index.
|
||||
int tidx = GetTensorDataIndex(n, total_tensors);
|
||||
return DataIndexToMetaDataIndex(tidx, total_tensors);
|
||||
}
|
||||
|
||||
// Get the MKL shape from the second string tensor
|
||||
inline void GetMklShape(OpKernelContext* ctext, int n, MklShape* mklshape) {
|
||||
mklshape->DeSerializeMklShape(
|
||||
ctext->input(GetTensorMetaDataIndex(n)).flat<uint8>().data(),
|
||||
ctext->input(GetTensorMetaDataIndex(n)).flat<uint8>().size() *
|
||||
ctext->input(GetTensorMetaDataIndex(n, ctext->num_inputs()))
|
||||
.flat<uint8>()
|
||||
.data(),
|
||||
ctext->input(GetTensorMetaDataIndex(n, ctext->num_inputs()))
|
||||
.flat<uint8>()
|
||||
.size() *
|
||||
sizeof(uint8));
|
||||
}
|
||||
|
||||
// Gets the actual input
|
||||
inline const Tensor& MklGetInput(OpKernelContext* ctext, int n) {
|
||||
return ctext->input(GetTensorDataIndex(n));
|
||||
return ctext->input(GetTensorDataIndex(n, ctext->num_inputs()));
|
||||
}
|
||||
|
||||
inline void GetMklInputList(OpKernelContext* ctext, StringPiece name,
|
||||
OpInputList* input_tensors) {
|
||||
CHECK_NOTNULL(input_tensors);
|
||||
ctext->input_list(name, input_tensors);
|
||||
}
|
||||
|
||||
inline void GetMklShapeList(OpKernelContext* ctext, StringPiece name,
|
||||
MklShapeList* mkl_shapes) {
|
||||
OpInputList input_mkl_tensors;
|
||||
GetMklInputList(ctext, strings::StrCat("mkl_", name), &input_mkl_tensors);
|
||||
|
||||
for (int i = 0; i < input_mkl_tensors.size(); i++) {
|
||||
(*mkl_shapes)[i].DeSerializeMklShape(
|
||||
input_mkl_tensors[i].flat<uint8>().data(),
|
||||
input_mkl_tensors[i].flat<uint8>().size() * sizeof(uint8));
|
||||
}
|
||||
}
|
||||
|
||||
// Allocate the second output tensor that will contain
|
||||
// the MKL shape serialized
|
||||
inline void AllocateOutputSetMklShape(OpKernelContext* ctext, int n,
|
||||
const MklShape& mkl_shape) {
|
||||
Tensor* second_tensor = nullptr;
|
||||
TensorShape second_shape;
|
||||
second_shape.AddDim(SIZE_OF_MKL_SERIAL_DATA(mkl_shape.GetDimension()));
|
||||
OP_REQUIRES_OK(ctext, ctext->allocate_output(
|
||||
GetTensorMetaDataIndex(n, ctext->num_outputs()),
|
||||
second_shape, &second_tensor));
|
||||
mkl_shape.SerializeMklShape(
|
||||
second_tensor->flat<uint8>().data(),
|
||||
second_tensor->flat<uint8>().size() * sizeof(uint8));
|
||||
}
|
||||
|
||||
// Allocate the output tensor, create a second output tensor that will contain
|
||||
// the MKL shape serialized
|
||||
inline void AllocateOutputSetMklshape(OpKernelContext* ctext, int n,
|
||||
inline void AllocateOutputSetMklShape(OpKernelContext* ctext, int n,
|
||||
Tensor** output,
|
||||
const TensorShape& tfshape,
|
||||
const MklShape& mklshape) {
|
||||
const TensorShape& tf_shape,
|
||||
const MklShape& mkl_shape) {
|
||||
Tensor* second_tensor = nullptr;
|
||||
TensorShape second_shape;
|
||||
second_shape.AddDim(SIZE_OF_MKL_SERIAL_DATA(mklshape.GetDimension()));
|
||||
second_shape.AddDim(SIZE_OF_MKL_SERIAL_DATA(mkl_shape.GetDimension()));
|
||||
OP_REQUIRES_OK(
|
||||
ctext, ctext->allocate_output(GetTensorDataIndex(n), tfshape, output));
|
||||
OP_REQUIRES_OK(ctext, ctext->allocate_output(GetTensorMetaDataIndex(n),
|
||||
second_shape, &second_tensor));
|
||||
mklshape.SerializeMklShape(
|
||||
ctext, ctext->allocate_output(GetTensorDataIndex(n, ctext->num_outputs()),
|
||||
tf_shape, output));
|
||||
OP_REQUIRES_OK(ctext, ctext->allocate_output(
|
||||
GetTensorMetaDataIndex(n, ctext->num_outputs()),
|
||||
second_shape, &second_tensor));
|
||||
mkl_shape.SerializeMklShape(
|
||||
second_tensor->flat<uint8>().data(),
|
||||
second_tensor->flat<uint8>().size() * sizeof(uint8));
|
||||
}
|
||||
@ -342,12 +495,11 @@ inline void GetStridesFromSizes(TensorFormat data_format, size_t* strides,
|
||||
|
||||
inline void MklSizesToTFSizes(OpKernelContext* context,
|
||||
TensorFormat data_format_,
|
||||
const MklShape& mklshape, TensorShape* tfshape) {
|
||||
size_t tf_dim = mklshape.GetDimension();
|
||||
const size_t* tf_sizes = mklshape.GetSizes();
|
||||
const MklShape& mkl_shape,
|
||||
TensorShape* tf_shape) {
|
||||
size_t tf_dim = mkl_shape.GetDimension();
|
||||
const size_t* tf_sizes = mkl_shape.GetSizes();
|
||||
|
||||
// TODO(agramesh1): check if this constraint is applicable in other cases
|
||||
// (besides BackpropInput, BackpropFilter).
|
||||
OP_REQUIRES(context, tf_dim == 4,
|
||||
errors::InvalidArgument("MKLSizesToTFSizes: size must be 4-dim"));
|
||||
std::vector<int32> sizes;
|
||||
@ -364,7 +516,7 @@ inline void MklSizesToTFSizes(OpKernelContext* context,
|
||||
sizes.push_back(tf_sizes[0]);
|
||||
}
|
||||
|
||||
OP_REQUIRES_OK(context, TensorShapeUtils::MakeShape(sizes, tfshape));
|
||||
OP_REQUIRES_OK(context, TensorShapeUtils::MakeShape(sizes, tf_shape));
|
||||
}
|
||||
|
||||
inline int32 GetMklTensorDimIndex(char dimension) {
|
||||
@ -383,38 +535,71 @@ inline int32 GetMklTensorDimIndex(char dimension) {
|
||||
}
|
||||
}
|
||||
|
||||
inline int64 GetMklTensorDim(const MklShape& mklshape, char dimension) {
|
||||
inline int64 GetMklTensorDim(const MklShape& mkl_shape, char dimension) {
|
||||
int index = GetMklTensorDimIndex(dimension);
|
||||
CHECK(index >= 0 && index < mklshape.GetDimension())
|
||||
CHECK(index >= 0 && index < mkl_shape.GetDimension())
|
||||
<< "Invalid index from the dimension: " << index << ", " << dimension;
|
||||
return mklshape.dim_size(index);
|
||||
return mkl_shape.dim_size(index);
|
||||
}
|
||||
|
||||
namespace mkl_layer_registry {
|
||||
inline void CopyMklTensorInToOut(OpKernelContext* context, int idx_in,
|
||||
int idx_out) {
|
||||
int num_inputs = context->num_inputs();
|
||||
int num_outputs = context->num_outputs();
|
||||
int idx_data_in = GetTensorDataIndex(idx_in, num_inputs);
|
||||
int idx_meta_in = GetTensorMetaDataIndex(idx_in, num_inputs);
|
||||
int idx_data_out = GetTensorDataIndex(idx_out, num_outputs);
|
||||
int idx_meta_out = GetTensorMetaDataIndex(idx_out, num_outputs);
|
||||
|
||||
static const char* kMklLayerLabel = "MklLayer";
|
||||
static const char* kMklLayerLabelPattern = "label='MklLayer'";
|
||||
const Tensor& data = context->input(idx_data_in);
|
||||
const Tensor& meta = context->input(idx_meta_in);
|
||||
Tensor output(data.dtype());
|
||||
Tensor meta_output(meta.dtype());
|
||||
|
||||
// TODO(intel_tf): alternatively, call forward_input_to_output_with_shape(...)
|
||||
CHECK(output.CopyFrom(data, data.shape()));
|
||||
CHECK(meta_output.CopyFrom(meta, meta.shape()));
|
||||
context->set_output(idx_data_out, output);
|
||||
context->set_output(idx_meta_out, meta_output);
|
||||
}
|
||||
|
||||
inline void CopyTFTensorInToOut(OpKernelContext* context, int idx_in,
|
||||
int idx_out, const TensorShape& shape) {
|
||||
int num_inputs = context->num_inputs();
|
||||
int num_outputs = context->num_outputs();
|
||||
int idx_data_in = GetTensorDataIndex(idx_in, num_inputs);
|
||||
int idx_data_out = GetTensorDataIndex(idx_out, num_outputs);
|
||||
|
||||
const Tensor& data = context->input(idx_data_in);
|
||||
MklShape mkl_shape_output;
|
||||
mkl_shape_output.SetMklTensor(false);
|
||||
AllocateOutputSetMklShape(context, idx_out, mkl_shape_output);
|
||||
Tensor output(data.dtype());
|
||||
// TODO(intel_tf): alternatively, call forward_input_to_output_with_shape(...)
|
||||
CHECK(output.CopyFrom(data, shape));
|
||||
context->set_output(idx_data_out, output);
|
||||
}
|
||||
|
||||
namespace mkl_op_registry {
|
||||
static const char* kMklOpLabel = "MklOp";
|
||||
static const char* kMklOpLabelPattern = "label='MklOp'";
|
||||
|
||||
// Check whether opname with type T is registered as MKL-compliant.
|
||||
//
|
||||
// @input: name of the op
|
||||
// @input: T datatype to be used for checking op
|
||||
// @return: true if opname is registered as Mkl layer op
|
||||
static inline bool IsMklLayer(const std::string& op_name, DataType T) {
|
||||
// @return: true if opname is registered as Mkl op
|
||||
static inline bool IsMklOp(const std::string& op_name, DataType T) {
|
||||
string kernel = KernelsRegisteredForOp(op_name);
|
||||
// Currently, MKL only supports float type for ops. So we check if
|
||||
// the type is float. Actually, we should query kernel registration and
|
||||
// find out if op is supported for type T. But there is no API to query
|
||||
// kernel registration using name and type.
|
||||
bool result =
|
||||
(kernel.find(kMklLayerLabelPattern) != string::npos) && (T == DT_FLOAT);
|
||||
if (result == true) {
|
||||
VLOG(1) << "mkl_layer_registry::" << op_name << " is " << kMklLayerLabel;
|
||||
kernel.find(kMklOpLabelPattern) != string::npos && (T == DT_FLOAT);
|
||||
if (result) {
|
||||
VLOG(1) << "mkl_op_registry::" << op_name << " is " << kMklOpLabel;
|
||||
}
|
||||
return result;
|
||||
}
|
||||
|
||||
} // namespace mkl_layer_registry
|
||||
} // namespace mkl_op_registry
|
||||
|
||||
} // namespace tensorflow
|
||||
#endif // INTEL_MKL
|
||||
|
@ -115,31 +115,31 @@ Example:
|
||||
|
||||
def my_op(tensor_in, other_tensor_in, my_param, other_param=0.5,
|
||||
output_collections=(), name=None):
|
||||
"""My operation that adds two tensors with given coefficients.
|
||||
"""My operation that adds two tensors with given coefficients.
|
||||
|
||||
Args:
|
||||
tensor_in: `Tensor`, input tensor.
|
||||
other_tensor_in: `Tensor`, same shape as `tensor_in`, other input tensor.
|
||||
my_param: `float`, coefficient for `tensor_in`.
|
||||
other_param: `float`, coefficient for `other_tensor_in`.
|
||||
output_collections: `tuple` of `string`s, name of the collection to
|
||||
collect result of this op.
|
||||
name: `string`, name of the operation.
|
||||
Args:
|
||||
tensor_in: `Tensor`, input tensor.
|
||||
other_tensor_in: `Tensor`, same shape as `tensor_in`, other input tensor.
|
||||
my_param: `float`, coefficient for `tensor_in`.
|
||||
other_param: `float`, coefficient for `other_tensor_in`.
|
||||
output_collections: `tuple` of `string`s, name of the collection to
|
||||
collect result of this op.
|
||||
name: `string`, name of the operation.
|
||||
|
||||
Returns:
|
||||
`Tensor` of same shape as `tensor_in`, sum of input values with coefficients.
|
||||
Returns:
|
||||
`Tensor` of same shape as `tensor_in`, sum of input values with coefficients.
|
||||
|
||||
Example:
|
||||
>>> my_op([1., 2.], [3., 4.], my_param=0.5, other_param=0.6,
|
||||
output_collections=['MY_OPS'], name='add_t1t2')
|
||||
[2.3, 3.4]
|
||||
"""
|
||||
with tf.name_scope(name, "my_op", [tensor_in, other_tensor_in]):
|
||||
tensor_in = tf.convert_to_tensor(tensor_in)
|
||||
other_tensor_in = tf.convert_to_tensor(other_tensor_in)
|
||||
result = my_param * tensor_in + other_param * other_tensor_in
|
||||
tf.add_to_collections(output_collections, result)
|
||||
return result
|
||||
Example:
|
||||
>>> my_op([1., 2.], [3., 4.], my_param=0.5, other_param=0.6,
|
||||
output_collections=['MY_OPS'], name='add_t1t2')
|
||||
[2.3, 3.4]
|
||||
"""
|
||||
with tf.name_scope(name, "my_op", [tensor_in, other_tensor_in]):
|
||||
tensor_in = tf.convert_to_tensor(tensor_in)
|
||||
other_tensor_in = tf.convert_to_tensor(other_tensor_in)
|
||||
result = my_param * tensor_in + other_param * other_tensor_in
|
||||
tf.add_to_collection(output_collections, result)
|
||||
return result
|
||||
|
||||
Usage:
|
||||
|
||||
|
@ -121,16 +121,16 @@ class ZeroOutOp : public OpKernel {
|
||||
Tensor* output_tensor = NULL;
|
||||
OP_REQUIRES_OK(context, context->allocate_output(0, input_tensor.shape(),
|
||||
&output_tensor));
|
||||
auto output = output_tensor->flat<int32>();
|
||||
auto output_flat = output_tensor->flat<int32>();
|
||||
|
||||
// Set all but the first element of the output tensor to 0.
|
||||
const int N = input.size();
|
||||
for (int i = 1; i < N; i++) {
|
||||
output(i) = 0;
|
||||
output_flat(i) = 0;
|
||||
}
|
||||
|
||||
// Preserve the first input value if possible.
|
||||
if (N > 0) output(0) = input(0);
|
||||
if (N > 0) output_flat(0) = input(0);
|
||||
}
|
||||
};
|
||||
```
|
||||
|
@ -323,7 +323,7 @@ for i in range(1000):
|
||||
sess.run(train, {x:x_train, y:y_train})
|
||||
|
||||
# evaluate training accuracy
|
||||
curr_W, curr_b, curr_loss = sess.run([W, b, loss], {x:x_train, y:y_train})
|
||||
curr_W, curr_b, curr_loss = sess.run([W, b, loss], {x:x_train, y:y_train})
|
||||
print("W: %s b: %s loss: %s"%(curr_W, curr_b, curr_loss))
|
||||
```
|
||||
When run, it produces
|
||||
|
@ -282,18 +282,15 @@ validation_metrics = {
|
||||
"accuracy":
|
||||
tf.contrib.learn.MetricSpec(
|
||||
metric_fn=tf.contrib.metrics.streaming_accuracy,
|
||||
prediction_key=tf.contrib.learn.prediction_key.PredictionKey.
|
||||
CLASSES),
|
||||
prediction_key=tf.contrib.learn.PredictionKey.CLASSES),
|
||||
"precision":
|
||||
tf.contrib.learn.MetricSpec(
|
||||
metric_fn=tf.contrib.metrics.streaming_precision,
|
||||
prediction_key=tf.contrib.learn.prediction_key.PredictionKey.
|
||||
CLASSES),
|
||||
prediction_key=tf.contrib.learn.PredictionKey.CLASSES),
|
||||
"recall":
|
||||
tf.contrib.learn.MetricSpec(
|
||||
metric_fn=tf.contrib.metrics.streaming_recall,
|
||||
prediction_key=tf.contrib.learn.prediction_key.PredictionKey.
|
||||
CLASSES)
|
||||
prediction_key=tf.contrib.learn.PredictionKey.CLASSES)
|
||||
}
|
||||
```
|
||||
|
||||
|
@ -282,7 +282,7 @@ enough that it can be stored in @{tf.constant TensorFlow constants}. The
|
||||
following code produces the simplest possible input pipeline:
|
||||
|
||||
```python
|
||||
# Define the test inputs
|
||||
# Define the training inputs
|
||||
def get_train_inputs():
|
||||
x = tf.constant(training_set.data)
|
||||
y = tf.constant(training_set.target)
|
||||
|
@ -35,7 +35,7 @@ enable TensorFlow for C:
|
||||
OS="linux" # Change to "darwin" for Mac OS
|
||||
TARGET_DIRECTORY="/usr/local"
|
||||
curl -L \
|
||||
"https://storage.googleapis.com/tensorflow/libtensorflow/libtensorflow-${TF_TYPE}-${OS}-x86_64-1.1.0-rc1.tar.gz" |
|
||||
"https://storage.googleapis.com/tensorflow/libtensorflow/libtensorflow-${TF_TYPE}-${OS}-x86_64-1.1.0-rc2.tar.gz" |
|
||||
sudo tar -C $TARGET_DIRECTORY -xz
|
||||
|
||||
The `tar` command extracts the TensorFlow C library into the `lib`
|
||||
|
@ -35,7 +35,7 @@ steps to install this library and enable TensorFlow for Go:
|
||||
TF_TYPE="cpu" # Change to "gpu" for GPU support
|
||||
TARGET_DIRECTORY='/usr/local'
|
||||
curl -L \
|
||||
"https://storage.googleapis.com/tensorflow/libtensorflow/libtensorflow-${TF_TYPE}-$(go env GOOS)-x86_64-1.1.0-rc1.tar.gz" |
|
||||
"https://storage.googleapis.com/tensorflow/libtensorflow/libtensorflow-${TF_TYPE}-$(go env GOOS)-x86_64-1.1.0-rc2.tar.gz" |
|
||||
sudo tar -C $TARGET_DIRECTORY -xz
|
||||
|
||||
The `tar` command extracts the TensorFlow C library into the `lib`
|
||||
|
@ -34,7 +34,7 @@ following to the project's `pom.xml` to use the TensorFlow Java APIs:
|
||||
<dependency>
|
||||
<groupId>org.tensorflow</groupId>
|
||||
<artifactId>tensorflow</artifactId>
|
||||
<version>1.1.0-rc1</version>
|
||||
<version>1.1.0-rc2</version>
|
||||
</dependency>
|
||||
```
|
||||
|
||||
@ -63,7 +63,7 @@ As an example, these steps will create a Maven project that uses TensorFlow:
|
||||
<dependency>
|
||||
<groupId>org.tensorflow</groupId>
|
||||
<artifactId>tensorflow</artifactId>
|
||||
<version>1.1.0-rc1</version>
|
||||
<version>1.1.0-rc2</version>
|
||||
</dependency>
|
||||
</dependencies>
|
||||
</project>
|
||||
@ -122,7 +122,7 @@ refer to the simpler instructions above instead.
|
||||
Take the following steps to install TensorFlow for Java on Linux or Mac OS:
|
||||
|
||||
1. Download
|
||||
[libtensorflow.jar](https://storage.googleapis.com/tensorflow/libtensorflow/libtensorflow-1.1.0-rc1.jar),
|
||||
[libtensorflow.jar](https://storage.googleapis.com/tensorflow/libtensorflow/libtensorflow-1.1.0-rc2.jar),
|
||||
which is the TensorFlow Java Archive (JAR).
|
||||
|
||||
2. Decide whether you will run TensorFlow for Java on CPU(s) only or with
|
||||
@ -141,7 +141,7 @@ Take the following steps to install TensorFlow for Java on Linux or Mac OS:
|
||||
OS=$(uname -s | tr '[:upper:]' '[:lower:]')
|
||||
mkdir -p ./jni
|
||||
curl -L \
|
||||
"https://storage.googleapis.com/tensorflow/libtensorflow/libtensorflow_jni-${TF_TYPE}-${OS}-x86_64-1.1.0-rc1.tar.gz" |
|
||||
"https://storage.googleapis.com/tensorflow/libtensorflow/libtensorflow_jni-${TF_TYPE}-${OS}-x86_64-1.1.0-rc2.tar.gz" |
|
||||
tar -xz -C ./jni
|
||||
|
||||
### Install on Windows
|
||||
@ -149,10 +149,10 @@ Take the following steps to install TensorFlow for Java on Linux or Mac OS:
|
||||
Take the following steps to install TensorFlow for Java on Windows:
|
||||
|
||||
1. Download
|
||||
[libtensorflow.jar](https://storage.googleapis.com/tensorflow/libtensorflow/libtensorflow-1.1.0-rc1.jar),
|
||||
[libtensorflow.jar](https://storage.googleapis.com/tensorflow/libtensorflow/libtensorflow-1.1.0-rc2.jar),
|
||||
which is the TensorFlow Java Archive (JAR).
|
||||
2. Download the following Java Native Interface (JNI) file appropriate for
|
||||
[TensorFlow for Java on Windows](https://storage.googleapis.com/tensorflow/libtensorflow/libtensorflow_jni-cpu-windows-x86_64-1.1.0-rc1.zip).
|
||||
[TensorFlow for Java on Windows](https://storage.googleapis.com/tensorflow/libtensorflow/libtensorflow_jni-cpu-windows-x86_64-1.1.0-rc2.zip).
|
||||
3. Extract this .zip file.
|
||||
|
||||
|
||||
@ -200,7 +200,7 @@ must be part of your `classpath`. For example, you can include the
|
||||
downloaded `.jar` in your `classpath` by using the `-cp` compilation flag
|
||||
as follows:
|
||||
|
||||
<pre><b>javac -cp libtensorflow-1.1.0-rc1.jar HelloTF.java</b></pre>
|
||||
<pre><b>javac -cp libtensorflow-1.1.0-rc2.jar HelloTF.java</b></pre>
|
||||
|
||||
|
||||
### Running
|
||||
@ -213,7 +213,7 @@ two files are available to the JVM:
|
||||
|
||||
For example, the following command line executes the `HelloTF` program:
|
||||
|
||||
<pre><b>java -cp libtensorflow-1.1.0-rc1.jar:. -Djava.library.path=./jni HelloTF</b></pre>
|
||||
<pre><b>java -cp libtensorflow-1.1.0-rc2.jar:. -Djava.library.path=./jni HelloTF</b></pre>
|
||||
|
||||
If the program prints <tt>Hello from <i>version</i></tt>, you've successfully
|
||||
installed TensorFlow for Java and are ready to use the API. If the program
|
||||
|
@ -165,8 +165,8 @@ Take the following steps to install TensorFlow with Virtualenv:
|
||||
issue the following command to install TensorFlow in the active
|
||||
virtualenv environment:
|
||||
|
||||
<pre> (tensorflow)$ <b>pip install --upgrade \\
|
||||
https://storage.googleapis.com/tensorflow/linux/cpu/tensorflow-1.1.0rc1-cp27-none-linux_x86_64.whl</b></pre>
|
||||
<pre>(tensorflow)$ <b>pip3 install --upgrade \
|
||||
https://storage.googleapis.com/tensorflow/linux/cpu/tensorflow-1.1.0rc2-cp34-cp34m-linux_x86_64.whl</b></pre>
|
||||
|
||||
If you encounter installation problems, see
|
||||
[Common Installation Problems](#common_installation_problems).
|
||||
@ -269,8 +269,10 @@ take the following steps:
|
||||
install TensorFlow for Linux, Python 2.7, and CPU-only support, issue
|
||||
the following command:
|
||||
|
||||
<pre> $ <b>sudo pip install --upgrade \
|
||||
https://storage.googleapis.com/tensorflow/linux/cpu/tensorflow-1.1.0rc1-cp27-none-linux_x86_64.whl</b></pre>
|
||||
<pre>
|
||||
$ <b>sudo pip3 install --upgrade \
|
||||
https://storage.googleapis.com/tensorflow/linux/cpu/tensorflow-1.1.0rc2-cp34-cp34m-linux_x86_64.whl</b>
|
||||
</pre>
|
||||
|
||||
If this step fails, see
|
||||
[Common Installation Problems](#common_installation_problems).
|
||||
@ -456,7 +458,7 @@ Take the following steps to install TensorFlow in an Anaconda environment:
|
||||
|
||||
<pre>
|
||||
(tensorflow)$ <b>pip install --ignore-installed --upgrade \
|
||||
https://storage.googleapis.com/tensorflow/linux/cpu/tensorflow-1.1.0rc1-cp27-none-linux_x86_64.whl</b></pre>
|
||||
https://storage.googleapis.com/tensorflow/linux/cpu/tensorflow-1.1.0rc2-cp34-cp34m-linux_x86_64.whl</b></pre>
|
||||
|
||||
|
||||
<a name="ValidateYourInstallation"></a>
|
||||
@ -624,14 +626,14 @@ This section documents the relevant values for Linux installations.
|
||||
CPU only:
|
||||
|
||||
<pre>
|
||||
https://storage.googleapis.com/tensorflow/linux/cpu/tensorflow-1.1.0rc1-cp27-none-linux_x86_64.whl
|
||||
https://storage.googleapis.com/tensorflow/linux/cpu/tensorflow-1.1.0rc2-cp27-none-linux_x86_64.whl
|
||||
</pre>
|
||||
|
||||
|
||||
GPU support:
|
||||
|
||||
<pre>
|
||||
https://storage.googleapis.com/tensorflow/linux/gpu/tensorflow_gpu-1.1.0rc1-cp27-none-linux_x86_64.whl
|
||||
https://storage.googleapis.com/tensorflow/linux/gpu/tensorflow_gpu-1.1.0rc2-cp27-none-linux_x86_64.whl
|
||||
</pre>
|
||||
|
||||
Note that GPU support requires the NVIDIA hardware and software described in
|
||||
@ -643,14 +645,14 @@ Note that GPU support requires the NVIDIA hardware and software described in
|
||||
CPU only:
|
||||
|
||||
<pre>
|
||||
https://storage.googleapis.com/tensorflow/linux/cpu/tensorflow-1.1.0rc1-cp34-cp34m-linux_x86_64.whl
|
||||
https://storage.googleapis.com/tensorflow/linux/cpu/tensorflow-1.1.0rc2-cp34-cp34m-linux_x86_64.whl
|
||||
</pre>
|
||||
|
||||
|
||||
GPU support:
|
||||
|
||||
<pre>
|
||||
https://storage.googleapis.com/tensorflow/linux/gpu/tensorflow_gpu-1.1.0rc1-cp34-cp34m-linux_x86_64.whl
|
||||
https://storage.googleapis.com/tensorflow/linux/gpu/tensorflow_gpu-1.1.0rc2-cp34-cp34m-linux_x86_64.whl
|
||||
</pre>
|
||||
|
||||
Note that GPU support requires the NVIDIA hardware and software described in
|
||||
@ -662,14 +664,14 @@ Note that GPU support requires the NVIDIA hardware and software described in
|
||||
CPU only:
|
||||
|
||||
<pre>
|
||||
https://storage.googleapis.com/tensorflow/linux/cpu/tensorflow-1.1.0rc1-cp35-cp35m-linux_x86_64.whl
|
||||
https://storage.googleapis.com/tensorflow/linux/cpu/tensorflow-1.1.0rc2-cp35-cp35m-linux_x86_64.whl
|
||||
</pre>
|
||||
|
||||
|
||||
GPU support:
|
||||
|
||||
<pre>
|
||||
https://storage.googleapis.com/tensorflow/linux/gpu/tensorflow_gpu-1.1.0rc1-cp35-cp35m-linux_x86_64.whl
|
||||
https://storage.googleapis.com/tensorflow/linux/gpu/tensorflow_gpu-1.1.0rc2-cp35-cp35m-linux_x86_64.whl
|
||||
</pre>
|
||||
|
||||
|
||||
@ -681,14 +683,14 @@ Note that GPU support requires the NVIDIA hardware and software described in
|
||||
CPU only:
|
||||
|
||||
<pre>
|
||||
https://storage.googleapis.com/tensorflow/linux/cpu/tensorflow-1.1.0rc1-cp36-cp36m-linux_x86_64.whl
|
||||
https://storage.googleapis.com/tensorflow/linux/cpu/tensorflow-1.1.0rc2-cp36-cp36m-linux_x86_64.whl
|
||||
</pre>
|
||||
|
||||
|
||||
GPU support:
|
||||
|
||||
<pre>
|
||||
https://storage.googleapis.com/tensorflow/linux/gpu/tensorflow_gpu-1.1.0rc1-cp36-cp36m-linux_x86_64.whl
|
||||
https://storage.googleapis.com/tensorflow/linux/gpu/tensorflow_gpu-1.1.0rc2-cp36-cp36m-linux_x86_64.whl
|
||||
</pre>
|
||||
|
||||
|
||||
|
@ -163,7 +163,7 @@ Take the following steps to install TensorFlow with Virtualenv:
|
||||
TensorFlow in the active Virtualenv is as follows:
|
||||
|
||||
<pre> $ <b>pip3 install --upgrade \
|
||||
https://storage.googleapis.com/tensorflow/mac/cpu/tensorflow-1.1.0rc1-py2-none-any.whl</b></pre>
|
||||
https://storage.googleapis.com/tensorflow/mac/cpu/tensorflow-1.1.0rc2-py2-none-any.whl</b></pre>
|
||||
|
||||
If you encounter installation problems, see
|
||||
[Common Installation Problems](#CommonInstallationProblems).
|
||||
@ -286,7 +286,7 @@ take the following steps:
|
||||
support, issue the following command:
|
||||
|
||||
<pre> $ <b>sudo pip3 install --upgrade \
|
||||
https://storage.googleapis.com/tensorflow/mac/cpu/tensorflow-1.1.0rc1-py2-none-any.whl</b> </pre>
|
||||
https://storage.googleapis.com/tensorflow/mac/cpu/tensorflow-1.1.0rc2-py2-none-any.whl</b> </pre>
|
||||
|
||||
If the preceding command fails, see
|
||||
[Common installation problems](#CommonInstallationProblems).
|
||||
@ -398,7 +398,7 @@ Take the following steps to install TensorFlow in an Anaconda environment:
|
||||
TensorFlow for Python 2.7:
|
||||
|
||||
<pre> (tensorflow)$ <b>pip install --ignore-installed --upgrade \
|
||||
https://storage.googleapis.com/tensorflow/mac/cpu/tensorflow-1.1.0rc1-py2-none-any.whl</b></pre>
|
||||
https://storage.googleapis.com/tensorflow/mac/cpu/tensorflow-1.1.0rc2-py2-none-any.whl</b></pre>
|
||||
|
||||
|
||||
<a name="ValidateYourInstallation"></a>
|
||||
@ -604,13 +604,13 @@ This section documents the relevant values for Mac OS installations.
|
||||
CPU only:
|
||||
|
||||
<pre>
|
||||
https://storage.googleapis.com/tensorflow/mac/cpu/tensorflow-1.1.0rc1-py2-none-any.whl
|
||||
https://storage.googleapis.com/tensorflow/mac/cpu/tensorflow-1.1.0rc2-py2-none-any.whl
|
||||
</pre>
|
||||
|
||||
GPU support:
|
||||
|
||||
<pre>
|
||||
https://storage.googleapis.com/tensorflow/mac/gpu/tensorflow_gpu-1.1.0rc1-py2-none-any.whl
|
||||
https://storage.googleapis.com/tensorflow/mac/gpu/tensorflow_gpu-1.1.0rc2-py2-none-any.whl
|
||||
</pre>
|
||||
|
||||
Requires CUDA toolkit 8.0 and CuDNN v5. For other versions, see
|
||||
@ -622,13 +622,13 @@ Requires CUDA toolkit 8.0 and CuDNN v5. For other versions, see
|
||||
CPU only:
|
||||
|
||||
<pre>
|
||||
https://storage.googleapis.com/tensorflow/mac/cpu/tensorflow-1.1.0rc1-py3-none-any.whl
|
||||
https://storage.googleapis.com/tensorflow/mac/cpu/tensorflow-1.1.0rc2-py3-none-any.whl
|
||||
</pre>
|
||||
|
||||
GPU support:
|
||||
|
||||
<pre>
|
||||
https://storage.googleapis.com/tensorflow/mac/gpu/tensorflow_gpu-1.1.0rc1-py3-none-any.whl
|
||||
https://storage.googleapis.com/tensorflow/mac/gpu/tensorflow_gpu-1.1.0rc2-py3-none-any.whl
|
||||
</pre>
|
||||
|
||||
Requires CUDA toolkit 8.0 and CuDNN v5. For other versions, see
|
||||
|
@ -319,10 +319,11 @@ $ <b>bazel-bin/tensorflow/tools/pip_package/build_pip_package /tmp/tensorflow_pk
|
||||
Invoke `pip install` to install that pip package.
|
||||
The filename of the `.whl` file depends on your platform.
|
||||
For example, the following command will install the pip package
|
||||
for TensorFlow 1.1.0rc1 on Linux:
|
||||
|
||||
for TensorFlow 1.1.0rc2 on Linux:
|
||||
|
||||
<pre>
|
||||
$ <b>sudo pip install /tmp/tensorflow_pkg/tensorflow-1.1.0rc1-py2-none-any.whl</b>
|
||||
$ <b>sudo pip install /tmp/tensorflow_pkg/tensorflow-1.1.0rc2-py2-none-any.whl</b>
|
||||
</pre>
|
||||
|
||||
## Validate your installation
|
||||
|
@ -114,12 +114,12 @@ Take the following steps to install TensorFlow in an Anaconda environment:
|
||||
environment. To install the CPU-only version of TensorFlow, enter the
|
||||
following command:
|
||||
|
||||
<pre>(tensorflow)C:\> <b>pip install --ignore-installed --upgrade https://storage.googleapis.com/tensorflow/windows/cpu/tensorflow-1.1.0rc1-cp35-cp35m-win_amd64.whl</b> </pre>
|
||||
<pre>(tensorflow)C:\> <b>pip install --ignore-installed --upgrade https://storage.googleapis.com/tensorflow/windows/cpu/tensorflow-1.1.0rc2-cp35-cp35m-win_amd64.whl</b> </pre>
|
||||
|
||||
To install the GPU version of TensorFlow, enter the following command
|
||||
(on a single line):
|
||||
|
||||
<pre>(tensorflow)C:\> <b>pip install --ignore-installed --upgrade https://storage.googleapis.com/tensorflow/windows/gpu/tensorflow_gpu-1.1.0rc1-cp35-cp35m-win_amd64.whl</b> </pre>
|
||||
<pre>(tensorflow)C:\> <b>pip install --ignore-installed --upgrade https://storage.googleapis.com/tensorflow/windows/gpu/tensorflow_gpu-1.1.0rc2-cp35-cp35m-win_amd64.whl</b> </pre>
|
||||
|
||||
## Validate your installation
|
||||
|
||||
|
@ -88,7 +88,8 @@ def run_training():
|
||||
saver = tf.train.Saver()
|
||||
|
||||
# Create the op for initializing variables.
|
||||
init_op = tf.global_variables_initializer()
|
||||
init_op = tf.group(tf.global_variables_initializer(),
|
||||
tf.local_variables_initializer())
|
||||
|
||||
# Create a session for running Ops on the Graph.
|
||||
sess = tf.Session()
|
||||
|
@ -135,7 +135,8 @@ def train():
|
||||
accuracy = tf.reduce_mean(tf.cast(correct_prediction, tf.float32))
|
||||
tf.summary.scalar('accuracy', accuracy)
|
||||
|
||||
# Merge all the summaries and write them out to /tmp/tensorflow/mnist/logs/mnist_with_summaries (by default)
|
||||
# Merge all the summaries and write them out to
|
||||
# /tmp/tensorflow/mnist/logs/mnist_with_summaries (by default)
|
||||
merged = tf.summary.merge_all()
|
||||
train_writer = tf.summary.FileWriter(FLAGS.log_dir + '/train', sess.graph)
|
||||
test_writer = tf.summary.FileWriter(FLAGS.log_dir + '/test')
|
||||
@ -196,9 +197,15 @@ if __name__ == '__main__':
|
||||
help='Initial learning rate')
|
||||
parser.add_argument('--dropout', type=float, default=0.9,
|
||||
help='Keep probability for training dropout.')
|
||||
parser.add_argument('--data_dir', type=str, default='/tmp/tensorflow/mnist/input_data',
|
||||
help='Directory for storing input data')
|
||||
parser.add_argument('--log_dir', type=str, default='/tmp/tensorflow/mnist/logs/mnist_with_summaries',
|
||||
help='Summaries log directory')
|
||||
parser.add_argument(
|
||||
'--data_dir',
|
||||
type=str,
|
||||
default='/tmp/tensorflow/mnist/input_data',
|
||||
help='Directory for storing input data')
|
||||
parser.add_argument(
|
||||
'--log_dir',
|
||||
type=str,
|
||||
default='/tmp/tensorflow/mnist/logs/mnist_with_summaries',
|
||||
help='Summaries log directory')
|
||||
FLAGS, unparsed = parser.parse_known_args()
|
||||
tf.app.run(main=main, argv=[sys.argv[0]] + unparsed)
|
||||
|
@ -25,6 +25,7 @@ load("//tensorflow/core:platform/default/build_config.bzl", "tf_proto_library_py
|
||||
load("//tensorflow/core:platform/default/build_config.bzl", "tf_additional_lib_deps")
|
||||
load("//tensorflow/core:platform/default/build_config_root.bzl", "tf_additional_plugin_deps")
|
||||
load("//tensorflow/python:build_defs.bzl", "tf_gen_op_wrapper_private_py")
|
||||
load("//tensorflow/core:platform/default/build_config_root.bzl", "tf_additional_verbs_deps")
|
||||
|
||||
py_library(
|
||||
name = "python",
|
||||
@ -2610,7 +2611,9 @@ tf_py_wrap_cc(
|
||||
"//tensorflow/tools/graph_transforms:transform_graph_lib",
|
||||
"//tensorflow/tools/tfprof/internal:print_model_analysis",
|
||||
"//util/python:python_headers",
|
||||
] + tf_additional_lib_deps() + tf_additional_plugin_deps(),
|
||||
] + (tf_additional_lib_deps() +
|
||||
tf_additional_plugin_deps() +
|
||||
tf_additional_verbs_deps()),
|
||||
)
|
||||
|
||||
py_library(
|
||||
|
@ -12,7 +12,7 @@
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ==============================================================================
|
||||
"""Tests for tensorflow.python.framework.importer."""
|
||||
"""Tests for tensorflow.python.framework.dtypes."""
|
||||
|
||||
from __future__ import absolute_import
|
||||
from __future__ import division
|
||||
|
@ -352,9 +352,20 @@ class FixedLengthRecordReaderTest(test.TestCase):
|
||||
self._record_bytes = 3
|
||||
self._footer_bytes = 2
|
||||
|
||||
self._hop_bytes = 2
|
||||
self._num_overlapped_records = 3
|
||||
|
||||
def _Record(self, f, r):
|
||||
return compat.as_bytes(str(f * 2 + r) * self._record_bytes)
|
||||
|
||||
def _OverlappedRecord(self, f, r):
|
||||
record_str = "".join([
|
||||
str(i)[0]
|
||||
for i in range(r * self._hop_bytes,
|
||||
r * self._hop_bytes + self._record_bytes)
|
||||
])
|
||||
return compat.as_bytes(record_str)
|
||||
|
||||
def _CreateFiles(self):
|
||||
filenames = []
|
||||
for i in range(self._num_files):
|
||||
@ -367,6 +378,23 @@ class FixedLengthRecordReaderTest(test.TestCase):
|
||||
f.write(b"F" * self._footer_bytes)
|
||||
return filenames
|
||||
|
||||
def _CreateOverlappedRecordFiles(self):
|
||||
filenames = []
|
||||
for i in range(self._num_files):
|
||||
fn = os.path.join(self.get_temp_dir(),
|
||||
"fixed_length_overlapped_record.%d.txt" % i)
|
||||
filenames.append(fn)
|
||||
with open(fn, "wb") as f:
|
||||
f.write(b"H" * self._header_bytes)
|
||||
all_records_str = "".join([
|
||||
str(i)[0]
|
||||
for i in range(self._record_bytes + self._hop_bytes *
|
||||
(self._num_overlapped_records - 1))
|
||||
])
|
||||
f.write(compat.as_bytes(all_records_str))
|
||||
f.write(b"F" * self._footer_bytes)
|
||||
return filenames
|
||||
|
||||
def testOneEpoch(self):
|
||||
files = self._CreateFiles()
|
||||
with self.test_session() as sess:
|
||||
@ -374,6 +402,7 @@ class FixedLengthRecordReaderTest(test.TestCase):
|
||||
header_bytes=self._header_bytes,
|
||||
record_bytes=self._record_bytes,
|
||||
footer_bytes=self._footer_bytes,
|
||||
hop_bytes=0,
|
||||
name="test_reader")
|
||||
queue = data_flow_ops.FIFOQueue(99, [dtypes.string], shapes=())
|
||||
key, value = reader.read(queue)
|
||||
@ -390,6 +419,31 @@ class FixedLengthRecordReaderTest(test.TestCase):
|
||||
"\\(requested 1, current size 0\\)"):
|
||||
k, v = sess.run([key, value])
|
||||
|
||||
def testOneEpochWithHopBytes(self):
|
||||
files = self._CreateOverlappedRecordFiles()
|
||||
with self.test_session() as sess:
|
||||
reader = io_ops.FixedLengthRecordReader(
|
||||
header_bytes=self._header_bytes,
|
||||
record_bytes=self._record_bytes,
|
||||
footer_bytes=self._footer_bytes,
|
||||
hop_bytes=self._hop_bytes,
|
||||
name="test_reader")
|
||||
queue = data_flow_ops.FIFOQueue(99, [dtypes.string], shapes=())
|
||||
key, value = reader.read(queue)
|
||||
|
||||
queue.enqueue_many([files]).run()
|
||||
queue.close().run()
|
||||
for i in range(self._num_files):
|
||||
for j in range(self._num_overlapped_records):
|
||||
k, v = sess.run([key, value])
|
||||
print(v)
|
||||
self.assertAllEqual("%s:%d" % (files[i], j), compat.as_text(k))
|
||||
self.assertAllEqual(self._OverlappedRecord(i, j), v)
|
||||
|
||||
with self.assertRaisesOpError("is closed and has insufficient elements "
|
||||
"\\(requested 1, current size 0\\)"):
|
||||
k, v = sess.run([key, value])
|
||||
|
||||
|
||||
class TFRecordReaderTest(test.TestCase):
|
||||
|
||||
|
@ -30,34 +30,44 @@ from tensorflow.python.platform import test
|
||||
|
||||
class SparseTensorDenseMatMulGradientTest(test.TestCase):
|
||||
|
||||
def _sparsify(self, x):
|
||||
def _sparsify(self, x, indices_dtype=np.int64):
|
||||
x[x < 0.5] = 0
|
||||
|
||||
non_zero = np.where(x)
|
||||
x_indices = np.vstack(non_zero).astype(np.int64).T
|
||||
x_indices = np.vstack(non_zero).astype(indices_dtype).T
|
||||
x_values = x[non_zero]
|
||||
x_shape = x.shape
|
||||
|
||||
return sparse_tensor.SparseTensor(
|
||||
indices=x_indices, values=x_values, dense_shape=x_shape), len(x_values)
|
||||
|
||||
def _randomTensor(self, size, np_dtype, adjoint=False, sparse=False):
|
||||
def _randomTensor(self,
|
||||
size,
|
||||
values_dtype,
|
||||
adjoint=False,
|
||||
sparse=False,
|
||||
indices_dtype=np.int64):
|
||||
n, m = size
|
||||
x = np.random.randn(n, m).astype(np_dtype)
|
||||
x = np.random.randn(n, m).astype(values_dtype)
|
||||
|
||||
if adjoint:
|
||||
x = x.transpose()
|
||||
|
||||
if sparse:
|
||||
return self._sparsify(x)
|
||||
return self._sparsify(x, indices_dtype=indices_dtype)
|
||||
else:
|
||||
return constant_op.constant(x, dtype=np_dtype)
|
||||
return constant_op.constant(x, dtype=values_dtype)
|
||||
|
||||
def _testGradients(self, adjoint_a, adjoint_b, name, np_dtype):
|
||||
def _testGradients(self, adjoint_a, adjoint_b, name, values_dtype,
|
||||
indices_dtype):
|
||||
n, k, m = np.random.randint(1, 10, size=3)
|
||||
sp_t, nnz = self._randomTensor(
|
||||
[n, k], np_dtype, adjoint=adjoint_a, sparse=True)
|
||||
dense_t = self._randomTensor([k, m], np_dtype, adjoint=adjoint_b)
|
||||
[n, k],
|
||||
values_dtype,
|
||||
adjoint=adjoint_a,
|
||||
sparse=True,
|
||||
indices_dtype=indices_dtype)
|
||||
dense_t = self._randomTensor([k, m], values_dtype, adjoint=adjoint_b)
|
||||
|
||||
matmul = sparse_ops.sparse_tensor_dense_matmul(
|
||||
sp_t, dense_t, adjoint_a=adjoint_a, adjoint_b=adjoint_b, name=name)
|
||||
@ -71,17 +81,19 @@ class SparseTensorDenseMatMulGradientTest(test.TestCase):
|
||||
print("%s gradient err = %s" % (name, err))
|
||||
self.assertLess(err, 1e-3)
|
||||
|
||||
def _testGradientsType(self, np_dtype):
|
||||
def _testGradientsType(self, values_dtype, indices_dtype):
|
||||
for adjoint_a in [True, False]:
|
||||
for adjoint_b in [True, False]:
|
||||
name = "sparse_tensor_dense_matmul_%s_%s_%s" % (adjoint_a, adjoint_b,
|
||||
np_dtype.__name__)
|
||||
self._testGradients(adjoint_a, adjoint_b, name, np_dtype)
|
||||
name = "sparse_tensor_dense_matmul_%s_%s_%s_%s" % (
|
||||
adjoint_a, adjoint_b, values_dtype.__name__, indices_dtype.__name__)
|
||||
self._testGradients(adjoint_a, adjoint_b, name, values_dtype,
|
||||
indices_dtype)
|
||||
|
||||
def testGradients(self):
|
||||
np.random.seed(5) # Fix seed to avoid flakiness
|
||||
self._testGradientsType(np.float32)
|
||||
self._testGradientsType(np.float64)
|
||||
self._testGradientsType(np.float32, np.int64)
|
||||
self._testGradientsType(np.float64, np.int64)
|
||||
self._testGradientsType(np.float32, np.int32)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
|
@ -45,7 +45,12 @@ def _maybe_complex(x):
|
||||
|
||||
class SparseTensorDenseMatMulTest(test.TestCase):
|
||||
|
||||
def _testMatmul(self, x, y, adjoint_a=False, adjoint_b=False):
|
||||
def _testMatmul(self,
|
||||
x,
|
||||
y,
|
||||
adjoint_a=False,
|
||||
adjoint_b=False,
|
||||
indices_dtype=np.int64):
|
||||
x_mat = np.matrix(x)
|
||||
if adjoint_a:
|
||||
x_mat = x_mat.H
|
||||
@ -55,7 +60,7 @@ class SparseTensorDenseMatMulTest(test.TestCase):
|
||||
|
||||
np_ans = x_mat * y_mat
|
||||
|
||||
x_indices = np.vstack(np.where(x)).astype(np.int64).T
|
||||
x_indices = np.vstack(np.where(x)).astype(indices_dtype).T
|
||||
x_values = x[np.where(x)]
|
||||
x_shape = x.shape
|
||||
|
||||
@ -82,13 +87,13 @@ class SparseTensorDenseMatMulTest(test.TestCase):
|
||||
else:
|
||||
self.assertAllClose(np_ans, out, rtol=1e-4, atol=1e-4)
|
||||
|
||||
def _testBasic(self, np_dtype):
|
||||
x = _maybe_complex(np.random.rand(10, 10).astype(np_dtype))
|
||||
def _testBasic(self, value_dtype, indices_dtype=np.int64):
|
||||
x = _maybe_complex(np.random.rand(10, 10).astype(value_dtype))
|
||||
x[np.abs(x) < 0.5] = 0 # Make it sparse
|
||||
|
||||
y = _maybe_complex(np.random.randn(10, 20).astype(np_dtype))
|
||||
y = _maybe_complex(np.random.randn(10, 20).astype(value_dtype))
|
||||
|
||||
self._testMatmul(x, y)
|
||||
self._testMatmul(x, y, indices_dtype=indices_dtype)
|
||||
|
||||
def testBasic(self):
|
||||
np.random.seed(127) # Repeatable results
|
||||
@ -97,6 +102,8 @@ class SparseTensorDenseMatMulTest(test.TestCase):
|
||||
self._testBasic(np.float64)
|
||||
self._testBasic(np.complex64)
|
||||
self._testBasic(np.complex128)
|
||||
self._testBasic(np.int32, indices_dtype=np.int32)
|
||||
self._testBasic(np.float32, indices_dtype=np.int32)
|
||||
|
||||
def testShapeInference(self):
|
||||
x = np.random.rand(10, 10)
|
||||
|
@ -12,7 +12,7 @@
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ==============================================================================
|
||||
"""Tests for tf.layers.core."""
|
||||
"""Tests for tf.layers.convolutional."""
|
||||
|
||||
from __future__ import absolute_import
|
||||
from __future__ import division
|
||||
|
@ -12,7 +12,7 @@
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ==============================================================================
|
||||
"""Tests for tf.layers.core."""
|
||||
"""Tests for tf.layers.normalization."""
|
||||
|
||||
from __future__ import absolute_import
|
||||
from __future__ import division
|
||||
|
@ -12,7 +12,7 @@
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ==============================================================================
|
||||
"""Tests for tf.layers.core."""
|
||||
"""Tests for tf.layers.utils."""
|
||||
|
||||
from __future__ import absolute_import
|
||||
from __future__ import division
|
||||
|
@ -198,7 +198,7 @@ class BatchNormBenchmark(test.Benchmark):
|
||||
if FLAGS.use_gpu:
|
||||
t1 = self._run_graph("gpu", shape, axes, 10, "op", True, True, 50)
|
||||
t2 = self._run_graph("gpu", shape, axes, 10, "py", True, True, 50)
|
||||
t2 = self._run_graph("gpu", shape, axes, 10, "slow", True, True, 50)
|
||||
t3 = self._run_graph("gpu", shape, axes, 10, "slow", True, True, 50)
|
||||
print_difference("op vs py", t1, t2)
|
||||
print_difference("py vs slow", t2, t3)
|
||||
print("Forward convolution (higher layers).")
|
||||
|
@ -391,7 +391,11 @@ class FixedLengthRecordReader(ReaderBase):
|
||||
"""
|
||||
# TODO(josh11b): Support serializing and restoring state.
|
||||
|
||||
def __init__(self, record_bytes, header_bytes=None, footer_bytes=None,
|
||||
def __init__(self,
|
||||
record_bytes,
|
||||
header_bytes=None,
|
||||
footer_bytes=None,
|
||||
hop_bytes=None,
|
||||
name=None):
|
||||
"""Create a FixedLengthRecordReader.
|
||||
|
||||
@ -399,11 +403,15 @@ class FixedLengthRecordReader(ReaderBase):
|
||||
record_bytes: An int.
|
||||
header_bytes: An optional int. Defaults to 0.
|
||||
footer_bytes: An optional int. Defaults to 0.
|
||||
hop_bytes: An optional int. Defaults to 0.
|
||||
name: A name for the operation (optional).
|
||||
"""
|
||||
rr = gen_io_ops._fixed_length_record_reader_v2(
|
||||
record_bytes=record_bytes, header_bytes=header_bytes,
|
||||
footer_bytes=footer_bytes, name=name)
|
||||
record_bytes=record_bytes,
|
||||
header_bytes=header_bytes,
|
||||
footer_bytes=footer_bytes,
|
||||
hop_bytes=hop_bytes,
|
||||
name=name)
|
||||
super(FixedLengthRecordReader, self).__init__(rr)
|
||||
|
||||
|
||||
|
@ -639,18 +639,22 @@ def moments(x, axes, shift=None, name=None, keep_dims=False):
|
||||
math_ops.reduce_mean(y, axes, keep_dims=True))
|
||||
else:
|
||||
shift = math_ops.cast(shift, y.dtype)
|
||||
counts, m_ss, v_ss, shift = sufficient_statistics(
|
||||
y, axes, shift=shift, keep_dims=keep_dims, name=name)
|
||||
# Reshape shift as needed.
|
||||
shift = array_ops.reshape(shift, array_ops.shape(m_ss))
|
||||
shift.set_shape(m_ss.get_shape())
|
||||
with ops.control_dependencies([counts, m_ss, v_ss]):
|
||||
mean, variance = normalize_moments(counts, m_ss, v_ss, shift, name=name)
|
||||
if x.dtype == dtypes.float16:
|
||||
return (math_ops.cast(mean, dtypes.float16),
|
||||
math_ops.cast(variance, dtypes.float16))
|
||||
else:
|
||||
return (mean, variance)
|
||||
shifted_mean = math_ops.reduce_mean(
|
||||
math_ops.subtract(y, shift), axes, keep_dims=True, name="shifted_mean")
|
||||
variance = math_ops.subtract(
|
||||
math_ops.reduce_mean(
|
||||
math_ops.squared_difference(y, shift), axes, keep_dims=True),
|
||||
math_ops.square(shifted_mean),
|
||||
name="variance")
|
||||
mean = math_ops.add(shifted_mean, shift, name="mean")
|
||||
if not keep_dims:
|
||||
mean = array_ops.squeeze(mean, axes)
|
||||
variance = array_ops.squeeze(variance, axes)
|
||||
if x.dtype == dtypes.float16:
|
||||
return (math_ops.cast(mean, dtypes.float16), math_ops.cast(
|
||||
variance, dtypes.float16))
|
||||
else:
|
||||
return (mean, variance)
|
||||
|
||||
|
||||
def weighted_moments(x, axes, frequency_weights, name=None, keep_dims=False):
|
||||
|
@ -136,12 +136,13 @@ def _SparseTensorDenseMatMulGrad(op, grad):
|
||||
Raises:
|
||||
TypeError: When the two operands don't have the same type.
|
||||
"""
|
||||
sp_t = sparse_tensor.SparseTensor(*op.inputs[:3])
|
||||
a_indices, a_values, a_shape = op.inputs[:3]
|
||||
b = op.inputs[3]
|
||||
adj_a = op.get_attr("adjoint_a")
|
||||
adj_b = op.get_attr("adjoint_b")
|
||||
|
||||
a_type = sp_t.values.dtype.base_dtype
|
||||
b_type = op.inputs[3].dtype.base_dtype
|
||||
a_type = a_values.dtype.base_dtype
|
||||
b_type = b.dtype.base_dtype
|
||||
if a_type != b_type:
|
||||
raise TypeError("SparseTensorDenseMatMul op received operands with "
|
||||
"different types: ", a_type, " and ", b_type)
|
||||
@ -150,15 +151,12 @@ def _SparseTensorDenseMatMulGrad(op, grad):
|
||||
"complex gradients.")
|
||||
|
||||
# gradient w.r.t. dense
|
||||
b_grad = sparse_ops.sparse_tensor_dense_matmul(sp_t, grad,
|
||||
adjoint_a=not adj_a)
|
||||
b_grad = gen_sparse_ops._sparse_tensor_dense_mat_mul( # pylint: disable=protected-access
|
||||
a_indices, a_values, a_shape, grad, adjoint_a=not adj_a)
|
||||
if adj_b:
|
||||
b_grad = array_ops.transpose(b_grad)
|
||||
|
||||
# gradient w.r.t. sparse values
|
||||
a_indices = op.inputs[0]
|
||||
b = op.inputs[3]
|
||||
|
||||
rows = a_indices[:, 0]
|
||||
cols = a_indices[:, 1]
|
||||
|
||||
|
@ -1239,7 +1239,7 @@ def sparse_tensor_dense_matmul(sp_a,
|
||||
A should be sorted in order of increasing dimension 1 (i.e., "column major"
|
||||
order instead of "row major" order).
|
||||
|
||||
Deciding when to use sparse_tensor_dense_matmul vs. matmul(sp_a=True):
|
||||
Deciding when to use sparse_tensor_dense_matmul vs. matmul(a_is_sparse=True):
|
||||
|
||||
There are a number of questions to ask in the decision process, including:
|
||||
|
||||
@ -1249,14 +1249,14 @@ def sparse_tensor_dense_matmul(sp_a,
|
||||
|
||||
If the answer to several of these questions is yes, consider
|
||||
converting the `SparseTensor` to a dense one and using `tf.matmul` with
|
||||
`sp_a=True`.
|
||||
`a_is_sparse=True`.
|
||||
|
||||
This operation tends to perform well when A is more sparse, if the column size
|
||||
of the product is small (e.g. matrix-vector multiplication), if
|
||||
`sp_a.dense_shape` takes on large values.
|
||||
|
||||
Below is a rough speed comparison between sparse_tensor_dense_matmul,
|
||||
labelled 'sparse', and matmul(sp_a=True), labelled 'dense'. For purposes of
|
||||
labelled 'sparse', and matmul(a_is_sparse=True), labelled 'dense'. For purposes of
|
||||
the comparison, the time spent converting from a SparseTensor to a dense
|
||||
Tensor is not included, so it is overly conservative with respect to
|
||||
the time ratio.
|
||||
|
@ -319,7 +319,7 @@ class StreamExecutorInterface {
|
||||
// Creates a new DnnSupport object, ownership is transferred to the caller.
|
||||
// If SupportsDnn() is false, this will always return null.
|
||||
//
|
||||
// If SupportsDnn() is true, this may return null, for example, if the RNG
|
||||
// If SupportsDnn() is true, this may return null, for example, if the DNN
|
||||
// initialization fails.
|
||||
virtual dnn::DnnSupport *CreateDnn() { return nullptr; }
|
||||
|
||||
|
Some files were not shown because too many files have changed in this diff Show More
Loading…
Reference in New Issue
Block a user