Merge commit for internal changes
Manually fixed conflicts by accepting --ours: tensorflow/contrib/slim/README.md Manually fixed conflicts by accepting --theirs: tensorflow/contrib/learn/python/learn/datasets/mnist.py tensorflow/contrib/verbs/BUILD tensorflow/contrib/verbs/grpc_verbs_client.cc tensorflow/contrib/verbs/grpc_verbs_client.h tensorflow/contrib/verbs/grpc_verbs_service.cc tensorflow/contrib/verbs/grpc_verbs_service.h tensorflow/contrib/verbs/grpc_verbs_service_impl.cc tensorflow/contrib/verbs/grpc_verbs_service_impl.h tensorflow/contrib/verbs/rdma.cc tensorflow/contrib/verbs/rdma.h tensorflow/contrib/verbs/rdma_mgr.cc tensorflow/contrib/verbs/rdma_mgr.h tensorflow/contrib/verbs/rdma_rendezvous_mgr.cc tensorflow/contrib/verbs/rdma_rendezvous_mgr.h tensorflow/contrib/verbs/verbs_server_lib.cc tensorflow/contrib/verbs/verbs_server_lib.h tensorflow/contrib/verbs/verbs_service.proto tensorflow/contrib/verbs/verbs_util.cc tensorflow/contrib/verbs/verbs_util.h tensorflow/core/distributed_runtime/rpc/grpc_server_lib.cc tensorflow/core/distributed_runtime/rpc/grpc_server_lib.h tensorflow/core/framework/function_testlib.cc tensorflow/core/graph/mkl_layout_pass.cc tensorflow/core/graph/mkl_layout_pass_test.cc tensorflow/core/graph/mkl_tfconversion_pass.cc tensorflow/core/graph/mkl_tfconversion_pass_test.cc tensorflow/core/kernels/fixed_length_record_reader_op.cc tensorflow/core/kernels/mkl_concat_op.cc tensorflow/core/kernels/mkl_conv_grad_bias_ops.cc tensorflow/core/kernels/mkl_conv_grad_filter_ops.cc tensorflow/core/kernels/mkl_conv_grad_input_ops.cc tensorflow/core/kernels/mkl_conv_ops.cc tensorflow/core/kernels/mkl_fused_batch_norm_op.cc tensorflow/core/kernels/mkl_lrn_op.cc tensorflow/core/kernels/mkl_relu_op.cc tensorflow/core/kernels/mkl_reshape_op.cc tensorflow/core/kernels/mkl_tfconv_op.cc tensorflow/core/kernels/sparse_tensor_dense_matmul_op.cc tensorflow/core/kernels/sparse_tensor_dense_matmul_op.h tensorflow/core/kernels/sparse_tensor_dense_matmul_op_gpu.cu.cc tensorflow/core/ops/array_ops.cc tensorflow/core/ops/nn_ops.cc tensorflow/core/util/mkl_util.h tensorflow/examples/tutorials/mnist/mnist_with_summaries.py tensorflow/python/kernel_tests/reader_ops_test.py tensorflow/python/kernel_tests/sparse_tensor_dense_matmul_grad_test.py tensorflow/python/kernel_tests/sparse_tensor_dense_matmul_op_test.py tensorflow/python/ops/io_ops.py tensorflow/python/ops/nn_impl.py tensorflow/python/ops/sparse_grad.py tensorflow/tensorboard/gulp_tasks/vulcanize.js third_party/jemalloc.BUILD
This commit is contained in:
commit
92675de336
README.mdRELEASE.mdconfigure
tensorflow
BUILD
compiler/xla
contrib
distributions/python/ops
framework/python/ops
keras/python/keras
labeled_tensor/python/ops
learn/python/learn
rnn
BUILD
python
seq2seq
kernels
ops
python
specs/python
verbs
core
BUILD
common_runtime
build_graph_options.hconstant_folding.cccopy_tensor.ccdirect_session.ccdirect_session.hexecutor.ccfunction.ccfunction_test.ccgraph_runner.ccrendezvous_mgr.ccresource_variable_read_optimizer.ccsimple_graph_execution_state.ccsimple_graph_execution_state.h
distributed_runtime/rpc
framework
graph
mkl_layout_pass.ccmkl_layout_pass_test.ccmkl_tfconversion_pass.ccmkl_tfconversion_pass_test.ccsubgraph.ccsubgraph.hsubgraph_test.cc
grappler
@ -53,7 +53,7 @@ $ python
|
||||
>>> hello = tf.constant('Hello, TensorFlow!')
|
||||
>>> sess = tf.Session()
|
||||
>>> sess.run(hello)
|
||||
Hello, TensorFlow!
|
||||
'Hello, TensorFlow!'
|
||||
>>> a = tf.constant(10)
|
||||
>>> b = tf.constant(32)
|
||||
>>> sess.run(a+b)
|
||||
|
13
RELEASE.md
13
RELEASE.md
@ -3,6 +3,19 @@
|
||||
## Major Features and Improvements
|
||||
* Added `tf.Session.make_callable()`, which provides a lower overhead means of running a similar step multiple times.
|
||||
* Added ibverbs-based RDMA support to contrib (courtesy @junshi15 from Yahoo).
|
||||
* `RNNCell` objects now subclass `tf.layers._Layer`. The strictness described
|
||||
in the TensorFlow 1.1 release is gone: The first time an RNNCell is used,
|
||||
it caches its scope. All future uses of the RNNCell will reuse variables from
|
||||
that same scope. This is a breaking change from the behavior of RNNCells
|
||||
in TensorFlow versions <= 1.0.1. TensorFlow 1.1 had checks in place to
|
||||
ensure old code works correctly with the new semantics; this version
|
||||
allows more flexible uses of RNNCell but can lead to subtle errors if
|
||||
using code meant for TensorFlow <= 1.0.1. For example, writing:
|
||||
`MultiRNNCell([lstm] * 5)` will now build a 5-layer LSTM stack where each
|
||||
layer shares the **same** parameters. To get 5 layers each with their own
|
||||
parameters, write: `MultiRNNCell([LSTMCell(...) for _ in range(5)])`.
|
||||
If at all unsure, first test your code with TF 1.1; ensure it raises no
|
||||
errors, and then upgrade to TF 1.2.
|
||||
|
||||
|
||||
# Release 1.1.0
|
||||
|
9
configure
vendored
9
configure
vendored
@ -86,15 +86,18 @@ while true; do
|
||||
PYTHON_BIN_PATH=""
|
||||
# Retry
|
||||
done
|
||||
export PYTHON_BIN_PATH
|
||||
write_action_env_to_bazelrc "PYTHON_BIN_PATH" "$PYTHON_BIN_PATH"
|
||||
# TODO(ngiraldo): allow the user to optionally set PYTHON_INCLUDE_PATH and NUMPY_INCLUDE_PATH
|
||||
|
||||
## Set up MKL related environment settings
|
||||
if false; then # Disable building with MKL for now
|
||||
while [ "$TF_NEED_MKL" == "" ]; do
|
||||
fromuser=""
|
||||
read -p "Do you wish to build TensorFlow with MKL support? [y/N] " INPUT
|
||||
read -p "Do you wish to build TensorFlow with MKL support (experimental)? [y/N] " INPUT
|
||||
fromuser="1"
|
||||
case $INPUT in
|
||||
[Yy]* ) echo "MKL support will be enabled for TensorFlow"; TF_NEED_MKL=1;;
|
||||
[Yy]* ) echo "MKL support (experimental) (will be enabled for TensorFlow"; TF_NEED_MKL=1;;
|
||||
[Nn]* ) echo "No MKL support will be enabled for TensorFlow"; TF_NEED_MKL=0;;
|
||||
"" ) echo "No MKL support will be enabled for TensorFlow"; TF_NEED_MKL=0;;
|
||||
* ) echo "Invalid selection: " $INPUT;;
|
||||
@ -261,7 +264,7 @@ if [[ "$TF_NEED_VERBS" == "1" ]]; then
|
||||
fi
|
||||
|
||||
# Invoke python_config and set up symlinks to python includes
|
||||
./util/python/python_config.sh --setup "$PYTHON_BIN_PATH"
|
||||
./util/python/python_config.sh "$PYTHON_BIN_PATH"
|
||||
|
||||
# Append CC optimization flags to bazel.rc
|
||||
echo >> tools/bazel.rc
|
||||
|
@ -278,6 +278,7 @@ filegroup(
|
||||
"//tensorflow/contrib/tfprof/python/tools/tfprof:all_files",
|
||||
"//tensorflow/contrib/training:all_files",
|
||||
"//tensorflow/contrib/util:all_files",
|
||||
"//tensorflow/contrib/verbs:all_files",
|
||||
"//tensorflow/contrib/xla_tf_graph:all_files",
|
||||
"//tensorflow/core:all_files",
|
||||
"//tensorflow/core/debug:all_files",
|
||||
@ -326,6 +327,7 @@ filegroup(
|
||||
"//tensorflow/tensorboard/components/vz_line_chart:all_files",
|
||||
"//tensorflow/tensorboard/components/vz_line_chart/demo:all_files",
|
||||
"//tensorflow/tensorboard/components/vz_projector:all_files",
|
||||
"//tensorflow/tensorboard/components/vz_projector_d3v4:all_files",
|
||||
"//tensorflow/tensorboard/components/vz_sorting:all_files",
|
||||
"//tensorflow/tensorboard/components/vz_sorting/test:all_files",
|
||||
"//tensorflow/tensorboard/lib:all_files",
|
||||
|
@ -38,7 +38,6 @@ static void AllocateFlags() {
|
||||
flags = new GpuCompilerFlags;
|
||||
flags->xla_gpu_embed_ir = false;
|
||||
flags->xla_cuda_data_dir = "./cuda_sdk_lib";
|
||||
flags->xla_ptxas_path = "/usr/local/cuda/bin/ptxas";
|
||||
flag_list = new std::vector<tensorflow::Flag>({
|
||||
tensorflow::Flag(
|
||||
"xla_gpu_embed_ir", &flags->xla_gpu_embed_ir,
|
||||
|
@ -649,4 +649,39 @@ ReferenceUtil::ReduceToRowArray2D(
|
||||
return result;
|
||||
}
|
||||
|
||||
/* static */ Array4D<float> ReferenceUtil::PadArray4D(
|
||||
const Array4D<float>& operand, const PaddingConfig& padding,
|
||||
const float pad) {
|
||||
CHECK_EQ(padding.dimensions_size(), 4);
|
||||
|
||||
const std::vector<int64> input_bounds = {operand.n1(), operand.n2(),
|
||||
operand.n3(), operand.n4()};
|
||||
std::vector<int64> pad_low(4);
|
||||
std::vector<int64> pad_high(4);
|
||||
std::vector<int64> output_bounds(4);
|
||||
for (int64 i = 0; i < 4; ++i) {
|
||||
pad_low[i] = padding.dimensions(i).edge_padding_low();
|
||||
pad_high[i] = padding.dimensions(i).edge_padding_high();
|
||||
CHECK_EQ(padding.dimensions(i).interior_padding(), 0) << "not implemented";
|
||||
|
||||
output_bounds[i] = pad_low[i] + input_bounds[i] + pad_high[i];
|
||||
}
|
||||
|
||||
Array4D<float> result(output_bounds[0], output_bounds[1], output_bounds[2],
|
||||
output_bounds[3]);
|
||||
result.Each([&](tensorflow::gtl::ArraySlice<int64> indices, float* value) {
|
||||
for (int i = 0; i < 4; ++i) {
|
||||
bool in_low_padding = indices[i] < pad_low[i];
|
||||
bool in_high_padding = indices[i] >= output_bounds[i] - pad_high[i];
|
||||
if (in_low_padding || in_high_padding) {
|
||||
*value = pad;
|
||||
return;
|
||||
}
|
||||
}
|
||||
*value = operand(indices[0] - pad_low[0], indices[1] - pad_low[1],
|
||||
indices[2] - pad_low[2], indices[3] - pad_low[3]);
|
||||
});
|
||||
return result;
|
||||
}
|
||||
|
||||
} // namespace xla
|
||||
|
@ -395,6 +395,11 @@ class ReferenceUtil {
|
||||
const Array2D<float>& operand, const PaddingConfig& padding,
|
||||
const float pad);
|
||||
|
||||
// Returns the result of a 4D pad on an input array.
|
||||
static Array4D<float> PadArray4D(const Array4D<float>& operand,
|
||||
const PaddingConfig& padding,
|
||||
const float pad);
|
||||
|
||||
private:
|
||||
TF_DISALLOW_COPY_AND_ASSIGN(ReferenceUtil);
|
||||
};
|
||||
|
@ -409,7 +409,7 @@ StatusOr<bool> CopyInsertion::Run(HloModule* module) {
|
||||
// operand copy insertion above (which will share an allocation).
|
||||
TF_RETURN_IF_ERROR(root_copier.RecordIndicesToCopyForColocatingBuffers(
|
||||
liveness.get(), computation->parameter_instruction(0)));
|
||||
} else if (copy_param_and_const_) {
|
||||
} else {
|
||||
// Record root indices to copy for general computations.
|
||||
TF_RETURN_IF_ERROR(root_copier.RecordIndicesWhichPointToParamOrConstant(
|
||||
liveness->points_to_analysis()));
|
||||
|
@ -32,9 +32,6 @@ namespace xla {
|
||||
// different lifetimes than computation results.
|
||||
class CopyInsertion : public HloPassInterface {
|
||||
public:
|
||||
explicit CopyInsertion(bool copy_param_and_const = true)
|
||||
: copy_param_and_const_(copy_param_and_const) {}
|
||||
~CopyInsertion() override {}
|
||||
tensorflow::StringPiece name() const override { return "copy-insertion"; }
|
||||
|
||||
// Run the pass on the given module. Returns whether the module was changed
|
||||
@ -46,10 +43,6 @@ class CopyInsertion : public HloPassInterface {
|
||||
// duplicate copies.
|
||||
StatusOr<HloInstruction*> FindOrInsertCopy(HloInstruction* hlo);
|
||||
|
||||
// Determines whether to insert copies if the root instruction is, or
|
||||
// points-to, any constant or parameter instruction.
|
||||
const bool copy_param_and_const_;
|
||||
|
||||
// A map containing all copies inserted during the copy insertion pass. The
|
||||
// key is the copied instruction and the value is the copy.
|
||||
std::unordered_map<HloInstruction*, HloInstruction*> inserted_copies_;
|
||||
|
@ -187,8 +187,8 @@ tensorflow::Status PrepareHloModuleForIrEmitting(
|
||||
|
||||
// Invokes the ptxas tool on the given PTX string, and dumps its output.
|
||||
void DumpPtxasInfo(const string& ptx) {
|
||||
legacy_flags::GpuCompilerFlags* flags = legacy_flags::GetGpuCompilerFlags();
|
||||
const string ptxas_path = flags->xla_ptxas_path;
|
||||
const string ptxas_path =
|
||||
tensorflow::io::JoinPath(tensorflow::CudaRoot(), "bin/ptxas");
|
||||
// Do not log PTX stats if ptxas is not found at the given path.
|
||||
if (!tensorflow::Env::Default()->FileExists(ptxas_path).ok()) {
|
||||
LOG(WARNING)
|
||||
|
@ -70,6 +70,7 @@ string HloExecutionProfile::ToString(
|
||||
string result;
|
||||
const int64 total_cycles = total_cycles_executed(computation);
|
||||
double clock_rate_ghz = device_description.clock_rate_ghz();
|
||||
CHECK_GE(clock_rate_ghz, 1e-9);
|
||||
|
||||
const auto cycles_to_microseconds = [&](double cycles) {
|
||||
return cycles / clock_rate_ghz / 1000.0;
|
||||
@ -80,14 +81,19 @@ string HloExecutionProfile::ToString(
|
||||
double nsecs = cycles / clock_rate_ghz;
|
||||
string bytes_per_sec;
|
||||
string bytes_per_cycle;
|
||||
if (bytes_accessed >= 0) {
|
||||
if (cycles <= 0 || bytes_accessed < 0) {
|
||||
bytes_per_sec = "<unknown>";
|
||||
bytes_per_cycle = "<unknown>";
|
||||
} else {
|
||||
bytes_per_sec = tensorflow::strings::HumanReadableNumBytes(
|
||||
bytes_accessed / (nsecs / 1e9));
|
||||
bytes_per_cycle =
|
||||
tensorflow::strings::HumanReadableNumBytes(bytes_accessed / cycles);
|
||||
} else {
|
||||
bytes_per_sec = "<unknown>";
|
||||
bytes_per_cycle = "<unknown>";
|
||||
}
|
||||
|
||||
double cycles_percent = 0;
|
||||
if (total_cycles > 0) {
|
||||
cycles_percent = cycles / static_cast<double>(total_cycles) * 100;
|
||||
}
|
||||
|
||||
tensorflow::strings::StrAppend(
|
||||
@ -97,8 +103,7 @@ string HloExecutionProfile::ToString(
|
||||
":: "
|
||||
"%12s/cycle :: "
|
||||
"%s",
|
||||
cycles, cycles / static_cast<double>(total_cycles) * 100,
|
||||
cycles_to_microseconds(cycles),
|
||||
cycles, cycles_percent, cycles_to_microseconds(cycles),
|
||||
flops <= 0 ? "<none>" : HumanReadableNumFlops(flops, nsecs).c_str(),
|
||||
bytes_per_sec.c_str(), bytes_per_cycle.c_str(), name.c_str()));
|
||||
};
|
||||
@ -114,26 +119,30 @@ string HloExecutionProfile::ToString(
|
||||
for (const auto& item : items) {
|
||||
const HloInstruction* hlo = item.first;
|
||||
tensorflow::strings::StrAppend(&result, "\n\t");
|
||||
int64 flops = hlo == nullptr ? -1 : cost_analysis.flop_count(*hlo);
|
||||
int64 bytes_accessed =
|
||||
hlo == nullptr ? -1 : cost_analysis.bytes_accessed(*hlo);
|
||||
string display = hlo == nullptr ? "<none>" : hlo->ToString();
|
||||
const int64 flops = (hlo == nullptr) ? -1 : cost_analysis.flop_count(*hlo);
|
||||
const int64 bytes_accessed =
|
||||
(hlo == nullptr) ? -1 : cost_analysis.bytes_accessed(*hlo);
|
||||
const string display = (hlo == nullptr) ? "<none>" : hlo->ToString();
|
||||
append_item(item.second, flops, bytes_accessed, display);
|
||||
}
|
||||
|
||||
MetricTableReport table;
|
||||
table.SetMetricName("microseconds");
|
||||
table.SetEntryName("ops");
|
||||
table.SetShowCategoryTable();
|
||||
for (const auto& item : items) {
|
||||
MetricTableReport::Entry entry;
|
||||
entry.text = item.first->ToString();
|
||||
entry.short_text = item.first->ToString(/*compact_operands=*/true);
|
||||
entry.category_text = item.first->ToCategory();
|
||||
entry.metric = cycles_to_microseconds(item.second);
|
||||
table.AddEntry(std::move(entry));
|
||||
if (total_cycles <= 0) {
|
||||
result += "****** 0 total cycles ******\n";
|
||||
} else {
|
||||
MetricTableReport table;
|
||||
table.SetMetricName("microseconds");
|
||||
table.SetEntryName("ops");
|
||||
table.SetShowCategoryTable();
|
||||
for (const auto& item : items) {
|
||||
MetricTableReport::Entry entry;
|
||||
entry.text = item.first->ToString();
|
||||
entry.short_text = item.first->ToString(/*compact_operands=*/true);
|
||||
entry.category_text = item.first->ToCategory();
|
||||
entry.metric = cycles_to_microseconds(item.second);
|
||||
table.AddEntry(std::move(entry));
|
||||
}
|
||||
result += table.MakeReport(cycles_to_microseconds(total_cycles));
|
||||
}
|
||||
result += table.MakeReport(cycles_to_microseconds(total_cycles));
|
||||
|
||||
return result;
|
||||
}
|
||||
|
@ -309,6 +309,10 @@ StatusOr<Shape> InferWindowOutputShape(const Shape& base_shape,
|
||||
return InvalidArgument(
|
||||
"the rank of the operand and the padding configuration do not match.");
|
||||
}
|
||||
if (operand_shape.element_type() != padding_value_shape.element_type()) {
|
||||
return InvalidArgument(
|
||||
"the element types of the operands to pad do not match");
|
||||
}
|
||||
std::vector<int64> dimensions(ShapeUtil::Rank(operand_shape));
|
||||
for (int64 i = 0; i < operand_shape.dimensions_size(); ++i) {
|
||||
dimensions[i] = operand_shape.dimensions(i) +
|
||||
@ -338,7 +342,7 @@ StatusOr<Shape> InferWindowOutputShape(const Shape& base_shape,
|
||||
|
||||
// Check if both element types are the same.
|
||||
if (lhs.element_type() != rhs.element_type()) {
|
||||
return fail("element types mismatch");
|
||||
return fail("element types do not match");
|
||||
}
|
||||
|
||||
if (ShapeUtil::Rank(lhs) < 1 || ShapeUtil::Rank(lhs) > 2 ||
|
||||
|
@ -31,6 +31,7 @@ limitations under the License.
|
||||
#include "tensorflow/core/lib/gtl/array_slice.h"
|
||||
#include "tensorflow/core/lib/math/math_util.h"
|
||||
#include "tensorflow/core/lib/strings/numbers.h"
|
||||
#include "tensorflow/core/lib/strings/strcat.h"
|
||||
#include "tensorflow/core/platform/logging.h"
|
||||
#include "tensorflow/core/platform/macros.h"
|
||||
#include "tensorflow/core/platform/protobuf.h"
|
||||
@ -200,6 +201,46 @@ int64 PositionInContainer(const Container& container, int64 value) {
|
||||
std::find(container.begin(), container.end(), value));
|
||||
}
|
||||
|
||||
// Formats the container as a comma-separated string. StrAppend must support
|
||||
// appending the elements of the container. Prefix is prepended and suffix is
|
||||
// appended to the returned string.
|
||||
template <typename Container>
|
||||
string CommaSeparatedString(const Container& c, const char* prefix = "",
|
||||
const char* suffix = "") {
|
||||
// Not using Join() since the implementation here is simple anyway and this
|
||||
// avoids copying the string to append prefix.
|
||||
string comma_separated = prefix;
|
||||
const char* separator = "";
|
||||
for (const auto& entry : c) {
|
||||
tensorflow::strings::StrAppend(&comma_separated, separator, entry);
|
||||
separator = ", ";
|
||||
}
|
||||
comma_separated += suffix;
|
||||
return comma_separated;
|
||||
}
|
||||
|
||||
// Overload needed to allow the container to be an initializer list. The default
|
||||
// type for T makes an empty initializer list work as well.
|
||||
template <typename T = int>
|
||||
string CommaSeparatedString(const std::initializer_list<T>& c,
|
||||
const char* prefix = "", const char* suffix = "") {
|
||||
return CommaSeparatedString<std::initializer_list<T>>(c, prefix, suffix);
|
||||
}
|
||||
|
||||
// Formats the container in the mathematical notation for a vector, e.g. (1, 3,
|
||||
// 7). StrAppend must support appending the elements of c.
|
||||
template <typename Container>
|
||||
string VectorString(const Container& c) {
|
||||
return CommaSeparatedString(c, "(", ")");
|
||||
}
|
||||
|
||||
// Overload needed to allow the container to be an initializer list. The default
|
||||
// type for T makes an empty initializer list work as well.
|
||||
template <typename T = int>
|
||||
string VectorString(const std::initializer_list<T>& c) {
|
||||
return VectorString<std::initializer_list<T>>(c);
|
||||
}
|
||||
|
||||
// Returns a PaddingConfig object that represents no padding for the given rank.
|
||||
PaddingConfig MakeNoPaddingConfig(int64 rank);
|
||||
|
||||
|
@ -80,6 +80,26 @@ TEST(UtilTest, HumanReadableNumFlopsExample) {
|
||||
ASSERT_EQ("1.00GFLOP/s", HumanReadableNumFlops(1e9, 1e9));
|
||||
}
|
||||
|
||||
TEST(UtilTest, CommaSeparatedString) {
|
||||
EXPECT_EQ(CommaSeparatedString({}), "");
|
||||
EXPECT_EQ(CommaSeparatedString({"hello world"}), "hello world");
|
||||
EXPECT_EQ(CommaSeparatedString({1, 57, 2}, "foo", "bar"), "foo1, 57, 2bar");
|
||||
}
|
||||
|
||||
TEST(UtilTest, VectorString) {
|
||||
std::list<int64> empty_list;
|
||||
EXPECT_EQ(VectorString(empty_list), "()");
|
||||
|
||||
std::vector<float> float_vector = {5.5};
|
||||
EXPECT_EQ(VectorString(float_vector), "(5.5)");
|
||||
|
||||
std::set<const char*> string_set = {"a", "b"};
|
||||
EXPECT_EQ(VectorString(string_set), "(a, b)");
|
||||
|
||||
EXPECT_EQ(VectorString({}), "()");
|
||||
EXPECT_EQ(VectorString({1, 57, 2}), "(1, 57, 2)");
|
||||
}
|
||||
|
||||
TEST(UtilTest, LogLines) {
|
||||
// Just make sure this code runs (not verifying the output).
|
||||
LogLines(tensorflow::INFO, "hello\n\nworld", __FILE__, __LINE__);
|
||||
|
@ -20,7 +20,6 @@ from __future__ import print_function
|
||||
|
||||
import abc
|
||||
import contextlib
|
||||
import inspect
|
||||
import types
|
||||
|
||||
import numpy as np
|
||||
@ -33,6 +32,7 @@ from tensorflow.python.framework import tensor_shape
|
||||
from tensorflow.python.framework import tensor_util
|
||||
from tensorflow.python.ops import array_ops
|
||||
from tensorflow.python.ops import math_ops
|
||||
from tensorflow.python.util import tf_inspect
|
||||
|
||||
|
||||
_DISTRIBUTION_PUBLIC_METHOD_WRAPPERS = [
|
||||
@ -154,12 +154,12 @@ class _DistributionMeta(abc.ABCMeta):
|
||||
if class_special_attr_value is None:
|
||||
# No _special method available, no need to update the docstring.
|
||||
continue
|
||||
class_special_attr_docstring = inspect.getdoc(class_special_attr_value)
|
||||
class_special_attr_docstring = tf_inspect.getdoc(class_special_attr_value)
|
||||
if not class_special_attr_docstring:
|
||||
# No docstring to append.
|
||||
continue
|
||||
class_attr_value = _copy_fn(base_attr_value)
|
||||
class_attr_docstring = inspect.getdoc(base_attr_value)
|
||||
class_attr_docstring = tf_inspect.getdoc(base_attr_value)
|
||||
if class_attr_docstring is None:
|
||||
raise ValueError(
|
||||
"Expected base class fn to contain a docstring: %s.%s"
|
||||
|
@ -44,7 +44,7 @@ class _Gumbel(distribution.Distribution):
|
||||
|
||||
where `loc = mu` and `scale = sigma`.
|
||||
|
||||
The cumulative densifyt function of this distribution is,
|
||||
The cumulative density function of this distribution is,
|
||||
|
||||
```cdf(x; mu, sigma) = exp(-exp(-(x - mu) / sigma))```
|
||||
|
||||
|
@ -18,12 +18,11 @@ from __future__ import absolute_import
|
||||
from __future__ import division
|
||||
from __future__ import print_function
|
||||
|
||||
import inspect
|
||||
|
||||
from tensorflow.python.framework import ops
|
||||
from tensorflow.python.ops import array_ops
|
||||
from tensorflow.python.ops import control_flow_ops
|
||||
from tensorflow.python.ops import math_ops
|
||||
from tensorflow.python.util import tf_inspect
|
||||
|
||||
|
||||
_DIVERGENCES = {}
|
||||
@ -31,8 +30,8 @@ _DIVERGENCES = {}
|
||||
|
||||
def _registered_kl(type_a, type_b):
|
||||
"""Get the KL function registered for classes a and b."""
|
||||
hierarchy_a = inspect.getmro(type_a)
|
||||
hierarchy_b = inspect.getmro(type_b)
|
||||
hierarchy_a = tf_inspect.getmro(type_a)
|
||||
hierarchy_b = tf_inspect.getmro(type_b)
|
||||
dist_to_children = None
|
||||
kl_fn = None
|
||||
for mro_to_a, parent_a in enumerate(hierarchy_a):
|
||||
|
@ -61,8 +61,9 @@
|
||||
from __future__ import absolute_import
|
||||
from __future__ import division
|
||||
from __future__ import print_function
|
||||
import contextlib
|
||||
import functools
|
||||
|
||||
from tensorflow.python.util import tf_contextlib
|
||||
from tensorflow.python.util import tf_decorator
|
||||
|
||||
__all__ = ['arg_scope',
|
||||
'add_arg_scope',
|
||||
@ -106,7 +107,7 @@ def _add_op(op):
|
||||
_DECORATED_OPS[key_op] = _kwarg_names(op)
|
||||
|
||||
|
||||
@contextlib.contextmanager
|
||||
@tf_contextlib.contextmanager
|
||||
def arg_scope(list_ops_or_scope, **kwargs):
|
||||
"""Stores the default arguments for the given set of list_ops.
|
||||
|
||||
@ -170,7 +171,6 @@ def add_arg_scope(func):
|
||||
Returns:
|
||||
A tuple with the decorated function func_with_args().
|
||||
"""
|
||||
@functools.wraps(func)
|
||||
def func_with_args(*args, **kwargs):
|
||||
current_scope = _current_arg_scope()
|
||||
current_args = kwargs
|
||||
@ -181,8 +181,7 @@ def add_arg_scope(func):
|
||||
return func(*args, **current_args)
|
||||
_add_op(func)
|
||||
setattr(func_with_args, '_key_op', _key_op(func))
|
||||
setattr(func_with_args, '__doc__', func.__doc__)
|
||||
return func_with_args
|
||||
return tf_decorator.make_decorator(func, func_with_args)
|
||||
|
||||
|
||||
def has_arg_scope(func):
|
||||
|
@ -18,12 +18,11 @@ from __future__ import absolute_import
|
||||
from __future__ import division
|
||||
from __future__ import print_function
|
||||
|
||||
import inspect
|
||||
|
||||
import numpy as np
|
||||
|
||||
from tensorflow.contrib.keras.python import keras
|
||||
from tensorflow.python.platform import test
|
||||
from tensorflow.python.util import tf_inspect
|
||||
|
||||
|
||||
def compare_single_input_op_to_numpy(keras_op,
|
||||
@ -207,7 +206,7 @@ class BackendLinearAlgebraTest(test.TestCase):
|
||||
compare_single_input_op_to_numpy(keras_op, np_op, input_shape=(4, 7, 5),
|
||||
keras_kwargs={'axis': -1},
|
||||
np_kwargs={'axis': -1})
|
||||
if 'keepdims' in inspect.getargspec(keras_op).args:
|
||||
if 'keepdims' in tf_inspect.getargspec(keras_op).args:
|
||||
compare_single_input_op_to_numpy(keras_op, np_op,
|
||||
input_shape=(4, 7, 5),
|
||||
keras_kwargs={'axis': 1,
|
||||
|
@ -20,7 +20,6 @@ from __future__ import division
|
||||
from __future__ import print_function
|
||||
|
||||
import copy
|
||||
import inspect
|
||||
import json
|
||||
import os
|
||||
import re
|
||||
@ -35,6 +34,7 @@ from tensorflow.contrib.keras.python.keras.utils import conv_utils
|
||||
from tensorflow.contrib.keras.python.keras.utils.io_utils import ask_to_proceed_with_overwrite
|
||||
from tensorflow.contrib.keras.python.keras.utils.layer_utils import print_summary as print_layer_summary
|
||||
from tensorflow.python.framework import tensor_shape
|
||||
from tensorflow.python.util import tf_inspect
|
||||
|
||||
|
||||
# pylint: disable=g-import-not-at-top
|
||||
@ -584,7 +584,7 @@ class Layer(object):
|
||||
user_kwargs = copy.copy(kwargs)
|
||||
if not _is_all_none(previous_mask):
|
||||
# The previous layer generated a mask.
|
||||
if 'mask' in inspect.getargspec(self.call).args:
|
||||
if 'mask' in tf_inspect.getargspec(self.call).args:
|
||||
if 'mask' not in kwargs:
|
||||
# If mask is explicitly passed to __call__,
|
||||
# we should override the default mask.
|
||||
@ -2166,7 +2166,7 @@ class Container(Layer):
|
||||
kwargs = {}
|
||||
if len(computed_data) == 1:
|
||||
computed_tensor, computed_mask = computed_data[0]
|
||||
if 'mask' in inspect.getargspec(layer.call).args:
|
||||
if 'mask' in tf_inspect.getargspec(layer.call).args:
|
||||
if 'mask' not in kwargs:
|
||||
kwargs['mask'] = computed_mask
|
||||
output_tensors = _to_list(layer.call(computed_tensor, **kwargs))
|
||||
@ -2177,7 +2177,7 @@ class Container(Layer):
|
||||
else:
|
||||
computed_tensors = [x[0] for x in computed_data]
|
||||
computed_masks = [x[1] for x in computed_data]
|
||||
if 'mask' in inspect.getargspec(layer.call).args:
|
||||
if 'mask' in tf_inspect.getargspec(layer.call).args:
|
||||
if 'mask' not in kwargs:
|
||||
kwargs['mask'] = computed_masks
|
||||
output_tensors = _to_list(layer.call(computed_tensors, **kwargs))
|
||||
|
@ -19,7 +19,6 @@ from __future__ import division
|
||||
from __future__ import print_function
|
||||
|
||||
import copy
|
||||
import inspect
|
||||
import types as python_types
|
||||
|
||||
import numpy as np
|
||||
@ -35,6 +34,7 @@ from tensorflow.contrib.keras.python.keras.utils.generic_utils import deserializ
|
||||
from tensorflow.contrib.keras.python.keras.utils.generic_utils import func_dump
|
||||
from tensorflow.contrib.keras.python.keras.utils.generic_utils import func_load
|
||||
from tensorflow.python.framework import tensor_shape
|
||||
from tensorflow.python.util import tf_inspect
|
||||
|
||||
|
||||
class Masking(Layer):
|
||||
@ -595,7 +595,7 @@ class Lambda(Layer):
|
||||
|
||||
def call(self, inputs, mask=None):
|
||||
arguments = self.arguments
|
||||
arg_spec = inspect.getargspec(self.function)
|
||||
arg_spec = tf_inspect.getargspec(self.function)
|
||||
if 'mask' in arg_spec.args:
|
||||
arguments['mask'] = mask
|
||||
return self.function(inputs, **arguments)
|
||||
|
@ -20,12 +20,12 @@ from __future__ import division
|
||||
from __future__ import print_function
|
||||
|
||||
import copy
|
||||
import inspect
|
||||
|
||||
from tensorflow.contrib.keras.python.keras import backend as K
|
||||
from tensorflow.contrib.keras.python.keras.engine import InputSpec
|
||||
from tensorflow.contrib.keras.python.keras.engine import Layer
|
||||
from tensorflow.python.framework import tensor_shape
|
||||
from tensorflow.python.util import tf_inspect
|
||||
|
||||
|
||||
class Wrapper(Layer):
|
||||
@ -284,7 +284,7 @@ class Bidirectional(Wrapper):
|
||||
|
||||
def call(self, inputs, training=None, mask=None):
|
||||
kwargs = {}
|
||||
func_args = inspect.getargspec(self.layer.call).args
|
||||
func_args = tf_inspect.getargspec(self.layer.call).args
|
||||
if 'training' in func_args:
|
||||
kwargs['training'] = training
|
||||
if 'mask' in func_args:
|
||||
|
@ -18,11 +18,10 @@ from __future__ import absolute_import
|
||||
from __future__ import division
|
||||
from __future__ import print_function
|
||||
|
||||
import inspect
|
||||
|
||||
import numpy as np
|
||||
|
||||
from tensorflow.contrib.keras.python import keras
|
||||
from tensorflow.python.util import tf_inspect
|
||||
|
||||
|
||||
def get_test_data(train_samples,
|
||||
@ -98,7 +97,7 @@ def layer_test(layer_cls, kwargs=None, input_shape=None, input_dtype=None,
|
||||
layer.set_weights(weights)
|
||||
|
||||
# test and instantiation from weights
|
||||
if 'weights' in inspect.getargspec(layer_cls.__init__):
|
||||
if 'weights' in tf_inspect.getargspec(layer_cls.__init__):
|
||||
kwargs['weights'] = weights
|
||||
layer = layer_cls(**kwargs)
|
||||
|
||||
|
@ -17,7 +17,6 @@ from __future__ import absolute_import
|
||||
from __future__ import division
|
||||
from __future__ import print_function
|
||||
|
||||
import inspect
|
||||
import marshal
|
||||
import sys
|
||||
import time
|
||||
@ -26,6 +25,8 @@ import types as python_types
|
||||
import numpy as np
|
||||
import six
|
||||
|
||||
from tensorflow.python.util import tf_decorator
|
||||
from tensorflow.python.util import tf_inspect
|
||||
|
||||
_GLOBAL_CUSTOM_OBJECTS = {}
|
||||
|
||||
@ -116,6 +117,7 @@ def get_custom_objects():
|
||||
|
||||
|
||||
def serialize_keras_object(instance):
|
||||
_, instance = tf_decorator.unwrap(instance)
|
||||
if instance is None:
|
||||
return None
|
||||
if hasattr(instance, 'get_config'):
|
||||
@ -149,7 +151,7 @@ def deserialize_keras_object(identifier,
|
||||
if cls is None:
|
||||
raise ValueError('Unknown ' + printable_module_name + ': ' + class_name)
|
||||
if hasattr(cls, 'from_config'):
|
||||
arg_spec = inspect.getargspec(cls.from_config)
|
||||
arg_spec = tf_inspect.getargspec(cls.from_config)
|
||||
if 'custom_objects' in arg_spec.args:
|
||||
custom_objects = custom_objects or {}
|
||||
return cls.from_config(
|
||||
|
@ -19,13 +19,13 @@ from __future__ import division
|
||||
from __future__ import print_function
|
||||
|
||||
import copy
|
||||
import inspect
|
||||
import types
|
||||
|
||||
import numpy as np
|
||||
|
||||
from tensorflow.contrib.keras.python.keras.models import Sequential
|
||||
from tensorflow.contrib.keras.python.keras.utils.np_utils import to_categorical
|
||||
from tensorflow.python.util import tf_inspect
|
||||
|
||||
|
||||
class BaseWrapper(object):
|
||||
@ -97,7 +97,7 @@ class BaseWrapper(object):
|
||||
|
||||
legal_params = []
|
||||
for fn in legal_params_fns:
|
||||
legal_params += inspect.getargspec(fn)[0]
|
||||
legal_params += tf_inspect.getargspec(fn)[0]
|
||||
legal_params = set(legal_params)
|
||||
|
||||
for params_name in params:
|
||||
@ -182,7 +182,7 @@ class BaseWrapper(object):
|
||||
"""
|
||||
override = override or {}
|
||||
res = {}
|
||||
fn_args = inspect.getargspec(fn)[0]
|
||||
fn_args = tf_inspect.getargspec(fn)[0]
|
||||
for name, value in self.sk_params.items():
|
||||
if name in fn_args:
|
||||
res.update({name: value})
|
||||
|
@ -24,9 +24,9 @@ from __future__ import print_function
|
||||
|
||||
import collections
|
||||
import functools
|
||||
import inspect
|
||||
import re
|
||||
|
||||
from tensorflow.python.util import tf_inspect
|
||||
|
||||
# used for register_type_abbreviation and _type_repr below.
|
||||
_TYPE_ABBREVIATIONS = {}
|
||||
@ -230,7 +230,7 @@ def accepts(*types):
|
||||
|
||||
def check_accepts(f):
|
||||
"""Check the types."""
|
||||
spec = inspect.getargspec(f)
|
||||
spec = tf_inspect.getargspec(f)
|
||||
|
||||
num_function_arguments = len(spec.args)
|
||||
if len(types) != num_function_arguments:
|
||||
|
@ -24,11 +24,12 @@ from abc import abstractmethod
|
||||
from abc import abstractproperty
|
||||
|
||||
import collections
|
||||
import inspect
|
||||
|
||||
from .series import Series
|
||||
from .series import TransformedSeries
|
||||
|
||||
from tensorflow.python.util import tf_inspect
|
||||
|
||||
|
||||
def _make_list_of_series(x):
|
||||
"""Converts `x` into a list of `Series` if possible.
|
||||
@ -120,7 +121,7 @@ class Transform(object):
|
||||
def parameters(self):
|
||||
"""A dict of names to values of properties marked with `@parameter`."""
|
||||
property_param_names = [name
|
||||
for name, func in inspect.getmembers(type(self))
|
||||
for name, func in tf_inspect.getmembers(type(self))
|
||||
if (hasattr(func, "fget") and hasattr(
|
||||
getattr(func, "fget"), "is_parameter"))]
|
||||
return {name: getattr(self, name) for name in property_param_names}
|
||||
|
@ -218,7 +218,8 @@ def read_data_sets(train_dir,
|
||||
if fake_data:
|
||||
|
||||
def fake():
|
||||
return DataSet([], [], fake_data=True, one_hot=one_hot, dtype=dtype, seed=seed)
|
||||
return DataSet(
|
||||
[], [], fake_data=True, one_hot=one_hot, dtype=dtype, seed=seed)
|
||||
|
||||
train = fake()
|
||||
validation = fake()
|
||||
@ -260,13 +261,16 @@ def read_data_sets(train_dir,
|
||||
train_images = train_images[validation_size:]
|
||||
train_labels = train_labels[validation_size:]
|
||||
|
||||
train = DataSet(train_images, train_labels, dtype=dtype, reshape=reshape, seed=seed)
|
||||
validation = DataSet(validation_images,
|
||||
validation_labels,
|
||||
dtype=dtype,
|
||||
reshape=reshape,
|
||||
seed=seed)
|
||||
test = DataSet(test_images, test_labels, dtype=dtype, reshape=reshape, seed=seed)
|
||||
train = DataSet(
|
||||
train_images, train_labels, dtype=dtype, reshape=reshape, seed=seed)
|
||||
validation = DataSet(
|
||||
validation_images,
|
||||
validation_labels,
|
||||
dtype=dtype,
|
||||
reshape=reshape,
|
||||
seed=seed)
|
||||
test = DataSet(
|
||||
test_images, test_labels, dtype=dtype, reshape=reshape, seed=seed)
|
||||
|
||||
return base.Datasets(train=train, validation=validation, test=test)
|
||||
|
||||
|
@ -21,7 +21,6 @@ from __future__ import print_function
|
||||
|
||||
import abc
|
||||
import copy
|
||||
import inspect
|
||||
import os
|
||||
import tempfile
|
||||
|
||||
@ -70,6 +69,8 @@ from tensorflow.python.training import monitored_session
|
||||
from tensorflow.python.training import saver
|
||||
from tensorflow.python.training import summary_io
|
||||
from tensorflow.python.util import compat
|
||||
from tensorflow.python.util import tf_decorator
|
||||
from tensorflow.python.util import tf_inspect
|
||||
|
||||
|
||||
AS_ITERABLE_DATE = '2016-09-15'
|
||||
@ -185,14 +186,15 @@ def _model_fn_args(fn):
|
||||
Raises:
|
||||
ValueError: if partial function has positionally bound arguments
|
||||
"""
|
||||
_, fn = tf_decorator.unwrap(fn)
|
||||
if hasattr(fn, 'func') and hasattr(fn, 'keywords') and hasattr(fn, 'args'):
|
||||
# Handle functools.partial and similar objects.
|
||||
return tuple([
|
||||
arg for arg in inspect.getargspec(fn.func).args[len(fn.args):]
|
||||
arg for arg in tf_inspect.getargspec(fn.func).args[len(fn.args):]
|
||||
if arg not in set(fn.keywords.keys())
|
||||
])
|
||||
# Handle function.
|
||||
return tuple(inspect.getargspec(fn).args)
|
||||
return tuple(tf_inspect.getargspec(fn).args)
|
||||
|
||||
|
||||
def _get_replica_device_setter(config):
|
||||
|
@ -52,7 +52,6 @@ from tensorflow.python.framework import ops
|
||||
from tensorflow.python.ops import array_ops
|
||||
from tensorflow.python.ops import check_ops
|
||||
from tensorflow.python.ops import control_flow_ops
|
||||
from tensorflow.python.ops import data_flow_ops
|
||||
from tensorflow.python.ops import math_ops
|
||||
from tensorflow.python.ops import parsing_ops
|
||||
from tensorflow.python.ops import variables as variables_lib
|
||||
@ -63,7 +62,6 @@ from tensorflow.python.saved_model import tag_constants
|
||||
from tensorflow.python.training import basic_session_run_hooks
|
||||
from tensorflow.python.training import input as input_lib
|
||||
from tensorflow.python.training import monitored_session
|
||||
from tensorflow.python.training import queue_runner_impl
|
||||
from tensorflow.python.training import saver as saver_lib
|
||||
from tensorflow.python.training import session_run_hook
|
||||
from tensorflow.python.util import compat
|
||||
|
@ -18,7 +18,7 @@ from __future__ import absolute_import
|
||||
from __future__ import division
|
||||
from __future__ import print_function
|
||||
|
||||
import inspect
|
||||
from tensorflow.python.util import tf_inspect
|
||||
|
||||
|
||||
def assert_estimator_contract(tester, estimator_class):
|
||||
@ -31,7 +31,7 @@ def assert_estimator_contract(tester, estimator_class):
|
||||
tester: A tf.test.TestCase.
|
||||
estimator_class: 'type' object of pre-canned estimator.
|
||||
"""
|
||||
attributes = inspect.getmembers(estimator_class)
|
||||
attributes = tf_inspect.getmembers(estimator_class)
|
||||
attribute_names = [a[0] for a in attributes]
|
||||
|
||||
tester.assertTrue('config' in attribute_names)
|
||||
|
@ -19,7 +19,6 @@ from __future__ import division
|
||||
from __future__ import print_function
|
||||
|
||||
import abc
|
||||
import inspect
|
||||
|
||||
import six
|
||||
|
||||
@ -38,14 +37,17 @@ from tensorflow.python.framework import ops
|
||||
from tensorflow.python.framework import sparse_tensor
|
||||
from tensorflow.python.ops import array_ops
|
||||
from tensorflow.python.ops import control_flow_ops
|
||||
from tensorflow.python.ops import logging_ops
|
||||
from tensorflow.python.ops import math_ops
|
||||
from tensorflow.python.ops import nn
|
||||
from tensorflow.python.ops import sparse_ops
|
||||
from tensorflow.python.ops import string_ops
|
||||
from tensorflow.python.ops import variable_scope
|
||||
from tensorflow.python.ops import weights_broadcast_ops
|
||||
from tensorflow.python.platform import tf_logging as logging
|
||||
from tensorflow.python.summary import summary
|
||||
from tensorflow.python.training import training
|
||||
from tensorflow.python.util import tf_decorator
|
||||
from tensorflow.python.util import tf_inspect
|
||||
|
||||
|
||||
class Head(object):
|
||||
@ -1663,12 +1665,10 @@ def _compute_weighted_loss(loss_unweighted, weight, name="loss"):
|
||||
if weight is None:
|
||||
loss = math_ops.reduce_mean(loss_unweighted, name=name_scope)
|
||||
return loss, loss
|
||||
weight = weights_broadcast_ops.broadcast_weights(weight, loss_unweighted)
|
||||
with ops.name_scope(None, "weighted_loss",
|
||||
(loss_unweighted, weight)) as name:
|
||||
# TODO(ptucker): Support weight broadcasting, or switch to tf.losses.
|
||||
weighted_loss = math_ops.multiply(
|
||||
array_ops.reshape(loss_unweighted, shape=(-1,)),
|
||||
array_ops.reshape(weight, shape=(-1,)), name=name)
|
||||
weighted_loss = math_ops.multiply(loss_unweighted, weight, name=name)
|
||||
weighted_loss_mean = math_ops.reduce_mean(weighted_loss, name=name_scope)
|
||||
weighted_loss_normalized = math_ops.div(
|
||||
math_ops.reduce_sum(weighted_loss),
|
||||
@ -1697,9 +1697,10 @@ def _check_mode_valid(mode):
|
||||
|
||||
def _get_arguments(func):
|
||||
"""Returns a spec of given func."""
|
||||
_, func = tf_decorator.unwrap(func)
|
||||
if hasattr(func, "__code__"):
|
||||
# Regular function.
|
||||
return inspect.getargspec(func)
|
||||
return tf_inspect.getargspec(func)
|
||||
elif hasattr(func, "__call__"):
|
||||
# Callable object.
|
||||
return _get_arguments(func.__call__)
|
||||
@ -1802,8 +1803,13 @@ def _float_weights_or_none(weights):
|
||||
|
||||
|
||||
def _indicator_labels_streaming_mean(labels, weights=None, class_id=None):
|
||||
labels = ops.convert_to_tensor(labels)
|
||||
labels = math_ops.to_float(labels)
|
||||
weights = _float_weights_or_none(weights)
|
||||
if weights is not None:
|
||||
weights = weights_broadcast_ops.broadcast_weights(weights, labels)
|
||||
if class_id is not None:
|
||||
if weights is not None:
|
||||
weights = weights[:, class_id]
|
||||
labels = labels[:, class_id]
|
||||
return metrics_lib.streaming_mean(labels, weights=weights)
|
||||
|
||||
@ -1811,11 +1817,13 @@ def _indicator_labels_streaming_mean(labels, weights=None, class_id=None):
|
||||
def _predictions_streaming_mean(predictions,
|
||||
weights=None,
|
||||
class_id=None):
|
||||
predictions = ops.convert_to_tensor(predictions)
|
||||
predictions = math_ops.to_float(predictions)
|
||||
weights = _float_weights_or_none(weights)
|
||||
if weights is not None:
|
||||
weights = ops.convert_to_tensor(weights)
|
||||
|
||||
weights = weights_broadcast_ops.broadcast_weights(weights, predictions)
|
||||
if class_id is not None:
|
||||
if weights is not None:
|
||||
weights = weights[:, class_id]
|
||||
predictions = predictions[:, class_id]
|
||||
return metrics_lib.streaming_mean(predictions, weights=weights)
|
||||
|
||||
@ -1850,16 +1858,21 @@ def _class_labels_streaming_mean(labels, weights, class_id):
|
||||
|
||||
def _streaming_auc(predictions, labels, weights=None, class_id=None,
|
||||
curve="ROC"):
|
||||
predictions = ops.convert_to_tensor(predictions)
|
||||
labels = ops.convert_to_tensor(labels)
|
||||
# pylint: disable=missing-docstring
|
||||
predictions = math_ops.to_float(predictions)
|
||||
if labels.dtype.base_dtype != dtypes.bool:
|
||||
logging.warning("Casting %s labels to bool.", labels.dtype)
|
||||
labels = math_ops.cast(labels, dtypes.bool)
|
||||
weights = _float_weights_or_none(weights)
|
||||
if weights is not None:
|
||||
weights = weights_broadcast_ops.broadcast_weights(weights, predictions)
|
||||
if class_id is not None:
|
||||
if weights is not None:
|
||||
weights = weights[:, class_id]
|
||||
predictions = predictions[:, class_id]
|
||||
labels = labels[:, class_id]
|
||||
return metrics_lib.streaming_auc(
|
||||
predictions,
|
||||
math_ops.cast(labels, dtypes.bool),
|
||||
weights=_float_weights_or_none(weights),
|
||||
curve=curve)
|
||||
predictions, labels, weights=weights, curve=curve)
|
||||
|
||||
|
||||
def _assert_class_id(class_id, num_classes=None):
|
||||
|
@ -36,7 +36,6 @@ from tensorflow.python.ops import data_flow_ops
|
||||
from tensorflow.python.ops import variables
|
||||
from tensorflow.python.ops.losses import losses as losses_lib
|
||||
from tensorflow.python.platform import test
|
||||
# pylint: enable=g-bad-todo,g-import-not-at-top
|
||||
|
||||
|
||||
def _assert_variables(test_case,
|
||||
@ -260,8 +259,10 @@ class RegressionHeadTest(test.TestCase):
|
||||
),
|
||||
expected_trainable=("regression_head/centered_bias_weight:0",))
|
||||
variables.global_variables_initializer().run()
|
||||
_assert_summary_tags(
|
||||
self, ["regression_head/loss", "regression_head/centered_bias/bias_0"])
|
||||
_assert_summary_tags(self, [
|
||||
"regression_head/loss",
|
||||
"regression_head/centered_bias/bias_0"
|
||||
])
|
||||
_assert_metrics(self, 5. / 3, {"loss": 5. / 3}, model_fn_ops)
|
||||
|
||||
def testRegressionErrorInSparseTensorLabels(self):
|
||||
@ -541,7 +542,26 @@ class MultiLabelHeadTest(test.TestCase):
|
||||
_assert_no_variables(self)
|
||||
_assert_summary_tags(self, ["multi_label_head/loss"])
|
||||
_assert_metrics(self, .089985214,
|
||||
self._expected_eval_metrics(2.69956), model_fn_ops)
|
||||
self._expected_eval_metrics(.89985214), model_fn_ops)
|
||||
|
||||
def testMultiLabelWithMultiDimensionalWeight(self):
|
||||
n_classes = 3
|
||||
head = head_lib.multi_label_head(
|
||||
n_classes=n_classes,
|
||||
weight_column_name="label_weight",
|
||||
metric_class_ids=range(n_classes))
|
||||
with ops.Graph().as_default(), session.Session():
|
||||
model_fn_ops = head.create_model_fn_ops(
|
||||
features={"label_weight": ((.1, .1, .1),)},
|
||||
labels=self._labels,
|
||||
mode=model_fn.ModeKeys.TRAIN,
|
||||
train_op_fn=head_lib.no_op_train_fn,
|
||||
logits=self._logits)
|
||||
self._assert_output_alternatives(model_fn_ops)
|
||||
_assert_no_variables(self)
|
||||
_assert_summary_tags(self, ["multi_label_head/loss"])
|
||||
_assert_metrics(self, .089985214,
|
||||
self._expected_eval_metrics(.89985214), model_fn_ops)
|
||||
|
||||
def testMultiLabelWithCustomLoss(self):
|
||||
n_classes = 3
|
||||
@ -560,8 +580,9 @@ class MultiLabelHeadTest(test.TestCase):
|
||||
self._assert_output_alternatives(model_fn_ops)
|
||||
_assert_no_variables(self)
|
||||
_assert_summary_tags(self, ["multi_label_head/loss"])
|
||||
_assert_metrics(self, 0.089985214,
|
||||
self._expected_eval_metrics(0.089985214), model_fn_ops)
|
||||
expected_loss = .089985214
|
||||
_assert_metrics(self, expected_loss,
|
||||
self._expected_eval_metrics(expected_loss), model_fn_ops)
|
||||
|
||||
def testMultiLabelWithCenteredBias(self):
|
||||
n_classes = 3
|
||||
@ -910,9 +931,10 @@ class BinaryClassificationHeadTest(test.TestCase):
|
||||
"Adagrad:0"),),
|
||||
expected_trainable=("binary_logistic_head/centered_bias_weight:0",))
|
||||
variables.global_variables_initializer().run()
|
||||
_assert_summary_tags(
|
||||
self, ["binary_logistic_head/loss",
|
||||
"binary_logistic_head/centered_bias/bias_0"])
|
||||
_assert_summary_tags(self, [
|
||||
"binary_logistic_head/loss",
|
||||
"binary_logistic_head/centered_bias/bias_0"
|
||||
])
|
||||
expected_loss = .81326175
|
||||
_assert_metrics(self, expected_loss,
|
||||
self._expected_eval_metrics(expected_loss), model_fn_ops)
|
||||
@ -1416,7 +1438,8 @@ class BinarySvmHeadTest(test.TestCase):
|
||||
with ops.Graph().as_default(), session.Session():
|
||||
weights = (7., 11.)
|
||||
model_fn_ops = head.create_model_fn_ops(
|
||||
features={"weights": weights},
|
||||
# We have to add an extra dim here for weights broadcasting to work.
|
||||
features={"weights": tuple([(w,) for w in weights])},
|
||||
mode=model_fn.ModeKeys.TRAIN,
|
||||
labels=self._labels,
|
||||
train_op_fn=head_lib.no_op_train_fn,
|
||||
@ -1424,11 +1447,10 @@ class BinarySvmHeadTest(test.TestCase):
|
||||
self._assert_output_alternatives(model_fn_ops)
|
||||
_assert_no_variables(self)
|
||||
_assert_summary_tags(self, ["binary_svm_head/loss"])
|
||||
expected_weighted_sum = np.sum(
|
||||
np.multiply(weights, self._expected_losses))
|
||||
_assert_metrics(self, expected_weighted_sum / len(weights), {
|
||||
expected_weighted_losses = np.multiply(weights, self._expected_losses)
|
||||
_assert_metrics(self, np.mean(expected_weighted_losses), {
|
||||
"accuracy": 1.,
|
||||
"loss": expected_weighted_sum / np.sum(weights),
|
||||
"loss": np.sum(expected_weighted_losses) / np.sum(weights),
|
||||
}, model_fn_ops)
|
||||
|
||||
def testBinarySVMWithCenteredBias(self):
|
||||
@ -1450,9 +1472,10 @@ class BinarySvmHeadTest(test.TestCase):
|
||||
),
|
||||
expected_trainable=("binary_svm_head/centered_bias_weight:0",))
|
||||
variables.global_variables_initializer().run()
|
||||
_assert_summary_tags(
|
||||
self, ["binary_svm_head/loss",
|
||||
"binary_svm_head/centered_bias/bias_0"])
|
||||
_assert_summary_tags(self, [
|
||||
"binary_svm_head/loss",
|
||||
"binary_svm_head/centered_bias/bias_0"
|
||||
])
|
||||
expected_loss = np.average(self._expected_losses)
|
||||
_assert_metrics(self, expected_loss, {
|
||||
"accuracy": 1.,
|
||||
|
@ -12,7 +12,6 @@
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ==============================================================================
|
||||
|
||||
"""ExportStrategy class represents different flavors of model export."""
|
||||
|
||||
from __future__ import absolute_import
|
||||
@ -20,13 +19,14 @@ from __future__ import division
|
||||
from __future__ import print_function
|
||||
|
||||
import collections
|
||||
import inspect
|
||||
|
||||
from tensorflow.python.util import tf_inspect
|
||||
|
||||
__all__ = ['ExportStrategy']
|
||||
|
||||
|
||||
class ExportStrategy(collections.namedtuple('ExportStrategy',
|
||||
['name', 'export_fn'])):
|
||||
class ExportStrategy(
|
||||
collections.namedtuple('ExportStrategy', ['name', 'export_fn'])):
|
||||
"""A class representing a type of model export.
|
||||
|
||||
Typically constructed by a utility function specific to the exporter, such as
|
||||
@ -74,7 +74,7 @@ class ExportStrategy(collections.namedtuple('ExportStrategy',
|
||||
"""
|
||||
# don't break existing export_fns that don't accept checkpoint_path and
|
||||
# eval_result
|
||||
export_fn_args = inspect.getargspec(self.export_fn).args
|
||||
export_fn_args = tf_inspect.getargspec(self.export_fn).args
|
||||
kwargs = {}
|
||||
if 'checkpoint_path' in export_fn_args:
|
||||
kwargs['checkpoint_path'] = checkpoint_path
|
||||
|
@ -18,10 +18,10 @@ from __future__ import absolute_import
|
||||
from __future__ import division
|
||||
from __future__ import print_function
|
||||
|
||||
import inspect
|
||||
import six
|
||||
|
||||
from tensorflow.python.platform import tf_logging as logging
|
||||
from tensorflow.python.util import tf_inspect
|
||||
|
||||
|
||||
def _assert_named_args(sentinel):
|
||||
@ -43,11 +43,11 @@ def _args(fn):
|
||||
if hasattr(fn, 'func') and hasattr(fn, 'keywords'):
|
||||
# Handle functools.partial and similar objects.
|
||||
return tuple([
|
||||
arg for arg in inspect.getargspec(fn.func).args
|
||||
arg for arg in tf_inspect.getargspec(fn.func).args
|
||||
if arg not in set(fn.keywords.keys())
|
||||
])
|
||||
# Handle function.
|
||||
return tuple(inspect.getargspec(fn).args)
|
||||
return tuple(tf_inspect.getargspec(fn).args)
|
||||
|
||||
|
||||
_CANONICAL_LABELS_ARG = 'labels'
|
||||
|
@ -35,7 +35,6 @@ from __future__ import absolute_import
|
||||
from __future__ import division
|
||||
from __future__ import print_function
|
||||
|
||||
import inspect
|
||||
import os
|
||||
import time
|
||||
|
||||
@ -53,6 +52,7 @@ from tensorflow.python.platform import tf_logging as logging
|
||||
from tensorflow.python.training import saver as saver_lib
|
||||
from tensorflow.python.training import summary_io
|
||||
from tensorflow.python.util import deprecation
|
||||
from tensorflow.python.util import tf_inspect
|
||||
|
||||
|
||||
# TODO(ptucker): Split each monitor class into a separate file.
|
||||
@ -1164,7 +1164,7 @@ class RunHookAdapterForMonitors(session_run_hook.SessionRunHook):
|
||||
def end(self, session):
|
||||
self._last_step = None
|
||||
for m in self._monitors:
|
||||
if "session" in inspect.getargspec(m.end).args:
|
||||
if "session" in tf_inspect.getargspec(m.end).args:
|
||||
m.end(session=session)
|
||||
else:
|
||||
m.end()
|
||||
|
@ -51,6 +51,7 @@ tf_custom_op_py_library(
|
||||
"//tensorflow/python:framework",
|
||||
"//tensorflow/python:framework_for_generated_wrappers",
|
||||
"//tensorflow/python:init_ops",
|
||||
"//tensorflow/python:layers",
|
||||
"//tensorflow/python:math_ops",
|
||||
"//tensorflow/python:nn_ops",
|
||||
"//tensorflow/python:partitioned_variables",
|
||||
|
@ -369,28 +369,28 @@ class RNNCellTest(test.TestCase):
|
||||
self.assertFalse([s for s in cpu_stats if "gru_cell" in s.node_name])
|
||||
self.assertTrue([s for s in gpu_stats if "gru_cell" in s.node_name])
|
||||
|
||||
def testUsingSecondCellInScopeWithExistingVariablesFails(self):
|
||||
# This test should go away when this behavior is no longer an
|
||||
# error (Approx. May 2017)
|
||||
cell1 = core_rnn_cell_impl.LSTMCell(3)
|
||||
cell2 = core_rnn_cell_impl.LSTMCell(3)
|
||||
x = array_ops.zeros([1, 3])
|
||||
m = core_rnn_cell_impl.LSTMStateTuple(*[array_ops.zeros([1, 3])] * 2)
|
||||
cell1(x, m)
|
||||
with self.assertRaisesRegexp(ValueError, r"LSTMCell\(..., reuse=True\)"):
|
||||
cell2(x, m)
|
||||
# def testUsingSecondCellInScopeWithExistingVariablesFails(self):
|
||||
# # This test should go away when this behavior is no longer an
|
||||
# # error (Approx. May 2017)
|
||||
# cell1 = core_rnn_cell_impl.LSTMCell(3)
|
||||
# cell2 = core_rnn_cell_impl.LSTMCell(3)
|
||||
# x = array_ops.zeros([1, 3])
|
||||
# m = core_rnn_cell_impl.LSTMStateTuple(*[array_ops.zeros([1, 3])] * 2)
|
||||
# cell1(x, m)
|
||||
# with self.assertRaisesRegexp(ValueError, r"LSTMCell\(..., reuse=True\)"):
|
||||
# cell2(x, m)
|
||||
|
||||
def testUsingCellInDifferentScopeFromFirstCallFails(self):
|
||||
# This test should go away when this behavior is no longer an
|
||||
# error (Approx. May 2017)
|
||||
cell = core_rnn_cell_impl.LSTMCell(3)
|
||||
x = array_ops.zeros([1, 3])
|
||||
m = core_rnn_cell_impl.LSTMStateTuple(*[array_ops.zeros([1, 3])] * 2)
|
||||
with variable_scope.variable_scope("scope1"):
|
||||
cell(x, m)
|
||||
with variable_scope.variable_scope("scope2"):
|
||||
with self.assertRaisesRegexp(ValueError, r"Attempt to reuse RNNCell"):
|
||||
cell(x, m)
|
||||
# def testUsingCellInDifferentScopeFromFirstCallFails(self):
|
||||
# # This test should go away when this behavior is no longer an
|
||||
# # error (Approx. May 2017)
|
||||
# cell = core_rnn_cell_impl.LSTMCell(3)
|
||||
# x = array_ops.zeros([1, 3])
|
||||
# m = core_rnn_cell_impl.LSTMStateTuple(*[array_ops.zeros([1, 3])] * 2)
|
||||
# with variable_scope.variable_scope("scope1"):
|
||||
# cell(x, m)
|
||||
# with variable_scope.variable_scope("scope2"):
|
||||
# with self.assertRaisesRegexp(ValueError, r"Attempt to reuse RNNCell"):
|
||||
# cell(x, m)
|
||||
|
||||
def testEmbeddingWrapper(self):
|
||||
with self.test_session() as sess:
|
||||
|
@ -521,7 +521,7 @@ class LSTMTest(test.TestCase):
|
||||
input_value = np.random.randn(batch_size, input_size)
|
||||
sess.run(outputs, feed_dict={inputs[0]: input_value})
|
||||
|
||||
def testStateTupleWithProjAndSequenceLength(self):
|
||||
def _testStateTupleWithProjAndSequenceLength(self):
|
||||
num_units = 3
|
||||
input_size = 5
|
||||
batch_size = 2
|
||||
|
@ -569,7 +569,7 @@ class RNNCellTest(test.TestCase):
|
||||
self.assertTrue(
|
||||
float(np.linalg.norm((state[0, :] - state[i, :]))) > 1e-6)
|
||||
|
||||
def testAttentionCellWrapperCorrectResult(self):
|
||||
def _testAttentionCellWrapperCorrectResult(self):
|
||||
num_units = 4
|
||||
attn_length = 6
|
||||
batch_size = 2
|
||||
|
@ -108,11 +108,11 @@ class BasicRNNCell(RNNCell):
|
||||
"""The most basic RNN cell."""
|
||||
|
||||
def __init__(self, num_units, input_size=None, activation=tanh, reuse=None):
|
||||
super(BasicRNNCell, self).__init__(_reuse=reuse)
|
||||
if input_size is not None:
|
||||
logging.warn("%s: The input_size parameter is deprecated.", self)
|
||||
self._num_units = num_units
|
||||
self._activation = activation
|
||||
self._reuse = reuse
|
||||
|
||||
@property
|
||||
def state_size(self):
|
||||
@ -122,11 +122,9 @@ class BasicRNNCell(RNNCell):
|
||||
def output_size(self):
|
||||
return self._num_units
|
||||
|
||||
def __call__(self, inputs, state, scope=None):
|
||||
def call(self, inputs, state):
|
||||
"""Most basic RNN: output = new_state = act(W * input + U * state + B)."""
|
||||
with _checked_scope(self, scope or "basic_rnn_cell", reuse=self._reuse):
|
||||
output = self._activation(
|
||||
_linear([inputs, state], self._num_units, True))
|
||||
output = self._activation(_linear([inputs, state], self._num_units, True))
|
||||
return output, output
|
||||
|
||||
|
||||
@ -134,11 +132,11 @@ class GRUCell(RNNCell):
|
||||
"""Gated Recurrent Unit cell (cf. http://arxiv.org/abs/1406.1078)."""
|
||||
|
||||
def __init__(self, num_units, input_size=None, activation=tanh, reuse=None):
|
||||
super(GRUCell, self).__init__(_reuse=reuse)
|
||||
if input_size is not None:
|
||||
logging.warn("%s: The input_size parameter is deprecated.", self)
|
||||
self._num_units = num_units
|
||||
self._activation = activation
|
||||
self._reuse = reuse
|
||||
|
||||
@property
|
||||
def state_size(self):
|
||||
@ -148,21 +146,15 @@ class GRUCell(RNNCell):
|
||||
def output_size(self):
|
||||
return self._num_units
|
||||
|
||||
def __call__(self, inputs, state, scope=None):
|
||||
def call(self, inputs, state):
|
||||
"""Gated recurrent unit (GRU) with nunits cells."""
|
||||
with _checked_scope(self, scope or "gru_cell", reuse=self._reuse):
|
||||
with vs.variable_scope("gates"): # Reset gate and update gate.
|
||||
# We start with bias of 1.0 to not reset and not update.
|
||||
value = sigmoid(_linear(
|
||||
[inputs, state], 2 * self._num_units, True, 1.0))
|
||||
r, u = array_ops.split(
|
||||
value=value,
|
||||
num_or_size_splits=2,
|
||||
axis=1)
|
||||
with vs.variable_scope("candidate"):
|
||||
c = self._activation(_linear([inputs, r * state],
|
||||
self._num_units, True))
|
||||
new_h = u * state + (1 - u) * c
|
||||
with vs.variable_scope("gates"): # Reset gate and update gate.
|
||||
# We start with bias of 1.0 to not reset and not update.
|
||||
value = sigmoid(_linear([inputs, state], 2 * self._num_units, True, 1.0))
|
||||
r, u = array_ops.split(value=value, num_or_size_splits=2, axis=1)
|
||||
with vs.variable_scope("candidate"):
|
||||
c = self._activation(_linear([inputs, r * state], self._num_units, True))
|
||||
new_h = u * state + (1 - u) * c
|
||||
return new_h, new_h
|
||||
|
||||
|
||||
@ -217,6 +209,7 @@ class BasicLSTMCell(RNNCell):
|
||||
in an existing scope. If not `True`, and the existing scope already has
|
||||
the given variables, an error is raised.
|
||||
"""
|
||||
super(BasicLSTMCell, self).__init__(_reuse=reuse)
|
||||
if not state_is_tuple:
|
||||
logging.warn("%s: Using a concatenated state is slower and will soon be "
|
||||
"deprecated. Use state_is_tuple=True.", self)
|
||||
@ -226,7 +219,6 @@ class BasicLSTMCell(RNNCell):
|
||||
self._forget_bias = forget_bias
|
||||
self._state_is_tuple = state_is_tuple
|
||||
self._activation = activation
|
||||
self._reuse = reuse
|
||||
|
||||
@property
|
||||
def state_size(self):
|
||||
@ -237,28 +229,28 @@ class BasicLSTMCell(RNNCell):
|
||||
def output_size(self):
|
||||
return self._num_units
|
||||
|
||||
def __call__(self, inputs, state, scope=None):
|
||||
def call(self, inputs, state):
|
||||
"""Long short-term memory cell (LSTM)."""
|
||||
with _checked_scope(self, scope or "basic_lstm_cell", reuse=self._reuse):
|
||||
# Parameters of gates are concatenated into one multiply for efficiency.
|
||||
if self._state_is_tuple:
|
||||
c, h = state
|
||||
else:
|
||||
c, h = array_ops.split(value=state, num_or_size_splits=2, axis=1)
|
||||
concat = _linear([inputs, h], 4 * self._num_units, True)
|
||||
# Parameters of gates are concatenated into one multiply for efficiency.
|
||||
if self._state_is_tuple:
|
||||
c, h = state
|
||||
else:
|
||||
c, h = array_ops.split(value=state, num_or_size_splits=2, axis=1)
|
||||
|
||||
# i = input_gate, j = new_input, f = forget_gate, o = output_gate
|
||||
i, j, f, o = array_ops.split(value=concat, num_or_size_splits=4, axis=1)
|
||||
concat = _linear([inputs, h], 4 * self._num_units, True)
|
||||
|
||||
new_c = (c * sigmoid(f + self._forget_bias) + sigmoid(i) *
|
||||
self._activation(j))
|
||||
new_h = self._activation(new_c) * sigmoid(o)
|
||||
# i = input_gate, j = new_input, f = forget_gate, o = output_gate
|
||||
i, j, f, o = array_ops.split(value=concat, num_or_size_splits=4, axis=1)
|
||||
|
||||
if self._state_is_tuple:
|
||||
new_state = LSTMStateTuple(new_c, new_h)
|
||||
else:
|
||||
new_state = array_ops.concat([new_c, new_h], 1)
|
||||
return new_h, new_state
|
||||
new_c = (
|
||||
c * sigmoid(f + self._forget_bias) + sigmoid(i) * self._activation(j))
|
||||
new_h = self._activation(new_c) * sigmoid(o)
|
||||
|
||||
if self._state_is_tuple:
|
||||
new_state = LSTMStateTuple(new_c, new_h)
|
||||
else:
|
||||
new_state = array_ops.concat([new_c, new_h], 1)
|
||||
return new_h, new_state
|
||||
|
||||
|
||||
class LSTMCell(RNNCell):
|
||||
@ -319,6 +311,7 @@ class LSTMCell(RNNCell):
|
||||
in an existing scope. If not `True`, and the existing scope already has
|
||||
the given variables, an error is raised.
|
||||
"""
|
||||
super(LSTMCell, self).__init__(_reuse=reuse)
|
||||
if not state_is_tuple:
|
||||
logging.warn("%s: Using a concatenated state is slower and will soon be "
|
||||
"deprecated. Use state_is_tuple=True.", self)
|
||||
@ -341,7 +334,6 @@ class LSTMCell(RNNCell):
|
||||
self._forget_bias = forget_bias
|
||||
self._state_is_tuple = state_is_tuple
|
||||
self._activation = activation
|
||||
self._reuse = reuse
|
||||
|
||||
if num_proj:
|
||||
self._state_size = (
|
||||
@ -362,7 +354,7 @@ class LSTMCell(RNNCell):
|
||||
def output_size(self):
|
||||
return self._output_size
|
||||
|
||||
def __call__(self, inputs, state, scope=None):
|
||||
def call(self, inputs, state):
|
||||
"""Run one step of LSTM.
|
||||
|
||||
Args:
|
||||
@ -371,7 +363,6 @@ class LSTMCell(RNNCell):
|
||||
`2-D, batch x state_size`. If `state_is_tuple` is True, this must be a
|
||||
tuple of state Tensors, both `2-D`, with column sizes `c_state` and
|
||||
`m_state`.
|
||||
scope: VariableScope for the created subgraph; defaults to "lstm_cell".
|
||||
|
||||
Returns:
|
||||
A tuple containing:
|
||||
@ -400,9 +391,8 @@ class LSTMCell(RNNCell):
|
||||
input_size = inputs.get_shape().with_rank(2)[1]
|
||||
if input_size.value is None:
|
||||
raise ValueError("Could not infer input size from inputs.get_shape()[-1]")
|
||||
with _checked_scope(self, scope or "lstm_cell",
|
||||
initializer=self._initializer,
|
||||
reuse=self._reuse) as unit_scope:
|
||||
scope = vs.get_variable_scope()
|
||||
with vs.variable_scope(scope, initializer=self._initializer) as unit_scope:
|
||||
if self._num_unit_shards is not None:
|
||||
unit_scope.set_partitioner(
|
||||
partitioned_variables.fixed_size_partitioner(
|
||||
@ -481,13 +471,13 @@ class OutputProjectionWrapper(RNNCell):
|
||||
TypeError: if cell is not an RNNCell.
|
||||
ValueError: if output_size is not positive.
|
||||
"""
|
||||
super(OutputProjectionWrapper, self).__init__(_reuse=reuse)
|
||||
if not isinstance(cell, RNNCell):
|
||||
raise TypeError("The parameter cell is not RNNCell.")
|
||||
if output_size < 1:
|
||||
raise ValueError("Parameter output_size must be > 0: %d." % output_size)
|
||||
self._cell = cell
|
||||
self._output_size = output_size
|
||||
self._reuse = reuse
|
||||
self._activation = activation
|
||||
|
||||
@property
|
||||
@ -502,15 +492,12 @@ class OutputProjectionWrapper(RNNCell):
|
||||
with ops.name_scope(type(self).__name__ + "ZeroState", values=[batch_size]):
|
||||
return self._cell.zero_state(batch_size, dtype)
|
||||
|
||||
def __call__(self, inputs, state, scope=None):
|
||||
def call(self, inputs, state):
|
||||
"""Run the cell and output projection on inputs, starting from state."""
|
||||
output, res_state = self._cell(inputs, state)
|
||||
# Default scope: "OutputProjectionWrapper"
|
||||
with _checked_scope(self, scope or "output_projection_wrapper",
|
||||
reuse=self._reuse):
|
||||
projected = _linear(output, self._output_size, True)
|
||||
if self._activation:
|
||||
projected = self._activation(projected)
|
||||
projected = _linear(output, self._output_size, True)
|
||||
if self._activation:
|
||||
projected = self._activation(projected)
|
||||
return projected, res_state
|
||||
|
||||
|
||||
@ -522,7 +509,8 @@ class InputProjectionWrapper(RNNCell):
|
||||
do the projection on this batch-concatenated sequence, then split it.
|
||||
"""
|
||||
|
||||
def __init__(self, cell, num_proj, activation=None, input_size=None):
|
||||
def __init__(self, cell, num_proj, activation=None, input_size=None,
|
||||
reuse=None):
|
||||
"""Create a cell with input projection.
|
||||
|
||||
Args:
|
||||
@ -530,10 +518,14 @@ class InputProjectionWrapper(RNNCell):
|
||||
num_proj: Python integer. The dimension to project to.
|
||||
activation: (optional) an optional activation function.
|
||||
input_size: Deprecated and unused.
|
||||
reuse: (optional) Python boolean describing whether to reuse variables
|
||||
in an existing scope. If not `True`, and the existing scope already has
|
||||
the given variables, an error is raised.
|
||||
|
||||
Raises:
|
||||
TypeError: if cell is not an RNNCell.
|
||||
"""
|
||||
super(InputProjectionWrapper, self).__init__(_reuse=reuse)
|
||||
if input_size is not None:
|
||||
logging.warn("%s: The input_size parameter is deprecated.", self)
|
||||
if not isinstance(cell, RNNCell):
|
||||
@ -554,13 +546,12 @@ class InputProjectionWrapper(RNNCell):
|
||||
with ops.name_scope(type(self).__name__ + "ZeroState", values=[batch_size]):
|
||||
return self._cell.zero_state(batch_size, dtype)
|
||||
|
||||
def __call__(self, inputs, state, scope=None):
|
||||
def call(self, inputs, state):
|
||||
"""Run the input projection and then the cell."""
|
||||
# Default scope: "InputProjectionWrapper"
|
||||
with vs.variable_scope(scope or "input_projection_wrapper"):
|
||||
projected = _linear(inputs, self._num_proj, True)
|
||||
if self._activation:
|
||||
projected = self._activation(projected)
|
||||
projected = _linear(inputs, self._num_proj, True)
|
||||
if self._activation:
|
||||
projected = self._activation(projected)
|
||||
return self._cell(projected, state)
|
||||
|
||||
|
||||
@ -847,6 +838,7 @@ class EmbeddingWrapper(RNNCell):
|
||||
TypeError: if cell is not an RNNCell.
|
||||
ValueError: if embedding_classes is not positive.
|
||||
"""
|
||||
super(EmbeddingWrapper, self).__init__(_reuse=reuse)
|
||||
if not isinstance(cell, RNNCell):
|
||||
raise TypeError("The parameter cell is not RNNCell.")
|
||||
if embedding_classes <= 0 or embedding_size <= 0:
|
||||
@ -856,7 +848,6 @@ class EmbeddingWrapper(RNNCell):
|
||||
self._embedding_classes = embedding_classes
|
||||
self._embedding_size = embedding_size
|
||||
self._initializer = initializer
|
||||
self._reuse = reuse
|
||||
|
||||
@property
|
||||
def state_size(self):
|
||||
@ -870,31 +861,31 @@ class EmbeddingWrapper(RNNCell):
|
||||
with ops.name_scope(type(self).__name__ + "ZeroState", values=[batch_size]):
|
||||
return self._cell.zero_state(batch_size, dtype)
|
||||
|
||||
def __call__(self, inputs, state, scope=None):
|
||||
def call(self, inputs, state):
|
||||
"""Run the cell on embedded inputs."""
|
||||
with _checked_scope(self, scope or "embedding_wrapper", reuse=self._reuse):
|
||||
with ops.device("/cpu:0"):
|
||||
if self._initializer:
|
||||
initializer = self._initializer
|
||||
elif vs.get_variable_scope().initializer:
|
||||
initializer = vs.get_variable_scope().initializer
|
||||
else:
|
||||
# Default initializer for embeddings should have variance=1.
|
||||
sqrt3 = math.sqrt(3) # Uniform(-sqrt(3), sqrt(3)) has variance=1.
|
||||
initializer = init_ops.random_uniform_initializer(-sqrt3, sqrt3)
|
||||
with ops.device("/cpu:0"):
|
||||
if self._initializer:
|
||||
initializer = self._initializer
|
||||
elif vs.get_variable_scope().initializer:
|
||||
initializer = vs.get_variable_scope().initializer
|
||||
else:
|
||||
# Default initializer for embeddings should have variance=1.
|
||||
sqrt3 = math.sqrt(3) # Uniform(-sqrt(3), sqrt(3)) has variance=1.
|
||||
initializer = init_ops.random_uniform_initializer(-sqrt3, sqrt3)
|
||||
|
||||
if type(state) is tuple:
|
||||
data_type = state[0].dtype
|
||||
else:
|
||||
data_type = state.dtype
|
||||
if type(state) is tuple:
|
||||
data_type = state[0].dtype
|
||||
else:
|
||||
data_type = state.dtype
|
||||
|
||||
embedding = vs.get_variable(
|
||||
"embedding", [self._embedding_classes, self._embedding_size],
|
||||
initializer=initializer,
|
||||
dtype=data_type)
|
||||
embedded = embedding_ops.embedding_lookup(
|
||||
embedding, array_ops.reshape(inputs, [-1]))
|
||||
return self._cell(embedded, state)
|
||||
embedding = vs.get_variable(
|
||||
"embedding", [self._embedding_classes, self._embedding_size],
|
||||
initializer=initializer,
|
||||
dtype=data_type)
|
||||
embedded = embedding_ops.embedding_lookup(embedding,
|
||||
array_ops.reshape(inputs, [-1]))
|
||||
|
||||
return self._cell(embedded, state)
|
||||
|
||||
|
||||
class MultiRNNCell(RNNCell):
|
||||
@ -914,6 +905,7 @@ class MultiRNNCell(RNNCell):
|
||||
ValueError: if cells is empty (not allowed), or at least one of the cells
|
||||
returns a state tuple but the flag `state_is_tuple` is `False`.
|
||||
"""
|
||||
super(MultiRNNCell, self).__init__()
|
||||
if not cells:
|
||||
raise ValueError("Must specify at least one cell for MultiRNNCell.")
|
||||
if not nest.is_sequence(cells):
|
||||
@ -948,28 +940,29 @@ class MultiRNNCell(RNNCell):
|
||||
# presumably does not contain TensorArrays or anything else fancy
|
||||
return super(MultiRNNCell, self).zero_state(batch_size, dtype)
|
||||
|
||||
def __call__(self, inputs, state, scope=None):
|
||||
def call(self, inputs, state):
|
||||
"""Run this multi-layer cell on inputs, starting from state."""
|
||||
with vs.variable_scope(scope or "multi_rnn_cell"):
|
||||
cur_state_pos = 0
|
||||
cur_inp = inputs
|
||||
new_states = []
|
||||
for i, cell in enumerate(self._cells):
|
||||
with vs.variable_scope("cell_%d" % i):
|
||||
if self._state_is_tuple:
|
||||
if not nest.is_sequence(state):
|
||||
raise ValueError(
|
||||
"Expected state to be a tuple of length %d, but received: %s"
|
||||
% (len(self.state_size), state))
|
||||
cur_state = state[i]
|
||||
else:
|
||||
cur_state = array_ops.slice(
|
||||
state, [0, cur_state_pos], [-1, cell.state_size])
|
||||
cur_state_pos += cell.state_size
|
||||
cur_inp, new_state = cell(cur_inp, cur_state)
|
||||
new_states.append(new_state)
|
||||
cur_state_pos = 0
|
||||
cur_inp = inputs
|
||||
new_states = []
|
||||
for i, cell in enumerate(self._cells):
|
||||
with vs.variable_scope("cell_%d" % i):
|
||||
if self._state_is_tuple:
|
||||
if not nest.is_sequence(state):
|
||||
raise ValueError(
|
||||
"Expected state to be a tuple of length %d, but received: %s" %
|
||||
(len(self.state_size), state))
|
||||
cur_state = state[i]
|
||||
else:
|
||||
cur_state = array_ops.slice(state, [0, cur_state_pos],
|
||||
[-1, cell.state_size])
|
||||
cur_state_pos += cell.state_size
|
||||
cur_inp, new_state = cell(cur_inp, cur_state)
|
||||
new_states.append(new_state)
|
||||
|
||||
new_states = (tuple(new_states) if self._state_is_tuple else
|
||||
array_ops.concat(new_states, 1))
|
||||
|
||||
return cur_inp, new_states
|
||||
|
||||
|
||||
|
@ -138,6 +138,7 @@ class CoupledInputForgetGateLSTMCell(core_rnn_cell.RNNCell):
|
||||
in an existing scope. If not `True`, and the existing scope already has
|
||||
the given variables, an error is raised.
|
||||
"""
|
||||
super(CoupledInputForgetGateLSTMCell, self).__init__(_reuse=reuse)
|
||||
if not state_is_tuple:
|
||||
logging.warn(
|
||||
"%s: Using a concatenated state is slower and will soon be "
|
||||
@ -173,7 +174,7 @@ class CoupledInputForgetGateLSTMCell(core_rnn_cell.RNNCell):
|
||||
def output_size(self):
|
||||
return self._output_size
|
||||
|
||||
def __call__(self, inputs, state, scope=None):
|
||||
def call(self, inputs, state):
|
||||
"""Run one step of LSTM.
|
||||
|
||||
Args:
|
||||
@ -182,7 +183,6 @@ class CoupledInputForgetGateLSTMCell(core_rnn_cell.RNNCell):
|
||||
`2-D, batch x state_size`. If `state_is_tuple` is True, this must be a
|
||||
tuple of state Tensors, both `2-D`, with column sizes `c_state` and
|
||||
`m_state`.
|
||||
scope: VariableScope for the created subgraph; defaults to "LSTMCell".
|
||||
|
||||
Returns:
|
||||
A tuple containing:
|
||||
@ -212,51 +212,49 @@ class CoupledInputForgetGateLSTMCell(core_rnn_cell.RNNCell):
|
||||
input_size = inputs.get_shape().with_rank(2)[1]
|
||||
if input_size.value is None:
|
||||
raise ValueError("Could not infer input size from inputs.get_shape()[-1]")
|
||||
with _checked_scope(self, scope or "coupled_input_forget_gate_lstm_cell",
|
||||
initializer=self._initializer, reuse=self._reuse):
|
||||
concat_w = _get_concat_variable(
|
||||
"W", [input_size.value + num_proj, 3 * self._num_units],
|
||||
dtype, self._num_unit_shards)
|
||||
concat_w = _get_concat_variable(
|
||||
"W", [input_size.value + num_proj, 3 * self._num_units],
|
||||
dtype, self._num_unit_shards)
|
||||
|
||||
b = vs.get_variable(
|
||||
"B",
|
||||
shape=[3 * self._num_units],
|
||||
initializer=init_ops.zeros_initializer(),
|
||||
dtype=dtype)
|
||||
b = vs.get_variable(
|
||||
"B",
|
||||
shape=[3 * self._num_units],
|
||||
initializer=init_ops.zeros_initializer(),
|
||||
dtype=dtype)
|
||||
|
||||
# j = new_input, f = forget_gate, o = output_gate
|
||||
cell_inputs = array_ops.concat([inputs, m_prev], 1)
|
||||
lstm_matrix = nn_ops.bias_add(math_ops.matmul(cell_inputs, concat_w), b)
|
||||
j, f, o = array_ops.split(value=lstm_matrix, num_or_size_splits=3, axis=1)
|
||||
# j = new_input, f = forget_gate, o = output_gate
|
||||
cell_inputs = array_ops.concat([inputs, m_prev], 1)
|
||||
lstm_matrix = nn_ops.bias_add(math_ops.matmul(cell_inputs, concat_w), b)
|
||||
j, f, o = array_ops.split(value=lstm_matrix, num_or_size_splits=3, axis=1)
|
||||
|
||||
# Diagonal connections
|
||||
if self._use_peepholes:
|
||||
w_f_diag = vs.get_variable(
|
||||
"W_F_diag", shape=[self._num_units], dtype=dtype)
|
||||
w_o_diag = vs.get_variable(
|
||||
"W_O_diag", shape=[self._num_units], dtype=dtype)
|
||||
# Diagonal connections
|
||||
if self._use_peepholes:
|
||||
w_f_diag = vs.get_variable(
|
||||
"W_F_diag", shape=[self._num_units], dtype=dtype)
|
||||
w_o_diag = vs.get_variable(
|
||||
"W_O_diag", shape=[self._num_units], dtype=dtype)
|
||||
|
||||
if self._use_peepholes:
|
||||
f_act = sigmoid(f + self._forget_bias + w_f_diag * c_prev)
|
||||
else:
|
||||
f_act = sigmoid(f + self._forget_bias)
|
||||
c = (f_act * c_prev + (1 - f_act) * self._activation(j))
|
||||
if self._use_peepholes:
|
||||
f_act = sigmoid(f + self._forget_bias + w_f_diag * c_prev)
|
||||
else:
|
||||
f_act = sigmoid(f + self._forget_bias)
|
||||
c = (f_act * c_prev + (1 - f_act) * self._activation(j))
|
||||
|
||||
if self._use_peepholes:
|
||||
m = sigmoid(o + w_o_diag * c) * self._activation(c)
|
||||
else:
|
||||
m = sigmoid(o) * self._activation(c)
|
||||
if self._use_peepholes:
|
||||
m = sigmoid(o + w_o_diag * c) * self._activation(c)
|
||||
else:
|
||||
m = sigmoid(o) * self._activation(c)
|
||||
|
||||
if self._num_proj is not None:
|
||||
concat_w_proj = _get_concat_variable(
|
||||
"W_P", [self._num_units, self._num_proj],
|
||||
dtype, self._num_proj_shards)
|
||||
if self._num_proj is not None:
|
||||
concat_w_proj = _get_concat_variable(
|
||||
"W_P", [self._num_units, self._num_proj],
|
||||
dtype, self._num_proj_shards)
|
||||
|
||||
m = math_ops.matmul(m, concat_w_proj)
|
||||
if self._proj_clip is not None:
|
||||
# pylint: disable=invalid-unary-operand-type
|
||||
m = clip_ops.clip_by_value(m, -self._proj_clip, self._proj_clip)
|
||||
# pylint: enable=invalid-unary-operand-type
|
||||
m = math_ops.matmul(m, concat_w_proj)
|
||||
if self._proj_clip is not None:
|
||||
# pylint: disable=invalid-unary-operand-type
|
||||
m = clip_ops.clip_by_value(m, -self._proj_clip, self._proj_clip)
|
||||
# pylint: enable=invalid-unary-operand-type
|
||||
|
||||
new_state = (core_rnn_cell.LSTMStateTuple(c, m) if self._state_is_tuple else
|
||||
array_ops.concat([c, m], 1))
|
||||
@ -301,6 +299,7 @@ class TimeFreqLSTMCell(core_rnn_cell.RNNCell):
|
||||
in an existing scope. If not `True`, and the existing scope already has
|
||||
the given variables, an error is raised.
|
||||
"""
|
||||
super(TimeFreqLSTMCell, self).__init__(_reuse=reuse)
|
||||
self._num_units = num_units
|
||||
self._use_peepholes = use_peepholes
|
||||
self._cell_clip = cell_clip
|
||||
@ -321,14 +320,12 @@ class TimeFreqLSTMCell(core_rnn_cell.RNNCell):
|
||||
def state_size(self):
|
||||
return self._state_size
|
||||
|
||||
def __call__(self, inputs, state, scope=None):
|
||||
def call(self, inputs, state):
|
||||
"""Run one step of LSTM.
|
||||
|
||||
Args:
|
||||
inputs: input Tensor, 2D, batch x num_units.
|
||||
state: state Tensor, 2D, batch x state_size.
|
||||
scope: VariableScope for the created subgraph; defaults to
|
||||
"TimeFreqLSTMCell".
|
||||
|
||||
Returns:
|
||||
A tuple containing:
|
||||
@ -347,63 +344,63 @@ class TimeFreqLSTMCell(core_rnn_cell.RNNCell):
|
||||
freq_inputs = self._make_tf_features(inputs)
|
||||
dtype = inputs.dtype
|
||||
actual_input_size = freq_inputs[0].get_shape().as_list()[1]
|
||||
with _checked_scope(self, scope or "time_freq_lstm_cell",
|
||||
initializer=self._initializer, reuse=self._reuse):
|
||||
concat_w = _get_concat_variable(
|
||||
"W", [actual_input_size + 2*self._num_units, 4 * self._num_units],
|
||||
dtype, self._num_unit_shards)
|
||||
b = vs.get_variable(
|
||||
"B",
|
||||
shape=[4 * self._num_units],
|
||||
initializer=init_ops.zeros_initializer(),
|
||||
dtype=dtype)
|
||||
|
||||
# Diagonal connections
|
||||
concat_w = _get_concat_variable(
|
||||
"W", [actual_input_size + 2*self._num_units, 4 * self._num_units],
|
||||
dtype, self._num_unit_shards)
|
||||
|
||||
b = vs.get_variable(
|
||||
"B",
|
||||
shape=[4 * self._num_units],
|
||||
initializer=init_ops.zeros_initializer(),
|
||||
dtype=dtype)
|
||||
|
||||
# Diagonal connections
|
||||
if self._use_peepholes:
|
||||
w_f_diag = vs.get_variable(
|
||||
"W_F_diag", shape=[self._num_units], dtype=dtype)
|
||||
w_i_diag = vs.get_variable(
|
||||
"W_I_diag", shape=[self._num_units], dtype=dtype)
|
||||
w_o_diag = vs.get_variable(
|
||||
"W_O_diag", shape=[self._num_units], dtype=dtype)
|
||||
|
||||
# initialize the first freq state to be zero
|
||||
m_prev_freq = array_ops.zeros([int(inputs.get_shape()[0]),
|
||||
self._num_units], dtype)
|
||||
for fq in range(len(freq_inputs)):
|
||||
c_prev = array_ops.slice(state, [0, 2*fq*self._num_units],
|
||||
[-1, self._num_units])
|
||||
m_prev = array_ops.slice(state, [0, (2*fq+1)*self._num_units],
|
||||
[-1, self._num_units])
|
||||
# i = input_gate, j = new_input, f = forget_gate, o = output_gate
|
||||
cell_inputs = array_ops.concat([freq_inputs[fq], m_prev, m_prev_freq],
|
||||
1)
|
||||
lstm_matrix = nn_ops.bias_add(math_ops.matmul(cell_inputs, concat_w), b)
|
||||
i, j, f, o = array_ops.split(
|
||||
value=lstm_matrix, num_or_size_splits=4, axis=1)
|
||||
|
||||
if self._use_peepholes:
|
||||
w_f_diag = vs.get_variable(
|
||||
"W_F_diag", shape=[self._num_units], dtype=dtype)
|
||||
w_i_diag = vs.get_variable(
|
||||
"W_I_diag", shape=[self._num_units], dtype=dtype)
|
||||
w_o_diag = vs.get_variable(
|
||||
"W_O_diag", shape=[self._num_units], dtype=dtype)
|
||||
c = (sigmoid(f + self._forget_bias + w_f_diag * c_prev) * c_prev +
|
||||
sigmoid(i + w_i_diag * c_prev) * tanh(j))
|
||||
else:
|
||||
c = (sigmoid(f + self._forget_bias) * c_prev + sigmoid(i) * tanh(j))
|
||||
|
||||
# initialize the first freq state to be zero
|
||||
m_prev_freq = array_ops.zeros([int(inputs.get_shape()[0]),
|
||||
self._num_units], dtype)
|
||||
for fq in range(len(freq_inputs)):
|
||||
c_prev = array_ops.slice(state, [0, 2*fq*self._num_units],
|
||||
[-1, self._num_units])
|
||||
m_prev = array_ops.slice(state, [0, (2*fq+1)*self._num_units],
|
||||
[-1, self._num_units])
|
||||
# i = input_gate, j = new_input, f = forget_gate, o = output_gate
|
||||
cell_inputs = array_ops.concat([freq_inputs[fq], m_prev, m_prev_freq],
|
||||
1)
|
||||
lstm_matrix = nn_ops.bias_add(math_ops.matmul(cell_inputs, concat_w), b)
|
||||
i, j, f, o = array_ops.split(
|
||||
value=lstm_matrix, num_or_size_splits=4, axis=1)
|
||||
if self._cell_clip is not None:
|
||||
# pylint: disable=invalid-unary-operand-type
|
||||
c = clip_ops.clip_by_value(c, -self._cell_clip, self._cell_clip)
|
||||
# pylint: enable=invalid-unary-operand-type
|
||||
|
||||
if self._use_peepholes:
|
||||
c = (sigmoid(f + self._forget_bias + w_f_diag * c_prev) * c_prev +
|
||||
sigmoid(i + w_i_diag * c_prev) * tanh(j))
|
||||
else:
|
||||
c = (sigmoid(f + self._forget_bias) * c_prev + sigmoid(i) * tanh(j))
|
||||
|
||||
if self._cell_clip is not None:
|
||||
# pylint: disable=invalid-unary-operand-type
|
||||
c = clip_ops.clip_by_value(c, -self._cell_clip, self._cell_clip)
|
||||
# pylint: enable=invalid-unary-operand-type
|
||||
|
||||
if self._use_peepholes:
|
||||
m = sigmoid(o + w_o_diag * c) * tanh(c)
|
||||
else:
|
||||
m = sigmoid(o) * tanh(c)
|
||||
m_prev_freq = m
|
||||
if fq == 0:
|
||||
state_out = array_ops.concat([c, m], 1)
|
||||
m_out = m
|
||||
else:
|
||||
state_out = array_ops.concat([state_out, c, m], 1)
|
||||
m_out = array_ops.concat([m_out, m], 1)
|
||||
if self._use_peepholes:
|
||||
m = sigmoid(o + w_o_diag * c) * tanh(c)
|
||||
else:
|
||||
m = sigmoid(o) * tanh(c)
|
||||
m_prev_freq = m
|
||||
if fq == 0:
|
||||
state_out = array_ops.concat([c, m], 1)
|
||||
m_out = m
|
||||
else:
|
||||
state_out = array_ops.concat([state_out, c, m], 1)
|
||||
m_out = array_ops.concat([m_out, m], 1)
|
||||
return m_out, state_out
|
||||
|
||||
def _make_tf_features(self, input_feat):
|
||||
@ -499,6 +496,7 @@ class GridLSTMCell(core_rnn_cell.RNNCell):
|
||||
Raises:
|
||||
ValueError: if the num_frequency_blocks list is not specified
|
||||
"""
|
||||
super(GridLSTMCell, self).__init__(_reuse=reuse)
|
||||
if not state_is_tuple:
|
||||
logging.warn("%s: Using a concatenated state is slower and will soon be "
|
||||
"deprecated. Use state_is_tuple=True.", self)
|
||||
@ -550,15 +548,13 @@ class GridLSTMCell(core_rnn_cell.RNNCell):
|
||||
def state_tuple_type(self):
|
||||
return self._state_tuple_type
|
||||
|
||||
def __call__(self, inputs, state, scope=None):
|
||||
def call(self, inputs, state):
|
||||
"""Run one step of LSTM.
|
||||
|
||||
Args:
|
||||
inputs: input Tensor, 2D, [batch, feature_size].
|
||||
state: Tensor or tuple of Tensors, 2D, [batch, state_size], depends on the
|
||||
flag self._state_is_tuple.
|
||||
scope: (optional) VariableScope for the created subgraph; if None, it
|
||||
defaults to "GridLSTMCell".
|
||||
|
||||
Returns:
|
||||
A tuple containing:
|
||||
@ -573,21 +569,19 @@ class GridLSTMCell(core_rnn_cell.RNNCell):
|
||||
"""
|
||||
batch_size = inputs.shape[0].value or array_ops.shape(inputs)[0]
|
||||
freq_inputs = self._make_tf_features(inputs)
|
||||
with _checked_scope(self, scope or "grid_lstm_cell",
|
||||
initializer=self._initializer, reuse=self._reuse):
|
||||
m_out_lst = []
|
||||
state_out_lst = []
|
||||
for block in range(len(freq_inputs)):
|
||||
m_out_lst_current, state_out_lst_current = self._compute(
|
||||
freq_inputs[block], block, state, batch_size,
|
||||
state_is_tuple=self._state_is_tuple)
|
||||
m_out_lst.extend(m_out_lst_current)
|
||||
state_out_lst.extend(state_out_lst_current)
|
||||
if self._state_is_tuple:
|
||||
state_out = self._state_tuple_type(*state_out_lst)
|
||||
else:
|
||||
state_out = array_ops.concat(state_out_lst, 1)
|
||||
m_out = array_ops.concat(m_out_lst, 1)
|
||||
m_out_lst = []
|
||||
state_out_lst = []
|
||||
for block in range(len(freq_inputs)):
|
||||
m_out_lst_current, state_out_lst_current = self._compute(
|
||||
freq_inputs[block], block, state, batch_size,
|
||||
state_is_tuple=self._state_is_tuple)
|
||||
m_out_lst.extend(m_out_lst_current)
|
||||
state_out_lst.extend(state_out_lst_current)
|
||||
if self._state_is_tuple:
|
||||
state_out = self._state_tuple_type(*state_out_lst)
|
||||
else:
|
||||
state_out = array_ops.concat(state_out_lst, 1)
|
||||
m_out = array_ops.concat(m_out_lst, 1)
|
||||
return m_out, state_out
|
||||
|
||||
def _compute(self, freq_inputs, block, state, batch_size,
|
||||
@ -974,14 +968,12 @@ class BidirectionalGridLSTMCell(GridLSTMCell):
|
||||
*([num_units, num_units] * self._total_blocks * 2))
|
||||
self._output_size = 2 * num_units * self._total_blocks * 2
|
||||
|
||||
def __call__(self, inputs, state, scope=None):
|
||||
def call(self, inputs, state):
|
||||
"""Run one step of LSTM.
|
||||
|
||||
Args:
|
||||
inputs: input Tensor, 2D, [batch, num_units].
|
||||
state: tuple of Tensors, 2D, [batch, state_size].
|
||||
scope: (optional) VariableScope for the created subgraph; if None, it
|
||||
defaults to "BidirectionalGridLSTMCell".
|
||||
|
||||
Returns:
|
||||
A tuple containing:
|
||||
@ -1002,29 +994,27 @@ class BidirectionalGridLSTMCell(GridLSTMCell):
|
||||
bwd_inputs = fwd_inputs
|
||||
|
||||
# Forward processing
|
||||
with _checked_scope(self, scope or "bidirectional_grid_lstm_cell",
|
||||
initializer=self._initializer, reuse=self._reuse):
|
||||
with vs.variable_scope("fwd"):
|
||||
fwd_m_out_lst = []
|
||||
fwd_state_out_lst = []
|
||||
for block in range(len(fwd_inputs)):
|
||||
fwd_m_out_lst_current, fwd_state_out_lst_current = self._compute(
|
||||
fwd_inputs[block], block, state, batch_size,
|
||||
state_prefix="fwd_state", state_is_tuple=True)
|
||||
fwd_m_out_lst.extend(fwd_m_out_lst_current)
|
||||
fwd_state_out_lst.extend(fwd_state_out_lst_current)
|
||||
# Backward processing
|
||||
bwd_m_out_lst = []
|
||||
bwd_state_out_lst = []
|
||||
with vs.variable_scope("bwd"):
|
||||
for block in range(len(bwd_inputs)):
|
||||
# Reverse the blocks
|
||||
bwd_inputs_reverse = bwd_inputs[block][::-1]
|
||||
bwd_m_out_lst_current, bwd_state_out_lst_current = self._compute(
|
||||
bwd_inputs_reverse, block, state, batch_size,
|
||||
state_prefix="bwd_state", state_is_tuple=True)
|
||||
bwd_m_out_lst.extend(bwd_m_out_lst_current)
|
||||
bwd_state_out_lst.extend(bwd_state_out_lst_current)
|
||||
with vs.variable_scope("fwd"):
|
||||
fwd_m_out_lst = []
|
||||
fwd_state_out_lst = []
|
||||
for block in range(len(fwd_inputs)):
|
||||
fwd_m_out_lst_current, fwd_state_out_lst_current = self._compute(
|
||||
fwd_inputs[block], block, state, batch_size,
|
||||
state_prefix="fwd_state", state_is_tuple=True)
|
||||
fwd_m_out_lst.extend(fwd_m_out_lst_current)
|
||||
fwd_state_out_lst.extend(fwd_state_out_lst_current)
|
||||
# Backward processing
|
||||
bwd_m_out_lst = []
|
||||
bwd_state_out_lst = []
|
||||
with vs.variable_scope("bwd"):
|
||||
for block in range(len(bwd_inputs)):
|
||||
# Reverse the blocks
|
||||
bwd_inputs_reverse = bwd_inputs[block][::-1]
|
||||
bwd_m_out_lst_current, bwd_state_out_lst_current = self._compute(
|
||||
bwd_inputs_reverse, block, state, batch_size,
|
||||
state_prefix="bwd_state", state_is_tuple=True)
|
||||
bwd_m_out_lst.extend(bwd_m_out_lst_current)
|
||||
bwd_state_out_lst.extend(bwd_state_out_lst_current)
|
||||
state_out = self._state_tuple_type(*(fwd_state_out_lst + bwd_state_out_lst))
|
||||
# Outputs are always concated as it is never used separately.
|
||||
m_out = array_ops.concat(fwd_m_out_lst + bwd_m_out_lst, 1)
|
||||
@ -1069,6 +1059,7 @@ class AttentionCellWrapper(core_rnn_cell.RNNCell):
|
||||
ValueError: if cell returns a state tuple but the flag
|
||||
`state_is_tuple` is `False` or if attn_length is zero or less.
|
||||
"""
|
||||
super(AttentionCellWrapper, self).__init__(_reuse=reuse)
|
||||
if not isinstance(cell, core_rnn_cell.RNNCell):
|
||||
raise TypeError("The parameter cell is not RNNCell.")
|
||||
if nest.is_sequence(cell.state_size) and not state_is_tuple:
|
||||
@ -1107,42 +1098,40 @@ class AttentionCellWrapper(core_rnn_cell.RNNCell):
|
||||
def output_size(self):
|
||||
return self._attn_size
|
||||
|
||||
def __call__(self, inputs, state, scope=None):
|
||||
def call(self, inputs, state):
|
||||
"""Long short-term memory cell with attention (LSTMA)."""
|
||||
with _checked_scope(self, scope or "attention_cell_wrapper",
|
||||
reuse=self._reuse):
|
||||
if self._state_is_tuple:
|
||||
state, attns, attn_states = state
|
||||
else:
|
||||
states = state
|
||||
state = array_ops.slice(states, [0, 0], [-1, self._cell.state_size])
|
||||
attns = array_ops.slice(
|
||||
states, [0, self._cell.state_size], [-1, self._attn_size])
|
||||
attn_states = array_ops.slice(
|
||||
states, [0, self._cell.state_size + self._attn_size],
|
||||
[-1, self._attn_size * self._attn_length])
|
||||
attn_states = array_ops.reshape(attn_states,
|
||||
[-1, self._attn_length, self._attn_size])
|
||||
input_size = self._input_size
|
||||
if input_size is None:
|
||||
input_size = inputs.get_shape().as_list()[1]
|
||||
inputs = _linear([inputs, attns], input_size, True)
|
||||
lstm_output, new_state = self._cell(inputs, state)
|
||||
if self._state_is_tuple:
|
||||
new_state_cat = array_ops.concat(nest.flatten(new_state), 1)
|
||||
else:
|
||||
new_state_cat = new_state
|
||||
new_attns, new_attn_states = self._attention(new_state_cat, attn_states)
|
||||
with vs.variable_scope("attn_output_projection"):
|
||||
output = _linear([lstm_output, new_attns], self._attn_size, True)
|
||||
new_attn_states = array_ops.concat(
|
||||
[new_attn_states, array_ops.expand_dims(output, 1)], 1)
|
||||
new_attn_states = array_ops.reshape(
|
||||
new_attn_states, [-1, self._attn_length * self._attn_size])
|
||||
new_state = (new_state, new_attns, new_attn_states)
|
||||
if not self._state_is_tuple:
|
||||
new_state = array_ops.concat(list(new_state), 1)
|
||||
return output, new_state
|
||||
if self._state_is_tuple:
|
||||
state, attns, attn_states = state
|
||||
else:
|
||||
states = state
|
||||
state = array_ops.slice(states, [0, 0], [-1, self._cell.state_size])
|
||||
attns = array_ops.slice(
|
||||
states, [0, self._cell.state_size], [-1, self._attn_size])
|
||||
attn_states = array_ops.slice(
|
||||
states, [0, self._cell.state_size + self._attn_size],
|
||||
[-1, self._attn_size * self._attn_length])
|
||||
attn_states = array_ops.reshape(attn_states,
|
||||
[-1, self._attn_length, self._attn_size])
|
||||
input_size = self._input_size
|
||||
if input_size is None:
|
||||
input_size = inputs.get_shape().as_list()[1]
|
||||
inputs = _linear([inputs, attns], input_size, True)
|
||||
lstm_output, new_state = self._cell(inputs, state)
|
||||
if self._state_is_tuple:
|
||||
new_state_cat = array_ops.concat(nest.flatten(new_state), 1)
|
||||
else:
|
||||
new_state_cat = new_state
|
||||
new_attns, new_attn_states = self._attention(new_state_cat, attn_states)
|
||||
with vs.variable_scope("attn_output_projection"):
|
||||
output = _linear([lstm_output, new_attns], self._attn_size, True)
|
||||
new_attn_states = array_ops.concat(
|
||||
[new_attn_states, array_ops.expand_dims(output, 1)], 1)
|
||||
new_attn_states = array_ops.reshape(
|
||||
new_attn_states, [-1, self._attn_length * self._attn_size])
|
||||
new_state = (new_state, new_attns, new_attn_states)
|
||||
if not self._state_is_tuple:
|
||||
new_state = array_ops.concat(list(new_state), 1)
|
||||
return output, new_state
|
||||
|
||||
def _attention(self, query, attn_states):
|
||||
conv2d = nn_ops.conv2d
|
||||
@ -1213,6 +1202,7 @@ class LayerNormBasicLSTMCell(core_rnn_cell.RNNCell):
|
||||
in an existing scope. If not `True`, and the existing scope already has
|
||||
the given variables, an error is raised.
|
||||
"""
|
||||
super(LayerNormBasicLSTMCell, self).__init__(_reuse=reuse)
|
||||
|
||||
if input_size is not None:
|
||||
logging.warn("%s: The input_size parameter is deprecated.", self)
|
||||
@ -1256,34 +1246,31 @@ class LayerNormBasicLSTMCell(core_rnn_cell.RNNCell):
|
||||
out = nn_ops.bias_add(out, bias)
|
||||
return out
|
||||
|
||||
def __call__(self, inputs, state, scope=None):
|
||||
def call(self, inputs, state):
|
||||
"""LSTM cell with layer normalization and recurrent dropout."""
|
||||
c, h = state
|
||||
args = array_ops.concat([inputs, h], 1)
|
||||
concat = self._linear(args)
|
||||
|
||||
with _checked_scope(self, scope or "layer_norm_basic_lstm_cell",
|
||||
reuse=self._reuse):
|
||||
c, h = state
|
||||
args = array_ops.concat([inputs, h], 1)
|
||||
concat = self._linear(args)
|
||||
i, j, f, o = array_ops.split(value=concat, num_or_size_splits=4, axis=1)
|
||||
if self._layer_norm:
|
||||
i = self._norm(i, "input")
|
||||
j = self._norm(j, "transform")
|
||||
f = self._norm(f, "forget")
|
||||
o = self._norm(o, "output")
|
||||
|
||||
i, j, f, o = array_ops.split(value=concat, num_or_size_splits=4, axis=1)
|
||||
if self._layer_norm:
|
||||
i = self._norm(i, "input")
|
||||
j = self._norm(j, "transform")
|
||||
f = self._norm(f, "forget")
|
||||
o = self._norm(o, "output")
|
||||
g = self._activation(j)
|
||||
if (not isinstance(self._keep_prob, float)) or self._keep_prob < 1:
|
||||
g = nn_ops.dropout(g, self._keep_prob, seed=self._seed)
|
||||
|
||||
g = self._activation(j)
|
||||
if (not isinstance(self._keep_prob, float)) or self._keep_prob < 1:
|
||||
g = nn_ops.dropout(g, self._keep_prob, seed=self._seed)
|
||||
new_c = (c * math_ops.sigmoid(f + self._forget_bias)
|
||||
+ math_ops.sigmoid(i) * g)
|
||||
if self._layer_norm:
|
||||
new_c = self._norm(new_c, "state")
|
||||
new_h = self._activation(new_c) * math_ops.sigmoid(o)
|
||||
|
||||
new_c = (c * math_ops.sigmoid(f + self._forget_bias)
|
||||
+ math_ops.sigmoid(i) * g)
|
||||
if self._layer_norm:
|
||||
new_c = self._norm(new_c, "state")
|
||||
new_h = self._activation(new_c) * math_ops.sigmoid(o)
|
||||
|
||||
new_state = core_rnn_cell.LSTMStateTuple(new_c, new_h)
|
||||
return new_h, new_state
|
||||
new_state = core_rnn_cell.LSTMStateTuple(new_c, new_h)
|
||||
return new_h, new_state
|
||||
|
||||
|
||||
class NASCell(core_rnn_cell.RNNCell):
|
||||
@ -1313,6 +1300,7 @@ class NASCell(core_rnn_cell.RNNCell):
|
||||
in an existing scope. If not `True`, and the existing scope already has
|
||||
the given variables, an error is raised.
|
||||
"""
|
||||
super(NASCell, self).__init__(_reuse=reuse)
|
||||
self._num_units = num_units
|
||||
self._num_proj = num_proj
|
||||
self._use_biases = use_biases
|
||||
@ -1333,14 +1321,13 @@ class NASCell(core_rnn_cell.RNNCell):
|
||||
def output_size(self):
|
||||
return self._output_size
|
||||
|
||||
def __call__(self, inputs, state, scope=None):
|
||||
def call(self, inputs, state):
|
||||
"""Run one step of NAS Cell.
|
||||
|
||||
Args:
|
||||
inputs: input Tensor, 2D, batch x num_units.
|
||||
state: This must be a tuple of state Tensors, both `2-D`, with column
|
||||
sizes `c_state` and `m_state`.
|
||||
scope: VariableScope for the created subgraph; defaults to "nas_rnn".
|
||||
|
||||
Returns:
|
||||
A tuple containing:
|
||||
@ -1368,71 +1355,70 @@ class NASCell(core_rnn_cell.RNNCell):
|
||||
input_size = inputs.get_shape().with_rank(2)[1]
|
||||
if input_size.value is None:
|
||||
raise ValueError("Could not infer input size from inputs.get_shape()[-1]")
|
||||
with _checked_scope(self, scope or "nas_rnn", reuse=self._reuse):
|
||||
# Variables for the NAS cell. W_m is all matrices multiplying the
|
||||
# hiddenstate and W_inputs is all matrices multiplying the inputs.
|
||||
concat_w_m = vs.get_variable(
|
||||
"recurrent_weights", [num_proj, 8 * self._num_units],
|
||||
dtype)
|
||||
concat_w_inputs = vs.get_variable(
|
||||
"weights", [input_size.value, 8 * self._num_units],
|
||||
# Variables for the NAS cell. W_m is all matrices multiplying the
|
||||
# hiddenstate and W_inputs is all matrices multiplying the inputs.
|
||||
concat_w_m = vs.get_variable(
|
||||
"recurrent_weights", [num_proj, 8 * self._num_units],
|
||||
dtype)
|
||||
concat_w_inputs = vs.get_variable(
|
||||
"weights", [input_size.value, 8 * self._num_units],
|
||||
dtype)
|
||||
|
||||
m_matrix = math_ops.matmul(m_prev, concat_w_m)
|
||||
inputs_matrix = math_ops.matmul(inputs, concat_w_inputs)
|
||||
|
||||
if self._use_biases:
|
||||
b = vs.get_variable(
|
||||
"bias",
|
||||
shape=[8 * self._num_units],
|
||||
initializer=init_ops.zeros_initializer(),
|
||||
dtype=dtype)
|
||||
m_matrix = nn_ops.bias_add(m_matrix, b)
|
||||
|
||||
# The NAS cell branches into 8 different splits for both the hiddenstate
|
||||
# and the input
|
||||
m_matrix_splits = array_ops.split(axis=1, num_or_size_splits=8,
|
||||
value=m_matrix)
|
||||
inputs_matrix_splits = array_ops.split(axis=1, num_or_size_splits=8,
|
||||
value=inputs_matrix)
|
||||
|
||||
# First layer
|
||||
layer1_0 = sigmoid(inputs_matrix_splits[0] + m_matrix_splits[0])
|
||||
layer1_1 = relu(inputs_matrix_splits[1] + m_matrix_splits[1])
|
||||
layer1_2 = sigmoid(inputs_matrix_splits[2] + m_matrix_splits[2])
|
||||
layer1_3 = relu(inputs_matrix_splits[3] * m_matrix_splits[3])
|
||||
layer1_4 = tanh(inputs_matrix_splits[4] + m_matrix_splits[4])
|
||||
layer1_5 = sigmoid(inputs_matrix_splits[5] + m_matrix_splits[5])
|
||||
layer1_6 = tanh(inputs_matrix_splits[6] + m_matrix_splits[6])
|
||||
layer1_7 = sigmoid(inputs_matrix_splits[7] + m_matrix_splits[7])
|
||||
|
||||
# Second layer
|
||||
l2_0 = tanh(layer1_0 * layer1_1)
|
||||
l2_1 = tanh(layer1_2 + layer1_3)
|
||||
l2_2 = tanh(layer1_4 * layer1_5)
|
||||
l2_3 = sigmoid(layer1_6 + layer1_7)
|
||||
|
||||
# Inject the cell
|
||||
l2_0 = tanh(l2_0 + c_prev)
|
||||
|
||||
# Third layer
|
||||
l3_0_pre = l2_0 * l2_1
|
||||
new_c = l3_0_pre # create new cell
|
||||
l3_0 = l3_0_pre
|
||||
l3_1 = tanh(l2_2 + l2_3)
|
||||
|
||||
# Final layer
|
||||
new_m = tanh(l3_0 * l3_1)
|
||||
|
||||
# Projection layer if specified
|
||||
if self._num_proj is not None:
|
||||
concat_w_proj = vs.get_variable(
|
||||
"projection_weights", [self._num_units, self._num_proj],
|
||||
dtype)
|
||||
new_m = math_ops.matmul(new_m, concat_w_proj)
|
||||
|
||||
m_matrix = math_ops.matmul(m_prev, concat_w_m)
|
||||
inputs_matrix = math_ops.matmul(inputs, concat_w_inputs)
|
||||
|
||||
if self._use_biases:
|
||||
b = vs.get_variable(
|
||||
"bias",
|
||||
shape=[8 * self._num_units],
|
||||
initializer=init_ops.zeros_initializer(),
|
||||
dtype=dtype)
|
||||
m_matrix = nn_ops.bias_add(m_matrix, b)
|
||||
|
||||
# The NAS cell branches into 8 different splits for both the hiddenstate
|
||||
# and the input
|
||||
m_matrix_splits = array_ops.split(axis=1, num_or_size_splits=8,
|
||||
value=m_matrix)
|
||||
inputs_matrix_splits = array_ops.split(axis=1, num_or_size_splits=8,
|
||||
value=inputs_matrix)
|
||||
|
||||
# First layer
|
||||
layer1_0 = sigmoid(inputs_matrix_splits[0] + m_matrix_splits[0])
|
||||
layer1_1 = relu(inputs_matrix_splits[1] + m_matrix_splits[1])
|
||||
layer1_2 = sigmoid(inputs_matrix_splits[2] + m_matrix_splits[2])
|
||||
layer1_3 = relu(inputs_matrix_splits[3] * m_matrix_splits[3])
|
||||
layer1_4 = tanh(inputs_matrix_splits[4] + m_matrix_splits[4])
|
||||
layer1_5 = sigmoid(inputs_matrix_splits[5] + m_matrix_splits[5])
|
||||
layer1_6 = tanh(inputs_matrix_splits[6] + m_matrix_splits[6])
|
||||
layer1_7 = sigmoid(inputs_matrix_splits[7] + m_matrix_splits[7])
|
||||
|
||||
# Second layer
|
||||
l2_0 = tanh(layer1_0 * layer1_1)
|
||||
l2_1 = tanh(layer1_2 + layer1_3)
|
||||
l2_2 = tanh(layer1_4 * layer1_5)
|
||||
l2_3 = sigmoid(layer1_6 + layer1_7)
|
||||
|
||||
# Inject the cell
|
||||
l2_0 = tanh(l2_0 + c_prev)
|
||||
|
||||
# Third layer
|
||||
l3_0_pre = l2_0 * l2_1
|
||||
new_c = l3_0_pre # create new cell
|
||||
l3_0 = l3_0_pre
|
||||
l3_1 = tanh(l2_2 + l2_3)
|
||||
|
||||
# Final layer
|
||||
new_m = tanh(l3_0 * l3_1)
|
||||
|
||||
# Projection layer if specified
|
||||
if self._num_proj is not None:
|
||||
concat_w_proj = vs.get_variable(
|
||||
"projection_weights", [self._num_units, self._num_proj],
|
||||
dtype)
|
||||
new_m = math_ops.matmul(new_m, concat_w_proj)
|
||||
|
||||
new_state = core_rnn_cell.LSTMStateTuple(new_c, new_m)
|
||||
return new_m, new_state
|
||||
new_state = core_rnn_cell.LSTMStateTuple(new_c, new_m)
|
||||
return new_m, new_state
|
||||
|
||||
|
||||
class UGRNNCell(core_rnn_cell.RNNCell):
|
||||
@ -1467,6 +1453,7 @@ class UGRNNCell(core_rnn_cell.RNNCell):
|
||||
in an existing scope. If not `True`, and the existing scope already has
|
||||
the given variables, an error is raised.
|
||||
"""
|
||||
super(UGRNNCell, self).__init__(_reuse=reuse)
|
||||
self._num_units = num_units
|
||||
self._initializer = initializer
|
||||
self._forget_bias = forget_bias
|
||||
@ -1481,13 +1468,12 @@ class UGRNNCell(core_rnn_cell.RNNCell):
|
||||
def output_size(self):
|
||||
return self._num_units
|
||||
|
||||
def __call__(self, inputs, state, scope=None):
|
||||
def call(self, inputs, state):
|
||||
"""Run one step of UGRNN.
|
||||
|
||||
Args:
|
||||
inputs: input Tensor, 2D, batch x input size.
|
||||
state: state Tensor, 2D, batch x num units.
|
||||
scope: VariableScope for the created subgraph; defaults to "ugrnn_cell".
|
||||
|
||||
Returns:
|
||||
new_output: batch x num units, Tensor representing the output of the UGRNN
|
||||
@ -1506,8 +1492,8 @@ class UGRNNCell(core_rnn_cell.RNNCell):
|
||||
if input_size.value is None:
|
||||
raise ValueError("Could not infer input size from inputs.get_shape()[-1]")
|
||||
|
||||
with _checked_scope(self, scope or "ugrnn_cell",
|
||||
initializer=self._initializer, reuse=self._reuse):
|
||||
with vs.variable_scope(vs.get_variable_scope(),
|
||||
initializer=self._initializer):
|
||||
cell_inputs = array_ops.concat([inputs, state], 1)
|
||||
rnn_matrix = _linear(cell_inputs, 2 * self._num_units, True)
|
||||
|
||||
@ -1567,6 +1553,7 @@ class IntersectionRNNCell(core_rnn_cell.RNNCell):
|
||||
in an existing scope. If not `True`, and the existing scope already has
|
||||
the given variables, an error is raised.
|
||||
"""
|
||||
super(IntersectionRNNCell, self).__init__(_reuse=reuse)
|
||||
self._num_units = num_units
|
||||
self._initializer = initializer
|
||||
self._forget_bias = forget_bias
|
||||
@ -1582,14 +1569,12 @@ class IntersectionRNNCell(core_rnn_cell.RNNCell):
|
||||
def output_size(self):
|
||||
return self._num_units
|
||||
|
||||
def __call__(self, inputs, state, scope=None):
|
||||
def call(self, inputs, state):
|
||||
"""Run one step of the Intersection RNN.
|
||||
|
||||
Args:
|
||||
inputs: input Tensor, 2D, batch x input size.
|
||||
state: state Tensor, 2D, batch x num units.
|
||||
scope: VariableScope for the created subgraph; defaults to
|
||||
"intersection_rnn_cell"
|
||||
|
||||
Returns:
|
||||
new_y: batch x num units, Tensor representing the output of the +RNN
|
||||
@ -1610,8 +1595,8 @@ class IntersectionRNNCell(core_rnn_cell.RNNCell):
|
||||
if input_size.value is None:
|
||||
raise ValueError("Could not infer input size from inputs.get_shape()[-1]")
|
||||
|
||||
with _checked_scope(self, scope or "intersection_rnn_cell",
|
||||
initializer=self._initializer, reuse=self._reuse):
|
||||
with vs.variable_scope(vs.get_variable_scope(),
|
||||
initializer=self._initializer):
|
||||
# read-in projections (should be used for first layer in deep +RNN
|
||||
# to transform size of inputs from I --> N)
|
||||
if input_size.value != self._num_units:
|
||||
@ -1683,7 +1668,7 @@ class CompiledWrapper(core_rnn_cell.RNNCell):
|
||||
return not _REGISTERED_OPS[node_def.op].is_stateful
|
||||
|
||||
with jit.experimental_jit_scope(compile_ops=compile_ops):
|
||||
return self._cell(inputs, state, scope=scope)
|
||||
return self._cell(inputs, state, scope)
|
||||
|
||||
|
||||
def _random_exp_initializer(minval,
|
||||
@ -1753,6 +1738,7 @@ class PhasedLSTMCell(core_rnn_cell.RNNCell):
|
||||
in an existing scope. If not `True`, and the existing scope already has
|
||||
the given variables, an error is raised.
|
||||
"""
|
||||
super(PhasedLSTMCell, self).__init__(_reuse=reuse)
|
||||
self._num_units = num_units
|
||||
self._use_peepholes = use_peepholes
|
||||
self._leak = leak
|
||||
@ -1782,7 +1768,7 @@ class PhasedLSTMCell(core_rnn_cell.RNNCell):
|
||||
cycle_ratio = self._mod(shifted_time, period_casted) / period_casted
|
||||
return math_ops.cast(cycle_ratio, dtype=dtypes.float32)
|
||||
|
||||
def __call__(self, inputs, state, scope=None):
|
||||
def call(self, inputs, state):
|
||||
"""Phased LSTM Cell.
|
||||
|
||||
Args:
|
||||
@ -1792,7 +1778,6 @@ class PhasedLSTMCell(core_rnn_cell.RNNCell):
|
||||
The second Tensor has shape [batch, features_size], and type float32.
|
||||
It stores the features.
|
||||
state: core_rnn_cell.LSTMStateTuple, state from previous timestep.
|
||||
scope: string, id of the variable scope.
|
||||
|
||||
Returns:
|
||||
A tuple containing:
|
||||
@ -1801,61 +1786,60 @@ class PhasedLSTMCell(core_rnn_cell.RNNCell):
|
||||
- A core_rnn_cell.LSTMStateTuple, containing 2 Tensors of float32, shape
|
||||
[batch_size, num_units], representing the new state and the output.
|
||||
"""
|
||||
with _checked_scope(self, scope or "phased_lstm_cell", reuse=self._reuse):
|
||||
(c_prev, h_prev) = state
|
||||
(time, x) = inputs
|
||||
(c_prev, h_prev) = state
|
||||
(time, x) = inputs
|
||||
|
||||
in_mask_gates = [x, h_prev]
|
||||
if self._use_peepholes:
|
||||
in_mask_gates.append(c_prev)
|
||||
in_mask_gates = [x, h_prev]
|
||||
if self._use_peepholes:
|
||||
in_mask_gates.append(c_prev)
|
||||
|
||||
with vs.variable_scope("mask_gates"):
|
||||
mask_gates = math_ops.sigmoid(
|
||||
_linear(in_mask_gates, 2 * self._num_units, True))
|
||||
[input_gate, forget_gate] = array_ops.split(
|
||||
axis=1, num_or_size_splits=2, value=mask_gates)
|
||||
with vs.variable_scope("mask_gates"):
|
||||
mask_gates = math_ops.sigmoid(
|
||||
_linear(in_mask_gates, 2 * self._num_units, True))
|
||||
[input_gate, forget_gate] = array_ops.split(
|
||||
axis=1, num_or_size_splits=2, value=mask_gates)
|
||||
|
||||
with vs.variable_scope("new_input"):
|
||||
new_input = math_ops.tanh(
|
||||
_linear([x, h_prev], self._num_units, True))
|
||||
with vs.variable_scope("new_input"):
|
||||
new_input = math_ops.tanh(
|
||||
_linear([x, h_prev], self._num_units, True))
|
||||
|
||||
new_c = (c_prev * forget_gate + input_gate * new_input)
|
||||
new_c = (c_prev * forget_gate + input_gate * new_input)
|
||||
|
||||
in_out_gate = [x, h_prev]
|
||||
if self._use_peepholes:
|
||||
in_out_gate.append(new_c)
|
||||
in_out_gate = [x, h_prev]
|
||||
if self._use_peepholes:
|
||||
in_out_gate.append(new_c)
|
||||
|
||||
with vs.variable_scope("output_gate"):
|
||||
output_gate = math_ops.sigmoid(
|
||||
_linear(in_out_gate, self._num_units, True))
|
||||
with vs.variable_scope("output_gate"):
|
||||
output_gate = math_ops.sigmoid(
|
||||
_linear(in_out_gate, self._num_units, True))
|
||||
|
||||
new_h = math_ops.tanh(new_c) * output_gate
|
||||
new_h = math_ops.tanh(new_c) * output_gate
|
||||
|
||||
period = vs.get_variable(
|
||||
"period", [self._num_units],
|
||||
initializer=_random_exp_initializer(
|
||||
self._period_init_min, self._period_init_max))
|
||||
phase = vs.get_variable(
|
||||
"phase", [self._num_units],
|
||||
initializer=init_ops.random_uniform_initializer(
|
||||
0., period.initial_value))
|
||||
ratio_on = vs.get_variable(
|
||||
"ratio_on", [self._num_units],
|
||||
initializer=init_ops.constant_initializer(self._ratio_on),
|
||||
trainable=self._trainable_ratio_on)
|
||||
period = vs.get_variable(
|
||||
"period", [self._num_units],
|
||||
initializer=_random_exp_initializer(
|
||||
self._period_init_min, self._period_init_max))
|
||||
phase = vs.get_variable(
|
||||
"phase", [self._num_units],
|
||||
initializer=init_ops.random_uniform_initializer(
|
||||
0., period.initial_value))
|
||||
ratio_on = vs.get_variable(
|
||||
"ratio_on", [self._num_units],
|
||||
initializer=init_ops.constant_initializer(self._ratio_on),
|
||||
trainable=self._trainable_ratio_on)
|
||||
|
||||
cycle_ratio = self._get_cycle_ratio(time, phase, period)
|
||||
cycle_ratio = self._get_cycle_ratio(time, phase, period)
|
||||
|
||||
k_up = 2 * cycle_ratio / ratio_on
|
||||
k_down = 2 - k_up
|
||||
k_closed = self._leak * cycle_ratio
|
||||
k_up = 2 * cycle_ratio / ratio_on
|
||||
k_down = 2 - k_up
|
||||
k_closed = self._leak * cycle_ratio
|
||||
|
||||
k = array_ops.where(cycle_ratio < ratio_on, k_down, k_closed)
|
||||
k = array_ops.where(cycle_ratio < 0.5 * ratio_on, k_up, k)
|
||||
k = array_ops.where(cycle_ratio < ratio_on, k_down, k_closed)
|
||||
k = array_ops.where(cycle_ratio < 0.5 * ratio_on, k_up, k)
|
||||
|
||||
new_c = k * new_c + (1 - k) * c_prev
|
||||
new_h = k * new_h + (1 - k) * h_prev
|
||||
new_c = k * new_c + (1 - k) * c_prev
|
||||
new_h = k * new_h + (1 - k) * h_prev
|
||||
|
||||
new_state = core_rnn_cell.LSTMStateTuple(new_c, new_h)
|
||||
new_state = core_rnn_cell.LSTMStateTuple(new_c, new_h)
|
||||
|
||||
return new_h, new_state
|
||||
return new_h, new_state
|
||||
|
@ -56,14 +56,19 @@ class GatherTreeOp : public OpKernel {
|
||||
errors::InvalidArgument("step_ids must be a 3-tensor, saw shape: ",
|
||||
step_ids_shape.DebugString()));
|
||||
OP_REQUIRES(
|
||||
ctx, TensorShapeUtils::IsVector(sequence_length.shape()),
|
||||
errors::InvalidArgument("sequence_length must be a vector, saw shape: ",
|
||||
ctx, TensorShapeUtils::IsMatrix(sequence_length.shape()),
|
||||
errors::InvalidArgument("sequence_length must be a matrix, saw shape: ",
|
||||
sequence_length.shape().DebugString()));
|
||||
OP_REQUIRES(ctx, sequence_length.dim_size(0) == step_ids_shape.dim_size(1),
|
||||
errors::InvalidArgument(
|
||||
"Inconsistent batch sizes: sequence_length.shape[1] (",
|
||||
"Inconsistent batch sizes: sequence_length.shape[0] (",
|
||||
sequence_length.dim_size(0), ") != ", "step_ids.shape[1] (",
|
||||
step_ids_shape.dim_size(0), ")"));
|
||||
step_ids_shape.dim_size(1), ")"));
|
||||
OP_REQUIRES(ctx, sequence_length.dim_size(1) == step_ids_shape.dim_size(2),
|
||||
errors::InvalidArgument(
|
||||
"Inconsistent batch sizes: sequence_length.shape[1] (",
|
||||
sequence_length.dim_size(1), ") != ", "step_ids.shape[2] (",
|
||||
step_ids_shape.dim_size(2), ")"));
|
||||
OP_REQUIRES(
|
||||
ctx, step_ids_shape == parent_ids.shape(),
|
||||
errors::InvalidArgument(
|
||||
@ -74,7 +79,7 @@ class GatherTreeOp : public OpKernel {
|
||||
OP_REQUIRES_OK(ctx, ctx->allocate_output(0, step_ids_shape, &beams));
|
||||
typename TTypes<T, 3>::ConstTensor step_ids_t = step_ids.tensor<T, 3>();
|
||||
typename TTypes<T, 3>::ConstTensor parent_ids_t = parent_ids.tensor<T, 3>();
|
||||
typename TTypes<T>::ConstVec seq_len_t = sequence_length.vec<T>();
|
||||
typename TTypes<T>::ConstMatrix seq_len_t = sequence_length.matrix<T>();
|
||||
typename TTypes<T, 3>::Tensor beams_t = beams->tensor<T, 3>();
|
||||
functor::GatherTree<Device, T>()(ctx, device, step_ids_t, parent_ids_t,
|
||||
seq_len_t, beams_t);
|
||||
@ -96,7 +101,7 @@ struct GatherTree<CPUDevice, int32> {
|
||||
void operator()(OpKernelContext* ctx, const CPUDevice& d,
|
||||
typename TTypes<int32, 3>::ConstTensor step_ids,
|
||||
typename TTypes<int32, 3>::ConstTensor parent_ids,
|
||||
typename TTypes<int32>::ConstVec sequence_length,
|
||||
typename TTypes<int32>::ConstMatrix sequence_length,
|
||||
typename TTypes<int32, 3>::Tensor beams) {
|
||||
const int64 max_time = parent_ids.dimension(0);
|
||||
const int64 batch_size = parent_ids.dimension(1);
|
||||
@ -104,15 +109,10 @@ struct GatherTree<CPUDevice, int32> {
|
||||
beams.setConstant(-1);
|
||||
|
||||
auto DoWork = [&, ctx](int start_batch_beam, int limit_batch_beam) {
|
||||
int32 seq_len_b = -1;
|
||||
int32 old_batch = -1;
|
||||
for (int32 i = start_batch_beam; i < limit_batch_beam; ++i) {
|
||||
const int32 batch = i / beam_width;
|
||||
const int32 beam = i % beam_width;
|
||||
if (batch != old_batch) {
|
||||
seq_len_b = sequence_length(batch);
|
||||
old_batch = batch;
|
||||
}
|
||||
int32 seq_len_b = sequence_length(batch, beam);
|
||||
if (seq_len_b == 0) {
|
||||
continue;
|
||||
}
|
||||
@ -148,14 +148,14 @@ struct GatherTree<CPUDevice, int32> {
|
||||
|
||||
#if GOOGLE_CUDA
|
||||
namespace functor {
|
||||
#define DECLARE_GPU_SPEC(T) \
|
||||
template <> \
|
||||
void GatherTree<GPUDevice, T>::operator()( \
|
||||
OpKernelContext* ctx, const GPUDevice& d, \
|
||||
typename TTypes<T, 3>::ConstTensor step_ids, \
|
||||
typename TTypes<T, 3>::ConstTensor parent_ids, \
|
||||
typename TTypes<T>::ConstVec sequence_length, \
|
||||
typename TTypes<T, 3>::Tensor beams); \
|
||||
#define DECLARE_GPU_SPEC(T) \
|
||||
template <> \
|
||||
void GatherTree<GPUDevice, T>::operator()( \
|
||||
OpKernelContext* ctx, const GPUDevice& d, \
|
||||
typename TTypes<T, 3>::ConstTensor step_ids, \
|
||||
typename TTypes<T, 3>::ConstTensor parent_ids, \
|
||||
typename TTypes<T>::ConstMatrix sequence_length, \
|
||||
typename TTypes<T, 3>::Tensor beams); \
|
||||
extern template struct GatherTree<GPUDevice, T>;
|
||||
|
||||
DECLARE_GPU_SPEC(int32);
|
||||
|
@ -31,7 +31,7 @@ struct GatherTree {
|
||||
void operator()(OpKernelContext* ctx, const Device& d,
|
||||
typename TTypes<T, 3>::ConstTensor step_ids,
|
||||
typename TTypes<T, 3>::ConstTensor parent_ids,
|
||||
typename TTypes<T>::ConstVec sequence_length,
|
||||
typename TTypes<T>::ConstMatrix sequence_length,
|
||||
typename TTypes<T, 3>::Tensor beams);
|
||||
};
|
||||
|
||||
|
@ -33,7 +33,7 @@ __global__ void GatherTreeOpKernel(const int32 batch_size, const int32 max_time,
|
||||
CUDA_1D_KERNEL_LOOP(i, batch_size * beam_width) {
|
||||
const int32 batch = i / beam_width;
|
||||
const int32 beam = i % beam_width;
|
||||
const int32 seq_len_b = ldg(sequence_length + batch);
|
||||
const int32 seq_len_b = ldg(sequence_length + batch * beam_width + beam);
|
||||
#define GET_IX(time_ix, beam_ix) \
|
||||
(batch_size * beam_width * (time_ix) + beam_width * batch + (beam_ix))
|
||||
const int32 initial_beam_ix = GET_IX(seq_len_b - 1, beam);
|
||||
@ -59,7 +59,7 @@ struct GatherTree<GPUDevice, T> {
|
||||
void operator()(OpKernelContext* ctx, const GPUDevice& d,
|
||||
typename TTypes<T, 3>::ConstTensor step_ids,
|
||||
typename TTypes<T, 3>::ConstTensor parent_ids,
|
||||
typename TTypes<T>::ConstVec sequence_length,
|
||||
typename TTypes<T>::ConstMatrix sequence_length,
|
||||
typename TTypes<T, 3>::Tensor beams) {
|
||||
const int32 max_time = parent_ids.dimension(0);
|
||||
const int32 batch_size = parent_ids.dimension(1);
|
||||
|
@ -32,17 +32,20 @@ REGISTER_OP("GatherTree")
|
||||
ShapeHandle step_ids, parent_ids, sequence_length;
|
||||
|
||||
// step_ids, parent_ids, and output are all shaped:
|
||||
// [batch_size, max_time, beam_width].
|
||||
// sequence_length is shaped [batch_size].
|
||||
// [max_time, batch_size, beam_width].
|
||||
// sequence_length is shaped [batch_size, beam_width].
|
||||
TF_RETURN_IF_ERROR(c->WithRank(c->input(0), 3, &step_ids));
|
||||
TF_RETURN_IF_ERROR(c->WithRank(c->input(1), 3, &parent_ids));
|
||||
TF_RETURN_IF_ERROR(c->WithRank(c->input(2), 1, &sequence_length));
|
||||
TF_RETURN_IF_ERROR(c->WithRank(c->input(2), 2, &sequence_length));
|
||||
|
||||
DimensionHandle batch_size = c->Dim(step_ids, 1);
|
||||
DimensionHandle beam_width = c->Dim(step_ids, 2);
|
||||
|
||||
TF_RETURN_IF_ERROR(c->Merge(step_ids, parent_ids, &step_ids));
|
||||
TF_RETURN_IF_ERROR(
|
||||
c->Merge(batch_size, c->Dim(sequence_length, 0), &batch_size));
|
||||
TF_RETURN_IF_ERROR(
|
||||
c->Merge(beam_width, c->Dim(sequence_length, 1), &beam_width));
|
||||
|
||||
c->set_output(0, step_ids);
|
||||
return tensorflow::Status::OK();
|
||||
@ -58,7 +61,7 @@ TODO(ebrevdo): fill in
|
||||
|
||||
step_ids: `[max_time, batch_size, beam_width]`.
|
||||
parent_ids: `[max_time, batch_size, beam_width]`.
|
||||
sequence_length: `[batch_size]`.
|
||||
sequence_length: `[batch_size, beam_width]`.
|
||||
beams: `[max_time, batch_size, beam_width]`.
|
||||
)doc");
|
||||
|
||||
|
@ -109,7 +109,7 @@ class AttentionWrapperTest(test.TestCase):
|
||||
initial_state=cell.zero_state(
|
||||
dtype=dtypes.float32, batch_size=batch_size))
|
||||
|
||||
final_outputs, final_state = decoder.dynamic_decode(my_decoder)
|
||||
final_outputs, final_state, _ = decoder.dynamic_decode(my_decoder)
|
||||
|
||||
self.assertTrue(
|
||||
isinstance(final_outputs, basic_decoder.BasicDecoderOutput))
|
||||
|
@ -24,6 +24,7 @@ import numpy as np
|
||||
from tensorflow.contrib.rnn import core_rnn_cell
|
||||
from tensorflow.contrib.seq2seq.python.ops import attention_wrapper
|
||||
from tensorflow.contrib.seq2seq.python.ops import beam_search_decoder
|
||||
from tensorflow.contrib.seq2seq.python.ops import beam_search_ops
|
||||
from tensorflow.contrib.seq2seq.python.ops import decoder
|
||||
from tensorflow.python.framework import constant_op
|
||||
from tensorflow.python.framework import dtypes
|
||||
@ -41,24 +42,32 @@ class TestGatherTree(test.TestCase):
|
||||
"""Tests the gather_tree function."""
|
||||
|
||||
def test_gather_tree(self):
|
||||
predicted_ids = np.array([[[1, 2, 3], [4, 5, 6], [7, 8, 9]],
|
||||
[[2, 3, 4], [5, 6, 7],
|
||||
[8, 9, 10]]]).transpose([1, 0, 2])
|
||||
parent_ids = np.array([
|
||||
[[0, 0, 0], [0, 1, 1], [2, 1, 2]],
|
||||
[[0, 0, 0], [1, 2, 0], [2, 1, 1]],
|
||||
]).transpose([1, 0, 2])
|
||||
expected_result = np.array([[[2, 2, 2], [6, 5, 6], [7, 8, 9]],
|
||||
[[2, 4, 4], [7, 6, 6],
|
||||
[8, 9, 10]]]).transpose([1, 0, 2])
|
||||
# (max_time = 3, batch_size = 2, beam_width = 3)
|
||||
|
||||
res = beam_search_decoder._gather_tree(
|
||||
ops.convert_to_tensor(predicted_ids), ops.convert_to_tensor(parent_ids))
|
||||
# create (batch_size, max_time, beam_width) matrix and transpose it
|
||||
predicted_ids = np.array(
|
||||
[[[1, 2, 3], [4, 5, 6], [7, 8, 9]],
|
||||
[[2, 3, 4], [5, 6, 7], [8, 9, 10]]],
|
||||
dtype=np.int32).transpose([1, 0, 2])
|
||||
parent_ids = np.array(
|
||||
[[[0, 0, 0], [0, 1, 1], [2, 1, 2]],
|
||||
[[0, 0, 0], [1, 2, 0], [2, 1, 1]]],
|
||||
dtype=np.int32).transpose([1, 0, 2])
|
||||
|
||||
# sequence_lengths is shaped (batch_size = 2, beam_width = 3)
|
||||
sequence_lengths = [[3, 3, 3], [3, 3, 3]]
|
||||
|
||||
expected_result = np.array(
|
||||
[[[2, 2, 2], [6, 5, 6], [7, 8, 9]],
|
||||
[[2, 4, 4], [7, 6, 6], [8, 9, 10]]]).transpose([1, 0, 2])
|
||||
|
||||
res = beam_search_ops.gather_tree(
|
||||
predicted_ids, parent_ids, sequence_lengths)
|
||||
|
||||
with self.test_session() as sess:
|
||||
res_ = sess.run(res)
|
||||
|
||||
np.testing.assert_array_equal(expected_result, res_)
|
||||
self.assertAllEqual(expected_result, res_)
|
||||
|
||||
|
||||
class TestEosMasking(test.TestCase):
|
||||
@ -80,18 +89,18 @@ class TestEosMasking(test.TestCase):
|
||||
probs = sess.run(probs)
|
||||
masked = sess.run(masked)
|
||||
|
||||
np.testing.assert_array_equal(probs[0][0], masked[0][0])
|
||||
np.testing.assert_array_equal(probs[0][2], masked[0][2])
|
||||
np.testing.assert_array_equal(probs[1][0], masked[1][0])
|
||||
self.assertAllEqual(probs[0][0], masked[0][0])
|
||||
self.assertAllEqual(probs[0][2], masked[0][2])
|
||||
self.assertAllEqual(probs[1][0], masked[1][0])
|
||||
|
||||
np.testing.assert_equal(masked[0][1][0], 0)
|
||||
np.testing.assert_equal(masked[1][1][0], 0)
|
||||
np.testing.assert_equal(masked[1][2][0], 0)
|
||||
self.assertEqual(masked[0][1][0], 0)
|
||||
self.assertEqual(masked[1][1][0], 0)
|
||||
self.assertEqual(masked[1][2][0], 0)
|
||||
|
||||
for i in range(1, 5):
|
||||
np.testing.assert_approx_equal(masked[0][1][i], np.finfo('float32').min)
|
||||
np.testing.assert_approx_equal(masked[1][1][i], np.finfo('float32').min)
|
||||
np.testing.assert_approx_equal(masked[1][2][i], np.finfo('float32').min)
|
||||
self.assertAllClose(masked[0][1][i], np.finfo('float32').min)
|
||||
self.assertAllClose(masked[1][1][i], np.finfo('float32').min)
|
||||
self.assertAllClose(masked[1][2][i], np.finfo('float32').min)
|
||||
|
||||
|
||||
class TestBeamStep(test.TestCase):
|
||||
@ -142,12 +151,11 @@ class TestBeamStep(test.TestCase):
|
||||
outputs_, next_state_, state_, log_probs_ = sess.run(
|
||||
[outputs, next_beam_state, beam_state, log_probs])
|
||||
|
||||
np.testing.assert_array_equal(outputs_.predicted_ids, [[3, 3, 2], [2, 2,
|
||||
1]])
|
||||
np.testing.assert_array_equal(outputs_.parent_ids, [[1, 0, 0], [2, 1, 0]])
|
||||
np.testing.assert_array_equal(next_state_.lengths, [[3, 3, 3], [3, 3, 3]])
|
||||
np.testing.assert_array_equal(next_state_.finished, [[False, False, False],
|
||||
[False, False, False]])
|
||||
self.assertAllEqual(outputs_.predicted_ids, [[3, 3, 2], [2, 2, 1]])
|
||||
self.assertAllEqual(outputs_.parent_ids, [[1, 0, 0], [2, 1, 0]])
|
||||
self.assertAllEqual(next_state_.lengths, [[3, 3, 3], [3, 3, 3]])
|
||||
self.assertAllEqual(next_state_.finished, [[False, False, False],
|
||||
[False, False, False]])
|
||||
|
||||
expected_log_probs = []
|
||||
expected_log_probs.append(state_.log_probs[0][[1, 0, 0]])
|
||||
@ -158,7 +166,7 @@ class TestBeamStep(test.TestCase):
|
||||
expected_log_probs[1][0] += log_probs_[1, 2, 2]
|
||||
expected_log_probs[1][1] += log_probs_[1, 1, 2]
|
||||
expected_log_probs[1][2] += log_probs_[1, 0, 1]
|
||||
np.testing.assert_array_equal(next_state_.log_probs, expected_log_probs)
|
||||
self.assertAllEqual(next_state_.log_probs, expected_log_probs)
|
||||
|
||||
def test_step_with_eos(self):
|
||||
dummy_cell_state = array_ops.zeros([self.batch_size, self.beam_width])
|
||||
@ -197,12 +205,11 @@ class TestBeamStep(test.TestCase):
|
||||
outputs_, next_state_, state_, log_probs_ = sess.run(
|
||||
[outputs, next_beam_state, beam_state, log_probs])
|
||||
|
||||
np.testing.assert_array_equal(outputs_.parent_ids, [[1, 0, 0], [1, 2, 0]])
|
||||
np.testing.assert_array_equal(outputs_.predicted_ids, [[0, 3, 2], [2, 0,
|
||||
1]])
|
||||
np.testing.assert_array_equal(next_state_.lengths, [[1, 3, 3], [3, 1, 3]])
|
||||
np.testing.assert_array_equal(next_state_.finished, [[True, False, False],
|
||||
[False, True, False]])
|
||||
self.assertAllEqual(outputs_.parent_ids, [[1, 0, 0], [1, 2, 0]])
|
||||
self.assertAllEqual(outputs_.predicted_ids, [[0, 3, 2], [2, 0, 1]])
|
||||
self.assertAllEqual(next_state_.lengths, [[1, 3, 3], [3, 1, 3]])
|
||||
self.assertAllEqual(next_state_.finished, [[True, False, False],
|
||||
[False, True, False]])
|
||||
|
||||
expected_log_probs = []
|
||||
expected_log_probs.append(state_.log_probs[0][[1, 0, 0]])
|
||||
@ -211,7 +218,7 @@ class TestBeamStep(test.TestCase):
|
||||
expected_log_probs[0][2] += log_probs_[0, 0, 2]
|
||||
expected_log_probs[1][0] += log_probs_[1, 1, 2]
|
||||
expected_log_probs[1][2] += log_probs_[1, 0, 1]
|
||||
np.testing.assert_array_equal(next_state_.log_probs, expected_log_probs)
|
||||
self.assertAllEqual(next_state_.log_probs, expected_log_probs)
|
||||
|
||||
|
||||
class BeamSearchDecoderTest(test.TestCase):
|
||||
@ -259,8 +266,9 @@ class BeamSearchDecoderTest(test.TestCase):
|
||||
output_layer=output_layer,
|
||||
length_penalty_weight=0.0)
|
||||
|
||||
final_outputs, final_state = decoder.dynamic_decode(
|
||||
bsd, output_time_major=time_major, maximum_iterations=max_out)
|
||||
final_outputs, final_state, final_sequence_lengths = (
|
||||
decoder.dynamic_decode(
|
||||
bsd, output_time_major=time_major, maximum_iterations=max_out))
|
||||
|
||||
def _t(shape):
|
||||
if time_major:
|
||||
@ -284,16 +292,18 @@ class BeamSearchDecoderTest(test.TestCase):
|
||||
sess.run(variables.global_variables_initializer())
|
||||
sess_results = sess.run({
|
||||
'final_outputs': final_outputs,
|
||||
'final_state': final_state
|
||||
'final_state': final_state,
|
||||
'final_sequence_lengths': final_sequence_lengths
|
||||
})
|
||||
|
||||
# Mostly a smoke test
|
||||
time_steps = max_out
|
||||
max_sequence_length = np.max(sess_results['final_sequence_lengths'])
|
||||
|
||||
# A smoke test
|
||||
self.assertEqual(
|
||||
_t((batch_size, time_steps, beam_width)),
|
||||
_t((batch_size, max_sequence_length, beam_width)),
|
||||
sess_results['final_outputs'].beam_search_decoder_output.scores.shape)
|
||||
self.assertEqual(
|
||||
_t((batch_size, time_steps, beam_width)), sess_results[
|
||||
_t((batch_size, max_sequence_length, beam_width)), sess_results[
|
||||
'final_outputs'].beam_search_decoder_output.predicted_ids.shape)
|
||||
|
||||
def testDynamicDecodeRNNBatchMajorNoAttention(self):
|
||||
|
@ -38,7 +38,7 @@ class GatherTreeTest(test.TestCase):
|
||||
[[[1, 2, 3], [4, 5, 6], [7, 8, 9], [-1, -1, -1]]])
|
||||
parent_ids = _transpose_batch_time(
|
||||
[[[0, 0, 0], [0, 1, 1], [2, 1, 2], [-1, -1, -1]]])
|
||||
sequence_length = [3]
|
||||
sequence_length = [[3, 3, 3]]
|
||||
expected_result = _transpose_batch_time(
|
||||
[[[2, 2, 2], [6, 5, 6], [7, 8, 9], [-1, -1, -1]]])
|
||||
beams = beam_search_ops.gather_tree(
|
||||
@ -54,7 +54,7 @@ class GatherTreeTest(test.TestCase):
|
||||
[[[1, 2, 3], [4, 5, 6], [7, 8, 9], [-1, -1, -1]]])
|
||||
parent_ids = _transpose_batch_time(
|
||||
[[[0, 0, 0], [0, -1, 1], [2, 1, 2], [-1, -1, -1]]])
|
||||
sequence_length = [3]
|
||||
sequence_length = [[3, 3, 3]]
|
||||
with ops.device("/cpu:0"):
|
||||
beams = beam_search_ops.gather_tree(
|
||||
step_ids=step_ids, parent_ids=parent_ids,
|
||||
@ -73,7 +73,7 @@ class GatherTreeTest(test.TestCase):
|
||||
[[[1, 2, 3], [4, 5, 6], [7, 8, 9], [-1, -1, -1]]])
|
||||
parent_ids = _transpose_batch_time(
|
||||
[[[0, 0, 0], [0, -1, 1], [2, 1, 2], [-1, -1, -1]]])
|
||||
sequence_length = [3]
|
||||
sequence_length = [[3, 3, 3]]
|
||||
expected_result = _transpose_batch_time(
|
||||
[[[2, -1, 2], [6, 5, 6], [7, 8, 9], [-1, -1, -1]]])
|
||||
with ops.device("/gpu:0"):
|
||||
@ -84,7 +84,8 @@ class GatherTreeTest(test.TestCase):
|
||||
self.assertAllEqual(expected_result, beams.eval())
|
||||
|
||||
def testGatherTreeBatch(self):
|
||||
sequence_length = [0, 1, 2, 3]
|
||||
# sequence_length is [batch_size, beam_width] = [4, 5]
|
||||
sequence_length = [[0] * 5, [1] * 5, [2] * 5, [3] * 5]
|
||||
|
||||
with self.test_session(use_gpu=True):
|
||||
# (max_time = 4, batch_size = 4, beam_width = 5)
|
||||
|
@ -60,9 +60,9 @@ class DynamicDecodeRNNTest(test.TestCase):
|
||||
initial_state=cell.zero_state(
|
||||
dtype=dtypes.float32, batch_size=batch_size))
|
||||
|
||||
final_outputs, final_state = decoder.dynamic_decode(
|
||||
my_decoder, output_time_major=time_major,
|
||||
maximum_iterations=maximum_iterations)
|
||||
final_outputs, final_state, final_sequence_length = (
|
||||
decoder.dynamic_decode(my_decoder, output_time_major=time_major,
|
||||
maximum_iterations=maximum_iterations))
|
||||
|
||||
def _t(shape):
|
||||
if time_major:
|
||||
@ -73,6 +73,9 @@ class DynamicDecodeRNNTest(test.TestCase):
|
||||
isinstance(final_outputs, basic_decoder.BasicDecoderOutput))
|
||||
self.assertTrue(isinstance(final_state, core_rnn_cell.LSTMStateTuple))
|
||||
|
||||
self.assertEqual(
|
||||
(batch_size,),
|
||||
tuple(final_sequence_length.get_shape().as_list()))
|
||||
self.assertEqual(
|
||||
_t((batch_size, None, cell_depth)),
|
||||
tuple(final_outputs.rnn_output.get_shape().as_list()))
|
||||
@ -83,7 +86,8 @@ class DynamicDecodeRNNTest(test.TestCase):
|
||||
sess.run(variables.global_variables_initializer())
|
||||
sess_results = sess.run({
|
||||
"final_outputs": final_outputs,
|
||||
"final_state": final_state
|
||||
"final_state": final_state,
|
||||
"final_sequence_length": final_sequence_length,
|
||||
})
|
||||
|
||||
# Mostly a smoke test
|
||||
@ -131,7 +135,7 @@ class DynamicDecodeRNNTest(test.TestCase):
|
||||
# Match the variable scope of dynamic_rnn below so we end up
|
||||
# using the same variables
|
||||
with vs.variable_scope("root") as scope:
|
||||
final_decoder_outputs, final_decoder_state = decoder.dynamic_decode(
|
||||
final_decoder_outputs, final_decoder_state, _ = decoder.dynamic_decode(
|
||||
my_decoder,
|
||||
# impute_finished=True ensures outputs and final state
|
||||
# match those of dynamic_rnn called with sequence_length not None
|
||||
|
@ -454,6 +454,7 @@ class AttentionWrapper(core_rnn_cell.RNNCell):
|
||||
up to the next cell in an RNN stack or to the top RNN output.
|
||||
name: Name to use when creating ops.
|
||||
"""
|
||||
super(AttentionWrapper, self).__init__()
|
||||
if not isinstance(cell, core_rnn_cell.RNNCell):
|
||||
raise TypeError(
|
||||
"cell must be an RNNCell, saw type: %s" % type(cell).__name__)
|
||||
@ -515,7 +516,7 @@ class AttentionWrapper(core_rnn_cell.RNNCell):
|
||||
dtype),
|
||||
alignment_history=alignment_history)
|
||||
|
||||
def __call__(self, inputs, state, tiling_factor=1, scope=None):
|
||||
def __call__(self, inputs, state, tiling_factor=1):
|
||||
"""Perform a step of attention-wrapped RNN.
|
||||
|
||||
- Step 1: Mix the `inputs` and previous step's `attention` output via
|
||||
@ -536,7 +537,6 @@ class AttentionWrapper(core_rnn_cell.RNNCell):
|
||||
tensors from the previous time step.
|
||||
tiling_factor: An integer factor for which to tile the batch dimension.
|
||||
Used with BeamSearchDecoder.
|
||||
scope: Must be `None`.
|
||||
|
||||
Returns:
|
||||
A tuple `(attention_or_cell_output, next_state)`, where:
|
||||
@ -548,50 +548,46 @@ class AttentionWrapper(core_rnn_cell.RNNCell):
|
||||
Raises:
|
||||
NotImplementedError: if `scope` is not `None`.
|
||||
"""
|
||||
if scope is not None:
|
||||
raise NotImplementedError("scope not None is not supported")
|
||||
# Step 1: Calculate the true inputs to the cell based on the
|
||||
# previous attention value.
|
||||
cell_inputs = self._cell_input_fn(inputs, state.attention)
|
||||
cell_state = state.cell_state
|
||||
cell_output, next_cell_state = self._cell(cell_inputs, cell_state)
|
||||
|
||||
with variable_scope.variable_scope("attention"):
|
||||
# Step 1: Calculate the true inputs to the cell based on the
|
||||
# previous attention value.
|
||||
cell_inputs = self._cell_input_fn(inputs, state.attention)
|
||||
cell_state = state.cell_state
|
||||
cell_output, next_cell_state = self._cell(cell_inputs, cell_state)
|
||||
score = self._attention_mechanism(cell_output, tiling_factor)
|
||||
alignments = self._probability_fn(score)
|
||||
|
||||
score = self._attention_mechanism(cell_output, tiling_factor)
|
||||
alignments = self._probability_fn(score)
|
||||
# Reshape from [batch_size, memory_time] to [batch_size, 1, memory_time]
|
||||
expanded_alignments = array_ops.expand_dims(alignments, 1)
|
||||
# Context is the inner product of alignments and values along the
|
||||
# memory time dimension.
|
||||
# alignments shape is
|
||||
# [batch_size, 1, memory_time]
|
||||
# attention_mechanism.values shape is
|
||||
# [batch_size, memory_time, attention_mechanism.num_units]
|
||||
# the batched matmul is over memory_time, so the output shape is
|
||||
# [batch_size, 1, attention_mechanism.num_units].
|
||||
# we then squeeze out the singleton dim.
|
||||
attention_mechanism_values = _maybe_tile_batch(
|
||||
self._attention_mechanism.values, tiling_factor)
|
||||
|
||||
# Reshape from [batch_size, memory_time] to [batch_size, 1, memory_time]
|
||||
expanded_alignments = array_ops.expand_dims(alignments, 1)
|
||||
# Context is the inner product of alignments and values along the
|
||||
# memory time dimension.
|
||||
# alignments shape is
|
||||
# [batch_size, 1, memory_time]
|
||||
# attention_mechanism.values shape is
|
||||
# [batch_size, memory_time, attention_mechanism.num_units]
|
||||
# the batched matmul is over memory_time, so the output shape is
|
||||
# [batch_size, 1, attention_mechanism.num_units].
|
||||
# we then squeeze out the singleton dim.
|
||||
attention_mechanism_values = _maybe_tile_batch(
|
||||
self._attention_mechanism.values, tiling_factor)
|
||||
context = math_ops.matmul(expanded_alignments, attention_mechanism_values)
|
||||
context = array_ops.squeeze(context, [1])
|
||||
|
||||
context = math_ops.matmul(expanded_alignments, attention_mechanism_values)
|
||||
context = array_ops.squeeze(context, [1])
|
||||
attention = self._attention_layer(
|
||||
array_ops.concat([cell_output, context], 1))
|
||||
|
||||
attention = self._attention_layer(
|
||||
array_ops.concat([cell_output, context], 1))
|
||||
if self._alignment_history:
|
||||
alignment_history = state.alignment_history.write(
|
||||
state.time, alignments)
|
||||
else:
|
||||
alignment_history = ()
|
||||
|
||||
if self._alignment_history:
|
||||
alignment_history = state.alignment_history.write(
|
||||
state.time, alignments)
|
||||
else:
|
||||
alignment_history = ()
|
||||
|
||||
next_state = AttentionWrapperState(
|
||||
time=state.time + 1,
|
||||
cell_state=next_cell_state,
|
||||
attention=attention,
|
||||
alignment_history=alignment_history)
|
||||
next_state = AttentionWrapperState(
|
||||
time=state.time + 1,
|
||||
cell_state=next_cell_state,
|
||||
attention=attention,
|
||||
alignment_history=alignment_history)
|
||||
|
||||
if self._output_attention:
|
||||
return attention, next_state
|
||||
|
@ -19,9 +19,9 @@ from __future__ import division
|
||||
from __future__ import print_function
|
||||
|
||||
import collections
|
||||
import numpy as np
|
||||
|
||||
from tensorflow.contrib.rnn import core_rnn_cell
|
||||
from tensorflow.contrib.seq2seq.python.ops import beam_search_ops
|
||||
from tensorflow.contrib.seq2seq.python.ops import decoder
|
||||
from tensorflow.python.framework import dtypes
|
||||
from tensorflow.python.framework import ops
|
||||
@ -33,7 +33,6 @@ from tensorflow.python.ops import control_flow_ops
|
||||
from tensorflow.python.ops import embedding_ops
|
||||
from tensorflow.python.ops import math_ops
|
||||
from tensorflow.python.ops import nn_ops
|
||||
from tensorflow.python.ops import script_ops
|
||||
from tensorflow.python.util import nest
|
||||
|
||||
|
||||
@ -202,20 +201,24 @@ class BeamSearchDecoder(decoder.Decoder):
|
||||
|
||||
return (finished, start_inputs, initial_state)
|
||||
|
||||
def finalize(self, outputs, final_state):
|
||||
def finalize(self, outputs, final_state, sequence_lengths):
|
||||
"""Finalize and return the predicted_ids.
|
||||
|
||||
Args:
|
||||
outputs: An instance of BeamSearchDecoderOutput.
|
||||
final_state: An instance of BeamSearchDecoderState. Passed through to the
|
||||
output.
|
||||
sequence_lengths: An `int32` tensor shaped `[batch_size, beam_width]`.
|
||||
The sequence lengths determined for each beam during decode.
|
||||
|
||||
Returns:
|
||||
outputs: An instance of FinalBeamSearchDecoderOutput where the
|
||||
predicted_ids are the result of calling _gather_tree.
|
||||
final_state: The same input instance of BeamSearchDecoderState.
|
||||
"""
|
||||
predicted_ids = _gather_tree(outputs.predicted_ids, outputs.parent_ids)
|
||||
predicted_ids = beam_search_ops.gather_tree(
|
||||
outputs.predicted_ids, outputs.parent_ids,
|
||||
sequence_length=sequence_lengths)
|
||||
outputs = FinalBeamSearchDecoderOutput(
|
||||
beam_search_decoder_output=outputs, predicted_ids=predicted_ids)
|
||||
return outputs, final_state
|
||||
@ -536,42 +539,6 @@ def _mask_probs(probs, eos_token, finished):
|
||||
return finished_examples + non_finished_examples
|
||||
|
||||
|
||||
def _gather_tree_py(values, parents):
|
||||
"""Gathers path through a tree backwards from the leave nodes.
|
||||
|
||||
Used to reconstruct beams given their parents.
|
||||
|
||||
Args:
|
||||
values: A [T, batch_size, beam_width] tensor of indices.
|
||||
parents: A [T, batch_size, beam_width] tensor of parent beam ids.
|
||||
|
||||
Returns:
|
||||
The [T, batch_size, beam_width] numpy array of paths. For a given batch
|
||||
entry b, the best path is given by ret[:, b, 0].
|
||||
"""
|
||||
num_timesteps = values.shape[0]
|
||||
num_beams = values.shape[2]
|
||||
batch_size = values.shape[1]
|
||||
ret = np.zeros_like(values) # [T, MB, BW]
|
||||
ret[-1, :, :] = values[-1, :, :]
|
||||
for beam_id in range(num_beams):
|
||||
for batch in range(batch_size):
|
||||
parent = parents[-1][batch][beam_id]
|
||||
for timestep in reversed(range(num_timesteps - 1)):
|
||||
ret[timestep, batch, beam_id] = values[timestep][batch][parent]
|
||||
parent = parents[timestep][batch][parent]
|
||||
# now we are going to return ret as a [ts, mb, bw] tensor
|
||||
return np.array(ret).astype(values.dtype)
|
||||
|
||||
|
||||
def _gather_tree(values, parents):
|
||||
"""Tensor version of _gather_tree_py."""
|
||||
ret = script_ops.py_func(
|
||||
func=_gather_tree_py, inp=[values, parents], Tout=values.dtype)
|
||||
ret.set_shape(values.get_shape().as_list())
|
||||
return ret
|
||||
|
||||
|
||||
def _tensor_gather_helper(gather_indices, gather_from, range_input, range_size,
|
||||
final_shape):
|
||||
range_ = array_ops.expand_dims(math_ops.range(range_input) * range_size, 1)
|
||||
|
@ -154,11 +154,11 @@ def dynamic_decode(decoder,
|
||||
scope: Optional variable scope to use.
|
||||
|
||||
Returns:
|
||||
`(final_outputs, final_state)`.
|
||||
`(final_outputs, final_state, final_sequence_lengths)`.
|
||||
|
||||
Raises:
|
||||
TypeError: if `decoder` is not an instance of `Decoder`.
|
||||
ValueError: if maximum_iterations is provided but is not a scalar.
|
||||
ValueError: if `maximum_iterations` is provided but is not a scalar.
|
||||
"""
|
||||
if not isinstance(decoder, Decoder):
|
||||
raise TypeError("Expected decoder to be type Decoder, but saw: %s" %
|
||||
@ -184,6 +184,8 @@ def dynamic_decode(decoder,
|
||||
if maximum_iterations is not None:
|
||||
initial_finished = math_ops.logical_or(
|
||||
initial_finished, 0 >= maximum_iterations)
|
||||
initial_sequence_lengths = array_ops.zeros_like(
|
||||
initial_finished, dtype=dtypes.int32)
|
||||
initial_time = constant_op.constant(0, dtype=dtypes.int32)
|
||||
|
||||
def _shape(batch_size, from_shape):
|
||||
@ -206,10 +208,10 @@ def dynamic_decode(decoder,
|
||||
decoder.output_dtype)
|
||||
|
||||
def condition(unused_time, unused_outputs_ta, unused_state, unused_inputs,
|
||||
finished):
|
||||
finished, unused_sequence_lengths):
|
||||
return math_ops.logical_not(math_ops.reduce_all(finished))
|
||||
|
||||
def body(time, outputs_ta, state, inputs, finished):
|
||||
def body(time, outputs_ta, state, inputs, finished, sequence_lengths):
|
||||
"""Internal while_loop body.
|
||||
|
||||
Args:
|
||||
@ -217,10 +219,13 @@ def dynamic_decode(decoder,
|
||||
outputs_ta: structure of TensorArray.
|
||||
state: (structure of) state tensors and TensorArrays.
|
||||
inputs: (structure of) input tensors.
|
||||
finished: 1-D bool tensor.
|
||||
finished: bool tensor (keeping track of what's finished).
|
||||
sequence_lengths: int32 tensor (keeping track of time of finish).
|
||||
|
||||
Returns:
|
||||
`(time + 1, outputs_ta, next_state, next_inputs, next_finished)`.
|
||||
`(time + 1, outputs_ta, next_state, next_inputs, next_finished,
|
||||
next_sequence_lengths)`.
|
||||
```
|
||||
"""
|
||||
(next_outputs, decoder_state, next_inputs,
|
||||
decoder_finished) = decoder.step(time, inputs, state)
|
||||
@ -228,6 +233,10 @@ def dynamic_decode(decoder,
|
||||
if maximum_iterations is not None:
|
||||
next_finished = math_ops.logical_or(
|
||||
next_finished, time + 1 >= maximum_iterations)
|
||||
next_sequence_lengths = array_ops.where(
|
||||
math_ops.logical_and(math_ops.logical_not(finished), next_finished),
|
||||
array_ops.fill(array_ops.shape(sequence_lengths), time + 1),
|
||||
sequence_lengths)
|
||||
|
||||
nest.assert_same_structure(state, decoder_state)
|
||||
nest.assert_same_structure(outputs_ta, next_outputs)
|
||||
@ -260,26 +269,30 @@ def dynamic_decode(decoder,
|
||||
|
||||
outputs_ta = nest.map_structure(lambda ta, out: ta.write(time, out),
|
||||
outputs_ta, emit)
|
||||
return (time + 1, outputs_ta, next_state, next_inputs, next_finished)
|
||||
return (time + 1, outputs_ta, next_state, next_inputs, next_finished,
|
||||
next_sequence_lengths)
|
||||
|
||||
res = control_flow_ops.while_loop(
|
||||
condition,
|
||||
body,
|
||||
loop_vars=[
|
||||
initial_time, initial_outputs_ta, initial_state, initial_inputs,
|
||||
initial_finished
|
||||
initial_finished, initial_sequence_lengths,
|
||||
],
|
||||
parallel_iterations=parallel_iterations,
|
||||
swap_memory=swap_memory)
|
||||
|
||||
final_outputs_ta = res[1]
|
||||
final_state = res[2]
|
||||
final_sequence_lengths = res[5]
|
||||
|
||||
final_outputs = nest.map_structure(lambda ta: ta.stack(), final_outputs_ta)
|
||||
|
||||
if hasattr(decoder, "finalize"):
|
||||
final_outputs, final_state = decoder.finalize(
|
||||
final_outputs, final_state, final_sequence_lengths)
|
||||
|
||||
if not output_time_major:
|
||||
final_outputs = nest.map_structure(_transpose_batch_time, final_outputs)
|
||||
|
||||
if hasattr(decoder, "finalize"):
|
||||
final_outputs, final_state = decoder.finalize(final_outputs, final_state)
|
||||
|
||||
return final_outputs, final_state
|
||||
return final_outputs, final_state, final_sequence_lengths
|
||||
|
@ -19,13 +19,11 @@ from __future__ import absolute_import
|
||||
from __future__ import division
|
||||
from __future__ import print_function
|
||||
|
||||
|
||||
import inspect
|
||||
|
||||
from six import exec_
|
||||
from tensorflow.contrib.specs.python import params_ops
|
||||
from tensorflow.contrib.specs.python import specs_lib
|
||||
from tensorflow.contrib.specs.python import specs_ops
|
||||
from tensorflow.python.util import tf_inspect
|
||||
|
||||
|
||||
def eval_params(params, environment=None):
|
||||
@ -44,7 +42,8 @@ def eval_params(params, environment=None):
|
||||
"""
|
||||
specs_lib.check_keywords(params)
|
||||
bindings = {}
|
||||
if environment: bindings.update(environment)
|
||||
if environment:
|
||||
bindings.update(environment)
|
||||
exec_(params, vars(params_ops), bindings) # pylint: disable=exec-used
|
||||
return bindings
|
||||
|
||||
@ -71,7 +70,8 @@ def eval_spec(spec, environment=None):
|
||||
"""
|
||||
specs_lib.check_keywords(spec)
|
||||
bindings = {}
|
||||
if environment: bindings.update(environment)
|
||||
if environment:
|
||||
bindings.update(environment)
|
||||
exec_(spec, vars(specs_ops), bindings) # pylint: disable=exec-used
|
||||
return bindings
|
||||
|
||||
@ -141,7 +141,7 @@ class LocalImport(object):
|
||||
self.names = names
|
||||
|
||||
def __enter__(self):
|
||||
self.frame = inspect.currentframe()
|
||||
self.frame = tf_inspect.currentframe()
|
||||
bindings = self.frame.f_back.f_globals
|
||||
self.old = {k: bindings.get(k, None) for k in self.names.keys()}
|
||||
bindings.update(self.names)
|
||||
@ -151,7 +151,9 @@ class LocalImport(object):
|
||||
bindings = self.frame.f_back.f_globals
|
||||
bindings.update(self.old)
|
||||
for k, v in self.old.items():
|
||||
if v is None: del bindings[k]
|
||||
if v is None:
|
||||
del bindings[k]
|
||||
del self.frame
|
||||
|
||||
|
||||
ops = LocalImport(specs_ops)
|
||||
|
@ -1,6 +1,10 @@
|
||||
# Description:
|
||||
# Verbs RDMA communication interfaces and implementations for TensorFlow.
|
||||
|
||||
package(default_visibility = [
|
||||
"//tensorflow:__subpackages__",
|
||||
])
|
||||
|
||||
licenses(["notice"]) # Apache 2.0
|
||||
|
||||
exports_files(["LICENSE"])
|
||||
@ -31,13 +35,10 @@ load(
|
||||
"tf_proto_library_cc",
|
||||
)
|
||||
|
||||
package(default_visibility = [
|
||||
"//tensorflow:__subpackages__",
|
||||
])
|
||||
|
||||
tf_proto_library_cc(
|
||||
name = "verbs_service_proto",
|
||||
srcs = ["verbs_service.proto"],
|
||||
has_services = 1,
|
||||
cc_api_version = 2,
|
||||
visibility = [
|
||||
"//tensorflow:__subpackages__",
|
||||
|
@ -19,11 +19,11 @@ limitations under the License.
|
||||
#include "tensorflow/core/lib/core/errors.h"
|
||||
#include "tensorflow/core/lib/core/status.h"
|
||||
|
||||
namespace tensorflow {
|
||||
|
||||
namespace tensorflow {
|
||||
|
||||
Status GrpcVerbsClient::GetRemoteAddress(CallOptions* call_options,
|
||||
const GetRemoteAddressRequest* request,
|
||||
GetRemoteAddressResponse* response) {
|
||||
const GetRemoteAddressRequest* request,
|
||||
GetRemoteAddressResponse* response) {
|
||||
::grpc::ClientContext ctx;
|
||||
ctx.set_fail_fast(false);
|
||||
SetDeadline(&ctx, call_options->GetTimeout());
|
||||
@ -31,14 +31,14 @@ Status GrpcVerbsClient::GetRemoteAddress(CallOptions* call_options,
|
||||
}
|
||||
|
||||
Status GrpcVerbsClient::GetRemoteAddress(const GetRemoteAddressRequest* request,
|
||||
GetRemoteAddressResponse* response) {
|
||||
GetRemoteAddressResponse* response) {
|
||||
CallOptions call_options;
|
||||
call_options.SetTimeout(-1); // no time out
|
||||
call_options.SetTimeout(-1); // no time out
|
||||
return GetRemoteAddress(&call_options, request, response);
|
||||
}
|
||||
|
||||
void GrpcVerbsClient::SetDeadline(::grpc::ClientContext* ctx,
|
||||
int64 time_in_ms) {
|
||||
void GrpcVerbsClient::SetDeadline(::grpc::ClientContext* ctx,
|
||||
int64 time_in_ms) {
|
||||
if (time_in_ms > 0) {
|
||||
ctx->set_deadline(gpr_time_from_millis(time_in_ms, GPR_TIMESPAN));
|
||||
}
|
||||
|
@ -16,11 +16,11 @@ limitations under the License.
|
||||
#ifndef THIRD_PARTY_TENSORFLOW_CONTRIB_GRPC_VERBS_CLIENT_H_
|
||||
#define THIRD_PARTY_TENSORFLOW_CONTRIB_GRPC_VERBS_CLIENT_H_
|
||||
|
||||
#include "tensorflow/contrib/verbs/grpc_verbs_service_impl.h"
|
||||
#include "tensorflow/contrib/verbs/verbs_service.pb.h"
|
||||
#include "tensorflow/core/distributed_runtime/call_options.h"
|
||||
#include "tensorflow/core/distributed_runtime/rpc/grpc_util.h"
|
||||
#include "tensorflow/core/lib/core/status.h"
|
||||
#include "tensorflow/contrib/verbs/grpc_verbs_service_impl.h"
|
||||
#include "tensorflow/contrib/verbs/verbs_service.pb.h"
|
||||
|
||||
namespace tensorflow {
|
||||
|
||||
@ -28,24 +28,23 @@ namespace tensorflow {
|
||||
class GrpcVerbsClient {
|
||||
public:
|
||||
explicit GrpcVerbsClient(SharedGrpcChannelPtr client_channel)
|
||||
: stub_(grpc::VerbsService::NewStub(client_channel)) {}
|
||||
: stub_(grpc::VerbsService::NewStub(client_channel)) {}
|
||||
~GrpcVerbsClient() {}
|
||||
|
||||
Status GetRemoteAddress(CallOptions* call_options,
|
||||
const GetRemoteAddressRequest* request,
|
||||
GetRemoteAddressResponse* response);
|
||||
const GetRemoteAddressRequest* request,
|
||||
GetRemoteAddressResponse* response);
|
||||
Status GetRemoteAddress(const GetRemoteAddressRequest* request,
|
||||
GetRemoteAddressResponse* response);
|
||||
|
||||
GetRemoteAddressResponse* response);
|
||||
|
||||
private:
|
||||
std::unique_ptr<grpc::VerbsService::Stub> stub_;
|
||||
|
||||
void SetDeadline(::grpc::ClientContext* ctx, int64 time_in_ms);
|
||||
|
||||
|
||||
TF_DISALLOW_COPY_AND_ASSIGN(GrpcVerbsClient);
|
||||
};
|
||||
|
||||
} // namespace tensorflow
|
||||
|
||||
#endif // THIRD_PARTY_TENSORFLOW_CONTRIB_GRPC_VERBS_CLIENT_H_
|
||||
|
||||
|
@ -26,10 +26,10 @@ limitations under the License.
|
||||
namespace tensorflow {
|
||||
|
||||
GrpcVerbsService::GrpcVerbsService(const WorkerEnv* worker_env,
|
||||
::grpc::ServerBuilder* builder)
|
||||
: is_shutdown_(false), worker_env_(worker_env) {
|
||||
builder->RegisterService(&verbs_service_);
|
||||
cq_ = builder->AddCompletionQueue().release();
|
||||
::grpc::ServerBuilder* builder)
|
||||
: is_shutdown_(false), worker_env_(worker_env) {
|
||||
builder->RegisterService(&verbs_service_);
|
||||
cq_ = builder->AddCompletionQueue().release();
|
||||
}
|
||||
|
||||
GrpcVerbsService::~GrpcVerbsService() {
|
||||
@ -52,7 +52,7 @@ void GrpcVerbsService::Shutdown() {
|
||||
new ::grpc::Alarm(cq_, gpr_now(GPR_CLOCK_MONOTONIC), nullptr);
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
// This macro creates a new request for the given RPC method name
|
||||
// (e.g., `ENQUEUE_REQUEST(GetRemoteAddress, false);`), and enqueues it on
|
||||
// `this->cq_`.
|
||||
@ -64,17 +64,17 @@ void GrpcVerbsService::Shutdown() {
|
||||
// The implementation of the request handler for each RPC method
|
||||
// must ensure that it calls ENQUEUE_REQUEST() for that RPC method,
|
||||
// to keep accepting new requests.
|
||||
#define ENQUEUE_REQUEST(method, supports_cancel) \
|
||||
do { \
|
||||
mutex_lock l(shutdown_mu_); \
|
||||
if (!is_shutdown_) { \
|
||||
Call<GrpcVerbsService, grpc::VerbsService::AsyncService, \
|
||||
method##Request, method##Response>:: \
|
||||
EnqueueRequest(&verbs_service_, cq_, \
|
||||
&grpc::VerbsService::AsyncService::Request##method, \
|
||||
&GrpcVerbsService::method##Handler, \
|
||||
(supports_cancel)); \
|
||||
} \
|
||||
#define ENQUEUE_REQUEST(method, supports_cancel) \
|
||||
do { \
|
||||
mutex_lock l(shutdown_mu_); \
|
||||
if (!is_shutdown_) { \
|
||||
Call<GrpcVerbsService, grpc::VerbsService::AsyncService, \
|
||||
method##Request, method##Response>:: \
|
||||
EnqueueRequest(&verbs_service_, cq_, \
|
||||
&grpc::VerbsService::AsyncService::Request##method, \
|
||||
&GrpcVerbsService::method##Handler, \
|
||||
(supports_cancel)); \
|
||||
} \
|
||||
} while (0)
|
||||
|
||||
// This method blocks forever handling requests from the completion queue.
|
||||
@ -97,8 +97,8 @@ void GrpcVerbsService::HandleRPCsLoop() {
|
||||
}
|
||||
}
|
||||
|
||||
void GrpcVerbsService::GetRemoteAddressHandler(WorkerCall
|
||||
<GetRemoteAddressRequest, GetRemoteAddressResponse>* call) {
|
||||
void GrpcVerbsService::GetRemoteAddressHandler(
|
||||
WorkerCall<GetRemoteAddressRequest, GetRemoteAddressResponse>* call) {
|
||||
Status s = GetRemoteAddressSync(&call->request, &call->response);
|
||||
call->SendResponse(ToGrpcStatus(s));
|
||||
ENQUEUE_REQUEST(GetRemoteAddress, false);
|
||||
@ -106,8 +106,8 @@ void GrpcVerbsService::GetRemoteAddressHandler(WorkerCall
|
||||
|
||||
// synchronous method
|
||||
Status GrpcVerbsService::GetRemoteAddressSync(
|
||||
const GetRemoteAddressRequest* request,
|
||||
GetRemoteAddressResponse* response) {
|
||||
const GetRemoteAddressRequest* request,
|
||||
GetRemoteAddressResponse* response) {
|
||||
// analyzing request
|
||||
// the channel setting part is redundant.
|
||||
const string remote_host_name = request->host_name();
|
||||
@ -115,7 +115,7 @@ Status GrpcVerbsService::GetRemoteAddressSync(
|
||||
CHECK(rc);
|
||||
RdmaAddress ra;
|
||||
ra.lid = request->channel().lid();
|
||||
ra.qpn = request->channel().qpn();
|
||||
ra.qpn = request->channel().qpn();
|
||||
ra.psn = request->channel().psn();
|
||||
rc->SetRemoteAddress(ra, false);
|
||||
rc->Connect();
|
||||
@ -140,8 +140,8 @@ Status GrpcVerbsService::GetRemoteAddressSync(
|
||||
CHECK(i == RdmaChannel::kNumMessageBuffers);
|
||||
|
||||
// setting up response
|
||||
response->set_host_name(worker_env_->session_mgr->
|
||||
LegacySession()->worker_name);
|
||||
response->set_host_name(
|
||||
worker_env_->session_mgr->LegacySession()->worker_name);
|
||||
Channel* channel_info = response->mutable_channel();
|
||||
channel_info->set_lid(rc->self().lid);
|
||||
channel_info->set_qpn(rc->self().qpn);
|
||||
@ -151,12 +151,12 @@ Status GrpcVerbsService::GetRemoteAddressSync(
|
||||
mr->set_remote_addr(reinterpret_cast<uint64>(mb[i]->buffer()));
|
||||
mr->set_rkey(mb[i]->self()->rkey);
|
||||
}
|
||||
return Status::OK();
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
// Create a GrpcVerbsService, then assign it to a given handle.
|
||||
void SetNewVerbsService(GrpcVerbsService** handle, const WorkerEnv* worker_env,
|
||||
::grpc::ServerBuilder* builder) {
|
||||
::grpc::ServerBuilder* builder) {
|
||||
*handle = new GrpcVerbsService(worker_env, builder);
|
||||
}
|
||||
|
||||
|
@ -18,12 +18,12 @@ limitations under the License.
|
||||
|
||||
#ifdef TENSORFLOW_USE_VERBS
|
||||
|
||||
#include "tensorflow/contrib/verbs/rdma_mgr.h"
|
||||
#include "tensorflow/contrib/verbs/grpc_verbs_service_impl.h"
|
||||
#include "tensorflow/contrib/verbs/rdma_mgr.h"
|
||||
#include "tensorflow/contrib/verbs/verbs_service.pb.h"
|
||||
#include "tensorflow/core/distributed_runtime/rpc/async_service_interface.h"
|
||||
#include "tensorflow/core/lib/core/refcount.h"
|
||||
#include "tensorflow/core/distributed_runtime/rpc/grpc_call.h"
|
||||
#include "tensorflow/core/lib/core/refcount.h"
|
||||
|
||||
namespace grpc {
|
||||
class ServerBuilder;
|
||||
@ -44,27 +44,27 @@ class GrpcVerbsService : public AsyncServiceInterface {
|
||||
private:
|
||||
template <class RequestMessage, class ResponseMessage>
|
||||
using WorkerCall = Call<GrpcVerbsService, grpc::VerbsService::AsyncService,
|
||||
RequestMessage, ResponseMessage>;
|
||||
void GetRemoteAddressHandler(WorkerCall
|
||||
<GetRemoteAddressRequest, GetRemoteAddressResponse>* call);
|
||||
RequestMessage, ResponseMessage>;
|
||||
void GetRemoteAddressHandler(
|
||||
WorkerCall<GetRemoteAddressRequest, GetRemoteAddressResponse>* call);
|
||||
Status GetRemoteAddressSync(const GetRemoteAddressRequest* request,
|
||||
GetRemoteAddressResponse* response);
|
||||
|
||||
::grpc::ServerCompletionQueue* cq_;
|
||||
GetRemoteAddressResponse* response);
|
||||
|
||||
::grpc::ServerCompletionQueue* cq_;
|
||||
grpc::VerbsService::AsyncService verbs_service_;
|
||||
mutex shutdown_mu_;
|
||||
bool is_shutdown_ GUARDED_BY(shutdown_mu_);
|
||||
::grpc::Alarm* shutdown_alarm_;
|
||||
// not owned
|
||||
RdmaMgr* rdma_mgr_;
|
||||
const WorkerEnv* const worker_env_;
|
||||
const WorkerEnv* const worker_env_;
|
||||
|
||||
TF_DISALLOW_COPY_AND_ASSIGN(GrpcVerbsService);
|
||||
};
|
||||
|
||||
// Create a GrpcVerbsService, then assign it to a given handle.
|
||||
void SetNewVerbsService(GrpcVerbsService** handle, const WorkerEnv* worker_env,
|
||||
::grpc::ServerBuilder* builder);
|
||||
::grpc::ServerBuilder* builder);
|
||||
|
||||
} // namespace tensorflow
|
||||
|
||||
|
@ -43,7 +43,7 @@ VerbsService::Stub::Stub(
|
||||
const std::shared_ptr< ::grpc::ChannelInterface>& channel)
|
||||
: channel_(channel),
|
||||
rpcmethod_GetRemoteAddress_(grpcVerbsService_method_names[0],
|
||||
::grpc::RpcMethod::NORMAL_RPC, channel) {}
|
||||
::grpc::RpcMethod::NORMAL_RPC, channel) {}
|
||||
|
||||
::grpc::Status VerbsService::Stub::GetRemoteAddress(
|
||||
::grpc::ClientContext* context, const GetRemoteAddressRequest& request,
|
||||
|
@ -48,16 +48,16 @@ class VerbsService GRPC_FINAL {
|
||||
class StubInterface {
|
||||
public:
|
||||
virtual ~StubInterface() {}
|
||||
virtual ::grpc::Status GetRemoteAddress(::grpc::ClientContext* context,
|
||||
const GetRemoteAddressRequest& request,
|
||||
GetRemoteAddressResponse* response) = 0;
|
||||
virtual ::grpc::Status GetRemoteAddress(
|
||||
::grpc::ClientContext* context, const GetRemoteAddressRequest& request,
|
||||
GetRemoteAddressResponse* response) = 0;
|
||||
};
|
||||
class Stub GRPC_FINAL : public StubInterface {
|
||||
public:
|
||||
Stub(const std::shared_ptr< ::grpc::ChannelInterface>& channel);
|
||||
::grpc::Status GetRemoteAddress(::grpc::ClientContext* context,
|
||||
const GetRemoteAddressRequest& request,
|
||||
GetRemoteAddressResponse* response) GRPC_OVERRIDE;
|
||||
::grpc::Status GetRemoteAddress(
|
||||
::grpc::ClientContext* context, const GetRemoteAddressRequest& request,
|
||||
GetRemoteAddressResponse* response) GRPC_OVERRIDE;
|
||||
|
||||
private:
|
||||
std::shared_ptr< ::grpc::ChannelInterface> channel_;
|
||||
|
@ -15,16 +15,16 @@ limitations under the License.
|
||||
|
||||
#ifdef TENSORFLOW_USE_VERBS
|
||||
|
||||
#include <cstdlib>
|
||||
#include "tensorflow/contrib/verbs/rdma.h"
|
||||
#include <cstdlib>
|
||||
#include "tensorflow/contrib/verbs/verbs_util.h"
|
||||
#include "tensorflow/core/common_runtime/device_mgr.h"
|
||||
#include "tensorflow/core/common_runtime/dma_helper.h"
|
||||
#include "tensorflow/core/common_runtime/gpu/gpu_util.h"
|
||||
#include "tensorflow/core/distributed_runtime/rendezvous_mgr_interface.h"
|
||||
#include "tensorflow/core/distributed_runtime/session_mgr.h"
|
||||
#include "tensorflow/core/framework/tensor.h"
|
||||
#include "tensorflow/core/framework/rendezvous.h"
|
||||
#include "tensorflow/core/framework/tensor.h"
|
||||
#include "tensorflow/core/lib/core/status.h"
|
||||
#include "tensorflow/core/lib/core/stringpiece.h"
|
||||
#include "tensorflow/core/lib/hash/hash.h"
|
||||
@ -35,12 +35,12 @@ namespace tensorflow {
|
||||
namespace {
|
||||
// hash name to 32-bit integer
|
||||
uint32_t NameHash(const string& name) {
|
||||
return Hash32(name.data(), name.size(), 0x1234ABCD);
|
||||
return Hash32(name.data(), name.size(), 0x1234ABCD);
|
||||
}
|
||||
|
||||
// convenience function for printing message
|
||||
string MessageTypeToString(RdmaMessageType rmt) {
|
||||
switch(rmt){
|
||||
switch (rmt) {
|
||||
case RDMA_MESSAGE_ACK:
|
||||
return "RDMA_MESSAGE_ACK";
|
||||
break;
|
||||
@ -59,11 +59,11 @@ string MessageTypeToString(RdmaMessageType rmt) {
|
||||
case RDMA_MESSAGE_TENSOR_WRITE:
|
||||
return "RDMA_MESSAGE_TENSOR_WRITE";
|
||||
break;
|
||||
default:
|
||||
default:
|
||||
return "UNKNOWN MESSAGE";
|
||||
}
|
||||
}
|
||||
}
|
||||
} // namespace
|
||||
|
||||
ibv_context* open_default_device() {
|
||||
ibv_device** dev_list;
|
||||
@ -89,29 +89,28 @@ RdmaAdapter::RdmaAdapter(const WorkerEnv* worker_env)
|
||||
worker_env_(worker_env) {
|
||||
event_channel_ = ibv_create_comp_channel(context_);
|
||||
CHECK(event_channel_) << "Failed to create completion channel";
|
||||
cq_ = ibv_create_cq(context_, MAX_CONCURRENT_WRITES * 2, NULL, event_channel_, 0);
|
||||
cq_ = ibv_create_cq(context_, MAX_CONCURRENT_WRITES * 2, NULL, event_channel_,
|
||||
0);
|
||||
CHECK(cq_) << "Failed to create completion queue";
|
||||
CHECK(!ibv_req_notify_cq(cq_, 0)) << "Failed to request CQ notification";
|
||||
polling_thread_.reset(Env::Default()->StartThread(
|
||||
ThreadOptions(), "RdmaAdapterCQThread",
|
||||
[this] {Process_CQ(); }));
|
||||
VLOG(2) << "Start RdmaAdapter: " << name();
|
||||
ThreadOptions(), "RdmaAdapterCQThread", [this] { Process_CQ(); }));
|
||||
VLOG(2) << "Start RdmaAdapter: " << name();
|
||||
}
|
||||
|
||||
RdmaAdapter::~RdmaAdapter() {
|
||||
polling_thread_.reset();
|
||||
CHECK(!ibv_destroy_cq(cq_)) << "Failed to destroy CQ";
|
||||
CHECK(!ibv_destroy_comp_channel(event_channel_)) << "Failed to destroy channel";
|
||||
CHECK(!ibv_destroy_comp_channel(event_channel_))
|
||||
<< "Failed to destroy channel";
|
||||
CHECK(!ibv_dealloc_pd(pd_)) << "Failed to deallocate PD";
|
||||
CHECK(!ibv_close_device(context_)) << "Failed to release context";
|
||||
}
|
||||
|
||||
string RdmaAdapter::name() const {
|
||||
return string(context_->device->name);
|
||||
}
|
||||
string RdmaAdapter::name() const { return string(context_->device->name); }
|
||||
|
||||
// Function to process incoming messages
|
||||
// There are two types of messages:
|
||||
// There are two types of messages:
|
||||
// 1. IBV_WC_RECV_RDMA_WITH_IMM (receive)
|
||||
// 2. IBV_WC_RDMA_WRITE (send))
|
||||
void RdmaAdapter::Process_CQ() {
|
||||
@ -123,15 +122,14 @@ void RdmaAdapter::Process_CQ() {
|
||||
ibv_ack_cq_events(cq, 1);
|
||||
CHECK(!ibv_req_notify_cq(cq_, 0));
|
||||
|
||||
int ne = ibv_poll_cq(cq_, MAX_CONCURRENT_WRITES * 2,
|
||||
static_cast<ibv_wc*>(wc_));
|
||||
int ne =
|
||||
ibv_poll_cq(cq_, MAX_CONCURRENT_WRITES * 2, static_cast<ibv_wc*>(wc_));
|
||||
CHECK_GE(ne, 0);
|
||||
for (int i = 0; i < ne; ++i) {
|
||||
CHECK(wc_[i].status == IBV_WC_SUCCESS) << "Failed status \n"
|
||||
<< ibv_wc_status_str(wc_[i].status)
|
||||
<< " " << wc_[i].status << " "
|
||||
<< static_cast<int>(wc_[i].wr_id)
|
||||
<< " "<< wc_[i].vendor_err;
|
||||
CHECK(wc_[i].status == IBV_WC_SUCCESS)
|
||||
<< "Failed status \n"
|
||||
<< ibv_wc_status_str(wc_[i].status) << " " << wc_[i].status << " "
|
||||
<< static_cast<int>(wc_[i].wr_id) << " " << wc_[i].vendor_err;
|
||||
if (wc_[i].opcode == IBV_WC_RECV_RDMA_WITH_IMM) {
|
||||
RdmaChannel* rc = reinterpret_cast<RdmaChannel*>(wc_[i].wr_id);
|
||||
// put back a recv wr.
|
||||
@ -142,8 +140,8 @@ void RdmaAdapter::Process_CQ() {
|
||||
RdmaMessage rm;
|
||||
RdmaMessage::ParseMessage(rm, rb->buffer_);
|
||||
VLOG(2) << "recv RDMA message: " << MessageTypeToString(rm.type_);
|
||||
|
||||
if (rm.type_ == RDMA_MESSAGE_ACK) {
|
||||
|
||||
if (rm.type_ == RDMA_MESSAGE_ACK) {
|
||||
// receive an ack to a message
|
||||
rb = rc->tx_message_buffer_;
|
||||
rb->SetBufferStatus(remote, idle);
|
||||
@ -155,12 +153,12 @@ void RdmaAdapter::Process_CQ() {
|
||||
ab->SendNextItem();
|
||||
// find or create buffer
|
||||
RdmaBuffer* tb = rc->FindOrCreateBuffer(rm.name_);
|
||||
string key_with_step_id =
|
||||
VerbsUtil::AppendStepidToKey(rm.name_, rm.step_id_);
|
||||
string key_with_step_id =
|
||||
VerbsUtil::AppendStepidToKey(rm.name_, rm.step_id_);
|
||||
tb->EnqueueItem(key_with_step_id);
|
||||
// send the next tensor
|
||||
worker_env_->compute_pool->Schedule([tb](){tb->SendNextItem();});
|
||||
} else if (rm.type_ == RDMA_MESSAGE_BUFFER_IDLE) {
|
||||
worker_env_->compute_pool->Schedule([tb]() { tb->SendNextItem(); });
|
||||
} else if (rm.type_ == RDMA_MESSAGE_BUFFER_IDLE) {
|
||||
// receive tensor-buffer-ready message
|
||||
// send ack to release remote tx message buffer
|
||||
RdmaBuffer* ab = rc->tx_ack_buffer_;
|
||||
@ -168,7 +166,7 @@ void RdmaAdapter::Process_CQ() {
|
||||
// find buffer
|
||||
RdmaBuffer* tb = rc->FindBuffer(rm.name_);
|
||||
tb->SetBufferStatus(remote, idle);
|
||||
worker_env_->compute_pool->Schedule([tb](){tb->SendNextItem();});
|
||||
worker_env_->compute_pool->Schedule([tb]() { tb->SendNextItem(); });
|
||||
} else if (rm.type_ == RDMA_MESSAGE_BUFFER_REQUEST) {
|
||||
// remote host requests to create a tensor buffer;
|
||||
// send ack to release remote tx message buffer
|
||||
@ -194,31 +192,30 @@ void RdmaAdapter::Process_CQ() {
|
||||
mb->EnqueueItem(message);
|
||||
mb->SendNextItem();
|
||||
} else if (rm.type_ == RDMA_MESSAGE_BUFFER_RESPONSE) {
|
||||
// remote creates a buffer and responds
|
||||
// remote creates a buffer and responds
|
||||
// send ack to release remote tx message buffer
|
||||
RdmaBuffer* ab = rc->tx_ack_buffer_;
|
||||
ab->SendNextItem();
|
||||
// find buffer
|
||||
RdmaBuffer* tb = rc->FindBuffer(rm.name_);
|
||||
CHECK(rm.buffer_size_ == tb->size_)
|
||||
<< "rm.buffer_size = " << rm.buffer_size_
|
||||
<< "tb->size_ = " << tb->size_
|
||||
<< "rm.name_ = " << rm.name_;
|
||||
CHECK(rm.buffer_size_ == tb->size_)
|
||||
<< "rm.buffer_size = " << rm.buffer_size_
|
||||
<< "tb->size_ = " << tb->size_ << "rm.name_ = " << rm.name_;
|
||||
RemoteMR rmr;
|
||||
rmr.remote_addr = rm.remote_addr_;
|
||||
rmr.rkey = rm.rkey_;
|
||||
tb->SetRemoteMR(rmr, true);
|
||||
tb->SetBufferStatus(local, idle);
|
||||
tb->SetBufferStatus(remote, idle);
|
||||
worker_env_->compute_pool->Schedule([tb](){tb->SendNextItem();});
|
||||
worker_env_->compute_pool->Schedule([tb]() { tb->SendNextItem(); });
|
||||
} else if (rm.type_ == RDMA_MESSAGE_TENSOR_WRITE) {
|
||||
// tensor RDMA write completed
|
||||
worker_env_->compute_pool->Schedule([rm, rc](){
|
||||
string key_with_step_id =
|
||||
VerbsUtil::AppendStepidToKey(rm.name_, rm.step_id_);
|
||||
worker_env_->compute_pool->Schedule([rm, rc]() {
|
||||
string key_with_step_id =
|
||||
VerbsUtil::AppendStepidToKey(rm.name_, rm.step_id_);
|
||||
rc->RunRecvCallback(key_with_step_id);
|
||||
});
|
||||
}
|
||||
});
|
||||
}
|
||||
} else if (wc_[i].opcode == IBV_WC_RDMA_WRITE) {
|
||||
RdmaBuffer* rb = reinterpret_cast<RdmaBuffer*>(wc_[i].wr_id);
|
||||
rb->SetBufferStatus(local, idle);
|
||||
@ -226,7 +223,7 @@ void RdmaAdapter::Process_CQ() {
|
||||
RdmaMessage::ParseMessage(rm, rb->buffer_);
|
||||
VLOG(2) << "sent RDMA message: " << MessageTypeToString(rm.type_);
|
||||
if (rm.type_ != RDMA_MESSAGE_ACK) {
|
||||
worker_env_->compute_pool->Schedule([rb](){rb->SendNextItem();});
|
||||
worker_env_->compute_pool->Schedule([rb]() { rb->SendNextItem(); });
|
||||
}
|
||||
}
|
||||
}
|
||||
@ -235,9 +232,7 @@ void RdmaAdapter::Process_CQ() {
|
||||
|
||||
RdmaChannel::RdmaChannel(const RdmaAdapter* adapter, const string local_name,
|
||||
const string remote_name)
|
||||
: adapter_(adapter),
|
||||
local_name_(local_name),
|
||||
remote_name_(remote_name) {
|
||||
: adapter_(adapter), local_name_(local_name), remote_name_(remote_name) {
|
||||
// Create queue pair
|
||||
{
|
||||
struct ibv_qp_init_attr attr;
|
||||
@ -263,21 +258,21 @@ RdmaChannel::RdmaChannel(const RdmaAdapter* adapter, const string local_name,
|
||||
attr.port_num = 1;
|
||||
attr.qp_access_flags = IBV_ACCESS_LOCAL_WRITE | IBV_ACCESS_REMOTE_WRITE;
|
||||
|
||||
int mask = IBV_QP_STATE | IBV_QP_PKEY_INDEX | IBV_QP_PORT
|
||||
| IBV_QP_ACCESS_FLAGS;
|
||||
int mask =
|
||||
IBV_QP_STATE | IBV_QP_PKEY_INDEX | IBV_QP_PORT | IBV_QP_ACCESS_FLAGS;
|
||||
CHECK(!ibv_modify_qp(qp_, &attr, mask)) << "Failed to set QP to INIT";
|
||||
}
|
||||
|
||||
// Local address
|
||||
{
|
||||
struct ibv_port_attr attr;
|
||||
CHECK(!ibv_query_port(adapter_->context_, (uint8_t) 1, &attr))
|
||||
CHECK(!ibv_query_port(adapter_->context_, (uint8_t)1, &attr))
|
||||
<< "Query port";
|
||||
self_.lid = attr.lid;
|
||||
self_.qpn = qp_->qp_num;
|
||||
self_.psn = static_cast<uint32_t>(random::New64()) & 0xffffff;
|
||||
}
|
||||
|
||||
|
||||
// create message and ack buffers, then initialize the tables.
|
||||
{
|
||||
const string buffer_names[] = {"tx_message_buffer", "rx_message_buffer",
|
||||
@ -303,7 +298,7 @@ RdmaChannel::RdmaChannel(const RdmaAdapter* adapter, const string local_name,
|
||||
buffer_index_name_table_.insert({index, buffer_names[i]});
|
||||
buffer_name_index_table_.insert({buffer_names[i], index});
|
||||
}
|
||||
|
||||
|
||||
// Initiate recv
|
||||
for (int i = 0; i < 100; i++) {
|
||||
Recv();
|
||||
@ -320,17 +315,17 @@ RdmaChannel::~RdmaChannel() {
|
||||
}
|
||||
|
||||
void RdmaChannel::SetRemoteAddress(const RdmaAddress& ra, bool override) {
|
||||
mutex_lock lock{mu_};
|
||||
if ((override) || (!remote_set_)) {
|
||||
remote_.lid = ra.lid;
|
||||
remote_.qpn = ra.qpn;
|
||||
remote_.psn = ra.psn;
|
||||
remote_set_ = true;
|
||||
} else {
|
||||
CHECK(remote_.lid == ra.lid);
|
||||
CHECK(remote_.qpn == ra.qpn);
|
||||
CHECK(remote_.psn == ra.psn);
|
||||
}
|
||||
mutex_lock lock{mu_};
|
||||
if ((override) || (!remote_set_)) {
|
||||
remote_.lid = ra.lid;
|
||||
remote_.qpn = ra.qpn;
|
||||
remote_.psn = ra.psn;
|
||||
remote_set_ = true;
|
||||
} else {
|
||||
CHECK(remote_.lid == ra.lid);
|
||||
CHECK(remote_.qpn == ra.qpn);
|
||||
CHECK(remote_.psn == ra.psn);
|
||||
}
|
||||
}
|
||||
|
||||
// Adding tokens to the completion queue
|
||||
@ -338,7 +333,7 @@ void RdmaChannel::SetRemoteAddress(const RdmaAddress& ra, bool override) {
|
||||
void RdmaChannel::Recv() {
|
||||
struct ibv_recv_wr wr;
|
||||
memset(&wr, 0, sizeof(wr));
|
||||
wr.wr_id = (uint64_t) this;
|
||||
wr.wr_id = (uint64_t)this;
|
||||
struct ibv_recv_wr* bad_wr;
|
||||
CHECK(!ibv_post_recv(qp_, &wr, &bad_wr)) << "Failed to post recv";
|
||||
}
|
||||
@ -347,12 +342,11 @@ void RdmaChannel::Recv() {
|
||||
// Args:
|
||||
// buffer_name: name of the buffer
|
||||
// Returns:
|
||||
// 32-bit index
|
||||
uint32_t RdmaChannel::LookupBufferIndex(const string& buffer_name){
|
||||
|
||||
// 32-bit index
|
||||
uint32_t RdmaChannel::LookupBufferIndex(const string& buffer_name) {
|
||||
mutex_lock lock{bt_mu_};
|
||||
BufferNameIndexTable::iterator iter = buffer_name_index_table_.find(
|
||||
buffer_name);
|
||||
BufferNameIndexTable::iterator iter =
|
||||
buffer_name_index_table_.find(buffer_name);
|
||||
CHECK(iter != buffer_name_index_table_.end());
|
||||
return iter->second;
|
||||
}
|
||||
@ -380,14 +374,14 @@ RdmaBuffer* RdmaChannel::FindBuffer(const string& name) {
|
||||
}
|
||||
|
||||
// Find a buffer if it exists, otherwise create one.
|
||||
// The memory inside the created buffer is not allocated.
|
||||
// Args:
|
||||
// The memory inside the created buffer is not allocated.
|
||||
// Args:
|
||||
// name: the name of the buffer
|
||||
// buffer_type: TENSOR, MESSAGE or ACK.
|
||||
// Returns:
|
||||
// the named buffer
|
||||
RdmaBuffer* RdmaChannel::FindOrCreateBuffer(const string& name,
|
||||
BufferType buffer_type) {
|
||||
RdmaBuffer* RdmaChannel::FindOrCreateBuffer(const string& name,
|
||||
BufferType buffer_type) {
|
||||
mutex_lock lock{bt_mu_};
|
||||
RdmaBuffer* rb;
|
||||
// find index
|
||||
@ -405,7 +399,7 @@ RdmaBuffer* RdmaChannel::FindOrCreateBuffer(const string& name,
|
||||
} else if (buffer_type == MESSAGE) {
|
||||
rb = new RdmaMessageBuffer(this, name);
|
||||
} else if (buffer_type == ACK) {
|
||||
rb = new RdmaAckBuffer(this, name);
|
||||
rb = new RdmaAckBuffer(this, name);
|
||||
}
|
||||
buffer_name_index_table_.insert({name, index});
|
||||
buffer_index_name_table_.insert({index, name});
|
||||
@ -417,20 +411,19 @@ RdmaBuffer* RdmaChannel::FindOrCreateBuffer(const string& name,
|
||||
|
||||
// Insert callback to the callback_table.
|
||||
// The callback is activated when the corresponding tensor is received.
|
||||
// Arg:
|
||||
// Arg:
|
||||
// key: the name of the tensor
|
||||
// recv_done: the callback associated with the tensor.
|
||||
// Returns:
|
||||
// None
|
||||
void RdmaChannel::InsertRecvCallback(const string& key,
|
||||
std::function<void()> recv_done) {
|
||||
|
||||
void RdmaChannel::InsertRecvCallback(const string& key,
|
||||
std::function<void()> recv_done) {
|
||||
mutex_lock lock{ct_mu_};
|
||||
callback_table_.insert({key, recv_done});
|
||||
}
|
||||
|
||||
// Remove callback from the callback_table.
|
||||
// Arg:
|
||||
// Arg:
|
||||
// key: the name of the tensor
|
||||
// Returns:
|
||||
// None
|
||||
@ -440,7 +433,7 @@ void RdmaChannel::RemoveRecvCallback(const string& key) {
|
||||
}
|
||||
|
||||
// Run named callback in the callback_table.
|
||||
// Arg:
|
||||
// Arg:
|
||||
// key: the name of the tensor
|
||||
// Returns:
|
||||
// None
|
||||
@ -484,17 +477,15 @@ void RdmaChannel::Connect(const RdmaAddress& remoteAddr) {
|
||||
attr.ah_attr.sl = 0;
|
||||
attr.ah_attr.src_path_bits = 0;
|
||||
attr.ah_attr.port_num = 1;
|
||||
|
||||
|
||||
int r;
|
||||
CHECK(!(r = ibv_modify_qp(qp_, &attr,
|
||||
IBV_QP_STATE |
|
||||
IBV_QP_AV |
|
||||
IBV_QP_PATH_MTU |
|
||||
IBV_QP_DEST_QPN |
|
||||
IBV_QP_RQ_PSN |
|
||||
IBV_QP_MAX_DEST_RD_ATOMIC |
|
||||
IBV_QP_MIN_RNR_TIMER))) << "QP to Ready to Receive " << r;
|
||||
|
||||
IBV_QP_STATE | IBV_QP_AV | IBV_QP_PATH_MTU |
|
||||
IBV_QP_DEST_QPN | IBV_QP_RQ_PSN |
|
||||
IBV_QP_MAX_DEST_RD_ATOMIC |
|
||||
IBV_QP_MIN_RNR_TIMER)))
|
||||
<< "QP to Ready to Receive " << r;
|
||||
|
||||
memset(&attr, 0, sizeof(ibv_qp_attr));
|
||||
attr.qp_state = IBV_QPS_RTS;
|
||||
attr.sq_psn = self_.psn;
|
||||
@ -502,15 +493,13 @@ void RdmaChannel::Connect(const RdmaAddress& remoteAddr) {
|
||||
attr.retry_cnt = 7;
|
||||
attr.rnr_retry = 7; /* infinite */
|
||||
attr.max_rd_atomic = 1;
|
||||
|
||||
|
||||
CHECK(!(r = ibv_modify_qp(qp_, &attr,
|
||||
IBV_QP_STATE |
|
||||
IBV_QP_TIMEOUT |
|
||||
IBV_QP_RETRY_CNT |
|
||||
IBV_QP_RNR_RETRY |
|
||||
IBV_QP_SQ_PSN |
|
||||
IBV_QP_MAX_QP_RD_ATOMIC))) << "QP to Ready to Send " << r;
|
||||
|
||||
IBV_QP_STATE | IBV_QP_TIMEOUT | IBV_QP_RETRY_CNT |
|
||||
IBV_QP_RNR_RETRY | IBV_QP_SQ_PSN |
|
||||
IBV_QP_MAX_QP_RD_ATOMIC)))
|
||||
<< "QP to Ready to Send " << r;
|
||||
|
||||
connected_ = true;
|
||||
} else {
|
||||
LOG(INFO) << "channel already connected";
|
||||
@ -518,7 +507,7 @@ void RdmaChannel::Connect(const RdmaAddress& remoteAddr) {
|
||||
}
|
||||
|
||||
RdmaBuffer::RdmaBuffer(RdmaChannel* channel, string name)
|
||||
: channel_(channel), name_(name) {}
|
||||
: channel_(channel), name_(name) {}
|
||||
|
||||
RdmaBuffer::~RdmaBuffer() {
|
||||
CHECK(!ibv_dereg_mr(self_)) << "ibv_dereg_mr failed";
|
||||
@ -528,9 +517,9 @@ RdmaBuffer::~RdmaBuffer() {
|
||||
void RdmaBuffer::FreeBuffer() {
|
||||
if ((buffer_ != nullptr) && buffer_on_host_) {
|
||||
free(buffer_);
|
||||
}
|
||||
}
|
||||
// TODO
|
||||
// release buffer if it is on device.
|
||||
// release buffer if it is on device.
|
||||
// We don't support RDMABuffer on device at this moment.
|
||||
}
|
||||
|
||||
@ -548,14 +537,12 @@ void RdmaBuffer::CreateCPUBuffer(size_t size, bool lock) {
|
||||
if (local_status_ != none) {
|
||||
// delete existing buffer
|
||||
CHECK(!ibv_dereg_mr(self_)) << "ibv_dereg_mr failed";
|
||||
FreeBuffer();
|
||||
FreeBuffer();
|
||||
}
|
||||
size_ = size;
|
||||
buffer_ = malloc(size_);
|
||||
self_ = ibv_reg_mr(channel_->adapter_->pd_,
|
||||
buffer_, size_,
|
||||
IBV_ACCESS_LOCAL_WRITE |
|
||||
IBV_ACCESS_REMOTE_WRITE);
|
||||
self_ = ibv_reg_mr(channel_->adapter_->pd_, buffer_, size_,
|
||||
IBV_ACCESS_LOCAL_WRITE | IBV_ACCESS_REMOTE_WRITE);
|
||||
CHECK(self_) << "Failed to register memory region";
|
||||
buffer_on_host_ = true;
|
||||
local_status_ = idle;
|
||||
@ -572,53 +559,52 @@ void RdmaBuffer::CreateCPUBuffer(size_t size, bool lock) {
|
||||
// None
|
||||
void RdmaBuffer::SetRemoteMR(RemoteMR rmr, bool override) {
|
||||
mutex_lock lock{mu_};
|
||||
if ((override) || (remote_status_ == none)) {
|
||||
if ((override) || (remote_status_ == none)) {
|
||||
remote_.remote_addr = rmr.remote_addr;
|
||||
remote_.rkey = rmr.rkey;
|
||||
remote_status_ = idle;
|
||||
} else {
|
||||
CHECK(remote_.remote_addr == rmr.remote_addr);
|
||||
CHECK(remote_.rkey == rmr.rkey);
|
||||
}
|
||||
CHECK(remote_.rkey == rmr.rkey);
|
||||
}
|
||||
}
|
||||
|
||||
// Put a task in the buffer's job queue
|
||||
void RdmaBuffer::EnqueueItem(string item){
|
||||
void RdmaBuffer::EnqueueItem(string item) {
|
||||
mutex_lock lock{mu_};
|
||||
queue_.push(item);
|
||||
}
|
||||
|
||||
// Rdma-Write the content of the buffer
|
||||
void RdmaBuffer::Write(uint32_t imm_data, size_t buffer_size) {
|
||||
|
||||
struct ibv_sge list;
|
||||
list.addr = (uint64_t) buffer_;
|
||||
list.addr = (uint64_t)buffer_;
|
||||
list.length = buffer_size;
|
||||
list.lkey = self_->lkey;
|
||||
|
||||
struct ibv_send_wr wr;
|
||||
memset(&wr, 0, sizeof(wr));
|
||||
wr.wr_id = (uint64_t) this;
|
||||
wr.wr_id = (uint64_t)this;
|
||||
wr.sg_list = &list;
|
||||
wr.num_sge = 1;
|
||||
wr.opcode = IBV_WR_RDMA_WRITE_WITH_IMM;
|
||||
wr.send_flags = IBV_SEND_SIGNALED;
|
||||
wr.imm_data = imm_data;
|
||||
wr.wr.rdma.remote_addr = (uint64_t) remote_.remote_addr;
|
||||
wr.wr.rdma.remote_addr = (uint64_t)remote_.remote_addr;
|
||||
wr.wr.rdma.rkey = remote_.rkey;
|
||||
|
||||
struct ibv_send_wr *bad_wr;
|
||||
struct ibv_send_wr* bad_wr;
|
||||
CHECK(!ibv_post_send(channel_->qp_, &wr, &bad_wr)) << "Failed to post send";
|
||||
}
|
||||
|
||||
RdmaAckBuffer::RdmaAckBuffer(RdmaChannel* channel, string name)
|
||||
: RdmaBuffer(channel, name) {}
|
||||
|
||||
: RdmaBuffer(channel, name) {}
|
||||
|
||||
RdmaMessageBuffer::RdmaMessageBuffer(RdmaChannel* channel, string name)
|
||||
: RdmaBuffer(channel, name) {}
|
||||
: RdmaBuffer(channel, name) {}
|
||||
|
||||
RdmaTensorBuffer::RdmaTensorBuffer(RdmaChannel* channel, string name)
|
||||
: RdmaBuffer(channel, name) {}
|
||||
: RdmaBuffer(channel, name) {}
|
||||
|
||||
// Send the next ack from the buffer's job queue.
|
||||
void RdmaAckBuffer::SendNextItem() {
|
||||
@ -636,13 +622,12 @@ void RdmaAckBuffer::SendNextItem() {
|
||||
void RdmaMessageBuffer::SendNextItem() {
|
||||
uint32_t imm_data = LookupBufferIndex("rx_message_buffer");
|
||||
mu_.lock();
|
||||
if (!queue_.empty() && (local_status_ == idle)
|
||||
&& (remote_status_ == idle)) {
|
||||
if (!queue_.empty() && (local_status_ == idle) && (remote_status_ == idle)) {
|
||||
local_status_ = busy;
|
||||
remote_status_= busy;
|
||||
remote_status_ = busy;
|
||||
string message = queue_.front();
|
||||
queue_.pop();
|
||||
// local/remote_status_ won't be set back to idle
|
||||
// local/remote_status_ won't be set back to idle
|
||||
// unitl Write() is successful
|
||||
mu_.unlock();
|
||||
memcpy(buffer_, message.data(), message.size());
|
||||
@ -665,61 +650,56 @@ void RdmaTensorBuffer::SendNextItem() {
|
||||
}
|
||||
// send the tensor if a key is acquired.
|
||||
if (key_with_step_id != "") {
|
||||
VLOG(2) << "try to send tensor: " << key_with_step_id;
|
||||
VLOG(2) << "try to send tensor: " << key_with_step_id;
|
||||
string key;
|
||||
int64 step_id;
|
||||
VerbsUtil::GetKeyAndStepId(key_with_step_id, key, step_id);
|
||||
CHECK(key.compare(name_) == 0);
|
||||
Rendezvous::ParsedKey parsed;
|
||||
Rendezvous::ParseKey(key, &parsed);
|
||||
Rendezvous::DoneCallback cb = [this, key_with_step_id, key,
|
||||
step_id, parsed](const Status& status,
|
||||
const Rendezvous::Args& send_args,
|
||||
const Rendezvous::Args& recv_args,
|
||||
const Tensor& in, bool is_dead) {
|
||||
CHECK(status.ok()) << "RecvLocalAsync was not ok, key"
|
||||
<< key_with_step_id
|
||||
<< " error message: " << status.error_message();
|
||||
Rendezvous::DoneCallback cb = [this, key_with_step_id, key, step_id,
|
||||
parsed](const Status& status,
|
||||
const Rendezvous::Args& send_args,
|
||||
const Rendezvous::Args& recv_args,
|
||||
const Tensor& in, bool is_dead) {
|
||||
CHECK(status.ok()) << "RecvLocalAsync was not ok, key" << key_with_step_id
|
||||
<< " error message: " << status.error_message();
|
||||
size_t buffer_size = RdmaMessage::kMessageTotalBytes;
|
||||
size_t tensor_bytes = 0;
|
||||
TensorProto proto;
|
||||
// Figures out which device the tensor is hosted on.
|
||||
Device* src_dev = nullptr;
|
||||
Status s =
|
||||
channel_->adapter_->worker_env_->
|
||||
device_mgr->LookupDevice(parsed.src_device, &src_dev);
|
||||
Status s = channel_->adapter_->worker_env_->device_mgr->LookupDevice(
|
||||
parsed.src_device, &src_dev);
|
||||
CHECK(s.ok()) << "src device not found";
|
||||
// Does the device have the right incarnation number we expect?
|
||||
CHECK(src_dev->attributes().incarnation() ==
|
||||
parsed.src_incarnation)
|
||||
<< "RecvTensor expects a different device incarnation: "
|
||||
<< parsed.src_incarnation
|
||||
<< " vs. "
|
||||
<< src_dev->attributes().incarnation()
|
||||
<< ". Your worker job was probably restarted. Check your "
|
||||
<< "worker job for the reason why it was restarted.";
|
||||
CHECK(src_dev->attributes().incarnation() == parsed.src_incarnation)
|
||||
<< "RecvTensor expects a different device incarnation: "
|
||||
<< parsed.src_incarnation << " vs. "
|
||||
<< src_dev->attributes().incarnation()
|
||||
<< ". Your worker job was probably restarted. Check your "
|
||||
<< "worker job for the reason why it was restarted.";
|
||||
Device* dst_dev = nullptr;
|
||||
// destination is on CPU.
|
||||
s = channel_->adapter_->worker_env_->
|
||||
device_mgr->LookupDevice("CPU:0", &dst_dev);
|
||||
CHECK(s.ok())<< "dst device not found";
|
||||
s = channel_->adapter_->worker_env_->device_mgr->LookupDevice("CPU:0",
|
||||
&dst_dev);
|
||||
CHECK(s.ok()) << "dst device not found";
|
||||
AllocatorAttributes dst_alloc_attr;
|
||||
dst_alloc_attr.set_on_host(true);
|
||||
// string tensor needs to be serialized
|
||||
if (src_dev->tensorflow_gpu_device_info() &&
|
||||
(!send_args.alloc_attrs.on_host())) {
|
||||
if (src_dev->tensorflow_gpu_device_info() &&
|
||||
(!send_args.alloc_attrs.on_host())) {
|
||||
CHECK(send_args.device_context)
|
||||
<< "send dev name: " << src_dev->name()
|
||||
<< " gpu_info: " << src_dev->tensorflow_gpu_device_info();
|
||||
// "val" is on a GPU. Uses GPUUtil to fill the proto.
|
||||
s = VerbsUtil::SetProtoFromGPUSync(in, src_dev,
|
||||
send_args.device_context,
|
||||
&proto, is_dead);
|
||||
CHECK(s.ok()) << "set proto from gpu sync";
|
||||
s = VerbsUtil::SetProtoFromGPUSync(
|
||||
in, src_dev, send_args.device_context, &proto, is_dead);
|
||||
CHECK(s.ok()) << "set proto from gpu sync";
|
||||
} else {
|
||||
// tensor is in CPU memory.
|
||||
in.AsProtoTensorContent(&proto);
|
||||
}
|
||||
}
|
||||
tensor_bytes = proto.ByteSize();
|
||||
// maybe some margin for string tensor?
|
||||
buffer_size += tensor_bytes;
|
||||
@ -734,13 +714,12 @@ void RdmaTensorBuffer::SendNextItem() {
|
||||
rm.tensor_bytes_ = tensor_bytes;
|
||||
rm.buffer_size_ = buffer_size;
|
||||
mu_.lock();
|
||||
if (local_status_ == none ||
|
||||
(buffer_size > size_ &&
|
||||
local_status_ == idle &&
|
||||
if (local_status_ == none ||
|
||||
(buffer_size > size_ && local_status_ == idle &&
|
||||
remote_status_ == idle)) {
|
||||
if ((local_status_ != none) && (buffer_size > size_)) {
|
||||
CHECK(rm.data_type_ == DT_STRING)
|
||||
<< "Only string tensor allows to change size";
|
||||
CHECK(rm.data_type_ == DT_STRING)
|
||||
<< "Only string tensor allows to change size";
|
||||
}
|
||||
CreateCPUBuffer(buffer_size, false);
|
||||
mu_.unlock();
|
||||
@ -752,29 +731,29 @@ void RdmaTensorBuffer::SendNextItem() {
|
||||
rm.rkey_ = self_->rkey;
|
||||
string message = RdmaMessage::CreateMessage(rm);
|
||||
channel_->tx_message_buffer_->EnqueueItem(message);
|
||||
channel_->tx_message_buffer_->SendNextItem();
|
||||
} else if((local_status_ == idle) && (remote_status_ == idle)) {
|
||||
channel_->tx_message_buffer_->SendNextItem();
|
||||
} else if ((local_status_ == idle) && (remote_status_ == idle)) {
|
||||
// both buffers are ready, send the tensor
|
||||
local_status_ = busy;
|
||||
remote_status_ = busy;
|
||||
// local/remote_status_ won't be set back to idle
|
||||
// local/remote_status_ won't be set back to idle
|
||||
// unitl Write() is successful
|
||||
mu_.unlock();
|
||||
CHECK((buffer_size == size_ && rm.data_type_ != DT_STRING) ||
|
||||
(buffer_size <= size_ && rm.data_type_ == DT_STRING))
|
||||
<< "tensor and buffer size do not agree!"
|
||||
<< " buffer_size = " << size_
|
||||
<< " requested tensor size = " << buffer_size
|
||||
<< in.DebugString();
|
||||
<< "tensor and buffer size do not agree!"
|
||||
<< " buffer_size = " << size_
|
||||
<< " requested tensor size = " << buffer_size << in.DebugString();
|
||||
uint32_t imm_data = LookupBufferIndex(key);
|
||||
rm.type_ = RDMA_MESSAGE_TENSOR_WRITE;
|
||||
string message = RdmaMessage::CreateMessage(rm);
|
||||
memcpy(buffer_, message.data(), message.size());
|
||||
if (!is_dead) {
|
||||
// copy the tensor buffer content
|
||||
void* output = static_cast<void*>(static_cast<char*>(
|
||||
buffer_) + RdmaMessage::kTensorBufferStartIndex);
|
||||
CHECK(tensor_bytes + RdmaMessage::kTensorBufferStartIndex <= size_);
|
||||
void* output =
|
||||
static_cast<void*>(static_cast<char*>(buffer_) +
|
||||
RdmaMessage::kTensorBufferStartIndex);
|
||||
CHECK(tensor_bytes + RdmaMessage::kTensorBufferStartIndex <= size_);
|
||||
proto.SerializeToArray(output, tensor_bytes);
|
||||
} else {
|
||||
buffer_size = RdmaMessage::kMessageTotalBytes;
|
||||
@ -789,8 +768,8 @@ void RdmaTensorBuffer::SendNextItem() {
|
||||
// Use default session (legacy_session_)
|
||||
// TODO use WorkerSessionForSession
|
||||
// need to pass in session handle
|
||||
channel_->adapter_->worker_env_->session_mgr->
|
||||
LegacySession()->rendezvous_mgr->RecvLocalAsync(step_id, parsed, cb);
|
||||
channel_->adapter_->worker_env_->session_mgr->LegacySession()
|
||||
->rendezvous_mgr->RecvLocalAsync(step_id, parsed, cb);
|
||||
}
|
||||
}
|
||||
|
||||
@ -811,8 +790,10 @@ string RdmaMessage::CreateMessage(const RdmaMessage& rm) {
|
||||
// TENSOR_WRITE: type|name_size|tensor_name|step_id|...|is_dead
|
||||
// |data_type|tensor_shape|tensor_bytes
|
||||
// BUFFER_IDLE: type|name_size|buffer_name
|
||||
// BUFFER_REQUEST: type|name_size|buffer_name|...|buffer_size|remote_addr|rkey|
|
||||
// BUFFER_RESPONSE: type|name_size|buffer_name|...|buffer_size|remote_addr|rkey|
|
||||
// BUFFER_REQUEST:
|
||||
// type|name_size|buffer_name|...|buffer_size|remote_addr|rkey|
|
||||
// BUFFER_RESPONSE:
|
||||
// type|name_size|buffer_name|...|buffer_size|remote_addr|rkey|
|
||||
char message[kMessageTotalBytes];
|
||||
// type
|
||||
message[kTypeStartIndex] = static_cast<char>(rm.type_) & 0xff;
|
||||
@ -821,32 +802,32 @@ string RdmaMessage::CreateMessage(const RdmaMessage& rm) {
|
||||
// name
|
||||
memcpy(&message[kNameStartIndex], rm.name_.data(), rm.name_.size());
|
||||
// buffer_size, remote_addr, rkey
|
||||
if ((rm.type_ == RDMA_MESSAGE_BUFFER_REQUEST) ||
|
||||
if ((rm.type_ == RDMA_MESSAGE_BUFFER_REQUEST) ||
|
||||
(rm.type_ == RDMA_MESSAGE_BUFFER_RESPONSE)) {
|
||||
memcpy(&message[kBufferSizeStartIndex], &rm.buffer_size_,
|
||||
sizeof(rm.buffer_size_));
|
||||
memcpy(&message[kRemoteAddrStartIndex], &rm.remote_addr_,
|
||||
sizeof(rm.remote_addr_));
|
||||
memcpy(&message[kRkeyStartIndex], &rm.rkey_, sizeof(rm.rkey_));
|
||||
memcpy(&message[kBufferSizeStartIndex], &rm.buffer_size_,
|
||||
sizeof(rm.buffer_size_));
|
||||
memcpy(&message[kRemoteAddrStartIndex], &rm.remote_addr_,
|
||||
sizeof(rm.remote_addr_));
|
||||
memcpy(&message[kRkeyStartIndex], &rm.rkey_, sizeof(rm.rkey_));
|
||||
}
|
||||
// step_id
|
||||
if ((rm.type_ == RDMA_MESSAGE_TENSOR_WRITE) ||
|
||||
if ((rm.type_ == RDMA_MESSAGE_TENSOR_WRITE) ||
|
||||
(rm.type_ == RDMA_MESSAGE_TENSOR_REQUEST)) {
|
||||
memcpy(&message[kStepIdStartIndex], &rm.step_id_, sizeof(rm.step_id_));
|
||||
}
|
||||
// is_dead, data_type, tensor_shape, tensor_bytes
|
||||
if (rm.type_ == RDMA_MESSAGE_TENSOR_WRITE) {
|
||||
memcpy(&message[kIsDeadStartIndex], &rm.is_dead_, sizeof(rm.is_dead_));
|
||||
|
||||
memcpy(&message[kDataTypeStartIndex], &rm.data_type_,
|
||||
sizeof(rm.data_type_));
|
||||
memcpy(&message[kTensorShapeStartIndex], &rm.tensor_shape_,
|
||||
sizeof(rm.tensor_shape_));
|
||||
memcpy(&message[kTensorBytesStartIndex], &rm.tensor_bytes_,
|
||||
sizeof(rm.tensor_bytes_));
|
||||
|
||||
memcpy(&message[kDataTypeStartIndex], &rm.data_type_,
|
||||
sizeof(rm.data_type_));
|
||||
memcpy(&message[kTensorShapeStartIndex], &rm.tensor_shape_,
|
||||
sizeof(rm.tensor_shape_));
|
||||
memcpy(&message[kTensorBytesStartIndex], &rm.tensor_bytes_,
|
||||
sizeof(rm.tensor_bytes_));
|
||||
}
|
||||
return string(message, kMessageTotalBytes);
|
||||
}
|
||||
}
|
||||
|
||||
// Parse a RdmaMessage according to the pre-defined format
|
||||
// Args:
|
||||
@ -865,27 +846,26 @@ void RdmaMessage::ParseMessage(RdmaMessage& rm, void* buffer) {
|
||||
// buffer_size, remote_addr, rkey
|
||||
if ((rm.type_ == RDMA_MESSAGE_BUFFER_REQUEST) ||
|
||||
(rm.type_ == RDMA_MESSAGE_BUFFER_RESPONSE)) {
|
||||
memcpy(&rm.buffer_size_, &message[kBufferSizeStartIndex],
|
||||
sizeof(rm.buffer_size_));
|
||||
memcpy(&rm.remote_addr_, &message[kRemoteAddrStartIndex],
|
||||
sizeof(rm.remote_addr_));
|
||||
memcpy(&rm.rkey_, &message[kRkeyStartIndex],
|
||||
sizeof(rm.rkey_));
|
||||
memcpy(&rm.buffer_size_, &message[kBufferSizeStartIndex],
|
||||
sizeof(rm.buffer_size_));
|
||||
memcpy(&rm.remote_addr_, &message[kRemoteAddrStartIndex],
|
||||
sizeof(rm.remote_addr_));
|
||||
memcpy(&rm.rkey_, &message[kRkeyStartIndex], sizeof(rm.rkey_));
|
||||
}
|
||||
// step_id
|
||||
if ((rm.type_ == RDMA_MESSAGE_TENSOR_WRITE) ||
|
||||
if ((rm.type_ == RDMA_MESSAGE_TENSOR_WRITE) ||
|
||||
(rm.type_ == RDMA_MESSAGE_TENSOR_REQUEST)) {
|
||||
memcpy(&rm.step_id_, &message[kStepIdStartIndex], sizeof(rm.step_id_));
|
||||
}
|
||||
// data_type, tensor_bytes, tensor_shape, is_dead
|
||||
// data_type, tensor_bytes, tensor_shape, is_dead
|
||||
if (rm.type_ == RDMA_MESSAGE_TENSOR_WRITE) {
|
||||
memcpy(&rm.is_dead_, &message[kIsDeadStartIndex], sizeof(rm.is_dead_));
|
||||
memcpy(&rm.data_type_, &message[kDataTypeStartIndex],
|
||||
sizeof(rm.data_type_));
|
||||
memcpy(&rm.data_type_, &message[kDataTypeStartIndex],
|
||||
sizeof(rm.data_type_));
|
||||
memcpy(&rm.tensor_shape_, &message[kTensorShapeStartIndex],
|
||||
sizeof(rm.tensor_shape_));
|
||||
memcpy(&rm.tensor_bytes_, &message[kTensorBytesStartIndex],
|
||||
sizeof(rm.tensor_bytes_));
|
||||
sizeof(rm.tensor_shape_));
|
||||
memcpy(&rm.tensor_bytes_, &message[kTensorBytesStartIndex],
|
||||
sizeof(rm.tensor_bytes_));
|
||||
}
|
||||
}
|
||||
|
||||
|
@ -19,11 +19,11 @@ limitations under the License.
|
||||
#ifdef TENSORFLOW_USE_VERBS
|
||||
|
||||
#include <infiniband/verbs.h>
|
||||
#include <memory> // for shared_ptr
|
||||
#include <cstring> // for memset
|
||||
#include <cstring> // for memset
|
||||
#include <functional>
|
||||
#include <memory> // for shared_ptr
|
||||
#include <queue>
|
||||
#include <string>
|
||||
#include <functional>
|
||||
#include <unordered_map>
|
||||
#include <vector>
|
||||
|
||||
@ -37,43 +37,46 @@ namespace tensorflow {
|
||||
|
||||
// structure to save the address of remote channels.
|
||||
struct RdmaAddress {
|
||||
uint32_t lid;
|
||||
uint32_t qpn;
|
||||
uint32_t psn;
|
||||
uint32_t lid;
|
||||
uint32_t qpn;
|
||||
uint32_t psn;
|
||||
};
|
||||
// structure to save information for remote memory regions.
|
||||
struct RemoteMR{
|
||||
uint64_t remote_addr;
|
||||
uint32_t rkey;
|
||||
struct RemoteMR {
|
||||
uint64_t remote_addr;
|
||||
uint32_t rkey;
|
||||
};
|
||||
enum BufferStatus { none, idle, busy };
|
||||
enum Location { local, remote };
|
||||
enum BufferType { ACK, MESSAGE, TENSOR };
|
||||
enum RdmaMessageType {
|
||||
RDMA_MESSAGE_ACK,
|
||||
RDMA_MESSAGE_BUFFER_IDLE,
|
||||
RDMA_MESSAGE_BUFFER_REQUEST,
|
||||
RDMA_MESSAGE_BUFFER_RESPONSE,
|
||||
RDMA_MESSAGE_TENSOR_REQUEST,
|
||||
RDMA_MESSAGE_TENSOR_WRITE
|
||||
};
|
||||
enum BufferStatus {none, idle, busy};
|
||||
enum Location {local, remote};
|
||||
enum BufferType {ACK, MESSAGE, TENSOR};
|
||||
enum RdmaMessageType {RDMA_MESSAGE_ACK,
|
||||
RDMA_MESSAGE_BUFFER_IDLE,
|
||||
RDMA_MESSAGE_BUFFER_REQUEST,
|
||||
RDMA_MESSAGE_BUFFER_RESPONSE,
|
||||
RDMA_MESSAGE_TENSOR_REQUEST,
|
||||
RDMA_MESSAGE_TENSOR_WRITE};
|
||||
class RdmaBuffer;
|
||||
// Class that represents the Rdma Adapter.
|
||||
// Responsible for creation of the completion queue, and handling
|
||||
// of work completions.
|
||||
// of work completions.
|
||||
class RdmaAdapter {
|
||||
friend class RdmaChannel;
|
||||
friend class RdmaBuffer;
|
||||
friend class RdmaAckBuffer;
|
||||
friend class RdmaMessageBuffer;
|
||||
friend class RdmaTensorBuffer;
|
||||
friend class RdmaMgr;
|
||||
friend class RdmaRemoteRendezvous;
|
||||
friend class RdmaChannel;
|
||||
friend class RdmaBuffer;
|
||||
friend class RdmaAckBuffer;
|
||||
friend class RdmaMessageBuffer;
|
||||
friend class RdmaTensorBuffer;
|
||||
friend class RdmaMgr;
|
||||
friend class RdmaRemoteRendezvous;
|
||||
|
||||
public:
|
||||
RdmaAdapter(const WorkerEnv* worker_env);
|
||||
~RdmaAdapter();
|
||||
// Adapter name, e.g. mlx5_0.
|
||||
string name() const;
|
||||
void Process_CQ();
|
||||
|
||||
|
||||
protected:
|
||||
static const int MAX_CONCURRENT_WRITES = 1000;
|
||||
ibv_context* context_;
|
||||
@ -94,36 +97,39 @@ class RdmaAdapter {
|
||||
// Class that represents a connection to a remote Rdma peer.
|
||||
// Responsible for connecting queue pairs.
|
||||
class RdmaChannel {
|
||||
friend class RdmaAdapter;
|
||||
friend class RdmaBuffer;
|
||||
friend class RdmaAckBuffer;
|
||||
friend class RdmaMessageBuffer;
|
||||
friend class RdmaTensorBuffer;
|
||||
friend class RdmaMgr;
|
||||
friend class RdmaRemoteRendezvous;
|
||||
friend class RdmaAdapter;
|
||||
friend class RdmaBuffer;
|
||||
friend class RdmaAckBuffer;
|
||||
friend class RdmaMessageBuffer;
|
||||
friend class RdmaTensorBuffer;
|
||||
friend class RdmaMgr;
|
||||
friend class RdmaRemoteRendezvous;
|
||||
|
||||
public:
|
||||
explicit RdmaChannel(const RdmaAdapter* adapter, const string local_name,
|
||||
const string remote_name_);
|
||||
explicit RdmaChannel(const RdmaAdapter* adapter, const string local_name,
|
||||
const string remote_name_);
|
||||
~RdmaChannel();
|
||||
inline const RdmaAddress& self() { return self_; }
|
||||
RdmaAddress address() const;
|
||||
inline const std::vector<RdmaBuffer*>& message_buffers() const {
|
||||
return message_buffers_;}
|
||||
return message_buffers_;
|
||||
}
|
||||
void Connect(const RdmaAddress& remoteAddr);
|
||||
void Connect();
|
||||
void Recv();
|
||||
RdmaBuffer* FindBuffer(const uint32_t index);
|
||||
RdmaBuffer* FindBuffer(const string& name);
|
||||
RdmaBuffer* FindOrCreateBuffer(const string& name,
|
||||
RdmaBuffer* FindOrCreateBuffer(const string& name,
|
||||
BufferType buffer_type = TENSOR);
|
||||
uint32_t LookupBufferIndex (const string& buffer_name);
|
||||
uint32_t LookupBufferIndex(const string& buffer_name);
|
||||
void SetRemoteAddress(const RdmaAddress& ra, bool override);
|
||||
void InsertRecvCallback(const string& key, std::function<void()> recv_done);
|
||||
void RemoveRecvCallback(const string& key);
|
||||
void RunRecvCallback(const string& key);
|
||||
static const int kNumMessageBuffers = 4;
|
||||
|
||||
protected:
|
||||
const RdmaAdapter* adapter_;
|
||||
const RdmaAdapter* adapter_;
|
||||
RdmaAddress self_;
|
||||
string local_name_;
|
||||
string remote_name_;
|
||||
@ -151,10 +157,11 @@ class RdmaChannel {
|
||||
|
||||
// Class that represents a buffer for Rdma writes and reads.
|
||||
class RdmaBuffer {
|
||||
friend class RdmaChannel;
|
||||
friend class RdmaAdapter;
|
||||
friend class RdmaMgr;
|
||||
friend class RdmaRemoteRendezvous;
|
||||
friend class RdmaChannel;
|
||||
friend class RdmaAdapter;
|
||||
friend class RdmaMgr;
|
||||
friend class RdmaRemoteRendezvous;
|
||||
|
||||
public:
|
||||
explicit RdmaBuffer(RdmaChannel* channel, string name);
|
||||
virtual ~RdmaBuffer();
|
||||
@ -173,10 +180,11 @@ class RdmaBuffer {
|
||||
void FreeBuffer();
|
||||
void EnqueueItem(string Item);
|
||||
virtual void SendNextItem(){};
|
||||
void CreateCPUBuffer(size_t size, bool lock=true);
|
||||
void CreateCPUBuffer(size_t size, bool lock = true);
|
||||
void SetRemoteMR(RemoteMR rmi, bool override);
|
||||
uint32_t LookupBufferIndex (const string& buffer_name) {
|
||||
return const_cast<RdmaChannel*>(channel_)->LookupBufferIndex(buffer_name);}
|
||||
uint32_t LookupBufferIndex(const string& buffer_name) {
|
||||
return const_cast<RdmaChannel*>(channel_)->LookupBufferIndex(buffer_name);
|
||||
}
|
||||
void Write(uint32_t imm_data, size_t buffer_size);
|
||||
|
||||
protected:
|
||||
@ -188,7 +196,7 @@ class RdmaBuffer {
|
||||
ibv_mr* self_ = nullptr;
|
||||
mutex mu_;
|
||||
RemoteMR remote_;
|
||||
std::queue <string> queue_ GUARDED_BY(mu_);
|
||||
std::queue<string> queue_ GUARDED_BY(mu_);
|
||||
BufferStatus local_status_ GUARDED_BY(mu_) = none;
|
||||
BufferStatus remote_status_ GUARDED_BY(mu_) = none;
|
||||
};
|
||||
@ -201,8 +209,9 @@ class RdmaAckBuffer : public RdmaBuffer {
|
||||
};
|
||||
|
||||
class RdmaMessageBuffer : public RdmaBuffer {
|
||||
friend class RdmaChannel;
|
||||
friend class RdmaAapater;
|
||||
friend class RdmaChannel;
|
||||
friend class RdmaAapater;
|
||||
|
||||
public:
|
||||
explicit RdmaMessageBuffer(RdmaChannel* channel, string name);
|
||||
virtual ~RdmaMessageBuffer() override {}
|
||||
@ -228,40 +237,41 @@ struct RdmaMessage {
|
||||
DataType data_type_;
|
||||
TensorShape tensor_shape_;
|
||||
size_t tensor_bytes_;
|
||||
|
||||
// type|name_size|name|step_id|buffer_size|remote_addr|rkey|is_dead|...
|
||||
// 1B| 2B | 512| 8B | 8B | 8B | 4B | 1B |...
|
||||
// ...|data_type|tensor_shape|tensor_bytes|tensor_buffer
|
||||
// ...| XB | XB | 8B |...
|
||||
//
|
||||
|
||||
// type|name_size|name|step_id|buffer_size|remote_addr|rkey|is_dead|...
|
||||
// 1B| 2B | 512| 8B | 8B | 8B | 4B | 1B |...
|
||||
// ...|data_type|tensor_shape|tensor_bytes|tensor_buffer
|
||||
// ...| XB | XB | 8B |...
|
||||
//
|
||||
static const size_t kNameCapacity = 512;
|
||||
static const size_t kTypeStartIndex = 0;
|
||||
static const size_t kNameSizeStartIndex = kTypeStartIndex + sizeof(type_);
|
||||
static const size_t kNameStartIndex = kNameSizeStartIndex + sizeof(name_size_);
|
||||
static const size_t kNameStartIndex =
|
||||
kNameSizeStartIndex + sizeof(name_size_);
|
||||
static const size_t kStepIdStartIndex = kNameStartIndex + kNameCapacity;
|
||||
static const size_t kBufferSizeStartIndex = kStepIdStartIndex
|
||||
+ sizeof(step_id_);
|
||||
static const size_t kRemoteAddrStartIndex = kBufferSizeStartIndex
|
||||
+ sizeof(buffer_size_);
|
||||
static const size_t kRkeyStartIndex = kRemoteAddrStartIndex
|
||||
+ sizeof(remote_addr_);
|
||||
static const size_t kBufferSizeStartIndex =
|
||||
kStepIdStartIndex + sizeof(step_id_);
|
||||
static const size_t kRemoteAddrStartIndex =
|
||||
kBufferSizeStartIndex + sizeof(buffer_size_);
|
||||
static const size_t kRkeyStartIndex =
|
||||
kRemoteAddrStartIndex + sizeof(remote_addr_);
|
||||
static const size_t kIsDeadStartIndex = kRkeyStartIndex + sizeof(rkey_);
|
||||
static const size_t kDataTypeStartIndex = kIsDeadStartIndex
|
||||
+ sizeof(is_dead_);
|
||||
static const size_t kTensorShapeStartIndex = kDataTypeStartIndex
|
||||
+ sizeof(data_type_);
|
||||
static const size_t kTensorBytesStartIndex = kTensorShapeStartIndex
|
||||
+ sizeof(TensorShape);
|
||||
static const size_t kTensorBufferStartIndex = kTensorBytesStartIndex
|
||||
+ sizeof(tensor_bytes_);
|
||||
static const size_t kDataTypeStartIndex =
|
||||
kIsDeadStartIndex + sizeof(is_dead_);
|
||||
static const size_t kTensorShapeStartIndex =
|
||||
kDataTypeStartIndex + sizeof(data_type_);
|
||||
static const size_t kTensorBytesStartIndex =
|
||||
kTensorShapeStartIndex + sizeof(TensorShape);
|
||||
static const size_t kTensorBufferStartIndex =
|
||||
kTensorBytesStartIndex + sizeof(tensor_bytes_);
|
||||
static const size_t kMessageTotalBytes = kTensorBufferStartIndex;
|
||||
static const size_t kRdmaMessageBufferSize = kMessageTotalBytes;
|
||||
static const size_t kRdmaAckBufferSize = kMessageTotalBytes;
|
||||
static string CreateMessage(const RdmaMessage & rm);
|
||||
static string CreateMessage(const RdmaMessage& rm);
|
||||
static void ParseMessage(RdmaMessage& rm, void* buffer);
|
||||
};
|
||||
|
||||
} // namespace tensorflow
|
||||
} // namespace tensorflow
|
||||
|
||||
#endif // TENSORFLOW_USE_VERBS
|
||||
#endif // TENSORFLOW_USE_VERBS
|
||||
#endif // THIRD_PARTY_TENSORFLOW_CONTRIB_VERBS_RDMA_H_
|
||||
|
@ -15,8 +15,8 @@ limitations under the License.
|
||||
|
||||
#ifdef TENSORFLOW_USE_VERBS
|
||||
|
||||
#include<vector>
|
||||
#include "tensorflow/contrib/verbs/rdma_mgr.h"
|
||||
#include <vector>
|
||||
#include "tensorflow/contrib/verbs/grpc_verbs_client.h"
|
||||
#include "tensorflow/contrib/verbs/verbs_service.pb.h"
|
||||
#include "tensorflow/core/distributed_runtime/rpc/grpc_worker_cache.h"
|
||||
@ -25,7 +25,7 @@ limitations under the License.
|
||||
|
||||
namespace tensorflow {
|
||||
|
||||
RdmaMgr::RdmaMgr(const WorkerEnv* const worker_env,
|
||||
RdmaMgr::RdmaMgr(const WorkerEnv* const worker_env,
|
||||
GrpcChannelCache* const channel_cache)
|
||||
: worker_env_(worker_env), channel_cache_(channel_cache) {
|
||||
rdma_adapter_ = new RdmaAdapter(worker_env_);
|
||||
@ -34,14 +34,15 @@ RdmaMgr::RdmaMgr(const WorkerEnv* const worker_env,
|
||||
// need to pass in session handle
|
||||
local_worker_ = worker_env_->session_mgr->LegacySession()->worker_name;
|
||||
std::vector<string> workers;
|
||||
worker_env_->session_mgr->LegacySession()->
|
||||
worker_cache->ListWorkers(&workers);
|
||||
num_remote_workers_ = workers.size()-1;
|
||||
worker_env_->session_mgr->LegacySession()->worker_cache->ListWorkers(
|
||||
&workers);
|
||||
num_remote_workers_ = workers.size() - 1;
|
||||
VLOG(2) << "rmda_mgr on local worker: " << local_worker_;
|
||||
for (size_t i = 0; i < workers.size(); i++) {
|
||||
if (local_worker_.compare(workers[i]) != 0) {
|
||||
channel_table_.insert({workers[i], new RdmaChannel(rdma_adapter_,
|
||||
local_worker_, workers[i])});
|
||||
channel_table_.insert(
|
||||
{workers[i],
|
||||
new RdmaChannel(rdma_adapter_, local_worker_, workers[i])});
|
||||
}
|
||||
}
|
||||
}
|
||||
@ -49,16 +50,16 @@ RdmaMgr::RdmaMgr(const WorkerEnv* const worker_env,
|
||||
// Setup Rdma channels between peers.
|
||||
// This is done at the beginning of the server setup.
|
||||
|
||||
void RdmaMgr::SetupChannels() {
|
||||
void RdmaMgr::SetupChannels() {
|
||||
for (const auto& p : channel_table_) {
|
||||
string worker_name = p.first;
|
||||
LOG(INFO) << "connecting to remote node " << worker_name;
|
||||
RdmaChannel* rc = p.second;
|
||||
GetRemoteAddressRequest req;
|
||||
GetRemoteAddressResponse resp;
|
||||
GetRemoteAddressResponse resp;
|
||||
// get the channel cache
|
||||
SharedGrpcChannelPtr client_channel = channel_cache_
|
||||
->FindWorkerChannel(worker_name);
|
||||
SharedGrpcChannelPtr client_channel =
|
||||
channel_cache_->FindWorkerChannel(worker_name);
|
||||
GrpcVerbsClient* client = new GrpcVerbsClient(client_channel);
|
||||
CHECK(client != nullptr) << "No worker known as " << worker_name;
|
||||
|
||||
@ -70,8 +71,8 @@ void RdmaMgr::SetupChannels() {
|
||||
channel_info->set_psn(rc->self_.psn);
|
||||
for (int i = 0; i < RdmaChannel::kNumMessageBuffers; i++) {
|
||||
MemoryRegion* mr = req.add_mr();
|
||||
mr->set_remote_addr(reinterpret_cast<uint64_t>(
|
||||
rc->message_buffers_[i]->buffer_));
|
||||
mr->set_remote_addr(
|
||||
reinterpret_cast<uint64_t>(rc->message_buffers_[i]->buffer_));
|
||||
mr->set_rkey(rc->message_buffers_[i]->self_->rkey);
|
||||
}
|
||||
// synchronous call
|
||||
@ -79,10 +80,10 @@ void RdmaMgr::SetupChannels() {
|
||||
// save obtained remote addresses
|
||||
// connect to the remote channel
|
||||
if (s.ok()) {
|
||||
CHECK(worker_name.compare(resp.host_name())==0);
|
||||
CHECK(worker_name.compare(resp.host_name()) == 0);
|
||||
RdmaAddress ra;
|
||||
ra.lid = resp.channel().lid();
|
||||
ra.qpn = resp.channel().qpn();
|
||||
ra.qpn = resp.channel().qpn();
|
||||
ra.psn = resp.channel().psn();
|
||||
rc->SetRemoteAddress(ra, false);
|
||||
rc->Connect();
|
||||
@ -112,7 +113,7 @@ void RdmaMgr::SetupChannels() {
|
||||
|
||||
RdmaMgr::~RdmaMgr() {
|
||||
for (const auto& p : channel_table_) delete p.second;
|
||||
channel_table_.clear();
|
||||
channel_table_.clear();
|
||||
delete rdma_adapter_;
|
||||
}
|
||||
|
||||
|
@ -28,9 +28,8 @@ limitations under the License.
|
||||
namespace tensorflow {
|
||||
|
||||
class RdmaMgr {
|
||||
|
||||
public:
|
||||
explicit RdmaMgr(const WorkerEnv* const worker_env,
|
||||
explicit RdmaMgr(const WorkerEnv* const worker_env,
|
||||
GrpcChannelCache* const channel_cache);
|
||||
~RdmaMgr();
|
||||
RdmaChannel* FindChannel(const string& key);
|
||||
@ -45,11 +44,11 @@ class RdmaMgr {
|
||||
RdmaAdapter* rdma_adapter_;
|
||||
typedef std::unordered_map<string, RdmaChannel*> ChannelTable;
|
||||
ChannelTable channel_table_;
|
||||
|
||||
|
||||
TF_DISALLOW_COPY_AND_ASSIGN(RdmaMgr);
|
||||
};
|
||||
|
||||
} // namespace tensorflow
|
||||
} // namespace tensorflow
|
||||
|
||||
#endif // TENSORFLOW_USE_VERBS
|
||||
#endif // TENSORFLOW_USE_VERBS
|
||||
#endif // THIRD_PARTY_TENSORFLOW_CONTRIB_VERBS_RDMA_MGR_H_
|
||||
|
@ -15,8 +15,8 @@ limitations under the License.
|
||||
|
||||
#ifdef TENSORFLOW_USE_VERBS
|
||||
|
||||
#include <unordered_set>
|
||||
#include "tensorflow/contrib/verbs/rdma_rendezvous_mgr.h"
|
||||
#include <unordered_set>
|
||||
#include "tensorflow/contrib/verbs/verbs_util.h"
|
||||
#include "tensorflow/core/common_runtime/device.h"
|
||||
#include "tensorflow/core/common_runtime/device_mgr.h"
|
||||
@ -29,14 +29,16 @@ namespace tensorflow {
|
||||
|
||||
class RdmaRemoteRendezvous : public BaseRemoteRendezvous {
|
||||
public:
|
||||
RdmaRemoteRendezvous(const WorkerEnv* env, const string& worker_name,
|
||||
RdmaRemoteRendezvous(const WorkerEnv* env, const string& worker_name,
|
||||
int64 step_id, RdmaMgr* rdma_mgr)
|
||||
: BaseRemoteRendezvous(env, worker_name, step_id, true),
|
||||
: BaseRemoteRendezvous(env, worker_name, step_id, true),
|
||||
rdma_mgr_(rdma_mgr) {}
|
||||
|
||||
protected:
|
||||
void RecvFromRemoteAsync(const Rendezvous::ParsedKey& parsed,
|
||||
const Rendezvous::Args& args,
|
||||
DoneCallback done) override;
|
||||
|
||||
private:
|
||||
~RdmaRemoteRendezvous() override {}
|
||||
RdmaMgr* rdma_mgr_;
|
||||
@ -45,13 +47,13 @@ class RdmaRemoteRendezvous : public BaseRemoteRendezvous {
|
||||
};
|
||||
|
||||
void RdmaRemoteRendezvous::RecvFromRemoteAsync(
|
||||
const Rendezvous::ParsedKey& parsed, const Rendezvous::Args& recv_args,
|
||||
DoneCallback done) {
|
||||
const Rendezvous::ParsedKey& parsed, const Rendezvous::Args& recv_args,
|
||||
DoneCallback done) {
|
||||
Status s;
|
||||
// parse src_name and dst_name
|
||||
string src_name, dst_name, unused;
|
||||
if (!DeviceNameUtils::SplitDeviceName(parsed.src_device,
|
||||
&src_name, &unused)) {
|
||||
if (!DeviceNameUtils::SplitDeviceName(parsed.src_device, &src_name,
|
||||
&unused)) {
|
||||
s = errors::Internal("Could not parse src name.");
|
||||
}
|
||||
CHECK(s.ok()) << "s is not ok, error code " << s.error_message();
|
||||
@ -59,8 +61,8 @@ void RdmaRemoteRendezvous::RecvFromRemoteAsync(
|
||||
done(s, Args(), recv_args, Tensor{}, false);
|
||||
return;
|
||||
}
|
||||
if (!DeviceNameUtils::SplitDeviceName(parsed.dst_device,
|
||||
&dst_name, &unused)) {
|
||||
if (!DeviceNameUtils::SplitDeviceName(parsed.dst_device, &dst_name,
|
||||
&unused)) {
|
||||
s = errors::Internal("Could not parse dst name.");
|
||||
}
|
||||
CHECK(s.ok()) << "s is not ok, error code " << s.error_message();
|
||||
@ -73,52 +75,52 @@ void RdmaRemoteRendezvous::RecvFromRemoteAsync(
|
||||
string key(std::move(parsed.FullKey().ToString()));
|
||||
string key_with_step_id = VerbsUtil::AppendStepidToKey(key, step_id_);
|
||||
// insert callback
|
||||
rc->InsertRecvCallback(key_with_step_id,
|
||||
[this, key, key_with_step_id, rc, recv_args, parsed, done](){
|
||||
Status s;
|
||||
Device* src_dev;
|
||||
s = env_->device_mgr->LookupDevice("CPU:0", &src_dev);
|
||||
CHECK(s.ok()) << "s is not ok, error code " << s.error_message();
|
||||
if (!s.ok()) {
|
||||
done(s, Args(), recv_args, Tensor(), true);
|
||||
return;
|
||||
}
|
||||
Device* dst_dev;
|
||||
s = env_->device_mgr->LookupDevice(parsed.dst_device, &dst_dev);
|
||||
CHECK(s.ok()) << "s is not ok, error code " << s.error_message();
|
||||
if (!s.ok()) {
|
||||
done(s, Args(), recv_args, Tensor(), true);
|
||||
return;
|
||||
}
|
||||
RdmaBuffer* rb = rc->FindBuffer(key);
|
||||
RdmaMessage rm;
|
||||
CHECK(rb->size_ >= RdmaMessage::kMessageTotalBytes);
|
||||
RdmaMessage::ParseMessage(rm, rb->buffer_);
|
||||
CHECK(rm.type_ == RDMA_MESSAGE_TENSOR_WRITE);
|
||||
Tensor val;
|
||||
if (!rm.is_dead_) {
|
||||
void* input = static_cast<char*>(rb->buffer_) +
|
||||
RdmaMessage::kTensorBufferStartIndex;
|
||||
TensorProto proto;
|
||||
CHECK(rm.tensor_bytes_ + RdmaMessage::kTensorBufferStartIndex <= rb->size_);
|
||||
CHECK(ParseProtoUnlimited(&proto, input, rm.tensor_bytes_))
|
||||
<< "fail to parse proto from array";
|
||||
s = dst_dev->MakeTensorFromProto(proto,
|
||||
recv_args.alloc_attrs, &val);
|
||||
}
|
||||
|
||||
rc->RemoveRecvCallback(key_with_step_id);
|
||||
// create message
|
||||
RdmaMessage br;
|
||||
br.type_ = RDMA_MESSAGE_BUFFER_IDLE;
|
||||
br.name_size_ = key.size();
|
||||
br.name_ = key;
|
||||
string message = RdmaMessage::CreateMessage(br);
|
||||
RdmaBuffer* tb = rc->tx_message_buffer_;
|
||||
tb->EnqueueItem(message);
|
||||
tb->SendNextItem();
|
||||
done(s, Args(), recv_args, val, rm.is_dead_);
|
||||
});
|
||||
rc->InsertRecvCallback(key_with_step_id, [this, key, key_with_step_id, rc,
|
||||
recv_args, parsed, done]() {
|
||||
Status s;
|
||||
Device* src_dev;
|
||||
s = env_->device_mgr->LookupDevice("CPU:0", &src_dev);
|
||||
CHECK(s.ok()) << "s is not ok, error code " << s.error_message();
|
||||
if (!s.ok()) {
|
||||
done(s, Args(), recv_args, Tensor(), true);
|
||||
return;
|
||||
}
|
||||
Device* dst_dev;
|
||||
s = env_->device_mgr->LookupDevice(parsed.dst_device, &dst_dev);
|
||||
CHECK(s.ok()) << "s is not ok, error code " << s.error_message();
|
||||
if (!s.ok()) {
|
||||
done(s, Args(), recv_args, Tensor(), true);
|
||||
return;
|
||||
}
|
||||
RdmaBuffer* rb = rc->FindBuffer(key);
|
||||
RdmaMessage rm;
|
||||
CHECK(rb->size_ >= RdmaMessage::kMessageTotalBytes);
|
||||
RdmaMessage::ParseMessage(rm, rb->buffer_);
|
||||
CHECK(rm.type_ == RDMA_MESSAGE_TENSOR_WRITE);
|
||||
Tensor val;
|
||||
if (!rm.is_dead_) {
|
||||
void* input = static_cast<char*>(rb->buffer_) +
|
||||
RdmaMessage::kTensorBufferStartIndex;
|
||||
TensorProto proto;
|
||||
CHECK(rm.tensor_bytes_ + RdmaMessage::kTensorBufferStartIndex <=
|
||||
rb->size_);
|
||||
CHECK(ParseProtoUnlimited(&proto, input, rm.tensor_bytes_))
|
||||
<< "fail to parse proto from array";
|
||||
s = dst_dev->MakeTensorFromProto(proto, recv_args.alloc_attrs, &val);
|
||||
}
|
||||
|
||||
rc->RemoveRecvCallback(key_with_step_id);
|
||||
// create message
|
||||
RdmaMessage br;
|
||||
br.type_ = RDMA_MESSAGE_BUFFER_IDLE;
|
||||
br.name_size_ = key.size();
|
||||
br.name_ = key;
|
||||
string message = RdmaMessage::CreateMessage(br);
|
||||
RdmaBuffer* tb = rc->tx_message_buffer_;
|
||||
tb->EnqueueItem(message);
|
||||
tb->SendNextItem();
|
||||
done(s, Args(), recv_args, val, rm.is_dead_);
|
||||
});
|
||||
// append key to message queue
|
||||
RdmaBuffer* rb = rc->tx_message_buffer_;
|
||||
RdmaMessage rm;
|
||||
@ -141,7 +143,7 @@ BaseRemoteRendezvous* RdmaRendezvousMgr::Create(int64 step_id,
|
||||
const string& worker_name) {
|
||||
return new RdmaRemoteRendezvous(worker_env, worker_name, step_id, rdma_mgr_);
|
||||
}
|
||||
|
||||
|
||||
} // end namespace tensorflow
|
||||
|
||||
#endif
|
||||
|
@ -47,12 +47,12 @@ class RdmaRendezvousMgr : public BaseRendezvousMgr {
|
||||
public:
|
||||
explicit RdmaRendezvousMgr(const WorkerEnv* env, const string& worker_name,
|
||||
WorkerCacheInterface* worker_cache);
|
||||
void SetRdmaMgr(RdmaMgr* rdma_mgr) {
|
||||
rdma_mgr_ = rdma_mgr;
|
||||
}
|
||||
void SetRdmaMgr(RdmaMgr* rdma_mgr) { rdma_mgr_ = rdma_mgr; }
|
||||
|
||||
protected:
|
||||
BaseRemoteRendezvous* Create(int64 step_id, const WorkerEnv* worker_env,
|
||||
const string& worker_name) override;
|
||||
const string& worker_name) override;
|
||||
|
||||
private:
|
||||
RdmaMgr* rdma_mgr_;
|
||||
TF_DISALLOW_COPY_AND_ASSIGN(RdmaRendezvousMgr);
|
||||
@ -60,5 +60,5 @@ class RdmaRendezvousMgr : public BaseRendezvousMgr {
|
||||
|
||||
} // end namespace tensorflow
|
||||
|
||||
#endif // TENSORFLOW_USE_VERBS
|
||||
#endif // TENSORFLOW_USE_VERBS
|
||||
#endif // THIRD_PARTY_TENSORFLOW_CONTRIB_VERBS_RDMA_RENDEZVOUS_MGR_H_
|
||||
|
@ -27,8 +27,9 @@ namespace tensorflow {
|
||||
|
||||
namespace {
|
||||
// static utility function
|
||||
RendezvousMgrInterface* NewRdmaRendezvousMgr(const WorkerEnv* env,
|
||||
const string& worker_name, WorkerCacheInterface* worker_cache) {
|
||||
RendezvousMgrInterface* NewRdmaRendezvousMgr(
|
||||
const WorkerEnv* env, const string& worker_name,
|
||||
WorkerCacheInterface* worker_cache) {
|
||||
return new RdmaRendezvousMgr(env, worker_name, worker_cache);
|
||||
}
|
||||
|
||||
@ -46,7 +47,7 @@ VerbsServer::~VerbsServer() {
|
||||
}
|
||||
|
||||
Status VerbsServer::ChannelCacheFactory(const ServerDef& server_def,
|
||||
GrpcChannelCache** channel_cache) {
|
||||
GrpcChannelCache** channel_cache) {
|
||||
string name_prefix =
|
||||
strings::StrCat("/job:", server_def.job_name(), "/replica:0",
|
||||
"/task:", server_def.task_index());
|
||||
@ -54,41 +55,43 @@ Status VerbsServer::ChannelCacheFactory(const ServerDef& server_def,
|
||||
GrpcChannelSpec channel_spec;
|
||||
TF_RETURN_IF_ERROR(ParseChannelSpec(server_def, &channel_spec));
|
||||
|
||||
*channel_cache = NewGrpcChannelCache(channel_spec,
|
||||
GetChannelCreationFunction(server_def));
|
||||
|
||||
*channel_cache =
|
||||
NewGrpcChannelCache(channel_spec, GetChannelCreationFunction(server_def));
|
||||
|
||||
const string host_port = (*channel_cache)->TranslateTask(name_prefix);
|
||||
int requested_port;
|
||||
|
||||
if (!strings::safe_strto32(str_util::Split(host_port, ':')[1],
|
||||
&requested_port)) {
|
||||
return errors::Internal("Could not parse port for local server from \"",
|
||||
(*channel_cache)->TranslateTask(name_prefix), "\".");
|
||||
(*channel_cache)->TranslateTask(name_prefix),
|
||||
"\".");
|
||||
}
|
||||
if (requested_port != bound_port()) {
|
||||
return errors::InvalidArgument("Requested port ", requested_port,
|
||||
" differs from expected port ", bound_port());
|
||||
" differs from expected port ",
|
||||
bound_port());
|
||||
}
|
||||
|
||||
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
Status VerbsServer::Init(ServiceInitFunction service_func,
|
||||
RendezvousMgrCreationFunction rendezvous_mgr_func) {
|
||||
RendezvousMgrCreationFunction rendezvous_mgr_func) {
|
||||
Status s = GrpcServer::Init(service_func, rendezvous_mgr_func);
|
||||
{
|
||||
mutex_lock l(mu_);
|
||||
CHECK_EQ(verbs_state_, DISCONNECTED);
|
||||
CHECK(ChannelCacheFactory(server_def(), &channel_cache_).ok());
|
||||
rdma_mgr_ = new RdmaMgr(worker_env(), channel_cache_);
|
||||
// set rdma_mgr for verbs_service and rdma_rendezvous_mgr
|
||||
// set rdma_mgr for verbs_service and rdma_rendezvous_mgr
|
||||
verbs_service_->SetRdmaMgr(rdma_mgr_);
|
||||
// hardcoded to default session (legacy_session_)
|
||||
// TODO: use WorkerSessionForSession
|
||||
// need to pass in session handle
|
||||
dynamic_cast<RdmaRendezvousMgr*>(worker_env()->session_mgr->
|
||||
LegacySession()->rendezvous_mgr.get())
|
||||
->SetRdmaMgr(rdma_mgr_);
|
||||
dynamic_cast<RdmaRendezvousMgr*>(
|
||||
worker_env()->session_mgr->LegacySession()->rendezvous_mgr.get())
|
||||
->SetRdmaMgr(rdma_mgr_);
|
||||
}
|
||||
return s;
|
||||
}
|
||||
@ -100,9 +103,9 @@ Status VerbsServer::Start() {
|
||||
if (verbs_state_ == DISCONNECTED) {
|
||||
// verbs_thread needs to be initiated
|
||||
// before rdma_mgr sets up the rdma channels.
|
||||
verbs_thread_.reset(
|
||||
worker_env()->env->StartThread(ThreadOptions(), "TF_verbs_service",
|
||||
[this] { verbs_service_->HandleRPCsLoop(); }));
|
||||
verbs_thread_.reset(worker_env()->env->StartThread(
|
||||
ThreadOptions(), "TF_verbs_service",
|
||||
[this] { verbs_service_->HandleRPCsLoop(); }));
|
||||
rdma_mgr_->SetupChannels();
|
||||
verbs_state_ = CONNECTED;
|
||||
}
|
||||
@ -124,10 +127,10 @@ Status VerbsServer::Join() {
|
||||
|
||||
/* static */
|
||||
Status VerbsServer::Create(const ServerDef& server_def, Env* env,
|
||||
std::unique_ptr<ServerInterface>* out_server) {
|
||||
std::unique_ptr<ServerInterface>* out_server) {
|
||||
std::unique_ptr<VerbsServer> ret(new VerbsServer(server_def, Env::Default()));
|
||||
ServiceInitFunction service_func = [&ret](const WorkerEnv* worker_env,
|
||||
::grpc::ServerBuilder* builder) {
|
||||
::grpc::ServerBuilder* builder) {
|
||||
return SetNewVerbsService(&ret->verbs_service_, worker_env, builder);
|
||||
};
|
||||
TF_RETURN_IF_ERROR(ret->Init(service_func, NewRdmaRendezvousMgr));
|
||||
|
@ -18,8 +18,8 @@ limitations under the License.
|
||||
|
||||
#ifdef TENSORFLOW_USE_VERBS
|
||||
|
||||
#include "tensorflow/contrib/verbs/rdma_mgr.h"
|
||||
#include "tensorflow/contrib/verbs/grpc_verbs_service.h"
|
||||
#include "tensorflow/contrib/verbs/rdma_mgr.h"
|
||||
#include "tensorflow/core/distributed_runtime/rpc/grpc_server_lib.h"
|
||||
|
||||
namespace tensorflow {
|
||||
@ -27,7 +27,7 @@ namespace tensorflow {
|
||||
class VerbsServer : public GrpcServer {
|
||||
protected:
|
||||
VerbsServer(const ServerDef& server_def, Env* env);
|
||||
|
||||
|
||||
public:
|
||||
static Status Create(const ServerDef& server_def, Env* env,
|
||||
std::unique_ptr<ServerInterface>* out_server);
|
||||
@ -39,21 +39,22 @@ class VerbsServer : public GrpcServer {
|
||||
// Implementations of ServerInterface methods.
|
||||
Status Start() override;
|
||||
Status Join() override;
|
||||
|
||||
|
||||
protected:
|
||||
Status Init(ServiceInitFunction service_func,
|
||||
RendezvousMgrCreationFunction rendezvous_mgr_func);
|
||||
Status ChannelCacheFactory(const ServerDef& server_def,
|
||||
GrpcChannelCache** channel_cache);
|
||||
|
||||
private:
|
||||
RdmaMgr* rdma_mgr_;
|
||||
|
||||
|
||||
// Guards state transitions.
|
||||
mutex mu_;
|
||||
|
||||
|
||||
enum State { DISCONNECTED, CONNECTED };
|
||||
State verbs_state_ GUARDED_BY(mu_);
|
||||
|
||||
|
||||
GrpcVerbsService* verbs_service_ = nullptr;
|
||||
std::unique_ptr<Thread> verbs_thread_ GUARDED_BY(mu_);
|
||||
GrpcChannelCache* channel_cache_ = nullptr;
|
||||
@ -61,5 +62,5 @@ class VerbsServer : public GrpcServer {
|
||||
|
||||
} // namespace tensorflow
|
||||
|
||||
#endif // TENSORFLOW_USE_VERBS
|
||||
#endif // TENSORFLOW_USE_VERBS
|
||||
#endif // THIRD_PARTY_TENSORFLOW_CONTRIB_VERBS_VERBS_SERVER_LIB_H_
|
||||
|
@ -29,7 +29,7 @@ option java_package = "org.tensorflow.contrib.verbs";
|
||||
message Channel {
|
||||
int32 lid = 1;
|
||||
int32 qpn = 2;
|
||||
int32 psn = 3;
|
||||
int32 psn = 3;
|
||||
}
|
||||
|
||||
message MemoryRegion {
|
||||
@ -39,15 +39,14 @@ message MemoryRegion {
|
||||
message GetRemoteAddressRequest {
|
||||
string host_name = 1;
|
||||
Channel channel = 2;
|
||||
repeated MemoryRegion mr = 3;
|
||||
repeated MemoryRegion mr = 3;
|
||||
}
|
||||
|
||||
message GetRemoteAddressResponse {
|
||||
string host_name = 1;
|
||||
Channel channel = 2;
|
||||
repeated MemoryRegion mr = 3;
|
||||
}
|
||||
|
||||
repeated MemoryRegion mr = 3;
|
||||
}
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////
|
||||
//
|
||||
@ -56,5 +55,6 @@ message GetRemoteAddressResponse {
|
||||
////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
service VerbsService {
|
||||
rpc GetRemoteAddress(GetRemoteAddressRequest) returns (GetRemoteAddressResponse);
|
||||
rpc GetRemoteAddress(GetRemoteAddressRequest)
|
||||
returns (GetRemoteAddressResponse);
|
||||
}
|
||||
|
@ -22,30 +22,27 @@ namespace tensorflow {
|
||||
|
||||
// static sync wrapper:
|
||||
Status VerbsUtil::SetProtoFromGPUSync(const Tensor& tensor, Device* dev,
|
||||
const DeviceContext* device_context,
|
||||
TensorProto* proto, bool is_dead) {
|
||||
const DeviceContext* device_context,
|
||||
TensorProto* proto, bool is_dead) {
|
||||
Notification n;
|
||||
Status status;
|
||||
GPUUtil::SetProtoFromGPU(tensor, dev,
|
||||
device_context,
|
||||
proto, is_dead,
|
||||
[&n, &status](const Status& s) {
|
||||
status = s;
|
||||
n.Notify();
|
||||
});
|
||||
GPUUtil::SetProtoFromGPU(tensor, dev, device_context, proto, is_dead,
|
||||
[&n, &status](const Status& s) {
|
||||
status = s;
|
||||
n.Notify();
|
||||
});
|
||||
n.WaitForNotification();
|
||||
return status;
|
||||
}
|
||||
|
||||
//static
|
||||
string VerbsUtil::AppendStepidToKey(const string& key,
|
||||
int64 step_id) {
|
||||
// static
|
||||
string VerbsUtil::AppendStepidToKey(const string& key, int64 step_id) {
|
||||
return strings::StrCat(key, ";", step_id);
|
||||
}
|
||||
|
||||
// static
|
||||
void VerbsUtil::GetKeyAndStepId(const string& key_with_step_id,
|
||||
string& key, int64& step_id) {
|
||||
void VerbsUtil::GetKeyAndStepId(const string& key_with_step_id, string& key,
|
||||
int64& step_id) {
|
||||
StringPiece s(key_with_step_id);
|
||||
// a key (with step_id) has exact 6 parts if split by ";"
|
||||
// part 1: src_device;
|
||||
@ -55,10 +52,10 @@ void VerbsUtil::GetKeyAndStepId(const string& key_with_step_id,
|
||||
// part 5: frame_iter.frame_id:frame_iter.iter_id
|
||||
// part 6: step_id
|
||||
std::vector<string> parts = str_util::Split(s, ';');
|
||||
CHECK(parts.size()==6) << "Key with step_id must have 6 parts";
|
||||
CHECK(parts.size() == 6) << "Key with step_id must have 6 parts";
|
||||
strings::safe_strto64(parts[5], &step_id);
|
||||
parts.pop_back(); // remove step_id
|
||||
key.assign(str_util::Join(parts, ";")); // stitch them together
|
||||
parts.pop_back(); // remove step_id
|
||||
key.assign(str_util::Join(parts, ";")); // stitch them together
|
||||
}
|
||||
|
||||
} // namespace tensorflow
|
||||
|
@ -28,14 +28,13 @@ class TensorProto;
|
||||
|
||||
class VerbsUtil {
|
||||
public:
|
||||
|
||||
// synchronous wrapper of SetProtoFromGPU
|
||||
static Status SetProtoFromGPUSync(const Tensor& tensor, Device* dev,
|
||||
const DeviceContext* device_context,
|
||||
TensorProto* proto, bool is_dead);
|
||||
const DeviceContext* device_context,
|
||||
TensorProto* proto, bool is_dead);
|
||||
static string AppendStepidToKey(const string& key, int64 step_id);
|
||||
static void GetKeyAndStepId(const string& key_with_step_id, string& key,
|
||||
int64& step_id);
|
||||
static void GetKeyAndStepId(const string& key_with_step_id, string& key,
|
||||
int64& step_id);
|
||||
};
|
||||
|
||||
} // namespace tensorflow
|
||||
|
@ -1570,6 +1570,7 @@ tf_cuda_library(
|
||||
":lib_internal",
|
||||
":proto_text",
|
||||
":protos_all_cc",
|
||||
"//tensorflow/core/kernels:function_ops",
|
||||
],
|
||||
alwayslink = 1,
|
||||
)
|
||||
|
@ -30,6 +30,11 @@ struct BuildGraphOptions {
|
||||
// the former via "ref" fetch_endpoints.
|
||||
std::vector<string> target_nodes;
|
||||
|
||||
// If `true`, uses Arg/Retval to implement feeds/fetches; otherwise
|
||||
// uses Recv/Send to implement feeds/fetches.
|
||||
// TODO(mrry): Remove this when the distributed runtime supports Arg/Retval.
|
||||
bool use_function_convention = false;
|
||||
|
||||
string DebugString() const;
|
||||
};
|
||||
|
||||
|
@ -43,7 +43,7 @@ namespace tensorflow {
|
||||
namespace {
|
||||
|
||||
bool IsConstantFoldable(const Node* n,
|
||||
std::function<bool(const Node*)> consider) {
|
||||
const std::function<bool(const Node*)>& consider) {
|
||||
if (n->op_def().is_stateful()) {
|
||||
return false;
|
||||
}
|
||||
|
@ -71,7 +71,8 @@ void CopyTensor::ViaDMA(StringPiece edge_name, DeviceContext* send_dev_context,
|
||||
if (ri.sender_device_type == src_device_type &&
|
||||
ri.receiver_device_type == dst_device_type) {
|
||||
ri.copy_function(send_dev_context, recv_dev_context, src, dst,
|
||||
src_alloc_attr, dst_alloc_attr, input, output, done);
|
||||
src_alloc_attr, dst_alloc_attr, input, output,
|
||||
std::move(done));
|
||||
return;
|
||||
}
|
||||
}
|
||||
|
@ -361,7 +361,6 @@ Status DirectSession::ExtendLocked(const GraphDef& graph) {
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
// TODO(yuanbyu): Simplify by treating Run() as "PRunSetup(); PRun()".
|
||||
Status DirectSession::Run(const NamedTensorList& inputs,
|
||||
const std::vector<string>& output_names,
|
||||
const std::vector<string>& target_nodes,
|
||||
@ -426,13 +425,34 @@ Status DirectSession::Run(const RunOptions& run_options,
|
||||
executor_step_count, input_tensor_names, output_names, target_nodes));
|
||||
}
|
||||
|
||||
// Configure a call frame for the step, which we use to feed and
|
||||
// fetch values to and from the executors.
|
||||
FunctionCallFrame call_frame(executors_and_keys->input_types,
|
||||
executors_and_keys->output_types);
|
||||
gtl::InlinedVector<Tensor, 4> feed_args(inputs.size());
|
||||
for (const auto& it : inputs) {
|
||||
if (it.second.dtype() == DT_RESOURCE) {
|
||||
Tensor tensor_from_handle;
|
||||
TF_RETURN_IF_ERROR(
|
||||
ResourceHandleToInputTensor(it.second, &tensor_from_handle));
|
||||
feed_args[executors_and_keys->input_name_to_index[it.first]] =
|
||||
tensor_from_handle;
|
||||
} else {
|
||||
feed_args[executors_and_keys->input_name_to_index[it.first]] = it.second;
|
||||
}
|
||||
}
|
||||
Status s = call_frame.SetArgs(feed_args);
|
||||
if (errors::IsInternal(s)) {
|
||||
return errors::InvalidArgument(s.error_message());
|
||||
} else if (!s.ok()) {
|
||||
return s;
|
||||
}
|
||||
|
||||
// Create a run state and start execution.
|
||||
RunState run_state(args.step_id, &devices_);
|
||||
run_state.rendez = new IntraProcessRendezvous(device_mgr_.get());
|
||||
CancellationManager step_cancellation_manager;
|
||||
|
||||
// Send inputs.
|
||||
TF_RETURN_IF_ERROR(SendInputs(inputs, executors_and_keys, run_state.rendez));
|
||||
args.call_frame = &call_frame;
|
||||
|
||||
// Start parallel Executors.
|
||||
const size_t num_executors = executors_and_keys->items.size();
|
||||
@ -535,8 +555,22 @@ Status DirectSession::Run(const RunOptions& run_options,
|
||||
}
|
||||
|
||||
// Receive outputs.
|
||||
TF_RETURN_IF_ERROR(
|
||||
RecvOutputs(output_names, executors_and_keys, &run_state, outputs));
|
||||
if (outputs) {
|
||||
std::vector<Tensor> sorted_outputs;
|
||||
Status s = call_frame.ConsumeRetvals(&sorted_outputs);
|
||||
if (errors::IsInternal(s)) {
|
||||
return errors::InvalidArgument(s.error_message());
|
||||
} else if (!s.ok()) {
|
||||
return s;
|
||||
}
|
||||
outputs->clear();
|
||||
outputs->reserve(sorted_outputs.size());
|
||||
for (const string& output_name : output_names) {
|
||||
outputs->emplace_back(
|
||||
std::move(sorted_outputs[executors_and_keys
|
||||
->output_name_to_index[output_name]]));
|
||||
}
|
||||
}
|
||||
|
||||
// Save the output tensors of this run we choose to keep.
|
||||
TF_RETURN_IF_ERROR(
|
||||
@ -706,11 +740,11 @@ Status DirectSession::PRun(const string& handle, const NamedTensorList& inputs,
|
||||
CheckFetch(inputs, output_names, executors_and_keys, run_state));
|
||||
|
||||
// Send inputs.
|
||||
Status s = SendInputs(inputs, executors_and_keys, run_state->rendez);
|
||||
Status s = SendPRunInputs(inputs, executors_and_keys, run_state->rendez);
|
||||
|
||||
// Receive outputs.
|
||||
if (s.ok()) {
|
||||
s = RecvOutputs(output_names, executors_and_keys, run_state, outputs);
|
||||
s = RecvPRunOutputs(output_names, executors_and_keys, run_state, outputs);
|
||||
}
|
||||
|
||||
// Save the output tensors of this run we choose to keep.
|
||||
@ -770,16 +804,17 @@ Status DirectSession::ResourceHandleToInputTensor(const Tensor& resource_tensor,
|
||||
}
|
||||
}
|
||||
|
||||
Status DirectSession::SendInputs(const NamedTensorList& inputs,
|
||||
const ExecutorsAndKeys* executors_and_keys,
|
||||
IntraProcessRendezvous* rendez) {
|
||||
Status DirectSession::SendPRunInputs(const NamedTensorList& inputs,
|
||||
const ExecutorsAndKeys* executors_and_keys,
|
||||
IntraProcessRendezvous* rendez) {
|
||||
Status s;
|
||||
Rendezvous::ParsedKey parsed;
|
||||
// Insert the input tensors into the local rendezvous by their
|
||||
// rendezvous key.
|
||||
for (const auto& input : inputs) {
|
||||
auto it = executors_and_keys->input_keys.find(input.first);
|
||||
if (it == executors_and_keys->input_keys.end()) {
|
||||
auto it =
|
||||
executors_and_keys->input_name_to_rendezvous_key.find(input.first);
|
||||
if (it == executors_and_keys->input_name_to_rendezvous_key.end()) {
|
||||
return errors::Internal("'", input.first, "' is not a pre-defined feed.");
|
||||
}
|
||||
const string& input_key = it->second;
|
||||
@ -808,10 +843,10 @@ Status DirectSession::SendInputs(const NamedTensorList& inputs,
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
Status DirectSession::RecvOutputs(const std::vector<string>& output_names,
|
||||
const ExecutorsAndKeys* executors_and_keys,
|
||||
RunState* run_state,
|
||||
std::vector<Tensor>* outputs) {
|
||||
Status DirectSession::RecvPRunOutputs(
|
||||
const std::vector<string>& output_names,
|
||||
const ExecutorsAndKeys* executors_and_keys, RunState* run_state,
|
||||
std::vector<Tensor>* outputs) {
|
||||
Status s;
|
||||
if (!output_names.empty()) {
|
||||
outputs->resize(output_names.size());
|
||||
@ -822,8 +857,9 @@ Status DirectSession::RecvOutputs(const std::vector<string>& output_names,
|
||||
for (size_t output_offset = 0; output_offset < output_names.size();
|
||||
++output_offset) {
|
||||
const string& output_name = output_names[output_offset];
|
||||
auto it = executors_and_keys->output_keys.find(output_name);
|
||||
if (it == executors_and_keys->output_keys.end()) {
|
||||
auto it =
|
||||
executors_and_keys->output_name_to_rendezvous_key.find(output_name);
|
||||
if (it == executors_and_keys->output_name_to_rendezvous_key.end()) {
|
||||
return errors::Internal("'", output_name,
|
||||
"' is not a pre-defined fetch.");
|
||||
}
|
||||
@ -987,14 +1023,16 @@ Status DirectSession::GetOrCreateExecutors(
|
||||
options.feed_endpoints = inputs_sorted;
|
||||
options.fetch_endpoints = outputs_sorted;
|
||||
options.target_nodes = tn_sorted;
|
||||
options.use_function_convention = !run_state_args->is_partial_run;
|
||||
|
||||
std::shared_ptr<ExecutorsAndKeys> ek(new ExecutorsAndKeys);
|
||||
|
||||
// The executor_lock_ is intentionally released while executor is
|
||||
// being created.
|
||||
std::unordered_map<string, std::unique_ptr<Graph>> graphs;
|
||||
TF_RETURN_IF_ERROR(
|
||||
CreateGraphs(options, &graphs, &ek->flib_def, run_state_args));
|
||||
TF_RETURN_IF_ERROR(CreateGraphs(options, &graphs, &ek->flib_def,
|
||||
run_state_args, &ek->input_types,
|
||||
&ek->output_types));
|
||||
|
||||
if (run_state_args->is_partial_run) {
|
||||
ek->graph = std::move(run_state_args->graph);
|
||||
@ -1079,17 +1117,37 @@ Status DirectSession::GetOrCreateExecutors(
|
||||
item->executor.reset(executor);
|
||||
}
|
||||
|
||||
// Compute the rendezvous keys to avoid recomputing them every time.
|
||||
//
|
||||
// We always use the first device as the device name portion of the
|
||||
// key, even if we're feeding another graph.
|
||||
for (const string& input : inputs) {
|
||||
ek->input_keys[input] = GetRendezvousKey(
|
||||
input, device_set_.client_device()->attributes(), FrameAndIter(0, 0));
|
||||
}
|
||||
for (const string& output : outputs) {
|
||||
ek->output_keys[output] = GetRendezvousKey(
|
||||
output, device_set_.client_device()->attributes(), FrameAndIter(0, 0));
|
||||
// Cache the mapping from input/output names to graph elements to
|
||||
// avoid recomputing it every time.
|
||||
if (!run_state_args->is_partial_run) {
|
||||
// For regular `Run()`, we use the function calling convention, and so
|
||||
// maintain a mapping from input/output names to
|
||||
// argument/return-value ordinal index.
|
||||
for (size_t i = 0; i < inputs_sorted.size(); ++i) {
|
||||
const string& input = inputs_sorted[i];
|
||||
ek->input_name_to_index[input] = i;
|
||||
}
|
||||
for (size_t i = 0; i < outputs_sorted.size(); ++i) {
|
||||
const string& output = outputs_sorted[i];
|
||||
ek->output_name_to_index[output] = i;
|
||||
}
|
||||
} else {
|
||||
// For `PRun()`, we use the rendezvous calling convention, and so
|
||||
// maintain a mapping from input/output names to rendezvous keys.
|
||||
//
|
||||
// We always use the first device as the device name portion of the
|
||||
// key, even if we're feeding another graph.
|
||||
for (size_t i = 0; i < inputs_sorted.size(); ++i) {
|
||||
const string& input = inputs_sorted[i];
|
||||
ek->input_name_to_rendezvous_key[input] = GetRendezvousKey(
|
||||
input, device_set_.client_device()->attributes(), FrameAndIter(0, 0));
|
||||
}
|
||||
for (size_t i = 0; i < outputs_sorted.size(); ++i) {
|
||||
const string& output = outputs_sorted[i];
|
||||
ek->output_name_to_rendezvous_key[output] =
|
||||
GetRendezvousKey(output, device_set_.client_device()->attributes(),
|
||||
FrameAndIter(0, 0));
|
||||
}
|
||||
}
|
||||
|
||||
// Reacquire the lock, try to insert into the map.
|
||||
@ -1110,7 +1168,8 @@ Status DirectSession::CreateGraphs(
|
||||
const BuildGraphOptions& subgraph_options,
|
||||
std::unordered_map<string, std::unique_ptr<Graph>>* outputs,
|
||||
std::unique_ptr<FunctionLibraryDefinition>* flib_def,
|
||||
RunStateArgs* run_state_args) {
|
||||
RunStateArgs* run_state_args, DataTypeVector* input_types,
|
||||
DataTypeVector* output_types) {
|
||||
mutex_lock l(graph_def_lock_);
|
||||
std::unique_ptr<SimpleClientGraph> client_graph;
|
||||
|
||||
@ -1135,6 +1194,23 @@ Status DirectSession::CreateGraphs(
|
||||
execution_state->BuildGraph(subgraph_options, &client_graph));
|
||||
}
|
||||
|
||||
if (subgraph_options.feed_endpoints.size() !=
|
||||
client_graph->feed_types.size()) {
|
||||
return errors::Internal(
|
||||
"Graph pruning failed: requested number of feed endpoints = ",
|
||||
subgraph_options.feed_endpoints.size(),
|
||||
" versus number of pruned feed endpoints = ",
|
||||
client_graph->feed_types.size());
|
||||
}
|
||||
if (subgraph_options.fetch_endpoints.size() !=
|
||||
client_graph->fetch_types.size()) {
|
||||
return errors::Internal(
|
||||
"Graph pruning failed: requested number of fetch endpoints = ",
|
||||
subgraph_options.fetch_endpoints.size(),
|
||||
" versus number of pruned fetch endpoints = ",
|
||||
client_graph->fetch_types.size());
|
||||
}
|
||||
|
||||
auto current_stateful_placements = execution_state->GetStatefulPlacements();
|
||||
// Update our current state based on the execution_state's
|
||||
// placements. If there are any mismatches for a node,
|
||||
@ -1240,6 +1316,8 @@ Status DirectSession::CreateGraphs(
|
||||
}
|
||||
}
|
||||
*flib_def = std::move(client_graph->flib_def);
|
||||
std::swap(*input_types, client_graph->feed_types);
|
||||
std::swap(*output_types, client_graph->fetch_types);
|
||||
return s;
|
||||
}
|
||||
|
||||
|
@ -132,8 +132,13 @@ class DirectSession : public Session {
|
||||
NameNodeMap name_to_node;
|
||||
std::unique_ptr<FunctionLibraryDefinition> flib_def;
|
||||
std::vector<PerPartitionExecutorsAndLib> items;
|
||||
std::unordered_map<string, string> input_keys;
|
||||
std::unordered_map<string, string> output_keys;
|
||||
std::unordered_map<string, size_t> input_name_to_index;
|
||||
std::unordered_map<string, string> input_name_to_rendezvous_key;
|
||||
std::unordered_map<string, size_t> output_name_to_index;
|
||||
std::unordered_map<string, string> output_name_to_rendezvous_key;
|
||||
|
||||
DataTypeVector input_types;
|
||||
DataTypeVector output_types;
|
||||
};
|
||||
|
||||
// For each live partial execution, the session maintains a RunState.
|
||||
@ -187,7 +192,8 @@ class DirectSession : public Session {
|
||||
const BuildGraphOptions& options,
|
||||
std::unordered_map<string, std::unique_ptr<Graph>>* outputs,
|
||||
std::unique_ptr<FunctionLibraryDefinition>* flib_def,
|
||||
RunStateArgs* run_state_args);
|
||||
RunStateArgs* run_state_args, DataTypeVector* input_types,
|
||||
DataTypeVector* output_types);
|
||||
|
||||
::tensorflow::Status ExtendLocked(const GraphDef& graph)
|
||||
EXCLUSIVE_LOCKS_REQUIRED(graph_def_lock_);
|
||||
@ -196,17 +202,17 @@ class DirectSession : public Session {
|
||||
const Tensor& resource_tensor, Tensor* retrieved_tensor);
|
||||
|
||||
// Feeds more inputs to the executors, triggering further execution.
|
||||
::tensorflow::Status SendInputs(
|
||||
::tensorflow::Status SendPRunInputs(
|
||||
const std::vector<std::pair<string, Tensor>>& inputs,
|
||||
const ExecutorsAndKeys* executors_and_keys,
|
||||
IntraProcessRendezvous* rendez);
|
||||
|
||||
// Fetches more outputs from the executors. It waits until the output
|
||||
// tensors are computed.
|
||||
::tensorflow::Status RecvOutputs(const std::vector<string>& output_names,
|
||||
const ExecutorsAndKeys* executors_and_keys,
|
||||
RunState* run_state,
|
||||
std::vector<Tensor>* outputs);
|
||||
::tensorflow::Status RecvPRunOutputs(
|
||||
const std::vector<string>& output_names,
|
||||
const ExecutorsAndKeys* executors_and_keys, RunState* run_state,
|
||||
std::vector<Tensor>* outputs);
|
||||
|
||||
// Check if the specified fetches can be computed from the feeds
|
||||
// that we have already provided.
|
||||
|
@ -1434,7 +1434,7 @@ void ExecutorState::RunAsync(Executor::DoneCallback done) {
|
||||
} else {
|
||||
num_outstanding_ops_ = ready.size();
|
||||
root_frame_->iterations[0]->outstanding_ops = ready.size();
|
||||
done_cb_ = done;
|
||||
done_cb_ = std::move(done);
|
||||
// Schedule to run all the ready ops in thread pool.
|
||||
ScheduleReady(ready, nullptr);
|
||||
}
|
||||
@ -2560,7 +2560,7 @@ bool ExecutorState::FrameState::CleanupIterations(const GraphView* gview,
|
||||
}
|
||||
|
||||
void ExecutorImpl::RunAsync(const Args& args, DoneCallback done) {
|
||||
(new ExecutorState(args, this))->RunAsync(done);
|
||||
(new ExecutorState(args, this))->RunAsync(std::move(done));
|
||||
}
|
||||
|
||||
} // end namespace
|
||||
|
@ -604,7 +604,7 @@ struct CustomCreatorSingleton {
|
||||
|
||||
void Set(CustomKernelCreator cb) {
|
||||
mutex_lock l(mu);
|
||||
custom_creator = cb;
|
||||
custom_creator = std::move(cb);
|
||||
}
|
||||
|
||||
CustomKernelCreator Get() {
|
||||
@ -621,7 +621,7 @@ CustomCreatorSingleton* GetCustomCreatorSingleton() {
|
||||
} // end namespace
|
||||
|
||||
void RegisterDefaultCustomKernelCreator(CustomKernelCreator cb) {
|
||||
GetCustomCreatorSingleton()->Set(cb);
|
||||
GetCustomCreatorSingleton()->Set(std::move(cb));
|
||||
}
|
||||
|
||||
FunctionLibraryRuntime* NewFunctionLibraryRuntime(
|
||||
@ -631,7 +631,7 @@ FunctionLibraryRuntime* NewFunctionLibraryRuntime(
|
||||
CustomKernelCreator custom_kernel_creator) {
|
||||
return new FunctionLibraryRuntimeImpl(dmgr, env, device, graph_def_version,
|
||||
lib_def, optimizer_options,
|
||||
custom_kernel_creator);
|
||||
std::move(custom_kernel_creator));
|
||||
}
|
||||
|
||||
FunctionLibraryRuntime* NewFunctionLibraryRuntime(
|
||||
|
@ -44,7 +44,7 @@ Status GetOpSig(const string& op, const OpDef** sig) {
|
||||
void FunctionTestSchedClosure(std::function<void()> fn) {
|
||||
static thread::ThreadPool* w =
|
||||
new thread::ThreadPool(Env::Default(), "Test", 8);
|
||||
w->Schedule(fn);
|
||||
w->Schedule(std::move(fn));
|
||||
}
|
||||
|
||||
void HasError(const Status& s, const string& substr) {
|
||||
@ -654,7 +654,8 @@ namespace {
|
||||
|
||||
bool DoNothing(Graph* g) { return false; }
|
||||
|
||||
string Optimize(std::function<bool(Graph* g)> pass, const FunctionDef& fdef) {
|
||||
string Optimize(const std::function<bool(Graph* g)>& pass,
|
||||
const FunctionDef& fdef) {
|
||||
InstantiationResult result;
|
||||
InstantiateAttrValueMap empty;
|
||||
TF_CHECK_OK(InstantiateFunction(fdef, empty, GetOpSig, &result));
|
||||
|
@ -130,9 +130,11 @@ Status GraphRunner::Run(Graph* graph, FunctionLibraryRuntime* function_library,
|
||||
}
|
||||
|
||||
// Call RewriteGraphForExecution
|
||||
subgraph::RewriteGraphMetadata metadata;
|
||||
TF_RETURN_IF_ERROR(subgraph::RewriteGraphForExecution(
|
||||
graph_to_run.get(), input_names, output_names, {} /* target nodes */,
|
||||
cpu_device_->attributes()));
|
||||
cpu_device_->attributes(), false /* use_function_convention */,
|
||||
&metadata));
|
||||
|
||||
// Create the local executor and the Rendezvous for fetching back the
|
||||
// constants.
|
||||
|
@ -106,7 +106,7 @@ void IntraProcessRendezvous::SameWorkerRecvDone(
|
||||
CopyTensor::ViaDMA(parsed.edge_name, send_args.device_context,
|
||||
recv_args.device_context, src_device, dst_device,
|
||||
send_args.alloc_attrs, recv_args.alloc_attrs, &in, out,
|
||||
done);
|
||||
std::move(done));
|
||||
}
|
||||
|
||||
void IntraProcessRendezvous::RecvAsync(const ParsedKey& parsed,
|
||||
@ -132,7 +132,8 @@ void IntraProcessRendezvous::RecvAsync(const ParsedKey& parsed,
|
||||
};
|
||||
|
||||
if (status.ok() && in.IsInitialized()) {
|
||||
SameWorkerRecvDone(parsed, send_args, recv_args, in, out, final_callback);
|
||||
SameWorkerRecvDone(parsed, send_args, recv_args, in, out,
|
||||
std::move(final_callback));
|
||||
} else {
|
||||
final_callback(status);
|
||||
}
|
||||
|
@ -21,9 +21,9 @@ limitations under the License.
|
||||
namespace tensorflow {
|
||||
namespace {
|
||||
|
||||
// Replaces ReadVariableOp nodes which are only used by Sends and sinks with
|
||||
// _UnsafeReadVariable nodes, as this transforamtion is safe and will improve
|
||||
// performance.
|
||||
// Replaces ReadVariableOp nodes which are only used by Sends, sinks,
|
||||
// and function Retvals with _UnsafeReadVariable nodes, as this
|
||||
// transformation is safe and will improve performance.
|
||||
class ResourceVariableReadPass : public GraphOptimizationPass {
|
||||
public:
|
||||
Status Run(const GraphOptimizationPassOptions& options) override {
|
||||
@ -43,7 +43,8 @@ class ResourceVariableReadPass : public GraphOptimizationPass {
|
||||
if (n->type_string() == "ReadVariableOp") {
|
||||
bool skip = false;
|
||||
for (const Edge* e : n->out_edges()) {
|
||||
if (!e->dst()->IsSend() && e->dst()->name() != "_SINK") {
|
||||
if (!e->dst()->IsSend() && e->dst()->type_string() != "_Retval" &&
|
||||
e->dst()->name() != "_SINK") {
|
||||
skip = true;
|
||||
}
|
||||
}
|
||||
|
@ -284,9 +284,11 @@ Status SimpleGraphExecutionState::InitBaseGraph(
|
||||
if (session_options_ &&
|
||||
session_options_->config.graph_options().place_pruned_graph()) {
|
||||
// Rewrite the graph before placement.
|
||||
rewrite_metadata_.reset(new subgraph::RewriteGraphMetadata);
|
||||
TF_RETURN_IF_ERROR(subgraph::RewriteGraphForExecution(
|
||||
new_graph.get(), options.feed_endpoints, options.fetch_endpoints,
|
||||
options.target_nodes, device_set_->client_device()->attributes()));
|
||||
options.target_nodes, device_set_->client_device()->attributes(),
|
||||
options.use_function_convention, rewrite_metadata_.get()));
|
||||
}
|
||||
|
||||
// Save stateful placements before placing.
|
||||
@ -333,15 +335,26 @@ Status SimpleGraphExecutionState::BuildGraph(
|
||||
std::unique_ptr<Graph> ng(new Graph(flib_def_.get()));
|
||||
CopyGraph(*graph_, ng.get());
|
||||
|
||||
subgraph::RewriteGraphMetadata rewrite_metadata;
|
||||
if (session_options_ == nullptr ||
|
||||
!session_options_->config.graph_options().place_pruned_graph()) {
|
||||
// Extract the subset of the graph that needs to be run, adding feed/fetch
|
||||
// ops as needed.
|
||||
TF_RETURN_IF_ERROR(subgraph::RewriteGraphForExecution(
|
||||
ng.get(), options.feed_endpoints, options.fetch_endpoints,
|
||||
options.target_nodes, device_set_->client_device()->attributes()));
|
||||
options.target_nodes, device_set_->client_device()->attributes(),
|
||||
options.use_function_convention, &rewrite_metadata));
|
||||
} else {
|
||||
// This SimpleGraphExecutionState represents a graph that was
|
||||
// pruned when this was constructed, so we copy the metadata from
|
||||
// a member variable.
|
||||
CHECK(rewrite_metadata_);
|
||||
rewrite_metadata = *rewrite_metadata_;
|
||||
}
|
||||
|
||||
CHECK_EQ(options.feed_endpoints.size(), rewrite_metadata.feed_types.size());
|
||||
CHECK_EQ(options.fetch_endpoints.size(), rewrite_metadata.fetch_types.size());
|
||||
|
||||
// Make a fresh copy of the function library for the client graph.
|
||||
std::unique_ptr<FunctionLibraryDefinition> flib(
|
||||
new FunctionLibraryDefinition(*flib_def_));
|
||||
@ -363,7 +376,8 @@ Status SimpleGraphExecutionState::BuildGraph(
|
||||
// since the local CostModel used to record its stats is sized by
|
||||
// the largest node id.
|
||||
std::unique_ptr<SimpleClientGraph> dense_copy(
|
||||
new SimpleClientGraph(std::move(flib)));
|
||||
new SimpleClientGraph(std::move(flib), rewrite_metadata.feed_types,
|
||||
rewrite_metadata.fetch_types));
|
||||
CopyGraph(*ng, &dense_copy->graph);
|
||||
|
||||
// TODO(vrv): We should check invariants of the graph here.
|
||||
|
@ -39,6 +39,10 @@ struct SessionOptions;
|
||||
class StepStats;
|
||||
class Timeline;
|
||||
|
||||
namespace subgraph {
|
||||
struct RewriteGraphMetadata;
|
||||
}
|
||||
|
||||
struct SimpleGraphExecutionStateOptions {
|
||||
const DeviceSet* device_set = nullptr;
|
||||
const SessionOptions* session_options = nullptr;
|
||||
@ -50,13 +54,19 @@ struct SimpleGraphExecutionStateOptions {
|
||||
// A SimpleClientGraph is simply a sub-graph of the full graph as induced by
|
||||
// BuildGraphOptions.
|
||||
struct SimpleClientGraph {
|
||||
explicit SimpleClientGraph(std::unique_ptr<FunctionLibraryDefinition> flib)
|
||||
: flib_def(std::move(flib)), graph(flib_def.get()) {}
|
||||
explicit SimpleClientGraph(std::unique_ptr<FunctionLibraryDefinition> flib,
|
||||
DataTypeVector feed_types,
|
||||
DataTypeVector fetch_types)
|
||||
: flib_def(std::move(flib)),
|
||||
graph(flib_def.get()),
|
||||
feed_types(std::move(feed_types)),
|
||||
fetch_types(std::move(fetch_types)) {}
|
||||
// Each client-graph gets its own function library since optimization passes
|
||||
// post rewrite for execution might want to introduce new functions.
|
||||
std::unique_ptr<FunctionLibraryDefinition> flib_def;
|
||||
Graph graph;
|
||||
int32 placement_version;
|
||||
DataTypeVector feed_types;
|
||||
DataTypeVector fetch_types;
|
||||
};
|
||||
|
||||
// SimpleGraphExecutionState is responsible for generating an
|
||||
@ -190,6 +200,10 @@ class SimpleGraphExecutionState {
|
||||
// and may be updated by a graph optimization pass.
|
||||
std::unique_ptr<FunctionLibraryDefinition> flib_def_;
|
||||
|
||||
// `rewrite_metadata_` is only set for SimpleGraphExecutionState
|
||||
// objects created by `MakeForPrunedGraph()`.
|
||||
std::unique_ptr<subgraph::RewriteGraphMetadata> rewrite_metadata_;
|
||||
|
||||
// The dataflow graph owned by this object.
|
||||
Graph* graph_;
|
||||
|
||||
|
@ -63,8 +63,9 @@ class NoReusePortOption : public ::grpc::ServerBuilderOption {
|
||||
};
|
||||
|
||||
// static utility function
|
||||
RendezvousMgrInterface* NewRpcRendezvousMgr(const WorkerEnv* env,
|
||||
const string& worker_name, WorkerCacheInterface* worker_cache) {
|
||||
RendezvousMgrInterface* NewRpcRendezvousMgr(
|
||||
const WorkerEnv* env, const string& worker_name,
|
||||
WorkerCacheInterface* worker_cache) {
|
||||
return new RpcRendezvousMgr(env, worker_name, worker_cache);
|
||||
}
|
||||
|
||||
@ -76,7 +77,7 @@ GrpcServer::GrpcServer(const ServerDef& server_def, Env* env)
|
||||
GrpcServer::~GrpcServer() {
|
||||
TF_CHECK_OK(Stop());
|
||||
TF_CHECK_OK(Join());
|
||||
|
||||
|
||||
delete master_service_;
|
||||
delete worker_service_;
|
||||
|
||||
@ -100,7 +101,7 @@ GrpcServer::~GrpcServer() {
|
||||
}
|
||||
|
||||
Status GrpcServer::Init(ServiceInitFunction service_func,
|
||||
RendezvousMgrCreationFunction rendevous_mgr_func) {
|
||||
RendezvousMgrCreationFunction rendevous_mgr_func) {
|
||||
mutex_lock l(mu_);
|
||||
CHECK_EQ(state_, NEW);
|
||||
master_env_.env = env_;
|
||||
@ -193,6 +194,8 @@ Status GrpcServer::Init(ServiceInitFunction service_func,
|
||||
|
||||
// Set up worker environment.
|
||||
std::unique_ptr<RendezvousMgrInterface> rendezvous_mgr(
|
||||
rendevous_mgr_func == nullptr ?
|
||||
new RpcRendezvousMgr(&worker_env_, name_prefix, worker_cache) :
|
||||
rendevous_mgr_func(&worker_env_, name_prefix, worker_cache));
|
||||
worker_env_.session_mgr = new SessionMgr(
|
||||
&worker_env_, SessionMgr::WorkerNameFromServerDef(server_def_),
|
||||
@ -222,6 +225,10 @@ Status GrpcServer::Init(ServiceInitFunction service_func,
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
Status GrpcServer::Init() {
|
||||
return Init(nullptr, nullptr);
|
||||
}
|
||||
|
||||
Status GrpcServer::ParseChannelSpec(const ServerDef& server_def,
|
||||
GrpcChannelSpec* channel_spec) {
|
||||
for (const auto& job : server_def.cluster().job()) {
|
||||
|
@ -37,14 +37,15 @@ class GrpcWorker;
|
||||
class Master;
|
||||
|
||||
// function that creates a RendezvousMgr.
|
||||
typedef std::function<RendezvousMgrInterface*(const WorkerEnv*,
|
||||
const std::string& worker_name, WorkerCacheInterface* worker_cache)>
|
||||
RendezvousMgrCreationFunction;
|
||||
typedef std::function<RendezvousMgrInterface*(
|
||||
const WorkerEnv*, const std::string& worker_name,
|
||||
WorkerCacheInterface* worker_cache)>
|
||||
RendezvousMgrCreationFunction;
|
||||
|
||||
// function that registers a service to the server. The service needs to
|
||||
// be registered before builder.BuildAndStart().
|
||||
typedef std::function<void(const WorkerEnv*, ::grpc::ServerBuilder*)>
|
||||
ServiceInitFunction;
|
||||
typedef std::function<void(const WorkerEnv*, ::grpc::ServerBuilder*)>
|
||||
ServiceInitFunction;
|
||||
|
||||
class GrpcServer : public ServerInterface {
|
||||
protected:
|
||||
@ -68,6 +69,8 @@ class GrpcServer : public ServerInterface {
|
||||
Status Init(ServiceInitFunction service_func,
|
||||
RendezvousMgrCreationFunction rendezvous_mgr_func);
|
||||
|
||||
Status Init();
|
||||
|
||||
// A subclass can override this method to support secure credentials.
|
||||
virtual std::shared_ptr<::grpc::ServerCredentials> GetServerCredentials(
|
||||
const ServerDef& server_def) const;
|
||||
@ -90,7 +93,7 @@ class GrpcServer : public ServerInterface {
|
||||
int bound_port() const { return bound_port_; }
|
||||
|
||||
WorkerEnv* worker_env() { return &worker_env_; }
|
||||
|
||||
|
||||
const ServerDef& server_def() const { return server_def_; }
|
||||
|
||||
private:
|
||||
@ -115,7 +118,7 @@ class GrpcServer : public ServerInterface {
|
||||
// Stop(), Join()
|
||||
enum State { NEW, STARTED, STOPPED };
|
||||
State state_ GUARDED_BY(mu_);
|
||||
|
||||
|
||||
// Implementation of a TensorFlow master, and RPC polling thread.
|
||||
MasterEnv master_env_;
|
||||
std::unique_ptr<Master> master_impl_;
|
||||
|
@ -789,7 +789,7 @@ Status FunctionCallFrame::GetRetvals(std::vector<Tensor>* rets) const {
|
||||
rets->clear();
|
||||
rets->reserve(rets_.size());
|
||||
for (size_t i = 0; i < rets_.size(); ++i) {
|
||||
auto item = rets_[i];
|
||||
const auto& item = rets_[i];
|
||||
if (item.has_val) {
|
||||
rets->push_back(item.val);
|
||||
} else {
|
||||
@ -799,6 +799,19 @@ Status FunctionCallFrame::GetRetvals(std::vector<Tensor>* rets) const {
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
Status FunctionCallFrame::ConsumeRetvals(std::vector<Tensor>* rets) {
|
||||
rets->clear();
|
||||
rets->reserve(rets_.size());
|
||||
for (size_t i = 0; i < rets_.size(); ++i) {
|
||||
if (rets_[i].has_val) {
|
||||
rets->emplace_back(std::move(rets_[i].val));
|
||||
} else {
|
||||
return errors::Internal("Retval[", i, "] does not have value");
|
||||
}
|
||||
}
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
Status FunctionCallFrame::GetArg(int index, Tensor* val) const {
|
||||
if (index < 0 || static_cast<size_t>(index) >= args_.size()) {
|
||||
return errors::InvalidArgument("GetArg ", index, " is not within [0, ",
|
||||
|
@ -259,6 +259,7 @@ class FunctionCallFrame {
|
||||
// Caller methods.
|
||||
Status SetArgs(gtl::ArraySlice<Tensor> args);
|
||||
Status GetRetvals(std::vector<Tensor>* rets) const;
|
||||
Status ConsumeRetvals(std::vector<Tensor>* rets);
|
||||
|
||||
// Callee methods.
|
||||
Status GetArg(int index, Tensor* val) const;
|
||||
|
@ -126,29 +126,33 @@ FunctionDef XTimes16() {
|
||||
{{"y", "y:y:0"}});
|
||||
}
|
||||
|
||||
FunctionDef WXPlusB() {
|
||||
return FDH::Define(
|
||||
// Name
|
||||
"WXPlusB",
|
||||
// Args
|
||||
{"w: T", "x: T", "b: T"},
|
||||
// Return values
|
||||
{"y: T"},
|
||||
// Attr def
|
||||
{"T: {float, double}"},
|
||||
// Nodes
|
||||
{{{"mm"},
|
||||
"MatMul",
|
||||
{"w", "x"},
|
||||
{{"T", "$T"},
|
||||
{"transpose_a", false},
|
||||
{"transpose_b", false},
|
||||
#ifdef INTEL_MKL
|
||||
}},
|
||||
FunctionDef WXPlusB(){return FDH::Define(
|
||||
// Name
|
||||
"WXPlusB",
|
||||
// Args
|
||||
{"w: T", "x: T", "b: T"},
|
||||
// Return values
|
||||
{"y: T"},
|
||||
// Attr def
|
||||
{"T: {float, double}"},
|
||||
// Nodes
|
||||
{
|
||||
{{"mm"},
|
||||
"MatMul",
|
||||
{"w", "x"},
|
||||
{
|
||||
{"T", "$T"}, {"transpose_a", false}, {"transpose_b", false},
|
||||
#ifdef INTEL_MKL
|
||||
}},
|
||||
#else
|
||||
{"_kernel", "eigen"}}},
|
||||
#endif
|
||||
{{"y"}, "Add", {"mm", "b"}, {{"T", "$T"}}}});
|
||||
{
|
||||
{"y"}, "Add", {"mm", "b"}, {
|
||||
{ "T", "$T" }
|
||||
}
|
||||
}
|
||||
});
|
||||
}
|
||||
|
||||
FunctionDef Swap() {
|
||||
|
@ -63,7 +63,7 @@ namespace tensorflow {
|
||||
// P = BiasAdd(O, C)
|
||||
//
|
||||
// We merge them into Conv2DWithBias as:
|
||||
// P = MklConv2DWithBias(A, A_m, B, B_m, C, C_m)
|
||||
// P = _MklConv2DWithBias(A, A_m, B, B_m, C, C_m)
|
||||
//
|
||||
// The meaning of A_m, B_m and C_m is explained in B.1.
|
||||
//
|
||||
@ -115,7 +115,7 @@ namespace tensorflow {
|
||||
// Since every rewritten node generates twice the number of inputs and
|
||||
// outputs, one could imagine various orderings among Tensorflow tensors
|
||||
// and Mkl tensors. E.g., assume an op 'Conv2D' that takes (A, B) as
|
||||
// inputs, then the new op 'MklConv2D' can take inputs A, B, A_m and B_m
|
||||
// inputs, then the new op '_MklConv2D' can take inputs A, B, A_m and B_m
|
||||
// in A, A_m, B, B_m order or it can also take them in A, B, A_m, B_m
|
||||
// order. Among N inputs one can get N! permutations.
|
||||
//
|
||||
@ -239,15 +239,15 @@ namespace tensorflow {
|
||||
// -------------------------------------------
|
||||
// Consider BiasAddGrad op as:
|
||||
//
|
||||
// O = MklConv2D(A, B, C, A_m, B_m, C_m)
|
||||
// O = _MklConv2D(A, B, C, A_m, B_m, C_m)
|
||||
// P = BiasAddGrad(O)
|
||||
//
|
||||
// Then we rewrite it as:
|
||||
//
|
||||
// P = Conv2DWithBiasBackpropBias(O, O_m)
|
||||
//
|
||||
// 'Distance' between input of BiasAddGrad and MklConv2D in terms of hops is
|
||||
// the context matching depth. If MklConv2DWithBias is not within the context
|
||||
// 'Distance' between input of BiasAddGrad and _MklConv2D in terms of hops is
|
||||
// the context matching depth. If _MklConv2DWithBias is not within the context
|
||||
// matching depth, then we do not rewrite BiasAddGrad.
|
||||
|
||||
// How many hops do we search for matching node in the backward dataflow graph?
|
||||
@ -261,74 +261,66 @@ class MklLayoutRewritePass : public GraphOptimizationPass {
|
||||
public:
|
||||
MklLayoutRewritePass() {
|
||||
// NOTE: names are alphabetically sorted.
|
||||
csinfo_.avg_pool = "AvgPool";
|
||||
csinfo_.avg_pool_grad = "AvgPoolGrad";
|
||||
csinfo_.bias_add = "BiasAdd";
|
||||
csinfo_.bias_add_grad = "BiasAddGrad";
|
||||
csinfo_.concat = "Concat";
|
||||
csinfo_.concatv2 = "ConcatV2";
|
||||
csinfo_.conv2d = "Conv2D";
|
||||
csinfo_.conv2d_grad_input = "Conv2DBackpropInput";
|
||||
csinfo_.conv2d_grad_filter = "Conv2DBackpropFilter";
|
||||
csinfo_.fused_batch_norm = "FusedBatchNorm";
|
||||
csinfo_.avg_pool = "AvgPool";
|
||||
csinfo_.avg_pool_grad = "AvgPoolGrad";
|
||||
csinfo_.bias_add = "BiasAdd";
|
||||
csinfo_.bias_add_grad = "BiasAddGrad";
|
||||
csinfo_.concat = "Concat";
|
||||
csinfo_.concatv2 = "ConcatV2";
|
||||
csinfo_.conv2d = "Conv2D";
|
||||
csinfo_.conv2d_grad_input = "Conv2DBackpropInput";
|
||||
csinfo_.conv2d_grad_filter = "Conv2DBackpropFilter";
|
||||
csinfo_.fused_batch_norm = "FusedBatchNorm";
|
||||
csinfo_.fused_batch_norm_grad = "FusedBatchNormGrad";
|
||||
csinfo_.lrn = "LRN";
|
||||
csinfo_.lrn_grad = "LRNGrad";
|
||||
csinfo_.matmul = "MatMul";
|
||||
csinfo_.max_pool = "MaxPool";
|
||||
csinfo_.max_pool_grad = "MaxPoolGrad";
|
||||
csinfo_.mkl_conv2d = "MklConv2D";
|
||||
csinfo_.mkl_conv2d_with_bias = "MklConv2DWithBias";
|
||||
csinfo_.lrn = "LRN";
|
||||
csinfo_.lrn_grad = "LRNGrad";
|
||||
csinfo_.matmul = "MatMul";
|
||||
csinfo_.max_pool = "MaxPool";
|
||||
csinfo_.max_pool_grad = "MaxPoolGrad";
|
||||
csinfo_.mkl_conv2d = "_MklConv2D";
|
||||
csinfo_.mkl_conv2d_with_bias = "_MklConv2DWithBias";
|
||||
csinfo_.mkl_conv2d_with_bias_backprop_bias =
|
||||
"MklConv2DWithBiasBackpropBias";
|
||||
csinfo_.relu = "Relu";
|
||||
csinfo_.reshape = "Reshape";
|
||||
csinfo_.relu_grad = "ReluGrad";
|
||||
csinfo_.split = "Split";
|
||||
"_MklConv2DWithBiasBackpropBias";
|
||||
csinfo_.relu = "Relu";
|
||||
csinfo_.reshape = "Reshape";
|
||||
csinfo_.relu_grad = "ReluGrad";
|
||||
csinfo_.split = "Split";
|
||||
|
||||
// NOTE: names are alphabetically sorted.
|
||||
rinfo_.push_back({csinfo_.avg_pool,
|
||||
GetMklOpName(csinfo_.avg_pool),
|
||||
1, CopyAttrsPooling, AlwaysRewrite});
|
||||
rinfo_.push_back({csinfo_.avg_pool, GetMklOpName(csinfo_.avg_pool), 1,
|
||||
CopyAttrsPooling, AlwaysRewrite});
|
||||
rinfo_.push_back({csinfo_.avg_pool_grad,
|
||||
GetMklOpName(csinfo_.avg_pool_grad),
|
||||
2, CopyAttrsPooling, AlwaysRewrite});
|
||||
rinfo_.push_back({csinfo_.concat,
|
||||
GetMklOpName(csinfo_.concat),
|
||||
0, CopyAttrsConcat, AlwaysRewrite});
|
||||
rinfo_.push_back({csinfo_.concatv2,
|
||||
GetMklOpName(csinfo_.concatv2),
|
||||
0, CopyAttrsConcatV2, AlwaysRewrite});
|
||||
rinfo_.push_back({csinfo_.conv2d,
|
||||
GetMklOpName(csinfo_.conv2d),
|
||||
2, CopyAttrsConv2D, AlwaysRewrite});
|
||||
GetMklOpName(csinfo_.avg_pool_grad), 2, CopyAttrsPooling,
|
||||
AlwaysRewrite});
|
||||
rinfo_.push_back({csinfo_.concat, GetMklOpName(csinfo_.concat), 0,
|
||||
CopyAttrsConcat, AlwaysRewrite});
|
||||
rinfo_.push_back({csinfo_.concatv2, GetMklOpName(csinfo_.concatv2), 0,
|
||||
CopyAttrsConcatV2, AlwaysRewrite});
|
||||
rinfo_.push_back({csinfo_.conv2d, GetMklOpName(csinfo_.conv2d), 2,
|
||||
CopyAttrsConv2D, AlwaysRewrite});
|
||||
rinfo_.push_back({csinfo_.conv2d_grad_filter,
|
||||
GetMklOpName(csinfo_.conv2d_grad_filter),
|
||||
3, CopyAttrsConv2D, AlwaysRewrite});
|
||||
GetMklOpName(csinfo_.conv2d_grad_filter), 3,
|
||||
CopyAttrsConv2D, AlwaysRewrite});
|
||||
rinfo_.push_back({csinfo_.conv2d_grad_input,
|
||||
GetMklOpName(csinfo_.conv2d_grad_input),
|
||||
3, CopyAttrsConv2D, AlwaysRewrite});
|
||||
GetMklOpName(csinfo_.conv2d_grad_input), 3,
|
||||
CopyAttrsConv2D, AlwaysRewrite});
|
||||
rinfo_.push_back({csinfo_.fused_batch_norm,
|
||||
GetMklOpName(csinfo_.fused_batch_norm),
|
||||
5, CopyAttrsFusedBatchNorm, AlwaysRewrite});
|
||||
GetMklOpName(csinfo_.fused_batch_norm), 5,
|
||||
CopyAttrsFusedBatchNorm, AlwaysRewrite});
|
||||
rinfo_.push_back({csinfo_.fused_batch_norm_grad,
|
||||
GetMklOpName(csinfo_.fused_batch_norm_grad),
|
||||
5, CopyAttrsFusedBatchNorm, AlwaysRewrite});
|
||||
rinfo_.push_back({csinfo_.lrn,
|
||||
GetMklOpName(csinfo_.lrn),
|
||||
1, CopyAttrsLRN, AlwaysRewrite});
|
||||
rinfo_.push_back({csinfo_.lrn_grad,
|
||||
GetMklOpName(csinfo_.lrn_grad),
|
||||
3, CopyAttrsLRN, AlwaysRewrite});
|
||||
rinfo_.push_back({csinfo_.max_pool,
|
||||
GetMklOpName(csinfo_.max_pool),
|
||||
1, CopyAttrsPooling, AlwaysRewrite});
|
||||
GetMklOpName(csinfo_.fused_batch_norm_grad), 5,
|
||||
CopyAttrsFusedBatchNorm, AlwaysRewrite});
|
||||
rinfo_.push_back({csinfo_.lrn, GetMklOpName(csinfo_.lrn), 1, CopyAttrsLRN,
|
||||
AlwaysRewrite});
|
||||
rinfo_.push_back({csinfo_.lrn_grad, GetMklOpName(csinfo_.lrn_grad), 3,
|
||||
CopyAttrsLRN, AlwaysRewrite});
|
||||
rinfo_.push_back({csinfo_.max_pool, GetMklOpName(csinfo_.max_pool), 1,
|
||||
CopyAttrsPooling, AlwaysRewrite});
|
||||
rinfo_.push_back({csinfo_.max_pool_grad,
|
||||
GetMklOpName(csinfo_.max_pool_grad),
|
||||
3, CopyAttrsPooling, AlwaysRewrite});
|
||||
rinfo_.push_back({csinfo_.relu,
|
||||
GetMklOpName(csinfo_.relu),
|
||||
1, CopyAttrsRelu, AlwaysRewrite});
|
||||
GetMklOpName(csinfo_.max_pool_grad), 3, CopyAttrsPooling,
|
||||
AlwaysRewrite});
|
||||
rinfo_.push_back({csinfo_.relu, GetMklOpName(csinfo_.relu), 1,
|
||||
CopyAttrsRelu, AlwaysRewrite});
|
||||
rinfo_.push_back({csinfo_.reshape, GetMklOpName(csinfo_.reshape), 2,
|
||||
CopyAttrsReshape, AlwaysRewrite});
|
||||
|
||||
@ -339,8 +331,8 @@ class MklLayoutRewritePass : public GraphOptimizationPass {
|
||||
wsinfo_.push_back({csinfo_.max_pool, csinfo_.max_pool_grad, 0, 1, 1, 3});
|
||||
|
||||
// Add a rule for merging nodes
|
||||
minfo_.push_back(
|
||||
{csinfo_.mkl_conv2d, csinfo_.bias_add, 0, csinfo_.mkl_conv2d_with_bias});
|
||||
minfo_.push_back({csinfo_.mkl_conv2d, csinfo_.bias_add, 0,
|
||||
csinfo_.mkl_conv2d_with_bias});
|
||||
|
||||
// We use maxhop of 10 based on empirical observations. Also, these are
|
||||
// maxhops in backward data-flow graph. Since input of forward nodes
|
||||
@ -374,7 +366,7 @@ class MklLayoutRewritePass : public GraphOptimizationPass {
|
||||
// A function handler to copy attributes from an old node to a new node.
|
||||
std::function<void(const Node*, NodeBuilder*)> copy_attrs;
|
||||
std::function<bool(const Node*)> rewrite_rule; // A rule under which to
|
||||
// rewrite this node.
|
||||
// rewrite this node.
|
||||
} RewriteInfo;
|
||||
|
||||
/// Structure to specify a forward op, a backward op, and the slot numbers
|
||||
@ -477,7 +469,7 @@ class MklLayoutRewritePass : public GraphOptimizationPass {
|
||||
//
|
||||
// Concat, Split are vararg nodes.
|
||||
inline bool IsVarArgNode(Node* n) {
|
||||
if (n->type_string() == csinfo_.concat ||
|
||||
if (n->type_string() == csinfo_.concat ||
|
||||
n->type_string() == csinfo_.concatv2 ||
|
||||
n->type_string() == csinfo_.split) {
|
||||
return true;
|
||||
@ -496,9 +488,9 @@ class MklLayoutRewritePass : public GraphOptimizationPass {
|
||||
inline int GetTensorListLength(const OpDef::ArgDef& arg, Node* n) {
|
||||
CHECK_EQ(ArgIsList(arg), true);
|
||||
int N = 0;
|
||||
const string attr_name = !arg.type_list_attr().empty() ?
|
||||
arg.type_list_attr() :
|
||||
arg.number_attr();
|
||||
const string attr_name = !arg.type_list_attr().empty()
|
||||
? arg.type_list_attr()
|
||||
: arg.number_attr();
|
||||
if (!arg.type_list_attr().empty()) {
|
||||
std::vector<DataType> value;
|
||||
TF_CHECK_OK(GetNodeAttr(n->def(), attr_name, &value));
|
||||
@ -514,7 +506,7 @@ class MklLayoutRewritePass : public GraphOptimizationPass {
|
||||
// TODO(nhasabni) We should move this to mkl_util.h.
|
||||
inline string GetMklOpName(const string& name) const {
|
||||
// Prefix that we add to Tensorflow op name to construct Mkl op name.
|
||||
const char* const kMklOpPrefix = "Mkl";
|
||||
const char* const kMklOpPrefix = "_Mkl";
|
||||
return string(kMklOpPrefix) + name;
|
||||
}
|
||||
|
||||
@ -598,9 +590,9 @@ class MklLayoutRewritePass : public GraphOptimizationPass {
|
||||
//
|
||||
// @return None
|
||||
void GetNodesProducingTFTensorList(
|
||||
const gtl::InlinedVector<std::pair<Node*, int>, 4>& inputs,
|
||||
int* input_idx, int list_length,
|
||||
std::vector<NodeBuilder::NodeOut>* output_nodes);
|
||||
const gtl::InlinedVector<std::pair<Node*, int>, 4>& inputs,
|
||||
int* input_idx, int list_length,
|
||||
std::vector<NodeBuilder::NodeOut>* output_nodes);
|
||||
|
||||
// Get nodes that will feed a list of Mkl tensors to the new
|
||||
// node that we are constructing.
|
||||
@ -616,10 +608,11 @@ class MklLayoutRewritePass : public GraphOptimizationPass {
|
||||
// @output output_nodes - the list of new nodes creating Mkl tensors
|
||||
//
|
||||
// @return None
|
||||
void GetNodesProducingMklTensorList(std::unique_ptr<Graph>* g,
|
||||
const gtl::InlinedVector<std::pair<Node*, int>, 4>& inputs,
|
||||
int* input_idx, int list_length,
|
||||
std::vector<NodeBuilder::NodeOut>* output_nodes);
|
||||
void GetNodesProducingMklTensorList(
|
||||
std::unique_ptr<Graph>* g,
|
||||
const gtl::InlinedVector<std::pair<Node*, int>, 4>& inputs,
|
||||
int* input_idx, int list_length,
|
||||
std::vector<NodeBuilder::NodeOut>* output_nodes);
|
||||
|
||||
// Get a node that will feed an Mkl tensor to the new
|
||||
// node that we are constructing. The output node could be (1) 'n'
|
||||
@ -635,7 +628,8 @@ class MklLayoutRewritePass : public GraphOptimizationPass {
|
||||
// will feed the tensor
|
||||
// @return None
|
||||
void GetNodeProducingMklTensor(std::unique_ptr<Graph>* g, Node* n,
|
||||
int n_output_slot, Node** mkl_node, int* mkl_node_output_slot);
|
||||
int n_output_slot, Node** mkl_node,
|
||||
int* mkl_node_output_slot);
|
||||
|
||||
// Setup new inputs using old inputs 'inputs' for the rewritten node in 'nb'
|
||||
// in graph 'g'. Original node is input in 'old_node'. Inputs to 'nb' are
|
||||
@ -648,11 +642,12 @@ class MklLayoutRewritePass : public GraphOptimizationPass {
|
||||
//
|
||||
// Returns Status::OK() if setting up inputs is successful, otherwise
|
||||
// returns appropriate status code.
|
||||
int SetUpContiguousInputs(std::unique_ptr<Graph>* g,
|
||||
const gtl::InlinedVector<std::pair<Node*, int>, 4>& old_node_inputs,
|
||||
NodeBuilder* nb, Node* old_node,
|
||||
std::vector<NodeBuilder::NodeOut>* workspace_tensors,
|
||||
bool are_workspace_tensors_available);
|
||||
int SetUpContiguousInputs(
|
||||
std::unique_ptr<Graph>* g,
|
||||
const gtl::InlinedVector<std::pair<Node*, int>, 4>& old_node_inputs,
|
||||
NodeBuilder* nb, Node* old_node,
|
||||
std::vector<NodeBuilder::NodeOut>* workspace_tensors,
|
||||
bool are_workspace_tensors_available);
|
||||
|
||||
// Setup new inputs using old inputs 'inputs' for the rewritten node in 'nb'
|
||||
// in graph 'g'. Original node is input in 'orig_node'.
|
||||
@ -672,8 +667,9 @@ class MklLayoutRewritePass : public GraphOptimizationPass {
|
||||
// tensors, if they need to be added, will be set into these tensors.
|
||||
// If we set workspace tensors, then are_ws_tensors_added should be true.
|
||||
void AddWorkSpaceEdgeIfNeeded(std::unique_ptr<Graph>* g, Node* orig_node,
|
||||
NodeBuilder* nb, std::vector<NodeBuilder::NodeOut>* ws_tensors,
|
||||
bool* are_ws_tensors_added);
|
||||
NodeBuilder* nb,
|
||||
std::vector<NodeBuilder::NodeOut>* ws_tensors,
|
||||
bool* are_ws_tensors_added);
|
||||
|
||||
// Functions specific to operators to copy attributes
|
||||
// We need operator-specific function to copy attributes because the framework
|
||||
@ -732,9 +728,8 @@ static void FillInputs(const Node* n,
|
||||
}
|
||||
|
||||
void MklLayoutRewritePass::GetNodesProducingTFTensorList(
|
||||
const gtl::InlinedVector<std::pair<Node*, int>, 4>& inputs,
|
||||
int* input_idx, int list_length,
|
||||
std::vector<NodeBuilder::NodeOut>* output_nodes) {
|
||||
const gtl::InlinedVector<std::pair<Node*, int>, 4>& inputs, int* input_idx,
|
||||
int list_length, std::vector<NodeBuilder::NodeOut>* output_nodes) {
|
||||
CHECK_LT(*input_idx, inputs.size());
|
||||
CHECK_GT(list_length, 0);
|
||||
CHECK_NOTNULL(output_nodes);
|
||||
@ -767,34 +762,33 @@ void MklLayoutRewritePass::GetNodesProducingTFTensorList(
|
||||
}
|
||||
|
||||
// TODO(nhasabni) We should move this to mkl_util.h.
|
||||
void MklLayoutRewritePass::GetDummyMklTensorNode(
|
||||
std::unique_ptr<Graph>* g, Node** out, Node* orig_node) {
|
||||
void MklLayoutRewritePass::GetDummyMklTensorNode(std::unique_ptr<Graph>* g,
|
||||
Node** out, Node* orig_node) {
|
||||
// We use a tensor of shape {8} and value 0,0,0,0,0,0,0,0 to represent
|
||||
// dummy Mkl tensor. 8 = 2*size_t.
|
||||
const DataType dt = DataTypeToEnum<uint8>::v();
|
||||
TensorProto proto;
|
||||
proto.set_dtype(dt);
|
||||
uint8 zero[8] = {0, 0, 0, 0, 0, 0, 0, 0};
|
||||
proto.set_tensor_content(const_cast<const void*>(
|
||||
static_cast<void*>(&zero)), 8);
|
||||
proto.set_tensor_content(const_cast<const void*>(static_cast<void*>(&zero)),
|
||||
8);
|
||||
TensorShape dummy_shape({8});
|
||||
dummy_shape.AsProto(proto.mutable_tensor_shape());
|
||||
TF_CHECK_OK(NodeBuilder((*g)->NewName("DMT"), "Const")
|
||||
.Attr("value", proto)
|
||||
.Attr("dtype", dt)
|
||||
.Device(orig_node->def().device()) // We place this node on
|
||||
// the same device as the
|
||||
// device of the original
|
||||
// node.
|
||||
.Finalize(&**g, out));
|
||||
.Attr("value", proto)
|
||||
.Attr("dtype", dt)
|
||||
.Device(orig_node->def().device()) // We place this node on
|
||||
// the same device as the
|
||||
// device of the original
|
||||
// node.
|
||||
.Finalize(&**g, out));
|
||||
(*out)->set_assigned_device_name(orig_node->assigned_device_name());
|
||||
}
|
||||
|
||||
void MklLayoutRewritePass::GetNodesProducingMklTensorList(
|
||||
std::unique_ptr<Graph>* g,
|
||||
const gtl::InlinedVector<std::pair<Node*, int>, 4>& inputs,
|
||||
int* input_idx, int list_length,
|
||||
std::vector<NodeBuilder::NodeOut>* output_nodes) {
|
||||
const gtl::InlinedVector<std::pair<Node*, int>, 4>& inputs, int* input_idx,
|
||||
int list_length, std::vector<NodeBuilder::NodeOut>* output_nodes) {
|
||||
CHECK_LT(*input_idx, inputs.size());
|
||||
CHECK_GT(list_length, 0);
|
||||
CHECK_NOTNULL(output_nodes);
|
||||
@ -819,8 +813,8 @@ void MklLayoutRewritePass::GetNodesProducingMklTensorList(
|
||||
// If it is a list, then create a list of Mkl dummy nodes.
|
||||
for (int j = 0; j < N; j++) {
|
||||
GetNodeProducingMklTensor(g, n, slot, &mkl_node, &mkl_node_output_slot);
|
||||
output_nodes->push_back(NodeBuilder::NodeOut(mkl_node,
|
||||
mkl_node_output_slot));
|
||||
output_nodes->push_back(
|
||||
NodeBuilder::NodeOut(mkl_node, mkl_node_output_slot));
|
||||
}
|
||||
(*input_idx)++;
|
||||
list_length -= N;
|
||||
@ -829,8 +823,8 @@ void MklLayoutRewritePass::GetNodesProducingMklTensorList(
|
||||
Node* mkl_node = nullptr;
|
||||
int mkl_node_output_slot = 0;
|
||||
GetNodeProducingMklTensor(g, n, slot, &mkl_node, &mkl_node_output_slot);
|
||||
output_nodes->push_back(NodeBuilder::NodeOut(mkl_node,
|
||||
mkl_node_output_slot));
|
||||
output_nodes->push_back(
|
||||
NodeBuilder::NodeOut(mkl_node, mkl_node_output_slot));
|
||||
(*input_idx)++;
|
||||
list_length--;
|
||||
}
|
||||
@ -841,9 +835,9 @@ void MklLayoutRewritePass::GetNodesProducingMklTensorList(
|
||||
// node that we are constructing. An input node could be (1) 'n'
|
||||
// if it is Mkl layer, or (2) a dummy node producing dummy Mkl tensor
|
||||
// if 'n' is not an Mkl layer.
|
||||
void MklLayoutRewritePass::GetNodeProducingMklTensor(std::unique_ptr<Graph>* g,
|
||||
Node* n,
|
||||
int n_output_slot, Node** mkl_node, int* mkl_node_output_slot) {
|
||||
void MklLayoutRewritePass::GetNodeProducingMklTensor(
|
||||
std::unique_ptr<Graph>* g, Node* n, int n_output_slot, Node** mkl_node,
|
||||
int* mkl_node_output_slot) {
|
||||
CHECK_NOTNULL(n);
|
||||
CHECK_NOTNULL(mkl_node);
|
||||
CHECK_NOTNULL(mkl_node_output_slot);
|
||||
@ -859,8 +853,8 @@ void MklLayoutRewritePass::GetNodeProducingMklTensor(std::unique_ptr<Graph>* g,
|
||||
// output slot number for Mkl tensor would be N+slot number of TensorFlow
|
||||
// tensor, where N is total number of TensorFlow tensors.
|
||||
*mkl_node = n;
|
||||
*mkl_node_output_slot = GetTensorMetaDataIndex(n_output_slot,
|
||||
n->num_outputs());
|
||||
*mkl_node_output_slot =
|
||||
GetTensorMetaDataIndex(n_output_slot, n->num_outputs());
|
||||
} else {
|
||||
// If we have not visited the node and rewritten it, then we need
|
||||
// to create a dummy node that will feed a dummy Mkl tensor to this node.
|
||||
@ -872,7 +866,8 @@ void MklLayoutRewritePass::GetNodeProducingMklTensor(std::unique_ptr<Graph>* g,
|
||||
}
|
||||
}
|
||||
|
||||
int MklLayoutRewritePass::SetUpContiguousInputs(std::unique_ptr<Graph>* g,
|
||||
int MklLayoutRewritePass::SetUpContiguousInputs(
|
||||
std::unique_ptr<Graph>* g,
|
||||
const gtl::InlinedVector<std::pair<Node*, int>, 4>& old_node_inputs,
|
||||
NodeBuilder* nb, Node* old_node,
|
||||
std::vector<NodeBuilder::NodeOut>* workspace_tensors,
|
||||
@ -931,16 +926,16 @@ int MklLayoutRewritePass::SetUpContiguousInputs(std::unique_ptr<Graph>* g,
|
||||
if (ArgIsList(arg)) {
|
||||
std::vector<NodeBuilder::NodeOut> new_node_inputs;
|
||||
int N = GetTensorListLength(arg, old_node);
|
||||
GetNodesProducingMklTensorList(g, old_node_inputs, &iidx,
|
||||
N, &new_node_inputs);
|
||||
GetNodesProducingMklTensorList(g, old_node_inputs, &iidx, N,
|
||||
&new_node_inputs);
|
||||
nb->Input(new_node_inputs);
|
||||
nn_slot_idx++;
|
||||
} else {
|
||||
Node* mkl_node = nullptr;
|
||||
int mkl_node_output_slot = 0;
|
||||
GetNodeProducingMklTensor(g, old_node_inputs[iidx].first,
|
||||
old_node_inputs[iidx].second,
|
||||
&mkl_node, &mkl_node_output_slot);
|
||||
old_node_inputs[iidx].second, &mkl_node,
|
||||
&mkl_node_output_slot);
|
||||
nb->Input(mkl_node, mkl_node_output_slot);
|
||||
iidx++;
|
||||
nn_slot_idx++;
|
||||
@ -961,7 +956,8 @@ int MklLayoutRewritePass::SetUpContiguousInputs(std::unique_ptr<Graph>* g,
|
||||
return nn_slot_idx;
|
||||
}
|
||||
|
||||
Status MklLayoutRewritePass::SetUpInputs(std::unique_ptr<Graph>* g,
|
||||
Status MklLayoutRewritePass::SetUpInputs(
|
||||
std::unique_ptr<Graph>* g,
|
||||
const gtl::InlinedVector<std::pair<Node*, int>, 4>& old_node_inputs,
|
||||
NodeBuilder* nb, Node* old_node) {
|
||||
// Let's check if we need to add workspace tensors for this node.
|
||||
@ -975,13 +971,14 @@ Status MklLayoutRewritePass::SetUpInputs(std::unique_ptr<Graph>* g,
|
||||
if (kTensorOrdering == MklTfTensorOrdering::TENSORS_INTERLEAVED) {
|
||||
// TODO(nhasabni): implement this function just for same of completion.
|
||||
// We do not use interleaved ordering right now.
|
||||
return Status(error::Code::UNIMPLEMENTED,
|
||||
"Interleaved ordering of tensors is currently not supported.");
|
||||
return Status(
|
||||
error::Code::UNIMPLEMENTED,
|
||||
"Interleaved ordering of tensors is currently not supported.");
|
||||
} else {
|
||||
CHECK_EQ(kTensorOrdering, MklTfTensorOrdering::TENSORS_CONTIGUOUS);
|
||||
new_node_input_slots = SetUpContiguousInputs(g, old_node_inputs, nb,
|
||||
old_node, &workspace_tensors,
|
||||
are_workspace_tensors_available);
|
||||
new_node_input_slots = SetUpContiguousInputs(
|
||||
g, old_node_inputs, nb, old_node, &workspace_tensors,
|
||||
are_workspace_tensors_available);
|
||||
}
|
||||
|
||||
// Sanity check
|
||||
@ -1023,20 +1020,19 @@ void MklLayoutRewritePass::GetDummyWorkspaceTensorNode(
|
||||
TensorShape dummy_shape({1});
|
||||
dummy_shape.AsProto(proto.mutable_tensor_shape());
|
||||
TF_CHECK_OK(NodeBuilder((*g)->NewName("DMT"), "Const")
|
||||
.Attr("value", proto)
|
||||
.Attr("dtype", dt)
|
||||
.Device(orig_node->def().device()) // We place this node on
|
||||
// same the device as the
|
||||
// device of the original
|
||||
// node.
|
||||
.Finalize(&**g, out));
|
||||
.Attr("value", proto)
|
||||
.Attr("dtype", dt)
|
||||
.Device(orig_node->def().device()) // We place this node on
|
||||
// same the device as the
|
||||
// device of the original
|
||||
// node.
|
||||
.Finalize(&**g, out));
|
||||
(*out)->set_assigned_device_name(orig_node->assigned_device_name());
|
||||
}
|
||||
|
||||
void MklLayoutRewritePass::AddWorkSpaceEdgeIfNeeded(std::unique_ptr<Graph>* g,
|
||||
Node* orig_node, NodeBuilder* nb,
|
||||
std::vector<NodeBuilder::NodeOut>* ws_tensors,
|
||||
bool* are_ws_tensors_added) {
|
||||
void MklLayoutRewritePass::AddWorkSpaceEdgeIfNeeded(
|
||||
std::unique_ptr<Graph>* g, Node* orig_node, NodeBuilder* nb,
|
||||
std::vector<NodeBuilder::NodeOut>* ws_tensors, bool* are_ws_tensors_added) {
|
||||
bool workspace_edge_added = false; // Default initializer
|
||||
CHECK_NOTNULL(are_ws_tensors_added);
|
||||
*are_ws_tensors_added = false; // Default initializer
|
||||
@ -1071,8 +1067,8 @@ void MklLayoutRewritePass::AddWorkSpaceEdgeIfNeeded(std::unique_ptr<Graph>* g,
|
||||
nb->Attr("workspace_enabled", false);
|
||||
}
|
||||
} else if (orig_node->type_string() == ws.bwd_op &&
|
||||
mkl_op_registry::IsMklOp(
|
||||
GetMklOpName(orig_node->type_string()), T)) {
|
||||
mkl_op_registry::IsMklOp(GetMklOpName(orig_node->type_string()),
|
||||
T)) {
|
||||
// If this op is a bwd op, then we need to add workspace edge and
|
||||
// it's Mkl tensor edge between its corresponding fwd op and this
|
||||
// op. Corresponding fwd op is specified in 'fwd_op' field of
|
||||
@ -1094,8 +1090,9 @@ void MklLayoutRewritePass::AddWorkSpaceEdgeIfNeeded(std::unique_ptr<Graph>* g,
|
||||
// Add workspace edge between fwd op and bwd op.
|
||||
ws_tensors->push_back(NodeBuilder::NodeOut(e->src(), ws.ws_fwd_slot));
|
||||
// Add Mkl tensor edge for workspace edge between fwd op and bwd op.
|
||||
ws_tensors->push_back(NodeBuilder::NodeOut(e->src(),
|
||||
DataIndexToMetaDataIndex(ws.ws_fwd_slot, e->src()->num_outputs())));
|
||||
ws_tensors->push_back(NodeBuilder::NodeOut(
|
||||
e->src(), DataIndexToMetaDataIndex(ws.ws_fwd_slot,
|
||||
e->src()->num_outputs())));
|
||||
*are_ws_tensors_added = true;
|
||||
// In terms of input ordering, we add these calls to add Input
|
||||
// here because workspace edge (and its Mkl tensor) is the last
|
||||
@ -1154,8 +1151,8 @@ void MklLayoutRewritePass::CopyAttrsConv2D(const Node* orig_node,
|
||||
TF_CHECK_OK(GetNodeAttr(orig_node->def(), "strides", &strides));
|
||||
TF_CHECK_OK(GetNodeAttr(orig_node->def(), "padding", &padding));
|
||||
TF_CHECK_OK(GetNodeAttr(orig_node->def(), "data_format", &data_format));
|
||||
TF_CHECK_OK(GetNodeAttr(orig_node->def(), "use_cudnn_on_gpu",
|
||||
&use_cudnn_on_gpu));
|
||||
TF_CHECK_OK(
|
||||
GetNodeAttr(orig_node->def(), "use_cudnn_on_gpu", &use_cudnn_on_gpu));
|
||||
|
||||
// Add attributes to new node.
|
||||
nb->Attr("T", T);
|
||||
@ -1307,14 +1304,14 @@ void MklLayoutRewritePass::CopyAttrsFusedBatchNorm(const Node* orig_node,
|
||||
}
|
||||
|
||||
void MklLayoutRewritePass::CopyAttrsReshape(const Node* orig_node,
|
||||
NodeBuilder* nb) {
|
||||
NodeBuilder* nb) {
|
||||
DataType T;
|
||||
DataType Tshape;
|
||||
|
||||
|
||||
// Get all attributes from old node.
|
||||
TF_CHECK_OK(GetNodeAttr(orig_node->def(), "T", &T));
|
||||
TF_CHECK_OK(GetNodeAttr(orig_node->def(), "Tshape", &Tshape));
|
||||
|
||||
|
||||
// Add attributes to new node.
|
||||
nb->Attr("T", T);
|
||||
nb->Attr("Tshape", Tshape);
|
||||
@ -1435,7 +1432,7 @@ Status MklLayoutRewritePass::MergeNode(std::unique_ptr<Graph>* g, Node* succ,
|
||||
// 2. Get inputs from both the nodes.
|
||||
// Find the 2 inputs from the conv and the bias from the add Bias.
|
||||
// Get operand 0, 1 of conv2D and their Mkl tensors.
|
||||
CHECK_EQ(pred->in_edges().size(), 4); // MklConv2D must have 4 inputs.
|
||||
CHECK_EQ(pred->in_edges().size(), 4); // _MklConv2D must have 4 inputs.
|
||||
// Get operand 1 of add_bias
|
||||
// BiasAdd must have 2 inputs: Conv, bias
|
||||
CHECK_EQ(succ->in_edges().size(), 2);
|
||||
@ -1538,15 +1535,15 @@ Status MklLayoutRewritePass::RewriteNode(std::unique_ptr<Graph>* g,
|
||||
DataType orig_T, ctx_T;
|
||||
string orig_data_format, ctx_data_format;
|
||||
TF_CHECK_OK(GetNodeAttr(orig_node->def(), "T", &orig_T));
|
||||
TF_CHECK_OK(GetNodeAttr(orig_node->def(), "data_format",
|
||||
&orig_data_format));
|
||||
TF_CHECK_OK(
|
||||
GetNodeAttr(orig_node->def(), "data_format", &orig_data_format));
|
||||
TF_CHECK_OK(GetNodeAttr(fwd_node->def(), "T", &ctx_T));
|
||||
TF_CHECK_OK(GetNodeAttr(fwd_node->def(), "data_format",
|
||||
&ctx_data_format));
|
||||
TF_CHECK_OK(
|
||||
GetNodeAttr(fwd_node->def(), "data_format", &ctx_data_format));
|
||||
|
||||
if (orig_data_format != ctx_data_format || orig_T != ctx_T ||
|
||||
orig_node->assigned_device_name() !=
|
||||
fwd_node->assigned_device_name() ||
|
||||
fwd_node->assigned_device_name() ||
|
||||
orig_node->def().device() != fwd_node->def().device()) {
|
||||
return Status(
|
||||
error::Code::INVALID_ARGUMENT,
|
||||
@ -1613,9 +1610,10 @@ Status MklLayoutRewritePass::RewriteNode(std::unique_ptr<Graph>* g,
|
||||
if (e->src_output() < 0) {
|
||||
(*g)->AddEdge(new_node, e->src_output(), e->dst(), e->dst_input());
|
||||
} else {
|
||||
(*g)->AddEdge(new_node, GetTensorDataIndex(e->src_output(),
|
||||
e->src()->num_outputs()),
|
||||
e->dst(), e->dst_input());
|
||||
(*g)->AddEdge(
|
||||
new_node,
|
||||
GetTensorDataIndex(e->src_output(), e->src()->num_outputs()),
|
||||
e->dst(), e->dst_input());
|
||||
}
|
||||
}
|
||||
|
||||
|
@ -110,13 +110,11 @@ class MklLayoutPassTest : public ::testing::Test {
|
||||
};
|
||||
|
||||
REGISTER_OP("Input").Output("o: float").SetIsStateful();
|
||||
REGISTER_OP("InputList").Output("o: N * float")
|
||||
.Attr("N: int")
|
||||
.SetIsStateful();
|
||||
REGISTER_OP("InputList").Output("o: N * float").Attr("N: int").SetIsStateful();
|
||||
REGISTER_OP("HalfInput").Output("o: half").SetIsStateful();
|
||||
REGISTER_OP("Int32Input").Output("o: int32").SetIsStateful();
|
||||
REGISTER_OP("MklInput").Output("o: uint8").SetIsStateful();
|
||||
REGISTER_OP("MklInput2").Output("o: uint8").Output("o1: uint8").SetIsStateful();
|
||||
REGISTER_OP("_MklInput").Output("o: uint8").SetIsStateful();
|
||||
REGISTER_OP("_MklInput2").Output("o: uint8").Output("o1: uint8").SetIsStateful();
|
||||
|
||||
/////////////////////////////////////////////////////////////////////
|
||||
// Unit tests related to node merge optiimization
|
||||
@ -137,16 +135,16 @@ TEST_F(MklLayoutPassTest, Basic) {
|
||||
|
||||
// Test set 1: Conv2D + AddBias
|
||||
|
||||
// C=MklConv2D(A,M,B,N); E=BiasAdd(C,D); Z=Sub(E,Y) (for interleaved ordering)
|
||||
// C=MklConv2D(A,B,M,N); E=BiasAdd(C,D); Z=Sub(E,Y) (for contiguous ordering)
|
||||
// C=_MklConv2D(A,M,B,N); E=BiasAdd(C,D); Z=Sub(E,Y) (for interleaved ordering)
|
||||
// C=_MklConv2D(A,B,M,N); E=BiasAdd(C,D); Z=Sub(E,Y) (for contiguous ordering)
|
||||
TEST_F(MklLayoutPassTest, NodeMerge_Conv2DWithBias_Positive) {
|
||||
CHECK_EQ(kTensorOrdering, MklTfTensorOrdering::TENSORS_CONTIGUOUS);
|
||||
InitGraph(
|
||||
"node { name: 'A' op: 'Input'}"
|
||||
"node { name: 'B' op: 'Input'}"
|
||||
"node { name: 'M' op: 'MklInput'}"
|
||||
"node { name: 'N' op: 'MklInput'}"
|
||||
"node { name: 'C' op: 'MklConv2D'"
|
||||
"node { name: 'M' op: '_MklInput'}"
|
||||
"node { name: 'N' op: '_MklInput'}"
|
||||
"node { name: 'C' op: '_MklConv2D'"
|
||||
" attr { key: 'T' value { type: DT_FLOAT } }"
|
||||
" attr { key: 'data_format' value { s: 'NCHW' } }"
|
||||
" attr { key: 'use_cudnn_on_gpu' value { b: false } }"
|
||||
@ -163,22 +161,22 @@ TEST_F(MklLayoutPassTest, NodeMerge_Conv2DWithBias_Positive) {
|
||||
" attr {key: 'T' value { type: DT_FLOAT } }"
|
||||
" input: ['E', 'Y']}");
|
||||
EXPECT_EQ(DoMklLayoutOptimizationPass(),
|
||||
"A(Input);B(Input);D(Input);DMT/_0(Const);E(MklConv2DWithBias);"
|
||||
"M(MklInput);N(MklInput);Y(Input);Z(Sub)|A->E;B->E:1;D->E:2;"
|
||||
"A(Input);B(Input);D(Input);DMT/_0(Const);E(_MklConv2DWithBias);"
|
||||
"M(_MklInput);N(_MklInput);Y(Input);Z(Sub)|A->E;B->E:1;D->E:2;"
|
||||
"DMT/_0->E:5;E->Z;M->E:3;N->E:4;Y->Z:1");
|
||||
}
|
||||
|
||||
// C=MklConv2D(A,M:1,B,N:1); E=BiasAdd(C,D); Z=Sub(E,Y) (for interleaved)
|
||||
// C=MklConv2D(A,B,M:1,N:1); E=BiasAdd(C,D); Z=Sub(E,Y) (for contiguous)
|
||||
// C=_MklConv2D(A,M:1,B,N:1); E=BiasAdd(C,D); Z=Sub(E,Y) (for interleaved)
|
||||
// C=_MklConv2D(A,B,M:1,N:1); E=BiasAdd(C,D); Z=Sub(E,Y) (for contiguous)
|
||||
// Test for correct output slots selected
|
||||
TEST_F(MklLayoutPassTest, NodeMerge_Conv2DWithBias_Positive1) {
|
||||
CHECK_EQ(kTensorOrdering, MklTfTensorOrdering::TENSORS_CONTIGUOUS);
|
||||
InitGraph(
|
||||
"node { name: 'A' op: 'Input'}"
|
||||
"node { name: 'B' op: 'Input'}"
|
||||
"node { name: 'M' op: 'MklInput2'}"
|
||||
"node { name: 'N' op: 'MklInput2'}"
|
||||
"node { name: 'C' op: 'MklConv2D'"
|
||||
"node { name: 'M' op: '_MklInput2'}"
|
||||
"node { name: 'N' op: '_MklInput2'}"
|
||||
"node { name: 'C' op: '_MklConv2D'"
|
||||
" attr { key: 'T' value { type: DT_FLOAT } }"
|
||||
" attr { key: 'data_format' value { s: 'NCHW' } }"
|
||||
" attr { key: 'use_cudnn_on_gpu' value { b: false } }"
|
||||
@ -195,15 +193,15 @@ TEST_F(MklLayoutPassTest, NodeMerge_Conv2DWithBias_Positive1) {
|
||||
" attr {key: 'T' value { type: DT_FLOAT } }"
|
||||
" input: ['E', 'Y']}");
|
||||
EXPECT_EQ(DoMklLayoutOptimizationPass(),
|
||||
"A(Input);B(Input);D(Input);DMT/_0(Const);E(MklConv2DWithBias);"
|
||||
"M(MklInput2);N(MklInput2);Y(Input);Z(Sub)|A->E;B->E:1;D->E:2;"
|
||||
"A(Input);B(Input);D(Input);DMT/_0(Const);E(_MklConv2DWithBias);"
|
||||
"M(_MklInput2);N(_MklInput2);Y(Input);Z(Sub)|A->E;B->E:1;D->E:2;"
|
||||
"DMT/_0->E:5;E->Z;M:1->E:3;N:1->E:4;Y->Z:1");
|
||||
}
|
||||
|
||||
// C=Conv2D(A,B); E=BiasAdd(C,D); Z=Sub(E,Y);
|
||||
// This is a case of node rewrite followed by node merge.
|
||||
// We will first rewrite Conv2D to MklConv2D, and then merge MklConv2D
|
||||
// with BiasAdd to produce MklConv2DWithBias.
|
||||
// We will first rewrite Conv2D to _MklConv2D, and then merge _MklConv2D
|
||||
// with BiasAdd to produce _MklConv2DWithBias.
|
||||
TEST_F(MklLayoutPassTest, NodeMerge_Conv2DWithBias_Positive2) {
|
||||
CHECK_EQ(kTensorOrdering, MklTfTensorOrdering::TENSORS_CONTIGUOUS);
|
||||
InitGraph(
|
||||
@ -227,19 +225,19 @@ TEST_F(MklLayoutPassTest, NodeMerge_Conv2DWithBias_Positive2) {
|
||||
" input: ['E', 'Y']}");
|
||||
EXPECT_EQ(DoMklLayoutOptimizationPass(),
|
||||
"A(Input);B(Input);D(Input);DMT/_0(Const);DMT/_1(Const);"
|
||||
"DMT/_2(Const);E(MklConv2DWithBias);Y(Input);Z(Sub)|"
|
||||
"DMT/_2(Const);E(_MklConv2DWithBias);Y(Input);Z(Sub)|"
|
||||
"A->E;B->E:1;D->E:2;DMT/_0->E:3;DMT/_1->E:4;DMT/_2->E:5;"
|
||||
"E->Z;Y->Z:1");
|
||||
}
|
||||
|
||||
// Graph contains only MklConv2D, no AddBias.
|
||||
// Graph contains only _MklConv2D, no AddBias.
|
||||
TEST_F(MklLayoutPassTest, NodeMerge_Conv2DWithBias_Negative_NoAddBias) {
|
||||
InitGraph(
|
||||
"node { name: 'A' op: 'Input'}"
|
||||
"node { name: 'B' op: 'Input'}"
|
||||
"node { name: 'M' op: 'MklInput'}"
|
||||
"node { name: 'N' op: 'MklInput'}"
|
||||
"node { name: 'C' op: 'MklConv2D'"
|
||||
"node { name: 'M' op: '_MklInput'}"
|
||||
"node { name: 'N' op: '_MklInput'}"
|
||||
"node { name: 'C' op: '_MklConv2D'"
|
||||
" attr { key: 'T' value { type: DT_FLOAT } }"
|
||||
" attr { key: 'data_format' value { s: 'NCHW' } }"
|
||||
" attr { key: 'use_cudnn_on_gpu' value { b: false } }"
|
||||
@ -247,18 +245,18 @@ TEST_F(MklLayoutPassTest, NodeMerge_Conv2DWithBias_Negative_NoAddBias) {
|
||||
" attr { key: 'padding' value { s: 'SAME' } }"
|
||||
" input: ['A', 'B', 'M', 'N']}");
|
||||
EXPECT_EQ(DoMklLayoutOptimizationPass(),
|
||||
"A(Input);B(Input);C(MklConv2D);M(MklInput);N(MklInput)|"
|
||||
"A->C;B->C:1;M->C:2;N->C:3");
|
||||
"A(Input);B(Input);C(_MklConv2D);M(_MklInput);N(_MklInput)|"
|
||||
"A->C;B->C:1;M->C:2;N->C:3");
|
||||
}
|
||||
|
||||
// MklConv2D output does not go to BiasAdd.
|
||||
// _MklConv2D output does not go to BiasAdd.
|
||||
TEST_F(MklLayoutPassTest, NodeMerge_Conv2DWithBias_Negative_Dataflow1) {
|
||||
InitGraph(
|
||||
"node { name: 'A' op: 'Input'}"
|
||||
"node { name: 'B' op: 'Input'}"
|
||||
"node { name: 'M' op: 'MklInput'}"
|
||||
"node { name: 'N' op: 'MklInput'}"
|
||||
"node { name: 'C' op: 'MklConv2D'"
|
||||
"node { name: 'M' op: '_MklInput'}"
|
||||
"node { name: 'N' op: '_MklInput'}"
|
||||
"node { name: 'C' op: '_MklConv2D'"
|
||||
" attr { key: 'T' value { type: DT_FLOAT } }"
|
||||
" attr { key: 'data_format' value { s: 'NCHW' } }"
|
||||
" attr { key: 'use_cudnn_on_gpu' value { b: false } }"
|
||||
@ -270,21 +268,21 @@ TEST_F(MklLayoutPassTest, NodeMerge_Conv2DWithBias_Negative_Dataflow1) {
|
||||
"node { name: 'F' op: 'BiasAdd'"
|
||||
" attr { key: 'T' value { type: DT_FLOAT } }"
|
||||
" attr { key: 'data_format' value { s: 'NCHW' } }"
|
||||
" input: ['D', 'E'] }"); // Output of MklConv2D does not go to BiasAdd.
|
||||
" input: ['D', 'E'] }"); // Output of _MklConv2D does not go to BiasAdd.
|
||||
EXPECT_EQ(DoMklLayoutOptimizationPass(),
|
||||
"A(Input);B(Input);C(MklConv2D);D(Input);E(Input);F(BiasAdd);"
|
||||
"M(MklInput);N(MklInput)|A->C;B->C:1;D->F;E->F:1;M->C:2;N->C:3");
|
||||
"A(Input);B(Input);C(_MklConv2D);D(Input);E(Input);F(BiasAdd);"
|
||||
"M(_MklInput);N(_MklInput)|A->C;B->C:1;D->F;E->F:1;M->C:2;N->C:3");
|
||||
}
|
||||
|
||||
// MklConv2D has two outgoing edges: BiasAdd and some other dummy node (Add).
|
||||
// _MklConv2D has two outgoing edges: BiasAdd and some other dummy node (Add).
|
||||
// Merge should not be done in such case.
|
||||
TEST_F(MklLayoutPassTest, NodeMerge_Conv2DWithBias_Negative_Dataflow2) {
|
||||
InitGraph(
|
||||
"node { name: 'A' op: 'Input'}"
|
||||
"node { name: 'B' op: 'Input'}"
|
||||
"node { name: 'M' op: 'MklInput'}"
|
||||
"node { name: 'N' op: 'MklInput'}"
|
||||
"node { name: 'C' op: 'MklConv2D'"
|
||||
"node { name: 'M' op: '_MklInput'}"
|
||||
"node { name: 'N' op: '_MklInput'}"
|
||||
"node { name: 'C' op: '_MklConv2D'"
|
||||
" attr { key: 'T' value { type: DT_FLOAT } }"
|
||||
" attr { key: 'data_format' value { s: 'NCHW' } }"
|
||||
" attr { key: 'use_cudnn_on_gpu' value { b: false } }"
|
||||
@ -302,8 +300,8 @@ TEST_F(MklLayoutPassTest, NodeMerge_Conv2DWithBias_Negative_Dataflow2) {
|
||||
" attr { key: 'T' value { type: DT_FLOAT } }"
|
||||
" input: ['C', 'E'] }");
|
||||
EXPECT_EQ(DoMklLayoutOptimizationPass(),
|
||||
"A(Input);B(Input);C(MklConv2D);D(Input);E(Input);F(BiasAdd);"
|
||||
"G(Add);M(MklInput);N(MklInput)|A->C;B->C:1;C->G;D->F;"
|
||||
"A(Input);B(Input);C(_MklConv2D);D(Input);E(Input);F(BiasAdd);"
|
||||
"G(Add);M(_MklInput);N(_MklInput)|A->C;B->C:1;C->G;D->F;"
|
||||
"E->F:1;E->G:1;M->C:2;N->C:3");
|
||||
}
|
||||
|
||||
@ -313,9 +311,9 @@ TEST_F(MklLayoutPassTest, NodeMerge_Conv2DWithBias_Negative_AttrMismatch) {
|
||||
InitGraph(
|
||||
"node { name: 'A' op: 'Input'}"
|
||||
"node { name: 'B' op: 'Input'}"
|
||||
"node { name: 'M' op: 'MklInput'}"
|
||||
"node { name: 'N' op: 'MklInput'}"
|
||||
"node { name: 'C' op: 'MklConv2D'"
|
||||
"node { name: 'M' op: '_MklInput'}"
|
||||
"node { name: 'N' op: '_MklInput'}"
|
||||
"node { name: 'C' op: '_MklConv2D'"
|
||||
" attr { key: 'T' value { type: DT_FLOAT } }"
|
||||
" attr { key: 'data_format' value { s: 'NCHW' } }"
|
||||
" attr { key: 'use_cudnn_on_gpu' value { b: false } }"
|
||||
@ -328,26 +326,26 @@ TEST_F(MklLayoutPassTest, NodeMerge_Conv2DWithBias_Negative_AttrMismatch) {
|
||||
" attr { key: 'data_format' value { s: 'NHCW' } }"
|
||||
" input: ['C', 'D'] }");
|
||||
EXPECT_EQ(DoMklLayoutOptimizationPass(),
|
||||
"A(Input);B(Input);C(MklConv2D);D(Input);E(BiasAdd);M(MklInput);"
|
||||
"N(MklInput)|A->C;B->C:1;C->E;D->E:1;M->C:2;N->C:3");
|
||||
"A(Input);B(Input);C(_MklConv2D);D(Input);E(BiasAdd);M(_MklInput);"
|
||||
"N(_MklInput)|A->C;B->C:1;C->E;D->E:1;M->C:2;N->C:3");
|
||||
}
|
||||
|
||||
// Disabling Conv2DBackpropBias test for now as we have disabled rewrite
|
||||
// of BiasAddGrad into BackpropBias
|
||||
#if 0
|
||||
// Test set 2: MklConv2D..BiasAddGrad -> MklConv2DWithBiasBackpropBias
|
||||
// Test set 2: _MklConv2D..BiasAddGrad -> _MklConv2DWithBiasBackpropBias
|
||||
// rewrite tests
|
||||
|
||||
// D=MklConv2D(A,M,B,N,C,O); E=Sub(D,A); F=BiasAddGrad(E)
|
||||
// D=_MklConv2D(A,M,B,N,C,O); E=Sub(D,A); F=BiasAddGrad(E)
|
||||
TEST_F(MklLayoutPassTest, NodeMerge_Conv2DBackprop_Positive) {
|
||||
InitGraph(
|
||||
"node { name: 'A' op: 'Input'}"
|
||||
"node { name: 'B' op: 'Input'}"
|
||||
"node { name: 'C' op: 'Input'}"
|
||||
"node { name: 'M' op: 'MklInput'}"
|
||||
"node { name: 'N' op: 'MklInput'}"
|
||||
"node { name: 'O' op: 'MklInput'}"
|
||||
"node { name: 'D' op: 'MklConv2DWithBias'"
|
||||
"node { name: 'M' op: '_MklInput'}"
|
||||
"node { name: 'N' op: '_MklInput'}"
|
||||
"node { name: 'O' op: '_MklInput'}"
|
||||
"node { name: 'D' op: '_MklConv2DWithBias'"
|
||||
" attr { key: 'T' value { type: DT_FLOAT } }"
|
||||
" attr { key: 'data_format' value { s: 'NCHW' } }"
|
||||
" attr { key: 'use_cudnn_on_gpu' value { b: false } }"
|
||||
@ -362,25 +360,25 @@ TEST_F(MklLayoutPassTest, NodeMerge_Conv2DBackprop_Positive) {
|
||||
" attr { key: 'data_format' value { s: 'NCHW' } }"
|
||||
" input: ['E'] }");
|
||||
EXPECT_EQ(DoMklLayoutOptimizationPass(),
|
||||
"A(Input);B(Input);C(Input);D(MklConv2DWithBias);DMT/_0(Const);"
|
||||
"E(Sub);F(MklConv2DWithBiasBackpropBias);M(MklInput);N(MklInput);"
|
||||
"O(MklInput)|A->D;A->E:1;B->D:1;C->D:2;D->E;DMT/_0->F:1;E->F;"
|
||||
"A(Input);B(Input);C(Input);D(_MklConv2DWithBias);DMT/_0(Const);"
|
||||
"E(Sub);F(_MklConv2DWithBiasBackpropBias);M(_MklInput);N(_MklInput);"
|
||||
"O(_MklInput)|A->D;A->E:1;B->D:1;C->D:2;D->E;DMT/_0->F:1;E->F;"
|
||||
"M->D:3;N->D:4;O->D:5");
|
||||
}
|
||||
#endif
|
||||
|
||||
// No MklConv2D in context, but Conv2D in context.
|
||||
// Only Conv2D would be rewritten to MklConv2D, but no rewrite
|
||||
// No _MklConv2D in context, but Conv2D in context.
|
||||
// Only Conv2D would be rewritten to _MklConv2D, but no rewrite
|
||||
// for BiasAddGrad should happen.
|
||||
// C=MklConv2D(A,M,B,N); D=Sub(C,A); E=BiasAddGrad(D) (for interleaved)
|
||||
// C=MklConv2D(A,B,M,N); D=Sub(C,A); E=BiasAddGrad(D) (for contiguous)
|
||||
TEST_F(MklLayoutPassTest, NodeMerge_Conv2DBackprop_Neg_NoMklConv2DWithBias) {
|
||||
// C=_MklConv2D(A,M,B,N); D=Sub(C,A); E=BiasAddGrad(D) (for interleaved)
|
||||
// C=_MklConv2D(A,B,M,N); D=Sub(C,A); E=BiasAddGrad(D) (for contiguous)
|
||||
TEST_F(MklLayoutPassTest, NodeMerge_Conv2DBackprop_Neg_No_MklConv2DWithBias) {
|
||||
InitGraph(
|
||||
"node { name: 'A' op: 'Input'}"
|
||||
"node { name: 'B' op: 'Input'}"
|
||||
"node { name: 'M' op: 'MklInput'}"
|
||||
"node { name: 'N' op: 'MklInput'}"
|
||||
"node { name: 'C' op: 'MklConv2D'"
|
||||
"node { name: 'M' op: '_MklInput'}"
|
||||
"node { name: 'N' op: '_MklInput'}"
|
||||
"node { name: 'C' op: '_MklConv2D'"
|
||||
" attr { key: 'T' value { type: DT_FLOAT } }"
|
||||
" attr { key: 'data_format' value { s: 'NCHW' } }"
|
||||
" attr { key: 'use_cudnn_on_gpu' value { b: false } }"
|
||||
@ -395,8 +393,8 @@ TEST_F(MklLayoutPassTest, NodeMerge_Conv2DBackprop_Neg_NoMklConv2DWithBias) {
|
||||
" attr { key: 'data_format' value { s: 'NCHW' } }"
|
||||
" input: ['D'] }");
|
||||
EXPECT_EQ(DoMklLayoutOptimizationPass(),
|
||||
"A(Input);B(Input);C(MklConv2D);D(Sub);E(BiasAddGrad);"
|
||||
"M(MklInput);N(MklInput)|A->C;A->D:1;B->C:1;C->D;D->E;"
|
||||
"A(Input);B(Input);C(_MklConv2D);D(Sub);E(BiasAddGrad);"
|
||||
"M(_MklInput);N(_MklInput)|A->C;A->D:1;B->C:1;C->D;D->E;"
|
||||
"M->C:2;N->C:3");
|
||||
}
|
||||
|
||||
@ -509,7 +507,7 @@ TEST_F(MklLayoutPassTest, NodeRewrite_Conv2D_Basic) {
|
||||
"node { name: 'D' op: 'Mul' attr { key: 'T' value { type: DT_FLOAT } }"
|
||||
" input: ['B', 'C'] }");
|
||||
EXPECT_EQ(DoMklLayoutOptimizationPass(),
|
||||
"A(Input);B(Input);C(MklConv2D);D(Mul);DMT/_0(Const);DMT/_1(Const)|"
|
||||
"A(Input);B(Input);C(_MklConv2D);D(Mul);DMT/_0(Const);DMT/_1(Const)|"
|
||||
"A->C;B->C:1;B->D;C->D:1;DMT/_0->C:2;DMT/_1->C:3");
|
||||
}
|
||||
|
||||
@ -536,7 +534,7 @@ TEST_F(MklLayoutPassTest, NodeRewrite_Conv2D_Positive1) {
|
||||
"node { name: 'E' op: 'Mul' attr { key: 'T' value { type: DT_FLOAT } }"
|
||||
" input: ['C', 'D'] }");
|
||||
EXPECT_EQ(DoMklLayoutOptimizationPass(),
|
||||
"A(Input);B(Input);C(MklConv2D);D(MklConv2D);DMT/_0(Const);"
|
||||
"A(Input);B(Input);C(_MklConv2D);D(_MklConv2D);DMT/_0(Const);"
|
||||
"DMT/_1(Const);DMT/_2(Const);E(Mul)|A->C;A->D;B->C:1;C->D:1;C->E;"
|
||||
"C:1->D:3;D->E:1;DMT/_0->C:2;DMT/_1->C:3;DMT/_2->D:2");
|
||||
}
|
||||
@ -578,7 +576,7 @@ TEST_F(MklLayoutPassTest, NodeRewrite_Concat_Basic) {
|
||||
"node { name: 'E' op: 'Mul' attr { key: 'T' value { type: DT_FLOAT } }"
|
||||
" input: ['C', 'D'] }");
|
||||
EXPECT_EQ(DoMklLayoutOptimizationPass(),
|
||||
"A(Const);B(InputList);C(Input);D(MklConcat);DMT/_0(Const);"
|
||||
"A(Const);B(InputList);C(Input);D(_MklConcat);DMT/_0(Const);"
|
||||
"DMT/_1(Const);DMT/_2(Const);E(Mul)|A->D;B->D:1;B->D:2;C->E;"
|
||||
"D->E:1;DMT/_0->D:3;DMT/_1->D:4;DMT/_2->D:5");
|
||||
}
|
||||
@ -617,8 +615,8 @@ TEST_F(MklLayoutPassTest, NodeRewrite_Concat_Input_Mkl) {
|
||||
" input: ['A', 'H'] }");
|
||||
EXPECT_EQ(DoMklLayoutOptimizationPass(),
|
||||
"A(Input);B(Input);C(Input);D(Input);DMT/_0(Const);DMT/_1(Const);"
|
||||
"DMT/_2(Const);DMT/_3(Const);DMT/_4(Const);E(MklConv2D);"
|
||||
"F(MklConv2D);G(Const);H(MklConcat);I(Mul)|A->E;A->I;B->E:1;C->F;"
|
||||
"DMT/_2(Const);DMT/_3(Const);DMT/_4(Const);E(_MklConv2D);"
|
||||
"F(_MklConv2D);G(Const);H(_MklConcat);I(Mul)|A->E;A->I;B->E:1;C->F;"
|
||||
"D->F:1;DMT/_0->F:2;DMT/_1->F:3;DMT/_2->E:2;DMT/_3->E:3;"
|
||||
"DMT/_4->H:3;E->H:1;E:1->H:4;F->H:2;F:1->H:5;G->H;H->I:1");
|
||||
}
|
||||
@ -652,8 +650,8 @@ TEST_F(MklLayoutPassTest, NodeRewrite_Concat_Input_MixedMkl) {
|
||||
" input: ['A', 'H'] }");
|
||||
EXPECT_EQ(DoMklLayoutOptimizationPass(),
|
||||
"A(Input);B(Input);C(Input);D(Input);DMT/_0(Const);DMT/_1(Const);"
|
||||
"DMT/_2(Const);DMT/_3(Const);E(MklConv2D);F(Mul);G(Const);"
|
||||
"H(MklConcat);I(Mul)|A->E;A->I;B->E:1;C->F;D->F:1;DMT/_0->E:2;"
|
||||
"DMT/_2(Const);DMT/_3(Const);E(_MklConv2D);F(Mul);G(Const);"
|
||||
"H(_MklConcat);I(Mul)|A->E;A->I;B->E:1;C->F;D->F:1;DMT/_0->E:2;"
|
||||
"DMT/_1->E:3;DMT/_2->H:3;DMT/_3->H:5;E->H:1;E:1->H:4;F->H:2;"
|
||||
"G->H;H->I:1");
|
||||
}
|
||||
@ -678,7 +676,7 @@ TEST_F(MklLayoutPassTest, NodeRewrite_ConcatV2_Basic) {
|
||||
"node { name: 'E' op: 'Mul' attr { key: 'T' value { type: DT_FLOAT } }"
|
||||
" input: ['C', 'D'] }");
|
||||
EXPECT_EQ(DoMklLayoutOptimizationPass(),
|
||||
"A(Const);B(InputList);C(Input);D(MklConcat);DMT/_0(Const);"
|
||||
"A(Const);B(InputList);C(Input);D(_MklConcat);DMT/_0(Const);"
|
||||
"DMT/_1(Const);DMT/_2(Const);E(Mul)|A->D:2;B->D;B:1->D:1;C->E;"
|
||||
"D->E:1;DMT/_0->D:3;DMT/_1->D:4;DMT/_2->D:5");
|
||||
}
|
||||
@ -719,8 +717,8 @@ TEST_F(MklLayoutPassTest, NodeRewrite_ConcatV2_Input_Mkl) {
|
||||
" input: ['A', 'H'] }");
|
||||
EXPECT_EQ(DoMklLayoutOptimizationPass(),
|
||||
"A(Input);B(Input);C(Input);D(Input);DMT/_0(Const);DMT/_1(Const);"
|
||||
"DMT/_2(Const);DMT/_3(Const);DMT/_4(Const);E(MklConv2D);"
|
||||
"F(MklConv2D);G(Const);H(MklConcatV2);I(Mul)|A->E;A->I;B->E:1;C->F;"
|
||||
"DMT/_2(Const);DMT/_3(Const);DMT/_4(Const);E(_MklConv2D);"
|
||||
"F(_MklConv2D);G(Const);H(_MklConcatV2);I(Mul)|A->E;A->I;B->E:1;C->F;"
|
||||
"D->F:1;DMT/_0->F:2;DMT/_1->F:3;DMT/_2->E:2;DMT/_3->E:3;"
|
||||
"DMT/_4->H:5;E->H;E:1->H:3;F->H:1;F:1->H:4;G->H:2;H->I:1");
|
||||
}
|
||||
@ -755,8 +753,8 @@ TEST_F(MklLayoutPassTest, NodeRewrite_ConcatV2_Input_MixedMkl) {
|
||||
" input: ['A', 'H'] }");
|
||||
EXPECT_EQ(DoMklLayoutOptimizationPass(),
|
||||
"A(Input);B(Input);C(Input);D(Input);DMT/_0(Const);DMT/_1(Const);"
|
||||
"DMT/_2(Const);DMT/_3(Const);E(MklConv2D);F(Mul);G(Const);"
|
||||
"H(MklConcatV2);I(Mul)|A->E;A->I;B->E:1;C->F;D->F:1;DMT/_0->E:2;"
|
||||
"DMT/_2(Const);DMT/_3(Const);E(_MklConv2D);F(Mul);G(Const);"
|
||||
"H(_MklConcatV2);I(Mul)|A->E;A->I;B->E:1;C->F;D->F:1;DMT/_0->E:2;"
|
||||
"DMT/_1->E:3;DMT/_2->H:4;DMT/_3->H:5;E->H;E:1->H:3;F->H:1;"
|
||||
"G->H:2;H->I:1");
|
||||
}
|
||||
@ -804,9 +802,10 @@ TEST_F(MklLayoutPassTest, MaxPoolLRN_Positive) {
|
||||
"node { name: 'H' op: 'Input'}"
|
||||
"node { name: 'I' op: 'Mul' attr { key: 'T' value { type: DT_FLOAT } }"
|
||||
" input: ['H', 'G'] }");
|
||||
EXPECT_EQ(DoMklLayoutOptimizationPass(),
|
||||
"A(Input);B(MklLRN);C(MklMaxPool);D(Input);DMT/_0(Const);DMT/_1(Const);"
|
||||
"DMT/_2(Const);E(MklMaxPoolGrad);F(Input);G(MklLRNGrad);H(Input);I(Mul)|"
|
||||
EXPECT_EQ(
|
||||
DoMklLayoutOptimizationPass(),
|
||||
"A(Input);B(_MklLRN);C(_MklMaxPool);D(Input);DMT/_0(Const);DMT/_1(Const);"
|
||||
"DMT/_2(Const);E(_MklMaxPoolGrad);F(Input);G(_MklLRNGrad);H(Input);I(Mul)|"
|
||||
"A->B;B->C;B->E;B->G:2;B:1->G:3;B:2->C:1;B:2->E:4;B:2->G:6;B:3->G:7;"
|
||||
"C->E:1;C:1->E:3;C:2->E:5;C:3->E:7;D->E:2;DMT/_0->B:1;DMT/_1->E:6;"
|
||||
"DMT/_2->G:5;E->G;E:1->G:4;F->G:1;G->I:1;H->I");
|
||||
@ -837,8 +836,8 @@ TEST_F(MklLayoutPassTest, LRN_Positive) {
|
||||
"node { name: 'F' op: 'Mul' attr { key: 'T' value { type: DT_FLOAT } }"
|
||||
" input: ['C', 'E'] }");
|
||||
EXPECT_EQ(DoMklLayoutOptimizationPass(),
|
||||
"A(Input);B(MklLRN);C(Input);D(Input);DMT/_0(Const);DMT/_1(Const);"
|
||||
"DMT/_2(Const);E(MklLRNGrad);F(Mul)|"
|
||||
"A(Input);B(_MklLRN);C(Input);D(Input);DMT/_0(Const);DMT/_1(Const);"
|
||||
"DMT/_2(Const);E(_MklLRNGrad);F(Mul)|"
|
||||
"A->B;B->E:2;B:1->E:3;B:2->E:6;B:3->E:7;C->E;C->F;D->E:1;"
|
||||
"DMT/_0->B:1;DMT/_1->E:4;DMT/_2->E:5;E->F:1");
|
||||
}
|
||||
@ -858,7 +857,7 @@ TEST_F(MklLayoutPassTest, LRN_Negative1) {
|
||||
"node { name: 'C' op: 'Mul' attr { key: 'T' value { type: DT_FLOAT } }"
|
||||
" input: ['A', 'B'] }");
|
||||
EXPECT_EQ(DoMklLayoutOptimizationPass(),
|
||||
"A(Input);B(MklLRN);C(Mul);DMT/_0(Const)|"
|
||||
"A(Input);B(_MklLRN);C(Mul);DMT/_0(Const)|"
|
||||
"A->B;A->C;B->C:1;DMT/_0->B:1");
|
||||
}
|
||||
|
||||
@ -879,7 +878,7 @@ TEST_F(MklLayoutPassTest, LRN_Negative2) {
|
||||
"node { name: 'E' op: 'Mul' attr { key: 'T' value { type: DT_FLOAT } }"
|
||||
" input: ['A', 'D'] }");
|
||||
EXPECT_EQ(DoMklLayoutOptimizationPass(),
|
||||
"A(Input);B(Input);C(Input);D(MklLRNGrad);DMT/_0(Const);"
|
||||
"A(Input);B(Input);C(Input);D(_MklLRNGrad);DMT/_0(Const);"
|
||||
"DMT/_1(Const);DMT/_2(Const);DMT/_3(Const);DMT/_4(Const);E(Mul)|"
|
||||
"A->D;A->E;B->D:1;C->D:2;D->E:1;DMT/_0->D:3;DMT/_1->D:7;"
|
||||
"DMT/_2->D:4;DMT/_3->D:5;DMT/_4->D:6");
|
||||
@ -919,9 +918,9 @@ TEST_F(MklLayoutPassTest, LRN_Negative3) {
|
||||
"node { name: 'G' op: 'Mul' attr { key: 'T' value { type: DT_FLOAT } }"
|
||||
" input: ['E', 'F'] }");
|
||||
EXPECT_EQ(DoMklLayoutOptimizationPass(),
|
||||
"A(Input);B(MklLRN);C(Input);D(Input);DMT/_0(Const);DMT/_1(Const);"
|
||||
"A(Input);B(_MklLRN);C(Input);D(Input);DMT/_0(Const);DMT/_1(Const);"
|
||||
"DMT/_2(Const);DMT/_3(Const);DMT/_4(Const);DMT/_5(Const);"
|
||||
"DMT/_6(Const);E(MklLRNGrad);F(MklLRNGrad);G(Mul)|A->B;B->E:2;"
|
||||
"DMT/_6(Const);E(_MklLRNGrad);F(_MklLRNGrad);G(Mul)|A->B;B->E:2;"
|
||||
"B->F:1;B:1->E:3;B:2->E:6;B:2->F:5;B:3->E:7;C->E;C->F;D->E:1;"
|
||||
"D->F:2;DMT/_0->B:1;DMT/_1->F:3;DMT/_2->F:7;DMT/_3->F:4;"
|
||||
"DMT/_4->F:6;DMT/_5->E:4;DMT/_6->E:5;E->G;F->G:1");
|
||||
@ -950,8 +949,8 @@ TEST_F(MklLayoutPassTest, NodeWorkspace_MaxPool_Positive) {
|
||||
"node { name: 'F' op: 'Mul' attr { key: 'T' value { type: DT_FLOAT } }"
|
||||
" input: ['C', 'E'] }");
|
||||
EXPECT_EQ(DoMklLayoutOptimizationPass(),
|
||||
"A(Input);B(MklMaxPool);C(Input);D(Input);DMT/_0(Const);"
|
||||
"DMT/_1(Const);DMT/_2(Const);E(MklMaxPoolGrad);F(Mul)|"
|
||||
"A(Input);B(_MklMaxPool);C(Input);D(Input);DMT/_0(Const);"
|
||||
"DMT/_1(Const);DMT/_2(Const);E(_MklMaxPoolGrad);F(Mul)|"
|
||||
"A->B;B->E:1;B:1->E:3;B:2->E:5;B:3->E:7;C->E;C->F;D->E:2;"
|
||||
"DMT/_0->B:1;DMT/_1->E:4;DMT/_2->E:6;E->F:1");
|
||||
}
|
||||
@ -972,7 +971,7 @@ TEST_F(MklLayoutPassTest, NodeWorkspace_MaxPool_Negative1) {
|
||||
"node { name: 'C' op: 'Mul' attr { key: 'T' value { type: DT_FLOAT } }"
|
||||
" input: ['A', 'B'] }");
|
||||
EXPECT_EQ(DoMklLayoutOptimizationPass(),
|
||||
"A(Input);B(MklMaxPool);C(Mul);DMT/_0(Const)|"
|
||||
"A(Input);B(_MklMaxPool);C(Mul);DMT/_0(Const)|"
|
||||
"A->B;A->C;B->C:1;DMT/_0->B:1");
|
||||
}
|
||||
|
||||
@ -994,7 +993,7 @@ TEST_F(MklLayoutPassTest, NodeWorkspace_MaxPool_Negative2) {
|
||||
"node { name: 'E' op: 'Mul' attr { key: 'T' value { type: DT_FLOAT } }"
|
||||
" input: ['A', 'D'] }");
|
||||
EXPECT_EQ(DoMklLayoutOptimizationPass(),
|
||||
"A(Input);B(Input);C(Input);D(MklMaxPoolGrad);DMT/_0(Const);"
|
||||
"A(Input);B(Input);C(Input);D(_MklMaxPoolGrad);DMT/_0(Const);"
|
||||
"DMT/_1(Const);DMT/_2(Const);DMT/_3(Const);DMT/_4(Const);E(Mul)|"
|
||||
"A->D;A->E;B->D:1;C->D:2;D->E:1;DMT/_0->D:3;DMT/_1->D:7;"
|
||||
"DMT/_2->D:4;DMT/_3->D:5;DMT/_4->D:6");
|
||||
|
@ -123,22 +123,24 @@ Status MklToTfConversionPass::InsertConversionNodeOnEdge(
|
||||
TF_CHECK_OK(GetNodeAttr(src->def(), "T", &src_datatype));
|
||||
TF_CHECK_OK(GetNodeAttr(dst->def(), "T", &dst_datatype));
|
||||
if (src_datatype != dst_datatype) {
|
||||
string err_msg = "T attribute of " + src->name() + " and " +
|
||||
dst->name() + " do not match. Will not insert" +
|
||||
string err_msg = "T attribute of " + src->name() + " and " + dst->name() +
|
||||
" do not match. Will not insert" +
|
||||
" MklToTf node in such case.";
|
||||
return Status(error::Code::INVALID_ARGUMENT, err_msg.c_str());
|
||||
}
|
||||
|
||||
// Build the conversion node and specify src as input.
|
||||
TF_CHECK_OK(NodeBuilder((*g)->NewName("Mkl2Tf"), "MklToTf")
|
||||
.Input(src, e->src_output())
|
||||
.Input(src, DataIndexToMetaDataIndex(
|
||||
e->src_output(), src->num_outputs())) // Get an Mkl tensor slot
|
||||
// from the Tf tensor slot.
|
||||
.Device(src->def().device()) // We want to get conversion node
|
||||
// on same device as source node.
|
||||
.Attr("T", src_datatype)
|
||||
.Finalize(&**g, &conversion_node));
|
||||
TF_CHECK_OK(
|
||||
NodeBuilder((*g)->NewName("Mkl2Tf"), "_MklToTf")
|
||||
.Input(src, e->src_output())
|
||||
.Input(src, DataIndexToMetaDataIndex(
|
||||
e->src_output(),
|
||||
src->num_outputs())) // Get an Mkl tensor slot
|
||||
// from the Tf tensor slot.
|
||||
.Device(src->def().device()) // We want to get conversion node
|
||||
// on same device as source node.
|
||||
.Attr("T", src_datatype)
|
||||
.Finalize(&**g, &conversion_node));
|
||||
|
||||
CHECK_NOTNULL(conversion_node);
|
||||
if (GetNodeAttr(src->def(), "data_format", &data_format) == Status::OK()) {
|
||||
@ -191,8 +193,8 @@ bool MklToTfConversionPass::RunPass(std::unique_ptr<Graph>* g) {
|
||||
|
||||
// We skip adding MklToTf on an edge between X->MklToTf or
|
||||
// MklToTf->X, where X is any node.
|
||||
if (src->type_string().compare("MklToTf") == 0 ||
|
||||
dst->type_string().compare("MklToTf") == 0) {
|
||||
if (src->type_string().compare("_MklToTf") == 0 ||
|
||||
dst->type_string().compare("_MklToTf") == 0) {
|
||||
continue;
|
||||
}
|
||||
|
||||
@ -246,8 +248,7 @@ bool InsertMklToTfConversionNodes(std::unique_ptr<Graph>* g) {
|
||||
return MklToTfConversionPass().RunPass(g);
|
||||
}
|
||||
|
||||
Status MklToTfConversionPass::Run(
|
||||
const GraphOptimizationPassOptions& options) {
|
||||
Status MklToTfConversionPass::Run(const GraphOptimizationPassOptions& options) {
|
||||
if (options.graph == nullptr && options.partition_graphs == nullptr) {
|
||||
return Status::OK();
|
||||
}
|
||||
|
@ -15,8 +15,8 @@ limitations under the License.
|
||||
|
||||
#ifdef INTEL_MKL
|
||||
|
||||
#include "tensorflow/core/util/mkl_util.h"
|
||||
#include "tensorflow/core/graph/mkl_tfconversion_pass.h"
|
||||
#include "tensorflow/core/util/mkl_util.h"
|
||||
|
||||
#include <algorithm>
|
||||
#include <string>
|
||||
@ -110,7 +110,7 @@ class MklToTfConversionPass : public ::testing::Test {
|
||||
|
||||
REGISTER_OP("Input").Output("o: float").SetIsStateful();
|
||||
REGISTER_OP("HalfInput").Output("o: half").SetIsStateful();
|
||||
REGISTER_OP("MklInput").Output("o: uint8").SetIsStateful();
|
||||
REGISTER_OP("_MklInput").Output("o: uint8").SetIsStateful();
|
||||
|
||||
TEST_F(MklToTfConversionPass, Basic) {
|
||||
InitGraph(
|
||||
@ -131,47 +131,49 @@ TEST_F(MklToTfConversionPass, Basic) {
|
||||
TEST_F(MklToTfConversionPass, Positive) {
|
||||
if (kTensorOrdering == MklTfTensorOrdering::TENSORS_INTERLEAVED) {
|
||||
InitGraph(
|
||||
"node { name: 'A' op: 'Input'}"
|
||||
"node { name: 'M' op: 'MklInput'}"
|
||||
"node { name: 'B' op: 'Input'}"
|
||||
"node { name: 'N' op: 'MklInput'}"
|
||||
"node { name: 'C' op: 'MklConv2D'"
|
||||
" attr { key: 'T' value { type: DT_FLOAT } }"
|
||||
" attr { key: 'data_format' value { s: 'NCHW' } }"
|
||||
" attr { key: 'use_cudnn_on_gpu' value { b: false } }"
|
||||
" attr { key: 'strides' value { list: {i: 1, i:1, i:1, i:1} } }"
|
||||
" attr { key: 'padding' value { s: 'SAME' } }"
|
||||
" input: ['A', 'M', 'B', 'N']}"
|
||||
"node { name: 'D' op: 'Input'}"
|
||||
"node { name: 'E' op: 'Sub'"
|
||||
" attr {key: 'T' value { type: DT_FLOAT } }"
|
||||
" input: ['C', 'D']}");
|
||||
"node { name: 'A' op: 'Input'}"
|
||||
"node { name: 'M' op: '_MklInput'}"
|
||||
"node { name: 'B' op: 'Input'}"
|
||||
"node { name: 'N' op: '_MklInput'}"
|
||||
"node { name: 'C' op: '_MklConv2D'"
|
||||
" attr { key: 'T' value { type: DT_FLOAT } }"
|
||||
" attr { key: 'data_format' value { s: 'NCHW' } }"
|
||||
" attr { key: 'use_cudnn_on_gpu' value { b: false } }"
|
||||
" attr { key: 'strides' value { list: {i: 1, i:1, i:1, i:1} } "
|
||||
"}"
|
||||
" attr { key: 'padding' value { s: 'SAME' } }"
|
||||
" input: ['A', 'M', 'B', 'N']}"
|
||||
"node { name: 'D' op: 'Input'}"
|
||||
"node { name: 'E' op: 'Sub'"
|
||||
" attr {key: 'T' value { type: DT_FLOAT } }"
|
||||
" input: ['C', 'D']}");
|
||||
EXPECT_EQ(DoRunMklToTfConversionPass(),
|
||||
"A(Input);B(Input);C(MklConv2D);D(Input);E(Sub);M(MklInput);"
|
||||
"Mkl2Tf/_0(MklToTf);N(MklInput)|A->C;B->C:2;C->Mkl2Tf/_0;"
|
||||
"C:1->Mkl2Tf/_0:1;D->E:1;M->C:1;Mkl2Tf/_0->E;N->C:3");
|
||||
"A(Input);B(Input);C(_MklConv2D);D(Input);E(Sub);M(_MklInput);"
|
||||
"_Mkl2Tf/_0(_MklToTf);N(_MklInput)|A->C;B->C:2;C->Mkl2Tf/_0;"
|
||||
"C:1->Mkl2Tf/_0:1;D->E:1;M->C:1;Mkl2Tf/_0->E;N->C:3");
|
||||
} else {
|
||||
CHECK_EQ(kTensorOrdering, MklTfTensorOrdering::TENSORS_CONTIGUOUS);
|
||||
InitGraph(
|
||||
"node { name: 'A' op: 'Input'}"
|
||||
"node { name: 'B' op: 'Input'}"
|
||||
"node { name: 'M' op: 'MklInput'}"
|
||||
"node { name: 'N' op: 'MklInput'}"
|
||||
"node { name: 'C' op: 'MklConv2D'"
|
||||
" attr { key: 'T' value { type: DT_FLOAT } }"
|
||||
" attr { key: 'data_format' value { s: 'NCHW' } }"
|
||||
" attr { key: 'use_cudnn_on_gpu' value { b: false } }"
|
||||
" attr { key: 'strides' value { list: {i: 1, i:1, i:1, i:1} } }"
|
||||
" attr { key: 'padding' value { s: 'SAME' } }"
|
||||
" input: ['A', 'B', 'M', 'N']}"
|
||||
"node { name: 'D' op: 'Input'}"
|
||||
"node { name: 'E' op: 'Sub'"
|
||||
" attr {key: 'T' value { type: DT_FLOAT } }"
|
||||
" input: ['C', 'D']}");
|
||||
"node { name: 'A' op: 'Input'}"
|
||||
"node { name: 'B' op: 'Input'}"
|
||||
"node { name: 'M' op: '_MklInput'}"
|
||||
"node { name: 'N' op: '_MklInput'}"
|
||||
"node { name: 'C' op: '_MklConv2D'"
|
||||
" attr { key: 'T' value { type: DT_FLOAT } }"
|
||||
" attr { key: 'data_format' value { s: 'NCHW' } }"
|
||||
" attr { key: 'use_cudnn_on_gpu' value { b: false } }"
|
||||
" attr { key: 'strides' value { list: {i: 1, i:1, i:1, i:1} } "
|
||||
"}"
|
||||
" attr { key: 'padding' value { s: 'SAME' } }"
|
||||
" input: ['A', 'B', 'M', 'N']}"
|
||||
"node { name: 'D' op: 'Input'}"
|
||||
"node { name: 'E' op: 'Sub'"
|
||||
" attr {key: 'T' value { type: DT_FLOAT } }"
|
||||
" input: ['C', 'D']}");
|
||||
EXPECT_EQ(DoRunMklToTfConversionPass(),
|
||||
"A(Input);B(Input);C(MklConv2D);D(Input);E(Sub);M(MklInput);"
|
||||
"Mkl2Tf/_0(MklToTf);N(MklInput)|A->C;B->C:1;C->Mkl2Tf/_0;"
|
||||
"C:1->Mkl2Tf/_0:1;D->E:1;M->C:2;Mkl2Tf/_0->E;N->C:3");
|
||||
"A(Input);B(Input);C(_MklConv2D);D(Input);E(Sub);M(_MklInput);"
|
||||
"_Mkl2Tf/_0(_MklToTf);N(_MklInput)|A->C;B->C:1;C->Mkl2Tf/_0;"
|
||||
"C:1->Mkl2Tf/_0:1;D->E:1;M->C:2;Mkl2Tf/_0->E;N->C:3");
|
||||
}
|
||||
}
|
||||
|
||||
@ -182,55 +184,57 @@ TEST_F(MklToTfConversionPass, Positive) {
|
||||
TEST_F(MklToTfConversionPass, Negative_DoubleInsert) {
|
||||
if (kTensorOrdering == MklTfTensorOrdering::TENSORS_INTERLEAVED) {
|
||||
InitGraph(
|
||||
"node { name: 'A' op: 'Input'}"
|
||||
"node { name: 'M' op: 'MklInput'}"
|
||||
"node { name: 'B' op: 'Input'}"
|
||||
"node { name: 'N' op: 'MklInput'}"
|
||||
"node { name: 'C' op: 'MklConv2D'"
|
||||
" attr { key: 'T' value { type: DT_FLOAT } }"
|
||||
" attr { key: 'data_format' value { s: 'NCHW' } }"
|
||||
" attr { key: 'use_cudnn_on_gpu' value { b: false } }"
|
||||
" attr { key: 'strides' value { list: {i: 1, i:1, i:1, i:1} } }"
|
||||
" attr { key: 'padding' value { s: 'SAME' } }"
|
||||
" input: ['A', 'M', 'B', 'N']}"
|
||||
"node { name: 'D' op: 'MklToTf'"
|
||||
" attr { key: 'T' value { type: DT_FLOAT } }"
|
||||
" attr { key: 'data_format' value { s: 'NCHW' } }"
|
||||
" input: ['C:0', 'C:1']}"
|
||||
"node { name: 'E' op: 'Input'}"
|
||||
"node { name: 'F' op: 'Sub'"
|
||||
" attr {key: 'T' value { type: DT_FLOAT } }"
|
||||
" input: ['D', 'E']}");
|
||||
"node { name: 'A' op: 'Input'}"
|
||||
"node { name: 'M' op: '_MklInput'}"
|
||||
"node { name: 'B' op: 'Input'}"
|
||||
"node { name: 'N' op: '_MklInput'}"
|
||||
"node { name: 'C' op: '_MklConv2D'"
|
||||
" attr { key: 'T' value { type: DT_FLOAT } }"
|
||||
" attr { key: 'data_format' value { s: 'NCHW' } }"
|
||||
" attr { key: 'use_cudnn_on_gpu' value { b: false } }"
|
||||
" attr { key: 'strides' value { list: {i: 1, i:1, i:1, i:1} } "
|
||||
"}"
|
||||
" attr { key: 'padding' value { s: 'SAME' } }"
|
||||
" input: ['A', 'M', 'B', 'N']}"
|
||||
"node { name: 'D' op: '_MklToTf'"
|
||||
" attr { key: 'T' value { type: DT_FLOAT } }"
|
||||
" attr { key: 'data_format' value { s: 'NCHW' } }"
|
||||
" input: ['C:0', 'C:1']}"
|
||||
"node { name: 'E' op: 'Input'}"
|
||||
"node { name: 'F' op: 'Sub'"
|
||||
" attr {key: 'T' value { type: DT_FLOAT } }"
|
||||
" input: ['D', 'E']}");
|
||||
EXPECT_EQ(DoRunMklToTfConversionPass(),
|
||||
"A(Input);B(Input);C(MklConv2D);D(MklToTf);E(Input);"
|
||||
"F(Sub);M(MklInput);N(MklInput)|"
|
||||
"A->C;B->C:2;C->D;C:1->D:1;D->F;E->F:1;M->C:1;N->C:3");
|
||||
"A(Input);B(Input);C(_MklConv2D);D(_MklToTf);E(Input);"
|
||||
"F(Sub);M(_MklInput);N(_MklInput)|"
|
||||
"A->C;B->C:2;C->D;C:1->D:1;D->F;E->F:1;M->C:1;N->C:3");
|
||||
} else {
|
||||
CHECK_EQ(kTensorOrdering, MklTfTensorOrdering::TENSORS_CONTIGUOUS);
|
||||
InitGraph(
|
||||
"node { name: 'A' op: 'Input'}"
|
||||
"node { name: 'B' op: 'Input'}"
|
||||
"node { name: 'M' op: 'MklInput'}"
|
||||
"node { name: 'N' op: 'MklInput'}"
|
||||
"node { name: 'C' op: 'MklConv2D'"
|
||||
" attr { key: 'T' value { type: DT_FLOAT } }"
|
||||
" attr { key: 'data_format' value { s: 'NCHW' } }"
|
||||
" attr { key: 'use_cudnn_on_gpu' value { b: false } }"
|
||||
" attr { key: 'strides' value { list: {i: 1, i:1, i:1, i:1} } }"
|
||||
" attr { key: 'padding' value { s: 'SAME' } }"
|
||||
" input: ['A', 'B', 'M', 'N']}"
|
||||
"node { name: 'D' op: 'MklToTf'"
|
||||
" attr { key: 'T' value { type: DT_FLOAT } }"
|
||||
" attr { key: 'data_format' value { s: 'NCHW' } }"
|
||||
" input: ['C:0', 'C:1']}"
|
||||
"node { name: 'E' op: 'Input'}"
|
||||
"node { name: 'F' op: 'Sub'"
|
||||
" attr {key: 'T' value { type: DT_FLOAT } }"
|
||||
" input: ['D', 'E']}");
|
||||
"node { name: 'A' op: 'Input'}"
|
||||
"node { name: 'B' op: 'Input'}"
|
||||
"node { name: 'M' op: '_MklInput'}"
|
||||
"node { name: 'N' op: '_MklInput'}"
|
||||
"node { name: 'C' op: '_MklConv2D'"
|
||||
" attr { key: 'T' value { type: DT_FLOAT } }"
|
||||
" attr { key: 'data_format' value { s: 'NCHW' } }"
|
||||
" attr { key: 'use_cudnn_on_gpu' value { b: false } }"
|
||||
" attr { key: 'strides' value { list: {i: 1, i:1, i:1, i:1} } "
|
||||
"}"
|
||||
" attr { key: 'padding' value { s: 'SAME' } }"
|
||||
" input: ['A', 'B', 'M', 'N']}"
|
||||
"node { name: 'D' op: '_MklToTf'"
|
||||
" attr { key: 'T' value { type: DT_FLOAT } }"
|
||||
" attr { key: 'data_format' value { s: 'NCHW' } }"
|
||||
" input: ['C:0', 'C:1']}"
|
||||
"node { name: 'E' op: 'Input'}"
|
||||
"node { name: 'F' op: 'Sub'"
|
||||
" attr {key: 'T' value { type: DT_FLOAT } }"
|
||||
" input: ['D', 'E']}");
|
||||
EXPECT_EQ(DoRunMklToTfConversionPass(),
|
||||
"A(Input);B(Input);C(MklConv2D);D(MklToTf);E(Input);"
|
||||
"F(Sub);M(MklInput);N(MklInput)|"
|
||||
"A->C;B->C:1;C->D;C:1->D:1;D->F;E->F:1;M->C:2;N->C:3");
|
||||
"A(Input);B(Input);C(_MklConv2D);D(_MklToTf);E(Input);"
|
||||
"F(Sub);M(_MklInput);N(_MklInput)|"
|
||||
"A->C;B->C:1;C->D;C:1->D:1;D->F;E->F:1;M->C:2;N->C:3");
|
||||
}
|
||||
}
|
||||
|
||||
@ -258,7 +262,7 @@ TEST_F(MklToTfConversionPass, Negative_NoMklLayer) {
|
||||
" input: ['E', 'Y']}");
|
||||
EXPECT_EQ(DoRunMklToTfConversionPass(),
|
||||
"A(Input);B(Input);C(Conv2D);D(Input);E(BiasAdd);Y(Input);Z(Sub)|"
|
||||
"A->C;B->C:1;C->E;D->E:1;E->Z;Y->Z:1");
|
||||
"A->C;B->C:1;C->E;D->E:1;E->Z;Y->Z:1");
|
||||
}
|
||||
|
||||
static void BM_RunMklToTfConversionPass(int iters, int op_nodes) {
|
||||
|
@ -55,8 +55,13 @@ namespace {
|
||||
// state).
|
||||
static Status FeedInputs(Graph* g, const DeviceAttributes& device_info,
|
||||
const gtl::ArraySlice<string>& fed_outputs,
|
||||
subgraph::NameIndex* name_index) {
|
||||
for (const string& t : fed_outputs) {
|
||||
bool use_function_convention,
|
||||
subgraph::NameIndex* name_index,
|
||||
DataTypeVector* out_feed_types) {
|
||||
out_feed_types->clear();
|
||||
out_feed_types->reserve(fed_outputs.size());
|
||||
for (size_t i = 0; i < fed_outputs.size(); ++i) {
|
||||
const string& t = fed_outputs[i];
|
||||
TensorId id(ParseTensorName(t));
|
||||
|
||||
auto iter = name_index->find(id.first);
|
||||
@ -71,17 +76,31 @@ static Status FeedInputs(Graph* g, const DeviceAttributes& device_info,
|
||||
}
|
||||
|
||||
Node* recv_node;
|
||||
TF_RETURN_IF_ERROR(
|
||||
NodeBuilder(strings::StrCat("_recv_", id.first, "_", id.second),
|
||||
"_Recv")
|
||||
.Attr("tensor_type", BaseType(n->output_type(id.second)))
|
||||
.Attr("tensor_name", t)
|
||||
.Attr("send_device", device_info.name())
|
||||
.Attr("recv_device", device_info.name())
|
||||
.Attr("send_device_incarnation",
|
||||
static_cast<int64>(device_info.incarnation()))
|
||||
.Attr("client_terminated", true)
|
||||
.Finalize(g, &recv_node));
|
||||
|
||||
if (!use_function_convention) {
|
||||
TF_RETURN_IF_ERROR(
|
||||
NodeBuilder(strings::StrCat("_recv_", id.first, "_", id.second),
|
||||
"_Recv")
|
||||
.Attr("tensor_type", BaseType(n->output_type(id.second)))
|
||||
.Attr("tensor_name", t)
|
||||
.Attr("send_device", device_info.name())
|
||||
.Attr("recv_device", device_info.name())
|
||||
.Attr("send_device_incarnation",
|
||||
static_cast<int64>(device_info.incarnation()))
|
||||
.Attr("client_terminated", true)
|
||||
.Finalize(g, &recv_node));
|
||||
} else {
|
||||
// NOTE(mrry): We must include the index as part of the node
|
||||
// name, because _Arg is a "stateful" kernel and therefore
|
||||
// its name must uniquely identify a kernel instance across all
|
||||
// graphs in the same session.
|
||||
TF_RETURN_IF_ERROR(NodeBuilder(strings::StrCat("_arg_", id.first, "_",
|
||||
id.second, "_", i),
|
||||
"_Arg")
|
||||
.Attr("T", BaseType(n->output_type(id.second)))
|
||||
.Attr("index", static_cast<int32>(i))
|
||||
.Finalize(g, &recv_node));
|
||||
}
|
||||
recv_node->set_assigned_device_name(device_info.name());
|
||||
|
||||
// Copy the _output_shapes from the original node to the feed node,
|
||||
@ -130,6 +149,7 @@ static Status FeedInputs(Graph* g, const DeviceAttributes& device_info,
|
||||
}
|
||||
g->RemoveEdge(e);
|
||||
}
|
||||
out_feed_types->push_back(BaseType(n->output_type(id.second)));
|
||||
}
|
||||
return Status::OK();
|
||||
}
|
||||
@ -181,9 +201,14 @@ namespace subgraph {
|
||||
|
||||
Status FetchOutputs(Graph* g, const DeviceAttributes& device_info,
|
||||
const gtl::ArraySlice<string>& fetch_outputs,
|
||||
NameIndex* name_index, std::vector<Node*>* fetch_nodes) {
|
||||
fetch_nodes->clear();
|
||||
for (const string& t : fetch_outputs) {
|
||||
bool use_function_convention, NameIndex* name_index,
|
||||
std::vector<Node*>* out_fetch_nodes,
|
||||
DataTypeVector* out_fetch_types) {
|
||||
out_fetch_nodes->clear();
|
||||
out_fetch_nodes->reserve(fetch_outputs.size());
|
||||
for (size_t i = 0; i < fetch_outputs.size(); ++i) {
|
||||
const string& t = fetch_outputs[i];
|
||||
|
||||
// Parse t into node_name and output_index.
|
||||
TensorId id(ParseTensorName(t));
|
||||
|
||||
@ -213,25 +238,39 @@ Status FetchOutputs(Graph* g, const DeviceAttributes& device_info,
|
||||
|
||||
// Create the fetch Node and connect it up
|
||||
Node* send_node;
|
||||
TF_RETURN_IF_ERROR(
|
||||
NodeBuilder(strings::StrCat("_send_", id.first, "_", id.second),
|
||||
"_Send")
|
||||
.Input(n, id.second)
|
||||
.Attr("tensor_name", t)
|
||||
.Attr("send_device", device_info.name())
|
||||
.Attr("recv_device", device_info.name())
|
||||
.Attr("send_device_incarnation",
|
||||
static_cast<int64>(device_info.incarnation()))
|
||||
.Attr("client_terminated", true)
|
||||
.Finalize(g, &send_node));
|
||||
if (!use_function_convention) {
|
||||
TF_RETURN_IF_ERROR(
|
||||
NodeBuilder(strings::StrCat("_send_", id.first, "_", id.second),
|
||||
"_Send")
|
||||
.Input(n, id.second)
|
||||
.Attr("tensor_name", t)
|
||||
.Attr("send_device", device_info.name())
|
||||
.Attr("recv_device", device_info.name())
|
||||
.Attr("send_device_incarnation",
|
||||
static_cast<int64>(device_info.incarnation()))
|
||||
.Attr("client_terminated", true)
|
||||
.Finalize(g, &send_node));
|
||||
} else {
|
||||
// NOTE(mrry): We must include the index as part of the node
|
||||
// name, because _Retval is a "stateful" kernel and therefore
|
||||
// its name must uniquely identify a kernel instance across all
|
||||
// graphs in the same session.
|
||||
TF_RETURN_IF_ERROR(NodeBuilder(strings::StrCat("_retval_", id.first, "_",
|
||||
id.second, "_", i),
|
||||
"_Retval")
|
||||
.Input(n, id.second)
|
||||
.Attr("T", BaseType(n->output_type(id.second)))
|
||||
.Attr("index", static_cast<int32>(i))
|
||||
.Finalize(g, &send_node));
|
||||
}
|
||||
send_node->set_assigned_device_name(device_info.name());
|
||||
VLOG(1) << "Created fetch node: " << SummarizeNodeDef(send_node->def());
|
||||
|
||||
// Update the index.
|
||||
(*name_index)[send_node->name()] = send_node;
|
||||
|
||||
g->AddControlEdge(send_node, g->sink_node());
|
||||
fetch_nodes->push_back(send_node);
|
||||
out_fetch_nodes->push_back(send_node);
|
||||
out_fetch_types->push_back(BaseType(n->output_type(id.second)));
|
||||
}
|
||||
|
||||
return Status::OK();
|
||||
@ -241,7 +280,8 @@ Status RewriteGraphForExecution(
|
||||
Graph* g, const gtl::ArraySlice<string>& fed_outputs,
|
||||
const gtl::ArraySlice<string>& fetch_outputs,
|
||||
const gtl::ArraySlice<string>& target_node_names,
|
||||
const DeviceAttributes& device_info) {
|
||||
const DeviceAttributes& device_info, bool use_function_convention,
|
||||
RewriteGraphMetadata* out_metadata) {
|
||||
if (fetch_outputs.empty() && target_node_names.empty()) {
|
||||
return errors::InvalidArgument(
|
||||
"Must specify at least one target to fetch or execute.");
|
||||
@ -274,18 +314,21 @@ Status RewriteGraphForExecution(
|
||||
// currently listed in "fetch_nodes". We pass "name_index" so the index is
|
||||
// kept up to date.
|
||||
if (!fed_outputs.empty()) {
|
||||
TF_RETURN_IF_ERROR(FeedInputs(g, device_info, fed_outputs, &name_index));
|
||||
TF_RETURN_IF_ERROR(FeedInputs(g, device_info, fed_outputs,
|
||||
use_function_convention, &name_index,
|
||||
&out_metadata->feed_types));
|
||||
}
|
||||
|
||||
// Add the fetch nodes, also updating "name_index".
|
||||
std::vector<Node*> fetch_nodes;
|
||||
if (!fetch_outputs.empty()) {
|
||||
TF_RETURN_IF_ERROR(
|
||||
FetchOutputs(g, device_info, fetch_outputs, &name_index, &fetch_nodes));
|
||||
TF_RETURN_IF_ERROR(FetchOutputs(g, device_info, fetch_outputs,
|
||||
use_function_convention, &name_index,
|
||||
&fetch_nodes, &out_metadata->fetch_types));
|
||||
}
|
||||
|
||||
// Prune the graph to only compute what is needed for the fetch nodes and the
|
||||
// targets nodes.
|
||||
// target nodes.
|
||||
if (!fetch_nodes.empty() || !target_node_names.empty()) {
|
||||
TF_RETURN_IF_ERROR(
|
||||
PruneForTargets(g, name_index, fetch_nodes, target_node_names));
|
||||
|
@ -26,6 +26,18 @@ limitations under the License.
|
||||
namespace tensorflow {
|
||||
namespace subgraph {
|
||||
|
||||
// Information about a graph rewritten by `RewriteGraphForExecution()`.
|
||||
struct RewriteGraphMetadata {
|
||||
// The element type of each tensor fed to this subgraph. The order
|
||||
// of types corresponds to the order of tensor names in
|
||||
// `fed_outputs` when calling `RewriteGraphForExecution()`.
|
||||
DataTypeVector feed_types;
|
||||
// The element type of each tensor fetched from this subgraph. The
|
||||
// order of types corresponds to the order of tensor names in
|
||||
// `fetch_outputs` when calling `RewriteGraphForExecution()`.
|
||||
DataTypeVector fetch_types;
|
||||
};
|
||||
|
||||
// Rewrite the graph structure of "*g" to deal with feeding node
|
||||
// outputs, fetching node outputs, and only running a subset of the
|
||||
// graph. "fed_outputs" and "fetch_outputs" are both lists of
|
||||
@ -56,7 +68,8 @@ Status RewriteGraphForExecution(
|
||||
Graph* g, const gtl::ArraySlice<string>& fed_outputs,
|
||||
const gtl::ArraySlice<string>& fetch_outputs,
|
||||
const gtl::ArraySlice<string>& target_node_names,
|
||||
const DeviceAttributes& device_info);
|
||||
const DeviceAttributes& device_info, bool use_function_convention,
|
||||
RewriteGraphMetadata* out_metadata);
|
||||
|
||||
typedef std::unordered_map<StringPiece, Node*, StringPiece::Hasher> NameIndex;
|
||||
|
||||
|
@ -104,7 +104,8 @@ class SubgraphTest : public ::testing::Test {
|
||||
}
|
||||
|
||||
string Subgraph(const string& fed_str, const string& fetch_str,
|
||||
const string& targets_str) {
|
||||
const string& targets_str,
|
||||
bool use_function_convention = false) {
|
||||
Graph* subgraph = new Graph(OpRegistry::Global());
|
||||
CopyGraph(*g_, subgraph);
|
||||
std::vector<string> fed =
|
||||
@ -114,13 +115,18 @@ class SubgraphTest : public ::testing::Test {
|
||||
std::vector<string> targets =
|
||||
str_util::Split(targets_str, ',', str_util::SkipEmpty());
|
||||
|
||||
Status s = subgraph::RewriteGraphForExecution(subgraph, fed, fetch, targets,
|
||||
device_info_);
|
||||
subgraph::RewriteGraphMetadata metadata;
|
||||
Status s = subgraph::RewriteGraphForExecution(
|
||||
subgraph, fed, fetch, targets, device_info_, use_function_convention,
|
||||
&metadata);
|
||||
if (!s.ok()) {
|
||||
delete subgraph;
|
||||
return s.ToString();
|
||||
}
|
||||
|
||||
EXPECT_EQ(fed.size(), metadata.feed_types.size());
|
||||
EXPECT_EQ(fetch.size(), metadata.fetch_types.size());
|
||||
|
||||
// Replace the graph with the subgraph for the rest of the display program
|
||||
g_.reset(subgraph);
|
||||
return "OK";
|
||||
@ -178,6 +184,20 @@ TEST_F(SubgraphTest, FedOutputs1) {
|
||||
ExpectNodes("W1,W2,_recv_input_1,t1,t2");
|
||||
}
|
||||
|
||||
TEST_F(SubgraphTest, FedOutputs1_FunctionConvention) {
|
||||
ExpectOK(
|
||||
"node { name: 'W1' op: 'TestParams' }"
|
||||
"node { name: 'W2' op: 'TestParams' }"
|
||||
"node { name: 'input' op: 'TestInput' }"
|
||||
"node { name: 't1' op: 'TestMul' input: [ 'W1', 'input:1' ] }"
|
||||
"node { name: 't2' op: 'TestMul' input: [ 'W2', 't1' ] }"
|
||||
"node { name: 't3_a' op: 'TestRelu' input: 't2' }"
|
||||
"node { name: 't3_b' op: 'TestRelu' input: 't2' }");
|
||||
EXPECT_EQ("OK",
|
||||
Subgraph("input:1", "", "t2", true /* use_function_convention */));
|
||||
ExpectNodes("W1,W2,_arg_input_1_0,t1,t2");
|
||||
}
|
||||
|
||||
TEST_F(SubgraphTest, FedRefNode) {
|
||||
ExpectOK(
|
||||
"node { name: 'W1' op: 'TestParams' }"
|
||||
@ -189,7 +209,19 @@ TEST_F(SubgraphTest, FedRefNode) {
|
||||
EXPECT_FALSE(IsRefType(CHECK_NOTNULL(n)->output_type(0)));
|
||||
}
|
||||
|
||||
TEST_F(SubgraphTest, FedOutputs2) {
|
||||
TEST_F(SubgraphTest, FedRefNode_FunctionConvention) {
|
||||
ExpectOK(
|
||||
"node { name: 'W1' op: 'TestParams' }"
|
||||
"node { name: 'W2' op: 'TestParams' }"
|
||||
"node { name: 't1' op: 'TestMul' input: [ 'W2', 'W1' ] }");
|
||||
EXPECT_EQ("OK",
|
||||
Subgraph("W1:0", "", "t1", true /* use_function_convention */));
|
||||
ExpectNodes("_arg_W1_0_0,W2,t1");
|
||||
Node* n = FindNode("_arg_W1_0_0");
|
||||
EXPECT_FALSE(IsRefType(CHECK_NOTNULL(n)->output_type(0)));
|
||||
}
|
||||
|
||||
TEST_F(SubgraphTest, FedOutputs2_FunctionConvention) {
|
||||
ExpectOK(
|
||||
"node { name: 'W1' op: 'TestParams' }"
|
||||
"node { name: 'W2' op: 'TestParams' }"
|
||||
@ -200,8 +232,9 @@ TEST_F(SubgraphTest, FedOutputs2) {
|
||||
"node { name: 't3_b' op: 'TestRelu' input: 't2' }");
|
||||
// We feed input:1, but nothing connects to it, so the _recv(input:1)
|
||||
// node also disappears.
|
||||
EXPECT_EQ("OK", Subgraph("input:1,t1,W2", "", "t2"));
|
||||
ExpectNodes("_recv_t1_0,_recv_W2_0,t2");
|
||||
EXPECT_EQ("OK", Subgraph("input:1,t1,W2", "", "t2",
|
||||
true /* use_function_convention */));
|
||||
ExpectNodes("_arg_t1_0_1,_arg_W2_0_2,t2");
|
||||
}
|
||||
|
||||
TEST_F(SubgraphTest, FetchOutputs1) {
|
||||
@ -218,6 +251,22 @@ TEST_F(SubgraphTest, FetchOutputs1) {
|
||||
"W1,W2,input,t1,t2,_send_W2_0,_send_input_1,_send_t1_0,_send_t2_0");
|
||||
}
|
||||
|
||||
TEST_F(SubgraphTest, FetchOutputs1_FunctionConvention) {
|
||||
ExpectOK(
|
||||
"node { name: 'W1' op: 'TestParams' }"
|
||||
"node { name: 'W2' op: 'TestParams' }"
|
||||
"node { name: 'input' op: 'TestInput' }"
|
||||
"node { name: 't1' op: 'TestMul' input: [ 'W1', 'input:1' ] }"
|
||||
"node { name: 't2' op: 'TestMul' input: [ 'W2', 't1' ] }"
|
||||
"node { name: 't3_a' op: 'TestRelu' input: 't2' }"
|
||||
"node { name: 't3_b' op: 'TestRelu' input: 't2' }");
|
||||
EXPECT_EQ("OK", Subgraph("", "W2,input:1,t1,t2", "t2",
|
||||
true /* use_function_convention */));
|
||||
ExpectNodes(
|
||||
"W1,W2,input,t1,t2,_retval_W2_0_0,_retval_input_1_1,_retval_t1_0_2,_"
|
||||
"retval_t2_0_3");
|
||||
}
|
||||
|
||||
TEST_F(SubgraphTest, FetchOutputs2) {
|
||||
ExpectOK(
|
||||
"node { name: 'W1' op: 'TestParams' }"
|
||||
@ -231,6 +280,20 @@ TEST_F(SubgraphTest, FetchOutputs2) {
|
||||
ExpectNodes("W1,W2,input,t1,t2,t3_a,_send_t3_a_0");
|
||||
}
|
||||
|
||||
TEST_F(SubgraphTest, FetchOutputs2_FunctionConvention) {
|
||||
ExpectOK(
|
||||
"node { name: 'W1' op: 'TestParams' }"
|
||||
"node { name: 'W2' op: 'TestParams' }"
|
||||
"node { name: 'input' op: 'TestInput' }"
|
||||
"node { name: 't1' op: 'TestMul' input: [ 'W1', 'input:1' ] }"
|
||||
"node { name: 't2' op: 'TestMul' input: [ 'W2', 't1' ] }"
|
||||
"node { name: 't3_a' op: 'TestRelu' input: 't2' }"
|
||||
"node { name: 't3_b' op: 'TestRelu' input: 't2' }");
|
||||
EXPECT_EQ("OK",
|
||||
Subgraph("", "t3_a", "t2", true /* use_function_convention */));
|
||||
ExpectNodes("W1,W2,input,t1,t2,t3_a,_retval_t3_a_0_0");
|
||||
}
|
||||
|
||||
TEST_F(SubgraphTest, ChainOfFools) {
|
||||
ExpectOK(
|
||||
"node { name: 'a' op: 'TestParams' }"
|
||||
@ -315,7 +378,8 @@ TEST_F(SubgraphTest, FedOutputsPreservesOutputShapes) {
|
||||
REGISTER_OP("In").Output("o: float");
|
||||
REGISTER_OP("Op").Input("i: float").Output("o: float");
|
||||
|
||||
static void BM_Subgraph(int iters, int num_nodes) {
|
||||
static void BM_SubgraphHelper(int iters, int num_nodes,
|
||||
bool use_function_convention) {
|
||||
DeviceAttributes device_info;
|
||||
device_info.set_name("/job:a/replica:0/task:0/cpu:0");
|
||||
device_info.set_device_type(DeviceType(DEVICE_CPU).type());
|
||||
@ -347,12 +411,26 @@ static void BM_Subgraph(int iters, int num_nodes) {
|
||||
while (--iters > 0) {
|
||||
Graph* subgraph = new Graph(OpRegistry::Global());
|
||||
CopyGraph(g, subgraph);
|
||||
TF_CHECK_OK(subgraph::RewriteGraphForExecution(subgraph, fed, fetch,
|
||||
targets, device_info));
|
||||
subgraph::RewriteGraphMetadata metadata;
|
||||
TF_CHECK_OK(subgraph::RewriteGraphForExecution(
|
||||
subgraph, fed, fetch, targets, device_info, use_function_convention,
|
||||
&metadata));
|
||||
delete subgraph;
|
||||
}
|
||||
}
|
||||
|
||||
static void BM_Subgraph(int iters, int num_nodes) {
|
||||
BM_SubgraphHelper(iters, num_nodes, false /* use_function_convention */);
|
||||
}
|
||||
static void BM_SubgraphFunctionConvention(int iters, int num_nodes) {
|
||||
BM_SubgraphHelper(iters, num_nodes, true /* use_function_convention */);
|
||||
}
|
||||
BENCHMARK(BM_Subgraph)->Arg(100)->Arg(1000)->Arg(10000)->Arg(100000);
|
||||
BENCHMARK(BM_SubgraphFunctionConvention)
|
||||
->Arg(100)
|
||||
->Arg(1000)
|
||||
->Arg(10000)
|
||||
->Arg(100000);
|
||||
|
||||
} // namespace
|
||||
} // namespace tensorflow
|
||||
|
@ -18,6 +18,13 @@ limitations under the License.
|
||||
namespace tensorflow {
|
||||
namespace grappler {
|
||||
|
||||
bool IsDequeueOp(const NodeDef& node) {
|
||||
static const std::set<std::string> dequeue_ops = {
|
||||
"QueueDequeueManyV2", "QueueDequeueMany", "QueueDequeueV2",
|
||||
"QueueDequeue"};
|
||||
return dequeue_ops.count(node.op()) > 0;
|
||||
}
|
||||
|
||||
bool IsPlaceholder(const NodeDef& node) {
|
||||
const auto op = node.op();
|
||||
return op == "Placeholder" || op == "PlaceholderV2";
|
||||
|
@ -21,6 +21,7 @@ limitations under the License.
|
||||
namespace tensorflow {
|
||||
namespace grappler {
|
||||
|
||||
bool IsDequeueOp(const NodeDef& node);
|
||||
bool IsPlaceholder(const NodeDef& node);
|
||||
bool IsVariable(const NodeDef& node);
|
||||
|
||||
|
Some files were not shown because too many files have changed in this diff Show More
Loading…
Reference in New Issue
Block a user