Merge commit for internal changes

Manually fixed conflicts by accepting --ours:
	tensorflow/contrib/slim/README.md

Manually fixed conflicts by accepting --theirs:
	tensorflow/contrib/learn/python/learn/datasets/mnist.py
	tensorflow/contrib/verbs/BUILD
	tensorflow/contrib/verbs/grpc_verbs_client.cc
	tensorflow/contrib/verbs/grpc_verbs_client.h
	tensorflow/contrib/verbs/grpc_verbs_service.cc
	tensorflow/contrib/verbs/grpc_verbs_service.h
	tensorflow/contrib/verbs/grpc_verbs_service_impl.cc
	tensorflow/contrib/verbs/grpc_verbs_service_impl.h
	tensorflow/contrib/verbs/rdma.cc
	tensorflow/contrib/verbs/rdma.h
	tensorflow/contrib/verbs/rdma_mgr.cc
	tensorflow/contrib/verbs/rdma_mgr.h
	tensorflow/contrib/verbs/rdma_rendezvous_mgr.cc
	tensorflow/contrib/verbs/rdma_rendezvous_mgr.h
	tensorflow/contrib/verbs/verbs_server_lib.cc
	tensorflow/contrib/verbs/verbs_server_lib.h
	tensorflow/contrib/verbs/verbs_service.proto
	tensorflow/contrib/verbs/verbs_util.cc
	tensorflow/contrib/verbs/verbs_util.h
	tensorflow/core/distributed_runtime/rpc/grpc_server_lib.cc
	tensorflow/core/distributed_runtime/rpc/grpc_server_lib.h
	tensorflow/core/framework/function_testlib.cc
	tensorflow/core/graph/mkl_layout_pass.cc
	tensorflow/core/graph/mkl_layout_pass_test.cc
	tensorflow/core/graph/mkl_tfconversion_pass.cc
	tensorflow/core/graph/mkl_tfconversion_pass_test.cc
	tensorflow/core/kernels/fixed_length_record_reader_op.cc
	tensorflow/core/kernels/mkl_concat_op.cc
	tensorflow/core/kernels/mkl_conv_grad_bias_ops.cc
	tensorflow/core/kernels/mkl_conv_grad_filter_ops.cc
	tensorflow/core/kernels/mkl_conv_grad_input_ops.cc
	tensorflow/core/kernels/mkl_conv_ops.cc
	tensorflow/core/kernels/mkl_fused_batch_norm_op.cc
	tensorflow/core/kernels/mkl_lrn_op.cc
	tensorflow/core/kernels/mkl_relu_op.cc
	tensorflow/core/kernels/mkl_reshape_op.cc
	tensorflow/core/kernels/mkl_tfconv_op.cc
	tensorflow/core/kernels/sparse_tensor_dense_matmul_op.cc
	tensorflow/core/kernels/sparse_tensor_dense_matmul_op.h
	tensorflow/core/kernels/sparse_tensor_dense_matmul_op_gpu.cu.cc
	tensorflow/core/ops/array_ops.cc
	tensorflow/core/ops/nn_ops.cc
	tensorflow/core/util/mkl_util.h
	tensorflow/examples/tutorials/mnist/mnist_with_summaries.py
	tensorflow/python/kernel_tests/reader_ops_test.py
	tensorflow/python/kernel_tests/sparse_tensor_dense_matmul_grad_test.py
	tensorflow/python/kernel_tests/sparse_tensor_dense_matmul_op_test.py
	tensorflow/python/ops/io_ops.py
	tensorflow/python/ops/nn_impl.py
	tensorflow/python/ops/sparse_grad.py
	tensorflow/tensorboard/gulp_tasks/vulcanize.js
	third_party/jemalloc.BUILD
This commit is contained in:
Shanqing Cai 2017-04-22 14:10:11 -04:00
commit 92675de336
263 changed files with 17130 additions and 2510 deletions
README.mdRELEASE.mdconfigure
tensorflow
BUILD
compiler/xla
contrib
core

View File

@ -53,7 +53,7 @@ $ python
>>> hello = tf.constant('Hello, TensorFlow!')
>>> sess = tf.Session()
>>> sess.run(hello)
Hello, TensorFlow!
'Hello, TensorFlow!'
>>> a = tf.constant(10)
>>> b = tf.constant(32)
>>> sess.run(a+b)

View File

@ -3,6 +3,19 @@
## Major Features and Improvements
* Added `tf.Session.make_callable()`, which provides a lower overhead means of running a similar step multiple times.
* Added ibverbs-based RDMA support to contrib (courtesy @junshi15 from Yahoo).
* `RNNCell` objects now subclass `tf.layers._Layer`. The strictness described
in the TensorFlow 1.1 release is gone: The first time an RNNCell is used,
it caches its scope. All future uses of the RNNCell will reuse variables from
that same scope. This is a breaking change from the behavior of RNNCells
in TensorFlow versions <= 1.0.1. TensorFlow 1.1 had checks in place to
ensure old code works correctly with the new semantics; this version
allows more flexible uses of RNNCell but can lead to subtle errors if
using code meant for TensorFlow <= 1.0.1. For example, writing:
`MultiRNNCell([lstm] * 5)` will now build a 5-layer LSTM stack where each
layer shares the **same** parameters. To get 5 layers each with their own
parameters, write: `MultiRNNCell([LSTMCell(...) for _ in range(5)])`.
If at all unsure, first test your code with TF 1.1; ensure it raises no
errors, and then upgrade to TF 1.2.
# Release 1.1.0

9
configure vendored
View File

@ -86,15 +86,18 @@ while true; do
PYTHON_BIN_PATH=""
# Retry
done
export PYTHON_BIN_PATH
write_action_env_to_bazelrc "PYTHON_BIN_PATH" "$PYTHON_BIN_PATH"
# TODO(ngiraldo): allow the user to optionally set PYTHON_INCLUDE_PATH and NUMPY_INCLUDE_PATH
## Set up MKL related environment settings
if false; then # Disable building with MKL for now
while [ "$TF_NEED_MKL" == "" ]; do
fromuser=""
read -p "Do you wish to build TensorFlow with MKL support? [y/N] " INPUT
read -p "Do you wish to build TensorFlow with MKL support (experimental)? [y/N] " INPUT
fromuser="1"
case $INPUT in
[Yy]* ) echo "MKL support will be enabled for TensorFlow"; TF_NEED_MKL=1;;
[Yy]* ) echo "MKL support (experimental) (will be enabled for TensorFlow"; TF_NEED_MKL=1;;
[Nn]* ) echo "No MKL support will be enabled for TensorFlow"; TF_NEED_MKL=0;;
"" ) echo "No MKL support will be enabled for TensorFlow"; TF_NEED_MKL=0;;
* ) echo "Invalid selection: " $INPUT;;
@ -261,7 +264,7 @@ if [[ "$TF_NEED_VERBS" == "1" ]]; then
fi
# Invoke python_config and set up symlinks to python includes
./util/python/python_config.sh --setup "$PYTHON_BIN_PATH"
./util/python/python_config.sh "$PYTHON_BIN_PATH"
# Append CC optimization flags to bazel.rc
echo >> tools/bazel.rc

View File

@ -278,6 +278,7 @@ filegroup(
"//tensorflow/contrib/tfprof/python/tools/tfprof:all_files",
"//tensorflow/contrib/training:all_files",
"//tensorflow/contrib/util:all_files",
"//tensorflow/contrib/verbs:all_files",
"//tensorflow/contrib/xla_tf_graph:all_files",
"//tensorflow/core:all_files",
"//tensorflow/core/debug:all_files",
@ -326,6 +327,7 @@ filegroup(
"//tensorflow/tensorboard/components/vz_line_chart:all_files",
"//tensorflow/tensorboard/components/vz_line_chart/demo:all_files",
"//tensorflow/tensorboard/components/vz_projector:all_files",
"//tensorflow/tensorboard/components/vz_projector_d3v4:all_files",
"//tensorflow/tensorboard/components/vz_sorting:all_files",
"//tensorflow/tensorboard/components/vz_sorting/test:all_files",
"//tensorflow/tensorboard/lib:all_files",

View File

@ -38,7 +38,6 @@ static void AllocateFlags() {
flags = new GpuCompilerFlags;
flags->xla_gpu_embed_ir = false;
flags->xla_cuda_data_dir = "./cuda_sdk_lib";
flags->xla_ptxas_path = "/usr/local/cuda/bin/ptxas";
flag_list = new std::vector<tensorflow::Flag>({
tensorflow::Flag(
"xla_gpu_embed_ir", &flags->xla_gpu_embed_ir,

View File

@ -649,4 +649,39 @@ ReferenceUtil::ReduceToRowArray2D(
return result;
}
/* static */ Array4D<float> ReferenceUtil::PadArray4D(
const Array4D<float>& operand, const PaddingConfig& padding,
const float pad) {
CHECK_EQ(padding.dimensions_size(), 4);
const std::vector<int64> input_bounds = {operand.n1(), operand.n2(),
operand.n3(), operand.n4()};
std::vector<int64> pad_low(4);
std::vector<int64> pad_high(4);
std::vector<int64> output_bounds(4);
for (int64 i = 0; i < 4; ++i) {
pad_low[i] = padding.dimensions(i).edge_padding_low();
pad_high[i] = padding.dimensions(i).edge_padding_high();
CHECK_EQ(padding.dimensions(i).interior_padding(), 0) << "not implemented";
output_bounds[i] = pad_low[i] + input_bounds[i] + pad_high[i];
}
Array4D<float> result(output_bounds[0], output_bounds[1], output_bounds[2],
output_bounds[3]);
result.Each([&](tensorflow::gtl::ArraySlice<int64> indices, float* value) {
for (int i = 0; i < 4; ++i) {
bool in_low_padding = indices[i] < pad_low[i];
bool in_high_padding = indices[i] >= output_bounds[i] - pad_high[i];
if (in_low_padding || in_high_padding) {
*value = pad;
return;
}
}
*value = operand(indices[0] - pad_low[0], indices[1] - pad_low[1],
indices[2] - pad_low[2], indices[3] - pad_low[3]);
});
return result;
}
} // namespace xla

View File

@ -395,6 +395,11 @@ class ReferenceUtil {
const Array2D<float>& operand, const PaddingConfig& padding,
const float pad);
// Returns the result of a 4D pad on an input array.
static Array4D<float> PadArray4D(const Array4D<float>& operand,
const PaddingConfig& padding,
const float pad);
private:
TF_DISALLOW_COPY_AND_ASSIGN(ReferenceUtil);
};

View File

@ -409,7 +409,7 @@ StatusOr<bool> CopyInsertion::Run(HloModule* module) {
// operand copy insertion above (which will share an allocation).
TF_RETURN_IF_ERROR(root_copier.RecordIndicesToCopyForColocatingBuffers(
liveness.get(), computation->parameter_instruction(0)));
} else if (copy_param_and_const_) {
} else {
// Record root indices to copy for general computations.
TF_RETURN_IF_ERROR(root_copier.RecordIndicesWhichPointToParamOrConstant(
liveness->points_to_analysis()));

View File

@ -32,9 +32,6 @@ namespace xla {
// different lifetimes than computation results.
class CopyInsertion : public HloPassInterface {
public:
explicit CopyInsertion(bool copy_param_and_const = true)
: copy_param_and_const_(copy_param_and_const) {}
~CopyInsertion() override {}
tensorflow::StringPiece name() const override { return "copy-insertion"; }
// Run the pass on the given module. Returns whether the module was changed
@ -46,10 +43,6 @@ class CopyInsertion : public HloPassInterface {
// duplicate copies.
StatusOr<HloInstruction*> FindOrInsertCopy(HloInstruction* hlo);
// Determines whether to insert copies if the root instruction is, or
// points-to, any constant or parameter instruction.
const bool copy_param_and_const_;
// A map containing all copies inserted during the copy insertion pass. The
// key is the copied instruction and the value is the copy.
std::unordered_map<HloInstruction*, HloInstruction*> inserted_copies_;

View File

@ -187,8 +187,8 @@ tensorflow::Status PrepareHloModuleForIrEmitting(
// Invokes the ptxas tool on the given PTX string, and dumps its output.
void DumpPtxasInfo(const string& ptx) {
legacy_flags::GpuCompilerFlags* flags = legacy_flags::GetGpuCompilerFlags();
const string ptxas_path = flags->xla_ptxas_path;
const string ptxas_path =
tensorflow::io::JoinPath(tensorflow::CudaRoot(), "bin/ptxas");
// Do not log PTX stats if ptxas is not found at the given path.
if (!tensorflow::Env::Default()->FileExists(ptxas_path).ok()) {
LOG(WARNING)

View File

@ -70,6 +70,7 @@ string HloExecutionProfile::ToString(
string result;
const int64 total_cycles = total_cycles_executed(computation);
double clock_rate_ghz = device_description.clock_rate_ghz();
CHECK_GE(clock_rate_ghz, 1e-9);
const auto cycles_to_microseconds = [&](double cycles) {
return cycles / clock_rate_ghz / 1000.0;
@ -80,14 +81,19 @@ string HloExecutionProfile::ToString(
double nsecs = cycles / clock_rate_ghz;
string bytes_per_sec;
string bytes_per_cycle;
if (bytes_accessed >= 0) {
if (cycles <= 0 || bytes_accessed < 0) {
bytes_per_sec = "<unknown>";
bytes_per_cycle = "<unknown>";
} else {
bytes_per_sec = tensorflow::strings::HumanReadableNumBytes(
bytes_accessed / (nsecs / 1e9));
bytes_per_cycle =
tensorflow::strings::HumanReadableNumBytes(bytes_accessed / cycles);
} else {
bytes_per_sec = "<unknown>";
bytes_per_cycle = "<unknown>";
}
double cycles_percent = 0;
if (total_cycles > 0) {
cycles_percent = cycles / static_cast<double>(total_cycles) * 100;
}
tensorflow::strings::StrAppend(
@ -97,8 +103,7 @@ string HloExecutionProfile::ToString(
":: "
"%12s/cycle :: "
"%s",
cycles, cycles / static_cast<double>(total_cycles) * 100,
cycles_to_microseconds(cycles),
cycles, cycles_percent, cycles_to_microseconds(cycles),
flops <= 0 ? "<none>" : HumanReadableNumFlops(flops, nsecs).c_str(),
bytes_per_sec.c_str(), bytes_per_cycle.c_str(), name.c_str()));
};
@ -114,26 +119,30 @@ string HloExecutionProfile::ToString(
for (const auto& item : items) {
const HloInstruction* hlo = item.first;
tensorflow::strings::StrAppend(&result, "\n\t");
int64 flops = hlo == nullptr ? -1 : cost_analysis.flop_count(*hlo);
int64 bytes_accessed =
hlo == nullptr ? -1 : cost_analysis.bytes_accessed(*hlo);
string display = hlo == nullptr ? "<none>" : hlo->ToString();
const int64 flops = (hlo == nullptr) ? -1 : cost_analysis.flop_count(*hlo);
const int64 bytes_accessed =
(hlo == nullptr) ? -1 : cost_analysis.bytes_accessed(*hlo);
const string display = (hlo == nullptr) ? "<none>" : hlo->ToString();
append_item(item.second, flops, bytes_accessed, display);
}
MetricTableReport table;
table.SetMetricName("microseconds");
table.SetEntryName("ops");
table.SetShowCategoryTable();
for (const auto& item : items) {
MetricTableReport::Entry entry;
entry.text = item.first->ToString();
entry.short_text = item.first->ToString(/*compact_operands=*/true);
entry.category_text = item.first->ToCategory();
entry.metric = cycles_to_microseconds(item.second);
table.AddEntry(std::move(entry));
if (total_cycles <= 0) {
result += "****** 0 total cycles ******\n";
} else {
MetricTableReport table;
table.SetMetricName("microseconds");
table.SetEntryName("ops");
table.SetShowCategoryTable();
for (const auto& item : items) {
MetricTableReport::Entry entry;
entry.text = item.first->ToString();
entry.short_text = item.first->ToString(/*compact_operands=*/true);
entry.category_text = item.first->ToCategory();
entry.metric = cycles_to_microseconds(item.second);
table.AddEntry(std::move(entry));
}
result += table.MakeReport(cycles_to_microseconds(total_cycles));
}
result += table.MakeReport(cycles_to_microseconds(total_cycles));
return result;
}

View File

@ -309,6 +309,10 @@ StatusOr<Shape> InferWindowOutputShape(const Shape& base_shape,
return InvalidArgument(
"the rank of the operand and the padding configuration do not match.");
}
if (operand_shape.element_type() != padding_value_shape.element_type()) {
return InvalidArgument(
"the element types of the operands to pad do not match");
}
std::vector<int64> dimensions(ShapeUtil::Rank(operand_shape));
for (int64 i = 0; i < operand_shape.dimensions_size(); ++i) {
dimensions[i] = operand_shape.dimensions(i) +
@ -338,7 +342,7 @@ StatusOr<Shape> InferWindowOutputShape(const Shape& base_shape,
// Check if both element types are the same.
if (lhs.element_type() != rhs.element_type()) {
return fail("element types mismatch");
return fail("element types do not match");
}
if (ShapeUtil::Rank(lhs) < 1 || ShapeUtil::Rank(lhs) > 2 ||

View File

@ -31,6 +31,7 @@ limitations under the License.
#include "tensorflow/core/lib/gtl/array_slice.h"
#include "tensorflow/core/lib/math/math_util.h"
#include "tensorflow/core/lib/strings/numbers.h"
#include "tensorflow/core/lib/strings/strcat.h"
#include "tensorflow/core/platform/logging.h"
#include "tensorflow/core/platform/macros.h"
#include "tensorflow/core/platform/protobuf.h"
@ -200,6 +201,46 @@ int64 PositionInContainer(const Container& container, int64 value) {
std::find(container.begin(), container.end(), value));
}
// Formats the container as a comma-separated string. StrAppend must support
// appending the elements of the container. Prefix is prepended and suffix is
// appended to the returned string.
template <typename Container>
string CommaSeparatedString(const Container& c, const char* prefix = "",
const char* suffix = "") {
// Not using Join() since the implementation here is simple anyway and this
// avoids copying the string to append prefix.
string comma_separated = prefix;
const char* separator = "";
for (const auto& entry : c) {
tensorflow::strings::StrAppend(&comma_separated, separator, entry);
separator = ", ";
}
comma_separated += suffix;
return comma_separated;
}
// Overload needed to allow the container to be an initializer list. The default
// type for T makes an empty initializer list work as well.
template <typename T = int>
string CommaSeparatedString(const std::initializer_list<T>& c,
const char* prefix = "", const char* suffix = "") {
return CommaSeparatedString<std::initializer_list<T>>(c, prefix, suffix);
}
// Formats the container in the mathematical notation for a vector, e.g. (1, 3,
// 7). StrAppend must support appending the elements of c.
template <typename Container>
string VectorString(const Container& c) {
return CommaSeparatedString(c, "(", ")");
}
// Overload needed to allow the container to be an initializer list. The default
// type for T makes an empty initializer list work as well.
template <typename T = int>
string VectorString(const std::initializer_list<T>& c) {
return VectorString<std::initializer_list<T>>(c);
}
// Returns a PaddingConfig object that represents no padding for the given rank.
PaddingConfig MakeNoPaddingConfig(int64 rank);

View File

@ -80,6 +80,26 @@ TEST(UtilTest, HumanReadableNumFlopsExample) {
ASSERT_EQ("1.00GFLOP/s", HumanReadableNumFlops(1e9, 1e9));
}
TEST(UtilTest, CommaSeparatedString) {
EXPECT_EQ(CommaSeparatedString({}), "");
EXPECT_EQ(CommaSeparatedString({"hello world"}), "hello world");
EXPECT_EQ(CommaSeparatedString({1, 57, 2}, "foo", "bar"), "foo1, 57, 2bar");
}
TEST(UtilTest, VectorString) {
std::list<int64> empty_list;
EXPECT_EQ(VectorString(empty_list), "()");
std::vector<float> float_vector = {5.5};
EXPECT_EQ(VectorString(float_vector), "(5.5)");
std::set<const char*> string_set = {"a", "b"};
EXPECT_EQ(VectorString(string_set), "(a, b)");
EXPECT_EQ(VectorString({}), "()");
EXPECT_EQ(VectorString({1, 57, 2}), "(1, 57, 2)");
}
TEST(UtilTest, LogLines) {
// Just make sure this code runs (not verifying the output).
LogLines(tensorflow::INFO, "hello\n\nworld", __FILE__, __LINE__);

View File

@ -20,7 +20,6 @@ from __future__ import print_function
import abc
import contextlib
import inspect
import types
import numpy as np
@ -33,6 +32,7 @@ from tensorflow.python.framework import tensor_shape
from tensorflow.python.framework import tensor_util
from tensorflow.python.ops import array_ops
from tensorflow.python.ops import math_ops
from tensorflow.python.util import tf_inspect
_DISTRIBUTION_PUBLIC_METHOD_WRAPPERS = [
@ -154,12 +154,12 @@ class _DistributionMeta(abc.ABCMeta):
if class_special_attr_value is None:
# No _special method available, no need to update the docstring.
continue
class_special_attr_docstring = inspect.getdoc(class_special_attr_value)
class_special_attr_docstring = tf_inspect.getdoc(class_special_attr_value)
if not class_special_attr_docstring:
# No docstring to append.
continue
class_attr_value = _copy_fn(base_attr_value)
class_attr_docstring = inspect.getdoc(base_attr_value)
class_attr_docstring = tf_inspect.getdoc(base_attr_value)
if class_attr_docstring is None:
raise ValueError(
"Expected base class fn to contain a docstring: %s.%s"

View File

@ -44,7 +44,7 @@ class _Gumbel(distribution.Distribution):
where `loc = mu` and `scale = sigma`.
The cumulative densifyt function of this distribution is,
The cumulative density function of this distribution is,
```cdf(x; mu, sigma) = exp(-exp(-(x - mu) / sigma))```

View File

@ -18,12 +18,11 @@ from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import inspect
from tensorflow.python.framework import ops
from tensorflow.python.ops import array_ops
from tensorflow.python.ops import control_flow_ops
from tensorflow.python.ops import math_ops
from tensorflow.python.util import tf_inspect
_DIVERGENCES = {}
@ -31,8 +30,8 @@ _DIVERGENCES = {}
def _registered_kl(type_a, type_b):
"""Get the KL function registered for classes a and b."""
hierarchy_a = inspect.getmro(type_a)
hierarchy_b = inspect.getmro(type_b)
hierarchy_a = tf_inspect.getmro(type_a)
hierarchy_b = tf_inspect.getmro(type_b)
dist_to_children = None
kl_fn = None
for mro_to_a, parent_a in enumerate(hierarchy_a):

View File

@ -61,8 +61,9 @@
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import contextlib
import functools
from tensorflow.python.util import tf_contextlib
from tensorflow.python.util import tf_decorator
__all__ = ['arg_scope',
'add_arg_scope',
@ -106,7 +107,7 @@ def _add_op(op):
_DECORATED_OPS[key_op] = _kwarg_names(op)
@contextlib.contextmanager
@tf_contextlib.contextmanager
def arg_scope(list_ops_or_scope, **kwargs):
"""Stores the default arguments for the given set of list_ops.
@ -170,7 +171,6 @@ def add_arg_scope(func):
Returns:
A tuple with the decorated function func_with_args().
"""
@functools.wraps(func)
def func_with_args(*args, **kwargs):
current_scope = _current_arg_scope()
current_args = kwargs
@ -181,8 +181,7 @@ def add_arg_scope(func):
return func(*args, **current_args)
_add_op(func)
setattr(func_with_args, '_key_op', _key_op(func))
setattr(func_with_args, '__doc__', func.__doc__)
return func_with_args
return tf_decorator.make_decorator(func, func_with_args)
def has_arg_scope(func):

View File

@ -18,12 +18,11 @@ from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import inspect
import numpy as np
from tensorflow.contrib.keras.python import keras
from tensorflow.python.platform import test
from tensorflow.python.util import tf_inspect
def compare_single_input_op_to_numpy(keras_op,
@ -207,7 +206,7 @@ class BackendLinearAlgebraTest(test.TestCase):
compare_single_input_op_to_numpy(keras_op, np_op, input_shape=(4, 7, 5),
keras_kwargs={'axis': -1},
np_kwargs={'axis': -1})
if 'keepdims' in inspect.getargspec(keras_op).args:
if 'keepdims' in tf_inspect.getargspec(keras_op).args:
compare_single_input_op_to_numpy(keras_op, np_op,
input_shape=(4, 7, 5),
keras_kwargs={'axis': 1,

View File

@ -20,7 +20,6 @@ from __future__ import division
from __future__ import print_function
import copy
import inspect
import json
import os
import re
@ -35,6 +34,7 @@ from tensorflow.contrib.keras.python.keras.utils import conv_utils
from tensorflow.contrib.keras.python.keras.utils.io_utils import ask_to_proceed_with_overwrite
from tensorflow.contrib.keras.python.keras.utils.layer_utils import print_summary as print_layer_summary
from tensorflow.python.framework import tensor_shape
from tensorflow.python.util import tf_inspect
# pylint: disable=g-import-not-at-top
@ -584,7 +584,7 @@ class Layer(object):
user_kwargs = copy.copy(kwargs)
if not _is_all_none(previous_mask):
# The previous layer generated a mask.
if 'mask' in inspect.getargspec(self.call).args:
if 'mask' in tf_inspect.getargspec(self.call).args:
if 'mask' not in kwargs:
# If mask is explicitly passed to __call__,
# we should override the default mask.
@ -2166,7 +2166,7 @@ class Container(Layer):
kwargs = {}
if len(computed_data) == 1:
computed_tensor, computed_mask = computed_data[0]
if 'mask' in inspect.getargspec(layer.call).args:
if 'mask' in tf_inspect.getargspec(layer.call).args:
if 'mask' not in kwargs:
kwargs['mask'] = computed_mask
output_tensors = _to_list(layer.call(computed_tensor, **kwargs))
@ -2177,7 +2177,7 @@ class Container(Layer):
else:
computed_tensors = [x[0] for x in computed_data]
computed_masks = [x[1] for x in computed_data]
if 'mask' in inspect.getargspec(layer.call).args:
if 'mask' in tf_inspect.getargspec(layer.call).args:
if 'mask' not in kwargs:
kwargs['mask'] = computed_masks
output_tensors = _to_list(layer.call(computed_tensors, **kwargs))

View File

@ -19,7 +19,6 @@ from __future__ import division
from __future__ import print_function
import copy
import inspect
import types as python_types
import numpy as np
@ -35,6 +34,7 @@ from tensorflow.contrib.keras.python.keras.utils.generic_utils import deserializ
from tensorflow.contrib.keras.python.keras.utils.generic_utils import func_dump
from tensorflow.contrib.keras.python.keras.utils.generic_utils import func_load
from tensorflow.python.framework import tensor_shape
from tensorflow.python.util import tf_inspect
class Masking(Layer):
@ -595,7 +595,7 @@ class Lambda(Layer):
def call(self, inputs, mask=None):
arguments = self.arguments
arg_spec = inspect.getargspec(self.function)
arg_spec = tf_inspect.getargspec(self.function)
if 'mask' in arg_spec.args:
arguments['mask'] = mask
return self.function(inputs, **arguments)

View File

@ -20,12 +20,12 @@ from __future__ import division
from __future__ import print_function
import copy
import inspect
from tensorflow.contrib.keras.python.keras import backend as K
from tensorflow.contrib.keras.python.keras.engine import InputSpec
from tensorflow.contrib.keras.python.keras.engine import Layer
from tensorflow.python.framework import tensor_shape
from tensorflow.python.util import tf_inspect
class Wrapper(Layer):
@ -284,7 +284,7 @@ class Bidirectional(Wrapper):
def call(self, inputs, training=None, mask=None):
kwargs = {}
func_args = inspect.getargspec(self.layer.call).args
func_args = tf_inspect.getargspec(self.layer.call).args
if 'training' in func_args:
kwargs['training'] = training
if 'mask' in func_args:

View File

@ -18,11 +18,10 @@ from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import inspect
import numpy as np
from tensorflow.contrib.keras.python import keras
from tensorflow.python.util import tf_inspect
def get_test_data(train_samples,
@ -98,7 +97,7 @@ def layer_test(layer_cls, kwargs=None, input_shape=None, input_dtype=None,
layer.set_weights(weights)
# test and instantiation from weights
if 'weights' in inspect.getargspec(layer_cls.__init__):
if 'weights' in tf_inspect.getargspec(layer_cls.__init__):
kwargs['weights'] = weights
layer = layer_cls(**kwargs)

View File

@ -17,7 +17,6 @@ from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import inspect
import marshal
import sys
import time
@ -26,6 +25,8 @@ import types as python_types
import numpy as np
import six
from tensorflow.python.util import tf_decorator
from tensorflow.python.util import tf_inspect
_GLOBAL_CUSTOM_OBJECTS = {}
@ -116,6 +117,7 @@ def get_custom_objects():
def serialize_keras_object(instance):
_, instance = tf_decorator.unwrap(instance)
if instance is None:
return None
if hasattr(instance, 'get_config'):
@ -149,7 +151,7 @@ def deserialize_keras_object(identifier,
if cls is None:
raise ValueError('Unknown ' + printable_module_name + ': ' + class_name)
if hasattr(cls, 'from_config'):
arg_spec = inspect.getargspec(cls.from_config)
arg_spec = tf_inspect.getargspec(cls.from_config)
if 'custom_objects' in arg_spec.args:
custom_objects = custom_objects or {}
return cls.from_config(

View File

@ -19,13 +19,13 @@ from __future__ import division
from __future__ import print_function
import copy
import inspect
import types
import numpy as np
from tensorflow.contrib.keras.python.keras.models import Sequential
from tensorflow.contrib.keras.python.keras.utils.np_utils import to_categorical
from tensorflow.python.util import tf_inspect
class BaseWrapper(object):
@ -97,7 +97,7 @@ class BaseWrapper(object):
legal_params = []
for fn in legal_params_fns:
legal_params += inspect.getargspec(fn)[0]
legal_params += tf_inspect.getargspec(fn)[0]
legal_params = set(legal_params)
for params_name in params:
@ -182,7 +182,7 @@ class BaseWrapper(object):
"""
override = override or {}
res = {}
fn_args = inspect.getargspec(fn)[0]
fn_args = tf_inspect.getargspec(fn)[0]
for name, value in self.sk_params.items():
if name in fn_args:
res.update({name: value})

View File

@ -24,9 +24,9 @@ from __future__ import print_function
import collections
import functools
import inspect
import re
from tensorflow.python.util import tf_inspect
# used for register_type_abbreviation and _type_repr below.
_TYPE_ABBREVIATIONS = {}
@ -230,7 +230,7 @@ def accepts(*types):
def check_accepts(f):
"""Check the types."""
spec = inspect.getargspec(f)
spec = tf_inspect.getargspec(f)
num_function_arguments = len(spec.args)
if len(types) != num_function_arguments:

View File

@ -24,11 +24,12 @@ from abc import abstractmethod
from abc import abstractproperty
import collections
import inspect
from .series import Series
from .series import TransformedSeries
from tensorflow.python.util import tf_inspect
def _make_list_of_series(x):
"""Converts `x` into a list of `Series` if possible.
@ -120,7 +121,7 @@ class Transform(object):
def parameters(self):
"""A dict of names to values of properties marked with `@parameter`."""
property_param_names = [name
for name, func in inspect.getmembers(type(self))
for name, func in tf_inspect.getmembers(type(self))
if (hasattr(func, "fget") and hasattr(
getattr(func, "fget"), "is_parameter"))]
return {name: getattr(self, name) for name in property_param_names}

View File

@ -218,7 +218,8 @@ def read_data_sets(train_dir,
if fake_data:
def fake():
return DataSet([], [], fake_data=True, one_hot=one_hot, dtype=dtype, seed=seed)
return DataSet(
[], [], fake_data=True, one_hot=one_hot, dtype=dtype, seed=seed)
train = fake()
validation = fake()
@ -260,13 +261,16 @@ def read_data_sets(train_dir,
train_images = train_images[validation_size:]
train_labels = train_labels[validation_size:]
train = DataSet(train_images, train_labels, dtype=dtype, reshape=reshape, seed=seed)
validation = DataSet(validation_images,
validation_labels,
dtype=dtype,
reshape=reshape,
seed=seed)
test = DataSet(test_images, test_labels, dtype=dtype, reshape=reshape, seed=seed)
train = DataSet(
train_images, train_labels, dtype=dtype, reshape=reshape, seed=seed)
validation = DataSet(
validation_images,
validation_labels,
dtype=dtype,
reshape=reshape,
seed=seed)
test = DataSet(
test_images, test_labels, dtype=dtype, reshape=reshape, seed=seed)
return base.Datasets(train=train, validation=validation, test=test)

View File

@ -21,7 +21,6 @@ from __future__ import print_function
import abc
import copy
import inspect
import os
import tempfile
@ -70,6 +69,8 @@ from tensorflow.python.training import monitored_session
from tensorflow.python.training import saver
from tensorflow.python.training import summary_io
from tensorflow.python.util import compat
from tensorflow.python.util import tf_decorator
from tensorflow.python.util import tf_inspect
AS_ITERABLE_DATE = '2016-09-15'
@ -185,14 +186,15 @@ def _model_fn_args(fn):
Raises:
ValueError: if partial function has positionally bound arguments
"""
_, fn = tf_decorator.unwrap(fn)
if hasattr(fn, 'func') and hasattr(fn, 'keywords') and hasattr(fn, 'args'):
# Handle functools.partial and similar objects.
return tuple([
arg for arg in inspect.getargspec(fn.func).args[len(fn.args):]
arg for arg in tf_inspect.getargspec(fn.func).args[len(fn.args):]
if arg not in set(fn.keywords.keys())
])
# Handle function.
return tuple(inspect.getargspec(fn).args)
return tuple(tf_inspect.getargspec(fn).args)
def _get_replica_device_setter(config):

View File

@ -52,7 +52,6 @@ from tensorflow.python.framework import ops
from tensorflow.python.ops import array_ops
from tensorflow.python.ops import check_ops
from tensorflow.python.ops import control_flow_ops
from tensorflow.python.ops import data_flow_ops
from tensorflow.python.ops import math_ops
from tensorflow.python.ops import parsing_ops
from tensorflow.python.ops import variables as variables_lib
@ -63,7 +62,6 @@ from tensorflow.python.saved_model import tag_constants
from tensorflow.python.training import basic_session_run_hooks
from tensorflow.python.training import input as input_lib
from tensorflow.python.training import monitored_session
from tensorflow.python.training import queue_runner_impl
from tensorflow.python.training import saver as saver_lib
from tensorflow.python.training import session_run_hook
from tensorflow.python.util import compat

View File

@ -18,7 +18,7 @@ from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import inspect
from tensorflow.python.util import tf_inspect
def assert_estimator_contract(tester, estimator_class):
@ -31,7 +31,7 @@ def assert_estimator_contract(tester, estimator_class):
tester: A tf.test.TestCase.
estimator_class: 'type' object of pre-canned estimator.
"""
attributes = inspect.getmembers(estimator_class)
attributes = tf_inspect.getmembers(estimator_class)
attribute_names = [a[0] for a in attributes]
tester.assertTrue('config' in attribute_names)

View File

@ -19,7 +19,6 @@ from __future__ import division
from __future__ import print_function
import abc
import inspect
import six
@ -38,14 +37,17 @@ from tensorflow.python.framework import ops
from tensorflow.python.framework import sparse_tensor
from tensorflow.python.ops import array_ops
from tensorflow.python.ops import control_flow_ops
from tensorflow.python.ops import logging_ops
from tensorflow.python.ops import math_ops
from tensorflow.python.ops import nn
from tensorflow.python.ops import sparse_ops
from tensorflow.python.ops import string_ops
from tensorflow.python.ops import variable_scope
from tensorflow.python.ops import weights_broadcast_ops
from tensorflow.python.platform import tf_logging as logging
from tensorflow.python.summary import summary
from tensorflow.python.training import training
from tensorflow.python.util import tf_decorator
from tensorflow.python.util import tf_inspect
class Head(object):
@ -1663,12 +1665,10 @@ def _compute_weighted_loss(loss_unweighted, weight, name="loss"):
if weight is None:
loss = math_ops.reduce_mean(loss_unweighted, name=name_scope)
return loss, loss
weight = weights_broadcast_ops.broadcast_weights(weight, loss_unweighted)
with ops.name_scope(None, "weighted_loss",
(loss_unweighted, weight)) as name:
# TODO(ptucker): Support weight broadcasting, or switch to tf.losses.
weighted_loss = math_ops.multiply(
array_ops.reshape(loss_unweighted, shape=(-1,)),
array_ops.reshape(weight, shape=(-1,)), name=name)
weighted_loss = math_ops.multiply(loss_unweighted, weight, name=name)
weighted_loss_mean = math_ops.reduce_mean(weighted_loss, name=name_scope)
weighted_loss_normalized = math_ops.div(
math_ops.reduce_sum(weighted_loss),
@ -1697,9 +1697,10 @@ def _check_mode_valid(mode):
def _get_arguments(func):
"""Returns a spec of given func."""
_, func = tf_decorator.unwrap(func)
if hasattr(func, "__code__"):
# Regular function.
return inspect.getargspec(func)
return tf_inspect.getargspec(func)
elif hasattr(func, "__call__"):
# Callable object.
return _get_arguments(func.__call__)
@ -1802,8 +1803,13 @@ def _float_weights_or_none(weights):
def _indicator_labels_streaming_mean(labels, weights=None, class_id=None):
labels = ops.convert_to_tensor(labels)
labels = math_ops.to_float(labels)
weights = _float_weights_or_none(weights)
if weights is not None:
weights = weights_broadcast_ops.broadcast_weights(weights, labels)
if class_id is not None:
if weights is not None:
weights = weights[:, class_id]
labels = labels[:, class_id]
return metrics_lib.streaming_mean(labels, weights=weights)
@ -1811,11 +1817,13 @@ def _indicator_labels_streaming_mean(labels, weights=None, class_id=None):
def _predictions_streaming_mean(predictions,
weights=None,
class_id=None):
predictions = ops.convert_to_tensor(predictions)
predictions = math_ops.to_float(predictions)
weights = _float_weights_or_none(weights)
if weights is not None:
weights = ops.convert_to_tensor(weights)
weights = weights_broadcast_ops.broadcast_weights(weights, predictions)
if class_id is not None:
if weights is not None:
weights = weights[:, class_id]
predictions = predictions[:, class_id]
return metrics_lib.streaming_mean(predictions, weights=weights)
@ -1850,16 +1858,21 @@ def _class_labels_streaming_mean(labels, weights, class_id):
def _streaming_auc(predictions, labels, weights=None, class_id=None,
curve="ROC"):
predictions = ops.convert_to_tensor(predictions)
labels = ops.convert_to_tensor(labels)
# pylint: disable=missing-docstring
predictions = math_ops.to_float(predictions)
if labels.dtype.base_dtype != dtypes.bool:
logging.warning("Casting %s labels to bool.", labels.dtype)
labels = math_ops.cast(labels, dtypes.bool)
weights = _float_weights_or_none(weights)
if weights is not None:
weights = weights_broadcast_ops.broadcast_weights(weights, predictions)
if class_id is not None:
if weights is not None:
weights = weights[:, class_id]
predictions = predictions[:, class_id]
labels = labels[:, class_id]
return metrics_lib.streaming_auc(
predictions,
math_ops.cast(labels, dtypes.bool),
weights=_float_weights_or_none(weights),
curve=curve)
predictions, labels, weights=weights, curve=curve)
def _assert_class_id(class_id, num_classes=None):

View File

@ -36,7 +36,6 @@ from tensorflow.python.ops import data_flow_ops
from tensorflow.python.ops import variables
from tensorflow.python.ops.losses import losses as losses_lib
from tensorflow.python.platform import test
# pylint: enable=g-bad-todo,g-import-not-at-top
def _assert_variables(test_case,
@ -260,8 +259,10 @@ class RegressionHeadTest(test.TestCase):
),
expected_trainable=("regression_head/centered_bias_weight:0",))
variables.global_variables_initializer().run()
_assert_summary_tags(
self, ["regression_head/loss", "regression_head/centered_bias/bias_0"])
_assert_summary_tags(self, [
"regression_head/loss",
"regression_head/centered_bias/bias_0"
])
_assert_metrics(self, 5. / 3, {"loss": 5. / 3}, model_fn_ops)
def testRegressionErrorInSparseTensorLabels(self):
@ -541,7 +542,26 @@ class MultiLabelHeadTest(test.TestCase):
_assert_no_variables(self)
_assert_summary_tags(self, ["multi_label_head/loss"])
_assert_metrics(self, .089985214,
self._expected_eval_metrics(2.69956), model_fn_ops)
self._expected_eval_metrics(.89985214), model_fn_ops)
def testMultiLabelWithMultiDimensionalWeight(self):
n_classes = 3
head = head_lib.multi_label_head(
n_classes=n_classes,
weight_column_name="label_weight",
metric_class_ids=range(n_classes))
with ops.Graph().as_default(), session.Session():
model_fn_ops = head.create_model_fn_ops(
features={"label_weight": ((.1, .1, .1),)},
labels=self._labels,
mode=model_fn.ModeKeys.TRAIN,
train_op_fn=head_lib.no_op_train_fn,
logits=self._logits)
self._assert_output_alternatives(model_fn_ops)
_assert_no_variables(self)
_assert_summary_tags(self, ["multi_label_head/loss"])
_assert_metrics(self, .089985214,
self._expected_eval_metrics(.89985214), model_fn_ops)
def testMultiLabelWithCustomLoss(self):
n_classes = 3
@ -560,8 +580,9 @@ class MultiLabelHeadTest(test.TestCase):
self._assert_output_alternatives(model_fn_ops)
_assert_no_variables(self)
_assert_summary_tags(self, ["multi_label_head/loss"])
_assert_metrics(self, 0.089985214,
self._expected_eval_metrics(0.089985214), model_fn_ops)
expected_loss = .089985214
_assert_metrics(self, expected_loss,
self._expected_eval_metrics(expected_loss), model_fn_ops)
def testMultiLabelWithCenteredBias(self):
n_classes = 3
@ -910,9 +931,10 @@ class BinaryClassificationHeadTest(test.TestCase):
"Adagrad:0"),),
expected_trainable=("binary_logistic_head/centered_bias_weight:0",))
variables.global_variables_initializer().run()
_assert_summary_tags(
self, ["binary_logistic_head/loss",
"binary_logistic_head/centered_bias/bias_0"])
_assert_summary_tags(self, [
"binary_logistic_head/loss",
"binary_logistic_head/centered_bias/bias_0"
])
expected_loss = .81326175
_assert_metrics(self, expected_loss,
self._expected_eval_metrics(expected_loss), model_fn_ops)
@ -1416,7 +1438,8 @@ class BinarySvmHeadTest(test.TestCase):
with ops.Graph().as_default(), session.Session():
weights = (7., 11.)
model_fn_ops = head.create_model_fn_ops(
features={"weights": weights},
# We have to add an extra dim here for weights broadcasting to work.
features={"weights": tuple([(w,) for w in weights])},
mode=model_fn.ModeKeys.TRAIN,
labels=self._labels,
train_op_fn=head_lib.no_op_train_fn,
@ -1424,11 +1447,10 @@ class BinarySvmHeadTest(test.TestCase):
self._assert_output_alternatives(model_fn_ops)
_assert_no_variables(self)
_assert_summary_tags(self, ["binary_svm_head/loss"])
expected_weighted_sum = np.sum(
np.multiply(weights, self._expected_losses))
_assert_metrics(self, expected_weighted_sum / len(weights), {
expected_weighted_losses = np.multiply(weights, self._expected_losses)
_assert_metrics(self, np.mean(expected_weighted_losses), {
"accuracy": 1.,
"loss": expected_weighted_sum / np.sum(weights),
"loss": np.sum(expected_weighted_losses) / np.sum(weights),
}, model_fn_ops)
def testBinarySVMWithCenteredBias(self):
@ -1450,9 +1472,10 @@ class BinarySvmHeadTest(test.TestCase):
),
expected_trainable=("binary_svm_head/centered_bias_weight:0",))
variables.global_variables_initializer().run()
_assert_summary_tags(
self, ["binary_svm_head/loss",
"binary_svm_head/centered_bias/bias_0"])
_assert_summary_tags(self, [
"binary_svm_head/loss",
"binary_svm_head/centered_bias/bias_0"
])
expected_loss = np.average(self._expected_losses)
_assert_metrics(self, expected_loss, {
"accuracy": 1.,

View File

@ -12,7 +12,6 @@
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""ExportStrategy class represents different flavors of model export."""
from __future__ import absolute_import
@ -20,13 +19,14 @@ from __future__ import division
from __future__ import print_function
import collections
import inspect
from tensorflow.python.util import tf_inspect
__all__ = ['ExportStrategy']
class ExportStrategy(collections.namedtuple('ExportStrategy',
['name', 'export_fn'])):
class ExportStrategy(
collections.namedtuple('ExportStrategy', ['name', 'export_fn'])):
"""A class representing a type of model export.
Typically constructed by a utility function specific to the exporter, such as
@ -74,7 +74,7 @@ class ExportStrategy(collections.namedtuple('ExportStrategy',
"""
# don't break existing export_fns that don't accept checkpoint_path and
# eval_result
export_fn_args = inspect.getargspec(self.export_fn).args
export_fn_args = tf_inspect.getargspec(self.export_fn).args
kwargs = {}
if 'checkpoint_path' in export_fn_args:
kwargs['checkpoint_path'] = checkpoint_path

View File

@ -18,10 +18,10 @@ from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import inspect
import six
from tensorflow.python.platform import tf_logging as logging
from tensorflow.python.util import tf_inspect
def _assert_named_args(sentinel):
@ -43,11 +43,11 @@ def _args(fn):
if hasattr(fn, 'func') and hasattr(fn, 'keywords'):
# Handle functools.partial and similar objects.
return tuple([
arg for arg in inspect.getargspec(fn.func).args
arg for arg in tf_inspect.getargspec(fn.func).args
if arg not in set(fn.keywords.keys())
])
# Handle function.
return tuple(inspect.getargspec(fn).args)
return tuple(tf_inspect.getargspec(fn).args)
_CANONICAL_LABELS_ARG = 'labels'

View File

@ -35,7 +35,6 @@ from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import inspect
import os
import time
@ -53,6 +52,7 @@ from tensorflow.python.platform import tf_logging as logging
from tensorflow.python.training import saver as saver_lib
from tensorflow.python.training import summary_io
from tensorflow.python.util import deprecation
from tensorflow.python.util import tf_inspect
# TODO(ptucker): Split each monitor class into a separate file.
@ -1164,7 +1164,7 @@ class RunHookAdapterForMonitors(session_run_hook.SessionRunHook):
def end(self, session):
self._last_step = None
for m in self._monitors:
if "session" in inspect.getargspec(m.end).args:
if "session" in tf_inspect.getargspec(m.end).args:
m.end(session=session)
else:
m.end()

View File

@ -51,6 +51,7 @@ tf_custom_op_py_library(
"//tensorflow/python:framework",
"//tensorflow/python:framework_for_generated_wrappers",
"//tensorflow/python:init_ops",
"//tensorflow/python:layers",
"//tensorflow/python:math_ops",
"//tensorflow/python:nn_ops",
"//tensorflow/python:partitioned_variables",

View File

@ -369,28 +369,28 @@ class RNNCellTest(test.TestCase):
self.assertFalse([s for s in cpu_stats if "gru_cell" in s.node_name])
self.assertTrue([s for s in gpu_stats if "gru_cell" in s.node_name])
def testUsingSecondCellInScopeWithExistingVariablesFails(self):
# This test should go away when this behavior is no longer an
# error (Approx. May 2017)
cell1 = core_rnn_cell_impl.LSTMCell(3)
cell2 = core_rnn_cell_impl.LSTMCell(3)
x = array_ops.zeros([1, 3])
m = core_rnn_cell_impl.LSTMStateTuple(*[array_ops.zeros([1, 3])] * 2)
cell1(x, m)
with self.assertRaisesRegexp(ValueError, r"LSTMCell\(..., reuse=True\)"):
cell2(x, m)
# def testUsingSecondCellInScopeWithExistingVariablesFails(self):
# # This test should go away when this behavior is no longer an
# # error (Approx. May 2017)
# cell1 = core_rnn_cell_impl.LSTMCell(3)
# cell2 = core_rnn_cell_impl.LSTMCell(3)
# x = array_ops.zeros([1, 3])
# m = core_rnn_cell_impl.LSTMStateTuple(*[array_ops.zeros([1, 3])] * 2)
# cell1(x, m)
# with self.assertRaisesRegexp(ValueError, r"LSTMCell\(..., reuse=True\)"):
# cell2(x, m)
def testUsingCellInDifferentScopeFromFirstCallFails(self):
# This test should go away when this behavior is no longer an
# error (Approx. May 2017)
cell = core_rnn_cell_impl.LSTMCell(3)
x = array_ops.zeros([1, 3])
m = core_rnn_cell_impl.LSTMStateTuple(*[array_ops.zeros([1, 3])] * 2)
with variable_scope.variable_scope("scope1"):
cell(x, m)
with variable_scope.variable_scope("scope2"):
with self.assertRaisesRegexp(ValueError, r"Attempt to reuse RNNCell"):
cell(x, m)
# def testUsingCellInDifferentScopeFromFirstCallFails(self):
# # This test should go away when this behavior is no longer an
# # error (Approx. May 2017)
# cell = core_rnn_cell_impl.LSTMCell(3)
# x = array_ops.zeros([1, 3])
# m = core_rnn_cell_impl.LSTMStateTuple(*[array_ops.zeros([1, 3])] * 2)
# with variable_scope.variable_scope("scope1"):
# cell(x, m)
# with variable_scope.variable_scope("scope2"):
# with self.assertRaisesRegexp(ValueError, r"Attempt to reuse RNNCell"):
# cell(x, m)
def testEmbeddingWrapper(self):
with self.test_session() as sess:

View File

@ -521,7 +521,7 @@ class LSTMTest(test.TestCase):
input_value = np.random.randn(batch_size, input_size)
sess.run(outputs, feed_dict={inputs[0]: input_value})
def testStateTupleWithProjAndSequenceLength(self):
def _testStateTupleWithProjAndSequenceLength(self):
num_units = 3
input_size = 5
batch_size = 2

View File

@ -569,7 +569,7 @@ class RNNCellTest(test.TestCase):
self.assertTrue(
float(np.linalg.norm((state[0, :] - state[i, :]))) > 1e-6)
def testAttentionCellWrapperCorrectResult(self):
def _testAttentionCellWrapperCorrectResult(self):
num_units = 4
attn_length = 6
batch_size = 2

View File

@ -108,11 +108,11 @@ class BasicRNNCell(RNNCell):
"""The most basic RNN cell."""
def __init__(self, num_units, input_size=None, activation=tanh, reuse=None):
super(BasicRNNCell, self).__init__(_reuse=reuse)
if input_size is not None:
logging.warn("%s: The input_size parameter is deprecated.", self)
self._num_units = num_units
self._activation = activation
self._reuse = reuse
@property
def state_size(self):
@ -122,11 +122,9 @@ class BasicRNNCell(RNNCell):
def output_size(self):
return self._num_units
def __call__(self, inputs, state, scope=None):
def call(self, inputs, state):
"""Most basic RNN: output = new_state = act(W * input + U * state + B)."""
with _checked_scope(self, scope or "basic_rnn_cell", reuse=self._reuse):
output = self._activation(
_linear([inputs, state], self._num_units, True))
output = self._activation(_linear([inputs, state], self._num_units, True))
return output, output
@ -134,11 +132,11 @@ class GRUCell(RNNCell):
"""Gated Recurrent Unit cell (cf. http://arxiv.org/abs/1406.1078)."""
def __init__(self, num_units, input_size=None, activation=tanh, reuse=None):
super(GRUCell, self).__init__(_reuse=reuse)
if input_size is not None:
logging.warn("%s: The input_size parameter is deprecated.", self)
self._num_units = num_units
self._activation = activation
self._reuse = reuse
@property
def state_size(self):
@ -148,21 +146,15 @@ class GRUCell(RNNCell):
def output_size(self):
return self._num_units
def __call__(self, inputs, state, scope=None):
def call(self, inputs, state):
"""Gated recurrent unit (GRU) with nunits cells."""
with _checked_scope(self, scope or "gru_cell", reuse=self._reuse):
with vs.variable_scope("gates"): # Reset gate and update gate.
# We start with bias of 1.0 to not reset and not update.
value = sigmoid(_linear(
[inputs, state], 2 * self._num_units, True, 1.0))
r, u = array_ops.split(
value=value,
num_or_size_splits=2,
axis=1)
with vs.variable_scope("candidate"):
c = self._activation(_linear([inputs, r * state],
self._num_units, True))
new_h = u * state + (1 - u) * c
with vs.variable_scope("gates"): # Reset gate and update gate.
# We start with bias of 1.0 to not reset and not update.
value = sigmoid(_linear([inputs, state], 2 * self._num_units, True, 1.0))
r, u = array_ops.split(value=value, num_or_size_splits=2, axis=1)
with vs.variable_scope("candidate"):
c = self._activation(_linear([inputs, r * state], self._num_units, True))
new_h = u * state + (1 - u) * c
return new_h, new_h
@ -217,6 +209,7 @@ class BasicLSTMCell(RNNCell):
in an existing scope. If not `True`, and the existing scope already has
the given variables, an error is raised.
"""
super(BasicLSTMCell, self).__init__(_reuse=reuse)
if not state_is_tuple:
logging.warn("%s: Using a concatenated state is slower and will soon be "
"deprecated. Use state_is_tuple=True.", self)
@ -226,7 +219,6 @@ class BasicLSTMCell(RNNCell):
self._forget_bias = forget_bias
self._state_is_tuple = state_is_tuple
self._activation = activation
self._reuse = reuse
@property
def state_size(self):
@ -237,28 +229,28 @@ class BasicLSTMCell(RNNCell):
def output_size(self):
return self._num_units
def __call__(self, inputs, state, scope=None):
def call(self, inputs, state):
"""Long short-term memory cell (LSTM)."""
with _checked_scope(self, scope or "basic_lstm_cell", reuse=self._reuse):
# Parameters of gates are concatenated into one multiply for efficiency.
if self._state_is_tuple:
c, h = state
else:
c, h = array_ops.split(value=state, num_or_size_splits=2, axis=1)
concat = _linear([inputs, h], 4 * self._num_units, True)
# Parameters of gates are concatenated into one multiply for efficiency.
if self._state_is_tuple:
c, h = state
else:
c, h = array_ops.split(value=state, num_or_size_splits=2, axis=1)
# i = input_gate, j = new_input, f = forget_gate, o = output_gate
i, j, f, o = array_ops.split(value=concat, num_or_size_splits=4, axis=1)
concat = _linear([inputs, h], 4 * self._num_units, True)
new_c = (c * sigmoid(f + self._forget_bias) + sigmoid(i) *
self._activation(j))
new_h = self._activation(new_c) * sigmoid(o)
# i = input_gate, j = new_input, f = forget_gate, o = output_gate
i, j, f, o = array_ops.split(value=concat, num_or_size_splits=4, axis=1)
if self._state_is_tuple:
new_state = LSTMStateTuple(new_c, new_h)
else:
new_state = array_ops.concat([new_c, new_h], 1)
return new_h, new_state
new_c = (
c * sigmoid(f + self._forget_bias) + sigmoid(i) * self._activation(j))
new_h = self._activation(new_c) * sigmoid(o)
if self._state_is_tuple:
new_state = LSTMStateTuple(new_c, new_h)
else:
new_state = array_ops.concat([new_c, new_h], 1)
return new_h, new_state
class LSTMCell(RNNCell):
@ -319,6 +311,7 @@ class LSTMCell(RNNCell):
in an existing scope. If not `True`, and the existing scope already has
the given variables, an error is raised.
"""
super(LSTMCell, self).__init__(_reuse=reuse)
if not state_is_tuple:
logging.warn("%s: Using a concatenated state is slower and will soon be "
"deprecated. Use state_is_tuple=True.", self)
@ -341,7 +334,6 @@ class LSTMCell(RNNCell):
self._forget_bias = forget_bias
self._state_is_tuple = state_is_tuple
self._activation = activation
self._reuse = reuse
if num_proj:
self._state_size = (
@ -362,7 +354,7 @@ class LSTMCell(RNNCell):
def output_size(self):
return self._output_size
def __call__(self, inputs, state, scope=None):
def call(self, inputs, state):
"""Run one step of LSTM.
Args:
@ -371,7 +363,6 @@ class LSTMCell(RNNCell):
`2-D, batch x state_size`. If `state_is_tuple` is True, this must be a
tuple of state Tensors, both `2-D`, with column sizes `c_state` and
`m_state`.
scope: VariableScope for the created subgraph; defaults to "lstm_cell".
Returns:
A tuple containing:
@ -400,9 +391,8 @@ class LSTMCell(RNNCell):
input_size = inputs.get_shape().with_rank(2)[1]
if input_size.value is None:
raise ValueError("Could not infer input size from inputs.get_shape()[-1]")
with _checked_scope(self, scope or "lstm_cell",
initializer=self._initializer,
reuse=self._reuse) as unit_scope:
scope = vs.get_variable_scope()
with vs.variable_scope(scope, initializer=self._initializer) as unit_scope:
if self._num_unit_shards is not None:
unit_scope.set_partitioner(
partitioned_variables.fixed_size_partitioner(
@ -481,13 +471,13 @@ class OutputProjectionWrapper(RNNCell):
TypeError: if cell is not an RNNCell.
ValueError: if output_size is not positive.
"""
super(OutputProjectionWrapper, self).__init__(_reuse=reuse)
if not isinstance(cell, RNNCell):
raise TypeError("The parameter cell is not RNNCell.")
if output_size < 1:
raise ValueError("Parameter output_size must be > 0: %d." % output_size)
self._cell = cell
self._output_size = output_size
self._reuse = reuse
self._activation = activation
@property
@ -502,15 +492,12 @@ class OutputProjectionWrapper(RNNCell):
with ops.name_scope(type(self).__name__ + "ZeroState", values=[batch_size]):
return self._cell.zero_state(batch_size, dtype)
def __call__(self, inputs, state, scope=None):
def call(self, inputs, state):
"""Run the cell and output projection on inputs, starting from state."""
output, res_state = self._cell(inputs, state)
# Default scope: "OutputProjectionWrapper"
with _checked_scope(self, scope or "output_projection_wrapper",
reuse=self._reuse):
projected = _linear(output, self._output_size, True)
if self._activation:
projected = self._activation(projected)
projected = _linear(output, self._output_size, True)
if self._activation:
projected = self._activation(projected)
return projected, res_state
@ -522,7 +509,8 @@ class InputProjectionWrapper(RNNCell):
do the projection on this batch-concatenated sequence, then split it.
"""
def __init__(self, cell, num_proj, activation=None, input_size=None):
def __init__(self, cell, num_proj, activation=None, input_size=None,
reuse=None):
"""Create a cell with input projection.
Args:
@ -530,10 +518,14 @@ class InputProjectionWrapper(RNNCell):
num_proj: Python integer. The dimension to project to.
activation: (optional) an optional activation function.
input_size: Deprecated and unused.
reuse: (optional) Python boolean describing whether to reuse variables
in an existing scope. If not `True`, and the existing scope already has
the given variables, an error is raised.
Raises:
TypeError: if cell is not an RNNCell.
"""
super(InputProjectionWrapper, self).__init__(_reuse=reuse)
if input_size is not None:
logging.warn("%s: The input_size parameter is deprecated.", self)
if not isinstance(cell, RNNCell):
@ -554,13 +546,12 @@ class InputProjectionWrapper(RNNCell):
with ops.name_scope(type(self).__name__ + "ZeroState", values=[batch_size]):
return self._cell.zero_state(batch_size, dtype)
def __call__(self, inputs, state, scope=None):
def call(self, inputs, state):
"""Run the input projection and then the cell."""
# Default scope: "InputProjectionWrapper"
with vs.variable_scope(scope or "input_projection_wrapper"):
projected = _linear(inputs, self._num_proj, True)
if self._activation:
projected = self._activation(projected)
projected = _linear(inputs, self._num_proj, True)
if self._activation:
projected = self._activation(projected)
return self._cell(projected, state)
@ -847,6 +838,7 @@ class EmbeddingWrapper(RNNCell):
TypeError: if cell is not an RNNCell.
ValueError: if embedding_classes is not positive.
"""
super(EmbeddingWrapper, self).__init__(_reuse=reuse)
if not isinstance(cell, RNNCell):
raise TypeError("The parameter cell is not RNNCell.")
if embedding_classes <= 0 or embedding_size <= 0:
@ -856,7 +848,6 @@ class EmbeddingWrapper(RNNCell):
self._embedding_classes = embedding_classes
self._embedding_size = embedding_size
self._initializer = initializer
self._reuse = reuse
@property
def state_size(self):
@ -870,31 +861,31 @@ class EmbeddingWrapper(RNNCell):
with ops.name_scope(type(self).__name__ + "ZeroState", values=[batch_size]):
return self._cell.zero_state(batch_size, dtype)
def __call__(self, inputs, state, scope=None):
def call(self, inputs, state):
"""Run the cell on embedded inputs."""
with _checked_scope(self, scope or "embedding_wrapper", reuse=self._reuse):
with ops.device("/cpu:0"):
if self._initializer:
initializer = self._initializer
elif vs.get_variable_scope().initializer:
initializer = vs.get_variable_scope().initializer
else:
# Default initializer for embeddings should have variance=1.
sqrt3 = math.sqrt(3) # Uniform(-sqrt(3), sqrt(3)) has variance=1.
initializer = init_ops.random_uniform_initializer(-sqrt3, sqrt3)
with ops.device("/cpu:0"):
if self._initializer:
initializer = self._initializer
elif vs.get_variable_scope().initializer:
initializer = vs.get_variable_scope().initializer
else:
# Default initializer for embeddings should have variance=1.
sqrt3 = math.sqrt(3) # Uniform(-sqrt(3), sqrt(3)) has variance=1.
initializer = init_ops.random_uniform_initializer(-sqrt3, sqrt3)
if type(state) is tuple:
data_type = state[0].dtype
else:
data_type = state.dtype
if type(state) is tuple:
data_type = state[0].dtype
else:
data_type = state.dtype
embedding = vs.get_variable(
"embedding", [self._embedding_classes, self._embedding_size],
initializer=initializer,
dtype=data_type)
embedded = embedding_ops.embedding_lookup(
embedding, array_ops.reshape(inputs, [-1]))
return self._cell(embedded, state)
embedding = vs.get_variable(
"embedding", [self._embedding_classes, self._embedding_size],
initializer=initializer,
dtype=data_type)
embedded = embedding_ops.embedding_lookup(embedding,
array_ops.reshape(inputs, [-1]))
return self._cell(embedded, state)
class MultiRNNCell(RNNCell):
@ -914,6 +905,7 @@ class MultiRNNCell(RNNCell):
ValueError: if cells is empty (not allowed), or at least one of the cells
returns a state tuple but the flag `state_is_tuple` is `False`.
"""
super(MultiRNNCell, self).__init__()
if not cells:
raise ValueError("Must specify at least one cell for MultiRNNCell.")
if not nest.is_sequence(cells):
@ -948,28 +940,29 @@ class MultiRNNCell(RNNCell):
# presumably does not contain TensorArrays or anything else fancy
return super(MultiRNNCell, self).zero_state(batch_size, dtype)
def __call__(self, inputs, state, scope=None):
def call(self, inputs, state):
"""Run this multi-layer cell on inputs, starting from state."""
with vs.variable_scope(scope or "multi_rnn_cell"):
cur_state_pos = 0
cur_inp = inputs
new_states = []
for i, cell in enumerate(self._cells):
with vs.variable_scope("cell_%d" % i):
if self._state_is_tuple:
if not nest.is_sequence(state):
raise ValueError(
"Expected state to be a tuple of length %d, but received: %s"
% (len(self.state_size), state))
cur_state = state[i]
else:
cur_state = array_ops.slice(
state, [0, cur_state_pos], [-1, cell.state_size])
cur_state_pos += cell.state_size
cur_inp, new_state = cell(cur_inp, cur_state)
new_states.append(new_state)
cur_state_pos = 0
cur_inp = inputs
new_states = []
for i, cell in enumerate(self._cells):
with vs.variable_scope("cell_%d" % i):
if self._state_is_tuple:
if not nest.is_sequence(state):
raise ValueError(
"Expected state to be a tuple of length %d, but received: %s" %
(len(self.state_size), state))
cur_state = state[i]
else:
cur_state = array_ops.slice(state, [0, cur_state_pos],
[-1, cell.state_size])
cur_state_pos += cell.state_size
cur_inp, new_state = cell(cur_inp, cur_state)
new_states.append(new_state)
new_states = (tuple(new_states) if self._state_is_tuple else
array_ops.concat(new_states, 1))
return cur_inp, new_states

View File

@ -138,6 +138,7 @@ class CoupledInputForgetGateLSTMCell(core_rnn_cell.RNNCell):
in an existing scope. If not `True`, and the existing scope already has
the given variables, an error is raised.
"""
super(CoupledInputForgetGateLSTMCell, self).__init__(_reuse=reuse)
if not state_is_tuple:
logging.warn(
"%s: Using a concatenated state is slower and will soon be "
@ -173,7 +174,7 @@ class CoupledInputForgetGateLSTMCell(core_rnn_cell.RNNCell):
def output_size(self):
return self._output_size
def __call__(self, inputs, state, scope=None):
def call(self, inputs, state):
"""Run one step of LSTM.
Args:
@ -182,7 +183,6 @@ class CoupledInputForgetGateLSTMCell(core_rnn_cell.RNNCell):
`2-D, batch x state_size`. If `state_is_tuple` is True, this must be a
tuple of state Tensors, both `2-D`, with column sizes `c_state` and
`m_state`.
scope: VariableScope for the created subgraph; defaults to "LSTMCell".
Returns:
A tuple containing:
@ -212,51 +212,49 @@ class CoupledInputForgetGateLSTMCell(core_rnn_cell.RNNCell):
input_size = inputs.get_shape().with_rank(2)[1]
if input_size.value is None:
raise ValueError("Could not infer input size from inputs.get_shape()[-1]")
with _checked_scope(self, scope or "coupled_input_forget_gate_lstm_cell",
initializer=self._initializer, reuse=self._reuse):
concat_w = _get_concat_variable(
"W", [input_size.value + num_proj, 3 * self._num_units],
dtype, self._num_unit_shards)
concat_w = _get_concat_variable(
"W", [input_size.value + num_proj, 3 * self._num_units],
dtype, self._num_unit_shards)
b = vs.get_variable(
"B",
shape=[3 * self._num_units],
initializer=init_ops.zeros_initializer(),
dtype=dtype)
b = vs.get_variable(
"B",
shape=[3 * self._num_units],
initializer=init_ops.zeros_initializer(),
dtype=dtype)
# j = new_input, f = forget_gate, o = output_gate
cell_inputs = array_ops.concat([inputs, m_prev], 1)
lstm_matrix = nn_ops.bias_add(math_ops.matmul(cell_inputs, concat_w), b)
j, f, o = array_ops.split(value=lstm_matrix, num_or_size_splits=3, axis=1)
# j = new_input, f = forget_gate, o = output_gate
cell_inputs = array_ops.concat([inputs, m_prev], 1)
lstm_matrix = nn_ops.bias_add(math_ops.matmul(cell_inputs, concat_w), b)
j, f, o = array_ops.split(value=lstm_matrix, num_or_size_splits=3, axis=1)
# Diagonal connections
if self._use_peepholes:
w_f_diag = vs.get_variable(
"W_F_diag", shape=[self._num_units], dtype=dtype)
w_o_diag = vs.get_variable(
"W_O_diag", shape=[self._num_units], dtype=dtype)
# Diagonal connections
if self._use_peepholes:
w_f_diag = vs.get_variable(
"W_F_diag", shape=[self._num_units], dtype=dtype)
w_o_diag = vs.get_variable(
"W_O_diag", shape=[self._num_units], dtype=dtype)
if self._use_peepholes:
f_act = sigmoid(f + self._forget_bias + w_f_diag * c_prev)
else:
f_act = sigmoid(f + self._forget_bias)
c = (f_act * c_prev + (1 - f_act) * self._activation(j))
if self._use_peepholes:
f_act = sigmoid(f + self._forget_bias + w_f_diag * c_prev)
else:
f_act = sigmoid(f + self._forget_bias)
c = (f_act * c_prev + (1 - f_act) * self._activation(j))
if self._use_peepholes:
m = sigmoid(o + w_o_diag * c) * self._activation(c)
else:
m = sigmoid(o) * self._activation(c)
if self._use_peepholes:
m = sigmoid(o + w_o_diag * c) * self._activation(c)
else:
m = sigmoid(o) * self._activation(c)
if self._num_proj is not None:
concat_w_proj = _get_concat_variable(
"W_P", [self._num_units, self._num_proj],
dtype, self._num_proj_shards)
if self._num_proj is not None:
concat_w_proj = _get_concat_variable(
"W_P", [self._num_units, self._num_proj],
dtype, self._num_proj_shards)
m = math_ops.matmul(m, concat_w_proj)
if self._proj_clip is not None:
# pylint: disable=invalid-unary-operand-type
m = clip_ops.clip_by_value(m, -self._proj_clip, self._proj_clip)
# pylint: enable=invalid-unary-operand-type
m = math_ops.matmul(m, concat_w_proj)
if self._proj_clip is not None:
# pylint: disable=invalid-unary-operand-type
m = clip_ops.clip_by_value(m, -self._proj_clip, self._proj_clip)
# pylint: enable=invalid-unary-operand-type
new_state = (core_rnn_cell.LSTMStateTuple(c, m) if self._state_is_tuple else
array_ops.concat([c, m], 1))
@ -301,6 +299,7 @@ class TimeFreqLSTMCell(core_rnn_cell.RNNCell):
in an existing scope. If not `True`, and the existing scope already has
the given variables, an error is raised.
"""
super(TimeFreqLSTMCell, self).__init__(_reuse=reuse)
self._num_units = num_units
self._use_peepholes = use_peepholes
self._cell_clip = cell_clip
@ -321,14 +320,12 @@ class TimeFreqLSTMCell(core_rnn_cell.RNNCell):
def state_size(self):
return self._state_size
def __call__(self, inputs, state, scope=None):
def call(self, inputs, state):
"""Run one step of LSTM.
Args:
inputs: input Tensor, 2D, batch x num_units.
state: state Tensor, 2D, batch x state_size.
scope: VariableScope for the created subgraph; defaults to
"TimeFreqLSTMCell".
Returns:
A tuple containing:
@ -347,63 +344,63 @@ class TimeFreqLSTMCell(core_rnn_cell.RNNCell):
freq_inputs = self._make_tf_features(inputs)
dtype = inputs.dtype
actual_input_size = freq_inputs[0].get_shape().as_list()[1]
with _checked_scope(self, scope or "time_freq_lstm_cell",
initializer=self._initializer, reuse=self._reuse):
concat_w = _get_concat_variable(
"W", [actual_input_size + 2*self._num_units, 4 * self._num_units],
dtype, self._num_unit_shards)
b = vs.get_variable(
"B",
shape=[4 * self._num_units],
initializer=init_ops.zeros_initializer(),
dtype=dtype)
# Diagonal connections
concat_w = _get_concat_variable(
"W", [actual_input_size + 2*self._num_units, 4 * self._num_units],
dtype, self._num_unit_shards)
b = vs.get_variable(
"B",
shape=[4 * self._num_units],
initializer=init_ops.zeros_initializer(),
dtype=dtype)
# Diagonal connections
if self._use_peepholes:
w_f_diag = vs.get_variable(
"W_F_diag", shape=[self._num_units], dtype=dtype)
w_i_diag = vs.get_variable(
"W_I_diag", shape=[self._num_units], dtype=dtype)
w_o_diag = vs.get_variable(
"W_O_diag", shape=[self._num_units], dtype=dtype)
# initialize the first freq state to be zero
m_prev_freq = array_ops.zeros([int(inputs.get_shape()[0]),
self._num_units], dtype)
for fq in range(len(freq_inputs)):
c_prev = array_ops.slice(state, [0, 2*fq*self._num_units],
[-1, self._num_units])
m_prev = array_ops.slice(state, [0, (2*fq+1)*self._num_units],
[-1, self._num_units])
# i = input_gate, j = new_input, f = forget_gate, o = output_gate
cell_inputs = array_ops.concat([freq_inputs[fq], m_prev, m_prev_freq],
1)
lstm_matrix = nn_ops.bias_add(math_ops.matmul(cell_inputs, concat_w), b)
i, j, f, o = array_ops.split(
value=lstm_matrix, num_or_size_splits=4, axis=1)
if self._use_peepholes:
w_f_diag = vs.get_variable(
"W_F_diag", shape=[self._num_units], dtype=dtype)
w_i_diag = vs.get_variable(
"W_I_diag", shape=[self._num_units], dtype=dtype)
w_o_diag = vs.get_variable(
"W_O_diag", shape=[self._num_units], dtype=dtype)
c = (sigmoid(f + self._forget_bias + w_f_diag * c_prev) * c_prev +
sigmoid(i + w_i_diag * c_prev) * tanh(j))
else:
c = (sigmoid(f + self._forget_bias) * c_prev + sigmoid(i) * tanh(j))
# initialize the first freq state to be zero
m_prev_freq = array_ops.zeros([int(inputs.get_shape()[0]),
self._num_units], dtype)
for fq in range(len(freq_inputs)):
c_prev = array_ops.slice(state, [0, 2*fq*self._num_units],
[-1, self._num_units])
m_prev = array_ops.slice(state, [0, (2*fq+1)*self._num_units],
[-1, self._num_units])
# i = input_gate, j = new_input, f = forget_gate, o = output_gate
cell_inputs = array_ops.concat([freq_inputs[fq], m_prev, m_prev_freq],
1)
lstm_matrix = nn_ops.bias_add(math_ops.matmul(cell_inputs, concat_w), b)
i, j, f, o = array_ops.split(
value=lstm_matrix, num_or_size_splits=4, axis=1)
if self._cell_clip is not None:
# pylint: disable=invalid-unary-operand-type
c = clip_ops.clip_by_value(c, -self._cell_clip, self._cell_clip)
# pylint: enable=invalid-unary-operand-type
if self._use_peepholes:
c = (sigmoid(f + self._forget_bias + w_f_diag * c_prev) * c_prev +
sigmoid(i + w_i_diag * c_prev) * tanh(j))
else:
c = (sigmoid(f + self._forget_bias) * c_prev + sigmoid(i) * tanh(j))
if self._cell_clip is not None:
# pylint: disable=invalid-unary-operand-type
c = clip_ops.clip_by_value(c, -self._cell_clip, self._cell_clip)
# pylint: enable=invalid-unary-operand-type
if self._use_peepholes:
m = sigmoid(o + w_o_diag * c) * tanh(c)
else:
m = sigmoid(o) * tanh(c)
m_prev_freq = m
if fq == 0:
state_out = array_ops.concat([c, m], 1)
m_out = m
else:
state_out = array_ops.concat([state_out, c, m], 1)
m_out = array_ops.concat([m_out, m], 1)
if self._use_peepholes:
m = sigmoid(o + w_o_diag * c) * tanh(c)
else:
m = sigmoid(o) * tanh(c)
m_prev_freq = m
if fq == 0:
state_out = array_ops.concat([c, m], 1)
m_out = m
else:
state_out = array_ops.concat([state_out, c, m], 1)
m_out = array_ops.concat([m_out, m], 1)
return m_out, state_out
def _make_tf_features(self, input_feat):
@ -499,6 +496,7 @@ class GridLSTMCell(core_rnn_cell.RNNCell):
Raises:
ValueError: if the num_frequency_blocks list is not specified
"""
super(GridLSTMCell, self).__init__(_reuse=reuse)
if not state_is_tuple:
logging.warn("%s: Using a concatenated state is slower and will soon be "
"deprecated. Use state_is_tuple=True.", self)
@ -550,15 +548,13 @@ class GridLSTMCell(core_rnn_cell.RNNCell):
def state_tuple_type(self):
return self._state_tuple_type
def __call__(self, inputs, state, scope=None):
def call(self, inputs, state):
"""Run one step of LSTM.
Args:
inputs: input Tensor, 2D, [batch, feature_size].
state: Tensor or tuple of Tensors, 2D, [batch, state_size], depends on the
flag self._state_is_tuple.
scope: (optional) VariableScope for the created subgraph; if None, it
defaults to "GridLSTMCell".
Returns:
A tuple containing:
@ -573,21 +569,19 @@ class GridLSTMCell(core_rnn_cell.RNNCell):
"""
batch_size = inputs.shape[0].value or array_ops.shape(inputs)[0]
freq_inputs = self._make_tf_features(inputs)
with _checked_scope(self, scope or "grid_lstm_cell",
initializer=self._initializer, reuse=self._reuse):
m_out_lst = []
state_out_lst = []
for block in range(len(freq_inputs)):
m_out_lst_current, state_out_lst_current = self._compute(
freq_inputs[block], block, state, batch_size,
state_is_tuple=self._state_is_tuple)
m_out_lst.extend(m_out_lst_current)
state_out_lst.extend(state_out_lst_current)
if self._state_is_tuple:
state_out = self._state_tuple_type(*state_out_lst)
else:
state_out = array_ops.concat(state_out_lst, 1)
m_out = array_ops.concat(m_out_lst, 1)
m_out_lst = []
state_out_lst = []
for block in range(len(freq_inputs)):
m_out_lst_current, state_out_lst_current = self._compute(
freq_inputs[block], block, state, batch_size,
state_is_tuple=self._state_is_tuple)
m_out_lst.extend(m_out_lst_current)
state_out_lst.extend(state_out_lst_current)
if self._state_is_tuple:
state_out = self._state_tuple_type(*state_out_lst)
else:
state_out = array_ops.concat(state_out_lst, 1)
m_out = array_ops.concat(m_out_lst, 1)
return m_out, state_out
def _compute(self, freq_inputs, block, state, batch_size,
@ -974,14 +968,12 @@ class BidirectionalGridLSTMCell(GridLSTMCell):
*([num_units, num_units] * self._total_blocks * 2))
self._output_size = 2 * num_units * self._total_blocks * 2
def __call__(self, inputs, state, scope=None):
def call(self, inputs, state):
"""Run one step of LSTM.
Args:
inputs: input Tensor, 2D, [batch, num_units].
state: tuple of Tensors, 2D, [batch, state_size].
scope: (optional) VariableScope for the created subgraph; if None, it
defaults to "BidirectionalGridLSTMCell".
Returns:
A tuple containing:
@ -1002,29 +994,27 @@ class BidirectionalGridLSTMCell(GridLSTMCell):
bwd_inputs = fwd_inputs
# Forward processing
with _checked_scope(self, scope or "bidirectional_grid_lstm_cell",
initializer=self._initializer, reuse=self._reuse):
with vs.variable_scope("fwd"):
fwd_m_out_lst = []
fwd_state_out_lst = []
for block in range(len(fwd_inputs)):
fwd_m_out_lst_current, fwd_state_out_lst_current = self._compute(
fwd_inputs[block], block, state, batch_size,
state_prefix="fwd_state", state_is_tuple=True)
fwd_m_out_lst.extend(fwd_m_out_lst_current)
fwd_state_out_lst.extend(fwd_state_out_lst_current)
# Backward processing
bwd_m_out_lst = []
bwd_state_out_lst = []
with vs.variable_scope("bwd"):
for block in range(len(bwd_inputs)):
# Reverse the blocks
bwd_inputs_reverse = bwd_inputs[block][::-1]
bwd_m_out_lst_current, bwd_state_out_lst_current = self._compute(
bwd_inputs_reverse, block, state, batch_size,
state_prefix="bwd_state", state_is_tuple=True)
bwd_m_out_lst.extend(bwd_m_out_lst_current)
bwd_state_out_lst.extend(bwd_state_out_lst_current)
with vs.variable_scope("fwd"):
fwd_m_out_lst = []
fwd_state_out_lst = []
for block in range(len(fwd_inputs)):
fwd_m_out_lst_current, fwd_state_out_lst_current = self._compute(
fwd_inputs[block], block, state, batch_size,
state_prefix="fwd_state", state_is_tuple=True)
fwd_m_out_lst.extend(fwd_m_out_lst_current)
fwd_state_out_lst.extend(fwd_state_out_lst_current)
# Backward processing
bwd_m_out_lst = []
bwd_state_out_lst = []
with vs.variable_scope("bwd"):
for block in range(len(bwd_inputs)):
# Reverse the blocks
bwd_inputs_reverse = bwd_inputs[block][::-1]
bwd_m_out_lst_current, bwd_state_out_lst_current = self._compute(
bwd_inputs_reverse, block, state, batch_size,
state_prefix="bwd_state", state_is_tuple=True)
bwd_m_out_lst.extend(bwd_m_out_lst_current)
bwd_state_out_lst.extend(bwd_state_out_lst_current)
state_out = self._state_tuple_type(*(fwd_state_out_lst + bwd_state_out_lst))
# Outputs are always concated as it is never used separately.
m_out = array_ops.concat(fwd_m_out_lst + bwd_m_out_lst, 1)
@ -1069,6 +1059,7 @@ class AttentionCellWrapper(core_rnn_cell.RNNCell):
ValueError: if cell returns a state tuple but the flag
`state_is_tuple` is `False` or if attn_length is zero or less.
"""
super(AttentionCellWrapper, self).__init__(_reuse=reuse)
if not isinstance(cell, core_rnn_cell.RNNCell):
raise TypeError("The parameter cell is not RNNCell.")
if nest.is_sequence(cell.state_size) and not state_is_tuple:
@ -1107,42 +1098,40 @@ class AttentionCellWrapper(core_rnn_cell.RNNCell):
def output_size(self):
return self._attn_size
def __call__(self, inputs, state, scope=None):
def call(self, inputs, state):
"""Long short-term memory cell with attention (LSTMA)."""
with _checked_scope(self, scope or "attention_cell_wrapper",
reuse=self._reuse):
if self._state_is_tuple:
state, attns, attn_states = state
else:
states = state
state = array_ops.slice(states, [0, 0], [-1, self._cell.state_size])
attns = array_ops.slice(
states, [0, self._cell.state_size], [-1, self._attn_size])
attn_states = array_ops.slice(
states, [0, self._cell.state_size + self._attn_size],
[-1, self._attn_size * self._attn_length])
attn_states = array_ops.reshape(attn_states,
[-1, self._attn_length, self._attn_size])
input_size = self._input_size
if input_size is None:
input_size = inputs.get_shape().as_list()[1]
inputs = _linear([inputs, attns], input_size, True)
lstm_output, new_state = self._cell(inputs, state)
if self._state_is_tuple:
new_state_cat = array_ops.concat(nest.flatten(new_state), 1)
else:
new_state_cat = new_state
new_attns, new_attn_states = self._attention(new_state_cat, attn_states)
with vs.variable_scope("attn_output_projection"):
output = _linear([lstm_output, new_attns], self._attn_size, True)
new_attn_states = array_ops.concat(
[new_attn_states, array_ops.expand_dims(output, 1)], 1)
new_attn_states = array_ops.reshape(
new_attn_states, [-1, self._attn_length * self._attn_size])
new_state = (new_state, new_attns, new_attn_states)
if not self._state_is_tuple:
new_state = array_ops.concat(list(new_state), 1)
return output, new_state
if self._state_is_tuple:
state, attns, attn_states = state
else:
states = state
state = array_ops.slice(states, [0, 0], [-1, self._cell.state_size])
attns = array_ops.slice(
states, [0, self._cell.state_size], [-1, self._attn_size])
attn_states = array_ops.slice(
states, [0, self._cell.state_size + self._attn_size],
[-1, self._attn_size * self._attn_length])
attn_states = array_ops.reshape(attn_states,
[-1, self._attn_length, self._attn_size])
input_size = self._input_size
if input_size is None:
input_size = inputs.get_shape().as_list()[1]
inputs = _linear([inputs, attns], input_size, True)
lstm_output, new_state = self._cell(inputs, state)
if self._state_is_tuple:
new_state_cat = array_ops.concat(nest.flatten(new_state), 1)
else:
new_state_cat = new_state
new_attns, new_attn_states = self._attention(new_state_cat, attn_states)
with vs.variable_scope("attn_output_projection"):
output = _linear([lstm_output, new_attns], self._attn_size, True)
new_attn_states = array_ops.concat(
[new_attn_states, array_ops.expand_dims(output, 1)], 1)
new_attn_states = array_ops.reshape(
new_attn_states, [-1, self._attn_length * self._attn_size])
new_state = (new_state, new_attns, new_attn_states)
if not self._state_is_tuple:
new_state = array_ops.concat(list(new_state), 1)
return output, new_state
def _attention(self, query, attn_states):
conv2d = nn_ops.conv2d
@ -1213,6 +1202,7 @@ class LayerNormBasicLSTMCell(core_rnn_cell.RNNCell):
in an existing scope. If not `True`, and the existing scope already has
the given variables, an error is raised.
"""
super(LayerNormBasicLSTMCell, self).__init__(_reuse=reuse)
if input_size is not None:
logging.warn("%s: The input_size parameter is deprecated.", self)
@ -1256,34 +1246,31 @@ class LayerNormBasicLSTMCell(core_rnn_cell.RNNCell):
out = nn_ops.bias_add(out, bias)
return out
def __call__(self, inputs, state, scope=None):
def call(self, inputs, state):
"""LSTM cell with layer normalization and recurrent dropout."""
c, h = state
args = array_ops.concat([inputs, h], 1)
concat = self._linear(args)
with _checked_scope(self, scope or "layer_norm_basic_lstm_cell",
reuse=self._reuse):
c, h = state
args = array_ops.concat([inputs, h], 1)
concat = self._linear(args)
i, j, f, o = array_ops.split(value=concat, num_or_size_splits=4, axis=1)
if self._layer_norm:
i = self._norm(i, "input")
j = self._norm(j, "transform")
f = self._norm(f, "forget")
o = self._norm(o, "output")
i, j, f, o = array_ops.split(value=concat, num_or_size_splits=4, axis=1)
if self._layer_norm:
i = self._norm(i, "input")
j = self._norm(j, "transform")
f = self._norm(f, "forget")
o = self._norm(o, "output")
g = self._activation(j)
if (not isinstance(self._keep_prob, float)) or self._keep_prob < 1:
g = nn_ops.dropout(g, self._keep_prob, seed=self._seed)
g = self._activation(j)
if (not isinstance(self._keep_prob, float)) or self._keep_prob < 1:
g = nn_ops.dropout(g, self._keep_prob, seed=self._seed)
new_c = (c * math_ops.sigmoid(f + self._forget_bias)
+ math_ops.sigmoid(i) * g)
if self._layer_norm:
new_c = self._norm(new_c, "state")
new_h = self._activation(new_c) * math_ops.sigmoid(o)
new_c = (c * math_ops.sigmoid(f + self._forget_bias)
+ math_ops.sigmoid(i) * g)
if self._layer_norm:
new_c = self._norm(new_c, "state")
new_h = self._activation(new_c) * math_ops.sigmoid(o)
new_state = core_rnn_cell.LSTMStateTuple(new_c, new_h)
return new_h, new_state
new_state = core_rnn_cell.LSTMStateTuple(new_c, new_h)
return new_h, new_state
class NASCell(core_rnn_cell.RNNCell):
@ -1313,6 +1300,7 @@ class NASCell(core_rnn_cell.RNNCell):
in an existing scope. If not `True`, and the existing scope already has
the given variables, an error is raised.
"""
super(NASCell, self).__init__(_reuse=reuse)
self._num_units = num_units
self._num_proj = num_proj
self._use_biases = use_biases
@ -1333,14 +1321,13 @@ class NASCell(core_rnn_cell.RNNCell):
def output_size(self):
return self._output_size
def __call__(self, inputs, state, scope=None):
def call(self, inputs, state):
"""Run one step of NAS Cell.
Args:
inputs: input Tensor, 2D, batch x num_units.
state: This must be a tuple of state Tensors, both `2-D`, with column
sizes `c_state` and `m_state`.
scope: VariableScope for the created subgraph; defaults to "nas_rnn".
Returns:
A tuple containing:
@ -1368,71 +1355,70 @@ class NASCell(core_rnn_cell.RNNCell):
input_size = inputs.get_shape().with_rank(2)[1]
if input_size.value is None:
raise ValueError("Could not infer input size from inputs.get_shape()[-1]")
with _checked_scope(self, scope or "nas_rnn", reuse=self._reuse):
# Variables for the NAS cell. W_m is all matrices multiplying the
# hiddenstate and W_inputs is all matrices multiplying the inputs.
concat_w_m = vs.get_variable(
"recurrent_weights", [num_proj, 8 * self._num_units],
dtype)
concat_w_inputs = vs.get_variable(
"weights", [input_size.value, 8 * self._num_units],
# Variables for the NAS cell. W_m is all matrices multiplying the
# hiddenstate and W_inputs is all matrices multiplying the inputs.
concat_w_m = vs.get_variable(
"recurrent_weights", [num_proj, 8 * self._num_units],
dtype)
concat_w_inputs = vs.get_variable(
"weights", [input_size.value, 8 * self._num_units],
dtype)
m_matrix = math_ops.matmul(m_prev, concat_w_m)
inputs_matrix = math_ops.matmul(inputs, concat_w_inputs)
if self._use_biases:
b = vs.get_variable(
"bias",
shape=[8 * self._num_units],
initializer=init_ops.zeros_initializer(),
dtype=dtype)
m_matrix = nn_ops.bias_add(m_matrix, b)
# The NAS cell branches into 8 different splits for both the hiddenstate
# and the input
m_matrix_splits = array_ops.split(axis=1, num_or_size_splits=8,
value=m_matrix)
inputs_matrix_splits = array_ops.split(axis=1, num_or_size_splits=8,
value=inputs_matrix)
# First layer
layer1_0 = sigmoid(inputs_matrix_splits[0] + m_matrix_splits[0])
layer1_1 = relu(inputs_matrix_splits[1] + m_matrix_splits[1])
layer1_2 = sigmoid(inputs_matrix_splits[2] + m_matrix_splits[2])
layer1_3 = relu(inputs_matrix_splits[3] * m_matrix_splits[3])
layer1_4 = tanh(inputs_matrix_splits[4] + m_matrix_splits[4])
layer1_5 = sigmoid(inputs_matrix_splits[5] + m_matrix_splits[5])
layer1_6 = tanh(inputs_matrix_splits[6] + m_matrix_splits[6])
layer1_7 = sigmoid(inputs_matrix_splits[7] + m_matrix_splits[7])
# Second layer
l2_0 = tanh(layer1_0 * layer1_1)
l2_1 = tanh(layer1_2 + layer1_3)
l2_2 = tanh(layer1_4 * layer1_5)
l2_3 = sigmoid(layer1_6 + layer1_7)
# Inject the cell
l2_0 = tanh(l2_0 + c_prev)
# Third layer
l3_0_pre = l2_0 * l2_1
new_c = l3_0_pre # create new cell
l3_0 = l3_0_pre
l3_1 = tanh(l2_2 + l2_3)
# Final layer
new_m = tanh(l3_0 * l3_1)
# Projection layer if specified
if self._num_proj is not None:
concat_w_proj = vs.get_variable(
"projection_weights", [self._num_units, self._num_proj],
dtype)
new_m = math_ops.matmul(new_m, concat_w_proj)
m_matrix = math_ops.matmul(m_prev, concat_w_m)
inputs_matrix = math_ops.matmul(inputs, concat_w_inputs)
if self._use_biases:
b = vs.get_variable(
"bias",
shape=[8 * self._num_units],
initializer=init_ops.zeros_initializer(),
dtype=dtype)
m_matrix = nn_ops.bias_add(m_matrix, b)
# The NAS cell branches into 8 different splits for both the hiddenstate
# and the input
m_matrix_splits = array_ops.split(axis=1, num_or_size_splits=8,
value=m_matrix)
inputs_matrix_splits = array_ops.split(axis=1, num_or_size_splits=8,
value=inputs_matrix)
# First layer
layer1_0 = sigmoid(inputs_matrix_splits[0] + m_matrix_splits[0])
layer1_1 = relu(inputs_matrix_splits[1] + m_matrix_splits[1])
layer1_2 = sigmoid(inputs_matrix_splits[2] + m_matrix_splits[2])
layer1_3 = relu(inputs_matrix_splits[3] * m_matrix_splits[3])
layer1_4 = tanh(inputs_matrix_splits[4] + m_matrix_splits[4])
layer1_5 = sigmoid(inputs_matrix_splits[5] + m_matrix_splits[5])
layer1_6 = tanh(inputs_matrix_splits[6] + m_matrix_splits[6])
layer1_7 = sigmoid(inputs_matrix_splits[7] + m_matrix_splits[7])
# Second layer
l2_0 = tanh(layer1_0 * layer1_1)
l2_1 = tanh(layer1_2 + layer1_3)
l2_2 = tanh(layer1_4 * layer1_5)
l2_3 = sigmoid(layer1_6 + layer1_7)
# Inject the cell
l2_0 = tanh(l2_0 + c_prev)
# Third layer
l3_0_pre = l2_0 * l2_1
new_c = l3_0_pre # create new cell
l3_0 = l3_0_pre
l3_1 = tanh(l2_2 + l2_3)
# Final layer
new_m = tanh(l3_0 * l3_1)
# Projection layer if specified
if self._num_proj is not None:
concat_w_proj = vs.get_variable(
"projection_weights", [self._num_units, self._num_proj],
dtype)
new_m = math_ops.matmul(new_m, concat_w_proj)
new_state = core_rnn_cell.LSTMStateTuple(new_c, new_m)
return new_m, new_state
new_state = core_rnn_cell.LSTMStateTuple(new_c, new_m)
return new_m, new_state
class UGRNNCell(core_rnn_cell.RNNCell):
@ -1467,6 +1453,7 @@ class UGRNNCell(core_rnn_cell.RNNCell):
in an existing scope. If not `True`, and the existing scope already has
the given variables, an error is raised.
"""
super(UGRNNCell, self).__init__(_reuse=reuse)
self._num_units = num_units
self._initializer = initializer
self._forget_bias = forget_bias
@ -1481,13 +1468,12 @@ class UGRNNCell(core_rnn_cell.RNNCell):
def output_size(self):
return self._num_units
def __call__(self, inputs, state, scope=None):
def call(self, inputs, state):
"""Run one step of UGRNN.
Args:
inputs: input Tensor, 2D, batch x input size.
state: state Tensor, 2D, batch x num units.
scope: VariableScope for the created subgraph; defaults to "ugrnn_cell".
Returns:
new_output: batch x num units, Tensor representing the output of the UGRNN
@ -1506,8 +1492,8 @@ class UGRNNCell(core_rnn_cell.RNNCell):
if input_size.value is None:
raise ValueError("Could not infer input size from inputs.get_shape()[-1]")
with _checked_scope(self, scope or "ugrnn_cell",
initializer=self._initializer, reuse=self._reuse):
with vs.variable_scope(vs.get_variable_scope(),
initializer=self._initializer):
cell_inputs = array_ops.concat([inputs, state], 1)
rnn_matrix = _linear(cell_inputs, 2 * self._num_units, True)
@ -1567,6 +1553,7 @@ class IntersectionRNNCell(core_rnn_cell.RNNCell):
in an existing scope. If not `True`, and the existing scope already has
the given variables, an error is raised.
"""
super(IntersectionRNNCell, self).__init__(_reuse=reuse)
self._num_units = num_units
self._initializer = initializer
self._forget_bias = forget_bias
@ -1582,14 +1569,12 @@ class IntersectionRNNCell(core_rnn_cell.RNNCell):
def output_size(self):
return self._num_units
def __call__(self, inputs, state, scope=None):
def call(self, inputs, state):
"""Run one step of the Intersection RNN.
Args:
inputs: input Tensor, 2D, batch x input size.
state: state Tensor, 2D, batch x num units.
scope: VariableScope for the created subgraph; defaults to
"intersection_rnn_cell"
Returns:
new_y: batch x num units, Tensor representing the output of the +RNN
@ -1610,8 +1595,8 @@ class IntersectionRNNCell(core_rnn_cell.RNNCell):
if input_size.value is None:
raise ValueError("Could not infer input size from inputs.get_shape()[-1]")
with _checked_scope(self, scope or "intersection_rnn_cell",
initializer=self._initializer, reuse=self._reuse):
with vs.variable_scope(vs.get_variable_scope(),
initializer=self._initializer):
# read-in projections (should be used for first layer in deep +RNN
# to transform size of inputs from I --> N)
if input_size.value != self._num_units:
@ -1683,7 +1668,7 @@ class CompiledWrapper(core_rnn_cell.RNNCell):
return not _REGISTERED_OPS[node_def.op].is_stateful
with jit.experimental_jit_scope(compile_ops=compile_ops):
return self._cell(inputs, state, scope=scope)
return self._cell(inputs, state, scope)
def _random_exp_initializer(minval,
@ -1753,6 +1738,7 @@ class PhasedLSTMCell(core_rnn_cell.RNNCell):
in an existing scope. If not `True`, and the existing scope already has
the given variables, an error is raised.
"""
super(PhasedLSTMCell, self).__init__(_reuse=reuse)
self._num_units = num_units
self._use_peepholes = use_peepholes
self._leak = leak
@ -1782,7 +1768,7 @@ class PhasedLSTMCell(core_rnn_cell.RNNCell):
cycle_ratio = self._mod(shifted_time, period_casted) / period_casted
return math_ops.cast(cycle_ratio, dtype=dtypes.float32)
def __call__(self, inputs, state, scope=None):
def call(self, inputs, state):
"""Phased LSTM Cell.
Args:
@ -1792,7 +1778,6 @@ class PhasedLSTMCell(core_rnn_cell.RNNCell):
The second Tensor has shape [batch, features_size], and type float32.
It stores the features.
state: core_rnn_cell.LSTMStateTuple, state from previous timestep.
scope: string, id of the variable scope.
Returns:
A tuple containing:
@ -1801,61 +1786,60 @@ class PhasedLSTMCell(core_rnn_cell.RNNCell):
- A core_rnn_cell.LSTMStateTuple, containing 2 Tensors of float32, shape
[batch_size, num_units], representing the new state and the output.
"""
with _checked_scope(self, scope or "phased_lstm_cell", reuse=self._reuse):
(c_prev, h_prev) = state
(time, x) = inputs
(c_prev, h_prev) = state
(time, x) = inputs
in_mask_gates = [x, h_prev]
if self._use_peepholes:
in_mask_gates.append(c_prev)
in_mask_gates = [x, h_prev]
if self._use_peepholes:
in_mask_gates.append(c_prev)
with vs.variable_scope("mask_gates"):
mask_gates = math_ops.sigmoid(
_linear(in_mask_gates, 2 * self._num_units, True))
[input_gate, forget_gate] = array_ops.split(
axis=1, num_or_size_splits=2, value=mask_gates)
with vs.variable_scope("mask_gates"):
mask_gates = math_ops.sigmoid(
_linear(in_mask_gates, 2 * self._num_units, True))
[input_gate, forget_gate] = array_ops.split(
axis=1, num_or_size_splits=2, value=mask_gates)
with vs.variable_scope("new_input"):
new_input = math_ops.tanh(
_linear([x, h_prev], self._num_units, True))
with vs.variable_scope("new_input"):
new_input = math_ops.tanh(
_linear([x, h_prev], self._num_units, True))
new_c = (c_prev * forget_gate + input_gate * new_input)
new_c = (c_prev * forget_gate + input_gate * new_input)
in_out_gate = [x, h_prev]
if self._use_peepholes:
in_out_gate.append(new_c)
in_out_gate = [x, h_prev]
if self._use_peepholes:
in_out_gate.append(new_c)
with vs.variable_scope("output_gate"):
output_gate = math_ops.sigmoid(
_linear(in_out_gate, self._num_units, True))
with vs.variable_scope("output_gate"):
output_gate = math_ops.sigmoid(
_linear(in_out_gate, self._num_units, True))
new_h = math_ops.tanh(new_c) * output_gate
new_h = math_ops.tanh(new_c) * output_gate
period = vs.get_variable(
"period", [self._num_units],
initializer=_random_exp_initializer(
self._period_init_min, self._period_init_max))
phase = vs.get_variable(
"phase", [self._num_units],
initializer=init_ops.random_uniform_initializer(
0., period.initial_value))
ratio_on = vs.get_variable(
"ratio_on", [self._num_units],
initializer=init_ops.constant_initializer(self._ratio_on),
trainable=self._trainable_ratio_on)
period = vs.get_variable(
"period", [self._num_units],
initializer=_random_exp_initializer(
self._period_init_min, self._period_init_max))
phase = vs.get_variable(
"phase", [self._num_units],
initializer=init_ops.random_uniform_initializer(
0., period.initial_value))
ratio_on = vs.get_variable(
"ratio_on", [self._num_units],
initializer=init_ops.constant_initializer(self._ratio_on),
trainable=self._trainable_ratio_on)
cycle_ratio = self._get_cycle_ratio(time, phase, period)
cycle_ratio = self._get_cycle_ratio(time, phase, period)
k_up = 2 * cycle_ratio / ratio_on
k_down = 2 - k_up
k_closed = self._leak * cycle_ratio
k_up = 2 * cycle_ratio / ratio_on
k_down = 2 - k_up
k_closed = self._leak * cycle_ratio
k = array_ops.where(cycle_ratio < ratio_on, k_down, k_closed)
k = array_ops.where(cycle_ratio < 0.5 * ratio_on, k_up, k)
k = array_ops.where(cycle_ratio < ratio_on, k_down, k_closed)
k = array_ops.where(cycle_ratio < 0.5 * ratio_on, k_up, k)
new_c = k * new_c + (1 - k) * c_prev
new_h = k * new_h + (1 - k) * h_prev
new_c = k * new_c + (1 - k) * c_prev
new_h = k * new_h + (1 - k) * h_prev
new_state = core_rnn_cell.LSTMStateTuple(new_c, new_h)
new_state = core_rnn_cell.LSTMStateTuple(new_c, new_h)
return new_h, new_state
return new_h, new_state

View File

@ -56,14 +56,19 @@ class GatherTreeOp : public OpKernel {
errors::InvalidArgument("step_ids must be a 3-tensor, saw shape: ",
step_ids_shape.DebugString()));
OP_REQUIRES(
ctx, TensorShapeUtils::IsVector(sequence_length.shape()),
errors::InvalidArgument("sequence_length must be a vector, saw shape: ",
ctx, TensorShapeUtils::IsMatrix(sequence_length.shape()),
errors::InvalidArgument("sequence_length must be a matrix, saw shape: ",
sequence_length.shape().DebugString()));
OP_REQUIRES(ctx, sequence_length.dim_size(0) == step_ids_shape.dim_size(1),
errors::InvalidArgument(
"Inconsistent batch sizes: sequence_length.shape[1] (",
"Inconsistent batch sizes: sequence_length.shape[0] (",
sequence_length.dim_size(0), ") != ", "step_ids.shape[1] (",
step_ids_shape.dim_size(0), ")"));
step_ids_shape.dim_size(1), ")"));
OP_REQUIRES(ctx, sequence_length.dim_size(1) == step_ids_shape.dim_size(2),
errors::InvalidArgument(
"Inconsistent batch sizes: sequence_length.shape[1] (",
sequence_length.dim_size(1), ") != ", "step_ids.shape[2] (",
step_ids_shape.dim_size(2), ")"));
OP_REQUIRES(
ctx, step_ids_shape == parent_ids.shape(),
errors::InvalidArgument(
@ -74,7 +79,7 @@ class GatherTreeOp : public OpKernel {
OP_REQUIRES_OK(ctx, ctx->allocate_output(0, step_ids_shape, &beams));
typename TTypes<T, 3>::ConstTensor step_ids_t = step_ids.tensor<T, 3>();
typename TTypes<T, 3>::ConstTensor parent_ids_t = parent_ids.tensor<T, 3>();
typename TTypes<T>::ConstVec seq_len_t = sequence_length.vec<T>();
typename TTypes<T>::ConstMatrix seq_len_t = sequence_length.matrix<T>();
typename TTypes<T, 3>::Tensor beams_t = beams->tensor<T, 3>();
functor::GatherTree<Device, T>()(ctx, device, step_ids_t, parent_ids_t,
seq_len_t, beams_t);
@ -96,7 +101,7 @@ struct GatherTree<CPUDevice, int32> {
void operator()(OpKernelContext* ctx, const CPUDevice& d,
typename TTypes<int32, 3>::ConstTensor step_ids,
typename TTypes<int32, 3>::ConstTensor parent_ids,
typename TTypes<int32>::ConstVec sequence_length,
typename TTypes<int32>::ConstMatrix sequence_length,
typename TTypes<int32, 3>::Tensor beams) {
const int64 max_time = parent_ids.dimension(0);
const int64 batch_size = parent_ids.dimension(1);
@ -104,15 +109,10 @@ struct GatherTree<CPUDevice, int32> {
beams.setConstant(-1);
auto DoWork = [&, ctx](int start_batch_beam, int limit_batch_beam) {
int32 seq_len_b = -1;
int32 old_batch = -1;
for (int32 i = start_batch_beam; i < limit_batch_beam; ++i) {
const int32 batch = i / beam_width;
const int32 beam = i % beam_width;
if (batch != old_batch) {
seq_len_b = sequence_length(batch);
old_batch = batch;
}
int32 seq_len_b = sequence_length(batch, beam);
if (seq_len_b == 0) {
continue;
}
@ -148,14 +148,14 @@ struct GatherTree<CPUDevice, int32> {
#if GOOGLE_CUDA
namespace functor {
#define DECLARE_GPU_SPEC(T) \
template <> \
void GatherTree<GPUDevice, T>::operator()( \
OpKernelContext* ctx, const GPUDevice& d, \
typename TTypes<T, 3>::ConstTensor step_ids, \
typename TTypes<T, 3>::ConstTensor parent_ids, \
typename TTypes<T>::ConstVec sequence_length, \
typename TTypes<T, 3>::Tensor beams); \
#define DECLARE_GPU_SPEC(T) \
template <> \
void GatherTree<GPUDevice, T>::operator()( \
OpKernelContext* ctx, const GPUDevice& d, \
typename TTypes<T, 3>::ConstTensor step_ids, \
typename TTypes<T, 3>::ConstTensor parent_ids, \
typename TTypes<T>::ConstMatrix sequence_length, \
typename TTypes<T, 3>::Tensor beams); \
extern template struct GatherTree<GPUDevice, T>;
DECLARE_GPU_SPEC(int32);

View File

@ -31,7 +31,7 @@ struct GatherTree {
void operator()(OpKernelContext* ctx, const Device& d,
typename TTypes<T, 3>::ConstTensor step_ids,
typename TTypes<T, 3>::ConstTensor parent_ids,
typename TTypes<T>::ConstVec sequence_length,
typename TTypes<T>::ConstMatrix sequence_length,
typename TTypes<T, 3>::Tensor beams);
};

View File

@ -33,7 +33,7 @@ __global__ void GatherTreeOpKernel(const int32 batch_size, const int32 max_time,
CUDA_1D_KERNEL_LOOP(i, batch_size * beam_width) {
const int32 batch = i / beam_width;
const int32 beam = i % beam_width;
const int32 seq_len_b = ldg(sequence_length + batch);
const int32 seq_len_b = ldg(sequence_length + batch * beam_width + beam);
#define GET_IX(time_ix, beam_ix) \
(batch_size * beam_width * (time_ix) + beam_width * batch + (beam_ix))
const int32 initial_beam_ix = GET_IX(seq_len_b - 1, beam);
@ -59,7 +59,7 @@ struct GatherTree<GPUDevice, T> {
void operator()(OpKernelContext* ctx, const GPUDevice& d,
typename TTypes<T, 3>::ConstTensor step_ids,
typename TTypes<T, 3>::ConstTensor parent_ids,
typename TTypes<T>::ConstVec sequence_length,
typename TTypes<T>::ConstMatrix sequence_length,
typename TTypes<T, 3>::Tensor beams) {
const int32 max_time = parent_ids.dimension(0);
const int32 batch_size = parent_ids.dimension(1);

View File

@ -32,17 +32,20 @@ REGISTER_OP("GatherTree")
ShapeHandle step_ids, parent_ids, sequence_length;
// step_ids, parent_ids, and output are all shaped:
// [batch_size, max_time, beam_width].
// sequence_length is shaped [batch_size].
// [max_time, batch_size, beam_width].
// sequence_length is shaped [batch_size, beam_width].
TF_RETURN_IF_ERROR(c->WithRank(c->input(0), 3, &step_ids));
TF_RETURN_IF_ERROR(c->WithRank(c->input(1), 3, &parent_ids));
TF_RETURN_IF_ERROR(c->WithRank(c->input(2), 1, &sequence_length));
TF_RETURN_IF_ERROR(c->WithRank(c->input(2), 2, &sequence_length));
DimensionHandle batch_size = c->Dim(step_ids, 1);
DimensionHandle beam_width = c->Dim(step_ids, 2);
TF_RETURN_IF_ERROR(c->Merge(step_ids, parent_ids, &step_ids));
TF_RETURN_IF_ERROR(
c->Merge(batch_size, c->Dim(sequence_length, 0), &batch_size));
TF_RETURN_IF_ERROR(
c->Merge(beam_width, c->Dim(sequence_length, 1), &beam_width));
c->set_output(0, step_ids);
return tensorflow::Status::OK();
@ -58,7 +61,7 @@ TODO(ebrevdo): fill in
step_ids: `[max_time, batch_size, beam_width]`.
parent_ids: `[max_time, batch_size, beam_width]`.
sequence_length: `[batch_size]`.
sequence_length: `[batch_size, beam_width]`.
beams: `[max_time, batch_size, beam_width]`.
)doc");

View File

@ -109,7 +109,7 @@ class AttentionWrapperTest(test.TestCase):
initial_state=cell.zero_state(
dtype=dtypes.float32, batch_size=batch_size))
final_outputs, final_state = decoder.dynamic_decode(my_decoder)
final_outputs, final_state, _ = decoder.dynamic_decode(my_decoder)
self.assertTrue(
isinstance(final_outputs, basic_decoder.BasicDecoderOutput))

View File

@ -24,6 +24,7 @@ import numpy as np
from tensorflow.contrib.rnn import core_rnn_cell
from tensorflow.contrib.seq2seq.python.ops import attention_wrapper
from tensorflow.contrib.seq2seq.python.ops import beam_search_decoder
from tensorflow.contrib.seq2seq.python.ops import beam_search_ops
from tensorflow.contrib.seq2seq.python.ops import decoder
from tensorflow.python.framework import constant_op
from tensorflow.python.framework import dtypes
@ -41,24 +42,32 @@ class TestGatherTree(test.TestCase):
"""Tests the gather_tree function."""
def test_gather_tree(self):
predicted_ids = np.array([[[1, 2, 3], [4, 5, 6], [7, 8, 9]],
[[2, 3, 4], [5, 6, 7],
[8, 9, 10]]]).transpose([1, 0, 2])
parent_ids = np.array([
[[0, 0, 0], [0, 1, 1], [2, 1, 2]],
[[0, 0, 0], [1, 2, 0], [2, 1, 1]],
]).transpose([1, 0, 2])
expected_result = np.array([[[2, 2, 2], [6, 5, 6], [7, 8, 9]],
[[2, 4, 4], [7, 6, 6],
[8, 9, 10]]]).transpose([1, 0, 2])
# (max_time = 3, batch_size = 2, beam_width = 3)
res = beam_search_decoder._gather_tree(
ops.convert_to_tensor(predicted_ids), ops.convert_to_tensor(parent_ids))
# create (batch_size, max_time, beam_width) matrix and transpose it
predicted_ids = np.array(
[[[1, 2, 3], [4, 5, 6], [7, 8, 9]],
[[2, 3, 4], [5, 6, 7], [8, 9, 10]]],
dtype=np.int32).transpose([1, 0, 2])
parent_ids = np.array(
[[[0, 0, 0], [0, 1, 1], [2, 1, 2]],
[[0, 0, 0], [1, 2, 0], [2, 1, 1]]],
dtype=np.int32).transpose([1, 0, 2])
# sequence_lengths is shaped (batch_size = 2, beam_width = 3)
sequence_lengths = [[3, 3, 3], [3, 3, 3]]
expected_result = np.array(
[[[2, 2, 2], [6, 5, 6], [7, 8, 9]],
[[2, 4, 4], [7, 6, 6], [8, 9, 10]]]).transpose([1, 0, 2])
res = beam_search_ops.gather_tree(
predicted_ids, parent_ids, sequence_lengths)
with self.test_session() as sess:
res_ = sess.run(res)
np.testing.assert_array_equal(expected_result, res_)
self.assertAllEqual(expected_result, res_)
class TestEosMasking(test.TestCase):
@ -80,18 +89,18 @@ class TestEosMasking(test.TestCase):
probs = sess.run(probs)
masked = sess.run(masked)
np.testing.assert_array_equal(probs[0][0], masked[0][0])
np.testing.assert_array_equal(probs[0][2], masked[0][2])
np.testing.assert_array_equal(probs[1][0], masked[1][0])
self.assertAllEqual(probs[0][0], masked[0][0])
self.assertAllEqual(probs[0][2], masked[0][2])
self.assertAllEqual(probs[1][0], masked[1][0])
np.testing.assert_equal(masked[0][1][0], 0)
np.testing.assert_equal(masked[1][1][0], 0)
np.testing.assert_equal(masked[1][2][0], 0)
self.assertEqual(masked[0][1][0], 0)
self.assertEqual(masked[1][1][0], 0)
self.assertEqual(masked[1][2][0], 0)
for i in range(1, 5):
np.testing.assert_approx_equal(masked[0][1][i], np.finfo('float32').min)
np.testing.assert_approx_equal(masked[1][1][i], np.finfo('float32').min)
np.testing.assert_approx_equal(masked[1][2][i], np.finfo('float32').min)
self.assertAllClose(masked[0][1][i], np.finfo('float32').min)
self.assertAllClose(masked[1][1][i], np.finfo('float32').min)
self.assertAllClose(masked[1][2][i], np.finfo('float32').min)
class TestBeamStep(test.TestCase):
@ -142,12 +151,11 @@ class TestBeamStep(test.TestCase):
outputs_, next_state_, state_, log_probs_ = sess.run(
[outputs, next_beam_state, beam_state, log_probs])
np.testing.assert_array_equal(outputs_.predicted_ids, [[3, 3, 2], [2, 2,
1]])
np.testing.assert_array_equal(outputs_.parent_ids, [[1, 0, 0], [2, 1, 0]])
np.testing.assert_array_equal(next_state_.lengths, [[3, 3, 3], [3, 3, 3]])
np.testing.assert_array_equal(next_state_.finished, [[False, False, False],
[False, False, False]])
self.assertAllEqual(outputs_.predicted_ids, [[3, 3, 2], [2, 2, 1]])
self.assertAllEqual(outputs_.parent_ids, [[1, 0, 0], [2, 1, 0]])
self.assertAllEqual(next_state_.lengths, [[3, 3, 3], [3, 3, 3]])
self.assertAllEqual(next_state_.finished, [[False, False, False],
[False, False, False]])
expected_log_probs = []
expected_log_probs.append(state_.log_probs[0][[1, 0, 0]])
@ -158,7 +166,7 @@ class TestBeamStep(test.TestCase):
expected_log_probs[1][0] += log_probs_[1, 2, 2]
expected_log_probs[1][1] += log_probs_[1, 1, 2]
expected_log_probs[1][2] += log_probs_[1, 0, 1]
np.testing.assert_array_equal(next_state_.log_probs, expected_log_probs)
self.assertAllEqual(next_state_.log_probs, expected_log_probs)
def test_step_with_eos(self):
dummy_cell_state = array_ops.zeros([self.batch_size, self.beam_width])
@ -197,12 +205,11 @@ class TestBeamStep(test.TestCase):
outputs_, next_state_, state_, log_probs_ = sess.run(
[outputs, next_beam_state, beam_state, log_probs])
np.testing.assert_array_equal(outputs_.parent_ids, [[1, 0, 0], [1, 2, 0]])
np.testing.assert_array_equal(outputs_.predicted_ids, [[0, 3, 2], [2, 0,
1]])
np.testing.assert_array_equal(next_state_.lengths, [[1, 3, 3], [3, 1, 3]])
np.testing.assert_array_equal(next_state_.finished, [[True, False, False],
[False, True, False]])
self.assertAllEqual(outputs_.parent_ids, [[1, 0, 0], [1, 2, 0]])
self.assertAllEqual(outputs_.predicted_ids, [[0, 3, 2], [2, 0, 1]])
self.assertAllEqual(next_state_.lengths, [[1, 3, 3], [3, 1, 3]])
self.assertAllEqual(next_state_.finished, [[True, False, False],
[False, True, False]])
expected_log_probs = []
expected_log_probs.append(state_.log_probs[0][[1, 0, 0]])
@ -211,7 +218,7 @@ class TestBeamStep(test.TestCase):
expected_log_probs[0][2] += log_probs_[0, 0, 2]
expected_log_probs[1][0] += log_probs_[1, 1, 2]
expected_log_probs[1][2] += log_probs_[1, 0, 1]
np.testing.assert_array_equal(next_state_.log_probs, expected_log_probs)
self.assertAllEqual(next_state_.log_probs, expected_log_probs)
class BeamSearchDecoderTest(test.TestCase):
@ -259,8 +266,9 @@ class BeamSearchDecoderTest(test.TestCase):
output_layer=output_layer,
length_penalty_weight=0.0)
final_outputs, final_state = decoder.dynamic_decode(
bsd, output_time_major=time_major, maximum_iterations=max_out)
final_outputs, final_state, final_sequence_lengths = (
decoder.dynamic_decode(
bsd, output_time_major=time_major, maximum_iterations=max_out))
def _t(shape):
if time_major:
@ -284,16 +292,18 @@ class BeamSearchDecoderTest(test.TestCase):
sess.run(variables.global_variables_initializer())
sess_results = sess.run({
'final_outputs': final_outputs,
'final_state': final_state
'final_state': final_state,
'final_sequence_lengths': final_sequence_lengths
})
# Mostly a smoke test
time_steps = max_out
max_sequence_length = np.max(sess_results['final_sequence_lengths'])
# A smoke test
self.assertEqual(
_t((batch_size, time_steps, beam_width)),
_t((batch_size, max_sequence_length, beam_width)),
sess_results['final_outputs'].beam_search_decoder_output.scores.shape)
self.assertEqual(
_t((batch_size, time_steps, beam_width)), sess_results[
_t((batch_size, max_sequence_length, beam_width)), sess_results[
'final_outputs'].beam_search_decoder_output.predicted_ids.shape)
def testDynamicDecodeRNNBatchMajorNoAttention(self):

View File

@ -38,7 +38,7 @@ class GatherTreeTest(test.TestCase):
[[[1, 2, 3], [4, 5, 6], [7, 8, 9], [-1, -1, -1]]])
parent_ids = _transpose_batch_time(
[[[0, 0, 0], [0, 1, 1], [2, 1, 2], [-1, -1, -1]]])
sequence_length = [3]
sequence_length = [[3, 3, 3]]
expected_result = _transpose_batch_time(
[[[2, 2, 2], [6, 5, 6], [7, 8, 9], [-1, -1, -1]]])
beams = beam_search_ops.gather_tree(
@ -54,7 +54,7 @@ class GatherTreeTest(test.TestCase):
[[[1, 2, 3], [4, 5, 6], [7, 8, 9], [-1, -1, -1]]])
parent_ids = _transpose_batch_time(
[[[0, 0, 0], [0, -1, 1], [2, 1, 2], [-1, -1, -1]]])
sequence_length = [3]
sequence_length = [[3, 3, 3]]
with ops.device("/cpu:0"):
beams = beam_search_ops.gather_tree(
step_ids=step_ids, parent_ids=parent_ids,
@ -73,7 +73,7 @@ class GatherTreeTest(test.TestCase):
[[[1, 2, 3], [4, 5, 6], [7, 8, 9], [-1, -1, -1]]])
parent_ids = _transpose_batch_time(
[[[0, 0, 0], [0, -1, 1], [2, 1, 2], [-1, -1, -1]]])
sequence_length = [3]
sequence_length = [[3, 3, 3]]
expected_result = _transpose_batch_time(
[[[2, -1, 2], [6, 5, 6], [7, 8, 9], [-1, -1, -1]]])
with ops.device("/gpu:0"):
@ -84,7 +84,8 @@ class GatherTreeTest(test.TestCase):
self.assertAllEqual(expected_result, beams.eval())
def testGatherTreeBatch(self):
sequence_length = [0, 1, 2, 3]
# sequence_length is [batch_size, beam_width] = [4, 5]
sequence_length = [[0] * 5, [1] * 5, [2] * 5, [3] * 5]
with self.test_session(use_gpu=True):
# (max_time = 4, batch_size = 4, beam_width = 5)

View File

@ -60,9 +60,9 @@ class DynamicDecodeRNNTest(test.TestCase):
initial_state=cell.zero_state(
dtype=dtypes.float32, batch_size=batch_size))
final_outputs, final_state = decoder.dynamic_decode(
my_decoder, output_time_major=time_major,
maximum_iterations=maximum_iterations)
final_outputs, final_state, final_sequence_length = (
decoder.dynamic_decode(my_decoder, output_time_major=time_major,
maximum_iterations=maximum_iterations))
def _t(shape):
if time_major:
@ -73,6 +73,9 @@ class DynamicDecodeRNNTest(test.TestCase):
isinstance(final_outputs, basic_decoder.BasicDecoderOutput))
self.assertTrue(isinstance(final_state, core_rnn_cell.LSTMStateTuple))
self.assertEqual(
(batch_size,),
tuple(final_sequence_length.get_shape().as_list()))
self.assertEqual(
_t((batch_size, None, cell_depth)),
tuple(final_outputs.rnn_output.get_shape().as_list()))
@ -83,7 +86,8 @@ class DynamicDecodeRNNTest(test.TestCase):
sess.run(variables.global_variables_initializer())
sess_results = sess.run({
"final_outputs": final_outputs,
"final_state": final_state
"final_state": final_state,
"final_sequence_length": final_sequence_length,
})
# Mostly a smoke test
@ -131,7 +135,7 @@ class DynamicDecodeRNNTest(test.TestCase):
# Match the variable scope of dynamic_rnn below so we end up
# using the same variables
with vs.variable_scope("root") as scope:
final_decoder_outputs, final_decoder_state = decoder.dynamic_decode(
final_decoder_outputs, final_decoder_state, _ = decoder.dynamic_decode(
my_decoder,
# impute_finished=True ensures outputs and final state
# match those of dynamic_rnn called with sequence_length not None

View File

@ -454,6 +454,7 @@ class AttentionWrapper(core_rnn_cell.RNNCell):
up to the next cell in an RNN stack or to the top RNN output.
name: Name to use when creating ops.
"""
super(AttentionWrapper, self).__init__()
if not isinstance(cell, core_rnn_cell.RNNCell):
raise TypeError(
"cell must be an RNNCell, saw type: %s" % type(cell).__name__)
@ -515,7 +516,7 @@ class AttentionWrapper(core_rnn_cell.RNNCell):
dtype),
alignment_history=alignment_history)
def __call__(self, inputs, state, tiling_factor=1, scope=None):
def __call__(self, inputs, state, tiling_factor=1):
"""Perform a step of attention-wrapped RNN.
- Step 1: Mix the `inputs` and previous step's `attention` output via
@ -536,7 +537,6 @@ class AttentionWrapper(core_rnn_cell.RNNCell):
tensors from the previous time step.
tiling_factor: An integer factor for which to tile the batch dimension.
Used with BeamSearchDecoder.
scope: Must be `None`.
Returns:
A tuple `(attention_or_cell_output, next_state)`, where:
@ -548,50 +548,46 @@ class AttentionWrapper(core_rnn_cell.RNNCell):
Raises:
NotImplementedError: if `scope` is not `None`.
"""
if scope is not None:
raise NotImplementedError("scope not None is not supported")
# Step 1: Calculate the true inputs to the cell based on the
# previous attention value.
cell_inputs = self._cell_input_fn(inputs, state.attention)
cell_state = state.cell_state
cell_output, next_cell_state = self._cell(cell_inputs, cell_state)
with variable_scope.variable_scope("attention"):
# Step 1: Calculate the true inputs to the cell based on the
# previous attention value.
cell_inputs = self._cell_input_fn(inputs, state.attention)
cell_state = state.cell_state
cell_output, next_cell_state = self._cell(cell_inputs, cell_state)
score = self._attention_mechanism(cell_output, tiling_factor)
alignments = self._probability_fn(score)
score = self._attention_mechanism(cell_output, tiling_factor)
alignments = self._probability_fn(score)
# Reshape from [batch_size, memory_time] to [batch_size, 1, memory_time]
expanded_alignments = array_ops.expand_dims(alignments, 1)
# Context is the inner product of alignments and values along the
# memory time dimension.
# alignments shape is
# [batch_size, 1, memory_time]
# attention_mechanism.values shape is
# [batch_size, memory_time, attention_mechanism.num_units]
# the batched matmul is over memory_time, so the output shape is
# [batch_size, 1, attention_mechanism.num_units].
# we then squeeze out the singleton dim.
attention_mechanism_values = _maybe_tile_batch(
self._attention_mechanism.values, tiling_factor)
# Reshape from [batch_size, memory_time] to [batch_size, 1, memory_time]
expanded_alignments = array_ops.expand_dims(alignments, 1)
# Context is the inner product of alignments and values along the
# memory time dimension.
# alignments shape is
# [batch_size, 1, memory_time]
# attention_mechanism.values shape is
# [batch_size, memory_time, attention_mechanism.num_units]
# the batched matmul is over memory_time, so the output shape is
# [batch_size, 1, attention_mechanism.num_units].
# we then squeeze out the singleton dim.
attention_mechanism_values = _maybe_tile_batch(
self._attention_mechanism.values, tiling_factor)
context = math_ops.matmul(expanded_alignments, attention_mechanism_values)
context = array_ops.squeeze(context, [1])
context = math_ops.matmul(expanded_alignments, attention_mechanism_values)
context = array_ops.squeeze(context, [1])
attention = self._attention_layer(
array_ops.concat([cell_output, context], 1))
attention = self._attention_layer(
array_ops.concat([cell_output, context], 1))
if self._alignment_history:
alignment_history = state.alignment_history.write(
state.time, alignments)
else:
alignment_history = ()
if self._alignment_history:
alignment_history = state.alignment_history.write(
state.time, alignments)
else:
alignment_history = ()
next_state = AttentionWrapperState(
time=state.time + 1,
cell_state=next_cell_state,
attention=attention,
alignment_history=alignment_history)
next_state = AttentionWrapperState(
time=state.time + 1,
cell_state=next_cell_state,
attention=attention,
alignment_history=alignment_history)
if self._output_attention:
return attention, next_state

View File

@ -19,9 +19,9 @@ from __future__ import division
from __future__ import print_function
import collections
import numpy as np
from tensorflow.contrib.rnn import core_rnn_cell
from tensorflow.contrib.seq2seq.python.ops import beam_search_ops
from tensorflow.contrib.seq2seq.python.ops import decoder
from tensorflow.python.framework import dtypes
from tensorflow.python.framework import ops
@ -33,7 +33,6 @@ from tensorflow.python.ops import control_flow_ops
from tensorflow.python.ops import embedding_ops
from tensorflow.python.ops import math_ops
from tensorflow.python.ops import nn_ops
from tensorflow.python.ops import script_ops
from tensorflow.python.util import nest
@ -202,20 +201,24 @@ class BeamSearchDecoder(decoder.Decoder):
return (finished, start_inputs, initial_state)
def finalize(self, outputs, final_state):
def finalize(self, outputs, final_state, sequence_lengths):
"""Finalize and return the predicted_ids.
Args:
outputs: An instance of BeamSearchDecoderOutput.
final_state: An instance of BeamSearchDecoderState. Passed through to the
output.
sequence_lengths: An `int32` tensor shaped `[batch_size, beam_width]`.
The sequence lengths determined for each beam during decode.
Returns:
outputs: An instance of FinalBeamSearchDecoderOutput where the
predicted_ids are the result of calling _gather_tree.
final_state: The same input instance of BeamSearchDecoderState.
"""
predicted_ids = _gather_tree(outputs.predicted_ids, outputs.parent_ids)
predicted_ids = beam_search_ops.gather_tree(
outputs.predicted_ids, outputs.parent_ids,
sequence_length=sequence_lengths)
outputs = FinalBeamSearchDecoderOutput(
beam_search_decoder_output=outputs, predicted_ids=predicted_ids)
return outputs, final_state
@ -536,42 +539,6 @@ def _mask_probs(probs, eos_token, finished):
return finished_examples + non_finished_examples
def _gather_tree_py(values, parents):
"""Gathers path through a tree backwards from the leave nodes.
Used to reconstruct beams given their parents.
Args:
values: A [T, batch_size, beam_width] tensor of indices.
parents: A [T, batch_size, beam_width] tensor of parent beam ids.
Returns:
The [T, batch_size, beam_width] numpy array of paths. For a given batch
entry b, the best path is given by ret[:, b, 0].
"""
num_timesteps = values.shape[0]
num_beams = values.shape[2]
batch_size = values.shape[1]
ret = np.zeros_like(values) # [T, MB, BW]
ret[-1, :, :] = values[-1, :, :]
for beam_id in range(num_beams):
for batch in range(batch_size):
parent = parents[-1][batch][beam_id]
for timestep in reversed(range(num_timesteps - 1)):
ret[timestep, batch, beam_id] = values[timestep][batch][parent]
parent = parents[timestep][batch][parent]
# now we are going to return ret as a [ts, mb, bw] tensor
return np.array(ret).astype(values.dtype)
def _gather_tree(values, parents):
"""Tensor version of _gather_tree_py."""
ret = script_ops.py_func(
func=_gather_tree_py, inp=[values, parents], Tout=values.dtype)
ret.set_shape(values.get_shape().as_list())
return ret
def _tensor_gather_helper(gather_indices, gather_from, range_input, range_size,
final_shape):
range_ = array_ops.expand_dims(math_ops.range(range_input) * range_size, 1)

View File

@ -154,11 +154,11 @@ def dynamic_decode(decoder,
scope: Optional variable scope to use.
Returns:
`(final_outputs, final_state)`.
`(final_outputs, final_state, final_sequence_lengths)`.
Raises:
TypeError: if `decoder` is not an instance of `Decoder`.
ValueError: if maximum_iterations is provided but is not a scalar.
ValueError: if `maximum_iterations` is provided but is not a scalar.
"""
if not isinstance(decoder, Decoder):
raise TypeError("Expected decoder to be type Decoder, but saw: %s" %
@ -184,6 +184,8 @@ def dynamic_decode(decoder,
if maximum_iterations is not None:
initial_finished = math_ops.logical_or(
initial_finished, 0 >= maximum_iterations)
initial_sequence_lengths = array_ops.zeros_like(
initial_finished, dtype=dtypes.int32)
initial_time = constant_op.constant(0, dtype=dtypes.int32)
def _shape(batch_size, from_shape):
@ -206,10 +208,10 @@ def dynamic_decode(decoder,
decoder.output_dtype)
def condition(unused_time, unused_outputs_ta, unused_state, unused_inputs,
finished):
finished, unused_sequence_lengths):
return math_ops.logical_not(math_ops.reduce_all(finished))
def body(time, outputs_ta, state, inputs, finished):
def body(time, outputs_ta, state, inputs, finished, sequence_lengths):
"""Internal while_loop body.
Args:
@ -217,10 +219,13 @@ def dynamic_decode(decoder,
outputs_ta: structure of TensorArray.
state: (structure of) state tensors and TensorArrays.
inputs: (structure of) input tensors.
finished: 1-D bool tensor.
finished: bool tensor (keeping track of what's finished).
sequence_lengths: int32 tensor (keeping track of time of finish).
Returns:
`(time + 1, outputs_ta, next_state, next_inputs, next_finished)`.
`(time + 1, outputs_ta, next_state, next_inputs, next_finished,
next_sequence_lengths)`.
```
"""
(next_outputs, decoder_state, next_inputs,
decoder_finished) = decoder.step(time, inputs, state)
@ -228,6 +233,10 @@ def dynamic_decode(decoder,
if maximum_iterations is not None:
next_finished = math_ops.logical_or(
next_finished, time + 1 >= maximum_iterations)
next_sequence_lengths = array_ops.where(
math_ops.logical_and(math_ops.logical_not(finished), next_finished),
array_ops.fill(array_ops.shape(sequence_lengths), time + 1),
sequence_lengths)
nest.assert_same_structure(state, decoder_state)
nest.assert_same_structure(outputs_ta, next_outputs)
@ -260,26 +269,30 @@ def dynamic_decode(decoder,
outputs_ta = nest.map_structure(lambda ta, out: ta.write(time, out),
outputs_ta, emit)
return (time + 1, outputs_ta, next_state, next_inputs, next_finished)
return (time + 1, outputs_ta, next_state, next_inputs, next_finished,
next_sequence_lengths)
res = control_flow_ops.while_loop(
condition,
body,
loop_vars=[
initial_time, initial_outputs_ta, initial_state, initial_inputs,
initial_finished
initial_finished, initial_sequence_lengths,
],
parallel_iterations=parallel_iterations,
swap_memory=swap_memory)
final_outputs_ta = res[1]
final_state = res[2]
final_sequence_lengths = res[5]
final_outputs = nest.map_structure(lambda ta: ta.stack(), final_outputs_ta)
if hasattr(decoder, "finalize"):
final_outputs, final_state = decoder.finalize(
final_outputs, final_state, final_sequence_lengths)
if not output_time_major:
final_outputs = nest.map_structure(_transpose_batch_time, final_outputs)
if hasattr(decoder, "finalize"):
final_outputs, final_state = decoder.finalize(final_outputs, final_state)
return final_outputs, final_state
return final_outputs, final_state, final_sequence_lengths

View File

@ -19,13 +19,11 @@ from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import inspect
from six import exec_
from tensorflow.contrib.specs.python import params_ops
from tensorflow.contrib.specs.python import specs_lib
from tensorflow.contrib.specs.python import specs_ops
from tensorflow.python.util import tf_inspect
def eval_params(params, environment=None):
@ -44,7 +42,8 @@ def eval_params(params, environment=None):
"""
specs_lib.check_keywords(params)
bindings = {}
if environment: bindings.update(environment)
if environment:
bindings.update(environment)
exec_(params, vars(params_ops), bindings) # pylint: disable=exec-used
return bindings
@ -71,7 +70,8 @@ def eval_spec(spec, environment=None):
"""
specs_lib.check_keywords(spec)
bindings = {}
if environment: bindings.update(environment)
if environment:
bindings.update(environment)
exec_(spec, vars(specs_ops), bindings) # pylint: disable=exec-used
return bindings
@ -141,7 +141,7 @@ class LocalImport(object):
self.names = names
def __enter__(self):
self.frame = inspect.currentframe()
self.frame = tf_inspect.currentframe()
bindings = self.frame.f_back.f_globals
self.old = {k: bindings.get(k, None) for k in self.names.keys()}
bindings.update(self.names)
@ -151,7 +151,9 @@ class LocalImport(object):
bindings = self.frame.f_back.f_globals
bindings.update(self.old)
for k, v in self.old.items():
if v is None: del bindings[k]
if v is None:
del bindings[k]
del self.frame
ops = LocalImport(specs_ops)

View File

@ -1,6 +1,10 @@
# Description:
# Verbs RDMA communication interfaces and implementations for TensorFlow.
package(default_visibility = [
"//tensorflow:__subpackages__",
])
licenses(["notice"]) # Apache 2.0
exports_files(["LICENSE"])
@ -31,13 +35,10 @@ load(
"tf_proto_library_cc",
)
package(default_visibility = [
"//tensorflow:__subpackages__",
])
tf_proto_library_cc(
name = "verbs_service_proto",
srcs = ["verbs_service.proto"],
has_services = 1,
cc_api_version = 2,
visibility = [
"//tensorflow:__subpackages__",

View File

@ -19,11 +19,11 @@ limitations under the License.
#include "tensorflow/core/lib/core/errors.h"
#include "tensorflow/core/lib/core/status.h"
namespace tensorflow {
namespace tensorflow {
Status GrpcVerbsClient::GetRemoteAddress(CallOptions* call_options,
const GetRemoteAddressRequest* request,
GetRemoteAddressResponse* response) {
const GetRemoteAddressRequest* request,
GetRemoteAddressResponse* response) {
::grpc::ClientContext ctx;
ctx.set_fail_fast(false);
SetDeadline(&ctx, call_options->GetTimeout());
@ -31,14 +31,14 @@ Status GrpcVerbsClient::GetRemoteAddress(CallOptions* call_options,
}
Status GrpcVerbsClient::GetRemoteAddress(const GetRemoteAddressRequest* request,
GetRemoteAddressResponse* response) {
GetRemoteAddressResponse* response) {
CallOptions call_options;
call_options.SetTimeout(-1); // no time out
call_options.SetTimeout(-1); // no time out
return GetRemoteAddress(&call_options, request, response);
}
void GrpcVerbsClient::SetDeadline(::grpc::ClientContext* ctx,
int64 time_in_ms) {
void GrpcVerbsClient::SetDeadline(::grpc::ClientContext* ctx,
int64 time_in_ms) {
if (time_in_ms > 0) {
ctx->set_deadline(gpr_time_from_millis(time_in_ms, GPR_TIMESPAN));
}

View File

@ -16,11 +16,11 @@ limitations under the License.
#ifndef THIRD_PARTY_TENSORFLOW_CONTRIB_GRPC_VERBS_CLIENT_H_
#define THIRD_PARTY_TENSORFLOW_CONTRIB_GRPC_VERBS_CLIENT_H_
#include "tensorflow/contrib/verbs/grpc_verbs_service_impl.h"
#include "tensorflow/contrib/verbs/verbs_service.pb.h"
#include "tensorflow/core/distributed_runtime/call_options.h"
#include "tensorflow/core/distributed_runtime/rpc/grpc_util.h"
#include "tensorflow/core/lib/core/status.h"
#include "tensorflow/contrib/verbs/grpc_verbs_service_impl.h"
#include "tensorflow/contrib/verbs/verbs_service.pb.h"
namespace tensorflow {
@ -28,24 +28,23 @@ namespace tensorflow {
class GrpcVerbsClient {
public:
explicit GrpcVerbsClient(SharedGrpcChannelPtr client_channel)
: stub_(grpc::VerbsService::NewStub(client_channel)) {}
: stub_(grpc::VerbsService::NewStub(client_channel)) {}
~GrpcVerbsClient() {}
Status GetRemoteAddress(CallOptions* call_options,
const GetRemoteAddressRequest* request,
GetRemoteAddressResponse* response);
const GetRemoteAddressRequest* request,
GetRemoteAddressResponse* response);
Status GetRemoteAddress(const GetRemoteAddressRequest* request,
GetRemoteAddressResponse* response);
GetRemoteAddressResponse* response);
private:
std::unique_ptr<grpc::VerbsService::Stub> stub_;
void SetDeadline(::grpc::ClientContext* ctx, int64 time_in_ms);
TF_DISALLOW_COPY_AND_ASSIGN(GrpcVerbsClient);
};
} // namespace tensorflow
#endif // THIRD_PARTY_TENSORFLOW_CONTRIB_GRPC_VERBS_CLIENT_H_

View File

@ -26,10 +26,10 @@ limitations under the License.
namespace tensorflow {
GrpcVerbsService::GrpcVerbsService(const WorkerEnv* worker_env,
::grpc::ServerBuilder* builder)
: is_shutdown_(false), worker_env_(worker_env) {
builder->RegisterService(&verbs_service_);
cq_ = builder->AddCompletionQueue().release();
::grpc::ServerBuilder* builder)
: is_shutdown_(false), worker_env_(worker_env) {
builder->RegisterService(&verbs_service_);
cq_ = builder->AddCompletionQueue().release();
}
GrpcVerbsService::~GrpcVerbsService() {
@ -52,7 +52,7 @@ void GrpcVerbsService::Shutdown() {
new ::grpc::Alarm(cq_, gpr_now(GPR_CLOCK_MONOTONIC), nullptr);
}
}
// This macro creates a new request for the given RPC method name
// (e.g., `ENQUEUE_REQUEST(GetRemoteAddress, false);`), and enqueues it on
// `this->cq_`.
@ -64,17 +64,17 @@ void GrpcVerbsService::Shutdown() {
// The implementation of the request handler for each RPC method
// must ensure that it calls ENQUEUE_REQUEST() for that RPC method,
// to keep accepting new requests.
#define ENQUEUE_REQUEST(method, supports_cancel) \
do { \
mutex_lock l(shutdown_mu_); \
if (!is_shutdown_) { \
Call<GrpcVerbsService, grpc::VerbsService::AsyncService, \
method##Request, method##Response>:: \
EnqueueRequest(&verbs_service_, cq_, \
&grpc::VerbsService::AsyncService::Request##method, \
&GrpcVerbsService::method##Handler, \
(supports_cancel)); \
} \
#define ENQUEUE_REQUEST(method, supports_cancel) \
do { \
mutex_lock l(shutdown_mu_); \
if (!is_shutdown_) { \
Call<GrpcVerbsService, grpc::VerbsService::AsyncService, \
method##Request, method##Response>:: \
EnqueueRequest(&verbs_service_, cq_, \
&grpc::VerbsService::AsyncService::Request##method, \
&GrpcVerbsService::method##Handler, \
(supports_cancel)); \
} \
} while (0)
// This method blocks forever handling requests from the completion queue.
@ -97,8 +97,8 @@ void GrpcVerbsService::HandleRPCsLoop() {
}
}
void GrpcVerbsService::GetRemoteAddressHandler(WorkerCall
<GetRemoteAddressRequest, GetRemoteAddressResponse>* call) {
void GrpcVerbsService::GetRemoteAddressHandler(
WorkerCall<GetRemoteAddressRequest, GetRemoteAddressResponse>* call) {
Status s = GetRemoteAddressSync(&call->request, &call->response);
call->SendResponse(ToGrpcStatus(s));
ENQUEUE_REQUEST(GetRemoteAddress, false);
@ -106,8 +106,8 @@ void GrpcVerbsService::GetRemoteAddressHandler(WorkerCall
// synchronous method
Status GrpcVerbsService::GetRemoteAddressSync(
const GetRemoteAddressRequest* request,
GetRemoteAddressResponse* response) {
const GetRemoteAddressRequest* request,
GetRemoteAddressResponse* response) {
// analyzing request
// the channel setting part is redundant.
const string remote_host_name = request->host_name();
@ -115,7 +115,7 @@ Status GrpcVerbsService::GetRemoteAddressSync(
CHECK(rc);
RdmaAddress ra;
ra.lid = request->channel().lid();
ra.qpn = request->channel().qpn();
ra.qpn = request->channel().qpn();
ra.psn = request->channel().psn();
rc->SetRemoteAddress(ra, false);
rc->Connect();
@ -140,8 +140,8 @@ Status GrpcVerbsService::GetRemoteAddressSync(
CHECK(i == RdmaChannel::kNumMessageBuffers);
// setting up response
response->set_host_name(worker_env_->session_mgr->
LegacySession()->worker_name);
response->set_host_name(
worker_env_->session_mgr->LegacySession()->worker_name);
Channel* channel_info = response->mutable_channel();
channel_info->set_lid(rc->self().lid);
channel_info->set_qpn(rc->self().qpn);
@ -151,12 +151,12 @@ Status GrpcVerbsService::GetRemoteAddressSync(
mr->set_remote_addr(reinterpret_cast<uint64>(mb[i]->buffer()));
mr->set_rkey(mb[i]->self()->rkey);
}
return Status::OK();
return Status::OK();
}
// Create a GrpcVerbsService, then assign it to a given handle.
void SetNewVerbsService(GrpcVerbsService** handle, const WorkerEnv* worker_env,
::grpc::ServerBuilder* builder) {
::grpc::ServerBuilder* builder) {
*handle = new GrpcVerbsService(worker_env, builder);
}

View File

@ -18,12 +18,12 @@ limitations under the License.
#ifdef TENSORFLOW_USE_VERBS
#include "tensorflow/contrib/verbs/rdma_mgr.h"
#include "tensorflow/contrib/verbs/grpc_verbs_service_impl.h"
#include "tensorflow/contrib/verbs/rdma_mgr.h"
#include "tensorflow/contrib/verbs/verbs_service.pb.h"
#include "tensorflow/core/distributed_runtime/rpc/async_service_interface.h"
#include "tensorflow/core/lib/core/refcount.h"
#include "tensorflow/core/distributed_runtime/rpc/grpc_call.h"
#include "tensorflow/core/lib/core/refcount.h"
namespace grpc {
class ServerBuilder;
@ -44,27 +44,27 @@ class GrpcVerbsService : public AsyncServiceInterface {
private:
template <class RequestMessage, class ResponseMessage>
using WorkerCall = Call<GrpcVerbsService, grpc::VerbsService::AsyncService,
RequestMessage, ResponseMessage>;
void GetRemoteAddressHandler(WorkerCall
<GetRemoteAddressRequest, GetRemoteAddressResponse>* call);
RequestMessage, ResponseMessage>;
void GetRemoteAddressHandler(
WorkerCall<GetRemoteAddressRequest, GetRemoteAddressResponse>* call);
Status GetRemoteAddressSync(const GetRemoteAddressRequest* request,
GetRemoteAddressResponse* response);
::grpc::ServerCompletionQueue* cq_;
GetRemoteAddressResponse* response);
::grpc::ServerCompletionQueue* cq_;
grpc::VerbsService::AsyncService verbs_service_;
mutex shutdown_mu_;
bool is_shutdown_ GUARDED_BY(shutdown_mu_);
::grpc::Alarm* shutdown_alarm_;
// not owned
RdmaMgr* rdma_mgr_;
const WorkerEnv* const worker_env_;
const WorkerEnv* const worker_env_;
TF_DISALLOW_COPY_AND_ASSIGN(GrpcVerbsService);
};
// Create a GrpcVerbsService, then assign it to a given handle.
void SetNewVerbsService(GrpcVerbsService** handle, const WorkerEnv* worker_env,
::grpc::ServerBuilder* builder);
::grpc::ServerBuilder* builder);
} // namespace tensorflow

View File

@ -43,7 +43,7 @@ VerbsService::Stub::Stub(
const std::shared_ptr< ::grpc::ChannelInterface>& channel)
: channel_(channel),
rpcmethod_GetRemoteAddress_(grpcVerbsService_method_names[0],
::grpc::RpcMethod::NORMAL_RPC, channel) {}
::grpc::RpcMethod::NORMAL_RPC, channel) {}
::grpc::Status VerbsService::Stub::GetRemoteAddress(
::grpc::ClientContext* context, const GetRemoteAddressRequest& request,

View File

@ -48,16 +48,16 @@ class VerbsService GRPC_FINAL {
class StubInterface {
public:
virtual ~StubInterface() {}
virtual ::grpc::Status GetRemoteAddress(::grpc::ClientContext* context,
const GetRemoteAddressRequest& request,
GetRemoteAddressResponse* response) = 0;
virtual ::grpc::Status GetRemoteAddress(
::grpc::ClientContext* context, const GetRemoteAddressRequest& request,
GetRemoteAddressResponse* response) = 0;
};
class Stub GRPC_FINAL : public StubInterface {
public:
Stub(const std::shared_ptr< ::grpc::ChannelInterface>& channel);
::grpc::Status GetRemoteAddress(::grpc::ClientContext* context,
const GetRemoteAddressRequest& request,
GetRemoteAddressResponse* response) GRPC_OVERRIDE;
::grpc::Status GetRemoteAddress(
::grpc::ClientContext* context, const GetRemoteAddressRequest& request,
GetRemoteAddressResponse* response) GRPC_OVERRIDE;
private:
std::shared_ptr< ::grpc::ChannelInterface> channel_;

View File

@ -15,16 +15,16 @@ limitations under the License.
#ifdef TENSORFLOW_USE_VERBS
#include <cstdlib>
#include "tensorflow/contrib/verbs/rdma.h"
#include <cstdlib>
#include "tensorflow/contrib/verbs/verbs_util.h"
#include "tensorflow/core/common_runtime/device_mgr.h"
#include "tensorflow/core/common_runtime/dma_helper.h"
#include "tensorflow/core/common_runtime/gpu/gpu_util.h"
#include "tensorflow/core/distributed_runtime/rendezvous_mgr_interface.h"
#include "tensorflow/core/distributed_runtime/session_mgr.h"
#include "tensorflow/core/framework/tensor.h"
#include "tensorflow/core/framework/rendezvous.h"
#include "tensorflow/core/framework/tensor.h"
#include "tensorflow/core/lib/core/status.h"
#include "tensorflow/core/lib/core/stringpiece.h"
#include "tensorflow/core/lib/hash/hash.h"
@ -35,12 +35,12 @@ namespace tensorflow {
namespace {
// hash name to 32-bit integer
uint32_t NameHash(const string& name) {
return Hash32(name.data(), name.size(), 0x1234ABCD);
return Hash32(name.data(), name.size(), 0x1234ABCD);
}
// convenience function for printing message
string MessageTypeToString(RdmaMessageType rmt) {
switch(rmt){
switch (rmt) {
case RDMA_MESSAGE_ACK:
return "RDMA_MESSAGE_ACK";
break;
@ -59,11 +59,11 @@ string MessageTypeToString(RdmaMessageType rmt) {
case RDMA_MESSAGE_TENSOR_WRITE:
return "RDMA_MESSAGE_TENSOR_WRITE";
break;
default:
default:
return "UNKNOWN MESSAGE";
}
}
}
} // namespace
ibv_context* open_default_device() {
ibv_device** dev_list;
@ -89,29 +89,28 @@ RdmaAdapter::RdmaAdapter(const WorkerEnv* worker_env)
worker_env_(worker_env) {
event_channel_ = ibv_create_comp_channel(context_);
CHECK(event_channel_) << "Failed to create completion channel";
cq_ = ibv_create_cq(context_, MAX_CONCURRENT_WRITES * 2, NULL, event_channel_, 0);
cq_ = ibv_create_cq(context_, MAX_CONCURRENT_WRITES * 2, NULL, event_channel_,
0);
CHECK(cq_) << "Failed to create completion queue";
CHECK(!ibv_req_notify_cq(cq_, 0)) << "Failed to request CQ notification";
polling_thread_.reset(Env::Default()->StartThread(
ThreadOptions(), "RdmaAdapterCQThread",
[this] {Process_CQ(); }));
VLOG(2) << "Start RdmaAdapter: " << name();
ThreadOptions(), "RdmaAdapterCQThread", [this] { Process_CQ(); }));
VLOG(2) << "Start RdmaAdapter: " << name();
}
RdmaAdapter::~RdmaAdapter() {
polling_thread_.reset();
CHECK(!ibv_destroy_cq(cq_)) << "Failed to destroy CQ";
CHECK(!ibv_destroy_comp_channel(event_channel_)) << "Failed to destroy channel";
CHECK(!ibv_destroy_comp_channel(event_channel_))
<< "Failed to destroy channel";
CHECK(!ibv_dealloc_pd(pd_)) << "Failed to deallocate PD";
CHECK(!ibv_close_device(context_)) << "Failed to release context";
}
string RdmaAdapter::name() const {
return string(context_->device->name);
}
string RdmaAdapter::name() const { return string(context_->device->name); }
// Function to process incoming messages
// There are two types of messages:
// There are two types of messages:
// 1. IBV_WC_RECV_RDMA_WITH_IMM (receive)
// 2. IBV_WC_RDMA_WRITE (send))
void RdmaAdapter::Process_CQ() {
@ -123,15 +122,14 @@ void RdmaAdapter::Process_CQ() {
ibv_ack_cq_events(cq, 1);
CHECK(!ibv_req_notify_cq(cq_, 0));
int ne = ibv_poll_cq(cq_, MAX_CONCURRENT_WRITES * 2,
static_cast<ibv_wc*>(wc_));
int ne =
ibv_poll_cq(cq_, MAX_CONCURRENT_WRITES * 2, static_cast<ibv_wc*>(wc_));
CHECK_GE(ne, 0);
for (int i = 0; i < ne; ++i) {
CHECK(wc_[i].status == IBV_WC_SUCCESS) << "Failed status \n"
<< ibv_wc_status_str(wc_[i].status)
<< " " << wc_[i].status << " "
<< static_cast<int>(wc_[i].wr_id)
<< " "<< wc_[i].vendor_err;
CHECK(wc_[i].status == IBV_WC_SUCCESS)
<< "Failed status \n"
<< ibv_wc_status_str(wc_[i].status) << " " << wc_[i].status << " "
<< static_cast<int>(wc_[i].wr_id) << " " << wc_[i].vendor_err;
if (wc_[i].opcode == IBV_WC_RECV_RDMA_WITH_IMM) {
RdmaChannel* rc = reinterpret_cast<RdmaChannel*>(wc_[i].wr_id);
// put back a recv wr.
@ -142,8 +140,8 @@ void RdmaAdapter::Process_CQ() {
RdmaMessage rm;
RdmaMessage::ParseMessage(rm, rb->buffer_);
VLOG(2) << "recv RDMA message: " << MessageTypeToString(rm.type_);
if (rm.type_ == RDMA_MESSAGE_ACK) {
if (rm.type_ == RDMA_MESSAGE_ACK) {
// receive an ack to a message
rb = rc->tx_message_buffer_;
rb->SetBufferStatus(remote, idle);
@ -155,12 +153,12 @@ void RdmaAdapter::Process_CQ() {
ab->SendNextItem();
// find or create buffer
RdmaBuffer* tb = rc->FindOrCreateBuffer(rm.name_);
string key_with_step_id =
VerbsUtil::AppendStepidToKey(rm.name_, rm.step_id_);
string key_with_step_id =
VerbsUtil::AppendStepidToKey(rm.name_, rm.step_id_);
tb->EnqueueItem(key_with_step_id);
// send the next tensor
worker_env_->compute_pool->Schedule([tb](){tb->SendNextItem();});
} else if (rm.type_ == RDMA_MESSAGE_BUFFER_IDLE) {
worker_env_->compute_pool->Schedule([tb]() { tb->SendNextItem(); });
} else if (rm.type_ == RDMA_MESSAGE_BUFFER_IDLE) {
// receive tensor-buffer-ready message
// send ack to release remote tx message buffer
RdmaBuffer* ab = rc->tx_ack_buffer_;
@ -168,7 +166,7 @@ void RdmaAdapter::Process_CQ() {
// find buffer
RdmaBuffer* tb = rc->FindBuffer(rm.name_);
tb->SetBufferStatus(remote, idle);
worker_env_->compute_pool->Schedule([tb](){tb->SendNextItem();});
worker_env_->compute_pool->Schedule([tb]() { tb->SendNextItem(); });
} else if (rm.type_ == RDMA_MESSAGE_BUFFER_REQUEST) {
// remote host requests to create a tensor buffer;
// send ack to release remote tx message buffer
@ -194,31 +192,30 @@ void RdmaAdapter::Process_CQ() {
mb->EnqueueItem(message);
mb->SendNextItem();
} else if (rm.type_ == RDMA_MESSAGE_BUFFER_RESPONSE) {
// remote creates a buffer and responds
// remote creates a buffer and responds
// send ack to release remote tx message buffer
RdmaBuffer* ab = rc->tx_ack_buffer_;
ab->SendNextItem();
// find buffer
RdmaBuffer* tb = rc->FindBuffer(rm.name_);
CHECK(rm.buffer_size_ == tb->size_)
<< "rm.buffer_size = " << rm.buffer_size_
<< "tb->size_ = " << tb->size_
<< "rm.name_ = " << rm.name_;
CHECK(rm.buffer_size_ == tb->size_)
<< "rm.buffer_size = " << rm.buffer_size_
<< "tb->size_ = " << tb->size_ << "rm.name_ = " << rm.name_;
RemoteMR rmr;
rmr.remote_addr = rm.remote_addr_;
rmr.rkey = rm.rkey_;
tb->SetRemoteMR(rmr, true);
tb->SetBufferStatus(local, idle);
tb->SetBufferStatus(remote, idle);
worker_env_->compute_pool->Schedule([tb](){tb->SendNextItem();});
worker_env_->compute_pool->Schedule([tb]() { tb->SendNextItem(); });
} else if (rm.type_ == RDMA_MESSAGE_TENSOR_WRITE) {
// tensor RDMA write completed
worker_env_->compute_pool->Schedule([rm, rc](){
string key_with_step_id =
VerbsUtil::AppendStepidToKey(rm.name_, rm.step_id_);
worker_env_->compute_pool->Schedule([rm, rc]() {
string key_with_step_id =
VerbsUtil::AppendStepidToKey(rm.name_, rm.step_id_);
rc->RunRecvCallback(key_with_step_id);
});
}
});
}
} else if (wc_[i].opcode == IBV_WC_RDMA_WRITE) {
RdmaBuffer* rb = reinterpret_cast<RdmaBuffer*>(wc_[i].wr_id);
rb->SetBufferStatus(local, idle);
@ -226,7 +223,7 @@ void RdmaAdapter::Process_CQ() {
RdmaMessage::ParseMessage(rm, rb->buffer_);
VLOG(2) << "sent RDMA message: " << MessageTypeToString(rm.type_);
if (rm.type_ != RDMA_MESSAGE_ACK) {
worker_env_->compute_pool->Schedule([rb](){rb->SendNextItem();});
worker_env_->compute_pool->Schedule([rb]() { rb->SendNextItem(); });
}
}
}
@ -235,9 +232,7 @@ void RdmaAdapter::Process_CQ() {
RdmaChannel::RdmaChannel(const RdmaAdapter* adapter, const string local_name,
const string remote_name)
: adapter_(adapter),
local_name_(local_name),
remote_name_(remote_name) {
: adapter_(adapter), local_name_(local_name), remote_name_(remote_name) {
// Create queue pair
{
struct ibv_qp_init_attr attr;
@ -263,21 +258,21 @@ RdmaChannel::RdmaChannel(const RdmaAdapter* adapter, const string local_name,
attr.port_num = 1;
attr.qp_access_flags = IBV_ACCESS_LOCAL_WRITE | IBV_ACCESS_REMOTE_WRITE;
int mask = IBV_QP_STATE | IBV_QP_PKEY_INDEX | IBV_QP_PORT
| IBV_QP_ACCESS_FLAGS;
int mask =
IBV_QP_STATE | IBV_QP_PKEY_INDEX | IBV_QP_PORT | IBV_QP_ACCESS_FLAGS;
CHECK(!ibv_modify_qp(qp_, &attr, mask)) << "Failed to set QP to INIT";
}
// Local address
{
struct ibv_port_attr attr;
CHECK(!ibv_query_port(adapter_->context_, (uint8_t) 1, &attr))
CHECK(!ibv_query_port(adapter_->context_, (uint8_t)1, &attr))
<< "Query port";
self_.lid = attr.lid;
self_.qpn = qp_->qp_num;
self_.psn = static_cast<uint32_t>(random::New64()) & 0xffffff;
}
// create message and ack buffers, then initialize the tables.
{
const string buffer_names[] = {"tx_message_buffer", "rx_message_buffer",
@ -303,7 +298,7 @@ RdmaChannel::RdmaChannel(const RdmaAdapter* adapter, const string local_name,
buffer_index_name_table_.insert({index, buffer_names[i]});
buffer_name_index_table_.insert({buffer_names[i], index});
}
// Initiate recv
for (int i = 0; i < 100; i++) {
Recv();
@ -320,17 +315,17 @@ RdmaChannel::~RdmaChannel() {
}
void RdmaChannel::SetRemoteAddress(const RdmaAddress& ra, bool override) {
mutex_lock lock{mu_};
if ((override) || (!remote_set_)) {
remote_.lid = ra.lid;
remote_.qpn = ra.qpn;
remote_.psn = ra.psn;
remote_set_ = true;
} else {
CHECK(remote_.lid == ra.lid);
CHECK(remote_.qpn == ra.qpn);
CHECK(remote_.psn == ra.psn);
}
mutex_lock lock{mu_};
if ((override) || (!remote_set_)) {
remote_.lid = ra.lid;
remote_.qpn = ra.qpn;
remote_.psn = ra.psn;
remote_set_ = true;
} else {
CHECK(remote_.lid == ra.lid);
CHECK(remote_.qpn == ra.qpn);
CHECK(remote_.psn == ra.psn);
}
}
// Adding tokens to the completion queue
@ -338,7 +333,7 @@ void RdmaChannel::SetRemoteAddress(const RdmaAddress& ra, bool override) {
void RdmaChannel::Recv() {
struct ibv_recv_wr wr;
memset(&wr, 0, sizeof(wr));
wr.wr_id = (uint64_t) this;
wr.wr_id = (uint64_t)this;
struct ibv_recv_wr* bad_wr;
CHECK(!ibv_post_recv(qp_, &wr, &bad_wr)) << "Failed to post recv";
}
@ -347,12 +342,11 @@ void RdmaChannel::Recv() {
// Args:
// buffer_name: name of the buffer
// Returns:
// 32-bit index
uint32_t RdmaChannel::LookupBufferIndex(const string& buffer_name){
// 32-bit index
uint32_t RdmaChannel::LookupBufferIndex(const string& buffer_name) {
mutex_lock lock{bt_mu_};
BufferNameIndexTable::iterator iter = buffer_name_index_table_.find(
buffer_name);
BufferNameIndexTable::iterator iter =
buffer_name_index_table_.find(buffer_name);
CHECK(iter != buffer_name_index_table_.end());
return iter->second;
}
@ -380,14 +374,14 @@ RdmaBuffer* RdmaChannel::FindBuffer(const string& name) {
}
// Find a buffer if it exists, otherwise create one.
// The memory inside the created buffer is not allocated.
// Args:
// The memory inside the created buffer is not allocated.
// Args:
// name: the name of the buffer
// buffer_type: TENSOR, MESSAGE or ACK.
// Returns:
// the named buffer
RdmaBuffer* RdmaChannel::FindOrCreateBuffer(const string& name,
BufferType buffer_type) {
RdmaBuffer* RdmaChannel::FindOrCreateBuffer(const string& name,
BufferType buffer_type) {
mutex_lock lock{bt_mu_};
RdmaBuffer* rb;
// find index
@ -405,7 +399,7 @@ RdmaBuffer* RdmaChannel::FindOrCreateBuffer(const string& name,
} else if (buffer_type == MESSAGE) {
rb = new RdmaMessageBuffer(this, name);
} else if (buffer_type == ACK) {
rb = new RdmaAckBuffer(this, name);
rb = new RdmaAckBuffer(this, name);
}
buffer_name_index_table_.insert({name, index});
buffer_index_name_table_.insert({index, name});
@ -417,20 +411,19 @@ RdmaBuffer* RdmaChannel::FindOrCreateBuffer(const string& name,
// Insert callback to the callback_table.
// The callback is activated when the corresponding tensor is received.
// Arg:
// Arg:
// key: the name of the tensor
// recv_done: the callback associated with the tensor.
// Returns:
// None
void RdmaChannel::InsertRecvCallback(const string& key,
std::function<void()> recv_done) {
void RdmaChannel::InsertRecvCallback(const string& key,
std::function<void()> recv_done) {
mutex_lock lock{ct_mu_};
callback_table_.insert({key, recv_done});
}
// Remove callback from the callback_table.
// Arg:
// Arg:
// key: the name of the tensor
// Returns:
// None
@ -440,7 +433,7 @@ void RdmaChannel::RemoveRecvCallback(const string& key) {
}
// Run named callback in the callback_table.
// Arg:
// Arg:
// key: the name of the tensor
// Returns:
// None
@ -484,17 +477,15 @@ void RdmaChannel::Connect(const RdmaAddress& remoteAddr) {
attr.ah_attr.sl = 0;
attr.ah_attr.src_path_bits = 0;
attr.ah_attr.port_num = 1;
int r;
CHECK(!(r = ibv_modify_qp(qp_, &attr,
IBV_QP_STATE |
IBV_QP_AV |
IBV_QP_PATH_MTU |
IBV_QP_DEST_QPN |
IBV_QP_RQ_PSN |
IBV_QP_MAX_DEST_RD_ATOMIC |
IBV_QP_MIN_RNR_TIMER))) << "QP to Ready to Receive " << r;
IBV_QP_STATE | IBV_QP_AV | IBV_QP_PATH_MTU |
IBV_QP_DEST_QPN | IBV_QP_RQ_PSN |
IBV_QP_MAX_DEST_RD_ATOMIC |
IBV_QP_MIN_RNR_TIMER)))
<< "QP to Ready to Receive " << r;
memset(&attr, 0, sizeof(ibv_qp_attr));
attr.qp_state = IBV_QPS_RTS;
attr.sq_psn = self_.psn;
@ -502,15 +493,13 @@ void RdmaChannel::Connect(const RdmaAddress& remoteAddr) {
attr.retry_cnt = 7;
attr.rnr_retry = 7; /* infinite */
attr.max_rd_atomic = 1;
CHECK(!(r = ibv_modify_qp(qp_, &attr,
IBV_QP_STATE |
IBV_QP_TIMEOUT |
IBV_QP_RETRY_CNT |
IBV_QP_RNR_RETRY |
IBV_QP_SQ_PSN |
IBV_QP_MAX_QP_RD_ATOMIC))) << "QP to Ready to Send " << r;
IBV_QP_STATE | IBV_QP_TIMEOUT | IBV_QP_RETRY_CNT |
IBV_QP_RNR_RETRY | IBV_QP_SQ_PSN |
IBV_QP_MAX_QP_RD_ATOMIC)))
<< "QP to Ready to Send " << r;
connected_ = true;
} else {
LOG(INFO) << "channel already connected";
@ -518,7 +507,7 @@ void RdmaChannel::Connect(const RdmaAddress& remoteAddr) {
}
RdmaBuffer::RdmaBuffer(RdmaChannel* channel, string name)
: channel_(channel), name_(name) {}
: channel_(channel), name_(name) {}
RdmaBuffer::~RdmaBuffer() {
CHECK(!ibv_dereg_mr(self_)) << "ibv_dereg_mr failed";
@ -528,9 +517,9 @@ RdmaBuffer::~RdmaBuffer() {
void RdmaBuffer::FreeBuffer() {
if ((buffer_ != nullptr) && buffer_on_host_) {
free(buffer_);
}
}
// TODO
// release buffer if it is on device.
// release buffer if it is on device.
// We don't support RDMABuffer on device at this moment.
}
@ -548,14 +537,12 @@ void RdmaBuffer::CreateCPUBuffer(size_t size, bool lock) {
if (local_status_ != none) {
// delete existing buffer
CHECK(!ibv_dereg_mr(self_)) << "ibv_dereg_mr failed";
FreeBuffer();
FreeBuffer();
}
size_ = size;
buffer_ = malloc(size_);
self_ = ibv_reg_mr(channel_->adapter_->pd_,
buffer_, size_,
IBV_ACCESS_LOCAL_WRITE |
IBV_ACCESS_REMOTE_WRITE);
self_ = ibv_reg_mr(channel_->adapter_->pd_, buffer_, size_,
IBV_ACCESS_LOCAL_WRITE | IBV_ACCESS_REMOTE_WRITE);
CHECK(self_) << "Failed to register memory region";
buffer_on_host_ = true;
local_status_ = idle;
@ -572,53 +559,52 @@ void RdmaBuffer::CreateCPUBuffer(size_t size, bool lock) {
// None
void RdmaBuffer::SetRemoteMR(RemoteMR rmr, bool override) {
mutex_lock lock{mu_};
if ((override) || (remote_status_ == none)) {
if ((override) || (remote_status_ == none)) {
remote_.remote_addr = rmr.remote_addr;
remote_.rkey = rmr.rkey;
remote_status_ = idle;
} else {
CHECK(remote_.remote_addr == rmr.remote_addr);
CHECK(remote_.rkey == rmr.rkey);
}
CHECK(remote_.rkey == rmr.rkey);
}
}
// Put a task in the buffer's job queue
void RdmaBuffer::EnqueueItem(string item){
void RdmaBuffer::EnqueueItem(string item) {
mutex_lock lock{mu_};
queue_.push(item);
}
// Rdma-Write the content of the buffer
void RdmaBuffer::Write(uint32_t imm_data, size_t buffer_size) {
struct ibv_sge list;
list.addr = (uint64_t) buffer_;
list.addr = (uint64_t)buffer_;
list.length = buffer_size;
list.lkey = self_->lkey;
struct ibv_send_wr wr;
memset(&wr, 0, sizeof(wr));
wr.wr_id = (uint64_t) this;
wr.wr_id = (uint64_t)this;
wr.sg_list = &list;
wr.num_sge = 1;
wr.opcode = IBV_WR_RDMA_WRITE_WITH_IMM;
wr.send_flags = IBV_SEND_SIGNALED;
wr.imm_data = imm_data;
wr.wr.rdma.remote_addr = (uint64_t) remote_.remote_addr;
wr.wr.rdma.remote_addr = (uint64_t)remote_.remote_addr;
wr.wr.rdma.rkey = remote_.rkey;
struct ibv_send_wr *bad_wr;
struct ibv_send_wr* bad_wr;
CHECK(!ibv_post_send(channel_->qp_, &wr, &bad_wr)) << "Failed to post send";
}
RdmaAckBuffer::RdmaAckBuffer(RdmaChannel* channel, string name)
: RdmaBuffer(channel, name) {}
: RdmaBuffer(channel, name) {}
RdmaMessageBuffer::RdmaMessageBuffer(RdmaChannel* channel, string name)
: RdmaBuffer(channel, name) {}
: RdmaBuffer(channel, name) {}
RdmaTensorBuffer::RdmaTensorBuffer(RdmaChannel* channel, string name)
: RdmaBuffer(channel, name) {}
: RdmaBuffer(channel, name) {}
// Send the next ack from the buffer's job queue.
void RdmaAckBuffer::SendNextItem() {
@ -636,13 +622,12 @@ void RdmaAckBuffer::SendNextItem() {
void RdmaMessageBuffer::SendNextItem() {
uint32_t imm_data = LookupBufferIndex("rx_message_buffer");
mu_.lock();
if (!queue_.empty() && (local_status_ == idle)
&& (remote_status_ == idle)) {
if (!queue_.empty() && (local_status_ == idle) && (remote_status_ == idle)) {
local_status_ = busy;
remote_status_= busy;
remote_status_ = busy;
string message = queue_.front();
queue_.pop();
// local/remote_status_ won't be set back to idle
// local/remote_status_ won't be set back to idle
// unitl Write() is successful
mu_.unlock();
memcpy(buffer_, message.data(), message.size());
@ -665,61 +650,56 @@ void RdmaTensorBuffer::SendNextItem() {
}
// send the tensor if a key is acquired.
if (key_with_step_id != "") {
VLOG(2) << "try to send tensor: " << key_with_step_id;
VLOG(2) << "try to send tensor: " << key_with_step_id;
string key;
int64 step_id;
VerbsUtil::GetKeyAndStepId(key_with_step_id, key, step_id);
CHECK(key.compare(name_) == 0);
Rendezvous::ParsedKey parsed;
Rendezvous::ParseKey(key, &parsed);
Rendezvous::DoneCallback cb = [this, key_with_step_id, key,
step_id, parsed](const Status& status,
const Rendezvous::Args& send_args,
const Rendezvous::Args& recv_args,
const Tensor& in, bool is_dead) {
CHECK(status.ok()) << "RecvLocalAsync was not ok, key"
<< key_with_step_id
<< " error message: " << status.error_message();
Rendezvous::DoneCallback cb = [this, key_with_step_id, key, step_id,
parsed](const Status& status,
const Rendezvous::Args& send_args,
const Rendezvous::Args& recv_args,
const Tensor& in, bool is_dead) {
CHECK(status.ok()) << "RecvLocalAsync was not ok, key" << key_with_step_id
<< " error message: " << status.error_message();
size_t buffer_size = RdmaMessage::kMessageTotalBytes;
size_t tensor_bytes = 0;
TensorProto proto;
// Figures out which device the tensor is hosted on.
Device* src_dev = nullptr;
Status s =
channel_->adapter_->worker_env_->
device_mgr->LookupDevice(parsed.src_device, &src_dev);
Status s = channel_->adapter_->worker_env_->device_mgr->LookupDevice(
parsed.src_device, &src_dev);
CHECK(s.ok()) << "src device not found";
// Does the device have the right incarnation number we expect?
CHECK(src_dev->attributes().incarnation() ==
parsed.src_incarnation)
<< "RecvTensor expects a different device incarnation: "
<< parsed.src_incarnation
<< " vs. "
<< src_dev->attributes().incarnation()
<< ". Your worker job was probably restarted. Check your "
<< "worker job for the reason why it was restarted.";
CHECK(src_dev->attributes().incarnation() == parsed.src_incarnation)
<< "RecvTensor expects a different device incarnation: "
<< parsed.src_incarnation << " vs. "
<< src_dev->attributes().incarnation()
<< ". Your worker job was probably restarted. Check your "
<< "worker job for the reason why it was restarted.";
Device* dst_dev = nullptr;
// destination is on CPU.
s = channel_->adapter_->worker_env_->
device_mgr->LookupDevice("CPU:0", &dst_dev);
CHECK(s.ok())<< "dst device not found";
s = channel_->adapter_->worker_env_->device_mgr->LookupDevice("CPU:0",
&dst_dev);
CHECK(s.ok()) << "dst device not found";
AllocatorAttributes dst_alloc_attr;
dst_alloc_attr.set_on_host(true);
// string tensor needs to be serialized
if (src_dev->tensorflow_gpu_device_info() &&
(!send_args.alloc_attrs.on_host())) {
if (src_dev->tensorflow_gpu_device_info() &&
(!send_args.alloc_attrs.on_host())) {
CHECK(send_args.device_context)
<< "send dev name: " << src_dev->name()
<< " gpu_info: " << src_dev->tensorflow_gpu_device_info();
// "val" is on a GPU. Uses GPUUtil to fill the proto.
s = VerbsUtil::SetProtoFromGPUSync(in, src_dev,
send_args.device_context,
&proto, is_dead);
CHECK(s.ok()) << "set proto from gpu sync";
s = VerbsUtil::SetProtoFromGPUSync(
in, src_dev, send_args.device_context, &proto, is_dead);
CHECK(s.ok()) << "set proto from gpu sync";
} else {
// tensor is in CPU memory.
in.AsProtoTensorContent(&proto);
}
}
tensor_bytes = proto.ByteSize();
// maybe some margin for string tensor?
buffer_size += tensor_bytes;
@ -734,13 +714,12 @@ void RdmaTensorBuffer::SendNextItem() {
rm.tensor_bytes_ = tensor_bytes;
rm.buffer_size_ = buffer_size;
mu_.lock();
if (local_status_ == none ||
(buffer_size > size_ &&
local_status_ == idle &&
if (local_status_ == none ||
(buffer_size > size_ && local_status_ == idle &&
remote_status_ == idle)) {
if ((local_status_ != none) && (buffer_size > size_)) {
CHECK(rm.data_type_ == DT_STRING)
<< "Only string tensor allows to change size";
CHECK(rm.data_type_ == DT_STRING)
<< "Only string tensor allows to change size";
}
CreateCPUBuffer(buffer_size, false);
mu_.unlock();
@ -752,29 +731,29 @@ void RdmaTensorBuffer::SendNextItem() {
rm.rkey_ = self_->rkey;
string message = RdmaMessage::CreateMessage(rm);
channel_->tx_message_buffer_->EnqueueItem(message);
channel_->tx_message_buffer_->SendNextItem();
} else if((local_status_ == idle) && (remote_status_ == idle)) {
channel_->tx_message_buffer_->SendNextItem();
} else if ((local_status_ == idle) && (remote_status_ == idle)) {
// both buffers are ready, send the tensor
local_status_ = busy;
remote_status_ = busy;
// local/remote_status_ won't be set back to idle
// local/remote_status_ won't be set back to idle
// unitl Write() is successful
mu_.unlock();
CHECK((buffer_size == size_ && rm.data_type_ != DT_STRING) ||
(buffer_size <= size_ && rm.data_type_ == DT_STRING))
<< "tensor and buffer size do not agree!"
<< " buffer_size = " << size_
<< " requested tensor size = " << buffer_size
<< in.DebugString();
<< "tensor and buffer size do not agree!"
<< " buffer_size = " << size_
<< " requested tensor size = " << buffer_size << in.DebugString();
uint32_t imm_data = LookupBufferIndex(key);
rm.type_ = RDMA_MESSAGE_TENSOR_WRITE;
string message = RdmaMessage::CreateMessage(rm);
memcpy(buffer_, message.data(), message.size());
if (!is_dead) {
// copy the tensor buffer content
void* output = static_cast<void*>(static_cast<char*>(
buffer_) + RdmaMessage::kTensorBufferStartIndex);
CHECK(tensor_bytes + RdmaMessage::kTensorBufferStartIndex <= size_);
void* output =
static_cast<void*>(static_cast<char*>(buffer_) +
RdmaMessage::kTensorBufferStartIndex);
CHECK(tensor_bytes + RdmaMessage::kTensorBufferStartIndex <= size_);
proto.SerializeToArray(output, tensor_bytes);
} else {
buffer_size = RdmaMessage::kMessageTotalBytes;
@ -789,8 +768,8 @@ void RdmaTensorBuffer::SendNextItem() {
// Use default session (legacy_session_)
// TODO use WorkerSessionForSession
// need to pass in session handle
channel_->adapter_->worker_env_->session_mgr->
LegacySession()->rendezvous_mgr->RecvLocalAsync(step_id, parsed, cb);
channel_->adapter_->worker_env_->session_mgr->LegacySession()
->rendezvous_mgr->RecvLocalAsync(step_id, parsed, cb);
}
}
@ -811,8 +790,10 @@ string RdmaMessage::CreateMessage(const RdmaMessage& rm) {
// TENSOR_WRITE: type|name_size|tensor_name|step_id|...|is_dead
// |data_type|tensor_shape|tensor_bytes
// BUFFER_IDLE: type|name_size|buffer_name
// BUFFER_REQUEST: type|name_size|buffer_name|...|buffer_size|remote_addr|rkey|
// BUFFER_RESPONSE: type|name_size|buffer_name|...|buffer_size|remote_addr|rkey|
// BUFFER_REQUEST:
// type|name_size|buffer_name|...|buffer_size|remote_addr|rkey|
// BUFFER_RESPONSE:
// type|name_size|buffer_name|...|buffer_size|remote_addr|rkey|
char message[kMessageTotalBytes];
// type
message[kTypeStartIndex] = static_cast<char>(rm.type_) & 0xff;
@ -821,32 +802,32 @@ string RdmaMessage::CreateMessage(const RdmaMessage& rm) {
// name
memcpy(&message[kNameStartIndex], rm.name_.data(), rm.name_.size());
// buffer_size, remote_addr, rkey
if ((rm.type_ == RDMA_MESSAGE_BUFFER_REQUEST) ||
if ((rm.type_ == RDMA_MESSAGE_BUFFER_REQUEST) ||
(rm.type_ == RDMA_MESSAGE_BUFFER_RESPONSE)) {
memcpy(&message[kBufferSizeStartIndex], &rm.buffer_size_,
sizeof(rm.buffer_size_));
memcpy(&message[kRemoteAddrStartIndex], &rm.remote_addr_,
sizeof(rm.remote_addr_));
memcpy(&message[kRkeyStartIndex], &rm.rkey_, sizeof(rm.rkey_));
memcpy(&message[kBufferSizeStartIndex], &rm.buffer_size_,
sizeof(rm.buffer_size_));
memcpy(&message[kRemoteAddrStartIndex], &rm.remote_addr_,
sizeof(rm.remote_addr_));
memcpy(&message[kRkeyStartIndex], &rm.rkey_, sizeof(rm.rkey_));
}
// step_id
if ((rm.type_ == RDMA_MESSAGE_TENSOR_WRITE) ||
if ((rm.type_ == RDMA_MESSAGE_TENSOR_WRITE) ||
(rm.type_ == RDMA_MESSAGE_TENSOR_REQUEST)) {
memcpy(&message[kStepIdStartIndex], &rm.step_id_, sizeof(rm.step_id_));
}
// is_dead, data_type, tensor_shape, tensor_bytes
if (rm.type_ == RDMA_MESSAGE_TENSOR_WRITE) {
memcpy(&message[kIsDeadStartIndex], &rm.is_dead_, sizeof(rm.is_dead_));
memcpy(&message[kDataTypeStartIndex], &rm.data_type_,
sizeof(rm.data_type_));
memcpy(&message[kTensorShapeStartIndex], &rm.tensor_shape_,
sizeof(rm.tensor_shape_));
memcpy(&message[kTensorBytesStartIndex], &rm.tensor_bytes_,
sizeof(rm.tensor_bytes_));
memcpy(&message[kDataTypeStartIndex], &rm.data_type_,
sizeof(rm.data_type_));
memcpy(&message[kTensorShapeStartIndex], &rm.tensor_shape_,
sizeof(rm.tensor_shape_));
memcpy(&message[kTensorBytesStartIndex], &rm.tensor_bytes_,
sizeof(rm.tensor_bytes_));
}
return string(message, kMessageTotalBytes);
}
}
// Parse a RdmaMessage according to the pre-defined format
// Args:
@ -865,27 +846,26 @@ void RdmaMessage::ParseMessage(RdmaMessage& rm, void* buffer) {
// buffer_size, remote_addr, rkey
if ((rm.type_ == RDMA_MESSAGE_BUFFER_REQUEST) ||
(rm.type_ == RDMA_MESSAGE_BUFFER_RESPONSE)) {
memcpy(&rm.buffer_size_, &message[kBufferSizeStartIndex],
sizeof(rm.buffer_size_));
memcpy(&rm.remote_addr_, &message[kRemoteAddrStartIndex],
sizeof(rm.remote_addr_));
memcpy(&rm.rkey_, &message[kRkeyStartIndex],
sizeof(rm.rkey_));
memcpy(&rm.buffer_size_, &message[kBufferSizeStartIndex],
sizeof(rm.buffer_size_));
memcpy(&rm.remote_addr_, &message[kRemoteAddrStartIndex],
sizeof(rm.remote_addr_));
memcpy(&rm.rkey_, &message[kRkeyStartIndex], sizeof(rm.rkey_));
}
// step_id
if ((rm.type_ == RDMA_MESSAGE_TENSOR_WRITE) ||
if ((rm.type_ == RDMA_MESSAGE_TENSOR_WRITE) ||
(rm.type_ == RDMA_MESSAGE_TENSOR_REQUEST)) {
memcpy(&rm.step_id_, &message[kStepIdStartIndex], sizeof(rm.step_id_));
}
// data_type, tensor_bytes, tensor_shape, is_dead
// data_type, tensor_bytes, tensor_shape, is_dead
if (rm.type_ == RDMA_MESSAGE_TENSOR_WRITE) {
memcpy(&rm.is_dead_, &message[kIsDeadStartIndex], sizeof(rm.is_dead_));
memcpy(&rm.data_type_, &message[kDataTypeStartIndex],
sizeof(rm.data_type_));
memcpy(&rm.data_type_, &message[kDataTypeStartIndex],
sizeof(rm.data_type_));
memcpy(&rm.tensor_shape_, &message[kTensorShapeStartIndex],
sizeof(rm.tensor_shape_));
memcpy(&rm.tensor_bytes_, &message[kTensorBytesStartIndex],
sizeof(rm.tensor_bytes_));
sizeof(rm.tensor_shape_));
memcpy(&rm.tensor_bytes_, &message[kTensorBytesStartIndex],
sizeof(rm.tensor_bytes_));
}
}

View File

@ -19,11 +19,11 @@ limitations under the License.
#ifdef TENSORFLOW_USE_VERBS
#include <infiniband/verbs.h>
#include <memory> // for shared_ptr
#include <cstring> // for memset
#include <cstring> // for memset
#include <functional>
#include <memory> // for shared_ptr
#include <queue>
#include <string>
#include <functional>
#include <unordered_map>
#include <vector>
@ -37,43 +37,46 @@ namespace tensorflow {
// structure to save the address of remote channels.
struct RdmaAddress {
uint32_t lid;
uint32_t qpn;
uint32_t psn;
uint32_t lid;
uint32_t qpn;
uint32_t psn;
};
// structure to save information for remote memory regions.
struct RemoteMR{
uint64_t remote_addr;
uint32_t rkey;
struct RemoteMR {
uint64_t remote_addr;
uint32_t rkey;
};
enum BufferStatus { none, idle, busy };
enum Location { local, remote };
enum BufferType { ACK, MESSAGE, TENSOR };
enum RdmaMessageType {
RDMA_MESSAGE_ACK,
RDMA_MESSAGE_BUFFER_IDLE,
RDMA_MESSAGE_BUFFER_REQUEST,
RDMA_MESSAGE_BUFFER_RESPONSE,
RDMA_MESSAGE_TENSOR_REQUEST,
RDMA_MESSAGE_TENSOR_WRITE
};
enum BufferStatus {none, idle, busy};
enum Location {local, remote};
enum BufferType {ACK, MESSAGE, TENSOR};
enum RdmaMessageType {RDMA_MESSAGE_ACK,
RDMA_MESSAGE_BUFFER_IDLE,
RDMA_MESSAGE_BUFFER_REQUEST,
RDMA_MESSAGE_BUFFER_RESPONSE,
RDMA_MESSAGE_TENSOR_REQUEST,
RDMA_MESSAGE_TENSOR_WRITE};
class RdmaBuffer;
// Class that represents the Rdma Adapter.
// Responsible for creation of the completion queue, and handling
// of work completions.
// of work completions.
class RdmaAdapter {
friend class RdmaChannel;
friend class RdmaBuffer;
friend class RdmaAckBuffer;
friend class RdmaMessageBuffer;
friend class RdmaTensorBuffer;
friend class RdmaMgr;
friend class RdmaRemoteRendezvous;
friend class RdmaChannel;
friend class RdmaBuffer;
friend class RdmaAckBuffer;
friend class RdmaMessageBuffer;
friend class RdmaTensorBuffer;
friend class RdmaMgr;
friend class RdmaRemoteRendezvous;
public:
RdmaAdapter(const WorkerEnv* worker_env);
~RdmaAdapter();
// Adapter name, e.g. mlx5_0.
string name() const;
void Process_CQ();
protected:
static const int MAX_CONCURRENT_WRITES = 1000;
ibv_context* context_;
@ -94,36 +97,39 @@ class RdmaAdapter {
// Class that represents a connection to a remote Rdma peer.
// Responsible for connecting queue pairs.
class RdmaChannel {
friend class RdmaAdapter;
friend class RdmaBuffer;
friend class RdmaAckBuffer;
friend class RdmaMessageBuffer;
friend class RdmaTensorBuffer;
friend class RdmaMgr;
friend class RdmaRemoteRendezvous;
friend class RdmaAdapter;
friend class RdmaBuffer;
friend class RdmaAckBuffer;
friend class RdmaMessageBuffer;
friend class RdmaTensorBuffer;
friend class RdmaMgr;
friend class RdmaRemoteRendezvous;
public:
explicit RdmaChannel(const RdmaAdapter* adapter, const string local_name,
const string remote_name_);
explicit RdmaChannel(const RdmaAdapter* adapter, const string local_name,
const string remote_name_);
~RdmaChannel();
inline const RdmaAddress& self() { return self_; }
RdmaAddress address() const;
inline const std::vector<RdmaBuffer*>& message_buffers() const {
return message_buffers_;}
return message_buffers_;
}
void Connect(const RdmaAddress& remoteAddr);
void Connect();
void Recv();
RdmaBuffer* FindBuffer(const uint32_t index);
RdmaBuffer* FindBuffer(const string& name);
RdmaBuffer* FindOrCreateBuffer(const string& name,
RdmaBuffer* FindOrCreateBuffer(const string& name,
BufferType buffer_type = TENSOR);
uint32_t LookupBufferIndex (const string& buffer_name);
uint32_t LookupBufferIndex(const string& buffer_name);
void SetRemoteAddress(const RdmaAddress& ra, bool override);
void InsertRecvCallback(const string& key, std::function<void()> recv_done);
void RemoveRecvCallback(const string& key);
void RunRecvCallback(const string& key);
static const int kNumMessageBuffers = 4;
protected:
const RdmaAdapter* adapter_;
const RdmaAdapter* adapter_;
RdmaAddress self_;
string local_name_;
string remote_name_;
@ -151,10 +157,11 @@ class RdmaChannel {
// Class that represents a buffer for Rdma writes and reads.
class RdmaBuffer {
friend class RdmaChannel;
friend class RdmaAdapter;
friend class RdmaMgr;
friend class RdmaRemoteRendezvous;
friend class RdmaChannel;
friend class RdmaAdapter;
friend class RdmaMgr;
friend class RdmaRemoteRendezvous;
public:
explicit RdmaBuffer(RdmaChannel* channel, string name);
virtual ~RdmaBuffer();
@ -173,10 +180,11 @@ class RdmaBuffer {
void FreeBuffer();
void EnqueueItem(string Item);
virtual void SendNextItem(){};
void CreateCPUBuffer(size_t size, bool lock=true);
void CreateCPUBuffer(size_t size, bool lock = true);
void SetRemoteMR(RemoteMR rmi, bool override);
uint32_t LookupBufferIndex (const string& buffer_name) {
return const_cast<RdmaChannel*>(channel_)->LookupBufferIndex(buffer_name);}
uint32_t LookupBufferIndex(const string& buffer_name) {
return const_cast<RdmaChannel*>(channel_)->LookupBufferIndex(buffer_name);
}
void Write(uint32_t imm_data, size_t buffer_size);
protected:
@ -188,7 +196,7 @@ class RdmaBuffer {
ibv_mr* self_ = nullptr;
mutex mu_;
RemoteMR remote_;
std::queue <string> queue_ GUARDED_BY(mu_);
std::queue<string> queue_ GUARDED_BY(mu_);
BufferStatus local_status_ GUARDED_BY(mu_) = none;
BufferStatus remote_status_ GUARDED_BY(mu_) = none;
};
@ -201,8 +209,9 @@ class RdmaAckBuffer : public RdmaBuffer {
};
class RdmaMessageBuffer : public RdmaBuffer {
friend class RdmaChannel;
friend class RdmaAapater;
friend class RdmaChannel;
friend class RdmaAapater;
public:
explicit RdmaMessageBuffer(RdmaChannel* channel, string name);
virtual ~RdmaMessageBuffer() override {}
@ -228,40 +237,41 @@ struct RdmaMessage {
DataType data_type_;
TensorShape tensor_shape_;
size_t tensor_bytes_;
// type|name_size|name|step_id|buffer_size|remote_addr|rkey|is_dead|...
// 1B| 2B | 512| 8B | 8B | 8B | 4B | 1B |...
// ...|data_type|tensor_shape|tensor_bytes|tensor_buffer
// ...| XB | XB | 8B |...
//
// type|name_size|name|step_id|buffer_size|remote_addr|rkey|is_dead|...
// 1B| 2B | 512| 8B | 8B | 8B | 4B | 1B |...
// ...|data_type|tensor_shape|tensor_bytes|tensor_buffer
// ...| XB | XB | 8B |...
//
static const size_t kNameCapacity = 512;
static const size_t kTypeStartIndex = 0;
static const size_t kNameSizeStartIndex = kTypeStartIndex + sizeof(type_);
static const size_t kNameStartIndex = kNameSizeStartIndex + sizeof(name_size_);
static const size_t kNameStartIndex =
kNameSizeStartIndex + sizeof(name_size_);
static const size_t kStepIdStartIndex = kNameStartIndex + kNameCapacity;
static const size_t kBufferSizeStartIndex = kStepIdStartIndex
+ sizeof(step_id_);
static const size_t kRemoteAddrStartIndex = kBufferSizeStartIndex
+ sizeof(buffer_size_);
static const size_t kRkeyStartIndex = kRemoteAddrStartIndex
+ sizeof(remote_addr_);
static const size_t kBufferSizeStartIndex =
kStepIdStartIndex + sizeof(step_id_);
static const size_t kRemoteAddrStartIndex =
kBufferSizeStartIndex + sizeof(buffer_size_);
static const size_t kRkeyStartIndex =
kRemoteAddrStartIndex + sizeof(remote_addr_);
static const size_t kIsDeadStartIndex = kRkeyStartIndex + sizeof(rkey_);
static const size_t kDataTypeStartIndex = kIsDeadStartIndex
+ sizeof(is_dead_);
static const size_t kTensorShapeStartIndex = kDataTypeStartIndex
+ sizeof(data_type_);
static const size_t kTensorBytesStartIndex = kTensorShapeStartIndex
+ sizeof(TensorShape);
static const size_t kTensorBufferStartIndex = kTensorBytesStartIndex
+ sizeof(tensor_bytes_);
static const size_t kDataTypeStartIndex =
kIsDeadStartIndex + sizeof(is_dead_);
static const size_t kTensorShapeStartIndex =
kDataTypeStartIndex + sizeof(data_type_);
static const size_t kTensorBytesStartIndex =
kTensorShapeStartIndex + sizeof(TensorShape);
static const size_t kTensorBufferStartIndex =
kTensorBytesStartIndex + sizeof(tensor_bytes_);
static const size_t kMessageTotalBytes = kTensorBufferStartIndex;
static const size_t kRdmaMessageBufferSize = kMessageTotalBytes;
static const size_t kRdmaAckBufferSize = kMessageTotalBytes;
static string CreateMessage(const RdmaMessage & rm);
static string CreateMessage(const RdmaMessage& rm);
static void ParseMessage(RdmaMessage& rm, void* buffer);
};
} // namespace tensorflow
} // namespace tensorflow
#endif // TENSORFLOW_USE_VERBS
#endif // TENSORFLOW_USE_VERBS
#endif // THIRD_PARTY_TENSORFLOW_CONTRIB_VERBS_RDMA_H_

View File

@ -15,8 +15,8 @@ limitations under the License.
#ifdef TENSORFLOW_USE_VERBS
#include<vector>
#include "tensorflow/contrib/verbs/rdma_mgr.h"
#include <vector>
#include "tensorflow/contrib/verbs/grpc_verbs_client.h"
#include "tensorflow/contrib/verbs/verbs_service.pb.h"
#include "tensorflow/core/distributed_runtime/rpc/grpc_worker_cache.h"
@ -25,7 +25,7 @@ limitations under the License.
namespace tensorflow {
RdmaMgr::RdmaMgr(const WorkerEnv* const worker_env,
RdmaMgr::RdmaMgr(const WorkerEnv* const worker_env,
GrpcChannelCache* const channel_cache)
: worker_env_(worker_env), channel_cache_(channel_cache) {
rdma_adapter_ = new RdmaAdapter(worker_env_);
@ -34,14 +34,15 @@ RdmaMgr::RdmaMgr(const WorkerEnv* const worker_env,
// need to pass in session handle
local_worker_ = worker_env_->session_mgr->LegacySession()->worker_name;
std::vector<string> workers;
worker_env_->session_mgr->LegacySession()->
worker_cache->ListWorkers(&workers);
num_remote_workers_ = workers.size()-1;
worker_env_->session_mgr->LegacySession()->worker_cache->ListWorkers(
&workers);
num_remote_workers_ = workers.size() - 1;
VLOG(2) << "rmda_mgr on local worker: " << local_worker_;
for (size_t i = 0; i < workers.size(); i++) {
if (local_worker_.compare(workers[i]) != 0) {
channel_table_.insert({workers[i], new RdmaChannel(rdma_adapter_,
local_worker_, workers[i])});
channel_table_.insert(
{workers[i],
new RdmaChannel(rdma_adapter_, local_worker_, workers[i])});
}
}
}
@ -49,16 +50,16 @@ RdmaMgr::RdmaMgr(const WorkerEnv* const worker_env,
// Setup Rdma channels between peers.
// This is done at the beginning of the server setup.
void RdmaMgr::SetupChannels() {
void RdmaMgr::SetupChannels() {
for (const auto& p : channel_table_) {
string worker_name = p.first;
LOG(INFO) << "connecting to remote node " << worker_name;
RdmaChannel* rc = p.second;
GetRemoteAddressRequest req;
GetRemoteAddressResponse resp;
GetRemoteAddressResponse resp;
// get the channel cache
SharedGrpcChannelPtr client_channel = channel_cache_
->FindWorkerChannel(worker_name);
SharedGrpcChannelPtr client_channel =
channel_cache_->FindWorkerChannel(worker_name);
GrpcVerbsClient* client = new GrpcVerbsClient(client_channel);
CHECK(client != nullptr) << "No worker known as " << worker_name;
@ -70,8 +71,8 @@ void RdmaMgr::SetupChannels() {
channel_info->set_psn(rc->self_.psn);
for (int i = 0; i < RdmaChannel::kNumMessageBuffers; i++) {
MemoryRegion* mr = req.add_mr();
mr->set_remote_addr(reinterpret_cast<uint64_t>(
rc->message_buffers_[i]->buffer_));
mr->set_remote_addr(
reinterpret_cast<uint64_t>(rc->message_buffers_[i]->buffer_));
mr->set_rkey(rc->message_buffers_[i]->self_->rkey);
}
// synchronous call
@ -79,10 +80,10 @@ void RdmaMgr::SetupChannels() {
// save obtained remote addresses
// connect to the remote channel
if (s.ok()) {
CHECK(worker_name.compare(resp.host_name())==0);
CHECK(worker_name.compare(resp.host_name()) == 0);
RdmaAddress ra;
ra.lid = resp.channel().lid();
ra.qpn = resp.channel().qpn();
ra.qpn = resp.channel().qpn();
ra.psn = resp.channel().psn();
rc->SetRemoteAddress(ra, false);
rc->Connect();
@ -112,7 +113,7 @@ void RdmaMgr::SetupChannels() {
RdmaMgr::~RdmaMgr() {
for (const auto& p : channel_table_) delete p.second;
channel_table_.clear();
channel_table_.clear();
delete rdma_adapter_;
}

View File

@ -28,9 +28,8 @@ limitations under the License.
namespace tensorflow {
class RdmaMgr {
public:
explicit RdmaMgr(const WorkerEnv* const worker_env,
explicit RdmaMgr(const WorkerEnv* const worker_env,
GrpcChannelCache* const channel_cache);
~RdmaMgr();
RdmaChannel* FindChannel(const string& key);
@ -45,11 +44,11 @@ class RdmaMgr {
RdmaAdapter* rdma_adapter_;
typedef std::unordered_map<string, RdmaChannel*> ChannelTable;
ChannelTable channel_table_;
TF_DISALLOW_COPY_AND_ASSIGN(RdmaMgr);
};
} // namespace tensorflow
} // namespace tensorflow
#endif // TENSORFLOW_USE_VERBS
#endif // TENSORFLOW_USE_VERBS
#endif // THIRD_PARTY_TENSORFLOW_CONTRIB_VERBS_RDMA_MGR_H_

View File

@ -15,8 +15,8 @@ limitations under the License.
#ifdef TENSORFLOW_USE_VERBS
#include <unordered_set>
#include "tensorflow/contrib/verbs/rdma_rendezvous_mgr.h"
#include <unordered_set>
#include "tensorflow/contrib/verbs/verbs_util.h"
#include "tensorflow/core/common_runtime/device.h"
#include "tensorflow/core/common_runtime/device_mgr.h"
@ -29,14 +29,16 @@ namespace tensorflow {
class RdmaRemoteRendezvous : public BaseRemoteRendezvous {
public:
RdmaRemoteRendezvous(const WorkerEnv* env, const string& worker_name,
RdmaRemoteRendezvous(const WorkerEnv* env, const string& worker_name,
int64 step_id, RdmaMgr* rdma_mgr)
: BaseRemoteRendezvous(env, worker_name, step_id, true),
: BaseRemoteRendezvous(env, worker_name, step_id, true),
rdma_mgr_(rdma_mgr) {}
protected:
void RecvFromRemoteAsync(const Rendezvous::ParsedKey& parsed,
const Rendezvous::Args& args,
DoneCallback done) override;
private:
~RdmaRemoteRendezvous() override {}
RdmaMgr* rdma_mgr_;
@ -45,13 +47,13 @@ class RdmaRemoteRendezvous : public BaseRemoteRendezvous {
};
void RdmaRemoteRendezvous::RecvFromRemoteAsync(
const Rendezvous::ParsedKey& parsed, const Rendezvous::Args& recv_args,
DoneCallback done) {
const Rendezvous::ParsedKey& parsed, const Rendezvous::Args& recv_args,
DoneCallback done) {
Status s;
// parse src_name and dst_name
string src_name, dst_name, unused;
if (!DeviceNameUtils::SplitDeviceName(parsed.src_device,
&src_name, &unused)) {
if (!DeviceNameUtils::SplitDeviceName(parsed.src_device, &src_name,
&unused)) {
s = errors::Internal("Could not parse src name.");
}
CHECK(s.ok()) << "s is not ok, error code " << s.error_message();
@ -59,8 +61,8 @@ void RdmaRemoteRendezvous::RecvFromRemoteAsync(
done(s, Args(), recv_args, Tensor{}, false);
return;
}
if (!DeviceNameUtils::SplitDeviceName(parsed.dst_device,
&dst_name, &unused)) {
if (!DeviceNameUtils::SplitDeviceName(parsed.dst_device, &dst_name,
&unused)) {
s = errors::Internal("Could not parse dst name.");
}
CHECK(s.ok()) << "s is not ok, error code " << s.error_message();
@ -73,52 +75,52 @@ void RdmaRemoteRendezvous::RecvFromRemoteAsync(
string key(std::move(parsed.FullKey().ToString()));
string key_with_step_id = VerbsUtil::AppendStepidToKey(key, step_id_);
// insert callback
rc->InsertRecvCallback(key_with_step_id,
[this, key, key_with_step_id, rc, recv_args, parsed, done](){
Status s;
Device* src_dev;
s = env_->device_mgr->LookupDevice("CPU:0", &src_dev);
CHECK(s.ok()) << "s is not ok, error code " << s.error_message();
if (!s.ok()) {
done(s, Args(), recv_args, Tensor(), true);
return;
}
Device* dst_dev;
s = env_->device_mgr->LookupDevice(parsed.dst_device, &dst_dev);
CHECK(s.ok()) << "s is not ok, error code " << s.error_message();
if (!s.ok()) {
done(s, Args(), recv_args, Tensor(), true);
return;
}
RdmaBuffer* rb = rc->FindBuffer(key);
RdmaMessage rm;
CHECK(rb->size_ >= RdmaMessage::kMessageTotalBytes);
RdmaMessage::ParseMessage(rm, rb->buffer_);
CHECK(rm.type_ == RDMA_MESSAGE_TENSOR_WRITE);
Tensor val;
if (!rm.is_dead_) {
void* input = static_cast<char*>(rb->buffer_) +
RdmaMessage::kTensorBufferStartIndex;
TensorProto proto;
CHECK(rm.tensor_bytes_ + RdmaMessage::kTensorBufferStartIndex <= rb->size_);
CHECK(ParseProtoUnlimited(&proto, input, rm.tensor_bytes_))
<< "fail to parse proto from array";
s = dst_dev->MakeTensorFromProto(proto,
recv_args.alloc_attrs, &val);
}
rc->RemoveRecvCallback(key_with_step_id);
// create message
RdmaMessage br;
br.type_ = RDMA_MESSAGE_BUFFER_IDLE;
br.name_size_ = key.size();
br.name_ = key;
string message = RdmaMessage::CreateMessage(br);
RdmaBuffer* tb = rc->tx_message_buffer_;
tb->EnqueueItem(message);
tb->SendNextItem();
done(s, Args(), recv_args, val, rm.is_dead_);
});
rc->InsertRecvCallback(key_with_step_id, [this, key, key_with_step_id, rc,
recv_args, parsed, done]() {
Status s;
Device* src_dev;
s = env_->device_mgr->LookupDevice("CPU:0", &src_dev);
CHECK(s.ok()) << "s is not ok, error code " << s.error_message();
if (!s.ok()) {
done(s, Args(), recv_args, Tensor(), true);
return;
}
Device* dst_dev;
s = env_->device_mgr->LookupDevice(parsed.dst_device, &dst_dev);
CHECK(s.ok()) << "s is not ok, error code " << s.error_message();
if (!s.ok()) {
done(s, Args(), recv_args, Tensor(), true);
return;
}
RdmaBuffer* rb = rc->FindBuffer(key);
RdmaMessage rm;
CHECK(rb->size_ >= RdmaMessage::kMessageTotalBytes);
RdmaMessage::ParseMessage(rm, rb->buffer_);
CHECK(rm.type_ == RDMA_MESSAGE_TENSOR_WRITE);
Tensor val;
if (!rm.is_dead_) {
void* input = static_cast<char*>(rb->buffer_) +
RdmaMessage::kTensorBufferStartIndex;
TensorProto proto;
CHECK(rm.tensor_bytes_ + RdmaMessage::kTensorBufferStartIndex <=
rb->size_);
CHECK(ParseProtoUnlimited(&proto, input, rm.tensor_bytes_))
<< "fail to parse proto from array";
s = dst_dev->MakeTensorFromProto(proto, recv_args.alloc_attrs, &val);
}
rc->RemoveRecvCallback(key_with_step_id);
// create message
RdmaMessage br;
br.type_ = RDMA_MESSAGE_BUFFER_IDLE;
br.name_size_ = key.size();
br.name_ = key;
string message = RdmaMessage::CreateMessage(br);
RdmaBuffer* tb = rc->tx_message_buffer_;
tb->EnqueueItem(message);
tb->SendNextItem();
done(s, Args(), recv_args, val, rm.is_dead_);
});
// append key to message queue
RdmaBuffer* rb = rc->tx_message_buffer_;
RdmaMessage rm;
@ -141,7 +143,7 @@ BaseRemoteRendezvous* RdmaRendezvousMgr::Create(int64 step_id,
const string& worker_name) {
return new RdmaRemoteRendezvous(worker_env, worker_name, step_id, rdma_mgr_);
}
} // end namespace tensorflow
#endif

View File

@ -47,12 +47,12 @@ class RdmaRendezvousMgr : public BaseRendezvousMgr {
public:
explicit RdmaRendezvousMgr(const WorkerEnv* env, const string& worker_name,
WorkerCacheInterface* worker_cache);
void SetRdmaMgr(RdmaMgr* rdma_mgr) {
rdma_mgr_ = rdma_mgr;
}
void SetRdmaMgr(RdmaMgr* rdma_mgr) { rdma_mgr_ = rdma_mgr; }
protected:
BaseRemoteRendezvous* Create(int64 step_id, const WorkerEnv* worker_env,
const string& worker_name) override;
const string& worker_name) override;
private:
RdmaMgr* rdma_mgr_;
TF_DISALLOW_COPY_AND_ASSIGN(RdmaRendezvousMgr);
@ -60,5 +60,5 @@ class RdmaRendezvousMgr : public BaseRendezvousMgr {
} // end namespace tensorflow
#endif // TENSORFLOW_USE_VERBS
#endif // TENSORFLOW_USE_VERBS
#endif // THIRD_PARTY_TENSORFLOW_CONTRIB_VERBS_RDMA_RENDEZVOUS_MGR_H_

View File

@ -27,8 +27,9 @@ namespace tensorflow {
namespace {
// static utility function
RendezvousMgrInterface* NewRdmaRendezvousMgr(const WorkerEnv* env,
const string& worker_name, WorkerCacheInterface* worker_cache) {
RendezvousMgrInterface* NewRdmaRendezvousMgr(
const WorkerEnv* env, const string& worker_name,
WorkerCacheInterface* worker_cache) {
return new RdmaRendezvousMgr(env, worker_name, worker_cache);
}
@ -46,7 +47,7 @@ VerbsServer::~VerbsServer() {
}
Status VerbsServer::ChannelCacheFactory(const ServerDef& server_def,
GrpcChannelCache** channel_cache) {
GrpcChannelCache** channel_cache) {
string name_prefix =
strings::StrCat("/job:", server_def.job_name(), "/replica:0",
"/task:", server_def.task_index());
@ -54,41 +55,43 @@ Status VerbsServer::ChannelCacheFactory(const ServerDef& server_def,
GrpcChannelSpec channel_spec;
TF_RETURN_IF_ERROR(ParseChannelSpec(server_def, &channel_spec));
*channel_cache = NewGrpcChannelCache(channel_spec,
GetChannelCreationFunction(server_def));
*channel_cache =
NewGrpcChannelCache(channel_spec, GetChannelCreationFunction(server_def));
const string host_port = (*channel_cache)->TranslateTask(name_prefix);
int requested_port;
if (!strings::safe_strto32(str_util::Split(host_port, ':')[1],
&requested_port)) {
return errors::Internal("Could not parse port for local server from \"",
(*channel_cache)->TranslateTask(name_prefix), "\".");
(*channel_cache)->TranslateTask(name_prefix),
"\".");
}
if (requested_port != bound_port()) {
return errors::InvalidArgument("Requested port ", requested_port,
" differs from expected port ", bound_port());
" differs from expected port ",
bound_port());
}
return Status::OK();
}
Status VerbsServer::Init(ServiceInitFunction service_func,
RendezvousMgrCreationFunction rendezvous_mgr_func) {
RendezvousMgrCreationFunction rendezvous_mgr_func) {
Status s = GrpcServer::Init(service_func, rendezvous_mgr_func);
{
mutex_lock l(mu_);
CHECK_EQ(verbs_state_, DISCONNECTED);
CHECK(ChannelCacheFactory(server_def(), &channel_cache_).ok());
rdma_mgr_ = new RdmaMgr(worker_env(), channel_cache_);
// set rdma_mgr for verbs_service and rdma_rendezvous_mgr
// set rdma_mgr for verbs_service and rdma_rendezvous_mgr
verbs_service_->SetRdmaMgr(rdma_mgr_);
// hardcoded to default session (legacy_session_)
// TODO: use WorkerSessionForSession
// need to pass in session handle
dynamic_cast<RdmaRendezvousMgr*>(worker_env()->session_mgr->
LegacySession()->rendezvous_mgr.get())
->SetRdmaMgr(rdma_mgr_);
dynamic_cast<RdmaRendezvousMgr*>(
worker_env()->session_mgr->LegacySession()->rendezvous_mgr.get())
->SetRdmaMgr(rdma_mgr_);
}
return s;
}
@ -100,9 +103,9 @@ Status VerbsServer::Start() {
if (verbs_state_ == DISCONNECTED) {
// verbs_thread needs to be initiated
// before rdma_mgr sets up the rdma channels.
verbs_thread_.reset(
worker_env()->env->StartThread(ThreadOptions(), "TF_verbs_service",
[this] { verbs_service_->HandleRPCsLoop(); }));
verbs_thread_.reset(worker_env()->env->StartThread(
ThreadOptions(), "TF_verbs_service",
[this] { verbs_service_->HandleRPCsLoop(); }));
rdma_mgr_->SetupChannels();
verbs_state_ = CONNECTED;
}
@ -124,10 +127,10 @@ Status VerbsServer::Join() {
/* static */
Status VerbsServer::Create(const ServerDef& server_def, Env* env,
std::unique_ptr<ServerInterface>* out_server) {
std::unique_ptr<ServerInterface>* out_server) {
std::unique_ptr<VerbsServer> ret(new VerbsServer(server_def, Env::Default()));
ServiceInitFunction service_func = [&ret](const WorkerEnv* worker_env,
::grpc::ServerBuilder* builder) {
::grpc::ServerBuilder* builder) {
return SetNewVerbsService(&ret->verbs_service_, worker_env, builder);
};
TF_RETURN_IF_ERROR(ret->Init(service_func, NewRdmaRendezvousMgr));

View File

@ -18,8 +18,8 @@ limitations under the License.
#ifdef TENSORFLOW_USE_VERBS
#include "tensorflow/contrib/verbs/rdma_mgr.h"
#include "tensorflow/contrib/verbs/grpc_verbs_service.h"
#include "tensorflow/contrib/verbs/rdma_mgr.h"
#include "tensorflow/core/distributed_runtime/rpc/grpc_server_lib.h"
namespace tensorflow {
@ -27,7 +27,7 @@ namespace tensorflow {
class VerbsServer : public GrpcServer {
protected:
VerbsServer(const ServerDef& server_def, Env* env);
public:
static Status Create(const ServerDef& server_def, Env* env,
std::unique_ptr<ServerInterface>* out_server);
@ -39,21 +39,22 @@ class VerbsServer : public GrpcServer {
// Implementations of ServerInterface methods.
Status Start() override;
Status Join() override;
protected:
Status Init(ServiceInitFunction service_func,
RendezvousMgrCreationFunction rendezvous_mgr_func);
Status ChannelCacheFactory(const ServerDef& server_def,
GrpcChannelCache** channel_cache);
private:
RdmaMgr* rdma_mgr_;
// Guards state transitions.
mutex mu_;
enum State { DISCONNECTED, CONNECTED };
State verbs_state_ GUARDED_BY(mu_);
GrpcVerbsService* verbs_service_ = nullptr;
std::unique_ptr<Thread> verbs_thread_ GUARDED_BY(mu_);
GrpcChannelCache* channel_cache_ = nullptr;
@ -61,5 +62,5 @@ class VerbsServer : public GrpcServer {
} // namespace tensorflow
#endif // TENSORFLOW_USE_VERBS
#endif // TENSORFLOW_USE_VERBS
#endif // THIRD_PARTY_TENSORFLOW_CONTRIB_VERBS_VERBS_SERVER_LIB_H_

View File

@ -29,7 +29,7 @@ option java_package = "org.tensorflow.contrib.verbs";
message Channel {
int32 lid = 1;
int32 qpn = 2;
int32 psn = 3;
int32 psn = 3;
}
message MemoryRegion {
@ -39,15 +39,14 @@ message MemoryRegion {
message GetRemoteAddressRequest {
string host_name = 1;
Channel channel = 2;
repeated MemoryRegion mr = 3;
repeated MemoryRegion mr = 3;
}
message GetRemoteAddressResponse {
string host_name = 1;
Channel channel = 2;
repeated MemoryRegion mr = 3;
}
repeated MemoryRegion mr = 3;
}
////////////////////////////////////////////////////////////////////////////////
//
@ -56,5 +55,6 @@ message GetRemoteAddressResponse {
////////////////////////////////////////////////////////////////////////////////
service VerbsService {
rpc GetRemoteAddress(GetRemoteAddressRequest) returns (GetRemoteAddressResponse);
rpc GetRemoteAddress(GetRemoteAddressRequest)
returns (GetRemoteAddressResponse);
}

View File

@ -22,30 +22,27 @@ namespace tensorflow {
// static sync wrapper:
Status VerbsUtil::SetProtoFromGPUSync(const Tensor& tensor, Device* dev,
const DeviceContext* device_context,
TensorProto* proto, bool is_dead) {
const DeviceContext* device_context,
TensorProto* proto, bool is_dead) {
Notification n;
Status status;
GPUUtil::SetProtoFromGPU(tensor, dev,
device_context,
proto, is_dead,
[&n, &status](const Status& s) {
status = s;
n.Notify();
});
GPUUtil::SetProtoFromGPU(tensor, dev, device_context, proto, is_dead,
[&n, &status](const Status& s) {
status = s;
n.Notify();
});
n.WaitForNotification();
return status;
}
//static
string VerbsUtil::AppendStepidToKey(const string& key,
int64 step_id) {
// static
string VerbsUtil::AppendStepidToKey(const string& key, int64 step_id) {
return strings::StrCat(key, ";", step_id);
}
// static
void VerbsUtil::GetKeyAndStepId(const string& key_with_step_id,
string& key, int64& step_id) {
void VerbsUtil::GetKeyAndStepId(const string& key_with_step_id, string& key,
int64& step_id) {
StringPiece s(key_with_step_id);
// a key (with step_id) has exact 6 parts if split by ";"
// part 1: src_device;
@ -55,10 +52,10 @@ void VerbsUtil::GetKeyAndStepId(const string& key_with_step_id,
// part 5: frame_iter.frame_id:frame_iter.iter_id
// part 6: step_id
std::vector<string> parts = str_util::Split(s, ';');
CHECK(parts.size()==6) << "Key with step_id must have 6 parts";
CHECK(parts.size() == 6) << "Key with step_id must have 6 parts";
strings::safe_strto64(parts[5], &step_id);
parts.pop_back(); // remove step_id
key.assign(str_util::Join(parts, ";")); // stitch them together
parts.pop_back(); // remove step_id
key.assign(str_util::Join(parts, ";")); // stitch them together
}
} // namespace tensorflow

View File

@ -28,14 +28,13 @@ class TensorProto;
class VerbsUtil {
public:
// synchronous wrapper of SetProtoFromGPU
static Status SetProtoFromGPUSync(const Tensor& tensor, Device* dev,
const DeviceContext* device_context,
TensorProto* proto, bool is_dead);
const DeviceContext* device_context,
TensorProto* proto, bool is_dead);
static string AppendStepidToKey(const string& key, int64 step_id);
static void GetKeyAndStepId(const string& key_with_step_id, string& key,
int64& step_id);
static void GetKeyAndStepId(const string& key_with_step_id, string& key,
int64& step_id);
};
} // namespace tensorflow

View File

@ -1570,6 +1570,7 @@ tf_cuda_library(
":lib_internal",
":proto_text",
":protos_all_cc",
"//tensorflow/core/kernels:function_ops",
],
alwayslink = 1,
)

View File

@ -30,6 +30,11 @@ struct BuildGraphOptions {
// the former via "ref" fetch_endpoints.
std::vector<string> target_nodes;
// If `true`, uses Arg/Retval to implement feeds/fetches; otherwise
// uses Recv/Send to implement feeds/fetches.
// TODO(mrry): Remove this when the distributed runtime supports Arg/Retval.
bool use_function_convention = false;
string DebugString() const;
};

View File

@ -43,7 +43,7 @@ namespace tensorflow {
namespace {
bool IsConstantFoldable(const Node* n,
std::function<bool(const Node*)> consider) {
const std::function<bool(const Node*)>& consider) {
if (n->op_def().is_stateful()) {
return false;
}

View File

@ -71,7 +71,8 @@ void CopyTensor::ViaDMA(StringPiece edge_name, DeviceContext* send_dev_context,
if (ri.sender_device_type == src_device_type &&
ri.receiver_device_type == dst_device_type) {
ri.copy_function(send_dev_context, recv_dev_context, src, dst,
src_alloc_attr, dst_alloc_attr, input, output, done);
src_alloc_attr, dst_alloc_attr, input, output,
std::move(done));
return;
}
}

View File

@ -361,7 +361,6 @@ Status DirectSession::ExtendLocked(const GraphDef& graph) {
return Status::OK();
}
// TODO(yuanbyu): Simplify by treating Run() as "PRunSetup(); PRun()".
Status DirectSession::Run(const NamedTensorList& inputs,
const std::vector<string>& output_names,
const std::vector<string>& target_nodes,
@ -426,13 +425,34 @@ Status DirectSession::Run(const RunOptions& run_options,
executor_step_count, input_tensor_names, output_names, target_nodes));
}
// Configure a call frame for the step, which we use to feed and
// fetch values to and from the executors.
FunctionCallFrame call_frame(executors_and_keys->input_types,
executors_and_keys->output_types);
gtl::InlinedVector<Tensor, 4> feed_args(inputs.size());
for (const auto& it : inputs) {
if (it.second.dtype() == DT_RESOURCE) {
Tensor tensor_from_handle;
TF_RETURN_IF_ERROR(
ResourceHandleToInputTensor(it.second, &tensor_from_handle));
feed_args[executors_and_keys->input_name_to_index[it.first]] =
tensor_from_handle;
} else {
feed_args[executors_and_keys->input_name_to_index[it.first]] = it.second;
}
}
Status s = call_frame.SetArgs(feed_args);
if (errors::IsInternal(s)) {
return errors::InvalidArgument(s.error_message());
} else if (!s.ok()) {
return s;
}
// Create a run state and start execution.
RunState run_state(args.step_id, &devices_);
run_state.rendez = new IntraProcessRendezvous(device_mgr_.get());
CancellationManager step_cancellation_manager;
// Send inputs.
TF_RETURN_IF_ERROR(SendInputs(inputs, executors_and_keys, run_state.rendez));
args.call_frame = &call_frame;
// Start parallel Executors.
const size_t num_executors = executors_and_keys->items.size();
@ -535,8 +555,22 @@ Status DirectSession::Run(const RunOptions& run_options,
}
// Receive outputs.
TF_RETURN_IF_ERROR(
RecvOutputs(output_names, executors_and_keys, &run_state, outputs));
if (outputs) {
std::vector<Tensor> sorted_outputs;
Status s = call_frame.ConsumeRetvals(&sorted_outputs);
if (errors::IsInternal(s)) {
return errors::InvalidArgument(s.error_message());
} else if (!s.ok()) {
return s;
}
outputs->clear();
outputs->reserve(sorted_outputs.size());
for (const string& output_name : output_names) {
outputs->emplace_back(
std::move(sorted_outputs[executors_and_keys
->output_name_to_index[output_name]]));
}
}
// Save the output tensors of this run we choose to keep.
TF_RETURN_IF_ERROR(
@ -706,11 +740,11 @@ Status DirectSession::PRun(const string& handle, const NamedTensorList& inputs,
CheckFetch(inputs, output_names, executors_and_keys, run_state));
// Send inputs.
Status s = SendInputs(inputs, executors_and_keys, run_state->rendez);
Status s = SendPRunInputs(inputs, executors_and_keys, run_state->rendez);
// Receive outputs.
if (s.ok()) {
s = RecvOutputs(output_names, executors_and_keys, run_state, outputs);
s = RecvPRunOutputs(output_names, executors_and_keys, run_state, outputs);
}
// Save the output tensors of this run we choose to keep.
@ -770,16 +804,17 @@ Status DirectSession::ResourceHandleToInputTensor(const Tensor& resource_tensor,
}
}
Status DirectSession::SendInputs(const NamedTensorList& inputs,
const ExecutorsAndKeys* executors_and_keys,
IntraProcessRendezvous* rendez) {
Status DirectSession::SendPRunInputs(const NamedTensorList& inputs,
const ExecutorsAndKeys* executors_and_keys,
IntraProcessRendezvous* rendez) {
Status s;
Rendezvous::ParsedKey parsed;
// Insert the input tensors into the local rendezvous by their
// rendezvous key.
for (const auto& input : inputs) {
auto it = executors_and_keys->input_keys.find(input.first);
if (it == executors_and_keys->input_keys.end()) {
auto it =
executors_and_keys->input_name_to_rendezvous_key.find(input.first);
if (it == executors_and_keys->input_name_to_rendezvous_key.end()) {
return errors::Internal("'", input.first, "' is not a pre-defined feed.");
}
const string& input_key = it->second;
@ -808,10 +843,10 @@ Status DirectSession::SendInputs(const NamedTensorList& inputs,
return Status::OK();
}
Status DirectSession::RecvOutputs(const std::vector<string>& output_names,
const ExecutorsAndKeys* executors_and_keys,
RunState* run_state,
std::vector<Tensor>* outputs) {
Status DirectSession::RecvPRunOutputs(
const std::vector<string>& output_names,
const ExecutorsAndKeys* executors_and_keys, RunState* run_state,
std::vector<Tensor>* outputs) {
Status s;
if (!output_names.empty()) {
outputs->resize(output_names.size());
@ -822,8 +857,9 @@ Status DirectSession::RecvOutputs(const std::vector<string>& output_names,
for (size_t output_offset = 0; output_offset < output_names.size();
++output_offset) {
const string& output_name = output_names[output_offset];
auto it = executors_and_keys->output_keys.find(output_name);
if (it == executors_and_keys->output_keys.end()) {
auto it =
executors_and_keys->output_name_to_rendezvous_key.find(output_name);
if (it == executors_and_keys->output_name_to_rendezvous_key.end()) {
return errors::Internal("'", output_name,
"' is not a pre-defined fetch.");
}
@ -987,14 +1023,16 @@ Status DirectSession::GetOrCreateExecutors(
options.feed_endpoints = inputs_sorted;
options.fetch_endpoints = outputs_sorted;
options.target_nodes = tn_sorted;
options.use_function_convention = !run_state_args->is_partial_run;
std::shared_ptr<ExecutorsAndKeys> ek(new ExecutorsAndKeys);
// The executor_lock_ is intentionally released while executor is
// being created.
std::unordered_map<string, std::unique_ptr<Graph>> graphs;
TF_RETURN_IF_ERROR(
CreateGraphs(options, &graphs, &ek->flib_def, run_state_args));
TF_RETURN_IF_ERROR(CreateGraphs(options, &graphs, &ek->flib_def,
run_state_args, &ek->input_types,
&ek->output_types));
if (run_state_args->is_partial_run) {
ek->graph = std::move(run_state_args->graph);
@ -1079,17 +1117,37 @@ Status DirectSession::GetOrCreateExecutors(
item->executor.reset(executor);
}
// Compute the rendezvous keys to avoid recomputing them every time.
//
// We always use the first device as the device name portion of the
// key, even if we're feeding another graph.
for (const string& input : inputs) {
ek->input_keys[input] = GetRendezvousKey(
input, device_set_.client_device()->attributes(), FrameAndIter(0, 0));
}
for (const string& output : outputs) {
ek->output_keys[output] = GetRendezvousKey(
output, device_set_.client_device()->attributes(), FrameAndIter(0, 0));
// Cache the mapping from input/output names to graph elements to
// avoid recomputing it every time.
if (!run_state_args->is_partial_run) {
// For regular `Run()`, we use the function calling convention, and so
// maintain a mapping from input/output names to
// argument/return-value ordinal index.
for (size_t i = 0; i < inputs_sorted.size(); ++i) {
const string& input = inputs_sorted[i];
ek->input_name_to_index[input] = i;
}
for (size_t i = 0; i < outputs_sorted.size(); ++i) {
const string& output = outputs_sorted[i];
ek->output_name_to_index[output] = i;
}
} else {
// For `PRun()`, we use the rendezvous calling convention, and so
// maintain a mapping from input/output names to rendezvous keys.
//
// We always use the first device as the device name portion of the
// key, even if we're feeding another graph.
for (size_t i = 0; i < inputs_sorted.size(); ++i) {
const string& input = inputs_sorted[i];
ek->input_name_to_rendezvous_key[input] = GetRendezvousKey(
input, device_set_.client_device()->attributes(), FrameAndIter(0, 0));
}
for (size_t i = 0; i < outputs_sorted.size(); ++i) {
const string& output = outputs_sorted[i];
ek->output_name_to_rendezvous_key[output] =
GetRendezvousKey(output, device_set_.client_device()->attributes(),
FrameAndIter(0, 0));
}
}
// Reacquire the lock, try to insert into the map.
@ -1110,7 +1168,8 @@ Status DirectSession::CreateGraphs(
const BuildGraphOptions& subgraph_options,
std::unordered_map<string, std::unique_ptr<Graph>>* outputs,
std::unique_ptr<FunctionLibraryDefinition>* flib_def,
RunStateArgs* run_state_args) {
RunStateArgs* run_state_args, DataTypeVector* input_types,
DataTypeVector* output_types) {
mutex_lock l(graph_def_lock_);
std::unique_ptr<SimpleClientGraph> client_graph;
@ -1135,6 +1194,23 @@ Status DirectSession::CreateGraphs(
execution_state->BuildGraph(subgraph_options, &client_graph));
}
if (subgraph_options.feed_endpoints.size() !=
client_graph->feed_types.size()) {
return errors::Internal(
"Graph pruning failed: requested number of feed endpoints = ",
subgraph_options.feed_endpoints.size(),
" versus number of pruned feed endpoints = ",
client_graph->feed_types.size());
}
if (subgraph_options.fetch_endpoints.size() !=
client_graph->fetch_types.size()) {
return errors::Internal(
"Graph pruning failed: requested number of fetch endpoints = ",
subgraph_options.fetch_endpoints.size(),
" versus number of pruned fetch endpoints = ",
client_graph->fetch_types.size());
}
auto current_stateful_placements = execution_state->GetStatefulPlacements();
// Update our current state based on the execution_state's
// placements. If there are any mismatches for a node,
@ -1240,6 +1316,8 @@ Status DirectSession::CreateGraphs(
}
}
*flib_def = std::move(client_graph->flib_def);
std::swap(*input_types, client_graph->feed_types);
std::swap(*output_types, client_graph->fetch_types);
return s;
}

View File

@ -132,8 +132,13 @@ class DirectSession : public Session {
NameNodeMap name_to_node;
std::unique_ptr<FunctionLibraryDefinition> flib_def;
std::vector<PerPartitionExecutorsAndLib> items;
std::unordered_map<string, string> input_keys;
std::unordered_map<string, string> output_keys;
std::unordered_map<string, size_t> input_name_to_index;
std::unordered_map<string, string> input_name_to_rendezvous_key;
std::unordered_map<string, size_t> output_name_to_index;
std::unordered_map<string, string> output_name_to_rendezvous_key;
DataTypeVector input_types;
DataTypeVector output_types;
};
// For each live partial execution, the session maintains a RunState.
@ -187,7 +192,8 @@ class DirectSession : public Session {
const BuildGraphOptions& options,
std::unordered_map<string, std::unique_ptr<Graph>>* outputs,
std::unique_ptr<FunctionLibraryDefinition>* flib_def,
RunStateArgs* run_state_args);
RunStateArgs* run_state_args, DataTypeVector* input_types,
DataTypeVector* output_types);
::tensorflow::Status ExtendLocked(const GraphDef& graph)
EXCLUSIVE_LOCKS_REQUIRED(graph_def_lock_);
@ -196,17 +202,17 @@ class DirectSession : public Session {
const Tensor& resource_tensor, Tensor* retrieved_tensor);
// Feeds more inputs to the executors, triggering further execution.
::tensorflow::Status SendInputs(
::tensorflow::Status SendPRunInputs(
const std::vector<std::pair<string, Tensor>>& inputs,
const ExecutorsAndKeys* executors_and_keys,
IntraProcessRendezvous* rendez);
// Fetches more outputs from the executors. It waits until the output
// tensors are computed.
::tensorflow::Status RecvOutputs(const std::vector<string>& output_names,
const ExecutorsAndKeys* executors_and_keys,
RunState* run_state,
std::vector<Tensor>* outputs);
::tensorflow::Status RecvPRunOutputs(
const std::vector<string>& output_names,
const ExecutorsAndKeys* executors_and_keys, RunState* run_state,
std::vector<Tensor>* outputs);
// Check if the specified fetches can be computed from the feeds
// that we have already provided.

View File

@ -1434,7 +1434,7 @@ void ExecutorState::RunAsync(Executor::DoneCallback done) {
} else {
num_outstanding_ops_ = ready.size();
root_frame_->iterations[0]->outstanding_ops = ready.size();
done_cb_ = done;
done_cb_ = std::move(done);
// Schedule to run all the ready ops in thread pool.
ScheduleReady(ready, nullptr);
}
@ -2560,7 +2560,7 @@ bool ExecutorState::FrameState::CleanupIterations(const GraphView* gview,
}
void ExecutorImpl::RunAsync(const Args& args, DoneCallback done) {
(new ExecutorState(args, this))->RunAsync(done);
(new ExecutorState(args, this))->RunAsync(std::move(done));
}
} // end namespace

View File

@ -604,7 +604,7 @@ struct CustomCreatorSingleton {
void Set(CustomKernelCreator cb) {
mutex_lock l(mu);
custom_creator = cb;
custom_creator = std::move(cb);
}
CustomKernelCreator Get() {
@ -621,7 +621,7 @@ CustomCreatorSingleton* GetCustomCreatorSingleton() {
} // end namespace
void RegisterDefaultCustomKernelCreator(CustomKernelCreator cb) {
GetCustomCreatorSingleton()->Set(cb);
GetCustomCreatorSingleton()->Set(std::move(cb));
}
FunctionLibraryRuntime* NewFunctionLibraryRuntime(
@ -631,7 +631,7 @@ FunctionLibraryRuntime* NewFunctionLibraryRuntime(
CustomKernelCreator custom_kernel_creator) {
return new FunctionLibraryRuntimeImpl(dmgr, env, device, graph_def_version,
lib_def, optimizer_options,
custom_kernel_creator);
std::move(custom_kernel_creator));
}
FunctionLibraryRuntime* NewFunctionLibraryRuntime(

View File

@ -44,7 +44,7 @@ Status GetOpSig(const string& op, const OpDef** sig) {
void FunctionTestSchedClosure(std::function<void()> fn) {
static thread::ThreadPool* w =
new thread::ThreadPool(Env::Default(), "Test", 8);
w->Schedule(fn);
w->Schedule(std::move(fn));
}
void HasError(const Status& s, const string& substr) {
@ -654,7 +654,8 @@ namespace {
bool DoNothing(Graph* g) { return false; }
string Optimize(std::function<bool(Graph* g)> pass, const FunctionDef& fdef) {
string Optimize(const std::function<bool(Graph* g)>& pass,
const FunctionDef& fdef) {
InstantiationResult result;
InstantiateAttrValueMap empty;
TF_CHECK_OK(InstantiateFunction(fdef, empty, GetOpSig, &result));

View File

@ -130,9 +130,11 @@ Status GraphRunner::Run(Graph* graph, FunctionLibraryRuntime* function_library,
}
// Call RewriteGraphForExecution
subgraph::RewriteGraphMetadata metadata;
TF_RETURN_IF_ERROR(subgraph::RewriteGraphForExecution(
graph_to_run.get(), input_names, output_names, {} /* target nodes */,
cpu_device_->attributes()));
cpu_device_->attributes(), false /* use_function_convention */,
&metadata));
// Create the local executor and the Rendezvous for fetching back the
// constants.

View File

@ -106,7 +106,7 @@ void IntraProcessRendezvous::SameWorkerRecvDone(
CopyTensor::ViaDMA(parsed.edge_name, send_args.device_context,
recv_args.device_context, src_device, dst_device,
send_args.alloc_attrs, recv_args.alloc_attrs, &in, out,
done);
std::move(done));
}
void IntraProcessRendezvous::RecvAsync(const ParsedKey& parsed,
@ -132,7 +132,8 @@ void IntraProcessRendezvous::RecvAsync(const ParsedKey& parsed,
};
if (status.ok() && in.IsInitialized()) {
SameWorkerRecvDone(parsed, send_args, recv_args, in, out, final_callback);
SameWorkerRecvDone(parsed, send_args, recv_args, in, out,
std::move(final_callback));
} else {
final_callback(status);
}

View File

@ -21,9 +21,9 @@ limitations under the License.
namespace tensorflow {
namespace {
// Replaces ReadVariableOp nodes which are only used by Sends and sinks with
// _UnsafeReadVariable nodes, as this transforamtion is safe and will improve
// performance.
// Replaces ReadVariableOp nodes which are only used by Sends, sinks,
// and function Retvals with _UnsafeReadVariable nodes, as this
// transformation is safe and will improve performance.
class ResourceVariableReadPass : public GraphOptimizationPass {
public:
Status Run(const GraphOptimizationPassOptions& options) override {
@ -43,7 +43,8 @@ class ResourceVariableReadPass : public GraphOptimizationPass {
if (n->type_string() == "ReadVariableOp") {
bool skip = false;
for (const Edge* e : n->out_edges()) {
if (!e->dst()->IsSend() && e->dst()->name() != "_SINK") {
if (!e->dst()->IsSend() && e->dst()->type_string() != "_Retval" &&
e->dst()->name() != "_SINK") {
skip = true;
}
}

View File

@ -284,9 +284,11 @@ Status SimpleGraphExecutionState::InitBaseGraph(
if (session_options_ &&
session_options_->config.graph_options().place_pruned_graph()) {
// Rewrite the graph before placement.
rewrite_metadata_.reset(new subgraph::RewriteGraphMetadata);
TF_RETURN_IF_ERROR(subgraph::RewriteGraphForExecution(
new_graph.get(), options.feed_endpoints, options.fetch_endpoints,
options.target_nodes, device_set_->client_device()->attributes()));
options.target_nodes, device_set_->client_device()->attributes(),
options.use_function_convention, rewrite_metadata_.get()));
}
// Save stateful placements before placing.
@ -333,15 +335,26 @@ Status SimpleGraphExecutionState::BuildGraph(
std::unique_ptr<Graph> ng(new Graph(flib_def_.get()));
CopyGraph(*graph_, ng.get());
subgraph::RewriteGraphMetadata rewrite_metadata;
if (session_options_ == nullptr ||
!session_options_->config.graph_options().place_pruned_graph()) {
// Extract the subset of the graph that needs to be run, adding feed/fetch
// ops as needed.
TF_RETURN_IF_ERROR(subgraph::RewriteGraphForExecution(
ng.get(), options.feed_endpoints, options.fetch_endpoints,
options.target_nodes, device_set_->client_device()->attributes()));
options.target_nodes, device_set_->client_device()->attributes(),
options.use_function_convention, &rewrite_metadata));
} else {
// This SimpleGraphExecutionState represents a graph that was
// pruned when this was constructed, so we copy the metadata from
// a member variable.
CHECK(rewrite_metadata_);
rewrite_metadata = *rewrite_metadata_;
}
CHECK_EQ(options.feed_endpoints.size(), rewrite_metadata.feed_types.size());
CHECK_EQ(options.fetch_endpoints.size(), rewrite_metadata.fetch_types.size());
// Make a fresh copy of the function library for the client graph.
std::unique_ptr<FunctionLibraryDefinition> flib(
new FunctionLibraryDefinition(*flib_def_));
@ -363,7 +376,8 @@ Status SimpleGraphExecutionState::BuildGraph(
// since the local CostModel used to record its stats is sized by
// the largest node id.
std::unique_ptr<SimpleClientGraph> dense_copy(
new SimpleClientGraph(std::move(flib)));
new SimpleClientGraph(std::move(flib), rewrite_metadata.feed_types,
rewrite_metadata.fetch_types));
CopyGraph(*ng, &dense_copy->graph);
// TODO(vrv): We should check invariants of the graph here.

View File

@ -39,6 +39,10 @@ struct SessionOptions;
class StepStats;
class Timeline;
namespace subgraph {
struct RewriteGraphMetadata;
}
struct SimpleGraphExecutionStateOptions {
const DeviceSet* device_set = nullptr;
const SessionOptions* session_options = nullptr;
@ -50,13 +54,19 @@ struct SimpleGraphExecutionStateOptions {
// A SimpleClientGraph is simply a sub-graph of the full graph as induced by
// BuildGraphOptions.
struct SimpleClientGraph {
explicit SimpleClientGraph(std::unique_ptr<FunctionLibraryDefinition> flib)
: flib_def(std::move(flib)), graph(flib_def.get()) {}
explicit SimpleClientGraph(std::unique_ptr<FunctionLibraryDefinition> flib,
DataTypeVector feed_types,
DataTypeVector fetch_types)
: flib_def(std::move(flib)),
graph(flib_def.get()),
feed_types(std::move(feed_types)),
fetch_types(std::move(fetch_types)) {}
// Each client-graph gets its own function library since optimization passes
// post rewrite for execution might want to introduce new functions.
std::unique_ptr<FunctionLibraryDefinition> flib_def;
Graph graph;
int32 placement_version;
DataTypeVector feed_types;
DataTypeVector fetch_types;
};
// SimpleGraphExecutionState is responsible for generating an
@ -190,6 +200,10 @@ class SimpleGraphExecutionState {
// and may be updated by a graph optimization pass.
std::unique_ptr<FunctionLibraryDefinition> flib_def_;
// `rewrite_metadata_` is only set for SimpleGraphExecutionState
// objects created by `MakeForPrunedGraph()`.
std::unique_ptr<subgraph::RewriteGraphMetadata> rewrite_metadata_;
// The dataflow graph owned by this object.
Graph* graph_;

View File

@ -63,8 +63,9 @@ class NoReusePortOption : public ::grpc::ServerBuilderOption {
};
// static utility function
RendezvousMgrInterface* NewRpcRendezvousMgr(const WorkerEnv* env,
const string& worker_name, WorkerCacheInterface* worker_cache) {
RendezvousMgrInterface* NewRpcRendezvousMgr(
const WorkerEnv* env, const string& worker_name,
WorkerCacheInterface* worker_cache) {
return new RpcRendezvousMgr(env, worker_name, worker_cache);
}
@ -76,7 +77,7 @@ GrpcServer::GrpcServer(const ServerDef& server_def, Env* env)
GrpcServer::~GrpcServer() {
TF_CHECK_OK(Stop());
TF_CHECK_OK(Join());
delete master_service_;
delete worker_service_;
@ -100,7 +101,7 @@ GrpcServer::~GrpcServer() {
}
Status GrpcServer::Init(ServiceInitFunction service_func,
RendezvousMgrCreationFunction rendevous_mgr_func) {
RendezvousMgrCreationFunction rendevous_mgr_func) {
mutex_lock l(mu_);
CHECK_EQ(state_, NEW);
master_env_.env = env_;
@ -193,6 +194,8 @@ Status GrpcServer::Init(ServiceInitFunction service_func,
// Set up worker environment.
std::unique_ptr<RendezvousMgrInterface> rendezvous_mgr(
rendevous_mgr_func == nullptr ?
new RpcRendezvousMgr(&worker_env_, name_prefix, worker_cache) :
rendevous_mgr_func(&worker_env_, name_prefix, worker_cache));
worker_env_.session_mgr = new SessionMgr(
&worker_env_, SessionMgr::WorkerNameFromServerDef(server_def_),
@ -222,6 +225,10 @@ Status GrpcServer::Init(ServiceInitFunction service_func,
return Status::OK();
}
Status GrpcServer::Init() {
return Init(nullptr, nullptr);
}
Status GrpcServer::ParseChannelSpec(const ServerDef& server_def,
GrpcChannelSpec* channel_spec) {
for (const auto& job : server_def.cluster().job()) {

View File

@ -37,14 +37,15 @@ class GrpcWorker;
class Master;
// function that creates a RendezvousMgr.
typedef std::function<RendezvousMgrInterface*(const WorkerEnv*,
const std::string& worker_name, WorkerCacheInterface* worker_cache)>
RendezvousMgrCreationFunction;
typedef std::function<RendezvousMgrInterface*(
const WorkerEnv*, const std::string& worker_name,
WorkerCacheInterface* worker_cache)>
RendezvousMgrCreationFunction;
// function that registers a service to the server. The service needs to
// be registered before builder.BuildAndStart().
typedef std::function<void(const WorkerEnv*, ::grpc::ServerBuilder*)>
ServiceInitFunction;
typedef std::function<void(const WorkerEnv*, ::grpc::ServerBuilder*)>
ServiceInitFunction;
class GrpcServer : public ServerInterface {
protected:
@ -68,6 +69,8 @@ class GrpcServer : public ServerInterface {
Status Init(ServiceInitFunction service_func,
RendezvousMgrCreationFunction rendezvous_mgr_func);
Status Init();
// A subclass can override this method to support secure credentials.
virtual std::shared_ptr<::grpc::ServerCredentials> GetServerCredentials(
const ServerDef& server_def) const;
@ -90,7 +93,7 @@ class GrpcServer : public ServerInterface {
int bound_port() const { return bound_port_; }
WorkerEnv* worker_env() { return &worker_env_; }
const ServerDef& server_def() const { return server_def_; }
private:
@ -115,7 +118,7 @@ class GrpcServer : public ServerInterface {
// Stop(), Join()
enum State { NEW, STARTED, STOPPED };
State state_ GUARDED_BY(mu_);
// Implementation of a TensorFlow master, and RPC polling thread.
MasterEnv master_env_;
std::unique_ptr<Master> master_impl_;

View File

@ -789,7 +789,7 @@ Status FunctionCallFrame::GetRetvals(std::vector<Tensor>* rets) const {
rets->clear();
rets->reserve(rets_.size());
for (size_t i = 0; i < rets_.size(); ++i) {
auto item = rets_[i];
const auto& item = rets_[i];
if (item.has_val) {
rets->push_back(item.val);
} else {
@ -799,6 +799,19 @@ Status FunctionCallFrame::GetRetvals(std::vector<Tensor>* rets) const {
return Status::OK();
}
Status FunctionCallFrame::ConsumeRetvals(std::vector<Tensor>* rets) {
rets->clear();
rets->reserve(rets_.size());
for (size_t i = 0; i < rets_.size(); ++i) {
if (rets_[i].has_val) {
rets->emplace_back(std::move(rets_[i].val));
} else {
return errors::Internal("Retval[", i, "] does not have value");
}
}
return Status::OK();
}
Status FunctionCallFrame::GetArg(int index, Tensor* val) const {
if (index < 0 || static_cast<size_t>(index) >= args_.size()) {
return errors::InvalidArgument("GetArg ", index, " is not within [0, ",

View File

@ -259,6 +259,7 @@ class FunctionCallFrame {
// Caller methods.
Status SetArgs(gtl::ArraySlice<Tensor> args);
Status GetRetvals(std::vector<Tensor>* rets) const;
Status ConsumeRetvals(std::vector<Tensor>* rets);
// Callee methods.
Status GetArg(int index, Tensor* val) const;

View File

@ -126,29 +126,33 @@ FunctionDef XTimes16() {
{{"y", "y:y:0"}});
}
FunctionDef WXPlusB() {
return FDH::Define(
// Name
"WXPlusB",
// Args
{"w: T", "x: T", "b: T"},
// Return values
{"y: T"},
// Attr def
{"T: {float, double}"},
// Nodes
{{{"mm"},
"MatMul",
{"w", "x"},
{{"T", "$T"},
{"transpose_a", false},
{"transpose_b", false},
#ifdef INTEL_MKL
}},
FunctionDef WXPlusB(){return FDH::Define(
// Name
"WXPlusB",
// Args
{"w: T", "x: T", "b: T"},
// Return values
{"y: T"},
// Attr def
{"T: {float, double}"},
// Nodes
{
{{"mm"},
"MatMul",
{"w", "x"},
{
{"T", "$T"}, {"transpose_a", false}, {"transpose_b", false},
#ifdef INTEL_MKL
}},
#else
{"_kernel", "eigen"}}},
#endif
{{"y"}, "Add", {"mm", "b"}, {{"T", "$T"}}}});
{
{"y"}, "Add", {"mm", "b"}, {
{ "T", "$T" }
}
}
});
}
FunctionDef Swap() {

View File

@ -63,7 +63,7 @@ namespace tensorflow {
// P = BiasAdd(O, C)
//
// We merge them into Conv2DWithBias as:
// P = MklConv2DWithBias(A, A_m, B, B_m, C, C_m)
// P = _MklConv2DWithBias(A, A_m, B, B_m, C, C_m)
//
// The meaning of A_m, B_m and C_m is explained in B.1.
//
@ -115,7 +115,7 @@ namespace tensorflow {
// Since every rewritten node generates twice the number of inputs and
// outputs, one could imagine various orderings among Tensorflow tensors
// and Mkl tensors. E.g., assume an op 'Conv2D' that takes (A, B) as
// inputs, then the new op 'MklConv2D' can take inputs A, B, A_m and B_m
// inputs, then the new op '_MklConv2D' can take inputs A, B, A_m and B_m
// in A, A_m, B, B_m order or it can also take them in A, B, A_m, B_m
// order. Among N inputs one can get N! permutations.
//
@ -239,15 +239,15 @@ namespace tensorflow {
// -------------------------------------------
// Consider BiasAddGrad op as:
//
// O = MklConv2D(A, B, C, A_m, B_m, C_m)
// O = _MklConv2D(A, B, C, A_m, B_m, C_m)
// P = BiasAddGrad(O)
//
// Then we rewrite it as:
//
// P = Conv2DWithBiasBackpropBias(O, O_m)
//
// 'Distance' between input of BiasAddGrad and MklConv2D in terms of hops is
// the context matching depth. If MklConv2DWithBias is not within the context
// 'Distance' between input of BiasAddGrad and _MklConv2D in terms of hops is
// the context matching depth. If _MklConv2DWithBias is not within the context
// matching depth, then we do not rewrite BiasAddGrad.
// How many hops do we search for matching node in the backward dataflow graph?
@ -261,74 +261,66 @@ class MklLayoutRewritePass : public GraphOptimizationPass {
public:
MklLayoutRewritePass() {
// NOTE: names are alphabetically sorted.
csinfo_.avg_pool = "AvgPool";
csinfo_.avg_pool_grad = "AvgPoolGrad";
csinfo_.bias_add = "BiasAdd";
csinfo_.bias_add_grad = "BiasAddGrad";
csinfo_.concat = "Concat";
csinfo_.concatv2 = "ConcatV2";
csinfo_.conv2d = "Conv2D";
csinfo_.conv2d_grad_input = "Conv2DBackpropInput";
csinfo_.conv2d_grad_filter = "Conv2DBackpropFilter";
csinfo_.fused_batch_norm = "FusedBatchNorm";
csinfo_.avg_pool = "AvgPool";
csinfo_.avg_pool_grad = "AvgPoolGrad";
csinfo_.bias_add = "BiasAdd";
csinfo_.bias_add_grad = "BiasAddGrad";
csinfo_.concat = "Concat";
csinfo_.concatv2 = "ConcatV2";
csinfo_.conv2d = "Conv2D";
csinfo_.conv2d_grad_input = "Conv2DBackpropInput";
csinfo_.conv2d_grad_filter = "Conv2DBackpropFilter";
csinfo_.fused_batch_norm = "FusedBatchNorm";
csinfo_.fused_batch_norm_grad = "FusedBatchNormGrad";
csinfo_.lrn = "LRN";
csinfo_.lrn_grad = "LRNGrad";
csinfo_.matmul = "MatMul";
csinfo_.max_pool = "MaxPool";
csinfo_.max_pool_grad = "MaxPoolGrad";
csinfo_.mkl_conv2d = "MklConv2D";
csinfo_.mkl_conv2d_with_bias = "MklConv2DWithBias";
csinfo_.lrn = "LRN";
csinfo_.lrn_grad = "LRNGrad";
csinfo_.matmul = "MatMul";
csinfo_.max_pool = "MaxPool";
csinfo_.max_pool_grad = "MaxPoolGrad";
csinfo_.mkl_conv2d = "_MklConv2D";
csinfo_.mkl_conv2d_with_bias = "_MklConv2DWithBias";
csinfo_.mkl_conv2d_with_bias_backprop_bias =
"MklConv2DWithBiasBackpropBias";
csinfo_.relu = "Relu";
csinfo_.reshape = "Reshape";
csinfo_.relu_grad = "ReluGrad";
csinfo_.split = "Split";
"_MklConv2DWithBiasBackpropBias";
csinfo_.relu = "Relu";
csinfo_.reshape = "Reshape";
csinfo_.relu_grad = "ReluGrad";
csinfo_.split = "Split";
// NOTE: names are alphabetically sorted.
rinfo_.push_back({csinfo_.avg_pool,
GetMklOpName(csinfo_.avg_pool),
1, CopyAttrsPooling, AlwaysRewrite});
rinfo_.push_back({csinfo_.avg_pool, GetMklOpName(csinfo_.avg_pool), 1,
CopyAttrsPooling, AlwaysRewrite});
rinfo_.push_back({csinfo_.avg_pool_grad,
GetMklOpName(csinfo_.avg_pool_grad),
2, CopyAttrsPooling, AlwaysRewrite});
rinfo_.push_back({csinfo_.concat,
GetMklOpName(csinfo_.concat),
0, CopyAttrsConcat, AlwaysRewrite});
rinfo_.push_back({csinfo_.concatv2,
GetMklOpName(csinfo_.concatv2),
0, CopyAttrsConcatV2, AlwaysRewrite});
rinfo_.push_back({csinfo_.conv2d,
GetMklOpName(csinfo_.conv2d),
2, CopyAttrsConv2D, AlwaysRewrite});
GetMklOpName(csinfo_.avg_pool_grad), 2, CopyAttrsPooling,
AlwaysRewrite});
rinfo_.push_back({csinfo_.concat, GetMklOpName(csinfo_.concat), 0,
CopyAttrsConcat, AlwaysRewrite});
rinfo_.push_back({csinfo_.concatv2, GetMklOpName(csinfo_.concatv2), 0,
CopyAttrsConcatV2, AlwaysRewrite});
rinfo_.push_back({csinfo_.conv2d, GetMklOpName(csinfo_.conv2d), 2,
CopyAttrsConv2D, AlwaysRewrite});
rinfo_.push_back({csinfo_.conv2d_grad_filter,
GetMklOpName(csinfo_.conv2d_grad_filter),
3, CopyAttrsConv2D, AlwaysRewrite});
GetMklOpName(csinfo_.conv2d_grad_filter), 3,
CopyAttrsConv2D, AlwaysRewrite});
rinfo_.push_back({csinfo_.conv2d_grad_input,
GetMklOpName(csinfo_.conv2d_grad_input),
3, CopyAttrsConv2D, AlwaysRewrite});
GetMklOpName(csinfo_.conv2d_grad_input), 3,
CopyAttrsConv2D, AlwaysRewrite});
rinfo_.push_back({csinfo_.fused_batch_norm,
GetMklOpName(csinfo_.fused_batch_norm),
5, CopyAttrsFusedBatchNorm, AlwaysRewrite});
GetMklOpName(csinfo_.fused_batch_norm), 5,
CopyAttrsFusedBatchNorm, AlwaysRewrite});
rinfo_.push_back({csinfo_.fused_batch_norm_grad,
GetMklOpName(csinfo_.fused_batch_norm_grad),
5, CopyAttrsFusedBatchNorm, AlwaysRewrite});
rinfo_.push_back({csinfo_.lrn,
GetMklOpName(csinfo_.lrn),
1, CopyAttrsLRN, AlwaysRewrite});
rinfo_.push_back({csinfo_.lrn_grad,
GetMklOpName(csinfo_.lrn_grad),
3, CopyAttrsLRN, AlwaysRewrite});
rinfo_.push_back({csinfo_.max_pool,
GetMklOpName(csinfo_.max_pool),
1, CopyAttrsPooling, AlwaysRewrite});
GetMklOpName(csinfo_.fused_batch_norm_grad), 5,
CopyAttrsFusedBatchNorm, AlwaysRewrite});
rinfo_.push_back({csinfo_.lrn, GetMklOpName(csinfo_.lrn), 1, CopyAttrsLRN,
AlwaysRewrite});
rinfo_.push_back({csinfo_.lrn_grad, GetMklOpName(csinfo_.lrn_grad), 3,
CopyAttrsLRN, AlwaysRewrite});
rinfo_.push_back({csinfo_.max_pool, GetMklOpName(csinfo_.max_pool), 1,
CopyAttrsPooling, AlwaysRewrite});
rinfo_.push_back({csinfo_.max_pool_grad,
GetMklOpName(csinfo_.max_pool_grad),
3, CopyAttrsPooling, AlwaysRewrite});
rinfo_.push_back({csinfo_.relu,
GetMklOpName(csinfo_.relu),
1, CopyAttrsRelu, AlwaysRewrite});
GetMklOpName(csinfo_.max_pool_grad), 3, CopyAttrsPooling,
AlwaysRewrite});
rinfo_.push_back({csinfo_.relu, GetMklOpName(csinfo_.relu), 1,
CopyAttrsRelu, AlwaysRewrite});
rinfo_.push_back({csinfo_.reshape, GetMklOpName(csinfo_.reshape), 2,
CopyAttrsReshape, AlwaysRewrite});
@ -339,8 +331,8 @@ class MklLayoutRewritePass : public GraphOptimizationPass {
wsinfo_.push_back({csinfo_.max_pool, csinfo_.max_pool_grad, 0, 1, 1, 3});
// Add a rule for merging nodes
minfo_.push_back(
{csinfo_.mkl_conv2d, csinfo_.bias_add, 0, csinfo_.mkl_conv2d_with_bias});
minfo_.push_back({csinfo_.mkl_conv2d, csinfo_.bias_add, 0,
csinfo_.mkl_conv2d_with_bias});
// We use maxhop of 10 based on empirical observations. Also, these are
// maxhops in backward data-flow graph. Since input of forward nodes
@ -374,7 +366,7 @@ class MklLayoutRewritePass : public GraphOptimizationPass {
// A function handler to copy attributes from an old node to a new node.
std::function<void(const Node*, NodeBuilder*)> copy_attrs;
std::function<bool(const Node*)> rewrite_rule; // A rule under which to
// rewrite this node.
// rewrite this node.
} RewriteInfo;
/// Structure to specify a forward op, a backward op, and the slot numbers
@ -477,7 +469,7 @@ class MklLayoutRewritePass : public GraphOptimizationPass {
//
// Concat, Split are vararg nodes.
inline bool IsVarArgNode(Node* n) {
if (n->type_string() == csinfo_.concat ||
if (n->type_string() == csinfo_.concat ||
n->type_string() == csinfo_.concatv2 ||
n->type_string() == csinfo_.split) {
return true;
@ -496,9 +488,9 @@ class MklLayoutRewritePass : public GraphOptimizationPass {
inline int GetTensorListLength(const OpDef::ArgDef& arg, Node* n) {
CHECK_EQ(ArgIsList(arg), true);
int N = 0;
const string attr_name = !arg.type_list_attr().empty() ?
arg.type_list_attr() :
arg.number_attr();
const string attr_name = !arg.type_list_attr().empty()
? arg.type_list_attr()
: arg.number_attr();
if (!arg.type_list_attr().empty()) {
std::vector<DataType> value;
TF_CHECK_OK(GetNodeAttr(n->def(), attr_name, &value));
@ -514,7 +506,7 @@ class MklLayoutRewritePass : public GraphOptimizationPass {
// TODO(nhasabni) We should move this to mkl_util.h.
inline string GetMklOpName(const string& name) const {
// Prefix that we add to Tensorflow op name to construct Mkl op name.
const char* const kMklOpPrefix = "Mkl";
const char* const kMklOpPrefix = "_Mkl";
return string(kMklOpPrefix) + name;
}
@ -598,9 +590,9 @@ class MklLayoutRewritePass : public GraphOptimizationPass {
//
// @return None
void GetNodesProducingTFTensorList(
const gtl::InlinedVector<std::pair<Node*, int>, 4>& inputs,
int* input_idx, int list_length,
std::vector<NodeBuilder::NodeOut>* output_nodes);
const gtl::InlinedVector<std::pair<Node*, int>, 4>& inputs,
int* input_idx, int list_length,
std::vector<NodeBuilder::NodeOut>* output_nodes);
// Get nodes that will feed a list of Mkl tensors to the new
// node that we are constructing.
@ -616,10 +608,11 @@ class MklLayoutRewritePass : public GraphOptimizationPass {
// @output output_nodes - the list of new nodes creating Mkl tensors
//
// @return None
void GetNodesProducingMklTensorList(std::unique_ptr<Graph>* g,
const gtl::InlinedVector<std::pair<Node*, int>, 4>& inputs,
int* input_idx, int list_length,
std::vector<NodeBuilder::NodeOut>* output_nodes);
void GetNodesProducingMklTensorList(
std::unique_ptr<Graph>* g,
const gtl::InlinedVector<std::pair<Node*, int>, 4>& inputs,
int* input_idx, int list_length,
std::vector<NodeBuilder::NodeOut>* output_nodes);
// Get a node that will feed an Mkl tensor to the new
// node that we are constructing. The output node could be (1) 'n'
@ -635,7 +628,8 @@ class MklLayoutRewritePass : public GraphOptimizationPass {
// will feed the tensor
// @return None
void GetNodeProducingMklTensor(std::unique_ptr<Graph>* g, Node* n,
int n_output_slot, Node** mkl_node, int* mkl_node_output_slot);
int n_output_slot, Node** mkl_node,
int* mkl_node_output_slot);
// Setup new inputs using old inputs 'inputs' for the rewritten node in 'nb'
// in graph 'g'. Original node is input in 'old_node'. Inputs to 'nb' are
@ -648,11 +642,12 @@ class MklLayoutRewritePass : public GraphOptimizationPass {
//
// Returns Status::OK() if setting up inputs is successful, otherwise
// returns appropriate status code.
int SetUpContiguousInputs(std::unique_ptr<Graph>* g,
const gtl::InlinedVector<std::pair<Node*, int>, 4>& old_node_inputs,
NodeBuilder* nb, Node* old_node,
std::vector<NodeBuilder::NodeOut>* workspace_tensors,
bool are_workspace_tensors_available);
int SetUpContiguousInputs(
std::unique_ptr<Graph>* g,
const gtl::InlinedVector<std::pair<Node*, int>, 4>& old_node_inputs,
NodeBuilder* nb, Node* old_node,
std::vector<NodeBuilder::NodeOut>* workspace_tensors,
bool are_workspace_tensors_available);
// Setup new inputs using old inputs 'inputs' for the rewritten node in 'nb'
// in graph 'g'. Original node is input in 'orig_node'.
@ -672,8 +667,9 @@ class MklLayoutRewritePass : public GraphOptimizationPass {
// tensors, if they need to be added, will be set into these tensors.
// If we set workspace tensors, then are_ws_tensors_added should be true.
void AddWorkSpaceEdgeIfNeeded(std::unique_ptr<Graph>* g, Node* orig_node,
NodeBuilder* nb, std::vector<NodeBuilder::NodeOut>* ws_tensors,
bool* are_ws_tensors_added);
NodeBuilder* nb,
std::vector<NodeBuilder::NodeOut>* ws_tensors,
bool* are_ws_tensors_added);
// Functions specific to operators to copy attributes
// We need operator-specific function to copy attributes because the framework
@ -732,9 +728,8 @@ static void FillInputs(const Node* n,
}
void MklLayoutRewritePass::GetNodesProducingTFTensorList(
const gtl::InlinedVector<std::pair<Node*, int>, 4>& inputs,
int* input_idx, int list_length,
std::vector<NodeBuilder::NodeOut>* output_nodes) {
const gtl::InlinedVector<std::pair<Node*, int>, 4>& inputs, int* input_idx,
int list_length, std::vector<NodeBuilder::NodeOut>* output_nodes) {
CHECK_LT(*input_idx, inputs.size());
CHECK_GT(list_length, 0);
CHECK_NOTNULL(output_nodes);
@ -767,34 +762,33 @@ void MklLayoutRewritePass::GetNodesProducingTFTensorList(
}
// TODO(nhasabni) We should move this to mkl_util.h.
void MklLayoutRewritePass::GetDummyMklTensorNode(
std::unique_ptr<Graph>* g, Node** out, Node* orig_node) {
void MklLayoutRewritePass::GetDummyMklTensorNode(std::unique_ptr<Graph>* g,
Node** out, Node* orig_node) {
// We use a tensor of shape {8} and value 0,0,0,0,0,0,0,0 to represent
// dummy Mkl tensor. 8 = 2*size_t.
const DataType dt = DataTypeToEnum<uint8>::v();
TensorProto proto;
proto.set_dtype(dt);
uint8 zero[8] = {0, 0, 0, 0, 0, 0, 0, 0};
proto.set_tensor_content(const_cast<const void*>(
static_cast<void*>(&zero)), 8);
proto.set_tensor_content(const_cast<const void*>(static_cast<void*>(&zero)),
8);
TensorShape dummy_shape({8});
dummy_shape.AsProto(proto.mutable_tensor_shape());
TF_CHECK_OK(NodeBuilder((*g)->NewName("DMT"), "Const")
.Attr("value", proto)
.Attr("dtype", dt)
.Device(orig_node->def().device()) // We place this node on
// the same device as the
// device of the original
// node.
.Finalize(&**g, out));
.Attr("value", proto)
.Attr("dtype", dt)
.Device(orig_node->def().device()) // We place this node on
// the same device as the
// device of the original
// node.
.Finalize(&**g, out));
(*out)->set_assigned_device_name(orig_node->assigned_device_name());
}
void MklLayoutRewritePass::GetNodesProducingMklTensorList(
std::unique_ptr<Graph>* g,
const gtl::InlinedVector<std::pair<Node*, int>, 4>& inputs,
int* input_idx, int list_length,
std::vector<NodeBuilder::NodeOut>* output_nodes) {
const gtl::InlinedVector<std::pair<Node*, int>, 4>& inputs, int* input_idx,
int list_length, std::vector<NodeBuilder::NodeOut>* output_nodes) {
CHECK_LT(*input_idx, inputs.size());
CHECK_GT(list_length, 0);
CHECK_NOTNULL(output_nodes);
@ -819,8 +813,8 @@ void MklLayoutRewritePass::GetNodesProducingMklTensorList(
// If it is a list, then create a list of Mkl dummy nodes.
for (int j = 0; j < N; j++) {
GetNodeProducingMklTensor(g, n, slot, &mkl_node, &mkl_node_output_slot);
output_nodes->push_back(NodeBuilder::NodeOut(mkl_node,
mkl_node_output_slot));
output_nodes->push_back(
NodeBuilder::NodeOut(mkl_node, mkl_node_output_slot));
}
(*input_idx)++;
list_length -= N;
@ -829,8 +823,8 @@ void MklLayoutRewritePass::GetNodesProducingMklTensorList(
Node* mkl_node = nullptr;
int mkl_node_output_slot = 0;
GetNodeProducingMklTensor(g, n, slot, &mkl_node, &mkl_node_output_slot);
output_nodes->push_back(NodeBuilder::NodeOut(mkl_node,
mkl_node_output_slot));
output_nodes->push_back(
NodeBuilder::NodeOut(mkl_node, mkl_node_output_slot));
(*input_idx)++;
list_length--;
}
@ -841,9 +835,9 @@ void MklLayoutRewritePass::GetNodesProducingMklTensorList(
// node that we are constructing. An input node could be (1) 'n'
// if it is Mkl layer, or (2) a dummy node producing dummy Mkl tensor
// if 'n' is not an Mkl layer.
void MklLayoutRewritePass::GetNodeProducingMklTensor(std::unique_ptr<Graph>* g,
Node* n,
int n_output_slot, Node** mkl_node, int* mkl_node_output_slot) {
void MklLayoutRewritePass::GetNodeProducingMklTensor(
std::unique_ptr<Graph>* g, Node* n, int n_output_slot, Node** mkl_node,
int* mkl_node_output_slot) {
CHECK_NOTNULL(n);
CHECK_NOTNULL(mkl_node);
CHECK_NOTNULL(mkl_node_output_slot);
@ -859,8 +853,8 @@ void MklLayoutRewritePass::GetNodeProducingMklTensor(std::unique_ptr<Graph>* g,
// output slot number for Mkl tensor would be N+slot number of TensorFlow
// tensor, where N is total number of TensorFlow tensors.
*mkl_node = n;
*mkl_node_output_slot = GetTensorMetaDataIndex(n_output_slot,
n->num_outputs());
*mkl_node_output_slot =
GetTensorMetaDataIndex(n_output_slot, n->num_outputs());
} else {
// If we have not visited the node and rewritten it, then we need
// to create a dummy node that will feed a dummy Mkl tensor to this node.
@ -872,7 +866,8 @@ void MklLayoutRewritePass::GetNodeProducingMklTensor(std::unique_ptr<Graph>* g,
}
}
int MklLayoutRewritePass::SetUpContiguousInputs(std::unique_ptr<Graph>* g,
int MklLayoutRewritePass::SetUpContiguousInputs(
std::unique_ptr<Graph>* g,
const gtl::InlinedVector<std::pair<Node*, int>, 4>& old_node_inputs,
NodeBuilder* nb, Node* old_node,
std::vector<NodeBuilder::NodeOut>* workspace_tensors,
@ -931,16 +926,16 @@ int MklLayoutRewritePass::SetUpContiguousInputs(std::unique_ptr<Graph>* g,
if (ArgIsList(arg)) {
std::vector<NodeBuilder::NodeOut> new_node_inputs;
int N = GetTensorListLength(arg, old_node);
GetNodesProducingMklTensorList(g, old_node_inputs, &iidx,
N, &new_node_inputs);
GetNodesProducingMklTensorList(g, old_node_inputs, &iidx, N,
&new_node_inputs);
nb->Input(new_node_inputs);
nn_slot_idx++;
} else {
Node* mkl_node = nullptr;
int mkl_node_output_slot = 0;
GetNodeProducingMklTensor(g, old_node_inputs[iidx].first,
old_node_inputs[iidx].second,
&mkl_node, &mkl_node_output_slot);
old_node_inputs[iidx].second, &mkl_node,
&mkl_node_output_slot);
nb->Input(mkl_node, mkl_node_output_slot);
iidx++;
nn_slot_idx++;
@ -961,7 +956,8 @@ int MklLayoutRewritePass::SetUpContiguousInputs(std::unique_ptr<Graph>* g,
return nn_slot_idx;
}
Status MklLayoutRewritePass::SetUpInputs(std::unique_ptr<Graph>* g,
Status MklLayoutRewritePass::SetUpInputs(
std::unique_ptr<Graph>* g,
const gtl::InlinedVector<std::pair<Node*, int>, 4>& old_node_inputs,
NodeBuilder* nb, Node* old_node) {
// Let's check if we need to add workspace tensors for this node.
@ -975,13 +971,14 @@ Status MklLayoutRewritePass::SetUpInputs(std::unique_ptr<Graph>* g,
if (kTensorOrdering == MklTfTensorOrdering::TENSORS_INTERLEAVED) {
// TODO(nhasabni): implement this function just for same of completion.
// We do not use interleaved ordering right now.
return Status(error::Code::UNIMPLEMENTED,
"Interleaved ordering of tensors is currently not supported.");
return Status(
error::Code::UNIMPLEMENTED,
"Interleaved ordering of tensors is currently not supported.");
} else {
CHECK_EQ(kTensorOrdering, MklTfTensorOrdering::TENSORS_CONTIGUOUS);
new_node_input_slots = SetUpContiguousInputs(g, old_node_inputs, nb,
old_node, &workspace_tensors,
are_workspace_tensors_available);
new_node_input_slots = SetUpContiguousInputs(
g, old_node_inputs, nb, old_node, &workspace_tensors,
are_workspace_tensors_available);
}
// Sanity check
@ -1023,20 +1020,19 @@ void MklLayoutRewritePass::GetDummyWorkspaceTensorNode(
TensorShape dummy_shape({1});
dummy_shape.AsProto(proto.mutable_tensor_shape());
TF_CHECK_OK(NodeBuilder((*g)->NewName("DMT"), "Const")
.Attr("value", proto)
.Attr("dtype", dt)
.Device(orig_node->def().device()) // We place this node on
// same the device as the
// device of the original
// node.
.Finalize(&**g, out));
.Attr("value", proto)
.Attr("dtype", dt)
.Device(orig_node->def().device()) // We place this node on
// same the device as the
// device of the original
// node.
.Finalize(&**g, out));
(*out)->set_assigned_device_name(orig_node->assigned_device_name());
}
void MklLayoutRewritePass::AddWorkSpaceEdgeIfNeeded(std::unique_ptr<Graph>* g,
Node* orig_node, NodeBuilder* nb,
std::vector<NodeBuilder::NodeOut>* ws_tensors,
bool* are_ws_tensors_added) {
void MklLayoutRewritePass::AddWorkSpaceEdgeIfNeeded(
std::unique_ptr<Graph>* g, Node* orig_node, NodeBuilder* nb,
std::vector<NodeBuilder::NodeOut>* ws_tensors, bool* are_ws_tensors_added) {
bool workspace_edge_added = false; // Default initializer
CHECK_NOTNULL(are_ws_tensors_added);
*are_ws_tensors_added = false; // Default initializer
@ -1071,8 +1067,8 @@ void MklLayoutRewritePass::AddWorkSpaceEdgeIfNeeded(std::unique_ptr<Graph>* g,
nb->Attr("workspace_enabled", false);
}
} else if (orig_node->type_string() == ws.bwd_op &&
mkl_op_registry::IsMklOp(
GetMklOpName(orig_node->type_string()), T)) {
mkl_op_registry::IsMklOp(GetMklOpName(orig_node->type_string()),
T)) {
// If this op is a bwd op, then we need to add workspace edge and
// it's Mkl tensor edge between its corresponding fwd op and this
// op. Corresponding fwd op is specified in 'fwd_op' field of
@ -1094,8 +1090,9 @@ void MklLayoutRewritePass::AddWorkSpaceEdgeIfNeeded(std::unique_ptr<Graph>* g,
// Add workspace edge between fwd op and bwd op.
ws_tensors->push_back(NodeBuilder::NodeOut(e->src(), ws.ws_fwd_slot));
// Add Mkl tensor edge for workspace edge between fwd op and bwd op.
ws_tensors->push_back(NodeBuilder::NodeOut(e->src(),
DataIndexToMetaDataIndex(ws.ws_fwd_slot, e->src()->num_outputs())));
ws_tensors->push_back(NodeBuilder::NodeOut(
e->src(), DataIndexToMetaDataIndex(ws.ws_fwd_slot,
e->src()->num_outputs())));
*are_ws_tensors_added = true;
// In terms of input ordering, we add these calls to add Input
// here because workspace edge (and its Mkl tensor) is the last
@ -1154,8 +1151,8 @@ void MklLayoutRewritePass::CopyAttrsConv2D(const Node* orig_node,
TF_CHECK_OK(GetNodeAttr(orig_node->def(), "strides", &strides));
TF_CHECK_OK(GetNodeAttr(orig_node->def(), "padding", &padding));
TF_CHECK_OK(GetNodeAttr(orig_node->def(), "data_format", &data_format));
TF_CHECK_OK(GetNodeAttr(orig_node->def(), "use_cudnn_on_gpu",
&use_cudnn_on_gpu));
TF_CHECK_OK(
GetNodeAttr(orig_node->def(), "use_cudnn_on_gpu", &use_cudnn_on_gpu));
// Add attributes to new node.
nb->Attr("T", T);
@ -1307,14 +1304,14 @@ void MklLayoutRewritePass::CopyAttrsFusedBatchNorm(const Node* orig_node,
}
void MklLayoutRewritePass::CopyAttrsReshape(const Node* orig_node,
NodeBuilder* nb) {
NodeBuilder* nb) {
DataType T;
DataType Tshape;
// Get all attributes from old node.
TF_CHECK_OK(GetNodeAttr(orig_node->def(), "T", &T));
TF_CHECK_OK(GetNodeAttr(orig_node->def(), "Tshape", &Tshape));
// Add attributes to new node.
nb->Attr("T", T);
nb->Attr("Tshape", Tshape);
@ -1435,7 +1432,7 @@ Status MklLayoutRewritePass::MergeNode(std::unique_ptr<Graph>* g, Node* succ,
// 2. Get inputs from both the nodes.
// Find the 2 inputs from the conv and the bias from the add Bias.
// Get operand 0, 1 of conv2D and their Mkl tensors.
CHECK_EQ(pred->in_edges().size(), 4); // MklConv2D must have 4 inputs.
CHECK_EQ(pred->in_edges().size(), 4); // _MklConv2D must have 4 inputs.
// Get operand 1 of add_bias
// BiasAdd must have 2 inputs: Conv, bias
CHECK_EQ(succ->in_edges().size(), 2);
@ -1538,15 +1535,15 @@ Status MklLayoutRewritePass::RewriteNode(std::unique_ptr<Graph>* g,
DataType orig_T, ctx_T;
string orig_data_format, ctx_data_format;
TF_CHECK_OK(GetNodeAttr(orig_node->def(), "T", &orig_T));
TF_CHECK_OK(GetNodeAttr(orig_node->def(), "data_format",
&orig_data_format));
TF_CHECK_OK(
GetNodeAttr(orig_node->def(), "data_format", &orig_data_format));
TF_CHECK_OK(GetNodeAttr(fwd_node->def(), "T", &ctx_T));
TF_CHECK_OK(GetNodeAttr(fwd_node->def(), "data_format",
&ctx_data_format));
TF_CHECK_OK(
GetNodeAttr(fwd_node->def(), "data_format", &ctx_data_format));
if (orig_data_format != ctx_data_format || orig_T != ctx_T ||
orig_node->assigned_device_name() !=
fwd_node->assigned_device_name() ||
fwd_node->assigned_device_name() ||
orig_node->def().device() != fwd_node->def().device()) {
return Status(
error::Code::INVALID_ARGUMENT,
@ -1613,9 +1610,10 @@ Status MklLayoutRewritePass::RewriteNode(std::unique_ptr<Graph>* g,
if (e->src_output() < 0) {
(*g)->AddEdge(new_node, e->src_output(), e->dst(), e->dst_input());
} else {
(*g)->AddEdge(new_node, GetTensorDataIndex(e->src_output(),
e->src()->num_outputs()),
e->dst(), e->dst_input());
(*g)->AddEdge(
new_node,
GetTensorDataIndex(e->src_output(), e->src()->num_outputs()),
e->dst(), e->dst_input());
}
}

View File

@ -110,13 +110,11 @@ class MklLayoutPassTest : public ::testing::Test {
};
REGISTER_OP("Input").Output("o: float").SetIsStateful();
REGISTER_OP("InputList").Output("o: N * float")
.Attr("N: int")
.SetIsStateful();
REGISTER_OP("InputList").Output("o: N * float").Attr("N: int").SetIsStateful();
REGISTER_OP("HalfInput").Output("o: half").SetIsStateful();
REGISTER_OP("Int32Input").Output("o: int32").SetIsStateful();
REGISTER_OP("MklInput").Output("o: uint8").SetIsStateful();
REGISTER_OP("MklInput2").Output("o: uint8").Output("o1: uint8").SetIsStateful();
REGISTER_OP("_MklInput").Output("o: uint8").SetIsStateful();
REGISTER_OP("_MklInput2").Output("o: uint8").Output("o1: uint8").SetIsStateful();
/////////////////////////////////////////////////////////////////////
// Unit tests related to node merge optiimization
@ -137,16 +135,16 @@ TEST_F(MklLayoutPassTest, Basic) {
// Test set 1: Conv2D + AddBias
// C=MklConv2D(A,M,B,N); E=BiasAdd(C,D); Z=Sub(E,Y) (for interleaved ordering)
// C=MklConv2D(A,B,M,N); E=BiasAdd(C,D); Z=Sub(E,Y) (for contiguous ordering)
// C=_MklConv2D(A,M,B,N); E=BiasAdd(C,D); Z=Sub(E,Y) (for interleaved ordering)
// C=_MklConv2D(A,B,M,N); E=BiasAdd(C,D); Z=Sub(E,Y) (for contiguous ordering)
TEST_F(MklLayoutPassTest, NodeMerge_Conv2DWithBias_Positive) {
CHECK_EQ(kTensorOrdering, MklTfTensorOrdering::TENSORS_CONTIGUOUS);
InitGraph(
"node { name: 'A' op: 'Input'}"
"node { name: 'B' op: 'Input'}"
"node { name: 'M' op: 'MklInput'}"
"node { name: 'N' op: 'MklInput'}"
"node { name: 'C' op: 'MklConv2D'"
"node { name: 'M' op: '_MklInput'}"
"node { name: 'N' op: '_MklInput'}"
"node { name: 'C' op: '_MklConv2D'"
" attr { key: 'T' value { type: DT_FLOAT } }"
" attr { key: 'data_format' value { s: 'NCHW' } }"
" attr { key: 'use_cudnn_on_gpu' value { b: false } }"
@ -163,22 +161,22 @@ TEST_F(MklLayoutPassTest, NodeMerge_Conv2DWithBias_Positive) {
" attr {key: 'T' value { type: DT_FLOAT } }"
" input: ['E', 'Y']}");
EXPECT_EQ(DoMklLayoutOptimizationPass(),
"A(Input);B(Input);D(Input);DMT/_0(Const);E(MklConv2DWithBias);"
"M(MklInput);N(MklInput);Y(Input);Z(Sub)|A->E;B->E:1;D->E:2;"
"A(Input);B(Input);D(Input);DMT/_0(Const);E(_MklConv2DWithBias);"
"M(_MklInput);N(_MklInput);Y(Input);Z(Sub)|A->E;B->E:1;D->E:2;"
"DMT/_0->E:5;E->Z;M->E:3;N->E:4;Y->Z:1");
}
// C=MklConv2D(A,M:1,B,N:1); E=BiasAdd(C,D); Z=Sub(E,Y) (for interleaved)
// C=MklConv2D(A,B,M:1,N:1); E=BiasAdd(C,D); Z=Sub(E,Y) (for contiguous)
// C=_MklConv2D(A,M:1,B,N:1); E=BiasAdd(C,D); Z=Sub(E,Y) (for interleaved)
// C=_MklConv2D(A,B,M:1,N:1); E=BiasAdd(C,D); Z=Sub(E,Y) (for contiguous)
// Test for correct output slots selected
TEST_F(MklLayoutPassTest, NodeMerge_Conv2DWithBias_Positive1) {
CHECK_EQ(kTensorOrdering, MklTfTensorOrdering::TENSORS_CONTIGUOUS);
InitGraph(
"node { name: 'A' op: 'Input'}"
"node { name: 'B' op: 'Input'}"
"node { name: 'M' op: 'MklInput2'}"
"node { name: 'N' op: 'MklInput2'}"
"node { name: 'C' op: 'MklConv2D'"
"node { name: 'M' op: '_MklInput2'}"
"node { name: 'N' op: '_MklInput2'}"
"node { name: 'C' op: '_MklConv2D'"
" attr { key: 'T' value { type: DT_FLOAT } }"
" attr { key: 'data_format' value { s: 'NCHW' } }"
" attr { key: 'use_cudnn_on_gpu' value { b: false } }"
@ -195,15 +193,15 @@ TEST_F(MklLayoutPassTest, NodeMerge_Conv2DWithBias_Positive1) {
" attr {key: 'T' value { type: DT_FLOAT } }"
" input: ['E', 'Y']}");
EXPECT_EQ(DoMklLayoutOptimizationPass(),
"A(Input);B(Input);D(Input);DMT/_0(Const);E(MklConv2DWithBias);"
"M(MklInput2);N(MklInput2);Y(Input);Z(Sub)|A->E;B->E:1;D->E:2;"
"A(Input);B(Input);D(Input);DMT/_0(Const);E(_MklConv2DWithBias);"
"M(_MklInput2);N(_MklInput2);Y(Input);Z(Sub)|A->E;B->E:1;D->E:2;"
"DMT/_0->E:5;E->Z;M:1->E:3;N:1->E:4;Y->Z:1");
}
// C=Conv2D(A,B); E=BiasAdd(C,D); Z=Sub(E,Y);
// This is a case of node rewrite followed by node merge.
// We will first rewrite Conv2D to MklConv2D, and then merge MklConv2D
// with BiasAdd to produce MklConv2DWithBias.
// We will first rewrite Conv2D to _MklConv2D, and then merge _MklConv2D
// with BiasAdd to produce _MklConv2DWithBias.
TEST_F(MklLayoutPassTest, NodeMerge_Conv2DWithBias_Positive2) {
CHECK_EQ(kTensorOrdering, MklTfTensorOrdering::TENSORS_CONTIGUOUS);
InitGraph(
@ -227,19 +225,19 @@ TEST_F(MklLayoutPassTest, NodeMerge_Conv2DWithBias_Positive2) {
" input: ['E', 'Y']}");
EXPECT_EQ(DoMklLayoutOptimizationPass(),
"A(Input);B(Input);D(Input);DMT/_0(Const);DMT/_1(Const);"
"DMT/_2(Const);E(MklConv2DWithBias);Y(Input);Z(Sub)|"
"DMT/_2(Const);E(_MklConv2DWithBias);Y(Input);Z(Sub)|"
"A->E;B->E:1;D->E:2;DMT/_0->E:3;DMT/_1->E:4;DMT/_2->E:5;"
"E->Z;Y->Z:1");
}
// Graph contains only MklConv2D, no AddBias.
// Graph contains only _MklConv2D, no AddBias.
TEST_F(MklLayoutPassTest, NodeMerge_Conv2DWithBias_Negative_NoAddBias) {
InitGraph(
"node { name: 'A' op: 'Input'}"
"node { name: 'B' op: 'Input'}"
"node { name: 'M' op: 'MklInput'}"
"node { name: 'N' op: 'MklInput'}"
"node { name: 'C' op: 'MklConv2D'"
"node { name: 'M' op: '_MklInput'}"
"node { name: 'N' op: '_MklInput'}"
"node { name: 'C' op: '_MklConv2D'"
" attr { key: 'T' value { type: DT_FLOAT } }"
" attr { key: 'data_format' value { s: 'NCHW' } }"
" attr { key: 'use_cudnn_on_gpu' value { b: false } }"
@ -247,18 +245,18 @@ TEST_F(MklLayoutPassTest, NodeMerge_Conv2DWithBias_Negative_NoAddBias) {
" attr { key: 'padding' value { s: 'SAME' } }"
" input: ['A', 'B', 'M', 'N']}");
EXPECT_EQ(DoMklLayoutOptimizationPass(),
"A(Input);B(Input);C(MklConv2D);M(MklInput);N(MklInput)|"
"A->C;B->C:1;M->C:2;N->C:3");
"A(Input);B(Input);C(_MklConv2D);M(_MklInput);N(_MklInput)|"
"A->C;B->C:1;M->C:2;N->C:3");
}
// MklConv2D output does not go to BiasAdd.
// _MklConv2D output does not go to BiasAdd.
TEST_F(MklLayoutPassTest, NodeMerge_Conv2DWithBias_Negative_Dataflow1) {
InitGraph(
"node { name: 'A' op: 'Input'}"
"node { name: 'B' op: 'Input'}"
"node { name: 'M' op: 'MklInput'}"
"node { name: 'N' op: 'MklInput'}"
"node { name: 'C' op: 'MklConv2D'"
"node { name: 'M' op: '_MklInput'}"
"node { name: 'N' op: '_MklInput'}"
"node { name: 'C' op: '_MklConv2D'"
" attr { key: 'T' value { type: DT_FLOAT } }"
" attr { key: 'data_format' value { s: 'NCHW' } }"
" attr { key: 'use_cudnn_on_gpu' value { b: false } }"
@ -270,21 +268,21 @@ TEST_F(MklLayoutPassTest, NodeMerge_Conv2DWithBias_Negative_Dataflow1) {
"node { name: 'F' op: 'BiasAdd'"
" attr { key: 'T' value { type: DT_FLOAT } }"
" attr { key: 'data_format' value { s: 'NCHW' } }"
" input: ['D', 'E'] }"); // Output of MklConv2D does not go to BiasAdd.
" input: ['D', 'E'] }"); // Output of _MklConv2D does not go to BiasAdd.
EXPECT_EQ(DoMklLayoutOptimizationPass(),
"A(Input);B(Input);C(MklConv2D);D(Input);E(Input);F(BiasAdd);"
"M(MklInput);N(MklInput)|A->C;B->C:1;D->F;E->F:1;M->C:2;N->C:3");
"A(Input);B(Input);C(_MklConv2D);D(Input);E(Input);F(BiasAdd);"
"M(_MklInput);N(_MklInput)|A->C;B->C:1;D->F;E->F:1;M->C:2;N->C:3");
}
// MklConv2D has two outgoing edges: BiasAdd and some other dummy node (Add).
// _MklConv2D has two outgoing edges: BiasAdd and some other dummy node (Add).
// Merge should not be done in such case.
TEST_F(MklLayoutPassTest, NodeMerge_Conv2DWithBias_Negative_Dataflow2) {
InitGraph(
"node { name: 'A' op: 'Input'}"
"node { name: 'B' op: 'Input'}"
"node { name: 'M' op: 'MklInput'}"
"node { name: 'N' op: 'MklInput'}"
"node { name: 'C' op: 'MklConv2D'"
"node { name: 'M' op: '_MklInput'}"
"node { name: 'N' op: '_MklInput'}"
"node { name: 'C' op: '_MklConv2D'"
" attr { key: 'T' value { type: DT_FLOAT } }"
" attr { key: 'data_format' value { s: 'NCHW' } }"
" attr { key: 'use_cudnn_on_gpu' value { b: false } }"
@ -302,8 +300,8 @@ TEST_F(MklLayoutPassTest, NodeMerge_Conv2DWithBias_Negative_Dataflow2) {
" attr { key: 'T' value { type: DT_FLOAT } }"
" input: ['C', 'E'] }");
EXPECT_EQ(DoMklLayoutOptimizationPass(),
"A(Input);B(Input);C(MklConv2D);D(Input);E(Input);F(BiasAdd);"
"G(Add);M(MklInput);N(MklInput)|A->C;B->C:1;C->G;D->F;"
"A(Input);B(Input);C(_MklConv2D);D(Input);E(Input);F(BiasAdd);"
"G(Add);M(_MklInput);N(_MklInput)|A->C;B->C:1;C->G;D->F;"
"E->F:1;E->G:1;M->C:2;N->C:3");
}
@ -313,9 +311,9 @@ TEST_F(MklLayoutPassTest, NodeMerge_Conv2DWithBias_Negative_AttrMismatch) {
InitGraph(
"node { name: 'A' op: 'Input'}"
"node { name: 'B' op: 'Input'}"
"node { name: 'M' op: 'MklInput'}"
"node { name: 'N' op: 'MklInput'}"
"node { name: 'C' op: 'MklConv2D'"
"node { name: 'M' op: '_MklInput'}"
"node { name: 'N' op: '_MklInput'}"
"node { name: 'C' op: '_MklConv2D'"
" attr { key: 'T' value { type: DT_FLOAT } }"
" attr { key: 'data_format' value { s: 'NCHW' } }"
" attr { key: 'use_cudnn_on_gpu' value { b: false } }"
@ -328,26 +326,26 @@ TEST_F(MklLayoutPassTest, NodeMerge_Conv2DWithBias_Negative_AttrMismatch) {
" attr { key: 'data_format' value { s: 'NHCW' } }"
" input: ['C', 'D'] }");
EXPECT_EQ(DoMklLayoutOptimizationPass(),
"A(Input);B(Input);C(MklConv2D);D(Input);E(BiasAdd);M(MklInput);"
"N(MklInput)|A->C;B->C:1;C->E;D->E:1;M->C:2;N->C:3");
"A(Input);B(Input);C(_MklConv2D);D(Input);E(BiasAdd);M(_MklInput);"
"N(_MklInput)|A->C;B->C:1;C->E;D->E:1;M->C:2;N->C:3");
}
// Disabling Conv2DBackpropBias test for now as we have disabled rewrite
// of BiasAddGrad into BackpropBias
#if 0
// Test set 2: MklConv2D..BiasAddGrad -> MklConv2DWithBiasBackpropBias
// Test set 2: _MklConv2D..BiasAddGrad -> _MklConv2DWithBiasBackpropBias
// rewrite tests
// D=MklConv2D(A,M,B,N,C,O); E=Sub(D,A); F=BiasAddGrad(E)
// D=_MklConv2D(A,M,B,N,C,O); E=Sub(D,A); F=BiasAddGrad(E)
TEST_F(MklLayoutPassTest, NodeMerge_Conv2DBackprop_Positive) {
InitGraph(
"node { name: 'A' op: 'Input'}"
"node { name: 'B' op: 'Input'}"
"node { name: 'C' op: 'Input'}"
"node { name: 'M' op: 'MklInput'}"
"node { name: 'N' op: 'MklInput'}"
"node { name: 'O' op: 'MklInput'}"
"node { name: 'D' op: 'MklConv2DWithBias'"
"node { name: 'M' op: '_MklInput'}"
"node { name: 'N' op: '_MklInput'}"
"node { name: 'O' op: '_MklInput'}"
"node { name: 'D' op: '_MklConv2DWithBias'"
" attr { key: 'T' value { type: DT_FLOAT } }"
" attr { key: 'data_format' value { s: 'NCHW' } }"
" attr { key: 'use_cudnn_on_gpu' value { b: false } }"
@ -362,25 +360,25 @@ TEST_F(MklLayoutPassTest, NodeMerge_Conv2DBackprop_Positive) {
" attr { key: 'data_format' value { s: 'NCHW' } }"
" input: ['E'] }");
EXPECT_EQ(DoMklLayoutOptimizationPass(),
"A(Input);B(Input);C(Input);D(MklConv2DWithBias);DMT/_0(Const);"
"E(Sub);F(MklConv2DWithBiasBackpropBias);M(MklInput);N(MklInput);"
"O(MklInput)|A->D;A->E:1;B->D:1;C->D:2;D->E;DMT/_0->F:1;E->F;"
"A(Input);B(Input);C(Input);D(_MklConv2DWithBias);DMT/_0(Const);"
"E(Sub);F(_MklConv2DWithBiasBackpropBias);M(_MklInput);N(_MklInput);"
"O(_MklInput)|A->D;A->E:1;B->D:1;C->D:2;D->E;DMT/_0->F:1;E->F;"
"M->D:3;N->D:4;O->D:5");
}
#endif
// No MklConv2D in context, but Conv2D in context.
// Only Conv2D would be rewritten to MklConv2D, but no rewrite
// No _MklConv2D in context, but Conv2D in context.
// Only Conv2D would be rewritten to _MklConv2D, but no rewrite
// for BiasAddGrad should happen.
// C=MklConv2D(A,M,B,N); D=Sub(C,A); E=BiasAddGrad(D) (for interleaved)
// C=MklConv2D(A,B,M,N); D=Sub(C,A); E=BiasAddGrad(D) (for contiguous)
TEST_F(MklLayoutPassTest, NodeMerge_Conv2DBackprop_Neg_NoMklConv2DWithBias) {
// C=_MklConv2D(A,M,B,N); D=Sub(C,A); E=BiasAddGrad(D) (for interleaved)
// C=_MklConv2D(A,B,M,N); D=Sub(C,A); E=BiasAddGrad(D) (for contiguous)
TEST_F(MklLayoutPassTest, NodeMerge_Conv2DBackprop_Neg_No_MklConv2DWithBias) {
InitGraph(
"node { name: 'A' op: 'Input'}"
"node { name: 'B' op: 'Input'}"
"node { name: 'M' op: 'MklInput'}"
"node { name: 'N' op: 'MklInput'}"
"node { name: 'C' op: 'MklConv2D'"
"node { name: 'M' op: '_MklInput'}"
"node { name: 'N' op: '_MklInput'}"
"node { name: 'C' op: '_MklConv2D'"
" attr { key: 'T' value { type: DT_FLOAT } }"
" attr { key: 'data_format' value { s: 'NCHW' } }"
" attr { key: 'use_cudnn_on_gpu' value { b: false } }"
@ -395,8 +393,8 @@ TEST_F(MklLayoutPassTest, NodeMerge_Conv2DBackprop_Neg_NoMklConv2DWithBias) {
" attr { key: 'data_format' value { s: 'NCHW' } }"
" input: ['D'] }");
EXPECT_EQ(DoMklLayoutOptimizationPass(),
"A(Input);B(Input);C(MklConv2D);D(Sub);E(BiasAddGrad);"
"M(MklInput);N(MklInput)|A->C;A->D:1;B->C:1;C->D;D->E;"
"A(Input);B(Input);C(_MklConv2D);D(Sub);E(BiasAddGrad);"
"M(_MklInput);N(_MklInput)|A->C;A->D:1;B->C:1;C->D;D->E;"
"M->C:2;N->C:3");
}
@ -509,7 +507,7 @@ TEST_F(MklLayoutPassTest, NodeRewrite_Conv2D_Basic) {
"node { name: 'D' op: 'Mul' attr { key: 'T' value { type: DT_FLOAT } }"
" input: ['B', 'C'] }");
EXPECT_EQ(DoMklLayoutOptimizationPass(),
"A(Input);B(Input);C(MklConv2D);D(Mul);DMT/_0(Const);DMT/_1(Const)|"
"A(Input);B(Input);C(_MklConv2D);D(Mul);DMT/_0(Const);DMT/_1(Const)|"
"A->C;B->C:1;B->D;C->D:1;DMT/_0->C:2;DMT/_1->C:3");
}
@ -536,7 +534,7 @@ TEST_F(MklLayoutPassTest, NodeRewrite_Conv2D_Positive1) {
"node { name: 'E' op: 'Mul' attr { key: 'T' value { type: DT_FLOAT } }"
" input: ['C', 'D'] }");
EXPECT_EQ(DoMklLayoutOptimizationPass(),
"A(Input);B(Input);C(MklConv2D);D(MklConv2D);DMT/_0(Const);"
"A(Input);B(Input);C(_MklConv2D);D(_MklConv2D);DMT/_0(Const);"
"DMT/_1(Const);DMT/_2(Const);E(Mul)|A->C;A->D;B->C:1;C->D:1;C->E;"
"C:1->D:3;D->E:1;DMT/_0->C:2;DMT/_1->C:3;DMT/_2->D:2");
}
@ -578,7 +576,7 @@ TEST_F(MklLayoutPassTest, NodeRewrite_Concat_Basic) {
"node { name: 'E' op: 'Mul' attr { key: 'T' value { type: DT_FLOAT } }"
" input: ['C', 'D'] }");
EXPECT_EQ(DoMklLayoutOptimizationPass(),
"A(Const);B(InputList);C(Input);D(MklConcat);DMT/_0(Const);"
"A(Const);B(InputList);C(Input);D(_MklConcat);DMT/_0(Const);"
"DMT/_1(Const);DMT/_2(Const);E(Mul)|A->D;B->D:1;B->D:2;C->E;"
"D->E:1;DMT/_0->D:3;DMT/_1->D:4;DMT/_2->D:5");
}
@ -617,8 +615,8 @@ TEST_F(MklLayoutPassTest, NodeRewrite_Concat_Input_Mkl) {
" input: ['A', 'H'] }");
EXPECT_EQ(DoMklLayoutOptimizationPass(),
"A(Input);B(Input);C(Input);D(Input);DMT/_0(Const);DMT/_1(Const);"
"DMT/_2(Const);DMT/_3(Const);DMT/_4(Const);E(MklConv2D);"
"F(MklConv2D);G(Const);H(MklConcat);I(Mul)|A->E;A->I;B->E:1;C->F;"
"DMT/_2(Const);DMT/_3(Const);DMT/_4(Const);E(_MklConv2D);"
"F(_MklConv2D);G(Const);H(_MklConcat);I(Mul)|A->E;A->I;B->E:1;C->F;"
"D->F:1;DMT/_0->F:2;DMT/_1->F:3;DMT/_2->E:2;DMT/_3->E:3;"
"DMT/_4->H:3;E->H:1;E:1->H:4;F->H:2;F:1->H:5;G->H;H->I:1");
}
@ -652,8 +650,8 @@ TEST_F(MklLayoutPassTest, NodeRewrite_Concat_Input_MixedMkl) {
" input: ['A', 'H'] }");
EXPECT_EQ(DoMklLayoutOptimizationPass(),
"A(Input);B(Input);C(Input);D(Input);DMT/_0(Const);DMT/_1(Const);"
"DMT/_2(Const);DMT/_3(Const);E(MklConv2D);F(Mul);G(Const);"
"H(MklConcat);I(Mul)|A->E;A->I;B->E:1;C->F;D->F:1;DMT/_0->E:2;"
"DMT/_2(Const);DMT/_3(Const);E(_MklConv2D);F(Mul);G(Const);"
"H(_MklConcat);I(Mul)|A->E;A->I;B->E:1;C->F;D->F:1;DMT/_0->E:2;"
"DMT/_1->E:3;DMT/_2->H:3;DMT/_3->H:5;E->H:1;E:1->H:4;F->H:2;"
"G->H;H->I:1");
}
@ -678,7 +676,7 @@ TEST_F(MklLayoutPassTest, NodeRewrite_ConcatV2_Basic) {
"node { name: 'E' op: 'Mul' attr { key: 'T' value { type: DT_FLOAT } }"
" input: ['C', 'D'] }");
EXPECT_EQ(DoMklLayoutOptimizationPass(),
"A(Const);B(InputList);C(Input);D(MklConcat);DMT/_0(Const);"
"A(Const);B(InputList);C(Input);D(_MklConcat);DMT/_0(Const);"
"DMT/_1(Const);DMT/_2(Const);E(Mul)|A->D:2;B->D;B:1->D:1;C->E;"
"D->E:1;DMT/_0->D:3;DMT/_1->D:4;DMT/_2->D:5");
}
@ -719,8 +717,8 @@ TEST_F(MklLayoutPassTest, NodeRewrite_ConcatV2_Input_Mkl) {
" input: ['A', 'H'] }");
EXPECT_EQ(DoMklLayoutOptimizationPass(),
"A(Input);B(Input);C(Input);D(Input);DMT/_0(Const);DMT/_1(Const);"
"DMT/_2(Const);DMT/_3(Const);DMT/_4(Const);E(MklConv2D);"
"F(MklConv2D);G(Const);H(MklConcatV2);I(Mul)|A->E;A->I;B->E:1;C->F;"
"DMT/_2(Const);DMT/_3(Const);DMT/_4(Const);E(_MklConv2D);"
"F(_MklConv2D);G(Const);H(_MklConcatV2);I(Mul)|A->E;A->I;B->E:1;C->F;"
"D->F:1;DMT/_0->F:2;DMT/_1->F:3;DMT/_2->E:2;DMT/_3->E:3;"
"DMT/_4->H:5;E->H;E:1->H:3;F->H:1;F:1->H:4;G->H:2;H->I:1");
}
@ -755,8 +753,8 @@ TEST_F(MklLayoutPassTest, NodeRewrite_ConcatV2_Input_MixedMkl) {
" input: ['A', 'H'] }");
EXPECT_EQ(DoMklLayoutOptimizationPass(),
"A(Input);B(Input);C(Input);D(Input);DMT/_0(Const);DMT/_1(Const);"
"DMT/_2(Const);DMT/_3(Const);E(MklConv2D);F(Mul);G(Const);"
"H(MklConcatV2);I(Mul)|A->E;A->I;B->E:1;C->F;D->F:1;DMT/_0->E:2;"
"DMT/_2(Const);DMT/_3(Const);E(_MklConv2D);F(Mul);G(Const);"
"H(_MklConcatV2);I(Mul)|A->E;A->I;B->E:1;C->F;D->F:1;DMT/_0->E:2;"
"DMT/_1->E:3;DMT/_2->H:4;DMT/_3->H:5;E->H;E:1->H:3;F->H:1;"
"G->H:2;H->I:1");
}
@ -804,9 +802,10 @@ TEST_F(MklLayoutPassTest, MaxPoolLRN_Positive) {
"node { name: 'H' op: 'Input'}"
"node { name: 'I' op: 'Mul' attr { key: 'T' value { type: DT_FLOAT } }"
" input: ['H', 'G'] }");
EXPECT_EQ(DoMklLayoutOptimizationPass(),
"A(Input);B(MklLRN);C(MklMaxPool);D(Input);DMT/_0(Const);DMT/_1(Const);"
"DMT/_2(Const);E(MklMaxPoolGrad);F(Input);G(MklLRNGrad);H(Input);I(Mul)|"
EXPECT_EQ(
DoMklLayoutOptimizationPass(),
"A(Input);B(_MklLRN);C(_MklMaxPool);D(Input);DMT/_0(Const);DMT/_1(Const);"
"DMT/_2(Const);E(_MklMaxPoolGrad);F(Input);G(_MklLRNGrad);H(Input);I(Mul)|"
"A->B;B->C;B->E;B->G:2;B:1->G:3;B:2->C:1;B:2->E:4;B:2->G:6;B:3->G:7;"
"C->E:1;C:1->E:3;C:2->E:5;C:3->E:7;D->E:2;DMT/_0->B:1;DMT/_1->E:6;"
"DMT/_2->G:5;E->G;E:1->G:4;F->G:1;G->I:1;H->I");
@ -837,8 +836,8 @@ TEST_F(MklLayoutPassTest, LRN_Positive) {
"node { name: 'F' op: 'Mul' attr { key: 'T' value { type: DT_FLOAT } }"
" input: ['C', 'E'] }");
EXPECT_EQ(DoMklLayoutOptimizationPass(),
"A(Input);B(MklLRN);C(Input);D(Input);DMT/_0(Const);DMT/_1(Const);"
"DMT/_2(Const);E(MklLRNGrad);F(Mul)|"
"A(Input);B(_MklLRN);C(Input);D(Input);DMT/_0(Const);DMT/_1(Const);"
"DMT/_2(Const);E(_MklLRNGrad);F(Mul)|"
"A->B;B->E:2;B:1->E:3;B:2->E:6;B:3->E:7;C->E;C->F;D->E:1;"
"DMT/_0->B:1;DMT/_1->E:4;DMT/_2->E:5;E->F:1");
}
@ -858,7 +857,7 @@ TEST_F(MklLayoutPassTest, LRN_Negative1) {
"node { name: 'C' op: 'Mul' attr { key: 'T' value { type: DT_FLOAT } }"
" input: ['A', 'B'] }");
EXPECT_EQ(DoMklLayoutOptimizationPass(),
"A(Input);B(MklLRN);C(Mul);DMT/_0(Const)|"
"A(Input);B(_MklLRN);C(Mul);DMT/_0(Const)|"
"A->B;A->C;B->C:1;DMT/_0->B:1");
}
@ -879,7 +878,7 @@ TEST_F(MklLayoutPassTest, LRN_Negative2) {
"node { name: 'E' op: 'Mul' attr { key: 'T' value { type: DT_FLOAT } }"
" input: ['A', 'D'] }");
EXPECT_EQ(DoMklLayoutOptimizationPass(),
"A(Input);B(Input);C(Input);D(MklLRNGrad);DMT/_0(Const);"
"A(Input);B(Input);C(Input);D(_MklLRNGrad);DMT/_0(Const);"
"DMT/_1(Const);DMT/_2(Const);DMT/_3(Const);DMT/_4(Const);E(Mul)|"
"A->D;A->E;B->D:1;C->D:2;D->E:1;DMT/_0->D:3;DMT/_1->D:7;"
"DMT/_2->D:4;DMT/_3->D:5;DMT/_4->D:6");
@ -919,9 +918,9 @@ TEST_F(MklLayoutPassTest, LRN_Negative3) {
"node { name: 'G' op: 'Mul' attr { key: 'T' value { type: DT_FLOAT } }"
" input: ['E', 'F'] }");
EXPECT_EQ(DoMklLayoutOptimizationPass(),
"A(Input);B(MklLRN);C(Input);D(Input);DMT/_0(Const);DMT/_1(Const);"
"A(Input);B(_MklLRN);C(Input);D(Input);DMT/_0(Const);DMT/_1(Const);"
"DMT/_2(Const);DMT/_3(Const);DMT/_4(Const);DMT/_5(Const);"
"DMT/_6(Const);E(MklLRNGrad);F(MklLRNGrad);G(Mul)|A->B;B->E:2;"
"DMT/_6(Const);E(_MklLRNGrad);F(_MklLRNGrad);G(Mul)|A->B;B->E:2;"
"B->F:1;B:1->E:3;B:2->E:6;B:2->F:5;B:3->E:7;C->E;C->F;D->E:1;"
"D->F:2;DMT/_0->B:1;DMT/_1->F:3;DMT/_2->F:7;DMT/_3->F:4;"
"DMT/_4->F:6;DMT/_5->E:4;DMT/_6->E:5;E->G;F->G:1");
@ -950,8 +949,8 @@ TEST_F(MklLayoutPassTest, NodeWorkspace_MaxPool_Positive) {
"node { name: 'F' op: 'Mul' attr { key: 'T' value { type: DT_FLOAT } }"
" input: ['C', 'E'] }");
EXPECT_EQ(DoMklLayoutOptimizationPass(),
"A(Input);B(MklMaxPool);C(Input);D(Input);DMT/_0(Const);"
"DMT/_1(Const);DMT/_2(Const);E(MklMaxPoolGrad);F(Mul)|"
"A(Input);B(_MklMaxPool);C(Input);D(Input);DMT/_0(Const);"
"DMT/_1(Const);DMT/_2(Const);E(_MklMaxPoolGrad);F(Mul)|"
"A->B;B->E:1;B:1->E:3;B:2->E:5;B:3->E:7;C->E;C->F;D->E:2;"
"DMT/_0->B:1;DMT/_1->E:4;DMT/_2->E:6;E->F:1");
}
@ -972,7 +971,7 @@ TEST_F(MklLayoutPassTest, NodeWorkspace_MaxPool_Negative1) {
"node { name: 'C' op: 'Mul' attr { key: 'T' value { type: DT_FLOAT } }"
" input: ['A', 'B'] }");
EXPECT_EQ(DoMklLayoutOptimizationPass(),
"A(Input);B(MklMaxPool);C(Mul);DMT/_0(Const)|"
"A(Input);B(_MklMaxPool);C(Mul);DMT/_0(Const)|"
"A->B;A->C;B->C:1;DMT/_0->B:1");
}
@ -994,7 +993,7 @@ TEST_F(MklLayoutPassTest, NodeWorkspace_MaxPool_Negative2) {
"node { name: 'E' op: 'Mul' attr { key: 'T' value { type: DT_FLOAT } }"
" input: ['A', 'D'] }");
EXPECT_EQ(DoMklLayoutOptimizationPass(),
"A(Input);B(Input);C(Input);D(MklMaxPoolGrad);DMT/_0(Const);"
"A(Input);B(Input);C(Input);D(_MklMaxPoolGrad);DMT/_0(Const);"
"DMT/_1(Const);DMT/_2(Const);DMT/_3(Const);DMT/_4(Const);E(Mul)|"
"A->D;A->E;B->D:1;C->D:2;D->E:1;DMT/_0->D:3;DMT/_1->D:7;"
"DMT/_2->D:4;DMT/_3->D:5;DMT/_4->D:6");

View File

@ -123,22 +123,24 @@ Status MklToTfConversionPass::InsertConversionNodeOnEdge(
TF_CHECK_OK(GetNodeAttr(src->def(), "T", &src_datatype));
TF_CHECK_OK(GetNodeAttr(dst->def(), "T", &dst_datatype));
if (src_datatype != dst_datatype) {
string err_msg = "T attribute of " + src->name() + " and " +
dst->name() + " do not match. Will not insert" +
string err_msg = "T attribute of " + src->name() + " and " + dst->name() +
" do not match. Will not insert" +
" MklToTf node in such case.";
return Status(error::Code::INVALID_ARGUMENT, err_msg.c_str());
}
// Build the conversion node and specify src as input.
TF_CHECK_OK(NodeBuilder((*g)->NewName("Mkl2Tf"), "MklToTf")
.Input(src, e->src_output())
.Input(src, DataIndexToMetaDataIndex(
e->src_output(), src->num_outputs())) // Get an Mkl tensor slot
// from the Tf tensor slot.
.Device(src->def().device()) // We want to get conversion node
// on same device as source node.
.Attr("T", src_datatype)
.Finalize(&**g, &conversion_node));
TF_CHECK_OK(
NodeBuilder((*g)->NewName("Mkl2Tf"), "_MklToTf")
.Input(src, e->src_output())
.Input(src, DataIndexToMetaDataIndex(
e->src_output(),
src->num_outputs())) // Get an Mkl tensor slot
// from the Tf tensor slot.
.Device(src->def().device()) // We want to get conversion node
// on same device as source node.
.Attr("T", src_datatype)
.Finalize(&**g, &conversion_node));
CHECK_NOTNULL(conversion_node);
if (GetNodeAttr(src->def(), "data_format", &data_format) == Status::OK()) {
@ -191,8 +193,8 @@ bool MklToTfConversionPass::RunPass(std::unique_ptr<Graph>* g) {
// We skip adding MklToTf on an edge between X->MklToTf or
// MklToTf->X, where X is any node.
if (src->type_string().compare("MklToTf") == 0 ||
dst->type_string().compare("MklToTf") == 0) {
if (src->type_string().compare("_MklToTf") == 0 ||
dst->type_string().compare("_MklToTf") == 0) {
continue;
}
@ -246,8 +248,7 @@ bool InsertMklToTfConversionNodes(std::unique_ptr<Graph>* g) {
return MklToTfConversionPass().RunPass(g);
}
Status MklToTfConversionPass::Run(
const GraphOptimizationPassOptions& options) {
Status MklToTfConversionPass::Run(const GraphOptimizationPassOptions& options) {
if (options.graph == nullptr && options.partition_graphs == nullptr) {
return Status::OK();
}

View File

@ -15,8 +15,8 @@ limitations under the License.
#ifdef INTEL_MKL
#include "tensorflow/core/util/mkl_util.h"
#include "tensorflow/core/graph/mkl_tfconversion_pass.h"
#include "tensorflow/core/util/mkl_util.h"
#include <algorithm>
#include <string>
@ -110,7 +110,7 @@ class MklToTfConversionPass : public ::testing::Test {
REGISTER_OP("Input").Output("o: float").SetIsStateful();
REGISTER_OP("HalfInput").Output("o: half").SetIsStateful();
REGISTER_OP("MklInput").Output("o: uint8").SetIsStateful();
REGISTER_OP("_MklInput").Output("o: uint8").SetIsStateful();
TEST_F(MklToTfConversionPass, Basic) {
InitGraph(
@ -131,47 +131,49 @@ TEST_F(MklToTfConversionPass, Basic) {
TEST_F(MklToTfConversionPass, Positive) {
if (kTensorOrdering == MklTfTensorOrdering::TENSORS_INTERLEAVED) {
InitGraph(
"node { name: 'A' op: 'Input'}"
"node { name: 'M' op: 'MklInput'}"
"node { name: 'B' op: 'Input'}"
"node { name: 'N' op: 'MklInput'}"
"node { name: 'C' op: 'MklConv2D'"
" attr { key: 'T' value { type: DT_FLOAT } }"
" attr { key: 'data_format' value { s: 'NCHW' } }"
" attr { key: 'use_cudnn_on_gpu' value { b: false } }"
" attr { key: 'strides' value { list: {i: 1, i:1, i:1, i:1} } }"
" attr { key: 'padding' value { s: 'SAME' } }"
" input: ['A', 'M', 'B', 'N']}"
"node { name: 'D' op: 'Input'}"
"node { name: 'E' op: 'Sub'"
" attr {key: 'T' value { type: DT_FLOAT } }"
" input: ['C', 'D']}");
"node { name: 'A' op: 'Input'}"
"node { name: 'M' op: '_MklInput'}"
"node { name: 'B' op: 'Input'}"
"node { name: 'N' op: '_MklInput'}"
"node { name: 'C' op: '_MklConv2D'"
" attr { key: 'T' value { type: DT_FLOAT } }"
" attr { key: 'data_format' value { s: 'NCHW' } }"
" attr { key: 'use_cudnn_on_gpu' value { b: false } }"
" attr { key: 'strides' value { list: {i: 1, i:1, i:1, i:1} } "
"}"
" attr { key: 'padding' value { s: 'SAME' } }"
" input: ['A', 'M', 'B', 'N']}"
"node { name: 'D' op: 'Input'}"
"node { name: 'E' op: 'Sub'"
" attr {key: 'T' value { type: DT_FLOAT } }"
" input: ['C', 'D']}");
EXPECT_EQ(DoRunMklToTfConversionPass(),
"A(Input);B(Input);C(MklConv2D);D(Input);E(Sub);M(MklInput);"
"Mkl2Tf/_0(MklToTf);N(MklInput)|A->C;B->C:2;C->Mkl2Tf/_0;"
"C:1->Mkl2Tf/_0:1;D->E:1;M->C:1;Mkl2Tf/_0->E;N->C:3");
"A(Input);B(Input);C(_MklConv2D);D(Input);E(Sub);M(_MklInput);"
"_Mkl2Tf/_0(_MklToTf);N(_MklInput)|A->C;B->C:2;C->Mkl2Tf/_0;"
"C:1->Mkl2Tf/_0:1;D->E:1;M->C:1;Mkl2Tf/_0->E;N->C:3");
} else {
CHECK_EQ(kTensorOrdering, MklTfTensorOrdering::TENSORS_CONTIGUOUS);
InitGraph(
"node { name: 'A' op: 'Input'}"
"node { name: 'B' op: 'Input'}"
"node { name: 'M' op: 'MklInput'}"
"node { name: 'N' op: 'MklInput'}"
"node { name: 'C' op: 'MklConv2D'"
" attr { key: 'T' value { type: DT_FLOAT } }"
" attr { key: 'data_format' value { s: 'NCHW' } }"
" attr { key: 'use_cudnn_on_gpu' value { b: false } }"
" attr { key: 'strides' value { list: {i: 1, i:1, i:1, i:1} } }"
" attr { key: 'padding' value { s: 'SAME' } }"
" input: ['A', 'B', 'M', 'N']}"
"node { name: 'D' op: 'Input'}"
"node { name: 'E' op: 'Sub'"
" attr {key: 'T' value { type: DT_FLOAT } }"
" input: ['C', 'D']}");
"node { name: 'A' op: 'Input'}"
"node { name: 'B' op: 'Input'}"
"node { name: 'M' op: '_MklInput'}"
"node { name: 'N' op: '_MklInput'}"
"node { name: 'C' op: '_MklConv2D'"
" attr { key: 'T' value { type: DT_FLOAT } }"
" attr { key: 'data_format' value { s: 'NCHW' } }"
" attr { key: 'use_cudnn_on_gpu' value { b: false } }"
" attr { key: 'strides' value { list: {i: 1, i:1, i:1, i:1} } "
"}"
" attr { key: 'padding' value { s: 'SAME' } }"
" input: ['A', 'B', 'M', 'N']}"
"node { name: 'D' op: 'Input'}"
"node { name: 'E' op: 'Sub'"
" attr {key: 'T' value { type: DT_FLOAT } }"
" input: ['C', 'D']}");
EXPECT_EQ(DoRunMklToTfConversionPass(),
"A(Input);B(Input);C(MklConv2D);D(Input);E(Sub);M(MklInput);"
"Mkl2Tf/_0(MklToTf);N(MklInput)|A->C;B->C:1;C->Mkl2Tf/_0;"
"C:1->Mkl2Tf/_0:1;D->E:1;M->C:2;Mkl2Tf/_0->E;N->C:3");
"A(Input);B(Input);C(_MklConv2D);D(Input);E(Sub);M(_MklInput);"
"_Mkl2Tf/_0(_MklToTf);N(_MklInput)|A->C;B->C:1;C->Mkl2Tf/_0;"
"C:1->Mkl2Tf/_0:1;D->E:1;M->C:2;Mkl2Tf/_0->E;N->C:3");
}
}
@ -182,55 +184,57 @@ TEST_F(MklToTfConversionPass, Positive) {
TEST_F(MklToTfConversionPass, Negative_DoubleInsert) {
if (kTensorOrdering == MklTfTensorOrdering::TENSORS_INTERLEAVED) {
InitGraph(
"node { name: 'A' op: 'Input'}"
"node { name: 'M' op: 'MklInput'}"
"node { name: 'B' op: 'Input'}"
"node { name: 'N' op: 'MklInput'}"
"node { name: 'C' op: 'MklConv2D'"
" attr { key: 'T' value { type: DT_FLOAT } }"
" attr { key: 'data_format' value { s: 'NCHW' } }"
" attr { key: 'use_cudnn_on_gpu' value { b: false } }"
" attr { key: 'strides' value { list: {i: 1, i:1, i:1, i:1} } }"
" attr { key: 'padding' value { s: 'SAME' } }"
" input: ['A', 'M', 'B', 'N']}"
"node { name: 'D' op: 'MklToTf'"
" attr { key: 'T' value { type: DT_FLOAT } }"
" attr { key: 'data_format' value { s: 'NCHW' } }"
" input: ['C:0', 'C:1']}"
"node { name: 'E' op: 'Input'}"
"node { name: 'F' op: 'Sub'"
" attr {key: 'T' value { type: DT_FLOAT } }"
" input: ['D', 'E']}");
"node { name: 'A' op: 'Input'}"
"node { name: 'M' op: '_MklInput'}"
"node { name: 'B' op: 'Input'}"
"node { name: 'N' op: '_MklInput'}"
"node { name: 'C' op: '_MklConv2D'"
" attr { key: 'T' value { type: DT_FLOAT } }"
" attr { key: 'data_format' value { s: 'NCHW' } }"
" attr { key: 'use_cudnn_on_gpu' value { b: false } }"
" attr { key: 'strides' value { list: {i: 1, i:1, i:1, i:1} } "
"}"
" attr { key: 'padding' value { s: 'SAME' } }"
" input: ['A', 'M', 'B', 'N']}"
"node { name: 'D' op: '_MklToTf'"
" attr { key: 'T' value { type: DT_FLOAT } }"
" attr { key: 'data_format' value { s: 'NCHW' } }"
" input: ['C:0', 'C:1']}"
"node { name: 'E' op: 'Input'}"
"node { name: 'F' op: 'Sub'"
" attr {key: 'T' value { type: DT_FLOAT } }"
" input: ['D', 'E']}");
EXPECT_EQ(DoRunMklToTfConversionPass(),
"A(Input);B(Input);C(MklConv2D);D(MklToTf);E(Input);"
"F(Sub);M(MklInput);N(MklInput)|"
"A->C;B->C:2;C->D;C:1->D:1;D->F;E->F:1;M->C:1;N->C:3");
"A(Input);B(Input);C(_MklConv2D);D(_MklToTf);E(Input);"
"F(Sub);M(_MklInput);N(_MklInput)|"
"A->C;B->C:2;C->D;C:1->D:1;D->F;E->F:1;M->C:1;N->C:3");
} else {
CHECK_EQ(kTensorOrdering, MklTfTensorOrdering::TENSORS_CONTIGUOUS);
InitGraph(
"node { name: 'A' op: 'Input'}"
"node { name: 'B' op: 'Input'}"
"node { name: 'M' op: 'MklInput'}"
"node { name: 'N' op: 'MklInput'}"
"node { name: 'C' op: 'MklConv2D'"
" attr { key: 'T' value { type: DT_FLOAT } }"
" attr { key: 'data_format' value { s: 'NCHW' } }"
" attr { key: 'use_cudnn_on_gpu' value { b: false } }"
" attr { key: 'strides' value { list: {i: 1, i:1, i:1, i:1} } }"
" attr { key: 'padding' value { s: 'SAME' } }"
" input: ['A', 'B', 'M', 'N']}"
"node { name: 'D' op: 'MklToTf'"
" attr { key: 'T' value { type: DT_FLOAT } }"
" attr { key: 'data_format' value { s: 'NCHW' } }"
" input: ['C:0', 'C:1']}"
"node { name: 'E' op: 'Input'}"
"node { name: 'F' op: 'Sub'"
" attr {key: 'T' value { type: DT_FLOAT } }"
" input: ['D', 'E']}");
"node { name: 'A' op: 'Input'}"
"node { name: 'B' op: 'Input'}"
"node { name: 'M' op: '_MklInput'}"
"node { name: 'N' op: '_MklInput'}"
"node { name: 'C' op: '_MklConv2D'"
" attr { key: 'T' value { type: DT_FLOAT } }"
" attr { key: 'data_format' value { s: 'NCHW' } }"
" attr { key: 'use_cudnn_on_gpu' value { b: false } }"
" attr { key: 'strides' value { list: {i: 1, i:1, i:1, i:1} } "
"}"
" attr { key: 'padding' value { s: 'SAME' } }"
" input: ['A', 'B', 'M', 'N']}"
"node { name: 'D' op: '_MklToTf'"
" attr { key: 'T' value { type: DT_FLOAT } }"
" attr { key: 'data_format' value { s: 'NCHW' } }"
" input: ['C:0', 'C:1']}"
"node { name: 'E' op: 'Input'}"
"node { name: 'F' op: 'Sub'"
" attr {key: 'T' value { type: DT_FLOAT } }"
" input: ['D', 'E']}");
EXPECT_EQ(DoRunMklToTfConversionPass(),
"A(Input);B(Input);C(MklConv2D);D(MklToTf);E(Input);"
"F(Sub);M(MklInput);N(MklInput)|"
"A->C;B->C:1;C->D;C:1->D:1;D->F;E->F:1;M->C:2;N->C:3");
"A(Input);B(Input);C(_MklConv2D);D(_MklToTf);E(Input);"
"F(Sub);M(_MklInput);N(_MklInput)|"
"A->C;B->C:1;C->D;C:1->D:1;D->F;E->F:1;M->C:2;N->C:3");
}
}
@ -258,7 +262,7 @@ TEST_F(MklToTfConversionPass, Negative_NoMklLayer) {
" input: ['E', 'Y']}");
EXPECT_EQ(DoRunMklToTfConversionPass(),
"A(Input);B(Input);C(Conv2D);D(Input);E(BiasAdd);Y(Input);Z(Sub)|"
"A->C;B->C:1;C->E;D->E:1;E->Z;Y->Z:1");
"A->C;B->C:1;C->E;D->E:1;E->Z;Y->Z:1");
}
static void BM_RunMklToTfConversionPass(int iters, int op_nodes) {

View File

@ -55,8 +55,13 @@ namespace {
// state).
static Status FeedInputs(Graph* g, const DeviceAttributes& device_info,
const gtl::ArraySlice<string>& fed_outputs,
subgraph::NameIndex* name_index) {
for (const string& t : fed_outputs) {
bool use_function_convention,
subgraph::NameIndex* name_index,
DataTypeVector* out_feed_types) {
out_feed_types->clear();
out_feed_types->reserve(fed_outputs.size());
for (size_t i = 0; i < fed_outputs.size(); ++i) {
const string& t = fed_outputs[i];
TensorId id(ParseTensorName(t));
auto iter = name_index->find(id.first);
@ -71,17 +76,31 @@ static Status FeedInputs(Graph* g, const DeviceAttributes& device_info,
}
Node* recv_node;
TF_RETURN_IF_ERROR(
NodeBuilder(strings::StrCat("_recv_", id.first, "_", id.second),
"_Recv")
.Attr("tensor_type", BaseType(n->output_type(id.second)))
.Attr("tensor_name", t)
.Attr("send_device", device_info.name())
.Attr("recv_device", device_info.name())
.Attr("send_device_incarnation",
static_cast<int64>(device_info.incarnation()))
.Attr("client_terminated", true)
.Finalize(g, &recv_node));
if (!use_function_convention) {
TF_RETURN_IF_ERROR(
NodeBuilder(strings::StrCat("_recv_", id.first, "_", id.second),
"_Recv")
.Attr("tensor_type", BaseType(n->output_type(id.second)))
.Attr("tensor_name", t)
.Attr("send_device", device_info.name())
.Attr("recv_device", device_info.name())
.Attr("send_device_incarnation",
static_cast<int64>(device_info.incarnation()))
.Attr("client_terminated", true)
.Finalize(g, &recv_node));
} else {
// NOTE(mrry): We must include the index as part of the node
// name, because _Arg is a "stateful" kernel and therefore
// its name must uniquely identify a kernel instance across all
// graphs in the same session.
TF_RETURN_IF_ERROR(NodeBuilder(strings::StrCat("_arg_", id.first, "_",
id.second, "_", i),
"_Arg")
.Attr("T", BaseType(n->output_type(id.second)))
.Attr("index", static_cast<int32>(i))
.Finalize(g, &recv_node));
}
recv_node->set_assigned_device_name(device_info.name());
// Copy the _output_shapes from the original node to the feed node,
@ -130,6 +149,7 @@ static Status FeedInputs(Graph* g, const DeviceAttributes& device_info,
}
g->RemoveEdge(e);
}
out_feed_types->push_back(BaseType(n->output_type(id.second)));
}
return Status::OK();
}
@ -181,9 +201,14 @@ namespace subgraph {
Status FetchOutputs(Graph* g, const DeviceAttributes& device_info,
const gtl::ArraySlice<string>& fetch_outputs,
NameIndex* name_index, std::vector<Node*>* fetch_nodes) {
fetch_nodes->clear();
for (const string& t : fetch_outputs) {
bool use_function_convention, NameIndex* name_index,
std::vector<Node*>* out_fetch_nodes,
DataTypeVector* out_fetch_types) {
out_fetch_nodes->clear();
out_fetch_nodes->reserve(fetch_outputs.size());
for (size_t i = 0; i < fetch_outputs.size(); ++i) {
const string& t = fetch_outputs[i];
// Parse t into node_name and output_index.
TensorId id(ParseTensorName(t));
@ -213,25 +238,39 @@ Status FetchOutputs(Graph* g, const DeviceAttributes& device_info,
// Create the fetch Node and connect it up
Node* send_node;
TF_RETURN_IF_ERROR(
NodeBuilder(strings::StrCat("_send_", id.first, "_", id.second),
"_Send")
.Input(n, id.second)
.Attr("tensor_name", t)
.Attr("send_device", device_info.name())
.Attr("recv_device", device_info.name())
.Attr("send_device_incarnation",
static_cast<int64>(device_info.incarnation()))
.Attr("client_terminated", true)
.Finalize(g, &send_node));
if (!use_function_convention) {
TF_RETURN_IF_ERROR(
NodeBuilder(strings::StrCat("_send_", id.first, "_", id.second),
"_Send")
.Input(n, id.second)
.Attr("tensor_name", t)
.Attr("send_device", device_info.name())
.Attr("recv_device", device_info.name())
.Attr("send_device_incarnation",
static_cast<int64>(device_info.incarnation()))
.Attr("client_terminated", true)
.Finalize(g, &send_node));
} else {
// NOTE(mrry): We must include the index as part of the node
// name, because _Retval is a "stateful" kernel and therefore
// its name must uniquely identify a kernel instance across all
// graphs in the same session.
TF_RETURN_IF_ERROR(NodeBuilder(strings::StrCat("_retval_", id.first, "_",
id.second, "_", i),
"_Retval")
.Input(n, id.second)
.Attr("T", BaseType(n->output_type(id.second)))
.Attr("index", static_cast<int32>(i))
.Finalize(g, &send_node));
}
send_node->set_assigned_device_name(device_info.name());
VLOG(1) << "Created fetch node: " << SummarizeNodeDef(send_node->def());
// Update the index.
(*name_index)[send_node->name()] = send_node;
g->AddControlEdge(send_node, g->sink_node());
fetch_nodes->push_back(send_node);
out_fetch_nodes->push_back(send_node);
out_fetch_types->push_back(BaseType(n->output_type(id.second)));
}
return Status::OK();
@ -241,7 +280,8 @@ Status RewriteGraphForExecution(
Graph* g, const gtl::ArraySlice<string>& fed_outputs,
const gtl::ArraySlice<string>& fetch_outputs,
const gtl::ArraySlice<string>& target_node_names,
const DeviceAttributes& device_info) {
const DeviceAttributes& device_info, bool use_function_convention,
RewriteGraphMetadata* out_metadata) {
if (fetch_outputs.empty() && target_node_names.empty()) {
return errors::InvalidArgument(
"Must specify at least one target to fetch or execute.");
@ -274,18 +314,21 @@ Status RewriteGraphForExecution(
// currently listed in "fetch_nodes". We pass "name_index" so the index is
// kept up to date.
if (!fed_outputs.empty()) {
TF_RETURN_IF_ERROR(FeedInputs(g, device_info, fed_outputs, &name_index));
TF_RETURN_IF_ERROR(FeedInputs(g, device_info, fed_outputs,
use_function_convention, &name_index,
&out_metadata->feed_types));
}
// Add the fetch nodes, also updating "name_index".
std::vector<Node*> fetch_nodes;
if (!fetch_outputs.empty()) {
TF_RETURN_IF_ERROR(
FetchOutputs(g, device_info, fetch_outputs, &name_index, &fetch_nodes));
TF_RETURN_IF_ERROR(FetchOutputs(g, device_info, fetch_outputs,
use_function_convention, &name_index,
&fetch_nodes, &out_metadata->fetch_types));
}
// Prune the graph to only compute what is needed for the fetch nodes and the
// targets nodes.
// target nodes.
if (!fetch_nodes.empty() || !target_node_names.empty()) {
TF_RETURN_IF_ERROR(
PruneForTargets(g, name_index, fetch_nodes, target_node_names));

View File

@ -26,6 +26,18 @@ limitations under the License.
namespace tensorflow {
namespace subgraph {
// Information about a graph rewritten by `RewriteGraphForExecution()`.
struct RewriteGraphMetadata {
// The element type of each tensor fed to this subgraph. The order
// of types corresponds to the order of tensor names in
// `fed_outputs` when calling `RewriteGraphForExecution()`.
DataTypeVector feed_types;
// The element type of each tensor fetched from this subgraph. The
// order of types corresponds to the order of tensor names in
// `fetch_outputs` when calling `RewriteGraphForExecution()`.
DataTypeVector fetch_types;
};
// Rewrite the graph structure of "*g" to deal with feeding node
// outputs, fetching node outputs, and only running a subset of the
// graph. "fed_outputs" and "fetch_outputs" are both lists of
@ -56,7 +68,8 @@ Status RewriteGraphForExecution(
Graph* g, const gtl::ArraySlice<string>& fed_outputs,
const gtl::ArraySlice<string>& fetch_outputs,
const gtl::ArraySlice<string>& target_node_names,
const DeviceAttributes& device_info);
const DeviceAttributes& device_info, bool use_function_convention,
RewriteGraphMetadata* out_metadata);
typedef std::unordered_map<StringPiece, Node*, StringPiece::Hasher> NameIndex;

View File

@ -104,7 +104,8 @@ class SubgraphTest : public ::testing::Test {
}
string Subgraph(const string& fed_str, const string& fetch_str,
const string& targets_str) {
const string& targets_str,
bool use_function_convention = false) {
Graph* subgraph = new Graph(OpRegistry::Global());
CopyGraph(*g_, subgraph);
std::vector<string> fed =
@ -114,13 +115,18 @@ class SubgraphTest : public ::testing::Test {
std::vector<string> targets =
str_util::Split(targets_str, ',', str_util::SkipEmpty());
Status s = subgraph::RewriteGraphForExecution(subgraph, fed, fetch, targets,
device_info_);
subgraph::RewriteGraphMetadata metadata;
Status s = subgraph::RewriteGraphForExecution(
subgraph, fed, fetch, targets, device_info_, use_function_convention,
&metadata);
if (!s.ok()) {
delete subgraph;
return s.ToString();
}
EXPECT_EQ(fed.size(), metadata.feed_types.size());
EXPECT_EQ(fetch.size(), metadata.fetch_types.size());
// Replace the graph with the subgraph for the rest of the display program
g_.reset(subgraph);
return "OK";
@ -178,6 +184,20 @@ TEST_F(SubgraphTest, FedOutputs1) {
ExpectNodes("W1,W2,_recv_input_1,t1,t2");
}
TEST_F(SubgraphTest, FedOutputs1_FunctionConvention) {
ExpectOK(
"node { name: 'W1' op: 'TestParams' }"
"node { name: 'W2' op: 'TestParams' }"
"node { name: 'input' op: 'TestInput' }"
"node { name: 't1' op: 'TestMul' input: [ 'W1', 'input:1' ] }"
"node { name: 't2' op: 'TestMul' input: [ 'W2', 't1' ] }"
"node { name: 't3_a' op: 'TestRelu' input: 't2' }"
"node { name: 't3_b' op: 'TestRelu' input: 't2' }");
EXPECT_EQ("OK",
Subgraph("input:1", "", "t2", true /* use_function_convention */));
ExpectNodes("W1,W2,_arg_input_1_0,t1,t2");
}
TEST_F(SubgraphTest, FedRefNode) {
ExpectOK(
"node { name: 'W1' op: 'TestParams' }"
@ -189,7 +209,19 @@ TEST_F(SubgraphTest, FedRefNode) {
EXPECT_FALSE(IsRefType(CHECK_NOTNULL(n)->output_type(0)));
}
TEST_F(SubgraphTest, FedOutputs2) {
TEST_F(SubgraphTest, FedRefNode_FunctionConvention) {
ExpectOK(
"node { name: 'W1' op: 'TestParams' }"
"node { name: 'W2' op: 'TestParams' }"
"node { name: 't1' op: 'TestMul' input: [ 'W2', 'W1' ] }");
EXPECT_EQ("OK",
Subgraph("W1:0", "", "t1", true /* use_function_convention */));
ExpectNodes("_arg_W1_0_0,W2,t1");
Node* n = FindNode("_arg_W1_0_0");
EXPECT_FALSE(IsRefType(CHECK_NOTNULL(n)->output_type(0)));
}
TEST_F(SubgraphTest, FedOutputs2_FunctionConvention) {
ExpectOK(
"node { name: 'W1' op: 'TestParams' }"
"node { name: 'W2' op: 'TestParams' }"
@ -200,8 +232,9 @@ TEST_F(SubgraphTest, FedOutputs2) {
"node { name: 't3_b' op: 'TestRelu' input: 't2' }");
// We feed input:1, but nothing connects to it, so the _recv(input:1)
// node also disappears.
EXPECT_EQ("OK", Subgraph("input:1,t1,W2", "", "t2"));
ExpectNodes("_recv_t1_0,_recv_W2_0,t2");
EXPECT_EQ("OK", Subgraph("input:1,t1,W2", "", "t2",
true /* use_function_convention */));
ExpectNodes("_arg_t1_0_1,_arg_W2_0_2,t2");
}
TEST_F(SubgraphTest, FetchOutputs1) {
@ -218,6 +251,22 @@ TEST_F(SubgraphTest, FetchOutputs1) {
"W1,W2,input,t1,t2,_send_W2_0,_send_input_1,_send_t1_0,_send_t2_0");
}
TEST_F(SubgraphTest, FetchOutputs1_FunctionConvention) {
ExpectOK(
"node { name: 'W1' op: 'TestParams' }"
"node { name: 'W2' op: 'TestParams' }"
"node { name: 'input' op: 'TestInput' }"
"node { name: 't1' op: 'TestMul' input: [ 'W1', 'input:1' ] }"
"node { name: 't2' op: 'TestMul' input: [ 'W2', 't1' ] }"
"node { name: 't3_a' op: 'TestRelu' input: 't2' }"
"node { name: 't3_b' op: 'TestRelu' input: 't2' }");
EXPECT_EQ("OK", Subgraph("", "W2,input:1,t1,t2", "t2",
true /* use_function_convention */));
ExpectNodes(
"W1,W2,input,t1,t2,_retval_W2_0_0,_retval_input_1_1,_retval_t1_0_2,_"
"retval_t2_0_3");
}
TEST_F(SubgraphTest, FetchOutputs2) {
ExpectOK(
"node { name: 'W1' op: 'TestParams' }"
@ -231,6 +280,20 @@ TEST_F(SubgraphTest, FetchOutputs2) {
ExpectNodes("W1,W2,input,t1,t2,t3_a,_send_t3_a_0");
}
TEST_F(SubgraphTest, FetchOutputs2_FunctionConvention) {
ExpectOK(
"node { name: 'W1' op: 'TestParams' }"
"node { name: 'W2' op: 'TestParams' }"
"node { name: 'input' op: 'TestInput' }"
"node { name: 't1' op: 'TestMul' input: [ 'W1', 'input:1' ] }"
"node { name: 't2' op: 'TestMul' input: [ 'W2', 't1' ] }"
"node { name: 't3_a' op: 'TestRelu' input: 't2' }"
"node { name: 't3_b' op: 'TestRelu' input: 't2' }");
EXPECT_EQ("OK",
Subgraph("", "t3_a", "t2", true /* use_function_convention */));
ExpectNodes("W1,W2,input,t1,t2,t3_a,_retval_t3_a_0_0");
}
TEST_F(SubgraphTest, ChainOfFools) {
ExpectOK(
"node { name: 'a' op: 'TestParams' }"
@ -315,7 +378,8 @@ TEST_F(SubgraphTest, FedOutputsPreservesOutputShapes) {
REGISTER_OP("In").Output("o: float");
REGISTER_OP("Op").Input("i: float").Output("o: float");
static void BM_Subgraph(int iters, int num_nodes) {
static void BM_SubgraphHelper(int iters, int num_nodes,
bool use_function_convention) {
DeviceAttributes device_info;
device_info.set_name("/job:a/replica:0/task:0/cpu:0");
device_info.set_device_type(DeviceType(DEVICE_CPU).type());
@ -347,12 +411,26 @@ static void BM_Subgraph(int iters, int num_nodes) {
while (--iters > 0) {
Graph* subgraph = new Graph(OpRegistry::Global());
CopyGraph(g, subgraph);
TF_CHECK_OK(subgraph::RewriteGraphForExecution(subgraph, fed, fetch,
targets, device_info));
subgraph::RewriteGraphMetadata metadata;
TF_CHECK_OK(subgraph::RewriteGraphForExecution(
subgraph, fed, fetch, targets, device_info, use_function_convention,
&metadata));
delete subgraph;
}
}
static void BM_Subgraph(int iters, int num_nodes) {
BM_SubgraphHelper(iters, num_nodes, false /* use_function_convention */);
}
static void BM_SubgraphFunctionConvention(int iters, int num_nodes) {
BM_SubgraphHelper(iters, num_nodes, true /* use_function_convention */);
}
BENCHMARK(BM_Subgraph)->Arg(100)->Arg(1000)->Arg(10000)->Arg(100000);
BENCHMARK(BM_SubgraphFunctionConvention)
->Arg(100)
->Arg(1000)
->Arg(10000)
->Arg(100000);
} // namespace
} // namespace tensorflow

View File

@ -18,6 +18,13 @@ limitations under the License.
namespace tensorflow {
namespace grappler {
bool IsDequeueOp(const NodeDef& node) {
static const std::set<std::string> dequeue_ops = {
"QueueDequeueManyV2", "QueueDequeueMany", "QueueDequeueV2",
"QueueDequeue"};
return dequeue_ops.count(node.op()) > 0;
}
bool IsPlaceholder(const NodeDef& node) {
const auto op = node.op();
return op == "Placeholder" || op == "PlaceholderV2";

View File

@ -21,6 +21,7 @@ limitations under the License.
namespace tensorflow {
namespace grappler {
bool IsDequeueOp(const NodeDef& node);
bool IsPlaceholder(const NodeDef& node);
bool IsVariable(const NodeDef& node);

Some files were not shown because too many files have changed in this diff Show More