Merge pull request #9677 from vrv/branch_155159972

Branch 155159972
This commit is contained in:
Vijay Vasudevan 2017-05-04 22:34:48 -07:00 committed by GitHub
commit ce02c770fb
192 changed files with 8204 additions and 4109 deletions

2
configure vendored
View File

@ -385,7 +385,7 @@ fi
# Append CC optimization flags to bazel.rc
for opt in $CC_OPT_FLAGS; do
write_to_bazelrc 'build:opt --cxxopt=$opt --copt=$opt'
write_to_bazelrc "build:opt --cxxopt=$opt --copt=$opt"
done
# Run the gen_git_source to create links where bazel can track dependencies for

View File

@ -58,6 +58,7 @@ tf_cuda_library(
"//tensorflow/cc/saved_model:loader",
"//tensorflow/cc:gradients",
"//tensorflow/cc:ops",
"//tensorflow/cc:grad_ops",
"//tensorflow/cc:scope_internal",
"//tensorflow/core:core_cpu",
"//tensorflow/core:framework",

View File

@ -91,6 +91,7 @@ cc_library(
deps = [
":array_grad",
":math_grad",
":nn_grad",
],
)

View File

@ -125,7 +125,7 @@ XlaDevice::XlaDevice(const SessionOptions& options,
const DeviceType& jit_device_name,
perftools::gputools::Platform* platform,
Allocator* xla_allocator)
: LocalDevice(options, attrs, xla_allocator),
: LocalDevice(options, attrs),
device_ordinal_(device_ordinal),
jit_device_name_(jit_device_name),
xla_allocator_(xla_allocator),

View File

@ -76,8 +76,7 @@ XlaCompilationDevice::XlaCompilationDevice(const SessionOptions& options,
options,
Device::BuildDeviceAttributes(
"", type, Bytes(256 << 20), DeviceLocality(),
strings::StrCat("device: XLA compilation device ", type.type())),
cpu_allocator()),
strings::StrCat("device: XLA compilation device ", type.type()))),
allocator_(new XlaCompilationAllocator()) {}
XlaCompilationDevice::~XlaCompilationDevice() {}

View File

@ -668,6 +668,14 @@ class ComputationBuilder {
// then Build() should be used instead.
Computation BuildAndNoteError();
// Returns the first error that was encountered while building the
// computation. When an error is encountered, by default we return a vacuous
// ComputationDataHandle and inform the user of the error that occurred while
// building the computation when they make a final call to Build().
//
// See also set_die_immediately_on_error().
Status first_error() const { return first_error_; }
private:
using PopulateLiteral = std::function<void(Literal*)>;

View File

@ -201,7 +201,8 @@ void IrEmitter::InitializeIrFunction(const string& function_name,
if (&argument == retval) {
continue;
}
compute_function_->setDoesNotAlias(argument.getArgNo() + 1);
compute_function_->addAttribute(argument.getArgNo() + 1,
llvm::Attribute::NoAlias);
}
ir_builder_.SetInsertPoint(llvm::BasicBlock::Create(

View File

@ -196,7 +196,7 @@ llvm::Function* IrEmitterUnnested::BuildKernelPrototype(
ir_emitter_context_->buffer_assignment().GetTempAllocation()) {
kernel->addDereferenceableAttr(temp_buffer_arg_no + 1, allocation->size());
}
kernel->setDoesNotAlias(temp_buffer_arg_no + 1);
kernel->addAttribute(temp_buffer_arg_no + 1, llvm::Attribute::NoAlias);
// Add the declaration of this kernel to llvm.nvvm.annotations so that NVPTX
// treats it as a CUDA kernel.

View File

@ -705,7 +705,8 @@ std::unique_ptr<Layout> LayoutAssignment::ChooseOperandLayoutFromOutputLayout(
CHECK(ShapeUtil::IsArray(instruction->shape()) &&
ShapeUtil::IsArray(operand->shape()));
if (instruction->IsElementwiseOnOperand(operand_no) &&
if ((instruction->IsElementwiseOnOperand(operand_no) ||
InstructionRequiresInputLayoutEqualToOutputLayout(instruction)) &&
!ShapeUtil::IsScalar(operand->shape()) &&
ShapeUtil::Rank(operand->shape()) ==
ShapeUtil::Rank(instruction->shape())) {

View File

@ -248,6 +248,15 @@ class LayoutAssignment : public HloPassInterface {
return Status::OK();
}
// This method can be overriden to mark instructions as requiring the operands
// to have the same layout as the result, for performance or correctness. This
// will propagate constraints through the instruction from the result into the
// operands.
virtual bool InstructionRequiresInputLayoutEqualToOutputLayout(
const HloInstruction* instruction) {
return false;
}
// Construct contraints and assign layouts to all instructions in the
// computation satisfying the given ComputationLayout. Layouts constraints are
// added, then propagated until all LogicalBuffers in the computation are

View File

@ -244,8 +244,11 @@ StatusOr<Shape> InferWindowOutputShape(const Shape& base_shape,
}
if (ShapeUtil::Rank(*arg_shape) != ShapeUtil::Rank(*shape)) {
return InvalidArgument(
"cannot concatenate arrays with different ranks: %lld vs %lld",
ShapeUtil::Rank(*arg_shape), ShapeUtil::Rank(*shape));
"Cannot concatenate arrays with different ranks: %lld (%s) vs %lld "
"(%s)",
ShapeUtil::Rank(*arg_shape),
ShapeUtil::HumanString(*arg_shape).c_str(), ShapeUtil::Rank(*shape),
ShapeUtil::HumanString(*shape).c_str());
}
if (arg_shape->element_type() != shape->element_type()) {
return InvalidArgument(

View File

@ -118,6 +118,7 @@ set(tf_proto_text_srcs
"tensorflow/core/framework/types.proto"
"tensorflow/core/framework/versions.proto"
"tensorflow/core/lib/core/error_codes.proto"
"tensorflow/core/protobuf/cluster.proto"
"tensorflow/core/protobuf/config.proto"
"tensorflow/core/protobuf/debug.proto"
"tensorflow/core/protobuf/rewriter_config.proto"

View File

@ -22,6 +22,7 @@ set(tf_op_lib_names
"image_ops"
"io_ops"
"linalg_ops"
"lookup_ops"
"logging_ops"
"math_ops"
"nn_ops"

View File

@ -203,6 +203,7 @@ add_python_module("tensorflow/python/estimator")
add_python_module("tensorflow/python/estimator/export")
add_python_module("tensorflow/python/estimator/inputs")
add_python_module("tensorflow/python/estimator/inputs/queues")
add_python_module("tensorflow/python/feature_column")
add_python_module("tensorflow/python/framework")
add_python_module("tensorflow/python/grappler")
add_python_module("tensorflow/python/kernel_tests")
@ -596,6 +597,7 @@ GENERATE_PYTHON_OP_LIB("image_ops")
GENERATE_PYTHON_OP_LIB("io_ops")
GENERATE_PYTHON_OP_LIB("linalg_ops")
GENERATE_PYTHON_OP_LIB("logging_ops")
GENERATE_PYTHON_OP_LIB("lookup_ops")
GENERATE_PYTHON_OP_LIB("nn_ops")
GENERATE_PYTHON_OP_LIB("parsing_ops")
GENERATE_PYTHON_OP_LIB("random_ops")

View File

@ -710,25 +710,6 @@ cuda_py_test(
],
)
cuda_py_test(
name = "identity_test",
size = "small",
srcs = ["python/kernel_tests/bijectors/identity_test.py"],
additional_deps = [
":bijectors_py",
":distributions_py",
"//third_party/py/numpy",
"@six_archive//:six",
"//tensorflow/contrib/linalg:linalg_py",
"//tensorflow/python:array_ops",
"//tensorflow/python:client_testlib",
"//tensorflow/python:framework_for_generated_wrappers",
"//tensorflow/python:framework_test_lib",
"//tensorflow/python:math_ops",
"//tensorflow/python:platform_test",
],
)
cuda_py_test(
name = "inline_test",
size = "small",

View File

@ -25,6 +25,7 @@ from __future__ import print_function
from tensorflow.contrib.distributions.python.ops import bijectors
from tensorflow.contrib.distributions.python.ops.binomial import *
from tensorflow.contrib.distributions.python.ops.chi2 import *
from tensorflow.contrib.distributions.python.ops.conditional_distribution import *
from tensorflow.contrib.distributions.python.ops.conditional_transformed_distribution import *
from tensorflow.contrib.distributions.python.ops.deterministic import *
from tensorflow.contrib.distributions.python.ops.distribution_util import matrix_diag_transform
@ -44,12 +45,10 @@ from tensorflow.contrib.distributions.python.ops.quantized_distribution import *
from tensorflow.contrib.distributions.python.ops.relaxed_bernoulli import *
from tensorflow.contrib.distributions.python.ops.relaxed_onehot_categorical import *
from tensorflow.contrib.distributions.python.ops.sample_stats import *
from tensorflow.contrib.distributions.python.ops.transformed_distribution import *
from tensorflow.contrib.distributions.python.ops.wishart import *
from tensorflow.python.ops.distributions.bernoulli import *
from tensorflow.python.ops.distributions.beta import *
from tensorflow.python.ops.distributions.categorical import *
from tensorflow.python.ops.distributions.conditional_distribution import *
from tensorflow.python.ops.distributions.dirichlet import *
from tensorflow.python.ops.distributions.dirichlet_multinomial import *
from tensorflow.python.ops.distributions.distribution import *
@ -60,6 +59,7 @@ from tensorflow.python.ops.distributions.laplace import *
from tensorflow.python.ops.distributions.multinomial import *
from tensorflow.python.ops.distributions.normal import *
from tensorflow.python.ops.distributions.student_t import *
from tensorflow.python.ops.distributions.transformed_distribution import *
from tensorflow.python.ops.distributions.uniform import *
# pylint: enable=unused-import,wildcard-import,line-too-long,g-importing-member

View File

@ -23,9 +23,9 @@ import itertools
import numpy as np
from tensorflow.contrib.distributions.python.ops.bijectors.affine import Affine
from tensorflow.contrib.distributions.python.ops.bijectors.bijector_test_util import assert_scalar_congruency
from tensorflow.python.framework import dtypes
from tensorflow.python.ops import array_ops
from tensorflow.python.ops.distributions.bijector_test_util import assert_scalar_congruency
from tensorflow.python.platform import test

View File

@ -20,12 +20,12 @@ from __future__ import print_function
import numpy as np
from tensorflow.contrib.distributions.python.ops.bijectors.bijector_test_util import assert_scalar_congruency
from tensorflow.contrib.distributions.python.ops.bijectors.chain import Chain
from tensorflow.contrib.distributions.python.ops.bijectors.exp import Exp
from tensorflow.contrib.distributions.python.ops.bijectors.softmax_centered import SoftmaxCentered
from tensorflow.contrib.distributions.python.ops.bijectors.softplus import Softplus
from tensorflow.python.framework import tensor_shape
from tensorflow.python.ops.distributions.bijector_test_util import assert_scalar_congruency
from tensorflow.python.platform import test

View File

@ -19,11 +19,11 @@ from __future__ import division
from __future__ import print_function
from tensorflow.contrib.distributions.python.ops import bijectors
from tensorflow.contrib.distributions.python.ops import transformed_distribution as transformed_distribution_lib
from tensorflow.contrib.distributions.python.ops.bijectors.bijector_test_util import assert_scalar_congruency
from tensorflow.python.framework import tensor_shape
from tensorflow.python.ops import array_ops
from tensorflow.python.ops.distributions import gamma as gamma_lib
from tensorflow.python.ops.distributions import transformed_distribution as transformed_distribution_lib
from tensorflow.python.ops.distributions.bijector_test_util import assert_scalar_congruency
from tensorflow.python.platform import test

View File

@ -20,9 +20,9 @@ from __future__ import print_function
import numpy as np
from tensorflow.contrib.distributions.python.ops.bijectors.bijector_test_util import assert_bijective_and_finite
from tensorflow.contrib.distributions.python.ops.bijectors.bijector_test_util import assert_scalar_congruency
from tensorflow.contrib.distributions.python.ops.bijectors.exp import Exp
from tensorflow.python.ops.distributions.bijector_test_util import assert_bijective_and_finite
from tensorflow.python.ops.distributions.bijector_test_util import assert_scalar_congruency
from tensorflow.python.platform import test

View File

@ -19,11 +19,11 @@ from __future__ import division
from __future__ import print_function
from tensorflow.contrib.distributions.python.ops import bijectors
from tensorflow.contrib.distributions.python.ops import transformed_distribution as transformed_distribution_lib
from tensorflow.contrib.distributions.python.ops.bijectors.bijector_test_util import assert_scalar_congruency
from tensorflow.python.framework import tensor_shape
from tensorflow.python.ops import array_ops
from tensorflow.python.ops.distributions import gamma as gamma_lib
from tensorflow.python.ops.distributions import transformed_distribution as transformed_distribution_lib
from tensorflow.python.ops.distributions.bijector_test_util import assert_scalar_congruency
from tensorflow.python.platform import test

View File

@ -20,9 +20,9 @@ from __future__ import print_function
import numpy as np
from tensorflow.contrib.distributions.python.ops.bijectors.bijector_test_util import assert_bijective_and_finite
from tensorflow.contrib.distributions.python.ops.bijectors.bijector_test_util import assert_scalar_congruency
from tensorflow.contrib.distributions.python.ops.bijectors.power_transform import PowerTransform
from tensorflow.python.ops.distributions.bijector_test_util import assert_bijective_and_finite
from tensorflow.python.ops.distributions.bijector_test_util import assert_scalar_congruency
from tensorflow.python.platform import test

View File

@ -21,9 +21,9 @@ from __future__ import print_function
import numpy as np
from scipy import special
from tensorflow.contrib.distributions.python.ops.bijectors.bijector_test_util import assert_bijective_and_finite
from tensorflow.contrib.distributions.python.ops.bijectors.bijector_test_util import assert_scalar_congruency
from tensorflow.contrib.distributions.python.ops.bijectors.sigmoid import Sigmoid
from tensorflow.python.ops.distributions.bijector_test_util import assert_bijective_and_finite
from tensorflow.python.ops.distributions.bijector_test_util import assert_scalar_congruency
from tensorflow.python.platform import test

View File

@ -20,9 +20,9 @@ from __future__ import print_function
import numpy as np
from tensorflow.contrib.distributions.python.ops.bijectors.bijector_test_util import assert_bijective_and_finite
from tensorflow.contrib.distributions.python.ops.bijectors.softmax_centered import SoftmaxCentered
from tensorflow.python.framework import tensor_shape
from tensorflow.python.ops.distributions.bijector_test_util import assert_bijective_and_finite
from tensorflow.python.platform import test

View File

@ -20,9 +20,9 @@ from __future__ import print_function
import numpy as np
from tensorflow.contrib.distributions.python.ops.bijectors.bijector_test_util import assert_bijective_and_finite
from tensorflow.contrib.distributions.python.ops.bijectors.bijector_test_util import assert_scalar_congruency
from tensorflow.contrib.distributions.python.ops.bijectors.softplus import Softplus
from tensorflow.python.ops.distributions.bijector_test_util import assert_bijective_and_finite
from tensorflow.python.ops.distributions.bijector_test_util import assert_scalar_congruency
from tensorflow.python.platform import test
rng = np.random.RandomState(42)

View File

@ -43,7 +43,6 @@ from tensorflow.contrib.distributions.python.ops.bijectors.chain import *
from tensorflow.contrib.distributions.python.ops.bijectors.cholesky_outer_product import *
from tensorflow.contrib.distributions.python.ops.bijectors.conditional_bijector import *
from tensorflow.contrib.distributions.python.ops.bijectors.exp import *
from tensorflow.contrib.distributions.python.ops.bijectors.identity import *
from tensorflow.contrib.distributions.python.ops.bijectors.inline import *
from tensorflow.contrib.distributions.python.ops.bijectors.invert import *
from tensorflow.contrib.distributions.python.ops.bijectors.power_transform import *
@ -52,6 +51,7 @@ from tensorflow.contrib.distributions.python.ops.bijectors.sigmoid_centered impo
from tensorflow.contrib.distributions.python.ops.bijectors.softmax_centered import *
from tensorflow.contrib.distributions.python.ops.bijectors.softplus import *
from tensorflow.python.ops.distributions.bijector import *
from tensorflow.python.ops.distributions.identity_bijector import Identity
# pylint: enable=unused-import,wildcard-import,line-too-long,g-importing-member

View File

@ -1,29 +0,0 @@
# Copyright 2016 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Identity bijector."""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
# go/tf-wildcard-import
# pylint: disable=wildcard-import
from tensorflow.contrib.distributions.python.ops.bijectors.identity_impl import *
# pylint: enable=wildcard-import
from tensorflow.python.util.all_util import remove_undocumented
_allowed_symbols = ["Identity"]
remove_undocumented(__name__, _allowed_symbols)

View File

@ -17,9 +17,9 @@ from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
from tensorflow.contrib.distributions.python.ops import transformed_distribution
from tensorflow.contrib.distributions.python.ops import conditional_distribution
from tensorflow.python.ops import math_ops
from tensorflow.python.ops.distributions import conditional_distribution
from tensorflow.python.ops.distributions import transformed_distribution
from tensorflow.python.ops.distributions import util as distribution_util

View File

@ -20,7 +20,6 @@ from __future__ import print_function
from tensorflow.contrib import linalg
from tensorflow.contrib.distributions.python.ops import bijectors
from tensorflow.contrib.distributions.python.ops import transformed_distribution
from tensorflow.python.framework import ops
from tensorflow.python.framework import tensor_shape
from tensorflow.python.framework import tensor_util
@ -29,6 +28,7 @@ from tensorflow.python.ops import linalg_ops
from tensorflow.python.ops import math_ops
from tensorflow.python.ops.distributions import kullback_leibler
from tensorflow.python.ops.distributions import normal
from tensorflow.python.ops.distributions import transformed_distribution
from tensorflow.python.ops.distributions import util as distribution_util

View File

@ -19,7 +19,6 @@ from __future__ import division
from __future__ import print_function
from tensorflow.contrib.distributions.python.ops import logistic
from tensorflow.contrib.distributions.python.ops import transformed_distribution
# Bijectors must be directly imported because `remove_undocumented` prevents
# individual file imports.
from tensorflow.contrib.distributions.python.ops.bijectors.sigmoid import Sigmoid
@ -27,6 +26,7 @@ from tensorflow.python.framework import dtypes
from tensorflow.python.framework import ops
from tensorflow.python.ops import array_ops
from tensorflow.python.ops import check_ops
from tensorflow.python.ops.distributions import transformed_distribution
from tensorflow.python.ops.distributions import util as distribution_util

View File

@ -20,7 +20,6 @@ from __future__ import print_function
import numpy as np
from tensorflow.contrib.distributions.python.ops import bijectors
from tensorflow.contrib.distributions.python.ops import transformed_distribution
from tensorflow.python.framework import dtypes
from tensorflow.python.framework import ops
from tensorflow.python.ops import array_ops
@ -30,6 +29,7 @@ from tensorflow.python.ops import math_ops
from tensorflow.python.ops import nn_ops
from tensorflow.python.ops import random_ops
from tensorflow.python.ops.distributions import distribution
from tensorflow.python.ops.distributions import transformed_distribution
from tensorflow.python.ops.distributions import util as distribution_util

View File

@ -19,13 +19,13 @@ from __future__ import division
from __future__ import print_function
from tensorflow.contrib.distributions.python.ops import bijectors
from tensorflow.contrib.distributions.python.ops import transformed_distribution
from tensorflow.python.framework import constant_op
from tensorflow.python.framework import dtypes
from tensorflow.python.framework import ops
from tensorflow.python.framework import tensor_shape
from tensorflow.python.ops import array_ops
from tensorflow.python.ops.distributions import student_t
from tensorflow.python.ops.distributions import transformed_distribution
from tensorflow.python.ops.distributions import util as distribution_util

View File

@ -108,6 +108,7 @@ tf_custom_op_py_library(
"//tensorflow/python:util",
"//tensorflow/python:variable_scope",
"//tensorflow/python:variables",
"//tensorflow/python/feature_column",
"@six_archive//:six",
],
)

View File

@ -136,8 +136,10 @@ from tensorflow.contrib.layers.python.layers import layers
from tensorflow.contrib.layers.python.ops import bucketization_op
from tensorflow.contrib.layers.python.ops import sparse_feature_cross_op
from tensorflow.contrib.layers.python.ops import sparse_ops as contrib_sparse_ops
from tensorflow.python.feature_column import feature_column as fc_core
from tensorflow.python.framework import dtypes
from tensorflow.python.framework import sparse_tensor as sparse_tensor_py
from tensorflow.python.framework import tensor_shape
from tensorflow.python.ops import array_ops
from tensorflow.python.ops import init_ops
from tensorflow.python.ops import math_ops
@ -1497,7 +1499,10 @@ def _real_valued_var_len_column(column_name,
is_sparse)
class _RealValuedColumn(_FeatureColumn, collections.namedtuple(
class _RealValuedColumn(
_FeatureColumn,
fc_core._DenseColumn, # pylint: disable=protected-access
collections.namedtuple(
"_RealValuedColumn",
["column_name", "dimension", "default_value", "dtype", "normalizer"])):
"""Represents a real valued feature column also known as continuous features.
@ -1569,6 +1574,23 @@ class _RealValuedColumn(_FeatureColumn, collections.namedtuple(
def _to_dense_tensor(self, input_tensor):
return input_tensor
@property
def _variable_shape(self):
return tensor_shape.TensorShape((self.dimension))
def _get_dense_tensor(self, inputs, weight_collections=None, trainable=None):
del weight_collections
del trainable
return inputs.get(self)
def _transform_feature(self, inputs):
return math_ops.to_float(
self._normalized_input_tensor(inputs.get(self.name)))
@property
def _parse_example_config(self):
return self.config
def real_valued_column(column_name,
dimension=1,

View File

@ -27,14 +27,15 @@ from tensorflow.contrib.layers.python.layers import feature_column
from tensorflow.contrib.layers.python.layers import feature_column_ops
from tensorflow.core.example import example_pb2
from tensorflow.core.example import feature_pb2
from tensorflow.python.feature_column import feature_column as fc_core
from tensorflow.python.framework import constant_op
from tensorflow.python.framework import dtypes
from tensorflow.python.framework import ops
from tensorflow.python.framework import sparse_tensor
from tensorflow.python.ops import array_ops
from tensorflow.python.ops import data_flow_ops
from tensorflow.python.ops import gradients_impl
from tensorflow.python.ops import init_ops
from tensorflow.python.ops import lookup_ops
from tensorflow.python.ops import partitioned_variables
from tensorflow.python.ops import random_ops
from tensorflow.python.ops import variable_scope
@ -223,7 +224,7 @@ class TransformerTest(test.TestCase):
self.assertEqual(len(output), 1)
self.assertIn(keys_sparse, output)
with self.test_session():
data_flow_ops.tables_initializer().run()
lookup_ops.tables_initializer().run()
self.assertEqual(output[keys_sparse].values.dtype, dtypes.int64)
self.assertAllEqual(output[keys_sparse].values.eval(), [1, 2, 0])
self.assertAllEqual(output[keys_sparse].indices.eval(),
@ -241,7 +242,7 @@ class TransformerTest(test.TestCase):
output = feature_column_ops._Transformer(features).transform(keys_sparse)
with self.test_session():
data_flow_ops.tables_initializer().run()
lookup_ops.tables_initializer().run()
# While the input is a dense Tensor, the output should be a SparseTensor.
self.assertIsInstance(output, sparse_tensor.SparseTensor)
self.assertEqual(output.dtype, dtypes.int64)
@ -310,7 +311,7 @@ class TransformerTest(test.TestCase):
self.assertIn(weighted_ids, output)
with self.test_session():
data_flow_ops.tables_initializer().run()
lookup_ops.tables_initializer().run()
self.assertAllEqual(output[weighted_ids][0].dense_shape.eval(),
ids_tensor.dense_shape.eval())
self.assertAllEqual(output[weighted_ids][0].indices.eval(),
@ -340,7 +341,7 @@ class TransformerTest(test.TestCase):
self.assertEqual(len(output), 1)
self.assertIn(vocab_sparse, output)
with self.test_session():
data_flow_ops.tables_initializer().run()
lookup_ops.tables_initializer().run()
self.assertEqual(output[vocab_sparse].values.dtype, dtypes.int64)
self.assertAllEqual(output[vocab_sparse].values.eval(), [1, 2, 0])
self.assertAllEqual(output[vocab_sparse].indices.eval(),
@ -362,7 +363,7 @@ class TransformerTest(test.TestCase):
self.assertEqual(len(output), 1)
self.assertIn(vocab_sparse, output)
with self.test_session():
data_flow_ops.tables_initializer().run()
lookup_ops.tables_initializer().run()
self.assertEqual(output[vocab_sparse].values.dtype, dtypes.int64)
self.assertAllEqual(output[vocab_sparse].values.eval(), [1, 2, 0, 1])
self.assertAllEqual(output[vocab_sparse].indices.eval(),
@ -386,7 +387,7 @@ class TransformerTest(test.TestCase):
self.assertEqual(len(output), 1)
self.assertIn(vocab_sparse, output)
with self.test_session():
data_flow_ops.tables_initializer().run()
lookup_ops.tables_initializer().run()
self.assertEqual(output[vocab_sparse].values.dtype, dtypes.int64)
self.assertAllEqual(output[vocab_sparse].values.eval(), [1, 2, 0])
self.assertAllEqual(output[vocab_sparse].indices.eval(),
@ -408,7 +409,7 @@ class TransformerTest(test.TestCase):
self.assertEqual(len(output), 1)
self.assertIn(vocab_sparse, output)
with self.test_session():
data_flow_ops.tables_initializer().run()
lookup_ops.tables_initializer().run()
self.assertEqual(output[vocab_sparse].values.dtype, dtypes.int64)
self.assertAllEqual(output[vocab_sparse].values.eval(), [1, 2, 0, 1])
self.assertAllEqual(output[vocab_sparse].indices.eval(),
@ -600,7 +601,7 @@ class CreateInputLayersForDNNsTest(test.TestCase):
one_hot_column, embedding_column, real_valued_column])
with self.test_session():
variables_lib.global_variables_initializer().run()
data_flow_ops.tables_initializer().run()
lookup_ops.tables_initializer().run()
self.assertAllEqual(output.eval().shape, [3, 2 + 4 + 10])
def testRealValuedColumn(self):
@ -610,6 +611,10 @@ class CreateInputLayersForDNNsTest(test.TestCase):
[real_valued])
with self.test_session():
self.assertAllClose(output.eval(), features["price"].eval())
# Verify cross compatibility: Core builder output should equal to contrib.
self.assertAllClose(output.eval(),
fc_core.make_input_layer(features,
[real_valued]).eval())
def testRealValuedColumnWithMultiDimensions(self):
real_valued = feature_column.real_valued_column("price", 2)
@ -620,6 +625,10 @@ class CreateInputLayersForDNNsTest(test.TestCase):
[real_valued])
with self.test_session():
self.assertAllClose(output.eval(), features["price"].eval())
# Verify cross compatibility: Core builder output should equal to contrib.
self.assertAllClose(output.eval(),
fc_core.make_input_layer(features,
[real_valued]).eval())
def testRealValuedColumnSparse(self):
sparse_real_valued = feature_column._real_valued_var_len_column(
@ -640,6 +649,10 @@ class CreateInputLayersForDNNsTest(test.TestCase):
[real_valued])
with self.test_session():
self.assertAllClose(output.eval(), features["price"].eval() - 2)
# Verify cross compatibility: Core builder output should equal to contrib.
self.assertAllClose(output.eval(),
fc_core.make_input_layer(features,
[real_valued]).eval())
def testRealValuedColumnWithMultiDimensionsAndNormalizer(self):
real_valued = feature_column.real_valued_column(
@ -651,6 +664,10 @@ class CreateInputLayersForDNNsTest(test.TestCase):
[real_valued])
with self.test_session():
self.assertAllClose(output.eval(), features["price"].eval() - 2)
# Verify cross compatibility: Core builder output should equal to contrib.
self.assertAllClose(output.eval(),
fc_core.make_input_layer(features,
[real_valued]).eval())
def testBucketizedColumnWithNormalizerSucceedsForDNN(self):
bucket = feature_column.bucketized_column(
@ -697,7 +714,7 @@ class CreateInputLayersForDNNsTest(test.TestCase):
[one_hot_column])
with self.test_session():
variables_lib.global_variables_initializer().run()
data_flow_ops.tables_initializer().run()
lookup_ops.tables_initializer().run()
self.assertAllEqual([[0, 0, 10., 0], [0, 20., 0, 0], [30., 0, 40., 0]],
output.eval())
@ -715,7 +732,7 @@ class CreateInputLayersForDNNsTest(test.TestCase):
with self.test_session():
variables_lib.global_variables_initializer().run()
data_flow_ops.tables_initializer().run()
lookup_ops.tables_initializer().run()
self.assertAllEqual([[0, 0, 1, 0], [0, 1, 0, 0], [1, 0, 0, 0]],
output.eval())
@ -733,7 +750,7 @@ class CreateInputLayersForDNNsTest(test.TestCase):
with self.test_session():
variables_lib.global_variables_initializer().run()
data_flow_ops.tables_initializer().run()
lookup_ops.tables_initializer().run()
self.assertAllEqual([[0, 0, 1, 0], [0, 1, 0, 0], [1, 0, 1, 0]],
output.eval())
@ -767,7 +784,7 @@ class CreateInputLayersForDNNsTest(test.TestCase):
[one_hot_sparse])
with self.test_session():
variables_lib.global_variables_initializer().run()
data_flow_ops.tables_initializer().run()
lookup_ops.tables_initializer().run()
self.assertAllEqual([3, 10], output.eval().shape)
def testEmbeddingColumnSucceedsForDNN(self):
@ -874,7 +891,7 @@ class CreateInputLayersForDNNsTest(test.TestCase):
[embeded_sparse])
with self.test_session():
variables_lib.global_variables_initializer().run()
data_flow_ops.tables_initializer().run()
lookup_ops.tables_initializer().run()
self.assertAllEqual(output.eval().shape, [2, 10])
def testEmbeddingColumnWithIntegerWeightedSparseColumnSucceedsForDNN(self):
@ -897,7 +914,7 @@ class CreateInputLayersForDNNsTest(test.TestCase):
[embeded_sparse])
with self.test_session():
variables_lib.global_variables_initializer().run()
data_flow_ops.tables_initializer().run()
lookup_ops.tables_initializer().run()
self.assertAllEqual(output.eval().shape, [2, 10])
def testEmbeddingColumnWithCrossedColumnSucceedsForDNN(self):
@ -948,7 +965,7 @@ class CreateInputLayersForDNNsTest(test.TestCase):
with self.assertRaisesRegexp(
ValueError,
"Error creating input layer for column: ids_weighted_by_weights"):
data_flow_ops.tables_initializer().run()
lookup_ops.tables_initializer().run()
feature_column_ops.input_from_feature_columns(features, [weighted_ids])
def testCrossedColumnFailsForDNN(self):
@ -1055,7 +1072,7 @@ class CreateInputLayersForDNNsTest(test.TestCase):
[embeded_sparse])
with self.test_session():
variables_lib.global_variables_initializer().run()
data_flow_ops.tables_initializer().run()
lookup_ops.tables_initializer().run()
# score: (sum of weights)
self.assertAllEqual(output.eval(), [[10.], [50.], [0.]])
@ -1293,7 +1310,7 @@ class SequenceInputFromFeatureColumnTest(test.TestCase):
with self.test_session() as sess:
variables_lib.global_variables_initializer().run()
data_flow_ops.tables_initializer().run()
lookup_ops.tables_initializer().run()
model_input = sess.run(model_input_tensor)
expected_input_shape = np.array([4, 3, 4])
@ -1327,7 +1344,7 @@ class SequenceInputFromFeatureColumnTest(test.TestCase):
with self.test_session() as sess:
variables_lib.global_variables_initializer().run()
data_flow_ops.tables_initializer().run()
lookup_ops.tables_initializer().run()
model_input = sess.run(model_input_tensor)
expected_input_shape = np.array([4, 3, hash_buckets])
@ -1357,7 +1374,7 @@ class SequenceInputFromFeatureColumnTest(test.TestCase):
with self.test_session() as sess:
variables_lib.global_variables_initializer().run()
data_flow_ops.tables_initializer().run()
lookup_ops.tables_initializer().run()
model_input = sess.run(model_input_tensor)
self.assertAllEqual(expected_input_shape, model_input.shape)
@ -1386,7 +1403,7 @@ class SequenceInputFromFeatureColumnTest(test.TestCase):
with self.test_session() as sess:
variables_lib.global_variables_initializer().run()
data_flow_ops.tables_initializer().run()
lookup_ops.tables_initializer().run()
model_input = sess.run(model_input_tensor)
self.assertAllEqual(expected_input_shape, model_input.shape)
@ -1416,7 +1433,7 @@ class SequenceInputFromFeatureColumnTest(test.TestCase):
embedding_weights)
with self.test_session() as sess:
variables_lib.global_variables_initializer().run()
data_flow_ops.tables_initializer().run()
lookup_ops.tables_initializer().run()
model_input, gradients = sess.run([model_input_tensor, gradient_tensor])
expected_input_shape = [4, 3, embedding_dimension]
@ -1483,7 +1500,7 @@ class SequenceInputFromFeatureColumnTest(test.TestCase):
with self.test_session() as sess:
variables_lib.global_variables_initializer().run()
data_flow_ops.tables_initializer().run()
lookup_ops.tables_initializer().run()
model_input = sess.run(model_input_tensor)
expected_input_shape = [
@ -1564,7 +1581,7 @@ class WeightedSumTest(test.TestCase):
features, [weighted_ids], num_outputs=5)
with self.test_session():
variables_lib.global_variables_initializer().run()
data_flow_ops.tables_initializer().run()
lookup_ops.tables_initializer().run()
self.assertAllEqual(logits.eval().shape, [2, 5])
def testWeightedSparseColumnWithDenseInputTensor(self):
@ -1580,7 +1597,7 @@ class WeightedSumTest(test.TestCase):
with self.test_session():
variables_lib.global_variables_initializer().run()
data_flow_ops.tables_initializer().run()
lookup_ops.tables_initializer().run()
self.assertAllEqual(logits.eval().shape, [2, 5])
def testCrossedColumn(self):
@ -1634,7 +1651,7 @@ class WeightedSumTest(test.TestCase):
features, [movies], num_outputs=1))
with self.test_session() as sess:
variables_lib.initialize_all_variables().run()
data_flow_ops.tables_initializer().run()
lookup_ops.tables_initializer().run()
weights = column_to_variable[movies][0]
self.assertEqual(weights.get_shape(), (3, 1))
@ -1709,7 +1726,7 @@ class WeightedSumTest(test.TestCase):
features, [age, language], num_outputs=1))
with self.test_session() as sess:
variables_lib.global_variables_initializer().run()
data_flow_ops.tables_initializer().run()
lookup_ops.tables_initializer().run()
self.assertAllClose(output.eval(), [[0.], [0.]])
@ -1749,7 +1766,7 @@ class WeightedSumTest(test.TestCase):
self.assertEqual(len(variables), 1)
with self.test_session() as sess:
variables_lib.global_variables_initializer().run()
data_flow_ops.tables_initializer().run()
lookup_ops.tables_initializer().run()
self.assertAllClose(output.eval(), [[0.], [0.]])
@ -1813,7 +1830,7 @@ class WeightedSumTest(test.TestCase):
features, [weighted_language], num_outputs=1))
with self.test_session() as sess:
variables_lib.global_variables_initializer().run()
data_flow_ops.tables_initializer().run()
lookup_ops.tables_initializer().run()
self.assertAllClose(output.eval(), [[0.], [0.]])
@ -1841,7 +1858,7 @@ class WeightedSumTest(test.TestCase):
features, [language], num_outputs=1))
with self.test_session() as sess:
variables_lib.global_variables_initializer().run()
data_flow_ops.tables_initializer().run()
lookup_ops.tables_initializer().run()
# score: 0.1 + language_weight['hindi'] + language_weight['english']
sess.run(bias.assign([0.1]))
@ -1864,7 +1881,7 @@ class WeightedSumTest(test.TestCase):
features, [movies], num_outputs=1))
with self.test_session() as sess:
variables_lib.global_variables_initializer().run()
data_flow_ops.tables_initializer().run()
lookup_ops.tables_initializer().run()
weights = column_to_variable[movies][0]
self.assertEqual(weights.get_shape(), (15, 1))
@ -1898,7 +1915,7 @@ class WeightedSumTest(test.TestCase):
features, [country_language], num_outputs=1))
with self.test_session() as sess:
variables_lib.global_variables_initializer().run()
data_flow_ops.tables_initializer().run()
lookup_ops.tables_initializer().run()
weights = column_to_variable[country_language][0]
sess.run(weights.assign(weights + 0.4))
@ -1922,7 +1939,7 @@ class WeightedSumTest(test.TestCase):
features, [language_language], num_outputs=1))
with self.test_session() as sess:
variables_lib.global_variables_initializer().run()
data_flow_ops.tables_initializer().run()
lookup_ops.tables_initializer().run()
weights = column_to_variable[language_language][0]
sess.run(weights.assign(weights + 0.4))
@ -1955,7 +1972,7 @@ class WeightedSumTest(test.TestCase):
features, [country_language], num_outputs=1))
with self.test_session() as sess:
variables_lib.global_variables_initializer().run()
data_flow_ops.tables_initializer().run()
lookup_ops.tables_initializer().run()
weights = column_to_variable[country_language][0]
sess.run(weights.assign(weights + 0.4))
@ -1996,7 +2013,7 @@ class WeightedSumTest(test.TestCase):
scope=scope))
with self.test_session() as sess:
variables_lib.global_variables_initializer().run()
data_flow_ops.tables_initializer().run()
lookup_ops.tables_initializer().run()
self.assertEqual(2, len(column_to_variable[country]))
self.assertEqual(3, len(column_to_variable[language]))
@ -2033,7 +2050,7 @@ class WeightedSumTest(test.TestCase):
features, [country, age, incomes], num_outputs=1))
with self.test_session() as sess:
variables_lib.global_variables_initializer().run()
data_flow_ops.tables_initializer().run()
lookup_ops.tables_initializer().run()
incomes_weights = column_to_variable[incomes][0]
sess.run(incomes_weights.assign([[0.1], [0.2], [0.3]]))
@ -2069,7 +2086,7 @@ class WeightedSumTest(test.TestCase):
features, [country, age, height, incomes], num_outputs=5))
with self.test_session() as sess:
variables_lib.global_variables_initializer().run()
data_flow_ops.tables_initializer().run()
lookup_ops.tables_initializer().run()
height_weights = column_to_variable[height][0]
sess.run(
@ -2099,7 +2116,7 @@ class WeightedSumTest(test.TestCase):
features, [bucket], num_outputs=1))
with self.test_session() as sess:
variables_lib.global_variables_initializer().run()
data_flow_ops.tables_initializer().run()
lookup_ops.tables_initializer().run()
sess.run(column_to_variable[bucket][0].assign([[0.1], [0.2], [0.3],
[0.4]]))
@ -2127,7 +2144,7 @@ class WeightedSumTest(test.TestCase):
features, [bucket, country], num_outputs=1))
with self.test_session() as sess:
variables_lib.global_variables_initializer().run()
data_flow_ops.tables_initializer().run()
lookup_ops.tables_initializer().run()
# dimension = 2, bucket_size = 4, num_classes = 1
sess.run(column_to_variable[bucket][0].assign(
@ -2156,7 +2173,7 @@ class WeightedSumTest(test.TestCase):
features, [bucket, country], num_outputs=5))
with self.test_session() as sess:
variables_lib.global_variables_initializer().run()
data_flow_ops.tables_initializer().run()
lookup_ops.tables_initializer().run()
# dimension = 2, bucket_size = 4, num_classes = 5
sess.run(column_to_variable[bucket][0].assign(
@ -2192,7 +2209,7 @@ class WeightedSumTest(test.TestCase):
features, [country_price], num_outputs=1))
with self.test_session() as sess:
variables_lib.global_variables_initializer().run()
data_flow_ops.tables_initializer().run()
lookup_ops.tables_initializer().run()
weights = column_to_variable[country_price][0]
sess.run(weights.assign(weights + 0.4))
@ -2231,7 +2248,7 @@ class WeightedSumTest(test.TestCase):
features, [country_language_price], num_outputs=1))
with self.test_session() as sess:
variables_lib.global_variables_initializer().run()
data_flow_ops.tables_initializer().run()
lookup_ops.tables_initializer().run()
weights = column_to_variable[country_language_price][0]
sess.run(weights.assign(weights + 0.4))
@ -2255,7 +2272,7 @@ class WeightedSumTest(test.TestCase):
features, [product], num_outputs=1))
with self.test_session() as sess:
variables_lib.global_variables_initializer().run()
data_flow_ops.tables_initializer().run()
lookup_ops.tables_initializer().run()
product_weights = column_to_variable[product][0]
sess.run(product_weights.assign([[0.1], [0.2], [0.3], [0.4], [0.5]]))
self.assertAllClose(output.eval(), [[0.1], [0.5], [0.3]])
@ -2270,7 +2287,7 @@ class WeightedSumTest(test.TestCase):
features, [product], num_outputs=1))
with self.test_session() as sess:
variables_lib.global_variables_initializer().run()
data_flow_ops.tables_initializer().run()
lookup_ops.tables_initializer().run()
product_weights = column_to_variable[product][0]
sess.run(product_weights.assign([[0.1], [0.2], [0.3], [0.4], [0.5]]))
self.assertAllClose(output.eval(), [[0.1], [0.5], [0.3]])
@ -2285,7 +2302,7 @@ class WeightedSumTest(test.TestCase):
features, [product], num_outputs=1))
with self.test_session() as sess:
variables_lib.global_variables_initializer().run()
data_flow_ops.tables_initializer().run()
lookup_ops.tables_initializer().run()
product_weights = column_to_variable[product][0]
sess.run(product_weights.assign([[0.1], [0.2], [0.3], [0.4], [0.5]]))
self.assertAllClose(output.eval(), [[0.6], [0.7]])
@ -2306,7 +2323,7 @@ class WeightedSumTest(test.TestCase):
features, [product], num_outputs=1))
with self.test_session() as sess:
variables_lib.global_variables_initializer().run()
data_flow_ops.tables_initializer().run()
lookup_ops.tables_initializer().run()
product_weights = column_to_variable[product][0]
sess.run(product_weights.assign([[0.1], [0.2], [0.3], [0.4], [0.5]]))
self.assertAllClose(output.eval(), [[0.1], [0.5], [0.3]])
@ -2318,7 +2335,7 @@ class WeightedSumTest(test.TestCase):
features, [feature_column.real_valued_column("age")], num_outputs=3)
with self.test_session() as sess:
variables_lib.global_variables_initializer().run()
data_flow_ops.tables_initializer().run()
lookup_ops.tables_initializer().run()
sess.run(bias.assign([0.1, 0.2, 0.3]))
self.assertAllClose(output.eval(), [[0.1, 0.2, 0.3], [0.1, 0.2, 0.3],
[0.1, 0.2, 0.3], [0.1, 0.2, 0.3]])
@ -2332,7 +2349,7 @@ class WeightedSumTest(test.TestCase):
features, [column], num_outputs=3))
with self.test_session() as sess:
variables_lib.global_variables_initializer().run()
data_flow_ops.tables_initializer().run()
lookup_ops.tables_initializer().run()
weights = column_to_variable[column][0]
self.assertEqual(weights.get_shape(), (1, 3))
sess.run(weights.assign([[0.01, 0.03, 0.05]]))
@ -2356,7 +2373,7 @@ class WeightedSumTest(test.TestCase):
features, [column], num_outputs=3))
with self.test_session() as sess:
variables_lib.global_variables_initializer().run()
data_flow_ops.tables_initializer().run()
lookup_ops.tables_initializer().run()
weights = column_to_variable[column][0]
self.assertEqual(weights.get_shape(), (5, 3))
sess.run(
@ -2382,7 +2399,7 @@ class WeightedSumTest(test.TestCase):
features, [column], num_outputs=3))
with self.test_session() as sess:
variables_lib.global_variables_initializer().run()
data_flow_ops.tables_initializer().run()
lookup_ops.tables_initializer().run()
weights = column_to_variable[column][0]
self.assertEqual(weights.get_shape(), (5, 3))
@ -2422,7 +2439,7 @@ class WeightedSumTest(test.TestCase):
features, [column], num_outputs=3))
with self.test_session() as sess:
variables_lib.global_variables_initializer().run()
data_flow_ops.tables_initializer().run()
lookup_ops.tables_initializer().run()
weights = column_to_variable[column][0]
self.assertEqual(weights.get_shape(), (5, 3))
@ -2451,7 +2468,7 @@ class WeightedSumTest(test.TestCase):
features, [column], num_outputs=3))
with self.test_session() as sess:
variables_lib.global_variables_initializer().run()
data_flow_ops.tables_initializer().run()
lookup_ops.tables_initializer().run()
weights = column_to_variable[column][0]
self.assertEqual(weights.get_shape(), (5, 3))
@ -2516,7 +2533,7 @@ class ParseExampleTest(test.TestCase):
self.assertIn(bucket, output)
self.assertIn(wire_cast, output)
with self.test_session():
data_flow_ops.tables_initializer().run()
lookup_ops.tables_initializer().run()
self.assertAllEqual(output[bucket].eval(), [[2, 3, 0]])
self.assertAllEqual(output[wire_cast].indices.eval(), [[0, 0], [0, 1]])
self.assertAllEqual(output[wire_cast].values.eval(), [2, 0])

View File

@ -46,7 +46,7 @@ def xavier_initializer(uniform=True, seed=None, dtype=dtypes.float32):
Args:
uniform: Whether to use uniform or normal distributed random initialization.
seed: A Python integer. Used to create random seeds. See
@{set_random_seed} for behavior.
@{tf.set_random_seed} for behavior.
dtype: The data type. Only floating point types are supported.
Returns:
@ -96,7 +96,7 @@ def variance_scaling_initializer(factor=2.0, mode='FAN_IN', uniform=False,
mode: String. 'FAN_IN', 'FAN_OUT', 'FAN_AVG'.
uniform: Whether to use uniform or normal distributed random initialization.
seed: A Python integer. Used to create random seeds. See
@{set_random_seed} for behavior.
@{tf.set_random_seed} for behavior.
dtype: The data type. Only floating point types are supported.
Returns:

View File

@ -38,8 +38,8 @@ from tensorflow.python.framework import dtypes
from tensorflow.python.framework import random_seed
from tensorflow.python.framework import sparse_tensor
from tensorflow.python.ops import array_ops
from tensorflow.python.ops import data_flow_ops
from tensorflow.python.ops import functional_ops
from tensorflow.python.ops import lookup_ops
from tensorflow.python.ops import math_ops
from tensorflow.python.ops import random_ops
from tensorflow.python.ops import variables
@ -157,7 +157,7 @@ class DynamicRnnEstimatorTest(test.TestCase):
self.context_feature_columns)
with self.test_session() as sess:
sess.run(variables.global_variables_initializer())
sess.run(data_flow_ops.tables_initializer())
sess.run(lookup_ops.tables_initializer())
sequence_input_val = sess.run(sequence_input)
expected_shape = np.array([
3, # expected batch size
@ -178,7 +178,7 @@ class DynamicRnnEstimatorTest(test.TestCase):
# Obtain values of activations and final state.
with session.Session() as sess:
sess.run(variables.global_variables_initializer())
sess.run(data_flow_ops.tables_initializer())
sess.run(lookup_ops.tables_initializer())
activations, final_state = sess.run([activations_t, final_state_t])
expected_activations_shape = np.array([3, 2, self.NUM_LABEL_COLUMNS])

View File

@ -57,7 +57,7 @@ from tensorflow.python.framework import ops
from tensorflow.python.framework import random_seed
from tensorflow.python.framework import sparse_tensor
from tensorflow.python.ops import control_flow_ops
from tensorflow.python.ops import data_flow_ops
from tensorflow.python.ops import lookup_ops
from tensorflow.python.ops import resources
from tensorflow.python.ops import variables
from tensorflow.python.platform import gfile
@ -1292,7 +1292,7 @@ class Estimator(BaseEstimator):
init_op = control_flow_ops.group(
variables.local_variables_initializer(),
resources.initialize_resources(resources.shared_resources()),
data_flow_ops.tables_initializer())
lookup_ops.tables_initializer())
# Perform the export
builder = saved_model_builder.SavedModelBuilder(export_dir)

View File

@ -32,7 +32,7 @@ from tensorflow.core.framework import summary_pb2
from tensorflow.python.client import session
from tensorflow.python.framework import ops
from tensorflow.python.framework import sparse_tensor
from tensorflow.python.ops import data_flow_ops
from tensorflow.python.ops import lookup_ops
from tensorflow.python.ops import variables
from tensorflow.python.ops.losses import losses as losses_lib
from tensorflow.python.platform import test
@ -1214,7 +1214,7 @@ class MultiClassHeadTest(test.TestCase):
train_op_fn=head_lib.no_op_train_fn,
logits=((1., 0., 0.), (0., 0., 1.),))
with session.Session():
data_flow_ops.tables_initializer().run()
lookup_ops.tables_initializer().run()
self.assertAllEqual(
[0, 2],
model_fn_ops.predictions["classes"].eval())
@ -1266,7 +1266,7 @@ class MultiClassHeadTest(test.TestCase):
train_op_fn=head_lib.no_op_train_fn,
logits=((1., 0., 0.), (0., 0., 1.),))
with session.Session():
data_flow_ops.tables_initializer().run()
lookup_ops.tables_initializer().run()
self.assertAllEqual(
[b"key0", b"key2"],
model_fn_ops.predictions["classes"].eval())
@ -1301,7 +1301,7 @@ class MultiClassHeadTest(test.TestCase):
train_op_fn=head_lib.no_op_train_fn,
logits=((1., 0., 0.),))
with session.Session():
data_flow_ops.tables_initializer().run()
lookup_ops.tables_initializer().run()
self.assertIsNone(model_fn_ops.train_op)
_assert_no_variables(self)
_assert_summary_tags(self, ["loss"])
@ -1327,7 +1327,7 @@ class MultiClassHeadTest(test.TestCase):
train_op_fn=head_lib.no_op_train_fn,
logits=((0., 0., 1.),))
with session.Session():
data_flow_ops.tables_initializer().run()
lookup_ops.tables_initializer().run()
self.assertIsNone(model_fn_ops.train_op)
_assert_no_variables(self)
_assert_summary_tags(self, ["loss"])

View File

@ -35,8 +35,8 @@ from tensorflow.python.framework import constant_op
from tensorflow.python.framework import dtypes
from tensorflow.python.framework import sparse_tensor
from tensorflow.python.ops import array_ops
from tensorflow.python.ops import data_flow_ops
from tensorflow.python.ops import init_ops
from tensorflow.python.ops import lookup_ops
from tensorflow.python.ops import math_ops
from tensorflow.python.ops import random_ops
from tensorflow.python.ops import variables
@ -55,7 +55,7 @@ class PrepareInputsForRnnTest(test.TestCase):
with self.test_session() as sess:
sess.run(variables.global_variables_initializer())
sess.run(data_flow_ops.initialize_all_tables())
sess.run(lookup_ops.tables_initializer())
features_val = sess.run(features_by_time)
self.assertAllEqual(expected, features_val)
@ -316,7 +316,7 @@ class StateSavingRnnEstimatorTest(test.TestCase):
with self.test_session() as sess:
sess.run(variables.global_variables_initializer())
sess.run(data_flow_ops.initialize_all_tables())
sess.run(lookup_ops.tables_initializer())
actual_sequence, actual_context = sess.run(
[sequence, context])
assert_equal(expected_sequence, actual_sequence)

View File

@ -647,6 +647,10 @@ class Experiment(object):
if _sentinel is not None:
raise ValueError("_call_train should be called with keyword args only")
# Estimator in core cannot work with monitors. We need to convert them
# to hooks. For Estimator in contrib, it is converted internally. So, it is
# safe to convert for both cases.
hooks = monitors.replace_monitors_with_hooks(hooks, self._estimator)
if self._core_estimator_used:
return self._estimator.train(input_fn=input_fn,
steps=steps,

View File

@ -24,7 +24,6 @@ import time
from tensorflow.contrib.learn.python.learn import evaluable
from tensorflow.contrib.learn.python.learn import experiment
from tensorflow.contrib.learn.python.learn import monitors
from tensorflow.contrib.learn.python.learn import run_config
from tensorflow.contrib.learn.python.learn import trainable
from tensorflow.contrib.learn.python.learn.estimators import run_config as run_config_lib
@ -461,7 +460,8 @@ class ExperimentTest(test.TestCase):
self.assertEqual(1, est.eval_count)
self.assertEqual(1, len(est.monitors))
self.assertEqual([noop_hook], est.eval_hooks)
self.assertTrue(isinstance(est.monitors[0], monitors.ValidationMonitor))
self.assertTrue(isinstance(est.monitors[0],
session_run_hook.SessionRunHook))
def test_train_hooks_extend_does_not_mutate_input_hooks(self):
for est in self._estimators_for_tests():
@ -563,7 +563,8 @@ class ExperimentTest(test.TestCase):
self.assertEqual(1, est.export_count)
self.assertEqual(1, len(est.monitors))
self.assertEqual([noop_hook], est.eval_hooks)
self.assertTrue(isinstance(est.monitors[0], monitors.ValidationMonitor))
self.assertTrue(isinstance(est.monitors[0],
session_run_hook.SessionRunHook))
def test_train_and_evaluate_with_no_eval_during_training(self):
for est in self._estimators_for_tests():

View File

@ -37,8 +37,8 @@ from tensorflow.python.client import session as tf_session
from tensorflow.python.framework import errors
from tensorflow.python.framework import ops
from tensorflow.python.ops import control_flow_ops
from tensorflow.python.ops import data_flow_ops
from tensorflow.python.ops import logging_ops
from tensorflow.python.ops import lookup_ops
from tensorflow.python.ops import resources
from tensorflow.python.ops import variables
from tensorflow.python.platform import tf_logging as logging
@ -429,11 +429,14 @@ def _get_ready_op():
def _get_local_init_op():
"""Returns the local init ops to initialize tables and local variables."""
local_init_op = _get_first_op_from_collection(
ops.GraphKeys.LOCAL_INIT_OP)
if local_init_op is None:
op_list = [variables.local_variables_initializer(),
data_flow_ops.tables_initializer()]
op_list = [
variables.local_variables_initializer(),
lookup_ops.tables_initializer()
]
if op_list:
local_init_op = control_flow_ops.group(*op_list)
ops.add_to_collection(ops.GraphKeys.LOCAL_INIT_OP, local_init_op)
@ -680,7 +683,7 @@ def run_feeds_iter(output_dict, feed_dicts, restore_checkpoint_path=None):
else:
session.run(variables.global_variables_initializer())
session.run(variables.local_variables_initializer())
session.run(data_flow_ops.tables_initializer())
session.run(lookup_ops.tables_initializer())
coord = coordinator.Coordinator()
threads = None
try:

View File

@ -28,7 +28,7 @@ from tensorflow.python.framework import dtypes
from tensorflow.python.framework import ops
from tensorflow.python.ops import array_ops
from tensorflow.python.ops import control_flow_ops
from tensorflow.python.ops import data_flow_ops
from tensorflow.python.ops import lookup_ops
from tensorflow.python.ops import variables
from tensorflow.python.platform import tf_logging as logging
from tensorflow.python.training import saver as tf_saver
@ -67,17 +67,17 @@ def _export_graph(graph, saver, checkpoint_path, export_dir,
with graph.as_default():
with tf_session.Session('') as session:
variables.local_variables_initializer()
data_flow_ops.tables_initializer()
lookup_ops.tables_initializer()
saver.restore(session, checkpoint_path)
export = exporter.Exporter(saver)
export.init(init_op=control_flow_ops.group(
export.init(
init_op=control_flow_ops.group(
variables.local_variables_initializer(),
data_flow_ops.tables_initializer()),
lookup_ops.tables_initializer()),
default_graph_signature=default_graph_signature,
named_graph_signatures=named_graph_signatures,
assets_collection=ops.get_collection(
ops.GraphKeys.ASSET_FILEPATHS))
assets_collection=ops.get_collection(ops.GraphKeys.ASSET_FILEPATHS))
return export.export(export_dir, contrib_variables.get_global_step(),
session, exports_to_keep=exports_to_keep)

View File

@ -13,19 +13,10 @@ py_library(
name = "lookup_py",
srcs = [
"__init__.py",
"lookup_ops.py",
],
srcs_version = "PY2AND3",
deps = [
"//tensorflow/python:array_ops",
"//tensorflow/python:control_flow_ops",
"//tensorflow/python:data_flow_ops_gen",
"//tensorflow/python:framework",
"//tensorflow/python:framework_for_generated_wrappers",
"//tensorflow/python:math_ops",
"//tensorflow/python:string_ops",
"//tensorflow/python:training",
"//tensorflow/python:util",
"//tensorflow/python/feature_column:lookup_ops",
],
)
@ -39,11 +30,11 @@ py_test(
"//tensorflow/python:array_ops",
"//tensorflow/python:client",
"//tensorflow/python:client_testlib",
"//tensorflow/python:data_flow_ops",
"//tensorflow/python:errors",
"//tensorflow/python:framework",
"//tensorflow/python:framework_for_generated_wrappers",
"//tensorflow/python:framework_test_lib",
"//tensorflow/python:lookup_ops",
"//tensorflow/python:platform_test",
"//tensorflow/python:training",
"//tensorflow/python:variables",

View File

@ -47,7 +47,7 @@ from __future__ import division
from __future__ import print_function
# pylint: disable=unused-import,wildcard-import
from tensorflow.contrib.lookup.lookup_ops import *
from tensorflow.python.feature_column.lookup_ops import *
# pylint: enable=unused-import,wildcard-import
from tensorflow.python.util.all_util import remove_undocumented

View File

@ -31,7 +31,7 @@ from tensorflow.python.framework import ops
from tensorflow.python.framework import sparse_tensor
from tensorflow.python.framework import test_util
from tensorflow.python.ops import array_ops
from tensorflow.python.ops import data_flow_ops
from tensorflow.python.ops import lookup_ops
from tensorflow.python.ops import variables
from tensorflow.python.platform import test
from tensorflow.python.training import saver
@ -125,7 +125,7 @@ class HashTableOpTest(test.TestCase):
table3 = lookup.HashTable(
lookup.KeyValueTensorInitializer(keys, values), default_val)
data_flow_ops.tables_initializer().run()
lookup_ops.tables_initializer().run()
self.assertAllEqual(3, table1.size().eval())
self.assertAllEqual(3, table2.size().eval())
self.assertAllEqual(3, table3.size().eval())
@ -1184,7 +1184,7 @@ class IndexTableFromFile(test.TestCase):
ids = table.lookup(constant_op.constant(["salad", "surgery", "tarkus"]))
self.assertRaises(errors_impl.OpError, ids.eval)
data_flow_ops.tables_initializer().run()
lookup_ops.tables_initializer().run()
self.assertAllEqual((1, 2, 3), ids.eval())
def test_int32_index_table_from_file(self):
@ -1198,7 +1198,7 @@ class IndexTableFromFile(test.TestCase):
constant_op.constant((1, -1000, 11), dtype=dtypes.int32))
self.assertRaises(errors_impl.OpError, ids.eval)
data_flow_ops.tables_initializer().run()
lookup_ops.tables_initializer().run()
self.assertAllEqual((1, 2, 3), ids.eval())
def test_int64_index_table_from_file(self):
@ -1212,7 +1212,7 @@ class IndexTableFromFile(test.TestCase):
constant_op.constant((1, -1000, 11), dtype=dtypes.int64))
self.assertRaises(errors_impl.OpError, ids.eval)
data_flow_ops.tables_initializer().run()
lookup_ops.tables_initializer().run()
self.assertAllEqual((1, 2, 3), ids.eval())
def test_index_table_from_file_with_default_value(self):
@ -1224,7 +1224,7 @@ class IndexTableFromFile(test.TestCase):
ids = table.lookup(constant_op.constant(["salad", "surgery", "tarkus"]))
self.assertRaises(errors_impl.OpError, ids.eval)
data_flow_ops.tables_initializer().run()
lookup_ops.tables_initializer().run()
self.assertAllEqual((1, 2, default_value), ids.eval())
def test_index_table_from_file_with_oov_buckets(self):
@ -1236,7 +1236,7 @@ class IndexTableFromFile(test.TestCase):
constant_op.constant(["salad", "surgery", "tarkus", "toccata"]))
self.assertRaises(errors_impl.OpError, ids.eval)
data_flow_ops.tables_initializer().run()
lookup_ops.tables_initializer().run()
self.assertAllEqual(
(
1, # From vocabulary file.
@ -1259,7 +1259,7 @@ class IndexTableFromFile(test.TestCase):
ids = table.lookup(constant_op.constant(["salad", "surgery", "tarkus"]))
self.assertRaises(errors_impl.OpError, ids.eval)
data_flow_ops.tables_initializer().run()
lookup_ops.tables_initializer().run()
self.assertAllEqual((1, -1, -1), ids.eval())
self.assertEqual(2, table.size().eval())
@ -1286,7 +1286,7 @@ class IndexTableFromFile(test.TestCase):
ids = table.lookup(constant_op.constant(["salad", "surgery", "tarkus"]))
self.assertRaises(errors_impl.OpError, ids.eval)
data_flow_ops.tables_initializer().run()
lookup_ops.tables_initializer().run()
self.assertAllEqual((1, 2, -1), ids.eval())
self.assertEqual(3, table.size().eval())
@ -1345,7 +1345,7 @@ class IndexTableFromTensor(test.TestCase):
ids = table.lookup(constant_op.constant(("salad", "surgery", "tarkus")))
self.assertRaises(errors_impl.OpError, ids.eval)
data_flow_ops.tables_initializer().run()
lookup_ops.tables_initializer().run()
self.assertAllEqual((1, 2, 3), ids.eval())
def test_int32_index_table_from_tensor_with_tensor_init(self):
@ -1356,7 +1356,7 @@ class IndexTableFromTensor(test.TestCase):
constant_op.constant((1, -1000, 11), dtype=dtypes.int32))
self.assertRaises(errors_impl.OpError, ids.eval)
data_flow_ops.tables_initializer().run()
lookup_ops.tables_initializer().run()
self.assertAllEqual((1, 2, 3), ids.eval())
def test_int64_index_table_from_tensor_with_tensor_init(self):
@ -1367,7 +1367,7 @@ class IndexTableFromTensor(test.TestCase):
constant_op.constant((1, -1000, 11), dtype=dtypes.int64))
self.assertRaises(errors_impl.OpError, ids.eval)
data_flow_ops.tables_initializer().run()
lookup_ops.tables_initializer().run()
self.assertAllEqual((1, 2, 3), ids.eval())
def test_index_table_from_tensor_with_default_value(self):
@ -1378,7 +1378,7 @@ class IndexTableFromTensor(test.TestCase):
ids = table.lookup(constant_op.constant(["salad", "surgery", "tarkus"]))
self.assertRaises(errors_impl.OpError, ids.eval)
data_flow_ops.tables_initializer().run()
lookup_ops.tables_initializer().run()
self.assertAllEqual((1, 2, default_value), ids.eval())
def test_index_table_from_tensor_missing_mapping(self):
@ -1394,7 +1394,7 @@ class IndexTableFromTensor(test.TestCase):
self.assertRaises(errors_impl.OpError, ids.eval)
with self.assertRaisesRegexp(
errors_impl.OpError, "keys and values cannot be empty"):
data_flow_ops.tables_initializer().run()
lookup_ops.tables_initializer().run()
def test_index_table_from_tensor_with_invalid_hashers(self):
with self.test_session():
@ -1422,7 +1422,7 @@ class StringToIndexTest(test.TestCase):
indices = lookup.string_to_index(feats, mapping=mapping_strings)
self.assertRaises(errors_impl.OpError, indices.eval)
data_flow_ops.tables_initializer().run()
lookup_ops.tables_initializer().run()
self.assertAllEqual((1, 2, -1), indices.eval())
@ -1433,7 +1433,7 @@ class StringToIndexTest(test.TestCase):
_ = lookup.string_to_index(feats, mapping=mapping_strings)
self.assertRaises(errors_impl.OpError,
data_flow_ops.tables_initializer().run)
lookup_ops.tables_initializer().run)
def test_string_to_index_with_default_value(self):
default_value = -42
@ -1444,7 +1444,7 @@ class StringToIndexTest(test.TestCase):
feats, mapping=mapping_strings, default_value=default_value)
self.assertRaises(errors_impl.OpError, indices.eval)
data_flow_ops.tables_initializer().run()
lookup_ops.tables_initializer().run()
self.assertAllEqual((1, 2, default_value), indices.eval())
@ -1463,7 +1463,7 @@ class IndexToStringTableFromFileTest(test.TestCase):
vocabulary_file=vocabulary_file)
features = table.lookup(constant_op.constant([0, 1, 2, 3], dtypes.int64))
self.assertRaises(errors_impl.OpError, features.eval)
data_flow_ops.tables_initializer().run()
lookup_ops.tables_initializer().run()
self.assertAllEqual((b"brain", b"salad", b"surgery", b"UNK"),
features.eval())
@ -1475,7 +1475,7 @@ class IndexToStringTableFromFileTest(test.TestCase):
vocabulary_file=vocabulary_file, default_value=default_value)
features = table.lookup(constant_op.constant([1, 2, 4], dtypes.int64))
self.assertRaises(errors_impl.OpError, features.eval)
data_flow_ops.tables_initializer().run()
lookup_ops.tables_initializer().run()
self.assertAllEqual((b"salad", b"surgery", default_value),
features.eval())
@ -1489,7 +1489,7 @@ class IndexToStringTableFromFileTest(test.TestCase):
default_value=default_value)
features = table.lookup(constant_op.constant([1, 2, 4], dtypes.int64))
self.assertRaises(errors_impl.OpError, features.eval)
data_flow_ops.tables_initializer().run()
lookup_ops.tables_initializer().run()
self.assertAllEqual((b"salad", default_value, default_value),
features.eval())
@ -1501,7 +1501,7 @@ class IndexToStringTableFromFileTest(test.TestCase):
features = table.lookup(constant_op.constant([1, 2, 4], dtypes.int64))
self.assertRaises(errors_impl.OpError, features.eval)
init = data_flow_ops.tables_initializer()
init = lookup_ops.tables_initializer()
self.assertRaisesRegexp(errors_impl.InvalidArgumentError,
"Invalid vocab_size", init.run)
@ -1513,7 +1513,7 @@ class IndexToStringTableFromFileTest(test.TestCase):
features = table.lookup(constant_op.constant([1, 2, 4], dtypes.int64))
self.assertRaises(errors_impl.OpError, features.eval)
data_flow_ops.tables_initializer().run()
lookup_ops.tables_initializer().run()
self.assertAllEqual((b"salad", b"surgery", b"UNK"), features.eval())
@ -1528,7 +1528,7 @@ class IndexToStringTableFromTensorTest(test.TestCase):
indices = constant_op.constant([0, 1, 2, 3], dtypes.int64)
features = table.lookup(indices)
self.assertRaises(errors_impl.OpError, features.eval)
data_flow_ops.tables_initializer().run()
lookup_ops.tables_initializer().run()
self.assertAllEqual((b"brain", b"salad", b"surgery", b"UNK"),
features.eval())
@ -1540,7 +1540,7 @@ class IndexToStringTableFromTensorTest(test.TestCase):
mapping=mapping_strings)
indices = constant_op.constant([0, 1, 4], dtypes.int64)
features = table.lookup(indices)
data_flow_ops.tables_initializer().run()
lookup_ops.tables_initializer().run()
self.assertAllEqual((b"hello", b"hello", b"UNK"), features.eval())
def test_index_to_string_with_default_value(self):
@ -1553,7 +1553,7 @@ class IndexToStringTableFromTensorTest(test.TestCase):
features = table.lookup(indices)
self.assertRaises(errors_impl.OpError, features.eval)
data_flow_ops.tables_initializer().run()
lookup_ops.tables_initializer().run()
self.assertAllEqual((b"salad", b"surgery", default_value),
features.eval())
@ -1567,7 +1567,7 @@ class IndexToStringTest(test.TestCase):
feats = lookup.index_to_string(indices, mapping=mapping_strings)
self.assertRaises(errors_impl.OpError, feats.eval)
data_flow_ops.tables_initializer().run()
lookup_ops.tables_initializer().run()
self.assertAllEqual((b"brain", b"salad", b"surgery", b"UNK"),
feats.eval())
@ -1577,11 +1577,11 @@ class IndexToStringTest(test.TestCase):
mapping_strings = constant_op.constant(["hello", "hello"])
indices = constant_op.constant([0, 1, 4], dtypes.int64)
feats = lookup.index_to_string(indices, mapping=mapping_strings)
data_flow_ops.tables_initializer().run()
lookup_ops.tables_initializer().run()
self.assertAllEqual((b"hello", b"hello", b"UNK"), feats.eval())
self.assertRaises(errors_impl.OpError,
data_flow_ops.tables_initializer().run)
lookup_ops.tables_initializer().run)
def test_index_to_string_with_default_value(self):
default_value = b"NONE"
@ -1592,7 +1592,7 @@ class IndexToStringTest(test.TestCase):
indices, mapping=mapping_strings, default_value=default_value)
self.assertRaises(errors_impl.OpError, feats.eval)
data_flow_ops.tables_initializer().run()
lookup_ops.tables_initializer().run()
self.assertAllEqual((b"salad", b"surgery", default_value), feats.eval())
@ -1755,7 +1755,7 @@ class InitializeTableFromFileOpTest(test.TestCase):
default_value,
shared_name=shared_name)
data_flow_ops.tables_initializer().run()
lookup_ops.tables_initializer().run()
input_string = constant_op.constant(["brain", "salad", "tank"])
@ -2081,7 +2081,7 @@ class IdTableWithHashBucketsTest(test.TestCase):
hasher_spec=lookup.StrongHashSpec((1, 2)),
name="table2")
data_flow_ops.tables_initializer().run()
lookup_ops.tables_initializer().run()
input_string = constant_op.constant(
["fruit", "brain", "salad", "surgery", "UNK"])
@ -2167,7 +2167,7 @@ class IdTableWithHashBucketsTest(test.TestCase):
default_value2),
oov_buckets)
data_flow_ops.tables_initializer().run()
lookup_ops.tables_initializer().run()
input_string_1 = constant_op.constant(
["brain", "salad", "surgery", "UNK"])

View File

@ -7,6 +7,7 @@ tensorflow/core/protobuf/saver.pb.cc
tensorflow/core/protobuf/queue_runner.pb.cc
tensorflow/core/protobuf/named_tensor.pb.cc
tensorflow/core/protobuf/meta_graph.pb.cc
tensorflow/core/protobuf/cluster.pb.cc
tensorflow/core/protobuf/config.pb.cc
tensorflow/core/protobuf/rewriter_config.pb.cc
tensorflow/core/protobuf/debug.pb.cc

View File

@ -7,6 +7,7 @@ tensorflow/core/protobuf/saver.pb.h
tensorflow/core/protobuf/queue_runner.pb.h
tensorflow/core/protobuf/named_tensor.pb.h
tensorflow/core/protobuf/meta_graph.pb.h
tensorflow/core/protobuf/cluster.pb.h
tensorflow/core/protobuf/config.pb.h
tensorflow/core/protobuf/debug.pb.h
tensorflow/core/protobuf/rewriter_config.pb.h

View File

@ -1,6 +1,7 @@
tensorflow/core/util/saved_tensor_slice.pb_text.cc
tensorflow/core/util/memmapped_file_system.pb_text.cc
tensorflow/core/protobuf/saver.pb_text.cc
tensorflow/core/protobuf/cluster.pb_text.cc
tensorflow/core/protobuf/config.pb_text.cc
tensorflow/core/protobuf/debug.pb_text.cc
tensorflow/core/protobuf/rewriter_config.pb_text.cc

View File

@ -7,6 +7,7 @@ tensorflow/core/protobuf/saver.proto
tensorflow/core/protobuf/queue_runner.proto
tensorflow/core/protobuf/named_tensor.proto
tensorflow/core/protobuf/meta_graph.proto
tensorflow/core/protobuf/cluster.proto
tensorflow/core/protobuf/config.proto
tensorflow/core/protobuf/debug.proto
tensorflow/core/protobuf/rewriter_config.proto

View File

@ -1338,6 +1338,87 @@ def streaming_sparse_precision_at_top_k(top_k_predictions,
name=name_scope)
def sparse_recall_at_top_k(labels,
top_k_predictions,
class_id=None,
weights=None,
metrics_collections=None,
updates_collections=None,
name=None):
"""Computes recall@k of top-k predictions with respect to sparse labels.
If `class_id` is specified, we calculate recall by considering only the
entries in the batch for which `class_id` is in the label, and computing
the fraction of them for which `class_id` is in the top-k `predictions`.
If `class_id` is not specified, we'll calculate recall as how often on
average a class among the labels of a batch entry is in the top-k
`predictions`.
`sparse_recall_at_top_k` creates two local variables, `true_positive_at_<k>`
and `false_negative_at_<k>`, that are used to compute the recall_at_k
frequency. This frequency is ultimately returned as `recall_at_<k>`: an
idempotent operation that simply divides `true_positive_at_<k>` by total
(`true_positive_at_<k>` + `false_negative_at_<k>`).
For estimation of the metric over a stream of data, the function creates an
`update_op` operation that updates these variables and returns the
`recall_at_<k>`. Set operations applied to `top_k` and `labels` calculate the
true positives and false negatives weighted by `weights`. Then `update_op`
increments `true_positive_at_<k>` and `false_negative_at_<k>` using these
values.
If `weights` is `None`, weights default to 1. Use weights of 0 to mask values.
Args:
labels: `int64` `Tensor` or `SparseTensor` with shape
[D1, ... DN, num_labels], where N >= 1 and num_labels is the number of
target classes for the associated prediction. Commonly, N=1 and `labels`
has shape [batch_size, num_labels]. [D1, ... DN] must match
`top_k_predictions`. Values should be in range [0, num_classes), where
num_classes is the last dimension of `predictions`. Values outside this
range always count towards `false_negative_at_<k>`.
top_k_predictions: Integer `Tensor` with shape [D1, ... DN, k] where
N >= 1. Commonly, N=1 and top_k_predictions has shape [batch size, k].
The final dimension contains the indices of top-k labels. [D1, ... DN]
must match `labels`.
class_id: Integer class ID for which we want binary metrics. This should be
in range [0, num_classes), where num_classes is the last dimension of
`predictions`. If class_id is outside this range, the method returns NAN.
weights: `Tensor` whose rank is either 0, or n-1, where n is the rank of
`labels`. If the latter, it must be broadcastable to `labels` (i.e., all
dimensions must be either `1`, or the same as the corresponding `labels`
dimension).
metrics_collections: An optional list of collections that values should
be added to.
updates_collections: An optional list of collections that updates should
be added to.
name: Name of new update operation, and namespace for other dependent ops.
Returns:
recall: Scalar `float64` `Tensor` with the value of `true_positives` divided
by the sum of `true_positives` and `false_negatives`.
update_op: `Operation` that increments `true_positives` and
`false_negatives` variables appropriately, and whose value matches
`recall`.
Raises:
ValueError: If `weights` is not `None` and its shape doesn't match
`predictions`, or if either `metrics_collections` or `updates_collections`
are not a list or tuple.
"""
default_name = _at_k_name('recall', class_id=class_id)
with ops.name_scope(name, default_name, (top_k_predictions, labels,
weights)) as name_scope:
return metrics_impl._sparse_recall_at_top_k( # pylint: disable=protected-access
labels=labels,
predictions_idx=top_k_predictions,
class_id=class_id,
weights=weights,
metrics_collections=metrics_collections,
updates_collections=updates_collections,
name=name_scope)
def streaming_sparse_average_precision_at_k(predictions,
labels,
k,
@ -2288,6 +2369,7 @@ def _remove_squeezable_dimensions(predictions, labels, weights):
__all__ = [
'aggregate_metric_map',
'aggregate_metrics',
'sparse_recall_at_top_k',
'streaming_accuracy',
'streaming_auc',
'streaming_false_negatives',
@ -2310,7 +2392,9 @@ __all__ = [
'streaming_root_mean_squared_error',
'streaming_sensitivity_at_specificity',
'streaming_sparse_average_precision_at_k',
'streaming_sparse_average_precision_at_top_k',
'streaming_sparse_precision_at_k',
'streaming_sparse_precision_at_top_k',
'streaming_sparse_recall_at_k',
'streaming_specificity_at_sensitivity',
'streaming_true_negatives',

View File

@ -2958,8 +2958,38 @@ class StreamingSparseRecallTest(test.TestCase):
self.assertEqual(expected, update.eval())
self.assertEqual(expected, metric.eval())
def _test_sparse_recall_at_top_k(self,
labels,
top_k_predictions,
expected,
class_id=None,
weights=None):
with ops.Graph().as_default() as g, self.test_session(g):
if weights is not None:
weights = constant_op.constant(weights, dtypes_lib.float32)
metric, update = metric_ops.sparse_recall_at_top_k(
labels=labels,
top_k_predictions=constant_op.constant(top_k_predictions,
dtypes_lib.int32),
class_id=class_id,
weights=weights)
# Fails without initialized vars.
self.assertRaises(errors_impl.OpError, metric.eval)
self.assertRaises(errors_impl.OpError, update.eval)
variables.variables_initializer(variables.local_variables()).run()
# Run per-step op and assert expected values.
if math.isnan(expected):
self.assertTrue(math.isnan(update.eval()))
self.assertTrue(math.isnan(metric.eval()))
else:
self.assertEqual(expected, update.eval())
self.assertEqual(expected, metric.eval())
def test_one_label_at_k1_nan(self):
predictions = [[0.1, 0.3, 0.2, 0.4], [0.1, 0.2, 0.3, 0.4]]
top_k_predictions = [[3], [3]]
sparse_labels = _binary_2d_label_to_sparse_value(
[[0, 0, 0, 1], [0, 0, 1, 0]])
dense_labels = np.array([[3], [2]], dtype=np.int64)
@ -2970,9 +3000,12 @@ class StreamingSparseRecallTest(test.TestCase):
for class_id in (-1, 0, 1, 4):
self._test_streaming_sparse_recall_at_k(
predictions, labels, k=1, expected=NAN, class_id=class_id)
self._test_sparse_recall_at_top_k(
labels, top_k_predictions, expected=NAN, class_id=class_id)
def test_one_label_at_k1_no_predictions(self):
predictions = [[0.1, 0.3, 0.2, 0.4], [0.1, 0.2, 0.3, 0.4]]
top_k_predictions = [[3], [3]]
sparse_labels = _binary_2d_label_to_sparse_value(
[[0, 0, 0, 1], [0, 0, 1, 0]])
dense_labels = np.array([[3], [2]], dtype=np.int64)
@ -2981,9 +3014,12 @@ class StreamingSparseRecallTest(test.TestCase):
# Class 2: 0 predictions.
self._test_streaming_sparse_recall_at_k(
predictions, labels, k=1, expected=0.0, class_id=2)
self._test_sparse_recall_at_top_k(
labels, top_k_predictions, expected=0.0, class_id=2)
def test_one_label_at_k1(self):
predictions = [[0.1, 0.3, 0.2, 0.4], [0.1, 0.2, 0.3, 0.4]]
top_k_predictions = [[3], [3]]
sparse_labels = _binary_2d_label_to_sparse_value(
[[0, 0, 0, 1], [0, 0, 1, 0]])
dense_labels = np.array([[3], [2]], dtype=np.int64)
@ -2992,13 +3028,18 @@ class StreamingSparseRecallTest(test.TestCase):
# Class 3: 1 label, 2 predictions, 1 correct.
self._test_streaming_sparse_recall_at_k(
predictions, labels, k=1, expected=1.0 / 1, class_id=3)
self._test_sparse_recall_at_top_k(
labels, top_k_predictions, expected=1.0 / 1, class_id=3)
# All classes: 2 labels, 2 predictions, 1 correct.
self._test_streaming_sparse_recall_at_k(
predictions, labels, k=1, expected=1.0 / 2)
self._test_sparse_recall_at_top_k(
labels, top_k_predictions, expected=1.0 / 2)
def test_one_label_at_k1_weighted(self):
predictions = [[0.1, 0.3, 0.2, 0.4], [0.1, 0.2, 0.3, 0.4]]
top_k_predictions = [[3], [3]]
sparse_labels = _binary_2d_label_to_sparse_value(
[[0, 0, 0, 1], [0, 0, 1, 0]])
dense_labels = np.array([[3], [2]], dtype=np.int64)
@ -3007,6 +3048,8 @@ class StreamingSparseRecallTest(test.TestCase):
# Class 3: 1 label, 2 predictions, 1 correct.
self._test_streaming_sparse_recall_at_k(
predictions, labels, k=1, expected=NAN, class_id=3, weights=(0.0,))
self._test_sparse_recall_at_top_k(
labels, top_k_predictions, expected=NAN, class_id=3, weights=(0.0,))
self._test_streaming_sparse_recall_at_k(
predictions,
labels,
@ -3014,6 +3057,12 @@ class StreamingSparseRecallTest(test.TestCase):
expected=1.0 / 1,
class_id=3,
weights=(1.0,))
self._test_sparse_recall_at_top_k(
labels,
top_k_predictions,
expected=1.0 / 1,
class_id=3,
weights=(1.0,))
self._test_streaming_sparse_recall_at_k(
predictions,
labels,
@ -3021,6 +3070,12 @@ class StreamingSparseRecallTest(test.TestCase):
expected=1.0 / 1,
class_id=3,
weights=(2.0,))
self._test_sparse_recall_at_top_k(
labels,
top_k_predictions,
expected=1.0 / 1,
class_id=3,
weights=(2.0,))
self._test_streaming_sparse_recall_at_k(
predictions,
labels,
@ -3028,6 +3083,12 @@ class StreamingSparseRecallTest(test.TestCase):
expected=NAN,
class_id=3,
weights=(0.0, 0.0))
self._test_sparse_recall_at_top_k(
labels,
top_k_predictions,
expected=NAN,
class_id=3,
weights=(0.0, 0.0))
self._test_streaming_sparse_recall_at_k(
predictions,
labels,
@ -3035,6 +3096,12 @@ class StreamingSparseRecallTest(test.TestCase):
expected=NAN,
class_id=3,
weights=(0.0, 1.0))
self._test_sparse_recall_at_top_k(
labels,
top_k_predictions,
expected=NAN,
class_id=3,
weights=(0.0, 1.0))
self._test_streaming_sparse_recall_at_k(
predictions,
labels,
@ -3042,6 +3109,12 @@ class StreamingSparseRecallTest(test.TestCase):
expected=1.0 / 1,
class_id=3,
weights=(1.0, 0.0))
self._test_sparse_recall_at_top_k(
labels,
top_k_predictions,
expected=1.0 / 1,
class_id=3,
weights=(1.0, 0.0))
self._test_streaming_sparse_recall_at_k(
predictions,
labels,
@ -3049,6 +3122,12 @@ class StreamingSparseRecallTest(test.TestCase):
expected=1.0 / 1,
class_id=3,
weights=(1.0, 1.0))
self._test_sparse_recall_at_top_k(
labels,
top_k_predictions,
expected=1.0 / 1,
class_id=3,
weights=(1.0, 1.0))
self._test_streaming_sparse_recall_at_k(
predictions,
labels,
@ -3056,6 +3135,12 @@ class StreamingSparseRecallTest(test.TestCase):
expected=2.0 / 2,
class_id=3,
weights=(2.0, 3.0))
self._test_sparse_recall_at_top_k(
labels,
top_k_predictions,
expected=2.0 / 2,
class_id=3,
weights=(2.0, 3.0))
self._test_streaming_sparse_recall_at_k(
predictions,
labels,
@ -3063,6 +3148,12 @@ class StreamingSparseRecallTest(test.TestCase):
expected=3.0 / 3,
class_id=3,
weights=(3.0, 2.0))
self._test_sparse_recall_at_top_k(
labels,
top_k_predictions,
expected=3.0 / 3,
class_id=3,
weights=(3.0, 2.0))
self._test_streaming_sparse_recall_at_k(
predictions,
labels,
@ -3070,6 +3161,12 @@ class StreamingSparseRecallTest(test.TestCase):
expected=0.3 / 0.3,
class_id=3,
weights=(0.3, 0.6))
self._test_sparse_recall_at_top_k(
labels,
top_k_predictions,
expected=0.3 / 0.3,
class_id=3,
weights=(0.3, 0.6))
self._test_streaming_sparse_recall_at_k(
predictions,
labels,
@ -3077,32 +3174,70 @@ class StreamingSparseRecallTest(test.TestCase):
expected=0.6 / 0.6,
class_id=3,
weights=(0.6, 0.3))
self._test_sparse_recall_at_top_k(
labels,
top_k_predictions,
expected=0.6 / 0.6,
class_id=3,
weights=(0.6, 0.3))
# All classes: 2 labels, 2 predictions, 1 correct.
self._test_streaming_sparse_recall_at_k(
predictions, labels, k=1, expected=NAN, weights=(0.0,))
self._test_sparse_recall_at_top_k(
labels, top_k_predictions, expected=NAN, weights=(0.0,))
self._test_streaming_sparse_recall_at_k(
predictions, labels, k=1, expected=1.0 / 2, weights=(1.0,))
self._test_sparse_recall_at_top_k(
labels, top_k_predictions, expected=1.0 / 2, weights=(1.0,))
self._test_streaming_sparse_recall_at_k(
predictions, labels, k=1, expected=1.0 / 2, weights=(2.0,))
self._test_sparse_recall_at_top_k(
labels, top_k_predictions, expected=1.0 / 2, weights=(2.0,))
self._test_streaming_sparse_recall_at_k(
predictions, labels, k=1, expected=1.0 / 1, weights=(1.0, 0.0))
self._test_sparse_recall_at_top_k(
labels, top_k_predictions, expected=1.0 / 1, weights=(1.0, 0.0))
self._test_streaming_sparse_recall_at_k(
predictions, labels, k=1, expected=0.0 / 1, weights=(0.0, 1.0))
self._test_sparse_recall_at_top_k(
labels, top_k_predictions, expected=0.0 / 1, weights=(0.0, 1.0))
self._test_streaming_sparse_recall_at_k(
predictions, labels, k=1, expected=1.0 / 2, weights=(1.0, 1.0))
self._test_sparse_recall_at_top_k(
labels, top_k_predictions, expected=1.0 / 2, weights=(1.0, 1.0))
self._test_streaming_sparse_recall_at_k(
predictions, labels, k=1, expected=2.0 / 5, weights=(2.0, 3.0))
self._test_sparse_recall_at_top_k(
labels, top_k_predictions, expected=2.0 / 5, weights=(2.0, 3.0))
self._test_streaming_sparse_recall_at_k(
predictions, labels, k=1, expected=3.0 / 5, weights=(3.0, 2.0))
self._test_sparse_recall_at_top_k(
labels, top_k_predictions, expected=3.0 / 5, weights=(3.0, 2.0))
self._test_streaming_sparse_recall_at_k(
predictions, labels, k=1, expected=0.3 / 0.9, weights=(0.3, 0.6))
self._test_sparse_recall_at_top_k(
labels, top_k_predictions, expected=0.3 / 0.9, weights=(0.3, 0.6))
self._test_streaming_sparse_recall_at_k(
predictions, labels, k=1, expected=0.6 / 0.9, weights=(0.6, 0.3))
self._test_sparse_recall_at_top_k(
labels, top_k_predictions, expected=0.6 / 0.9, weights=(0.6, 0.3))
def test_three_labels_at_k5_nan(self):
predictions = [[0.5, 0.1, 0.6, 0.3, 0.8, 0.0, 0.7, 0.2, 0.4, 0.9],
[0.3, 0.0, 0.7, 0.2, 0.4, 0.9, 0.5, 0.8, 0.1, 0.6]]
top_k_predictions = [
[9, 4, 6, 2, 0],
[5, 7, 2, 9, 6],
]
sparse_labels = _binary_2d_label_to_sparse_value(
[[0, 0, 1, 0, 0, 0, 0, 1, 1, 0], [0, 1, 1, 0, 0, 1, 0, 0, 0, 0]])
dense_labels = np.array([[2, 7, 8], [1, 2, 5]], dtype=np.int64)
@ -3112,10 +3247,16 @@ class StreamingSparseRecallTest(test.TestCase):
for class_id in (0, 3, 4, 6, 9, 10):
self._test_streaming_sparse_recall_at_k(
predictions, labels, k=5, expected=NAN, class_id=class_id)
self._test_sparse_recall_at_top_k(
labels, top_k_predictions, expected=NAN, class_id=class_id)
def test_three_labels_at_k5_no_predictions(self):
predictions = [[0.5, 0.1, 0.6, 0.3, 0.8, 0.0, 0.7, 0.2, 0.4, 0.9],
[0.3, 0.0, 0.7, 0.2, 0.4, 0.9, 0.5, 0.8, 0.1, 0.6]]
top_k_predictions = [
[9, 4, 6, 2, 0],
[5, 7, 2, 9, 6],
]
sparse_labels = _binary_2d_label_to_sparse_value(
[[0, 0, 1, 0, 0, 0, 0, 1, 1, 0], [0, 1, 1, 0, 0, 1, 0, 0, 0, 0]])
dense_labels = np.array([[2, 7, 8], [1, 2, 5]], dtype=np.int64)
@ -3124,10 +3265,16 @@ class StreamingSparseRecallTest(test.TestCase):
# Class 8: 1 label, no predictions.
self._test_streaming_sparse_recall_at_k(
predictions, labels, k=5, expected=0.0 / 1, class_id=8)
self._test_sparse_recall_at_top_k(
labels, top_k_predictions, expected=0.0 / 1, class_id=8)
def test_three_labels_at_k5(self):
predictions = [[0.5, 0.1, 0.6, 0.3, 0.8, 0.0, 0.7, 0.2, 0.4, 0.9],
[0.3, 0.0, 0.7, 0.2, 0.4, 0.9, 0.5, 0.8, 0.1, 0.6]]
top_k_predictions = [
[9, 4, 6, 2, 0],
[5, 7, 2, 9, 6],
]
sparse_labels = _binary_2d_label_to_sparse_value(
[[0, 0, 1, 0, 0, 0, 0, 1, 1, 0], [0, 1, 1, 0, 0, 1, 0, 0, 0, 0]])
dense_labels = np.array([[2, 7, 8], [1, 2, 5]], dtype=np.int64)
@ -3136,23 +3283,35 @@ class StreamingSparseRecallTest(test.TestCase):
# Class 2: 2 labels, both correct.
self._test_streaming_sparse_recall_at_k(
predictions, labels, k=5, expected=2.0 / 2, class_id=2)
self._test_sparse_recall_at_top_k(
labels, top_k_predictions, expected=2.0 / 2, class_id=2)
# Class 5: 1 label, incorrect.
self._test_streaming_sparse_recall_at_k(
predictions, labels, k=5, expected=1.0 / 1, class_id=5)
self._test_sparse_recall_at_top_k(
labels, top_k_predictions, expected=1.0 / 1, class_id=5)
# Class 7: 1 label, incorrect.
self._test_streaming_sparse_recall_at_k(
predictions, labels, k=5, expected=0.0 / 1, class_id=7)
self._test_sparse_recall_at_top_k(
labels, top_k_predictions, expected=0.0 / 1, class_id=7)
# All classes: 6 labels, 3 correct.
self._test_streaming_sparse_recall_at_k(
predictions, labels, k=5, expected=3.0 / 6)
self._test_sparse_recall_at_top_k(
labels, top_k_predictions, expected=3.0 / 6)
def test_three_labels_at_k5_some_out_of_range(self):
"""Tests that labels outside the [0, n_classes) count in denominator."""
predictions = [[0.5, 0.1, 0.6, 0.3, 0.8, 0.0, 0.7, 0.2, 0.4, 0.9],
[0.3, 0.0, 0.7, 0.2, 0.4, 0.9, 0.5, 0.8, 0.1, 0.6]]
top_k_predictions = [
[9, 4, 6, 2, 0],
[5, 7, 2, 9, 6],
]
sp_labels = sparse_tensor.SparseTensorValue(
indices=[[0, 0], [0, 1], [0, 2], [0, 3], [1, 0], [1, 1], [1, 2],
[1, 3]],
@ -3167,6 +3326,11 @@ class StreamingSparseRecallTest(test.TestCase):
k=5,
expected=2.0 / 2,
class_id=2)
self._test_sparse_recall_at_top_k(
sp_labels,
top_k_predictions,
expected=2.0 / 2,
class_id=2)
# Class 5: 1 label, incorrect.
self._test_streaming_sparse_recall_at_k(
@ -3175,6 +3339,11 @@ class StreamingSparseRecallTest(test.TestCase):
k=5,
expected=1.0 / 1,
class_id=5)
self._test_sparse_recall_at_top_k(
sp_labels,
top_k_predictions,
expected=1.0 / 1,
class_id=5)
# Class 7: 1 label, incorrect.
self._test_streaming_sparse_recall_at_k(
@ -3183,16 +3352,30 @@ class StreamingSparseRecallTest(test.TestCase):
k=5,
expected=0.0 / 1,
class_id=7)
self._test_sparse_recall_at_top_k(
sp_labels,
top_k_predictions,
expected=0.0 / 1,
class_id=7)
# All classes: 8 labels, 3 correct.
self._test_streaming_sparse_recall_at_k(
predictions=predictions, labels=sp_labels, k=5, expected=3.0 / 8)
self._test_sparse_recall_at_top_k(
sp_labels, top_k_predictions, expected=3.0 / 8)
def test_3d_nan(self):
predictions = [[[0.5, 0.1, 0.6, 0.3, 0.8, 0.0, 0.7, 0.2, 0.4, 0.9],
[0.3, 0.0, 0.7, 0.2, 0.4, 0.9, 0.5, 0.8, 0.1, 0.6]],
[[0.3, 0.0, 0.7, 0.2, 0.4, 0.9, 0.5, 0.8, 0.1, 0.6],
[0.5, 0.1, 0.6, 0.3, 0.8, 0.0, 0.7, 0.2, 0.4, 0.9]]]
top_k_predictions = [[
[9, 4, 6, 2, 0],
[5, 7, 2, 9, 6],
], [
[5, 7, 2, 9, 6],
[9, 4, 6, 2, 0],
]]
sparse_labels = _binary_3d_label_to_sparse_value(
[[[0, 0, 1, 0, 0, 0, 0, 1, 1, 0], [0, 1, 1, 0, 0, 1, 0, 0, 0, 0]],
[[0, 1, 1, 0, 0, 1, 0, 0, 0, 0], [0, 0, 1, 0, 0, 0, 0, 1, 1, 0]]])
@ -3207,12 +3390,21 @@ class StreamingSparseRecallTest(test.TestCase):
for class_id in (0, 3, 4, 6, 9, 10):
self._test_streaming_sparse_recall_at_k(
predictions, labels, k=5, expected=NAN, class_id=class_id)
self._test_sparse_recall_at_top_k(
labels, top_k_predictions, expected=NAN, class_id=class_id)
def test_3d_no_predictions(self):
predictions = [[[0.5, 0.1, 0.6, 0.3, 0.8, 0.0, 0.7, 0.2, 0.4, 0.9],
[0.3, 0.0, 0.7, 0.2, 0.4, 0.9, 0.5, 0.8, 0.1, 0.6]],
[[0.3, 0.0, 0.7, 0.2, 0.4, 0.9, 0.5, 0.8, 0.1, 0.6],
[0.5, 0.1, 0.6, 0.3, 0.8, 0.0, 0.7, 0.2, 0.4, 0.9]]]
top_k_predictions = [[
[9, 4, 6, 2, 0],
[5, 7, 2, 9, 6],
], [
[5, 7, 2, 9, 6],
[9, 4, 6, 2, 0],
]]
sparse_labels = _binary_3d_label_to_sparse_value(
[[[0, 0, 1, 0, 0, 0, 0, 1, 1, 0],
[0, 1, 1, 0, 0, 1, 0, 0, 0, 0]],
@ -3229,12 +3421,21 @@ class StreamingSparseRecallTest(test.TestCase):
for class_id in (1, 8):
self._test_streaming_sparse_recall_at_k(
predictions, labels, k=5, expected=0.0, class_id=class_id)
self._test_sparse_recall_at_top_k(
labels, top_k_predictions, expected=0.0, class_id=class_id)
def test_3d(self):
predictions = [[[0.5, 0.1, 0.6, 0.3, 0.8, 0.0, 0.7, 0.2, 0.4, 0.9],
[0.3, 0.0, 0.7, 0.2, 0.4, 0.9, 0.5, 0.8, 0.1, 0.6]],
[[0.3, 0.0, 0.7, 0.2, 0.4, 0.9, 0.5, 0.8, 0.1, 0.6],
[0.5, 0.1, 0.6, 0.3, 0.8, 0.0, 0.7, 0.2, 0.4, 0.9]]]
top_k_predictions = [[
[9, 4, 6, 2, 0],
[5, 7, 2, 9, 6],
], [
[5, 7, 2, 9, 6],
[9, 4, 6, 2, 0],
]]
labels = _binary_3d_label_to_sparse_value(
[[[0, 0, 1, 0, 0, 0, 0, 1, 1, 0],
[0, 1, 1, 0, 0, 1, 0, 0, 0, 0]],
@ -3244,24 +3445,39 @@ class StreamingSparseRecallTest(test.TestCase):
# Class 2: 4 labels, all correct.
self._test_streaming_sparse_recall_at_k(
predictions, labels, k=5, expected=4.0 / 4, class_id=2)
self._test_sparse_recall_at_top_k(
labels, top_k_predictions, expected=4.0 / 4, class_id=2)
# Class 5: 2 labels, both correct.
self._test_streaming_sparse_recall_at_k(
predictions, labels, k=5, expected=2.0 / 2, class_id=5)
self._test_sparse_recall_at_top_k(
labels, top_k_predictions, expected=2.0 / 2, class_id=5)
# Class 7: 2 labels, 1 incorrect.
self._test_streaming_sparse_recall_at_k(
predictions, labels, k=5, expected=1.0 / 2, class_id=7)
self._test_sparse_recall_at_top_k(
labels, top_k_predictions, expected=1.0 / 2, class_id=7)
# All classes: 12 labels, 7 correct.
self._test_streaming_sparse_recall_at_k(
predictions, labels, k=5, expected=7.0 / 12)
self._test_sparse_recall_at_top_k(
labels, top_k_predictions, expected=7.0 / 12)
def test_3d_ignore_all(self):
predictions = [[[0.5, 0.1, 0.6, 0.3, 0.8, 0.0, 0.7, 0.2, 0.4, 0.9],
[0.3, 0.0, 0.7, 0.2, 0.4, 0.9, 0.5, 0.8, 0.1, 0.6]],
[[0.3, 0.0, 0.7, 0.2, 0.4, 0.9, 0.5, 0.8, 0.1, 0.6],
[0.5, 0.1, 0.6, 0.3, 0.8, 0.0, 0.7, 0.2, 0.4, 0.9]]]
top_k_predictions = [[
[9, 4, 6, 2, 0],
[5, 7, 2, 9, 6],
], [
[5, 7, 2, 9, 6],
[9, 4, 6, 2, 0],
]]
labels = _binary_3d_label_to_sparse_value(
[[[0, 0, 1, 0, 0, 0, 0, 1, 1, 0],
[0, 1, 1, 0, 0, 1, 0, 0, 0, 0]],
@ -3276,6 +3492,12 @@ class StreamingSparseRecallTest(test.TestCase):
expected=NAN,
class_id=class_id,
weights=[[0], [0]])
self._test_sparse_recall_at_top_k(
labels,
top_k_predictions,
expected=NAN,
class_id=class_id,
weights=[[0], [0]])
self._test_streaming_sparse_recall_at_k(
predictions,
labels,
@ -3283,16 +3505,33 @@ class StreamingSparseRecallTest(test.TestCase):
expected=NAN,
class_id=class_id,
weights=[[0, 0], [0, 0]])
self._test_sparse_recall_at_top_k(
labels,
top_k_predictions,
expected=NAN,
class_id=class_id,
weights=[[0, 0], [0, 0]])
self._test_streaming_sparse_recall_at_k(
predictions, labels, k=5, expected=NAN, weights=[[0], [0]])
self._test_sparse_recall_at_top_k(
labels, top_k_predictions, expected=NAN, weights=[[0], [0]])
self._test_streaming_sparse_recall_at_k(
predictions, labels, k=5, expected=NAN, weights=[[0, 0], [0, 0]])
self._test_sparse_recall_at_top_k(
labels, top_k_predictions, expected=NAN, weights=[[0, 0], [0, 0]])
def test_3d_ignore_some(self):
predictions = [[[0.5, 0.1, 0.6, 0.3, 0.8, 0.0, 0.7, 0.2, 0.4, 0.9],
[0.3, 0.0, 0.7, 0.2, 0.4, 0.9, 0.5, 0.8, 0.1, 0.6]],
[[0.3, 0.0, 0.7, 0.2, 0.4, 0.9, 0.5, 0.8, 0.1, 0.6],
[0.5, 0.1, 0.6, 0.3, 0.8, 0.0, 0.7, 0.2, 0.4, 0.9]]]
top_k_predictions = [[
[9, 4, 6, 2, 0],
[5, 7, 2, 9, 6],
], [
[5, 7, 2, 9, 6],
[9, 4, 6, 2, 0],
]]
labels = _binary_3d_label_to_sparse_value(
[[[0, 0, 1, 0, 0, 0, 0, 1, 1, 0],
[0, 1, 1, 0, 0, 1, 0, 0, 0, 0]],
@ -3307,6 +3546,12 @@ class StreamingSparseRecallTest(test.TestCase):
expected=2.0 / 2.0,
class_id=2,
weights=[[1], [0]])
self._test_sparse_recall_at_top_k(
labels,
top_k_predictions,
expected=2.0 / 2.0,
class_id=2,
weights=[[1], [0]])
# Class 2: 2 labels, both correct.
self._test_streaming_sparse_recall_at_k(
@ -3316,6 +3561,12 @@ class StreamingSparseRecallTest(test.TestCase):
expected=2.0 / 2.0,
class_id=2,
weights=[[0], [1]])
self._test_sparse_recall_at_top_k(
labels,
top_k_predictions,
expected=2.0 / 2.0,
class_id=2,
weights=[[0], [1]])
# Class 7: 1 label, correct.
self._test_streaming_sparse_recall_at_k(
@ -3325,6 +3576,12 @@ class StreamingSparseRecallTest(test.TestCase):
expected=1.0 / 1.0,
class_id=7,
weights=[[0], [1]])
self._test_sparse_recall_at_top_k(
labels,
top_k_predictions,
expected=1.0 / 1.0,
class_id=7,
weights=[[0], [1]])
# Class 7: 1 label, incorrect.
self._test_streaming_sparse_recall_at_k(
@ -3334,6 +3591,12 @@ class StreamingSparseRecallTest(test.TestCase):
expected=0.0 / 1.0,
class_id=7,
weights=[[1], [0]])
self._test_sparse_recall_at_top_k(
labels,
top_k_predictions,
expected=0.0 / 1.0,
class_id=7,
weights=[[1], [0]])
# Class 7: 2 labels, 1 correct.
self._test_streaming_sparse_recall_at_k(
@ -3343,6 +3606,12 @@ class StreamingSparseRecallTest(test.TestCase):
expected=1.0 / 2.0,
class_id=7,
weights=[[1, 0], [1, 0]])
self._test_sparse_recall_at_top_k(
labels,
top_k_predictions,
expected=1.0 / 2.0,
class_id=7,
weights=[[1, 0], [1, 0]])
# Class 7: No labels.
self._test_streaming_sparse_recall_at_k(
@ -3352,6 +3621,12 @@ class StreamingSparseRecallTest(test.TestCase):
expected=NAN,
class_id=7,
weights=[[0, 1], [0, 1]])
self._test_sparse_recall_at_top_k(
labels,
top_k_predictions,
expected=NAN,
class_id=7,
weights=[[0, 1], [0, 1]])
def test_sparse_tensor_value(self):
predictions = [[0.1, 0.3, 0.2, 0.4],

View File

@ -304,6 +304,7 @@ filegroup(
exclude = [
"**/METADATA",
"**/OWNERS",
"tools/**",
],
),
visibility = ["//tensorflow:__subpackages__"],
@ -351,3 +352,27 @@ tf_kernel_library(
"//third_party/eigen3",
],
)
py_binary(
name = "checkpoint_convert",
srcs = ["python/tools/checkpoint_convert.py"],
srcs_version = "PY2AND3",
deps = [
"//tensorflow/core:protos_all_py",
"//tensorflow/python:platform",
"//tensorflow/python:training",
"//tensorflow/python:variables",
],
)
py_test(
name = "checkpoint_convert_test",
size = "small",
srcs = ["python/tools/checkpoint_convert_test.py"],
srcs_version = "PY2AND3",
tags = ["no_pip"],
deps = [
":checkpoint_convert",
"//tensorflow/python:client_testlib",
],
)

View File

@ -74,7 +74,41 @@ class RNNCellTest(test.TestCase):
"root", initializer=init_ops.constant_initializer(0.5)):
x = array_ops.zeros([1, 2])
m = array_ops.zeros([1, 2])
g, _ = core_rnn_cell_impl.BasicRNNCell(2)(x, m)
cell = core_rnn_cell_impl.BasicRNNCell(2)
g, _ = cell(x, m)
self.assertEqual(
["root/basic_rnn_cell/%s:0"
% core_rnn_cell_impl._WEIGHTS_VARIABLE_NAME,
"root/basic_rnn_cell/%s:0"
% core_rnn_cell_impl._BIAS_VARIABLE_NAME],
[v.name for v in cell.trainable_variables])
self.assertFalse(cell.non_trainable_variables)
sess.run([variables_lib.global_variables_initializer()])
res = sess.run(
[g], {x.name: np.array([[1., 1.]]),
m.name: np.array([[0.1, 0.1]])})
self.assertEqual(res[0].shape, (1, 2))
def testBasicRNNCellNotTrainable(self):
with self.test_session() as sess:
def not_trainable_getter(getter, *args, **kwargs):
kwargs["trainable"] = False
return getter(*args, **kwargs)
with variable_scope.variable_scope(
"root", initializer=init_ops.constant_initializer(0.5),
custom_getter=not_trainable_getter):
x = array_ops.zeros([1, 2])
m = array_ops.zeros([1, 2])
cell = core_rnn_cell_impl.BasicRNNCell(2)
g, _ = cell(x, m)
self.assertFalse(cell.trainable_variables)
self.assertEqual(
["root/basic_rnn_cell/%s:0"
% core_rnn_cell_impl._WEIGHTS_VARIABLE_NAME,
"root/basic_rnn_cell/%s:0"
% core_rnn_cell_impl._BIAS_VARIABLE_NAME],
[v.name for v in cell.non_trainable_variables])
sess.run([variables_lib.global_variables_initializer()])
res = sess.run(
[g], {x.name: np.array([[1., 1.]]),
@ -114,10 +148,23 @@ class RNNCellTest(test.TestCase):
"root", initializer=init_ops.constant_initializer(0.5)):
x = array_ops.zeros([1, 2])
m = array_ops.zeros([1, 8])
g, out_m = core_rnn_cell_impl.MultiRNNCell(
cell = core_rnn_cell_impl.MultiRNNCell(
[core_rnn_cell_impl.BasicLSTMCell(
2, state_is_tuple=False) for _ in range(2)],
state_is_tuple=False)(x, m)
state_is_tuple=False)
g, out_m = cell(x, m)
expected_variable_names = [
"root/multi_rnn_cell/cell_0/basic_lstm_cell/%s:0"
% core_rnn_cell_impl._WEIGHTS_VARIABLE_NAME,
"root/multi_rnn_cell/cell_0/basic_lstm_cell/%s:0"
% core_rnn_cell_impl._BIAS_VARIABLE_NAME,
"root/multi_rnn_cell/cell_1/basic_lstm_cell/%s:0"
% core_rnn_cell_impl._WEIGHTS_VARIABLE_NAME,
"root/multi_rnn_cell/cell_1/basic_lstm_cell/%s:0"
% core_rnn_cell_impl._BIAS_VARIABLE_NAME]
self.assertEqual(
expected_variable_names, [v.name for v in cell.trainable_variables])
self.assertFalse(cell.non_trainable_variables)
sess.run([variables_lib.global_variables_initializer()])
res = sess.run(
[g, out_m],
@ -125,15 +172,7 @@ class RNNCellTest(test.TestCase):
m.name: 0.1 * np.ones([1, 8])})
self.assertEqual(len(res), 2)
variables = variables_lib.global_variables()
self.assertEqual(4, len(variables))
self.assertEquals(variables[0].op.name,
"root/multi_rnn_cell/cell_0/basic_lstm_cell/weights")
self.assertEquals(variables[1].op.name,
"root/multi_rnn_cell/cell_0/basic_lstm_cell/biases")
self.assertEquals(variables[2].op.name,
"root/multi_rnn_cell/cell_1/basic_lstm_cell/weights")
self.assertEquals(variables[3].op.name,
"root/multi_rnn_cell/cell_1/basic_lstm_cell/biases")
self.assertEqual(expected_variable_names, [v.name for v in variables])
# The numbers in results were not calculated, this is just a smoke test.
self.assertAllClose(res[0], [[0.24024698, 0.24024698]])
expected_mem = np.array([[

View File

@ -27,7 +27,6 @@ from __future__ import division
from __future__ import print_function
import collections
import contextlib
import hashlib
import math
import numbers
@ -57,53 +56,6 @@ _BIAS_VARIABLE_NAME = "biases"
_WEIGHTS_VARIABLE_NAME = "weights"
@contextlib.contextmanager
def _checked_scope(cell, scope, reuse=None, **kwargs):
if reuse is not None:
kwargs["reuse"] = reuse
with vs.variable_scope(scope, **kwargs) as checking_scope:
scope_name = checking_scope.name
if hasattr(cell, "_scope"):
cell_scope = cell._scope # pylint: disable=protected-access
if cell_scope.name != checking_scope.name:
raise ValueError(
"Attempt to reuse RNNCell %s with a different variable scope than "
"its first use. First use of cell was with scope '%s', this "
"attempt is with scope '%s'. Please create a new instance of the "
"cell if you would like it to use a different set of weights. "
"If before you were using: MultiRNNCell([%s(...)] * num_layers), "
"change to: MultiRNNCell([%s(...) for _ in range(num_layers)]). "
"If before you were using the same cell instance as both the "
"forward and reverse cell of a bidirectional RNN, simply create "
"two instances (one for forward, one for reverse). "
"In May 2017, we will start transitioning this cell's behavior "
"to use existing stored weights, if any, when it is called "
"with scope=None (which can lead to silent model degradation, so "
"this error will remain until then.)"
% (cell, cell_scope.name, scope_name, type(cell).__name__,
type(cell).__name__))
else:
weights_found = False
try:
with vs.variable_scope(checking_scope, reuse=True):
vs.get_variable(_WEIGHTS_VARIABLE_NAME)
weights_found = True
except ValueError:
pass
if weights_found and reuse is None:
raise ValueError(
"Attempt to have a second RNNCell use the weights of a variable "
"scope that already has weights: '%s'; and the cell was not "
"constructed as %s(..., reuse=True). "
"To share the weights of an RNNCell, simply "
"reuse it in your second calculation, or create a new one with "
"the argument reuse=True." % (scope_name, type(cell).__name__))
# Everything is OK. Update the cell's scope and yield it.
cell._scope = checking_scope # pylint: disable=protected-access
yield checking_scope
class BasicRNNCell(RNNCell):
"""The most basic RNN cell."""

View File

@ -39,9 +39,6 @@ from tensorflow.python.platform import tf_logging as logging
from tensorflow.python.util import nest
_checked_scope = core_rnn_cell_impl._checked_scope # pylint: disable=protected-access
def _get_concat_variable(name, shape, dtype, num_shards):
"""Get a sharded variable concatenated into one tensor."""
sharded_variable = _get_sharded_variable(name, shape, dtype, num_shards)

View File

@ -0,0 +1,231 @@
# Copyright 2017 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
r"""Convert checkpoints using RNNCells to new name convention.
Usage:
python checkpoint_convert [--write_v1_checkpoint] \
'/path/to/checkpoint' '/path/to/new_checkpoint'
"""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import argparse
import collections
import re
import sys
from tensorflow.core.protobuf import saver_pb2
from tensorflow.python import pywrap_tensorflow
from tensorflow.python.client import session
from tensorflow.python.framework import ops
from tensorflow.python.ops import variables
from tensorflow.python.platform import app
from tensorflow.python.platform import tf_logging as logging
from tensorflow.python.training import saver as saver_lib
_RNN_NAME_REPLACEMENTS = collections.OrderedDict([
############################################################################
# contrib/rnn/python/ops/core_rnn_cell_impl.py
# BasicRNNCell
('basic_rnn_cell/weights', 'basic_rnn_cell/kernel'),
('basic_rnn_cell/biases', 'basic_rnn_cell/bias'),
# GRUCell
('gru_cell/weights', 'gru_cell/kernel'),
('gru_cell/biases', 'gru_cell/bias'),
('gru_cell/gates/weights', 'gru_cell/gates/kernel'),
('gru_cell/gates/biases', 'gru_cell/gates/bias'),
('gru_cell/candidate/weights', 'gru_cell/candidate/kernel'),
('gru_cell/candidate/biases', 'gru_cell/candidate/bias'),
# BasicLSTMCell
('basic_lstm_cell/weights', 'basic_lstm_cell/kernel'),
('basic_lstm_cell/biases', 'basic_lstm_cell/bias'),
# LSTMCell
('lstm_cell/weights', 'lstm_cell/kernel'),
('lstm_cell/biases', 'lstm_cell/bias'),
('lstm_cell/projection/weights', 'lstm_cell/projection/kernel'),
('lstm_cell/projection/biases', 'lstm_cell/projection/bias'),
# OutputProjectionWrapper
('output_projection_wrapper/weights', 'output_projection_wrapper/kernel'),
('output_projection_wrapper/biases', 'output_projection_wrapper/bias'),
# InputProjectionWrapper
('input_projection_wrapper/weights', 'input_projection_wrapper/kernel'),
('input_projection_wrapper/biases', 'input_projection_wrapper/bias'),
############################################################################
# contrib/rnn/python/ops/lstm_ops.py
# LSTMBlockFusedCell ??
('lstm_block_wrapper/weights', 'lstm_block_wrapper/kernel'),
('lstm_block_wrapper/biases', 'lstm_block_wrapper/bias'),
############################################################################
# contrib/rnn/python/ops/rnn_cell.py
# LayerNormBasicLSTMCell
('layer_norm_basic_lstm_cell/weights', 'layer_norm_basic_lstm_cell/kernel'),
('layer_norm_basic_lstm_cell/biases', 'layer_norm_basic_lstm_cell/bias'),
# UGRNNCell, not found in g3, but still need it?
('ugrnn_cell/weights', 'ugrnn_cell/kernel'),
('ugrnn_cell/biases', 'ugrnn_cell/bias'),
# NASCell
('nas_rnn/weights', 'nas_rnn/kernel'),
('nas_rnn/recurrent_weights', 'nas_rnn/recurrent_kernel'),
# IntersectionRNNCell
('intersection_rnn_cell/weights', 'intersection_rnn_cell/kernel'),
('intersection_rnn_cell/biases', 'intersection_rnn_cell/bias'),
('intersection_rnn_cell/in_projection/weights',
'intersection_rnn_cell/in_projection/kernel'),
('intersection_rnn_cell/in_projection/biases',
'intersection_rnn_cell/in_projection/bias'),
# PhasedLSTMCell
('phased_lstm_cell/mask_gates/weights',
'phased_lstm_cell/mask_gates/kernel'),
('phased_lstm_cell/mask_gates/biases', 'phased_lstm_cell/mask_gates/bias'),
('phased_lstm_cell/new_input/weights', 'phased_lstm_cell/new_input/kernel'),
('phased_lstm_cell/new_input/biases', 'phased_lstm_cell/new_input/bias'),
('phased_lstm_cell/output_gate/weights',
'phased_lstm_cell/output_gate/kernel'),
('phased_lstm_cell/output_gate/biases',
'phased_lstm_cell/output_gate/bias'),
# AttentionCellWrapper
('attention_cell_wrapper/weights', 'attention_cell_wrapper/kernel'),
('attention_cell_wrapper/biases', 'attention_cell_wrapper/bias'),
('attention_cell_wrapper/attn_output_projection/weights',
'attention_cell_wrapper/attn_output_projection/kernel'),
('attention_cell_wrapper/attn_output_projection/biases',
'attention_cell_wrapper/attn_output_projection/bias'),
('attention_cell_wrapper/attention/weights',
'attention_cell_wrapper/attention/kernel'),
('attention_cell_wrapper/attention/biases',
'attention_cell_wrapper/attention/bias'),
])
_RNN_SHARDED_NAME_REPLACEMENTS = collections.OrderedDict([
('LSTMCell/W_', 'lstm_cell/weights/part_'),
('BasicLSTMCell/Linear/Matrix_', 'basic_lstm_cell/weights/part_'),
('GRUCell/W_', 'gru_cell/weights/part_'),
('MultiRNNCell/Cell', 'multi_rnn_cell/cell_'),
])
def _rnn_name_replacement(var_name):
for pattern in _RNN_NAME_REPLACEMENTS:
if pattern in var_name:
old_var_name = var_name
var_name = var_name.replace(pattern, _RNN_NAME_REPLACEMENTS[pattern])
logging.info('Converted: %s --> %s' % (old_var_name, var_name))
break
return var_name
def _rnn_name_replacement_sharded(var_name):
for pattern in _RNN_SHARDED_NAME_REPLACEMENTS:
if pattern in var_name:
old_var_name = var_name
var_name = var_name.replace(pattern,
_RNN_SHARDED_NAME_REPLACEMENTS[pattern])
logging.info('Converted: %s --> %s' % (old_var_name, var_name))
return var_name
def _split_sharded_vars(name_shape_map):
"""Split shareded variables.
Args:
name_shape_map: A dict from variable name to variable shape.
Returns:
not_sharded: Names of the non-sharded variables.
sharded: Names of the sharded varibales.
"""
sharded = []
not_sharded = []
for name in name_shape_map:
if re.match(name, '_[0-9]+$'):
if re.sub('_[0-9]+$', '_1', name) in name_shape_map:
sharded.append(name)
else:
not_sharded.append(name)
else:
not_sharded.append(name)
return not_sharded, sharded
def convert_names(checkpoint_from_path,
checkpoint_to_path,
write_v1_checkpoint=False):
"""Migrates the names of variables within a checkpoint.
Args:
checkpoint_from_path: Path to source checkpoint to be read in.
checkpoint_to_path: Path to checkpoint to be written out.
write_v1_checkpoint: Whether the output checkpoint will be in V1 format.
Returns:
A dictionary that maps the new variable names to the Variable objects.
A dictionary that maps the old variable names to the new variable names.
"""
with ops.Graph().as_default():
logging.info('Reading checkpoint_from_path %s' % checkpoint_from_path)
reader = pywrap_tensorflow.NewCheckpointReader(checkpoint_from_path)
name_shape_map = reader.get_variable_to_shape_map()
not_sharded, sharded = _split_sharded_vars(name_shape_map)
new_variable_map = {}
conversion_map = {}
for var_name in not_sharded:
new_var_name = _rnn_name_replacement(var_name)
tensor = reader.get_tensor(var_name)
var = variables.Variable(tensor, name=var_name)
new_variable_map[new_var_name] = var
if new_var_name != var_name:
conversion_map[var_name] = new_var_name
for var_name in sharded:
new_var_name = _rnn_name_replacement_sharded(var_name)
var = variables.Variable(tensor, name=var_name)
new_variable_map[new_var_name] = var
if new_var_name != var_name:
conversion_map[var_name] = new_var_name
write_version = (saver_pb2.SaverDef.V1
if write_v1_checkpoint else saver_pb2.SaverDef.V2)
saver = saver_lib.Saver(new_variable_map, write_version=write_version)
with session.Session() as sess:
sess.run(variables.global_variables_initializer())
logging.info('Writing checkpoint_to_path %s' % checkpoint_to_path)
saver.save(sess, checkpoint_to_path)
logging.info('Summary:')
logging.info(' Converted %d variable name(s).' % len(new_variable_map))
return new_variable_map, conversion_map
def main(_):
convert_names(
FLAGS.checkpoint_from_path,
FLAGS.checkpoint_to_path,
write_v1_checkpoint=FLAGS.write_v1_checkpoint)
if __name__ == '__main__':
parser = argparse.ArgumentParser()
parser.register('type', 'bool', lambda v: v.lower() == 'true')
parser.add_argument('checkpoint_from_path', type=str,
help='Path to source checkpoint to be read in.')
parser.add_argument('checkpoint_to_path', type=str,
help='Path to checkpoint to be written out.')
parser.add_argument('--write_v1_checkpoint', action='store_true',
help='Write v1 checkpoint')
FLAGS, unparsed = parser.parse_known_args()
app.run(main=main, argv=[sys.argv[0]] + unparsed)

View File

@ -0,0 +1,108 @@
# Copyright 2017 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Unit tests for checkpoint converter."""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import glob
import os
import tempfile
from tensorflow.contrib.rnn.python.tools import checkpoint_convert
from tensorflow.python.client import session
from tensorflow.python.framework import ops
from tensorflow.python.ops import variables
from tensorflow.python.platform import test
from tensorflow.python.training import saver as saver_lib
class CheckpointConvertTest(test.TestCase):
def setUp(self):
self._old_ckpt_path = tempfile.mktemp()
self._new_ckpt_path = tempfile.mktemp()
ops.reset_default_graph()
def tearDown(self):
for file_name in glob.glob(self._old_ckpt_path + "*"):
os.remove(file_name)
for file_name in glob.glob(self._new_ckpt_path + "*"):
os.remove(file_name)
def testReplacementDictsContainUniqueAndNonEmptyVariableNames(self):
for old_name in checkpoint_convert._RNN_NAME_REPLACEMENTS:
new_name = checkpoint_convert._RNN_NAME_REPLACEMENTS[old_name]
self.assertTrue(old_name)
self.assertTrue(new_name)
self.assertNotEqual(old_name, new_name)
for old_name in checkpoint_convert._RNN_SHARDED_NAME_REPLACEMENTS:
new_name = checkpoint_convert._RNN_SHARDED_NAME_REPLACEMENTS[old_name]
self.assertTrue(old_name)
self.assertTrue(new_name)
self.assertNotEqual(old_name, new_name)
def testConversionFromV2WithConvertedVariableNamesSucceeds(self):
variables.Variable(10.0, name="a")
for old_name in checkpoint_convert._RNN_NAME_REPLACEMENTS:
variables.Variable(20.0, name=old_name)
with session.Session() as sess:
saver = saver_lib.Saver()
sess.run(variables.global_variables_initializer())
saver.save(sess, self._old_ckpt_path)
new_var_map, conversion_map = checkpoint_convert.convert_names(
self._old_ckpt_path, self._new_ckpt_path)
self.assertTrue(glob.glob(self._new_ckpt_path + "*"))
self.assertItemsEqual(
["a"] + list(checkpoint_convert._RNN_NAME_REPLACEMENTS.values()),
new_var_map.keys())
self.assertEqual(checkpoint_convert._RNN_NAME_REPLACEMENTS, conversion_map)
def testConversionFromV2WithoutConvertedVariableNamesSucceeds(self):
variables.Variable(10.0, name="a")
with session.Session() as sess:
saver = saver_lib.Saver()
sess.run(variables.global_variables_initializer())
saver.save(sess, self._old_ckpt_path)
new_var_map, conversion_map = checkpoint_convert.convert_names(
self._old_ckpt_path, self._new_ckpt_path)
self.assertItemsEqual(["a"], new_var_map.keys())
self.assertFalse(conversion_map)
def testConversionToV1Succeeds(self):
variables.Variable(10.0, name="a")
variables.Variable(
20.0, name=list(checkpoint_convert._RNN_NAME_REPLACEMENTS.keys())[-1])
with session.Session() as sess:
saver = saver_lib.Saver()
sess.run(variables.global_variables_initializer())
saver.save(sess, self._old_ckpt_path)
new_var_map, conversion_map = checkpoint_convert.convert_names(
self._old_ckpt_path, self._new_ckpt_path, write_v1_checkpoint=True)
self.assertItemsEqual(
["a", list(checkpoint_convert._RNN_NAME_REPLACEMENTS.values())[-1]],
new_var_map.keys())
self.assertEqual(
{list(checkpoint_convert._RNN_NAME_REPLACEMENTS.keys())[-1]:
list(checkpoint_convert._RNN_NAME_REPLACEMENTS.values())[-1]},
conversion_map)
if __name__ == "__main__":
test.main()

View File

@ -261,7 +261,7 @@ from tensorflow.python.framework import ops
from tensorflow.python.lib.io import file_io
from tensorflow.python.ops import clip_ops
from tensorflow.python.ops import control_flow_ops
from tensorflow.python.ops import data_flow_ops
from tensorflow.python.ops import lookup_ops
from tensorflow.python.ops import math_ops
from tensorflow.python.ops import variables as tf_variables
from tensorflow.python.platform import tf_logging as logging
@ -657,7 +657,7 @@ def train(train_op,
if local_init_op == _USE_DEFAULT:
local_init_op = control_flow_ops.group(
tf_variables.local_variables_initializer(),
data_flow_ops.tables_initializer())
lookup_ops.tables_initializer())
if sync_optimizer is not None and isinstance(
sync_optimizer, sync_replicas_optimizer.SyncReplicasOptimizer):

View File

@ -154,6 +154,7 @@ CORE_PROTO_SRCS = [
"framework/versions.proto",
"lib/core/error_codes.proto",
"protobuf/config.proto",
"protobuf/cluster.proto",
"protobuf/debug.proto",
"protobuf/queue_runner.proto",
"protobuf/rewriter_config.proto",
@ -506,6 +507,7 @@ tf_gen_op_libs(
"image_ops",
"io_ops",
"linalg_ops",
"lookup_ops",
"logging_ops",
"math_ops",
"nn_ops",
@ -582,6 +584,7 @@ cc_library(
":image_ops_op_lib",
":io_ops_op_lib",
":linalg_ops_op_lib",
":lookup_ops_op_lib",
":logging_ops_op_lib",
":math_ops_op_lib",
":nn_ops_op_lib",
@ -708,6 +711,7 @@ cc_library(
"//tensorflow/core/kernels:image",
"//tensorflow/core/kernels:io",
"//tensorflow/core/kernels:linalg",
"//tensorflow/core/kernels:lookup",
"//tensorflow/core/kernels:logging",
"//tensorflow/core/kernels:math",
"//tensorflow/core/kernels:multinomial_op",

View File

@ -23,8 +23,7 @@ limitations under the License.
namespace tensorflow {
Device::Device(Env* env, const DeviceAttributes& device_attributes,
Allocator* device_allocator)
Device::Device(Env* env, const DeviceAttributes& device_attributes)
: DeviceBase(env), device_attributes_(device_attributes) {
CHECK(DeviceNameUtils::ParseFullName(name(), &parsed_name_))
<< "Invalid device name: " << name();

View File

@ -53,8 +53,7 @@ namespace tensorflow {
class Device : public DeviceBase {
public:
Device(Env* env, const DeviceAttributes& device_attributes,
Allocator* device_allocator);
Device(Env* env, const DeviceAttributes& device_attributes);
~Device() override;
// Full name of this device (see top comment).

View File

@ -29,10 +29,18 @@ DeviceMgr::DeviceMgr(const std::vector<Device*>& devices)
for (Device* d : devices) {
devices_.push_back(d);
// Register under both the full name and the local name.
// Register under the (1) full name, (2) canonical name, and (3) local name.
string full_name = d->name();
device_map_[CopyToBackingStore(full_name)] = d;
DeviceNameUtils::ParsedName parsed_name = d->parsed_name();
if (parsed_name.has_job && parsed_name.has_replica &&
parsed_name.has_task && parsed_name.has_type && parsed_name.has_id) {
string canonical_name = DeviceNameUtils::FullName(
parsed_name.job, parsed_name.replica, parsed_name.task,
parsed_name.type, parsed_name.id);
device_map_[CopyToBackingStore(canonical_name)] = d;
}
string lname = DeviceNameUtils::LocalName(d->name());
device_map_[CopyToBackingStore(lname)] = d;
device_type_counts_[d->device_type()]++;
@ -40,7 +48,8 @@ DeviceMgr::DeviceMgr(const std::vector<Device*>& devices)
}
DeviceMgr::~DeviceMgr() {
for (auto p : devices_) delete p;
// TODO(b/37437134): Remove destructor after converting to std::unique_ptr.
for (Device* p : devices_) delete p;
}
StringPiece DeviceMgr::CopyToBackingStore(StringPiece s) {
@ -85,6 +94,12 @@ Status DeviceMgr::LookupDevice(StringPiece name, Device** device) const {
Status s;
auto iter = device_map_.find(name);
if (iter == device_map_.end()) {
std::vector<StringPiece> device_names;
for (auto&& itr : device_map_) {
device_names.push_back(itr.first);
}
LOG(WARNING) << "Unknown device: " << name
<< " all devices: " << str_util::Join(device_names, ", ");
return errors::InvalidArgument(name, " unknown device.");
}
*device = iter->second;

View File

@ -36,6 +36,7 @@ class DeviceMgr {
public:
// Takes ownership of each device in 'devices'.
// TODO(zhifengc): Other initialization information.
// TODO(b/37437134): Use std::unique_ptr's to track ownership.
explicit DeviceMgr(const std::vector<Device*>& devices);
~DeviceMgr();
@ -61,6 +62,7 @@ class DeviceMgr {
int NumDeviceType(const string& type) const;
private:
// TODO(b/37437134): Use std::unique_ptr's to track ownership.
typedef gtl::InlinedVector<Device*, 8> DeviceVec;
DeviceVec devices_;

View File

@ -39,7 +39,10 @@ class DeviceSet {
// Set the device designated as the "client". This device
// must also be registered via AddDevice().
void set_client_device(Device* device) { client_device_ = device; }
void set_client_device(Device* device) {
DCHECK(client_device_ == nullptr);
client_device_ = device;
}
// Returns a pointer to the device designated as the "client".
Device* client_device() const { return client_device_; }

View File

@ -27,8 +27,7 @@ namespace {
static Device* Dev(const char* type, const char* name) {
class FakeDevice : public Device {
public:
explicit FakeDevice(const DeviceAttributes& attr)
: Device(nullptr, attr, nullptr) {}
explicit FakeDevice(const DeviceAttributes& attr) : Device(nullptr, attr) {}
Status Sync() override { return Status::OK(); }
Allocator* GetAllocator(AllocatorAttributes) override { return nullptr; }
};

View File

@ -179,10 +179,9 @@ BaseGPUDevice::BaseGPUDevice(const SessionOptions& options, const string& name,
int gpu_id, const string& physical_device_desc,
Allocator* gpu_allocator, Allocator* cpu_allocator,
bool sync_every_op, int32 max_streams)
: LocalDevice(options,
Device::BuildDeviceAttributes(name, DEVICE_GPU, memory_limit,
locality, physical_device_desc),
gpu_allocator),
: LocalDevice(options, Device::BuildDeviceAttributes(name, DEVICE_GPU,
memory_limit, locality,
physical_device_desc)),
gpu_allocator_(gpu_allocator),
cpu_allocator_(cpu_allocator),
gpu_id_(gpu_id),

View File

@ -60,10 +60,8 @@ struct LocalDevice::EigenThreadPoolInfo {
};
LocalDevice::LocalDevice(const SessionOptions& options,
const DeviceAttributes& attributes,
Allocator* device_allocator)
: Device(options.env, attributes, device_allocator),
owned_tp_info_(nullptr) {
const DeviceAttributes& attributes)
: Device(options.env, attributes), owned_tp_info_(nullptr) {
// If we're running on the CPU, log warnings if we're not compiled using the
// best flags for performance.
port::WarnAboutUnusedCPUFeatures();

View File

@ -33,8 +33,8 @@ struct SessionOptions;
// GPUDevice into more 'process-wide' abstractions.
class LocalDevice : public Device {
public:
LocalDevice(const SessionOptions& options, const DeviceAttributes& attributes,
Allocator* device_allocator);
LocalDevice(const SessionOptions& options,
const DeviceAttributes& attributes);
~LocalDevice() override;
private:

View File

@ -0,0 +1,54 @@
/* Copyright 2015 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#include "tensorflow/core/common_runtime/renamed_device.h"
namespace tensorflow {
// TODO(saeta): Convert to returning a std::unique_ptr?
/* static */
Device* RenamedDevice::NewRenamedDevice(const string& new_base,
Device* underlying,
bool owns_underlying) {
DeviceNameUtils::ParsedName parsed_name;
CHECK(DeviceNameUtils::ParseFullName(new_base, &parsed_name));
DeviceNameUtils::ParsedName underlying_parsed_name =
underlying->parsed_name();
CHECK(underlying_parsed_name.has_type);
CHECK(underlying_parsed_name.has_id);
parsed_name.type = underlying_parsed_name.type;
parsed_name.id = underlying_parsed_name.id;
string name = DeviceNameUtils::FullName(parsed_name.job, parsed_name.replica,
parsed_name.task, parsed_name.type,
parsed_name.id);
DeviceAttributes attributes(underlying->attributes());
attributes.set_name(name);
return new RenamedDevice(underlying, attributes, owns_underlying);
}
RenamedDevice::RenamedDevice(Device* underlying,
const DeviceAttributes& attributes,
bool owns_underlying)
: Device(underlying->env(), attributes),
underlying_(underlying),
owns_underlying_(owns_underlying) {}
RenamedDevice::~RenamedDevice() {
if (owns_underlying_) {
delete underlying_;
}
}
} // namespace tensorflow

View File

@ -0,0 +1,119 @@
/* Copyright 2016 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#ifndef THIRD_PARTY_TENSORFLOW_CORE_COMMON_RUNTIME_RENAMED_DEVICE_H_
#define THIRD_PARTY_TENSORFLOW_CORE_COMMON_RUNTIME_RENAMED_DEVICE_H_
#include "tensorflow/core/common_runtime/device.h"
#include "tensorflow/core/util/device_name_utils.h"
namespace tensorflow {
// Wraps a device with a new name, delegating work to the wrapped device.
//
// This class is used to wrap local devices when using clusterspec propagation
// where the name of a particular device may change in the context of a given
// session.
class RenamedDevice : public Device {
public:
static Device* NewRenamedDevice(const string& new_base, Device* underlying,
bool owns_underlying);
~RenamedDevice() override;
// Below are virtual methods defined on DeviceBase
bool RequiresRecordingAccessedTensors() const override {
return underlying_->RequiresRecordingAccessedTensors();
}
const CpuWorkerThreads* tensorflow_cpu_worker_threads() const override {
return underlying_->tensorflow_cpu_worker_threads();
}
const GpuDeviceInfo* tensorflow_gpu_device_info() const override {
return underlying_->tensorflow_gpu_device_info();
}
Allocator* GetAllocator(AllocatorAttributes attr) override {
return underlying_->GetAllocator(attr);
}
Allocator* GetStepAllocator(AllocatorAttributes attr,
ResourceMgr* step_resource_manager) override {
return underlying_->GetStepAllocator(attr, step_resource_manager);
}
const Eigen::ThreadPoolDevice* eigen_cpu_device() override {
return underlying_->eigen_cpu_device();
}
#ifdef TENSORFLOW_USE_SYCL
const Eigen::SyclDevice* eigen_sycl_device() const override {
return underlying_->eigen_sycl_device();
}
#endif
PerOpGpuDevice* MakeGpuDevice() override {
return underlying_->MakeGpuDevice();
}
void ReinitializeGpuDevice(OpKernelContext* context, PerOpGpuDevice* device,
DeviceContext* dc, Allocator* allocator) override {
underlying_->ReinitializeGpuDevice(context, device, dc, allocator);
}
Status MakeTensorFromProto(const TensorProto& tensor_proto,
const AllocatorAttributes alloc_attrs,
Tensor* tensor) override {
return underlying_->MakeTensorFromProto(tensor_proto, alloc_attrs, tensor);
}
// Below are virtual methods defined on Device
void Compute(OpKernel* op_kernel, OpKernelContext* context) override {
underlying_->Compute(op_kernel, context);
}
void ComputeAsync(AsyncOpKernel* op_kernel, OpKernelContext* context,
AsyncOpKernel::DoneCallback done) override {
underlying_->ComputeAsync(op_kernel, context, std::move(done));
}
void ConsumeListOfAccessedTensors(
DeviceContext* context, const TensorReferenceVector& tensors) override {
underlying_->ConsumeListOfAccessedTensors(context, tensors);
}
Status Sync() override { return underlying_->Sync(); }
Status MaybeRewriteGraph(const FunctionDefLibrary& library,
std::unique_ptr<Graph>* graph) override {
return underlying_->MaybeRewriteGraph(library, graph);
}
Status FillContextMap(const Graph* graph,
DeviceContextMap* device_context_map) override {
return underlying_->FillContextMap(graph, device_context_map);
}
private:
RenamedDevice(Device* underlying, const DeviceAttributes& attributes,
bool owns_underlying);
Device* const underlying_;
const bool owns_underlying_;
};
} // namespace tensorflow
#endif // THIRD_PARTY_TENSORFLOW_CORE_COMMON_RUNTIME_RENAMED_DEVICE_H_

View File

@ -66,7 +66,7 @@ class DummyOp : public OpKernel {
class FakeDevice : public Device {
private:
explicit FakeDevice(const DeviceAttributes& device_attributes)
: Device(nullptr, device_attributes, nullptr) {}
: Device(nullptr, device_attributes) {}
public:
Status Sync() override { return errors::Unimplemented("FakeDevice::Sync()"); }

View File

@ -38,10 +38,8 @@ ThreadPoolDevice::ThreadPoolDevice(const SessionOptions& options,
const string& name, Bytes memory_limit,
const DeviceLocality& locality,
Allocator* allocator)
: LocalDevice(options,
Device::BuildDeviceAttributes(name, DEVICE_CPU, memory_limit,
locality),
allocator),
: LocalDevice(options, Device::BuildDeviceAttributes(
name, DEVICE_CPU, memory_limit, locality)),
allocator_(allocator) {}
ThreadPoolDevice::~ThreadPoolDevice() {}

View File

@ -77,7 +77,6 @@ cc_library(
],
deps = [
":graph_mgr",
":rendezvous_mgr_interface",
":worker_cache",
"//tensorflow/core:master_proto_cc",
"//tensorflow/core:protos_all_cc",
@ -92,9 +91,9 @@ cc_library(
deps = [
":graph_mgr",
":worker_session",
"//tensorflow/core:core_cpu_internal",
"//tensorflow/core:lib",
"//tensorflow/core:protos_all_cc",
"//tensorflow/core/distributed_runtime/rpc:rpc_rendezvous_mgr",
],
)
@ -237,6 +236,7 @@ cc_library(
"//tensorflow/core:lib",
"//tensorflow/core:lib_internal",
"//tensorflow/core:master_proto_cc",
"//tensorflow/core:protos_all_cc",
"//tensorflow/core:worker_proto_cc",
],
)

View File

@ -35,9 +35,8 @@ limitations under the License.
namespace tensorflow {
BaseRendezvousMgr::BaseRendezvousMgr(const WorkerEnv* worker_env,
const string& worker_name)
: worker_env_(worker_env), worker_name_(worker_name) {}
BaseRendezvousMgr::BaseRendezvousMgr(const WorkerEnv* worker_env)
: worker_env_(worker_env) {}
BaseRendezvousMgr::~BaseRendezvousMgr() {
for (auto& p : table_) {
@ -47,7 +46,7 @@ BaseRendezvousMgr::~BaseRendezvousMgr() {
}
}
Rendezvous* BaseRendezvousMgr::Find(int64 step_id) {
RemoteRendezvous* BaseRendezvousMgr::Find(int64 step_id) {
return FindOrCreate(step_id);
}
@ -55,7 +54,7 @@ BaseRemoteRendezvous* BaseRendezvousMgr::FindOrCreate(int64 step_id) {
mutex_lock l(mu_);
Table::iterator iter = table_.find(step_id);
if (iter == table_.end()) {
auto rr = Create(step_id, worker_env_, worker_name_);
auto rr = Create(step_id, worker_env_);
iter = table_.insert({step_id, rr}).first;
}
iter->second->Ref();
@ -128,14 +127,12 @@ void BaseRendezvousMgr::CleanupAll() {
}
}
BaseRemoteRendezvous::BaseRemoteRendezvous(const WorkerEnv* env,
const string& worker_name,
int64 step_id,
BaseRemoteRendezvous::BaseRemoteRendezvous(const WorkerEnv* env, int64 step_id,
bool tolerate_dup_recv)
: env_(env),
worker_name_(worker_name),
step_id_(step_id),
local_(NewLocalRendezvous(tolerate_dup_recv)) {}
local_(NewLocalRendezvous(tolerate_dup_recv)),
session_(nullptr) {}
BaseRemoteRendezvous::~BaseRemoteRendezvous() {
CHECK(active_.empty());
@ -150,6 +147,41 @@ static bool IsLocalDevice(const string& worker_name,
return device_name.starts_with(worker_name);
}
Status BaseRemoteRendezvous::Initialize(WorkerSession* session) {
CHECK_NE(session, nullptr) << "session must not be null!";
std::vector<DeferredCall> deferred_calls;
{
mutex_lock l(mu_);
if (session_ != nullptr) {
if (session_->worker_name == session->worker_name) {
LOG(INFO) << "Skipping rendezvous re-initialization.";
return Status::OK();
}
Status s = errors::Internal(
"Double init! Worker names would have changed from: ",
session_->worker_name, " -> ", session->worker_name);
LOG(WARNING) << s;
return s;
}
session_ = session;
std::swap(deferred_calls, deferred_calls_);
}
for (DeferredCall& call : deferred_calls) {
RecvLocalAsyncInternal(call.parsed, std::move(call.done));
}
return Status::OK();
}
WorkerSession* BaseRemoteRendezvous::session() {
mutex_lock l(mu_);
return session_;
}
bool BaseRemoteRendezvous::is_initialized() {
mutex_lock l(mu_);
return is_initialized_locked();
}
Status BaseRemoteRendezvous::Send(const Rendezvous::ParsedKey& parsed,
const Rendezvous::Args& args,
const Tensor& val, const bool is_dead) {
@ -157,10 +189,12 @@ Status BaseRemoteRendezvous::Send(const Rendezvous::ParsedKey& parsed,
{
mutex_lock l(mu_);
if (!status_.ok()) return status_;
DCHECK(is_initialized_locked());
if (!IsLocalDevice(session_->worker_name, parsed.src_device)) {
return errors::InvalidArgument(
"Invalid rendezvous key (src): ", parsed.FullKey(), " @ ",
session_->worker_name);
}
if (!IsLocalDevice(worker_name_, parsed.src_device)) {
return errors::InvalidArgument("Invalid rendezvous key (src): ",
parsed.FullKey(), " @ ", worker_name_);
}
// Buffers "val" and "device_context" in local_.
return local_->Send(parsed, args, val, is_dead);
@ -168,17 +202,24 @@ Status BaseRemoteRendezvous::Send(const Rendezvous::ParsedKey& parsed,
Status BaseRemoteRendezvous::ValidateDevices(const ParsedKey& parsed,
bool is_src) {
// Cache session pointer to avoid repeatedly taking & releasing the lock
// (e.g. calling session())
WorkerSession* sess = nullptr;
{
mutex_lock l(mu_);
if (!status_.ok()) return status_;
if (!is_initialized_locked()) {
return errors::Internal("ValidateDevices called before initialization.");
}
if (is_src && !IsLocalDevice(worker_name_, parsed.src_device)) {
sess = session_;
}
if (is_src && !IsLocalDevice(sess->worker_name, parsed.src_device)) {
return errors::InvalidArgument("Invalid rendezvous key (src): ",
parsed.FullKey(), " @ ", worker_name_);
parsed.FullKey(), " @ ", sess->worker_name);
}
if (!is_src && !IsLocalDevice(worker_name_, parsed.dst_device)) {
if (!is_src && !IsLocalDevice(sess->worker_name, parsed.dst_device)) {
return errors::InvalidArgument("Invalid rendezvous key (dst): ",
parsed.FullKey(), " @ ", worker_name_);
parsed.FullKey(), " @ ", sess->worker_name);
}
return Status::OK();
}
@ -244,6 +285,7 @@ void BaseRemoteRendezvous::RecvAsync(const ParsedKey& parsed,
const Rendezvous::Args& recv_args,
DoneCallback done) {
VLOG(1) << "RemoteRendezvous Recv " << this << " " << parsed.FullKey();
CHECK(is_initialized()) << "RecvAsync called when uninitialized.";
Status s = ValidateDevices(parsed, false /*!is_src*/);
if (!s.ok()) {
done(s, Args(), recv_args, Tensor(), false);
@ -280,6 +322,26 @@ void BaseRemoteRendezvous::RecvAsync(const ParsedKey& parsed,
void BaseRemoteRendezvous::RecvLocalAsync(const ParsedKey& parsed,
DoneCallback done) {
{
mutex_lock l(mu_);
if (!is_initialized_locked()) {
// RecvLocalAsync can be called (due to an incoming RecvTensor RPC from a
// remote worker) before the RunStep (or PartialRunStep) RPC from the
// master arrives. RecvLocalAsync thus buffers the arguments until after
// the RemoteRendezvous is Initialize()'d, when it completes the
// rendezvous logic. At some point after Initialize() is called, a Tensor
// is produced locally that will then be sent in response to the incoming
// RPC.
DeferredCall call(parsed, std::move(done));
deferred_calls_.push_back(call);
return;
}
}
RecvLocalAsyncInternal(parsed, std::move(done));
}
void BaseRemoteRendezvous::RecvLocalAsyncInternal(const ParsedKey& parsed,
DoneCallback done) {
Status s = ValidateDevices(parsed, true /* is_src */);
if (!s.ok()) {
done(s, Args(), Args(), Tensor(), false);
@ -318,4 +380,8 @@ void BaseRemoteRendezvous::DeregisterCall(BaseRecvTensorCall* call) {
active_.erase(call);
}
BaseRemoteRendezvous::DeferredCall::DeferredCall(const ParsedKey& parsed,
DoneCallback done)
: parsed(parsed), done(std::move(done)) {}
} // end namespace tensorflow

View File

@ -59,15 +59,17 @@ class BaseRecvTensorCall;
// RendezvousMgr must have keys generated by Rendezvous::CreateKey().
class BaseRendezvousMgr : public RendezvousMgrInterface {
public:
explicit BaseRendezvousMgr(const WorkerEnv* worker_env,
const string& worker_name);
explicit BaseRendezvousMgr(const WorkerEnv* worker_env);
~BaseRendezvousMgr() override;
// Returns Rendezvous supporting send and recv among workers in the
// "step_id". The caller takes ownership of one reference on the
// returned Rendezvous instance.
Rendezvous* Find(int64 step_id) override;
//
// Note: the caller must guarantee to eventually call Initialize on the
// returned RemoteRendezvous
RemoteRendezvous* Find(int64 step_id) override;
// Finds the local rendezvous instance for the "step_id". Runs
// "done" when the tensor for "key" is produced or an error occurs.
@ -91,8 +93,7 @@ class BaseRendezvousMgr : public RendezvousMgrInterface {
protected:
virtual BaseRemoteRendezvous* Create(int64 step_id,
const WorkerEnv* worker_env,
const string& worker_name) = 0;
const WorkerEnv* worker_env) = 0;
private:
// Maps step_id to rendezvous.
@ -100,7 +101,6 @@ class BaseRendezvousMgr : public RendezvousMgrInterface {
// Not owned.
const WorkerEnv* const worker_env_;
const string worker_name_;
mutex mu_;
Table table_ GUARDED_BY(mu_);
@ -116,10 +116,13 @@ class BaseRendezvousMgr : public RendezvousMgrInterface {
// Buffering of Tensor values is delegated to a "local" Rendezvous
// obtained from NewLocalRendezvous(). This class just adds
// functionality to coordinate with remote workers.
class BaseRemoteRendezvous : public Rendezvous {
class BaseRemoteRendezvous : public RemoteRendezvous {
public:
BaseRemoteRendezvous(const WorkerEnv* env, const string& worker_name,
int64 step_id, bool tolerate_dup_recv);
BaseRemoteRendezvous(const WorkerEnv* env, int64 step_id,
bool tolerate_dup_recv);
// Upgrades the BaseRemoteRendezvous to full initialization.
Status Initialize(WorkerSession* session) override;
// Forwards to local_, where the Tensor "val" will be buffered and
// any waiting callback stored.
@ -163,10 +166,13 @@ class BaseRemoteRendezvous : public Rendezvous {
// Removes "call" from active_ if "call" is in active_.
void DeregisterCall(BaseRecvTensorCall* call);
WorkerSession* session();
bool is_initialized();
~BaseRemoteRendezvous() override;
const WorkerEnv* const env_; // Not owned.
const string worker_name_;
const int64 step_id_;
private:
@ -176,10 +182,24 @@ class BaseRemoteRendezvous : public Rendezvous {
// Status given by StartAbort() if any.
Status status_ GUARDED_BY(mu_);
WorkerSession* session_ GUARDED_BY(mu_); // Not owned.
// Data structures to handle calls when partially initialized.
struct DeferredCall {
const ParsedKey parsed;
DoneCallback done;
DeferredCall(const ParsedKey& parsed, DoneCallback done);
};
std::vector<DeferredCall> deferred_calls_ GUARDED_BY(mu_);
// Active outstanding RecvTensor calls.
gtl::FlatSet<BaseRecvTensorCall*> active_ GUARDED_BY(mu_);
bool is_initialized_locked() EXCLUSIVE_LOCKS_REQUIRED(mu_) {
return session_ != nullptr;
}
// If "is_src" is true, checks that the rendezvous key "parsed"'s
// source is in this process. If "is_src" is false, checks that the
// rendezvous key "parsed"'s destination is in this process.
@ -194,6 +214,9 @@ class BaseRemoteRendezvous : public Rendezvous {
const Rendezvous::Args& out_args, const Tensor& in,
Tensor* out, StatusCallback done);
// Must be called only if fully initialized.
void RecvLocalAsyncInternal(const ParsedKey& parsed, DoneCallback done);
TF_DISALLOW_COPY_AND_ASSIGN(BaseRemoteRendezvous);
};

View File

@ -46,10 +46,8 @@ limitations under the License.
namespace tensorflow {
GraphMgr::GraphMgr(const WorkerEnv* worker_env,
RendezvousMgrInterface* rendezvous_mgr)
: worker_env_(worker_env), rendezvous_mgr_(rendezvous_mgr), table_(5) {
CHECK(rendezvous_mgr) << "Rendezvous mgr was null";
GraphMgr::GraphMgr(const WorkerEnv* worker_env, DeviceMgr* device_mgr)
: worker_env_(worker_env), device_mgr_(device_mgr), table_(5) {
// The default value of sync_on_finish will be flipped soon and this
// environment variable will be removed as well.
Status status =
@ -148,7 +146,7 @@ Status GraphMgr::InitItem(const string& session, const GraphDef& gdef,
};
popts.get_incarnation = [this](const string& name) -> int64 {
Device* device = nullptr;
Status s = worker_env_->device_mgr->LookupDevice(name, &device);
Status s = device_mgr_->LookupDevice(name, &device);
if (s.ok()) {
return device->attributes().incarnation();
} else {
@ -193,8 +191,7 @@ Status GraphMgr::InitItem(const string& session, const GraphDef& gdef,
ExecutionUnit* unit = &(item->units.back());
// Find the device.
Status s =
worker_env_->device_mgr->LookupDevice(device_name, &unit->device);
Status s = device_mgr_->LookupDevice(device_name, &unit->device);
if (!s.ok()) {
// Remove the empty unit from the item as the item destructor wants all
// units to have valid devices.
@ -214,7 +211,7 @@ Status GraphMgr::InitItem(const string& session, const GraphDef& gdef,
// Function library runtime.
unit->lib = NewFunctionLibraryRuntime(
worker_env_->device_mgr, worker_env_->env, unit->device,
device_mgr_, worker_env_->env, unit->device,
subgraph->versions().producer(), item->lib_def,
graph_options.optimizer_options());
@ -419,14 +416,14 @@ void GraphMgr::RecvOutputsFromRendezvousAsync(Rendezvous* rendezvous,
}
Status GraphMgr::SendInputs(const int64 step_id, const NamedTensors& in) {
Rendezvous* rendezvous = rendezvous_mgr_->Find(step_id);
Rendezvous* rendezvous = worker_env_->rendezvous_mgr->Find(step_id);
Status s = SendInputsToRendezvous(rendezvous, in);
rendezvous->Unref();
return s;
}
Status GraphMgr::RecvOutputs(const int64 step_id, NamedTensors* out) {
Rendezvous* rendezvous = rendezvous_mgr_->Find(step_id);
Rendezvous* rendezvous = worker_env_->rendezvous_mgr->Find(step_id);
Status s = RecvOutputsFromRendezvous(rendezvous, out);
rendezvous->Unref();
return s;
@ -434,7 +431,7 @@ Status GraphMgr::RecvOutputs(const int64 step_id, NamedTensors* out) {
void GraphMgr::RecvOutputsAsync(const int64 step_id, NamedTensors* out,
StatusCallback done) {
Rendezvous* rendezvous = rendezvous_mgr_->Find(step_id);
Rendezvous* rendezvous = worker_env_->rendezvous_mgr->Find(step_id);
RecvOutputsFromRendezvousAsync(rendezvous, out,
[done, rendezvous](const Status s) {
rendezvous->Unref();
@ -443,7 +440,8 @@ void GraphMgr::RecvOutputsAsync(const int64 step_id, NamedTensors* out,
}
void GraphMgr::ExecuteAsync(const string& handle, const int64 step_id,
const ExecutorOpts& opts,
WorkerSession* session,
const ExecutorOpts& /*opts*/,
StepStatsCollector* collector,
CostGraphDef* cost_graph,
CancellationManager* cancellation_manager,
@ -464,10 +462,14 @@ void GraphMgr::ExecuteAsync(const string& handle, const int64 step_id,
return;
}
Rendezvous* rendezvous = rendezvous_mgr_->Find(step_id);
RemoteRendezvous* rendezvous = worker_env_->rendezvous_mgr->Find(step_id);
Status s = rendezvous->Initialize(session);
// Sends values specified by the caller.
Status s = SendInputsToRendezvous(rendezvous, in);
if (s.ok()) {
s = SendInputsToRendezvous(rendezvous, in);
}
if (!s.ok()) {
done(s);
item->Unref();
@ -492,10 +494,9 @@ void GraphMgr::StartParallelExecutors(const string& handle, int64 step_id,
StatusCallback done) {
const int num_units = item->units.size();
CHECK_GE(num_units, 1);
ScopedStepContainer* step_container =
new ScopedStepContainer(step_id, [this](const string& name) {
worker_env_->device_mgr->ClearContainers({name});
});
ScopedStepContainer* step_container = new ScopedStepContainer(
step_id,
[this](const string& name) { device_mgr_->ClearContainers({name}); });
// NOTE: Transfer one ref of rendezvous and item.
ExecutorBarrier* barrier =
new ExecutorBarrier(num_units, rendezvous,

View File

@ -37,6 +37,8 @@ namespace tensorflow {
class ExecutorOpts;
class StepStatsCollector;
class RendezvousMgrInterface;
class DeviceMgr;
struct WorkerSession;
// GraphMgr keeps track of a set of graphs that are registered with a
// TensorFlow worker. Each registered graph is identified by a handle
@ -62,8 +64,7 @@ class RendezvousMgrInterface;
// EXPECT_EQ(out["c"], Tensor({4, 6}));
class GraphMgr {
public:
explicit GraphMgr(const WorkerEnv* worker_env,
RendezvousMgrInterface* rendezvous_mgr);
explicit GraphMgr(const WorkerEnv* worker_env, DeviceMgr* device_mgr);
~GraphMgr();
// Registers a graph. Fills in "handle"
@ -78,8 +79,8 @@ class GraphMgr {
typedef std::map<string, Tensor> NamedTensors;
typedef std::function<void(const Status&)> StatusCallback;
void ExecuteAsync(const string& handle, const int64 step_id,
const ExecutorOpts& opts, StepStatsCollector* collector,
CostGraphDef* cost_graph,
WorkerSession* session, const ExecutorOpts& opts,
StepStatsCollector* collector, CostGraphDef* cost_graph,
CancellationManager* cancellation_manager,
const NamedTensors& in, StatusCallback done);
@ -131,7 +132,7 @@ class GraphMgr {
};
const WorkerEnv* worker_env_; // Not owned.
RendezvousMgrInterface* rendezvous_mgr_; // Not owned.
DeviceMgr* device_mgr_;
CostModelManager cost_model_manager_;

View File

@ -34,6 +34,7 @@ limitations under the License.
#include <unordered_set>
#include <vector>
#include "tensorflow/core/common_runtime/device_set.h"
#include "tensorflow/core/common_runtime/process_util.h"
#include "tensorflow/core/distributed_runtime/remote_device.h"
#include "tensorflow/core/distributed_runtime/worker_cache.h"
@ -48,12 +49,17 @@ limitations under the License.
#include "tensorflow/core/platform/macros.h"
#include "tensorflow/core/platform/mutex.h"
#include "tensorflow/core/platform/types.h"
#include "tensorflow/core/protobuf/cluster.pb.h"
#include "tensorflow/core/protobuf/master.pb.h"
#include "tensorflow/core/protobuf/worker.pb.h"
#include "tensorflow/core/public/session_options.h"
namespace tensorflow {
namespace {
const char* const kGrpcProtocol = "grpc://";
} // namespace
Master::Master(MasterEnv* env, double session_gc_seconds)
: env_(env),
last_1000_steps_(1000),
@ -290,25 +296,122 @@ void Master::CreateSession(const CreateSessionRequest* req,
CreateSessionResponse* resp, MyClosure done) {
SchedClosure([this, req, resp, done]() {
Status status;
WorkerCacheFactoryOptions worker_cache_factory_options;
string grpc_protocol("grpc");
worker_cache_factory_options.protocol = &grpc_protocol;
auto call_done = gtl::MakeCleanup([&status, &done] { done(status); });
status = ValidateExternalGraphDefSyntax(req->graph_def());
if (!status.ok()) return;
// Ping all the workers and build the list of devices that the
// session will use.
// The following 4 variables are set differently, depending on whether this
// session uses a client-provided clusterspec or not.
WorkerCacheInterface* worker_cache = nullptr;
// Note: worker_cache_ptr will be null except if this session is using a
// client-supplied ClusterDef (ClusterSpec propagation).
std::unique_ptr<WorkerCacheInterface> worker_cache_ptr;
std::unique_ptr<DeviceSet> device_set;
// TODO(saeta): Convert to std::make_unique when available.
std::unique_ptr<std::vector<std::unique_ptr<Device>>> remote_devices(
new std::vector<std::unique_ptr<Device>>());
status = DeviceFinder::GetRemoteDevices(req->config().device_filters(),
env_, env_->worker_cache,
remote_devices.get());
if (req->config().has_cluster_def()) {
worker_cache_factory_options.cluster_def = &req->config().cluster_def();
// Set the server_def's job_name and task_index fields.
string normalized_string;
string grpc_protocol(kGrpcProtocol);
if (req->target().compare(0, grpc_protocol.length(), grpc_protocol) ==
0) {
normalized_string =
req->target().substr(grpc_protocol.length(), string::npos);
} else {
normalized_string = req->target();
}
for (auto&& job : req->config().cluster_def().job()) {
for (auto&& task : job.tasks()) {
if (task.second == normalized_string) {
if (worker_cache_factory_options.job_name != nullptr) {
status = errors::InvalidArgument(
"Found multiple matching tasks that correspond to "
"to the master. Master target: '",
req->target(), "'. ClusterDef: ",
req->config().cluster_def().ShortDebugString());
LOG(ERROR) << status;
return;
}
if (env_->local_devices[0]->parsed_name().job == job.name() &&
env_->local_devices[0]->parsed_name().task == task.first) {
// TODO(b/37868888): Remove this limitation when resolved
status = errors::InvalidArgument(
"The ClusterSpec names the job and task index to be the same "
"names that were provided when the server booted. This is "
"currently not allowed. Job: ",
job.name(), ", task index: ", task.first);
return;
}
worker_cache_factory_options.job_name = &job.name();
worker_cache_factory_options.task_index = task.first;
}
}
}
// Create the worker cache from the computed server_def.
status = env_->worker_cache_factory(worker_cache_factory_options,
&worker_cache);
if (!status.ok()) return;
worker_cache_ptr = std::unique_ptr<WorkerCacheInterface>(worker_cache);
// Ping all the workers and build the list of devices that the
// session will use.
status =
DeviceFinder::GetRemoteDevices(req->config().device_filters(), env_,
worker_cache, remote_devices.get());
if (!status.ok()) return;
device_set.reset(new DeviceSet);
for (auto&& d : *remote_devices) {
device_set->AddDevice(d.get());
DeviceNameUtils::ParsedName name = d->parsed_name();
if (name.job == *worker_cache_factory_options.job_name &&
name.task == worker_cache_factory_options.task_index &&
name.type == "CPU") {
device_set->set_client_device(d.get());
}
}
} else {
worker_cache = env_->worker_cache;
// Ping all the workers and build the list of devices that the
// session will use.
status =
DeviceFinder::GetRemoteDevices(req->config().device_filters(), env_,
worker_cache, remote_devices.get());
if (!status.ok()) return;
device_set.reset(new DeviceSet);
for (auto&& d : *remote_devices) {
device_set->AddDevice(d.get());
}
int num_local_devices = 0;
for (Device* d : env_->local_devices) {
device_set->AddDevice(d);
if (num_local_devices == 0) {
// Uses the first local device as the client device.
device_set->set_client_device(d);
}
num_local_devices++;
}
}
CHECK(device_set->client_device());
SessionOptions options;
options.config = req->config();
MasterSession* session =
env_->master_session_factory(options, env_, std::move(remote_devices));
MasterSession* session = env_->master_session_factory(
options, env_, std::move(remote_devices), std::move(worker_cache_ptr),
std::move(device_set));
GraphDef* gdef =
const_cast<CreateSessionRequest*>(req)->mutable_graph_def();
status = session->Create(gdef);
status = session->Create(gdef, worker_cache_factory_options);
if (!status.ok()) {
session->Close().IgnoreError();
session->Unref();

View File

@ -19,17 +19,41 @@ limitations under the License.
#include <functional>
#include <vector>
#include "tensorflow/core/distributed_runtime/master_session.h"
#include "tensorflow/core/protobuf/cluster.pb.h"
#include "tensorflow/core/protobuf/tensorflow_server.pb.h"
#include "tensorflow/core/public/session_options.h"
namespace tensorflow {
class Device;
class DeviceSet;
class Env;
class MasterSession;
class OpRegistryInterface;
class WorkerCacheInterface;
// Options passed to the worker_cache_factory function.
struct WorkerCacheFactoryOptions {
const ClusterDef* cluster_def = nullptr;
const string* job_name = nullptr;
int task_index;
const string* protocol = nullptr;
WorkerCacheFactoryOptions() {}
// Construct from a ServerDef proto.
//
// Note: server_def must outlive WorkerCacheFactoryOptions!
WorkerCacheFactoryOptions(const ServerDef& server_def) {
if (server_def.has_cluster() && !server_def.job_name().empty()) {
cluster_def = &server_def.cluster();
job_name = &server_def.job_name();
task_index = server_def.task_index();
protocol = &server_def.protocol();
}
}
};
// The master environment class, which holds a bag of pointers to
// per-master state.
//
@ -57,8 +81,14 @@ struct MasterEnv {
// `MasterEnv*` is retained by the caller.
std::function<MasterSession*(
SessionOptions, MasterEnv*,
std::unique_ptr<std::vector<std::unique_ptr<Device>>>)>
std::unique_ptr<std::vector<std::unique_ptr<Device>>>,
std::unique_ptr<WorkerCacheInterface>,
std::unique_ptr<DeviceSet> device_set)>
master_session_factory;
std::function<Status(const WorkerCacheFactoryOptions&,
WorkerCacheInterface**)>
worker_cache_factory;
};
} // end namespace tensorflow

View File

@ -36,11 +36,13 @@ limitations under the License.
#include "tensorflow/core/lib/core/notification.h"
#include "tensorflow/core/lib/core/refcount.h"
#include "tensorflow/core/lib/core/status.h"
#include "tensorflow/core/lib/gtl/cleanup.h"
#include "tensorflow/core/lib/gtl/inlined_vector.h"
#include "tensorflow/core/lib/gtl/map_util.h"
#include "tensorflow/core/lib/random/random.h"
#include "tensorflow/core/lib/strings/numbers.h"
#include "tensorflow/core/lib/strings/str_util.h"
#include "tensorflow/core/lib/strings/strcat.h"
#include "tensorflow/core/lib/strings/stringprintf.h"
#include "tensorflow/core/platform/env.h"
#include "tensorflow/core/platform/logging.h"
@ -162,7 +164,7 @@ class MasterSession::ReffedClientGraph : public core::RefCounted {
// Partitions the graph into subgraphs and registers them on
// workers.
Status RegisterPartitions(const PartitionOptions& popts,
const FunctionDefLibrary& func_def_lib);
const FunctionLibraryDefinition& flib_def);
// Runs one step of all partitions.
Status RunPartitions(const MasterEnv* env, int64 step_id,
@ -273,7 +275,7 @@ class MasterSession::ReffedClientGraph : public core::RefCounted {
};
Status MasterSession::ReffedClientGraph::RegisterPartitions(
const PartitionOptions& popts, const FunctionDefLibrary& func_def_lib) {
const PartitionOptions& popts, const FunctionLibraryDefinition& flib_def) {
{ // Ensure register once.
mu_.lock();
if (!init_started_) {
@ -292,7 +294,8 @@ Status MasterSession::ReffedClientGraph::RegisterPartitions(
graph_defs_for_publishing.push_back(&name_def.second);
}
stats_publisher_->PublishGraphProto(graph_defs_for_publishing);
s = DoRegisterPartitions(popts, func_def_lib, std::move(graph_defs));
s = DoRegisterPartitions(popts, flib_def.ToProto(),
std::move(graph_defs));
}
mu_.lock();
init_result_ = s;
@ -527,6 +530,7 @@ Status MasterSession::ReffedClientGraph::RunPartitions(
c->req->set_is_partial(is_partial_);
c->req->set_is_last_partial_run(is_last_partial_run);
}
c->req->set_session_handle(session_handle_);
c->req->set_graph_handle(part.graph_handle);
c->req->set_step_id(step_id);
*c->req->mutable_exec_opts() = exec_opts;
@ -870,6 +874,7 @@ void MasterSession::ReffedClientGraph::DeregisterPartitions() {
// The graph handle may be empty if we failed during partition registration.
if (!part.graph_handle.empty()) {
Call* c = new Call;
c->req.set_session_handle(session_handle_);
c->req.set_graph_handle(part.graph_handle);
// NOTE(mrry): We must capture `worker_cache_` since `this`
// could be deleted before the callback is called.
@ -972,31 +977,25 @@ string BuildGraphOptionsString(const BuildGraphOptions& opts) {
MasterSession::MasterSession(
const SessionOptions& opt, const MasterEnv* env,
std::unique_ptr<std::vector<std::unique_ptr<Device>>> remote_devs,
std::unique_ptr<WorkerCacheInterface> worker_cache,
std::unique_ptr<DeviceSet> device_set,
StatsPublisherFactory stats_publisher_factory)
: session_opts_(opt),
env_(env),
handle_(strings::FpToString(random::New64())),
remote_devs_(std::move(remote_devs)),
worker_cache_(std::move(worker_cache)),
devices_(std::move(device_set)),
stats_publisher_factory_(std::move(stats_publisher_factory)),
graph_version_(0),
run_graphs_(5),
partial_run_graphs_(5) {
UpdateLastAccessTime();
CHECK(devices_) << "device_set was null!";
VLOG(1) << "Session " << handle_ << " #local " << env->local_devices.size()
<< " #remote " << remote_devs_->size();
for (auto&& d : *remote_devs_) {
devices_.AddDevice(d.get());
}
int num_local_devices = 0;
for (Device* d : env->local_devices) {
devices_.AddDevice(d);
if (num_local_devices == 0) {
// Uses the first local device as the client device.
devices_.set_client_device(d);
}
num_local_devices++;
}
LOG(INFO) << "Start master session " << handle_
<< " with config: " << std::endl
<< session_opts_.config.DebugString();
@ -1011,7 +1010,8 @@ void MasterSession::UpdateLastAccessTime() {
last_access_time_usec_.store(Env::Default()->NowMicros());
}
Status MasterSession::Create(GraphDef* graph_def) {
Status MasterSession::Create(GraphDef* graph_def,
const WorkerCacheFactoryOptions& options) {
if (session_opts_.config.graph_options().place_pruned_graph()) {
// TODO(b/29900832): Fix this or remove the option.
LOG(WARNING) << "Distributed session does not support the "
@ -1019,17 +1019,93 @@ Status MasterSession::Create(GraphDef* graph_def) {
session_opts_.config.mutable_graph_options()->set_place_pruned_graph(false);
}
SimpleGraphExecutionStateOptions options;
options.device_set = &devices_;
options.session_options = &session_opts_;
SimpleGraphExecutionStateOptions execution_options;
execution_options.device_set = devices_.get();
execution_options.session_options = &session_opts_;
{
mutex_lock l(mu_);
TF_RETURN_IF_ERROR(SimpleGraphExecutionState::MakeForBaseGraph(
graph_def, options, &execution_state_));
graph_def, execution_options, &execution_state_));
}
if (options.cluster_def != nullptr) {
return CreateWorkerSessions(options);
}
return Status::OK();
}
Status MasterSession::CreateWorkerSessions(
const WorkerCacheFactoryOptions& options) {
CHECK(worker_cache_) << "CreateWorkerSessions should be called only with "
<< "dynamic cluster membership.";
std::vector<string> worker_names;
worker_cache_->ListWorkers(&worker_names);
struct WorkerGroup {
// The worker name. (Not owned.)
const string* name;
// The worker referenced by name. (Not owned.)
WorkerInterface* worker = nullptr;
// Request and responses used for a given worker.
CreateWorkerSessionRequest request;
CreateWorkerSessionResponse response;
Status status = Status::OK();
};
BlockingCounter done(worker_names.size());
std::vector<WorkerGroup> workers(worker_names.size());
// Release the workers.
auto cleanup = gtl::MakeCleanup([this, &workers] {
for (auto&& worker_group : workers) {
if (worker_group.worker != nullptr) {
worker_cache_->ReleaseWorker(*worker_group.name, worker_group.worker);
}
}
});
Status status = Status::OK();
// Create all the workers & kick off the computations.
for (size_t i = 0; i < worker_names.size(); ++i) {
workers[i].name = &worker_names[i];
workers[i].worker = worker_cache_->CreateWorker(worker_names[i]);
workers[i].request.set_session_handle(handle_);
*workers[i].request.mutable_server_def()->mutable_cluster() =
*options.cluster_def;
workers[i].request.mutable_server_def()->set_protocol(*options.protocol);
DeviceNameUtils::ParsedName name;
if (!DeviceNameUtils::ParseFullName(worker_names[i], &name)) {
status = errors::Internal("Could not parse name ", worker_names[i]);
LOG(WARNING) << status;
return status;
}
if (!name.has_job || !name.has_task) {
status = errors::Internal("Incomplete worker name ", worker_names[i]);
LOG(WARNING) << status;
return status;
}
workers[i].request.mutable_server_def()->set_job_name(name.job);
workers[i].request.mutable_server_def()->set_task_index(name.task);
}
for (size_t i = 0; i < worker_names.size(); ++i) {
auto cb = [i, &workers, &done](const Status& s) {
workers[i].status = s;
done.DecrementCount();
};
workers[i].worker->CreateWorkerSessionAsync(&workers[i].request,
&workers[i].response, cb);
}
done.Wait();
for (size_t i = 0; i < workers.size(); ++i) {
status.Update(workers[i].status);
}
return status;
}
Status MasterSession::Extend(const ExtendSessionRequest* req,
ExtendSessionResponse* resp) {
UpdateLastAccessTime();
@ -1059,6 +1135,13 @@ Status MasterSession::Extend(const ExtendSessionRequest* req,
return Status::OK();
}
WorkerCacheInterface* MasterSession::get_worker_cache() const {
if (worker_cache_) {
return worker_cache_.get();
}
return env_->worker_cache;
}
Status MasterSession::StartStep(const BuildGraphOptions& opts, int64* count,
ReffedClientGraph** rcg, bool is_partial) {
const uint64 hash = HashBuildGraphOptions(opts);
@ -1082,11 +1165,11 @@ Status MasterSession::StartStep(const BuildGraphOptions& opts, int64* count,
<< "\n";
std::unique_ptr<SimpleClientGraph> client_graph;
TF_RETURN_IF_ERROR(execution_state_->BuildGraph(opts, &client_graph));
WorkerCacheInterface* worker_cache = get_worker_cache();
auto entry = new ReffedClientGraph(
handle_, opts, std::move(client_graph), session_opts_,
stats_publisher_factory_, execution_state_.get(), is_partial,
env_->worker_cache);
worker_cache);
iter = m->insert({hash, entry}).first;
VLOG(1) << "Preparing to execute new graph";
}
@ -1161,6 +1244,8 @@ Status MasterSession::Run(CallOptions* opts, const RunStepRequestWrapper& req,
return errors::FailedPrecondition("Session is closed.");
}
++num_running_;
// Note: all code paths must eventually call MarkRunCompletion()
// in order to appropriate decrement the num_running_ counter.
}
Status status;
if (!req.partial_run_handle().empty()) {
@ -1168,14 +1253,16 @@ Status MasterSession::Run(CallOptions* opts, const RunStepRequestWrapper& req,
} else {
status = DoRunWithLocalExecution(opts, req, resp);
}
{
return status;
}
// Decrements num_running_ and broadcasts if num_running_ is zero.
void MasterSession::MarkRunCompletion() {
mutex_lock l(mu_);
--num_running_;
if (num_running_ == 0) {
num_running_is_zero_.notify_all();
}
}
return status;
}
Status MasterSession::BuildAndRegisterPartitions(ReffedClientGraph* rcg) {
@ -1187,7 +1274,7 @@ Status MasterSession::BuildAndRegisterPartitions(ReffedClientGraph* rcg) {
return strings::StrCat(prefix, "_S", next_node_id_++);
};
popts.get_incarnation = [this](const string& name) -> int64 {
Device* d = devices_.FindDeviceByName(name);
Device* d = devices_->FindDeviceByName(name);
if (d == nullptr) {
return PartitionOptions::kIllegalIncarnation;
} else {
@ -1214,7 +1301,7 @@ Status MasterSession::BuildAndRegisterPartitions(ReffedClientGraph* rcg) {
}
TF_RETURN_IF_ERROR(
rcg->RegisterPartitions(popts, rcg->client_graph()->flib_def->ToProto()));
rcg->RegisterPartitions(popts, *rcg->client_graph()->flib_def));
return Status::OK();
}
@ -1222,6 +1309,7 @@ Status MasterSession::BuildAndRegisterPartitions(ReffedClientGraph* rcg) {
Status MasterSession::DoPartialRun(CallOptions* opts,
const RunStepRequestWrapper& req,
MutableRunStepResponseWrapper* resp) {
auto cleanup = gtl::MakeCleanup([this] { MarkRunCompletion(); });
const string& prun_handle = req.partial_run_handle();
RunState* run_state = nullptr;
{
@ -1320,12 +1408,14 @@ Status MasterSession::DoPartialRun(CallOptions* opts,
rcg->Ref();
rcg->ProcessStats(run_state->step_id, &run_state->pss, run_state->ph.get(),
req.options(), resp->mutable_metadata());
cleanup.release(); // MarkRunCompletion called in done closure.
rcg->CleanupPartitionsAsync(
run_state->step_id, [this, rcg, prun_handle](const Status& s) {
if (!s.ok()) {
LOG(ERROR) << "Cleanup partition error: " << s;
}
rcg->Unref();
MarkRunCompletion();
});
mutex_lock l(mu_);
partial_runs_.erase(prun_handle);
@ -1367,10 +1457,10 @@ Status MasterSession::CreateDebuggerState(
Status MasterSession::DoRunWithLocalExecution(
CallOptions* opts, const RunStepRequestWrapper& req,
MutableRunStepResponseWrapper* resp) {
VLOG(2) << "DoRunWithLocalExecution "
<< "req: " << req.DebugString();
VLOG(2) << "DoRunWithLocalExecution req: " << req.DebugString();
PerStepState pss;
pss.start_micros = Env::Default()->NowMicros();
auto cleanup = gtl::MakeCleanup([this] { MarkRunCompletion(); });
// Prepare.
BuildGraphOptions bgopts;
@ -1437,11 +1527,13 @@ Status MasterSession::DoRunWithLocalExecution(
}
}
rcg->Ref();
rcg->CleanupPartitionsAsync(step_id, [rcg](const Status& s) {
cleanup.release(); // MarkRunCompletion called in done closure.
rcg->CleanupPartitionsAsync(step_id, [this, rcg](const Status& s) {
if (!s.ok()) {
LOG(ERROR) << "Cleanup partition error: " << s;
}
rcg->Unref();
MarkRunCompletion();
});
return s;
}

View File

@ -26,6 +26,7 @@ limitations under the License.
#include "tensorflow/core/distributed_runtime/call_options.h"
#include "tensorflow/core/distributed_runtime/master_env.h"
#include "tensorflow/core/distributed_runtime/message_wrappers.h"
#include "tensorflow/core/distributed_runtime/worker_cache.h"
#include "tensorflow/core/lib/core/status.h"
#include "tensorflow/core/platform/types.h"
#include "tensorflow/core/protobuf/master.pb.h"
@ -49,13 +50,15 @@ class MasterSession : public core::RefCounted {
MasterSession(
const SessionOptions& options, const MasterEnv* env,
std::unique_ptr<std::vector<std::unique_ptr<Device>>> remote_devs,
std::unique_ptr<WorkerCacheInterface> worker_cache,
std::unique_ptr<DeviceSet> device_set,
StatsPublisherFactory stats_publisher_factory);
// Initialize the MasterSession for "def". Must be called before Extend(),
// Run(), or Close().
//
// After this method returns, `def` will no longer be valid.
Status Create(GraphDef* def);
Status Create(GraphDef* def, const WorkerCacheFactoryOptions& options);
// Returns the session handle.
const string& handle() const { return handle_; }
@ -107,8 +110,14 @@ class MasterSession : public core::RefCounted {
std::unique_ptr<std::vector<std::unique_ptr<Device>>> remote_devs_;
// The optional session-specific worker cluster.
// TODO(saeta): Convert to std::optional when available.
std::unique_ptr<WorkerCacheInterface> worker_cache_;
// Retrieves either worker_cache_ or the env_->worker_cache as appropriate.
WorkerCacheInterface* get_worker_cache() const;
// The device set used by this session.
DeviceSet devices_;
std::unique_ptr<DeviceSet> devices_;
StatsPublisherFactory stats_publisher_factory_;
@ -181,6 +190,13 @@ class MasterSession : public core::RefCounted {
// Private dtor. The client must call Close().
virtual ~MasterSession();
// Creates sessions on all workers.
//
// If this session is operating using the new ClusterSpec propagation behavior
// call this method in order to propagate the cluster membership to all
// workers.
Status CreateWorkerSessions(const WorkerCacheFactoryOptions& server_def);
Status StartStep(const BuildGraphOptions& opts, int64* count,
ReffedClientGraph** graph, bool is_partial);
void ClearRunsTable(std::vector<ReffedClientGraph*>* to_unref,
@ -190,6 +206,7 @@ class MasterSession : public core::RefCounted {
MutableRunStepResponseWrapper* resp);
Status DoPartialRun(CallOptions* opts, const RunStepRequestWrapper& req,
MutableRunStepResponseWrapper* resp);
void MarkRunCompletion();
void UpdateLastAccessTime();
Status BuildAndRegisterPartitions(ReffedClientGraph* rcg);

View File

@ -252,6 +252,14 @@ string ProtoRunStepRequest::DebugString() const {
const RunStepRequest& ProtoRunStepRequest::ToProto() const { return *request_; }
const string& InMemoryRunGraphRequest::session_handle() const {
return session_handle_;
}
void InMemoryRunGraphRequest::set_session_handle(const string& handle) {
session_handle_ = handle;
}
const string& InMemoryRunGraphRequest::graph_handle() const {
return graph_handle_;
}
@ -320,6 +328,7 @@ void InMemoryRunGraphRequest::set_is_last_partial_run(
const RunGraphRequest& InMemoryRunGraphRequest::ToProto() const {
if (!proto_version_) {
proto_version_.reset(new RunGraphRequest);
proto_version_->set_session_handle(session_handle());
proto_version_->set_graph_handle(graph_handle());
proto_version_->set_step_id(step_id());
*proto_version_->mutable_exec_opts() = exec_opts();
@ -337,6 +346,14 @@ const RunGraphRequest& InMemoryRunGraphRequest::ToProto() const {
return *proto_version_;
}
const string& MutableProtoRunGraphRequest::session_handle() const {
return request_.session_handle();
}
void MutableProtoRunGraphRequest::set_session_handle(const string& handle) {
request_.set_session_handle(handle);
}
const string& MutableProtoRunGraphRequest::graph_handle() const {
return request_.graph_handle();
}
@ -423,6 +440,10 @@ const RunGraphRequest& MutableProtoRunGraphRequest::ToProto() const {
ProtoRunGraphRequest::ProtoRunGraphRequest(const RunGraphRequest* request)
: request_(request) {}
const string& ProtoRunGraphRequest::session_handle() const {
return request_->session_handle();
}
const string& ProtoRunGraphRequest::graph_handle() const {
return request_->graph_handle();
}

View File

@ -223,6 +223,10 @@ class RunGraphRequestWrapper {
public:
virtual ~RunGraphRequestWrapper() {}
// The session handle used to register the graph. If empty, a single global
// namespace is used.
virtual const string& session_handle() const = 0;
// REQUIRED: graph_handle must be returned by a RegisterGraph call
// to the same WorkerService.
virtual const string& graph_handle() const = 0;
@ -262,6 +266,7 @@ class RunGraphRequestWrapper {
// See `RunGraphRequestWrapper` above for a description of the fields.
class MutableRunGraphRequestWrapper : public RunGraphRequestWrapper {
public:
virtual void set_session_handle(const string& handle) = 0;
virtual void set_graph_handle(const string& handle) = 0;
virtual void set_step_id(int64 step_id) = 0;
virtual ExecutorOpts* mutable_exec_opts() = 0;
@ -280,6 +285,7 @@ class MutableRunGraphRequestWrapper : public RunGraphRequestWrapper {
class InMemoryRunGraphRequest : public MutableRunGraphRequestWrapper {
public:
// RunGraphRequestWrapper methods.
const string& session_handle() const override;
const string& graph_handle() const override;
int64 step_id() const override;
const ExecutorOpts& exec_opts() const override;
@ -293,6 +299,7 @@ class InMemoryRunGraphRequest : public MutableRunGraphRequestWrapper {
const RunGraphRequest& ToProto() const override;
// MutableRunGraphRequestWrapper methods.
void set_session_handle(const string& handle) override;
void set_graph_handle(const string& handle) override;
void set_step_id(int64 step_id) override;
ExecutorOpts* mutable_exec_opts() override;
@ -304,6 +311,7 @@ class InMemoryRunGraphRequest : public MutableRunGraphRequestWrapper {
void set_is_last_partial_run(bool is_last_partial_run) override;
private:
string session_handle_;
string graph_handle_;
int64 step_id_;
ExecutorOpts exec_opts_;
@ -325,6 +333,7 @@ class InMemoryRunGraphRequest : public MutableRunGraphRequestWrapper {
class MutableProtoRunGraphRequest : public MutableRunGraphRequestWrapper {
public:
// RunGraphRequestWrapper methods.
const string& session_handle() const override;
const string& graph_handle() const override;
int64 step_id() const override;
const ExecutorOpts& exec_opts() const override;
@ -338,6 +347,7 @@ class MutableProtoRunGraphRequest : public MutableRunGraphRequestWrapper {
const RunGraphRequest& ToProto() const override;
// MutableRunGraphRequestWrapper methods.
void set_session_handle(const string& handle) override;
void set_graph_handle(const string& handle) override;
void set_step_id(int64 step_id) override;
ExecutorOpts* mutable_exec_opts() override;
@ -357,6 +367,7 @@ class ProtoRunGraphRequest : public RunGraphRequestWrapper {
ProtoRunGraphRequest(const RunGraphRequest* request);
// RunGraphRequestWrapper methods.
const string& session_handle() const override;
const string& graph_handle() const override;
int64 step_id() const override;
const ExecutorOpts& exec_opts() const override;

View File

@ -16,11 +16,13 @@ limitations under the License.
#include "tensorflow/core/distributed_runtime/remote_device.h"
#include <vector>
#include "tensorflow/core/common_runtime/device.h"
#include "tensorflow/core/common_runtime/process_util.h"
#include "tensorflow/core/distributed_runtime/worker_cache.h"
#include "tensorflow/core/distributed_runtime/worker_interface.h"
#include "tensorflow/core/lib/core/errors.h"
#include "tensorflow/core/lib/gtl/cleanup.h"
#include "tensorflow/core/platform/logging.h"
#include "tensorflow/core/platform/macros.h"
#include "tensorflow/core/protobuf/worker.pb.h"
@ -43,8 +45,7 @@ string GetLocalDeviceName(StringPiece fullname) {
class RemoteDevice : public Device {
public:
RemoteDevice(Env* env, const DeviceAttributes& da)
: Device(env, da, nullptr),
local_dev_name_(GetLocalDeviceName(da.name())) {}
: Device(env, da), local_dev_name_(GetLocalDeviceName(da.name())) {}
Status Sync() override { return Status::OK(); }
Allocator* GetAllocator(AllocatorAttributes attr) override { return nullptr; }
@ -68,18 +69,50 @@ void NewRemoteDevices(Env* env, WorkerCacheInterface* worker_cache,
GetStatusResponse resp;
};
Call* call = new Call;
auto cb = [env, worker_cache, worker_name, done, wi, call](const Status& s) {
auto cb = [env, worker_cache, worker_name, done, wi,
call](const Status& status) {
Status s = status;
std::vector<Device*> remote_devices;
if (s.ok()) {
remote_devices.reserve(call->resp.device_attributes_size());
for (const DeviceAttributes& da : call->resp.device_attributes()) {
auto d = new RemoteDevice(env, da);
remote_devices.push_back(d);
}
}
auto cleanup = gtl::MakeCleanup(
[&worker_cache, &worker_name, &wi, &done, &remote_devices, &s, call] {
worker_cache->ReleaseWorker(worker_name, wi);
done(s, &remote_devices);
delete call;
});
if (s.ok()) {
DeviceNameUtils::ParsedName worker_name_parsed;
if (!DeviceNameUtils::ParseFullName(worker_name, &worker_name_parsed) ||
!worker_name_parsed.has_job || !worker_name_parsed.has_replica ||
!worker_name_parsed.has_task) {
s = errors::InvalidArgument("Could not parse worker name: ",
worker_name);
LOG(WARNING) << s;
return;
}
remote_devices.reserve(call->resp.device_attributes_size());
for (const DeviceAttributes& da : call->resp.device_attributes()) {
DeviceNameUtils::ParsedName device_name_parsed;
CHECK(DeviceNameUtils::ParseFullName(da.name(), &device_name_parsed))
<< "Device attribute name '" << da.name() << "' could not be "
<< "parsed. Device Attribute: " << da.DebugString();
// Preserve the exact name, if possible.
// TODO(b/37868888): Simplify when legacy device name formats removed.
if (device_name_parsed.job == worker_name_parsed.job &&
device_name_parsed.replica == worker_name_parsed.replica &&
device_name_parsed.task == worker_name_parsed.task) {
auto d = new RemoteDevice(env, da);
remote_devices.push_back(d);
} else {
DeviceAttributes da_rewritten = da;
da_rewritten.set_name(DeviceNameUtils::FullName(
worker_name_parsed.job, worker_name_parsed.replica,
worker_name_parsed.task, device_name_parsed.type,
device_name_parsed.id));
auto d = new RemoteDevice(env, da_rewritten);
remote_devices.push_back(d);
}
}
}
};
wi->GetStatusAsync(&call->req, &call->resp, cb);
}

View File

@ -25,6 +25,23 @@ limitations under the License.
namespace tensorflow {
struct WorkerSession;
// RemoteRendezvous follow a 2-part initialization. First the objects are
// constructed. Eventually, they will be initialized. Clients of the
// RendezvousMgrInterface must guarantee to call Initialize on the returned
// RemoteRendezvous eventually.
//
// Partially initialized RemoteRendezvous must respect the Rendezvous interface
// (i.e. Send() must never block), however implementations are not expected to
// actually perform the underlying operations until after the RemoteRendezvous
// has been Initialize'd.
class RemoteRendezvous : public Rendezvous {
public:
// Fully construct the RemoteRendezvous.
virtual Status Initialize(WorkerSession* session) = 0;
};
// RendezvousMgr keeps track of a set of local rendezvous instances.
// All tensors sent by this worker are buffered in a RendezvousMgr
// until the tensor is received. Each global unique "step_id"
@ -51,7 +68,10 @@ class RendezvousMgrInterface {
// Returns Rendezvous supporting send and recv among workers in the
// "step_id". The caller takes ownership of one reference on the
// returned Rendezvous instance.
virtual Rendezvous* Find(int64 step_id) = 0;
//
// Note: the caller must guarantee to eventually call Initialize on the
// returned RemoteRendezvous
virtual RemoteRendezvous* Find(int64 step_id) = 0;
// Finds the local rendezvous instance for the "step_id". Runs
// "done" when the tensor for "key" is produced or an error occurs.

View File

@ -63,10 +63,8 @@ class NoReusePortOption : public ::grpc::ServerBuilderOption {
};
// static utility function
RendezvousMgrInterface* NewRpcRendezvousMgr(
const WorkerEnv* env, const string& worker_name,
WorkerCacheInterface* worker_cache) {
return new RpcRendezvousMgr(env, worker_name, worker_cache);
RendezvousMgrInterface* NewRpcRendezvousMgr(const WorkerEnv* env) {
return new RpcRendezvousMgr(env);
}
} // namespace
@ -84,6 +82,9 @@ GrpcServer::~GrpcServer() {
// TODO(mrry): Refactor the *Env classes so that it is less fiddly
// to destroy them.
// Shut down all outstanding rendezvous.
delete worker_env_.rendezvous_mgr;
// We must delete graph_mgr before device_mgr, due to shared
// ownership of OpKernels in the executors. (The graph_mgr will
// free all stateless OpKernels, and pass over borrowed stateful
@ -91,8 +92,10 @@ GrpcServer::~GrpcServer() {
// OpSegments.)
if (worker_env_.session_mgr != nullptr) {
delete worker_env_.session_mgr; // Deletes graph_mgr's.
}
} else {
// Note: session_mgr's legacy_session_ deletes device_mgr now.
delete worker_env_.device_mgr;
}
// Do not delete (as these are not owned by the server):
// - master_env_.env
@ -100,8 +103,9 @@ GrpcServer::~GrpcServer() {
// - worker_env_.compute_pool
}
Status GrpcServer::Init(ServiceInitFunction service_func,
RendezvousMgrCreationFunction rendevous_mgr_func) {
Status GrpcServer::Init(
ServiceInitFunction service_func,
const RendezvousMgrCreationFunction& rendezvous_mgr_func) {
mutex_lock l(mu_);
CHECK_EQ(state_, NEW);
master_env_.env = env_;
@ -117,7 +121,11 @@ Status GrpcServer::Init(ServiceInitFunction service_func,
"/task:", server_def_.task_index());
TF_RETURN_IF_ERROR(DeviceFactory::AddDevices(sess_opts, name_prefix,
&master_env_.local_devices));
worker_env_.device_mgr = new DeviceMgr(master_env_.local_devices);
worker_env_.local_devices = master_env_.local_devices;
worker_env_.device_mgr = new DeviceMgr(worker_env_.local_devices);
worker_env_.rendezvous_mgr = rendezvous_mgr_func == nullptr
? new RpcRendezvousMgr(&worker_env_)
: rendezvous_mgr_func(&worker_env_);
string unused;
string default_worker_name;
if (!DeviceNameUtils::SplitDeviceName(master_env_.local_devices[0]->name(),
@ -189,20 +197,18 @@ Status GrpcServer::Init(ServiceInitFunction service_func,
}
WorkerCacheInterface* worker_cache;
TF_RETURN_IF_ERROR(WorkerCacheFactory(server_def_, &worker_cache));
WorkerCacheFactoryOptions worker_cache_factory_options(server_def_);
TF_RETURN_IF_ERROR(
WorkerCacheFactory(worker_cache_factory_options, &worker_cache));
CHECK_NE(nullptr, worker_cache);
// Set up worker environment.
std::unique_ptr<RendezvousMgrInterface> rendezvous_mgr(
rendevous_mgr_func == nullptr ?
new RpcRendezvousMgr(&worker_env_, name_prefix, worker_cache) :
rendevous_mgr_func(&worker_env_, name_prefix, worker_cache));
worker_env_.session_mgr = new SessionMgr(
&worker_env_, SessionMgr::WorkerNameFromServerDef(server_def_),
std::unique_ptr<WorkerCacheInterface>(worker_cache),
std::move(rendezvous_mgr),
[this](const ServerDef& server_def, WorkerCacheInterface** worker_cache) {
return WorkerCacheFactory(server_def, worker_cache);
WorkerCacheFactoryOptions options(server_def);
return WorkerCacheFactory(options, worker_cache);
});
worker_env_.compute_pool = ComputePool(sess_opts);
@ -212,11 +218,19 @@ Status GrpcServer::Init(ServiceInitFunction service_func,
master_env_.master_session_factory =
[config](
SessionOptions options, const MasterEnv* env,
std::unique_ptr<std::vector<std::unique_ptr<Device>>> remote_devs) {
std::unique_ptr<std::vector<std::unique_ptr<Device>>> remote_devs,
std::unique_ptr<WorkerCacheInterface> worker_cache,
std::unique_ptr<DeviceSet> device_set) {
options.config.MergeFrom(config);
return new MasterSession(options, env, std::move(remote_devs),
std::move(worker_cache), std::move(device_set),
CreateNoOpStatsPublisher);
};
master_env_.worker_cache_factory =
[this](const WorkerCacheFactoryOptions& options,
WorkerCacheInterface** worker_cache) {
return WorkerCacheFactory(options, worker_cache);
};
// Provide direct access to the master from in-process clients.
LocalMaster::Register(target(), master_impl_.get(),
@ -225,13 +239,11 @@ Status GrpcServer::Init(ServiceInitFunction service_func,
return Status::OK();
}
Status GrpcServer::Init() {
return Init(nullptr, nullptr);
}
Status GrpcServer::Init() { return Init(nullptr, nullptr); }
Status GrpcServer::ParseChannelSpec(const ServerDef& server_def,
Status GrpcServer::ParseChannelSpec(const WorkerCacheFactoryOptions& options,
GrpcChannelSpec* channel_spec) {
for (const auto& job : server_def.cluster().job()) {
for (const auto& job : options.cluster_def->job()) {
std::map<int, string> host_ports;
for (const auto& task : job.tasks()) {
string& host_port = host_ports[task.first];
@ -241,8 +253,7 @@ Status GrpcServer::ParseChannelSpec(const ServerDef& server_def,
task.first, "\": ", host_port, " and ",
task.second);
}
if (job.name() == server_def.job_name() &&
task.first == server_def.task_index()) {
if (job.name() == *options.job_name && task.first == options.task_index) {
host_port = strings::StrCat("localhost:", bound_port_);
} else {
host_port = task.second;
@ -253,17 +264,26 @@ Status GrpcServer::ParseChannelSpec(const ServerDef& server_def,
return Status::OK();
}
Status GrpcServer::WorkerCacheFactory(const ServerDef& server_def,
Status GrpcServer::WorkerCacheFactory(const WorkerCacheFactoryOptions& options,
WorkerCacheInterface** worker_cache) {
string name_prefix =
strings::StrCat("/job:", server_def.job_name(), "/replica:0",
"/task:", server_def.task_index());
if (options.job_name == nullptr || options.job_name->empty()) {
Status s = errors::InvalidArgument(
"The master (current machine) is not included in the provided "
"cluster_def. ",
options.cluster_def->DebugString());
LOG(WARNING) << s;
return s;
}
GrpcChannelSpec channel_spec;
TF_RETURN_IF_ERROR(ParseChannelSpec(server_def, &channel_spec));
TF_RETURN_IF_ERROR(ParseChannelSpec(options, &channel_spec));
std::unique_ptr<GrpcChannelCache> channel_cache(
NewGrpcChannelCache(channel_spec, GetChannelCreationFunction()));
string name_prefix = strings::StrCat("/job:", *options.job_name, "/replica:0",
"/task:", options.task_index);
std::unique_ptr<GrpcChannelCache> channel_cache(NewGrpcChannelCache(
channel_spec, GetChannelCreationFunction(server_def)));
const string host_port = channel_cache->TranslateTask(name_prefix);
int requested_port;
@ -349,8 +369,7 @@ std::shared_ptr<::grpc::ServerCredentials> GrpcServer::GetServerCredentials(
return ::grpc::InsecureServerCredentials();
}
ChannelCreationFunction GrpcServer::GetChannelCreationFunction(
const ServerDef& server_def) const {
ChannelCreationFunction GrpcServer::GetChannelCreationFunction() const {
// We can do this because SparseGrpcChannelCache is robust to nullptr being
// returned by the channel creation function
return ConvertToChannelCreationFunction(NewHostPortGrpcChannel);

View File

@ -37,9 +37,7 @@ class GrpcWorker;
class Master;
// function that creates a RendezvousMgr.
typedef std::function<RendezvousMgrInterface*(
const WorkerEnv*, const std::string& worker_name,
WorkerCacheInterface* worker_cache)>
typedef std::function<RendezvousMgrInterface*(const WorkerEnv*)>
RendezvousMgrCreationFunction;
// function that registers a service to the server. The service needs to
@ -67,7 +65,7 @@ class GrpcServer : public ServerInterface {
protected:
Status Init(ServiceInitFunction service_func,
RendezvousMgrCreationFunction rendezvous_mgr_func);
const RendezvousMgrCreationFunction& rendezvous_mgr_func);
Status Init();
@ -75,17 +73,16 @@ class GrpcServer : public ServerInterface {
virtual std::shared_ptr<::grpc::ServerCredentials> GetServerCredentials(
const ServerDef& server_def) const;
virtual ChannelCreationFunction GetChannelCreationFunction(
const ServerDef& server_def) const;
virtual ChannelCreationFunction GetChannelCreationFunction() const;
virtual std::unique_ptr<Master> CreateMaster(MasterEnv* master_env);
// Creates a WorkerCacheInterface for a session.
Status WorkerCacheFactory(const ServerDef& server_def,
Status WorkerCacheFactory(const WorkerCacheFactoryOptions& options,
WorkerCacheInterface** worker_cache);
// Parses a ServerDef into a GrpcChannelSpec.
Status ParseChannelSpec(const ServerDef& server_def,
// Parses a WorkerCacheFactoryOptions into a GrpcChannelSpec.
Status ParseChannelSpec(const WorkerCacheFactoryOptions& options,
GrpcChannelSpec* channel_spec);
// Returns the port to which this server is bound.

View File

@ -43,7 +43,7 @@ const size_t kSchemePrefixLength = strlen(kSchemePrefix);
/* static */
Status GrpcSession::Create(const SessionOptions& options,
std::unique_ptr<GrpcSession>* out_session) {
std::unique_ptr<GrpcSession> ret(new GrpcSession(options));
std::unique_ptr<GrpcSession> session(new GrpcSession(options));
std::unique_ptr<MasterInterface> master;
// For testing, we enable the client to disable the use of the local
// master registry, so that the RPC stack is exercised.
@ -56,8 +56,8 @@ Status GrpcSession::Create(const SessionOptions& options,
options.target.substr(kSchemePrefixLength), &master_channel));
master.reset(NewGrpcMaster(master_channel));
}
ret->SetRemoteMaster(std::move(master));
*out_session = std::move(ret);
session->SetRemoteMaster(std::move(master));
*out_session = std::move(session);
return Status::OK();
}
@ -102,6 +102,7 @@ Status GrpcSession::CreateImpl(CallOptions* call_options,
CreateSessionRequest req;
*req.mutable_config() = options_.config;
*req.mutable_graph_def() = graph;
req.set_target(options_.target);
ReEncodeConsts(req.mutable_graph_def());
CreateSessionResponse resp;
Status s = master_->CreateSession(call_options, &req, &resp);

View File

@ -113,6 +113,7 @@ class GrpcWorkerService : public AsyncServiceInterface {
// completes, and we may decide to bound some of the request
// types.
ENQUEUE_REQUEST(GetStatus, false);
ENQUEUE_REQUEST(CreateWorkerSession, false);
ENQUEUE_REQUEST(CleanupAll, false);
ENQUEUE_REQUEST(RegisterGraph, false);
ENQUEUE_REQUEST(DeregisterGraph, false);
@ -181,6 +182,16 @@ class GrpcWorkerService : public AsyncServiceInterface {
ENQUEUE_REQUEST(GetStatus, false);
}
void CreateWorkerSessionHandler(
WorkerCall<CreateWorkerSessionRequest, CreateWorkerSessionResponse>*
call) {
Schedule([this, call]() {
Status s = worker_->CreateWorkerSession(&call->request, &call->response);
call->SendResponse(ToGrpcStatus(s));
});
ENQUEUE_REQUEST(CreateWorkerSession, false);
}
void CleanupAllHandler(
WorkerCall<CleanupAllRequest, CleanupAllResponse>* call) {
Schedule([this, call]() {
@ -298,7 +309,6 @@ void GrpcWorker::RecvTensorAsync(CallOptions* opts,
::grpc::ByteBuffer* response,
StatusCallback done) {
const int64 step_id = request->step_id();
WorkerSession* session = env_->session_mgr->WorkerSessionForStepId(step_id);
const string& key = request->rendezvous_key();
TRACEPRINTF("RecvTensor: %lld %s", step_id, key.c_str());
Rendezvous::ParsedKey parsed;
@ -317,7 +327,7 @@ void GrpcWorker::RecvTensorAsync(CallOptions* opts,
// of execution of the callback lambda body below, an RPC
// cancellation should abort the rendezvous.
opts->SetCancelCallback([this, step_id]() { AbortStep(step_id); });
session->rendezvous_mgr->RecvLocalAsync(
env_->rendezvous_mgr->RecvLocalAsync(
step_id, parsed,
[opts, response, done, src_dev](const Status& status,
const Rendezvous::Args& send_args,

View File

@ -38,9 +38,8 @@ namespace {
class RpcRemoteRendezvous : public BaseRemoteRendezvous {
public:
RpcRemoteRendezvous(const WorkerEnv* env, const string& worker_name,
WorkerCacheInterface* cache, int64 step_id)
: BaseRemoteRendezvous(env, worker_name, step_id, false), cache_(cache) {}
RpcRemoteRendezvous(const WorkerEnv* env, int64 step_id)
: BaseRemoteRendezvous(env, step_id, false) {}
protected:
void RecvFromRemoteAsync(const Rendezvous::ParsedKey& parsed,
@ -50,7 +49,6 @@ class RpcRemoteRendezvous : public BaseRemoteRendezvous {
private:
~RpcRemoteRendezvous() override {}
WorkerCacheInterface* const cache_; // Not owned.
TF_DISALLOW_COPY_AND_ASSIGN(RpcRemoteRendezvous);
};
@ -204,75 +202,10 @@ static RpcRecvTensorFreeList* get_call_freelist() {
return call_freelist;
}
// A private cache that wraps worker_cache and allows reuse of
// WorkerInterface objects.
class WorkerFreeListCache : public WorkerCacheInterface {
public:
explicit WorkerFreeListCache(WorkerCacheInterface* w) : wrapped_(w) {}
~WorkerFreeListCache() {
for (auto p : workers_) {
wrapped_->ReleaseWorker(p.first, p.second.worker);
}
}
void ListWorkers(std::vector<string>* workers) const override {
wrapped_->ListWorkers(workers);
}
WorkerInterface* CreateWorker(const string& target) override {
mutex_lock l(mu_);
auto p = workers_.find(target);
if (p != workers_.end()) {
return p->second.worker;
}
WorkerState state;
state.worker = wrapped_->CreateWorker(target);
if (state.worker != nullptr) {
workers_.insert(std::make_pair(target, state));
}
return state.worker;
}
void ReleaseWorker(const string& target, WorkerInterface* worker) override {
// TODO(jeff,sanjay): Should decrement ref-count when we implement eviction.
}
bool GetDeviceLocalityNonBlocking(const string& device,
DeviceLocality* locality) override {
return wrapped_->GetDeviceLocalityNonBlocking(device, locality);
}
void GetDeviceLocalityAsync(const string& device, DeviceLocality* locality,
StatusCallback done) override {
wrapped_->GetDeviceLocalityAsync(device, locality, done);
}
void SetLogging(bool active) override { wrapped_->SetLogging(active); }
void ClearLogs() override { wrapped_->ClearLogs(); }
bool RetrieveLogs(int64 step_id, StepStats* ss) override {
return wrapped_->RetrieveLogs(step_id, ss);
}
private:
WorkerCacheInterface* wrapped_;
// Information kept per created WorkerInterface.
struct WorkerState {
WorkerInterface* worker;
// TODO(jeff,sanjay): Add reference count if we support eviction.
};
// TODO(jeff,sanjay): Eviction when the map becomes too big.
mutex mu_;
std::unordered_map<string, WorkerState> workers_ GUARDED_BY(mu_);
};
void RpcRemoteRendezvous::RecvFromRemoteAsync(
const Rendezvous::ParsedKey& parsed, const Rendezvous::Args& recv_args,
DoneCallback done) {
CHECK(is_initialized());
Status s;
// Prepare a RecvTensor call that can handle being aborted.
@ -284,17 +217,21 @@ void RpcRemoteRendezvous::RecvFromRemoteAsync(
s = errors::Internal(parsed.src_device,
" is invalid remote source device.");
}
WorkerInterface* rwi = cache_->CreateWorker(call->src_worker_);
WorkerSession* sess = session();
WorkerInterface* rwi = sess->worker_cache->CreateWorker(call->src_worker_);
if (s.ok() && rwi == nullptr) {
s = errors::Internal("No worker known as ", call->src_worker_);
}
Device* dst_device;
if (s.ok()) {
s = env_->device_mgr->LookupDevice(parsed.dst_device, &dst_device);
s = sess->device_mgr->LookupDevice(parsed.dst_device, &dst_device);
}
if (!s.ok()) {
get_call_freelist()->Release(call, cache_);
if (rwi != nullptr) {
sess->worker_cache->ReleaseWorker(call->src_worker_, rwi);
}
get_call_freelist()->Release(call, sess->worker_cache.get());
done(s, Args(), recv_args, Tensor{}, false);
return;
}
@ -314,26 +251,21 @@ void RpcRemoteRendezvous::RecvFromRemoteAsync(
// current status should be bad.
Status s = call->status();
call->done()(s, Args(), call->recv_args(), call->tensor(), call->is_dead());
cache_->ReleaseWorker(call->src_worker_, call->wi_);
session()->worker_cache->ReleaseWorker(call->src_worker_, call->wi_);
call->wi_ = nullptr;
get_call_freelist()->Release(call, cache_);
get_call_freelist()->Release(call, session()->worker_cache.get());
Unref();
});
}
} // namespace
RpcRendezvousMgr::RpcRendezvousMgr(const WorkerEnv* env,
const string& worker_name,
WorkerCacheInterface* worker_cache)
: BaseRendezvousMgr(env, worker_name),
cache_(new WorkerFreeListCache(worker_cache)) {}
RpcRendezvousMgr::RpcRendezvousMgr(const WorkerEnv* env)
: BaseRendezvousMgr(env) {}
BaseRemoteRendezvous* RpcRendezvousMgr::Create(int64 step_id,
const WorkerEnv* worker_env,
const string& worker_name) {
return new RpcRemoteRendezvous(worker_env, worker_name, cache_.get(),
step_id);
const WorkerEnv* worker_env) {
return new RpcRemoteRendezvous(worker_env, step_id);
}
} // end namespace tensorflow

View File

@ -17,13 +17,13 @@ limitations under the License.
#define TENSORFLOW_CORE_DISTRIBUTED_RUNTIME_RPC_RPC_RENDEZVOUS_MGR_H_
#include "tensorflow/core/distributed_runtime/base_rendezvous_mgr.h"
#include "tensorflow/core/distributed_runtime/worker_cache.h"
#include "tensorflow/core/distributed_runtime/worker_env.h"
#include "tensorflow/core/distributed_runtime/worker_session.h"
#include "tensorflow/core/platform/macros.h"
namespace tensorflow {
class DeviceMgr;
// RendezvousMgr keeps track of a set of local rendezvous instances.
// All tensors sent by this worker are buffered in a RendezvousMgr
// until the tensor is received. Each global unique "step_id"
@ -44,17 +44,12 @@ namespace tensorflow {
// RendezvousMgr must have keys generated by Rendezvous::CreateKey.
class RpcRendezvousMgr : public BaseRendezvousMgr {
public:
explicit RpcRendezvousMgr(const WorkerEnv* env, const string& worker_name,
WorkerCacheInterface* worker_cache);
explicit RpcRendezvousMgr(const WorkerEnv* env);
protected:
BaseRemoteRendezvous* Create(int64 step_id, const WorkerEnv* worker_env,
const string& session_name) override;
BaseRemoteRendezvous* Create(int64 step_id, const WorkerEnv* worker_env);
private:
// Private cache_ that allows us to reuse WorkerInterface objects.
std::unique_ptr<WorkerCacheInterface> cache_;
TF_DISALLOW_COPY_AND_ASSIGN(RpcRendezvousMgr);
};

View File

@ -68,9 +68,9 @@ class RpcRendezvousMgrTest : public ::testing::Test {
: cache_(new DummyWorkerCache),
worker_session_("/job:mnist/replica:1/task:2",
std::unique_ptr<WorkerCacheInterface>(cache_),
std::unique_ptr<RendezvousMgrInterface>(),
std::unique_ptr<DeviceMgr>(),
std::unique_ptr<GraphMgr>()),
rmgr_(&env, worker_session_.worker_name, cache_) {
rmgr_(&env) {
env.env = Env::Default();
}
@ -87,7 +87,8 @@ TEST_F(RpcRendezvousMgrTest, LocalSendRecv) {
"/job:mnist/replica:1/task:2/cpu:0", 7890,
"/job:mnist/replica:1/task:2/cpu:1", "foo", FrameAndIter(0, 0)));
{
Rendezvous* rendez = rmgr_.Find(step_id);
RemoteRendezvous* rendez = rmgr_.Find(step_id);
TF_ASSERT_OK(rendez->Initialize(&worker_session_));
core::ScopedUnref unref(rendez);
Rendezvous::Args args;
TF_ASSERT_OK(rendez->Send(key, args, V("peach"), false));
@ -107,7 +108,7 @@ TEST_F(RpcRendezvousMgrTest, LocalAbort) {
"/job:mnist/replica:1/task:2/cpu:1", "foo", FrameAndIter(0, 0)));
{ // Explicit Abort().
const int64 step_id = 123;
Rendezvous* rendez = rmgr_.Find(step_id);
RemoteRendezvous* rendez = rmgr_.Find(step_id);
core::ScopedUnref unref(rendez);
SchedClosure([this, rendez]() {
env.env->SleepForMicroseconds(100 * 1000);
@ -116,11 +117,12 @@ TEST_F(RpcRendezvousMgrTest, LocalAbort) {
Tensor val(DT_STRING);
bool val_dead = false;
Rendezvous::Args args;
TF_ASSERT_OK(rendez->Initialize(&worker_session_));
EXPECT_TRUE(errors::IsAborted(rendez->Recv(key, args, &val, &val_dead)));
}
{ // Cleanup causes Abort().
const int64 step_id = 321;
Rendezvous* rendez = rmgr_.Find(step_id);
RemoteRendezvous* rendez = rmgr_.Find(step_id);
core::ScopedUnref unref(rendez);
SchedClosure([this, step_id]() {
env.env->SleepForMicroseconds(100 * 1000);
@ -129,6 +131,7 @@ TEST_F(RpcRendezvousMgrTest, LocalAbort) {
Tensor val(DT_STRING);
bool val_dead = false;
Rendezvous::Args args;
TF_ASSERT_OK(rendez->Initialize(&worker_session_));
EXPECT_TRUE(errors::IsAborted(rendez->Recv(key, args, &val, &val_dead)));
}
}
@ -139,7 +142,8 @@ TEST_F(RpcRendezvousMgrTest, CleanupAll) {
"/job:mnist/replica:1/task:2/cpu:1", "foo", FrameAndIter(0, 0)));
{
const int64 step_id = 123;
Rendezvous* rendez = rmgr_.Find(step_id);
RemoteRendezvous* rendez = rmgr_.Find(step_id);
TF_ASSERT_OK(rendez->Initialize(&worker_session_));
core::ScopedUnref unref(rendez);
Rendezvous::Args args;
TF_ASSERT_OK(rendez->Send(key, args, V("peach"), false));
@ -168,10 +172,11 @@ TEST_F(RpcRendezvousMgrTest, TransferDummyDeviceContext) {
"/job:mnist/replica:1/task:2/cpu:0", 7890,
"/job:mnist/replica:1/task:2/cpu:1", "foo", FrameAndIter(0, 0)));
{
Rendezvous* rendez = rmgr_.Find(step_id);
RemoteRendezvous* rendez = rmgr_.Find(step_id);
core::ScopedUnref unref(rendez);
Rendezvous::Args args;
args.device_context = dc;
TF_ASSERT_OK(rendez->Initialize(&worker_session_));
TF_ASSERT_OK(rendez->Send(key, args, V("peach"), false));
}
{

View File

@ -17,8 +17,9 @@ limitations under the License.
#include <utility>
#include "tensorflow/core/common_runtime/device_mgr.h"
#include "tensorflow/core/common_runtime/renamed_device.h"
#include "tensorflow/core/distributed_runtime/graph_mgr.h"
#include "tensorflow/core/distributed_runtime/rpc/rpc_rendezvous_mgr.h"
#include "tensorflow/core/lib/strings/strcat.h"
namespace tensorflow {
@ -26,23 +27,12 @@ namespace tensorflow {
SessionMgr::SessionMgr(
WorkerEnv* worker_env, const string& default_worker_name,
std::unique_ptr<WorkerCacheInterface> default_worker_cache,
std::unique_ptr<RendezvousMgrInterface> default_rendezvous_mgr,
WorkerCacheFactory worker_cache_factory)
: SessionMgr(
worker_env, default_worker_name, std::move(default_worker_cache),
default_rendezvous_mgr.release(), std::move(worker_cache_factory)) {}
SessionMgr::SessionMgr(
WorkerEnv* worker_env, const string& default_worker_name,
std::unique_ptr<WorkerCacheInterface> default_worker_cache,
RendezvousMgrInterface* default_rendezvous_mgr,
WorkerCacheFactory worker_cache_factory)
: worker_env_(worker_env),
legacy_session_(
default_worker_name, std::move(default_worker_cache),
std::unique_ptr<RendezvousMgrInterface>(default_rendezvous_mgr),
legacy_session_(default_worker_name, std::move(default_worker_cache),
std::unique_ptr<DeviceMgr>(worker_env->device_mgr),
std::unique_ptr<GraphMgr>(
new GraphMgr(worker_env, default_rendezvous_mgr))),
new GraphMgr(worker_env, worker_env->device_mgr))),
worker_cache_factory_(std::move(worker_cache_factory)) {}
string SessionMgr::WorkerNameFromServerDef(const ServerDef& server_def) {
@ -53,20 +43,28 @@ string SessionMgr::WorkerNameFromServerDef(const ServerDef& server_def) {
Status SessionMgr::CreateSession(const string& session,
const ServerDef& server_def) {
mutex_lock l(mu_);
if (session.empty()) {
return errors::InvalidArgument("Session must be non-empty.");
}
const string worker_name = WorkerNameFromServerDef(server_def);
WorkerCacheInterface* worker_cache = nullptr;
TF_RETURN_IF_ERROR(worker_cache_factory_(server_def, &worker_cache));
std::unique_ptr<RendezvousMgrInterface> rendezvous_mgr(
new RpcRendezvousMgr(worker_env_, worker_name, worker_cache));
std::vector<Device*> renamed_devices;
for (Device* d : worker_env_->local_devices) {
renamed_devices.push_back(
RenamedDevice::NewRenamedDevice(worker_name, d, false));
}
std::unique_ptr<DeviceMgr> device_mgr(new DeviceMgr(renamed_devices));
std::unique_ptr<GraphMgr> graph_mgr(
new GraphMgr(worker_env_, rendezvous_mgr.get()));
new GraphMgr(worker_env_, device_mgr.get()));
std::unique_ptr<WorkerSession> worker_session(new WorkerSession(
worker_name, std::unique_ptr<WorkerCacheInterface>(worker_cache),
std::move(rendezvous_mgr), std::move(graph_mgr)));
std::move(device_mgr), std::move(graph_mgr)));
sessions_.insert(std::make_pair(session, std::move(worker_session)));
return Status::OK();
@ -78,22 +76,6 @@ Status SessionMgr::DeleteSession(const string& session) {
if (it != sessions_.end()) {
sessions_.erase(it);
}
std::set<string> graph_handles;
for (auto graph_handle_it = sessions_by_graph_handle_.begin();
graph_handle_it != sessions_by_graph_handle_.end(); ++graph_handle_it) {
if (graph_handle_it->second == session) {
graph_handles.insert(graph_handle_it->first);
graph_handle_it = sessions_by_graph_handle_.erase(graph_handle_it);
if (graph_handle_it == sessions_by_graph_handle_.end()) break;
}
}
for (auto step_id_it = graphs_by_step_id_.begin();
step_id_it != graphs_by_step_id_.end(); ++step_id_it) {
if (graph_handles.find(step_id_it->second) != graph_handles.end()) {
step_id_it = graphs_by_step_id_.erase(step_id_it);
if (step_id_it == graphs_by_step_id_.end()) break;
}
}
return Status::OK();
}
@ -114,58 +96,4 @@ WorkerSession* SessionMgr::WorkerSessionForSession(const string& session) {
WorkerSession* SessionMgr::LegacySession() { return &legacy_session_; }
WorkerSession* SessionMgr::WorkerSessionForGraphHandleUnlocked(
const string& graph_handle) {
auto it = sessions_by_graph_handle_.find(graph_handle);
if (it == sessions_by_graph_handle_.end()) {
return &legacy_session_;
} else {
return WorkerSessionForSessionUnlocked(it->second);
}
}
WorkerSession* SessionMgr::WorkerSessionForGraphHandle(
const string& graph_handle) {
mutex_lock l(mu_);
return WorkerSessionForGraphHandleUnlocked(graph_handle);
}
WorkerSession* SessionMgr::WorkerSessionForStepId(const int64 step_id) {
mutex_lock l(mu_);
auto it = graphs_by_step_id_.find(step_id);
if (it == graphs_by_step_id_.end()) {
return &legacy_session_;
} else {
return WorkerSessionForGraphHandleUnlocked(it->second);
}
}
void SessionMgr::AssociateGraphWithSession(const string& session,
const string& graph_handle) {
mutex_lock l(mu_);
sessions_by_graph_handle_[graph_handle] = session;
}
void SessionMgr::DisassociateGraphFromSession(const string& graph_handle) {
mutex_lock l(mu_);
auto it = sessions_by_graph_handle_.find(graph_handle);
if (it != sessions_by_graph_handle_.end()) {
sessions_by_graph_handle_.erase(it);
}
}
void SessionMgr::AssociateStepIdWithGraph(const string& graph_handle,
const int64 step_id) {
mutex_lock l(mu_);
graphs_by_step_id_[step_id] = graph_handle;
}
void SessionMgr::DisassociateStepIdFromGraph(const int64 step_id) {
mutex_lock l(mu_);
auto it = graphs_by_step_id_.find(step_id);
if (it != graphs_by_step_id_.end()) {
graphs_by_step_id_.erase(it);
}
}
} // namespace tensorflow

View File

@ -30,6 +30,8 @@ struct WorkerEnv;
// SessionMgr keeps track of information related to a given session.
//
// SessionMgr runs on the workers.
//
// SessionMgr is threadsafe.
class SessionMgr {
public:
@ -39,7 +41,6 @@ class SessionMgr {
explicit SessionMgr(
WorkerEnv* worker_env, const string& default_worker_name,
std::unique_ptr<WorkerCacheInterface> default_worker_cache,
std::unique_ptr<RendezvousMgrInterface> default_rendezvous_mgr,
WorkerCacheFactory worker_cache_factory);
~SessionMgr() {}
@ -50,49 +51,36 @@ class SessionMgr {
WorkerSession* WorkerSessionForSession(const string& session);
WorkerSession* LegacySession();
// Locates the worker session for a given graph handle
WorkerSession* WorkerSessionForGraphHandle(const string& graph_handle);
void AssociateGraphWithSession(const string& session,
const string& graph_handle);
void DisassociateGraphFromSession(const string& graph_handle);
// Locates a worker session for a given step id
WorkerSession* WorkerSessionForStepId(const int64 step_id);
void AssociateStepIdWithGraph(const string& graph_handle,
const int64 step_id);
void DisassociateStepIdFromGraph(const int64 step_id);
Status DeleteSession(const string& session);
static string WorkerNameFromServerDef(const ServerDef& server_def);
private:
// Private constructor to work around std::unique_ptr ownership issues.
explicit SessionMgr(
WorkerEnv* worker_env, const string& default_worker_name,
std::unique_ptr<WorkerCacheInterface> default_worker_cache,
RendezvousMgrInterface* default_rendezvous_mgr,
WorkerCacheFactory worker_cache_factory);
const WorkerEnv* const worker_env_; // Not owned.
// A note about destruction:
// We must delete graph_mgr before device_mgr, due to shared
// ownership of OpKernels in the executors. (The graph_mgr will
// free all stateless OpKernels, and pass over borrowed stateful
// OpKernels, which are also held in their respective devices'
// OpSegments.)
//
// legacy_session_ owns the worker_env_.device_mgr, and so we must ensure
// that sessions_'s WorkerSessions are deleted (which do not own the
// underlying devices, but instead own RenamedDevices) before
// legacy_session_ is deleted. Further, we must ensure that WorkerSession's
// device_mgr is deleted after WorkerSession's graph_mgr.
WorkerSession legacy_session_;
const WorkerCacheFactory worker_cache_factory_;
WorkerSession* WorkerSessionForSessionUnlocked(const string& session)
EXCLUSIVE_LOCKS_REQUIRED(mu_);
WorkerSession* WorkerSessionForGraphHandleUnlocked(const string& graph_handle)
EXCLUSIVE_LOCKS_REQUIRED(mu_);
mutex mu_;
// A map from session identifier to internal session structure.
std::map<string, std::unique_ptr<WorkerSession>> sessions_ GUARDED_BY(mu_);
// A map from graph handles to the session that they belong to.
std::map<string, string> sessions_by_graph_handle_ GUARDED_BY(mu_);
// A map from globally-unique step id's to the corresponding graph handles.
std::map<int64, string> graphs_by_step_id_ GUARDED_BY(mu_);
};
} // namespace tensorflow

View File

@ -27,8 +27,6 @@ class SessionMgrTest : public ::testing::Test {
SessionMgrTest()
: mgr_(&env_, "/job:mnist/replica:0/task:0",
std::unique_ptr<WorkerCacheInterface>(),
std::unique_ptr<RendezvousMgrInterface>(new RpcRendezvousMgr(
&env_, "/job:mnist/replica:0/task:0", nullptr)),
factory_),
legacy_session_(mgr_.WorkerSessionForSession("novel_session_id")) {}
@ -48,90 +46,19 @@ TEST_F(SessionMgrTest, CreateSessionSimple) {
TF_EXPECT_OK(mgr_.CreateSession(session_handle, server_def));
WorkerSession* session = mgr_.WorkerSessionForSession(session_handle);
EXPECT_NE(nullptr, session) << "Session for " << session_handle << "was null";
EXPECT_NE(mgr_.LegacySession(), session);
TF_EXPECT_OK(mgr_.DeleteSession(session_handle));
}
TEST_F(SessionMgrTest, AssociateGraphWithSession) {
TEST_F(SessionMgrTest, LegacySession) {
ServerDef server_def;
string session_handle = "test_session_handle";
TF_EXPECT_OK(mgr_.CreateSession(session_handle, server_def));
string session_handle = "";
WorkerSession* session = mgr_.WorkerSessionForSession(session_handle);
ASSERT_NE(nullptr, session) << "Session for " << session_handle << "was null";
string graph_handle = "test_graph_handle";
mgr_.AssociateGraphWithSession(session_handle, graph_handle);
WorkerSession* graph_session = mgr_.WorkerSessionForGraphHandle(graph_handle);
ASSERT_EQ(session, graph_session);
EXPECT_EQ(mgr_.LegacySession(), session);
TF_EXPECT_OK(mgr_.DeleteSession(session_handle));
}
TEST_F(SessionMgrTest, AssociateStepWithGraph) {
ServerDef server_def;
string session_handle = "test_session_handle";
TF_EXPECT_OK(mgr_.CreateSession(session_handle, server_def));
WorkerSession* session = mgr_.WorkerSessionForSession(session_handle);
ASSERT_NE(nullptr, session) << "Session for " << session_handle << "was null";
string graph_handle = "test_graph_handle";
mgr_.AssociateGraphWithSession(session_handle, graph_handle);
WorkerSession* graph_session = mgr_.WorkerSessionForGraphHandle(graph_handle);
ASSERT_EQ(session, graph_session);
int64 step_id = 1234567890L;
mgr_.AssociateStepIdWithGraph(graph_handle, step_id);
WorkerSession* step_session = mgr_.WorkerSessionForStepId(step_id);
ASSERT_EQ(session, step_session);
ASSERT_EQ(graph_session, step_session);
TF_EXPECT_OK(mgr_.DeleteSession(session_handle));
}
TEST_F(SessionMgrTest, AssociateGraphWithSession_MissingSession) {
string session_handle = "test_session_handle";
string graph_handle = "test_graph_handle";
mgr_.AssociateGraphWithSession(session_handle, graph_handle);
WorkerSession* graph_session = mgr_.WorkerSessionForGraphHandle(graph_handle);
ASSERT_EQ(legacy_session_, graph_session);
}
TEST_F(SessionMgrTest, AssociateStepWithGraph_MissingGraph) {
ServerDef server_def;
string session_handle = "test_session_handle";
TF_EXPECT_OK(mgr_.CreateSession(session_handle, server_def));
WorkerSession* session = mgr_.WorkerSessionForSession(session_handle);
ASSERT_NE(nullptr, session) << "Session for " << session_handle << "was null";
string graph_handle = "test_graph_handle";
int64 step_id = 1234567890L;
mgr_.AssociateStepIdWithGraph(graph_handle, step_id);
WorkerSession* step_session = mgr_.WorkerSessionForStepId(step_id);
ASSERT_EQ(legacy_session_, step_session);
}
TEST_F(SessionMgrTest, AssociateStepWithGraph_MissingSession) {
string session_handle = "test_session_handle";
string graph_handle = "test_graph_handle";
mgr_.AssociateGraphWithSession(session_handle, graph_handle);
WorkerSession* graph_session = mgr_.WorkerSessionForGraphHandle(graph_handle);
ASSERT_EQ(legacy_session_, graph_session);
int64 step_id = 1234567890L;
mgr_.AssociateStepIdWithGraph(graph_handle, step_id);
WorkerSession* step_session = mgr_.WorkerSessionForStepId(step_id);
ASSERT_EQ(legacy_session_, step_session);
}
TEST_F(SessionMgrTest, AssociateStepWithGraph_MissingSessionAndGraph) {
string session_handle = "test_session_handle";
string graph_handle = "test_graph_handle";
int64 step_id = 1234567890L;
mgr_.AssociateStepIdWithGraph(graph_handle, step_id);
WorkerSession* step_session = mgr_.WorkerSessionForStepId(step_id);
ASSERT_EQ(legacy_session_, step_session);
}
TEST_F(SessionMgrTest, WorkerNameFromServerDef) {
ServerDef server_def;
server_def.set_job_name("worker");

View File

@ -56,10 +56,6 @@ void Worker::RegisterGraphAsync(const RegisterGraphRequest* request,
Status s = session->graph_mgr->Register(
request->session_handle(), request->graph_def(), request->graph_options(),
request->debug_options(), response->mutable_graph_handle());
if (s.ok()) {
env_->session_mgr->AssociateGraphWithSession(request->session_handle(),
response->graph_handle());
}
done(s);
}
@ -67,9 +63,8 @@ void Worker::DeregisterGraphAsync(const DeregisterGraphRequest* request,
DeregisterGraphResponse* response,
StatusCallback done) {
WorkerSession* session =
env_->session_mgr->WorkerSessionForGraphHandle(request->graph_handle());
env_->session_mgr->WorkerSessionForSession(request->session_handle());
Status s = session->graph_mgr->Deregister(request->graph_handle());
env_->session_mgr->DisassociateGraphFromSession(request->graph_handle());
done(s);
}
@ -141,8 +136,7 @@ void Worker::SetOrCallFinalCallback(const string& graph_handle, int step_id,
}
void Worker::AbortStep(int64 step_id) {
WorkerSession* session = env_->session_mgr->WorkerSessionForStepId(step_id);
Rendezvous* rendez = session->rendezvous_mgr->Find(step_id);
Rendezvous* rendez = env_->rendezvous_mgr->Find(step_id);
SchedNonBlockingClosureAfter(1000000, [rendez, step_id]() {
// Delay a bit before aborting the step. This way, the root
// cause may return first back to the client instead of this
@ -193,8 +187,7 @@ void Worker::DoRunGraph(CallOptions* opts, RunGraphRequestWrapper* request,
const int64 step_id = request->step_id();
TRACEPRINTF("RunGraph: %lld", step_id);
WorkerSession* session =
env_->session_mgr->WorkerSessionForGraphHandle(request->graph_handle());
env_->session_mgr->AssociateStepIdWithGraph(request->graph_handle(), step_id);
env_->session_mgr->WorkerSessionForSession(request->session_handle());
GraphMgr::NamedTensors in;
GraphMgr::NamedTensors* out = new GraphMgr::NamedTensors;
Status s = PrepareRunGraph(request, &in, out);
@ -231,8 +224,8 @@ void Worker::DoRunGraph(CallOptions* opts, RunGraphRequestWrapper* request,
}
CostGraphDef* cost_graph = response->mutable_cost_graph();
session->graph_mgr->ExecuteAsync(
request->graph_handle(), step_id, request->exec_opts(), collector,
cost_graph, cm, in,
request->graph_handle(), step_id, session, request->exec_opts(),
collector, cost_graph, cm, in,
[this, step_id, response, session, cm, out, token, collector, opts,
done](Status s) {
if (s.ok()) {
@ -267,8 +260,8 @@ void Worker::DoPartialRunGraph(CallOptions* opts,
const string& graph_handle = request->graph_handle();
TRACEPRINTF("PartialRunGraph: %lld", step_id);
WorkerSession* session =
env_->session_mgr->WorkerSessionForGraphHandle(graph_handle);
env_->session_mgr->AssociateStepIdWithGraph(graph_handle, step_id);
env_->session_mgr->WorkerSessionForSession(request->session_handle());
GraphMgr::NamedTensors in;
GraphMgr::NamedTensors* out = new GraphMgr::NamedTensors;
Status s = PrepareRunGraph(request, &in, out);
@ -315,8 +308,8 @@ void Worker::DoPartialRunGraph(CallOptions* opts,
[cm]() { cm->StartCancel(); });
}
session->graph_mgr->ExecuteAsync(
graph_handle, step_id, request->exec_opts(), nullptr /* collector */,
nullptr /* cost_graph */, cm, in,
graph_handle, step_id, session, request->exec_opts(),
nullptr /* collector */, nullptr /* cost_graph */, cm, in,
[this, token, graph_handle, step_id, cm](Status s) {
{
mutex_lock l(mu_);
@ -365,8 +358,7 @@ void Worker::CleanupGraphAsync(const CleanupGraphRequest* request,
CleanupGraphResponse* response,
StatusCallback done) {
const int64 step_id = request->step_id();
WorkerSession* session = env_->session_mgr->WorkerSessionForStepId(step_id);
session->rendezvous_mgr->Cleanup(step_id);
env_->rendezvous_mgr->Cleanup(step_id);
done(Status::OK());
}
@ -394,8 +386,8 @@ void Worker::TracingAsync(const TracingRequest* request,
Status Worker::PrepareRecvTensor(const Rendezvous::ParsedKey& parsed,
Device** src_dev) {
// Figures out which device the tensor is hosted on.
TF_RETURN_IF_ERROR(
env_->device_mgr->LookupDevice(parsed.src_device, src_dev));
string local_name = DeviceNameUtils::LocalName(parsed.src_device);
TF_RETURN_IF_ERROR(env_->device_mgr->LookupDevice(local_name, src_dev));
// Does the device have the right incarnation number we expect?
if ((*src_dev)->attributes().incarnation() != parsed.src_incarnation) {

View File

@ -16,6 +16,7 @@ limitations under the License.
#ifndef TENSORFLOW_CORE_DISTRIBUTED_RUNTIME_WORKER_ENV_H_
#define TENSORFLOW_CORE_DISTRIBUTED_RUNTIME_WORKER_ENV_H_
#include <vector>
#include "tensorflow/core/platform/types.h"
namespace tensorflow {
@ -24,8 +25,10 @@ namespace thread {
class ThreadPool;
} // namespace thread
class Device;
class DeviceMgr;
class Env;
class RendezvousMgrInterface;
class SessionMgr;
// The worker environment class, which holds a bag of pointers to
@ -38,10 +41,18 @@ struct WorkerEnv {
// session_mgr encapsulates state for each session.
SessionMgr* session_mgr = nullptr;
// The local devices of this worker. Devices are owned by the device_mgr.
//
// REQUIRES: !local_devices.empty().
std::vector<Device*> local_devices;
// device_mgr manages local devices (cpu and gpu). The WorkerService
// is the network interface for managed devices.
DeviceMgr* device_mgr = nullptr;
// A set of rendezvous keyed by step ids.
RendezvousMgrInterface* rendezvous_mgr = nullptr;
// A pool of threads for scheduling compute work.
thread::ThreadPool* compute_pool = nullptr;
};

Some files were not shown because too many files have changed in this diff Show More